tensorlogic-quantrs-hooks
Probabilistic Graphical Model Integration for TensorLogic
Bridge between logic-based reasoning and probabilistic inference through factor graphs, belief propagation, and variational methods.
Overview
tensorlogic-quantrs-hooks enables probabilistic reasoning over TensorLogic expressions by converting logical rules into factor graphs and applying state-of-the-art inference algorithms. This crate seamlessly integrates with the QuantRS2 ecosystem for probabilistic programming.
Key Features
- TLExpr → Factor Graph Conversion: Automatic translation of logical expressions to PGM representations
- Exact Inference:
- Sum-product and max-product belief propagation for tree-structured graphs
- Junction tree algorithm for exact inference on arbitrary graphs
- Approximate Inference:
- Loopy BP: Message passing for graphs with cycles, with damping and convergence detection
- Variational Inference: Mean-field, Bethe approximation, and tree-reweighted BP
- Expectation Propagation (EP): Moment matching with site approximations for discrete and continuous variables
- MCMC Sampling: Gibbs sampling for approximate posterior computation
- QuantRS2 Integration:
- Distribution and model export to QuantRS format
- JSON serialization for ecosystem interoperability
- Information-theoretic utilities (mutual information, KL divergence)
- Parameter Learning:
- Maximum Likelihood Estimation (MLE) for discrete distributions
- Bayesian estimation with Dirichlet priors
- Baum-Welch algorithm (EM) for Hidden Markov Models
- Sequence Models:
- Linear-chain CRFs for sequence labeling with Viterbi decoding
- Feature functions (transition, emission, custom)
- Forward-backward algorithm for marginal probabilities
- Full SciRS2 Integration: All tensor operations use SciRS2 for performance and consistency
Installation
Add to your Cargo.toml:
[]
= "0.1.0-alpha.1"
= "0.1.0-rc.2" # For tensor operations
Quick Start
Basic Factor Graph Creation
use ;
use Array;
// Create factor graph
let mut graph = new;
// Add binary variables
graph.add_variable_with_card;
graph.add_variable_with_card;
// Add factor P(x)
let px_values = from_shape_vec
.unwrap
.into_dyn;
let px = new.unwrap;
graph.add_factor.unwrap;
// Add factor P(y|x)
let pyx_values = from_shape_vec.unwrap.into_dyn;
let pyx = new.unwrap;
graph.add_factor.unwrap;
Converting TLExpr to Factor Graph
use TLExpr;
use expr_to_factor_graph;
// Define logical expression
let expr = and;
// Convert to factor graph
let graph = expr_to_factor_graph.unwrap;
println!;
println!;
Core Concepts
Factor Graphs
A factor graph is a bipartite graph with:
- Variable nodes: Represent random variables
- Factor nodes: Represent functions over subsets of variables
Variables: X₁ X₂ X₃
| \ / | |
Factors: φ₁ φ₂ φ₃
Factors
Factors are functions φ(X₁, X₂, ..., Xₖ) → ℝ⁺ representing probabilities or potentials.
use Factor;
use Array;
// Create a binary factor P(X, Y)
let values = from_shape_vec.unwrap.into_dyn;
let factor = new.unwrap;
// Normalize to sum to 1
let mut normalized = factor.clone;
normalized.normalize;
Factor Operations
Factor Product
Combine factors over different variable sets:
// φ₁(X) = [0.6, 0.4]
let f1_values = from_shape_vec
.unwrap.into_dyn;
let f1 = new.unwrap;
// φ₂(Y) = [0.7, 0.3]
let f2_values = from_shape_vec
.unwrap.into_dyn;
let f2 = new.unwrap;
// φ₁(X) × φ₂(Y) = φ(X, Y)
let product = f1.product.unwrap;
assert_eq!;
assert_eq!;
Marginalization
Sum out variables to compute marginals:
// φ(X, Y) → φ(X) = Σ_Y φ(X, Y)
let values = from_shape_vec
.unwrap.into_dyn;
let factor = new.unwrap;
let marginal = factor.marginalize_out.unwrap;
assert_eq!;
// Result: [0.1 + 0.2, 0.3 + 0.4] = [0.3, 0.7]
Factor Reduction (Evidence)
Condition on observed values:
// Observe Y = 1, compute P(X | Y=1)
let conditional = factor.reduce.unwrap;
assert_eq!;
// Result: [0.2, 0.4] (before normalization)
Factor Division
Compute message quotients:
let f1 = new.unwrap;
let f2 = new.unwrap;
let result = f1.divide.unwrap;
// Result: [0.6/0.3, 0.4/0.2] = [2.0, 2.0]
Inference Algorithms
1. Sum-Product Belief Propagation
Exact inference for tree-structured graphs, loopy BP for graphs with cycles.
use ;
// Create algorithm with custom parameters
let algorithm = new;
// Create inference engine
let engine = new;
// Compute marginal P(X)
let query = MarginalizationQuery ;
let marginal = engine.marginalize.unwrap;
println!;
println!;
Loopy BP with Damping
For graphs with cycles, use damping to improve convergence:
let loopy_bp = new;
let engine = new;
let result = engine.marginalize.unwrap;
2. Max-Product Algorithm (MAP Inference)
Find the most probable assignment:
use MaxProductAlgorithm;
let max_product = default;
let engine = new;
// Compute MAP assignment
let marginals = engine.run.unwrap;
// Find most probable values
for in &marginals
3. Variational Inference (Mean-Field)
Scalable approximate inference using mean-field approximation:
use MeanFieldInference;
// Create mean-field inference engine
let mean_field = new;
// Run inference
let result = mean_field.infer.unwrap;
// Access variational parameters
for in &result.variational_params
// Check ELBO for convergence
println!;
println!;
ELBO Monitoring
let mut elbo_history = Vecnew;
let mean_field = with_callback;
3.1. Structured Variational Inference
Beyond mean-field, structured variational methods leverage the factor graph structure for improved accuracy:
Bethe Approximation
Uses the graph structure to define a structured approximation (equivalent to loopy BP fixed points):
use BetheApproximation;
// Create Bethe approximation engine
let bethe = new;
// Run inference
let beliefs = bethe.run?;
// Compute factor beliefs from variable beliefs
let factor_beliefs = bethe.compute_factor_beliefs?;
// Compute Bethe free energy
let free_energy = bethe.compute_free_energy?;
println!;
Advantages over Mean-Field:
- Respects factor graph structure
- More accurate marginals for loopy graphs
- Similar computational cost to loopy BP
Tree-Reweighted Belief Propagation (TRW-BP)
Provides upper bounds on the log partition function using edge reweighting:
use TreeReweightedBP;
// Create TRW-BP engine
let mut trw = new;
// Optionally set custom edge weights
trw.set_edge_weight;
// Or use uniform weights (default)
trw.initialize_uniform_weights;
// Run inference
let beliefs = trw.run?;
// Compute upper bound on log Z
let log_z_bound = trw.compute_log_partition_upper_bound?;
Key Properties:
- Provides upper bounds on log partition function
- Guaranteed convergence for convex tree mixtures
- Particularly robust for loopy graphs
- Uses edge appearance probabilities ρ_e ∈ [0,1]
Comparison: Mean-Field vs. Bethe vs. TRW-BP
// Mean-Field: Fastest, assumes full independence
let mf = default;
let mf_beliefs = mf.run?;
let mf_elbo = mf.compute_elbo?;
// Bethe: Uses graph structure, more accurate
let bethe = default;
let bethe_beliefs = bethe.run?;
// TRW-BP: Provides bounds, most robust
let mut trw = default;
let trw_beliefs = trw.run?;
See examples/structured_variational.rs for a complete grid MRF comparison.
4. Gibbs Sampling
MCMC sampling for approximate marginals:
use GibbsSampler;
// Create sampler
let sampler = new;
// Run sampling
let samples = sampler.sample.unwrap;
// Compute empirical marginals
let marginals = sampler.compute_marginals.unwrap;
for in &marginals
Sample Statistics
// Check acceptance rates
let stats = sampler.get_statistics;
println!;
println!;
5. Junction Tree Algorithm (Exact Inference)
The junction tree algorithm provides exact inference for any graph structure by constructing a tree of cliques:
use JunctionTree;
// Build junction tree from factor graph
let mut tree = from_factor_graph?;
// Calibrate the tree (message passing)
tree.calibrate?;
// Query exact marginals
let p_x = tree.query_marginal?;
println!;
println!;
// Query joint marginals
let p_xy = tree.query_joint_marginal?;
Junction Tree Properties
// Check treewidth (complexity indicator)
let tw = tree.treewidth;
println!;
// Verify running intersection property
assert!;
// Inspect clique structure
for in tree.cliques.iter.enumerate
Advantages:
- Exact inference (no approximation error)
- Efficient for low-treewidth graphs
- Handles any query after single calibration
- Guarantees consistency across marginals
Complexity: O(n × d^(w+1)) where w is the treewidth, d is max domain size
6. Expectation Propagation (EP)
EP approximates complex posteriors using moment matching with site approximations:
use ExpectationPropagation;
// Create EP algorithm
let ep = new;
// Run EP inference
let marginals = ep.run?;
// Access marginals
for in &marginals
Gaussian EP for Continuous Variables
For continuous variables, use Gaussian EP with natural parameterization:
use ;
// Create Gaussian EP
let gep = new;
// Create Gaussian sites
let site1 = new; // precision=2, precision_mean=4
let site2 = new;
// Combine sites
let product = site1.product;
println!;
Key Features:
- Site approximations and cavity distributions
- Moment matching for discrete and continuous variables
- Damping for improved convergence
- Natural parameterization for Gaussians
7. Linear-chain CRFs (Sequence Labeling)
Linear-chain CRFs enable efficient sequence labeling with structured prediction:
use ;
use Array;
// Create CRF with 3 labels
let mut crf = new;
// Set transition weights (3x3 matrix)
let transition_weights = from_shape_vec.unwrap..unwrap;
crf.set_transition_weights?;
// Viterbi decoding (most likely sequence)
let input_sequence = vec!;
let = crf.viterbi?;
println!;
// Compute marginal probabilities
let marginals = crf.marginals?;
for t in 0..input_sequence.len
Custom Feature Functions
Define custom features for domain-specific sequence labeling:
use FeatureFunction;
// Add feature with weight
let feature = Boxnew;
crf.add_feature;
Applications:
- Part-of-speech tagging
- Named entity recognition
- Speech recognition
- Bioinformatics (protein sequence analysis)
Algorithms:
- Viterbi: O(T × S²) for most likely sequence
- Forward-backward: O(T × S²) for marginals
- Where T = sequence length, S = number of states
QuantRS2 Integration
Distribution Export
Convert factors to QuantRS2-compatible distributions for ecosystem integration:
use QuantRSDistribution;
// Export factor to QuantRS format
let factor = new?;
let dist_export = factor.to_quantrs_distribution?;
println!;
println!;
println!;
Model Export
Export entire factor graphs for use across the COOLJAPAN ecosystem:
use QuantRSModelExport;
// Export model to QuantRS2 format
let model_export = graph.to_quantrs_model?;
println!;
println!;
println!;
// Get model statistics
let stats = graph.model_stats;
println!;
println!;
JSON Serialization
Export models as JSON for interoperability:
use utils;
// Export to JSON
let json = export_to_json?;
println!;
// Import from JSON
let model = import_from_json?;
Information Theory
Compute information-theoretic quantities:
use utils;
// Mutual information
let mi = mutual_information?;
println!;
// KL divergence
let kl = kl_divergence?;
println!;
Parameter Learning
Learn model parameters from observed data.
Maximum Likelihood Estimation
Estimate parameters from complete data (all variables observed):
use MaximumLikelihoodEstimator;
use HashMap;
let estimator = new;
// Create training data
let mut data = Vecnew;
for _ in 0..70
for _ in 0..30
// Estimate P(Weather)
let probs = estimator.estimate_marginal?;
// Result: [0.7, 0.3]
Bayesian Estimation with Priors
Use Dirichlet priors for robust estimation:
use BayesianEstimator;
let estimator = new; // Prior strength
// Estimate with prior
let probs = estimator.estimate_marginal?;
Baum-Welch Algorithm for HMMs
Learn HMM parameters from observation sequences (even when hidden states are not observed):
use ;
// Create an HMM with random initialization
let mut hmm = new_random; // 2 states, 3 observations
// Observation sequences (hidden states unknown)
let observation_sequences = vec!;
// Learn parameters
let learner = with_verbose;
let log_likelihood = learner.learn?;
println!;
Key Features:
- Expectation-Maximization (EM) algorithm for HMMs
- Forward-backward message passing
- Automatic convergence detection
- Verbose mode for monitoring progress
Advanced Usage
Conditional Queries
Compute P(X | Y=y):
use ConditionalQuery;
use HashMap;
// Evidence: Y = 1
let mut evidence = new;
evidence.insert;
let query = ConditionalQuery ;
let conditional = engine.conditional.unwrap;
Custom Convergence Criteria
Multi-Variable Queries
// Compute joint marginal P(X, Y)
let vars_to_keep = vec!;
let joint_marginal = compute_joint_marginal.unwrap;
Integration with TensorLogic
From TLExpr to Probabilities
use ;
use ;
// Define logical rule: ∃x. P(x) ∧ Q(x)
let expr = exists;
// Convert to factor graph
let graph = expr_to_factor_graph.unwrap;
// Run probabilistic inference
let algorithm = Boxnew;
let engine = new;
let marginals = engine.run.unwrap;
Probabilistic Logic Programming
// Weighted rules with confidence scores
let rules = vec!;
// Convert to factor graph with weights
let mut graph = new;
for in rules
Performance Considerations
Algorithm Selection Guide
| Graph Type | Recommended Algorithm | Complexity | Notes |
|---|---|---|---|
| Tree | Sum-Product | O(N × D²) | Exact inference |
| Low Treewidth | Junction Tree | O(N × D^(w+1)) | Exact, w = treewidth |
| Small Loopy | Loopy BP with damping | O(I × N × D²) | Approximate |
| Large Loopy | Mean-Field VI | O(I × N × D) | Fast approximate |
| Large Loopy (Structured) | Bethe / TRW-BP | O(I × E × D²) | Better accuracy |
| Complex Posteriors | Expectation Propagation | O(I × F × D²) | Moment matching |
| Sequence Labeling | Linear-chain CRF | O(T × S²) | Viterbi/Forward-backward |
| Any | Gibbs Sampling | O(S × N × D) | MCMC |
Where:
- N = number of variables
- D = max domain size
- I = iterations to converge
- S = number of samples
- E = number of edges
- F = number of factors
- w = treewidth
- T = sequence length
- S = number of states (for CRF)
Optimization Tips
- Use appropriate cardinalities: Smaller domains = faster inference
- Enable damping for loopy graphs: Improves convergence
- Tune convergence tolerance: Balance accuracy vs. speed
- Use variational inference for large graphs: O(N) vs O(N²) for BP
- Batch factor operations: Leverage SciRS2 vectorization
Memory Usage
// Estimate memory for factor graph
let num_vars = graph.num_variables;
let num_factors = graph.num_factors;
let avg_cardinality = 10;
let memory_mb = / 1_000_000;
println!;
Examples
Example 1: Bayesian Network
// Classic cancer/smoking example
let mut graph = new;
// Variables
graph.add_variable_with_card;
graph.add_variable_with_card;
graph.add_variable_with_card;
// Prior P(Smoking)
let p_smoking = new.unwrap;
// P(Cancer | Smoking)
let p_cancer_given_smoking = new.unwrap;
// P(XRay | Cancer)
let p_xray_given_cancer = new.unwrap;
graph.add_factor.unwrap;
graph.add_factor.unwrap;
graph.add_factor.unwrap;
// Query: P(Cancer | XRay=positive)
let mut evidence = new;
evidence.insert;
let query = ConditionalQuery ;
let algorithm = Boxnew;
let engine = new;
let p_cancer_given_xray = engine.conditional.unwrap;
println!;
Example 2: Markov Random Field
// 2x2 grid MRF for image denoising
let mut graph = new;
// Pixel variables
for i in 0..4
// Pairwise potentials (smoothness)
let smoothness = from_shape_vec.unwrap.into_dyn;
// Add edge factors
let edges = vec!;
for in edges
// Add observation factors (noisy measurements)
// ... (implementation details)
Example 3: Hidden Markov Model
// Simple HMM: weather states predicting umbrella usage
let mut graph = new;
let T = 5; // Time steps
// Hidden states (weather)
for t in 0..T
// Observations (umbrella)
for t in 0..T
// Initial state
let initial = new.unwrap;
graph.add_factor.unwrap;
// Transition model
let transition = from_shape_vec.unwrap.into_dyn;
for t in 0..T-1
// Observation model
let observation = from_shape_vec.unwrap.into_dyn;
for t in 0..T
// Filtering: P(weather_t | umbrella_0:t)
// Smoothing: P(weather_t | umbrella_0:T)
// ... (inference implementation)
Testing
Run all tests:
Run specific test suites:
# Factor operations
# Message passing
# Inference
# Variational
Architecture
tensorlogic-quantrs-hooks
├── Factor Operations
│ ├── Product (×)
│ ├── Marginalization (Σ)
│ ├── Division (÷)
│ └── Reduction (evidence)
├── Factor Graphs
│ ├── Variable Nodes
│ ├── Factor Nodes
│ └── Adjacency Lists
├── Message Passing
│ ├── Sum-Product (marginals)
│ ├── Max-Product (MAP)
│ └── Convergence Detection
├── Variational Inference
│ ├── Mean-Field (fully factorized)
│ ├── Bethe Approximation (structured)
│ ├── Tree-Reweighted BP
│ └── ELBO/Free Energy Computation
└── Sampling
├── Gibbs Sampler
├── Burn-in/Thinning
└── Empirical Marginals
Contributing
See CONTRIBUTING.md for guidelines.
References
- Koller & Friedman, "Probabilistic Graphical Models" (2009)
- Wainwright & Jordan, "Graphical Models, Exponential Families, and Variational Inference" (2008)
- Bishop, "Pattern Recognition and Machine Learning" (2006), Chapter 8
License
Apache-2.0
Status: 🎉 Production Ready (v0.1.0-alpha.1) Last Updated: 2025-11-07 Tests: 109 passing (100%: 96 unit + 13 integration) Examples: 8 comprehensive examples Completeness: ~99% (all medium-priority features complete!) Features:
- Inference: 7 algorithms (Sum-Product, Max-Product, Junction Tree, Mean-Field, Bethe, TRW-BP, EP, Gibbs)
- Models: 5 types (Bayesian Networks, HMMs, MRFs, CRFs, Linear-chain CRFs)
- Learning: Parameter estimation, Baum-Welch EM
- Integration: QuantRS2 hooks, JSON export, information theory utilities Part of: TensorLogic Ecosystem