use std::time::Duration;
use tokio::sync::{mpsc, oneshot};
use crate::RerankResult;
use crate::engine::Scorer;
pub struct BatchRequest {
pub query: String,
pub documents: Vec<String>,
pub top_k: Option<usize>,
pub response_tx: oneshot::Sender<crate::Result<Vec<RerankResult>>>,
}
pub struct DynamicBatcher {
max_batch_size: usize,
max_wait: Duration,
request_tx: mpsc::Sender<BatchRequest>,
}
impl DynamicBatcher {
pub fn new(max_batch_size: usize, max_wait: Duration, scorer: Box<dyn Scorer>) -> Self {
let (request_tx, request_rx) = mpsc::channel::<BatchRequest>(1024);
let (gpu_tx, gpu_rx) = std::sync::mpsc::channel::<Vec<BatchRequest>>();
std::thread::spawn(move || {
gpu_thread_loop(scorer, gpu_rx);
});
tokio::spawn(Self::batcher_loop(
request_rx,
gpu_tx,
max_batch_size,
max_wait,
));
Self {
max_batch_size,
max_wait,
request_tx,
}
}
pub fn with_gpu_sender(
max_batch_size: usize,
max_wait: Duration,
gpu_tx: std::sync::mpsc::Sender<Vec<BatchRequest>>,
) -> Self {
let (request_tx, request_rx) = mpsc::channel::<BatchRequest>(1024);
tokio::spawn(Self::batcher_loop(
request_rx,
gpu_tx,
max_batch_size,
max_wait,
));
Self {
max_batch_size,
max_wait,
request_tx,
}
}
pub async fn submit(
&self,
query: String,
documents: Vec<String>,
top_k: Option<usize>,
) -> crate::Result<Vec<RerankResult>> {
let (response_tx, response_rx) = oneshot::channel();
let req = BatchRequest {
query,
documents,
top_k,
response_tx,
};
self.request_tx
.send(req)
.await
.map_err(|_| crate::Error::Inference("Batcher channel closed".into()))?;
response_rx
.await
.map_err(|_| crate::Error::Inference("Response channel closed".into()))?
}
pub fn max_batch_size(&self) -> usize {
self.max_batch_size
}
pub fn max_wait(&self) -> Duration {
self.max_wait
}
async fn batcher_loop(
mut request_rx: mpsc::Receiver<BatchRequest>,
gpu_tx: std::sync::mpsc::Sender<Vec<BatchRequest>>,
max_batch_size: usize,
max_wait: Duration,
) {
let mut batch: Vec<BatchRequest> = Vec::with_capacity(max_batch_size);
loop {
match request_rx.recv().await {
Some(req) => batch.push(req),
None => break, }
let deadline = tokio::time::Instant::now() + max_wait;
while batch.len() < max_batch_size {
tokio::select! {
biased;
req = request_rx.recv() => {
match req {
Some(r) => batch.push(r),
None => break,
}
}
_ = tokio::time::sleep_until(deadline) => break,
}
}
let ready_batch: Vec<BatchRequest> = batch.drain(..).collect();
if gpu_tx.send(ready_batch).is_err() {
tracing::error!("GPU thread channel closed, batcher shutting down");
break;
}
}
tracing::info!("Batcher loop exiting");
}
}
fn gpu_thread_loop(
scorer: Box<dyn Scorer>,
batch_rx: std::sync::mpsc::Receiver<Vec<BatchRequest>>,
) {
tracing::info!("GPU thread started");
while let Ok(batch) = batch_rx.recv() {
let batch_size = batch.len();
tracing::debug!(batch_size, "Processing batch");
for req in batch {
let result = scorer.score(&req.query, &req.documents).map(|mut results| {
if let Some(top_k) = req.top_k {
results.truncate(top_k);
}
results
});
let _ = req.response_tx.send(result);
}
}
tracing::info!("GPU thread exiting");
}
#[cfg(test)]
mod tests {
use super::*;
struct MockScorer;
impl Scorer for MockScorer {
fn score(&self, _query: &str, documents: &[String]) -> crate::Result<Vec<RerankResult>> {
let mut results: Vec<RerankResult> = documents
.iter()
.enumerate()
.map(|(i, _)| RerankResult {
index: i,
score: 1.0 / (i as f32 + 1.0),
document: None,
})
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(results)
}
}
#[tokio::test]
async fn test_batcher_single_request() {
let batcher = DynamicBatcher::new(8, Duration::from_millis(50), Box::new(MockScorer));
let results = batcher
.submit(
"test query".to_string(),
vec!["doc1".to_string(), "doc2".to_string()],
None,
)
.await
.unwrap();
assert_eq!(results.len(), 2);
assert!(results[0].score >= results[1].score);
}
#[tokio::test]
async fn test_batcher_concurrent_requests() {
let batcher = std::sync::Arc::new(DynamicBatcher::new(
16,
Duration::from_millis(20),
Box::new(MockScorer),
));
let mut handles = Vec::new();
for i in 0..20 {
let b = batcher.clone();
handles.push(tokio::spawn(async move {
b.submit(
format!("query {i}"),
vec![format!("doc_{i}_a"), format!("doc_{i}_b")],
None,
)
.await
}));
}
for handle in handles {
let result = handle.await.unwrap().unwrap();
assert_eq!(result.len(), 2);
}
}
#[tokio::test]
async fn test_batcher_top_k() {
let batcher = DynamicBatcher::new(8, Duration::from_millis(50), Box::new(MockScorer));
let results = batcher
.submit(
"test".to_string(),
vec!["a".to_string(), "b".to_string(), "c".to_string()],
Some(1),
)
.await
.unwrap();
assert_eq!(results.len(), 1);
}
#[test]
fn test_batcher_accessors() {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let batcher = DynamicBatcher::new(16, Duration::from_millis(100), Box::new(MockScorer));
assert_eq!(batcher.max_batch_size(), 16);
assert_eq!(batcher.max_wait(), Duration::from_millis(100));
});
}
#[tokio::test]
async fn test_batcher_timeout_dispatches_partial_batch() {
let batcher = std::sync::Arc::new(DynamicBatcher::new(
100,
Duration::from_millis(20),
Box::new(MockScorer),
));
let b = batcher.clone();
let h1 = tokio::spawn(async move {
b.submit("q1".to_string(), vec!["d1".to_string()], None)
.await
});
let b = batcher.clone();
let h2 = tokio::spawn(async move {
b.submit("q2".to_string(), vec!["d2".to_string()], None)
.await
});
let r1 = h1.await.unwrap().unwrap();
let r2 = h2.await.unwrap().unwrap();
assert_eq!(r1.len(), 1);
assert_eq!(r2.len(), 1);
}
#[tokio::test]
async fn test_batcher_full_dispatch() {
let batcher = std::sync::Arc::new(DynamicBatcher::new(
2,
Duration::from_secs(60), Box::new(MockScorer),
));
let b = batcher.clone();
let h1 =
tokio::spawn(
async move { b.submit("q".to_string(), vec!["a".to_string()], None).await },
);
let b = batcher.clone();
let h2 =
tokio::spawn(
async move { b.submit("q".to_string(), vec!["b".to_string()], None).await },
);
let (r1, r2) = tokio::join!(h1, h2);
assert!(r1.unwrap().is_ok());
assert!(r2.unwrap().is_ok());
}
#[tokio::test]
async fn test_batcher_top_k_zero() {
let batcher = DynamicBatcher::new(8, Duration::from_millis(50), Box::new(MockScorer));
let results = batcher
.submit(
"test".to_string(),
vec!["a".to_string(), "b".to_string()],
Some(0),
)
.await
.unwrap();
assert!(results.is_empty(), "top_k=0 should return empty results");
}
#[tokio::test]
async fn test_batcher_empty_documents() {
let batcher = DynamicBatcher::new(8, Duration::from_millis(50), Box::new(MockScorer));
let results = batcher
.submit("query".to_string(), vec![], None)
.await
.unwrap();
assert!(
results.is_empty(),
"Empty documents should return empty results"
);
}
#[tokio::test]
async fn test_batcher_dropped_sender() {
let (gpu_tx, gpu_rx) = std::sync::mpsc::channel::<Vec<BatchRequest>>();
let batcher = DynamicBatcher::with_gpu_sender(8, Duration::from_millis(20), gpu_tx);
drop(gpu_rx);
tokio::time::sleep(Duration::from_millis(50)).await;
let result = batcher
.submit("q".to_string(), vec!["d".to_string()], None)
.await;
let _ = result; }
}