Skip to main content

mockforge_intelligence/behavioral_cloning/
edge_amplifier.rs

1//! Rare edge case amplification
2//!
3//! This module provides functionality to identify and amplify rare
4//! error patterns for testing purposes.
5
6use crate::behavioral_cloning::types::{
7    EdgeAmplificationConfig, EndpointProbabilityModel, ErrorPattern,
8};
9use mockforge_core::Result;
10
11/// Edge amplifier for increasing rare error frequency
12pub struct EdgeAmplifier;
13
14impl EdgeAmplifier {
15    /// Create a new edge amplifier
16    pub fn new() -> Self {
17        Self
18    }
19
20    /// Identify rare edge patterns in a probability model
21    ///
22    /// Finds patterns with probability below the threshold
23    /// (default 1%).
24    pub fn identify_rare_edges(
25        model: &EndpointProbabilityModel,
26        threshold: f64,
27    ) -> Vec<&ErrorPattern> {
28        model
29            .error_patterns
30            .iter()
31            .filter(|pattern| pattern.probability < threshold)
32            .collect()
33    }
34
35    /// Apply amplification to a probability model
36    ///
37    /// Increases the probability of rare patterns and normalizes
38    /// the remaining probabilities to sum to 1.0.
39    ///
40    /// Stores original probabilities before amplification for later restoration.
41    pub fn apply_amplification(
42        model: &mut EndpointProbabilityModel,
43        config: &EdgeAmplificationConfig,
44    ) -> Result<()> {
45        if !config.enabled {
46            return Ok(());
47        }
48
49        // Store original probabilities if not already stored
50        if model.original_error_probabilities.is_none() {
51            let mut original = std::collections::HashMap::new();
52            for pattern in &model.error_patterns {
53                original.insert(pattern.error_type.clone(), pattern.probability);
54            }
55            model.original_error_probabilities = Some(original);
56        }
57
58        // Identify rare patterns
59        let rare_patterns: Vec<usize> = model
60            .error_patterns
61            .iter()
62            .enumerate()
63            .filter(|(_, pattern)| pattern.probability < config.rare_threshold)
64            .map(|(idx, _)| idx)
65            .collect();
66
67        if rare_patterns.is_empty() {
68            return Ok(());
69        }
70
71        // Calculate total probability of rare patterns
72        let rare_total: f64 =
73            rare_patterns.iter().map(|&idx| model.error_patterns[idx].probability).sum();
74
75        // Calculate total probability of non-rare patterns
76        let non_rare_total: f64 = model
77            .error_patterns
78            .iter()
79            .enumerate()
80            .filter(|(idx, _)| !rare_patterns.contains(idx))
81            .map(|(_, pattern)| pattern.probability)
82            .sum();
83
84        // Set amplified probability for rare patterns
85        let amplified_total = config.amplification_factor;
86
87        // Normalize rare patterns to sum to amplified_total
88        if rare_total > 0.0 {
89            let scale_factor = amplified_total / rare_total;
90            for &idx in &rare_patterns {
91                model.error_patterns[idx].probability *= scale_factor;
92            }
93        } else {
94            // If no rare patterns existed, distribute amplified_total evenly
95            let per_pattern = amplified_total / rare_patterns.len() as f64;
96            for &idx in &rare_patterns {
97                model.error_patterns[idx].probability = per_pattern;
98            }
99        }
100
101        // Normalize non-rare patterns to sum to (1.0 - amplified_total)
102        if non_rare_total > 0.0 {
103            let scale_factor = (1.0 - amplified_total) / non_rare_total;
104            for (idx, pattern) in model.error_patterns.iter_mut().enumerate() {
105                if !rare_patterns.contains(&idx) {
106                    pattern.probability *= scale_factor;
107                }
108            }
109        }
110
111        // Ensure probabilities sum to 1.0 (with small tolerance for floating point)
112        let total: f64 = model.error_patterns.iter().map(|p| p.probability).sum();
113        if total > 0.0 && (total - 1.0).abs() > 0.001 {
114            let scale = 1.0 / total;
115            for pattern in &mut model.error_patterns {
116                pattern.probability *= scale;
117            }
118        }
119
120        Ok(())
121    }
122
123    /// Restore original probabilities (before amplification)
124    ///
125    /// Restores the error pattern probabilities to their values before
126    /// amplification was applied. Requires that original probabilities
127    /// were stored during amplification.
128    pub fn restore_original(model: &mut EndpointProbabilityModel) -> Result<()> {
129        let original_probs = match &model.original_error_probabilities {
130            Some(probs) => probs,
131            None => {
132                // No original probabilities stored - nothing to restore
133                return Ok(());
134            }
135        };
136
137        // Restore each pattern's probability from the stored original
138        for pattern in &mut model.error_patterns {
139            if let Some(&original_prob) = original_probs.get(&pattern.error_type) {
140                pattern.probability = original_prob;
141            }
142        }
143
144        // Clear the stored original probabilities after restoration
145        model.original_error_probabilities = None;
146
147        Ok(())
148    }
149}
150
151impl Default for EdgeAmplifier {
152    fn default() -> Self {
153        Self::new()
154    }
155}