use crate::array::Array;
use crate::error::NumRs2Error;
use std::fmt;
pub mod bayesian;
pub mod distributions;
pub mod graphical;
pub mod inference;
#[cfg(test)]
mod tests;
pub use bayesian::*;
pub use distributions::*;
pub use graphical::*;
pub use inference::*;
pub type Result<T> = std::result::Result<T, ProbabilisticError>;
#[derive(Debug, Clone)]
pub enum ProbabilisticError {
InvalidParameter { parameter: String, message: String },
DimensionMismatch {
expected: Vec<usize>,
actual: Vec<usize>,
operation: String,
},
NumericalError { message: String },
ConvergenceError {
algorithm: String,
iterations: usize,
message: String,
},
InvalidDistribution {
distribution: String,
reason: String,
},
SamplingError {
sampler: String,
iteration: usize,
message: String,
},
VariationalInferenceError { message: String },
GraphicalModelError { model_type: String, message: String },
NumRs2IntegrationError { source: Box<NumRs2Error> },
Other { message: String },
}
impl fmt::Display for ProbabilisticError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ProbabilisticError::InvalidParameter { parameter, message } => {
write!(f, "Invalid parameter '{}': {}", parameter, message)
}
ProbabilisticError::DimensionMismatch {
expected,
actual,
operation,
} => {
write!(
f,
"Dimension mismatch in {}: expected {:?}, got {:?}",
operation, expected, actual
)
}
ProbabilisticError::NumericalError { message } => {
write!(f, "Numerical error: {}", message)
}
ProbabilisticError::ConvergenceError {
algorithm,
iterations,
message,
} => {
write!(
f,
"Convergence failure in {} after {} iterations: {}",
algorithm, iterations, message
)
}
ProbabilisticError::InvalidDistribution {
distribution,
reason,
} => {
write!(f, "Invalid distribution '{}': {}", distribution, reason)
}
ProbabilisticError::SamplingError {
sampler,
iteration,
message,
} => {
write!(
f,
"Sampling error in {} at iteration {}: {}",
sampler, iteration, message
)
}
ProbabilisticError::VariationalInferenceError { message } => {
write!(f, "Variational inference error: {}", message)
}
ProbabilisticError::GraphicalModelError {
model_type,
message,
} => {
write!(f, "Graphical model error in {}: {}", model_type, message)
}
ProbabilisticError::NumRs2IntegrationError { source } => {
write!(f, "NumRS2 integration error: {}", source)
}
ProbabilisticError::Other { message } => {
write!(f, "Probabilistic error: {}", message)
}
}
}
}
impl std::error::Error for ProbabilisticError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
ProbabilisticError::NumRs2IntegrationError { source } => Some(source),
_ => None,
}
}
}
impl From<NumRs2Error> for ProbabilisticError {
fn from(error: NumRs2Error) -> Self {
ProbabilisticError::NumRs2IntegrationError {
source: Box::new(error),
}
}
}
pub fn validate_probability(p: f64, name: &str) -> Result<()> {
if !p.is_finite() {
return Err(ProbabilisticError::InvalidParameter {
parameter: name.to_string(),
message: format!("probability must be finite, got {}", p),
});
}
if !(0.0..=1.0).contains(&p) {
return Err(ProbabilisticError::InvalidParameter {
parameter: name.to_string(),
message: format!("probability must be in [0, 1], got {}", p),
});
}
Ok(())
}
pub fn validate_positive(value: f64, name: &str) -> Result<()> {
if !value.is_finite() {
return Err(ProbabilisticError::InvalidParameter {
parameter: name.to_string(),
message: format!("value must be finite, got {}", value),
});
}
if value <= 0.0 {
return Err(ProbabilisticError::InvalidParameter {
parameter: name.to_string(),
message: format!("value must be positive, got {}", value),
});
}
Ok(())
}
pub fn validate_non_negative(value: f64, name: &str) -> Result<()> {
if !value.is_finite() {
return Err(ProbabilisticError::InvalidParameter {
parameter: name.to_string(),
message: format!("value must be finite, got {}", value),
});
}
if value < 0.0 {
return Err(ProbabilisticError::InvalidParameter {
parameter: name.to_string(),
message: format!("value must be non-negative, got {}", value),
});
}
Ok(())
}
pub fn validate_shape(expected: &[usize], actual: &[usize], operation: &str) -> Result<()> {
if expected != actual {
return Err(ProbabilisticError::DimensionMismatch {
expected: expected.to_vec(),
actual: actual.to_vec(),
operation: operation.to_string(),
});
}
Ok(())
}
#[cfg(test)]
mod module_tests {
use super::*;
#[test]
fn test_validate_probability() {
assert!(validate_probability(0.0, "p").is_ok());
assert!(validate_probability(0.5, "p").is_ok());
assert!(validate_probability(1.0, "p").is_ok());
assert!(validate_probability(-0.1, "p").is_err());
assert!(validate_probability(1.1, "p").is_err());
assert!(validate_probability(f64::NAN, "p").is_err());
assert!(validate_probability(f64::INFINITY, "p").is_err());
}
#[test]
fn test_validate_positive() {
assert!(validate_positive(0.1, "x").is_ok());
assert!(validate_positive(1.0, "x").is_ok());
assert!(validate_positive(100.0, "x").is_ok());
assert!(validate_positive(0.0, "x").is_err());
assert!(validate_positive(-1.0, "x").is_err());
assert!(validate_positive(f64::NAN, "x").is_err());
}
#[test]
fn test_validate_non_negative() {
assert!(validate_non_negative(0.0, "x").is_ok());
assert!(validate_non_negative(0.1, "x").is_ok());
assert!(validate_non_negative(1.0, "x").is_ok());
assert!(validate_non_negative(-0.1, "x").is_err());
assert!(validate_non_negative(f64::NAN, "x").is_err());
}
#[test]
fn test_validate_shape() {
assert!(validate_shape(&[2, 3], &[2, 3], "test").is_ok());
assert!(validate_shape(&[2], &[2], "test").is_ok());
assert!(validate_shape(&[2, 3], &[3, 2], "test").is_err());
assert!(validate_shape(&[2, 3], &[2], "test").is_err());
}
#[test]
fn test_error_display() {
let err = ProbabilisticError::InvalidParameter {
parameter: "alpha".to_string(),
message: "must be positive".to_string(),
};
let display = format!("{}", err);
assert!(display.contains("alpha"));
assert!(display.contains("positive"));
}
#[test]
fn test_error_from_numrs2() {
let numrs2_err = NumRs2Error::DimensionMismatch("expected 2x3, got 3x2".to_string());
let prob_err: ProbabilisticError = numrs2_err.into();
match prob_err {
ProbabilisticError::NumRs2IntegrationError { .. } => {}
_ => panic!("Expected NumRs2IntegrationError"),
}
}
}