use crate::error::{LmmError, Result};
use crate::traits::Causal;
use std::collections::{HashMap, HashSet, VecDeque};
#[derive(Debug, Clone)]
pub struct CausalNode {
pub name: String,
pub value: Option<f64>,
}
#[derive(Debug, Clone)]
pub struct CausalEdge {
pub from: String,
pub to: String,
pub coefficient: Option<f64>,
}
#[derive(Debug, Clone, Default)]
pub struct CausalGraph {
pub nodes: Vec<CausalNode>,
pub edges: Vec<CausalEdge>,
}
impl CausalGraph {
pub fn new() -> Self {
Self::default()
}
pub fn add_node(&mut self, name: &str, value: Option<f64>) {
if !self.nodes.iter().any(|n| n.name == name) {
self.nodes.push(CausalNode {
name: name.to_string(),
value,
});
}
}
pub fn add_edge(&mut self, from: &str, to: &str, coefficient: Option<f64>) -> Result<()> {
if !self.nodes.iter().any(|n| n.name == from) {
return Err(LmmError::CausalError(format!("Node '{}' not found", from)));
}
if !self.nodes.iter().any(|n| n.name == to) {
return Err(LmmError::CausalError(format!("Node '{}' not found", to)));
}
self.edges.push(CausalEdge {
from: from.to_string(),
to: to.to_string(),
coefficient,
});
Ok(())
}
pub fn get_value(&self, name: &str) -> Option<f64> {
self.nodes.iter().find(|n| n.name == name)?.value
}
pub fn topological_order(&self) -> Result<Vec<String>> {
let mut in_degree: HashMap<&str, usize> = self
.nodes
.iter()
.map(|n| (n.name.as_str(), 0usize))
.collect();
let mut adj: HashMap<&str, Vec<&str>> = self
.nodes
.iter()
.map(|n| (n.name.as_str(), vec![]))
.collect();
for edge in &self.edges {
*in_degree.entry(edge.to.as_str()).or_insert(0) += 1;
adj.entry(edge.from.as_str())
.or_default()
.push(edge.to.as_str());
}
let mut queue: VecDeque<&str> = self
.nodes
.iter()
.filter(|n| in_degree.get(n.name.as_str()).copied().unwrap_or(0) == 0)
.map(|n| n.name.as_str())
.collect();
let mut order: Vec<String> = Vec::with_capacity(self.nodes.len());
while let Some(node) = queue.pop_front() {
order.push(node.to_string());
for &child in adj.get(node).map(Vec::as_slice).unwrap_or_default() {
let deg = in_degree.entry(child).or_insert(0);
*deg = deg.saturating_sub(1);
if *deg == 0 {
queue.push_back(child);
}
}
}
if order.len() != self.nodes.len() {
return Err(LmmError::CausalError(
"Cycle detected in causal graph; topological ordering impossible".into(),
));
}
Ok(order)
}
pub fn forward_pass(&mut self) -> Result<()> {
let order = self.topological_order()?;
let mut parent_map: HashMap<String, Vec<(String, f64)>> = HashMap::new();
for edge in &self.edges {
parent_map
.entry(edge.to.clone())
.or_default()
.push((edge.from.clone(), edge.coefficient.unwrap_or(1.0)));
}
let mut values: HashMap<String, Option<f64>> = self
.nodes
.iter()
.map(|n| (n.name.clone(), n.value))
.collect();
for name in &order {
if let Some(parents) = parent_map.get(name) {
let has_all_parents = parents
.iter()
.all(|(p, _)| values.get(p).and_then(|v| *v).is_some());
if has_all_parents {
let sum: f64 = parents
.iter()
.filter_map(|(p, coeff)| values.get(p).and_then(|v| *v).map(|v| v * coeff))
.sum();
values.insert(name.clone(), Some(sum));
}
}
}
for node in &mut self.nodes {
if let Some(v) = values.get(&node.name) {
node.value = *v;
}
}
Ok(())
}
pub fn counterfactual(
&self,
intervention_node: &str,
intervention_value: f64,
query_node: &str,
) -> Result<f64> {
let mut g = self.clone();
g.intervene(intervention_node, intervention_value)?;
g.forward_pass()?;
g.get_value(query_node).ok_or_else(|| {
LmmError::CausalError(format!(
"Query node '{}' has no value after propagation",
query_node
))
})
}
pub fn parents(&self, node: &str) -> Vec<String> {
self.edges
.iter()
.filter(|e| e.to == node)
.map(|e| e.from.clone())
.collect()
}
pub fn children(&self, node: &str) -> Vec<String> {
self.edges
.iter()
.filter(|e| e.from == node)
.map(|e| e.to.clone())
.collect()
}
pub fn has_cycle(&self) -> bool {
self.topological_order().is_err()
}
pub fn markov_blanket(&self, node: &str) -> HashSet<String> {
let mut blanket = HashSet::new();
for p in self.parents(node) {
blanket.insert(p);
}
for child in self.children(node) {
blanket.insert(child.clone());
for cp in self.parents(&child) {
if cp != node {
blanket.insert(cp);
}
}
}
blanket
}
}
impl Causal for CausalGraph {
fn intervene(&mut self, var: &str, value: f64) -> Result<()> {
let found = self.nodes.iter_mut().find(|n| n.name == var);
if let Some(node) = found {
node.value = Some(value);
self.edges.retain(|e| e.to != var);
Ok(())
} else {
Err(LmmError::CausalError(format!(
"Intervention target '{}' not found in graph",
var
)))
}
}
}
pub fn build_chain(n: usize, coefficient: f64) -> CausalGraph {
let mut g = CausalGraph::new();
for i in 0..n {
let val = if i == 0 { Some(1.0) } else { None };
g.add_node(&format!("x{}", i), val);
}
for i in 0..n.saturating_sub(1) {
let _ = g.add_edge(
&format!("x{}", i),
&format!("x{}", i + 1),
Some(coefficient),
);
}
g
}
pub fn ancestors(g: &CausalGraph, node: &str) -> HashSet<String> {
let mut result = HashSet::new();
let mut queue: VecDeque<&str> = VecDeque::new();
queue.push_back(node);
while let Some(current) = queue.pop_front() {
for edge in &g.edges {
if edge.to == current && !result.contains(&edge.from) {
result.insert(edge.from.clone());
queue.push_back(&edge.from);
}
}
}
result
}