Skip to main content

datasynth_audit_optimizer/
overlay_fitting.rs

1//! Overlay parameter fitting from target engagement metrics.
2//!
3//! Given an [`EngagementProfile`] describing desired engagement characteristics
4//! (duration, event count, finding count, revision rate, anomaly rate, completion
5//! rate), this module iteratively adjusts [`GenerationOverlay`] parameters until
6//! Monte Carlo simulations match the targets within a configurable tolerance.
7
8use rand::SeedableRng;
9use rand_chacha::ChaCha8Rng;
10use serde::Serialize;
11
12use datasynth_audit_fsm::{
13    context::EngagementContext,
14    engine::AuditFsmEngine,
15    loader::{default_overlay, BlueprintWithPreconditions},
16    schema::GenerationOverlay,
17};
18
19// ---------------------------------------------------------------------------
20// Target profile
21// ---------------------------------------------------------------------------
22
23/// Desired engagement characteristics that the fitting algorithm targets.
24#[derive(Debug, Clone)]
25pub struct EngagementProfile {
26    /// Target average engagement duration in hours.
27    pub target_duration_hours: f64,
28    /// Target average event count per engagement.
29    pub target_event_count: usize,
30    /// Target average finding count per engagement.
31    pub target_finding_count: usize,
32    /// Target revision rate (fraction of total transitions that are revisions).
33    pub target_revision_rate: f64,
34    /// Target anomaly rate (fraction of events flagged as anomalies).
35    pub target_anomaly_rate: f64,
36    /// Target completion rate (fraction of procedures reaching completed/closed).
37    pub target_completion_rate: f64,
38}
39
40// ---------------------------------------------------------------------------
41// Achieved metrics
42// ---------------------------------------------------------------------------
43
44/// Mean metrics observed from Monte Carlo simulation runs.
45#[derive(Debug, Clone, Serialize)]
46pub struct EngagementMetrics {
47    /// Average engagement duration in hours.
48    pub avg_duration_hours: f64,
49    /// Average event count per engagement.
50    pub avg_event_count: f64,
51    /// Average finding count per engagement.
52    pub avg_finding_count: f64,
53    /// Average revision rate (revision transitions / total events).
54    pub avg_revision_rate: f64,
55    /// Average anomaly rate (anomaly events / total events).
56    pub avg_anomaly_rate: f64,
57    /// Average completion rate (completed procedures / total procedures).
58    pub avg_completion_rate: f64,
59}
60
61// ---------------------------------------------------------------------------
62// Fitted result
63// ---------------------------------------------------------------------------
64
65/// The result of an overlay fitting run.
66#[derive(Debug, Clone, Serialize)]
67pub struct FittedOverlay {
68    /// The adjusted overlay parameters.
69    pub overlay: GenerationOverlay,
70    /// Metrics achieved with the fitted overlay.
71    pub achieved_metrics: EngagementMetrics,
72    /// Number of fitting iterations executed.
73    pub iterations: usize,
74    /// Whether the algorithm converged (residual < threshold).
75    pub converged: bool,
76    /// Final normalized residual distance to target.
77    pub residual: f64,
78}
79
80// ---------------------------------------------------------------------------
81// Core fitting algorithm
82// ---------------------------------------------------------------------------
83
84/// Iteratively adjust overlay parameters until Monte Carlo simulations match
85/// the target engagement profile.
86///
87/// # Algorithm
88///
89/// 1. Start with [`default_overlay()`].
90/// 2. For each iteration:
91///    a. Run `samples_per_iteration` engagements with deterministic seeds.
92///    b. Compute mean metrics across samples.
93///    c. Compute normalized residual distance to target.
94///    d. If residual < 0.05 (5%), stop (converged).
95///    e. Adjust overlay parameters proportionally toward targets with clamping.
96/// 3. Return the fitted overlay with final metrics.
97///
98/// # Arguments
99///
100/// * `bwp` - Validated blueprint with preconditions.
101/// * `profile` - Target engagement profile to fit toward.
102/// * `max_iterations` - Maximum fitting iterations (recommended: 10-20).
103/// * `samples_per_iteration` - Monte Carlo runs per evaluation (recommended: 3-5).
104/// * `base_seed` - Base RNG seed for reproducibility.
105pub fn fit_overlay(
106    bwp: &BlueprintWithPreconditions,
107    profile: &EngagementProfile,
108    max_iterations: usize,
109    samples_per_iteration: usize,
110    base_seed: u64,
111    context: &EngagementContext,
112) -> FittedOverlay {
113    assert!(max_iterations >= 1, "max_iterations must be >= 1");
114    assert!(
115        samples_per_iteration >= 1,
116        "samples_per_iteration must be >= 1"
117    );
118
119    let mut overlay = default_overlay();
120    let mut best_residual = f64::MAX;
121    let mut best_overlay = overlay.clone();
122    let mut best_metrics =
123        compute_metrics(bwp, &overlay, samples_per_iteration, base_seed, 0, context);
124    let mut iterations_used = 0;
125
126    for iter in 0..max_iterations {
127        iterations_used = iter + 1;
128
129        let metrics = compute_metrics(
130            bwp,
131            &overlay,
132            samples_per_iteration,
133            base_seed,
134            iter as u64 * samples_per_iteration as u64,
135            context,
136        );
137        let residual = compute_residual(&metrics, profile);
138
139        if residual < best_residual {
140            best_residual = residual;
141            best_overlay = overlay.clone();
142            best_metrics = metrics.clone();
143        }
144
145        // Converged — within 5% of target.
146        if residual < 0.05 {
147            return FittedOverlay {
148                overlay: best_overlay,
149                achieved_metrics: best_metrics,
150                iterations: iterations_used,
151                converged: true,
152                residual: best_residual,
153            };
154        }
155
156        // Adjust overlay parameters proportionally.
157        adjust_overlay(&mut overlay, &metrics, profile);
158    }
159
160    FittedOverlay {
161        overlay: best_overlay,
162        achieved_metrics: best_metrics,
163        iterations: iterations_used,
164        converged: best_residual < 0.05,
165        residual: best_residual,
166    }
167}
168
169// ---------------------------------------------------------------------------
170// Metrics computation
171// ---------------------------------------------------------------------------
172
173/// Run `samples` engagements and compute mean metrics.
174fn compute_metrics(
175    bwp: &BlueprintWithPreconditions,
176    overlay: &GenerationOverlay,
177    samples: usize,
178    base_seed: u64,
179    seed_offset: u64,
180    context: &EngagementContext,
181) -> EngagementMetrics {
182    let mut total_duration = 0.0;
183    let mut total_events = 0.0;
184    let mut total_findings = 0.0;
185    let mut total_revision_rate = 0.0;
186    let mut total_anomaly_rate = 0.0;
187    let mut total_completion_rate = 0.0;
188    let mut successful_runs = 0usize;
189
190    for i in 0..samples {
191        let iter_seed = base_seed.wrapping_add(seed_offset).wrapping_add(i as u64);
192        let rng = ChaCha8Rng::seed_from_u64(iter_seed);
193        let mut engine = AuditFsmEngine::new(bwp.clone(), overlay.clone(), rng);
194
195        let result = match engine.run_engagement(context) {
196            Ok(r) => r,
197            Err(_) => continue,
198        };
199
200        successful_runs += 1;
201
202        let event_count = result.event_log.len();
203        total_duration += result.total_duration_hours;
204        total_events += event_count as f64;
205        total_findings += result.artifacts.findings.len() as f64;
206
207        // Revision rate: events where under_review -> in_progress divided by total events.
208        let revision_count = result
209            .event_log
210            .iter()
211            .filter(|e| {
212                e.from_state.as_deref() == Some("under_review")
213                    && e.to_state.as_deref() == Some("in_progress")
214            })
215            .count();
216        total_revision_rate += if event_count > 0 {
217            revision_count as f64 / event_count as f64
218        } else {
219            0.0
220        };
221
222        // Anomaly rate: events with is_anomaly == true divided by total events.
223        let anomaly_count = result.event_log.iter().filter(|e| e.is_anomaly).count();
224        total_anomaly_rate += if event_count > 0 {
225            anomaly_count as f64 / event_count as f64
226        } else {
227            0.0
228        };
229
230        // Completion rate: procedures in completed or closed / total procedures.
231        let total_procs = result.procedure_states.len();
232        let completed_procs = result
233            .procedure_states
234            .values()
235            .filter(|s| s.as_str() == "completed" || s.as_str() == "closed")
236            .count();
237        total_completion_rate += if total_procs > 0 {
238            completed_procs as f64 / total_procs as f64
239        } else {
240            0.0
241        };
242    }
243
244    let n = successful_runs.max(1) as f64;
245
246    EngagementMetrics {
247        avg_duration_hours: total_duration / n,
248        avg_event_count: total_events / n,
249        avg_finding_count: total_findings / n,
250        avg_revision_rate: total_revision_rate / n,
251        avg_anomaly_rate: total_anomaly_rate / n,
252        avg_completion_rate: total_completion_rate / n,
253    }
254}
255
256// ---------------------------------------------------------------------------
257// Residual computation
258// ---------------------------------------------------------------------------
259
260/// Compute normalized distance between achieved metrics and target profile.
261///
262/// Each metric contributes equally (1/6 weight) via its relative error:
263/// `|achieved - target| / max(target, epsilon)`.
264fn compute_residual(metrics: &EngagementMetrics, profile: &EngagementProfile) -> f64 {
265    let eps = 1e-6;
266    let n_metrics = 6.0;
267
268    let dur_err = (metrics.avg_duration_hours - profile.target_duration_hours).abs()
269        / profile.target_duration_hours.max(eps);
270    let evt_err = (metrics.avg_event_count - profile.target_event_count as f64).abs()
271        / (profile.target_event_count as f64).max(eps);
272    let find_err = (metrics.avg_finding_count - profile.target_finding_count as f64).abs()
273        / (profile.target_finding_count as f64).max(eps);
274    let rev_err = (metrics.avg_revision_rate - profile.target_revision_rate).abs()
275        / profile.target_revision_rate.max(eps);
276    let anom_err = (metrics.avg_anomaly_rate - profile.target_anomaly_rate).abs()
277        / profile.target_anomaly_rate.max(eps);
278    let comp_err = (metrics.avg_completion_rate - profile.target_completion_rate).abs()
279        / profile.target_completion_rate.max(eps);
280
281    (dur_err + evt_err + find_err + rev_err + anom_err + comp_err) / n_metrics
282}
283
284// ---------------------------------------------------------------------------
285// Overlay adjustment
286// ---------------------------------------------------------------------------
287
288/// Adjust overlay parameters proportionally toward the target profile.
289///
290/// Each parameter is scaled by `target / achieved` (clamped to [0.5x, 2.0x]
291/// per step to prevent oscillation) and then clamped to sane absolute ranges.
292fn adjust_overlay(
293    overlay: &mut GenerationOverlay,
294    metrics: &EngagementMetrics,
295    profile: &EngagementProfile,
296) {
297    let eps = 1e-6;
298
299    // --- Duration: adjust timing.mu_hours ---
300    let duration_ratio =
301        clamp_ratio(profile.target_duration_hours / metrics.avg_duration_hours.max(eps));
302    overlay.transitions.defaults.timing.mu_hours *= duration_ratio;
303    // Also adjust sigma proportionally to keep the distribution shape.
304    overlay.transitions.defaults.timing.sigma_hours *= duration_ratio;
305    // Clamp mu_hours to [0.5, 5000.0] hours.
306    overlay.transitions.defaults.timing.mu_hours = overlay
307        .transitions
308        .defaults
309        .timing
310        .mu_hours
311        .clamp(0.5, 5000.0);
312    overlay.transitions.defaults.timing.sigma_hours = overlay
313        .transitions
314        .defaults
315        .timing
316        .sigma_hours
317        .clamp(0.1, 2000.0);
318
319    // --- Revision rate: adjust revision_probability ---
320    let revision_ratio =
321        clamp_ratio(profile.target_revision_rate / metrics.avg_revision_rate.max(eps));
322    overlay.transitions.defaults.revision_probability *= revision_ratio;
323    overlay.transitions.defaults.revision_probability = overlay
324        .transitions
325        .defaults
326        .revision_probability
327        .clamp(0.01, 0.5);
328
329    // --- Anomaly rates: scale all anomaly probabilities ---
330    let anomaly_ratio =
331        clamp_ratio(profile.target_anomaly_rate / metrics.avg_anomaly_rate.max(eps));
332    overlay.anomalies.skipped_approval =
333        (overlay.anomalies.skipped_approval * anomaly_ratio).clamp(0.0, 0.5);
334    overlay.anomalies.late_posting =
335        (overlay.anomalies.late_posting * anomaly_ratio).clamp(0.0, 0.5);
336    overlay.anomalies.missing_evidence =
337        (overlay.anomalies.missing_evidence * anomaly_ratio).clamp(0.0, 0.5);
338    overlay.anomalies.out_of_sequence =
339        (overlay.anomalies.out_of_sequence * anomaly_ratio).clamp(0.0, 0.5);
340}
341
342/// Clamp a ratio to [0.5, 2.0] to prevent wild oscillation.
343fn clamp_ratio(ratio: f64) -> f64 {
344    ratio.clamp(0.5, 2.0)
345}
346
347// ---------------------------------------------------------------------------
348// Tests
349// ---------------------------------------------------------------------------
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354
355    fn load_fsa() -> BlueprintWithPreconditions {
356        BlueprintWithPreconditions::load_builtin_fsa().expect("builtin FSA blueprint should load")
357    }
358
359    #[test]
360    fn test_fit_overlay_converges_to_target_duration() {
361        // Target long duration (2000h) — should increase mu_hours.
362        let bwp = load_fsa();
363        let profile = EngagementProfile {
364            target_duration_hours: 2000.0,
365            target_event_count: 50,
366            target_finding_count: 5,
367            target_revision_rate: 0.15,
368            target_anomaly_rate: 0.05,
369            target_completion_rate: 1.0,
370        };
371        let fitted = fit_overlay(&bwp, &profile, 15, 3, 42, &EngagementContext::demo());
372        // Achieved duration should be closer to 2000 than default (~800).
373        assert!(
374            fitted.achieved_metrics.avg_duration_hours > 1000.0,
375            "Fitted duration {:.0} should approach target 2000",
376            fitted.achieved_metrics.avg_duration_hours
377        );
378    }
379
380    #[test]
381    fn test_fit_overlay_adjusts_anomaly_rate() {
382        // Target high anomaly rate.
383        let bwp = load_fsa();
384        let profile = EngagementProfile {
385            target_duration_hours: 800.0,
386            target_event_count: 50,
387            target_finding_count: 5,
388            target_revision_rate: 0.15,
389            target_anomaly_rate: 0.20,
390            target_completion_rate: 1.0,
391        };
392        let fitted = fit_overlay(&bwp, &profile, 15, 3, 42, &EngagementContext::demo());
393        assert!(
394            fitted.achieved_metrics.avg_anomaly_rate > 0.05,
395            "Anomaly rate {:.3} should increase toward target 0.20",
396            fitted.achieved_metrics.avg_anomaly_rate
397        );
398    }
399
400    #[test]
401    fn test_fit_overlay_returns_valid_overlay() {
402        let bwp = load_fsa();
403        let profile = EngagementProfile {
404            target_duration_hours: 800.0,
405            target_event_count: 50,
406            target_finding_count: 3,
407            target_revision_rate: 0.10,
408            target_anomaly_rate: 0.05,
409            target_completion_rate: 1.0,
410        };
411        let fitted = fit_overlay(&bwp, &profile, 10, 3, 42, &EngagementContext::demo());
412        // Overlay should have valid parameter ranges.
413        assert!(fitted.overlay.transitions.defaults.revision_probability >= 0.0);
414        assert!(fitted.overlay.transitions.defaults.revision_probability <= 0.5);
415        assert!(fitted.overlay.transitions.defaults.timing.mu_hours > 0.0);
416    }
417
418    #[test]
419    fn test_fit_overlay_serializes() {
420        let bwp = load_fsa();
421        let profile = EngagementProfile {
422            target_duration_hours: 800.0,
423            target_event_count: 50,
424            target_finding_count: 3,
425            target_revision_rate: 0.10,
426            target_anomaly_rate: 0.05,
427            target_completion_rate: 1.0,
428        };
429        let fitted = fit_overlay(&bwp, &profile, 5, 2, 42, &EngagementContext::demo());
430        let json = serde_json::to_string(&fitted).unwrap();
431        assert!(json.contains("converged"));
432        assert!(json.contains("residual"));
433    }
434
435    #[test]
436    fn test_fit_overlay_deterministic() {
437        let bwp = load_fsa();
438        let profile = EngagementProfile {
439            target_duration_hours: 1200.0,
440            target_event_count: 50,
441            target_finding_count: 5,
442            target_revision_rate: 0.15,
443            target_anomaly_rate: 0.05,
444            target_completion_rate: 1.0,
445        };
446        let f1 = fit_overlay(&bwp, &profile, 5, 2, 42, &EngagementContext::demo());
447        let f2 = fit_overlay(&bwp, &profile, 5, 2, 42, &EngagementContext::demo());
448        assert_eq!(f1.iterations, f2.iterations);
449        assert!(
450            (f1.residual - f2.residual).abs() < 0.001,
451            "Residuals should match: {} vs {}",
452            f1.residual,
453            f2.residual
454        );
455    }
456}