use std::sync::Arc;
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use tokio::sync::Semaphore;
use tokio_util::sync::CancellationToken;
use crate::inference::{InferenceRequest, InferenceResponse};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum BatchStatus {
Running,
Completed,
Cancelled,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchItemResult {
pub index: usize,
pub response: Option<InferenceResponse>,
pub error: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchProgress {
pub id: String,
pub total: usize,
pub completed: usize,
pub failed: usize,
pub status: BatchStatus,
pub results: Vec<BatchItemResult>,
}
struct BatchState {
progress: tokio::sync::Mutex<BatchProgress>,
cancel: CancellationToken,
completed_at: std::sync::Mutex<Option<std::time::Instant>>,
}
pub struct BatchManager {
semaphore: Arc<Semaphore>,
batches: DashMap<String, Arc<BatchState>>,
}
impl BatchManager {
#[must_use]
pub fn new(max_concurrent: usize) -> Self {
Self {
semaphore: Arc::new(Semaphore::new(max_concurrent)),
batches: DashMap::new(),
}
}
pub fn submit<F, Fut>(
&self,
batch_id: String,
requests: Vec<InferenceRequest>,
infer_fn: F,
) -> String
where
F: Fn(InferenceRequest) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = anyhow::Result<InferenceResponse>> + Send + 'static,
{
let total = requests.len();
let progress = BatchProgress {
id: batch_id.clone(),
total,
completed: 0,
failed: 0,
status: BatchStatus::Running,
results: (0..total)
.map(|i| BatchItemResult {
index: i,
response: None,
error: None,
})
.collect(),
};
let state = Arc::new(BatchState {
progress: tokio::sync::Mutex::new(progress),
cancel: CancellationToken::new(),
completed_at: std::sync::Mutex::new(None),
});
self.batches.insert(batch_id.clone(), state.clone());
tracing::info!(batch_id = %batch_id, total, "batch submitted");
let semaphore = self.semaphore.clone();
let infer_fn = Arc::new(infer_fn);
tokio::spawn(async move {
let mut handles = Vec::with_capacity(total);
for (index, request) in requests.into_iter().enumerate() {
let sem = semaphore.clone();
let st = state.clone();
let f = infer_fn.clone();
let handle = tokio::spawn(async move {
if st.cancel.is_cancelled() {
return;
}
let _permit = match sem.acquire().await {
Ok(p) => p,
Err(_) => return, };
if st.cancel.is_cancelled() {
return;
}
let result = f(request).await;
let mut prog = st.progress.lock().await;
match result {
Ok(response) => {
prog.results[index].response = Some(response);
}
Err(e) => {
prog.results[index].error = Some(e.to_string());
prog.failed += 1;
}
}
prog.completed += 1;
});
handles.push(handle);
}
for handle in handles {
let _ = handle.await;
}
let mut prog = state.progress.lock().await;
if state.cancel.is_cancelled() {
prog.status = BatchStatus::Cancelled;
} else {
prog.status = BatchStatus::Completed;
}
drop(prog);
if let Ok(mut ts) = state.completed_at.lock() {
*ts = Some(std::time::Instant::now());
}
});
batch_id
}
pub async fn get_progress(&self, batch_id: &str) -> Option<BatchProgress> {
let state = self.batches.get(batch_id)?;
let prog = state.progress.lock().await;
Some(prog.clone())
}
pub fn cancel(&self, batch_id: &str) -> bool {
if let Some(state) = self.batches.get(batch_id) {
state.cancel.cancel();
true
} else {
false
}
}
pub fn remove(&self, batch_id: &str) -> bool {
self.batches.remove(batch_id).is_some()
}
pub fn evict_completed(&self, max_age: std::time::Duration) -> usize {
let now = std::time::Instant::now();
let mut evicted = 0;
let keys: Vec<String> = self
.batches
.iter()
.filter_map(|entry| {
let state = entry.value();
if let Ok(ts) = state.completed_at.lock()
&& let Some(completed) = *ts
&& now.duration_since(completed) > max_age
{
return Some(entry.key().clone());
}
None
})
.collect();
for key in keys {
self.batches.remove(&key);
evicted += 1;
}
evicted
}
#[must_use]
pub fn active_count(&self) -> usize {
self.batches.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::inference::TokenUsage;
fn make_request(model: &str) -> InferenceRequest {
InferenceRequest {
model: model.into(),
prompt: "test".into(),
..Default::default()
}
}
fn make_response(model: &str) -> InferenceResponse {
InferenceResponse {
text: "response".into(),
model: model.into(),
usage: TokenUsage::default(),
tool_calls: Vec::new(),
provider: "test".into(),
latency_ms: 1,
}
}
#[test]
fn batch_manager_creation() {
let mgr = BatchManager::new(10);
assert_eq!(mgr.active_count(), 0);
}
#[tokio::test]
async fn batch_submit_and_complete() {
let mgr = BatchManager::new(4);
let requests = vec![make_request("model1"), make_request("model2")];
let batch_id = mgr.submit("batch-1".into(), requests, |req| async move {
Ok(make_response(&req.model))
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let progress = mgr.get_progress(&batch_id).await.unwrap();
assert_eq!(progress.total, 2);
assert_eq!(progress.completed, 2);
assert_eq!(progress.failed, 0);
assert_eq!(progress.status, BatchStatus::Completed);
}
#[tokio::test]
async fn batch_with_failures() {
let mgr = BatchManager::new(4);
let requests = vec![make_request("ok"), make_request("fail")];
mgr.submit("batch-2".into(), requests, |req| async move {
if req.model == "fail" {
Err(anyhow::anyhow!("simulated failure"))
} else {
Ok(make_response(&req.model))
}
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let progress = mgr.get_progress("batch-2").await.unwrap();
assert_eq!(progress.completed, 2);
assert_eq!(progress.failed, 1);
}
#[tokio::test]
async fn batch_cancel() {
let mgr = BatchManager::new(1); let requests = vec![make_request("a"), make_request("b"), make_request("c")];
mgr.submit("batch-3".into(), requests, |_req| async {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
Ok(make_response("x"))
});
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
assert!(mgr.cancel("batch-3"));
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
let progress = mgr.get_progress("batch-3").await.unwrap();
assert_eq!(progress.status, BatchStatus::Cancelled);
}
#[test]
fn batch_remove() {
let mgr = BatchManager::new(4);
assert!(!mgr.remove("nonexistent"));
}
#[tokio::test]
async fn batch_nonexistent_progress() {
let mgr = BatchManager::new(4);
assert!(mgr.get_progress("nope").await.is_none());
}
#[test]
fn batch_cancel_nonexistent() {
let mgr = BatchManager::new(4);
assert!(!mgr.cancel("nonexistent"));
}
#[tokio::test]
async fn batch_remove_completed() {
let mgr = BatchManager::new(4);
let requests = vec![make_request("model1")];
mgr.submit("batch-rm".into(), requests, |req| async move {
Ok(make_response(&req.model))
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert_eq!(mgr.active_count(), 1);
assert!(mgr.remove("batch-rm"));
assert_eq!(mgr.active_count(), 0);
}
#[tokio::test]
async fn batch_evict_completed_none_old_enough() {
let mgr = BatchManager::new(4);
let requests = vec![make_request("model1")];
mgr.submit("batch-ev".into(), requests, |req| async move {
Ok(make_response(&req.model))
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let evicted = mgr.evict_completed(std::time::Duration::from_secs(3600));
assert_eq!(evicted, 0);
assert_eq!(mgr.active_count(), 1);
}
#[tokio::test]
async fn batch_evict_completed_old_entries() {
let mgr = BatchManager::new(4);
let requests = vec![make_request("model1")];
mgr.submit("batch-old".into(), requests, |req| async move {
Ok(make_response(&req.model))
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let evicted = mgr.evict_completed(std::time::Duration::ZERO);
assert_eq!(evicted, 1);
assert_eq!(mgr.active_count(), 0);
}
#[tokio::test]
async fn batch_evict_running_not_evicted() {
let mgr = BatchManager::new(1);
let requests = vec![make_request("slow")];
mgr.submit("batch-running".into(), requests, |_req| async {
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
Ok(make_response("slow"))
});
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let evicted = mgr.evict_completed(std::time::Duration::ZERO);
assert_eq!(evicted, 0);
mgr.cancel("batch-running");
}
#[tokio::test]
async fn batch_progress_shows_individual_results() {
let mgr = BatchManager::new(4);
let requests = vec![make_request("ok"), make_request("fail")];
mgr.submit("batch-results".into(), requests, |req| async move {
if req.model == "fail" {
Err(anyhow::anyhow!("oops"))
} else {
Ok(make_response(&req.model))
}
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let progress = mgr.get_progress("batch-results").await.unwrap();
assert_eq!(progress.results.len(), 2);
let has_response = progress.results.iter().any(|r| r.response.is_some());
let has_error = progress.results.iter().any(|r| r.error.is_some());
assert!(has_response);
assert!(has_error);
}
#[test]
fn batch_status_serde_roundtrip() {
let statuses = [
BatchStatus::Running,
BatchStatus::Completed,
BatchStatus::Cancelled,
];
for status in &statuses {
let json = serde_json::to_string(status).unwrap();
let back: BatchStatus = serde_json::from_str(&json).unwrap();
assert_eq!(*status, back);
}
}
}