Skip to main content

datasynth_audit_optimizer/
conformance.rs

1//! Conformance metrics for audit event trails against blueprints.
2//!
3//! Computes fitness (fraction of observed transitions that are valid per the
4//! blueprint), precision (fraction of defined transitions that were observed),
5//! and anomaly statistics.
6
7use std::collections::{HashMap, HashSet};
8
9use datasynth_audit_fsm::context::EngagementContext;
10use datasynth_audit_fsm::engine::AuditFsmEngine;
11use datasynth_audit_fsm::event::AuditEvent;
12use datasynth_audit_fsm::loader::BlueprintWithPreconditions;
13use datasynth_audit_fsm::schema::{AuditBlueprint, GenerationOverlay};
14use rand::SeedableRng;
15use rand_chacha::ChaCha8Rng;
16use serde::Serialize;
17
18// ---------------------------------------------------------------------------
19// Types
20// ---------------------------------------------------------------------------
21
22/// Full conformance report for an event trail against a blueprint.
23#[derive(Debug, Clone, Serialize)]
24pub struct ConformanceReport {
25    /// Fraction of observed transition events that match a defined transition.
26    pub fitness: f64,
27    /// Fraction of defined transitions that were observed in the event trail.
28    pub precision: f64,
29    /// Generalization score in `[0, 1]`. High values indicate the blueprint
30    /// produces consistent fitness across different seeds (low variance).
31    /// `None` if not computed (requires `compute_generalization`).
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub generalization: Option<f64>,
34    /// Anomaly statistics.
35    pub anomaly_stats: AnomalyStats,
36    /// Per-procedure conformance breakdown.
37    pub per_procedure: Vec<ProcedureConformance>,
38}
39
40/// Metrics for evaluating an external anomaly detector against ground-truth labels.
41#[derive(Debug, Clone, Serialize)]
42pub struct AnomalyDetectionMetrics {
43    /// Events correctly identified as anomalies.
44    pub true_positives: usize,
45    /// Events incorrectly identified as anomalies.
46    pub false_positives: usize,
47    /// Anomaly events missed by the detector.
48    pub false_negatives: usize,
49    /// True negatives: correctly identified normal events.
50    pub true_negatives: usize,
51    /// Precision = TP / (TP + FP).
52    pub precision: f64,
53    /// Recall = TP / (TP + FN).
54    pub recall: f64,
55    /// F1 = 2 * precision * recall / (precision + recall).
56    pub f1: f64,
57}
58
59/// Summary statistics about anomalies in the event trail.
60#[derive(Debug, Clone, Serialize)]
61pub struct AnomalyStats {
62    /// Total events in the trail.
63    pub total_events: usize,
64    /// Number of events flagged as anomalies.
65    pub anomaly_events: usize,
66    /// Anomaly rate (anomaly_events / total_events).
67    pub anomaly_rate: f64,
68    /// Anomaly counts by type.
69    pub by_type: HashMap<String, usize>,
70}
71
72/// Conformance metrics for a single procedure.
73#[derive(Debug, Clone, Serialize)]
74pub struct ProcedureConformance {
75    /// Procedure identifier.
76    pub procedure_id: String,
77    /// Fraction of this procedure's observed transitions that are valid.
78    pub fitness: f64,
79    /// Number of transition events observed for this procedure.
80    pub transitions_observed: usize,
81    /// Number of transitions defined for this procedure in the blueprint.
82    pub transitions_defined: usize,
83}
84
85// ---------------------------------------------------------------------------
86// Analysis
87// ---------------------------------------------------------------------------
88
89/// Analyze conformance of an event trail against a blueprint.
90///
91/// - **Fitness**: For each event that has both `from_state` and `to_state`,
92///   checks whether `(from_state, to_state)` exists in the corresponding
93///   procedure's aggregate transitions. `fitness = valid / total`.
94///
95/// - **Precision**: Counts the unique `(procedure_id, from_state, to_state)`
96///   triples observed in the event trail, divided by the total number of
97///   transitions defined across all procedures in the blueprint.
98///
99/// - **Anomaly stats**: Counts events with `is_anomaly == true`, grouped by
100///   `anomaly_type`.
101///
102/// - **Per-procedure**: Computes fitness for each procedure independently.
103pub fn analyze_conformance(events: &[AuditEvent], blueprint: &AuditBlueprint) -> ConformanceReport {
104    // Build a lookup: procedure_id -> set of (from_state, to_state).
105    let mut defined_transitions: HashMap<String, HashSet<(String, String)>> = HashMap::new();
106    let mut total_defined = 0usize;
107
108    for phase in &blueprint.phases {
109        for proc in &phase.procedures {
110            let pairs: HashSet<(String, String)> = proc
111                .aggregate
112                .transitions
113                .iter()
114                .map(|t| (t.from_state.clone(), t.to_state.clone()))
115                .collect();
116            total_defined += pairs.len();
117            defined_transitions.insert(proc.id.clone(), pairs);
118        }
119    }
120
121    // Traverse events, computing fitness and precision.
122    let mut global_valid = 0usize;
123    let mut global_total = 0usize;
124    let mut observed_triples: HashSet<(String, String, String)> = HashSet::new();
125
126    // Per-procedure accumulators: (valid, total).
127    let mut proc_accum: HashMap<String, (usize, usize)> = HashMap::new();
128
129    // Anomaly tracking.
130    let mut anomaly_events = 0usize;
131    let mut anomaly_by_type: HashMap<String, usize> = HashMap::new();
132
133    for event in events {
134        // Anomaly stats.
135        if event.is_anomaly {
136            anomaly_events += 1;
137            let type_str = event
138                .anomaly_type
139                .as_ref()
140                .map(|t| t.to_string())
141                .unwrap_or_else(|| "unknown".to_string());
142            *anomaly_by_type.entry(type_str).or_default() += 1;
143        }
144
145        // Fitness: only consider events with both from_state and to_state.
146        if let (Some(ref from), Some(ref to)) = (&event.from_state, &event.to_state) {
147            global_total += 1;
148            let entry = proc_accum.entry(event.procedure_id.clone()).or_default();
149            entry.1 += 1;
150
151            let is_valid = defined_transitions
152                .get(&event.procedure_id)
153                .map(|pairs| pairs.contains(&(from.clone(), to.clone())))
154                .unwrap_or(false);
155
156            if is_valid {
157                global_valid += 1;
158                entry.0 += 1;
159            }
160
161            // Track observed triple for precision.
162            observed_triples.insert((event.procedure_id.clone(), from.clone(), to.clone()));
163        }
164    }
165
166    let fitness = if global_total > 0 {
167        global_valid as f64 / global_total as f64
168    } else {
169        1.0
170    };
171
172    let precision = if total_defined > 0 {
173        observed_triples.len() as f64 / total_defined as f64
174    } else {
175        0.0
176    };
177
178    let anomaly_rate = if events.is_empty() {
179        0.0
180    } else {
181        anomaly_events as f64 / events.len() as f64
182    };
183
184    let anomaly_stats = AnomalyStats {
185        total_events: events.len(),
186        anomaly_events,
187        anomaly_rate,
188        by_type: anomaly_by_type,
189    };
190
191    // Build per-procedure conformance.
192    let mut per_procedure: Vec<ProcedureConformance> = Vec::new();
193    // Include all procedures from the blueprint, even if they had no events.
194    for phase in &blueprint.phases {
195        for proc in &phase.procedures {
196            let (valid, total) = proc_accum.get(&proc.id).copied().unwrap_or((0, 0));
197            let proc_fitness = if total > 0 {
198                valid as f64 / total as f64
199            } else {
200                1.0
201            };
202            let transitions_defined = defined_transitions
203                .get(&proc.id)
204                .map(|s| s.len())
205                .unwrap_or(0);
206            per_procedure.push(ProcedureConformance {
207                procedure_id: proc.id.clone(),
208                fitness: proc_fitness,
209                transitions_observed: total,
210                transitions_defined,
211            });
212        }
213    }
214
215    ConformanceReport {
216        fitness,
217        precision,
218        generalization: None,
219        anomaly_stats,
220        per_procedure,
221    }
222}
223
224// ---------------------------------------------------------------------------
225// Generalization
226// ---------------------------------------------------------------------------
227
228/// Compute generalization: run the blueprint with 3 different seeds, measure
229/// fitness variance. Low variance = high generalization (score near 1.0).
230///
231/// Generalization = 1.0 - std_dev(fitness values across seeds).
232/// The result is clamped to [0, 1].
233pub fn compute_generalization(
234    bwp: &BlueprintWithPreconditions,
235    overlay: &GenerationOverlay,
236    blueprint: &AuditBlueprint,
237    base_seed: u64,
238    context: &EngagementContext,
239) -> f64 {
240    let seeds = [
241        base_seed,
242        base_seed.wrapping_add(1000),
243        base_seed.wrapping_add(2000),
244    ];
245    let mut fitness_values = Vec::new();
246
247    for seed in &seeds {
248        let rng = ChaCha8Rng::seed_from_u64(*seed);
249        let mut engine = AuditFsmEngine::new(bwp.clone(), overlay.clone(), rng);
250        if let Ok(result) = engine.run_engagement(context) {
251            let report = analyze_conformance(&result.event_log, blueprint);
252            fitness_values.push(report.fitness);
253        }
254    }
255
256    if fitness_values.len() < 2 {
257        return 1.0; // Not enough data; assume perfect generalization.
258    }
259
260    let n = fitness_values.len() as f64;
261    let mean = fitness_values.iter().sum::<f64>() / n;
262    let variance = fitness_values
263        .iter()
264        .map(|f| (f - mean).powi(2))
265        .sum::<f64>()
266        / n;
267    let std_dev = variance.sqrt();
268
269    (1.0 - std_dev).clamp(0.0, 1.0)
270}
271
272// ---------------------------------------------------------------------------
273// Anomaly Detection Evaluation
274// ---------------------------------------------------------------------------
275
276/// Evaluate an external anomaly detector's predictions against ground-truth
277/// labels from the audit event trail.
278///
279/// `events` — the audit event trail with `is_anomaly` ground-truth labels.
280/// `predictions` — one boolean per event: `true` = "detector thinks anomaly".
281///
282/// # Errors
283///
284/// Returns an error if `events.len() != predictions.len()`.
285pub fn evaluate_detector(
286    events: &[AuditEvent],
287    predictions: &[bool],
288) -> Result<AnomalyDetectionMetrics, String> {
289    if events.len() != predictions.len() {
290        return Err(format!(
291            "events and predictions must have the same length ({} vs {})",
292            events.len(),
293            predictions.len()
294        ));
295    }
296
297    let mut tp = 0usize;
298    let mut fp = 0usize;
299    let mut fn_ = 0usize;
300    let mut tn = 0usize;
301
302    for (event, &predicted) in events.iter().zip(predictions.iter()) {
303        match (event.is_anomaly, predicted) {
304            (true, true) => tp += 1,
305            (false, true) => fp += 1,
306            (true, false) => fn_ += 1,
307            (false, false) => tn += 1,
308        }
309    }
310
311    let precision = if tp + fp > 0 {
312        tp as f64 / (tp + fp) as f64
313    } else {
314        0.0
315    };
316    let recall = if tp + fn_ > 0 {
317        tp as f64 / (tp + fn_) as f64
318    } else {
319        0.0
320    };
321    let f1 = if precision + recall > 0.0 {
322        2.0 * precision * recall / (precision + recall)
323    } else {
324        0.0
325    };
326
327    Ok(AnomalyDetectionMetrics {
328        true_positives: tp,
329        false_positives: fp,
330        false_negatives: fn_,
331        true_negatives: tn,
332        precision,
333        recall,
334        f1,
335    })
336}
337
338// ---------------------------------------------------------------------------
339// Tests
340// ---------------------------------------------------------------------------
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345    use datasynth_audit_fsm::context::EngagementContext;
346    use datasynth_audit_fsm::engine::AuditFsmEngine;
347    use datasynth_audit_fsm::loader::{
348        default_overlay, load_overlay, BlueprintWithPreconditions, BuiltinOverlay, OverlaySource,
349    };
350    use rand::SeedableRng;
351    use rand_chacha::ChaCha8Rng;
352
353    fn run_fsa_engagement(
354        overlay_type: BuiltinOverlay,
355        seed: u64,
356    ) -> (Vec<AuditEvent>, AuditBlueprint) {
357        let bwp = BlueprintWithPreconditions::load_builtin_fsa().unwrap();
358        let overlay = load_overlay(&OverlaySource::Builtin(overlay_type)).unwrap();
359        let bp = bwp.blueprint.clone();
360        let rng = ChaCha8Rng::seed_from_u64(seed);
361        let mut engine = AuditFsmEngine::new(bwp, overlay, rng);
362        let ctx = EngagementContext::demo();
363        let result = engine.run_engagement(&ctx).unwrap();
364        (result.event_log, bp)
365    }
366
367    #[test]
368    fn test_conformance_perfect_log() {
369        // FSA with zeroed anomalies: all transitions should be valid.
370        let bwp = BlueprintWithPreconditions::load_builtin_fsa().unwrap();
371        let bp = bwp.blueprint.clone();
372        let mut overlay = default_overlay();
373        overlay.anomalies.skipped_approval = 0.0;
374        overlay.anomalies.late_posting = 0.0;
375        overlay.anomalies.missing_evidence = 0.0;
376        overlay.anomalies.out_of_sequence = 0.0;
377        overlay.anomalies.rules.clear();
378        let rng = ChaCha8Rng::seed_from_u64(42);
379        let mut engine = AuditFsmEngine::new(bwp, overlay, rng);
380        let ctx = EngagementContext::demo();
381        let result = engine.run_engagement(&ctx).unwrap();
382
383        let report = analyze_conformance(&result.event_log, &bp);
384        assert!(
385            (report.fitness - 1.0).abs() < f64::EPSILON,
386            "Fitness should be 1.0 for a perfect log, got {}",
387            report.fitness
388        );
389        assert_eq!(report.anomaly_stats.anomaly_events, 0);
390    }
391
392    #[test]
393    fn test_conformance_with_anomalies() {
394        // Rushed overlay has elevated anomaly rates.
395        let (events, bp) = run_fsa_engagement(BuiltinOverlay::Rushed, 42);
396        let report = analyze_conformance(&events, &bp);
397
398        // Fitness should still be high (anomalies don't create invalid transitions).
399        assert!(
400            report.fitness > 0.0,
401            "Fitness should be > 0, got {}",
402            report.fitness
403        );
404        // With rushed overlay, the anomaly_rate should be captured.
405        // (We check the stats are computed, not the exact value.)
406        assert!(report.anomaly_stats.total_events > 0, "Should have events");
407    }
408
409    #[test]
410    fn test_precision_computed() {
411        let (events, bp) = run_fsa_engagement(BuiltinOverlay::Default, 42);
412        let report = analyze_conformance(&events, &bp);
413
414        assert!(
415            report.precision > 0.0,
416            "Precision should be > 0, got {}",
417            report.precision
418        );
419        assert!(
420            report.precision <= 1.0,
421            "Precision should be <= 1.0, got {}",
422            report.precision
423        );
424    }
425
426    #[test]
427    fn test_per_procedure_conformance() {
428        let (events, bp) = run_fsa_engagement(BuiltinOverlay::Default, 42);
429        let report = analyze_conformance(&events, &bp);
430
431        // Should have a conformance entry for each procedure in the blueprint.
432        let total_procedures: usize = bp.phases.iter().map(|p| p.procedures.len()).sum();
433        assert_eq!(
434            report.per_procedure.len(),
435            total_procedures,
436            "Expected {} per-procedure entries, got {}",
437            total_procedures,
438            report.per_procedure.len()
439        );
440
441        // Each entry should have reasonable values.
442        for pc in &report.per_procedure {
443            assert!(
444                pc.fitness >= 0.0 && pc.fitness <= 1.0,
445                "Procedure '{}' fitness out of range: {}",
446                pc.procedure_id,
447                pc.fitness
448            );
449        }
450    }
451
452    #[test]
453    fn test_conformance_report_serializes() {
454        let (events, bp) = run_fsa_engagement(BuiltinOverlay::Default, 42);
455        let report = analyze_conformance(&events, &bp);
456
457        // JSON roundtrip.
458        let json = serde_json::to_string_pretty(&report).unwrap();
459        assert!(!json.is_empty());
460        let deserialized: serde_json::Value = serde_json::from_str(&json).unwrap();
461        assert!(deserialized.get("fitness").is_some());
462        assert!(deserialized.get("precision").is_some());
463        assert!(deserialized.get("anomaly_stats").is_some());
464        assert!(deserialized.get("per_procedure").is_some());
465    }
466
467    #[test]
468    fn test_generalization_score() {
469        let bwp = BlueprintWithPreconditions::load_builtin_fsa().unwrap();
470        let bp = bwp.blueprint.clone();
471        let overlay = default_overlay();
472        let ctx = EngagementContext::demo();
473        let gen = compute_generalization(&bwp, &overlay, &bp, 42, &ctx);
474
475        assert!(
476            gen >= 0.0 && gen <= 1.0,
477            "Generalization should be in [0, 1], got {}",
478            gen
479        );
480        // With deterministic FSM, fitness should be very consistent across seeds.
481        assert!(
482            gen > 0.8,
483            "Generalization should be > 0.8 for consistent FSM, got {}",
484            gen
485        );
486    }
487
488    #[test]
489    fn test_evaluate_detector_perfect() {
490        let (events, _bp) = run_fsa_engagement(BuiltinOverlay::Default, 42);
491        // Perfect detector: predictions match ground truth exactly.
492        let predictions: Vec<bool> = events.iter().map(|e| e.is_anomaly).collect();
493        let metrics = evaluate_detector(&events, &predictions).unwrap();
494
495        assert!(
496            (metrics.f1 - 1.0).abs() < f64::EPSILON || metrics.true_positives == 0,
497            "Perfect detector should have F1=1.0 or no anomalies to detect"
498        );
499        assert_eq!(metrics.false_positives, 0);
500        assert_eq!(metrics.false_negatives, 0);
501    }
502
503    #[test]
504    fn test_evaluate_detector_all_positive() {
505        let (events, _bp) = run_fsa_engagement(BuiltinOverlay::Default, 42);
506        // Naive detector: predicts everything as anomaly.
507        let predictions = vec![true; events.len()];
508        let metrics = evaluate_detector(&events, &predictions).unwrap();
509
510        // All actual anomalies found (FN=0) but many false positives.
511        assert_eq!(metrics.false_negatives, 0);
512        assert!(metrics.recall == 1.0 || metrics.true_positives == 0);
513    }
514
515    #[test]
516    fn test_evaluate_detector_serializes() {
517        let (events, _bp) = run_fsa_engagement(BuiltinOverlay::Default, 42);
518        let predictions: Vec<bool> = events.iter().map(|e| e.is_anomaly).collect();
519        let metrics = evaluate_detector(&events, &predictions).unwrap();
520
521        let json = serde_json::to_string(&metrics).unwrap();
522        assert!(json.contains("f1"));
523        assert!(json.contains("precision"));
524        assert!(json.contains("recall"));
525    }
526}