1use crate::services::process_hidden::HideWindow;
6use crate::services::remote::channel::AgentChannel;
7use crate::services::remote::protocol::AgentResponse;
8use crate::services::remote::AGENT_SOURCE;
9use std::path::PathBuf;
10use std::process::Stdio;
11use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
12use tokio::process::{Child, ChildStderr, Command};
13
14#[derive(Debug, thiserror::Error)]
16pub enum SshError {
17 #[error("Failed to spawn SSH process ({0}). Is the `ssh` command installed and in your PATH?")]
18 SpawnFailed(#[from] std::io::Error),
19
20 #[error("Agent failed to start: {0}")]
21 AgentStartFailed(String),
22
23 #[error("Protocol version mismatch: expected {expected}, got {got}")]
24 VersionMismatch { expected: u32, got: u32 },
25
26 #[error("Connection closed")]
27 ConnectionClosed,
28
29 #[error("Authentication failed")]
30 AuthenticationFailed,
31}
32
33#[derive(Debug, Clone)]
35pub struct ConnectionParams {
36 pub user: Option<String>,
39 pub host: String,
40 pub port: Option<u16>,
41 pub identity_file: Option<PathBuf>,
42 pub extra_args: Vec<String>,
47}
48
49impl ConnectionParams {
50 pub fn parse(s: &str) -> Option<Self> {
53 let s = s.strip_prefix("ssh://").unwrap_or(s);
54 let (user_host, port) = if let Some((uh, p)) = s.rsplit_once(':') {
55 if let Ok(port) = p.parse::<u16>() {
56 (uh, Some(port))
57 } else {
58 (s, None)
59 }
60 } else {
61 (s, None)
62 };
63
64 let (user, host) = match user_host.split_once('@') {
65 Some((u, h)) => (Some(u.to_string()), h),
66 None => (None, user_host),
67 };
68 if host.is_empty() || user.as_deref() == Some("") {
69 return None;
70 }
71
72 Some(Self {
73 user,
74 host: host.to_string(),
75 port,
76 identity_file: None,
77 extra_args: Vec::new(),
78 })
79 }
80
81 pub fn ssh_target(&self) -> String {
84 match &self.user {
85 Some(user) if !user.is_empty() => format!("{user}@{}", self.host),
86 _ => self.host.clone(),
87 }
88 }
89}
90
91impl std::fmt::Display for ConnectionParams {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 match self.port {
94 Some(port) => write!(f, "{}:{}", self.ssh_target(), port),
95 None => write!(f, "{}", self.ssh_target()),
96 }
97 }
98}
99
100pub struct SshConnection {
102 process: Child,
104 channel: std::sync::Arc<AgentChannel>,
106 params: ConnectionParams,
108}
109
110impl SshConnection {
111 pub async fn connect(params: ConnectionParams) -> Result<Self, SshError> {
113 let mut cmd = Command::new("ssh");
114
115 cmd.arg("-o").arg("StrictHostKeyChecking=accept-new");
117
118 if let Some(port) = params.port {
119 cmd.arg("-p").arg(port.to_string());
120 }
121
122 if let Some(ref identity) = params.identity_file {
123 cmd.arg("-i").arg(identity);
124 }
125
126 cmd.args(¶ms.extra_args);
127 cmd.arg(params.ssh_target());
128
129 let agent_len = AGENT_SOURCE.len();
138 let bootstrap = format!(
139 "python3 -u -c \"import sys;exec(sys.stdin.read({}))\"",
140 agent_len
141 );
142 cmd.arg(bootstrap);
143
144 cmd.stdin(Stdio::piped());
145 cmd.stdout(Stdio::piped());
146 cmd.stderr(Stdio::piped());
155 cmd.kill_on_drop(true);
162 cmd.hide_window();
163
164 let mut child = cmd.spawn()?;
165
166 let mut stdin = child
168 .stdin
169 .take()
170 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
171 let stdout = child
172 .stdout
173 .take()
174 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
175 let stderr = child.stderr.take();
176
177 if stdin.write_all(AGENT_SOURCE.as_bytes()).await.is_err() || stdin.flush().await.is_err() {
184 return Err(ssh_eof_error(&mut child, ¶ms, stderr).await);
185 }
186
187 let mut reader = BufReader::new(stdout);
189
190 let mut ready_line = String::new();
194 match reader.read_line(&mut ready_line).await {
195 Ok(0) => {
196 return Err(ssh_eof_error(&mut child, ¶ms, stderr).await);
197 }
198 Ok(_) => {}
199 Err(e) => return Err(SshError::AgentStartFailed(format!("read error: {}", e))),
200 }
201
202 if let Some(mut stderr) = stderr {
207 tokio::spawn(async move {
208 let mut sink = tokio::io::sink();
209 #[allow(clippy::let_underscore_must_use)]
212 let _ = tokio::io::copy(&mut stderr, &mut sink).await;
213 });
214 }
215
216 let ready: AgentResponse = serde_json::from_str(&ready_line).map_err(|e| {
217 SshError::AgentStartFailed(format!(
218 "invalid ready message '{}': {}",
219 ready_line.trim(),
220 e
221 ))
222 })?;
223
224 if !ready.is_ready() {
225 return Err(SshError::AgentStartFailed(
226 "agent did not send ready message".to_string(),
227 ));
228 }
229
230 let version = ready.version.unwrap_or(0);
232 if version != crate::services::remote::protocol::PROTOCOL_VERSION {
233 return Err(SshError::VersionMismatch {
234 expected: crate::services::remote::protocol::PROTOCOL_VERSION,
235 got: version,
236 });
237 }
238
239 let channel = std::sync::Arc::new(AgentChannel::new(reader, stdin));
241
242 Ok(Self {
243 process: child,
244 channel,
245 params,
246 })
247 }
248
249 pub fn channel(&self) -> std::sync::Arc<AgentChannel> {
251 self.channel.clone()
252 }
253
254 pub fn params(&self) -> &ConnectionParams {
256 &self.params
257 }
258
259 pub fn is_connected(&self) -> bool {
261 self.channel.is_connected()
262 }
263
264 pub fn connection_string(&self) -> String {
266 self.params.to_string()
267 }
268}
269
270impl Drop for SshConnection {
271 fn drop(&mut self) {
272 if let Ok(()) = self.process.start_kill() {}
277 }
278}
279
280const DEFAULT_RECONNECT_INTERVAL: std::time::Duration = std::time::Duration::from_secs(5);
282
283pub struct ReconnectConfig {
285 pub interval: std::time::Duration,
287}
288
289impl Default for ReconnectConfig {
290 fn default() -> Self {
291 Self {
292 interval: DEFAULT_RECONNECT_INTERVAL,
293 }
294 }
295}
296
297pub fn spawn_reconnect_task(
307 channel: std::sync::Arc<AgentChannel>,
308 params: ConnectionParams,
309) -> tokio::task::JoinHandle<()> {
310 let connect_fn = move || {
311 let params = params.clone();
312 async move {
313 let (reader, writer, _child) = establish_ssh_transport(¶ms).await?;
314 let reader: Box<dyn tokio::io::AsyncBufRead + Unpin + Send> = Box::new(reader);
316 let writer: Box<dyn tokio::io::AsyncWrite + Unpin + Send> = Box::new(writer);
317 Ok::<_, SshError>((reader, writer))
318 }
319 };
320
321 spawn_reconnect_task_with(
322 channel,
323 connect_fn,
324 ReconnectConfig::default(),
325 "SSH remote",
326 )
327}
328
329pub fn spawn_reconnect_task_with<F, Fut>(
336 channel: std::sync::Arc<AgentChannel>,
337 connect_fn: F,
338 config: ReconnectConfig,
339 label: &'static str,
340) -> tokio::task::JoinHandle<()>
341where
342 F: Fn() -> Fut + Send + 'static,
343 Fut: std::future::Future<
344 Output = Result<
345 (
346 Box<dyn tokio::io::AsyncBufRead + Unpin + Send>,
347 Box<dyn tokio::io::AsyncWrite + Unpin + Send>,
348 ),
349 SshError,
350 >,
351 > + Send,
352{
353 tokio::spawn(async move {
354 loop {
355 while channel.is_connected() {
357 tokio::time::sleep(config.interval).await;
358 }
359
360 tracing::info!("{label}: connection lost, attempting reconnection...");
361
362 loop {
364 tokio::time::sleep(config.interval).await;
365
366 if !channel.is_connected() {
368 } else {
370 break;
372 }
373
374 match (connect_fn)().await {
375 Ok((reader, writer)) => {
376 tracing::info!("{label}: reconnected successfully");
377 channel.replace_transport(reader, writer).await;
378 break;
379 }
380 Err(e) => {
381 tracing::debug!("{label}: reconnection attempt failed: {e}");
382 }
383 }
384 }
385 }
386 })
387}
388
389pub const DEFAULT_HEARTBEAT_INTERVAL: std::time::Duration = std::time::Duration::from_secs(60);
393
394pub fn spawn_heartbeat_task(
411 channel: &std::sync::Arc<AgentChannel>,
412 interval: std::time::Duration,
413) -> tokio::task::JoinHandle<()> {
414 let weak = std::sync::Arc::downgrade(channel);
415 tokio::spawn(async move {
416 loop {
417 tokio::time::sleep(interval).await;
418 let Some(channel) = weak.upgrade() else {
419 break;
420 };
421 if channel.is_connected() {
422 let _ping = channel.request("info", serde_json::json!({})).await;
427 }
428 }
429 })
430}
431
432async fn ssh_eof_error(
439 child: &mut Child,
440 params: &ConnectionParams,
441 stderr: Option<ChildStderr>,
442) -> SshError {
443 let status = tokio::time::timeout(std::time::Duration::from_secs(5), child.wait()).await;
445
446 let hint = match status {
447 Ok(Ok(status)) => {
448 match status.code() {
449 Some(255) => format!(
453 "SSH could not connect to {}. Check that the host is \
454 reachable, the hostname is correct, and your SSH \
455 credentials are valid (exit code 255)",
456 params
457 ),
458 Some(127) => format!(
459 "python3 was not found on the remote host {}. \
460 Ensure Python 3 is installed on the remote machine",
461 params
462 ),
463 Some(code) => format!(
464 "SSH process exited with code {} while connecting to {}",
465 code, params
466 ),
467 None => format!(
468 "SSH process was killed by a signal while connecting to {}",
469 params
470 ),
471 }
472 }
473 Ok(Err(e)) => format!("failed to get SSH exit status: {}", e),
474 Err(_) => {
475 if let Err(e) = child.start_kill() {
477 tracing::warn!("Failed to kill timed-out SSH process: {}", e);
478 }
479 format!(
480 "SSH process did not exit in time while connecting to {}",
481 params
482 )
483 }
484 };
485
486 match read_ssh_stderr(stderr).await {
491 Some(detail) => SshError::AgentStartFailed(format!("{hint}: {detail}")),
492 None => SshError::AgentStartFailed(hint),
493 }
494}
495
496async fn read_ssh_stderr(stderr: Option<ChildStderr>) -> Option<String> {
501 let mut stderr = stderr?;
502 let mut buf = String::new();
503 #[allow(clippy::let_underscore_must_use)]
504 let _ = tokio::time::timeout(
505 std::time::Duration::from_secs(2),
506 stderr.read_to_string(&mut buf),
507 )
508 .await;
509 buf.trim()
510 .lines()
511 .map(str::trim)
512 .filter(|line| !line.is_empty())
513 .next_back()
514 .map(str::to_string)
515}
516
517async fn establish_ssh_transport(
521 params: &ConnectionParams,
522) -> Result<
523 (
524 BufReader<tokio::process::ChildStdout>,
525 tokio::process::ChildStdin,
526 Child,
527 ),
528 SshError,
529> {
530 let mut cmd = Command::new("ssh");
531
532 cmd.arg("-o").arg("StrictHostKeyChecking=accept-new");
533 cmd.arg("-o").arg("BatchMode=yes");
535
536 if let Some(port) = params.port {
537 cmd.arg("-p").arg(port.to_string());
538 }
539
540 if let Some(ref identity) = params.identity_file {
541 cmd.arg("-i").arg(identity);
542 }
543
544 cmd.args(¶ms.extra_args);
545 cmd.arg(params.ssh_target());
546
547 let agent_len = AGENT_SOURCE.len();
548 let bootstrap = format!(
549 "python3 -u -c \"import sys;exec(sys.stdin.read({}))\"",
550 agent_len
551 );
552 cmd.arg(bootstrap);
553
554 cmd.stdin(Stdio::piped());
555 cmd.stdout(Stdio::piped());
556 cmd.stderr(Stdio::null()); cmd.hide_window();
558
559 let mut child = cmd.spawn()?;
560
561 let mut stdin = child
562 .stdin
563 .take()
564 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
565 let stdout = child
566 .stdout
567 .take()
568 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
569
570 stdin.write_all(AGENT_SOURCE.as_bytes()).await?;
572 stdin.flush().await?;
573
574 let mut reader = BufReader::new(stdout);
575
576 let mut ready_line = String::new();
578 match reader.read_line(&mut ready_line).await {
579 Ok(0) => {
580 return Err(ssh_eof_error(&mut child, params, None).await);
583 }
584 Ok(_) => {}
585 Err(e) => return Err(SshError::AgentStartFailed(format!("read error: {}", e))),
586 }
587
588 let ready: AgentResponse = serde_json::from_str(&ready_line).map_err(|e| {
589 SshError::AgentStartFailed(format!(
590 "invalid ready message '{}': {}",
591 ready_line.trim(),
592 e
593 ))
594 })?;
595
596 if !ready.is_ready() {
597 return Err(SshError::AgentStartFailed(
598 "agent did not send ready message".to_string(),
599 ));
600 }
601
602 let version = ready.version.unwrap_or(0);
603 if version != crate::services::remote::protocol::PROTOCOL_VERSION {
604 return Err(SshError::VersionMismatch {
605 expected: crate::services::remote::protocol::PROTOCOL_VERSION,
606 got: version,
607 });
608 }
609
610 Ok((reader, stdin, child))
611}
612
613#[doc(hidden)]
618pub async fn spawn_local_agent() -> Result<std::sync::Arc<AgentChannel>, SshError> {
619 use tokio::process::Command as TokioCommand;
620
621 let mut child = TokioCommand::new("python3")
622 .arg("-u")
623 .arg("-c")
624 .arg(AGENT_SOURCE)
625 .stdin(Stdio::piped())
626 .stdout(Stdio::piped())
627 .stderr(Stdio::piped())
628 .hide_window()
629 .spawn()?;
630
631 let stdin = child
632 .stdin
633 .take()
634 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
635 let stdout = child
636 .stdout
637 .take()
638 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
639
640 let mut reader = BufReader::new(stdout);
641
642 let mut ready_line = String::new();
644 reader.read_line(&mut ready_line).await?;
645
646 let ready: AgentResponse = serde_json::from_str(&ready_line)
647 .map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
648
649 if !ready.is_ready() {
650 return Err(SshError::AgentStartFailed(
651 "agent did not send ready message".to_string(),
652 ));
653 }
654
655 Ok(std::sync::Arc::new(AgentChannel::new(reader, stdin)))
656}
657
658#[doc(hidden)]
663pub async fn spawn_local_agent_with_capacity(
664 data_channel_capacity: usize,
665) -> Result<std::sync::Arc<AgentChannel>, SshError> {
666 use tokio::process::Command as TokioCommand;
667
668 let mut child = TokioCommand::new("python3")
669 .arg("-u")
670 .arg("-c")
671 .arg(AGENT_SOURCE)
672 .stdin(Stdio::piped())
673 .stdout(Stdio::piped())
674 .stderr(Stdio::piped())
675 .hide_window()
676 .spawn()?;
677
678 let stdin = child
679 .stdin
680 .take()
681 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
682 let stdout = child
683 .stdout
684 .take()
685 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
686
687 let mut reader = BufReader::new(stdout);
688
689 let mut ready_line = String::new();
691 reader.read_line(&mut ready_line).await?;
692
693 let ready: AgentResponse = serde_json::from_str(&ready_line)
694 .map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
695
696 if !ready.is_ready() {
697 return Err(SshError::AgentStartFailed(
698 "agent did not send ready message".to_string(),
699 ));
700 }
701
702 Ok(std::sync::Arc::new(AgentChannel::with_capacity(
703 reader,
704 stdin,
705 data_channel_capacity,
706 )))
707}
708
709#[doc(hidden)]
715pub async fn spawn_local_agent_transport() -> Result<
716 (
717 tokio::io::BufReader<tokio::process::ChildStdout>,
718 tokio::process::ChildStdin,
719 ),
720 SshError,
721> {
722 use tokio::process::Command as TokioCommand;
723
724 let mut child = TokioCommand::new("python3")
725 .arg("-u")
726 .arg("-c")
727 .arg(AGENT_SOURCE)
728 .stdin(Stdio::piped())
729 .stdout(Stdio::piped())
730 .stderr(Stdio::piped())
731 .hide_window()
732 .spawn()?;
733
734 let stdin = child
735 .stdin
736 .take()
737 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
738 let stdout = child
739 .stdout
740 .take()
741 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
742
743 let mut reader = BufReader::new(stdout);
744
745 let mut ready_line = String::new();
747 reader.read_line(&mut ready_line).await?;
748
749 let ready: AgentResponse = serde_json::from_str(&ready_line)
750 .map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
751
752 if !ready.is_ready() {
753 return Err(SshError::AgentStartFailed(
754 "agent did not send ready message".to_string(),
755 ));
756 }
757
758 Ok((reader, stdin))
759}
760
761#[cfg(test)]
762mod tests {
763 use super::*;
764
765 #[test]
766 fn test_parse_connection_params() {
767 let params = ConnectionParams::parse("user@host").unwrap();
768 assert_eq!(params.user.as_deref(), Some("user"));
769 assert_eq!(params.host, "host");
770 assert_eq!(params.port, None);
771
772 let params = ConnectionParams::parse("user@host:22").unwrap();
773 assert_eq!(params.user.as_deref(), Some("user"));
774 assert_eq!(params.host, "host");
775 assert_eq!(params.port, Some(22));
776
777 let params = ConnectionParams::parse("hostonly").unwrap();
779 assert_eq!(params.user, None);
780 assert_eq!(params.host, "hostonly");
781 assert_eq!(params.ssh_target(), "hostonly");
782
783 let params = ConnectionParams::parse("ssh://example.com:2222").unwrap();
784 assert_eq!(params.user, None);
785 assert_eq!(params.host, "example.com");
786 assert_eq!(params.port, Some(2222));
787
788 assert!(ConnectionParams::parse("@host").is_none());
790 assert!(ConnectionParams::parse("user@").is_none());
791 }
792
793 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
794 async fn heartbeat_keeps_channel_warm_and_exits_on_drop() {
795 let channel = spawn_local_agent().await.expect("spawn local agent");
797 let handle = spawn_heartbeat_task(&channel, std::time::Duration::from_millis(30));
798
799 tokio::time::sleep(std::time::Duration::from_millis(150)).await;
801 assert!(
802 channel.is_connected(),
803 "channel stays connected while heartbeat pings"
804 );
805 assert!(
806 channel.request("info", serde_json::json!({})).await.is_ok(),
807 "agent still answers after heartbeats"
808 );
809
810 drop(channel);
813 tokio::time::timeout(std::time::Duration::from_secs(3), handle)
814 .await
815 .expect("heartbeat task exits after the channel is dropped")
816 .expect("heartbeat task did not panic");
817 }
818
819 #[test]
820 fn test_connection_string() {
821 let params = ConnectionParams {
822 user: Some("alice".to_string()),
823 host: "example.com".to_string(),
824 port: None,
825 identity_file: None,
826 extra_args: Vec::new(),
827 };
828 assert_eq!(params.to_string(), "alice@example.com");
829
830 let params = ConnectionParams {
831 user: Some("bob".to_string()),
832 host: "server.local".to_string(),
833 port: Some(2222),
834 identity_file: None,
835 extra_args: Vec::new(),
836 };
837 assert_eq!(params.to_string(), "bob@server.local:2222");
838
839 let params = ConnectionParams {
841 user: None,
842 host: "server.local".to_string(),
843 port: Some(2222),
844 identity_file: None,
845 extra_args: Vec::new(),
846 };
847 assert_eq!(params.to_string(), "server.local:2222");
848 assert_eq!(params.ssh_target(), "server.local");
849 }
850}