Skip to main content

batuta/serve/banco/
batch.rs

1//! Batch inference — process multiple prompts in a single request.
2//!
3//! Accepts a list of prompt items, processes each through the chat pipeline,
4//! and returns all results. Useful for dataset evaluation, bulk classification,
5//! and generating training data.
6
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10
11/// A single item in a batch request.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct BatchItem {
14    pub id: String,
15    pub messages: Vec<crate::serve::templates::ChatMessage>,
16    #[serde(default = "default_max_tokens")]
17    pub max_tokens: u32,
18}
19
20fn default_max_tokens() -> u32 {
21    256
22}
23
24/// Result for a single batch item.
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct BatchItemResult {
27    pub id: String,
28    pub content: String,
29    pub finish_reason: String,
30    pub tokens: u32,
31}
32
33/// A batch job.
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct BatchJob {
36    pub batch_id: String,
37    pub status: BatchStatus,
38    pub total_items: usize,
39    pub completed_items: usize,
40    pub results: Vec<BatchItemResult>,
41}
42
43/// Batch job status.
44#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
45#[serde(rename_all = "snake_case")]
46pub enum BatchStatus {
47    Processing,
48    Complete,
49    Failed,
50}
51
52/// Batch store — tracks batch jobs.
53pub struct BatchStore {
54    jobs: RwLock<HashMap<String, BatchJob>>,
55    counter: std::sync::atomic::AtomicU64,
56}
57
58impl BatchStore {
59    #[must_use]
60    pub fn new() -> Arc<Self> {
61        Arc::new(Self {
62            jobs: RwLock::new(HashMap::new()),
63            counter: std::sync::atomic::AtomicU64::new(0),
64        })
65    }
66
67    /// Create and immediately run a batch job (synchronous for Phase 3).
68    pub fn run(
69        &self,
70        items: Vec<BatchItem>,
71        process_fn: impl Fn(&BatchItem) -> BatchItemResult,
72    ) -> BatchJob {
73        let seq = self.counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
74        let batch_id = format!("batch-{}-{seq}", epoch_secs());
75        let total = items.len();
76
77        let results: Vec<BatchItemResult> = items.iter().map(&process_fn).collect();
78
79        let job = BatchJob {
80            batch_id: batch_id.clone(),
81            status: BatchStatus::Complete,
82            total_items: total,
83            completed_items: results.len(),
84            results,
85        };
86
87        if let Ok(mut store) = self.jobs.write() {
88            store.insert(batch_id, job.clone());
89        }
90
91        job
92    }
93
94    /// Get a batch job by ID.
95    #[must_use]
96    pub fn get(&self, id: &str) -> Option<BatchJob> {
97        self.jobs.read().unwrap_or_else(|e| e.into_inner()).get(id).cloned()
98    }
99
100    /// List all batch jobs.
101    #[must_use]
102    pub fn list(&self) -> Vec<BatchJob> {
103        let store = self.jobs.read().unwrap_or_else(|e| e.into_inner());
104        store.values().cloned().collect()
105    }
106}
107
108fn epoch_secs() -> u64 {
109    std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_secs()
110}