Skip to main content

agentic_workflow/engine/
batch.rs

1use std::collections::HashMap;
2
3use chrono::Utc;
4use uuid::Uuid;
5
6use crate::types::{
7    BatchItem, BatchItemStatus, BatchJob, BatchProgress, BatchReport, BatchStatus,
8    WorkflowError, WorkflowResult,
9};
10
11/// Batch processing engine with controlled parallelism.
12pub struct BatchEngine {
13    jobs: HashMap<String, BatchJob>,
14}
15
16impl BatchEngine {
17    pub fn new() -> Self {
18        Self {
19            jobs: HashMap::new(),
20        }
21    }
22
23    /// Create a batch job from a list of items.
24    pub fn create_batch(
25        &mut self,
26        workflow_id: &str,
27        items: Vec<serde_json::Value>,
28        concurrency: usize,
29        checkpoint_every: usize,
30    ) -> WorkflowResult<String> {
31        let id = Uuid::new_v4().to_string();
32        let batch_items: Vec<BatchItem> = items
33            .into_iter()
34            .enumerate()
35            .map(|(i, input)| BatchItem {
36                index: i,
37                input,
38                status: BatchItemStatus::Pending,
39                output: None,
40                error: None,
41                duration_ms: None,
42            })
43            .collect();
44
45        let job = BatchJob {
46            id: id.clone(),
47            workflow_id: workflow_id.to_string(),
48            items: batch_items,
49            concurrency: concurrency.max(1),
50            checkpoint_every: checkpoint_every.max(1),
51            status: BatchStatus::Pending,
52            created_at: Utc::now(),
53            started_at: None,
54            completed_at: None,
55        };
56
57        self.jobs.insert(id.clone(), job);
58        Ok(id)
59    }
60
61    /// Get batch progress.
62    pub fn get_progress(&self, batch_id: &str) -> WorkflowResult<BatchProgress> {
63        let job = self
64            .jobs
65            .get(batch_id)
66            .ok_or_else(|| WorkflowError::BatchError(format!("Not found: {}", batch_id)))?;
67
68        let total = job.items.len();
69        let completed = job.items.iter().filter(|i| i.status == BatchItemStatus::Success).count();
70        let failed = job.items.iter().filter(|i| i.status == BatchItemStatus::Failed).count();
71        let skipped = job.items.iter().filter(|i| i.status == BatchItemStatus::Skipped).count();
72        let running = job.items.iter().filter(|i| i.status == BatchItemStatus::Running).count();
73        let pending = job.items.iter().filter(|i| i.status == BatchItemStatus::Pending).count();
74
75        let percent = if total > 0 {
76            (completed as f64 / total as f64) * 100.0
77        } else {
78            0.0
79        };
80
81        // Find last checkpoint
82        let last_checkpoint = job
83            .items
84            .iter()
85            .filter(|i| i.status == BatchItemStatus::Success)
86            .map(|i| i.index)
87            .max();
88
89        Ok(BatchProgress {
90            batch_id: batch_id.to_string(),
91            total_items: total,
92            completed,
93            failed,
94            skipped,
95            running,
96            pending,
97            percent_complete: percent,
98            estimated_remaining_ms: None,
99            last_checkpoint_index: last_checkpoint,
100        })
101    }
102
103    /// Generate batch completion report.
104    pub fn get_report(&self, batch_id: &str) -> WorkflowResult<BatchReport> {
105        let job = self
106            .jobs
107            .get(batch_id)
108            .ok_or_else(|| WorkflowError::BatchError(format!("Not found: {}", batch_id)))?;
109
110        let success_count = job.items.iter().filter(|i| i.status == BatchItemStatus::Success).count();
111        let fail_count = job.items.iter().filter(|i| i.status == BatchItemStatus::Failed).count();
112        let skip_count = job.items.iter().filter(|i| i.status == BatchItemStatus::Skipped).count();
113
114        let total_duration: u64 = job
115            .items
116            .iter()
117            .filter_map(|i| i.duration_ms)
118            .sum();
119
120        let processed = success_count + fail_count;
121        let avg = if processed > 0 {
122            total_duration as f64 / processed as f64
123        } else {
124            0.0
125        };
126
127        // Group errors by pattern
128        let mut error_groups: HashMap<String, Vec<usize>> = HashMap::new();
129        for item in &job.items {
130            if let Some(err) = &item.error {
131                error_groups
132                    .entry(err.clone())
133                    .or_default()
134                    .push(item.index);
135            }
136        }
137
138        let error_summary = error_groups
139            .into_iter()
140            .map(|(pattern, indices)| crate::types::batch::BatchErrorGroup {
141                error_pattern: pattern,
142                count: indices.len(),
143                sample_indices: indices.into_iter().take(5).collect(),
144            })
145            .collect();
146
147        Ok(BatchReport {
148            batch_id: batch_id.to_string(),
149            total_items: job.items.len(),
150            success_count,
151            fail_count,
152            skip_count,
153            total_duration_ms: total_duration,
154            avg_item_duration_ms: avg,
155            error_summary,
156        })
157    }
158
159    /// Get a batch job.
160    pub fn get_job(&self, batch_id: &str) -> WorkflowResult<&BatchJob> {
161        self.jobs
162            .get(batch_id)
163            .ok_or_else(|| WorkflowError::BatchError(format!("Not found: {}", batch_id)))
164    }
165}
166
167impl Default for BatchEngine {
168    fn default() -> Self {
169        Self::new()
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176
177    #[test]
178    fn test_batch_creation() {
179        let mut engine = BatchEngine::new();
180        let items = vec![
181            serde_json::json!({"id": 1}),
182            serde_json::json!({"id": 2}),
183            serde_json::json!({"id": 3}),
184        ];
185
186        let bid = engine.create_batch("wf-1", items, 2, 10).unwrap();
187        let progress = engine.get_progress(&bid).unwrap();
188        assert_eq!(progress.total_items, 3);
189        assert_eq!(progress.pending, 3);
190        assert_eq!(progress.percent_complete, 0.0);
191    }
192}