Skip to main content

khive_runtime/
daemon.rs

1//! khived daemon server — persistent warm runtime over a Unix socket (ADR-049).
2//!
3//! The daemon binds `~/.khive/khived.sock`, accepts length-prefixed request
4//! frames, dispatches them through a [`DaemonDispatch`] implementor, and serves
5//! results back. It is transport-agnostic: the MCP crate provides the dispatch
6//! impl, but any future client (CLI, HTTP gateway) can reuse this server.
7//!
8//! The client side (forwarding, auto-spawn) lives in the transport crate
9//! (e.g. `khive-mcp`), not here.
10
11use std::io::Write as _;
12use std::path::PathBuf;
13use std::sync::Arc;
14
15#[cfg(unix)]
16use std::os::unix::fs::PermissionsExt;
17
18use async_trait::async_trait;
19use serde::{Deserialize, Serialize};
20use tokio::io::{AsyncReadExt, AsyncWriteExt};
21use tokio::net::{UnixListener, UnixStream};
22
23/// Maximum frame size accepted in either direction.
24pub const MAX_FRAME_BYTES: usize = 8 * 1024 * 1024;
25
26const DEFAULT_DRAIN_TIMEOUT_SECS: u64 = 10;
27
28// ── paths ─────────────────────────────────────────────────────────────────────
29
30fn khive_dir() -> PathBuf {
31    let home = std::env::var("HOME").unwrap_or_else(|_| ".".into());
32    PathBuf::from(home).join(".khive")
33}
34
35/// Unix socket path the daemon binds and clients connect to.
36///
37/// Overridable via the `KHIVE_SOCKET` env var (for tests and ops).
38pub fn socket_path() -> PathBuf {
39    if let Ok(p) = std::env::var("KHIVE_SOCKET") {
40        if !p.is_empty() {
41            return PathBuf::from(p);
42        }
43    }
44    khive_dir().join("khived.sock")
45}
46
47/// PID file path written by the daemon.
48///
49/// Overridable via the `KHIVE_PID` env var.
50pub fn pid_path() -> PathBuf {
51    if let Ok(p) = std::env::var("KHIVE_PID") {
52        if !p.is_empty() {
53            return PathBuf::from(p);
54        }
55    }
56    khive_dir().join("khived.pid")
57}
58
59// ── wire types ────────────────────────────────────────────────────────────────
60
61/// Request frame sent from a client to the daemon.
62#[derive(Serialize, Deserialize)]
63pub struct DaemonRequestFrame {
64    pub ops: String,
65    pub presentation: Option<String>,
66    pub presentation_per_op: Option<Vec<Option<String>>>,
67    pub namespace: String,
68}
69
70/// Response frame sent from the daemon back to a client.
71#[derive(Serialize, Deserialize)]
72pub struct DaemonResponseFrame {
73    pub ok: bool,
74    pub result: Option<String>,
75    pub error: Option<String>,
76    pub namespace_mismatch: bool,
77}
78
79// ── framing ───────────────────────────────────────────────────────────────────
80
81/// Read one length-prefixed frame (4-byte BE u32 length + JSON bytes).
82pub async fn read_frame(stream: &mut UnixStream) -> std::io::Result<Vec<u8>> {
83    let mut len_buf = [0u8; 4];
84    stream.read_exact(&mut len_buf).await?;
85    let len = u32::from_be_bytes(len_buf) as usize;
86    if len > MAX_FRAME_BYTES {
87        return Err(std::io::Error::new(
88            std::io::ErrorKind::InvalidData,
89            format!("daemon frame of {len} bytes exceeds {MAX_FRAME_BYTES} cap"),
90        ));
91    }
92    let mut buf = vec![0u8; len];
93    stream.read_exact(&mut buf).await?;
94    Ok(buf)
95}
96
97/// Write one length-prefixed frame.
98pub async fn write_frame(stream: &mut UnixStream, payload: &[u8]) -> std::io::Result<()> {
99    if payload.len() > MAX_FRAME_BYTES {
100        return Err(std::io::Error::new(
101            std::io::ErrorKind::InvalidData,
102            format!(
103                "daemon frame of {} bytes exceeds {MAX_FRAME_BYTES} cap",
104                payload.len()
105            ),
106        ));
107    }
108    let len = (payload.len() as u32).to_be_bytes();
109    stream.write_all(&len).await?;
110    stream.write_all(payload).await?;
111    stream.flush().await?;
112    Ok(())
113}
114
115// ── dispatch trait ────────────────────────────────────────────────────────────
116
117/// Transport-agnostic dispatch interface for the daemon server.
118///
119/// The MCP crate implements this by wrapping `dispatch_request_local`; any
120/// future transport can do the same.
121#[async_trait]
122pub trait DaemonDispatch: Clone + Send + Sync + 'static {
123    /// Dispatch a verb-DSL request string and return the JSON result.
124    async fn dispatch(
125        &self,
126        ops: String,
127        presentation: Option<String>,
128        presentation_per_op: Option<Vec<Option<String>>>,
129    ) -> Result<String, String>;
130
131    /// Warm every pack's in-memory state (ANN indexes, etc.).
132    async fn warm_all(&self);
133
134    /// The namespace this dispatcher was configured for.
135    fn namespace(&self) -> &str;
136}
137
138// ── server ────────────────────────────────────────────────────────────────────
139
140async fn handle_conn<D: DaemonDispatch>(mut stream: UnixStream, dispatcher: D) {
141    let raw = match read_frame(&mut stream).await {
142        Ok(r) => r,
143        Err(e) => {
144            tracing::debug!(error = %e, "failed to read daemon request frame");
145            return;
146        }
147    };
148    let frame: DaemonRequestFrame = match serde_json::from_slice(&raw) {
149        Ok(f) => f,
150        Err(e) => {
151            tracing::debug!(error = %e, "failed to decode daemon request frame");
152            return;
153        }
154    };
155
156    let resp = if frame.namespace != dispatcher.namespace() {
157        DaemonResponseFrame {
158            ok: false,
159            result: None,
160            error: None,
161            namespace_mismatch: true,
162        }
163    } else {
164        match dispatcher
165            .dispatch(frame.ops, frame.presentation, frame.presentation_per_op)
166            .await
167        {
168            Ok(result) => DaemonResponseFrame {
169                ok: true,
170                result: Some(result),
171                error: None,
172                namespace_mismatch: false,
173            },
174            Err(e) => DaemonResponseFrame {
175                ok: false,
176                result: None,
177                error: Some(e),
178                namespace_mismatch: false,
179            },
180        }
181    };
182
183    match serde_json::to_vec(&resp) {
184        Ok(payload) => {
185            if let Err(e) = write_frame(&mut stream, &payload).await {
186                tracing::debug!(error = %e, "failed to write daemon response frame");
187            }
188        }
189        Err(e) => tracing::warn!(error = %e, "failed to serialize daemon response frame"),
190    }
191}
192
193/// Run the daemon: bind the socket, warm in the background, serve request
194/// frames until SIGTERM/SIGINT.
195pub async fn run_daemon<D: DaemonDispatch>(dispatcher: D) -> anyhow::Result<()> {
196    let sock = socket_path();
197    let pid_file = pid_path();
198
199    if let Some(parent) = sock.parent() {
200        std::fs::create_dir_all(parent)?;
201        #[cfg(unix)]
202        if let Err(e) = std::fs::set_permissions(parent, std::fs::Permissions::from_mode(0o700)) {
203            tracing::warn!(error = %e, path = ?parent, "failed to chmod 0700 khive dir");
204        }
205    }
206
207    if !cleanup_stale_daemon(&sock, &pid_file).await {
208        tracing::info!("a responsive khived is already running; exiting");
209        return Ok(());
210    }
211
212    let listener = UnixListener::bind(&sock)?;
213    #[cfg(unix)]
214    if let Err(e) = std::fs::set_permissions(&sock, std::fs::Permissions::from_mode(0o600)) {
215        tracing::warn!(error = %e, path = ?sock, "failed to chmod 0600 socket");
216    }
217
218    write_pid_file(&pid_file)?;
219    tracing::info!(socket = ?sock, pid = std::process::id(), "khived listening");
220
221    {
222        let warm = dispatcher.clone();
223        tokio::spawn(async move {
224            warm.warm_all().await;
225        });
226    }
227
228    let active = Arc::new(std::sync::atomic::AtomicUsize::new(0));
229
230    let shutdown = async {
231        let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
232            .expect("install SIGTERM handler");
233        let mut sigint = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())
234            .expect("install SIGINT handler");
235        tokio::select! {
236            _ = sigterm.recv() => tracing::info!("received SIGTERM"),
237            _ = sigint.recv() => tracing::info!("received SIGINT"),
238        }
239    };
240
241    tokio::select! {
242        _ = async {
243            loop {
244                match listener.accept().await {
245                    Ok((stream, _)) => {
246                        let d = dispatcher.clone();
247                        let active = Arc::clone(&active);
248                        tokio::spawn(async move {
249                            active.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
250                            handle_conn(stream, d).await;
251                            active.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
252                        });
253                    }
254                    Err(e) => tracing::error!(error = %e, "accept failed"),
255                }
256            }
257        } => {}
258        _ = shutdown => {}
259    }
260
261    drain(&active).await;
262
263    let _ = std::fs::remove_file(&sock);
264    let _ = std::fs::remove_file(&pid_file);
265    tracing::info!("khived stopped");
266    Ok(())
267}
268
269// ── helpers ───────────────────────────────────────────────────────────────────
270
271fn is_process_running(pid: u32) -> bool {
272    let Ok(pid) = i32::try_from(pid) else {
273        return false;
274    };
275    if pid <= 0 {
276        return false;
277    }
278    // SAFETY: signal 0 is an existence/permission probe with no side effects.
279    unsafe { libc::kill(pid, 0) == 0 }
280}
281
282async fn cleanup_stale_daemon(sock: &std::path::Path, pid_file: &std::path::Path) -> bool {
283    if let Ok(pid_str) = std::fs::read_to_string(pid_file) {
284        if let Ok(pid) = pid_str.trim().parse::<u32>() {
285            if pid != std::process::id()
286                && is_process_running(pid)
287                && sock.exists()
288                && UnixStream::connect(sock).await.is_ok()
289            {
290                return false;
291            }
292        }
293    }
294    if sock.exists() {
295        if let Err(e) = std::fs::remove_file(sock) {
296            tracing::warn!(error = %e, path = ?sock, "failed to remove stale socket");
297        }
298    }
299    if pid_file.exists() {
300        if let Err(e) = std::fs::remove_file(pid_file) {
301            tracing::warn!(error = %e, path = ?pid_file, "failed to remove stale PID file");
302        }
303    }
304    true
305}
306
307fn write_pid_file(pid_file: &std::path::Path) -> std::io::Result<()> {
308    let mut opts = std::fs::OpenOptions::new();
309    opts.write(true).create(true).truncate(true);
310    #[cfg(unix)]
311    {
312        use std::os::unix::fs::OpenOptionsExt;
313        opts.mode(0o600);
314    }
315    let mut f = opts.open(pid_file)?;
316    f.write_all(std::process::id().to_string().as_bytes())?;
317    Ok(())
318}
319
320async fn drain(active: &std::sync::atomic::AtomicUsize) {
321    use std::sync::atomic::Ordering;
322    if active.load(Ordering::Relaxed) == 0 {
323        return;
324    }
325    let deadline = tokio::time::Instant::now() + drain_timeout();
326    while active.load(Ordering::Relaxed) > 0 {
327        if tokio::time::Instant::now() >= deadline {
328            tracing::warn!(
329                remaining = active.load(Ordering::Relaxed),
330                "drain timeout reached; forcing shutdown"
331            );
332            break;
333        }
334        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
335    }
336}
337
338fn drain_timeout() -> std::time::Duration {
339    let secs = std::env::var("KHIVE_DRAIN_TIMEOUT_SECS")
340        .ok()
341        .and_then(|v| v.parse::<u64>().ok())
342        .unwrap_or(DEFAULT_DRAIN_TIMEOUT_SECS);
343    std::time::Duration::from_secs(secs)
344}
345
346/// Returns `true` for non-empty env values that are not `"0"` or `"false"`.
347pub fn env_truthy(key: &str) -> bool {
348    std::env::var(key)
349        .map(|v| {
350            let v = v.trim();
351            !v.is_empty() && v != "0" && !v.eq_ignore_ascii_case("false")
352        })
353        .unwrap_or(false)
354}