datasynth_audit_optimizer/
calibration.rs1use rand::SeedableRng;
7use rand_chacha::ChaCha8Rng;
8use serde::Serialize;
9
10use datasynth_audit_fsm::{
11 context::EngagementContext,
12 engine::AuditFsmEngine,
13 error::AuditFsmError,
14 loader::{default_overlay, BlueprintWithPreconditions},
15 schema::GenerationOverlay,
16};
17
18#[derive(Debug, Clone)]
24pub struct CalibrationTarget {
25 pub target_anomaly_rate: f64,
27 pub tolerance: f64,
30 pub max_iterations: usize,
32}
33
34#[derive(Debug, Clone, Serialize)]
36pub struct CalibratedOverlay {
37 pub overlay: GenerationOverlay,
39 pub achieved_rate: f64,
41 pub iterations: usize,
43 pub converged: bool,
45}
46
47pub fn calibrate_anomaly_rates(
69 bwp: &BlueprintWithPreconditions,
70 target: &CalibrationTarget,
71 base_seed: u64,
72 context: &EngagementContext,
73) -> Result<CalibratedOverlay, AuditFsmError> {
74 const SAMPLES_PER_ITER: usize = 3;
75 const PROB_MIN: f64 = 0.001;
76 const PROB_MAX: f64 = 0.5;
77
78 let mut overlay = default_overlay();
79
80 if target.target_anomaly_rate <= 0.0 {
82 overlay.anomalies.skipped_approval = 0.0;
83 overlay.anomalies.late_posting = 0.0;
84 overlay.anomalies.missing_evidence = 0.0;
85 overlay.anomalies.out_of_sequence = 0.0;
86 for rule in &mut overlay.anomalies.rules {
87 rule.probability = 0.0;
88 }
89 return Ok(CalibratedOverlay {
90 overlay,
91 achieved_rate: 0.0,
92 iterations: 1,
93 converged: true,
94 });
95 }
96
97 let mut best_overlay = overlay.clone();
98 let mut best_achieved = f64::MAX;
99 let mut best_distance = f64::MAX;
100
101 for iter in 0..target.max_iterations {
102 let achieved = mean_anomaly_rate(
103 bwp,
104 &overlay,
105 SAMPLES_PER_ITER,
106 base_seed,
107 iter as u64,
108 context,
109 );
110
111 let distance = (achieved - target.target_anomaly_rate).abs();
112 if distance < best_distance {
113 best_distance = distance;
114 best_achieved = achieved;
115 best_overlay = overlay.clone();
116 }
117
118 if distance <= target.tolerance {
119 return Ok(CalibratedOverlay {
120 overlay: best_overlay,
121 achieved_rate: best_achieved,
122 iterations: iter + 1,
123 converged: true,
124 });
125 }
126
127 let scale = if achieved > 1e-9 {
129 (target.target_anomaly_rate / achieved).clamp(0.1, 10.0)
130 } else {
131 2.0
133 };
134
135 scale_anomaly_probs(&mut overlay, scale, PROB_MIN, PROB_MAX);
136 }
137
138 Ok(CalibratedOverlay {
139 overlay: best_overlay,
140 achieved_rate: best_achieved,
141 iterations: target.max_iterations,
142 converged: best_distance <= target.tolerance,
143 })
144}
145
146fn mean_anomaly_rate(
152 bwp: &BlueprintWithPreconditions,
153 overlay: &GenerationOverlay,
154 samples: usize,
155 base_seed: u64,
156 seed_offset: u64,
157 context: &EngagementContext,
158) -> f64 {
159 let mut total_anomaly_rate = 0.0;
160 let mut successful = 0usize;
161
162 for i in 0..samples {
163 let iter_seed = base_seed.wrapping_add(seed_offset).wrapping_add(i as u64);
164 let rng = ChaCha8Rng::seed_from_u64(iter_seed);
165 let mut engine = AuditFsmEngine::new(bwp.clone(), overlay.clone(), rng);
166
167 let result = match engine.run_engagement(context) {
168 Ok(r) => r,
169 Err(_) => continue,
170 };
171
172 let event_count = result.event_log.len();
173 let anomaly_count = result.event_log.iter().filter(|e| e.is_anomaly).count();
174 total_anomaly_rate += if event_count > 0 {
175 anomaly_count as f64 / event_count as f64
176 } else {
177 0.0
178 };
179 successful += 1;
180 }
181
182 if successful == 0 {
183 return 0.0;
184 }
185 total_anomaly_rate / successful as f64
186}
187
188fn scale_anomaly_probs(overlay: &mut GenerationOverlay, scale: f64, min: f64, max: f64) {
190 let a = &mut overlay.anomalies;
191 a.skipped_approval = (a.skipped_approval * scale).clamp(min, max);
192 a.late_posting = (a.late_posting * scale).clamp(min, max);
193 a.missing_evidence = (a.missing_evidence * scale).clamp(min, max);
194 a.out_of_sequence = (a.out_of_sequence * scale).clamp(min, max);
195 for rule in &mut a.rules {
196 rule.probability = (rule.probability * scale).clamp(min, max);
197 }
198}
199
200#[cfg(test)]
205mod tests {
206 use super::*;
207
208 fn default_bwp() -> BlueprintWithPreconditions {
209 BlueprintWithPreconditions::load_builtin_fsa().expect("builtin FSA must load")
210 }
211
212 #[test]
214 fn test_calibrate_to_target_rate() {
215 let bwp = default_bwp();
216 let target = CalibrationTarget {
217 target_anomaly_rate: 0.15,
218 tolerance: 0.05,
219 max_iterations: 10,
220 };
221 let result =
222 calibrate_anomaly_rates(&bwp, &target, 42, &EngagementContext::demo()).unwrap();
223 let diff = (result.achieved_rate - 0.15).abs();
224 assert!(
225 diff <= 0.15,
226 "achieved_rate={:.4} too far from 0.15 (diff={:.4})",
227 result.achieved_rate,
228 diff,
229 );
230 }
231
232 #[test]
234 fn test_calibrate_zero_rate() {
235 let bwp = default_bwp();
236 let target = CalibrationTarget {
237 target_anomaly_rate: 0.0,
238 tolerance: 0.001,
239 max_iterations: 10,
240 };
241 let result = calibrate_anomaly_rates(&bwp, &target, 7, &EngagementContext::demo()).unwrap();
242 assert!(
243 result.converged,
244 "should converge immediately for zero target"
245 );
246 assert_eq!(result.overlay.anomalies.skipped_approval, 0.0);
247 assert_eq!(result.overlay.anomalies.late_posting, 0.0);
248 assert_eq!(result.overlay.anomalies.missing_evidence, 0.0);
249 assert_eq!(result.overlay.anomalies.out_of_sequence, 0.0);
250 }
251
252 #[test]
254 fn test_calibrate_converges() {
255 let bwp = default_bwp();
256 let target = CalibrationTarget {
257 target_anomaly_rate: 0.10,
258 tolerance: 0.10,
259 max_iterations: 10,
260 };
261 let result =
262 calibrate_anomaly_rates(&bwp, &target, 99, &EngagementContext::demo()).unwrap();
263 assert!(
264 result.converged,
265 "expected convergence with loose tolerance 0.10, achieved_rate={}",
266 result.achieved_rate
267 );
268 }
269
270 #[test]
272 fn test_calibrated_overlay_serializes() {
273 let bwp = default_bwp();
274 let target = CalibrationTarget {
275 target_anomaly_rate: 0.05,
276 tolerance: 0.10,
277 max_iterations: 3,
278 };
279 let result = calibrate_anomaly_rates(&bwp, &target, 1, &EngagementContext::demo()).unwrap();
280 let json = serde_json::to_string(&result).expect("CalibratedOverlay must serialize");
281 assert!(!json.is_empty());
282 assert!(json.contains("achieved_rate"));
283 assert!(json.contains("converged"));
284 }
285}