use std::fmt;
use thiserror::Error;
const SCALE: f32 = u16::MAX as f32;
#[derive(Debug, Error, PartialEq)]
pub enum ConfidenceError {
#[error("confidence {0} outside [0.0, 1.0]")]
OutOfRange(f32),
#[error("confidence NaN is not a valid value")]
NotANumber,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Confidence(u16);
impl Confidence {
pub const ZERO: Self = Self(0);
pub const ONE: Self = Self(u16::MAX);
pub fn try_from_f32(value: f32) -> Result<Self, ConfidenceError> {
if value.is_nan() {
return Err(ConfidenceError::NotANumber);
}
if !(0.0..=1.0).contains(&value) {
return Err(ConfidenceError::OutOfRange(value));
}
let scaled = (f64::from(value) * f64::from(SCALE)).round_ties_even();
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let stored = scaled as u16;
Ok(Self(stored))
}
#[must_use]
pub const fn from_u16(raw: u16) -> Self {
Self(raw)
}
#[must_use]
pub const fn as_u16(self) -> u16 {
self.0
}
#[must_use]
#[allow(clippy::cast_possible_truncation)]
pub fn as_f32(self) -> f32 {
(f64::from(self.0) / f64::from(SCALE)) as f32
}
}
impl fmt::Display for Confidence {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:.4}", self.as_f32())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn boundary_values() {
let zero = Confidence::try_from_f32(0.0).unwrap();
let one = Confidence::try_from_f32(1.0).unwrap();
assert_eq!(zero, Confidence::ZERO);
assert_eq!(one, Confidence::ONE);
}
#[test]
fn out_of_range_rejected() {
assert!(matches!(
Confidence::try_from_f32(-0.01),
Err(ConfidenceError::OutOfRange(_))
));
assert!(matches!(
Confidence::try_from_f32(1.01),
Err(ConfidenceError::OutOfRange(_))
));
}
#[test]
fn nan_rejected() {
assert_eq!(
Confidence::try_from_f32(f32::NAN),
Err(ConfidenceError::NotANumber),
);
}
#[test]
fn roundtrip_precision_within_one_step() {
let step = 1.0 / f32::from(u16::MAX);
for raw in [0_u16, 1, 32_768, 65_534, 65_535] {
let c = Confidence::from_u16(raw);
let rebuilt = Confidence::try_from_f32(c.as_f32()).unwrap();
let delta = i64::from(c.as_u16()) - i64::from(rebuilt.as_u16());
assert!(delta.abs() <= 1, "raw={raw} delta={delta} step={step}");
}
}
}