Skip to main content

inferd_client/
client.rs

1//! Inference-socket client. NDJSON over UDS / pipe / loopback TCP.
2
3use inferd_proto::{Request, Response};
4use std::io;
5#[cfg(unix)]
6use std::path::Path;
7use std::pin::Pin;
8use std::sync::Arc;
9use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
10use tokio::net::TcpStream;
11use tokio::sync::Mutex;
12use tokio_stream::Stream;
13
14/// Errors produced by the inference client.
15#[derive(Debug, thiserror::Error)]
16pub enum ClientError {
17    /// Connection / I/O error against the daemon.
18    #[error("io: {0}")]
19    Io(#[from] io::Error),
20    /// JSON encode/decode of a wire frame failed.
21    #[error("decode: {0}")]
22    Decode(#[from] serde_json::Error),
23    /// Connection closed before a terminal `done` / `error` frame.
24    /// Per `docs/protocol-v1.md`, callers treat this as an error
25    /// equivalent to `code: backend_unavailable` and apply their own
26    /// retry policy.
27    #[error("daemon closed connection before terminal frame")]
28    UnexpectedEof,
29}
30
31/// Stream of `Response` frames yielded by `Client::generate`.
32pub type FrameStream = Pin<Box<dyn Stream<Item = Result<Response, ClientError>> + Send>>;
33
34/// Inference-socket client.
35///
36/// Construct via `dial_tcp` / `dial_uds` (Unix) / `dial_pipe`
37/// (Windows). Wrap with [`crate::dial_and_wait_ready`] to retry
38/// connect during daemon bring-up.
39pub struct Client {
40    inner: Arc<Mutex<Inner>>,
41}
42
43impl std::fmt::Debug for Client {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        // Inner BufReader / write half don't impl Debug usefully;
46        // give callers something to print without leaking the
47        // underlying transport state.
48        f.debug_struct("Client").finish_non_exhaustive()
49    }
50}
51
52struct Inner {
53    write: Box<dyn AsyncWrite + Send + Unpin>,
54    read: BufReader<Box<dyn AsyncRead + Send + Unpin>>,
55}
56
57impl Client {
58    /// Open a TCP connection to `addr` (e.g. `"127.0.0.1:47321"`).
59    pub async fn dial_tcp(addr: &str) -> Result<Self, ClientError> {
60        let stream = TcpStream::connect(addr).await?;
61        let (read, write) = stream.into_split();
62        Ok(Self::wrap(Box::new(read), Box::new(write)))
63    }
64
65    /// Open a Unix domain socket connection (Unix only).
66    #[cfg(unix)]
67    pub async fn dial_uds(path: &Path) -> Result<Self, ClientError> {
68        let stream = tokio::net::UnixStream::connect(path).await?;
69        let (read, write) = stream.into_split();
70        Ok(Self::wrap(Box::new(read), Box::new(write)))
71    }
72
73    /// Open a Windows named pipe connection (Windows only).
74    #[cfg(windows)]
75    pub async fn dial_pipe(path: &str) -> Result<Self, ClientError> {
76        use tokio::net::windows::named_pipe::ClientOptions;
77        let pipe = ClientOptions::new().open(path)?;
78        let (read, write) = tokio::io::split(pipe);
79        Ok(Self::wrap(Box::new(read), Box::new(write)))
80    }
81
82    fn wrap(
83        read: Box<dyn AsyncRead + Send + Unpin>,
84        write: Box<dyn AsyncWrite + Send + Unpin>,
85    ) -> Self {
86        Self {
87            inner: Arc::new(Mutex::new(Inner {
88                write,
89                read: BufReader::with_capacity(64 * 1024, read),
90            })),
91        }
92    }
93
94    /// Test-only constructor: build a `Client` from arbitrary
95    /// `AsyncRead`/`AsyncWrite` halves. Lets sibling-module tests
96    /// stub the transport with `tokio::io::duplex`. Not part of the
97    /// public API.
98    #[doc(hidden)]
99    pub fn wrap_for_test(
100        read: Box<dyn AsyncRead + Send + Unpin>,
101        write: Box<dyn AsyncWrite + Send + Unpin>,
102    ) -> Self {
103        Self::wrap(read, write)
104    }
105
106    /// Send a `Request` and return a stream of response frames.
107    /// The stream completes after a terminal `done` or `error` frame
108    /// (or yields `Err(ClientError::UnexpectedEof)` if the daemon
109    /// closes the connection mid-stream).
110    pub async fn generate(&mut self, req: Request) -> Result<FrameStream, ClientError> {
111        let mut buf = Vec::with_capacity(512);
112        serde_json::to_writer(&mut buf, &req)?;
113        buf.push(b'\n');
114
115        {
116            let mut g = self.inner.lock().await;
117            g.write.write_all(&buf).await?;
118            g.write.flush().await?;
119        }
120
121        let inner = Arc::clone(&self.inner);
122        let stream = async_stream::stream! {
123            loop {
124                let mut g = inner.lock().await;
125                let mut line = Vec::with_capacity(512);
126                let n = match g.read.read_until(b'\n', &mut line).await {
127                    Ok(n) => n,
128                    Err(e) => { yield Err(ClientError::Io(e)); return; }
129                };
130                if n == 0 {
131                    yield Err(ClientError::UnexpectedEof);
132                    return;
133                }
134                drop(g);
135
136                match serde_json::from_slice::<Response>(&line) {
137                    Ok(resp) => {
138                        let terminal = matches!(
139                            &resp,
140                            Response::Done { .. } | Response::Error { .. }
141                        );
142                        yield Ok(resp);
143                        if terminal {
144                            return;
145                        }
146                    }
147                    Err(e) => {
148                        yield Err(ClientError::Decode(e));
149                        return;
150                    }
151                }
152            }
153        };
154        Ok(Box::pin(stream))
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161    use inferd_proto::{ErrorCode, Message, Role, StopReason, Usage};
162    use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
163
164    fn sample_request() -> Request {
165        Request {
166            id: "test".into(),
167            messages: vec![Message {
168                role: Role::User,
169                content: "hello".into(),
170            }],
171            temperature: None,
172            top_p: None,
173            top_k: None,
174            max_tokens: None,
175            stream: None,
176            image_token_budget: None,
177            grammar: String::new(),
178        }
179    }
180
181    /// Wrap an in-memory duplex pipe as a Client and drive a
182    /// canned-server scenario from the other end of the pipe.
183    /// Avoids a real socket so unit tests stay platform-agnostic.
184    #[tokio::test]
185    async fn generate_streams_token_then_done() {
186        let (server_side, client_side) = tokio::io::duplex(4096);
187        let (read, write) = tokio::io::split(client_side);
188        let mut client = Client::wrap(Box::new(read), Box::new(write));
189
190        // Server task: read one Request frame, write Token + Done.
191        let server = tokio::spawn(async move {
192            let (rx, mut tx) = tokio::io::split(server_side);
193            let mut br = tokio::io::BufReader::new(rx);
194            let mut req_line = Vec::new();
195            br.read_until(b'\n', &mut req_line).await.unwrap();
196
197            // Token frame.
198            let token = serde_json::to_vec(&Response::Token {
199                id: "test".into(),
200                content: "hi".into(),
201            })
202            .unwrap();
203            tx.write_all(&token).await.unwrap();
204            tx.write_all(b"\n").await.unwrap();
205            // Done frame.
206            let done = serde_json::to_vec(&Response::Done {
207                id: "test".into(),
208                content: "hi".into(),
209                usage: Usage {
210                    prompt_tokens: 1,
211                    completion_tokens: 1,
212                },
213                stop_reason: StopReason::End,
214                backend: "mock".into(),
215            })
216            .unwrap();
217            tx.write_all(&done).await.unwrap();
218            tx.write_all(b"\n").await.unwrap();
219        });
220
221        let stream = client.generate(sample_request()).await.unwrap();
222        use tokio_stream::StreamExt;
223        let frames: Vec<_> = stream.collect().await;
224        server.await.unwrap();
225
226        assert_eq!(frames.len(), 2);
227        match frames[0].as_ref().unwrap() {
228            Response::Token { content, .. } => assert_eq!(content, "hi"),
229            other => panic!("frame[0]: {other:?}"),
230        }
231        match frames[1].as_ref().unwrap() {
232            Response::Done { backend, .. } => assert_eq!(backend, "mock"),
233            other => panic!("frame[1]: {other:?}"),
234        }
235    }
236
237    #[tokio::test]
238    async fn unexpected_eof_yields_clienterror() {
239        let (server_side, client_side) = tokio::io::duplex(4096);
240        let (read, write) = tokio::io::split(client_side);
241        let mut client = Client::wrap(Box::new(read), Box::new(write));
242
243        // Server: read request, immediately drop without writing.
244        let server = tokio::spawn(async move {
245            let (rx, _tx) = tokio::io::split(server_side);
246            let mut br = tokio::io::BufReader::new(rx);
247            let mut req_line = Vec::new();
248            br.read_until(b'\n', &mut req_line).await.unwrap();
249            // server_side drops here -> EOF on client.
250        });
251
252        let mut stream = client.generate(sample_request()).await.unwrap();
253        use tokio_stream::StreamExt;
254        let first = stream.next().await.unwrap();
255        server.await.unwrap();
256        match first {
257            Err(ClientError::UnexpectedEof) => {}
258            other => panic!("expected UnexpectedEof, got {other:?}"),
259        }
260    }
261
262    #[test]
263    fn error_code_round_trips() {
264        // Sanity check: the re-export from inferd-proto behaves
265        // the same as a direct inferd_proto::ErrorCode.
266        let frame = Response::Error {
267            id: "x".into(),
268            code: ErrorCode::QueueFull,
269            message: "queue full".into(),
270        };
271        let s = serde_json::to_string(&frame).unwrap();
272        assert!(s.contains(r#""code":"queue_full""#));
273    }
274}