Skip to main content

gritty/
daemon.rs

1use crate::protocol::{Frame, FrameCodec, SessionEntry};
2use crate::server::{self, SessionMetadata};
3use futures_util::{SinkExt, StreamExt};
4use std::collections::HashMap;
5use std::os::fd::OwnedFd;
6use std::path::{Path, PathBuf};
7use std::sync::atomic::Ordering;
8use std::sync::{Arc, OnceLock};
9use tokio::net::UnixStream;
10use tokio::sync::mpsc;
11use tokio::task::JoinHandle;
12use tokio_util::codec::Framed;
13use tracing::{error, info, warn};
14
15struct SessionState {
16    handle: JoinHandle<anyhow::Result<()>>,
17    metadata: Arc<OnceLock<SessionMetadata>>,
18    client_tx: mpsc::UnboundedSender<Framed<UnixStream, FrameCodec>>,
19    name: Option<String>,
20}
21
22/// Returns the base directory for gritty sockets.
23/// Prefers $XDG_RUNTIME_DIR/gritty, falls back to /tmp/gritty-$UID.
24pub fn socket_dir() -> PathBuf {
25    if let Ok(xdg) = std::env::var("XDG_RUNTIME_DIR") {
26        PathBuf::from(xdg).join("gritty")
27    } else {
28        let uid = unsafe { libc::getuid() };
29        PathBuf::from(format!("/tmp/gritty-{uid}"))
30    }
31}
32
33/// Returns the daemon socket path.
34pub fn control_socket_path() -> PathBuf {
35    socket_dir().join("ctl.sock")
36}
37
38/// Returns the PID file path (sibling to ctl.sock).
39pub fn pid_file_path(ctl_path: &Path) -> PathBuf {
40    ctl_path.with_file_name("daemon.pid")
41}
42
43fn reap_sessions(sessions: &mut HashMap<u32, SessionState>) {
44    sessions.retain(|id, state| {
45        if state.handle.is_finished() {
46            info!(id, "session ended");
47            false
48        } else {
49            true
50        }
51    });
52}
53
54/// Resolve a session identifier (name or id string) to a session id.
55fn resolve_session(sessions: &HashMap<u32, SessionState>, target: &str) -> Option<u32> {
56    // Try name match first
57    for (&id, state) in sessions {
58        if state.name.as_deref() == Some(target) {
59            return Some(id);
60        }
61    }
62    // Then try parsing as numeric id
63    if let Ok(id) = target.parse::<u32>()
64        && sessions.contains_key(&id)
65    {
66        return Some(id);
67    }
68    None
69}
70
71fn build_session_entries(sessions: &HashMap<u32, SessionState>) -> Vec<SessionEntry> {
72    let mut entries: Vec<_> = sessions
73        .iter()
74        .map(|(&id, state)| {
75            if let Some(meta) = state.metadata.get() {
76                SessionEntry {
77                    id: id.to_string(),
78                    name: state.name.clone().unwrap_or_default(),
79                    pty_path: meta.pty_path.clone(),
80                    shell_pid: meta.shell_pid,
81                    created_at: meta.created_at,
82                    attached: meta.attached.load(Ordering::Relaxed),
83                    last_heartbeat: meta.last_heartbeat.load(Ordering::Relaxed),
84                }
85            } else {
86                SessionEntry {
87                    id: id.to_string(),
88                    name: state.name.clone().unwrap_or_default(),
89                    pty_path: String::new(),
90                    shell_pid: 0,
91                    created_at: 0,
92                    attached: false,
93                    last_heartbeat: 0,
94                }
95            }
96        })
97        .collect();
98    entries.sort_by_key(|e| e.id.parse::<u32>().unwrap_or(u32::MAX));
99    entries
100}
101
102fn shutdown(sessions: &mut HashMap<u32, SessionState>, ctl_path: &Path) {
103    for (id, state) in sessions.drain() {
104        state.handle.abort();
105        info!(id, "session aborted");
106    }
107    let _ = std::fs::remove_file(ctl_path);
108    let _ = std::fs::remove_file(pid_file_path(ctl_path));
109}
110
111/// Run the daemon, listening on its socket.
112///
113/// If `ready_fd` is provided, a single byte is written to it after the socket
114/// is bound, then the fd is dropped. This unblocks the parent process after
115/// `daemonize()` forks.
116pub async fn run(ctl_path: &Path, ready_fd: Option<OwnedFd>) -> anyhow::Result<()> {
117    // Restrictive umask for all files/sockets created by the daemon
118    unsafe {
119        libc::umask(0o077);
120    }
121
122    // Ensure parent directory exists with secure permissions
123    if let Some(parent) = ctl_path.parent() {
124        crate::security::secure_create_dir_all(parent)?;
125    }
126
127    let listener = crate::security::bind_unix_listener(ctl_path)?;
128    info!(path = %ctl_path.display(), "daemon listening");
129
130    // Signal readiness to parent (daemonize pipe)
131    if let Some(fd) = ready_fd {
132        use std::io::Write;
133        let mut f = std::fs::File::from(fd);
134        let _ = f.write_all(&[1]);
135        // f drops here, closing the pipe
136    }
137
138    // Write PID file
139    let pid_path = pid_file_path(ctl_path);
140    std::fs::write(&pid_path, std::process::id().to_string())?;
141
142    let mut sessions: HashMap<u32, SessionState> = HashMap::new();
143    let mut next_id: u32 = 0;
144
145    // Signal handlers
146    let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
147    let mut sigint = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())?;
148
149    loop {
150        reap_sessions(&mut sessions);
151
152        let stream = tokio::select! {
153            result = listener.accept() => {
154                let (stream, _addr) = result?;
155                stream
156            }
157            _ = sigterm.recv() => {
158                info!("SIGTERM received, shutting down");
159                shutdown(&mut sessions, ctl_path);
160                break;
161            }
162            _ = sigint.recv() => {
163                info!("SIGINT received, shutting down");
164                shutdown(&mut sessions, ctl_path);
165                break;
166            }
167        };
168
169        if let Err(e) = crate::security::verify_peer_uid(&stream) {
170            warn!("{e}");
171            continue;
172        }
173        let mut framed = Framed::new(stream, FrameCodec);
174
175        // Handle one control request per connection
176        let Some(Ok(frame)) = framed.next().await else {
177            continue;
178        };
179
180        match frame {
181            Frame::NewSession { name } => {
182                // Check for duplicate name
183                let name_opt = if name.is_empty() { None } else { Some(name) };
184                if let Some(ref n) = name_opt {
185                    let dup = sessions.values().any(|s| s.name.as_deref() == Some(n));
186                    if dup {
187                        let _ = framed
188                            .send(Frame::Error {
189                                message: format!("session name already exists: {n}"),
190                            })
191                            .await;
192                        continue;
193                    }
194                }
195
196                let id = next_id;
197                next_id += 1;
198
199                let (client_tx, client_rx) = mpsc::unbounded_channel();
200                let metadata = Arc::new(OnceLock::new());
201                let meta_clone = Arc::clone(&metadata);
202                let sock_dir = ctl_path.parent().expect("ctl_path must have a parent");
203                let agent_socket_path = sock_dir.join(format!("agent-{id}.sock"));
204                let open_socket_path = sock_dir.join(format!("open-{id}.sock"));
205                let handle = tokio::spawn(async move {
206                    server::run(client_rx, meta_clone, agent_socket_path, open_socket_path).await
207                });
208
209                sessions.insert(
210                    id,
211                    SessionState {
212                        handle,
213                        metadata,
214                        client_tx: client_tx.clone(),
215                        name: name_opt.clone(),
216                    },
217                );
218
219                info!(id, name = ?name_opt, "session created");
220
221                let _ = framed.send(Frame::SessionCreated { id: id.to_string() }).await;
222
223                // Hand off connection to session for auto-attach
224                let _ = client_tx.send(framed);
225            }
226            Frame::Attach { session } => {
227                reap_sessions(&mut sessions);
228                if let Some(id) = resolve_session(&sessions, &session) {
229                    let state = &sessions[&id];
230                    if state.client_tx.is_closed() {
231                        sessions.remove(&id);
232                        let _ = framed
233                            .send(Frame::Error { message: format!("no such session: {session}") })
234                            .await;
235                    } else {
236                        let _ = framed.send(Frame::Ok).await;
237                        let _ = state.client_tx.send(framed);
238                    }
239                } else {
240                    let _ = framed
241                        .send(Frame::Error { message: format!("no such session: {session}") })
242                        .await;
243                }
244            }
245            Frame::ListSessions => {
246                reap_sessions(&mut sessions);
247                let entries = build_session_entries(&sessions);
248                let _ = framed.send(Frame::SessionInfo { sessions: entries }).await;
249            }
250            Frame::KillSession { session } => {
251                reap_sessions(&mut sessions);
252                if let Some(id) = resolve_session(&sessions, &session) {
253                    let state = sessions.remove(&id).unwrap();
254                    state.handle.abort();
255                    info!(id, "session killed");
256                    let _ = framed.send(Frame::Ok).await;
257                } else {
258                    let _ = framed
259                        .send(Frame::Error { message: format!("no such session: {session}") })
260                        .await;
261                }
262            }
263            Frame::KillServer => {
264                info!("kill-server received, shutting down");
265                shutdown(&mut sessions, ctl_path);
266                let _ = framed.send(Frame::Ok).await;
267                break;
268            }
269            other => {
270                error!(?other, "unexpected frame on control socket");
271                let _ = framed
272                    .send(Frame::Error { message: "unexpected frame type".to_string() })
273                    .await;
274            }
275        }
276    }
277
278    Ok(())
279}