hanzo_engine/paged_attention/
mod.rs1pub mod block_hash;
3pub mod block_pool;
5mod cache_engine;
9mod config;
10pub mod encoder_cache;
12pub mod kv_cache_manager;
14mod layers;
15mod scheduler;
16pub const _PAD_SLOT_ID: i64 = -1;
17
18pub use cache_engine::{CacheConfig, CacheEngine, PagedCacheType};
19pub use config::{KvCacheLayout, ModelConfigLike, ModelConfigMetadata};
20use hanzo_ml::{DType, Device};
21pub use kv_cache_manager::KVCacheManager;
22pub use layers::PagedAttention;
23pub use scheduler::{
24 PagedAttentionScheduler, PagedAttentionSchedulerConfig, PagedAttentionSchedulerOutput,
25};
26
27use crate::MemoryUsage;
28use tracing::info;
29
30pub const DEFAULT_PAGED_ATTENTION_BLOCK_SIZE: usize = 32;
31
32#[derive(Clone, Copy)]
34pub struct PagedAttentionConfig {
35 pub(crate) block_size: Option<usize>,
36 pub(crate) mem_gpu: MemoryGpuConfig,
37 pub(crate) cache_type: PagedCacheType,
38}
39
40impl PagedAttentionConfig {
41 pub fn new(
42 block_size: Option<usize>,
43 mem_gpu: MemoryGpuConfig,
44 cache_type: PagedCacheType,
45 ) -> anyhow::Result<Self> {
46 Ok(Self {
47 block_size,
48 mem_gpu,
49 cache_type,
50 })
51 }
52}
53
54#[derive(Debug, Clone, Copy, PartialEq)]
55pub enum AttentionImplementation {
56 Eager,
57 PagedAttention,
58}
59
60#[derive(Clone, Copy)]
61#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass)]
62pub enum MemoryGpuConfig {
63 MbAmount(usize),
64 Utilization(f32),
65 ContextSize(usize),
66}
67
68const SUPPORTED_BLOCK_SIZE: &[usize] = &[8, 16, 32];
70
71const SIZE_IN_MB: usize = 1024 * 1024;
72
73macro_rules! mb_to_blocks {
74 ($mb_size:expr, $dtype_size:expr, $block_size:expr, $config:expr) => {
75 $mb_size
76 / $dtype_size
77 / $block_size
78 / $config.num_layers()
79 / $config.kv_cache_elements_per_token()
80 };
81}
82
83macro_rules! ctxt_to_blocks {
84 ($context_len:expr, $dtype_size:expr, $block_size:expr, $config:expr) => {
85 $context_len * $dtype_size * $config.num_layers() * $config.kv_cache_elements_per_token()
86 };
87}
88
89#[allow(clippy::too_many_arguments)]
103pub fn calculate_cache_config(
104 mem_gpu: MemoryGpuConfig,
105 block_size: Option<usize>,
106 dtype: DType,
107 cache_type: PagedCacheType,
108 config: &dyn ModelConfigLike,
109 device: &Device,
110 layer_devices: &[Option<Device>],
111 silent: bool,
112 model_weight_size_in_bytes: Option<usize>,
113 max_num_tokens: Option<usize>,
114) -> anyhow::Result<CacheConfig> {
115 let block_size = block_size.unwrap_or(DEFAULT_PAGED_ATTENTION_BLOCK_SIZE);
116 if !SUPPORTED_BLOCK_SIZE.contains(&block_size) {
117 anyhow::bail!("Block size must be in {SUPPORTED_BLOCK_SIZE:?}, got {block_size}");
118 }
119 let dtype = cache_type.to_dtype(dtype);
120 let dtype_size = dtype.size_in_bytes();
121
122 let num_devices = layer_devices.len().max(1);
124 let model_weight_per_device_mb =
125 model_weight_size_in_bytes.unwrap_or(0) / num_devices / SIZE_IN_MB;
126
127 let mut min_mem_gpu = usize::MAX;
128 for dev in layer_devices {
129 let device = dev.as_ref().unwrap_or(device);
130
131 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
132 let mem_gpu = match mem_gpu {
133 MemoryGpuConfig::MbAmount(v) => v,
134 MemoryGpuConfig::Utilization(f) => {
135 let mem = MemoryUsage.query(device)?;
136 let total = mem.total() as f32 / SIZE_IN_MB as f32;
137 if model_weight_size_in_bytes.is_some() {
138 (total * f - model_weight_per_device_mb as f32).max(0.0) as usize
140 } else {
141 let used = (mem.total() - mem.available()) as f32 / SIZE_IN_MB as f32;
142 (total * f - used).max(0.0) as usize
143 }
144 }
145 MemoryGpuConfig::ContextSize(toks) => {
146 ctxt_to_blocks!(toks, dtype_size, block_size, config) / SIZE_IN_MB
148 }
149 };
150 min_mem_gpu = min_mem_gpu.min(mem_gpu);
151 }
152
153 #[allow(unused_mut, unused_variables)]
158 let mut mem_gpu = min_mem_gpu;
159 if device.is_metal() {
160 let max_tokens = max_num_tokens.unwrap_or(config.max_seq_len());
161 let mem_for_tokens =
162 ctxt_to_blocks!(max_tokens, dtype_size, block_size, config) / SIZE_IN_MB;
163 if mem_for_tokens < mem_gpu {
164 if !silent {
165 info!(
166 "Metal: capping KV cache from {} MB to {} MB ({} tokens).",
167 mem_gpu, mem_for_tokens, max_tokens
168 );
169 }
170 mem_gpu = mem_for_tokens;
171 }
172 }
173
174 let num_gpu_blocks = mb_to_blocks!(mem_gpu * SIZE_IN_MB, dtype_size, block_size, config);
175 if num_gpu_blocks == 0 {
176 anyhow::bail!("Num GPU blocks is 0. This means there is not enough memory. Either reduce the memory amount/utilization/context size or disable PagedAttention.");
177 }
178
179 if !silent {
180 info!("Allocating {mem_gpu} MB for PagedAttention KV cache per GPU");
181 info!("PagedAttention KV cache type is {dtype:?}");
182 info!("Using PagedAttention with block size {block_size} and {num_gpu_blocks} GPU blocks: available context length is {} tokens", num_gpu_blocks*block_size);
183 }
184 Ok(CacheConfig {
185 block_size,
186 num_gpu_blocks,
187 cache_type,
188 })
189}