use crate::smoothing::EmaF64;
#[derive(Debug, Clone)]
pub struct HitRateF64 {
hits: u64,
total: u64,
ew_hits: EmaF64,
}
#[derive(Debug, Clone)]
pub struct HitRateF64Builder {
halflife: Option<f64>,
alpha: Option<f64>,
}
impl HitRateF64 {
#[inline]
#[must_use]
pub fn builder() -> HitRateF64Builder {
HitRateF64Builder {
halflife: None,
alpha: None,
}
}
#[inline]
pub fn update(
&mut self,
predicted_direction: f64,
realized_direction: f64,
) -> Result<(), crate::DataError> {
check_finite!(predicted_direction);
check_finite!(realized_direction);
self.total += 1;
let is_hit = (predicted_direction > 0.0 && realized_direction > 0.0)
|| (predicted_direction < 0.0 && realized_direction < 0.0);
if is_hit {
self.hits += 1;
}
let hit_val = if is_hit { 1.0 } else { 0.0 };
let _ = self.ew_hits.update(hit_val);
Ok(())
}
#[inline]
#[must_use]
pub fn hit_rate(&self) -> f64 {
if self.total == 0 {
return 0.0;
}
self.hits as f64 / self.total as f64
}
#[inline]
#[must_use]
pub fn ew_hit_rate(&self) -> f64 {
self.ew_hits.value().unwrap_or(0.0)
}
#[inline]
#[must_use]
pub fn count(&self) -> u64 {
self.total
}
#[inline]
#[must_use]
pub fn is_primed(&self) -> bool {
self.total > 0
}
pub fn reset(&mut self) {
self.hits = 0;
self.total = 0;
self.ew_hits.reset();
}
}
impl HitRateF64Builder {
#[inline]
#[must_use]
#[cfg(any(feature = "std", feature = "libm"))]
pub fn halflife(mut self, halflife: f64) -> Self {
let ln2 = core::f64::consts::LN_2;
self.alpha = Some(1.0 - crate::math::exp(-ln2 / halflife));
self.halflife = Some(halflife);
self
}
#[inline]
#[must_use]
pub fn alpha(mut self, alpha: f64) -> Self {
self.alpha = Some(alpha);
self
}
pub fn build(self) -> Result<HitRateF64, crate::ConfigError> {
let alpha = self
.alpha
.ok_or(crate::ConfigError::Missing("alpha or halflife"))?;
let ew_hits = EmaF64::builder().alpha(alpha).build()?;
Ok(HitRateF64 {
hits: 0,
total: 0,
ew_hits,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn all_correct() {
let mut hr = HitRateF64::builder().alpha(0.1).build().unwrap();
for _ in 0..100 {
hr.update(1.0, 1.0).unwrap();
}
assert!((hr.hit_rate() - 1.0).abs() < f64::EPSILON);
assert!(hr.ew_hit_rate() > 0.99);
}
#[test]
fn all_wrong() {
let mut hr = HitRateF64::builder().alpha(0.1).build().unwrap();
for _ in 0..100 {
hr.update(1.0, -1.0).unwrap();
}
assert!(hr.hit_rate().abs() < f64::EPSILON);
assert!(hr.ew_hit_rate() < 0.01);
}
#[test]
fn mixed_directions() {
let mut hr = HitRateF64::builder().alpha(0.1).build().unwrap();
for i in 0..10 {
if i < 5 {
hr.update(1.0, 1.0).unwrap();
} else {
hr.update(1.0, -1.0).unwrap();
}
}
assert!((hr.hit_rate() - 0.5).abs() < f64::EPSILON);
}
#[test]
fn zero_treated_as_miss() {
let mut hr = HitRateF64::builder().alpha(0.1).build().unwrap();
hr.update(0.0, 1.0).unwrap();
assert!(hr.hit_rate().abs() < f64::EPSILON);
hr.update(1.0, 0.0).unwrap();
assert!(hr.hit_rate().abs() < f64::EPSILON);
}
#[test]
fn priming() {
let hr = HitRateF64::builder().alpha(0.1).build().unwrap();
assert!(!hr.is_primed());
assert_eq!(hr.count(), 0);
}
#[test]
fn reset_clears_state() {
let mut hr = HitRateF64::builder().alpha(0.1).build().unwrap();
hr.update(1.0, 1.0).unwrap();
hr.reset();
assert_eq!(hr.count(), 0);
assert!(!hr.is_primed());
assert!(hr.hit_rate().abs() < f64::EPSILON);
}
#[test]
fn nan_rejected() {
let mut hr = HitRateF64::builder().alpha(0.1).build().unwrap();
assert!(hr.update(f64::NAN, 1.0).is_err());
assert!(hr.update(1.0, f64::NAN).is_err());
}
#[test]
fn inf_rejected() {
let mut hr = HitRateF64::builder().alpha(0.1).build().unwrap();
assert!(hr.update(f64::INFINITY, 1.0).is_err());
}
}