1use std::collections::HashMap;
2
3use rand::Rng;
4use rand::SeedableRng;
5use rand_chacha::ChaCha8Rng;
6use rand_distr::{Distribution, LogNormal, Normal};
7
8use super::graph::{CausalGraph, CausalVarType, CausalVariable};
9
10pub struct StructuralCausalModel {
12 graph: CausalGraph,
13}
14
15impl StructuralCausalModel {
16 pub fn new(graph: CausalGraph) -> Result<Self, String> {
17 graph.validate()?;
18 Ok(Self { graph })
19 }
20
21 pub fn graph(&self) -> &CausalGraph {
23 &self.graph
24 }
25
26 pub fn generate(
28 &self,
29 n_samples: usize,
30 seed: u64,
31 ) -> Result<Vec<HashMap<String, f64>>, String> {
32 let order = self.graph.topological_order()?;
33 let mut rng = ChaCha8Rng::seed_from_u64(seed);
34 let mut samples = Vec::with_capacity(n_samples);
35
36 for _ in 0..n_samples {
37 let mut record: HashMap<String, f64> = HashMap::new();
38
39 for var_name in &order {
40 let var = self
41 .graph
42 .get_variable(var_name)
43 .ok_or_else(|| format!("Variable '{}' not found", var_name))?;
44
45 let noise = self.sample_exogenous(var, &mut rng);
47
48 let parent_edges = self.graph.parent_edges(var_name);
50 let parent_contribution: f64 = parent_edges
51 .iter()
52 .map(|edge| {
53 let parent_val = record.get(&edge.from).copied().unwrap_or(0.0);
54 edge.mechanism.apply(parent_val) * edge.strength
55 })
56 .sum();
57
58 let value = match var.var_type {
60 CausalVarType::Binary => {
61 let prob = (noise + parent_contribution).clamp(0.0, 1.0);
62 if rng.gen::<f64>() < prob {
63 1.0
64 } else {
65 0.0
66 }
67 }
68 CausalVarType::Count => (noise + parent_contribution).max(0.0).round(),
69 _ => noise + parent_contribution,
70 };
71
72 record.insert(var_name.clone(), value);
73 }
74
75 samples.push(record);
76 }
77
78 Ok(samples)
79 }
80
81 fn sample_exogenous(&self, var: &CausalVariable, rng: &mut ChaCha8Rng) -> f64 {
83 let dist = var.distribution.as_deref().unwrap_or("normal");
84 match dist {
85 "lognormal" => {
86 let mu = var.params.get("mu").copied().unwrap_or(0.0);
87 let sigma = var.params.get("sigma").copied().unwrap_or(1.0);
88 if let Ok(d) = LogNormal::new(mu, sigma) {
89 d.sample(rng)
90 } else {
91 0.0
92 }
93 }
94 "beta" => {
95 let alpha = var.params.get("alpha").copied().unwrap_or(2.0);
97 let beta_param = var.params.get("beta_param").copied().unwrap_or(2.0);
98 let mean = alpha / (alpha + beta_param);
99 let var_val = (alpha * beta_param)
100 / ((alpha + beta_param).powi(2) * (alpha + beta_param + 1.0));
101 if let Ok(d) = Normal::new(mean, var_val.sqrt()) {
102 d.sample(rng).clamp(0.0, 1.0) } else {
104 mean
105 }
106 }
107 "uniform" => {
108 let low = var.params.get("low").copied().unwrap_or(0.0);
109 let high = var.params.get("high").copied().unwrap_or(1.0);
110 rng.gen::<f64>() * (high - low) + low
111 }
112 _ => {
113 let mean = var.params.get("mean").copied().unwrap_or(0.0);
115 let std = var.params.get("std").copied().unwrap_or(1.0);
116 if let Ok(d) = Normal::new(mean, std) {
117 d.sample(rng)
118 } else {
119 mean
120 }
121 }
122 }
123 }
124
125 pub fn intervene(&self, variable: &str, value: f64) -> Result<IntervenedScm<'_>, String> {
128 if self.graph.get_variable(variable).is_none() {
130 return Err(format!(
131 "Variable '{}' not found for intervention",
132 variable
133 ));
134 }
135 Ok(IntervenedScm {
136 base: self,
137 interventions: vec![(variable.to_string(), value)],
138 })
139 }
140}
141
142pub struct IntervenedScm<'a> {
144 base: &'a StructuralCausalModel,
145 interventions: Vec<(String, f64)>,
146}
147
148impl<'a> IntervenedScm<'a> {
149 pub fn and_intervene(mut self, variable: &str, value: f64) -> Self {
151 self.interventions.push((variable.to_string(), value));
152 self
153 }
154
155 pub fn generate(
157 &self,
158 n_samples: usize,
159 seed: u64,
160 ) -> Result<Vec<HashMap<String, f64>>, String> {
161 let order = self.base.graph.topological_order()?;
162 let mut rng = ChaCha8Rng::seed_from_u64(seed);
163 let intervention_map: HashMap<&str, f64> = self
164 .interventions
165 .iter()
166 .map(|(k, v)| (k.as_str(), *v))
167 .collect();
168 let mut samples = Vec::with_capacity(n_samples);
169
170 for _ in 0..n_samples {
171 let mut record: HashMap<String, f64> = HashMap::new();
172
173 for var_name in &order {
174 if let Some(&fixed_val) = intervention_map.get(var_name.as_str()) {
176 record.insert(var_name.clone(), fixed_val);
177 continue;
178 }
179
180 let var = self
181 .base
182 .graph
183 .get_variable(var_name)
184 .ok_or_else(|| format!("Variable '{}' not found", var_name))?;
185
186 let noise = self.base.sample_exogenous(var, &mut rng);
187 let parent_edges = self.base.graph.parent_edges(var_name);
188 let parent_contribution: f64 = parent_edges
189 .iter()
190 .map(|edge| {
191 let parent_val = record.get(&edge.from).copied().unwrap_or(0.0);
192 edge.mechanism.apply(parent_val) * edge.strength
193 })
194 .sum();
195
196 let value = match var.var_type {
197 CausalVarType::Binary => {
198 let prob = (noise + parent_contribution).clamp(0.0, 1.0);
199 if rng.gen::<f64>() < prob {
200 1.0
201 } else {
202 0.0
203 }
204 }
205 CausalVarType::Count => (noise + parent_contribution).max(0.0).round(),
206 _ => noise + parent_contribution,
207 };
208
209 record.insert(var_name.clone(), value);
210 }
211
212 samples.push(record);
213 }
214
215 Ok(samples)
216 }
217}
218
219#[cfg(test)]
220#[allow(clippy::unwrap_used)]
221mod tests {
222 use super::super::graph::CausalGraph;
223 use super::*;
224
225 #[test]
226 fn test_scm_generates_correct_count() {
227 let graph = CausalGraph::fraud_detection_template();
228 let scm = StructuralCausalModel::new(graph).unwrap();
229 let samples = scm.generate(100, 42).unwrap();
230 assert_eq!(samples.len(), 100);
231 }
232
233 #[test]
234 fn test_scm_deterministic() {
235 let graph = CausalGraph::fraud_detection_template();
236 let scm = StructuralCausalModel::new(graph).unwrap();
237 let s1 = scm.generate(50, 42).unwrap();
238 let s2 = scm.generate(50, 42).unwrap();
239 for (a, b) in s1.iter().zip(s2.iter()) {
240 assert_eq!(a.get("transaction_amount"), b.get("transaction_amount"));
241 }
242 }
243
244 #[test]
245 fn test_scm_all_variables_present() {
246 let graph = CausalGraph::fraud_detection_template();
247 let var_names: Vec<String> = graph.variables.iter().map(|v| v.name.clone()).collect();
248 let scm = StructuralCausalModel::new(graph).unwrap();
249 let samples = scm.generate(10, 42).unwrap();
250 for sample in &samples {
251 for name in &var_names {
252 assert!(
253 sample.contains_key(name),
254 "Sample missing variable '{}'",
255 name
256 );
257 }
258 }
259 }
260
261 #[test]
262 fn test_scm_is_fraud_binary() {
263 let graph = CausalGraph::fraud_detection_template();
264 let scm = StructuralCausalModel::new(graph).unwrap();
265 let samples = scm.generate(100, 42).unwrap();
266 for sample in &samples {
267 let val = sample.get("is_fraud").copied().unwrap_or(-1.0);
268 assert!(
269 val == 0.0 || val == 1.0,
270 "is_fraud should be binary, got {}",
271 val
272 );
273 }
274 }
275
276 #[test]
277 fn test_intervention_sets_value() {
278 let graph = CausalGraph::fraud_detection_template();
279 let scm = StructuralCausalModel::new(graph).unwrap();
280 let intervened = scm.intervene("transaction_amount", 10000.0).unwrap();
281 let samples = intervened.generate(50, 42).unwrap();
282 for sample in &samples {
283 assert_eq!(sample.get("transaction_amount").copied(), Some(10000.0));
284 }
285 }
286
287 #[test]
288 fn test_intervention_affects_downstream() {
289 let graph = CausalGraph::fraud_detection_template();
290 let scm = StructuralCausalModel::new(graph).unwrap();
291
292 let high_intervened = scm.intervene("transaction_amount", 100000.0).unwrap();
294 let high_samples = high_intervened.generate(200, 42).unwrap();
295 let high_fraud_rate: f64 = high_samples
296 .iter()
297 .map(|s| s.get("is_fraud").copied().unwrap_or(0.0))
298 .sum::<f64>()
299 / 200.0;
300
301 let low_intervened = scm.intervene("transaction_amount", 1.0).unwrap();
303 let low_samples = low_intervened.generate(200, 42).unwrap();
304 let low_fraud_rate: f64 = low_samples
305 .iter()
306 .map(|s| s.get("is_fraud").copied().unwrap_or(0.0))
307 .sum::<f64>()
308 / 200.0;
309
310 assert!(
312 high_fraud_rate >= low_fraud_rate,
313 "High transaction amount ({}) should increase fraud rate ({} vs {})",
314 100000.0,
315 high_fraud_rate,
316 low_fraud_rate
317 );
318 }
319
320 #[test]
321 fn test_intervention_unknown_variable() {
322 let graph = CausalGraph::fraud_detection_template();
323 let scm = StructuralCausalModel::new(graph).unwrap();
324 assert!(scm.intervene("nonexistent", 0.0).is_err());
325 }
326
327 #[test]
328 fn test_cyclic_graph_rejected_by_scm() {
329 use super::super::graph::{CausalEdge, CausalMechanism, CausalVarType, CausalVariable};
330 let mut graph = CausalGraph::new();
331 graph.add_variable(CausalVariable::new("a", CausalVarType::Continuous));
332 graph.add_variable(CausalVariable::new("b", CausalVarType::Continuous));
333 graph.add_edge(CausalEdge {
334 from: "a".into(),
335 to: "b".into(),
336 mechanism: CausalMechanism::Linear { coefficient: 1.0 },
337 strength: 1.0,
338 });
339 graph.add_edge(CausalEdge {
340 from: "b".into(),
341 to: "a".into(),
342 mechanism: CausalMechanism::Linear { coefficient: 1.0 },
343 strength: 1.0,
344 });
345 assert!(StructuralCausalModel::new(graph).is_err());
346 }
347}