1use inferd_proto::{Request, Response};
4use std::io;
5#[cfg(unix)]
6use std::path::Path;
7use std::pin::Pin;
8use std::sync::Arc;
9use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
10use tokio::net::TcpStream;
11use tokio::sync::Mutex;
12use tokio_stream::Stream;
13
14#[derive(Debug, thiserror::Error)]
16pub enum ClientError {
17 #[error("io: {0}")]
19 Io(#[from] io::Error),
20 #[error("decode: {0}")]
22 Decode(#[from] serde_json::Error),
23 #[error("daemon closed connection before terminal frame")]
28 UnexpectedEof,
29}
30
31pub type FrameStream = Pin<Box<dyn Stream<Item = Result<Response, ClientError>> + Send>>;
33
34pub struct Client {
40 inner: Arc<Mutex<Inner>>,
41}
42
43impl std::fmt::Debug for Client {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 f.debug_struct("Client").finish_non_exhaustive()
49 }
50}
51
52struct Inner {
53 write: Box<dyn AsyncWrite + Send + Unpin>,
54 read: BufReader<Box<dyn AsyncRead + Send + Unpin>>,
55}
56
57impl Client {
58 pub async fn dial_tcp(addr: &str) -> Result<Self, ClientError> {
60 let stream = TcpStream::connect(addr).await?;
61 let (read, write) = stream.into_split();
62 Ok(Self::wrap(Box::new(read), Box::new(write)))
63 }
64
65 #[cfg(unix)]
67 pub async fn dial_uds(path: &Path) -> Result<Self, ClientError> {
68 let stream = tokio::net::UnixStream::connect(path).await?;
69 let (read, write) = stream.into_split();
70 Ok(Self::wrap(Box::new(read), Box::new(write)))
71 }
72
73 #[cfg(windows)]
75 pub async fn dial_pipe(path: &str) -> Result<Self, ClientError> {
76 use tokio::net::windows::named_pipe::ClientOptions;
77 let pipe = ClientOptions::new().open(path)?;
78 let (read, write) = tokio::io::split(pipe);
79 Ok(Self::wrap(Box::new(read), Box::new(write)))
80 }
81
82 fn wrap(
83 read: Box<dyn AsyncRead + Send + Unpin>,
84 write: Box<dyn AsyncWrite + Send + Unpin>,
85 ) -> Self {
86 Self {
87 inner: Arc::new(Mutex::new(Inner {
88 write,
89 read: BufReader::with_capacity(64 * 1024, read),
90 })),
91 }
92 }
93
94 #[doc(hidden)]
99 pub fn wrap_for_test(
100 read: Box<dyn AsyncRead + Send + Unpin>,
101 write: Box<dyn AsyncWrite + Send + Unpin>,
102 ) -> Self {
103 Self::wrap(read, write)
104 }
105
106 pub async fn generate(&mut self, req: Request) -> Result<FrameStream, ClientError> {
111 let mut buf = Vec::with_capacity(512);
112 serde_json::to_writer(&mut buf, &req)?;
113 buf.push(b'\n');
114
115 {
116 let mut g = self.inner.lock().await;
117 g.write.write_all(&buf).await?;
118 g.write.flush().await?;
119 }
120
121 let inner = Arc::clone(&self.inner);
122 let stream = async_stream::stream! {
123 loop {
124 let mut g = inner.lock().await;
125 let mut line = Vec::with_capacity(512);
126 let n = match g.read.read_until(b'\n', &mut line).await {
127 Ok(n) => n,
128 Err(e) => { yield Err(ClientError::Io(e)); return; }
129 };
130 if n == 0 {
131 yield Err(ClientError::UnexpectedEof);
132 return;
133 }
134 drop(g);
135
136 match serde_json::from_slice::<Response>(&line) {
137 Ok(resp) => {
138 let terminal = matches!(
139 &resp,
140 Response::Done { .. } | Response::Error { .. }
141 );
142 yield Ok(resp);
143 if terminal {
144 return;
145 }
146 }
147 Err(e) => {
148 yield Err(ClientError::Decode(e));
149 return;
150 }
151 }
152 }
153 };
154 Ok(Box::pin(stream))
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161 use inferd_proto::{ErrorCode, Message, Role, StopReason, Usage};
162 use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
163
164 fn sample_request() -> Request {
165 Request {
166 id: "test".into(),
167 messages: vec![Message {
168 role: Role::User,
169 content: "hello".into(),
170 }],
171 temperature: None,
172 top_p: None,
173 top_k: None,
174 max_tokens: None,
175 stream: None,
176 image_token_budget: None,
177 grammar: String::new(),
178 }
179 }
180
181 #[tokio::test]
185 async fn generate_streams_token_then_done() {
186 let (server_side, client_side) = tokio::io::duplex(4096);
187 let (read, write) = tokio::io::split(client_side);
188 let mut client = Client::wrap(Box::new(read), Box::new(write));
189
190 let server = tokio::spawn(async move {
192 let (rx, mut tx) = tokio::io::split(server_side);
193 let mut br = tokio::io::BufReader::new(rx);
194 let mut req_line = Vec::new();
195 br.read_until(b'\n', &mut req_line).await.unwrap();
196
197 let token = serde_json::to_vec(&Response::Token {
199 id: "test".into(),
200 content: "hi".into(),
201 })
202 .unwrap();
203 tx.write_all(&token).await.unwrap();
204 tx.write_all(b"\n").await.unwrap();
205 let done = serde_json::to_vec(&Response::Done {
207 id: "test".into(),
208 content: "hi".into(),
209 usage: Usage {
210 prompt_tokens: 1,
211 completion_tokens: 1,
212 },
213 stop_reason: StopReason::End,
214 backend: "mock".into(),
215 })
216 .unwrap();
217 tx.write_all(&done).await.unwrap();
218 tx.write_all(b"\n").await.unwrap();
219 });
220
221 let stream = client.generate(sample_request()).await.unwrap();
222 use tokio_stream::StreamExt;
223 let frames: Vec<_> = stream.collect().await;
224 server.await.unwrap();
225
226 assert_eq!(frames.len(), 2);
227 match frames[0].as_ref().unwrap() {
228 Response::Token { content, .. } => assert_eq!(content, "hi"),
229 other => panic!("frame[0]: {other:?}"),
230 }
231 match frames[1].as_ref().unwrap() {
232 Response::Done { backend, .. } => assert_eq!(backend, "mock"),
233 other => panic!("frame[1]: {other:?}"),
234 }
235 }
236
237 #[tokio::test]
238 async fn unexpected_eof_yields_clienterror() {
239 let (server_side, client_side) = tokio::io::duplex(4096);
240 let (read, write) = tokio::io::split(client_side);
241 let mut client = Client::wrap(Box::new(read), Box::new(write));
242
243 let server = tokio::spawn(async move {
245 let (rx, _tx) = tokio::io::split(server_side);
246 let mut br = tokio::io::BufReader::new(rx);
247 let mut req_line = Vec::new();
248 br.read_until(b'\n', &mut req_line).await.unwrap();
249 });
251
252 let mut stream = client.generate(sample_request()).await.unwrap();
253 use tokio_stream::StreamExt;
254 let first = stream.next().await.unwrap();
255 server.await.unwrap();
256 match first {
257 Err(ClientError::UnexpectedEof) => {}
258 other => panic!("expected UnexpectedEof, got {other:?}"),
259 }
260 }
261
262 #[test]
263 fn error_code_round_trips() {
264 let frame = Response::Error {
267 id: "x".into(),
268 code: ErrorCode::QueueFull,
269 message: "queue full".into(),
270 };
271 let s = serde_json::to_string(&frame).unwrap();
272 assert!(s.contains(r#""code":"queue_full""#));
273 }
274}