1use serde::{Deserialize, Serialize};
2use std::collections::{HashMap, HashSet, VecDeque};
3
4#[derive(Debug, Clone, Default, Serialize, Deserialize)]
6#[serde(rename_all = "snake_case")]
7pub enum CausalVarType {
8 #[default]
9 Continuous,
10 Categorical,
11 Count,
12 Binary,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct CausalVariable {
18 pub name: String,
19 #[serde(default)]
20 pub var_type: CausalVarType,
21 #[serde(default)]
23 pub distribution: Option<String>,
24 #[serde(default)]
26 pub params: HashMap<String, f64>,
27}
28
29impl CausalVariable {
30 pub fn new(name: impl Into<String>, var_type: CausalVarType) -> Self {
31 Self {
32 name: name.into(),
33 var_type,
34 distribution: None,
35 params: HashMap::new(),
36 }
37 }
38
39 pub fn with_distribution(mut self, dist: impl Into<String>) -> Self {
40 self.distribution = Some(dist.into());
41 self
42 }
43
44 pub fn with_param(mut self, key: impl Into<String>, value: f64) -> Self {
45 self.params.insert(key.into(), value);
46 self
47 }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52#[serde(tag = "type", rename_all = "snake_case")]
53pub enum CausalMechanism {
54 Linear { coefficient: f64 },
56 Threshold { cutoff: f64 },
58 Polynomial { coefficients: Vec<f64> },
60 Logistic { scale: f64, midpoint: f64 },
62}
63
64impl CausalMechanism {
65 pub fn apply(&self, parent_value: f64) -> f64 {
67 match self {
68 CausalMechanism::Linear { coefficient } => coefficient * parent_value,
69 CausalMechanism::Threshold { cutoff } => {
70 if parent_value > *cutoff {
71 1.0
72 } else {
73 0.0
74 }
75 }
76 CausalMechanism::Polynomial { coefficients } => coefficients
77 .iter()
78 .enumerate()
79 .map(|(i, c)| c * parent_value.powi(i as i32))
80 .sum(),
81 CausalMechanism::Logistic { scale, midpoint } => {
82 1.0 / (1.0 + (-scale * (parent_value - midpoint)).exp())
83 }
84 }
85 }
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct CausalEdge {
91 pub from: String,
92 pub to: String,
93 pub mechanism: CausalMechanism,
94 #[serde(default = "default_strength")]
95 pub strength: f64,
96}
97
98fn default_strength() -> f64 {
99 1.0
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct CausalGraph {
105 pub variables: Vec<CausalVariable>,
106 pub edges: Vec<CausalEdge>,
107}
108
109impl CausalGraph {
110 pub fn new() -> Self {
111 Self {
112 variables: Vec::new(),
113 edges: Vec::new(),
114 }
115 }
116
117 pub fn add_variable(&mut self, var: CausalVariable) {
118 self.variables.push(var);
119 }
120
121 pub fn add_edge(&mut self, edge: CausalEdge) {
122 self.edges.push(edge);
123 }
124
125 pub fn variable_names(&self) -> Vec<&str> {
127 self.variables.iter().map(|v| v.name.as_str()).collect()
128 }
129
130 pub fn get_variable(&self, name: &str) -> Option<&CausalVariable> {
132 self.variables.iter().find(|v| v.name == name)
133 }
134
135 pub fn parent_edges(&self, variable: &str) -> Vec<&CausalEdge> {
137 self.edges.iter().filter(|e| e.to == variable).collect()
138 }
139
140 pub fn validate(&self) -> Result<(), String> {
142 let var_names: HashSet<&str> = self.variables.iter().map(|v| v.name.as_str()).collect();
143
144 for edge in &self.edges {
146 if edge.from == edge.to {
147 return Err(format!("Self-loop detected on variable '{}'", edge.from));
148 }
149 }
150
151 for edge in &self.edges {
153 if !var_names.contains(edge.from.as_str()) {
154 return Err(format!("Edge references unknown variable '{}'", edge.from));
155 }
156 if !var_names.contains(edge.to.as_str()) {
157 return Err(format!("Edge references unknown variable '{}'", edge.to));
158 }
159 }
160
161 self.topological_order().map(|_| ())
163 }
164
165 pub fn topological_order(&self) -> Result<Vec<String>, String> {
167 let var_names: Vec<String> = self.variables.iter().map(|v| v.name.clone()).collect();
168 let n = var_names.len();
169 let name_to_idx: HashMap<&str, usize> = var_names
170 .iter()
171 .enumerate()
172 .map(|(i, n)| (n.as_str(), i))
173 .collect();
174
175 let mut in_degree = vec![0usize; n];
177 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
178
179 for edge in &self.edges {
180 if let (Some(&from_idx), Some(&to_idx)) = (
181 name_to_idx.get(edge.from.as_str()),
182 name_to_idx.get(edge.to.as_str()),
183 ) {
184 adj[from_idx].push(to_idx);
185 in_degree[to_idx] += 1;
186 }
187 }
188
189 let mut queue: VecDeque<usize> = VecDeque::new();
191 for (i, °) in in_degree.iter().enumerate() {
192 if deg == 0 {
193 queue.push_back(i);
194 }
195 }
196
197 let mut order = Vec::with_capacity(n);
198 while let Some(node) = queue.pop_front() {
199 order.push(var_names[node].clone());
200 for &neighbor in &adj[node] {
201 in_degree[neighbor] -= 1;
202 if in_degree[neighbor] == 0 {
203 queue.push_back(neighbor);
204 }
205 }
206 }
207
208 if order.len() != n {
209 Err("Causal graph contains a cycle".to_string())
210 } else {
211 Ok(order)
212 }
213 }
214
215 pub fn fraud_detection_template() -> Self {
217 let mut graph = Self::new();
218 graph.add_variable(
219 CausalVariable::new("transaction_amount", CausalVarType::Continuous)
220 .with_distribution("lognormal")
221 .with_param("mu", 6.0)
222 .with_param("sigma", 1.5),
223 );
224 graph.add_variable(
225 CausalVariable::new("merchant_risk", CausalVarType::Continuous)
226 .with_distribution("beta")
227 .with_param("alpha", 2.0)
228 .with_param("beta_param", 5.0),
229 );
230 graph.add_variable(
231 CausalVariable::new("transaction_frequency", CausalVarType::Count)
232 .with_distribution("normal")
233 .with_param("mean", 10.0)
234 .with_param("std", 3.0),
235 );
236 graph.add_variable(CausalVariable::new(
237 "fraud_probability",
238 CausalVarType::Continuous,
239 ));
240 graph.add_variable(CausalVariable::new("is_fraud", CausalVarType::Binary));
241
242 graph.add_edge(CausalEdge {
243 from: "transaction_amount".into(),
244 to: "fraud_probability".into(),
245 mechanism: CausalMechanism::Linear { coefficient: 0.3 },
246 strength: 1.0,
247 });
248 graph.add_edge(CausalEdge {
249 from: "merchant_risk".into(),
250 to: "fraud_probability".into(),
251 mechanism: CausalMechanism::Linear { coefficient: 0.5 },
252 strength: 1.0,
253 });
254 graph.add_edge(CausalEdge {
255 from: "transaction_frequency".into(),
256 to: "fraud_probability".into(),
257 mechanism: CausalMechanism::Linear { coefficient: 0.2 },
258 strength: 1.0,
259 });
260 graph.add_edge(CausalEdge {
261 from: "fraud_probability".into(),
262 to: "is_fraud".into(),
263 mechanism: CausalMechanism::Threshold { cutoff: 0.7 },
264 strength: 1.0,
265 });
266
267 graph
268 }
269
270 pub fn revenue_cycle_template() -> Self {
272 let mut graph = Self::new();
273 graph.add_variable(
274 CausalVariable::new("order_volume", CausalVarType::Continuous)
275 .with_distribution("normal")
276 .with_param("mean", 100.0)
277 .with_param("std", 30.0),
278 );
279 graph.add_variable(
280 CausalVariable::new("shipment_rate", CausalVarType::Continuous)
281 .with_distribution("beta")
282 .with_param("alpha", 8.0)
283 .with_param("beta_param", 2.0),
284 );
285 graph.add_variable(CausalVariable::new(
286 "invoice_amount",
287 CausalVarType::Continuous,
288 ));
289 graph.add_variable(CausalVariable::new(
290 "collection_rate",
291 CausalVarType::Continuous,
292 ));
293
294 graph.add_edge(CausalEdge {
295 from: "order_volume".into(),
296 to: "shipment_rate".into(),
297 mechanism: CausalMechanism::Logistic {
298 scale: 0.05,
299 midpoint: 50.0,
300 },
301 strength: 1.0,
302 });
303 graph.add_edge(CausalEdge {
304 from: "order_volume".into(),
305 to: "invoice_amount".into(),
306 mechanism: CausalMechanism::Linear { coefficient: 100.0 },
307 strength: 1.0,
308 });
309 graph.add_edge(CausalEdge {
310 from: "shipment_rate".into(),
311 to: "invoice_amount".into(),
312 mechanism: CausalMechanism::Linear { coefficient: 0.5 },
313 strength: 1.0,
314 });
315 graph.add_edge(CausalEdge {
316 from: "invoice_amount".into(),
317 to: "collection_rate".into(),
318 mechanism: CausalMechanism::Logistic {
319 scale: -0.0001,
320 midpoint: 5000.0,
321 },
322 strength: 1.0,
323 });
324
325 graph
326 }
327}
328
329impl Default for CausalGraph {
330 fn default() -> Self {
331 Self::new()
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338
339 #[test]
340 fn test_acyclic_graph_validates() {
341 let graph = CausalGraph::fraud_detection_template();
342 assert!(graph.validate().is_ok());
343 }
344
345 #[test]
346 fn test_cyclic_graph_rejected() {
347 let mut graph = CausalGraph::new();
348 graph.add_variable(CausalVariable::new("a", CausalVarType::Continuous));
349 graph.add_variable(CausalVariable::new("b", CausalVarType::Continuous));
350 graph.add_edge(CausalEdge {
351 from: "a".into(),
352 to: "b".into(),
353 mechanism: CausalMechanism::Linear { coefficient: 1.0 },
354 strength: 1.0,
355 });
356 graph.add_edge(CausalEdge {
357 from: "b".into(),
358 to: "a".into(),
359 mechanism: CausalMechanism::Linear { coefficient: 1.0 },
360 strength: 1.0,
361 });
362 assert!(graph.validate().is_err());
363 }
364
365 #[test]
366 fn test_self_loop_rejected() {
367 let mut graph = CausalGraph::new();
368 graph.add_variable(CausalVariable::new("a", CausalVarType::Continuous));
369 graph.add_edge(CausalEdge {
370 from: "a".into(),
371 to: "a".into(),
372 mechanism: CausalMechanism::Linear { coefficient: 1.0 },
373 strength: 1.0,
374 });
375 let result = graph.validate();
376 assert!(result.is_err());
377 assert!(result.unwrap_err().contains("Self-loop"));
378 }
379
380 #[test]
381 fn test_topological_order() {
382 let graph = CausalGraph::fraud_detection_template();
383 let order = graph.topological_order().unwrap();
384 let amount_pos = order
386 .iter()
387 .position(|n| n == "transaction_amount")
388 .unwrap();
389 let fraud_prob_pos = order.iter().position(|n| n == "fraud_probability").unwrap();
390 let is_fraud_pos = order.iter().position(|n| n == "is_fraud").unwrap();
391 assert!(amount_pos < fraud_prob_pos);
392 assert!(fraud_prob_pos < is_fraud_pos);
393 }
394
395 #[test]
396 fn test_unknown_variable_rejected() {
397 let mut graph = CausalGraph::new();
398 graph.add_variable(CausalVariable::new("a", CausalVarType::Continuous));
399 graph.add_edge(CausalEdge {
400 from: "a".into(),
401 to: "nonexistent".into(),
402 mechanism: CausalMechanism::Linear { coefficient: 1.0 },
403 strength: 1.0,
404 });
405 assert!(graph.validate().is_err());
406 }
407
408 #[test]
409 fn test_mechanism_linear() {
410 let m = CausalMechanism::Linear { coefficient: 2.0 };
411 assert!((m.apply(3.0) - 6.0).abs() < 1e-10);
412 }
413
414 #[test]
415 fn test_mechanism_threshold() {
416 let m = CausalMechanism::Threshold { cutoff: 0.5 };
417 assert!((m.apply(0.3) - 0.0).abs() < 1e-10);
418 assert!((m.apply(0.7) - 1.0).abs() < 1e-10);
419 }
420
421 #[test]
422 fn test_mechanism_logistic() {
423 let m = CausalMechanism::Logistic {
424 scale: 1.0,
425 midpoint: 0.0,
426 };
427 assert!((m.apply(0.0) - 0.5).abs() < 1e-10);
428 assert!(m.apply(10.0) > 0.99);
429 assert!(m.apply(-10.0) < 0.01);
430 }
431
432 #[test]
433 fn test_mechanism_polynomial() {
434 let m = CausalMechanism::Polynomial {
435 coefficients: vec![1.0, 2.0, 3.0],
436 };
437 assert!((m.apply(2.0) - 17.0).abs() < 1e-10);
439 }
440
441 #[test]
442 fn test_revenue_cycle_validates() {
443 let graph = CausalGraph::revenue_cycle_template();
444 assert!(graph.validate().is_ok());
445 }
446
447 #[test]
448 fn test_graph_serde_roundtrip() {
449 let graph = CausalGraph::fraud_detection_template();
450 let json = serde_json::to_string(&graph).unwrap();
451 let deserialized: CausalGraph = serde_json::from_str(&json).unwrap();
452 assert_eq!(deserialized.variables.len(), graph.variables.len());
453 assert_eq!(deserialized.edges.len(), graph.edges.len());
454 }
455}