fresh/services/remote/
connection.rs1use crate::services::remote::channel::AgentChannel;
6use crate::services::remote::protocol::AgentResponse;
7use crate::services::remote::AGENT_SOURCE;
8use std::path::PathBuf;
9use std::process::Stdio;
10use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
11use tokio::process::{Child, Command};
12
13#[derive(Debug, thiserror::Error)]
15pub enum SshError {
16 #[error("Failed to spawn SSH process: {0}")]
17 SpawnFailed(#[from] std::io::Error),
18
19 #[error("Agent failed to start: {0}")]
20 AgentStartFailed(String),
21
22 #[error("Protocol version mismatch: expected {expected}, got {got}")]
23 VersionMismatch { expected: u32, got: u32 },
24
25 #[error("Connection closed")]
26 ConnectionClosed,
27
28 #[error("Authentication failed")]
29 AuthenticationFailed,
30}
31
32#[derive(Debug, Clone)]
34pub struct ConnectionParams {
35 pub user: String,
36 pub host: String,
37 pub port: Option<u16>,
38 pub identity_file: Option<PathBuf>,
39}
40
41impl ConnectionParams {
42 pub fn parse(s: &str) -> Option<Self> {
44 let (user_host, port) = if let Some((uh, p)) = s.rsplit_once(':') {
45 if let Ok(port) = p.parse::<u16>() {
46 (uh, Some(port))
47 } else {
48 (s, None)
49 }
50 } else {
51 (s, None)
52 };
53
54 let (user, host) = user_host.split_once('@')?;
55 if user.is_empty() || host.is_empty() {
56 return None;
57 }
58
59 Some(Self {
60 user: user.to_string(),
61 host: host.to_string(),
62 port,
63 identity_file: None,
64 })
65 }
66
67 pub fn to_string(&self) -> String {
69 if let Some(port) = self.port {
70 format!("{}@{}:{}", self.user, self.host, port)
71 } else {
72 format!("{}@{}", self.user, self.host)
73 }
74 }
75}
76
77pub struct SshConnection {
79 process: Child,
81 channel: std::sync::Arc<AgentChannel>,
83 params: ConnectionParams,
85}
86
87impl SshConnection {
88 pub async fn connect(params: ConnectionParams) -> Result<Self, SshError> {
90 let mut cmd = Command::new("ssh");
91
92 cmd.arg("-o").arg("StrictHostKeyChecking=accept-new");
94 if let Some(port) = params.port {
98 cmd.arg("-p").arg(port.to_string());
99 }
100
101 if let Some(ref identity) = params.identity_file {
102 cmd.arg("-i").arg(identity);
103 }
104
105 cmd.arg(format!("{}@{}", params.user, params.host));
106
107 let agent_len = AGENT_SOURCE.len();
116 let bootstrap = format!(
117 "python3 -u -c \"import sys;exec(sys.stdin.read({}))\"",
118 agent_len
119 );
120 cmd.arg(bootstrap);
121
122 cmd.stdin(Stdio::piped());
123 cmd.stdout(Stdio::piped());
124 cmd.stderr(Stdio::inherit());
126
127 let mut child = cmd.spawn()?;
128
129 let mut stdin = child
131 .stdin
132 .take()
133 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
134 let stdout = child
135 .stdout
136 .take()
137 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
138 stdin.write_all(AGENT_SOURCE.as_bytes()).await?;
142 stdin.flush().await?;
143
144 let mut reader = BufReader::new(stdout);
146
147 let mut ready_line = String::new();
151 match reader.read_line(&mut ready_line).await {
152 Ok(0) => {
153 return Err(SshError::AgentStartFailed(
154 "connection closed (check terminal for SSH errors)".to_string(),
155 ));
156 }
157 Ok(_) => {}
158 Err(e) => return Err(SshError::AgentStartFailed(format!("read error: {}", e))),
159 }
160
161 let ready: AgentResponse = serde_json::from_str(&ready_line).map_err(|e| {
162 SshError::AgentStartFailed(format!(
163 "invalid ready message '{}': {}",
164 ready_line.trim(),
165 e
166 ))
167 })?;
168
169 if !ready.is_ready() {
170 return Err(SshError::AgentStartFailed(
171 "agent did not send ready message".to_string(),
172 ));
173 }
174
175 let version = ready.version.unwrap_or(0);
177 if version != crate::services::remote::protocol::PROTOCOL_VERSION {
178 return Err(SshError::VersionMismatch {
179 expected: crate::services::remote::protocol::PROTOCOL_VERSION,
180 got: version,
181 });
182 }
183
184 let channel = std::sync::Arc::new(AgentChannel::new(reader, stdin));
186
187 Ok(Self {
188 process: child,
189 channel,
190 params,
191 })
192 }
193
194 pub fn channel(&self) -> std::sync::Arc<AgentChannel> {
196 self.channel.clone()
197 }
198
199 pub fn params(&self) -> &ConnectionParams {
201 &self.params
202 }
203
204 pub fn is_connected(&self) -> bool {
206 self.channel.is_connected()
207 }
208
209 pub fn connection_string(&self) -> String {
211 self.params.to_string()
212 }
213}
214
215impl Drop for SshConnection {
216 fn drop(&mut self) {
217 let _ = self.process.start_kill();
219 }
220}
221
222#[doc(hidden)]
227pub async fn spawn_local_agent() -> Result<std::sync::Arc<AgentChannel>, SshError> {
228 use tokio::process::Command as TokioCommand;
229
230 let mut child = TokioCommand::new("python3")
231 .arg("-u")
232 .arg("-c")
233 .arg(AGENT_SOURCE)
234 .stdin(Stdio::piped())
235 .stdout(Stdio::piped())
236 .stderr(Stdio::piped())
237 .spawn()?;
238
239 let stdin = child
240 .stdin
241 .take()
242 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
243 let stdout = child
244 .stdout
245 .take()
246 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
247
248 let mut reader = BufReader::new(stdout);
249
250 let mut ready_line = String::new();
252 reader.read_line(&mut ready_line).await?;
253
254 let ready: AgentResponse = serde_json::from_str(&ready_line)
255 .map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
256
257 if !ready.is_ready() {
258 return Err(SshError::AgentStartFailed(
259 "agent did not send ready message".to_string(),
260 ));
261 }
262
263 Ok(std::sync::Arc::new(AgentChannel::new(reader, stdin)))
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
271 fn test_parse_connection_params() {
272 let params = ConnectionParams::parse("user@host").unwrap();
273 assert_eq!(params.user, "user");
274 assert_eq!(params.host, "host");
275 assert_eq!(params.port, None);
276
277 let params = ConnectionParams::parse("user@host:22").unwrap();
278 assert_eq!(params.user, "user");
279 assert_eq!(params.host, "host");
280 assert_eq!(params.port, Some(22));
281
282 assert!(ConnectionParams::parse("hostonly").is_none());
283 assert!(ConnectionParams::parse("@host").is_none());
284 assert!(ConnectionParams::parse("user@").is_none());
285 }
286
287 #[test]
288 fn test_connection_string() {
289 let params = ConnectionParams {
290 user: "alice".to_string(),
291 host: "example.com".to_string(),
292 port: None,
293 identity_file: None,
294 };
295 assert_eq!(params.to_string(), "alice@example.com");
296
297 let params = ConnectionParams {
298 user: "bob".to_string(),
299 host: "server.local".to_string(),
300 port: Some(2222),
301 identity_file: None,
302 };
303 assert_eq!(params.to_string(), "bob@server.local:2222");
304 }
305}