use scirs2_core::ndarray::ArrayD;
use std::collections::HashMap;
use crate::error::Result;
use crate::graph::FactorGraph;
use crate::message_passing::MessagePassingAlgorithm;
#[derive(Clone, Debug)]
pub struct MarginalizationQuery {
pub variable: String,
}
#[derive(Clone, Debug)]
pub struct ConditionalQuery {
pub query_vars: Vec<String>,
pub evidence: HashMap<String, usize>,
}
pub struct InferenceEngine {
graph: FactorGraph,
algorithm: Box<dyn MessagePassingAlgorithm>,
}
impl InferenceEngine {
pub fn new(graph: FactorGraph, algorithm: Box<dyn MessagePassingAlgorithm>) -> Self {
Self { graph, algorithm }
}
pub fn marginalize(&self, query: &MarginalizationQuery) -> Result<ArrayD<f64>> {
let marginals = self.algorithm.run(&self.graph)?;
marginals
.get(&query.variable)
.cloned()
.ok_or_else(|| crate::error::PgmError::VariableNotFound(query.variable.clone()))
}
pub fn conditional(&self, query: &ConditionalQuery) -> Result<HashMap<String, ArrayD<f64>>> {
let marginals = self.algorithm.run(&self.graph)?;
let mut result = HashMap::new();
for var in &query.query_vars {
if let Some(marginal) = marginals.get(var) {
result.insert(var.clone(), marginal.clone());
}
}
Ok(result)
}
pub fn joint(&self) -> Result<ArrayD<f64>> {
use crate::factor::Factor;
let all_vars: Vec<String> = self.graph.variable_names().cloned().collect();
if all_vars.is_empty() {
return Err(crate::error::PgmError::InvalidGraph(
"No variables in graph".to_string(),
));
}
let mut joint_factor: Option<Factor> = None;
for factor_id in self.graph.factor_ids() {
if let Some(factor) = self.graph.get_factor(factor_id) {
joint_factor = if let Some(existing) = joint_factor {
Some(existing.product(factor)?)
} else {
Some(factor.clone())
};
}
}
if let Some(mut joint) = joint_factor {
joint.normalize();
Ok(joint.values)
} else {
let shape: Vec<usize> = all_vars
.iter()
.filter_map(|v| self.graph.get_variable(v))
.map(|n| n.cardinality)
.collect();
let size: usize = shape.iter().product();
Ok(ArrayD::from_elem(shape, 1.0 / size as f64))
}
}
pub fn graph(&self) -> &FactorGraph {
&self.graph
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::message_passing::SumProductAlgorithm;
#[test]
fn test_inference_engine() {
let mut graph = FactorGraph::new();
graph.add_variable("x".to_string(), "D1".to_string());
let algorithm = Box::new(SumProductAlgorithm::default());
let engine = InferenceEngine::new(graph, algorithm);
let query = MarginalizationQuery {
variable: "x".to_string(),
};
let result = engine.marginalize(&query);
assert!(result.is_ok());
}
#[test]
fn test_conditional_query() {
let mut graph = FactorGraph::new();
graph.add_variable("x".to_string(), "D1".to_string());
graph.add_variable("y".to_string(), "D2".to_string());
let algorithm = Box::new(SumProductAlgorithm::default());
let engine = InferenceEngine::new(graph, algorithm);
let query = ConditionalQuery {
query_vars: vec!["x".to_string()],
evidence: HashMap::new(),
};
let result = engine.conditional(&query);
assert!(result.is_ok());
}
#[test]
fn test_joint_probability() {
let mut graph = FactorGraph::new();
graph.add_variable("var_0".to_string(), "D1".to_string());
let algorithm = Box::new(SumProductAlgorithm::default());
let engine = InferenceEngine::new(graph, algorithm);
let joint = engine.joint();
assert!(joint.is_ok());
let sum: f64 = joint.expect("unwrap").iter().sum();
assert!((sum - 1.0).abs() < 1e-6);
}
#[test]
fn test_joint_with_multiple_variables() {
let mut graph = FactorGraph::new();
graph.add_variable("var_0".to_string(), "D1".to_string());
graph.add_variable("var_1".to_string(), "D2".to_string());
let algorithm = Box::new(SumProductAlgorithm::default());
let engine = InferenceEngine::new(graph, algorithm);
let joint = engine.joint();
assert!(joint.is_ok());
let sum: f64 = joint.expect("unwrap").iter().sum();
assert!((sum - 1.0).abs() < 1e-6);
}
}