Skip to main content

hanzo_engine/paged_attention/
cache_engine.rs

1use std::{
2    str::FromStr,
3    sync::{Arc, Mutex, MutexGuard},
4};
5
6use hanzo_ml::{DType, Device, Result, Tensor};
7use serde::{Deserialize, Serialize};
8
9use super::config::{KvCacheLayout, ModelConfigLike};
10
11#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Default)]
12#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
13pub enum PagedCacheType {
14    #[default]
15    Auto,
16    F8E4M3,
17}
18
19impl PagedCacheType {
20    pub fn to_dtype(&self, act_dtype: DType) -> DType {
21        match self {
22            PagedCacheType::F8E4M3 => DType::F8E4M3,
23            PagedCacheType::Auto => act_dtype,
24        }
25    }
26}
27
28impl FromStr for PagedCacheType {
29    type Err = String;
30    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
31        match s {
32            "auto" => Ok(Self::Auto),
33            "f8e4m3" => Ok(Self::F8E4M3),
34            other => Err(format!(
35                "Unexpected `PagedCacheType`, got `{other}` but expected `auto` and `f8e4m3`."
36            )),
37        }
38    }
39}
40
41#[derive(Clone, Debug)]
42pub struct CacheConfig {
43    pub block_size: usize,
44    pub num_gpu_blocks: usize,
45    pub cache_type: PagedCacheType,
46}
47
48pub type KVCache = (Tensor, Tensor);
49
50pub struct CacheEngine {
51    gpu_cache: Arc<Mutex<Vec<KVCache>>>,
52}
53
54impl CacheEngine {
55    pub fn new(
56        model_config: &dyn ModelConfigLike,
57        cache_config: &CacheConfig,
58        dtype: DType,
59        device: &Device,
60        layer_devices: Vec<Option<Device>>,
61    ) -> Result<Self> {
62        let dtype = cache_config.cache_type.to_dtype(dtype);
63        Ok(Self {
64            gpu_cache: Arc::new(Mutex::new(Self::allocate_gpu_cache(
65                model_config,
66                cache_config,
67                dtype,
68                device,
69                layer_devices,
70            )?)),
71        })
72    }
73
74    pub fn get_kv_cache(&self) -> MutexGuard<'_, Vec<KVCache>> {
75        // Use blocking lock instead of busy-wait spin loop to avoid CPU waste
76        // and potential thread starvation issues
77        self.gpu_cache.lock().expect("KV cache mutex was poisoned")
78    }
79
80    fn allocate_gpu_cache(
81        model_config: &dyn ModelConfigLike,
82        cache_config: &CacheConfig,
83        dtype: DType,
84        device: &Device,
85        layer_devices: Vec<Option<Device>>,
86    ) -> Result<Vec<KVCache>> {
87        let kv_cache_layout = model_config.kv_cache_layout();
88        let mut gpu_cache = Vec::new();
89
90        for (layer_idx, device) in layer_devices
91            .iter()
92            .take(model_config.num_layers())
93            .map(|x| x.as_ref().unwrap_or(device))
94            .enumerate()
95        {
96            let (key_blocks, value_blocks) = match kv_cache_layout {
97                KvCacheLayout::Standard => {
98                    let key_block_shape = Self::calculate_key_block_shape(
99                        model_config,
100                        dtype,
101                        cache_config.block_size,
102                        layer_idx,
103                    );
104                    let value_block_shape = Self::calculate_value_block_shape(
105                        model_config,
106                        cache_config.block_size,
107                        layer_idx,
108                    );
109                    #[allow(unused)]
110                    let key_blocks = if let Device::Metal(dev) = &device {
111                        #[cfg(feature = "metal")]
112                        {
113                            use hanzo_ml::{MetalStorage, Shape, Storage};
114
115                            let elem_count = cache_config.num_gpu_blocks
116                                * key_block_shape.0
117                                * key_block_shape.1
118                                * key_block_shape.2
119                                * key_block_shape.3;
120                            let buffer = dev.new_private_buffer(elem_count, dtype, "k_cache")?;
121                            let storage = Storage::Metal(MetalStorage::new(
122                                buffer,
123                                dev.clone(),
124                                elem_count,
125                                dtype,
126                            ));
127                            Tensor::from((
128                                storage,
129                                Shape::from_dims(&[
130                                    cache_config.num_gpu_blocks,
131                                    key_block_shape.0,
132                                    key_block_shape.1,
133                                    key_block_shape.2,
134                                    key_block_shape.3,
135                                ]),
136                            ))
137                        }
138
139                        #[cfg(not(feature = "metal"))]
140                        {
141                            unreachable!()
142                        }
143                    } else {
144                        unsafe {
145                            Tensor::empty(
146                                (
147                                    cache_config.num_gpu_blocks,
148                                    key_block_shape.0,
149                                    key_block_shape.1,
150                                    key_block_shape.2,
151                                    key_block_shape.3,
152                                ),
153                                dtype,
154                                device,
155                            )?
156                        }
157                    };
158                    #[allow(unused)]
159                    let value_blocks = if let Device::Metal(dev) = &device {
160                        #[cfg(feature = "metal")]
161                        {
162                            use hanzo_ml::{MetalStorage, Shape, Storage};
163
164                            let elem_count = cache_config.num_gpu_blocks
165                                * value_block_shape.0
166                                * value_block_shape.1
167                                * value_block_shape.2;
168                            let buffer = dev.new_private_buffer(elem_count, dtype, "v_cache")?;
169                            let storage = Storage::Metal(MetalStorage::new(
170                                buffer,
171                                dev.clone(),
172                                elem_count,
173                                dtype,
174                            ));
175                            Tensor::from((
176                                storage,
177                                Shape::from_dims(&[
178                                    cache_config.num_gpu_blocks,
179                                    value_block_shape.0,
180                                    value_block_shape.1,
181                                    value_block_shape.2,
182                                ]),
183                            ))
184                        }
185
186                        #[cfg(not(feature = "metal"))]
187                        {
188                            unreachable!()
189                        }
190                    } else {
191                        unsafe {
192                            Tensor::empty(
193                                (
194                                    cache_config.num_gpu_blocks,
195                                    value_block_shape.0,
196                                    value_block_shape.1,
197                                    value_block_shape.2,
198                                ),
199                                dtype,
200                                device,
201                            )?
202                        }
203                    };
204                    (key_blocks, value_blocks)
205                }
206                KvCacheLayout::Mla {
207                    kv_lora_rank,
208                    kpe_head_dim,
209                } => {
210                    #[allow(unused)]
211                    let key_blocks = if let Device::Metal(dev) = &device {
212                        #[cfg(feature = "metal")]
213                        {
214                            use hanzo_ml::{MetalStorage, Shape, Storage};
215
216                            let elem_count = cache_config.num_gpu_blocks
217                                * cache_config.block_size
218                                * kv_lora_rank;
219                            let buffer = dev.new_private_buffer(elem_count, dtype, "k_cache")?;
220                            let storage = Storage::Metal(MetalStorage::new(
221                                buffer,
222                                dev.clone(),
223                                elem_count,
224                                dtype,
225                            ));
226                            Tensor::from((
227                                storage,
228                                Shape::from_dims(&[
229                                    cache_config.num_gpu_blocks,
230                                    cache_config.block_size,
231                                    kv_lora_rank,
232                                ]),
233                            ))
234                        }
235
236                        #[cfg(not(feature = "metal"))]
237                        {
238                            unreachable!()
239                        }
240                    } else {
241                        unsafe {
242                            Tensor::empty(
243                                (
244                                    cache_config.num_gpu_blocks,
245                                    cache_config.block_size,
246                                    kv_lora_rank,
247                                ),
248                                dtype,
249                                device,
250                            )?
251                        }
252                    };
253                    #[allow(unused)]
254                    let value_blocks = if let Device::Metal(dev) = &device {
255                        #[cfg(feature = "metal")]
256                        {
257                            use hanzo_ml::{MetalStorage, Shape, Storage};
258
259                            let elem_count = cache_config.num_gpu_blocks
260                                * cache_config.block_size
261                                * kpe_head_dim;
262                            let buffer = dev.new_private_buffer(elem_count, dtype, "v_cache")?;
263                            let storage = Storage::Metal(MetalStorage::new(
264                                buffer,
265                                dev.clone(),
266                                elem_count,
267                                dtype,
268                            ));
269                            Tensor::from((
270                                storage,
271                                Shape::from_dims(&[
272                                    cache_config.num_gpu_blocks,
273                                    cache_config.block_size,
274                                    kpe_head_dim,
275                                ]),
276                            ))
277                        }
278
279                        #[cfg(not(feature = "metal"))]
280                        {
281                            unreachable!()
282                        }
283                    } else {
284                        unsafe {
285                            Tensor::empty(
286                                (
287                                    cache_config.num_gpu_blocks,
288                                    cache_config.block_size,
289                                    kpe_head_dim,
290                                ),
291                                dtype,
292                                device,
293                            )?
294                        }
295                    };
296                    (key_blocks, value_blocks)
297                }
298            };
299            gpu_cache.push((key_blocks, value_blocks));
300        }
301        Ok(gpu_cache)
302    }
303
304    fn calculate_key_block_shape(
305        model_config: &dyn ModelConfigLike,
306        dtype: DType,
307        block_size: usize,
308        layer_idx: usize,
309    ) -> (usize, usize, usize, usize) {
310        let element_size = dtype.size_in_bytes();
311        let x = 16 / element_size;
312        (
313            model_config.num_kv_heads_for_layer(layer_idx),
314            model_config.k_head_dim_for_layer(layer_idx) / x,
315            block_size,
316            x,
317        )
318    }
319
320    fn calculate_value_block_shape(
321        model_config: &dyn ModelConfigLike,
322        block_size: usize,
323        layer_idx: usize,
324    ) -> (usize, usize, usize) {
325        (
326            model_config.num_kv_heads_for_layer(layer_idx),
327            model_config.v_head_dim_for_layer(layer_idx),
328            block_size,
329        )
330    }
331}