Skip to main content

agentic_workflow/intelligence/
prediction.rs

1use std::collections::HashMap;
2
3use crate::types::{
4    CostPrediction, DurationPrediction, ExecutionFingerprint, ResourcePrediction,
5    RiskFactor, StepDurationPrediction, StepLifecycle, SuccessPrediction,
6    WorkflowResult,
7};
8
9/// Predictive execution engine — estimates from historical data.
10pub struct PredictionEngine {
11    fingerprints: Vec<ExecutionFingerprint>,
12}
13
14impl PredictionEngine {
15    pub fn new() -> Self {
16        Self {
17            fingerprints: Vec::new(),
18        }
19    }
20
21    /// Add historical execution data.
22    pub fn ingest_fingerprint(&mut self, fp: ExecutionFingerprint) {
23        self.fingerprints.push(fp);
24    }
25
26    /// Predict execution duration.
27    pub fn predict_duration(&self, workflow_id: &str) -> WorkflowResult<DurationPrediction> {
28        let fps: Vec<&ExecutionFingerprint> = self
29            .fingerprints
30            .iter()
31            .filter(|f| f.workflow_id == workflow_id)
32            .collect();
33
34        if fps.is_empty() {
35            return Ok(DurationPrediction {
36                workflow_id: workflow_id.to_string(),
37                predicted_ms: 0,
38                confidence: 0.0,
39                min_ms: 0,
40                max_ms: 0,
41                based_on_executions: 0,
42                step_predictions: Vec::new(),
43            });
44        }
45
46        let durations: Vec<u64> = fps.iter().map(|f| f.total_duration_ms).collect();
47        let avg = durations.iter().sum::<u64>() / durations.len() as u64;
48        let min = *durations.iter().min().unwrap();
49        let max = *durations.iter().max().unwrap();
50
51        let confidence = (fps.len() as f64 / 10.0).min(1.0);
52
53        // Step-level predictions
54        let mut step_totals: HashMap<&str, (u64, usize)> = HashMap::new();
55        for fp in &fps {
56            for (sid, dur) in &fp.step_durations {
57                let e = step_totals.entry(sid.as_str()).or_insert((0, 0));
58                e.0 += dur;
59                e.1 += 1;
60            }
61        }
62
63        let step_predictions: Vec<StepDurationPrediction> = step_totals
64            .into_iter()
65            .map(|(sid, (total, count))| StepDurationPrediction {
66                step_id: sid.to_string(),
67                predicted_ms: total / count as u64,
68                confidence,
69            })
70            .collect();
71
72        Ok(DurationPrediction {
73            workflow_id: workflow_id.to_string(),
74            predicted_ms: avg,
75            confidence,
76            min_ms: min,
77            max_ms: max,
78            based_on_executions: fps.len(),
79            step_predictions,
80        })
81    }
82
83    /// Predict success probability.
84    pub fn predict_success(&self, workflow_id: &str) -> WorkflowResult<SuccessPrediction> {
85        let fps: Vec<&ExecutionFingerprint> = self
86            .fingerprints
87            .iter()
88            .filter(|f| f.workflow_id == workflow_id)
89            .collect();
90
91        if fps.is_empty() {
92            return Ok(SuccessPrediction {
93                workflow_id: workflow_id.to_string(),
94                success_probability: 0.5,
95                risk_factors: Vec::new(),
96                based_on_executions: 0,
97            });
98        }
99
100        // Count step failure rates
101        let mut step_failures: HashMap<&str, (usize, usize)> = HashMap::new();
102        for fp in &fps {
103            for (sid, outcome) in &fp.step_outcomes {
104                let e = step_failures.entry(sid.as_str()).or_insert((0, 0));
105                e.1 += 1;
106                if *outcome == StepLifecycle::Failed {
107                    e.0 += 1;
108                }
109            }
110        }
111
112        let risk_factors: Vec<RiskFactor> = step_failures
113            .into_iter()
114            .filter(|(_, (fails, total))| *fails > 0 && *total > 0)
115            .map(|(sid, (fails, total))| RiskFactor {
116                step_id: sid.to_string(),
117                risk: fails as f64 / total as f64,
118                reason: format!("{}/{} executions failed", fails, total),
119            })
120            .collect();
121
122        let total_success = fps
123            .iter()
124            .filter(|f| f.retry_count == 0 && f.step_outcomes.values().all(|o| *o == StepLifecycle::Success))
125            .count();
126
127        let probability = total_success as f64 / fps.len() as f64;
128
129        Ok(SuccessPrediction {
130            workflow_id: workflow_id.to_string(),
131            success_probability: probability,
132            risk_factors,
133            based_on_executions: fps.len(),
134        })
135    }
136
137    /// Predict resource consumption.
138    pub fn predict_resources(&self, workflow_id: &str) -> WorkflowResult<ResourcePrediction> {
139        let fps: Vec<&ExecutionFingerprint> = self
140            .fingerprints
141            .iter()
142            .filter(|f| f.workflow_id == workflow_id)
143            .collect();
144
145        let avg_steps = if !fps.is_empty() {
146            fps.iter().map(|f| f.step_durations.len() as u64).sum::<u64>() / fps.len() as u64
147        } else {
148            0
149        };
150
151        Ok(ResourcePrediction {
152            workflow_id: workflow_id.to_string(),
153            estimated_api_calls: avg_steps,
154            estimated_compute_seconds: avg_steps as f64 * 0.5,
155            estimated_storage_bytes: avg_steps * 1024,
156        })
157    }
158
159    /// Predict monetary cost.
160    pub fn predict_cost(&self, workflow_id: &str) -> WorkflowResult<CostPrediction> {
161        let resources = self.predict_resources(workflow_id)?;
162
163        Ok(CostPrediction {
164            workflow_id: workflow_id.to_string(),
165            estimated_cost_usd: resources.estimated_api_calls as f64 * 0.001,
166            breakdown: vec![crate::types::prediction::CostBreakdown {
167                component: "API calls".to_string(),
168                cost_usd: resources.estimated_api_calls as f64 * 0.001,
169                quantity: resources.estimated_api_calls as f64,
170                unit: "calls".to_string(),
171            }],
172            confidence: if self.fingerprints.is_empty() { 0.0 } else { 0.7 },
173        })
174    }
175}
176
177impl Default for PredictionEngine {
178    fn default() -> Self {
179        Self::new()
180    }
181}