use serde::Serialize;
#[derive(Debug, Clone, Default, Serialize)]
pub struct OnlinePerplexity {
sum_log_p: f64,
n: usize,
}
impl OnlinePerplexity {
pub fn new() -> Self {
Self::default()
}
pub fn push(&mut self, log_p: f32) {
self.sum_log_p += log_p as f64;
self.n += 1;
}
pub fn push_chunk(&mut self, log_ps: &[f32]) {
for &l in log_ps {
self.push(l);
}
}
pub fn reset(&mut self) {
self.sum_log_p = 0.0;
self.n = 0;
}
pub fn tokens(&self) -> usize {
self.n
}
pub fn current(&self) -> f32 {
if self.n == 0 {
f32::INFINITY
} else {
let mean_neg_log = -self.sum_log_p / self.n as f64;
mean_neg_log.exp() as f32
}
}
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct OnlineAccuracy {
correct: usize,
total: usize,
}
impl OnlineAccuracy {
pub fn new() -> Self {
Self::default()
}
pub fn push(&mut self, is_correct: bool) {
if is_correct {
self.correct += 1;
}
self.total += 1;
}
pub fn push_many(&mut self, outcomes: &[bool]) {
for &b in outcomes {
self.push(b);
}
}
pub fn current(&self) -> f32 {
if self.total == 0 {
0.0
} else {
self.correct as f32 / self.total as f32
}
}
pub fn counts(&self) -> (usize, usize) {
(self.correct, self.total)
}
pub fn reset(&mut self) {
self.correct = 0;
self.total = 0;
}
}