use crate::auth::{AuthFrame, key_matches};
use crate::endpoint::Connection;
use crate::peercred::PeerIdentity;
use crate::queue::{Admission, SubmitError};
use crate::router::{Router, RouterError};
use inferd_engine::{GenerateError, TokenEvent};
use inferd_proto::{ErrorCode, ProtoError, Request, Response, write_frame};
use std::io;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::io::{AsyncWrite, AsyncWriteExt, BufReader};
use tokio::sync::Mutex;
use tokio_stream::StreamExt;
use tracing::{debug, info, warn};
pub async fn wait_for_ready(router: &Router, timeout: Duration) -> Result<Duration, ReadyTimeout> {
let started = Instant::now();
let poll = Duration::from_millis(50);
loop {
if router.all_ready() {
return Ok(started.elapsed());
}
if started.elapsed() >= timeout {
return Err(ReadyTimeout(timeout));
}
tokio::time::sleep(poll).await;
}
}
#[derive(Debug, thiserror::Error)]
#[error("backend not ready within {0:?}")]
pub struct ReadyTimeout(pub Duration);
#[derive(Clone, Default)]
pub struct AcceptContext {
pub expected_api_key: Option<String>,
pub admission: Option<Admission>,
}
impl std::fmt::Debug for AcceptContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AcceptContext")
.field("expected_api_key", &self.expected_api_key.is_some())
.field(
"admission_capacity",
&self.admission.as_ref().map(|a| a.capacity()),
)
.finish()
}
}
pub async fn handle_connection<C: Connection + 'static>(
mut conn: C,
router: Arc<Router>,
peer: PeerIdentity,
ctx: AcceptContext,
) -> Result<(), io::Error> {
let transport = conn.transport();
info!(
target: "inferd_daemon::activity",
transport = transport,
peer = %peer,
peer_uid = peer.uid,
peer_pid = peer.pid,
peer_sid = peer.sid.as_deref(),
"connection_accepted"
);
let (read_half, write_half) = tokio::io::split(&mut conn);
let mut reader = BufReader::with_capacity(64 * 1024, read_half);
let writer = Arc::new(Mutex::new(write_half));
if transport == "tcp"
&& let Some(expected) = ctx.expected_api_key.as_deref()
{
match read_auth_frame(&mut reader).await {
Some(frame) if key_matches(&frame.key, expected) => {
debug!(transport, "tcp auth ok");
}
_ => {
warn!(
target: "inferd_daemon::activity",
peer = %peer,
"tcp_auth_rejected"
);
return Ok(());
}
}
}
loop {
let request: Request = match read_frame_async(&mut reader).await {
Ok(Some(r)) => r,
Ok(None) => return Ok(()), Err(ProtoError::Io(e)) => return Err(e),
Err(e) => {
let resp = Response::Error {
id: String::new(),
code: e.to_error_code(),
message: e.to_string(),
};
write_response(&writer, &resp).await?;
return Ok(());
}
};
let id = request.id.clone();
let resolved = match request.resolve() {
Ok(r) => r,
Err(e) => {
let resp = Response::Error {
id,
code: ErrorCode::InvalidRequest,
message: e.to_string(),
};
write_response(&writer, &resp).await?;
continue;
}
};
let _admit_permit = match ctx.admission.as_ref().map(|a| a.try_admit()) {
None => None,
Some(Ok(p)) => Some(p),
Some(Err(SubmitError::QueueFull)) => {
let resp = Response::Error {
id: resolved.id.clone(),
code: ErrorCode::QueueFull,
message: "queue full".into(),
};
write_response(&writer, &resp).await?;
continue;
}
Some(Err(SubmitError::Closed)) => {
let resp = Response::Error {
id: resolved.id.clone(),
code: ErrorCode::BackendUnavailable,
message: "admission closed".into(),
};
write_response(&writer, &resp).await?;
return Ok(());
}
};
let dispatch = match router.dispatch() {
Ok(d) => d,
Err(RouterError::NoBackends) | Err(RouterError::NoneAvailable) => {
let resp = Response::Error {
id: resolved.id.clone(),
code: ErrorCode::BackendUnavailable,
message: "no backend available".into(),
};
write_response(&writer, &resp).await?;
continue;
}
};
let backend_name = dispatch.name.clone();
let backend = dispatch.backend;
let req_id = resolved.id.clone();
let mut stream = match backend.generate(resolved).await {
Ok(s) => s,
Err(e) => {
let (code, message, is_backend_failure) = match e {
GenerateError::InvalidRequest(m) => (ErrorCode::InvalidRequest, m, false),
GenerateError::NotReady => (
ErrorCode::BackendUnavailable,
"backend not ready".into(),
true,
),
GenerateError::Unavailable(m) => (ErrorCode::BackendUnavailable, m, true),
GenerateError::Internal(m) => (ErrorCode::Internal, m, true),
};
if is_backend_failure {
router.record_failure(&backend_name);
}
let resp = Response::Error {
id: req_id,
code,
message,
};
write_response(&writer, &resp).await?;
continue;
}
};
let mut full = String::new();
let mut terminal_emitted = false;
while let Some(ev) = stream.next().await {
match ev {
TokenEvent::Token(text) => {
let frame = Response::Token {
id: req_id.clone(),
content: text.clone(),
};
write_response(&writer, &frame).await?;
full.push_str(&text);
}
TokenEvent::Done { stop_reason, usage } => {
let frame = Response::Done {
id: req_id.clone(),
content: std::mem::take(&mut full),
usage,
stop_reason,
backend: backend_name.clone(),
};
write_response(&writer, &frame).await?;
info!(
target: "inferd_daemon::activity",
req_id = %req_id,
backend = %backend_name,
stop_reason = ?stop_reason,
prompt_tokens = usage.prompt_tokens,
completion_tokens = usage.completion_tokens,
"request_done"
);
router.record_success(&backend_name);
terminal_emitted = true;
break;
}
}
}
if !terminal_emitted {
warn!(
target: "inferd_daemon::activity",
req_id = %req_id,
backend = %backend_name,
"request_error_mid_stream"
);
router.record_failure(&backend_name);
let frame = Response::Error {
id: req_id,
code: ErrorCode::BackendUnavailable,
message: "backend ended stream without terminal frame".into(),
};
write_response(&writer, &frame).await?;
}
}
}
async fn read_auth_frame<R>(reader: &mut R) -> Option<AuthFrame>
where
R: tokio::io::AsyncBufRead + Unpin,
{
use tokio::io::AsyncBufReadExt;
let mut line = Vec::with_capacity(256);
let limit = inferd_proto::MAX_FRAME_BYTES;
loop {
let buf = reader.fill_buf().await.ok()?;
if buf.is_empty() {
return None;
}
if let Some(idx) = buf.iter().position(|&b| b == b'\n') {
if line.len() + idx > limit {
return None;
}
line.extend_from_slice(&buf[..idx]);
reader.consume(idx + 1);
return AuthFrame::from_json(&line);
}
if line.len() + buf.len() > limit {
return None;
}
line.extend_from_slice(buf);
let n = buf.len();
reader.consume(n);
}
}
async fn read_frame_async<R>(reader: &mut R) -> Result<Option<Request>, ProtoError>
where
R: tokio::io::AsyncBufRead + Unpin,
{
use tokio::io::AsyncBufReadExt;
let mut line = Vec::with_capacity(512);
let limit = inferd_proto::MAX_FRAME_BYTES;
loop {
let buf = reader.fill_buf().await?;
if buf.is_empty() {
if line.is_empty() {
return Ok(None);
}
return inferd_proto::read_frame::<&[u8], Request>(&mut &line[..]);
}
if let Some(idx) = buf.iter().position(|&b| b == b'\n') {
if line.len() + idx > limit {
return Err(ProtoError::FrameTooLarge);
}
line.extend_from_slice(&buf[..=idx]);
reader.consume(idx + 1);
return inferd_proto::read_frame::<&[u8], Request>(&mut &line[..]);
}
if line.len() + buf.len() > limit {
return Err(ProtoError::FrameTooLarge);
}
line.extend_from_slice(buf);
let n = buf.len();
reader.consume(n);
}
}
async fn write_response<W: AsyncWrite + Unpin>(
writer: &Mutex<W>,
resp: &Response,
) -> io::Result<()> {
let mut buf = Vec::with_capacity(512);
write_frame(&mut buf, resp)
.map_err(|e| io::Error::other(format!("serialise response: {e}")))?;
let mut guard = writer.lock().await;
guard.write_all(&buf).await?;
guard.flush().await?;
Ok(())
}
pub async fn serve_tcp(
listener: tokio::net::TcpListener,
router: Arc<Router>,
ctx: AcceptContext,
mut shutdown: tokio::sync::oneshot::Receiver<()>,
) -> io::Result<()> {
info!(addr = ?listener.local_addr()?, "tcp listener accepting");
loop {
tokio::select! {
_ = &mut shutdown => {
info!("shutdown signalled");
return Ok(());
}
accept = listener.accept() => {
let (stream, peer_addr) = accept?;
let r = Arc::clone(&router);
let peer = PeerIdentity::from_tcp(peer_addr);
let ctx = ctx.clone();
debug!(?peer_addr, "tcp accept");
tokio::spawn(async move {
if let Err(e) = handle_connection(stream, r, peer, ctx).await {
warn!(error = ?e, "connection terminated with error");
}
});
}
}
}
}
#[cfg(unix)]
pub async fn serve_uds(
listener: tokio::net::UnixListener,
router: Arc<Router>,
ctx: AcceptContext,
mut shutdown: tokio::sync::oneshot::Receiver<()>,
) -> io::Result<()> {
info!("uds listener accepting");
loop {
tokio::select! {
_ = &mut shutdown => {
info!("shutdown signalled");
return Ok(());
}
accept = listener.accept() => {
let (stream, _) = accept?;
let r = Arc::clone(&router);
let peer = crate::peercred::unix::from_stream(&stream)
.unwrap_or_else(|e| {
warn!(error = %e, "SO_PEERCRED failed; recording empty unix identity");
crate::peercred::PeerIdentity {
uid: None, gid: None, pid: None,
sid: None, remote_addr: None,
transport: "unix",
}
});
let ctx = ctx.clone();
debug!(?peer, "uds accept");
tokio::spawn(async move {
if let Err(e) = handle_connection(stream, r, peer, ctx).await {
warn!(error = ?e, "connection terminated with error");
}
});
}
}
}
}
#[cfg(windows)]
pub async fn serve_named_pipe(
path: &str,
first_instance: tokio::net::windows::named_pipe::NamedPipeServer,
router: Arc<Router>,
ctx: AcceptContext,
mut shutdown: tokio::sync::oneshot::Receiver<()>,
) -> io::Result<()> {
use crate::endpoint::bind_named_pipe;
info!(path = %path, "named pipe listener accepting");
let mut server = first_instance;
loop {
tokio::select! {
_ = &mut shutdown => {
info!("shutdown signalled");
return Ok(());
}
connect_result = server.connect() => {
connect_result?;
let connected = server;
server = bind_named_pipe(path, false)?;
let peer = crate::peercred::windows::from_stream(&connected)
.unwrap_or_else(|e| {
warn!(error = %e, "GetNamedPipeClientProcessId failed; empty pipe identity");
crate::peercred::PeerIdentity {
uid: None, gid: None, pid: None,
sid: None, remote_addr: None,
transport: "pipe",
}
});
let r = Arc::clone(&router);
let ctx = ctx.clone();
debug!(?peer, "named pipe accept");
tokio::spawn(async move {
if let Err(e) = handle_connection(connected, r, peer, ctx).await {
warn!(error = ?e, "connection terminated with error");
}
});
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use inferd_engine::mock::Mock;
#[tokio::test]
async fn wait_for_ready_returns_when_already_ready() {
let router = Router::new(vec![Arc::new(Mock::new())]);
let elapsed = wait_for_ready(&router, Duration::from_secs(1))
.await
.unwrap();
assert!(elapsed < Duration::from_millis(100));
}
#[tokio::test]
async fn wait_for_ready_times_out_when_not_ready() {
let mock = Arc::new(Mock::new());
mock.set_ready(false);
let router = Router::new(vec![mock]);
let err = wait_for_ready(&router, Duration::from_millis(100))
.await
.unwrap_err();
assert!(err.to_string().contains("not ready"));
}
#[tokio::test]
async fn wait_for_ready_succeeds_after_delayed_ready() {
let mock = Arc::new(Mock::new());
mock.set_ready(false);
let router = Router::new(vec![mock.clone()]);
let m2 = Arc::clone(&mock);
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(150)).await;
m2.set_ready(true);
});
let elapsed = wait_for_ready(&router, Duration::from_secs(1))
.await
.unwrap();
assert!(elapsed >= Duration::from_millis(100));
}
}