use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::error::{PgmError, Result};
use crate::factor::Factor;
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct VariableNode {
pub name: String,
pub domain: String,
pub cardinality: usize,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct FactorNode {
pub id: String,
pub variables: Vec<String>,
}
#[derive(Clone, Debug)]
pub struct FactorGraph {
variables: HashMap<String, VariableNode>,
factors: HashMap<String, Factor>,
var_to_factors: HashMap<String, Vec<String>>,
factor_to_vars: HashMap<String, Vec<String>>,
}
impl FactorGraph {
pub fn new() -> Self {
Self {
variables: HashMap::new(),
factors: HashMap::new(),
var_to_factors: HashMap::new(),
factor_to_vars: HashMap::new(),
}
}
pub fn add_variable(&mut self, name: String, domain: String) {
let node = VariableNode {
name: name.clone(),
domain,
cardinality: 2, };
self.variables.insert(name.clone(), node);
self.var_to_factors.entry(name).or_default();
}
pub fn add_variable_with_card(&mut self, name: String, domain: String, cardinality: usize) {
let node = VariableNode {
name: name.clone(),
domain,
cardinality,
};
self.variables.insert(name.clone(), node);
self.var_to_factors.entry(name).or_default();
}
pub fn add_factor(&mut self, factor: Factor) -> Result<()> {
let factor_id = factor.name.clone();
for var in &factor.variables {
if !self.variables.contains_key(var) {
return Err(PgmError::VariableNotFound(var.clone()));
}
}
for var in &factor.variables {
self.var_to_factors
.entry(var.clone())
.or_default()
.push(factor_id.clone());
}
self.factor_to_vars
.insert(factor_id.clone(), factor.variables.clone());
self.factors.insert(factor_id, factor);
Ok(())
}
pub fn add_factor_from_predicate(&mut self, name: &str, var_names: &[String]) -> Result<()> {
let factor = Factor::uniform(name.to_string(), var_names.to_vec(), 2);
self.add_factor(factor)
}
pub fn get_variable(&self, name: &str) -> Option<&VariableNode> {
self.variables.get(name)
}
pub fn get_factor(&self, id: &str) -> Option<&Factor> {
self.factors.get(id)
}
pub fn get_factor_by_name(&self, name: &str) -> Option<&Factor> {
self.factors.values().find(|f| f.name == name)
}
pub fn get_adjacent_factors(&self, var: &str) -> Option<&Vec<String>> {
self.var_to_factors.get(var)
}
pub fn get_adjacent_variables(&self, factor_id: &str) -> Option<&Vec<String>> {
self.factor_to_vars.get(factor_id)
}
pub fn num_variables(&self) -> usize {
self.variables.len()
}
pub fn num_factors(&self) -> usize {
self.factors.len()
}
pub fn is_empty(&self) -> bool {
self.variables.is_empty() && self.factors.is_empty()
}
pub fn variable_names(&self) -> impl Iterator<Item = &String> {
self.variables.keys()
}
pub fn factor_ids(&self) -> impl Iterator<Item = &String> {
self.factors.keys()
}
pub fn variables(&self) -> impl Iterator<Item = (&String, &VariableNode)> {
self.variables.iter()
}
pub fn factors(&self) -> impl Iterator<Item = &Factor> {
self.factors.values()
}
pub fn get_all_factors(&self) -> Vec<&Factor> {
self.factors.values().collect()
}
}
impl Default for FactorGraph {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_graph_creation() {
let graph = FactorGraph::new();
assert!(graph.is_empty());
}
#[test]
fn test_add_variables() {
let mut graph = FactorGraph::new();
graph.add_variable("x".to_string(), "D1".to_string());
graph.add_variable("y".to_string(), "D2".to_string());
assert_eq!(graph.num_variables(), 2);
assert!(graph.get_variable("x").is_some());
}
#[test]
fn test_add_factor() {
let mut graph = FactorGraph::new();
graph.add_variable("x".to_string(), "D1".to_string());
graph.add_variable("y".to_string(), "D2".to_string());
let result = graph.add_factor_from_predicate("P", &["x".to_string(), "y".to_string()]);
assert!(result.is_ok());
assert_eq!(graph.num_factors(), 1);
}
#[test]
fn test_adjacency() {
let mut graph = FactorGraph::new();
graph.add_variable("x".to_string(), "D1".to_string());
graph.add_variable("y".to_string(), "D2".to_string());
graph
.add_factor_from_predicate("P", &["x".to_string(), "y".to_string()])
.expect("unwrap");
let adjacent = graph.get_adjacent_factors("x");
assert!(adjacent.is_some());
assert_eq!(adjacent.expect("unwrap").len(), 1);
}
}