use crate::behavioral_cloning::types::{
EdgeAmplificationConfig, EndpointProbabilityModel, ErrorPattern,
};
use crate::Result;
pub struct EdgeAmplifier;
impl EdgeAmplifier {
pub fn new() -> Self {
Self
}
pub fn identify_rare_edges(
model: &EndpointProbabilityModel,
threshold: f64,
) -> Vec<&ErrorPattern> {
model
.error_patterns
.iter()
.filter(|pattern| pattern.probability < threshold)
.collect()
}
pub fn apply_amplification(
model: &mut EndpointProbabilityModel,
config: &EdgeAmplificationConfig,
) -> Result<()> {
if !config.enabled {
return Ok(());
}
if model.original_error_probabilities.is_none() {
let mut original = std::collections::HashMap::new();
for pattern in &model.error_patterns {
original.insert(pattern.error_type.clone(), pattern.probability);
}
model.original_error_probabilities = Some(original);
}
let rare_patterns: Vec<usize> = model
.error_patterns
.iter()
.enumerate()
.filter(|(_, pattern)| pattern.probability < config.rare_threshold)
.map(|(idx, _)| idx)
.collect();
if rare_patterns.is_empty() {
return Ok(());
}
let rare_total: f64 =
rare_patterns.iter().map(|&idx| model.error_patterns[idx].probability).sum();
let non_rare_total: f64 = model
.error_patterns
.iter()
.enumerate()
.filter(|(idx, _)| !rare_patterns.contains(idx))
.map(|(_, pattern)| pattern.probability)
.sum();
let amplified_total = config.amplification_factor;
if rare_total > 0.0 {
let scale_factor = amplified_total / rare_total;
for &idx in &rare_patterns {
model.error_patterns[idx].probability *= scale_factor;
}
} else {
let per_pattern = amplified_total / rare_patterns.len() as f64;
for &idx in &rare_patterns {
model.error_patterns[idx].probability = per_pattern;
}
}
if non_rare_total > 0.0 {
let scale_factor = (1.0 - amplified_total) / non_rare_total;
for (idx, pattern) in model.error_patterns.iter_mut().enumerate() {
if !rare_patterns.contains(&idx) {
pattern.probability *= scale_factor;
}
}
}
let total: f64 = model.error_patterns.iter().map(|p| p.probability).sum();
if total > 0.0 && (total - 1.0).abs() > 0.001 {
let scale = 1.0 / total;
for pattern in &mut model.error_patterns {
pattern.probability *= scale;
}
}
Ok(())
}
pub fn restore_original(model: &mut EndpointProbabilityModel) -> Result<()> {
let original_probs = match &model.original_error_probabilities {
Some(probs) => probs,
None => {
return Ok(());
}
};
for pattern in &mut model.error_patterns {
if let Some(&original_prob) = original_probs.get(&pattern.error_type) {
pattern.probability = original_prob;
}
}
model.original_error_probabilities = None;
Ok(())
}
}
impl Default for EdgeAmplifier {
fn default() -> Self {
Self::new()
}
}