Skip to main content

rlx_locateanything/
session_cache.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Compile caches for LM prefill / MTP / decode.
17
18use crate::compile_support::{
19    lm_active_extent_enabled, lm_decode_compile_options, lm_gpu_kv_enabled, lm_host_device,
20    metal_lm_compile_guard,
21};
22use crate::config::LocateAnythingConfig;
23use crate::kv_buckets::locateanything_kv_bucket_ranges_for_device;
24use crate::lm_flow::{
25    build_locateanything_decode_built_ext, build_locateanything_mtp_kv_built,
26    build_locateanything_prefill_built,
27};
28use crate::load::LocateAnythingWeightStore;
29use crate::mask::{attn_bias_for_incremental_padded, mtp_decode_mask_padded};
30use crate::weights::CheckpointLmWeightLoader;
31use anyhow::Result;
32use rlx_core::flow_util::{
33    bucket_cache_ensure_built, compile_cache_ensure_built, graph_from_built,
34};
35use rlx_core::{
36    GpuKvCacheSet, KvCacheState, prefill_cache_key, run_bucketed_kv_decode,
37    run_bucketed_kv_decode_gpu, run_bucketed_kv_mtp_gpu, sync_gpu_kv_to_host,
38};
39use std::sync::Arc;
40
41use rlx_runtime::Device;
42use rlx_runtime::compile_cache::{BucketedCompileCache, CacheRunInput, CompileCache};
43
44fn session_use_gpu_kv(device: Device) -> bool {
45    lm_gpu_kv_enabled(device)
46}
47
48fn session_set_active_extent(
49    compiled: &mut rlx_runtime::CompiledGraph,
50    device: Device,
51    extent: (usize, usize),
52) {
53    if lm_active_extent_enabled(device) {
54        compiled.set_active_extent(Some(extent));
55    }
56}
57
58fn session_clear_active_extent(compiled: &mut rlx_runtime::CompiledGraph, device: Device) {
59    if lm_active_extent_enabled(device) {
60        compiled.set_active_extent(None);
61    }
62}
63
64/// Cached LM graphs + weight snapshot for one runner device.
65pub struct LmSessionCaches {
66    lm_store: Option<Arc<LocateAnythingWeightStore>>,
67    pub projector: std::collections::HashMap<usize, rlx_runtime::CompiledGraph>,
68    _device: Device,
69    prefill: CompileCache,
70    decode_causal: BucketedCompileCache,
71    decode_mtp: BucketedCompileCache,
72    mtp: BucketedCompileCache,
73    #[allow(dead_code)]
74    max_past: usize,
75    compile_opts_decode: rlx_runtime::CompileOptions,
76    device: Device,
77    gpu_kv: GpuKvCacheSet,
78}
79
80impl LmSessionCaches {
81    pub fn new(device: Device, max_past: usize) -> Self {
82        let max_past = max_past.max(1);
83        let host = lm_host_device(device);
84        let bucket_ranges = locateanything_kv_bucket_ranges_for_device(host, max_past);
85        Self {
86            lm_store: None,
87            projector: std::collections::HashMap::new(),
88            _device: device,
89            device,
90            prefill: CompileCache::new(host, 8),
91            decode_causal: BucketedCompileCache::new(host, bucket_ranges.clone()),
92            decode_mtp: BucketedCompileCache::new(host, bucket_ranges.clone()),
93            mtp: BucketedCompileCache::new(host, bucket_ranges),
94            max_past,
95            compile_opts_decode: lm_decode_compile_options(host),
96            gpu_kv: GpuKvCacheSet::default(),
97        }
98    }
99
100    /// Invalidate GPU K/V handle bindings (after MTP block or host KV rewrite).
101    pub fn reset_gpu_kv(&mut self) {
102        self.gpu_kv.reset();
103    }
104
105    /// Drop decode/MTP GPU bindings after an MTP block advanced `past_len`.
106    pub fn reset_decode_after_mtp(&mut self) {
107        self.gpu_kv.reset_decode_after_mtp();
108    }
109
110    /// Copy GPU-resident K/V into `kv` before a host-path MTP forward.
111    pub fn sync_kv_from_gpu(
112        &mut self,
113        cfg: &LocateAnythingConfig,
114        past_len: usize,
115        kv: &mut KvCacheState,
116    ) -> Result<()> {
117        if !session_use_gpu_kv(self.device) {
118            return Ok(());
119        }
120        let layers = cfg.text_config.num_hidden_layers;
121        let kv_dim = cfg.text_config.num_key_value_heads * cfg.text_config.head_dim();
122        let keys = if past_len > 0 {
123            [past_len as u64 - 1, past_len as u64]
124        } else {
125            [0, 0]
126        };
127        for &key in &keys {
128            let compiled = if self.gpu_kv.causal.upper != 0 {
129                self.decode_causal.compiled_for_key_mut(key)
130            } else if self.gpu_kv.mtp.upper != 0 {
131                self.mtp.compiled_for_key_mut(key)
132            } else {
133                return Ok(());
134            };
135            if let Some(compiled) = compiled {
136                return sync_gpu_kv_to_host(compiled, kv, kv_dim, layers);
137            }
138        }
139        Ok(())
140    }
141
142    /// Pin mmap-backed LM weights for compile caches (no full f32 snapshot in RAM).
143    pub fn ensure_lm_store(
144        &mut self,
145        store: Arc<LocateAnythingWeightStore>,
146    ) -> Arc<LocateAnythingWeightStore> {
147        if self.lm_store.is_none() {
148            self.lm_store = Some(store);
149        }
150        Arc::clone(self.lm_store.as_ref().expect("lm store"))
151    }
152
153    fn lm_loader(store: &Arc<LocateAnythingWeightStore>) -> CheckpointLmWeightLoader {
154        CheckpointLmWeightLoader::new(Arc::clone(store))
155    }
156
157    pub fn projector_graph(
158        &mut self,
159        n_tokens: usize,
160        build: impl FnOnce() -> Result<rlx_runtime::CompiledGraph>,
161    ) -> Result<&mut rlx_runtime::CompiledGraph> {
162        if let std::collections::hash_map::Entry::Vacant(e) = self.projector.entry(n_tokens) {
163            e.insert(build()?);
164        }
165        Ok(self.projector.get_mut(&n_tokens).expect("projector"))
166    }
167
168    pub fn prefill_with_kv(
169        &mut self,
170        cfg: &LocateAnythingConfig,
171        seq: usize,
172        inputs_embeds: &[f32],
173    ) -> Result<(Vec<f32>, Vec<Vec<f32>>)> {
174        let key = prefill_cache_key(1, seq);
175        let cfg = cfg.clone();
176        let store = Arc::clone(
177            self.lm_store
178                .as_ref()
179                .ok_or_else(|| anyhow::anyhow!("lm store missing"))?,
180        );
181        let mut loader = Self::lm_loader(&store);
182        let built = build_locateanything_prefill_built(&cfg, &mut loader, 1, seq, true, true)?;
183        let compiled = metal_lm_compile_guard(self.device, || {
184            compile_cache_ensure_built(&mut self.prefill, key, built)
185        })?;
186        let outs = metal_lm_compile_guard(self.device, || {
187            compiled.run(&[("inputs_embeds", inputs_embeds)])
188        });
189        let kv_start = 1usize;
190        self.reset_gpu_kv();
191        Ok((outs[0].clone(), outs[kv_start..].to_vec()))
192    }
193
194    pub fn mtp_logits(
195        &mut self,
196        cfg: &LocateAnythingConfig,
197        past_len: usize,
198        q_len: usize,
199        inputs_embeds: &[f32],
200        full_mask_2d: &[f32],
201        full_seq: usize,
202        rope_cos: &[f32],
203        rope_sin: &[f32],
204        kv: &mut KvCacheState,
205    ) -> Result<(Vec<f32>, KvCacheState)> {
206        let layers = cfg.text_config.num_hidden_layers;
207        let nh = cfg.text_config.num_attention_heads;
208        let kv_dim = cfg.text_config.num_key_value_heads * cfg.text_config.head_dim();
209        let key = past_len as u64;
210        let cfg = cfg.clone();
211        let store = Arc::clone(
212            self.lm_store
213                .as_ref()
214                .ok_or_else(|| anyhow::anyhow!("lm store missing"))?,
215        );
216
217        if session_use_gpu_kv(self.device) {
218            let upper = self
219                .mtp
220                .bucket_upper_for_key(key)
221                .ok_or_else(|| anyhow::anyhow!("past_len {past_len} outside MTP buckets"))?;
222
223            let attn_bias = attn_bias_for_incremental_padded(
224                1,
225                nh,
226                past_len,
227                upper as usize,
228                q_len,
229                full_mask_2d,
230                full_seq,
231            );
232            let fixed = [
233                CacheRunInput {
234                    name: "inputs_embeds",
235                    data: inputs_embeds,
236                    row_inner: None,
237                },
238                CacheRunInput {
239                    name: "attn_bias",
240                    data: &attn_bias,
241                    row_inner: None,
242                },
243                CacheRunInput {
244                    name: "rope_cos",
245                    data: rope_cos,
246                    row_inner: None,
247                },
248                CacheRunInput {
249                    name: "rope_sin",
250                    data: rope_sin,
251                    row_inner: None,
252                },
253            ];
254            let logits = metal_lm_compile_guard(self.device, || {
255                run_bucketed_kv_mtp_gpu(
256                    &mut self.mtp,
257                    past_len,
258                    q_len,
259                    kv,
260                    &mut self.gpu_kv.mtp,
261                    kv_dim,
262                    layers,
263                    &fixed,
264                    |upper| {
265                        let mut loader = Self::lm_loader(&store);
266                        let built = build_locateanything_mtp_kv_built(
267                            &cfg,
268                            &mut loader,
269                            1,
270                            upper as usize,
271                            q_len,
272                        )
273                        .expect("mtp kv graph");
274                        graph_from_built(built).expect("mtp kv graph from built")
275                    },
276                    &self.compile_opts_decode,
277                )
278            })?;
279            let past_after = past_len + q_len;
280            if let Some(compiled) = self.mtp.compiled_for_key_mut(key) {
281                kv.past_len = past_after;
282                sync_gpu_kv_to_host(compiled, kv, kv_dim, layers)?;
283            } else {
284                anyhow::bail!("mtp gpu: compiled graph missing for past_len {past_len}");
285            }
286            self.gpu_kv.reset_decode_after_mtp();
287            return Ok((logits, kv.clone()));
288        }
289
290        let (upper, compiled) = metal_lm_compile_guard(self.device, || {
291            bucket_cache_ensure_built(
292                &mut self.mtp,
293                key,
294                |upper| {
295                    let mut loader = Self::lm_loader(&store);
296                    build_locateanything_mtp_kv_built(&cfg, &mut loader, 1, upper as usize, q_len)
297                },
298                &self.compile_opts_decode,
299            )
300        })
301        .ok_or_else(|| anyhow::anyhow!("past_len {past_len} outside MTP buckets"))?;
302
303        let attn_bias = attn_bias_for_incremental_padded(
304            1,
305            nh,
306            past_len,
307            upper as usize,
308            q_len,
309            full_mask_2d,
310            full_seq,
311        );
312        let (padded_k, padded_v) = kv.pad_layers_to_upper(upper, kv_dim);
313        let key_past = rlx_core::past_kv_input_names(layers);
314        let mut run_in: Vec<(&str, &[f32])> = vec![
315            ("inputs_embeds", inputs_embeds),
316            ("attn_bias", &attn_bias),
317            ("rope_cos", rope_cos),
318            ("rope_sin", rope_sin),
319        ];
320        for i in 0..layers {
321            run_in.push((key_past[2 * i].as_str(), padded_k[i].as_slice()));
322            run_in.push((key_past[2 * i + 1].as_str(), padded_v[i].as_slice()));
323        }
324        // Bucket axis: we append `q_len` KV rows (not a single decode step).
325        let actual_kv = past_len + q_len;
326        let upper_kv = upper as usize + q_len;
327        session_set_active_extent(compiled, self.device, (actual_kv, upper_kv));
328        let outs = compiled.run(&run_in);
329        session_clear_active_extent(compiled, self.device);
330        let past_after = past_len + q_len;
331        let kv = kv_state_from_runner(past_after, &outs[1..], layers, kv_dim)?;
332        self.gpu_kv.reset_decode_after_mtp();
333        Ok((outs[0].clone(), kv))
334    }
335
336    /// Single-token (or MTP-mask) decode; updates `kv` in place and returns logits only.
337    pub fn decode_step_in_place(
338        &mut self,
339        cfg: &LocateAnythingConfig,
340        past_len: usize,
341        token: u32,
342        rope_cos: &[f32],
343        rope_sin: &[f32],
344        mtp_window: Option<(usize, usize)>,
345        kv: &mut KvCacheState,
346    ) -> Result<Vec<f32>> {
347        self.decode_step(cfg, past_len, token, rope_cos, rope_sin, mtp_window, kv)
348    }
349
350    fn decode_step(
351        &mut self,
352        cfg: &LocateAnythingConfig,
353        past_len: usize,
354        token: u32,
355        rope_cos: &[f32],
356        rope_sin: &[f32],
357        mtp_window: Option<(usize, usize)>,
358        kv: &mut KvCacheState,
359    ) -> Result<Vec<f32>> {
360        let layers = cfg.text_config.num_hidden_layers;
361        let kv_dim = cfg.text_config.num_key_value_heads * cfg.text_config.head_dim();
362        let token_f = [token as f32];
363        let cfg_c = cfg.clone();
364        let store = Arc::clone(
365            self.lm_store
366                .as_ref()
367                .ok_or_else(|| anyhow::anyhow!("lm store missing"))?,
368        );
369
370        let mut fixed = vec![
371            CacheRunInput {
372                name: "input_ids",
373                data: &token_f,
374                row_inner: None,
375            },
376            CacheRunInput {
377                name: "rope_cos",
378                data: rope_cos,
379                row_inner: None,
380            },
381            CacheRunInput {
382                name: "rope_sin",
383                data: rope_sin,
384                row_inner: None,
385            },
386        ];
387
388        if session_use_gpu_kv(self.device) {
389            let binding = if mtp_window.is_some() {
390                &mut self.gpu_kv.decode_mtp
391            } else {
392                &mut self.gpu_kv.causal
393            };
394            if let Some((block_size, past)) = mtp_window {
395                let key = past_len as u64;
396                let upper = self
397                    .decode_mtp
398                    .ensure_graph_with_params(
399                        key,
400                        |upper| {
401                            let mut loader = Self::lm_loader(&store);
402                            let built = build_locateanything_decode_built_ext(
403                                &cfg_c,
404                                &mut loader,
405                                1,
406                                upper as usize,
407                                true,
408                                false,
409                            )
410                            .expect("mtp decode graph");
411                            graph_from_built(built).expect("mtp decode graph from built")
412                        },
413                        &self.compile_opts_decode,
414                    )
415                    .ok_or_else(|| {
416                        anyhow::anyhow!("past_len {past_len} outside MTP decode buckets")
417                    })?
418                    .0;
419                let mask = mtp_decode_mask_padded(block_size, past, upper as usize + 1);
420                fixed.push(CacheRunInput {
421                    name: "mask",
422                    data: &mask,
423                    row_inner: None,
424                });
425                return metal_lm_compile_guard(self.device, || {
426                    run_bucketed_kv_decode_gpu(
427                        &mut self.decode_mtp,
428                        key,
429                        past_len,
430                        kv,
431                        binding,
432                        kv_dim,
433                        layers,
434                        &fixed,
435                        |upper| {
436                            let mut loader = Self::lm_loader(&store);
437                            let built = build_locateanything_decode_built_ext(
438                                &cfg_c,
439                                &mut loader,
440                                1,
441                                upper as usize,
442                                true,
443                                false,
444                            )
445                            .expect("mtp decode graph");
446                            graph_from_built(built).expect("mtp decode graph from built")
447                        },
448                        &self.compile_opts_decode,
449                        false,
450                    )
451                });
452            }
453            return metal_lm_compile_guard(self.device, || {
454                run_bucketed_kv_decode_gpu(
455                    &mut self.decode_causal,
456                    past_len as u64,
457                    past_len,
458                    kv,
459                    binding,
460                    kv_dim,
461                    layers,
462                    &fixed,
463                    |upper| {
464                        let mut loader = Self::lm_loader(&store);
465                        let built = build_locateanything_decode_built_ext(
466                            &cfg_c,
467                            &mut loader,
468                            1,
469                            upper as usize,
470                            false,
471                            false,
472                        )
473                        .expect("causal decode graph");
474                        graph_from_built(built).expect("causal decode graph from built")
475                    },
476                    &self.compile_opts_decode,
477                    false,
478                )
479            });
480        }
481
482        if let Some((block_size, past)) = mtp_window {
483            let key = past_len as u64;
484            let (upper, _) = self
485                .decode_mtp
486                .ensure_graph_with_params(
487                    key,
488                    |upper| {
489                        let mut loader = Self::lm_loader(&store);
490                        let built = build_locateanything_decode_built_ext(
491                            &cfg_c,
492                            &mut loader,
493                            1,
494                            upper as usize,
495                            true,
496                            false,
497                        )
498                        .expect("mtp decode graph");
499                        graph_from_built(built).expect("mtp decode graph from built")
500                    },
501                    &self.compile_opts_decode,
502                )
503                .ok_or_else(|| anyhow::anyhow!("past_len {past_len} outside MTP decode buckets"))?;
504            let mask = mtp_decode_mask_padded(block_size, past, upper as usize + 1);
505            fixed.push(CacheRunInput {
506                name: "mask",
507                data: &mask,
508                row_inner: None,
509            });
510            let (logits, new_k, new_v) = metal_lm_compile_guard(self.device, || {
511                run_bucketed_kv_decode(
512                    &mut self.decode_mtp,
513                    past_len,
514                    kv,
515                    kv_dim,
516                    layers,
517                    &fixed,
518                    |upper| {
519                        let mut loader = Self::lm_loader(&store);
520                        let built = build_locateanything_decode_built_ext(
521                            &cfg_c,
522                            &mut loader,
523                            1,
524                            upper as usize,
525                            true,
526                            false,
527                        )
528                        .expect("mtp decode graph");
529                        graph_from_built(built).expect("mtp decode graph from built")
530                    },
531                    &self.compile_opts_decode,
532                )
533            })?;
534            kv.past_len = past_len + 1;
535            let n = kv.past_len * kv_dim;
536            for i in 0..layers {
537                kv.layers_k[i] = take_kv_rows(&new_k[i], n);
538                kv.layers_v[i] = take_kv_rows(&new_v[i], n);
539            }
540            return Ok(logits);
541        }
542
543        let (logits, new_k, new_v) = metal_lm_compile_guard(self.device, || {
544            run_bucketed_kv_decode(
545                &mut self.decode_causal,
546                past_len,
547                kv,
548                kv_dim,
549                layers,
550                &fixed,
551                |upper| {
552                    let mut loader = Self::lm_loader(&store);
553                    let built = build_locateanything_decode_built_ext(
554                        &cfg_c,
555                        &mut loader,
556                        1,
557                        upper as usize,
558                        false,
559                        false,
560                    )
561                    .expect("causal decode graph");
562                    graph_from_built(built).expect("causal decode graph from built")
563                },
564                &self.compile_opts_decode,
565            )
566        })?;
567        kv.past_len = past_len + 1;
568        let n = kv.past_len * kv_dim;
569        for i in 0..layers {
570            kv.layers_k[i] = take_kv_rows(&new_k[i], n);
571            kv.layers_v[i] = take_kv_rows(&new_v[i], n);
572        }
573        Ok(logits)
574    }
575}
576
577/// Build KV state from prefill outputs (`[k0, v0, k1, v1, …]`).
578///
579/// When backend outputs are bucket-padded, only the first `past_len * kv_dim` floats
580/// per layer are copied (avoids retaining multi‑MiB tail padding on the host).
581pub fn kv_state_from_runner(
582    past_len: usize,
583    kv_flat: &[Vec<f32>],
584    layers: usize,
585    kv_dim: usize,
586) -> Result<KvCacheState> {
587    anyhow::ensure!(
588        kv_flat.len() == 2 * layers,
589        "expected {} kv tensors, got {}",
590        2 * layers,
591        kv_flat.len()
592    );
593    let n = past_len * kv_dim;
594    let mut layers_k = Vec::with_capacity(layers);
595    let mut layers_v = Vec::with_capacity(layers);
596    for i in 0..layers {
597        layers_k.push(take_kv_rows(&kv_flat[2 * i], n));
598        layers_v.push(take_kv_rows(&kv_flat[2 * i + 1], n));
599    }
600    Ok(KvCacheState {
601        past_len,
602        layers_k,
603        layers_v,
604    })
605}
606
607fn take_kv_rows(buf: &[f32], n: usize) -> Vec<f32> {
608    if buf.len() <= n {
609        buf.to_vec()
610    } else {
611        buf[..n].to_vec()
612    }
613}
614
615/// Trim MTP KV to `prefix_past + committed` tokens (MTP may compute a full block).
616pub fn truncate_kv_state(
617    kv: KvCacheState,
618    prefix_past: usize,
619    committed: usize,
620    kv_dim: usize,
621) -> Result<KvCacheState> {
622    let want = prefix_past + committed;
623    if want >= kv.past_len {
624        return Ok(kv);
625    }
626    let n = want * kv_dim;
627    let mut layers_k = Vec::with_capacity(kv.layers_k.len());
628    let mut layers_v = Vec::with_capacity(kv.layers_v.len());
629    for (k, v) in kv.layers_k.iter().zip(kv.layers_v.iter()) {
630        anyhow::ensure!(
631            k.len() >= n && v.len() >= n,
632            "truncate_kv_state: layer buffer too short (k={} v={} need {n})",
633            k.len(),
634            v.len()
635        );
636        layers_k.push(k[..n].to_vec());
637        layers_v.push(v[..n].to_vec());
638    }
639    Ok(KvCacheState {
640        past_len: want,
641        layers_k,
642        layers_v,
643    })
644}