1use crate::protocol::{Frame, FrameCodec, SessionEntry};
2use crate::server::{self, SessionMetadata};
3use futures_util::{SinkExt, StreamExt};
4use std::collections::HashMap;
5use std::os::fd::OwnedFd;
6use std::path::{Path, PathBuf};
7use std::sync::atomic::Ordering;
8use std::sync::{Arc, OnceLock};
9use tokio::net::UnixStream;
10use tokio::sync::mpsc;
11use tokio::task::JoinHandle;
12use tokio_util::codec::Framed;
13use tracing::{error, info, warn};
14
15struct SessionState {
16 handle: JoinHandle<anyhow::Result<()>>,
17 metadata: Arc<OnceLock<SessionMetadata>>,
18 client_tx: mpsc::UnboundedSender<Framed<UnixStream, FrameCodec>>,
19 name: Option<String>,
20}
21
22pub fn socket_dir() -> PathBuf {
25 if let Ok(xdg) = std::env::var("XDG_RUNTIME_DIR") {
26 PathBuf::from(xdg).join("gritty")
27 } else {
28 let uid = unsafe { libc::getuid() };
29 PathBuf::from(format!("/tmp/gritty-{uid}"))
30 }
31}
32
33pub fn control_socket_path() -> PathBuf {
35 socket_dir().join("ctl.sock")
36}
37
38fn pid_file_path(ctl_path: &Path) -> PathBuf {
40 ctl_path.with_file_name("daemon.pid")
41}
42
43fn reap_sessions(sessions: &mut HashMap<u32, SessionState>) {
44 sessions.retain(|id, state| {
45 if state.handle.is_finished() {
46 info!(id, "session ended");
47 false
48 } else {
49 true
50 }
51 });
52}
53
54fn resolve_session(sessions: &HashMap<u32, SessionState>, target: &str) -> Option<u32> {
56 for (&id, state) in sessions {
58 if state.name.as_deref() == Some(target) {
59 return Some(id);
60 }
61 }
62 if let Ok(id) = target.parse::<u32>()
64 && sessions.contains_key(&id)
65 {
66 return Some(id);
67 }
68 None
69}
70
71fn build_session_entries(sessions: &HashMap<u32, SessionState>) -> Vec<SessionEntry> {
72 let mut entries: Vec<_> = sessions
73 .iter()
74 .map(|(&id, state)| {
75 if let Some(meta) = state.metadata.get() {
76 SessionEntry {
77 id: id.to_string(),
78 name: state.name.clone().unwrap_or_default(),
79 pty_path: meta.pty_path.clone(),
80 shell_pid: meta.shell_pid,
81 created_at: meta.created_at,
82 attached: meta.attached.load(Ordering::Relaxed),
83 last_heartbeat: meta.last_heartbeat.load(Ordering::Relaxed),
84 }
85 } else {
86 SessionEntry {
87 id: id.to_string(),
88 name: state.name.clone().unwrap_or_default(),
89 pty_path: String::new(),
90 shell_pid: 0,
91 created_at: 0,
92 attached: false,
93 last_heartbeat: 0,
94 }
95 }
96 })
97 .collect();
98 entries.sort_by_key(|e| e.id.parse::<u32>().unwrap_or(u32::MAX));
99 entries
100}
101
102fn shutdown(sessions: &mut HashMap<u32, SessionState>, ctl_path: &Path) {
103 for (id, state) in sessions.drain() {
104 state.handle.abort();
105 info!(id, "session aborted");
106 }
107 let _ = std::fs::remove_file(ctl_path);
108 let _ = std::fs::remove_file(pid_file_path(ctl_path));
109}
110
111pub async fn run(ctl_path: &Path, ready_fd: Option<OwnedFd>) -> anyhow::Result<()> {
117 unsafe {
119 libc::umask(0o077);
120 }
121
122 if let Some(parent) = ctl_path.parent() {
124 crate::security::secure_create_dir_all(parent)?;
125 }
126
127 let listener = crate::security::bind_unix_listener(ctl_path)?;
128 info!(path = %ctl_path.display(), "daemon listening");
129
130 if let Some(fd) = ready_fd {
132 use std::io::Write;
133 let mut f = std::fs::File::from(fd);
134 let _ = f.write_all(&[1]);
135 }
137
138 let pid_path = pid_file_path(ctl_path);
140 std::fs::write(&pid_path, std::process::id().to_string())?;
141
142 let mut sessions: HashMap<u32, SessionState> = HashMap::new();
143 let mut next_id: u32 = 0;
144
145 let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
147 let mut sigint = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())?;
148
149 loop {
150 reap_sessions(&mut sessions);
151
152 let stream = tokio::select! {
153 result = listener.accept() => {
154 let (stream, _addr) = result?;
155 stream
156 }
157 _ = sigterm.recv() => {
158 info!("SIGTERM received, shutting down");
159 shutdown(&mut sessions, ctl_path);
160 break;
161 }
162 _ = sigint.recv() => {
163 info!("SIGINT received, shutting down");
164 shutdown(&mut sessions, ctl_path);
165 break;
166 }
167 };
168
169 if let Err(e) = crate::security::verify_peer_uid(&stream) {
170 warn!("{e}");
171 continue;
172 }
173 let mut framed = Framed::new(stream, FrameCodec);
174
175 let Some(Ok(frame)) = framed.next().await else {
177 continue;
178 };
179
180 match frame {
181 Frame::NewSession { name } => {
182 let name_opt = if name.is_empty() { None } else { Some(name) };
184 if let Some(ref n) = name_opt {
185 let dup = sessions.values().any(|s| s.name.as_deref() == Some(n));
186 if dup {
187 let _ = framed
188 .send(Frame::Error {
189 message: format!("session name already exists: {n}"),
190 })
191 .await;
192 continue;
193 }
194 }
195
196 let id = next_id;
197 next_id += 1;
198
199 let (client_tx, client_rx) = mpsc::unbounded_channel();
200 let metadata = Arc::new(OnceLock::new());
201 let meta_clone = Arc::clone(&metadata);
202 let handle = tokio::spawn(async move { server::run(client_rx, meta_clone).await });
203
204 sessions.insert(
205 id,
206 SessionState {
207 handle,
208 metadata,
209 client_tx: client_tx.clone(),
210 name: name_opt.clone(),
211 },
212 );
213
214 info!(id, name = ?name_opt, "session created");
215
216 let _ = framed.send(Frame::SessionCreated { id: id.to_string() }).await;
217
218 let _ = client_tx.send(framed);
220 }
221 Frame::Attach { session } => {
222 reap_sessions(&mut sessions);
223 if let Some(id) = resolve_session(&sessions, &session) {
224 let state = &sessions[&id];
225 if state.client_tx.is_closed() {
226 sessions.remove(&id);
227 let _ = framed
228 .send(Frame::Error { message: format!("no such session: {session}") })
229 .await;
230 } else {
231 let _ = framed.send(Frame::Ok).await;
232 let _ = state.client_tx.send(framed);
233 }
234 } else {
235 let _ = framed
236 .send(Frame::Error { message: format!("no such session: {session}") })
237 .await;
238 }
239 }
240 Frame::ListSessions => {
241 reap_sessions(&mut sessions);
242 let entries = build_session_entries(&sessions);
243 let _ = framed.send(Frame::SessionInfo { sessions: entries }).await;
244 }
245 Frame::KillSession { session } => {
246 reap_sessions(&mut sessions);
247 if let Some(id) = resolve_session(&sessions, &session) {
248 let state = sessions.remove(&id).unwrap();
249 state.handle.abort();
250 info!(id, "session killed");
251 let _ = framed.send(Frame::Ok).await;
252 } else {
253 let _ = framed
254 .send(Frame::Error { message: format!("no such session: {session}") })
255 .await;
256 }
257 }
258 Frame::KillServer => {
259 info!("kill-server received, shutting down");
260 shutdown(&mut sessions, ctl_path);
261 let _ = framed.send(Frame::Ok).await;
262 break;
263 }
264 other => {
265 error!(?other, "unexpected frame on control socket");
266 let _ = framed
267 .send(Frame::Error { message: "unexpected frame type".to_string() })
268 .await;
269 }
270 }
271 }
272
273 Ok(())
274}