use crate::causal::CausalGraph;
use crate::error::Result;
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq)]
pub struct BeliefDistribution {
pub mean: f64,
pub variance: f64,
}
impl BeliefDistribution {
pub fn new(mean: f64, variance: f64) -> Self {
Self {
mean,
variance: variance.max(0.0),
}
}
pub fn uninformative(mean: f64) -> Self {
Self::new(mean, f64::MAX / 2.0)
}
pub fn std_dev(&self) -> f64 {
self.variance.sqrt()
}
pub fn confidence_interval(&self, z: f64) -> (f64, f64) {
let half_width = z * self.std_dev();
(self.mean - half_width, self.mean + half_width)
}
pub fn contains(&self, value: f64, z: f64) -> bool {
let (lo, hi) = self.confidence_interval(z);
value >= lo && value <= hi
}
pub fn fuse(&self, other: &BeliefDistribution) -> BeliefDistribution {
let prec_self = if self.variance > 0.0 {
1.0 / self.variance
} else {
f64::MAX
};
let prec_other = if other.variance > 0.0 {
1.0 / other.variance
} else {
f64::MAX
};
let prec_total = prec_self + prec_other;
let fused_mean = (prec_self * self.mean + prec_other * other.mean) / prec_total;
let fused_variance = if prec_total > 0.0 {
1.0 / prec_total
} else {
0.0
};
BeliefDistribution::new(fused_mean, fused_variance)
}
}
impl Default for BeliefDistribution {
fn default() -> Self {
Self::new(0.0, 1.0)
}
}
pub struct UncertaintyPropagator;
impl UncertaintyPropagator {
pub fn propagate(
graph: &CausalGraph,
initial: &HashMap<String, BeliefDistribution>,
) -> Result<HashMap<String, BeliefDistribution>> {
let order = graph.topological_order()?;
let mut beliefs: HashMap<String, BeliefDistribution> = graph
.nodes
.iter()
.map(|n| {
let b = initial
.get(&n.name)
.cloned()
.unwrap_or_else(|| BeliefDistribution::new(n.value.unwrap_or(0.0), 1.0));
(n.name.clone(), b)
})
.collect();
let mut parent_map: HashMap<String, Vec<(String, f64)>> = HashMap::new();
for edge in &graph.edges {
parent_map
.entry(edge.to.clone())
.or_default()
.push((edge.from.clone(), edge.coefficient.unwrap_or(1.0)));
}
for name in &order {
if let Some(parents) = parent_map.get(name) {
let all_known = parents.iter().all(|(p, _)| beliefs.contains_key(p));
if !all_known {
continue;
}
let mean: f64 = parents
.iter()
.map(|(p, coeff)| beliefs[p].mean * coeff)
.sum();
let variance: f64 = parents
.iter()
.map(|(p, coeff)| coeff * coeff * beliefs[p].variance)
.sum();
beliefs.insert(name.clone(), BeliefDistribution::new(mean, variance));
}
}
Ok(beliefs)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct CalibrationObservation {
pub lower: f64,
pub upper: f64,
pub realized: f64,
}
impl CalibrationObservation {
pub fn is_hit(&self) -> bool {
self.realized >= self.lower && self.realized <= self.upper
}
}
#[derive(Debug, Clone, Default)]
pub struct CalibrationRecord {
observations: Vec<CalibrationObservation>,
}
impl CalibrationRecord {
pub fn new() -> Self {
Self::default()
}
pub fn observe(&mut self, obs: CalibrationObservation) {
self.observations.push(obs);
}
pub fn hit_rate(&self) -> f64 {
if self.observations.is_empty() {
return 0.0;
}
let hits = self.observations.iter().filter(|o| o.is_hit()).count();
hits as f64 / self.observations.len() as f64
}
pub fn calibration_loss(&self) -> f64 {
if self.observations.is_empty() {
return 0.0;
}
let total: f64 = self
.observations
.iter()
.map(|o| {
if o.is_hit() {
0.0
} else {
let width = (o.upper - o.lower).abs().max(1e-12);
let miss = if o.realized < o.lower {
o.lower - o.realized
} else {
o.realized - o.upper
};
(miss / width).powi(2)
}
})
.sum();
total / self.observations.len() as f64
}
pub fn len(&self) -> usize {
self.observations.len()
}
pub fn is_empty(&self) -> bool {
self.observations.is_empty()
}
}