1use crate::services::process_hidden::HideWindow;
6use crate::services::remote::channel::AgentChannel;
7use crate::services::remote::protocol::AgentResponse;
8use crate::services::remote::AGENT_SOURCE;
9use std::path::PathBuf;
10use std::process::Stdio;
11use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
12use tokio::process::{Child, ChildStderr, Command};
13
14#[derive(Debug, thiserror::Error)]
16pub enum SshError {
17 #[error("Failed to spawn SSH process ({0}). Is the `ssh` command installed and in your PATH?")]
18 SpawnFailed(#[from] std::io::Error),
19
20 #[error("Agent failed to start: {0}")]
21 AgentStartFailed(String),
22
23 #[error("Protocol version mismatch: expected {expected}, got {got}")]
24 VersionMismatch { expected: u32, got: u32 },
25
26 #[error("Connection closed")]
27 ConnectionClosed,
28
29 #[error("Authentication failed")]
30 AuthenticationFailed,
31}
32
33#[derive(Debug, Clone)]
35pub struct ConnectionParams {
36 pub user: Option<String>,
39 pub host: String,
40 pub port: Option<u16>,
41 pub identity_file: Option<PathBuf>,
42 pub extra_args: Vec<String>,
47}
48
49impl ConnectionParams {
50 pub fn parse(s: &str) -> Option<Self> {
53 let s = s.strip_prefix("ssh://").unwrap_or(s);
54 let (user_host, port) = if let Some((uh, p)) = s.rsplit_once(':') {
55 if let Ok(port) = p.parse::<u16>() {
56 (uh, Some(port))
57 } else {
58 (s, None)
59 }
60 } else {
61 (s, None)
62 };
63
64 let (user, host) = match user_host.split_once('@') {
65 Some((u, h)) => (Some(u.to_string()), h),
66 None => (None, user_host),
67 };
68 if host.is_empty() || user.as_deref() == Some("") {
69 return None;
70 }
71
72 Some(Self {
73 user,
74 host: host.to_string(),
75 port,
76 identity_file: None,
77 extra_args: Vec::new(),
78 })
79 }
80
81 pub fn ssh_target(&self) -> String {
84 match &self.user {
85 Some(user) if !user.is_empty() => format!("{user}@{}", self.host),
86 _ => self.host.clone(),
87 }
88 }
89}
90
91impl std::fmt::Display for ConnectionParams {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 match self.port {
94 Some(port) => write!(f, "{}:{}", self.ssh_target(), port),
95 None => write!(f, "{}", self.ssh_target()),
96 }
97 }
98}
99
100pub struct SshConnection {
102 process: Child,
104 channel: std::sync::Arc<AgentChannel>,
106 params: ConnectionParams,
108}
109
110impl SshConnection {
111 pub async fn connect(params: ConnectionParams) -> Result<Self, SshError> {
113 let mut cmd = Command::new("ssh");
114
115 cmd.arg("-o").arg("StrictHostKeyChecking=accept-new");
117
118 if let Some(port) = params.port {
119 cmd.arg("-p").arg(port.to_string());
120 }
121
122 if let Some(ref identity) = params.identity_file {
123 cmd.arg("-i").arg(identity);
124 }
125
126 cmd.args(¶ms.extra_args);
127 cmd.arg(params.ssh_target());
128
129 let agent_len = AGENT_SOURCE.len();
138 let bootstrap = format!(
139 "python3 -u -c \"import sys;exec(sys.stdin.read({}))\"",
140 agent_len
141 );
142 cmd.arg(bootstrap);
143
144 cmd.stdin(Stdio::piped());
145 cmd.stdout(Stdio::piped());
146 cmd.stderr(Stdio::piped());
155 cmd.kill_on_drop(true);
162 cmd.hide_window();
163
164 let mut child = cmd.spawn()?;
165
166 let mut stdin = child
168 .stdin
169 .take()
170 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
171 let stdout = child
172 .stdout
173 .take()
174 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
175 let stderr = child.stderr.take();
176
177 if stdin.write_all(AGENT_SOURCE.as_bytes()).await.is_err() || stdin.flush().await.is_err() {
184 return Err(ssh_eof_error(&mut child, ¶ms, stderr).await);
185 }
186
187 let mut reader = BufReader::new(stdout);
189
190 let mut ready_line = String::new();
194 match reader.read_line(&mut ready_line).await {
195 Ok(0) => {
196 return Err(ssh_eof_error(&mut child, ¶ms, stderr).await);
197 }
198 Ok(_) => {}
199 Err(e) => return Err(SshError::AgentStartFailed(format!("read error: {}", e))),
200 }
201
202 if let Some(mut stderr) = stderr {
207 tokio::spawn(async move {
208 let mut sink = tokio::io::sink();
209 #[allow(clippy::let_underscore_must_use)]
212 let _ = tokio::io::copy(&mut stderr, &mut sink).await;
213 });
214 }
215
216 let ready: AgentResponse = serde_json::from_str(&ready_line).map_err(|e| {
217 SshError::AgentStartFailed(format!(
218 "invalid ready message '{}': {}",
219 ready_line.trim(),
220 e
221 ))
222 })?;
223
224 if !ready.is_ready() {
225 return Err(SshError::AgentStartFailed(
226 "agent did not send ready message".to_string(),
227 ));
228 }
229
230 let version = ready.version.unwrap_or(0);
232 if version != crate::services::remote::protocol::PROTOCOL_VERSION {
233 return Err(SshError::VersionMismatch {
234 expected: crate::services::remote::protocol::PROTOCOL_VERSION,
235 got: version,
236 });
237 }
238
239 let channel = std::sync::Arc::new(AgentChannel::new(reader, stdin));
241
242 Ok(Self {
243 process: child,
244 channel,
245 params,
246 })
247 }
248
249 pub fn channel(&self) -> std::sync::Arc<AgentChannel> {
251 self.channel.clone()
252 }
253
254 pub fn params(&self) -> &ConnectionParams {
256 &self.params
257 }
258
259 pub fn is_connected(&self) -> bool {
261 self.channel.is_connected()
262 }
263
264 pub fn connection_string(&self) -> String {
266 self.params.to_string()
267 }
268}
269
270impl Drop for SshConnection {
271 fn drop(&mut self) {
272 if let Ok(()) = self.process.start_kill() {}
277 }
278}
279
280const DEFAULT_RECONNECT_INTERVAL: std::time::Duration = std::time::Duration::from_secs(5);
282
283pub struct ReconnectConfig {
285 pub interval: std::time::Duration,
287}
288
289impl Default for ReconnectConfig {
290 fn default() -> Self {
291 Self {
292 interval: DEFAULT_RECONNECT_INTERVAL,
293 }
294 }
295}
296
297pub fn spawn_reconnect_task(
307 channel: std::sync::Arc<AgentChannel>,
308 params: ConnectionParams,
309) -> tokio::task::JoinHandle<()> {
310 let connect_fn = move || {
311 let params = params.clone();
312 async move {
313 let (reader, writer, _child) = establish_ssh_transport(¶ms).await?;
314 let reader: Box<dyn tokio::io::AsyncBufRead + Unpin + Send> = Box::new(reader);
316 let writer: Box<dyn tokio::io::AsyncWrite + Unpin + Send> = Box::new(writer);
317 Ok::<_, SshError>((reader, writer))
318 }
319 };
320
321 spawn_reconnect_task_with(
322 channel,
323 connect_fn,
324 ReconnectConfig::default(),
325 "SSH remote",
326 )
327}
328
329pub fn spawn_reconnect_task_with<F, Fut>(
336 channel: std::sync::Arc<AgentChannel>,
337 connect_fn: F,
338 config: ReconnectConfig,
339 label: &'static str,
340) -> tokio::task::JoinHandle<()>
341where
342 F: Fn() -> Fut + Send + 'static,
343 Fut: std::future::Future<
344 Output = Result<
345 (
346 Box<dyn tokio::io::AsyncBufRead + Unpin + Send>,
347 Box<dyn tokio::io::AsyncWrite + Unpin + Send>,
348 ),
349 SshError,
350 >,
351 > + Send,
352{
353 tokio::spawn(async move {
354 loop {
355 while channel.is_connected() {
357 tokio::time::sleep(config.interval).await;
358 }
359
360 tracing::info!("{label}: connection lost, attempting reconnection...");
361
362 loop {
364 tokio::time::sleep(config.interval).await;
365
366 if !channel.is_connected() {
368 } else {
370 break;
372 }
373
374 match (connect_fn)().await {
375 Ok((reader, writer)) => {
376 tracing::info!("{label}: reconnected successfully");
377 channel.replace_transport(reader, writer).await;
378 break;
379 }
380 Err(e) => {
381 tracing::debug!("{label}: reconnection attempt failed: {e}");
382 }
383 }
384 }
385 }
386 })
387}
388
389pub const DEFAULT_HEARTBEAT_INTERVAL: std::time::Duration = std::time::Duration::from_secs(60);
393
394pub fn spawn_heartbeat_task(
411 channel: &std::sync::Arc<AgentChannel>,
412 interval: std::time::Duration,
413) -> tokio::task::JoinHandle<()> {
414 let weak = std::sync::Arc::downgrade(channel);
415 tokio::spawn(async move {
416 loop {
417 tokio::time::sleep(interval).await;
418 let Some(channel) = weak.upgrade() else {
419 break;
420 };
421 if channel.is_connected() {
422 let _ping = channel.request("info", serde_json::json!({})).await;
427 }
428 }
429 })
430}
431
432async fn ssh_eof_error(
439 child: &mut Child,
440 params: &ConnectionParams,
441 stderr: Option<ChildStderr>,
442) -> SshError {
443 let status = tokio::time::timeout(std::time::Duration::from_secs(5), child.wait()).await;
445
446 let hint = match status {
447 Ok(Ok(status)) => {
448 match status.code() {
449 Some(255) => format!(
453 "SSH could not connect to {}. Check that the host is \
454 reachable, the hostname is correct, and your SSH \
455 credentials are valid (exit code 255)",
456 params
457 ),
458 Some(127) => format!(
463 "Python 3 was not found on the remote host {}. \
464 Fresh's remote support requires python3 on the remote — \
465 install it there, then reconnect",
466 params
467 ),
468 Some(code) => format!(
469 "SSH process exited with code {} while connecting to {}",
470 code, params
471 ),
472 None => format!(
473 "SSH process was killed by a signal while connecting to {}",
474 params
475 ),
476 }
477 }
478 Ok(Err(e)) => format!("failed to get SSH exit status: {}", e),
479 Err(_) => {
480 if let Err(e) = child.start_kill() {
482 tracing::warn!("Failed to kill timed-out SSH process: {}", e);
483 }
484 format!(
485 "SSH process did not exit in time while connecting to {}",
486 params
487 )
488 }
489 };
490
491 match read_ssh_stderr(stderr).await {
496 Some(detail) => SshError::AgentStartFailed(format!("{hint}: {detail}")),
497 None => SshError::AgentStartFailed(hint),
498 }
499}
500
501async fn read_ssh_stderr(stderr: Option<ChildStderr>) -> Option<String> {
506 let mut stderr = stderr?;
507 let mut buf = String::new();
508 #[allow(clippy::let_underscore_must_use)]
509 let _ = tokio::time::timeout(
510 std::time::Duration::from_secs(2),
511 stderr.read_to_string(&mut buf),
512 )
513 .await;
514 buf.trim()
515 .lines()
516 .map(str::trim)
517 .filter(|line| !line.is_empty())
518 .next_back()
519 .map(str::to_string)
520}
521
522async fn establish_ssh_transport(
526 params: &ConnectionParams,
527) -> Result<
528 (
529 BufReader<tokio::process::ChildStdout>,
530 tokio::process::ChildStdin,
531 Child,
532 ),
533 SshError,
534> {
535 let mut cmd = Command::new("ssh");
536
537 cmd.arg("-o").arg("StrictHostKeyChecking=accept-new");
538 cmd.arg("-o").arg("BatchMode=yes");
540
541 if let Some(port) = params.port {
542 cmd.arg("-p").arg(port.to_string());
543 }
544
545 if let Some(ref identity) = params.identity_file {
546 cmd.arg("-i").arg(identity);
547 }
548
549 cmd.args(¶ms.extra_args);
550 cmd.arg(params.ssh_target());
551
552 let agent_len = AGENT_SOURCE.len();
553 let bootstrap = format!(
554 "python3 -u -c \"import sys;exec(sys.stdin.read({}))\"",
555 agent_len
556 );
557 cmd.arg(bootstrap);
558
559 cmd.stdin(Stdio::piped());
560 cmd.stdout(Stdio::piped());
561 cmd.stderr(Stdio::null()); cmd.hide_window();
563
564 let mut child = cmd.spawn()?;
565
566 let mut stdin = child
567 .stdin
568 .take()
569 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
570 let stdout = child
571 .stdout
572 .take()
573 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
574
575 stdin.write_all(AGENT_SOURCE.as_bytes()).await?;
577 stdin.flush().await?;
578
579 let mut reader = BufReader::new(stdout);
580
581 let mut ready_line = String::new();
583 match reader.read_line(&mut ready_line).await {
584 Ok(0) => {
585 return Err(ssh_eof_error(&mut child, params, None).await);
588 }
589 Ok(_) => {}
590 Err(e) => return Err(SshError::AgentStartFailed(format!("read error: {}", e))),
591 }
592
593 let ready: AgentResponse = serde_json::from_str(&ready_line).map_err(|e| {
594 SshError::AgentStartFailed(format!(
595 "invalid ready message '{}': {}",
596 ready_line.trim(),
597 e
598 ))
599 })?;
600
601 if !ready.is_ready() {
602 return Err(SshError::AgentStartFailed(
603 "agent did not send ready message".to_string(),
604 ));
605 }
606
607 let version = ready.version.unwrap_or(0);
608 if version != crate::services::remote::protocol::PROTOCOL_VERSION {
609 return Err(SshError::VersionMismatch {
610 expected: crate::services::remote::protocol::PROTOCOL_VERSION,
611 got: version,
612 });
613 }
614
615 Ok((reader, stdin, child))
616}
617
618#[doc(hidden)]
623pub async fn spawn_local_agent() -> Result<std::sync::Arc<AgentChannel>, SshError> {
624 use tokio::process::Command as TokioCommand;
625
626 let mut child = TokioCommand::new("python3")
627 .arg("-u")
628 .arg("-c")
629 .arg(AGENT_SOURCE)
630 .stdin(Stdio::piped())
631 .stdout(Stdio::piped())
632 .stderr(Stdio::piped())
633 .hide_window()
634 .spawn()?;
635
636 let stdin = child
637 .stdin
638 .take()
639 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
640 let stdout = child
641 .stdout
642 .take()
643 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
644
645 let mut reader = BufReader::new(stdout);
646
647 let mut ready_line = String::new();
649 reader.read_line(&mut ready_line).await?;
650
651 let ready: AgentResponse = serde_json::from_str(&ready_line)
652 .map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
653
654 if !ready.is_ready() {
655 return Err(SshError::AgentStartFailed(
656 "agent did not send ready message".to_string(),
657 ));
658 }
659
660 Ok(std::sync::Arc::new(AgentChannel::new(reader, stdin)))
661}
662
663#[doc(hidden)]
668pub async fn spawn_local_agent_with_capacity(
669 data_channel_capacity: usize,
670) -> Result<std::sync::Arc<AgentChannel>, SshError> {
671 use tokio::process::Command as TokioCommand;
672
673 let mut child = TokioCommand::new("python3")
674 .arg("-u")
675 .arg("-c")
676 .arg(AGENT_SOURCE)
677 .stdin(Stdio::piped())
678 .stdout(Stdio::piped())
679 .stderr(Stdio::piped())
680 .hide_window()
681 .spawn()?;
682
683 let stdin = child
684 .stdin
685 .take()
686 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
687 let stdout = child
688 .stdout
689 .take()
690 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
691
692 let mut reader = BufReader::new(stdout);
693
694 let mut ready_line = String::new();
696 reader.read_line(&mut ready_line).await?;
697
698 let ready: AgentResponse = serde_json::from_str(&ready_line)
699 .map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
700
701 if !ready.is_ready() {
702 return Err(SshError::AgentStartFailed(
703 "agent did not send ready message".to_string(),
704 ));
705 }
706
707 Ok(std::sync::Arc::new(AgentChannel::with_capacity(
708 reader,
709 stdin,
710 data_channel_capacity,
711 )))
712}
713
714#[doc(hidden)]
720pub async fn spawn_local_agent_transport() -> Result<
721 (
722 tokio::io::BufReader<tokio::process::ChildStdout>,
723 tokio::process::ChildStdin,
724 ),
725 SshError,
726> {
727 use tokio::process::Command as TokioCommand;
728
729 let mut child = TokioCommand::new("python3")
730 .arg("-u")
731 .arg("-c")
732 .arg(AGENT_SOURCE)
733 .stdin(Stdio::piped())
734 .stdout(Stdio::piped())
735 .stderr(Stdio::piped())
736 .hide_window()
737 .spawn()?;
738
739 let stdin = child
740 .stdin
741 .take()
742 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
743 let stdout = child
744 .stdout
745 .take()
746 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
747
748 let mut reader = BufReader::new(stdout);
749
750 let mut ready_line = String::new();
752 reader.read_line(&mut ready_line).await?;
753
754 let ready: AgentResponse = serde_json::from_str(&ready_line)
755 .map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
756
757 if !ready.is_ready() {
758 return Err(SshError::AgentStartFailed(
759 "agent did not send ready message".to_string(),
760 ));
761 }
762
763 Ok((reader, stdin))
764}
765
766#[cfg(test)]
767mod tests {
768 use super::*;
769
770 #[test]
771 fn test_parse_connection_params() {
772 let params = ConnectionParams::parse("user@host").unwrap();
773 assert_eq!(params.user.as_deref(), Some("user"));
774 assert_eq!(params.host, "host");
775 assert_eq!(params.port, None);
776
777 let params = ConnectionParams::parse("user@host:22").unwrap();
778 assert_eq!(params.user.as_deref(), Some("user"));
779 assert_eq!(params.host, "host");
780 assert_eq!(params.port, Some(22));
781
782 let params = ConnectionParams::parse("hostonly").unwrap();
784 assert_eq!(params.user, None);
785 assert_eq!(params.host, "hostonly");
786 assert_eq!(params.ssh_target(), "hostonly");
787
788 let params = ConnectionParams::parse("ssh://example.com:2222").unwrap();
789 assert_eq!(params.user, None);
790 assert_eq!(params.host, "example.com");
791 assert_eq!(params.port, Some(2222));
792
793 assert!(ConnectionParams::parse("@host").is_none());
795 assert!(ConnectionParams::parse("user@").is_none());
796 }
797
798 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
799 async fn heartbeat_keeps_channel_warm_and_exits_on_drop() {
800 let channel = spawn_local_agent().await.expect("spawn local agent");
802 let handle = spawn_heartbeat_task(&channel, std::time::Duration::from_millis(30));
803
804 tokio::time::sleep(std::time::Duration::from_millis(150)).await;
806 assert!(
807 channel.is_connected(),
808 "channel stays connected while heartbeat pings"
809 );
810 assert!(
811 channel.request("info", serde_json::json!({})).await.is_ok(),
812 "agent still answers after heartbeats"
813 );
814
815 drop(channel);
818 tokio::time::timeout(std::time::Duration::from_secs(3), handle)
819 .await
820 .expect("heartbeat task exits after the channel is dropped")
821 .expect("heartbeat task did not panic");
822 }
823
824 #[test]
825 fn test_connection_string() {
826 let params = ConnectionParams {
827 user: Some("alice".to_string()),
828 host: "example.com".to_string(),
829 port: None,
830 identity_file: None,
831 extra_args: Vec::new(),
832 };
833 assert_eq!(params.to_string(), "alice@example.com");
834
835 let params = ConnectionParams {
836 user: Some("bob".to_string()),
837 host: "server.local".to_string(),
838 port: Some(2222),
839 identity_file: None,
840 extra_args: Vec::new(),
841 };
842 assert_eq!(params.to_string(), "bob@server.local:2222");
843
844 let params = ConnectionParams {
846 user: None,
847 host: "server.local".to_string(),
848 port: Some(2222),
849 identity_file: None,
850 extra_args: Vec::new(),
851 };
852 assert_eq!(params.to_string(), "server.local:2222");
853 assert_eq!(params.ssh_target(), "server.local");
854 }
855}