use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{ChildStdin, ChildStdout};
use tokio::sync::{Mutex, Semaphore, oneshot};
use tokio::time::Duration;
use super::{EmbedderClient, EmbedderError};
use crate::embedder::{ExecutionProvider, resolve_expected_provider};
const EMBED_CALL_TIMEOUT_DEFAULT_SECS: u64 = 30;
fn embed_call_timeout() -> Duration {
static CACHED: std::sync::OnceLock<Duration> = std::sync::OnceLock::new();
*CACHED.get_or_init(|| {
let secs = std::env::var("TRUSTY_EMBEDDERD_CALL_TIMEOUT_SECS")
.ok()
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(EMBED_CALL_TIMEOUT_DEFAULT_SECS);
Duration::from_secs(secs)
})
}
fn embed_inflight() -> usize {
static CACHED: std::sync::OnceLock<usize> = std::sync::OnceLock::new();
*CACHED.get_or_init(|| {
std::env::var("TRUSTY_EMBED_INFLIGHT")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.map(|n| n.clamp(1, 4))
.unwrap_or(2)
})
}
const METHOD_EMBED: &str = "embed";
const JSONRPC_VERSION: &str = "2.0";
#[derive(Debug, serde::Serialize)]
struct RpcRequest<'a> {
jsonrpc: &'a str,
method: &'a str,
params: EmbedParams<'a>,
id: u64,
}
#[derive(Debug, serde::Serialize)]
struct EmbedParams<'a> {
texts: &'a [String],
}
#[derive(Debug, serde::Deserialize)]
struct RpcResponse {
#[serde(default)]
result: Option<EmbedResult>,
#[serde(default)]
error: Option<RpcError>,
}
#[derive(Debug, serde::Deserialize)]
struct EmbedResult {
embeddings: Vec<Vec<f32>>,
}
#[derive(Debug, serde::Deserialize)]
struct RpcError {
code: i32,
message: String,
}
struct PendingRequest {
sent: usize,
reply: oneshot::Sender<Result<Vec<Vec<f32>>, EmbedderError>>,
}
type PendingMap = Arc<Mutex<HashMap<u64, PendingRequest>>>;
pub struct StdioEmbedderClient {
stdin: Arc<Mutex<ChildStdin>>,
pending: PendingMap,
inflight: Arc<Semaphore>,
next_id: Arc<AtomicU64>,
}
impl StdioEmbedderClient {
pub fn new(stdin: ChildStdin, stdout: ChildStdout) -> Self {
let stdin = Arc::new(Mutex::new(stdin));
let pending: PendingMap = Arc::new(Mutex::new(HashMap::new()));
let inflight = Arc::new(Semaphore::new(embed_inflight()));
let next_id = Arc::new(AtomicU64::new(1));
let pending_clone = Arc::clone(&pending);
let timeout = embed_call_timeout();
tokio::spawn(reader_task(BufReader::new(stdout), pending_clone, timeout));
Self {
stdin,
pending,
inflight,
next_id,
}
}
}
fn timeout_stall_hint(provider: ExecutionProvider) -> &'static str {
match provider {
ExecutionProvider::Cuda => "CUDA OOM/BFCArena stall?",
ExecutionProvider::CoreML | ExecutionProvider::CoreMLAne => {
"CoreML/ANE session-init or oversized-batch stall?"
}
ExecutionProvider::Cpu => "embedder sidecar stall?",
}
}
async fn reader_task<R: AsyncBufRead + Unpin>(
mut reader: R,
pending: PendingMap,
timeout: Duration,
) {
let mut line = String::new();
loop {
line.clear();
let oldest_id: Option<u64> = {
let guard = pending.lock().await;
if guard.is_empty() {
None
} else {
guard.keys().copied().min()
}
};
let read_result = tokio::time::timeout(timeout, reader.read_line(&mut line)).await;
match read_result {
Err(_elapsed) => {
let stall_hint = timeout_stall_hint(resolve_expected_provider());
if let Some(id) = oldest_id {
tracing::warn!(
timeout_secs = timeout.as_secs(),
timed_out_id = id,
"StdioEmbedderClient reader: timed out waiting for response \
({}s — {}) — removing stalled entry, \
re-arming; task STAYS ALIVE",
timeout.as_secs(),
stall_hint,
);
} else {
tracing::debug!(
timeout_secs = timeout.as_secs(),
timed_out_id = ?oldest_id,
"StdioEmbedderClient reader: timeout fired with no in-flight \
request (idle re-arm, {}s — embedder healthy) — re-arming; \
task STAYS ALIVE",
timeout.as_secs(),
);
}
if let Some(id) = oldest_id {
let req = {
let mut guard = pending.lock().await;
guard.remove(&id)
};
if let Some(r) = req {
let _ = r.reply.send(Err(EmbedderError::Stdio(format!(
"embed call timed out after {}s (id={id}) — sidecar \
stalled (set TRUSTY_EMBEDDERD_CALL_TIMEOUT_SECS to adjust)",
timeout.as_secs()
))));
}
}
line.clear();
continue;
}
Ok(Err(e)) => {
tracing::warn!(
"StdioEmbedderClient reader: IO error reading from sidecar stdout: {e}"
);
drain_pending_with_error(
&pending,
EmbedderError::Stdio(format!("read response from child stdout: {e}")),
)
.await;
return;
}
Ok(Ok(0)) => {
tracing::info!(
"StdioEmbedderClient reader: stdout EOF \
(sidecar exited) — draining pending requests"
);
drain_pending_with_error(
&pending,
EmbedderError::Stdio(
"child closed stdout before responding (process exited)".to_owned(),
),
)
.await;
return;
}
Ok(Ok(_)) => {
}
}
let resp_id: Option<u64> = extract_response_id(line.trim());
let Some(response_id) = resp_id else {
tracing::warn!(
raw = %line.trim(),
"StdioEmbedderClient reader: received response with no parseable id — \
discarding (malformed sidecar frame)"
);
continue;
};
let req = {
let mut guard = pending.lock().await;
guard.remove(&response_id)
};
let Some(pending_req) = req else {
tracing::warn!(
response_id,
"StdioEmbedderClient reader: received response for id={} but \
no pending entry found — discarding stale/orphaned frame \
(likely a late reply for a previously timed-out request)",
response_id
);
continue;
};
let result = decode_response(line.trim(), pending_req.sent);
let _ = pending_req.reply.send(result);
}
}
fn extract_response_id(line: &str) -> Option<u64> {
#[derive(serde::Deserialize)]
struct IdOnly {
#[serde(default)]
id: Option<serde_json::Value>,
}
let parsed: IdOnly = serde_json::from_str(line).ok()?;
match parsed.id? {
serde_json::Value::Number(n) => n.as_u64(),
_ => None,
}
}
fn decode_response(line: &str, sent: usize) -> Result<Vec<Vec<f32>>, EmbedderError> {
let resp: RpcResponse = serde_json::from_str(line)
.map_err(|e| EmbedderError::Stdio(format!("decode response (raw={line:?}): {e}")))?;
if let Some(err) = resp.error {
return Err(EmbedderError::ModelError(format!(
"daemon RPC error {}: {}",
err.code, err.message
)));
}
let result = resp.result.ok_or_else(|| {
EmbedderError::Stdio("response missing both result and error fields".to_owned())
})?;
if result.embeddings.len() != sent {
return Err(EmbedderError::DimensionMismatch {
sent,
got: result.embeddings.len(),
});
}
Ok(result.embeddings)
}
async fn drain_pending_with_error(pending: &PendingMap, error: EmbedderError) {
let mut guard = pending.lock().await;
for (_id, req) in guard.drain() {
let _ = req.reply.send(Err(EmbedderError::Stdio(
match &error {
EmbedderError::Stdio(msg) => msg.clone(),
EmbedderError::ModelError(msg) => msg.clone(),
EmbedderError::DimensionMismatch { sent, got } => {
format!("dimension mismatch: sent={sent}, got={got}")
}
other => format!("{other}"),
},
)));
}
}
#[async_trait::async_trait]
impl EmbedderClient for StdioEmbedderClient {
async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, EmbedderError> {
if texts.is_empty() {
return Ok(vec![]);
}
let sent = texts.len();
let _permit = self
.inflight
.acquire()
.await
.map_err(|_| EmbedderError::Stdio("inflight semaphore closed".to_owned()))?;
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
tracing::debug!(n = sent, id, "StdioEmbedderClient: sending batch");
let (reply_tx, reply_rx) = oneshot::channel();
{
let mut guard = self.pending.lock().await;
guard.insert(
id,
PendingRequest {
sent,
reply: reply_tx,
},
);
}
let req = RpcRequest {
jsonrpc: JSONRPC_VERSION,
method: METHOD_EMBED,
params: EmbedParams { texts: &texts },
id,
};
let mut payload = serde_json::to_vec(&req)
.map_err(|e| EmbedderError::Stdio(format!("serialise JSON-RPC request: {e}")))?;
payload.push(b'\n');
{
let mut stdin_guard = self.stdin.lock().await;
stdin_guard
.write_all(&payload)
.await
.map_err(|e| EmbedderError::Stdio(format!("write request to child stdin: {e}")))?;
stdin_guard
.flush()
.await
.map_err(|e| EmbedderError::Stdio(format!("flush child stdin: {e}")))?;
}
let result = reply_rx.await.map_err(|_| {
EmbedderError::Stdio(
"reader task dropped reply channel (sidecar crashed or was restarted)".to_owned(),
)
})?;
tracing::debug!(n = sent, id, "StdioEmbedderClient: batch complete");
result
}
}
#[cfg(test)]
#[path = "stdio_tests.rs"]
mod tests;