oxify_model/
batching.rs

1//! Intelligent node execution batching
2//!
3//! This module provides analysis and strategies for batching node executions
4//! to improve workflow performance through parallelization.
5
6use crate::{Node, NodeId, NodeKind, Workflow};
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet};
9
10/// Batch execution plan for a workflow
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct BatchPlan {
13    /// Execution batches in topological order
14    pub batches: Vec<ExecutionBatch>,
15
16    /// Total number of nodes
17    pub total_nodes: usize,
18
19    /// Maximum parallelism (largest batch size)
20    pub max_parallelism: usize,
21
22    /// Estimated speedup factor
23    pub speedup_factor: f64,
24
25    /// Batch statistics
26    pub stats: BatchStats,
27}
28
29/// A batch of nodes that can be executed in parallel
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct ExecutionBatch {
32    /// Batch level (0-indexed)
33    pub level: usize,
34
35    /// Nodes in this batch
36    pub nodes: Vec<NodeId>,
37
38    /// Estimated execution time for this batch (ms)
39    pub estimated_time_ms: u64,
40
41    /// Whether this batch can benefit from parallel execution
42    pub parallelizable: bool,
43}
44
45/// Statistics about batching strategy
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct BatchStats {
48    /// Number of batches
49    pub batch_count: usize,
50
51    /// Average batch size
52    pub avg_batch_size: f64,
53
54    /// Number of sequential-only batches
55    pub sequential_batches: usize,
56
57    /// Number of parallel batches
58    pub parallel_batches: usize,
59
60    /// Parallelization efficiency (0.0 to 1.0)
61    pub efficiency: f64,
62}
63
64/// Node batching analyzer
65pub struct BatchAnalyzer;
66
67impl BatchAnalyzer {
68    /// Analyze a workflow and generate a batch execution plan
69    pub fn analyze(workflow: &Workflow) -> BatchPlan {
70        // Build dependency graph
71        let dependencies = Self::build_dependency_graph(workflow);
72
73        // Compute in-degrees for topological sorting
74        let in_degrees = Self::compute_in_degrees(workflow, &dependencies);
75
76        // Generate batches using level-based topological sort
77        let batches = Self::generate_batches(workflow, &dependencies, in_degrees);
78
79        // Calculate statistics
80        let stats = Self::calculate_stats(&batches);
81
82        // Calculate speedup factor
83        let speedup_factor = Self::calculate_speedup(&batches, workflow.nodes.len());
84
85        // Find max parallelism
86        let max_parallelism = batches.iter().map(|b| b.nodes.len()).max().unwrap_or(0);
87
88        BatchPlan {
89            total_nodes: workflow.nodes.len(),
90            max_parallelism,
91            speedup_factor,
92            batches,
93            stats,
94        }
95    }
96
97    /// Build dependency graph (node -> list of nodes that depend on it)
98    fn build_dependency_graph(workflow: &Workflow) -> HashMap<NodeId, Vec<NodeId>> {
99        let mut graph: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
100
101        // Initialize with all nodes
102        for node in &workflow.nodes {
103            graph.entry(node.id).or_default();
104        }
105
106        // Add edges
107        for edge in &workflow.edges {
108            graph.entry(edge.from).or_default().push(edge.to);
109        }
110
111        graph
112    }
113
114    /// Compute in-degrees for each node
115    fn compute_in_degrees(
116        workflow: &Workflow,
117        dependencies: &HashMap<NodeId, Vec<NodeId>>,
118    ) -> HashMap<NodeId, usize> {
119        let mut in_degrees: HashMap<NodeId, usize> = HashMap::new();
120
121        // Initialize all nodes with 0 in-degree
122        for node in &workflow.nodes {
123            in_degrees.insert(node.id, 0);
124        }
125
126        // Count incoming edges
127        for children in dependencies.values() {
128            for &child_id in children {
129                *in_degrees.entry(child_id).or_insert(0) += 1;
130            }
131        }
132
133        in_degrees
134    }
135
136    /// Generate execution batches using level-based topological sort
137    fn generate_batches(
138        workflow: &Workflow,
139        dependencies: &HashMap<NodeId, Vec<NodeId>>,
140        mut in_degrees: HashMap<NodeId, usize>,
141    ) -> Vec<ExecutionBatch> {
142        let mut batches = Vec::new();
143        let mut processed = HashSet::new();
144        let mut current_level = 0;
145
146        // Create a map for quick node lookup
147        let node_map: HashMap<NodeId, &Node> = workflow.nodes.iter().map(|n| (n.id, n)).collect();
148
149        while processed.len() < workflow.nodes.len() {
150            // Find all nodes with in-degree 0 (ready to execute)
151            let ready_nodes: Vec<NodeId> = in_degrees
152                .iter()
153                .filter(|(&id, &degree)| degree == 0 && !processed.contains(&id))
154                .map(|(&id, _)| id)
155                .collect();
156
157            if ready_nodes.is_empty() {
158                // No more nodes to process (shouldn't happen with valid DAG)
159                break;
160            }
161
162            // Estimate time for this batch (max of all node times)
163            let estimated_time_ms = ready_nodes
164                .iter()
165                .filter_map(|id| node_map.get(id))
166                .map(|node| Self::estimate_node_time(node))
167                .max()
168                .unwrap_or(100);
169
170            // Check if batch is parallelizable
171            let parallelizable = ready_nodes.len() > 1
172                && ready_nodes.iter().all(|id| {
173                    if let Some(node) = node_map.get(id) {
174                        Self::is_parallelizable(node)
175                    } else {
176                        false
177                    }
178                });
179
180            batches.push(ExecutionBatch {
181                level: current_level,
182                nodes: ready_nodes.clone(),
183                estimated_time_ms,
184                parallelizable,
185            });
186
187            // Mark nodes as processed and update in-degrees
188            for &node_id in &ready_nodes {
189                processed.insert(node_id);
190                in_degrees.remove(&node_id);
191
192                // Reduce in-degree of children
193                if let Some(children) = dependencies.get(&node_id) {
194                    for &child_id in children {
195                        if let Some(degree) = in_degrees.get_mut(&child_id) {
196                            *degree = degree.saturating_sub(1);
197                        }
198                    }
199                }
200            }
201
202            current_level += 1;
203        }
204
205        batches
206    }
207
208    /// Estimate execution time for a node (simplified)
209    fn estimate_node_time(node: &Node) -> u64 {
210        match &node.kind {
211            NodeKind::Start | NodeKind::End => 10,
212            NodeKind::LLM(_) => 3000,
213            NodeKind::Retriever(_) => 500,
214            NodeKind::Code(_) => 1000,
215            NodeKind::Tool(_) => 2000,
216            NodeKind::IfElse(_) | NodeKind::Switch(_) => 50,
217            NodeKind::Loop(_) => 100,
218            NodeKind::TryCatch(_) => 100,
219            NodeKind::SubWorkflow(_) => 5000,
220            NodeKind::Parallel(_) => 200,
221            NodeKind::Approval(_) => 60000,
222            NodeKind::Form(_) => 120000,
223            NodeKind::Vision(_) => 3000,
224        }
225    }
226
227    /// Check if a node can be safely parallelized
228    fn is_parallelizable(node: &Node) -> bool {
229        // Most nodes can be parallelized if they don't have data dependencies
230        // Exceptions: nodes that require sequential execution or have side effects
231        !matches!(node.kind, NodeKind::Approval(_) | NodeKind::Form(_))
232    }
233
234    /// Calculate batch statistics
235    fn calculate_stats(batches: &[ExecutionBatch]) -> BatchStats {
236        let batch_count = batches.len();
237
238        let total_nodes: usize = batches.iter().map(|b| b.nodes.len()).sum();
239        let avg_batch_size = if batch_count > 0 {
240            total_nodes as f64 / batch_count as f64
241        } else {
242            0.0
243        };
244
245        let sequential_batches = batches.iter().filter(|b| !b.parallelizable).count();
246        let parallel_batches = batches.iter().filter(|b| b.parallelizable).count();
247
248        // Efficiency: ratio of nodes that can run in parallel
249        let parallel_nodes: usize = batches
250            .iter()
251            .filter(|b| b.parallelizable)
252            .map(|b| b.nodes.len())
253            .sum();
254
255        let efficiency = if total_nodes > 0 {
256            parallel_nodes as f64 / total_nodes as f64
257        } else {
258            0.0
259        };
260
261        BatchStats {
262            batch_count,
263            avg_batch_size,
264            sequential_batches,
265            parallel_batches,
266            efficiency,
267        }
268    }
269
270    /// Calculate estimated speedup factor from batching
271    fn calculate_speedup(batches: &[ExecutionBatch], total_nodes: usize) -> f64 {
272        if total_nodes == 0 {
273            return 1.0;
274        }
275
276        // Sequential time: sum of all node times
277        let sequential_time: u64 =
278            batches.iter().flat_map(|b| b.nodes.iter()).count() as u64 * 1000; // Assume avg 1s per node
279
280        // Parallel time: sum of batch times (max within each batch)
281        let parallel_time: u64 = batches.iter().map(|b| b.estimated_time_ms).sum();
282
283        if parallel_time > 0 {
284            sequential_time as f64 / parallel_time as f64
285        } else {
286            1.0
287        }
288    }
289
290    /// Get nodes that can be batched together
291    pub fn find_batch_opportunities(workflow: &Workflow) -> Vec<BatchOpportunity> {
292        let plan = Self::analyze(workflow);
293        let node_map: HashMap<NodeId, &Node> = workflow.nodes.iter().map(|n| (n.id, n)).collect();
294
295        let mut opportunities = Vec::new();
296
297        for batch in &plan.batches {
298            if batch.parallelizable && batch.nodes.len() > 1 {
299                let node_names: Vec<String> = batch
300                    .nodes
301                    .iter()
302                    .filter_map(|id| node_map.get(id).map(|n| n.name.clone()))
303                    .collect();
304
305                opportunities.push(BatchOpportunity {
306                    level: batch.level,
307                    node_count: batch.nodes.len(),
308                    node_names,
309                    estimated_speedup: batch.nodes.len() as f64 * 0.8, // Conservative estimate
310                    description: format!(
311                        "Level {} has {} nodes that can run in parallel",
312                        batch.level,
313                        batch.nodes.len()
314                    ),
315                });
316            }
317        }
318
319        opportunities
320    }
321}
322
323/// A batching opportunity identified in the workflow
324#[derive(Debug, Clone, Serialize, Deserialize)]
325pub struct BatchOpportunity {
326    /// Execution level
327    pub level: usize,
328
329    /// Number of nodes in batch
330    pub node_count: usize,
331
332    /// Names of nodes in batch
333    pub node_names: Vec<String>,
334
335    /// Estimated speedup from batching
336    pub estimated_speedup: f64,
337
338    /// Human-readable description
339    pub description: String,
340}
341
342impl BatchPlan {
343    /// Format batch plan as human-readable string
344    pub fn format_summary(&self) -> String {
345        format!(
346            "Batch Execution Plan:\n\
347             Total Nodes: {} | Batches: {} | Max Parallelism: {}\n\
348             Speedup Factor: {:.2}x | Efficiency: {:.0}%\n\
349             Parallel Batches: {} | Sequential Batches: {}\n\
350             Average Batch Size: {:.1}",
351            self.total_nodes,
352            self.stats.batch_count,
353            self.max_parallelism,
354            self.speedup_factor,
355            self.stats.efficiency * 100.0,
356            self.stats.parallel_batches,
357            self.stats.sequential_batches,
358            self.stats.avg_batch_size
359        )
360    }
361
362    /// Get the critical path (longest batch sequence)
363    pub fn critical_path(&self) -> Vec<&ExecutionBatch> {
364        self.batches.iter().collect()
365    }
366
367    /// Get all parallel batches
368    pub fn parallel_batches(&self) -> Vec<&ExecutionBatch> {
369        self.batches.iter().filter(|b| b.parallelizable).collect()
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376    use crate::{Edge, LlmConfig, WorkflowBuilder};
377
378    #[test]
379    fn test_linear_workflow_batching() {
380        let workflow = WorkflowBuilder::new("Linear")
381            .start("Start")
382            .llm(
383                "LLM1",
384                LlmConfig {
385                    provider: "openai".to_string(),
386                    model: "gpt-4".to_string(),
387                    system_prompt: None,
388                    prompt_template: "test1".to_string(),
389                    temperature: None,
390                    max_tokens: Some(100),
391                    tools: vec![],
392                    images: vec![],
393                    extra_params: serde_json::Value::Null,
394                },
395            )
396            .llm(
397                "LLM2",
398                LlmConfig {
399                    provider: "openai".to_string(),
400                    model: "gpt-4".to_string(),
401                    system_prompt: None,
402                    prompt_template: "test2".to_string(),
403                    temperature: None,
404                    max_tokens: Some(100),
405                    tools: vec![],
406                    images: vec![],
407                    extra_params: serde_json::Value::Null,
408                },
409            )
410            .end("End")
411            .build();
412
413        let plan = BatchAnalyzer::analyze(&workflow);
414
415        // Linear workflow should have 4 batches (Start, LLM1, LLM2, End)
416        assert_eq!(plan.batches.len(), 4);
417        assert_eq!(plan.total_nodes, 4);
418        assert_eq!(plan.max_parallelism, 1); // All sequential
419    }
420
421    #[test]
422    fn test_parallel_workflow_batching() {
423        let mut workflow = WorkflowBuilder::new("Parallel").start("Start").build();
424
425        let start_id = workflow.nodes[0].id;
426
427        // Add two parallel LLM nodes
428        let llm1 = Node::new(
429            "LLM1".to_string(),
430            NodeKind::LLM(LlmConfig {
431                provider: "openai".to_string(),
432                model: "gpt-4".to_string(),
433                system_prompt: None,
434                prompt_template: "test1".to_string(),
435                temperature: None,
436                max_tokens: Some(100),
437                tools: vec![],
438                images: vec![],
439                extra_params: serde_json::Value::Null,
440            }),
441        );
442
443        let llm2 = Node::new(
444            "LLM2".to_string(),
445            NodeKind::LLM(LlmConfig {
446                provider: "openai".to_string(),
447                model: "gpt-4".to_string(),
448                system_prompt: None,
449                prompt_template: "test2".to_string(),
450                temperature: None,
451                max_tokens: Some(100),
452                tools: vec![],
453                images: vec![],
454                extra_params: serde_json::Value::Null,
455            }),
456        );
457
458        let end = Node::new("End".to_string(), NodeKind::End);
459
460        workflow.add_edge(Edge::new(start_id, llm1.id));
461        workflow.add_edge(Edge::new(start_id, llm2.id));
462        workflow.add_edge(Edge::new(llm1.id, end.id));
463        workflow.add_edge(Edge::new(llm2.id, end.id));
464
465        workflow.nodes.push(llm1);
466        workflow.nodes.push(llm2);
467        workflow.nodes.push(end);
468
469        let plan = BatchAnalyzer::analyze(&workflow);
470
471        // Should have 3 batches: [Start], [LLM1, LLM2], [End]
472        assert_eq!(plan.batches.len(), 3);
473        assert_eq!(plan.max_parallelism, 2); // LLM1 and LLM2 in parallel
474
475        // Second batch should be parallelizable
476        assert!(plan.batches[1].parallelizable);
477        assert_eq!(plan.batches[1].nodes.len(), 2);
478    }
479
480    #[test]
481    fn test_batch_opportunities() {
482        let mut workflow = WorkflowBuilder::new("Parallel").start("Start").build();
483
484        let start_id = workflow.nodes[0].id;
485
486        // Add three parallel nodes
487        let llm1 = Node::new(
488            "LLM1".to_string(),
489            NodeKind::LLM(LlmConfig {
490                provider: "openai".to_string(),
491                model: "gpt-4".to_string(),
492                system_prompt: None,
493                prompt_template: "test1".to_string(),
494                temperature: None,
495                max_tokens: Some(100),
496                tools: vec![],
497                images: vec![],
498                extra_params: serde_json::Value::Null,
499            }),
500        );
501
502        let llm2 = Node::new(
503            "LLM2".to_string(),
504            NodeKind::LLM(LlmConfig {
505                provider: "openai".to_string(),
506                model: "gpt-4".to_string(),
507                system_prompt: None,
508                prompt_template: "test2".to_string(),
509                temperature: None,
510                max_tokens: Some(100),
511                tools: vec![],
512                images: vec![],
513                extra_params: serde_json::Value::Null,
514            }),
515        );
516
517        let llm3 = Node::new(
518            "LLM3".to_string(),
519            NodeKind::LLM(LlmConfig {
520                provider: "openai".to_string(),
521                model: "gpt-4".to_string(),
522                system_prompt: None,
523                prompt_template: "test3".to_string(),
524                temperature: None,
525                max_tokens: Some(100),
526                tools: vec![],
527                images: vec![],
528                extra_params: serde_json::Value::Null,
529            }),
530        );
531
532        let end = Node::new("End".to_string(), NodeKind::End);
533
534        workflow.add_edge(Edge::new(start_id, llm1.id));
535        workflow.add_edge(Edge::new(start_id, llm2.id));
536        workflow.add_edge(Edge::new(start_id, llm3.id));
537        workflow.add_edge(Edge::new(llm1.id, end.id));
538        workflow.add_edge(Edge::new(llm2.id, end.id));
539        workflow.add_edge(Edge::new(llm3.id, end.id));
540
541        workflow.nodes.push(llm1);
542        workflow.nodes.push(llm2);
543        workflow.nodes.push(llm3);
544        workflow.nodes.push(end);
545
546        let opportunities = BatchAnalyzer::find_batch_opportunities(&workflow);
547
548        // Should find one opportunity with 3 nodes
549        assert!(!opportunities.is_empty());
550        assert_eq!(opportunities[0].node_count, 3);
551    }
552
553    #[test]
554    fn test_batch_plan_summary() {
555        let workflow = WorkflowBuilder::new("Test")
556            .start("Start")
557            .end("End")
558            .build();
559
560        let plan = BatchAnalyzer::analyze(&workflow);
561        let summary = plan.format_summary();
562
563        assert!(summary.contains("Batch Execution Plan"));
564        assert!(summary.contains("Total Nodes: 2"));
565    }
566
567    #[test]
568    fn test_speedup_calculation() {
569        let mut workflow = WorkflowBuilder::new("Parallel").start("Start").build();
570
571        let start_id = workflow.nodes[0].id;
572
573        // Add 4 parallel nodes
574        for i in 0..4 {
575            let llm = Node::new(
576                format!("LLM{}", i),
577                NodeKind::LLM(LlmConfig {
578                    provider: "openai".to_string(),
579                    model: "gpt-4".to_string(),
580                    system_prompt: None,
581                    prompt_template: format!("test{}", i),
582                    temperature: None,
583                    max_tokens: Some(100),
584                    tools: vec![],
585                    images: vec![],
586                    extra_params: serde_json::Value::Null,
587                }),
588            );
589
590            workflow.add_edge(Edge::new(start_id, llm.id));
591            workflow.nodes.push(llm);
592        }
593
594        let end = Node::new("End".to_string(), NodeKind::End);
595        for i in 1..=4 {
596            workflow.add_edge(Edge::new(workflow.nodes[i].id, end.id));
597        }
598        workflow.nodes.push(end);
599
600        let plan = BatchAnalyzer::analyze(&workflow);
601
602        // Speedup should be > 1 due to parallelization
603        assert!(plan.speedup_factor > 1.0);
604    }
605
606    #[test]
607    fn test_parallel_batches_filter() {
608        let mut workflow = WorkflowBuilder::new("Mixed").start("Start").build();
609
610        let start_id = workflow.nodes[0].id;
611
612        let llm1 = Node::new(
613            "LLM1".to_string(),
614            NodeKind::LLM(LlmConfig {
615                provider: "openai".to_string(),
616                model: "gpt-4".to_string(),
617                system_prompt: None,
618                prompt_template: "test1".to_string(),
619                temperature: None,
620                max_tokens: Some(100),
621                tools: vec![],
622                images: vec![],
623                extra_params: serde_json::Value::Null,
624            }),
625        );
626
627        let llm2 = Node::new(
628            "LLM2".to_string(),
629            NodeKind::LLM(LlmConfig {
630                provider: "openai".to_string(),
631                model: "gpt-4".to_string(),
632                system_prompt: None,
633                prompt_template: "test2".to_string(),
634                temperature: None,
635                max_tokens: Some(100),
636                tools: vec![],
637                images: vec![],
638                extra_params: serde_json::Value::Null,
639            }),
640        );
641
642        let end = Node::new("End".to_string(), NodeKind::End);
643
644        workflow.add_edge(Edge::new(start_id, llm1.id));
645        workflow.add_edge(Edge::new(start_id, llm2.id));
646        workflow.add_edge(Edge::new(llm1.id, end.id));
647        workflow.add_edge(Edge::new(llm2.id, end.id));
648
649        workflow.nodes.push(llm1);
650        workflow.nodes.push(llm2);
651        workflow.nodes.push(end);
652
653        let plan = BatchAnalyzer::analyze(&workflow);
654        let parallel = plan.parallel_batches();
655
656        // Should have at least one parallel batch
657        assert!(!parallel.is_empty());
658    }
659}