Skip to main content

rlx_runtime/
spec_decode.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Speculative decoding scheduling pattern (plan #34).
17//!
18//! Borrowed from MAX's serving scheduler structure
19//! (`one_shot_scheduler.py`, decode/prefill split). The classic
20//! Leviathan-et-al "Fast Inference from Transformers via
21//! Speculative Decoding" algorithm — a small draft model proposes
22//! `n` tokens; the larger target model verifies all `n` in one
23//! forward pass; tokens are accepted up to the first rejection,
24//! then one extra "corrected" token is sampled from the residual
25//! distribution.
26//!
27//! Expected speedup on decode-heavy workloads: 2-3×.
28//!
29//! Layout:
30//!   - [`Speculator`] — trait an autoregressive model implements.
31//!     Two methods: `propose` (draft) and `verify` (target).
32//!   - [`DraftProposal`] / [`VerifyResult`] / [`AcceptDecision`]
33//!     — wire-format data shapes.
34//!   - [`speculative_accept`] — pure function that runs the
35//!     acceptance algorithm. Testable without a real model.
36//!   - [`SpecDecoder`] — orchestrator that calls a draft + target
37//!     and returns the next batch of accepted tokens.
38
39use rlx_ir::Philox4x32;
40
41/// One round of draft proposals.
42#[derive(Debug, Clone)]
43pub struct DraftProposal {
44    /// `n` proposed tokens (draft sampled greedily or stochastically).
45    pub tokens: Vec<u32>,
46    /// `[n, vocab]` row-major — the draft's probability for each
47    /// token at that position. `probs[i][tokens[i]]` is the
48    /// probability the draft assigned to its own choice.
49    pub probs: Vec<Vec<f32>>,
50}
51
52/// Target model's verification of the draft's proposals.
53#[derive(Debug, Clone)]
54pub struct VerifyResult {
55    /// `[n, vocab]` row-major — target's probability at each
56    /// position, conditioned on the prefix and all preceding
57    /// draft tokens.
58    pub probs: Vec<Vec<f32>>,
59}
60
61/// Outcome of one speculative-decoding round.
62#[derive(Debug, Clone)]
63pub struct AcceptDecision {
64    /// Tokens accepted. Length is `0..=n`.
65    pub accepted: Vec<u32>,
66    /// One extra token sampled from the target's distribution
67    /// after rejection — `None` only when all `n` are accepted.
68    /// Either way the round produces `accepted.len() + 1` real
69    /// tokens (the +1 is `corrected` *or* a final target sample).
70    pub corrected: Option<u32>,
71}
72
73impl AcceptDecision {
74    /// Total real tokens this round produced.
75    pub fn total_tokens(&self) -> usize {
76        self.accepted.len() + if self.corrected.is_some() { 1 } else { 0 }
77    }
78}
79
80/// Streaming speculator interface — one method to draft, one to
81/// verify. Real implementations bind to a `CompiledGraph` per
82/// model; testable implementations can return canned probability
83/// tables.
84pub trait Speculator {
85    /// Propose `n` tokens given the current `context`. Returns the
86    /// proposed tokens + the draft's probability tables.
87    fn propose(&mut self, context: &[u32], n: usize) -> DraftProposal;
88
89    /// Verify a batch of `proposed` tokens in one forward pass:
90    /// for each position `i ∈ 0..n`, return the *target* model's
91    /// probability distribution conditioned on
92    /// `context ++ proposed[..i]`.
93    fn verify(&mut self, context: &[u32], proposed: &[u32]) -> VerifyResult;
94
95    /// Commit `accepted` tokens into persistent decode state after a
96    /// speculative round. Default no-op; MTP draft overrides so GDN
97    /// recurrent state only advances for accepted tokens.
98    fn commit(&mut self, context: &[u32], accepted: &[u32]) {
99        let _ = (context, accepted);
100    }
101}
102
103/// Pure speculative-acceptance algorithm. Given the draft's
104/// proposal and the target's verification, runs the
105/// per-position accept/reject test and returns the final
106/// decision. No model state, no I/O — easy to unit-test against
107/// hand-built distributions.
108///
109/// Algorithm (Leviathan et al. 2022, Algorithm 1):
110///   for i in 0..n:
111///     r ~ Uniform(0,1)
112///     if r < min(1, q_target(x_i) / p_draft(x_i)):
113///       accept x_i
114///     else:
115///       sample x' from norm(max(0, q - p))
116///       return (accepted[..i], Some(x'))
117///   return (all n accepted, None)
118pub fn speculative_accept(
119    proposal: &DraftProposal,
120    verify: &VerifyResult,
121    rng: &mut Philox4x32,
122) -> AcceptDecision {
123    assert_eq!(
124        proposal.tokens.len(),
125        proposal.probs.len(),
126        "DraftProposal: tokens and probs must agree"
127    );
128    assert_eq!(
129        proposal.probs.len(),
130        verify.probs.len(),
131        "DraftProposal and VerifyResult must propose the same n"
132    );
133    let n = proposal.tokens.len();
134    let mut accepted: Vec<u32> = Vec::with_capacity(n);
135    for i in 0..n {
136        let token = proposal.tokens[i];
137        let p = proposal.probs[i][token as usize].max(f32::MIN_POSITIVE);
138        let q = verify.probs[i][token as usize];
139        let accept_ratio = (q / p).min(1.0);
140        let r = rng.next_f32();
141        if r < accept_ratio {
142            accepted.push(token);
143        } else {
144            let corrected = sample_corrected_residual(&proposal.probs[i], &verify.probs[i], rng);
145            return AcceptDecision {
146                accepted,
147                corrected: Some(corrected),
148            };
149        }
150    }
151    AcceptDecision {
152        accepted,
153        corrected: None,
154    }
155}
156
157/// Sample from the *residual* distribution `norm(max(0, q - p))`.
158/// This is the "what the target prefers but the draft missed"
159/// distribution, used after a rejection so the round still emits
160/// a valid sample from the target.
161fn sample_corrected_residual(p: &[f32], q: &[f32], rng: &mut Philox4x32) -> u32 {
162    let mut adj: Vec<f32> = q.iter().zip(p).map(|(qi, pi)| (qi - pi).max(0.0)).collect();
163    let sum: f32 = adj.iter().sum();
164    if sum <= f32::MIN_POSITIVE {
165        // q ≤ p elementwise (extreme edge case): fall back to
166        // sampling from q directly.
167        return sample_from(q, rng);
168    }
169    let inv = 1.0 / sum;
170    for v in adj.iter_mut() {
171        *v *= inv;
172    }
173    sample_from(&adj, rng)
174}
175
176fn sample_from(probs: &[f32], rng: &mut Philox4x32) -> u32 {
177    let r = rng.next_f32();
178    let mut acc = 0f32;
179    for (i, &p) in probs.iter().enumerate() {
180        acc += p;
181        if r <= acc {
182            return i as u32;
183        }
184    }
185    (probs.len() - 1) as u32
186}
187
188/// Top-level orchestrator. Holds a draft + target speculator and
189/// the lookahead window `n`. `step()` runs one full round and
190/// returns the tokens to append to the running context.
191pub struct SpecDecoder<D: Speculator, T: Speculator> {
192    pub draft: D,
193    pub target: T,
194    pub n: usize,
195    rng: Philox4x32,
196}
197
198impl<D: Speculator, T: Speculator> SpecDecoder<D, T> {
199    pub fn new(draft: D, target: T, n: usize, seed: u64) -> Self {
200        Self {
201            draft,
202            target,
203            n,
204            rng: Philox4x32::new(seed),
205        }
206    }
207
208    /// One speculative-decoding round. Returns the tokens that
209    /// should be appended to `context`.
210    pub fn step(&mut self, context: &[u32]) -> Vec<u32> {
211        let proposal = self.draft.propose(context, self.n);
212        let verify = self.target.verify(context, &proposal.tokens);
213        let decision = speculative_accept(&proposal, &verify, &mut self.rng);
214        let mut out = decision.accepted;
215        if let Some(c) = decision.corrected {
216            out.push(c);
217        }
218        self.draft.commit(context, &out);
219        self.target.commit(context, &out);
220        out
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    /// When draft and target agree perfectly (same probs), every
229    /// proposed token must be accepted (accept_ratio = 1.0).
230    #[test]
231    fn identical_distributions_accept_all() {
232        let n = 4;
233        let vocab = 8;
234        // Draft proposed token = argmax of a peaked distribution.
235        // Target's distribution is identical → q/p = 1.0 → always
236        // accept.
237        let mut probs = Vec::with_capacity(n);
238        let mut tokens = Vec::with_capacity(n);
239        for i in 0..n {
240            let mut row = vec![0.01f32; vocab];
241            let pick = (i * 2) % vocab;
242            row[pick] = 1.0 - 0.01 * (vocab - 1) as f32;
243            probs.push(row);
244            tokens.push(pick as u32);
245        }
246        let proposal = DraftProposal {
247            tokens: tokens.clone(),
248            probs: probs.clone(),
249        };
250        let verify = VerifyResult { probs };
251
252        // 100 trials with different seeds; all should accept all 4.
253        for seed in 0..100u64 {
254            let mut rng = Philox4x32::new(seed + 1);
255            let d = speculative_accept(&proposal, &verify, &mut rng);
256            assert_eq!(d.accepted, tokens, "seed {seed}: should accept all");
257            assert!(d.corrected.is_none());
258        }
259    }
260
261    /// When the draft places mass on tokens the target rejects
262    /// (q ≪ p on those tokens), at least some rejections happen.
263    #[test]
264    fn divergent_distributions_reject_sometimes() {
265        let n = 4;
266        let _vocab = 4;
267        // Draft ALWAYS picks token 0; target wants token 3.
268        let draft_row = vec![0.97f32, 0.01, 0.01, 0.01];
269        let target_row = vec![0.01f32, 0.01, 0.01, 0.97];
270        let proposal = DraftProposal {
271            tokens: vec![0u32; n],
272            probs: vec![draft_row.clone(); n],
273        };
274        let verify = VerifyResult {
275            probs: vec![target_row.clone(); n],
276        };
277
278        let mut total_accepted = 0usize;
279        let trials = 200;
280        for seed in 0..trials {
281            let mut rng = Philox4x32::new(seed + 1);
282            let d = speculative_accept(&proposal, &verify, &mut rng);
283            total_accepted += d.accepted.len();
284            // After rejection, corrected must be present.
285            if d.accepted.len() < n {
286                assert!(
287                    d.corrected.is_some(),
288                    "rejection at seed {seed} should yield a corrected token"
289                );
290                // Corrected token should be drawn from
291                // norm(max(0, q-p)) which strongly favours token 3.
292            }
293        }
294        // q/p = 0.01/0.97 ≈ 0.0103 per token → expected acceptance
295        // length per round is geometric, mean ≈ 0.01. Across 200
296        // trials × 4 positions = 800 chances, accept rate ~1%.
297        assert!(
298            total_accepted < 80,
299            "divergent distributions should accept rarely; got {total_accepted}/800"
300        );
301    }
302
303    /// Mock speculators for end-to-end SpecDecoder basic test.
304    /// Both return canned probability tables.
305    struct CannedSpeculator {
306        next_token: u32,
307        peaked_prob: f32,
308    }
309
310    impl Speculator for CannedSpeculator {
311        fn propose(&mut self, _ctx: &[u32], n: usize) -> DraftProposal {
312            let vocab = 8;
313            let mut probs = Vec::with_capacity(n);
314            for _ in 0..n {
315                let mut row = vec![(1.0 - self.peaked_prob) / (vocab - 1) as f32; vocab];
316                row[self.next_token as usize] = self.peaked_prob;
317                probs.push(row);
318            }
319            DraftProposal {
320                tokens: vec![self.next_token; n],
321                probs,
322            }
323        }
324        fn verify(&mut self, _ctx: &[u32], proposed: &[u32]) -> VerifyResult {
325            // Canned target: identical distribution to its own
326            // "next_token" choice.
327            let n = proposed.len();
328            let vocab = 8;
329            let mut probs = Vec::with_capacity(n);
330            for _ in 0..n {
331                let mut row = vec![(1.0 - self.peaked_prob) / (vocab - 1) as f32; vocab];
332                row[self.next_token as usize] = self.peaked_prob;
333                probs.push(row);
334            }
335            VerifyResult { probs }
336        }
337    }
338
339    #[test]
340    fn spec_decoder_step_emits_n_plus_1_tokens_when_aligned() {
341        let draft = CannedSpeculator {
342            next_token: 5,
343            peaked_prob: 0.95,
344        };
345        let target = CannedSpeculator {
346            next_token: 5,
347            peaked_prob: 0.95,
348        };
349        let mut dec = SpecDecoder::new(draft, target, 4, 1);
350        let context = vec![0u32, 1, 2];
351        let out = dec.step(&context);
352        // Aligned distributions → all 4 accepted, no corrected; total = 4.
353        assert_eq!(
354            out.len(),
355            4,
356            "aligned step should emit n tokens (no rejection)"
357        );
358        assert!(out.iter().all(|&t| t == 5));
359    }
360}