Skip to main content

cbtop/continuous_batcher/
speculative.rs

1//! Speculative decoding and helper types for continuous batching.
2
3use std::fmt;
4
5use super::request::Token;
6
7/// Exponential moving average for tracking metrics.
8#[derive(Debug, Clone)]
9pub struct ExponentialMovingAverage {
10    /// Current value
11    value: f64,
12    /// Smoothing factor (0-1)
13    alpha: f64,
14    /// Number of samples
15    count: u64,
16}
17
18impl ExponentialMovingAverage {
19    /// Create new EMA with smoothing factor.
20    pub fn new(alpha: f64) -> Self {
21        Self {
22            value: 0.0,
23            alpha: alpha.clamp(0.0, 1.0),
24            count: 0,
25        }
26    }
27
28    /// Update with new sample.
29    pub fn update(&mut self, sample: f64) {
30        if self.count == 0 {
31            self.value = sample;
32        } else {
33            self.value = self.alpha * sample + (1.0 - self.alpha) * self.value;
34        }
35        self.count += 1;
36    }
37
38    /// Get current value.
39    pub fn value(&self) -> f64 {
40        self.value
41    }
42
43    /// Reset to initial state.
44    pub fn reset(&mut self) {
45        self.value = 0.0;
46        self.count = 0;
47    }
48}
49
50impl Default for ExponentialMovingAverage {
51    fn default() -> Self {
52        Self::new(0.1)
53    }
54}
55
56/// Output from speculative decoding step.
57#[derive(Debug, Clone)]
58pub struct SpeculativeOutput {
59    /// Accepted tokens from draft
60    pub accepted: Vec<Token>,
61    /// Rejection index (first rejected draft token)
62    pub rejection_idx: Option<usize>,
63    /// Token from target model (after rejection or all accepted)
64    pub target_token: Token,
65    /// Total draft tokens proposed
66    pub draft_count: usize,
67}
68
69impl SpeculativeOutput {
70    /// Calculate acceptance rate for this step.
71    pub fn acceptance_rate(&self) -> f64 {
72        if self.draft_count == 0 {
73            return 0.0;
74        }
75        self.accepted.len() as f64 / self.draft_count as f64
76    }
77
78    /// Number of tokens produced (accepted + 1 target).
79    pub fn total_tokens(&self) -> usize {
80        self.accepted.len() + 1
81    }
82}
83
84/// Speculative decoding coordinator.
85///
86/// Coordinates draft and target models for speculative decoding.
87/// The draft model proposes tokens, target model verifies.
88#[derive(Debug)]
89pub struct SpeculativeDecoder {
90    /// Speculation depth (draft tokens per step)
91    k: usize,
92    /// Acceptance rate tracker
93    acceptance_rate: ExponentialMovingAverage,
94    /// Total steps
95    total_steps: u64,
96    /// Total accepted tokens
97    total_accepted: u64,
98    /// Total draft tokens
99    total_draft: u64,
100}
101
102impl SpeculativeDecoder {
103    /// Create new speculative decoder.
104    pub fn new(k: usize) -> Self {
105        Self {
106            k,
107            acceptance_rate: ExponentialMovingAverage::new(0.1),
108            total_steps: 0,
109            total_accepted: 0,
110            total_draft: 0,
111        }
112    }
113
114    /// Get speculation depth.
115    pub fn k(&self) -> usize {
116        self.k
117    }
118
119    /// Set speculation depth.
120    pub fn set_k(&mut self, k: usize) {
121        self.k = k;
122    }
123
124    /// Simulate speculative decoding step.
125    ///
126    /// In a real implementation, this would:
127    /// 1. Run draft model k times to get draft tokens
128    /// 2. Run target model once on all draft positions
129    /// 3. Compare and accept/reject
130    pub fn simulate_step(
131        &mut self,
132        draft_tokens: &[Token],
133        target_probs: &[(Token, f64)],
134    ) -> SpeculativeOutput {
135        let draft_count = draft_tokens.len().min(self.k);
136        let mut accepted = Vec::new();
137        let mut rejection_idx = None;
138
139        // Simulate acceptance (simplified: accept if target agrees)
140        for (i, &draft_token) in draft_tokens.iter().take(draft_count).enumerate() {
141            if let Some((target_token, _)) = target_probs.get(i) {
142                if *target_token == draft_token {
143                    accepted.push(draft_token);
144                } else {
145                    rejection_idx = Some(i);
146                    break;
147                }
148            } else {
149                rejection_idx = Some(i);
150                break;
151            }
152        }
153
154        // Get target token (either after rejection or as the k+1 token)
155        let target_token = if let Some(idx) = rejection_idx {
156            target_probs.get(idx).map(|(t, _)| *t).unwrap_or(0)
157        } else {
158            target_probs.get(draft_count).map(|(t, _)| *t).unwrap_or(0)
159        };
160
161        let output = SpeculativeOutput {
162            accepted: accepted.clone(),
163            rejection_idx,
164            target_token,
165            draft_count,
166        };
167
168        // Update statistics
169        self.total_steps += 1;
170        self.total_accepted += accepted.len() as u64;
171        self.total_draft += draft_count as u64;
172        self.acceptance_rate.update(output.acceptance_rate());
173
174        output
175    }
176
177    /// Current acceptance rate (EMA).
178    pub fn acceptance_rate(&self) -> f64 {
179        self.acceptance_rate.value()
180    }
181
182    /// Overall acceptance rate.
183    pub fn overall_acceptance_rate(&self) -> f64 {
184        if self.total_draft == 0 {
185            return 0.0;
186        }
187        self.total_accepted as f64 / self.total_draft as f64
188    }
189
190    /// Effective speedup vs naive decoding.
191    ///
192    /// Speedup = (accepted + 1) / (1 + verification_cost_ratio)
193    /// Simplified: assume verification cost = 1/k of naive
194    pub fn speedup(&self) -> f64 {
195        let rate = self.acceptance_rate();
196        // Expected tokens per step = 1 + k * acceptance_rate
197        // Cost = 1 (target call) + k * draft_cost (assume draft_cost << 1)
198        // Simplified model: speedup ≈ 1 + k * acceptance_rate
199        1.0 + (self.k as f64) * rate
200    }
201
202    /// Get statistics.
203    pub fn stats(&self) -> (u64, u64, u64) {
204        (self.total_steps, self.total_accepted, self.total_draft)
205    }
206}
207
208impl fmt::Display for SpeculativeDecoder {
209    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
210        write!(
211            f,
212            "SpeculativeDecoder(k={}, acceptance={:.1}%, speedup={:.2}x)",
213            self.k,
214            self.acceptance_rate() * 100.0,
215            self.speedup()
216        )
217    }
218}