1use crate::protocol::{Frame, FrameCodec, PROTOCOL_VERSION, SessionEntry};
2use crate::server::{self, ClientConn, SessionMetadata};
3use futures_util::{SinkExt, StreamExt};
4use std::collections::HashMap;
5use std::os::fd::OwnedFd;
6use std::path::{Path, PathBuf};
7use std::sync::atomic::Ordering;
8use std::sync::{Arc, OnceLock};
9use std::time::Duration;
10use tokio::net::UnixStream;
11use tokio::sync::mpsc;
12use tokio::task::JoinHandle;
13use tokio_util::codec::Framed;
14use tracing::{error, info, warn};
15
16const SEND_TIMEOUT: Duration = Duration::from_secs(5);
17
18async fn timed_send(
21 framed: &mut Framed<UnixStream, FrameCodec>,
22 frame: Frame,
23) -> Result<(), std::io::Error> {
24 match tokio::time::timeout(SEND_TIMEOUT, framed.send(frame)).await {
25 Ok(Ok(())) => Ok(()),
26 Ok(Err(e)) => {
27 warn!("control send error: {e}");
28 Err(e)
29 }
30 Err(_) => {
31 warn!("control send timed out");
32 Err(std::io::Error::new(std::io::ErrorKind::TimedOut, "send timed out"))
33 }
34 }
35}
36
37struct SessionState {
38 handle: JoinHandle<anyhow::Result<()>>,
39 metadata: Arc<OnceLock<SessionMetadata>>,
40 client_tx: mpsc::UnboundedSender<ClientConn>,
41 name: Option<String>,
42}
43
44pub fn socket_dir() -> PathBuf {
47 if let Ok(dir) = std::env::var("GRITTY_SOCKET_DIR") {
48 return PathBuf::from(dir);
49 }
50 if let Some(proj) = directories::ProjectDirs::from("", "", "gritty") {
51 if let Some(runtime) = proj.runtime_dir() {
52 return runtime.to_path_buf();
53 }
54 }
55 let uid = unsafe { libc::getuid() };
56 PathBuf::from(format!("/tmp/gritty-{uid}"))
57}
58
59pub fn control_socket_path() -> PathBuf {
61 socket_dir().join("ctl.sock")
62}
63
64pub fn pid_file_path(ctl_path: &Path) -> PathBuf {
66 ctl_path.with_file_name("daemon.pid")
67}
68
69fn reap_sessions(sessions: &mut HashMap<u32, SessionState>) {
70 sessions.retain(|id, state| {
71 if state.handle.is_finished() {
72 info!(id, "session ended");
73 false
74 } else {
75 true
76 }
77 });
78}
79
80fn resolve_session(
82 sessions: &HashMap<u32, SessionState>,
83 target: &str,
84 last_attached: Option<u32>,
85) -> Option<u32> {
86 if target == "-" {
88 return last_attached.filter(|id| sessions.contains_key(id));
89 }
90 for (&id, state) in sessions {
92 if state.name.as_deref() == Some(target) {
93 return Some(id);
94 }
95 }
96 if let Ok(id) = target.parse::<u32>()
98 && sessions.contains_key(&id)
99 {
100 return Some(id);
101 }
102 None
103}
104
105fn foreground_process(shell_pid: u32) -> String {
108 let stat = match std::fs::read_to_string(format!("/proc/{shell_pid}/stat")) {
110 Ok(s) => s,
111 Err(_) => return "-".to_string(),
112 };
113 let after_comm = match stat.rfind(')') {
116 Some(pos) => &stat[pos + 2..], None => return "-".to_string(),
118 };
119 let fields: Vec<&str> = after_comm.split_whitespace().collect();
122 let tpgid = match fields.get(5).and_then(|s| s.parse::<u32>().ok()) {
123 Some(t) if t > 0 => t,
124 _ => return "-".to_string(),
125 };
126 std::fs::read_to_string(format!("/proc/{tpgid}/comm"))
128 .map(|s| s.trim().to_string())
129 .unwrap_or_else(|_| "-".to_string())
130}
131
132fn build_session_entries(sessions: &HashMap<u32, SessionState>) -> Vec<SessionEntry> {
133 let mut entries: Vec<_> = sessions
134 .iter()
135 .map(|(&id, state)| {
136 if let Some(meta) = state.metadata.get() {
137 SessionEntry {
138 id: id.to_string(),
139 name: state.name.clone().unwrap_or_default(),
140 pty_path: meta.pty_path.clone(),
141 shell_pid: meta.shell_pid,
142 created_at: meta.created_at,
143 attached: meta.attached.load(Ordering::Relaxed),
144 last_heartbeat: meta.last_heartbeat.load(Ordering::Relaxed),
145 foreground_cmd: foreground_process(meta.shell_pid),
146 }
147 } else {
148 SessionEntry {
149 id: id.to_string(),
150 name: state.name.clone().unwrap_or_default(),
151 pty_path: String::new(),
152 shell_pid: 0,
153 created_at: 0,
154 attached: false,
155 last_heartbeat: 0,
156 foreground_cmd: "-".to_string(),
157 }
158 }
159 })
160 .collect();
161 entries.sort_by_key(|e| e.id.parse::<u32>().unwrap_or(u32::MAX));
162 entries
163}
164
165fn shutdown(sessions: &mut HashMap<u32, SessionState>, ctl_path: &Path) {
166 for (id, state) in sessions.drain() {
167 state.handle.abort();
168 info!(id, "session aborted");
169 }
170 let _ = std::fs::remove_file(ctl_path);
171 let _ = std::fs::remove_file(pid_file_path(ctl_path));
172}
173
174async fn connection_handshake(
177 stream: UnixStream,
178 tx: mpsc::Sender<(Frame, Framed<UnixStream, FrameCodec>)>,
179) {
180 let mut framed = Framed::new(stream, FrameCodec);
181
182 let hello = match tokio::time::timeout(Duration::from_secs(5), framed.next()).await {
184 Ok(Some(Ok(Frame::Hello { version }))) => version,
185 Ok(Some(Ok(_))) => {
186 let _ = timed_send(
187 &mut framed,
188 Frame::Error { message: "expected Hello handshake".to_string() },
189 )
190 .await;
191 return;
192 }
193 Ok(Some(Err(e))) => {
194 warn!("frame decode error: {e}");
195 return;
196 }
197 Ok(None) => return,
198 Err(_) => {
199 warn!("control connection timed out (hello)");
200 return;
201 }
202 };
203
204 if hello != PROTOCOL_VERSION {
206 let _ = timed_send(
207 &mut framed,
208 Frame::Error {
209 message: format!(
210 "protocol version mismatch: client={hello} server={PROTOCOL_VERSION}; \
211 both sides must run the same version"
212 ),
213 },
214 )
215 .await;
216 return;
217 }
218 if timed_send(&mut framed, Frame::HelloAck { version: PROTOCOL_VERSION }).await.is_err() {
219 return;
220 }
221
222 let frame = match tokio::time::timeout(Duration::from_secs(5), framed.next()).await {
224 Ok(Some(Ok(f))) => f,
225 Ok(Some(Err(e))) => {
226 warn!("frame decode error: {e}");
227 return;
228 }
229 Ok(None) => return,
230 Err(_) => {
231 warn!("control connection timed out");
232 return;
233 }
234 };
235
236 let _ = tx.send((frame, framed)).await;
237}
238
239pub async fn run(ctl_path: &Path, ready_fd: Option<OwnedFd>) -> anyhow::Result<()> {
245 unsafe {
247 libc::umask(0o077);
248 }
249
250 if let Some(parent) = ctl_path.parent() {
252 crate::security::secure_create_dir_all(parent)?;
253 }
254
255 let listener = crate::security::bind_unix_listener(ctl_path)?;
256 info!(path = %ctl_path.display(), "daemon listening");
257
258 if let Some(fd) = ready_fd {
260 use std::io::Write;
261 let mut f = std::fs::File::from(fd);
262 let pid = std::process::id();
263 let mut buf = [0u8; 5];
264 buf[0] = 0x01;
265 buf[1..5].copy_from_slice(&pid.to_le_bytes());
266 let _ = f.write_all(&buf);
267 }
269
270 let pid_path = pid_file_path(ctl_path);
272 std::fs::write(&pid_path, std::process::id().to_string())?;
273
274 let mut sessions: HashMap<u32, SessionState> = HashMap::new();
275 let mut next_id: u32 = 0;
276 let mut last_attached: Option<u32> = None;
277 let session_config = crate::config::ConfigFile::load().resolve_session(None);
278 let ring_buffer_cap = session_config.ring_buffer_size as usize;
279 let oauth_tunnel_idle_timeout = session_config.oauth_tunnel_idle_timeout;
280
281 let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
283 let mut sigint = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())?;
284
285 let (conn_tx, mut conn_rx) = mpsc::channel::<(Frame, Framed<UnixStream, FrameCodec>)>(64);
287
288 loop {
289 reap_sessions(&mut sessions);
290
291 let should_break = tokio::select! {
292 result = listener.accept() => {
293 let (stream, _addr) = result?;
294 if let Err(e) = crate::security::verify_peer_uid(&stream) {
295 warn!("{e}");
296 } else {
297 let tx = conn_tx.clone();
298 tokio::spawn(connection_handshake(stream, tx));
299 }
300 false
301 }
302 Some((frame, framed)) = conn_rx.recv() => {
303 dispatch_control(
304 frame, framed, &mut sessions, &mut next_id, ctl_path, &mut last_attached,
305 ring_buffer_cap, oauth_tunnel_idle_timeout,
306 ).await
307 }
308 _ = sigterm.recv() => {
309 info!("SIGTERM received, shutting down");
310 shutdown(&mut sessions, ctl_path);
311 true
312 }
313 _ = sigint.recv() => {
314 info!("SIGINT received, shutting down");
315 shutdown(&mut sessions, ctl_path);
316 true
317 }
318 };
319
320 if should_break {
321 break;
322 }
323 }
324
325 Ok(())
326}
327
328#[allow(clippy::too_many_arguments)]
332async fn dispatch_control(
333 frame: Frame,
334 mut framed: Framed<UnixStream, FrameCodec>,
335 sessions: &mut HashMap<u32, SessionState>,
336 next_id: &mut u32,
337 ctl_path: &Path,
338 last_attached: &mut Option<u32>,
339 ring_buffer_cap: usize,
340 oauth_tunnel_idle_timeout: u64,
341) -> bool {
342 match frame {
343 Frame::NewSession { name, command } => {
344 let name_opt = if name.is_empty() { None } else { Some(name) };
346 let command_opt = if command.is_empty() { None } else { Some(command) };
347 if let Some(ref n) = name_opt {
348 if n.bytes().any(|b| b.is_ascii_control()) {
349 let _ = timed_send(
350 &mut framed,
351 Frame::Error {
352 message: "session name must not contain control characters".to_string(),
353 },
354 )
355 .await;
356 return false;
357 }
358 if n.parse::<u32>().is_ok() {
359 let _ = timed_send(
360 &mut framed,
361 Frame::Error {
362 message: "session name must not be purely numeric (ambiguous with session IDs)".to_string(),
363 },
364 )
365 .await;
366 return false;
367 }
368 let dup = sessions.values().any(|s| s.name.as_deref() == Some(n));
369 if dup {
370 let _ = timed_send(
371 &mut framed,
372 Frame::Error { message: format!("session name already exists: {n}") },
373 )
374 .await;
375 return false;
376 }
377 }
378
379 let id = *next_id;
380 *next_id += 1;
381
382 let (client_tx, client_rx) = mpsc::unbounded_channel();
383 let metadata = Arc::new(OnceLock::new());
384 let meta_clone = Arc::clone(&metadata);
385 let sock_dir = ctl_path.parent().expect("ctl_path must have a parent");
386 let agent_socket_path = sock_dir.join(format!("agent-{id}.sock"));
387 let svc_socket_path = sock_dir.join(format!("svc-{id}.sock"));
388 let name_for_server = name_opt.clone();
389 let cmd_for_server = command_opt;
390 let handle = tokio::spawn(async move {
391 server::run(
392 client_rx,
393 meta_clone,
394 agent_socket_path,
395 svc_socket_path,
396 id,
397 name_for_server,
398 cmd_for_server,
399 ring_buffer_cap,
400 oauth_tunnel_idle_timeout,
401 )
402 .await
403 });
404
405 sessions.insert(
406 id,
407 SessionState {
408 handle,
409 metadata,
410 client_tx: client_tx.clone(),
411 name: name_opt.clone(),
412 },
413 );
414
415 info!(id, name = ?name_opt, "session created");
416
417 if timed_send(&mut framed, Frame::SessionCreated { id: id.to_string() }).await.is_err()
418 {
419 return false;
420 }
421
422 *last_attached = Some(id);
424 let _ = client_tx.send(ClientConn::Active(framed));
425 false
426 }
427 Frame::Attach { session } => {
428 reap_sessions(sessions);
429 if let Some(id) = resolve_session(sessions, &session, *last_attached) {
430 let state = &sessions[&id];
431 if state.client_tx.is_closed() {
432 sessions.remove(&id);
433 let _ = timed_send(
434 &mut framed,
435 Frame::Error { message: format!("no such session: {session}") },
436 )
437 .await;
438 } else if timed_send(&mut framed, Frame::Ok).await.is_ok() {
439 *last_attached = Some(id);
440 let _ = state.client_tx.send(ClientConn::Active(framed));
441 }
442 } else {
443 let _ = timed_send(
444 &mut framed,
445 Frame::Error { message: format!("no such session: {session}") },
446 )
447 .await;
448 }
449 false
450 }
451 Frame::Tail { session } => {
452 reap_sessions(sessions);
453 if let Some(id) = resolve_session(sessions, &session, *last_attached) {
454 let state = &sessions[&id];
455 if state.client_tx.is_closed() {
456 sessions.remove(&id);
457 let _ = timed_send(
458 &mut framed,
459 Frame::Error { message: format!("no such session: {session}") },
460 )
461 .await;
462 } else if timed_send(&mut framed, Frame::Ok).await.is_ok() {
463 let _ = state.client_tx.send(ClientConn::Tail(framed));
464 }
465 } else {
466 let _ = timed_send(
467 &mut framed,
468 Frame::Error { message: format!("no such session: {session}") },
469 )
470 .await;
471 }
472 false
473 }
474 Frame::ListSessions => {
475 reap_sessions(sessions);
476 let entries = build_session_entries(sessions);
477 let _ = timed_send(&mut framed, Frame::SessionInfo { sessions: entries }).await;
478 false
479 }
480 Frame::KillSession { session } => {
481 reap_sessions(sessions);
482 if let Some(id) = resolve_session(sessions, &session, *last_attached) {
483 let state = sessions.remove(&id).unwrap();
484 state.handle.abort();
485 info!(id, "session killed");
486 let _ = timed_send(&mut framed, Frame::Ok).await;
487 } else {
488 let _ = timed_send(
489 &mut framed,
490 Frame::Error { message: format!("no such session: {session}") },
491 )
492 .await;
493 }
494 false
495 }
496 Frame::SendFile { session, .. } => {
497 reap_sessions(sessions);
498 if let Some(id) = resolve_session(sessions, &session, *last_attached) {
499 let state = &sessions[&id];
500 if state.client_tx.is_closed() {
501 sessions.remove(&id);
502 let _ = timed_send(
503 &mut framed,
504 Frame::Error { message: format!("no such session: {session}") },
505 )
506 .await;
507 } else if timed_send(&mut framed, Frame::Ok).await.is_ok() {
508 let stream = framed.into_inner();
509 let _ = state.client_tx.send(ClientConn::Send(stream));
510 }
511 } else {
512 let _ = timed_send(
513 &mut framed,
514 Frame::Error { message: format!("no such session: {session}") },
515 )
516 .await;
517 }
518 false
519 }
520 Frame::RenameSession { session, new_name } => {
521 reap_sessions(sessions);
522 if let Some(id) = resolve_session(sessions, &session, *last_attached) {
523 if new_name.is_empty() {
524 let _ = timed_send(
525 &mut framed,
526 Frame::Error { message: "new name must not be empty".to_string() },
527 )
528 .await;
529 } else if new_name.bytes().any(|b| b.is_ascii_control()) {
530 let _ = timed_send(
531 &mut framed,
532 Frame::Error {
533 message: "session name must not contain control characters".to_string(),
534 },
535 )
536 .await;
537 } else if new_name.parse::<u32>().is_ok() {
538 let _ = timed_send(
539 &mut framed,
540 Frame::Error {
541 message: "session name must not be purely numeric (ambiguous with session IDs)".to_string(),
542 },
543 )
544 .await;
545 } else if sessions.values().any(|s| s.name.as_deref() == Some(&new_name)) {
546 let _ = timed_send(
547 &mut framed,
548 Frame::Error {
549 message: format!("session name already exists: {new_name}"),
550 },
551 )
552 .await;
553 } else {
554 sessions.get_mut(&id).unwrap().name = Some(new_name.clone());
555 info!(id, new_name, "session renamed");
556 let _ = timed_send(&mut framed, Frame::Ok).await;
557 }
558 } else {
559 let _ = timed_send(
560 &mut framed,
561 Frame::Error { message: format!("no such session: {session}") },
562 )
563 .await;
564 }
565 false
566 }
567 Frame::KillServer => {
568 info!("kill-server received, shutting down");
569 shutdown(sessions, ctl_path);
570 let _ = timed_send(&mut framed, Frame::Ok).await;
571 true
572 }
573 other => {
574 error!(?other, "unexpected frame on control socket");
575 let _ = timed_send(
576 &mut framed,
577 Frame::Error { message: "unexpected frame type".to_string() },
578 )
579 .await;
580 false
581 }
582 }
583}