fresh/services/remote/
connection.rs1use crate::services::remote::channel::AgentChannel;
6use crate::services::remote::protocol::AgentResponse;
7use crate::services::remote::AGENT_SOURCE;
8use std::path::PathBuf;
9use std::process::Stdio;
10use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
11use tokio::process::{Child, Command};
12
13#[derive(Debug, thiserror::Error)]
15pub enum SshError {
16 #[error("Failed to spawn SSH process: {0}")]
17 SpawnFailed(#[from] std::io::Error),
18
19 #[error("Agent failed to start: {0}")]
20 AgentStartFailed(String),
21
22 #[error("Protocol version mismatch: expected {expected}, got {got}")]
23 VersionMismatch { expected: u32, got: u32 },
24
25 #[error("Connection closed")]
26 ConnectionClosed,
27
28 #[error("Authentication failed")]
29 AuthenticationFailed,
30}
31
32#[derive(Debug, Clone)]
34pub struct ConnectionParams {
35 pub user: String,
36 pub host: String,
37 pub port: Option<u16>,
38 pub identity_file: Option<PathBuf>,
39}
40
41impl ConnectionParams {
42 pub fn parse(s: &str) -> Option<Self> {
44 let (user_host, port) = if let Some((uh, p)) = s.rsplit_once(':') {
45 if let Ok(port) = p.parse::<u16>() {
46 (uh, Some(port))
47 } else {
48 (s, None)
49 }
50 } else {
51 (s, None)
52 };
53
54 let (user, host) = user_host.split_once('@')?;
55 if user.is_empty() || host.is_empty() {
56 return None;
57 }
58
59 Some(Self {
60 user: user.to_string(),
61 host: host.to_string(),
62 port,
63 identity_file: None,
64 })
65 }
66
67 pub fn to_string(&self) -> String {
69 if let Some(port) = self.port {
70 format!("{}@{}:{}", self.user, self.host, port)
71 } else {
72 format!("{}@{}", self.user, self.host)
73 }
74 }
75}
76
77pub struct SshConnection {
79 process: Child,
81 channel: std::sync::Arc<AgentChannel>,
83 params: ConnectionParams,
85}
86
87impl SshConnection {
88 pub async fn connect(params: ConnectionParams) -> Result<Self, SshError> {
90 let mut cmd = Command::new("ssh");
91
92 cmd.arg("-o").arg("StrictHostKeyChecking=accept-new");
94 if let Some(port) = params.port {
98 cmd.arg("-p").arg(port.to_string());
99 }
100
101 if let Some(ref identity) = params.identity_file {
102 cmd.arg("-i").arg(identity);
103 }
104
105 cmd.arg(format!("{}@{}", params.user, params.host));
106
107 let agent_len = AGENT_SOURCE.len();
116 let bootstrap = format!(
117 "python3 -u -c \"import sys;exec(sys.stdin.read({}))\"",
118 agent_len
119 );
120 cmd.arg(bootstrap);
121
122 cmd.stdin(Stdio::piped());
123 cmd.stdout(Stdio::piped());
124 cmd.stderr(Stdio::inherit());
126
127 let mut child = cmd.spawn()?;
128
129 let mut stdin = child
131 .stdin
132 .take()
133 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
134 let stdout = child
135 .stdout
136 .take()
137 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
138 stdin.write_all(AGENT_SOURCE.as_bytes()).await?;
142 stdin.flush().await?;
143
144 let mut reader = BufReader::new(stdout);
146
147 let mut ready_line = String::new();
151 match reader.read_line(&mut ready_line).await {
152 Ok(0) => {
153 return Err(SshError::AgentStartFailed(
154 "connection closed (check terminal for SSH errors)".to_string(),
155 ));
156 }
157 Ok(_) => {}
158 Err(e) => return Err(SshError::AgentStartFailed(format!("read error: {}", e))),
159 }
160
161 let ready: AgentResponse = serde_json::from_str(&ready_line).map_err(|e| {
162 SshError::AgentStartFailed(format!(
163 "invalid ready message '{}': {}",
164 ready_line.trim(),
165 e
166 ))
167 })?;
168
169 if !ready.is_ready() {
170 return Err(SshError::AgentStartFailed(
171 "agent did not send ready message".to_string(),
172 ));
173 }
174
175 let version = ready.version.unwrap_or(0);
177 if version != crate::services::remote::protocol::PROTOCOL_VERSION {
178 return Err(SshError::VersionMismatch {
179 expected: crate::services::remote::protocol::PROTOCOL_VERSION,
180 got: version,
181 });
182 }
183
184 let channel = std::sync::Arc::new(AgentChannel::new(reader, stdin));
186
187 Ok(Self {
188 process: child,
189 channel,
190 params,
191 })
192 }
193
194 pub fn channel(&self) -> std::sync::Arc<AgentChannel> {
196 self.channel.clone()
197 }
198
199 pub fn params(&self) -> &ConnectionParams {
201 &self.params
202 }
203
204 pub fn is_connected(&self) -> bool {
206 self.channel.is_connected()
207 }
208
209 pub fn connection_string(&self) -> String {
211 self.params.to_string()
212 }
213}
214
215impl Drop for SshConnection {
216 fn drop(&mut self) {
217 if let Ok(()) = self.process.start_kill() {}
222 }
223}
224
225#[doc(hidden)]
230pub async fn spawn_local_agent() -> Result<std::sync::Arc<AgentChannel>, SshError> {
231 use tokio::process::Command as TokioCommand;
232
233 let mut child = TokioCommand::new("python3")
234 .arg("-u")
235 .arg("-c")
236 .arg(AGENT_SOURCE)
237 .stdin(Stdio::piped())
238 .stdout(Stdio::piped())
239 .stderr(Stdio::piped())
240 .spawn()?;
241
242 let stdin = child
243 .stdin
244 .take()
245 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
246 let stdout = child
247 .stdout
248 .take()
249 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
250
251 let mut reader = BufReader::new(stdout);
252
253 let mut ready_line = String::new();
255 reader.read_line(&mut ready_line).await?;
256
257 let ready: AgentResponse = serde_json::from_str(&ready_line)
258 .map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
259
260 if !ready.is_ready() {
261 return Err(SshError::AgentStartFailed(
262 "agent did not send ready message".to_string(),
263 ));
264 }
265
266 Ok(std::sync::Arc::new(AgentChannel::new(reader, stdin)))
267}
268
269#[doc(hidden)]
274pub async fn spawn_local_agent_with_capacity(
275 data_channel_capacity: usize,
276) -> Result<std::sync::Arc<AgentChannel>, SshError> {
277 use tokio::process::Command as TokioCommand;
278
279 let mut child = TokioCommand::new("python3")
280 .arg("-u")
281 .arg("-c")
282 .arg(AGENT_SOURCE)
283 .stdin(Stdio::piped())
284 .stdout(Stdio::piped())
285 .stderr(Stdio::piped())
286 .spawn()?;
287
288 let stdin = child
289 .stdin
290 .take()
291 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
292 let stdout = child
293 .stdout
294 .take()
295 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
296
297 let mut reader = BufReader::new(stdout);
298
299 let mut ready_line = String::new();
301 reader.read_line(&mut ready_line).await?;
302
303 let ready: AgentResponse = serde_json::from_str(&ready_line)
304 .map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
305
306 if !ready.is_ready() {
307 return Err(SshError::AgentStartFailed(
308 "agent did not send ready message".to_string(),
309 ));
310 }
311
312 Ok(std::sync::Arc::new(AgentChannel::with_capacity(
313 reader,
314 stdin,
315 data_channel_capacity,
316 )))
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[test]
324 fn test_parse_connection_params() {
325 let params = ConnectionParams::parse("user@host").unwrap();
326 assert_eq!(params.user, "user");
327 assert_eq!(params.host, "host");
328 assert_eq!(params.port, None);
329
330 let params = ConnectionParams::parse("user@host:22").unwrap();
331 assert_eq!(params.user, "user");
332 assert_eq!(params.host, "host");
333 assert_eq!(params.port, Some(22));
334
335 assert!(ConnectionParams::parse("hostonly").is_none());
336 assert!(ConnectionParams::parse("@host").is_none());
337 assert!(ConnectionParams::parse("user@").is_none());
338 }
339
340 #[test]
341 fn test_connection_string() {
342 let params = ConnectionParams {
343 user: "alice".to_string(),
344 host: "example.com".to_string(),
345 port: None,
346 identity_file: None,
347 };
348 assert_eq!(params.to_string(), "alice@example.com");
349
350 let params = ConnectionParams {
351 user: "bob".to_string(),
352 host: "server.local".to_string(),
353 port: Some(2222),
354 identity_file: None,
355 };
356 assert_eq!(params.to_string(), "bob@server.local:2222");
357 }
358}