1use anyhow::{Context, bail};
2use std::os::fd::OwnedFd;
3use std::os::unix::fs::OpenOptionsExt;
4use std::os::unix::io::AsRawFd;
5use std::path::{Path, PathBuf};
6use std::process::Stdio;
7use std::time::{Duration, Instant};
8use tokio::process::{Child, Command};
9use tracing::{debug, info, warn};
10
11#[derive(Debug, Clone, PartialEq, Eq)]
16struct Destination {
17 user: Option<String>,
18 host: String,
19 port: Option<u16>,
20}
21
22impl Destination {
23 fn parse(s: &str) -> anyhow::Result<Self> {
24 if s.is_empty() {
25 bail!("empty destination");
26 }
27
28 let (user, remainder) = if let Some(at) = s.find('@') {
29 let u = &s[..at];
30 if u.is_empty() {
31 bail!("empty user in destination: {s}");
32 }
33 (Some(u.to_string()), &s[at + 1..])
34 } else {
35 (None, s)
36 };
37
38 let (host, port) = if let Some(colon) = remainder.rfind(':') {
39 let h = &remainder[..colon];
40 let p = remainder[colon + 1..]
41 .parse::<u16>()
42 .with_context(|| format!("invalid port in destination: {s}"))?;
43 (h.to_string(), Some(p))
44 } else {
45 (remainder.to_string(), None)
46 };
47
48 if host.is_empty() {
49 bail!("empty host in destination: {s}");
50 }
51
52 Ok(Self { user, host, port })
53 }
54
55 fn ssh_dest(&self) -> String {
57 match &self.user {
58 Some(u) => format!("{u}@{}", self.host),
59 None => self.host.clone(),
60 }
61 }
62
63 fn port_args(&self) -> Vec<String> {
65 match self.port {
66 Some(p) => vec!["-p".to_string(), p.to_string()],
67 None => vec![],
68 }
69 }
70}
71
72const TUNNEL_SSH_OPTS: &[&str] = &[
78 "ServerAliveInterval=3",
79 "ServerAliveCountMax=2",
80 "StreamLocalBindUnlink=yes",
81 "ExitOnForwardFailure=yes",
82 "ControlPath=none",
85 "ForwardAgent=no",
86 "ForwardX11=no",
87];
88
89const REMOTE_PATH_PREFIX: &str = "$HOME/bin:$HOME/.local/bin:$HOME/.cargo/bin:$PATH";
92
93fn base_ssh_args(dest: &Destination, extra_ssh_opts: &[String], foreground: bool) -> Vec<String> {
96 let mut args = Vec::new();
97 args.extend(dest.port_args());
98 for opt in extra_ssh_opts {
99 args.push("-o".into());
100 args.push(opt.clone());
101 }
102 args.push("-o".into());
103 args.push("ConnectTimeout=5".into());
104 if !foreground {
105 args.push("-o".into());
106 args.push("BatchMode=yes".into());
107 }
108 args
109}
110
111fn remote_exec_command(
113 dest: &Destination,
114 remote_cmd: &str,
115 extra_ssh_opts: &[String],
116 foreground: bool,
117) -> Command {
118 let wrapped_cmd = format!("PATH=\"{REMOTE_PATH_PREFIX}\"; {remote_cmd}");
119 let mut cmd = Command::new("ssh");
120 cmd.args(base_ssh_args(dest, extra_ssh_opts, foreground));
121 cmd.arg(dest.ssh_dest());
122 cmd.arg(&wrapped_cmd);
123 cmd
124}
125
126async fn remote_exec(
132 dest: &Destination,
133 remote_cmd: &str,
134 extra_ssh_opts: &[String],
135 foreground: bool,
136) -> anyhow::Result<String> {
137 debug!("ssh {}: {remote_cmd}", dest.ssh_dest());
138
139 let mut cmd = remote_exec_command(dest, remote_cmd, extra_ssh_opts, foreground);
140 cmd.stdout(Stdio::piped());
141 cmd.stderr(Stdio::piped());
142 cmd.stdin(Stdio::null());
143
144 let output = cmd.output().await.context("failed to run ssh")?;
145
146 if !output.status.success() {
147 let stderr = String::from_utf8_lossy(&output.stderr);
148 let stderr = stderr.trim();
149 debug!("ssh failed (status {}): {stderr}", output.status);
150 if stderr.contains("command not found") || stderr.contains("No such file") {
151 bail!("gritty not found on remote host (is it in PATH?)");
152 }
153 let diag = format_ssh_diag(dest, extra_ssh_opts, foreground);
154 if stderr.is_empty() {
155 bail!("ssh command failed (exit {})\n to diagnose: {diag}", output.status);
156 }
157 bail!("ssh command failed: {stderr}\n to diagnose: {diag}");
158 }
159
160 let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string();
161 debug!("ssh output: {stdout}");
162 Ok(stdout)
163}
164
165fn format_ssh_diag(dest: &Destination, extra_ssh_opts: &[String], foreground: bool) -> String {
168 let mut parts = vec!["ssh".to_string()];
169 for arg in base_ssh_args(dest, extra_ssh_opts, foreground) {
170 parts.push(shell_quote(&arg));
171 }
172 parts.push(dest.ssh_dest());
173 parts.join(" ")
174}
175
176fn shell_quote(s: &str) -> String {
179 if s.is_empty() {
180 return "''".to_string();
181 }
182 if s.bytes().all(|b| b.is_ascii_alphanumeric() || b"-_./=:@$+%,".contains(&b)) {
183 return s.to_string();
184 }
185 format!("'{}'", s.replace('\'', "'\\''"))
186}
187
188fn format_command(cmd: &Command) -> String {
190 let std_cmd = cmd.as_std();
191 let prog = std_cmd.get_program().to_string_lossy();
192 let args: Vec<_> = std_cmd.get_args().map(|a| shell_quote(&a.to_string_lossy())).collect();
193 if args.is_empty() { prog.to_string() } else { format!("{prog} {}", args.join(" ")) }
194}
195
196fn tunnel_command(
201 dest: &Destination,
202 local_sock: &Path,
203 remote_sock: &str,
204 extra_ssh_opts: &[String],
205 foreground: bool,
206) -> Command {
207 let mut cmd = Command::new("ssh");
208 cmd.args(base_ssh_args(dest, extra_ssh_opts, foreground));
209 for opt in TUNNEL_SSH_OPTS {
210 cmd.arg("-o").arg(opt);
211 }
212 cmd.args(["-N", "-T"]);
213 let forward = format!("{}:{}", local_sock.display(), remote_sock);
214 cmd.arg("-L").arg(forward);
215 cmd.arg(dest.ssh_dest());
216 cmd.stdout(Stdio::null());
217 cmd.stderr(Stdio::piped());
218 cmd.stdin(Stdio::null());
219 cmd
220}
221
222async fn spawn_tunnel(
224 dest: &Destination,
225 local_sock: &Path,
226 remote_sock: &str,
227 extra_ssh_opts: &[String],
228 foreground: bool,
229) -> anyhow::Result<Child> {
230 debug!("tunnel: {} -> {}:{}", local_sock.display(), dest.ssh_dest(), remote_sock,);
231 let mut cmd = tunnel_command(dest, local_sock, remote_sock, extra_ssh_opts, foreground);
232 let child = cmd.spawn().context("failed to spawn ssh tunnel")?;
233 debug!("ssh tunnel pid: {:?}", child.id());
234 Ok(child)
235}
236
237async fn wait_for_socket(path: &Path, timeout: Duration) -> anyhow::Result<()> {
239 let deadline = Instant::now() + timeout;
240 loop {
241 if std::os::unix::net::UnixStream::connect(path).is_ok() {
242 return Ok(());
243 }
244 if Instant::now() >= deadline {
245 bail!("timeout waiting for SSH tunnel socket at {}", path.display());
246 }
247 tokio::time::sleep(Duration::from_millis(200)).await;
248 }
249}
250
251async fn tunnel_monitor(
253 mut child: Child,
254 dest: Destination,
255 local_sock: PathBuf,
256 remote_sock: String,
257 extra_ssh_opts: Vec<String>,
258 stop: tokio_util::sync::CancellationToken,
259) {
260 let mut exit_times: Vec<Instant> = Vec::new();
261
262 loop {
263 tokio::select! {
264 _ = stop.cancelled() => {
265 let _ = child.kill().await;
266 return;
267 }
268 status = child.wait() => {
269 let status = match status {
270 Ok(s) => s,
271 Err(e) => {
272 warn!("failed to wait on ssh tunnel: {e}");
273 return;
274 }
275 };
276
277 if stop.is_cancelled() {
278 return;
279 }
280
281 let code = status.code();
282 debug!("ssh tunnel exited: {:?}", code);
283
284 if let Some(c) = code
288 && c != 255
289 {
290 warn!("ssh tunnel exited with code {c} (not retrying)");
291 return;
292 }
293
294 let now = Instant::now();
296 exit_times.push(now);
297 exit_times.retain(|t| now.duration_since(*t) < Duration::from_secs(10));
298 if exit_times.len() >= 5 {
299 warn!("ssh tunnel failing too fast (5 exits in 10s), giving up");
300 return;
301 }
302
303 tokio::time::sleep(Duration::from_secs(1)).await;
304
305 if stop.is_cancelled() {
306 return;
307 }
308
309 match spawn_tunnel(&dest, &local_sock, &remote_sock, &extra_ssh_opts, false).await {
310 Ok(new_child) => {
311 info!("ssh tunnel respawned");
312 child = new_child;
313 }
314 Err(e) => {
315 warn!("failed to respawn ssh tunnel: {e}");
316 return;
317 }
318 }
319 }
320 }
321 }
322}
323
324const REMOTE_ENSURE_CMD: &str = "\
329 SOCK=$(gritty socket-path) && \
330 (gritty ls >/dev/null 2>&1 || \
331 { gritty server && sleep 0.3; }) && \
332 echo \"$SOCK\"";
333
334async fn ensure_remote_ready(
336 dest: &Destination,
337 no_server_start: bool,
338 extra_ssh_opts: &[String],
339 foreground: bool,
340) -> anyhow::Result<String> {
341 let remote_cmd = if no_server_start { "gritty socket-path" } else { REMOTE_ENSURE_CMD };
342 debug!("ensuring remote server (no_server_start={no_server_start})");
343
344 let sock_path = remote_exec(dest, remote_cmd, extra_ssh_opts, foreground).await?;
345
346 if sock_path.is_empty() {
347 bail!("remote host returned empty socket path");
348 }
349
350 Ok(sock_path)
351}
352
353fn local_socket_path(destination: &str) -> PathBuf {
363 crate::daemon::socket_dir().join(format!("connect-{destination}.sock"))
364}
365
366fn connect_pid_path(connection_name: &str) -> PathBuf {
367 crate::daemon::socket_dir().join(format!("connect-{connection_name}.pid"))
368}
369
370fn connect_lock_path(connection_name: &str) -> PathBuf {
371 crate::daemon::socket_dir().join(format!("connect-{connection_name}.lock"))
372}
373
374fn connect_dest_path(connection_name: &str) -> PathBuf {
375 crate::daemon::socket_dir().join(format!("connect-{connection_name}.dest"))
376}
377
378pub fn connection_socket_path(connection_name: &str) -> PathBuf {
381 local_socket_path(connection_name)
382}
383
384pub fn parse_host(destination: &str) -> anyhow::Result<String> {
386 Ok(Destination::parse(destination)?.host)
387}
388
389fn acquire_lock(lock_path: &Path) -> anyhow::Result<OwnedFd> {
396 use std::fs::OpenOptions;
397 let file = OpenOptions::new()
398 .create(true)
399 .truncate(false)
400 .write(true)
401 .mode(0o600)
402 .open(lock_path)
403 .with_context(|| format!("failed to open lockfile: {}", lock_path.display()))?;
404 let fd = OwnedFd::from(file);
405 if unsafe { libc::flock(fd.as_raw_fd(), libc::LOCK_EX) } != 0 {
406 bail!("failed to acquire lock on {}", lock_path.display());
407 }
408 Ok(fd)
409}
410
411fn is_lock_held(lock_path: &Path) -> bool {
414 use std::fs::OpenOptions;
415 let file = match OpenOptions::new().read(true).open(lock_path) {
416 Ok(f) => f,
417 Err(_) => return false,
418 };
419 if unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_EX | libc::LOCK_NB) } == 0 {
421 false
423 } else {
424 true }
426}
427
428#[derive(Debug, PartialEq, Eq)]
430pub enum TunnelStatus {
431 Healthy,
432 Reconnecting,
433 Stale,
434}
435
436fn probe_tunnel_status(name: &str) -> TunnelStatus {
438 let lock_path = connect_lock_path(name);
439 if is_lock_held(&lock_path) {
440 let sock_path = local_socket_path(name);
441 if std::os::unix::net::UnixStream::connect(&sock_path).is_ok() {
442 TunnelStatus::Healthy
443 } else {
444 TunnelStatus::Reconnecting
445 }
446 } else {
447 TunnelStatus::Stale
448 }
449}
450
451fn read_pid_hint(name: &str) -> Option<u32> {
455 std::fs::read_to_string(connect_pid_path(name)).ok().and_then(|s| s.trim().parse().ok())
456}
457
458fn cleanup_stale_files(name: &str) {
459 let _ = std::fs::remove_file(local_socket_path(name));
460 let _ = std::fs::remove_file(connect_pid_path(name));
461 let _ = std::fs::remove_file(connect_lock_path(name));
462 let _ = std::fs::remove_file(connect_dest_path(name));
463}
464
465fn enumerate_tunnels() -> Vec<String> {
467 let dir = crate::daemon::socket_dir();
468 let Ok(entries) = std::fs::read_dir(&dir) else {
469 return Vec::new();
470 };
471 entries
472 .filter_map(|e| e.ok())
473 .filter_map(|e| {
474 let name = e.file_name().to_string_lossy().to_string();
475 if name.starts_with("connect-") && name.ends_with(".lock") {
476 Some(name["connect-".len()..name.len() - ".lock".len()].to_string())
477 } else {
478 None
479 }
480 })
481 .collect()
482}
483
484struct ConnectGuard {
489 child: Option<Child>,
490 local_sock: PathBuf,
491 pid_file: PathBuf,
492 lock_file: PathBuf,
493 dest_file: PathBuf,
494 _lock_fd: Option<OwnedFd>,
495 stop: tokio_util::sync::CancellationToken,
496}
497
498impl Drop for ConnectGuard {
499 fn drop(&mut self) {
500 self.stop.cancel();
501
502 if let Some(ref mut child) = self.child
503 && let Some(pid) = child.id()
504 {
505 unsafe {
506 libc::kill(pid as i32, libc::SIGTERM);
507 }
508 }
509
510 let _ = std::fs::remove_file(&self.local_sock);
511 let _ = std::fs::remove_file(&self.pid_file);
512 let _ = std::fs::remove_file(&self.lock_file);
513 let _ = std::fs::remove_file(&self.dest_file);
514 }
516}
517
518pub struct ConnectOpts {
523 pub destination: String,
524 pub no_server_start: bool,
525 pub ssh_options: Vec<String>,
526 pub name: Option<String>,
527 pub dry_run: bool,
528 pub foreground: bool,
529}
530
531pub async fn run(opts: ConnectOpts, ready_fd: Option<OwnedFd>) -> anyhow::Result<i32> {
532 let dest = Destination::parse(&opts.destination)?;
533 let connection_name = opts.name.unwrap_or_else(|| dest.host.clone());
534 let local_sock = local_socket_path(&connection_name);
535
536 if opts.dry_run {
537 let remote_cmd =
538 if opts.no_server_start { "gritty socket-path" } else { REMOTE_ENSURE_CMD };
539 let ensure_cmd = remote_exec_command(&dest, remote_cmd, &opts.ssh_options, true);
540 let tunnel_cmd =
541 tunnel_command(&dest, &local_sock, "$REMOTE_SOCK", &opts.ssh_options, true);
542
543 println!(
544 "# Get remote socket path{}",
545 if opts.no_server_start { "" } else { " and start server if needed" }
546 );
547 println!("REMOTE_SOCK=$({})", format_command(&ensure_cmd));
548 println!();
549 println!("# Start SSH tunnel");
550 println!("{}", format_command(&tunnel_cmd));
551 return Ok(0);
552 }
553
554 let pid_file = connect_pid_path(&connection_name);
556 let lock_path = connect_lock_path(&connection_name);
557 let dest_file = connect_dest_path(&connection_name);
558 debug!("local socket: {}", local_sock.display());
559 if let Some(parent) = local_sock.parent() {
560 crate::security::secure_create_dir_all(parent)?;
561 }
562
563 match probe_tunnel_status(&connection_name) {
565 TunnelStatus::Healthy => {
566 println!("{}", local_sock.display());
567 let pid_hint = read_pid_hint(&connection_name);
568 eprint!("tunnel already running (name: {connection_name})");
569 if let Some(pid) = pid_hint {
570 eprintln!(" (pid {pid})");
571 eprintln!(" to stop: gritty disconnect {connection_name}");
572 } else {
573 eprintln!();
574 }
575 eprintln!(" to use:");
576 eprintln!(" gritty new {connection_name}");
577 eprintln!(" gritty attach {connection_name} -t <name>");
578 signal_ready(&ready_fd);
580 return Ok(0);
581 }
582 TunnelStatus::Reconnecting => {
583 let pid_hint = read_pid_hint(&connection_name);
584 eprint!("tunnel exists but is reconnecting (name: {connection_name})");
585 if let Some(pid) = pid_hint {
586 eprintln!(" (pid {pid})");
587 } else {
588 eprintln!();
589 }
590 eprintln!(" wait for it, or: gritty disconnect {connection_name}");
591 signal_ready(&ready_fd);
593 return Ok(0);
594 }
595 TunnelStatus::Stale => {
596 debug!("cleaning stale tunnel files for {connection_name}");
597 cleanup_stale_files(&connection_name);
598 }
599 }
600
601 let lock_fd = acquire_lock(&lock_path)?;
603
604 let remote_sock =
606 ensure_remote_ready(&dest, opts.no_server_start, &opts.ssh_options, opts.foreground)
607 .await?;
608 debug!(remote_sock, "remote socket path");
609
610 let child =
612 spawn_tunnel(&dest, &local_sock, &remote_sock, &opts.ssh_options, opts.foreground).await?;
613 let stop = tokio_util::sync::CancellationToken::new();
614
615 let mut guard = ConnectGuard {
616 child: Some(child),
617 local_sock: local_sock.clone(),
618 pid_file: pid_file.clone(),
619 lock_file: lock_path,
620 dest_file: dest_file.clone(),
621 _lock_fd: Some(lock_fd),
622 stop: stop.clone(),
623 };
624
625 let mut child = guard.child.take().unwrap();
627 tokio::select! {
628 result = wait_for_socket(&local_sock, Duration::from_secs(15)) => {
629 result?;
630 guard.child = Some(child);
631 }
632 status = child.wait() => {
633 let status = status.context("failed to wait on ssh tunnel")?;
634 let diag = format_ssh_diag(&dest, &opts.ssh_options, opts.foreground);
635 let msg = if let Some(mut stderr) = child.stderr.take() {
636 use tokio::io::AsyncReadExt;
637 let mut buf = String::new();
638 let _ = stderr.read_to_string(&mut buf).await;
639 let buf = buf.trim().to_string();
640 if buf.is_empty() { None } else { Some(buf) }
641 } else {
642 None
643 };
644 match msg {
645 Some(err) => bail!("ssh tunnel failed: {err}\n to diagnose: {diag}"),
646 None => bail!("ssh tunnel exited ({status})\n to diagnose: {diag}"),
647 }
648 }
649 }
650 debug!("tunnel socket ready");
651
652 let _ = std::fs::write(&pid_file, std::process::id().to_string());
654 let _ = std::fs::write(&dest_file, &opts.destination);
655
656 signal_ready(&ready_fd);
658
659 let original_child = guard.child.take().unwrap();
661 let monitor_handle = tokio::spawn(tunnel_monitor(
662 original_child,
663 dest,
664 local_sock.clone(),
665 remote_sock,
666 opts.ssh_options,
667 stop.clone(),
668 ));
669
670 let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
672 tokio::select! {
673 _ = sigterm.recv() => {}
674 _ = monitor_handle => {}
675 }
676
677 drop(guard);
679
680 Ok(0)
681}
682
683fn signal_ready(ready_fd: &Option<OwnedFd>) {
685 if let Some(fd) = ready_fd {
686 let _ = nix::unistd::write(fd, b"\x01");
687 }
688}
689
690pub async fn disconnect(name: &str) -> anyhow::Result<()> {
695 match probe_tunnel_status(name) {
696 TunnelStatus::Stale => {
697 cleanup_stale_files(name);
698 eprintln!("tunnel already stopped: {name}");
699 return Ok(());
700 }
701 TunnelStatus::Healthy | TunnelStatus::Reconnecting => {}
702 }
703
704 let pid_file = connect_pid_path(name);
706 let pid = std::fs::read_to_string(&pid_file)
707 .ok()
708 .and_then(|s| s.trim().parse::<i32>().ok())
709 .ok_or_else(|| anyhow::anyhow!("cannot read PID for tunnel {name}"))?;
710
711 unsafe {
712 libc::kill(pid, libc::SIGTERM);
713 }
714
715 let deadline = Instant::now() + Duration::from_secs(2);
717 loop {
718 tokio::time::sleep(Duration::from_millis(100)).await;
719 if !is_lock_held(&connect_lock_path(name)) {
720 cleanup_stale_files(name);
721 eprintln!("tunnel stopped: {name}");
722 return Ok(());
723 }
724 if Instant::now() >= deadline {
725 break;
726 }
727 }
728
729 unsafe {
731 libc::kill(pid, libc::SIGKILL);
732 libc::killpg(pid, libc::SIGTERM);
733 }
734 tokio::time::sleep(Duration::from_millis(100)).await;
735 cleanup_stale_files(name);
736 eprintln!("tunnel killed: {name}");
737 Ok(())
738}
739
740pub struct TunnelInfo {
745 pub name: String,
746 pub destination: String,
747 pub status: String,
748 pub pid: Option<u32>,
749 pub log_path: PathBuf,
750}
751
752pub fn get_tunnel_info() -> Vec<TunnelInfo> {
754 let names = enumerate_tunnels();
755 let mut infos = Vec::new();
756 for name in &names {
757 let status = probe_tunnel_status(name);
758 if status == TunnelStatus::Stale {
759 debug!("cleaning stale tunnel: {name}");
760 cleanup_stale_files(name);
761 continue;
762 }
763 let dest =
764 std::fs::read_to_string(connect_dest_path(name)).unwrap_or_else(|_| "-".to_string());
765 let status_str = match status {
766 TunnelStatus::Healthy => "healthy".to_string(),
767 TunnelStatus::Reconnecting => "reconnecting".to_string(),
768 TunnelStatus::Stale => unreachable!(),
769 };
770 infos.push(TunnelInfo {
771 name: name.clone(),
772 destination: dest.trim().to_string(),
773 status: status_str,
774 pid: read_pid_hint(name),
775 log_path: crate::daemon::socket_dir().join(format!("connect-{name}.log")),
776 });
777 }
778 infos
779}
780
781pub fn list_tunnels() {
782 let infos = get_tunnel_info();
783 if infos.is_empty() {
784 println!("no active tunnels");
785 return;
786 }
787
788 let w_name = infos.iter().map(|i| i.name.len()).max().unwrap().max(4);
789 let w_dest = infos.iter().map(|i| i.destination.len()).max().unwrap().max(11);
790
791 println!("{:<w_name$} {:<w_dest$} Status", "Name", "Destination");
792 for info in &infos {
793 println!("{:<w_name$} {:<w_dest$} {}", info.name, info.destination, info.status);
794 }
795}
796
797#[cfg(test)]
802mod tests {
803 use super::*;
804
805 #[test]
806 fn parse_destination_user_host() {
807 let d = Destination::parse("user@host").unwrap();
808 assert_eq!(d.user.as_deref(), Some("user"));
809 assert_eq!(d.host, "host");
810 assert_eq!(d.port, None);
811 }
812
813 #[test]
814 fn parse_destination_host_only() {
815 let d = Destination::parse("myhost").unwrap();
816 assert_eq!(d.user, None);
817 assert_eq!(d.host, "myhost");
818 assert_eq!(d.port, None);
819 }
820
821 #[test]
822 fn parse_destination_host_port() {
823 let d = Destination::parse("host:2222").unwrap();
824 assert_eq!(d.user, None);
825 assert_eq!(d.host, "host");
826 assert_eq!(d.port, Some(2222));
827 }
828
829 #[test]
830 fn parse_destination_user_host_port() {
831 let d = Destination::parse("user@host:2222").unwrap();
832 assert_eq!(d.user.as_deref(), Some("user"));
833 assert_eq!(d.host, "host");
834 assert_eq!(d.port, Some(2222));
835 }
836
837 #[test]
838 fn parse_destination_invalid_empty() {
839 assert!(Destination::parse("").is_err());
840 }
841
842 #[test]
843 fn parse_destination_invalid_at_only() {
844 assert!(Destination::parse("@host").is_err());
845 }
846
847 #[test]
848 fn parse_destination_invalid_colon_only() {
849 assert!(Destination::parse(":2222").is_err());
850 }
851
852 #[test]
853 fn tunnel_command_default_opts() {
854 let dest = Destination::parse("user@host").unwrap();
855 let cmd = tunnel_command(
856 &dest,
857 Path::new("/tmp/local.sock"),
858 "/run/user/1000/gritty/ctl.sock",
859 &[],
860 false,
861 );
862 let args: Vec<_> =
863 cmd.as_std().get_args().map(|a| a.to_string_lossy().to_string()).collect();
864 assert!(args.contains(&"ConnectTimeout=5".to_string()));
866 assert!(args.contains(&"BatchMode=yes".to_string()));
867 assert!(args.contains(&"ServerAliveInterval=3".to_string()));
869 assert!(args.contains(&"StreamLocalBindUnlink=yes".to_string()));
870 assert!(args.contains(&"ExitOnForwardFailure=yes".to_string()));
871 assert!(args.contains(&"ControlPath=none".to_string()));
872 assert!(args.contains(&"ForwardAgent=no".to_string()));
873 assert!(args.contains(&"ForwardX11=no".to_string()));
874 assert!(args.contains(&"-N".to_string()));
876 assert!(args.contains(&"-T".to_string()));
877 assert!(args.contains(&"/tmp/local.sock:/run/user/1000/gritty/ctl.sock".to_string()));
878 assert!(args.contains(&"user@host".to_string()));
879 }
880
881 #[test]
882 fn tunnel_command_extra_opts() {
883 let dest = Destination::parse("host:2222").unwrap();
884 let cmd = tunnel_command(
885 &dest,
886 Path::new("/tmp/local.sock"),
887 "/tmp/remote.sock",
888 &["ProxyJump=bastion".to_string()],
889 false,
890 );
891 let args: Vec<_> =
892 cmd.as_std().get_args().map(|a| a.to_string_lossy().to_string()).collect();
893 assert!(args.contains(&"ProxyJump=bastion".to_string()));
894 assert!(args.contains(&"-p".to_string()));
895 assert!(args.contains(&"2222".to_string()));
896 }
897
898 #[test]
899 fn local_socket_path_format() {
900 let path = local_socket_path("devbox");
902 assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-devbox.sock");
903
904 let path = local_socket_path("example.com");
905 assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-example.com.sock");
906
907 let path = local_socket_path("myproject");
909 assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-myproject.sock");
910 }
911
912 #[test]
913 fn connect_pid_path_format() {
914 let path = connect_pid_path("devbox");
915 assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-devbox.pid");
916
917 let path = connect_pid_path("example.com");
918 assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-example.com.pid");
919 }
920
921 #[test]
922 fn ssh_dest_with_user() {
923 let d = Destination::parse("alice@example.com").unwrap();
924 assert_eq!(d.ssh_dest(), "alice@example.com");
925 }
926
927 #[test]
928 fn ssh_dest_without_user() {
929 let d = Destination::parse("example.com").unwrap();
930 assert_eq!(d.ssh_dest(), "example.com");
931 }
932
933 #[test]
934 fn port_args_with_port() {
935 let d = Destination::parse("host:9999").unwrap();
936 assert_eq!(d.port_args(), vec!["-p", "9999"]);
937 }
938
939 #[test]
940 fn port_args_without_port() {
941 let d = Destination::parse("host").unwrap();
942 assert!(d.port_args().is_empty());
943 }
944
945 #[test]
946 fn shell_quote_simple() {
947 assert_eq!(shell_quote("hello"), "hello");
948 assert_eq!(shell_quote("-N"), "-N");
949 assert_eq!(shell_quote("ServerAliveInterval=3"), "ServerAliveInterval=3");
950 assert_eq!(shell_quote("user@host"), "user@host");
951 assert_eq!(
952 shell_quote("/tmp/local.sock:/tmp/remote.sock"),
953 "/tmp/local.sock:/tmp/remote.sock"
954 );
955 assert_eq!(shell_quote("$REMOTE_SOCK"), "$REMOTE_SOCK");
956 }
957
958 #[test]
959 fn shell_quote_needs_quoting() {
960 assert_eq!(shell_quote("hello world"), "'hello world'");
961 assert_eq!(shell_quote(""), "''");
962 assert_eq!(shell_quote("it's"), "'it'\\''s'");
963 }
964
965 #[test]
966 fn shell_quote_remote_cmd() {
967 let cmd = format!("PATH=\"{REMOTE_PATH_PREFIX}\"; gritty socket-path");
970 let quoted = shell_quote(&cmd);
971 assert!(quoted.starts_with('\''));
972 assert!(quoted.ends_with('\''));
973 }
974
975 #[test]
976 fn format_command_tunnel() {
977 let dest = Destination::parse("user@host").unwrap();
978 let cmd = tunnel_command(&dest, Path::new("/tmp/local.sock"), "$REMOTE_SOCK", &[], true);
979 let formatted = format_command(&cmd);
980 assert!(formatted.contains("ServerAliveInterval=3"));
981 assert!(formatted.contains("ControlPath=none"));
982 assert!(formatted.contains("ForwardAgent=no"));
983 assert!(formatted.contains("-N"));
984 assert!(formatted.contains("-T"));
985 assert!(formatted.contains("/tmp/local.sock:$REMOTE_SOCK"));
987 assert!(formatted.contains("user@host"));
988 }
989
990 #[test]
991 fn format_command_remote_exec() {
992 let dest = Destination::parse("user@host:2222").unwrap();
993 let cmd = remote_exec_command(&dest, "gritty socket-path", &[], true);
994 let formatted = format_command(&cmd);
995 assert!(formatted.starts_with("ssh "));
996 assert!(formatted.contains("-p 2222"));
997 assert!(formatted.contains("ConnectTimeout=5"));
998 assert!(formatted.contains("user@host"));
999 assert!(formatted.contains(&format!("PATH=\"{REMOTE_PATH_PREFIX}\"")));
1001 }
1002
1003 #[test]
1004 fn format_command_remote_exec_with_extra_opts() {
1005 let dest = Destination::parse("user@host").unwrap();
1006 let cmd =
1007 remote_exec_command(&dest, REMOTE_ENSURE_CMD, &["ProxyJump=bastion".to_string()], true);
1008 let formatted = format_command(&cmd);
1009 assert!(formatted.contains("ProxyJump=bastion"));
1010 assert!(formatted.contains("gritty socket-path"));
1011 assert!(formatted.contains("gritty server"));
1012 }
1013
1014 #[test]
1015 fn base_ssh_args_foreground() {
1016 let dest = Destination::parse("user@host:2222").unwrap();
1017 let args = base_ssh_args(&dest, &["ProxyJump=bastion".into()], true);
1018 assert!(args.contains(&"-p".to_string()));
1019 assert!(args.contains(&"2222".to_string()));
1020 assert!(args.contains(&"ProxyJump=bastion".to_string()));
1021 assert!(args.contains(&"ConnectTimeout=5".to_string()));
1022 assert!(!args.contains(&"BatchMode=yes".to_string()));
1023 }
1024
1025 #[test]
1026 fn base_ssh_args_background() {
1027 let dest = Destination::parse("host").unwrap();
1028 let args = base_ssh_args(&dest, &[], false);
1029 assert!(args.contains(&"ConnectTimeout=5".to_string()));
1030 assert!(args.contains(&"BatchMode=yes".to_string()));
1031 assert!(!args.contains(&"-p".to_string()));
1032 }
1033
1034 #[test]
1039 fn connect_lock_path_format() {
1040 let path = connect_lock_path("devbox");
1041 assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-devbox.lock");
1042 }
1043
1044 #[test]
1045 fn connect_dest_path_format() {
1046 let path = connect_dest_path("devbox");
1047 assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-devbox.dest");
1048 }
1049
1050 #[test]
1051 fn acquire_and_probe_lock() {
1052 let dir = tempfile::tempdir().unwrap();
1053 let lock_path = dir.path().join("test.lock");
1054
1055 assert!(!is_lock_held(&lock_path));
1057
1058 let _fd = acquire_lock(&lock_path).unwrap();
1060
1061 assert!(is_lock_held(&lock_path));
1063
1064 drop(_fd);
1066
1067 assert!(!is_lock_held(&lock_path));
1069 }
1070
1071 #[test]
1072 fn probe_stale_no_files() {
1073 let status = probe_tunnel_status("nonexistent-test-tunnel-xyz");
1075 assert_eq!(status, TunnelStatus::Stale);
1076 }
1077
1078 #[test]
1079 fn cleanup_stale_files_removes_all() {
1080 let _dir = tempfile::tempdir().unwrap();
1081 cleanup_stale_files("nonexistent-cleanup-test-xyz");
1084 }
1086
1087 #[test]
1088 fn enumerate_tunnels_empty_dir() {
1089 let names = enumerate_tunnels();
1092 let _ = names;
1095 }
1096
1097 #[test]
1098 fn connection_socket_path_matches_local() {
1099 let public_path = connection_socket_path("myhost");
1100 let internal_path = local_socket_path("myhost");
1101 assert_eq!(public_path, internal_path);
1102 }
1103
1104 #[tokio::test]
1109 async fn tunnel_monitor_non_transient_exit() {
1110 let child = Command::new("sh").arg("-c").arg("exit 1").spawn().unwrap();
1111 let dest = Destination::parse("fake-host-test").unwrap();
1112 let stop = tokio_util::sync::CancellationToken::new();
1113
1114 let result = tokio::time::timeout(
1115 Duration::from_secs(5),
1116 tunnel_monitor(
1117 child,
1118 dest,
1119 PathBuf::from("/tmp/nonexistent.sock"),
1120 "/tmp/remote.sock".into(),
1121 vec![],
1122 stop,
1123 ),
1124 )
1125 .await;
1126
1127 assert!(result.is_ok(), "monitor should return quickly on non-transient exit");
1128 }
1129
1130 #[tokio::test]
1131 async fn tunnel_monitor_cancellation() {
1132 let child = Command::new("sleep").arg("60").spawn().unwrap();
1133 let dest = Destination::parse("fake-host-test").unwrap();
1134 let stop = tokio_util::sync::CancellationToken::new();
1135 let stop_clone = stop.clone();
1136
1137 tokio::spawn(async move {
1138 tokio::time::sleep(Duration::from_millis(100)).await;
1139 stop_clone.cancel();
1140 });
1141
1142 let result = tokio::time::timeout(
1143 Duration::from_secs(5),
1144 tunnel_monitor(
1145 child,
1146 dest,
1147 PathBuf::from("/tmp/nonexistent.sock"),
1148 "/tmp/remote.sock".into(),
1149 vec![],
1150 stop,
1151 ),
1152 )
1153 .await;
1154
1155 assert!(result.is_ok(), "monitor should return after cancellation");
1156 }
1157
1158 #[tokio::test]
1159 async fn tunnel_monitor_transient_exit_checks_cancellation() {
1160 let child = Command::new("sh").arg("-c").arg("exit 255").spawn().unwrap();
1162 let dest = Destination::parse("fake-host-test").unwrap();
1163 let stop = tokio_util::sync::CancellationToken::new();
1164 let stop_clone = stop.clone();
1165
1166 tokio::spawn(async move {
1168 tokio::time::sleep(Duration::from_millis(500)).await;
1169 stop_clone.cancel();
1170 });
1171
1172 let result = tokio::time::timeout(
1173 Duration::from_secs(5),
1174 tunnel_monitor(
1175 child,
1176 dest,
1177 PathBuf::from("/tmp/nonexistent.sock"),
1178 "/tmp/remote.sock".into(),
1179 vec![],
1180 stop,
1181 ),
1182 )
1183 .await;
1184
1185 assert!(result.is_ok(), "monitor should return after cancellation during sleep");
1186 }
1187
1188 #[tokio::test]
1193 async fn wait_for_socket_succeeds_after_delay() {
1194 let dir = tempfile::tempdir().unwrap();
1195 let sock_path = dir.path().join("delayed.sock");
1196 let sock_path_clone = sock_path.clone();
1197
1198 tokio::spawn(async move {
1200 tokio::time::sleep(Duration::from_millis(500)).await;
1201 let _listener = tokio::net::UnixListener::bind(&sock_path_clone).unwrap();
1202 tokio::time::sleep(Duration::from_secs(30)).await;
1204 });
1205
1206 let result = wait_for_socket(&sock_path, Duration::from_secs(5)).await;
1207 assert!(result.is_ok(), "should successfully connect");
1208 }
1209
1210 #[tokio::test]
1211 async fn wait_for_socket_timeout() {
1212 let dir = tempfile::tempdir().unwrap();
1213 let sock_path = dir.path().join("never.sock");
1214 let result = wait_for_socket(&sock_path, Duration::from_secs(1)).await;
1215 assert!(result.is_err());
1216 }
1217}