mockforge_intelligence/behavioral_cloning/
edge_amplifier.rs1use crate::behavioral_cloning::types::{
7 EdgeAmplificationConfig, EndpointProbabilityModel, ErrorPattern,
8};
9use mockforge_core::Result;
10
11pub struct EdgeAmplifier;
13
14impl EdgeAmplifier {
15 pub fn new() -> Self {
17 Self
18 }
19
20 pub fn identify_rare_edges(
25 model: &EndpointProbabilityModel,
26 threshold: f64,
27 ) -> Vec<&ErrorPattern> {
28 model
29 .error_patterns
30 .iter()
31 .filter(|pattern| pattern.probability < threshold)
32 .collect()
33 }
34
35 pub fn apply_amplification(
42 model: &mut EndpointProbabilityModel,
43 config: &EdgeAmplificationConfig,
44 ) -> Result<()> {
45 if !config.enabled {
46 return Ok(());
47 }
48
49 if model.original_error_probabilities.is_none() {
51 let mut original = std::collections::HashMap::new();
52 for pattern in &model.error_patterns {
53 original.insert(pattern.error_type.clone(), pattern.probability);
54 }
55 model.original_error_probabilities = Some(original);
56 }
57
58 let rare_patterns: Vec<usize> = model
60 .error_patterns
61 .iter()
62 .enumerate()
63 .filter(|(_, pattern)| pattern.probability < config.rare_threshold)
64 .map(|(idx, _)| idx)
65 .collect();
66
67 if rare_patterns.is_empty() {
68 return Ok(());
69 }
70
71 let rare_total: f64 =
73 rare_patterns.iter().map(|&idx| model.error_patterns[idx].probability).sum();
74
75 let non_rare_total: f64 = model
77 .error_patterns
78 .iter()
79 .enumerate()
80 .filter(|(idx, _)| !rare_patterns.contains(idx))
81 .map(|(_, pattern)| pattern.probability)
82 .sum();
83
84 let amplified_total = config.amplification_factor;
86
87 if rare_total > 0.0 {
89 let scale_factor = amplified_total / rare_total;
90 for &idx in &rare_patterns {
91 model.error_patterns[idx].probability *= scale_factor;
92 }
93 } else {
94 let per_pattern = amplified_total / rare_patterns.len() as f64;
96 for &idx in &rare_patterns {
97 model.error_patterns[idx].probability = per_pattern;
98 }
99 }
100
101 if non_rare_total > 0.0 {
103 let scale_factor = (1.0 - amplified_total) / non_rare_total;
104 for (idx, pattern) in model.error_patterns.iter_mut().enumerate() {
105 if !rare_patterns.contains(&idx) {
106 pattern.probability *= scale_factor;
107 }
108 }
109 }
110
111 let total: f64 = model.error_patterns.iter().map(|p| p.probability).sum();
113 if total > 0.0 && (total - 1.0).abs() > 0.001 {
114 let scale = 1.0 / total;
115 for pattern in &mut model.error_patterns {
116 pattern.probability *= scale;
117 }
118 }
119
120 Ok(())
121 }
122
123 pub fn restore_original(model: &mut EndpointProbabilityModel) -> Result<()> {
129 let original_probs = match &model.original_error_probabilities {
130 Some(probs) => probs,
131 None => {
132 return Ok(());
134 }
135 };
136
137 for pattern in &mut model.error_patterns {
139 if let Some(&original_prob) = original_probs.get(&pattern.error_type) {
140 pattern.probability = original_prob;
141 }
142 }
143
144 model.original_error_probabilities = None;
146
147 Ok(())
148 }
149}
150
151impl Default for EdgeAmplifier {
152 fn default() -> Self {
153 Self::new()
154 }
155}