use std::fmt::Write;
use serde::{Deserialize, Serialize};
use crate::graph::FactorGraph;
#[derive(Debug, Clone, Default)]
pub struct FactorGraphModel {
pub variables: Vec<VizVariableNode>,
pub factors: Vec<VizFactorNode>,
}
#[derive(Debug, Clone)]
pub struct VizVariableNode {
pub name: String,
pub domain_size: usize,
}
#[derive(Debug, Clone)]
pub struct VizFactorNode {
pub name: String,
pub variable_indices: Vec<usize>,
}
impl FactorGraphModel {
pub fn new() -> Self {
Self::default()
}
pub fn from_factor_graph(fg: &FactorGraph) -> Self {
let mut var_names: Vec<String> = fg.variable_names().cloned().collect();
var_names.sort();
let mut name_to_idx: std::collections::HashMap<String, usize> =
std::collections::HashMap::new();
let mut model = Self::new();
for name in &var_names {
let card = fg.get_variable(name).map(|v| v.cardinality).unwrap_or(2);
let idx = model.add_variable(name.clone(), card);
name_to_idx.insert(name.clone(), idx);
}
for factor in fg.factors() {
let indices: Vec<usize> = factor
.variables
.iter()
.filter_map(|v| name_to_idx.get(v).copied())
.collect();
model.add_factor(factor.name.clone(), indices);
}
model
}
pub fn add_variable(&mut self, name: impl Into<String>, domain_size: usize) -> usize {
let idx = self.variables.len();
self.variables.push(VizVariableNode {
name: name.into(),
domain_size,
});
idx
}
pub fn add_factor(&mut self, name: impl Into<String>, variable_indices: Vec<usize>) {
self.factors.push(VizFactorNode {
name: name.into(),
variable_indices,
});
}
pub fn variable_count(&self) -> usize {
self.variables.len()
}
pub fn factor_count(&self) -> usize {
self.factors.len()
}
pub fn edge_count(&self) -> usize {
self.factors.iter().map(|f| f.variable_indices.len()).sum()
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct FactorGraphStats {
pub variable_count: usize,
pub factor_count: usize,
pub edge_count: usize,
pub max_factor_arity: usize,
pub avg_factor_arity: f64,
pub max_variable_degree: usize,
pub avg_variable_degree: f64,
pub is_tree: bool,
pub treewidth_upper_bound: usize,
}
impl FactorGraphStats {
pub fn compute(model: &FactorGraphModel) -> Self {
let variable_count = model.variable_count();
let factor_count = model.factor_count();
let edge_count = model.edge_count();
let max_factor_arity = model
.factors
.iter()
.map(|f| f.variable_indices.len())
.max()
.unwrap_or(0);
let avg_factor_arity = if factor_count > 0 {
edge_count as f64 / factor_count as f64
} else {
0.0
};
let mut var_degrees = vec![0usize; variable_count];
for factor in &model.factors {
for &vi in &factor.variable_indices {
if vi < variable_count {
var_degrees[vi] += 1;
}
}
}
let max_variable_degree = var_degrees.iter().copied().max().unwrap_or(0);
let avg_variable_degree = if variable_count > 0 {
var_degrees.iter().sum::<usize>() as f64 / variable_count as f64
} else {
0.0
};
let total_nodes = variable_count + factor_count;
let is_tree = total_nodes > 0 && edge_count + 1 == total_nodes;
let treewidth_upper_bound = if max_factor_arity > 0 {
max_factor_arity - 1
} else {
0
};
Self {
variable_count,
factor_count,
edge_count,
max_factor_arity,
avg_factor_arity,
max_variable_degree,
avg_variable_degree,
is_tree,
treewidth_upper_bound,
}
}
pub fn summary(&self) -> String {
format!(
"{} vars, {} factors, {} edges, treewidth\u{2264}{}{}",
self.variable_count,
self.factor_count,
self.edge_count,
self.treewidth_upper_bound,
if self.is_tree { " (tree)" } else { "" }
)
}
}
pub fn render_ascii(model: &FactorGraphModel) -> String {
let mut out = String::new();
let _ = writeln!(out, "Factor Graph:");
let var_descs: Vec<String> = model
.variables
.iter()
.map(|v| format!("{}({})", v.name, v.domain_size))
.collect();
let _ = writeln!(
out,
" Variables ({}): {}",
model.variable_count(),
var_descs.join(", ")
);
let fac_descs: Vec<String> = model
.factors
.iter()
.map(|f| format!("{}({})", f.name, f.variable_indices.len()))
.collect();
let _ = writeln!(
out,
" Factors ({}): {}",
model.factor_count(),
fac_descs.join(", ")
);
let _ = writeln!(out, " Connections:");
for factor in &model.factors {
let var_names: Vec<&str> = factor
.variable_indices
.iter()
.filter_map(|&i| model.variables.get(i).map(|v| v.name.as_str()))
.collect();
let _ = writeln!(
out,
" {} \u{2500}\u{2500} {}",
factor.name,
var_names.join(", ")
);
}
out
}
pub fn render_dot(model: &FactorGraphModel) -> String {
let mut dot = String::new();
let _ = writeln!(dot, "graph FactorGraph {{");
let _ = writeln!(dot, " rankdir=LR;");
for (i, var) in model.variables.iter().enumerate() {
let _ = writeln!(dot, " v{} [label=\"{}\", shape=circle];", i, var.name);
}
for (i, factor) in model.factors.iter().enumerate() {
let _ = writeln!(
dot,
" f{} [label=\"{}\", shape=square, style=filled, fillcolor=lightgray];",
i, factor.name
);
for &vi in &factor.variable_indices {
let _ = writeln!(dot, " f{} -- v{};", i, vi);
}
}
let _ = writeln!(dot, "}}");
dot
}
#[cfg(test)]
mod tests {
use super::*;
fn chain_model() -> FactorGraphModel {
let mut m = FactorGraphModel::new();
let a = m.add_variable("A", 2);
let b = m.add_variable("B", 2);
let c = m.add_variable("C", 2);
m.add_factor("f1", vec![a, b]);
m.add_factor("f2", vec![b, c]);
m
}
fn loopy_model() -> FactorGraphModel {
let mut m = FactorGraphModel::new();
let a = m.add_variable("A", 2);
let b = m.add_variable("B", 2);
let c = m.add_variable("C", 2);
m.add_factor("f1", vec![a, b]);
m.add_factor("f2", vec![b, c]);
m.add_factor("f3", vec![a, c]);
m
}
#[test]
fn test_model_new_empty() {
let m = FactorGraphModel::new();
assert_eq!(m.variable_count(), 0);
assert_eq!(m.factor_count(), 0);
assert_eq!(m.edge_count(), 0);
}
#[test]
fn test_model_add_variable() {
let mut m = FactorGraphModel::new();
let idx = m.add_variable("X", 4);
assert_eq!(idx, 0);
assert_eq!(m.variable_count(), 1);
assert_eq!(m.variables[0].domain_size, 4);
}
#[test]
fn test_model_add_factor() {
let mut m = FactorGraphModel::new();
let a = m.add_variable("A", 2);
m.add_factor("f1", vec![a]);
assert_eq!(m.factor_count(), 1);
}
#[test]
fn test_model_counts() {
let m = chain_model();
assert_eq!(m.variable_count(), 3);
assert_eq!(m.factor_count(), 2);
assert_eq!(m.edge_count(), 4);
}
#[test]
fn test_stats_empty() {
let m = FactorGraphModel::new();
let s = FactorGraphStats::compute(&m);
assert_eq!(s.variable_count, 0);
assert_eq!(s.factor_count, 0);
assert_eq!(s.edge_count, 0);
assert_eq!(s.max_factor_arity, 0);
assert!((s.avg_factor_arity - 0.0).abs() < f64::EPSILON);
}
#[test]
fn test_stats_simple_chain() {
let s = FactorGraphStats::compute(&chain_model());
assert_eq!(s.variable_count, 3);
assert_eq!(s.factor_count, 2);
assert_eq!(s.edge_count, 4);
}
#[test]
fn test_stats_max_factor_arity() {
let mut m = FactorGraphModel::new();
let a = m.add_variable("A", 2);
let b = m.add_variable("B", 2);
let c = m.add_variable("C", 2);
m.add_factor("big", vec![a, b, c]);
let s = FactorGraphStats::compute(&m);
assert_eq!(s.max_factor_arity, 3);
}
#[test]
fn test_stats_avg_factor_arity() {
let s = FactorGraphStats::compute(&chain_model());
assert!((s.avg_factor_arity - 2.0).abs() < f64::EPSILON);
}
#[test]
fn test_stats_variable_degree() {
let s = FactorGraphStats::compute(&chain_model());
assert_eq!(s.max_variable_degree, 2);
}
#[test]
fn test_stats_is_tree_true() {
let s = FactorGraphStats::compute(&chain_model());
assert!(s.is_tree);
}
#[test]
fn test_stats_is_tree_false() {
let s = FactorGraphStats::compute(&loopy_model());
assert!(!s.is_tree);
}
#[test]
fn test_stats_treewidth() {
let s = FactorGraphStats::compute(&chain_model());
assert_eq!(s.treewidth_upper_bound, 1);
}
#[test]
fn test_stats_summary() {
let s = FactorGraphStats::compute(&chain_model());
let summary = s.summary();
assert!(summary.contains("vars"));
assert!(summary.contains("factors"));
}
#[test]
fn test_render_ascii_header() {
let out = render_ascii(&chain_model());
assert!(out.contains("Factor Graph:"));
}
#[test]
fn test_render_ascii_variables() {
let out = render_ascii(&chain_model());
assert!(out.contains("A(2)"));
assert!(out.contains("B(2)"));
assert!(out.contains("C(2)"));
}
#[test]
fn test_render_ascii_connections() {
let out = render_ascii(&chain_model());
assert!(out.contains("f1"));
assert!(out.contains("A"));
assert!(out.contains("B"));
}
#[test]
fn test_render_dot_undirected() {
let dot = render_dot(&chain_model());
assert!(dot.starts_with("graph "));
assert!(!dot.contains("digraph"));
}
#[test]
fn test_render_dot_nodes() {
let dot = render_dot(&chain_model());
assert!(dot.contains("v0"));
assert!(dot.contains("shape=circle"));
assert!(dot.contains("f0"));
assert!(dot.contains("shape=square"));
}
}