1use crate::protocol::{Frame, FrameCodec, SessionEntry};
2use crate::server::{self, ClientConn, SessionMetadata};
3use futures_util::{SinkExt, StreamExt};
4use std::collections::HashMap;
5use std::os::fd::OwnedFd;
6use std::path::{Path, PathBuf};
7use std::sync::atomic::Ordering;
8use std::sync::{Arc, OnceLock};
9use tokio::sync::mpsc;
10use tokio::task::JoinHandle;
11use tokio_util::codec::Framed;
12use tracing::{error, info, warn};
13
14struct SessionState {
15 handle: JoinHandle<anyhow::Result<()>>,
16 metadata: Arc<OnceLock<SessionMetadata>>,
17 client_tx: mpsc::UnboundedSender<ClientConn>,
18 name: Option<String>,
19}
20
21pub fn socket_dir() -> PathBuf {
24 if let Ok(xdg) = std::env::var("XDG_RUNTIME_DIR") {
25 PathBuf::from(xdg).join("gritty")
26 } else {
27 let uid = unsafe { libc::getuid() };
28 PathBuf::from(format!("/tmp/gritty-{uid}"))
29 }
30}
31
32pub fn control_socket_path() -> PathBuf {
34 socket_dir().join("ctl.sock")
35}
36
37pub fn pid_file_path(ctl_path: &Path) -> PathBuf {
39 ctl_path.with_file_name("daemon.pid")
40}
41
42fn reap_sessions(sessions: &mut HashMap<u32, SessionState>) {
43 sessions.retain(|id, state| {
44 if state.handle.is_finished() {
45 info!(id, "session ended");
46 false
47 } else {
48 true
49 }
50 });
51}
52
53fn resolve_session(sessions: &HashMap<u32, SessionState>, target: &str) -> Option<u32> {
55 for (&id, state) in sessions {
57 if state.name.as_deref() == Some(target) {
58 return Some(id);
59 }
60 }
61 if let Ok(id) = target.parse::<u32>()
63 && sessions.contains_key(&id)
64 {
65 return Some(id);
66 }
67 None
68}
69
70fn build_session_entries(sessions: &HashMap<u32, SessionState>) -> Vec<SessionEntry> {
71 let mut entries: Vec<_> = sessions
72 .iter()
73 .map(|(&id, state)| {
74 if let Some(meta) = state.metadata.get() {
75 SessionEntry {
76 id: id.to_string(),
77 name: state.name.clone().unwrap_or_default(),
78 pty_path: meta.pty_path.clone(),
79 shell_pid: meta.shell_pid,
80 created_at: meta.created_at,
81 attached: meta.attached.load(Ordering::Relaxed),
82 last_heartbeat: meta.last_heartbeat.load(Ordering::Relaxed),
83 }
84 } else {
85 SessionEntry {
86 id: id.to_string(),
87 name: state.name.clone().unwrap_or_default(),
88 pty_path: String::new(),
89 shell_pid: 0,
90 created_at: 0,
91 attached: false,
92 last_heartbeat: 0,
93 }
94 }
95 })
96 .collect();
97 entries.sort_by_key(|e| e.id.parse::<u32>().unwrap_or(u32::MAX));
98 entries
99}
100
101fn shutdown(sessions: &mut HashMap<u32, SessionState>, ctl_path: &Path) {
102 for (id, state) in sessions.drain() {
103 state.handle.abort();
104 info!(id, "session aborted");
105 }
106 let _ = std::fs::remove_file(ctl_path);
107 let _ = std::fs::remove_file(pid_file_path(ctl_path));
108}
109
110pub async fn run(ctl_path: &Path, ready_fd: Option<OwnedFd>) -> anyhow::Result<()> {
116 unsafe {
118 libc::umask(0o077);
119 }
120
121 if let Some(parent) = ctl_path.parent() {
123 crate::security::secure_create_dir_all(parent)?;
124 }
125
126 let listener = crate::security::bind_unix_listener(ctl_path)?;
127 info!(path = %ctl_path.display(), "daemon listening");
128
129 if let Some(fd) = ready_fd {
131 use std::io::Write;
132 let mut f = std::fs::File::from(fd);
133 let _ = f.write_all(&[1]);
134 }
136
137 let pid_path = pid_file_path(ctl_path);
139 std::fs::write(&pid_path, std::process::id().to_string())?;
140
141 let mut sessions: HashMap<u32, SessionState> = HashMap::new();
142 let mut next_id: u32 = 0;
143
144 let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
146 let mut sigint = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())?;
147
148 loop {
149 reap_sessions(&mut sessions);
150
151 let stream = tokio::select! {
152 result = listener.accept() => {
153 let (stream, _addr) = result?;
154 stream
155 }
156 _ = sigterm.recv() => {
157 info!("SIGTERM received, shutting down");
158 shutdown(&mut sessions, ctl_path);
159 break;
160 }
161 _ = sigint.recv() => {
162 info!("SIGINT received, shutting down");
163 shutdown(&mut sessions, ctl_path);
164 break;
165 }
166 };
167
168 if let Err(e) = crate::security::verify_peer_uid(&stream) {
169 warn!("{e}");
170 continue;
171 }
172 let mut framed = Framed::new(stream, FrameCodec);
173
174 let Some(Ok(frame)) = framed.next().await else {
176 continue;
177 };
178
179 match frame {
180 Frame::NewSession { name } => {
181 let name_opt = if name.is_empty() { None } else { Some(name) };
184 if let Some(ref n) = name_opt {
185 if n.bytes().any(|b| b.is_ascii_control()) {
186 let _ = framed
187 .send(Frame::Error {
188 message: "session name must not contain control characters"
189 .to_string(),
190 })
191 .await;
192 continue;
193 }
194 let dup = sessions.values().any(|s| s.name.as_deref() == Some(n));
195 if dup {
196 let _ = framed
197 .send(Frame::Error {
198 message: format!("session name already exists: {n}"),
199 })
200 .await;
201 continue;
202 }
203 }
204
205 let id = next_id;
206 next_id += 1;
207
208 let (client_tx, client_rx) = mpsc::unbounded_channel();
209 let metadata = Arc::new(OnceLock::new());
210 let meta_clone = Arc::clone(&metadata);
211 let sock_dir = ctl_path.parent().expect("ctl_path must have a parent");
212 let agent_socket_path = sock_dir.join(format!("agent-{id}.sock"));
213 let open_socket_path = sock_dir.join(format!("open-{id}.sock"));
214 let handle = tokio::spawn(async move {
215 server::run(client_rx, meta_clone, agent_socket_path, open_socket_path).await
216 });
217
218 sessions.insert(
219 id,
220 SessionState {
221 handle,
222 metadata,
223 client_tx: client_tx.clone(),
224 name: name_opt.clone(),
225 },
226 );
227
228 info!(id, name = ?name_opt, "session created");
229
230 let _ = framed.send(Frame::SessionCreated { id: id.to_string() }).await;
231
232 let _ = client_tx.send(ClientConn::Active(framed));
234 }
235 Frame::Attach { session } => {
236 reap_sessions(&mut sessions);
237 if let Some(id) = resolve_session(&sessions, &session) {
238 let state = &sessions[&id];
239 if state.client_tx.is_closed() {
240 sessions.remove(&id);
241 let _ = framed
242 .send(Frame::Error { message: format!("no such session: {session}") })
243 .await;
244 } else {
245 let _ = framed.send(Frame::Ok).await;
246 let _ = state.client_tx.send(ClientConn::Active(framed));
247 }
248 } else {
249 let _ = framed
250 .send(Frame::Error { message: format!("no such session: {session}") })
251 .await;
252 }
253 }
254 Frame::Tail { session } => {
255 reap_sessions(&mut sessions);
256 if let Some(id) = resolve_session(&sessions, &session) {
257 let state = &sessions[&id];
258 if state.client_tx.is_closed() {
259 sessions.remove(&id);
260 let _ = framed
261 .send(Frame::Error { message: format!("no such session: {session}") })
262 .await;
263 } else {
264 let _ = framed.send(Frame::Ok).await;
265 let _ = state.client_tx.send(ClientConn::Tail(framed));
266 }
267 } else {
268 let _ = framed
269 .send(Frame::Error { message: format!("no such session: {session}") })
270 .await;
271 }
272 }
273 Frame::ListSessions => {
274 reap_sessions(&mut sessions);
275 let entries = build_session_entries(&sessions);
276 let _ = framed.send(Frame::SessionInfo { sessions: entries }).await;
277 }
278 Frame::KillSession { session } => {
279 reap_sessions(&mut sessions);
280 if let Some(id) = resolve_session(&sessions, &session) {
281 let state = sessions.remove(&id).unwrap();
282 state.handle.abort();
283 info!(id, "session killed");
284 let _ = framed.send(Frame::Ok).await;
285 } else {
286 let _ = framed
287 .send(Frame::Error { message: format!("no such session: {session}") })
288 .await;
289 }
290 }
291 Frame::KillServer => {
292 info!("kill-server received, shutting down");
293 shutdown(&mut sessions, ctl_path);
294 let _ = framed.send(Frame::Ok).await;
295 break;
296 }
297 other => {
298 error!(?other, "unexpected frame on control socket");
299 let _ = framed
300 .send(Frame::Error { message: "unexpected frame type".to_string() })
301 .await;
302 }
303 }
304 }
305
306 Ok(())
307}