1use crate::{ExecutionStats, LlmConfig, Node, NodeKind, VectorConfig, Workflow};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::time::Duration;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct TimeEstimate {
14 pub min_duration_ms: u64,
16
17 pub avg_duration_ms: u64,
19
20 pub max_duration_ms: u64,
22
23 pub critical_path: Vec<String>,
25
26 pub node_times: HashMap<String, NodeTime>,
28
29 pub confidence: f64,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct NodeTime {
36 pub node_name: String,
38
39 pub node_type: String,
41
42 pub min_ms: u64,
44
45 pub avg_ms: u64,
47
48 pub max_ms: u64,
50
51 pub expected_executions: u32,
53
54 pub is_critical: bool,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct HistoricalData {
61 pub node_type_averages: HashMap<String, u64>,
63
64 pub provider_latencies: HashMap<String, u64>,
66
67 pub node_execution_history: HashMap<String, Vec<u64>>,
69}
70
71impl HistoricalData {
72 pub fn new() -> Self {
74 Self {
75 node_type_averages: HashMap::new(),
76 provider_latencies: HashMap::new(),
77 node_execution_history: HashMap::new(),
78 }
79 }
80
81 pub fn update_from_stats(&mut self, _stats: &ExecutionStats) {
83 }
86
87 pub fn get_node_type_average(&self, node_type: &str) -> Option<u64> {
89 self.node_type_averages.get(node_type).copied()
90 }
91
92 pub fn get_provider_latency(&self, provider: &str) -> Option<u64> {
94 self.provider_latencies.get(provider).copied()
95 }
96}
97
98impl Default for HistoricalData {
99 fn default() -> Self {
100 Self::new()
101 }
102}
103
104pub struct TimePredictor {
106 historical_data: HistoricalData,
108}
109
110impl TimePredictor {
111 pub fn new() -> Self {
113 Self {
114 historical_data: HistoricalData::new(),
115 }
116 }
117
118 pub fn with_historical_data(historical_data: HistoricalData) -> Self {
120 Self { historical_data }
121 }
122
123 pub fn predict(&self, workflow: &Workflow) -> TimeEstimate {
125 let mut node_times = HashMap::new();
126 let mut total_min = 0u64;
127 let mut total_avg = 0u64;
128 let mut total_max = 0u64;
129
130 for node in &workflow.nodes {
132 let node_time = self.predict_node_time(node);
133 total_min += node_time.min_ms;
134 total_avg += node_time.avg_ms;
135 total_max += node_time.max_ms;
136 node_times.insert(node.id.to_string(), node_time);
137 }
138
139 let critical_path = workflow
141 .nodes
142 .iter()
143 .map(|n| n.name.clone())
144 .collect::<Vec<_>>();
145
146 let confidence = self.calculate_confidence(workflow);
148
149 TimeEstimate {
150 min_duration_ms: total_min,
151 avg_duration_ms: total_avg,
152 max_duration_ms: total_max,
153 critical_path,
154 node_times,
155 confidence,
156 }
157 }
158
159 fn predict_node_time(&self, node: &Node) -> NodeTime {
161 let (min_ms, avg_ms, max_ms) = match &node.kind {
162 NodeKind::Start | NodeKind::End => (1, 5, 10),
163
164 NodeKind::LLM(config) => self.predict_llm_time(config, node),
165
166 NodeKind::Retriever(config) => self.predict_vector_time(config),
167
168 NodeKind::Code(_) => {
169 (100, 500, 5000)
171 }
172
173 NodeKind::Tool(_) => {
174 (200, 1000, 5000)
176 }
177
178 NodeKind::IfElse(_) => {
179 (1, 10, 50)
181 }
182
183 NodeKind::Switch(_) => {
184 (1, 10, 50)
186 }
187
188 NodeKind::Loop(_) => {
189 (10, 50, 200)
191 }
192
193 NodeKind::TryCatch(_) => {
194 (5, 20, 100)
196 }
197
198 NodeKind::SubWorkflow(_) => {
199 (100, 5000, 30000)
201 }
202
203 NodeKind::Parallel(_) => {
204 (50, 200, 1000)
206 }
207
208 NodeKind::Approval(_) => {
209 (1000, 60000, 3600000) }
212
213 NodeKind::Form(_) => {
214 (5000, 120000, 600000) }
217
218 NodeKind::Vision(_) => {
219 (500, 3000, 15000) }
222 };
223
224 let expected_executions = Self::estimate_executions(node);
225
226 NodeTime {
227 node_name: node.name.clone(),
228 node_type: self.get_node_type_string(&node.kind),
229 min_ms: min_ms * expected_executions as u64,
230 avg_ms: avg_ms * expected_executions as u64,
231 max_ms: max_ms * expected_executions as u64,
232 expected_executions,
233 is_critical: false, }
235 }
236
237 fn predict_llm_time(&self, config: &LlmConfig, _node: &Node) -> (u64, u64, u64) {
239 if let Some(avg) = self.historical_data.get_provider_latency(&config.provider) {
241 return (avg / 2, avg, avg * 2);
242 }
243
244 let base_latency = match config.provider.to_lowercase().as_str() {
246 "openai" => {
247 if config.model.contains("gpt-4") {
248 (3000, 8000, 20000) } else {
250 (1000, 3000, 10000) }
252 }
253 "anthropic" => {
254 if config.model.contains("opus") {
255 (2000, 6000, 15000)
256 } else if config.model.contains("sonnet") {
257 (1000, 4000, 12000)
258 } else {
259 (500, 2000, 8000) }
261 }
262 "ollama" | "local" => {
263 (500, 2000, 10000)
265 }
266 _ => (2000, 5000, 15000), };
268
269 let max_tokens = config.max_tokens.unwrap_or(1000);
271 let token_multiplier = (max_tokens as f64 / 1000.0).max(0.5);
272
273 (
274 (base_latency.0 as f64 * token_multiplier) as u64,
275 (base_latency.1 as f64 * token_multiplier) as u64,
276 (base_latency.2 as f64 * token_multiplier) as u64,
277 )
278 }
279
280 fn predict_vector_time(&self, config: &VectorConfig) -> (u64, u64, u64) {
282 match config.db_type.to_lowercase().as_str() {
283 "qdrant" => {
284 let base = 50 + (config.top_k * 5) as u64;
286 (base / 2, base, base * 3)
287 }
288 "pgvector" => {
289 let base = 100 + (config.top_k * 10) as u64;
291 (base / 2, base, base * 5)
292 }
293 _ => {
294 let base = 100 + (config.top_k * 10) as u64;
295 (base / 2, base, base * 3)
296 }
297 }
298 }
299
300 fn estimate_executions(node: &Node) -> u32 {
302 let mut executions = 1u32;
303
304 if let Some(retry_config) = &node.retry_config {
305 let avg_retries = (retry_config.max_retries as f32 * 0.3).ceil() as u32;
307 executions += avg_retries;
308 }
309
310 executions
311 }
312
313 fn calculate_confidence(&self, workflow: &Workflow) -> f64 {
315 if workflow.nodes.is_empty() {
316 return 0.0;
317 }
318
319 let mut total_confidence = 0.0;
320
321 for node in &workflow.nodes {
322 let node_confidence = match &node.kind {
323 NodeKind::Start | NodeKind::End | NodeKind::IfElse(_) | NodeKind::Switch(_) => 0.9,
325
326 NodeKind::LLM(_) | NodeKind::Retriever(_) => {
328 if self
329 .historical_data
330 .node_execution_history
331 .contains_key(&node.id.to_string())
332 {
333 0.8 } else {
335 0.5 }
337 }
338
339 NodeKind::Code(_) | NodeKind::Tool(_) | NodeKind::SubWorkflow(_) => 0.4,
341
342 NodeKind::Approval(_) | NodeKind::Form(_) => 0.2,
344
345 NodeKind::Loop(_) | NodeKind::TryCatch(_) | NodeKind::Parallel(_) => 0.5,
347
348 NodeKind::Vision(_) => 0.6,
350 };
351
352 total_confidence += node_confidence;
353 }
354
355 (total_confidence / workflow.nodes.len() as f64).min(1.0)
356 }
357
358 fn get_node_type_string(&self, kind: &NodeKind) -> String {
360 match kind {
361 NodeKind::Start => "Start".to_string(),
362 NodeKind::End => "End".to_string(),
363 NodeKind::LLM(_) => "LLM".to_string(),
364 NodeKind::Retriever(_) => "Retriever".to_string(),
365 NodeKind::Code(_) => "Code".to_string(),
366 NodeKind::IfElse(_) => "IfElse".to_string(),
367 NodeKind::Tool(_) => "Tool".to_string(),
368 NodeKind::Loop(_) => "Loop".to_string(),
369 NodeKind::TryCatch(_) => "TryCatch".to_string(),
370 NodeKind::SubWorkflow(_) => "SubWorkflow".to_string(),
371 NodeKind::Switch(_) => "Switch".to_string(),
372 NodeKind::Parallel(_) => "Parallel".to_string(),
373 NodeKind::Approval(_) => "Approval".to_string(),
374 NodeKind::Form(_) => "Form".to_string(),
375 NodeKind::Vision(_) => "Vision".to_string(),
376 }
377 }
378}
379
380impl Default for TimePredictor {
381 fn default() -> Self {
382 Self::new()
383 }
384}
385
386impl TimeEstimate {
387 pub fn format_summary(&self) -> String {
389 let min_duration = Duration::from_millis(self.min_duration_ms);
390 let avg_duration = Duration::from_millis(self.avg_duration_ms);
391 let max_duration = Duration::from_millis(self.max_duration_ms);
392
393 format!(
394 "Estimated Time: {:?} - {:?} (avg: {:?})\n\
395 Critical Path: {}\n\
396 Confidence: {:.0}%",
397 min_duration,
398 max_duration,
399 avg_duration,
400 self.critical_path.join(" → "),
401 self.confidence * 100.0
402 )
403 }
404
405 pub fn critical_path_nodes(&self) -> Vec<&NodeTime> {
407 self.node_times
408 .values()
409 .filter(|nt| nt.is_critical)
410 .collect()
411 }
412
413 pub fn slowest_nodes(&self, limit: usize) -> Vec<&NodeTime> {
415 let mut times: Vec<&NodeTime> = self.node_times.values().collect();
416 times.sort_by(|a, b| b.avg_ms.cmp(&a.avg_ms));
417 times.into_iter().take(limit).collect()
418 }
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424 use crate::WorkflowBuilder;
425
426 #[test]
427 fn test_time_predictor_new() {
428 let predictor = TimePredictor::new();
429 assert!(predictor.historical_data.node_type_averages.is_empty());
430 }
431
432 #[test]
433 fn test_predict_simple_workflow() {
434 let workflow = WorkflowBuilder::new("Test")
435 .start("Start")
436 .llm(
437 "Generate",
438 LlmConfig {
439 provider: "openai".to_string(),
440 model: "gpt-3.5-turbo".to_string(),
441 system_prompt: None,
442 prompt_template: "Hello".to_string(),
443 temperature: None,
444 max_tokens: Some(100),
445 tools: vec![],
446 images: vec![],
447 extra_params: serde_json::Value::Null,
448 },
449 )
450 .end("End")
451 .build();
452
453 let predictor = TimePredictor::new();
454 let estimate = predictor.predict(&workflow);
455
456 assert!(estimate.avg_duration_ms > 0);
457 assert!(estimate.min_duration_ms < estimate.avg_duration_ms);
458 assert!(estimate.avg_duration_ms < estimate.max_duration_ms);
459 assert!(estimate.confidence > 0.0 && estimate.confidence <= 1.0);
460 }
461
462 #[test]
463 fn test_predict_with_vector_search() {
464 let workflow = WorkflowBuilder::new("RAG")
465 .start("Start")
466 .retriever(
467 "Search",
468 VectorConfig {
469 db_type: "qdrant".to_string(),
470 collection: "docs".to_string(),
471 query: "test".to_string(),
472 top_k: 5,
473 score_threshold: Some(0.7),
474 },
475 )
476 .end("End")
477 .build();
478
479 let predictor = TimePredictor::new();
480 let estimate = predictor.predict(&workflow);
481
482 assert!(estimate.avg_duration_ms > 0);
483 assert_eq!(estimate.node_times.len(), 3); }
485
486 #[test]
487 fn test_estimate_format_summary() {
488 let workflow = WorkflowBuilder::new("Test")
489 .start("Start")
490 .end("End")
491 .build();
492
493 let predictor = TimePredictor::new();
494 let estimate = predictor.predict(&workflow);
495 let summary = estimate.format_summary();
496
497 assert!(summary.contains("Estimated Time:"));
498 assert!(summary.contains("Critical Path:"));
499 assert!(summary.contains("Confidence:"));
500 }
501
502 #[test]
503 fn test_slowest_nodes() {
504 let workflow = WorkflowBuilder::new("Multi-LLM")
505 .start("Start")
506 .llm(
507 "GPT4",
508 LlmConfig {
509 provider: "openai".to_string(),
510 model: "gpt-4".to_string(),
511 system_prompt: None,
512 prompt_template: "test".to_string(),
513 temperature: None,
514 max_tokens: Some(2000),
515 tools: vec![],
516 images: vec![],
517 extra_params: serde_json::Value::Null,
518 },
519 )
520 .llm(
521 "GPT3.5",
522 LlmConfig {
523 provider: "openai".to_string(),
524 model: "gpt-3.5-turbo".to_string(),
525 system_prompt: None,
526 prompt_template: "test".to_string(),
527 temperature: None,
528 max_tokens: Some(100),
529 tools: vec![],
530 images: vec![],
531 extra_params: serde_json::Value::Null,
532 },
533 )
534 .end("End")
535 .build();
536
537 let predictor = TimePredictor::new();
538 let estimate = predictor.predict(&workflow);
539 let slowest = estimate.slowest_nodes(1);
540
541 assert_eq!(slowest.len(), 1);
542 assert_eq!(slowest[0].node_name, "GPT4");
544 }
545
546 #[test]
547 fn test_node_with_retry_prediction() {
548 let llm_config = LlmConfig {
549 provider: "openai".to_string(),
550 model: "gpt-4".to_string(),
551 system_prompt: None,
552 prompt_template: "test".to_string(),
553 temperature: None,
554 max_tokens: Some(100),
555 tools: vec![],
556 images: vec![],
557 extra_params: serde_json::Value::Null,
558 };
559
560 let node = Node::new("LLM".to_string(), NodeKind::LLM(llm_config)).with_retry(
561 crate::RetryConfig {
562 max_retries: 3,
563 initial_delay_ms: 1000,
564 backoff_multiplier: 2.0,
565 max_delay_ms: 30000,
566 },
567 );
568
569 let predictor = TimePredictor::new();
570 let time = predictor.predict_node_time(&node);
571
572 assert!(time.expected_executions > 1);
574 }
575
576 #[test]
577 fn test_historical_data() {
578 let mut historical_data = HistoricalData::new();
579 historical_data
580 .node_type_averages
581 .insert("LLM".to_string(), 5000);
582 historical_data
583 .provider_latencies
584 .insert("openai".to_string(), 4000);
585
586 assert_eq!(historical_data.get_node_type_average("LLM"), Some(5000));
587 assert_eq!(historical_data.get_provider_latency("openai"), Some(4000));
588 assert_eq!(historical_data.get_node_type_average("Code"), None);
589 }
590
591 #[test]
592 fn test_predictor_with_historical_data() {
593 let mut historical_data = HistoricalData::new();
594 historical_data
595 .provider_latencies
596 .insert("openai".to_string(), 2000);
597
598 let predictor = TimePredictor::with_historical_data(historical_data);
599
600 let workflow = WorkflowBuilder::new("Test")
601 .start("Start")
602 .llm(
603 "GPT",
604 LlmConfig {
605 provider: "openai".to_string(),
606 model: "gpt-4".to_string(),
607 system_prompt: None,
608 prompt_template: "test".to_string(),
609 temperature: None,
610 max_tokens: Some(100),
611 tools: vec![],
612 images: vec![],
613 extra_params: serde_json::Value::Null,
614 },
615 )
616 .end("End")
617 .build();
618
619 let estimate = predictor.predict(&workflow);
620 assert!(estimate.avg_duration_ms > 0);
621 }
622}