use scirs2_core::ndarray::{Array1, ArrayD};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use super::cycle::CycleAnalysis;
use super::energy::BetheFreeEnergy;
use super::types::{LbpConvergenceMonitor, LbpDampingPolicy, UpdateSchedule};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct LoopyBpConfig {
pub max_iterations: usize,
pub tolerance: f64,
pub damping: LbpDampingPolicy,
pub schedule: UpdateSchedule,
pub compute_bethe: bool,
pub seed: u64,
}
impl Default for LoopyBpConfig {
fn default() -> Self {
Self {
max_iterations: 200,
tolerance: 1e-6,
damping: LbpDampingPolicy::Uniform(0.5),
schedule: UpdateSchedule::Synchronous,
compute_bethe: true,
seed: 42,
}
}
}
impl LoopyBpConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_iterations(mut self, n: usize) -> Self {
self.max_iterations = n;
self
}
pub fn with_tolerance(mut self, tol: f64) -> Self {
self.tolerance = tol;
self
}
pub fn with_damping(mut self, d: LbpDampingPolicy) -> Self {
self.damping = d;
self
}
pub fn with_schedule(mut self, s: UpdateSchedule) -> Self {
self.schedule = s;
self
}
}
#[derive(Clone, Debug)]
pub struct LoopyBpResult {
pub beliefs: HashMap<String, Array1<f64>>,
pub factor_beliefs: HashMap<String, ArrayD<f64>>,
pub convergence: LbpConvergenceMonitor,
pub bethe: Option<BetheFreeEnergy>,
pub cycle_analysis: CycleAnalysis,
}