Skip to main content

inferd_client/
embed_client.rs

1//! Embed-socket client. NDJSON over UDS / pipe / loopback TCP.
2//!
3//! Spec: ADR 0017. The embed socket is the *third* inferd surface
4//! (v1, v2, embed), each on its own path. Construct an `EmbedClient`
5//! with `dial_tcp` / `dial_uds` / `dial_pipe`, then call `embed`
6//! per request. The connection is long-lived: send a request, receive
7//! one terminal frame, send the next request — there is no streaming
8//! since an embedding is a complete vector.
9
10use crate::client::ClientError;
11use inferd_proto::embed::{EmbedRequest, EmbedResponse};
12#[cfg(unix)]
13use std::path::Path;
14use std::sync::Arc;
15use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
16use tokio::net::TcpStream;
17use tokio::sync::Mutex;
18
19/// Embed-socket client.
20///
21/// Construct via `dial_tcp` / `dial_uds` (Unix) / `dial_pipe`
22/// (Windows). Wrap with [`crate::dial_and_wait_ready`] to retry
23/// connect during daemon bring-up — the retry helper is generic over
24/// the client type so the same wait logic serves v1 / v2 / embed.
25pub struct EmbedClient {
26    inner: Arc<Mutex<Inner>>,
27}
28
29impl std::fmt::Debug for EmbedClient {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.debug_struct("EmbedClient").finish_non_exhaustive()
32    }
33}
34
35struct Inner {
36    write: Box<dyn AsyncWrite + Send + Unpin>,
37    read: BufReader<Box<dyn AsyncRead + Send + Unpin>>,
38}
39
40impl EmbedClient {
41    /// Open a TCP connection to `addr` (e.g. `"127.0.0.1:47323"`).
42    /// Embed TCP transport is opt-in (config-flagged off by default);
43    /// see ADR 0017 §Endpoints for the configured port convention.
44    pub async fn dial_tcp(addr: &str) -> Result<Self, ClientError> {
45        let stream = TcpStream::connect(addr).await?;
46        let (read, write) = stream.into_split();
47        Ok(Self::wrap(Box::new(read), Box::new(write)))
48    }
49
50    /// Open a Unix domain socket connection (Unix only). Default embed
51    /// path: `${XDG_RUNTIME_DIR}/inferd/infer.embed.sock` on Linux,
52    /// `${TMPDIR}/inferd/infer.embed.sock` on macOS.
53    #[cfg(unix)]
54    pub async fn dial_uds(path: &Path) -> Result<Self, ClientError> {
55        let stream = tokio::net::UnixStream::connect(path).await?;
56        let (read, write) = stream.into_split();
57        Ok(Self::wrap(Box::new(read), Box::new(write)))
58    }
59
60    /// Open a Windows named pipe connection (Windows only). Default
61    /// embed path: `\\.\pipe\inferd-infer-embed`.
62    #[cfg(windows)]
63    pub async fn dial_pipe(path: &str) -> Result<Self, ClientError> {
64        use tokio::net::windows::named_pipe::ClientOptions;
65        let pipe = ClientOptions::new().open(path)?;
66        let (read, write) = tokio::io::split(pipe);
67        Ok(Self::wrap(Box::new(read), Box::new(write)))
68    }
69
70    fn wrap(
71        read: Box<dyn AsyncRead + Send + Unpin>,
72        write: Box<dyn AsyncWrite + Send + Unpin>,
73    ) -> Self {
74        Self {
75            inner: Arc::new(Mutex::new(Inner {
76                write,
77                read: BufReader::with_capacity(64 * 1024, read),
78            })),
79        }
80    }
81
82    /// Test-only constructor: build an `EmbedClient` from arbitrary
83    /// `AsyncRead` / `AsyncWrite` halves. Lets sibling-module tests
84    /// stub the transport with `tokio::io::duplex`. Not part of the
85    /// public API.
86    #[doc(hidden)]
87    pub fn wrap_for_test(
88        read: Box<dyn AsyncRead + Send + Unpin>,
89        write: Box<dyn AsyncWrite + Send + Unpin>,
90    ) -> Self {
91        Self::wrap(read, write)
92    }
93
94    /// Send an `EmbedRequest` and read back the single terminal
95    /// `EmbedResponse` frame (`Embeddings` or `Error`). The connection
96    /// stays open for the next call.
97    ///
98    /// Yields `Err(ClientError::UnexpectedEof)` if the daemon closes
99    /// the connection without writing a response (e.g. crashed mid-
100    /// request). Per `docs/protocol-v1.md`, callers treat this as
101    /// equivalent to a `backend_unavailable` error and apply their own
102    /// retry policy.
103    pub async fn embed(&mut self, req: EmbedRequest) -> Result<EmbedResponse, ClientError> {
104        let mut buf = Vec::with_capacity(512);
105        serde_json::to_writer(&mut buf, &req)?;
106        buf.push(b'\n');
107
108        let mut g = self.inner.lock().await;
109        g.write.write_all(&buf).await?;
110        g.write.flush().await?;
111
112        let mut line = Vec::with_capacity(512);
113        let n = g.read.read_until(b'\n', &mut line).await?;
114        if n == 0 {
115            return Err(ClientError::UnexpectedEof);
116        }
117        let resp: EmbedResponse = serde_json::from_slice(&line)?;
118        Ok(resp)
119    }
120}
121
122/// Default embed inference endpoint path, mirroring the daemon's
123/// `endpoint::default_embed_addr`. Returned as a `PathBuf` on Unix
124/// and as a pipe-path string on Windows; callers pick by `cfg`.
125///
126/// Linux fallback chain (same as v2 / admin):
127/// 1. `${XDG_RUNTIME_DIR}/inferd/infer.embed.sock`
128/// 2. `${HOME}/.inferd/run/infer.embed.sock`
129/// 3. `/tmp/inferd/infer.embed.sock`
130pub fn default_embed_addr() -> std::path::PathBuf {
131    #[cfg(target_os = "linux")]
132    {
133        if let Some(xdg) = std::env::var_os("XDG_RUNTIME_DIR") {
134            let mut p = std::path::PathBuf::from(xdg);
135            if !p.as_os_str().is_empty() {
136                p.push("inferd");
137                p.push("infer.embed.sock");
138                return p;
139            }
140        }
141        if let Some(home) = std::env::var_os("HOME") {
142            let mut p = std::path::PathBuf::from(home);
143            if !p.as_os_str().is_empty() {
144                p.push(".inferd");
145                p.push("run");
146                p.push("infer.embed.sock");
147                return p;
148            }
149        }
150        std::path::PathBuf::from("/tmp/inferd/infer.embed.sock")
151    }
152    #[cfg(target_os = "macos")]
153    {
154        let mut p = std::env::temp_dir();
155        p.push("inferd");
156        p.push("infer.embed.sock");
157        p
158    }
159    #[cfg(windows)]
160    {
161        std::path::PathBuf::from(r"\\.\pipe\inferd-infer-embed")
162    }
163    #[cfg(not(any(target_os = "linux", target_os = "macos", windows)))]
164    {
165        std::path::PathBuf::from("/tmp/inferd/infer.embed.sock")
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172    use inferd_proto::embed::{EmbedErrorCode, EmbedTask, EmbedUsage};
173    use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
174
175    fn sample_request() -> EmbedRequest {
176        EmbedRequest {
177            id: "embed-test".into(),
178            input: vec!["hello".into(), "world".into()],
179            dimensions: Some(128),
180            task: Some(EmbedTask::RetrievalDocument),
181        }
182    }
183
184    #[tokio::test]
185    async fn embed_round_trips_a_success_frame() {
186        let (server_side, client_side) = tokio::io::duplex(4096);
187        let (read, write) = tokio::io::split(client_side);
188        let mut client = EmbedClient::wrap(Box::new(read), Box::new(write));
189
190        let server = tokio::spawn(async move {
191            let (rx, mut tx) = tokio::io::split(server_side);
192            let mut br = tokio::io::BufReader::new(rx);
193            let mut req_line = Vec::new();
194            br.read_until(b'\n', &mut req_line).await.unwrap();
195
196            let frame = serde_json::to_vec(&EmbedResponse::Embeddings {
197                id: "embed-test".into(),
198                embeddings: vec![vec![0.1, 0.2], vec![0.3, 0.4]],
199                dimensions: 128,
200                model: "embeddinggemma-300m".into(),
201                usage: EmbedUsage { input_tokens: 4 },
202                backend: "llamacpp".into(),
203            })
204            .unwrap();
205            tx.write_all(&frame).await.unwrap();
206            tx.write_all(b"\n").await.unwrap();
207        });
208
209        let resp = client.embed(sample_request()).await.unwrap();
210        server.await.unwrap();
211
212        match resp {
213            EmbedResponse::Embeddings {
214                embeddings,
215                dimensions,
216                backend,
217                ..
218            } => {
219                assert_eq!(embeddings.len(), 2);
220                assert_eq!(dimensions, 128);
221                assert_eq!(backend, "llamacpp");
222            }
223            other => panic!("expected Embeddings, got {other:?}"),
224        }
225    }
226
227    #[tokio::test]
228    async fn embed_round_trips_an_error_frame() {
229        let (server_side, client_side) = tokio::io::duplex(4096);
230        let (read, write) = tokio::io::split(client_side);
231        let mut client = EmbedClient::wrap(Box::new(read), Box::new(write));
232
233        let server = tokio::spawn(async move {
234            let (rx, mut tx) = tokio::io::split(server_side);
235            let mut br = tokio::io::BufReader::new(rx);
236            let mut req_line = Vec::new();
237            br.read_until(b'\n', &mut req_line).await.unwrap();
238
239            let frame = serde_json::to_vec(&EmbedResponse::Error {
240                id: "embed-test".into(),
241                code: EmbedErrorCode::InvalidRequest,
242                message: "dimensions=999 not supported".into(),
243            })
244            .unwrap();
245            tx.write_all(&frame).await.unwrap();
246            tx.write_all(b"\n").await.unwrap();
247        });
248
249        let resp = client.embed(sample_request()).await.unwrap();
250        server.await.unwrap();
251
252        match resp {
253            EmbedResponse::Error { code, .. } => {
254                assert_eq!(code, EmbedErrorCode::InvalidRequest);
255            }
256            other => panic!("expected Error, got {other:?}"),
257        }
258    }
259
260    #[tokio::test]
261    async fn unexpected_eof_yields_clienterror() {
262        let (server_side, client_side) = tokio::io::duplex(4096);
263        let (read, write) = tokio::io::split(client_side);
264        let mut client = EmbedClient::wrap(Box::new(read), Box::new(write));
265
266        let server = tokio::spawn(async move {
267            let (rx, _tx) = tokio::io::split(server_side);
268            let mut br = tokio::io::BufReader::new(rx);
269            let mut req_line = Vec::new();
270            br.read_until(b'\n', &mut req_line).await.unwrap();
271            // server_side drops here -> EOF on client.
272        });
273
274        let result = client.embed(sample_request()).await;
275        server.await.unwrap();
276        match result {
277            Err(ClientError::UnexpectedEof) => {}
278            other => panic!("expected UnexpectedEof, got {other:?}"),
279        }
280    }
281
282    #[tokio::test]
283    async fn connection_stays_open_for_a_second_request() {
284        let (server_side, client_side) = tokio::io::duplex(4096);
285        let (read, write) = tokio::io::split(client_side);
286        let mut client = EmbedClient::wrap(Box::new(read), Box::new(write));
287
288        let server = tokio::spawn(async move {
289            let (rx, mut tx) = tokio::io::split(server_side);
290            let mut br = tokio::io::BufReader::new(rx);
291            for i in 0..2 {
292                let mut req_line = Vec::new();
293                br.read_until(b'\n', &mut req_line).await.unwrap();
294                let frame = serde_json::to_vec(&EmbedResponse::Embeddings {
295                    id: format!("r{i}"),
296                    embeddings: vec![vec![0.0]],
297                    dimensions: 1,
298                    model: "m".into(),
299                    usage: EmbedUsage { input_tokens: 1 },
300                    backend: "mock".into(),
301                })
302                .unwrap();
303                tx.write_all(&frame).await.unwrap();
304                tx.write_all(b"\n").await.unwrap();
305            }
306        });
307
308        for i in 0..2 {
309            let req = EmbedRequest {
310                id: format!("r{i}"),
311                input: vec!["x".into()],
312                ..Default::default()
313            };
314            let resp = client.embed(req).await.unwrap();
315            assert_eq!(resp.id(), format!("r{i}"));
316        }
317        server.await.unwrap();
318    }
319}