1use anyhow::{Context, Result};
2use openssh::{KnownHosts, Session, Stdio};
3use std::collections::HashMap;
4
5use crate::config::Server;
6
7pub struct SshPool {
8 sessions: HashMap<String, Session>,
9}
10
11impl SshPool {
12 pub async fn connect(servers: &HashMap<String, Server>) -> Result<Self> {
13 let mut sessions = HashMap::new();
14 for (name, server) in servers {
15 let ssh_target = server.ip.as_deref().unwrap_or(&server.host);
16 let dest = format!("ssh://{}@{}", server.user, ssh_target);
17 let session = Session::connect(&dest, KnownHosts::Strict)
18 .await
19 .with_context(|| format!("Failed to connect to {name}"))?;
20 sessions.insert(name.clone(), session);
21 }
22 Ok(Self { sessions })
23 }
24
25 pub async fn connect_one(name: &str, server: &Server) -> Result<Self> {
26 let ssh_target = server.ip.as_deref().unwrap_or(&server.host);
27 let dest = format!("ssh://{}@{}", server.user, ssh_target);
28 let session = Session::connect(&dest, KnownHosts::Strict)
29 .await
30 .with_context(|| format!("Failed to connect to {name}"))?;
31 let mut sessions = HashMap::new();
32 sessions.insert(name.to_string(), session);
33 Ok(Self { sessions })
34 }
35
36 pub fn get(&self, server: &str) -> Result<&Session> {
37 self.sessions
38 .get(server)
39 .with_context(|| format!("No connection to server '{server}'"))
40 }
41
42 pub async fn exec(&self, server: &str, cmd: &str) -> Result<String> {
43 let session = self.get(server)?;
44 let output = session
45 .command("sh")
46 .arg("-c")
47 .arg(cmd)
48 .stdout(Stdio::piped())
49 .stderr(Stdio::piped())
50 .output()
51 .await
52 .with_context(|| format!("Failed to run command on {server}"))?;
53
54 if !output.status.success() {
55 anyhow::bail!(
56 "Command failed on {} (exit {}): {}\nstderr: {}",
57 server,
58 output.status,
59 cmd,
60 String::from_utf8_lossy(&output.stderr).trim()
61 );
62 }
63 Ok(String::from_utf8_lossy(&output.stdout).to_string())
64 }
65
66 pub async fn exec_streaming(
67 &self,
68 server: &str,
69 cmd: &str,
70 ) -> Result<openssh::Child<&Session>> {
71 let session = self.get(server)?;
72 let child = session
73 .command("sh")
74 .arg("-c")
75 .arg(cmd)
76 .stdout(Stdio::piped())
77 .stderr(Stdio::piped())
78 .spawn()
79 .await
80 .with_context(|| format!("Failed to run command on {server}"))?;
81 Ok(child)
82 }
83
84 pub async fn upload_file(&self, server: &str, remote_path: &str, content: &str) -> Result<()> {
85 let session = self.get(server)?;
86 let escaped = content.replace('\'', "'\\''");
87 let cmd = format!("cat > {remote_path} <<'FLOW_EOF'\n{escaped}\nFLOW_EOF");
88 let output = session
89 .command("sh")
90 .arg("-c")
91 .arg(&cmd)
92 .output()
93 .await
94 .with_context(|| format!("Failed to upload to {server}:{remote_path}"))?;
95
96 if !output.status.success() {
97 anyhow::bail!(
98 "Failed to write {}:{}: {}",
99 server,
100 remote_path,
101 String::from_utf8_lossy(&output.stderr)
102 );
103 }
104 Ok(())
105 }
106
107 pub async fn close(self) -> Result<()> {
108 for (_, session) in self.sessions {
109 session.close().await?;
110 }
111 Ok(())
112 }
113}