1use crate::adaptive_lookahead::{AdaptiveLookahead, AdaptiveLookaheadConfig};
34use crate::engine::InferenceEngine;
35use crate::sampling::SamplingParams;
36
37#[derive(Debug, Clone)]
43pub struct SpeculativeConfig {
44 pub lookahead: usize,
46 pub acceptance_threshold: f32,
51}
52
53impl Default for SpeculativeConfig {
54 fn default() -> Self {
55 Self {
56 lookahead: 5,
57 acceptance_threshold: 0.0,
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
68pub struct SpeculativeStep {
69 pub draft_tokens: Vec<u32>,
71 pub accepted_tokens: Vec<u32>,
73 pub acceptance_rate: f32,
75}
76
77struct Xorshift64 {
83 state: u64,
84}
85
86impl Xorshift64 {
87 fn new(seed: u64) -> Self {
88 let state = if seed == 0 { 0xdeadbeef_cafebabe } else { seed };
90 Self { state }
91 }
92
93 fn next_u64(&mut self) -> u64 {
94 self.state ^= self.state << 13;
95 self.state ^= self.state >> 7;
96 self.state ^= self.state << 17;
97 self.state
98 }
99
100 fn next_f32(&mut self) -> f32 {
102 (self.next_u64() >> 40) as f32 / (1u64 << 24) as f32
104 }
105}
106
107pub struct SpeculativeDecoder<'a> {
114 pub draft_engine: InferenceEngine<'a>,
116 pub config: SpeculativeConfig,
118 pub total_steps: u64,
120 pub total_draft_tokens: u64,
122 pub total_accepted_tokens: u64,
124 #[allow(dead_code)]
126 rng: Xorshift64,
127 adaptive: Option<AdaptiveLookahead>,
130}
131
132impl<'a> SpeculativeDecoder<'a> {
133 pub fn new(draft_engine: InferenceEngine<'a>, config: SpeculativeConfig) -> Self {
135 Self {
136 draft_engine,
137 config,
138 total_steps: 0,
139 total_draft_tokens: 0,
140 total_accepted_tokens: 0,
141 rng: Xorshift64::new(0xfeed1234_5678abcd),
142 adaptive: None,
143 }
144 }
145
146 pub fn with_adaptive(
150 draft_engine: InferenceEngine<'a>,
151 config: SpeculativeConfig,
152 adaptive_config: AdaptiveLookaheadConfig,
153 ) -> Result<Self, crate::adaptive_lookahead::AdaptiveLookaheadError> {
154 let adaptive = AdaptiveLookahead::try_new(adaptive_config)?;
155 let mut config = config;
156 config.lookahead = adaptive.lookahead();
157 Ok(Self {
158 draft_engine,
159 config,
160 total_steps: 0,
161 total_draft_tokens: 0,
162 total_accepted_tokens: 0,
163 rng: Xorshift64::new(0xfeed1234_5678abcd),
164 adaptive: Some(adaptive),
165 })
166 }
167
168 pub fn adaptive(&self) -> Option<&AdaptiveLookahead> {
170 self.adaptive.as_ref()
171 }
172
173 pub fn adaptive_mut(&mut self) -> Option<&mut AdaptiveLookahead> {
175 self.adaptive.as_mut()
176 }
177
178 pub fn draft(&mut self, context: &[u32], _params: &SamplingParams) -> Vec<u32> {
184 let k = self.config.lookahead;
185 let mut draft_tokens = Vec::with_capacity(k);
186
187 let mut current_context: Vec<u32> = context.to_vec();
189
190 for _ in 0..k {
191 match self.draft_engine.generate(¤t_context, 1) {
193 Ok(generated) if !generated.is_empty() => {
194 let token = generated[0];
195 draft_tokens.push(token);
196 current_context.push(token);
197 }
198 _ => {
199 break;
201 }
202 }
203 }
204
205 draft_tokens
206 }
207
208 pub fn verify(
221 &self,
222 draft_tokens: &[u32],
223 target_logits: &[Vec<f32>],
224 _params: &SamplingParams,
225 ) -> Vec<u32> {
226 let mut accepted = Vec::with_capacity(draft_tokens.len());
227
228 let mut local_rng = Xorshift64::new(
230 self.total_steps
231 .wrapping_mul(6364136223846793005)
232 .wrapping_add(0xabcdef01),
233 );
234
235 for (i, &token) in draft_tokens.iter().enumerate() {
236 let logits = match target_logits.get(i) {
237 Some(l) => l,
238 None => break,
239 };
240
241 if logits.is_empty() {
242 break;
243 }
244
245 let target_probs = softmax(logits);
247
248 let target_prob = if (token as usize) < target_probs.len() {
250 target_probs[token as usize]
251 } else {
252 0.0
253 };
254
255 let vocab_size = logits.len() as f32;
259 let draft_prob = (1.0 / vocab_size).max(1e-9);
260
261 let rng_sample = local_rng.next_f32();
262 let threshold = self.config.acceptance_threshold;
263
264 if Self::should_accept(draft_prob, target_prob, threshold, rng_sample) {
265 accepted.push(token);
266 } else {
267 break;
269 }
270 }
271
272 accepted
273 }
274
275 pub fn step(
280 &mut self,
281 context: &[u32],
282 target_logits: &[Vec<f32>],
283 params: &SamplingParams,
284 ) -> SpeculativeStep {
285 let draft_tokens = self.draft(context, params);
287 let n_drafted = draft_tokens.len();
288
289 let accepted_tokens = self.verify(&draft_tokens, target_logits, params);
291 let n_accepted = accepted_tokens.len();
292
293 self.total_steps += 1;
295 self.total_draft_tokens += n_drafted as u64;
296 self.total_accepted_tokens += n_accepted as u64;
297
298 if let Some(adaptive) = self.adaptive.as_mut() {
300 adaptive.observe_step(n_drafted, n_accepted);
301 self.config.lookahead = adaptive.lookahead();
304 }
305
306 let acceptance_rate = if n_drafted > 0 {
307 n_accepted as f32 / n_drafted as f32
308 } else {
309 0.0
310 };
311
312 SpeculativeStep {
313 draft_tokens,
314 accepted_tokens,
315 acceptance_rate,
316 }
317 }
318
319 pub fn generate_speculative(
329 &mut self,
330 prompt_tokens: &[u32],
331 max_tokens: usize,
332 params: &SamplingParams,
333 ) -> Vec<u32> {
334 let mut output: Vec<u32> = Vec::with_capacity(max_tokens);
335 let mut context: Vec<u32> = prompt_tokens.to_vec();
336
337 while output.len() < max_tokens {
338 let remaining = max_tokens - output.len();
339 let effective_lookahead = self.config.lookahead.min(remaining);
340
341 let vocab_size = 32000usize; let target_logits: Vec<Vec<f32>> = (0..effective_lookahead)
346 .map(|step_idx| {
347 let peak_token =
349 (context.last().copied().unwrap_or(0) as usize + step_idx + 1) % vocab_size;
350 let mut logits = vec![0.0f32; vocab_size];
351 logits[peak_token] = 10.0;
353 for (i, l) in logits.iter_mut().enumerate() {
354 if i != peak_token {
355 *l = -2.0;
356 }
357 }
358 logits
359 })
360 .collect();
361
362 let step_result = self.step(&context, &target_logits, params);
363
364 if step_result.accepted_tokens.is_empty() {
365 match self.draft_engine.generate(&context, 1) {
367 Ok(t) if !t.is_empty() => {
368 let token = t[0];
369 output.push(token);
370 context.push(token);
371 }
372 _ => break,
373 }
374 } else {
375 let to_take = step_result.accepted_tokens.len().min(remaining);
376 for &tok in step_result.accepted_tokens[..to_take].iter() {
377 output.push(tok);
378 context.push(tok);
379 if output.len() >= max_tokens {
380 break;
381 }
382 }
383 }
384
385 if context.len() > prompt_tokens.len() + max_tokens + self.config.lookahead {
387 break;
388 }
389 }
390
391 output
392 }
393
394 pub fn acceptance_rate(&self) -> f32 {
398 if self.total_draft_tokens == 0 {
399 return 0.0;
400 }
401 self.total_accepted_tokens as f32 / self.total_draft_tokens as f32
402 }
403
404 pub fn speedup_estimate(&self) -> f32 {
413 if self.total_steps == 0 {
414 return 1.0;
415 }
416 let avg_accepted = self.total_accepted_tokens as f32 / self.total_steps as f32;
417 avg_accepted.max(1.0)
419 }
420
421 pub fn reset_stats(&mut self) {
424 self.total_steps = 0;
425 self.total_draft_tokens = 0;
426 self.total_accepted_tokens = 0;
427 if let Some(adaptive) = self.adaptive.as_mut() {
428 adaptive.reset();
429 self.config.lookahead = adaptive.lookahead();
430 }
431 }
432
433 fn should_accept(draft_prob: f32, target_prob: f32, threshold: f32, rng_sample: f32) -> bool {
442 if target_prob >= draft_prob {
443 true
445 } else {
446 let accept_prob = (target_prob / draft_prob).max(0.0);
448 let effective_threshold = accept_prob - threshold;
449 rng_sample < effective_threshold
450 }
451 }
452}
453
454fn softmax(logits: &[f32]) -> Vec<f32> {
460 if logits.is_empty() {
461 return vec![];
462 }
463 let max_val = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
464 let exps: Vec<f32> = logits.iter().map(|&l| (l - max_val).exp()).collect();
465 let sum: f32 = exps.iter().sum();
466 if sum < 1e-30 {
467 let n = logits.len() as f32;
469 return vec![1.0 / n; logits.len()];
470 }
471 exps.iter().map(|&e| e / sum).collect()
472}
473
474#[cfg(test)]
479mod tests {
480 use super::*;
481 use oxibonsai_core::config::Qwen3Config;
482
483 fn make_decoder(lookahead: usize) -> SpeculativeDecoder<'static> {
484 let config = Qwen3Config::tiny_test();
486 let params = SamplingParams::default();
487 let engine = InferenceEngine::new(config, params, 42);
488 let spec_config = SpeculativeConfig {
489 lookahead,
490 acceptance_threshold: 0.0,
491 };
492 SpeculativeDecoder::new(engine, spec_config)
493 }
494
495 fn make_peaked_logits(
496 vocab_size: usize,
497 peak_token: usize,
498 n_positions: usize,
499 ) -> Vec<Vec<f32>> {
500 (0..n_positions)
501 .map(|_| {
502 let mut logits = vec![-5.0f32; vocab_size];
503 if peak_token < vocab_size {
504 logits[peak_token] = 10.0;
505 }
506 logits
507 })
508 .collect()
509 }
510
511 #[test]
512 fn test_speculative_config_defaults() {
513 let cfg = SpeculativeConfig::default();
514 assert_eq!(cfg.lookahead, 5, "default lookahead should be 5");
515 assert!(
516 (cfg.acceptance_threshold - 0.0).abs() < f32::EPSILON,
517 "default threshold should be 0.0"
518 );
519 }
520
521 #[test]
522 fn test_draft_generates_lookahead_tokens() {
523 let mut decoder = make_decoder(3);
524 let context = vec![1u32, 2, 3];
525 let params = SamplingParams::default();
526 let draft = decoder.draft(&context, ¶ms);
527 assert!(
529 draft.len() <= 3,
530 "draft should not exceed lookahead=3, got {}",
531 draft.len()
532 );
533 }
534
535 #[test]
536 fn test_verify_accepts_high_probability_tokens() {
537 let decoder = make_decoder(5);
538 let params = SamplingParams::default();
539 let vocab_size = 100;
540
541 let draft_tokens = vec![42u32];
543 let target_logits = make_peaked_logits(vocab_size, 42, 1);
544
545 let accepted = decoder.verify(&draft_tokens, &target_logits, ¶ms);
546 assert_eq!(
547 accepted.len(),
548 1,
549 "high-probability token should be accepted"
550 );
551 assert_eq!(accepted[0], 42);
552 }
553
554 #[test]
555 fn test_verify_rejects_low_probability_tokens() {
556 let decoder = make_decoder(5);
557 let params = SamplingParams::default();
558 let vocab_size = 1000;
559
560 let draft_tokens = vec![500u32];
562 let mut logits = vec![-10.0f32; vocab_size];
563 logits[0] = 20.0; let target_logits = vec![logits];
565
566 let mut rejections = 0;
569 for _ in 0..20 {
570 let accepted = decoder.verify(&draft_tokens, &target_logits, ¶ms);
571 if accepted.is_empty() {
572 rejections += 1;
573 }
574 }
575 assert!(
576 rejections > 0,
577 "low-probability token should be rejected at least sometimes"
578 );
579 }
580
581 #[test]
582 fn test_acceptance_rate_zero_at_start() {
583 let decoder = make_decoder(5);
584 assert!(
585 (decoder.acceptance_rate() - 0.0).abs() < f32::EPSILON,
586 "acceptance rate must be 0.0 before any steps"
587 );
588 assert_eq!(decoder.total_steps, 0);
589 assert_eq!(decoder.total_draft_tokens, 0);
590 assert_eq!(decoder.total_accepted_tokens, 0);
591 }
592
593 #[test]
594 fn test_acceptance_rate_updates_after_step() {
595 let mut decoder = make_decoder(4);
596 let params = SamplingParams::default();
597 let context = vec![1u32, 2, 3];
598
599 let vocab_size = 32usize;
601 let target_logits = make_peaked_logits(vocab_size, 5, 4);
602
603 let step = decoder.step(&context, &target_logits, ¶ms);
604
605 assert_eq!(decoder.total_steps, 1, "one step should have been recorded");
606 assert_eq!(
607 decoder.total_draft_tokens,
608 step.draft_tokens.len() as u64,
609 "draft token count should match"
610 );
611 assert!(
612 decoder.total_accepted_tokens <= decoder.total_draft_tokens,
613 "accepted cannot exceed drafted"
614 );
615 }
616
617 #[test]
618 fn test_generate_speculative_returns_tokens() {
619 let mut decoder = make_decoder(3);
620 let params = SamplingParams::default();
621 let prompt = vec![1u32, 2, 3];
622
623 let output = decoder.generate_speculative(&prompt, 5, ¶ms);
624 assert!(
626 output.len() <= 5,
627 "output should not exceed max_tokens=5, got {}",
628 output.len()
629 );
630 }
631
632 #[test]
633 fn test_should_accept_target_above_draft() {
634 assert!(
636 SpeculativeDecoder::should_accept(0.1, 0.9, 0.0, 0.99),
637 "target > draft: must accept even with rng_sample near 1.0"
638 );
639 assert!(
640 SpeculativeDecoder::should_accept(0.05, 0.5, 0.0, 0.0),
641 "target > draft: must accept with rng_sample=0.0"
642 );
643 }
644
645 #[test]
646 fn test_should_accept_target_below_draft_probabilistic() {
647 assert!(
651 SpeculativeDecoder::should_accept(1.0, 0.1, 0.0, 0.05),
652 "rng_sample=0.05 < accept_prob=0.1, should accept"
653 );
654 assert!(
656 !SpeculativeDecoder::should_accept(1.0, 0.1, 0.0, 0.5),
657 "rng_sample=0.5 >= accept_prob=0.1, should reject"
658 );
659 }
660
661 #[test]
662 fn test_speedup_estimate_below_lookahead() {
663 let mut decoder = make_decoder(5);
664 assert!(
666 (decoder.speedup_estimate() - 1.0).abs() < f32::EPSILON,
667 "initial speedup should be 1.0"
668 );
669
670 decoder.total_steps = 10;
672 decoder.total_draft_tokens = 30;
673 decoder.total_accepted_tokens = 15;
674
675 let speedup = decoder.speedup_estimate();
676 assert!(
678 (speedup - 1.5).abs() < 1e-4,
679 "speedup should be 1.5 (avg accepted per step), got {speedup}"
680 );
681 assert!(
682 speedup <= decoder.config.lookahead as f32 + 1.0,
683 "speedup cannot exceed lookahead+1"
684 );
685 }
686
687 #[test]
688 fn test_with_adaptive_starts_with_initial_lookahead() {
689 let config = Qwen3Config::tiny_test();
690 let params = SamplingParams::default();
691 let engine = InferenceEngine::new(config, params, 42);
692 let spec_cfg = SpeculativeConfig {
693 lookahead: 99,
694 acceptance_threshold: 0.0,
695 };
696 let adapt_cfg = AdaptiveLookaheadConfig {
697 initial: 4,
698 min: 2,
699 max: 10,
700 alpha: 0.5,
701 cooldown_steps: 1,
702 };
703 let decoder =
704 SpeculativeDecoder::with_adaptive(engine, spec_cfg, adapt_cfg).expect("valid");
705 assert_eq!(decoder.config.lookahead, 4);
707 assert!(decoder.adaptive().is_some());
708 }
709
710 #[test]
711 fn test_adaptive_decreases_lookahead_on_low_acceptance() {
712 let config = Qwen3Config::tiny_test();
713 let params = SamplingParams::default();
714 let engine = InferenceEngine::new(config, params, 42);
715 let spec_cfg = SpeculativeConfig {
716 lookahead: 8,
717 acceptance_threshold: 0.0,
718 };
719 let adapt_cfg = AdaptiveLookaheadConfig {
720 initial: 8,
721 min: 2,
722 max: 12,
723 alpha: 0.7,
724 cooldown_steps: 1,
725 };
726 let mut decoder =
727 SpeculativeDecoder::with_adaptive(engine, spec_cfg, adapt_cfg).expect("valid");
728 let context = vec![1u32, 2, 3];
729 let params = SamplingParams::default();
730 let vocab = 100usize;
732 let logits: Vec<Vec<f32>> = (0..decoder.config.lookahead)
733 .map(|_| {
734 let mut l = vec![10.0f32; vocab];
735 l[0] = -50.0; l
737 })
738 .collect();
739 for _ in 0..30 {
740 decoder.step(&context, &logits, ¶ms);
741 }
742 let final_la = decoder.config.lookahead;
744 assert!(
745 final_la <= 8,
746 "lookahead should not increase, got {final_la}"
747 );
748 }
749
750 #[test]
751 fn test_reset_stats_resets_adaptive() {
752 let config = Qwen3Config::tiny_test();
753 let params = SamplingParams::default();
754 let engine = InferenceEngine::new(config, params, 42);
755 let spec_cfg = SpeculativeConfig {
756 lookahead: 5,
757 acceptance_threshold: 0.0,
758 };
759 let adapt_cfg = AdaptiveLookaheadConfig {
760 initial: 5,
761 min: 2,
762 max: 12,
763 alpha: 0.5,
764 cooldown_steps: 1,
765 };
766 let mut decoder =
767 SpeculativeDecoder::with_adaptive(engine, spec_cfg, adapt_cfg).expect("valid");
768 for _ in 0..30 {
770 let logits = make_peaked_logits(64, 5, decoder.config.lookahead);
771 decoder.step(&[1, 2, 3], &logits, &SamplingParams::default());
772 }
773 decoder.reset_stats();
774 assert_eq!(decoder.total_steps, 0);
775 assert_eq!(decoder.config.lookahead, 5);
776 assert_eq!(
777 decoder.adaptive().expect("adaptive present").observations(),
778 0
779 );
780 }
781}