1use crate::causal_engine::ValidatedIntervention;
4use datasynth_config::GeneratorConfig;
5use datasynth_core::{Intervention, InterventionTiming, InterventionType};
6use thiserror::Error;
7
8#[derive(Debug, Error)]
10pub enum InterventionError {
11 #[error("invalid target: {0}")]
12 InvalidTarget(String),
13 #[error(
14 "timing out of range: intervention start_month {start} exceeds period_months {period}"
15 )]
16 TimingOutOfRange { start: u32, period: u32 },
17 #[error("timing invalid: start_month must be >= 1, got {0}")]
18 TimingInvalid(u32),
19 #[error("conflict detected: interventions at priority {0} overlap on path '{1}'")]
20 ConflictDetected(u32, String),
21 #[error("bounds violation: {0}")]
22 BoundsViolation(String),
23}
24
25pub struct InterventionManager;
27
28impl InterventionManager {
29 pub fn validate(
31 interventions: &[Intervention],
32 config: &GeneratorConfig,
33 ) -> Result<Vec<ValidatedIntervention>, InterventionError> {
34 let mut validated = Vec::new();
35
36 for intervention in interventions {
37 Self::validate_timing(&intervention.timing, config)?;
38 Self::validate_bounds(&intervention.intervention_type)?;
39
40 let paths = Self::resolve_config_paths(&intervention.intervention_type);
41
42 validated.push(ValidatedIntervention {
43 intervention: intervention.clone(),
44 affected_config_paths: paths,
45 });
46 }
47
48 Self::check_conflicts(&validated)?;
49 Ok(validated)
50 }
51
52 fn validate_timing(
54 timing: &InterventionTiming,
55 config: &GeneratorConfig,
56 ) -> Result<(), InterventionError> {
57 if timing.start_month < 1 {
58 return Err(InterventionError::TimingInvalid(timing.start_month));
59 }
60 if timing.start_month > config.global.period_months {
61 return Err(InterventionError::TimingOutOfRange {
62 start: timing.start_month,
63 period: config.global.period_months,
64 });
65 }
66 Ok(())
67 }
68
69 fn validate_bounds(intervention_type: &InterventionType) -> Result<(), InterventionError> {
71 match intervention_type {
72 InterventionType::ControlFailure(cf) if !(0.0..=1.0).contains(&cf.severity) => {
73 Err(InterventionError::BoundsViolation(format!(
74 "control failure severity must be between 0.0 and 1.0, got {}",
75 cf.severity
76 )))
77 }
78 InterventionType::MacroShock(ms) if ms.severity < 0.0 => {
79 Err(InterventionError::BoundsViolation(format!(
80 "macro shock severity must be >= 0.0, got {}",
81 ms.severity
82 )))
83 }
84 _ => Ok(()),
85 }
86 }
87
88 fn resolve_config_paths(intervention_type: &InterventionType) -> Vec<String> {
90 match intervention_type {
91 InterventionType::ParameterShift(ps) => vec![ps.target.clone()],
92 InterventionType::ControlFailure(_) => {
93 vec![
94 "internal_controls.exception_rate".to_string(),
95 "internal_controls.sod_violation_rate".to_string(),
96 ]
97 }
98 InterventionType::MacroShock(_) => {
99 vec![
100 "distributions.drift.economic_cycle.amplitude".to_string(),
101 "transactions.volume_multiplier".to_string(),
102 ]
103 }
104 InterventionType::EntityEvent(ee) => {
105 use datasynth_core::InterventionEntityEvent;
106 match ee.subtype {
109 InterventionEntityEvent::VendorDefault => vec![
110 "vendor_network.dependencies.max_single_vendor_concentration".to_string(),
111 ],
112 InterventionEntityEvent::CustomerChurn => {
113 vec!["customer_segmentation.lifecycle.churned_rate".to_string()]
114 }
115 InterventionEntityEvent::EmployeeDeparture
116 | InterventionEntityEvent::KeyPersonRisk => vec![
117 "internal_controls.exception_rate".to_string(),
118 "internal_controls.sod_violation_rate".to_string(),
119 ],
120 InterventionEntityEvent::NewVendorOnboarding => vec![
121 "vendor_network.tiers.tier1.count_max".to_string(),
122 "vendor_network.clusters.standard_operational".to_string(),
123 ],
124 InterventionEntityEvent::MergerAcquisition => vec![
125 "companies".to_string(),
126 "intercompany.relationship_density".to_string(),
127 ],
128 InterventionEntityEvent::VendorCollusion => vec![
129 "fraud.enabled".to_string(),
130 "fraud.fraud_type_distribution.suspense_account_abuse".to_string(),
131 "vendor_network.clusters.problematic".to_string(),
132 ],
133 InterventionEntityEvent::CustomerConsolidation => vec![
134 "customer_segmentation.value_segments.enterprise.customer_share"
135 .to_string(),
136 "customer_segmentation.value_segments.smb.customer_share".to_string(),
137 ],
138 }
139 }
140 InterventionType::ProcessChange(_) => {
141 vec!["approval.thresholds".to_string()]
142 }
143 InterventionType::RegulatoryChange(_) => {
144 vec!["accounting_standards".to_string()]
145 }
146 InterventionType::Custom(ci) => ci.config_overrides.keys().cloned().collect(),
147 InterventionType::Composite(comp) => {
148 let mut paths = Vec::new();
149 for child in &comp.children {
150 paths.extend(Self::resolve_config_paths(child));
151 }
152 paths.sort();
153 paths.dedup();
154 paths
155 }
156 }
157 }
158
159 fn check_conflicts(validated: &[ValidatedIntervention]) -> Result<(), InterventionError> {
161 for i in 0..validated.len() {
162 for j in (i + 1)..validated.len() {
163 let a = &validated[i];
164 let b = &validated[j];
165
166 for path_a in &a.affected_config_paths {
168 for path_b in &b.affected_config_paths {
169 if path_a == path_b
170 && Self::timing_overlaps(&a.intervention.timing, &b.intervention.timing)
171 {
172 if a.intervention.priority == b.intervention.priority {
174 return Err(InterventionError::ConflictDetected(
175 a.intervention.priority,
176 path_a.clone(),
177 ));
178 }
179 }
181 }
182 }
183 }
184 }
185 Ok(())
186 }
187
188 fn timing_overlaps(a: &InterventionTiming, b: &InterventionTiming) -> bool {
190 let a_end = a.start_month + a.duration_months.unwrap_or(u32::MAX - a.start_month);
191 let b_end = b.start_month + b.duration_months.unwrap_or(u32::MAX - b.start_month);
192 a.start_month < b_end && b.start_month < a_end
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use datasynth_core::{
200 ControlFailureIntervention, ControlFailureType, ControlTarget, OnsetType,
201 ParameterShiftIntervention,
202 };
203 use datasynth_test_utils::fixtures::minimal_config;
204 use uuid::Uuid;
205
206 fn make_intervention(
207 intervention_type: InterventionType,
208 start_month: u32,
209 priority: u32,
210 ) -> Intervention {
211 Intervention {
212 id: Uuid::new_v4(),
213 intervention_type,
214 timing: InterventionTiming {
215 start_month,
216 duration_months: None,
217 onset: OnsetType::Sudden,
218 ramp_months: None,
219 },
220 label: None,
221 priority,
222 }
223 }
224
225 #[test]
226 fn test_validate_timing_out_of_range() {
227 let config = minimal_config();
228 let intervention = make_intervention(
229 InterventionType::ParameterShift(ParameterShiftIntervention {
230 target: "test.path".to_string(),
231 from: None,
232 to: serde_json::json!(100),
233 interpolation: Default::default(),
234 }),
235 999, 0,
237 );
238 let result = InterventionManager::validate(&[intervention], &config);
239 assert!(matches!(
240 result,
241 Err(InterventionError::TimingOutOfRange { .. })
242 ));
243 }
244
245 #[test]
246 fn test_validate_empty_interventions() {
247 let config = minimal_config();
248 let result = InterventionManager::validate(&[], &config);
249 assert!(result.is_ok());
250 assert!(result.expect("should be ok").is_empty());
251 }
252
253 #[test]
254 fn test_validate_parameter_shift() {
255 let config = minimal_config();
256 let intervention = make_intervention(
257 InterventionType::ParameterShift(ParameterShiftIntervention {
258 target: "transactions.count".to_string(),
259 from: None,
260 to: serde_json::json!(2000),
261 interpolation: Default::default(),
262 }),
263 1,
264 0,
265 );
266 let result = InterventionManager::validate(&[intervention], &config);
267 assert!(result.is_ok());
268 }
269
270 #[test]
271 fn test_conflict_detection() {
272 let config = minimal_config();
273 let a = make_intervention(
274 InterventionType::ParameterShift(ParameterShiftIntervention {
275 target: "transactions.count".to_string(),
276 from: None,
277 to: serde_json::json!(2000),
278 interpolation: Default::default(),
279 }),
280 1,
281 0, );
283 let b = make_intervention(
284 InterventionType::ParameterShift(ParameterShiftIntervention {
285 target: "transactions.count".to_string(),
286 from: None,
287 to: serde_json::json!(3000),
288 interpolation: Default::default(),
289 }),
290 1,
291 0, );
293 let result = InterventionManager::validate(&[a, b], &config);
294 assert!(matches!(
295 result,
296 Err(InterventionError::ConflictDetected(_, _))
297 ));
298 }
299
300 #[test]
301 fn test_conflict_resolution_by_priority() {
302 let config = minimal_config();
303 let a = make_intervention(
304 InterventionType::ParameterShift(ParameterShiftIntervention {
305 target: "transactions.count".to_string(),
306 from: None,
307 to: serde_json::json!(2000),
308 interpolation: Default::default(),
309 }),
310 1,
311 1, );
313 let b = make_intervention(
314 InterventionType::ParameterShift(ParameterShiftIntervention {
315 target: "transactions.count".to_string(),
316 from: None,
317 to: serde_json::json!(3000),
318 interpolation: Default::default(),
319 }),
320 1,
321 2, );
323 let result = InterventionManager::validate(&[a, b], &config);
324 assert!(result.is_ok());
325 }
326
327 #[test]
328 fn test_validate_bounds_control_failure() {
329 let config = minimal_config();
330 let intervention = make_intervention(
331 InterventionType::ControlFailure(ControlFailureIntervention {
332 subtype: ControlFailureType::EffectivenessReduction,
333 control_target: ControlTarget::ById {
334 control_id: "C001".to_string(),
335 },
336 severity: 1.5, detectable: true,
338 }),
339 1,
340 0,
341 );
342 let result = InterventionManager::validate(&[intervention], &config);
343 assert!(matches!(result, Err(InterventionError::BoundsViolation(_))));
344 }
345}