use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{ChildStdin, ChildStdout};
use tokio::sync::Mutex;
use tokio::time::Duration;
const EMBED_CALL_TIMEOUT_DEFAULT_SECS: u64 = 120;
fn embed_call_timeout() -> Duration {
static CACHED: std::sync::OnceLock<Duration> = std::sync::OnceLock::new();
*CACHED.get_or_init(|| {
let secs = std::env::var("TRUSTY_EMBEDDERD_CALL_TIMEOUT_SECS")
.ok()
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(EMBED_CALL_TIMEOUT_DEFAULT_SECS);
Duration::from_secs(secs)
})
}
use super::{EmbedderClient, EmbedderError};
const METHOD_EMBED: &str = "embed";
const JSONRPC_VERSION: &str = "2.0";
#[derive(Debug, serde::Serialize)]
struct RpcRequest<'a> {
jsonrpc: &'a str,
method: &'a str,
params: EmbedParams<'a>,
id: u64,
}
#[derive(Debug, serde::Serialize)]
struct EmbedParams<'a> {
texts: &'a [String],
}
#[derive(Debug, serde::Deserialize)]
struct RpcResponse {
#[serde(default)]
result: Option<EmbedResult>,
#[serde(default)]
error: Option<RpcError>,
}
#[derive(Debug, serde::Deserialize)]
struct EmbedResult {
embeddings: Vec<Vec<f32>>,
}
#[derive(Debug, serde::Deserialize)]
struct RpcError {
code: i32,
message: String,
}
pub struct StdioEmbedderClient {
stdin: Mutex<ChildStdin>,
stdout: Mutex<BufReader<ChildStdout>>,
}
impl StdioEmbedderClient {
pub fn new(stdin: ChildStdin, stdout: ChildStdout) -> Self {
Self {
stdin: Mutex::new(stdin),
stdout: Mutex::new(BufReader::new(stdout)),
}
}
}
#[async_trait::async_trait]
impl EmbedderClient for StdioEmbedderClient {
async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, EmbedderError> {
if texts.is_empty() {
return Ok(vec![]);
}
let sent = texts.len();
tracing::debug!(n = sent, "StdioEmbedderClient: sending batch");
let req = RpcRequest {
jsonrpc: JSONRPC_VERSION,
method: METHOD_EMBED,
params: EmbedParams { texts: &texts },
id: 1,
};
let mut payload = serde_json::to_vec(&req)
.map_err(|e| EmbedderError::Stdio(format!("serialise JSON-RPC request: {e}")))?;
payload.push(b'\n');
let mut stdin_guard = self.stdin.lock().await;
let mut stdout_guard = self.stdout.lock().await;
stdin_guard
.write_all(&payload)
.await
.map_err(|e| EmbedderError::Stdio(format!("write request to child stdin: {e}")))?;
stdin_guard
.flush()
.await
.map_err(|e| EmbedderError::Stdio(format!("flush child stdin: {e}")))?;
let mut line = String::new();
let timeout = embed_call_timeout();
let n = tokio::time::timeout(timeout, stdout_guard.read_line(&mut line))
.await
.map_err(|_| {
tracing::warn!(
timeout_secs = timeout.as_secs(),
"StdioEmbedderClient: embed call timed out — sidecar may be stalled"
);
EmbedderError::Stdio(format!(
"embed call timed out after {}s — sidecar may be stalled \
(set TRUSTY_EMBEDDERD_CALL_TIMEOUT_SECS to adjust)",
timeout.as_secs()
))
})?
.map_err(|e| EmbedderError::Stdio(format!("read response from child stdout: {e}")))?;
if n == 0 {
return Err(EmbedderError::Stdio(
"child closed stdout before responding (process crashed?)".to_owned(),
));
}
let resp: RpcResponse = serde_json::from_str(line.trim()).map_err(|e| {
EmbedderError::Stdio(format!("decode response (raw={:?}): {e}", line.trim()))
})?;
if let Some(err) = resp.error {
return Err(EmbedderError::ModelError(format!(
"daemon RPC error {}: {}",
err.code, err.message
)));
}
let result = resp.result.ok_or_else(|| {
EmbedderError::Stdio("response missing both result and error fields".to_owned())
})?;
if result.embeddings.len() != sent {
return Err(EmbedderError::DimensionMismatch {
sent,
got: result.embeddings.len(),
});
}
tracing::debug!(n = sent, "StdioEmbedderClient: batch complete");
Ok(result.embeddings)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn request_serialises_correctly() {
let texts = vec!["hello".to_string(), "world".to_string()];
let req = RpcRequest {
jsonrpc: JSONRPC_VERSION,
method: METHOD_EMBED,
params: EmbedParams { texts: &texts },
id: 1,
};
let s = serde_json::to_string(&req).unwrap();
assert!(s.contains("\"jsonrpc\":\"2.0\""), "must have jsonrpc 2.0");
assert!(s.contains("\"method\":\"embed\""), "must have embed method");
assert!(
s.contains("\"texts\":[\"hello\",\"world\"]"),
"must include texts"
);
assert!(s.contains("\"id\":1"), "must have id");
}
#[test]
fn error_response_maps_to_model_error() {
let json = r#"{"jsonrpc":"2.0","error":{"code":-32603,"message":"ort failed"},"id":1}"#;
let resp: RpcResponse = serde_json::from_str(json).unwrap();
assert!(resp.error.is_some());
assert!(resp.result.is_none());
let err = resp.error.unwrap();
assert_eq!(err.code, -32603);
assert!(err.message.contains("ort failed"));
}
#[test]
fn success_response_decoded() {
let json = r#"{"jsonrpc":"2.0","result":{"embeddings":[[0.1,0.2],[0.3,0.4]]},"id":1}"#;
let resp: RpcResponse = serde_json::from_str(json).unwrap();
assert!(resp.error.is_none());
let result = resp.result.unwrap();
assert_eq!(result.embeddings.len(), 2);
assert_eq!(result.embeddings[0][0], 0.1_f32);
}
#[tokio::test]
async fn embed_call_stalled_reader_times_out() {
use tokio::io::duplex;
let (_tx, rx) = duplex(1024);
let mut buf = String::new();
let mut reader = tokio::io::BufReader::new(rx);
let result = tokio::time::timeout(Duration::from_secs(1), reader.read_line(&mut buf)).await;
assert!(
result.is_err(),
"a read_line on a never-writing reader must time out under a 1 s deadline; \
got: {result:?}"
);
}
}