1use crate::client::ClientError;
11use inferd_proto::embed::{EmbedRequest, EmbedResponse};
12#[cfg(unix)]
13use std::path::Path;
14use std::sync::Arc;
15use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
16use tokio::net::TcpStream;
17use tokio::sync::Mutex;
18
19pub struct EmbedClient {
26 inner: Arc<Mutex<Inner>>,
27}
28
29impl std::fmt::Debug for EmbedClient {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 f.debug_struct("EmbedClient").finish_non_exhaustive()
32 }
33}
34
35struct Inner {
36 write: Box<dyn AsyncWrite + Send + Unpin>,
37 read: BufReader<Box<dyn AsyncRead + Send + Unpin>>,
38}
39
40impl EmbedClient {
41 pub async fn dial_tcp(addr: &str) -> Result<Self, ClientError> {
45 let stream = TcpStream::connect(addr).await?;
46 let (read, write) = stream.into_split();
47 Ok(Self::wrap(Box::new(read), Box::new(write)))
48 }
49
50 #[cfg(unix)]
54 pub async fn dial_uds(path: &Path) -> Result<Self, ClientError> {
55 let stream = tokio::net::UnixStream::connect(path).await?;
56 let (read, write) = stream.into_split();
57 Ok(Self::wrap(Box::new(read), Box::new(write)))
58 }
59
60 #[cfg(windows)]
63 pub async fn dial_pipe(path: &str) -> Result<Self, ClientError> {
64 use tokio::net::windows::named_pipe::ClientOptions;
65 let pipe = ClientOptions::new().open(path)?;
66 let (read, write) = tokio::io::split(pipe);
67 Ok(Self::wrap(Box::new(read), Box::new(write)))
68 }
69
70 fn wrap(
71 read: Box<dyn AsyncRead + Send + Unpin>,
72 write: Box<dyn AsyncWrite + Send + Unpin>,
73 ) -> Self {
74 Self {
75 inner: Arc::new(Mutex::new(Inner {
76 write,
77 read: BufReader::with_capacity(64 * 1024, read),
78 })),
79 }
80 }
81
82 #[doc(hidden)]
87 pub fn wrap_for_test(
88 read: Box<dyn AsyncRead + Send + Unpin>,
89 write: Box<dyn AsyncWrite + Send + Unpin>,
90 ) -> Self {
91 Self::wrap(read, write)
92 }
93
94 pub async fn embed(&mut self, req: EmbedRequest) -> Result<EmbedResponse, ClientError> {
104 let mut buf = Vec::with_capacity(512);
105 serde_json::to_writer(&mut buf, &req)?;
106 buf.push(b'\n');
107
108 let mut g = self.inner.lock().await;
109 g.write.write_all(&buf).await?;
110 g.write.flush().await?;
111
112 let mut line = Vec::with_capacity(512);
113 let n = g.read.read_until(b'\n', &mut line).await?;
114 if n == 0 {
115 return Err(ClientError::UnexpectedEof);
116 }
117 let resp: EmbedResponse = serde_json::from_slice(&line)?;
118 Ok(resp)
119 }
120}
121
122pub fn default_embed_addr() -> std::path::PathBuf {
131 #[cfg(target_os = "linux")]
132 {
133 if let Some(xdg) = std::env::var_os("XDG_RUNTIME_DIR") {
134 let mut p = std::path::PathBuf::from(xdg);
135 if !p.as_os_str().is_empty() {
136 p.push("inferd");
137 p.push("infer.embed.sock");
138 return p;
139 }
140 }
141 if let Some(home) = std::env::var_os("HOME") {
142 let mut p = std::path::PathBuf::from(home);
143 if !p.as_os_str().is_empty() {
144 p.push(".inferd");
145 p.push("run");
146 p.push("infer.embed.sock");
147 return p;
148 }
149 }
150 std::path::PathBuf::from("/tmp/inferd/infer.embed.sock")
151 }
152 #[cfg(target_os = "macos")]
153 {
154 let mut p = std::env::temp_dir();
155 p.push("inferd");
156 p.push("infer.embed.sock");
157 p
158 }
159 #[cfg(windows)]
160 {
161 std::path::PathBuf::from(r"\\.\pipe\inferd-infer-embed")
162 }
163 #[cfg(not(any(target_os = "linux", target_os = "macos", windows)))]
164 {
165 std::path::PathBuf::from("/tmp/inferd/infer.embed.sock")
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use inferd_proto::embed::{EmbedErrorCode, EmbedTask, EmbedUsage};
173 use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
174
175 fn sample_request() -> EmbedRequest {
176 EmbedRequest {
177 id: "embed-test".into(),
178 input: vec!["hello".into(), "world".into()],
179 dimensions: Some(128),
180 task: Some(EmbedTask::RetrievalDocument),
181 }
182 }
183
184 #[tokio::test]
185 async fn embed_round_trips_a_success_frame() {
186 let (server_side, client_side) = tokio::io::duplex(4096);
187 let (read, write) = tokio::io::split(client_side);
188 let mut client = EmbedClient::wrap(Box::new(read), Box::new(write));
189
190 let server = tokio::spawn(async move {
191 let (rx, mut tx) = tokio::io::split(server_side);
192 let mut br = tokio::io::BufReader::new(rx);
193 let mut req_line = Vec::new();
194 br.read_until(b'\n', &mut req_line).await.unwrap();
195
196 let frame = serde_json::to_vec(&EmbedResponse::Embeddings {
197 id: "embed-test".into(),
198 embeddings: vec![vec![0.1, 0.2], vec![0.3, 0.4]],
199 dimensions: 128,
200 model: "embeddinggemma-300m".into(),
201 usage: EmbedUsage { input_tokens: 4 },
202 backend: "llamacpp".into(),
203 })
204 .unwrap();
205 tx.write_all(&frame).await.unwrap();
206 tx.write_all(b"\n").await.unwrap();
207 });
208
209 let resp = client.embed(sample_request()).await.unwrap();
210 server.await.unwrap();
211
212 match resp {
213 EmbedResponse::Embeddings {
214 embeddings,
215 dimensions,
216 backend,
217 ..
218 } => {
219 assert_eq!(embeddings.len(), 2);
220 assert_eq!(dimensions, 128);
221 assert_eq!(backend, "llamacpp");
222 }
223 other => panic!("expected Embeddings, got {other:?}"),
224 }
225 }
226
227 #[tokio::test]
228 async fn embed_round_trips_an_error_frame() {
229 let (server_side, client_side) = tokio::io::duplex(4096);
230 let (read, write) = tokio::io::split(client_side);
231 let mut client = EmbedClient::wrap(Box::new(read), Box::new(write));
232
233 let server = tokio::spawn(async move {
234 let (rx, mut tx) = tokio::io::split(server_side);
235 let mut br = tokio::io::BufReader::new(rx);
236 let mut req_line = Vec::new();
237 br.read_until(b'\n', &mut req_line).await.unwrap();
238
239 let frame = serde_json::to_vec(&EmbedResponse::Error {
240 id: "embed-test".into(),
241 code: EmbedErrorCode::InvalidRequest,
242 message: "dimensions=999 not supported".into(),
243 })
244 .unwrap();
245 tx.write_all(&frame).await.unwrap();
246 tx.write_all(b"\n").await.unwrap();
247 });
248
249 let resp = client.embed(sample_request()).await.unwrap();
250 server.await.unwrap();
251
252 match resp {
253 EmbedResponse::Error { code, .. } => {
254 assert_eq!(code, EmbedErrorCode::InvalidRequest);
255 }
256 other => panic!("expected Error, got {other:?}"),
257 }
258 }
259
260 #[tokio::test]
261 async fn unexpected_eof_yields_clienterror() {
262 let (server_side, client_side) = tokio::io::duplex(4096);
263 let (read, write) = tokio::io::split(client_side);
264 let mut client = EmbedClient::wrap(Box::new(read), Box::new(write));
265
266 let server = tokio::spawn(async move {
267 let (rx, _tx) = tokio::io::split(server_side);
268 let mut br = tokio::io::BufReader::new(rx);
269 let mut req_line = Vec::new();
270 br.read_until(b'\n', &mut req_line).await.unwrap();
271 });
273
274 let result = client.embed(sample_request()).await;
275 server.await.unwrap();
276 match result {
277 Err(ClientError::UnexpectedEof) => {}
278 other => panic!("expected UnexpectedEof, got {other:?}"),
279 }
280 }
281
282 #[tokio::test]
283 async fn connection_stays_open_for_a_second_request() {
284 let (server_side, client_side) = tokio::io::duplex(4096);
285 let (read, write) = tokio::io::split(client_side);
286 let mut client = EmbedClient::wrap(Box::new(read), Box::new(write));
287
288 let server = tokio::spawn(async move {
289 let (rx, mut tx) = tokio::io::split(server_side);
290 let mut br = tokio::io::BufReader::new(rx);
291 for i in 0..2 {
292 let mut req_line = Vec::new();
293 br.read_until(b'\n', &mut req_line).await.unwrap();
294 let frame = serde_json::to_vec(&EmbedResponse::Embeddings {
295 id: format!("r{i}"),
296 embeddings: vec![vec![0.0]],
297 dimensions: 1,
298 model: "m".into(),
299 usage: EmbedUsage { input_tokens: 1 },
300 backend: "mock".into(),
301 })
302 .unwrap();
303 tx.write_all(&frame).await.unwrap();
304 tx.write_all(b"\n").await.unwrap();
305 }
306 });
307
308 for i in 0..2 {
309 let req = EmbedRequest {
310 id: format!("r{i}"),
311 input: vec!["x".into()],
312 ..Default::default()
313 };
314 let resp = client.embed(req).await.unwrap();
315 assert_eq!(resp.id(), format!("r{i}"));
316 }
317 server.await.unwrap();
318 }
319}