hanzo_engine/speculative/
cache.rs1use std::collections::HashMap;
2
3use hanzo_ml::{Device, Result, Tensor};
4
5use crate::device_map::DeviceMapper;
6use crate::paged_attention::CacheEngine;
7use crate::pipeline::text_models_inputs_processor::{
8 FlashParams, InputMetadata, PagedAttentionInputMetadata, PagedAttentionMeta,
9};
10use crate::sequence::Sequence;
11
12use super::proposer::SpeculativeKvCache;
13
14#[derive(Clone, Copy)]
15pub struct SpeculativeCacheOutcome {
16 pub keep_len: usize,
17 pub accepted_all: bool,
18}
19
20pub trait SpeculativeCacheGuard {
21 fn commit(&mut self) -> Result<()>;
22 fn rollback_to(&mut self, keep_len: usize) -> Result<()>;
23}
24
25pub trait SpeculativeCacheAccess {
26 type Guard: SpeculativeCacheGuard;
27
28 fn begin(
31 &self,
32 seq_id: usize,
33 base_len: usize,
34 verify_len: usize,
35 ) -> Result<Option<Self::Guard>>;
36
37 fn guard_for_reserved(&self, seq_id: usize, base_len: usize, verify_len: usize) -> Self::Guard;
38
39 fn make_verify_input_metadata(
40 &self,
41 verify_tokens: &[u32],
42 seq_id: usize,
43 base_len: usize,
44 device: &Device,
45 mapper: &dyn DeviceMapper,
46 ) -> Result<InputMetadata>;
47
48 fn proposer_cache(&self, sequences: &[&Sequence]) -> Result<SpeculativeKvCache<'_>>;
49
50 fn finish_verification(
51 &self,
52 guard: &mut Self::Guard,
53 _seq: &mut Sequence,
54 keep_len: usize,
55 accepted_all: bool,
56 ) -> Result<()> {
57 if accepted_all {
58 guard.commit()
59 } else {
60 guard.rollback_to(keep_len)
61 }
62 }
63
64 fn finish_verification_batch(
65 &self,
66 guards: &mut [Option<Self::Guard>],
67 seqs: &mut [&mut Sequence],
68 outcomes: &[Option<SpeculativeCacheOutcome>],
69 ) -> Result<()> {
70 if guards.len() != seqs.len() || outcomes.len() != seqs.len() {
71 hanzo_ml::bail!(
72 "speculative cache batch shape mismatch: guards={}, seqs={}, outcomes={}",
73 guards.len(),
74 seqs.len(),
75 outcomes.len()
76 );
77 }
78 for ((guard, seq), outcome) in guards.iter_mut().zip(seqs.iter_mut()).zip(outcomes) {
79 let (Some(guard), Some(outcome)) = (guard.as_mut(), outcome) else {
80 continue;
81 };
82 self.finish_verification(guard, seq, outcome.keep_len, outcome.accepted_all)?;
83 }
84 Ok(())
85 }
86
87 fn can_stage_proposal(
88 &self,
89 _sequences: &[&Sequence],
90 _base_lens: &[usize],
91 _proposal_len: usize,
92 ) -> bool {
93 true
94 }
95}
96
97pub struct PagedSpeculativeCacheAccess<'a> {
98 metadata: &'a PagedAttentionMeta,
99 kv_cache: Vec<(Tensor, Tensor)>,
100}
101
102impl<'a> PagedSpeculativeCacheAccess<'a> {
103 pub fn new(metadata: &'a PagedAttentionMeta, cache_engine: &CacheEngine) -> Self {
104 Self {
105 metadata,
106 kv_cache: cache_engine.get_kv_cache().clone(),
107 }
108 }
109}
110
111pub struct PagedSpeculativeCacheGuard<'a> {
112 metadata: &'a PagedAttentionMeta,
113 seq_id: usize,
114 reserved_len: usize,
115}
116
117impl SpeculativeCacheGuard for PagedSpeculativeCacheGuard<'_> {
118 fn commit(&mut self) -> Result<()> {
119 Ok(())
120 }
121
122 fn rollback_to(&mut self, keep_len: usize) -> Result<()> {
123 if keep_len < self.reserved_len {
124 let mut kv_mgr = crate::get_mut_arcmutex!(self.metadata.kv_cache_manager);
125 kv_mgr.trim_request_to_num_tokens(self.seq_id, keep_len);
126 }
127 Ok(())
128 }
129}
130
131impl<'a> SpeculativeCacheAccess for PagedSpeculativeCacheAccess<'a> {
132 type Guard = PagedSpeculativeCacheGuard<'a>;
133
134 fn begin(
135 &self,
136 seq_id: usize,
137 base_len: usize,
138 verify_len: usize,
139 ) -> Result<Option<Self::Guard>> {
140 let reserved_len = base_len + verify_len;
141 let mut kv_mgr = crate::get_mut_arcmutex!(self.metadata.kv_cache_manager);
142 let Some(_) = kv_mgr.allocate_slots(seq_id, reserved_len, &[]) else {
143 return Ok(None);
144 };
145 Ok(Some(PagedSpeculativeCacheGuard {
146 metadata: self.metadata,
147 seq_id,
148 reserved_len,
149 }))
150 }
151
152 fn guard_for_reserved(
153 &self,
154 seq_id: usize,
155 base_len: usize,
156 verify_len: usize,
157 ) -> PagedSpeculativeCacheGuard<'a> {
158 PagedSpeculativeCacheGuard {
159 metadata: self.metadata,
160 seq_id,
161 reserved_len: base_len + verify_len,
162 }
163 }
164
165 fn make_verify_input_metadata(
166 &self,
167 verify_tokens: &[u32],
168 seq_id: usize,
169 base_len: usize,
170 device: &Device,
171 mapper: &dyn DeviceMapper,
172 ) -> Result<InputMetadata> {
173 let verify_len = verify_tokens.len();
174 if verify_len == 0 {
175 hanzo_ml::bail!("speculative verification requires at least one token.");
176 }
177
178 let kv_mgr = crate::get_mut_arcmutex!(self.metadata.kv_cache_manager);
179 let full_table = kv_mgr
180 .get_block_ids(seq_id)
181 .ok_or_else(|| {
182 hanzo_ml::Error::Msg(format!(
183 "speculative sequence {seq_id} has no paged-attention blocks"
184 ))
185 })?
186 .to_vec();
187 drop(kv_mgr);
188
189 let mut slot_mappings = Vec::with_capacity(verify_len);
190 let mut block_tables = Vec::with_capacity(verify_len);
191 let mut context_lens = Vec::with_capacity(verify_len);
192 let mut full_block_tables = Vec::with_capacity(verify_len);
193 let mut full_context_lens = Vec::with_capacity(verify_len);
194
195 for row in 0..verify_len {
196 let token_pos = base_len + row;
197 let full_context_len = token_pos + 1;
198 let block_number = full_table
199 .get(token_pos / self.metadata.block_size)
200 .copied()
201 .ok_or_else(|| {
202 hanzo_ml::Error::Msg(format!(
203 "speculative verification block table is too small: token_pos={token_pos}, block_size={}, table_len={}",
204 self.metadata.block_size,
205 full_table.len()
206 ))
207 })?;
208 let slot = block_number
209 .checked_mul(self.metadata.block_size)
210 .and_then(|v| v.checked_add(token_pos % self.metadata.block_size))
211 .ok_or_else(|| {
212 hanzo_ml::Error::Msg("speculative verification slot overflowed".to_string())
213 })?;
214 slot_mappings.push(slot as i64);
215
216 full_block_tables.push(full_table.clone());
217 full_context_lens.push(usize_to_u32(full_context_len, "full context length")?);
218
219 if let Some(sliding_window) = self.metadata.sliding_window {
220 let window_start = full_context_len.saturating_sub(sliding_window);
221 let slide_idx = window_start / self.metadata.block_size;
222 let block_aligned_start = slide_idx * self.metadata.block_size;
223 let context_len = full_context_len.saturating_sub(block_aligned_start);
224 let needed_blocks = context_len.div_ceil(self.metadata.block_size);
225 let slide_end = (slide_idx + needed_blocks).min(full_table.len());
226 block_tables.push(full_table.get(slide_idx..slide_end).unwrap_or(&[]).to_vec());
227 context_lens.push(usize_to_u32(context_len, "context length")?);
228 } else {
229 block_tables.push(full_table.clone());
230 context_lens.push(usize_to_u32(full_context_len, "context length")?);
231 }
232 }
233
234 let cpu = Device::Cpu;
235 let input = Tensor::from_vec(verify_tokens.to_vec(), (1, verify_len), device)?;
236 let slot_mappings = Tensor::from_vec(slot_mappings, (1, verify_len), &cpu)?;
237
238 let max_block_table_len = block_tables.iter().map(Vec::len).max().unwrap_or(1).max(1);
239 let block_tables = repeated_table_tensor(&block_tables, max_block_table_len, &cpu)?;
240 let context_lens = Tensor::from_vec(context_lens, (verify_len,), &cpu)?;
241
242 let full_max_block_table_len = full_block_tables
243 .iter()
244 .map(Vec::len)
245 .max()
246 .unwrap_or(1)
247 .max(1);
248 let full_block_tables =
249 repeated_table_tensor(&full_block_tables, full_max_block_table_len, &cpu)?;
250 let full_context_lens = Tensor::from_vec(full_context_lens, (verify_len,), &cpu)?;
251
252 let metadata = InputMetadata {
253 input,
254 positions: vec![base_len],
255 context_lens: vec![(0, verify_len)],
256 position_ids: vec![base_len + verify_len],
257 paged_attn_meta: Some(PagedAttentionInputMetadata {
258 block_tables: Some(map_to_devices(&block_tables, device, mapper)?),
259 context_lens: Some(map_to_devices(&context_lens, device, mapper)?),
260 slot_mappings: map_to_devices(&slot_mappings, device, mapper)?,
261 max_context_len: Some(
262 context_lens
263 .to_vec1::<u32>()?
264 .into_iter()
265 .max()
266 .unwrap_or(0) as usize,
267 ),
268 full_block_tables: Some(map_to_devices(&full_block_tables, device, mapper)?),
269 full_context_lens: Some(map_to_devices(&full_context_lens, device, mapper)?),
270 full_max_context_len: Some(base_len + verify_len),
271 is_first_prompt_chunk: false,
272 paged_kv_indptr: None,
273 paged_kv_indices: None,
274 paged_kv_last_page_len: None,
275 paged_kv_request_indices: None,
276 paged_kv_tile_indices: None,
277 paged_kv_o_indptr: None,
278 paged_kv_chunk_size: None,
279 num_cached_tokens: None,
280 query_lens: None,
281 cu_seqlens_q: None,
282 cu_seqlens_kv: None,
283 }),
284 flash_meta: FlashParams::empty(true),
285 };
286 Ok(metadata)
287 }
288
289 fn proposer_cache(&self, _sequences: &[&Sequence]) -> Result<SpeculativeKvCache<'_>> {
290 Ok(SpeculativeKvCache::Paged {
291 metadata: self.metadata,
292 kv_cache: &self.kv_cache,
293 })
294 }
295}
296
297fn repeated_table_tensor(rows: &[Vec<usize>], max_len: usize, device: &Device) -> Result<Tensor> {
298 let mut values = Vec::with_capacity(rows.len() * max_len);
299 for row in rows {
300 for value in row {
301 values.push(usize_to_u32(*value, "block table entry")?);
302 }
303 values.extend(std::iter::repeat_n(0u32, max_len.saturating_sub(row.len())));
304 }
305 Tensor::from_vec(values, (rows.len(), max_len), device)
306}
307
308fn usize_to_u32(value: usize, name: &str) -> Result<u32> {
309 u32::try_from(value)
310 .map_err(|_| hanzo_ml::Error::Msg(format!("{name} exceeds u32::MAX: {value}")))
311}
312
313fn map_to_devices(
314 tensor: &Tensor,
315 device: &Device,
316 mapper: &dyn DeviceMapper,
317) -> Result<HashMap<hanzo_ml::DeviceLocation, Tensor>> {
318 let mut devices = mapper.get_unique_devices();
319 if !devices
320 .iter()
321 .any(|dev| dev.location() == device.location())
322 {
323 devices.push(device.clone());
324 }
325
326 let mut map = HashMap::new();
327 for dev in devices {
328 map.insert(dev.location(), tensor.to_device(&dev)?);
329 }
330 Ok(map)
331}