gllm_kernels/ops/
speculative_decoding.rs

1//! Speculative decoding utilities for draft/target verification.
2
3use std::marker::PhantomData;
4
5use burn::tensor::backend::Backend;
6use burn::tensor::{Tensor, TensorData};
7
8/// Speculative decoder for draft/target model verification.
9pub struct SpeculativeDecoder<B: Backend> {
10    /// Prediction head configuration.
11    prediction_config: PredictionConfig,
12    /// Tree structure configuration.
13    tree_config: TreeConfig,
14    /// Maximum speculation length (hard cap).
15    max_speculation_length: usize,
16    _marker: PhantomData<B>,
17}
18
19/// Configuration for the prediction head.
20#[derive(Debug, Clone, Copy)]
21pub struct PredictionConfig {
22    /// Hidden dimension (treated as vocab dimension for logits).
23    pub hidden_dim: usize,
24    /// Prediction head type.
25    pub head_type: PredictionHeadType,
26}
27
28/// Prediction head choices.
29#[derive(Debug, Clone, Copy)]
30pub enum PredictionHeadType {
31    /// EAGLE-style lightweight MLP head.
32    Eagle { num_layers: usize },
33    /// Early-exit head with probability threshold.
34    EarlyExit { exit_threshold: f32 },
35}
36
37/// Configuration for speculative tree expansion.
38#[derive(Debug, Clone, Copy)]
39pub struct TreeConfig {
40    /// Branch factor per node.
41    pub branch_factor: usize,
42    /// Tree depth.
43    pub depth: usize,
44    /// Verification strategy for acceptance.
45    pub verification: VerificationStrategy,
46}
47
48/// Verification strategy for speculative decoding.
49#[derive(Debug, Clone, Copy)]
50pub enum VerificationStrategy {
51    /// Greedy verification (deterministic).
52    Greedy,
53    /// Sampling-based verification with temperature.
54    Sampling { temperature: f32 },
55}
56
57/// Token entry in the speculative tree.
58#[derive(Debug, Clone, Copy)]
59pub struct SpeculativeToken {
60    /// Token id (vocab index).
61    pub id: usize,
62    /// Draft log-probability.
63    pub log_prob: f32,
64}
65
66/// Speculative token tree arranged by depth levels.
67#[derive(Debug, Clone)]
68pub struct SpeculativeTree {
69    /// Tokens per depth level (BFS order, repeated per parent).
70    pub levels: Vec<Vec<SpeculativeToken>>,
71}
72
73/// Batch of speculative trees.
74#[derive(Debug, Clone)]
75pub struct SpeculativeCandidates {
76    /// Trees for each batch item.
77    pub trees: Vec<SpeculativeTree>,
78    /// Branch factor used to build trees.
79    pub branch_factor: usize,
80    /// Maximum depth across the batch.
81    pub max_depth: usize,
82    /// Vocabulary size inferred from logits.
83    pub vocab_size: usize,
84}
85
86/// Verification output with accepted tokens and updated cache.
87#[derive(Debug)]
88pub struct SpeculativeVerification<B: Backend> {
89    /// Accepted tokens per batch item.
90    pub accepted_tokens: Vec<Vec<usize>>,
91    /// Updated cache tokens with padding for shorter accept sequences.
92    pub updated_cache: Tensor<B, 2>,
93}
94
95impl<B: Backend> SpeculativeDecoder<B> {
96    /// Create a new speculative decoder.
97    pub fn new(
98        prediction_config: PredictionConfig,
99        tree_config: TreeConfig,
100        max_speculation_length: usize,
101    ) -> Self {
102        Self {
103            prediction_config,
104            tree_config,
105            max_speculation_length,
106            _marker: PhantomData,
107        }
108    }
109
110    /// Access prediction head configuration.
111    pub fn prediction_config(&self) -> &PredictionConfig {
112        &self.prediction_config
113    }
114
115    /// Access tree configuration.
116    pub fn tree_config(&self) -> &TreeConfig {
117        &self.tree_config
118    }
119
120    /// Maximum speculation length.
121    pub fn max_speculation_length(&self) -> usize {
122        self.max_speculation_length
123    }
124
125    /// Generate speculative token trees from hidden states.
126    ///
127    /// # Shapes
128    /// * `hidden`: [batch, seq_len, hidden_dim]
129    pub fn speculate(&self, hidden: Tensor<B, 3>) -> Result<SpeculativeCandidates, &'static str> {
130        self.prediction_config.validate()?;
131        self.tree_config.validate()?;
132        if self.max_speculation_length == 0 {
133            return Err("max speculation length must be > 0");
134        }
135
136        let [batch, seq_len, hidden_dim] = hidden.dims();
137        if seq_len == 0 {
138            return Err("sequence length must be > 0");
139        }
140        if hidden_dim != self.prediction_config.hidden_dim {
141            return Err("hidden dimension mismatch");
142        }
143
144        let hidden_data = hidden
145            .into_data()
146            .into_vec::<f32>()
147            .map_err(|_| "hidden data conversion failed")?;
148
149        let depth_cap = self
150            .tree_config
151            .depth
152            .min(self.max_speculation_length);
153        if depth_cap == 0 {
154            return Err("speculation depth must be > 0");
155        }
156
157        let mut trees = Vec::with_capacity(batch);
158        let mut max_depth = 0;
159
160        for batch_idx in 0..batch {
161            let base = batch_idx * seq_len * hidden_dim;
162            let offset = base + (seq_len - 1) * hidden_dim;
163            let mut logits = hidden_data[offset..offset + hidden_dim].to_vec();
164            apply_prediction_head(self.prediction_config.head_type, &mut logits)?;
165            let log_probs = log_softmax(&logits);
166
167            let mut effective_depth = depth_cap;
168            if let PredictionHeadType::EarlyExit { exit_threshold } =
169                self.prediction_config.head_type
170            {
171                let max_log_prob = log_probs
172                    .iter()
173                    .cloned()
174                    .fold(f32::NEG_INFINITY, f32::max);
175                let max_prob = max_log_prob.exp();
176                if max_prob >= exit_threshold {
177                    effective_depth = 1;
178                }
179            }
180
181            let mut levels = Vec::with_capacity(effective_depth);
182            let top_k = top_k_indices(&log_probs, self.tree_config.branch_factor);
183            let mut parents = 1usize;
184            for _depth in 0..effective_depth {
185                let mut level = Vec::with_capacity(parents * top_k.len());
186                for _ in 0..parents {
187                    for &token in &top_k {
188                        level.push(SpeculativeToken {
189                            id: token,
190                            log_prob: log_probs[token],
191                        });
192                    }
193                }
194                parents = parents.saturating_mul(top_k.len().max(1));
195                levels.push(level);
196            }
197
198            max_depth = max_depth.max(effective_depth);
199            trees.push(SpeculativeTree { levels });
200        }
201
202        Ok(SpeculativeCandidates {
203            trees,
204            branch_factor: self.tree_config.branch_factor,
205            max_depth,
206            vocab_size: hidden_dim,
207        })
208    }
209
210    /// Verify candidates against target logits with rejection-style acceptance.
211    ///
212    /// # Shapes
213    /// * `target_logits`: [batch, depth, vocab]
214    /// * `cache_tokens`: [batch, cache_len]
215    pub fn verify(
216        &self,
217        candidates: &SpeculativeCandidates,
218        target_logits: Tensor<B, 3>,
219        cache_tokens: Tensor<B, 2>,
220    ) -> Result<SpeculativeVerification<B>, &'static str> {
221        let [batch, depth, vocab] = target_logits.dims();
222        if batch != candidates.trees.len() {
223            return Err("target batch mismatch");
224        }
225        if vocab != candidates.vocab_size {
226            return Err("target vocab mismatch");
227        }
228        if depth < candidates.max_depth {
229            return Err("target logits depth too small");
230        }
231
232        let [cache_batch, cache_len] = cache_tokens.dims();
233        if cache_batch != batch {
234            return Err("cache batch mismatch");
235        }
236
237        let target_data = target_logits
238            .into_data()
239            .into_vec::<f32>()
240            .map_err(|_| "target logits conversion failed")?;
241        let cache_device = cache_tokens.device();
242        let cache_data = cache_tokens
243            .into_data()
244            .into_vec::<f32>()
245            .map_err(|_| "cache conversion failed")?;
246
247        let mut accepted_tokens = Vec::with_capacity(batch);
248        for (batch_idx, tree) in candidates.trees.iter().enumerate() {
249            let mut accepted = Vec::new();
250            for (depth_idx, level) in tree.levels.iter().enumerate() {
251                let offset = (batch_idx * depth + depth_idx) * vocab;
252                let mut logits = target_data[offset..offset + vocab].to_vec();
253                if let VerificationStrategy::Sampling { temperature } = self.tree_config.verification
254                {
255                    if temperature <= 0.0 {
256                        return Err("temperature must be > 0");
257                    }
258                    for value in logits.iter_mut() {
259                        *value /= temperature;
260                    }
261                }
262                let log_probs = log_softmax(&logits);
263
264                let mut best_token = None;
265                let mut best_prob = f32::NEG_INFINITY;
266                let mut best_draft_prob = 0.0f32;
267                for token in level {
268                    let target_prob = log_probs[token.id].exp();
269                    if target_prob > best_prob {
270                        best_prob = target_prob;
271                        best_token = Some(token.id);
272                        best_draft_prob = token.log_prob.exp();
273                    }
274                }
275
276                let token_id = match best_token {
277                    Some(id) => id,
278                    None => break,
279                };
280
281                if best_prob >= best_draft_prob {
282                    accepted.push(token_id);
283                } else {
284                    break;
285                }
286            }
287            accepted_tokens.push(accepted);
288        }
289
290        let max_accept = accepted_tokens
291            .iter()
292            .map(|tokens| tokens.len())
293            .max()
294            .unwrap_or(0);
295        let new_len = cache_len + max_accept;
296        let mut updated = vec![-1.0f32; batch * new_len];
297        for batch_idx in 0..batch {
298            let src_offset = batch_idx * cache_len;
299            let dst_offset = batch_idx * new_len;
300            updated[dst_offset..dst_offset + cache_len]
301                .copy_from_slice(&cache_data[src_offset..src_offset + cache_len]);
302            for (idx, token) in accepted_tokens[batch_idx].iter().enumerate() {
303                updated[dst_offset + cache_len + idx] = *token as f32;
304            }
305        }
306
307        let updated_cache =
308            Tensor::from_data(TensorData::new(updated, [batch, new_len]), &cache_device);
309
310        Ok(SpeculativeVerification {
311            accepted_tokens,
312            updated_cache,
313        })
314    }
315}
316
317impl PredictionConfig {
318    /// Validate prediction head configuration.
319    pub fn validate(&self) -> Result<(), &'static str> {
320        if self.hidden_dim == 0 {
321            return Err("hidden_dim must be > 0");
322        }
323        self.head_type.validate()
324    }
325}
326
327impl PredictionHeadType {
328    fn validate(&self) -> Result<(), &'static str> {
329        match *self {
330            PredictionHeadType::Eagle { num_layers } => {
331                if num_layers == 0 {
332                    return Err("num_layers must be > 0");
333                }
334            }
335            PredictionHeadType::EarlyExit { exit_threshold } => {
336                if exit_threshold <= 0.0 || exit_threshold > 1.0 {
337                    return Err("exit_threshold must be in (0, 1]");
338                }
339            }
340        }
341        Ok(())
342    }
343}
344
345impl TreeConfig {
346    /// Validate tree configuration.
347    pub fn validate(&self) -> Result<(), &'static str> {
348        if self.branch_factor == 0 {
349            return Err("branch_factor must be > 0");
350        }
351        if self.depth == 0 {
352            return Err("depth must be > 0");
353        }
354        if let VerificationStrategy::Sampling { temperature } = self.verification {
355            if temperature <= 0.0 {
356                return Err("temperature must be > 0");
357            }
358        }
359        Ok(())
360    }
361}
362
363fn apply_prediction_head(
364    head_type: PredictionHeadType,
365    logits: &mut [f32],
366) -> Result<(), &'static str> {
367    match head_type {
368        PredictionHeadType::Eagle { num_layers } => {
369            for _ in 0..num_layers {
370                for value in logits.iter_mut() {
371                    let gate = 1.0 / (1.0 + (-*value).exp());
372                    *value = value.tanh() * gate;
373                }
374            }
375        }
376        PredictionHeadType::EarlyExit { .. } => {}
377    }
378    Ok(())
379}
380
381fn log_softmax(logits: &[f32]) -> Vec<f32> {
382    if logits.is_empty() {
383        return Vec::new();
384    }
385    let max = logits
386        .iter()
387        .cloned()
388        .fold(f32::NEG_INFINITY, f32::max);
389    if !max.is_finite() {
390        return vec![max; logits.len()];
391    }
392    let mut sum = 0.0f32;
393    for value in logits {
394        sum += (value - max).exp();
395    }
396    let log_sum = max + sum.ln();
397    logits.iter().map(|value| value - log_sum).collect()
398}
399
400fn top_k_indices(scores: &[f32], k: usize) -> Vec<usize> {
401    if k == 0 || scores.is_empty() {
402        return Vec::new();
403    }
404    let mut scored: Vec<(usize, f32)> = scores
405        .iter()
406        .enumerate()
407        .map(|(idx, &score)| {
408            let score = if score.is_nan() {
409                f32::NEG_INFINITY
410            } else {
411                score
412            };
413            (idx, score)
414        })
415        .collect();
416    scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
417    scored.truncate(k.min(scored.len()));
418    scored.into_iter().map(|(idx, _)| idx).collect()
419}
420
421#[cfg(all(test, feature = "cpu"))]
422mod tests {
423    use super::*;
424    use burn_ndarray::NdArray;
425
426    #[test]
427    fn test_speculate_tree_depth() {
428        let config = PredictionConfig {
429            hidden_dim: 4,
430            head_type: PredictionHeadType::Eagle { num_layers: 2 },
431        };
432        let tree_config = TreeConfig {
433            branch_factor: 2,
434            depth: 3,
435            verification: VerificationStrategy::Greedy,
436        };
437        let decoder = SpeculativeDecoder::<NdArray<f32>>::new(config, tree_config, 2);
438        let device = <NdArray<f32> as Backend>::Device::default();
439        let data = vec![
440            0.1, 0.2, 0.3, 0.4, 0.5, 0.4, 0.3, 0.2, 0.2, 0.1, 0.0, -0.1,
441        ];
442        let hidden = Tensor::from_data(TensorData::new(data, [1, 3, 4]), &device);
443
444        let candidates = decoder.speculate(hidden).expect("speculate");
445        assert_eq!(candidates.trees.len(), 1);
446        assert_eq!(candidates.max_depth, 2);
447        assert_eq!(candidates.trees[0].levels.len(), 2);
448        assert_eq!(candidates.trees[0].levels[0].len(), 2);
449        assert_eq!(candidates.trees[0].levels[1].len(), 4);
450    }
451
452    #[test]
453    fn test_verify_rejects_on_low_target_prob() {
454        let config = PredictionConfig {
455            hidden_dim: 3,
456            head_type: PredictionHeadType::EarlyExit { exit_threshold: 0.5 },
457        };
458        let tree_config = TreeConfig {
459            branch_factor: 2,
460            depth: 2,
461            verification: VerificationStrategy::Greedy,
462        };
463        let decoder = SpeculativeDecoder::<NdArray<f32>>::new(config, tree_config, 2);
464        let device = <NdArray<f32> as Backend>::Device::default();
465        let hidden = Tensor::from_data(
466            TensorData::new(vec![0.2, 0.1, 0.0], [1, 1, 3]),
467            &device,
468        );
469        let candidates = decoder.speculate(hidden).expect("speculate");
470
471        let target_logits = Tensor::from_data(
472            TensorData::new(vec![0.0, 2.0, 0.0, -2.0, -2.0, 5.0], [1, 2, 3]),
473            &device,
474        );
475        let cache_tokens =
476            Tensor::from_data(TensorData::new(vec![1.0, 2.0], [1, 2]), &device);
477
478        let result = decoder
479            .verify(&candidates, target_logits, cache_tokens)
480            .expect("verify");
481        assert_eq!(result.accepted_tokens.len(), 1);
482        assert_eq!(result.accepted_tokens[0].len(), 1);
483        let updated = result
484            .updated_cache
485            .into_data()
486            .into_vec::<f32>()
487            .expect("cache data");
488        assert_eq!(updated.len(), 3);
489    }
490}