use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, thiserror::Error, Serialize, Deserialize)]
#[non_exhaustive]
pub enum AvataraError {
#[error("invalid parameter: {0}")]
InvalidParameter(String),
#[error("unknown archetype: {0}")]
UnknownArchetype(String),
#[error("incompatible archetypes: {0}")]
Incompatible(String),
#[error("trait weight out of range in {context}: {value}")]
OutOfRange { context: String, value: f64 },
}
pub type Result<T> = std::result::Result<T, AvataraError>;
#[inline]
pub fn require_unit_range(value: f64, context: &str) -> Result<f64> {
if (0.0..=1.0).contains(&value) {
Ok(value)
} else {
Err(AvataraError::OutOfRange {
context: context.to_string(),
value,
})
}
}
#[inline]
pub fn require_all_unit_range(values: &[f64], context: &str) -> Result<()> {
for &v in values {
require_unit_range(v, context)?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn valid_unit_range() {
assert!(require_unit_range(0.0, "test").is_ok());
assert!(require_unit_range(0.5, "test").is_ok());
assert!(require_unit_range(1.0, "test").is_ok());
}
#[test]
fn invalid_unit_range() {
assert!(require_unit_range(-0.1, "test").is_err());
assert!(require_unit_range(1.1, "test").is_err());
assert!(require_unit_range(f64::NAN, "test").is_err());
}
#[test]
fn all_unit_range() {
assert!(require_all_unit_range(&[0.0, 0.5, 1.0], "test").is_ok());
assert!(require_all_unit_range(&[0.5, 1.5], "test").is_err());
}
}