use std::collections::VecDeque;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{ChildStdin, ChildStdout};
use tokio::sync::{Mutex, Semaphore, oneshot};
use tokio::time::Duration;
use super::{EmbedderClient, EmbedderError};
const EMBED_CALL_TIMEOUT_DEFAULT_SECS: u64 = 120;
fn embed_call_timeout() -> Duration {
static CACHED: std::sync::OnceLock<Duration> = std::sync::OnceLock::new();
*CACHED.get_or_init(|| {
let secs = std::env::var("TRUSTY_EMBEDDERD_CALL_TIMEOUT_SECS")
.ok()
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(EMBED_CALL_TIMEOUT_DEFAULT_SECS);
Duration::from_secs(secs)
})
}
fn embed_inflight() -> usize {
static CACHED: std::sync::OnceLock<usize> = std::sync::OnceLock::new();
*CACHED.get_or_init(|| {
std::env::var("TRUSTY_EMBED_INFLIGHT")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.map(|n| n.clamp(1, 4))
.unwrap_or(2)
})
}
const METHOD_EMBED: &str = "embed";
const JSONRPC_VERSION: &str = "2.0";
#[derive(Debug, serde::Serialize)]
struct RpcRequest<'a> {
jsonrpc: &'a str,
method: &'a str,
params: EmbedParams<'a>,
id: u64,
}
#[derive(Debug, serde::Serialize)]
struct EmbedParams<'a> {
texts: &'a [String],
}
#[derive(Debug, serde::Deserialize)]
struct RpcResponse {
#[serde(default)]
result: Option<EmbedResult>,
#[serde(default)]
error: Option<RpcError>,
#[allow(dead_code)]
#[serde(default)]
id: Option<serde_json::Value>,
}
#[derive(Debug, serde::Deserialize)]
struct EmbedResult {
embeddings: Vec<Vec<f32>>,
}
#[derive(Debug, serde::Deserialize)]
struct RpcError {
code: i32,
message: String,
}
struct PendingRequest {
sent: usize,
reply: oneshot::Sender<Result<Vec<Vec<f32>>, EmbedderError>>,
}
type PendingQueue = Arc<Mutex<VecDeque<PendingRequest>>>;
pub struct StdioEmbedderClient {
stdin: Arc<Mutex<ChildStdin>>,
pending: PendingQueue,
inflight: Arc<Semaphore>,
next_id: Arc<AtomicU64>,
}
impl StdioEmbedderClient {
pub fn new(stdin: ChildStdin, stdout: ChildStdout) -> Self {
let stdin = Arc::new(Mutex::new(stdin));
let pending: PendingQueue = Arc::new(Mutex::new(VecDeque::new()));
let inflight = Arc::new(Semaphore::new(embed_inflight()));
let next_id = Arc::new(AtomicU64::new(1));
let pending_clone = Arc::clone(&pending);
tokio::spawn(reader_task(BufReader::new(stdout), pending_clone));
Self {
stdin,
pending,
inflight,
next_id,
}
}
}
async fn reader_task(mut reader: BufReader<ChildStdout>, pending: PendingQueue) {
let timeout = embed_call_timeout();
let mut line = String::new();
loop {
line.clear();
let read_result = tokio::time::timeout(timeout, reader.read_line(&mut line)).await;
match read_result {
Err(_elapsed) => {
tracing::warn!(
timeout_secs = timeout.as_secs(),
"StdioEmbedderClient reader: timed out waiting for response \
(sidecar may be stalled) — draining pending requests"
);
drain_pending_with_error(
&pending,
EmbedderError::Stdio(format!(
"embed call timed out after {}s — sidecar may be stalled \
(set TRUSTY_EMBEDDERD_CALL_TIMEOUT_SECS to adjust)",
timeout.as_secs()
)),
)
.await;
return;
}
Ok(Err(e)) => {
tracing::warn!(
"StdioEmbedderClient reader: IO error reading from sidecar stdout: {e}"
);
drain_pending_with_error(
&pending,
EmbedderError::Stdio(format!("read response from child stdout: {e}")),
)
.await;
return;
}
Ok(Ok(0)) => {
tracing::info!(
"StdioEmbedderClient reader: stdout EOF \
(sidecar exited) — draining pending requests"
);
drain_pending_with_error(
&pending,
EmbedderError::Stdio(
"child closed stdout before responding (process exited)".to_owned(),
),
)
.await;
return;
}
Ok(Ok(_)) => {
}
}
let req = {
let mut guard = pending.lock().await;
guard.pop_front()
};
let Some(pending_req) = req else {
tracing::warn!(
"StdioEmbedderClient reader: received response but pending queue is empty \
(spurious frame from sidecar?) — ignoring"
);
continue;
};
let result = decode_response(line.trim(), pending_req.sent);
let _ = pending_req.reply.send(result);
}
}
fn decode_response(line: &str, sent: usize) -> Result<Vec<Vec<f32>>, EmbedderError> {
let resp: RpcResponse = serde_json::from_str(line)
.map_err(|e| EmbedderError::Stdio(format!("decode response (raw={line:?}): {e}")))?;
if let Some(err) = resp.error {
return Err(EmbedderError::ModelError(format!(
"daemon RPC error {}: {}",
err.code, err.message
)));
}
let result = resp.result.ok_or_else(|| {
EmbedderError::Stdio("response missing both result and error fields".to_owned())
})?;
if result.embeddings.len() != sent {
return Err(EmbedderError::DimensionMismatch {
sent,
got: result.embeddings.len(),
});
}
Ok(result.embeddings)
}
async fn drain_pending_with_error(pending: &PendingQueue, error: EmbedderError) {
let mut guard = pending.lock().await;
for req in guard.drain(..) {
let _ = req.reply.send(Err(EmbedderError::Stdio(
match &error {
EmbedderError::Stdio(msg) => msg.clone(),
EmbedderError::ModelError(msg) => msg.clone(),
EmbedderError::DimensionMismatch { sent, got } => {
format!("dimension mismatch: sent={sent}, got={got}")
}
other => format!("{other}"),
},
)));
}
}
#[async_trait::async_trait]
impl EmbedderClient for StdioEmbedderClient {
async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, EmbedderError> {
if texts.is_empty() {
return Ok(vec![]);
}
let sent = texts.len();
let _permit = self
.inflight
.acquire()
.await
.map_err(|_| EmbedderError::Stdio("inflight semaphore closed".to_owned()))?;
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
tracing::debug!(n = sent, id, "StdioEmbedderClient: sending batch");
let (reply_tx, reply_rx) = oneshot::channel();
{
let mut guard = self.pending.lock().await;
guard.push_back(PendingRequest {
sent,
reply: reply_tx,
});
}
let req = RpcRequest {
jsonrpc: JSONRPC_VERSION,
method: METHOD_EMBED,
params: EmbedParams { texts: &texts },
id,
};
let mut payload = serde_json::to_vec(&req)
.map_err(|e| EmbedderError::Stdio(format!("serialise JSON-RPC request: {e}")))?;
payload.push(b'\n');
{
let mut stdin_guard = self.stdin.lock().await;
stdin_guard
.write_all(&payload)
.await
.map_err(|e| EmbedderError::Stdio(format!("write request to child stdin: {e}")))?;
stdin_guard
.flush()
.await
.map_err(|e| EmbedderError::Stdio(format!("flush child stdin: {e}")))?;
}
let result = reply_rx.await.map_err(|_| {
EmbedderError::Stdio(
"reader task dropped reply channel (sidecar crashed or was restarted)".to_owned(),
)
})?;
tracing::debug!(n = sent, id, "StdioEmbedderClient: batch complete");
result
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn request_serialises_correctly() {
let texts = vec!["hello".to_string(), "world".to_string()];
let req = RpcRequest {
jsonrpc: JSONRPC_VERSION,
method: METHOD_EMBED,
params: EmbedParams { texts: &texts },
id: 1,
};
let s = serde_json::to_string(&req).unwrap();
assert!(s.contains("\"jsonrpc\":\"2.0\""), "must have jsonrpc 2.0");
assert!(s.contains("\"method\":\"embed\""), "must have embed method");
assert!(
s.contains("\"texts\":[\"hello\",\"world\"]"),
"must include texts"
);
assert!(s.contains("\"id\":1"), "must have id");
}
#[test]
fn error_response_maps_to_model_error() {
let json = r#"{"jsonrpc":"2.0","error":{"code":-32603,"message":"ort failed"},"id":1}"#;
let result = decode_response(json, 1);
assert!(
matches!(result, Err(EmbedderError::ModelError(_))),
"got: {result:?}"
);
}
#[test]
fn success_response_decoded() {
let json = r#"{"jsonrpc":"2.0","result":{"embeddings":[[0.1,0.2],[0.3,0.4]]},"id":1}"#;
let result = decode_response(json, 2).unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0][0], 0.1_f32);
}
#[test]
fn count_mismatch_returns_dimension_error() {
let json = r#"{"jsonrpc":"2.0","result":{"embeddings":[[0.1],[0.2]]},"id":1}"#;
let result = decode_response(json, 3);
assert!(
matches!(
result,
Err(EmbedderError::DimensionMismatch { sent: 3, got: 2 })
),
"got: {result:?}"
);
}
#[tokio::test]
async fn embed_call_stalled_reader_times_out() {
use tokio::io::duplex;
let (_tx, rx) = duplex(1024);
let mut buf = String::new();
let mut reader = tokio::io::BufReader::new(rx);
let result = tokio::time::timeout(Duration::from_secs(1), reader.read_line(&mut buf)).await;
assert!(
result.is_err(),
"a read_line on a never-writing reader must time out under a 1 s deadline; \
got: {result:?}"
);
}
}