use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use thiserror::Error;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use crate::engine::InferenceEngine;
use crate::error::{RuntimeError, RuntimeResult};
use crate::sampling::{Sampler, SamplerConfig};
#[derive(Debug, Error)]
pub enum RewindError {
#[error("rewind not supported for this model type (SSM/recurrent state)")]
NotSupported,
#[error("rewind target position {target} exceeds current length {current}")]
PositionBeyondEnd { target: usize, current: usize },
#[error("rewind runtime error: {0}")]
Runtime(#[from] RuntimeError),
}
pub trait Rewindable {
fn rewind(&mut self, n: usize) -> Result<(), RewindError>;
fn current_length(&self) -> usize;
}
impl Rewindable for InferenceEngine {
fn rewind(&mut self, n: usize) -> Result<(), RewindError> {
let current = self.current_length();
if n > current {
return Err(RewindError::PositionBeyondEnd { target: n, current });
}
self.truncate_kv_cache(n).map_err(RewindError::Runtime)
}
fn current_length(&self) -> usize {
self.kv_seq_len()
}
}
#[derive(Debug, Default, Clone)]
pub struct SpecStats {
pub accepted: u64,
pub rejected: u64,
pub bonus_tokens: u64,
pub total_elapsed: Duration,
pub n1_fallbacks: u64,
}
impl SpecStats {
pub fn acceptance_rate(&self) -> f32 {
let total = self.accepted + self.rejected;
if total == 0 {
0.0
} else {
self.accepted as f32 / total as f32
}
}
pub fn total_output_tokens(&self) -> u64 {
self.accepted + self.bonus_tokens
}
}
#[derive(Debug)]
struct DraftProposal {
tokens: Vec<u32>,
probs: Vec<f32>,
start_pos: usize,
}
pub struct SpeculativeDecoder {
draft: Arc<Mutex<InferenceEngine>>,
target: InferenceEngine,
config: AsyncSpecConfig,
cancel: CancellationToken,
stats: SpecStats,
}
#[derive(Debug, Clone)]
pub struct AsyncSpecConfig {
pub spec_k: usize,
pub draft_sampler: SamplerConfig,
pub target_sampler: SamplerConfig,
pub force_n1: bool,
pub max_tokens: usize,
}
impl Default for AsyncSpecConfig {
fn default() -> Self {
Self {
spec_k: 4,
draft_sampler: SamplerConfig::greedy(),
target_sampler: SamplerConfig::default(),
force_n1: false,
max_tokens: 512,
}
}
}
impl SpeculativeDecoder {
pub fn new(draft: InferenceEngine, target: InferenceEngine, config: AsyncSpecConfig) -> Self {
Self {
draft: Arc::new(Mutex::new(draft)),
target,
config,
cancel: CancellationToken::new(),
stats: SpecStats::default(),
}
}
pub fn new_n1(
draft: InferenceEngine,
target: InferenceEngine,
config: AsyncSpecConfig,
) -> Self {
let cfg = AsyncSpecConfig {
force_n1: true,
..config
};
Self::new(draft, target, cfg)
}
pub fn stats(&self) -> &SpecStats {
&self.stats
}
pub fn reset_stats(&mut self) {
self.stats = SpecStats::default();
}
pub fn cancellation_token(&self) -> CancellationToken {
self.cancel.clone()
}
pub async fn generate<F>(&mut self, prompt: &str, mut on_token: F) -> RuntimeResult<String>
where
F: FnMut(&str) + Send + 'static,
{
let started_at = Instant::now();
if !self.target.is_loaded() {
return Err(RuntimeError::ModelNotLoaded);
}
{
let draft_guard = self
.draft
.lock()
.map_err(|_| RuntimeError::ModelLoadError {
message: "draft engine mutex poisoned".to_string(),
})?;
if !draft_guard.is_loaded() {
return Err(RuntimeError::ModelNotLoaded);
}
}
let use_n1 = self.config.force_n1;
let spec_k = if use_n1 { 1 } else { self.config.spec_k };
let max_tokens = self.config.max_tokens;
let prompt_tokens = self.target.tokenize(prompt)?;
if prompt_tokens.is_empty() {
return Ok(String::new());
}
self.target.prefill(&prompt_tokens)?;
{
let draft = Arc::clone(&self.draft);
let pt = prompt_tokens.clone();
tokio::task::spawn_blocking(move || {
let mut d = draft.lock().map_err(|_| RuntimeError::ModelLoadError {
message: "draft mutex poisoned during prefill".to_string(),
})?;
d.prefill(&pt)
})
.await
.map_err(|e| RuntimeError::ModelLoadError {
message: format!("draft prefill task panicked: {e}"),
})??;
}
let mut output_text = String::new();
let mut generated = 0usize;
let mut target_sampler = Sampler::new(self.config.target_sampler.clone());
let mut recent_tokens = prompt_tokens.clone();
let (proposal_tx, mut proposal_rx) = mpsc::channel::<DraftProposal>(2);
let cancel_child = self.cancel.child_token();
let draft_arc = Arc::clone(&self.draft);
let draft_sampler_cfg = self.config.draft_sampler.clone();
let cancel_draft = cancel_child.clone();
let stop_flag = Arc::new(std::sync::atomic::AtomicBool::new(false));
let stop_flag_draft = Arc::clone(&stop_flag);
tokio::task::spawn(async move {
let _draft_sampler = Sampler::new(draft_sampler_cfg);
let draft_recent: Vec<u32> = Vec::new();
loop {
if cancel_draft.is_cancelled()
|| stop_flag_draft.load(std::sync::atomic::Ordering::Relaxed)
{
break;
}
let draft_arc2 = Arc::clone(&draft_arc);
let spec_k_local = spec_k;
let recent_clone = draft_recent.clone();
let proposal = tokio::task::spawn_blocking(move || {
let mut d = draft_arc2
.lock()
.map_err(|_| RuntimeError::ModelLoadError {
message: "draft mutex poisoned in draft task".to_string(),
})?;
let start_pos = d.kv_seq_len();
let mut tokens = Vec::with_capacity(spec_k_local);
let mut probs = Vec::with_capacity(spec_k_local);
let mut recent = recent_clone;
for _ in 0..spec_k_local {
if d.kv_seq_len() >= d.max_ctx_len() {
break;
}
let last = tokens
.last()
.copied()
.or_else(|| recent.last().copied())
.unwrap_or(0);
let logits = d.forward_one(last)?;
let tok = Sampler::new(SamplerConfig::greedy()).sample(&logits, &recent);
let prob = softmax_prob(&logits, tok);
tokens.push(tok);
probs.push(prob);
recent.push(tok);
}
Ok::<DraftProposal, RuntimeError>(DraftProposal {
tokens,
probs,
start_pos,
})
})
.await;
match proposal {
Ok(Ok(p)) if !p.tokens.is_empty() => {
if proposal_tx.send(p).await.is_err() {
break;
}
}
_ => break,
}
}
});
'outer: loop {
if self.cancel.is_cancelled() {
stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
if generated == 0 {
return Err(RuntimeError::Cancelled);
}
break;
}
if generated >= max_tokens {
stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
break;
}
let proposal =
tokio::time::timeout(Duration::from_millis(500), proposal_rx.recv()).await;
let proposal = match proposal {
Ok(Some(p)) => p,
_ => {
stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
break;
}
};
let mut diverged_at: Option<usize> = None;
let mut last_target_logits: Vec<f32> = Vec::new();
for (i, (&draft_tok, &draft_prob)) in proposal
.tokens
.iter()
.zip(proposal.probs.iter())
.enumerate()
{
if generated + i >= max_tokens {
stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
break 'outer;
}
let tgt_logits = match self.target.forward_one(draft_tok) {
Ok(l) => l,
Err(e) => {
stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
return Err(e);
}
};
let target_prob = softmax_prob(&tgt_logits, draft_tok);
let accept = accept_draft_token(target_prob, draft_prob);
if accept {
let text = match self.target.decode_token(draft_tok) {
Ok(t) => t,
Err(e) => {
stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
return Err(e);
}
};
on_token(&text);
output_text.push_str(&text);
recent_tokens.push(draft_tok);
self.stats.accepted += 1;
generated += 1;
if self.target.is_eos(draft_tok) || generated >= max_tokens {
stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
break 'outer;
}
last_target_logits = tgt_logits;
} else {
self.stats.rejected += 1;
diverged_at = Some(proposal.start_pos + i);
last_target_logits = tgt_logits;
break;
}
}
if diverged_at.is_none() && !last_target_logits.is_empty() {
let bonus = target_sampler.sample(&last_target_logits, &recent_tokens);
let text = match self.target.decode_token(bonus) {
Ok(t) => t,
Err(e) => {
stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
return Err(e);
}
};
on_token(&text);
output_text.push_str(&text);
recent_tokens.push(bonus);
self.stats.bonus_tokens += 1;
generated += 1;
if self.target.is_eos(bonus) || generated >= max_tokens {
stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
break;
}
}
if let Some(rewind_pos) = diverged_at {
let residual_tok = target_sampler.sample(&last_target_logits, &recent_tokens);
let text = match self.target.decode_token(residual_tok) {
Ok(t) => t,
Err(e) => {
stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
return Err(e);
}
};
on_token(&text);
output_text.push_str(&text);
recent_tokens.push(residual_tok);
generated += 1;
if self.target.is_eos(residual_tok) || generated >= max_tokens {
stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
break;
}
let new_len = rewind_pos + 1;
match self.target.rewind(new_len) {
Ok(()) => {}
Err(RewindError::NotSupported) => {
self.stats.n1_fallbacks += 1;
}
Err(RewindError::PositionBeyondEnd { .. }) => {
}
Err(RewindError::Runtime(e)) => {
stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
return Err(e);
}
}
let draft_arc2 = Arc::clone(&self.draft);
let rewind_to = new_len;
let _ = tokio::task::spawn_blocking(move || {
let mut d = draft_arc2.lock().ok()?;
let _ = d.rewind(rewind_to);
Some(())
})
.await;
}
}
self.stats.total_elapsed += started_at.elapsed();
Ok(output_text)
}
}
fn softmax_prob(logits: &[f32], token_id: u32) -> f32 {
let idx = token_id as usize;
if idx >= logits.len() {
return 0.0;
}
let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp: Vec<f32> = logits.iter().map(|&l| (l - max).exp()).collect();
let sum: f32 = exp.iter().sum();
if sum < 1e-9 {
return 0.0;
}
exp[idx] / sum
}
fn accept_draft_token(p_target: f32, p_draft: f32) -> bool {
if p_draft < 1e-9 {
return false;
}
if p_target >= p_draft {
return true;
}
let threshold = p_target / p_draft;
let pseudo_rand = pseudo_uniform(p_target, p_draft);
pseudo_rand < threshold
}
fn pseudo_uniform(a: f32, b: f32) -> f32 {
let bits = a
.to_bits()
.wrapping_mul(2654435761)
.wrapping_add(b.to_bits().wrapping_mul(40503));
(bits as f32) / (u32::MAX as f32)
}
trait InferenceEngineExt {
fn truncate_kv_cache(&mut self, n: usize) -> RuntimeResult<()>;
fn kv_seq_len(&self) -> usize;
fn max_ctx_len(&self) -> usize;
}
impl InferenceEngineExt for InferenceEngine {
fn truncate_kv_cache(&mut self, n: usize) -> RuntimeResult<()> {
self.truncate(n)
}
fn kv_seq_len(&self) -> usize {
self.kv_cache_seq_len()
}
fn max_ctx_len(&self) -> usize {
self.model_config()
.map(|c| c.max_context_length)
.unwrap_or(4096)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn spec_stats_acceptance_rate_empty() {
let s = SpecStats::default();
assert!(
(s.acceptance_rate() - 0.0).abs() < 1e-6,
"empty stats must return 0.0 acceptance rate"
);
}
#[test]
fn spec_stats_acceptance_rate_all_accepted() {
let s = SpecStats {
accepted: 10,
rejected: 0,
..SpecStats::default()
};
assert!(
(s.acceptance_rate() - 1.0).abs() < 1e-6,
"all-accepted must return 1.0"
);
}
#[test]
fn spec_stats_acceptance_rate_half() {
let s = SpecStats {
accepted: 5,
rejected: 5,
..SpecStats::default()
};
assert!(
(s.acceptance_rate() - 0.5).abs() < 1e-6,
"half accepted must return 0.5"
);
}
#[test]
fn spec_stats_total_output_tokens() {
let s = SpecStats {
accepted: 8,
bonus_tokens: 2,
..SpecStats::default()
};
assert_eq!(s.total_output_tokens(), 10);
}
#[test]
fn softmax_prob_uniform_logits() {
let logits = vec![1.0f32; 4];
let p = softmax_prob(&logits, 0);
assert!(
(p - 0.25).abs() < 1e-5,
"uniform logits must produce p=0.25 for any token, got {p}"
);
}
#[test]
fn softmax_prob_out_of_range_returns_zero() {
let logits = vec![1.0f32; 4];
let p = softmax_prob(&logits, 99);
assert_eq!(p, 0.0, "out-of-range token must return 0.0");
}
#[test]
fn softmax_prob_large_positive_logit() {
let mut logits = vec![0.0f32; 8];
logits[3] = 100.0;
let p = softmax_prob(&logits, 3);
assert!(
p > 0.99,
"dominant logit must produce near-1 probability, got {p}"
);
}
#[test]
fn accept_draft_token_always_accepts_when_target_ge_draft() {
assert!(
accept_draft_token(0.9, 0.5),
"p_target=0.9 >= p_draft=0.5 must always accept"
);
assert!(
accept_draft_token(0.5, 0.5),
"p_target==p_draft must always accept"
);
}
#[test]
fn accept_draft_token_never_accepts_zero_draft_prob() {
assert!(
!accept_draft_token(0.5, 0.0),
"zero draft prob must always reject"
);
}
#[test]
fn async_spec_config_defaults() {
let cfg = AsyncSpecConfig::default();
assert_eq!(cfg.spec_k, 4, "default spec_k must be 4");
assert!(!cfg.force_n1, "force_n1 must be false by default");
assert_eq!(cfg.max_tokens, 512);
}
#[test]
fn rewind_error_not_supported_display() {
let e = RewindError::NotSupported;
let s = e.to_string();
assert!(
s.contains("not supported"),
"NotSupported display must contain 'not supported', got: {s}"
);
}
#[test]
fn rewind_error_position_beyond_end_display() {
let e = RewindError::PositionBeyondEnd {
target: 10,
current: 5,
};
let s = e.to_string();
assert!(
s.contains("10") && s.contains("5"),
"display must include positions, got: {s}"
);
}
#[test]
fn spec_decode_construction_with_unloaded_engines() {
use crate::engine::EngineConfig;
let draft = InferenceEngine::new(EngineConfig::default());
let target = InferenceEngine::new(EngineConfig::default());
let decoder = SpeculativeDecoder::new(draft, target, AsyncSpecConfig::default());
assert_eq!(decoder.stats().accepted, 0);
assert_eq!(decoder.stats().rejected, 0);
}
#[tokio::test]
async fn spec_decode_correctness_stub() {
use crate::engine::EngineConfig;
let draft = InferenceEngine::new(EngineConfig::default());
let target = InferenceEngine::new(EngineConfig::default());
let mut decoder = SpeculativeDecoder::new(draft, target, AsyncSpecConfig::default());
let result = decoder.generate("hello", |_| {}).await;
assert!(
matches!(result, Err(RuntimeError::ModelNotLoaded)),
"expected ModelNotLoaded for unloaded decoder, got {result:?}"
);
}
#[test]
fn spec_decode_divergence_rollback() {
use crate::engine::EngineConfig;
let draft = InferenceEngine::new(EngineConfig::default());
let target = InferenceEngine::new(EngineConfig::default());
let cfg = AsyncSpecConfig {
force_n1: true,
..AsyncSpecConfig::default()
};
let mut decoder = SpeculativeDecoder::new_n1(draft, target, cfg);
decoder.reset_stats();
let stats = decoder.stats();
assert_eq!(stats.accepted, 0);
assert_eq!(stats.n1_fallbacks, 0);
}
#[test]
fn spec_decode_ssm_falls_back() {
use crate::engine::EngineConfig;
let draft = InferenceEngine::new(EngineConfig::default());
let target = InferenceEngine::new(EngineConfig::default());
let decoder = SpeculativeDecoder::new_n1(
draft,
target,
AsyncSpecConfig {
force_n1: true,
spec_k: 1,
..AsyncSpecConfig::default()
},
);
assert!(
decoder.config.force_n1,
"force_n1 must be true when constructed with new_n1"
);
assert_eq!(decoder.config.spec_k, 1);
}
#[test]
fn cancellation_token_child_relationship() {
use crate::engine::EngineConfig;
let draft = InferenceEngine::new(EngineConfig::default());
let target = InferenceEngine::new(EngineConfig::default());
let decoder = SpeculativeDecoder::new(draft, target, AsyncSpecConfig::default());
let token = decoder.cancellation_token();
assert!(
!token.is_cancelled(),
"token must not be cancelled initially"
);
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[tokio::test]
async fn spec_decode_loaded_engines_produce_output() {
use crate::engine::EngineConfig;
let model_bytes = oxillama_gguf::test_utils::build_minimal_llama_gguf();
let tok_json = oxillama_gguf::test_utils::minimal_tokenizer_json();
let mut draft_eng = InferenceEngine::new(EngineConfig::default());
draft_eng
.load_model_from_bytes(&model_bytes, tok_json)
.expect("draft load");
let mut target_eng = InferenceEngine::new(EngineConfig::default());
target_eng
.load_model_from_bytes(&model_bytes, tok_json)
.expect("target load");
let cfg = AsyncSpecConfig {
spec_k: 2,
max_tokens: 4,
..AsyncSpecConfig::default()
};
let mut decoder = SpeculativeDecoder::new(draft_eng, target_eng, cfg);
let result = decoder.generate("a", |_| {}).await;
assert!(
result.is_ok() || result.is_err(),
"generate must return Ok or a known error"
);
}
}