Skip to main content

gritty/
daemon.rs

1use crate::protocol::{Frame, FrameCodec, PROTOCOL_VERSION, SessionEntry};
2use crate::server::{self, ClientConn, SessionMetadata};
3use futures_util::{SinkExt, StreamExt};
4use std::collections::HashMap;
5use std::os::fd::OwnedFd;
6use std::path::{Path, PathBuf};
7use std::sync::atomic::Ordering;
8use std::sync::{Arc, OnceLock};
9use std::time::Duration;
10use tokio::net::UnixStream;
11use tokio::sync::mpsc;
12use tokio::task::JoinHandle;
13use tokio_util::codec::Framed;
14use tracing::{error, info, warn};
15
16const SEND_TIMEOUT: Duration = Duration::from_secs(5);
17
18/// Send a frame with a timeout. Returns `Ok(())` on success, `Err` on
19/// send failure or timeout (error is logged before returning).
20async fn timed_send(
21    framed: &mut Framed<UnixStream, FrameCodec>,
22    frame: Frame,
23) -> Result<(), std::io::Error> {
24    match tokio::time::timeout(SEND_TIMEOUT, framed.send(frame)).await {
25        Ok(Ok(())) => Ok(()),
26        Ok(Err(e)) => {
27            warn!("control send error: {e}");
28            Err(e)
29        }
30        Err(_) => {
31            warn!("control send timed out");
32            Err(std::io::Error::new(std::io::ErrorKind::TimedOut, "send timed out"))
33        }
34    }
35}
36
37struct SessionState {
38    handle: JoinHandle<anyhow::Result<()>>,
39    metadata: Arc<OnceLock<SessionMetadata>>,
40    client_tx: mpsc::UnboundedSender<ClientConn>,
41    name: Option<String>,
42}
43
44/// Returns the base directory for gritty sockets.
45/// Prefers $GRITTY_SOCKET_DIR, then $XDG_RUNTIME_DIR/gritty, falls back to /tmp/gritty-$UID.
46pub fn socket_dir() -> PathBuf {
47    if let Ok(dir) = std::env::var("GRITTY_SOCKET_DIR") {
48        return PathBuf::from(dir);
49    }
50    if let Some(proj) = directories::ProjectDirs::from("", "", "gritty") {
51        if let Some(runtime) = proj.runtime_dir() {
52            return runtime.to_path_buf();
53        }
54    }
55    let uid = unsafe { libc::getuid() };
56    PathBuf::from(format!("/tmp/gritty-{uid}"))
57}
58
59/// Returns the daemon socket path.
60pub fn control_socket_path() -> PathBuf {
61    socket_dir().join("ctl.sock")
62}
63
64/// Returns the PID file path (sibling to ctl.sock).
65pub fn pid_file_path(ctl_path: &Path) -> PathBuf {
66    ctl_path.with_file_name("daemon.pid")
67}
68
69fn reap_sessions(sessions: &mut HashMap<u32, SessionState>) {
70    sessions.retain(|id, state| {
71        if state.handle.is_finished() {
72            info!(id, "session ended");
73            false
74        } else {
75            true
76        }
77    });
78}
79
80/// Resolve a session identifier (name, id string, or "-" for last attached) to a session id.
81fn resolve_session(
82    sessions: &HashMap<u32, SessionState>,
83    target: &str,
84    last_attached: Option<u32>,
85) -> Option<u32> {
86    // "-" means last attached session
87    if target == "-" {
88        return last_attached.filter(|id| sessions.contains_key(id));
89    }
90    // Try name match first
91    for (&id, state) in sessions {
92        if state.name.as_deref() == Some(target) {
93            return Some(id);
94        }
95    }
96    // Then try parsing as numeric id
97    if let Ok(id) = target.parse::<u32>()
98        && sessions.contains_key(&id)
99    {
100        return Some(id);
101    }
102    None
103}
104
105/// Read the foreground process command name for a shell pid via /proc.
106/// Returns "-" on any failure (non-Linux, permission denied, etc.).
107fn foreground_process(shell_pid: u32) -> String {
108    // Read /proc/{shell_pid}/stat to get tpgid (field 8, 1-indexed)
109    let stat = match std::fs::read_to_string(format!("/proc/{shell_pid}/stat")) {
110        Ok(s) => s,
111        Err(_) => return "-".to_string(),
112    };
113    // Fields are space-separated, but field 2 (comm) is in parens and may contain spaces.
114    // Find the closing paren, then parse fields after it.
115    let after_comm = match stat.rfind(')') {
116        Some(pos) => &stat[pos + 2..], // skip ") "
117        None => return "-".to_string(),
118    };
119    // Fields after comm: state(3) ppid(4) pgrp(5) session(6) tty_nr(7) tpgid(8)
120    // That's index 5 in the remaining space-separated fields (0-indexed)
121    let fields: Vec<&str> = after_comm.split_whitespace().collect();
122    let tpgid = match fields.get(5).and_then(|s| s.parse::<u32>().ok()) {
123        Some(t) if t > 0 => t,
124        _ => return "-".to_string(),
125    };
126    // Read /proc/{tpgid}/comm
127    std::fs::read_to_string(format!("/proc/{tpgid}/comm"))
128        .map(|s| s.trim().to_string())
129        .unwrap_or_else(|_| "-".to_string())
130}
131
132fn build_session_entries(sessions: &HashMap<u32, SessionState>) -> Vec<SessionEntry> {
133    let mut entries: Vec<_> = sessions
134        .iter()
135        .map(|(&id, state)| {
136            if let Some(meta) = state.metadata.get() {
137                SessionEntry {
138                    id: id.to_string(),
139                    name: state.name.clone().unwrap_or_default(),
140                    pty_path: meta.pty_path.clone(),
141                    shell_pid: meta.shell_pid,
142                    created_at: meta.created_at,
143                    attached: meta.attached.load(Ordering::Relaxed),
144                    last_heartbeat: meta.last_heartbeat.load(Ordering::Relaxed),
145                    foreground_cmd: foreground_process(meta.shell_pid),
146                }
147            } else {
148                SessionEntry {
149                    id: id.to_string(),
150                    name: state.name.clone().unwrap_or_default(),
151                    pty_path: String::new(),
152                    shell_pid: 0,
153                    created_at: 0,
154                    attached: false,
155                    last_heartbeat: 0,
156                    foreground_cmd: "-".to_string(),
157                }
158            }
159        })
160        .collect();
161    entries.sort_by_key(|e| e.id.parse::<u32>().unwrap_or(u32::MAX));
162    entries
163}
164
165fn shutdown(sessions: &mut HashMap<u32, SessionState>, ctl_path: &Path) {
166    for (id, state) in sessions.drain() {
167        state.handle.abort();
168        info!(id, "session aborted");
169    }
170    let _ = std::fs::remove_file(ctl_path);
171    let _ = std::fs::remove_file(pid_file_path(ctl_path));
172}
173
174/// Perform Hello/HelloAck handshake and read control frame for a single connection.
175/// Spawned as a per-connection task so slow clients don't block the accept loop.
176async fn connection_handshake(
177    stream: UnixStream,
178    tx: mpsc::Sender<(Frame, Framed<UnixStream, FrameCodec>)>,
179) {
180    let mut framed = Framed::new(stream, FrameCodec);
181
182    // Read Hello handshake (5s timeout)
183    let hello = match tokio::time::timeout(Duration::from_secs(5), framed.next()).await {
184        Ok(Some(Ok(Frame::Hello { version }))) => version,
185        Ok(Some(Ok(_))) => {
186            let _ = timed_send(
187                &mut framed,
188                Frame::Error { message: "expected Hello handshake".to_string() },
189            )
190            .await;
191            return;
192        }
193        Ok(Some(Err(e))) => {
194            warn!("frame decode error: {e}");
195            return;
196        }
197        Ok(None) => return,
198        Err(_) => {
199            warn!("control connection timed out (hello)");
200            return;
201        }
202    };
203
204    // Reject version mismatch
205    if hello != PROTOCOL_VERSION {
206        let _ = timed_send(
207            &mut framed,
208            Frame::Error {
209                message: format!(
210                    "protocol version mismatch: client={hello} server={PROTOCOL_VERSION}; \
211                     both sides must run the same version"
212                ),
213            },
214        )
215        .await;
216        return;
217    }
218    if timed_send(&mut framed, Frame::HelloAck { version: PROTOCOL_VERSION }).await.is_err() {
219        return;
220    }
221
222    // Read control frame (5s timeout)
223    let frame = match tokio::time::timeout(Duration::from_secs(5), framed.next()).await {
224        Ok(Some(Ok(f))) => f,
225        Ok(Some(Err(e))) => {
226            warn!("frame decode error: {e}");
227            return;
228        }
229        Ok(None) => return,
230        Err(_) => {
231            warn!("control connection timed out");
232            return;
233        }
234    };
235
236    let _ = tx.send((frame, framed)).await;
237}
238
239/// Run the daemon, listening on its socket.
240///
241/// If `ready_fd` is provided, a single byte is written to it after the socket
242/// is bound, then the fd is dropped. This unblocks the parent process after
243/// `daemonize()` forks.
244pub async fn run(ctl_path: &Path, ready_fd: Option<OwnedFd>) -> anyhow::Result<()> {
245    // Restrictive umask for all files/sockets created by the daemon
246    unsafe {
247        libc::umask(0o077);
248    }
249
250    // Ensure parent directory exists with secure permissions
251    if let Some(parent) = ctl_path.parent() {
252        crate::security::secure_create_dir_all(parent)?;
253    }
254
255    let listener = crate::security::bind_unix_listener(ctl_path)?;
256    info!(path = %ctl_path.display(), "daemon listening");
257
258    // Signal readiness to parent (daemonize pipe): [0x01][pid: u32 LE]
259    if let Some(fd) = ready_fd {
260        use std::io::Write;
261        let mut f = std::fs::File::from(fd);
262        let pid = std::process::id();
263        let mut buf = [0u8; 5];
264        buf[0] = 0x01;
265        buf[1..5].copy_from_slice(&pid.to_le_bytes());
266        let _ = f.write_all(&buf);
267        // f drops here, closing the pipe
268    }
269
270    // Write PID file
271    let pid_path = pid_file_path(ctl_path);
272    std::fs::write(&pid_path, std::process::id().to_string())?;
273
274    let mut sessions: HashMap<u32, SessionState> = HashMap::new();
275    let mut next_id: u32 = 0;
276    let mut last_attached: Option<u32> = None;
277    let session_config = crate::config::ConfigFile::load().resolve_session(None);
278    let ring_buffer_cap = session_config.ring_buffer_size as usize;
279    let oauth_tunnel_idle_timeout = session_config.oauth_tunnel_idle_timeout;
280
281    // Signal handlers
282    let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
283    let mut sigint = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())?;
284
285    // Channel for handshake results -- spawned tasks send completed handshakes here
286    let (conn_tx, mut conn_rx) = mpsc::channel::<(Frame, Framed<UnixStream, FrameCodec>)>(64);
287
288    loop {
289        reap_sessions(&mut sessions);
290
291        let should_break = tokio::select! {
292            result = listener.accept() => {
293                let (stream, _addr) = result?;
294                if let Err(e) = crate::security::verify_peer_uid(&stream) {
295                    warn!("{e}");
296                } else {
297                    let tx = conn_tx.clone();
298                    tokio::spawn(connection_handshake(stream, tx));
299                }
300                false
301            }
302            Some((frame, framed)) = conn_rx.recv() => {
303                dispatch_control(
304                    frame, framed, &mut sessions, &mut next_id, ctl_path, &mut last_attached,
305                    ring_buffer_cap, oauth_tunnel_idle_timeout,
306                ).await
307            }
308            _ = sigterm.recv() => {
309                info!("SIGTERM received, shutting down");
310                shutdown(&mut sessions, ctl_path);
311                true
312            }
313            _ = sigint.recv() => {
314                info!("SIGINT received, shutting down");
315                shutdown(&mut sessions, ctl_path);
316                true
317            }
318        };
319
320        if should_break {
321            break;
322        }
323    }
324
325    Ok(())
326}
327
328/// Dispatch a single control frame. Takes ownership of the framed connection
329/// so it can be handed off to session tasks when needed. Returns `true` for
330/// KillServer (daemon should exit).
331#[allow(clippy::too_many_arguments)]
332async fn dispatch_control(
333    frame: Frame,
334    mut framed: Framed<UnixStream, FrameCodec>,
335    sessions: &mut HashMap<u32, SessionState>,
336    next_id: &mut u32,
337    ctl_path: &Path,
338    last_attached: &mut Option<u32>,
339    ring_buffer_cap: usize,
340    oauth_tunnel_idle_timeout: u64,
341) -> bool {
342    match frame {
343        Frame::NewSession { name, command } => {
344            // Reject names containing control characters
345            let name_opt = if name.is_empty() { None } else { Some(name) };
346            let command_opt = if command.is_empty() { None } else { Some(command) };
347            if let Some(ref n) = name_opt {
348                if n.bytes().any(|b| b.is_ascii_control()) {
349                    let _ = timed_send(
350                        &mut framed,
351                        Frame::Error {
352                            message: "session name must not contain control characters".to_string(),
353                        },
354                    )
355                    .await;
356                    return false;
357                }
358                if n.parse::<u32>().is_ok() {
359                    let _ = timed_send(
360                        &mut framed,
361                        Frame::Error {
362                            message: "session name must not be purely numeric (ambiguous with session IDs)".to_string(),
363                        },
364                    )
365                    .await;
366                    return false;
367                }
368                let dup = sessions.values().any(|s| s.name.as_deref() == Some(n));
369                if dup {
370                    let _ = timed_send(
371                        &mut framed,
372                        Frame::Error { message: format!("session name already exists: {n}") },
373                    )
374                    .await;
375                    return false;
376                }
377            }
378
379            let id = *next_id;
380            *next_id += 1;
381
382            let (client_tx, client_rx) = mpsc::unbounded_channel();
383            let metadata = Arc::new(OnceLock::new());
384            let meta_clone = Arc::clone(&metadata);
385            let sock_dir = ctl_path.parent().expect("ctl_path must have a parent");
386            let agent_socket_path = sock_dir.join(format!("agent-{id}.sock"));
387            let svc_socket_path = sock_dir.join(format!("svc-{id}.sock"));
388            let name_for_server = name_opt.clone();
389            let cmd_for_server = command_opt;
390            let handle = tokio::spawn(async move {
391                server::run(
392                    client_rx,
393                    meta_clone,
394                    agent_socket_path,
395                    svc_socket_path,
396                    id,
397                    name_for_server,
398                    cmd_for_server,
399                    ring_buffer_cap,
400                    oauth_tunnel_idle_timeout,
401                )
402                .await
403            });
404
405            sessions.insert(
406                id,
407                SessionState {
408                    handle,
409                    metadata,
410                    client_tx: client_tx.clone(),
411                    name: name_opt.clone(),
412                },
413            );
414
415            info!(id, name = ?name_opt, "session created");
416
417            if timed_send(&mut framed, Frame::SessionCreated { id: id.to_string() }).await.is_err()
418            {
419                return false;
420            }
421
422            // Hand off connection to session for auto-attach
423            *last_attached = Some(id);
424            let _ = client_tx.send(ClientConn::Active(framed));
425            false
426        }
427        Frame::Attach { session } => {
428            reap_sessions(sessions);
429            if let Some(id) = resolve_session(sessions, &session, *last_attached) {
430                let state = &sessions[&id];
431                if state.client_tx.is_closed() {
432                    sessions.remove(&id);
433                    let _ = timed_send(
434                        &mut framed,
435                        Frame::Error { message: format!("no such session: {session}") },
436                    )
437                    .await;
438                } else if timed_send(&mut framed, Frame::Ok).await.is_ok() {
439                    *last_attached = Some(id);
440                    let _ = state.client_tx.send(ClientConn::Active(framed));
441                }
442            } else {
443                let _ = timed_send(
444                    &mut framed,
445                    Frame::Error { message: format!("no such session: {session}") },
446                )
447                .await;
448            }
449            false
450        }
451        Frame::Tail { session } => {
452            reap_sessions(sessions);
453            if let Some(id) = resolve_session(sessions, &session, *last_attached) {
454                let state = &sessions[&id];
455                if state.client_tx.is_closed() {
456                    sessions.remove(&id);
457                    let _ = timed_send(
458                        &mut framed,
459                        Frame::Error { message: format!("no such session: {session}") },
460                    )
461                    .await;
462                } else if timed_send(&mut framed, Frame::Ok).await.is_ok() {
463                    let _ = state.client_tx.send(ClientConn::Tail(framed));
464                }
465            } else {
466                let _ = timed_send(
467                    &mut framed,
468                    Frame::Error { message: format!("no such session: {session}") },
469                )
470                .await;
471            }
472            false
473        }
474        Frame::ListSessions => {
475            reap_sessions(sessions);
476            let entries = build_session_entries(sessions);
477            let _ = timed_send(&mut framed, Frame::SessionInfo { sessions: entries }).await;
478            false
479        }
480        Frame::KillSession { session } => {
481            reap_sessions(sessions);
482            if let Some(id) = resolve_session(sessions, &session, *last_attached) {
483                let state = sessions.remove(&id).unwrap();
484                state.handle.abort();
485                info!(id, "session killed");
486                let _ = timed_send(&mut framed, Frame::Ok).await;
487            } else {
488                let _ = timed_send(
489                    &mut framed,
490                    Frame::Error { message: format!("no such session: {session}") },
491                )
492                .await;
493            }
494            false
495        }
496        Frame::SendFile { session, .. } => {
497            reap_sessions(sessions);
498            if let Some(id) = resolve_session(sessions, &session, *last_attached) {
499                let state = &sessions[&id];
500                if state.client_tx.is_closed() {
501                    sessions.remove(&id);
502                    let _ = timed_send(
503                        &mut framed,
504                        Frame::Error { message: format!("no such session: {session}") },
505                    )
506                    .await;
507                } else if timed_send(&mut framed, Frame::Ok).await.is_ok() {
508                    let stream = framed.into_inner();
509                    let _ = state.client_tx.send(ClientConn::Send(stream));
510                }
511            } else {
512                let _ = timed_send(
513                    &mut framed,
514                    Frame::Error { message: format!("no such session: {session}") },
515                )
516                .await;
517            }
518            false
519        }
520        Frame::RenameSession { session, new_name } => {
521            reap_sessions(sessions);
522            if let Some(id) = resolve_session(sessions, &session, *last_attached) {
523                if new_name.is_empty() {
524                    let _ = timed_send(
525                        &mut framed,
526                        Frame::Error { message: "new name must not be empty".to_string() },
527                    )
528                    .await;
529                } else if new_name.bytes().any(|b| b.is_ascii_control()) {
530                    let _ = timed_send(
531                        &mut framed,
532                        Frame::Error {
533                            message: "session name must not contain control characters".to_string(),
534                        },
535                    )
536                    .await;
537                } else if new_name.parse::<u32>().is_ok() {
538                    let _ = timed_send(
539                        &mut framed,
540                        Frame::Error {
541                            message: "session name must not be purely numeric (ambiguous with session IDs)".to_string(),
542                        },
543                    )
544                    .await;
545                } else if sessions.values().any(|s| s.name.as_deref() == Some(&new_name)) {
546                    let _ = timed_send(
547                        &mut framed,
548                        Frame::Error {
549                            message: format!("session name already exists: {new_name}"),
550                        },
551                    )
552                    .await;
553                } else {
554                    sessions.get_mut(&id).unwrap().name = Some(new_name.clone());
555                    info!(id, new_name, "session renamed");
556                    let _ = timed_send(&mut framed, Frame::Ok).await;
557                }
558            } else {
559                let _ = timed_send(
560                    &mut framed,
561                    Frame::Error { message: format!("no such session: {session}") },
562                )
563                .await;
564            }
565            false
566        }
567        Frame::KillServer => {
568            info!("kill-server received, shutting down");
569            shutdown(sessions, ctl_path);
570            let _ = timed_send(&mut framed, Frame::Ok).await;
571            true
572        }
573        other => {
574            error!(?other, "unexpected frame on control socket");
575            let _ = timed_send(
576                &mut framed,
577                Frame::Error { message: "unexpected frame type".to_string() },
578            )
579            .await;
580            false
581        }
582    }
583}