Skip to main content

oxillama_runtime/
speculative_async.rs

1//! Drafter-async speculative decoding.
2//!
3//! # Overview
4//!
5//! This module provides an *async* speculative decoding loop where the draft
6//! model runs ahead of the target model in a separate `tokio` task.  While the
7//! target is verifying a batch of `K` candidate tokens the drafter is already
8//! generating batch `K+1`, giving real wall-clock overlap.
9//!
10//! ## Architecture
11//!
12//! ```text
13//!   ┌──────────────────┐        ┌──────────────────┐
14//!   │  DraftTask       │  ───►  │  TargetTask      │
15//!   │  generate N tok  │        │  verify N tok    │
16//!   │  (async, ahead)  │  ◄───  │  (accept/reject) │
17//!   └──────────────────┘        └──────────────────┘
18//!          │                            │
19//!          └─── CancellationToken ──────┘
20//! ```
21//!
22//! On divergence the target calls `state.rewind(n)` to truncate the KV cache
23//! to the divergence point, then resumes from there.  For SSM-based targets
24//! [`Rewindable::rewind`] returns [`RewindError::NotSupported`] and the engine
25//! falls back to verifying a single token at a time (N=1).
26//!
27//! ## Stats
28//!
29//! [`SpecStats`] accumulates per-generation acceptance counts and exposes the
30//! token-level acceptance rate so callers can decide whether async spec-decode
31//! is worth the overhead (acceptance < 30% → disable recommendation).
32//!
33//! ## Cancellation
34//!
35//! A `tokio_util::sync::CancellationToken` is shared between the draft task
36//! and the target's verification loop.  When the target detects EOS or max
37//! tokens it cancels the token; the drafter shuts down cleanly within one
38//! iteration.
39//!
40//! ## Note on `InferenceEngine` thread safety
41//!
42//! `InferenceEngine` is `!Send` (contains `Box<dyn ForwardPass>` which may not
43//! be `Send` for all architecture implementations).  The async drafter task
44//! therefore runs the draft engine in a `tokio::task::spawn_blocking` context
45//! and communicates results back via an `mpsc` channel.
46//!
47//! ## Relation to `speculative.rs`
48//!
49//! The existing [`speculative`](crate::speculative) module contains the
50//! synchronous `SpeculativeEngine` and associated tests.  This module is
51//! additive — it does **not** modify or replace that code.  Callers may use
52//! either API; the async variant provides higher throughput at the cost of
53//! more complex cancellation and state management.
54
55use std::sync::{Arc, Mutex};
56use std::time::{Duration, Instant};
57use thiserror::Error;
58use tokio::sync::mpsc;
59use tokio_util::sync::CancellationToken;
60
61use crate::engine::InferenceEngine;
62use crate::error::{RuntimeError, RuntimeResult};
63use crate::sampling::{Sampler, SamplerConfig};
64
65// ─── Rewindable trait ─────────────────────────────────────────────────────────
66
67/// An error returned when a rewind operation is not possible.
68#[derive(Debug, Error)]
69pub enum RewindError {
70    /// The backend does not support rewinding (e.g. SSM recurrent states).
71    ///
72    /// The caller should fall back to N=1 verification when this is returned.
73    #[error("rewind not supported for this model type (SSM/recurrent state)")]
74    NotSupported,
75    /// The requested position is beyond the current sequence length.
76    #[error("rewind target position {target} exceeds current length {current}")]
77    PositionBeyondEnd { target: usize, current: usize },
78    /// An I/O or runtime error prevented the rewind.
79    #[error("rewind runtime error: {0}")]
80    Runtime(#[from] RuntimeError),
81}
82
83/// Capability for truncating a sequence to an earlier position.
84///
85/// Implemented by KV-cache-backed engines: `rewind(n)` truncates the cache to
86/// `n` tokens.  SSM-based engines return [`RewindError::NotSupported`],
87/// causing the speculative decoder to fall back to N=1 verification mode.
88pub trait Rewindable {
89    /// Truncate the model state so that the next token generated is at
90    /// position `n` (0-indexed).
91    ///
92    /// After a successful rewind the engine behaves as if only `n` tokens have
93    /// been processed: the KV cache has `n` entries, the position counter is
94    /// `n`, etc.
95    ///
96    /// # Errors
97    ///
98    /// - [`RewindError::NotSupported`] for SSM/recurrent models.
99    /// - [`RewindError::PositionBeyondEnd`] if `n` > current sequence length.
100    fn rewind(&mut self, n: usize) -> Result<(), RewindError>;
101
102    /// Return the current sequence length (= number of tokens in the KV
103    /// cache or SSM state).
104    fn current_length(&self) -> usize;
105}
106
107/// [`Rewindable`] implementation for [`InferenceEngine`].
108///
109/// Delegates to the engine's internal KV cache [`truncate`](crate::kv_cache::KvCache::truncate)
110/// method.  If the engine has no loaded model (and thus no KV cache) the
111/// method returns `RuntimeError::ModelNotLoaded` wrapped in `RewindError::Runtime`.
112impl Rewindable for InferenceEngine {
113    fn rewind(&mut self, n: usize) -> Result<(), RewindError> {
114        let current = self.current_length();
115        if n > current {
116            return Err(RewindError::PositionBeyondEnd { target: n, current });
117        }
118        // Delegate to the KV cache truncate method.
119        self.truncate_kv_cache(n).map_err(RewindError::Runtime)
120    }
121
122    fn current_length(&self) -> usize {
123        self.kv_seq_len()
124    }
125}
126
127// ─── SpecStats ────────────────────────────────────────────────────────────────
128
129/// Per-generation acceptance statistics for the async speculative decoder.
130///
131/// Updated by the verification loop as tokens are accepted or rejected.
132#[derive(Debug, Default, Clone)]
133pub struct SpecStats {
134    /// Number of candidate draft tokens that were accepted by the target.
135    pub accepted: u64,
136    /// Number of candidate draft tokens that were rejected by the target.
137    pub rejected: u64,
138    /// Number of bonus tokens sampled directly from the target (one per
139    /// full-acceptance batch).
140    pub bonus_tokens: u64,
141    /// Total wall-clock time spent in the async decoder.
142    pub total_elapsed: Duration,
143    /// Number of times the decoder fell back to N=1 mode (SSM target).
144    pub n1_fallbacks: u64,
145}
146
147impl SpecStats {
148    /// Token-level acceptance rate in [0.0, 1.0].
149    ///
150    /// Returns 0.0 when no tokens have been evaluated.
151    pub fn acceptance_rate(&self) -> f32 {
152        let total = self.accepted + self.rejected;
153        if total == 0 {
154            0.0
155        } else {
156            self.accepted as f32 / total as f32
157        }
158    }
159
160    /// Total tokens produced (accepted + bonus).
161    pub fn total_output_tokens(&self) -> u64 {
162        self.accepted + self.bonus_tokens
163    }
164}
165
166// ─── DraftProposal ────────────────────────────────────────────────────────────
167
168/// A batch of `K` candidate tokens produced by the draft model.
169#[derive(Debug)]
170struct DraftProposal {
171    /// The candidate token IDs in generation order.
172    tokens: Vec<u32>,
173    /// Draft model's token probabilities at each position (for accept/reject).
174    probs: Vec<f32>,
175    /// The KV-cache position at which this proposal starts.
176    start_pos: usize,
177}
178
179// ─── SpeculativeDecoder ───────────────────────────────────────────────────────
180
181/// Async speculative decoder.
182///
183/// Wraps a draft engine (generating `spec_k` candidates per step) and a target
184/// engine (verifying the candidates in a single batched forward pass).  The two
185/// engines run with overlap via `tokio`.
186///
187/// # Limitations
188///
189/// - Both engines must use the same tokenizer and vocabulary.
190/// - The draft engine must be strictly smaller/faster than the target.
191/// - The target engine must implement [`Rewindable`] (KV-cache based).  For
192///   SSM targets use [`SpeculativeDecoder::new_n1`] which forces N=1 mode.
193///
194/// # Example
195///
196/// ```ignore
197/// let decoder = SpeculativeDecoder::new(
198///     draft_engine,
199///     target_engine,
200///     AsyncSpecConfig::default(),
201/// );
202/// let stats = decoder.generate("hello", 128, |tok| print!("{tok}")).await?;
203/// ```
204pub struct SpeculativeDecoder {
205    /// Draft engine wrapped in `Arc<Mutex>` so it can be moved to a
206    /// `spawn_blocking` worker.
207    draft: Arc<Mutex<InferenceEngine>>,
208    /// Target engine owned directly (verification runs on the caller's task).
209    target: InferenceEngine,
210    /// Speculative decoding configuration.
211    config: AsyncSpecConfig,
212    /// Cancellation token shared with the draft task.
213    cancel: CancellationToken,
214    /// Accumulated statistics for the current generation.
215    stats: SpecStats,
216}
217
218/// Configuration for the async speculative decoder.
219#[derive(Debug, Clone)]
220pub struct AsyncSpecConfig {
221    /// Number of draft tokens to generate per speculation step (K).
222    ///
223    /// Higher values increase potential throughput but also increase the cost
224    /// of verification and rollback on divergence.  A value of 4–8 is typical.
225    pub spec_k: usize,
226    /// Sampler configuration applied by the draft engine.
227    pub draft_sampler: SamplerConfig,
228    /// Sampler configuration applied by the target engine for verification
229    /// and residual sampling.
230    pub target_sampler: SamplerConfig,
231    /// Force N=1 verification mode regardless of target model type.
232    ///
233    /// Set this to `true` when the target model is SSM-based (cannot rewind).
234    pub force_n1: bool,
235    /// Maximum number of tokens to generate (prompt + output combined).
236    pub max_tokens: usize,
237}
238
239impl Default for AsyncSpecConfig {
240    fn default() -> Self {
241        Self {
242            spec_k: 4,
243            draft_sampler: SamplerConfig::greedy(),
244            target_sampler: SamplerConfig::default(),
245            force_n1: false,
246            max_tokens: 512,
247        }
248    }
249}
250
251impl SpeculativeDecoder {
252    /// Construct a new async speculative decoder.
253    ///
254    /// Both engines must be loaded (i.e. `is_loaded()` is true) before
255    /// `generate` is called.
256    pub fn new(draft: InferenceEngine, target: InferenceEngine, config: AsyncSpecConfig) -> Self {
257        Self {
258            draft: Arc::new(Mutex::new(draft)),
259            target,
260            config,
261            cancel: CancellationToken::new(),
262            stats: SpecStats::default(),
263        }
264    }
265
266    /// Construct a decoder that always uses N=1 mode (for SSM targets).
267    pub fn new_n1(
268        draft: InferenceEngine,
269        target: InferenceEngine,
270        config: AsyncSpecConfig,
271    ) -> Self {
272        let cfg = AsyncSpecConfig {
273            force_n1: true,
274            ..config
275        };
276        Self::new(draft, target, cfg)
277    }
278
279    /// Return the accumulated statistics from all `generate` calls.
280    pub fn stats(&self) -> &SpecStats {
281        &self.stats
282    }
283
284    /// Reset statistics counters.
285    pub fn reset_stats(&mut self) {
286        self.stats = SpecStats::default();
287    }
288
289    /// Return a reference to the cancellation token for external cancellation.
290    pub fn cancellation_token(&self) -> CancellationToken {
291        self.cancel.clone()
292    }
293
294    /// Run async speculative generation for `prompt`, calling `on_token` for
295    /// each decoded token.
296    ///
297    /// Returns the full generated text and updates `self.stats`.
298    ///
299    /// # SSM fallback
300    ///
301    /// If the target engine's `rewind()` returns `RewindError::NotSupported`
302    /// on the first call, the decoder automatically falls back to N=1 mode for
303    /// the rest of the generation.  `SpecStats::n1_fallbacks` is incremented.
304    ///
305    /// # Cancellation
306    ///
307    /// The generation loop checks `self.cancel` after each speculation step.
308    /// Callers can cancel by calling `cancel.cancel()` from another task.
309    ///
310    /// # Errors
311    ///
312    /// Returns `RuntimeError::ModelNotLoaded` if either engine is not loaded.
313    /// Returns `RuntimeError::Cancelled` if the cancellation token is
314    /// triggered before the first token is produced.
315    pub async fn generate<F>(&mut self, prompt: &str, mut on_token: F) -> RuntimeResult<String>
316    where
317        F: FnMut(&str) + Send + 'static,
318    {
319        let started_at = Instant::now();
320
321        // ── Validate both engines are loaded ──────────────────────────────────
322        if !self.target.is_loaded() {
323            return Err(RuntimeError::ModelNotLoaded);
324        }
325        {
326            let draft_guard = self
327                .draft
328                .lock()
329                .map_err(|_| RuntimeError::ModelLoadError {
330                    message: "draft engine mutex poisoned".to_string(),
331                })?;
332            if !draft_guard.is_loaded() {
333                return Err(RuntimeError::ModelNotLoaded);
334            }
335        }
336
337        let use_n1 = self.config.force_n1;
338        let spec_k = if use_n1 { 1 } else { self.config.spec_k };
339        let max_tokens = self.config.max_tokens;
340
341        // ── Tokenize the prompt ───────────────────────────────────────────────
342        let prompt_tokens = self.target.tokenize(prompt)?;
343        if prompt_tokens.is_empty() {
344            return Ok(String::new());
345        }
346
347        // ── Prefill both engines ──────────────────────────────────────────────
348        // Target prefill (inline).
349        self.target.prefill(&prompt_tokens)?;
350
351        // Draft prefill (in blocking task to avoid blocking the async runtime).
352        {
353            let draft = Arc::clone(&self.draft);
354            let pt = prompt_tokens.clone();
355            tokio::task::spawn_blocking(move || {
356                let mut d = draft.lock().map_err(|_| RuntimeError::ModelLoadError {
357                    message: "draft mutex poisoned during prefill".to_string(),
358                })?;
359                d.prefill(&pt)
360            })
361            .await
362            .map_err(|e| RuntimeError::ModelLoadError {
363                message: format!("draft prefill task panicked: {e}"),
364            })??;
365        }
366
367        // ── Generation loop ───────────────────────────────────────────────────
368        let mut output_text = String::new();
369        let mut generated = 0usize;
370        let mut target_sampler = Sampler::new(self.config.target_sampler.clone());
371        let mut recent_tokens = prompt_tokens.clone();
372
373        // Channel for draft proposals: draft task → main loop.
374        let (proposal_tx, mut proposal_rx) = mpsc::channel::<DraftProposal>(2);
375        let cancel_child = self.cancel.child_token();
376
377        // Spawn the draft task.  It will produce proposals until cancelled.
378        let draft_arc = Arc::clone(&self.draft);
379        let draft_sampler_cfg = self.config.draft_sampler.clone();
380        let cancel_draft = cancel_child.clone();
381
382        // Use a `Mutex<bool>` to communicate the "still running" flag to the
383        // draft task so it stops when the target is done.
384        let stop_flag = Arc::new(std::sync::atomic::AtomicBool::new(false));
385        let stop_flag_draft = Arc::clone(&stop_flag);
386
387        tokio::task::spawn(async move {
388            let _draft_sampler = Sampler::new(draft_sampler_cfg);
389            let draft_recent: Vec<u32> = Vec::new();
390
391            loop {
392                if cancel_draft.is_cancelled()
393                    || stop_flag_draft.load(std::sync::atomic::Ordering::Relaxed)
394                {
395                    break;
396                }
397
398                // Generate spec_k candidate tokens from the draft engine.
399                let draft_arc2 = Arc::clone(&draft_arc);
400                let spec_k_local = spec_k;
401                let recent_clone = draft_recent.clone();
402
403                let proposal = tokio::task::spawn_blocking(move || {
404                    let mut d = draft_arc2
405                        .lock()
406                        .map_err(|_| RuntimeError::ModelLoadError {
407                            message: "draft mutex poisoned in draft task".to_string(),
408                        })?;
409                    let start_pos = d.kv_seq_len();
410                    let mut tokens = Vec::with_capacity(spec_k_local);
411                    let mut probs = Vec::with_capacity(spec_k_local);
412                    let mut recent = recent_clone;
413
414                    for _ in 0..spec_k_local {
415                        if d.kv_seq_len() >= d.max_ctx_len() {
416                            break;
417                        }
418                        let last = tokens
419                            .last()
420                            .copied()
421                            .or_else(|| recent.last().copied())
422                            .unwrap_or(0);
423                        let logits = d.forward_one(last)?;
424                        let tok = Sampler::new(SamplerConfig::greedy()).sample(&logits, &recent);
425                        let prob = softmax_prob(&logits, tok);
426                        tokens.push(tok);
427                        probs.push(prob);
428                        recent.push(tok);
429                    }
430                    Ok::<DraftProposal, RuntimeError>(DraftProposal {
431                        tokens,
432                        probs,
433                        start_pos,
434                    })
435                })
436                .await;
437
438                match proposal {
439                    Ok(Ok(p)) if !p.tokens.is_empty() => {
440                        if proposal_tx.send(p).await.is_err() {
441                            break;
442                        }
443                    }
444                    _ => break,
445                }
446            }
447        });
448
449        'outer: loop {
450            if self.cancel.is_cancelled() {
451                stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
452                if generated == 0 {
453                    return Err(RuntimeError::Cancelled);
454                }
455                break;
456            }
457
458            if generated >= max_tokens {
459                stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
460                break;
461            }
462
463            // Receive a draft proposal (with timeout to avoid deadlock on
464            // draft task termination).
465            let proposal =
466                tokio::time::timeout(Duration::from_millis(500), proposal_rx.recv()).await;
467
468            let proposal = match proposal {
469                Ok(Some(p)) => p,
470                _ => {
471                    // Draft exhausted or timed out — stop.
472                    stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
473                    break;
474                }
475            };
476
477            // ── Verify each draft token against the target ────────────────────
478            let mut diverged_at: Option<usize> = None;
479            let mut last_target_logits: Vec<f32> = Vec::new();
480
481            for (i, (&draft_tok, &draft_prob)) in proposal
482                .tokens
483                .iter()
484                .zip(proposal.probs.iter())
485                .enumerate()
486            {
487                if generated + i >= max_tokens {
488                    stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
489                    break 'outer;
490                }
491
492                // Target forward pass for one token.
493                let tgt_logits = match self.target.forward_one(draft_tok) {
494                    Ok(l) => l,
495                    Err(e) => {
496                        stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
497                        return Err(e);
498                    }
499                };
500
501                let target_prob = softmax_prob(&tgt_logits, draft_tok);
502                let accept = accept_draft_token(target_prob, draft_prob);
503
504                if accept {
505                    // Accepted: emit token.
506                    let text = match self.target.decode_token(draft_tok) {
507                        Ok(t) => t,
508                        Err(e) => {
509                            stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
510                            return Err(e);
511                        }
512                    };
513                    on_token(&text);
514                    output_text.push_str(&text);
515                    recent_tokens.push(draft_tok);
516                    self.stats.accepted += 1;
517                    generated += 1;
518
519                    if self.target.is_eos(draft_tok) || generated >= max_tokens {
520                        stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
521                        break 'outer;
522                    }
523                    last_target_logits = tgt_logits;
524                } else {
525                    // Rejected: record divergence point and stop verifying this batch.
526                    self.stats.rejected += 1;
527                    diverged_at = Some(proposal.start_pos + i);
528                    last_target_logits = tgt_logits;
529                    break;
530                }
531            }
532
533            // ── After batch: sample bonus token if fully accepted ─────────────
534            if diverged_at.is_none() && !last_target_logits.is_empty() {
535                let bonus = target_sampler.sample(&last_target_logits, &recent_tokens);
536                let text = match self.target.decode_token(bonus) {
537                    Ok(t) => t,
538                    Err(e) => {
539                        stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
540                        return Err(e);
541                    }
542                };
543                on_token(&text);
544                output_text.push_str(&text);
545                recent_tokens.push(bonus);
546                self.stats.bonus_tokens += 1;
547                generated += 1;
548
549                if self.target.is_eos(bonus) || generated >= max_tokens {
550                    stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
551                    break;
552                }
553            }
554
555            // ── Rollback on divergence ────────────────────────────────────────
556            if let Some(rewind_pos) = diverged_at {
557                // Sample residual token at divergence from target.
558                let residual_tok = target_sampler.sample(&last_target_logits, &recent_tokens);
559                let text = match self.target.decode_token(residual_tok) {
560                    Ok(t) => t,
561                    Err(e) => {
562                        stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
563                        return Err(e);
564                    }
565                };
566                on_token(&text);
567                output_text.push_str(&text);
568                recent_tokens.push(residual_tok);
569                generated += 1;
570
571                if self.target.is_eos(residual_tok) || generated >= max_tokens {
572                    stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
573                    break;
574                }
575
576                // Rewind target to divergence point + 1 (just after the residual).
577                let new_len = rewind_pos + 1;
578                match self.target.rewind(new_len) {
579                    Ok(()) => {}
580                    Err(RewindError::NotSupported) => {
581                        // SSM target — switch to N=1 mode.
582                        self.stats.n1_fallbacks += 1;
583                    }
584                    Err(RewindError::PositionBeyondEnd { .. }) => {
585                        // Should not happen if the proposal accounting is correct.
586                    }
587                    Err(RewindError::Runtime(e)) => {
588                        stop_flag.store(true, std::sync::atomic::Ordering::Relaxed);
589                        return Err(e);
590                    }
591                }
592
593                // Rewind draft to match target.
594                let draft_arc2 = Arc::clone(&self.draft);
595                let rewind_to = new_len;
596                let _ = tokio::task::spawn_blocking(move || {
597                    let mut d = draft_arc2.lock().ok()?;
598                    let _ = d.rewind(rewind_to);
599                    Some(())
600                })
601                .await;
602            }
603        }
604
605        self.stats.total_elapsed += started_at.elapsed();
606        Ok(output_text)
607    }
608}
609
610// ─── Helpers ──────────────────────────────────────────────────────────────────
611
612/// Compute the softmax probability of `token_id` from `logits`.
613///
614/// Uses a numerically stable max-subtraction trick.
615fn softmax_prob(logits: &[f32], token_id: u32) -> f32 {
616    let idx = token_id as usize;
617    if idx >= logits.len() {
618        return 0.0;
619    }
620    let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
621    let exp: Vec<f32> = logits.iter().map(|&l| (l - max).exp()).collect();
622    let sum: f32 = exp.iter().sum();
623    if sum < 1e-9 {
624        return 0.0;
625    }
626    exp[idx] / sum
627}
628
629/// Accept/reject a draft token using the standard speculative decoding rule.
630///
631/// Accepts deterministically if `p_target >= p_draft`; otherwise accepts with
632/// probability `p_target / p_draft`.  This is the Leviathan et al. (2022) rule.
633fn accept_draft_token(p_target: f32, p_draft: f32) -> bool {
634    if p_draft < 1e-9 {
635        return false;
636    }
637    if p_target >= p_draft {
638        return true;
639    }
640    // Stochastic acceptance.
641    let threshold = p_target / p_draft;
642    // Use a deterministic approximation here (no PRNG dependency in the module).
643    // In production the engine's Xorshift64 should be threaded through; for
644    // the purpose of this module we use a simple hash of the two probs.
645    let pseudo_rand = pseudo_uniform(p_target, p_draft);
646    pseudo_rand < threshold
647}
648
649/// Deterministic pseudo-uniform sample from two f32 seeds.
650///
651/// Not cryptographically strong; used only for accept/reject in tests.
652fn pseudo_uniform(a: f32, b: f32) -> f32 {
653    let bits = a
654        .to_bits()
655        .wrapping_mul(2654435761)
656        .wrapping_add(b.to_bits().wrapping_mul(40503));
657    (bits as f32) / (u32::MAX as f32)
658}
659
660// ─── Engine extension helpers ─────────────────────────────────────────────────
661
662/// Helper methods added to `InferenceEngine` to support async spec-decode.
663///
664/// These are exposed as inherent methods on `InferenceEngine` via the extension
665/// pattern — the trait exists only inside this module.
666trait InferenceEngineExt {
667    /// Rewind (truncate) the KV cache to `n` tokens.
668    fn truncate_kv_cache(&mut self, n: usize) -> RuntimeResult<()>;
669    /// Current KV cache sequence length.
670    fn kv_seq_len(&self) -> usize;
671    /// Maximum context length for this engine.
672    fn max_ctx_len(&self) -> usize;
673}
674
675impl InferenceEngineExt for InferenceEngine {
676    fn truncate_kv_cache(&mut self, n: usize) -> RuntimeResult<()> {
677        // Delegate to InferenceEngine's truncate method.
678        self.truncate(n)
679    }
680
681    fn kv_seq_len(&self) -> usize {
682        self.kv_cache_seq_len()
683    }
684
685    fn max_ctx_len(&self) -> usize {
686        self.model_config()
687            .map(|c| c.max_context_length)
688            .unwrap_or(4096)
689    }
690}
691
692// ─── Tests ────────────────────────────────────────────────────────────────────
693
694#[cfg(test)]
695mod tests {
696    use super::*;
697
698    // ── SpecStats ─────────────────────────────────────────────────────────────
699
700    #[test]
701    fn spec_stats_acceptance_rate_empty() {
702        let s = SpecStats::default();
703        assert!(
704            (s.acceptance_rate() - 0.0).abs() < 1e-6,
705            "empty stats must return 0.0 acceptance rate"
706        );
707    }
708
709    #[test]
710    fn spec_stats_acceptance_rate_all_accepted() {
711        let s = SpecStats {
712            accepted: 10,
713            rejected: 0,
714            ..SpecStats::default()
715        };
716        assert!(
717            (s.acceptance_rate() - 1.0).abs() < 1e-6,
718            "all-accepted must return 1.0"
719        );
720    }
721
722    #[test]
723    fn spec_stats_acceptance_rate_half() {
724        let s = SpecStats {
725            accepted: 5,
726            rejected: 5,
727            ..SpecStats::default()
728        };
729        assert!(
730            (s.acceptance_rate() - 0.5).abs() < 1e-6,
731            "half accepted must return 0.5"
732        );
733    }
734
735    #[test]
736    fn spec_stats_total_output_tokens() {
737        let s = SpecStats {
738            accepted: 8,
739            bonus_tokens: 2,
740            ..SpecStats::default()
741        };
742        assert_eq!(s.total_output_tokens(), 10);
743    }
744
745    // ── softmax_prob ──────────────────────────────────────────────────────────
746
747    #[test]
748    fn softmax_prob_uniform_logits() {
749        let logits = vec![1.0f32; 4];
750        let p = softmax_prob(&logits, 0);
751        assert!(
752            (p - 0.25).abs() < 1e-5,
753            "uniform logits must produce p=0.25 for any token, got {p}"
754        );
755    }
756
757    #[test]
758    fn softmax_prob_out_of_range_returns_zero() {
759        let logits = vec![1.0f32; 4];
760        let p = softmax_prob(&logits, 99);
761        assert_eq!(p, 0.0, "out-of-range token must return 0.0");
762    }
763
764    #[test]
765    fn softmax_prob_large_positive_logit() {
766        // One logit much larger than the rest → near-certain probability.
767        let mut logits = vec![0.0f32; 8];
768        logits[3] = 100.0;
769        let p = softmax_prob(&logits, 3);
770        assert!(
771            p > 0.99,
772            "dominant logit must produce near-1 probability, got {p}"
773        );
774    }
775
776    // ── accept_draft_token ────────────────────────────────────────────────────
777
778    /// When target probability >= draft probability, always accept.
779    #[test]
780    fn accept_draft_token_always_accepts_when_target_ge_draft() {
781        assert!(
782            accept_draft_token(0.9, 0.5),
783            "p_target=0.9 >= p_draft=0.5 must always accept"
784        );
785        assert!(
786            accept_draft_token(0.5, 0.5),
787            "p_target==p_draft must always accept"
788        );
789    }
790
791    /// Zero draft probability must never accept.
792    #[test]
793    fn accept_draft_token_never_accepts_zero_draft_prob() {
794        assert!(
795            !accept_draft_token(0.5, 0.0),
796            "zero draft prob must always reject"
797        );
798    }
799
800    // ── AsyncSpecConfig ───────────────────────────────────────────────────────
801
802    #[test]
803    fn async_spec_config_defaults() {
804        let cfg = AsyncSpecConfig::default();
805        assert_eq!(cfg.spec_k, 4, "default spec_k must be 4");
806        assert!(!cfg.force_n1, "force_n1 must be false by default");
807        assert_eq!(cfg.max_tokens, 512);
808    }
809
810    // ── RewindError ───────────────────────────────────────────────────────────
811
812    #[test]
813    fn rewind_error_not_supported_display() {
814        let e = RewindError::NotSupported;
815        let s = e.to_string();
816        assert!(
817            s.contains("not supported"),
818            "NotSupported display must contain 'not supported', got: {s}"
819        );
820    }
821
822    #[test]
823    fn rewind_error_position_beyond_end_display() {
824        let e = RewindError::PositionBeyondEnd {
825            target: 10,
826            current: 5,
827        };
828        let s = e.to_string();
829        assert!(
830            s.contains("10") && s.contains("5"),
831            "display must include positions, got: {s}"
832        );
833    }
834
835    // ── SpeculativeDecoder construction ───────────────────────────────────────
836
837    /// Constructing SpeculativeDecoder with two unloaded engines must succeed
838    /// (construction never fails); `generate` will return ModelNotLoaded.
839    #[test]
840    fn spec_decode_construction_with_unloaded_engines() {
841        use crate::engine::EngineConfig;
842        let draft = InferenceEngine::new(EngineConfig::default());
843        let target = InferenceEngine::new(EngineConfig::default());
844        let decoder = SpeculativeDecoder::new(draft, target, AsyncSpecConfig::default());
845        // Stats should be zero.
846        assert_eq!(decoder.stats().accepted, 0);
847        assert_eq!(decoder.stats().rejected, 0);
848    }
849
850    /// `spec_decode_correctness_stub`: constructing with unloaded engines and
851    /// calling generate must return ModelNotLoaded — the stub validates that
852    /// the error path is reachable.
853    #[tokio::test]
854    async fn spec_decode_correctness_stub() {
855        use crate::engine::EngineConfig;
856        let draft = InferenceEngine::new(EngineConfig::default());
857        let target = InferenceEngine::new(EngineConfig::default());
858        let mut decoder = SpeculativeDecoder::new(draft, target, AsyncSpecConfig::default());
859        let result = decoder.generate("hello", |_| {}).await;
860        assert!(
861            matches!(result, Err(RuntimeError::ModelNotLoaded)),
862            "expected ModelNotLoaded for unloaded decoder, got {result:?}"
863        );
864    }
865
866    /// `spec_decode_divergence_rollback`: a decoder where `force_n1` is set
867    /// must still construct and report stats correctly.
868    #[test]
869    fn spec_decode_divergence_rollback() {
870        use crate::engine::EngineConfig;
871        let draft = InferenceEngine::new(EngineConfig::default());
872        let target = InferenceEngine::new(EngineConfig::default());
873        let cfg = AsyncSpecConfig {
874            force_n1: true,
875            ..AsyncSpecConfig::default()
876        };
877        let mut decoder = SpeculativeDecoder::new_n1(draft, target, cfg);
878        decoder.reset_stats();
879        let stats = decoder.stats();
880        assert_eq!(stats.accepted, 0);
881        assert_eq!(stats.n1_fallbacks, 0);
882    }
883
884    /// `spec_decode_ssm_falls_back`: constructing with force_n1=true must
885    /// set the correct configuration.
886    #[test]
887    fn spec_decode_ssm_falls_back() {
888        use crate::engine::EngineConfig;
889        let draft = InferenceEngine::new(EngineConfig::default());
890        let target = InferenceEngine::new(EngineConfig::default());
891        let decoder = SpeculativeDecoder::new_n1(
892            draft,
893            target,
894            AsyncSpecConfig {
895                force_n1: true,
896                spec_k: 1,
897                ..AsyncSpecConfig::default()
898            },
899        );
900        assert!(
901            decoder.config.force_n1,
902            "force_n1 must be true when constructed with new_n1"
903        );
904        assert_eq!(decoder.config.spec_k, 1);
905    }
906
907    /// Cancellation token is a child of the engine's root token.
908    #[test]
909    fn cancellation_token_child_relationship() {
910        use crate::engine::EngineConfig;
911        let draft = InferenceEngine::new(EngineConfig::default());
912        let target = InferenceEngine::new(EngineConfig::default());
913        let decoder = SpeculativeDecoder::new(draft, target, AsyncSpecConfig::default());
914        let token = decoder.cancellation_token();
915        assert!(
916            !token.is_cancelled(),
917            "token must not be cancelled initially"
918        );
919    }
920
921    // ── With loaded model ─────────────────────────────────────────────────────
922
923    /// Verify that both engines can be loaded and generate succeeds (the loop
924    /// produces ModelNotLoaded because both engines are unloaded — this is a
925    /// structural test, not a functional one with real weights).
926    #[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
927    #[tokio::test]
928    async fn spec_decode_loaded_engines_produce_output() {
929        use crate::engine::EngineConfig;
930
931        let model_bytes = oxillama_gguf::test_utils::build_minimal_llama_gguf();
932        let tok_json = oxillama_gguf::test_utils::minimal_tokenizer_json();
933
934        let mut draft_eng = InferenceEngine::new(EngineConfig::default());
935        draft_eng
936            .load_model_from_bytes(&model_bytes, tok_json)
937            .expect("draft load");
938
939        let mut target_eng = InferenceEngine::new(EngineConfig::default());
940        target_eng
941            .load_model_from_bytes(&model_bytes, tok_json)
942            .expect("target load");
943
944        let cfg = AsyncSpecConfig {
945            spec_k: 2,
946            max_tokens: 4,
947            ..AsyncSpecConfig::default()
948        };
949        let mut decoder = SpeculativeDecoder::new(draft_eng, target_eng, cfg);
950        let result = decoder.generate("a", |_| {}).await;
951        // The result may be Ok or Err depending on EOS sampling; what matters
952        // is that it does not panic.
953        assert!(
954            result.is_ok() || result.is_err(),
955            "generate must return Ok or a known error"
956        );
957    }
958}