Skip to main content

gritty/
daemon.rs

1use crate::protocol::{Frame, FrameCodec, SessionEntry};
2use crate::server::{self, ClientConn, SessionMetadata};
3use futures_util::{SinkExt, StreamExt};
4use std::collections::HashMap;
5use std::os::fd::OwnedFd;
6use std::path::{Path, PathBuf};
7use std::sync::atomic::Ordering;
8use std::sync::{Arc, OnceLock};
9use tokio::sync::mpsc;
10use tokio::task::JoinHandle;
11use tokio_util::codec::Framed;
12use tracing::{error, info, warn};
13
14struct SessionState {
15    handle: JoinHandle<anyhow::Result<()>>,
16    metadata: Arc<OnceLock<SessionMetadata>>,
17    client_tx: mpsc::UnboundedSender<ClientConn>,
18    name: Option<String>,
19}
20
21/// Returns the base directory for gritty sockets.
22/// Prefers $XDG_RUNTIME_DIR/gritty, falls back to /tmp/gritty-$UID.
23pub fn socket_dir() -> PathBuf {
24    if let Ok(xdg) = std::env::var("XDG_RUNTIME_DIR") {
25        PathBuf::from(xdg).join("gritty")
26    } else {
27        let uid = unsafe { libc::getuid() };
28        PathBuf::from(format!("/tmp/gritty-{uid}"))
29    }
30}
31
32/// Returns the daemon socket path.
33pub fn control_socket_path() -> PathBuf {
34    socket_dir().join("ctl.sock")
35}
36
37/// Returns the PID file path (sibling to ctl.sock).
38pub fn pid_file_path(ctl_path: &Path) -> PathBuf {
39    ctl_path.with_file_name("daemon.pid")
40}
41
42fn reap_sessions(sessions: &mut HashMap<u32, SessionState>) {
43    sessions.retain(|id, state| {
44        if state.handle.is_finished() {
45            info!(id, "session ended");
46            false
47        } else {
48            true
49        }
50    });
51}
52
53/// Resolve a session identifier (name or id string) to a session id.
54fn resolve_session(sessions: &HashMap<u32, SessionState>, target: &str) -> Option<u32> {
55    // Try name match first
56    for (&id, state) in sessions {
57        if state.name.as_deref() == Some(target) {
58            return Some(id);
59        }
60    }
61    // Then try parsing as numeric id
62    if let Ok(id) = target.parse::<u32>()
63        && sessions.contains_key(&id)
64    {
65        return Some(id);
66    }
67    None
68}
69
70fn build_session_entries(sessions: &HashMap<u32, SessionState>) -> Vec<SessionEntry> {
71    let mut entries: Vec<_> = sessions
72        .iter()
73        .map(|(&id, state)| {
74            if let Some(meta) = state.metadata.get() {
75                SessionEntry {
76                    id: id.to_string(),
77                    name: state.name.clone().unwrap_or_default(),
78                    pty_path: meta.pty_path.clone(),
79                    shell_pid: meta.shell_pid,
80                    created_at: meta.created_at,
81                    attached: meta.attached.load(Ordering::Relaxed),
82                    last_heartbeat: meta.last_heartbeat.load(Ordering::Relaxed),
83                }
84            } else {
85                SessionEntry {
86                    id: id.to_string(),
87                    name: state.name.clone().unwrap_or_default(),
88                    pty_path: String::new(),
89                    shell_pid: 0,
90                    created_at: 0,
91                    attached: false,
92                    last_heartbeat: 0,
93                }
94            }
95        })
96        .collect();
97    entries.sort_by_key(|e| e.id.parse::<u32>().unwrap_or(u32::MAX));
98    entries
99}
100
101fn shutdown(sessions: &mut HashMap<u32, SessionState>, ctl_path: &Path) {
102    for (id, state) in sessions.drain() {
103        state.handle.abort();
104        info!(id, "session aborted");
105    }
106    let _ = std::fs::remove_file(ctl_path);
107    let _ = std::fs::remove_file(pid_file_path(ctl_path));
108}
109
110/// Run the daemon, listening on its socket.
111///
112/// If `ready_fd` is provided, a single byte is written to it after the socket
113/// is bound, then the fd is dropped. This unblocks the parent process after
114/// `daemonize()` forks.
115pub async fn run(ctl_path: &Path, ready_fd: Option<OwnedFd>) -> anyhow::Result<()> {
116    // Restrictive umask for all files/sockets created by the daemon
117    unsafe {
118        libc::umask(0o077);
119    }
120
121    // Ensure parent directory exists with secure permissions
122    if let Some(parent) = ctl_path.parent() {
123        crate::security::secure_create_dir_all(parent)?;
124    }
125
126    let listener = crate::security::bind_unix_listener(ctl_path)?;
127    info!(path = %ctl_path.display(), "daemon listening");
128
129    // Signal readiness to parent (daemonize pipe)
130    if let Some(fd) = ready_fd {
131        use std::io::Write;
132        let mut f = std::fs::File::from(fd);
133        let _ = f.write_all(&[1]);
134        // f drops here, closing the pipe
135    }
136
137    // Write PID file
138    let pid_path = pid_file_path(ctl_path);
139    std::fs::write(&pid_path, std::process::id().to_string())?;
140
141    let mut sessions: HashMap<u32, SessionState> = HashMap::new();
142    let mut next_id: u32 = 0;
143
144    // Signal handlers
145    let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
146    let mut sigint = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())?;
147
148    loop {
149        reap_sessions(&mut sessions);
150
151        let stream = tokio::select! {
152            result = listener.accept() => {
153                let (stream, _addr) = result?;
154                stream
155            }
156            _ = sigterm.recv() => {
157                info!("SIGTERM received, shutting down");
158                shutdown(&mut sessions, ctl_path);
159                break;
160            }
161            _ = sigint.recv() => {
162                info!("SIGINT received, shutting down");
163                shutdown(&mut sessions, ctl_path);
164                break;
165            }
166        };
167
168        if let Err(e) = crate::security::verify_peer_uid(&stream) {
169            warn!("{e}");
170            continue;
171        }
172        let mut framed = Framed::new(stream, FrameCodec);
173
174        // Handle one control request per connection
175        let Some(Ok(frame)) = framed.next().await else {
176            continue;
177        };
178
179        match frame {
180            Frame::NewSession { name } => {
181                // Reject names containing control characters (prevents wire
182                // format corruption in tab/newline-delimited SessionInfo)
183                let name_opt = if name.is_empty() { None } else { Some(name) };
184                if let Some(ref n) = name_opt {
185                    if n.bytes().any(|b| b.is_ascii_control()) {
186                        let _ = framed
187                            .send(Frame::Error {
188                                message: "session name must not contain control characters"
189                                    .to_string(),
190                            })
191                            .await;
192                        continue;
193                    }
194                    let dup = sessions.values().any(|s| s.name.as_deref() == Some(n));
195                    if dup {
196                        let _ = framed
197                            .send(Frame::Error {
198                                message: format!("session name already exists: {n}"),
199                            })
200                            .await;
201                        continue;
202                    }
203                }
204
205                let id = next_id;
206                next_id += 1;
207
208                let (client_tx, client_rx) = mpsc::unbounded_channel();
209                let metadata = Arc::new(OnceLock::new());
210                let meta_clone = Arc::clone(&metadata);
211                let sock_dir = ctl_path.parent().expect("ctl_path must have a parent");
212                let agent_socket_path = sock_dir.join(format!("agent-{id}.sock"));
213                let open_socket_path = sock_dir.join(format!("open-{id}.sock"));
214                let handle = tokio::spawn(async move {
215                    server::run(client_rx, meta_clone, agent_socket_path, open_socket_path).await
216                });
217
218                sessions.insert(
219                    id,
220                    SessionState {
221                        handle,
222                        metadata,
223                        client_tx: client_tx.clone(),
224                        name: name_opt.clone(),
225                    },
226                );
227
228                info!(id, name = ?name_opt, "session created");
229
230                let _ = framed.send(Frame::SessionCreated { id: id.to_string() }).await;
231
232                // Hand off connection to session for auto-attach
233                let _ = client_tx.send(ClientConn::Active(framed));
234            }
235            Frame::Attach { session } => {
236                reap_sessions(&mut sessions);
237                if let Some(id) = resolve_session(&sessions, &session) {
238                    let state = &sessions[&id];
239                    if state.client_tx.is_closed() {
240                        sessions.remove(&id);
241                        let _ = framed
242                            .send(Frame::Error { message: format!("no such session: {session}") })
243                            .await;
244                    } else {
245                        let _ = framed.send(Frame::Ok).await;
246                        let _ = state.client_tx.send(ClientConn::Active(framed));
247                    }
248                } else {
249                    let _ = framed
250                        .send(Frame::Error { message: format!("no such session: {session}") })
251                        .await;
252                }
253            }
254            Frame::Tail { session } => {
255                reap_sessions(&mut sessions);
256                if let Some(id) = resolve_session(&sessions, &session) {
257                    let state = &sessions[&id];
258                    if state.client_tx.is_closed() {
259                        sessions.remove(&id);
260                        let _ = framed
261                            .send(Frame::Error { message: format!("no such session: {session}") })
262                            .await;
263                    } else {
264                        let _ = framed.send(Frame::Ok).await;
265                        let _ = state.client_tx.send(ClientConn::Tail(framed));
266                    }
267                } else {
268                    let _ = framed
269                        .send(Frame::Error { message: format!("no such session: {session}") })
270                        .await;
271                }
272            }
273            Frame::ListSessions => {
274                reap_sessions(&mut sessions);
275                let entries = build_session_entries(&sessions);
276                let _ = framed.send(Frame::SessionInfo { sessions: entries }).await;
277            }
278            Frame::KillSession { session } => {
279                reap_sessions(&mut sessions);
280                if let Some(id) = resolve_session(&sessions, &session) {
281                    let state = sessions.remove(&id).unwrap();
282                    state.handle.abort();
283                    info!(id, "session killed");
284                    let _ = framed.send(Frame::Ok).await;
285                } else {
286                    let _ = framed
287                        .send(Frame::Error { message: format!("no such session: {session}") })
288                        .await;
289                }
290            }
291            Frame::KillServer => {
292                info!("kill-server received, shutting down");
293                shutdown(&mut sessions, ctl_path);
294                let _ = framed.send(Frame::Ok).await;
295                break;
296            }
297            other => {
298                error!(?other, "unexpected frame on control socket");
299                let _ = framed
300                    .send(Frame::Error { message: "unexpected frame type".to_string() })
301                    .await;
302            }
303        }
304    }
305
306    Ok(())
307}