1use crate::auth::{AuthFrame, key_matches};
19use crate::endpoint::Connection;
20use crate::peercred::PeerIdentity;
21use crate::queue::{Admission, SubmitError};
22use crate::router::{Router, RouterError};
23use inferd_engine::{GenerateError, TokenEvent};
24use inferd_proto::{ErrorCode, ProtoError, Request, Response, write_frame};
25use std::io;
26use std::sync::Arc;
27use std::time::{Duration, Instant};
28use tokio::io::{AsyncWrite, AsyncWriteExt, BufReader};
29use tokio::sync::Mutex;
30use tokio_stream::StreamExt;
31use tracing::{debug, info, warn};
32
33pub async fn wait_for_ready(router: &Router, timeout: Duration) -> Result<Duration, ReadyTimeout> {
38 let started = Instant::now();
39 let poll = Duration::from_millis(50);
40 loop {
41 if router.all_ready() {
42 return Ok(started.elapsed());
43 }
44 if started.elapsed() >= timeout {
45 return Err(ReadyTimeout(timeout));
46 }
47 tokio::time::sleep(poll).await;
48 }
49}
50
51#[derive(Debug, thiserror::Error)]
54#[error("backend not ready within {0:?}")]
55pub struct ReadyTimeout(pub Duration);
56
57#[derive(Clone, Default)]
65pub struct AcceptContext {
66 pub expected_api_key: Option<String>,
71 pub admission: Option<Admission>,
75}
76
77impl std::fmt::Debug for AcceptContext {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 f.debug_struct("AcceptContext")
80 .field("expected_api_key", &self.expected_api_key.is_some())
81 .field(
82 "admission_capacity",
83 &self.admission.as_ref().map(|a| a.capacity()),
84 )
85 .finish()
86 }
87}
88
89pub async fn handle_connection<C: Connection + 'static>(
103 mut conn: C,
104 router: Arc<Router>,
105 peer: PeerIdentity,
106 ctx: AcceptContext,
107) -> Result<(), io::Error> {
108 let transport = conn.transport();
109 info!(
110 target: "inferd_daemon::activity",
111 transport = transport,
112 peer = %peer,
113 peer_uid = peer.uid,
114 peer_pid = peer.pid,
115 peer_sid = peer.sid.as_deref(),
116 "connection_accepted"
117 );
118
119 let (read_half, write_half) = tokio::io::split(&mut conn);
124 let mut reader = BufReader::with_capacity(64 * 1024, read_half);
125 let writer = Arc::new(Mutex::new(write_half));
126
127 if transport == "tcp"
131 && let Some(expected) = ctx.expected_api_key.as_deref()
132 {
133 match read_auth_frame(&mut reader).await {
134 Some(frame) if key_matches(&frame.key, expected) => {
135 debug!(transport, "tcp auth ok");
136 }
137 _ => {
138 warn!(
139 target: "inferd_daemon::activity",
140 peer = %peer,
141 "tcp_auth_rejected"
142 );
143 return Ok(());
144 }
145 }
146 }
147
148 loop {
149 let request: Request = match read_frame_async(&mut reader).await {
153 Ok(Some(r)) => r,
154 Ok(None) => return Ok(()), Err(ProtoError::Io(e)) => return Err(e),
156 Err(e) => {
157 let resp = Response::Error {
158 id: String::new(),
159 code: e.to_error_code(),
160 message: e.to_string(),
161 };
162 write_response(&writer, &resp).await?;
163 return Ok(());
164 }
165 };
166
167 let id = request.id.clone();
169 let resolved = match request.resolve() {
170 Ok(r) => r,
171 Err(e) => {
172 let resp = Response::Error {
173 id,
174 code: ErrorCode::InvalidRequest,
175 message: e.to_string(),
176 };
177 write_response(&writer, &resp).await?;
178 continue;
179 }
180 };
181
182 let _admit_permit = match ctx.admission.as_ref().map(|a| a.try_admit()) {
188 None => None,
189 Some(Ok(p)) => Some(p),
190 Some(Err(SubmitError::QueueFull)) => {
191 let resp = Response::Error {
192 id: resolved.id.clone(),
193 code: ErrorCode::QueueFull,
194 message: "queue full".into(),
195 };
196 write_response(&writer, &resp).await?;
197 continue;
198 }
199 Some(Err(SubmitError::Closed)) => {
200 let resp = Response::Error {
204 id: resolved.id.clone(),
205 code: ErrorCode::BackendUnavailable,
206 message: "admission closed".into(),
207 };
208 write_response(&writer, &resp).await?;
209 return Ok(());
210 }
211 };
212
213 let dispatch = match router.dispatch() {
215 Ok(d) => d,
216 Err(RouterError::NoBackends) | Err(RouterError::NoneAvailable) => {
217 let resp = Response::Error {
218 id: resolved.id.clone(),
219 code: ErrorCode::BackendUnavailable,
220 message: "no backend available".into(),
221 };
222 write_response(&writer, &resp).await?;
223 continue;
224 }
225 };
226 let backend_name = dispatch.name.clone();
227 let backend = dispatch.backend;
228 let req_id = resolved.id.clone();
229
230 let mut stream = match backend.generate(resolved).await {
234 Ok(s) => s,
235 Err(e) => {
236 let (code, message, is_backend_failure) = match e {
237 GenerateError::InvalidRequest(m) => (ErrorCode::InvalidRequest, m, false),
238 GenerateError::NotReady => (
239 ErrorCode::BackendUnavailable,
240 "backend not ready".into(),
241 true,
242 ),
243 GenerateError::Unavailable(m) => (ErrorCode::BackendUnavailable, m, true),
244 GenerateError::Internal(m) => (ErrorCode::Internal, m, true),
245 };
246 if is_backend_failure {
247 router.record_failure(&backend_name);
248 }
249 let resp = Response::Error {
250 id: req_id,
251 code,
252 message,
253 };
254 write_response(&writer, &resp).await?;
255 continue;
256 }
257 };
258
259 let mut full = String::new();
262 let mut terminal_emitted = false;
263 while let Some(ev) = stream.next().await {
264 match ev {
265 TokenEvent::Token(text) => {
266 let frame = Response::Token {
267 id: req_id.clone(),
268 content: text.clone(),
269 };
270 write_response(&writer, &frame).await?;
271 full.push_str(&text);
272 }
273 TokenEvent::Done { stop_reason, usage } => {
274 let frame = Response::Done {
275 id: req_id.clone(),
276 content: std::mem::take(&mut full),
277 usage,
278 stop_reason,
279 backend: backend_name.clone(),
280 };
281 write_response(&writer, &frame).await?;
282 info!(
283 target: "inferd_daemon::activity",
284 req_id = %req_id,
285 backend = %backend_name,
286 stop_reason = ?stop_reason,
287 prompt_tokens = usage.prompt_tokens,
288 completion_tokens = usage.completion_tokens,
289 "request_done"
290 );
291 router.record_success(&backend_name);
292 terminal_emitted = true;
293 break;
294 }
295 }
296 }
297
298 if !terminal_emitted {
299 warn!(
303 target: "inferd_daemon::activity",
304 req_id = %req_id,
305 backend = %backend_name,
306 "request_error_mid_stream"
307 );
308 router.record_failure(&backend_name);
309 let frame = Response::Error {
310 id: req_id,
311 code: ErrorCode::BackendUnavailable,
312 message: "backend ended stream without terminal frame".into(),
313 };
314 write_response(&writer, &frame).await?;
315 }
316 }
317}
318
319async fn read_auth_frame<R>(reader: &mut R) -> Option<AuthFrame>
327where
328 R: tokio::io::AsyncBufRead + Unpin,
329{
330 use tokio::io::AsyncBufReadExt;
331 let mut line = Vec::with_capacity(256);
332 let limit = inferd_proto::MAX_FRAME_BYTES;
333 loop {
334 let buf = reader.fill_buf().await.ok()?;
335 if buf.is_empty() {
336 return None;
337 }
338 if let Some(idx) = buf.iter().position(|&b| b == b'\n') {
339 if line.len() + idx > limit {
340 return None;
341 }
342 line.extend_from_slice(&buf[..idx]);
343 reader.consume(idx + 1);
344 return AuthFrame::from_json(&line);
345 }
346 if line.len() + buf.len() > limit {
347 return None;
348 }
349 line.extend_from_slice(buf);
350 let n = buf.len();
351 reader.consume(n);
352 }
353}
354
355async fn read_frame_async<R>(reader: &mut R) -> Result<Option<Request>, ProtoError>
363where
364 R: tokio::io::AsyncBufRead + Unpin,
365{
366 use tokio::io::AsyncBufReadExt;
367 let mut line = Vec::with_capacity(512);
368 let limit = inferd_proto::MAX_FRAME_BYTES;
369 loop {
370 let buf = reader.fill_buf().await?;
371 if buf.is_empty() {
372 if line.is_empty() {
373 return Ok(None);
374 }
375 return inferd_proto::read_frame::<&[u8], Request>(&mut &line[..]);
379 }
380 if let Some(idx) = buf.iter().position(|&b| b == b'\n') {
381 if line.len() + idx > limit {
382 return Err(ProtoError::FrameTooLarge);
383 }
384 line.extend_from_slice(&buf[..=idx]);
385 reader.consume(idx + 1);
386 return inferd_proto::read_frame::<&[u8], Request>(&mut &line[..]);
387 }
388 if line.len() + buf.len() > limit {
389 return Err(ProtoError::FrameTooLarge);
390 }
391 line.extend_from_slice(buf);
392 let n = buf.len();
393 reader.consume(n);
394 }
395}
396
397async fn write_response<W: AsyncWrite + Unpin>(
398 writer: &Mutex<W>,
399 resp: &Response,
400) -> io::Result<()> {
401 let mut buf = Vec::with_capacity(512);
402 write_frame(&mut buf, resp)
403 .map_err(|e| io::Error::other(format!("serialise response: {e}")))?;
404 let mut guard = writer.lock().await;
405 guard.write_all(&buf).await?;
406 guard.flush().await?;
407 Ok(())
408}
409
410pub async fn serve_tcp(
416 listener: tokio::net::TcpListener,
417 router: Arc<Router>,
418 ctx: AcceptContext,
419 mut shutdown: tokio::sync::oneshot::Receiver<()>,
420) -> io::Result<()> {
421 info!(addr = ?listener.local_addr()?, "tcp listener accepting");
422 loop {
423 tokio::select! {
424 _ = &mut shutdown => {
425 info!("shutdown signalled");
426 return Ok(());
427 }
428 accept = listener.accept() => {
429 let (stream, peer_addr) = accept?;
430 let r = Arc::clone(&router);
431 let peer = PeerIdentity::from_tcp(peer_addr);
432 let ctx = ctx.clone();
433 debug!(?peer_addr, "tcp accept");
434 tokio::spawn(async move {
435 if let Err(e) = handle_connection(stream, r, peer, ctx).await {
436 warn!(error = ?e, "connection terminated with error");
437 }
438 });
439 }
440 }
441 }
442}
443
444#[cfg(unix)]
446pub async fn serve_uds(
447 listener: tokio::net::UnixListener,
448 router: Arc<Router>,
449 ctx: AcceptContext,
450 mut shutdown: tokio::sync::oneshot::Receiver<()>,
451) -> io::Result<()> {
452 info!("uds listener accepting");
453 loop {
454 tokio::select! {
455 _ = &mut shutdown => {
456 info!("shutdown signalled");
457 return Ok(());
458 }
459 accept = listener.accept() => {
460 let (stream, _) = accept?;
461 let r = Arc::clone(&router);
462 let peer = crate::peercred::unix::from_stream(&stream)
467 .unwrap_or_else(|e| {
468 warn!(error = %e, "SO_PEERCRED failed; recording empty unix identity");
469 crate::peercred::PeerIdentity {
470 uid: None, gid: None, pid: None,
471 sid: None, remote_addr: None,
472 transport: "unix",
473 }
474 });
475 let ctx = ctx.clone();
476 debug!(?peer, "uds accept");
477 tokio::spawn(async move {
478 if let Err(e) = handle_connection(stream, r, peer, ctx).await {
479 warn!(error = ?e, "connection terminated with error");
480 }
481 });
482 }
483 }
484 }
485}
486
487#[cfg(windows)]
504pub async fn serve_named_pipe(
505 path: &str,
506 first_instance: tokio::net::windows::named_pipe::NamedPipeServer,
507 router: Arc<Router>,
508 ctx: AcceptContext,
509 mut shutdown: tokio::sync::oneshot::Receiver<()>,
510) -> io::Result<()> {
511 use crate::endpoint::bind_named_pipe;
512
513 info!(path = %path, "named pipe listener accepting");
514 let mut server = first_instance;
515 loop {
516 tokio::select! {
517 _ = &mut shutdown => {
518 info!("shutdown signalled");
519 return Ok(());
520 }
521 connect_result = server.connect() => {
522 connect_result?;
523 let connected = server;
527 server = bind_named_pipe(path, false)?;
528
529 let peer = crate::peercred::windows::from_stream(&connected)
534 .unwrap_or_else(|e| {
535 warn!(error = %e, "GetNamedPipeClientProcessId failed; empty pipe identity");
536 crate::peercred::PeerIdentity {
537 uid: None, gid: None, pid: None,
538 sid: None, remote_addr: None,
539 transport: "pipe",
540 }
541 });
542 let r = Arc::clone(&router);
543 let ctx = ctx.clone();
544 debug!(?peer, "named pipe accept");
545 tokio::spawn(async move {
546 if let Err(e) = handle_connection(connected, r, peer, ctx).await {
547 warn!(error = ?e, "connection terminated with error");
548 }
549 });
550 }
551 }
552 }
553}
554
555#[cfg(test)]
556mod tests {
557 use super::*;
558 use inferd_engine::mock::Mock;
559
560 #[tokio::test]
561 async fn wait_for_ready_returns_when_already_ready() {
562 let router = Router::new(vec![Arc::new(Mock::new())]);
563 let elapsed = wait_for_ready(&router, Duration::from_secs(1))
564 .await
565 .unwrap();
566 assert!(elapsed < Duration::from_millis(100));
567 }
568
569 #[tokio::test]
570 async fn wait_for_ready_times_out_when_not_ready() {
571 let mock = Arc::new(Mock::new());
572 mock.set_ready(false);
573 let router = Router::new(vec![mock]);
574 let err = wait_for_ready(&router, Duration::from_millis(100))
575 .await
576 .unwrap_err();
577 assert!(err.to_string().contains("not ready"));
578 }
579
580 #[tokio::test]
581 async fn wait_for_ready_succeeds_after_delayed_ready() {
582 let mock = Arc::new(Mock::new());
583 mock.set_ready(false);
584 let router = Router::new(vec![mock.clone()]);
585
586 let m2 = Arc::clone(&mock);
587 tokio::spawn(async move {
588 tokio::time::sleep(Duration::from_millis(150)).await;
589 m2.set_ready(true);
590 });
591
592 let elapsed = wait_for_ready(&router, Duration::from_secs(1))
593 .await
594 .unwrap();
595 assert!(elapsed >= Duration::from_millis(100));
596 }
597}