1use std::io::Write as _;
12use std::path::PathBuf;
13use std::sync::Arc;
14
15#[cfg(unix)]
16use std::os::unix::fs::PermissionsExt;
17
18use async_trait::async_trait;
19use serde::{Deserialize, Serialize};
20use tokio::io::{AsyncReadExt, AsyncWriteExt};
21use tokio::net::{UnixListener, UnixStream};
22
23pub const MAX_FRAME_BYTES: usize = 8 * 1024 * 1024;
25
26const DEFAULT_DRAIN_TIMEOUT_SECS: u64 = 10;
27
28fn khive_dir() -> PathBuf {
31 let home = std::env::var("HOME").unwrap_or_else(|_| ".".into());
32 PathBuf::from(home).join(".khive")
33}
34
35pub fn socket_path() -> PathBuf {
39 if let Ok(p) = std::env::var("KHIVE_SOCKET") {
40 if !p.is_empty() {
41 return PathBuf::from(p);
42 }
43 }
44 khive_dir().join("khived.sock")
45}
46
47pub fn pid_path() -> PathBuf {
51 if let Ok(p) = std::env::var("KHIVE_PID") {
52 if !p.is_empty() {
53 return PathBuf::from(p);
54 }
55 }
56 khive_dir().join("khived.pid")
57}
58
59#[derive(Serialize, Deserialize)]
63pub struct DaemonRequestFrame {
64 pub ops: String,
65 pub presentation: Option<String>,
66 pub presentation_per_op: Option<Vec<Option<String>>>,
67 pub namespace: String,
68 #[serde(default)]
73 pub config_id: String,
74}
75
76#[derive(Serialize, Deserialize)]
78pub struct DaemonResponseFrame {
79 pub ok: bool,
80 pub result: Option<String>,
81 pub error: Option<String>,
82 pub namespace_mismatch: bool,
83 #[serde(default)]
87 pub config_mismatch: bool,
88 #[serde(default)]
95 pub served_config_id: Option<String>,
96}
97
98pub async fn read_frame(stream: &mut UnixStream) -> std::io::Result<Vec<u8>> {
102 let mut len_buf = [0u8; 4];
103 stream.read_exact(&mut len_buf).await?;
104 let len = u32::from_be_bytes(len_buf) as usize;
105 if len > MAX_FRAME_BYTES {
106 return Err(std::io::Error::new(
107 std::io::ErrorKind::InvalidData,
108 format!("daemon frame of {len} bytes exceeds {MAX_FRAME_BYTES} cap"),
109 ));
110 }
111 let mut buf = vec![0u8; len];
112 stream.read_exact(&mut buf).await?;
113 Ok(buf)
114}
115
116pub async fn write_frame(stream: &mut UnixStream, payload: &[u8]) -> std::io::Result<()> {
118 if payload.len() > MAX_FRAME_BYTES {
119 return Err(std::io::Error::new(
120 std::io::ErrorKind::InvalidData,
121 format!(
122 "daemon frame of {} bytes exceeds {MAX_FRAME_BYTES} cap",
123 payload.len()
124 ),
125 ));
126 }
127 let len = (payload.len() as u32).to_be_bytes();
128 stream.write_all(&len).await?;
129 stream.write_all(payload).await?;
130 stream.flush().await?;
131 Ok(())
132}
133
134#[async_trait]
141pub trait DaemonDispatch: Clone + Send + Sync + 'static {
142 async fn dispatch(
144 &self,
145 ops: String,
146 presentation: Option<String>,
147 presentation_per_op: Option<Vec<Option<String>>>,
148 ) -> Result<String, String>;
149
150 async fn warm_all(&self);
152
153 fn namespace(&self) -> &str;
155
156 fn config_id(&self) -> &str;
161}
162
163async fn handle_conn<D: DaemonDispatch>(mut stream: UnixStream, dispatcher: D) {
166 let raw = match read_frame(&mut stream).await {
167 Ok(r) => r,
168 Err(e) => {
169 tracing::debug!(error = %e, "failed to read daemon request frame");
170 return;
171 }
172 };
173 let frame: DaemonRequestFrame = match serde_json::from_slice(&raw) {
174 Ok(f) => f,
175 Err(e) => {
176 tracing::debug!(error = %e, "failed to decode daemon request frame");
177 return;
178 }
179 };
180
181 let served_config_id = Some(dispatcher.config_id().to_string());
182 let resp = if frame.namespace != dispatcher.namespace() {
183 DaemonResponseFrame {
184 ok: false,
185 result: None,
186 error: None,
187 namespace_mismatch: true,
188 config_mismatch: false,
189 served_config_id,
190 }
191 } else if frame.config_id != dispatcher.config_id() {
192 DaemonResponseFrame {
193 ok: false,
194 result: None,
195 error: None,
196 namespace_mismatch: false,
197 config_mismatch: true,
198 served_config_id,
199 }
200 } else {
201 match dispatcher
202 .dispatch(frame.ops, frame.presentation, frame.presentation_per_op)
203 .await
204 {
205 Ok(result) => DaemonResponseFrame {
206 ok: true,
207 result: Some(result),
208 error: None,
209 namespace_mismatch: false,
210 config_mismatch: false,
211 served_config_id,
212 },
213 Err(e) => DaemonResponseFrame {
214 ok: false,
215 result: None,
216 error: Some(e),
217 namespace_mismatch: false,
218 config_mismatch: false,
219 served_config_id,
220 },
221 }
222 };
223
224 match serde_json::to_vec(&resp) {
225 Ok(payload) => {
226 if let Err(e) = write_frame(&mut stream, &payload).await {
227 tracing::debug!(error = %e, "failed to write daemon response frame");
228 }
229 }
230 Err(e) => tracing::warn!(error = %e, "failed to serialize daemon response frame"),
231 }
232}
233
234pub async fn run_daemon<D: DaemonDispatch>(dispatcher: D) -> anyhow::Result<()> {
237 let sock = socket_path();
238 let pid_file = pid_path();
239
240 if let Some(parent) = sock.parent() {
241 std::fs::create_dir_all(parent)?;
242 #[cfg(unix)]
243 if let Err(e) = std::fs::set_permissions(parent, std::fs::Permissions::from_mode(0o700)) {
244 tracing::warn!(error = %e, path = ?parent, "failed to chmod 0700 khive dir");
245 }
246 }
247
248 if !cleanup_stale_daemon(&sock, &pid_file).await {
249 tracing::info!("a responsive khived is already running; exiting");
250 return Ok(());
251 }
252
253 let listener = UnixListener::bind(&sock)?;
254 #[cfg(unix)]
255 if let Err(e) = std::fs::set_permissions(&sock, std::fs::Permissions::from_mode(0o600)) {
256 tracing::warn!(error = %e, path = ?sock, "failed to chmod 0600 socket");
257 }
258
259 write_pid_file(&pid_file)?;
260 tracing::info!(socket = ?sock, pid = std::process::id(), "khived listening");
261
262 {
263 let warm = dispatcher.clone();
264 tokio::spawn(async move {
265 warm.warm_all().await;
266 });
267 }
268
269 let active = Arc::new(std::sync::atomic::AtomicUsize::new(0));
270
271 let shutdown = async {
272 let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
276 .expect("install SIGTERM handler");
277 let mut sigint = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())
278 .expect("install SIGINT handler");
279 tokio::select! {
280 _ = sigterm.recv() => tracing::info!("received SIGTERM"),
281 _ = sigint.recv() => tracing::info!("received SIGINT"),
282 }
283 };
284
285 tokio::select! {
286 _ = async {
287 loop {
288 match listener.accept().await {
289 Ok((stream, _)) => {
290 let d = dispatcher.clone();
291 let active = Arc::clone(&active);
292 tokio::spawn(async move {
293 active.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
294 handle_conn(stream, d).await;
295 active.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
296 });
297 }
298 Err(e) => tracing::error!(error = %e, "accept failed"),
299 }
300 }
301 } => {}
302 _ = shutdown => {}
303 }
304
305 drain(&active).await;
306
307 let _ = std::fs::remove_file(&sock);
308 let _ = std::fs::remove_file(&pid_file);
309 tracing::info!("khived stopped");
310 Ok(())
311}
312
313fn is_process_running(pid: u32) -> bool {
316 let Ok(pid) = i32::try_from(pid) else {
317 return false;
318 };
319 if pid <= 0 {
320 return false;
321 }
322 unsafe { libc::kill(pid, 0) == 0 }
324}
325
326async fn cleanup_stale_daemon(sock: &std::path::Path, pid_file: &std::path::Path) -> bool {
327 if let Ok(pid_str) = std::fs::read_to_string(pid_file) {
328 if let Ok(pid) = pid_str.trim().parse::<u32>() {
329 if pid != std::process::id()
330 && is_process_running(pid)
331 && sock.exists()
332 && UnixStream::connect(sock).await.is_ok()
333 {
334 return false;
335 }
336 }
337 }
338 if sock.exists() {
339 if let Err(e) = std::fs::remove_file(sock) {
340 tracing::warn!(error = %e, path = ?sock, "failed to remove stale socket");
341 }
342 }
343 if pid_file.exists() {
344 if let Err(e) = std::fs::remove_file(pid_file) {
345 tracing::warn!(error = %e, path = ?pid_file, "failed to remove stale PID file");
346 }
347 }
348 true
349}
350
351fn write_pid_file(pid_file: &std::path::Path) -> std::io::Result<()> {
352 let mut opts = std::fs::OpenOptions::new();
353 opts.write(true).create(true).truncate(true);
354 #[cfg(unix)]
355 {
356 use std::os::unix::fs::OpenOptionsExt;
357 opts.mode(0o600);
358 }
359 let mut f = opts.open(pid_file)?;
360 f.write_all(std::process::id().to_string().as_bytes())?;
361 Ok(())
362}
363
364async fn drain(active: &std::sync::atomic::AtomicUsize) {
365 use std::sync::atomic::Ordering;
366 if active.load(Ordering::Relaxed) == 0 {
367 return;
368 }
369 let deadline = tokio::time::Instant::now() + drain_timeout();
370 while active.load(Ordering::Relaxed) > 0 {
371 if tokio::time::Instant::now() >= deadline {
372 tracing::warn!(
373 remaining = active.load(Ordering::Relaxed),
374 "drain timeout reached; forcing shutdown"
375 );
376 break;
377 }
378 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
379 }
380}
381
382fn drain_timeout() -> std::time::Duration {
383 let secs = std::env::var("KHIVE_DRAIN_TIMEOUT_SECS")
384 .ok()
385 .and_then(|v| v.parse::<u64>().ok())
386 .unwrap_or(DEFAULT_DRAIN_TIMEOUT_SECS);
387 std::time::Duration::from_secs(secs)
388}
389
390pub fn env_truthy(key: &str) -> bool {
392 std::env::var(key)
393 .map(|v| {
394 let v = v.trim();
395 !v.is_empty() && v != "0" && !v.eq_ignore_ascii_case("false")
396 })
397 .unwrap_or(false)
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403
404 #[test]
408 fn current_process_is_running() {
409 let pid = std::process::id();
411 assert!(
412 is_process_running(pid),
413 "current process {pid} should be detected as running"
414 );
415 }
416
417 #[test]
418 fn pid_zero_is_not_running() {
419 assert!(
422 !is_process_running(0),
423 "pid 0 must be rejected by the guard before the unsafe call"
424 );
425 }
426
427 #[test]
428 fn very_large_pid_is_not_running() {
429 assert!(
431 !is_process_running(u32::MAX),
432 "u32::MAX should fail i32 conversion and return false"
433 );
434 }
435
436 #[test]
437 fn env_truthy_recognises_set_values() {
438 assert!(!env_truthy("__KHIVE_TEST_ABSENT_VAR_XYZ__"));
439
440 let key = "__KHIVE_TEST_TRUTHY_ABC__";
444 std::env::set_var(key, "1");
445 assert!(env_truthy(key));
446 std::env::set_var(key, "false");
447 assert!(!env_truthy(key));
448 std::env::set_var(key, "0");
449 assert!(!env_truthy(key));
450 std::env::remove_var(key);
451 }
452}