Skip to main content

gritty/
daemon.rs

1use crate::protocol::{Frame, FrameCodec, PROTOCOL_VERSION, SessionEntry};
2use crate::server::{self, ClientConn, SessionMetadata};
3use futures_util::{SinkExt, StreamExt};
4use std::collections::HashMap;
5use std::os::fd::OwnedFd;
6use std::path::{Path, PathBuf};
7use std::sync::atomic::Ordering;
8use std::sync::{Arc, OnceLock};
9use std::time::Duration;
10use tokio::net::UnixStream;
11use tokio::sync::mpsc;
12use tokio::task::JoinHandle;
13use tokio_util::codec::Framed;
14use tracing::{error, info, warn};
15
16const SEND_TIMEOUT: Duration = Duration::from_secs(5);
17
18/// Send a frame with a timeout. Returns `Ok(())` on success, `Err` on
19/// send failure or timeout (error is logged before returning).
20async fn timed_send(
21    framed: &mut Framed<UnixStream, FrameCodec>,
22    frame: Frame,
23) -> Result<(), std::io::Error> {
24    match tokio::time::timeout(SEND_TIMEOUT, framed.send(frame)).await {
25        Ok(Ok(())) => Ok(()),
26        Ok(Err(e)) => {
27            warn!("control send error: {e}");
28            Err(e)
29        }
30        Err(_) => {
31            warn!("control send timed out");
32            Err(std::io::Error::new(std::io::ErrorKind::TimedOut, "send timed out"))
33        }
34    }
35}
36
37struct SessionState {
38    handle: JoinHandle<anyhow::Result<()>>,
39    metadata: Arc<OnceLock<SessionMetadata>>,
40    client_tx: mpsc::UnboundedSender<ClientConn>,
41    name: Option<String>,
42}
43
44/// Returns the base directory for gritty sockets.
45/// Prefers $XDG_RUNTIME_DIR/gritty, falls back to /tmp/gritty-$UID.
46pub fn socket_dir() -> PathBuf {
47    if let Ok(xdg) = std::env::var("XDG_RUNTIME_DIR") {
48        PathBuf::from(xdg).join("gritty")
49    } else {
50        let uid = unsafe { libc::getuid() };
51        PathBuf::from(format!("/tmp/gritty-{uid}"))
52    }
53}
54
55/// Returns the daemon socket path.
56pub fn control_socket_path() -> PathBuf {
57    socket_dir().join("ctl.sock")
58}
59
60/// Returns the PID file path (sibling to ctl.sock).
61pub fn pid_file_path(ctl_path: &Path) -> PathBuf {
62    ctl_path.with_file_name("daemon.pid")
63}
64
65fn reap_sessions(sessions: &mut HashMap<u32, SessionState>) {
66    sessions.retain(|id, state| {
67        if state.handle.is_finished() {
68            info!(id, "session ended");
69            false
70        } else {
71            true
72        }
73    });
74}
75
76/// Resolve a session identifier (name or id string) to a session id.
77fn resolve_session(sessions: &HashMap<u32, SessionState>, target: &str) -> Option<u32> {
78    // Try name match first
79    for (&id, state) in sessions {
80        if state.name.as_deref() == Some(target) {
81            return Some(id);
82        }
83    }
84    // Then try parsing as numeric id
85    if let Ok(id) = target.parse::<u32>()
86        && sessions.contains_key(&id)
87    {
88        return Some(id);
89    }
90    None
91}
92
93fn build_session_entries(sessions: &HashMap<u32, SessionState>) -> Vec<SessionEntry> {
94    let mut entries: Vec<_> = sessions
95        .iter()
96        .map(|(&id, state)| {
97            if let Some(meta) = state.metadata.get() {
98                SessionEntry {
99                    id: id.to_string(),
100                    name: state.name.clone().unwrap_or_default(),
101                    pty_path: meta.pty_path.clone(),
102                    shell_pid: meta.shell_pid,
103                    created_at: meta.created_at,
104                    attached: meta.attached.load(Ordering::Relaxed),
105                    last_heartbeat: meta.last_heartbeat.load(Ordering::Relaxed),
106                }
107            } else {
108                SessionEntry {
109                    id: id.to_string(),
110                    name: state.name.clone().unwrap_or_default(),
111                    pty_path: String::new(),
112                    shell_pid: 0,
113                    created_at: 0,
114                    attached: false,
115                    last_heartbeat: 0,
116                }
117            }
118        })
119        .collect();
120    entries.sort_by_key(|e| e.id.parse::<u32>().unwrap_or(u32::MAX));
121    entries
122}
123
124fn shutdown(sessions: &mut HashMap<u32, SessionState>, ctl_path: &Path) {
125    for (id, state) in sessions.drain() {
126        state.handle.abort();
127        info!(id, "session aborted");
128    }
129    let _ = std::fs::remove_file(ctl_path);
130    let _ = std::fs::remove_file(pid_file_path(ctl_path));
131}
132
133/// Run the daemon, listening on its socket.
134///
135/// If `ready_fd` is provided, a single byte is written to it after the socket
136/// is bound, then the fd is dropped. This unblocks the parent process after
137/// `daemonize()` forks.
138pub async fn run(ctl_path: &Path, ready_fd: Option<OwnedFd>) -> anyhow::Result<()> {
139    // Restrictive umask for all files/sockets created by the daemon
140    unsafe {
141        libc::umask(0o077);
142    }
143
144    // Ensure parent directory exists with secure permissions
145    if let Some(parent) = ctl_path.parent() {
146        crate::security::secure_create_dir_all(parent)?;
147    }
148
149    let listener = crate::security::bind_unix_listener(ctl_path)?;
150    info!(path = %ctl_path.display(), "daemon listening");
151
152    // Signal readiness to parent (daemonize pipe)
153    if let Some(fd) = ready_fd {
154        use std::io::Write;
155        let mut f = std::fs::File::from(fd);
156        let _ = f.write_all(&[1]);
157        // f drops here, closing the pipe
158    }
159
160    // Write PID file
161    let pid_path = pid_file_path(ctl_path);
162    std::fs::write(&pid_path, std::process::id().to_string())?;
163
164    let mut sessions: HashMap<u32, SessionState> = HashMap::new();
165    let mut next_id: u32 = 0;
166
167    // Signal handlers
168    let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
169    let mut sigint = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())?;
170
171    loop {
172        reap_sessions(&mut sessions);
173
174        let stream = tokio::select! {
175            result = listener.accept() => {
176                let (stream, _addr) = result?;
177                stream
178            }
179            _ = sigterm.recv() => {
180                info!("SIGTERM received, shutting down");
181                shutdown(&mut sessions, ctl_path);
182                break;
183            }
184            _ = sigint.recv() => {
185                info!("SIGINT received, shutting down");
186                shutdown(&mut sessions, ctl_path);
187                break;
188            }
189        };
190
191        if let Err(e) = crate::security::verify_peer_uid(&stream) {
192            warn!("{e}");
193            continue;
194        }
195        let mut framed = Framed::new(stream, FrameCodec);
196
197        // Read Hello handshake (5s timeout)
198        let hello = match tokio::time::timeout(Duration::from_secs(5), framed.next()).await {
199            Ok(Some(Ok(Frame::Hello { version }))) => version,
200            Ok(Some(Ok(_))) => {
201                let _ = timed_send(
202                    &mut framed,
203                    Frame::Error { message: "expected Hello handshake".to_string() },
204                )
205                .await;
206                continue;
207            }
208            Ok(Some(Err(e))) => {
209                warn!("frame decode error: {e}");
210                continue;
211            }
212            Ok(None) => continue,
213            Err(_) => {
214                warn!("control connection timed out (hello)");
215                continue;
216            }
217        };
218
219        // Send HelloAck with negotiated version
220        let negotiated = hello.min(PROTOCOL_VERSION);
221        if timed_send(&mut framed, Frame::HelloAck { version: negotiated }).await.is_err() {
222            continue;
223        }
224
225        // Read control frame (5s timeout)
226        let frame = match tokio::time::timeout(Duration::from_secs(5), framed.next()).await {
227            Ok(Some(Ok(f))) => f,
228            Ok(Some(Err(e))) => {
229                warn!("frame decode error: {e}");
230                continue;
231            }
232            Ok(None) => continue,
233            Err(_) => {
234                warn!("control connection timed out");
235                continue;
236            }
237        };
238
239        match frame {
240            Frame::NewSession { name } => {
241                // Reject names containing control characters (prevents wire
242                // format corruption in tab/newline-delimited SessionInfo)
243                let name_opt = if name.is_empty() { None } else { Some(name) };
244                if let Some(ref n) = name_opt {
245                    if n.bytes().any(|b| b.is_ascii_control()) {
246                        let _ = timed_send(
247                            &mut framed,
248                            Frame::Error {
249                                message: "session name must not contain control characters"
250                                    .to_string(),
251                            },
252                        )
253                        .await;
254                        continue;
255                    }
256                    let dup = sessions.values().any(|s| s.name.as_deref() == Some(n));
257                    if dup {
258                        let _ = timed_send(
259                            &mut framed,
260                            Frame::Error { message: format!("session name already exists: {n}") },
261                        )
262                        .await;
263                        continue;
264                    }
265                }
266
267                let id = next_id;
268                next_id += 1;
269
270                let (client_tx, client_rx) = mpsc::unbounded_channel();
271                let metadata = Arc::new(OnceLock::new());
272                let meta_clone = Arc::clone(&metadata);
273                let sock_dir = ctl_path.parent().expect("ctl_path must have a parent");
274                let agent_socket_path = sock_dir.join(format!("agent-{id}.sock"));
275                let open_socket_path = sock_dir.join(format!("open-{id}.sock"));
276                let handle = tokio::spawn(async move {
277                    server::run(client_rx, meta_clone, agent_socket_path, open_socket_path).await
278                });
279
280                sessions.insert(
281                    id,
282                    SessionState {
283                        handle,
284                        metadata,
285                        client_tx: client_tx.clone(),
286                        name: name_opt.clone(),
287                    },
288                );
289
290                info!(id, name = ?name_opt, "session created");
291
292                if timed_send(&mut framed, Frame::SessionCreated { id: id.to_string() })
293                    .await
294                    .is_err()
295                {
296                    continue;
297                }
298
299                // Hand off connection to session for auto-attach
300                let _ = client_tx.send(ClientConn::Active(framed));
301            }
302            Frame::Attach { session } => {
303                reap_sessions(&mut sessions);
304                if let Some(id) = resolve_session(&sessions, &session) {
305                    let state = &sessions[&id];
306                    if state.client_tx.is_closed() {
307                        sessions.remove(&id);
308                        let _ = timed_send(
309                            &mut framed,
310                            Frame::Error { message: format!("no such session: {session}") },
311                        )
312                        .await;
313                    } else if timed_send(&mut framed, Frame::Ok).await.is_ok() {
314                        let _ = state.client_tx.send(ClientConn::Active(framed));
315                    }
316                } else {
317                    let _ = timed_send(
318                        &mut framed,
319                        Frame::Error { message: format!("no such session: {session}") },
320                    )
321                    .await;
322                }
323            }
324            Frame::Tail { session } => {
325                reap_sessions(&mut sessions);
326                if let Some(id) = resolve_session(&sessions, &session) {
327                    let state = &sessions[&id];
328                    if state.client_tx.is_closed() {
329                        sessions.remove(&id);
330                        let _ = timed_send(
331                            &mut framed,
332                            Frame::Error { message: format!("no such session: {session}") },
333                        )
334                        .await;
335                    } else if timed_send(&mut framed, Frame::Ok).await.is_ok() {
336                        let _ = state.client_tx.send(ClientConn::Tail(framed));
337                    }
338                } else {
339                    let _ = timed_send(
340                        &mut framed,
341                        Frame::Error { message: format!("no such session: {session}") },
342                    )
343                    .await;
344                }
345            }
346            Frame::ListSessions => {
347                reap_sessions(&mut sessions);
348                let entries = build_session_entries(&sessions);
349                let _ = timed_send(&mut framed, Frame::SessionInfo { sessions: entries }).await;
350            }
351            Frame::KillSession { session } => {
352                reap_sessions(&mut sessions);
353                if let Some(id) = resolve_session(&sessions, &session) {
354                    let state = sessions.remove(&id).unwrap();
355                    state.handle.abort();
356                    info!(id, "session killed");
357                    let _ = timed_send(&mut framed, Frame::Ok).await;
358                } else {
359                    let _ = timed_send(
360                        &mut framed,
361                        Frame::Error { message: format!("no such session: {session}") },
362                    )
363                    .await;
364                }
365            }
366            Frame::KillServer => {
367                info!("kill-server received, shutting down");
368                shutdown(&mut sessions, ctl_path);
369                let _ = timed_send(&mut framed, Frame::Ok).await;
370                break;
371            }
372            other => {
373                error!(?other, "unexpected frame on control socket");
374                let _ = timed_send(
375                    &mut framed,
376                    Frame::Error { message: "unexpected frame type".to_string() },
377                )
378                .await;
379            }
380        }
381    }
382
383    Ok(())
384}