use scirs2_core::ndarray::ArrayD;
use serde::{Deserialize, Serialize};
use crate::error::{PgmError, Result};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct Factor {
pub variables: Vec<String>,
pub values: ArrayD<f64>,
pub name: String,
}
impl Factor {
pub fn new(name: String, variables: Vec<String>, values: ArrayD<f64>) -> Result<Self> {
if values.ndim() != variables.len() {
return Err(PgmError::DimensionMismatch {
expected: vec![variables.len()],
got: vec![values.ndim()],
});
}
Ok(Self {
name,
variables,
values,
})
}
pub fn uniform(name: String, variables: Vec<String>, card: usize) -> Self {
let shape = vec![card; variables.len()];
let values = ArrayD::from_elem(shape, 1.0 / (card.pow(variables.len() as u32) as f64));
Self {
name,
variables,
values,
}
}
pub fn normalize(&mut self) {
let sum: f64 = self.values.iter().sum();
if sum > 0.0 {
self.values /= sum;
}
}
pub fn get_cardinality(&self, var: &str) -> Option<usize> {
self.variables
.iter()
.position(|v| v == var)
.map(|idx| self.values.shape()[idx])
}
}
pub enum FactorOp {
Product,
Marginalize,
Divide,
}
impl Factor {
pub fn product(&self, other: &Factor) -> Result<Factor> {
let mut all_vars = self.variables.clone();
for v in &other.variables {
if !all_vars.contains(v) {
all_vars.push(v.clone());
}
}
let mut shape = Vec::new();
let mut self_mapping = Vec::new(); let mut other_mapping = Vec::new();
for var in &all_vars {
let self_idx_opt = self.variables.iter().position(|v| v == var);
let other_idx_opt = other.variables.iter().position(|v| v == var);
let cardinality = if let Some(self_idx) = self_idx_opt {
self_mapping.push(Some(self_idx));
self.values.shape()[self_idx]
} else if let Some(other_idx) = other_idx_opt {
self_mapping.push(None);
other.values.shape()[other_idx]
} else {
unreachable!("Variable must be in at least one factor");
};
if let Some(other_idx) = other_idx_opt {
other_mapping.push(Some(other_idx));
} else {
other_mapping.push(None);
}
shape.push(cardinality);
}
let mut result_values = ArrayD::zeros(shape.clone());
let total_size: usize = shape.iter().product();
for linear_idx in 0..total_size {
let mut assignment = Vec::new();
let mut temp_idx = linear_idx;
for &dim in shape.iter().rev() {
assignment.push(temp_idx % dim);
temp_idx /= dim;
}
assignment.reverse();
let self_idx: Vec<usize> = self_mapping
.iter()
.enumerate()
.filter_map(|(i, &opt)| opt.map(|_| assignment[i]))
.collect();
let other_idx: Vec<usize> = other_mapping
.iter()
.enumerate()
.filter_map(|(i, &opt)| opt.map(|_| assignment[i]))
.collect();
let self_val = if self_idx.len() == self.variables.len() {
self.values[self_idx.as_slice()]
} else {
1.0
};
let other_val = if other_idx.len() == other.variables.len() {
other.values[other_idx.as_slice()]
} else {
1.0
};
result_values[assignment.as_slice()] = self_val * other_val;
}
Ok(Factor {
name: format!("{}*{}", self.name, other.name),
variables: all_vars,
values: result_values,
})
}
pub fn marginalize_out(&self, var: &str) -> Result<Factor> {
use scirs2_core::ndarray::Axis;
let var_idx = self
.variables
.iter()
.position(|v| v == var)
.ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
let new_values = self.values.sum_axis(Axis(var_idx));
let new_vars: Vec<String> = self
.variables
.iter()
.filter(|v| *v != var)
.cloned()
.collect();
Ok(Factor {
name: format!("{}_marg", self.name),
variables: new_vars,
values: new_values,
})
}
pub fn marginalize_out_vars(&self, vars: &[String]) -> Result<Factor> {
let mut result = self.clone();
for var in vars {
result = result.marginalize_out(var)?;
}
Ok(result)
}
pub fn marginalize_out_all_except(&self, keep_vars: &[String]) -> Result<Factor> {
let vars_to_remove: Vec<String> = self
.variables
.iter()
.filter(|v| !keep_vars.contains(v))
.cloned()
.collect();
self.marginalize_out_vars(&vars_to_remove)
}
pub fn maximize_out(&self, var: &str) -> Result<Factor> {
use scirs2_core::ndarray::Axis;
let var_idx = self
.variables
.iter()
.position(|v| v == var)
.ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
let new_values = self.values.map_axis(Axis(var_idx), |view| {
view.iter().fold(f64::NEG_INFINITY, |acc, &x| acc.max(x))
});
let new_vars: Vec<String> = self
.variables
.iter()
.filter(|v| *v != var)
.cloned()
.collect();
Ok(Factor {
name: format!("{}_max", self.name),
variables: new_vars,
values: new_values,
})
}
pub fn maximize_out_vars(&self, vars: &[String]) -> Result<Factor> {
let mut result = self.clone();
for var in vars {
result = result.maximize_out(var)?;
}
Ok(result)
}
pub fn divide(&self, other: &Factor) -> Result<Factor> {
if self.variables != other.variables {
return Err(PgmError::InvalidDistribution(
"Cannot divide factors with different variables".to_string(),
));
}
let result_values = &self.values
/ &other
.values
.mapv(|x| if x.abs() < 1e-10 { 1e-10 } else { x });
Ok(Factor {
name: format!("{}/{}", self.name, other.name),
variables: self.variables.clone(),
values: result_values,
})
}
pub fn reduce(&self, var: &str, value: usize) -> Result<Factor> {
use scirs2_core::ndarray::Axis;
let var_idx = self
.variables
.iter()
.position(|v| v == var)
.ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
if value >= self.values.shape()[var_idx] {
return Err(PgmError::InvalidDistribution(format!(
"Value {} out of bounds for variable {} with cardinality {}",
value,
var,
self.values.shape()[var_idx]
)));
}
let new_values = self.values.index_axis(Axis(var_idx), value).to_owned();
let new_vars: Vec<String> = self
.variables
.iter()
.filter(|v| *v != var)
.cloned()
.collect();
Ok(Factor {
name: format!("{}_reduced", self.name),
variables: new_vars,
values: new_values,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array;
#[test]
fn test_factor_creation() {
let values = Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
.expect("unwrap")
.into_dyn();
let factor = Factor::new(
"f1".to_string(),
vec!["x".to_string(), "y".to_string()],
values,
)
.expect("unwrap");
assert_eq!(factor.variables.len(), 2);
assert_eq!(factor.values.ndim(), 2);
}
#[test]
fn test_factor_normalize() {
let values = Array::from_shape_vec(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0])
.expect("unwrap")
.into_dyn();
let mut factor = Factor::new(
"f1".to_string(),
vec!["x".to_string(), "y".to_string()],
values,
)
.expect("unwrap");
factor.normalize();
let sum: f64 = factor.values.iter().sum();
assert!((sum - 1.0).abs() < 1e-10);
}
#[test]
fn test_uniform_factor() {
let factor = Factor::uniform("f1".to_string(), vec!["x".to_string()], 3);
assert_eq!(factor.values.len(), 3);
let sum: f64 = factor.values.iter().sum();
assert!((sum - 1.0).abs() < 1e-10);
}
#[test]
fn test_factor_product() {
let f1_values = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
.expect("unwrap")
.into_dyn();
let f1 = Factor::new("f1".to_string(), vec!["x".to_string()], f1_values).expect("unwrap");
let f2_values = Array::from_shape_vec(vec![2], vec![0.7, 0.3])
.expect("unwrap")
.into_dyn();
let f2 = Factor::new("f2".to_string(), vec!["y".to_string()], f2_values).expect("unwrap");
let product = f1.product(&f2).expect("unwrap");
assert_eq!(product.variables.len(), 2);
assert_eq!(product.values.shape(), &[2, 2]);
let expected = 0.6 * 0.7 + 0.6 * 0.3 + 0.4 * 0.7 + 0.4 * 0.3;
let actual: f64 = product.values.iter().sum();
assert!((actual - expected).abs() < 1e-10);
}
#[test]
fn test_factor_marginalize() {
let values = Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
.expect("unwrap")
.into_dyn();
let factor = Factor::new(
"f1".to_string(),
vec!["x".to_string(), "y".to_string()],
values,
)
.expect("unwrap");
let marginal = factor.marginalize_out("y").expect("unwrap");
assert_eq!(marginal.variables.len(), 1);
assert_eq!(marginal.variables[0], "x");
assert_eq!(marginal.values.shape(), &[2]);
assert!((marginal.values[[0]] - 0.3).abs() < 1e-10);
assert!((marginal.values[[1]] - 0.7).abs() < 1e-10);
}
#[test]
fn test_factor_divide() {
let values1 = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
.expect("unwrap")
.into_dyn();
let f1 = Factor::new("f1".to_string(), vec!["x".to_string()], values1).expect("unwrap");
let values2 = Array::from_shape_vec(vec![2], vec![0.3, 0.2])
.expect("unwrap")
.into_dyn();
let f2 = Factor::new("f2".to_string(), vec!["x".to_string()], values2).expect("unwrap");
let result = f1.divide(&f2).expect("unwrap");
assert_eq!(result.variables.len(), 1);
assert!((result.values[[0]] - 2.0).abs() < 1e-10);
assert!((result.values[[1]] - 2.0).abs() < 1e-10);
}
#[test]
fn test_factor_reduce() {
let values = Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
.expect("unwrap")
.into_dyn();
let factor = Factor::new(
"f1".to_string(),
vec!["x".to_string(), "y".to_string()],
values,
)
.expect("unwrap");
let reduced = factor.reduce("y", 1).expect("unwrap");
assert_eq!(reduced.variables.len(), 1);
assert_eq!(reduced.variables[0], "x");
assert!((reduced.values[[0]] - 0.2).abs() < 1e-10);
assert!((reduced.values[[1]] - 0.4).abs() < 1e-10);
}
#[test]
fn test_factor_product_with_shared_vars() {
let f1_values = Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
.expect("unwrap")
.into_dyn();
let f1 = Factor::new(
"f1".to_string(),
vec!["x".to_string(), "y".to_string()],
f1_values,
)
.expect("unwrap");
let f2_values = Array::from_shape_vec(vec![2, 2], vec![0.5, 0.5, 0.5, 0.5])
.expect("unwrap")
.into_dyn();
let f2 = Factor::new(
"f2".to_string(),
vec!["y".to_string(), "z".to_string()],
f2_values,
)
.expect("unwrap");
let product = f1.product(&f2).expect("unwrap");
assert_eq!(product.variables.len(), 3);
assert!(product.variables.contains(&"x".to_string()));
assert!(product.variables.contains(&"y".to_string()));
assert!(product.variables.contains(&"z".to_string()));
}
#[test]
fn test_factor_maximize() {
let values = Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
.expect("unwrap")
.into_dyn();
let factor = Factor::new(
"f1".to_string(),
vec!["x".to_string(), "y".to_string()],
values,
)
.expect("unwrap");
let maximized = factor.maximize_out("y").expect("unwrap");
assert_eq!(maximized.variables.len(), 1);
assert_eq!(maximized.variables[0], "x");
assert_eq!(maximized.values.shape(), &[2]);
assert!((maximized.values[[0]] - 0.2).abs() < 1e-10);
assert!((maximized.values[[1]] - 0.4).abs() < 1e-10);
}
#[test]
fn test_factor_maximize_multiple() {
let values =
Array::from_shape_vec(vec![2, 2, 2], vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8])
.expect("unwrap")
.into_dyn();
let factor = Factor::new(
"f1".to_string(),
vec!["x".to_string(), "y".to_string(), "z".to_string()],
values,
)
.expect("unwrap");
let maximized = factor
.maximize_out_vars(&["y".to_string(), "z".to_string()])
.expect("unwrap");
assert_eq!(maximized.variables.len(), 1);
assert_eq!(maximized.variables[0], "x");
}
}