use std::fmt;
#[derive(Debug, Clone)]
pub struct SpeculativeConfig {
pub draft_length: usize,
pub temperature: f64,
pub top_k: usize,
pub max_tokens: usize,
pub adaptive_draft: bool,
}
impl Default for SpeculativeConfig {
fn default() -> Self {
Self {
draft_length: 4,
temperature: 1.0,
top_k: 50,
max_tokens: 512,
adaptive_draft: false,
}
}
}
impl fmt::Display for SpeculativeConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"SpeculativeConfig(draft_length={}, temperature={:.2}, top_k={}, max_tokens={}, adaptive={})",
self.draft_length, self.temperature, self.top_k, self.max_tokens, self.adaptive_draft
)
}
}
#[derive(Debug, Clone)]
pub struct VerificationResult {
pub accepted_tokens: Vec<usize>,
pub rejected_at: Option<usize>,
pub acceptance_rate: f64,
}
impl VerificationResult {
pub fn new(
accepted_tokens: Vec<usize>,
rejected_at: Option<usize>,
acceptance_rate: f64,
) -> Self {
Self {
accepted_tokens,
rejected_at,
acceptance_rate,
}
}
pub fn all_accepted(&self) -> bool {
self.rejected_at.is_none()
}
pub fn num_accepted(&self) -> usize {
self.accepted_tokens.len()
}
}
impl fmt::Display for VerificationResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"VerificationResult(accepted={}, rejected_at={:?}, rate={:.2})",
self.accepted_tokens.len(),
self.rejected_at,
self.acceptance_rate
)
}
}
#[derive(Debug, Clone)]
pub struct DecodingStats {
pub total_tokens: usize,
pub draft_tokens: usize,
pub accepted_tokens: usize,
pub wall_time_ms: f64,
pub tokens_per_step: f64,
}
impl DecodingStats {
pub fn new() -> Self {
Self {
total_tokens: 0,
draft_tokens: 0,
accepted_tokens: 0,
wall_time_ms: 0.0,
tokens_per_step: 0.0,
}
}
pub fn acceptance_rate(&self) -> f64 {
if self.draft_tokens == 0 {
0.0
} else {
self.accepted_tokens as f64 / self.draft_tokens as f64
}
}
pub fn throughput(&self) -> f64 {
if self.wall_time_ms <= 0.0 {
0.0
} else {
self.total_tokens as f64 / self.wall_time_ms
}
}
}
impl Default for DecodingStats {
fn default() -> Self {
Self::new()
}
}
impl fmt::Display for DecodingStats {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"DecodingStats(total={}, drafted={}, accepted={}, rate={:.2}, tok/step={:.2}, time={:.1}ms)",
self.total_tokens,
self.draft_tokens,
self.accepted_tokens,
self.acceptance_rate(),
self.tokens_per_step,
self.wall_time_ms,
)
}
}
#[derive(Debug, Clone)]
pub struct TokenDistribution {
probs: Vec<f64>,
}
impl TokenDistribution {
pub fn from_probs(probs: Vec<f64>) -> Option<Self> {
if probs.is_empty() {
return None;
}
if probs.iter().any(|&p| p < 0.0) {
return None;
}
let sum: f64 = probs.iter().sum();
if sum <= 0.0 {
return None;
}
let normalized: Vec<f64> = probs.iter().map(|&p| p / sum).collect();
Some(Self { probs: normalized })
}
pub fn uniform(vocab_size: usize) -> Option<Self> {
if vocab_size == 0 {
return None;
}
let p = 1.0 / vocab_size as f64;
Some(Self {
probs: vec![p; vocab_size],
})
}
pub fn from_log_probs(log_probs: &[f64]) -> Option<Self> {
if log_probs.is_empty() {
return None;
}
let max_lp = log_probs.iter().copied().fold(f64::NEG_INFINITY, f64::max);
if max_lp.is_nan() {
return None;
}
let exps: Vec<f64> = log_probs.iter().map(|&lp| (lp - max_lp).exp()).collect();
let sum: f64 = exps.iter().sum();
if sum <= 0.0 || sum.is_nan() {
return None;
}
let probs: Vec<f64> = exps.iter().map(|&e| e / sum).collect();
Some(Self { probs })
}
pub fn vocab_size(&self) -> usize {
self.probs.len()
}
pub fn prob(&self, token_id: usize) -> f64 {
self.probs.get(token_id).copied().unwrap_or(0.0)
}
pub fn probs(&self) -> &[f64] {
&self.probs
}
pub fn with_temperature(&self, temperature: f64) -> Option<Self> {
if temperature <= 0.0 {
return None;
}
if (temperature - 1.0).abs() < 1e-12 {
return Some(self.clone());
}
let log_probs: Vec<f64> = self
.probs
.iter()
.map(|&p| {
if p > 0.0 {
p.ln() / temperature
} else {
f64::NEG_INFINITY
}
})
.collect();
Self::from_log_probs(&log_probs)
}
pub fn with_top_k(&self, k: usize) -> Option<Self> {
if k == 0 {
return None;
}
if k >= self.probs.len() {
return Some(self.clone());
}
let mut sorted: Vec<f64> = self.probs.clone();
sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
let threshold = sorted[k - 1];
let filtered: Vec<f64> = self
.probs
.iter()
.map(|&p| if p >= threshold { p } else { 0.0 })
.collect();
Self::from_probs(filtered)
}
pub fn sample_with_uniform(&self, u: f64) -> usize {
let u = u.clamp(0.0, 1.0 - f64::EPSILON);
let mut cumulative = 0.0;
for (i, &p) in self.probs.iter().enumerate() {
cumulative += p;
if u < cumulative {
return i;
}
}
self.probs.len().saturating_sub(1)
}
pub fn argmax(&self) -> usize {
self.probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0)
}
}
impl fmt::Display for TokenDistribution {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let top = self.argmax();
write!(
f,
"TokenDistribution(vocab={}, top_token={}, top_prob={:.4})",
self.vocab_size(),
top,
self.prob(top),
)
}
}