1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
//! Per-layer K/V cache for autoregressive decode (Whisper, Qwen, Gemma, …).
use crate::compile_cache::pad_rows;
/// Layer-wise past K/V tensors in row-major `[past_len * kv_dim]` layout per layer.
#[derive(Debug, Clone)]
pub struct LayerKvCache {
pub past_len: usize,
pub layers_k: Vec<Vec<f32>>,
pub layers_v: Vec<Vec<f32>>,
}
impl LayerKvCache {
pub fn from_layer_outputs(
num_layers: usize,
batch: usize,
past_seq: usize,
kv_dim: usize,
outputs: &[Vec<f32>],
) -> Result<Self, String> {
if outputs.len() != 2 * num_layers {
return Err(format!(
"from_layer_outputs: expected {} K/V tensors, got {}",
2 * num_layers,
outputs.len()
));
}
let expected = batch * past_seq * kv_dim;
let mut layers_k = Vec::with_capacity(num_layers);
let mut layers_v = Vec::with_capacity(num_layers);
for layer in 0..num_layers {
let k = &outputs[2 * layer];
let v = &outputs[2 * layer + 1];
if k.len() != expected || v.len() != expected {
return Err(format!(
"layer {layer}: k.len={} v.len={} expected {expected}",
k.len(),
v.len()
));
}
layers_k.push(k.clone());
layers_v.push(v.clone());
}
Ok(Self {
past_len: past_seq,
layers_k,
layers_v,
})
}
/// Pad each layer's K/V to `upper` rows along the sequence axis (`kv_dim` inner).
pub fn pad_layers_to_upper(&self, upper: u64, kv_dim: usize) -> (Vec<Vec<f32>>, Vec<Vec<f32>>) {
let padded_k = self
.layers_k
.iter()
.map(|k| pad_rows(k, kv_dim, upper))
.collect();
let padded_v = self
.layers_v
.iter()
.map(|v| pad_rows(v, kv_dim, upper))
.collect();
(padded_k, padded_v)
}
/// Update cache from decode outputs: `[logits, k0, v0, k1, v1, …]` (bucket-padded).
pub fn advance_from_decode_outputs(
&mut self,
outputs: Vec<Vec<f32>>,
_batch: usize,
kv_dim: usize,
) -> Result<(), String> {
let n = self.layers_k.len();
if outputs.len() != 1 + 2 * n {
return Err(format!(
"advance_from_decode_outputs: expected {} outputs, got {}",
1 + 2 * n,
outputs.len()
));
}
let new_len = self.past_len + 1;
let real_len = new_len * kv_dim;
let mut iter = outputs.into_iter();
let _logits = iter.next().ok_or("missing logits")?;
for i in 0..n {
let k = iter.next().ok_or("missing k")?;
let v = iter.next().ok_or("missing v")?;
self.layers_k[i] = k[..real_len.min(k.len())].to_vec();
self.layers_v[i] = v[..real_len.min(v.len())].to_vec();
}
self.past_len = new_len;
Ok(())
}
}