agentic_workflow/engine/
batch.rs1use std::collections::HashMap;
2
3use chrono::Utc;
4use uuid::Uuid;
5
6use crate::types::{
7 BatchItem, BatchItemStatus, BatchJob, BatchProgress, BatchReport, BatchStatus,
8 WorkflowError, WorkflowResult,
9};
10
11pub struct BatchEngine {
13 jobs: HashMap<String, BatchJob>,
14}
15
16impl BatchEngine {
17 pub fn new() -> Self {
18 Self {
19 jobs: HashMap::new(),
20 }
21 }
22
23 pub fn create_batch(
25 &mut self,
26 workflow_id: &str,
27 items: Vec<serde_json::Value>,
28 concurrency: usize,
29 checkpoint_every: usize,
30 ) -> WorkflowResult<String> {
31 let id = Uuid::new_v4().to_string();
32 let batch_items: Vec<BatchItem> = items
33 .into_iter()
34 .enumerate()
35 .map(|(i, input)| BatchItem {
36 index: i,
37 input,
38 status: BatchItemStatus::Pending,
39 output: None,
40 error: None,
41 duration_ms: None,
42 })
43 .collect();
44
45 let job = BatchJob {
46 id: id.clone(),
47 workflow_id: workflow_id.to_string(),
48 items: batch_items,
49 concurrency: concurrency.max(1),
50 checkpoint_every: checkpoint_every.max(1),
51 status: BatchStatus::Pending,
52 created_at: Utc::now(),
53 started_at: None,
54 completed_at: None,
55 };
56
57 self.jobs.insert(id.clone(), job);
58 Ok(id)
59 }
60
61 pub fn get_progress(&self, batch_id: &str) -> WorkflowResult<BatchProgress> {
63 let job = self
64 .jobs
65 .get(batch_id)
66 .ok_or_else(|| WorkflowError::BatchError(format!("Not found: {}", batch_id)))?;
67
68 let total = job.items.len();
69 let completed = job.items.iter().filter(|i| i.status == BatchItemStatus::Success).count();
70 let failed = job.items.iter().filter(|i| i.status == BatchItemStatus::Failed).count();
71 let skipped = job.items.iter().filter(|i| i.status == BatchItemStatus::Skipped).count();
72 let running = job.items.iter().filter(|i| i.status == BatchItemStatus::Running).count();
73 let pending = job.items.iter().filter(|i| i.status == BatchItemStatus::Pending).count();
74
75 let percent = if total > 0 {
76 (completed as f64 / total as f64) * 100.0
77 } else {
78 0.0
79 };
80
81 let last_checkpoint = job
83 .items
84 .iter()
85 .filter(|i| i.status == BatchItemStatus::Success)
86 .map(|i| i.index)
87 .max();
88
89 Ok(BatchProgress {
90 batch_id: batch_id.to_string(),
91 total_items: total,
92 completed,
93 failed,
94 skipped,
95 running,
96 pending,
97 percent_complete: percent,
98 estimated_remaining_ms: None,
99 last_checkpoint_index: last_checkpoint,
100 })
101 }
102
103 pub fn get_report(&self, batch_id: &str) -> WorkflowResult<BatchReport> {
105 let job = self
106 .jobs
107 .get(batch_id)
108 .ok_or_else(|| WorkflowError::BatchError(format!("Not found: {}", batch_id)))?;
109
110 let success_count = job.items.iter().filter(|i| i.status == BatchItemStatus::Success).count();
111 let fail_count = job.items.iter().filter(|i| i.status == BatchItemStatus::Failed).count();
112 let skip_count = job.items.iter().filter(|i| i.status == BatchItemStatus::Skipped).count();
113
114 let total_duration: u64 = job
115 .items
116 .iter()
117 .filter_map(|i| i.duration_ms)
118 .sum();
119
120 let processed = success_count + fail_count;
121 let avg = if processed > 0 {
122 total_duration as f64 / processed as f64
123 } else {
124 0.0
125 };
126
127 let mut error_groups: HashMap<String, Vec<usize>> = HashMap::new();
129 for item in &job.items {
130 if let Some(err) = &item.error {
131 error_groups
132 .entry(err.clone())
133 .or_default()
134 .push(item.index);
135 }
136 }
137
138 let error_summary = error_groups
139 .into_iter()
140 .map(|(pattern, indices)| crate::types::batch::BatchErrorGroup {
141 error_pattern: pattern,
142 count: indices.len(),
143 sample_indices: indices.into_iter().take(5).collect(),
144 })
145 .collect();
146
147 Ok(BatchReport {
148 batch_id: batch_id.to_string(),
149 total_items: job.items.len(),
150 success_count,
151 fail_count,
152 skip_count,
153 total_duration_ms: total_duration,
154 avg_item_duration_ms: avg,
155 error_summary,
156 })
157 }
158
159 pub fn get_job(&self, batch_id: &str) -> WorkflowResult<&BatchJob> {
161 self.jobs
162 .get(batch_id)
163 .ok_or_else(|| WorkflowError::BatchError(format!("Not found: {}", batch_id)))
164 }
165}
166
167impl Default for BatchEngine {
168 fn default() -> Self {
169 Self::new()
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176
177 #[test]
178 fn test_batch_creation() {
179 let mut engine = BatchEngine::new();
180 let items = vec![
181 serde_json::json!({"id": 1}),
182 serde_json::json!({"id": 2}),
183 serde_json::json!({"id": 3}),
184 ];
185
186 let bid = engine.create_batch("wf-1", items, 2, 10).unwrap();
187 let progress = engine.get_progress(&bid).unwrap();
188 assert_eq!(progress.total_items, 3);
189 assert_eq!(progress.pending, 3);
190 assert_eq!(progress.percent_complete, 0.0);
191 }
192}