#[cfg(feature = "alloc")]
use alloc::string::String;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum FormatError {
BadMagic,
UnsupportedVersion,
Truncated,
Unaligned,
InvalidNodeIndex,
InvalidFeatureIndex,
MisalignedTreeOffset,
}
impl core::fmt::Display for FormatError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
FormatError::BadMagic => write!(f, "bad magic: expected \"IRIT\""),
FormatError::UnsupportedVersion => write!(f, "unsupported format version"),
FormatError::Truncated => write!(f, "buffer truncated"),
FormatError::Unaligned => write!(f, "buffer not 4-byte aligned"),
FormatError::InvalidNodeIndex => write!(f, "node child index out of bounds"),
FormatError::InvalidFeatureIndex => write!(f, "feature index exceeds n_features"),
FormatError::MisalignedTreeOffset => {
write!(f, "tree offset not aligned to node size")
}
}
}
}
#[cfg(feature = "alloc")]
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum ConfigError {
OutOfRange {
param: &'static str,
constraint: &'static str,
value: String,
},
Invalid {
param: &'static str,
reason: String,
},
}
#[cfg(feature = "alloc")]
impl ConfigError {
pub fn out_of_range(
param: &'static str,
constraint: &'static str,
value: impl core::fmt::Display,
) -> Self {
use alloc::format;
ConfigError::OutOfRange {
param,
constraint,
value: format!("{}", value),
}
}
pub fn invalid(param: &'static str, reason: impl Into<String>) -> Self {
ConfigError::Invalid {
param,
reason: reason.into(),
}
}
}
#[cfg(feature = "alloc")]
impl core::fmt::Display for ConfigError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
ConfigError::OutOfRange {
param,
constraint,
value,
} => write!(f, "{} {} (got {})", param, constraint, value),
ConfigError::Invalid { param, reason } => write!(f, "{} {}", param, reason),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for ConfigError {}
#[cfg(feature = "alloc")]
#[derive(Debug)]
#[non_exhaustive]
pub enum IrithyllError {
InvalidConfig(ConfigError),
InsufficientData(String),
DimensionMismatch {
expected: usize,
got: usize,
},
NotTrained,
}
#[cfg(feature = "alloc")]
impl core::fmt::Display for IrithyllError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
IrithyllError::InvalidConfig(e) => write!(f, "invalid configuration: {}", e),
IrithyllError::InsufficientData(msg) => write!(f, "insufficient data: {}", msg),
IrithyllError::DimensionMismatch { expected, got } => {
write!(f, "dimension mismatch: expected {}, got {}", expected, got)
}
IrithyllError::NotTrained => write!(f, "model not trained"),
}
}
}
#[cfg(feature = "alloc")]
impl From<ConfigError> for IrithyllError {
fn from(e: ConfigError) -> Self {
IrithyllError::InvalidConfig(e)
}
}
#[cfg(feature = "alloc")]
pub type Result<T> = core::result::Result<T, IrithyllError>;
#[cfg(test)]
mod tests {
use super::*;
use alloc::string::ToString;
#[test]
fn format_error_display() {
assert_eq!(
FormatError::BadMagic.to_string(),
"bad magic: expected \"IRIT\""
);
assert_eq!(FormatError::Truncated.to_string(), "buffer truncated");
}
#[cfg(feature = "alloc")]
#[test]
fn config_error_out_of_range_display() {
let e = ConfigError::out_of_range("n_steps", "must be > 0", 0);
assert_eq!(e.to_string(), "n_steps must be > 0 (got 0)");
}
#[cfg(feature = "alloc")]
#[test]
fn config_error_invalid_display() {
let e = ConfigError::invalid(
"split_reeval_interval",
"must be >= grace_period (200), got 50",
);
assert!(e.to_string().contains("split_reeval_interval"));
assert!(e.to_string().contains("must be >= grace_period"));
}
#[cfg(feature = "alloc")]
#[test]
fn irithyll_error_from_config_error() {
let ce = ConfigError::out_of_range("learning_rate", "must be in (0, 1]", 1.5);
let ie: IrithyllError = ce.into();
let msg = ie.to_string();
assert!(msg.contains("invalid configuration"));
assert!(msg.contains("learning_rate"));
}
#[cfg(feature = "alloc")]
#[test]
fn irithyll_error_dimension_mismatch() {
let e = IrithyllError::DimensionMismatch {
expected: 10,
got: 5,
};
assert_eq!(e.to_string(), "dimension mismatch: expected 10, got 5");
}
#[cfg(feature = "alloc")]
#[test]
fn irithyll_error_not_trained() {
assert_eq!(IrithyllError::NotTrained.to_string(), "model not trained");
}
}