use anyhow::{Context, Result};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::{mpsc, oneshot};
use crate::core::embed::Embedder;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum RequestPriority {
Interactive,
Background,
}
const LANE_CAPACITY: usize = 64;
struct EmbedRequest {
texts: Vec<String>,
reply: oneshot::Sender<Result<Vec<Vec<f32>>>>,
priority: RequestPriority,
}
pub struct EmbedPool {
interactive_tx: mpsc::Sender<EmbedRequest>,
background_tx: mpsc::Sender<EmbedRequest>,
workers: usize,
in_flight: Arc<AtomicUsize>,
}
impl EmbedPool {
pub fn new(workers: usize, embedder: Arc<dyn Embedder>) -> Self {
let workers = workers.max(1);
let (interactive_tx, interactive_rx) = mpsc::channel::<EmbedRequest>(LANE_CAPACITY);
let (background_tx, background_rx) = mpsc::channel::<EmbedRequest>(LANE_CAPACITY);
let interactive_rx = Arc::new(tokio::sync::Mutex::new(interactive_rx));
let background_rx = Arc::new(tokio::sync::Mutex::new(background_rx));
let in_flight = Arc::new(AtomicUsize::new(0));
metrics::gauge!("trusty_embed_pool_workers").set(workers as f64);
for worker_id in 0..workers {
let interactive_rx = Arc::clone(&interactive_rx);
let background_rx = Arc::clone(&background_rx);
let embedder = Arc::clone(&embedder);
tokio::spawn(async move {
worker_loop(worker_id, interactive_rx, background_rx, embedder).await;
});
}
Self {
interactive_tx,
background_tx,
workers,
in_flight,
}
}
pub fn with_autotune(embedder: Arc<dyn Embedder>) -> Self {
let workers = autotune_workers();
tracing::info!("embed pool: {} workers", workers);
Self::new(workers, embedder)
}
pub async fn embed(
&self,
texts: Vec<String>,
priority: RequestPriority,
) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let (reply_tx, reply_rx) = oneshot::channel();
let req = EmbedRequest {
texts,
reply: reply_tx,
priority,
};
let tx = match priority {
RequestPriority::Interactive => &self.interactive_tx,
RequestPriority::Background => &self.background_tx,
};
self.in_flight.fetch_add(1, Ordering::Relaxed);
metrics::gauge!("trusty_embed_pool_utilisation")
.set(self.in_flight.load(Ordering::Relaxed) as f64);
let send_result = tx.send(req).await.context("embed pool closed");
let result = match send_result {
Ok(()) => match reply_rx.await {
Ok(r) => r,
Err(_) => Err(anyhow::anyhow!("embed pool worker dropped reply")),
},
Err(e) => Err(e),
};
self.in_flight.fetch_sub(1, Ordering::Relaxed);
metrics::gauge!("trusty_embed_pool_utilisation")
.set(self.in_flight.load(Ordering::Relaxed) as f64);
result
}
pub fn workers(&self) -> usize {
self.workers
}
}
async fn worker_loop(
worker_id: usize,
interactive_rx: Arc<tokio::sync::Mutex<mpsc::Receiver<EmbedRequest>>>,
background_rx: Arc<tokio::sync::Mutex<mpsc::Receiver<EmbedRequest>>>,
embedder: Arc<dyn Embedder>,
) {
loop {
let req = {
let mut interactive_guard = interactive_rx.lock().await;
let mut background_guard = background_rx.lock().await;
tokio::select! {
biased;
msg = interactive_guard.recv() => msg,
msg = background_guard.recv() => msg,
}
};
let Some(req) = req else {
tracing::debug!(worker_id, "embed pool worker exiting (channels closed)");
return;
};
let EmbedRequest {
texts,
reply,
priority,
} = req;
let started = std::time::Instant::now();
let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
let result = embedder
.embed_batch(&text_refs)
.await
.context("embed pool worker: embed_batch failed");
let elapsed_ms = started.elapsed().as_millis() as u64;
tracing::trace!(
worker_id,
priority = ?priority,
batch_size = texts.len(),
elapsed_ms,
"embed pool dispatched batch"
);
let _ = reply.send(result);
}
}
pub fn autotune_workers() -> usize {
if let Ok(raw) = std::env::var("TRUSTY_EMBED_WORKERS") {
if let Ok(n) = raw.parse::<usize>() {
return n.max(1);
}
}
let ram_mb = crate::core::memory_policy::detect_total_ram_mb().unwrap_or(8 * 1024);
let ram_gb = ram_mb / 1024;
if ram_gb <= 16 {
1
} else if ram_gb <= 32 {
2
} else {
4
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::embed::MockEmbedder;
use std::time::Duration;
fn make_pool(workers: usize) -> EmbedPool {
let embedder: Arc<dyn Embedder> = Arc::new(MockEmbedder::new(384));
EmbedPool::new(workers, embedder)
}
#[tokio::test]
async fn embed_returns_vector_per_text() {
let pool = make_pool(2);
let out = pool
.embed(
vec!["hello".into(), "world".into()],
RequestPriority::Interactive,
)
.await
.expect("embed succeeds");
assert_eq!(out.len(), 2);
assert_eq!(out[0].len(), 384);
}
#[tokio::test]
async fn embed_handles_empty_input() {
let pool = make_pool(1);
let out = pool
.embed(vec![], RequestPriority::Background)
.await
.expect("empty embed is a no-op");
assert!(out.is_empty());
}
#[tokio::test]
async fn pool_creates_n_workers() {
let pool = make_pool(3);
assert_eq!(pool.workers(), 3);
}
#[tokio::test]
#[serial_test::serial(env_workers)]
async fn autotune_worker_count_matches_table() {
std::env::remove_var("TRUSTY_EMBED_WORKERS");
let n = autotune_workers();
assert!(
n == 1 || n == 2 || n == 4,
"autotune returned unexpected count: {n}"
);
}
#[tokio::test]
#[serial_test::serial(env_workers)]
async fn pool_autotune_respects_env_override() {
std::env::set_var("TRUSTY_EMBED_WORKERS", "7");
let n = autotune_workers();
std::env::remove_var("TRUSTY_EMBED_WORKERS");
assert_eq!(n, 7);
}
#[tokio::test]
async fn priority_ordering_interactive_drains_first() {
let pool = make_pool(1);
let interactive = pool
.embed(vec!["i".into()], RequestPriority::Interactive)
.await
.expect("interactive embed succeeds");
let background = pool
.embed(vec!["b".into()], RequestPriority::Background)
.await
.expect("background embed succeeds");
assert_eq!(interactive.len(), 1);
assert_eq!(background.len(), 1);
}
#[tokio::test]
async fn dropping_pool_shuts_workers_down() {
let pool = make_pool(1);
drop(pool);
tokio::time::sleep(Duration::from_millis(50)).await;
}
}