1use crate::{CompiledGraph, Device, Session};
40use rlx_ir::DimBinding;
41use rlx_ir::Graph;
42use rlx_ir::hir::HirModule;
43use rlx_opt::CompileResult;
44use std::collections::VecDeque;
45use std::ops::Range;
46
47pub struct CompileCache {
48 device: Device,
49 capacity: usize,
50 policy: Option<rlx_opt::PrecisionPolicy>,
53 entries: Vec<(u64, CompiledGraph)>,
57 order: VecDeque<u64>,
59}
60
61impl CompileCache {
62 pub fn new(device: Device, capacity: usize) -> Self {
63 Self::with_policy(device, capacity, None)
64 }
65
66 pub fn with_policy(
70 device: Device,
71 capacity: usize,
72 policy: Option<rlx_opt::PrecisionPolicy>,
73 ) -> Self {
74 assert!(capacity > 0, "CompileCache capacity must be ≥ 1");
75 Self {
76 device,
77 capacity,
78 policy,
79 entries: Vec::with_capacity(capacity),
80 order: VecDeque::with_capacity(capacity),
81 }
82 }
83
84 pub fn get_or_compile<F: FnOnce() -> Graph>(
88 &mut self,
89 key: u64,
90 build: F,
91 ) -> &mut CompiledGraph {
92 self.get_or_compile_with_options(key, build, &crate::CompileOptions::new())
93 }
94
95 pub fn get_or_compile_with_options<F: FnOnce() -> Graph>(
97 &mut self,
98 key: u64,
99 build: F,
100 options: &crate::CompileOptions,
101 ) -> &mut CompiledGraph {
102 if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
103 return &mut self.entries[idx].1;
104 }
105 let mut session = Session::new(self.device);
106 if let Some(p) = &self.policy {
107 session = session.with_policy(p.clone());
108 }
109 let compiled = session.compile_with(build(), options);
110
111 if self.entries.len() >= self.capacity
113 && let Some(evict_key) = self.order.pop_front()
114 {
115 self.entries.retain(|(k, _)| *k != evict_key);
116 }
117 self.entries.push((key, compiled));
118 self.order.push_back(key);
119 &mut self.entries.last_mut().unwrap().1
120 }
121
122 pub fn len(&self) -> usize {
124 self.entries.len()
125 }
126 pub fn is_empty(&self) -> bool {
127 self.entries.is_empty()
128 }
129 pub fn contains(&self, key: u64) -> bool {
131 self.entries.iter().any(|(k, _)| *k == key)
132 }
133}
134
135pub struct BucketedCompileCache {
182 device: Device,
183 policy: Option<rlx_opt::PrecisionPolicy>,
184 buckets: Vec<Bucket>,
185}
186
187struct Bucket {
188 range: Range<u64>,
189 compiled: Option<CompiledGraph>,
190}
191
192impl BucketedCompileCache {
193 pub fn new(device: Device, buckets: Vec<Range<u64>>) -> Self {
194 Self::with_policy(device, buckets, None)
195 }
196
197 pub fn power_of_two_ladder(device: Device, min: u64, max: u64) -> Self {
213 Self::power_of_two_ladder_with_policy(device, min, max, None)
214 }
215
216 pub fn power_of_two_ladder_with_policy(
217 device: Device,
218 min: u64,
219 max: u64,
220 policy: Option<rlx_opt::PrecisionPolicy>,
221 ) -> Self {
222 assert!(min >= 1, "power_of_two_ladder: min must be ≥ 1, got {min}");
223 assert!(
224 max >= min,
225 "power_of_two_ladder: max ({max}) must be ≥ min ({min})"
226 );
227 let mut buckets: Vec<Range<u64>> = Vec::new();
228 let mut start = 1u64;
229 let mut extent = min.next_power_of_two();
230 loop {
231 buckets.push(start..(extent + 1));
232 if extent >= max {
233 break;
234 }
235 start = extent + 1;
236 extent = extent
237 .checked_mul(2)
238 .expect("power_of_two_ladder: extent overflow");
239 }
240 Self::with_policy(device, buckets, policy)
241 }
242
243 pub fn with_policy(
244 device: Device,
245 buckets: Vec<Range<u64>>,
246 policy: Option<rlx_opt::PrecisionPolicy>,
247 ) -> Self {
248 assert!(!buckets.is_empty(), "BucketedCompileCache needs ≥1 bucket");
249 for (i, b) in buckets.iter().enumerate() {
250 assert!(b.start < b.end, "bucket {i} ({b:?}) is empty");
251 if i + 1 < buckets.len() {
252 assert!(
253 b.end <= buckets[i + 1].start,
254 "buckets {i} ({b:?}) and {} ({:?}) overlap",
255 i + 1,
256 buckets[i + 1],
257 );
258 }
259 }
260 let buckets = buckets
261 .into_iter()
262 .map(|range| Bucket {
263 range,
264 compiled: None,
265 })
266 .collect();
267 Self {
268 device,
269 policy,
270 buckets,
271 }
272 }
273
274 pub fn get_or_compile<F: FnOnce(u64) -> Graph>(
283 &mut self,
284 key: u64,
285 build: F,
286 ) -> Option<(u64, &mut CompiledGraph)> {
287 self.get_or_compile_with_options(key, build, &crate::CompileOptions::new())
288 }
289
290 pub fn get_or_compile_with_options<F: FnOnce(u64) -> Graph>(
292 &mut self,
293 key: u64,
294 build: F,
295 options: &crate::CompileOptions,
296 ) -> Option<(u64, &mut CompiledGraph)> {
297 let idx = self.bucket_for(key)?;
298 let upper = self.buckets[idx].range.end - 1;
299 if self.buckets[idx].compiled.is_none() {
300 let mut session = Session::new(self.device);
301 if let Some(p) = &self.policy {
302 session = session.with_policy(p.clone());
303 }
304 self.buckets[idx].compiled = Some(session.compile_with(build(upper), options));
305 }
306 Some((upper, self.buckets[idx].compiled.as_mut().unwrap()))
307 }
308
309 pub fn get_or_compile_hir<F: FnOnce(u64) -> HirModule>(
312 &mut self,
313 key: u64,
314 build: F,
315 ) -> Option<(u64, &mut CompiledGraph)> {
316 self.get_or_compile_hir_with_options(key, build, &crate::CompileOptions::new())
317 }
318
319 pub fn get_or_compile_hir_with_options<F: FnOnce(u64) -> HirModule>(
321 &mut self,
322 key: u64,
323 build: F,
324 options: &crate::CompileOptions,
325 ) -> Option<(u64, &mut CompiledGraph)> {
326 let idx = self.bucket_for(key)?;
327 let upper = self.buckets[idx].range.end - 1;
328 if self.buckets[idx].compiled.is_none() {
329 let mut session = Session::new(self.device);
330 if let Some(p) = &self.policy {
331 session = session.with_policy(p.clone());
332 }
333 let compiled = session
334 .compile_hir_with(build(upper), options)
335 .expect("HIR lower/compile in bucketed cache");
336 self.buckets[idx].compiled = Some(compiled);
337 }
338 Some((upper, self.buckets[idx].compiled.as_mut().unwrap()))
339 }
340
341 pub fn bucket_for(&self, key: u64) -> Option<usize> {
344 self.buckets.iter().position(|b| b.range.contains(&key))
345 }
346
347 pub fn buckets(&self) -> impl Iterator<Item = &Range<u64>> {
348 self.buckets.iter().map(|b| &b.range)
349 }
350
351 pub fn compiled_count(&self) -> usize {
353 self.buckets.iter().filter(|b| b.compiled.is_some()).count()
354 }
355
356 pub fn total_buckets(&self) -> usize {
357 self.buckets.len()
358 }
359
360 pub fn run_padded<F: FnOnce(u64) -> Graph>(
386 &mut self,
387 key: u64,
388 actual_rows: usize,
389 build: F,
390 inputs: &[(&str, &[f32], usize)],
391 output_inners: &[usize],
392 ) -> Option<(u64, Vec<Vec<f32>>)> {
393 let (upper, compiled) = self.get_or_compile(key, build)?;
394
395 let padded: Vec<(&str, Vec<f32>)> = inputs
397 .iter()
398 .map(|(name, data, inner)| (*name, pad_rows(data, *inner, upper)))
399 .collect();
400 let pairs: Vec<(&str, &[f32])> = padded.iter().map(|(n, d)| (*n, d.as_slice())).collect();
401
402 compiled.set_active_extent(Some((actual_rows, upper as usize)));
408 let raw_outputs = compiled.run(&pairs);
409 compiled.set_active_extent(None);
410
411 let outs = raw_outputs
412 .into_iter()
413 .enumerate()
414 .map(|(i, out)| match output_inners.get(i).copied() {
415 Some(0) | None => out,
416 Some(inner) => slice_rows(&out, inner, actual_rows),
417 })
418 .collect();
419
420 Some((upper, outs))
421 }
422}
423
424pub struct DynamicDimCompileCache {
432 device: Device,
433 policy: Option<rlx_opt::PrecisionPolicy>,
434 capacity: usize,
435 template: Option<CompileResult>,
436 entries: Vec<(u64, CompiledGraph)>,
437 order: VecDeque<u64>,
438}
439
440impl DynamicDimCompileCache {
441 pub fn new(device: Device, capacity: usize) -> Self {
442 Self::with_policy(device, capacity, None)
443 }
444
445 pub fn with_policy(
446 device: Device,
447 capacity: usize,
448 policy: Option<rlx_opt::PrecisionPolicy>,
449 ) -> Self {
450 assert!(capacity > 0, "DynamicDimCompileCache capacity must be ≥ 1");
451 Self {
452 device,
453 policy,
454 capacity,
455 template: None,
456 entries: Vec::with_capacity(capacity),
457 order: VecDeque::with_capacity(capacity),
458 }
459 }
460
461 pub fn compile_device(&self) -> Device {
462 self.device
463 }
464
465 pub fn get_or_specialize<F: FnOnce() -> HirModule>(
468 &mut self,
469 key: u64,
470 binding: &DimBinding,
471 build_hir: F,
472 options: &crate::CompileOptions,
473 ) -> Result<&mut CompiledGraph, rlx_ir::hir::LowerError> {
474 if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
475 return Ok(&mut self.entries[idx].1);
476 }
477 if self.template.is_none() {
478 let mut template_opts = options.clone();
479 template_opts.dim_binding = None;
480 let pipe = crate::stages::pipeline_for(self.device, &template_opts);
481 self.template = Some(pipe.compile_hir(build_hir())?);
482 }
483 let template = self.template.as_ref().expect("template just set");
484 let mut spec_opts = options.clone();
485 spec_opts.dim_binding = None;
486 let pipe = crate::stages::pipeline_for(self.device, &spec_opts);
487 let specialized = template.specialize(&pipe, binding);
488 let backend = crate::registry::backend_for(self.device).expect("backend registered");
489 let mut compile_opts = options.clone();
490 compile_opts.dim_binding = None;
491 if compile_opts.policy.is_none() {
492 if let Some(p) = &self.policy {
493 compile_opts = compile_opts.policy(p.clone());
494 }
495 }
496 let executable = backend.compile_lir(specialized.lir, &compile_opts);
497 let compiled = CompiledGraph::new(executable, self.device);
498
499 if self.entries.len() >= self.capacity
500 && let Some(evict_key) = self.order.pop_front()
501 {
502 self.entries.retain(|(k, _)| *k != evict_key);
503 }
504 self.entries.push((key, compiled));
505 self.order.push_back(key);
506 Ok(&mut self.entries.last_mut().unwrap().1)
507 }
508
509 pub fn len(&self) -> usize {
510 self.entries.len()
511 }
512
513 pub fn is_empty(&self) -> bool {
514 self.entries.is_empty()
515 }
516
517 pub fn contains(&self, key: u64) -> bool {
518 self.entries.iter().any(|(k, _)| *k == key)
519 }
520
521 pub fn has_template(&self) -> bool {
522 self.template.is_some()
523 }
524
525 pub fn ensure_template<F: FnOnce() -> HirModule>(
527 &mut self,
528 build_hir: F,
529 options: &crate::CompileOptions,
530 ) -> Result<&CompileResult, rlx_ir::hir::LowerError> {
531 if self.template.is_none() {
532 let mut opts = options.clone();
533 opts.dim_binding = None;
534 let pipe = crate::stages::pipeline_for(self.device, &opts);
535 self.template = Some(pipe.compile_hir(build_hir())?);
536 }
537 Ok(self.template.as_ref().expect("template set"))
538 }
539
540 pub fn template_result(&self) -> Option<&CompileResult> {
541 self.template.as_ref()
542 }
543
544 pub fn get_or_specialize_aot<F: FnOnce() -> HirModule>(
547 &mut self,
548 aot: &crate::AotCache,
549 disk_base: &str,
550 key: u64,
551 binding: &rlx_ir::DimBinding,
552 build_hir: F,
553 options: &crate::CompileOptions,
554 ) -> Result<&mut CompiledGraph, crate::AotCacheError> {
555 if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
556 return Ok(&mut self.entries[idx].1);
557 }
558 let device = self.device;
559 let template = self.ensure_template(build_hir, options)?;
560 let compiled = aot.specialize_cached(disk_base, binding, device, template, options)?;
561 if self.entries.len() >= self.capacity
562 && let Some(evict_key) = self.order.pop_front()
563 {
564 self.entries.retain(|(k, _)| *k != evict_key);
565 }
566 self.entries.push((key, compiled));
567 self.order.push_back(key);
568 Ok(&mut self.entries.last_mut().unwrap().1)
569 }
570}
571
572pub fn pad_rows(data: &[f32], inner: usize, upper: u64) -> Vec<f32> {
580 assert!(inner > 0, "pad_rows: inner stride must be ≥ 1");
581 assert_eq!(
582 data.len() % inner,
583 0,
584 "pad_rows: data len {} not a multiple of inner {inner}",
585 data.len(),
586 );
587 let upper = upper as usize;
588 let actual = data.len() / inner;
589 assert!(
590 actual <= upper,
591 "pad_rows: actual rows {actual} exceed upper bound {upper}",
592 );
593 let mut out = vec![0.0_f32; upper * inner];
594 out[..actual * inner].copy_from_slice(data);
595 out
596}
597
598pub fn slice_rows(data: &[f32], inner: usize, actual: usize) -> Vec<f32> {
604 assert!(inner > 0, "slice_rows: inner stride must be ≥ 1");
605 assert_eq!(
606 data.len() % inner,
607 0,
608 "slice_rows: data len {} not a multiple of inner {inner}",
609 data.len(),
610 );
611 let upper = data.len() / inner;
612 assert!(
613 actual <= upper,
614 "slice_rows: actual rows {actual} exceed upper {upper}",
615 );
616 data[..actual * inner].to_vec()
617}
618
619#[cfg(test)]
620mod tests {
621 use super::*;
622 use rlx_ir::infer::GraphExt;
623 use rlx_ir::*;
624 use std::cell::Cell;
625
626 fn tiny_graph(n: usize) -> Graph {
627 let mut g = Graph::new("t");
628 let f = DType::F32;
629 let x = g.input("x", Shape::new(&[n], f));
630 let y = g.activation(rlx_ir::op::Activation::Relu, x, Shape::new(&[n], f));
631 g.set_outputs(vec![y]);
632 g
633 }
634
635 #[test]
636 fn cache_hits_avoid_recompile() {
637 let mut cache = CompileCache::new(Device::Cpu, 4);
638 let calls = Cell::new(0);
639
640 let _ = cache.get_or_compile(1, || {
641 calls.set(calls.get() + 1);
642 tiny_graph(8)
643 });
644 let _ = cache.get_or_compile(1, || {
645 calls.set(calls.get() + 1);
646 tiny_graph(8)
647 });
648 let _ = cache.get_or_compile(1, || {
649 calls.set(calls.get() + 1);
650 tiny_graph(8)
651 });
652 assert_eq!(calls.get(), 1);
654 assert_eq!(cache.len(), 1);
655 }
656
657 #[test]
658 fn fifo_evicts_oldest_at_capacity() {
659 let mut cache = CompileCache::new(Device::Cpu, 2);
660 let _ = cache.get_or_compile(1, || tiny_graph(4));
661 let _ = cache.get_or_compile(2, || tiny_graph(8));
662 assert!(cache.contains(1) && cache.contains(2));
663 let _ = cache.get_or_compile(3, || tiny_graph(16));
665 assert!(!cache.contains(1));
666 assert!(cache.contains(2) && cache.contains(3));
667 }
668
669 #[test]
670 fn different_keys_keep_separate_compiles() {
671 let mut cache = CompileCache::new(Device::Cpu, 4);
672 let calls = Cell::new(0);
673 let _ = cache.get_or_compile(1, || {
674 calls.set(calls.get() + 1);
675 tiny_graph(8)
676 });
677 let _ = cache.get_or_compile(2, || {
678 calls.set(calls.get() + 1);
679 tiny_graph(16)
680 });
681 let _ = cache.get_or_compile(1, || {
682 calls.set(calls.get() + 1);
683 tiny_graph(8)
684 });
685 assert_eq!(calls.get(), 2);
687 assert_eq!(cache.len(), 2);
688 }
689
690 #[test]
693 fn bucket_amortizes_keys_within_range() {
694 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..4, 4..16]);
695 let calls = Cell::new(0);
696 let uppers = Cell::new((0u64, 0u64));
697
698 let (u1, _) = cache
700 .get_or_compile(2, |upper| {
701 calls.set(calls.get() + 1);
702 uppers.set((upper, uppers.get().1));
703 tiny_graph(upper as usize)
704 })
705 .expect("key 2 in range");
706 let (u2, _) = cache
707 .get_or_compile(3, |upper| {
708 calls.set(calls.get() + 1);
709 uppers.set((uppers.get().0, upper));
710 tiny_graph(upper as usize)
711 })
712 .expect("key 3 in range");
713
714 assert_eq!(calls.get(), 1);
716 assert_eq!(u1, 3);
717 assert_eq!(u2, 3);
718 assert_eq!(uppers.get().0, 3);
719 assert_eq!(cache.compiled_count(), 1);
720 assert_eq!(cache.total_buckets(), 2);
721 }
722
723 #[test]
724 fn bucket_lookup_returns_none_outside_range() {
725 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..4, 4..16]);
726 assert!(cache.bucket_for(0).is_none());
727 assert!(cache.bucket_for(16).is_none());
728 assert!(cache.bucket_for(100).is_none());
729 assert_eq!(cache.bucket_for(3), Some(0));
730 assert_eq!(cache.bucket_for(4), Some(1));
731
732 let calls = Cell::new(0);
733 let result = cache.get_or_compile(100, |u| {
734 calls.set(calls.get() + 1);
735 tiny_graph(u as usize)
736 });
737 assert!(result.is_none());
738 assert_eq!(calls.get(), 0); assert_eq!(cache.compiled_count(), 0);
740 }
741
742 #[test]
743 fn bucket_compiles_lazily_per_bucket() {
744 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..4, 4..16, 16..64]);
745 let calls = Cell::new(0);
746
747 let _ = cache.get_or_compile(2, |u| {
748 calls.set(calls.get() + 1);
749 tiny_graph(u as usize)
750 });
751 let _ = cache.get_or_compile(8, |u| {
752 calls.set(calls.get() + 1);
753 tiny_graph(u as usize)
754 });
755 assert_eq!(calls.get(), 2);
757 assert_eq!(cache.compiled_count(), 2);
758 assert_eq!(cache.total_buckets(), 3);
759 }
760
761 #[test]
762 #[should_panic(expected = "overlap")]
763 fn bucket_overlap_rejected() {
764 let _ = BucketedCompileCache::new(Device::Cpu, vec![1..8, 4..16]);
765 }
766
767 #[test]
768 #[should_panic(expected = "≥1 bucket")]
769 fn empty_bucket_list_rejected() {
770 let _ = BucketedCompileCache::new(Device::Cpu, vec![]);
771 }
772
773 #[test]
776 fn pad_rows_appends_zeros() {
777 let p = pad_rows(&[1.0, 2.0, 3.0], 1, 5);
779 assert_eq!(p, vec![1.0, 2.0, 3.0, 0.0, 0.0]);
780
781 let p = pad_rows(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 4);
783 assert_eq!(
784 p,
785 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
786 );
787
788 let p = pad_rows(&[7.0, 8.0], 1, 2);
790 assert_eq!(p, vec![7.0, 8.0]);
791 }
792
793 #[test]
794 fn slice_rows_truncates_trailing() {
795 let s = slice_rows(&[1.0, 2.0, 3.0, 0.0, 0.0], 1, 3);
796 assert_eq!(s, vec![1.0, 2.0, 3.0]);
797
798 let s = slice_rows(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0, 0.0], 3, 2);
799 assert_eq!(s, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
800 }
801
802 #[test]
803 #[should_panic(expected = "exceed upper")]
804 fn pad_rows_rejects_too_long_input() {
805 let _ = pad_rows(&[1.0, 2.0, 3.0, 4.0], 1, 3);
806 }
807
808 #[test]
809 #[should_panic(expected = "exceed upper")]
810 fn slice_rows_rejects_too_large_actual() {
811 let _ = slice_rows(&[1.0, 2.0, 3.0], 1, 5);
812 }
813
814 #[test]
817 fn run_padded_pads_input_and_slices_output() {
818 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
821 let input: Vec<f32> = vec![1.0, -1.0, 2.0, -2.0, 3.0, -3.0, 4.0, -4.0, 5.0, -5.0];
822
823 let (upper, outs) = cache
824 .run_padded(
825 10, 10, |max| tiny_graph(max as usize),
828 &[("x", &input, 1)], &[1], )
831 .expect("key 10 in [1..16)");
832
833 assert_eq!(upper, 15);
834 assert_eq!(outs.len(), 1);
835 let out = &outs[0];
836 assert_eq!(out.len(), 10, "output sliced back to actual_rows");
837 let expected: Vec<f32> = input.iter().map(|x| x.max(0.0)).collect();
838 assert_eq!(out, &expected);
839 }
840
841 #[test]
842 fn run_padded_reuses_bucket_across_actuals() {
843 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
845 let calls = Cell::new(0);
846
847 let (u1, o1) = cache
848 .run_padded(
849 10,
850 10,
851 |max| {
852 calls.set(calls.get() + 1);
853 tiny_graph(max as usize)
854 },
855 &[(
856 "x",
857 &[1.0, -1.0, 2.0, -2.0, 3.0, -3.0, 4.0, -4.0, 5.0, -5.0],
858 1,
859 )],
860 &[1],
861 )
862 .unwrap();
863 assert_eq!(o1.len(), 1);
864 assert_eq!(o1[0].len(), 10);
865 assert_eq!(u1, 15);
866
867 let (u2, o2) = cache
868 .run_padded(
869 5,
870 5,
871 |max| {
872 calls.set(calls.get() + 1);
873 tiny_graph(max as usize)
874 },
875 &[("x", &[-1.0, 2.0, -3.0, 4.0, -5.0], 1)],
876 &[1],
877 )
878 .unwrap();
879 assert_eq!(o2.len(), 1);
880 assert_eq!(o2[0].len(), 5);
881 assert_eq!(u2, 15);
882 assert_eq!(o2[0], vec![0.0, 2.0, 0.0, 4.0, 0.0]);
883
884 assert_eq!(calls.get(), 1, "bucket cached across actuals");
885 assert_eq!(cache.compiled_count(), 1);
886 }
887
888 #[test]
889 fn run_padded_returns_none_out_of_range() {
890 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
891 let calls = Cell::new(0);
892 let result = cache.run_padded(
893 100,
894 5,
895 |u| {
896 calls.set(calls.get() + 1);
897 tiny_graph(u as usize)
898 },
899 &[("x", &[1.0, 2.0, 3.0, 4.0, 5.0], 1)],
900 &[1],
901 );
902 assert!(result.is_none());
903 assert_eq!(calls.get(), 0);
904 assert_eq!(cache.compiled_count(), 0);
905 }
906
907 #[test]
910 fn power_of_two_ladder_generates_log_buckets() {
911 let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 8, 64);
912 let ranges: Vec<_> = cache.buckets().cloned().collect();
914 assert_eq!(ranges, vec![1..9, 9..17, 17..33, 33..65]);
915 assert_eq!(cache.total_buckets(), 4);
916 }
917
918 #[test]
919 fn power_of_two_ladder_picks_smallest_extent_for_actual() {
920 let mut cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 8, 64);
923 let captured_uppers: std::cell::RefCell<Vec<u64>> = Default::default();
924
925 let (u17, _) = cache
926 .get_or_compile(17, |upper| {
927 captured_uppers.borrow_mut().push(upper);
928 tiny_graph(upper as usize)
929 })
930 .unwrap();
931 let (u9, _) = cache
932 .get_or_compile(9, |upper| {
933 captured_uppers.borrow_mut().push(upper);
934 tiny_graph(upper as usize)
935 })
936 .unwrap();
937 let (u3, _) = cache
938 .get_or_compile(3, |upper| {
939 captured_uppers.borrow_mut().push(upper);
940 tiny_graph(upper as usize)
941 })
942 .unwrap();
943 let (u64_, _) = cache
944 .get_or_compile(64, |upper| {
945 captured_uppers.borrow_mut().push(upper);
946 tiny_graph(upper as usize)
947 })
948 .unwrap();
949
950 assert_eq!(u17, 32, "key=17 → smallest extent ≥ 17 is 32");
951 assert_eq!(u9, 16, "key=9 → smallest extent ≥ 9 is 16");
952 assert_eq!(u3, 8, "key=3 → smallest extent ≥ 3 is 8");
953 assert_eq!(u64_, 64, "key=64 → exact match at 64");
954 assert_eq!(*captured_uppers.borrow(), vec![32, 16, 8, 64]);
955 assert_eq!(cache.compiled_count(), 4);
956 }
957
958 #[test]
959 fn power_of_two_ladder_min_above_one_starts_at_one() {
960 let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 16, 32);
963 let ranges: Vec<_> = cache.buckets().cloned().collect();
964 assert_eq!(ranges, vec![1..17, 17..33]);
966 }
967
968 #[test]
969 fn power_of_two_ladder_non_pow2_min_rounds_up() {
970 let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 10, 64);
972 let ranges: Vec<_> = cache.buckets().cloned().collect();
973 assert_eq!(ranges, vec![1..17, 17..33, 33..65]);
974 }
975
976 #[test]
977 fn power_of_two_ladder_max_below_pow2_extends_up() {
978 let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 8, 20);
980 let ranges: Vec<_> = cache.buckets().cloned().collect();
981 assert_eq!(ranges, vec![1..9, 9..17, 17..33]);
982 }
983
984 #[test]
985 fn power_of_two_ladder_min_equals_max() {
986 let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 16, 16);
987 let ranges: Vec<_> = cache.buckets().cloned().collect();
988 assert_eq!(ranges, vec![1..17]);
989 }
990
991 #[test]
992 #[should_panic(expected = "min must be ≥ 1")]
993 fn power_of_two_ladder_zero_min_rejected() {
994 let _ = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 0, 16);
995 }
996
997 #[test]
998 #[should_panic(expected = "max")]
999 fn power_of_two_ladder_max_below_min_rejected() {
1000 let _ = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 32, 8);
1001 }
1002
1003 #[test]
1016 #[ignore = "active-extent execution is a stub on CPU (thunk.rs::execute_thunks_active)"]
1017 fn active_extent_skips_compute_on_cpu_activation() {
1018 let graph = tiny_graph(15);
1029 let mut compiled = Session::new(Device::Cpu).compile(graph);
1030
1031 let warm_input: Vec<f32> = vec![1.0; 15];
1033 let warm_outs = compiled.run(&[("x", &warm_input)]);
1034 assert_eq!(warm_outs[0], vec![1.0; 15], "warm-up sanity");
1035
1036 let neg_input: Vec<f32> = vec![-1.0; 15];
1039 compiled.set_active_extent(Some((5, 15)));
1040 let outs = compiled.run(&[("x", &neg_input)]);
1041 let out = &outs[0];
1042
1043 assert_eq!(out.len(), 15);
1044 assert_eq!(
1045 out[..5],
1046 [0.0; 5],
1047 "first 5 elements processed (relu of -1)"
1048 );
1049 assert_eq!(
1050 out[5..],
1051 [1.0; 10],
1052 "tail untouched — proves Copy + Activation skipped indices 5..15"
1053 );
1054
1055 compiled.set_active_extent(None);
1058 let outs = compiled.run(&[("x", &neg_input)]);
1059 assert_eq!(
1060 outs[0],
1061 vec![0.0; 15],
1062 "full-extent path must clip every negative"
1063 );
1064 }
1065
1066 #[test]
1067 #[ignore = "active-extent execution is a stub on CPU (thunk.rs::execute_thunks_active)"]
1068 fn active_extent_skips_compute_on_binary_full() {
1069 let mut g = Graph::new("add");
1073 let f = DType::F32;
1074 let a = g.input("a", Shape::new(&[4], f));
1075 let b = g.input("b", Shape::new(&[4], f));
1076 let c = g.add(a, b);
1077 g.set_outputs(vec![c]);
1078 let mut compiled = Session::new(Device::Cpu).compile(g);
1079
1080 let warm = compiled.run(&[("a", &[1.0f32; 4]), ("b", &[1.0f32; 4])]);
1082 assert_eq!(warm[0], vec![2.0; 4]);
1083
1084 compiled.set_active_extent(Some((2, 4)));
1087 let outs = compiled.run(&[("a", &[10.0f32; 4]), ("b", &[10.0f32; 4])]);
1088 let out = &outs[0];
1089 assert_eq!(out[..2], [20.0, 20.0], "first 2 = active sum");
1090 assert_eq!(
1091 out[2..],
1092 [2.0, 2.0],
1093 "tail untouched — proves BinaryFull skipped indices 2..4"
1094 );
1095
1096 compiled.set_active_extent(None);
1098 let outs = compiled.run(&[("a", &[10.0f32; 4]), ("b", &[10.0f32; 4])]);
1099 assert_eq!(outs[0], vec![20.0; 4]);
1100 }
1101
1102 #[test]
1103 #[ignore = "process-wide STATE; runs only in isolation via `cargo test perfetto -- --ignored`"]
1104 fn perfetto_trace_emits_per_thunk_events() {
1105 use std::env;
1112 use std::fs;
1113 let path = env::temp_dir().join(format!("rlx-perfetto-e2e-{}.json", std::process::id()));
1114 if path.exists() {
1115 let _ = fs::remove_file(&path);
1116 }
1117 unsafe {
1118 env::set_var("RLX_TRACE_PERFETTO", &path);
1119 }
1120
1121 let f = DType::F32;
1123 let mut g = Graph::new("perf");
1124 let a = g.input("a", Shape::new(&[4], f));
1125 let b = g.input("b", Shape::new(&[4], f));
1126 let s = g.add(a, b);
1127 let r = g.relu(s);
1128 g.set_outputs(vec![r]);
1129 let mut compiled = Session::new(Device::Cpu).compile(g);
1130 let _ = compiled.run(&[("a", &[1.0; 4]), ("b", &[1.0; 4])]);
1131
1132 crate::perfetto::flush_and_finalize();
1134
1135 let contents = fs::read_to_string(&path).expect("trace file");
1136 assert!(
1138 contents.contains("\"binary\"")
1139 || contents.contains("\"activation\"")
1140 || contents.contains("\"elementwise_region\""),
1141 "expected at least one thunk-name event in perfetto trace; got: {contents}"
1142 );
1143 assert!(contents.trim_start().starts_with('['));
1145 let _ = fs::remove_file(&path);
1146 }
1147
1148 #[test]
1149 fn elementwise_region_fused_matches_unfused() {
1150 let f = DType::F32;
1155 let mut g = Graph::new("ew_e2e");
1156 let a = g.input("a", Shape::new(&[8], f));
1157 let b = g.input("b", Shape::new(&[8], f));
1158 let c = g.input("c", Shape::new(&[8], f));
1159 let s = Shape::new(&[8], f);
1160 let add = g.add(a, b);
1161 let mul = g.mul(add, c);
1162 let relu = g.relu(mul);
1163 let _ = s;
1164 g.set_outputs(vec![relu]);
1165
1166 let mut compiled = Session::new(Device::Cpu).compile(g);
1167 let av: Vec<f32> = vec![1.0, -2.0, 3.0, -4.0, 0.5, -0.5, 1.5, -1.5];
1168 let bv: Vec<f32> = vec![0.5, 1.0, 2.0, 4.0, 0.5, 0.5, 0.5, 0.5];
1169 let cv: Vec<f32> = vec![1.0, 2.0, 1.0, 1.0, 2.0, 3.0, 0.5, 4.0];
1170 let outs = compiled.run(&[("a", &av), ("b", &bv), ("c", &cv)]);
1171 let out = &outs[0];
1172
1173 let expected: Vec<f32> = (0..8)
1174 .map(|i| {
1175 let v = (av[i] + bv[i]) * cv[i];
1176 v.max(0.0)
1177 })
1178 .collect();
1179 for (i, (got, exp)) in out.iter().zip(&expected).enumerate() {
1180 assert!(
1181 (got - exp).abs() < 1e-6,
1182 "mismatch at {i}: got {got}, expected {exp}"
1183 );
1184 }
1185 }
1186
1187 #[test]
1188 #[ignore = "active-extent execution is a stub on CPU (thunk.rs::execute_thunks_active)"]
1189 fn active_extent_skips_compute_on_attention() {
1190 use rlx_ir::op::MaskKind;
1193 let f = DType::F32;
1194 let mut g = Graph::new("attn");
1195 let q = g.input("q", Shape::new(&[1, 4, 8], f));
1196 let k = g.input("k", Shape::new(&[1, 4, 8], f));
1197 let v = g.input("v", Shape::new(&[1, 4, 8], f));
1198 let out = g.attention_kind(q, k, v, 2, 4, MaskKind::None, Shape::new(&[1, 4, 8], f));
1199 g.set_outputs(vec![out]);
1200 let mut compiled = Session::new(Device::Cpu).compile(g);
1201
1202 let warm = compiled.run(&[
1204 ("q", &[1.0f32; 32]),
1205 ("k", &[1.0f32; 32]),
1206 ("v", &[1.0f32; 32]),
1207 ]);
1208 let warm_out = warm[0].clone();
1209 assert_eq!(warm_out.len(), 32);
1210
1211 compiled.set_active_extent(Some((2, 4)));
1215 let outs = compiled.run(&[
1216 ("q", &[3.0f32; 32]),
1217 ("k", &[3.0f32; 32]),
1218 ("v", &[3.0f32; 32]),
1219 ]);
1220 let out = &outs[0];
1221 assert_eq!(out.len(), 32);
1222 assert_eq!(
1223 &out[16..],
1224 &warm_out[16..],
1225 "tail (positions 2,3) must be untouched — proves Attention skipped"
1226 );
1227 assert_ne!(
1229 &out[..16],
1230 &warm_out[..16],
1231 "first 2 positions should reflect new input"
1232 );
1233 }
1234
1235 #[test]
1236 fn active_extent_falls_back_when_unsupported_thunk_in_schedule() {
1237 }
1252
1253 #[test]
1254 fn run_padded_uses_active_extent_on_cpu() {
1255 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
1258 let input: Vec<f32> = vec![
1259 1.0, -1.0, 2.0, -2.0, 3.0, -10.0, -20.0, -30.0, -40.0, -50.0, ];
1262 let (upper, outs) = cache
1268 .run_padded(
1269 5,
1270 5,
1271 |max| tiny_graph(max as usize),
1272 &[("x", &input[..5], 1)],
1273 &[1],
1274 )
1275 .unwrap();
1276 assert_eq!(upper, 15);
1277 assert_eq!(outs[0].len(), 5);
1278 assert_eq!(outs[0], vec![1.0, 0.0, 2.0, 0.0, 3.0]);
1284 }
1285
1286 #[test]
1287 fn run_padded_inner_zero_returns_output_unsliced() {
1288 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
1291 let input: Vec<f32> = vec![1.0, -1.0, 2.0, -2.0, 3.0];
1292
1293 let (upper, outs) = cache
1294 .run_padded(
1295 5,
1296 5,
1297 |max| tiny_graph(max as usize),
1298 &[("x", &input, 1)],
1299 &[0], )
1301 .unwrap();
1302
1303 assert_eq!(upper, 15);
1304 assert_eq!(
1305 outs[0].len(),
1306 15,
1307 "unsliced output preserves full upper extent"
1308 );
1309 assert_eq!(&outs[0][..5], &[1.0, 0.0, 2.0, 0.0, 3.0]);
1311 assert!(outs[0][5..].iter().all(|&v| v == 0.0));
1312 }
1313
1314 #[test]
1315 fn dynamic_dim_cache_specializes_per_key() {
1316 use rlx_ir::DType;
1317 use rlx_ir::Shape;
1318 use rlx_ir::hir::HirModule;
1319 use rlx_ir::sym;
1320
1321 let mut cache = DynamicDimCompileCache::new(Device::Cpu, 4);
1322 let opts = crate::CompileOptions::new();
1323 {
1324 let _short = cache
1325 .get_or_specialize(
1326 8,
1327 &rlx_ir::DimBinding::batch_seq(1, 8),
1328 || {
1329 let mut hir = HirModule::new("dyn_cache");
1330 let x = hir.input_batch_seq("x", sym::BATCH, sym::SEQ, 4, DType::F32);
1331 let w = hir.param("w", Shape::new(&[4, 2], DType::F32));
1332 let y = hir.linear(
1333 x,
1334 w,
1335 None,
1336 None,
1337 Shape::batch_seq(sym::BATCH, sym::SEQ, 2, DType::F32),
1338 );
1339 hir.set_outputs(vec![y]);
1340 hir
1341 },
1342 &opts,
1343 )
1344 .expect("specialize short");
1345 }
1346 assert!(cache.has_template());
1347 assert_eq!(cache.len(), 1);
1348 cache
1349 .get_or_specialize(
1350 128,
1351 &rlx_ir::DimBinding::batch_seq(1, 128),
1352 || panic!("HIR builder must not run twice"),
1353 &opts,
1354 )
1355 .expect("specialize long");
1356 assert_eq!(cache.len(), 2);
1357 }
1358}