use crate::error::{PgmError, Result};
use crate::graph::FactorGraph;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum EliminationStrategy {
#[default]
MinDegree,
MinFill,
WeightedMinFill,
MinWidth,
MaxCardinalitySearch,
}
pub struct EliminationOrdering {
strategy: EliminationStrategy,
}
impl Default for EliminationOrdering {
fn default() -> Self {
Self::new(EliminationStrategy::default())
}
}
impl EliminationOrdering {
pub fn new(strategy: EliminationStrategy) -> Self {
Self { strategy }
}
pub fn compute_order(&self, graph: &FactorGraph, vars: &[String]) -> Result<Vec<String>> {
match self.strategy {
EliminationStrategy::MinDegree => self.min_degree_order(graph, vars),
EliminationStrategy::MinFill => self.min_fill_order(graph, vars),
EliminationStrategy::WeightedMinFill => self.weighted_min_fill_order(graph, vars),
EliminationStrategy::MinWidth => self.min_width_order(graph, vars),
EliminationStrategy::MaxCardinalitySearch => self.max_cardinality_search(graph, vars),
}
}
fn min_degree_order(&self, graph: &FactorGraph, vars: &[String]) -> Result<Vec<String>> {
let mut remaining: HashSet<String> = vars.iter().cloned().collect();
let mut order = Vec::new();
let mut adjacency = self.build_adjacency_graph(graph, &remaining)?;
while !remaining.is_empty() {
let min_var = remaining
.iter()
.min_by_key(|v| adjacency.get(*v).map(|s| s.len()).unwrap_or(0))
.ok_or_else(|| PgmError::InvalidGraph("No variables to eliminate".to_string()))?
.clone();
order.push(min_var.clone());
remaining.remove(&min_var);
self.update_adjacency_after_elimination(&mut adjacency, &min_var);
}
Ok(order)
}
fn min_fill_order(&self, graph: &FactorGraph, vars: &[String]) -> Result<Vec<String>> {
let mut remaining: HashSet<String> = vars.iter().cloned().collect();
let mut order = Vec::new();
let mut adjacency = self.build_adjacency_graph(graph, &remaining)?;
while !remaining.is_empty() {
let min_var = remaining
.iter()
.min_by_key(|v| self.compute_fill(&adjacency, v))
.ok_or_else(|| PgmError::InvalidGraph("No variables to eliminate".to_string()))?
.clone();
order.push(min_var.clone());
remaining.remove(&min_var);
self.update_adjacency_after_elimination(&mut adjacency, &min_var);
}
Ok(order)
}
fn weighted_min_fill_order(&self, graph: &FactorGraph, vars: &[String]) -> Result<Vec<String>> {
let mut remaining: HashSet<String> = vars.iter().cloned().collect();
let mut order = Vec::new();
let mut adjacency = self.build_adjacency_graph(graph, &remaining)?;
let weights = self.compute_variable_weights(graph, vars)?;
while !remaining.is_empty() {
let min_var = remaining
.iter()
.min_by_key(|v| {
let fill = self.compute_fill(&adjacency, v);
let weight = weights.get(*v).copied().unwrap_or(1);
fill * weight
})
.ok_or_else(|| PgmError::InvalidGraph("No variables to eliminate".to_string()))?
.clone();
order.push(min_var.clone());
remaining.remove(&min_var);
self.update_adjacency_after_elimination(&mut adjacency, &min_var);
}
Ok(order)
}
fn min_width_order(&self, graph: &FactorGraph, vars: &[String]) -> Result<Vec<String>> {
let mut remaining: HashSet<String> = vars.iter().cloned().collect();
let mut order = Vec::new();
let mut adjacency = self.build_adjacency_graph(graph, &remaining)?;
while !remaining.is_empty() {
let min_var = remaining
.iter()
.min_by_key(|v| {
let neighbors = adjacency.get(*v).map(|s| s.len()).unwrap_or(0);
neighbors
})
.ok_or_else(|| PgmError::InvalidGraph("No variables to eliminate".to_string()))?
.clone();
order.push(min_var.clone());
remaining.remove(&min_var);
self.update_adjacency_after_elimination(&mut adjacency, &min_var);
}
Ok(order)
}
fn max_cardinality_search(&self, graph: &FactorGraph, vars: &[String]) -> Result<Vec<String>> {
let mut remaining: HashSet<String> = vars.iter().cloned().collect();
let mut order = Vec::new();
let mut cardinality: HashMap<String, usize> = HashMap::new();
for var in vars {
cardinality.insert(var.clone(), 0);
}
let adjacency = self.build_adjacency_graph(graph, &remaining)?;
while !remaining.is_empty() {
let max_var = remaining
.iter()
.max_by_key(|v| cardinality.get(*v).copied().unwrap_or(0))
.ok_or_else(|| PgmError::InvalidGraph("No variables to eliminate".to_string()))?
.clone();
order.push(max_var.clone());
remaining.remove(&max_var);
if let Some(neighbors) = adjacency.get(&max_var) {
for neighbor in neighbors {
if remaining.contains(neighbor) {
*cardinality.entry(neighbor.clone()).or_insert(0) += 1;
}
}
}
}
Ok(order)
}
fn build_adjacency_graph(
&self,
graph: &FactorGraph,
vars: &HashSet<String>,
) -> Result<HashMap<String, HashSet<String>>> {
let mut adjacency: HashMap<String, HashSet<String>> = HashMap::new();
for var in vars {
adjacency.insert(var.clone(), HashSet::new());
}
for factor_id in graph.factor_ids() {
if let Some(factor) = graph.get_factor(factor_id) {
let factor_vars: Vec<String> = factor
.variables
.iter()
.filter(|v| vars.contains(*v))
.cloned()
.collect();
for i in 0..factor_vars.len() {
for j in (i + 1)..factor_vars.len() {
let v1 = &factor_vars[i];
let v2 = &factor_vars[j];
adjacency.entry(v1.clone()).or_default().insert(v2.clone());
adjacency.entry(v2.clone()).or_default().insert(v1.clone());
}
}
}
}
Ok(adjacency)
}
fn compute_fill(&self, adjacency: &HashMap<String, HashSet<String>>, var: &str) -> usize {
let neighbors = match adjacency.get(var) {
Some(n) => n,
None => return 0,
};
if neighbors.is_empty() {
return 0;
}
let mut fill = 0;
let neighbors_vec: Vec<_> = neighbors.iter().collect();
for i in 0..neighbors_vec.len() {
for j in (i + 1)..neighbors_vec.len() {
let v1 = neighbors_vec[i];
let v2 = neighbors_vec[j];
if let Some(adj_v1) = adjacency.get(v1) {
if !adj_v1.contains(v2) {
fill += 1;
}
}
}
}
fill
}
fn update_adjacency_after_elimination(
&self,
adjacency: &mut HashMap<String, HashSet<String>>,
var: &str,
) {
let neighbors = match adjacency.remove(var) {
Some(n) => n,
None => return,
};
for neighbor in &neighbors {
if let Some(adj) = adjacency.get_mut(neighbor) {
adj.remove(var);
}
}
let neighbors_vec: Vec<_> = neighbors.iter().cloned().collect();
for i in 0..neighbors_vec.len() {
for j in (i + 1)..neighbors_vec.len() {
let v1 = &neighbors_vec[i];
let v2 = &neighbors_vec[j];
if let Some(adj_v1) = adjacency.get_mut(v1) {
adj_v1.insert(v2.clone());
}
if let Some(adj_v2) = adjacency.get_mut(v2) {
adj_v2.insert(v1.clone());
}
}
}
}
fn compute_variable_weights(
&self,
graph: &FactorGraph,
vars: &[String],
) -> Result<HashMap<String, usize>> {
let mut weights = HashMap::new();
for var in vars {
let mut weight = 1;
if let Some(factors) = graph.get_adjacent_factors(var) {
for factor_id in factors {
if let Some(factor) = graph.get_factor(factor_id) {
for factor_var in &factor.variables {
if let Some(var_node) = graph.get_variable(factor_var) {
weight *= var_node.cardinality;
}
}
}
}
}
weights.insert(var.clone(), weight);
}
Ok(weights)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Factor;
use scirs2_core::ndarray::Array;
fn create_test_graph() -> FactorGraph {
let mut graph = FactorGraph::new();
graph.add_variable_with_card("X".to_string(), "Domain".to_string(), 2);
graph.add_variable_with_card("Y".to_string(), "Domain".to_string(), 2);
graph.add_variable_with_card("Z".to_string(), "Domain".to_string(), 2);
let f_xy = Factor::new(
"f_xy".to_string(),
vec!["X".to_string(), "Y".to_string()],
Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
.expect("unwrap")
.into_dyn(),
)
.expect("unwrap");
let f_yz = Factor::new(
"f_yz".to_string(),
vec!["Y".to_string(), "Z".to_string()],
Array::from_shape_vec(vec![2, 2], vec![0.5, 0.6, 0.7, 0.8])
.expect("unwrap")
.into_dyn(),
)
.expect("unwrap");
graph.add_factor(f_xy).expect("unwrap");
graph.add_factor(f_yz).expect("unwrap");
graph
}
#[test]
fn test_min_degree_ordering() {
let graph = create_test_graph();
let vars = vec!["X".to_string(), "Y".to_string(), "Z".to_string()];
let ordering = EliminationOrdering::new(EliminationStrategy::MinDegree);
let order = ordering.compute_order(&graph, &vars).expect("unwrap");
assert_eq!(order.len(), 3);
assert!(order[0] == "X" || order[0] == "Z");
}
#[test]
fn test_min_fill_ordering() {
let graph = create_test_graph();
let vars = vec!["X".to_string(), "Y".to_string(), "Z".to_string()];
let ordering = EliminationOrdering::new(EliminationStrategy::MinFill);
let order = ordering.compute_order(&graph, &vars).expect("unwrap");
assert_eq!(order.len(), 3);
}
#[test]
fn test_weighted_min_fill_ordering() {
let graph = create_test_graph();
let vars = vec!["X".to_string(), "Y".to_string(), "Z".to_string()];
let ordering = EliminationOrdering::new(EliminationStrategy::WeightedMinFill);
let order = ordering.compute_order(&graph, &vars).expect("unwrap");
assert_eq!(order.len(), 3);
}
#[test]
fn test_max_cardinality_search() {
let graph = create_test_graph();
let vars = vec!["X".to_string(), "Y".to_string(), "Z".to_string()];
let ordering = EliminationOrdering::new(EliminationStrategy::MaxCardinalitySearch);
let order = ordering.compute_order(&graph, &vars).expect("unwrap");
assert_eq!(order.len(), 3);
}
}