Skip to main content

datasynth_audit_optimizer/
resource_optimizer.rs

1//! Resource-constrained plan optimization.
2//!
3//! Given a blueprint, overlay, preconditions, and resource constraints, select
4//! the best-value set of procedures that fits within a budget.
5
6use std::collections::{HashMap, HashSet, VecDeque};
7
8use serde::{Deserialize, Serialize};
9
10use datasynth_audit_fsm::schema::{AuditBlueprint, BlueprintProcedure, GenerationOverlay};
11
12// ---------------------------------------------------------------------------
13// Types
14// ---------------------------------------------------------------------------
15
16/// Constraints that limit which procedures can be included in a plan.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ResourceConstraints {
19    /// Maximum total hours available.
20    pub total_budget_hours: f64,
21    /// Per-role hour availability (unused roles are unconstrained).
22    pub role_availability: HashMap<String, f64>,
23    /// Procedure ids that must be included regardless of budget.
24    pub must_include: Vec<String>,
25    /// Procedure ids that must be excluded.
26    pub must_exclude: Vec<String>,
27}
28
29/// An optimized audit plan produced by [`optimize_plan`].
30#[derive(Debug, Clone, Serialize)]
31pub struct OptimizedPlan {
32    /// Procedure ids included in the plan.
33    pub included_procedures: Vec<String>,
34    /// Procedure ids excluded from the plan.
35    pub excluded_procedures: Vec<String>,
36    /// Total effective hours of included procedures.
37    pub total_hours: f64,
38    /// Total monetary cost of included procedures.
39    pub total_cost: f64,
40    /// Fraction of distinct discriminator values covered (0.0 to 1.0).
41    pub risk_coverage: f64,
42    /// Fraction of distinct standards covered (0.0 to 1.0).
43    pub standards_coverage: f64,
44    /// Hours along the longest precondition chain.
45    pub critical_path_hours: f64,
46    /// Hours per primary role across included procedures.
47    pub role_hours: HashMap<String, f64>,
48}
49
50// ---------------------------------------------------------------------------
51// Public API
52// ---------------------------------------------------------------------------
53
54/// Select an optimal subset of procedures under resource constraints.
55///
56/// # Algorithm
57///
58/// 1. Collect all procedures from the blueprint.
59/// 2. Build a lookup by procedure id.
60/// 3. Expand `must_include` with transitive preconditions (BFS).
61/// 4. Remove `must_exclude` (warn if it is a dependency of a must-include).
62/// 5. If the mandatory set already exceeds the budget, return it as-is.
63/// 6. Score remaining procedures by discriminator coverage per hour and greedily
64///    add while budget allows.
65/// 7. Compute coverage, critical path, and role hours.
66pub fn optimize_plan(
67    blueprint: &AuditBlueprint,
68    overlay: &GenerationOverlay,
69    preconditions: &HashMap<String, Vec<String>>,
70    constraints: &ResourceConstraints,
71) -> OptimizedPlan {
72    let costs = &overlay.resource_costs;
73
74    // ------------------------------------------------------------------
75    // 1. Collect all procedures from all phases.
76    // ------------------------------------------------------------------
77    let all_procs: Vec<&BlueprintProcedure> = blueprint
78        .phases
79        .iter()
80        .flat_map(|phase| phase.procedures.iter())
81        .collect();
82
83    // ------------------------------------------------------------------
84    // 2. Build procedure lookup.
85    // ------------------------------------------------------------------
86    let proc_map: HashMap<&str, &BlueprintProcedure> =
87        all_procs.iter().map(|p| (p.id.as_str(), *p)).collect();
88
89    let all_ids: HashSet<&str> = proc_map.keys().copied().collect();
90
91    // ------------------------------------------------------------------
92    // 3. Expand must_include with transitive preconditions (BFS).
93    // ------------------------------------------------------------------
94    let mut mandatory: HashSet<String> = HashSet::new();
95    let mut queue: VecDeque<String> = VecDeque::new();
96
97    for id in &constraints.must_include {
98        if mandatory.insert(id.clone()) {
99            queue.push_back(id.clone());
100        }
101    }
102
103    while let Some(proc_id) = queue.pop_front() {
104        if let Some(deps) = preconditions.get(&proc_id) {
105            for dep in deps {
106                if mandatory.insert(dep.clone()) {
107                    queue.push_back(dep.clone());
108                }
109            }
110        }
111    }
112
113    // ------------------------------------------------------------------
114    // 4. Remove must_exclude.
115    // ------------------------------------------------------------------
116    let exclude_set: HashSet<&str> = constraints
117        .must_exclude
118        .iter()
119        .map(|s| s.as_str())
120        .collect();
121
122    // Remove excluded from mandatory (with warning — we just proceed).
123    mandatory.retain(|id| !exclude_set.contains(id.as_str()));
124
125    // ------------------------------------------------------------------
126    // 5. Compute hours for mandatory set.
127    // ------------------------------------------------------------------
128    let mandatory_hours: f64 = mandatory
129        .iter()
130        .filter_map(|id| proc_map.get(id.as_str()))
131        .map(|p| costs.effective_hours(p))
132        .sum();
133
134    let mut included: HashSet<String> = mandatory.clone();
135
136    // ------------------------------------------------------------------
137    // 6. If budget not exhausted, score and greedily add remaining.
138    // ------------------------------------------------------------------
139    if mandatory_hours < constraints.total_budget_hours {
140        let mut remaining: Vec<&BlueprintProcedure> = all_procs
141            .iter()
142            .filter(|p| !included.contains(&p.id) && !exclude_set.contains(p.id.as_str()))
143            .copied()
144            .collect();
145
146        // Score: discriminator values / effective hours.
147        remaining.sort_by(|a, b| {
148            let score_a = discriminator_score(a) / costs.effective_hours(a);
149            let score_b = discriminator_score(b) / costs.effective_hours(b);
150            score_b
151                .partial_cmp(&score_a)
152                .unwrap_or(std::cmp::Ordering::Equal)
153        });
154
155        let mut budget_remaining = constraints.total_budget_hours - mandatory_hours;
156        for proc in remaining {
157            let h = costs.effective_hours(proc);
158            if h <= budget_remaining {
159                included.insert(proc.id.clone());
160                budget_remaining -= h;
161            }
162        }
163    }
164
165    // ------------------------------------------------------------------
166    // 7. Compute metrics.
167    // ------------------------------------------------------------------
168    let total_hours: f64 = included
169        .iter()
170        .filter_map(|id| proc_map.get(id.as_str()))
171        .map(|p| costs.effective_hours(p))
172        .sum();
173
174    let total_cost: f64 = included
175        .iter()
176        .filter_map(|id| proc_map.get(id.as_str()))
177        .map(|p| costs.procedure_cost(p))
178        .sum();
179
180    // Standards coverage.
181    let (included_standards, total_standards) = compute_standards_sets(blueprint, &included);
182    let standards_coverage = if total_standards.is_empty() {
183        1.0
184    } else {
185        included_standards.len() as f64 / total_standards.len() as f64
186    };
187
188    // Risk (discriminator) coverage.
189    let (included_disc_values, total_disc_values) =
190        compute_discriminator_sets(blueprint, &included);
191    let risk_coverage = if total_disc_values.is_empty() {
192        1.0
193    } else {
194        included_disc_values.len() as f64 / total_disc_values.len() as f64
195    };
196
197    // Critical path hours.
198    let critical_path_hours =
199        compute_critical_path_hours(&included, &proc_map, preconditions, costs, overlay);
200
201    // Role hours.
202    let mut role_hours: HashMap<String, f64> = HashMap::new();
203    for id in &included {
204        if let Some(proc) = proc_map.get(id.as_str()) {
205            let h = costs.effective_hours(proc);
206            let role = proc
207                .required_roles
208                .first()
209                .cloned()
210                .unwrap_or_else(|| "audit_staff".to_string());
211            *role_hours.entry(role).or_insert(0.0) += h;
212        }
213    }
214
215    // Excluded procedures.
216    let excluded: Vec<String> = all_ids
217        .iter()
218        .filter(|id| !included.contains(**id))
219        .map(|id| id.to_string())
220        .collect();
221
222    let mut included_sorted: Vec<String> = included.into_iter().collect();
223    included_sorted.sort();
224    let mut excluded_sorted = excluded;
225    excluded_sorted.sort();
226
227    OptimizedPlan {
228        included_procedures: included_sorted,
229        excluded_procedures: excluded_sorted,
230        total_hours,
231        total_cost,
232        risk_coverage,
233        standards_coverage,
234        critical_path_hours,
235        role_hours,
236    }
237}
238
239// ---------------------------------------------------------------------------
240// Helpers
241// ---------------------------------------------------------------------------
242
243/// Discriminator score: total number of discriminator values across all categories.
244fn discriminator_score(proc: &BlueprintProcedure) -> f64 {
245    let count: usize = proc.discriminators.values().map(|v| v.len()).sum();
246    // Ensure we never return 0 to avoid NaN in division.
247    (count.max(1)) as f64
248}
249
250/// Collect the set of unique standard ref_ids for included procedures and for
251/// all procedures in the blueprint.
252fn compute_standards_sets(
253    blueprint: &AuditBlueprint,
254    included: &HashSet<String>,
255) -> (HashSet<String>, HashSet<String>) {
256    let mut total = HashSet::new();
257    let mut inc = HashSet::new();
258
259    for phase in &blueprint.phases {
260        for proc in &phase.procedures {
261            for step in &proc.steps {
262                for std_ref in &step.standards {
263                    total.insert(std_ref.ref_id.clone());
264                    if included.contains(&proc.id) {
265                        inc.insert(std_ref.ref_id.clone());
266                    }
267                }
268            }
269        }
270    }
271
272    (inc, total)
273}
274
275/// A set of `(category, value)` discriminator pairs.
276type DiscriminatorSet = HashSet<(String, String)>;
277
278/// Collect the set of unique discriminator `(category, value)` pairs for
279/// included procedures vs all procedures.
280fn compute_discriminator_sets(
281    blueprint: &AuditBlueprint,
282    included: &HashSet<String>,
283) -> (DiscriminatorSet, DiscriminatorSet) {
284    let mut total = HashSet::new();
285    let mut inc = HashSet::new();
286
287    for phase in &blueprint.phases {
288        for proc in &phase.procedures {
289            for (cat, vals) in &proc.discriminators {
290                for v in vals {
291                    total.insert((cat.clone(), v.clone()));
292                    if included.contains(&proc.id) {
293                        inc.insert((cat.clone(), v.clone()));
294                    }
295                }
296            }
297        }
298    }
299
300    (inc, total)
301}
302
303/// Compute the critical path hours: the longest chain of preconditions
304/// measured in effective hours.
305fn compute_critical_path_hours(
306    included: &HashSet<String>,
307    proc_map: &HashMap<&str, &BlueprintProcedure>,
308    preconditions: &HashMap<String, Vec<String>>,
309    costs: &datasynth_audit_fsm::schema::ResourceCosts,
310    _overlay: &GenerationOverlay,
311) -> f64 {
312    // Memoised DFS: for each procedure, compute the maximum total hours from
313    // root of the precondition chain to that procedure.
314    let mut memo: HashMap<String, f64> = HashMap::new();
315
316    fn dfs(
317        id: &str,
318        included: &HashSet<String>,
319        proc_map: &HashMap<&str, &BlueprintProcedure>,
320        preconditions: &HashMap<String, Vec<String>>,
321        costs: &datasynth_audit_fsm::schema::ResourceCosts,
322        memo: &mut HashMap<String, f64>,
323    ) -> f64 {
324        if let Some(&cached) = memo.get(id) {
325            return cached;
326        }
327        let self_hours = proc_map
328            .get(id)
329            .map(|p| costs.effective_hours(p))
330            .unwrap_or(0.0);
331
332        let max_pred = preconditions
333            .get(id)
334            .map(|deps| {
335                deps.iter()
336                    .filter(|d| included.contains(d.as_str()))
337                    .map(|d| dfs(d, included, proc_map, preconditions, costs, memo))
338                    .fold(0.0_f64, f64::max)
339            })
340            .unwrap_or(0.0);
341
342        let total = self_hours + max_pred;
343        memo.insert(id.to_string(), total);
344        total
345    }
346
347    included
348        .iter()
349        .map(|id| dfs(id, included, proc_map, preconditions, costs, &mut memo))
350        .fold(0.0_f64, f64::max)
351}
352
353// ---------------------------------------------------------------------------
354// Tests
355// ---------------------------------------------------------------------------
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360    use datasynth_audit_fsm::loader::BlueprintWithPreconditions;
361
362    fn load_fsa() -> BlueprintWithPreconditions {
363        BlueprintWithPreconditions::load_builtin_fsa().expect("builtin FSA blueprint should load")
364    }
365
366    #[test]
367    fn test_must_include_always_present() {
368        let bwp = load_fsa();
369        let overlay = GenerationOverlay::default();
370        let constraints = ResourceConstraints {
371            total_budget_hours: 1000.0,
372            role_availability: HashMap::new(),
373            must_include: vec!["form_opinion".to_string()],
374            must_exclude: vec![],
375        };
376
377        let plan = optimize_plan(&bwp.blueprint, &overlay, &bwp.preconditions, &constraints);
378
379        assert!(
380            plan.included_procedures
381                .contains(&"form_opinion".to_string()),
382            "form_opinion must be included"
383        );
384        // Transitive preconditions must also be included.
385        assert!(
386            plan.included_procedures
387                .contains(&"going_concern".to_string()),
388            "going_concern (transitive dep) must be included"
389        );
390        assert!(
391            plan.included_procedures
392                .contains(&"subsequent_events".to_string()),
393            "subsequent_events (transitive dep) must be included"
394        );
395    }
396
397    #[test]
398    fn test_budget_constrains_selection() {
399        let bwp = load_fsa();
400        let overlay = GenerationOverlay::default();
401
402        // Very tight budget: only enough for the smallest procedure.
403        let constraints = ResourceConstraints {
404            total_budget_hours: 5.0,
405            role_availability: HashMap::new(),
406            must_include: vec![],
407            must_exclude: vec![],
408        };
409
410        let plan = optimize_plan(&bwp.blueprint, &overlay, &bwp.preconditions, &constraints);
411
412        assert!(
413            plan.total_hours <= 5.0,
414            "total hours {} should not exceed budget 5.0",
415            plan.total_hours
416        );
417        // With a 5-hour budget, we cannot fit all 8 procedures.
418        let total_proc_count: usize = bwp
419            .blueprint
420            .phases
421            .iter()
422            .map(|p| p.procedures.len())
423            .sum();
424        assert!(
425            plan.included_procedures.len() < total_proc_count,
426            "tight budget should exclude some procedures"
427        );
428    }
429
430    #[test]
431    fn test_must_exclude_removed() {
432        let bwp = load_fsa();
433        let overlay = GenerationOverlay::default();
434        let constraints = ResourceConstraints {
435            total_budget_hours: 1000.0,
436            role_availability: HashMap::new(),
437            must_include: vec![],
438            must_exclude: vec!["analytical_procedures".to_string()],
439        };
440
441        let plan = optimize_plan(&bwp.blueprint, &overlay, &bwp.preconditions, &constraints);
442
443        assert!(
444            !plan
445                .included_procedures
446                .contains(&"analytical_procedures".to_string()),
447            "analytical_procedures must be excluded"
448        );
449        assert!(
450            plan.excluded_procedures
451                .contains(&"analytical_procedures".to_string()),
452            "analytical_procedures must appear in excluded list"
453        );
454    }
455
456    #[test]
457    fn test_critical_path_computed() {
458        let bwp = load_fsa();
459        let overlay = GenerationOverlay::default();
460        let constraints = ResourceConstraints {
461            total_budget_hours: 1000.0,
462            role_availability: HashMap::new(),
463            must_include: vec![],
464            must_exclude: vec![],
465        };
466
467        let plan = optimize_plan(&bwp.blueprint, &overlay, &bwp.preconditions, &constraints);
468
469        assert!(
470            plan.critical_path_hours > 0.0,
471            "critical path must be > 0 when procedures are included"
472        );
473        assert!(
474            plan.critical_path_hours <= plan.total_hours,
475            "critical path {} should not exceed total hours {}",
476            plan.critical_path_hours,
477            plan.total_hours
478        );
479    }
480
481    #[test]
482    fn test_optimized_plan_serializes() {
483        let bwp = load_fsa();
484        let overlay = GenerationOverlay::default();
485        let constraints = ResourceConstraints {
486            total_budget_hours: 1000.0,
487            role_availability: HashMap::new(),
488            must_include: vec!["form_opinion".to_string()],
489            must_exclude: vec![],
490        };
491
492        let plan = optimize_plan(&bwp.blueprint, &overlay, &bwp.preconditions, &constraints);
493
494        let json = serde_json::to_string(&plan).expect("should serialize to JSON");
495        assert!(json.contains("included_procedures"));
496        assert!(json.contains("total_hours"));
497        assert!(json.contains("risk_coverage"));
498        assert!(json.contains("standards_coverage"));
499        assert!(json.contains("critical_path_hours"));
500        assert!(json.contains("role_hours"));
501    }
502}