1use crate::services::process_hidden::HideWindow;
6use crate::services::remote::channel::AgentChannel;
7use crate::services::remote::protocol::AgentResponse;
8use crate::services::remote::AGENT_SOURCE;
9use std::path::PathBuf;
10use std::process::Stdio;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
12use tokio::process::{Child, Command};
13
14#[derive(Debug, thiserror::Error)]
16pub enum SshError {
17 #[error("Failed to spawn SSH process ({0}). Is the `ssh` command installed and in your PATH?")]
18 SpawnFailed(#[from] std::io::Error),
19
20 #[error("Agent failed to start: {0}")]
21 AgentStartFailed(String),
22
23 #[error("Protocol version mismatch: expected {expected}, got {got}")]
24 VersionMismatch { expected: u32, got: u32 },
25
26 #[error("Connection closed")]
27 ConnectionClosed,
28
29 #[error("Authentication failed")]
30 AuthenticationFailed,
31}
32
33#[derive(Debug, Clone)]
35pub struct ConnectionParams {
36 pub user: String,
37 pub host: String,
38 pub port: Option<u16>,
39 pub identity_file: Option<PathBuf>,
40}
41
42impl ConnectionParams {
43 pub fn parse(s: &str) -> Option<Self> {
45 let (user_host, port) = if let Some((uh, p)) = s.rsplit_once(':') {
46 if let Ok(port) = p.parse::<u16>() {
47 (uh, Some(port))
48 } else {
49 (s, None)
50 }
51 } else {
52 (s, None)
53 };
54
55 let (user, host) = user_host.split_once('@')?;
56 if user.is_empty() || host.is_empty() {
57 return None;
58 }
59
60 Some(Self {
61 user: user.to_string(),
62 host: host.to_string(),
63 port,
64 identity_file: None,
65 })
66 }
67}
68
69impl std::fmt::Display for ConnectionParams {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 if let Some(port) = self.port {
72 write!(f, "{}@{}:{}", self.user, self.host, port)
73 } else {
74 write!(f, "{}@{}", self.user, self.host)
75 }
76 }
77}
78
79pub struct SshConnection {
81 process: Child,
83 channel: std::sync::Arc<AgentChannel>,
85 params: ConnectionParams,
87}
88
89impl SshConnection {
90 pub async fn connect(params: ConnectionParams) -> Result<Self, SshError> {
92 let mut cmd = Command::new("ssh");
93
94 cmd.arg("-o").arg("StrictHostKeyChecking=accept-new");
96 if let Some(port) = params.port {
100 cmd.arg("-p").arg(port.to_string());
101 }
102
103 if let Some(ref identity) = params.identity_file {
104 cmd.arg("-i").arg(identity);
105 }
106
107 cmd.arg(format!("{}@{}", params.user, params.host));
108
109 let agent_len = AGENT_SOURCE.len();
118 let bootstrap = format!(
119 "python3 -u -c \"import sys;exec(sys.stdin.read({}))\"",
120 agent_len
121 );
122 cmd.arg(bootstrap);
123
124 cmd.stdin(Stdio::piped());
125 cmd.stdout(Stdio::piped());
126 cmd.stderr(Stdio::inherit());
128 cmd.hide_window();
129
130 let mut child = cmd.spawn()?;
131
132 let mut stdin = child
134 .stdin
135 .take()
136 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
137 let stdout = child
138 .stdout
139 .take()
140 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
141 stdin.write_all(AGENT_SOURCE.as_bytes()).await?;
145 stdin.flush().await?;
146
147 let mut reader = BufReader::new(stdout);
149
150 let mut ready_line = String::new();
154 match reader.read_line(&mut ready_line).await {
155 Ok(0) => {
156 return Err(ssh_eof_error(&mut child, ¶ms).await);
157 }
158 Ok(_) => {}
159 Err(e) => return Err(SshError::AgentStartFailed(format!("read error: {}", e))),
160 }
161
162 let ready: AgentResponse = serde_json::from_str(&ready_line).map_err(|e| {
163 SshError::AgentStartFailed(format!(
164 "invalid ready message '{}': {}",
165 ready_line.trim(),
166 e
167 ))
168 })?;
169
170 if !ready.is_ready() {
171 return Err(SshError::AgentStartFailed(
172 "agent did not send ready message".to_string(),
173 ));
174 }
175
176 let version = ready.version.unwrap_or(0);
178 if version != crate::services::remote::protocol::PROTOCOL_VERSION {
179 return Err(SshError::VersionMismatch {
180 expected: crate::services::remote::protocol::PROTOCOL_VERSION,
181 got: version,
182 });
183 }
184
185 let channel = std::sync::Arc::new(AgentChannel::new(reader, stdin));
187
188 Ok(Self {
189 process: child,
190 channel,
191 params,
192 })
193 }
194
195 pub fn channel(&self) -> std::sync::Arc<AgentChannel> {
197 self.channel.clone()
198 }
199
200 pub fn params(&self) -> &ConnectionParams {
202 &self.params
203 }
204
205 pub fn is_connected(&self) -> bool {
207 self.channel.is_connected()
208 }
209
210 pub fn connection_string(&self) -> String {
212 self.params.to_string()
213 }
214}
215
216impl Drop for SshConnection {
217 fn drop(&mut self) {
218 if let Ok(()) = self.process.start_kill() {}
223 }
224}
225
226const DEFAULT_RECONNECT_INTERVAL: std::time::Duration = std::time::Duration::from_secs(5);
228
229pub struct ReconnectConfig {
231 pub interval: std::time::Duration,
233}
234
235impl Default for ReconnectConfig {
236 fn default() -> Self {
237 Self {
238 interval: DEFAULT_RECONNECT_INTERVAL,
239 }
240 }
241}
242
243pub fn spawn_reconnect_task(
253 channel: std::sync::Arc<AgentChannel>,
254 params: ConnectionParams,
255) -> tokio::task::JoinHandle<()> {
256 let connect_fn = move || {
257 let params = params.clone();
258 async move {
259 let (reader, writer, _child) = establish_ssh_transport(¶ms).await?;
260 let reader: Box<dyn tokio::io::AsyncBufRead + Unpin + Send> = Box::new(reader);
262 let writer: Box<dyn tokio::io::AsyncWrite + Unpin + Send> = Box::new(writer);
263 Ok::<_, SshError>((reader, writer))
264 }
265 };
266
267 spawn_reconnect_task_with(
268 channel,
269 connect_fn,
270 ReconnectConfig::default(),
271 "SSH remote",
272 )
273}
274
275pub fn spawn_reconnect_task_with<F, Fut>(
282 channel: std::sync::Arc<AgentChannel>,
283 connect_fn: F,
284 config: ReconnectConfig,
285 label: &'static str,
286) -> tokio::task::JoinHandle<()>
287where
288 F: Fn() -> Fut + Send + 'static,
289 Fut: std::future::Future<
290 Output = Result<
291 (
292 Box<dyn tokio::io::AsyncBufRead + Unpin + Send>,
293 Box<dyn tokio::io::AsyncWrite + Unpin + Send>,
294 ),
295 SshError,
296 >,
297 > + Send,
298{
299 tokio::spawn(async move {
300 loop {
301 while channel.is_connected() {
303 tokio::time::sleep(config.interval).await;
304 }
305
306 tracing::info!("{label}: connection lost, attempting reconnection...");
307
308 loop {
310 tokio::time::sleep(config.interval).await;
311
312 if !channel.is_connected() {
314 } else {
316 break;
318 }
319
320 match (connect_fn)().await {
321 Ok((reader, writer)) => {
322 tracing::info!("{label}: reconnected successfully");
323 channel.replace_transport(reader, writer).await;
324 break;
325 }
326 Err(e) => {
327 tracing::debug!("{label}: reconnection attempt failed: {e}");
328 }
329 }
330 }
331 }
332 })
333}
334
335async fn ssh_eof_error(child: &mut Child, params: &ConnectionParams) -> SshError {
342 let status = tokio::time::timeout(std::time::Duration::from_secs(5), child.wait()).await;
344
345 let hint = match status {
346 Ok(Ok(status)) => {
347 match status.code() {
348 Some(255) => format!(
352 "SSH could not connect to {}. Check that the host is \
353 reachable, the hostname is correct, and your SSH \
354 credentials are valid (exit code 255)",
355 params
356 ),
357 Some(127) => format!(
358 "python3 was not found on the remote host {}. \
359 Ensure Python 3 is installed on the remote machine",
360 params
361 ),
362 Some(code) => format!(
363 "SSH process exited with code {} while connecting to {}",
364 code, params
365 ),
366 None => format!(
367 "SSH process was killed by a signal while connecting to {}",
368 params
369 ),
370 }
371 }
372 Ok(Err(e)) => format!("failed to get SSH exit status: {}", e),
373 Err(_) => {
374 if let Err(e) = child.start_kill() {
376 tracing::warn!("Failed to kill timed-out SSH process: {}", e);
377 }
378 format!(
379 "SSH process did not exit in time while connecting to {}",
380 params
381 )
382 }
383 };
384
385 SshError::AgentStartFailed(hint)
386}
387
388async fn establish_ssh_transport(
392 params: &ConnectionParams,
393) -> Result<
394 (
395 BufReader<tokio::process::ChildStdout>,
396 tokio::process::ChildStdin,
397 Child,
398 ),
399 SshError,
400> {
401 let mut cmd = Command::new("ssh");
402
403 cmd.arg("-o").arg("StrictHostKeyChecking=accept-new");
404 cmd.arg("-o").arg("BatchMode=yes");
406
407 if let Some(port) = params.port {
408 cmd.arg("-p").arg(port.to_string());
409 }
410
411 if let Some(ref identity) = params.identity_file {
412 cmd.arg("-i").arg(identity);
413 }
414
415 cmd.arg(format!("{}@{}", params.user, params.host));
416
417 let agent_len = AGENT_SOURCE.len();
418 let bootstrap = format!(
419 "python3 -u -c \"import sys;exec(sys.stdin.read({}))\"",
420 agent_len
421 );
422 cmd.arg(bootstrap);
423
424 cmd.stdin(Stdio::piped());
425 cmd.stdout(Stdio::piped());
426 cmd.stderr(Stdio::null()); cmd.hide_window();
428
429 let mut child = cmd.spawn()?;
430
431 let mut stdin = child
432 .stdin
433 .take()
434 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
435 let stdout = child
436 .stdout
437 .take()
438 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
439
440 stdin.write_all(AGENT_SOURCE.as_bytes()).await?;
442 stdin.flush().await?;
443
444 let mut reader = BufReader::new(stdout);
445
446 let mut ready_line = String::new();
448 match reader.read_line(&mut ready_line).await {
449 Ok(0) => {
450 return Err(ssh_eof_error(&mut child, params).await);
451 }
452 Ok(_) => {}
453 Err(e) => return Err(SshError::AgentStartFailed(format!("read error: {}", e))),
454 }
455
456 let ready: AgentResponse = serde_json::from_str(&ready_line).map_err(|e| {
457 SshError::AgentStartFailed(format!(
458 "invalid ready message '{}': {}",
459 ready_line.trim(),
460 e
461 ))
462 })?;
463
464 if !ready.is_ready() {
465 return Err(SshError::AgentStartFailed(
466 "agent did not send ready message".to_string(),
467 ));
468 }
469
470 let version = ready.version.unwrap_or(0);
471 if version != crate::services::remote::protocol::PROTOCOL_VERSION {
472 return Err(SshError::VersionMismatch {
473 expected: crate::services::remote::protocol::PROTOCOL_VERSION,
474 got: version,
475 });
476 }
477
478 Ok((reader, stdin, child))
479}
480
481#[doc(hidden)]
486pub async fn spawn_local_agent() -> Result<std::sync::Arc<AgentChannel>, SshError> {
487 use tokio::process::Command as TokioCommand;
488
489 let mut child = TokioCommand::new("python3")
490 .arg("-u")
491 .arg("-c")
492 .arg(AGENT_SOURCE)
493 .stdin(Stdio::piped())
494 .stdout(Stdio::piped())
495 .stderr(Stdio::piped())
496 .hide_window()
497 .spawn()?;
498
499 let stdin = child
500 .stdin
501 .take()
502 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
503 let stdout = child
504 .stdout
505 .take()
506 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
507
508 let mut reader = BufReader::new(stdout);
509
510 let mut ready_line = String::new();
512 reader.read_line(&mut ready_line).await?;
513
514 let ready: AgentResponse = serde_json::from_str(&ready_line)
515 .map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
516
517 if !ready.is_ready() {
518 return Err(SshError::AgentStartFailed(
519 "agent did not send ready message".to_string(),
520 ));
521 }
522
523 Ok(std::sync::Arc::new(AgentChannel::new(reader, stdin)))
524}
525
526#[doc(hidden)]
531pub async fn spawn_local_agent_with_capacity(
532 data_channel_capacity: usize,
533) -> Result<std::sync::Arc<AgentChannel>, SshError> {
534 use tokio::process::Command as TokioCommand;
535
536 let mut child = TokioCommand::new("python3")
537 .arg("-u")
538 .arg("-c")
539 .arg(AGENT_SOURCE)
540 .stdin(Stdio::piped())
541 .stdout(Stdio::piped())
542 .stderr(Stdio::piped())
543 .hide_window()
544 .spawn()?;
545
546 let stdin = child
547 .stdin
548 .take()
549 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
550 let stdout = child
551 .stdout
552 .take()
553 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
554
555 let mut reader = BufReader::new(stdout);
556
557 let mut ready_line = String::new();
559 reader.read_line(&mut ready_line).await?;
560
561 let ready: AgentResponse = serde_json::from_str(&ready_line)
562 .map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
563
564 if !ready.is_ready() {
565 return Err(SshError::AgentStartFailed(
566 "agent did not send ready message".to_string(),
567 ));
568 }
569
570 Ok(std::sync::Arc::new(AgentChannel::with_capacity(
571 reader,
572 stdin,
573 data_channel_capacity,
574 )))
575}
576
577#[doc(hidden)]
583pub async fn spawn_local_agent_transport() -> Result<
584 (
585 tokio::io::BufReader<tokio::process::ChildStdout>,
586 tokio::process::ChildStdin,
587 ),
588 SshError,
589> {
590 use tokio::process::Command as TokioCommand;
591
592 let mut child = TokioCommand::new("python3")
593 .arg("-u")
594 .arg("-c")
595 .arg(AGENT_SOURCE)
596 .stdin(Stdio::piped())
597 .stdout(Stdio::piped())
598 .stderr(Stdio::piped())
599 .hide_window()
600 .spawn()?;
601
602 let stdin = child
603 .stdin
604 .take()
605 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
606 let stdout = child
607 .stdout
608 .take()
609 .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
610
611 let mut reader = BufReader::new(stdout);
612
613 let mut ready_line = String::new();
615 reader.read_line(&mut ready_line).await?;
616
617 let ready: AgentResponse = serde_json::from_str(&ready_line)
618 .map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
619
620 if !ready.is_ready() {
621 return Err(SshError::AgentStartFailed(
622 "agent did not send ready message".to_string(),
623 ));
624 }
625
626 Ok((reader, stdin))
627}
628
629#[cfg(test)]
630mod tests {
631 use super::*;
632
633 #[test]
634 fn test_parse_connection_params() {
635 let params = ConnectionParams::parse("user@host").unwrap();
636 assert_eq!(params.user, "user");
637 assert_eq!(params.host, "host");
638 assert_eq!(params.port, None);
639
640 let params = ConnectionParams::parse("user@host:22").unwrap();
641 assert_eq!(params.user, "user");
642 assert_eq!(params.host, "host");
643 assert_eq!(params.port, Some(22));
644
645 assert!(ConnectionParams::parse("hostonly").is_none());
646 assert!(ConnectionParams::parse("@host").is_none());
647 assert!(ConnectionParams::parse("user@").is_none());
648 }
649
650 #[test]
651 fn test_connection_string() {
652 let params = ConnectionParams {
653 user: "alice".to_string(),
654 host: "example.com".to_string(),
655 port: None,
656 identity_file: None,
657 };
658 assert_eq!(params.to_string(), "alice@example.com");
659
660 let params = ConnectionParams {
661 user: "bob".to_string(),
662 host: "server.local".to_string(),
663 port: Some(2222),
664 identity_file: None,
665 };
666 assert_eq!(params.to_string(), "bob@server.local:2222");
667 }
668}