cbtop/continuous_batcher/
speculative.rs1use std::fmt;
4
5use super::request::Token;
6
7#[derive(Debug, Clone)]
9pub struct ExponentialMovingAverage {
10 value: f64,
12 alpha: f64,
14 count: u64,
16}
17
18impl ExponentialMovingAverage {
19 pub fn new(alpha: f64) -> Self {
21 Self {
22 value: 0.0,
23 alpha: alpha.clamp(0.0, 1.0),
24 count: 0,
25 }
26 }
27
28 pub fn update(&mut self, sample: f64) {
30 if self.count == 0 {
31 self.value = sample;
32 } else {
33 self.value = self.alpha * sample + (1.0 - self.alpha) * self.value;
34 }
35 self.count += 1;
36 }
37
38 pub fn value(&self) -> f64 {
40 self.value
41 }
42
43 pub fn reset(&mut self) {
45 self.value = 0.0;
46 self.count = 0;
47 }
48}
49
50impl Default for ExponentialMovingAverage {
51 fn default() -> Self {
52 Self::new(0.1)
53 }
54}
55
56#[derive(Debug, Clone)]
58pub struct SpeculativeOutput {
59 pub accepted: Vec<Token>,
61 pub rejection_idx: Option<usize>,
63 pub target_token: Token,
65 pub draft_count: usize,
67}
68
69impl SpeculativeOutput {
70 pub fn acceptance_rate(&self) -> f64 {
72 if self.draft_count == 0 {
73 return 0.0;
74 }
75 self.accepted.len() as f64 / self.draft_count as f64
76 }
77
78 pub fn total_tokens(&self) -> usize {
80 self.accepted.len() + 1
81 }
82}
83
84#[derive(Debug)]
89pub struct SpeculativeDecoder {
90 k: usize,
92 acceptance_rate: ExponentialMovingAverage,
94 total_steps: u64,
96 total_accepted: u64,
98 total_draft: u64,
100}
101
102impl SpeculativeDecoder {
103 pub fn new(k: usize) -> Self {
105 Self {
106 k,
107 acceptance_rate: ExponentialMovingAverage::new(0.1),
108 total_steps: 0,
109 total_accepted: 0,
110 total_draft: 0,
111 }
112 }
113
114 pub fn k(&self) -> usize {
116 self.k
117 }
118
119 pub fn set_k(&mut self, k: usize) {
121 self.k = k;
122 }
123
124 pub fn simulate_step(
131 &mut self,
132 draft_tokens: &[Token],
133 target_probs: &[(Token, f64)],
134 ) -> SpeculativeOutput {
135 let draft_count = draft_tokens.len().min(self.k);
136 let mut accepted = Vec::new();
137 let mut rejection_idx = None;
138
139 for (i, &draft_token) in draft_tokens.iter().take(draft_count).enumerate() {
141 if let Some((target_token, _)) = target_probs.get(i) {
142 if *target_token == draft_token {
143 accepted.push(draft_token);
144 } else {
145 rejection_idx = Some(i);
146 break;
147 }
148 } else {
149 rejection_idx = Some(i);
150 break;
151 }
152 }
153
154 let target_token = if let Some(idx) = rejection_idx {
156 target_probs.get(idx).map(|(t, _)| *t).unwrap_or(0)
157 } else {
158 target_probs.get(draft_count).map(|(t, _)| *t).unwrap_or(0)
159 };
160
161 let output = SpeculativeOutput {
162 accepted: accepted.clone(),
163 rejection_idx,
164 target_token,
165 draft_count,
166 };
167
168 self.total_steps += 1;
170 self.total_accepted += accepted.len() as u64;
171 self.total_draft += draft_count as u64;
172 self.acceptance_rate.update(output.acceptance_rate());
173
174 output
175 }
176
177 pub fn acceptance_rate(&self) -> f64 {
179 self.acceptance_rate.value()
180 }
181
182 pub fn overall_acceptance_rate(&self) -> f64 {
184 if self.total_draft == 0 {
185 return 0.0;
186 }
187 self.total_accepted as f64 / self.total_draft as f64
188 }
189
190 pub fn speedup(&self) -> f64 {
195 let rate = self.acceptance_rate();
196 1.0 + (self.k as f64) * rate
200 }
201
202 pub fn stats(&self) -> (u64, u64, u64) {
204 (self.total_steps, self.total_accepted, self.total_draft)
205 }
206}
207
208impl fmt::Display for SpeculativeDecoder {
209 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
210 write!(
211 f,
212 "SpeculativeDecoder(k={}, acceptance={:.1}%, speedup={:.2}x)",
213 self.k,
214 self.acceptance_rate() * 100.0,
215 self.speedup()
216 )
217 }
218}