#[derive(Debug, Clone, PartialEq)]
pub struct SpeculativeMetrics {
pub accept_rate: f32,
pub tokens_per_step_avg: f32,
pub speedup_estimate: f32,
pub rounds: u64,
pub total_drafted: u64,
pub total_accepted: u64,
pub total_committed: u64,
pub cost_ratio: f32,
}
impl Default for SpeculativeMetrics {
fn default() -> Self {
Self {
accept_rate: 0.0,
tokens_per_step_avg: 0.0,
speedup_estimate: 1.0,
rounds: 0,
total_drafted: 0,
total_accepted: 0,
total_committed: 0,
cost_ratio: 0.125,
}
}
}
impl SpeculativeMetrics {
pub fn new() -> Self {
Self::default()
}
pub fn with_cost_ratio(mut self, cost_ratio: f32) -> Self {
self.cost_ratio = cost_ratio.max(0.0);
self
}
pub fn record_round(&mut self, drafted: u32, accepted: u32, committed: u32, k: u32) {
self.rounds = self.rounds.saturating_add(1);
self.total_drafted = self.total_drafted.saturating_add(drafted as u64);
self.total_accepted = self.total_accepted.saturating_add(accepted as u64);
self.total_committed = self.total_committed.saturating_add(committed as u64);
if self.total_drafted > 0 {
self.accept_rate = self.total_accepted as f32 / self.total_drafted as f32;
}
if self.rounds > 0 {
self.tokens_per_step_avg = self.total_committed as f32 / self.rounds as f32;
}
let round_cost = 1.0 + (k as f32) * self.cost_ratio;
self.speedup_estimate = if round_cost > 0.0 {
self.tokens_per_step_avg / round_cost
} else {
1.0
};
}
pub fn reset(&mut self) {
let keep = self.cost_ratio;
*self = Self::default();
self.cost_ratio = keep;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fresh_metrics_default() {
let m = SpeculativeMetrics::new();
assert_eq!(m.accept_rate, 0.0);
assert_eq!(m.tokens_per_step_avg, 0.0);
assert!((m.speedup_estimate - 1.0).abs() < 1e-6);
assert_eq!(m.rounds, 0);
}
#[test]
fn record_round_updates_everything() {
let mut m = SpeculativeMetrics::new();
m.record_round(4, 3, 4, 4);
assert_eq!(m.rounds, 1);
assert_eq!(m.total_drafted, 4);
assert_eq!(m.total_accepted, 3);
assert_eq!(m.total_committed, 4);
assert!((m.accept_rate - 0.75).abs() < 1e-6);
assert!((m.tokens_per_step_avg - 4.0).abs() < 1e-6);
assert!((m.speedup_estimate - 4.0 / 1.5).abs() < 1e-4);
}
#[test]
fn rolling_averages_accumulate() {
let mut m = SpeculativeMetrics::new();
m.record_round(4, 4, 5, 4); m.record_round(4, 0, 1, 4); assert_eq!(m.rounds, 2);
assert_eq!(m.total_accepted, 4);
assert!((m.accept_rate - 0.5).abs() < 1e-6);
assert!((m.tokens_per_step_avg - 3.0).abs() < 1e-6);
}
#[test]
fn reset_keeps_cost_ratio() {
let mut m = SpeculativeMetrics::new().with_cost_ratio(0.05);
m.record_round(4, 3, 4, 4);
m.reset();
assert_eq!(m.rounds, 0);
assert!((m.cost_ratio - 0.05).abs() < 1e-6);
}
#[test]
fn with_cost_ratio_clamps_negative() {
let m = SpeculativeMetrics::new().with_cost_ratio(-0.5);
assert_eq!(m.cost_ratio, 0.0);
}
}