datasynth_audit_optimizer/
resource_optimizer.rs1use std::collections::{HashMap, HashSet, VecDeque};
7
8use serde::{Deserialize, Serialize};
9
10use datasynth_audit_fsm::schema::{AuditBlueprint, BlueprintProcedure, GenerationOverlay};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ResourceConstraints {
19 pub total_budget_hours: f64,
21 pub role_availability: HashMap<String, f64>,
23 pub must_include: Vec<String>,
25 pub must_exclude: Vec<String>,
27}
28
29#[derive(Debug, Clone, Serialize)]
31pub struct OptimizedPlan {
32 pub included_procedures: Vec<String>,
34 pub excluded_procedures: Vec<String>,
36 pub total_hours: f64,
38 pub total_cost: f64,
40 pub risk_coverage: f64,
42 pub standards_coverage: f64,
44 pub critical_path_hours: f64,
46 pub role_hours: HashMap<String, f64>,
48}
49
50pub fn optimize_plan(
67 blueprint: &AuditBlueprint,
68 overlay: &GenerationOverlay,
69 preconditions: &HashMap<String, Vec<String>>,
70 constraints: &ResourceConstraints,
71) -> OptimizedPlan {
72 let costs = &overlay.resource_costs;
73
74 let all_procs: Vec<&BlueprintProcedure> = blueprint
78 .phases
79 .iter()
80 .flat_map(|phase| phase.procedures.iter())
81 .collect();
82
83 let proc_map: HashMap<&str, &BlueprintProcedure> =
87 all_procs.iter().map(|p| (p.id.as_str(), *p)).collect();
88
89 let all_ids: HashSet<&str> = proc_map.keys().copied().collect();
90
91 let mut mandatory: HashSet<String> = HashSet::new();
95 let mut queue: VecDeque<String> = VecDeque::new();
96
97 for id in &constraints.must_include {
98 if mandatory.insert(id.clone()) {
99 queue.push_back(id.clone());
100 }
101 }
102
103 while let Some(proc_id) = queue.pop_front() {
104 if let Some(deps) = preconditions.get(&proc_id) {
105 for dep in deps {
106 if mandatory.insert(dep.clone()) {
107 queue.push_back(dep.clone());
108 }
109 }
110 }
111 }
112
113 let exclude_set: HashSet<&str> = constraints
117 .must_exclude
118 .iter()
119 .map(|s| s.as_str())
120 .collect();
121
122 mandatory.retain(|id| !exclude_set.contains(id.as_str()));
124
125 let mandatory_hours: f64 = mandatory
129 .iter()
130 .filter_map(|id| proc_map.get(id.as_str()))
131 .map(|p| costs.effective_hours(p))
132 .sum();
133
134 let mut included: HashSet<String> = mandatory.clone();
135
136 if mandatory_hours < constraints.total_budget_hours {
140 let mut remaining: Vec<&BlueprintProcedure> = all_procs
141 .iter()
142 .filter(|p| !included.contains(&p.id) && !exclude_set.contains(p.id.as_str()))
143 .copied()
144 .collect();
145
146 remaining.sort_by(|a, b| {
148 let score_a = discriminator_score(a) / costs.effective_hours(a);
149 let score_b = discriminator_score(b) / costs.effective_hours(b);
150 score_b
151 .partial_cmp(&score_a)
152 .unwrap_or(std::cmp::Ordering::Equal)
153 });
154
155 let mut budget_remaining = constraints.total_budget_hours - mandatory_hours;
156 for proc in remaining {
157 let h = costs.effective_hours(proc);
158 if h <= budget_remaining {
159 included.insert(proc.id.clone());
160 budget_remaining -= h;
161 }
162 }
163 }
164
165 let total_hours: f64 = included
169 .iter()
170 .filter_map(|id| proc_map.get(id.as_str()))
171 .map(|p| costs.effective_hours(p))
172 .sum();
173
174 let total_cost: f64 = included
175 .iter()
176 .filter_map(|id| proc_map.get(id.as_str()))
177 .map(|p| costs.procedure_cost(p))
178 .sum();
179
180 let (included_standards, total_standards) = compute_standards_sets(blueprint, &included);
182 let standards_coverage = if total_standards.is_empty() {
183 1.0
184 } else {
185 included_standards.len() as f64 / total_standards.len() as f64
186 };
187
188 let (included_disc_values, total_disc_values) =
190 compute_discriminator_sets(blueprint, &included);
191 let risk_coverage = if total_disc_values.is_empty() {
192 1.0
193 } else {
194 included_disc_values.len() as f64 / total_disc_values.len() as f64
195 };
196
197 let critical_path_hours =
199 compute_critical_path_hours(&included, &proc_map, preconditions, costs, overlay);
200
201 let mut role_hours: HashMap<String, f64> = HashMap::new();
203 for id in &included {
204 if let Some(proc) = proc_map.get(id.as_str()) {
205 let h = costs.effective_hours(proc);
206 let role = proc
207 .required_roles
208 .first()
209 .cloned()
210 .unwrap_or_else(|| "audit_staff".to_string());
211 *role_hours.entry(role).or_insert(0.0) += h;
212 }
213 }
214
215 let excluded: Vec<String> = all_ids
217 .iter()
218 .filter(|id| !included.contains(**id))
219 .map(|id| id.to_string())
220 .collect();
221
222 let mut included_sorted: Vec<String> = included.into_iter().collect();
223 included_sorted.sort();
224 let mut excluded_sorted = excluded;
225 excluded_sorted.sort();
226
227 OptimizedPlan {
228 included_procedures: included_sorted,
229 excluded_procedures: excluded_sorted,
230 total_hours,
231 total_cost,
232 risk_coverage,
233 standards_coverage,
234 critical_path_hours,
235 role_hours,
236 }
237}
238
239fn discriminator_score(proc: &BlueprintProcedure) -> f64 {
245 let count: usize = proc.discriminators.values().map(|v| v.len()).sum();
246 (count.max(1)) as f64
248}
249
250fn compute_standards_sets(
253 blueprint: &AuditBlueprint,
254 included: &HashSet<String>,
255) -> (HashSet<String>, HashSet<String>) {
256 let mut total = HashSet::new();
257 let mut inc = HashSet::new();
258
259 for phase in &blueprint.phases {
260 for proc in &phase.procedures {
261 for step in &proc.steps {
262 for std_ref in &step.standards {
263 total.insert(std_ref.ref_id.clone());
264 if included.contains(&proc.id) {
265 inc.insert(std_ref.ref_id.clone());
266 }
267 }
268 }
269 }
270 }
271
272 (inc, total)
273}
274
275type DiscriminatorSet = HashSet<(String, String)>;
277
278fn compute_discriminator_sets(
281 blueprint: &AuditBlueprint,
282 included: &HashSet<String>,
283) -> (DiscriminatorSet, DiscriminatorSet) {
284 let mut total = HashSet::new();
285 let mut inc = HashSet::new();
286
287 for phase in &blueprint.phases {
288 for proc in &phase.procedures {
289 for (cat, vals) in &proc.discriminators {
290 for v in vals {
291 total.insert((cat.clone(), v.clone()));
292 if included.contains(&proc.id) {
293 inc.insert((cat.clone(), v.clone()));
294 }
295 }
296 }
297 }
298 }
299
300 (inc, total)
301}
302
303fn compute_critical_path_hours(
306 included: &HashSet<String>,
307 proc_map: &HashMap<&str, &BlueprintProcedure>,
308 preconditions: &HashMap<String, Vec<String>>,
309 costs: &datasynth_audit_fsm::schema::ResourceCosts,
310 _overlay: &GenerationOverlay,
311) -> f64 {
312 let mut memo: HashMap<String, f64> = HashMap::new();
315
316 fn dfs(
317 id: &str,
318 included: &HashSet<String>,
319 proc_map: &HashMap<&str, &BlueprintProcedure>,
320 preconditions: &HashMap<String, Vec<String>>,
321 costs: &datasynth_audit_fsm::schema::ResourceCosts,
322 memo: &mut HashMap<String, f64>,
323 ) -> f64 {
324 if let Some(&cached) = memo.get(id) {
325 return cached;
326 }
327 let self_hours = proc_map
328 .get(id)
329 .map(|p| costs.effective_hours(p))
330 .unwrap_or(0.0);
331
332 let max_pred = preconditions
333 .get(id)
334 .map(|deps| {
335 deps.iter()
336 .filter(|d| included.contains(d.as_str()))
337 .map(|d| dfs(d, included, proc_map, preconditions, costs, memo))
338 .fold(0.0_f64, f64::max)
339 })
340 .unwrap_or(0.0);
341
342 let total = self_hours + max_pred;
343 memo.insert(id.to_string(), total);
344 total
345 }
346
347 included
348 .iter()
349 .map(|id| dfs(id, included, proc_map, preconditions, costs, &mut memo))
350 .fold(0.0_f64, f64::max)
351}
352
353#[cfg(test)]
358mod tests {
359 use super::*;
360 use datasynth_audit_fsm::loader::BlueprintWithPreconditions;
361
362 fn load_fsa() -> BlueprintWithPreconditions {
363 BlueprintWithPreconditions::load_builtin_fsa().expect("builtin FSA blueprint should load")
364 }
365
366 #[test]
367 fn test_must_include_always_present() {
368 let bwp = load_fsa();
369 let overlay = GenerationOverlay::default();
370 let constraints = ResourceConstraints {
371 total_budget_hours: 1000.0,
372 role_availability: HashMap::new(),
373 must_include: vec!["form_opinion".to_string()],
374 must_exclude: vec![],
375 };
376
377 let plan = optimize_plan(&bwp.blueprint, &overlay, &bwp.preconditions, &constraints);
378
379 assert!(
380 plan.included_procedures
381 .contains(&"form_opinion".to_string()),
382 "form_opinion must be included"
383 );
384 assert!(
386 plan.included_procedures
387 .contains(&"going_concern".to_string()),
388 "going_concern (transitive dep) must be included"
389 );
390 assert!(
391 plan.included_procedures
392 .contains(&"subsequent_events".to_string()),
393 "subsequent_events (transitive dep) must be included"
394 );
395 }
396
397 #[test]
398 fn test_budget_constrains_selection() {
399 let bwp = load_fsa();
400 let overlay = GenerationOverlay::default();
401
402 let constraints = ResourceConstraints {
404 total_budget_hours: 5.0,
405 role_availability: HashMap::new(),
406 must_include: vec![],
407 must_exclude: vec![],
408 };
409
410 let plan = optimize_plan(&bwp.blueprint, &overlay, &bwp.preconditions, &constraints);
411
412 assert!(
413 plan.total_hours <= 5.0,
414 "total hours {} should not exceed budget 5.0",
415 plan.total_hours
416 );
417 let total_proc_count: usize = bwp
419 .blueprint
420 .phases
421 .iter()
422 .map(|p| p.procedures.len())
423 .sum();
424 assert!(
425 plan.included_procedures.len() < total_proc_count,
426 "tight budget should exclude some procedures"
427 );
428 }
429
430 #[test]
431 fn test_must_exclude_removed() {
432 let bwp = load_fsa();
433 let overlay = GenerationOverlay::default();
434 let constraints = ResourceConstraints {
435 total_budget_hours: 1000.0,
436 role_availability: HashMap::new(),
437 must_include: vec![],
438 must_exclude: vec!["analytical_procedures".to_string()],
439 };
440
441 let plan = optimize_plan(&bwp.blueprint, &overlay, &bwp.preconditions, &constraints);
442
443 assert!(
444 !plan
445 .included_procedures
446 .contains(&"analytical_procedures".to_string()),
447 "analytical_procedures must be excluded"
448 );
449 assert!(
450 plan.excluded_procedures
451 .contains(&"analytical_procedures".to_string()),
452 "analytical_procedures must appear in excluded list"
453 );
454 }
455
456 #[test]
457 fn test_critical_path_computed() {
458 let bwp = load_fsa();
459 let overlay = GenerationOverlay::default();
460 let constraints = ResourceConstraints {
461 total_budget_hours: 1000.0,
462 role_availability: HashMap::new(),
463 must_include: vec![],
464 must_exclude: vec![],
465 };
466
467 let plan = optimize_plan(&bwp.blueprint, &overlay, &bwp.preconditions, &constraints);
468
469 assert!(
470 plan.critical_path_hours > 0.0,
471 "critical path must be > 0 when procedures are included"
472 );
473 assert!(
474 plan.critical_path_hours <= plan.total_hours,
475 "critical path {} should not exceed total hours {}",
476 plan.critical_path_hours,
477 plan.total_hours
478 );
479 }
480
481 #[test]
482 fn test_optimized_plan_serializes() {
483 let bwp = load_fsa();
484 let overlay = GenerationOverlay::default();
485 let constraints = ResourceConstraints {
486 total_budget_hours: 1000.0,
487 role_availability: HashMap::new(),
488 must_include: vec!["form_opinion".to_string()],
489 must_exclude: vec![],
490 };
491
492 let plan = optimize_plan(&bwp.blueprint, &overlay, &bwp.preconditions, &constraints);
493
494 let json = serde_json::to_string(&plan).expect("should serialize to JSON");
495 assert!(json.contains("included_procedures"));
496 assert!(json.contains("total_hours"));
497 assert!(json.contains("risk_coverage"));
498 assert!(json.contains("standards_coverage"));
499 assert!(json.contains("critical_path_hours"));
500 assert!(json.contains("role_hours"));
501 }
502}