Skip to main content

inferd_client/
v2_client.rs

1//! v2 inference-socket client. NDJSON over UDS / pipe / loopback TCP.
2//!
3//! Spec: ADR 0015. v2 lives on a *separate* socket from v1 — clients
4//! pick a transport variant per platform with `dial_uds`,
5//! `dial_pipe`, or `dial_tcp`, mirroring `Client`'s shape but using
6//! the v2 wire types (`RequestV2` / `ResponseV2`). The framing
7//! contract is identical: 64 MiB cap, `\n`-delimited NDJSON, terminal
8//! `done` / `error` ends the stream.
9
10use crate::client::ClientError;
11use inferd_proto::v2::{RequestV2, ResponseV2};
12#[cfg(unix)]
13use std::path::Path;
14use std::pin::Pin;
15use std::sync::Arc;
16use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
17use tokio::net::TcpStream;
18use tokio::sync::Mutex;
19use tokio_stream::Stream;
20
21/// Stream of `ResponseV2` frames yielded by `ClientV2::generate`.
22pub type FrameStreamV2 = Pin<Box<dyn Stream<Item = Result<ResponseV2, ClientError>> + Send>>;
23
24/// v2 inference-socket client.
25///
26/// Construct via `dial_tcp` / `dial_uds` (Unix) / `dial_pipe`
27/// (Windows). Wrap with [`crate::dial_and_wait_ready`] to retry
28/// connect during daemon bring-up — the retry helper is generic over
29/// the client type so the same wait logic serves v1 and v2.
30pub struct ClientV2 {
31    inner: Arc<Mutex<Inner>>,
32}
33
34impl std::fmt::Debug for ClientV2 {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        f.debug_struct("ClientV2").finish_non_exhaustive()
37    }
38}
39
40struct Inner {
41    write: Box<dyn AsyncWrite + Send + Unpin>,
42    read: BufReader<Box<dyn AsyncRead + Send + Unpin>>,
43}
44
45impl ClientV2 {
46    /// Open a TCP connection to `addr` (e.g. `"127.0.0.1:47322"`).
47    /// v2 TCP transport is opt-in (config-flagged off by default);
48    /// see ADR 0015 §Endpoints for the configured port convention.
49    pub async fn dial_tcp(addr: &str) -> Result<Self, ClientError> {
50        let stream = TcpStream::connect(addr).await?;
51        let (read, write) = stream.into_split();
52        Ok(Self::wrap(Box::new(read), Box::new(write)))
53    }
54
55    /// Open a Unix domain socket connection (Unix only). Default v2
56    /// path: `${XDG_RUNTIME_DIR}/inferd/infer.v2.sock` on Linux,
57    /// `${TMPDIR}/inferd/infer.v2.sock` on macOS.
58    #[cfg(unix)]
59    pub async fn dial_uds(path: &Path) -> Result<Self, ClientError> {
60        let stream = tokio::net::UnixStream::connect(path).await?;
61        let (read, write) = stream.into_split();
62        Ok(Self::wrap(Box::new(read), Box::new(write)))
63    }
64
65    /// Open a Windows named pipe connection (Windows only). Default
66    /// v2 path: `\\.\pipe\inferd-infer-v2`.
67    #[cfg(windows)]
68    pub async fn dial_pipe(path: &str) -> Result<Self, ClientError> {
69        use tokio::net::windows::named_pipe::ClientOptions;
70        let pipe = ClientOptions::new().open(path)?;
71        let (read, write) = tokio::io::split(pipe);
72        Ok(Self::wrap(Box::new(read), Box::new(write)))
73    }
74
75    fn wrap(
76        read: Box<dyn AsyncRead + Send + Unpin>,
77        write: Box<dyn AsyncWrite + Send + Unpin>,
78    ) -> Self {
79        Self {
80            inner: Arc::new(Mutex::new(Inner {
81                write,
82                read: BufReader::with_capacity(64 * 1024, read),
83            })),
84        }
85    }
86
87    /// Test-only constructor: build a `ClientV2` from arbitrary
88    /// `AsyncRead`/`AsyncWrite` halves. Lets sibling-module tests
89    /// stub the transport with `tokio::io::duplex`. Not part of the
90    /// public API.
91    #[doc(hidden)]
92    pub fn wrap_for_test(
93        read: Box<dyn AsyncRead + Send + Unpin>,
94        write: Box<dyn AsyncWrite + Send + Unpin>,
95    ) -> Self {
96        Self::wrap(read, write)
97    }
98
99    /// Send a `RequestV2` and return a stream of `ResponseV2` frames.
100    /// The stream completes after a terminal `done` or `error` frame
101    /// (or yields `Err(ClientError::UnexpectedEof)` if the daemon
102    /// closes the connection mid-stream).
103    pub async fn generate(&mut self, req: RequestV2) -> Result<FrameStreamV2, ClientError> {
104        let mut buf = Vec::with_capacity(512);
105        serde_json::to_writer(&mut buf, &req)?;
106        buf.push(b'\n');
107
108        {
109            let mut g = self.inner.lock().await;
110            g.write.write_all(&buf).await?;
111            g.write.flush().await?;
112        }
113
114        let inner = Arc::clone(&self.inner);
115        let stream = async_stream::stream! {
116            loop {
117                let mut g = inner.lock().await;
118                let mut line = Vec::with_capacity(512);
119                let n = match g.read.read_until(b'\n', &mut line).await {
120                    Ok(n) => n,
121                    Err(e) => { yield Err(ClientError::Io(e)); return; }
122                };
123                if n == 0 {
124                    yield Err(ClientError::UnexpectedEof);
125                    return;
126                }
127                drop(g);
128
129                match serde_json::from_slice::<ResponseV2>(&line) {
130                    Ok(resp) => {
131                        let terminal = resp.is_terminal();
132                        yield Ok(resp);
133                        if terminal {
134                            return;
135                        }
136                    }
137                    Err(e) => {
138                        yield Err(ClientError::Decode(e));
139                        return;
140                    }
141                }
142            }
143        };
144        Ok(Box::pin(stream))
145    }
146}
147
148/// Default v2 admin / inference endpoint paths, mirroring the
149/// daemon's `endpoint::default_v2_addr`. Returned as `PathBuf` on
150/// Unix and as a pipe-path string on Windows; callers pick by `cfg`.
151///
152/// Linux fallback chain (same as v1's admin chain):
153/// 1. `${XDG_RUNTIME_DIR}/inferd/infer.v2.sock`
154/// 2. `${HOME}/.inferd/run/infer.v2.sock`
155/// 3. `/tmp/inferd/infer.v2.sock`
156pub fn default_v2_addr() -> std::path::PathBuf {
157    #[cfg(target_os = "linux")]
158    {
159        if let Some(xdg) = std::env::var_os("XDG_RUNTIME_DIR") {
160            let mut p = std::path::PathBuf::from(xdg);
161            if !p.as_os_str().is_empty() {
162                p.push("inferd");
163                p.push("infer.v2.sock");
164                return p;
165            }
166        }
167        if let Some(home) = std::env::var_os("HOME") {
168            let mut p = std::path::PathBuf::from(home);
169            if !p.as_os_str().is_empty() {
170                p.push(".inferd");
171                p.push("run");
172                p.push("infer.v2.sock");
173                return p;
174            }
175        }
176        std::path::PathBuf::from("/tmp/inferd/infer.v2.sock")
177    }
178    #[cfg(target_os = "macos")]
179    {
180        let mut p = std::env::temp_dir();
181        p.push("inferd");
182        p.push("infer.v2.sock");
183        p
184    }
185    #[cfg(windows)]
186    {
187        std::path::PathBuf::from(r"\\.\pipe\inferd-infer-v2")
188    }
189    #[cfg(not(any(target_os = "linux", target_os = "macos", windows)))]
190    {
191        std::path::PathBuf::from("/tmp/inferd/infer.v2.sock")
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198    use inferd_proto::v2::{
199        ContentBlock, ErrorCodeV2, MessageV2, ResponseBlock, RoleV2, StopReasonV2, UsageV2,
200    };
201    use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
202
203    fn sample_request() -> RequestV2 {
204        RequestV2 {
205            id: "v2-test".into(),
206            messages: vec![MessageV2 {
207                role: RoleV2::User,
208                content: vec![ContentBlock::Text {
209                    text: "hello".into(),
210                }],
211            }],
212            ..Default::default()
213        }
214    }
215
216    #[tokio::test]
217    async fn generate_streams_frame_then_done() {
218        let (server_side, client_side) = tokio::io::duplex(4096);
219        let (read, write) = tokio::io::split(client_side);
220        let mut client = ClientV2::wrap(Box::new(read), Box::new(write));
221
222        let server = tokio::spawn(async move {
223            let (rx, mut tx) = tokio::io::split(server_side);
224            let mut br = tokio::io::BufReader::new(rx);
225            let mut req_line = Vec::new();
226            br.read_until(b'\n', &mut req_line).await.unwrap();
227
228            let frame = serde_json::to_vec(&ResponseV2::Frame {
229                id: "v2-test".into(),
230                block: ResponseBlock::Text { delta: "hi".into() },
231            })
232            .unwrap();
233            tx.write_all(&frame).await.unwrap();
234            tx.write_all(b"\n").await.unwrap();
235
236            let done = serde_json::to_vec(&ResponseV2::Done {
237                id: "v2-test".into(),
238                usage: UsageV2 {
239                    input_tokens: 1,
240                    output_tokens: 1,
241                },
242                stop_reason: StopReasonV2::EndTurn,
243                backend: "mock".into(),
244            })
245            .unwrap();
246            tx.write_all(&done).await.unwrap();
247            tx.write_all(b"\n").await.unwrap();
248        });
249
250        let stream = client.generate(sample_request()).await.unwrap();
251        use tokio_stream::StreamExt;
252        let frames: Vec<_> = stream.collect().await;
253        server.await.unwrap();
254
255        assert_eq!(frames.len(), 2);
256        match frames[0].as_ref().unwrap() {
257            ResponseV2::Frame {
258                block: ResponseBlock::Text { delta },
259                ..
260            } => assert_eq!(delta, "hi"),
261            other => panic!("frame[0]: {other:?}"),
262        }
263        match frames[1].as_ref().unwrap() {
264            ResponseV2::Done {
265                backend,
266                stop_reason,
267                ..
268            } => {
269                assert_eq!(backend, "mock");
270                assert_eq!(*stop_reason, StopReasonV2::EndTurn);
271            }
272            other => panic!("frame[1]: {other:?}"),
273        }
274    }
275
276    #[tokio::test]
277    async fn unexpected_eof_yields_clienterror() {
278        let (server_side, client_side) = tokio::io::duplex(4096);
279        let (read, write) = tokio::io::split(client_side);
280        let mut client = ClientV2::wrap(Box::new(read), Box::new(write));
281
282        let server = tokio::spawn(async move {
283            let (rx, _tx) = tokio::io::split(server_side);
284            let mut br = tokio::io::BufReader::new(rx);
285            let mut req_line = Vec::new();
286            br.read_until(b'\n', &mut req_line).await.unwrap();
287            // server_side drops here -> EOF on client.
288        });
289
290        let mut stream = client.generate(sample_request()).await.unwrap();
291        use tokio_stream::StreamExt;
292        let first = stream.next().await.unwrap();
293        server.await.unwrap();
294        match first {
295            Err(ClientError::UnexpectedEof) => {}
296            other => panic!("expected UnexpectedEof, got {other:?}"),
297        }
298    }
299
300    #[test]
301    fn error_v2_round_trips() {
302        let frame = ResponseV2::Error {
303            id: "x".into(),
304            code: ErrorCodeV2::AttachmentUnsupported,
305            message: "no audio".into(),
306        };
307        let s = serde_json::to_string(&frame).unwrap();
308        assert!(s.contains(r#""code":"attachment_unsupported""#));
309    }
310}