use crate::tile::GateThresholds;
#[derive(Clone, Debug)]
pub struct LearningConfig {
pub learning_rate: f64,
pub history_window: usize,
pub warmup_samples: usize,
pub min_structural_threshold: f64,
pub max_structural_threshold: f64,
pub ema_decay: f64,
pub auto_adjust: bool,
}
impl Default for LearningConfig {
fn default() -> Self {
Self {
learning_rate: 0.01,
history_window: 10_000,
warmup_samples: 1_000,
min_structural_threshold: 1.0,
max_structural_threshold: 20.0,
ema_decay: 0.99,
auto_adjust: true,
}
}
}
impl LearningConfig {
pub fn conservative() -> Self {
Self {
learning_rate: 0.001,
history_window: 50_000,
warmup_samples: 5_000,
ema_decay: 0.999,
auto_adjust: true,
..Default::default()
}
}
pub fn aggressive() -> Self {
Self {
learning_rate: 0.1,
history_window: 1_000,
warmup_samples: 100,
ema_decay: 0.95,
auto_adjust: true,
..Default::default()
}
}
}
#[derive(Clone, Debug, Default)]
struct RunningStats {
count: u64,
mean: f64,
m2: f64,
min: f64,
max: f64,
}
impl RunningStats {
fn new() -> Self {
Self {
count: 0,
mean: 0.0,
m2: 0.0,
min: f64::MAX,
max: f64::MIN,
}
}
fn update(&mut self, value: f64) {
self.count += 1;
let delta = value - self.mean;
self.mean += delta / self.count as f64;
let delta2 = value - self.mean;
self.m2 += delta * delta2;
if value < self.min {
self.min = value;
}
if value > self.max {
self.max = value;
}
}
fn variance(&self) -> f64 {
if self.count < 2 {
return 0.0;
}
self.m2 / (self.count - 1) as f64
}
fn std_dev(&self) -> f64 {
self.variance().sqrt()
}
}
#[derive(Clone, Debug)]
struct EMA {
value: f64,
decay: f64,
initialized: bool,
}
impl EMA {
fn new(decay: f64) -> Self {
Self {
value: 0.0,
decay,
initialized: false,
}
}
fn update(&mut self, sample: f64) {
if !self.initialized {
self.value = sample;
self.initialized = true;
} else {
self.value = self.decay * self.value + (1.0 - self.decay) * sample;
}
}
fn get(&self) -> f64 {
self.value
}
}
pub struct AdaptiveThresholds {
config: LearningConfig,
current: GateThresholds,
cut_stats: RunningStats,
shift_stats: RunningStats,
evidence_stats: RunningStats,
false_positive_ema: EMA,
false_negative_ema: EMA,
samples: u64,
outcomes: OutcomeTracker,
}
#[derive(Clone, Debug, Default)]
struct OutcomeTracker {
true_positives: u64,
true_negatives: u64,
false_positives: u64,
false_negatives: u64,
}
impl OutcomeTracker {
fn record(&mut self, predicted_deny: bool, actual_bad: bool) {
match (predicted_deny, actual_bad) {
(true, true) => self.true_positives += 1,
(false, false) => self.true_negatives += 1,
(true, false) => self.false_positives += 1,
(false, true) => self.false_negatives += 1,
}
}
fn precision(&self) -> f64 {
let denom = self.true_positives + self.false_positives;
if denom == 0 {
return 1.0;
}
self.true_positives as f64 / denom as f64
}
fn recall(&self) -> f64 {
let denom = self.true_positives + self.false_negatives;
if denom == 0 {
return 1.0;
}
self.true_positives as f64 / denom as f64
}
fn f1_score(&self) -> f64 {
let p = self.precision();
let r = self.recall();
if p + r == 0.0 {
return 0.0;
}
2.0 * p * r / (p + r)
}
fn false_positive_rate(&self) -> f64 {
let denom = self.false_positives + self.true_negatives;
if denom == 0 {
return 0.0;
}
self.false_positives as f64 / denom as f64
}
fn false_negative_rate(&self) -> f64 {
let denom = self.false_negatives + self.true_positives;
if denom == 0 {
return 0.0;
}
self.false_negatives as f64 / denom as f64
}
}
impl AdaptiveThresholds {
pub fn new(config: LearningConfig) -> Self {
let current = GateThresholds::default();
Self {
false_positive_ema: EMA::new(config.ema_decay),
false_negative_ema: EMA::new(config.ema_decay),
config,
current,
cut_stats: RunningStats::new(),
shift_stats: RunningStats::new(),
evidence_stats: RunningStats::new(),
samples: 0,
outcomes: OutcomeTracker::default(),
}
}
pub fn record_metrics(&mut self, cut: f64, shift: f64, e_value: f64) {
self.cut_stats.update(cut);
self.shift_stats.update(shift);
self.evidence_stats.update(e_value);
self.samples += 1;
if self.config.auto_adjust && self.samples > self.config.warmup_samples as u64 {
self.adjust_thresholds();
}
}
pub fn record_outcome(&mut self, was_deny: bool, was_actually_bad: bool) {
self.outcomes.record(was_deny, was_actually_bad);
let fp = if was_deny && !was_actually_bad { 1.0 } else { 0.0 };
let fn_rate = if !was_deny && was_actually_bad { 1.0 } else { 0.0 };
self.false_positive_ema.update(fp);
self.false_negative_ema.update(fn_rate);
if self.config.auto_adjust && self.samples > self.config.warmup_samples as u64 {
self.adjust_from_outcome(was_deny, was_actually_bad);
}
}
pub fn current_thresholds(&self) -> &GateThresholds {
&self.current
}
pub fn current_thresholds_mut(&mut self) -> &mut GateThresholds {
&mut self.current
}
pub fn is_warmed_up(&self) -> bool {
self.samples >= self.config.warmup_samples as u64
}
pub fn stats(&self) -> AdaptiveStats {
AdaptiveStats {
samples: self.samples,
cut_mean: self.cut_stats.mean,
cut_std: self.cut_stats.std_dev(),
shift_mean: self.shift_stats.mean,
shift_std: self.shift_stats.std_dev(),
evidence_mean: self.evidence_stats.mean,
precision: self.outcomes.precision(),
recall: self.outcomes.recall(),
f1_score: self.outcomes.f1_score(),
false_positive_rate: self.false_positive_ema.get(),
false_negative_rate: self.false_negative_ema.get(),
}
}
pub fn reset(&mut self) {
self.cut_stats = RunningStats::new();
self.shift_stats = RunningStats::new();
self.evidence_stats = RunningStats::new();
self.false_positive_ema = EMA::new(self.config.ema_decay);
self.false_negative_ema = EMA::new(self.config.ema_decay);
self.samples = 0;
self.outcomes = OutcomeTracker::default();
}
fn adjust_thresholds(&mut self) {
let lr = self.config.learning_rate;
if self.cut_stats.count > 100 {
let target = self.cut_stats.mean - 2.0 * self.cut_stats.std_dev();
let target = target.clamp(
self.config.min_structural_threshold,
self.config.max_structural_threshold,
);
self.current.structural_min_cut =
self.current.structural_min_cut * (1.0 - lr) + target * lr;
}
if self.shift_stats.count > 100 {
let target = (self.shift_stats.mean + 2.0 * self.shift_stats.std_dev()).min(1.0);
self.current.shift_max =
self.current.shift_max * (1.0 - lr) + target * lr;
}
if self.evidence_stats.count > 100 {
let tau_deny_target = (self.evidence_stats.mean - 2.0 * self.evidence_stats.std_dev())
.max(0.001);
self.current.tau_deny =
self.current.tau_deny * (1.0 - lr) + tau_deny_target * lr;
let tau_permit_target = self.evidence_stats.mean + 0.5 * self.evidence_stats.std_dev();
self.current.tau_permit =
self.current.tau_permit * (1.0 - lr) + tau_permit_target * lr;
}
}
fn adjust_from_outcome(&mut self, was_deny: bool, was_actually_bad: bool) {
let lr = self.config.learning_rate * 0.1;
match (was_deny, was_actually_bad) {
(true, false) => {
self.current.structural_min_cut *= 1.0 - lr;
self.current.shift_max = (self.current.shift_max + lr).min(1.0);
}
(false, true) => {
self.current.structural_min_cut *= 1.0 + lr;
self.current.shift_max = (self.current.shift_max - lr).max(0.1);
}
_ => {
}
}
self.current.structural_min_cut = self.current.structural_min_cut.clamp(
self.config.min_structural_threshold,
self.config.max_structural_threshold,
);
}
}
#[derive(Clone, Debug, PartialEq)]
pub enum DriftProfile {
Stable,
Linear {
slope: f64,
direction: DriftDirection,
},
StepChange {
magnitude: f64,
direction: DriftDirection,
},
Oscillating {
amplitude: f64,
period_samples: usize,
},
VarianceExpansion {
ratio: f64,
},
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum DriftDirection {
Increasing,
Decreasing,
}
#[derive(Clone, Debug)]
pub struct DriftConfig {
pub window_size: usize,
pub min_samples: usize,
pub mean_shift_threshold: f64,
pub variance_threshold: f64,
pub trend_sensitivity: f64,
}
impl Default for DriftConfig {
fn default() -> Self {
Self {
window_size: 100,
min_samples: 50,
mean_shift_threshold: 2.0, variance_threshold: 1.5, trend_sensitivity: 0.1,
}
}
}
pub struct DriftDetector {
config: DriftConfig,
buffer: Vec<f64>,
write_pos: usize,
sample_count: u64,
baseline_mean: f64,
baseline_var: f64,
prev_window_mean: f64,
prev_window_var: f64,
trend_accumulator: f64,
}
impl DriftDetector {
pub fn new(window_size: usize) -> Self {
Self::with_config(DriftConfig {
window_size,
..Default::default()
})
}
pub fn with_config(config: DriftConfig) -> Self {
Self {
buffer: vec![0.0; config.window_size],
write_pos: 0,
sample_count: 0,
baseline_mean: 0.0,
baseline_var: 0.0,
prev_window_mean: 0.0,
prev_window_var: 0.0,
trend_accumulator: 0.0,
config,
}
}
pub fn push(&mut self, value: f64) {
self.buffer[self.write_pos] = value;
self.write_pos = (self.write_pos + 1) % self.config.window_size;
self.sample_count += 1;
if self.sample_count == self.config.min_samples as u64 {
let (mean, var) = self.compute_window_stats();
self.baseline_mean = mean;
self.baseline_var = var;
self.prev_window_mean = mean;
self.prev_window_var = var;
}
}
pub fn detect(&mut self) -> Option<DriftProfile> {
if self.sample_count < self.config.min_samples as u64 {
return None;
}
let (current_mean, current_var) = self.compute_window_stats();
let baseline_std = self.baseline_var.sqrt().max(1e-10);
let mean_shift = (current_mean - self.baseline_mean).abs() / baseline_std;
if mean_shift > self.config.mean_shift_threshold {
let direction = if current_mean > self.baseline_mean {
DriftDirection::Increasing
} else {
DriftDirection::Decreasing
};
return Some(DriftProfile::StepChange {
magnitude: mean_shift * baseline_std,
direction,
});
}
let var_ratio = current_var / self.baseline_var.max(1e-10);
if var_ratio > self.config.variance_threshold || var_ratio < 1.0 / self.config.variance_threshold {
return Some(DriftProfile::VarianceExpansion { ratio: var_ratio });
}
let mean_delta = current_mean - self.prev_window_mean;
self.trend_accumulator = 0.9 * self.trend_accumulator + 0.1 * mean_delta;
if self.trend_accumulator.abs() > self.config.trend_sensitivity * baseline_std {
let direction = if self.trend_accumulator > 0.0 {
DriftDirection::Increasing
} else {
DriftDirection::Decreasing
};
let slope = self.trend_accumulator / (self.config.window_size as f64);
self.prev_window_mean = current_mean;
self.prev_window_var = current_var;
return Some(DriftProfile::Linear { slope, direction });
}
if var_ratio > 1.2 && mean_shift < 0.5 {
let period = self.estimate_oscillation_period();
if period > 2 {
return Some(DriftProfile::Oscillating {
amplitude: current_var.sqrt() - baseline_std,
period_samples: period,
});
}
}
self.prev_window_mean = current_mean;
self.prev_window_var = current_var;
Some(DriftProfile::Stable)
}
pub fn severity(&self) -> f64 {
if self.sample_count < self.config.min_samples as u64 {
return 0.0;
}
let (current_mean, current_var) = self.compute_window_stats();
let baseline_std = self.baseline_var.sqrt().max(1e-10);
let mean_component = ((current_mean - self.baseline_mean).abs() / baseline_std) / 3.0;
let var_component = if self.baseline_var < 1e-6 && current_var < 1e-6 {
0.0 } else {
((current_var / self.baseline_var.max(1e-10)) - 1.0).abs() / 2.0
};
(mean_component + var_component).min(1.0)
}
pub fn reset_baseline(&mut self) {
if self.sample_count >= self.config.min_samples as u64 {
let (mean, var) = self.compute_window_stats();
self.baseline_mean = mean;
self.baseline_var = var;
self.trend_accumulator = 0.0;
}
}
pub fn current_stats(&self) -> (f64, f64) {
self.compute_window_stats()
}
pub fn baseline_stats(&self) -> (f64, f64) {
(self.baseline_mean, self.baseline_var)
}
fn compute_window_stats(&self) -> (f64, f64) {
let n = self.buffer.len().min(self.sample_count as usize);
if n == 0 {
return (0.0, 0.0);
}
let sum: f64 = self.buffer.iter().take(n).sum();
let mean = sum / n as f64;
let var_sum: f64 = self.buffer.iter()
.take(n)
.map(|x| (x - mean).powi(2))
.sum();
let var = var_sum / n as f64;
(mean, var)
}
fn estimate_oscillation_period(&self) -> usize {
let (mean, _) = self.compute_window_stats();
let n = self.buffer.len().min(self.sample_count as usize);
let mut crossings = 0;
let mut prev_above = self.buffer[0] > mean;
for i in 1..n {
let above = self.buffer[i] > mean;
if above != prev_above {
crossings += 1;
prev_above = above;
}
}
if crossings < 2 {
return 0;
}
(2 * n) / crossings
}
}
impl AdaptiveThresholds {
pub fn apply_drift_compensation(&mut self, profile: &DriftProfile) {
match profile {
DriftProfile::Stable => {
}
DriftProfile::Linear { slope, direction } => {
let adjustment = slope.abs() * 0.5;
match direction {
DriftDirection::Increasing => {
self.current.structural_min_cut += adjustment;
}
DriftDirection::Decreasing => {
self.current.structural_min_cut -= adjustment;
}
}
}
DriftProfile::StepChange { magnitude, direction } => {
let adjustment = magnitude * 0.3;
match direction {
DriftDirection::Increasing => {
self.current.structural_min_cut += adjustment;
}
DriftDirection::Decreasing => {
self.current.structural_min_cut -= adjustment;
}
}
}
DriftProfile::Oscillating { amplitude, .. } => {
self.current.structural_min_cut += amplitude * 0.5;
}
DriftProfile::VarianceExpansion { ratio } => {
if *ratio > 1.0 {
self.current.shift_max = (self.current.shift_max * ratio.sqrt()).min(1.0);
}
}
}
self.current.structural_min_cut = self.current.structural_min_cut.clamp(
self.config.min_structural_threshold,
self.config.max_structural_threshold,
);
}
}
#[derive(Clone, Debug, Default)]
pub struct AdaptiveStats {
pub samples: u64,
pub cut_mean: f64,
pub cut_std: f64,
pub shift_mean: f64,
pub shift_std: f64,
pub evidence_mean: f64,
pub precision: f64,
pub recall: f64,
pub f1_score: f64,
pub false_positive_rate: f64,
pub false_negative_rate: f64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_learning_config_default() {
let config = LearningConfig::default();
assert_eq!(config.learning_rate, 0.01);
assert!(config.auto_adjust);
}
#[test]
fn test_running_stats() {
let mut stats = RunningStats::new();
for i in 1..=100 {
stats.update(i as f64);
}
assert_eq!(stats.count, 100);
assert!((stats.mean - 50.5).abs() < 0.001);
assert!(stats.std_dev() > 0.0);
assert_eq!(stats.min, 1.0);
assert_eq!(stats.max, 100.0);
}
#[test]
fn test_ema() {
let mut ema = EMA::new(0.9);
ema.update(100.0);
assert_eq!(ema.get(), 100.0);
ema.update(0.0);
assert!((ema.get() - 90.0).abs() < 0.001);
}
#[test]
fn test_adaptive_thresholds_creation() {
let config = LearningConfig::default();
let adaptive = AdaptiveThresholds::new(config);
assert!(!adaptive.is_warmed_up());
assert_eq!(adaptive.samples, 0);
}
#[test]
fn test_adaptive_metrics_recording() {
let config = LearningConfig {
warmup_samples: 10,
..Default::default()
};
let mut adaptive = AdaptiveThresholds::new(config);
for i in 0..20 {
adaptive.record_metrics(10.0 + i as f64 * 0.1, 0.2, 100.0);
}
assert!(adaptive.is_warmed_up());
assert_eq!(adaptive.samples, 20);
}
#[test]
fn test_outcome_tracker() {
let mut tracker = OutcomeTracker::default();
for _ in 0..8 {
tracker.record(true, true);
}
for _ in 0..2 {
tracker.record(true, false);
}
assert_eq!(tracker.precision(), 0.8);
}
#[test]
fn test_adaptive_stats() {
let config = LearningConfig {
warmup_samples: 5,
..Default::default()
};
let mut adaptive = AdaptiveThresholds::new(config);
for _ in 0..10 {
adaptive.record_metrics(10.0, 0.2, 100.0);
}
let stats = adaptive.stats();
assert_eq!(stats.samples, 10);
assert!((stats.cut_mean - 10.0).abs() < 0.001);
}
#[test]
fn test_drift_detector_creation() {
let detector = DriftDetector::new(100);
assert_eq!(detector.sample_count, 0);
}
#[test]
fn test_drift_detector_stable() {
let mut detector = DriftDetector::new(50);
for i in 0..100 {
let noise = ((i as f64) * 0.1).sin() * 0.1;
detector.push(10.0 + noise);
}
let profile = detector.detect();
assert!(matches!(profile, Some(DriftProfile::Stable)));
}
#[test]
fn test_drift_detector_step_change() {
let mut detector = DriftDetector::with_config(DriftConfig {
window_size: 50,
min_samples: 30,
mean_shift_threshold: 2.0,
..Default::default()
});
for _ in 0..40 {
detector.push(10.0);
}
for _ in 0..30 {
detector.push(20.0);
}
let profile = detector.detect();
assert!(
matches!(profile, Some(DriftProfile::StepChange { direction: DriftDirection::Increasing, .. })),
"Expected step change increasing, got {:?}",
profile
);
}
#[test]
fn test_drift_detector_variance_expansion() {
let mut detector = DriftDetector::with_config(DriftConfig {
window_size: 50,
min_samples: 30,
variance_threshold: 1.5,
mean_shift_threshold: 5.0, ..Default::default()
});
for i in 0..40 {
let noise = ((i as f64) * 0.1).sin() * 0.05;
detector.push(10.0 + noise);
}
detector.reset_baseline();
for i in 0..50 {
let noise = ((i as f64) * 0.3).sin() * 2.5; detector.push(10.0 + noise);
}
let profile = detector.detect();
assert!(
profile.is_some(),
"Expected some drift profile, got None"
);
}
#[test]
fn test_drift_severity() {
let mut detector = DriftDetector::new(50);
for i in 0..10 {
detector.push(10.0 + (i as f64) * 0.001); }
assert_eq!(detector.severity(), 0.0);
for i in 0..100 {
let noise = ((i as f64) * 0.1).sin() * 0.05;
detector.push(10.0 + noise);
}
detector.reset_baseline();
for i in 0..50 {
let noise = ((i as f64 + 100.0) * 0.1).sin() * 0.05;
detector.push(10.0 + noise);
}
let severity = detector.severity();
assert!(severity < 0.6, "Expected reasonable severity for stable signal: {}", severity);
}
#[test]
fn test_drift_baseline_reset() {
let mut detector = DriftDetector::new(50);
for _ in 0..60 {
detector.push(10.0);
}
let (baseline_mean, _) = detector.baseline_stats();
assert!((baseline_mean - 10.0).abs() < 0.1);
for _ in 0..30 {
detector.push(20.0);
}
detector.reset_baseline();
let (new_baseline, _) = detector.baseline_stats();
assert!(new_baseline > 12.0, "Baseline should shift: {}", new_baseline);
}
#[test]
fn test_drift_compensation() {
let config = LearningConfig::default();
let mut adaptive = AdaptiveThresholds::new(config);
let original = adaptive.current.structural_min_cut;
let profile = DriftProfile::StepChange {
magnitude: 2.0,
direction: DriftDirection::Increasing,
};
adaptive.apply_drift_compensation(&profile);
assert!(
adaptive.current.structural_min_cut > original,
"Threshold should increase for increasing drift"
);
}
#[test]
fn test_drift_config_default() {
let config = DriftConfig::default();
assert_eq!(config.window_size, 100);
assert_eq!(config.min_samples, 50);
assert_eq!(config.mean_shift_threshold, 2.0);
}
}