Skip to main content

fresh/services/remote/
connection.rs

1//! SSH connection management
2//!
3//! Handles spawning SSH process and bootstrapping the Python agent.
4
5use crate::services::remote::channel::AgentChannel;
6use crate::services::remote::protocol::AgentResponse;
7use crate::services::remote::AGENT_SOURCE;
8use std::path::PathBuf;
9use std::process::Stdio;
10use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
11use tokio::process::{Child, Command};
12
13/// Error type for SSH connection
14#[derive(Debug, thiserror::Error)]
15pub enum SshError {
16    #[error("Failed to spawn SSH process ({0}). Is the `ssh` command installed and in your PATH?")]
17    SpawnFailed(#[from] std::io::Error),
18
19    #[error("Agent failed to start: {0}")]
20    AgentStartFailed(String),
21
22    #[error("Protocol version mismatch: expected {expected}, got {got}")]
23    VersionMismatch { expected: u32, got: u32 },
24
25    #[error("Connection closed")]
26    ConnectionClosed,
27
28    #[error("Authentication failed")]
29    AuthenticationFailed,
30}
31
32/// SSH connection parameters
33#[derive(Debug, Clone)]
34pub struct ConnectionParams {
35    pub user: String,
36    pub host: String,
37    pub port: Option<u16>,
38    pub identity_file: Option<PathBuf>,
39}
40
41impl ConnectionParams {
42    /// Parse a connection string like "user@host" or "user@host:port"
43    pub fn parse(s: &str) -> Option<Self> {
44        let (user_host, port) = if let Some((uh, p)) = s.rsplit_once(':') {
45            if let Ok(port) = p.parse::<u16>() {
46                (uh, Some(port))
47            } else {
48                (s, None)
49            }
50        } else {
51            (s, None)
52        };
53
54        let (user, host) = user_host.split_once('@')?;
55        if user.is_empty() || host.is_empty() {
56            return None;
57        }
58
59        Some(Self {
60            user: user.to_string(),
61            host: host.to_string(),
62            port,
63            identity_file: None,
64        })
65    }
66}
67
68impl std::fmt::Display for ConnectionParams {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        if let Some(port) = self.port {
71            write!(f, "{}@{}:{}", self.user, self.host, port)
72        } else {
73            write!(f, "{}@{}", self.user, self.host)
74        }
75    }
76}
77
78/// Active SSH connection with bootstrapped agent
79pub struct SshConnection {
80    /// SSH child process
81    process: Child,
82    /// Communication channel with agent (wrapped in Arc for sharing)
83    channel: std::sync::Arc<AgentChannel>,
84    /// Connection parameters
85    params: ConnectionParams,
86}
87
88impl SshConnection {
89    /// Establish a new SSH connection and bootstrap the agent
90    pub async fn connect(params: ConnectionParams) -> Result<Self, SshError> {
91        let mut cmd = Command::new("ssh");
92
93        // Don't check host key strictly for ease of use
94        cmd.arg("-o").arg("StrictHostKeyChecking=accept-new");
95        // Allow password prompts - SSH will use the terminal for this
96        // Note: We inherit stderr so SSH can prompt for password if needed
97
98        if let Some(port) = params.port {
99            cmd.arg("-p").arg(port.to_string());
100        }
101
102        if let Some(ref identity) = params.identity_file {
103            cmd.arg("-i").arg(identity);
104        }
105
106        cmd.arg(format!("{}@{}", params.user, params.host));
107
108        // Bootstrap the agent using Python itself to read the exact byte count.
109        // This avoids requiring bash or other shell utilities on the remote.
110        // Python reads exactly N bytes (the agent code), execs it, and the agent
111        // then continues reading from stdin for protocol messages.
112        //
113        // Note: SSH passes the remote command through a shell, so we need to
114        // properly quote the Python code. We use double quotes for the outer
115        // shell and avoid problematic characters in the Python code.
116        let agent_len = AGENT_SOURCE.len();
117        let bootstrap = format!(
118            "python3 -u -c \"import sys;exec(sys.stdin.read({}))\"",
119            agent_len
120        );
121        cmd.arg(bootstrap);
122
123        cmd.stdin(Stdio::piped());
124        cmd.stdout(Stdio::piped());
125        // Inherit stderr so SSH can prompt for password on the terminal
126        cmd.stderr(Stdio::inherit());
127
128        let mut child = cmd.spawn()?;
129
130        // Get handles
131        let mut stdin = child
132            .stdin
133            .take()
134            .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
135        let stdout = child
136            .stdout
137            .take()
138            .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
139        // Note: stderr is inherited so SSH can prompt for password on the terminal
140
141        // Send the agent code (exact byte count)
142        stdin.write_all(AGENT_SOURCE.as_bytes()).await?;
143        stdin.flush().await?;
144
145        // Create buffered reader for stdout
146        let mut reader = BufReader::new(stdout);
147
148        // Wait for ready message from agent
149        // No timeout needed - all failure modes (auth failure, network issues, etc.)
150        // result in SSH exiting and us getting EOF. User can Ctrl+C if needed.
151        let mut ready_line = String::new();
152        match reader.read_line(&mut ready_line).await {
153            Ok(0) => {
154                return Err(ssh_eof_error(&mut child, &params).await);
155            }
156            Ok(_) => {}
157            Err(e) => return Err(SshError::AgentStartFailed(format!("read error: {}", e))),
158        }
159
160        let ready: AgentResponse = serde_json::from_str(&ready_line).map_err(|e| {
161            SshError::AgentStartFailed(format!(
162                "invalid ready message '{}': {}",
163                ready_line.trim(),
164                e
165            ))
166        })?;
167
168        if !ready.is_ready() {
169            return Err(SshError::AgentStartFailed(
170                "agent did not send ready message".to_string(),
171            ));
172        }
173
174        // Check protocol version
175        let version = ready.version.unwrap_or(0);
176        if version != crate::services::remote::protocol::PROTOCOL_VERSION {
177            return Err(SshError::VersionMismatch {
178                expected: crate::services::remote::protocol::PROTOCOL_VERSION,
179                got: version,
180            });
181        }
182
183        // Create channel (takes ownership of stdin for writing)
184        let channel = std::sync::Arc::new(AgentChannel::new(reader, stdin));
185
186        Ok(Self {
187            process: child,
188            channel,
189            params,
190        })
191    }
192
193    /// Get the communication channel as an Arc for sharing
194    pub fn channel(&self) -> std::sync::Arc<AgentChannel> {
195        self.channel.clone()
196    }
197
198    /// Get connection parameters
199    pub fn params(&self) -> &ConnectionParams {
200        &self.params
201    }
202
203    /// Check if the connection is still alive
204    pub fn is_connected(&self) -> bool {
205        self.channel.is_connected()
206    }
207
208    /// Get the connection string for display
209    pub fn connection_string(&self) -> String {
210        self.params.to_string()
211    }
212}
213
214impl Drop for SshConnection {
215    fn drop(&mut self) {
216        // Best-effort kill of the SSH process during cleanup.
217        // If it fails (process already exited, permission error, etc.)
218        // there's nothing we can do in a Drop impl — the OS will clean
219        // up the zombie when our process exits.
220        if let Ok(()) = self.process.start_kill() {}
221    }
222}
223
224/// Default interval between reconnection attempts.
225const DEFAULT_RECONNECT_INTERVAL: std::time::Duration = std::time::Duration::from_secs(5);
226
227/// Configuration for the reconnect task.
228pub struct ReconnectConfig {
229    /// How long to wait between reconnection attempts.
230    pub interval: std::time::Duration,
231}
232
233impl Default for ReconnectConfig {
234    fn default() -> Self {
235        Self {
236            interval: DEFAULT_RECONNECT_INTERVAL,
237        }
238    }
239}
240
241/// Spawn a background task that automatically reconnects when the channel
242/// disconnects.
243///
244/// The task monitors `channel.is_connected()` and, when false, attempts to
245/// establish a new SSH connection using the given `params`. On success, it
246/// calls `channel.replace_transport()` to hot-swap the underlying reader/writer.
247///
248/// The task runs until the channel is dropped (write_tx closed) or the
249/// returned `tokio::task::JoinHandle` is aborted.
250pub fn spawn_reconnect_task(
251    channel: std::sync::Arc<AgentChannel>,
252    params: ConnectionParams,
253) -> tokio::task::JoinHandle<()> {
254    let connect_fn = move || {
255        let params = params.clone();
256        async move {
257            let (reader, writer, _child) = establish_ssh_transport(&params).await?;
258            // Box the reader/writer so they have a uniform type
259            let reader: Box<dyn tokio::io::AsyncBufRead + Unpin + Send> = Box::new(reader);
260            let writer: Box<dyn tokio::io::AsyncWrite + Unpin + Send> = Box::new(writer);
261            Ok::<_, SshError>((reader, writer))
262        }
263    };
264
265    spawn_reconnect_task_with(
266        channel,
267        connect_fn,
268        ReconnectConfig::default(),
269        "SSH remote",
270    )
271}
272
273/// Spawn a reconnect task with a custom connection factory.
274///
275/// This is the generic version used by both production (via `spawn_reconnect_task`)
276/// and tests (with a fake connection factory). The `connect_fn` is called each
277/// time a reconnection attempt is made. It should return a `(reader, writer)` pair
278/// on success.
279pub fn spawn_reconnect_task_with<F, Fut>(
280    channel: std::sync::Arc<AgentChannel>,
281    connect_fn: F,
282    config: ReconnectConfig,
283    label: &'static str,
284) -> tokio::task::JoinHandle<()>
285where
286    F: Fn() -> Fut + Send + 'static,
287    Fut: std::future::Future<
288            Output = Result<
289                (
290                    Box<dyn tokio::io::AsyncBufRead + Unpin + Send>,
291                    Box<dyn tokio::io::AsyncWrite + Unpin + Send>,
292                ),
293                SshError,
294            >,
295        > + Send,
296{
297    tokio::spawn(async move {
298        loop {
299            // Wait until disconnected
300            while channel.is_connected() {
301                tokio::time::sleep(config.interval).await;
302            }
303
304            tracing::info!("{label}: connection lost, attempting reconnection...");
305
306            // Retry loop
307            loop {
308                tokio::time::sleep(config.interval).await;
309
310                // Check if channel was dropped (write_tx gone)
311                if !channel.is_connected() {
312                    // Still disconnected — try to reconnect
313                } else {
314                    // Something else reconnected us (e.g., manual replace_transport)
315                    break;
316                }
317
318                match (connect_fn)().await {
319                    Ok((reader, writer)) => {
320                        tracing::info!("{label}: reconnected successfully");
321                        channel.replace_transport(reader, writer).await;
322                        break;
323                    }
324                    Err(e) => {
325                        tracing::debug!("{label}: reconnection attempt failed: {e}");
326                    }
327                }
328            }
329        }
330    })
331}
332
333/// Establish a new SSH connection and return the raw transport + child process.
334///
335/// Build a descriptive error when the SSH process closes stdout (EOF) without
336/// sending a ready message. We wait for the SSH process to exit and inspect its
337/// exit code to give the user a more actionable message than a generic
338/// "connection closed".
339async fn ssh_eof_error(child: &mut Child, params: &ConnectionParams) -> SshError {
340    // Give SSH a moment to finish so we can read its exit code.
341    let status = tokio::time::timeout(std::time::Duration::from_secs(5), child.wait()).await;
342
343    let hint = match status {
344        Ok(Ok(status)) => {
345            match status.code() {
346                // 255 is SSH's conventional exit code for connection errors
347                // (host unreachable, connection refused, DNS failure, auth
348                // failure, etc.).
349                Some(255) => format!(
350                    "SSH could not connect to {}. Check that the host is \
351                     reachable, the hostname is correct, and your SSH \
352                     credentials are valid (exit code 255)",
353                    params
354                ),
355                Some(127) => format!(
356                    "python3 was not found on the remote host {}. \
357                     Ensure Python 3 is installed on the remote machine",
358                    params
359                ),
360                Some(code) => format!(
361                    "SSH process exited with code {} while connecting to {}",
362                    code, params
363                ),
364                None => format!(
365                    "SSH process was killed by a signal while connecting to {}",
366                    params
367                ),
368            }
369        }
370        Ok(Err(e)) => format!("failed to get SSH exit status: {}", e),
371        Err(_) => {
372            // Timed out waiting for exit — kill it so we don't leak.
373            if let Err(e) = child.start_kill() {
374                tracing::warn!("Failed to kill timed-out SSH process: {}", e);
375            }
376            format!(
377                "SSH process did not exit in time while connecting to {}",
378                params
379            )
380        }
381    };
382
383    SshError::AgentStartFailed(hint)
384}
385
386/// This is the lower-level function used by both `SshConnection::connect` and
387/// the reconnect task. It spawns an SSH process, bootstraps the Python agent,
388/// and returns the reader/writer pair ready for use with `AgentChannel`.
389async fn establish_ssh_transport(
390    params: &ConnectionParams,
391) -> Result<
392    (
393        BufReader<tokio::process::ChildStdout>,
394        tokio::process::ChildStdin,
395        Child,
396    ),
397    SshError,
398> {
399    let mut cmd = Command::new("ssh");
400
401    cmd.arg("-o").arg("StrictHostKeyChecking=accept-new");
402    // Disable password prompts for reconnection (non-interactive)
403    cmd.arg("-o").arg("BatchMode=yes");
404
405    if let Some(port) = params.port {
406        cmd.arg("-p").arg(port.to_string());
407    }
408
409    if let Some(ref identity) = params.identity_file {
410        cmd.arg("-i").arg(identity);
411    }
412
413    cmd.arg(format!("{}@{}", params.user, params.host));
414
415    let agent_len = AGENT_SOURCE.len();
416    let bootstrap = format!(
417        "python3 -u -c \"import sys;exec(sys.stdin.read({}))\"",
418        agent_len
419    );
420    cmd.arg(bootstrap);
421
422    cmd.stdin(Stdio::piped());
423    cmd.stdout(Stdio::piped());
424    cmd.stderr(Stdio::null()); // No terminal for reconnection
425
426    let mut child = cmd.spawn()?;
427
428    let mut stdin = child
429        .stdin
430        .take()
431        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
432    let stdout = child
433        .stdout
434        .take()
435        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
436
437    // Send the agent code
438    stdin.write_all(AGENT_SOURCE.as_bytes()).await?;
439    stdin.flush().await?;
440
441    let mut reader = BufReader::new(stdout);
442
443    // Wait for ready message
444    let mut ready_line = String::new();
445    match reader.read_line(&mut ready_line).await {
446        Ok(0) => {
447            return Err(ssh_eof_error(&mut child, params).await);
448        }
449        Ok(_) => {}
450        Err(e) => return Err(SshError::AgentStartFailed(format!("read error: {}", e))),
451    }
452
453    let ready: AgentResponse = serde_json::from_str(&ready_line).map_err(|e| {
454        SshError::AgentStartFailed(format!(
455            "invalid ready message '{}': {}",
456            ready_line.trim(),
457            e
458        ))
459    })?;
460
461    if !ready.is_ready() {
462        return Err(SshError::AgentStartFailed(
463            "agent did not send ready message".to_string(),
464        ));
465    }
466
467    let version = ready.version.unwrap_or(0);
468    if version != crate::services::remote::protocol::PROTOCOL_VERSION {
469        return Err(SshError::VersionMismatch {
470            expected: crate::services::remote::protocol::PROTOCOL_VERSION,
471            got: version,
472        });
473    }
474
475    Ok((reader, stdin, child))
476}
477
478/// Spawn a local agent process for testing (no SSH)
479///
480/// This is used by integration tests to test the full stack without SSH.
481/// Not intended for production use.
482#[doc(hidden)]
483pub async fn spawn_local_agent() -> Result<std::sync::Arc<AgentChannel>, SshError> {
484    use tokio::process::Command as TokioCommand;
485
486    let mut child = TokioCommand::new("python3")
487        .arg("-u")
488        .arg("-c")
489        .arg(AGENT_SOURCE)
490        .stdin(Stdio::piped())
491        .stdout(Stdio::piped())
492        .stderr(Stdio::piped())
493        .spawn()?;
494
495    let stdin = child
496        .stdin
497        .take()
498        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
499    let stdout = child
500        .stdout
501        .take()
502        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
503
504    let mut reader = BufReader::new(stdout);
505
506    // Wait for ready message
507    let mut ready_line = String::new();
508    reader.read_line(&mut ready_line).await?;
509
510    let ready: AgentResponse = serde_json::from_str(&ready_line)
511        .map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
512
513    if !ready.is_ready() {
514        return Err(SshError::AgentStartFailed(
515            "agent did not send ready message".to_string(),
516        ));
517    }
518
519    Ok(std::sync::Arc::new(AgentChannel::new(reader, stdin)))
520}
521
522/// Spawn a local Python agent with a custom data channel capacity.
523///
524/// Same as `spawn_local_agent` but allows overriding the channel capacity
525/// for stress-testing backpressure handling.
526#[doc(hidden)]
527pub async fn spawn_local_agent_with_capacity(
528    data_channel_capacity: usize,
529) -> Result<std::sync::Arc<AgentChannel>, SshError> {
530    use tokio::process::Command as TokioCommand;
531
532    let mut child = TokioCommand::new("python3")
533        .arg("-u")
534        .arg("-c")
535        .arg(AGENT_SOURCE)
536        .stdin(Stdio::piped())
537        .stdout(Stdio::piped())
538        .stderr(Stdio::piped())
539        .spawn()?;
540
541    let stdin = child
542        .stdin
543        .take()
544        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
545    let stdout = child
546        .stdout
547        .take()
548        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
549
550    let mut reader = BufReader::new(stdout);
551
552    // Wait for ready message
553    let mut ready_line = String::new();
554    reader.read_line(&mut ready_line).await?;
555
556    let ready: AgentResponse = serde_json::from_str(&ready_line)
557        .map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
558
559    if !ready.is_ready() {
560        return Err(SshError::AgentStartFailed(
561            "agent did not send ready message".to_string(),
562        ));
563    }
564
565    Ok(std::sync::Arc::new(AgentChannel::with_capacity(
566        reader,
567        stdin,
568        data_channel_capacity,
569    )))
570}
571
572/// Spawn a local Python agent and return the raw reader/writer transport.
573///
574/// Unlike `spawn_local_agent`, this does NOT create an `AgentChannel`. It
575/// returns the ready-to-use reader and writer so callers can feed them to
576/// `AgentChannel::replace_transport()` for reconnection testing.
577#[doc(hidden)]
578pub async fn spawn_local_agent_transport() -> Result<
579    (
580        tokio::io::BufReader<tokio::process::ChildStdout>,
581        tokio::process::ChildStdin,
582    ),
583    SshError,
584> {
585    use tokio::process::Command as TokioCommand;
586
587    let mut child = TokioCommand::new("python3")
588        .arg("-u")
589        .arg("-c")
590        .arg(AGENT_SOURCE)
591        .stdin(Stdio::piped())
592        .stdout(Stdio::piped())
593        .stderr(Stdio::piped())
594        .spawn()?;
595
596    let stdin = child
597        .stdin
598        .take()
599        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
600    let stdout = child
601        .stdout
602        .take()
603        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
604
605    let mut reader = BufReader::new(stdout);
606
607    // Wait for ready message
608    let mut ready_line = String::new();
609    reader.read_line(&mut ready_line).await?;
610
611    let ready: AgentResponse = serde_json::from_str(&ready_line)
612        .map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
613
614    if !ready.is_ready() {
615        return Err(SshError::AgentStartFailed(
616            "agent did not send ready message".to_string(),
617        ));
618    }
619
620    Ok((reader, stdin))
621}
622
623#[cfg(test)]
624mod tests {
625    use super::*;
626
627    #[test]
628    fn test_parse_connection_params() {
629        let params = ConnectionParams::parse("user@host").unwrap();
630        assert_eq!(params.user, "user");
631        assert_eq!(params.host, "host");
632        assert_eq!(params.port, None);
633
634        let params = ConnectionParams::parse("user@host:22").unwrap();
635        assert_eq!(params.user, "user");
636        assert_eq!(params.host, "host");
637        assert_eq!(params.port, Some(22));
638
639        assert!(ConnectionParams::parse("hostonly").is_none());
640        assert!(ConnectionParams::parse("@host").is_none());
641        assert!(ConnectionParams::parse("user@").is_none());
642    }
643
644    #[test]
645    fn test_connection_string() {
646        let params = ConnectionParams {
647            user: "alice".to_string(),
648            host: "example.com".to_string(),
649            port: None,
650            identity_file: None,
651        };
652        assert_eq!(params.to_string(), "alice@example.com");
653
654        let params = ConnectionParams {
655            user: "bob".to_string(),
656            host: "server.local".to_string(),
657            port: Some(2222),
658            identity_file: None,
659        };
660        assert_eq!(params.to_string(), "bob@server.local:2222");
661    }
662}