1use std::collections::HashMap;
2
3use rand::RngExt;
4use rand::SeedableRng;
5use rand_chacha::ChaCha8Rng;
6use rand_distr::{Beta, Distribution, LogNormal, Normal};
7
8use super::graph::{CausalGraph, CausalVarType, CausalVariable};
9
10pub struct StructuralCausalModel {
12 graph: CausalGraph,
13}
14
15impl StructuralCausalModel {
16 pub fn new(graph: CausalGraph) -> Result<Self, String> {
17 graph.validate()?;
18 Ok(Self { graph })
19 }
20
21 pub fn graph(&self) -> &CausalGraph {
23 &self.graph
24 }
25
26 pub fn generate(
28 &self,
29 n_samples: usize,
30 seed: u64,
31 ) -> Result<Vec<HashMap<String, f64>>, String> {
32 let order = self.graph.topological_order()?;
33 let mut rng = ChaCha8Rng::seed_from_u64(seed);
34 let mut samples = Vec::with_capacity(n_samples);
35
36 for _ in 0..n_samples {
37 let mut record: HashMap<String, f64> = HashMap::new();
38
39 for var_name in &order {
40 let var = self
41 .graph
42 .get_variable(var_name)
43 .ok_or_else(|| format!("Variable '{var_name}' not found"))?;
44
45 let noise = self.sample_exogenous(var, &mut rng);
47
48 let parent_edges = self.graph.parent_edges(var_name);
50 let parent_contribution: f64 = parent_edges
51 .iter()
52 .map(|edge| {
53 let parent_val = record.get(&edge.from).copied().unwrap_or(0.0);
54 edge.mechanism.apply(parent_val) * edge.strength
55 })
56 .sum();
57
58 let value = match var.var_type {
60 CausalVarType::Binary => {
61 let prob = (noise + parent_contribution).clamp(0.0, 1.0);
62 if rng.random::<f64>() < prob {
63 1.0
64 } else {
65 0.0
66 }
67 }
68 CausalVarType::Count => (noise + parent_contribution).max(0.0).round(),
69 _ => noise + parent_contribution,
70 };
71
72 record.insert(var_name.clone(), value);
73 }
74
75 samples.push(record);
76 }
77
78 Ok(samples)
79 }
80
81 fn sample_exogenous(&self, var: &CausalVariable, rng: &mut ChaCha8Rng) -> f64 {
83 let dist = var.distribution.as_deref().unwrap_or("normal");
84 match dist {
85 "lognormal" => {
86 let mu = var.params.get("mu").copied().unwrap_or(0.0);
87 let sigma = var.params.get("sigma").copied().unwrap_or(1.0);
88 if let Ok(d) = LogNormal::new(mu, sigma) {
89 d.sample(rng)
90 } else {
91 0.0
92 }
93 }
94 "beta" => {
95 let alpha = var.params.get("alpha").copied().unwrap_or(2.0);
96 let beta_param = var.params.get("beta_param").copied().unwrap_or(2.0);
97 if let Ok(d) = Beta::new(alpha, beta_param) {
98 d.sample(rng)
99 } else {
100 let sum = alpha + beta_param;
102 if sum > 0.0 {
103 alpha / sum
104 } else {
105 0.5
106 }
107 }
108 }
109 "uniform" => {
110 let low = var.params.get("low").copied().unwrap_or(0.0);
111 let high = var.params.get("high").copied().unwrap_or(1.0);
112 rng.random::<f64>() * (high - low) + low
113 }
114 _ => {
115 let mean = var.params.get("mean").copied().unwrap_or(0.0);
117 let std = var.params.get("std").copied().unwrap_or(1.0);
118 if let Ok(d) = Normal::new(mean, std) {
119 d.sample(rng)
120 } else {
121 mean
122 }
123 }
124 }
125 }
126
127 pub fn intervene(&self, variable: &str, value: f64) -> Result<IntervenedScm<'_>, String> {
130 if self.graph.get_variable(variable).is_none() {
132 return Err(format!("Variable '{variable}' not found for intervention"));
133 }
134 Ok(IntervenedScm {
135 base: self,
136 interventions: vec![(variable.to_string(), value)],
137 })
138 }
139}
140
141pub struct IntervenedScm<'a> {
143 base: &'a StructuralCausalModel,
144 interventions: Vec<(String, f64)>,
145}
146
147impl<'a> IntervenedScm<'a> {
148 pub fn and_intervene(mut self, variable: &str, value: f64) -> Self {
150 self.interventions.push((variable.to_string(), value));
151 self
152 }
153
154 pub fn generate(
156 &self,
157 n_samples: usize,
158 seed: u64,
159 ) -> Result<Vec<HashMap<String, f64>>, String> {
160 let order = self.base.graph.topological_order()?;
161 let mut rng = ChaCha8Rng::seed_from_u64(seed);
162 let intervention_map: HashMap<&str, f64> = self
163 .interventions
164 .iter()
165 .map(|(k, v)| (k.as_str(), *v))
166 .collect();
167 let mut samples = Vec::with_capacity(n_samples);
168
169 for _ in 0..n_samples {
170 let mut record: HashMap<String, f64> = HashMap::new();
171
172 for var_name in &order {
173 if let Some(&fixed_val) = intervention_map.get(var_name.as_str()) {
175 record.insert(var_name.clone(), fixed_val);
176 continue;
177 }
178
179 let var = self
180 .base
181 .graph
182 .get_variable(var_name)
183 .ok_or_else(|| format!("Variable '{var_name}' not found"))?;
184
185 let noise = self.base.sample_exogenous(var, &mut rng);
186 let parent_edges = self.base.graph.parent_edges(var_name);
187 let parent_contribution: f64 = parent_edges
188 .iter()
189 .map(|edge| {
190 let parent_val = record.get(&edge.from).copied().unwrap_or(0.0);
191 edge.mechanism.apply(parent_val) * edge.strength
192 })
193 .sum();
194
195 let value = match var.var_type {
196 CausalVarType::Binary => {
197 let prob = (noise + parent_contribution).clamp(0.0, 1.0);
198 if rng.random::<f64>() < prob {
199 1.0
200 } else {
201 0.0
202 }
203 }
204 CausalVarType::Count => (noise + parent_contribution).max(0.0).round(),
205 _ => noise + parent_contribution,
206 };
207
208 record.insert(var_name.clone(), value);
209 }
210
211 samples.push(record);
212 }
213
214 Ok(samples)
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::super::graph::CausalGraph;
221 use super::*;
222
223 #[test]
224 fn test_scm_generates_correct_count() {
225 let graph = CausalGraph::fraud_detection_template();
226 let scm = StructuralCausalModel::new(graph).unwrap();
227 let samples = scm.generate(100, 42).unwrap();
228 assert_eq!(samples.len(), 100);
229 }
230
231 #[test]
232 fn test_scm_deterministic() {
233 let graph = CausalGraph::fraud_detection_template();
234 let scm = StructuralCausalModel::new(graph).unwrap();
235 let s1 = scm.generate(50, 42).unwrap();
236 let s2 = scm.generate(50, 42).unwrap();
237 for (a, b) in s1.iter().zip(s2.iter()) {
238 assert_eq!(a.get("transaction_amount"), b.get("transaction_amount"));
239 }
240 }
241
242 #[test]
243 fn test_scm_all_variables_present() {
244 let graph = CausalGraph::fraud_detection_template();
245 let var_names: Vec<String> = graph.variables.iter().map(|v| v.name.clone()).collect();
246 let scm = StructuralCausalModel::new(graph).unwrap();
247 let samples = scm.generate(10, 42).unwrap();
248 for sample in &samples {
249 for name in &var_names {
250 assert!(
251 sample.contains_key(name),
252 "Sample missing variable '{}'",
253 name
254 );
255 }
256 }
257 }
258
259 #[test]
260 fn test_scm_is_fraud_binary() {
261 let graph = CausalGraph::fraud_detection_template();
262 let scm = StructuralCausalModel::new(graph).unwrap();
263 let samples = scm.generate(100, 42).unwrap();
264 for sample in &samples {
265 let val = sample.get("is_fraud").copied().unwrap_or(-1.0);
266 assert!(
267 val == 0.0 || val == 1.0,
268 "is_fraud should be binary, got {}",
269 val
270 );
271 }
272 }
273
274 #[test]
275 fn test_intervention_sets_value() {
276 let graph = CausalGraph::fraud_detection_template();
277 let scm = StructuralCausalModel::new(graph).unwrap();
278 let intervened = scm.intervene("transaction_amount", 10000.0).unwrap();
279 let samples = intervened.generate(50, 42).unwrap();
280 for sample in &samples {
281 assert_eq!(sample.get("transaction_amount").copied(), Some(10000.0));
282 }
283 }
284
285 #[test]
286 fn test_intervention_affects_downstream() {
287 let graph = CausalGraph::fraud_detection_template();
288 let scm = StructuralCausalModel::new(graph).unwrap();
289
290 let high_intervened = scm.intervene("transaction_amount", 100000.0).unwrap();
292 let high_samples = high_intervened.generate(200, 42).unwrap();
293 let high_fraud_rate: f64 = high_samples
294 .iter()
295 .map(|s| s.get("is_fraud").copied().unwrap_or(0.0))
296 .sum::<f64>()
297 / 200.0;
298
299 let low_intervened = scm.intervene("transaction_amount", 1.0).unwrap();
301 let low_samples = low_intervened.generate(200, 42).unwrap();
302 let low_fraud_rate: f64 = low_samples
303 .iter()
304 .map(|s| s.get("is_fraud").copied().unwrap_or(0.0))
305 .sum::<f64>()
306 / 200.0;
307
308 assert!(
310 high_fraud_rate >= low_fraud_rate,
311 "High transaction amount ({}) should increase fraud rate ({} vs {})",
312 100000.0,
313 high_fraud_rate,
314 low_fraud_rate
315 );
316 }
317
318 #[test]
319 fn test_intervention_unknown_variable() {
320 let graph = CausalGraph::fraud_detection_template();
321 let scm = StructuralCausalModel::new(graph).unwrap();
322 assert!(scm.intervene("nonexistent", 0.0).is_err());
323 }
324
325 #[test]
326 fn test_cyclic_graph_rejected_by_scm() {
327 use super::super::graph::{CausalEdge, CausalMechanism, CausalVarType, CausalVariable};
328 let mut graph = CausalGraph::new();
329 graph.add_variable(CausalVariable::new("a", CausalVarType::Continuous));
330 graph.add_variable(CausalVariable::new("b", CausalVarType::Continuous));
331 graph.add_edge(CausalEdge {
332 from: "a".into(),
333 to: "b".into(),
334 mechanism: CausalMechanism::Linear { coefficient: 1.0 },
335 strength: 1.0,
336 });
337 graph.add_edge(CausalEdge {
338 from: "b".into(),
339 to: "a".into(),
340 mechanism: CausalMechanism::Linear { coefficient: 1.0 },
341 strength: 1.0,
342 });
343 assert!(StructuralCausalModel::new(graph).is_err());
344 }
345}