use std::fmt;
#[derive(Debug, Clone, PartialEq)]
pub enum Error {
SingularMatrix,
InsufficientData {
required: usize,
available: usize,
},
InvalidInput(String),
DimensionMismatch(String),
ComputationFailed(String),
ParseError(String),
DomainCheck(String),
IoError(String),
SerializationError(String),
DeserializationError(String),
IncompatibleFormatVersion {
file_version: String,
supported: String,
},
ModelTypeMismatch {
expected: String,
found: String,
},
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::SingularMatrix => {
write!(
f,
"Matrix is singular (perfect multicollinearity). Remove redundant variables."
)
},
Error::InsufficientData {
required,
available,
} => {
write!(
f,
"Insufficient data: need at least {} observations, have {}",
required, available
)
},
Error::InvalidInput(msg) => {
write!(f, "Invalid input: {}", msg)
},
Error::DimensionMismatch(msg) => {
write!(f, "Dimension mismatch: {}", msg)
},
Error::ComputationFailed(msg) => {
write!(f, "Computation failed: {}", msg)
},
Error::ParseError(msg) => {
write!(f, "Parse error: {}", msg)
},
Error::DomainCheck(msg) => {
write!(f, "Domain check failed: {}", msg)
},
Error::IoError(msg) => {
write!(f, "I/O error: {}", msg)
},
Error::SerializationError(msg) => {
write!(f, "Serialization error: {}", msg)
},
Error::DeserializationError(msg) => {
write!(f, "Deserialization error: {}", msg)
},
Error::IncompatibleFormatVersion { file_version, supported } => {
write!(
f,
"Incompatible format version: file has version {}, supported version is {}",
file_version, supported
)
},
Error::ModelTypeMismatch { expected, found } => {
write!(
f,
"Model type mismatch: expected {}, found {}",
expected, found
)
},
}
}
}
impl std::error::Error for Error {}
pub type Result<T> = std::result::Result<T, Error>;
pub fn error_json(msg: &str) -> String {
serde_json::json!({ "error": msg }).to_string()
}
pub fn error_to_json(err: &Error) -> String {
error_json(&err.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_singular_matrix_display() {
let err = Error::SingularMatrix;
let msg = err.to_string();
assert!(msg.contains("singular"));
assert!(msg.contains("multicollinearity"));
}
#[test]
fn test_insufficient_data_display() {
let err = Error::InsufficientData {
required: 10,
available: 5,
};
let msg = err.to_string();
assert!(msg.contains("Insufficient data"));
assert!(msg.contains("10"));
assert!(msg.contains("5"));
}
#[test]
fn test_invalid_input_display() {
let err = Error::InvalidInput("negative value".to_string());
let msg = err.to_string();
assert!(msg.contains("Invalid input"));
assert!(msg.contains("negative value"));
}
#[test]
fn test_dimension_mismatch_display() {
let err = Error::DimensionMismatch("matrix 3x3 cannot multiply with 2x2".to_string());
let msg = err.to_string();
assert!(msg.contains("Dimension mismatch"));
assert!(msg.contains("matrix 3x3"));
}
#[test]
fn test_computation_failed_display() {
let err = Error::ComputationFailed("QR decomposition failed".to_string());
let msg = err.to_string();
assert!(msg.contains("Computation failed"));
assert!(msg.contains("QR decomposition"));
}
#[test]
fn test_parse_error_display() {
let err = Error::ParseError("invalid JSON syntax".to_string());
let msg = err.to_string();
assert!(msg.contains("Parse error"));
assert!(msg.contains("JSON"));
}
#[test]
fn test_domain_check_display() {
let err = Error::DomainCheck("unauthorized domain".to_string());
let msg = err.to_string();
assert!(msg.contains("Domain check failed"));
assert!(msg.contains("unauthorized"));
}
#[test]
fn test_error_json() {
let json = error_json("test error");
assert_eq!(json, r#"{"error":"test error"}"#);
}
#[test]
fn test_error_to_json_singular_matrix() {
let err = Error::SingularMatrix;
let json = error_to_json(&err);
assert!(json.contains(r#""error":"#));
assert!(json.contains("singular"));
}
#[test]
fn test_error_to_json_dimension_mismatch() {
let err = Error::DimensionMismatch("incompatible dimensions".to_string());
let json = error_to_json(&err);
assert!(json.contains(r#""error":"#));
assert!(json.contains("Dimension"));
}
#[test]
fn test_error_to_json_computation_failed() {
let err = Error::ComputationFailed("convergence failure".to_string());
let json = error_to_json(&err);
assert!(json.contains(r#""error":"#));
assert!(json.contains("Computation"));
}
#[test]
fn test_error_partial_eq() {
let err1 = Error::SingularMatrix;
let err2 = Error::SingularMatrix;
let err3 = Error::InvalidInput("test".to_string());
assert_eq!(err1, err2);
assert_ne!(err1, err3);
}
#[test]
fn test_error_clone() {
let err1 = Error::InvalidInput("test".to_string());
let err2 = err1.clone();
assert_eq!(err1, err2);
}
#[test]
fn test_error_debug() {
let err = Error::ComputationFailed("test failure".to_string());
let debug_str = format!("{:?}", err);
assert!(debug_str.contains("ComputationFailed"));
}
#[test]
fn test_result_type_alias() {
fn returns_ok() -> Result<f64> {
Ok(42.0)
}
fn returns_err() -> Result<f64> {
Err(Error::InvalidInput("test".to_string()))
}
assert_eq!(returns_ok().unwrap(), 42.0);
assert!(returns_err().is_err());
}
#[test]
fn test_io_error_display() {
let err = Error::IoError("Failed to open file".to_string());
let msg = err.to_string();
assert!(msg.contains("I/O error"));
assert!(msg.contains("Failed to open file"));
}
#[test]
fn test_serialization_error_display() {
let err = Error::SerializationError("Failed to serialize model".to_string());
let msg = err.to_string();
assert!(msg.contains("Serialization error"));
assert!(msg.contains("Failed to serialize"));
}
#[test]
fn test_deserialization_error_display() {
let err = Error::DeserializationError("Invalid JSON".to_string());
let msg = err.to_string();
assert!(msg.contains("Deserialization error"));
assert!(msg.contains("Invalid JSON"));
}
#[test]
fn test_incompatible_format_version_display() {
let err = Error::IncompatibleFormatVersion {
file_version: "2.0".to_string(),
supported: "1.0".to_string(),
};
let msg = err.to_string();
assert!(msg.contains("Incompatible format version"));
assert!(msg.contains("2.0"));
assert!(msg.contains("1.0"));
}
#[test]
fn test_model_type_mismatch_display() {
let err = Error::ModelTypeMismatch {
expected: "OLS".to_string(),
found: "Ridge".to_string(),
};
let msg = err.to_string();
assert!(msg.contains("Model type mismatch"));
assert!(msg.contains("OLS"));
assert!(msg.contains("Ridge"));
}
#[test]
fn test_error_to_json_serialization() {
let err = Error::SerializationError("test".to_string());
let json = error_to_json(&err);
assert!(json.contains(r#""error":"#));
assert!(json.contains("Serialization"));
}
#[test]
fn test_error_to_json_deserialization() {
let err = Error::DeserializationError("test".to_string());
let json = error_to_json(&err);
assert!(json.contains(r#""error":"#));
assert!(json.contains("Deserialization"));
}
}