1use serde::{Deserialize, Serialize};
2use std::collections::{HashMap, HashSet, VecDeque};
3
4#[derive(Debug, Clone, Default, Serialize, Deserialize)]
6#[serde(rename_all = "snake_case")]
7pub enum CausalVarType {
8 #[default]
9 Continuous,
10 Categorical,
11 Count,
12 Binary,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct CausalVariable {
18 pub name: String,
19 #[serde(default)]
20 pub var_type: CausalVarType,
21 #[serde(default)]
23 pub distribution: Option<String>,
24 #[serde(default)]
26 pub params: HashMap<String, f64>,
27}
28
29impl CausalVariable {
30 pub fn new(name: impl Into<String>, var_type: CausalVarType) -> Self {
31 Self {
32 name: name.into(),
33 var_type,
34 distribution: None,
35 params: HashMap::new(),
36 }
37 }
38
39 pub fn with_distribution(mut self, dist: impl Into<String>) -> Self {
40 self.distribution = Some(dist.into());
41 self
42 }
43
44 pub fn with_param(mut self, key: impl Into<String>, value: f64) -> Self {
45 self.params.insert(key.into(), value);
46 self
47 }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52#[serde(tag = "type", rename_all = "snake_case")]
53pub enum CausalMechanism {
54 Linear { coefficient: f64 },
56 Threshold { cutoff: f64 },
58 Polynomial { coefficients: Vec<f64> },
60 Logistic { scale: f64, midpoint: f64 },
62}
63
64impl CausalMechanism {
65 pub fn apply(&self, parent_value: f64) -> f64 {
67 match self {
68 CausalMechanism::Linear { coefficient } => coefficient * parent_value,
69 CausalMechanism::Threshold { cutoff } => {
70 if parent_value > *cutoff {
71 1.0
72 } else {
73 0.0
74 }
75 }
76 CausalMechanism::Polynomial { coefficients } => coefficients
77 .iter()
78 .enumerate()
79 .map(|(i, c)| c * parent_value.powi(i as i32))
80 .sum(),
81 CausalMechanism::Logistic { scale, midpoint } => {
82 1.0 / (1.0 + (-scale * (parent_value - midpoint)).exp())
83 }
84 }
85 }
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct CausalEdge {
91 pub from: String,
92 pub to: String,
93 pub mechanism: CausalMechanism,
94 #[serde(default = "default_strength")]
95 pub strength: f64,
96}
97
98fn default_strength() -> f64 {
99 1.0
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct CausalGraph {
105 pub variables: Vec<CausalVariable>,
106 pub edges: Vec<CausalEdge>,
107}
108
109impl CausalGraph {
110 pub fn new() -> Self {
111 Self {
112 variables: Vec::new(),
113 edges: Vec::new(),
114 }
115 }
116
117 pub fn add_variable(&mut self, var: CausalVariable) {
118 self.variables.push(var);
119 }
120
121 pub fn add_edge(&mut self, edge: CausalEdge) {
122 self.edges.push(edge);
123 }
124
125 pub fn variable_names(&self) -> Vec<&str> {
127 self.variables.iter().map(|v| v.name.as_str()).collect()
128 }
129
130 pub fn get_variable(&self, name: &str) -> Option<&CausalVariable> {
132 self.variables.iter().find(|v| v.name == name)
133 }
134
135 pub fn parent_edges(&self, variable: &str) -> Vec<&CausalEdge> {
137 self.edges.iter().filter(|e| e.to == variable).collect()
138 }
139
140 pub fn validate(&self) -> Result<(), String> {
142 let var_names: HashSet<&str> = self.variables.iter().map(|v| v.name.as_str()).collect();
143
144 for edge in &self.edges {
146 if edge.from == edge.to {
147 return Err(format!("Self-loop detected on variable '{}'", edge.from));
148 }
149 }
150
151 for edge in &self.edges {
153 if !var_names.contains(edge.from.as_str()) {
154 return Err(format!("Edge references unknown variable '{}'", edge.from));
155 }
156 if !var_names.contains(edge.to.as_str()) {
157 return Err(format!("Edge references unknown variable '{}'", edge.to));
158 }
159 }
160
161 self.topological_order().map(|_| ())
163 }
164
165 pub fn topological_order(&self) -> Result<Vec<String>, String> {
167 let var_names: Vec<String> = self.variables.iter().map(|v| v.name.clone()).collect();
168 let n = var_names.len();
169 let name_to_idx: HashMap<&str, usize> = var_names
170 .iter()
171 .enumerate()
172 .map(|(i, n)| (n.as_str(), i))
173 .collect();
174
175 let mut in_degree = vec![0usize; n];
177 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
178
179 for edge in &self.edges {
180 if let (Some(&from_idx), Some(&to_idx)) = (
181 name_to_idx.get(edge.from.as_str()),
182 name_to_idx.get(edge.to.as_str()),
183 ) {
184 adj[from_idx].push(to_idx);
185 in_degree[to_idx] += 1;
186 }
187 }
188
189 let mut queue: VecDeque<usize> = VecDeque::new();
191 for (i, °) in in_degree.iter().enumerate() {
192 if deg == 0 {
193 queue.push_back(i);
194 }
195 }
196
197 let mut order = Vec::with_capacity(n);
198 while let Some(node) = queue.pop_front() {
199 order.push(var_names[node].clone());
200 for &neighbor in &adj[node] {
201 in_degree[neighbor] -= 1;
202 if in_degree[neighbor] == 0 {
203 queue.push_back(neighbor);
204 }
205 }
206 }
207
208 if order.len() != n {
209 Err("Causal graph contains a cycle".to_string())
210 } else {
211 Ok(order)
212 }
213 }
214
215 pub fn fraud_detection_template() -> Self {
217 let mut graph = Self::new();
218 graph.add_variable(
219 CausalVariable::new("transaction_amount", CausalVarType::Continuous)
220 .with_distribution("lognormal")
221 .with_param("mu", 6.0)
222 .with_param("sigma", 1.5),
223 );
224 graph.add_variable(
225 CausalVariable::new("merchant_risk", CausalVarType::Continuous)
226 .with_distribution("beta")
227 .with_param("alpha", 2.0)
228 .with_param("beta_param", 5.0),
229 );
230 graph.add_variable(
231 CausalVariable::new("transaction_frequency", CausalVarType::Count)
232 .with_distribution("normal")
233 .with_param("mean", 10.0)
234 .with_param("std", 3.0),
235 );
236 graph.add_variable(CausalVariable::new(
237 "fraud_probability",
238 CausalVarType::Continuous,
239 ));
240 graph.add_variable(CausalVariable::new("is_fraud", CausalVarType::Binary));
241
242 graph.add_edge(CausalEdge {
243 from: "transaction_amount".into(),
244 to: "fraud_probability".into(),
245 mechanism: CausalMechanism::Linear { coefficient: 0.3 },
246 strength: 1.0,
247 });
248 graph.add_edge(CausalEdge {
249 from: "merchant_risk".into(),
250 to: "fraud_probability".into(),
251 mechanism: CausalMechanism::Linear { coefficient: 0.5 },
252 strength: 1.0,
253 });
254 graph.add_edge(CausalEdge {
255 from: "transaction_frequency".into(),
256 to: "fraud_probability".into(),
257 mechanism: CausalMechanism::Linear { coefficient: 0.2 },
258 strength: 1.0,
259 });
260 graph.add_edge(CausalEdge {
261 from: "fraud_probability".into(),
262 to: "is_fraud".into(),
263 mechanism: CausalMechanism::Threshold { cutoff: 0.7 },
264 strength: 1.0,
265 });
266
267 graph
268 }
269
270 pub fn revenue_cycle_template() -> Self {
272 let mut graph = Self::new();
273 graph.add_variable(
274 CausalVariable::new("order_volume", CausalVarType::Continuous)
275 .with_distribution("normal")
276 .with_param("mean", 100.0)
277 .with_param("std", 30.0),
278 );
279 graph.add_variable(
280 CausalVariable::new("shipment_rate", CausalVarType::Continuous)
281 .with_distribution("beta")
282 .with_param("alpha", 8.0)
283 .with_param("beta_param", 2.0),
284 );
285 graph.add_variable(CausalVariable::new(
286 "invoice_amount",
287 CausalVarType::Continuous,
288 ));
289 graph.add_variable(CausalVariable::new(
290 "collection_rate",
291 CausalVarType::Continuous,
292 ));
293
294 graph.add_edge(CausalEdge {
295 from: "order_volume".into(),
296 to: "shipment_rate".into(),
297 mechanism: CausalMechanism::Logistic {
298 scale: 0.05,
299 midpoint: 50.0,
300 },
301 strength: 1.0,
302 });
303 graph.add_edge(CausalEdge {
304 from: "order_volume".into(),
305 to: "invoice_amount".into(),
306 mechanism: CausalMechanism::Linear { coefficient: 100.0 },
307 strength: 1.0,
308 });
309 graph.add_edge(CausalEdge {
310 from: "shipment_rate".into(),
311 to: "invoice_amount".into(),
312 mechanism: CausalMechanism::Linear { coefficient: 0.5 },
313 strength: 1.0,
314 });
315 graph.add_edge(CausalEdge {
316 from: "invoice_amount".into(),
317 to: "collection_rate".into(),
318 mechanism: CausalMechanism::Logistic {
319 scale: -0.0001,
320 midpoint: 5000.0,
321 },
322 strength: 1.0,
323 });
324
325 graph
326 }
327}
328
329impl Default for CausalGraph {
330 fn default() -> Self {
331 Self::new()
332 }
333}
334
335#[cfg(test)]
336#[allow(clippy::unwrap_used)]
337mod tests {
338 use super::*;
339
340 #[test]
341 fn test_acyclic_graph_validates() {
342 let graph = CausalGraph::fraud_detection_template();
343 assert!(graph.validate().is_ok());
344 }
345
346 #[test]
347 fn test_cyclic_graph_rejected() {
348 let mut graph = CausalGraph::new();
349 graph.add_variable(CausalVariable::new("a", CausalVarType::Continuous));
350 graph.add_variable(CausalVariable::new("b", CausalVarType::Continuous));
351 graph.add_edge(CausalEdge {
352 from: "a".into(),
353 to: "b".into(),
354 mechanism: CausalMechanism::Linear { coefficient: 1.0 },
355 strength: 1.0,
356 });
357 graph.add_edge(CausalEdge {
358 from: "b".into(),
359 to: "a".into(),
360 mechanism: CausalMechanism::Linear { coefficient: 1.0 },
361 strength: 1.0,
362 });
363 assert!(graph.validate().is_err());
364 }
365
366 #[test]
367 fn test_self_loop_rejected() {
368 let mut graph = CausalGraph::new();
369 graph.add_variable(CausalVariable::new("a", CausalVarType::Continuous));
370 graph.add_edge(CausalEdge {
371 from: "a".into(),
372 to: "a".into(),
373 mechanism: CausalMechanism::Linear { coefficient: 1.0 },
374 strength: 1.0,
375 });
376 let result = graph.validate();
377 assert!(result.is_err());
378 assert!(result.unwrap_err().contains("Self-loop"));
379 }
380
381 #[test]
382 fn test_topological_order() {
383 let graph = CausalGraph::fraud_detection_template();
384 let order = graph.topological_order().unwrap();
385 let amount_pos = order
387 .iter()
388 .position(|n| n == "transaction_amount")
389 .unwrap();
390 let fraud_prob_pos = order.iter().position(|n| n == "fraud_probability").unwrap();
391 let is_fraud_pos = order.iter().position(|n| n == "is_fraud").unwrap();
392 assert!(amount_pos < fraud_prob_pos);
393 assert!(fraud_prob_pos < is_fraud_pos);
394 }
395
396 #[test]
397 fn test_unknown_variable_rejected() {
398 let mut graph = CausalGraph::new();
399 graph.add_variable(CausalVariable::new("a", CausalVarType::Continuous));
400 graph.add_edge(CausalEdge {
401 from: "a".into(),
402 to: "nonexistent".into(),
403 mechanism: CausalMechanism::Linear { coefficient: 1.0 },
404 strength: 1.0,
405 });
406 assert!(graph.validate().is_err());
407 }
408
409 #[test]
410 fn test_mechanism_linear() {
411 let m = CausalMechanism::Linear { coefficient: 2.0 };
412 assert!((m.apply(3.0) - 6.0).abs() < 1e-10);
413 }
414
415 #[test]
416 fn test_mechanism_threshold() {
417 let m = CausalMechanism::Threshold { cutoff: 0.5 };
418 assert!((m.apply(0.3) - 0.0).abs() < 1e-10);
419 assert!((m.apply(0.7) - 1.0).abs() < 1e-10);
420 }
421
422 #[test]
423 fn test_mechanism_logistic() {
424 let m = CausalMechanism::Logistic {
425 scale: 1.0,
426 midpoint: 0.0,
427 };
428 assert!((m.apply(0.0) - 0.5).abs() < 1e-10);
429 assert!(m.apply(10.0) > 0.99);
430 assert!(m.apply(-10.0) < 0.01);
431 }
432
433 #[test]
434 fn test_mechanism_polynomial() {
435 let m = CausalMechanism::Polynomial {
436 coefficients: vec![1.0, 2.0, 3.0],
437 };
438 assert!((m.apply(2.0) - 17.0).abs() < 1e-10);
440 }
441
442 #[test]
443 fn test_revenue_cycle_validates() {
444 let graph = CausalGraph::revenue_cycle_template();
445 assert!(graph.validate().is_ok());
446 }
447
448 #[test]
449 fn test_graph_serde_roundtrip() {
450 let graph = CausalGraph::fraud_detection_template();
451 let json = serde_json::to_string(&graph).unwrap();
452 let deserialized: CausalGraph = serde_json::from_str(&json).unwrap();
453 assert_eq!(deserialized.variables.len(), graph.variables.len());
454 assert_eq!(deserialized.edges.len(), graph.edges.len());
455 }
456}