batuta/serve/banco/
batch.rs1use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct BatchItem {
14 pub id: String,
15 pub messages: Vec<crate::serve::templates::ChatMessage>,
16 #[serde(default = "default_max_tokens")]
17 pub max_tokens: u32,
18}
19
20fn default_max_tokens() -> u32 {
21 256
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct BatchItemResult {
27 pub id: String,
28 pub content: String,
29 pub finish_reason: String,
30 pub tokens: u32,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct BatchJob {
36 pub batch_id: String,
37 pub status: BatchStatus,
38 pub total_items: usize,
39 pub completed_items: usize,
40 pub results: Vec<BatchItemResult>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
45#[serde(rename_all = "snake_case")]
46pub enum BatchStatus {
47 Processing,
48 Complete,
49 Failed,
50}
51
52pub struct BatchStore {
54 jobs: RwLock<HashMap<String, BatchJob>>,
55 counter: std::sync::atomic::AtomicU64,
56}
57
58impl BatchStore {
59 #[must_use]
60 pub fn new() -> Arc<Self> {
61 Arc::new(Self {
62 jobs: RwLock::new(HashMap::new()),
63 counter: std::sync::atomic::AtomicU64::new(0),
64 })
65 }
66
67 pub fn run(
69 &self,
70 items: Vec<BatchItem>,
71 process_fn: impl Fn(&BatchItem) -> BatchItemResult,
72 ) -> BatchJob {
73 let seq = self.counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
74 let batch_id = format!("batch-{}-{seq}", epoch_secs());
75 let total = items.len();
76
77 let results: Vec<BatchItemResult> = items.iter().map(&process_fn).collect();
78
79 let job = BatchJob {
80 batch_id: batch_id.clone(),
81 status: BatchStatus::Complete,
82 total_items: total,
83 completed_items: results.len(),
84 results,
85 };
86
87 if let Ok(mut store) = self.jobs.write() {
88 store.insert(batch_id, job.clone());
89 }
90
91 job
92 }
93
94 #[must_use]
96 pub fn get(&self, id: &str) -> Option<BatchJob> {
97 self.jobs.read().unwrap_or_else(|e| e.into_inner()).get(id).cloned()
98 }
99
100 #[must_use]
102 pub fn list(&self) -> Vec<BatchJob> {
103 let store = self.jobs.read().unwrap_or_else(|e| e.into_inner());
104 store.values().cloned().collect()
105 }
106}
107
108fn epoch_secs() -> u64 {
109 std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_secs()
110}