1use std::sync::Arc;
2
3use hanzo_ml::{DType, Result, Tensor};
4use rand::Rng;
5use rand_isaac::Isaac64Rng;
6
7use crate::pipeline::sampling::{finish_or_add_toks_to_seq, sample_sequence};
8use crate::pipeline::Pipeline;
9use crate::prefix_cacher::PrefixCacheManagerV2;
10use crate::sampler::Logprobs;
11use crate::sequence::{Sequence, SequenceRecognizer, SequenceState};
12
13pub struct VerificationOutcome {
14 pub accepted_drafts: usize,
15 pub proposed_drafts: usize,
16 pub keep_len: usize,
17 pub continuation_token: Option<u32>,
18}
19
20#[allow(clippy::too_many_arguments)]
21pub async fn finish_verified_step<P: Pipeline>(
22 pipeline: &P,
23 seq: &mut Sequence,
24 verify_logits: Tensor,
25 proposal: Vec<u32>,
26 proposal_logits: Option<Tensor>,
27 base_len: usize,
28 prefix_cacher: &mut PrefixCacheManagerV2,
29 disable_eos_stop: bool,
30 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
31 anchor_to_emit: Option<Logprobs>,
32) -> Result<VerificationOutcome> {
33 let general_metadata = pipeline.get_metadata();
34 let eos_tok = if disable_eos_stop {
35 None
36 } else {
37 Some(&general_metadata.eos_tok[..])
38 };
39 let return_logprobs = seq.return_logprobs();
40
41 if let Some(anchor) = anchor_to_emit {
42 finish_or_add_toks_to_seq(pipeline, prefix_cacher, seq, anchor, eos_tok, true).await?;
43 if matches!(seq.getstate(), SequenceState::Done(_)) {
44 let keep_len = base_len + 1;
45 seq.clear_staged_speculative_tokens();
46 return Ok(VerificationOutcome {
47 accepted_drafts: 0,
48 proposed_drafts: proposal.len(),
49 keep_len,
50 continuation_token: None,
51 });
52 }
53 }
54
55 if let Some(proposal_logits) = proposal_logits {
56 if !seq.sampler().is_argmax() && matches!(seq.recognizer, SequenceRecognizer::None) {
57 return finish_verified_step_stochastic(
58 pipeline,
59 seq,
60 verify_logits,
61 proposal,
62 proposal_logits,
63 base_len,
64 prefix_cacher,
65 eos_tok,
66 return_logprobs,
67 rng,
68 )
69 .await;
70 }
71 }
72
73 let mut accepted = 0usize;
74 for (idx, draft) in proposal.iter().copied().enumerate() {
75 let row = logit_row(&verify_logits, idx)?;
76 let sampled = sample_sequence(
77 row.clone(),
78 seq,
79 return_logprobs,
80 rng.clone(),
81 false,
82 false,
83 false,
84 )
85 .await?;
86 let sampled_token = sampled.token;
87 if sampled_token == draft {
88 accepted += 1;
89 finish_or_add_toks_to_seq(pipeline, prefix_cacher, seq, sampled, eos_tok, true).await?;
90 if matches!(seq.getstate(), SequenceState::Done(_)) {
91 let keep_len = base_len + 1 + accepted;
92 seq.clear_staged_speculative_tokens();
93 return Ok(VerificationOutcome {
94 accepted_drafts: accepted,
95 proposed_drafts: proposal.len(),
96 keep_len,
97 continuation_token: None,
98 });
99 }
100 } else {
101 let keep_len = base_len + 1 + accepted;
102 finish_or_add_toks_to_seq(pipeline, prefix_cacher, seq, sampled, eos_tok, true).await?;
103 if matches!(seq.getstate(), SequenceState::Done(_)) {
104 seq.clear_staged_speculative_tokens();
105 return Ok(VerificationOutcome {
106 accepted_drafts: accepted,
107 proposed_drafts: proposal.len(),
108 keep_len,
109 continuation_token: None,
110 });
111 }
112 return Ok(VerificationOutcome {
113 accepted_drafts: accepted,
114 proposed_drafts: proposal.len(),
115 keep_len,
116 continuation_token: Some(sampled_token),
117 });
118 }
119 }
120
121 let row = logit_row(&verify_logits, accepted)?;
122 let continuation = sample_sequence(
123 row.clone(),
124 seq,
125 return_logprobs,
126 rng.clone(),
127 false,
128 false,
129 false,
130 )
131 .await?;
132 let continuation_token = continuation.token;
133 finish_or_add_toks_to_seq(pipeline, prefix_cacher, seq, continuation, eos_tok, true).await?;
134
135 let keep_len = base_len + 1 + accepted;
136 let continuation_token = if matches!(seq.getstate(), SequenceState::Done(_)) {
137 seq.clear_staged_speculative_tokens();
138 None
139 } else {
140 Some(continuation_token)
141 };
142
143 Ok(VerificationOutcome {
144 accepted_drafts: accepted,
145 proposed_drafts: proposal.len(),
146 keep_len,
147 continuation_token,
148 })
149}
150
151#[allow(clippy::too_many_arguments)]
152async fn finish_verified_step_stochastic<P: Pipeline>(
153 pipeline: &P,
154 seq: &mut Sequence,
155 verify_logits: Tensor,
156 proposal: Vec<u32>,
157 proposal_logits: Tensor,
158 base_len: usize,
159 prefix_cacher: &mut PrefixCacheManagerV2,
160 eos_tok: Option<&[u32]>,
161 return_logprobs: bool,
162 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
163) -> Result<VerificationOutcome> {
164 let mut accepted = 0usize;
165 for (idx, draft) in proposal.iter().copied().enumerate() {
166 let target_row = logit_row(&verify_logits, idx)?;
167 let candidate_row = logit_row(&proposal_logits, idx)?;
168 let sampler = seq.sampler();
169 let target_probs =
170 sampler.speculative_target_probs(flat_logits(target_row.clone())?, seq.get_toks())?;
171 let candidate_probs =
172 sampler.speculative_candidate_probs(flat_logits(candidate_row)?, seq.get_toks())?;
173 if target_probs.len() != candidate_probs.len() {
174 hanzo_ml::bail!(
175 "speculative target/candidate vocab mismatch: target={}, candidate={}",
176 target_probs.len(),
177 candidate_probs.len()
178 );
179 }
180 let draft_idx = draft as usize;
181 let p_i = target_probs.get(draft_idx).copied().unwrap_or(0.0);
182 let q_i = candidate_probs.get(draft_idx).copied().unwrap_or(0.0);
183 let accept_prob = if q_i <= 0.0 {
184 if p_i > 0.0 {
185 1.0
186 } else {
187 0.0
188 }
189 } else {
190 (p_i / q_i).min(1.0)
191 };
192 let draw = {
193 let mut rng = rng.lock().expect("could not lock rng mutex");
194 rng.random::<f32>()
195 };
196
197 if draw <= accept_prob {
198 accepted += 1;
199 let sampled = sampler.logprobs_from_probs(draft, &target_probs, return_logprobs)?;
200 finish_or_add_toks_to_seq(pipeline, prefix_cacher, seq, sampled, eos_tok, true).await?;
201 if matches!(seq.getstate(), SequenceState::Done(_)) {
202 let keep_len = base_len + 1 + accepted;
203 seq.clear_staged_speculative_tokens();
204 return Ok(VerificationOutcome {
205 accepted_drafts: accepted,
206 proposed_drafts: proposal.len(),
207 keep_len,
208 continuation_token: None,
209 });
210 }
211 continue;
212 }
213
214 let mut adjusted_probs = target_probs
215 .iter()
216 .zip(candidate_probs.iter())
217 .map(|(p, q)| (p - q).max(0.0))
218 .collect::<Vec<_>>();
219 if normalize_probs(&mut adjusted_probs).is_err() {
220 adjusted_probs = target_probs;
221 }
222 let sampled = sampler.sample_from_probs(&adjusted_probs, return_logprobs, rng.clone())?;
223 let sampled_token = sampled.token;
224 let keep_len = base_len + 1 + accepted;
225 finish_or_add_toks_to_seq(pipeline, prefix_cacher, seq, sampled, eos_tok, true).await?;
226 if matches!(seq.getstate(), SequenceState::Done(_)) {
227 seq.clear_staged_speculative_tokens();
228 return Ok(VerificationOutcome {
229 accepted_drafts: accepted,
230 proposed_drafts: proposal.len(),
231 keep_len,
232 continuation_token: None,
233 });
234 }
235 return Ok(VerificationOutcome {
236 accepted_drafts: accepted,
237 proposed_drafts: proposal.len(),
238 keep_len,
239 continuation_token: Some(sampled_token),
240 });
241 }
242
243 let row = logit_row(&verify_logits, accepted)?;
244 let sampler = seq.sampler();
245 let target_probs =
246 sampler.speculative_target_probs(flat_logits(row.clone())?, seq.get_toks())?;
247 let continuation = sampler.sample_from_probs(&target_probs, return_logprobs, rng)?;
248 let continuation_token = continuation.token;
249 finish_or_add_toks_to_seq(pipeline, prefix_cacher, seq, continuation, eos_tok, true).await?;
250
251 let keep_len = base_len + 1 + accepted;
252 let continuation_token = if matches!(seq.getstate(), SequenceState::Done(_)) {
253 seq.clear_staged_speculative_tokens();
254 None
255 } else {
256 Some(continuation_token)
257 };
258
259 Ok(VerificationOutcome {
260 accepted_drafts: accepted,
261 proposed_drafts: proposal.len(),
262 keep_len,
263 continuation_token,
264 })
265}
266
267fn logit_row(logits: &Tensor, row: usize) -> Result<Tensor> {
268 match logits.dims() {
269 [_, rows, _] => {
270 if row >= *rows {
271 hanzo_ml::bail!("speculative logit row {row} is out of range for {rows} rows");
272 }
273 logits.narrow(1, row, 1)
274 }
275 [rows, _] => {
276 if row >= *rows {
277 hanzo_ml::bail!("speculative logit row {row} is out of range for {rows} rows");
278 }
279 logits.narrow(0, row, 1)
280 }
281 shape => hanzo_ml::bail!("speculative logits have unsupported shape {shape:?}"),
282 }
283}
284
285fn flat_logits(logits: Tensor) -> Result<Tensor> {
286 match logits.dims() {
287 [1, 1, _] => logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32),
288 [1, _] => logits.squeeze(0)?.to_dtype(DType::F32),
289 [_] => logits.to_dtype(DType::F32),
290 dims => hanzo_ml::bail!("speculative logit row must flatten to rank 1, got {dims:?}"),
291 }
292}
293
294fn normalize_probs(probs: &mut [f32]) -> Result<()> {
295 let sum: f32 = probs
296 .iter()
297 .copied()
298 .filter(|prob| prob.is_finite() && *prob > 0.0)
299 .sum();
300 if sum <= 0.0 {
301 hanzo_ml::bail!("all probabilities are zero in speculative adjusted distribution");
302 }
303 for prob in probs.iter_mut() {
304 if prob.is_finite() && *prob > 0.0 {
305 *prob /= sum;
306 } else {
307 *prob = 0.0;
308 }
309 }
310 Ok(())
311}