1use anyhow::{Context, bail};
2use std::os::fd::OwnedFd;
3use std::os::unix::fs::OpenOptionsExt;
4use std::os::unix::io::AsRawFd;
5use std::path::{Path, PathBuf};
6use std::process::Stdio;
7use std::time::{Duration, Instant};
8use tokio::process::{Child, Command};
9use tracing::{debug, info, warn};
10
11#[derive(Debug, Clone, PartialEq, Eq)]
16struct Destination {
17 user: Option<String>,
18 host: String,
19 port: Option<u16>,
20}
21
22impl Destination {
23 fn parse(s: &str) -> anyhow::Result<Self> {
24 if s.is_empty() {
25 bail!("empty destination");
26 }
27
28 let (user, remainder) = if let Some(at) = s.find('@') {
29 let u = &s[..at];
30 if u.is_empty() {
31 bail!("empty user in destination: {s}");
32 }
33 (Some(u.to_string()), &s[at + 1..])
34 } else {
35 (None, s)
36 };
37
38 let (host, port) = if let Some(colon) = remainder.rfind(':') {
39 let h = &remainder[..colon];
40 let p = remainder[colon + 1..]
41 .parse::<u16>()
42 .with_context(|| format!("invalid port in destination: {s}"))?;
43 (h.to_string(), Some(p))
44 } else {
45 (remainder.to_string(), None)
46 };
47
48 if host.is_empty() {
49 bail!("empty host in destination: {s}");
50 }
51
52 Ok(Self { user, host, port })
53 }
54
55 fn ssh_dest(&self) -> String {
57 match &self.user {
58 Some(u) => format!("{u}@{}", self.host),
59 None => self.host.clone(),
60 }
61 }
62
63 fn port_args(&self) -> Vec<String> {
65 match self.port {
66 Some(p) => vec!["-p".to_string(), p.to_string()],
67 None => vec![],
68 }
69 }
70}
71
72fn validate_connection_name(name: &str) -> anyhow::Result<()> {
74 if name.is_empty() {
75 bail!("connection name must not be empty");
76 }
77 if name.contains('/') || name.contains('\0') || name.contains("..") {
78 bail!("invalid connection name: {name:?}");
79 }
80 Ok(())
81}
82
83const TUNNEL_SSH_OPTS: &[&str] = &[
89 "ServerAliveInterval=3",
90 "ServerAliveCountMax=2",
91 "StreamLocalBindUnlink=yes",
92 "ExitOnForwardFailure=yes",
93 "ControlPath=none",
96 "ForwardAgent=no",
97 "ForwardX11=no",
98];
99
100const REMOTE_PATH_PREFIX: &str =
103 "$HOME/bin:$HOME/.local/bin:$HOME/.cargo/bin:/usr/local/bin:/opt/homebrew/bin:$PATH";
104
105fn base_ssh_args(dest: &Destination, extra_ssh_opts: &[String], foreground: bool) -> Vec<String> {
108 let mut args = Vec::new();
109 args.extend(dest.port_args());
110 for opt in extra_ssh_opts {
111 args.push("-o".into());
112 args.push(opt.clone());
113 }
114 args.push("-o".into());
115 args.push("ConnectTimeout=5".into());
116 if !foreground {
117 args.push("-o".into());
118 args.push("BatchMode=yes".into());
119 }
120 args
121}
122
123fn remote_exec_command(
125 dest: &Destination,
126 remote_cmd: &str,
127 extra_ssh_opts: &[String],
128 foreground: bool,
129) -> Command {
130 let wrapped_cmd = format!("PATH=\"{REMOTE_PATH_PREFIX}\"; {remote_cmd}");
131 let mut cmd = Command::new("ssh");
132 cmd.args(base_ssh_args(dest, extra_ssh_opts, foreground));
133 cmd.arg(dest.ssh_dest());
134 cmd.arg(&wrapped_cmd);
135 cmd
136}
137
138async fn remote_exec(
144 dest: &Destination,
145 remote_cmd: &str,
146 extra_ssh_opts: &[String],
147 foreground: bool,
148) -> anyhow::Result<String> {
149 debug!("ssh {}: {remote_cmd}", dest.ssh_dest());
150
151 let mut cmd = remote_exec_command(dest, remote_cmd, extra_ssh_opts, foreground);
152 cmd.stdout(Stdio::piped());
153 cmd.stderr(Stdio::piped());
154 cmd.stdin(Stdio::null());
155
156 let output = cmd.output().await.context("failed to run ssh")?;
157
158 if !output.status.success() {
159 let stderr = String::from_utf8_lossy(&output.stderr);
160 let stderr = stderr.trim();
161 debug!("ssh failed (status {}): {stderr}", output.status);
162 if stderr.contains("command not found") || stderr.contains("No such file") {
163 bail!("gritty not found on remote host (is it in PATH?)");
164 }
165 let diag = format_ssh_diag(dest, extra_ssh_opts, foreground);
166 if stderr.is_empty() {
167 bail!("ssh command failed (exit {})\n to diagnose: {diag}", output.status);
168 }
169 bail!("ssh command failed: {stderr}\n to diagnose: {diag}");
170 }
171
172 let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string();
173 debug!("ssh output: {stdout}");
174 Ok(stdout)
175}
176
177fn format_ssh_diag(dest: &Destination, extra_ssh_opts: &[String], foreground: bool) -> String {
180 let mut parts = vec!["ssh".to_string()];
181 for arg in base_ssh_args(dest, extra_ssh_opts, foreground) {
182 parts.push(shell_quote(&arg));
183 }
184 parts.push(dest.ssh_dest());
185 parts.join(" ")
186}
187
188fn shell_quote(s: &str) -> String {
191 if s.is_empty() {
192 return "''".to_string();
193 }
194 if s.bytes().all(|b| b.is_ascii_alphanumeric() || b"-_./=:@$+%,".contains(&b)) {
195 return s.to_string();
196 }
197 format!("'{}'", s.replace('\'', "'\\''"))
198}
199
200fn format_command(cmd: &Command) -> String {
202 let std_cmd = cmd.as_std();
203 let prog = std_cmd.get_program().to_string_lossy();
204 let args: Vec<_> = std_cmd.get_args().map(|a| shell_quote(&a.to_string_lossy())).collect();
205 if args.is_empty() { prog.to_string() } else { format!("{prog} {}", args.join(" ")) }
206}
207
208fn tunnel_command(
213 dest: &Destination,
214 local_sock: &Path,
215 remote_sock: &str,
216 extra_ssh_opts: &[String],
217 foreground: bool,
218) -> Command {
219 let mut cmd = Command::new("ssh");
220 cmd.args(base_ssh_args(dest, extra_ssh_opts, foreground));
221 for opt in TUNNEL_SSH_OPTS {
222 cmd.arg("-o").arg(opt);
223 }
224 cmd.args(["-N", "-T"]);
225 let forward = format!("{}:{}", local_sock.display(), remote_sock);
226 cmd.arg("-L").arg(forward);
227 cmd.arg(dest.ssh_dest());
228 cmd.stdout(Stdio::null());
229 cmd.stderr(Stdio::piped());
230 cmd.stdin(Stdio::null());
231 cmd
232}
233
234async fn spawn_tunnel(
236 dest: &Destination,
237 local_sock: &Path,
238 remote_sock: &str,
239 extra_ssh_opts: &[String],
240 foreground: bool,
241) -> anyhow::Result<Child> {
242 debug!("tunnel: {} -> {}:{}", local_sock.display(), dest.ssh_dest(), remote_sock,);
243 let mut cmd = tunnel_command(dest, local_sock, remote_sock, extra_ssh_opts, foreground);
244 let child = cmd.spawn().context("failed to spawn ssh tunnel")?;
245 debug!("ssh tunnel pid: {:?}", child.id());
246 Ok(child)
247}
248
249async fn wait_for_socket(path: &Path, timeout: Duration) -> anyhow::Result<()> {
251 let deadline = Instant::now() + timeout;
252 loop {
253 if std::os::unix::net::UnixStream::connect(path).is_ok() {
254 return Ok(());
255 }
256 if Instant::now() >= deadline {
257 bail!("timeout waiting for SSH tunnel socket at {}", path.display());
258 }
259 tokio::time::sleep(Duration::from_millis(200)).await;
260 }
261}
262
263async fn tunnel_monitor(
265 mut child: Child,
266 dest: Destination,
267 local_sock: PathBuf,
268 remote_sock: String,
269 extra_ssh_opts: Vec<String>,
270 stop: tokio_util::sync::CancellationToken,
271) {
272 let mut exit_times: Vec<Instant> = Vec::new();
273
274 loop {
275 tokio::select! {
276 _ = stop.cancelled() => {
277 let _ = child.kill().await;
278 return;
279 }
280 status = child.wait() => {
281 let status = match status {
282 Ok(s) => s,
283 Err(e) => {
284 warn!("failed to wait on ssh tunnel: {e}");
285 return;
286 }
287 };
288
289 if stop.is_cancelled() {
290 return;
291 }
292
293 let code = status.code();
294 debug!("ssh tunnel exited: {:?}", code);
295
296 if let Some(c) = code
300 && c != 255
301 {
302 warn!("ssh tunnel exited with code {c} (not retrying)");
303 return;
304 }
305
306 let now = Instant::now();
308 exit_times.push(now);
309 exit_times.retain(|t| now.duration_since(*t) < Duration::from_secs(10));
310 if exit_times.len() >= 5 {
311 warn!("ssh tunnel failing too fast (5 exits in 10s), giving up");
312 return;
313 }
314
315 tokio::time::sleep(Duration::from_secs(1)).await;
316
317 if stop.is_cancelled() {
318 return;
319 }
320
321 match spawn_tunnel(&dest, &local_sock, &remote_sock, &extra_ssh_opts, false).await {
322 Ok(new_child) => {
323 info!("ssh tunnel respawned");
324 child = new_child;
325 }
326 Err(e) => {
327 warn!("failed to respawn ssh tunnel: {e}");
328 return;
329 }
330 }
331 }
332 }
333 }
334}
335
336const REMOTE_ENSURE_CMD: &str = "\
341 SOCK=$(gritty socket-path) && \
342 (gritty ls >/dev/null 2>&1 || \
343 { gritty server && sleep 0.3; }) && \
344 echo \"$SOCK\" && \
345 gritty protocol-version 2>/dev/null || true";
346
347async fn ensure_remote_ready(
350 dest: &Destination,
351 no_server_start: bool,
352 extra_ssh_opts: &[String],
353 foreground: bool,
354) -> anyhow::Result<(String, Option<u16>)> {
355 let remote_cmd = if no_server_start { "gritty socket-path" } else { REMOTE_ENSURE_CMD };
356 debug!("ensuring remote server (no_server_start={no_server_start})");
357
358 let output = remote_exec(dest, remote_cmd, extra_ssh_opts, foreground).await?;
359
360 let mut lines = output.lines();
362 let sock_path = lines.next().unwrap_or("").to_string();
363 let remote_version = lines.next().and_then(|s| s.trim().parse::<u16>().ok());
364
365 if sock_path.is_empty() {
366 bail!("remote host returned empty socket path");
367 }
368
369 Ok((sock_path, remote_version))
370}
371
372fn local_socket_path(destination: &str) -> PathBuf {
382 crate::daemon::socket_dir().join(format!("connect-{destination}.sock"))
383}
384
385fn connect_pid_path(connection_name: &str) -> PathBuf {
386 crate::daemon::socket_dir().join(format!("connect-{connection_name}.pid"))
387}
388
389fn connect_lock_path(connection_name: &str) -> PathBuf {
390 crate::daemon::socket_dir().join(format!("connect-{connection_name}.lock"))
391}
392
393fn connect_dest_path(connection_name: &str) -> PathBuf {
394 crate::daemon::socket_dir().join(format!("connect-{connection_name}.dest"))
395}
396
397pub fn connection_socket_path(connection_name: &str) -> PathBuf {
400 local_socket_path(connection_name)
401}
402
403pub fn parse_host(destination: &str) -> anyhow::Result<String> {
405 Ok(Destination::parse(destination)?.host)
406}
407
408fn acquire_lock(lock_path: &Path) -> anyhow::Result<OwnedFd> {
415 use std::fs::OpenOptions;
416 let file = OpenOptions::new()
417 .create(true)
418 .truncate(false)
419 .write(true)
420 .mode(0o600)
421 .open(lock_path)
422 .with_context(|| format!("failed to open lockfile: {}", lock_path.display()))?;
423 let fd = OwnedFd::from(file);
424 if unsafe { libc::flock(fd.as_raw_fd(), libc::LOCK_EX) } != 0 {
425 bail!("failed to acquire lock on {}", lock_path.display());
426 }
427 Ok(fd)
428}
429
430fn is_lock_held(lock_path: &Path) -> bool {
433 use std::fs::OpenOptions;
434 let file = match OpenOptions::new().read(true).open(lock_path) {
435 Ok(f) => f,
436 Err(_) => return false,
437 };
438 if unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_EX | libc::LOCK_NB) } == 0 {
440 false
442 } else {
443 true }
445}
446
447#[derive(Debug, PartialEq, Eq)]
449pub enum TunnelStatus {
450 Healthy,
451 Reconnecting,
452 Stale,
453}
454
455fn probe_tunnel_status(name: &str) -> TunnelStatus {
457 let lock_path = connect_lock_path(name);
458 if is_lock_held(&lock_path) {
459 let sock_path = local_socket_path(name);
460 if std::os::unix::net::UnixStream::connect(&sock_path).is_ok() {
461 TunnelStatus::Healthy
462 } else {
463 TunnelStatus::Reconnecting
464 }
465 } else {
466 TunnelStatus::Stale
467 }
468}
469
470fn read_pid_hint(name: &str) -> Option<u32> {
474 std::fs::read_to_string(connect_pid_path(name)).ok().and_then(|s| s.trim().parse().ok())
475}
476
477fn cleanup_stale_files(name: &str) {
478 let _ = std::fs::remove_file(local_socket_path(name));
479 let _ = std::fs::remove_file(connect_pid_path(name));
480 let _ = std::fs::remove_file(connect_lock_path(name));
481 let _ = std::fs::remove_file(connect_dest_path(name));
482}
483
484fn enumerate_tunnels() -> Vec<String> {
486 let dir = crate::daemon::socket_dir();
487 let Ok(entries) = std::fs::read_dir(&dir) else {
488 return Vec::new();
489 };
490 entries
491 .filter_map(|e| e.ok())
492 .filter_map(|e| {
493 let name = e.file_name().to_string_lossy().to_string();
494 if name.starts_with("connect-") && name.ends_with(".lock") {
495 Some(name["connect-".len()..name.len() - ".lock".len()].to_string())
496 } else {
497 None
498 }
499 })
500 .collect()
501}
502
503struct ConnectGuard {
508 child: Option<Child>,
509 local_sock: PathBuf,
510 pid_file: PathBuf,
511 lock_file: PathBuf,
512 dest_file: PathBuf,
513 _lock_fd: Option<OwnedFd>,
514 stop: tokio_util::sync::CancellationToken,
515}
516
517impl Drop for ConnectGuard {
518 fn drop(&mut self) {
519 self.stop.cancel();
520
521 if let Some(ref mut child) = self.child
522 && let Some(pid) = child.id()
523 {
524 unsafe {
525 libc::kill(pid as i32, libc::SIGTERM);
526 }
527 }
528
529 let _ = std::fs::remove_file(&self.local_sock);
530 let _ = std::fs::remove_file(&self.pid_file);
531 let _ = std::fs::remove_file(&self.lock_file);
532 let _ = std::fs::remove_file(&self.dest_file);
533 }
535}
536
537pub struct ConnectOpts {
542 pub destination: String,
543 pub no_server_start: bool,
544 pub ssh_options: Vec<String>,
545 pub name: Option<String>,
546 pub dry_run: bool,
547 pub foreground: bool,
548}
549
550pub async fn run(opts: ConnectOpts, ready_fd: Option<OwnedFd>) -> anyhow::Result<i32> {
551 unsafe {
552 libc::umask(0o077);
553 }
554
555 let dest = Destination::parse(&opts.destination)?;
556 let connection_name = opts.name.unwrap_or_else(|| dest.host.clone());
557 validate_connection_name(&connection_name)?;
558 let local_sock = local_socket_path(&connection_name);
559
560 if opts.dry_run {
561 let remote_cmd =
562 if opts.no_server_start { "gritty socket-path" } else { REMOTE_ENSURE_CMD };
563 let ensure_cmd = remote_exec_command(&dest, remote_cmd, &opts.ssh_options, true);
564 let tunnel_cmd =
565 tunnel_command(&dest, &local_sock, "$REMOTE_SOCK", &opts.ssh_options, true);
566
567 println!(
568 "# Get remote socket path{}",
569 if opts.no_server_start { "" } else { " and start server if needed" }
570 );
571 println!("REMOTE_SOCK=$({})", format_command(&ensure_cmd));
572 println!();
573 println!("# Start SSH tunnel");
574 println!("{}", format_command(&tunnel_cmd));
575 return Ok(0);
576 }
577
578 let pid_file = connect_pid_path(&connection_name);
580 let lock_path = connect_lock_path(&connection_name);
581 let dest_file = connect_dest_path(&connection_name);
582 debug!("local socket: {}", local_sock.display());
583 if let Some(parent) = local_sock.parent() {
584 crate::security::secure_create_dir_all(parent)?;
585 }
586
587 match probe_tunnel_status(&connection_name) {
589 TunnelStatus::Healthy => {
590 println!("{}", local_sock.display());
591 let pid_hint = read_pid_hint(&connection_name);
592 eprint!("tunnel already running (name: {connection_name})");
593 if let Some(pid) = pid_hint {
594 eprintln!(" (pid {pid})");
595 eprintln!(" to stop: gritty disconnect {connection_name}");
596 } else {
597 eprintln!();
598 }
599 eprintln!(" to use:");
600 eprintln!(" gritty new {connection_name}");
601 eprintln!(" gritty attach {connection_name} -t <name>");
602 signal_ready(&ready_fd);
604 return Ok(0);
605 }
606 TunnelStatus::Reconnecting => {
607 let pid_hint = read_pid_hint(&connection_name);
608 eprint!("tunnel exists but is reconnecting (name: {connection_name})");
609 if let Some(pid) = pid_hint {
610 eprintln!(" (pid {pid})");
611 } else {
612 eprintln!();
613 }
614 eprintln!(" wait for it, or: gritty disconnect {connection_name}");
615 signal_ready(&ready_fd);
617 return Ok(0);
618 }
619 TunnelStatus::Stale => {
620 debug!("cleaning stale tunnel files for {connection_name}");
621 cleanup_stale_files(&connection_name);
622 }
623 }
624
625 let lock_fd = acquire_lock(&lock_path)?;
627
628 let (remote_sock, remote_version) =
630 ensure_remote_ready(&dest, opts.no_server_start, &opts.ssh_options, opts.foreground)
631 .await?;
632 debug!(remote_sock, ?remote_version, "remote socket path");
633
634 if let Some(rv) = remote_version {
636 if rv != crate::protocol::PROTOCOL_VERSION {
637 warn!(
638 "remote protocol version ({rv}) differs from local ({})",
639 crate::protocol::PROTOCOL_VERSION
640 );
641 }
642 }
643
644 let child =
646 spawn_tunnel(&dest, &local_sock, &remote_sock, &opts.ssh_options, opts.foreground).await?;
647 let stop = tokio_util::sync::CancellationToken::new();
648
649 let mut guard = ConnectGuard {
650 child: Some(child),
651 local_sock: local_sock.clone(),
652 pid_file: pid_file.clone(),
653 lock_file: lock_path,
654 dest_file: dest_file.clone(),
655 _lock_fd: Some(lock_fd),
656 stop: stop.clone(),
657 };
658
659 let mut child = guard.child.take().unwrap();
661 tokio::select! {
662 result = wait_for_socket(&local_sock, Duration::from_secs(15)) => {
663 result?;
664 guard.child = Some(child);
665 }
666 status = child.wait() => {
667 let status = status.context("failed to wait on ssh tunnel")?;
668 let diag = format_ssh_diag(&dest, &opts.ssh_options, opts.foreground);
669 let msg = if let Some(mut stderr) = child.stderr.take() {
670 use tokio::io::AsyncReadExt;
671 let mut buf = String::new();
672 let _ = stderr.read_to_string(&mut buf).await;
673 let buf = buf.trim().to_string();
674 if buf.is_empty() { None } else { Some(buf) }
675 } else {
676 None
677 };
678 match msg {
679 Some(err) => bail!("ssh tunnel failed: {err}\n to diagnose: {diag}"),
680 None => bail!("ssh tunnel exited ({status})\n to diagnose: {diag}"),
681 }
682 }
683 }
684 debug!("tunnel socket ready");
685
686 let _ = std::fs::write(&pid_file, std::process::id().to_string());
688 let _ = std::fs::write(&dest_file, &opts.destination);
689
690 signal_ready(&ready_fd);
692
693 let original_child = guard.child.take().unwrap();
695 let monitor_handle = tokio::spawn(tunnel_monitor(
696 original_child,
697 dest,
698 local_sock.clone(),
699 remote_sock,
700 opts.ssh_options,
701 stop.clone(),
702 ));
703
704 let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
706 tokio::select! {
707 _ = sigterm.recv() => {}
708 _ = monitor_handle => {}
709 }
710
711 drop(guard);
713
714 Ok(0)
715}
716
717fn signal_ready(ready_fd: &Option<OwnedFd>) {
719 if let Some(fd) = ready_fd {
720 let _ = nix::unistd::write(fd, b"\x01");
721 }
722}
723
724pub async fn disconnect(name: &str) -> anyhow::Result<()> {
729 validate_connection_name(name)?;
730 match probe_tunnel_status(name) {
731 TunnelStatus::Stale => {
732 cleanup_stale_files(name);
733 eprintln!("tunnel already stopped: {name}");
734 return Ok(());
735 }
736 TunnelStatus::Healthy | TunnelStatus::Reconnecting => {}
737 }
738
739 let pid_file = connect_pid_path(name);
741 let pid: i32 = std::fs::read_to_string(&pid_file)
742 .ok()
743 .and_then(|s| s.trim().parse::<u32>().ok())
744 .map(|p| p as i32)
745 .ok_or_else(|| anyhow::anyhow!("cannot read PID for tunnel {name}"))?;
746
747 let lock_path = connect_lock_path(name);
748 if !is_lock_held(&lock_path) {
749 cleanup_stale_files(name);
750 eprintln!("tunnel already stopped: {name}");
751 return Ok(());
752 }
753 unsafe {
754 libc::kill(pid, libc::SIGTERM);
755 }
756
757 let deadline = Instant::now() + Duration::from_secs(2);
759 loop {
760 tokio::time::sleep(Duration::from_millis(100)).await;
761 if !is_lock_held(&lock_path) {
762 cleanup_stale_files(name);
763 eprintln!("tunnel stopped: {name}");
764 return Ok(());
765 }
766 if Instant::now() >= deadline {
767 break;
768 }
769 }
770
771 if is_lock_held(&lock_path) {
773 unsafe {
774 libc::kill(pid, libc::SIGKILL);
775 libc::killpg(pid, libc::SIGTERM);
776 }
777 }
778 tokio::time::sleep(Duration::from_millis(100)).await;
779 cleanup_stale_files(name);
780 eprintln!("tunnel killed: {name}");
781 Ok(())
782}
783
784pub struct TunnelInfo {
789 pub name: String,
790 pub destination: String,
791 pub status: String,
792 pub pid: Option<u32>,
793 pub log_path: PathBuf,
794}
795
796pub fn get_tunnel_info() -> Vec<TunnelInfo> {
798 let names = enumerate_tunnels();
799 let mut infos = Vec::new();
800 for name in &names {
801 let status = probe_tunnel_status(name);
802 if status == TunnelStatus::Stale {
803 debug!("cleaning stale tunnel: {name}");
804 cleanup_stale_files(name);
805 continue;
806 }
807 let dest =
808 std::fs::read_to_string(connect_dest_path(name)).unwrap_or_else(|_| "-".to_string());
809 let status_str = match status {
810 TunnelStatus::Healthy => "healthy".to_string(),
811 TunnelStatus::Reconnecting => "reconnecting".to_string(),
812 TunnelStatus::Stale => unreachable!(),
813 };
814 infos.push(TunnelInfo {
815 name: name.clone(),
816 destination: dest.trim().to_string(),
817 status: status_str,
818 pid: read_pid_hint(name),
819 log_path: crate::daemon::socket_dir().join(format!("connect-{name}.log")),
820 });
821 }
822 infos
823}
824
825pub fn list_tunnels() {
826 let infos = get_tunnel_info();
827 if infos.is_empty() {
828 println!("no active tunnels");
829 return;
830 }
831
832 let w_name = infos.iter().map(|i| i.name.len()).max().unwrap().max(4);
833 let w_dest = infos.iter().map(|i| i.destination.len()).max().unwrap().max(11);
834
835 println!("{:<w_name$} {:<w_dest$} Status", "Name", "Destination");
836 for info in &infos {
837 println!("{:<w_name$} {:<w_dest$} {}", info.name, info.destination, info.status);
838 }
839}
840
841#[cfg(test)]
846mod tests {
847 use super::*;
848
849 #[test]
850 fn parse_destination_user_host() {
851 let d = Destination::parse("user@host").unwrap();
852 assert_eq!(d.user.as_deref(), Some("user"));
853 assert_eq!(d.host, "host");
854 assert_eq!(d.port, None);
855 }
856
857 #[test]
858 fn parse_destination_host_only() {
859 let d = Destination::parse("myhost").unwrap();
860 assert_eq!(d.user, None);
861 assert_eq!(d.host, "myhost");
862 assert_eq!(d.port, None);
863 }
864
865 #[test]
866 fn parse_destination_host_port() {
867 let d = Destination::parse("host:2222").unwrap();
868 assert_eq!(d.user, None);
869 assert_eq!(d.host, "host");
870 assert_eq!(d.port, Some(2222));
871 }
872
873 #[test]
874 fn parse_destination_user_host_port() {
875 let d = Destination::parse("user@host:2222").unwrap();
876 assert_eq!(d.user.as_deref(), Some("user"));
877 assert_eq!(d.host, "host");
878 assert_eq!(d.port, Some(2222));
879 }
880
881 #[test]
882 fn parse_destination_invalid_empty() {
883 assert!(Destination::parse("").is_err());
884 }
885
886 #[test]
887 fn parse_destination_invalid_at_only() {
888 assert!(Destination::parse("@host").is_err());
889 }
890
891 #[test]
892 fn parse_destination_invalid_colon_only() {
893 assert!(Destination::parse(":2222").is_err());
894 }
895
896 #[test]
897 fn tunnel_command_default_opts() {
898 let dest = Destination::parse("user@host").unwrap();
899 let cmd = tunnel_command(
900 &dest,
901 Path::new("/tmp/local.sock"),
902 "/run/user/1000/gritty/ctl.sock",
903 &[],
904 false,
905 );
906 let args: Vec<_> =
907 cmd.as_std().get_args().map(|a| a.to_string_lossy().to_string()).collect();
908 assert!(args.contains(&"ConnectTimeout=5".to_string()));
910 assert!(args.contains(&"BatchMode=yes".to_string()));
911 assert!(args.contains(&"ServerAliveInterval=3".to_string()));
913 assert!(args.contains(&"StreamLocalBindUnlink=yes".to_string()));
914 assert!(args.contains(&"ExitOnForwardFailure=yes".to_string()));
915 assert!(args.contains(&"ControlPath=none".to_string()));
916 assert!(args.contains(&"ForwardAgent=no".to_string()));
917 assert!(args.contains(&"ForwardX11=no".to_string()));
918 assert!(args.contains(&"-N".to_string()));
920 assert!(args.contains(&"-T".to_string()));
921 assert!(args.contains(&"/tmp/local.sock:/run/user/1000/gritty/ctl.sock".to_string()));
922 assert!(args.contains(&"user@host".to_string()));
923 }
924
925 #[test]
926 fn tunnel_command_extra_opts() {
927 let dest = Destination::parse("host:2222").unwrap();
928 let cmd = tunnel_command(
929 &dest,
930 Path::new("/tmp/local.sock"),
931 "/tmp/remote.sock",
932 &["ProxyJump=bastion".to_string()],
933 false,
934 );
935 let args: Vec<_> =
936 cmd.as_std().get_args().map(|a| a.to_string_lossy().to_string()).collect();
937 assert!(args.contains(&"ProxyJump=bastion".to_string()));
938 assert!(args.contains(&"-p".to_string()));
939 assert!(args.contains(&"2222".to_string()));
940 }
941
942 #[test]
943 fn local_socket_path_format() {
944 let path = local_socket_path("devbox");
946 assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-devbox.sock");
947
948 let path = local_socket_path("example.com");
949 assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-example.com.sock");
950
951 let path = local_socket_path("myproject");
953 assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-myproject.sock");
954 }
955
956 #[test]
957 fn connect_pid_path_format() {
958 let path = connect_pid_path("devbox");
959 assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-devbox.pid");
960
961 let path = connect_pid_path("example.com");
962 assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-example.com.pid");
963 }
964
965 #[test]
966 fn ssh_dest_with_user() {
967 let d = Destination::parse("alice@example.com").unwrap();
968 assert_eq!(d.ssh_dest(), "alice@example.com");
969 }
970
971 #[test]
972 fn ssh_dest_without_user() {
973 let d = Destination::parse("example.com").unwrap();
974 assert_eq!(d.ssh_dest(), "example.com");
975 }
976
977 #[test]
978 fn port_args_with_port() {
979 let d = Destination::parse("host:9999").unwrap();
980 assert_eq!(d.port_args(), vec!["-p", "9999"]);
981 }
982
983 #[test]
984 fn port_args_without_port() {
985 let d = Destination::parse("host").unwrap();
986 assert!(d.port_args().is_empty());
987 }
988
989 #[test]
990 fn shell_quote_simple() {
991 assert_eq!(shell_quote("hello"), "hello");
992 assert_eq!(shell_quote("-N"), "-N");
993 assert_eq!(shell_quote("ServerAliveInterval=3"), "ServerAliveInterval=3");
994 assert_eq!(shell_quote("user@host"), "user@host");
995 assert_eq!(
996 shell_quote("/tmp/local.sock:/tmp/remote.sock"),
997 "/tmp/local.sock:/tmp/remote.sock"
998 );
999 assert_eq!(shell_quote("$REMOTE_SOCK"), "$REMOTE_SOCK");
1000 }
1001
1002 #[test]
1003 fn shell_quote_needs_quoting() {
1004 assert_eq!(shell_quote("hello world"), "'hello world'");
1005 assert_eq!(shell_quote(""), "''");
1006 assert_eq!(shell_quote("it's"), "'it'\\''s'");
1007 }
1008
1009 #[test]
1010 fn shell_quote_remote_cmd() {
1011 let cmd = format!("PATH=\"{REMOTE_PATH_PREFIX}\"; gritty socket-path");
1014 let quoted = shell_quote(&cmd);
1015 assert!(quoted.starts_with('\''));
1016 assert!(quoted.ends_with('\''));
1017 }
1018
1019 #[test]
1020 fn format_command_tunnel() {
1021 let dest = Destination::parse("user@host").unwrap();
1022 let cmd = tunnel_command(&dest, Path::new("/tmp/local.sock"), "$REMOTE_SOCK", &[], true);
1023 let formatted = format_command(&cmd);
1024 assert!(formatted.contains("ServerAliveInterval=3"));
1025 assert!(formatted.contains("ControlPath=none"));
1026 assert!(formatted.contains("ForwardAgent=no"));
1027 assert!(formatted.contains("-N"));
1028 assert!(formatted.contains("-T"));
1029 assert!(formatted.contains("/tmp/local.sock:$REMOTE_SOCK"));
1031 assert!(formatted.contains("user@host"));
1032 }
1033
1034 #[test]
1035 fn format_command_remote_exec() {
1036 let dest = Destination::parse("user@host:2222").unwrap();
1037 let cmd = remote_exec_command(&dest, "gritty socket-path", &[], true);
1038 let formatted = format_command(&cmd);
1039 assert!(formatted.starts_with("ssh "));
1040 assert!(formatted.contains("-p 2222"));
1041 assert!(formatted.contains("ConnectTimeout=5"));
1042 assert!(formatted.contains("user@host"));
1043 assert!(formatted.contains(&format!("PATH=\"{REMOTE_PATH_PREFIX}\"")));
1045 }
1046
1047 #[test]
1048 fn format_command_remote_exec_with_extra_opts() {
1049 let dest = Destination::parse("user@host").unwrap();
1050 let cmd =
1051 remote_exec_command(&dest, REMOTE_ENSURE_CMD, &["ProxyJump=bastion".to_string()], true);
1052 let formatted = format_command(&cmd);
1053 assert!(formatted.contains("ProxyJump=bastion"));
1054 assert!(formatted.contains("gritty socket-path"));
1055 assert!(formatted.contains("gritty server"));
1056 }
1057
1058 #[test]
1059 fn base_ssh_args_foreground() {
1060 let dest = Destination::parse("user@host:2222").unwrap();
1061 let args = base_ssh_args(&dest, &["ProxyJump=bastion".into()], true);
1062 assert!(args.contains(&"-p".to_string()));
1063 assert!(args.contains(&"2222".to_string()));
1064 assert!(args.contains(&"ProxyJump=bastion".to_string()));
1065 assert!(args.contains(&"ConnectTimeout=5".to_string()));
1066 assert!(!args.contains(&"BatchMode=yes".to_string()));
1067 }
1068
1069 #[test]
1070 fn base_ssh_args_background() {
1071 let dest = Destination::parse("host").unwrap();
1072 let args = base_ssh_args(&dest, &[], false);
1073 assert!(args.contains(&"ConnectTimeout=5".to_string()));
1074 assert!(args.contains(&"BatchMode=yes".to_string()));
1075 assert!(!args.contains(&"-p".to_string()));
1076 }
1077
1078 #[test]
1083 fn connect_lock_path_format() {
1084 let path = connect_lock_path("devbox");
1085 assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-devbox.lock");
1086 }
1087
1088 #[test]
1089 fn connect_dest_path_format() {
1090 let path = connect_dest_path("devbox");
1091 assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-devbox.dest");
1092 }
1093
1094 #[test]
1095 fn acquire_and_probe_lock() {
1096 let dir = tempfile::tempdir().unwrap();
1097 let lock_path = dir.path().join("test.lock");
1098
1099 assert!(!is_lock_held(&lock_path));
1101
1102 let _fd = acquire_lock(&lock_path).unwrap();
1104
1105 assert!(is_lock_held(&lock_path));
1107
1108 drop(_fd);
1110
1111 assert!(!is_lock_held(&lock_path));
1113 }
1114
1115 #[test]
1116 fn probe_stale_no_files() {
1117 let status = probe_tunnel_status("nonexistent-test-tunnel-xyz");
1119 assert_eq!(status, TunnelStatus::Stale);
1120 }
1121
1122 #[test]
1123 fn cleanup_stale_files_removes_all() {
1124 let _dir = tempfile::tempdir().unwrap();
1125 cleanup_stale_files("nonexistent-cleanup-test-xyz");
1128 }
1130
1131 #[test]
1132 fn enumerate_tunnels_empty_dir() {
1133 let names = enumerate_tunnels();
1136 let _ = names;
1139 }
1140
1141 #[test]
1142 fn connection_socket_path_matches_local() {
1143 let public_path = connection_socket_path("myhost");
1144 let internal_path = local_socket_path("myhost");
1145 assert_eq!(public_path, internal_path);
1146 }
1147
1148 #[tokio::test]
1153 async fn tunnel_monitor_non_transient_exit() {
1154 let child = Command::new("sh").arg("-c").arg("exit 1").spawn().unwrap();
1155 let dest = Destination::parse("fake-host-test").unwrap();
1156 let stop = tokio_util::sync::CancellationToken::new();
1157
1158 let result = tokio::time::timeout(
1159 Duration::from_secs(5),
1160 tunnel_monitor(
1161 child,
1162 dest,
1163 PathBuf::from("/tmp/nonexistent.sock"),
1164 "/tmp/remote.sock".into(),
1165 vec![],
1166 stop,
1167 ),
1168 )
1169 .await;
1170
1171 assert!(result.is_ok(), "monitor should return quickly on non-transient exit");
1172 }
1173
1174 #[tokio::test]
1175 async fn tunnel_monitor_cancellation() {
1176 let child = Command::new("sleep").arg("60").spawn().unwrap();
1177 let dest = Destination::parse("fake-host-test").unwrap();
1178 let stop = tokio_util::sync::CancellationToken::new();
1179 let stop_clone = stop.clone();
1180
1181 tokio::spawn(async move {
1182 tokio::time::sleep(Duration::from_millis(100)).await;
1183 stop_clone.cancel();
1184 });
1185
1186 let result = tokio::time::timeout(
1187 Duration::from_secs(5),
1188 tunnel_monitor(
1189 child,
1190 dest,
1191 PathBuf::from("/tmp/nonexistent.sock"),
1192 "/tmp/remote.sock".into(),
1193 vec![],
1194 stop,
1195 ),
1196 )
1197 .await;
1198
1199 assert!(result.is_ok(), "monitor should return after cancellation");
1200 }
1201
1202 #[tokio::test]
1203 async fn tunnel_monitor_transient_exit_checks_cancellation() {
1204 let child = Command::new("sh").arg("-c").arg("exit 255").spawn().unwrap();
1206 let dest = Destination::parse("fake-host-test").unwrap();
1207 let stop = tokio_util::sync::CancellationToken::new();
1208 let stop_clone = stop.clone();
1209
1210 tokio::spawn(async move {
1212 tokio::time::sleep(Duration::from_millis(500)).await;
1213 stop_clone.cancel();
1214 });
1215
1216 let result = tokio::time::timeout(
1217 Duration::from_secs(5),
1218 tunnel_monitor(
1219 child,
1220 dest,
1221 PathBuf::from("/tmp/nonexistent.sock"),
1222 "/tmp/remote.sock".into(),
1223 vec![],
1224 stop,
1225 ),
1226 )
1227 .await;
1228
1229 assert!(result.is_ok(), "monitor should return after cancellation during sleep");
1230 }
1231
1232 #[tokio::test]
1237 async fn wait_for_socket_succeeds_after_delay() {
1238 let dir = tempfile::tempdir().unwrap();
1239 let sock_path = dir.path().join("delayed.sock");
1240 let sock_path_clone = sock_path.clone();
1241
1242 tokio::spawn(async move {
1244 tokio::time::sleep(Duration::from_millis(500)).await;
1245 let _listener = tokio::net::UnixListener::bind(&sock_path_clone).unwrap();
1246 tokio::time::sleep(Duration::from_secs(30)).await;
1248 });
1249
1250 let result = wait_for_socket(&sock_path, Duration::from_secs(5)).await;
1251 assert!(result.is_ok(), "should successfully connect");
1252 }
1253
1254 #[tokio::test]
1255 async fn wait_for_socket_timeout() {
1256 let dir = tempfile::tempdir().unwrap();
1257 let sock_path = dir.path().join("never.sock");
1258 let result = wait_for_socket(&sock_path, Duration::from_secs(1)).await;
1259 assert!(result.is_err());
1260 }
1261}