use thiserror::Error;
#[derive(Error, Debug, Clone, PartialEq)]
pub enum StatsError {
#[error("Invalid input: {message}")]
InvalidInput {
message: String,
},
#[error("Conversion error: {message}")]
ConversionError {
message: String,
},
#[error("Empty data: {message}")]
EmptyData {
message: String,
},
#[error("Dimension mismatch: {message}")]
DimensionMismatch {
message: String,
},
#[error("Numerical error: {message}")]
NumericalError {
message: String,
},
#[error("Model not fitted: {message}")]
NotFitted {
message: String,
},
#[error("Invalid parameter: {message}")]
InvalidParameter {
message: String,
},
#[error("Index out of bounds: {message}")]
IndexOutOfBounds {
message: String,
},
#[error("Mathematical error: {message}")]
MathematicalError {
message: String,
},
}
pub type StatsResult<T> = Result<T, StatsError>;
impl StatsError {
pub fn invalid_input<S: Into<String>>(message: S) -> Self {
StatsError::InvalidInput {
message: message.into(),
}
}
pub fn conversion_error<S: Into<String>>(message: S) -> Self {
StatsError::ConversionError {
message: message.into(),
}
}
pub fn empty_data<S: Into<String>>(message: S) -> Self {
StatsError::EmptyData {
message: message.into(),
}
}
pub fn dimension_mismatch<S: Into<String>>(message: S) -> Self {
StatsError::DimensionMismatch {
message: message.into(),
}
}
pub fn numerical_error<S: Into<String>>(message: S) -> Self {
StatsError::NumericalError {
message: message.into(),
}
}
pub fn not_fitted<S: Into<String>>(message: S) -> Self {
StatsError::NotFitted {
message: message.into(),
}
}
pub fn invalid_parameter<S: Into<String>>(message: S) -> Self {
StatsError::InvalidParameter {
message: message.into(),
}
}
pub fn index_out_of_bounds<S: Into<String>>(message: S) -> Self {
StatsError::IndexOutOfBounds {
message: message.into(),
}
}
pub fn mathematical_error<S: Into<String>>(message: S) -> Self {
StatsError::MathematicalError {
message: message.into(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_all_variants_display() {
let cases = vec![
(StatsError::invalid_input("msg"), "Invalid input: msg"),
(StatsError::conversion_error("msg"), "Conversion error: msg"),
(StatsError::empty_data("msg"), "Empty data: msg"),
(
StatsError::dimension_mismatch("msg"),
"Dimension mismatch: msg",
),
(StatsError::numerical_error("msg"), "Numerical error: msg"),
(StatsError::not_fitted("msg"), "Model not fitted: msg"),
(
StatsError::invalid_parameter("msg"),
"Invalid parameter: msg",
),
(
StatsError::index_out_of_bounds("msg"),
"Index out of bounds: msg",
),
(
StatsError::mathematical_error("msg"),
"Mathematical error: msg",
),
];
for (err, expected) in cases {
assert_eq!(err.to_string(), expected, "Display format mismatch");
}
}
#[test]
fn test_error_equality() {
let err1 = StatsError::invalid_input("message");
let err2 = StatsError::invalid_input("message");
let err3 = StatsError::invalid_input("different");
let err4 = StatsError::conversion_error("message");
assert_eq!(err1, err2, "Same variant and message should be equal");
assert_ne!(err1, err3, "Different messages should not be equal");
assert_ne!(err1, err4, "Different variants should not be equal");
}
#[test]
fn test_error_clone() {
let err = StatsError::conversion_error("test");
let cloned = err.clone();
assert_eq!(err, cloned);
}
#[test]
fn test_stats_result() {
let ok: StatsResult<f64> = Ok(42.0);
assert_eq!(ok.unwrap(), 42.0);
let err: StatsResult<f64> = Err(StatsError::invalid_input("test"));
assert!(err.is_err());
assert_eq!(err.unwrap_err(), StatsError::invalid_input("test"));
}
#[test]
fn test_helper_methods() {
assert!(matches!(
StatsError::invalid_input("msg"),
StatsError::InvalidInput { .. }
));
assert!(matches!(
StatsError::conversion_error("msg"),
StatsError::ConversionError { .. }
));
assert!(matches!(
StatsError::empty_data("msg"),
StatsError::EmptyData { .. }
));
assert!(matches!(
StatsError::dimension_mismatch("msg"),
StatsError::DimensionMismatch { .. }
));
assert!(matches!(
StatsError::numerical_error("msg"),
StatsError::NumericalError { .. }
));
assert!(matches!(
StatsError::not_fitted("msg"),
StatsError::NotFitted { .. }
));
assert!(matches!(
StatsError::invalid_parameter("msg"),
StatsError::InvalidParameter { .. }
));
assert!(matches!(
StatsError::index_out_of_bounds("msg"),
StatsError::IndexOutOfBounds { .. }
));
assert!(matches!(
StatsError::mathematical_error("msg"),
StatsError::MathematicalError { .. }
));
}
#[test]
fn test_into_string_conversion() {
let err1 = StatsError::invalid_input("string slice");
assert_eq!(err1.to_string(), "Invalid input: string slice");
let err2 = StatsError::invalid_input("owned string".to_string());
assert_eq!(err2.to_string(), "Invalid input: owned string");
}
#[test]
fn test_error_propagation() {
fn might_fail() -> StatsResult<f64> {
Err(StatsError::invalid_input("test"))
}
fn propagate() -> StatsResult<f64> {
might_fail()?;
Ok(42.0)
}
let result = propagate();
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_pattern_matching() {
let err = StatsError::conversion_error("failed");
match err {
StatsError::ConversionError { message } => {
assert_eq!(message, "failed");
}
_ => panic!("Wrong variant matched"),
}
}
#[test]
fn test_debug_implementation() {
let err = StatsError::invalid_input("test");
let debug_str = format!("{:?}", err);
assert!(debug_str.contains("InvalidInput"));
assert!(debug_str.contains("test"));
}
}