use super::contains::Contains;
use super::f32_bw0and1::F32Bw0and1;
use super::ord_pair::OrdPair;
use crate::Error;
use serde::{Deserialize, Serialize};
use std::{fmt, str::FromStr as _};
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum ThresholdState {
GtEq(u8),
InvertGtEqLtEq(OrdPair<u8>),
Both((u8, OrdPair<u8>)),
}
impl Default for ThresholdState {
fn default() -> Self {
ThresholdState::GtEq(0)
}
}
impl fmt::Display for ThresholdState {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let printable = match *self {
ThresholdState::GtEq(v) => format!("probabilities >= {:.4}", F32Bw0and1::from(v)),
ThresholdState::InvertGtEqLtEq(v) => {
format!(
"probabilities < {:.4} or > {:.4}",
F32Bw0and1::from(v.low()),
F32Bw0and1::from(v.high())
)
}
ThresholdState::Both((a, b)) => {
format!(
"{:.4} and ({:.4})",
ThresholdState::GtEq(a),
ThresholdState::InvertGtEqLtEq(b)
)
}
};
write!(f, "{printable}")
}
}
impl Contains<u8> for ThresholdState {
fn contains(&self, val: &u8) -> bool {
match *self {
ThresholdState::GtEq(v) => *val >= v,
ThresholdState::InvertGtEqLtEq(w) => !w.contains(val),
ThresholdState::Both((a, b)) => {
ThresholdState::GtEq(a).contains(val)
&& ThresholdState::InvertGtEqLtEq(b).contains(val)
}
}
}
}
impl From<OrdPair<F32Bw0and1>> for ThresholdState {
fn from(value: OrdPair<F32Bw0and1>) -> Self {
let low: u8 = value.low().into();
let high: u8 = value.high().into();
ThresholdState::InvertGtEqLtEq(OrdPair::<u8>::new(low, high).expect("no error"))
}
}
impl ThresholdState {
pub fn from_str_ordpair_fraction(value: &str) -> Result<ThresholdState, Error> {
if value.is_empty() {
Ok(ThresholdState::GtEq(0))
} else {
let result: ThresholdState = OrdPair::<F32Bw0and1>::from_str(value)?.into();
Ok(result)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn threshold_state_gt_eq() {
let threshold = ThresholdState::GtEq(100);
assert!(threshold.contains(&101));
assert!(threshold.contains(&100));
assert!(!threshold.contains(&99));
let display_str = format!("{threshold}");
assert!(display_str.contains("probabilities >= 0.3922"));
}
#[test]
fn threshold_state_invert_gt_eq_lt_eq() {
let pair = OrdPair::new(200, 220).expect("should create");
let threshold = ThresholdState::InvertGtEqLtEq(pair);
assert!(threshold.contains(&0)); assert!(threshold.contains(&100)); assert!(!threshold.contains(&200)); assert!(!threshold.contains(&210)); assert!(!threshold.contains(&220)); assert!(threshold.contains(&250));
let display_str = format!("{threshold}");
assert!(display_str.contains("probabilities <"));
assert!(display_str.contains("or >"));
}
#[test]
fn threshold_state_both() {
let pair = OrdPair::new(200, 220).expect("should create");
let threshold = ThresholdState::Both((100, pair));
assert!(!threshold.contains(&0)); assert!(!threshold.contains(&99)); assert!(threshold.contains(&100)); assert!(threshold.contains(&101)); assert!(!threshold.contains(&200)); assert!(!threshold.contains(&210)); assert!(!threshold.contains(&220)); assert!(threshold.contains(&250));
let display_str = format!("{threshold}");
assert!(display_str.contains("and"));
assert!(display_str.contains("probabilities >="));
}
#[test]
fn threshold_state_default() {
let default_threshold = ThresholdState::default();
assert!(matches!(default_threshold, ThresholdState::GtEq(0)));
for val in 0..=255u8 {
assert!(default_threshold.contains(&val));
}
}
#[test]
fn threshold_state_display_consistency() {
let thresholds = vec![
ThresholdState::GtEq(128),
ThresholdState::InvertGtEqLtEq(OrdPair::new(100, 150).expect("should create")),
ThresholdState::Both((50, OrdPair::new(120, 140).expect("should create"))),
];
for threshold in thresholds {
let display_str = format!("{threshold}");
assert!(display_str.contains("probabilities"));
assert!(!display_str.is_empty());
}
}
#[test]
fn threshold_state_edge_cases() {
let threshold_255 = ThresholdState::GtEq(255);
assert!(threshold_255.contains(&255));
assert!(!threshold_255.contains(&254));
let threshold_0 = ThresholdState::GtEq(0);
assert!(threshold_0.contains(&0));
assert!(threshold_0.contains(&255));
let single_val_pair = OrdPair::new(128, 129).expect("should create");
let threshold_single = ThresholdState::InvertGtEqLtEq(single_val_pair);
assert!(threshold_single.contains(&127));
assert!(!threshold_single.contains(&128));
assert!(!threshold_single.contains(&129));
assert!(threshold_single.contains(&130));
}
#[test]
fn threshold_state_from_ordpair_f32bw0and1() {
use std::str::FromStr as _;
let b: ThresholdState = OrdPair::<F32Bw0and1>::from_str("0.4,0.6")
.expect("should parse")
.into();
assert_eq!(
b,
ThresholdState::InvertGtEqLtEq(
OrdPair::<u8>::new(102u8, 153u8).expect("should create")
)
);
}
#[test]
fn threshold_state_from_str_ordpair_fraction_simple() {
let a = ThresholdState::from_str_ordpair_fraction("0.4,0.6").expect("should parse");
assert_eq!(
a,
ThresholdState::InvertGtEqLtEq((102u8, 153u8).try_into().expect("should create"))
);
}
#[test]
fn threshold_state_from_str_ordpair_fraction_empty_string() {
let a = ThresholdState::from_str_ordpair_fraction("").expect("should parse");
assert_eq!(a, ThresholdState::GtEq(0));
}
#[test]
fn threshold_state_from_str_ordpair_fraction_error_cases() {
let _: Error = ThresholdState::from_str_ordpair_fraction("invalid").unwrap_err();
let _: Error = ThresholdState::from_str_ordpair_fraction("0.5").unwrap_err();
let _: Error = ThresholdState::from_str_ordpair_fraction("0.6,0.4").unwrap_err(); let _: Error = ThresholdState::from_str_ordpair_fraction("1.5,2.0").unwrap_err(); }
#[test]
fn threshold_state_from_ordpair_f32bw0and1_conversion() {
use std::str::FromStr as _;
let pair1 = OrdPair::<F32Bw0and1>::from_str("0.4,0.6").expect("should parse");
let threshold1: ThresholdState = pair1.into();
assert_eq!(
threshold1,
ThresholdState::InvertGtEqLtEq(
OrdPair::<u8>::new(102u8, 153u8).expect("should create")
)
);
let pair2 = OrdPair::<F32Bw0and1>::from_str("0.0,1.0").expect("should parse");
let threshold2: ThresholdState = pair2.into();
assert_eq!(
threshold2,
ThresholdState::InvertGtEqLtEq(OrdPair::<u8>::new(0u8, 255u8).expect("should create"))
);
let pair3 = OrdPair::<F32Bw0and1>::from_str("0.5,0.7").expect("should parse");
let threshold3: ThresholdState = pair3.into();
assert!(
matches!(threshold3, ThresholdState::InvertGtEqLtEq(_)),
"Expected InvertGtEqLtEq variant"
);
if let ThresholdState::InvertGtEqLtEq(ord_pair) = threshold3 {
assert!(ord_pair.low() >= 127 && ord_pair.low() <= 128);
assert!(ord_pair.high() >= 178 && ord_pair.high() <= 179);
}
}
}