1use std::collections::{HashMap, HashSet};
8
9use datasynth_audit_fsm::context::EngagementContext;
10use datasynth_audit_fsm::engine::AuditFsmEngine;
11use datasynth_audit_fsm::event::AuditEvent;
12use datasynth_audit_fsm::loader::BlueprintWithPreconditions;
13use datasynth_audit_fsm::schema::{AuditBlueprint, GenerationOverlay};
14use rand::SeedableRng;
15use rand_chacha::ChaCha8Rng;
16use serde::Serialize;
17
18#[derive(Debug, Clone, Serialize)]
24pub struct ConformanceReport {
25 pub fitness: f64,
27 pub precision: f64,
29 #[serde(skip_serializing_if = "Option::is_none")]
33 pub generalization: Option<f64>,
34 pub anomaly_stats: AnomalyStats,
36 pub per_procedure: Vec<ProcedureConformance>,
38}
39
40#[derive(Debug, Clone, Serialize)]
42pub struct AnomalyDetectionMetrics {
43 pub true_positives: usize,
45 pub false_positives: usize,
47 pub false_negatives: usize,
49 pub true_negatives: usize,
51 pub precision: f64,
53 pub recall: f64,
55 pub f1: f64,
57}
58
59#[derive(Debug, Clone, Serialize)]
61pub struct AnomalyStats {
62 pub total_events: usize,
64 pub anomaly_events: usize,
66 pub anomaly_rate: f64,
68 pub by_type: HashMap<String, usize>,
70}
71
72#[derive(Debug, Clone, Serialize)]
74pub struct ProcedureConformance {
75 pub procedure_id: String,
77 pub fitness: f64,
79 pub transitions_observed: usize,
81 pub transitions_defined: usize,
83}
84
85pub fn analyze_conformance(events: &[AuditEvent], blueprint: &AuditBlueprint) -> ConformanceReport {
104 let mut defined_transitions: HashMap<String, HashSet<(String, String)>> = HashMap::new();
106 let mut total_defined = 0usize;
107
108 for phase in &blueprint.phases {
109 for proc in &phase.procedures {
110 let pairs: HashSet<(String, String)> = proc
111 .aggregate
112 .transitions
113 .iter()
114 .map(|t| (t.from_state.clone(), t.to_state.clone()))
115 .collect();
116 total_defined += pairs.len();
117 defined_transitions.insert(proc.id.clone(), pairs);
118 }
119 }
120
121 let mut global_valid = 0usize;
123 let mut global_total = 0usize;
124 let mut observed_triples: HashSet<(String, String, String)> = HashSet::new();
125
126 let mut proc_accum: HashMap<String, (usize, usize)> = HashMap::new();
128
129 let mut anomaly_events = 0usize;
131 let mut anomaly_by_type: HashMap<String, usize> = HashMap::new();
132
133 for event in events {
134 if event.is_anomaly {
136 anomaly_events += 1;
137 let type_str = event
138 .anomaly_type
139 .as_ref()
140 .map(|t| t.to_string())
141 .unwrap_or_else(|| "unknown".to_string());
142 *anomaly_by_type.entry(type_str).or_default() += 1;
143 }
144
145 if let (Some(ref from), Some(ref to)) = (&event.from_state, &event.to_state) {
147 global_total += 1;
148 let entry = proc_accum.entry(event.procedure_id.clone()).or_default();
149 entry.1 += 1;
150
151 let is_valid = defined_transitions
152 .get(&event.procedure_id)
153 .map(|pairs| pairs.contains(&(from.clone(), to.clone())))
154 .unwrap_or(false);
155
156 if is_valid {
157 global_valid += 1;
158 entry.0 += 1;
159 }
160
161 observed_triples.insert((event.procedure_id.clone(), from.clone(), to.clone()));
163 }
164 }
165
166 let fitness = if global_total > 0 {
167 global_valid as f64 / global_total as f64
168 } else {
169 1.0
170 };
171
172 let precision = if total_defined > 0 {
173 observed_triples.len() as f64 / total_defined as f64
174 } else {
175 0.0
176 };
177
178 let anomaly_rate = if events.is_empty() {
179 0.0
180 } else {
181 anomaly_events as f64 / events.len() as f64
182 };
183
184 let anomaly_stats = AnomalyStats {
185 total_events: events.len(),
186 anomaly_events,
187 anomaly_rate,
188 by_type: anomaly_by_type,
189 };
190
191 let mut per_procedure: Vec<ProcedureConformance> = Vec::new();
193 for phase in &blueprint.phases {
195 for proc in &phase.procedures {
196 let (valid, total) = proc_accum.get(&proc.id).copied().unwrap_or((0, 0));
197 let proc_fitness = if total > 0 {
198 valid as f64 / total as f64
199 } else {
200 1.0
201 };
202 let transitions_defined = defined_transitions
203 .get(&proc.id)
204 .map(|s| s.len())
205 .unwrap_or(0);
206 per_procedure.push(ProcedureConformance {
207 procedure_id: proc.id.clone(),
208 fitness: proc_fitness,
209 transitions_observed: total,
210 transitions_defined,
211 });
212 }
213 }
214
215 ConformanceReport {
216 fitness,
217 precision,
218 generalization: None,
219 anomaly_stats,
220 per_procedure,
221 }
222}
223
224pub fn compute_generalization(
234 bwp: &BlueprintWithPreconditions,
235 overlay: &GenerationOverlay,
236 blueprint: &AuditBlueprint,
237 base_seed: u64,
238 context: &EngagementContext,
239) -> f64 {
240 let seeds = [
241 base_seed,
242 base_seed.wrapping_add(1000),
243 base_seed.wrapping_add(2000),
244 ];
245 let mut fitness_values = Vec::new();
246
247 for seed in &seeds {
248 let rng = ChaCha8Rng::seed_from_u64(*seed);
249 let mut engine = AuditFsmEngine::new(bwp.clone(), overlay.clone(), rng);
250 if let Ok(result) = engine.run_engagement(context) {
251 let report = analyze_conformance(&result.event_log, blueprint);
252 fitness_values.push(report.fitness);
253 }
254 }
255
256 if fitness_values.len() < 2 {
257 return 1.0; }
259
260 let n = fitness_values.len() as f64;
261 let mean = fitness_values.iter().sum::<f64>() / n;
262 let variance = fitness_values
263 .iter()
264 .map(|f| (f - mean).powi(2))
265 .sum::<f64>()
266 / n;
267 let std_dev = variance.sqrt();
268
269 (1.0 - std_dev).clamp(0.0, 1.0)
270}
271
272pub fn evaluate_detector(
286 events: &[AuditEvent],
287 predictions: &[bool],
288) -> Result<AnomalyDetectionMetrics, String> {
289 if events.len() != predictions.len() {
290 return Err(format!(
291 "events and predictions must have the same length ({} vs {})",
292 events.len(),
293 predictions.len()
294 ));
295 }
296
297 let mut tp = 0usize;
298 let mut fp = 0usize;
299 let mut fn_ = 0usize;
300 let mut tn = 0usize;
301
302 for (event, &predicted) in events.iter().zip(predictions.iter()) {
303 match (event.is_anomaly, predicted) {
304 (true, true) => tp += 1,
305 (false, true) => fp += 1,
306 (true, false) => fn_ += 1,
307 (false, false) => tn += 1,
308 }
309 }
310
311 let precision = if tp + fp > 0 {
312 tp as f64 / (tp + fp) as f64
313 } else {
314 0.0
315 };
316 let recall = if tp + fn_ > 0 {
317 tp as f64 / (tp + fn_) as f64
318 } else {
319 0.0
320 };
321 let f1 = if precision + recall > 0.0 {
322 2.0 * precision * recall / (precision + recall)
323 } else {
324 0.0
325 };
326
327 Ok(AnomalyDetectionMetrics {
328 true_positives: tp,
329 false_positives: fp,
330 false_negatives: fn_,
331 true_negatives: tn,
332 precision,
333 recall,
334 f1,
335 })
336}
337
338#[cfg(test)]
343mod tests {
344 use super::*;
345 use datasynth_audit_fsm::context::EngagementContext;
346 use datasynth_audit_fsm::engine::AuditFsmEngine;
347 use datasynth_audit_fsm::loader::{
348 default_overlay, load_overlay, BlueprintWithPreconditions, BuiltinOverlay, OverlaySource,
349 };
350 use rand::SeedableRng;
351 use rand_chacha::ChaCha8Rng;
352
353 fn run_fsa_engagement(
354 overlay_type: BuiltinOverlay,
355 seed: u64,
356 ) -> (Vec<AuditEvent>, AuditBlueprint) {
357 let bwp = BlueprintWithPreconditions::load_builtin_fsa().unwrap();
358 let overlay = load_overlay(&OverlaySource::Builtin(overlay_type)).unwrap();
359 let bp = bwp.blueprint.clone();
360 let rng = ChaCha8Rng::seed_from_u64(seed);
361 let mut engine = AuditFsmEngine::new(bwp, overlay, rng);
362 let ctx = EngagementContext::demo();
363 let result = engine.run_engagement(&ctx).unwrap();
364 (result.event_log, bp)
365 }
366
367 #[test]
368 fn test_conformance_perfect_log() {
369 let bwp = BlueprintWithPreconditions::load_builtin_fsa().unwrap();
371 let bp = bwp.blueprint.clone();
372 let mut overlay = default_overlay();
373 overlay.anomalies.skipped_approval = 0.0;
374 overlay.anomalies.late_posting = 0.0;
375 overlay.anomalies.missing_evidence = 0.0;
376 overlay.anomalies.out_of_sequence = 0.0;
377 overlay.anomalies.rules.clear();
378 let rng = ChaCha8Rng::seed_from_u64(42);
379 let mut engine = AuditFsmEngine::new(bwp, overlay, rng);
380 let ctx = EngagementContext::demo();
381 let result = engine.run_engagement(&ctx).unwrap();
382
383 let report = analyze_conformance(&result.event_log, &bp);
384 assert!(
385 (report.fitness - 1.0).abs() < f64::EPSILON,
386 "Fitness should be 1.0 for a perfect log, got {}",
387 report.fitness
388 );
389 assert_eq!(report.anomaly_stats.anomaly_events, 0);
390 }
391
392 #[test]
393 fn test_conformance_with_anomalies() {
394 let (events, bp) = run_fsa_engagement(BuiltinOverlay::Rushed, 42);
396 let report = analyze_conformance(&events, &bp);
397
398 assert!(
400 report.fitness > 0.0,
401 "Fitness should be > 0, got {}",
402 report.fitness
403 );
404 assert!(report.anomaly_stats.total_events > 0, "Should have events");
407 }
408
409 #[test]
410 fn test_precision_computed() {
411 let (events, bp) = run_fsa_engagement(BuiltinOverlay::Default, 42);
412 let report = analyze_conformance(&events, &bp);
413
414 assert!(
415 report.precision > 0.0,
416 "Precision should be > 0, got {}",
417 report.precision
418 );
419 assert!(
420 report.precision <= 1.0,
421 "Precision should be <= 1.0, got {}",
422 report.precision
423 );
424 }
425
426 #[test]
427 fn test_per_procedure_conformance() {
428 let (events, bp) = run_fsa_engagement(BuiltinOverlay::Default, 42);
429 let report = analyze_conformance(&events, &bp);
430
431 let total_procedures: usize = bp.phases.iter().map(|p| p.procedures.len()).sum();
433 assert_eq!(
434 report.per_procedure.len(),
435 total_procedures,
436 "Expected {} per-procedure entries, got {}",
437 total_procedures,
438 report.per_procedure.len()
439 );
440
441 for pc in &report.per_procedure {
443 assert!(
444 pc.fitness >= 0.0 && pc.fitness <= 1.0,
445 "Procedure '{}' fitness out of range: {}",
446 pc.procedure_id,
447 pc.fitness
448 );
449 }
450 }
451
452 #[test]
453 fn test_conformance_report_serializes() {
454 let (events, bp) = run_fsa_engagement(BuiltinOverlay::Default, 42);
455 let report = analyze_conformance(&events, &bp);
456
457 let json = serde_json::to_string_pretty(&report).unwrap();
459 assert!(!json.is_empty());
460 let deserialized: serde_json::Value = serde_json::from_str(&json).unwrap();
461 assert!(deserialized.get("fitness").is_some());
462 assert!(deserialized.get("precision").is_some());
463 assert!(deserialized.get("anomaly_stats").is_some());
464 assert!(deserialized.get("per_procedure").is_some());
465 }
466
467 #[test]
468 fn test_generalization_score() {
469 let bwp = BlueprintWithPreconditions::load_builtin_fsa().unwrap();
470 let bp = bwp.blueprint.clone();
471 let overlay = default_overlay();
472 let ctx = EngagementContext::demo();
473 let gen = compute_generalization(&bwp, &overlay, &bp, 42, &ctx);
474
475 assert!(
476 gen >= 0.0 && gen <= 1.0,
477 "Generalization should be in [0, 1], got {}",
478 gen
479 );
480 assert!(
482 gen > 0.8,
483 "Generalization should be > 0.8 for consistent FSM, got {}",
484 gen
485 );
486 }
487
488 #[test]
489 fn test_evaluate_detector_perfect() {
490 let (events, _bp) = run_fsa_engagement(BuiltinOverlay::Default, 42);
491 let predictions: Vec<bool> = events.iter().map(|e| e.is_anomaly).collect();
493 let metrics = evaluate_detector(&events, &predictions).unwrap();
494
495 assert!(
496 (metrics.f1 - 1.0).abs() < f64::EPSILON || metrics.true_positives == 0,
497 "Perfect detector should have F1=1.0 or no anomalies to detect"
498 );
499 assert_eq!(metrics.false_positives, 0);
500 assert_eq!(metrics.false_negatives, 0);
501 }
502
503 #[test]
504 fn test_evaluate_detector_all_positive() {
505 let (events, _bp) = run_fsa_engagement(BuiltinOverlay::Default, 42);
506 let predictions = vec![true; events.len()];
508 let metrics = evaluate_detector(&events, &predictions).unwrap();
509
510 assert_eq!(metrics.false_negatives, 0);
512 assert!(metrics.recall == 1.0 || metrics.true_positives == 0);
513 }
514
515 #[test]
516 fn test_evaluate_detector_serializes() {
517 let (events, _bp) = run_fsa_engagement(BuiltinOverlay::Default, 42);
518 let predictions: Vec<bool> = events.iter().map(|e| e.is_anomaly).collect();
519 let metrics = evaluate_detector(&events, &predictions).unwrap();
520
521 let json = serde_json::to_string(&metrics).unwrap();
522 assert!(json.contains("f1"));
523 assert!(json.contains("precision"));
524 assert!(json.contains("recall"));
525 }
526}