1use std::collections::HashMap;
2
3use rand::Rng;
4use rand::SeedableRng;
5use rand_chacha::ChaCha8Rng;
6use rand_distr::{Beta, Distribution, LogNormal, Normal};
7
8use super::graph::{CausalGraph, CausalVarType, CausalVariable};
9
10pub struct StructuralCausalModel {
12 graph: CausalGraph,
13}
14
15impl StructuralCausalModel {
16 pub fn new(graph: CausalGraph) -> Result<Self, String> {
17 graph.validate()?;
18 Ok(Self { graph })
19 }
20
21 pub fn graph(&self) -> &CausalGraph {
23 &self.graph
24 }
25
26 pub fn generate(
28 &self,
29 n_samples: usize,
30 seed: u64,
31 ) -> Result<Vec<HashMap<String, f64>>, String> {
32 let order = self.graph.topological_order()?;
33 let mut rng = ChaCha8Rng::seed_from_u64(seed);
34 let mut samples = Vec::with_capacity(n_samples);
35
36 for _ in 0..n_samples {
37 let mut record: HashMap<String, f64> = HashMap::new();
38
39 for var_name in &order {
40 let var = self
41 .graph
42 .get_variable(var_name)
43 .ok_or_else(|| format!("Variable '{}' not found", var_name))?;
44
45 let noise = self.sample_exogenous(var, &mut rng);
47
48 let parent_edges = self.graph.parent_edges(var_name);
50 let parent_contribution: f64 = parent_edges
51 .iter()
52 .map(|edge| {
53 let parent_val = record.get(&edge.from).copied().unwrap_or(0.0);
54 edge.mechanism.apply(parent_val) * edge.strength
55 })
56 .sum();
57
58 let value = match var.var_type {
60 CausalVarType::Binary => {
61 let prob = (noise + parent_contribution).clamp(0.0, 1.0);
62 if rng.random::<f64>() < prob {
63 1.0
64 } else {
65 0.0
66 }
67 }
68 CausalVarType::Count => (noise + parent_contribution).max(0.0).round(),
69 _ => noise + parent_contribution,
70 };
71
72 record.insert(var_name.clone(), value);
73 }
74
75 samples.push(record);
76 }
77
78 Ok(samples)
79 }
80
81 fn sample_exogenous(&self, var: &CausalVariable, rng: &mut ChaCha8Rng) -> f64 {
83 let dist = var.distribution.as_deref().unwrap_or("normal");
84 match dist {
85 "lognormal" => {
86 let mu = var.params.get("mu").copied().unwrap_or(0.0);
87 let sigma = var.params.get("sigma").copied().unwrap_or(1.0);
88 if let Ok(d) = LogNormal::new(mu, sigma) {
89 d.sample(rng)
90 } else {
91 0.0
92 }
93 }
94 "beta" => {
95 let alpha = var.params.get("alpha").copied().unwrap_or(2.0);
96 let beta_param = var.params.get("beta_param").copied().unwrap_or(2.0);
97 if let Ok(d) = Beta::new(alpha, beta_param) {
98 d.sample(rng)
99 } else {
100 let sum = alpha + beta_param;
102 if sum > 0.0 {
103 alpha / sum
104 } else {
105 0.5
106 }
107 }
108 }
109 "uniform" => {
110 let low = var.params.get("low").copied().unwrap_or(0.0);
111 let high = var.params.get("high").copied().unwrap_or(1.0);
112 rng.random::<f64>() * (high - low) + low
113 }
114 _ => {
115 let mean = var.params.get("mean").copied().unwrap_or(0.0);
117 let std = var.params.get("std").copied().unwrap_or(1.0);
118 if let Ok(d) = Normal::new(mean, std) {
119 d.sample(rng)
120 } else {
121 mean
122 }
123 }
124 }
125 }
126
127 pub fn intervene(&self, variable: &str, value: f64) -> Result<IntervenedScm<'_>, String> {
130 if self.graph.get_variable(variable).is_none() {
132 return Err(format!(
133 "Variable '{}' not found for intervention",
134 variable
135 ));
136 }
137 Ok(IntervenedScm {
138 base: self,
139 interventions: vec![(variable.to_string(), value)],
140 })
141 }
142}
143
144pub struct IntervenedScm<'a> {
146 base: &'a StructuralCausalModel,
147 interventions: Vec<(String, f64)>,
148}
149
150impl<'a> IntervenedScm<'a> {
151 pub fn and_intervene(mut self, variable: &str, value: f64) -> Self {
153 self.interventions.push((variable.to_string(), value));
154 self
155 }
156
157 pub fn generate(
159 &self,
160 n_samples: usize,
161 seed: u64,
162 ) -> Result<Vec<HashMap<String, f64>>, String> {
163 let order = self.base.graph.topological_order()?;
164 let mut rng = ChaCha8Rng::seed_from_u64(seed);
165 let intervention_map: HashMap<&str, f64> = self
166 .interventions
167 .iter()
168 .map(|(k, v)| (k.as_str(), *v))
169 .collect();
170 let mut samples = Vec::with_capacity(n_samples);
171
172 for _ in 0..n_samples {
173 let mut record: HashMap<String, f64> = HashMap::new();
174
175 for var_name in &order {
176 if let Some(&fixed_val) = intervention_map.get(var_name.as_str()) {
178 record.insert(var_name.clone(), fixed_val);
179 continue;
180 }
181
182 let var = self
183 .base
184 .graph
185 .get_variable(var_name)
186 .ok_or_else(|| format!("Variable '{}' not found", var_name))?;
187
188 let noise = self.base.sample_exogenous(var, &mut rng);
189 let parent_edges = self.base.graph.parent_edges(var_name);
190 let parent_contribution: f64 = parent_edges
191 .iter()
192 .map(|edge| {
193 let parent_val = record.get(&edge.from).copied().unwrap_or(0.0);
194 edge.mechanism.apply(parent_val) * edge.strength
195 })
196 .sum();
197
198 let value = match var.var_type {
199 CausalVarType::Binary => {
200 let prob = (noise + parent_contribution).clamp(0.0, 1.0);
201 if rng.random::<f64>() < prob {
202 1.0
203 } else {
204 0.0
205 }
206 }
207 CausalVarType::Count => (noise + parent_contribution).max(0.0).round(),
208 _ => noise + parent_contribution,
209 };
210
211 record.insert(var_name.clone(), value);
212 }
213
214 samples.push(record);
215 }
216
217 Ok(samples)
218 }
219}
220
221#[cfg(test)]
222#[allow(clippy::unwrap_used)]
223mod tests {
224 use super::super::graph::CausalGraph;
225 use super::*;
226
227 #[test]
228 fn test_scm_generates_correct_count() {
229 let graph = CausalGraph::fraud_detection_template();
230 let scm = StructuralCausalModel::new(graph).unwrap();
231 let samples = scm.generate(100, 42).unwrap();
232 assert_eq!(samples.len(), 100);
233 }
234
235 #[test]
236 fn test_scm_deterministic() {
237 let graph = CausalGraph::fraud_detection_template();
238 let scm = StructuralCausalModel::new(graph).unwrap();
239 let s1 = scm.generate(50, 42).unwrap();
240 let s2 = scm.generate(50, 42).unwrap();
241 for (a, b) in s1.iter().zip(s2.iter()) {
242 assert_eq!(a.get("transaction_amount"), b.get("transaction_amount"));
243 }
244 }
245
246 #[test]
247 fn test_scm_all_variables_present() {
248 let graph = CausalGraph::fraud_detection_template();
249 let var_names: Vec<String> = graph.variables.iter().map(|v| v.name.clone()).collect();
250 let scm = StructuralCausalModel::new(graph).unwrap();
251 let samples = scm.generate(10, 42).unwrap();
252 for sample in &samples {
253 for name in &var_names {
254 assert!(
255 sample.contains_key(name),
256 "Sample missing variable '{}'",
257 name
258 );
259 }
260 }
261 }
262
263 #[test]
264 fn test_scm_is_fraud_binary() {
265 let graph = CausalGraph::fraud_detection_template();
266 let scm = StructuralCausalModel::new(graph).unwrap();
267 let samples = scm.generate(100, 42).unwrap();
268 for sample in &samples {
269 let val = sample.get("is_fraud").copied().unwrap_or(-1.0);
270 assert!(
271 val == 0.0 || val == 1.0,
272 "is_fraud should be binary, got {}",
273 val
274 );
275 }
276 }
277
278 #[test]
279 fn test_intervention_sets_value() {
280 let graph = CausalGraph::fraud_detection_template();
281 let scm = StructuralCausalModel::new(graph).unwrap();
282 let intervened = scm.intervene("transaction_amount", 10000.0).unwrap();
283 let samples = intervened.generate(50, 42).unwrap();
284 for sample in &samples {
285 assert_eq!(sample.get("transaction_amount").copied(), Some(10000.0));
286 }
287 }
288
289 #[test]
290 fn test_intervention_affects_downstream() {
291 let graph = CausalGraph::fraud_detection_template();
292 let scm = StructuralCausalModel::new(graph).unwrap();
293
294 let high_intervened = scm.intervene("transaction_amount", 100000.0).unwrap();
296 let high_samples = high_intervened.generate(200, 42).unwrap();
297 let high_fraud_rate: f64 = high_samples
298 .iter()
299 .map(|s| s.get("is_fraud").copied().unwrap_or(0.0))
300 .sum::<f64>()
301 / 200.0;
302
303 let low_intervened = scm.intervene("transaction_amount", 1.0).unwrap();
305 let low_samples = low_intervened.generate(200, 42).unwrap();
306 let low_fraud_rate: f64 = low_samples
307 .iter()
308 .map(|s| s.get("is_fraud").copied().unwrap_or(0.0))
309 .sum::<f64>()
310 / 200.0;
311
312 assert!(
314 high_fraud_rate >= low_fraud_rate,
315 "High transaction amount ({}) should increase fraud rate ({} vs {})",
316 100000.0,
317 high_fraud_rate,
318 low_fraud_rate
319 );
320 }
321
322 #[test]
323 fn test_intervention_unknown_variable() {
324 let graph = CausalGraph::fraud_detection_template();
325 let scm = StructuralCausalModel::new(graph).unwrap();
326 assert!(scm.intervene("nonexistent", 0.0).is_err());
327 }
328
329 #[test]
330 fn test_cyclic_graph_rejected_by_scm() {
331 use super::super::graph::{CausalEdge, CausalMechanism, CausalVarType, CausalVariable};
332 let mut graph = CausalGraph::new();
333 graph.add_variable(CausalVariable::new("a", CausalVarType::Continuous));
334 graph.add_variable(CausalVariable::new("b", CausalVarType::Continuous));
335 graph.add_edge(CausalEdge {
336 from: "a".into(),
337 to: "b".into(),
338 mechanism: CausalMechanism::Linear { coefficient: 1.0 },
339 strength: 1.0,
340 });
341 graph.add_edge(CausalEdge {
342 from: "b".into(),
343 to: "a".into(),
344 mechanism: CausalMechanism::Linear { coefficient: 1.0 },
345 strength: 1.0,
346 });
347 assert!(StructuralCausalModel::new(graph).is_err());
348 }
349}