Skip to main content

gritty/
connect.rs

1use anyhow::{Context, bail};
2use std::os::fd::OwnedFd;
3use std::os::unix::fs::OpenOptionsExt;
4use std::os::unix::io::AsRawFd;
5use std::path::{Path, PathBuf};
6use std::process::Stdio;
7use std::time::{Duration, Instant};
8use tokio::process::{Child, Command};
9use tracing::{debug, info, warn};
10
11// ---------------------------------------------------------------------------
12// Destination parsing
13// ---------------------------------------------------------------------------
14
15#[derive(Debug, Clone, PartialEq, Eq)]
16struct Destination {
17    user: Option<String>,
18    host: String,
19    port: Option<u16>,
20}
21
22impl Destination {
23    fn parse(s: &str) -> anyhow::Result<Self> {
24        if s.is_empty() {
25            bail!("empty destination");
26        }
27
28        let (user, remainder) = if let Some(at) = s.find('@') {
29            let u = &s[..at];
30            if u.is_empty() {
31                bail!("empty user in destination: {s}");
32            }
33            (Some(u.to_string()), &s[at + 1..])
34        } else {
35            (None, s)
36        };
37
38        let (host, port) = if let Some(colon) = remainder.rfind(':') {
39            let h = &remainder[..colon];
40            let p = remainder[colon + 1..]
41                .parse::<u16>()
42                .with_context(|| format!("invalid port in destination: {s}"))?;
43            (h.to_string(), Some(p))
44        } else {
45            (remainder.to_string(), None)
46        };
47
48        if host.is_empty() {
49            bail!("empty host in destination: {s}");
50        }
51
52        Ok(Self { user, host, port })
53    }
54
55    /// Build the SSH destination string (`user@host` or just `host`).
56    fn ssh_dest(&self) -> String {
57        match &self.user {
58            Some(u) => format!("{u}@{}", self.host),
59            None => self.host.clone(),
60        }
61    }
62
63    /// Common SSH args for port, if set.
64    fn port_args(&self) -> Vec<String> {
65        match self.port {
66            Some(p) => vec!["-p".to_string(), p.to_string()],
67            None => vec![],
68        }
69    }
70}
71
72// ---------------------------------------------------------------------------
73// SSH helpers
74// ---------------------------------------------------------------------------
75
76/// Tunnel-specific SSH `-o` options (keepalive, cleanup, failure behavior).
77const TUNNEL_SSH_OPTS: &[&str] = &[
78    "ServerAliveInterval=3",
79    "ServerAliveCountMax=2",
80    "StreamLocalBindUnlink=yes",
81    "ExitOnForwardFailure=yes",
82    // Prevent user config from leaking forwarding or connection sharing
83    // into the tunnel (gritty handles agent forwarding separately).
84    "ControlPath=none",
85    "ForwardAgent=no",
86    "ForwardX11=no",
87];
88
89/// PATH prefix prepended to remote commands so gritty is discoverable
90/// in non-interactive SSH shells.
91const REMOTE_PATH_PREFIX: &str = "$HOME/bin:$HOME/.local/bin:$HOME/.cargo/bin:$PATH";
92
93/// Build the common SSH args that precede the destination in every invocation:
94/// port, user-supplied options, ConnectTimeout, and BatchMode (background only).
95fn base_ssh_args(dest: &Destination, extra_ssh_opts: &[String], foreground: bool) -> Vec<String> {
96    let mut args = Vec::new();
97    args.extend(dest.port_args());
98    for opt in extra_ssh_opts {
99        args.push("-o".into());
100        args.push(opt.clone());
101    }
102    args.push("-o".into());
103    args.push("ConnectTimeout=5".into());
104    if !foreground {
105        args.push("-o".into());
106        args.push("BatchMode=yes".into());
107    }
108    args
109}
110
111/// Build the SSH command for remote execution (without stdio config).
112fn remote_exec_command(
113    dest: &Destination,
114    remote_cmd: &str,
115    extra_ssh_opts: &[String],
116    foreground: bool,
117) -> Command {
118    let wrapped_cmd = format!("PATH=\"{REMOTE_PATH_PREFIX}\"; {remote_cmd}");
119    let mut cmd = Command::new("ssh");
120    cmd.args(base_ssh_args(dest, extra_ssh_opts, foreground));
121    cmd.arg(dest.ssh_dest());
122    cmd.arg(&wrapped_cmd);
123    cmd
124}
125
126/// Run a command on the remote host via SSH, returning stdout.
127///
128/// Stderr is always piped so we can include SSH errors in our error messages.
129/// SSH interactive prompts use `/dev/tty` directly, not stderr.
130/// In background mode, `BatchMode=yes` is set so SSH fails fast instead of hanging.
131async fn remote_exec(
132    dest: &Destination,
133    remote_cmd: &str,
134    extra_ssh_opts: &[String],
135    foreground: bool,
136) -> anyhow::Result<String> {
137    debug!("ssh {}: {remote_cmd}", dest.ssh_dest());
138
139    let mut cmd = remote_exec_command(dest, remote_cmd, extra_ssh_opts, foreground);
140    cmd.stdout(Stdio::piped());
141    cmd.stderr(Stdio::piped());
142    cmd.stdin(Stdio::null());
143
144    let output = cmd.output().await.context("failed to run ssh")?;
145
146    if !output.status.success() {
147        let stderr = String::from_utf8_lossy(&output.stderr);
148        let stderr = stderr.trim();
149        debug!("ssh failed (status {}): {stderr}", output.status);
150        if stderr.contains("command not found") || stderr.contains("No such file") {
151            bail!("gritty not found on remote host (is it in PATH?)");
152        }
153        let diag = format_ssh_diag(dest, extra_ssh_opts, foreground);
154        if stderr.is_empty() {
155            bail!("ssh command failed (exit {})\n  to diagnose: {diag}", output.status);
156        }
157        bail!("ssh command failed: {stderr}\n  to diagnose: {diag}");
158    }
159
160    let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string();
161    debug!("ssh output: {stdout}");
162    Ok(stdout)
163}
164
165/// Format a diagnostic SSH command for display in error messages.
166/// Mirrors `base_ssh_args` so the suggestion matches what was actually run.
167fn format_ssh_diag(dest: &Destination, extra_ssh_opts: &[String], foreground: bool) -> String {
168    let mut parts = vec!["ssh".to_string()];
169    for arg in base_ssh_args(dest, extra_ssh_opts, foreground) {
170        parts.push(shell_quote(&arg));
171    }
172    parts.push(dest.ssh_dest());
173    parts.join(" ")
174}
175
176/// Shell-quote a string if it contains characters that need quoting.
177/// Used only for display (--dry-run output), never for command execution.
178fn shell_quote(s: &str) -> String {
179    if s.is_empty() {
180        return "''".to_string();
181    }
182    if s.bytes().all(|b| b.is_ascii_alphanumeric() || b"-_./=:@$+%,".contains(&b)) {
183        return s.to_string();
184    }
185    format!("'{}'", s.replace('\'', "'\\''"))
186}
187
188/// Format a tokio Command as a shell string for display.
189fn format_command(cmd: &Command) -> String {
190    let std_cmd = cmd.as_std();
191    let prog = std_cmd.get_program().to_string_lossy();
192    let args: Vec<_> = std_cmd.get_args().map(|a| shell_quote(&a.to_string_lossy())).collect();
193    if args.is_empty() { prog.to_string() } else { format!("{prog} {}", args.join(" ")) }
194}
195
196/// Build the SSH tunnel command with hardened options.
197///
198/// Stderr is always piped so we can capture SSH errors on failure.
199/// (SSH interactive prompts use `/dev/tty` directly, not stderr.)
200fn tunnel_command(
201    dest: &Destination,
202    local_sock: &Path,
203    remote_sock: &str,
204    extra_ssh_opts: &[String],
205    foreground: bool,
206) -> Command {
207    let mut cmd = Command::new("ssh");
208    cmd.args(base_ssh_args(dest, extra_ssh_opts, foreground));
209    for opt in TUNNEL_SSH_OPTS {
210        cmd.arg("-o").arg(opt);
211    }
212    cmd.args(["-N", "-T"]);
213    let forward = format!("{}:{}", local_sock.display(), remote_sock);
214    cmd.arg("-L").arg(forward);
215    cmd.arg(dest.ssh_dest());
216    cmd.stdout(Stdio::null());
217    cmd.stderr(Stdio::piped());
218    cmd.stdin(Stdio::null());
219    cmd
220}
221
222/// Spawn the SSH tunnel, returning the child process.
223async fn spawn_tunnel(
224    dest: &Destination,
225    local_sock: &Path,
226    remote_sock: &str,
227    extra_ssh_opts: &[String],
228    foreground: bool,
229) -> anyhow::Result<Child> {
230    debug!("tunnel: {} -> {}:{}", local_sock.display(), dest.ssh_dest(), remote_sock,);
231    let mut cmd = tunnel_command(dest, local_sock, remote_sock, extra_ssh_opts, foreground);
232    let child = cmd.spawn().context("failed to spawn ssh tunnel")?;
233    debug!("ssh tunnel pid: {:?}", child.id());
234    Ok(child)
235}
236
237/// Poll until the local socket is connectable (200ms interval).
238async fn wait_for_socket(path: &Path, timeout: Duration) -> anyhow::Result<()> {
239    let deadline = Instant::now() + timeout;
240    loop {
241        if std::os::unix::net::UnixStream::connect(path).is_ok() {
242            return Ok(());
243        }
244        if Instant::now() >= deadline {
245            bail!("timeout waiting for SSH tunnel socket at {}", path.display());
246        }
247        tokio::time::sleep(Duration::from_millis(200)).await;
248    }
249}
250
251/// Background task: monitor SSH child, respawn on transient failure.
252async fn tunnel_monitor(
253    mut child: Child,
254    dest: Destination,
255    local_sock: PathBuf,
256    remote_sock: String,
257    extra_ssh_opts: Vec<String>,
258    stop: tokio_util::sync::CancellationToken,
259) {
260    let mut exit_times: Vec<Instant> = Vec::new();
261
262    loop {
263        tokio::select! {
264            _ = stop.cancelled() => {
265                let _ = child.kill().await;
266                return;
267            }
268            status = child.wait() => {
269                let status = match status {
270                    Ok(s) => s,
271                    Err(e) => {
272                        warn!("failed to wait on ssh tunnel: {e}");
273                        return;
274                    }
275                };
276
277                if stop.is_cancelled() {
278                    return;
279                }
280
281                let code = status.code();
282                debug!("ssh tunnel exited: {:?}", code);
283
284                // Non-transient failure: don't retry
285                // SSH exit 255 = connection error (transient). Signal-killed = no code.
286                // Everything else (auth failure, config error) = bail.
287                if let Some(c) = code
288                    && c != 255
289                {
290                    warn!("ssh tunnel exited with code {c} (not retrying)");
291                    return;
292                }
293
294                // Rate limit: 5 exits in 10s = give up
295                let now = Instant::now();
296                exit_times.push(now);
297                exit_times.retain(|t| now.duration_since(*t) < Duration::from_secs(10));
298                if exit_times.len() >= 5 {
299                    warn!("ssh tunnel failing too fast (5 exits in 10s), giving up");
300                    return;
301                }
302
303                tokio::time::sleep(Duration::from_secs(1)).await;
304
305                if stop.is_cancelled() {
306                    return;
307                }
308
309                match spawn_tunnel(&dest, &local_sock, &remote_sock, &extra_ssh_opts, false).await {
310                    Ok(new_child) => {
311                        info!("ssh tunnel respawned");
312                        child = new_child;
313                    }
314                    Err(e) => {
315                        warn!("failed to respawn ssh tunnel: {e}");
316                        return;
317                    }
318                }
319            }
320        }
321    }
322}
323
324// ---------------------------------------------------------------------------
325// Remote server management
326// ---------------------------------------------------------------------------
327
328const REMOTE_ENSURE_CMD: &str = "\
329    SOCK=$(gritty socket-path) && \
330    (gritty ls >/dev/null 2>&1 || \
331     { gritty server && sleep 0.3; }) && \
332    echo \"$SOCK\"";
333
334/// Get the remote socket path and optionally auto-start the server.
335async fn ensure_remote_ready(
336    dest: &Destination,
337    no_server_start: bool,
338    extra_ssh_opts: &[String],
339    foreground: bool,
340) -> anyhow::Result<String> {
341    let remote_cmd = if no_server_start { "gritty socket-path" } else { REMOTE_ENSURE_CMD };
342    debug!("ensuring remote server (no_server_start={no_server_start})");
343
344    let sock_path = remote_exec(dest, remote_cmd, extra_ssh_opts, foreground).await?;
345
346    if sock_path.is_empty() {
347        bail!("remote host returned empty socket path");
348    }
349
350    Ok(sock_path)
351}
352
353// ---------------------------------------------------------------------------
354// Local socket path
355// ---------------------------------------------------------------------------
356
357/// Compute a deterministic local socket path based on the destination.
358///
359/// Using the raw destination string means re-running `gritty connect user@host`
360/// produces the same socket path, so sessions that used `--ctl-socket` can
361/// auto-reconnect after a tunnel restart.
362fn local_socket_path(destination: &str) -> PathBuf {
363    crate::daemon::socket_dir().join(format!("connect-{destination}.sock"))
364}
365
366fn connect_pid_path(connection_name: &str) -> PathBuf {
367    crate::daemon::socket_dir().join(format!("connect-{connection_name}.pid"))
368}
369
370fn connect_lock_path(connection_name: &str) -> PathBuf {
371    crate::daemon::socket_dir().join(format!("connect-{connection_name}.lock"))
372}
373
374fn connect_dest_path(connection_name: &str) -> PathBuf {
375    crate::daemon::socket_dir().join(format!("connect-{connection_name}.dest"))
376}
377
378/// Compute the local socket path for a given connection name.
379/// Public so main.rs can compute the path in the parent process after daemonize.
380pub fn connection_socket_path(connection_name: &str) -> PathBuf {
381    local_socket_path(connection_name)
382}
383
384/// Extract the host component from a destination string (`[user@]host[:port]`).
385pub fn parse_host(destination: &str) -> anyhow::Result<String> {
386    Ok(Destination::parse(destination)?.host)
387}
388
389// ---------------------------------------------------------------------------
390// Lockfile-based liveness
391// ---------------------------------------------------------------------------
392
393/// Acquire an exclusive flock on the lockfile. Returns the locked fd on success.
394/// The lock is held for the lifetime of the returned `OwnedFd`.
395fn acquire_lock(lock_path: &Path) -> anyhow::Result<OwnedFd> {
396    use std::fs::OpenOptions;
397    let file = OpenOptions::new()
398        .create(true)
399        .truncate(false)
400        .write(true)
401        .mode(0o600)
402        .open(lock_path)
403        .with_context(|| format!("failed to open lockfile: {}", lock_path.display()))?;
404    let fd = OwnedFd::from(file);
405    if unsafe { libc::flock(fd.as_raw_fd(), libc::LOCK_EX) } != 0 {
406        bail!("failed to acquire lock on {}", lock_path.display());
407    }
408    Ok(fd)
409}
410
411/// Probe whether a lockfile is held by a live process.
412/// Returns true if the lock is held (process alive), false if free (process dead).
413fn is_lock_held(lock_path: &Path) -> bool {
414    use std::fs::OpenOptions;
415    let file = match OpenOptions::new().read(true).open(lock_path) {
416        Ok(f) => f,
417        Err(_) => return false,
418    };
419    // Non-blocking exclusive lock attempt: if it succeeds, the old process is dead
420    if unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_EX | libc::LOCK_NB) } == 0 {
421        // We got the lock — old process is gone. Release it immediately (fd drop).
422        false
423    } else {
424        true // Lock held by another process
425    }
426}
427
428/// Tunnel health status.
429#[derive(Debug, PartialEq, Eq)]
430pub enum TunnelStatus {
431    Healthy,
432    Reconnecting,
433    Stale,
434}
435
436/// Probe a tunnel's status using lockfile + socket connectivity.
437fn probe_tunnel_status(name: &str) -> TunnelStatus {
438    let lock_path = connect_lock_path(name);
439    if is_lock_held(&lock_path) {
440        let sock_path = local_socket_path(name);
441        if std::os::unix::net::UnixStream::connect(&sock_path).is_ok() {
442            TunnelStatus::Healthy
443        } else {
444            TunnelStatus::Reconnecting
445        }
446    } else {
447        TunnelStatus::Stale
448    }
449}
450
451/// Clean up files for a stale tunnel (process already dead).
452/// No signals sent — the process is confirmed dead (lockfile released).
453/// Orphaned SSH children self-terminate via ServerAliveInterval/ServerAliveCountMax.
454fn read_pid_hint(name: &str) -> Option<u32> {
455    std::fs::read_to_string(connect_pid_path(name)).ok().and_then(|s| s.trim().parse().ok())
456}
457
458fn cleanup_stale_files(name: &str) {
459    let _ = std::fs::remove_file(local_socket_path(name));
460    let _ = std::fs::remove_file(connect_pid_path(name));
461    let _ = std::fs::remove_file(connect_lock_path(name));
462    let _ = std::fs::remove_file(connect_dest_path(name));
463}
464
465/// Extract tunnel connection names by globbing lock files in the socket dir.
466fn enumerate_tunnels() -> Vec<String> {
467    let dir = crate::daemon::socket_dir();
468    let Ok(entries) = std::fs::read_dir(&dir) else {
469        return Vec::new();
470    };
471    entries
472        .filter_map(|e| e.ok())
473        .filter_map(|e| {
474            let name = e.file_name().to_string_lossy().to_string();
475            if name.starts_with("connect-") && name.ends_with(".lock") {
476                Some(name["connect-".len()..name.len() - ".lock".len()].to_string())
477            } else {
478                None
479            }
480        })
481        .collect()
482}
483
484// ---------------------------------------------------------------------------
485// Cleanup guard
486// ---------------------------------------------------------------------------
487
488struct ConnectGuard {
489    child: Option<Child>,
490    local_sock: PathBuf,
491    pid_file: PathBuf,
492    lock_file: PathBuf,
493    dest_file: PathBuf,
494    _lock_fd: Option<OwnedFd>,
495    stop: tokio_util::sync::CancellationToken,
496}
497
498impl Drop for ConnectGuard {
499    fn drop(&mut self) {
500        self.stop.cancel();
501
502        if let Some(ref mut child) = self.child
503            && let Some(pid) = child.id()
504        {
505            unsafe {
506                libc::kill(pid as i32, libc::SIGTERM);
507            }
508        }
509
510        let _ = std::fs::remove_file(&self.local_sock);
511        let _ = std::fs::remove_file(&self.pid_file);
512        let _ = std::fs::remove_file(&self.lock_file);
513        let _ = std::fs::remove_file(&self.dest_file);
514        // _lock_fd drops here, releasing the flock
515    }
516}
517
518// ---------------------------------------------------------------------------
519// Public API
520// ---------------------------------------------------------------------------
521
522pub struct ConnectOpts {
523    pub destination: String,
524    pub no_server_start: bool,
525    pub ssh_options: Vec<String>,
526    pub name: Option<String>,
527    pub dry_run: bool,
528    pub foreground: bool,
529}
530
531pub async fn run(opts: ConnectOpts, ready_fd: Option<OwnedFd>) -> anyhow::Result<i32> {
532    let dest = Destination::parse(&opts.destination)?;
533    let connection_name = opts.name.unwrap_or_else(|| dest.host.clone());
534    let local_sock = local_socket_path(&connection_name);
535
536    if opts.dry_run {
537        let remote_cmd =
538            if opts.no_server_start { "gritty socket-path" } else { REMOTE_ENSURE_CMD };
539        let ensure_cmd = remote_exec_command(&dest, remote_cmd, &opts.ssh_options, true);
540        let tunnel_cmd =
541            tunnel_command(&dest, &local_sock, "$REMOTE_SOCK", &opts.ssh_options, true);
542
543        println!(
544            "# Get remote socket path{}",
545            if opts.no_server_start { "" } else { " and start server if needed" }
546        );
547        println!("REMOTE_SOCK=$({})", format_command(&ensure_cmd));
548        println!();
549        println!("# Start SSH tunnel");
550        println!("{}", format_command(&tunnel_cmd));
551        return Ok(0);
552    }
553
554    // 1. Ensure socket directory exists
555    let pid_file = connect_pid_path(&connection_name);
556    let lock_path = connect_lock_path(&connection_name);
557    let dest_file = connect_dest_path(&connection_name);
558    debug!("local socket: {}", local_sock.display());
559    if let Some(parent) = local_sock.parent() {
560        crate::security::secure_create_dir_all(parent)?;
561    }
562
563    // 2. Check for existing tunnel via lockfile (authoritative)
564    match probe_tunnel_status(&connection_name) {
565        TunnelStatus::Healthy => {
566            println!("{}", local_sock.display());
567            let pid_hint = read_pid_hint(&connection_name);
568            eprint!("tunnel already running (name: {connection_name})");
569            if let Some(pid) = pid_hint {
570                eprintln!(" (pid {pid})");
571                eprintln!("  to stop: gritty disconnect {connection_name}");
572            } else {
573                eprintln!();
574            }
575            eprintln!("  to use:");
576            eprintln!("    gritty new {connection_name}");
577            eprintln!("    gritty attach {connection_name} -t <name>");
578            // Signal readiness to parent even for already-running case
579            signal_ready(&ready_fd);
580            return Ok(0);
581        }
582        TunnelStatus::Reconnecting => {
583            let pid_hint = read_pid_hint(&connection_name);
584            eprint!("tunnel exists but is reconnecting (name: {connection_name})");
585            if let Some(pid) = pid_hint {
586                eprintln!(" (pid {pid})");
587            } else {
588                eprintln!();
589            }
590            eprintln!("  wait for it, or: gritty disconnect {connection_name}");
591            // Signal readiness to parent so it doesn't hang
592            signal_ready(&ready_fd);
593            return Ok(0);
594        }
595        TunnelStatus::Stale => {
596            debug!("cleaning stale tunnel files for {connection_name}");
597            cleanup_stale_files(&connection_name);
598        }
599    }
600
601    // 3. Acquire lockfile (held for entire lifetime of this process)
602    let lock_fd = acquire_lock(&lock_path)?;
603
604    // 4. Ensure remote server is running and get socket path
605    let remote_sock =
606        ensure_remote_ready(&dest, opts.no_server_start, &opts.ssh_options, opts.foreground)
607            .await?;
608    debug!(remote_sock, "remote socket path");
609
610    // 5. Spawn SSH tunnel
611    let child =
612        spawn_tunnel(&dest, &local_sock, &remote_sock, &opts.ssh_options, opts.foreground).await?;
613    let stop = tokio_util::sync::CancellationToken::new();
614
615    let mut guard = ConnectGuard {
616        child: Some(child),
617        local_sock: local_sock.clone(),
618        pid_file: pid_file.clone(),
619        lock_file: lock_path,
620        dest_file: dest_file.clone(),
621        _lock_fd: Some(lock_fd),
622        stop: stop.clone(),
623    };
624
625    // 6. Wait for local socket to become connectable (race against child exit)
626    let mut child = guard.child.take().unwrap();
627    tokio::select! {
628        result = wait_for_socket(&local_sock, Duration::from_secs(15)) => {
629            result?;
630            guard.child = Some(child);
631        }
632        status = child.wait() => {
633            let status = status.context("failed to wait on ssh tunnel")?;
634            let diag = format_ssh_diag(&dest, &opts.ssh_options, opts.foreground);
635            let msg = if let Some(mut stderr) = child.stderr.take() {
636                use tokio::io::AsyncReadExt;
637                let mut buf = String::new();
638                let _ = stderr.read_to_string(&mut buf).await;
639                let buf = buf.trim().to_string();
640                if buf.is_empty() { None } else { Some(buf) }
641            } else {
642                None
643            };
644            match msg {
645                Some(err) => bail!("ssh tunnel failed: {err}\n  to diagnose: {diag}"),
646                None => bail!("ssh tunnel exited ({status})\n  to diagnose: {diag}"),
647            }
648        }
649    }
650    debug!("tunnel socket ready");
651
652    // Write PID + dest files
653    let _ = std::fs::write(&pid_file, std::process::id().to_string());
654    let _ = std::fs::write(&dest_file, &opts.destination);
655
656    // 7. Signal readiness to parent (or print if foreground)
657    signal_ready(&ready_fd);
658
659    // 8. Hand off the child to the tunnel monitor background task
660    let original_child = guard.child.take().unwrap();
661    let monitor_handle = tokio::spawn(tunnel_monitor(
662        original_child,
663        dest,
664        local_sock.clone(),
665        remote_sock,
666        opts.ssh_options,
667        stop.clone(),
668    ));
669
670    // 9. Wait for signal or monitor death
671    let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
672    tokio::select! {
673        _ = sigterm.recv() => {}
674        _ = monitor_handle => {}
675    }
676
677    // 10. Cleanup (guard Drop handles ssh kill + file removal + lock release)
678    drop(guard);
679
680    Ok(0)
681}
682
683/// Write one readiness byte to the pipe fd (if present).
684fn signal_ready(ready_fd: &Option<OwnedFd>) {
685    if let Some(fd) = ready_fd {
686        let _ = nix::unistd::write(fd, b"\x01");
687    }
688}
689
690// ---------------------------------------------------------------------------
691// Disconnect
692// ---------------------------------------------------------------------------
693
694pub async fn disconnect(name: &str) -> anyhow::Result<()> {
695    match probe_tunnel_status(name) {
696        TunnelStatus::Stale => {
697            cleanup_stale_files(name);
698            eprintln!("tunnel already stopped: {name}");
699            return Ok(());
700        }
701        TunnelStatus::Healthy | TunnelStatus::Reconnecting => {}
702    }
703
704    // Read PID and send SIGTERM (let the process handle graceful shutdown)
705    let pid_file = connect_pid_path(name);
706    let pid = std::fs::read_to_string(&pid_file)
707        .ok()
708        .and_then(|s| s.trim().parse::<i32>().ok())
709        .ok_or_else(|| anyhow::anyhow!("cannot read PID for tunnel {name}"))?;
710
711    unsafe {
712        libc::kill(pid, libc::SIGTERM);
713    }
714
715    // Poll lock for up to 2s to confirm exit
716    let deadline = Instant::now() + Duration::from_secs(2);
717    loop {
718        tokio::time::sleep(Duration::from_millis(100)).await;
719        if !is_lock_held(&connect_lock_path(name)) {
720            cleanup_stale_files(name);
721            eprintln!("tunnel stopped: {name}");
722            return Ok(());
723        }
724        if Instant::now() >= deadline {
725            break;
726        }
727    }
728
729    // Still alive after timeout — escalate to SIGKILL + killpg
730    unsafe {
731        libc::kill(pid, libc::SIGKILL);
732        libc::killpg(pid, libc::SIGTERM);
733    }
734    tokio::time::sleep(Duration::from_millis(100)).await;
735    cleanup_stale_files(name);
736    eprintln!("tunnel killed: {name}");
737    Ok(())
738}
739
740// ---------------------------------------------------------------------------
741// List tunnels
742// ---------------------------------------------------------------------------
743
744pub struct TunnelInfo {
745    pub name: String,
746    pub destination: String,
747    pub status: String,
748    pub pid: Option<u32>,
749    pub log_path: PathBuf,
750}
751
752/// Gather info for all live tunnels (cleans stale ones as a side effect).
753pub fn get_tunnel_info() -> Vec<TunnelInfo> {
754    let names = enumerate_tunnels();
755    let mut infos = Vec::new();
756    for name in &names {
757        let status = probe_tunnel_status(name);
758        if status == TunnelStatus::Stale {
759            debug!("cleaning stale tunnel: {name}");
760            cleanup_stale_files(name);
761            continue;
762        }
763        let dest =
764            std::fs::read_to_string(connect_dest_path(name)).unwrap_or_else(|_| "-".to_string());
765        let status_str = match status {
766            TunnelStatus::Healthy => "healthy".to_string(),
767            TunnelStatus::Reconnecting => "reconnecting".to_string(),
768            TunnelStatus::Stale => unreachable!(),
769        };
770        infos.push(TunnelInfo {
771            name: name.clone(),
772            destination: dest.trim().to_string(),
773            status: status_str,
774            pid: read_pid_hint(name),
775            log_path: crate::daemon::socket_dir().join(format!("connect-{name}.log")),
776        });
777    }
778    infos
779}
780
781pub fn list_tunnels() {
782    let infos = get_tunnel_info();
783    if infos.is_empty() {
784        println!("no active tunnels");
785        return;
786    }
787
788    let w_name = infos.iter().map(|i| i.name.len()).max().unwrap().max(4);
789    let w_dest = infos.iter().map(|i| i.destination.len()).max().unwrap().max(11);
790
791    println!("{:<w_name$}  {:<w_dest$}  Status", "Name", "Destination");
792    for info in &infos {
793        println!("{:<w_name$}  {:<w_dest$}  {}", info.name, info.destination, info.status);
794    }
795}
796
797// ---------------------------------------------------------------------------
798// Tests
799// ---------------------------------------------------------------------------
800
801#[cfg(test)]
802mod tests {
803    use super::*;
804
805    #[test]
806    fn parse_destination_user_host() {
807        let d = Destination::parse("user@host").unwrap();
808        assert_eq!(d.user.as_deref(), Some("user"));
809        assert_eq!(d.host, "host");
810        assert_eq!(d.port, None);
811    }
812
813    #[test]
814    fn parse_destination_host_only() {
815        let d = Destination::parse("myhost").unwrap();
816        assert_eq!(d.user, None);
817        assert_eq!(d.host, "myhost");
818        assert_eq!(d.port, None);
819    }
820
821    #[test]
822    fn parse_destination_host_port() {
823        let d = Destination::parse("host:2222").unwrap();
824        assert_eq!(d.user, None);
825        assert_eq!(d.host, "host");
826        assert_eq!(d.port, Some(2222));
827    }
828
829    #[test]
830    fn parse_destination_user_host_port() {
831        let d = Destination::parse("user@host:2222").unwrap();
832        assert_eq!(d.user.as_deref(), Some("user"));
833        assert_eq!(d.host, "host");
834        assert_eq!(d.port, Some(2222));
835    }
836
837    #[test]
838    fn parse_destination_invalid_empty() {
839        assert!(Destination::parse("").is_err());
840    }
841
842    #[test]
843    fn parse_destination_invalid_at_only() {
844        assert!(Destination::parse("@host").is_err());
845    }
846
847    #[test]
848    fn parse_destination_invalid_colon_only() {
849        assert!(Destination::parse(":2222").is_err());
850    }
851
852    #[test]
853    fn tunnel_command_default_opts() {
854        let dest = Destination::parse("user@host").unwrap();
855        let cmd = tunnel_command(
856            &dest,
857            Path::new("/tmp/local.sock"),
858            "/run/user/1000/gritty/ctl.sock",
859            &[],
860            false,
861        );
862        let args: Vec<_> =
863            cmd.as_std().get_args().map(|a| a.to_string_lossy().to_string()).collect();
864        // From base_ssh_args
865        assert!(args.contains(&"ConnectTimeout=5".to_string()));
866        assert!(args.contains(&"BatchMode=yes".to_string()));
867        // From TUNNEL_SSH_OPTS
868        assert!(args.contains(&"ServerAliveInterval=3".to_string()));
869        assert!(args.contains(&"StreamLocalBindUnlink=yes".to_string()));
870        assert!(args.contains(&"ExitOnForwardFailure=yes".to_string()));
871        assert!(args.contains(&"ControlPath=none".to_string()));
872        assert!(args.contains(&"ForwardAgent=no".to_string()));
873        assert!(args.contains(&"ForwardX11=no".to_string()));
874        // Tunnel flags and forward
875        assert!(args.contains(&"-N".to_string()));
876        assert!(args.contains(&"-T".to_string()));
877        assert!(args.contains(&"/tmp/local.sock:/run/user/1000/gritty/ctl.sock".to_string()));
878        assert!(args.contains(&"user@host".to_string()));
879    }
880
881    #[test]
882    fn tunnel_command_extra_opts() {
883        let dest = Destination::parse("host:2222").unwrap();
884        let cmd = tunnel_command(
885            &dest,
886            Path::new("/tmp/local.sock"),
887            "/tmp/remote.sock",
888            &["ProxyJump=bastion".to_string()],
889            false,
890        );
891        let args: Vec<_> =
892            cmd.as_std().get_args().map(|a| a.to_string_lossy().to_string()).collect();
893        assert!(args.contains(&"ProxyJump=bastion".to_string()));
894        assert!(args.contains(&"-p".to_string()));
895        assert!(args.contains(&"2222".to_string()));
896    }
897
898    #[test]
899    fn local_socket_path_format() {
900        // With hostname-based naming, connect uses just the host part
901        let path = local_socket_path("devbox");
902        assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-devbox.sock");
903
904        let path = local_socket_path("example.com");
905        assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-example.com.sock");
906
907        // Custom name override
908        let path = local_socket_path("myproject");
909        assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-myproject.sock");
910    }
911
912    #[test]
913    fn connect_pid_path_format() {
914        let path = connect_pid_path("devbox");
915        assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-devbox.pid");
916
917        let path = connect_pid_path("example.com");
918        assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-example.com.pid");
919    }
920
921    #[test]
922    fn ssh_dest_with_user() {
923        let d = Destination::parse("alice@example.com").unwrap();
924        assert_eq!(d.ssh_dest(), "alice@example.com");
925    }
926
927    #[test]
928    fn ssh_dest_without_user() {
929        let d = Destination::parse("example.com").unwrap();
930        assert_eq!(d.ssh_dest(), "example.com");
931    }
932
933    #[test]
934    fn port_args_with_port() {
935        let d = Destination::parse("host:9999").unwrap();
936        assert_eq!(d.port_args(), vec!["-p", "9999"]);
937    }
938
939    #[test]
940    fn port_args_without_port() {
941        let d = Destination::parse("host").unwrap();
942        assert!(d.port_args().is_empty());
943    }
944
945    #[test]
946    fn shell_quote_simple() {
947        assert_eq!(shell_quote("hello"), "hello");
948        assert_eq!(shell_quote("-N"), "-N");
949        assert_eq!(shell_quote("ServerAliveInterval=3"), "ServerAliveInterval=3");
950        assert_eq!(shell_quote("user@host"), "user@host");
951        assert_eq!(
952            shell_quote("/tmp/local.sock:/tmp/remote.sock"),
953            "/tmp/local.sock:/tmp/remote.sock"
954        );
955        assert_eq!(shell_quote("$REMOTE_SOCK"), "$REMOTE_SOCK");
956    }
957
958    #[test]
959    fn shell_quote_needs_quoting() {
960        assert_eq!(shell_quote("hello world"), "'hello world'");
961        assert_eq!(shell_quote(""), "''");
962        assert_eq!(shell_quote("it's"), "'it'\\''s'");
963    }
964
965    #[test]
966    fn shell_quote_remote_cmd() {
967        // The wrapped remote command contains spaces, quotes, semicolons —
968        // must be single-quoted so $HOME expands on the remote side.
969        let cmd = format!("PATH=\"{REMOTE_PATH_PREFIX}\"; gritty socket-path");
970        let quoted = shell_quote(&cmd);
971        assert!(quoted.starts_with('\''));
972        assert!(quoted.ends_with('\''));
973    }
974
975    #[test]
976    fn format_command_tunnel() {
977        let dest = Destination::parse("user@host").unwrap();
978        let cmd = tunnel_command(&dest, Path::new("/tmp/local.sock"), "$REMOTE_SOCK", &[], true);
979        let formatted = format_command(&cmd);
980        assert!(formatted.contains("ServerAliveInterval=3"));
981        assert!(formatted.contains("ControlPath=none"));
982        assert!(formatted.contains("ForwardAgent=no"));
983        assert!(formatted.contains("-N"));
984        assert!(formatted.contains("-T"));
985        // Forward arg references $REMOTE_SOCK unquoted (no spaces, $ is safe)
986        assert!(formatted.contains("/tmp/local.sock:$REMOTE_SOCK"));
987        assert!(formatted.contains("user@host"));
988    }
989
990    #[test]
991    fn format_command_remote_exec() {
992        let dest = Destination::parse("user@host:2222").unwrap();
993        let cmd = remote_exec_command(&dest, "gritty socket-path", &[], true);
994        let formatted = format_command(&cmd);
995        assert!(formatted.starts_with("ssh "));
996        assert!(formatted.contains("-p 2222"));
997        assert!(formatted.contains("ConnectTimeout=5"));
998        assert!(formatted.contains("user@host"));
999        // The wrapped command should be single-quoted (contains spaces)
1000        assert!(formatted.contains(&format!("PATH=\"{REMOTE_PATH_PREFIX}\"")));
1001    }
1002
1003    #[test]
1004    fn format_command_remote_exec_with_extra_opts() {
1005        let dest = Destination::parse("user@host").unwrap();
1006        let cmd =
1007            remote_exec_command(&dest, REMOTE_ENSURE_CMD, &["ProxyJump=bastion".to_string()], true);
1008        let formatted = format_command(&cmd);
1009        assert!(formatted.contains("ProxyJump=bastion"));
1010        assert!(formatted.contains("gritty socket-path"));
1011        assert!(formatted.contains("gritty server"));
1012    }
1013
1014    #[test]
1015    fn base_ssh_args_foreground() {
1016        let dest = Destination::parse("user@host:2222").unwrap();
1017        let args = base_ssh_args(&dest, &["ProxyJump=bastion".into()], true);
1018        assert!(args.contains(&"-p".to_string()));
1019        assert!(args.contains(&"2222".to_string()));
1020        assert!(args.contains(&"ProxyJump=bastion".to_string()));
1021        assert!(args.contains(&"ConnectTimeout=5".to_string()));
1022        assert!(!args.contains(&"BatchMode=yes".to_string()));
1023    }
1024
1025    #[test]
1026    fn base_ssh_args_background() {
1027        let dest = Destination::parse("host").unwrap();
1028        let args = base_ssh_args(&dest, &[], false);
1029        assert!(args.contains(&"ConnectTimeout=5".to_string()));
1030        assert!(args.contains(&"BatchMode=yes".to_string()));
1031        assert!(!args.contains(&"-p".to_string()));
1032    }
1033
1034    // -----------------------------------------------------------------------
1035    // Lockfile and tunnel lifecycle tests
1036    // -----------------------------------------------------------------------
1037
1038    #[test]
1039    fn connect_lock_path_format() {
1040        let path = connect_lock_path("devbox");
1041        assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-devbox.lock");
1042    }
1043
1044    #[test]
1045    fn connect_dest_path_format() {
1046        let path = connect_dest_path("devbox");
1047        assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-devbox.dest");
1048    }
1049
1050    #[test]
1051    fn acquire_and_probe_lock() {
1052        let dir = tempfile::tempdir().unwrap();
1053        let lock_path = dir.path().join("test.lock");
1054
1055        // Lock not held initially (file doesn't exist)
1056        assert!(!is_lock_held(&lock_path));
1057
1058        // Acquire the lock
1059        let _fd = acquire_lock(&lock_path).unwrap();
1060
1061        // Now it should be held
1062        assert!(is_lock_held(&lock_path));
1063
1064        // Drop the lock
1065        drop(_fd);
1066
1067        // Should be free again
1068        assert!(!is_lock_held(&lock_path));
1069    }
1070
1071    #[test]
1072    fn probe_stale_no_files() {
1073        // No files at all → stale
1074        let status = probe_tunnel_status("nonexistent-test-tunnel-xyz");
1075        assert_eq!(status, TunnelStatus::Stale);
1076    }
1077
1078    #[test]
1079    fn cleanup_stale_files_removes_all() {
1080        let _dir = tempfile::tempdir().unwrap();
1081        // We can't easily override socket_dir(), so test that cleanup_stale_files
1082        // at least doesn't panic on nonexistent files
1083        cleanup_stale_files("nonexistent-cleanup-test-xyz");
1084        // No panic = success
1085    }
1086
1087    #[test]
1088    fn enumerate_tunnels_empty_dir() {
1089        // If socket dir doesn't have any lock files, should return empty
1090        // This tests the function doesn't crash on various filesystem states
1091        let names = enumerate_tunnels();
1092        // We can't control what's in socket_dir during tests, but at minimum
1093        // the function should not panic
1094        let _ = names;
1095    }
1096
1097    #[test]
1098    fn connection_socket_path_matches_local() {
1099        let public_path = connection_socket_path("myhost");
1100        let internal_path = local_socket_path("myhost");
1101        assert_eq!(public_path, internal_path);
1102    }
1103
1104    // -----------------------------------------------------------------------
1105    // tunnel_monitor tests
1106    // -----------------------------------------------------------------------
1107
1108    #[tokio::test]
1109    async fn tunnel_monitor_non_transient_exit() {
1110        let child = Command::new("sh").arg("-c").arg("exit 1").spawn().unwrap();
1111        let dest = Destination::parse("fake-host-test").unwrap();
1112        let stop = tokio_util::sync::CancellationToken::new();
1113
1114        let result = tokio::time::timeout(
1115            Duration::from_secs(5),
1116            tunnel_monitor(
1117                child,
1118                dest,
1119                PathBuf::from("/tmp/nonexistent.sock"),
1120                "/tmp/remote.sock".into(),
1121                vec![],
1122                stop,
1123            ),
1124        )
1125        .await;
1126
1127        assert!(result.is_ok(), "monitor should return quickly on non-transient exit");
1128    }
1129
1130    #[tokio::test]
1131    async fn tunnel_monitor_cancellation() {
1132        let child = Command::new("sleep").arg("60").spawn().unwrap();
1133        let dest = Destination::parse("fake-host-test").unwrap();
1134        let stop = tokio_util::sync::CancellationToken::new();
1135        let stop_clone = stop.clone();
1136
1137        tokio::spawn(async move {
1138            tokio::time::sleep(Duration::from_millis(100)).await;
1139            stop_clone.cancel();
1140        });
1141
1142        let result = tokio::time::timeout(
1143            Duration::from_secs(5),
1144            tunnel_monitor(
1145                child,
1146                dest,
1147                PathBuf::from("/tmp/nonexistent.sock"),
1148                "/tmp/remote.sock".into(),
1149                vec![],
1150                stop,
1151            ),
1152        )
1153        .await;
1154
1155        assert!(result.is_ok(), "monitor should return after cancellation");
1156    }
1157
1158    #[tokio::test]
1159    async fn tunnel_monitor_transient_exit_checks_cancellation() {
1160        // Child exits with 255 (transient). Monitor sleeps 1s then checks cancellation.
1161        let child = Command::new("sh").arg("-c").arg("exit 255").spawn().unwrap();
1162        let dest = Destination::parse("fake-host-test").unwrap();
1163        let stop = tokio_util::sync::CancellationToken::new();
1164        let stop_clone = stop.clone();
1165
1166        // Cancel during the 1s sleep between exit and respawn attempt
1167        tokio::spawn(async move {
1168            tokio::time::sleep(Duration::from_millis(500)).await;
1169            stop_clone.cancel();
1170        });
1171
1172        let result = tokio::time::timeout(
1173            Duration::from_secs(5),
1174            tunnel_monitor(
1175                child,
1176                dest,
1177                PathBuf::from("/tmp/nonexistent.sock"),
1178                "/tmp/remote.sock".into(),
1179                vec![],
1180                stop,
1181            ),
1182        )
1183        .await;
1184
1185        assert!(result.is_ok(), "monitor should return after cancellation during sleep");
1186    }
1187
1188    // -----------------------------------------------------------------------
1189    // wait_for_socket tests
1190    // -----------------------------------------------------------------------
1191
1192    #[tokio::test]
1193    async fn wait_for_socket_succeeds_after_delay() {
1194        let dir = tempfile::tempdir().unwrap();
1195        let sock_path = dir.path().join("delayed.sock");
1196        let sock_path_clone = sock_path.clone();
1197
1198        // Bind the socket after 500ms
1199        tokio::spawn(async move {
1200            tokio::time::sleep(Duration::from_millis(500)).await;
1201            let _listener = tokio::net::UnixListener::bind(&sock_path_clone).unwrap();
1202            // Keep listener alive
1203            tokio::time::sleep(Duration::from_secs(30)).await;
1204        });
1205
1206        let result = wait_for_socket(&sock_path, Duration::from_secs(5)).await;
1207        assert!(result.is_ok(), "should successfully connect");
1208    }
1209
1210    #[tokio::test]
1211    async fn wait_for_socket_timeout() {
1212        let dir = tempfile::tempdir().unwrap();
1213        let sock_path = dir.path().join("never.sock");
1214        let result = wait_for_socket(&sock_path, Duration::from_secs(1)).await;
1215        assert!(result.is_err());
1216    }
1217}