use crate::client::ClientError;
use inferd_proto::embed::{EmbedRequest, EmbedResponse};
#[cfg(unix)]
use std::path::Path;
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
use tokio::net::TcpStream;
use tokio::sync::Mutex;
pub struct EmbedClient {
inner: Arc<Mutex<Inner>>,
}
impl std::fmt::Debug for EmbedClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EmbedClient").finish_non_exhaustive()
}
}
struct Inner {
write: Box<dyn AsyncWrite + Send + Unpin>,
read: BufReader<Box<dyn AsyncRead + Send + Unpin>>,
}
impl EmbedClient {
pub async fn dial_tcp(addr: &str) -> Result<Self, ClientError> {
let stream = TcpStream::connect(addr).await?;
let (read, write) = stream.into_split();
Ok(Self::wrap(Box::new(read), Box::new(write)))
}
#[cfg(unix)]
pub async fn dial_uds(path: &Path) -> Result<Self, ClientError> {
let stream = tokio::net::UnixStream::connect(path).await?;
let (read, write) = stream.into_split();
Ok(Self::wrap(Box::new(read), Box::new(write)))
}
#[cfg(windows)]
pub async fn dial_pipe(path: &str) -> Result<Self, ClientError> {
use tokio::net::windows::named_pipe::ClientOptions;
let pipe = ClientOptions::new().open(path)?;
let (read, write) = tokio::io::split(pipe);
Ok(Self::wrap(Box::new(read), Box::new(write)))
}
fn wrap(
read: Box<dyn AsyncRead + Send + Unpin>,
write: Box<dyn AsyncWrite + Send + Unpin>,
) -> Self {
Self {
inner: Arc::new(Mutex::new(Inner {
write,
read: BufReader::with_capacity(64 * 1024, read),
})),
}
}
#[doc(hidden)]
pub fn wrap_for_test(
read: Box<dyn AsyncRead + Send + Unpin>,
write: Box<dyn AsyncWrite + Send + Unpin>,
) -> Self {
Self::wrap(read, write)
}
pub async fn embed(&mut self, req: EmbedRequest) -> Result<EmbedResponse, ClientError> {
let mut buf = Vec::with_capacity(512);
serde_json::to_writer(&mut buf, &req)?;
buf.push(b'\n');
let mut g = self.inner.lock().await;
g.write.write_all(&buf).await?;
g.write.flush().await?;
let mut line = Vec::with_capacity(512);
let n = g.read.read_until(b'\n', &mut line).await?;
if n == 0 {
return Err(ClientError::UnexpectedEof);
}
let resp: EmbedResponse = serde_json::from_slice(&line)?;
Ok(resp)
}
}
pub fn default_embed_addr() -> std::path::PathBuf {
#[cfg(target_os = "linux")]
{
if let Some(xdg) = std::env::var_os("XDG_RUNTIME_DIR") {
let mut p = std::path::PathBuf::from(xdg);
if !p.as_os_str().is_empty() {
p.push("inferd");
p.push("infer.embed.sock");
return p;
}
}
if let Some(home) = std::env::var_os("HOME") {
let mut p = std::path::PathBuf::from(home);
if !p.as_os_str().is_empty() {
p.push(".inferd");
p.push("run");
p.push("infer.embed.sock");
return p;
}
}
std::path::PathBuf::from("/tmp/inferd/infer.embed.sock")
}
#[cfg(target_os = "macos")]
{
let mut p = std::env::temp_dir();
p.push("inferd");
p.push("infer.embed.sock");
p
}
#[cfg(windows)]
{
std::path::PathBuf::from(r"\\.\pipe\inferd-infer-embed")
}
#[cfg(not(any(target_os = "linux", target_os = "macos", windows)))]
{
std::path::PathBuf::from("/tmp/inferd/infer.embed.sock")
}
}
#[cfg(test)]
mod tests {
use super::*;
use inferd_proto::embed::{EmbedErrorCode, EmbedTask, EmbedUsage};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
fn sample_request() -> EmbedRequest {
EmbedRequest {
id: "embed-test".into(),
input: vec!["hello".into(), "world".into()],
dimensions: Some(128),
task: Some(EmbedTask::RetrievalDocument),
}
}
#[tokio::test]
async fn embed_round_trips_a_success_frame() {
let (server_side, client_side) = tokio::io::duplex(4096);
let (read, write) = tokio::io::split(client_side);
let mut client = EmbedClient::wrap(Box::new(read), Box::new(write));
let server = tokio::spawn(async move {
let (rx, mut tx) = tokio::io::split(server_side);
let mut br = tokio::io::BufReader::new(rx);
let mut req_line = Vec::new();
br.read_until(b'\n', &mut req_line).await.unwrap();
let frame = serde_json::to_vec(&EmbedResponse::Embeddings {
id: "embed-test".into(),
embeddings: vec![vec![0.1, 0.2], vec![0.3, 0.4]],
dimensions: 128,
model: "embeddinggemma-300m".into(),
usage: EmbedUsage { input_tokens: 4 },
backend: "llamacpp".into(),
})
.unwrap();
tx.write_all(&frame).await.unwrap();
tx.write_all(b"\n").await.unwrap();
});
let resp = client.embed(sample_request()).await.unwrap();
server.await.unwrap();
match resp {
EmbedResponse::Embeddings {
embeddings,
dimensions,
backend,
..
} => {
assert_eq!(embeddings.len(), 2);
assert_eq!(dimensions, 128);
assert_eq!(backend, "llamacpp");
}
other => panic!("expected Embeddings, got {other:?}"),
}
}
#[tokio::test]
async fn embed_round_trips_an_error_frame() {
let (server_side, client_side) = tokio::io::duplex(4096);
let (read, write) = tokio::io::split(client_side);
let mut client = EmbedClient::wrap(Box::new(read), Box::new(write));
let server = tokio::spawn(async move {
let (rx, mut tx) = tokio::io::split(server_side);
let mut br = tokio::io::BufReader::new(rx);
let mut req_line = Vec::new();
br.read_until(b'\n', &mut req_line).await.unwrap();
let frame = serde_json::to_vec(&EmbedResponse::Error {
id: "embed-test".into(),
code: EmbedErrorCode::InvalidRequest,
message: "dimensions=999 not supported".into(),
})
.unwrap();
tx.write_all(&frame).await.unwrap();
tx.write_all(b"\n").await.unwrap();
});
let resp = client.embed(sample_request()).await.unwrap();
server.await.unwrap();
match resp {
EmbedResponse::Error { code, .. } => {
assert_eq!(code, EmbedErrorCode::InvalidRequest);
}
other => panic!("expected Error, got {other:?}"),
}
}
#[tokio::test]
async fn unexpected_eof_yields_clienterror() {
let (server_side, client_side) = tokio::io::duplex(4096);
let (read, write) = tokio::io::split(client_side);
let mut client = EmbedClient::wrap(Box::new(read), Box::new(write));
let server = tokio::spawn(async move {
let (rx, _tx) = tokio::io::split(server_side);
let mut br = tokio::io::BufReader::new(rx);
let mut req_line = Vec::new();
br.read_until(b'\n', &mut req_line).await.unwrap();
});
let result = client.embed(sample_request()).await;
server.await.unwrap();
match result {
Err(ClientError::UnexpectedEof) => {}
other => panic!("expected UnexpectedEof, got {other:?}"),
}
}
#[tokio::test]
async fn connection_stays_open_for_a_second_request() {
let (server_side, client_side) = tokio::io::duplex(4096);
let (read, write) = tokio::io::split(client_side);
let mut client = EmbedClient::wrap(Box::new(read), Box::new(write));
let server = tokio::spawn(async move {
let (rx, mut tx) = tokio::io::split(server_side);
let mut br = tokio::io::BufReader::new(rx);
for i in 0..2 {
let mut req_line = Vec::new();
br.read_until(b'\n', &mut req_line).await.unwrap();
let frame = serde_json::to_vec(&EmbedResponse::Embeddings {
id: format!("r{i}"),
embeddings: vec![vec![0.0]],
dimensions: 1,
model: "m".into(),
usage: EmbedUsage { input_tokens: 1 },
backend: "mock".into(),
})
.unwrap();
tx.write_all(&frame).await.unwrap();
tx.write_all(b"\n").await.unwrap();
}
});
for i in 0..2 {
let req = EmbedRequest {
id: format!("r{i}"),
input: vec!["x".into()],
..Default::default()
};
let resp = client.embed(req).await.unwrap();
assert_eq!(resp.id(), format!("r{i}"));
}
server.await.unwrap();
}
}