inferd_client/
v2_client.rs1use crate::client::ClientError;
11use inferd_proto::v2::{RequestV2, ResponseV2};
12#[cfg(unix)]
13use std::path::Path;
14use std::pin::Pin;
15use std::sync::Arc;
16use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
17use tokio::net::TcpStream;
18use tokio::sync::Mutex;
19use tokio_stream::Stream;
20
21pub type FrameStreamV2 = Pin<Box<dyn Stream<Item = Result<ResponseV2, ClientError>> + Send>>;
23
24pub struct ClientV2 {
31 inner: Arc<Mutex<Inner>>,
32}
33
34impl std::fmt::Debug for ClientV2 {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 f.debug_struct("ClientV2").finish_non_exhaustive()
37 }
38}
39
40struct Inner {
41 write: Box<dyn AsyncWrite + Send + Unpin>,
42 read: BufReader<Box<dyn AsyncRead + Send + Unpin>>,
43}
44
45impl ClientV2 {
46 pub async fn dial_tcp(addr: &str) -> Result<Self, ClientError> {
50 let stream = TcpStream::connect(addr).await?;
51 let (read, write) = stream.into_split();
52 Ok(Self::wrap(Box::new(read), Box::new(write)))
53 }
54
55 #[cfg(unix)]
59 pub async fn dial_uds(path: &Path) -> Result<Self, ClientError> {
60 let stream = tokio::net::UnixStream::connect(path).await?;
61 let (read, write) = stream.into_split();
62 Ok(Self::wrap(Box::new(read), Box::new(write)))
63 }
64
65 #[cfg(windows)]
68 pub async fn dial_pipe(path: &str) -> Result<Self, ClientError> {
69 use tokio::net::windows::named_pipe::ClientOptions;
70 let pipe = ClientOptions::new().open(path)?;
71 let (read, write) = tokio::io::split(pipe);
72 Ok(Self::wrap(Box::new(read), Box::new(write)))
73 }
74
75 fn wrap(
76 read: Box<dyn AsyncRead + Send + Unpin>,
77 write: Box<dyn AsyncWrite + Send + Unpin>,
78 ) -> Self {
79 Self {
80 inner: Arc::new(Mutex::new(Inner {
81 write,
82 read: BufReader::with_capacity(64 * 1024, read),
83 })),
84 }
85 }
86
87 #[doc(hidden)]
92 pub fn wrap_for_test(
93 read: Box<dyn AsyncRead + Send + Unpin>,
94 write: Box<dyn AsyncWrite + Send + Unpin>,
95 ) -> Self {
96 Self::wrap(read, write)
97 }
98
99 pub async fn generate(&mut self, req: RequestV2) -> Result<FrameStreamV2, ClientError> {
104 let mut buf = Vec::with_capacity(512);
105 serde_json::to_writer(&mut buf, &req)?;
106 buf.push(b'\n');
107
108 {
109 let mut g = self.inner.lock().await;
110 g.write.write_all(&buf).await?;
111 g.write.flush().await?;
112 }
113
114 let inner = Arc::clone(&self.inner);
115 let stream = async_stream::stream! {
116 loop {
117 let mut g = inner.lock().await;
118 let mut line = Vec::with_capacity(512);
119 let n = match g.read.read_until(b'\n', &mut line).await {
120 Ok(n) => n,
121 Err(e) => { yield Err(ClientError::Io(e)); return; }
122 };
123 if n == 0 {
124 yield Err(ClientError::UnexpectedEof);
125 return;
126 }
127 drop(g);
128
129 match serde_json::from_slice::<ResponseV2>(&line) {
130 Ok(resp) => {
131 let terminal = resp.is_terminal();
132 yield Ok(resp);
133 if terminal {
134 return;
135 }
136 }
137 Err(e) => {
138 yield Err(ClientError::Decode(e));
139 return;
140 }
141 }
142 }
143 };
144 Ok(Box::pin(stream))
145 }
146}
147
148pub fn default_v2_addr() -> std::path::PathBuf {
157 #[cfg(target_os = "linux")]
158 {
159 if let Some(xdg) = std::env::var_os("XDG_RUNTIME_DIR") {
160 let mut p = std::path::PathBuf::from(xdg);
161 if !p.as_os_str().is_empty() {
162 p.push("inferd");
163 p.push("infer.v2.sock");
164 return p;
165 }
166 }
167 if let Some(home) = std::env::var_os("HOME") {
168 let mut p = std::path::PathBuf::from(home);
169 if !p.as_os_str().is_empty() {
170 p.push(".inferd");
171 p.push("run");
172 p.push("infer.v2.sock");
173 return p;
174 }
175 }
176 std::path::PathBuf::from("/tmp/inferd/infer.v2.sock")
177 }
178 #[cfg(target_os = "macos")]
179 {
180 let mut p = std::env::temp_dir();
181 p.push("inferd");
182 p.push("infer.v2.sock");
183 p
184 }
185 #[cfg(windows)]
186 {
187 std::path::PathBuf::from(r"\\.\pipe\inferd-infer-v2")
188 }
189 #[cfg(not(any(target_os = "linux", target_os = "macos", windows)))]
190 {
191 std::path::PathBuf::from("/tmp/inferd/infer.v2.sock")
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 use inferd_proto::v2::{
199 ContentBlock, ErrorCodeV2, MessageV2, ResponseBlock, RoleV2, StopReasonV2, UsageV2,
200 };
201 use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
202
203 fn sample_request() -> RequestV2 {
204 RequestV2 {
205 id: "v2-test".into(),
206 messages: vec![MessageV2 {
207 role: RoleV2::User,
208 content: vec![ContentBlock::Text {
209 text: "hello".into(),
210 }],
211 }],
212 ..Default::default()
213 }
214 }
215
216 #[tokio::test]
217 async fn generate_streams_frame_then_done() {
218 let (server_side, client_side) = tokio::io::duplex(4096);
219 let (read, write) = tokio::io::split(client_side);
220 let mut client = ClientV2::wrap(Box::new(read), Box::new(write));
221
222 let server = tokio::spawn(async move {
223 let (rx, mut tx) = tokio::io::split(server_side);
224 let mut br = tokio::io::BufReader::new(rx);
225 let mut req_line = Vec::new();
226 br.read_until(b'\n', &mut req_line).await.unwrap();
227
228 let frame = serde_json::to_vec(&ResponseV2::Frame {
229 id: "v2-test".into(),
230 block: ResponseBlock::Text { delta: "hi".into() },
231 })
232 .unwrap();
233 tx.write_all(&frame).await.unwrap();
234 tx.write_all(b"\n").await.unwrap();
235
236 let done = serde_json::to_vec(&ResponseV2::Done {
237 id: "v2-test".into(),
238 usage: UsageV2 {
239 input_tokens: 1,
240 output_tokens: 1,
241 },
242 stop_reason: StopReasonV2::EndTurn,
243 backend: "mock".into(),
244 })
245 .unwrap();
246 tx.write_all(&done).await.unwrap();
247 tx.write_all(b"\n").await.unwrap();
248 });
249
250 let stream = client.generate(sample_request()).await.unwrap();
251 use tokio_stream::StreamExt;
252 let frames: Vec<_> = stream.collect().await;
253 server.await.unwrap();
254
255 assert_eq!(frames.len(), 2);
256 match frames[0].as_ref().unwrap() {
257 ResponseV2::Frame {
258 block: ResponseBlock::Text { delta },
259 ..
260 } => assert_eq!(delta, "hi"),
261 other => panic!("frame[0]: {other:?}"),
262 }
263 match frames[1].as_ref().unwrap() {
264 ResponseV2::Done {
265 backend,
266 stop_reason,
267 ..
268 } => {
269 assert_eq!(backend, "mock");
270 assert_eq!(*stop_reason, StopReasonV2::EndTurn);
271 }
272 other => panic!("frame[1]: {other:?}"),
273 }
274 }
275
276 #[tokio::test]
277 async fn unexpected_eof_yields_clienterror() {
278 let (server_side, client_side) = tokio::io::duplex(4096);
279 let (read, write) = tokio::io::split(client_side);
280 let mut client = ClientV2::wrap(Box::new(read), Box::new(write));
281
282 let server = tokio::spawn(async move {
283 let (rx, _tx) = tokio::io::split(server_side);
284 let mut br = tokio::io::BufReader::new(rx);
285 let mut req_line = Vec::new();
286 br.read_until(b'\n', &mut req_line).await.unwrap();
287 });
289
290 let mut stream = client.generate(sample_request()).await.unwrap();
291 use tokio_stream::StreamExt;
292 let first = stream.next().await.unwrap();
293 server.await.unwrap();
294 match first {
295 Err(ClientError::UnexpectedEof) => {}
296 other => panic!("expected UnexpectedEof, got {other:?}"),
297 }
298 }
299
300 #[test]
301 fn error_v2_round_trips() {
302 let frame = ResponseV2::Error {
303 id: "x".into(),
304 code: ErrorCodeV2::AttachmentUnsupported,
305 message: "no audio".into(),
306 };
307 let s = serde_json::to_string(&frame).unwrap();
308 assert!(s.contains(r#""code":"attachment_unsupported""#));
309 }
310}