Skip to main content

rlx_models_core/
gpu_kv.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! GPU-resident KV cache via `bind_gpu_handle` (MLX device arrays; Metal/CUDA/WGPU host mirrors).
17
18use crate::autoregressive::{KvCacheState, compact_bucketed_kv_buffer, past_kv_input_names};
19use anyhow::{Context, Result, ensure};
20use rlx_ir::{Graph, hir::HirModule};
21use rlx_runtime::compile_cache::{BucketedCompileCache, CacheRunInput, pad_rows};
22use rlx_runtime::kv_cache::LayerKvCache;
23use rlx_runtime::{CompileOptions, CompiledGraph, Device};
24use std::collections::HashMap;
25
26/// Backends that support persistent K/V handles + selective logits readback.
27pub fn device_supports_gpu_kv(device: Device) -> bool {
28    matches!(
29        device,
30        Device::Mlx | Device::Metal | Device::Cuda | Device::Rocm | Device::Gpu | Device::Vulkan
31    )
32}
33
34/// Tracks which bucket upper bound GPU handles were allocated for.
35#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
36pub struct GpuKvBinding {
37    pub upper: u64,
38}
39
40/// Per compile-cache GPU binding state.
41#[derive(Debug, Default)]
42pub struct GpuKvCacheSet {
43    pub causal: GpuKvBinding,
44    pub decode_mtp: GpuKvBinding,
45    pub mtp: GpuKvBinding,
46}
47
48impl GpuKvCacheSet {
49    pub fn reset(&mut self) {
50        *self = Self::default();
51    }
52
53    /// Drop decode bindings after an MTP block advanced `past_len`.
54    pub fn reset_decode_after_mtp(&mut self) {
55        self.causal = GpuKvBinding::default();
56        self.decode_mtp = GpuKvBinding::default();
57        self.mtp = GpuKvBinding::default();
58    }
59}
60
61/// True when `cross_k_0` is already bound on this graph.
62pub fn cross_attn_gpu_handles_ready(compiled: &CompiledGraph) -> bool {
63    compiled.has_gpu_handle("cross_k_0")
64}
65
66/// Upload fixed cross-attention K/V (`cross_k_*` / `cross_v_*`) for Whisper-style decoders.
67pub fn install_cross_attn_gpu_handles(
68    compiled: &mut CompiledGraph,
69    cross: &LayerKvCache,
70    enc_seq: usize,
71    kv_dim: usize,
72    num_layers: usize,
73) -> Result<()> {
74    let upper = enc_seq as u64;
75    for i in 0..num_layers {
76        let k_name = format!("cross_k_{i}");
77        let v_name = format!("cross_v_{i}");
78        let k_pad = pad_rows(cross.layers_k[i].as_slice(), kv_dim, upper);
79        let v_pad = pad_rows(cross.layers_v[i].as_slice(), kv_dim, upper);
80        ensure!(
81            compiled.bind_gpu_handle(k_name.as_str(), &k_pad),
82            "bind_gpu_handle failed for {k_name}"
83        );
84        ensure!(
85            compiled.bind_gpu_handle(v_name.as_str(), &v_pad),
86            "bind_gpu_handle failed for {v_name}"
87        );
88    }
89    Ok(())
90}
91
92/// Upload `prefix_rows` of `kv` into `past_k_*` / `past_v_*` GPU handles and wire output feeds.
93pub fn install_gpu_kv_handles(
94    compiled: &mut CompiledGraph,
95    kv: &KvCacheState,
96    prefix_rows: usize,
97    upper: u64,
98    kv_dim: usize,
99    num_layers: usize,
100) -> Result<()> {
101    let names = past_kv_input_names(num_layers);
102    for layer in 0..num_layers {
103        let k_name = names[2 * layer].as_str();
104        let v_name = names[2 * layer + 1].as_str();
105        let n = prefix_rows * kv_dim;
106        let k_slice = &kv.layers_k[layer][..n.min(kv.layers_k[layer].len())];
107        let v_slice = &kv.layers_v[layer][..n.min(kv.layers_v[layer].len())];
108        let k_pad = pad_rows(k_slice, kv_dim, upper);
109        let v_pad = pad_rows(v_slice, kv_dim, upper);
110        ensure!(
111            compiled.bind_gpu_handle(k_name, &k_pad),
112            "bind_gpu_handle failed for {k_name}"
113        );
114        compiled.set_gpu_handle_feed(k_name, 1 + 2 * layer);
115        ensure!(
116            compiled.bind_gpu_handle(v_name, &v_pad),
117            "bind_gpu_handle failed for {v_name}"
118        );
119        compiled.set_gpu_handle_feed(v_name, 2 + 2 * layer);
120    }
121    Ok(())
122}
123
124fn layer_host_rows(
125    compiled: &CompiledGraph,
126    name: &str,
127    host: &[f32],
128    past_len: usize,
129    kv_dim: usize,
130) -> Vec<f32> {
131    if compiled.has_gpu_handle(name) {
132        if let Some(buf) = compiled.read_gpu_handle(name) {
133            return compact_bucketed_kv_buffer(&buf, past_len, kv_dim, 1);
134        }
135    }
136    let take = (past_len * kv_dim).min(host.len());
137    host[..take].to_vec()
138}
139
140/// Rebind handles after a bucket change (read back prior GPU K/V, pad to new `upper`).
141pub fn reinstall_gpu_kv_handles(
142    compiled: &mut CompiledGraph,
143    kv: &KvCacheState,
144    _old_upper: u64,
145    new_upper: u64,
146    kv_dim: usize,
147    num_layers: usize,
148) -> Result<()> {
149    let names = past_kv_input_names(num_layers);
150    let mut tmp = KvCacheState {
151        past_len: kv.past_len,
152        layers_k: Vec::with_capacity(num_layers),
153        layers_v: Vec::with_capacity(num_layers),
154    };
155    for layer in 0..num_layers {
156        tmp.layers_k.push(layer_host_rows(
157            compiled,
158            &names[2 * layer],
159            &kv.layers_k[layer],
160            kv.past_len,
161            kv_dim,
162        ));
163        tmp.layers_v.push(layer_host_rows(
164            compiled,
165            &names[2 * layer + 1],
166            &kv.layers_v[layer],
167            kv.past_len,
168            kv_dim,
169        ));
170    }
171    install_gpu_kv_handles(compiled, &tmp, tmp.past_len, new_upper, kv_dim, num_layers)
172}
173
174/// Pull GPU K/V back to host `kv` (for MTP truncate / prefill cache).
175pub fn sync_gpu_kv_to_host(
176    compiled: &CompiledGraph,
177    kv: &mut KvCacheState,
178    kv_dim: usize,
179    num_layers: usize,
180) -> Result<()> {
181    let names = past_kv_input_names(num_layers);
182    let n = kv.past_len * kv_dim;
183    for layer in 0..num_layers {
184        kv.layers_k[layer] = layer_host_rows(
185            compiled,
186            &names[2 * layer],
187            &kv.layers_k[layer],
188            kv.past_len,
189            kv_dim,
190        );
191        kv.layers_v[layer] = layer_host_rows(
192            compiled,
193            &names[2 * layer + 1],
194            &kv.layers_v[layer],
195            kv.past_len,
196            kv_dim,
197        );
198        if kv.layers_k[layer].len() > n {
199            kv.layers_k[layer].truncate(n);
200        }
201        if kv.layers_v[layer].len() > n {
202            kv.layers_v[layer].truncate(n);
203        }
204    }
205    Ok(())
206}
207
208fn ensure_gpu_kv_bindings(
209    compiled: &mut CompiledGraph,
210    kv: &KvCacheState,
211    binding: &mut GpuKvBinding,
212    upper: u64,
213    kv_dim: usize,
214    num_layers: usize,
215    refresh_kv: bool,
216) -> Result<()> {
217    let names = past_kv_input_names(num_layers);
218    let handles_live = compiled.has_gpu_handle(names[0].as_str());
219    if refresh_kv || !handles_live || binding.upper != upper {
220        install_gpu_kv_handles(compiled, kv, kv.past_len, upper, kv_dim, num_layers)?;
221        binding.upper = upper;
222    }
223    Ok(())
224}
225
226/// One bucketed decode step with GPU-resident K/V (output 0 readback — logits or hidden_states).
227///
228/// `cache_key` indexes the compile bucket (for batch>1 use `(batch << 32) | past_seq`).
229pub fn run_bucketed_kv_decode_gpu<F>(
230    cache: &mut BucketedCompileCache,
231    cache_key: u64,
232    past_seq: usize,
233    kv: &mut KvCacheState,
234    binding: &mut GpuKvBinding,
235    kv_dim: usize,
236    num_layers: usize,
237    fixed_inputs: &[CacheRunInput<'_>],
238    build: F,
239    options: &CompileOptions,
240    refresh_kv: bool,
241) -> Result<Vec<f32>>
242where
243    F: FnOnce(u64) -> (Graph, HashMap<String, Vec<f32>>),
244{
245    let (upper, compiled) = cache
246        .ensure_graph_with_params(cache_key, build, options)
247        .ok_or_else(|| anyhow::anyhow!("cache_key {cache_key} outside decode buckets"))?;
248
249    ensure_gpu_kv_bindings(compiled, kv, binding, upper, kv_dim, num_layers, refresh_kv)?;
250
251    let mut pairs: Vec<(&str, &[f32])> = Vec::with_capacity(fixed_inputs.len());
252    for inp in fixed_inputs {
253        pairs.push((inp.name, inp.data));
254    }
255
256    // Metal: skip active extent (CPU ignores it; Metal SDPA scaling breaks bucketed mask).
257    if compiled.device() != Device::Metal {
258        compiled.set_active_extent(Some((upper as usize + 1, upper as usize + 1)));
259    }
260    let outs = compiled.run_read_outputs(&pairs, Some(&[0]));
261    compiled.set_active_extent(None);
262
263    let logits = outs
264        .into_iter()
265        .next()
266        .context("gpu kv decode: missing logits output")?;
267    kv.past_len = past_seq + 1;
268    Ok(logits)
269}
270
271/// HIR variant of [`run_bucketed_kv_decode_gpu`] (stable Metal bucketed decode).
272pub fn run_bucketed_kv_decode_gpu_hir<F>(
273    cache: &mut BucketedCompileCache,
274    cache_key: u64,
275    past_seq: usize,
276    kv: &mut KvCacheState,
277    binding: &mut GpuKvBinding,
278    kv_dim: usize,
279    num_layers: usize,
280    fixed_inputs: &[CacheRunInput<'_>],
281    build: F,
282    options: &CompileOptions,
283    refresh_kv: bool,
284) -> Result<Vec<f32>>
285where
286    F: FnOnce(u64) -> (HirModule, HashMap<String, Vec<f32>>),
287{
288    let (upper, compiled) = cache
289        .ensure_hir_with_params(cache_key, build, options)
290        .ok_or_else(|| anyhow::anyhow!("cache_key {cache_key} outside decode buckets"))?;
291
292    ensure_gpu_kv_bindings(compiled, kv, binding, upper, kv_dim, num_layers, refresh_kv)?;
293
294    let mut pairs: Vec<(&str, &[f32])> = Vec::with_capacity(fixed_inputs.len());
295    for inp in fixed_inputs {
296        pairs.push((inp.name, inp.data));
297    }
298
299    if compiled.device() != Device::Metal {
300        compiled.set_active_extent(Some((upper as usize + 1, upper as usize + 1)));
301    }
302    let outs = compiled.run_read_outputs(&pairs, Some(&[0]));
303    compiled.set_active_extent(None);
304
305    let logits = outs
306        .into_iter()
307        .next()
308        .context("gpu kv decode: missing logits output")?;
309    kv.past_len = past_seq + 1;
310    Ok(logits)
311}
312
313/// MTP query block with GPU-resident prefix K/V (logits slab readback only).
314pub fn run_bucketed_kv_mtp_gpu<F>(
315    cache: &mut BucketedCompileCache,
316    past_len: usize,
317    q_len: usize,
318    kv: &KvCacheState,
319    binding: &mut GpuKvBinding,
320    kv_dim: usize,
321    num_layers: usize,
322    fixed_inputs: &[CacheRunInput<'_>],
323    build: F,
324    options: &CompileOptions,
325) -> Result<Vec<f32>>
326where
327    F: FnOnce(u64) -> (Graph, HashMap<String, Vec<f32>>),
328{
329    let key = past_len as u64;
330    let (upper, compiled) = cache
331        .ensure_graph_with_params(key, build, options)
332        .ok_or_else(|| anyhow::anyhow!("past_len {past_len} outside MTP buckets"))?;
333
334    ensure_gpu_kv_bindings(compiled, kv, binding, upper, kv_dim, num_layers, false)?;
335    let actual_kv = past_len + q_len;
336    let upper_kv = upper as usize + q_len;
337    let mut pairs: Vec<(&str, &[f32])> = Vec::with_capacity(fixed_inputs.len());
338    for inp in fixed_inputs {
339        pairs.push((inp.name, inp.data));
340    }
341    compiled.set_active_extent(Some((actual_kv, upper_kv)));
342    let outs = compiled.run_read_outputs(&pairs, Some(&[0]));
343    compiled.set_active_extent(None);
344
345    outs.into_iter()
346        .next()
347        .context("gpu kv mtp: missing logits output")
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353    use crate::autoregressive::compact_bucketed_kv_buffer;
354    use rlx_runtime::Device;
355
356    #[test]
357    fn gpu_kv_supported_backends() {
358        assert!(device_supports_gpu_kv(Device::Mlx));
359        assert!(device_supports_gpu_kv(Device::Metal));
360        assert!(device_supports_gpu_kv(Device::Cuda));
361        assert!(device_supports_gpu_kv(Device::Gpu));
362        assert!(device_supports_gpu_kv(Device::Rocm));
363        assert!(!device_supports_gpu_kv(Device::Cpu));
364    }
365
366    #[test]
367    fn compact_bucketed_kv_skips_middle_padding() {
368        let kv_dim = 2;
369        // past_len=3: rows 0,1 real; row 2 padding; row 3 (upper) new token.
370        let buf = vec![
371            1.0, 1.1, //
372            2.0, 2.1, //
373            0.0, 0.0, // padding
374            9.0, 9.1, // new K at upper
375        ];
376        let out = compact_bucketed_kv_buffer(&buf, 3, kv_dim, 1);
377        assert_eq!(out, vec![1.0, 1.1, 2.0, 2.1, 9.0, 9.1]);
378    }
379}