use scirs2_core::ndarray::{ArrayD, IxDyn};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use crate::error::{PgmError, Result};
use crate::Factor;
#[derive(Debug)]
pub struct FactorPool {
pools: Mutex<HashMap<usize, Vec<Vec<f64>>>>,
stats: Mutex<PoolStats>,
max_pool_size: usize,
}
#[derive(Debug, Clone, Default)]
pub struct PoolStats {
pub hits: usize,
pub misses: usize,
pub returns: usize,
pub peak_bytes: usize,
pub current_bytes: usize,
}
impl Default for FactorPool {
fn default() -> Self {
Self::new(100)
}
}
impl FactorPool {
pub fn new(max_pool_size: usize) -> Self {
Self {
pools: Mutex::new(HashMap::new()),
stats: Mutex::new(PoolStats::default()),
max_pool_size,
}
}
pub fn allocate(&self, size: usize) -> Vec<f64> {
let mut pools = self.pools.lock().expect("lock should not be poisoned");
let mut stats = self.stats.lock().expect("lock should not be poisoned");
if let Some(pool) = pools.get_mut(&size) {
if let Some(array) = pool.pop() {
stats.hits += 1;
stats.current_bytes -= size * std::mem::size_of::<f64>();
return array;
}
}
stats.misses += 1;
vec![0.0; size]
}
pub fn return_array(&self, mut array: Vec<f64>) {
let size = array.len();
let mut pools = self.pools.lock().expect("lock should not be poisoned");
let mut stats = self.stats.lock().expect("lock should not be poisoned");
let pool = pools.entry(size).or_default();
if pool.len() < self.max_pool_size {
array.fill(0.0);
pool.push(array);
stats.returns += 1;
stats.current_bytes += size * std::mem::size_of::<f64>();
stats.peak_bytes = stats.peak_bytes.max(stats.current_bytes);
}
}
pub fn stats(&self) -> PoolStats {
self.stats
.lock()
.expect("lock should not be poisoned")
.clone()
}
pub fn clear(&self) {
let mut pools = self.pools.lock().expect("lock should not be poisoned");
let mut stats = self.stats.lock().expect("lock should not be poisoned");
pools.clear();
stats.current_bytes = 0;
}
pub fn hit_rate(&self) -> f64 {
let stats = self.stats.lock().expect("lock should not be poisoned");
let total = stats.hits + stats.misses;
if total > 0 {
stats.hits as f64 / total as f64
} else {
0.0
}
}
}
#[derive(Debug, Clone)]
pub struct SparseFactor {
pub variables: Vec<String>,
pub cardinalities: Vec<usize>,
pub entries: Vec<(Vec<usize>, f64)>,
pub default_value: f64,
}
impl SparseFactor {
pub fn new(variables: Vec<String>, cardinalities: Vec<usize>) -> Self {
Self {
variables,
cardinalities,
entries: Vec::new(),
default_value: 0.0,
}
}
pub fn from_dense(factor: &Factor, threshold: f64) -> Self {
let shape: Vec<usize> = factor.values.shape().to_vec();
let mut sparse = Self::new(factor.variables.clone(), shape.clone());
sparse.default_value = 0.0;
let total_size: usize = shape.iter().product();
for i in 0..total_size {
let indices = Self::flat_to_indices(i, &shape);
let value = factor.values[indices.as_slice()];
if value.abs() > threshold {
sparse.entries.push((indices, value));
}
}
sparse
}
pub fn to_dense(&self) -> Result<Factor> {
let total_size: usize = self.cardinalities.iter().product();
let mut values = vec![self.default_value; total_size];
for (indices, value) in &self.entries {
let flat_idx = Self::indices_to_flat(indices, &self.cardinalities);
values[flat_idx] = *value;
}
let array = ArrayD::from_shape_vec(IxDyn(&self.cardinalities), values)?;
Factor::new("sparse".to_string(), self.variables.clone(), array)
}
pub fn get(&self, indices: &[usize]) -> f64 {
for (entry_indices, value) in &self.entries {
if entry_indices == indices {
return *value;
}
}
self.default_value
}
pub fn set(&mut self, indices: Vec<usize>, value: f64) {
for (entry_indices, entry_value) in &mut self.entries {
if *entry_indices == indices {
*entry_value = value;
return;
}
}
if (value - self.default_value).abs() > 1e-10 {
self.entries.push((indices, value));
}
}
pub fn sparsity(&self) -> f64 {
let total_size: usize = self.cardinalities.iter().product();
if total_size > 0 {
1.0 - (self.entries.len() as f64 / total_size as f64)
} else {
1.0
}
}
pub fn memory_savings(&self) -> f64 {
let dense_bytes = self.cardinalities.iter().product::<usize>() * std::mem::size_of::<f64>();
let sparse_bytes = self.entries.len()
* (self.variables.len() * std::mem::size_of::<usize>() + std::mem::size_of::<f64>());
if dense_bytes > 0 {
1.0 - (sparse_bytes as f64 / dense_bytes as f64)
} else {
0.0
}
}
fn flat_to_indices(flat: usize, shape: &[usize]) -> Vec<usize> {
let mut indices = vec![0; shape.len()];
let mut remaining = flat;
for i in (0..shape.len()).rev() {
indices[i] = remaining % shape[i];
remaining /= shape[i];
}
indices
}
fn indices_to_flat(indices: &[usize], shape: &[usize]) -> usize {
let mut flat = 0;
let mut stride = 1;
for i in (0..shape.len()).rev() {
flat += indices[i] * stride;
stride *= shape[i];
}
flat
}
}
#[derive(Clone)]
pub struct LazyFactor {
computation: Arc<dyn Fn() -> Result<Factor> + Send + Sync>,
cached: Arc<Mutex<Option<Factor>>>,
}
impl std::fmt::Debug for LazyFactor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LazyFactor")
.field(
"cached",
&self
.cached
.lock()
.expect("lock should not be poisoned")
.is_some(),
)
.finish()
}
}
impl LazyFactor {
pub fn new<F>(computation: F) -> Self
where
F: Fn() -> Result<Factor> + Send + Sync + 'static,
{
Self {
computation: Arc::new(computation),
cached: Arc::new(Mutex::new(None)),
}
}
pub fn from_factor(factor: Factor) -> Self {
Self {
computation: Arc::new(move || {
Err(PgmError::InvalidDistribution(
"Already computed".to_string(),
))
}),
cached: Arc::new(Mutex::new(Some(factor))),
}
}
pub fn evaluate(&self) -> Result<Factor> {
let mut cached = self.cached.lock().expect("lock should not be poisoned");
if let Some(ref factor) = *cached {
return Ok(factor.clone());
}
let result = (self.computation)()?;
*cached = Some(result.clone());
Ok(result)
}
pub fn is_computed(&self) -> bool {
self.cached
.lock()
.expect("lock should not be poisoned")
.is_some()
}
pub fn clear_cache(&self) {
let mut cached = self.cached.lock().expect("lock should not be poisoned");
*cached = None;
}
pub fn lazy_product(a: LazyFactor, b: LazyFactor) -> LazyFactor {
LazyFactor::new(move || {
let factor_a = a.evaluate()?;
let factor_b = b.evaluate()?;
factor_a.product(&factor_b)
})
}
pub fn lazy_marginalize(factor: LazyFactor, var: String) -> LazyFactor {
LazyFactor::new(move || {
let f = factor.evaluate()?;
f.marginalize_out(&var)
})
}
}
pub struct StreamingFactorGraph {
variables: HashMap<String, VariableInfo>,
factor_generators: Vec<Box<dyn Fn() -> Result<Factor> + Send + Sync>>,
#[allow(dead_code)]
pool: Arc<FactorPool>,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct VariableInfo {
domain: String,
cardinality: usize,
}
impl StreamingFactorGraph {
pub fn new() -> Self {
Self {
variables: HashMap::new(),
factor_generators: Vec::new(),
pool: Arc::new(FactorPool::default()),
}
}
pub fn with_pool(pool: Arc<FactorPool>) -> Self {
Self {
variables: HashMap::new(),
factor_generators: Vec::new(),
pool,
}
}
pub fn add_variable(&mut self, name: String, domain: String, cardinality: usize) {
self.variables.insert(
name,
VariableInfo {
domain,
cardinality,
},
);
}
pub fn add_factor<F>(&mut self, generator: F)
where
F: Fn() -> Result<Factor> + Send + Sync + 'static,
{
self.factor_generators.push(Box::new(generator));
}
pub fn stream_factors(&self) -> impl Iterator<Item = Result<Factor>> + '_ {
self.factor_generators.iter().map(|gen| gen())
}
pub fn streaming_product(&self) -> Result<Factor> {
let mut result: Option<Factor> = None;
for gen in &self.factor_generators {
let factor = gen()?;
result = match result {
None => Some(factor),
Some(r) => Some(r.product(&factor)?),
};
}
result.ok_or_else(|| PgmError::InvalidDistribution("No factors in graph".to_string()))
}
pub fn num_variables(&self) -> usize {
self.variables.len()
}
pub fn num_factors(&self) -> usize {
self.factor_generators.len()
}
}
impl Default for StreamingFactorGraph {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct CompressedFactor {
pub variables: Vec<String>,
pub cardinalities: Vec<usize>,
quantized: Vec<u16>,
min_value: f64,
scale: f64,
}
impl CompressedFactor {
pub fn from_factor(factor: &Factor) -> Self {
let values: Vec<f64> = factor.values.iter().copied().collect();
let cardinalities: Vec<usize> = factor.values.shape().to_vec();
let min_value = values.iter().copied().fold(f64::INFINITY, f64::min);
let max_value = values.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let scale = if max_value > min_value {
(max_value - min_value) / 65535.0
} else {
1.0
};
let quantized: Vec<u16> = values
.iter()
.map(|&v| ((v - min_value) / scale).round() as u16)
.collect();
Self {
variables: factor.variables.clone(),
cardinalities,
quantized,
min_value,
scale,
}
}
pub fn to_factor(&self) -> Result<Factor> {
let values: Vec<f64> = self
.quantized
.iter()
.map(|&q| self.min_value + (q as f64) * self.scale)
.collect();
let array = ArrayD::from_shape_vec(IxDyn(&self.cardinalities), values)?;
Factor::new("compressed".to_string(), self.variables.clone(), array)
}
pub fn memory_size(&self) -> usize {
self.quantized.len() * std::mem::size_of::<u16>()
+ self.variables.len() * std::mem::size_of::<String>()
+ self.cardinalities.len() * std::mem::size_of::<usize>()
+ 2 * std::mem::size_of::<f64>()
}
pub fn compression_ratio(&self) -> f64 {
let original = self.quantized.len() * std::mem::size_of::<f64>();
let compressed = self.quantized.len() * std::mem::size_of::<u16>();
if compressed > 0 {
original as f64 / compressed as f64
} else {
1.0
}
}
}
#[derive(Debug, Clone)]
pub struct BlockSparseFactor {
pub variables: Vec<String>,
pub cardinalities: Vec<usize>,
pub block_size: usize,
blocks: HashMap<Vec<usize>, Vec<f64>>,
default_value: f64,
}
impl BlockSparseFactor {
pub fn new(variables: Vec<String>, cardinalities: Vec<usize>, block_size: usize) -> Self {
Self {
variables,
cardinalities,
block_size,
blocks: HashMap::new(),
default_value: 0.0,
}
}
pub fn from_factor(factor: &Factor, block_size: usize, threshold: f64) -> Self {
let shape: Vec<usize> = factor.values.shape().to_vec();
let mut sparse = Self::new(factor.variables.clone(), shape.clone(), block_size);
sparse.default_value = 0.0;
let block_dims: Vec<usize> = shape.iter().map(|&d| d.div_ceil(block_size)).collect();
let total_blocks: usize = block_dims.iter().product();
for block_flat in 0..total_blocks {
let block_indices = SparseFactor::flat_to_indices(block_flat, &block_dims);
let block_total = block_size.pow(shape.len() as u32);
let mut block_values = Vec::with_capacity(block_total);
let mut has_nonzero = false;
for local_flat in 0..block_total {
let local_indices =
SparseFactor::flat_to_indices(local_flat, &vec![block_size; shape.len()]);
let global_indices: Vec<usize> = block_indices
.iter()
.zip(local_indices.iter())
.zip(shape.iter())
.map(|((&bi, &li), &s)| (bi * block_size + li).min(s - 1))
.collect();
let value = factor.values[global_indices.as_slice()];
block_values.push(value);
if value.abs() > threshold {
has_nonzero = true;
}
}
if has_nonzero {
sparse.blocks.insert(block_indices, block_values);
}
}
sparse
}
pub fn num_blocks(&self) -> usize {
self.blocks.len()
}
pub fn block_sparsity(&self) -> f64 {
let block_dims: Vec<usize> = self
.cardinalities
.iter()
.map(|&d| d.div_ceil(self.block_size))
.collect();
let total_blocks: usize = block_dims.iter().product();
if total_blocks > 0 {
1.0 - (self.blocks.len() as f64 / total_blocks as f64)
} else {
1.0
}
}
}
pub fn estimate_memory_usage(
num_variables: usize,
avg_cardinality: usize,
num_factors: usize,
avg_scope_size: usize,
) -> MemoryEstimate {
let bytes_per_entry = std::mem::size_of::<f64>();
let avg_factor_size = avg_cardinality.pow(avg_scope_size as u32);
let total_factor_bytes = num_factors * avg_factor_size * bytes_per_entry;
let edges = num_factors * avg_scope_size;
let message_bytes = 2 * edges * avg_cardinality * bytes_per_entry;
let marginal_bytes = num_variables * avg_cardinality * bytes_per_entry;
MemoryEstimate {
factor_bytes: total_factor_bytes,
message_bytes,
marginal_bytes,
total_bytes: total_factor_bytes + message_bytes + marginal_bytes,
}
}
#[derive(Debug, Clone)]
pub struct MemoryEstimate {
pub factor_bytes: usize,
pub message_bytes: usize,
pub marginal_bytes: usize,
pub total_bytes: usize,
}
impl std::fmt::Display for MemoryEstimate {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let to_mb = |bytes: usize| bytes as f64 / 1_048_576.0;
write!(
f,
"Memory Estimate: {:.2} MB total (factors: {:.2} MB, messages: {:.2} MB, marginals: {:.2} MB)",
to_mb(self.total_bytes),
to_mb(self.factor_bytes),
to_mb(self.message_bytes),
to_mb(self.marginal_bytes)
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::Array;
#[test]
fn test_factor_pool_allocation() {
let pool = FactorPool::new(10);
let arr1 = pool.allocate(100);
assert_eq!(arr1.len(), 100);
pool.return_array(arr1);
assert_eq!(pool.stats().returns, 1);
let arr2 = pool.allocate(100);
assert_eq!(arr2.len(), 100);
assert_eq!(pool.stats().hits, 1);
}
#[test]
fn test_factor_pool_hit_rate() {
let pool = FactorPool::new(10);
let arr = pool.allocate(50);
pool.return_array(arr);
let _ = pool.allocate(50);
assert!(pool.hit_rate() > 0.4); }
#[test]
fn test_sparse_factor_creation() {
let mut sparse = SparseFactor::new(vec!["x".to_string()], vec![4]);
sparse.set(vec![0], 1.0);
sparse.set(vec![2], 0.5);
assert_abs_diff_eq!(sparse.get(&[0]), 1.0, epsilon = 1e-10);
assert_abs_diff_eq!(sparse.get(&[1]), 0.0, epsilon = 1e-10);
assert_abs_diff_eq!(sparse.get(&[2]), 0.5, epsilon = 1e-10);
}
#[test]
fn test_sparse_factor_from_dense() {
let factor = Factor::new(
"test".to_string(),
vec!["x".to_string()],
Array::from_vec(vec![0.0, 1.0, 0.0, 0.5]).into_dyn(),
)
.expect("unwrap");
let sparse = SparseFactor::from_dense(&factor, 0.1);
assert_eq!(sparse.entries.len(), 2);
let dense = sparse.to_dense().expect("unwrap");
assert_abs_diff_eq!(dense.values[[1]], 1.0, epsilon = 1e-10);
assert_abs_diff_eq!(dense.values[[3]], 0.5, epsilon = 1e-10);
}
#[test]
fn test_sparse_factor_sparsity() {
let mut sparse = SparseFactor::new(vec!["x".to_string()], vec![100]);
sparse.set(vec![50], 1.0);
let sparsity = sparse.sparsity();
assert!(sparsity > 0.98); }
#[test]
fn test_lazy_factor_deferred() {
let counter = Arc::new(Mutex::new(0));
let counter_clone = counter.clone();
let lazy = LazyFactor::new(move || {
let mut count = counter_clone.lock().expect("unwrap");
*count += 1;
Factor::new(
"test".to_string(),
vec!["x".to_string()],
Array::from_vec(vec![0.5, 0.5]).into_dyn(),
)
});
assert!(!lazy.is_computed());
assert_eq!(*counter.lock().expect("unwrap"), 0);
let _ = lazy.evaluate().expect("unwrap");
assert!(lazy.is_computed());
assert_eq!(*counter.lock().expect("unwrap"), 1);
let _ = lazy.evaluate().expect("unwrap");
assert_eq!(*counter.lock().expect("unwrap"), 1);
}
#[test]
fn test_lazy_factor_from_factor() {
let factor = Factor::new(
"test".to_string(),
vec!["x".to_string()],
Array::from_vec(vec![0.3, 0.7]).into_dyn(),
)
.expect("unwrap");
let lazy = LazyFactor::from_factor(factor);
assert!(lazy.is_computed());
let result = lazy.evaluate().expect("unwrap");
assert_abs_diff_eq!(result.values[[0]], 0.3, epsilon = 1e-10);
}
#[test]
fn test_compressed_factor() {
let factor = Factor::new(
"test".to_string(),
vec!["x".to_string()],
Array::from_vec(vec![0.1, 0.2, 0.3, 0.4]).into_dyn(),
)
.expect("unwrap");
let compressed = CompressedFactor::from_factor(&factor);
let decompressed = compressed.to_factor().expect("unwrap");
for i in 0..4 {
assert_abs_diff_eq!(factor.values[[i]], decompressed.values[[i]], epsilon = 0.01);
}
}
#[test]
fn test_compressed_factor_ratio() {
let factor = Factor::new(
"test".to_string(),
vec!["x".to_string(), "y".to_string()],
ArrayD::from_elem(IxDyn(&[10, 10]), 0.5),
)
.expect("unwrap");
let compressed = CompressedFactor::from_factor(&factor);
let ratio = compressed.compression_ratio();
assert!(ratio > 3.5);
}
#[test]
fn test_streaming_factor_graph() {
let mut graph = StreamingFactorGraph::new();
graph.add_variable("x".to_string(), "Binary".to_string(), 2);
graph.add_variable("y".to_string(), "Binary".to_string(), 2);
graph.add_factor(|| {
Factor::new(
"factor_x".to_string(),
vec!["x".to_string()],
Array::from_vec(vec![0.3, 0.7]).into_dyn(),
)
});
graph.add_factor(|| {
Factor::new(
"factor_y".to_string(),
vec!["y".to_string()],
Array::from_vec(vec![0.4, 0.6]).into_dyn(),
)
});
assert_eq!(graph.num_variables(), 2);
assert_eq!(graph.num_factors(), 2);
}
#[test]
fn test_memory_estimate() {
let estimate = estimate_memory_usage(10, 3, 20, 3);
assert!(estimate.total_bytes > 0);
assert!(estimate.factor_bytes > 0);
assert!(estimate.message_bytes > 0);
}
#[test]
fn test_block_sparse_factor() {
let factor = Factor::new(
"test".to_string(),
vec!["x".to_string(), "y".to_string()],
ArrayD::from_elem(IxDyn(&[8, 8]), 0.0),
)
.expect("unwrap");
let block_sparse = BlockSparseFactor::from_factor(&factor, 4, 0.001);
let sparsity = block_sparse.block_sparsity();
assert!(sparsity > 0.99);
}
}