Skip to main content

inferd_daemon/
lifecycle_embed.rs

1//! Embed connection lifecycle — Phase 6B-7 part 4.
2//!
3//! Per ADR 0017, embeddings live on a *separate* socket from v1 and
4//! v2. This module mirrors `lifecycle_v2.rs` but for the embed wire
5//! types (`inferd_proto::embed::EmbedRequest` / `EmbedResponse`).
6//!
7//! Single-frame request / single-frame response — embeddings are not
8//! streamed (the result is a complete vector, there is nothing to
9//! stream). The connection stays open for the next request.
10//!
11//! Per request:
12//!   1. Read one NDJSON frame, parse as `EmbedRequest`.
13//!   2. `EmbedRequest::resolve()` — structural validation.
14//!   3. Admission gate (same Admission shared with v1 / v2; one slot
15//!      is one slot regardless of wire surface).
16//!   4. Dispatch through the router; check the chosen backend's
17//!      `capabilities().embed` flag — backends that don't support
18//!      embeddings yield `Error{EmbedUnsupported, ...}`.
19//!   5. `backend.embed(resolved)` — errors map to embed error codes.
20//!   6. Emit a single `EmbedResponse::Embeddings` or
21//!      `EmbedResponse::Error` frame, then loop for the next request.
22
23use crate::auth::{AuthFrame, key_matches};
24use crate::endpoint::Connection;
25use crate::peercred::PeerIdentity;
26use crate::queue::SubmitError;
27use crate::router::{Router, RouterError};
28use inferd_engine::EmbedError;
29use inferd_proto::ProtoError;
30use inferd_proto::embed::{EmbedErrorCode, EmbedRequest, EmbedResponse};
31use inferd_proto::write_frame;
32use std::io;
33use std::sync::Arc;
34use tokio::io::{AsyncWrite, AsyncWriteExt, BufReader};
35use tokio::sync::Mutex;
36use tracing::{debug, info, warn};
37
38/// Per-accept context for embed connections. Reuses v1's
39/// `AcceptContext` shape — same TCP API key, same admission gate.
40pub use crate::lifecycle::AcceptContext;
41
42/// Handle one accepted embed client connection.
43pub async fn handle_embed_connection<C: Connection + 'static>(
44    mut conn: C,
45    router: Arc<Router>,
46    peer: PeerIdentity,
47    ctx: AcceptContext,
48) -> Result<(), io::Error> {
49    let transport = conn.transport();
50    info!(
51        target: "inferd_daemon::activity",
52        transport = transport,
53        wire_version = "embed",
54        peer = %peer,
55        peer_uid = peer.uid,
56        peer_pid = peer.pid,
57        peer_sid = peer.sid.as_deref(),
58        "embed_connection_accepted"
59    );
60
61    let (read_half, write_half) = tokio::io::split(&mut conn);
62    let mut reader = BufReader::with_capacity(64 * 1024, read_half);
63    let writer = Arc::new(Mutex::new(write_half));
64
65    // F-8 first-frame auth on TCP, identical to v1 / v2.
66    if transport == "tcp"
67        && let Some(expected) = ctx.expected_api_key.as_deref()
68    {
69        match read_auth_frame(&mut reader).await {
70            Some(frame) if key_matches(&frame.key, expected) => {
71                debug!(transport, "embed tcp auth ok");
72            }
73            _ => {
74                warn!(
75                    target: "inferd_daemon::activity",
76                    peer = %peer,
77                    wire_version = "embed",
78                    "embed_tcp_auth_rejected"
79                );
80                return Ok(());
81            }
82        }
83    }
84
85    loop {
86        let request: EmbedRequest = match read_request_embed(&mut reader).await {
87            Ok(Some(r)) => r,
88            Ok(None) => return Ok(()),
89            Err(ProtoError::Io(e)) => return Err(e),
90            Err(e) => {
91                let resp = EmbedResponse::Error {
92                    id: String::new(),
93                    code: error_code_for(&e),
94                    message: e.to_string(),
95                };
96                write_response_embed(&writer, &resp).await?;
97                return Ok(());
98            }
99        };
100
101        let id = request.id.clone();
102        let resolved = match request.resolve() {
103            Ok(r) => r,
104            Err(e) => {
105                let resp = EmbedResponse::Error {
106                    id,
107                    code: EmbedErrorCode::InvalidRequest,
108                    message: e.to_string(),
109                };
110                write_response_embed(&writer, &resp).await?;
111                continue;
112            }
113        };
114
115        // Admission gate. Embed shares the same admission instance as
116        // v1 / v2 — one slot is one slot.
117        let _admit_permit = match ctx.admission.as_ref().map(|a| a.try_admit()) {
118            None => None,
119            Some(Ok(p)) => Some(p),
120            Some(Err(SubmitError::QueueFull)) => {
121                let resp = EmbedResponse::Error {
122                    id: resolved.id.clone(),
123                    code: EmbedErrorCode::QueueFull,
124                    message: "queue full".into(),
125                };
126                write_response_embed(&writer, &resp).await?;
127                continue;
128            }
129            Some(Err(SubmitError::Closed)) => {
130                let resp = EmbedResponse::Error {
131                    id: resolved.id.clone(),
132                    code: EmbedErrorCode::BackendUnavailable,
133                    message: "admission closed".into(),
134                };
135                write_response_embed(&writer, &resp).await?;
136                return Ok(());
137            }
138        };
139
140        // Dispatch through the router. `dispatch_embed` filters out
141        // slots whose backend doesn't advertise `capabilities().embed`
142        // so multi-backend configs that put a generate-only backend
143        // ahead of an embed-capable one route embed requests correctly.
144        let dispatch = match router.dispatch_embed() {
145            Ok(d) => d,
146            Err(RouterError::NoBackends) | Err(RouterError::NoneAvailable) => {
147                let resp = EmbedResponse::Error {
148                    id: resolved.id.clone(),
149                    code: EmbedErrorCode::BackendUnavailable,
150                    message: "no embed-capable backend available".into(),
151                };
152                write_response_embed(&writer, &resp).await?;
153                continue;
154            }
155        };
156        let backend_name = dispatch.name.clone();
157        let backend = dispatch.backend;
158
159        let req_id = resolved.id.clone();
160        let n_inputs = resolved.input.len();
161
162        let result = backend.embed(resolved).await;
163        match result {
164            Ok(out) => {
165                let usage = out.usage;
166                let dimensions = out.dimensions;
167                let frame = EmbedResponse::Embeddings {
168                    id: req_id.clone(),
169                    embeddings: out.embeddings,
170                    dimensions,
171                    model: out.model,
172                    usage,
173                    backend: backend_name.clone(),
174                };
175                write_response_embed(&writer, &frame).await?;
176                router.record_success(&backend_name);
177                info!(
178                    target: "inferd_daemon::activity",
179                    req_id = %req_id,
180                    backend = %backend_name,
181                    wire_version = "embed",
182                    n_inputs = n_inputs,
183                    input_tokens = usage.input_tokens,
184                    dimensions = dimensions,
185                    "embed_request_done"
186                );
187            }
188            Err(e) => {
189                let (code, message, is_backend_failure) = match e {
190                    EmbedError::InvalidRequest(m) => (EmbedErrorCode::InvalidRequest, m, false),
191                    EmbedError::NotReady => (
192                        EmbedErrorCode::BackendUnavailable,
193                        "backend not ready".into(),
194                        true,
195                    ),
196                    EmbedError::Unavailable(m) => (EmbedErrorCode::BackendUnavailable, m, true),
197                    EmbedError::Unsupported => (
198                        EmbedErrorCode::EmbedUnsupported,
199                        "embed not supported by this backend".into(),
200                        false,
201                    ),
202                    EmbedError::Internal(m) => (EmbedErrorCode::Internal, m, true),
203                };
204                if is_backend_failure {
205                    router.record_failure(&backend_name);
206                }
207                let frame = EmbedResponse::Error {
208                    id: req_id,
209                    code,
210                    message,
211                };
212                write_response_embed(&writer, &frame).await?;
213            }
214        }
215    }
216}
217
218fn error_code_for(e: &ProtoError) -> EmbedErrorCode {
219    match e {
220        ProtoError::FrameTooLarge => EmbedErrorCode::FrameTooLarge,
221        ProtoError::Decode(_) | ProtoError::InvalidRequest(_) => EmbedErrorCode::InvalidRequest,
222        ProtoError::Io(_) => EmbedErrorCode::Internal,
223    }
224}
225
226async fn read_auth_frame<R>(reader: &mut R) -> Option<AuthFrame>
227where
228    R: tokio::io::AsyncBufRead + Unpin,
229{
230    use tokio::io::AsyncBufReadExt;
231    let mut line = Vec::with_capacity(256);
232    let limit = inferd_proto::MAX_FRAME_BYTES;
233    loop {
234        let buf = reader.fill_buf().await.ok()?;
235        if buf.is_empty() {
236            return None;
237        }
238        if let Some(idx) = buf.iter().position(|&b| b == b'\n') {
239            if line.len() + idx > limit {
240                return None;
241            }
242            line.extend_from_slice(&buf[..idx]);
243            reader.consume(idx + 1);
244            return AuthFrame::from_json(&line);
245        }
246        if line.len() + buf.len() > limit {
247            return None;
248        }
249        line.extend_from_slice(buf);
250        let n = buf.len();
251        reader.consume(n);
252    }
253}
254
255async fn read_request_embed<R>(reader: &mut R) -> Result<Option<EmbedRequest>, ProtoError>
256where
257    R: tokio::io::AsyncBufRead + Unpin,
258{
259    use tokio::io::AsyncBufReadExt;
260    let mut line = Vec::with_capacity(512);
261    let limit = inferd_proto::MAX_FRAME_BYTES;
262    loop {
263        let buf = reader.fill_buf().await?;
264        if buf.is_empty() {
265            if line.is_empty() {
266                return Ok(None);
267            }
268            return inferd_proto::read_frame::<&[u8], EmbedRequest>(&mut &line[..]);
269        }
270        if let Some(idx) = buf.iter().position(|&b| b == b'\n') {
271            if line.len() + idx > limit {
272                return Err(ProtoError::FrameTooLarge);
273            }
274            line.extend_from_slice(&buf[..=idx]);
275            reader.consume(idx + 1);
276            return inferd_proto::read_frame::<&[u8], EmbedRequest>(&mut &line[..]);
277        }
278        if line.len() + buf.len() > limit {
279            return Err(ProtoError::FrameTooLarge);
280        }
281        line.extend_from_slice(buf);
282        let n = buf.len();
283        reader.consume(n);
284    }
285}
286
287async fn write_response_embed<W: AsyncWrite + Unpin>(
288    writer: &Mutex<W>,
289    resp: &EmbedResponse,
290) -> io::Result<()> {
291    let mut buf = Vec::with_capacity(512);
292    write_frame(&mut buf, resp)
293        .map_err(|e| io::Error::other(format!("serialise embed response: {e}")))?;
294    let mut guard = writer.lock().await;
295    guard.write_all(&buf).await?;
296    guard.flush().await?;
297    Ok(())
298}
299
300/// Serve an embed TCP listener.
301pub async fn serve_tcp_embed(
302    listener: tokio::net::TcpListener,
303    router: Arc<Router>,
304    ctx: AcceptContext,
305    mut shutdown: tokio::sync::oneshot::Receiver<()>,
306) -> io::Result<()> {
307    info!(addr = ?listener.local_addr()?, "embed tcp listener accepting");
308    loop {
309        tokio::select! {
310            _ = &mut shutdown => {
311                info!("embed tcp shutdown signalled");
312                return Ok(());
313            }
314            accept = listener.accept() => {
315                let (stream, peer_addr) = accept?;
316                let peer = PeerIdentity::from_tcp(peer_addr);
317                let r = Arc::clone(&router);
318                let ctx = ctx.clone();
319                debug!(?peer_addr, "embed tcp accept");
320                tokio::spawn(async move {
321                    if let Err(e) = handle_embed_connection(stream, r, peer, ctx).await {
322                        warn!(error = ?e, "embed connection terminated with error");
323                    }
324                });
325            }
326        }
327    }
328}
329
330/// Serve an embed Unix domain socket listener.
331#[cfg(unix)]
332pub async fn serve_uds_embed(
333    listener: tokio::net::UnixListener,
334    router: Arc<Router>,
335    ctx: AcceptContext,
336    mut shutdown: tokio::sync::oneshot::Receiver<()>,
337) -> io::Result<()> {
338    info!("embed uds listener accepting");
339    loop {
340        tokio::select! {
341            _ = &mut shutdown => {
342                info!("embed uds shutdown signalled");
343                return Ok(());
344            }
345            accept = listener.accept() => {
346                let (stream, _) = accept?;
347                let r = Arc::clone(&router);
348                let peer = crate::peercred::unix::from_stream(&stream)
349                    .unwrap_or_else(|e| {
350                        warn!(error = %e, "embed SO_PEERCRED failed; recording empty unix identity");
351                        crate::peercred::PeerIdentity {
352                            uid: None, gid: None, pid: None,
353                            sid: None, remote_addr: None,
354                            transport: "unix",
355                        }
356                    });
357                let ctx = ctx.clone();
358                debug!(?peer, "embed uds accept");
359                tokio::spawn(async move {
360                    if let Err(e) = handle_embed_connection(stream, r, peer, ctx).await {
361                        warn!(error = ?e, "embed connection terminated with error");
362                    }
363                });
364            }
365        }
366    }
367}
368
369/// Serve an embed Windows named pipe listener.
370#[cfg(windows)]
371pub async fn serve_named_pipe_embed(
372    path: &str,
373    first_instance: tokio::net::windows::named_pipe::NamedPipeServer,
374    router: Arc<Router>,
375    ctx: AcceptContext,
376    mut shutdown: tokio::sync::oneshot::Receiver<()>,
377) -> io::Result<()> {
378    use crate::endpoint::bind_named_pipe;
379
380    info!(path = %path, "embed named pipe listener accepting");
381    let mut server = first_instance;
382    loop {
383        tokio::select! {
384            _ = &mut shutdown => {
385                info!("embed named pipe shutdown signalled");
386                return Ok(());
387            }
388            connect_result = server.connect() => {
389                connect_result?;
390                let connected = server;
391                server = bind_named_pipe(path, false)?;
392
393                let peer = crate::peercred::windows::from_stream(&connected)
394                    .unwrap_or_else(|e| {
395                        warn!(error = %e, "embed GetNamedPipeClientProcessId failed; empty pipe identity");
396                        crate::peercred::PeerIdentity {
397                            uid: None, gid: None, pid: None,
398                            sid: None, remote_addr: None,
399                            transport: "pipe",
400                        }
401                    });
402                let r = Arc::clone(&router);
403                let ctx = ctx.clone();
404                debug!(?peer, "embed named pipe accept");
405                tokio::spawn(async move {
406                    if let Err(e) = handle_embed_connection(connected, r, peer, ctx).await {
407                        warn!(error = ?e, "embed connection terminated with error");
408                    }
409                });
410            }
411        }
412    }
413}