1use crate::{Node, NodeId, NodeKind, Workflow};
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct BatchPlan {
13 pub batches: Vec<ExecutionBatch>,
15
16 pub total_nodes: usize,
18
19 pub max_parallelism: usize,
21
22 pub speedup_factor: f64,
24
25 pub stats: BatchStats,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct ExecutionBatch {
32 pub level: usize,
34
35 pub nodes: Vec<NodeId>,
37
38 pub estimated_time_ms: u64,
40
41 pub parallelizable: bool,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct BatchStats {
48 pub batch_count: usize,
50
51 pub avg_batch_size: f64,
53
54 pub sequential_batches: usize,
56
57 pub parallel_batches: usize,
59
60 pub efficiency: f64,
62}
63
64pub struct BatchAnalyzer;
66
67impl BatchAnalyzer {
68 pub fn analyze(workflow: &Workflow) -> BatchPlan {
70 let dependencies = Self::build_dependency_graph(workflow);
72
73 let in_degrees = Self::compute_in_degrees(workflow, &dependencies);
75
76 let batches = Self::generate_batches(workflow, &dependencies, in_degrees);
78
79 let stats = Self::calculate_stats(&batches);
81
82 let speedup_factor = Self::calculate_speedup(&batches, workflow.nodes.len());
84
85 let max_parallelism = batches.iter().map(|b| b.nodes.len()).max().unwrap_or(0);
87
88 BatchPlan {
89 total_nodes: workflow.nodes.len(),
90 max_parallelism,
91 speedup_factor,
92 batches,
93 stats,
94 }
95 }
96
97 fn build_dependency_graph(workflow: &Workflow) -> HashMap<NodeId, Vec<NodeId>> {
99 let mut graph: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
100
101 for node in &workflow.nodes {
103 graph.entry(node.id).or_default();
104 }
105
106 for edge in &workflow.edges {
108 graph.entry(edge.from).or_default().push(edge.to);
109 }
110
111 graph
112 }
113
114 fn compute_in_degrees(
116 workflow: &Workflow,
117 dependencies: &HashMap<NodeId, Vec<NodeId>>,
118 ) -> HashMap<NodeId, usize> {
119 let mut in_degrees: HashMap<NodeId, usize> = HashMap::new();
120
121 for node in &workflow.nodes {
123 in_degrees.insert(node.id, 0);
124 }
125
126 for children in dependencies.values() {
128 for &child_id in children {
129 *in_degrees.entry(child_id).or_insert(0) += 1;
130 }
131 }
132
133 in_degrees
134 }
135
136 fn generate_batches(
138 workflow: &Workflow,
139 dependencies: &HashMap<NodeId, Vec<NodeId>>,
140 mut in_degrees: HashMap<NodeId, usize>,
141 ) -> Vec<ExecutionBatch> {
142 let mut batches = Vec::new();
143 let mut processed = HashSet::new();
144 let mut current_level = 0;
145
146 let node_map: HashMap<NodeId, &Node> = workflow.nodes.iter().map(|n| (n.id, n)).collect();
148
149 while processed.len() < workflow.nodes.len() {
150 let ready_nodes: Vec<NodeId> = in_degrees
152 .iter()
153 .filter(|(&id, °ree)| degree == 0 && !processed.contains(&id))
154 .map(|(&id, _)| id)
155 .collect();
156
157 if ready_nodes.is_empty() {
158 break;
160 }
161
162 let estimated_time_ms = ready_nodes
164 .iter()
165 .filter_map(|id| node_map.get(id))
166 .map(|node| Self::estimate_node_time(node))
167 .max()
168 .unwrap_or(100);
169
170 let parallelizable = ready_nodes.len() > 1
172 && ready_nodes.iter().all(|id| {
173 if let Some(node) = node_map.get(id) {
174 Self::is_parallelizable(node)
175 } else {
176 false
177 }
178 });
179
180 batches.push(ExecutionBatch {
181 level: current_level,
182 nodes: ready_nodes.clone(),
183 estimated_time_ms,
184 parallelizable,
185 });
186
187 for &node_id in &ready_nodes {
189 processed.insert(node_id);
190 in_degrees.remove(&node_id);
191
192 if let Some(children) = dependencies.get(&node_id) {
194 for &child_id in children {
195 if let Some(degree) = in_degrees.get_mut(&child_id) {
196 *degree = degree.saturating_sub(1);
197 }
198 }
199 }
200 }
201
202 current_level += 1;
203 }
204
205 batches
206 }
207
208 fn estimate_node_time(node: &Node) -> u64 {
210 match &node.kind {
211 NodeKind::Start | NodeKind::End => 10,
212 NodeKind::LLM(_) => 3000,
213 NodeKind::Retriever(_) => 500,
214 NodeKind::Code(_) => 1000,
215 NodeKind::Tool(_) => 2000,
216 NodeKind::IfElse(_) | NodeKind::Switch(_) => 50,
217 NodeKind::Loop(_) => 100,
218 NodeKind::TryCatch(_) => 100,
219 NodeKind::SubWorkflow(_) => 5000,
220 NodeKind::Parallel(_) => 200,
221 NodeKind::Approval(_) => 60000,
222 NodeKind::Form(_) => 120000,
223 NodeKind::Vision(_) => 3000,
224 }
225 }
226
227 fn is_parallelizable(node: &Node) -> bool {
229 !matches!(node.kind, NodeKind::Approval(_) | NodeKind::Form(_))
232 }
233
234 fn calculate_stats(batches: &[ExecutionBatch]) -> BatchStats {
236 let batch_count = batches.len();
237
238 let total_nodes: usize = batches.iter().map(|b| b.nodes.len()).sum();
239 let avg_batch_size = if batch_count > 0 {
240 total_nodes as f64 / batch_count as f64
241 } else {
242 0.0
243 };
244
245 let sequential_batches = batches.iter().filter(|b| !b.parallelizable).count();
246 let parallel_batches = batches.iter().filter(|b| b.parallelizable).count();
247
248 let parallel_nodes: usize = batches
250 .iter()
251 .filter(|b| b.parallelizable)
252 .map(|b| b.nodes.len())
253 .sum();
254
255 let efficiency = if total_nodes > 0 {
256 parallel_nodes as f64 / total_nodes as f64
257 } else {
258 0.0
259 };
260
261 BatchStats {
262 batch_count,
263 avg_batch_size,
264 sequential_batches,
265 parallel_batches,
266 efficiency,
267 }
268 }
269
270 fn calculate_speedup(batches: &[ExecutionBatch], total_nodes: usize) -> f64 {
272 if total_nodes == 0 {
273 return 1.0;
274 }
275
276 let sequential_time: u64 =
278 batches.iter().flat_map(|b| b.nodes.iter()).count() as u64 * 1000; let parallel_time: u64 = batches.iter().map(|b| b.estimated_time_ms).sum();
282
283 if parallel_time > 0 {
284 sequential_time as f64 / parallel_time as f64
285 } else {
286 1.0
287 }
288 }
289
290 pub fn find_batch_opportunities(workflow: &Workflow) -> Vec<BatchOpportunity> {
292 let plan = Self::analyze(workflow);
293 let node_map: HashMap<NodeId, &Node> = workflow.nodes.iter().map(|n| (n.id, n)).collect();
294
295 let mut opportunities = Vec::new();
296
297 for batch in &plan.batches {
298 if batch.parallelizable && batch.nodes.len() > 1 {
299 let node_names: Vec<String> = batch
300 .nodes
301 .iter()
302 .filter_map(|id| node_map.get(id).map(|n| n.name.clone()))
303 .collect();
304
305 opportunities.push(BatchOpportunity {
306 level: batch.level,
307 node_count: batch.nodes.len(),
308 node_names,
309 estimated_speedup: batch.nodes.len() as f64 * 0.8, description: format!(
311 "Level {} has {} nodes that can run in parallel",
312 batch.level,
313 batch.nodes.len()
314 ),
315 });
316 }
317 }
318
319 opportunities
320 }
321}
322
323#[derive(Debug, Clone, Serialize, Deserialize)]
325pub struct BatchOpportunity {
326 pub level: usize,
328
329 pub node_count: usize,
331
332 pub node_names: Vec<String>,
334
335 pub estimated_speedup: f64,
337
338 pub description: String,
340}
341
342impl BatchPlan {
343 pub fn format_summary(&self) -> String {
345 format!(
346 "Batch Execution Plan:\n\
347 Total Nodes: {} | Batches: {} | Max Parallelism: {}\n\
348 Speedup Factor: {:.2}x | Efficiency: {:.0}%\n\
349 Parallel Batches: {} | Sequential Batches: {}\n\
350 Average Batch Size: {:.1}",
351 self.total_nodes,
352 self.stats.batch_count,
353 self.max_parallelism,
354 self.speedup_factor,
355 self.stats.efficiency * 100.0,
356 self.stats.parallel_batches,
357 self.stats.sequential_batches,
358 self.stats.avg_batch_size
359 )
360 }
361
362 pub fn critical_path(&self) -> Vec<&ExecutionBatch> {
364 self.batches.iter().collect()
365 }
366
367 pub fn parallel_batches(&self) -> Vec<&ExecutionBatch> {
369 self.batches.iter().filter(|b| b.parallelizable).collect()
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376 use crate::{Edge, LlmConfig, WorkflowBuilder};
377
378 #[test]
379 fn test_linear_workflow_batching() {
380 let workflow = WorkflowBuilder::new("Linear")
381 .start("Start")
382 .llm(
383 "LLM1",
384 LlmConfig {
385 provider: "openai".to_string(),
386 model: "gpt-4".to_string(),
387 system_prompt: None,
388 prompt_template: "test1".to_string(),
389 temperature: None,
390 max_tokens: Some(100),
391 tools: vec![],
392 images: vec![],
393 extra_params: serde_json::Value::Null,
394 },
395 )
396 .llm(
397 "LLM2",
398 LlmConfig {
399 provider: "openai".to_string(),
400 model: "gpt-4".to_string(),
401 system_prompt: None,
402 prompt_template: "test2".to_string(),
403 temperature: None,
404 max_tokens: Some(100),
405 tools: vec![],
406 images: vec![],
407 extra_params: serde_json::Value::Null,
408 },
409 )
410 .end("End")
411 .build();
412
413 let plan = BatchAnalyzer::analyze(&workflow);
414
415 assert_eq!(plan.batches.len(), 4);
417 assert_eq!(plan.total_nodes, 4);
418 assert_eq!(plan.max_parallelism, 1); }
420
421 #[test]
422 fn test_parallel_workflow_batching() {
423 let mut workflow = WorkflowBuilder::new("Parallel").start("Start").build();
424
425 let start_id = workflow.nodes[0].id;
426
427 let llm1 = Node::new(
429 "LLM1".to_string(),
430 NodeKind::LLM(LlmConfig {
431 provider: "openai".to_string(),
432 model: "gpt-4".to_string(),
433 system_prompt: None,
434 prompt_template: "test1".to_string(),
435 temperature: None,
436 max_tokens: Some(100),
437 tools: vec![],
438 images: vec![],
439 extra_params: serde_json::Value::Null,
440 }),
441 );
442
443 let llm2 = Node::new(
444 "LLM2".to_string(),
445 NodeKind::LLM(LlmConfig {
446 provider: "openai".to_string(),
447 model: "gpt-4".to_string(),
448 system_prompt: None,
449 prompt_template: "test2".to_string(),
450 temperature: None,
451 max_tokens: Some(100),
452 tools: vec![],
453 images: vec![],
454 extra_params: serde_json::Value::Null,
455 }),
456 );
457
458 let end = Node::new("End".to_string(), NodeKind::End);
459
460 workflow.add_edge(Edge::new(start_id, llm1.id));
461 workflow.add_edge(Edge::new(start_id, llm2.id));
462 workflow.add_edge(Edge::new(llm1.id, end.id));
463 workflow.add_edge(Edge::new(llm2.id, end.id));
464
465 workflow.nodes.push(llm1);
466 workflow.nodes.push(llm2);
467 workflow.nodes.push(end);
468
469 let plan = BatchAnalyzer::analyze(&workflow);
470
471 assert_eq!(plan.batches.len(), 3);
473 assert_eq!(plan.max_parallelism, 2); assert!(plan.batches[1].parallelizable);
477 assert_eq!(plan.batches[1].nodes.len(), 2);
478 }
479
480 #[test]
481 fn test_batch_opportunities() {
482 let mut workflow = WorkflowBuilder::new("Parallel").start("Start").build();
483
484 let start_id = workflow.nodes[0].id;
485
486 let llm1 = Node::new(
488 "LLM1".to_string(),
489 NodeKind::LLM(LlmConfig {
490 provider: "openai".to_string(),
491 model: "gpt-4".to_string(),
492 system_prompt: None,
493 prompt_template: "test1".to_string(),
494 temperature: None,
495 max_tokens: Some(100),
496 tools: vec![],
497 images: vec![],
498 extra_params: serde_json::Value::Null,
499 }),
500 );
501
502 let llm2 = Node::new(
503 "LLM2".to_string(),
504 NodeKind::LLM(LlmConfig {
505 provider: "openai".to_string(),
506 model: "gpt-4".to_string(),
507 system_prompt: None,
508 prompt_template: "test2".to_string(),
509 temperature: None,
510 max_tokens: Some(100),
511 tools: vec![],
512 images: vec![],
513 extra_params: serde_json::Value::Null,
514 }),
515 );
516
517 let llm3 = Node::new(
518 "LLM3".to_string(),
519 NodeKind::LLM(LlmConfig {
520 provider: "openai".to_string(),
521 model: "gpt-4".to_string(),
522 system_prompt: None,
523 prompt_template: "test3".to_string(),
524 temperature: None,
525 max_tokens: Some(100),
526 tools: vec![],
527 images: vec![],
528 extra_params: serde_json::Value::Null,
529 }),
530 );
531
532 let end = Node::new("End".to_string(), NodeKind::End);
533
534 workflow.add_edge(Edge::new(start_id, llm1.id));
535 workflow.add_edge(Edge::new(start_id, llm2.id));
536 workflow.add_edge(Edge::new(start_id, llm3.id));
537 workflow.add_edge(Edge::new(llm1.id, end.id));
538 workflow.add_edge(Edge::new(llm2.id, end.id));
539 workflow.add_edge(Edge::new(llm3.id, end.id));
540
541 workflow.nodes.push(llm1);
542 workflow.nodes.push(llm2);
543 workflow.nodes.push(llm3);
544 workflow.nodes.push(end);
545
546 let opportunities = BatchAnalyzer::find_batch_opportunities(&workflow);
547
548 assert!(!opportunities.is_empty());
550 assert_eq!(opportunities[0].node_count, 3);
551 }
552
553 #[test]
554 fn test_batch_plan_summary() {
555 let workflow = WorkflowBuilder::new("Test")
556 .start("Start")
557 .end("End")
558 .build();
559
560 let plan = BatchAnalyzer::analyze(&workflow);
561 let summary = plan.format_summary();
562
563 assert!(summary.contains("Batch Execution Plan"));
564 assert!(summary.contains("Total Nodes: 2"));
565 }
566
567 #[test]
568 fn test_speedup_calculation() {
569 let mut workflow = WorkflowBuilder::new("Parallel").start("Start").build();
570
571 let start_id = workflow.nodes[0].id;
572
573 for i in 0..4 {
575 let llm = Node::new(
576 format!("LLM{}", i),
577 NodeKind::LLM(LlmConfig {
578 provider: "openai".to_string(),
579 model: "gpt-4".to_string(),
580 system_prompt: None,
581 prompt_template: format!("test{}", i),
582 temperature: None,
583 max_tokens: Some(100),
584 tools: vec![],
585 images: vec![],
586 extra_params: serde_json::Value::Null,
587 }),
588 );
589
590 workflow.add_edge(Edge::new(start_id, llm.id));
591 workflow.nodes.push(llm);
592 }
593
594 let end = Node::new("End".to_string(), NodeKind::End);
595 for i in 1..=4 {
596 workflow.add_edge(Edge::new(workflow.nodes[i].id, end.id));
597 }
598 workflow.nodes.push(end);
599
600 let plan = BatchAnalyzer::analyze(&workflow);
601
602 assert!(plan.speedup_factor > 1.0);
604 }
605
606 #[test]
607 fn test_parallel_batches_filter() {
608 let mut workflow = WorkflowBuilder::new("Mixed").start("Start").build();
609
610 let start_id = workflow.nodes[0].id;
611
612 let llm1 = Node::new(
613 "LLM1".to_string(),
614 NodeKind::LLM(LlmConfig {
615 provider: "openai".to_string(),
616 model: "gpt-4".to_string(),
617 system_prompt: None,
618 prompt_template: "test1".to_string(),
619 temperature: None,
620 max_tokens: Some(100),
621 tools: vec![],
622 images: vec![],
623 extra_params: serde_json::Value::Null,
624 }),
625 );
626
627 let llm2 = Node::new(
628 "LLM2".to_string(),
629 NodeKind::LLM(LlmConfig {
630 provider: "openai".to_string(),
631 model: "gpt-4".to_string(),
632 system_prompt: None,
633 prompt_template: "test2".to_string(),
634 temperature: None,
635 max_tokens: Some(100),
636 tools: vec![],
637 images: vec![],
638 extra_params: serde_json::Value::Null,
639 }),
640 );
641
642 let end = Node::new("End".to_string(), NodeKind::End);
643
644 workflow.add_edge(Edge::new(start_id, llm1.id));
645 workflow.add_edge(Edge::new(start_id, llm2.id));
646 workflow.add_edge(Edge::new(llm1.id, end.id));
647 workflow.add_edge(Edge::new(llm2.id, end.id));
648
649 workflow.nodes.push(llm1);
650 workflow.nodes.push(llm2);
651 workflow.nodes.push(end);
652
653 let plan = BatchAnalyzer::analyze(&workflow);
654 let parallel = plan.parallel_batches();
655
656 assert!(!parallel.is_empty());
658 }
659}