Skip to main content

fresh/services/remote/
connection.rs

1//! SSH connection management
2//!
3//! Handles spawning SSH process and bootstrapping the Python agent.
4
5use crate::services::process_hidden::HideWindow;
6use crate::services::remote::channel::AgentChannel;
7use crate::services::remote::protocol::AgentResponse;
8use crate::services::remote::AGENT_SOURCE;
9use std::path::PathBuf;
10use std::process::Stdio;
11use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
12use tokio::process::{Child, ChildStderr, Command};
13
14/// Error type for SSH connection
15#[derive(Debug, thiserror::Error)]
16pub enum SshError {
17    #[error("Failed to spawn SSH process ({0}). Is the `ssh` command installed and in your PATH?")]
18    SpawnFailed(#[from] std::io::Error),
19
20    #[error("Agent failed to start: {0}")]
21    AgentStartFailed(String),
22
23    #[error("Protocol version mismatch: expected {expected}, got {got}")]
24    VersionMismatch { expected: u32, got: u32 },
25
26    #[error("Connection closed")]
27    ConnectionClosed,
28
29    #[error("Authentication failed")]
30    AuthenticationFailed,
31}
32
33/// SSH connection parameters
34#[derive(Debug, Clone)]
35pub struct ConnectionParams {
36    /// SSH login user. `None` lets ssh pick the user (its config / the current
37    /// local user), so `host` and `ssh://host` work without a `user@`.
38    pub user: Option<String>,
39    pub host: String,
40    pub port: Option<u16>,
41    pub identity_file: Option<PathBuf>,
42    /// Extra `ssh` arguments inserted verbatim before the target on every ssh
43    /// invocation (agent connect, reconnect, interactive terminal, LSP/probe
44    /// spawns), so options like `-J jump` or `-o ProxyCommand=…` apply end to
45    /// end rather than only to the initial connect.
46    pub extra_args: Vec<String>,
47}
48
49impl ConnectionParams {
50    /// Parse a connection string like `host`, `user@host`, or `user@host:port`
51    /// (a leading `ssh://` is tolerated). The user is optional.
52    pub fn parse(s: &str) -> Option<Self> {
53        let s = s.strip_prefix("ssh://").unwrap_or(s);
54        let (user_host, port) = if let Some((uh, p)) = s.rsplit_once(':') {
55            if let Ok(port) = p.parse::<u16>() {
56                (uh, Some(port))
57            } else {
58                (s, None)
59            }
60        } else {
61            (s, None)
62        };
63
64        let (user, host) = match user_host.split_once('@') {
65            Some((u, h)) => (Some(u.to_string()), h),
66            None => (None, user_host),
67        };
68        if host.is_empty() || user.as_deref() == Some("") {
69            return None;
70        }
71
72        Some(Self {
73            user,
74            host: host.to_string(),
75            port,
76            identity_file: None,
77            extra_args: Vec::new(),
78        })
79    }
80
81    /// The ssh target argument: `user@host` when a user is set, else bare
82    /// `host` (ssh then resolves the user itself).
83    pub fn ssh_target(&self) -> String {
84        match &self.user {
85            Some(user) if !user.is_empty() => format!("{user}@{}", self.host),
86            _ => self.host.clone(),
87        }
88    }
89}
90
91impl std::fmt::Display for ConnectionParams {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        match self.port {
94            Some(port) => write!(f, "{}:{}", self.ssh_target(), port),
95            None => write!(f, "{}", self.ssh_target()),
96        }
97    }
98}
99
100/// Active SSH connection with bootstrapped agent
101pub struct SshConnection {
102    /// SSH child process
103    process: Child,
104    /// Communication channel with agent (wrapped in Arc for sharing)
105    channel: std::sync::Arc<AgentChannel>,
106    /// Connection parameters
107    params: ConnectionParams,
108}
109
110impl SshConnection {
111    /// Establish a new SSH connection and bootstrap the agent
112    pub async fn connect(params: ConnectionParams) -> Result<Self, SshError> {
113        let mut cmd = Command::new("ssh");
114
115        // Don't check host key strictly for ease of use
116        cmd.arg("-o").arg("StrictHostKeyChecking=accept-new");
117
118        if let Some(port) = params.port {
119            cmd.arg("-p").arg(port.to_string());
120        }
121
122        if let Some(ref identity) = params.identity_file {
123            cmd.arg("-i").arg(identity);
124        }
125
126        cmd.args(&params.extra_args);
127        cmd.arg(params.ssh_target());
128
129        // Bootstrap the agent using Python itself to read the exact byte count.
130        // This avoids requiring bash or other shell utilities on the remote.
131        // Python reads exactly N bytes (the agent code), execs it, and the agent
132        // then continues reading from stdin for protocol messages.
133        //
134        // Note: SSH passes the remote command through a shell, so we need to
135        // properly quote the Python code. We use double quotes for the outer
136        // shell and avoid problematic characters in the Python code.
137        let agent_len = AGENT_SOURCE.len();
138        let bootstrap = format!(
139            "python3 -u -c \"import sys;exec(sys.stdin.read({}))\"",
140            agent_len
141        );
142        cmd.arg(bootstrap);
143
144        cmd.stdin(Stdio::piped());
145        cmd.stdout(Stdio::piped());
146        // Capture ssh's stderr instead of inheriting it. The editor runs a
147        // full-screen ratatui UI on the alternate screen; an inherited stderr
148        // lets ssh scribble its diagnostics ("Could not resolve hostname …")
149        // straight over the rendered UI. ratatui has no idea those cells
150        // changed, so the garbage persists until the next full repaint — the
151        // "corrupted window" users see after a bad host. We pipe stderr and
152        // fold its message into the connection error instead (see
153        // `ssh_eof_error`), so a failed connect becomes a clean status line.
154        cmd.stderr(Stdio::piped());
155        // Kill the ssh process if this connect future is dropped before it
156        // finishes (e.g. the New-Session dialog's Cancel aborts the connect
157        // task while the handshake is still hanging). Without this a hung
158        // connect would orphan the ssh child until it timed out on its own.
159        // For an established carrier `SshConnection`'s Drop also kills it; this
160        // covers the window before the connection object exists.
161        cmd.kill_on_drop(true);
162        cmd.hide_window();
163
164        let mut child = cmd.spawn()?;
165
166        // Get handles
167        let mut stdin = child
168            .stdin
169            .take()
170            .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
171        let stdout = child
172            .stdout
173            .take()
174            .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
175        let stderr = child.stderr.take();
176
177        // Send the agent code (exact byte count). If the carrier already died
178        // (a failed connect — e.g. the host was unreachable), this write/flush
179        // races the child's exit and can fail with a broken pipe. That pipe
180        // error isn't the actionable reason; the carrier's own stderr is. Fall
181        // through to the same EOF path so we surface "ssh: …" rather than a bare
182        // `SpawnFailed`, regardless of which side loses the race.
183        if stdin.write_all(AGENT_SOURCE.as_bytes()).await.is_err() || stdin.flush().await.is_err() {
184            return Err(ssh_eof_error(&mut child, &params, stderr).await);
185        }
186
187        // Create buffered reader for stdout
188        let mut reader = BufReader::new(stdout);
189
190        // Wait for ready message from agent
191        // No timeout needed - all failure modes (auth failure, network issues, etc.)
192        // result in SSH exiting and us getting EOF. User can Ctrl+C if needed.
193        let mut ready_line = String::new();
194        match reader.read_line(&mut ready_line).await {
195            Ok(0) => {
196                return Err(ssh_eof_error(&mut child, &params, stderr).await);
197            }
198            Ok(_) => {}
199            Err(e) => return Err(SshError::AgentStartFailed(format!("read error: {}", e))),
200        }
201
202        // Connected. Drain ssh's stderr for the life of the connection so the
203        // occasional later diagnostic (host-key warnings, etc.) is discarded
204        // rather than filling the pipe or — if we'd inherited it — landing on
205        // the editor's screen.
206        if let Some(mut stderr) = stderr {
207            tokio::spawn(async move {
208                let mut sink = tokio::io::sink();
209                // Best-effort drain; the byte count / EOF error is irrelevant
210                // since we're discarding ssh's stderr for the session.
211                #[allow(clippy::let_underscore_must_use)]
212                let _ = tokio::io::copy(&mut stderr, &mut sink).await;
213            });
214        }
215
216        let ready: AgentResponse = serde_json::from_str(&ready_line).map_err(|e| {
217            SshError::AgentStartFailed(format!(
218                "invalid ready message '{}': {}",
219                ready_line.trim(),
220                e
221            ))
222        })?;
223
224        if !ready.is_ready() {
225            return Err(SshError::AgentStartFailed(
226                "agent did not send ready message".to_string(),
227            ));
228        }
229
230        // Check protocol version
231        let version = ready.version.unwrap_or(0);
232        if version != crate::services::remote::protocol::PROTOCOL_VERSION {
233            return Err(SshError::VersionMismatch {
234                expected: crate::services::remote::protocol::PROTOCOL_VERSION,
235                got: version,
236            });
237        }
238
239        // Create channel (takes ownership of stdin for writing)
240        let channel = std::sync::Arc::new(AgentChannel::new(reader, stdin));
241
242        Ok(Self {
243            process: child,
244            channel,
245            params,
246        })
247    }
248
249    /// Get the communication channel as an Arc for sharing
250    pub fn channel(&self) -> std::sync::Arc<AgentChannel> {
251        self.channel.clone()
252    }
253
254    /// Get connection parameters
255    pub fn params(&self) -> &ConnectionParams {
256        &self.params
257    }
258
259    /// Check if the connection is still alive
260    pub fn is_connected(&self) -> bool {
261        self.channel.is_connected()
262    }
263
264    /// Get the connection string for display
265    pub fn connection_string(&self) -> String {
266        self.params.to_string()
267    }
268}
269
270impl Drop for SshConnection {
271    fn drop(&mut self) {
272        // Best-effort kill of the SSH process during cleanup.
273        // If it fails (process already exited, permission error, etc.)
274        // there's nothing we can do in a Drop impl — the OS will clean
275        // up the zombie when our process exits.
276        if let Ok(()) = self.process.start_kill() {}
277    }
278}
279
280/// Default interval between reconnection attempts.
281const DEFAULT_RECONNECT_INTERVAL: std::time::Duration = std::time::Duration::from_secs(5);
282
283/// Configuration for the reconnect task.
284pub struct ReconnectConfig {
285    /// How long to wait between reconnection attempts.
286    pub interval: std::time::Duration,
287}
288
289impl Default for ReconnectConfig {
290    fn default() -> Self {
291        Self {
292            interval: DEFAULT_RECONNECT_INTERVAL,
293        }
294    }
295}
296
297/// Spawn a background task that automatically reconnects when the channel
298/// disconnects.
299///
300/// The task monitors `channel.is_connected()` and, when false, attempts to
301/// establish a new SSH connection using the given `params`. On success, it
302/// calls `channel.replace_transport()` to hot-swap the underlying reader/writer.
303///
304/// The task runs until the channel is dropped (write_tx closed) or the
305/// returned `tokio::task::JoinHandle` is aborted.
306pub fn spawn_reconnect_task(
307    channel: std::sync::Arc<AgentChannel>,
308    params: ConnectionParams,
309) -> tokio::task::JoinHandle<()> {
310    let connect_fn = move || {
311        let params = params.clone();
312        async move {
313            let (reader, writer, _child) = establish_ssh_transport(&params).await?;
314            // Box the reader/writer so they have a uniform type
315            let reader: Box<dyn tokio::io::AsyncBufRead + Unpin + Send> = Box::new(reader);
316            let writer: Box<dyn tokio::io::AsyncWrite + Unpin + Send> = Box::new(writer);
317            Ok::<_, SshError>((reader, writer))
318        }
319    };
320
321    spawn_reconnect_task_with(
322        channel,
323        connect_fn,
324        ReconnectConfig::default(),
325        "SSH remote",
326    )
327}
328
329/// Spawn a reconnect task with a custom connection factory.
330///
331/// This is the generic version used by both production (via `spawn_reconnect_task`)
332/// and tests (with a fake connection factory). The `connect_fn` is called each
333/// time a reconnection attempt is made. It should return a `(reader, writer)` pair
334/// on success.
335pub fn spawn_reconnect_task_with<F, Fut>(
336    channel: std::sync::Arc<AgentChannel>,
337    connect_fn: F,
338    config: ReconnectConfig,
339    label: &'static str,
340) -> tokio::task::JoinHandle<()>
341where
342    F: Fn() -> Fut + Send + 'static,
343    Fut: std::future::Future<
344            Output = Result<
345                (
346                    Box<dyn tokio::io::AsyncBufRead + Unpin + Send>,
347                    Box<dyn tokio::io::AsyncWrite + Unpin + Send>,
348                ),
349                SshError,
350            >,
351        > + Send,
352{
353    tokio::spawn(async move {
354        loop {
355            // Wait until disconnected
356            while channel.is_connected() {
357                tokio::time::sleep(config.interval).await;
358            }
359
360            tracing::info!("{label}: connection lost, attempting reconnection...");
361
362            // Retry loop
363            loop {
364                tokio::time::sleep(config.interval).await;
365
366                // Check if channel was dropped (write_tx gone)
367                if !channel.is_connected() {
368                    // Still disconnected — try to reconnect
369                } else {
370                    // Something else reconnected us (e.g., manual replace_transport)
371                    break;
372                }
373
374                match (connect_fn)().await {
375                    Ok((reader, writer)) => {
376                        tracing::info!("{label}: reconnected successfully");
377                        channel.replace_transport(reader, writer).await;
378                        break;
379                    }
380                    Err(e) => {
381                        tracing::debug!("{label}: reconnection attempt failed: {e}");
382                    }
383                }
384            }
385        }
386    })
387}
388
389/// Default heartbeat interval. Comfortably under the smallest common
390/// load-balancer / NAT idle timeout (~5 min) so an otherwise-idle agent
391/// stream keeps generating traffic and isn't silently dropped.
392pub const DEFAULT_HEARTBEAT_INTERVAL: std::time::Duration = std::time::Duration::from_secs(60);
393
394/// Spawn a background task that pings the agent periodically so an idle
395/// connection's stream keeps producing traffic.
396///
397/// Long-lived agent streams that sit idle (no edits, no LSP chatter) get
398/// silently dropped by ELB / NAT idle timers after a few minutes — the
399/// client never sees a FIN, so the *next* request just hangs until it
400/// times out and the UI appears frozen. A cheap periodic `info` request
401/// keeps the NAT state-table entry warm. Shared by every agent transport
402/// (SSH and `kubectl exec` alike); `info` is already handled by every
403/// agent version, so no protocol bump is needed.
404///
405/// Holds only a `Weak` reference, so the task terminates on its own once
406/// the last owner of the channel is dropped — no JoinHandle bookkeeping
407/// is required to avoid a leak (callers may still `abort()` it to stop
408/// pinging immediately when the carrier dies). Pinging while disconnected
409/// is skipped; the reconnect task owns re-establishment.
410pub fn spawn_heartbeat_task(
411    channel: &std::sync::Arc<AgentChannel>,
412    interval: std::time::Duration,
413) -> tokio::task::JoinHandle<()> {
414    let weak = std::sync::Arc::downgrade(channel);
415    tokio::spawn(async move {
416        loop {
417            tokio::time::sleep(interval).await;
418            let Some(channel) = weak.upgrade() else {
419                break;
420            };
421            if channel.is_connected() {
422                // Outcome ignored on purpose: a failed/timed-out ping
423                // already marks the channel disconnected (see `request`),
424                // and the reconnect task owns recovery from there. Bound
425                // to a named `_` to satisfy `deny(let_underscore_must_use)`.
426                let _ping = channel.request("info", serde_json::json!({})).await;
427            }
428        }
429    })
430}
431
432/// Establish a new SSH connection and return the raw transport + child process.
433///
434/// Build a descriptive error when the SSH process closes stdout (EOF) without
435/// sending a ready message. We wait for the SSH process to exit and inspect its
436/// exit code to give the user a more actionable message than a generic
437/// "connection closed".
438async fn ssh_eof_error(
439    child: &mut Child,
440    params: &ConnectionParams,
441    stderr: Option<ChildStderr>,
442) -> SshError {
443    // Give SSH a moment to finish so we can read its exit code.
444    let status = tokio::time::timeout(std::time::Duration::from_secs(5), child.wait()).await;
445
446    let hint = match status {
447        Ok(Ok(status)) => {
448            match status.code() {
449                // 255 is SSH's conventional exit code for connection errors
450                // (host unreachable, connection refused, DNS failure, auth
451                // failure, etc.).
452                Some(255) => format!(
453                    "SSH could not connect to {}. Check that the host is \
454                     reachable, the hostname is correct, and your SSH \
455                     credentials are valid (exit code 255)",
456                    params
457                ),
458                Some(127) => format!(
459                    "python3 was not found on the remote host {}. \
460                     Ensure Python 3 is installed on the remote machine",
461                    params
462                ),
463                Some(code) => format!(
464                    "SSH process exited with code {} while connecting to {}",
465                    code, params
466                ),
467                None => format!(
468                    "SSH process was killed by a signal while connecting to {}",
469                    params
470                ),
471            }
472        }
473        Ok(Err(e)) => format!("failed to get SSH exit status: {}", e),
474        Err(_) => {
475            // Timed out waiting for exit — kill it so we don't leak.
476            if let Err(e) = child.start_kill() {
477                tracing::warn!("Failed to kill timed-out SSH process: {}", e);
478            }
479            format!(
480                "SSH process did not exit in time while connecting to {}",
481                params
482            )
483        }
484    };
485
486    // ssh writes the actionable reason ("Could not resolve hostname",
487    // "Permission denied", "Connection refused", …) to stderr. We piped it
488    // (rather than letting it corrupt the editor's screen), so fold the most
489    // specific line into the error for the status bar.
490    match read_ssh_stderr(stderr).await {
491        Some(detail) => SshError::AgentStartFailed(format!("{hint}: {detail}")),
492        None => SshError::AgentStartFailed(hint),
493    }
494}
495
496/// Read whatever a failed ssh process wrote to stderr and return its most
497/// specific (last non-empty) line. ssh has closed stdout by the time we call
498/// this and is exiting, so the read is bounded; we still cap the wait so a
499/// wedged pipe can't hang the error path.
500async fn read_ssh_stderr(stderr: Option<ChildStderr>) -> Option<String> {
501    let mut stderr = stderr?;
502    let mut buf = String::new();
503    #[allow(clippy::let_underscore_must_use)]
504    let _ = tokio::time::timeout(
505        std::time::Duration::from_secs(2),
506        stderr.read_to_string(&mut buf),
507    )
508    .await;
509    buf.trim()
510        .lines()
511        .map(str::trim)
512        .filter(|line| !line.is_empty())
513        .next_back()
514        .map(str::to_string)
515}
516
517/// This is the lower-level function used by both `SshConnection::connect` and
518/// the reconnect task. It spawns an SSH process, bootstraps the Python agent,
519/// and returns the reader/writer pair ready for use with `AgentChannel`.
520async fn establish_ssh_transport(
521    params: &ConnectionParams,
522) -> Result<
523    (
524        BufReader<tokio::process::ChildStdout>,
525        tokio::process::ChildStdin,
526        Child,
527    ),
528    SshError,
529> {
530    let mut cmd = Command::new("ssh");
531
532    cmd.arg("-o").arg("StrictHostKeyChecking=accept-new");
533    // Disable password prompts for reconnection (non-interactive)
534    cmd.arg("-o").arg("BatchMode=yes");
535
536    if let Some(port) = params.port {
537        cmd.arg("-p").arg(port.to_string());
538    }
539
540    if let Some(ref identity) = params.identity_file {
541        cmd.arg("-i").arg(identity);
542    }
543
544    cmd.args(&params.extra_args);
545    cmd.arg(params.ssh_target());
546
547    let agent_len = AGENT_SOURCE.len();
548    let bootstrap = format!(
549        "python3 -u -c \"import sys;exec(sys.stdin.read({}))\"",
550        agent_len
551    );
552    cmd.arg(bootstrap);
553
554    cmd.stdin(Stdio::piped());
555    cmd.stdout(Stdio::piped());
556    cmd.stderr(Stdio::null()); // No terminal for reconnection
557    cmd.hide_window();
558
559    let mut child = cmd.spawn()?;
560
561    let mut stdin = child
562        .stdin
563        .take()
564        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
565    let stdout = child
566        .stdout
567        .take()
568        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
569
570    // Send the agent code
571    stdin.write_all(AGENT_SOURCE.as_bytes()).await?;
572    stdin.flush().await?;
573
574    let mut reader = BufReader::new(stdout);
575
576    // Wait for ready message
577    let mut ready_line = String::new();
578    match reader.read_line(&mut ready_line).await {
579        Ok(0) => {
580            // Reconnect spawns with `stderr(Stdio::null())`, so there is no
581            // captured stderr to attach here.
582            return Err(ssh_eof_error(&mut child, params, None).await);
583        }
584        Ok(_) => {}
585        Err(e) => return Err(SshError::AgentStartFailed(format!("read error: {}", e))),
586    }
587
588    let ready: AgentResponse = serde_json::from_str(&ready_line).map_err(|e| {
589        SshError::AgentStartFailed(format!(
590            "invalid ready message '{}': {}",
591            ready_line.trim(),
592            e
593        ))
594    })?;
595
596    if !ready.is_ready() {
597        return Err(SshError::AgentStartFailed(
598            "agent did not send ready message".to_string(),
599        ));
600    }
601
602    let version = ready.version.unwrap_or(0);
603    if version != crate::services::remote::protocol::PROTOCOL_VERSION {
604        return Err(SshError::VersionMismatch {
605            expected: crate::services::remote::protocol::PROTOCOL_VERSION,
606            got: version,
607        });
608    }
609
610    Ok((reader, stdin, child))
611}
612
613/// Spawn a local agent process for testing (no SSH)
614///
615/// This is used by integration tests to test the full stack without SSH.
616/// Not intended for production use.
617#[doc(hidden)]
618pub async fn spawn_local_agent() -> Result<std::sync::Arc<AgentChannel>, SshError> {
619    use tokio::process::Command as TokioCommand;
620
621    let mut child = TokioCommand::new("python3")
622        .arg("-u")
623        .arg("-c")
624        .arg(AGENT_SOURCE)
625        .stdin(Stdio::piped())
626        .stdout(Stdio::piped())
627        .stderr(Stdio::piped())
628        .hide_window()
629        .spawn()?;
630
631    let stdin = child
632        .stdin
633        .take()
634        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
635    let stdout = child
636        .stdout
637        .take()
638        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
639
640    let mut reader = BufReader::new(stdout);
641
642    // Wait for ready message
643    let mut ready_line = String::new();
644    reader.read_line(&mut ready_line).await?;
645
646    let ready: AgentResponse = serde_json::from_str(&ready_line)
647        .map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
648
649    if !ready.is_ready() {
650        return Err(SshError::AgentStartFailed(
651            "agent did not send ready message".to_string(),
652        ));
653    }
654
655    Ok(std::sync::Arc::new(AgentChannel::new(reader, stdin)))
656}
657
658/// Spawn a local Python agent with a custom data channel capacity.
659///
660/// Same as `spawn_local_agent` but allows overriding the channel capacity
661/// for stress-testing backpressure handling.
662#[doc(hidden)]
663pub async fn spawn_local_agent_with_capacity(
664    data_channel_capacity: usize,
665) -> Result<std::sync::Arc<AgentChannel>, SshError> {
666    use tokio::process::Command as TokioCommand;
667
668    let mut child = TokioCommand::new("python3")
669        .arg("-u")
670        .arg("-c")
671        .arg(AGENT_SOURCE)
672        .stdin(Stdio::piped())
673        .stdout(Stdio::piped())
674        .stderr(Stdio::piped())
675        .hide_window()
676        .spawn()?;
677
678    let stdin = child
679        .stdin
680        .take()
681        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
682    let stdout = child
683        .stdout
684        .take()
685        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
686
687    let mut reader = BufReader::new(stdout);
688
689    // Wait for ready message
690    let mut ready_line = String::new();
691    reader.read_line(&mut ready_line).await?;
692
693    let ready: AgentResponse = serde_json::from_str(&ready_line)
694        .map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
695
696    if !ready.is_ready() {
697        return Err(SshError::AgentStartFailed(
698            "agent did not send ready message".to_string(),
699        ));
700    }
701
702    Ok(std::sync::Arc::new(AgentChannel::with_capacity(
703        reader,
704        stdin,
705        data_channel_capacity,
706    )))
707}
708
709/// Spawn a local Python agent and return the raw reader/writer transport.
710///
711/// Unlike `spawn_local_agent`, this does NOT create an `AgentChannel`. It
712/// returns the ready-to-use reader and writer so callers can feed them to
713/// `AgentChannel::replace_transport()` for reconnection testing.
714#[doc(hidden)]
715pub async fn spawn_local_agent_transport() -> Result<
716    (
717        tokio::io::BufReader<tokio::process::ChildStdout>,
718        tokio::process::ChildStdin,
719    ),
720    SshError,
721> {
722    use tokio::process::Command as TokioCommand;
723
724    let mut child = TokioCommand::new("python3")
725        .arg("-u")
726        .arg("-c")
727        .arg(AGENT_SOURCE)
728        .stdin(Stdio::piped())
729        .stdout(Stdio::piped())
730        .stderr(Stdio::piped())
731        .hide_window()
732        .spawn()?;
733
734    let stdin = child
735        .stdin
736        .take()
737        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
738    let stdout = child
739        .stdout
740        .take()
741        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
742
743    let mut reader = BufReader::new(stdout);
744
745    // Wait for ready message
746    let mut ready_line = String::new();
747    reader.read_line(&mut ready_line).await?;
748
749    let ready: AgentResponse = serde_json::from_str(&ready_line)
750        .map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
751
752    if !ready.is_ready() {
753        return Err(SshError::AgentStartFailed(
754            "agent did not send ready message".to_string(),
755        ));
756    }
757
758    Ok((reader, stdin))
759}
760
761#[cfg(test)]
762mod tests {
763    use super::*;
764
765    #[test]
766    fn test_parse_connection_params() {
767        let params = ConnectionParams::parse("user@host").unwrap();
768        assert_eq!(params.user.as_deref(), Some("user"));
769        assert_eq!(params.host, "host");
770        assert_eq!(params.port, None);
771
772        let params = ConnectionParams::parse("user@host:22").unwrap();
773        assert_eq!(params.user.as_deref(), Some("user"));
774        assert_eq!(params.host, "host");
775        assert_eq!(params.port, Some(22));
776
777        // User is optional: bare host and ssh:// both parse, user = None.
778        let params = ConnectionParams::parse("hostonly").unwrap();
779        assert_eq!(params.user, None);
780        assert_eq!(params.host, "hostonly");
781        assert_eq!(params.ssh_target(), "hostonly");
782
783        let params = ConnectionParams::parse("ssh://example.com:2222").unwrap();
784        assert_eq!(params.user, None);
785        assert_eq!(params.host, "example.com");
786        assert_eq!(params.port, Some(2222));
787
788        // Empty user / empty host are still rejected.
789        assert!(ConnectionParams::parse("@host").is_none());
790        assert!(ConnectionParams::parse("user@").is_none());
791    }
792
793    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
794    async fn heartbeat_keeps_channel_warm_and_exits_on_drop() {
795        // Real agent over local stdio — no SSH/kubectl, same channel.
796        let channel = spawn_local_agent().await.expect("spawn local agent");
797        let handle = spawn_heartbeat_task(&channel, std::time::Duration::from_millis(30));
798
799        // Let several heartbeats fire; the channel must stay healthy.
800        tokio::time::sleep(std::time::Duration::from_millis(150)).await;
801        assert!(
802            channel.is_connected(),
803            "channel stays connected while heartbeat pings"
804        );
805        assert!(
806            channel.request("info", serde_json::json!({})).await.is_ok(),
807            "agent still answers after heartbeats"
808        );
809
810        // Dropping the last strong ref lets the Weak-based task terminate
811        // on its own — proving it can't leak past the connection's life.
812        drop(channel);
813        tokio::time::timeout(std::time::Duration::from_secs(3), handle)
814            .await
815            .expect("heartbeat task exits after the channel is dropped")
816            .expect("heartbeat task did not panic");
817    }
818
819    #[test]
820    fn test_connection_string() {
821        let params = ConnectionParams {
822            user: Some("alice".to_string()),
823            host: "example.com".to_string(),
824            port: None,
825            identity_file: None,
826            extra_args: Vec::new(),
827        };
828        assert_eq!(params.to_string(), "alice@example.com");
829
830        let params = ConnectionParams {
831            user: Some("bob".to_string()),
832            host: "server.local".to_string(),
833            port: Some(2222),
834            identity_file: None,
835            extra_args: Vec::new(),
836        };
837        assert_eq!(params.to_string(), "bob@server.local:2222");
838
839        // No user: the target (and display) is the bare host.
840        let params = ConnectionParams {
841            user: None,
842            host: "server.local".to_string(),
843            port: Some(2222),
844            identity_file: None,
845            extra_args: Vec::new(),
846        };
847        assert_eq!(params.to_string(), "server.local:2222");
848        assert_eq!(params.ssh_target(), "server.local");
849    }
850}