mockforge_chaos/
ml_assertion_generator.rs

1//! ML-based assertion generation from historical data
2//!
3//! Analyzes historical orchestration execution data to automatically generate
4//! meaningful assertions based on observed patterns and anomalies.
5
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Historical execution data point
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ExecutionDataPoint {
13    pub timestamp: DateTime<Utc>,
14    pub orchestration_id: String,
15    pub step_id: String,
16    pub metrics: HashMap<String, f64>,
17    pub success: bool,
18    pub duration_ms: u64,
19    pub error_message: Option<String>,
20}
21
22/// Statistical summary of metric
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct MetricStats {
25    pub mean: f64,
26    pub median: f64,
27    pub std_dev: f64,
28    pub min: f64,
29    pub max: f64,
30    pub p95: f64,
31    pub p99: f64,
32    pub sample_count: usize,
33}
34
35/// Generated assertion
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct GeneratedAssertion {
38    pub id: String,
39    pub assertion_type: AssertionType,
40    pub path: String,
41    pub operator: AssertionOperator,
42    pub value: f64,
43    pub confidence: f64,
44    pub rationale: String,
45    pub based_on_samples: usize,
46    pub created_at: DateTime<Utc>,
47}
48
49/// Type of assertion
50#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
51#[serde(rename_all = "snake_case")]
52pub enum AssertionType {
53    MetricThreshold,
54    SuccessRate,
55    Duration,
56    ErrorRate,
57    Custom,
58}
59
60/// Assertion operator
61#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
62#[serde(rename_all = "snake_case")]
63pub enum AssertionOperator {
64    LessThan,
65    LessThanOrEqual,
66    GreaterThan,
67    GreaterThanOrEqual,
68    InRange,
69    NotInRange,
70}
71
72/// Assertion generation configuration
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct AssertionGeneratorConfig {
75    /// Minimum number of samples required
76    pub min_samples: usize,
77    /// Confidence threshold (0.0 - 1.0)
78    pub min_confidence: f64,
79    /// Standard deviations for threshold detection
80    pub std_dev_multiplier: f64,
81    /// Use percentiles for threshold calculation
82    pub use_percentiles: bool,
83    /// Percentile to use for upper bounds
84    pub upper_percentile: f64,
85    /// Percentile to use for lower bounds
86    pub lower_percentile: f64,
87}
88
89impl Default for AssertionGeneratorConfig {
90    fn default() -> Self {
91        Self {
92            min_samples: 10,
93            min_confidence: 0.7,
94            std_dev_multiplier: 2.0,
95            use_percentiles: true,
96            upper_percentile: 95.0,
97            lower_percentile: 5.0,
98        }
99    }
100}
101
102/// ML-based assertion generator
103pub struct AssertionGenerator {
104    config: AssertionGeneratorConfig,
105    historical_data: Vec<ExecutionDataPoint>,
106}
107
108impl AssertionGenerator {
109    /// Create a new assertion generator
110    pub fn new(config: AssertionGeneratorConfig) -> Self {
111        Self {
112            config,
113            historical_data: Vec::new(),
114        }
115    }
116
117    /// Add historical data
118    pub fn add_data(&mut self, data: ExecutionDataPoint) {
119        self.historical_data.push(data);
120    }
121
122    /// Add multiple data points
123    pub fn add_bulk_data(&mut self, data: Vec<ExecutionDataPoint>) {
124        self.historical_data.extend(data);
125    }
126
127    /// Generate assertions based on historical data
128    pub fn generate_assertions(&self) -> Result<Vec<GeneratedAssertion>, String> {
129        if self.historical_data.len() < self.config.min_samples {
130            return Err(format!(
131                "Insufficient data: need at least {} samples, have {}",
132                self.config.min_samples,
133                self.historical_data.len()
134            ));
135        }
136
137        let mut assertions = Vec::new();
138
139        // Group data by orchestration and step
140        let grouped_data = self.group_data_by_step();
141
142        for ((orch_id, step_id), data_points) in grouped_data {
143            if data_points.len() < self.config.min_samples {
144                continue;
145            }
146
147            // Generate duration assertions
148            assertions.extend(self.generate_duration_assertions(
149                &orch_id,
150                &step_id,
151                &data_points,
152            )?);
153
154            // Generate success rate assertions
155            assertions.extend(self.generate_success_rate_assertions(
156                &orch_id,
157                &step_id,
158                &data_points,
159            )?);
160
161            // Generate metric assertions
162            assertions.extend(self.generate_metric_assertions(&orch_id, &step_id, &data_points)?);
163
164            // Generate error rate assertions
165            assertions.extend(self.generate_error_rate_assertions(
166                &orch_id,
167                &step_id,
168                &data_points,
169            )?);
170        }
171
172        Ok(assertions)
173    }
174
175    /// Group data by step
176    fn group_data_by_step(&self) -> HashMap<(String, String), Vec<ExecutionDataPoint>> {
177        let mut grouped: HashMap<(String, String), Vec<ExecutionDataPoint>> = HashMap::new();
178
179        for data_point in &self.historical_data {
180            let key = (data_point.orchestration_id.clone(), data_point.step_id.clone());
181            grouped.entry(key).or_default().push(data_point.clone());
182        }
183
184        grouped
185    }
186
187    /// Generate duration assertions
188    fn generate_duration_assertions(
189        &self,
190        orch_id: &str,
191        step_id: &str,
192        data: &[ExecutionDataPoint],
193    ) -> Result<Vec<GeneratedAssertion>, String> {
194        let durations: Vec<f64> = data.iter().map(|d| d.duration_ms as f64).collect();
195        let stats = Self::calculate_stats(&durations);
196
197        let mut assertions = Vec::new();
198
199        // Generate P95 duration assertion
200        if self.config.use_percentiles {
201            let threshold = stats.p95;
202            let confidence = self.calculate_confidence(&durations, threshold);
203
204            if confidence >= self.config.min_confidence {
205                assertions.push(GeneratedAssertion {
206                    id: format!("duration_{}_{}", orch_id, step_id),
207                    assertion_type: AssertionType::Duration,
208                    path: format!("{}.{}.duration", orch_id, step_id),
209                    operator: AssertionOperator::LessThanOrEqual,
210                    value: threshold,
211                    confidence,
212                    rationale: format!(
213                        "Based on P95 of historical data: {:.2}ms (mean: {:.2}ms, std: {:.2}ms)",
214                        threshold, stats.mean, stats.std_dev
215                    ),
216                    based_on_samples: data.len(),
217                    created_at: Utc::now(),
218                });
219            }
220        }
221
222        Ok(assertions)
223    }
224
225    /// Generate success rate assertions
226    fn generate_success_rate_assertions(
227        &self,
228        orch_id: &str,
229        step_id: &str,
230        data: &[ExecutionDataPoint],
231    ) -> Result<Vec<GeneratedAssertion>, String> {
232        let success_count = data.iter().filter(|d| d.success).count();
233        let total_count = data.len();
234        let success_rate = success_count as f64 / total_count as f64;
235
236        let mut assertions = Vec::new();
237
238        // Only generate if success rate is consistently high
239        if success_rate >= 0.9 {
240            let confidence = success_rate;
241
242            assertions.push(GeneratedAssertion {
243                id: format!("success_rate_{}_{}", orch_id, step_id),
244                assertion_type: AssertionType::SuccessRate,
245                path: format!("{}.{}.success_rate", orch_id, step_id),
246                operator: AssertionOperator::GreaterThanOrEqual,
247                value: success_rate * 0.95, // Allow 5% deviation
248                confidence,
249                rationale: format!(
250                    "Based on historical success rate: {:.2}% ({}/{} successful executions)",
251                    success_rate * 100.0,
252                    success_count,
253                    total_count
254                ),
255                based_on_samples: total_count,
256                created_at: Utc::now(),
257            });
258        }
259
260        Ok(assertions)
261    }
262
263    /// Generate metric assertions
264    fn generate_metric_assertions(
265        &self,
266        orch_id: &str,
267        step_id: &str,
268        data: &[ExecutionDataPoint],
269    ) -> Result<Vec<GeneratedAssertion>, String> {
270        let mut assertions = Vec::new();
271
272        // Collect all metric names
273        let mut all_metrics: HashMap<String, Vec<f64>> = HashMap::new();
274        for data_point in data {
275            for (metric_name, value) in &data_point.metrics {
276                all_metrics.entry(metric_name.clone()).or_default().push(*value);
277            }
278        }
279
280        // Generate assertions for each metric
281        for (metric_name, values) in all_metrics {
282            if values.len() < self.config.min_samples {
283                continue;
284            }
285
286            let stats = Self::calculate_stats(&values);
287
288            if self.config.use_percentiles {
289                // Upper bound assertion (P95)
290                let upper_threshold = stats.p95;
291                let confidence = self.calculate_confidence(&values, upper_threshold);
292
293                if confidence >= self.config.min_confidence {
294                    assertions.push(GeneratedAssertion {
295                        id: format!("metric_{}_{}_{}_upper", orch_id, step_id, metric_name),
296                        assertion_type: AssertionType::MetricThreshold,
297                        path: format!("{}.{}.metrics.{}", orch_id, step_id, metric_name),
298                        operator: AssertionOperator::LessThanOrEqual,
299                        value: upper_threshold,
300                        confidence,
301                        rationale: format!(
302                            "Metric '{}' typically below {:.2} (P95: {:.2}, mean: {:.2}, std: {:.2})",
303                            metric_name, upper_threshold, stats.p95, stats.mean, stats.std_dev
304                        ),
305                        based_on_samples: values.len(),
306                        created_at: Utc::now(),
307                    });
308                }
309            }
310        }
311
312        Ok(assertions)
313    }
314
315    /// Generate error rate assertions
316    fn generate_error_rate_assertions(
317        &self,
318        orch_id: &str,
319        step_id: &str,
320        data: &[ExecutionDataPoint],
321    ) -> Result<Vec<GeneratedAssertion>, String> {
322        let error_count = data.iter().filter(|d| !d.success).count();
323        let total_count = data.len();
324        let error_rate = error_count as f64 / total_count as f64;
325
326        let mut assertions = Vec::new();
327
328        // Generate assertion if error rate is consistently low
329        if error_rate <= 0.1 {
330            assertions.push(GeneratedAssertion {
331                id: format!("error_rate_{}_{}", orch_id, step_id),
332                assertion_type: AssertionType::ErrorRate,
333                path: format!("{}.{}.error_rate", orch_id, step_id),
334                operator: AssertionOperator::LessThanOrEqual,
335                value: (error_rate * 1.5).min(0.2), // Allow 50% increase, max 20%
336                confidence: 1.0 - error_rate,
337                rationale: format!(
338                    "Based on historical error rate: {:.2}% ({}/{} failures)",
339                    error_rate * 100.0,
340                    error_count,
341                    total_count
342                ),
343                based_on_samples: total_count,
344                created_at: Utc::now(),
345            });
346        }
347
348        Ok(assertions)
349    }
350
351    /// Calculate statistics for a set of values
352    fn calculate_stats(values: &[f64]) -> MetricStats {
353        if values.is_empty() {
354            return MetricStats {
355                mean: 0.0,
356                median: 0.0,
357                std_dev: 0.0,
358                min: 0.0,
359                max: 0.0,
360                p95: 0.0,
361                p99: 0.0,
362                sample_count: 0,
363            };
364        }
365
366        let mut sorted = values.to_vec();
367        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
368
369        let mean = sorted.iter().sum::<f64>() / sorted.len() as f64;
370        let median = sorted[sorted.len() / 2];
371        let min = sorted[0];
372        let max = sorted[sorted.len() - 1];
373
374        let variance = sorted.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / sorted.len() as f64;
375        let std_dev = variance.sqrt();
376
377        let p95_idx = ((sorted.len() as f64) * 0.95) as usize;
378        let p99_idx = ((sorted.len() as f64) * 0.99) as usize;
379        let p95 = sorted[p95_idx.min(sorted.len() - 1)];
380        let p99 = sorted[p99_idx.min(sorted.len() - 1)];
381
382        MetricStats {
383            mean,
384            median,
385            std_dev,
386            min,
387            max,
388            p95,
389            p99,
390            sample_count: sorted.len(),
391        }
392    }
393
394    /// Calculate confidence for a threshold
395    fn calculate_confidence(&self, values: &[f64], threshold: f64) -> f64 {
396        let within_threshold = values.iter().filter(|&&v| v <= threshold).count();
397        within_threshold as f64 / values.len() as f64
398    }
399
400    /// Get data count
401    pub fn data_count(&self) -> usize {
402        self.historical_data.len()
403    }
404
405    /// Clear historical data
406    pub fn clear_data(&mut self) {
407        self.historical_data.clear();
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414
415    fn create_sample_data(count: usize) -> Vec<ExecutionDataPoint> {
416        (0..count)
417            .map(|i| {
418                let mut metrics = HashMap::new();
419                metrics.insert("latency_ms".to_string(), 100.0 + (i % 20) as f64);
420                metrics.insert("error_rate".to_string(), 0.01 + (i % 5) as f64 * 0.001);
421
422                ExecutionDataPoint {
423                    timestamp: Utc::now(),
424                    orchestration_id: "orch-1".to_string(),
425                    step_id: "step-1".to_string(),
426                    metrics,
427                    success: i % 10 != 0, // 90% success rate
428                    duration_ms: 100 + (i % 50) as u64,
429                    error_message: if i % 10 == 0 {
430                        Some("Test error".to_string())
431                    } else {
432                        None
433                    },
434                }
435            })
436            .collect()
437    }
438
439    #[test]
440    fn test_generator_creation() {
441        let config = AssertionGeneratorConfig::default();
442        let generator = AssertionGenerator::new(config);
443        assert_eq!(generator.data_count(), 0);
444    }
445
446    #[test]
447    fn test_add_data() {
448        let config = AssertionGeneratorConfig::default();
449        let mut generator = AssertionGenerator::new(config);
450
451        let data = create_sample_data(1);
452        generator.add_data(data[0].clone());
453
454        assert_eq!(generator.data_count(), 1);
455    }
456
457    #[test]
458    fn test_generate_assertions() {
459        let config = AssertionGeneratorConfig::default();
460        let mut generator = AssertionGenerator::new(config);
461
462        let data = create_sample_data(50);
463        generator.add_bulk_data(data);
464
465        let assertions = generator.generate_assertions().unwrap();
466        assert!(!assertions.is_empty());
467
468        // Should have duration, success rate, and metric assertions
469        assert!(assertions.iter().any(|a| a.assertion_type == AssertionType::Duration));
470        assert!(assertions.iter().any(|a| a.assertion_type == AssertionType::SuccessRate));
471    }
472
473    #[test]
474    fn test_insufficient_data() {
475        let config = AssertionGeneratorConfig::default();
476        let mut generator = AssertionGenerator::new(config);
477
478        let data = create_sample_data(5);
479        generator.add_bulk_data(data);
480
481        let result = generator.generate_assertions();
482        assert!(result.is_err());
483    }
484
485    #[test]
486    fn test_stats_calculation() {
487        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
488        let stats = AssertionGenerator::calculate_stats(&values);
489
490        assert_eq!(stats.mean, 5.5);
491        assert_eq!(stats.median, 6.0);
492        assert_eq!(stats.min, 1.0);
493        assert_eq!(stats.max, 10.0);
494        assert_eq!(stats.sample_count, 10);
495    }
496}