use anyhow::{anyhow, Result};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct MassFunction {
masses: HashMap<Vec<String>, f64>,
}
impl MassFunction {
pub fn new() -> Self {
Self {
masses: HashMap::new(),
}
}
pub fn assign_mass(&mut self, mut hypotheses: Vec<String>, mass: f64) -> Result<()> {
if !(0.0..=1.0).contains(&mass) {
return Err(anyhow!("Mass must be between 0 and 1, got {}", mass));
}
hypotheses.sort();
if hypotheses.is_empty() {
return Err(anyhow!("Cannot assign mass to empty set"));
}
*self.masses.entry(hypotheses).or_insert(0.0) += mass;
Ok(())
}
pub fn get_mass(&self, hypotheses: &[String]) -> f64 {
let mut sorted = hypotheses.to_vec();
sorted.sort();
*self.masses.get(&sorted).unwrap_or(&0.0)
}
pub fn focal_elements(&self) -> Vec<&Vec<String>> {
self.masses
.iter()
.filter(|(_, &mass)| mass > 1e-10)
.map(|(elem, _)| elem)
.collect()
}
pub fn total_mass(&self) -> f64 {
self.masses.values().sum()
}
pub fn normalize(&mut self) -> Result<()> {
let total = self.total_mass();
if total < 1e-10 {
return Err(anyhow!(
"Cannot normalize mass function with total mass near zero"
));
}
for mass in self.masses.values_mut() {
*mass /= total;
}
Ok(())
}
}
impl Default for MassFunction {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct DempsterShaferSystem {
frame: Vec<String>,
combined_mass: MassFunction,
}
impl DempsterShaferSystem {
pub fn new(frame: Vec<String>) -> Self {
let mut combined_mass = MassFunction::new();
let _ = combined_mass.assign_mass(frame.clone(), 1.0);
Self {
frame,
combined_mass,
}
}
pub fn add_evidence(&mut self, evidence: MassFunction) -> Result<()> {
let total_mass = evidence.total_mass();
if (total_mass - 1.0).abs() > 1e-6 {
return Err(anyhow!("Evidence mass must sum to 1.0, got {}", total_mass));
}
self.combined_mass = self.dempster_combine(&self.combined_mass, &evidence)?;
Ok(())
}
fn dempster_combine(&self, m1: &MassFunction, m2: &MassFunction) -> Result<MassFunction> {
let mut combined = MassFunction::new();
let mut conflict = 0.0;
for focal1 in m1.focal_elements() {
for focal2 in m2.focal_elements() {
let mass1 = m1.get_mass(focal1);
let mass2 = m2.get_mass(focal2);
let intersection = self.intersect(focal1, focal2);
if intersection.is_empty() {
conflict += mass1 * mass2;
} else {
combined
.assign_mass(intersection, mass1 * mass2)
.map_err(|e| anyhow!("Failed to combine masses: {}", e))?;
}
}
}
if (conflict - 1.0).abs() < 1e-10 {
return Err(anyhow!(
"Total conflict: evidence is completely contradictory"
));
}
for mass in combined.masses.values_mut() {
*mass /= 1.0 - conflict;
}
Ok(combined)
}
fn intersect(&self, set1: &[String], set2: &[String]) -> Vec<String> {
let s1: HashSet<_> = set1.iter().collect();
let s2: HashSet<_> = set2.iter().collect();
let mut intersection: Vec<_> = s1.intersection(&s2).map(|&s| s.clone()).collect();
intersection.sort();
intersection
}
pub fn belief(&self, hypotheses: &[String]) -> Result<f64> {
self.validate_hypotheses(hypotheses)?;
let target_set: HashSet<_> = hypotheses.iter().collect();
let mut belief = 0.0;
for focal in self.combined_mass.focal_elements() {
let focal_set: HashSet<_> = focal.iter().collect();
if focal_set.is_subset(&target_set) {
belief += self.combined_mass.get_mass(focal);
}
}
Ok(belief)
}
pub fn plausibility(&self, hypotheses: &[String]) -> Result<f64> {
self.validate_hypotheses(hypotheses)?;
let target_set: HashSet<_> = hypotheses.iter().collect();
let mut plausibility = 0.0;
for focal in self.combined_mass.focal_elements() {
let focal_set: HashSet<_> = focal.iter().collect();
if !focal_set.is_disjoint(&target_set) {
plausibility += self.combined_mass.get_mass(focal);
}
}
Ok(plausibility)
}
pub fn uncertainty_interval(&self, hypotheses: &[String]) -> Result<(f64, f64)> {
let belief = self.belief(hypotheses)?;
let plausibility = self.plausibility(hypotheses)?;
Ok((belief, plausibility))
}
pub fn pignistic_probability(&self, hypothesis: &str) -> Result<f64> {
if !self.frame.contains(&hypothesis.to_string()) {
return Err(anyhow!("Hypothesis '{}' not in frame", hypothesis));
}
let mut prob = 0.0;
for focal in self.combined_mass.focal_elements() {
if focal.contains(&hypothesis.to_string()) {
let mass = self.combined_mass.get_mass(focal);
let cardinality = focal.len() as f64;
prob += mass / cardinality;
}
}
Ok(prob)
}
pub fn pignistic_distribution(&self) -> Result<HashMap<String, f64>> {
let mut distribution = HashMap::new();
for hypothesis in &self.frame {
let prob = self.pignistic_probability(hypothesis)?;
distribution.insert(hypothesis.clone(), prob);
}
Ok(distribution)
}
pub fn get_combined_mass(&self) -> &MassFunction {
&self.combined_mass
}
pub fn get_frame(&self) -> &[String] {
&self.frame
}
fn validate_hypotheses(&self, hypotheses: &[String]) -> Result<()> {
for h in hypotheses {
if !self.frame.contains(h) {
return Err(anyhow!("Hypothesis '{}' not in frame of discernment", h));
}
}
Ok(())
}
pub fn compute_conflict(&self, evidence1: &MassFunction, evidence2: &MassFunction) -> f64 {
let mut conflict = 0.0;
for focal1 in evidence1.focal_elements() {
for focal2 in evidence2.focal_elements() {
let intersection = self.intersect(focal1, focal2);
if intersection.is_empty() {
conflict += evidence1.get_mass(focal1) * evidence2.get_mass(focal2);
}
}
}
conflict
}
}
#[derive(Debug, Clone)]
pub struct DempsterShaferReasoner {
system: DempsterShaferSystem,
evidence_sources: HashMap<String, MassFunction>,
}
impl DempsterShaferReasoner {
pub fn new(hypotheses: Vec<String>) -> Self {
Self {
system: DempsterShaferSystem::new(hypotheses),
evidence_sources: HashMap::new(),
}
}
pub fn add_named_evidence(&mut self, name: String, evidence: MassFunction) -> Result<()> {
self.system.add_evidence(evidence.clone())?;
self.evidence_sources.insert(name, evidence);
Ok(())
}
pub fn query_belief(&self, hypotheses: Vec<String>) -> Result<f64> {
self.system.belief(&hypotheses)
}
pub fn query_plausibility(&self, hypotheses: Vec<String>) -> Result<f64> {
self.system.plausibility(&hypotheses)
}
pub fn get_decision_probabilities(&self) -> Result<HashMap<String, f64>> {
self.system.pignistic_distribution()
}
pub fn get_all_uncertainty_intervals(&self) -> Result<HashMap<String, (f64, f64)>> {
let mut intervals = HashMap::new();
for hypothesis in self.system.get_frame() {
let interval = self
.system
.uncertainty_interval(std::slice::from_ref(hypothesis))?;
intervals.insert(hypothesis.clone(), interval);
}
Ok(intervals)
}
pub fn get_most_plausible(&self) -> Result<(String, f64)> {
let dist = self.system.pignistic_distribution()?;
dist.into_iter()
.max_by(|(_, p1), (_, p2)| p1.partial_cmp(p2).unwrap_or(std::cmp::Ordering::Equal))
.ok_or_else(|| anyhow!("No hypotheses in system"))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mass_function_basic() -> Result<(), Box<dyn std::error::Error>> {
let mut mf = MassFunction::new();
mf.assign_mass(vec!["A".to_string()], 0.6)?;
mf.assign_mass(vec!["B".to_string()], 0.4)?;
assert!((mf.get_mass(&["A".to_string()]) - 0.6).abs() < 1e-10);
assert!((mf.total_mass() - 1.0).abs() < 1e-10);
Ok(())
}
#[test]
fn test_mass_function_normalization() -> Result<(), Box<dyn std::error::Error>> {
let mut mf = MassFunction::new();
mf.assign_mass(vec!["A".to_string()], 0.3)?;
mf.assign_mass(vec!["B".to_string()], 0.2)?;
mf.normalize()?;
assert!((mf.total_mass() - 1.0).abs() < 1e-10);
Ok(())
}
#[test]
fn test_ds_system_belief() -> Result<(), Box<dyn std::error::Error>> {
let frame = vec!["A".to_string(), "B".to_string(), "C".to_string()];
let mut ds = DempsterShaferSystem::new(frame);
let mut evidence = MassFunction::new();
evidence.assign_mass(vec!["A".to_string()], 0.6)?;
evidence.assign_mass(vec!["A".to_string(), "B".to_string()], 0.3)?;
evidence.assign_mass(vec!["A".to_string(), "B".to_string(), "C".to_string()], 0.1)?;
ds.add_evidence(evidence)?;
let belief_a = ds.belief(&["A".to_string()])?;
assert!((belief_a - 0.6).abs() < 1e-10);
let belief_ab = ds.belief(&["A".to_string(), "B".to_string()])?;
assert!((belief_ab - 0.9).abs() < 1e-10);
Ok(())
}
#[test]
fn test_ds_system_plausibility() -> Result<(), Box<dyn std::error::Error>> {
let frame = vec!["A".to_string(), "B".to_string(), "C".to_string()];
let mut ds = DempsterShaferSystem::new(frame);
let mut evidence = MassFunction::new();
evidence.assign_mass(vec!["A".to_string()], 0.6)?;
evidence.assign_mass(vec!["B".to_string()], 0.3)?;
evidence.assign_mass(vec!["C".to_string()], 0.1)?;
ds.add_evidence(evidence)?;
let pl_a = ds.plausibility(&["A".to_string()])?;
assert!((pl_a - 0.6).abs() < 1e-10);
let pl_ab = ds.plausibility(&["A".to_string(), "B".to_string()])?;
assert!((pl_ab - 0.9).abs() < 1e-10);
Ok(())
}
#[test]
fn test_dempster_combination() -> Result<(), Box<dyn std::error::Error>> {
let frame = vec!["A".to_string(), "B".to_string()];
let mut ds = DempsterShaferSystem::new(frame);
let mut ev1 = MassFunction::new();
ev1.assign_mass(vec!["A".to_string()], 0.7)?;
ev1.assign_mass(vec!["B".to_string()], 0.2)?;
ev1.assign_mass(vec!["A".to_string(), "B".to_string()], 0.1)?;
let mut ev2 = MassFunction::new();
ev2.assign_mass(vec!["A".to_string()], 0.6)?;
ev2.assign_mass(vec!["B".to_string()], 0.3)?;
ev2.assign_mass(vec!["A".to_string(), "B".to_string()], 0.1)?;
ds.add_evidence(ev1)?;
ds.add_evidence(ev2)?;
let belief_a = ds.belief(&["A".to_string()])?;
assert!(belief_a > 0.7); Ok(())
}
#[test]
fn test_pignistic_probability() -> Result<(), Box<dyn std::error::Error>> {
let frame = vec!["A".to_string(), "B".to_string(), "C".to_string()];
let mut ds = DempsterShaferSystem::new(frame);
let mut evidence = MassFunction::new();
evidence.assign_mass(vec!["A".to_string()], 0.6)?;
evidence.assign_mass(vec!["A".to_string(), "B".to_string()], 0.4)?;
ds.add_evidence(evidence)?;
let prob_a = ds.pignistic_probability("A")?;
assert!((prob_a - 0.8).abs() < 1e-10);
let prob_b = ds.pignistic_probability("B")?;
assert!((prob_b - 0.2).abs() < 1e-10);
Ok(())
}
#[test]
fn test_ds_reasoner() -> Result<(), Box<dyn std::error::Error>> {
let hypotheses = vec!["Rain".to_string(), "NoRain".to_string()];
let mut reasoner = DempsterShaferReasoner::new(hypotheses);
let mut forecast = MassFunction::new();
forecast.assign_mass(vec!["Rain".to_string()], 0.7)?;
forecast.assign_mass(vec!["NoRain".to_string()], 0.3)?;
reasoner.add_named_evidence("forecast".to_string(), forecast)?;
let mut sensor = MassFunction::new();
sensor.assign_mass(vec!["Rain".to_string()], 0.8)?;
sensor.assign_mass(vec!["NoRain".to_string()], 0.2)?;
reasoner.add_named_evidence("sensor".to_string(), sensor)?;
let (most_plausible, prob) = reasoner.get_most_plausible()?;
assert_eq!(most_plausible, "Rain");
assert!(prob > 0.8);
Ok(())
}
#[test]
fn test_uncertainty_intervals() -> Result<(), Box<dyn std::error::Error>> {
let frame = vec!["A".to_string(), "B".to_string(), "C".to_string()];
let mut ds = DempsterShaferSystem::new(frame);
let mut evidence = MassFunction::new();
evidence.assign_mass(vec!["A".to_string()], 0.4)?;
evidence.assign_mass(vec!["A".to_string(), "B".to_string()], 0.3)?;
evidence.assign_mass(vec!["A".to_string(), "B".to_string(), "C".to_string()], 0.3)?;
ds.add_evidence(evidence)?;
let (bel, pl) = ds.uncertainty_interval(&["A".to_string()])?;
assert!((bel - 0.4).abs() < 1e-10);
assert!((pl - 1.0).abs() < 1e-10);
assert!((pl - bel - 0.6).abs() < 1e-10);
Ok(())
}
}