pub mod block_hash;
pub mod block_pool;
mod cache_engine;
mod config;
pub mod encoder_cache;
pub mod kv_cache_manager;
mod layers;
mod scheduler;
pub const _PAD_SLOT_ID: i64 = -1;
pub use cache_engine::{CacheConfig, CacheEngine, PagedCacheType};
use candle_core::{DType, Device};
pub use config::{KvCacheLayout, ModelConfigLike, ModelConfigMetadata};
pub use kv_cache_manager::KVCacheManager;
pub use layers::PagedAttention;
pub use scheduler::{
PagedAttentionScheduler, PagedAttentionSchedulerConfig, PagedAttentionSchedulerOutput,
};
use crate::MemoryUsage;
use tracing::info;
pub const DEFAULT_PAGED_ATTENTION_BLOCK_SIZE: usize = 32;
#[derive(Clone, Copy)]
pub struct PagedAttentionConfig {
pub(crate) block_size: Option<usize>,
pub(crate) mem_gpu: MemoryGpuConfig,
pub(crate) cache_type: PagedCacheType,
}
impl PagedAttentionConfig {
pub fn new(
block_size: Option<usize>,
mem_gpu: MemoryGpuConfig,
cache_type: PagedCacheType,
) -> anyhow::Result<Self> {
Ok(Self {
block_size,
mem_gpu,
cache_type,
})
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum AttentionImplementation {
Eager,
PagedAttention,
}
#[derive(Clone, Copy)]
#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass)]
pub enum MemoryGpuConfig {
MbAmount(usize),
Utilization(f32),
ContextSize(usize),
}
const SUPPORTED_BLOCK_SIZE: &[usize] = &[8, 16, 32];
const SIZE_IN_MB: usize = 1024 * 1024;
macro_rules! mb_to_blocks {
($mb_size:expr, $dtype_size:expr, $block_size:expr, $config:expr) => {
$mb_size
/ $dtype_size
/ $block_size
/ $config.num_layers()
/ $config.kv_cache_elements_per_token()
};
}
macro_rules! ctxt_to_blocks {
($context_len:expr, $dtype_size:expr, $block_size:expr, $config:expr) => {
$context_len * $dtype_size * $config.num_layers() * $config.kv_cache_elements_per_token()
};
}
#[allow(clippy::too_many_arguments)]
pub fn calculate_cache_config(
mem_gpu: MemoryGpuConfig,
block_size: Option<usize>,
dtype: DType,
cache_type: PagedCacheType,
config: &dyn ModelConfigLike,
device: &Device,
layer_devices: &[Option<Device>],
silent: bool,
model_weight_size_in_bytes: Option<usize>,
max_num_tokens: Option<usize>,
) -> anyhow::Result<CacheConfig> {
let block_size = block_size.unwrap_or(DEFAULT_PAGED_ATTENTION_BLOCK_SIZE);
if !SUPPORTED_BLOCK_SIZE.contains(&block_size) {
anyhow::bail!("Block size must be in {SUPPORTED_BLOCK_SIZE:?}, got {block_size}");
}
let dtype = cache_type.to_dtype(dtype);
let dtype_size = dtype.size_in_bytes();
let num_devices = layer_devices.len().max(1);
let model_weight_per_device_mb =
model_weight_size_in_bytes.unwrap_or(0) / num_devices / SIZE_IN_MB;
let mut min_mem_gpu = usize::MAX;
for dev in layer_devices {
let device = dev.as_ref().unwrap_or(device);
#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
let mem_gpu = match mem_gpu {
MemoryGpuConfig::MbAmount(v) => v,
MemoryGpuConfig::Utilization(f) => {
let total = MemoryUsage.get_total_memory(device)? as f32 / SIZE_IN_MB as f32;
if model_weight_size_in_bytes.is_some() {
(total * f - model_weight_per_device_mb as f32).max(0.0) as usize
} else {
let free = MemoryUsage.get_memory_available(device)? as f32 / SIZE_IN_MB as f32;
#[allow(unused_mut)]
let mut used = total - free;
#[cfg(feature = "metal")]
if let Device::Metal(dev) = device {
used = dev.current_allocated_size() as f32 / SIZE_IN_MB as f32;
}
(total * f - used).max(0.0) as usize
}
}
MemoryGpuConfig::ContextSize(toks) => {
ctxt_to_blocks!(toks, dtype_size, block_size, config) / SIZE_IN_MB
}
};
min_mem_gpu = min_mem_gpu.min(mem_gpu);
}
#[allow(unused_mut, unused_variables)]
let mut mem_gpu = min_mem_gpu;
if device.is_metal() {
let max_tokens = max_num_tokens.unwrap_or(config.max_seq_len());
let mem_for_tokens =
ctxt_to_blocks!(max_tokens, dtype_size, block_size, config) / SIZE_IN_MB;
if mem_for_tokens < mem_gpu {
if !silent {
info!(
"Metal: capping KV cache from {} MB to {} MB ({} tokens).",
mem_gpu, mem_for_tokens, max_tokens
);
}
mem_gpu = mem_for_tokens;
}
}
let num_gpu_blocks = mb_to_blocks!(mem_gpu * SIZE_IN_MB, dtype_size, block_size, config);
if num_gpu_blocks == 0 {
anyhow::bail!("Num GPU blocks is 0. This means there is not enough memory. Either reduce the memory amount/utilization/context size or disable PagedAttention.");
}
if !silent {
info!("Allocating {mem_gpu} MB for PagedAttention KV cache per GPU");
info!("PagedAttention KV cache type is {dtype:?}");
info!("Using PagedAttention with block size {block_size} and {num_gpu_blocks} GPU blocks: available context length is {} tokens", num_gpu_blocks*block_size);
}
Ok(CacheConfig {
block_size,
num_gpu_blocks,
cache_type,
})
}