1use std::marker::PhantomData;
4
5use burn::tensor::backend::Backend;
6use burn::tensor::{Tensor, TensorData};
7
8pub struct SpeculativeDecoder<B: Backend> {
10 prediction_config: PredictionConfig,
12 tree_config: TreeConfig,
14 max_speculation_length: usize,
16 _marker: PhantomData<B>,
17}
18
19#[derive(Debug, Clone, Copy)]
21pub struct PredictionConfig {
22 pub hidden_dim: usize,
24 pub head_type: PredictionHeadType,
26}
27
28#[derive(Debug, Clone, Copy)]
30pub enum PredictionHeadType {
31 Eagle { num_layers: usize },
33 EarlyExit { exit_threshold: f32 },
35}
36
37#[derive(Debug, Clone, Copy)]
39pub struct TreeConfig {
40 pub branch_factor: usize,
42 pub depth: usize,
44 pub verification: VerificationStrategy,
46}
47
48#[derive(Debug, Clone, Copy)]
50pub enum VerificationStrategy {
51 Greedy,
53 Sampling { temperature: f32 },
55}
56
57#[derive(Debug, Clone, Copy)]
59pub struct SpeculativeToken {
60 pub id: usize,
62 pub log_prob: f32,
64}
65
66#[derive(Debug, Clone)]
68pub struct SpeculativeTree {
69 pub levels: Vec<Vec<SpeculativeToken>>,
71}
72
73#[derive(Debug, Clone)]
75pub struct SpeculativeCandidates {
76 pub trees: Vec<SpeculativeTree>,
78 pub branch_factor: usize,
80 pub max_depth: usize,
82 pub vocab_size: usize,
84}
85
86#[derive(Debug)]
88pub struct SpeculativeVerification<B: Backend> {
89 pub accepted_tokens: Vec<Vec<usize>>,
91 pub updated_cache: Tensor<B, 2>,
93}
94
95impl<B: Backend> SpeculativeDecoder<B> {
96 pub fn new(
98 prediction_config: PredictionConfig,
99 tree_config: TreeConfig,
100 max_speculation_length: usize,
101 ) -> Self {
102 Self {
103 prediction_config,
104 tree_config,
105 max_speculation_length,
106 _marker: PhantomData,
107 }
108 }
109
110 pub fn prediction_config(&self) -> &PredictionConfig {
112 &self.prediction_config
113 }
114
115 pub fn tree_config(&self) -> &TreeConfig {
117 &self.tree_config
118 }
119
120 pub fn max_speculation_length(&self) -> usize {
122 self.max_speculation_length
123 }
124
125 pub fn speculate(&self, hidden: Tensor<B, 3>) -> Result<SpeculativeCandidates, &'static str> {
130 self.prediction_config.validate()?;
131 self.tree_config.validate()?;
132 if self.max_speculation_length == 0 {
133 return Err("max speculation length must be > 0");
134 }
135
136 let [batch, seq_len, hidden_dim] = hidden.dims();
137 if seq_len == 0 {
138 return Err("sequence length must be > 0");
139 }
140 if hidden_dim != self.prediction_config.hidden_dim {
141 return Err("hidden dimension mismatch");
142 }
143
144 let hidden_data = hidden
145 .into_data()
146 .into_vec::<f32>()
147 .map_err(|_| "hidden data conversion failed")?;
148
149 let depth_cap = self
150 .tree_config
151 .depth
152 .min(self.max_speculation_length);
153 if depth_cap == 0 {
154 return Err("speculation depth must be > 0");
155 }
156
157 let mut trees = Vec::with_capacity(batch);
158 let mut max_depth = 0;
159
160 for batch_idx in 0..batch {
161 let base = batch_idx * seq_len * hidden_dim;
162 let offset = base + (seq_len - 1) * hidden_dim;
163 let mut logits = hidden_data[offset..offset + hidden_dim].to_vec();
164 apply_prediction_head(self.prediction_config.head_type, &mut logits)?;
165 let log_probs = log_softmax(&logits);
166
167 let mut effective_depth = depth_cap;
168 if let PredictionHeadType::EarlyExit { exit_threshold } =
169 self.prediction_config.head_type
170 {
171 let max_log_prob = log_probs
172 .iter()
173 .cloned()
174 .fold(f32::NEG_INFINITY, f32::max);
175 let max_prob = max_log_prob.exp();
176 if max_prob >= exit_threshold {
177 effective_depth = 1;
178 }
179 }
180
181 let mut levels = Vec::with_capacity(effective_depth);
182 let top_k = top_k_indices(&log_probs, self.tree_config.branch_factor);
183 let mut parents = 1usize;
184 for _depth in 0..effective_depth {
185 let mut level = Vec::with_capacity(parents * top_k.len());
186 for _ in 0..parents {
187 for &token in &top_k {
188 level.push(SpeculativeToken {
189 id: token,
190 log_prob: log_probs[token],
191 });
192 }
193 }
194 parents = parents.saturating_mul(top_k.len().max(1));
195 levels.push(level);
196 }
197
198 max_depth = max_depth.max(effective_depth);
199 trees.push(SpeculativeTree { levels });
200 }
201
202 Ok(SpeculativeCandidates {
203 trees,
204 branch_factor: self.tree_config.branch_factor,
205 max_depth,
206 vocab_size: hidden_dim,
207 })
208 }
209
210 pub fn verify(
216 &self,
217 candidates: &SpeculativeCandidates,
218 target_logits: Tensor<B, 3>,
219 cache_tokens: Tensor<B, 2>,
220 ) -> Result<SpeculativeVerification<B>, &'static str> {
221 let [batch, depth, vocab] = target_logits.dims();
222 if batch != candidates.trees.len() {
223 return Err("target batch mismatch");
224 }
225 if vocab != candidates.vocab_size {
226 return Err("target vocab mismatch");
227 }
228 if depth < candidates.max_depth {
229 return Err("target logits depth too small");
230 }
231
232 let [cache_batch, cache_len] = cache_tokens.dims();
233 if cache_batch != batch {
234 return Err("cache batch mismatch");
235 }
236
237 let target_data = target_logits
238 .into_data()
239 .into_vec::<f32>()
240 .map_err(|_| "target logits conversion failed")?;
241 let cache_device = cache_tokens.device();
242 let cache_data = cache_tokens
243 .into_data()
244 .into_vec::<f32>()
245 .map_err(|_| "cache conversion failed")?;
246
247 let mut accepted_tokens = Vec::with_capacity(batch);
248 for (batch_idx, tree) in candidates.trees.iter().enumerate() {
249 let mut accepted = Vec::new();
250 for (depth_idx, level) in tree.levels.iter().enumerate() {
251 let offset = (batch_idx * depth + depth_idx) * vocab;
252 let mut logits = target_data[offset..offset + vocab].to_vec();
253 if let VerificationStrategy::Sampling { temperature } = self.tree_config.verification
254 {
255 if temperature <= 0.0 {
256 return Err("temperature must be > 0");
257 }
258 for value in logits.iter_mut() {
259 *value /= temperature;
260 }
261 }
262 let log_probs = log_softmax(&logits);
263
264 let mut best_token = None;
265 let mut best_prob = f32::NEG_INFINITY;
266 let mut best_draft_prob = 0.0f32;
267 for token in level {
268 let target_prob = log_probs[token.id].exp();
269 if target_prob > best_prob {
270 best_prob = target_prob;
271 best_token = Some(token.id);
272 best_draft_prob = token.log_prob.exp();
273 }
274 }
275
276 let token_id = match best_token {
277 Some(id) => id,
278 None => break,
279 };
280
281 if best_prob >= best_draft_prob {
282 accepted.push(token_id);
283 } else {
284 break;
285 }
286 }
287 accepted_tokens.push(accepted);
288 }
289
290 let max_accept = accepted_tokens
291 .iter()
292 .map(|tokens| tokens.len())
293 .max()
294 .unwrap_or(0);
295 let new_len = cache_len + max_accept;
296 let mut updated = vec![-1.0f32; batch * new_len];
297 for batch_idx in 0..batch {
298 let src_offset = batch_idx * cache_len;
299 let dst_offset = batch_idx * new_len;
300 updated[dst_offset..dst_offset + cache_len]
301 .copy_from_slice(&cache_data[src_offset..src_offset + cache_len]);
302 for (idx, token) in accepted_tokens[batch_idx].iter().enumerate() {
303 updated[dst_offset + cache_len + idx] = *token as f32;
304 }
305 }
306
307 let updated_cache =
308 Tensor::from_data(TensorData::new(updated, [batch, new_len]), &cache_device);
309
310 Ok(SpeculativeVerification {
311 accepted_tokens,
312 updated_cache,
313 })
314 }
315}
316
317impl PredictionConfig {
318 pub fn validate(&self) -> Result<(), &'static str> {
320 if self.hidden_dim == 0 {
321 return Err("hidden_dim must be > 0");
322 }
323 self.head_type.validate()
324 }
325}
326
327impl PredictionHeadType {
328 fn validate(&self) -> Result<(), &'static str> {
329 match *self {
330 PredictionHeadType::Eagle { num_layers } => {
331 if num_layers == 0 {
332 return Err("num_layers must be > 0");
333 }
334 }
335 PredictionHeadType::EarlyExit { exit_threshold } => {
336 if exit_threshold <= 0.0 || exit_threshold > 1.0 {
337 return Err("exit_threshold must be in (0, 1]");
338 }
339 }
340 }
341 Ok(())
342 }
343}
344
345impl TreeConfig {
346 pub fn validate(&self) -> Result<(), &'static str> {
348 if self.branch_factor == 0 {
349 return Err("branch_factor must be > 0");
350 }
351 if self.depth == 0 {
352 return Err("depth must be > 0");
353 }
354 if let VerificationStrategy::Sampling { temperature } = self.verification {
355 if temperature <= 0.0 {
356 return Err("temperature must be > 0");
357 }
358 }
359 Ok(())
360 }
361}
362
363fn apply_prediction_head(
364 head_type: PredictionHeadType,
365 logits: &mut [f32],
366) -> Result<(), &'static str> {
367 match head_type {
368 PredictionHeadType::Eagle { num_layers } => {
369 for _ in 0..num_layers {
370 for value in logits.iter_mut() {
371 let gate = 1.0 / (1.0 + (-*value).exp());
372 *value = value.tanh() * gate;
373 }
374 }
375 }
376 PredictionHeadType::EarlyExit { .. } => {}
377 }
378 Ok(())
379}
380
381fn log_softmax(logits: &[f32]) -> Vec<f32> {
382 if logits.is_empty() {
383 return Vec::new();
384 }
385 let max = logits
386 .iter()
387 .cloned()
388 .fold(f32::NEG_INFINITY, f32::max);
389 if !max.is_finite() {
390 return vec![max; logits.len()];
391 }
392 let mut sum = 0.0f32;
393 for value in logits {
394 sum += (value - max).exp();
395 }
396 let log_sum = max + sum.ln();
397 logits.iter().map(|value| value - log_sum).collect()
398}
399
400fn top_k_indices(scores: &[f32], k: usize) -> Vec<usize> {
401 if k == 0 || scores.is_empty() {
402 return Vec::new();
403 }
404 let mut scored: Vec<(usize, f32)> = scores
405 .iter()
406 .enumerate()
407 .map(|(idx, &score)| {
408 let score = if score.is_nan() {
409 f32::NEG_INFINITY
410 } else {
411 score
412 };
413 (idx, score)
414 })
415 .collect();
416 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
417 scored.truncate(k.min(scored.len()));
418 scored.into_iter().map(|(idx, _)| idx).collect()
419}
420
421#[cfg(all(test, feature = "cpu"))]
422mod tests {
423 use super::*;
424 use burn_ndarray::NdArray;
425
426 #[test]
427 fn test_speculate_tree_depth() {
428 let config = PredictionConfig {
429 hidden_dim: 4,
430 head_type: PredictionHeadType::Eagle { num_layers: 2 },
431 };
432 let tree_config = TreeConfig {
433 branch_factor: 2,
434 depth: 3,
435 verification: VerificationStrategy::Greedy,
436 };
437 let decoder = SpeculativeDecoder::<NdArray<f32>>::new(config, tree_config, 2);
438 let device = <NdArray<f32> as Backend>::Device::default();
439 let data = vec![
440 0.1, 0.2, 0.3, 0.4, 0.5, 0.4, 0.3, 0.2, 0.2, 0.1, 0.0, -0.1,
441 ];
442 let hidden = Tensor::from_data(TensorData::new(data, [1, 3, 4]), &device);
443
444 let candidates = decoder.speculate(hidden).expect("speculate");
445 assert_eq!(candidates.trees.len(), 1);
446 assert_eq!(candidates.max_depth, 2);
447 assert_eq!(candidates.trees[0].levels.len(), 2);
448 assert_eq!(candidates.trees[0].levels[0].len(), 2);
449 assert_eq!(candidates.trees[0].levels[1].len(), 4);
450 }
451
452 #[test]
453 fn test_verify_rejects_on_low_target_prob() {
454 let config = PredictionConfig {
455 hidden_dim: 3,
456 head_type: PredictionHeadType::EarlyExit { exit_threshold: 0.5 },
457 };
458 let tree_config = TreeConfig {
459 branch_factor: 2,
460 depth: 2,
461 verification: VerificationStrategy::Greedy,
462 };
463 let decoder = SpeculativeDecoder::<NdArray<f32>>::new(config, tree_config, 2);
464 let device = <NdArray<f32> as Backend>::Device::default();
465 let hidden = Tensor::from_data(
466 TensorData::new(vec![0.2, 0.1, 0.0], [1, 1, 3]),
467 &device,
468 );
469 let candidates = decoder.speculate(hidden).expect("speculate");
470
471 let target_logits = Tensor::from_data(
472 TensorData::new(vec![0.0, 2.0, 0.0, -2.0, -2.0, 5.0], [1, 2, 3]),
473 &device,
474 );
475 let cache_tokens =
476 Tensor::from_data(TensorData::new(vec![1.0, 2.0], [1, 2]), &device);
477
478 let result = decoder
479 .verify(&candidates, target_logits, cache_tokens)
480 .expect("verify");
481 assert_eq!(result.accepted_tokens.len(), 1);
482 assert_eq!(result.accepted_tokens[0].len(), 1);
483 let updated = result
484 .updated_cache
485 .into_data()
486 .into_vec::<f32>()
487 .expect("cache data");
488 assert_eq!(updated.len(), 3);
489 }
490}