Skip to main content

hanzo_engine/speculative/
verifier.rs

1use std::sync::Arc;
2
3use hanzo_ml::{DType, Result, Tensor};
4use rand::Rng;
5use rand_isaac::Isaac64Rng;
6
7use crate::pipeline::sampling::{finish_or_add_toks_to_seq, sample_sequence};
8use crate::pipeline::Pipeline;
9use crate::prefix_cacher::PrefixCacheManagerV2;
10use crate::sampler::Logprobs;
11use crate::sequence::{Sequence, SequenceRecognizer, SequenceState};
12
13pub struct VerificationOutcome {
14    pub accepted_drafts: usize,
15    pub proposed_drafts: usize,
16    pub keep_len: usize,
17    pub continuation_token: Option<u32>,
18}
19
20#[allow(clippy::too_many_arguments)]
21pub async fn finish_verified_step<P: Pipeline>(
22    pipeline: &P,
23    seq: &mut Sequence,
24    verify_logits: Tensor,
25    proposal: Vec<u32>,
26    proposal_logits: Option<Tensor>,
27    base_len: usize,
28    prefix_cacher: &mut PrefixCacheManagerV2,
29    disable_eos_stop: bool,
30    rng: Arc<std::sync::Mutex<Isaac64Rng>>,
31    anchor_to_emit: Option<Logprobs>,
32) -> Result<VerificationOutcome> {
33    let general_metadata = pipeline.get_metadata();
34    let eos_tok = if disable_eos_stop {
35        None
36    } else {
37        Some(&general_metadata.eos_tok[..])
38    };
39    let return_logprobs = seq.return_logprobs();
40
41    if let Some(anchor) = anchor_to_emit {
42        finish_or_add_toks_to_seq(pipeline, prefix_cacher, seq, anchor, eos_tok, true).await?;
43        if matches!(seq.getstate(), SequenceState::Done(_)) {
44            let keep_len = base_len + 1;
45            seq.clear_staged_speculative_tokens();
46            return Ok(VerificationOutcome {
47                accepted_drafts: 0,
48                proposed_drafts: proposal.len(),
49                keep_len,
50                continuation_token: None,
51            });
52        }
53    }
54
55    if let Some(proposal_logits) = proposal_logits {
56        if !seq.sampler().is_argmax() && matches!(seq.recognizer, SequenceRecognizer::None) {
57            return finish_verified_step_stochastic(
58                pipeline,
59                seq,
60                verify_logits,
61                proposal,
62                proposal_logits,
63                base_len,
64                prefix_cacher,
65                eos_tok,
66                return_logprobs,
67                rng,
68            )
69            .await;
70        }
71    }
72
73    let mut accepted = 0usize;
74    for (idx, draft) in proposal.iter().copied().enumerate() {
75        let row = logit_row(&verify_logits, idx)?;
76        let sampled = sample_sequence(
77            row.clone(),
78            seq,
79            return_logprobs,
80            rng.clone(),
81            false,
82            false,
83            false,
84        )
85        .await?;
86        let sampled_token = sampled.token;
87        if sampled_token == draft {
88            accepted += 1;
89            finish_or_add_toks_to_seq(pipeline, prefix_cacher, seq, sampled, eos_tok, true).await?;
90            if matches!(seq.getstate(), SequenceState::Done(_)) {
91                let keep_len = base_len + 1 + accepted;
92                seq.clear_staged_speculative_tokens();
93                return Ok(VerificationOutcome {
94                    accepted_drafts: accepted,
95                    proposed_drafts: proposal.len(),
96                    keep_len,
97                    continuation_token: None,
98                });
99            }
100        } else {
101            let keep_len = base_len + 1 + accepted;
102            finish_or_add_toks_to_seq(pipeline, prefix_cacher, seq, sampled, eos_tok, true).await?;
103            if matches!(seq.getstate(), SequenceState::Done(_)) {
104                seq.clear_staged_speculative_tokens();
105                return Ok(VerificationOutcome {
106                    accepted_drafts: accepted,
107                    proposed_drafts: proposal.len(),
108                    keep_len,
109                    continuation_token: None,
110                });
111            }
112            return Ok(VerificationOutcome {
113                accepted_drafts: accepted,
114                proposed_drafts: proposal.len(),
115                keep_len,
116                continuation_token: Some(sampled_token),
117            });
118        }
119    }
120
121    let row = logit_row(&verify_logits, accepted)?;
122    let continuation = sample_sequence(
123        row.clone(),
124        seq,
125        return_logprobs,
126        rng.clone(),
127        false,
128        false,
129        false,
130    )
131    .await?;
132    let continuation_token = continuation.token;
133    finish_or_add_toks_to_seq(pipeline, prefix_cacher, seq, continuation, eos_tok, true).await?;
134
135    let keep_len = base_len + 1 + accepted;
136    let continuation_token = if matches!(seq.getstate(), SequenceState::Done(_)) {
137        seq.clear_staged_speculative_tokens();
138        None
139    } else {
140        Some(continuation_token)
141    };
142
143    Ok(VerificationOutcome {
144        accepted_drafts: accepted,
145        proposed_drafts: proposal.len(),
146        keep_len,
147        continuation_token,
148    })
149}
150
151#[allow(clippy::too_many_arguments)]
152async fn finish_verified_step_stochastic<P: Pipeline>(
153    pipeline: &P,
154    seq: &mut Sequence,
155    verify_logits: Tensor,
156    proposal: Vec<u32>,
157    proposal_logits: Tensor,
158    base_len: usize,
159    prefix_cacher: &mut PrefixCacheManagerV2,
160    eos_tok: Option<&[u32]>,
161    return_logprobs: bool,
162    rng: Arc<std::sync::Mutex<Isaac64Rng>>,
163) -> Result<VerificationOutcome> {
164    let mut accepted = 0usize;
165    for (idx, draft) in proposal.iter().copied().enumerate() {
166        let target_row = logit_row(&verify_logits, idx)?;
167        let candidate_row = logit_row(&proposal_logits, idx)?;
168        let sampler = seq.sampler();
169        let target_probs =
170            sampler.speculative_target_probs(flat_logits(target_row.clone())?, seq.get_toks())?;
171        let candidate_probs =
172            sampler.speculative_candidate_probs(flat_logits(candidate_row)?, seq.get_toks())?;
173        if target_probs.len() != candidate_probs.len() {
174            hanzo_ml::bail!(
175                "speculative target/candidate vocab mismatch: target={}, candidate={}",
176                target_probs.len(),
177                candidate_probs.len()
178            );
179        }
180        let draft_idx = draft as usize;
181        let p_i = target_probs.get(draft_idx).copied().unwrap_or(0.0);
182        let q_i = candidate_probs.get(draft_idx).copied().unwrap_or(0.0);
183        let accept_prob = if q_i <= 0.0 {
184            if p_i > 0.0 {
185                1.0
186            } else {
187                0.0
188            }
189        } else {
190            (p_i / q_i).min(1.0)
191        };
192        let draw = {
193            let mut rng = rng.lock().expect("could not lock rng mutex");
194            rng.random::<f32>()
195        };
196
197        if draw <= accept_prob {
198            accepted += 1;
199            let sampled = sampler.logprobs_from_probs(draft, &target_probs, return_logprobs)?;
200            finish_or_add_toks_to_seq(pipeline, prefix_cacher, seq, sampled, eos_tok, true).await?;
201            if matches!(seq.getstate(), SequenceState::Done(_)) {
202                let keep_len = base_len + 1 + accepted;
203                seq.clear_staged_speculative_tokens();
204                return Ok(VerificationOutcome {
205                    accepted_drafts: accepted,
206                    proposed_drafts: proposal.len(),
207                    keep_len,
208                    continuation_token: None,
209                });
210            }
211            continue;
212        }
213
214        let mut adjusted_probs = target_probs
215            .iter()
216            .zip(candidate_probs.iter())
217            .map(|(p, q)| (p - q).max(0.0))
218            .collect::<Vec<_>>();
219        if normalize_probs(&mut adjusted_probs).is_err() {
220            adjusted_probs = target_probs;
221        }
222        let sampled = sampler.sample_from_probs(&adjusted_probs, return_logprobs, rng.clone())?;
223        let sampled_token = sampled.token;
224        let keep_len = base_len + 1 + accepted;
225        finish_or_add_toks_to_seq(pipeline, prefix_cacher, seq, sampled, eos_tok, true).await?;
226        if matches!(seq.getstate(), SequenceState::Done(_)) {
227            seq.clear_staged_speculative_tokens();
228            return Ok(VerificationOutcome {
229                accepted_drafts: accepted,
230                proposed_drafts: proposal.len(),
231                keep_len,
232                continuation_token: None,
233            });
234        }
235        return Ok(VerificationOutcome {
236            accepted_drafts: accepted,
237            proposed_drafts: proposal.len(),
238            keep_len,
239            continuation_token: Some(sampled_token),
240        });
241    }
242
243    let row = logit_row(&verify_logits, accepted)?;
244    let sampler = seq.sampler();
245    let target_probs =
246        sampler.speculative_target_probs(flat_logits(row.clone())?, seq.get_toks())?;
247    let continuation = sampler.sample_from_probs(&target_probs, return_logprobs, rng)?;
248    let continuation_token = continuation.token;
249    finish_or_add_toks_to_seq(pipeline, prefix_cacher, seq, continuation, eos_tok, true).await?;
250
251    let keep_len = base_len + 1 + accepted;
252    let continuation_token = if matches!(seq.getstate(), SequenceState::Done(_)) {
253        seq.clear_staged_speculative_tokens();
254        None
255    } else {
256        Some(continuation_token)
257    };
258
259    Ok(VerificationOutcome {
260        accepted_drafts: accepted,
261        proposed_drafts: proposal.len(),
262        keep_len,
263        continuation_token,
264    })
265}
266
267fn logit_row(logits: &Tensor, row: usize) -> Result<Tensor> {
268    match logits.dims() {
269        [_, rows, _] => {
270            if row >= *rows {
271                hanzo_ml::bail!("speculative logit row {row} is out of range for {rows} rows");
272            }
273            logits.narrow(1, row, 1)
274        }
275        [rows, _] => {
276            if row >= *rows {
277                hanzo_ml::bail!("speculative logit row {row} is out of range for {rows} rows");
278            }
279            logits.narrow(0, row, 1)
280        }
281        shape => hanzo_ml::bail!("speculative logits have unsupported shape {shape:?}"),
282    }
283}
284
285fn flat_logits(logits: Tensor) -> Result<Tensor> {
286    match logits.dims() {
287        [1, 1, _] => logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32),
288        [1, _] => logits.squeeze(0)?.to_dtype(DType::F32),
289        [_] => logits.to_dtype(DType::F32),
290        dims => hanzo_ml::bail!("speculative logit row must flatten to rank 1, got {dims:?}"),
291    }
292}
293
294fn normalize_probs(probs: &mut [f32]) -> Result<()> {
295    let sum: f32 = probs
296        .iter()
297        .copied()
298        .filter(|prob| prob.is_finite() && *prob > 0.0)
299        .sum();
300    if sum <= 0.0 {
301        hanzo_ml::bail!("all probabilities are zero in speculative adjusted distribution");
302    }
303    for prob in probs.iter_mut() {
304        if prob.is_finite() && *prob > 0.0 {
305            *prob /= sum;
306        } else {
307            *prob = 0.0;
308        }
309    }
310    Ok(())
311}