1use crate::auth::{AuthFrame, key_matches};
24use crate::endpoint::Connection;
25use crate::peercred::PeerIdentity;
26use crate::queue::SubmitError;
27use crate::router::{Router, RouterError};
28use inferd_engine::EmbedError;
29use inferd_proto::ProtoError;
30use inferd_proto::embed::{EmbedErrorCode, EmbedRequest, EmbedResponse};
31use inferd_proto::write_frame;
32use std::io;
33use std::sync::Arc;
34use tokio::io::{AsyncWrite, AsyncWriteExt, BufReader};
35use tokio::sync::Mutex;
36use tracing::{debug, info, warn};
37
38pub use crate::lifecycle::AcceptContext;
41
42pub async fn handle_embed_connection<C: Connection + 'static>(
44 mut conn: C,
45 router: Arc<Router>,
46 peer: PeerIdentity,
47 ctx: AcceptContext,
48) -> Result<(), io::Error> {
49 let transport = conn.transport();
50 info!(
51 target: "inferd_daemon::activity",
52 transport = transport,
53 wire_version = "embed",
54 peer = %peer,
55 peer_uid = peer.uid,
56 peer_pid = peer.pid,
57 peer_sid = peer.sid.as_deref(),
58 "embed_connection_accepted"
59 );
60
61 let (read_half, write_half) = tokio::io::split(&mut conn);
62 let mut reader = BufReader::with_capacity(64 * 1024, read_half);
63 let writer = Arc::new(Mutex::new(write_half));
64
65 if transport == "tcp"
67 && let Some(expected) = ctx.expected_api_key.as_deref()
68 {
69 match read_auth_frame(&mut reader).await {
70 Some(frame) if key_matches(&frame.key, expected) => {
71 debug!(transport, "embed tcp auth ok");
72 }
73 _ => {
74 warn!(
75 target: "inferd_daemon::activity",
76 peer = %peer,
77 wire_version = "embed",
78 "embed_tcp_auth_rejected"
79 );
80 return Ok(());
81 }
82 }
83 }
84
85 loop {
86 let request: EmbedRequest = match read_request_embed(&mut reader).await {
87 Ok(Some(r)) => r,
88 Ok(None) => return Ok(()),
89 Err(ProtoError::Io(e)) => return Err(e),
90 Err(e) => {
91 let resp = EmbedResponse::Error {
92 id: String::new(),
93 code: error_code_for(&e),
94 message: e.to_string(),
95 };
96 write_response_embed(&writer, &resp).await?;
97 return Ok(());
98 }
99 };
100
101 let id = request.id.clone();
102 let resolved = match request.resolve() {
103 Ok(r) => r,
104 Err(e) => {
105 let resp = EmbedResponse::Error {
106 id,
107 code: EmbedErrorCode::InvalidRequest,
108 message: e.to_string(),
109 };
110 write_response_embed(&writer, &resp).await?;
111 continue;
112 }
113 };
114
115 let _admit_permit = match ctx.admission.as_ref().map(|a| a.try_admit()) {
118 None => None,
119 Some(Ok(p)) => Some(p),
120 Some(Err(SubmitError::QueueFull)) => {
121 let resp = EmbedResponse::Error {
122 id: resolved.id.clone(),
123 code: EmbedErrorCode::QueueFull,
124 message: "queue full".into(),
125 };
126 write_response_embed(&writer, &resp).await?;
127 continue;
128 }
129 Some(Err(SubmitError::Closed)) => {
130 let resp = EmbedResponse::Error {
131 id: resolved.id.clone(),
132 code: EmbedErrorCode::BackendUnavailable,
133 message: "admission closed".into(),
134 };
135 write_response_embed(&writer, &resp).await?;
136 return Ok(());
137 }
138 };
139
140 let dispatch = match router.dispatch_embed() {
145 Ok(d) => d,
146 Err(RouterError::NoBackends) | Err(RouterError::NoneAvailable) => {
147 let resp = EmbedResponse::Error {
148 id: resolved.id.clone(),
149 code: EmbedErrorCode::BackendUnavailable,
150 message: "no embed-capable backend available".into(),
151 };
152 write_response_embed(&writer, &resp).await?;
153 continue;
154 }
155 };
156 let backend_name = dispatch.name.clone();
157 let backend = dispatch.backend;
158
159 let req_id = resolved.id.clone();
160 let n_inputs = resolved.input.len();
161
162 let result = backend.embed(resolved).await;
163 match result {
164 Ok(out) => {
165 let usage = out.usage;
166 let dimensions = out.dimensions;
167 let frame = EmbedResponse::Embeddings {
168 id: req_id.clone(),
169 embeddings: out.embeddings,
170 dimensions,
171 model: out.model,
172 usage,
173 backend: backend_name.clone(),
174 };
175 write_response_embed(&writer, &frame).await?;
176 router.record_success(&backend_name);
177 info!(
178 target: "inferd_daemon::activity",
179 req_id = %req_id,
180 backend = %backend_name,
181 wire_version = "embed",
182 n_inputs = n_inputs,
183 input_tokens = usage.input_tokens,
184 dimensions = dimensions,
185 "embed_request_done"
186 );
187 }
188 Err(e) => {
189 let (code, message, is_backend_failure) = match e {
190 EmbedError::InvalidRequest(m) => (EmbedErrorCode::InvalidRequest, m, false),
191 EmbedError::NotReady => (
192 EmbedErrorCode::BackendUnavailable,
193 "backend not ready".into(),
194 true,
195 ),
196 EmbedError::Unavailable(m) => (EmbedErrorCode::BackendUnavailable, m, true),
197 EmbedError::Unsupported => (
198 EmbedErrorCode::EmbedUnsupported,
199 "embed not supported by this backend".into(),
200 false,
201 ),
202 EmbedError::Internal(m) => (EmbedErrorCode::Internal, m, true),
203 };
204 if is_backend_failure {
205 router.record_failure(&backend_name);
206 }
207 let frame = EmbedResponse::Error {
208 id: req_id,
209 code,
210 message,
211 };
212 write_response_embed(&writer, &frame).await?;
213 }
214 }
215 }
216}
217
218fn error_code_for(e: &ProtoError) -> EmbedErrorCode {
219 match e {
220 ProtoError::FrameTooLarge => EmbedErrorCode::FrameTooLarge,
221 ProtoError::Decode(_) | ProtoError::InvalidRequest(_) => EmbedErrorCode::InvalidRequest,
222 ProtoError::Io(_) => EmbedErrorCode::Internal,
223 }
224}
225
226async fn read_auth_frame<R>(reader: &mut R) -> Option<AuthFrame>
227where
228 R: tokio::io::AsyncBufRead + Unpin,
229{
230 use tokio::io::AsyncBufReadExt;
231 let mut line = Vec::with_capacity(256);
232 let limit = inferd_proto::MAX_FRAME_BYTES;
233 loop {
234 let buf = reader.fill_buf().await.ok()?;
235 if buf.is_empty() {
236 return None;
237 }
238 if let Some(idx) = buf.iter().position(|&b| b == b'\n') {
239 if line.len() + idx > limit {
240 return None;
241 }
242 line.extend_from_slice(&buf[..idx]);
243 reader.consume(idx + 1);
244 return AuthFrame::from_json(&line);
245 }
246 if line.len() + buf.len() > limit {
247 return None;
248 }
249 line.extend_from_slice(buf);
250 let n = buf.len();
251 reader.consume(n);
252 }
253}
254
255async fn read_request_embed<R>(reader: &mut R) -> Result<Option<EmbedRequest>, ProtoError>
256where
257 R: tokio::io::AsyncBufRead + Unpin,
258{
259 use tokio::io::AsyncBufReadExt;
260 let mut line = Vec::with_capacity(512);
261 let limit = inferd_proto::MAX_FRAME_BYTES;
262 loop {
263 let buf = reader.fill_buf().await?;
264 if buf.is_empty() {
265 if line.is_empty() {
266 return Ok(None);
267 }
268 return inferd_proto::read_frame::<&[u8], EmbedRequest>(&mut &line[..]);
269 }
270 if let Some(idx) = buf.iter().position(|&b| b == b'\n') {
271 if line.len() + idx > limit {
272 return Err(ProtoError::FrameTooLarge);
273 }
274 line.extend_from_slice(&buf[..=idx]);
275 reader.consume(idx + 1);
276 return inferd_proto::read_frame::<&[u8], EmbedRequest>(&mut &line[..]);
277 }
278 if line.len() + buf.len() > limit {
279 return Err(ProtoError::FrameTooLarge);
280 }
281 line.extend_from_slice(buf);
282 let n = buf.len();
283 reader.consume(n);
284 }
285}
286
287async fn write_response_embed<W: AsyncWrite + Unpin>(
288 writer: &Mutex<W>,
289 resp: &EmbedResponse,
290) -> io::Result<()> {
291 let mut buf = Vec::with_capacity(512);
292 write_frame(&mut buf, resp)
293 .map_err(|e| io::Error::other(format!("serialise embed response: {e}")))?;
294 let mut guard = writer.lock().await;
295 guard.write_all(&buf).await?;
296 guard.flush().await?;
297 Ok(())
298}
299
300pub async fn serve_tcp_embed(
302 listener: tokio::net::TcpListener,
303 router: Arc<Router>,
304 ctx: AcceptContext,
305 mut shutdown: tokio::sync::oneshot::Receiver<()>,
306) -> io::Result<()> {
307 info!(addr = ?listener.local_addr()?, "embed tcp listener accepting");
308 loop {
309 tokio::select! {
310 _ = &mut shutdown => {
311 info!("embed tcp shutdown signalled");
312 return Ok(());
313 }
314 accept = listener.accept() => {
315 let (stream, peer_addr) = accept?;
316 let peer = PeerIdentity::from_tcp(peer_addr);
317 let r = Arc::clone(&router);
318 let ctx = ctx.clone();
319 debug!(?peer_addr, "embed tcp accept");
320 tokio::spawn(async move {
321 if let Err(e) = handle_embed_connection(stream, r, peer, ctx).await {
322 warn!(error = ?e, "embed connection terminated with error");
323 }
324 });
325 }
326 }
327 }
328}
329
330#[cfg(unix)]
332pub async fn serve_uds_embed(
333 listener: tokio::net::UnixListener,
334 router: Arc<Router>,
335 ctx: AcceptContext,
336 mut shutdown: tokio::sync::oneshot::Receiver<()>,
337) -> io::Result<()> {
338 info!("embed uds listener accepting");
339 loop {
340 tokio::select! {
341 _ = &mut shutdown => {
342 info!("embed uds shutdown signalled");
343 return Ok(());
344 }
345 accept = listener.accept() => {
346 let (stream, _) = accept?;
347 let r = Arc::clone(&router);
348 let peer = crate::peercred::unix::from_stream(&stream)
349 .unwrap_or_else(|e| {
350 warn!(error = %e, "embed SO_PEERCRED failed; recording empty unix identity");
351 crate::peercred::PeerIdentity {
352 uid: None, gid: None, pid: None,
353 sid: None, remote_addr: None,
354 transport: "unix",
355 }
356 });
357 let ctx = ctx.clone();
358 debug!(?peer, "embed uds accept");
359 tokio::spawn(async move {
360 if let Err(e) = handle_embed_connection(stream, r, peer, ctx).await {
361 warn!(error = ?e, "embed connection terminated with error");
362 }
363 });
364 }
365 }
366 }
367}
368
369#[cfg(windows)]
371pub async fn serve_named_pipe_embed(
372 path: &str,
373 first_instance: tokio::net::windows::named_pipe::NamedPipeServer,
374 router: Arc<Router>,
375 ctx: AcceptContext,
376 mut shutdown: tokio::sync::oneshot::Receiver<()>,
377) -> io::Result<()> {
378 use crate::endpoint::bind_named_pipe;
379
380 info!(path = %path, "embed named pipe listener accepting");
381 let mut server = first_instance;
382 loop {
383 tokio::select! {
384 _ = &mut shutdown => {
385 info!("embed named pipe shutdown signalled");
386 return Ok(());
387 }
388 connect_result = server.connect() => {
389 connect_result?;
390 let connected = server;
391 server = bind_named_pipe(path, false)?;
392
393 let peer = crate::peercred::windows::from_stream(&connected)
394 .unwrap_or_else(|e| {
395 warn!(error = %e, "embed GetNamedPipeClientProcessId failed; empty pipe identity");
396 crate::peercred::PeerIdentity {
397 uid: None, gid: None, pid: None,
398 sid: None, remote_addr: None,
399 transport: "pipe",
400 }
401 });
402 let r = Arc::clone(&router);
403 let ctx = ctx.clone();
404 debug!(?peer, "embed named pipe accept");
405 tokio::spawn(async move {
406 if let Err(e) = handle_embed_connection(connected, r, peer, ctx).await {
407 warn!(error = ?e, "embed connection terminated with error");
408 }
409 });
410 }
411 }
412 }
413}