Skip to main content

iron/
ssh.rs

1use anyhow::{Context, Result};
2use openssh::{KnownHosts, Session, Stdio};
3use std::collections::HashMap;
4
5use crate::config::Server;
6
7pub struct SshPool {
8    sessions: HashMap<String, Session>,
9}
10
11impl SshPool {
12    pub async fn connect(servers: &HashMap<String, Server>) -> Result<Self> {
13        let mut sessions = HashMap::new();
14        for (name, server) in servers {
15            let ssh_target = server.ip.as_deref().unwrap_or(&server.host);
16            let dest = format!("ssh://{}@{}", server.user, ssh_target);
17            let session = Session::connect(&dest, KnownHosts::Strict)
18                .await
19                .with_context(|| format!("Failed to connect to {name}"))?;
20            sessions.insert(name.clone(), session);
21        }
22        Ok(Self { sessions })
23    }
24
25    pub async fn connect_one(name: &str, server: &Server) -> Result<Self> {
26        let ssh_target = server.ip.as_deref().unwrap_or(&server.host);
27        let dest = format!("ssh://{}@{}", server.user, ssh_target);
28        let session = Session::connect(&dest, KnownHosts::Strict)
29            .await
30            .with_context(|| format!("Failed to connect to {name}"))?;
31        let mut sessions = HashMap::new();
32        sessions.insert(name.to_string(), session);
33        Ok(Self { sessions })
34    }
35
36    pub fn get(&self, server: &str) -> Result<&Session> {
37        self.sessions
38            .get(server)
39            .with_context(|| format!("No connection to server '{server}'"))
40    }
41
42    pub async fn exec(&self, server: &str, cmd: &str) -> Result<String> {
43        let session = self.get(server)?;
44        let output = session
45            .command("sh")
46            .arg("-c")
47            .arg(cmd)
48            .stdout(Stdio::piped())
49            .stderr(Stdio::piped())
50            .output()
51            .await
52            .with_context(|| format!("Failed to run command on {server}"))?;
53
54        if !output.status.success() {
55            anyhow::bail!(
56                "Command failed on {} (exit {}): {}\nstderr: {}",
57                server,
58                output.status,
59                cmd,
60                String::from_utf8_lossy(&output.stderr).trim()
61            );
62        }
63        Ok(String::from_utf8_lossy(&output.stdout).to_string())
64    }
65
66    pub async fn exec_streaming(
67        &self,
68        server: &str,
69        cmd: &str,
70    ) -> Result<openssh::Child<&Session>> {
71        let session = self.get(server)?;
72        let child = session
73            .command("sh")
74            .arg("-c")
75            .arg(cmd)
76            .stdout(Stdio::piped())
77            .stderr(Stdio::piped())
78            .spawn()
79            .await
80            .with_context(|| format!("Failed to run command on {server}"))?;
81        Ok(child)
82    }
83
84    pub async fn upload_file(&self, server: &str, remote_path: &str, content: &str) -> Result<()> {
85        let session = self.get(server)?;
86        let escaped = content.replace('\'', "'\\''");
87        let cmd = format!("cat > {remote_path} <<'FLOW_EOF'\n{escaped}\nFLOW_EOF");
88        let output = session
89            .command("sh")
90            .arg("-c")
91            .arg(&cmd)
92            .output()
93            .await
94            .with_context(|| format!("Failed to upload to {server}:{remote_path}"))?;
95
96        if !output.status.success() {
97            anyhow::bail!(
98                "Failed to write {}:{}: {}",
99                server,
100                remote_path,
101                String::from_utf8_lossy(&output.stderr)
102            );
103        }
104        Ok(())
105    }
106
107    pub async fn close(self) -> Result<()> {
108        for (_, session) in self.sessions {
109            session.close().await?;
110        }
111        Ok(())
112    }
113}