nexus-stats-core 3.0.1

Core types and utilities shared across nexus-stats subcrates
Documentation
use crate::smoothing::EmaF64;

/// Fraction of directionally correct predictions with exponential weighting.
///
/// Tracks both a cumulative hit rate (hits / total) and an exponentially
/// weighted hit rate for recency. A "hit" is when `predicted_direction`
/// and `realized_direction` have the same sign.
///
/// # Examples
///
/// ```
/// use nexus_stats_core::statistics::HitRateF64;
///
/// let mut hr = HitRateF64::builder()
///     .halflife(50.0)
///     .build()
///     .unwrap();
///
/// // Correct prediction: both positive
/// hr.update(1.0, 0.5).unwrap();
/// assert!((hr.hit_rate() - 1.0).abs() < f64::EPSILON);
/// ```
#[derive(Debug, Clone)]
pub struct HitRateF64 {
    hits: u64,
    total: u64,
    ew_hits: EmaF64,
}

/// Builder for [`HitRateF64`].
#[derive(Debug, Clone)]
pub struct HitRateF64Builder {
    halflife: Option<f64>,
    alpha: Option<f64>,
}

impl HitRateF64 {
    /// Creates a builder.
    #[inline]
    #[must_use]
    pub fn builder() -> HitRateF64Builder {
        HitRateF64Builder {
            halflife: None,
            alpha: None,
        }
    }

    /// Feed a prediction and realized direction.
    ///
    /// A "hit" is recorded when both values have the same sign (both
    /// positive or both negative). Zero is treated as no direction —
    /// if either is zero, it counts as a miss.
    ///
    /// # Errors
    ///
    /// Returns `DataError` if either value is NaN or infinite.
    #[inline]
    pub fn update(
        &mut self,
        predicted_direction: f64,
        realized_direction: f64,
    ) -> Result<(), crate::DataError> {
        check_finite!(predicted_direction);
        check_finite!(realized_direction);

        self.total += 1;
        let is_hit = (predicted_direction > 0.0 && realized_direction > 0.0)
            || (predicted_direction < 0.0 && realized_direction < 0.0);

        if is_hit {
            self.hits += 1;
        }

        let hit_val = if is_hit { 1.0 } else { 0.0 };
        // EmaF64::update returns Result<Option<f64>, DataError>, but 0.0/1.0
        // are always finite, so this cannot fail.
        let _ = self.ew_hits.update(hit_val);

        Ok(())
    }

    /// Cumulative hit rate (hits / total). Returns 0.0 if no observations.
    #[inline]
    #[must_use]
    pub fn hit_rate(&self) -> f64 {
        if self.total == 0 {
            return 0.0;
        }
        self.hits as f64 / self.total as f64
    }

    /// Exponentially weighted hit rate for recency.
    #[inline]
    #[must_use]
    pub fn ew_hit_rate(&self) -> f64 {
        self.ew_hits.value().unwrap_or(0.0)
    }

    /// Total number of observations.
    #[inline]
    #[must_use]
    pub fn count(&self) -> u64 {
        self.total
    }

    /// Whether enough observations have been fed.
    #[inline]
    #[must_use]
    pub fn is_primed(&self) -> bool {
        self.total > 0
    }

    /// Reset all state.
    pub fn reset(&mut self) {
        self.hits = 0;
        self.total = 0;
        self.ew_hits.reset();
    }
}

impl HitRateF64Builder {
    /// EW halflife in observations. One of `halflife` or `alpha` is required.
    #[inline]
    #[must_use]
    #[cfg(any(feature = "std", feature = "libm"))]
    pub fn halflife(mut self, halflife: f64) -> Self {
        let ln2 = core::f64::consts::LN_2;
        self.alpha = Some(1.0 - crate::math::exp(-ln2 / halflife));
        self.halflife = Some(halflife);
        self
    }

    /// Direct smoothing factor for the EW hit rate. Must be in (0, 1).
    #[inline]
    #[must_use]
    pub fn alpha(mut self, alpha: f64) -> Self {
        self.alpha = Some(alpha);
        self
    }

    /// Build the hit rate tracker.
    ///
    /// # Errors
    ///
    /// Returns `ConfigError` if no smoothing factor was set.
    pub fn build(self) -> Result<HitRateF64, crate::ConfigError> {
        let alpha = self
            .alpha
            .ok_or(crate::ConfigError::Missing("alpha or halflife"))?;

        let ew_hits = EmaF64::builder().alpha(alpha).build()?;

        Ok(HitRateF64 {
            hits: 0,
            total: 0,
            ew_hits,
        })
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn all_correct() {
        let mut hr = HitRateF64::builder().alpha(0.1).build().unwrap();
        for _ in 0..100 {
            hr.update(1.0, 1.0).unwrap();
        }
        assert!((hr.hit_rate() - 1.0).abs() < f64::EPSILON);
        assert!(hr.ew_hit_rate() > 0.99);
    }

    #[test]
    fn all_wrong() {
        let mut hr = HitRateF64::builder().alpha(0.1).build().unwrap();
        for _ in 0..100 {
            hr.update(1.0, -1.0).unwrap();
        }
        assert!(hr.hit_rate().abs() < f64::EPSILON);
        assert!(hr.ew_hit_rate() < 0.01);
    }

    #[test]
    fn mixed_directions() {
        let mut hr = HitRateF64::builder().alpha(0.1).build().unwrap();
        // 5 hits, 5 misses
        for i in 0..10 {
            if i < 5 {
                hr.update(1.0, 1.0).unwrap();
            } else {
                hr.update(1.0, -1.0).unwrap();
            }
        }
        assert!((hr.hit_rate() - 0.5).abs() < f64::EPSILON);
    }

    #[test]
    fn zero_treated_as_miss() {
        let mut hr = HitRateF64::builder().alpha(0.1).build().unwrap();
        hr.update(0.0, 1.0).unwrap();
        assert!(hr.hit_rate().abs() < f64::EPSILON);
        hr.update(1.0, 0.0).unwrap();
        assert!(hr.hit_rate().abs() < f64::EPSILON);
    }

    #[test]
    fn priming() {
        let hr = HitRateF64::builder().alpha(0.1).build().unwrap();
        assert!(!hr.is_primed());
        assert_eq!(hr.count(), 0);
    }

    #[test]
    fn reset_clears_state() {
        let mut hr = HitRateF64::builder().alpha(0.1).build().unwrap();
        hr.update(1.0, 1.0).unwrap();
        hr.reset();
        assert_eq!(hr.count(), 0);
        assert!(!hr.is_primed());
        assert!(hr.hit_rate().abs() < f64::EPSILON);
    }

    #[test]
    fn nan_rejected() {
        let mut hr = HitRateF64::builder().alpha(0.1).build().unwrap();
        assert!(hr.update(f64::NAN, 1.0).is_err());
        assert!(hr.update(1.0, f64::NAN).is_err());
    }

    #[test]
    fn inf_rejected() {
        let mut hr = HitRateF64::builder().alpha(0.1).build().unwrap();
        assert!(hr.update(f64::INFINITY, 1.0).is_err());
    }
}