#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct SpeculativeConfig {
pub draft_steps: usize,
pub temperature: f64,
pub top_p: f64,
pub max_tokens: usize,
pub acceptance_threshold: f64,
}
impl Default for SpeculativeConfig {
fn default() -> Self {
Self {
draft_steps: 4,
temperature: 1.0,
top_p: 0.9,
max_tokens: 256,
acceptance_threshold: 0.0,
}
}
}
#[derive(Debug, Clone)]
pub struct SpeculativeResult {
pub accepted_tokens: Vec<u32>,
pub n_draft_tokens: usize,
pub n_accepted_tokens: usize,
pub acceptance_rate: f64,
pub n_verification_calls: usize,
}
impl SpeculativeResult {
pub(crate) fn new(
accepted_tokens: Vec<u32>,
n_draft_tokens: usize,
n_accepted_tokens: usize,
n_verification_calls: usize,
) -> Self {
let acceptance_rate = if n_draft_tokens == 0 {
0.0
} else {
n_accepted_tokens as f64 / n_draft_tokens as f64
};
Self {
accepted_tokens,
n_draft_tokens,
n_accepted_tokens,
acceptance_rate,
n_verification_calls,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn speculative_config_default() {
let cfg = SpeculativeConfig::default();
assert_eq!(cfg.draft_steps, 4);
assert!((cfg.temperature - 1.0).abs() < 1e-12);
assert!((cfg.top_p - 0.9).abs() < 1e-12);
assert_eq!(cfg.max_tokens, 256);
assert!((cfg.acceptance_threshold - 0.0).abs() < 1e-12);
}
#[test]
fn speculative_result_acceptance_rate_zero_drafts() {
let r = SpeculativeResult::new(vec![], 0, 0, 0);
assert!((r.acceptance_rate - 0.0).abs() < 1e-12);
}
#[test]
fn speculative_result_acceptance_rate_computed() {
let r = SpeculativeResult::new(vec![1, 2, 3], 4, 3, 1);
assert!((r.acceptance_rate - 0.75).abs() < 1e-12);
}
#[test]
fn speculative_result_leq_one() {
for (drafted, accepted) in [(10, 8), (5, 5), (3, 0)] {
let r = SpeculativeResult::new(vec![], drafted, accepted, 0);
assert!(r.acceptance_rate <= 1.0 + 1e-12);
}
}
}