1use anyhow::{Context, bail};
2use std::path::{Path, PathBuf};
3use std::process::Stdio;
4use std::time::{Duration, Instant};
5use tokio::process::{Child, Command};
6use tracing::{debug, info, warn};
7
8#[derive(Debug, Clone, PartialEq, Eq)]
13struct Destination {
14 user: Option<String>,
15 host: String,
16 port: Option<u16>,
17}
18
19impl Destination {
20 fn parse(s: &str) -> anyhow::Result<Self> {
21 if s.is_empty() {
22 bail!("empty destination");
23 }
24
25 let (user, remainder) = if let Some(at) = s.find('@') {
26 let u = &s[..at];
27 if u.is_empty() {
28 bail!("empty user in destination: {s}");
29 }
30 (Some(u.to_string()), &s[at + 1..])
31 } else {
32 (None, s)
33 };
34
35 let (host, port) = if let Some(colon) = remainder.rfind(':') {
36 let h = &remainder[..colon];
37 let p = remainder[colon + 1..]
38 .parse::<u16>()
39 .with_context(|| format!("invalid port in destination: {s}"))?;
40 (h.to_string(), Some(p))
41 } else {
42 (remainder.to_string(), None)
43 };
44
45 if host.is_empty() {
46 bail!("empty host in destination: {s}");
47 }
48
49 Ok(Self { user, host, port })
50 }
51
52 fn ssh_dest(&self) -> String {
54 match &self.user {
55 Some(u) => format!("{u}@{}", self.host),
56 None => self.host.clone(),
57 }
58 }
59
60 fn port_args(&self) -> Vec<String> {
62 match self.port {
63 Some(p) => vec!["-p".to_string(), p.to_string()],
64 None => vec![],
65 }
66 }
67}
68
69const SSH_TUNNEL_OPTS: &[&str] = &[
75 "-o",
76 "ServerAliveInterval=3",
77 "-o",
78 "ServerAliveCountMax=2",
79 "-o",
80 "StreamLocalBindUnlink=yes",
81 "-o",
82 "ExitOnForwardFailure=yes",
83 "-o",
84 "ConnectTimeout=5",
85 "-N",
86 "-T",
87];
88
89async fn remote_exec(
91 dest: &Destination,
92 remote_cmd: &str,
93 extra_ssh_opts: &[String],
94) -> anyhow::Result<String> {
95 let wrapped_cmd =
98 format!("PATH=\"$HOME/bin:$HOME/.local/bin:$HOME/.cargo/bin:$PATH\"; {remote_cmd}");
99
100 debug!("ssh {}: {remote_cmd}", dest.ssh_dest());
101
102 let mut cmd = Command::new("ssh");
103 cmd.args(dest.port_args());
104 for opt in extra_ssh_opts {
105 cmd.arg("-o").arg(opt);
106 }
107 cmd.arg("-o").arg("ConnectTimeout=5");
108 cmd.arg(dest.ssh_dest());
109 cmd.arg(&wrapped_cmd);
110 cmd.stdout(Stdio::piped());
111 cmd.stderr(Stdio::piped());
112 cmd.stdin(Stdio::null());
113
114 let output = cmd.output().await.context("failed to run ssh")?;
115
116 if !output.status.success() {
117 let stderr = String::from_utf8_lossy(&output.stderr);
118 let stderr = stderr.trim();
119 debug!("ssh failed (status {}): {stderr}", output.status);
120 if stderr.contains("command not found") || stderr.contains("No such file") {
121 bail!("gritty not found on remote host (is it in PATH?)");
122 }
123 bail!("ssh command failed: {stderr}");
124 }
125
126 let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string();
127 debug!("ssh output: {stdout}");
128 Ok(stdout)
129}
130
131fn tunnel_command(
133 dest: &Destination,
134 local_sock: &Path,
135 remote_sock: &str,
136 extra_ssh_opts: &[String],
137) -> Command {
138 let mut cmd = Command::new("ssh");
139 cmd.args(dest.port_args());
140 cmd.args(SSH_TUNNEL_OPTS);
141 for opt in extra_ssh_opts {
142 cmd.arg("-o").arg(opt);
143 }
144 let forward = format!("{}:{}", local_sock.display(), remote_sock);
145 cmd.arg("-L").arg(forward);
146 cmd.arg(dest.ssh_dest());
147 cmd.stdout(Stdio::null());
148 cmd.stderr(Stdio::piped());
149 cmd.stdin(Stdio::null());
150 cmd
151}
152
153async fn spawn_tunnel(
155 dest: &Destination,
156 local_sock: &Path,
157 remote_sock: &str,
158 extra_ssh_opts: &[String],
159) -> anyhow::Result<Child> {
160 debug!("tunnel: {} -> {}:{}", local_sock.display(), dest.ssh_dest(), remote_sock,);
161 let mut cmd = tunnel_command(dest, local_sock, remote_sock, extra_ssh_opts);
162 let child = cmd.spawn().context("failed to spawn ssh tunnel")?;
163 debug!("ssh tunnel pid: {:?}", child.id());
164 Ok(child)
165}
166
167async fn wait_for_socket(path: &Path) -> anyhow::Result<()> {
169 let deadline = Instant::now() + Duration::from_secs(15);
170 loop {
171 if std::os::unix::net::UnixStream::connect(path).is_ok() {
172 return Ok(());
173 }
174 if Instant::now() >= deadline {
175 bail!("timeout waiting for SSH tunnel socket at {}", path.display());
176 }
177 tokio::time::sleep(Duration::from_millis(200)).await;
178 }
179}
180
181async fn tunnel_monitor(
183 mut child: Child,
184 dest: Destination,
185 local_sock: PathBuf,
186 remote_sock: String,
187 extra_ssh_opts: Vec<String>,
188 stop: tokio_util::sync::CancellationToken,
189) {
190 let mut exit_times: Vec<Instant> = Vec::new();
191
192 loop {
193 tokio::select! {
194 _ = stop.cancelled() => {
195 let _ = child.kill().await;
196 return;
197 }
198 status = child.wait() => {
199 let status = match status {
200 Ok(s) => s,
201 Err(e) => {
202 warn!("failed to wait on ssh tunnel: {e}");
203 return;
204 }
205 };
206
207 if stop.is_cancelled() {
208 return;
209 }
210
211 let code = status.code();
212 debug!("ssh tunnel exited: {:?}", code);
213
214 if let Some(c) = code
218 && c != 255
219 {
220 warn!("ssh tunnel exited with code {c} (not retrying)");
221 return;
222 }
223
224 let now = Instant::now();
226 exit_times.push(now);
227 exit_times.retain(|t| now.duration_since(*t) < Duration::from_secs(10));
228 if exit_times.len() >= 5 {
229 warn!("ssh tunnel failing too fast (5 exits in 10s), giving up");
230 return;
231 }
232
233 tokio::time::sleep(Duration::from_secs(1)).await;
234
235 if stop.is_cancelled() {
236 return;
237 }
238
239 match spawn_tunnel(&dest, &local_sock, &remote_sock, &extra_ssh_opts).await {
240 Ok(new_child) => {
241 info!("ssh tunnel respawned");
242 child = new_child;
243 }
244 Err(e) => {
245 warn!("failed to respawn ssh tunnel: {e}");
246 return;
247 }
248 }
249 }
250 }
251 }
252}
253
254const REMOTE_ENSURE_CMD: &str = "\
259 SOCK=$(gritty socket-path) && \
260 (gritty ls >/dev/null 2>&1 || \
261 { gritty daemon && sleep 0.3; }) && \
262 echo \"$SOCK\"";
263
264async fn ensure_remote_ready(
266 dest: &Destination,
267 no_daemon_start: bool,
268 extra_ssh_opts: &[String],
269) -> anyhow::Result<String> {
270 let remote_cmd = if no_daemon_start { "gritty socket-path" } else { REMOTE_ENSURE_CMD };
271 debug!("ensuring remote daemon (no_daemon_start={no_daemon_start})");
272
273 let sock_path = remote_exec(dest, remote_cmd, extra_ssh_opts).await?;
274
275 if sock_path.is_empty() {
276 bail!("remote host returned empty socket path");
277 }
278
279 Ok(sock_path)
280}
281
282fn local_socket_path(destination: &str) -> PathBuf {
292 crate::daemon::socket_dir().join(format!("connect-{destination}.sock"))
293}
294
295fn connect_pid_path(connection_name: &str) -> PathBuf {
296 crate::daemon::socket_dir().join(format!("connect-{connection_name}.pid"))
297}
298
299struct ConnectGuard {
304 child: Option<Child>,
305 local_sock: PathBuf,
306 pid_file: PathBuf,
307 stop: tokio_util::sync::CancellationToken,
308}
309
310impl Drop for ConnectGuard {
311 fn drop(&mut self) {
312 self.stop.cancel();
313
314 if let Some(ref mut child) = self.child
315 && let Some(pid) = child.id()
316 {
317 unsafe {
318 libc::kill(pid as i32, libc::SIGTERM);
319 }
320 }
321
322 let _ = std::fs::remove_file(&self.local_sock);
323 let _ = std::fs::remove_file(&self.pid_file);
324 }
325}
326
327pub struct ConnectOpts {
332 pub destination: String,
333 pub no_daemon_start: bool,
334 pub ssh_options: Vec<String>,
335 pub name: Option<String>,
336}
337
338pub async fn run(opts: ConnectOpts) -> anyhow::Result<i32> {
339 let dest = Destination::parse(&opts.destination)?;
340 let connection_name = opts.name.unwrap_or_else(|| dest.host.clone());
341
342 let local_sock = local_socket_path(&connection_name);
344 let pid_file = connect_pid_path(&connection_name);
345 debug!("local socket: {}", local_sock.display());
346 if let Some(parent) = local_sock.parent() {
347 crate::security::secure_create_dir_all(parent)?;
348 }
349
350 if std::os::unix::net::UnixStream::connect(&local_sock).is_ok() {
351 let pid_hint =
352 std::fs::read_to_string(&pid_file).ok().and_then(|s| s.trim().parse::<u32>().ok());
353 println!("{}", local_sock.display());
354 eprint!("tunnel already running (name: {connection_name})");
355 if let Some(pid) = pid_hint {
356 eprintln!(" (pid {pid})");
357 eprintln!(" to stop: kill {pid}");
358 } else {
359 eprintln!();
360 }
361 eprintln!(" to use:");
362 eprintln!(" gritty new {connection_name}");
363 eprintln!(" gritty attach {connection_name} -t <name>");
364 return Ok(0);
365 }
366 let _ = std::fs::remove_file(&local_sock);
368
369 eprintln!("starting remote daemon...");
371 let remote_sock = ensure_remote_ready(&dest, opts.no_daemon_start, &opts.ssh_options).await?;
372 debug!(remote_sock, "remote socket path");
373
374 let child = spawn_tunnel(&dest, &local_sock, &remote_sock, &opts.ssh_options).await?;
376 let stop = tokio_util::sync::CancellationToken::new();
377
378 let mut guard = ConnectGuard {
379 child: Some(child),
380 local_sock: local_sock.clone(),
381 pid_file: pid_file.clone(),
382 stop: stop.clone(),
383 };
384
385 wait_for_socket(&local_sock).await?;
387 debug!("tunnel socket ready");
388
389 let _ = std::fs::write(&pid_file, std::process::id().to_string());
391
392 let original_child = guard.child.take().unwrap();
394 let monitor_handle = tokio::spawn(tunnel_monitor(
395 original_child,
396 dest,
397 local_sock.clone(),
398 remote_sock,
399 opts.ssh_options,
400 stop.clone(),
401 ));
402
403 println!("{}", local_sock.display());
405 eprintln!("tunnel ready (name: {connection_name}). to use:");
406 eprintln!(" gritty new {connection_name}");
407 eprintln!(" gritty attach {connection_name} -t <name>");
408
409 let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
411 tokio::select! {
412 _ = tokio::signal::ctrl_c() => {}
413 _ = sigterm.recv() => {}
414 _ = monitor_handle => {
415 eprintln!("tunnel lost");
416 }
417 }
418
419 drop(guard);
421
422 Ok(0)
423}
424
425#[cfg(test)]
430mod tests {
431 use super::*;
432
433 #[test]
434 fn parse_destination_user_host() {
435 let d = Destination::parse("user@host").unwrap();
436 assert_eq!(d.user.as_deref(), Some("user"));
437 assert_eq!(d.host, "host");
438 assert_eq!(d.port, None);
439 }
440
441 #[test]
442 fn parse_destination_host_only() {
443 let d = Destination::parse("myhost").unwrap();
444 assert_eq!(d.user, None);
445 assert_eq!(d.host, "myhost");
446 assert_eq!(d.port, None);
447 }
448
449 #[test]
450 fn parse_destination_host_port() {
451 let d = Destination::parse("host:2222").unwrap();
452 assert_eq!(d.user, None);
453 assert_eq!(d.host, "host");
454 assert_eq!(d.port, Some(2222));
455 }
456
457 #[test]
458 fn parse_destination_user_host_port() {
459 let d = Destination::parse("user@host:2222").unwrap();
460 assert_eq!(d.user.as_deref(), Some("user"));
461 assert_eq!(d.host, "host");
462 assert_eq!(d.port, Some(2222));
463 }
464
465 #[test]
466 fn parse_destination_invalid_empty() {
467 assert!(Destination::parse("").is_err());
468 }
469
470 #[test]
471 fn parse_destination_invalid_at_only() {
472 assert!(Destination::parse("@host").is_err());
473 }
474
475 #[test]
476 fn parse_destination_invalid_colon_only() {
477 assert!(Destination::parse(":2222").is_err());
478 }
479
480 #[test]
481 fn tunnel_command_default_opts() {
482 let dest = Destination::parse("user@host").unwrap();
483 let cmd = tunnel_command(
484 &dest,
485 Path::new("/tmp/local.sock"),
486 "/run/user/1000/gritty/ctl.sock",
487 &[],
488 );
489 let args: Vec<_> =
490 cmd.as_std().get_args().map(|a| a.to_string_lossy().to_string()).collect();
491 assert!(args.contains(&"ServerAliveInterval=3".to_string()));
492 assert!(args.contains(&"StreamLocalBindUnlink=yes".to_string()));
493 assert!(args.contains(&"ExitOnForwardFailure=yes".to_string()));
494 assert!(args.contains(&"ConnectTimeout=5".to_string()));
495 assert!(args.contains(&"-N".to_string()));
496 assert!(args.contains(&"-T".to_string()));
497 assert!(args.contains(&"/tmp/local.sock:/run/user/1000/gritty/ctl.sock".to_string()));
498 assert!(args.contains(&"user@host".to_string()));
499 }
500
501 #[test]
502 fn tunnel_command_extra_opts() {
503 let dest = Destination::parse("host:2222").unwrap();
504 let cmd = tunnel_command(
505 &dest,
506 Path::new("/tmp/local.sock"),
507 "/tmp/remote.sock",
508 &["ProxyJump=bastion".to_string()],
509 );
510 let args: Vec<_> =
511 cmd.as_std().get_args().map(|a| a.to_string_lossy().to_string()).collect();
512 assert!(args.contains(&"ProxyJump=bastion".to_string()));
513 assert!(args.contains(&"-p".to_string()));
514 assert!(args.contains(&"2222".to_string()));
515 }
516
517 #[test]
518 fn local_socket_path_format() {
519 let path = local_socket_path("devbox");
521 assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-devbox.sock");
522
523 let path = local_socket_path("example.com");
524 assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-example.com.sock");
525
526 let path = local_socket_path("myproject");
528 assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-myproject.sock");
529 }
530
531 #[test]
532 fn connect_pid_path_format() {
533 let path = connect_pid_path("devbox");
534 assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-devbox.pid");
535
536 let path = connect_pid_path("example.com");
537 assert_eq!(path.file_name().unwrap().to_string_lossy(), "connect-example.com.pid");
538 }
539
540 #[test]
541 fn ssh_dest_with_user() {
542 let d = Destination::parse("alice@example.com").unwrap();
543 assert_eq!(d.ssh_dest(), "alice@example.com");
544 }
545
546 #[test]
547 fn ssh_dest_without_user() {
548 let d = Destination::parse("example.com").unwrap();
549 assert_eq!(d.ssh_dest(), "example.com");
550 }
551
552 #[test]
553 fn port_args_with_port() {
554 let d = Destination::parse("host:9999").unwrap();
555 assert_eq!(d.port_args(), vec!["-p", "9999"]);
556 }
557
558 #[test]
559 fn port_args_without_port() {
560 let d = Destination::parse("host").unwrap();
561 assert!(d.port_args().is_empty());
562 }
563}