use serde::{Deserialize, Serialize};
#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
#[serde(tag = "mode", content = "threshold")]
pub enum RecallPrecisionMode {
HighRecall(f64),
Balanced,
HighPrecision(f64),
Custom(f64),
}
impl Default for RecallPrecisionMode {
#[inline]
fn default() -> Self {
Self::Balanced
}
}
impl RecallPrecisionMode {
pub const HIGH_RECALL_DEFAULT: f64 = 0.6;
pub const BALANCED_DEFAULT: f64 = 0.8;
pub const HIGH_PRECISION_DEFAULT: f64 = 0.9;
#[inline]
pub const fn high_recall(threshold: f64) -> Self {
Self::HighRecall(threshold)
}
#[inline]
pub const fn high_precision(threshold: f64) -> Self {
Self::HighPrecision(threshold)
}
#[inline]
pub const fn custom(threshold: f64) -> Self {
Self::Custom(threshold)
}
#[inline]
pub fn threshold(&self) -> f64 {
match self {
Self::HighRecall(t) => *t,
Self::Balanced => Self::BALANCED_DEFAULT,
Self::HighPrecision(t) => *t,
Self::Custom(t) => *t,
}
}
#[inline]
pub fn permissiveness(&self) -> f64 {
match self {
Self::HighRecall(_) => 0.8,
Self::Balanced => 0.5,
Self::HighPrecision(_) => 0.2,
Self::Custom(t) => 1.0 - t,
}
}
#[inline]
pub fn favors_recall(&self) -> bool {
matches!(self, Self::HighRecall(_))
}
#[inline]
pub fn favors_precision(&self) -> bool {
matches!(self, Self::HighPrecision(_))
}
#[inline]
pub fn accepts(&self, score: f64) -> bool {
score >= self.threshold()
}
pub fn from_threshold(threshold: f64) -> Self {
if threshold < 0.7 {
Self::HighRecall(threshold)
} else if threshold < 0.85 {
if (threshold - Self::BALANCED_DEFAULT).abs() < 0.01 {
Self::Balanced
} else {
Self::Custom(threshold)
}
} else {
Self::HighPrecision(threshold)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_is_balanced() {
assert_eq!(
RecallPrecisionMode::default(),
RecallPrecisionMode::Balanced
);
}
#[test]
fn test_threshold_values() {
assert!((RecallPrecisionMode::Balanced.threshold() - 0.8).abs() < 0.001);
assert!((RecallPrecisionMode::high_recall(0.5).threshold() - 0.5).abs() < 0.001);
assert!((RecallPrecisionMode::high_precision(0.95).threshold() - 0.95).abs() < 0.001);
assert!((RecallPrecisionMode::custom(0.75).threshold() - 0.75).abs() < 0.001);
}
#[test]
fn test_favors_recall() {
assert!(RecallPrecisionMode::high_recall(0.5).favors_recall());
assert!(!RecallPrecisionMode::Balanced.favors_recall());
assert!(!RecallPrecisionMode::high_precision(0.9).favors_recall());
}
#[test]
fn test_favors_precision() {
assert!(!RecallPrecisionMode::high_recall(0.5).favors_precision());
assert!(!RecallPrecisionMode::Balanced.favors_precision());
assert!(RecallPrecisionMode::high_precision(0.9).favors_precision());
}
#[test]
fn test_accepts() {
let mode = RecallPrecisionMode::high_recall(0.6);
assert!(mode.accepts(0.6));
assert!(mode.accepts(0.8));
assert!(!mode.accepts(0.5));
let mode = RecallPrecisionMode::high_precision(0.9);
assert!(mode.accepts(0.9));
assert!(mode.accepts(1.0));
assert!(!mode.accepts(0.89));
}
#[test]
fn test_from_threshold() {
assert!(matches!(
RecallPrecisionMode::from_threshold(0.5),
RecallPrecisionMode::HighRecall(_)
));
assert!(matches!(
RecallPrecisionMode::from_threshold(0.8),
RecallPrecisionMode::Balanced
));
assert!(matches!(
RecallPrecisionMode::from_threshold(0.9),
RecallPrecisionMode::HighPrecision(_)
));
}
#[test]
fn test_permissiveness() {
assert!(RecallPrecisionMode::high_recall(0.5).permissiveness() > 0.5);
assert!((RecallPrecisionMode::Balanced.permissiveness() - 0.5).abs() < 0.001);
assert!(RecallPrecisionMode::high_precision(0.9).permissiveness() < 0.5);
}
#[test]
fn test_serde_roundtrip() {
let modes = vec![
RecallPrecisionMode::HighRecall(0.6),
RecallPrecisionMode::Balanced,
RecallPrecisionMode::HighPrecision(0.95),
RecallPrecisionMode::Custom(0.75),
];
for mode in modes {
let json = serde_json::to_string(&mode).unwrap();
let parsed: RecallPrecisionMode = serde_json::from_str(&json).unwrap();
assert_eq!(mode, parsed);
}
}
}