use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{ChildStdin, ChildStdout};
use tokio::sync::Mutex;
use super::{EmbedderClient, EmbedderError};
const METHOD_EMBED: &str = "embed";
const JSONRPC_VERSION: &str = "2.0";
#[derive(Debug, serde::Serialize)]
struct RpcRequest<'a> {
jsonrpc: &'a str,
method: &'a str,
params: EmbedParams<'a>,
id: u64,
}
#[derive(Debug, serde::Serialize)]
struct EmbedParams<'a> {
texts: &'a [String],
}
#[derive(Debug, serde::Deserialize)]
struct RpcResponse {
#[serde(default)]
result: Option<EmbedResult>,
#[serde(default)]
error: Option<RpcError>,
}
#[derive(Debug, serde::Deserialize)]
struct EmbedResult {
embeddings: Vec<Vec<f32>>,
}
#[derive(Debug, serde::Deserialize)]
struct RpcError {
code: i32,
message: String,
}
pub struct StdioEmbedderClient {
stdin: Mutex<ChildStdin>,
stdout: Mutex<BufReader<ChildStdout>>,
}
impl StdioEmbedderClient {
pub fn new(stdin: ChildStdin, stdout: ChildStdout) -> Self {
Self {
stdin: Mutex::new(stdin),
stdout: Mutex::new(BufReader::new(stdout)),
}
}
}
#[async_trait::async_trait]
impl EmbedderClient for StdioEmbedderClient {
async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, EmbedderError> {
if texts.is_empty() {
return Ok(vec![]);
}
let sent = texts.len();
tracing::debug!(n = sent, "StdioEmbedderClient: sending batch");
let req = RpcRequest {
jsonrpc: JSONRPC_VERSION,
method: METHOD_EMBED,
params: EmbedParams { texts: &texts },
id: 1,
};
let mut payload = serde_json::to_vec(&req)
.map_err(|e| EmbedderError::Stdio(format!("serialise JSON-RPC request: {e}")))?;
payload.push(b'\n');
let mut stdin_guard = self.stdin.lock().await;
let mut stdout_guard = self.stdout.lock().await;
stdin_guard
.write_all(&payload)
.await
.map_err(|e| EmbedderError::Stdio(format!("write request to child stdin: {e}")))?;
stdin_guard
.flush()
.await
.map_err(|e| EmbedderError::Stdio(format!("flush child stdin: {e}")))?;
let mut line = String::new();
let n = stdout_guard
.read_line(&mut line)
.await
.map_err(|e| EmbedderError::Stdio(format!("read response from child stdout: {e}")))?;
if n == 0 {
return Err(EmbedderError::Stdio(
"child closed stdout before responding (process crashed?)".to_owned(),
));
}
let resp: RpcResponse = serde_json::from_str(line.trim()).map_err(|e| {
EmbedderError::Stdio(format!("decode response (raw={:?}): {e}", line.trim()))
})?;
if let Some(err) = resp.error {
return Err(EmbedderError::ModelError(format!(
"daemon RPC error {}: {}",
err.code, err.message
)));
}
let result = resp.result.ok_or_else(|| {
EmbedderError::Stdio("response missing both result and error fields".to_owned())
})?;
if result.embeddings.len() != sent {
return Err(EmbedderError::DimensionMismatch {
sent,
got: result.embeddings.len(),
});
}
tracing::debug!(n = sent, "StdioEmbedderClient: batch complete");
Ok(result.embeddings)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn request_serialises_correctly() {
let texts = vec!["hello".to_string(), "world".to_string()];
let req = RpcRequest {
jsonrpc: JSONRPC_VERSION,
method: METHOD_EMBED,
params: EmbedParams { texts: &texts },
id: 1,
};
let s = serde_json::to_string(&req).unwrap();
assert!(s.contains("\"jsonrpc\":\"2.0\""), "must have jsonrpc 2.0");
assert!(s.contains("\"method\":\"embed\""), "must have embed method");
assert!(
s.contains("\"texts\":[\"hello\",\"world\"]"),
"must include texts"
);
assert!(s.contains("\"id\":1"), "must have id");
}
#[test]
fn error_response_maps_to_model_error() {
let json = r#"{"jsonrpc":"2.0","error":{"code":-32603,"message":"ort failed"},"id":1}"#;
let resp: RpcResponse = serde_json::from_str(json).unwrap();
assert!(resp.error.is_some());
assert!(resp.result.is_none());
let err = resp.error.unwrap();
assert_eq!(err.code, -32603);
assert!(err.message.contains("ort failed"));
}
#[test]
fn success_response_decoded() {
let json = r#"{"jsonrpc":"2.0","result":{"embeddings":[[0.1,0.2],[0.3,0.4]]},"id":1}"#;
let resp: RpcResponse = serde_json::from_str(json).unwrap();
assert!(resp.error.is_none());
let result = resp.result.unwrap();
assert_eq!(result.embeddings.len(), 2);
assert_eq!(result.embeddings[0][0], 0.1_f32);
}
}