rlx_runtime/spec_decode.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Speculative decoding scheduling pattern (plan #34).
17//!
18//! Borrowed from MAX's serving scheduler structure
19//! (`one_shot_scheduler.py`, decode/prefill split). The classic
20//! Leviathan-et-al "Fast Inference from Transformers via
21//! Speculative Decoding" algorithm — a small draft model proposes
22//! `n` tokens; the larger target model verifies all `n` in one
23//! forward pass; tokens are accepted up to the first rejection,
24//! then one extra "corrected" token is sampled from the residual
25//! distribution.
26//!
27//! Expected speedup on decode-heavy workloads: 2-3×.
28//!
29//! Layout:
30//! - [`Speculator`] — trait an autoregressive model implements.
31//! Two methods: `propose` (draft) and `verify` (target).
32//! - [`DraftProposal`] / [`VerifyResult`] / [`AcceptDecision`]
33//! — wire-format data shapes.
34//! - [`speculative_accept`] — pure function that runs the
35//! acceptance algorithm. Testable without a real model.
36//! - [`SpecDecoder`] — orchestrator that calls a draft + target
37//! and returns the next batch of accepted tokens.
38
39use rlx_ir::Philox4x32;
40
41/// One round of draft proposals.
42#[derive(Debug, Clone)]
43pub struct DraftProposal {
44 /// `n` proposed tokens (draft sampled greedily or stochastically).
45 pub tokens: Vec<u32>,
46 /// `[n, vocab]` row-major — the draft's probability for each
47 /// token at that position. `probs[i][tokens[i]]` is the
48 /// probability the draft assigned to its own choice.
49 pub probs: Vec<Vec<f32>>,
50}
51
52/// Target model's verification of the draft's proposals.
53#[derive(Debug, Clone)]
54pub struct VerifyResult {
55 /// `[n, vocab]` row-major — target's probability at each
56 /// position, conditioned on the prefix and all preceding
57 /// draft tokens.
58 pub probs: Vec<Vec<f32>>,
59}
60
61/// Outcome of one speculative-decoding round.
62#[derive(Debug, Clone)]
63pub struct AcceptDecision {
64 /// Tokens accepted. Length is `0..=n`.
65 pub accepted: Vec<u32>,
66 /// One extra token sampled from the target's distribution
67 /// after rejection — `None` only when all `n` are accepted.
68 /// Either way the round produces `accepted.len() + 1` real
69 /// tokens (the +1 is `corrected` *or* a final target sample).
70 pub corrected: Option<u32>,
71}
72
73impl AcceptDecision {
74 /// Total real tokens this round produced.
75 pub fn total_tokens(&self) -> usize {
76 self.accepted.len() + if self.corrected.is_some() { 1 } else { 0 }
77 }
78}
79
80/// Streaming speculator interface — one method to draft, one to
81/// verify. Real implementations bind to a `CompiledGraph` per
82/// model; testable implementations can return canned probability
83/// tables.
84pub trait Speculator {
85 /// Propose `n` tokens given the current `context`. Returns the
86 /// proposed tokens + the draft's probability tables.
87 fn propose(&mut self, context: &[u32], n: usize) -> DraftProposal;
88
89 /// Verify a batch of `proposed` tokens in one forward pass:
90 /// for each position `i ∈ 0..n`, return the *target* model's
91 /// probability distribution conditioned on
92 /// `context ++ proposed[..i]`.
93 fn verify(&mut self, context: &[u32], proposed: &[u32]) -> VerifyResult;
94
95 /// Commit `accepted` tokens into persistent decode state after a
96 /// speculative round. Default no-op; MTP draft overrides so GDN
97 /// recurrent state only advances for accepted tokens.
98 fn commit(&mut self, context: &[u32], accepted: &[u32]) {
99 let _ = (context, accepted);
100 }
101}
102
103/// Pure speculative-acceptance algorithm. Given the draft's
104/// proposal and the target's verification, runs the
105/// per-position accept/reject test and returns the final
106/// decision. No model state, no I/O — easy to unit-test against
107/// hand-built distributions.
108///
109/// Algorithm (Leviathan et al. 2022, Algorithm 1):
110/// for i in 0..n:
111/// r ~ Uniform(0,1)
112/// if r < min(1, q_target(x_i) / p_draft(x_i)):
113/// accept x_i
114/// else:
115/// sample x' from norm(max(0, q - p))
116/// return (accepted[..i], Some(x'))
117/// return (all n accepted, None)
118pub fn speculative_accept(
119 proposal: &DraftProposal,
120 verify: &VerifyResult,
121 rng: &mut Philox4x32,
122) -> AcceptDecision {
123 assert_eq!(
124 proposal.tokens.len(),
125 proposal.probs.len(),
126 "DraftProposal: tokens and probs must agree"
127 );
128 assert_eq!(
129 proposal.probs.len(),
130 verify.probs.len(),
131 "DraftProposal and VerifyResult must propose the same n"
132 );
133 let n = proposal.tokens.len();
134 let mut accepted: Vec<u32> = Vec::with_capacity(n);
135 for i in 0..n {
136 let token = proposal.tokens[i];
137 let p = proposal.probs[i][token as usize].max(f32::MIN_POSITIVE);
138 let q = verify.probs[i][token as usize];
139 let accept_ratio = (q / p).min(1.0);
140 let r = rng.next_f32();
141 if r < accept_ratio {
142 accepted.push(token);
143 } else {
144 let corrected = sample_corrected_residual(&proposal.probs[i], &verify.probs[i], rng);
145 return AcceptDecision {
146 accepted,
147 corrected: Some(corrected),
148 };
149 }
150 }
151 AcceptDecision {
152 accepted,
153 corrected: None,
154 }
155}
156
157/// Sample from the *residual* distribution `norm(max(0, q - p))`.
158/// This is the "what the target prefers but the draft missed"
159/// distribution, used after a rejection so the round still emits
160/// a valid sample from the target.
161fn sample_corrected_residual(p: &[f32], q: &[f32], rng: &mut Philox4x32) -> u32 {
162 let mut adj: Vec<f32> = q.iter().zip(p).map(|(qi, pi)| (qi - pi).max(0.0)).collect();
163 let sum: f32 = adj.iter().sum();
164 if sum <= f32::MIN_POSITIVE {
165 // q ≤ p elementwise (extreme edge case): fall back to
166 // sampling from q directly.
167 return sample_from(q, rng);
168 }
169 let inv = 1.0 / sum;
170 for v in adj.iter_mut() {
171 *v *= inv;
172 }
173 sample_from(&adj, rng)
174}
175
176fn sample_from(probs: &[f32], rng: &mut Philox4x32) -> u32 {
177 let r = rng.next_f32();
178 let mut acc = 0f32;
179 for (i, &p) in probs.iter().enumerate() {
180 acc += p;
181 if r <= acc {
182 return i as u32;
183 }
184 }
185 (probs.len() - 1) as u32
186}
187
188/// Top-level orchestrator. Holds a draft + target speculator and
189/// the lookahead window `n`. `step()` runs one full round and
190/// returns the tokens to append to the running context.
191pub struct SpecDecoder<D: Speculator, T: Speculator> {
192 pub draft: D,
193 pub target: T,
194 pub n: usize,
195 rng: Philox4x32,
196}
197
198impl<D: Speculator, T: Speculator> SpecDecoder<D, T> {
199 pub fn new(draft: D, target: T, n: usize, seed: u64) -> Self {
200 Self {
201 draft,
202 target,
203 n,
204 rng: Philox4x32::new(seed),
205 }
206 }
207
208 /// One speculative-decoding round. Returns the tokens that
209 /// should be appended to `context`.
210 pub fn step(&mut self, context: &[u32]) -> Vec<u32> {
211 let proposal = self.draft.propose(context, self.n);
212 let verify = self.target.verify(context, &proposal.tokens);
213 let decision = speculative_accept(&proposal, &verify, &mut self.rng);
214 let mut out = decision.accepted;
215 if let Some(c) = decision.corrected {
216 out.push(c);
217 }
218 self.draft.commit(context, &out);
219 self.target.commit(context, &out);
220 out
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 /// When draft and target agree perfectly (same probs), every
229 /// proposed token must be accepted (accept_ratio = 1.0).
230 #[test]
231 fn identical_distributions_accept_all() {
232 let n = 4;
233 let vocab = 8;
234 // Draft proposed token = argmax of a peaked distribution.
235 // Target's distribution is identical → q/p = 1.0 → always
236 // accept.
237 let mut probs = Vec::with_capacity(n);
238 let mut tokens = Vec::with_capacity(n);
239 for i in 0..n {
240 let mut row = vec![0.01f32; vocab];
241 let pick = (i * 2) % vocab;
242 row[pick] = 1.0 - 0.01 * (vocab - 1) as f32;
243 probs.push(row);
244 tokens.push(pick as u32);
245 }
246 let proposal = DraftProposal {
247 tokens: tokens.clone(),
248 probs: probs.clone(),
249 };
250 let verify = VerifyResult { probs };
251
252 // 100 trials with different seeds; all should accept all 4.
253 for seed in 0..100u64 {
254 let mut rng = Philox4x32::new(seed + 1);
255 let d = speculative_accept(&proposal, &verify, &mut rng);
256 assert_eq!(d.accepted, tokens, "seed {seed}: should accept all");
257 assert!(d.corrected.is_none());
258 }
259 }
260
261 /// When the draft places mass on tokens the target rejects
262 /// (q ≪ p on those tokens), at least some rejections happen.
263 #[test]
264 fn divergent_distributions_reject_sometimes() {
265 let n = 4;
266 let _vocab = 4;
267 // Draft ALWAYS picks token 0; target wants token 3.
268 let draft_row = vec![0.97f32, 0.01, 0.01, 0.01];
269 let target_row = vec![0.01f32, 0.01, 0.01, 0.97];
270 let proposal = DraftProposal {
271 tokens: vec![0u32; n],
272 probs: vec![draft_row.clone(); n],
273 };
274 let verify = VerifyResult {
275 probs: vec![target_row.clone(); n],
276 };
277
278 let mut total_accepted = 0usize;
279 let trials = 200;
280 for seed in 0..trials {
281 let mut rng = Philox4x32::new(seed + 1);
282 let d = speculative_accept(&proposal, &verify, &mut rng);
283 total_accepted += d.accepted.len();
284 // After rejection, corrected must be present.
285 if d.accepted.len() < n {
286 assert!(
287 d.corrected.is_some(),
288 "rejection at seed {seed} should yield a corrected token"
289 );
290 // Corrected token should be drawn from
291 // norm(max(0, q-p)) which strongly favours token 3.
292 }
293 }
294 // q/p = 0.01/0.97 ≈ 0.0103 per token → expected acceptance
295 // length per round is geometric, mean ≈ 0.01. Across 200
296 // trials × 4 positions = 800 chances, accept rate ~1%.
297 assert!(
298 total_accepted < 80,
299 "divergent distributions should accept rarely; got {total_accepted}/800"
300 );
301 }
302
303 /// Mock speculators for end-to-end SpecDecoder basic test.
304 /// Both return canned probability tables.
305 struct CannedSpeculator {
306 next_token: u32,
307 peaked_prob: f32,
308 }
309
310 impl Speculator for CannedSpeculator {
311 fn propose(&mut self, _ctx: &[u32], n: usize) -> DraftProposal {
312 let vocab = 8;
313 let mut probs = Vec::with_capacity(n);
314 for _ in 0..n {
315 let mut row = vec![(1.0 - self.peaked_prob) / (vocab - 1) as f32; vocab];
316 row[self.next_token as usize] = self.peaked_prob;
317 probs.push(row);
318 }
319 DraftProposal {
320 tokens: vec![self.next_token; n],
321 probs,
322 }
323 }
324 fn verify(&mut self, _ctx: &[u32], proposed: &[u32]) -> VerifyResult {
325 // Canned target: identical distribution to its own
326 // "next_token" choice.
327 let n = proposed.len();
328 let vocab = 8;
329 let mut probs = Vec::with_capacity(n);
330 for _ in 0..n {
331 let mut row = vec![(1.0 - self.peaked_prob) / (vocab - 1) as f32; vocab];
332 row[self.next_token as usize] = self.peaked_prob;
333 probs.push(row);
334 }
335 VerifyResult { probs }
336 }
337 }
338
339 #[test]
340 fn spec_decoder_step_emits_n_plus_1_tokens_when_aligned() {
341 let draft = CannedSpeculator {
342 next_token: 5,
343 peaked_prob: 0.95,
344 };
345 let target = CannedSpeculator {
346 next_token: 5,
347 peaked_prob: 0.95,
348 };
349 let mut dec = SpecDecoder::new(draft, target, 4, 1);
350 let context = vec![0u32, 1, 2];
351 let out = dec.step(&context);
352 // Aligned distributions → all 4 accepted, no corrected; total = 4.
353 assert_eq!(
354 out.len(),
355 4,
356 "aligned step should emit n tokens (no rejection)"
357 );
358 assert!(out.iter().all(|&t| t == 5));
359 }
360}