mistralrs_kv_cache/lib.rs
1//! Trait interface for compressed KV-cache implementations.
2//!
3//! This crate defines [`CompressedKVCache`] — the contract between inference
4//! engines (like mistral.rs) and KV-cache compression libraries (like turboquant).
5//!
6//! Implementing this trait is the only requirement for integrating a new
7//! cache compression method. The inference engine calls `prefill()` during
8//! multi-token processing and `decode()` during single-token generation.
9//! All compression decisions (fused vs dequantized, lazy vs immediate,
10//! CPU vs GPU) are made internally by the implementation.
11
12pub use candle_core::{DType, Device, Result, Tensor};
13
14/// Result of decompressing cached KV data.
15///
16/// Contains the full (decompressed) key and value tensors for all cached
17/// tokens, plus an optional bias to be added to attention logits before
18/// softmax (e.g. QJL correction in TurboQuant).
19pub struct DequantResult {
20 /// Decompressed keys. Shape: `[batch, num_kv_heads, total_seq_len, head_dim]`
21 pub k: Tensor,
22 /// Decompressed values. Shape: `[batch, num_kv_heads, total_seq_len, head_dim]`
23 pub v: Tensor,
24 /// Optional bias added to attention logits before softmax.
25 /// Shape: `[batch, num_heads, q_len, kv_len]` or `None`.
26 /// Used by TurboQuant's QJL correction; other methods return `None`.
27 pub logit_bias: Option<Tensor>,
28}
29
30/// Result of a decode step.
31///
32/// The cache implementation decides internally whether to compute attention
33/// via a fused kernel or return decompressed data for the caller's SDPA.
34pub enum DecodeOutput {
35 /// The implementation computed attention internally (e.g. fused CUDA kernel).
36 /// Shape: `[batch, num_attention_heads, 1, head_dim]`
37 Fused(Tensor),
38
39 /// The implementation decompressed the cache — caller runs SDPA.
40 /// Used on CPU, Metal, or when fused attention is not available.
41 Dequantized(DequantResult),
42}
43
44/// Configuration for attention computation during decode.
45///
46/// Passed to [`CompressedKVCache::decode`]. Extensible without breaking
47/// the trait signature — new fields can be added here.
48pub struct AttendConfig {
49 /// Softmax scaling factor, typically `1 / sqrt(head_dim)`.
50 pub softmax_scale: f32,
51 /// GQA group count: `num_attention_heads / num_kv_heads`.
52 /// Set to 1 for MHA (no grouping).
53 pub n_kv_groups: usize,
54}
55
56/// Trait for compressed KV-cache implementations.
57///
58/// Two methods for the two phases of LLM inference:
59/// - [`prefill`](Self::prefill): Store multiple tokens, return decompressed KV for Flash Attention.
60/// - [`decode`](Self::decode): Store single token, compute or prepare attention.
61///
62/// The implementation makes **all** internal decisions:
63/// - Fused kernel vs full decompression (based on device, kernel availability)
64/// - Immediate vs deferred compression during prefill
65/// - QJL bias computation (only when needed)
66///
67/// # Synchronization
68///
69/// All methods take `&self`. Implementations are responsible for interior
70/// synchronization (e.g. per-layer locks) so that calls for different layers
71/// may proceed in parallel. This enables use cases like speculative decoding
72/// where draft and target models run concurrently.
73///
74/// # Adding a new compression method
75///
76/// 1. Implement this trait for your cache struct
77/// 2. Add a match arm in the inference engine's cache factory
78/// 3. Done — no model code changes needed
79pub trait CompressedKVCache: Send + Sync {
80 /// Prefill: store multiple new KV tokens and return decompressed data.
81 ///
82 /// The implementation decides internally whether to compress immediately
83 /// or defer compression (lazy). Returns all cached tokens (old + new)
84 /// for the caller's SDPA / Flash Attention.
85 ///
86 /// `q` is provided for implementations that need it for bias computation
87 /// (e.g. QJL). Implementations that don't need it may ignore this parameter.
88 fn prefill(&self, layer: usize, k: &Tensor, v: &Tensor, q: &Tensor) -> Result<DequantResult>;
89
90 /// Decode: store a single new KV token and compute or prepare attention.
91 ///
92 /// Returns [`DecodeOutput::Fused`] if the implementation computed attention
93 /// internally (e.g. CUDA fused kernel), or [`DecodeOutput::Dequantized`] if
94 /// the caller should run standard SDPA on the returned data.
95 ///
96 /// The implementation decides which path to take based on device type,
97 /// kernel availability, and internal state.
98 fn decode(
99 &self,
100 layer: usize,
101 k: &Tensor,
102 v: &Tensor,
103 q: &Tensor,
104 config: &AttendConfig,
105 ) -> Result<DecodeOutput>;
106
107 /// Number of tokens currently stored for a given layer.
108 fn seq_len(&self, layer: usize) -> usize;
109
110 /// Reset all layers to empty state.
111 fn reset(&self) -> Result<()>;
112
113 /// Total persistent memory usage in bytes (compressed cache, not temporary buffers).
114 fn memory_usage(&self) -> usize;
115}