oxify_model/
prediction.rs

1//! Execution time prediction for workflows
2//!
3//! This module provides predictions for workflow execution times based on
4//! historical data, node complexity, and expected behavior.
5
6use crate::{ExecutionStats, LlmConfig, Node, NodeKind, VectorConfig, Workflow};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::time::Duration;
10
11/// Execution time prediction for a workflow
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct TimeEstimate {
14    /// Estimated minimum execution time
15    pub min_duration_ms: u64,
16
17    /// Estimated average execution time
18    pub avg_duration_ms: u64,
19
20    /// Estimated maximum execution time (including retries, worst case)
21    pub max_duration_ms: u64,
22
23    /// Critical path through the workflow
24    pub critical_path: Vec<String>,
25
26    /// Time breakdown by node
27    pub node_times: HashMap<String, NodeTime>,
28
29    /// Confidence level (0.0 to 1.0)
30    pub confidence: f64,
31}
32
33/// Time estimate for a single node
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct NodeTime {
36    /// Node name
37    pub node_name: String,
38
39    /// Node type
40    pub node_type: String,
41
42    /// Minimum execution time (ms)
43    pub min_ms: u64,
44
45    /// Average execution time (ms)
46    pub avg_ms: u64,
47
48    /// Maximum execution time (ms)
49    pub max_ms: u64,
50
51    /// Number of expected executions
52    pub expected_executions: u32,
53
54    /// Whether this node is on the critical path
55    pub is_critical: bool,
56}
57
58/// Historical execution data for improving predictions
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct HistoricalData {
61    /// Node type -> average execution time mapping
62    pub node_type_averages: HashMap<String, u64>,
63
64    /// Provider -> average API latency mapping
65    pub provider_latencies: HashMap<String, u64>,
66
67    /// Specific node ID -> execution times
68    pub node_execution_history: HashMap<String, Vec<u64>>,
69}
70
71impl HistoricalData {
72    /// Create empty historical data
73    pub fn new() -> Self {
74        Self {
75            node_type_averages: HashMap::new(),
76            provider_latencies: HashMap::new(),
77            node_execution_history: HashMap::new(),
78        }
79    }
80
81    /// Update historical data from execution stats
82    pub fn update_from_stats(&mut self, _stats: &ExecutionStats) {
83        // Update provider latencies if available
84        // This is a simplified version - real implementation would parse from stats
85    }
86
87    /// Get average time for a node type
88    pub fn get_node_type_average(&self, node_type: &str) -> Option<u64> {
89        self.node_type_averages.get(node_type).copied()
90    }
91
92    /// Get average latency for a provider
93    pub fn get_provider_latency(&self, provider: &str) -> Option<u64> {
94        self.provider_latencies.get(provider).copied()
95    }
96}
97
98impl Default for HistoricalData {
99    fn default() -> Self {
100        Self::new()
101    }
102}
103
104/// Time prediction engine
105pub struct TimePredictor {
106    /// Historical execution data
107    historical_data: HistoricalData,
108}
109
110impl TimePredictor {
111    /// Create a new time predictor
112    pub fn new() -> Self {
113        Self {
114            historical_data: HistoricalData::new(),
115        }
116    }
117
118    /// Create predictor with historical data
119    pub fn with_historical_data(historical_data: HistoricalData) -> Self {
120        Self { historical_data }
121    }
122
123    /// Predict execution time for a workflow
124    pub fn predict(&self, workflow: &Workflow) -> TimeEstimate {
125        let mut node_times = HashMap::new();
126        let mut total_min = 0u64;
127        let mut total_avg = 0u64;
128        let mut total_max = 0u64;
129
130        // Predict time for each node
131        for node in &workflow.nodes {
132            let node_time = self.predict_node_time(node);
133            total_min += node_time.min_ms;
134            total_avg += node_time.avg_ms;
135            total_max += node_time.max_ms;
136            node_times.insert(node.id.to_string(), node_time);
137        }
138
139        // Find critical path (simplified - assumes linear for now)
140        let critical_path = workflow
141            .nodes
142            .iter()
143            .map(|n| n.name.clone())
144            .collect::<Vec<_>>();
145
146        // Calculate confidence based on available historical data
147        let confidence = self.calculate_confidence(workflow);
148
149        TimeEstimate {
150            min_duration_ms: total_min,
151            avg_duration_ms: total_avg,
152            max_duration_ms: total_max,
153            critical_path,
154            node_times,
155            confidence,
156        }
157    }
158
159    /// Predict time for a single node
160    fn predict_node_time(&self, node: &Node) -> NodeTime {
161        let (min_ms, avg_ms, max_ms) = match &node.kind {
162            NodeKind::Start | NodeKind::End => (1, 5, 10),
163
164            NodeKind::LLM(config) => self.predict_llm_time(config, node),
165
166            NodeKind::Retriever(config) => self.predict_vector_time(config),
167
168            NodeKind::Code(_) => {
169                // Code execution time varies widely
170                (100, 500, 5000)
171            }
172
173            NodeKind::Tool(_) => {
174                // API call time
175                (200, 1000, 5000)
176            }
177
178            NodeKind::IfElse(_) => {
179                // Condition evaluation is fast
180                (1, 10, 50)
181            }
182
183            NodeKind::Switch(_) => {
184                // Switch evaluation
185                (1, 10, 50)
186            }
187
188            NodeKind::Loop(_) => {
189                // Loop overhead (not including body execution)
190                (10, 50, 200)
191            }
192
193            NodeKind::TryCatch(_) => {
194                // Try-catch overhead
195                (5, 20, 100)
196            }
197
198            NodeKind::SubWorkflow(_) => {
199                // Sub-workflow execution (depends on sub-workflow)
200                (100, 5000, 30000)
201            }
202
203            NodeKind::Parallel(_) => {
204                // Parallel execution overhead
205                (50, 200, 1000)
206            }
207
208            NodeKind::Approval(_) => {
209                // Human approval can take very long
210                (1000, 60000, 3600000) // 1s to 1 hour
211            }
212
213            NodeKind::Form(_) => {
214                // Form submission time
215                (5000, 120000, 600000) // 5s to 10 minutes
216            }
217
218            NodeKind::Vision(_) => {
219                // Vision/OCR processing time
220                (500, 3000, 15000) // 0.5s to 15s depending on image size
221            }
222        };
223
224        let expected_executions = Self::estimate_executions(node);
225
226        NodeTime {
227            node_name: node.name.clone(),
228            node_type: self.get_node_type_string(&node.kind),
229            min_ms: min_ms * expected_executions as u64,
230            avg_ms: avg_ms * expected_executions as u64,
231            max_ms: max_ms * expected_executions as u64,
232            expected_executions,
233            is_critical: false, // Would be set by critical path analysis
234        }
235    }
236
237    /// Predict LLM execution time
238    fn predict_llm_time(&self, config: &LlmConfig, _node: &Node) -> (u64, u64, u64) {
239        // Check historical data first
240        if let Some(avg) = self.historical_data.get_provider_latency(&config.provider) {
241            return (avg / 2, avg, avg * 2);
242        }
243
244        // Estimate based on provider and model
245        let base_latency = match config.provider.to_lowercase().as_str() {
246            "openai" => {
247                if config.model.contains("gpt-4") {
248                    (3000, 8000, 20000) // GPT-4 is slower
249                } else {
250                    (1000, 3000, 10000) // GPT-3.5 is faster
251                }
252            }
253            "anthropic" => {
254                if config.model.contains("opus") {
255                    (2000, 6000, 15000)
256                } else if config.model.contains("sonnet") {
257                    (1000, 4000, 12000)
258                } else {
259                    (500, 2000, 8000) // Haiku is fastest
260                }
261            }
262            "ollama" | "local" => {
263                // Local models depend on hardware
264                (500, 2000, 10000)
265            }
266            _ => (2000, 5000, 15000), // Default estimate
267        };
268
269        // Adjust for token count
270        let max_tokens = config.max_tokens.unwrap_or(1000);
271        let token_multiplier = (max_tokens as f64 / 1000.0).max(0.5);
272
273        (
274            (base_latency.0 as f64 * token_multiplier) as u64,
275            (base_latency.1 as f64 * token_multiplier) as u64,
276            (base_latency.2 as f64 * token_multiplier) as u64,
277        )
278    }
279
280    /// Predict vector search time
281    fn predict_vector_time(&self, config: &VectorConfig) -> (u64, u64, u64) {
282        match config.db_type.to_lowercase().as_str() {
283            "qdrant" => {
284                // Qdrant is very fast
285                let base = 50 + (config.top_k * 5) as u64;
286                (base / 2, base, base * 3)
287            }
288            "pgvector" => {
289                // pgvector can be slower depending on index
290                let base = 100 + (config.top_k * 10) as u64;
291                (base / 2, base, base * 5)
292            }
293            _ => {
294                let base = 100 + (config.top_k * 10) as u64;
295                (base / 2, base, base * 3)
296            }
297        }
298    }
299
300    /// Estimate number of executions (considering retries)
301    fn estimate_executions(node: &Node) -> u32 {
302        let mut executions = 1u32;
303
304        if let Some(retry_config) = &node.retry_config {
305            // Assume 30% failure rate requiring retries
306            let avg_retries = (retry_config.max_retries as f32 * 0.3).ceil() as u32;
307            executions += avg_retries;
308        }
309
310        executions
311    }
312
313    /// Calculate confidence level based on available data
314    fn calculate_confidence(&self, workflow: &Workflow) -> f64 {
315        if workflow.nodes.is_empty() {
316            return 0.0;
317        }
318
319        let mut total_confidence = 0.0;
320
321        for node in &workflow.nodes {
322            let node_confidence = match &node.kind {
323                // High confidence for simple nodes
324                NodeKind::Start | NodeKind::End | NodeKind::IfElse(_) | NodeKind::Switch(_) => 0.9,
325
326                // Medium confidence for LLM/Vector (depends on historical data)
327                NodeKind::LLM(_) | NodeKind::Retriever(_) => {
328                    if self
329                        .historical_data
330                        .node_execution_history
331                        .contains_key(&node.id.to_string())
332                    {
333                        0.8 // Higher confidence with historical data
334                    } else {
335                        0.5 // Lower confidence without data
336                    }
337                }
338
339                // Lower confidence for variable-time operations
340                NodeKind::Code(_) | NodeKind::Tool(_) | NodeKind::SubWorkflow(_) => 0.4,
341
342                // Very low confidence for human-in-the-loop
343                NodeKind::Approval(_) | NodeKind::Form(_) => 0.2,
344
345                // Medium confidence for control flow
346                NodeKind::Loop(_) | NodeKind::TryCatch(_) | NodeKind::Parallel(_) => 0.5,
347
348                // Medium confidence for vision/OCR
349                NodeKind::Vision(_) => 0.6,
350            };
351
352            total_confidence += node_confidence;
353        }
354
355        (total_confidence / workflow.nodes.len() as f64).min(1.0)
356    }
357
358    /// Get node type as string
359    fn get_node_type_string(&self, kind: &NodeKind) -> String {
360        match kind {
361            NodeKind::Start => "Start".to_string(),
362            NodeKind::End => "End".to_string(),
363            NodeKind::LLM(_) => "LLM".to_string(),
364            NodeKind::Retriever(_) => "Retriever".to_string(),
365            NodeKind::Code(_) => "Code".to_string(),
366            NodeKind::IfElse(_) => "IfElse".to_string(),
367            NodeKind::Tool(_) => "Tool".to_string(),
368            NodeKind::Loop(_) => "Loop".to_string(),
369            NodeKind::TryCatch(_) => "TryCatch".to_string(),
370            NodeKind::SubWorkflow(_) => "SubWorkflow".to_string(),
371            NodeKind::Switch(_) => "Switch".to_string(),
372            NodeKind::Parallel(_) => "Parallel".to_string(),
373            NodeKind::Approval(_) => "Approval".to_string(),
374            NodeKind::Form(_) => "Form".to_string(),
375            NodeKind::Vision(_) => "Vision".to_string(),
376        }
377    }
378}
379
380impl Default for TimePredictor {
381    fn default() -> Self {
382        Self::new()
383    }
384}
385
386impl TimeEstimate {
387    /// Format as human-readable string
388    pub fn format_summary(&self) -> String {
389        let min_duration = Duration::from_millis(self.min_duration_ms);
390        let avg_duration = Duration::from_millis(self.avg_duration_ms);
391        let max_duration = Duration::from_millis(self.max_duration_ms);
392
393        format!(
394            "Estimated Time: {:?} - {:?} (avg: {:?})\n\
395             Critical Path: {}\n\
396             Confidence: {:.0}%",
397            min_duration,
398            max_duration,
399            avg_duration,
400            self.critical_path.join(" → "),
401            self.confidence * 100.0
402        )
403    }
404
405    /// Get nodes on the critical path
406    pub fn critical_path_nodes(&self) -> Vec<&NodeTime> {
407        self.node_times
408            .values()
409            .filter(|nt| nt.is_critical)
410            .collect()
411    }
412
413    /// Get slowest nodes
414    pub fn slowest_nodes(&self, limit: usize) -> Vec<&NodeTime> {
415        let mut times: Vec<&NodeTime> = self.node_times.values().collect();
416        times.sort_by(|a, b| b.avg_ms.cmp(&a.avg_ms));
417        times.into_iter().take(limit).collect()
418    }
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424    use crate::WorkflowBuilder;
425
426    #[test]
427    fn test_time_predictor_new() {
428        let predictor = TimePredictor::new();
429        assert!(predictor.historical_data.node_type_averages.is_empty());
430    }
431
432    #[test]
433    fn test_predict_simple_workflow() {
434        let workflow = WorkflowBuilder::new("Test")
435            .start("Start")
436            .llm(
437                "Generate",
438                LlmConfig {
439                    provider: "openai".to_string(),
440                    model: "gpt-3.5-turbo".to_string(),
441                    system_prompt: None,
442                    prompt_template: "Hello".to_string(),
443                    temperature: None,
444                    max_tokens: Some(100),
445                    tools: vec![],
446                    images: vec![],
447                    extra_params: serde_json::Value::Null,
448                },
449            )
450            .end("End")
451            .build();
452
453        let predictor = TimePredictor::new();
454        let estimate = predictor.predict(&workflow);
455
456        assert!(estimate.avg_duration_ms > 0);
457        assert!(estimate.min_duration_ms < estimate.avg_duration_ms);
458        assert!(estimate.avg_duration_ms < estimate.max_duration_ms);
459        assert!(estimate.confidence > 0.0 && estimate.confidence <= 1.0);
460    }
461
462    #[test]
463    fn test_predict_with_vector_search() {
464        let workflow = WorkflowBuilder::new("RAG")
465            .start("Start")
466            .retriever(
467                "Search",
468                VectorConfig {
469                    db_type: "qdrant".to_string(),
470                    collection: "docs".to_string(),
471                    query: "test".to_string(),
472                    top_k: 5,
473                    score_threshold: Some(0.7),
474                },
475            )
476            .end("End")
477            .build();
478
479        let predictor = TimePredictor::new();
480        let estimate = predictor.predict(&workflow);
481
482        assert!(estimate.avg_duration_ms > 0);
483        assert_eq!(estimate.node_times.len(), 3); // Start, Retriever, End
484    }
485
486    #[test]
487    fn test_estimate_format_summary() {
488        let workflow = WorkflowBuilder::new("Test")
489            .start("Start")
490            .end("End")
491            .build();
492
493        let predictor = TimePredictor::new();
494        let estimate = predictor.predict(&workflow);
495        let summary = estimate.format_summary();
496
497        assert!(summary.contains("Estimated Time:"));
498        assert!(summary.contains("Critical Path:"));
499        assert!(summary.contains("Confidence:"));
500    }
501
502    #[test]
503    fn test_slowest_nodes() {
504        let workflow = WorkflowBuilder::new("Multi-LLM")
505            .start("Start")
506            .llm(
507                "GPT4",
508                LlmConfig {
509                    provider: "openai".to_string(),
510                    model: "gpt-4".to_string(),
511                    system_prompt: None,
512                    prompt_template: "test".to_string(),
513                    temperature: None,
514                    max_tokens: Some(2000),
515                    tools: vec![],
516                    images: vec![],
517                    extra_params: serde_json::Value::Null,
518                },
519            )
520            .llm(
521                "GPT3.5",
522                LlmConfig {
523                    provider: "openai".to_string(),
524                    model: "gpt-3.5-turbo".to_string(),
525                    system_prompt: None,
526                    prompt_template: "test".to_string(),
527                    temperature: None,
528                    max_tokens: Some(100),
529                    tools: vec![],
530                    images: vec![],
531                    extra_params: serde_json::Value::Null,
532                },
533            )
534            .end("End")
535            .build();
536
537        let predictor = TimePredictor::new();
538        let estimate = predictor.predict(&workflow);
539        let slowest = estimate.slowest_nodes(1);
540
541        assert_eq!(slowest.len(), 1);
542        // GPT-4 with more tokens should be slowest
543        assert_eq!(slowest[0].node_name, "GPT4");
544    }
545
546    #[test]
547    fn test_node_with_retry_prediction() {
548        let llm_config = LlmConfig {
549            provider: "openai".to_string(),
550            model: "gpt-4".to_string(),
551            system_prompt: None,
552            prompt_template: "test".to_string(),
553            temperature: None,
554            max_tokens: Some(100),
555            tools: vec![],
556            images: vec![],
557            extra_params: serde_json::Value::Null,
558        };
559
560        let node = Node::new("LLM".to_string(), NodeKind::LLM(llm_config)).with_retry(
561            crate::RetryConfig {
562                max_retries: 3,
563                initial_delay_ms: 1000,
564                backoff_multiplier: 2.0,
565                max_delay_ms: 30000,
566            },
567        );
568
569        let predictor = TimePredictor::new();
570        let time = predictor.predict_node_time(&node);
571
572        // Should have higher expected executions due to retries
573        assert!(time.expected_executions > 1);
574    }
575
576    #[test]
577    fn test_historical_data() {
578        let mut historical_data = HistoricalData::new();
579        historical_data
580            .node_type_averages
581            .insert("LLM".to_string(), 5000);
582        historical_data
583            .provider_latencies
584            .insert("openai".to_string(), 4000);
585
586        assert_eq!(historical_data.get_node_type_average("LLM"), Some(5000));
587        assert_eq!(historical_data.get_provider_latency("openai"), Some(4000));
588        assert_eq!(historical_data.get_node_type_average("Code"), None);
589    }
590
591    #[test]
592    fn test_predictor_with_historical_data() {
593        let mut historical_data = HistoricalData::new();
594        historical_data
595            .provider_latencies
596            .insert("openai".to_string(), 2000);
597
598        let predictor = TimePredictor::with_historical_data(historical_data);
599
600        let workflow = WorkflowBuilder::new("Test")
601            .start("Start")
602            .llm(
603                "GPT",
604                LlmConfig {
605                    provider: "openai".to_string(),
606                    model: "gpt-4".to_string(),
607                    system_prompt: None,
608                    prompt_template: "test".to_string(),
609                    temperature: None,
610                    max_tokens: Some(100),
611                    tools: vec![],
612                    images: vec![],
613                    extra_params: serde_json::Value::Null,
614                },
615            )
616            .end("End")
617            .build();
618
619        let estimate = predictor.predict(&workflow);
620        assert!(estimate.avg_duration_ms > 0);
621    }
622}