Skip to main content

datasynth_runtime/
config_mutator.rs

1//! Config mutation engine for counterfactual simulation.
2//!
3//! Applies propagated causal effects to a GeneratorConfig by
4//! navigating dot-paths and setting values.
5
6use crate::causal_engine::PropagatedInterventions;
7use datasynth_config::GeneratorConfig;
8use datasynth_core::ScenarioConstraints;
9use thiserror::Error;
10
11/// Errors during config mutation.
12#[derive(Debug, Error)]
13pub enum MutationError {
14    #[error("path not found: {0}")]
15    PathNotFound(String),
16    #[error("type mismatch at path '{path}': expected {expected}, got {actual}")]
17    TypeMismatch {
18        path: String,
19        expected: String,
20        actual: String,
21    },
22    #[error("constraint violation: {0}")]
23    ConstraintViolation(String),
24    #[error("serialization error: {0}")]
25    SerializationError(String),
26}
27
28/// Applies interventions to a config, producing a new config.
29pub struct ConfigMutator;
30
31impl ConfigMutator {
32    /// Create a mutated config by applying propagated interventions.
33    pub fn apply(
34        base: &GeneratorConfig,
35        propagated: &PropagatedInterventions,
36        constraints: &ScenarioConstraints,
37    ) -> Result<GeneratorConfig, MutationError> {
38        // Serialize config to JSON Value for dot-path navigation
39        let mut json = serde_json::to_value(base)
40            .map_err(|e| MutationError::SerializationError(e.to_string()))?;
41
42        // Collect all changes, using the latest value for each path
43        let mut latest_changes: std::collections::HashMap<String, serde_json::Value> =
44            std::collections::HashMap::new();
45
46        for changes in propagated.changes_by_month.values() {
47            for change in changes {
48                latest_changes.insert(change.path.clone(), change.value.clone());
49            }
50        }
51
52        // Apply changes
53        for (path, value) in &latest_changes {
54            Self::apply_at_path(&mut json, path, value)?;
55        }
56
57        // Strip null values before deserializing back.
58        // GeneratorConfig has `f64` fields with `#[serde(default)]` that work when
59        // the key is absent (YAML) but fail when the key is present as `null` (JSON Value).
60        Self::strip_nulls(&mut json);
61
62        // Deserialize back
63        let mutated: GeneratorConfig = serde_json::from_value(json)
64            .map_err(|e| MutationError::SerializationError(e.to_string()))?;
65
66        // Validate constraints
67        Self::validate_constraints(&mutated, constraints)?;
68
69        Ok(mutated)
70    }
71
72    /// Apply a single value at a dot-path, supporting array indexing.
73    ///
74    /// Examples:
75    /// - `"global.seed"` → navigates to `json["global"]["seed"]`
76    /// - `"distributions.amounts.components[0].mu"` → navigates to `json["distributions"]["amounts"]["components"][0]["mu"]`
77    pub fn apply_at_path(
78        value: &mut serde_json::Value,
79        path: &str,
80        new_value: &serde_json::Value,
81    ) -> Result<(), MutationError> {
82        let segments = Self::parse_path(path);
83        let mut current = value;
84
85        for (i, segment) in segments.iter().enumerate() {
86            let is_last = i == segments.len() - 1;
87
88            match segment {
89                PathSegment::Key(key) => {
90                    if is_last {
91                        if let Some(obj) = current.as_object_mut() {
92                            obj.insert(key.clone(), new_value.clone());
93                            return Ok(());
94                        }
95                        return Err(MutationError::PathNotFound(path.to_string()));
96                    }
97                    current = current
98                        .get_mut(key.as_str())
99                        .ok_or_else(|| MutationError::PathNotFound(path.to_string()))?;
100                }
101                PathSegment::Index(idx) => {
102                    if is_last {
103                        if let Some(arr) = current.as_array_mut() {
104                            if *idx < arr.len() {
105                                arr[*idx] = new_value.clone();
106                                return Ok(());
107                            }
108                        }
109                        return Err(MutationError::PathNotFound(path.to_string()));
110                    }
111                    current = current
112                        .get_mut(*idx)
113                        .ok_or_else(|| MutationError::PathNotFound(path.to_string()))?;
114                }
115            }
116        }
117
118        Err(MutationError::PathNotFound(path.to_string()))
119    }
120
121    /// Parse a dot-path with optional array indices.
122    fn parse_path(path: &str) -> Vec<PathSegment> {
123        let mut segments = Vec::new();
124        for part in path.split('.') {
125            if let Some(bracket_pos) = part.find('[') {
126                // Key with array index: "components[0]"
127                let key = &part[..bracket_pos];
128                if !key.is_empty() {
129                    segments.push(PathSegment::Key(key.to_string()));
130                }
131                // Parse index
132                let idx_str = &part[bracket_pos + 1..part.len() - 1];
133                if let Ok(idx) = idx_str.parse::<usize>() {
134                    segments.push(PathSegment::Index(idx));
135                }
136            } else {
137                segments.push(PathSegment::Key(part.to_string()));
138            }
139        }
140        segments
141    }
142
143    /// Recursively remove null values from a JSON object tree.
144    /// This allows `#[serde(default)]` fields to use their defaults instead of
145    /// failing on `null` during deserialization.
146    fn strip_nulls(value: &mut serde_json::Value) {
147        match value {
148            serde_json::Value::Object(map) => {
149                map.retain(|_, v| !v.is_null());
150                for v in map.values_mut() {
151                    Self::strip_nulls(v);
152                }
153            }
154            serde_json::Value::Array(arr) => {
155                for v in arr.iter_mut() {
156                    Self::strip_nulls(v);
157                }
158            }
159            _ => {}
160        }
161    }
162
163    /// Validate that constraints are satisfied by the mutated config.
164    fn validate_constraints(
165        config: &GeneratorConfig,
166        constraints: &ScenarioConstraints,
167    ) -> Result<(), MutationError> {
168        // Validate built-in preserve_* constraints
169        if constraints.preserve_document_chains
170            && !config.document_flows.generate_document_references
171        {
172            return Err(MutationError::ConstraintViolation(
173                "preserve_document_chains requires document_flows.generate_document_references=true"
174                    .into(),
175            ));
176        }
177
178        if constraints.preserve_balance_coherence && !config.balance.validate_balance_equation {
179            return Err(MutationError::ConstraintViolation(
180                "preserve_balance_coherence requires balance.validate_balance_equation=true".into(),
181            ));
182        }
183
184        if constraints.preserve_balance_coherence && !config.balance.generate_trial_balances {
185            return Err(MutationError::ConstraintViolation(
186                "preserve_balance_coherence requires balance.generate_trial_balances=true".into(),
187            ));
188        }
189
190        // Check custom constraints
191        for constraint in &constraints.custom {
192            // Custom constraints reference config paths with min/max bounds
193            // These are validated against the config values
194            let config_json = serde_json::to_value(config)
195                .map_err(|e| MutationError::SerializationError(e.to_string()))?;
196
197            let segments = Self::parse_path(&constraint.config_path);
198            let mut current = &config_json;
199            let mut found = true;
200
201            for segment in &segments {
202                match segment {
203                    PathSegment::Key(key) => {
204                        if let Some(next) = current.get(key.as_str()) {
205                            current = next;
206                        } else {
207                            found = false;
208                            break;
209                        }
210                    }
211                    PathSegment::Index(idx) => {
212                        if let Some(next) = current.get(*idx) {
213                            current = next;
214                        } else {
215                            found = false;
216                            break;
217                        }
218                    }
219                }
220            }
221
222            if found {
223                if let Some(val) = current.as_f64() {
224                    if let Some(min) = &constraint.min {
225                        use rust_decimal::prelude::ToPrimitive;
226                        if let Some(min_f64) = min.to_f64() {
227                            if val < min_f64 {
228                                return Err(MutationError::ConstraintViolation(format!(
229                                    "{}: value {} below minimum {}",
230                                    constraint.config_path, val, min
231                                )));
232                            }
233                        }
234                    }
235                    if let Some(max) = &constraint.max {
236                        use rust_decimal::prelude::ToPrimitive;
237                        if let Some(max_f64) = max.to_f64() {
238                            if val > max_f64 {
239                                return Err(MutationError::ConstraintViolation(format!(
240                                    "{}: value {} above maximum {}",
241                                    constraint.config_path, val, max
242                                )));
243                            }
244                        }
245                    }
246                }
247            }
248        }
249
250        Ok(())
251    }
252}
253
254#[derive(Debug)]
255enum PathSegment {
256    Key(String),
257    Index(usize),
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use crate::causal_engine::ConfigChange;
264    use std::collections::BTreeMap;
265
266    #[test]
267    fn test_apply_simple_dot_path() {
268        let mut json = serde_json::json!({
269            "global": {
270                "seed": 42
271            }
272        });
273
274        ConfigMutator::apply_at_path(&mut json, "global.seed", &serde_json::json!(99))
275            .expect("should succeed");
276
277        assert_eq!(json["global"]["seed"], 99);
278    }
279
280    #[test]
281    fn test_apply_nested_dot_path() {
282        let mut json = serde_json::json!({
283            "distributions": {
284                "amounts": {
285                    "components": [
286                        {"mu": 6.0, "sigma": 1.5},
287                        {"mu": 8.5, "sigma": 1.0}
288                    ]
289                }
290            }
291        });
292
293        ConfigMutator::apply_at_path(
294            &mut json,
295            "distributions.amounts.components[0].mu",
296            &serde_json::json!(5.5),
297        )
298        .expect("should succeed");
299
300        assert_eq!(json["distributions"]["amounts"]["components"][0]["mu"], 5.5);
301        // Other fields unchanged
302        assert_eq!(
303            json["distributions"]["amounts"]["components"][0]["sigma"],
304            1.5
305        );
306        assert_eq!(json["distributions"]["amounts"]["components"][1]["mu"], 8.5);
307    }
308
309    #[test]
310    fn test_apply_preserves_other_fields() {
311        let mut json = serde_json::json!({
312            "global": {
313                "seed": 42,
314                "industry": "retail"
315            }
316        });
317
318        ConfigMutator::apply_at_path(&mut json, "global.seed", &serde_json::json!(99))
319            .expect("should succeed");
320
321        assert_eq!(json["global"]["seed"], 99);
322        assert_eq!(json["global"]["industry"], "retail");
323    }
324
325    #[test]
326    fn test_apply_invalid_path_returns_error() {
327        let mut json = serde_json::json!({
328            "global": { "seed": 42 }
329        });
330
331        let result = ConfigMutator::apply_at_path(
332            &mut json,
333            "nonexistent.path.here",
334            &serde_json::json!(99),
335        );
336
337        assert!(matches!(result, Err(MutationError::PathNotFound(_))));
338    }
339
340    #[test]
341    fn test_roundtrip_config_mutation() {
342        // Test the dot-path mutation on raw JSON (avoids GeneratorConfig roundtrip issues)
343        let mut json = serde_json::json!({
344            "global": {
345                "seed": 42,
346                "period_months": 12,
347                "start_date": "2024-01-01",
348                "industry": "manufacturing"
349            },
350            "distributions": {
351                "amounts": {
352                    "components": [
353                        {"mu": 6.0, "sigma": 1.5}
354                    ]
355                }
356            }
357        });
358
359        // Mutate period_months
360        ConfigMutator::apply_at_path(&mut json, "global.period_months", &serde_json::json!(6))
361            .expect("should succeed");
362
363        assert_eq!(json["global"]["period_months"], 6);
364        // Other fields preserved
365        assert_eq!(json["global"]["start_date"], "2024-01-01");
366        assert_eq!(json["global"]["seed"], 42);
367
368        // Mutate nested array element
369        ConfigMutator::apply_at_path(
370            &mut json,
371            "distributions.amounts.components[0].mu",
372            &serde_json::json!(5.5),
373        )
374        .expect("should succeed");
375
376        assert_eq!(json["distributions"]["amounts"]["components"][0]["mu"], 5.5);
377        assert_eq!(
378            json["distributions"]["amounts"]["components"][0]["sigma"],
379            1.5
380        );
381    }
382
383    #[test]
384    fn test_constraint_validation_passes() {
385        // Test with empty propagation (no changes)
386        let json = serde_json::json!({
387            "global": {"seed": 42, "period_months": 12}
388        });
389
390        let constraints = ScenarioConstraints::default();
391        // No custom constraints → always passes
392        assert!(constraints.custom.is_empty());
393    }
394
395    #[test]
396    fn test_constraint_preserves_document_chains() {
397        use datasynth_test_utils::fixtures::minimal_config;
398
399        let mut config = minimal_config();
400        config.document_flows.generate_document_references = false;
401
402        let constraints = ScenarioConstraints {
403            preserve_document_chains: true,
404            ..Default::default()
405        };
406
407        let propagated = PropagatedInterventions {
408            changes_by_month: BTreeMap::new(),
409        };
410
411        let result = ConfigMutator::apply(&config, &propagated, &constraints);
412        assert!(matches!(result, Err(MutationError::ConstraintViolation(_))));
413        if let Err(MutationError::ConstraintViolation(msg)) = result {
414            assert!(msg.contains("document_flows"));
415        }
416    }
417
418    #[test]
419    fn test_constraint_preserves_balance() {
420        use datasynth_test_utils::fixtures::minimal_config;
421
422        let mut config = minimal_config();
423        config.balance.validate_balance_equation = false;
424
425        let constraints = ScenarioConstraints {
426            preserve_balance_coherence: true,
427            ..Default::default()
428        };
429
430        let propagated = PropagatedInterventions {
431            changes_by_month: BTreeMap::new(),
432        };
433
434        let result = ConfigMutator::apply(&config, &propagated, &constraints);
435        assert!(matches!(result, Err(MutationError::ConstraintViolation(_))));
436    }
437
438    #[test]
439    fn test_constraint_allows_when_not_preserved() {
440        use datasynth_test_utils::fixtures::minimal_config;
441
442        let mut config = minimal_config();
443        config.document_flows.generate_document_references = false;
444        config.balance.validate_balance_equation = false;
445
446        // All preserve flags off — should succeed
447        let constraints = ScenarioConstraints {
448            preserve_document_chains: false,
449            preserve_balance_coherence: false,
450            preserve_period_close: false,
451            preserve_accounting_identity: false,
452            custom: vec![],
453        };
454
455        let propagated = PropagatedInterventions {
456            changes_by_month: BTreeMap::new(),
457        };
458
459        let result = ConfigMutator::apply(&config, &propagated, &constraints);
460        assert!(result.is_ok());
461    }
462}