use candle_core::{DType, Device, Result, Tensor};
use super::cache_err;
use super::config::{BITS_PER_BYTE, QUANT_BLOCK_SIZE};
pub struct QuantizedKV<'a> {
pub k_indices: &'a Tensor,
pub k_scales: &'a Tensor,
pub v_indices: &'a Tensor,
pub v_scales: &'a Tensor,
}
pub struct CompressedStorage {
pub(crate) num_kv_heads: usize,
pub(crate) head_dim: usize,
pub(crate) bits: u8,
num_layers: usize,
buf_seq_len: Vec<usize>,
gpu_k_indices: Vec<Option<Tensor>>,
gpu_v_indices: Vec<Option<Tensor>>,
gpu_k_scales: Vec<Option<Tensor>>,
gpu_v_scales: Vec<Option<Tensor>>,
gpu_path_active: Vec<bool>,
}
impl CompressedStorage {
pub fn new(num_kv_heads: usize, head_dim: usize, bits: u8, num_layers: usize) -> Self {
Self {
num_kv_heads,
head_dim,
bits,
num_layers,
buf_seq_len: vec![0; num_layers],
gpu_k_indices: vec![None; num_layers],
gpu_v_indices: vec![None; num_layers],
gpu_k_scales: vec![None; num_layers],
gpu_v_scales: vec![None; num_layers],
gpu_path_active: vec![false; num_layers],
}
}
pub fn seq_len(&self, layer: usize) -> usize {
self.buf_seq_len[layer]
}
pub fn is_active(&self, layer: usize) -> bool {
self.gpu_path_active[layer] && self.buf_seq_len[layer] > 0
}
pub fn packed_dim(&self) -> usize {
self.head_dim * self.bits as usize / BITS_PER_BYTE
}
pub fn num_blocks(&self) -> usize {
self.head_dim / QUANT_BLOCK_SIZE
}
pub fn k_indices(&self, layer: usize) -> Option<&Tensor> {
self.gpu_k_indices[layer].as_ref()
}
pub fn k_scales(&self, layer: usize) -> Option<&Tensor> {
self.gpu_k_scales[layer].as_ref()
}
pub fn v_indices(&self, layer: usize) -> Option<&Tensor> {
self.gpu_v_indices[layer].as_ref()
}
pub fn v_scales(&self, layer: usize) -> Option<&Tensor> {
self.gpu_v_scales[layer].as_ref()
}
pub fn capacity(&self, layer: usize) -> usize {
self.gpu_k_indices[layer]
.as_ref()
.map_or(0, |t| t.dims()[1])
}
pub fn ensure_capacity(&mut self, layer: usize, needed: usize, device: &Device) -> Result<()> {
let current_cap = self.capacity(layer);
if current_cap >= needed {
return Ok(());
}
const MIN_HEADROOM: usize = 128;
let grow = (needed / 4).max(MIN_HEADROOM);
let new_cap = needed + grow;
let old_seq = self.buf_seq_len[layer];
let heads = self.num_kv_heads;
let packed_dim = self.packed_dim();
let num_blocks = self.num_blocks();
let new_ki = Tensor::zeros((heads, new_cap, packed_dim), DType::U8, device)?;
let new_vi = Tensor::zeros((heads, new_cap, packed_dim), DType::U8, device)?;
let new_ks = Tensor::zeros((heads, new_cap, num_blocks), DType::F16, device)?;
let new_vs = Tensor::zeros((heads, new_cap, num_blocks), DType::F16, device)?;
if old_seq > 0 {
copy_old_data(&self.gpu_k_indices[layer], &new_ki, old_seq)?;
copy_old_data(&self.gpu_v_indices[layer], &new_vi, old_seq)?;
copy_old_data(&self.gpu_k_scales[layer], &new_ks, old_seq)?;
copy_old_data(&self.gpu_v_scales[layer], &new_vs, old_seq)?;
}
self.gpu_k_indices[layer] = Some(new_ki);
self.gpu_v_indices[layer] = Some(new_vi);
self.gpu_k_scales[layer] = Some(new_ks);
self.gpu_v_scales[layer] = Some(new_vs);
Ok(())
}
pub fn append(
&mut self,
layer: usize,
offset: usize,
kv: &QuantizedKV<'_>,
new_seq_len: usize,
) -> Result<()> {
self.gpu_k_indices[layer]
.as_ref()
.ok_or_else(|| cache_err("k_indices buffer not allocated"))?
.slice_set(kv.k_indices, 1, offset)?;
self.gpu_v_indices[layer]
.as_ref()
.ok_or_else(|| cache_err("v_indices buffer not allocated"))?
.slice_set(kv.v_indices, 1, offset)?;
self.gpu_k_scales[layer]
.as_ref()
.ok_or_else(|| cache_err("k_scales buffer not allocated"))?
.slice_set(kv.k_scales, 1, offset)?;
self.gpu_v_scales[layer]
.as_ref()
.ok_or_else(|| cache_err("v_scales buffer not allocated"))?
.slice_set(kv.v_scales, 1, offset)?;
self.buf_seq_len[layer] = offset + new_seq_len;
self.gpu_path_active[layer] = true;
Ok(())
}
pub fn reset(&mut self) {
for layer in 0..self.num_layers {
self.gpu_k_indices[layer] = None;
self.gpu_v_indices[layer] = None;
self.gpu_k_scales[layer] = None;
self.gpu_v_scales[layer] = None;
self.gpu_path_active[layer] = false;
self.buf_seq_len[layer] = 0;
}
}
pub fn memory_usage(&self) -> usize {
let mut total = 0;
for layer in 0..self.num_layers {
let seq = self.buf_seq_len[layer];
if seq == 0 {
continue;
}
let packed_dim = self.packed_dim();
let num_blocks = self.num_blocks();
total += 2 * self.num_kv_heads * seq * packed_dim;
total += 2 * self.num_kv_heads * seq * num_blocks * 2;
}
total
}
}
fn copy_old_data(old: &Option<Tensor>, new: &Tensor, old_seq: usize) -> Result<()> {
if let Some(ref old_tensor) = old {
let slice = old_tensor.narrow(1, 0, old_seq)?;
new.slice_set(&slice, 1, 0)?;
}
Ok(())
}