Skip to main content

numrs2/new_modules/probabilistic/
mod.rs

1//! # Probabilistic Programming Module
2//!
3//! This module provides comprehensive probabilistic programming infrastructure for NumRS2,
4//! including advanced inference algorithms, Bayesian utilities, and probabilistic graphical models.
5//!
6//! ## Overview
7//!
8//! The probabilistic module offers production-ready implementations of:
9//!
10//! - **MCMC Inference**: Metropolis-Hastings, Gibbs sampling, Hamiltonian Monte Carlo (HMC),
11//!   No-U-Turn Sampler (NUTS), and Parallel Tempering
12//! - **Variational Inference**: Mean-field VI, Automatic Differentiation VI (ADVI), ELBO optimization
13//! - **Bayesian Utilities**: Conjugate priors, posterior computation, model comparison (BIC, DIC, WAIC),
14//!   credible intervals, hypothesis testing
15//! - **Graphical Models**: Bayesian networks, Markov Random Fields, Hidden Markov Models,
16//!   Gaussian Processes
17//! - **Extended Distributions**: Beta, Gamma, Dirichlet, Student's t, Wishart, Inverse-Wishart,
18//!   Von Mises, and more
19//!
20//! ## Mathematical Background
21//!
22//! ### Bayesian Inference
23//!
24//! Bayesian inference provides a principled framework for updating beliefs about parameters θ
25//! given observed data D using Bayes' theorem:
26//!
27//! ```text
28//! p(θ|D) = p(D|θ)p(θ) / p(D)
29//! ```
30//!
31//! where:
32//! - p(θ|D) is the posterior distribution
33//! - p(D|θ) is the likelihood
34//! - p(θ) is the prior distribution
35//! - p(D) is the marginal likelihood (evidence)
36//!
37//! ### Markov Chain Monte Carlo (MCMC)
38//!
39//! When the posterior distribution cannot be computed analytically, MCMC methods construct
40//! a Markov chain whose stationary distribution is the target posterior. Common algorithms include:
41//!
42//! - **Metropolis-Hastings**: Generic MCMC using proposal distributions
43//! - **Gibbs Sampling**: Samples from conditional distributions
44//! - **Hamiltonian Monte Carlo**: Uses gradient information for efficient exploration
45//! - **NUTS**: Adaptive HMC with automatic step size tuning
46//!
47//! ### Variational Inference
48//!
49//! Variational inference approximates the posterior p(θ|D) with a simpler distribution q(θ)
50//! by minimizing the Kullback-Leibler divergence:
51//!
52//! ```text
53//! KL(q||p) = ∫ q(θ) log(q(θ)/p(θ|D)) dθ
54//! ```
55//!
56//! This is equivalent to maximizing the Evidence Lower BOund (ELBO):
57//!
58//! ```text
59//! ELBO = E_q[log p(D,θ)] - E_q[log q(θ)]
60//! ```
61//!
62//! ## SCIRS2 Policy Compliance
63//!
64//! This module strictly follows SCIRS2 ecosystem policies:
65//!
66//! - **Random Number Generation**: ALWAYS use `scirs2_core::random` (NEVER direct rand/rand_distr)
67//! - **Array Operations**: ALWAYS use `scirs2_core::ndarray` (NEVER direct ndarray)
68//! - **Parallel Processing**: ALWAYS use `scirs2_core::parallel_ops` (NEVER direct rayon)
69//! - **Statistical Functions**: Use `scirs2_stats` for statistical computations
70//! - **Linear Algebra**: Use `scirs2_linalg` for matrix operations (Pure Rust via OxiBLAS)
71//!
72//! ## Usage Examples
73//!
74//! ### Example 1: Metropolis-Hastings Sampling
75//!
76//! ```rust,ignore
77//! use numrs2::new_modules::probabilistic::{MetropolisHastings, GaussianProposal};
78//! use scirs2_core::random::default_rng;
79//!
80//! // Define log-posterior function
81//! let log_posterior = |theta: &[f64]| -> f64 {
82//!     // Log-likelihood + log-prior
83//!     -0.5 * theta[0].powi(2) // Standard normal prior
84//! };
85//!
86//! // Create sampler with Gaussian proposal
87//! let mut rng = default_rng();
88//! let proposal = GaussianProposal::new(0.5)?; // Proposal std dev
89//! let mut sampler = MetropolisHastings::new(log_posterior, proposal);
90//!
91//! // Run MCMC for 10,000 iterations
92//! let initial_state = vec![0.0];
93//! let samples = sampler.sample(&initial_state, 10000, 1000, &mut rng)?;
94//! ```
95//!
96//! ### Example 2: Bayesian Linear Regression
97//!
98//! ```rust,ignore
99//! use numrs2::new_modules::probabilistic::{BayesianLinearRegression, NormalInverseGammaPrior};
100//! use numrs2::prelude::*;
101//!
102//! // Data: y = 2*x + 1 + noise
103//! let x = linspace(0.0, 10.0, 100).reshape(&[100, 1]);
104//! let y = x.multiply_scalar(2.0).add_scalar(1.0).add(&randn(&[100]));
105//!
106//! // Set up conjugate prior
107//! let prior = NormalInverseGammaPrior::default();
108//!
109//! // Compute posterior
110//! let posterior = prior.update(&x, &y)?;
111//!
112//! // Sample from posterior predictive
113//! let x_new = linspace(10.0, 15.0, 50).reshape(&[50, 1]);
114//! let y_pred = posterior.predict(&x_new)?;
115//! ```
116//!
117//! ### Example 3: Model Comparison with WAIC
118//!
119//! ```rust,ignore
120//! use numrs2::new_modules::probabilistic::{ModelComparison, waic};
121//!
122//! // Compute WAIC for model selection
123//! let log_likelihood_samples = /* MCMC samples of log-likelihood */;
124//! let waic_score = waic(&log_likelihood_samples)?;
125//! println!("WAIC: {}", waic_score.waic);
126//! println!("Effective parameters: {}", waic_score.p_waic);
127//! ```
128//!
129//! ## Performance Considerations
130//!
131//! - **SIMD Optimization**: Distribution operations use SIMD when applicable
132//! - **Parallel MCMC**: Multiple chains can run in parallel using `scirs2_core::parallel_ops`
133//! - **Memory Efficiency**: Streaming algorithms for large-scale inference
134//! - **Numerical Stability**: Log-space computations to prevent underflow
135//!
136//! ## References
137//!
138//! - Gelman, A., et al. (2013). *Bayesian Data Analysis* (3rd ed.). Chapman and Hall/CRC.
139//! - Neal, R. M. (2011). MCMC using Hamiltonian dynamics. *Handbook of Markov Chain Monte Carlo*.
140//! - Hoffman, M. D., & Gelman, A. (2014). The No-U-Turn Sampler. *Journal of Machine Learning Research*.
141//! - Blei, D. M., Kucukelbir, A., & McAuliffe, J. D. (2017). Variational Inference: A Review for Statisticians.
142//! - Bishop, C. M. (2006). *Pattern Recognition and Machine Learning*. Springer.
143
144use crate::array::Array;
145use crate::error::NumRs2Error;
146use std::fmt;
147
148// Module declarations
149pub mod bayesian;
150pub mod distributions;
151pub mod graphical;
152pub mod inference;
153
154#[cfg(test)]
155mod tests;
156
157// Re-exports from submodules
158pub use bayesian::*;
159pub use distributions::*;
160pub use graphical::*;
161pub use inference::*;
162
163/// Result type for probabilistic operations
164pub type Result<T> = std::result::Result<T, ProbabilisticError>;
165
166/// Comprehensive error type for probabilistic programming operations
167#[derive(Debug, Clone)]
168pub enum ProbabilisticError {
169    /// Invalid parameter value
170    InvalidParameter { parameter: String, message: String },
171
172    /// Dimension mismatch in array operations
173    DimensionMismatch {
174        expected: Vec<usize>,
175        actual: Vec<usize>,
176        operation: String,
177    },
178
179    /// Numerical error (overflow, underflow, NaN, etc.)
180    NumericalError { message: String },
181
182    /// Convergence failure in iterative algorithms
183    ConvergenceError {
184        algorithm: String,
185        iterations: usize,
186        message: String,
187    },
188
189    /// Invalid probability distribution
190    InvalidDistribution {
191        distribution: String,
192        reason: String,
193    },
194
195    /// MCMC sampling error
196    SamplingError {
197        sampler: String,
198        iteration: usize,
199        message: String,
200    },
201
202    /// Variational inference error
203    VariationalInferenceError { message: String },
204
205    /// Graphical model error
206    GraphicalModelError { model_type: String, message: String },
207
208    /// Integration error with NumRS2
209    NumRs2IntegrationError { source: Box<NumRs2Error> },
210
211    /// Generic error with custom message
212    Other { message: String },
213}
214
215impl fmt::Display for ProbabilisticError {
216    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
217        match self {
218            ProbabilisticError::InvalidParameter { parameter, message } => {
219                write!(f, "Invalid parameter '{}': {}", parameter, message)
220            }
221            ProbabilisticError::DimensionMismatch {
222                expected,
223                actual,
224                operation,
225            } => {
226                write!(
227                    f,
228                    "Dimension mismatch in {}: expected {:?}, got {:?}",
229                    operation, expected, actual
230                )
231            }
232            ProbabilisticError::NumericalError { message } => {
233                write!(f, "Numerical error: {}", message)
234            }
235            ProbabilisticError::ConvergenceError {
236                algorithm,
237                iterations,
238                message,
239            } => {
240                write!(
241                    f,
242                    "Convergence failure in {} after {} iterations: {}",
243                    algorithm, iterations, message
244                )
245            }
246            ProbabilisticError::InvalidDistribution {
247                distribution,
248                reason,
249            } => {
250                write!(f, "Invalid distribution '{}': {}", distribution, reason)
251            }
252            ProbabilisticError::SamplingError {
253                sampler,
254                iteration,
255                message,
256            } => {
257                write!(
258                    f,
259                    "Sampling error in {} at iteration {}: {}",
260                    sampler, iteration, message
261                )
262            }
263            ProbabilisticError::VariationalInferenceError { message } => {
264                write!(f, "Variational inference error: {}", message)
265            }
266            ProbabilisticError::GraphicalModelError {
267                model_type,
268                message,
269            } => {
270                write!(f, "Graphical model error in {}: {}", model_type, message)
271            }
272            ProbabilisticError::NumRs2IntegrationError { source } => {
273                write!(f, "NumRS2 integration error: {}", source)
274            }
275            ProbabilisticError::Other { message } => {
276                write!(f, "Probabilistic error: {}", message)
277            }
278        }
279    }
280}
281
282impl std::error::Error for ProbabilisticError {
283    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
284        match self {
285            ProbabilisticError::NumRs2IntegrationError { source } => Some(source),
286            _ => None,
287        }
288    }
289}
290
291impl From<NumRs2Error> for ProbabilisticError {
292    fn from(error: NumRs2Error) -> Self {
293        ProbabilisticError::NumRs2IntegrationError {
294            source: Box::new(error),
295        }
296    }
297}
298
299/// Helper function to validate probability values
300///
301/// Ensures that a probability is in the valid range [0, 1].
302///
303/// # Arguments
304///
305/// * `p` - Probability value to validate
306/// * `name` - Parameter name for error messages
307///
308/// # Returns
309///
310/// * `Ok(())` if the probability is valid
311/// * `Err(ProbabilisticError)` if the probability is invalid
312///
313/// # Examples
314///
315/// ```rust,ignore
316/// validate_probability(0.5, "p")?; // OK
317/// validate_probability(-0.1, "p")?; // Error: probability out of range
318/// validate_probability(1.5, "p")?; // Error: probability out of range
319/// ```
320pub fn validate_probability(p: f64, name: &str) -> Result<()> {
321    if !p.is_finite() {
322        return Err(ProbabilisticError::InvalidParameter {
323            parameter: name.to_string(),
324            message: format!("probability must be finite, got {}", p),
325        });
326    }
327    if !(0.0..=1.0).contains(&p) {
328        return Err(ProbabilisticError::InvalidParameter {
329            parameter: name.to_string(),
330            message: format!("probability must be in [0, 1], got {}", p),
331        });
332    }
333    Ok(())
334}
335
336/// Helper function to validate positive parameter values
337///
338/// Ensures that a parameter is strictly positive.
339///
340/// # Arguments
341///
342/// * `value` - Parameter value to validate
343/// * `name` - Parameter name for error messages
344///
345/// # Returns
346///
347/// * `Ok(())` if the value is valid
348/// * `Err(ProbabilisticError)` if the value is invalid
349pub fn validate_positive(value: f64, name: &str) -> Result<()> {
350    if !value.is_finite() {
351        return Err(ProbabilisticError::InvalidParameter {
352            parameter: name.to_string(),
353            message: format!("value must be finite, got {}", value),
354        });
355    }
356    if value <= 0.0 {
357        return Err(ProbabilisticError::InvalidParameter {
358            parameter: name.to_string(),
359            message: format!("value must be positive, got {}", value),
360        });
361    }
362    Ok(())
363}
364
365/// Helper function to validate non-negative parameter values
366///
367/// Ensures that a parameter is non-negative.
368///
369/// # Arguments
370///
371/// * `value` - Parameter value to validate
372/// * `name` - Parameter name for error messages
373///
374/// # Returns
375///
376/// * `Ok(())` if the value is valid
377/// * `Err(ProbabilisticError)` if the value is invalid
378pub fn validate_non_negative(value: f64, name: &str) -> Result<()> {
379    if !value.is_finite() {
380        return Err(ProbabilisticError::InvalidParameter {
381            parameter: name.to_string(),
382            message: format!("value must be finite, got {}", value),
383        });
384    }
385    if value < 0.0 {
386        return Err(ProbabilisticError::InvalidParameter {
387            parameter: name.to_string(),
388            message: format!("value must be non-negative, got {}", value),
389        });
390    }
391    Ok(())
392}
393
394/// Helper function to validate array shapes match
395///
396/// # Arguments
397///
398/// * `expected` - Expected shape
399/// * `actual` - Actual shape
400/// * `operation` - Operation name for error messages
401///
402/// # Returns
403///
404/// * `Ok(())` if shapes match
405/// * `Err(ProbabilisticError)` if shapes don't match
406pub fn validate_shape(expected: &[usize], actual: &[usize], operation: &str) -> Result<()> {
407    if expected != actual {
408        return Err(ProbabilisticError::DimensionMismatch {
409            expected: expected.to_vec(),
410            actual: actual.to_vec(),
411            operation: operation.to_string(),
412        });
413    }
414    Ok(())
415}
416
417#[cfg(test)]
418mod module_tests {
419    use super::*;
420
421    #[test]
422    fn test_validate_probability() {
423        assert!(validate_probability(0.0, "p").is_ok());
424        assert!(validate_probability(0.5, "p").is_ok());
425        assert!(validate_probability(1.0, "p").is_ok());
426        assert!(validate_probability(-0.1, "p").is_err());
427        assert!(validate_probability(1.1, "p").is_err());
428        assert!(validate_probability(f64::NAN, "p").is_err());
429        assert!(validate_probability(f64::INFINITY, "p").is_err());
430    }
431
432    #[test]
433    fn test_validate_positive() {
434        assert!(validate_positive(0.1, "x").is_ok());
435        assert!(validate_positive(1.0, "x").is_ok());
436        assert!(validate_positive(100.0, "x").is_ok());
437        assert!(validate_positive(0.0, "x").is_err());
438        assert!(validate_positive(-1.0, "x").is_err());
439        assert!(validate_positive(f64::NAN, "x").is_err());
440    }
441
442    #[test]
443    fn test_validate_non_negative() {
444        assert!(validate_non_negative(0.0, "x").is_ok());
445        assert!(validate_non_negative(0.1, "x").is_ok());
446        assert!(validate_non_negative(1.0, "x").is_ok());
447        assert!(validate_non_negative(-0.1, "x").is_err());
448        assert!(validate_non_negative(f64::NAN, "x").is_err());
449    }
450
451    #[test]
452    fn test_validate_shape() {
453        assert!(validate_shape(&[2, 3], &[2, 3], "test").is_ok());
454        assert!(validate_shape(&[2], &[2], "test").is_ok());
455        assert!(validate_shape(&[2, 3], &[3, 2], "test").is_err());
456        assert!(validate_shape(&[2, 3], &[2], "test").is_err());
457    }
458
459    #[test]
460    fn test_error_display() {
461        let err = ProbabilisticError::InvalidParameter {
462            parameter: "alpha".to_string(),
463            message: "must be positive".to_string(),
464        };
465        let display = format!("{}", err);
466        assert!(display.contains("alpha"));
467        assert!(display.contains("positive"));
468    }
469
470    #[test]
471    fn test_error_from_numrs2() {
472        let numrs2_err = NumRs2Error::DimensionMismatch("expected 2x3, got 3x2".to_string());
473        let prob_err: ProbabilisticError = numrs2_err.into();
474
475        match prob_err {
476            ProbabilisticError::NumRs2IntegrationError { .. } => {}
477            _ => panic!("Expected NumRs2IntegrationError"),
478        }
479    }
480}