mod cache;
pub mod convergence;
pub mod dbn;
mod elimination_ordering;
mod error;
mod expectation_propagation;
mod factor;
pub mod factor_graph_viz;
mod graph;
mod inference;
pub mod influence;
mod junction_tree;
mod linear_chain_crf;
pub mod loopy_bp;
pub mod memory;
mod message_passing;
mod models;
mod parallel_message_passing;
pub mod parameter_learning;
pub mod quantrs_hooks;
pub mod quantum_circuit;
pub mod quantum_simulation;
mod sampling;
pub mod tensor_network_bridge;
mod variable_elimination;
mod variational;
pub mod vmp;
pub use cache::{CacheStats, CachedFactor, FactorCache};
pub use convergence::{
ConvergenceConfig, ConvergenceError, ConvergenceMonitor, ConvergenceState, DampingSchedule,
InferenceStats,
};
pub use dbn::{CoupledDBN, CouplingFactor, DBNBuilder, DynamicBayesianNetwork, TemporalVar};
pub use elimination_ordering::{EliminationOrdering, EliminationStrategy};
pub use error::{PgmError, Result};
pub use expectation_propagation::{ExpectationPropagation, GaussianEP, GaussianSite, Site};
pub use factor::{Factor, FactorOp};
pub use factor_graph_viz::{
render_ascii, render_dot, FactorGraphModel, FactorGraphStats, VizFactorNode, VizVariableNode,
};
pub use graph::{FactorGraph, FactorNode, VariableNode};
pub use inference::{ConditionalQuery, InferenceEngine, MarginalizationQuery};
pub use influence::{
InfluenceDiagram, InfluenceDiagramBuilder, InfluenceNode, MultiAttributeUtility, NodeType,
};
pub use junction_tree::{Clique, JunctionTree, JunctionTreeEdge, Separator};
pub use linear_chain_crf::{
EmissionFeature, FeatureFunction, IdentityFeature, LinearChainCRF, TransitionFeature,
};
pub use loopy_bp::{
bethe_free_energy, BetheFreeEnergy, CycleAnalysis, CycleDetector, LbpConvergenceMonitor,
LbpDampingPolicy, LbpIterStats, LogMessage, LoopyBeliefPropagation, LoopyBpConfig,
LoopyBpResult, UpdateSchedule,
};
pub use memory::{
BlockSparseFactor, CompressedFactor, FactorPool, LazyFactor, MemoryEstimate, PoolStats,
SparseFactor, StreamingFactorGraph,
};
pub use message_passing::{
ConvergenceStats, MaxProductAlgorithm, MessagePassingAlgorithm, SumProductAlgorithm,
};
pub use models::{BayesianNetwork, ConditionalRandomField, HiddenMarkovModel, MarkovRandomField};
pub use parallel_message_passing::{ParallelMaxProduct, ParallelSumProduct};
pub use parameter_learning::{
BaumWelchLearner, BayesianEstimator, MaximumLikelihoodEstimator, SimpleHMM,
};
pub use quantrs_hooks::{
AnnealingConfig, DistributionExport, DistributionMetadata, ModelExport, ModelStatistics,
QuantRSAssignment, QuantRSDistribution, QuantRSInferenceQuery, QuantRSModelExport,
QuantRSParameterLearning, QuantRSSamplingHook, QuantumAnnealing, QuantumInference,
QuantumSolution, QuantumSolutionMetadata,
};
pub use quantum_circuit::{
tlexpr_to_qaoa_circuit, IsingModel, QAOAConfig, QAOAResult, QUBOProblem, QuantumCircuitBuilder,
};
pub use quantum_simulation::{
run_qaoa, QuantumSimulationBackend, SimulatedState, SimulationConfig,
};
pub use sampling::{
Assignment, GibbsSampler, ImportanceSampler, LikelihoodWeighting, Particle, ParticleFilter,
ProposalDistribution, WeightedSample,
};
pub use tensor_network_bridge::{
factor_graph_to_tensor_network, linear_chain_to_mps, MatrixProductState, Tensor, TensorNetwork,
TensorNetworkStats,
};
pub use variable_elimination::VariableElimination;
pub use variational::{BetheApproximation, MeanFieldInference, TreeReweightedBP};
pub use vmp::{
categorical_kl, dirichlet_kl, gaussian_kl, gaussian_kl_fixed_precision,
BetaBernoulliObservation, BetaNP, CategoricalNP, DirichletNP, ExponentialFamily, Family,
GammaNP, GammaPoissonObservation, GaussianNP, MessageDirection, VariationalMessagePassing,
VariationalState, VmpConfig, VmpFactor, VmpMessage, VmpResult,
};
use scirs2_core::ndarray::ArrayD;
use std::collections::HashMap;
use tensorlogic_ir::TLExpr;
pub fn expr_to_factor_graph(expr: &TLExpr) -> Result<FactorGraph> {
let mut graph = FactorGraph::new();
extract_factors(expr, &mut graph)?;
Ok(graph)
}
fn extract_factors(expr: &TLExpr, graph: &mut FactorGraph) -> Result<()> {
match expr {
TLExpr::Pred { name, args } => {
let var_names: Vec<String> = args
.iter()
.filter_map(|term| match term {
tensorlogic_ir::Term::Var(v) => Some(v.clone()),
_ => None,
})
.collect();
for var_name in &var_names {
if graph.get_variable(var_name).is_none() {
graph.add_variable(var_name.clone(), "default".to_string());
}
}
if !var_names.is_empty() {
graph.add_factor_from_predicate(name, &var_names)?;
}
}
TLExpr::And(left, right) => {
extract_factors(left, graph)?;
extract_factors(right, graph)?;
}
TLExpr::Exists { var, domain, body } | TLExpr::ForAll { var, domain, body } => {
graph.add_variable(var.clone(), domain.clone());
extract_factors(body, graph)?;
}
TLExpr::Imply(premise, conclusion) => {
extract_factors(premise, graph)?;
extract_factors(conclusion, graph)?;
}
TLExpr::Not(inner) => {
extract_factors(inner, graph)?;
}
_ => {
}
}
Ok(())
}
pub fn message_passing_reduce(
graph: &FactorGraph,
algorithm: &dyn MessagePassingAlgorithm,
) -> Result<HashMap<String, ArrayD<f64>>> {
algorithm.run(graph)
}
pub fn marginalize(
joint_distribution: &ArrayD<f64>,
variable_idx: usize,
axes_to_sum: &[usize],
) -> Result<ArrayD<f64>> {
use scirs2_core::ndarray::Axis;
let mut result = joint_distribution.clone();
for &axis in axes_to_sum.iter().rev() {
if axis != variable_idx {
result = result.sum_axis(Axis(axis));
}
}
Ok(result)
}
pub fn condition(
joint_distribution: &ArrayD<f64>,
evidence: &HashMap<usize, usize>,
) -> Result<ArrayD<f64>> {
let mut result = joint_distribution.clone();
for (&var_idx, &value) in evidence {
result = result.index_axis_move(scirs2_core::ndarray::Axis(var_idx), value);
}
let sum: f64 = result.iter().sum();
if sum > 0.0 {
result /= sum;
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::Array;
use tensorlogic_ir::Term;
#[test]
fn test_expr_to_factor_graph() {
let expr = TLExpr::pred("P", vec![Term::var("x")]);
let graph = expr_to_factor_graph(&expr).expect("unwrap");
assert!(!graph.is_empty());
}
#[test]
fn test_marginalize_simple() {
let joint = Array::from_shape_vec(vec![2, 2], vec![0.25, 0.25, 0.25, 0.25])
.expect("unwrap")
.into_dyn();
let marginal = marginalize(&joint, 0, &[0, 1]).expect("unwrap");
assert_eq!(marginal.ndim(), 1);
assert_abs_diff_eq!(marginal.sum(), 1.0, epsilon = 1e-10);
}
#[test]
fn test_condition_simple() {
let joint = Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
.expect("unwrap")
.into_dyn();
let mut evidence = HashMap::new();
evidence.insert(1, 1);
let conditional = condition(&joint, &evidence).expect("unwrap");
assert_eq!(conditional.ndim(), 1);
assert_abs_diff_eq!(conditional.sum(), 1.0, epsilon = 1e-10);
}
#[test]
fn test_factor_graph_construction() {
let mut graph = FactorGraph::new();
graph.add_variable("x".to_string(), "Domain1".to_string());
graph.add_variable("y".to_string(), "Domain2".to_string());
assert_eq!(graph.num_variables(), 2);
}
}