Skip to main content

rlx_runtime/
kv_cache.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Per-layer K/V cache for autoregressive decode (Whisper, Qwen, Gemma, …).
17
18use crate::compile_cache::pad_rows;
19
20/// Layer-wise past K/V tensors in row-major `[past_len * kv_dim]` layout per layer.
21#[derive(Debug, Clone)]
22pub struct LayerKvCache {
23    pub past_len: usize,
24    pub layers_k: Vec<Vec<f32>>,
25    pub layers_v: Vec<Vec<f32>>,
26}
27
28impl LayerKvCache {
29    pub fn from_layer_outputs(
30        num_layers: usize,
31        batch: usize,
32        past_seq: usize,
33        kv_dim: usize,
34        outputs: &[Vec<f32>],
35    ) -> Result<Self, String> {
36        if outputs.len() != 2 * num_layers {
37            return Err(format!(
38                "from_layer_outputs: expected {} K/V tensors, got {}",
39                2 * num_layers,
40                outputs.len()
41            ));
42        }
43        let expected = batch * past_seq * kv_dim;
44        let mut layers_k = Vec::with_capacity(num_layers);
45        let mut layers_v = Vec::with_capacity(num_layers);
46        for layer in 0..num_layers {
47            let k = &outputs[2 * layer];
48            let v = &outputs[2 * layer + 1];
49            if k.len() != expected || v.len() != expected {
50                return Err(format!(
51                    "layer {layer}: k.len={} v.len={} expected {expected}",
52                    k.len(),
53                    v.len()
54                ));
55            }
56            layers_k.push(k.clone());
57            layers_v.push(v.clone());
58        }
59        Ok(Self {
60            past_len: past_seq,
61            layers_k,
62            layers_v,
63        })
64    }
65
66    /// Pad each layer's K/V to `upper` rows along the sequence axis (`kv_dim` inner).
67    pub fn pad_layers_to_upper(&self, upper: u64, kv_dim: usize) -> (Vec<Vec<f32>>, Vec<Vec<f32>>) {
68        let padded_k = self
69            .layers_k
70            .iter()
71            .map(|k| pad_rows(k, kv_dim, upper))
72            .collect();
73        let padded_v = self
74            .layers_v
75            .iter()
76            .map(|v| pad_rows(v, kv_dim, upper))
77            .collect();
78        (padded_k, padded_v)
79    }
80
81    /// Update cache from decode outputs: `[logits, k0, v0, k1, v1, …]` (bucket-padded).
82    pub fn advance_from_decode_outputs(
83        &mut self,
84        outputs: Vec<Vec<f32>>,
85        _batch: usize,
86        kv_dim: usize,
87    ) -> Result<(), String> {
88        let n = self.layers_k.len();
89        if outputs.len() != 1 + 2 * n {
90            return Err(format!(
91                "advance_from_decode_outputs: expected {} outputs, got {}",
92                1 + 2 * n,
93                outputs.len()
94            ));
95        }
96        let new_len = self.past_len + 1;
97        let real_len = new_len * kv_dim;
98        let mut iter = outputs.into_iter();
99        let _logits = iter.next().ok_or("missing logits")?;
100        for i in 0..n {
101            let k = iter.next().ok_or("missing k")?;
102            let v = iter.next().ok_or("missing v")?;
103            self.layers_k[i] = k[..real_len.min(k.len())].to_vec();
104            self.layers_v[i] = v[..real_len.min(v.len())].to_vec();
105        }
106        self.past_len = new_len;
107        Ok(())
108    }
109}