Skip to main content

inferd_daemon/
lifecycle.rs

1//! Daemon lifecycle: boot → wait-for-ready → bind listener → accept →
2//! dispatch → shutdown.
3//!
4//! The M1 lifecycle wires:
5//! - `lock` — single-instance lock at startup (THREAT_MODEL F-2).
6//! - `router` — backend selection (no-op v0.1 — picks the only one).
7//! - `endpoint` — listener bound only after `router.all_ready()`
8//!   (THREAT_MODEL F-13).
9//! - `queue` — admission gate (`SubmitError::QueueFull` → wire
10//!   `code: queue_full`).
11//! - `inferd-proto` — frame parsing and serialisation.
12//!
13//! Cancellation: dropping a connection drops the in-flight `TokenStream`,
14//! which closes the engine's `tx` and stops the spawned generation task.
15//! Per ADR 0007 the daemon emits no terminal frame on cancel — the EOF
16//! is the signal.
17
18use crate::auth::{AuthFrame, key_matches};
19use crate::endpoint::Connection;
20use crate::peercred::PeerIdentity;
21use crate::queue::{Admission, SubmitError};
22use crate::router::{Router, RouterError};
23use inferd_engine::{GenerateError, TokenEvent};
24use inferd_proto::{ErrorCode, ProtoError, Request, Response, write_frame};
25use std::io;
26use std::sync::Arc;
27use std::time::{Duration, Instant};
28use tokio::io::{AsyncWrite, AsyncWriteExt, BufReader};
29use tokio::sync::Mutex;
30use tokio_stream::StreamExt;
31use tracing::{debug, info, warn};
32
33/// Wait until every backend in `router` reports ready, polling at 50ms
34/// intervals up to `timeout`. Returns the duration spent waiting.
35///
36/// THREAT_MODEL F-13: nothing else creates listeners until this returns.
37pub async fn wait_for_ready(router: &Router, timeout: Duration) -> Result<Duration, ReadyTimeout> {
38    let started = Instant::now();
39    let poll = Duration::from_millis(50);
40    loop {
41        if router.all_ready() {
42            return Ok(started.elapsed());
43        }
44        if started.elapsed() >= timeout {
45            return Err(ReadyTimeout(timeout));
46        }
47        tokio::time::sleep(poll).await;
48    }
49}
50
51/// Returned when `wait_for_ready` exhausts its budget without seeing
52/// readiness across every backend.
53#[derive(Debug, thiserror::Error)]
54#[error("backend not ready within {0:?}")]
55pub struct ReadyTimeout(pub Duration);
56
57/// Per-accept context that the lifecycle hands to every spawned
58/// connection task.
59///
60/// Today it carries the optional TCP API key (THREAT_MODEL F-8) and
61/// the shared admission gate (queue_full enforcement). New per-
62/// connection policy (rate limits, per-caller quotas) extends this
63/// struct rather than each `serve_*` signature.
64#[derive(Clone, Default)]
65pub struct AcceptContext {
66    /// When `Some` and the connection is TCP, the daemon requires an
67    /// auth frame as the first NDJSON line on the wire and constant-
68    /// time-compares the key against this value. UDS / pipe ignore
69    /// this field — F-7 covers them.
70    pub expected_api_key: Option<String>,
71    /// Shared admission gate. `None` for tests / dev paths that
72    /// don't care about queue depth — those treat every request
73    /// as admitted. Production lifecycle always passes `Some`.
74    pub admission: Option<Admission>,
75}
76
77impl std::fmt::Debug for AcceptContext {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        f.debug_struct("AcceptContext")
80            .field("expected_api_key", &self.expected_api_key.is_some())
81            .field(
82                "admission_capacity",
83                &self.admission.as_ref().map(|a| a.capacity()),
84            )
85            .finish()
86    }
87}
88
89/// Handle one accepted client connection: read framed `Request`s and write
90/// framed `Response`s until EOF or fatal error.
91///
92/// Per request:
93/// 1. Read one frame (`read_frame`).
94/// 2. `Request::resolve()` — apply defaults, validate. Failures → `error`
95///    frame with `code: invalid_request`.
96/// 3. `router.dispatch()` — pick a backend.
97/// 4. `backend.generate()` — pre-stream errors → `error` frame with
98///    `code: backend_unavailable`.
99/// 5. Stream `TokenEvent`s, translating each to `Response::Token` /
100///    `Response::Done`. If the engine drops the stream without `Done`,
101///    emit `error` with `code: backend_unavailable`.
102pub async fn handle_connection<C: Connection + 'static>(
103    mut conn: C,
104    router: Arc<Router>,
105    peer: PeerIdentity,
106    ctx: AcceptContext,
107) -> Result<(), io::Error> {
108    let transport = conn.transport();
109    info!(
110        target: "inferd_daemon::activity",
111        transport = transport,
112        peer = %peer,
113        peer_uid = peer.uid,
114        peer_pid = peer.pid,
115        peer_sid = peer.sid.as_deref(),
116        "connection_accepted"
117    );
118
119    // Split read and write halves so the generation task can write tokens
120    // while we keep reading the next request. We don't actually pipeline
121    // requests in M1 (admission queue is 1-active anyway), but the split
122    // is needed because tokio AsyncWrite is consumed by `write_all`.
123    let (read_half, write_half) = tokio::io::split(&mut conn);
124    let mut reader = BufReader::with_capacity(64 * 1024, read_half);
125    let writer = Arc::new(Mutex::new(write_half));
126
127    // F-8: TCP first-frame auth. UDS / pipe rely on F-7 peer creds and
128    // skip this. Anonymous probers see the connection close with no
129    // protocol error frame — we don't confirm endpoint existence.
130    if transport == "tcp"
131        && let Some(expected) = ctx.expected_api_key.as_deref()
132    {
133        match read_auth_frame(&mut reader).await {
134            Some(frame) if key_matches(&frame.key, expected) => {
135                debug!(transport, "tcp auth ok");
136            }
137            _ => {
138                warn!(
139                    target: "inferd_daemon::activity",
140                    peer = %peer,
141                    "tcp_auth_rejected"
142                );
143                return Ok(());
144            }
145        }
146    }
147
148    loop {
149        // Read one request frame. `read_frame` is sync over a sync BufRead;
150        // we have an async reader, so do a small async-to-sync bridge by
151        // first reading into a vec and then parsing.
152        let request: Request = match read_frame_async(&mut reader).await {
153            Ok(Some(r)) => r,
154            Ok(None) => return Ok(()), // peer closed cleanly
155            Err(ProtoError::Io(e)) => return Err(e),
156            Err(e) => {
157                let resp = Response::Error {
158                    id: String::new(),
159                    code: e.to_error_code(),
160                    message: e.to_string(),
161                };
162                write_response(&writer, &resp).await?;
163                return Ok(());
164            }
165        };
166
167        // Resolve: defaults + validation.
168        let id = request.id.clone();
169        let resolved = match request.resolve() {
170            Ok(r) => r,
171            Err(e) => {
172                let resp = Response::Error {
173                    id,
174                    code: ErrorCode::InvalidRequest,
175                    message: e.to_string(),
176                };
177                write_response(&writer, &resp).await?;
178                continue;
179            }
180        };
181
182        // Admission gate (queue_full enforcement). Held for the full
183        // generation; dropping the permit (after the Done frame, on
184        // mid-stream error, or on connection drop) returns the slot
185        // to the pool. `None` admission = tests / dev paths that
186        // don't care about queue depth.
187        let _admit_permit = match ctx.admission.as_ref().map(|a| a.try_admit()) {
188            None => None,
189            Some(Ok(p)) => Some(p),
190            Some(Err(SubmitError::QueueFull)) => {
191                let resp = Response::Error {
192                    id: resolved.id.clone(),
193                    code: ErrorCode::QueueFull,
194                    message: "queue full".into(),
195                };
196                write_response(&writer, &resp).await?;
197                continue;
198            }
199            Some(Err(SubmitError::Closed)) => {
200                // Admission closed = daemon shutting down. Tell the
201                // caller, then drop the connection — there's no point
202                // reading another request that we'll also reject.
203                let resp = Response::Error {
204                    id: resolved.id.clone(),
205                    code: ErrorCode::BackendUnavailable,
206                    message: "admission closed".into(),
207                };
208                write_response(&writer, &resp).await?;
209                return Ok(());
210            }
211        };
212
213        // Dispatch through the router.
214        let dispatch = match router.dispatch() {
215            Ok(d) => d,
216            Err(RouterError::NoBackends) | Err(RouterError::NoneAvailable) => {
217                let resp = Response::Error {
218                    id: resolved.id.clone(),
219                    code: ErrorCode::BackendUnavailable,
220                    message: "no backend available".into(),
221                };
222                write_response(&writer, &resp).await?;
223                continue;
224            }
225        };
226        let backend_name = dispatch.name.clone();
227        let backend = dispatch.backend;
228        let req_id = resolved.id.clone();
229
230        // Generate. Pre-stream errors count toward the breaker per
231        // ADR 0007 — InvalidRequest does not (it's a caller bug, not
232        // a backend health signal).
233        let mut stream = match backend.generate(resolved).await {
234            Ok(s) => s,
235            Err(e) => {
236                let (code, message, is_backend_failure) = match e {
237                    GenerateError::InvalidRequest(m) => (ErrorCode::InvalidRequest, m, false),
238                    GenerateError::NotReady => (
239                        ErrorCode::BackendUnavailable,
240                        "backend not ready".into(),
241                        true,
242                    ),
243                    GenerateError::Unavailable(m) => (ErrorCode::BackendUnavailable, m, true),
244                    GenerateError::Internal(m) => (ErrorCode::Internal, m, true),
245                };
246                if is_backend_failure {
247                    router.record_failure(&backend_name);
248                }
249                let resp = Response::Error {
250                    id: req_id,
251                    code,
252                    message,
253                };
254                write_response(&writer, &resp).await?;
255                continue;
256            }
257        };
258
259        // Stream tokens. Build the full content for Response::Done in one
260        // pass; the engine reports usage so we don't have to count.
261        let mut full = String::new();
262        let mut terminal_emitted = false;
263        while let Some(ev) = stream.next().await {
264            match ev {
265                TokenEvent::Token(text) => {
266                    let frame = Response::Token {
267                        id: req_id.clone(),
268                        content: text.clone(),
269                    };
270                    write_response(&writer, &frame).await?;
271                    full.push_str(&text);
272                }
273                TokenEvent::Done { stop_reason, usage } => {
274                    let frame = Response::Done {
275                        id: req_id.clone(),
276                        content: std::mem::take(&mut full),
277                        usage,
278                        stop_reason,
279                        backend: backend_name.clone(),
280                    };
281                    write_response(&writer, &frame).await?;
282                    info!(
283                        target: "inferd_daemon::activity",
284                        req_id = %req_id,
285                        backend = %backend_name,
286                        stop_reason = ?stop_reason,
287                        prompt_tokens = usage.prompt_tokens,
288                        completion_tokens = usage.completion_tokens,
289                        "request_done"
290                    );
291                    router.record_success(&backend_name);
292                    terminal_emitted = true;
293                    break;
294                }
295            }
296        }
297
298        if !terminal_emitted {
299            // Mid-stream backend failure (no Done event). Report and move
300            // to next request on the same connection. Counts toward the
301            // breaker per ADR 0007.
302            warn!(
303                target: "inferd_daemon::activity",
304                req_id = %req_id,
305                backend = %backend_name,
306                "request_error_mid_stream"
307            );
308            router.record_failure(&backend_name);
309            let frame = Response::Error {
310                id: req_id,
311                code: ErrorCode::BackendUnavailable,
312                message: "backend ended stream without terminal frame".into(),
313            };
314            write_response(&writer, &frame).await?;
315        }
316    }
317}
318
319/// Read one NDJSON line from a tokio `BufRead` and parse it as an
320/// `AuthFrame`. Returns `None` on any failure (truncation, garbage,
321/// wrong type) so the caller can close the connection silently.
322///
323/// Takes the *existing* BufReader the connection handler already
324/// owns — wrapping it in a second BufReader would buffer bytes
325/// past the auth line that then get lost when the local wrapper drops.
326async fn read_auth_frame<R>(reader: &mut R) -> Option<AuthFrame>
327where
328    R: tokio::io::AsyncBufRead + Unpin,
329{
330    use tokio::io::AsyncBufReadExt;
331    let mut line = Vec::with_capacity(256);
332    let limit = inferd_proto::MAX_FRAME_BYTES;
333    loop {
334        let buf = reader.fill_buf().await.ok()?;
335        if buf.is_empty() {
336            return None;
337        }
338        if let Some(idx) = buf.iter().position(|&b| b == b'\n') {
339            if line.len() + idx > limit {
340                return None;
341            }
342            line.extend_from_slice(&buf[..idx]);
343            reader.consume(idx + 1);
344            return AuthFrame::from_json(&line);
345        }
346        if line.len() + buf.len() > limit {
347            return None;
348        }
349        line.extend_from_slice(buf);
350        let n = buf.len();
351        reader.consume(n);
352    }
353}
354
355/// Async wrapper around `inferd_proto::read_frame` for tokio readers.
356///
357/// Consumes from an existing `AsyncBufRead` (typically the per-connection
358/// `BufReader` that the lifecycle holds) so any bytes prefetched past the
359/// current line stay in the caller's buffer for the next read. Wrapping
360/// the input in a *second* BufReader here would lose those bytes when
361/// the local wrapper dropped.
362async fn read_frame_async<R>(reader: &mut R) -> Result<Option<Request>, ProtoError>
363where
364    R: tokio::io::AsyncBufRead + Unpin,
365{
366    use tokio::io::AsyncBufReadExt;
367    let mut line = Vec::with_capacity(512);
368    let limit = inferd_proto::MAX_FRAME_BYTES;
369    loop {
370        let buf = reader.fill_buf().await?;
371        if buf.is_empty() {
372            if line.is_empty() {
373                return Ok(None);
374            }
375            // Trailing line without newline. Defer to the proto crate's
376            // sync reader, which handles trailing-line-without-newline as
377            // a final frame.
378            return inferd_proto::read_frame::<&[u8], Request>(&mut &line[..]);
379        }
380        if let Some(idx) = buf.iter().position(|&b| b == b'\n') {
381            if line.len() + idx > limit {
382                return Err(ProtoError::FrameTooLarge);
383            }
384            line.extend_from_slice(&buf[..=idx]);
385            reader.consume(idx + 1);
386            return inferd_proto::read_frame::<&[u8], Request>(&mut &line[..]);
387        }
388        if line.len() + buf.len() > limit {
389            return Err(ProtoError::FrameTooLarge);
390        }
391        line.extend_from_slice(buf);
392        let n = buf.len();
393        reader.consume(n);
394    }
395}
396
397async fn write_response<W: AsyncWrite + Unpin>(
398    writer: &Mutex<W>,
399    resp: &Response,
400) -> io::Result<()> {
401    let mut buf = Vec::with_capacity(512);
402    write_frame(&mut buf, resp)
403        .map_err(|e| io::Error::other(format!("serialise response: {e}")))?;
404    let mut guard = writer.lock().await;
405    guard.write_all(&buf).await?;
406    guard.flush().await?;
407    Ok(())
408}
409
410/// Serve a TCP listener: accept loop, spawn one task per connection.
411///
412/// Returns when `shutdown` resolves (e.g. a Ctrl-C signal). All in-flight
413/// connections are dropped at that point — clients see EOF and treat it as
414/// a non-terminal-frame error per `docs/protocol-v1.md`.
415pub async fn serve_tcp(
416    listener: tokio::net::TcpListener,
417    router: Arc<Router>,
418    ctx: AcceptContext,
419    mut shutdown: tokio::sync::oneshot::Receiver<()>,
420) -> io::Result<()> {
421    info!(addr = ?listener.local_addr()?, "tcp listener accepting");
422    loop {
423        tokio::select! {
424            _ = &mut shutdown => {
425                info!("shutdown signalled");
426                return Ok(());
427            }
428            accept = listener.accept() => {
429                let (stream, peer_addr) = accept?;
430                let r = Arc::clone(&router);
431                let peer = PeerIdentity::from_tcp(peer_addr);
432                let ctx = ctx.clone();
433                debug!(?peer_addr, "tcp accept");
434                tokio::spawn(async move {
435                    if let Err(e) = handle_connection(stream, r, peer, ctx).await {
436                        warn!(error = ?e, "connection terminated with error");
437                    }
438                });
439            }
440        }
441    }
442}
443
444/// Serve a Unix domain socket listener (Unix only).
445#[cfg(unix)]
446pub async fn serve_uds(
447    listener: tokio::net::UnixListener,
448    router: Arc<Router>,
449    ctx: AcceptContext,
450    mut shutdown: tokio::sync::oneshot::Receiver<()>,
451) -> io::Result<()> {
452    info!("uds listener accepting");
453    loop {
454        tokio::select! {
455            _ = &mut shutdown => {
456                info!("shutdown signalled");
457                return Ok(());
458            }
459            accept = listener.accept() => {
460                let (stream, _) = accept?;
461                let r = Arc::clone(&router);
462                // Best-effort SO_PEERCRED. If the OS refuses (very rare on
463                // a connected UDS), record an empty identity but still
464                // serve the request — the socket ACL was the primary
465                // perimeter; this is defence in depth.
466                let peer = crate::peercred::unix::from_stream(&stream)
467                    .unwrap_or_else(|e| {
468                        warn!(error = %e, "SO_PEERCRED failed; recording empty unix identity");
469                        crate::peercred::PeerIdentity {
470                            uid: None, gid: None, pid: None,
471                            sid: None, remote_addr: None,
472                            transport: "unix",
473                        }
474                    });
475                let ctx = ctx.clone();
476                debug!(?peer, "uds accept");
477                tokio::spawn(async move {
478                    if let Err(e) = handle_connection(stream, r, peer, ctx).await {
479                        warn!(error = ?e, "connection terminated with error");
480                    }
481                });
482            }
483        }
484    }
485}
486
487/// Serve a Windows named pipe (Windows only).
488///
489/// Caller must bind the first instance via
490/// [`crate::endpoint::bind_named_pipe(path, true)`] and pass it in via
491/// `first_instance`. This split ensures the listener exists before the
492/// caller (or a test harness) hands the path out — eliminates the race
493/// where a client connects between `tokio::spawn(serve_named_pipe)` and
494/// the first `bind_named_pipe` call inside the loop.
495///
496/// Loop:
497/// 1. Await `server.connect()` — accept point.
498/// 2. Hand the connected server to a per-connection task.
499/// 3. Bind the next server instance (`first = false`) so the next
500///    client can connect immediately.
501///
502/// Loops until `shutdown` resolves.
503#[cfg(windows)]
504pub async fn serve_named_pipe(
505    path: &str,
506    first_instance: tokio::net::windows::named_pipe::NamedPipeServer,
507    router: Arc<Router>,
508    ctx: AcceptContext,
509    mut shutdown: tokio::sync::oneshot::Receiver<()>,
510) -> io::Result<()> {
511    use crate::endpoint::bind_named_pipe;
512
513    info!(path = %path, "named pipe listener accepting");
514    let mut server = first_instance;
515    loop {
516        tokio::select! {
517            _ = &mut shutdown => {
518                info!("shutdown signalled");
519                return Ok(());
520            }
521            connect_result = server.connect() => {
522                connect_result?;
523                // Take ownership of the connected server; build the next
524                // listening instance before spawning the handler so a
525                // second client can connect immediately.
526                let connected = server;
527                server = bind_named_pipe(path, false)?;
528
529                // Best-effort peer identity. If the lookup fails (caller
530                // process exited between accept and probe), serve with
531                // an empty identity; named-pipe DACL is the primary
532                // perimeter, this is defence in depth.
533                let peer = crate::peercred::windows::from_stream(&connected)
534                    .unwrap_or_else(|e| {
535                        warn!(error = %e, "GetNamedPipeClientProcessId failed; empty pipe identity");
536                        crate::peercred::PeerIdentity {
537                            uid: None, gid: None, pid: None,
538                            sid: None, remote_addr: None,
539                            transport: "pipe",
540                        }
541                    });
542                let r = Arc::clone(&router);
543                let ctx = ctx.clone();
544                debug!(?peer, "named pipe accept");
545                tokio::spawn(async move {
546                    if let Err(e) = handle_connection(connected, r, peer, ctx).await {
547                        warn!(error = ?e, "connection terminated with error");
548                    }
549                });
550            }
551        }
552    }
553}
554
555#[cfg(test)]
556mod tests {
557    use super::*;
558    use inferd_engine::mock::Mock;
559
560    #[tokio::test]
561    async fn wait_for_ready_returns_when_already_ready() {
562        let router = Router::new(vec![Arc::new(Mock::new())]);
563        let elapsed = wait_for_ready(&router, Duration::from_secs(1))
564            .await
565            .unwrap();
566        assert!(elapsed < Duration::from_millis(100));
567    }
568
569    #[tokio::test]
570    async fn wait_for_ready_times_out_when_not_ready() {
571        let mock = Arc::new(Mock::new());
572        mock.set_ready(false);
573        let router = Router::new(vec![mock]);
574        let err = wait_for_ready(&router, Duration::from_millis(100))
575            .await
576            .unwrap_err();
577        assert!(err.to_string().contains("not ready"));
578    }
579
580    #[tokio::test]
581    async fn wait_for_ready_succeeds_after_delayed_ready() {
582        let mock = Arc::new(Mock::new());
583        mock.set_ready(false);
584        let router = Router::new(vec![mock.clone()]);
585
586        let m2 = Arc::clone(&mock);
587        tokio::spawn(async move {
588            tokio::time::sleep(Duration::from_millis(150)).await;
589            m2.set_ready(true);
590        });
591
592        let elapsed = wait_for_ready(&router, Duration::from_secs(1))
593            .await
594            .unwrap();
595        assert!(elapsed >= Duration::from_millis(100));
596    }
597}