1use crate::auth::{AuthFrame, key_matches};
24use crate::endpoint::Connection;
25use crate::peercred::PeerIdentity;
26use crate::queue::SubmitError;
27use crate::router::{Router, RouterError};
28use inferd_engine::EmbedError;
29use inferd_proto::ProtoError;
30use inferd_proto::embed::{EmbedErrorCode, EmbedRequest, EmbedResponse};
31use inferd_proto::write_frame;
32use std::io;
33use std::sync::Arc;
34use tokio::io::{AsyncWrite, AsyncWriteExt, BufReader};
35use tokio::sync::Mutex;
36use tracing::{debug, info, warn};
37
38pub use crate::lifecycle::AcceptContext;
41
42pub async fn handle_embed_connection<C: Connection + 'static>(
44 mut conn: C,
45 router: Arc<Router>,
46 peer: PeerIdentity,
47 ctx: AcceptContext,
48) -> Result<(), io::Error> {
49 let transport = conn.transport();
50 info!(
51 target: "inferd_daemon::activity",
52 transport = transport,
53 wire_version = "embed",
54 peer = %peer,
55 peer_uid = peer.uid,
56 peer_pid = peer.pid,
57 peer_sid = peer.sid.as_deref(),
58 "embed_connection_accepted"
59 );
60
61 let (read_half, write_half) = tokio::io::split(&mut conn);
62 let mut reader = BufReader::with_capacity(64 * 1024, read_half);
63 let writer = Arc::new(Mutex::new(write_half));
64
65 if transport == "tcp"
67 && let Some(expected) = ctx.expected_api_key.as_deref()
68 {
69 match read_auth_frame(&mut reader).await {
70 Some(frame) if key_matches(&frame.key, expected) => {
71 debug!(transport, "embed tcp auth ok");
72 }
73 _ => {
74 warn!(
75 target: "inferd_daemon::activity",
76 peer = %peer,
77 wire_version = "embed",
78 "embed_tcp_auth_rejected"
79 );
80 return Ok(());
81 }
82 }
83 }
84
85 loop {
86 let request: EmbedRequest = match read_request_embed(&mut reader).await {
87 Ok(Some(r)) => r,
88 Ok(None) => return Ok(()),
89 Err(ProtoError::Io(e)) => return Err(e),
90 Err(e) => {
91 let resp = EmbedResponse::Error {
92 id: String::new(),
93 code: error_code_for(&e),
94 message: e.to_string(),
95 };
96 write_response_embed(&writer, &resp).await?;
97 return Ok(());
98 }
99 };
100
101 let id = request.id.clone();
102 let resolved = match request.resolve() {
103 Ok(r) => r,
104 Err(e) => {
105 let resp = EmbedResponse::Error {
106 id,
107 code: EmbedErrorCode::InvalidRequest,
108 message: e.to_string(),
109 };
110 write_response_embed(&writer, &resp).await?;
111 continue;
112 }
113 };
114
115 let _admit_permit = match ctx.admission.as_ref().map(|a| a.try_admit()) {
118 None => None,
119 Some(Ok(p)) => Some(p),
120 Some(Err(SubmitError::QueueFull)) => {
121 let resp = EmbedResponse::Error {
122 id: resolved.id.clone(),
123 code: EmbedErrorCode::QueueFull,
124 message: "queue full".into(),
125 };
126 write_response_embed(&writer, &resp).await?;
127 continue;
128 }
129 Some(Err(SubmitError::Closed)) => {
130 let resp = EmbedResponse::Error {
131 id: resolved.id.clone(),
132 code: EmbedErrorCode::BackendUnavailable,
133 message: "admission closed".into(),
134 };
135 write_response_embed(&writer, &resp).await?;
136 return Ok(());
137 }
138 };
139
140 let dispatch = match router.dispatch() {
142 Ok(d) => d,
143 Err(RouterError::NoBackends) | Err(RouterError::NoneAvailable) => {
144 let resp = EmbedResponse::Error {
145 id: resolved.id.clone(),
146 code: EmbedErrorCode::BackendUnavailable,
147 message: "no backend available".into(),
148 };
149 write_response_embed(&writer, &resp).await?;
150 continue;
151 }
152 };
153 let backend_name = dispatch.name.clone();
154 let backend = dispatch.backend;
155
156 if !backend.capabilities().embed {
160 let resp = EmbedResponse::Error {
161 id: resolved.id.clone(),
162 code: EmbedErrorCode::EmbedUnsupported,
163 message: format!("backend {backend_name:?} does not support embeddings"),
164 };
165 write_response_embed(&writer, &resp).await?;
166 continue;
167 }
168
169 let req_id = resolved.id.clone();
170 let n_inputs = resolved.input.len();
171
172 let result = backend.embed(resolved).await;
173 match result {
174 Ok(out) => {
175 let usage = out.usage;
176 let dimensions = out.dimensions;
177 let frame = EmbedResponse::Embeddings {
178 id: req_id.clone(),
179 embeddings: out.embeddings,
180 dimensions,
181 model: out.model,
182 usage,
183 backend: backend_name.clone(),
184 };
185 write_response_embed(&writer, &frame).await?;
186 router.record_success(&backend_name);
187 info!(
188 target: "inferd_daemon::activity",
189 req_id = %req_id,
190 backend = %backend_name,
191 wire_version = "embed",
192 n_inputs = n_inputs,
193 input_tokens = usage.input_tokens,
194 dimensions = dimensions,
195 "embed_request_done"
196 );
197 }
198 Err(e) => {
199 let (code, message, is_backend_failure) = match e {
200 EmbedError::InvalidRequest(m) => (EmbedErrorCode::InvalidRequest, m, false),
201 EmbedError::NotReady => (
202 EmbedErrorCode::BackendUnavailable,
203 "backend not ready".into(),
204 true,
205 ),
206 EmbedError::Unavailable(m) => (EmbedErrorCode::BackendUnavailable, m, true),
207 EmbedError::Unsupported => (
208 EmbedErrorCode::EmbedUnsupported,
209 "embed not supported by this backend".into(),
210 false,
211 ),
212 EmbedError::Internal(m) => (EmbedErrorCode::Internal, m, true),
213 };
214 if is_backend_failure {
215 router.record_failure(&backend_name);
216 }
217 let frame = EmbedResponse::Error {
218 id: req_id,
219 code,
220 message,
221 };
222 write_response_embed(&writer, &frame).await?;
223 }
224 }
225 }
226}
227
228fn error_code_for(e: &ProtoError) -> EmbedErrorCode {
229 match e {
230 ProtoError::FrameTooLarge => EmbedErrorCode::FrameTooLarge,
231 ProtoError::Decode(_) | ProtoError::InvalidRequest(_) => EmbedErrorCode::InvalidRequest,
232 ProtoError::Io(_) => EmbedErrorCode::Internal,
233 }
234}
235
236async fn read_auth_frame<R>(reader: &mut R) -> Option<AuthFrame>
237where
238 R: tokio::io::AsyncBufRead + Unpin,
239{
240 use tokio::io::AsyncBufReadExt;
241 let mut line = Vec::with_capacity(256);
242 let limit = inferd_proto::MAX_FRAME_BYTES;
243 loop {
244 let buf = reader.fill_buf().await.ok()?;
245 if buf.is_empty() {
246 return None;
247 }
248 if let Some(idx) = buf.iter().position(|&b| b == b'\n') {
249 if line.len() + idx > limit {
250 return None;
251 }
252 line.extend_from_slice(&buf[..idx]);
253 reader.consume(idx + 1);
254 return AuthFrame::from_json(&line);
255 }
256 if line.len() + buf.len() > limit {
257 return None;
258 }
259 line.extend_from_slice(buf);
260 let n = buf.len();
261 reader.consume(n);
262 }
263}
264
265async fn read_request_embed<R>(reader: &mut R) -> Result<Option<EmbedRequest>, ProtoError>
266where
267 R: tokio::io::AsyncBufRead + Unpin,
268{
269 use tokio::io::AsyncBufReadExt;
270 let mut line = Vec::with_capacity(512);
271 let limit = inferd_proto::MAX_FRAME_BYTES;
272 loop {
273 let buf = reader.fill_buf().await?;
274 if buf.is_empty() {
275 if line.is_empty() {
276 return Ok(None);
277 }
278 return inferd_proto::read_frame::<&[u8], EmbedRequest>(&mut &line[..]);
279 }
280 if let Some(idx) = buf.iter().position(|&b| b == b'\n') {
281 if line.len() + idx > limit {
282 return Err(ProtoError::FrameTooLarge);
283 }
284 line.extend_from_slice(&buf[..=idx]);
285 reader.consume(idx + 1);
286 return inferd_proto::read_frame::<&[u8], EmbedRequest>(&mut &line[..]);
287 }
288 if line.len() + buf.len() > limit {
289 return Err(ProtoError::FrameTooLarge);
290 }
291 line.extend_from_slice(buf);
292 let n = buf.len();
293 reader.consume(n);
294 }
295}
296
297async fn write_response_embed<W: AsyncWrite + Unpin>(
298 writer: &Mutex<W>,
299 resp: &EmbedResponse,
300) -> io::Result<()> {
301 let mut buf = Vec::with_capacity(512);
302 write_frame(&mut buf, resp)
303 .map_err(|e| io::Error::other(format!("serialise embed response: {e}")))?;
304 let mut guard = writer.lock().await;
305 guard.write_all(&buf).await?;
306 guard.flush().await?;
307 Ok(())
308}
309
310pub async fn serve_tcp_embed(
312 listener: tokio::net::TcpListener,
313 router: Arc<Router>,
314 ctx: AcceptContext,
315 mut shutdown: tokio::sync::oneshot::Receiver<()>,
316) -> io::Result<()> {
317 info!(addr = ?listener.local_addr()?, "embed tcp listener accepting");
318 loop {
319 tokio::select! {
320 _ = &mut shutdown => {
321 info!("embed tcp shutdown signalled");
322 return Ok(());
323 }
324 accept = listener.accept() => {
325 let (stream, peer_addr) = accept?;
326 let peer = PeerIdentity::from_tcp(peer_addr);
327 let r = Arc::clone(&router);
328 let ctx = ctx.clone();
329 debug!(?peer_addr, "embed tcp accept");
330 tokio::spawn(async move {
331 if let Err(e) = handle_embed_connection(stream, r, peer, ctx).await {
332 warn!(error = ?e, "embed connection terminated with error");
333 }
334 });
335 }
336 }
337 }
338}
339
340#[cfg(unix)]
342pub async fn serve_uds_embed(
343 listener: tokio::net::UnixListener,
344 router: Arc<Router>,
345 ctx: AcceptContext,
346 mut shutdown: tokio::sync::oneshot::Receiver<()>,
347) -> io::Result<()> {
348 info!("embed uds listener accepting");
349 loop {
350 tokio::select! {
351 _ = &mut shutdown => {
352 info!("embed uds shutdown signalled");
353 return Ok(());
354 }
355 accept = listener.accept() => {
356 let (stream, _) = accept?;
357 let r = Arc::clone(&router);
358 let peer = crate::peercred::unix::from_stream(&stream)
359 .unwrap_or_else(|e| {
360 warn!(error = %e, "embed SO_PEERCRED failed; recording empty unix identity");
361 crate::peercred::PeerIdentity {
362 uid: None, gid: None, pid: None,
363 sid: None, remote_addr: None,
364 transport: "unix",
365 }
366 });
367 let ctx = ctx.clone();
368 debug!(?peer, "embed uds accept");
369 tokio::spawn(async move {
370 if let Err(e) = handle_embed_connection(stream, r, peer, ctx).await {
371 warn!(error = ?e, "embed connection terminated with error");
372 }
373 });
374 }
375 }
376 }
377}
378
379#[cfg(windows)]
381pub async fn serve_named_pipe_embed(
382 path: &str,
383 first_instance: tokio::net::windows::named_pipe::NamedPipeServer,
384 router: Arc<Router>,
385 ctx: AcceptContext,
386 mut shutdown: tokio::sync::oneshot::Receiver<()>,
387) -> io::Result<()> {
388 use crate::endpoint::bind_named_pipe;
389
390 info!(path = %path, "embed named pipe listener accepting");
391 let mut server = first_instance;
392 loop {
393 tokio::select! {
394 _ = &mut shutdown => {
395 info!("embed named pipe shutdown signalled");
396 return Ok(());
397 }
398 connect_result = server.connect() => {
399 connect_result?;
400 let connected = server;
401 server = bind_named_pipe(path, false)?;
402
403 let peer = crate::peercred::windows::from_stream(&connected)
404 .unwrap_or_else(|e| {
405 warn!(error = %e, "embed GetNamedPipeClientProcessId failed; empty pipe identity");
406 crate::peercred::PeerIdentity {
407 uid: None, gid: None, pid: None,
408 sid: None, remote_addr: None,
409 transport: "pipe",
410 }
411 });
412 let r = Arc::clone(&router);
413 let ctx = ctx.clone();
414 debug!(?peer, "embed named pipe accept");
415 tokio::spawn(async move {
416 if let Err(e) = handle_embed_connection(connected, r, peer, ctx).await {
417 warn!(error = ?e, "embed connection terminated with error");
418 }
419 });
420 }
421 }
422 }
423}