1use std::io::Write as _;
12use std::path::PathBuf;
13use std::sync::Arc;
14
15#[cfg(unix)]
16use std::os::unix::fs::PermissionsExt;
17
18use async_trait::async_trait;
19use serde::{Deserialize, Serialize};
20use tokio::io::{AsyncReadExt, AsyncWriteExt};
21use tokio::net::{UnixListener, UnixStream};
22
23pub const MAX_FRAME_BYTES: usize = 8 * 1024 * 1024;
25
26const DEFAULT_DRAIN_TIMEOUT_SECS: u64 = 10;
27
28fn khive_dir() -> PathBuf {
31 let home = std::env::var("HOME").unwrap_or_else(|_| ".".into());
32 PathBuf::from(home).join(".khive")
33}
34
35pub fn socket_path() -> PathBuf {
39 if let Ok(p) = std::env::var("KHIVE_SOCKET") {
40 if !p.is_empty() {
41 return PathBuf::from(p);
42 }
43 }
44 khive_dir().join("khived.sock")
45}
46
47pub fn pid_path() -> PathBuf {
51 if let Ok(p) = std::env::var("KHIVE_PID") {
52 if !p.is_empty() {
53 return PathBuf::from(p);
54 }
55 }
56 khive_dir().join("khived.pid")
57}
58
59#[derive(Serialize, Deserialize)]
63pub struct DaemonRequestFrame {
64 pub ops: String,
65 pub presentation: Option<String>,
66 pub presentation_per_op: Option<Vec<Option<String>>>,
67 pub namespace: String,
68}
69
70#[derive(Serialize, Deserialize)]
72pub struct DaemonResponseFrame {
73 pub ok: bool,
74 pub result: Option<String>,
75 pub error: Option<String>,
76 pub namespace_mismatch: bool,
77}
78
79pub async fn read_frame(stream: &mut UnixStream) -> std::io::Result<Vec<u8>> {
83 let mut len_buf = [0u8; 4];
84 stream.read_exact(&mut len_buf).await?;
85 let len = u32::from_be_bytes(len_buf) as usize;
86 if len > MAX_FRAME_BYTES {
87 return Err(std::io::Error::new(
88 std::io::ErrorKind::InvalidData,
89 format!("daemon frame of {len} bytes exceeds {MAX_FRAME_BYTES} cap"),
90 ));
91 }
92 let mut buf = vec![0u8; len];
93 stream.read_exact(&mut buf).await?;
94 Ok(buf)
95}
96
97pub async fn write_frame(stream: &mut UnixStream, payload: &[u8]) -> std::io::Result<()> {
99 if payload.len() > MAX_FRAME_BYTES {
100 return Err(std::io::Error::new(
101 std::io::ErrorKind::InvalidData,
102 format!(
103 "daemon frame of {} bytes exceeds {MAX_FRAME_BYTES} cap",
104 payload.len()
105 ),
106 ));
107 }
108 let len = (payload.len() as u32).to_be_bytes();
109 stream.write_all(&len).await?;
110 stream.write_all(payload).await?;
111 stream.flush().await?;
112 Ok(())
113}
114
115#[async_trait]
122pub trait DaemonDispatch: Clone + Send + Sync + 'static {
123 async fn dispatch(
125 &self,
126 ops: String,
127 presentation: Option<String>,
128 presentation_per_op: Option<Vec<Option<String>>>,
129 ) -> Result<String, String>;
130
131 async fn warm_all(&self);
133
134 fn namespace(&self) -> &str;
136}
137
138async fn handle_conn<D: DaemonDispatch>(mut stream: UnixStream, dispatcher: D) {
141 let raw = match read_frame(&mut stream).await {
142 Ok(r) => r,
143 Err(e) => {
144 tracing::debug!(error = %e, "failed to read daemon request frame");
145 return;
146 }
147 };
148 let frame: DaemonRequestFrame = match serde_json::from_slice(&raw) {
149 Ok(f) => f,
150 Err(e) => {
151 tracing::debug!(error = %e, "failed to decode daemon request frame");
152 return;
153 }
154 };
155
156 let resp = if frame.namespace != dispatcher.namespace() {
157 DaemonResponseFrame {
158 ok: false,
159 result: None,
160 error: None,
161 namespace_mismatch: true,
162 }
163 } else {
164 match dispatcher
165 .dispatch(frame.ops, frame.presentation, frame.presentation_per_op)
166 .await
167 {
168 Ok(result) => DaemonResponseFrame {
169 ok: true,
170 result: Some(result),
171 error: None,
172 namespace_mismatch: false,
173 },
174 Err(e) => DaemonResponseFrame {
175 ok: false,
176 result: None,
177 error: Some(e),
178 namespace_mismatch: false,
179 },
180 }
181 };
182
183 match serde_json::to_vec(&resp) {
184 Ok(payload) => {
185 if let Err(e) = write_frame(&mut stream, &payload).await {
186 tracing::debug!(error = %e, "failed to write daemon response frame");
187 }
188 }
189 Err(e) => tracing::warn!(error = %e, "failed to serialize daemon response frame"),
190 }
191}
192
193pub async fn run_daemon<D: DaemonDispatch>(dispatcher: D) -> anyhow::Result<()> {
196 let sock = socket_path();
197 let pid_file = pid_path();
198
199 if let Some(parent) = sock.parent() {
200 std::fs::create_dir_all(parent)?;
201 #[cfg(unix)]
202 if let Err(e) = std::fs::set_permissions(parent, std::fs::Permissions::from_mode(0o700)) {
203 tracing::warn!(error = %e, path = ?parent, "failed to chmod 0700 khive dir");
204 }
205 }
206
207 if !cleanup_stale_daemon(&sock, &pid_file).await {
208 tracing::info!("a responsive khived is already running; exiting");
209 return Ok(());
210 }
211
212 let listener = UnixListener::bind(&sock)?;
213 #[cfg(unix)]
214 if let Err(e) = std::fs::set_permissions(&sock, std::fs::Permissions::from_mode(0o600)) {
215 tracing::warn!(error = %e, path = ?sock, "failed to chmod 0600 socket");
216 }
217
218 write_pid_file(&pid_file)?;
219 tracing::info!(socket = ?sock, pid = std::process::id(), "khived listening");
220
221 {
222 let warm = dispatcher.clone();
223 tokio::spawn(async move {
224 warm.warm_all().await;
225 });
226 }
227
228 let active = Arc::new(std::sync::atomic::AtomicUsize::new(0));
229
230 let shutdown = async {
231 let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
232 .expect("install SIGTERM handler");
233 let mut sigint = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())
234 .expect("install SIGINT handler");
235 tokio::select! {
236 _ = sigterm.recv() => tracing::info!("received SIGTERM"),
237 _ = sigint.recv() => tracing::info!("received SIGINT"),
238 }
239 };
240
241 tokio::select! {
242 _ = async {
243 loop {
244 match listener.accept().await {
245 Ok((stream, _)) => {
246 let d = dispatcher.clone();
247 let active = Arc::clone(&active);
248 tokio::spawn(async move {
249 active.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
250 handle_conn(stream, d).await;
251 active.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
252 });
253 }
254 Err(e) => tracing::error!(error = %e, "accept failed"),
255 }
256 }
257 } => {}
258 _ = shutdown => {}
259 }
260
261 drain(&active).await;
262
263 let _ = std::fs::remove_file(&sock);
264 let _ = std::fs::remove_file(&pid_file);
265 tracing::info!("khived stopped");
266 Ok(())
267}
268
269fn is_process_running(pid: u32) -> bool {
272 let Ok(pid) = i32::try_from(pid) else {
273 return false;
274 };
275 if pid <= 0 {
276 return false;
277 }
278 unsafe { libc::kill(pid, 0) == 0 }
280}
281
282async fn cleanup_stale_daemon(sock: &std::path::Path, pid_file: &std::path::Path) -> bool {
283 if let Ok(pid_str) = std::fs::read_to_string(pid_file) {
284 if let Ok(pid) = pid_str.trim().parse::<u32>() {
285 if pid != std::process::id()
286 && is_process_running(pid)
287 && sock.exists()
288 && UnixStream::connect(sock).await.is_ok()
289 {
290 return false;
291 }
292 }
293 }
294 if sock.exists() {
295 if let Err(e) = std::fs::remove_file(sock) {
296 tracing::warn!(error = %e, path = ?sock, "failed to remove stale socket");
297 }
298 }
299 if pid_file.exists() {
300 if let Err(e) = std::fs::remove_file(pid_file) {
301 tracing::warn!(error = %e, path = ?pid_file, "failed to remove stale PID file");
302 }
303 }
304 true
305}
306
307fn write_pid_file(pid_file: &std::path::Path) -> std::io::Result<()> {
308 let mut opts = std::fs::OpenOptions::new();
309 opts.write(true).create(true).truncate(true);
310 #[cfg(unix)]
311 {
312 use std::os::unix::fs::OpenOptionsExt;
313 opts.mode(0o600);
314 }
315 let mut f = opts.open(pid_file)?;
316 f.write_all(std::process::id().to_string().as_bytes())?;
317 Ok(())
318}
319
320async fn drain(active: &std::sync::atomic::AtomicUsize) {
321 use std::sync::atomic::Ordering;
322 if active.load(Ordering::Relaxed) == 0 {
323 return;
324 }
325 let deadline = tokio::time::Instant::now() + drain_timeout();
326 while active.load(Ordering::Relaxed) > 0 {
327 if tokio::time::Instant::now() >= deadline {
328 tracing::warn!(
329 remaining = active.load(Ordering::Relaxed),
330 "drain timeout reached; forcing shutdown"
331 );
332 break;
333 }
334 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
335 }
336}
337
338fn drain_timeout() -> std::time::Duration {
339 let secs = std::env::var("KHIVE_DRAIN_TIMEOUT_SECS")
340 .ok()
341 .and_then(|v| v.parse::<u64>().ok())
342 .unwrap_or(DEFAULT_DRAIN_TIMEOUT_SECS);
343 std::time::Duration::from_secs(secs)
344}
345
346pub fn env_truthy(key: &str) -> bool {
348 std::env::var(key)
349 .map(|v| {
350 let v = v.trim();
351 !v.is_empty() && v != "0" && !v.eq_ignore_ascii_case("false")
352 })
353 .unwrap_or(false)
354}