use crate::error::{PgmError, Result};
use crate::factor::Factor;
use crate::graph::FactorGraph;
use scirs2_core::ndarray::ArrayD;
use std::collections::{HashMap, HashSet, VecDeque};
#[derive(Debug, Clone)]
pub struct Clique {
pub id: usize,
pub variables: HashSet<String>,
pub potential: Option<Factor>,
}
impl Clique {
pub fn new(id: usize, variables: HashSet<String>) -> Self {
Self {
id,
variables,
potential: None,
}
}
pub fn contains_all(&self, vars: &HashSet<String>) -> bool {
vars.is_subset(&self.variables)
}
pub fn intersection(&self, other: &Clique) -> HashSet<String> {
self.variables
.intersection(&other.variables)
.cloned()
.collect()
}
}
#[derive(Debug, Clone)]
pub struct Separator {
pub variables: HashSet<String>,
pub potential: Option<Factor>,
}
impl Separator {
pub fn from_cliques(c1: &Clique, c2: &Clique) -> Self {
Self {
variables: c1.intersection(c2),
potential: None,
}
}
}
#[derive(Debug, Clone)]
pub struct JunctionTreeEdge {
pub clique1: usize,
pub clique2: usize,
pub separator: Separator,
pub message_1_to_2: Option<Factor>,
pub message_2_to_1: Option<Factor>,
}
impl JunctionTreeEdge {
pub fn new(clique1: usize, clique2: usize, separator: Separator) -> Self {
Self {
clique1,
clique2,
separator,
message_1_to_2: None,
message_2_to_1: None,
}
}
}
#[derive(Debug, Clone)]
pub struct JunctionTree {
pub cliques: Vec<Clique>,
pub edges: Vec<JunctionTreeEdge>,
pub var_to_cliques: HashMap<String, Vec<usize>>,
pub calibrated: bool,
}
impl JunctionTree {
pub fn new() -> Self {
Self {
cliques: Vec::new(),
edges: Vec::new(),
var_to_cliques: HashMap::new(),
calibrated: false,
}
}
pub fn from_factor_graph(graph: &FactorGraph) -> Result<Self> {
let interaction_graph = Self::build_interaction_graph(graph)?;
let triangulated = Self::triangulate(&interaction_graph)?;
let cliques = Self::find_maximal_cliques(&triangulated)?;
let mut tree = Self::build_tree_from_cliques(cliques)?;
tree.initialize_potentials(graph)?;
Ok(tree)
}
fn build_interaction_graph(graph: &FactorGraph) -> Result<HashMap<String, HashSet<String>>> {
let mut adjacency: HashMap<String, HashSet<String>> = HashMap::new();
for var_name in graph.variable_names() {
adjacency.insert(var_name.clone(), HashSet::new());
}
for factor in graph.factors() {
let vars = &factor.variables;
for i in 0..vars.len() {
for j in (i + 1)..vars.len() {
let v1 = &vars[i];
let v2 = &vars[j];
adjacency.entry(v1.clone()).or_default().insert(v2.clone());
adjacency.entry(v2.clone()).or_default().insert(v1.clone());
}
}
}
Ok(adjacency)
}
fn triangulate(
graph: &HashMap<String, HashSet<String>>,
) -> Result<HashMap<String, HashSet<String>>> {
let mut triangulated = graph.clone();
let mut remaining: HashSet<String> = graph.keys().cloned().collect();
while !remaining.is_empty() {
let var = Self::find_min_fill_variable(&triangulated, &remaining)?;
let neighbors: Vec<String> = triangulated
.get(&var)
.ok_or_else(|| PgmError::InvalidGraph("Variable not found".to_string()))?
.intersection(&remaining)
.cloned()
.collect();
for i in 0..neighbors.len() {
for j in (i + 1)..neighbors.len() {
let n1 = &neighbors[i];
let n2 = &neighbors[j];
triangulated
.entry(n1.clone())
.or_default()
.insert(n2.clone());
triangulated
.entry(n2.clone())
.or_default()
.insert(n1.clone());
}
}
remaining.remove(&var);
}
Ok(triangulated)
}
fn find_min_fill_variable(
graph: &HashMap<String, HashSet<String>>,
remaining: &HashSet<String>,
) -> Result<String> {
let mut min_fill = usize::MAX;
let mut best_var = None;
for var in remaining {
let neighbors: Vec<String> = graph
.get(var)
.ok_or_else(|| PgmError::InvalidGraph("Variable not found".to_string()))?
.intersection(remaining)
.cloned()
.collect();
let mut fill_count = 0;
for i in 0..neighbors.len() {
for j in (i + 1)..neighbors.len() {
let n1 = &neighbors[i];
let n2 = &neighbors[j];
if !graph
.get(n1)
.expect("n1 neighbor set present in triangulated graph")
.contains(n2)
{
fill_count += 1;
}
}
}
if fill_count < min_fill {
min_fill = fill_count;
best_var = Some(var.clone());
}
}
best_var.ok_or_else(|| PgmError::InvalidGraph("No variable found".to_string()))
}
fn find_maximal_cliques(
graph: &HashMap<String, HashSet<String>>,
) -> Result<Vec<HashSet<String>>> {
let mut cliques = Vec::new();
let mut visited: HashSet<String> = HashSet::new();
for var in graph.keys() {
if visited.contains(var) {
continue;
}
let mut clique: HashSet<String> = HashSet::new();
clique.insert(var.clone());
for neighbor in graph.get(var).expect("var present in graph adjacency") {
let is_fully_connected = clique.iter().all(|c| {
c == neighbor
|| graph
.get(neighbor)
.expect("neighbor present in graph adjacency")
.contains(c)
});
if is_fully_connected {
clique.insert(neighbor.clone());
}
}
let is_maximal = !cliques
.iter()
.any(|c: &HashSet<String>| c.is_superset(&clique));
if is_maximal {
cliques.retain(|c| !clique.is_superset(c));
cliques.push(clique.clone());
}
visited.insert(var.clone());
}
if cliques.is_empty() && !graph.is_empty() {
let all_vars: HashSet<String> = graph.keys().cloned().collect();
cliques.push(all_vars);
}
Ok(cliques)
}
fn build_tree_from_cliques(clique_sets: Vec<HashSet<String>>) -> Result<Self> {
let mut tree = JunctionTree::new();
for (id, vars) in clique_sets.into_iter().enumerate() {
let clique = Clique::new(id, vars.clone());
for var in &vars {
tree.var_to_cliques.entry(var.clone()).or_default().push(id);
}
tree.cliques.push(clique);
}
if tree.cliques.len() > 1 {
tree.build_maximum_spanning_tree()?;
}
Ok(tree)
}
fn build_maximum_spanning_tree(&mut self) -> Result<()> {
let n = self.cliques.len();
if n == 0 {
return Ok(());
}
let mut in_tree = vec![false; n];
let mut edges_to_add: Vec<(usize, usize, usize)> = Vec::new();
in_tree[0] = true;
let mut tree_size = 1;
while tree_size < n {
let mut best_edge = None;
let mut best_weight = 0;
for i in 0..n {
if !in_tree[i] {
continue;
}
for (j, &is_in_tree) in in_tree.iter().enumerate().take(n) {
if is_in_tree {
continue;
}
let separator = self.cliques[i].intersection(&self.cliques[j]);
let weight = separator.len();
if weight > best_weight {
best_weight = weight;
best_edge = Some((i, j, weight));
}
}
}
if let Some((i, j, _)) = best_edge {
edges_to_add.push((i, j, best_weight));
in_tree[j] = true;
tree_size += 1;
} else {
break;
}
}
for (c1, c2, _) in edges_to_add {
let separator = Separator::from_cliques(&self.cliques[c1], &self.cliques[c2]);
let edge = JunctionTreeEdge::new(c1, c2, separator);
self.edges.push(edge);
}
Ok(())
}
fn initialize_potentials(&mut self, graph: &FactorGraph) -> Result<()> {
for factor in graph.factors() {
let factor_vars: HashSet<String> = factor.variables.iter().cloned().collect();
let clique_idx = self
.cliques
.iter()
.position(|c| c.contains_all(&factor_vars))
.ok_or_else(|| {
PgmError::InvalidGraph(format!(
"No clique contains all variables for factor: {:?}",
factor.name
))
})?;
let clique = &mut self.cliques[clique_idx];
if let Some(ref mut potential) = clique.potential {
*potential = potential.product(factor)?;
} else {
clique.potential = Some(factor.clone());
}
}
for clique in &mut self.cliques {
if clique.potential.is_none() {
clique.potential = Some(Self::create_uniform_potential(&clique.variables, graph)?);
}
}
Ok(())
}
fn create_uniform_potential(
variables: &HashSet<String>,
graph: &FactorGraph,
) -> Result<Factor> {
let var_vec: Vec<String> = variables.iter().cloned().collect();
let mut shape = Vec::new();
for var in &var_vec {
let cardinality = graph
.get_variable(var)
.ok_or_else(|| PgmError::InvalidGraph(format!("Variable {} not found", var)))?
.cardinality;
shape.push(cardinality);
}
let size: usize = shape.iter().product();
let values = vec![1.0; size];
let array = ArrayD::from_shape_vec(shape, values)
.map_err(|e| PgmError::InvalidGraph(format!("Array creation failed: {}", e)))?;
Factor::new("uniform".to_string(), var_vec, array)
}
pub fn calibrate(&mut self) -> Result<()> {
if self.edges.is_empty() {
self.calibrated = true;
return Ok(());
}
let root = 0;
self.collect_evidence(root, None)?;
self.distribute_evidence(root, None)?;
self.calibrated = true;
Ok(())
}
fn collect_evidence(&mut self, current: usize, parent: Option<usize>) -> Result<()> {
let children: Vec<usize> = self.get_neighbors(current, parent);
for child in &children {
self.collect_evidence(*child, Some(current))?;
}
if let Some(parent_idx) = parent {
self.send_message(current, parent_idx)?;
}
Ok(())
}
fn distribute_evidence(&mut self, current: usize, parent: Option<usize>) -> Result<()> {
let children: Vec<usize> = self.get_neighbors(current, parent);
for child in &children {
self.send_message(current, *child)?;
self.distribute_evidence(*child, Some(current))?;
}
Ok(())
}
fn get_neighbors(&self, clique: usize, parent: Option<usize>) -> Vec<usize> {
let mut neighbors = Vec::new();
for edge in &self.edges {
if edge.clique1 == clique {
if parent != Some(edge.clique2) {
neighbors.push(edge.clique2);
}
} else if edge.clique2 == clique && parent != Some(edge.clique1) {
neighbors.push(edge.clique1);
}
}
neighbors
}
fn send_message(&mut self, from: usize, to: usize) -> Result<()> {
let edge_idx = self
.edges
.iter()
.position(|e| {
(e.clique1 == from && e.clique2 == to) || (e.clique1 == to && e.clique2 == from)
})
.ok_or_else(|| PgmError::InvalidGraph("Edge not found".to_string()))?;
let separator_vars = self.edges[edge_idx].separator.variables.clone();
let clique_potential = self.cliques[from].potential.clone().ok_or_else(|| {
PgmError::InvalidGraph("Clique potential not initialized".to_string())
})?;
let mut message = clique_potential;
let all_vars: HashSet<String> = message.variables.iter().cloned().collect();
let vars_to_eliminate: Vec<String> =
all_vars.difference(&separator_vars).cloned().collect();
for var in vars_to_eliminate {
message = message.marginalize_out(&var)?;
}
let edge = &mut self.edges[edge_idx];
if edge.clique1 == from {
edge.message_1_to_2 = Some(message);
} else {
edge.message_2_to_1 = Some(message);
}
Ok(())
}
pub fn query_marginal(&self, variable: &str) -> Result<ArrayD<f64>> {
if !self.calibrated {
return Err(PgmError::InvalidGraph(
"Tree must be calibrated before querying".to_string(),
));
}
let clique_indices = self
.var_to_cliques
.get(variable)
.ok_or_else(|| PgmError::InvalidGraph(format!("Variable {} not found", variable)))?;
if clique_indices.is_empty() {
return Err(PgmError::InvalidGraph(format!(
"No clique contains variable {}",
variable
)));
}
let clique = &self.cliques[clique_indices[0]];
let mut belief = clique.potential.clone().ok_or_else(|| {
PgmError::InvalidGraph("Clique potential not initialized".to_string())
})?;
let all_vars: HashSet<String> = belief.variables.iter().cloned().collect();
let mut target_set = HashSet::new();
target_set.insert(variable.to_string());
let vars_to_eliminate: Vec<String> = all_vars.difference(&target_set).cloned().collect();
for var in vars_to_eliminate {
belief = belief.marginalize_out(&var)?;
}
belief.normalize();
Ok(belief.values)
}
pub fn query_joint_marginal(&self, variables: &[String]) -> Result<ArrayD<f64>> {
if !self.calibrated {
return Err(PgmError::InvalidGraph(
"Tree must be calibrated before querying".to_string(),
));
}
let var_set: HashSet<String> = variables.iter().cloned().collect();
let clique = self
.cliques
.iter()
.find(|c| c.contains_all(&var_set))
.ok_or_else(|| {
PgmError::InvalidGraph(format!("No clique contains all variables: {:?}", variables))
})?;
let mut belief = clique.potential.clone().ok_or_else(|| {
PgmError::InvalidGraph("Clique potential not initialized".to_string())
})?;
let all_vars: HashSet<String> = belief.variables.iter().cloned().collect();
let vars_to_eliminate: Vec<String> = all_vars.difference(&var_set).cloned().collect();
for var in vars_to_eliminate {
belief = belief.marginalize_out(&var)?;
}
belief.normalize();
Ok(belief.values)
}
pub fn treewidth(&self) -> usize {
self.cliques
.iter()
.map(|c| c.variables.len())
.max()
.unwrap_or(0)
.saturating_sub(1)
}
pub fn verify_running_intersection_property(&self) -> bool {
for var in self.var_to_cliques.keys() {
let cliques_with_var = self
.var_to_cliques
.get(var)
.expect("var present in var_to_cliques, iterating over known keys");
if cliques_with_var.len() <= 1 {
continue;
}
if !self.is_connected_subgraph(cliques_with_var) {
return false;
}
}
true
}
fn is_connected_subgraph(&self, cliques: &[usize]) -> bool {
if cliques.is_empty() {
return true;
}
let clique_set: HashSet<usize> = cliques.iter().copied().collect();
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
queue.push_back(cliques[0]);
visited.insert(cliques[0]);
while let Some(current) = queue.pop_front() {
for edge in &self.edges {
let neighbor = if edge.clique1 == current {
Some(edge.clique2)
} else if edge.clique2 == current {
Some(edge.clique1)
} else {
None
};
if let Some(n) = neighbor {
if clique_set.contains(&n) && !visited.contains(&n) {
visited.insert(n);
queue.push_back(n);
}
}
}
}
visited.len() == cliques.len()
}
}
impl Default for JunctionTree {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::FactorGraph;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::Array;
#[test]
fn test_clique_creation() {
let mut vars = HashSet::new();
vars.insert("x".to_string());
vars.insert("y".to_string());
let clique = Clique::new(0, vars);
assert_eq!(clique.id, 0);
assert_eq!(clique.variables.len(), 2);
}
#[test]
fn test_clique_intersection() {
let mut vars1 = HashSet::new();
vars1.insert("x".to_string());
vars1.insert("y".to_string());
let mut vars2 = HashSet::new();
vars2.insert("y".to_string());
vars2.insert("z".to_string());
let c1 = Clique::new(0, vars1);
let c2 = Clique::new(1, vars2);
let intersection = c1.intersection(&c2);
assert_eq!(intersection.len(), 1);
assert!(intersection.contains("y"));
}
#[test]
fn test_interaction_graph() {
let mut graph = FactorGraph::new();
graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
graph.add_variable_with_card("z".to_string(), "Binary".to_string(), 2);
let pxy = Factor::new(
"P(x,y)".to_string(),
vec!["x".to_string(), "y".to_string()],
Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
.expect("unwrap")
.into_dyn(),
)
.expect("unwrap");
graph.add_factor(pxy).expect("unwrap");
let pyz = Factor::new(
"P(y,z)".to_string(),
vec!["y".to_string(), "z".to_string()],
Array::from_shape_vec(vec![2, 2], vec![0.5, 0.1, 0.2, 0.2])
.expect("unwrap")
.into_dyn(),
)
.expect("unwrap");
graph.add_factor(pyz).expect("unwrap");
let interaction_graph = JunctionTree::build_interaction_graph(&graph).expect("unwrap");
assert!(interaction_graph.get("x").expect("unwrap").contains("y"));
assert!(interaction_graph.get("y").expect("unwrap").contains("x"));
assert!(interaction_graph.get("y").expect("unwrap").contains("z"));
assert!(interaction_graph.get("z").expect("unwrap").contains("y"));
}
#[test]
fn test_junction_tree_construction() {
let mut graph = FactorGraph::new();
graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
let pxy = Factor::new(
"P(x,y)".to_string(),
vec!["x".to_string(), "y".to_string()],
Array::from_shape_vec(vec![2, 2], vec![0.3, 0.7, 0.4, 0.6])
.expect("unwrap")
.into_dyn(),
)
.expect("unwrap");
graph.add_factor(pxy).expect("unwrap");
let tree = JunctionTree::from_factor_graph(&graph).expect("unwrap");
assert!(!tree.cliques.is_empty());
assert!(tree.verify_running_intersection_property());
}
#[test]
fn test_junction_tree_calibration() {
let mut graph = FactorGraph::new();
graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
let pxy = Factor::new(
"P(x,y)".to_string(),
vec!["x".to_string(), "y".to_string()],
Array::from_shape_vec(vec![2, 2], vec![0.25, 0.25, 0.25, 0.25])
.expect("unwrap")
.into_dyn(),
)
.expect("unwrap");
graph.add_factor(pxy).expect("unwrap");
let mut tree = JunctionTree::from_factor_graph(&graph).expect("unwrap");
tree.calibrate().expect("unwrap");
assert!(tree.calibrated);
}
#[test]
fn test_marginal_query() {
let mut graph = FactorGraph::new();
graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
let pxy = Factor::new(
"P(x,y)".to_string(),
vec!["x".to_string(), "y".to_string()],
Array::from_shape_vec(vec![2, 2], vec![0.1, 0.4, 0.2, 0.3])
.expect("unwrap")
.into_dyn(),
)
.expect("unwrap");
graph.add_factor(pxy).expect("unwrap");
let mut tree = JunctionTree::from_factor_graph(&graph).expect("unwrap");
tree.calibrate().expect("unwrap");
let marginal_x = tree.query_marginal("x").expect("unwrap");
assert_abs_diff_eq!(marginal_x[[0]], 0.5, epsilon = 1e-6);
assert_abs_diff_eq!(marginal_x[[1]], 0.5, epsilon = 1e-6);
}
#[test]
fn test_treewidth() {
let mut graph = FactorGraph::new();
graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
graph.add_variable_with_card("z".to_string(), "Binary".to_string(), 2);
let pxy = Factor::new(
"P(x,y)".to_string(),
vec!["x".to_string(), "y".to_string()],
Array::from_shape_vec(vec![2, 2], vec![0.3, 0.7, 0.4, 0.6])
.expect("unwrap")
.into_dyn(),
)
.expect("unwrap");
let pyz = Factor::new(
"P(y,z)".to_string(),
vec!["y".to_string(), "z".to_string()],
Array::from_shape_vec(vec![2, 2], vec![0.5, 0.5, 0.6, 0.4])
.expect("unwrap")
.into_dyn(),
)
.expect("unwrap");
graph.add_factor(pxy).expect("unwrap");
graph.add_factor(pyz).expect("unwrap");
let tree = JunctionTree::from_factor_graph(&graph).expect("unwrap");
assert!(tree.treewidth() <= 2);
}
}