use std::path::{Path, PathBuf};
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::UnixStream;
use super::{EmbedderClient, EmbedderError};
const METHOD_EMBED: &str = "embed";
const JSONRPC_VERSION: &str = "2.0";
#[derive(Debug, Serialize)]
struct RpcRequest<'a> {
jsonrpc: &'a str,
method: &'a str,
params: EmbedParams<'a>,
id: u64,
}
#[derive(Debug, Serialize)]
struct EmbedParams<'a> {
texts: &'a [String],
}
#[derive(Debug, Deserialize)]
struct RpcResponse {
#[serde(default)]
result: Option<EmbedResult>,
#[serde(default)]
error: Option<RpcError>,
}
#[derive(Debug, Deserialize)]
struct EmbedResult {
embeddings: Vec<Vec<f32>>,
}
#[derive(Debug, Deserialize)]
struct RpcError {
code: i32,
message: String,
}
#[derive(Debug, Clone)]
pub struct UdsEmbedderClient {
socket_path: PathBuf,
}
impl UdsEmbedderClient {
pub fn new(socket_path: impl Into<PathBuf>) -> Self {
Self {
socket_path: socket_path.into(),
}
}
pub fn default_path() -> PathBuf {
let dir = match std::env::var("TMPDIR") {
Ok(p) if !p.trim().is_empty() => PathBuf::from(p),
_ => PathBuf::from("/tmp"),
};
dir.join(SOCKET_FILENAME)
}
pub fn socket_path(&self) -> &Path {
&self.socket_path
}
}
pub const SOCKET_FILENAME: &str = "trusty-embedderd.sock";
#[async_trait::async_trait]
impl EmbedderClient for UdsEmbedderClient {
async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, EmbedderError> {
if texts.is_empty() {
return Ok(vec![]);
}
let sent = texts.len();
tracing::debug!(
socket = %self.socket_path.display(),
n = sent,
"UdsEmbedderClient: sending batch"
);
let stream = UnixStream::connect(&self.socket_path).await.map_err(|e| {
EmbedderError::Uds(format!(
"connect to {} failed: {e}",
self.socket_path.display()
))
})?;
let (read_half, mut write_half) = stream.into_split();
let req = RpcRequest {
jsonrpc: JSONRPC_VERSION,
method: METHOD_EMBED,
params: EmbedParams { texts: &texts },
id: 1,
};
let mut payload = serde_json::to_vec(&req)
.map_err(|e| EmbedderError::Uds(format!("serialise JSON-RPC request: {e}")))?;
payload.push(b'\n');
write_half
.write_all(&payload)
.await
.map_err(|e| EmbedderError::Uds(format!("write request frame: {e}")))?;
write_half
.shutdown()
.await
.map_err(|e| EmbedderError::Uds(format!("half-close write side: {e}")))?;
let mut reader = BufReader::new(read_half);
let mut line = String::new();
let n = reader
.read_line(&mut line)
.await
.map_err(|e| EmbedderError::Uds(format!("read response frame: {e}")))?;
if n == 0 {
return Err(EmbedderError::Uds(
"daemon closed connection before responding".to_owned(),
));
}
let resp: RpcResponse = serde_json::from_str(line.trim()).map_err(|e| {
EmbedderError::Uds(format!("decode response (raw={:?}): {e}", line.trim()))
})?;
if let Some(err) = resp.error {
return Err(EmbedderError::ModelError(format!(
"daemon RPC error {}: {}",
err.code, err.message
)));
}
let result = resp.result.ok_or_else(|| {
EmbedderError::Uds("response missing both result and error fields".to_owned())
})?;
if result.embeddings.len() != sent {
return Err(EmbedderError::DimensionMismatch {
sent,
got: result.embeddings.len(),
});
}
tracing::debug!(
socket = %self.socket_path.display(),
n = sent,
"UdsEmbedderClient: batch complete"
);
Ok(result.embeddings)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn empty_batch_short_circuits() {
let client = UdsEmbedderClient::new("/nonexistent/socket/path");
let result = client
.embed_batch(vec![])
.await
.expect("empty batch must short-circuit");
assert!(result.is_empty());
}
#[test]
fn request_serialises_correctly() {
let texts = vec!["hello".to_string(), "world".to_string()];
let req = RpcRequest {
jsonrpc: JSONRPC_VERSION,
method: METHOD_EMBED,
params: EmbedParams { texts: &texts },
id: 1,
};
let s = serde_json::to_string(&req).unwrap();
assert!(s.contains("\"jsonrpc\":\"2.0\""), "must have jsonrpc 2.0");
assert!(s.contains("\"method\":\"embed\""), "must have embed method");
assert!(
s.contains("\"texts\":[\"hello\",\"world\"]"),
"must include texts"
);
assert!(s.contains("\"id\":1"), "must have id");
}
#[test]
fn error_response_maps_to_model_error() {
let json = r#"{"jsonrpc":"2.0","error":{"code":-32603,"message":"ort failed"},"id":1}"#;
let resp: RpcResponse = serde_json::from_str(json).unwrap();
assert!(resp.error.is_some());
assert!(resp.result.is_none());
let err = resp.error.unwrap();
assert_eq!(err.code, -32603);
assert!(err.message.contains("ort failed"));
}
#[test]
fn default_socket_path_uses_tmpdir() {
let p = UdsEmbedderClient::default_path();
assert_eq!(
p.file_name().and_then(|s| s.to_str()),
Some(SOCKET_FILENAME),
"default path must end with {SOCKET_FILENAME}"
);
assert!(p.parent().is_some(), "must have a parent directory");
}
#[test]
fn dimension_mismatch_detected() {
let resp = RpcResponse {
result: Some(EmbedResult {
embeddings: vec![vec![0.1_f32]],
}),
error: None,
};
let sent = 2;
let got = resp.result.unwrap().embeddings.len();
assert_ne!(sent, got);
let err = EmbedderError::DimensionMismatch { sent, got };
let s = err.to_string();
assert!(s.contains("2") && s.contains("1"));
}
}