Skip to main content

inferd_daemon/
lifecycle_embed.rs

1//! Embed connection lifecycle — Phase 6B-7 part 4.
2//!
3//! Per ADR 0017, embeddings live on a *separate* socket from v1 and
4//! v2. This module mirrors `lifecycle_v2.rs` but for the embed wire
5//! types (`inferd_proto::embed::EmbedRequest` / `EmbedResponse`).
6//!
7//! Single-frame request / single-frame response — embeddings are not
8//! streamed (the result is a complete vector, there is nothing to
9//! stream). The connection stays open for the next request.
10//!
11//! Per request:
12//!   1. Read one NDJSON frame, parse as `EmbedRequest`.
13//!   2. `EmbedRequest::resolve()` — structural validation.
14//!   3. Admission gate (same Admission shared with v1 / v2; one slot
15//!      is one slot regardless of wire surface).
16//!   4. Dispatch through the router; check the chosen backend's
17//!      `capabilities().embed` flag — backends that don't support
18//!      embeddings yield `Error{EmbedUnsupported, ...}`.
19//!   5. `backend.embed(resolved)` — errors map to embed error codes.
20//!   6. Emit a single `EmbedResponse::Embeddings` or
21//!      `EmbedResponse::Error` frame, then loop for the next request.
22
23use crate::auth::{AuthFrame, key_matches};
24use crate::endpoint::Connection;
25use crate::peercred::PeerIdentity;
26use crate::queue::SubmitError;
27use crate::router::{Router, RouterError};
28use inferd_engine::EmbedError;
29use inferd_proto::ProtoError;
30use inferd_proto::embed::{EmbedErrorCode, EmbedRequest, EmbedResponse};
31use inferd_proto::write_frame;
32use std::io;
33use std::sync::Arc;
34use tokio::io::{AsyncWrite, AsyncWriteExt, BufReader};
35use tokio::sync::Mutex;
36use tracing::{debug, info, warn};
37
38/// Per-accept context for embed connections. Reuses v1's
39/// `AcceptContext` shape — same TCP API key, same admission gate.
40pub use crate::lifecycle::AcceptContext;
41
42/// Handle one accepted embed client connection.
43pub async fn handle_embed_connection<C: Connection + 'static>(
44    mut conn: C,
45    router: Arc<Router>,
46    peer: PeerIdentity,
47    ctx: AcceptContext,
48) -> Result<(), io::Error> {
49    let transport = conn.transport();
50    info!(
51        target: "inferd_daemon::activity",
52        transport = transport,
53        wire_version = "embed",
54        peer = %peer,
55        peer_uid = peer.uid,
56        peer_pid = peer.pid,
57        peer_sid = peer.sid.as_deref(),
58        "embed_connection_accepted"
59    );
60
61    let (read_half, write_half) = tokio::io::split(&mut conn);
62    let mut reader = BufReader::with_capacity(64 * 1024, read_half);
63    let writer = Arc::new(Mutex::new(write_half));
64
65    // F-8 first-frame auth on TCP, identical to v1 / v2.
66    if transport == "tcp"
67        && let Some(expected) = ctx.expected_api_key.as_deref()
68    {
69        match read_auth_frame(&mut reader).await {
70            Some(frame) if key_matches(&frame.key, expected) => {
71                debug!(transport, "embed tcp auth ok");
72            }
73            _ => {
74                warn!(
75                    target: "inferd_daemon::activity",
76                    peer = %peer,
77                    wire_version = "embed",
78                    "embed_tcp_auth_rejected"
79                );
80                return Ok(());
81            }
82        }
83    }
84
85    loop {
86        let request: EmbedRequest = match read_request_embed(&mut reader).await {
87            Ok(Some(r)) => r,
88            Ok(None) => return Ok(()),
89            Err(ProtoError::Io(e)) => return Err(e),
90            Err(e) => {
91                let resp = EmbedResponse::Error {
92                    id: String::new(),
93                    code: error_code_for(&e),
94                    message: e.to_string(),
95                };
96                write_response_embed(&writer, &resp).await?;
97                return Ok(());
98            }
99        };
100
101        let id = request.id.clone();
102        let resolved = match request.resolve() {
103            Ok(r) => r,
104            Err(e) => {
105                let resp = EmbedResponse::Error {
106                    id,
107                    code: EmbedErrorCode::InvalidRequest,
108                    message: e.to_string(),
109                };
110                write_response_embed(&writer, &resp).await?;
111                continue;
112            }
113        };
114
115        // Admission gate. Embed shares the same admission instance as
116        // v1 / v2 — one slot is one slot.
117        let _admit_permit = match ctx.admission.as_ref().map(|a| a.try_admit()) {
118            None => None,
119            Some(Ok(p)) => Some(p),
120            Some(Err(SubmitError::QueueFull)) => {
121                let resp = EmbedResponse::Error {
122                    id: resolved.id.clone(),
123                    code: EmbedErrorCode::QueueFull,
124                    message: "queue full".into(),
125                };
126                write_response_embed(&writer, &resp).await?;
127                continue;
128            }
129            Some(Err(SubmitError::Closed)) => {
130                let resp = EmbedResponse::Error {
131                    id: resolved.id.clone(),
132                    code: EmbedErrorCode::BackendUnavailable,
133                    message: "admission closed".into(),
134                };
135                write_response_embed(&writer, &resp).await?;
136                return Ok(());
137            }
138        };
139
140        // Dispatch through the router.
141        let dispatch = match router.dispatch() {
142            Ok(d) => d,
143            Err(RouterError::NoBackends) | Err(RouterError::NoneAvailable) => {
144                let resp = EmbedResponse::Error {
145                    id: resolved.id.clone(),
146                    code: EmbedErrorCode::BackendUnavailable,
147                    message: "no backend available".into(),
148                };
149                write_response_embed(&writer, &resp).await?;
150                continue;
151            }
152        };
153        let backend_name = dispatch.name.clone();
154        let backend = dispatch.backend;
155
156        // Belt-and-braces check: the embed socket is only bound when
157        // the active backend's capability is `true`, but a router
158        // serving multiple backends could land on a non-embed slot.
159        if !backend.capabilities().embed {
160            let resp = EmbedResponse::Error {
161                id: resolved.id.clone(),
162                code: EmbedErrorCode::EmbedUnsupported,
163                message: format!("backend {backend_name:?} does not support embeddings"),
164            };
165            write_response_embed(&writer, &resp).await?;
166            continue;
167        }
168
169        let req_id = resolved.id.clone();
170        let n_inputs = resolved.input.len();
171
172        let result = backend.embed(resolved).await;
173        match result {
174            Ok(out) => {
175                let usage = out.usage;
176                let dimensions = out.dimensions;
177                let frame = EmbedResponse::Embeddings {
178                    id: req_id.clone(),
179                    embeddings: out.embeddings,
180                    dimensions,
181                    model: out.model,
182                    usage,
183                    backend: backend_name.clone(),
184                };
185                write_response_embed(&writer, &frame).await?;
186                router.record_success(&backend_name);
187                info!(
188                    target: "inferd_daemon::activity",
189                    req_id = %req_id,
190                    backend = %backend_name,
191                    wire_version = "embed",
192                    n_inputs = n_inputs,
193                    input_tokens = usage.input_tokens,
194                    dimensions = dimensions,
195                    "embed_request_done"
196                );
197            }
198            Err(e) => {
199                let (code, message, is_backend_failure) = match e {
200                    EmbedError::InvalidRequest(m) => (EmbedErrorCode::InvalidRequest, m, false),
201                    EmbedError::NotReady => (
202                        EmbedErrorCode::BackendUnavailable,
203                        "backend not ready".into(),
204                        true,
205                    ),
206                    EmbedError::Unavailable(m) => (EmbedErrorCode::BackendUnavailable, m, true),
207                    EmbedError::Unsupported => (
208                        EmbedErrorCode::EmbedUnsupported,
209                        "embed not supported by this backend".into(),
210                        false,
211                    ),
212                    EmbedError::Internal(m) => (EmbedErrorCode::Internal, m, true),
213                };
214                if is_backend_failure {
215                    router.record_failure(&backend_name);
216                }
217                let frame = EmbedResponse::Error {
218                    id: req_id,
219                    code,
220                    message,
221                };
222                write_response_embed(&writer, &frame).await?;
223            }
224        }
225    }
226}
227
228fn error_code_for(e: &ProtoError) -> EmbedErrorCode {
229    match e {
230        ProtoError::FrameTooLarge => EmbedErrorCode::FrameTooLarge,
231        ProtoError::Decode(_) | ProtoError::InvalidRequest(_) => EmbedErrorCode::InvalidRequest,
232        ProtoError::Io(_) => EmbedErrorCode::Internal,
233    }
234}
235
236async fn read_auth_frame<R>(reader: &mut R) -> Option<AuthFrame>
237where
238    R: tokio::io::AsyncBufRead + Unpin,
239{
240    use tokio::io::AsyncBufReadExt;
241    let mut line = Vec::with_capacity(256);
242    let limit = inferd_proto::MAX_FRAME_BYTES;
243    loop {
244        let buf = reader.fill_buf().await.ok()?;
245        if buf.is_empty() {
246            return None;
247        }
248        if let Some(idx) = buf.iter().position(|&b| b == b'\n') {
249            if line.len() + idx > limit {
250                return None;
251            }
252            line.extend_from_slice(&buf[..idx]);
253            reader.consume(idx + 1);
254            return AuthFrame::from_json(&line);
255        }
256        if line.len() + buf.len() > limit {
257            return None;
258        }
259        line.extend_from_slice(buf);
260        let n = buf.len();
261        reader.consume(n);
262    }
263}
264
265async fn read_request_embed<R>(reader: &mut R) -> Result<Option<EmbedRequest>, ProtoError>
266where
267    R: tokio::io::AsyncBufRead + Unpin,
268{
269    use tokio::io::AsyncBufReadExt;
270    let mut line = Vec::with_capacity(512);
271    let limit = inferd_proto::MAX_FRAME_BYTES;
272    loop {
273        let buf = reader.fill_buf().await?;
274        if buf.is_empty() {
275            if line.is_empty() {
276                return Ok(None);
277            }
278            return inferd_proto::read_frame::<&[u8], EmbedRequest>(&mut &line[..]);
279        }
280        if let Some(idx) = buf.iter().position(|&b| b == b'\n') {
281            if line.len() + idx > limit {
282                return Err(ProtoError::FrameTooLarge);
283            }
284            line.extend_from_slice(&buf[..=idx]);
285            reader.consume(idx + 1);
286            return inferd_proto::read_frame::<&[u8], EmbedRequest>(&mut &line[..]);
287        }
288        if line.len() + buf.len() > limit {
289            return Err(ProtoError::FrameTooLarge);
290        }
291        line.extend_from_slice(buf);
292        let n = buf.len();
293        reader.consume(n);
294    }
295}
296
297async fn write_response_embed<W: AsyncWrite + Unpin>(
298    writer: &Mutex<W>,
299    resp: &EmbedResponse,
300) -> io::Result<()> {
301    let mut buf = Vec::with_capacity(512);
302    write_frame(&mut buf, resp)
303        .map_err(|e| io::Error::other(format!("serialise embed response: {e}")))?;
304    let mut guard = writer.lock().await;
305    guard.write_all(&buf).await?;
306    guard.flush().await?;
307    Ok(())
308}
309
310/// Serve an embed TCP listener.
311pub async fn serve_tcp_embed(
312    listener: tokio::net::TcpListener,
313    router: Arc<Router>,
314    ctx: AcceptContext,
315    mut shutdown: tokio::sync::oneshot::Receiver<()>,
316) -> io::Result<()> {
317    info!(addr = ?listener.local_addr()?, "embed tcp listener accepting");
318    loop {
319        tokio::select! {
320            _ = &mut shutdown => {
321                info!("embed tcp shutdown signalled");
322                return Ok(());
323            }
324            accept = listener.accept() => {
325                let (stream, peer_addr) = accept?;
326                let peer = PeerIdentity::from_tcp(peer_addr);
327                let r = Arc::clone(&router);
328                let ctx = ctx.clone();
329                debug!(?peer_addr, "embed tcp accept");
330                tokio::spawn(async move {
331                    if let Err(e) = handle_embed_connection(stream, r, peer, ctx).await {
332                        warn!(error = ?e, "embed connection terminated with error");
333                    }
334                });
335            }
336        }
337    }
338}
339
340/// Serve an embed Unix domain socket listener.
341#[cfg(unix)]
342pub async fn serve_uds_embed(
343    listener: tokio::net::UnixListener,
344    router: Arc<Router>,
345    ctx: AcceptContext,
346    mut shutdown: tokio::sync::oneshot::Receiver<()>,
347) -> io::Result<()> {
348    info!("embed uds listener accepting");
349    loop {
350        tokio::select! {
351            _ = &mut shutdown => {
352                info!("embed uds shutdown signalled");
353                return Ok(());
354            }
355            accept = listener.accept() => {
356                let (stream, _) = accept?;
357                let r = Arc::clone(&router);
358                let peer = crate::peercred::unix::from_stream(&stream)
359                    .unwrap_or_else(|e| {
360                        warn!(error = %e, "embed SO_PEERCRED failed; recording empty unix identity");
361                        crate::peercred::PeerIdentity {
362                            uid: None, gid: None, pid: None,
363                            sid: None, remote_addr: None,
364                            transport: "unix",
365                        }
366                    });
367                let ctx = ctx.clone();
368                debug!(?peer, "embed uds accept");
369                tokio::spawn(async move {
370                    if let Err(e) = handle_embed_connection(stream, r, peer, ctx).await {
371                        warn!(error = ?e, "embed connection terminated with error");
372                    }
373                });
374            }
375        }
376    }
377}
378
379/// Serve an embed Windows named pipe listener.
380#[cfg(windows)]
381pub async fn serve_named_pipe_embed(
382    path: &str,
383    first_instance: tokio::net::windows::named_pipe::NamedPipeServer,
384    router: Arc<Router>,
385    ctx: AcceptContext,
386    mut shutdown: tokio::sync::oneshot::Receiver<()>,
387) -> io::Result<()> {
388    use crate::endpoint::bind_named_pipe;
389
390    info!(path = %path, "embed named pipe listener accepting");
391    let mut server = first_instance;
392    loop {
393        tokio::select! {
394            _ = &mut shutdown => {
395                info!("embed named pipe shutdown signalled");
396                return Ok(());
397            }
398            connect_result = server.connect() => {
399                connect_result?;
400                let connected = server;
401                server = bind_named_pipe(path, false)?;
402
403                let peer = crate::peercred::windows::from_stream(&connected)
404                    .unwrap_or_else(|e| {
405                        warn!(error = %e, "embed GetNamedPipeClientProcessId failed; empty pipe identity");
406                        crate::peercred::PeerIdentity {
407                            uid: None, gid: None, pid: None,
408                            sid: None, remote_addr: None,
409                            transport: "pipe",
410                        }
411                    });
412                let r = Arc::clone(&router);
413                let ctx = ctx.clone();
414                debug!(?peer, "embed named pipe accept");
415                tokio::spawn(async move {
416                    if let Err(e) = handle_embed_connection(connected, r, peer, ctx).await {
417                        warn!(error = ?e, "embed connection terminated with error");
418                    }
419                });
420            }
421        }
422    }
423}