1use crate::protocol::{Frame, FrameCodec, PROTOCOL_VERSION, SessionEntry};
2use crate::server::{self, ClientConn, SessionMetadata};
3use futures_util::{SinkExt, StreamExt};
4use std::collections::HashMap;
5use std::os::fd::OwnedFd;
6use std::path::{Path, PathBuf};
7use std::sync::atomic::Ordering;
8use std::sync::{Arc, OnceLock};
9use std::time::Duration;
10use tokio::net::UnixStream;
11use tokio::sync::mpsc;
12use tokio::task::JoinHandle;
13use tokio_util::codec::Framed;
14use tracing::{error, info, warn};
15
16const SEND_TIMEOUT: Duration = Duration::from_secs(5);
17
18async fn timed_send(
21 framed: &mut Framed<UnixStream, FrameCodec>,
22 frame: Frame,
23) -> Result<(), std::io::Error> {
24 match tokio::time::timeout(SEND_TIMEOUT, framed.send(frame)).await {
25 Ok(Ok(())) => Ok(()),
26 Ok(Err(e)) => {
27 warn!("control send error: {e}");
28 Err(e)
29 }
30 Err(_) => {
31 warn!("control send timed out");
32 Err(std::io::Error::new(std::io::ErrorKind::TimedOut, "send timed out"))
33 }
34 }
35}
36
37struct SessionState {
38 handle: JoinHandle<anyhow::Result<()>>,
39 metadata: Arc<OnceLock<SessionMetadata>>,
40 client_tx: mpsc::UnboundedSender<ClientConn>,
41 name: Option<String>,
42}
43
44pub fn socket_dir() -> PathBuf {
47 if let Ok(xdg) = std::env::var("XDG_RUNTIME_DIR") {
48 PathBuf::from(xdg).join("gritty")
49 } else {
50 let uid = unsafe { libc::getuid() };
51 PathBuf::from(format!("/tmp/gritty-{uid}"))
52 }
53}
54
55pub fn control_socket_path() -> PathBuf {
57 socket_dir().join("ctl.sock")
58}
59
60pub fn pid_file_path(ctl_path: &Path) -> PathBuf {
62 ctl_path.with_file_name("daemon.pid")
63}
64
65fn reap_sessions(sessions: &mut HashMap<u32, SessionState>) {
66 sessions.retain(|id, state| {
67 if state.handle.is_finished() {
68 info!(id, "session ended");
69 false
70 } else {
71 true
72 }
73 });
74}
75
76fn resolve_session(sessions: &HashMap<u32, SessionState>, target: &str) -> Option<u32> {
78 for (&id, state) in sessions {
80 if state.name.as_deref() == Some(target) {
81 return Some(id);
82 }
83 }
84 if let Ok(id) = target.parse::<u32>()
86 && sessions.contains_key(&id)
87 {
88 return Some(id);
89 }
90 None
91}
92
93fn build_session_entries(sessions: &HashMap<u32, SessionState>) -> Vec<SessionEntry> {
94 let mut entries: Vec<_> = sessions
95 .iter()
96 .map(|(&id, state)| {
97 if let Some(meta) = state.metadata.get() {
98 SessionEntry {
99 id: id.to_string(),
100 name: state.name.clone().unwrap_or_default(),
101 pty_path: meta.pty_path.clone(),
102 shell_pid: meta.shell_pid,
103 created_at: meta.created_at,
104 attached: meta.attached.load(Ordering::Relaxed),
105 last_heartbeat: meta.last_heartbeat.load(Ordering::Relaxed),
106 }
107 } else {
108 SessionEntry {
109 id: id.to_string(),
110 name: state.name.clone().unwrap_or_default(),
111 pty_path: String::new(),
112 shell_pid: 0,
113 created_at: 0,
114 attached: false,
115 last_heartbeat: 0,
116 }
117 }
118 })
119 .collect();
120 entries.sort_by_key(|e| e.id.parse::<u32>().unwrap_or(u32::MAX));
121 entries
122}
123
124fn shutdown(sessions: &mut HashMap<u32, SessionState>, ctl_path: &Path) {
125 for (id, state) in sessions.drain() {
126 state.handle.abort();
127 info!(id, "session aborted");
128 }
129 let _ = std::fs::remove_file(ctl_path);
130 let _ = std::fs::remove_file(pid_file_path(ctl_path));
131}
132
133pub async fn run(ctl_path: &Path, ready_fd: Option<OwnedFd>) -> anyhow::Result<()> {
139 unsafe {
141 libc::umask(0o077);
142 }
143
144 if let Some(parent) = ctl_path.parent() {
146 crate::security::secure_create_dir_all(parent)?;
147 }
148
149 let listener = crate::security::bind_unix_listener(ctl_path)?;
150 info!(path = %ctl_path.display(), "daemon listening");
151
152 if let Some(fd) = ready_fd {
154 use std::io::Write;
155 let mut f = std::fs::File::from(fd);
156 let _ = f.write_all(&[1]);
157 }
159
160 let pid_path = pid_file_path(ctl_path);
162 std::fs::write(&pid_path, std::process::id().to_string())?;
163
164 let mut sessions: HashMap<u32, SessionState> = HashMap::new();
165 let mut next_id: u32 = 0;
166
167 let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
169 let mut sigint = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())?;
170
171 loop {
172 reap_sessions(&mut sessions);
173
174 let stream = tokio::select! {
175 result = listener.accept() => {
176 let (stream, _addr) = result?;
177 stream
178 }
179 _ = sigterm.recv() => {
180 info!("SIGTERM received, shutting down");
181 shutdown(&mut sessions, ctl_path);
182 break;
183 }
184 _ = sigint.recv() => {
185 info!("SIGINT received, shutting down");
186 shutdown(&mut sessions, ctl_path);
187 break;
188 }
189 };
190
191 if let Err(e) = crate::security::verify_peer_uid(&stream) {
192 warn!("{e}");
193 continue;
194 }
195 let mut framed = Framed::new(stream, FrameCodec);
196
197 let hello = match tokio::time::timeout(Duration::from_secs(5), framed.next()).await {
199 Ok(Some(Ok(Frame::Hello { version }))) => version,
200 Ok(Some(Ok(_))) => {
201 let _ = timed_send(
202 &mut framed,
203 Frame::Error { message: "expected Hello handshake".to_string() },
204 )
205 .await;
206 continue;
207 }
208 Ok(Some(Err(e))) => {
209 warn!("frame decode error: {e}");
210 continue;
211 }
212 Ok(None) => continue,
213 Err(_) => {
214 warn!("control connection timed out (hello)");
215 continue;
216 }
217 };
218
219 let negotiated = hello.min(PROTOCOL_VERSION);
221 if timed_send(&mut framed, Frame::HelloAck { version: negotiated }).await.is_err() {
222 continue;
223 }
224
225 let frame = match tokio::time::timeout(Duration::from_secs(5), framed.next()).await {
227 Ok(Some(Ok(f))) => f,
228 Ok(Some(Err(e))) => {
229 warn!("frame decode error: {e}");
230 continue;
231 }
232 Ok(None) => continue,
233 Err(_) => {
234 warn!("control connection timed out");
235 continue;
236 }
237 };
238
239 match frame {
240 Frame::NewSession { name } => {
241 let name_opt = if name.is_empty() { None } else { Some(name) };
244 if let Some(ref n) = name_opt {
245 if n.bytes().any(|b| b.is_ascii_control()) {
246 let _ = timed_send(
247 &mut framed,
248 Frame::Error {
249 message: "session name must not contain control characters"
250 .to_string(),
251 },
252 )
253 .await;
254 continue;
255 }
256 let dup = sessions.values().any(|s| s.name.as_deref() == Some(n));
257 if dup {
258 let _ = timed_send(
259 &mut framed,
260 Frame::Error { message: format!("session name already exists: {n}") },
261 )
262 .await;
263 continue;
264 }
265 }
266
267 let id = next_id;
268 next_id += 1;
269
270 let (client_tx, client_rx) = mpsc::unbounded_channel();
271 let metadata = Arc::new(OnceLock::new());
272 let meta_clone = Arc::clone(&metadata);
273 let sock_dir = ctl_path.parent().expect("ctl_path must have a parent");
274 let agent_socket_path = sock_dir.join(format!("agent-{id}.sock"));
275 let open_socket_path = sock_dir.join(format!("open-{id}.sock"));
276 let handle = tokio::spawn(async move {
277 server::run(client_rx, meta_clone, agent_socket_path, open_socket_path).await
278 });
279
280 sessions.insert(
281 id,
282 SessionState {
283 handle,
284 metadata,
285 client_tx: client_tx.clone(),
286 name: name_opt.clone(),
287 },
288 );
289
290 info!(id, name = ?name_opt, "session created");
291
292 if timed_send(&mut framed, Frame::SessionCreated { id: id.to_string() })
293 .await
294 .is_err()
295 {
296 continue;
297 }
298
299 let _ = client_tx.send(ClientConn::Active(framed));
301 }
302 Frame::Attach { session } => {
303 reap_sessions(&mut sessions);
304 if let Some(id) = resolve_session(&sessions, &session) {
305 let state = &sessions[&id];
306 if state.client_tx.is_closed() {
307 sessions.remove(&id);
308 let _ = timed_send(
309 &mut framed,
310 Frame::Error { message: format!("no such session: {session}") },
311 )
312 .await;
313 } else if timed_send(&mut framed, Frame::Ok).await.is_ok() {
314 let _ = state.client_tx.send(ClientConn::Active(framed));
315 }
316 } else {
317 let _ = timed_send(
318 &mut framed,
319 Frame::Error { message: format!("no such session: {session}") },
320 )
321 .await;
322 }
323 }
324 Frame::Tail { session } => {
325 reap_sessions(&mut sessions);
326 if let Some(id) = resolve_session(&sessions, &session) {
327 let state = &sessions[&id];
328 if state.client_tx.is_closed() {
329 sessions.remove(&id);
330 let _ = timed_send(
331 &mut framed,
332 Frame::Error { message: format!("no such session: {session}") },
333 )
334 .await;
335 } else if timed_send(&mut framed, Frame::Ok).await.is_ok() {
336 let _ = state.client_tx.send(ClientConn::Tail(framed));
337 }
338 } else {
339 let _ = timed_send(
340 &mut framed,
341 Frame::Error { message: format!("no such session: {session}") },
342 )
343 .await;
344 }
345 }
346 Frame::ListSessions => {
347 reap_sessions(&mut sessions);
348 let entries = build_session_entries(&sessions);
349 let _ = timed_send(&mut framed, Frame::SessionInfo { sessions: entries }).await;
350 }
351 Frame::KillSession { session } => {
352 reap_sessions(&mut sessions);
353 if let Some(id) = resolve_session(&sessions, &session) {
354 let state = sessions.remove(&id).unwrap();
355 state.handle.abort();
356 info!(id, "session killed");
357 let _ = timed_send(&mut framed, Frame::Ok).await;
358 } else {
359 let _ = timed_send(
360 &mut framed,
361 Frame::Error { message: format!("no such session: {session}") },
362 )
363 .await;
364 }
365 }
366 Frame::KillServer => {
367 info!("kill-server received, shutting down");
368 shutdown(&mut sessions, ctl_path);
369 let _ = timed_send(&mut framed, Frame::Ok).await;
370 break;
371 }
372 other => {
373 error!(?other, "unexpected frame on control socket");
374 let _ = timed_send(
375 &mut framed,
376 Frame::Error { message: "unexpected frame type".to_string() },
377 )
378 .await;
379 }
380 }
381 }
382
383 Ok(())
384}