Skip to main content

datasynth_core/causal/
scm.rs

1use std::collections::HashMap;
2
3use rand::Rng;
4use rand::SeedableRng;
5use rand_chacha::ChaCha8Rng;
6use rand_distr::{Distribution, LogNormal, Normal};
7
8use super::graph::{CausalGraph, CausalVarType, CausalVariable};
9
10/// Structural Causal Model for generating data from a causal graph.
11pub struct StructuralCausalModel {
12    graph: CausalGraph,
13}
14
15impl StructuralCausalModel {
16    pub fn new(graph: CausalGraph) -> Result<Self, String> {
17        graph.validate()?;
18        Ok(Self { graph })
19    }
20
21    /// Get reference to the underlying graph.
22    pub fn graph(&self) -> &CausalGraph {
23        &self.graph
24    }
25
26    /// Generate samples from the causal model.
27    pub fn generate(
28        &self,
29        n_samples: usize,
30        seed: u64,
31    ) -> Result<Vec<HashMap<String, f64>>, String> {
32        let order = self.graph.topological_order()?;
33        let mut rng = ChaCha8Rng::seed_from_u64(seed);
34        let mut samples = Vec::with_capacity(n_samples);
35
36        for _ in 0..n_samples {
37            let mut record: HashMap<String, f64> = HashMap::new();
38
39            for var_name in &order {
40                let var = self
41                    .graph
42                    .get_variable(var_name)
43                    .ok_or_else(|| format!("Variable '{}' not found", var_name))?;
44
45                // Sample exogenous noise
46                let noise = self.sample_exogenous(var, &mut rng);
47
48                // Compute contribution from parents
49                let parent_edges = self.graph.parent_edges(var_name);
50                let parent_contribution: f64 = parent_edges
51                    .iter()
52                    .map(|edge| {
53                        let parent_val = record.get(&edge.from).copied().unwrap_or(0.0);
54                        edge.mechanism.apply(parent_val) * edge.strength
55                    })
56                    .sum();
57
58                // Combine: noise + parent contributions
59                let value = match var.var_type {
60                    CausalVarType::Binary => {
61                        let prob = (noise + parent_contribution).clamp(0.0, 1.0);
62                        if rng.gen::<f64>() < prob {
63                            1.0
64                        } else {
65                            0.0
66                        }
67                    }
68                    CausalVarType::Count => (noise + parent_contribution).max(0.0).round(),
69                    _ => noise + parent_contribution,
70                };
71
72                record.insert(var_name.clone(), value);
73            }
74
75            samples.push(record);
76        }
77
78        Ok(samples)
79    }
80
81    /// Sample exogenous noise for a variable based on its distribution specification.
82    fn sample_exogenous(&self, var: &CausalVariable, rng: &mut ChaCha8Rng) -> f64 {
83        let dist = var.distribution.as_deref().unwrap_or("normal");
84        match dist {
85            "lognormal" => {
86                let mu = var.params.get("mu").copied().unwrap_or(0.0);
87                let sigma = var.params.get("sigma").copied().unwrap_or(1.0);
88                if let Ok(d) = LogNormal::new(mu, sigma) {
89                    d.sample(rng)
90                } else {
91                    0.0
92                }
93            }
94            "beta" => {
95                // Simple beta approximation using normal
96                let alpha = var.params.get("alpha").copied().unwrap_or(2.0);
97                let beta_param = var.params.get("beta_param").copied().unwrap_or(2.0);
98                let mean = alpha / (alpha + beta_param);
99                let var_val = (alpha * beta_param)
100                    / ((alpha + beta_param).powi(2) * (alpha + beta_param + 1.0));
101                if let Ok(d) = Normal::new(mean, var_val.sqrt()) {
102                    d.sample(rng).clamp(0.0, 1.0) // Clamp for beta-like behavior
103                } else {
104                    mean
105                }
106            }
107            "uniform" => {
108                let low = var.params.get("low").copied().unwrap_or(0.0);
109                let high = var.params.get("high").copied().unwrap_or(1.0);
110                rng.gen::<f64>() * (high - low) + low
111            }
112            _ => {
113                // Default to normal distribution
114                let mean = var.params.get("mean").copied().unwrap_or(0.0);
115                let std = var.params.get("std").copied().unwrap_or(1.0);
116                if let Ok(d) = Normal::new(mean, std) {
117                    d.sample(rng)
118                } else {
119                    mean
120                }
121            }
122        }
123    }
124
125    /// Create an intervened SCM where a variable is set to a fixed value.
126    /// This implements the do-calculus do(X=x) operation.
127    pub fn intervene(&self, variable: &str, value: f64) -> Result<IntervenedScm<'_>, String> {
128        // Verify variable exists
129        if self.graph.get_variable(variable).is_none() {
130            return Err(format!(
131                "Variable '{}' not found for intervention",
132                variable
133            ));
134        }
135        Ok(IntervenedScm {
136            base: self,
137            interventions: vec![(variable.to_string(), value)],
138        })
139    }
140}
141
142/// An SCM with active interventions (do-calculus).
143pub struct IntervenedScm<'a> {
144    base: &'a StructuralCausalModel,
145    interventions: Vec<(String, f64)>,
146}
147
148impl<'a> IntervenedScm<'a> {
149    /// Add another intervention.
150    pub fn and_intervene(mut self, variable: &str, value: f64) -> Self {
151        self.interventions.push((variable.to_string(), value));
152        self
153    }
154
155    /// Generate samples under intervention.
156    pub fn generate(
157        &self,
158        n_samples: usize,
159        seed: u64,
160    ) -> Result<Vec<HashMap<String, f64>>, String> {
161        let order = self.base.graph.topological_order()?;
162        let mut rng = ChaCha8Rng::seed_from_u64(seed);
163        let intervention_map: HashMap<&str, f64> = self
164            .interventions
165            .iter()
166            .map(|(k, v)| (k.as_str(), *v))
167            .collect();
168        let mut samples = Vec::with_capacity(n_samples);
169
170        for _ in 0..n_samples {
171            let mut record: HashMap<String, f64> = HashMap::new();
172
173            for var_name in &order {
174                // If this variable is intervened on, use fixed value
175                if let Some(&fixed_val) = intervention_map.get(var_name.as_str()) {
176                    record.insert(var_name.clone(), fixed_val);
177                    continue;
178                }
179
180                let var = self
181                    .base
182                    .graph
183                    .get_variable(var_name)
184                    .ok_or_else(|| format!("Variable '{}' not found", var_name))?;
185
186                let noise = self.base.sample_exogenous(var, &mut rng);
187                let parent_edges = self.base.graph.parent_edges(var_name);
188                let parent_contribution: f64 = parent_edges
189                    .iter()
190                    .map(|edge| {
191                        let parent_val = record.get(&edge.from).copied().unwrap_or(0.0);
192                        edge.mechanism.apply(parent_val) * edge.strength
193                    })
194                    .sum();
195
196                let value = match var.var_type {
197                    CausalVarType::Binary => {
198                        let prob = (noise + parent_contribution).clamp(0.0, 1.0);
199                        if rng.gen::<f64>() < prob {
200                            1.0
201                        } else {
202                            0.0
203                        }
204                    }
205                    CausalVarType::Count => (noise + parent_contribution).max(0.0).round(),
206                    _ => noise + parent_contribution,
207                };
208
209                record.insert(var_name.clone(), value);
210            }
211
212            samples.push(record);
213        }
214
215        Ok(samples)
216    }
217}
218
219#[cfg(test)]
220#[allow(clippy::unwrap_used)]
221mod tests {
222    use super::super::graph::CausalGraph;
223    use super::*;
224
225    #[test]
226    fn test_scm_generates_correct_count() {
227        let graph = CausalGraph::fraud_detection_template();
228        let scm = StructuralCausalModel::new(graph).unwrap();
229        let samples = scm.generate(100, 42).unwrap();
230        assert_eq!(samples.len(), 100);
231    }
232
233    #[test]
234    fn test_scm_deterministic() {
235        let graph = CausalGraph::fraud_detection_template();
236        let scm = StructuralCausalModel::new(graph).unwrap();
237        let s1 = scm.generate(50, 42).unwrap();
238        let s2 = scm.generate(50, 42).unwrap();
239        for (a, b) in s1.iter().zip(s2.iter()) {
240            assert_eq!(a.get("transaction_amount"), b.get("transaction_amount"));
241        }
242    }
243
244    #[test]
245    fn test_scm_all_variables_present() {
246        let graph = CausalGraph::fraud_detection_template();
247        let var_names: Vec<String> = graph.variables.iter().map(|v| v.name.clone()).collect();
248        let scm = StructuralCausalModel::new(graph).unwrap();
249        let samples = scm.generate(10, 42).unwrap();
250        for sample in &samples {
251            for name in &var_names {
252                assert!(
253                    sample.contains_key(name),
254                    "Sample missing variable '{}'",
255                    name
256                );
257            }
258        }
259    }
260
261    #[test]
262    fn test_scm_is_fraud_binary() {
263        let graph = CausalGraph::fraud_detection_template();
264        let scm = StructuralCausalModel::new(graph).unwrap();
265        let samples = scm.generate(100, 42).unwrap();
266        for sample in &samples {
267            let val = sample.get("is_fraud").copied().unwrap_or(-1.0);
268            assert!(
269                val == 0.0 || val == 1.0,
270                "is_fraud should be binary, got {}",
271                val
272            );
273        }
274    }
275
276    #[test]
277    fn test_intervention_sets_value() {
278        let graph = CausalGraph::fraud_detection_template();
279        let scm = StructuralCausalModel::new(graph).unwrap();
280        let intervened = scm.intervene("transaction_amount", 10000.0).unwrap();
281        let samples = intervened.generate(50, 42).unwrap();
282        for sample in &samples {
283            assert_eq!(sample.get("transaction_amount").copied(), Some(10000.0));
284        }
285    }
286
287    #[test]
288    fn test_intervention_affects_downstream() {
289        let graph = CausalGraph::fraud_detection_template();
290        let scm = StructuralCausalModel::new(graph).unwrap();
291
292        // Generate with very high transaction amount - should increase fraud probability
293        let high_intervened = scm.intervene("transaction_amount", 100000.0).unwrap();
294        let high_samples = high_intervened.generate(200, 42).unwrap();
295        let high_fraud_rate: f64 = high_samples
296            .iter()
297            .map(|s| s.get("is_fraud").copied().unwrap_or(0.0))
298            .sum::<f64>()
299            / 200.0;
300
301        // Generate with very low transaction amount
302        let low_intervened = scm.intervene("transaction_amount", 1.0).unwrap();
303        let low_samples = low_intervened.generate(200, 42).unwrap();
304        let low_fraud_rate: f64 = low_samples
305            .iter()
306            .map(|s| s.get("is_fraud").copied().unwrap_or(0.0))
307            .sum::<f64>()
308            / 200.0;
309
310        // High amount should generally lead to higher fraud rate
311        assert!(
312            high_fraud_rate >= low_fraud_rate,
313            "High transaction amount ({}) should increase fraud rate ({} vs {})",
314            100000.0,
315            high_fraud_rate,
316            low_fraud_rate
317        );
318    }
319
320    #[test]
321    fn test_intervention_unknown_variable() {
322        let graph = CausalGraph::fraud_detection_template();
323        let scm = StructuralCausalModel::new(graph).unwrap();
324        assert!(scm.intervene("nonexistent", 0.0).is_err());
325    }
326
327    #[test]
328    fn test_cyclic_graph_rejected_by_scm() {
329        use super::super::graph::{CausalEdge, CausalMechanism, CausalVarType, CausalVariable};
330        let mut graph = CausalGraph::new();
331        graph.add_variable(CausalVariable::new("a", CausalVarType::Continuous));
332        graph.add_variable(CausalVariable::new("b", CausalVarType::Continuous));
333        graph.add_edge(CausalEdge {
334            from: "a".into(),
335            to: "b".into(),
336            mechanism: CausalMechanism::Linear { coefficient: 1.0 },
337            strength: 1.0,
338        });
339        graph.add_edge(CausalEdge {
340            from: "b".into(),
341            to: "a".into(),
342            mechanism: CausalMechanism::Linear { coefficient: 1.0 },
343            strength: 1.0,
344        });
345        assert!(StructuralCausalModel::new(graph).is_err());
346    }
347}