1use std::collections::HashMap;
2
3use rand::Rng;
4use rand::SeedableRng;
5use rand_chacha::ChaCha8Rng;
6use rand_distr::{Beta, Distribution, LogNormal, Normal};
7
8use super::graph::{CausalGraph, CausalVarType, CausalVariable};
9
10pub struct StructuralCausalModel {
12 graph: CausalGraph,
13}
14
15impl StructuralCausalModel {
16 pub fn new(graph: CausalGraph) -> Result<Self, String> {
17 graph.validate()?;
18 Ok(Self { graph })
19 }
20
21 pub fn graph(&self) -> &CausalGraph {
23 &self.graph
24 }
25
26 pub fn generate(
28 &self,
29 n_samples: usize,
30 seed: u64,
31 ) -> Result<Vec<HashMap<String, f64>>, String> {
32 let order = self.graph.topological_order()?;
33 let mut rng = ChaCha8Rng::seed_from_u64(seed);
34 let mut samples = Vec::with_capacity(n_samples);
35
36 for _ in 0..n_samples {
37 let mut record: HashMap<String, f64> = HashMap::new();
38
39 for var_name in &order {
40 let var = self
41 .graph
42 .get_variable(var_name)
43 .ok_or_else(|| format!("Variable '{var_name}' not found"))?;
44
45 let noise = self.sample_exogenous(var, &mut rng);
47
48 let parent_edges = self.graph.parent_edges(var_name);
50 let parent_contribution: f64 = parent_edges
51 .iter()
52 .map(|edge| {
53 let parent_val = record.get(&edge.from).copied().unwrap_or(0.0);
54 edge.mechanism.apply(parent_val) * edge.strength
55 })
56 .sum();
57
58 let value = match var.var_type {
60 CausalVarType::Binary => {
61 let prob = (noise + parent_contribution).clamp(0.0, 1.0);
62 if rng.random::<f64>() < prob {
63 1.0
64 } else {
65 0.0
66 }
67 }
68 CausalVarType::Count => (noise + parent_contribution).max(0.0).round(),
69 _ => noise + parent_contribution,
70 };
71
72 record.insert(var_name.clone(), value);
73 }
74
75 samples.push(record);
76 }
77
78 Ok(samples)
79 }
80
81 fn sample_exogenous(&self, var: &CausalVariable, rng: &mut ChaCha8Rng) -> f64 {
83 let dist = var.distribution.as_deref().unwrap_or("normal");
84 match dist {
85 "lognormal" => {
86 let mu = var.params.get("mu").copied().unwrap_or(0.0);
87 let sigma = var.params.get("sigma").copied().unwrap_or(1.0);
88 if let Ok(d) = LogNormal::new(mu, sigma) {
89 d.sample(rng)
90 } else {
91 0.0
92 }
93 }
94 "beta" => {
95 let alpha = var.params.get("alpha").copied().unwrap_or(2.0);
96 let beta_param = var.params.get("beta_param").copied().unwrap_or(2.0);
97 if let Ok(d) = Beta::new(alpha, beta_param) {
98 d.sample(rng)
99 } else {
100 let sum = alpha + beta_param;
102 if sum > 0.0 {
103 alpha / sum
104 } else {
105 0.5
106 }
107 }
108 }
109 "uniform" => {
110 let low = var.params.get("low").copied().unwrap_or(0.0);
111 let high = var.params.get("high").copied().unwrap_or(1.0);
112 rng.random::<f64>() * (high - low) + low
113 }
114 _ => {
115 let mean = var.params.get("mean").copied().unwrap_or(0.0);
117 let std = var.params.get("std").copied().unwrap_or(1.0);
118 if let Ok(d) = Normal::new(mean, std) {
119 d.sample(rng)
120 } else {
121 mean
122 }
123 }
124 }
125 }
126
127 pub fn intervene(&self, variable: &str, value: f64) -> Result<IntervenedScm<'_>, String> {
130 if self.graph.get_variable(variable).is_none() {
132 return Err(format!("Variable '{variable}' not found for intervention"));
133 }
134 Ok(IntervenedScm {
135 base: self,
136 interventions: vec![(variable.to_string(), value)],
137 })
138 }
139}
140
141pub struct IntervenedScm<'a> {
143 base: &'a StructuralCausalModel,
144 interventions: Vec<(String, f64)>,
145}
146
147impl<'a> IntervenedScm<'a> {
148 pub fn and_intervene(mut self, variable: &str, value: f64) -> Self {
150 self.interventions.push((variable.to_string(), value));
151 self
152 }
153
154 pub fn generate(
156 &self,
157 n_samples: usize,
158 seed: u64,
159 ) -> Result<Vec<HashMap<String, f64>>, String> {
160 let order = self.base.graph.topological_order()?;
161 let mut rng = ChaCha8Rng::seed_from_u64(seed);
162 let intervention_map: HashMap<&str, f64> = self
163 .interventions
164 .iter()
165 .map(|(k, v)| (k.as_str(), *v))
166 .collect();
167 let mut samples = Vec::with_capacity(n_samples);
168
169 for _ in 0..n_samples {
170 let mut record: HashMap<String, f64> = HashMap::new();
171
172 for var_name in &order {
173 if let Some(&fixed_val) = intervention_map.get(var_name.as_str()) {
175 record.insert(var_name.clone(), fixed_val);
176 continue;
177 }
178
179 let var = self
180 .base
181 .graph
182 .get_variable(var_name)
183 .ok_or_else(|| format!("Variable '{var_name}' not found"))?;
184
185 let noise = self.base.sample_exogenous(var, &mut rng);
186 let parent_edges = self.base.graph.parent_edges(var_name);
187 let parent_contribution: f64 = parent_edges
188 .iter()
189 .map(|edge| {
190 let parent_val = record.get(&edge.from).copied().unwrap_or(0.0);
191 edge.mechanism.apply(parent_val) * edge.strength
192 })
193 .sum();
194
195 let value = match var.var_type {
196 CausalVarType::Binary => {
197 let prob = (noise + parent_contribution).clamp(0.0, 1.0);
198 if rng.random::<f64>() < prob {
199 1.0
200 } else {
201 0.0
202 }
203 }
204 CausalVarType::Count => (noise + parent_contribution).max(0.0).round(),
205 _ => noise + parent_contribution,
206 };
207
208 record.insert(var_name.clone(), value);
209 }
210
211 samples.push(record);
212 }
213
214 Ok(samples)
215 }
216}
217
218#[cfg(test)]
219#[allow(clippy::unwrap_used)]
220mod tests {
221 use super::super::graph::CausalGraph;
222 use super::*;
223
224 #[test]
225 fn test_scm_generates_correct_count() {
226 let graph = CausalGraph::fraud_detection_template();
227 let scm = StructuralCausalModel::new(graph).unwrap();
228 let samples = scm.generate(100, 42).unwrap();
229 assert_eq!(samples.len(), 100);
230 }
231
232 #[test]
233 fn test_scm_deterministic() {
234 let graph = CausalGraph::fraud_detection_template();
235 let scm = StructuralCausalModel::new(graph).unwrap();
236 let s1 = scm.generate(50, 42).unwrap();
237 let s2 = scm.generate(50, 42).unwrap();
238 for (a, b) in s1.iter().zip(s2.iter()) {
239 assert_eq!(a.get("transaction_amount"), b.get("transaction_amount"));
240 }
241 }
242
243 #[test]
244 fn test_scm_all_variables_present() {
245 let graph = CausalGraph::fraud_detection_template();
246 let var_names: Vec<String> = graph.variables.iter().map(|v| v.name.clone()).collect();
247 let scm = StructuralCausalModel::new(graph).unwrap();
248 let samples = scm.generate(10, 42).unwrap();
249 for sample in &samples {
250 for name in &var_names {
251 assert!(
252 sample.contains_key(name),
253 "Sample missing variable '{}'",
254 name
255 );
256 }
257 }
258 }
259
260 #[test]
261 fn test_scm_is_fraud_binary() {
262 let graph = CausalGraph::fraud_detection_template();
263 let scm = StructuralCausalModel::new(graph).unwrap();
264 let samples = scm.generate(100, 42).unwrap();
265 for sample in &samples {
266 let val = sample.get("is_fraud").copied().unwrap_or(-1.0);
267 assert!(
268 val == 0.0 || val == 1.0,
269 "is_fraud should be binary, got {}",
270 val
271 );
272 }
273 }
274
275 #[test]
276 fn test_intervention_sets_value() {
277 let graph = CausalGraph::fraud_detection_template();
278 let scm = StructuralCausalModel::new(graph).unwrap();
279 let intervened = scm.intervene("transaction_amount", 10000.0).unwrap();
280 let samples = intervened.generate(50, 42).unwrap();
281 for sample in &samples {
282 assert_eq!(sample.get("transaction_amount").copied(), Some(10000.0));
283 }
284 }
285
286 #[test]
287 fn test_intervention_affects_downstream() {
288 let graph = CausalGraph::fraud_detection_template();
289 let scm = StructuralCausalModel::new(graph).unwrap();
290
291 let high_intervened = scm.intervene("transaction_amount", 100000.0).unwrap();
293 let high_samples = high_intervened.generate(200, 42).unwrap();
294 let high_fraud_rate: f64 = high_samples
295 .iter()
296 .map(|s| s.get("is_fraud").copied().unwrap_or(0.0))
297 .sum::<f64>()
298 / 200.0;
299
300 let low_intervened = scm.intervene("transaction_amount", 1.0).unwrap();
302 let low_samples = low_intervened.generate(200, 42).unwrap();
303 let low_fraud_rate: f64 = low_samples
304 .iter()
305 .map(|s| s.get("is_fraud").copied().unwrap_or(0.0))
306 .sum::<f64>()
307 / 200.0;
308
309 assert!(
311 high_fraud_rate >= low_fraud_rate,
312 "High transaction amount ({}) should increase fraud rate ({} vs {})",
313 100000.0,
314 high_fraud_rate,
315 low_fraud_rate
316 );
317 }
318
319 #[test]
320 fn test_intervention_unknown_variable() {
321 let graph = CausalGraph::fraud_detection_template();
322 let scm = StructuralCausalModel::new(graph).unwrap();
323 assert!(scm.intervene("nonexistent", 0.0).is_err());
324 }
325
326 #[test]
327 fn test_cyclic_graph_rejected_by_scm() {
328 use super::super::graph::{CausalEdge, CausalMechanism, CausalVarType, CausalVariable};
329 let mut graph = CausalGraph::new();
330 graph.add_variable(CausalVariable::new("a", CausalVarType::Continuous));
331 graph.add_variable(CausalVariable::new("b", CausalVarType::Continuous));
332 graph.add_edge(CausalEdge {
333 from: "a".into(),
334 to: "b".into(),
335 mechanism: CausalMechanism::Linear { coefficient: 1.0 },
336 strength: 1.0,
337 });
338 graph.add_edge(CausalEdge {
339 from: "b".into(),
340 to: "a".into(),
341 mechanism: CausalMechanism::Linear { coefficient: 1.0 },
342 strength: 1.0,
343 });
344 assert!(StructuralCausalModel::new(graph).is_err());
345 }
346}