Skip to main content

khive_runtime/
daemon.rs

1//! khived daemon server — persistent warm runtime over a Unix socket.
2//!
3//! The daemon binds `~/.khive/khived.sock`, accepts length-prefixed request
4//! frames, dispatches them through a [`DaemonDispatch`] implementor, and serves
5//! results back. It is transport-agnostic: the MCP crate provides the dispatch
6//! impl, but any future client (CLI, HTTP gateway) can reuse this server.
7//!
8//! The client side (forwarding, auto-spawn) lives in the transport crate
9//! (e.g. `khive-mcp`), not here.
10
11use std::io::Write as _;
12use std::path::PathBuf;
13use std::sync::Arc;
14
15#[cfg(unix)]
16use std::os::unix::fs::PermissionsExt;
17
18use async_trait::async_trait;
19use serde::{Deserialize, Serialize};
20use tokio::io::{AsyncReadExt, AsyncWriteExt};
21use tokio::net::{UnixListener, UnixStream};
22
23/// Maximum frame size accepted in either direction.
24pub const MAX_FRAME_BYTES: usize = 8 * 1024 * 1024;
25
26const DEFAULT_DRAIN_TIMEOUT_SECS: u64 = 10;
27
28// ── paths ─────────────────────────────────────────────────────────────────────
29
30fn khive_dir() -> PathBuf {
31    let home = std::env::var("HOME").unwrap_or_else(|_| ".".into());
32    PathBuf::from(home).join(".khive")
33}
34
35/// Unix socket path the daemon binds and clients connect to.
36///
37/// Overridable via the `KHIVE_SOCKET` env var (for tests and ops).
38pub fn socket_path() -> PathBuf {
39    if let Ok(p) = std::env::var("KHIVE_SOCKET") {
40        if !p.is_empty() {
41            return PathBuf::from(p);
42        }
43    }
44    khive_dir().join("khived.sock")
45}
46
47/// PID file path written by the daemon.
48///
49/// Overridable via the `KHIVE_PID` env var.
50pub fn pid_path() -> PathBuf {
51    if let Ok(p) = std::env::var("KHIVE_PID") {
52        if !p.is_empty() {
53            return PathBuf::from(p);
54        }
55    }
56    khive_dir().join("khived.pid")
57}
58
59// ── wire types ────────────────────────────────────────────────────────────────
60
61/// Request frame sent from a client to the daemon.
62#[derive(Serialize, Deserialize)]
63pub struct DaemonRequestFrame {
64    pub ops: String,
65    pub presentation: Option<String>,
66    pub presentation_per_op: Option<Vec<Option<String>>>,
67    pub namespace: String,
68    /// Fingerprint of the client's resolved runtime config (packs, db target,
69    /// embedders). The daemon rejects a request whose `config_id` differs from
70    /// its own so a restricted client (e.g. `--pack kg`, `--db :memory:`) never
71    /// dispatches through the broader default daemon. See ADR-027 / ADR-049.
72    #[serde(default)]
73    pub config_id: String,
74}
75
76/// Response frame sent from the daemon back to a client.
77#[derive(Serialize, Deserialize)]
78pub struct DaemonResponseFrame {
79    pub ok: bool,
80    pub result: Option<String>,
81    pub error: Option<String>,
82    pub namespace_mismatch: bool,
83    /// Set when the request's `config_id` does not match the daemon's. Like
84    /// `namespace_mismatch`, this signals the client to fall back to local
85    /// dispatch rather than execute under a different runtime/config.
86    #[serde(default)]
87    pub config_mismatch: bool,
88    /// The `config_id` the daemon dispatched under, echoed back so the client
89    /// can positively confirm the result came from a matching runtime. A
90    /// pre-`config_id` daemon omits this field (deserializes to `None`), which
91    /// the client treats as a mismatch and falls back to local dispatch — this
92    /// closes the upgrade window where a new restricted client could otherwise
93    /// trust a still-warm legacy daemon's broader registry.
94    #[serde(default)]
95    pub served_config_id: Option<String>,
96}
97
98// ── framing ───────────────────────────────────────────────────────────────────
99
100/// Read one length-prefixed frame (4-byte BE u32 length + JSON bytes).
101pub async fn read_frame(stream: &mut UnixStream) -> std::io::Result<Vec<u8>> {
102    let mut len_buf = [0u8; 4];
103    stream.read_exact(&mut len_buf).await?;
104    let len = u32::from_be_bytes(len_buf) as usize;
105    if len > MAX_FRAME_BYTES {
106        return Err(std::io::Error::new(
107            std::io::ErrorKind::InvalidData,
108            format!("daemon frame of {len} bytes exceeds {MAX_FRAME_BYTES} cap"),
109        ));
110    }
111    let mut buf = vec![0u8; len];
112    stream.read_exact(&mut buf).await?;
113    Ok(buf)
114}
115
116/// Write one length-prefixed frame.
117pub async fn write_frame(stream: &mut UnixStream, payload: &[u8]) -> std::io::Result<()> {
118    if payload.len() > MAX_FRAME_BYTES {
119        return Err(std::io::Error::new(
120            std::io::ErrorKind::InvalidData,
121            format!(
122                "daemon frame of {} bytes exceeds {MAX_FRAME_BYTES} cap",
123                payload.len()
124            ),
125        ));
126    }
127    let len = (payload.len() as u32).to_be_bytes();
128    stream.write_all(&len).await?;
129    stream.write_all(payload).await?;
130    stream.flush().await?;
131    Ok(())
132}
133
134// ── dispatch trait ────────────────────────────────────────────────────────────
135
136/// Transport-agnostic dispatch interface for the daemon server.
137///
138/// The MCP crate implements this by wrapping `dispatch_request_local`; any
139/// future transport can do the same.
140#[async_trait]
141pub trait DaemonDispatch: Clone + Send + Sync + 'static {
142    /// Dispatch a verb-DSL request string and return the JSON result.
143    async fn dispatch(
144        &self,
145        ops: String,
146        presentation: Option<String>,
147        presentation_per_op: Option<Vec<Option<String>>>,
148    ) -> Result<String, String>;
149
150    /// Warm every pack's in-memory state (ANN indexes, etc.).
151    async fn warm_all(&self);
152
153    /// The namespace this dispatcher was configured for.
154    fn namespace(&self) -> &str;
155
156    /// Fingerprint of this dispatcher's resolved runtime config (packs, db
157    /// target, embedders). Used to reject forwarded requests from clients whose
158    /// config differs, so a restricted client cannot dispatch through a broader
159    /// daemon.
160    fn config_id(&self) -> &str;
161}
162
163// ── server ────────────────────────────────────────────────────────────────────
164
165async fn handle_conn<D: DaemonDispatch>(mut stream: UnixStream, dispatcher: D) {
166    let raw = match read_frame(&mut stream).await {
167        Ok(r) => r,
168        Err(e) => {
169            tracing::debug!(error = %e, "failed to read daemon request frame");
170            return;
171        }
172    };
173    let frame: DaemonRequestFrame = match serde_json::from_slice(&raw) {
174        Ok(f) => f,
175        Err(e) => {
176            tracing::debug!(error = %e, "failed to decode daemon request frame");
177            return;
178        }
179    };
180
181    let served_config_id = Some(dispatcher.config_id().to_string());
182    let resp = if frame.namespace != dispatcher.namespace() {
183        DaemonResponseFrame {
184            ok: false,
185            result: None,
186            error: None,
187            namespace_mismatch: true,
188            config_mismatch: false,
189            served_config_id,
190        }
191    } else if frame.config_id != dispatcher.config_id() {
192        DaemonResponseFrame {
193            ok: false,
194            result: None,
195            error: None,
196            namespace_mismatch: false,
197            config_mismatch: true,
198            served_config_id,
199        }
200    } else {
201        match dispatcher
202            .dispatch(frame.ops, frame.presentation, frame.presentation_per_op)
203            .await
204        {
205            Ok(result) => DaemonResponseFrame {
206                ok: true,
207                result: Some(result),
208                error: None,
209                namespace_mismatch: false,
210                config_mismatch: false,
211                served_config_id,
212            },
213            Err(e) => DaemonResponseFrame {
214                ok: false,
215                result: None,
216                error: Some(e),
217                namespace_mismatch: false,
218                config_mismatch: false,
219                served_config_id,
220            },
221        }
222    };
223
224    match serde_json::to_vec(&resp) {
225        Ok(payload) => {
226            if let Err(e) = write_frame(&mut stream, &payload).await {
227                tracing::debug!(error = %e, "failed to write daemon response frame");
228            }
229        }
230        Err(e) => tracing::warn!(error = %e, "failed to serialize daemon response frame"),
231    }
232}
233
234/// Run the daemon: bind the socket, warm in the background, serve request
235/// frames until SIGTERM/SIGINT.
236pub async fn run_daemon<D: DaemonDispatch>(dispatcher: D) -> anyhow::Result<()> {
237    let sock = socket_path();
238    let pid_file = pid_path();
239
240    if let Some(parent) = sock.parent() {
241        std::fs::create_dir_all(parent)?;
242        #[cfg(unix)]
243        if let Err(e) = std::fs::set_permissions(parent, std::fs::Permissions::from_mode(0o700)) {
244            tracing::warn!(error = %e, path = ?parent, "failed to chmod 0700 khive dir");
245        }
246    }
247
248    if !cleanup_stale_daemon(&sock, &pid_file).await {
249        tracing::info!("a responsive khived is already running; exiting");
250        return Ok(());
251    }
252
253    let listener = UnixListener::bind(&sock)?;
254    #[cfg(unix)]
255    if let Err(e) = std::fs::set_permissions(&sock, std::fs::Permissions::from_mode(0o600)) {
256        tracing::warn!(error = %e, path = ?sock, "failed to chmod 0600 socket");
257    }
258
259    write_pid_file(&pid_file)?;
260    tracing::info!(socket = ?sock, pid = std::process::id(), "khived listening");
261
262    {
263        let warm = dispatcher.clone();
264        tokio::spawn(async move {
265            warm.warm_all().await;
266        });
267    }
268
269    let active = Arc::new(std::sync::atomic::AtomicUsize::new(0));
270
271    let shutdown = async {
272        // REASON: signal handler registration can only fail if the global Tokio runtime
273        // is not running or the OS rejects the signal number — both are unrecoverable
274        // at this point in startup, so panic is the correct response.
275        let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
276            .expect("install SIGTERM handler");
277        let mut sigint = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())
278            .expect("install SIGINT handler");
279        tokio::select! {
280            _ = sigterm.recv() => tracing::info!("received SIGTERM"),
281            _ = sigint.recv() => tracing::info!("received SIGINT"),
282        }
283    };
284
285    tokio::select! {
286        _ = async {
287            loop {
288                match listener.accept().await {
289                    Ok((stream, _)) => {
290                        let d = dispatcher.clone();
291                        let active = Arc::clone(&active);
292                        tokio::spawn(async move {
293                            active.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
294                            handle_conn(stream, d).await;
295                            active.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
296                        });
297                    }
298                    Err(e) => tracing::error!(error = %e, "accept failed"),
299                }
300            }
301        } => {}
302        _ = shutdown => {}
303    }
304
305    drain(&active).await;
306
307    let _ = std::fs::remove_file(&sock);
308    let _ = std::fs::remove_file(&pid_file);
309    tracing::info!("khived stopped");
310    Ok(())
311}
312
313// ── helpers ───────────────────────────────────────────────────────────────────
314
315fn is_process_running(pid: u32) -> bool {
316    let Ok(pid) = i32::try_from(pid) else {
317        return false;
318    };
319    if pid <= 0 {
320        return false;
321    }
322    // SAFETY: signal 0 is an existence/permission probe with no side effects.
323    unsafe { libc::kill(pid, 0) == 0 }
324}
325
326async fn cleanup_stale_daemon(sock: &std::path::Path, pid_file: &std::path::Path) -> bool {
327    if let Ok(pid_str) = std::fs::read_to_string(pid_file) {
328        if let Ok(pid) = pid_str.trim().parse::<u32>() {
329            if pid != std::process::id()
330                && is_process_running(pid)
331                && sock.exists()
332                && UnixStream::connect(sock).await.is_ok()
333            {
334                return false;
335            }
336        }
337    }
338    if sock.exists() {
339        if let Err(e) = std::fs::remove_file(sock) {
340            tracing::warn!(error = %e, path = ?sock, "failed to remove stale socket");
341        }
342    }
343    if pid_file.exists() {
344        if let Err(e) = std::fs::remove_file(pid_file) {
345            tracing::warn!(error = %e, path = ?pid_file, "failed to remove stale PID file");
346        }
347    }
348    true
349}
350
351fn write_pid_file(pid_file: &std::path::Path) -> std::io::Result<()> {
352    let mut opts = std::fs::OpenOptions::new();
353    opts.write(true).create(true).truncate(true);
354    #[cfg(unix)]
355    {
356        use std::os::unix::fs::OpenOptionsExt;
357        opts.mode(0o600);
358    }
359    let mut f = opts.open(pid_file)?;
360    f.write_all(std::process::id().to_string().as_bytes())?;
361    Ok(())
362}
363
364async fn drain(active: &std::sync::atomic::AtomicUsize) {
365    use std::sync::atomic::Ordering;
366    if active.load(Ordering::Relaxed) == 0 {
367        return;
368    }
369    let deadline = tokio::time::Instant::now() + drain_timeout();
370    while active.load(Ordering::Relaxed) > 0 {
371        if tokio::time::Instant::now() >= deadline {
372            tracing::warn!(
373                remaining = active.load(Ordering::Relaxed),
374                "drain timeout reached; forcing shutdown"
375            );
376            break;
377        }
378        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
379    }
380}
381
382fn drain_timeout() -> std::time::Duration {
383    let secs = std::env::var("KHIVE_DRAIN_TIMEOUT_SECS")
384        .ok()
385        .and_then(|v| v.parse::<u64>().ok())
386        .unwrap_or(DEFAULT_DRAIN_TIMEOUT_SECS);
387    std::time::Duration::from_secs(secs)
388}
389
390/// Returns `true` for non-empty env values that are not `"0"` or `"false"`.
391pub fn env_truthy(key: &str) -> bool {
392    std::env::var(key)
393        .map(|v| {
394            let v = v.trim();
395            !v.is_empty() && v != "0" && !v.eq_ignore_ascii_case("false")
396        })
397        .unwrap_or(false)
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    // Focused regression tests for the unsafe process probe (SAFETY: signal 0
405    // is an existence check with no side effects; see is_process_running).
406
407    #[test]
408    fn current_process_is_running() {
409        // The current PID is always alive.
410        let pid = std::process::id();
411        assert!(
412            is_process_running(pid),
413            "current process {pid} should be detected as running"
414        );
415    }
416
417    #[test]
418    fn pid_zero_is_not_running() {
419        // PID 0 is the process group; kill(0, 0) sends to the group,
420        // which we treat as invalid — the guard `pid <= 0` must block it.
421        assert!(
422            !is_process_running(0),
423            "pid 0 must be rejected by the guard before the unsafe call"
424        );
425    }
426
427    #[test]
428    fn very_large_pid_is_not_running() {
429        // u32::MAX overflows i32 — try_from returns Err, guard returns false.
430        assert!(
431            !is_process_running(u32::MAX),
432            "u32::MAX should fail i32 conversion and return false"
433        );
434    }
435
436    #[test]
437    fn env_truthy_recognises_set_values() {
438        assert!(!env_truthy("__KHIVE_TEST_ABSENT_VAR_XYZ__"));
439
440        // env_truthy with a live value — set and unset atomically to avoid
441        // cross-test pollution (not parallel-safe without serial_test, but these
442        // are fast unit tests and the variable name is unique).
443        let key = "__KHIVE_TEST_TRUTHY_ABC__";
444        std::env::set_var(key, "1");
445        assert!(env_truthy(key));
446        std::env::set_var(key, "false");
447        assert!(!env_truthy(key));
448        std::env::set_var(key, "0");
449        assert!(!env_truthy(key));
450        std::env::remove_var(key);
451    }
452}