Skip to main content

fresh/services/remote/
connection.rs

1//! SSH connection management
2//!
3//! Handles spawning SSH process and bootstrapping the Python agent.
4
5use crate::services::process_hidden::HideWindow;
6use crate::services::remote::channel::AgentChannel;
7use crate::services::remote::protocol::AgentResponse;
8use crate::services::remote::AGENT_SOURCE;
9use std::path::PathBuf;
10use std::process::Stdio;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
12use tokio::process::{Child, Command};
13
14/// Error type for SSH connection
15#[derive(Debug, thiserror::Error)]
16pub enum SshError {
17    #[error("Failed to spawn SSH process ({0}). Is the `ssh` command installed and in your PATH?")]
18    SpawnFailed(#[from] std::io::Error),
19
20    #[error("Agent failed to start: {0}")]
21    AgentStartFailed(String),
22
23    #[error("Protocol version mismatch: expected {expected}, got {got}")]
24    VersionMismatch { expected: u32, got: u32 },
25
26    #[error("Connection closed")]
27    ConnectionClosed,
28
29    #[error("Authentication failed")]
30    AuthenticationFailed,
31}
32
33/// SSH connection parameters
34#[derive(Debug, Clone)]
35pub struct ConnectionParams {
36    pub user: String,
37    pub host: String,
38    pub port: Option<u16>,
39    pub identity_file: Option<PathBuf>,
40}
41
42impl ConnectionParams {
43    /// Parse a connection string like "user@host" or "user@host:port"
44    pub fn parse(s: &str) -> Option<Self> {
45        let (user_host, port) = if let Some((uh, p)) = s.rsplit_once(':') {
46            if let Ok(port) = p.parse::<u16>() {
47                (uh, Some(port))
48            } else {
49                (s, None)
50            }
51        } else {
52            (s, None)
53        };
54
55        let (user, host) = user_host.split_once('@')?;
56        if user.is_empty() || host.is_empty() {
57            return None;
58        }
59
60        Some(Self {
61            user: user.to_string(),
62            host: host.to_string(),
63            port,
64            identity_file: None,
65        })
66    }
67}
68
69impl std::fmt::Display for ConnectionParams {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        if let Some(port) = self.port {
72            write!(f, "{}@{}:{}", self.user, self.host, port)
73        } else {
74            write!(f, "{}@{}", self.user, self.host)
75        }
76    }
77}
78
79/// Active SSH connection with bootstrapped agent
80pub struct SshConnection {
81    /// SSH child process
82    process: Child,
83    /// Communication channel with agent (wrapped in Arc for sharing)
84    channel: std::sync::Arc<AgentChannel>,
85    /// Connection parameters
86    params: ConnectionParams,
87}
88
89impl SshConnection {
90    /// Establish a new SSH connection and bootstrap the agent
91    pub async fn connect(params: ConnectionParams) -> Result<Self, SshError> {
92        let mut cmd = Command::new("ssh");
93
94        // Don't check host key strictly for ease of use
95        cmd.arg("-o").arg("StrictHostKeyChecking=accept-new");
96        // Allow password prompts - SSH will use the terminal for this
97        // Note: We inherit stderr so SSH can prompt for password if needed
98
99        if let Some(port) = params.port {
100            cmd.arg("-p").arg(port.to_string());
101        }
102
103        if let Some(ref identity) = params.identity_file {
104            cmd.arg("-i").arg(identity);
105        }
106
107        cmd.arg(format!("{}@{}", params.user, params.host));
108
109        // Bootstrap the agent using Python itself to read the exact byte count.
110        // This avoids requiring bash or other shell utilities on the remote.
111        // Python reads exactly N bytes (the agent code), execs it, and the agent
112        // then continues reading from stdin for protocol messages.
113        //
114        // Note: SSH passes the remote command through a shell, so we need to
115        // properly quote the Python code. We use double quotes for the outer
116        // shell and avoid problematic characters in the Python code.
117        let agent_len = AGENT_SOURCE.len();
118        let bootstrap = format!(
119            "python3 -u -c \"import sys;exec(sys.stdin.read({}))\"",
120            agent_len
121        );
122        cmd.arg(bootstrap);
123
124        cmd.stdin(Stdio::piped());
125        cmd.stdout(Stdio::piped());
126        // Inherit stderr so SSH can prompt for password on the terminal
127        cmd.stderr(Stdio::inherit());
128        cmd.hide_window();
129
130        let mut child = cmd.spawn()?;
131
132        // Get handles
133        let mut stdin = child
134            .stdin
135            .take()
136            .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
137        let stdout = child
138            .stdout
139            .take()
140            .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
141        // Note: stderr is inherited so SSH can prompt for password on the terminal
142
143        // Send the agent code (exact byte count)
144        stdin.write_all(AGENT_SOURCE.as_bytes()).await?;
145        stdin.flush().await?;
146
147        // Create buffered reader for stdout
148        let mut reader = BufReader::new(stdout);
149
150        // Wait for ready message from agent
151        // No timeout needed - all failure modes (auth failure, network issues, etc.)
152        // result in SSH exiting and us getting EOF. User can Ctrl+C if needed.
153        let mut ready_line = String::new();
154        match reader.read_line(&mut ready_line).await {
155            Ok(0) => {
156                return Err(ssh_eof_error(&mut child, &params).await);
157            }
158            Ok(_) => {}
159            Err(e) => return Err(SshError::AgentStartFailed(format!("read error: {}", e))),
160        }
161
162        let ready: AgentResponse = serde_json::from_str(&ready_line).map_err(|e| {
163            SshError::AgentStartFailed(format!(
164                "invalid ready message '{}': {}",
165                ready_line.trim(),
166                e
167            ))
168        })?;
169
170        if !ready.is_ready() {
171            return Err(SshError::AgentStartFailed(
172                "agent did not send ready message".to_string(),
173            ));
174        }
175
176        // Check protocol version
177        let version = ready.version.unwrap_or(0);
178        if version != crate::services::remote::protocol::PROTOCOL_VERSION {
179            return Err(SshError::VersionMismatch {
180                expected: crate::services::remote::protocol::PROTOCOL_VERSION,
181                got: version,
182            });
183        }
184
185        // Create channel (takes ownership of stdin for writing)
186        let channel = std::sync::Arc::new(AgentChannel::new(reader, stdin));
187
188        Ok(Self {
189            process: child,
190            channel,
191            params,
192        })
193    }
194
195    /// Get the communication channel as an Arc for sharing
196    pub fn channel(&self) -> std::sync::Arc<AgentChannel> {
197        self.channel.clone()
198    }
199
200    /// Get connection parameters
201    pub fn params(&self) -> &ConnectionParams {
202        &self.params
203    }
204
205    /// Check if the connection is still alive
206    pub fn is_connected(&self) -> bool {
207        self.channel.is_connected()
208    }
209
210    /// Get the connection string for display
211    pub fn connection_string(&self) -> String {
212        self.params.to_string()
213    }
214}
215
216impl Drop for SshConnection {
217    fn drop(&mut self) {
218        // Best-effort kill of the SSH process during cleanup.
219        // If it fails (process already exited, permission error, etc.)
220        // there's nothing we can do in a Drop impl — the OS will clean
221        // up the zombie when our process exits.
222        if let Ok(()) = self.process.start_kill() {}
223    }
224}
225
226/// Default interval between reconnection attempts.
227const DEFAULT_RECONNECT_INTERVAL: std::time::Duration = std::time::Duration::from_secs(5);
228
229/// Configuration for the reconnect task.
230pub struct ReconnectConfig {
231    /// How long to wait between reconnection attempts.
232    pub interval: std::time::Duration,
233}
234
235impl Default for ReconnectConfig {
236    fn default() -> Self {
237        Self {
238            interval: DEFAULT_RECONNECT_INTERVAL,
239        }
240    }
241}
242
243/// Spawn a background task that automatically reconnects when the channel
244/// disconnects.
245///
246/// The task monitors `channel.is_connected()` and, when false, attempts to
247/// establish a new SSH connection using the given `params`. On success, it
248/// calls `channel.replace_transport()` to hot-swap the underlying reader/writer.
249///
250/// The task runs until the channel is dropped (write_tx closed) or the
251/// returned `tokio::task::JoinHandle` is aborted.
252pub fn spawn_reconnect_task(
253    channel: std::sync::Arc<AgentChannel>,
254    params: ConnectionParams,
255) -> tokio::task::JoinHandle<()> {
256    let connect_fn = move || {
257        let params = params.clone();
258        async move {
259            let (reader, writer, _child) = establish_ssh_transport(&params).await?;
260            // Box the reader/writer so they have a uniform type
261            let reader: Box<dyn tokio::io::AsyncBufRead + Unpin + Send> = Box::new(reader);
262            let writer: Box<dyn tokio::io::AsyncWrite + Unpin + Send> = Box::new(writer);
263            Ok::<_, SshError>((reader, writer))
264        }
265    };
266
267    spawn_reconnect_task_with(
268        channel,
269        connect_fn,
270        ReconnectConfig::default(),
271        "SSH remote",
272    )
273}
274
275/// Spawn a reconnect task with a custom connection factory.
276///
277/// This is the generic version used by both production (via `spawn_reconnect_task`)
278/// and tests (with a fake connection factory). The `connect_fn` is called each
279/// time a reconnection attempt is made. It should return a `(reader, writer)` pair
280/// on success.
281pub fn spawn_reconnect_task_with<F, Fut>(
282    channel: std::sync::Arc<AgentChannel>,
283    connect_fn: F,
284    config: ReconnectConfig,
285    label: &'static str,
286) -> tokio::task::JoinHandle<()>
287where
288    F: Fn() -> Fut + Send + 'static,
289    Fut: std::future::Future<
290            Output = Result<
291                (
292                    Box<dyn tokio::io::AsyncBufRead + Unpin + Send>,
293                    Box<dyn tokio::io::AsyncWrite + Unpin + Send>,
294                ),
295                SshError,
296            >,
297        > + Send,
298{
299    tokio::spawn(async move {
300        loop {
301            // Wait until disconnected
302            while channel.is_connected() {
303                tokio::time::sleep(config.interval).await;
304            }
305
306            tracing::info!("{label}: connection lost, attempting reconnection...");
307
308            // Retry loop
309            loop {
310                tokio::time::sleep(config.interval).await;
311
312                // Check if channel was dropped (write_tx gone)
313                if !channel.is_connected() {
314                    // Still disconnected — try to reconnect
315                } else {
316                    // Something else reconnected us (e.g., manual replace_transport)
317                    break;
318                }
319
320                match (connect_fn)().await {
321                    Ok((reader, writer)) => {
322                        tracing::info!("{label}: reconnected successfully");
323                        channel.replace_transport(reader, writer).await;
324                        break;
325                    }
326                    Err(e) => {
327                        tracing::debug!("{label}: reconnection attempt failed: {e}");
328                    }
329                }
330            }
331        }
332    })
333}
334
335/// Establish a new SSH connection and return the raw transport + child process.
336///
337/// Build a descriptive error when the SSH process closes stdout (EOF) without
338/// sending a ready message. We wait for the SSH process to exit and inspect its
339/// exit code to give the user a more actionable message than a generic
340/// "connection closed".
341async fn ssh_eof_error(child: &mut Child, params: &ConnectionParams) -> SshError {
342    // Give SSH a moment to finish so we can read its exit code.
343    let status = tokio::time::timeout(std::time::Duration::from_secs(5), child.wait()).await;
344
345    let hint = match status {
346        Ok(Ok(status)) => {
347            match status.code() {
348                // 255 is SSH's conventional exit code for connection errors
349                // (host unreachable, connection refused, DNS failure, auth
350                // failure, etc.).
351                Some(255) => format!(
352                    "SSH could not connect to {}. Check that the host is \
353                     reachable, the hostname is correct, and your SSH \
354                     credentials are valid (exit code 255)",
355                    params
356                ),
357                Some(127) => format!(
358                    "python3 was not found on the remote host {}. \
359                     Ensure Python 3 is installed on the remote machine",
360                    params
361                ),
362                Some(code) => format!(
363                    "SSH process exited with code {} while connecting to {}",
364                    code, params
365                ),
366                None => format!(
367                    "SSH process was killed by a signal while connecting to {}",
368                    params
369                ),
370            }
371        }
372        Ok(Err(e)) => format!("failed to get SSH exit status: {}", e),
373        Err(_) => {
374            // Timed out waiting for exit — kill it so we don't leak.
375            if let Err(e) = child.start_kill() {
376                tracing::warn!("Failed to kill timed-out SSH process: {}", e);
377            }
378            format!(
379                "SSH process did not exit in time while connecting to {}",
380                params
381            )
382        }
383    };
384
385    SshError::AgentStartFailed(hint)
386}
387
388/// This is the lower-level function used by both `SshConnection::connect` and
389/// the reconnect task. It spawns an SSH process, bootstraps the Python agent,
390/// and returns the reader/writer pair ready for use with `AgentChannel`.
391async fn establish_ssh_transport(
392    params: &ConnectionParams,
393) -> Result<
394    (
395        BufReader<tokio::process::ChildStdout>,
396        tokio::process::ChildStdin,
397        Child,
398    ),
399    SshError,
400> {
401    let mut cmd = Command::new("ssh");
402
403    cmd.arg("-o").arg("StrictHostKeyChecking=accept-new");
404    // Disable password prompts for reconnection (non-interactive)
405    cmd.arg("-o").arg("BatchMode=yes");
406
407    if let Some(port) = params.port {
408        cmd.arg("-p").arg(port.to_string());
409    }
410
411    if let Some(ref identity) = params.identity_file {
412        cmd.arg("-i").arg(identity);
413    }
414
415    cmd.arg(format!("{}@{}", params.user, params.host));
416
417    let agent_len = AGENT_SOURCE.len();
418    let bootstrap = format!(
419        "python3 -u -c \"import sys;exec(sys.stdin.read({}))\"",
420        agent_len
421    );
422    cmd.arg(bootstrap);
423
424    cmd.stdin(Stdio::piped());
425    cmd.stdout(Stdio::piped());
426    cmd.stderr(Stdio::null()); // No terminal for reconnection
427    cmd.hide_window();
428
429    let mut child = cmd.spawn()?;
430
431    let mut stdin = child
432        .stdin
433        .take()
434        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
435    let stdout = child
436        .stdout
437        .take()
438        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
439
440    // Send the agent code
441    stdin.write_all(AGENT_SOURCE.as_bytes()).await?;
442    stdin.flush().await?;
443
444    let mut reader = BufReader::new(stdout);
445
446    // Wait for ready message
447    let mut ready_line = String::new();
448    match reader.read_line(&mut ready_line).await {
449        Ok(0) => {
450            return Err(ssh_eof_error(&mut child, params).await);
451        }
452        Ok(_) => {}
453        Err(e) => return Err(SshError::AgentStartFailed(format!("read error: {}", e))),
454    }
455
456    let ready: AgentResponse = serde_json::from_str(&ready_line).map_err(|e| {
457        SshError::AgentStartFailed(format!(
458            "invalid ready message '{}': {}",
459            ready_line.trim(),
460            e
461        ))
462    })?;
463
464    if !ready.is_ready() {
465        return Err(SshError::AgentStartFailed(
466            "agent did not send ready message".to_string(),
467        ));
468    }
469
470    let version = ready.version.unwrap_or(0);
471    if version != crate::services::remote::protocol::PROTOCOL_VERSION {
472        return Err(SshError::VersionMismatch {
473            expected: crate::services::remote::protocol::PROTOCOL_VERSION,
474            got: version,
475        });
476    }
477
478    Ok((reader, stdin, child))
479}
480
481/// Spawn a local agent process for testing (no SSH)
482///
483/// This is used by integration tests to test the full stack without SSH.
484/// Not intended for production use.
485#[doc(hidden)]
486pub async fn spawn_local_agent() -> Result<std::sync::Arc<AgentChannel>, SshError> {
487    use tokio::process::Command as TokioCommand;
488
489    let mut child = TokioCommand::new("python3")
490        .arg("-u")
491        .arg("-c")
492        .arg(AGENT_SOURCE)
493        .stdin(Stdio::piped())
494        .stdout(Stdio::piped())
495        .stderr(Stdio::piped())
496        .hide_window()
497        .spawn()?;
498
499    let stdin = child
500        .stdin
501        .take()
502        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
503    let stdout = child
504        .stdout
505        .take()
506        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
507
508    let mut reader = BufReader::new(stdout);
509
510    // Wait for ready message
511    let mut ready_line = String::new();
512    reader.read_line(&mut ready_line).await?;
513
514    let ready: AgentResponse = serde_json::from_str(&ready_line)
515        .map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
516
517    if !ready.is_ready() {
518        return Err(SshError::AgentStartFailed(
519            "agent did not send ready message".to_string(),
520        ));
521    }
522
523    Ok(std::sync::Arc::new(AgentChannel::new(reader, stdin)))
524}
525
526/// Spawn a local Python agent with a custom data channel capacity.
527///
528/// Same as `spawn_local_agent` but allows overriding the channel capacity
529/// for stress-testing backpressure handling.
530#[doc(hidden)]
531pub async fn spawn_local_agent_with_capacity(
532    data_channel_capacity: usize,
533) -> Result<std::sync::Arc<AgentChannel>, SshError> {
534    use tokio::process::Command as TokioCommand;
535
536    let mut child = TokioCommand::new("python3")
537        .arg("-u")
538        .arg("-c")
539        .arg(AGENT_SOURCE)
540        .stdin(Stdio::piped())
541        .stdout(Stdio::piped())
542        .stderr(Stdio::piped())
543        .hide_window()
544        .spawn()?;
545
546    let stdin = child
547        .stdin
548        .take()
549        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
550    let stdout = child
551        .stdout
552        .take()
553        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
554
555    let mut reader = BufReader::new(stdout);
556
557    // Wait for ready message
558    let mut ready_line = String::new();
559    reader.read_line(&mut ready_line).await?;
560
561    let ready: AgentResponse = serde_json::from_str(&ready_line)
562        .map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
563
564    if !ready.is_ready() {
565        return Err(SshError::AgentStartFailed(
566            "agent did not send ready message".to_string(),
567        ));
568    }
569
570    Ok(std::sync::Arc::new(AgentChannel::with_capacity(
571        reader,
572        stdin,
573        data_channel_capacity,
574    )))
575}
576
577/// Spawn a local Python agent and return the raw reader/writer transport.
578///
579/// Unlike `spawn_local_agent`, this does NOT create an `AgentChannel`. It
580/// returns the ready-to-use reader and writer so callers can feed them to
581/// `AgentChannel::replace_transport()` for reconnection testing.
582#[doc(hidden)]
583pub async fn spawn_local_agent_transport() -> Result<
584    (
585        tokio::io::BufReader<tokio::process::ChildStdout>,
586        tokio::process::ChildStdin,
587    ),
588    SshError,
589> {
590    use tokio::process::Command as TokioCommand;
591
592    let mut child = TokioCommand::new("python3")
593        .arg("-u")
594        .arg("-c")
595        .arg(AGENT_SOURCE)
596        .stdin(Stdio::piped())
597        .stdout(Stdio::piped())
598        .stderr(Stdio::piped())
599        .hide_window()
600        .spawn()?;
601
602    let stdin = child
603        .stdin
604        .take()
605        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
606    let stdout = child
607        .stdout
608        .take()
609        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
610
611    let mut reader = BufReader::new(stdout);
612
613    // Wait for ready message
614    let mut ready_line = String::new();
615    reader.read_line(&mut ready_line).await?;
616
617    let ready: AgentResponse = serde_json::from_str(&ready_line)
618        .map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
619
620    if !ready.is_ready() {
621        return Err(SshError::AgentStartFailed(
622            "agent did not send ready message".to_string(),
623        ));
624    }
625
626    Ok((reader, stdin))
627}
628
629#[cfg(test)]
630mod tests {
631    use super::*;
632
633    #[test]
634    fn test_parse_connection_params() {
635        let params = ConnectionParams::parse("user@host").unwrap();
636        assert_eq!(params.user, "user");
637        assert_eq!(params.host, "host");
638        assert_eq!(params.port, None);
639
640        let params = ConnectionParams::parse("user@host:22").unwrap();
641        assert_eq!(params.user, "user");
642        assert_eq!(params.host, "host");
643        assert_eq!(params.port, Some(22));
644
645        assert!(ConnectionParams::parse("hostonly").is_none());
646        assert!(ConnectionParams::parse("@host").is_none());
647        assert!(ConnectionParams::parse("user@").is_none());
648    }
649
650    #[test]
651    fn test_connection_string() {
652        let params = ConnectionParams {
653            user: "alice".to_string(),
654            host: "example.com".to_string(),
655            port: None,
656            identity_file: None,
657        };
658        assert_eq!(params.to_string(), "alice@example.com");
659
660        let params = ConnectionParams {
661            user: "bob".to_string(),
662            host: "server.local".to_string(),
663            port: Some(2222),
664            identity_file: None,
665        };
666        assert_eq!(params.to_string(), "bob@server.local:2222");
667    }
668}