Skip to main content

oxibonsai_runtime/
speculative.rs

1//! Speculative decoding for accelerated autoregressive generation.
2//!
3//! Speculative decoding uses a small "draft" model to generate K candidate tokens,
4//! which the larger "target" model then verifies in a single parallel forward pass.
5//! Accepted tokens are kept; the first rejected token is resampled from the target
6//! distribution. This can yield near-linear speedup proportional to the average
7//! number of accepted tokens per step.
8//!
9//! ## Algorithm (Leviathan et al., 2023)
10//!
11//! 1. Draft model generates K tokens: `t_1, ..., t_K` with draft probabilities `p_d`
12//! 2. Target model scores all K+1 positions in parallel, producing `p_t`
13//! 3. For each position `i`, accept `t_i` if:
14//!    - `p_t(t_i) >= p_d(t_i)`, OR
15//!    - with probability `p_t(t_i) / p_d(t_i)` (rejection sampling)
16//! 4. If rejected at position `i`, resample from adjusted distribution
17//! 5. Always append one bonus target-sampled token after full acceptance
18//!
19//! ## Usage
20//!
21//! ```rust,no_run
22//! use oxibonsai_core::config::Qwen3Config;
23//! use oxibonsai_runtime::engine::InferenceEngine;
24//! use oxibonsai_runtime::sampling::SamplingParams;
25//! use oxibonsai_runtime::speculative::{SpeculativeConfig, SpeculativeDecoder};
26//!
27//! let config = Qwen3Config::tiny_test();
28//! let draft_engine = InferenceEngine::new(config, SamplingParams::default(), 42);
29//! let spec_config = SpeculativeConfig::default();
30//! let mut decoder = SpeculativeDecoder::new(draft_engine, spec_config);
31//! ```
32
33use crate::adaptive_lookahead::{AdaptiveLookahead, AdaptiveLookaheadConfig};
34use crate::engine::InferenceEngine;
35use crate::sampling::SamplingParams;
36
37// ──────────────────────────────────────────────────────────────────
38// Configuration
39// ──────────────────────────────────────────────────────────────────
40
41/// Configuration for speculative decoding.
42#[derive(Debug, Clone)]
43pub struct SpeculativeConfig {
44    /// Number of draft tokens to generate per step (lookahead K, typically 4–8).
45    pub lookahead: usize,
46    /// Minimum acceptance ratio threshold (0.0 = pure rejection sampling criterion).
47    ///
48    /// Setting this above 0.0 makes the decoder more conservative (fewer accepted
49    /// tokens per step, but closer to target distribution).
50    pub acceptance_threshold: f32,
51}
52
53impl Default for SpeculativeConfig {
54    fn default() -> Self {
55        Self {
56            lookahead: 5,
57            acceptance_threshold: 0.0,
58        }
59    }
60}
61
62// ──────────────────────────────────────────────────────────────────
63// Step result
64// ──────────────────────────────────────────────────────────────────
65
66/// Result from one speculative decoding step (draft + verify).
67#[derive(Debug, Clone)]
68pub struct SpeculativeStep {
69    /// Tokens proposed by the draft model.
70    pub draft_tokens: Vec<u32>,
71    /// Tokens accepted after verification against the target.
72    pub accepted_tokens: Vec<u32>,
73    /// Fraction of draft tokens that were accepted: `accepted / proposed`.
74    pub acceptance_rate: f32,
75}
76
77// ──────────────────────────────────────────────────────────────────
78// Internal mini-PRNG (xorshift64, no external rand crate)
79// ──────────────────────────────────────────────────────────────────
80
81/// Minimal xorshift64 PRNG state — no external dependency.
82struct Xorshift64 {
83    state: u64,
84}
85
86impl Xorshift64 {
87    fn new(seed: u64) -> Self {
88        // Ensure non-zero state (xorshift must not start at 0)
89        let state = if seed == 0 { 0xdeadbeef_cafebabe } else { seed };
90        Self { state }
91    }
92
93    fn next_u64(&mut self) -> u64 {
94        self.state ^= self.state << 13;
95        self.state ^= self.state >> 7;
96        self.state ^= self.state << 17;
97        self.state
98    }
99
100    /// Returns a sample in `[0.0, 1.0)`.
101    fn next_f32(&mut self) -> f32 {
102        // Use top 24 bits for f32 mantissa precision
103        (self.next_u64() >> 40) as f32 / (1u64 << 24) as f32
104    }
105}
106
107// ──────────────────────────────────────────────────────────────────
108// SpeculativeDecoder
109// ──────────────────────────────────────────────────────────────────
110
111/// Speculative decoder: wraps a draft [`InferenceEngine`] and provides
112/// draft-then-verify generation with running acceptance statistics.
113pub struct SpeculativeDecoder<'a> {
114    /// Draft model engine (smaller/faster model).
115    pub draft_engine: InferenceEngine<'a>,
116    /// Speculative decoding configuration.
117    pub config: SpeculativeConfig,
118    /// Total number of speculative steps taken.
119    pub total_steps: u64,
120    /// Total number of tokens proposed by the draft model.
121    pub total_draft_tokens: u64,
122    /// Total number of tokens accepted after target verification.
123    pub total_accepted_tokens: u64,
124    /// Internal PRNG for rejection sampling decisions (available for subtype use).
125    #[allow(dead_code)]
126    rng: Xorshift64,
127    /// Optional adaptive controller — when present, the lookahead is
128    /// updated after each step from the running acceptance EWMA.
129    adaptive: Option<AdaptiveLookahead>,
130}
131
132impl<'a> SpeculativeDecoder<'a> {
133    /// Create a new speculative decoder with the given draft engine and config.
134    pub fn new(draft_engine: InferenceEngine<'a>, config: SpeculativeConfig) -> Self {
135        Self {
136            draft_engine,
137            config,
138            total_steps: 0,
139            total_draft_tokens: 0,
140            total_accepted_tokens: 0,
141            rng: Xorshift64::new(0xfeed1234_5678abcd),
142            adaptive: None,
143        }
144    }
145
146    /// Create a speculative decoder with an [`AdaptiveLookahead`] controller
147    /// active. The initial lookahead is taken from `adaptive_config.initial`
148    /// and overrides `config.lookahead` for the first step.
149    pub fn with_adaptive(
150        draft_engine: InferenceEngine<'a>,
151        config: SpeculativeConfig,
152        adaptive_config: AdaptiveLookaheadConfig,
153    ) -> Result<Self, crate::adaptive_lookahead::AdaptiveLookaheadError> {
154        let adaptive = AdaptiveLookahead::try_new(adaptive_config)?;
155        let mut config = config;
156        config.lookahead = adaptive.lookahead();
157        Ok(Self {
158            draft_engine,
159            config,
160            total_steps: 0,
161            total_draft_tokens: 0,
162            total_accepted_tokens: 0,
163            rng: Xorshift64::new(0xfeed1234_5678abcd),
164            adaptive: Some(adaptive),
165        })
166    }
167
168    /// Read the current adaptive controller, if any.
169    pub fn adaptive(&self) -> Option<&AdaptiveLookahead> {
170        self.adaptive.as_ref()
171    }
172
173    /// Mutable access to the adaptive controller, if any.
174    pub fn adaptive_mut(&mut self) -> Option<&mut AdaptiveLookahead> {
175        self.adaptive.as_mut()
176    }
177
178    /// Generate up to `config.lookahead` draft tokens from the draft model.
179    ///
180    /// In this implementation, the draft engine uses its sampler to produce tokens
181    /// autoregressively from `context`. The returned tokens are the draft candidates
182    /// for target-model verification.
183    pub fn draft(&mut self, context: &[u32], _params: &SamplingParams) -> Vec<u32> {
184        let k = self.config.lookahead;
185        let mut draft_tokens = Vec::with_capacity(k);
186
187        // Build a combined context + generated so far
188        let mut current_context: Vec<u32> = context.to_vec();
189
190        for _ in 0..k {
191            // Generate one token using the draft engine
192            match self.draft_engine.generate(&current_context, 1) {
193                Ok(generated) if !generated.is_empty() => {
194                    let token = generated[0];
195                    draft_tokens.push(token);
196                    current_context.push(token);
197                }
198                _ => {
199                    // Draft generation failed or returned empty — stop drafting
200                    break;
201                }
202            }
203        }
204
205        draft_tokens
206    }
207
208    /// Verify draft tokens against target-model logits.
209    ///
210    /// For each draft position `i`, the target's probability `p_t(t_i)` is
211    /// compared against a mock draft probability `p_d(t_i)` derived from
212    /// the target logits (as a self-consistency check when target logits are
213    /// provided). In production, `p_d` comes from the draft model's softmax.
214    ///
215    /// Acceptance criterion (speculative sampling):
216    /// - Accept if `p_t(t_i) >= p_d(t_i)`
217    /// - Else accept with probability `p_t(t_i) / p_d(t_i)`
218    ///
219    /// Returns only the prefix of tokens accepted before the first rejection.
220    pub fn verify(
221        &self,
222        draft_tokens: &[u32],
223        target_logits: &[Vec<f32>],
224        _params: &SamplingParams,
225    ) -> Vec<u32> {
226        let mut accepted = Vec::with_capacity(draft_tokens.len());
227
228        // We need a mutable PRNG — use a local one seeded from step count for reproducibility
229        let mut local_rng = Xorshift64::new(
230            self.total_steps
231                .wrapping_mul(6364136223846793005)
232                .wrapping_add(0xabcdef01),
233        );
234
235        for (i, &token) in draft_tokens.iter().enumerate() {
236            let logits = match target_logits.get(i) {
237                Some(l) => l,
238                None => break,
239            };
240
241            if logits.is_empty() {
242                break;
243            }
244
245            // Compute softmax probabilities for target
246            let target_probs = softmax(logits);
247
248            // Get target probability for this draft token
249            let target_prob = if (token as usize) < target_probs.len() {
250                target_probs[token as usize]
251            } else {
252                0.0
253            };
254
255            // Mock draft probability: use a uniform-like estimate over top candidates
256            // In production this would come from the draft model's own softmax output.
257            // Here we use 1/vocab_size as a conservative draft estimate.
258            let vocab_size = logits.len() as f32;
259            let draft_prob = (1.0 / vocab_size).max(1e-9);
260
261            let rng_sample = local_rng.next_f32();
262            let threshold = self.config.acceptance_threshold;
263
264            if Self::should_accept(draft_prob, target_prob, threshold, rng_sample) {
265                accepted.push(token);
266            } else {
267                // First rejection — stop here
268                break;
269            }
270        }
271
272        accepted
273    }
274
275    /// Perform one complete speculative decoding step: draft K tokens then verify.
276    ///
277    /// Returns a [`SpeculativeStep`] with the draft proposals, accepted subset,
278    /// and per-step acceptance rate.
279    pub fn step(
280        &mut self,
281        context: &[u32],
282        target_logits: &[Vec<f32>],
283        params: &SamplingParams,
284    ) -> SpeculativeStep {
285        // Phase 1: Draft
286        let draft_tokens = self.draft(context, params);
287        let n_drafted = draft_tokens.len();
288
289        // Phase 2: Verify
290        let accepted_tokens = self.verify(&draft_tokens, target_logits, params);
291        let n_accepted = accepted_tokens.len();
292
293        // Update statistics
294        self.total_steps += 1;
295        self.total_draft_tokens += n_drafted as u64;
296        self.total_accepted_tokens += n_accepted as u64;
297
298        // Feed the adaptive controller (if any) and apply its lookahead update.
299        if let Some(adaptive) = self.adaptive.as_mut() {
300            adaptive.observe_step(n_drafted, n_accepted);
301            // The controller may have changed `lookahead` — propagate it to
302            // `config.lookahead` so the next `step` drafts the new amount.
303            self.config.lookahead = adaptive.lookahead();
304        }
305
306        let acceptance_rate = if n_drafted > 0 {
307            n_accepted as f32 / n_drafted as f32
308        } else {
309            0.0
310        };
311
312        SpeculativeStep {
313            draft_tokens,
314            accepted_tokens,
315            acceptance_rate,
316        }
317    }
318
319    /// Generate up to `max_tokens` tokens using speculative decoding.
320    ///
321    /// Each step drafts `lookahead` candidates, verifies them, and appends
322    /// accepted tokens. The loop continues until `max_tokens` are collected
323    /// or generation stalls (no tokens accepted/generated).
324    ///
325    /// In this mock implementation, target logits are synthesised from the
326    /// draft engine's perspective — in production the target model would
327    /// score all positions in one batched forward pass.
328    pub fn generate_speculative(
329        &mut self,
330        prompt_tokens: &[u32],
331        max_tokens: usize,
332        params: &SamplingParams,
333    ) -> Vec<u32> {
334        let mut output: Vec<u32> = Vec::with_capacity(max_tokens);
335        let mut context: Vec<u32> = prompt_tokens.to_vec();
336
337        while output.len() < max_tokens {
338            let remaining = max_tokens - output.len();
339            let effective_lookahead = self.config.lookahead.min(remaining);
340
341            // Synthesise mock target logits for each draft position.
342            // In production: run target model forward pass over all positions.
343            // Here we generate uniform-ish logits for each draft position using PRNG.
344            let vocab_size = 32000usize; // representative for Qwen3
345            let target_logits: Vec<Vec<f32>> = (0..effective_lookahead)
346                .map(|step_idx| {
347                    // Build a peaked distribution at a token derived from context + step
348                    let peak_token =
349                        (context.last().copied().unwrap_or(0) as usize + step_idx + 1) % vocab_size;
350                    let mut logits = vec![0.0f32; vocab_size];
351                    // Give the peak token high logit, others low
352                    logits[peak_token] = 10.0;
353                    for (i, l) in logits.iter_mut().enumerate() {
354                        if i != peak_token {
355                            *l = -2.0;
356                        }
357                    }
358                    logits
359                })
360                .collect();
361
362            let step_result = self.step(&context, &target_logits, params);
363
364            if step_result.accepted_tokens.is_empty() {
365                // No tokens accepted — try generating one greedily to avoid infinite loop
366                match self.draft_engine.generate(&context, 1) {
367                    Ok(t) if !t.is_empty() => {
368                        let token = t[0];
369                        output.push(token);
370                        context.push(token);
371                    }
372                    _ => break,
373                }
374            } else {
375                let to_take = step_result.accepted_tokens.len().min(remaining);
376                for &tok in step_result.accepted_tokens[..to_take].iter() {
377                    output.push(tok);
378                    context.push(tok);
379                    if output.len() >= max_tokens {
380                        break;
381                    }
382                }
383            }
384
385            // Safety: break if context grows unexpectedly large
386            if context.len() > prompt_tokens.len() + max_tokens + self.config.lookahead {
387                break;
388            }
389        }
390
391        output
392    }
393
394    /// Overall acceptance rate: accepted tokens / draft tokens, across all steps.
395    ///
396    /// Returns 0.0 if no drafts have been generated yet.
397    pub fn acceptance_rate(&self) -> f32 {
398        if self.total_draft_tokens == 0 {
399            return 0.0;
400        }
401        self.total_accepted_tokens as f32 / self.total_draft_tokens as f32
402    }
403
404    /// Theoretical speedup estimate from speculative decoding.
405    ///
406    /// Speedup ≈ accepted tokens per step (capped at lookahead).
407    /// Returns the mean accepted tokens per step, which indicates how many
408    /// target forward passes were "skipped" relative to autoregressive decoding.
409    ///
410    /// A return of 1.0 means no speedup (equivalent to autoregressive); higher
411    /// values indicate benefit from speculative parallelism.
412    pub fn speedup_estimate(&self) -> f32 {
413        if self.total_steps == 0 {
414            return 1.0;
415        }
416        let avg_accepted = self.total_accepted_tokens as f32 / self.total_steps as f32;
417        // Speedup is bounded by lookahead + 1 (the bonus token)
418        avg_accepted.max(1.0)
419    }
420
421    /// Reset all accumulated statistics (steps, tokens, acceptance counts).
422    /// If an adaptive controller is attached, its EWMA is also reset.
423    pub fn reset_stats(&mut self) {
424        self.total_steps = 0;
425        self.total_draft_tokens = 0;
426        self.total_accepted_tokens = 0;
427        if let Some(adaptive) = self.adaptive.as_mut() {
428            adaptive.reset();
429            self.config.lookahead = adaptive.lookahead();
430        }
431    }
432
433    /// Determine whether a draft token should be accepted.
434    ///
435    /// Implements the speculative sampling acceptance criterion:
436    /// - If `target_prob >= draft_prob`: always accept
437    /// - Otherwise: accept with probability `target_prob / draft_prob`
438    ///
439    /// The `threshold` parameter can optionally raise the bar for acceptance.
440    /// `rng_sample` must be in `[0.0, 1.0)`.
441    fn should_accept(draft_prob: f32, target_prob: f32, threshold: f32, rng_sample: f32) -> bool {
442        if target_prob >= draft_prob {
443            // Target assigns higher probability — always accept
444            true
445        } else {
446            // Rejection sampling: accept with prob target/draft
447            let accept_prob = (target_prob / draft_prob).max(0.0);
448            let effective_threshold = accept_prob - threshold;
449            rng_sample < effective_threshold
450        }
451    }
452}
453
454// ──────────────────────────────────────────────────────────────────
455// Utility: softmax over f32 slice
456// ──────────────────────────────────────────────────────────────────
457
458/// Compute numerically stable softmax over a logit slice.
459fn softmax(logits: &[f32]) -> Vec<f32> {
460    if logits.is_empty() {
461        return vec![];
462    }
463    let max_val = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
464    let exps: Vec<f32> = logits.iter().map(|&l| (l - max_val).exp()).collect();
465    let sum: f32 = exps.iter().sum();
466    if sum < 1e-30 {
467        // Uniform fallback
468        let n = logits.len() as f32;
469        return vec![1.0 / n; logits.len()];
470    }
471    exps.iter().map(|&e| e / sum).collect()
472}
473
474// ──────────────────────────────────────────────────────────────────
475// Tests
476// ──────────────────────────────────────────────────────────────────
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481    use oxibonsai_core::config::Qwen3Config;
482
483    fn make_decoder(lookahead: usize) -> SpeculativeDecoder<'static> {
484        // Use a statically-valid config — tiny_test gives a minimal model
485        let config = Qwen3Config::tiny_test();
486        let params = SamplingParams::default();
487        let engine = InferenceEngine::new(config, params, 42);
488        let spec_config = SpeculativeConfig {
489            lookahead,
490            acceptance_threshold: 0.0,
491        };
492        SpeculativeDecoder::new(engine, spec_config)
493    }
494
495    fn make_peaked_logits(
496        vocab_size: usize,
497        peak_token: usize,
498        n_positions: usize,
499    ) -> Vec<Vec<f32>> {
500        (0..n_positions)
501            .map(|_| {
502                let mut logits = vec![-5.0f32; vocab_size];
503                if peak_token < vocab_size {
504                    logits[peak_token] = 10.0;
505                }
506                logits
507            })
508            .collect()
509    }
510
511    #[test]
512    fn test_speculative_config_defaults() {
513        let cfg = SpeculativeConfig::default();
514        assert_eq!(cfg.lookahead, 5, "default lookahead should be 5");
515        assert!(
516            (cfg.acceptance_threshold - 0.0).abs() < f32::EPSILON,
517            "default threshold should be 0.0"
518        );
519    }
520
521    #[test]
522    fn test_draft_generates_lookahead_tokens() {
523        let mut decoder = make_decoder(3);
524        let context = vec![1u32, 2, 3];
525        let params = SamplingParams::default();
526        let draft = decoder.draft(&context, &params);
527        // Draft should generate up to lookahead tokens (may be fewer if EOS hit)
528        assert!(
529            draft.len() <= 3,
530            "draft should not exceed lookahead=3, got {}",
531            draft.len()
532        );
533    }
534
535    #[test]
536    fn test_verify_accepts_high_probability_tokens() {
537        let decoder = make_decoder(5);
538        let params = SamplingParams::default();
539        let vocab_size = 100;
540
541        // Token 42 is the draft token; give it very high target probability
542        let draft_tokens = vec![42u32];
543        let target_logits = make_peaked_logits(vocab_size, 42, 1);
544
545        let accepted = decoder.verify(&draft_tokens, &target_logits, &params);
546        assert_eq!(
547            accepted.len(),
548            1,
549            "high-probability token should be accepted"
550        );
551        assert_eq!(accepted[0], 42);
552    }
553
554    #[test]
555    fn test_verify_rejects_low_probability_tokens() {
556        let decoder = make_decoder(5);
557        let params = SamplingParams::default();
558        let vocab_size = 1000;
559
560        // Token 500 — give it very low probability (far from peak)
561        let draft_tokens = vec![500u32];
562        let mut logits = vec![-10.0f32; vocab_size];
563        logits[0] = 20.0; // strong peak at token 0, not 500
564        let target_logits = vec![logits];
565
566        // With very low target_prob for token 500, most RNG samples should reject
567        // Run multiple times to confirm rejection is common
568        let mut rejections = 0;
569        for _ in 0..20 {
570            let accepted = decoder.verify(&draft_tokens, &target_logits, &params);
571            if accepted.is_empty() {
572                rejections += 1;
573            }
574        }
575        assert!(
576            rejections > 0,
577            "low-probability token should be rejected at least sometimes"
578        );
579    }
580
581    #[test]
582    fn test_acceptance_rate_zero_at_start() {
583        let decoder = make_decoder(5);
584        assert!(
585            (decoder.acceptance_rate() - 0.0).abs() < f32::EPSILON,
586            "acceptance rate must be 0.0 before any steps"
587        );
588        assert_eq!(decoder.total_steps, 0);
589        assert_eq!(decoder.total_draft_tokens, 0);
590        assert_eq!(decoder.total_accepted_tokens, 0);
591    }
592
593    #[test]
594    fn test_acceptance_rate_updates_after_step() {
595        let mut decoder = make_decoder(4);
596        let params = SamplingParams::default();
597        let context = vec![1u32, 2, 3];
598
599        // Use peaked logits so tokens are likely accepted
600        let vocab_size = 32usize;
601        let target_logits = make_peaked_logits(vocab_size, 5, 4);
602
603        let step = decoder.step(&context, &target_logits, &params);
604
605        assert_eq!(decoder.total_steps, 1, "one step should have been recorded");
606        assert_eq!(
607            decoder.total_draft_tokens,
608            step.draft_tokens.len() as u64,
609            "draft token count should match"
610        );
611        assert!(
612            decoder.total_accepted_tokens <= decoder.total_draft_tokens,
613            "accepted cannot exceed drafted"
614        );
615    }
616
617    #[test]
618    fn test_generate_speculative_returns_tokens() {
619        let mut decoder = make_decoder(3);
620        let params = SamplingParams::default();
621        let prompt = vec![1u32, 2, 3];
622
623        let output = decoder.generate_speculative(&prompt, 5, &params);
624        // Should return up to max_tokens tokens
625        assert!(
626            output.len() <= 5,
627            "output should not exceed max_tokens=5, got {}",
628            output.len()
629        );
630    }
631
632    #[test]
633    fn test_should_accept_target_above_draft() {
634        // When target_prob > draft_prob, always accept regardless of rng_sample
635        assert!(
636            SpeculativeDecoder::should_accept(0.1, 0.9, 0.0, 0.99),
637            "target > draft: must accept even with rng_sample near 1.0"
638        );
639        assert!(
640            SpeculativeDecoder::should_accept(0.05, 0.5, 0.0, 0.0),
641            "target > draft: must accept with rng_sample=0.0"
642        );
643    }
644
645    #[test]
646    fn test_should_accept_target_below_draft_probabilistic() {
647        // target_prob < draft_prob → accept with prob target/draft
648        // With target=0.1, draft=1.0, accept_prob = 0.1
649        // rng_sample=0.05 < 0.1 → should accept
650        assert!(
651            SpeculativeDecoder::should_accept(1.0, 0.1, 0.0, 0.05),
652            "rng_sample=0.05 < accept_prob=0.1, should accept"
653        );
654        // rng_sample=0.5 >= 0.1 → should reject
655        assert!(
656            !SpeculativeDecoder::should_accept(1.0, 0.1, 0.0, 0.5),
657            "rng_sample=0.5 >= accept_prob=0.1, should reject"
658        );
659    }
660
661    #[test]
662    fn test_speedup_estimate_below_lookahead() {
663        let mut decoder = make_decoder(5);
664        // Before any steps, speedup is 1.0 (baseline)
665        assert!(
666            (decoder.speedup_estimate() - 1.0).abs() < f32::EPSILON,
667            "initial speedup should be 1.0"
668        );
669
670        // Simulate some stats: 10 steps, 30 drafted, 15 accepted
671        decoder.total_steps = 10;
672        decoder.total_draft_tokens = 30;
673        decoder.total_accepted_tokens = 15;
674
675        let speedup = decoder.speedup_estimate();
676        // avg_accepted = 15/10 = 1.5; speedup = max(1.5, 1.0) = 1.5
677        assert!(
678            (speedup - 1.5).abs() < 1e-4,
679            "speedup should be 1.5 (avg accepted per step), got {speedup}"
680        );
681        assert!(
682            speedup <= decoder.config.lookahead as f32 + 1.0,
683            "speedup cannot exceed lookahead+1"
684        );
685    }
686
687    #[test]
688    fn test_with_adaptive_starts_with_initial_lookahead() {
689        let config = Qwen3Config::tiny_test();
690        let params = SamplingParams::default();
691        let engine = InferenceEngine::new(config, params, 42);
692        let spec_cfg = SpeculativeConfig {
693            lookahead: 99,
694            acceptance_threshold: 0.0,
695        };
696        let adapt_cfg = AdaptiveLookaheadConfig {
697            initial: 4,
698            min: 2,
699            max: 10,
700            alpha: 0.5,
701            cooldown_steps: 1,
702        };
703        let decoder =
704            SpeculativeDecoder::with_adaptive(engine, spec_cfg, adapt_cfg).expect("valid");
705        // Adaptive overrides the spec config's lookahead.
706        assert_eq!(decoder.config.lookahead, 4);
707        assert!(decoder.adaptive().is_some());
708    }
709
710    #[test]
711    fn test_adaptive_decreases_lookahead_on_low_acceptance() {
712        let config = Qwen3Config::tiny_test();
713        let params = SamplingParams::default();
714        let engine = InferenceEngine::new(config, params, 42);
715        let spec_cfg = SpeculativeConfig {
716            lookahead: 8,
717            acceptance_threshold: 0.0,
718        };
719        let adapt_cfg = AdaptiveLookaheadConfig {
720            initial: 8,
721            min: 2,
722            max: 12,
723            alpha: 0.7,
724            cooldown_steps: 1,
725        };
726        let mut decoder =
727            SpeculativeDecoder::with_adaptive(engine, spec_cfg, adapt_cfg).expect("valid");
728        let context = vec![1u32, 2, 3];
729        let params = SamplingParams::default();
730        // Provide logits with no peaked target — most rejections.
731        let vocab = 100usize;
732        let logits: Vec<Vec<f32>> = (0..decoder.config.lookahead)
733            .map(|_| {
734                let mut l = vec![10.0f32; vocab];
735                l[0] = -50.0; // bias away from typical draft tokens
736                l
737            })
738            .collect();
739        for _ in 0..30 {
740            decoder.step(&context, &logits, &params);
741        }
742        // With low acceptance, lookahead should have fallen toward the min.
743        let final_la = decoder.config.lookahead;
744        assert!(
745            final_la <= 8,
746            "lookahead should not increase, got {final_la}"
747        );
748    }
749
750    #[test]
751    fn test_reset_stats_resets_adaptive() {
752        let config = Qwen3Config::tiny_test();
753        let params = SamplingParams::default();
754        let engine = InferenceEngine::new(config, params, 42);
755        let spec_cfg = SpeculativeConfig {
756            lookahead: 5,
757            acceptance_threshold: 0.0,
758        };
759        let adapt_cfg = AdaptiveLookaheadConfig {
760            initial: 5,
761            min: 2,
762            max: 12,
763            alpha: 0.5,
764            cooldown_steps: 1,
765        };
766        let mut decoder =
767            SpeculativeDecoder::with_adaptive(engine, spec_cfg, adapt_cfg).expect("valid");
768        // Drive the adaptive controller into a different state.
769        for _ in 0..30 {
770            let logits = make_peaked_logits(64, 5, decoder.config.lookahead);
771            decoder.step(&[1, 2, 3], &logits, &SamplingParams::default());
772        }
773        decoder.reset_stats();
774        assert_eq!(decoder.total_steps, 0);
775        assert_eq!(decoder.config.lookahead, 5);
776        assert_eq!(
777            decoder.adaptive().expect("adaptive present").observations(),
778            0
779        );
780    }
781}