Skip to main content

mistralrs_core/paged_attention/
mod.rs

1/// The higher-level manager of the blocks allocated. Operations performed by the block engine do
2/// not directly change memory.
3mod block_engine;
4mod block_engine_sequence;
5/// This is the lower-level manager of the cache. It manages swapping and copying the blocks and
6/// actually allocates the KV cache for the CPU and GPU. It is used by the LLMEngine to execute
7/// operations issued by the scheduler.
8mod cache_engine;
9mod config;
10mod layers;
11/// Prefix caching for KV cache reuse across requests with shared prefixes.
12mod prefix_cacher;
13mod scheduler;
14pub const _PAD_SLOT_ID: i64 = -1;
15
16pub use block_engine::{BlockEngine, BlockRef, BlockTables, LogicalTokenBlock};
17pub use block_engine_sequence::BlockEngineSequence;
18pub use cache_engine::{CacheConfig, CacheEngine, PagedCacheType};
19use candle_core::{DType, Device};
20pub use config::{KvCacheLayout, ModelConfigLike, ModelConfigMetadata};
21pub use layers::PagedAttention;
22pub use scheduler::{
23    PagedAttentionScheduler, PagedAttentionSchedulerConfig, PagedAttentionSchedulerOutput,
24};
25
26use crate::MemoryUsage;
27use tracing::{info, warn};
28
29pub const DEFAULT_PAGED_ATTENTION_BLOCK_SIZE: usize = 32;
30
31/// All memory counts in MB. Default for block size is 32.
32#[derive(Clone, Copy)]
33pub struct PagedAttentionConfig {
34    pub(crate) block_size: Option<usize>,
35    pub(crate) mem_gpu: MemoryGpuConfig,
36    pub(crate) cache_type: PagedCacheType,
37}
38
39impl PagedAttentionConfig {
40    pub fn new(
41        block_size: Option<usize>,
42        mem_gpu: MemoryGpuConfig,
43        cache_type: PagedCacheType,
44    ) -> anyhow::Result<Self> {
45        Ok(Self {
46            block_size,
47            mem_gpu,
48            cache_type,
49        })
50    }
51}
52
53#[derive(Debug, Clone, Copy, PartialEq)]
54pub enum AttentionImplementation {
55    Eager,
56    PagedAttention,
57}
58
59#[derive(Clone, Copy)]
60#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass)]
61pub enum MemoryGpuConfig {
62    MbAmount(usize),
63    Utilization(f32),
64    ContextSize(usize),
65}
66
67// See `pagedattention.cu` CALL_V1_LAUNCHER_BLOCK_SIZE
68const SUPPORTED_BLOCK_SIZE: &[usize] = &[8, 16, 32];
69
70const SIZE_IN_MB: usize = 1024 * 1024;
71
72macro_rules! mb_to_blocks {
73    ($mb_size:expr, $dtype_size:expr, $block_size:expr, $config:expr) => {
74        $mb_size
75            / $dtype_size
76            / $block_size
77            / $config.num_layers()
78            / $config.kv_cache_elements_per_token()
79    };
80}
81
82macro_rules! ctxt_to_blocks {
83    ($context_len:expr, $dtype_size:expr, $block_size:expr, $config:expr) => {
84        $context_len * $dtype_size * $config.num_layers() * $config.kv_cache_elements_per_token()
85    };
86}
87
88/// Memory values are in MBs or a percentage in [0,1]. Specify block size or the default is 32.
89#[allow(clippy::too_many_arguments)]
90pub fn calculate_cache_config(
91    mem_gpu: MemoryGpuConfig,
92    block_size: Option<usize>,
93    dtype: DType,
94    cache_type: PagedCacheType,
95    config: &dyn ModelConfigLike,
96    device: &Device,
97    layer_devices: &[Option<Device>],
98    silent: bool,
99) -> anyhow::Result<CacheConfig> {
100    let block_size = block_size.unwrap_or(DEFAULT_PAGED_ATTENTION_BLOCK_SIZE);
101    if !SUPPORTED_BLOCK_SIZE.contains(&block_size) {
102        anyhow::bail!("Block size must be in {SUPPORTED_BLOCK_SIZE:?}, got {block_size}");
103    }
104    let dtype = cache_type.to_dtype(dtype);
105    let dtype_size = dtype.size_in_bytes();
106
107    let mut min_mem_gpu = usize::MAX;
108    for dev in layer_devices {
109        let device = dev.as_ref().unwrap_or(device);
110
111        #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
112        let mem_gpu = match mem_gpu {
113            MemoryGpuConfig::MbAmount(v) => v,
114            MemoryGpuConfig::Utilization(f) => {
115                let free = MemoryUsage.get_memory_available(device)? as f32 / SIZE_IN_MB as f32;
116                let total = MemoryUsage.get_total_memory(device)? as f32 / SIZE_IN_MB as f32;
117                let used = total - free;
118                (total * f - used) as usize
119            }
120            MemoryGpuConfig::ContextSize(toks) => {
121                ctxt_to_blocks!(toks, dtype_size, block_size, config) / SIZE_IN_MB
122            }
123        };
124        min_mem_gpu = min_mem_gpu.min(mem_gpu);
125    }
126
127    // // Cap at kv cache for max seq len
128    // let mem_for_toks =
129    //     ctxt_to_blocks!(config.max_seq_len(), dtype_size, block_size, config) / SIZE_IN_MB;
130    // let mem_gpu = min_mem_gpu.min(mem_for_toks);
131
132    // Cap Metal GPU memory to the wired (non‑paged) allocation limit reported by the kernel (`iogpu.wired_limit_mb`).
133    // Users can raise this limit with `sudo sysctl -w iogpu.wired_limit_mb=<desired_mb>`.
134    let mem_gpu = if matches!(device, Device::Metal(_)) {
135        let metal_cap_mb = MemoryUsage.get_total_memory(device)? / SIZE_IN_MB;
136
137        info!("Metal GPU wired limit is {metal_cap_mb} MB.");
138
139        if min_mem_gpu > metal_cap_mb {
140            if !silent {
141                warn!(
142                    "Capping Metal GPU memory allocation from {} MB to {} MB (limited by iogpu.wired_limit_mb). \
143To raise this cap run: `sudo sysctl -w iogpu.wired_limit_mb=<desired_mb>`.",
144                    min_mem_gpu,
145                    metal_cap_mb
146                );
147            }
148            metal_cap_mb
149        } else {
150            min_mem_gpu
151        }
152    } else {
153        min_mem_gpu
154    };
155
156    let num_gpu_blocks = mb_to_blocks!(mem_gpu * SIZE_IN_MB, dtype_size, block_size, config);
157    if num_gpu_blocks == 0 {
158        anyhow::bail!("Num GPU blocks is 0. This means there is not enough memory. Either reduce the memory amount/utilization/context size or disable PagedAttention.");
159    }
160
161    if !silent {
162        info!("Allocating {mem_gpu} MB for PagedAttention KV cache per GPU");
163        info!("PagedAttention KV cache type is {dtype:?}");
164        info!("Using PagedAttention with block size {block_size} and {num_gpu_blocks} GPU blocks: available context length is {} tokens", num_gpu_blocks*block_size);
165    }
166    Ok(CacheConfig {
167        block_size,
168        num_gpu_blocks,
169        cache_type,
170    })
171}