1use crate::compile_support::{
19 lm_active_extent_enabled, lm_decode_compile_options, lm_gpu_kv_enabled, lm_host_device,
20 metal_lm_compile_guard,
21};
22use crate::config::LocateAnythingConfig;
23use crate::kv_buckets::locateanything_kv_bucket_ranges_for_device;
24use crate::lm_flow::{
25 build_locateanything_decode_built_ext, build_locateanything_mtp_kv_built,
26 build_locateanything_prefill_built,
27};
28use crate::load::LocateAnythingWeightStore;
29use crate::mask::{attn_bias_for_incremental_padded, mtp_decode_mask_padded};
30use crate::weights::CheckpointLmWeightLoader;
31use anyhow::Result;
32use rlx_core::flow_util::{
33 bucket_cache_ensure_built, compile_cache_ensure_built, graph_from_built,
34};
35use rlx_core::{
36 GpuKvCacheSet, KvCacheState, prefill_cache_key, run_bucketed_kv_decode,
37 run_bucketed_kv_decode_gpu, run_bucketed_kv_mtp_gpu, sync_gpu_kv_to_host,
38};
39use std::sync::Arc;
40
41use rlx_runtime::Device;
42use rlx_runtime::compile_cache::{BucketedCompileCache, CacheRunInput, CompileCache};
43
44fn session_use_gpu_kv(device: Device) -> bool {
45 lm_gpu_kv_enabled(device)
46}
47
48fn session_set_active_extent(
49 compiled: &mut rlx_runtime::CompiledGraph,
50 device: Device,
51 extent: (usize, usize),
52) {
53 if lm_active_extent_enabled(device) {
54 compiled.set_active_extent(Some(extent));
55 }
56}
57
58fn session_clear_active_extent(compiled: &mut rlx_runtime::CompiledGraph, device: Device) {
59 if lm_active_extent_enabled(device) {
60 compiled.set_active_extent(None);
61 }
62}
63
64pub struct LmSessionCaches {
66 lm_store: Option<Arc<LocateAnythingWeightStore>>,
67 pub projector: std::collections::HashMap<usize, rlx_runtime::CompiledGraph>,
68 _device: Device,
69 prefill: CompileCache,
70 decode_causal: BucketedCompileCache,
71 decode_mtp: BucketedCompileCache,
72 mtp: BucketedCompileCache,
73 #[allow(dead_code)]
74 max_past: usize,
75 compile_opts_decode: rlx_runtime::CompileOptions,
76 device: Device,
77 gpu_kv: GpuKvCacheSet,
78}
79
80impl LmSessionCaches {
81 pub fn new(device: Device, max_past: usize) -> Self {
82 let max_past = max_past.max(1);
83 let host = lm_host_device(device);
84 let bucket_ranges = locateanything_kv_bucket_ranges_for_device(host, max_past);
85 Self {
86 lm_store: None,
87 projector: std::collections::HashMap::new(),
88 _device: device,
89 device,
90 prefill: CompileCache::new(host, 8),
91 decode_causal: BucketedCompileCache::new(host, bucket_ranges.clone()),
92 decode_mtp: BucketedCompileCache::new(host, bucket_ranges.clone()),
93 mtp: BucketedCompileCache::new(host, bucket_ranges),
94 max_past,
95 compile_opts_decode: lm_decode_compile_options(host),
96 gpu_kv: GpuKvCacheSet::default(),
97 }
98 }
99
100 pub fn reset_gpu_kv(&mut self) {
102 self.gpu_kv.reset();
103 }
104
105 pub fn reset_decode_after_mtp(&mut self) {
107 self.gpu_kv.reset_decode_after_mtp();
108 }
109
110 pub fn sync_kv_from_gpu(
112 &mut self,
113 cfg: &LocateAnythingConfig,
114 past_len: usize,
115 kv: &mut KvCacheState,
116 ) -> Result<()> {
117 if !session_use_gpu_kv(self.device) {
118 return Ok(());
119 }
120 let layers = cfg.text_config.num_hidden_layers;
121 let kv_dim = cfg.text_config.num_key_value_heads * cfg.text_config.head_dim();
122 let keys = if past_len > 0 {
123 [past_len as u64 - 1, past_len as u64]
124 } else {
125 [0, 0]
126 };
127 for &key in &keys {
128 let compiled = if self.gpu_kv.causal.upper != 0 {
129 self.decode_causal.compiled_for_key_mut(key)
130 } else if self.gpu_kv.mtp.upper != 0 {
131 self.mtp.compiled_for_key_mut(key)
132 } else {
133 return Ok(());
134 };
135 if let Some(compiled) = compiled {
136 return sync_gpu_kv_to_host(compiled, kv, kv_dim, layers);
137 }
138 }
139 Ok(())
140 }
141
142 pub fn ensure_lm_store(
144 &mut self,
145 store: Arc<LocateAnythingWeightStore>,
146 ) -> Arc<LocateAnythingWeightStore> {
147 if self.lm_store.is_none() {
148 self.lm_store = Some(store);
149 }
150 Arc::clone(self.lm_store.as_ref().expect("lm store"))
151 }
152
153 fn lm_loader(store: &Arc<LocateAnythingWeightStore>) -> CheckpointLmWeightLoader {
154 CheckpointLmWeightLoader::new(Arc::clone(store))
155 }
156
157 pub fn projector_graph(
158 &mut self,
159 n_tokens: usize,
160 build: impl FnOnce() -> Result<rlx_runtime::CompiledGraph>,
161 ) -> Result<&mut rlx_runtime::CompiledGraph> {
162 if let std::collections::hash_map::Entry::Vacant(e) = self.projector.entry(n_tokens) {
163 e.insert(build()?);
164 }
165 Ok(self.projector.get_mut(&n_tokens).expect("projector"))
166 }
167
168 pub fn prefill_with_kv(
169 &mut self,
170 cfg: &LocateAnythingConfig,
171 seq: usize,
172 inputs_embeds: &[f32],
173 ) -> Result<(Vec<f32>, Vec<Vec<f32>>)> {
174 let key = prefill_cache_key(1, seq);
175 let cfg = cfg.clone();
176 let store = Arc::clone(
177 self.lm_store
178 .as_ref()
179 .ok_or_else(|| anyhow::anyhow!("lm store missing"))?,
180 );
181 let mut loader = Self::lm_loader(&store);
182 let built = build_locateanything_prefill_built(&cfg, &mut loader, 1, seq, true, true)?;
183 let compiled = metal_lm_compile_guard(self.device, || {
184 compile_cache_ensure_built(&mut self.prefill, key, built)
185 })?;
186 let outs = metal_lm_compile_guard(self.device, || {
187 compiled.run(&[("inputs_embeds", inputs_embeds)])
188 });
189 let kv_start = 1usize;
190 self.reset_gpu_kv();
191 Ok((outs[0].clone(), outs[kv_start..].to_vec()))
192 }
193
194 pub fn mtp_logits(
195 &mut self,
196 cfg: &LocateAnythingConfig,
197 past_len: usize,
198 q_len: usize,
199 inputs_embeds: &[f32],
200 full_mask_2d: &[f32],
201 full_seq: usize,
202 rope_cos: &[f32],
203 rope_sin: &[f32],
204 kv: &mut KvCacheState,
205 ) -> Result<(Vec<f32>, KvCacheState)> {
206 let layers = cfg.text_config.num_hidden_layers;
207 let nh = cfg.text_config.num_attention_heads;
208 let kv_dim = cfg.text_config.num_key_value_heads * cfg.text_config.head_dim();
209 let key = past_len as u64;
210 let cfg = cfg.clone();
211 let store = Arc::clone(
212 self.lm_store
213 .as_ref()
214 .ok_or_else(|| anyhow::anyhow!("lm store missing"))?,
215 );
216
217 if session_use_gpu_kv(self.device) {
218 let upper = self
219 .mtp
220 .bucket_upper_for_key(key)
221 .ok_or_else(|| anyhow::anyhow!("past_len {past_len} outside MTP buckets"))?;
222
223 let attn_bias = attn_bias_for_incremental_padded(
224 1,
225 nh,
226 past_len,
227 upper as usize,
228 q_len,
229 full_mask_2d,
230 full_seq,
231 );
232 let fixed = [
233 CacheRunInput {
234 name: "inputs_embeds",
235 data: inputs_embeds,
236 row_inner: None,
237 },
238 CacheRunInput {
239 name: "attn_bias",
240 data: &attn_bias,
241 row_inner: None,
242 },
243 CacheRunInput {
244 name: "rope_cos",
245 data: rope_cos,
246 row_inner: None,
247 },
248 CacheRunInput {
249 name: "rope_sin",
250 data: rope_sin,
251 row_inner: None,
252 },
253 ];
254 let logits = metal_lm_compile_guard(self.device, || {
255 run_bucketed_kv_mtp_gpu(
256 &mut self.mtp,
257 past_len,
258 q_len,
259 kv,
260 &mut self.gpu_kv.mtp,
261 kv_dim,
262 layers,
263 &fixed,
264 |upper| {
265 let mut loader = Self::lm_loader(&store);
266 let built = build_locateanything_mtp_kv_built(
267 &cfg,
268 &mut loader,
269 1,
270 upper as usize,
271 q_len,
272 )
273 .expect("mtp kv graph");
274 graph_from_built(built).expect("mtp kv graph from built")
275 },
276 &self.compile_opts_decode,
277 )
278 })?;
279 let past_after = past_len + q_len;
280 if let Some(compiled) = self.mtp.compiled_for_key_mut(key) {
281 kv.past_len = past_after;
282 sync_gpu_kv_to_host(compiled, kv, kv_dim, layers)?;
283 } else {
284 anyhow::bail!("mtp gpu: compiled graph missing for past_len {past_len}");
285 }
286 self.gpu_kv.reset_decode_after_mtp();
287 return Ok((logits, kv.clone()));
288 }
289
290 let (upper, compiled) = metal_lm_compile_guard(self.device, || {
291 bucket_cache_ensure_built(
292 &mut self.mtp,
293 key,
294 |upper| {
295 let mut loader = Self::lm_loader(&store);
296 build_locateanything_mtp_kv_built(&cfg, &mut loader, 1, upper as usize, q_len)
297 },
298 &self.compile_opts_decode,
299 )
300 })
301 .ok_or_else(|| anyhow::anyhow!("past_len {past_len} outside MTP buckets"))?;
302
303 let attn_bias = attn_bias_for_incremental_padded(
304 1,
305 nh,
306 past_len,
307 upper as usize,
308 q_len,
309 full_mask_2d,
310 full_seq,
311 );
312 let (padded_k, padded_v) = kv.pad_layers_to_upper(upper, kv_dim);
313 let key_past = rlx_core::past_kv_input_names(layers);
314 let mut run_in: Vec<(&str, &[f32])> = vec![
315 ("inputs_embeds", inputs_embeds),
316 ("attn_bias", &attn_bias),
317 ("rope_cos", rope_cos),
318 ("rope_sin", rope_sin),
319 ];
320 for i in 0..layers {
321 run_in.push((key_past[2 * i].as_str(), padded_k[i].as_slice()));
322 run_in.push((key_past[2 * i + 1].as_str(), padded_v[i].as_slice()));
323 }
324 let actual_kv = past_len + q_len;
326 let upper_kv = upper as usize + q_len;
327 session_set_active_extent(compiled, self.device, (actual_kv, upper_kv));
328 let outs = compiled.run(&run_in);
329 session_clear_active_extent(compiled, self.device);
330 let past_after = past_len + q_len;
331 let kv = kv_state_from_runner(past_after, &outs[1..], layers, kv_dim)?;
332 self.gpu_kv.reset_decode_after_mtp();
333 Ok((outs[0].clone(), kv))
334 }
335
336 pub fn decode_step_in_place(
338 &mut self,
339 cfg: &LocateAnythingConfig,
340 past_len: usize,
341 token: u32,
342 rope_cos: &[f32],
343 rope_sin: &[f32],
344 mtp_window: Option<(usize, usize)>,
345 kv: &mut KvCacheState,
346 ) -> Result<Vec<f32>> {
347 self.decode_step(cfg, past_len, token, rope_cos, rope_sin, mtp_window, kv)
348 }
349
350 fn decode_step(
351 &mut self,
352 cfg: &LocateAnythingConfig,
353 past_len: usize,
354 token: u32,
355 rope_cos: &[f32],
356 rope_sin: &[f32],
357 mtp_window: Option<(usize, usize)>,
358 kv: &mut KvCacheState,
359 ) -> Result<Vec<f32>> {
360 let layers = cfg.text_config.num_hidden_layers;
361 let kv_dim = cfg.text_config.num_key_value_heads * cfg.text_config.head_dim();
362 let token_f = [token as f32];
363 let cfg_c = cfg.clone();
364 let store = Arc::clone(
365 self.lm_store
366 .as_ref()
367 .ok_or_else(|| anyhow::anyhow!("lm store missing"))?,
368 );
369
370 let mut fixed = vec![
371 CacheRunInput {
372 name: "input_ids",
373 data: &token_f,
374 row_inner: None,
375 },
376 CacheRunInput {
377 name: "rope_cos",
378 data: rope_cos,
379 row_inner: None,
380 },
381 CacheRunInput {
382 name: "rope_sin",
383 data: rope_sin,
384 row_inner: None,
385 },
386 ];
387
388 if session_use_gpu_kv(self.device) {
389 let binding = if mtp_window.is_some() {
390 &mut self.gpu_kv.decode_mtp
391 } else {
392 &mut self.gpu_kv.causal
393 };
394 if let Some((block_size, past)) = mtp_window {
395 let key = past_len as u64;
396 let upper = self
397 .decode_mtp
398 .ensure_graph_with_params(
399 key,
400 |upper| {
401 let mut loader = Self::lm_loader(&store);
402 let built = build_locateanything_decode_built_ext(
403 &cfg_c,
404 &mut loader,
405 1,
406 upper as usize,
407 true,
408 false,
409 )
410 .expect("mtp decode graph");
411 graph_from_built(built).expect("mtp decode graph from built")
412 },
413 &self.compile_opts_decode,
414 )
415 .ok_or_else(|| {
416 anyhow::anyhow!("past_len {past_len} outside MTP decode buckets")
417 })?
418 .0;
419 let mask = mtp_decode_mask_padded(block_size, past, upper as usize + 1);
420 fixed.push(CacheRunInput {
421 name: "mask",
422 data: &mask,
423 row_inner: None,
424 });
425 return metal_lm_compile_guard(self.device, || {
426 run_bucketed_kv_decode_gpu(
427 &mut self.decode_mtp,
428 key,
429 past_len,
430 kv,
431 binding,
432 kv_dim,
433 layers,
434 &fixed,
435 |upper| {
436 let mut loader = Self::lm_loader(&store);
437 let built = build_locateanything_decode_built_ext(
438 &cfg_c,
439 &mut loader,
440 1,
441 upper as usize,
442 true,
443 false,
444 )
445 .expect("mtp decode graph");
446 graph_from_built(built).expect("mtp decode graph from built")
447 },
448 &self.compile_opts_decode,
449 false,
450 )
451 });
452 }
453 return metal_lm_compile_guard(self.device, || {
454 run_bucketed_kv_decode_gpu(
455 &mut self.decode_causal,
456 past_len as u64,
457 past_len,
458 kv,
459 binding,
460 kv_dim,
461 layers,
462 &fixed,
463 |upper| {
464 let mut loader = Self::lm_loader(&store);
465 let built = build_locateanything_decode_built_ext(
466 &cfg_c,
467 &mut loader,
468 1,
469 upper as usize,
470 false,
471 false,
472 )
473 .expect("causal decode graph");
474 graph_from_built(built).expect("causal decode graph from built")
475 },
476 &self.compile_opts_decode,
477 false,
478 )
479 });
480 }
481
482 if let Some((block_size, past)) = mtp_window {
483 let key = past_len as u64;
484 let (upper, _) = self
485 .decode_mtp
486 .ensure_graph_with_params(
487 key,
488 |upper| {
489 let mut loader = Self::lm_loader(&store);
490 let built = build_locateanything_decode_built_ext(
491 &cfg_c,
492 &mut loader,
493 1,
494 upper as usize,
495 true,
496 false,
497 )
498 .expect("mtp decode graph");
499 graph_from_built(built).expect("mtp decode graph from built")
500 },
501 &self.compile_opts_decode,
502 )
503 .ok_or_else(|| anyhow::anyhow!("past_len {past_len} outside MTP decode buckets"))?;
504 let mask = mtp_decode_mask_padded(block_size, past, upper as usize + 1);
505 fixed.push(CacheRunInput {
506 name: "mask",
507 data: &mask,
508 row_inner: None,
509 });
510 let (logits, new_k, new_v) = metal_lm_compile_guard(self.device, || {
511 run_bucketed_kv_decode(
512 &mut self.decode_mtp,
513 past_len,
514 kv,
515 kv_dim,
516 layers,
517 &fixed,
518 |upper| {
519 let mut loader = Self::lm_loader(&store);
520 let built = build_locateanything_decode_built_ext(
521 &cfg_c,
522 &mut loader,
523 1,
524 upper as usize,
525 true,
526 false,
527 )
528 .expect("mtp decode graph");
529 graph_from_built(built).expect("mtp decode graph from built")
530 },
531 &self.compile_opts_decode,
532 )
533 })?;
534 kv.past_len = past_len + 1;
535 let n = kv.past_len * kv_dim;
536 for i in 0..layers {
537 kv.layers_k[i] = take_kv_rows(&new_k[i], n);
538 kv.layers_v[i] = take_kv_rows(&new_v[i], n);
539 }
540 return Ok(logits);
541 }
542
543 let (logits, new_k, new_v) = metal_lm_compile_guard(self.device, || {
544 run_bucketed_kv_decode(
545 &mut self.decode_causal,
546 past_len,
547 kv,
548 kv_dim,
549 layers,
550 &fixed,
551 |upper| {
552 let mut loader = Self::lm_loader(&store);
553 let built = build_locateanything_decode_built_ext(
554 &cfg_c,
555 &mut loader,
556 1,
557 upper as usize,
558 false,
559 false,
560 )
561 .expect("causal decode graph");
562 graph_from_built(built).expect("causal decode graph from built")
563 },
564 &self.compile_opts_decode,
565 )
566 })?;
567 kv.past_len = past_len + 1;
568 let n = kv.past_len * kv_dim;
569 for i in 0..layers {
570 kv.layers_k[i] = take_kv_rows(&new_k[i], n);
571 kv.layers_v[i] = take_kv_rows(&new_v[i], n);
572 }
573 Ok(logits)
574 }
575}
576
577pub fn kv_state_from_runner(
582 past_len: usize,
583 kv_flat: &[Vec<f32>],
584 layers: usize,
585 kv_dim: usize,
586) -> Result<KvCacheState> {
587 anyhow::ensure!(
588 kv_flat.len() == 2 * layers,
589 "expected {} kv tensors, got {}",
590 2 * layers,
591 kv_flat.len()
592 );
593 let n = past_len * kv_dim;
594 let mut layers_k = Vec::with_capacity(layers);
595 let mut layers_v = Vec::with_capacity(layers);
596 for i in 0..layers {
597 layers_k.push(take_kv_rows(&kv_flat[2 * i], n));
598 layers_v.push(take_kv_rows(&kv_flat[2 * i + 1], n));
599 }
600 Ok(KvCacheState {
601 past_len,
602 layers_k,
603 layers_v,
604 })
605}
606
607fn take_kv_rows(buf: &[f32], n: usize) -> Vec<f32> {
608 if buf.len() <= n {
609 buf.to_vec()
610 } else {
611 buf[..n].to_vec()
612 }
613}
614
615pub fn truncate_kv_state(
617 kv: KvCacheState,
618 prefix_past: usize,
619 committed: usize,
620 kv_dim: usize,
621) -> Result<KvCacheState> {
622 let want = prefix_past + committed;
623 if want >= kv.past_len {
624 return Ok(kv);
625 }
626 let n = want * kv_dim;
627 let mut layers_k = Vec::with_capacity(kv.layers_k.len());
628 let mut layers_v = Vec::with_capacity(kv.layers_v.len());
629 for (k, v) in kv.layers_k.iter().zip(kv.layers_v.iter()) {
630 anyhow::ensure!(
631 k.len() >= n && v.len() >= n,
632 "truncate_kv_state: layer buffer too short (k={} v={} need {n})",
633 k.len(),
634 v.len()
635 );
636 layers_k.push(k[..n].to_vec());
637 layers_v.push(v[..n].to_vec());
638 }
639 Ok(KvCacheState {
640 past_len: want,
641 layers_k,
642 layers_v,
643 })
644}