1use crate::services::remote::channel::AgentChannel;
6use crate::services::remote::protocol::AgentResponse;
7use crate::services::remote::AGENT_SOURCE;
8use std::path::PathBuf;
9use std::process::Stdio;
10use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
11use tokio::process::{Child, Command};
12
13#[derive(Debug, thiserror::Error)]
15pub enum SshError {
16 #[error("Failed to spawn SSH process ({0}). Is the `ssh` command installed and in your PATH?")]
17 SpawnFailed(#[from] std::io::Error),
18
19 #[error("Agent failed to start: {0}")]
20 AgentStartFailed(String),
21
22 #[error("Protocol version mismatch: expected {expected}, got {got}")]
23 VersionMismatch { expected: u32, got: u32 },
24
25 #[error("Connection closed")]
26 ConnectionClosed,
27
28 #[error("Authentication failed")]
29 AuthenticationFailed,
30}
31
32#[derive(Debug, Clone)]
34pub struct ConnectionParams {
35 pub user: String,
36 pub host: String,
37 pub port: Option<u16>,
38 pub identity_file: Option<PathBuf>,
39}
40
41impl ConnectionParams {
42 pub fn parse(s: &str) -> Option<Self> {
44 let (user_host, port) = if let Some((uh, p)) = s.rsplit_once(':') {
45 if let Ok(port) = p.parse::<u16>() {
46 (uh, Some(port))
47 } else {
48 (s, None)
49 }
50 } else {
51 (s, None)
52 };
53
54 let (user, host) = user_host.split_once('@')?;
55 if user.is_empty() || host.is_empty() {
56 return None;
57 }
58
59 Some(Self {
60 user: user.to_string(),
61 host: host.to_string(),
62 port,
63 identity_file: None,
64 })
65 }
66}
67
68impl std::fmt::Display for ConnectionParams {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 if let Some(port) = self.port {
71 write!(f, "{}@{}:{}", self.user, self.host, port)
72 } else {
73 write!(f, "{}@{}", self.user, self.host)
74 }
75 }
76}
77
78pub struct SshConnection {
80 process: Child,
82 channel: std::sync::Arc<AgentChannel>,
84 params: ConnectionParams,
86}
87
88impl SshConnection {
89 pub async fn connect(params: ConnectionParams) -> Result<Self, SshError> {
91 let mut cmd = Command::new("ssh");
92
93 cmd.arg("-o").arg("StrictHostKeyChecking=accept-new");
95 if let Some(port) = params.port {
99 cmd.arg("-p").arg(port.to_string());
100 }
101
102 if let Some(ref identity) = params.identity_file {
103 cmd.arg("-i").arg(identity);
104 }
105
106 cmd.arg(format!("{}@{}", params.user, params.host));
107
108 let agent_len = AGENT_SOURCE.len();
117 let bootstrap = format!(
118 "python3 -u -c \"import sys;exec(sys.stdin.read({}))\"",
119 agent_len
120 );
121 cmd.arg(bootstrap);
122
123 cmd.stdin(Stdio::piped());
124 cmd.stdout(Stdio::piped());
125 cmd.stderr(Stdio::inherit());
127
128 let mut child = cmd.spawn()?;
129
130 let mut stdin = child
132 .stdin
133 .take()
134 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
135 let stdout = child
136 .stdout
137 .take()
138 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
139 stdin.write_all(AGENT_SOURCE.as_bytes()).await?;
143 stdin.flush().await?;
144
145 let mut reader = BufReader::new(stdout);
147
148 let mut ready_line = String::new();
152 match reader.read_line(&mut ready_line).await {
153 Ok(0) => {
154 return Err(ssh_eof_error(&mut child, ¶ms).await);
155 }
156 Ok(_) => {}
157 Err(e) => return Err(SshError::AgentStartFailed(format!("read error: {}", e))),
158 }
159
160 let ready: AgentResponse = serde_json::from_str(&ready_line).map_err(|e| {
161 SshError::AgentStartFailed(format!(
162 "invalid ready message '{}': {}",
163 ready_line.trim(),
164 e
165 ))
166 })?;
167
168 if !ready.is_ready() {
169 return Err(SshError::AgentStartFailed(
170 "agent did not send ready message".to_string(),
171 ));
172 }
173
174 let version = ready.version.unwrap_or(0);
176 if version != crate::services::remote::protocol::PROTOCOL_VERSION {
177 return Err(SshError::VersionMismatch {
178 expected: crate::services::remote::protocol::PROTOCOL_VERSION,
179 got: version,
180 });
181 }
182
183 let channel = std::sync::Arc::new(AgentChannel::new(reader, stdin));
185
186 Ok(Self {
187 process: child,
188 channel,
189 params,
190 })
191 }
192
193 pub fn channel(&self) -> std::sync::Arc<AgentChannel> {
195 self.channel.clone()
196 }
197
198 pub fn params(&self) -> &ConnectionParams {
200 &self.params
201 }
202
203 pub fn is_connected(&self) -> bool {
205 self.channel.is_connected()
206 }
207
208 pub fn connection_string(&self) -> String {
210 self.params.to_string()
211 }
212}
213
214impl Drop for SshConnection {
215 fn drop(&mut self) {
216 if let Ok(()) = self.process.start_kill() {}
221 }
222}
223
224const DEFAULT_RECONNECT_INTERVAL: std::time::Duration = std::time::Duration::from_secs(5);
226
227pub struct ReconnectConfig {
229 pub interval: std::time::Duration,
231}
232
233impl Default for ReconnectConfig {
234 fn default() -> Self {
235 Self {
236 interval: DEFAULT_RECONNECT_INTERVAL,
237 }
238 }
239}
240
241pub fn spawn_reconnect_task(
251 channel: std::sync::Arc<AgentChannel>,
252 params: ConnectionParams,
253) -> tokio::task::JoinHandle<()> {
254 let connect_fn = move || {
255 let params = params.clone();
256 async move {
257 let (reader, writer, _child) = establish_ssh_transport(¶ms).await?;
258 let reader: Box<dyn tokio::io::AsyncBufRead + Unpin + Send> = Box::new(reader);
260 let writer: Box<dyn tokio::io::AsyncWrite + Unpin + Send> = Box::new(writer);
261 Ok::<_, SshError>((reader, writer))
262 }
263 };
264
265 spawn_reconnect_task_with(
266 channel,
267 connect_fn,
268 ReconnectConfig::default(),
269 "SSH remote",
270 )
271}
272
273pub fn spawn_reconnect_task_with<F, Fut>(
280 channel: std::sync::Arc<AgentChannel>,
281 connect_fn: F,
282 config: ReconnectConfig,
283 label: &'static str,
284) -> tokio::task::JoinHandle<()>
285where
286 F: Fn() -> Fut + Send + 'static,
287 Fut: std::future::Future<
288 Output = Result<
289 (
290 Box<dyn tokio::io::AsyncBufRead + Unpin + Send>,
291 Box<dyn tokio::io::AsyncWrite + Unpin + Send>,
292 ),
293 SshError,
294 >,
295 > + Send,
296{
297 tokio::spawn(async move {
298 loop {
299 while channel.is_connected() {
301 tokio::time::sleep(config.interval).await;
302 }
303
304 tracing::info!("{label}: connection lost, attempting reconnection...");
305
306 loop {
308 tokio::time::sleep(config.interval).await;
309
310 if !channel.is_connected() {
312 } else {
314 break;
316 }
317
318 match (connect_fn)().await {
319 Ok((reader, writer)) => {
320 tracing::info!("{label}: reconnected successfully");
321 channel.replace_transport(reader, writer).await;
322 break;
323 }
324 Err(e) => {
325 tracing::debug!("{label}: reconnection attempt failed: {e}");
326 }
327 }
328 }
329 }
330 })
331}
332
333async fn ssh_eof_error(child: &mut Child, params: &ConnectionParams) -> SshError {
340 let status = tokio::time::timeout(std::time::Duration::from_secs(5), child.wait()).await;
342
343 let hint = match status {
344 Ok(Ok(status)) => {
345 match status.code() {
346 Some(255) => format!(
350 "SSH could not connect to {}. Check that the host is \
351 reachable, the hostname is correct, and your SSH \
352 credentials are valid (exit code 255)",
353 params
354 ),
355 Some(127) => format!(
356 "python3 was not found on the remote host {}. \
357 Ensure Python 3 is installed on the remote machine",
358 params
359 ),
360 Some(code) => format!(
361 "SSH process exited with code {} while connecting to {}",
362 code, params
363 ),
364 None => format!(
365 "SSH process was killed by a signal while connecting to {}",
366 params
367 ),
368 }
369 }
370 Ok(Err(e)) => format!("failed to get SSH exit status: {}", e),
371 Err(_) => {
372 if let Err(e) = child.start_kill() {
374 tracing::warn!("Failed to kill timed-out SSH process: {}", e);
375 }
376 format!(
377 "SSH process did not exit in time while connecting to {}",
378 params
379 )
380 }
381 };
382
383 SshError::AgentStartFailed(hint)
384}
385
386async fn establish_ssh_transport(
390 params: &ConnectionParams,
391) -> Result<
392 (
393 BufReader<tokio::process::ChildStdout>,
394 tokio::process::ChildStdin,
395 Child,
396 ),
397 SshError,
398> {
399 let mut cmd = Command::new("ssh");
400
401 cmd.arg("-o").arg("StrictHostKeyChecking=accept-new");
402 cmd.arg("-o").arg("BatchMode=yes");
404
405 if let Some(port) = params.port {
406 cmd.arg("-p").arg(port.to_string());
407 }
408
409 if let Some(ref identity) = params.identity_file {
410 cmd.arg("-i").arg(identity);
411 }
412
413 cmd.arg(format!("{}@{}", params.user, params.host));
414
415 let agent_len = AGENT_SOURCE.len();
416 let bootstrap = format!(
417 "python3 -u -c \"import sys;exec(sys.stdin.read({}))\"",
418 agent_len
419 );
420 cmd.arg(bootstrap);
421
422 cmd.stdin(Stdio::piped());
423 cmd.stdout(Stdio::piped());
424 cmd.stderr(Stdio::null()); let mut child = cmd.spawn()?;
427
428 let mut stdin = child
429 .stdin
430 .take()
431 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
432 let stdout = child
433 .stdout
434 .take()
435 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
436
437 stdin.write_all(AGENT_SOURCE.as_bytes()).await?;
439 stdin.flush().await?;
440
441 let mut reader = BufReader::new(stdout);
442
443 let mut ready_line = String::new();
445 match reader.read_line(&mut ready_line).await {
446 Ok(0) => {
447 return Err(ssh_eof_error(&mut child, params).await);
448 }
449 Ok(_) => {}
450 Err(e) => return Err(SshError::AgentStartFailed(format!("read error: {}", e))),
451 }
452
453 let ready: AgentResponse = serde_json::from_str(&ready_line).map_err(|e| {
454 SshError::AgentStartFailed(format!(
455 "invalid ready message '{}': {}",
456 ready_line.trim(),
457 e
458 ))
459 })?;
460
461 if !ready.is_ready() {
462 return Err(SshError::AgentStartFailed(
463 "agent did not send ready message".to_string(),
464 ));
465 }
466
467 let version = ready.version.unwrap_or(0);
468 if version != crate::services::remote::protocol::PROTOCOL_VERSION {
469 return Err(SshError::VersionMismatch {
470 expected: crate::services::remote::protocol::PROTOCOL_VERSION,
471 got: version,
472 });
473 }
474
475 Ok((reader, stdin, child))
476}
477
478#[doc(hidden)]
483pub async fn spawn_local_agent() -> Result<std::sync::Arc<AgentChannel>, SshError> {
484 use tokio::process::Command as TokioCommand;
485
486 let mut child = TokioCommand::new("python3")
487 .arg("-u")
488 .arg("-c")
489 .arg(AGENT_SOURCE)
490 .stdin(Stdio::piped())
491 .stdout(Stdio::piped())
492 .stderr(Stdio::piped())
493 .spawn()?;
494
495 let stdin = child
496 .stdin
497 .take()
498 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
499 let stdout = child
500 .stdout
501 .take()
502 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
503
504 let mut reader = BufReader::new(stdout);
505
506 let mut ready_line = String::new();
508 reader.read_line(&mut ready_line).await?;
509
510 let ready: AgentResponse = serde_json::from_str(&ready_line)
511 .map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
512
513 if !ready.is_ready() {
514 return Err(SshError::AgentStartFailed(
515 "agent did not send ready message".to_string(),
516 ));
517 }
518
519 Ok(std::sync::Arc::new(AgentChannel::new(reader, stdin)))
520}
521
522#[doc(hidden)]
527pub async fn spawn_local_agent_with_capacity(
528 data_channel_capacity: usize,
529) -> Result<std::sync::Arc<AgentChannel>, SshError> {
530 use tokio::process::Command as TokioCommand;
531
532 let mut child = TokioCommand::new("python3")
533 .arg("-u")
534 .arg("-c")
535 .arg(AGENT_SOURCE)
536 .stdin(Stdio::piped())
537 .stdout(Stdio::piped())
538 .stderr(Stdio::piped())
539 .spawn()?;
540
541 let stdin = child
542 .stdin
543 .take()
544 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
545 let stdout = child
546 .stdout
547 .take()
548 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
549
550 let mut reader = BufReader::new(stdout);
551
552 let mut ready_line = String::new();
554 reader.read_line(&mut ready_line).await?;
555
556 let ready: AgentResponse = serde_json::from_str(&ready_line)
557 .map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
558
559 if !ready.is_ready() {
560 return Err(SshError::AgentStartFailed(
561 "agent did not send ready message".to_string(),
562 ));
563 }
564
565 Ok(std::sync::Arc::new(AgentChannel::with_capacity(
566 reader,
567 stdin,
568 data_channel_capacity,
569 )))
570}
571
572#[doc(hidden)]
578pub async fn spawn_local_agent_transport() -> Result<
579 (
580 tokio::io::BufReader<tokio::process::ChildStdout>,
581 tokio::process::ChildStdin,
582 ),
583 SshError,
584> {
585 use tokio::process::Command as TokioCommand;
586
587 let mut child = TokioCommand::new("python3")
588 .arg("-u")
589 .arg("-c")
590 .arg(AGENT_SOURCE)
591 .stdin(Stdio::piped())
592 .stdout(Stdio::piped())
593 .stderr(Stdio::piped())
594 .spawn()?;
595
596 let stdin = child
597 .stdin
598 .take()
599 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
600 let stdout = child
601 .stdout
602 .take()
603 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
604
605 let mut reader = BufReader::new(stdout);
606
607 let mut ready_line = String::new();
609 reader.read_line(&mut ready_line).await?;
610
611 let ready: AgentResponse = serde_json::from_str(&ready_line)
612 .map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
613
614 if !ready.is_ready() {
615 return Err(SshError::AgentStartFailed(
616 "agent did not send ready message".to_string(),
617 ));
618 }
619
620 Ok((reader, stdin))
621}
622
623#[cfg(test)]
624mod tests {
625 use super::*;
626
627 #[test]
628 fn test_parse_connection_params() {
629 let params = ConnectionParams::parse("user@host").unwrap();
630 assert_eq!(params.user, "user");
631 assert_eq!(params.host, "host");
632 assert_eq!(params.port, None);
633
634 let params = ConnectionParams::parse("user@host:22").unwrap();
635 assert_eq!(params.user, "user");
636 assert_eq!(params.host, "host");
637 assert_eq!(params.port, Some(22));
638
639 assert!(ConnectionParams::parse("hostonly").is_none());
640 assert!(ConnectionParams::parse("@host").is_none());
641 assert!(ConnectionParams::parse("user@").is_none());
642 }
643
644 #[test]
645 fn test_connection_string() {
646 let params = ConnectionParams {
647 user: "alice".to_string(),
648 host: "example.com".to_string(),
649 port: None,
650 identity_file: None,
651 };
652 assert_eq!(params.to_string(), "alice@example.com");
653
654 let params = ConnectionParams {
655 user: "bob".to_string(),
656 host: "server.local".to_string(),
657 port: Some(2222),
658 identity_file: None,
659 };
660 assert_eq!(params.to_string(), "bob@server.local:2222");
661 }
662}