Skip to main content

mistralrs_kv_cache/
lib.rs

1//! Trait interface for compressed KV-cache implementations.
2//!
3//! This crate defines [`CompressedKVCache`] — the contract between inference
4//! engines (like mistral.rs) and KV-cache compression libraries (like turboquant).
5//!
6//! Implementing this trait is the only requirement for integrating a new
7//! cache compression method. The inference engine calls `prefill()` during
8//! multi-token processing and `decode()` during single-token generation.
9//! All compression decisions (fused vs dequantized, lazy vs immediate,
10//! CPU vs GPU) are made internally by the implementation.
11
12pub use candle_core::{DType, Device, Result, Tensor};
13
14/// Result of decompressing cached KV data.
15///
16/// Contains the full (decompressed) key and value tensors for all cached
17/// tokens, plus an optional bias to be added to attention logits before
18/// softmax (e.g. QJL correction in TurboQuant).
19pub struct DequantResult {
20    /// Decompressed keys. Shape: `[batch, num_kv_heads, total_seq_len, head_dim]`
21    pub k: Tensor,
22    /// Decompressed values. Shape: `[batch, num_kv_heads, total_seq_len, head_dim]`
23    pub v: Tensor,
24    /// Optional bias added to attention logits before softmax.
25    /// Shape: `[batch, num_heads, q_len, kv_len]` or `None`.
26    /// Used by TurboQuant's QJL correction; other methods return `None`.
27    pub logit_bias: Option<Tensor>,
28}
29
30/// Result of a decode step.
31///
32/// The cache implementation decides internally whether to compute attention
33/// via a fused kernel or return decompressed data for the caller's SDPA.
34pub enum DecodeOutput {
35    /// The implementation computed attention internally (e.g. fused CUDA kernel).
36    /// Shape: `[batch, num_attention_heads, 1, head_dim]`
37    Fused(Tensor),
38
39    /// The implementation decompressed the cache — caller runs SDPA.
40    /// Used on CPU, Metal, or when fused attention is not available.
41    Dequantized(DequantResult),
42}
43
44/// Configuration for attention computation during decode.
45///
46/// Passed to [`CompressedKVCache::decode`]. Extensible without breaking
47/// the trait signature — new fields can be added here.
48pub struct AttendConfig {
49    /// Softmax scaling factor, typically `1 / sqrt(head_dim)`.
50    pub softmax_scale: f32,
51    /// GQA group count: `num_attention_heads / num_kv_heads`.
52    /// Set to 1 for MHA (no grouping).
53    pub n_kv_groups: usize,
54}
55
56/// Trait for compressed KV-cache implementations.
57///
58/// Two methods for the two phases of LLM inference:
59/// - [`prefill`](Self::prefill): Store multiple tokens, return decompressed KV for Flash Attention.
60/// - [`decode`](Self::decode): Store single token, compute or prepare attention.
61///
62/// The implementation makes **all** internal decisions:
63/// - Fused kernel vs full decompression (based on device, kernel availability)
64/// - Immediate vs deferred compression during prefill
65/// - QJL bias computation (only when needed)
66///
67/// # Synchronization
68///
69/// All methods take `&self`. Implementations are responsible for interior
70/// synchronization (e.g. per-layer locks) so that calls for different layers
71/// may proceed in parallel. This enables use cases like speculative decoding
72/// where draft and target models run concurrently.
73///
74/// # Adding a new compression method
75///
76/// 1. Implement this trait for your cache struct
77/// 2. Add a match arm in the inference engine's cache factory
78/// 3. Done — no model code changes needed
79pub trait CompressedKVCache: Send + Sync {
80    /// Prefill: store multiple new KV tokens and return decompressed data.
81    ///
82    /// The implementation decides internally whether to compress immediately
83    /// or defer compression (lazy). Returns all cached tokens (old + new)
84    /// for the caller's SDPA / Flash Attention.
85    ///
86    /// `q` is provided for implementations that need it for bias computation
87    /// (e.g. QJL). Implementations that don't need it may ignore this parameter.
88    fn prefill(&self, layer: usize, k: &Tensor, v: &Tensor, q: &Tensor) -> Result<DequantResult>;
89
90    /// Decode: store a single new KV token and compute or prepare attention.
91    ///
92    /// Returns [`DecodeOutput::Fused`] if the implementation computed attention
93    /// internally (e.g. CUDA fused kernel), or [`DecodeOutput::Dequantized`] if
94    /// the caller should run standard SDPA on the returned data.
95    ///
96    /// The implementation decides which path to take based on device type,
97    /// kernel availability, and internal state.
98    fn decode(
99        &self,
100        layer: usize,
101        k: &Tensor,
102        v: &Tensor,
103        q: &Tensor,
104        config: &AttendConfig,
105    ) -> Result<DecodeOutput>;
106
107    /// Number of tokens currently stored for a given layer.
108    fn seq_len(&self, layer: usize) -> usize;
109
110    /// Reset all layers to empty state.
111    fn reset(&self) -> Result<()>;
112
113    /// Total persistent memory usage in bytes (compressed cache, not temporary buffers).
114    fn memory_usage(&self) -> usize;
115}