use candle_core::{DType, Device, Result, Tensor};
use mistralrs_kv_cache::{AttendConfig, CompressedKVCache, DecodeOutput, DequantResult};
use parking_lot::Mutex;
use super::cache_err;
use super::common::{
dequantize_full_impl, flatten_kv, make_quant_config, quantize_kv_pair,
validate_and_make_metadata,
};
use super::config::{CacheConfig, BITS_PER_BYTE, DEFAULT_QJL_SEED};
use super::precomputed::GpuPrecomputed;
use super::quantize_tensor::polar_dequantize;
use super::storage::{LayerStorage, QuantizedKV, StorageMetadata};
use super::{ensure_gpu_precomputed, PrecomputedState};
const MIN_QJL_GROW: usize = 128;
#[derive(Default)]
struct TqLayer {
storage: LayerStorage,
qjl_signs: Option<Tensor>,
qjl_norms: Option<Tensor>,
}
pub struct TqCache {
config: CacheConfig,
metadata: StorageMetadata,
precomputed: PrecomputedState,
layers: Vec<Mutex<TqLayer>>,
}
impl TqCache {
pub fn new(config: CacheConfig) -> candle_core::Result<Self> {
let metadata = validate_and_make_metadata(&config)?;
let layers = (0..config.num_layers)
.map(|_| Mutex::new(TqLayer::default()))
.collect();
Ok(Self {
config,
metadata,
precomputed: PrecomputedState::default(),
layers,
})
}
fn ensure_qjl_capacity(
&self,
layer_slot: &mut TqLayer,
needed: usize,
device: &Device,
) -> Result<()> {
let signs_per_head = self.config.head_dim / BITS_PER_BYTE;
let heads = self.config.num_kv_heads;
let current_cap = layer_slot.qjl_signs.as_ref().map_or(0, |t| t.dims()[1]);
if current_cap >= needed {
return Ok(());
}
let grow = (needed / 4).max(MIN_QJL_GROW);
let new_cap = needed + grow;
let old_seq = layer_slot.storage.seq_len();
let new_signs = Tensor::zeros((heads, new_cap, signs_per_head), DType::U8, device)?;
let new_norms = Tensor::zeros((heads, new_cap), DType::F16, device)?;
if old_seq > 0 {
if let Some(ref old) = layer_slot.qjl_signs {
new_signs.slice_set(&old.narrow(1, 0, old_seq)?, 1, 0)?;
}
if let Some(ref old) = layer_slot.qjl_norms {
new_norms.slice_set(&old.narrow(1, 0, old_seq)?, 1, 0)?;
}
}
layer_slot.qjl_signs = Some(new_signs);
layer_slot.qjl_norms = Some(new_norms);
Ok(())
}
fn quantize_and_store(
&self,
layer_slot: &mut TqLayer,
k: &Tensor,
v: &Tensor,
pre: &GpuPrecomputed,
) -> Result<(usize, usize)> {
let device = k.device().clone();
let new_seq_len = k.dims()[2];
let old_seq_len = layer_slot.storage.seq_len();
let total_seq_len = old_seq_len + new_seq_len;
layer_slot
.storage
.ensure_capacity(total_seq_len, &self.metadata, &device)?;
self.ensure_qjl_capacity(layer_slot, total_seq_len, &device)?;
let (k_flat, v_flat) = flatten_kv(k, v, self.config.num_kv_heads, self.config.head_dim)?;
let qc = make_quant_config(pre, &self.config)?;
let packed_dim = qc.packed_dim();
let num_blocks = qc.num_blocks();
let (k_idx, k_sc, v_idx, v_sc) =
quantize_kv_pair(&k_flat, &v_flat, self.config.norm_mode, &qc)?;
let heads = self.config.num_kv_heads;
let k_idx_r = k_idx.reshape((heads, new_seq_len, packed_dim))?;
let v_idx_r = v_idx.reshape((heads, new_seq_len, packed_dim))?;
let k_sc_r = k_sc.reshape((heads, new_seq_len, num_blocks))?;
let v_sc_r = v_sc.reshape((heads, new_seq_len, num_blocks))?;
let kv = QuantizedKV {
k_indices: &k_idx_r,
k_scales: &k_sc_r,
v_indices: &v_idx_r,
v_scales: &v_sc_r,
};
layer_slot.storage.append(old_seq_len, &kv, new_seq_len)?;
self.compute_and_store_qjl(layer_slot, &k_flat, &k_idx, &k_sc, &qc)?;
Ok((old_seq_len, total_seq_len))
}
fn compute_and_store_qjl(
&self,
layer_slot: &mut TqLayer,
k_flat: &Tensor,
k_idx: &Tensor,
k_sc: &Tensor,
qc: &super::quantize_tensor::QuantConfig<'_>,
) -> Result<()> {
let head_dim = self.config.head_dim;
let num_kv_heads = self.config.num_kv_heads;
let packed_dim = qc.packed_dim();
let num_blocks = qc.num_blocks();
let n_vecs = k_flat.dims()[0];
let new_seq_len = n_vecs / num_kv_heads;
let old_seq_len = layer_slot.storage.seq_len() - new_seq_len;
let k_idx_flat = k_idx.reshape((n_vecs, packed_dim))?;
let k_sc_flat = k_sc.reshape((n_vecs, num_blocks))?;
let k_dequant = polar_dequantize(&k_idx_flat, &k_sc_flat, qc)?;
let signs_per_head = head_dim / BITS_PER_BYTE;
let (signs_tensor, norms_tensor) =
compute_qjl_signs_and_norms(k_flat, &k_dequant, n_vecs, head_dim, signs_per_head)?;
let signs_r = signs_tensor.reshape((num_kv_heads, new_seq_len, signs_per_head))?;
let norms_r = norms_tensor.reshape((num_kv_heads, new_seq_len))?;
layer_slot
.qjl_signs
.as_ref()
.ok_or_else(|| cache_err("qjl_signs not initialized"))?
.slice_set(&signs_r, 1, old_seq_len)?;
layer_slot
.qjl_norms
.as_ref()
.ok_or_else(|| cache_err("qjl_norms not initialized"))?
.slice_set(&norms_r, 1, old_seq_len)?;
Ok(())
}
fn compute_logit_bias(
&self,
layer_slot: &TqLayer,
pre: &GpuPrecomputed,
q: &Tensor,
) -> Result<Tensor> {
let head_dim = self.config.head_dim;
let total_seq = layer_slot.storage.seq_len();
let q_dims = q.dims4()?;
let num_attn_heads = q_dims.1;
let rademacher = pre
.qjl_rademacher
.as_ref()
.ok_or_else(|| cache_err("QJL Rademacher matrix not precomputed"))?;
let rademacher_t = rademacher.t()?;
let mut head_corrections = Vec::with_capacity(self.config.num_kv_heads);
let n_kv_groups = num_attn_heads / self.config.num_kv_heads;
let qjl_signs = layer_slot
.qjl_signs
.as_ref()
.ok_or_else(|| cache_err("qjl_signs not initialized"))?;
let qjl_norms = layer_slot
.qjl_norms
.as_ref()
.ok_or_else(|| cache_err("qjl_norms not initialized"))?;
let bit_masks =
Tensor::from_vec(BYTE_BIT_MASKS.to_vec(), (1, 1, BITS_PER_BYTE), q.device())?;
let sqrt_pi_over_2 = std::f64::consts::FRAC_PI_2.sqrt() as f32;
let scale_factor = sqrt_pi_over_2 / (head_dim as f32).sqrt();
for kv_head in 0..self.config.num_kv_heads {
let (signs_float_t, c_row) = unpack_qjl_signs(
qjl_signs,
qjl_norms,
kv_head,
total_seq,
head_dim,
&bit_masks,
scale_factor,
)?;
for qh in 0..n_kv_groups {
let attn_head = kv_head * n_kv_groups + qh;
let q_head = q
.narrow(1, attn_head, 1)?
.squeeze(0)?
.squeeze(0)?
.to_dtype(DType::F32)?;
let r_q = q_head.matmul(&rademacher_t)?;
let raw = r_q.matmul(&signs_float_t)?;
let corr = raw.broadcast_mul(&c_row)?;
head_corrections.push(corr.unsqueeze(0)?); }
}
let refs: Vec<&Tensor> = head_corrections.iter().collect();
let combined = Tensor::cat(&refs, 0)?;
combined.unsqueeze(0)?.to_dtype(q.dtype())
}
fn dequantize_full(
&self,
layer_slot: &TqLayer,
pre: &GpuPrecomputed,
orig_dtype: DType,
) -> Result<(Tensor, Tensor)> {
let qc = make_quant_config(pre, &self.config)?;
dequantize_full_impl(&layer_slot.storage, &self.metadata, &qc, orig_dtype)
}
fn layer_mutex(&self, layer: usize) -> Result<&Mutex<TqLayer>> {
self.layers.get(layer).ok_or_else(|| {
cache_err(format!(
"layer index {layer} out of range (cache has {} layers)",
self.layers.len()
))
})
}
}
impl CompressedKVCache for TqCache {
fn prefill(&self, layer: usize, k: &Tensor, v: &Tensor, q: &Tensor) -> Result<DequantResult> {
let orig_dtype = k.dtype();
let pre = ensure_gpu_precomputed(&self.precomputed, &self.config, k.device())?;
let mut guard = self.layer_mutex(layer)?.lock();
let (old_seq_len, _total) = self.quantize_and_store(&mut guard, k, v, pre)?;
let (full_k, full_v) = if old_seq_len == 0 {
(k.clone(), v.clone())
} else {
self.dequantize_full(&guard, pre, orig_dtype)?
};
let logit_bias = self.compute_logit_bias(&guard, pre, q)?;
Ok(DequantResult {
k: full_k,
v: full_v,
logit_bias: Some(logit_bias),
})
}
fn decode(
&self,
layer: usize,
k: &Tensor,
v: &Tensor,
q: &Tensor,
_config: &AttendConfig,
) -> Result<DecodeOutput> {
let orig_dtype = k.dtype();
let pre = ensure_gpu_precomputed(&self.precomputed, &self.config, k.device())?;
let mut guard = self.layer_mutex(layer)?.lock();
self.quantize_and_store(&mut guard, k, v, pre)?;
let (full_k, full_v) = self.dequantize_full(&guard, pre, orig_dtype)?;
let logit_bias = self.compute_logit_bias(&guard, pre, q)?;
Ok(DecodeOutput::Dequantized(DequantResult {
k: full_k,
v: full_v,
logit_bias: Some(logit_bias),
}))
}
fn seq_len(&self, layer: usize) -> usize {
self.layers
.get(layer)
.map(|m| m.lock().storage.seq_len())
.unwrap_or(0)
}
fn reset(&self) -> Result<()> {
self.layers
.iter()
.for_each(|m| *m.lock() = TqLayer::default());
Ok(())
}
fn memory_usage(&self) -> usize {
self.layers
.iter()
.map(|m| {
let g = m.lock();
let storage_bytes = g.storage.memory_usage(&self.metadata);
let qjl_bytes: usize = [g.qjl_signs.as_ref(), g.qjl_norms.as_ref()]
.iter()
.flatten()
.map(|t| t.elem_count() * t.dtype().size_in_bytes())
.sum();
storage_bytes + qjl_bytes
})
.sum()
}
}
const BYTE_BIT_MASKS: [u8; 8] = [1, 2, 4, 8, 16, 32, 64, 128];
fn unpack_qjl_signs(
qjl_signs: &Tensor,
qjl_norms: &Tensor,
kv_head: usize,
total_seq: usize,
head_dim: usize,
bit_masks: &Tensor,
scale_factor: f32,
) -> Result<(Tensor, Tensor)> {
let head_signs = qjl_signs
.narrow(0, kv_head, 1)?
.narrow(1, 0, total_seq)?
.squeeze(0)?;
let head_norms = qjl_norms
.narrow(0, kv_head, 1)?
.narrow(1, 0, total_seq)?
.squeeze(0)?
.to_dtype(DType::F32)?;
let signs_u8 = head_signs.unsqueeze(2)?;
let bytes_f = signs_u8.to_dtype(DType::F32)?;
let masks_f = bit_masks.to_dtype(DType::F32)?;
let divided = bytes_f.broadcast_div(&masks_f)?.floor()?;
let bit_set = ((÷d / 2.0)?.floor()? * 2.0 - ÷d)?.abs()?;
let signs_float = ((bit_set * 2.0)? - 1.0)?.reshape((total_seq, head_dim))?;
let signs_float_t = signs_float.t()?;
let c = (head_norms * scale_factor as f64)?;
let c_row = c.unsqueeze(0)?;
Ok((signs_float_t, c_row))
}
fn compute_qjl_signs_and_norms(
original: &Tensor,
dequantized: &Tensor,
n_vecs: usize,
head_dim: usize,
signs_per_head: usize,
) -> Result<(Tensor, Tensor)> {
let device = original.device().clone();
let residual = (original - dequantized)?;
let norms = residual
.sqr()?
.sum_keepdim(1)?
.sqrt()?
.squeeze(1)?
.to_dtype(DType::F16)?;
let residual_cpu = residual.to_device(&Device::Cpu)?;
let all_residual: Vec<f32> = residual_cpu.flatten_all()?.to_vec1()?;
let mut all_signs = vec![0u8; n_vecs * signs_per_head];
for vec_idx in 0..n_vecs {
let row_data = &all_residual[vec_idx * head_dim..(vec_idx + 1) * head_dim];
let signs = crate::compute_qjl_signs(row_data, head_dim, DEFAULT_QJL_SEED)
.map_err(super::cache_err)?;
let start = vec_idx * signs_per_head;
all_signs[start..start + signs_per_head].copy_from_slice(&signs);
}
let signs =
Tensor::from_vec(all_signs, n_vecs * signs_per_head, &Device::Cpu)?.to_device(&device)?;
Ok((signs, norms))
}