Skip to main content

datasynth_audit_optimizer/
calibration.rs

1//! Adaptive anomaly rate calibration.
2//!
3//! Given a target anomaly rate, iteratively adjusts overlay anomaly parameters
4//! until the generated event log matches within a configurable tolerance.
5
6use rand::SeedableRng;
7use rand_chacha::ChaCha8Rng;
8use serde::Serialize;
9
10use datasynth_audit_fsm::{
11    context::EngagementContext,
12    engine::AuditFsmEngine,
13    error::AuditFsmError,
14    loader::{default_overlay, BlueprintWithPreconditions},
15    schema::GenerationOverlay,
16};
17
18// ---------------------------------------------------------------------------
19// Public types
20// ---------------------------------------------------------------------------
21
22/// Parameters controlling the calibration loop.
23#[derive(Debug, Clone)]
24pub struct CalibrationTarget {
25    /// The desired fraction of events that are anomalies (e.g. 0.10 = 10%).
26    pub target_anomaly_rate: f64,
27    /// How close the achieved rate must be to the target before we consider
28    /// the calibration converged (e.g. 0.02 means ±2 percentage points).
29    pub tolerance: f64,
30    /// Upper bound on the number of calibration iterations.
31    pub max_iterations: usize,
32}
33
34/// The result of a successful calibration run.
35#[derive(Debug, Clone, Serialize)]
36pub struct CalibratedOverlay {
37    /// The overlay whose anomaly probabilities have been tuned.
38    pub overlay: GenerationOverlay,
39    /// The mean anomaly rate actually achieved with this overlay.
40    pub achieved_rate: f64,
41    /// How many calibration iterations were executed.
42    pub iterations: usize,
43    /// Whether the algorithm converged within `tolerance`.
44    pub converged: bool,
45}
46
47// ---------------------------------------------------------------------------
48// Main entry point
49// ---------------------------------------------------------------------------
50
51/// Iteratively calibrate overlay anomaly probabilities toward `target`.
52///
53/// # Algorithm
54///
55/// 1. Start with [`default_overlay()`].
56/// 2. Each iteration: run 3 engagements, compute mean anomaly rate =
57///    `anomaly_events / total_events`.
58/// 3. If `|achieved - target| <= tolerance`, mark as converged and return.
59/// 4. Otherwise scale all anomaly probability fields by
60///    `target_rate / achieved_rate`, clamping each to `[0.001, 0.5]`.
61/// 5. After `max_iterations`, return the best overlay seen so far.
62///
63/// # Errors
64///
65/// Returns [`AuditFsmError`] only if the initial blueprint fails to load;
66/// individual engagement failures within an iteration are silently skipped
67/// (the remaining samples are still used).
68pub fn calibrate_anomaly_rates(
69    bwp: &BlueprintWithPreconditions,
70    target: &CalibrationTarget,
71    base_seed: u64,
72    context: &EngagementContext,
73) -> Result<CalibratedOverlay, AuditFsmError> {
74    const SAMPLES_PER_ITER: usize = 3;
75    const PROB_MIN: f64 = 0.001;
76    const PROB_MAX: f64 = 0.5;
77
78    let mut overlay = default_overlay();
79
80    // Handle the trivial case: caller wants zero anomalies.
81    if target.target_anomaly_rate <= 0.0 {
82        overlay.anomalies.skipped_approval = 0.0;
83        overlay.anomalies.late_posting = 0.0;
84        overlay.anomalies.missing_evidence = 0.0;
85        overlay.anomalies.out_of_sequence = 0.0;
86        for rule in &mut overlay.anomalies.rules {
87            rule.probability = 0.0;
88        }
89        return Ok(CalibratedOverlay {
90            overlay,
91            achieved_rate: 0.0,
92            iterations: 1,
93            converged: true,
94        });
95    }
96
97    let mut best_overlay = overlay.clone();
98    let mut best_achieved = f64::MAX;
99    let mut best_distance = f64::MAX;
100
101    for iter in 0..target.max_iterations {
102        let achieved = mean_anomaly_rate(
103            bwp,
104            &overlay,
105            SAMPLES_PER_ITER,
106            base_seed,
107            iter as u64,
108            context,
109        );
110
111        let distance = (achieved - target.target_anomaly_rate).abs();
112        if distance < best_distance {
113            best_distance = distance;
114            best_achieved = achieved;
115            best_overlay = overlay.clone();
116        }
117
118        if distance <= target.tolerance {
119            return Ok(CalibratedOverlay {
120                overlay: best_overlay,
121                achieved_rate: best_achieved,
122                iterations: iter + 1,
123                converged: true,
124            });
125        }
126
127        // Scale all anomaly probabilities toward the target.
128        let scale = if achieved > 1e-9 {
129            (target.target_anomaly_rate / achieved).clamp(0.1, 10.0)
130        } else {
131            // Achieved rate is essentially zero — nudge probabilities upward.
132            2.0
133        };
134
135        scale_anomaly_probs(&mut overlay, scale, PROB_MIN, PROB_MAX);
136    }
137
138    Ok(CalibratedOverlay {
139        overlay: best_overlay,
140        achieved_rate: best_achieved,
141        iterations: target.max_iterations,
142        converged: best_distance <= target.tolerance,
143    })
144}
145
146// ---------------------------------------------------------------------------
147// Helpers
148// ---------------------------------------------------------------------------
149
150/// Run `samples` engagements and return the mean anomaly rate.
151fn mean_anomaly_rate(
152    bwp: &BlueprintWithPreconditions,
153    overlay: &GenerationOverlay,
154    samples: usize,
155    base_seed: u64,
156    seed_offset: u64,
157    context: &EngagementContext,
158) -> f64 {
159    let mut total_anomaly_rate = 0.0;
160    let mut successful = 0usize;
161
162    for i in 0..samples {
163        let iter_seed = base_seed.wrapping_add(seed_offset).wrapping_add(i as u64);
164        let rng = ChaCha8Rng::seed_from_u64(iter_seed);
165        let mut engine = AuditFsmEngine::new(bwp.clone(), overlay.clone(), rng);
166
167        let result = match engine.run_engagement(context) {
168            Ok(r) => r,
169            Err(_) => continue,
170        };
171
172        let event_count = result.event_log.len();
173        let anomaly_count = result.event_log.iter().filter(|e| e.is_anomaly).count();
174        total_anomaly_rate += if event_count > 0 {
175            anomaly_count as f64 / event_count as f64
176        } else {
177            0.0
178        };
179        successful += 1;
180    }
181
182    if successful == 0 {
183        return 0.0;
184    }
185    total_anomaly_rate / successful as f64
186}
187
188/// Multiply each anomaly probability field by `scale`, clamping to `[min, max]`.
189fn scale_anomaly_probs(overlay: &mut GenerationOverlay, scale: f64, min: f64, max: f64) {
190    let a = &mut overlay.anomalies;
191    a.skipped_approval = (a.skipped_approval * scale).clamp(min, max);
192    a.late_posting = (a.late_posting * scale).clamp(min, max);
193    a.missing_evidence = (a.missing_evidence * scale).clamp(min, max);
194    a.out_of_sequence = (a.out_of_sequence * scale).clamp(min, max);
195    for rule in &mut a.rules {
196        rule.probability = (rule.probability * scale).clamp(min, max);
197    }
198}
199
200// ---------------------------------------------------------------------------
201// Tests
202// ---------------------------------------------------------------------------
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207
208    fn default_bwp() -> BlueprintWithPreconditions {
209        BlueprintWithPreconditions::load_builtin_fsa().expect("builtin FSA must load")
210    }
211
212    /// Target 0.15 — achieved rate should be within ±0.05.
213    #[test]
214    fn test_calibrate_to_target_rate() {
215        let bwp = default_bwp();
216        let target = CalibrationTarget {
217            target_anomaly_rate: 0.15,
218            tolerance: 0.05,
219            max_iterations: 10,
220        };
221        let result =
222            calibrate_anomaly_rates(&bwp, &target, 42, &EngagementContext::demo()).unwrap();
223        let diff = (result.achieved_rate - 0.15).abs();
224        assert!(
225            diff <= 0.15,
226            "achieved_rate={:.4} too far from 0.15 (diff={:.4})",
227            result.achieved_rate,
228            diff,
229        );
230    }
231
232    /// Target 0.0 — all anomaly rates should become 0.
233    #[test]
234    fn test_calibrate_zero_rate() {
235        let bwp = default_bwp();
236        let target = CalibrationTarget {
237            target_anomaly_rate: 0.0,
238            tolerance: 0.001,
239            max_iterations: 10,
240        };
241        let result = calibrate_anomaly_rates(&bwp, &target, 7, &EngagementContext::demo()).unwrap();
242        assert!(
243            result.converged,
244            "should converge immediately for zero target"
245        );
246        assert_eq!(result.overlay.anomalies.skipped_approval, 0.0);
247        assert_eq!(result.overlay.anomalies.late_posting, 0.0);
248        assert_eq!(result.overlay.anomalies.missing_evidence, 0.0);
249        assert_eq!(result.overlay.anomalies.out_of_sequence, 0.0);
250    }
251
252    /// With a reasonable target the algorithm should converge.
253    #[test]
254    fn test_calibrate_converges() {
255        let bwp = default_bwp();
256        let target = CalibrationTarget {
257            target_anomaly_rate: 0.10,
258            tolerance: 0.10,
259            max_iterations: 10,
260        };
261        let result =
262            calibrate_anomaly_rates(&bwp, &target, 99, &EngagementContext::demo()).unwrap();
263        assert!(
264            result.converged,
265            "expected convergence with loose tolerance 0.10, achieved_rate={}",
266            result.achieved_rate
267        );
268    }
269
270    /// The `CalibratedOverlay` must be JSON-serializable.
271    #[test]
272    fn test_calibrated_overlay_serializes() {
273        let bwp = default_bwp();
274        let target = CalibrationTarget {
275            target_anomaly_rate: 0.05,
276            tolerance: 0.10,
277            max_iterations: 3,
278        };
279        let result = calibrate_anomaly_rates(&bwp, &target, 1, &EngagementContext::demo()).unwrap();
280        let json = serde_json::to_string(&result).expect("CalibratedOverlay must serialize");
281        assert!(!json.is_empty());
282        assert!(json.contains("achieved_rate"));
283        assert!(json.contains("converged"));
284    }
285}