hanzo_engine/paged_attention/
cache_engine.rs1use std::{
2 str::FromStr,
3 sync::{Arc, Mutex, MutexGuard},
4};
5
6use hanzo_ml::{DType, Device, Result, Tensor};
7use serde::{Deserialize, Serialize};
8
9use super::config::{KvCacheLayout, ModelConfigLike};
10
11#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Default)]
12#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
13pub enum PagedCacheType {
14 #[default]
15 Auto,
16 F8E4M3,
17}
18
19impl PagedCacheType {
20 pub fn to_dtype(&self, act_dtype: DType) -> DType {
21 match self {
22 PagedCacheType::F8E4M3 => DType::F8E4M3,
23 PagedCacheType::Auto => act_dtype,
24 }
25 }
26}
27
28impl FromStr for PagedCacheType {
29 type Err = String;
30 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
31 match s {
32 "auto" => Ok(Self::Auto),
33 "f8e4m3" => Ok(Self::F8E4M3),
34 other => Err(format!(
35 "Unexpected `PagedCacheType`, got `{other}` but expected `auto` and `f8e4m3`."
36 )),
37 }
38 }
39}
40
41#[derive(Clone, Debug)]
42pub struct CacheConfig {
43 pub block_size: usize,
44 pub num_gpu_blocks: usize,
45 pub cache_type: PagedCacheType,
46}
47
48pub type KVCache = (Tensor, Tensor);
49
50pub struct CacheEngine {
51 gpu_cache: Arc<Mutex<Vec<KVCache>>>,
52}
53
54impl CacheEngine {
55 pub fn new(
56 model_config: &dyn ModelConfigLike,
57 cache_config: &CacheConfig,
58 dtype: DType,
59 device: &Device,
60 layer_devices: Vec<Option<Device>>,
61 ) -> Result<Self> {
62 let dtype = cache_config.cache_type.to_dtype(dtype);
63 Ok(Self {
64 gpu_cache: Arc::new(Mutex::new(Self::allocate_gpu_cache(
65 model_config,
66 cache_config,
67 dtype,
68 device,
69 layer_devices,
70 )?)),
71 })
72 }
73
74 pub fn get_kv_cache(&self) -> MutexGuard<'_, Vec<KVCache>> {
75 self.gpu_cache.lock().expect("KV cache mutex was poisoned")
78 }
79
80 fn allocate_gpu_cache(
81 model_config: &dyn ModelConfigLike,
82 cache_config: &CacheConfig,
83 dtype: DType,
84 device: &Device,
85 layer_devices: Vec<Option<Device>>,
86 ) -> Result<Vec<KVCache>> {
87 let kv_cache_layout = model_config.kv_cache_layout();
88 let mut gpu_cache = Vec::new();
89
90 for (layer_idx, device) in layer_devices
91 .iter()
92 .take(model_config.num_layers())
93 .map(|x| x.as_ref().unwrap_or(device))
94 .enumerate()
95 {
96 let (key_blocks, value_blocks) = match kv_cache_layout {
97 KvCacheLayout::Standard => {
98 let key_block_shape = Self::calculate_key_block_shape(
99 model_config,
100 dtype,
101 cache_config.block_size,
102 layer_idx,
103 );
104 let value_block_shape = Self::calculate_value_block_shape(
105 model_config,
106 cache_config.block_size,
107 layer_idx,
108 );
109 #[allow(unused)]
110 let key_blocks = if let Device::Metal(dev) = &device {
111 #[cfg(feature = "metal")]
112 {
113 use hanzo_ml::{MetalStorage, Shape, Storage};
114
115 let elem_count = cache_config.num_gpu_blocks
116 * key_block_shape.0
117 * key_block_shape.1
118 * key_block_shape.2
119 * key_block_shape.3;
120 let buffer = dev.new_private_buffer(elem_count, dtype, "k_cache")?;
121 let storage = Storage::Metal(MetalStorage::new(
122 buffer,
123 dev.clone(),
124 elem_count,
125 dtype,
126 ));
127 Tensor::from((
128 storage,
129 Shape::from_dims(&[
130 cache_config.num_gpu_blocks,
131 key_block_shape.0,
132 key_block_shape.1,
133 key_block_shape.2,
134 key_block_shape.3,
135 ]),
136 ))
137 }
138
139 #[cfg(not(feature = "metal"))]
140 {
141 unreachable!()
142 }
143 } else {
144 unsafe {
145 Tensor::empty(
146 (
147 cache_config.num_gpu_blocks,
148 key_block_shape.0,
149 key_block_shape.1,
150 key_block_shape.2,
151 key_block_shape.3,
152 ),
153 dtype,
154 device,
155 )?
156 }
157 };
158 #[allow(unused)]
159 let value_blocks = if let Device::Metal(dev) = &device {
160 #[cfg(feature = "metal")]
161 {
162 use hanzo_ml::{MetalStorage, Shape, Storage};
163
164 let elem_count = cache_config.num_gpu_blocks
165 * value_block_shape.0
166 * value_block_shape.1
167 * value_block_shape.2;
168 let buffer = dev.new_private_buffer(elem_count, dtype, "v_cache")?;
169 let storage = Storage::Metal(MetalStorage::new(
170 buffer,
171 dev.clone(),
172 elem_count,
173 dtype,
174 ));
175 Tensor::from((
176 storage,
177 Shape::from_dims(&[
178 cache_config.num_gpu_blocks,
179 value_block_shape.0,
180 value_block_shape.1,
181 value_block_shape.2,
182 ]),
183 ))
184 }
185
186 #[cfg(not(feature = "metal"))]
187 {
188 unreachable!()
189 }
190 } else {
191 unsafe {
192 Tensor::empty(
193 (
194 cache_config.num_gpu_blocks,
195 value_block_shape.0,
196 value_block_shape.1,
197 value_block_shape.2,
198 ),
199 dtype,
200 device,
201 )?
202 }
203 };
204 (key_blocks, value_blocks)
205 }
206 KvCacheLayout::Mla {
207 kv_lora_rank,
208 kpe_head_dim,
209 } => {
210 #[allow(unused)]
211 let key_blocks = if let Device::Metal(dev) = &device {
212 #[cfg(feature = "metal")]
213 {
214 use hanzo_ml::{MetalStorage, Shape, Storage};
215
216 let elem_count = cache_config.num_gpu_blocks
217 * cache_config.block_size
218 * kv_lora_rank;
219 let buffer = dev.new_private_buffer(elem_count, dtype, "k_cache")?;
220 let storage = Storage::Metal(MetalStorage::new(
221 buffer,
222 dev.clone(),
223 elem_count,
224 dtype,
225 ));
226 Tensor::from((
227 storage,
228 Shape::from_dims(&[
229 cache_config.num_gpu_blocks,
230 cache_config.block_size,
231 kv_lora_rank,
232 ]),
233 ))
234 }
235
236 #[cfg(not(feature = "metal"))]
237 {
238 unreachable!()
239 }
240 } else {
241 unsafe {
242 Tensor::empty(
243 (
244 cache_config.num_gpu_blocks,
245 cache_config.block_size,
246 kv_lora_rank,
247 ),
248 dtype,
249 device,
250 )?
251 }
252 };
253 #[allow(unused)]
254 let value_blocks = if let Device::Metal(dev) = &device {
255 #[cfg(feature = "metal")]
256 {
257 use hanzo_ml::{MetalStorage, Shape, Storage};
258
259 let elem_count = cache_config.num_gpu_blocks
260 * cache_config.block_size
261 * kpe_head_dim;
262 let buffer = dev.new_private_buffer(elem_count, dtype, "v_cache")?;
263 let storage = Storage::Metal(MetalStorage::new(
264 buffer,
265 dev.clone(),
266 elem_count,
267 dtype,
268 ));
269 Tensor::from((
270 storage,
271 Shape::from_dims(&[
272 cache_config.num_gpu_blocks,
273 cache_config.block_size,
274 kpe_head_dim,
275 ]),
276 ))
277 }
278
279 #[cfg(not(feature = "metal"))]
280 {
281 unreachable!()
282 }
283 } else {
284 unsafe {
285 Tensor::empty(
286 (
287 cache_config.num_gpu_blocks,
288 cache_config.block_size,
289 kpe_head_dim,
290 ),
291 dtype,
292 device,
293 )?
294 }
295 };
296 (key_blocks, value_blocks)
297 }
298 };
299 gpu_cache.push((key_blocks, value_blocks));
300 }
301 Ok(gpu_cache)
302 }
303
304 fn calculate_key_block_shape(
305 model_config: &dyn ModelConfigLike,
306 dtype: DType,
307 block_size: usize,
308 layer_idx: usize,
309 ) -> (usize, usize, usize, usize) {
310 let element_size = dtype.size_in_bytes();
311 let x = 16 / element_size;
312 (
313 model_config.num_kv_heads_for_layer(layer_idx),
314 model_config.k_head_dim_for_layer(layer_idx) / x,
315 block_size,
316 x,
317 )
318 }
319
320 fn calculate_value_block_shape(
321 model_config: &dyn ModelConfigLike,
322 block_size: usize,
323 layer_idx: usize,
324 ) -> (usize, usize, usize) {
325 (
326 model_config.num_kv_heads_for_layer(layer_idx),
327 model_config.v_head_dim_for_layer(layer_idx),
328 block_size,
329 )
330 }
331}