Skip to main content

oxibonsai_eval/
streaming.rs

1//! Streaming / online evaluation state machines.
2//!
3//! These mirror the batch evaluation paths but maintain running state so
4//! callers can report a partial score mid-stream. The finalised value is
5//! mathematically equivalent to the batch path (property-tested).
6
7use serde::Serialize;
8
9// ──────────────────────────────────────────────────────────────────────────────
10// OnlinePerplexity
11// ──────────────────────────────────────────────────────────────────────────────
12
13/// Running perplexity estimator: accumulates `Σ log_p` and token count, and
14/// reports `exp(-mean_neg_log_p)` on demand.
15///
16/// Feeding tokens one at a time yields exactly the same value as the batch
17/// [`crate::perplexity::PerplexityEvaluator::compute`] call at the end, up to
18/// `f32` accumulation order.
19#[derive(Debug, Clone, Default, Serialize)]
20pub struct OnlinePerplexity {
21    /// Sum of log-probabilities seen so far (natural log).
22    sum_log_p: f64,
23    /// Number of tokens observed.
24    n: usize,
25}
26
27impl OnlinePerplexity {
28    /// Construct a fresh estimator.
29    pub fn new() -> Self {
30        Self::default()
31    }
32
33    /// Feed a single log-probability (natural log).
34    pub fn push(&mut self, log_p: f32) {
35        self.sum_log_p += log_p as f64;
36        self.n += 1;
37    }
38
39    /// Feed a chunk of log-probabilities.
40    pub fn push_chunk(&mut self, log_ps: &[f32]) {
41        for &l in log_ps {
42            self.push(l);
43        }
44    }
45
46    /// Reset the state (zero tokens, zero sum).
47    pub fn reset(&mut self) {
48        self.sum_log_p = 0.0;
49        self.n = 0;
50    }
51
52    /// Number of tokens seen.
53    pub fn tokens(&self) -> usize {
54        self.n
55    }
56
57    /// Current perplexity estimate; `f32::INFINITY` if no tokens have been seen.
58    pub fn current(&self) -> f32 {
59        if self.n == 0 {
60            f32::INFINITY
61        } else {
62            let mean_neg_log = -self.sum_log_p / self.n as f64;
63            mean_neg_log.exp() as f32
64        }
65    }
66}
67
68// ──────────────────────────────────────────────────────────────────────────────
69// OnlineAccuracy
70// ──────────────────────────────────────────────────────────────────────────────
71
72/// Running accuracy counter.
73#[derive(Debug, Clone, Default, Serialize)]
74pub struct OnlineAccuracy {
75    /// Correct count.
76    correct: usize,
77    /// Total count.
78    total: usize,
79}
80
81impl OnlineAccuracy {
82    /// Construct a fresh counter.
83    pub fn new() -> Self {
84        Self::default()
85    }
86
87    /// Record a single prediction outcome.
88    pub fn push(&mut self, is_correct: bool) {
89        if is_correct {
90            self.correct += 1;
91        }
92        self.total += 1;
93    }
94
95    /// Record a batch of outcomes.
96    pub fn push_many(&mut self, outcomes: &[bool]) {
97        for &b in outcomes {
98            self.push(b);
99        }
100    }
101
102    /// Return the current accuracy in `[0, 1]`. Returns 0.0 for empty state.
103    pub fn current(&self) -> f32 {
104        if self.total == 0 {
105            0.0
106        } else {
107            self.correct as f32 / self.total as f32
108        }
109    }
110
111    /// Return `(correct, total)` so callers can build an
112    /// [`crate::accuracy::AccuracyResult`] later.
113    pub fn counts(&self) -> (usize, usize) {
114        (self.correct, self.total)
115    }
116
117    /// Reset the counter.
118    pub fn reset(&mut self) {
119        self.correct = 0;
120        self.total = 0;
121    }
122}