numrs2/new_modules/probabilistic/mod.rs
1//! # Probabilistic Programming Module
2//!
3//! This module provides comprehensive probabilistic programming infrastructure for NumRS2,
4//! including advanced inference algorithms, Bayesian utilities, and probabilistic graphical models.
5//!
6//! ## Overview
7//!
8//! The probabilistic module offers production-ready implementations of:
9//!
10//! - **MCMC Inference**: Metropolis-Hastings, Gibbs sampling, Hamiltonian Monte Carlo (HMC),
11//! No-U-Turn Sampler (NUTS), and Parallel Tempering
12//! - **Variational Inference**: Mean-field VI, Automatic Differentiation VI (ADVI), ELBO optimization
13//! - **Bayesian Utilities**: Conjugate priors, posterior computation, model comparison (BIC, DIC, WAIC),
14//! credible intervals, hypothesis testing
15//! - **Graphical Models**: Bayesian networks, Markov Random Fields, Hidden Markov Models,
16//! Gaussian Processes
17//! - **Extended Distributions**: Beta, Gamma, Dirichlet, Student's t, Wishart, Inverse-Wishart,
18//! Von Mises, and more
19//!
20//! ## Mathematical Background
21//!
22//! ### Bayesian Inference
23//!
24//! Bayesian inference provides a principled framework for updating beliefs about parameters θ
25//! given observed data D using Bayes' theorem:
26//!
27//! ```text
28//! p(θ|D) = p(D|θ)p(θ) / p(D)
29//! ```
30//!
31//! where:
32//! - p(θ|D) is the posterior distribution
33//! - p(D|θ) is the likelihood
34//! - p(θ) is the prior distribution
35//! - p(D) is the marginal likelihood (evidence)
36//!
37//! ### Markov Chain Monte Carlo (MCMC)
38//!
39//! When the posterior distribution cannot be computed analytically, MCMC methods construct
40//! a Markov chain whose stationary distribution is the target posterior. Common algorithms include:
41//!
42//! - **Metropolis-Hastings**: Generic MCMC using proposal distributions
43//! - **Gibbs Sampling**: Samples from conditional distributions
44//! - **Hamiltonian Monte Carlo**: Uses gradient information for efficient exploration
45//! - **NUTS**: Adaptive HMC with automatic step size tuning
46//!
47//! ### Variational Inference
48//!
49//! Variational inference approximates the posterior p(θ|D) with a simpler distribution q(θ)
50//! by minimizing the Kullback-Leibler divergence:
51//!
52//! ```text
53//! KL(q||p) = ∫ q(θ) log(q(θ)/p(θ|D)) dθ
54//! ```
55//!
56//! This is equivalent to maximizing the Evidence Lower BOund (ELBO):
57//!
58//! ```text
59//! ELBO = E_q[log p(D,θ)] - E_q[log q(θ)]
60//! ```
61//!
62//! ## SCIRS2 Policy Compliance
63//!
64//! This module strictly follows SCIRS2 ecosystem policies:
65//!
66//! - **Random Number Generation**: ALWAYS use `scirs2_core::random` (NEVER direct rand/rand_distr)
67//! - **Array Operations**: ALWAYS use `scirs2_core::ndarray` (NEVER direct ndarray)
68//! - **Parallel Processing**: ALWAYS use `scirs2_core::parallel_ops` (NEVER direct rayon)
69//! - **Statistical Functions**: Use `scirs2_stats` for statistical computations
70//! - **Linear Algebra**: Use `scirs2_linalg` for matrix operations (Pure Rust via OxiBLAS)
71//!
72//! ## Usage Examples
73//!
74//! ### Example 1: Metropolis-Hastings Sampling
75//!
76//! ```rust,ignore
77//! use numrs2::new_modules::probabilistic::{MetropolisHastings, GaussianProposal};
78//! use scirs2_core::random::default_rng;
79//!
80//! // Define log-posterior function
81//! let log_posterior = |theta: &[f64]| -> f64 {
82//! // Log-likelihood + log-prior
83//! -0.5 * theta[0].powi(2) // Standard normal prior
84//! };
85//!
86//! // Create sampler with Gaussian proposal
87//! let mut rng = default_rng();
88//! let proposal = GaussianProposal::new(0.5)?; // Proposal std dev
89//! let mut sampler = MetropolisHastings::new(log_posterior, proposal);
90//!
91//! // Run MCMC for 10,000 iterations
92//! let initial_state = vec![0.0];
93//! let samples = sampler.sample(&initial_state, 10000, 1000, &mut rng)?;
94//! ```
95//!
96//! ### Example 2: Bayesian Linear Regression
97//!
98//! ```rust,ignore
99//! use numrs2::new_modules::probabilistic::{BayesianLinearRegression, NormalInverseGammaPrior};
100//! use numrs2::prelude::*;
101//!
102//! // Data: y = 2*x + 1 + noise
103//! let x = linspace(0.0, 10.0, 100).reshape(&[100, 1]);
104//! let y = x.multiply_scalar(2.0).add_scalar(1.0).add(&randn(&[100]));
105//!
106//! // Set up conjugate prior
107//! let prior = NormalInverseGammaPrior::default();
108//!
109//! // Compute posterior
110//! let posterior = prior.update(&x, &y)?;
111//!
112//! // Sample from posterior predictive
113//! let x_new = linspace(10.0, 15.0, 50).reshape(&[50, 1]);
114//! let y_pred = posterior.predict(&x_new)?;
115//! ```
116//!
117//! ### Example 3: Model Comparison with WAIC
118//!
119//! ```rust,ignore
120//! use numrs2::new_modules::probabilistic::{ModelComparison, waic};
121//!
122//! // Compute WAIC for model selection
123//! let log_likelihood_samples = /* MCMC samples of log-likelihood */;
124//! let waic_score = waic(&log_likelihood_samples)?;
125//! println!("WAIC: {}", waic_score.waic);
126//! println!("Effective parameters: {}", waic_score.p_waic);
127//! ```
128//!
129//! ## Performance Considerations
130//!
131//! - **SIMD Optimization**: Distribution operations use SIMD when applicable
132//! - **Parallel MCMC**: Multiple chains can run in parallel using `scirs2_core::parallel_ops`
133//! - **Memory Efficiency**: Streaming algorithms for large-scale inference
134//! - **Numerical Stability**: Log-space computations to prevent underflow
135//!
136//! ## References
137//!
138//! - Gelman, A., et al. (2013). *Bayesian Data Analysis* (3rd ed.). Chapman and Hall/CRC.
139//! - Neal, R. M. (2011). MCMC using Hamiltonian dynamics. *Handbook of Markov Chain Monte Carlo*.
140//! - Hoffman, M. D., & Gelman, A. (2014). The No-U-Turn Sampler. *Journal of Machine Learning Research*.
141//! - Blei, D. M., Kucukelbir, A., & McAuliffe, J. D. (2017). Variational Inference: A Review for Statisticians.
142//! - Bishop, C. M. (2006). *Pattern Recognition and Machine Learning*. Springer.
143
144use crate::array::Array;
145use crate::error::NumRs2Error;
146use std::fmt;
147
148// Module declarations
149pub mod bayesian;
150pub mod distributions;
151pub mod graphical;
152pub mod inference;
153
154#[cfg(test)]
155mod tests;
156
157// Re-exports from submodules
158pub use bayesian::*;
159pub use distributions::*;
160pub use graphical::*;
161pub use inference::*;
162
163/// Result type for probabilistic operations
164pub type Result<T> = std::result::Result<T, ProbabilisticError>;
165
166/// Comprehensive error type for probabilistic programming operations
167#[derive(Debug, Clone)]
168pub enum ProbabilisticError {
169 /// Invalid parameter value
170 InvalidParameter { parameter: String, message: String },
171
172 /// Dimension mismatch in array operations
173 DimensionMismatch {
174 expected: Vec<usize>,
175 actual: Vec<usize>,
176 operation: String,
177 },
178
179 /// Numerical error (overflow, underflow, NaN, etc.)
180 NumericalError { message: String },
181
182 /// Convergence failure in iterative algorithms
183 ConvergenceError {
184 algorithm: String,
185 iterations: usize,
186 message: String,
187 },
188
189 /// Invalid probability distribution
190 InvalidDistribution {
191 distribution: String,
192 reason: String,
193 },
194
195 /// MCMC sampling error
196 SamplingError {
197 sampler: String,
198 iteration: usize,
199 message: String,
200 },
201
202 /// Variational inference error
203 VariationalInferenceError { message: String },
204
205 /// Graphical model error
206 GraphicalModelError { model_type: String, message: String },
207
208 /// Integration error with NumRS2
209 NumRs2IntegrationError { source: Box<NumRs2Error> },
210
211 /// Generic error with custom message
212 Other { message: String },
213}
214
215impl fmt::Display for ProbabilisticError {
216 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
217 match self {
218 ProbabilisticError::InvalidParameter { parameter, message } => {
219 write!(f, "Invalid parameter '{}': {}", parameter, message)
220 }
221 ProbabilisticError::DimensionMismatch {
222 expected,
223 actual,
224 operation,
225 } => {
226 write!(
227 f,
228 "Dimension mismatch in {}: expected {:?}, got {:?}",
229 operation, expected, actual
230 )
231 }
232 ProbabilisticError::NumericalError { message } => {
233 write!(f, "Numerical error: {}", message)
234 }
235 ProbabilisticError::ConvergenceError {
236 algorithm,
237 iterations,
238 message,
239 } => {
240 write!(
241 f,
242 "Convergence failure in {} after {} iterations: {}",
243 algorithm, iterations, message
244 )
245 }
246 ProbabilisticError::InvalidDistribution {
247 distribution,
248 reason,
249 } => {
250 write!(f, "Invalid distribution '{}': {}", distribution, reason)
251 }
252 ProbabilisticError::SamplingError {
253 sampler,
254 iteration,
255 message,
256 } => {
257 write!(
258 f,
259 "Sampling error in {} at iteration {}: {}",
260 sampler, iteration, message
261 )
262 }
263 ProbabilisticError::VariationalInferenceError { message } => {
264 write!(f, "Variational inference error: {}", message)
265 }
266 ProbabilisticError::GraphicalModelError {
267 model_type,
268 message,
269 } => {
270 write!(f, "Graphical model error in {}: {}", model_type, message)
271 }
272 ProbabilisticError::NumRs2IntegrationError { source } => {
273 write!(f, "NumRS2 integration error: {}", source)
274 }
275 ProbabilisticError::Other { message } => {
276 write!(f, "Probabilistic error: {}", message)
277 }
278 }
279 }
280}
281
282impl std::error::Error for ProbabilisticError {
283 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
284 match self {
285 ProbabilisticError::NumRs2IntegrationError { source } => Some(source),
286 _ => None,
287 }
288 }
289}
290
291impl From<NumRs2Error> for ProbabilisticError {
292 fn from(error: NumRs2Error) -> Self {
293 ProbabilisticError::NumRs2IntegrationError {
294 source: Box::new(error),
295 }
296 }
297}
298
299/// Helper function to validate probability values
300///
301/// Ensures that a probability is in the valid range [0, 1].
302///
303/// # Arguments
304///
305/// * `p` - Probability value to validate
306/// * `name` - Parameter name for error messages
307///
308/// # Returns
309///
310/// * `Ok(())` if the probability is valid
311/// * `Err(ProbabilisticError)` if the probability is invalid
312///
313/// # Examples
314///
315/// ```rust,ignore
316/// validate_probability(0.5, "p")?; // OK
317/// validate_probability(-0.1, "p")?; // Error: probability out of range
318/// validate_probability(1.5, "p")?; // Error: probability out of range
319/// ```
320pub fn validate_probability(p: f64, name: &str) -> Result<()> {
321 if !p.is_finite() {
322 return Err(ProbabilisticError::InvalidParameter {
323 parameter: name.to_string(),
324 message: format!("probability must be finite, got {}", p),
325 });
326 }
327 if !(0.0..=1.0).contains(&p) {
328 return Err(ProbabilisticError::InvalidParameter {
329 parameter: name.to_string(),
330 message: format!("probability must be in [0, 1], got {}", p),
331 });
332 }
333 Ok(())
334}
335
336/// Helper function to validate positive parameter values
337///
338/// Ensures that a parameter is strictly positive.
339///
340/// # Arguments
341///
342/// * `value` - Parameter value to validate
343/// * `name` - Parameter name for error messages
344///
345/// # Returns
346///
347/// * `Ok(())` if the value is valid
348/// * `Err(ProbabilisticError)` if the value is invalid
349pub fn validate_positive(value: f64, name: &str) -> Result<()> {
350 if !value.is_finite() {
351 return Err(ProbabilisticError::InvalidParameter {
352 parameter: name.to_string(),
353 message: format!("value must be finite, got {}", value),
354 });
355 }
356 if value <= 0.0 {
357 return Err(ProbabilisticError::InvalidParameter {
358 parameter: name.to_string(),
359 message: format!("value must be positive, got {}", value),
360 });
361 }
362 Ok(())
363}
364
365/// Helper function to validate non-negative parameter values
366///
367/// Ensures that a parameter is non-negative.
368///
369/// # Arguments
370///
371/// * `value` - Parameter value to validate
372/// * `name` - Parameter name for error messages
373///
374/// # Returns
375///
376/// * `Ok(())` if the value is valid
377/// * `Err(ProbabilisticError)` if the value is invalid
378pub fn validate_non_negative(value: f64, name: &str) -> Result<()> {
379 if !value.is_finite() {
380 return Err(ProbabilisticError::InvalidParameter {
381 parameter: name.to_string(),
382 message: format!("value must be finite, got {}", value),
383 });
384 }
385 if value < 0.0 {
386 return Err(ProbabilisticError::InvalidParameter {
387 parameter: name.to_string(),
388 message: format!("value must be non-negative, got {}", value),
389 });
390 }
391 Ok(())
392}
393
394/// Helper function to validate array shapes match
395///
396/// # Arguments
397///
398/// * `expected` - Expected shape
399/// * `actual` - Actual shape
400/// * `operation` - Operation name for error messages
401///
402/// # Returns
403///
404/// * `Ok(())` if shapes match
405/// * `Err(ProbabilisticError)` if shapes don't match
406pub fn validate_shape(expected: &[usize], actual: &[usize], operation: &str) -> Result<()> {
407 if expected != actual {
408 return Err(ProbabilisticError::DimensionMismatch {
409 expected: expected.to_vec(),
410 actual: actual.to_vec(),
411 operation: operation.to_string(),
412 });
413 }
414 Ok(())
415}
416
417#[cfg(test)]
418mod module_tests {
419 use super::*;
420
421 #[test]
422 fn test_validate_probability() {
423 assert!(validate_probability(0.0, "p").is_ok());
424 assert!(validate_probability(0.5, "p").is_ok());
425 assert!(validate_probability(1.0, "p").is_ok());
426 assert!(validate_probability(-0.1, "p").is_err());
427 assert!(validate_probability(1.1, "p").is_err());
428 assert!(validate_probability(f64::NAN, "p").is_err());
429 assert!(validate_probability(f64::INFINITY, "p").is_err());
430 }
431
432 #[test]
433 fn test_validate_positive() {
434 assert!(validate_positive(0.1, "x").is_ok());
435 assert!(validate_positive(1.0, "x").is_ok());
436 assert!(validate_positive(100.0, "x").is_ok());
437 assert!(validate_positive(0.0, "x").is_err());
438 assert!(validate_positive(-1.0, "x").is_err());
439 assert!(validate_positive(f64::NAN, "x").is_err());
440 }
441
442 #[test]
443 fn test_validate_non_negative() {
444 assert!(validate_non_negative(0.0, "x").is_ok());
445 assert!(validate_non_negative(0.1, "x").is_ok());
446 assert!(validate_non_negative(1.0, "x").is_ok());
447 assert!(validate_non_negative(-0.1, "x").is_err());
448 assert!(validate_non_negative(f64::NAN, "x").is_err());
449 }
450
451 #[test]
452 fn test_validate_shape() {
453 assert!(validate_shape(&[2, 3], &[2, 3], "test").is_ok());
454 assert!(validate_shape(&[2], &[2], "test").is_ok());
455 assert!(validate_shape(&[2, 3], &[3, 2], "test").is_err());
456 assert!(validate_shape(&[2, 3], &[2], "test").is_err());
457 }
458
459 #[test]
460 fn test_error_display() {
461 let err = ProbabilisticError::InvalidParameter {
462 parameter: "alpha".to_string(),
463 message: "must be positive".to_string(),
464 };
465 let display = format!("{}", err);
466 assert!(display.contains("alpha"));
467 assert!(display.contains("positive"));
468 }
469
470 #[test]
471 fn test_error_from_numrs2() {
472 let numrs2_err = NumRs2Error::DimensionMismatch("expected 2x3, got 3x2".to_string());
473 let prob_err: ProbabilisticError = numrs2_err.into();
474
475 match prob_err {
476 ProbabilisticError::NumRs2IntegrationError { .. } => {}
477 _ => panic!("Expected NumRs2IntegrationError"),
478 }
479 }
480}