Skip to main content

datasynth_core/causal/
graph.rs

1use serde::{Deserialize, Serialize};
2use std::collections::{HashMap, HashSet, VecDeque};
3
4/// Type of a causal variable.
5#[derive(Debug, Clone, Default, Serialize, Deserialize)]
6#[serde(rename_all = "snake_case")]
7pub enum CausalVarType {
8    #[default]
9    Continuous,
10    Categorical,
11    Count,
12    Binary,
13}
14
15/// A variable in the causal graph.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct CausalVariable {
18    pub name: String,
19    #[serde(default)]
20    pub var_type: CausalVarType,
21    /// Base distribution for exogenous noise (e.g., "normal", "lognormal", "beta").
22    #[serde(default)]
23    pub distribution: Option<String>,
24    /// Distribution parameters.
25    #[serde(default)]
26    pub params: HashMap<String, f64>,
27}
28
29impl CausalVariable {
30    pub fn new(name: impl Into<String>, var_type: CausalVarType) -> Self {
31        Self {
32            name: name.into(),
33            var_type,
34            distribution: None,
35            params: HashMap::new(),
36        }
37    }
38
39    pub fn with_distribution(mut self, dist: impl Into<String>) -> Self {
40        self.distribution = Some(dist.into());
41        self
42    }
43
44    pub fn with_param(mut self, key: impl Into<String>, value: f64) -> Self {
45        self.params.insert(key.into(), value);
46        self
47    }
48}
49
50/// Causal mechanism defining how a parent influences a child.
51#[derive(Debug, Clone, Serialize, Deserialize)]
52#[serde(tag = "type", rename_all = "snake_case")]
53pub enum CausalMechanism {
54    /// Linear: child += coefficient * parent
55    Linear { coefficient: f64 },
56    /// Threshold: child = 1 if parent > cutoff else 0
57    Threshold { cutoff: f64 },
58    /// Polynomial: child += sum(coeff[i] * parent^i)
59    Polynomial { coefficients: Vec<f64> },
60    /// Logistic: child += 1 / (1 + exp(-scale * (parent - midpoint)))
61    Logistic { scale: f64, midpoint: f64 },
62}
63
64impl CausalMechanism {
65    /// Apply this mechanism to compute the contribution from a parent value.
66    pub fn apply(&self, parent_value: f64) -> f64 {
67        match self {
68            CausalMechanism::Linear { coefficient } => coefficient * parent_value,
69            CausalMechanism::Threshold { cutoff } => {
70                if parent_value > *cutoff {
71                    1.0
72                } else {
73                    0.0
74                }
75            }
76            CausalMechanism::Polynomial { coefficients } => coefficients
77                .iter()
78                .enumerate()
79                .map(|(i, c)| c * parent_value.powi(i as i32))
80                .sum(),
81            CausalMechanism::Logistic { scale, midpoint } => {
82                1.0 / (1.0 + (-scale * (parent_value - midpoint)).exp())
83            }
84        }
85    }
86}
87
88/// A directed edge in the causal graph.
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct CausalEdge {
91    pub from: String,
92    pub to: String,
93    pub mechanism: CausalMechanism,
94    #[serde(default = "default_strength")]
95    pub strength: f64,
96}
97
98fn default_strength() -> f64 {
99    1.0
100}
101
102/// A causal directed acyclic graph (DAG).
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct CausalGraph {
105    pub variables: Vec<CausalVariable>,
106    pub edges: Vec<CausalEdge>,
107}
108
109impl CausalGraph {
110    pub fn new() -> Self {
111        Self {
112            variables: Vec::new(),
113            edges: Vec::new(),
114        }
115    }
116
117    pub fn add_variable(&mut self, var: CausalVariable) {
118        self.variables.push(var);
119    }
120
121    pub fn add_edge(&mut self, edge: CausalEdge) {
122        self.edges.push(edge);
123    }
124
125    /// Get variable names.
126    pub fn variable_names(&self) -> Vec<&str> {
127        self.variables.iter().map(|v| v.name.as_str()).collect()
128    }
129
130    /// Get variable by name.
131    pub fn get_variable(&self, name: &str) -> Option<&CausalVariable> {
132        self.variables.iter().find(|v| v.name == name)
133    }
134
135    /// Get all edges pointing TO a given variable (its parents).
136    pub fn parent_edges(&self, variable: &str) -> Vec<&CausalEdge> {
137        self.edges.iter().filter(|e| e.to == variable).collect()
138    }
139
140    /// Validate the graph: check acyclicity, no self-loops, all referenced vars exist.
141    pub fn validate(&self) -> Result<(), String> {
142        let var_names: HashSet<&str> = self.variables.iter().map(|v| v.name.as_str()).collect();
143
144        // Check for self-loops
145        for edge in &self.edges {
146            if edge.from == edge.to {
147                return Err(format!("Self-loop detected on variable '{}'", edge.from));
148            }
149        }
150
151        // Check all referenced variables exist
152        for edge in &self.edges {
153            if !var_names.contains(edge.from.as_str()) {
154                return Err(format!("Edge references unknown variable '{}'", edge.from));
155            }
156            if !var_names.contains(edge.to.as_str()) {
157                return Err(format!("Edge references unknown variable '{}'", edge.to));
158            }
159        }
160
161        // Check acyclicity via topological sort
162        self.topological_order().map(|_| ())
163    }
164
165    /// Compute topological ordering of variables. Returns error if cyclic.
166    pub fn topological_order(&self) -> Result<Vec<String>, String> {
167        let var_names: Vec<String> = self.variables.iter().map(|v| v.name.clone()).collect();
168        let n = var_names.len();
169        let name_to_idx: HashMap<&str, usize> = var_names
170            .iter()
171            .enumerate()
172            .map(|(i, n)| (n.as_str(), i))
173            .collect();
174
175        // Build adjacency and in-degree
176        let mut in_degree = vec![0usize; n];
177        let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
178
179        for edge in &self.edges {
180            if let (Some(&from_idx), Some(&to_idx)) = (
181                name_to_idx.get(edge.from.as_str()),
182                name_to_idx.get(edge.to.as_str()),
183            ) {
184                adj[from_idx].push(to_idx);
185                in_degree[to_idx] += 1;
186            }
187        }
188
189        // Kahn's algorithm
190        let mut queue: VecDeque<usize> = VecDeque::new();
191        for (i, &deg) in in_degree.iter().enumerate() {
192            if deg == 0 {
193                queue.push_back(i);
194            }
195        }
196
197        let mut order = Vec::with_capacity(n);
198        while let Some(node) = queue.pop_front() {
199            order.push(var_names[node].clone());
200            for &neighbor in &adj[node] {
201                in_degree[neighbor] -= 1;
202                if in_degree[neighbor] == 0 {
203                    queue.push_back(neighbor);
204                }
205            }
206        }
207
208        if order.len() != n {
209            Err("Causal graph contains a cycle".to_string())
210        } else {
211            Ok(order)
212        }
213    }
214
215    /// Built-in fraud detection SCM template.
216    pub fn fraud_detection_template() -> Self {
217        let mut graph = Self::new();
218        graph.add_variable(
219            CausalVariable::new("transaction_amount", CausalVarType::Continuous)
220                .with_distribution("lognormal")
221                .with_param("mu", 6.0)
222                .with_param("sigma", 1.5),
223        );
224        graph.add_variable(
225            CausalVariable::new("merchant_risk", CausalVarType::Continuous)
226                .with_distribution("beta")
227                .with_param("alpha", 2.0)
228                .with_param("beta_param", 5.0),
229        );
230        graph.add_variable(
231            CausalVariable::new("transaction_frequency", CausalVarType::Count)
232                .with_distribution("normal")
233                .with_param("mean", 10.0)
234                .with_param("std", 3.0),
235        );
236        graph.add_variable(CausalVariable::new(
237            "fraud_probability",
238            CausalVarType::Continuous,
239        ));
240        graph.add_variable(CausalVariable::new("is_fraud", CausalVarType::Binary));
241
242        graph.add_edge(CausalEdge {
243            from: "transaction_amount".into(),
244            to: "fraud_probability".into(),
245            mechanism: CausalMechanism::Linear { coefficient: 0.3 },
246            strength: 1.0,
247        });
248        graph.add_edge(CausalEdge {
249            from: "merchant_risk".into(),
250            to: "fraud_probability".into(),
251            mechanism: CausalMechanism::Linear { coefficient: 0.5 },
252            strength: 1.0,
253        });
254        graph.add_edge(CausalEdge {
255            from: "transaction_frequency".into(),
256            to: "fraud_probability".into(),
257            mechanism: CausalMechanism::Linear { coefficient: 0.2 },
258            strength: 1.0,
259        });
260        graph.add_edge(CausalEdge {
261            from: "fraud_probability".into(),
262            to: "is_fraud".into(),
263            mechanism: CausalMechanism::Threshold { cutoff: 0.7 },
264            strength: 1.0,
265        });
266
267        graph
268    }
269
270    /// Built-in revenue cycle SCM template.
271    pub fn revenue_cycle_template() -> Self {
272        let mut graph = Self::new();
273        graph.add_variable(
274            CausalVariable::new("order_volume", CausalVarType::Continuous)
275                .with_distribution("normal")
276                .with_param("mean", 100.0)
277                .with_param("std", 30.0),
278        );
279        graph.add_variable(
280            CausalVariable::new("shipment_rate", CausalVarType::Continuous)
281                .with_distribution("beta")
282                .with_param("alpha", 8.0)
283                .with_param("beta_param", 2.0),
284        );
285        graph.add_variable(CausalVariable::new(
286            "invoice_amount",
287            CausalVarType::Continuous,
288        ));
289        graph.add_variable(CausalVariable::new(
290            "collection_rate",
291            CausalVarType::Continuous,
292        ));
293
294        graph.add_edge(CausalEdge {
295            from: "order_volume".into(),
296            to: "shipment_rate".into(),
297            mechanism: CausalMechanism::Logistic {
298                scale: 0.05,
299                midpoint: 50.0,
300            },
301            strength: 1.0,
302        });
303        graph.add_edge(CausalEdge {
304            from: "order_volume".into(),
305            to: "invoice_amount".into(),
306            mechanism: CausalMechanism::Linear { coefficient: 100.0 },
307            strength: 1.0,
308        });
309        graph.add_edge(CausalEdge {
310            from: "shipment_rate".into(),
311            to: "invoice_amount".into(),
312            mechanism: CausalMechanism::Linear { coefficient: 0.5 },
313            strength: 1.0,
314        });
315        graph.add_edge(CausalEdge {
316            from: "invoice_amount".into(),
317            to: "collection_rate".into(),
318            mechanism: CausalMechanism::Logistic {
319                scale: -0.0001,
320                midpoint: 5000.0,
321            },
322            strength: 1.0,
323        });
324
325        graph
326    }
327}
328
329impl Default for CausalGraph {
330    fn default() -> Self {
331        Self::new()
332    }
333}
334
335#[cfg(test)]
336#[allow(clippy::unwrap_used)]
337mod tests {
338    use super::*;
339
340    #[test]
341    fn test_acyclic_graph_validates() {
342        let graph = CausalGraph::fraud_detection_template();
343        assert!(graph.validate().is_ok());
344    }
345
346    #[test]
347    fn test_cyclic_graph_rejected() {
348        let mut graph = CausalGraph::new();
349        graph.add_variable(CausalVariable::new("a", CausalVarType::Continuous));
350        graph.add_variable(CausalVariable::new("b", CausalVarType::Continuous));
351        graph.add_edge(CausalEdge {
352            from: "a".into(),
353            to: "b".into(),
354            mechanism: CausalMechanism::Linear { coefficient: 1.0 },
355            strength: 1.0,
356        });
357        graph.add_edge(CausalEdge {
358            from: "b".into(),
359            to: "a".into(),
360            mechanism: CausalMechanism::Linear { coefficient: 1.0 },
361            strength: 1.0,
362        });
363        assert!(graph.validate().is_err());
364    }
365
366    #[test]
367    fn test_self_loop_rejected() {
368        let mut graph = CausalGraph::new();
369        graph.add_variable(CausalVariable::new("a", CausalVarType::Continuous));
370        graph.add_edge(CausalEdge {
371            from: "a".into(),
372            to: "a".into(),
373            mechanism: CausalMechanism::Linear { coefficient: 1.0 },
374            strength: 1.0,
375        });
376        let result = graph.validate();
377        assert!(result.is_err());
378        assert!(result.unwrap_err().contains("Self-loop"));
379    }
380
381    #[test]
382    fn test_topological_order() {
383        let graph = CausalGraph::fraud_detection_template();
384        let order = graph.topological_order().unwrap();
385        // Root variables (no parents) should come first
386        let amount_pos = order
387            .iter()
388            .position(|n| n == "transaction_amount")
389            .unwrap();
390        let fraud_prob_pos = order.iter().position(|n| n == "fraud_probability").unwrap();
391        let is_fraud_pos = order.iter().position(|n| n == "is_fraud").unwrap();
392        assert!(amount_pos < fraud_prob_pos);
393        assert!(fraud_prob_pos < is_fraud_pos);
394    }
395
396    #[test]
397    fn test_unknown_variable_rejected() {
398        let mut graph = CausalGraph::new();
399        graph.add_variable(CausalVariable::new("a", CausalVarType::Continuous));
400        graph.add_edge(CausalEdge {
401            from: "a".into(),
402            to: "nonexistent".into(),
403            mechanism: CausalMechanism::Linear { coefficient: 1.0 },
404            strength: 1.0,
405        });
406        assert!(graph.validate().is_err());
407    }
408
409    #[test]
410    fn test_mechanism_linear() {
411        let m = CausalMechanism::Linear { coefficient: 2.0 };
412        assert!((m.apply(3.0) - 6.0).abs() < 1e-10);
413    }
414
415    #[test]
416    fn test_mechanism_threshold() {
417        let m = CausalMechanism::Threshold { cutoff: 0.5 };
418        assert!((m.apply(0.3) - 0.0).abs() < 1e-10);
419        assert!((m.apply(0.7) - 1.0).abs() < 1e-10);
420    }
421
422    #[test]
423    fn test_mechanism_logistic() {
424        let m = CausalMechanism::Logistic {
425            scale: 1.0,
426            midpoint: 0.0,
427        };
428        assert!((m.apply(0.0) - 0.5).abs() < 1e-10);
429        assert!(m.apply(10.0) > 0.99);
430        assert!(m.apply(-10.0) < 0.01);
431    }
432
433    #[test]
434    fn test_mechanism_polynomial() {
435        let m = CausalMechanism::Polynomial {
436            coefficients: vec![1.0, 2.0, 3.0],
437        };
438        // 1 + 2*x + 3*x^2 at x=2 = 1 + 4 + 12 = 17
439        assert!((m.apply(2.0) - 17.0).abs() < 1e-10);
440    }
441
442    #[test]
443    fn test_revenue_cycle_validates() {
444        let graph = CausalGraph::revenue_cycle_template();
445        assert!(graph.validate().is_ok());
446    }
447
448    #[test]
449    fn test_graph_serde_roundtrip() {
450        let graph = CausalGraph::fraud_detection_template();
451        let json = serde_json::to_string(&graph).unwrap();
452        let deserialized: CausalGraph = serde_json::from_str(&json).unwrap();
453        assert_eq!(deserialized.variables.len(), graph.variables.len());
454        assert_eq!(deserialized.edges.len(), graph.edges.len());
455    }
456}