use std::collections::VecDeque;
use std::sync::Arc;
use futures::StreamExt;
use futures::stream::FuturesUnordered;
use tokio::sync::mpsc;
use crate::BatchItem;
use crate::EmbeddingProvider;
use crate::batching::TokenAwareBatcher;
use crate::{Error, Result};
pub struct EmbeddingResult<M> {
pub items: Vec<M>,
pub embeddings: Vec<Vec<f32>>,
}
async fn embed_batch<M>(
client: &dyn EmbeddingProvider,
batch: Vec<BatchItem<M>>,
) -> Result<EmbeddingResult<M>> {
if batch.is_empty() {
return Ok(EmbeddingResult {
items: Vec::new(),
embeddings: Vec::new(),
});
}
let mut inputs = Vec::with_capacity(batch.len());
let mut items = Vec::with_capacity(batch.len());
for item in batch {
inputs.push(item.input);
items.push(item.meta);
}
let output = client.embed(&inputs).await?;
if output.embeddings.len() != items.len() {
return Err(Error::EmbeddingCountMismatch {
embeddings: output.embeddings.len(),
inputs: items.len(),
});
}
Ok(EmbeddingResult {
items,
embeddings: output.embeddings,
})
}
pub struct EmbedderService<M> {
batcher: TokenAwareBatcher<M>,
batch_tx: Option<mpsc::Sender<Vec<BatchItem<M>>>>,
}
impl<M: Send + 'static> EmbedderService<M> {
pub fn new(
embedder: Arc<dyn EmbeddingProvider>,
max_tokens: usize,
max_batch_size: usize,
workers: usize,
) -> Result<(Self, mpsc::Receiver<Result<EmbeddingResult<M>>>)> {
if workers == 0 {
return Err(Error::InvalidConfiguration {
message: "workers must be greater than zero".to_string(),
});
}
let worker_count = workers;
let batcher = TokenAwareBatcher::new(max_tokens, max_batch_size)?;
let (batch_tx, mut prepared_rx) = mpsc::channel::<Vec<BatchItem<M>>>(worker_count * 2);
let (execution_tx, mut execution_rx) = mpsc::channel::<Vec<BatchItem<M>>>(worker_count * 2);
let (result_tx, result_rx) = mpsc::channel::<Result<EmbeddingResult<M>>>(worker_count * 2);
tokio::spawn(async move {
while let Some(batch) = prepared_rx.recv().await {
if execution_tx.send(batch).await.is_err() {
break;
}
}
});
tokio::spawn(async move {
let mut in_flight: FuturesUnordered<_> = FuturesUnordered::new();
let mut pending: VecDeque<Vec<BatchItem<M>>> = VecDeque::new();
let mut execution_closed = false;
let pending_limit = worker_count * 2;
loop {
while in_flight.len() < worker_count {
let Some(batch) = pending.pop_front() else {
break;
};
let embedder = Arc::clone(&embedder);
in_flight.push(async move { embed_batch(embedder.as_ref(), batch).await });
}
if execution_closed && pending.is_empty() && in_flight.is_empty() {
break;
}
tokio::select! {
biased;
Some(result) = in_flight.next(), if !in_flight.is_empty() => {
if result_tx.send(result).await.is_err() {
break;
}
}
batch = execution_rx.recv(), if !execution_closed && pending.len() < pending_limit => {
match batch {
Some(batch) => pending.push_back(batch),
None => execution_closed = true,
}
}
else => {
}
}
}
});
Ok((
Self {
batcher,
batch_tx: Some(batch_tx),
},
result_rx,
))
}
pub async fn enqueue(&mut self, item: BatchItem<M>) -> Result<bool> {
if let Some(batch) = self.batcher.add(item)
&& let Some(tx) = self.batch_tx.as_ref()
{
tx.send(batch)
.await
.map_err(|_| Error::BatchChannelClosed)?;
return Ok(true);
}
Ok(false)
}
pub async fn flush(&mut self) -> Result<bool> {
if self.batch_tx.is_none() {
return Ok(false);
}
let mut sent = false;
if let Some(batch) = self.batcher.flush()
&& let Some(tx) = self.batch_tx.as_ref()
{
tx.send(batch)
.await
.map_err(|_| Error::BatchChannelClosed)?;
sent = true;
}
drop(self.batch_tx.take());
Ok(sent)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use tokio::sync::Notify;
use tokio::time::timeout;
use crate::{EmbedOutput, PreparedEmbeddingInput};
struct NoopProvider;
#[async_trait]
impl EmbeddingProvider for NoopProvider {
async fn embed(&self, input: &[PreparedEmbeddingInput]) -> Result<EmbedOutput> {
Ok(EmbedOutput {
embeddings: input.iter().map(|_| vec![0.0; 2]).collect(),
})
}
}
struct BlockingProvider {
gate: Arc<Notify>,
}
#[async_trait]
impl EmbeddingProvider for BlockingProvider {
async fn embed(&self, input: &[PreparedEmbeddingInput]) -> Result<EmbedOutput> {
self.gate.notified().await;
Ok(EmbedOutput {
embeddings: input.iter().map(|_| vec![0.0; 2]).collect(),
})
}
}
fn prepared(count: usize) -> PreparedEmbeddingInput {
PreparedEmbeddingInput::new(vec![1; count]).unwrap()
}
#[test]
fn service_rejects_zero_workers() {
let result = EmbedderService::<()>::new(Arc::new(NoopProvider), 64, 8, 0);
assert!(matches!(result, Err(Error::InvalidConfiguration { .. })));
}
#[test]
fn service_rejects_zero_batch_limits() {
let result = EmbedderService::<()>::new(Arc::new(NoopProvider), 0, 8, 1);
assert!(matches!(result, Err(Error::InvalidConfiguration { .. })));
}
#[tokio::test]
async fn intake_can_progress_when_one_batch_is_in_flight() {
let gate = Arc::new(Notify::new());
let provider = Arc::new(BlockingProvider { gate: gate.clone() });
let (mut service, _rx) = EmbedderService::new(provider, 10_000, 1, 1).unwrap();
let first = service
.enqueue(BatchItem {
meta: 1usize,
input: prepared(1),
})
.await
.unwrap();
assert!(!first);
for n in 2usize..=6usize {
let flushed = timeout(
Duration::from_millis(75),
service.enqueue(BatchItem {
meta: n,
input: prepared(1),
}),
)
.await
.expect("enqueue should remain non-blocking while queue capacity remains")
.unwrap();
assert!(flushed);
}
let mut observed_backpressure = false;
for n in 7usize..=20usize {
let pushed = timeout(
Duration::from_millis(75),
service.enqueue(BatchItem {
meta: n,
input: prepared(1),
}),
)
.await;
if pushed.is_err() {
observed_backpressure = true;
break;
}
}
assert!(
observed_backpressure,
"enqueue should eventually backpressure when execution is stalled"
);
gate.notify_waiters();
}
}