Skip to main content

gritty/
connect.rs

1use anyhow::{Context, bail};
2use std::path::{Path, PathBuf};
3use std::process::Stdio;
4use std::time::{Duration, Instant};
5use tokio::process::{Child, Command};
6use tracing::{debug, info, warn};
7
8// ---------------------------------------------------------------------------
9// Destination parsing
10// ---------------------------------------------------------------------------
11
12#[derive(Debug, Clone, PartialEq, Eq)]
13struct Destination {
14    user: Option<String>,
15    host: String,
16    port: Option<u16>,
17}
18
19impl Destination {
20    fn parse(s: &str) -> anyhow::Result<Self> {
21        if s.is_empty() {
22            bail!("empty destination");
23        }
24
25        let (user, remainder) = if let Some(at) = s.find('@') {
26            let u = &s[..at];
27            if u.is_empty() {
28                bail!("empty user in destination: {s}");
29            }
30            (Some(u.to_string()), &s[at + 1..])
31        } else {
32            (None, s)
33        };
34
35        let (host, port) = if let Some(colon) = remainder.rfind(':') {
36            let h = &remainder[..colon];
37            let p = remainder[colon + 1..]
38                .parse::<u16>()
39                .with_context(|| format!("invalid port in destination: {s}"))?;
40            (h.to_string(), Some(p))
41        } else {
42            (remainder.to_string(), None)
43        };
44
45        if host.is_empty() {
46            bail!("empty host in destination: {s}");
47        }
48
49        Ok(Self { user, host, port })
50    }
51
52    /// Build the SSH destination string (`user@host` or just `host`).
53    fn ssh_dest(&self) -> String {
54        match &self.user {
55            Some(u) => format!("{u}@{}", self.host),
56            None => self.host.clone(),
57        }
58    }
59
60    /// Common SSH args for port, if set.
61    fn port_args(&self) -> Vec<String> {
62        match self.port {
63            Some(p) => vec!["-p".to_string(), p.to_string()],
64            None => vec![],
65        }
66    }
67}
68
69// ---------------------------------------------------------------------------
70// SSH helpers
71// ---------------------------------------------------------------------------
72
73/// Hardened SSH options embedded in every tunnel.
74const SSH_TUNNEL_OPTS: &[&str] = &[
75    "-o",
76    "ServerAliveInterval=3",
77    "-o",
78    "ServerAliveCountMax=2",
79    "-o",
80    "StreamLocalBindUnlink=yes",
81    "-o",
82    "ExitOnForwardFailure=yes",
83    "-o",
84    "ConnectTimeout=5",
85    "-N",
86    "-T",
87];
88
89/// Run a command on the remote host via SSH, returning stdout.
90async fn remote_exec(
91    dest: &Destination,
92    remote_cmd: &str,
93    extra_ssh_opts: &[String],
94) -> anyhow::Result<String> {
95    // Prepend common binary paths — SSH non-interactive shells don't source
96    // .bashrc/.zshrc, so ~/bin etc. won't be in PATH by default.
97    let wrapped_cmd =
98        format!("PATH=\"$HOME/bin:$HOME/.local/bin:$HOME/.cargo/bin:$PATH\"; {remote_cmd}");
99
100    debug!("ssh {}: {remote_cmd}", dest.ssh_dest());
101
102    let mut cmd = Command::new("ssh");
103    cmd.args(dest.port_args());
104    for opt in extra_ssh_opts {
105        cmd.arg("-o").arg(opt);
106    }
107    cmd.arg("-o").arg("ConnectTimeout=5");
108    cmd.arg(dest.ssh_dest());
109    cmd.arg(&wrapped_cmd);
110    cmd.stdout(Stdio::piped());
111    cmd.stderr(Stdio::piped());
112    cmd.stdin(Stdio::null());
113
114    let output = cmd.output().await.context("failed to run ssh")?;
115
116    if !output.status.success() {
117        let stderr = String::from_utf8_lossy(&output.stderr);
118        let stderr = stderr.trim();
119        debug!("ssh failed (status {}): {stderr}", output.status);
120        if stderr.contains("command not found") || stderr.contains("No such file") {
121            bail!("gritty not found on remote host (is it in PATH?)");
122        }
123        bail!("ssh command failed: {stderr}");
124    }
125
126    let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string();
127    debug!("ssh output: {stdout}");
128    Ok(stdout)
129}
130
131/// Build the SSH tunnel command with hardened options.
132fn tunnel_command(
133    dest: &Destination,
134    local_sock: &Path,
135    remote_sock: &str,
136    extra_ssh_opts: &[String],
137) -> Command {
138    let mut cmd = Command::new("ssh");
139    cmd.args(dest.port_args());
140    cmd.args(SSH_TUNNEL_OPTS);
141    for opt in extra_ssh_opts {
142        cmd.arg("-o").arg(opt);
143    }
144    let forward = format!("{}:{}", local_sock.display(), remote_sock);
145    cmd.arg("-L").arg(forward);
146    cmd.arg(dest.ssh_dest());
147    cmd.stdout(Stdio::null());
148    cmd.stderr(Stdio::piped());
149    cmd.stdin(Stdio::null());
150    cmd
151}
152
153/// Spawn the SSH tunnel, returning the child process.
154async fn spawn_tunnel(
155    dest: &Destination,
156    local_sock: &Path,
157    remote_sock: &str,
158    extra_ssh_opts: &[String],
159) -> anyhow::Result<Child> {
160    debug!("tunnel: {} -> {}:{}", local_sock.display(), dest.ssh_dest(), remote_sock,);
161    let mut cmd = tunnel_command(dest, local_sock, remote_sock, extra_ssh_opts);
162    let child = cmd.spawn().context("failed to spawn ssh tunnel")?;
163    debug!("ssh tunnel pid: {:?}", child.id());
164    Ok(child)
165}
166
167/// Poll until the local socket is connectable (200ms interval, 15s timeout).
168async fn wait_for_socket(path: &Path) -> anyhow::Result<()> {
169    let deadline = Instant::now() + Duration::from_secs(15);
170    loop {
171        if std::os::unix::net::UnixStream::connect(path).is_ok() {
172            return Ok(());
173        }
174        if Instant::now() >= deadline {
175            bail!("timeout waiting for SSH tunnel socket at {}", path.display());
176        }
177        tokio::time::sleep(Duration::from_millis(200)).await;
178    }
179}
180
181/// Background task: monitor SSH child, respawn on transient failure.
182async fn tunnel_monitor(
183    mut child: Child,
184    dest: Destination,
185    local_sock: PathBuf,
186    remote_sock: String,
187    extra_ssh_opts: Vec<String>,
188    stop: tokio_util::sync::CancellationToken,
189) {
190    let mut exit_times: Vec<Instant> = Vec::new();
191
192    loop {
193        tokio::select! {
194            _ = stop.cancelled() => {
195                let _ = child.kill().await;
196                return;
197            }
198            status = child.wait() => {
199                let status = match status {
200                    Ok(s) => s,
201                    Err(e) => {
202                        warn!("failed to wait on ssh tunnel: {e}");
203                        return;
204                    }
205                };
206
207                if stop.is_cancelled() {
208                    return;
209                }
210
211                let code = status.code();
212                debug!("ssh tunnel exited: {:?}", code);
213
214                // Non-transient failure: don't retry
215                // SSH exit 255 = connection error (transient). Signal-killed = no code.
216                // Everything else (auth failure, config error) = bail.
217                if let Some(c) = code
218                    && c != 255
219                {
220                    warn!("ssh tunnel exited with code {c} (not retrying)");
221                    return;
222                }
223
224                // Rate limit: 5 exits in 10s = give up
225                let now = Instant::now();
226                exit_times.push(now);
227                exit_times.retain(|t| now.duration_since(*t) < Duration::from_secs(10));
228                if exit_times.len() >= 5 {
229                    warn!("ssh tunnel failing too fast (5 exits in 10s), giving up");
230                    return;
231                }
232
233                tokio::time::sleep(Duration::from_secs(1)).await;
234
235                if stop.is_cancelled() {
236                    return;
237                }
238
239                match spawn_tunnel(&dest, &local_sock, &remote_sock, &extra_ssh_opts).await {
240                    Ok(new_child) => {
241                        info!("ssh tunnel respawned");
242                        child = new_child;
243                    }
244                    Err(e) => {
245                        warn!("failed to respawn ssh tunnel: {e}");
246                        return;
247                    }
248                }
249            }
250        }
251    }
252}
253
254// ---------------------------------------------------------------------------
255// Remote daemon management
256// ---------------------------------------------------------------------------
257
258const REMOTE_ENSURE_CMD: &str = "\
259    SOCK=$(gritty socket-path) && \
260    (gritty ls >/dev/null 2>&1 || \
261     { gritty daemon && sleep 0.3; }) && \
262    echo \"$SOCK\"";
263
264/// Get the remote socket path and optionally auto-start the daemon.
265async fn ensure_remote_ready(
266    dest: &Destination,
267    no_daemon_start: bool,
268    extra_ssh_opts: &[String],
269) -> anyhow::Result<String> {
270    let remote_cmd = if no_daemon_start { "gritty socket-path" } else { REMOTE_ENSURE_CMD };
271    debug!("ensuring remote daemon (no_daemon_start={no_daemon_start})");
272
273    let sock_path = remote_exec(dest, remote_cmd, extra_ssh_opts).await?;
274
275    if sock_path.is_empty() {
276        bail!("remote host returned empty socket path");
277    }
278
279    Ok(sock_path)
280}
281
282// ---------------------------------------------------------------------------
283// Local socket path
284// ---------------------------------------------------------------------------
285
286/// Compute a deterministic local socket path based on the destination.
287///
288/// Using the raw destination string means re-running `gritty connect user@host`
289/// produces the same socket path, so sessions that used `--ctl-socket` can
290/// auto-reconnect after a tunnel restart.
291fn local_socket_path(destination: &str) -> PathBuf {
292    crate::daemon::socket_dir().join(format!("connect-{destination}.sock"))
293}
294
295fn connect_pid_path(connection_name: &str) -> PathBuf {
296    crate::daemon::socket_dir().join(format!("connect-{connection_name}.pid"))
297}
298
299// ---------------------------------------------------------------------------
300// Cleanup guard
301// ---------------------------------------------------------------------------
302
303struct ConnectGuard {
304    child: Option<Child>,
305    local_sock: PathBuf,
306    pid_file: PathBuf,
307    stop: tokio_util::sync::CancellationToken,
308}
309
310impl Drop for ConnectGuard {
311    fn drop(&mut self) {
312        self.stop.cancel();
313
314        if let Some(ref mut child) = self.child
315            && let Some(pid) = child.id()
316        {
317            unsafe {
318                libc::kill(pid as i32, libc::SIGTERM);
319            }
320        }
321
322        let _ = std::fs::remove_file(&self.local_sock);
323        let _ = std::fs::remove_file(&self.pid_file);
324    }
325}
326
327// ---------------------------------------------------------------------------
328// Public API
329// ---------------------------------------------------------------------------
330
331pub struct ConnectOpts {
332    pub destination: String,
333    pub no_daemon_start: bool,
334    pub ssh_options: Vec<String>,
335    pub name: Option<String>,
336}
337
338pub async fn run(opts: ConnectOpts) -> anyhow::Result<i32> {
339    let dest = Destination::parse(&opts.destination)?;
340    let connection_name = opts.name.unwrap_or_else(|| dest.host.clone());
341
342    // 1. Compute local socket path and check for existing tunnel
343    let local_sock = local_socket_path(&connection_name);
344    let pid_file = connect_pid_path(&connection_name);
345    debug!("local socket: {}", local_sock.display());
346    if let Some(parent) = local_sock.parent() {
347        crate::security::secure_create_dir_all(parent)?;
348    }
349
350    if std::os::unix::net::UnixStream::connect(&local_sock).is_ok() {
351        let pid_hint =
352            std::fs::read_to_string(&pid_file).ok().and_then(|s| s.trim().parse::<u32>().ok());
353        println!("{}", local_sock.display());
354        eprint!("tunnel already running (name: {connection_name})");
355        if let Some(pid) = pid_hint {
356            eprintln!(" (pid {pid})");
357            eprintln!("  to stop: kill {pid}");
358        } else {
359            eprintln!();
360        }
361        eprintln!("  to use:");
362        eprintln!("    gritty new {connection_name}");
363        eprintln!("    gritty attach {connection_name} -t <name>");
364        return Ok(0);
365    }
366    // Socket is stale or absent — clean up
367    let _ = std::fs::remove_file(&local_sock);
368
369    // 2. Ensure remote daemon is running and get socket path
370    eprintln!("starting remote daemon...");
371    let remote_sock = ensure_remote_ready(&dest, opts.no_daemon_start, &opts.ssh_options).await?;
372    debug!(remote_sock, "remote socket path");
373
374    // 3. Spawn SSH tunnel
375    let child = spawn_tunnel(&dest, &local_sock, &remote_sock, &opts.ssh_options).await?;
376    let stop = tokio_util::sync::CancellationToken::new();
377
378    let mut guard = ConnectGuard {
379        child: Some(child),
380        local_sock: local_sock.clone(),
381        pid_file: pid_file.clone(),
382        stop: stop.clone(),
383    };
384
385    // 4. Wait for local socket to become connectable
386    wait_for_socket(&local_sock).await?;
387    debug!("tunnel socket ready");
388
389    // Write PID file so subsequent connects can report it
390    let _ = std::fs::write(&pid_file, std::process::id().to_string());
391
392    // 5. Hand off the child to the tunnel monitor background task
393    let original_child = guard.child.take().unwrap();
394    let monitor_handle = tokio::spawn(tunnel_monitor(
395        original_child,
396        dest,
397        local_sock.clone(),
398        remote_sock,
399        opts.ssh_options,
400        stop.clone(),
401    ));
402
403    // 6. Print socket path and usage hints
404    println!("{}", local_sock.display());
405    eprintln!("tunnel ready (name: {connection_name}). to use:");
406    eprintln!("  gritty new {connection_name}");
407    eprintln!("  gritty attach {connection_name} -t <name>");
408
409    // 7. Wait for signal or monitor death
410    let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
411    tokio::select! {
412        _ = tokio::signal::ctrl_c() => {}
413        _ = sigterm.recv() => {}
414        _ = monitor_handle => {
415            eprintln!("tunnel lost");
416        }
417    }
418
419    // 8. Cleanup (guard Drop handles ssh kill + socket removal)
420    drop(guard);
421
422    Ok(0)
423}
424
425// ---------------------------------------------------------------------------
426// Tests
427// ---------------------------------------------------------------------------
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432
433    #[test]
434    fn parse_destination_user_host() {
435        let d = Destination::parse("user@host").unwrap();
436        assert_eq!(d.user.as_deref(), Some("user"));
437        assert_eq!(d.host, "host");
438        assert_eq!(d.port, None);
439    }
440
441    #[test]
442    fn parse_destination_host_only() {
443        let d = Destination::parse("myhost").unwrap();
444        assert_eq!(d.user, None);
445        assert_eq!(d.host, "myhost");
446        assert_eq!(d.port, None);
447    }
448
449    #[test]
450    fn parse_destination_host_port() {
451        let d = Destination::parse("host:2222").unwrap();
452        assert_eq!(d.user, None);
453        assert_eq!(d.host, "host");
454        assert_eq!(d.port, Some(2222));
455    }
456
457    #[test]
458    fn parse_destination_user_host_port() {
459        let d = Destination::parse("user@host:2222").unwrap();
460        assert_eq!(d.user.as_deref(), Some("user"));
461        assert_eq!(d.host, "host");
462        assert_eq!(d.port, Some(2222));
463    }
464
465    #[test]
466    fn parse_destination_invalid_empty() {
467        assert!(Destination::parse("").is_err());
468    }
469
470    #[test]
471    fn parse_destination_invalid_at_only() {
472        assert!(Destination::parse("@host").is_err());
473    }
474
475    #[test]
476    fn parse_destination_invalid_colon_only() {
477        assert!(Destination::parse(":2222").is_err());
478    }
479
480    #[test]
481    fn tunnel_command_default_opts() {
482        let dest = Destination::parse("user@host").unwrap();
483        let cmd = tunnel_command(
484            &dest,
485            Path::new("/tmp/local.sock"),
486            "/run/user/1000/gritty/ctl.sock",
487            &[],
488        );
489        let args: Vec<_> =
490            cmd.as_std().get_args().map(|a| a.to_string_lossy().to_string()).collect();
491        assert!(args.contains(&"ServerAliveInterval=3".to_string()));
492        assert!(args.contains(&"StreamLocalBindUnlink=yes".to_string()));
493        assert!(args.contains(&"ExitOnForwardFailure=yes".to_string()));
494        assert!(args.contains(&"ConnectTimeout=5".to_string()));
495        assert!(args.contains(&"-N".to_string()));
496        assert!(args.contains(&"-T".to_string()));
497        assert!(args.contains(&"/tmp/local.sock:/run/user/1000/gritty/ctl.sock".to_string()));
498        assert!(args.contains(&"user@host".to_string()));
499    }
500
501    #[test]
502    fn tunnel_command_extra_opts() {
503        let dest = Destination::parse("host:2222").unwrap();
504        let cmd = tunnel_command(
505            &dest,
506            Path::new("/tmp/local.sock"),
507            "/tmp/remote.sock",
508            &["ProxyJump=bastion".to_string()],
509        );
510        let args: Vec<_> =
511            cmd.as_std().get_args().map(|a| a.to_string_lossy().to_string()).collect();
512        assert!(args.contains(&"ProxyJump=bastion".to_string()));
513        assert!(args.contains(&"-p".to_string()));
514        assert!(args.contains(&"2222".to_string()));
515    }
516
517    #[test]
518    fn local_socket_path_format() {
519        // With hostname-based naming, connect uses just the host part
520        let path = local_socket_path("devbox");
521        assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-devbox.sock");
522
523        let path = local_socket_path("example.com");
524        assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-example.com.sock");
525
526        // Custom name override
527        let path = local_socket_path("myproject");
528        assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-myproject.sock");
529    }
530
531    #[test]
532    fn connect_pid_path_format() {
533        let path = connect_pid_path("devbox");
534        assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-devbox.pid");
535
536        let path = connect_pid_path("example.com");
537        assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-example.com.pid");
538    }
539
540    #[test]
541    fn ssh_dest_with_user() {
542        let d = Destination::parse("alice@example.com").unwrap();
543        assert_eq!(d.ssh_dest(), "alice@example.com");
544    }
545
546    #[test]
547    fn ssh_dest_without_user() {
548        let d = Destination::parse("example.com").unwrap();
549        assert_eq!(d.ssh_dest(), "example.com");
550    }
551
552    #[test]
553    fn port_args_with_port() {
554        let d = Destination::parse("host:9999").unwrap();
555        assert_eq!(d.port_args(), vec!["-p", "9999"]);
556    }
557
558    #[test]
559    fn port_args_without_port() {
560        let d = Destination::parse("host").unwrap();
561        assert!(d.port_args().is_empty());
562    }
563}