1use crate::{CompiledGraph, Device, Session};
40use rlx_ir::DimBinding;
41use rlx_ir::Graph;
42use rlx_ir::hir::HirModule;
43use rlx_opt::CompileResult;
44use std::collections::HashMap;
45use std::collections::VecDeque;
46use std::ops::Range;
47
48pub struct CacheRunInput<'a> {
50 pub name: &'a str,
51 pub data: &'a [f32],
52 pub row_inner: Option<usize>,
54}
55
56pub struct CompileCache {
57 device: Device,
58 capacity: usize,
59 policy: Option<rlx_opt::PrecisionPolicy>,
62 entries: Vec<(u64, CompiledGraph)>,
66 order: VecDeque<u64>,
68}
69
70impl CompileCache {
71 pub fn new(device: Device, capacity: usize) -> Self {
72 Self::with_policy(device, capacity, None)
73 }
74
75 pub fn with_policy(
79 device: Device,
80 capacity: usize,
81 policy: Option<rlx_opt::PrecisionPolicy>,
82 ) -> Self {
83 assert!(capacity > 0, "CompileCache capacity must be ≥ 1");
84 Self {
85 device,
86 capacity,
87 policy,
88 entries: Vec::with_capacity(capacity),
89 order: VecDeque::with_capacity(capacity),
90 }
91 }
92
93 pub fn get_or_compile<F: FnOnce() -> Graph>(
97 &mut self,
98 key: u64,
99 build: F,
100 ) -> &mut CompiledGraph {
101 self.get_or_compile_with_options(key, build, &crate::CompileOptions::new())
102 }
103
104 pub fn get_or_compile_with_options<F: FnOnce() -> Graph>(
106 &mut self,
107 key: u64,
108 build: F,
109 options: &crate::CompileOptions,
110 ) -> &mut CompiledGraph {
111 if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
112 return &mut self.entries[idx].1;
113 }
114 let mut session = Session::new(self.device);
115 if let Some(p) = &self.policy {
116 session = session.with_policy(p.clone());
117 }
118 let compiled = session.compile_with(build(), options);
119
120 if self.entries.len() >= self.capacity
122 && let Some(evict_key) = self.order.pop_front()
123 {
124 sync_evicted_entry(&mut self.entries, evict_key);
125 self.entries.retain(|(k, _)| *k != evict_key);
126 }
127 self.entries.push((key, compiled));
128 self.order.push_back(key);
129 &mut self.entries.last_mut().unwrap().1
130 }
131
132 pub fn get_or_compile_hir_with_options<F: FnOnce() -> rlx_ir::hir::HirModule>(
134 &mut self,
135 key: u64,
136 build: F,
137 options: &crate::CompileOptions,
138 ) -> &mut CompiledGraph {
139 if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
140 return &mut self.entries[idx].1;
141 }
142 let mut session = Session::new(self.device);
143 if let Some(p) = &self.policy {
144 session = session.with_policy(p.clone());
145 }
146 let compiled = session
147 .compile_hir_with(build(), options)
148 .expect("HIR lower/compile in compile cache");
149
150 if self.entries.len() >= self.capacity
151 && let Some(evict_key) = self.order.pop_front()
152 {
153 sync_evicted_entry(&mut self.entries, evict_key);
154 self.entries.retain(|(k, _)| *k != evict_key);
155 }
156 self.entries.push((key, compiled));
157 self.order.push_back(key);
158 &mut self.entries.last_mut().unwrap().1
159 }
160
161 pub fn len(&self) -> usize {
163 self.entries.len()
164 }
165 pub fn is_empty(&self) -> bool {
166 self.entries.is_empty()
167 }
168 pub fn contains(&self, key: u64) -> bool {
170 self.entries.iter().any(|(k, _)| *k == key)
171 }
172
173 pub fn clear(&mut self) {
175 self.sync_all();
176 self.entries.clear();
177 self.order.clear();
178 }
179
180 pub fn sync_all(&mut self) {
182 for (_, compiled) in &mut self.entries {
183 compiled.sync_pending();
184 }
185 }
186}
187
188fn sync_evicted_entry(entries: &mut [(u64, CompiledGraph)], evict_key: u64) {
189 if let Some((_, compiled)) = entries.iter_mut().find(|(k, _)| *k == evict_key) {
190 compiled.sync_pending();
191 }
192}
193
194pub struct BucketedCompileCache {
241 device: Device,
242 policy: Option<rlx_opt::PrecisionPolicy>,
243 buckets: Vec<Bucket>,
244}
245
246struct Bucket {
247 range: Range<u64>,
248 compiled: Option<CompiledGraph>,
249}
250
251impl BucketedCompileCache {
252 pub fn new(device: Device, buckets: Vec<Range<u64>>) -> Self {
253 Self::with_policy(device, buckets, None)
254 }
255
256 pub fn power_of_two_ladder(device: Device, min: u64, max: u64) -> Self {
272 Self::power_of_two_ladder_with_policy(device, min, max, None)
273 }
274
275 pub fn power_of_two_ladder_with_policy(
276 device: Device,
277 min: u64,
278 max: u64,
279 policy: Option<rlx_opt::PrecisionPolicy>,
280 ) -> Self {
281 assert!(min >= 1, "power_of_two_ladder: min must be ≥ 1, got {min}");
282 assert!(
283 max >= min,
284 "power_of_two_ladder: max ({max}) must be ≥ min ({min})"
285 );
286 let mut buckets: Vec<Range<u64>> = Vec::new();
287 let mut start = 1u64;
288 let mut extent = min.next_power_of_two();
289 loop {
290 buckets.push(start..(extent + 1));
291 if extent >= max {
292 break;
293 }
294 start = extent + 1;
295 extent = extent
296 .checked_mul(2)
297 .expect("power_of_two_ladder: extent overflow");
298 }
299 Self::with_policy(device, buckets, policy)
300 }
301
302 pub fn with_policy(
303 device: Device,
304 buckets: Vec<Range<u64>>,
305 policy: Option<rlx_opt::PrecisionPolicy>,
306 ) -> Self {
307 assert!(!buckets.is_empty(), "BucketedCompileCache needs ≥1 bucket");
308 for (i, b) in buckets.iter().enumerate() {
309 assert!(b.start < b.end, "bucket {i} ({b:?}) is empty");
310 if i + 1 < buckets.len() {
311 assert!(
312 b.end <= buckets[i + 1].start,
313 "buckets {i} ({b:?}) and {} ({:?}) overlap",
314 i + 1,
315 buckets[i + 1],
316 );
317 }
318 }
319 let buckets = buckets
320 .into_iter()
321 .map(|range| Bucket {
322 range,
323 compiled: None,
324 })
325 .collect();
326 Self {
327 device,
328 policy,
329 buckets,
330 }
331 }
332
333 pub fn get_or_compile<F: FnOnce(u64) -> Graph>(
342 &mut self,
343 key: u64,
344 build: F,
345 ) -> Option<(u64, &mut CompiledGraph)> {
346 self.get_or_compile_with_options(key, build, &crate::CompileOptions::new())
347 }
348
349 pub fn get_or_compile_with_options<F: FnOnce(u64) -> Graph>(
351 &mut self,
352 key: u64,
353 build: F,
354 options: &crate::CompileOptions,
355 ) -> Option<(u64, &mut CompiledGraph)> {
356 let idx = self.bucket_for(key)?;
357 let upper = self.buckets[idx].range.end - 1;
358 if self.buckets[idx].compiled.is_none() {
359 let mut session = Session::new(self.device);
360 if let Some(p) = &self.policy {
361 session = session.with_policy(p.clone());
362 }
363 self.buckets[idx].compiled = Some(session.compile_with(build(upper), options));
364 }
365 Some((upper, self.buckets[idx].compiled.as_mut().unwrap()))
366 }
367
368 pub fn get_or_compile_hir<F: FnOnce(u64) -> HirModule>(
371 &mut self,
372 key: u64,
373 build: F,
374 ) -> Option<(u64, &mut CompiledGraph)> {
375 self.get_or_compile_hir_with_options(key, build, &crate::CompileOptions::new())
376 }
377
378 pub fn get_or_compile_hir_with_options<F: FnOnce(u64) -> HirModule>(
380 &mut self,
381 key: u64,
382 build: F,
383 options: &crate::CompileOptions,
384 ) -> Option<(u64, &mut CompiledGraph)> {
385 let idx = self.bucket_for(key)?;
386 let upper = self.buckets[idx].range.end - 1;
387 if self.buckets[idx].compiled.is_none() {
388 let mut session = Session::new(self.device);
389 if let Some(p) = &self.policy {
390 session = session.with_policy(p.clone());
391 }
392 let compiled = session
393 .compile_hir_with(build(upper), options)
394 .expect("HIR lower/compile in bucketed cache");
395 self.buckets[idx].compiled = Some(compiled);
396 }
397 Some((upper, self.buckets[idx].compiled.as_mut().unwrap()))
398 }
399
400 pub fn bucket_for(&self, key: u64) -> Option<usize> {
403 self.buckets.iter().position(|b| b.range.contains(&key))
404 }
405
406 pub fn bucket_upper_for_key(&self, key: u64) -> Option<u64> {
408 let idx = self.bucket_for(key)?;
409 Some(self.buckets[idx].range.end - 1)
410 }
411
412 pub fn buckets(&self) -> impl Iterator<Item = &Range<u64>> {
413 self.buckets.iter().map(|b| &b.range)
414 }
415
416 pub fn compiled_count(&self) -> usize {
418 self.buckets.iter().filter(|b| b.compiled.is_some()).count()
419 }
420
421 pub fn compiled_for_key_mut(&mut self, key: u64) -> Option<&mut CompiledGraph> {
423 let idx = self.bucket_for(key)?;
424 self.buckets[idx].compiled.as_mut()
425 }
426
427 pub fn total_buckets(&self) -> usize {
428 self.buckets.len()
429 }
430
431 pub fn evict_except(&mut self, keep: usize) {
433 for (i, bucket) in self.buckets.iter_mut().enumerate() {
434 if i != keep {
435 bucket.compiled = None;
436 }
437 }
438 }
439
440 pub fn clear_compiled(&mut self) {
442 for bucket in &mut self.buckets {
443 bucket.compiled = None;
444 }
445 }
446
447 pub fn run_padded<F: FnOnce(u64) -> Graph>(
473 &mut self,
474 key: u64,
475 actual_rows: usize,
476 build: F,
477 inputs: &[(&str, &[f32], usize)],
478 output_inners: &[usize],
479 ) -> Option<(u64, Vec<Vec<f32>>)> {
480 let (upper, compiled) = self.get_or_compile(key, build)?;
481
482 let padded: Vec<(&str, Vec<f32>)> = inputs
484 .iter()
485 .map(|(name, data, inner)| (*name, pad_rows(data, *inner, upper)))
486 .collect();
487 let pairs: Vec<(&str, &[f32])> = padded.iter().map(|(n, d)| (*n, d.as_slice())).collect();
488
489 compiled.set_active_extent(Some((actual_rows, upper as usize)));
495 let raw_outputs = compiled.run(&pairs);
496 compiled.set_active_extent(None);
497 #[cfg(feature = "cpu")]
498 crate::onnx_active::set_active_token_count(None);
499
500 let outs = raw_outputs
501 .into_iter()
502 .enumerate()
503 .map(|(i, out)| match output_inners.get(i).copied() {
504 Some(0) | None => out,
505 Some(inner) => slice_rows(&out, inner, actual_rows),
506 })
507 .collect();
508
509 Some((upper, outs))
510 }
511
512 pub fn ensure_graph_with_params<F>(
514 &mut self,
515 key: u64,
516 build: F,
517 options: &crate::CompileOptions,
518 ) -> Option<(u64, &mut CompiledGraph)>
519 where
520 F: FnOnce(u64) -> (Graph, HashMap<String, Vec<f32>>),
521 {
522 let idx = self.bucket_for(key)?;
523 let upper = self.buckets[idx].range.end - 1;
524 if self.buckets[idx].compiled.is_none() {
525 let (graph, params) = build(upper);
526 let mut session = Session::new(self.device);
527 if let Some(p) = &self.policy {
528 session = session.with_policy(p.clone());
529 }
530 let mut compiled = session.compile_with(graph, options);
531 for (name, data) in params {
532 compiled.set_param(&name, &data);
533 }
534 self.buckets[idx].compiled = Some(compiled);
535 }
536 Some((upper, self.buckets[idx].compiled.as_mut().unwrap()))
537 }
538
539 pub fn ensure_hir_with_params<F>(
541 &mut self,
542 key: u64,
543 build: F,
544 options: &crate::CompileOptions,
545 ) -> Option<(u64, &mut CompiledGraph)>
546 where
547 F: FnOnce(u64) -> (HirModule, HashMap<String, Vec<f32>>),
548 {
549 let idx = self.bucket_for(key)?;
550 let upper = self.buckets[idx].range.end - 1;
551 if self.buckets[idx].compiled.is_none() {
552 let (hir, params) = build(upper);
553 let mut session = Session::new(self.device);
554 if let Some(p) = &self.policy {
555 session = session.with_policy(p.clone());
556 }
557 let mut compiled = session
558 .compile_hir_with(hir, options)
559 .expect("HIR lower/compile in ensure_hir_with_params");
560 for (name, data) in params {
561 compiled.set_param(&name, &data);
562 }
563 self.buckets[idx].compiled = Some(compiled);
564 }
565 Some((upper, self.buckets[idx].compiled.as_mut().unwrap()))
566 }
567
568 pub fn run_padded_mixed<F>(
570 &mut self,
571 key: u64,
572 actual_rows: usize,
573 build: F,
574 inputs: &[CacheRunInput<'_>],
575 output_inners: &[usize],
576 ) -> Option<(u64, Vec<Vec<f32>>)>
577 where
578 F: FnOnce(u64) -> Graph,
579 {
580 let (upper, compiled) = self.get_or_compile(key, build)?;
581
582 let padded: Vec<(&str, Vec<f32>)> = inputs
583 .iter()
584 .map(|inp| match inp.row_inner {
585 Some(inner) => (inp.name, pad_rows(inp.data, inner, upper)),
586 None => (inp.name, inp.data.to_vec()),
587 })
588 .collect();
589 let pairs: Vec<(&str, &[f32])> = padded.iter().map(|(n, d)| (*n, d.as_slice())).collect();
590
591 compiled.set_active_extent(Some((actual_rows, upper as usize)));
592 let raw_outputs = compiled.run(&pairs);
593 compiled.set_active_extent(None);
594 #[cfg(feature = "cpu")]
595 crate::onnx_active::set_active_token_count(None);
596
597 let outs = raw_outputs
598 .into_iter()
599 .enumerate()
600 .map(|(i, out)| match output_inners.get(i).copied() {
601 Some(0) | None => out,
602 Some(inner) => slice_rows(&out, inner, actual_rows),
603 })
604 .collect();
605
606 Some((upper, outs))
607 }
608
609 pub fn sync_all(&mut self) {
611 for bucket in &mut self.buckets {
612 if let Some(compiled) = &mut bucket.compiled {
613 compiled.sync_pending();
614 }
615 }
616 }
617}
618
619pub struct DynamicDimCompileCache {
627 device: Device,
628 policy: Option<rlx_opt::PrecisionPolicy>,
629 capacity: usize,
630 template: Option<CompileResult>,
631 entries: Vec<(u64, CompiledGraph)>,
632 order: VecDeque<u64>,
633}
634
635impl DynamicDimCompileCache {
636 pub fn new(device: Device, capacity: usize) -> Self {
637 Self::with_policy(device, capacity, None)
638 }
639
640 pub fn with_policy(
641 device: Device,
642 capacity: usize,
643 policy: Option<rlx_opt::PrecisionPolicy>,
644 ) -> Self {
645 assert!(capacity > 0, "DynamicDimCompileCache capacity must be ≥ 1");
646 Self {
647 device,
648 policy,
649 capacity,
650 template: None,
651 entries: Vec::with_capacity(capacity),
652 order: VecDeque::with_capacity(capacity),
653 }
654 }
655
656 pub fn compile_device(&self) -> Device {
657 self.device
658 }
659
660 pub fn get_or_specialize<F: FnOnce() -> HirModule>(
663 &mut self,
664 key: u64,
665 binding: &DimBinding,
666 build_hir: F,
667 options: &crate::CompileOptions,
668 ) -> Result<&mut CompiledGraph, rlx_ir::hir::LowerError> {
669 if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
670 return Ok(&mut self.entries[idx].1);
671 }
672 if self.template.is_none() {
673 let mut template_opts = options.clone();
674 template_opts.dim_binding = None;
675 let pipe = crate::stages::pipeline_for(self.device, &template_opts);
676 self.template = Some(pipe.compile_hir(build_hir())?);
677 }
678 let template = self.template.as_ref().expect("template just set");
679 let mut spec_opts = options.clone();
680 spec_opts.dim_binding = None;
681 let pipe = crate::stages::pipeline_for(self.device, &spec_opts);
682 let specialized = template.specialize(&pipe, binding);
683 let backend = crate::registry::backend_for(self.device).expect("backend registered");
684 let mut compile_opts = options.clone();
685 compile_opts.dim_binding = None;
686 if compile_opts.policy.is_none() {
687 if let Some(p) = &self.policy {
688 compile_opts = compile_opts.policy(p.clone());
689 }
690 }
691 let executable = backend.compile_lir(specialized.lir, &compile_opts);
692 let compiled = CompiledGraph::new(executable, self.device);
693
694 if self.entries.len() >= self.capacity
695 && let Some(evict_key) = self.order.pop_front()
696 {
697 sync_evicted_entry(&mut self.entries, evict_key);
698 self.entries.retain(|(k, _)| *k != evict_key);
699 }
700 self.entries.push((key, compiled));
701 self.order.push_back(key);
702 Ok(&mut self.entries.last_mut().unwrap().1)
703 }
704
705 pub fn len(&self) -> usize {
706 self.entries.len()
707 }
708
709 pub fn is_empty(&self) -> bool {
710 self.entries.is_empty()
711 }
712
713 pub fn contains(&self, key: u64) -> bool {
714 self.entries.iter().any(|(k, _)| *k == key)
715 }
716
717 pub fn clear(&mut self) {
718 self.sync_all();
719 self.template = None;
720 self.entries.clear();
721 self.order.clear();
722 }
723
724 pub fn has_template(&self) -> bool {
725 self.template.is_some()
726 }
727
728 pub fn sync_all(&mut self) {
730 for (_, compiled) in &mut self.entries {
731 compiled.sync_pending();
732 }
733 }
734
735 pub fn ensure_template<F: FnOnce() -> HirModule>(
737 &mut self,
738 build_hir: F,
739 options: &crate::CompileOptions,
740 ) -> Result<&CompileResult, rlx_ir::hir::LowerError> {
741 if self.template.is_none() {
742 let mut opts = options.clone();
743 opts.dim_binding = None;
744 let pipe = crate::stages::pipeline_for(self.device, &opts);
745 self.template = Some(pipe.compile_hir(build_hir())?);
746 }
747 Ok(self.template.as_ref().expect("template set"))
748 }
749
750 pub fn template_result(&self) -> Option<&CompileResult> {
751 self.template.as_ref()
752 }
753
754 pub fn get_or_specialize_aot<F: FnOnce() -> HirModule>(
757 &mut self,
758 aot: &crate::AotCache,
759 disk_base: &str,
760 key: u64,
761 binding: &rlx_ir::DimBinding,
762 build_hir: F,
763 options: &crate::CompileOptions,
764 ) -> Result<&mut CompiledGraph, crate::AotCacheError> {
765 if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
766 return Ok(&mut self.entries[idx].1);
767 }
768 let device = self.device;
769 let template = self.ensure_template(build_hir, options)?;
770 let compiled = aot.specialize_cached(disk_base, binding, device, template, options)?;
771 if self.entries.len() >= self.capacity
772 && let Some(evict_key) = self.order.pop_front()
773 {
774 sync_evicted_entry(&mut self.entries, evict_key);
775 self.entries.retain(|(k, _)| *k != evict_key);
776 }
777 self.entries.push((key, compiled));
778 self.order.push_back(key);
779 Ok(&mut self.entries.last_mut().unwrap().1)
780 }
781}
782
783pub fn pad_rows(data: &[f32], inner: usize, upper: u64) -> Vec<f32> {
791 assert!(inner > 0, "pad_rows: inner stride must be ≥ 1");
792 assert_eq!(
793 data.len() % inner,
794 0,
795 "pad_rows: data len {} not a multiple of inner {inner}",
796 data.len(),
797 );
798 let upper = upper as usize;
799 let actual = data.len() / inner;
800 assert!(
801 actual <= upper,
802 "pad_rows: actual rows {actual} exceed upper bound {upper}",
803 );
804 let mut out = vec![0.0_f32; upper * inner];
805 out[..actual * inner].copy_from_slice(data);
806 out
807}
808
809pub fn pad_rows_into(out: &mut [f32], data: &[f32], inner: usize) {
811 assert!(inner > 0, "pad_rows_into: inner stride must be ≥ 1");
812 assert_eq!(
813 data.len() % inner,
814 0,
815 "pad_rows_into: data len {} not a multiple of inner {inner}",
816 data.len(),
817 );
818 assert_eq!(
819 out.len() % inner,
820 0,
821 "pad_rows_into: out len {} not a multiple of inner {inner}",
822 out.len(),
823 );
824 let upper = out.len() / inner;
825 let actual = data.len() / inner;
826 assert!(
827 actual <= upper,
828 "pad_rows_into: actual rows {actual} exceed upper bound {upper}",
829 );
830 out.fill(0.0);
831 out[..data.len()].copy_from_slice(data);
832}
833
834pub fn slice_rows(data: &[f32], inner: usize, actual: usize) -> Vec<f32> {
840 assert!(inner > 0, "slice_rows: inner stride must be ≥ 1");
841 assert_eq!(
842 data.len() % inner,
843 0,
844 "slice_rows: data len {} not a multiple of inner {inner}",
845 data.len(),
846 );
847 let upper = data.len() / inner;
848 assert!(
849 actual <= upper,
850 "slice_rows: actual rows {actual} exceed upper {upper}",
851 );
852 data[..actual * inner].to_vec()
853}
854
855#[cfg(test)]
856mod tests {
857 use super::*;
858 use rlx_ir::infer::GraphExt;
859 use rlx_ir::*;
860 use std::cell::Cell;
861
862 fn tiny_graph(n: usize) -> Graph {
863 let mut g = Graph::new("t");
864 let f = DType::F32;
865 let x = g.input("x", Shape::new(&[n], f));
866 let y = g.activation(rlx_ir::op::Activation::Relu, x, Shape::new(&[n], f));
867 g.set_outputs(vec![y]);
868 g
869 }
870
871 #[test]
872 fn cache_hits_avoid_recompile() {
873 let mut cache = CompileCache::new(Device::Cpu, 4);
874 let calls = Cell::new(0);
875
876 let _ = cache.get_or_compile(1, || {
877 calls.set(calls.get() + 1);
878 tiny_graph(8)
879 });
880 let _ = cache.get_or_compile(1, || {
881 calls.set(calls.get() + 1);
882 tiny_graph(8)
883 });
884 let _ = cache.get_or_compile(1, || {
885 calls.set(calls.get() + 1);
886 tiny_graph(8)
887 });
888 assert_eq!(calls.get(), 1);
890 assert_eq!(cache.len(), 1);
891 }
892
893 #[test]
894 fn fifo_evicts_oldest_at_capacity() {
895 let mut cache = CompileCache::new(Device::Cpu, 2);
896 let _ = cache.get_or_compile(1, || tiny_graph(4));
897 let _ = cache.get_or_compile(2, || tiny_graph(8));
898 assert!(cache.contains(1) && cache.contains(2));
899 let _ = cache.get_or_compile(3, || tiny_graph(16));
901 assert!(!cache.contains(1));
902 assert!(cache.contains(2) && cache.contains(3));
903 }
904
905 #[test]
906 fn different_keys_keep_separate_compiles() {
907 let mut cache = CompileCache::new(Device::Cpu, 4);
908 let calls = Cell::new(0);
909 let _ = cache.get_or_compile(1, || {
910 calls.set(calls.get() + 1);
911 tiny_graph(8)
912 });
913 let _ = cache.get_or_compile(2, || {
914 calls.set(calls.get() + 1);
915 tiny_graph(16)
916 });
917 let _ = cache.get_or_compile(1, || {
918 calls.set(calls.get() + 1);
919 tiny_graph(8)
920 });
921 assert_eq!(calls.get(), 2);
923 assert_eq!(cache.len(), 2);
924 }
925
926 #[test]
929 fn bucket_amortizes_keys_within_range() {
930 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..4, 4..16]);
931 let calls = Cell::new(0);
932 let uppers = Cell::new((0u64, 0u64));
933
934 let (u1, _) = cache
936 .get_or_compile(2, |upper| {
937 calls.set(calls.get() + 1);
938 uppers.set((upper, uppers.get().1));
939 tiny_graph(upper as usize)
940 })
941 .expect("key 2 in range");
942 let (u2, _) = cache
943 .get_or_compile(3, |upper| {
944 calls.set(calls.get() + 1);
945 uppers.set((uppers.get().0, upper));
946 tiny_graph(upper as usize)
947 })
948 .expect("key 3 in range");
949
950 assert_eq!(calls.get(), 1);
952 assert_eq!(u1, 3);
953 assert_eq!(u2, 3);
954 assert_eq!(uppers.get().0, 3);
955 assert_eq!(cache.compiled_count(), 1);
956 assert_eq!(cache.total_buckets(), 2);
957 }
958
959 #[test]
960 fn bucket_lookup_returns_none_outside_range() {
961 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..4, 4..16]);
962 assert!(cache.bucket_for(0).is_none());
963 assert!(cache.bucket_for(16).is_none());
964 assert!(cache.bucket_for(100).is_none());
965 assert_eq!(cache.bucket_for(3), Some(0));
966 assert_eq!(cache.bucket_for(4), Some(1));
967 assert_eq!(cache.bucket_upper_for_key(3), Some(3));
968 assert_eq!(cache.bucket_upper_for_key(4), Some(15));
969 assert!(cache.bucket_upper_for_key(0).is_none());
970
971 let calls = Cell::new(0);
972 let result = cache.get_or_compile(100, |u| {
973 calls.set(calls.get() + 1);
974 tiny_graph(u as usize)
975 });
976 assert!(result.is_none());
977 assert_eq!(calls.get(), 0); assert_eq!(cache.compiled_count(), 0);
979 }
980
981 #[test]
982 fn bucket_compiles_lazily_per_bucket() {
983 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..4, 4..16, 16..64]);
984 let calls = Cell::new(0);
985
986 let _ = cache.get_or_compile(2, |u| {
987 calls.set(calls.get() + 1);
988 tiny_graph(u as usize)
989 });
990 let _ = cache.get_or_compile(8, |u| {
991 calls.set(calls.get() + 1);
992 tiny_graph(u as usize)
993 });
994 assert_eq!(calls.get(), 2);
996 assert_eq!(cache.compiled_count(), 2);
997 assert_eq!(cache.total_buckets(), 3);
998 }
999
1000 #[test]
1001 #[should_panic(expected = "overlap")]
1002 fn bucket_overlap_rejected() {
1003 let _ = BucketedCompileCache::new(Device::Cpu, vec![1..8, 4..16]);
1004 }
1005
1006 #[test]
1007 #[should_panic(expected = "≥1 bucket")]
1008 fn empty_bucket_list_rejected() {
1009 let _ = BucketedCompileCache::new(Device::Cpu, vec![]);
1010 }
1011
1012 #[test]
1015 fn pad_rows_appends_zeros() {
1016 let p = pad_rows(&[1.0, 2.0, 3.0], 1, 5);
1018 assert_eq!(p, vec![1.0, 2.0, 3.0, 0.0, 0.0]);
1019
1020 let p = pad_rows(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 4);
1022 assert_eq!(
1023 p,
1024 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1025 );
1026
1027 let p = pad_rows(&[7.0, 8.0], 1, 2);
1029 assert_eq!(p, vec![7.0, 8.0]);
1030 }
1031
1032 #[test]
1033 fn slice_rows_truncates_trailing() {
1034 let s = slice_rows(&[1.0, 2.0, 3.0, 0.0, 0.0], 1, 3);
1035 assert_eq!(s, vec![1.0, 2.0, 3.0]);
1036
1037 let s = slice_rows(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0, 0.0], 3, 2);
1038 assert_eq!(s, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1039 }
1040
1041 #[test]
1042 #[should_panic(expected = "exceed upper")]
1043 fn pad_rows_rejects_too_long_input() {
1044 let _ = pad_rows(&[1.0, 2.0, 3.0, 4.0], 1, 3);
1045 }
1046
1047 #[test]
1048 #[should_panic(expected = "exceed upper")]
1049 fn slice_rows_rejects_too_large_actual() {
1050 let _ = slice_rows(&[1.0, 2.0, 3.0], 1, 5);
1051 }
1052
1053 #[test]
1056 fn run_padded_pads_input_and_slices_output() {
1057 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
1060 let input: Vec<f32> = vec![1.0, -1.0, 2.0, -2.0, 3.0, -3.0, 4.0, -4.0, 5.0, -5.0];
1061
1062 let (upper, outs) = cache
1063 .run_padded(
1064 10, 10, |max| tiny_graph(max as usize),
1067 &[("x", &input, 1)], &[1], )
1070 .expect("key 10 in [1..16)");
1071
1072 assert_eq!(upper, 15);
1073 assert_eq!(outs.len(), 1);
1074 let out = &outs[0];
1075 assert_eq!(out.len(), 10, "output sliced back to actual_rows");
1076 let expected: Vec<f32> = input.iter().map(|x| x.max(0.0)).collect();
1077 assert_eq!(out, &expected);
1078 }
1079
1080 #[test]
1081 fn run_padded_reuses_bucket_across_actuals() {
1082 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
1084 let calls = Cell::new(0);
1085
1086 let (u1, o1) = cache
1087 .run_padded(
1088 10,
1089 10,
1090 |max| {
1091 calls.set(calls.get() + 1);
1092 tiny_graph(max as usize)
1093 },
1094 &[(
1095 "x",
1096 &[1.0, -1.0, 2.0, -2.0, 3.0, -3.0, 4.0, -4.0, 5.0, -5.0],
1097 1,
1098 )],
1099 &[1],
1100 )
1101 .unwrap();
1102 assert_eq!(o1.len(), 1);
1103 assert_eq!(o1[0].len(), 10);
1104 assert_eq!(u1, 15);
1105
1106 let (u2, o2) = cache
1107 .run_padded(
1108 5,
1109 5,
1110 |max| {
1111 calls.set(calls.get() + 1);
1112 tiny_graph(max as usize)
1113 },
1114 &[("x", &[-1.0, 2.0, -3.0, 4.0, -5.0], 1)],
1115 &[1],
1116 )
1117 .unwrap();
1118 assert_eq!(o2.len(), 1);
1119 assert_eq!(o2[0].len(), 5);
1120 assert_eq!(u2, 15);
1121 assert_eq!(o2[0], vec![0.0, 2.0, 0.0, 4.0, 0.0]);
1122
1123 assert_eq!(calls.get(), 1, "bucket cached across actuals");
1124 assert_eq!(cache.compiled_count(), 1);
1125 }
1126
1127 #[test]
1128 fn run_padded_returns_none_out_of_range() {
1129 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
1130 let calls = Cell::new(0);
1131 let result = cache.run_padded(
1132 100,
1133 5,
1134 |u| {
1135 calls.set(calls.get() + 1);
1136 tiny_graph(u as usize)
1137 },
1138 &[("x", &[1.0, 2.0, 3.0, 4.0, 5.0], 1)],
1139 &[1],
1140 );
1141 assert!(result.is_none());
1142 assert_eq!(calls.get(), 0);
1143 assert_eq!(cache.compiled_count(), 0);
1144 }
1145
1146 #[test]
1149 fn power_of_two_ladder_generates_log_buckets() {
1150 let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 8, 64);
1151 let ranges: Vec<_> = cache.buckets().cloned().collect();
1153 assert_eq!(ranges, vec![1..9, 9..17, 17..33, 33..65]);
1154 assert_eq!(cache.total_buckets(), 4);
1155 }
1156
1157 #[test]
1158 fn power_of_two_ladder_picks_smallest_extent_for_actual() {
1159 let mut cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 8, 64);
1162 let captured_uppers: std::cell::RefCell<Vec<u64>> = Default::default();
1163
1164 let (u17, _) = cache
1165 .get_or_compile(17, |upper| {
1166 captured_uppers.borrow_mut().push(upper);
1167 tiny_graph(upper as usize)
1168 })
1169 .unwrap();
1170 let (u9, _) = cache
1171 .get_or_compile(9, |upper| {
1172 captured_uppers.borrow_mut().push(upper);
1173 tiny_graph(upper as usize)
1174 })
1175 .unwrap();
1176 let (u3, _) = cache
1177 .get_or_compile(3, |upper| {
1178 captured_uppers.borrow_mut().push(upper);
1179 tiny_graph(upper as usize)
1180 })
1181 .unwrap();
1182 let (u64_, _) = cache
1183 .get_or_compile(64, |upper| {
1184 captured_uppers.borrow_mut().push(upper);
1185 tiny_graph(upper as usize)
1186 })
1187 .unwrap();
1188
1189 assert_eq!(u17, 32, "key=17 → smallest extent ≥ 17 is 32");
1190 assert_eq!(u9, 16, "key=9 → smallest extent ≥ 9 is 16");
1191 assert_eq!(u3, 8, "key=3 → smallest extent ≥ 3 is 8");
1192 assert_eq!(u64_, 64, "key=64 → exact match at 64");
1193 assert_eq!(*captured_uppers.borrow(), vec![32, 16, 8, 64]);
1194 assert_eq!(cache.compiled_count(), 4);
1195 }
1196
1197 #[test]
1198 fn power_of_two_ladder_min_above_one_starts_at_one() {
1199 let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 16, 32);
1202 let ranges: Vec<_> = cache.buckets().cloned().collect();
1203 assert_eq!(ranges, vec![1..17, 17..33]);
1205 }
1206
1207 #[test]
1208 fn power_of_two_ladder_non_pow2_min_rounds_up() {
1209 let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 10, 64);
1211 let ranges: Vec<_> = cache.buckets().cloned().collect();
1212 assert_eq!(ranges, vec![1..17, 17..33, 33..65]);
1213 }
1214
1215 #[test]
1216 fn power_of_two_ladder_max_below_pow2_extends_up() {
1217 let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 8, 20);
1219 let ranges: Vec<_> = cache.buckets().cloned().collect();
1220 assert_eq!(ranges, vec![1..9, 9..17, 17..33]);
1221 }
1222
1223 #[test]
1224 fn power_of_two_ladder_min_equals_max() {
1225 let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 16, 16);
1226 let ranges: Vec<_> = cache.buckets().cloned().collect();
1227 assert_eq!(ranges, vec![1..17]);
1228 }
1229
1230 #[test]
1231 #[should_panic(expected = "min must be ≥ 1")]
1232 fn power_of_two_ladder_zero_min_rejected() {
1233 let _ = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 0, 16);
1234 }
1235
1236 #[test]
1237 #[should_panic(expected = "max")]
1238 fn power_of_two_ladder_max_below_min_rejected() {
1239 let _ = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 32, 8);
1240 }
1241
1242 #[test]
1255 #[ignore = "active-extent execution is a stub on CPU (thunk.rs::execute_thunks_active)"]
1256 fn active_extent_skips_compute_on_cpu_activation() {
1257 let graph = tiny_graph(15);
1268 let mut compiled = Session::new(Device::Cpu).compile(graph);
1269
1270 let warm_input: Vec<f32> = vec![1.0; 15];
1272 let warm_outs = compiled.run(&[("x", &warm_input)]);
1273 assert_eq!(warm_outs[0], vec![1.0; 15], "warm-up sanity");
1274
1275 let neg_input: Vec<f32> = vec![-1.0; 15];
1278 compiled.set_active_extent(Some((5, 15)));
1279 let outs = compiled.run(&[("x", &neg_input)]);
1280 let out = &outs[0];
1281
1282 assert_eq!(out.len(), 15);
1283 assert_eq!(
1284 out[..5],
1285 [0.0; 5],
1286 "first 5 elements processed (relu of -1)"
1287 );
1288 assert_eq!(
1289 out[5..],
1290 [1.0; 10],
1291 "tail untouched — proves Copy + Activation skipped indices 5..15"
1292 );
1293
1294 compiled.set_active_extent(None);
1297 let outs = compiled.run(&[("x", &neg_input)]);
1298 assert_eq!(
1299 outs[0],
1300 vec![0.0; 15],
1301 "full-extent path must clip every negative"
1302 );
1303 }
1304
1305 #[test]
1306 #[ignore = "active-extent execution is a stub on CPU (thunk.rs::execute_thunks_active)"]
1307 fn active_extent_skips_compute_on_binary_full() {
1308 let mut g = Graph::new("add");
1312 let f = DType::F32;
1313 let a = g.input("a", Shape::new(&[4], f));
1314 let b = g.input("b", Shape::new(&[4], f));
1315 let c = g.add(a, b);
1316 g.set_outputs(vec![c]);
1317 let mut compiled = Session::new(Device::Cpu).compile(g);
1318
1319 let warm = compiled.run(&[("a", &[1.0f32; 4]), ("b", &[1.0f32; 4])]);
1321 assert_eq!(warm[0], vec![2.0; 4]);
1322
1323 compiled.set_active_extent(Some((2, 4)));
1326 let outs = compiled.run(&[("a", &[10.0f32; 4]), ("b", &[10.0f32; 4])]);
1327 let out = &outs[0];
1328 assert_eq!(out[..2], [20.0, 20.0], "first 2 = active sum");
1329 assert_eq!(
1330 out[2..],
1331 [2.0, 2.0],
1332 "tail untouched — proves BinaryFull skipped indices 2..4"
1333 );
1334
1335 compiled.set_active_extent(None);
1337 let outs = compiled.run(&[("a", &[10.0f32; 4]), ("b", &[10.0f32; 4])]);
1338 assert_eq!(outs[0], vec![20.0; 4]);
1339 }
1340
1341 #[test]
1342 #[ignore = "process-wide STATE; runs only in isolation via `cargo test perfetto -- --ignored`"]
1343 fn perfetto_trace_emits_per_thunk_events() {
1344 use std::env;
1351 use std::fs;
1352 let path = env::temp_dir().join(format!("rlx-perfetto-e2e-{}.json", std::process::id()));
1353 if path.exists() {
1354 let _ = fs::remove_file(&path);
1355 }
1356 unsafe {
1357 env::set_var("RLX_TRACE_PERFETTO", &path);
1358 }
1359
1360 let f = DType::F32;
1362 let mut g = Graph::new("perf");
1363 let a = g.input("a", Shape::new(&[4], f));
1364 let b = g.input("b", Shape::new(&[4], f));
1365 let s = g.add(a, b);
1366 let r = g.relu(s);
1367 g.set_outputs(vec![r]);
1368 let mut compiled = Session::new(Device::Cpu).compile(g);
1369 let _ = compiled.run(&[("a", &[1.0; 4]), ("b", &[1.0; 4])]);
1370
1371 crate::perfetto::flush_and_finalize();
1373
1374 let contents = fs::read_to_string(&path).expect("trace file");
1375 assert!(
1377 contents.contains("\"binary\"")
1378 || contents.contains("\"activation\"")
1379 || contents.contains("\"elementwise_region\""),
1380 "expected at least one thunk-name event in perfetto trace; got: {contents}"
1381 );
1382 assert!(contents.trim_start().starts_with('['));
1384 let _ = fs::remove_file(&path);
1385 }
1386
1387 #[test]
1388 fn elementwise_region_fused_matches_unfused() {
1389 let f = DType::F32;
1394 let mut g = Graph::new("ew_e2e");
1395 let a = g.input("a", Shape::new(&[8], f));
1396 let b = g.input("b", Shape::new(&[8], f));
1397 let c = g.input("c", Shape::new(&[8], f));
1398 let s = Shape::new(&[8], f);
1399 let add = g.add(a, b);
1400 let mul = g.mul(add, c);
1401 let relu = g.relu(mul);
1402 let _ = s;
1403 g.set_outputs(vec![relu]);
1404
1405 let mut compiled = Session::new(Device::Cpu).compile(g);
1406 let av: Vec<f32> = vec![1.0, -2.0, 3.0, -4.0, 0.5, -0.5, 1.5, -1.5];
1407 let bv: Vec<f32> = vec![0.5, 1.0, 2.0, 4.0, 0.5, 0.5, 0.5, 0.5];
1408 let cv: Vec<f32> = vec![1.0, 2.0, 1.0, 1.0, 2.0, 3.0, 0.5, 4.0];
1409 let outs = compiled.run(&[("a", &av), ("b", &bv), ("c", &cv)]);
1410 let out = &outs[0];
1411
1412 let expected: Vec<f32> = (0..8)
1413 .map(|i| {
1414 let v = (av[i] + bv[i]) * cv[i];
1415 v.max(0.0)
1416 })
1417 .collect();
1418 for (i, (got, exp)) in out.iter().zip(&expected).enumerate() {
1419 assert!(
1420 (got - exp).abs() < 1e-6,
1421 "mismatch at {i}: got {got}, expected {exp}"
1422 );
1423 }
1424 }
1425
1426 #[test]
1427 #[ignore = "active-extent execution is a stub on CPU (thunk.rs::execute_thunks_active)"]
1428 fn active_extent_skips_compute_on_attention() {
1429 use rlx_ir::op::MaskKind;
1432 let f = DType::F32;
1433 let mut g = Graph::new("attn");
1434 let q = g.input("q", Shape::new(&[1, 4, 8], f));
1435 let k = g.input("k", Shape::new(&[1, 4, 8], f));
1436 let v = g.input("v", Shape::new(&[1, 4, 8], f));
1437 let out = g.attention_kind(q, k, v, 2, 4, MaskKind::None, Shape::new(&[1, 4, 8], f));
1438 g.set_outputs(vec![out]);
1439 let mut compiled = Session::new(Device::Cpu).compile(g);
1440
1441 let warm = compiled.run(&[
1443 ("q", &[1.0f32; 32]),
1444 ("k", &[1.0f32; 32]),
1445 ("v", &[1.0f32; 32]),
1446 ]);
1447 let warm_out = warm[0].clone();
1448 assert_eq!(warm_out.len(), 32);
1449
1450 compiled.set_active_extent(Some((2, 4)));
1454 let outs = compiled.run(&[
1455 ("q", &[3.0f32; 32]),
1456 ("k", &[3.0f32; 32]),
1457 ("v", &[3.0f32; 32]),
1458 ]);
1459 let out = &outs[0];
1460 assert_eq!(out.len(), 32);
1461 assert_eq!(
1462 &out[16..],
1463 &warm_out[16..],
1464 "tail (positions 2,3) must be untouched — proves Attention skipped"
1465 );
1466 assert_ne!(
1468 &out[..16],
1469 &warm_out[..16],
1470 "first 2 positions should reflect new input"
1471 );
1472 }
1473
1474 #[test]
1475 fn active_extent_falls_back_when_unsupported_thunk_in_schedule() {
1476 }
1491
1492 #[test]
1493 fn run_padded_uses_active_extent_on_cpu() {
1494 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
1497 let input: Vec<f32> = vec![
1498 1.0, -1.0, 2.0, -2.0, 3.0, -10.0, -20.0, -30.0, -40.0, -50.0, ];
1501 let (upper, outs) = cache
1507 .run_padded(
1508 5,
1509 5,
1510 |max| tiny_graph(max as usize),
1511 &[("x", &input[..5], 1)],
1512 &[1],
1513 )
1514 .unwrap();
1515 assert_eq!(upper, 15);
1516 assert_eq!(outs[0].len(), 5);
1517 assert_eq!(outs[0], vec![1.0, 0.0, 2.0, 0.0, 3.0]);
1523 }
1524
1525 #[test]
1526 fn run_padded_inner_zero_returns_output_unsliced() {
1527 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
1530 let input: Vec<f32> = vec![1.0, -1.0, 2.0, -2.0, 3.0];
1531
1532 let (upper, outs) = cache
1533 .run_padded(
1534 5,
1535 5,
1536 |max| tiny_graph(max as usize),
1537 &[("x", &input, 1)],
1538 &[0], )
1540 .unwrap();
1541
1542 assert_eq!(upper, 15);
1543 assert_eq!(
1544 outs[0].len(),
1545 15,
1546 "unsliced output preserves full upper extent"
1547 );
1548 assert_eq!(&outs[0][..5], &[1.0, 0.0, 2.0, 0.0, 3.0]);
1550 assert!(outs[0][5..].iter().all(|&v| v == 0.0));
1551 }
1552
1553 #[test]
1554 fn dynamic_dim_cache_specializes_per_key() {
1555 use rlx_ir::DType;
1556 use rlx_ir::Shape;
1557 use rlx_ir::hir::HirModule;
1558 use rlx_ir::sym;
1559
1560 let mut cache = DynamicDimCompileCache::new(Device::Cpu, 4);
1561 let opts = crate::CompileOptions::new();
1562 {
1563 let _short = cache
1564 .get_or_specialize(
1565 8,
1566 &rlx_ir::DimBinding::batch_seq(1, 8),
1567 || {
1568 let mut hir = HirModule::new("dyn_cache");
1569 let x = hir.input_batch_seq("x", sym::BATCH, sym::SEQ, 4, DType::F32);
1570 let w = hir.param("w", Shape::new(&[4, 2], DType::F32));
1571 let y = hir.linear(
1572 x,
1573 w,
1574 None,
1575 None,
1576 Shape::batch_seq(sym::BATCH, sym::SEQ, 2, DType::F32),
1577 );
1578 hir.set_outputs(vec![y]);
1579 hir
1580 },
1581 &opts,
1582 )
1583 .expect("specialize short");
1584 }
1585 assert!(cache.has_template());
1586 assert_eq!(cache.len(), 1);
1587 cache
1588 .get_or_specialize(
1589 128,
1590 &rlx_ir::DimBinding::batch_seq(1, 128),
1591 || panic!("HIR builder must not run twice"),
1592 &opts,
1593 )
1594 .expect("specialize long");
1595 assert_eq!(cache.len(), 2);
1596 }
1597}