oxillama_runtime/speculative_async.rs
1//! Drafter-async speculative decoding.
2//!
3//! # Overview
4//!
5//! This module provides an *async* speculative decoding loop where the draft
6//! model runs ahead of the target model in a separate `tokio` task. While the
7//! target is verifying a batch of `K` candidate tokens the drafter is already
8//! generating batch `K+1`, giving real wall-clock overlap.
9//!
10//! ## Architecture
11//!
12//! ```text
13//! ┌──────────────────┐ ┌──────────────────┐
14//! │ DraftTask │ ───► │ TargetTask │
15//! │ generate N tok │ │ verify N tok │
16//! │ (async, ahead) │ ◄─── │ (accept/reject) │
17//! └──────────────────┘ └──────────────────┘
18//! │ │
19//! └─── CancellationToken ──────┘
20//! ```
21//!
22//! On divergence the target calls `state.rewind(n)` to truncate the KV cache
23//! to the divergence point, then resumes from there. For SSM-based targets
24//! [`Rewindable::rewind`] returns [`RewindError::NotSupported`] and the engine
25//! falls back to verifying a single token at a time (N=1).
26//!
27//! ## Stats
28//!
29//! [`SpecStats`] accumulates per-generation acceptance counts and exposes the
30//! token-level acceptance rate so callers can decide whether async spec-decode
31//! is worth the overhead (acceptance < 30% → disable recommendation).
32//!
33//! ## Cancellation
34//!
35//! A `tokio_util::sync::CancellationToken` is shared between the draft task
36//! and the target's verification loop. When the target detects EOS or max
37//! tokens it cancels the token; the drafter shuts down cleanly within one
38//! iteration.
39//!
40//! ## Note on `InferenceEngine` thread safety
41//!
42//! `InferenceEngine` is `!Send` (contains `Box<dyn ForwardPass>` which may not
43//! be `Send` for all architecture implementations). The async drafter task
44//! therefore runs the draft engine in a `tokio::task::spawn_blocking` context
45//! and communicates results back via an `mpsc` channel.
46//!
47//! ## Relation to `speculative.rs`
48//!
49//! The existing [`speculative`](crate::speculative) module contains the
50//! synchronous `SpeculativeEngine` and associated tests. This module is
51//! additive — it does **not** modify or replace that code. Callers may use
52//! either API; the async variant provides higher throughput at the cost of
53//! more complex cancellation and state management.
54
55use std::sync::{Arc, Mutex};
56use std::time::{Duration, Instant};
57use thiserror::Error;
58use tokio::sync::mpsc;
59use tokio_util::sync::CancellationToken;
60
61use crate::engine::InferenceEngine;
62use crate::error::{RuntimeError, RuntimeResult};
63use crate::sampling::{Sampler, SamplerConfig};
64
65// ─── Rewindable trait ─────────────────────────────────────────────────────────
66
67/// An error returned when a rewind operation is not possible.
68#[derive(Debug, Error)]
69pub enum RewindError {
70 /// The backend does not support rewinding (e.g. SSM recurrent states).
71 ///
72 /// The caller should fall back to N=1 verification when this is returned.
73 #[error("rewind not supported for this model type (SSM/recurrent state)")]
74 NotSupported,
75 /// The requested position is beyond the current sequence length.
76 #[error("rewind target position {target} exceeds current length {current}")]
77 PositionBeyondEnd { target: usize, current: usize },
78 /// An I/O or runtime error prevented the rewind.
79 #[error("rewind runtime error: {0}")]
80 Runtime(#[from] RuntimeError),
81}
82
83/// Capability for truncating a sequence to an earlier position.
84///
85/// Implemented by KV-cache-backed engines: `rewind(n)` truncates the cache to
86/// `n` tokens. SSM-based engines return [`RewindError::NotSupported`],
87/// causing the speculative decoder to fall back to N=1 verification mode.
88pub trait Rewindable {
89 /// Truncate the model state so that the next token generated is at
90 /// position `n` (0-indexed).
91 ///
92 /// After a successful rewind the engine behaves as if only `n` tokens have
93 /// been processed: the KV cache has `n` entries, the position counter is
94 /// `n`, etc.
95 ///
96 /// # Errors
97 ///
98 /// - [`RewindError::NotSupported`] for SSM/recurrent models.
99 /// - [`RewindError::PositionBeyondEnd`] if `n` > current sequence length.
100 fn rewind(&mut self, n: usize) -> Result<(), RewindError>;
101
102 /// Return the current sequence length (= number of tokens in the KV
103 /// cache or SSM state).
104 fn current_length(&self) -> usize;
105}
106
107/// [`Rewindable`] implementation for [`InferenceEngine`].
108///
109/// Delegates to the engine's internal KV cache [`truncate`](crate::kv_cache::KvCache::truncate)
110/// method. If the engine has no loaded model (and thus no KV cache) the
111/// method returns `RuntimeError::ModelNotLoaded` wrapped in `RewindError::Runtime`.
112impl Rewindable for InferenceEngine {
113 fn rewind(&mut self, n: usize) -> Result<(), RewindError> {
114 let current = self.current_length();
115 if n > current {
116 return Err(RewindError::PositionBeyondEnd { target: n, current });
117 }
118 // Delegate to the KV cache truncate method.
119 self.truncate_kv_cache(n).map_err(RewindError::Runtime)
120 }
121
122 fn current_length(&self) -> usize {
123 self.kv_seq_len()
124 }
125}
126
127// ─── SpecStats ────────────────────────────────────────────────────────────────
128
129/// Per-generation acceptance statistics for the async speculative decoder.
130///
131/// Updated by the verification loop as tokens are accepted or rejected.
132#[derive(Debug, Default, Clone)]
133pub struct SpecStats {
134 /// Number of candidate draft tokens that were accepted by the target.
135 pub accepted: u64,
136 /// Number of candidate draft tokens that were rejected by the target.
137 pub rejected: u64,
138 /// Number of bonus tokens sampled directly from the target (one per
139 /// full-acceptance batch).
140 pub bonus_tokens: u64,
141 /// Total wall-clock time spent in the async decoder.
142 pub total_elapsed: Duration,
143 /// Number of times the decoder fell back to N=1 mode (SSM target).
144 pub n1_fallbacks: u64,
145}
146
147impl SpecStats {
148 /// Token-level acceptance rate in [0.0, 1.0].
149 ///
150 /// Returns 0.0 when no tokens have been evaluated.
151 pub fn acceptance_rate(&self) -> f32 {
152 let total = self.accepted + self.rejected;
153 if total == 0 {
154 0.0
155 } else {
156 self.accepted as f32 / total as f32
157 }
158 }
159
160 /// Total tokens produced (accepted + bonus).
161 pub fn total_output_tokens(&self) -> u64 {
162 self.accepted + self.bonus_tokens
163 }
164}
165
166// ─── DraftProposal ────────────────────────────────────────────────────────────
167
168/// A batch of `K` candidate tokens produced by the draft model.
169#[derive(Debug)]
170struct DraftProposal {
171 /// The candidate token IDs in generation order.
172 tokens: Vec<u32>,
173 /// Draft model's token probabilities at each position (for accept/reject).
174 probs: Vec<f32>,
175 /// The KV-cache position at which this proposal starts.
176 start_pos: usize,
177}
178
179// ─── SpeculativeDecoder ───────────────────────────────────────────────────────
180
181/// Async speculative decoder.
182///
183/// Wraps a draft engine (generating `spec_k` candidates per step) and a target
184/// engine (verifying the candidates in a single batched forward pass). The two
185/// engines run with overlap via `tokio`.
186///
187/// # Limitations
188///
189/// - Both engines must use the same tokenizer and vocabulary.
190/// - The draft engine must be strictly smaller/faster than the target.
191/// - The target engine must implement [`Rewindable`] (KV-cache based). For
192/// SSM targets use [`SpeculativeDecoder::new_n1`] which forces N=1 mode.
193///
194/// # Example
195///
196/// ```ignore
197/// let decoder = SpeculativeDecoder::new(
198/// draft_engine,
199/// target_engine,
200/// AsyncSpecConfig::default(),
201/// );
202/// let stats = decoder.generate("hello", 128, |tok| print!("{tok}")).await?;
203/// ```
204pub struct SpeculativeDecoder {
205 /// Draft engine wrapped in `Arc<Mutex>` so it can be moved to a
206 /// `spawn_blocking` worker.
207 draft: Arc<Mutex<InferenceEngine>>,
208 /// Target engine owned directly (verification runs on the caller's task).
209 target: InferenceEngine,
210 /// Speculative decoding configuration.
211 config: AsyncSpecConfig,
212 /// Cancellation token shared with the draft task.
213 cancel: CancellationToken,
214 /// Accumulated statistics for the current generation.
215 stats: SpecStats,
216}
217
218/// Configuration for the async speculative decoder.
219#[derive(Debug, Clone)]
220pub struct AsyncSpecConfig {
221 /// Number of draft tokens to generate per speculation step (K).
222 ///
223 /// Higher values increase potential throughput but also increase the cost
224 /// of verification and rollback on divergence. A value of 4–8 is typical.
225 pub spec_k: usize,
226 /// Sampler configuration applied by the draft engine.
227 pub draft_sampler: SamplerConfig,
228 /// Sampler configuration applied by the target engine for verification
229 /// and residual sampling.
230 pub target_sampler: SamplerConfig,
231 /// Force N=1 verification mode regardless of target model type.
232 ///
233 /// Set this to `true` when the target model is SSM-based (cannot rewind).
234 pub force_n1: bool,
235 /// Maximum number of tokens to generate (prompt + output combined).
236 pub max_tokens: usize,
237}
238
239impl Default for AsyncSpecConfig {
240 fn default() -> Self {
241 Self {
242 spec_k: 4,
243 draft_sampler: SamplerConfig::greedy(),
244 target_sampler: SamplerConfig::default(),
245 force_n1: false,
246 max_tokens: 512,
247 }
248 }
249}
250
251impl SpeculativeDecoder {
252 /// Construct a new async speculative decoder.
253 ///
254 /// Both engines must be loaded (i.e. `is_loaded()` is true) before
255 /// `generate` is called.
256 pub fn new(draft: InferenceEngine, target: InferenceEngine, config: AsyncSpecConfig) -> Self {
257 Self {
258 draft: Arc::new(Mutex::new(draft)),
259 target,
260 config,
261 cancel: CancellationToken::new(),
262 stats: SpecStats::default(),
263 }
264 }
265
266 /// Construct a decoder that always uses N=1 mode (for SSM targets).
267 pub fn new_n1(
268 draft: InferenceEngine,
269 target: InferenceEngine,
270 config: AsyncSpecConfig,
271 ) -> Self {
272 let cfg = AsyncSpecConfig {
273 force_n1: true,
274 ..config
275 };
276 Self::new(draft, target, cfg)
277 }
278
279 /// Return the accumulated statistics from all `generate` calls.
280 pub fn stats(&self) -> &SpecStats {
281 &self.stats
282 }
283
284 /// Reset statistics counters.
285 pub fn reset_stats(&mut self) {
286 self.stats = SpecStats::default();
287 }
288
289 /// Return a reference to the cancellation token for external cancellation.
290 pub fn cancellation_token(&self) -> CancellationToken {
291 self.cancel.clone()
292 }
293
294 /// Run async speculative generation for `prompt`, calling `on_token` for
295 /// each decoded token.
296 ///
297 /// Returns the full generated text and updates `self.stats`.
298 ///
299 /// # SSM fallback
300 ///
301 /// If the target engine's `rewind()` returns `RewindError::NotSupported`
302 /// on the first call, the decoder automatically falls back to N=1 mode for
303 /// the rest of the generation. `SpecStats::n1_fallbacks` is incremented.
304 ///
305 /// # Cancellation
306 ///
307 /// The generation loop checks `self.cancel` after each speculation step.
308 /// Callers can cancel by calling `cancel.cancel()` from another task.
309 ///
310 /// # Errors
311 ///
312 /// Returns `RuntimeError::ModelNotLoaded` if either engine is not loaded.
313 /// Returns `RuntimeError::Cancelled` if the cancellation token is
314 /// triggered before the first token is produced.
315 pub async fn generate<F>(&mut self, prompt: &str, mut on_token: F) -> RuntimeResult<String>
316 where
317 F: FnMut(&str) + Send + 'static,
318 {
319 let started_at = Instant::now();
320
321 // ── Validate both engines are loaded ──────────────────────────────────
322 if !self.target.is_loaded() {
323 return Err(RuntimeError::ModelNotLoaded);
324 }
325 {
326 let draft_guard = self
327 .draft
328 .lock()
329 .map_err(|_| RuntimeError::ModelLoadError {
330 message: "draft engine mutex poisoned".to_string(),
331 })?;
332 if !draft_guard.is_loaded() {
333 return Err(RuntimeError::ModelNotLoaded);
334 }
335 }
336
337 let use_n1 = self.config.force_n1;
338 let spec_k = if use_n1 { 1 } else { self.config.spec_k };
339 let max_tokens = self.config.max_tokens;
340
341 // ── Tokenize the prompt ───────────────────────────────────────────────
342 let prompt_tokens = self.target.tokenize(prompt)?;
343 if prompt_tokens.is_empty() {
344 return Ok(String::new());
345 }
346
347 // ── Prefill both engines ──────────────────────────────────────────────
348 // Target prefill (inline).
349 self.target.prefill(&prompt_tokens)?;
350
351 // Draft prefill (in blocking task to avoid blocking the async runtime).
352 {
353 let draft = Arc::clone(&self.draft);
354 let pt = prompt_tokens.clone();
355 tokio::task::spawn_blocking(move || {
356 let mut d = draft.lock().map_err(|_| RuntimeError::ModelLoadError {
357 message: "draft mutex poisoned during prefill".to_string(),
358 })?;
359 d.prefill(&pt)
360 })
361 .await
362 .map_err(|e| RuntimeError::ModelLoadError {
363 message: format!("draft prefill task panicked: {e}"),
364 })??;
365 }
366
367 // ── Generation loop ───────────────────────────────────────────────────
368 let mut output_text = String::new();
369 let mut generated = 0usize;
370 let mut target_sampler = Sampler::new(self.config.target_sampler.clone());
371 let mut recent_tokens = prompt_tokens.clone();
372
373 // Channel for draft proposals: draft task → main loop.
374 let (proposal_tx, mut proposal_rx) = mpsc::channel::<DraftProposal>(2);
375 let cancel_child = self.cancel.child_token();
376
377 // Spawn the draft task. It will produce proposals until cancelled.
378 let draft_arc = Arc::clone(&self.draft);
379 let draft_sampler_cfg = self.config.draft_sampler.clone();
380 let cancel_draft = cancel_child.clone();
381
382 // Use a `Mutex<bool>` to communicate the "still running" flag to the
383 // draft task so it stops when the target is done.
384 let stop_flag = Arc::new(std::sync::atomic::AtomicBool::new(false));
385 let stop_flag_draft = Arc::clone(&stop_flag);
386
387 tokio::task::spawn(async move {
388 let _draft_sampler = Sampler::new(draft_sampler_cfg);
389 let draft_recent: Vec<u32> = Vec::new();
390
391 loop {
392 if cancel_draft.is_cancelled()
393 || stop_flag_draft.load(std::sync::atomic::Ordering::Relaxed)
394 {
395 break;
396 }
397
398 // Generate spec_k candidate tokens from the draft engine.
399 let draft_arc2 = Arc::clone(&draft_arc);
400 let spec_k_local = spec_k;
401 let recent_clone = draft_recent.clone();
402
403 let proposal = tokio::task::spawn_blocking(move || {
404 let mut d = draft_arc2
405 .lock()
406 .map_err(|_| RuntimeError::ModelLoadError {
407 message: "draft mutex poisoned in draft task".to_string(),
408 })?;
409 let start_pos = d.kv_seq_len();
410 let mut tokens = Vec::with_capacity(spec_k_local);
411 let mut probs = Vec::with_capacity(spec_k_local);
412 let mut recent = recent_clone;
413
414 for _ in 0..spec_k_local {
415 if d.kv_seq_len() >= d.max_ctx_len() {
416 break;
417 }
418 let last = tokens
419 .last()
420 .copied()
421 .or_else(|| recent.last().copied())
422 .unwrap_or(0);
423 let logits = d.forward_one(last)?;
424 let tok = Sampler::new(SamplerConfig::greedy()).sample(&logits, &recent);
425 let prob = softmax_prob(&logits, tok);
426 tokens.push(tok);
427 probs.push(prob);
428 recent.push(tok);
429 }
430 Ok::<DraftProposal, RuntimeError>(DraftProposal {
431 tokens,
432 probs,
433 start_pos,
434 })
435 })
436 .await;
437
438 match proposal {
439 Ok(Ok(p)) if !p.tokens.is_empty() => {
440 if proposal_tx.send(p).await.is_err() {
441 break;
442 }
443 }
444 _ => break,
445 }
446 }
447 });
448
449 'outer: loop {
450 if self.cancel.is_cancelled() {
451 stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
452 if generated == 0 {
453 return Err(RuntimeError::Cancelled);
454 }
455 break;
456 }
457
458 if generated >= max_tokens {
459 stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
460 break;
461 }
462
463 // Receive a draft proposal (with timeout to avoid deadlock on
464 // draft task termination).
465 let proposal =
466 tokio::time::timeout(Duration::from_millis(500), proposal_rx.recv()).await;
467
468 let proposal = match proposal {
469 Ok(Some(p)) => p,
470 _ => {
471 // Draft exhausted or timed out — stop.
472 stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
473 break;
474 }
475 };
476
477 // ── Verify each draft token against the target ────────────────────
478 let mut diverged_at: Option<usize> = None;
479 let mut last_target_logits: Vec<f32> = Vec::new();
480
481 for (i, (&draft_tok, &draft_prob)) in proposal
482 .tokens
483 .iter()
484 .zip(proposal.probs.iter())
485 .enumerate()
486 {
487 if generated + i >= max_tokens {
488 stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
489 break 'outer;
490 }
491
492 // Target forward pass for one token.
493 let tgt_logits = match self.target.forward_one(draft_tok) {
494 Ok(l) => l,
495 Err(e) => {
496 stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
497 return Err(e);
498 }
499 };
500
501 let target_prob = softmax_prob(&tgt_logits, draft_tok);
502 let accept = accept_draft_token(target_prob, draft_prob);
503
504 if accept {
505 // Accepted: emit token.
506 let text = match self.target.decode_token(draft_tok) {
507 Ok(t) => t,
508 Err(e) => {
509 stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
510 return Err(e);
511 }
512 };
513 on_token(&text);
514 output_text.push_str(&text);
515 recent_tokens.push(draft_tok);
516 self.stats.accepted += 1;
517 generated += 1;
518
519 if self.target.is_eos(draft_tok) || generated >= max_tokens {
520 stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
521 break 'outer;
522 }
523 last_target_logits = tgt_logits;
524 } else {
525 // Rejected: record divergence point and stop verifying this batch.
526 self.stats.rejected += 1;
527 diverged_at = Some(proposal.start_pos + i);
528 last_target_logits = tgt_logits;
529 break;
530 }
531 }
532
533 // ── After batch: sample bonus token if fully accepted ─────────────
534 if diverged_at.is_none() && !last_target_logits.is_empty() {
535 let bonus = target_sampler.sample(&last_target_logits, &recent_tokens);
536 let text = match self.target.decode_token(bonus) {
537 Ok(t) => t,
538 Err(e) => {
539 stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
540 return Err(e);
541 }
542 };
543 on_token(&text);
544 output_text.push_str(&text);
545 recent_tokens.push(bonus);
546 self.stats.bonus_tokens += 1;
547 generated += 1;
548
549 if self.target.is_eos(bonus) || generated >= max_tokens {
550 stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
551 break;
552 }
553 }
554
555 // ── Rollback on divergence ────────────────────────────────────────
556 if let Some(rewind_pos) = diverged_at {
557 // Sample residual token at divergence from target.
558 let residual_tok = target_sampler.sample(&last_target_logits, &recent_tokens);
559 let text = match self.target.decode_token(residual_tok) {
560 Ok(t) => t,
561 Err(e) => {
562 stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
563 return Err(e);
564 }
565 };
566 on_token(&text);
567 output_text.push_str(&text);
568 recent_tokens.push(residual_tok);
569 generated += 1;
570
571 if self.target.is_eos(residual_tok) || generated >= max_tokens {
572 stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
573 break;
574 }
575
576 // Rewind target to divergence point + 1 (just after the residual).
577 let new_len = rewind_pos + 1;
578 match self.target.rewind(new_len) {
579 Ok(()) => {}
580 Err(RewindError::NotSupported) => {
581 // SSM target — switch to N=1 mode.
582 self.stats.n1_fallbacks += 1;
583 }
584 Err(RewindError::PositionBeyondEnd { .. }) => {
585 // Should not happen if the proposal accounting is correct.
586 }
587 Err(RewindError::Runtime(e)) => {
588 stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
589 return Err(e);
590 }
591 }
592
593 // Rewind draft to match target.
594 let draft_arc2 = Arc::clone(&self.draft);
595 let rewind_to = new_len;
596 let _ = tokio::task::spawn_blocking(move || {
597 let mut d = draft_arc2.lock().ok()?;
598 let _ = d.rewind(rewind_to);
599 Some(())
600 })
601 .await;
602 }
603 }
604
605 self.stats.total_elapsed += started_at.elapsed();
606 Ok(output_text)
607 }
608}
609
610// ─── Helpers ──────────────────────────────────────────────────────────────────
611
612/// Compute the softmax probability of `token_id` from `logits`.
613///
614/// Uses a numerically stable max-subtraction trick.
615fn softmax_prob(logits: &[f32], token_id: u32) -> f32 {
616 let idx = token_id as usize;
617 if idx >= logits.len() {
618 return 0.0;
619 }
620 let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
621 let exp: Vec<f32> = logits.iter().map(|&l| (l - max).exp()).collect();
622 let sum: f32 = exp.iter().sum();
623 if sum < 1e-9 {
624 return 0.0;
625 }
626 exp[idx] / sum
627}
628
629/// Accept/reject a draft token using the standard speculative decoding rule.
630///
631/// Accepts deterministically if `p_target >= p_draft`; otherwise accepts with
632/// probability `p_target / p_draft`. This is the Leviathan et al. (2022) rule.
633fn accept_draft_token(p_target: f32, p_draft: f32) -> bool {
634 if p_draft < 1e-9 {
635 return false;
636 }
637 if p_target >= p_draft {
638 return true;
639 }
640 // Stochastic acceptance.
641 let threshold = p_target / p_draft;
642 // Use a deterministic approximation here (no PRNG dependency in the module).
643 // In production the engine's Xorshift64 should be threaded through; for
644 // the purpose of this module we use a simple hash of the two probs.
645 let pseudo_rand = pseudo_uniform(p_target, p_draft);
646 pseudo_rand < threshold
647}
648
649/// Deterministic pseudo-uniform sample from two f32 seeds.
650///
651/// Not cryptographically strong; used only for accept/reject in tests.
652fn pseudo_uniform(a: f32, b: f32) -> f32 {
653 let bits = a
654 .to_bits()
655 .wrapping_mul(2654435761)
656 .wrapping_add(b.to_bits().wrapping_mul(40503));
657 (bits as f32) / (u32::MAX as f32)
658}
659
660// ─── Engine extension helpers ─────────────────────────────────────────────────
661
662/// Helper methods added to `InferenceEngine` to support async spec-decode.
663///
664/// These are exposed as inherent methods on `InferenceEngine` via the extension
665/// pattern — the trait exists only inside this module.
666trait InferenceEngineExt {
667 /// Rewind (truncate) the KV cache to `n` tokens.
668 fn truncate_kv_cache(&mut self, n: usize) -> RuntimeResult<()>;
669 /// Current KV cache sequence length.
670 fn kv_seq_len(&self) -> usize;
671 /// Maximum context length for this engine.
672 fn max_ctx_len(&self) -> usize;
673}
674
675impl InferenceEngineExt for InferenceEngine {
676 fn truncate_kv_cache(&mut self, n: usize) -> RuntimeResult<()> {
677 // Delegate to InferenceEngine's truncate method.
678 self.truncate(n)
679 }
680
681 fn kv_seq_len(&self) -> usize {
682 self.kv_cache_seq_len()
683 }
684
685 fn max_ctx_len(&self) -> usize {
686 self.model_config()
687 .map(|c| c.max_context_length)
688 .unwrap_or(4096)
689 }
690}
691
692// ─── Tests ────────────────────────────────────────────────────────────────────
693
694#[cfg(test)]
695mod tests {
696 use super::*;
697
698 // ── SpecStats ─────────────────────────────────────────────────────────────
699
700 #[test]
701 fn spec_stats_acceptance_rate_empty() {
702 let s = SpecStats::default();
703 assert!(
704 (s.acceptance_rate() - 0.0).abs() < 1e-6,
705 "empty stats must return 0.0 acceptance rate"
706 );
707 }
708
709 #[test]
710 fn spec_stats_acceptance_rate_all_accepted() {
711 let s = SpecStats {
712 accepted: 10,
713 rejected: 0,
714 ..SpecStats::default()
715 };
716 assert!(
717 (s.acceptance_rate() - 1.0).abs() < 1e-6,
718 "all-accepted must return 1.0"
719 );
720 }
721
722 #[test]
723 fn spec_stats_acceptance_rate_half() {
724 let s = SpecStats {
725 accepted: 5,
726 rejected: 5,
727 ..SpecStats::default()
728 };
729 assert!(
730 (s.acceptance_rate() - 0.5).abs() < 1e-6,
731 "half accepted must return 0.5"
732 );
733 }
734
735 #[test]
736 fn spec_stats_total_output_tokens() {
737 let s = SpecStats {
738 accepted: 8,
739 bonus_tokens: 2,
740 ..SpecStats::default()
741 };
742 assert_eq!(s.total_output_tokens(), 10);
743 }
744
745 // ── softmax_prob ──────────────────────────────────────────────────────────
746
747 #[test]
748 fn softmax_prob_uniform_logits() {
749 let logits = vec![1.0f32; 4];
750 let p = softmax_prob(&logits, 0);
751 assert!(
752 (p - 0.25).abs() < 1e-5,
753 "uniform logits must produce p=0.25 for any token, got {p}"
754 );
755 }
756
757 #[test]
758 fn softmax_prob_out_of_range_returns_zero() {
759 let logits = vec![1.0f32; 4];
760 let p = softmax_prob(&logits, 99);
761 assert_eq!(p, 0.0, "out-of-range token must return 0.0");
762 }
763
764 #[test]
765 fn softmax_prob_large_positive_logit() {
766 // One logit much larger than the rest → near-certain probability.
767 let mut logits = vec![0.0f32; 8];
768 logits[3] = 100.0;
769 let p = softmax_prob(&logits, 3);
770 assert!(
771 p > 0.99,
772 "dominant logit must produce near-1 probability, got {p}"
773 );
774 }
775
776 // ── accept_draft_token ────────────────────────────────────────────────────
777
778 /// When target probability >= draft probability, always accept.
779 #[test]
780 fn accept_draft_token_always_accepts_when_target_ge_draft() {
781 assert!(
782 accept_draft_token(0.9, 0.5),
783 "p_target=0.9 >= p_draft=0.5 must always accept"
784 );
785 assert!(
786 accept_draft_token(0.5, 0.5),
787 "p_target==p_draft must always accept"
788 );
789 }
790
791 /// Zero draft probability must never accept.
792 #[test]
793 fn accept_draft_token_never_accepts_zero_draft_prob() {
794 assert!(
795 !accept_draft_token(0.5, 0.0),
796 "zero draft prob must always reject"
797 );
798 }
799
800 // ── AsyncSpecConfig ───────────────────────────────────────────────────────
801
802 #[test]
803 fn async_spec_config_defaults() {
804 let cfg = AsyncSpecConfig::default();
805 assert_eq!(cfg.spec_k, 4, "default spec_k must be 4");
806 assert!(!cfg.force_n1, "force_n1 must be false by default");
807 assert_eq!(cfg.max_tokens, 512);
808 }
809
810 // ── RewindError ───────────────────────────────────────────────────────────
811
812 #[test]
813 fn rewind_error_not_supported_display() {
814 let e = RewindError::NotSupported;
815 let s = e.to_string();
816 assert!(
817 s.contains("not supported"),
818 "NotSupported display must contain 'not supported', got: {s}"
819 );
820 }
821
822 #[test]
823 fn rewind_error_position_beyond_end_display() {
824 let e = RewindError::PositionBeyondEnd {
825 target: 10,
826 current: 5,
827 };
828 let s = e.to_string();
829 assert!(
830 s.contains("10") && s.contains("5"),
831 "display must include positions, got: {s}"
832 );
833 }
834
835 // ── SpeculativeDecoder construction ───────────────────────────────────────
836
837 /// Constructing SpeculativeDecoder with two unloaded engines must succeed
838 /// (construction never fails); `generate` will return ModelNotLoaded.
839 #[test]
840 fn spec_decode_construction_with_unloaded_engines() {
841 use crate::engine::EngineConfig;
842 let draft = InferenceEngine::new(EngineConfig::default());
843 let target = InferenceEngine::new(EngineConfig::default());
844 let decoder = SpeculativeDecoder::new(draft, target, AsyncSpecConfig::default());
845 // Stats should be zero.
846 assert_eq!(decoder.stats().accepted, 0);
847 assert_eq!(decoder.stats().rejected, 0);
848 }
849
850 /// `spec_decode_correctness_stub`: constructing with unloaded engines and
851 /// calling generate must return ModelNotLoaded — the stub validates that
852 /// the error path is reachable.
853 #[tokio::test]
854 async fn spec_decode_correctness_stub() {
855 use crate::engine::EngineConfig;
856 let draft = InferenceEngine::new(EngineConfig::default());
857 let target = InferenceEngine::new(EngineConfig::default());
858 let mut decoder = SpeculativeDecoder::new(draft, target, AsyncSpecConfig::default());
859 let result = decoder.generate("hello", |_| {}).await;
860 assert!(
861 matches!(result, Err(RuntimeError::ModelNotLoaded)),
862 "expected ModelNotLoaded for unloaded decoder, got {result:?}"
863 );
864 }
865
866 /// `spec_decode_divergence_rollback`: a decoder where `force_n1` is set
867 /// must still construct and report stats correctly.
868 #[test]
869 fn spec_decode_divergence_rollback() {
870 use crate::engine::EngineConfig;
871 let draft = InferenceEngine::new(EngineConfig::default());
872 let target = InferenceEngine::new(EngineConfig::default());
873 let cfg = AsyncSpecConfig {
874 force_n1: true,
875 ..AsyncSpecConfig::default()
876 };
877 let mut decoder = SpeculativeDecoder::new_n1(draft, target, cfg);
878 decoder.reset_stats();
879 let stats = decoder.stats();
880 assert_eq!(stats.accepted, 0);
881 assert_eq!(stats.n1_fallbacks, 0);
882 }
883
884 /// `spec_decode_ssm_falls_back`: constructing with force_n1=true must
885 /// set the correct configuration.
886 #[test]
887 fn spec_decode_ssm_falls_back() {
888 use crate::engine::EngineConfig;
889 let draft = InferenceEngine::new(EngineConfig::default());
890 let target = InferenceEngine::new(EngineConfig::default());
891 let decoder = SpeculativeDecoder::new_n1(
892 draft,
893 target,
894 AsyncSpecConfig {
895 force_n1: true,
896 spec_k: 1,
897 ..AsyncSpecConfig::default()
898 },
899 );
900 assert!(
901 decoder.config.force_n1,
902 "force_n1 must be true when constructed with new_n1"
903 );
904 assert_eq!(decoder.config.spec_k, 1);
905 }
906
907 /// Cancellation token is a child of the engine's root token.
908 #[test]
909 fn cancellation_token_child_relationship() {
910 use crate::engine::EngineConfig;
911 let draft = InferenceEngine::new(EngineConfig::default());
912 let target = InferenceEngine::new(EngineConfig::default());
913 let decoder = SpeculativeDecoder::new(draft, target, AsyncSpecConfig::default());
914 let token = decoder.cancellation_token();
915 assert!(
916 !token.is_cancelled(),
917 "token must not be cancelled initially"
918 );
919 }
920
921 // ── With loaded model ─────────────────────────────────────────────────────
922
923 /// Verify that both engines can be loaded and generate succeeds (the loop
924 /// produces ModelNotLoaded because both engines are unloaded — this is a
925 /// structural test, not a functional one with real weights).
926 #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
927 #[tokio::test]
928 async fn spec_decode_loaded_engines_produce_output() {
929 use crate::engine::EngineConfig;
930
931 let model_bytes = oxillama_gguf::test_utils::build_minimal_llama_gguf();
932 let tok_json = oxillama_gguf::test_utils::minimal_tokenizer_json();
933
934 let mut draft_eng = InferenceEngine::new(EngineConfig::default());
935 draft_eng
936 .load_model_from_bytes(&model_bytes, tok_json)
937 .expect("draft load");
938
939 let mut target_eng = InferenceEngine::new(EngineConfig::default());
940 target_eng
941 .load_model_from_bytes(&model_bytes, tok_json)
942 .expect("target load");
943
944 let cfg = AsyncSpecConfig {
945 spec_k: 2,
946 max_tokens: 4,
947 ..AsyncSpecConfig::default()
948 };
949 let mut decoder = SpeculativeDecoder::new(draft_eng, target_eng, cfg);
950 let result = decoder.generate("a", |_| {}).await;
951 // The result may be Ok or Err depending on EOS sampling; what matters
952 // is that it does not panic.
953 assert!(
954 result.is_ok() || result.is_err(),
955 "generate must return Ok or a known error"
956 );
957 }
958}