mistralrs-kv-cache 0.2.0

Trait interface for compressed KV-cache implementations in mistral.rs
Documentation
//! Trait interface for compressed KV-cache implementations.
//!
//! This crate defines [`CompressedKVCache`] — the contract between inference
//! engines (like mistral.rs) and KV-cache compression libraries (like turboquant).
//!
//! Implementing this trait is the only requirement for integrating a new
//! cache compression method. The inference engine calls `prefill()` during
//! multi-token processing and `decode()` during single-token generation.
//! All compression decisions (fused vs dequantized, lazy vs immediate,
//! CPU vs GPU) are made internally by the implementation.

pub use candle_core::{DType, Device, Result, Tensor};

/// Result of decompressing cached KV data.
///
/// Contains the full (decompressed) key and value tensors for all cached
/// tokens, plus an optional bias to be added to attention logits before
/// softmax (e.g. QJL correction in TurboQuant).
pub struct DequantResult {
    /// Decompressed keys. Shape: `[batch, num_kv_heads, total_seq_len, head_dim]`
    pub k: Tensor,
    /// Decompressed values. Shape: `[batch, num_kv_heads, total_seq_len, head_dim]`
    pub v: Tensor,
    /// Optional bias added to attention logits before softmax.
    /// Shape: `[batch, num_heads, q_len, kv_len]` or `None`.
    /// Used by TurboQuant's QJL correction; other methods return `None`.
    pub logit_bias: Option<Tensor>,
}

/// Result of a decode step.
///
/// The cache implementation decides internally whether to compute attention
/// via a fused kernel or return decompressed data for the caller's SDPA.
pub enum DecodeOutput {
    /// The implementation computed attention internally (e.g. fused CUDA kernel).
    /// Shape: `[batch, num_attention_heads, 1, head_dim]`
    Fused(Tensor),

    /// The implementation decompressed the cache — caller runs SDPA.
    /// Used on CPU, Metal, or when fused attention is not available.
    Dequantized(DequantResult),
}

/// Configuration for attention computation during decode.
///
/// Passed to [`CompressedKVCache::decode`]. Extensible without breaking
/// the trait signature — new fields can be added here.
pub struct AttendConfig {
    /// Softmax scaling factor, typically `1 / sqrt(head_dim)`.
    pub softmax_scale: f32,
    /// GQA group count: `num_attention_heads / num_kv_heads`.
    /// Set to 1 for MHA (no grouping).
    pub n_kv_groups: usize,
}

/// Trait for compressed KV-cache implementations.
///
/// Two methods for the two phases of LLM inference:
/// - [`prefill`](Self::prefill): Store multiple tokens, return decompressed KV for Flash Attention.
/// - [`decode`](Self::decode): Store single token, compute or prepare attention.
///
/// The implementation makes **all** internal decisions:
/// - Fused kernel vs full decompression (based on device, kernel availability)
/// - Immediate vs deferred compression during prefill
/// - QJL bias computation (only when needed)
///
/// # Adding a new compression method
///
/// 1. Implement this trait for your cache struct
/// 2. Add a match arm in the inference engine's cache factory
/// 3. Done — no model code changes needed
pub trait CompressedKVCache: Send + Sync {
    /// Prefill: store multiple new KV tokens and return decompressed data.
    ///
    /// The implementation decides internally whether to compress immediately
    /// or defer compression (lazy). Returns all cached tokens (old + new)
    /// for the caller's SDPA / Flash Attention.
    ///
    /// `q` is provided for implementations that need it for bias computation
    /// (e.g. QJL). Implementations that don't need it may ignore this parameter.
    fn prefill(
        &mut self,
        layer: usize,
        k: &Tensor,
        v: &Tensor,
        q: &Tensor,
    ) -> Result<DequantResult>;

    /// Decode: store a single new KV token and compute or prepare attention.
    ///
    /// Returns [`DecodeOutput::Fused`] if the implementation computed attention
    /// internally (e.g. CUDA fused kernel), or [`DecodeOutput::Dequantized`] if
    /// the caller should run standard SDPA on the returned data.
    ///
    /// The implementation decides which path to take based on device type,
    /// kernel availability, and internal state.
    fn decode(
        &mut self,
        layer: usize,
        k: &Tensor,
        v: &Tensor,
        q: &Tensor,
        config: &AttendConfig,
    ) -> Result<DecodeOutput>;

    /// Number of tokens currently stored for a given layer.
    fn seq_len(&self, layer: usize) -> usize;

    /// Reset all layers to empty state.
    fn reset(&mut self) -> Result<()>;

    /// Total persistent memory usage in bytes (compressed cache, not temporary buffers).
    fn memory_usage(&self) -> usize;
}