Skip to main content

rlx_runtime/
kv_cache.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Per-layer K/V cache for autoregressive decode (Whisper, Qwen, Gemma, …).
17
18use crate::compile_cache::pad_rows;
19
20/// Layer-wise past K/V tensors in row-major `[past_len * kv_dim]` layout per layer.
21#[derive(Debug, Clone)]
22pub struct LayerKvCache {
23    pub past_len: usize,
24    pub layers_k: Vec<Vec<f32>>,
25    pub layers_v: Vec<Vec<f32>>,
26}
27
28impl LayerKvCache {
29    pub fn from_layer_outputs(
30        num_layers: usize,
31        batch: usize,
32        past_seq: usize,
33        kv_dim: usize,
34        outputs: &[Vec<f32>],
35    ) -> Result<Self, String> {
36        let dims: Vec<usize> = vec![kv_dim; num_layers];
37        Self::from_layer_outputs_per_layer(num_layers, batch, past_seq, &dims, outputs)
38    }
39
40    /// Like [`Self::from_layer_outputs`] but accepts a per-layer
41    /// `kv_dim` vector. Gemma 4 12B's full-attention layers have
42    /// `kv_dim = 1 * 512 = 512` while sliding layers have `8 * 256 =
43    /// 2048`; this constructor handles that heterogeneity.
44    pub fn from_layer_outputs_per_layer(
45        num_layers: usize,
46        batch: usize,
47        past_seq: usize,
48        kv_dims: &[usize],
49        outputs: &[Vec<f32>],
50    ) -> Result<Self, String> {
51        if outputs.len() != 2 * num_layers {
52            return Err(format!(
53                "from_layer_outputs_per_layer: expected {} K/V tensors, got {}",
54                2 * num_layers,
55                outputs.len()
56            ));
57        }
58        if kv_dims.len() != num_layers {
59            return Err(format!(
60                "from_layer_outputs_per_layer: expected {} kv_dims, got {}",
61                num_layers,
62                kv_dims.len()
63            ));
64        }
65        let mut layers_k = Vec::with_capacity(num_layers);
66        let mut layers_v = Vec::with_capacity(num_layers);
67        for layer in 0..num_layers {
68            let kv_dim = kv_dims[layer];
69            let expected = batch * past_seq * kv_dim;
70            let k = &outputs[2 * layer];
71            let v = &outputs[2 * layer + 1];
72            if k.len() != expected || v.len() != expected {
73                return Err(format!(
74                    "layer {layer}: k.len={} v.len={} expected {expected} (kv_dim={kv_dim})",
75                    k.len(),
76                    v.len()
77                ));
78            }
79            layers_k.push(k.clone());
80            layers_v.push(v.clone());
81        }
82        Ok(Self {
83            past_len: past_seq,
84            layers_k,
85            layers_v,
86        })
87    }
88
89    /// Pad each layer's K/V to `upper` rows along the sequence axis (`kv_dim` inner).
90    pub fn pad_layers_to_upper(&self, upper: u64, kv_dim: usize) -> (Vec<Vec<f32>>, Vec<Vec<f32>>) {
91        let dims: Vec<usize> = vec![kv_dim; self.layers_k.len()];
92        self.pad_layers_to_upper_per_layer(upper, &dims)
93    }
94
95    /// Like [`Self::pad_layers_to_upper`] but pads each layer to its
96    /// own `kv_dim`. The number of dims must equal the number of
97    /// cached layers.
98    pub fn pad_layers_to_upper_per_layer(
99        &self,
100        upper: u64,
101        kv_dims: &[usize],
102    ) -> (Vec<Vec<f32>>, Vec<Vec<f32>>) {
103        assert_eq!(
104            kv_dims.len(),
105            self.layers_k.len(),
106            "pad_layers_to_upper_per_layer: kv_dims len {} != layers {}",
107            kv_dims.len(),
108            self.layers_k.len(),
109        );
110        let padded_k = self
111            .layers_k
112            .iter()
113            .zip(kv_dims.iter())
114            .map(|(k, &d)| pad_rows(k, d, upper))
115            .collect();
116        let padded_v = self
117            .layers_v
118            .iter()
119            .zip(kv_dims.iter())
120            .map(|(v, &d)| pad_rows(v, d, upper))
121            .collect();
122        (padded_k, padded_v)
123    }
124
125    /// Update cache from decode outputs: `[logits, k0, v0, k1, v1, …]` (bucket-padded).
126    pub fn advance_from_decode_outputs(
127        &mut self,
128        outputs: Vec<Vec<f32>>,
129        batch: usize,
130        kv_dim: usize,
131    ) -> Result<(), String> {
132        let dims: Vec<usize> = vec![kv_dim; self.layers_k.len()];
133        self.advance_from_decode_outputs_per_layer(outputs, batch, &dims)
134    }
135
136    /// Trim each layer's K/V history to at most `window` rows on
137    /// the sequence axis, keeping the most recent rows. Used by
138    /// Gemma 3/4 sliding-attention layers — long contexts can keep
139    /// only the last `window` (e.g. 1024) tokens per sliding layer
140    /// without affecting attention semantics (those layers mask out
141    /// older positions anyway).
142    ///
143    /// `kv_dims_keep` selects which layers to trim and at what dim:
144    /// `kv_dims_keep[i] = Some((dim, window))` trims layer `i`,
145    /// `None` leaves the layer untouched. Pass-through for layers
146    /// whose attention is full-causal.
147    ///
148    /// Note: `past_len` is unchanged — the per-layer K/V buffers
149    /// just hold fewer real rows now; the decode flow's per-layer
150    /// `past_k_{i}` input shape will see the trimmed length. Caller
151    /// is responsible for ensuring the graph's declared `past_seq`
152    /// matches the trimmed length OR the trimmed layer is bound
153    /// dynamically.
154    pub fn trim_sliding_window_per_layer(
155        &mut self,
156        kv_dims_keep: &[Option<(usize, usize)>],
157    ) -> Result<(), String> {
158        if kv_dims_keep.len() != self.layers_k.len() {
159            return Err(format!(
160                "trim_sliding_window_per_layer: kv_dims_keep len {} != layers {}",
161                kv_dims_keep.len(),
162                self.layers_k.len(),
163            ));
164        }
165        for (i, spec) in kv_dims_keep.iter().enumerate() {
166            let Some((kv_dim, window)) = spec else {
167                continue;
168            };
169            let kv_dim = *kv_dim;
170            let window = *window;
171            if window == 0 || kv_dim == 0 {
172                continue;
173            }
174            let rows = self.layers_k[i].len() / kv_dim;
175            if rows <= window {
176                continue;
177            }
178            let drop_rows = rows - window;
179            let drop_bytes = drop_rows * kv_dim;
180            self.layers_k[i].drain(..drop_bytes);
181            self.layers_v[i].drain(..drop_bytes);
182        }
183        Ok(())
184    }
185
186    /// Per-layer variant of [`Self::advance_from_decode_outputs`].
187    pub fn advance_from_decode_outputs_per_layer(
188        &mut self,
189        outputs: Vec<Vec<f32>>,
190        _batch: usize,
191        kv_dims: &[usize],
192    ) -> Result<(), String> {
193        let n = self.layers_k.len();
194        if outputs.len() != 1 + 2 * n {
195            return Err(format!(
196                "advance_from_decode_outputs_per_layer: expected {} outputs, got {}",
197                1 + 2 * n,
198                outputs.len()
199            ));
200        }
201        if kv_dims.len() != n {
202            return Err(format!(
203                "advance_from_decode_outputs_per_layer: kv_dims len {} != layers {n}",
204                kv_dims.len()
205            ));
206        }
207        let new_len = self.past_len + 1;
208        let mut iter = outputs.into_iter();
209        let _logits = iter.next().ok_or("missing logits")?;
210        for i in 0..n {
211            let k = iter.next().ok_or("missing k")?;
212            let v = iter.next().ok_or("missing v")?;
213            let real_len = new_len * kv_dims[i];
214            self.layers_k[i] = k[..real_len.min(k.len())].to_vec();
215            self.layers_v[i] = v[..real_len.min(v.len())].to_vec();
216        }
217        self.past_len = new_len;
218        Ok(())
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    #[test]
227    fn sliding_window_trim_keeps_last_w_rows() {
228        // 3 layers, each storing 6 rows of kv_dim=4 = 24 floats.
229        let kv_dim = 4;
230        let rows = 6;
231        let mut cache = LayerKvCache {
232            past_len: rows,
233            layers_k: vec![(0..(rows * kv_dim)).map(|x| x as f32).collect(); 3],
234            layers_v: vec![(0..(rows * kv_dim)).map(|x| x as f32).collect(); 3],
235        };
236        // Trim layer 0 to last 2 rows; layer 1 untouched; layer 2 to last 4.
237        let spec = [Some((kv_dim, 2)), None, Some((kv_dim, 4))];
238        cache.trim_sliding_window_per_layer(&spec).unwrap();
239        assert_eq!(cache.layers_k[0].len(), 2 * kv_dim);
240        // Layer 0 should now hold the LAST 2 rows: rows 4 and 5.
241        assert_eq!(
242            cache.layers_k[0],
243            vec![16., 17., 18., 19., 20., 21., 22., 23.]
244        );
245        assert_eq!(
246            cache.layers_k[1].len(),
247            6 * kv_dim,
248            "untouched layer keeps full history"
249        );
250        assert_eq!(cache.layers_k[2].len(), 4 * kv_dim);
251    }
252
253    #[test]
254    fn sliding_window_trim_no_op_when_under_window() {
255        let kv_dim = 4;
256        let rows = 3;
257        let mut cache = LayerKvCache {
258            past_len: rows,
259            layers_k: vec![vec![1.0f32; rows * kv_dim]],
260            layers_v: vec![vec![2.0f32; rows * kv_dim]],
261        };
262        cache
263            .trim_sliding_window_per_layer(&[Some((kv_dim, 10))])
264            .unwrap();
265        assert_eq!(cache.layers_k[0].len(), rows * kv_dim);
266    }
267}