Skip to main content

gritty/
connect.rs

1use anyhow::{Context, bail};
2use std::os::fd::OwnedFd;
3use std::os::unix::fs::OpenOptionsExt;
4use std::os::unix::io::AsRawFd;
5use std::path::{Path, PathBuf};
6use std::process::Stdio;
7use std::time::{Duration, Instant};
8use tokio::process::{Child, Command};
9use tracing::{debug, info, warn};
10
11// ---------------------------------------------------------------------------
12// Destination parsing
13// ---------------------------------------------------------------------------
14
15#[derive(Debug, Clone, PartialEq, Eq)]
16struct Destination {
17    user: Option<String>,
18    host: String,
19    port: Option<u16>,
20}
21
22impl Destination {
23    fn parse(s: &str) -> anyhow::Result<Self> {
24        if s.is_empty() {
25            bail!("empty destination");
26        }
27
28        let (user, remainder) = if let Some(at) = s.find('@') {
29            let u = &s[..at];
30            if u.is_empty() {
31                bail!("empty user in destination: {s}");
32            }
33            (Some(u.to_string()), &s[at + 1..])
34        } else {
35            (None, s)
36        };
37
38        let (host, port) = if let Some(colon) = remainder.rfind(':') {
39            let h = &remainder[..colon];
40            let p = remainder[colon + 1..]
41                .parse::<u16>()
42                .with_context(|| format!("invalid port in destination: {s}"))?;
43            (h.to_string(), Some(p))
44        } else {
45            (remainder.to_string(), None)
46        };
47
48        if host.is_empty() {
49            bail!("empty host in destination: {s}");
50        }
51
52        Ok(Self { user, host, port })
53    }
54
55    /// Build the SSH destination string (`user@host` or just `host`).
56    fn ssh_dest(&self) -> String {
57        match &self.user {
58            Some(u) => format!("{u}@{}", self.host),
59            None => self.host.clone(),
60        }
61    }
62
63    /// Common SSH args for port, if set.
64    fn port_args(&self) -> Vec<String> {
65        match self.port {
66            Some(p) => vec!["-p".to_string(), p.to_string()],
67            None => vec![],
68        }
69    }
70}
71
72/// Reject connection names that could cause path traversal or corruption.
73fn validate_connection_name(name: &str) -> anyhow::Result<()> {
74    if name.is_empty() {
75        bail!("connection name must not be empty");
76    }
77    if name.contains('/') || name.contains('\0') || name.contains("..") {
78        bail!("invalid connection name: {name:?}");
79    }
80    Ok(())
81}
82
83// ---------------------------------------------------------------------------
84// SSH helpers
85// ---------------------------------------------------------------------------
86
87/// Tunnel-specific SSH `-o` options (keepalive, cleanup, failure behavior).
88const TUNNEL_SSH_OPTS: &[&str] = &[
89    "ServerAliveInterval=3",
90    "ServerAliveCountMax=2",
91    "StreamLocalBindUnlink=yes",
92    "ExitOnForwardFailure=yes",
93    // Prevent user config from leaking forwarding or connection sharing
94    // into the tunnel (gritty handles agent forwarding separately).
95    "ControlPath=none",
96    "ForwardAgent=no",
97    "ForwardX11=no",
98];
99
100/// PATH prefix prepended to remote commands so gritty is discoverable
101/// in non-interactive SSH shells.
102const REMOTE_PATH_PREFIX: &str =
103    "$HOME/bin:$HOME/.local/bin:$HOME/.cargo/bin:/usr/local/bin:/opt/homebrew/bin:$PATH";
104
105/// Build the common SSH args that precede the destination in every invocation:
106/// port, user-supplied options, ConnectTimeout, and BatchMode (background only).
107fn base_ssh_args(dest: &Destination, extra_ssh_opts: &[String], foreground: bool) -> Vec<String> {
108    let mut args = Vec::new();
109    args.extend(dest.port_args());
110    for opt in extra_ssh_opts {
111        args.push("-o".into());
112        args.push(opt.clone());
113    }
114    args.push("-o".into());
115    args.push("ConnectTimeout=5".into());
116    if !foreground {
117        args.push("-o".into());
118        args.push("BatchMode=yes".into());
119    }
120    args
121}
122
123/// Build the SSH command for remote execution (without stdio config).
124fn remote_exec_command(
125    dest: &Destination,
126    remote_cmd: &str,
127    extra_ssh_opts: &[String],
128    foreground: bool,
129) -> Command {
130    let wrapped_cmd = format!("PATH=\"{REMOTE_PATH_PREFIX}\"; {remote_cmd}");
131    let mut cmd = Command::new("ssh");
132    cmd.args(base_ssh_args(dest, extra_ssh_opts, foreground));
133    cmd.arg(dest.ssh_dest());
134    cmd.arg(&wrapped_cmd);
135    cmd
136}
137
138/// Run a command on the remote host via SSH, returning stdout.
139///
140/// Stderr is always piped so we can include SSH errors in our error messages.
141/// SSH interactive prompts use `/dev/tty` directly, not stderr.
142/// In background mode, `BatchMode=yes` is set so SSH fails fast instead of hanging.
143async fn remote_exec(
144    dest: &Destination,
145    remote_cmd: &str,
146    extra_ssh_opts: &[String],
147    foreground: bool,
148) -> anyhow::Result<String> {
149    debug!("ssh {}: {remote_cmd}", dest.ssh_dest());
150
151    let mut cmd = remote_exec_command(dest, remote_cmd, extra_ssh_opts, foreground);
152    cmd.stdout(Stdio::piped());
153    cmd.stderr(Stdio::piped());
154    cmd.stdin(Stdio::null());
155
156    let output = cmd.output().await.context("failed to run ssh")?;
157
158    if !output.status.success() {
159        let stderr = String::from_utf8_lossy(&output.stderr);
160        let stderr = stderr.trim();
161        debug!("ssh failed (status {}): {stderr}", output.status);
162        if stderr.contains("command not found") || stderr.contains("No such file") {
163            bail!("gritty not found on remote host (is it in PATH?)");
164        }
165        let diag = format_ssh_diag(dest, extra_ssh_opts, foreground);
166        if stderr.is_empty() {
167            bail!("ssh command failed (exit {})\n  to diagnose: {diag}", output.status);
168        }
169        bail!("ssh command failed: {stderr}\n  to diagnose: {diag}");
170    }
171
172    let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string();
173    debug!("ssh output: {stdout}");
174    Ok(stdout)
175}
176
177/// Format a diagnostic SSH command for display in error messages.
178/// Mirrors `base_ssh_args` so the suggestion matches what was actually run.
179fn format_ssh_diag(dest: &Destination, extra_ssh_opts: &[String], foreground: bool) -> String {
180    let mut parts = vec!["ssh".to_string()];
181    for arg in base_ssh_args(dest, extra_ssh_opts, foreground) {
182        parts.push(shell_quote(&arg));
183    }
184    parts.push(dest.ssh_dest());
185    parts.join(" ")
186}
187
188/// Shell-quote a string if it contains characters that need quoting.
189/// Used only for display (--dry-run output), never for command execution.
190fn shell_quote(s: &str) -> String {
191    if s.is_empty() {
192        return "''".to_string();
193    }
194    if s.bytes().all(|b| b.is_ascii_alphanumeric() || b"-_./=:@$+%,".contains(&b)) {
195        return s.to_string();
196    }
197    format!("'{}'", s.replace('\'', "'\\''"))
198}
199
200/// Format a tokio Command as a shell string for display.
201fn format_command(cmd: &Command) -> String {
202    let std_cmd = cmd.as_std();
203    let prog = std_cmd.get_program().to_string_lossy();
204    let args: Vec<_> = std_cmd.get_args().map(|a| shell_quote(&a.to_string_lossy())).collect();
205    if args.is_empty() { prog.to_string() } else { format!("{prog} {}", args.join(" ")) }
206}
207
208/// Build the SSH tunnel command with hardened options.
209///
210/// Stderr is always piped so we can capture SSH errors on failure.
211/// (SSH interactive prompts use `/dev/tty` directly, not stderr.)
212fn tunnel_command(
213    dest: &Destination,
214    local_sock: &Path,
215    remote_sock: &str,
216    extra_ssh_opts: &[String],
217    foreground: bool,
218) -> Command {
219    let mut cmd = Command::new("ssh");
220    cmd.args(base_ssh_args(dest, extra_ssh_opts, foreground));
221    for opt in TUNNEL_SSH_OPTS {
222        cmd.arg("-o").arg(opt);
223    }
224    cmd.args(["-N", "-T"]);
225    let forward = format!("{}:{}", local_sock.display(), remote_sock);
226    cmd.arg("-L").arg(forward);
227    cmd.arg(dest.ssh_dest());
228    cmd.stdout(Stdio::null());
229    cmd.stderr(Stdio::piped());
230    cmd.stdin(Stdio::null());
231    cmd
232}
233
234/// Spawn the SSH tunnel, returning the child process.
235async fn spawn_tunnel(
236    dest: &Destination,
237    local_sock: &Path,
238    remote_sock: &str,
239    extra_ssh_opts: &[String],
240    foreground: bool,
241) -> anyhow::Result<Child> {
242    debug!("tunnel: {} -> {}:{}", local_sock.display(), dest.ssh_dest(), remote_sock,);
243    let mut cmd = tunnel_command(dest, local_sock, remote_sock, extra_ssh_opts, foreground);
244    let child = cmd.spawn().context("failed to spawn ssh tunnel")?;
245    debug!("ssh tunnel pid: {:?}", child.id());
246    Ok(child)
247}
248
249/// Poll until the local socket is connectable (200ms interval).
250async fn wait_for_socket(path: &Path, timeout: Duration) -> anyhow::Result<()> {
251    let deadline = Instant::now() + timeout;
252    loop {
253        if std::os::unix::net::UnixStream::connect(path).is_ok() {
254            return Ok(());
255        }
256        if Instant::now() >= deadline {
257            bail!("timeout waiting for SSH tunnel socket at {}", path.display());
258        }
259        tokio::time::sleep(Duration::from_millis(200)).await;
260    }
261}
262
263/// Background task: monitor SSH child, respawn on transient failure.
264async fn tunnel_monitor(
265    mut child: Child,
266    dest: Destination,
267    local_sock: PathBuf,
268    remote_sock: String,
269    extra_ssh_opts: Vec<String>,
270    stop: tokio_util::sync::CancellationToken,
271) {
272    let mut exit_times: Vec<Instant> = Vec::new();
273
274    loop {
275        tokio::select! {
276            _ = stop.cancelled() => {
277                let _ = child.kill().await;
278                return;
279            }
280            status = child.wait() => {
281                let status = match status {
282                    Ok(s) => s,
283                    Err(e) => {
284                        warn!("failed to wait on ssh tunnel: {e}");
285                        return;
286                    }
287                };
288
289                if stop.is_cancelled() {
290                    return;
291                }
292
293                let code = status.code();
294                debug!("ssh tunnel exited: {:?}", code);
295
296                // Non-transient failure: don't retry
297                // SSH exit 255 = connection error (transient). Signal-killed = no code.
298                // Everything else (auth failure, config error) = bail.
299                if let Some(c) = code
300                    && c != 255
301                {
302                    warn!("ssh tunnel exited with code {c} (not retrying)");
303                    return;
304                }
305
306                // Rate limit: 5 exits in 10s = give up
307                let now = Instant::now();
308                exit_times.push(now);
309                exit_times.retain(|t| now.duration_since(*t) < Duration::from_secs(10));
310                if exit_times.len() >= 5 {
311                    warn!("ssh tunnel failing too fast (5 exits in 10s), giving up");
312                    return;
313                }
314
315                tokio::time::sleep(Duration::from_secs(1)).await;
316
317                if stop.is_cancelled() {
318                    return;
319                }
320
321                match spawn_tunnel(&dest, &local_sock, &remote_sock, &extra_ssh_opts, false).await {
322                    Ok(new_child) => {
323                        info!("ssh tunnel respawned");
324                        child = new_child;
325                    }
326                    Err(e) => {
327                        warn!("failed to respawn ssh tunnel: {e}");
328                        return;
329                    }
330                }
331            }
332        }
333    }
334}
335
336// ---------------------------------------------------------------------------
337// Remote server management
338// ---------------------------------------------------------------------------
339
340const REMOTE_ENSURE_CMD: &str = "\
341    SOCK=$(gritty socket-path) && \
342    (gritty ls >/dev/null 2>&1 || \
343     { gritty server && sleep 0.3; }) && \
344    echo \"$SOCK\" && \
345    gritty protocol-version 2>/dev/null || true";
346
347/// Get the remote socket path and optionally auto-start the server.
348/// Returns (socket_path, remote_protocol_version).
349async fn ensure_remote_ready(
350    dest: &Destination,
351    no_server_start: bool,
352    extra_ssh_opts: &[String],
353    foreground: bool,
354) -> anyhow::Result<(String, Option<u16>)> {
355    let remote_cmd = if no_server_start { "gritty socket-path" } else { REMOTE_ENSURE_CMD };
356    debug!("ensuring remote server (no_server_start={no_server_start})");
357
358    let output = remote_exec(dest, remote_cmd, extra_ssh_opts, foreground).await?;
359
360    // Output is "socket_path\nversion" (version line may be absent for old remotes)
361    let mut lines = output.lines();
362    let sock_path = lines.next().unwrap_or("").to_string();
363    let remote_version = lines.next().and_then(|s| s.trim().parse::<u16>().ok());
364
365    if sock_path.is_empty() {
366        bail!("remote host returned empty socket path");
367    }
368
369    Ok((sock_path, remote_version))
370}
371
372// ---------------------------------------------------------------------------
373// Local socket path
374// ---------------------------------------------------------------------------
375
376/// Compute a deterministic local socket path based on the destination.
377///
378/// Using the raw destination string means re-running `gritty connect user@host`
379/// produces the same socket path, so sessions that used `--ctl-socket` can
380/// auto-reconnect after a tunnel restart.
381fn local_socket_path(destination: &str) -> PathBuf {
382    crate::daemon::socket_dir().join(format!("connect-{destination}.sock"))
383}
384
385fn connect_pid_path(connection_name: &str) -> PathBuf {
386    crate::daemon::socket_dir().join(format!("connect-{connection_name}.pid"))
387}
388
389fn connect_lock_path(connection_name: &str) -> PathBuf {
390    crate::daemon::socket_dir().join(format!("connect-{connection_name}.lock"))
391}
392
393fn connect_dest_path(connection_name: &str) -> PathBuf {
394    crate::daemon::socket_dir().join(format!("connect-{connection_name}.dest"))
395}
396
397/// Compute the local socket path for a given connection name.
398/// Public so main.rs can compute the path in the parent process after daemonize.
399pub fn connection_socket_path(connection_name: &str) -> PathBuf {
400    local_socket_path(connection_name)
401}
402
403/// Extract the host component from a destination string (`[user@]host[:port]`).
404pub fn parse_host(destination: &str) -> anyhow::Result<String> {
405    Ok(Destination::parse(destination)?.host)
406}
407
408// ---------------------------------------------------------------------------
409// Lockfile-based liveness
410// ---------------------------------------------------------------------------
411
412/// Acquire an exclusive flock on the lockfile. Returns the locked fd on success.
413/// The lock is held for the lifetime of the returned `OwnedFd`.
414fn acquire_lock(lock_path: &Path) -> anyhow::Result<OwnedFd> {
415    use std::fs::OpenOptions;
416    let file = OpenOptions::new()
417        .create(true)
418        .truncate(false)
419        .write(true)
420        .mode(0o600)
421        .open(lock_path)
422        .with_context(|| format!("failed to open lockfile: {}", lock_path.display()))?;
423    let fd = OwnedFd::from(file);
424    if unsafe { libc::flock(fd.as_raw_fd(), libc::LOCK_EX) } != 0 {
425        bail!("failed to acquire lock on {}", lock_path.display());
426    }
427    Ok(fd)
428}
429
430/// Probe whether a lockfile is held by a live process.
431/// Returns true if the lock is held (process alive), false if free (process dead).
432fn is_lock_held(lock_path: &Path) -> bool {
433    use std::fs::OpenOptions;
434    let file = match OpenOptions::new().read(true).open(lock_path) {
435        Ok(f) => f,
436        Err(_) => return false,
437    };
438    // Non-blocking exclusive lock attempt: if it succeeds, the old process is dead
439    if unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_EX | libc::LOCK_NB) } == 0 {
440        // We got the lock — old process is gone. Release it immediately (fd drop).
441        false
442    } else {
443        true // Lock held by another process
444    }
445}
446
447/// Tunnel health status.
448#[derive(Debug, PartialEq, Eq)]
449pub enum TunnelStatus {
450    Healthy,
451    Reconnecting,
452    Stale,
453}
454
455/// Probe a tunnel's status using lockfile + socket connectivity.
456fn probe_tunnel_status(name: &str) -> TunnelStatus {
457    let lock_path = connect_lock_path(name);
458    if is_lock_held(&lock_path) {
459        let sock_path = local_socket_path(name);
460        if std::os::unix::net::UnixStream::connect(&sock_path).is_ok() {
461            TunnelStatus::Healthy
462        } else {
463            TunnelStatus::Reconnecting
464        }
465    } else {
466        TunnelStatus::Stale
467    }
468}
469
470/// Clean up files for a stale tunnel (process already dead).
471/// No signals sent — the process is confirmed dead (lockfile released).
472/// Orphaned SSH children self-terminate via ServerAliveInterval/ServerAliveCountMax.
473fn read_pid_hint(name: &str) -> Option<u32> {
474    std::fs::read_to_string(connect_pid_path(name)).ok().and_then(|s| s.trim().parse().ok())
475}
476
477fn cleanup_stale_files(name: &str) {
478    let _ = std::fs::remove_file(local_socket_path(name));
479    let _ = std::fs::remove_file(connect_pid_path(name));
480    let _ = std::fs::remove_file(connect_lock_path(name));
481    let _ = std::fs::remove_file(connect_dest_path(name));
482}
483
484/// Extract tunnel connection names by globbing lock files in the socket dir.
485fn enumerate_tunnels() -> Vec<String> {
486    let dir = crate::daemon::socket_dir();
487    let Ok(entries) = std::fs::read_dir(&dir) else {
488        return Vec::new();
489    };
490    entries
491        .filter_map(|e| e.ok())
492        .filter_map(|e| {
493            let name = e.file_name().to_string_lossy().to_string();
494            if name.starts_with("connect-") && name.ends_with(".lock") {
495                Some(name["connect-".len()..name.len() - ".lock".len()].to_string())
496            } else {
497                None
498            }
499        })
500        .collect()
501}
502
503// ---------------------------------------------------------------------------
504// Cleanup guard
505// ---------------------------------------------------------------------------
506
507struct ConnectGuard {
508    child: Option<Child>,
509    local_sock: PathBuf,
510    pid_file: PathBuf,
511    lock_file: PathBuf,
512    dest_file: PathBuf,
513    _lock_fd: Option<OwnedFd>,
514    stop: tokio_util::sync::CancellationToken,
515}
516
517impl Drop for ConnectGuard {
518    fn drop(&mut self) {
519        self.stop.cancel();
520
521        if let Some(ref mut child) = self.child
522            && let Some(pid) = child.id()
523        {
524            unsafe {
525                libc::kill(pid as i32, libc::SIGTERM);
526            }
527        }
528
529        let _ = std::fs::remove_file(&self.local_sock);
530        let _ = std::fs::remove_file(&self.pid_file);
531        let _ = std::fs::remove_file(&self.lock_file);
532        let _ = std::fs::remove_file(&self.dest_file);
533        // _lock_fd drops here, releasing the flock
534    }
535}
536
537// ---------------------------------------------------------------------------
538// Public API
539// ---------------------------------------------------------------------------
540
541pub struct ConnectOpts {
542    pub destination: String,
543    pub no_server_start: bool,
544    pub ssh_options: Vec<String>,
545    pub name: Option<String>,
546    pub dry_run: bool,
547    pub foreground: bool,
548}
549
550pub async fn run(opts: ConnectOpts, ready_fd: Option<OwnedFd>) -> anyhow::Result<i32> {
551    unsafe {
552        libc::umask(0o077);
553    }
554
555    let dest = Destination::parse(&opts.destination)?;
556    let connection_name = opts.name.unwrap_or_else(|| dest.host.clone());
557    validate_connection_name(&connection_name)?;
558    let local_sock = local_socket_path(&connection_name);
559
560    if opts.dry_run {
561        let remote_cmd =
562            if opts.no_server_start { "gritty socket-path" } else { REMOTE_ENSURE_CMD };
563        let ensure_cmd = remote_exec_command(&dest, remote_cmd, &opts.ssh_options, true);
564        let tunnel_cmd =
565            tunnel_command(&dest, &local_sock, "$REMOTE_SOCK", &opts.ssh_options, true);
566
567        println!(
568            "# Get remote socket path{}",
569            if opts.no_server_start { "" } else { " and start server if needed" }
570        );
571        println!("REMOTE_SOCK=$({})", format_command(&ensure_cmd));
572        println!();
573        println!("# Start SSH tunnel");
574        println!("{}", format_command(&tunnel_cmd));
575        return Ok(0);
576    }
577
578    // 1. Ensure socket directory exists
579    let pid_file = connect_pid_path(&connection_name);
580    let lock_path = connect_lock_path(&connection_name);
581    let dest_file = connect_dest_path(&connection_name);
582    debug!("local socket: {}", local_sock.display());
583    if let Some(parent) = local_sock.parent() {
584        crate::security::secure_create_dir_all(parent)?;
585    }
586
587    // 2. Check for existing tunnel via lockfile (authoritative)
588    match probe_tunnel_status(&connection_name) {
589        TunnelStatus::Healthy => {
590            println!("{}", local_sock.display());
591            let pid_hint = read_pid_hint(&connection_name);
592            eprint!("tunnel already running (name: {connection_name})");
593            if let Some(pid) = pid_hint {
594                eprintln!(" (pid {pid})");
595                eprintln!("  to stop: gritty disconnect {connection_name}");
596            } else {
597                eprintln!();
598            }
599            eprintln!("  to use:");
600            eprintln!("    gritty new {connection_name}");
601            eprintln!("    gritty attach {connection_name} -t <name>");
602            // Signal readiness to parent even for already-running case
603            signal_ready(&ready_fd);
604            return Ok(0);
605        }
606        TunnelStatus::Reconnecting => {
607            let pid_hint = read_pid_hint(&connection_name);
608            eprint!("tunnel exists but is reconnecting (name: {connection_name})");
609            if let Some(pid) = pid_hint {
610                eprintln!(" (pid {pid})");
611            } else {
612                eprintln!();
613            }
614            eprintln!("  wait for it, or: gritty disconnect {connection_name}");
615            // Signal readiness to parent so it doesn't hang
616            signal_ready(&ready_fd);
617            return Ok(0);
618        }
619        TunnelStatus::Stale => {
620            debug!("cleaning stale tunnel files for {connection_name}");
621            cleanup_stale_files(&connection_name);
622        }
623    }
624
625    // 3. Acquire lockfile (held for entire lifetime of this process)
626    let lock_fd = acquire_lock(&lock_path)?;
627
628    // 4. Ensure remote server is running and get socket path
629    let (remote_sock, remote_version) =
630        ensure_remote_ready(&dest, opts.no_server_start, &opts.ssh_options, opts.foreground)
631            .await?;
632    debug!(remote_sock, ?remote_version, "remote socket path");
633
634    // Check protocol version compatibility
635    if let Some(rv) = remote_version {
636        if rv != crate::protocol::PROTOCOL_VERSION {
637            warn!(
638                "remote protocol version ({rv}) differs from local ({})",
639                crate::protocol::PROTOCOL_VERSION
640            );
641        }
642    }
643
644    // 5. Spawn SSH tunnel
645    let child =
646        spawn_tunnel(&dest, &local_sock, &remote_sock, &opts.ssh_options, opts.foreground).await?;
647    let stop = tokio_util::sync::CancellationToken::new();
648
649    let mut guard = ConnectGuard {
650        child: Some(child),
651        local_sock: local_sock.clone(),
652        pid_file: pid_file.clone(),
653        lock_file: lock_path,
654        dest_file: dest_file.clone(),
655        _lock_fd: Some(lock_fd),
656        stop: stop.clone(),
657    };
658
659    // 6. Wait for local socket to become connectable (race against child exit)
660    let mut child = guard.child.take().unwrap();
661    tokio::select! {
662        result = wait_for_socket(&local_sock, Duration::from_secs(15)) => {
663            result?;
664            guard.child = Some(child);
665        }
666        status = child.wait() => {
667            let status = status.context("failed to wait on ssh tunnel")?;
668            let diag = format_ssh_diag(&dest, &opts.ssh_options, opts.foreground);
669            let msg = if let Some(mut stderr) = child.stderr.take() {
670                use tokio::io::AsyncReadExt;
671                let mut buf = String::new();
672                let _ = stderr.read_to_string(&mut buf).await;
673                let buf = buf.trim().to_string();
674                if buf.is_empty() { None } else { Some(buf) }
675            } else {
676                None
677            };
678            match msg {
679                Some(err) => bail!("ssh tunnel failed: {err}\n  to diagnose: {diag}"),
680                None => bail!("ssh tunnel exited ({status})\n  to diagnose: {diag}"),
681            }
682        }
683    }
684    debug!("tunnel socket ready");
685
686    // Write PID + dest files
687    let _ = std::fs::write(&pid_file, std::process::id().to_string());
688    let _ = std::fs::write(&dest_file, &opts.destination);
689
690    // 7. Signal readiness to parent (or print if foreground)
691    signal_ready(&ready_fd);
692
693    // 8. Hand off the child to the tunnel monitor background task
694    let original_child = guard.child.take().unwrap();
695    let monitor_handle = tokio::spawn(tunnel_monitor(
696        original_child,
697        dest,
698        local_sock.clone(),
699        remote_sock,
700        opts.ssh_options,
701        stop.clone(),
702    ));
703
704    // 9. Wait for signal or monitor death
705    let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
706    tokio::select! {
707        _ = sigterm.recv() => {}
708        _ = monitor_handle => {}
709    }
710
711    // 10. Cleanup (guard Drop handles ssh kill + file removal + lock release)
712    drop(guard);
713
714    Ok(0)
715}
716
717/// Write one readiness byte to the pipe fd (if present).
718fn signal_ready(ready_fd: &Option<OwnedFd>) {
719    if let Some(fd) = ready_fd {
720        let _ = nix::unistd::write(fd, b"\x01");
721    }
722}
723
724// ---------------------------------------------------------------------------
725// Disconnect
726// ---------------------------------------------------------------------------
727
728pub async fn disconnect(name: &str) -> anyhow::Result<()> {
729    validate_connection_name(name)?;
730    match probe_tunnel_status(name) {
731        TunnelStatus::Stale => {
732            cleanup_stale_files(name);
733            eprintln!("tunnel already stopped: {name}");
734            return Ok(());
735        }
736        TunnelStatus::Healthy | TunnelStatus::Reconnecting => {}
737    }
738
739    // Read PID and send SIGTERM (let the process handle graceful shutdown)
740    let pid_file = connect_pid_path(name);
741    let pid: i32 = std::fs::read_to_string(&pid_file)
742        .ok()
743        .and_then(|s| s.trim().parse::<u32>().ok())
744        .map(|p| p as i32)
745        .ok_or_else(|| anyhow::anyhow!("cannot read PID for tunnel {name}"))?;
746
747    let lock_path = connect_lock_path(name);
748    if !is_lock_held(&lock_path) {
749        cleanup_stale_files(name);
750        eprintln!("tunnel already stopped: {name}");
751        return Ok(());
752    }
753    unsafe {
754        libc::kill(pid, libc::SIGTERM);
755    }
756
757    // Poll lock for up to 2s to confirm exit
758    let deadline = Instant::now() + Duration::from_secs(2);
759    loop {
760        tokio::time::sleep(Duration::from_millis(100)).await;
761        if !is_lock_held(&lock_path) {
762            cleanup_stale_files(name);
763            eprintln!("tunnel stopped: {name}");
764            return Ok(());
765        }
766        if Instant::now() >= deadline {
767            break;
768        }
769    }
770
771    // Still alive after timeout — escalate to SIGKILL + killpg
772    if is_lock_held(&lock_path) {
773        unsafe {
774            libc::kill(pid, libc::SIGKILL);
775            libc::killpg(pid, libc::SIGTERM);
776        }
777    }
778    tokio::time::sleep(Duration::from_millis(100)).await;
779    cleanup_stale_files(name);
780    eprintln!("tunnel killed: {name}");
781    Ok(())
782}
783
784// ---------------------------------------------------------------------------
785// List tunnels
786// ---------------------------------------------------------------------------
787
788pub struct TunnelInfo {
789    pub name: String,
790    pub destination: String,
791    pub status: String,
792    pub pid: Option<u32>,
793    pub log_path: PathBuf,
794}
795
796/// Gather info for all live tunnels (cleans stale ones as a side effect).
797pub fn get_tunnel_info() -> Vec<TunnelInfo> {
798    let names = enumerate_tunnels();
799    let mut infos = Vec::new();
800    for name in &names {
801        let status = probe_tunnel_status(name);
802        if status == TunnelStatus::Stale {
803            debug!("cleaning stale tunnel: {name}");
804            cleanup_stale_files(name);
805            continue;
806        }
807        let dest =
808            std::fs::read_to_string(connect_dest_path(name)).unwrap_or_else(|_| "-".to_string());
809        let status_str = match status {
810            TunnelStatus::Healthy => "healthy".to_string(),
811            TunnelStatus::Reconnecting => "reconnecting".to_string(),
812            TunnelStatus::Stale => unreachable!(),
813        };
814        infos.push(TunnelInfo {
815            name: name.clone(),
816            destination: dest.trim().to_string(),
817            status: status_str,
818            pid: read_pid_hint(name),
819            log_path: crate::daemon::socket_dir().join(format!("connect-{name}.log")),
820        });
821    }
822    infos
823}
824
825pub fn list_tunnels() {
826    let infos = get_tunnel_info();
827    if infos.is_empty() {
828        println!("no active tunnels");
829        return;
830    }
831
832    let w_name = infos.iter().map(|i| i.name.len()).max().unwrap().max(4);
833    let w_dest = infos.iter().map(|i| i.destination.len()).max().unwrap().max(11);
834
835    println!("{:<w_name$}  {:<w_dest$}  Status", "Name", "Destination");
836    for info in &infos {
837        println!("{:<w_name$}  {:<w_dest$}  {}", info.name, info.destination, info.status);
838    }
839}
840
841// ---------------------------------------------------------------------------
842// Tests
843// ---------------------------------------------------------------------------
844
845#[cfg(test)]
846mod tests {
847    use super::*;
848
849    #[test]
850    fn parse_destination_user_host() {
851        let d = Destination::parse("user@host").unwrap();
852        assert_eq!(d.user.as_deref(), Some("user"));
853        assert_eq!(d.host, "host");
854        assert_eq!(d.port, None);
855    }
856
857    #[test]
858    fn parse_destination_host_only() {
859        let d = Destination::parse("myhost").unwrap();
860        assert_eq!(d.user, None);
861        assert_eq!(d.host, "myhost");
862        assert_eq!(d.port, None);
863    }
864
865    #[test]
866    fn parse_destination_host_port() {
867        let d = Destination::parse("host:2222").unwrap();
868        assert_eq!(d.user, None);
869        assert_eq!(d.host, "host");
870        assert_eq!(d.port, Some(2222));
871    }
872
873    #[test]
874    fn parse_destination_user_host_port() {
875        let d = Destination::parse("user@host:2222").unwrap();
876        assert_eq!(d.user.as_deref(), Some("user"));
877        assert_eq!(d.host, "host");
878        assert_eq!(d.port, Some(2222));
879    }
880
881    #[test]
882    fn parse_destination_invalid_empty() {
883        assert!(Destination::parse("").is_err());
884    }
885
886    #[test]
887    fn parse_destination_invalid_at_only() {
888        assert!(Destination::parse("@host").is_err());
889    }
890
891    #[test]
892    fn parse_destination_invalid_colon_only() {
893        assert!(Destination::parse(":2222").is_err());
894    }
895
896    #[test]
897    fn tunnel_command_default_opts() {
898        let dest = Destination::parse("user@host").unwrap();
899        let cmd = tunnel_command(
900            &dest,
901            Path::new("/tmp/local.sock"),
902            "/run/user/1000/gritty/ctl.sock",
903            &[],
904            false,
905        );
906        let args: Vec<_> =
907            cmd.as_std().get_args().map(|a| a.to_string_lossy().to_string()).collect();
908        // From base_ssh_args
909        assert!(args.contains(&"ConnectTimeout=5".to_string()));
910        assert!(args.contains(&"BatchMode=yes".to_string()));
911        // From TUNNEL_SSH_OPTS
912        assert!(args.contains(&"ServerAliveInterval=3".to_string()));
913        assert!(args.contains(&"StreamLocalBindUnlink=yes".to_string()));
914        assert!(args.contains(&"ExitOnForwardFailure=yes".to_string()));
915        assert!(args.contains(&"ControlPath=none".to_string()));
916        assert!(args.contains(&"ForwardAgent=no".to_string()));
917        assert!(args.contains(&"ForwardX11=no".to_string()));
918        // Tunnel flags and forward
919        assert!(args.contains(&"-N".to_string()));
920        assert!(args.contains(&"-T".to_string()));
921        assert!(args.contains(&"/tmp/local.sock:/run/user/1000/gritty/ctl.sock".to_string()));
922        assert!(args.contains(&"user@host".to_string()));
923    }
924
925    #[test]
926    fn tunnel_command_extra_opts() {
927        let dest = Destination::parse("host:2222").unwrap();
928        let cmd = tunnel_command(
929            &dest,
930            Path::new("/tmp/local.sock"),
931            "/tmp/remote.sock",
932            &["ProxyJump=bastion".to_string()],
933            false,
934        );
935        let args: Vec<_> =
936            cmd.as_std().get_args().map(|a| a.to_string_lossy().to_string()).collect();
937        assert!(args.contains(&"ProxyJump=bastion".to_string()));
938        assert!(args.contains(&"-p".to_string()));
939        assert!(args.contains(&"2222".to_string()));
940    }
941
942    #[test]
943    fn local_socket_path_format() {
944        // With hostname-based naming, connect uses just the host part
945        let path = local_socket_path("devbox");
946        assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-devbox.sock");
947
948        let path = local_socket_path("example.com");
949        assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-example.com.sock");
950
951        // Custom name override
952        let path = local_socket_path("myproject");
953        assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-myproject.sock");
954    }
955
956    #[test]
957    fn connect_pid_path_format() {
958        let path = connect_pid_path("devbox");
959        assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-devbox.pid");
960
961        let path = connect_pid_path("example.com");
962        assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-example.com.pid");
963    }
964
965    #[test]
966    fn ssh_dest_with_user() {
967        let d = Destination::parse("alice@example.com").unwrap();
968        assert_eq!(d.ssh_dest(), "alice@example.com");
969    }
970
971    #[test]
972    fn ssh_dest_without_user() {
973        let d = Destination::parse("example.com").unwrap();
974        assert_eq!(d.ssh_dest(), "example.com");
975    }
976
977    #[test]
978    fn port_args_with_port() {
979        let d = Destination::parse("host:9999").unwrap();
980        assert_eq!(d.port_args(), vec!["-p", "9999"]);
981    }
982
983    #[test]
984    fn port_args_without_port() {
985        let d = Destination::parse("host").unwrap();
986        assert!(d.port_args().is_empty());
987    }
988
989    #[test]
990    fn shell_quote_simple() {
991        assert_eq!(shell_quote("hello"), "hello");
992        assert_eq!(shell_quote("-N"), "-N");
993        assert_eq!(shell_quote("ServerAliveInterval=3"), "ServerAliveInterval=3");
994        assert_eq!(shell_quote("user@host"), "user@host");
995        assert_eq!(
996            shell_quote("/tmp/local.sock:/tmp/remote.sock"),
997            "/tmp/local.sock:/tmp/remote.sock"
998        );
999        assert_eq!(shell_quote("$REMOTE_SOCK"), "$REMOTE_SOCK");
1000    }
1001
1002    #[test]
1003    fn shell_quote_needs_quoting() {
1004        assert_eq!(shell_quote("hello world"), "'hello world'");
1005        assert_eq!(shell_quote(""), "''");
1006        assert_eq!(shell_quote("it's"), "'it'\\''s'");
1007    }
1008
1009    #[test]
1010    fn shell_quote_remote_cmd() {
1011        // The wrapped remote command contains spaces, quotes, semicolons —
1012        // must be single-quoted so $HOME expands on the remote side.
1013        let cmd = format!("PATH=\"{REMOTE_PATH_PREFIX}\"; gritty socket-path");
1014        let quoted = shell_quote(&cmd);
1015        assert!(quoted.starts_with('\''));
1016        assert!(quoted.ends_with('\''));
1017    }
1018
1019    #[test]
1020    fn format_command_tunnel() {
1021        let dest = Destination::parse("user@host").unwrap();
1022        let cmd = tunnel_command(&dest, Path::new("/tmp/local.sock"), "$REMOTE_SOCK", &[], true);
1023        let formatted = format_command(&cmd);
1024        assert!(formatted.contains("ServerAliveInterval=3"));
1025        assert!(formatted.contains("ControlPath=none"));
1026        assert!(formatted.contains("ForwardAgent=no"));
1027        assert!(formatted.contains("-N"));
1028        assert!(formatted.contains("-T"));
1029        // Forward arg references $REMOTE_SOCK unquoted (no spaces, $ is safe)
1030        assert!(formatted.contains("/tmp/local.sock:$REMOTE_SOCK"));
1031        assert!(formatted.contains("user@host"));
1032    }
1033
1034    #[test]
1035    fn format_command_remote_exec() {
1036        let dest = Destination::parse("user@host:2222").unwrap();
1037        let cmd = remote_exec_command(&dest, "gritty socket-path", &[], true);
1038        let formatted = format_command(&cmd);
1039        assert!(formatted.starts_with("ssh "));
1040        assert!(formatted.contains("-p 2222"));
1041        assert!(formatted.contains("ConnectTimeout=5"));
1042        assert!(formatted.contains("user@host"));
1043        // The wrapped command should be single-quoted (contains spaces)
1044        assert!(formatted.contains(&format!("PATH=\"{REMOTE_PATH_PREFIX}\"")));
1045    }
1046
1047    #[test]
1048    fn format_command_remote_exec_with_extra_opts() {
1049        let dest = Destination::parse("user@host").unwrap();
1050        let cmd =
1051            remote_exec_command(&dest, REMOTE_ENSURE_CMD, &["ProxyJump=bastion".to_string()], true);
1052        let formatted = format_command(&cmd);
1053        assert!(formatted.contains("ProxyJump=bastion"));
1054        assert!(formatted.contains("gritty socket-path"));
1055        assert!(formatted.contains("gritty server"));
1056    }
1057
1058    #[test]
1059    fn base_ssh_args_foreground() {
1060        let dest = Destination::parse("user@host:2222").unwrap();
1061        let args = base_ssh_args(&dest, &["ProxyJump=bastion".into()], true);
1062        assert!(args.contains(&"-p".to_string()));
1063        assert!(args.contains(&"2222".to_string()));
1064        assert!(args.contains(&"ProxyJump=bastion".to_string()));
1065        assert!(args.contains(&"ConnectTimeout=5".to_string()));
1066        assert!(!args.contains(&"BatchMode=yes".to_string()));
1067    }
1068
1069    #[test]
1070    fn base_ssh_args_background() {
1071        let dest = Destination::parse("host").unwrap();
1072        let args = base_ssh_args(&dest, &[], false);
1073        assert!(args.contains(&"ConnectTimeout=5".to_string()));
1074        assert!(args.contains(&"BatchMode=yes".to_string()));
1075        assert!(!args.contains(&"-p".to_string()));
1076    }
1077
1078    // -----------------------------------------------------------------------
1079    // Lockfile and tunnel lifecycle tests
1080    // -----------------------------------------------------------------------
1081
1082    #[test]
1083    fn connect_lock_path_format() {
1084        let path = connect_lock_path("devbox");
1085        assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-devbox.lock");
1086    }
1087
1088    #[test]
1089    fn connect_dest_path_format() {
1090        let path = connect_dest_path("devbox");
1091        assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-devbox.dest");
1092    }
1093
1094    #[test]
1095    fn acquire_and_probe_lock() {
1096        let dir = tempfile::tempdir().unwrap();
1097        let lock_path = dir.path().join("test.lock");
1098
1099        // Lock not held initially (file doesn't exist)
1100        assert!(!is_lock_held(&lock_path));
1101
1102        // Acquire the lock
1103        let _fd = acquire_lock(&lock_path).unwrap();
1104
1105        // Now it should be held
1106        assert!(is_lock_held(&lock_path));
1107
1108        // Drop the lock
1109        drop(_fd);
1110
1111        // Should be free again
1112        assert!(!is_lock_held(&lock_path));
1113    }
1114
1115    #[test]
1116    fn probe_stale_no_files() {
1117        // No files at all → stale
1118        let status = probe_tunnel_status("nonexistent-test-tunnel-xyz");
1119        assert_eq!(status, TunnelStatus::Stale);
1120    }
1121
1122    #[test]
1123    fn cleanup_stale_files_removes_all() {
1124        let _dir = tempfile::tempdir().unwrap();
1125        // We can't easily override socket_dir(), so test that cleanup_stale_files
1126        // at least doesn't panic on nonexistent files
1127        cleanup_stale_files("nonexistent-cleanup-test-xyz");
1128        // No panic = success
1129    }
1130
1131    #[test]
1132    fn enumerate_tunnels_empty_dir() {
1133        // If socket dir doesn't have any lock files, should return empty
1134        // This tests the function doesn't crash on various filesystem states
1135        let names = enumerate_tunnels();
1136        // We can't control what's in socket_dir during tests, but at minimum
1137        // the function should not panic
1138        let _ = names;
1139    }
1140
1141    #[test]
1142    fn connection_socket_path_matches_local() {
1143        let public_path = connection_socket_path("myhost");
1144        let internal_path = local_socket_path("myhost");
1145        assert_eq!(public_path, internal_path);
1146    }
1147
1148    // -----------------------------------------------------------------------
1149    // tunnel_monitor tests
1150    // -----------------------------------------------------------------------
1151
1152    #[tokio::test]
1153    async fn tunnel_monitor_non_transient_exit() {
1154        let child = Command::new("sh").arg("-c").arg("exit 1").spawn().unwrap();
1155        let dest = Destination::parse("fake-host-test").unwrap();
1156        let stop = tokio_util::sync::CancellationToken::new();
1157
1158        let result = tokio::time::timeout(
1159            Duration::from_secs(5),
1160            tunnel_monitor(
1161                child,
1162                dest,
1163                PathBuf::from("/tmp/nonexistent.sock"),
1164                "/tmp/remote.sock".into(),
1165                vec![],
1166                stop,
1167            ),
1168        )
1169        .await;
1170
1171        assert!(result.is_ok(), "monitor should return quickly on non-transient exit");
1172    }
1173
1174    #[tokio::test]
1175    async fn tunnel_monitor_cancellation() {
1176        let child = Command::new("sleep").arg("60").spawn().unwrap();
1177        let dest = Destination::parse("fake-host-test").unwrap();
1178        let stop = tokio_util::sync::CancellationToken::new();
1179        let stop_clone = stop.clone();
1180
1181        tokio::spawn(async move {
1182            tokio::time::sleep(Duration::from_millis(100)).await;
1183            stop_clone.cancel();
1184        });
1185
1186        let result = tokio::time::timeout(
1187            Duration::from_secs(5),
1188            tunnel_monitor(
1189                child,
1190                dest,
1191                PathBuf::from("/tmp/nonexistent.sock"),
1192                "/tmp/remote.sock".into(),
1193                vec![],
1194                stop,
1195            ),
1196        )
1197        .await;
1198
1199        assert!(result.is_ok(), "monitor should return after cancellation");
1200    }
1201
1202    #[tokio::test]
1203    async fn tunnel_monitor_transient_exit_checks_cancellation() {
1204        // Child exits with 255 (transient). Monitor sleeps 1s then checks cancellation.
1205        let child = Command::new("sh").arg("-c").arg("exit 255").spawn().unwrap();
1206        let dest = Destination::parse("fake-host-test").unwrap();
1207        let stop = tokio_util::sync::CancellationToken::new();
1208        let stop_clone = stop.clone();
1209
1210        // Cancel during the 1s sleep between exit and respawn attempt
1211        tokio::spawn(async move {
1212            tokio::time::sleep(Duration::from_millis(500)).await;
1213            stop_clone.cancel();
1214        });
1215
1216        let result = tokio::time::timeout(
1217            Duration::from_secs(5),
1218            tunnel_monitor(
1219                child,
1220                dest,
1221                PathBuf::from("/tmp/nonexistent.sock"),
1222                "/tmp/remote.sock".into(),
1223                vec![],
1224                stop,
1225            ),
1226        )
1227        .await;
1228
1229        assert!(result.is_ok(), "monitor should return after cancellation during sleep");
1230    }
1231
1232    // -----------------------------------------------------------------------
1233    // wait_for_socket tests
1234    // -----------------------------------------------------------------------
1235
1236    #[tokio::test]
1237    async fn wait_for_socket_succeeds_after_delay() {
1238        let dir = tempfile::tempdir().unwrap();
1239        let sock_path = dir.path().join("delayed.sock");
1240        let sock_path_clone = sock_path.clone();
1241
1242        // Bind the socket after 500ms
1243        tokio::spawn(async move {
1244            tokio::time::sleep(Duration::from_millis(500)).await;
1245            let _listener = tokio::net::UnixListener::bind(&sock_path_clone).unwrap();
1246            // Keep listener alive
1247            tokio::time::sleep(Duration::from_secs(30)).await;
1248        });
1249
1250        let result = wait_for_socket(&sock_path, Duration::from_secs(5)).await;
1251        assert!(result.is_ok(), "should successfully connect");
1252    }
1253
1254    #[tokio::test]
1255    async fn wait_for_socket_timeout() {
1256        let dir = tempfile::tempdir().unwrap();
1257        let sock_path = dir.path().join("never.sock");
1258        let result = wait_for_socket(&sock_path, Duration::from_secs(1)).await;
1259        assert!(result.is_err());
1260    }
1261}