1use crate::autoregressive::{KvCacheState, compact_bucketed_kv_buffer, past_kv_input_names};
19use anyhow::{Context, Result, ensure};
20use rlx_ir::{Graph, hir::HirModule};
21use rlx_runtime::compile_cache::{BucketedCompileCache, CacheRunInput, pad_rows};
22use rlx_runtime::kv_cache::LayerKvCache;
23use rlx_runtime::{CompileOptions, CompiledGraph, Device};
24use std::collections::HashMap;
25
26pub fn device_supports_gpu_kv(device: Device) -> bool {
28 matches!(
29 device,
30 Device::Mlx | Device::Metal | Device::Cuda | Device::Rocm | Device::Gpu | Device::Vulkan
31 )
32}
33
34#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
36pub struct GpuKvBinding {
37 pub upper: u64,
38}
39
40#[derive(Debug, Default)]
42pub struct GpuKvCacheSet {
43 pub causal: GpuKvBinding,
44 pub decode_mtp: GpuKvBinding,
45 pub mtp: GpuKvBinding,
46}
47
48impl GpuKvCacheSet {
49 pub fn reset(&mut self) {
50 *self = Self::default();
51 }
52
53 pub fn reset_decode_after_mtp(&mut self) {
55 self.causal = GpuKvBinding::default();
56 self.decode_mtp = GpuKvBinding::default();
57 self.mtp = GpuKvBinding::default();
58 }
59}
60
61pub fn cross_attn_gpu_handles_ready(compiled: &CompiledGraph) -> bool {
63 compiled.has_gpu_handle("cross_k_0")
64}
65
66pub fn install_cross_attn_gpu_handles(
68 compiled: &mut CompiledGraph,
69 cross: &LayerKvCache,
70 enc_seq: usize,
71 kv_dim: usize,
72 num_layers: usize,
73) -> Result<()> {
74 let upper = enc_seq as u64;
75 for i in 0..num_layers {
76 let k_name = format!("cross_k_{i}");
77 let v_name = format!("cross_v_{i}");
78 let k_pad = pad_rows(cross.layers_k[i].as_slice(), kv_dim, upper);
79 let v_pad = pad_rows(cross.layers_v[i].as_slice(), kv_dim, upper);
80 ensure!(
81 compiled.bind_gpu_handle(k_name.as_str(), &k_pad),
82 "bind_gpu_handle failed for {k_name}"
83 );
84 ensure!(
85 compiled.bind_gpu_handle(v_name.as_str(), &v_pad),
86 "bind_gpu_handle failed for {v_name}"
87 );
88 }
89 Ok(())
90}
91
92pub fn install_gpu_kv_handles(
94 compiled: &mut CompiledGraph,
95 kv: &KvCacheState,
96 prefix_rows: usize,
97 upper: u64,
98 kv_dim: usize,
99 num_layers: usize,
100) -> Result<()> {
101 let names = past_kv_input_names(num_layers);
102 for layer in 0..num_layers {
103 let k_name = names[2 * layer].as_str();
104 let v_name = names[2 * layer + 1].as_str();
105 let n = prefix_rows * kv_dim;
106 let k_slice = &kv.layers_k[layer][..n.min(kv.layers_k[layer].len())];
107 let v_slice = &kv.layers_v[layer][..n.min(kv.layers_v[layer].len())];
108 let k_pad = pad_rows(k_slice, kv_dim, upper);
109 let v_pad = pad_rows(v_slice, kv_dim, upper);
110 ensure!(
111 compiled.bind_gpu_handle(k_name, &k_pad),
112 "bind_gpu_handle failed for {k_name}"
113 );
114 compiled.set_gpu_handle_feed(k_name, 1 + 2 * layer);
115 ensure!(
116 compiled.bind_gpu_handle(v_name, &v_pad),
117 "bind_gpu_handle failed for {v_name}"
118 );
119 compiled.set_gpu_handle_feed(v_name, 2 + 2 * layer);
120 }
121 Ok(())
122}
123
124fn layer_host_rows(
125 compiled: &CompiledGraph,
126 name: &str,
127 host: &[f32],
128 past_len: usize,
129 kv_dim: usize,
130) -> Vec<f32> {
131 if compiled.has_gpu_handle(name) {
132 if let Some(buf) = compiled.read_gpu_handle(name) {
133 return compact_bucketed_kv_buffer(&buf, past_len, kv_dim, 1);
134 }
135 }
136 let take = (past_len * kv_dim).min(host.len());
137 host[..take].to_vec()
138}
139
140pub fn reinstall_gpu_kv_handles(
142 compiled: &mut CompiledGraph,
143 kv: &KvCacheState,
144 _old_upper: u64,
145 new_upper: u64,
146 kv_dim: usize,
147 num_layers: usize,
148) -> Result<()> {
149 let names = past_kv_input_names(num_layers);
150 let mut tmp = KvCacheState {
151 past_len: kv.past_len,
152 layers_k: Vec::with_capacity(num_layers),
153 layers_v: Vec::with_capacity(num_layers),
154 };
155 for layer in 0..num_layers {
156 tmp.layers_k.push(layer_host_rows(
157 compiled,
158 &names[2 * layer],
159 &kv.layers_k[layer],
160 kv.past_len,
161 kv_dim,
162 ));
163 tmp.layers_v.push(layer_host_rows(
164 compiled,
165 &names[2 * layer + 1],
166 &kv.layers_v[layer],
167 kv.past_len,
168 kv_dim,
169 ));
170 }
171 install_gpu_kv_handles(compiled, &tmp, tmp.past_len, new_upper, kv_dim, num_layers)
172}
173
174pub fn sync_gpu_kv_to_host(
176 compiled: &CompiledGraph,
177 kv: &mut KvCacheState,
178 kv_dim: usize,
179 num_layers: usize,
180) -> Result<()> {
181 let names = past_kv_input_names(num_layers);
182 let n = kv.past_len * kv_dim;
183 for layer in 0..num_layers {
184 kv.layers_k[layer] = layer_host_rows(
185 compiled,
186 &names[2 * layer],
187 &kv.layers_k[layer],
188 kv.past_len,
189 kv_dim,
190 );
191 kv.layers_v[layer] = layer_host_rows(
192 compiled,
193 &names[2 * layer + 1],
194 &kv.layers_v[layer],
195 kv.past_len,
196 kv_dim,
197 );
198 if kv.layers_k[layer].len() > n {
199 kv.layers_k[layer].truncate(n);
200 }
201 if kv.layers_v[layer].len() > n {
202 kv.layers_v[layer].truncate(n);
203 }
204 }
205 Ok(())
206}
207
208fn ensure_gpu_kv_bindings(
209 compiled: &mut CompiledGraph,
210 kv: &KvCacheState,
211 binding: &mut GpuKvBinding,
212 upper: u64,
213 kv_dim: usize,
214 num_layers: usize,
215 refresh_kv: bool,
216) -> Result<()> {
217 let names = past_kv_input_names(num_layers);
218 let handles_live = compiled.has_gpu_handle(names[0].as_str());
219 if refresh_kv || !handles_live || binding.upper != upper {
220 install_gpu_kv_handles(compiled, kv, kv.past_len, upper, kv_dim, num_layers)?;
221 binding.upper = upper;
222 }
223 Ok(())
224}
225
226pub fn run_bucketed_kv_decode_gpu<F>(
230 cache: &mut BucketedCompileCache,
231 cache_key: u64,
232 past_seq: usize,
233 kv: &mut KvCacheState,
234 binding: &mut GpuKvBinding,
235 kv_dim: usize,
236 num_layers: usize,
237 fixed_inputs: &[CacheRunInput<'_>],
238 build: F,
239 options: &CompileOptions,
240 refresh_kv: bool,
241) -> Result<Vec<f32>>
242where
243 F: FnOnce(u64) -> (Graph, HashMap<String, Vec<f32>>),
244{
245 let (upper, compiled) = cache
246 .ensure_graph_with_params(cache_key, build, options)
247 .ok_or_else(|| anyhow::anyhow!("cache_key {cache_key} outside decode buckets"))?;
248
249 ensure_gpu_kv_bindings(compiled, kv, binding, upper, kv_dim, num_layers, refresh_kv)?;
250
251 let mut pairs: Vec<(&str, &[f32])> = Vec::with_capacity(fixed_inputs.len());
252 for inp in fixed_inputs {
253 pairs.push((inp.name, inp.data));
254 }
255
256 if compiled.device() != Device::Metal {
258 compiled.set_active_extent(Some((upper as usize + 1, upper as usize + 1)));
259 }
260 let outs = compiled.run_read_outputs(&pairs, Some(&[0]));
261 compiled.set_active_extent(None);
262
263 let logits = outs
264 .into_iter()
265 .next()
266 .context("gpu kv decode: missing logits output")?;
267 kv.past_len = past_seq + 1;
268 Ok(logits)
269}
270
271pub fn run_bucketed_kv_decode_gpu_hir<F>(
273 cache: &mut BucketedCompileCache,
274 cache_key: u64,
275 past_seq: usize,
276 kv: &mut KvCacheState,
277 binding: &mut GpuKvBinding,
278 kv_dim: usize,
279 num_layers: usize,
280 fixed_inputs: &[CacheRunInput<'_>],
281 build: F,
282 options: &CompileOptions,
283 refresh_kv: bool,
284) -> Result<Vec<f32>>
285where
286 F: FnOnce(u64) -> (HirModule, HashMap<String, Vec<f32>>),
287{
288 let (upper, compiled) = cache
289 .ensure_hir_with_params(cache_key, build, options)
290 .ok_or_else(|| anyhow::anyhow!("cache_key {cache_key} outside decode buckets"))?;
291
292 ensure_gpu_kv_bindings(compiled, kv, binding, upper, kv_dim, num_layers, refresh_kv)?;
293
294 let mut pairs: Vec<(&str, &[f32])> = Vec::with_capacity(fixed_inputs.len());
295 for inp in fixed_inputs {
296 pairs.push((inp.name, inp.data));
297 }
298
299 if compiled.device() != Device::Metal {
300 compiled.set_active_extent(Some((upper as usize + 1, upper as usize + 1)));
301 }
302 let outs = compiled.run_read_outputs(&pairs, Some(&[0]));
303 compiled.set_active_extent(None);
304
305 let logits = outs
306 .into_iter()
307 .next()
308 .context("gpu kv decode: missing logits output")?;
309 kv.past_len = past_seq + 1;
310 Ok(logits)
311}
312
313pub fn run_bucketed_kv_mtp_gpu<F>(
315 cache: &mut BucketedCompileCache,
316 past_len: usize,
317 q_len: usize,
318 kv: &KvCacheState,
319 binding: &mut GpuKvBinding,
320 kv_dim: usize,
321 num_layers: usize,
322 fixed_inputs: &[CacheRunInput<'_>],
323 build: F,
324 options: &CompileOptions,
325) -> Result<Vec<f32>>
326where
327 F: FnOnce(u64) -> (Graph, HashMap<String, Vec<f32>>),
328{
329 let key = past_len as u64;
330 let (upper, compiled) = cache
331 .ensure_graph_with_params(key, build, options)
332 .ok_or_else(|| anyhow::anyhow!("past_len {past_len} outside MTP buckets"))?;
333
334 ensure_gpu_kv_bindings(compiled, kv, binding, upper, kv_dim, num_layers, false)?;
335 let actual_kv = past_len + q_len;
336 let upper_kv = upper as usize + q_len;
337 let mut pairs: Vec<(&str, &[f32])> = Vec::with_capacity(fixed_inputs.len());
338 for inp in fixed_inputs {
339 pairs.push((inp.name, inp.data));
340 }
341 compiled.set_active_extent(Some((actual_kv, upper_kv)));
342 let outs = compiled.run_read_outputs(&pairs, Some(&[0]));
343 compiled.set_active_extent(None);
344
345 outs.into_iter()
346 .next()
347 .context("gpu kv mtp: missing logits output")
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353 use crate::autoregressive::compact_bucketed_kv_buffer;
354 use rlx_runtime::Device;
355
356 #[test]
357 fn gpu_kv_supported_backends() {
358 assert!(device_supports_gpu_kv(Device::Mlx));
359 assert!(device_supports_gpu_kv(Device::Metal));
360 assert!(device_supports_gpu_kv(Device::Cuda));
361 assert!(device_supports_gpu_kv(Device::Gpu));
362 assert!(device_supports_gpu_kv(Device::Rocm));
363 assert!(!device_supports_gpu_kv(Device::Cpu));
364 }
365
366 #[test]
367 fn compact_bucketed_kv_skips_middle_padding() {
368 let kv_dim = 2;
369 let buf = vec![
371 1.0, 1.1, 2.0, 2.1, 0.0, 0.0, 9.0, 9.1, ];
376 let out = compact_bucketed_kv_buffer(&buf, 3, kv_dim, 1);
377 assert_eq!(out, vec![1.0, 1.1, 2.0, 2.1, 9.0, 9.1]);
378 }
379}