Skip to main content

hanzo_engine/speculative/
cache.rs

1use std::collections::HashMap;
2
3use hanzo_ml::{Device, Result, Tensor};
4
5use crate::device_map::DeviceMapper;
6use crate::paged_attention::CacheEngine;
7use crate::pipeline::text_models_inputs_processor::{
8    FlashParams, InputMetadata, PagedAttentionInputMetadata, PagedAttentionMeta,
9};
10use crate::sequence::Sequence;
11
12use super::proposer::SpeculativeKvCache;
13
14#[derive(Clone, Copy)]
15pub struct SpeculativeCacheOutcome {
16    pub keep_len: usize,
17    pub accepted_all: bool,
18}
19
20pub trait SpeculativeCacheGuard {
21    fn commit(&mut self) -> Result<()>;
22    fn rollback_to(&mut self, keep_len: usize) -> Result<()>;
23}
24
25pub trait SpeculativeCacheAccess {
26    type Guard: SpeculativeCacheGuard;
27
28    /// Returns `Ok(None)` when the cache cannot reserve speculative slots and
29    /// the caller should fall back to normal decoding for this step.
30    fn begin(
31        &self,
32        seq_id: usize,
33        base_len: usize,
34        verify_len: usize,
35    ) -> Result<Option<Self::Guard>>;
36
37    fn guard_for_reserved(&self, seq_id: usize, base_len: usize, verify_len: usize) -> Self::Guard;
38
39    fn make_verify_input_metadata(
40        &self,
41        verify_tokens: &[u32],
42        seq_id: usize,
43        base_len: usize,
44        device: &Device,
45        mapper: &dyn DeviceMapper,
46    ) -> Result<InputMetadata>;
47
48    fn proposer_cache(&self, sequences: &[&Sequence]) -> Result<SpeculativeKvCache<'_>>;
49
50    fn finish_verification(
51        &self,
52        guard: &mut Self::Guard,
53        _seq: &mut Sequence,
54        keep_len: usize,
55        accepted_all: bool,
56    ) -> Result<()> {
57        if accepted_all {
58            guard.commit()
59        } else {
60            guard.rollback_to(keep_len)
61        }
62    }
63
64    fn finish_verification_batch(
65        &self,
66        guards: &mut [Option<Self::Guard>],
67        seqs: &mut [&mut Sequence],
68        outcomes: &[Option<SpeculativeCacheOutcome>],
69    ) -> Result<()> {
70        if guards.len() != seqs.len() || outcomes.len() != seqs.len() {
71            hanzo_ml::bail!(
72                "speculative cache batch shape mismatch: guards={}, seqs={}, outcomes={}",
73                guards.len(),
74                seqs.len(),
75                outcomes.len()
76            );
77        }
78        for ((guard, seq), outcome) in guards.iter_mut().zip(seqs.iter_mut()).zip(outcomes) {
79            let (Some(guard), Some(outcome)) = (guard.as_mut(), outcome) else {
80                continue;
81            };
82            self.finish_verification(guard, seq, outcome.keep_len, outcome.accepted_all)?;
83        }
84        Ok(())
85    }
86
87    fn can_stage_proposal(
88        &self,
89        _sequences: &[&Sequence],
90        _base_lens: &[usize],
91        _proposal_len: usize,
92    ) -> bool {
93        true
94    }
95}
96
97pub struct PagedSpeculativeCacheAccess<'a> {
98    metadata: &'a PagedAttentionMeta,
99    kv_cache: Vec<(Tensor, Tensor)>,
100}
101
102impl<'a> PagedSpeculativeCacheAccess<'a> {
103    pub fn new(metadata: &'a PagedAttentionMeta, cache_engine: &CacheEngine) -> Self {
104        Self {
105            metadata,
106            kv_cache: cache_engine.get_kv_cache().clone(),
107        }
108    }
109}
110
111pub struct PagedSpeculativeCacheGuard<'a> {
112    metadata: &'a PagedAttentionMeta,
113    seq_id: usize,
114    reserved_len: usize,
115}
116
117impl SpeculativeCacheGuard for PagedSpeculativeCacheGuard<'_> {
118    fn commit(&mut self) -> Result<()> {
119        Ok(())
120    }
121
122    fn rollback_to(&mut self, keep_len: usize) -> Result<()> {
123        if keep_len < self.reserved_len {
124            let mut kv_mgr = crate::get_mut_arcmutex!(self.metadata.kv_cache_manager);
125            kv_mgr.trim_request_to_num_tokens(self.seq_id, keep_len);
126        }
127        Ok(())
128    }
129}
130
131impl<'a> SpeculativeCacheAccess for PagedSpeculativeCacheAccess<'a> {
132    type Guard = PagedSpeculativeCacheGuard<'a>;
133
134    fn begin(
135        &self,
136        seq_id: usize,
137        base_len: usize,
138        verify_len: usize,
139    ) -> Result<Option<Self::Guard>> {
140        let reserved_len = base_len + verify_len;
141        let mut kv_mgr = crate::get_mut_arcmutex!(self.metadata.kv_cache_manager);
142        let Some(_) = kv_mgr.allocate_slots(seq_id, reserved_len, &[]) else {
143            return Ok(None);
144        };
145        Ok(Some(PagedSpeculativeCacheGuard {
146            metadata: self.metadata,
147            seq_id,
148            reserved_len,
149        }))
150    }
151
152    fn guard_for_reserved(
153        &self,
154        seq_id: usize,
155        base_len: usize,
156        verify_len: usize,
157    ) -> PagedSpeculativeCacheGuard<'a> {
158        PagedSpeculativeCacheGuard {
159            metadata: self.metadata,
160            seq_id,
161            reserved_len: base_len + verify_len,
162        }
163    }
164
165    fn make_verify_input_metadata(
166        &self,
167        verify_tokens: &[u32],
168        seq_id: usize,
169        base_len: usize,
170        device: &Device,
171        mapper: &dyn DeviceMapper,
172    ) -> Result<InputMetadata> {
173        let verify_len = verify_tokens.len();
174        if verify_len == 0 {
175            hanzo_ml::bail!("speculative verification requires at least one token.");
176        }
177
178        let kv_mgr = crate::get_mut_arcmutex!(self.metadata.kv_cache_manager);
179        let full_table = kv_mgr
180            .get_block_ids(seq_id)
181            .ok_or_else(|| {
182                hanzo_ml::Error::Msg(format!(
183                    "speculative sequence {seq_id} has no paged-attention blocks"
184                ))
185            })?
186            .to_vec();
187        drop(kv_mgr);
188
189        let mut slot_mappings = Vec::with_capacity(verify_len);
190        let mut block_tables = Vec::with_capacity(verify_len);
191        let mut context_lens = Vec::with_capacity(verify_len);
192        let mut full_block_tables = Vec::with_capacity(verify_len);
193        let mut full_context_lens = Vec::with_capacity(verify_len);
194
195        for row in 0..verify_len {
196            let token_pos = base_len + row;
197            let full_context_len = token_pos + 1;
198            let block_number = full_table
199                .get(token_pos / self.metadata.block_size)
200                .copied()
201                .ok_or_else(|| {
202                    hanzo_ml::Error::Msg(format!(
203                        "speculative verification block table is too small: token_pos={token_pos}, block_size={}, table_len={}",
204                        self.metadata.block_size,
205                        full_table.len()
206                    ))
207                })?;
208            let slot = block_number
209                .checked_mul(self.metadata.block_size)
210                .and_then(|v| v.checked_add(token_pos % self.metadata.block_size))
211                .ok_or_else(|| {
212                    hanzo_ml::Error::Msg("speculative verification slot overflowed".to_string())
213                })?;
214            slot_mappings.push(slot as i64);
215
216            full_block_tables.push(full_table.clone());
217            full_context_lens.push(usize_to_u32(full_context_len, "full context length")?);
218
219            if let Some(sliding_window) = self.metadata.sliding_window {
220                let window_start = full_context_len.saturating_sub(sliding_window);
221                let slide_idx = window_start / self.metadata.block_size;
222                let block_aligned_start = slide_idx * self.metadata.block_size;
223                let context_len = full_context_len.saturating_sub(block_aligned_start);
224                let needed_blocks = context_len.div_ceil(self.metadata.block_size);
225                let slide_end = (slide_idx + needed_blocks).min(full_table.len());
226                block_tables.push(full_table.get(slide_idx..slide_end).unwrap_or(&[]).to_vec());
227                context_lens.push(usize_to_u32(context_len, "context length")?);
228            } else {
229                block_tables.push(full_table.clone());
230                context_lens.push(usize_to_u32(full_context_len, "context length")?);
231            }
232        }
233
234        let cpu = Device::Cpu;
235        let input = Tensor::from_vec(verify_tokens.to_vec(), (1, verify_len), device)?;
236        let slot_mappings = Tensor::from_vec(slot_mappings, (1, verify_len), &cpu)?;
237
238        let max_block_table_len = block_tables.iter().map(Vec::len).max().unwrap_or(1).max(1);
239        let block_tables = repeated_table_tensor(&block_tables, max_block_table_len, &cpu)?;
240        let context_lens = Tensor::from_vec(context_lens, (verify_len,), &cpu)?;
241
242        let full_max_block_table_len = full_block_tables
243            .iter()
244            .map(Vec::len)
245            .max()
246            .unwrap_or(1)
247            .max(1);
248        let full_block_tables =
249            repeated_table_tensor(&full_block_tables, full_max_block_table_len, &cpu)?;
250        let full_context_lens = Tensor::from_vec(full_context_lens, (verify_len,), &cpu)?;
251
252        let metadata = InputMetadata {
253            input,
254            positions: vec![base_len],
255            context_lens: vec![(0, verify_len)],
256            position_ids: vec![base_len + verify_len],
257            paged_attn_meta: Some(PagedAttentionInputMetadata {
258                block_tables: Some(map_to_devices(&block_tables, device, mapper)?),
259                context_lens: Some(map_to_devices(&context_lens, device, mapper)?),
260                slot_mappings: map_to_devices(&slot_mappings, device, mapper)?,
261                max_context_len: Some(
262                    context_lens
263                        .to_vec1::<u32>()?
264                        .into_iter()
265                        .max()
266                        .unwrap_or(0) as usize,
267                ),
268                full_block_tables: Some(map_to_devices(&full_block_tables, device, mapper)?),
269                full_context_lens: Some(map_to_devices(&full_context_lens, device, mapper)?),
270                full_max_context_len: Some(base_len + verify_len),
271                is_first_prompt_chunk: false,
272                paged_kv_indptr: None,
273                paged_kv_indices: None,
274                paged_kv_last_page_len: None,
275                paged_kv_request_indices: None,
276                paged_kv_tile_indices: None,
277                paged_kv_o_indptr: None,
278                paged_kv_chunk_size: None,
279                num_cached_tokens: None,
280                query_lens: None,
281                cu_seqlens_q: None,
282                cu_seqlens_kv: None,
283            }),
284            flash_meta: FlashParams::empty(true),
285        };
286        Ok(metadata)
287    }
288
289    fn proposer_cache(&self, _sequences: &[&Sequence]) -> Result<SpeculativeKvCache<'_>> {
290        Ok(SpeculativeKvCache::Paged {
291            metadata: self.metadata,
292            kv_cache: &self.kv_cache,
293        })
294    }
295}
296
297fn repeated_table_tensor(rows: &[Vec<usize>], max_len: usize, device: &Device) -> Result<Tensor> {
298    let mut values = Vec::with_capacity(rows.len() * max_len);
299    for row in rows {
300        for value in row {
301            values.push(usize_to_u32(*value, "block table entry")?);
302        }
303        values.extend(std::iter::repeat_n(0u32, max_len.saturating_sub(row.len())));
304    }
305    Tensor::from_vec(values, (rows.len(), max_len), device)
306}
307
308fn usize_to_u32(value: usize, name: &str) -> Result<u32> {
309    u32::try_from(value)
310        .map_err(|_| hanzo_ml::Error::Msg(format!("{name} exceeds u32::MAX: {value}")))
311}
312
313fn map_to_devices(
314    tensor: &Tensor,
315    device: &Device,
316    mapper: &dyn DeviceMapper,
317) -> Result<HashMap<hanzo_ml::DeviceLocation, Tensor>> {
318    let mut devices = mapper.get_unique_devices();
319    if !devices
320        .iter()
321        .any(|dev| dev.location() == device.location())
322    {
323        devices.push(device.clone());
324    }
325
326    let mut map = HashMap::new();
327    for dev in devices {
328        map.insert(dev.location(), tensor.to_device(&dev)?);
329    }
330    Ok(map)
331}