use anyhow::{Context, Result};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::{mpsc, oneshot};
const DEFAULT_REPLY_TIMEOUT_SECS: u64 = 60;
fn reply_timeout() -> std::time::Duration {
static CACHED: std::sync::OnceLock<std::time::Duration> = std::sync::OnceLock::new();
*CACHED.get_or_init(|| {
let secs = std::env::var("TRUSTY_EMBED_POOL_REPLY_TIMEOUT_SECS")
.ok()
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(DEFAULT_REPLY_TIMEOUT_SECS);
std::time::Duration::from_secs(secs)
})
}
use crate::core::embed::Embedder;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum RequestPriority {
Interactive,
Background,
}
const LANE_CAPACITY: usize = 64;
pub(crate) struct EmbedRequest {
pub(crate) texts: Vec<String>,
pub(crate) reply: oneshot::Sender<Result<Vec<Vec<f32>>>>,
pub(crate) priority: RequestPriority,
}
pub struct EmbedPool {
pub(crate) interactive_tx: mpsc::Sender<EmbedRequest>,
pub(crate) background_tx: mpsc::Sender<EmbedRequest>,
workers: usize,
in_flight: Arc<AtomicUsize>,
_worker_threads: Vec<std::thread::JoinHandle<()>>,
stall_tracker: Option<Arc<crate::service::stall_tracker::EmbedderStallTracker>>,
}
impl EmbedPool {
pub fn new(workers: usize, embedder: Arc<dyn Embedder>) -> Self {
let workers = workers.max(1);
let (interactive_tx, interactive_rx) = mpsc::channel::<EmbedRequest>(LANE_CAPACITY);
let (background_tx, background_rx) = mpsc::channel::<EmbedRequest>(LANE_CAPACITY);
let interactive_rx = Arc::new(tokio::sync::Mutex::new(interactive_rx));
let background_rx = Arc::new(tokio::sync::Mutex::new(background_rx));
let in_flight = Arc::new(AtomicUsize::new(0));
metrics::gauge!("trusty_embed_pool_workers").set(workers as f64);
let mut worker_threads = Vec::with_capacity(workers);
for worker_id in 0..workers {
let interactive_rx = Arc::clone(&interactive_rx);
let background_rx = Arc::clone(&background_rx);
let embedder = Arc::clone(&embedder);
let handle = std::thread::Builder::new()
.name(format!("trusty-embed-{worker_id}"))
.spawn(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.thread_name(format!("trusty-embed-io-{worker_id}"))
.build()
.expect("embed worker: failed to build tokio runtime");
rt.block_on(worker_loop(
worker_id,
interactive_rx,
background_rx,
embedder,
));
})
.expect("embed worker: failed to spawn OS thread");
worker_threads.push(handle);
}
Self {
interactive_tx,
background_tx,
workers,
in_flight,
_worker_threads: worker_threads,
stall_tracker: None,
}
}
pub fn with_stall_tracker(
mut self,
tracker: Arc<crate::service::stall_tracker::EmbedderStallTracker>,
) -> Self {
self.stall_tracker = Some(tracker);
self
}
pub fn with_autotune(embedder: Arc<dyn Embedder>) -> Self {
let workers = autotune_workers();
tracing::info!("embed pool: {} workers (isolated OS threads)", workers);
Self::new(workers, embedder)
}
pub async fn embed(
&self,
texts: Vec<String>,
priority: RequestPriority,
) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let (reply_tx, reply_rx) = oneshot::channel();
let req = EmbedRequest {
texts,
reply: reply_tx,
priority,
};
let tx = match priority {
RequestPriority::Interactive => &self.interactive_tx,
RequestPriority::Background => &self.background_tx,
};
self.in_flight.fetch_add(1, Ordering::Relaxed);
metrics::gauge!("trusty_embed_pool_utilisation")
.set(self.in_flight.load(Ordering::Relaxed) as f64);
let send_result = tx.send(req).await.context("embed pool closed");
let deadline = reply_timeout();
let result = match send_result {
Ok(()) => match tokio::time::timeout(deadline, reply_rx).await {
Ok(Ok(r)) => r,
Ok(Err(_)) => Err(anyhow::anyhow!("embed pool worker dropped reply")),
Err(_elapsed) => Err(anyhow::anyhow!(
"embed pool reply timed out after {}s — worker may have panicked \
(set TRUSTY_EMBED_POOL_REPLY_TIMEOUT_SECS to adjust)",
deadline.as_secs()
)),
},
Err(e) => Err(e),
};
self.in_flight.fetch_sub(1, Ordering::Relaxed);
metrics::gauge!("trusty_embed_pool_utilisation")
.set(self.in_flight.load(Ordering::Relaxed) as f64);
if let Some(tracker) = &self.stall_tracker {
if result.is_ok() {
tracker.record_success();
} else {
tracker.record_timeout();
}
}
result
}
pub fn workers(&self) -> usize {
self.workers
}
}
async fn worker_loop(
worker_id: usize,
interactive_rx: Arc<tokio::sync::Mutex<mpsc::Receiver<EmbedRequest>>>,
background_rx: Arc<tokio::sync::Mutex<mpsc::Receiver<EmbedRequest>>>,
embedder: Arc<dyn Embedder>,
) {
loop {
let req = {
let mut interactive_guard = interactive_rx.lock().await;
let mut background_guard = background_rx.lock().await;
tokio::select! {
biased;
msg = interactive_guard.recv() => msg,
msg = background_guard.recv() => msg,
}
};
let Some(req) = req else {
tracing::debug!(worker_id, "embed pool worker exiting (channels closed)");
return;
};
let EmbedRequest {
texts,
reply,
priority,
} = req;
let started = std::time::Instant::now();
let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
let result = embedder
.embed_batch(&text_refs)
.await
.context("embed pool worker: embed_batch failed");
let elapsed_ms = started.elapsed().as_millis() as u64;
tracing::trace!(
worker_id,
priority = ?priority,
batch_size = texts.len(),
elapsed_ms,
"embed pool dispatched batch"
);
let _ = reply.send(result);
}
}
pub fn autotune_workers() -> usize {
if let Ok(raw) = std::env::var("TRUSTY_EMBED_WORKERS") {
if let Ok(n) = raw.parse::<usize>() {
return n.max(1);
}
}
let ram_mb = crate::core::memory_policy::detect_total_ram_mb().unwrap_or(8 * 1024);
let ram_gb = ram_mb / 1024;
if ram_gb <= 16 {
1
} else if ram_gb <= 32 {
2
} else {
4
}
}
#[cfg(test)]
#[path = "embed_pool_tests.rs"]
mod tests;