Skip to main content

gritty/
connect.rs

1use anyhow::{Context, bail};
2use std::os::fd::OwnedFd;
3use std::os::unix::fs::OpenOptionsExt;
4use std::path::{Path, PathBuf};
5use std::process::Stdio;
6use std::time::{Duration, Instant};
7use tokio::process::{Child, Command};
8use tracing::{debug, info, warn};
9
10// ---------------------------------------------------------------------------
11// Destination parsing
12// ---------------------------------------------------------------------------
13
14#[derive(Debug, Clone, PartialEq, Eq)]
15struct Destination {
16    user: Option<String>,
17    host: String,
18    port: Option<u16>,
19}
20
21impl Destination {
22    fn parse(s: &str) -> anyhow::Result<Self> {
23        if s.is_empty() {
24            bail!("empty destination");
25        }
26
27        let (user, remainder) = if let Some(at) = s.find('@') {
28            let u = &s[..at];
29            if u.is_empty() {
30                bail!("empty user in destination: {s}");
31            }
32            (Some(u.to_string()), &s[at + 1..])
33        } else {
34            (None, s)
35        };
36
37        // Handle bracketed IPv6: [::1] or [::1]:port
38        let (host, port) = if remainder.starts_with('[') {
39            if let Some(bracket) = remainder.find(']') {
40                let h = &remainder[1..bracket];
41                let after = &remainder[bracket + 1..];
42                let p = if let Some(rest) = after.strip_prefix(':') {
43                    Some(
44                        rest.parse::<u16>()
45                            .with_context(|| format!("invalid port in destination: {s}"))?,
46                    )
47                } else if after.is_empty() {
48                    None
49                } else {
50                    bail!("unexpected characters after bracketed host: {s}");
51                };
52                (h.to_string(), p)
53            } else {
54                bail!("unclosed bracket in destination: {s}");
55            }
56        } else if remainder.matches(':').count() > 1 {
57            // Multiple colons without brackets -- bare IPv6 address (no port)
58            (remainder.to_string(), None)
59        } else if let Some(colon) = remainder.rfind(':') {
60            let h = &remainder[..colon];
61            match remainder[colon + 1..].parse::<u16>() {
62                Ok(p) => (h.to_string(), Some(p)),
63                Err(_) => (remainder.to_string(), None),
64            }
65        } else {
66            (remainder.to_string(), None)
67        };
68
69        if host.is_empty() {
70            bail!("empty host in destination: {s}");
71        }
72
73        Ok(Self { user, host, port })
74    }
75
76    /// Build the SSH destination string (`user@host` or just `host`).
77    fn ssh_dest(&self) -> String {
78        match &self.user {
79            Some(u) => format!("{u}@{}", self.host),
80            None => self.host.clone(),
81        }
82    }
83
84    /// Common SSH args for port, if set.
85    fn port_args(&self) -> Vec<String> {
86        match self.port {
87            Some(p) => vec!["-p".to_string(), p.to_string()],
88            None => vec![],
89        }
90    }
91}
92
93/// Reject connection names that could cause path traversal or corruption.
94fn validate_connection_name(name: &str) -> anyhow::Result<()> {
95    if name.is_empty() {
96        bail!("connection name must not be empty");
97    }
98    if name.contains('/') || name.contains('\0') || name.contains("..") {
99        bail!("invalid connection name: {name:?}");
100    }
101    Ok(())
102}
103
104// ---------------------------------------------------------------------------
105// SSH helpers
106// ---------------------------------------------------------------------------
107
108/// Tunnel-specific SSH `-o` options (keepalive, cleanup, failure behavior).
109const TUNNEL_SSH_OPTS: &[&str] = &[
110    "ServerAliveInterval=3",
111    "ServerAliveCountMax=2",
112    "StreamLocalBindUnlink=yes",
113    "ExitOnForwardFailure=yes",
114    // Prevent user config from leaking forwarding or connection sharing
115    // into the tunnel (gritty handles agent forwarding separately).
116    "ControlPath=none",
117    "ForwardAgent=no",
118    "ForwardX11=no",
119];
120
121/// PATH prefix prepended to remote commands so gritty is discoverable
122/// in non-interactive SSH shells.
123const REMOTE_PATH_PREFIX: &str = "$HOME/bin:$HOME/.local/bin:$HOME/.cargo/bin:$HOME/.nix-profile/bin:/usr/local/bin:/opt/homebrew/bin:/snap/bin:$PATH";
124
125/// Build the common SSH args that precede the destination in every invocation:
126/// port, user-supplied options, ConnectTimeout, and BatchMode (background only).
127fn base_ssh_args(dest: &Destination, extra_ssh_opts: &[String], foreground: bool) -> Vec<String> {
128    let mut args = Vec::new();
129    args.extend(dest.port_args());
130    for opt in extra_ssh_opts {
131        args.push("-o".into());
132        args.push(opt.clone());
133    }
134    args.push("-o".into());
135    args.push("ConnectTimeout=5".into());
136    if !foreground {
137        args.push("-o".into());
138        args.push("BatchMode=yes".into());
139    }
140    args
141}
142
143/// Build the SSH command for remote execution (without stdio config).
144fn remote_exec_command(
145    dest: &Destination,
146    remote_cmd: &str,
147    extra_ssh_opts: &[String],
148    foreground: bool,
149) -> Command {
150    let mut preamble = format!("PATH=\"{REMOTE_PATH_PREFIX}\"");
151    if let Ok(dir) = std::env::var("GRITTY_SOCKET_DIR") {
152        preamble.push_str(&format!("; export GRITTY_SOCKET_DIR=\"{dir}\""));
153    }
154    let wrapped_cmd = format!("{preamble}; {remote_cmd}");
155    let mut cmd = Command::new("ssh");
156    cmd.args(base_ssh_args(dest, extra_ssh_opts, foreground));
157    cmd.arg(dest.ssh_dest());
158    cmd.arg(&wrapped_cmd);
159    cmd
160}
161
162/// Run a command on the remote host via SSH, returning stdout.
163///
164/// Stderr is always piped so we can include SSH errors in our error messages.
165/// SSH interactive prompts use `/dev/tty` directly, not stderr.
166/// In background mode, `BatchMode=yes` is set so SSH fails fast instead of hanging.
167async fn remote_exec(
168    dest: &Destination,
169    remote_cmd: &str,
170    extra_ssh_opts: &[String],
171    foreground: bool,
172) -> anyhow::Result<String> {
173    debug!("ssh {}: {remote_cmd}", dest.ssh_dest());
174
175    let mut cmd = remote_exec_command(dest, remote_cmd, extra_ssh_opts, foreground);
176    cmd.stdout(Stdio::piped());
177    cmd.stderr(Stdio::piped());
178    cmd.stdin(Stdio::null());
179
180    let output = cmd.output().await.context("failed to run ssh")?;
181
182    if !output.status.success() {
183        let stderr = String::from_utf8_lossy(&output.stderr);
184        let stderr = stderr.trim();
185        debug!("ssh failed (status {}): {stderr}", output.status);
186        if stderr.contains("command not found") || stderr.contains("No such file") {
187            bail!(
188                "gritty not found on remote host -- install it there with: cargo install gritty-cli"
189            );
190        }
191        let diag = format_ssh_diag(dest, extra_ssh_opts, foreground);
192        if stderr.is_empty() {
193            bail!("ssh command failed (exit {})\n  to diagnose: {diag}", output.status);
194        }
195        bail!("ssh command failed: {stderr}\n  to diagnose: {diag}");
196    }
197
198    let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string();
199    debug!("ssh output: {stdout}");
200    Ok(stdout)
201}
202
203/// Format a diagnostic SSH command for display in error messages.
204/// Mirrors `base_ssh_args` so the suggestion matches what was actually run.
205fn format_ssh_diag(dest: &Destination, extra_ssh_opts: &[String], foreground: bool) -> String {
206    let mut parts = vec!["ssh".to_string()];
207    for arg in base_ssh_args(dest, extra_ssh_opts, foreground) {
208        parts.push(shell_quote(&arg));
209    }
210    parts.push(dest.ssh_dest());
211    parts.join(" ")
212}
213
214/// Shell-quote a string if it contains characters that need quoting.
215/// Used only for display (--dry-run output), never for command execution.
216fn shell_quote(s: &str) -> String {
217    shlex::try_quote(s).map(|c| c.into_owned()).unwrap_or_else(|_| s.to_string())
218}
219
220/// Format a tokio Command as a shell string for display.
221fn format_command(cmd: &Command) -> String {
222    let std_cmd = cmd.as_std();
223    let prog = std_cmd.get_program().to_string_lossy();
224    let args: Vec<_> = std_cmd.get_args().map(|a| shell_quote(&a.to_string_lossy())).collect();
225    if args.is_empty() { prog.to_string() } else { format!("{prog} {}", args.join(" ")) }
226}
227
228/// Build the SSH tunnel command with hardened options.
229///
230/// Stderr is always piped so we can capture SSH errors on failure.
231/// (SSH interactive prompts use `/dev/tty` directly, not stderr.)
232fn tunnel_command(
233    dest: &Destination,
234    local_sock: &Path,
235    remote_sock: &str,
236    extra_ssh_opts: &[String],
237    foreground: bool,
238) -> Command {
239    let mut cmd = Command::new("ssh");
240    cmd.args(base_ssh_args(dest, extra_ssh_opts, foreground));
241    for opt in TUNNEL_SSH_OPTS {
242        cmd.arg("-o").arg(opt);
243    }
244    cmd.args(["-N", "-T"]);
245    let forward = format!("{}:{}", local_sock.display(), remote_sock);
246    cmd.arg("-L").arg(forward);
247    cmd.arg(dest.ssh_dest());
248    cmd.stdout(Stdio::null());
249    cmd.stderr(Stdio::piped());
250    cmd.stdin(Stdio::null());
251    cmd
252}
253
254/// Spawn the SSH tunnel, returning the child process.
255async fn spawn_tunnel(
256    dest: &Destination,
257    local_sock: &Path,
258    remote_sock: &str,
259    extra_ssh_opts: &[String],
260    foreground: bool,
261) -> anyhow::Result<Child> {
262    debug!("tunnel: {} -> {}:{}", local_sock.display(), dest.ssh_dest(), remote_sock,);
263    let mut cmd = tunnel_command(dest, local_sock, remote_sock, extra_ssh_opts, foreground);
264    cmd.kill_on_drop(true);
265    let child = cmd.spawn().context("failed to spawn ssh tunnel")?;
266    debug!("ssh tunnel pid: {:?}", child.id());
267    Ok(child)
268}
269
270/// Poll until the local socket is connectable (200ms interval).
271async fn wait_for_socket(path: &Path, timeout: Duration) -> anyhow::Result<()> {
272    let deadline = Instant::now() + timeout;
273    loop {
274        if std::os::unix::net::UnixStream::connect(path).is_ok() {
275            return Ok(());
276        }
277        if Instant::now() >= deadline {
278            bail!("timeout waiting for SSH tunnel socket at {}", path.display());
279        }
280        tokio::time::sleep(Duration::from_millis(200)).await;
281    }
282}
283
284/// Background task: monitor SSH child, respawn on transient failure.
285/// Uses exponential backoff (1s..60s) and never gives up on transient errors.
286async fn tunnel_monitor(
287    mut child: Child,
288    dest: Destination,
289    local_sock: PathBuf,
290    remote_sock: String,
291    extra_ssh_opts: Vec<String>,
292    stop: tokio_util::sync::CancellationToken,
293) {
294    let mut backoff = Duration::from_secs(1);
295    const MAX_BACKOFF: Duration = Duration::from_secs(60);
296    /// If the tunnel stays alive for this long, reset the backoff.
297    const HEALTHY_THRESHOLD: Duration = Duration::from_secs(30);
298
299    loop {
300        let spawned_at = Instant::now();
301
302        tokio::select! {
303            _ = stop.cancelled() => {
304                let _ = child.kill().await;
305                return;
306            }
307            status = child.wait() => {
308                let status = match status {
309                    Ok(s) => s,
310                    Err(e) => {
311                        warn!("failed to wait on ssh tunnel: {e}");
312                        return;
313                    }
314                };
315
316                if stop.is_cancelled() {
317                    return;
318                }
319
320                let code = status.code();
321                debug!("ssh tunnel exited: {:?}", code);
322
323                // Non-transient failure: don't retry
324                // SSH exit 255 = connection error (transient). Signal-killed = no code.
325                // Everything else (auth failure, config error) = bail.
326                if let Some(c) = code
327                    && c != 255
328                {
329                    warn!("ssh tunnel exited with code {c} (not retrying)");
330                    return;
331                }
332
333                // Reset backoff if the tunnel was alive long enough
334                if spawned_at.elapsed() >= HEALTHY_THRESHOLD {
335                    backoff = Duration::from_secs(1);
336                }
337
338                info!("ssh tunnel died, retrying in {}s", backoff.as_secs());
339
340                tokio::select! {
341                    _ = tokio::time::sleep(backoff) => {}
342                    _ = stop.cancelled() => return,
343                }
344
345                backoff = (backoff * 2).min(MAX_BACKOFF);
346
347                match spawn_tunnel(&dest, &local_sock, &remote_sock, &extra_ssh_opts, false).await {
348                    Ok(new_child) => {
349                        info!("ssh tunnel respawned");
350                        child = new_child;
351                    }
352                    Err(e) => {
353                        warn!("failed to respawn ssh tunnel: {e}");
354                        // Even spawn failure is retried -- the tunnel process
355                        // might just not be available momentarily.
356                        continue;
357                    }
358                }
359            }
360        }
361    }
362}
363
364// ---------------------------------------------------------------------------
365// Remote server management
366// ---------------------------------------------------------------------------
367
368const REMOTE_ENSURE_CMD: &str = "\
369    SOCK=$(gritty socket-path) && \
370    (gritty ls local >/dev/null 2>&1 || \
371     { gritty server && sleep 0.3; }) && \
372    echo \"$SOCK\" && \
373    gritty protocol-version 2>/dev/null || true";
374
375/// Get the remote socket path and optionally auto-start the server.
376/// Returns (socket_path, remote_protocol_version).
377async fn ensure_remote_ready(
378    dest: &Destination,
379    no_server_start: bool,
380    extra_ssh_opts: &[String],
381    foreground: bool,
382) -> anyhow::Result<(String, Option<u16>)> {
383    let remote_cmd = if no_server_start { "gritty socket-path" } else { REMOTE_ENSURE_CMD };
384    debug!("ensuring remote server (no_server_start={no_server_start})");
385
386    let output = remote_exec(dest, remote_cmd, extra_ssh_opts, foreground).await?;
387
388    // Output is "socket_path\nversion" (version line may be absent for old remotes)
389    let mut lines = output.lines();
390    let sock_path = lines.next().unwrap_or("").to_string();
391    let remote_version = lines.next().and_then(|s| s.trim().parse::<u16>().ok());
392
393    if sock_path.is_empty() {
394        bail!("remote host returned empty socket path");
395    }
396
397    Ok((sock_path, remote_version))
398}
399
400// ---------------------------------------------------------------------------
401// Local socket path
402// ---------------------------------------------------------------------------
403
404/// Compute a deterministic local socket path based on the destination.
405///
406/// Using the raw destination string means re-running `gritty connect user@host`
407/// produces the same socket path, so sessions that used `--ctl-socket` can
408/// auto-reconnect after a tunnel restart.
409fn local_socket_path(destination: &str) -> PathBuf {
410    crate::daemon::socket_dir().join(format!("connect-{destination}.sock"))
411}
412
413fn connect_pid_path(connection_name: &str) -> PathBuf {
414    crate::daemon::socket_dir().join(format!("connect-{connection_name}.pid"))
415}
416
417fn connect_lock_path(connection_name: &str) -> PathBuf {
418    crate::daemon::socket_dir().join(format!("connect-{connection_name}.lock"))
419}
420
421fn connect_dest_path(connection_name: &str) -> PathBuf {
422    crate::daemon::socket_dir().join(format!("connect-{connection_name}.dest"))
423}
424
425/// Compute the local socket path for a given connection name.
426/// Public so main.rs can compute the path in the parent process after daemonize.
427pub fn connection_socket_path(connection_name: &str) -> PathBuf {
428    local_socket_path(connection_name)
429}
430
431/// Extract the host component from a destination string (`[user@]host[:port]`).
432pub fn parse_host(destination: &str) -> anyhow::Result<String> {
433    Ok(Destination::parse(destination)?.host)
434}
435
436/// Synchronous SSH connectivity check -- call before daemonizing to catch
437/// host-key prompts and password prompts while the terminal is still attached.
438pub fn preflight_ssh(dest_str: &str, ssh_options: &[String]) -> anyhow::Result<()> {
439    let dest = Destination::parse(dest_str)?;
440    let mut cmd = std::process::Command::new("ssh");
441    cmd.args(dest.port_args());
442    for opt in ssh_options {
443        cmd.arg("-o");
444        cmd.arg(opt);
445    }
446    cmd.args(["-o", "BatchMode=yes", "-o", "ConnectTimeout=5"]);
447    cmd.arg(dest.ssh_dest());
448    cmd.arg("true");
449    cmd.stdin(std::process::Stdio::null());
450    cmd.stdout(std::process::Stdio::null());
451    cmd.stderr(std::process::Stdio::null());
452    let status = cmd.status().context("failed to run ssh")?;
453    if !status.success() {
454        bail!(
455            "SSH cannot connect non-interactively to {}\n  \
456             if SSH needs a password or host key accept, use: gritty connect --foreground {}",
457            dest.ssh_dest(),
458            dest_str
459        );
460    }
461    Ok(())
462}
463
464// ---------------------------------------------------------------------------
465// Lockfile-based liveness
466// ---------------------------------------------------------------------------
467
468/// Acquire an exclusive flock on the lockfile. Returns the RAII lock on success.
469/// The lock is held for the lifetime of the returned `Flock`.
470fn acquire_lock(lock_path: &Path) -> anyhow::Result<nix::fcntl::Flock<std::fs::File>> {
471    use std::fs::OpenOptions;
472    let file = OpenOptions::new()
473        .create(true)
474        .truncate(false)
475        .write(true)
476        .mode(0o600)
477        .open(lock_path)
478        .with_context(|| format!("failed to open lockfile: {}", lock_path.display()))?;
479    nix::fcntl::Flock::lock(file, nix::fcntl::FlockArg::LockExclusive)
480        .map_err(|(_, e)| e)
481        .with_context(|| format!("failed to acquire lock on {}", lock_path.display()))
482}
483
484/// Probe whether a lockfile is held by a live process.
485/// Returns true if the lock is held (process alive), false if free (process dead).
486fn is_lock_held(lock_path: &Path) -> bool {
487    use std::fs::OpenOptions;
488    let file = match OpenOptions::new().read(true).open(lock_path) {
489        Ok(f) => f,
490        Err(_) => return false,
491    };
492    // Non-blocking exclusive lock attempt: if it succeeds, the old process is dead.
493    // The lock is released immediately when the Flock drops.
494    nix::fcntl::Flock::lock(file, nix::fcntl::FlockArg::LockExclusiveNonblock).is_err()
495}
496
497/// Tunnel health status.
498#[derive(Debug, PartialEq, Eq)]
499pub enum TunnelStatus {
500    Healthy,
501    Reconnecting,
502    Stale,
503}
504
505/// Probe a tunnel's status using lockfile + socket connectivity.
506fn probe_tunnel_status(name: &str) -> TunnelStatus {
507    let lock_path = connect_lock_path(name);
508    if is_lock_held(&lock_path) {
509        let sock_path = local_socket_path(name);
510        if std::os::unix::net::UnixStream::connect(&sock_path).is_ok() {
511            TunnelStatus::Healthy
512        } else {
513            TunnelStatus::Reconnecting
514        }
515    } else {
516        TunnelStatus::Stale
517    }
518}
519
520/// Clean up files for a stale tunnel (process already dead).
521/// No signals sent — the process is confirmed dead (lockfile released).
522/// Orphaned SSH children self-terminate via ServerAliveInterval/ServerAliveCountMax.
523fn read_pid_hint(name: &str) -> Option<u32> {
524    std::fs::read_to_string(connect_pid_path(name)).ok().and_then(|s| s.trim().parse().ok())
525}
526
527fn cleanup_stale_files(name: &str) {
528    let _ = std::fs::remove_file(local_socket_path(name));
529    let _ = std::fs::remove_file(connect_pid_path(name));
530    let _ = std::fs::remove_file(connect_lock_path(name));
531    let _ = std::fs::remove_file(connect_dest_path(name));
532}
533
534/// Extract tunnel connection names by globbing lock files in the socket dir.
535fn enumerate_tunnels() -> Vec<String> {
536    let dir = crate::daemon::socket_dir();
537    let Ok(entries) = std::fs::read_dir(&dir) else {
538        return Vec::new();
539    };
540    entries
541        .filter_map(|e| e.ok())
542        .filter_map(|e| {
543            let name = e.file_name().to_string_lossy().to_string();
544            if name.starts_with("connect-") && name.ends_with(".lock") {
545                Some(name["connect-".len()..name.len() - ".lock".len()].to_string())
546            } else {
547                None
548            }
549        })
550        .collect()
551}
552
553// ---------------------------------------------------------------------------
554// Cleanup guard
555// ---------------------------------------------------------------------------
556
557struct ConnectGuard {
558    child: Option<Child>,
559    local_sock: PathBuf,
560    pid_file: PathBuf,
561    lock_file: PathBuf,
562    dest_file: PathBuf,
563    _lock: Option<nix::fcntl::Flock<std::fs::File>>,
564    stop: tokio_util::sync::CancellationToken,
565}
566
567impl Drop for ConnectGuard {
568    fn drop(&mut self) {
569        self.stop.cancel();
570
571        if let Some(ref mut child) = self.child
572            && let Some(pid) = child.id()
573        {
574            unsafe {
575                libc::kill(pid as i32, libc::SIGTERM);
576            }
577        }
578
579        let _ = std::fs::remove_file(&self.local_sock);
580        let _ = std::fs::remove_file(&self.pid_file);
581        let _ = std::fs::remove_file(&self.lock_file);
582        let _ = std::fs::remove_file(&self.dest_file);
583        // _lock drops here, releasing the flock
584    }
585}
586
587// ---------------------------------------------------------------------------
588// Public API
589// ---------------------------------------------------------------------------
590
591pub struct ConnectOpts {
592    pub destination: String,
593    pub no_server_start: bool,
594    pub ssh_options: Vec<String>,
595    pub name: Option<String>,
596    pub dry_run: bool,
597    pub foreground: bool,
598    pub ignore_version_mismatch: bool,
599}
600
601pub async fn run(opts: ConnectOpts, ready_fd: Option<OwnedFd>) -> anyhow::Result<i32> {
602    unsafe {
603        libc::umask(0o077);
604    }
605
606    let dest = Destination::parse(&opts.destination)?;
607    let connection_name = opts.name.unwrap_or_else(|| dest.host.clone());
608    validate_connection_name(&connection_name)?;
609    let local_sock = local_socket_path(&connection_name);
610
611    if opts.dry_run {
612        let remote_cmd =
613            if opts.no_server_start { "gritty socket-path" } else { REMOTE_ENSURE_CMD };
614        let ensure_cmd = remote_exec_command(&dest, remote_cmd, &opts.ssh_options, true);
615        let tunnel_cmd =
616            tunnel_command(&dest, &local_sock, "$REMOTE_SOCK", &opts.ssh_options, true);
617
618        println!(
619            "# Get remote socket path{}",
620            if opts.no_server_start { "" } else { " and start server if needed" }
621        );
622        println!("REMOTE_SOCK=$({})", format_command(&ensure_cmd));
623        println!();
624        println!("# Start SSH tunnel");
625        println!("{}", format_command(&tunnel_cmd));
626        return Ok(0);
627    }
628
629    // 1. Ensure socket directory exists
630    let pid_file = connect_pid_path(&connection_name);
631    let lock_path = connect_lock_path(&connection_name);
632    let dest_file = connect_dest_path(&connection_name);
633    debug!("local socket: {}", local_sock.display());
634    if let Some(parent) = local_sock.parent() {
635        crate::security::secure_create_dir_all(parent)?;
636    }
637
638    // 2. Check for existing tunnel via lockfile (authoritative)
639    match probe_tunnel_status(&connection_name) {
640        TunnelStatus::Healthy => {
641            println!("{}", local_sock.display());
642            let pid_hint = read_pid_hint(&connection_name);
643            eprint!("tunnel already running (name: {connection_name})");
644            if let Some(pid) = pid_hint {
645                eprintln!(" (pid {pid})");
646                eprintln!("  to stop: gritty disconnect {connection_name}");
647            } else {
648                eprintln!();
649            }
650            eprintln!("  to use:");
651            eprintln!("    gritty new {connection_name}");
652            eprintln!("    gritty attach {connection_name}:<session>");
653            // Signal readiness to parent even for already-running case
654            signal_ready(&ready_fd);
655            return Ok(0);
656        }
657        TunnelStatus::Reconnecting => {
658            let pid_hint = read_pid_hint(&connection_name);
659            eprint!("tunnel exists but is reconnecting (name: {connection_name})");
660            if let Some(pid) = pid_hint {
661                eprintln!(" (pid {pid})");
662            } else {
663                eprintln!();
664            }
665            eprintln!("  wait for it, or: gritty disconnect {connection_name}");
666            // Signal readiness to parent so it doesn't hang
667            signal_ready(&ready_fd);
668            return Ok(0);
669        }
670        TunnelStatus::Stale => {
671            debug!("cleaning stale tunnel files for {connection_name}");
672            cleanup_stale_files(&connection_name);
673        }
674    }
675
676    // 3. Acquire lockfile (held for entire lifetime of this process)
677    let lock_fd = acquire_lock(&lock_path)?;
678
679    // 4. Ensure remote server is running and get socket path
680    let (remote_sock, remote_version) =
681        ensure_remote_ready(&dest, opts.no_server_start, &opts.ssh_options, opts.foreground)
682            .await?;
683    debug!(remote_sock, ?remote_version, "remote socket path");
684
685    // Check protocol version compatibility
686    if let Some(rv) = remote_version {
687        if rv != crate::protocol::PROTOCOL_VERSION {
688            let msg = format!(
689                "remote protocol version ({rv}) differs from local ({}); \
690                 use --ignore-version-mismatch to connect anyway",
691                crate::protocol::PROTOCOL_VERSION
692            );
693            if opts.ignore_version_mismatch {
694                warn!("{msg}");
695            } else {
696                bail!("{msg}");
697            }
698        }
699    }
700
701    // 5. Spawn SSH tunnel
702    let child =
703        spawn_tunnel(&dest, &local_sock, &remote_sock, &opts.ssh_options, opts.foreground).await?;
704    let stop = tokio_util::sync::CancellationToken::new();
705
706    let mut guard = ConnectGuard {
707        child: Some(child),
708        local_sock: local_sock.clone(),
709        pid_file: pid_file.clone(),
710        lock_file: lock_path,
711        dest_file: dest_file.clone(),
712        _lock: Some(lock_fd),
713        stop: stop.clone(),
714    };
715
716    // 6. Wait for local socket to become connectable (race against child exit)
717    let mut child = guard.child.take().unwrap();
718    tokio::select! {
719        result = wait_for_socket(&local_sock, Duration::from_secs(15)) => {
720            result?;
721            guard.child = Some(child);
722        }
723        status = child.wait() => {
724            let status = status.context("failed to wait on ssh tunnel")?;
725            let diag = format_ssh_diag(&dest, &opts.ssh_options, opts.foreground);
726            let msg = if let Some(mut stderr) = child.stderr.take() {
727                use tokio::io::AsyncReadExt;
728                let mut buf = String::new();
729                let _ = stderr.read_to_string(&mut buf).await;
730                let buf = buf.trim().to_string();
731                if buf.is_empty() { None } else { Some(buf) }
732            } else {
733                None
734            };
735            let fg_hint = if opts.foreground {
736                String::new()
737            } else {
738                format!("\n  if SSH needs a password or host key accept, use: gritty connect --foreground {}", opts.destination)
739            };
740            match msg {
741                Some(err) => bail!("ssh tunnel failed: {err}\n  to diagnose: {diag}{fg_hint}"),
742                None => bail!("ssh tunnel exited ({status})\n  to diagnose: {diag}{fg_hint}"),
743            }
744        }
745    }
746    debug!("tunnel socket ready");
747
748    // Write PID + dest files
749    let _ = std::fs::write(&pid_file, std::process::id().to_string());
750    let _ = std::fs::write(&dest_file, &opts.destination);
751
752    // 7. Signal readiness to parent (or print if foreground)
753    signal_ready(&ready_fd);
754
755    // 8. Hand off the child to the tunnel monitor background task
756    let original_child = guard.child.take().unwrap();
757    let monitor_handle = tokio::spawn(tunnel_monitor(
758        original_child,
759        dest,
760        local_sock.clone(),
761        remote_sock,
762        opts.ssh_options,
763        stop.clone(),
764    ));
765
766    // 9. Wait for signal or monitor death
767    let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
768    tokio::select! {
769        _ = sigterm.recv() => {}
770        _ = monitor_handle => {}
771    }
772
773    // 10. Cleanup (guard Drop handles ssh kill + file removal + lock release)
774    drop(guard);
775
776    Ok(0)
777}
778
779/// Write readiness signal to the pipe fd: [0x01][pid: u32 LE].
780fn signal_ready(ready_fd: &Option<OwnedFd>) {
781    if let Some(fd) = ready_fd {
782        let pid = std::process::id();
783        let mut buf = [0u8; 5];
784        buf[0] = 0x01;
785        buf[1..5].copy_from_slice(&pid.to_le_bytes());
786        let _ = nix::unistd::write(fd, &buf);
787    }
788}
789
790// ---------------------------------------------------------------------------
791// Disconnect
792// ---------------------------------------------------------------------------
793
794pub async fn disconnect(name: &str) -> anyhow::Result<()> {
795    validate_connection_name(name)?;
796    match probe_tunnel_status(name) {
797        TunnelStatus::Stale => {
798            cleanup_stale_files(name);
799            eprintln!("tunnel already stopped: {name}");
800            return Ok(());
801        }
802        TunnelStatus::Healthy | TunnelStatus::Reconnecting => {}
803    }
804
805    // Read PID and send SIGTERM (let the process handle graceful shutdown)
806    let pid_file = connect_pid_path(name);
807    let pid: i32 = std::fs::read_to_string(&pid_file)
808        .ok()
809        .and_then(|s| s.trim().parse::<u32>().ok())
810        .map(|p| p as i32)
811        .ok_or_else(|| anyhow::anyhow!("cannot read PID for tunnel {name}"))?;
812
813    let lock_path = connect_lock_path(name);
814    if !is_lock_held(&lock_path) {
815        cleanup_stale_files(name);
816        eprintln!("tunnel already stopped: {name}");
817        return Ok(());
818    }
819    unsafe {
820        libc::kill(pid, libc::SIGTERM);
821    }
822
823    // Poll lock for up to 2s to confirm exit
824    let deadline = Instant::now() + Duration::from_secs(2);
825    loop {
826        tokio::time::sleep(Duration::from_millis(100)).await;
827        if !is_lock_held(&lock_path) {
828            cleanup_stale_files(name);
829            eprintln!("tunnel stopped: {name}");
830            return Ok(());
831        }
832        if Instant::now() >= deadline {
833            break;
834        }
835    }
836
837    // Still alive after timeout — escalate to SIGKILL + killpg
838    if is_lock_held(&lock_path) {
839        unsafe {
840            libc::kill(pid, libc::SIGKILL);
841            libc::killpg(pid, libc::SIGTERM);
842        }
843    }
844    tokio::time::sleep(Duration::from_millis(100)).await;
845    cleanup_stale_files(name);
846    eprintln!("tunnel killed: {name}");
847    Ok(())
848}
849
850// ---------------------------------------------------------------------------
851// List tunnels
852// ---------------------------------------------------------------------------
853
854pub struct TunnelInfo {
855    pub name: String,
856    pub destination: String,
857    pub status: String,
858    pub pid: Option<u32>,
859    pub log_path: PathBuf,
860}
861
862/// Gather info for all live tunnels (cleans stale ones as a side effect).
863pub fn get_tunnel_info() -> Vec<TunnelInfo> {
864    let names = enumerate_tunnels();
865    let mut infos = Vec::new();
866    for name in &names {
867        let status = probe_tunnel_status(name);
868        if status == TunnelStatus::Stale {
869            debug!("cleaning stale tunnel: {name}");
870            cleanup_stale_files(name);
871            continue;
872        }
873        let dest =
874            std::fs::read_to_string(connect_dest_path(name)).unwrap_or_else(|_| "-".to_string());
875        let status_str = match status {
876            TunnelStatus::Healthy => "healthy".to_string(),
877            TunnelStatus::Reconnecting => "reconnecting".to_string(),
878            TunnelStatus::Stale => unreachable!(),
879        };
880        infos.push(TunnelInfo {
881            name: name.clone(),
882            destination: dest.trim().to_string(),
883            status: status_str,
884            pid: read_pid_hint(name),
885            log_path: crate::daemon::socket_dir().join(format!("connect-{name}.log")),
886        });
887    }
888    infos
889}
890
891pub fn list_tunnels() {
892    let infos = get_tunnel_info();
893    if infos.is_empty() {
894        println!("no active tunnels");
895        return;
896    }
897
898    let rows: Vec<Vec<String>> = infos
899        .iter()
900        .map(|i| vec![i.name.clone(), i.destination.clone(), i.status.clone()])
901        .collect();
902    crate::table::print_table(&["Name", "Destination", "Status"], &rows);
903}
904
905// ---------------------------------------------------------------------------
906// Tests
907// ---------------------------------------------------------------------------
908
909#[cfg(test)]
910mod tests {
911    use super::*;
912
913    #[test]
914    fn parse_destination_user_host() {
915        let d = Destination::parse("user@host").unwrap();
916        assert_eq!(d.user.as_deref(), Some("user"));
917        assert_eq!(d.host, "host");
918        assert_eq!(d.port, None);
919    }
920
921    #[test]
922    fn parse_destination_host_only() {
923        let d = Destination::parse("myhost").unwrap();
924        assert_eq!(d.user, None);
925        assert_eq!(d.host, "myhost");
926        assert_eq!(d.port, None);
927    }
928
929    #[test]
930    fn parse_destination_host_port() {
931        let d = Destination::parse("host:2222").unwrap();
932        assert_eq!(d.user, None);
933        assert_eq!(d.host, "host");
934        assert_eq!(d.port, Some(2222));
935    }
936
937    #[test]
938    fn parse_destination_user_host_port() {
939        let d = Destination::parse("user@host:2222").unwrap();
940        assert_eq!(d.user.as_deref(), Some("user"));
941        assert_eq!(d.host, "host");
942        assert_eq!(d.port, Some(2222));
943    }
944
945    #[test]
946    fn parse_destination_invalid_empty() {
947        assert!(Destination::parse("").is_err());
948    }
949
950    #[test]
951    fn parse_destination_invalid_at_only() {
952        assert!(Destination::parse("@host").is_err());
953    }
954
955    #[test]
956    fn parse_destination_invalid_colon_only() {
957        assert!(Destination::parse(":2222").is_err());
958    }
959
960    #[test]
961    fn parse_destination_ipv6_bracketed() {
962        let d = Destination::parse("[::1]").unwrap();
963        assert_eq!(d.user, None);
964        assert_eq!(d.host, "::1");
965        assert_eq!(d.port, None);
966    }
967
968    #[test]
969    fn parse_destination_ipv6_bracketed_port() {
970        let d = Destination::parse("[::1]:2222").unwrap();
971        assert_eq!(d.user, None);
972        assert_eq!(d.host, "::1");
973        assert_eq!(d.port, Some(2222));
974    }
975
976    #[test]
977    fn parse_destination_ipv6_user_bracketed_port() {
978        let d = Destination::parse("user@[fe80::1]:22").unwrap();
979        assert_eq!(d.user.as_deref(), Some("user"));
980        assert_eq!(d.host, "fe80::1");
981        assert_eq!(d.port, Some(22));
982    }
983
984    #[test]
985    fn parse_destination_bare_ipv6() {
986        let d = Destination::parse("::1").unwrap();
987        assert_eq!(d.user, None);
988        assert_eq!(d.host, "::1");
989        assert_eq!(d.port, None);
990    }
991
992    #[test]
993    fn parse_destination_bare_ipv6_with_scope() {
994        let d = Destination::parse("fe80::1").unwrap();
995        assert_eq!(d.user, None);
996        assert_eq!(d.host, "fe80::1");
997        assert_eq!(d.port, None);
998    }
999
1000    #[test]
1001    fn parse_destination_ipv6_unclosed_bracket() {
1002        assert!(Destination::parse("[::1").is_err());
1003    }
1004
1005    #[test]
1006    fn tunnel_command_default_opts() {
1007        let dest = Destination::parse("user@host").unwrap();
1008        let cmd = tunnel_command(
1009            &dest,
1010            Path::new("/tmp/local.sock"),
1011            "/run/user/1000/gritty/ctl.sock",
1012            &[],
1013            false,
1014        );
1015        let args: Vec<_> =
1016            cmd.as_std().get_args().map(|a| a.to_string_lossy().to_string()).collect();
1017        // From base_ssh_args
1018        assert!(args.contains(&"ConnectTimeout=5".to_string()));
1019        assert!(args.contains(&"BatchMode=yes".to_string()));
1020        // From TUNNEL_SSH_OPTS
1021        assert!(args.contains(&"ServerAliveInterval=3".to_string()));
1022        assert!(args.contains(&"StreamLocalBindUnlink=yes".to_string()));
1023        assert!(args.contains(&"ExitOnForwardFailure=yes".to_string()));
1024        assert!(args.contains(&"ControlPath=none".to_string()));
1025        assert!(args.contains(&"ForwardAgent=no".to_string()));
1026        assert!(args.contains(&"ForwardX11=no".to_string()));
1027        // Tunnel flags and forward
1028        assert!(args.contains(&"-N".to_string()));
1029        assert!(args.contains(&"-T".to_string()));
1030        assert!(args.contains(&"/tmp/local.sock:/run/user/1000/gritty/ctl.sock".to_string()));
1031        assert!(args.contains(&"user@host".to_string()));
1032    }
1033
1034    #[test]
1035    fn tunnel_command_extra_opts() {
1036        let dest = Destination::parse("host:2222").unwrap();
1037        let cmd = tunnel_command(
1038            &dest,
1039            Path::new("/tmp/local.sock"),
1040            "/tmp/remote.sock",
1041            &["ProxyJump=bastion".to_string()],
1042            false,
1043        );
1044        let args: Vec<_> =
1045            cmd.as_std().get_args().map(|a| a.to_string_lossy().to_string()).collect();
1046        assert!(args.contains(&"ProxyJump=bastion".to_string()));
1047        assert!(args.contains(&"-p".to_string()));
1048        assert!(args.contains(&"2222".to_string()));
1049    }
1050
1051    #[test]
1052    fn local_socket_path_format() {
1053        // With hostname-based naming, connect uses just the host part
1054        let path = local_socket_path("devbox");
1055        assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-devbox.sock");
1056
1057        let path = local_socket_path("example.com");
1058        assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-example.com.sock");
1059
1060        // Custom name override
1061        let path = local_socket_path("myproject");
1062        assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-myproject.sock");
1063    }
1064
1065    #[test]
1066    fn connect_pid_path_format() {
1067        let path = connect_pid_path("devbox");
1068        assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-devbox.pid");
1069
1070        let path = connect_pid_path("example.com");
1071        assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-example.com.pid");
1072    }
1073
1074    #[test]
1075    fn ssh_dest_with_user() {
1076        let d = Destination::parse("alice@example.com").unwrap();
1077        assert_eq!(d.ssh_dest(), "alice@example.com");
1078    }
1079
1080    #[test]
1081    fn ssh_dest_without_user() {
1082        let d = Destination::parse("example.com").unwrap();
1083        assert_eq!(d.ssh_dest(), "example.com");
1084    }
1085
1086    #[test]
1087    fn port_args_with_port() {
1088        let d = Destination::parse("host:9999").unwrap();
1089        assert_eq!(d.port_args(), vec!["-p", "9999"]);
1090    }
1091
1092    #[test]
1093    fn port_args_without_port() {
1094        let d = Destination::parse("host").unwrap();
1095        assert!(d.port_args().is_empty());
1096    }
1097
1098    #[test]
1099    fn shell_quote_simple() {
1100        // Safe strings pass through unquoted
1101        assert_eq!(shell_quote("hello"), "hello");
1102        assert_eq!(shell_quote("-N"), "-N");
1103        // shlex quotes some chars the old impl didn't (=, @, :, $), which is correct
1104        let q = shell_quote("ServerAliveInterval=3");
1105        assert!(q.contains("ServerAliveInterval") && q.contains("3"));
1106        let q = shell_quote("user@host");
1107        assert!(q.contains("user") && q.contains("host"));
1108    }
1109
1110    #[test]
1111    fn shell_quote_needs_quoting() {
1112        let q = shell_quote("hello world");
1113        assert!(q.starts_with('\'') || q.starts_with('"'));
1114        assert!(q.contains("hello world"));
1115        assert_eq!(shell_quote(""), "''");
1116        let q = shell_quote("it's");
1117        assert!(q.contains("it") && q.contains("s"));
1118    }
1119
1120    #[test]
1121    fn shell_quote_remote_cmd() {
1122        // The wrapped remote command contains spaces, quotes, semicolons —
1123        // must be single-quoted so $HOME expands on the remote side.
1124        let cmd = format!("PATH=\"{REMOTE_PATH_PREFIX}\"; gritty socket-path");
1125        let quoted = shell_quote(&cmd);
1126        assert!(quoted.starts_with('\''));
1127        assert!(quoted.ends_with('\''));
1128    }
1129
1130    #[test]
1131    fn format_command_tunnel() {
1132        let dest = Destination::parse("user@host").unwrap();
1133        let cmd = tunnel_command(&dest, Path::new("/tmp/local.sock"), "$REMOTE_SOCK", &[], true);
1134        let formatted = format_command(&cmd);
1135        assert!(formatted.contains("ServerAliveInterval=3"));
1136        assert!(formatted.contains("ControlPath=none"));
1137        assert!(formatted.contains("ForwardAgent=no"));
1138        assert!(formatted.contains("-N"));
1139        assert!(formatted.contains("-T"));
1140        // Forward arg references $REMOTE_SOCK unquoted (no spaces, $ is safe)
1141        assert!(formatted.contains("/tmp/local.sock:$REMOTE_SOCK"));
1142        assert!(formatted.contains("user@host"));
1143    }
1144
1145    #[test]
1146    fn format_command_remote_exec() {
1147        let dest = Destination::parse("user@host:2222").unwrap();
1148        let cmd = remote_exec_command(&dest, "gritty socket-path", &[], true);
1149        let formatted = format_command(&cmd);
1150        assert!(formatted.starts_with("ssh "));
1151        assert!(formatted.contains("-p 2222"));
1152        assert!(formatted.contains("ConnectTimeout=5"));
1153        assert!(formatted.contains("user@host"));
1154        // The wrapped command should be single-quoted (contains spaces)
1155        assert!(formatted.contains(&format!("PATH=\"{REMOTE_PATH_PREFIX}\"")));
1156    }
1157
1158    #[test]
1159    fn format_command_remote_exec_with_extra_opts() {
1160        let dest = Destination::parse("user@host").unwrap();
1161        let cmd =
1162            remote_exec_command(&dest, REMOTE_ENSURE_CMD, &["ProxyJump=bastion".to_string()], true);
1163        let formatted = format_command(&cmd);
1164        assert!(formatted.contains("ProxyJump=bastion"));
1165        assert!(formatted.contains("gritty socket-path"));
1166        assert!(formatted.contains("gritty server"));
1167    }
1168
1169    #[test]
1170    fn base_ssh_args_foreground() {
1171        let dest = Destination::parse("user@host:2222").unwrap();
1172        let args = base_ssh_args(&dest, &["ProxyJump=bastion".into()], true);
1173        assert!(args.contains(&"-p".to_string()));
1174        assert!(args.contains(&"2222".to_string()));
1175        assert!(args.contains(&"ProxyJump=bastion".to_string()));
1176        assert!(args.contains(&"ConnectTimeout=5".to_string()));
1177        assert!(!args.contains(&"BatchMode=yes".to_string()));
1178    }
1179
1180    #[test]
1181    fn base_ssh_args_background() {
1182        let dest = Destination::parse("host").unwrap();
1183        let args = base_ssh_args(&dest, &[], false);
1184        assert!(args.contains(&"ConnectTimeout=5".to_string()));
1185        assert!(args.contains(&"BatchMode=yes".to_string()));
1186        assert!(!args.contains(&"-p".to_string()));
1187    }
1188
1189    // -----------------------------------------------------------------------
1190    // Lockfile and tunnel lifecycle tests
1191    // -----------------------------------------------------------------------
1192
1193    #[test]
1194    fn connect_lock_path_format() {
1195        let path = connect_lock_path("devbox");
1196        assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-devbox.lock");
1197    }
1198
1199    #[test]
1200    fn connect_dest_path_format() {
1201        let path = connect_dest_path("devbox");
1202        assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-devbox.dest");
1203    }
1204
1205    #[test]
1206    fn acquire_and_probe_lock() {
1207        let dir = tempfile::tempdir().unwrap();
1208        let lock_path = dir.path().join("test.lock");
1209
1210        // Lock not held initially (file doesn't exist)
1211        assert!(!is_lock_held(&lock_path));
1212
1213        // Acquire the lock
1214        let _fd = acquire_lock(&lock_path).unwrap();
1215
1216        // Now it should be held
1217        assert!(is_lock_held(&lock_path));
1218
1219        // Drop the lock
1220        drop(_fd);
1221
1222        // Should be free again
1223        assert!(!is_lock_held(&lock_path));
1224    }
1225
1226    #[test]
1227    fn probe_stale_no_files() {
1228        // No files at all → stale
1229        let status = probe_tunnel_status("nonexistent-test-tunnel-xyz");
1230        assert_eq!(status, TunnelStatus::Stale);
1231    }
1232
1233    #[test]
1234    fn cleanup_stale_files_removes_all() {
1235        let _dir = tempfile::tempdir().unwrap();
1236        // We can't easily override socket_dir(), so test that cleanup_stale_files
1237        // at least doesn't panic on nonexistent files
1238        cleanup_stale_files("nonexistent-cleanup-test-xyz");
1239        // No panic = success
1240    }
1241
1242    #[test]
1243    fn enumerate_tunnels_empty_dir() {
1244        // If socket dir doesn't have any lock files, should return empty
1245        // This tests the function doesn't crash on various filesystem states
1246        let names = enumerate_tunnels();
1247        // We can't control what's in socket_dir during tests, but at minimum
1248        // the function should not panic
1249        let _ = names;
1250    }
1251
1252    #[test]
1253    fn connection_socket_path_matches_local() {
1254        let public_path = connection_socket_path("myhost");
1255        let internal_path = local_socket_path("myhost");
1256        assert_eq!(public_path, internal_path);
1257    }
1258
1259    // -----------------------------------------------------------------------
1260    // tunnel_monitor tests
1261    // -----------------------------------------------------------------------
1262
1263    #[tokio::test]
1264    async fn tunnel_monitor_non_transient_exit() {
1265        let child = Command::new("sh").arg("-c").arg("exit 1").spawn().unwrap();
1266        let dest = Destination::parse("fake-host-test").unwrap();
1267        let stop = tokio_util::sync::CancellationToken::new();
1268
1269        let result = tokio::time::timeout(
1270            Duration::from_secs(5),
1271            tunnel_monitor(
1272                child,
1273                dest,
1274                PathBuf::from("/tmp/nonexistent.sock"),
1275                "/tmp/remote.sock".into(),
1276                vec![],
1277                stop,
1278            ),
1279        )
1280        .await;
1281
1282        assert!(result.is_ok(), "monitor should return quickly on non-transient exit");
1283    }
1284
1285    #[tokio::test]
1286    async fn tunnel_monitor_cancellation() {
1287        let child = Command::new("sleep").arg("60").spawn().unwrap();
1288        let dest = Destination::parse("fake-host-test").unwrap();
1289        let stop = tokio_util::sync::CancellationToken::new();
1290        let stop_clone = stop.clone();
1291
1292        tokio::spawn(async move {
1293            tokio::time::sleep(Duration::from_millis(100)).await;
1294            stop_clone.cancel();
1295        });
1296
1297        let result = tokio::time::timeout(
1298            Duration::from_secs(5),
1299            tunnel_monitor(
1300                child,
1301                dest,
1302                PathBuf::from("/tmp/nonexistent.sock"),
1303                "/tmp/remote.sock".into(),
1304                vec![],
1305                stop,
1306            ),
1307        )
1308        .await;
1309
1310        assert!(result.is_ok(), "monitor should return after cancellation");
1311    }
1312
1313    #[tokio::test]
1314    async fn tunnel_monitor_transient_exit_checks_cancellation() {
1315        // Child exits with 255 (transient). Monitor sleeps 1s then checks cancellation.
1316        let child = Command::new("sh").arg("-c").arg("exit 255").spawn().unwrap();
1317        let dest = Destination::parse("fake-host-test").unwrap();
1318        let stop = tokio_util::sync::CancellationToken::new();
1319        let stop_clone = stop.clone();
1320
1321        // Cancel during the 1s sleep between exit and respawn attempt
1322        tokio::spawn(async move {
1323            tokio::time::sleep(Duration::from_millis(500)).await;
1324            stop_clone.cancel();
1325        });
1326
1327        let result = tokio::time::timeout(
1328            Duration::from_secs(5),
1329            tunnel_monitor(
1330                child,
1331                dest,
1332                PathBuf::from("/tmp/nonexistent.sock"),
1333                "/tmp/remote.sock".into(),
1334                vec![],
1335                stop,
1336            ),
1337        )
1338        .await;
1339
1340        assert!(result.is_ok(), "monitor should return after cancellation during sleep");
1341    }
1342
1343    // -----------------------------------------------------------------------
1344    // wait_for_socket tests
1345    // -----------------------------------------------------------------------
1346
1347    #[tokio::test]
1348    async fn wait_for_socket_succeeds_after_delay() {
1349        let dir = tempfile::tempdir().unwrap();
1350        let sock_path = dir.path().join("delayed.sock");
1351        let sock_path_clone = sock_path.clone();
1352
1353        // Bind the socket after 500ms
1354        tokio::spawn(async move {
1355            tokio::time::sleep(Duration::from_millis(500)).await;
1356            let _listener = tokio::net::UnixListener::bind(&sock_path_clone).unwrap();
1357            // Keep listener alive
1358            tokio::time::sleep(Duration::from_secs(30)).await;
1359        });
1360
1361        let result = wait_for_socket(&sock_path, Duration::from_secs(5)).await;
1362        assert!(result.is_ok(), "should successfully connect");
1363    }
1364
1365    #[tokio::test]
1366    async fn wait_for_socket_timeout() {
1367        let dir = tempfile::tempdir().unwrap();
1368        let sock_path = dir.path().join("never.sock");
1369        let result = wait_for_socket(&sock_path, Duration::from_secs(1)).await;
1370        assert!(result.is_err());
1371    }
1372}