1use crate::{CompiledGraph, Device, Session};
40use rlx_ir::DimBinding;
41use rlx_ir::Graph;
42use rlx_ir::hir::HirModule;
43use rlx_opt::CompileResult;
44use std::collections::HashMap;
45use std::collections::VecDeque;
46use std::ops::Range;
47
48pub struct CacheRunInput<'a> {
50 pub name: &'a str,
51 pub data: &'a [f32],
52 pub row_inner: Option<usize>,
54}
55
56pub struct CompileCache {
57 device: Device,
58 capacity: usize,
59 policy: Option<rlx_opt::PrecisionPolicy>,
62 entries: Vec<(u64, CompiledGraph)>,
66 order: VecDeque<u64>,
68}
69
70impl CompileCache {
71 pub fn new(device: Device, capacity: usize) -> Self {
72 Self::with_policy(device, capacity, None)
73 }
74
75 pub fn with_policy(
79 device: Device,
80 capacity: usize,
81 policy: Option<rlx_opt::PrecisionPolicy>,
82 ) -> Self {
83 assert!(capacity > 0, "CompileCache capacity must be ≥ 1");
84 Self {
85 device,
86 capacity,
87 policy,
88 entries: Vec::with_capacity(capacity),
89 order: VecDeque::with_capacity(capacity),
90 }
91 }
92
93 pub fn get_or_compile<F: FnOnce() -> Graph>(
97 &mut self,
98 key: u64,
99 build: F,
100 ) -> &mut CompiledGraph {
101 self.get_or_compile_with_options(key, build, &crate::CompileOptions::new())
102 }
103
104 pub fn get_or_compile_with_options<F: FnOnce() -> Graph>(
106 &mut self,
107 key: u64,
108 build: F,
109 options: &crate::CompileOptions,
110 ) -> &mut CompiledGraph {
111 if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
112 return &mut self.entries[idx].1;
113 }
114 let mut session = Session::new(self.device);
115 if let Some(p) = &self.policy {
116 session = session.with_policy(p.clone());
117 }
118 let compiled = session.compile_with(build(), options);
119
120 if self.entries.len() >= self.capacity
122 && let Some(evict_key) = self.order.pop_front()
123 {
124 sync_evicted_entry(&mut self.entries, evict_key);
125 self.entries.retain(|(k, _)| *k != evict_key);
126 }
127 self.entries.push((key, compiled));
128 self.order.push_back(key);
129 &mut self.entries.last_mut().unwrap().1
130 }
131
132 pub fn len(&self) -> usize {
134 self.entries.len()
135 }
136 pub fn is_empty(&self) -> bool {
137 self.entries.is_empty()
138 }
139 pub fn contains(&self, key: u64) -> bool {
141 self.entries.iter().any(|(k, _)| *k == key)
142 }
143
144 pub fn sync_all(&mut self) {
146 for (_, compiled) in &mut self.entries {
147 compiled.sync_pending();
148 }
149 }
150}
151
152fn sync_evicted_entry(entries: &mut [(u64, CompiledGraph)], evict_key: u64) {
153 if let Some((_, compiled)) = entries.iter_mut().find(|(k, _)| *k == evict_key) {
154 compiled.sync_pending();
155 }
156}
157
158pub struct BucketedCompileCache {
205 device: Device,
206 policy: Option<rlx_opt::PrecisionPolicy>,
207 buckets: Vec<Bucket>,
208}
209
210struct Bucket {
211 range: Range<u64>,
212 compiled: Option<CompiledGraph>,
213}
214
215impl BucketedCompileCache {
216 pub fn new(device: Device, buckets: Vec<Range<u64>>) -> Self {
217 Self::with_policy(device, buckets, None)
218 }
219
220 pub fn power_of_two_ladder(device: Device, min: u64, max: u64) -> Self {
236 Self::power_of_two_ladder_with_policy(device, min, max, None)
237 }
238
239 pub fn power_of_two_ladder_with_policy(
240 device: Device,
241 min: u64,
242 max: u64,
243 policy: Option<rlx_opt::PrecisionPolicy>,
244 ) -> Self {
245 assert!(min >= 1, "power_of_two_ladder: min must be ≥ 1, got {min}");
246 assert!(
247 max >= min,
248 "power_of_two_ladder: max ({max}) must be ≥ min ({min})"
249 );
250 let mut buckets: Vec<Range<u64>> = Vec::new();
251 let mut start = 1u64;
252 let mut extent = min.next_power_of_two();
253 loop {
254 buckets.push(start..(extent + 1));
255 if extent >= max {
256 break;
257 }
258 start = extent + 1;
259 extent = extent
260 .checked_mul(2)
261 .expect("power_of_two_ladder: extent overflow");
262 }
263 Self::with_policy(device, buckets, policy)
264 }
265
266 pub fn with_policy(
267 device: Device,
268 buckets: Vec<Range<u64>>,
269 policy: Option<rlx_opt::PrecisionPolicy>,
270 ) -> Self {
271 assert!(!buckets.is_empty(), "BucketedCompileCache needs ≥1 bucket");
272 for (i, b) in buckets.iter().enumerate() {
273 assert!(b.start < b.end, "bucket {i} ({b:?}) is empty");
274 if i + 1 < buckets.len() {
275 assert!(
276 b.end <= buckets[i + 1].start,
277 "buckets {i} ({b:?}) and {} ({:?}) overlap",
278 i + 1,
279 buckets[i + 1],
280 );
281 }
282 }
283 let buckets = buckets
284 .into_iter()
285 .map(|range| Bucket {
286 range,
287 compiled: None,
288 })
289 .collect();
290 Self {
291 device,
292 policy,
293 buckets,
294 }
295 }
296
297 pub fn get_or_compile<F: FnOnce(u64) -> Graph>(
306 &mut self,
307 key: u64,
308 build: F,
309 ) -> Option<(u64, &mut CompiledGraph)> {
310 self.get_or_compile_with_options(key, build, &crate::CompileOptions::new())
311 }
312
313 pub fn get_or_compile_with_options<F: FnOnce(u64) -> Graph>(
315 &mut self,
316 key: u64,
317 build: F,
318 options: &crate::CompileOptions,
319 ) -> Option<(u64, &mut CompiledGraph)> {
320 let idx = self.bucket_for(key)?;
321 let upper = self.buckets[idx].range.end - 1;
322 if self.buckets[idx].compiled.is_none() {
323 let mut session = Session::new(self.device);
324 if let Some(p) = &self.policy {
325 session = session.with_policy(p.clone());
326 }
327 self.buckets[idx].compiled = Some(session.compile_with(build(upper), options));
328 }
329 Some((upper, self.buckets[idx].compiled.as_mut().unwrap()))
330 }
331
332 pub fn get_or_compile_hir<F: FnOnce(u64) -> HirModule>(
335 &mut self,
336 key: u64,
337 build: F,
338 ) -> Option<(u64, &mut CompiledGraph)> {
339 self.get_or_compile_hir_with_options(key, build, &crate::CompileOptions::new())
340 }
341
342 pub fn get_or_compile_hir_with_options<F: FnOnce(u64) -> HirModule>(
344 &mut self,
345 key: u64,
346 build: F,
347 options: &crate::CompileOptions,
348 ) -> Option<(u64, &mut CompiledGraph)> {
349 let idx = self.bucket_for(key)?;
350 let upper = self.buckets[idx].range.end - 1;
351 if self.buckets[idx].compiled.is_none() {
352 let mut session = Session::new(self.device);
353 if let Some(p) = &self.policy {
354 session = session.with_policy(p.clone());
355 }
356 let compiled = session
357 .compile_hir_with(build(upper), options)
358 .expect("HIR lower/compile in bucketed cache");
359 self.buckets[idx].compiled = Some(compiled);
360 }
361 Some((upper, self.buckets[idx].compiled.as_mut().unwrap()))
362 }
363
364 pub fn bucket_for(&self, key: u64) -> Option<usize> {
367 self.buckets.iter().position(|b| b.range.contains(&key))
368 }
369
370 pub fn bucket_upper_for_key(&self, key: u64) -> Option<u64> {
372 let idx = self.bucket_for(key)?;
373 Some(self.buckets[idx].range.end - 1)
374 }
375
376 pub fn buckets(&self) -> impl Iterator<Item = &Range<u64>> {
377 self.buckets.iter().map(|b| &b.range)
378 }
379
380 pub fn compiled_count(&self) -> usize {
382 self.buckets.iter().filter(|b| b.compiled.is_some()).count()
383 }
384
385 pub fn compiled_for_key_mut(&mut self, key: u64) -> Option<&mut CompiledGraph> {
387 let idx = self.bucket_for(key)?;
388 self.buckets[idx].compiled.as_mut()
389 }
390
391 pub fn total_buckets(&self) -> usize {
392 self.buckets.len()
393 }
394
395 pub fn run_padded<F: FnOnce(u64) -> Graph>(
421 &mut self,
422 key: u64,
423 actual_rows: usize,
424 build: F,
425 inputs: &[(&str, &[f32], usize)],
426 output_inners: &[usize],
427 ) -> Option<(u64, Vec<Vec<f32>>)> {
428 let (upper, compiled) = self.get_or_compile(key, build)?;
429
430 let padded: Vec<(&str, Vec<f32>)> = inputs
432 .iter()
433 .map(|(name, data, inner)| (*name, pad_rows(data, *inner, upper)))
434 .collect();
435 let pairs: Vec<(&str, &[f32])> = padded.iter().map(|(n, d)| (*n, d.as_slice())).collect();
436
437 compiled.set_active_extent(Some((actual_rows, upper as usize)));
443 let raw_outputs = compiled.run(&pairs);
444 compiled.set_active_extent(None);
445
446 let outs = raw_outputs
447 .into_iter()
448 .enumerate()
449 .map(|(i, out)| match output_inners.get(i).copied() {
450 Some(0) | None => out,
451 Some(inner) => slice_rows(&out, inner, actual_rows),
452 })
453 .collect();
454
455 Some((upper, outs))
456 }
457
458 pub fn ensure_graph_with_params<F>(
460 &mut self,
461 key: u64,
462 build: F,
463 options: &crate::CompileOptions,
464 ) -> Option<(u64, &mut CompiledGraph)>
465 where
466 F: FnOnce(u64) -> (Graph, HashMap<String, Vec<f32>>),
467 {
468 let idx = self.bucket_for(key)?;
469 let upper = self.buckets[idx].range.end - 1;
470 if self.buckets[idx].compiled.is_none() {
471 let (graph, params) = build(upper);
472 let mut session = Session::new(self.device);
473 if let Some(p) = &self.policy {
474 session = session.with_policy(p.clone());
475 }
476 let mut compiled = session.compile_with(graph, options);
477 for (name, data) in params {
478 compiled.set_param(&name, &data);
479 }
480 self.buckets[idx].compiled = Some(compiled);
481 }
482 Some((upper, self.buckets[idx].compiled.as_mut().unwrap()))
483 }
484
485 pub fn ensure_hir_with_params<F>(
487 &mut self,
488 key: u64,
489 build: F,
490 options: &crate::CompileOptions,
491 ) -> Option<(u64, &mut CompiledGraph)>
492 where
493 F: FnOnce(u64) -> (HirModule, HashMap<String, Vec<f32>>),
494 {
495 let idx = self.bucket_for(key)?;
496 let upper = self.buckets[idx].range.end - 1;
497 if self.buckets[idx].compiled.is_none() {
498 let (hir, params) = build(upper);
499 let mut session = Session::new(self.device);
500 if let Some(p) = &self.policy {
501 session = session.with_policy(p.clone());
502 }
503 let mut compiled = session
504 .compile_hir_with(hir, options)
505 .expect("HIR lower/compile in ensure_hir_with_params");
506 for (name, data) in params {
507 compiled.set_param(&name, &data);
508 }
509 self.buckets[idx].compiled = Some(compiled);
510 }
511 Some((upper, self.buckets[idx].compiled.as_mut().unwrap()))
512 }
513
514 pub fn run_padded_mixed<F>(
516 &mut self,
517 key: u64,
518 actual_rows: usize,
519 build: F,
520 inputs: &[CacheRunInput<'_>],
521 output_inners: &[usize],
522 ) -> Option<(u64, Vec<Vec<f32>>)>
523 where
524 F: FnOnce(u64) -> Graph,
525 {
526 let (upper, compiled) = self.get_or_compile(key, build)?;
527
528 let padded: Vec<(&str, Vec<f32>)> = inputs
529 .iter()
530 .map(|inp| match inp.row_inner {
531 Some(inner) => (inp.name, pad_rows(inp.data, inner, upper)),
532 None => (inp.name, inp.data.to_vec()),
533 })
534 .collect();
535 let pairs: Vec<(&str, &[f32])> = padded.iter().map(|(n, d)| (*n, d.as_slice())).collect();
536
537 compiled.set_active_extent(Some((actual_rows, upper as usize)));
538 let raw_outputs = compiled.run(&pairs);
539 compiled.set_active_extent(None);
540
541 let outs = raw_outputs
542 .into_iter()
543 .enumerate()
544 .map(|(i, out)| match output_inners.get(i).copied() {
545 Some(0) | None => out,
546 Some(inner) => slice_rows(&out, inner, actual_rows),
547 })
548 .collect();
549
550 Some((upper, outs))
551 }
552
553 pub fn sync_all(&mut self) {
555 for bucket in &mut self.buckets {
556 if let Some(compiled) = &mut bucket.compiled {
557 compiled.sync_pending();
558 }
559 }
560 }
561}
562
563pub struct DynamicDimCompileCache {
571 device: Device,
572 policy: Option<rlx_opt::PrecisionPolicy>,
573 capacity: usize,
574 template: Option<CompileResult>,
575 entries: Vec<(u64, CompiledGraph)>,
576 order: VecDeque<u64>,
577}
578
579impl DynamicDimCompileCache {
580 pub fn new(device: Device, capacity: usize) -> Self {
581 Self::with_policy(device, capacity, None)
582 }
583
584 pub fn with_policy(
585 device: Device,
586 capacity: usize,
587 policy: Option<rlx_opt::PrecisionPolicy>,
588 ) -> Self {
589 assert!(capacity > 0, "DynamicDimCompileCache capacity must be ≥ 1");
590 Self {
591 device,
592 policy,
593 capacity,
594 template: None,
595 entries: Vec::with_capacity(capacity),
596 order: VecDeque::with_capacity(capacity),
597 }
598 }
599
600 pub fn compile_device(&self) -> Device {
601 self.device
602 }
603
604 pub fn get_or_specialize<F: FnOnce() -> HirModule>(
607 &mut self,
608 key: u64,
609 binding: &DimBinding,
610 build_hir: F,
611 options: &crate::CompileOptions,
612 ) -> Result<&mut CompiledGraph, rlx_ir::hir::LowerError> {
613 if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
614 return Ok(&mut self.entries[idx].1);
615 }
616 if self.template.is_none() {
617 let mut template_opts = options.clone();
618 template_opts.dim_binding = None;
619 let pipe = crate::stages::pipeline_for(self.device, &template_opts);
620 self.template = Some(pipe.compile_hir(build_hir())?);
621 }
622 let template = self.template.as_ref().expect("template just set");
623 let mut spec_opts = options.clone();
624 spec_opts.dim_binding = None;
625 let pipe = crate::stages::pipeline_for(self.device, &spec_opts);
626 let specialized = template.specialize(&pipe, binding);
627 let backend = crate::registry::backend_for(self.device).expect("backend registered");
628 let mut compile_opts = options.clone();
629 compile_opts.dim_binding = None;
630 if compile_opts.policy.is_none() {
631 if let Some(p) = &self.policy {
632 compile_opts = compile_opts.policy(p.clone());
633 }
634 }
635 let executable = backend.compile_lir(specialized.lir, &compile_opts);
636 let compiled = CompiledGraph::new(executable, self.device);
637
638 if self.entries.len() >= self.capacity
639 && let Some(evict_key) = self.order.pop_front()
640 {
641 sync_evicted_entry(&mut self.entries, evict_key);
642 self.entries.retain(|(k, _)| *k != evict_key);
643 }
644 self.entries.push((key, compiled));
645 self.order.push_back(key);
646 Ok(&mut self.entries.last_mut().unwrap().1)
647 }
648
649 pub fn len(&self) -> usize {
650 self.entries.len()
651 }
652
653 pub fn is_empty(&self) -> bool {
654 self.entries.is_empty()
655 }
656
657 pub fn contains(&self, key: u64) -> bool {
658 self.entries.iter().any(|(k, _)| *k == key)
659 }
660
661 pub fn has_template(&self) -> bool {
662 self.template.is_some()
663 }
664
665 pub fn sync_all(&mut self) {
667 for (_, compiled) in &mut self.entries {
668 compiled.sync_pending();
669 }
670 }
671
672 pub fn ensure_template<F: FnOnce() -> HirModule>(
674 &mut self,
675 build_hir: F,
676 options: &crate::CompileOptions,
677 ) -> Result<&CompileResult, rlx_ir::hir::LowerError> {
678 if self.template.is_none() {
679 let mut opts = options.clone();
680 opts.dim_binding = None;
681 let pipe = crate::stages::pipeline_for(self.device, &opts);
682 self.template = Some(pipe.compile_hir(build_hir())?);
683 }
684 Ok(self.template.as_ref().expect("template set"))
685 }
686
687 pub fn template_result(&self) -> Option<&CompileResult> {
688 self.template.as_ref()
689 }
690
691 pub fn get_or_specialize_aot<F: FnOnce() -> HirModule>(
694 &mut self,
695 aot: &crate::AotCache,
696 disk_base: &str,
697 key: u64,
698 binding: &rlx_ir::DimBinding,
699 build_hir: F,
700 options: &crate::CompileOptions,
701 ) -> Result<&mut CompiledGraph, crate::AotCacheError> {
702 if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
703 return Ok(&mut self.entries[idx].1);
704 }
705 let device = self.device;
706 let template = self.ensure_template(build_hir, options)?;
707 let compiled = aot.specialize_cached(disk_base, binding, device, template, options)?;
708 if self.entries.len() >= self.capacity
709 && let Some(evict_key) = self.order.pop_front()
710 {
711 sync_evicted_entry(&mut self.entries, evict_key);
712 self.entries.retain(|(k, _)| *k != evict_key);
713 }
714 self.entries.push((key, compiled));
715 self.order.push_back(key);
716 Ok(&mut self.entries.last_mut().unwrap().1)
717 }
718}
719
720pub fn pad_rows(data: &[f32], inner: usize, upper: u64) -> Vec<f32> {
728 assert!(inner > 0, "pad_rows: inner stride must be ≥ 1");
729 assert_eq!(
730 data.len() % inner,
731 0,
732 "pad_rows: data len {} not a multiple of inner {inner}",
733 data.len(),
734 );
735 let upper = upper as usize;
736 let actual = data.len() / inner;
737 assert!(
738 actual <= upper,
739 "pad_rows: actual rows {actual} exceed upper bound {upper}",
740 );
741 let mut out = vec![0.0_f32; upper * inner];
742 out[..actual * inner].copy_from_slice(data);
743 out
744}
745
746pub fn pad_rows_into(out: &mut [f32], data: &[f32], inner: usize) {
748 assert!(inner > 0, "pad_rows_into: inner stride must be ≥ 1");
749 assert_eq!(
750 data.len() % inner,
751 0,
752 "pad_rows_into: data len {} not a multiple of inner {inner}",
753 data.len(),
754 );
755 assert_eq!(
756 out.len() % inner,
757 0,
758 "pad_rows_into: out len {} not a multiple of inner {inner}",
759 out.len(),
760 );
761 let upper = out.len() / inner;
762 let actual = data.len() / inner;
763 assert!(
764 actual <= upper,
765 "pad_rows_into: actual rows {actual} exceed upper bound {upper}",
766 );
767 out.fill(0.0);
768 out[..data.len()].copy_from_slice(data);
769}
770
771pub fn slice_rows(data: &[f32], inner: usize, actual: usize) -> Vec<f32> {
777 assert!(inner > 0, "slice_rows: inner stride must be ≥ 1");
778 assert_eq!(
779 data.len() % inner,
780 0,
781 "slice_rows: data len {} not a multiple of inner {inner}",
782 data.len(),
783 );
784 let upper = data.len() / inner;
785 assert!(
786 actual <= upper,
787 "slice_rows: actual rows {actual} exceed upper {upper}",
788 );
789 data[..actual * inner].to_vec()
790}
791
792#[cfg(test)]
793mod tests {
794 use super::*;
795 use rlx_ir::infer::GraphExt;
796 use rlx_ir::*;
797 use std::cell::Cell;
798
799 fn tiny_graph(n: usize) -> Graph {
800 let mut g = Graph::new("t");
801 let f = DType::F32;
802 let x = g.input("x", Shape::new(&[n], f));
803 let y = g.activation(rlx_ir::op::Activation::Relu, x, Shape::new(&[n], f));
804 g.set_outputs(vec![y]);
805 g
806 }
807
808 #[test]
809 fn cache_hits_avoid_recompile() {
810 let mut cache = CompileCache::new(Device::Cpu, 4);
811 let calls = Cell::new(0);
812
813 let _ = cache.get_or_compile(1, || {
814 calls.set(calls.get() + 1);
815 tiny_graph(8)
816 });
817 let _ = cache.get_or_compile(1, || {
818 calls.set(calls.get() + 1);
819 tiny_graph(8)
820 });
821 let _ = cache.get_or_compile(1, || {
822 calls.set(calls.get() + 1);
823 tiny_graph(8)
824 });
825 assert_eq!(calls.get(), 1);
827 assert_eq!(cache.len(), 1);
828 }
829
830 #[test]
831 fn fifo_evicts_oldest_at_capacity() {
832 let mut cache = CompileCache::new(Device::Cpu, 2);
833 let _ = cache.get_or_compile(1, || tiny_graph(4));
834 let _ = cache.get_or_compile(2, || tiny_graph(8));
835 assert!(cache.contains(1) && cache.contains(2));
836 let _ = cache.get_or_compile(3, || tiny_graph(16));
838 assert!(!cache.contains(1));
839 assert!(cache.contains(2) && cache.contains(3));
840 }
841
842 #[test]
843 fn different_keys_keep_separate_compiles() {
844 let mut cache = CompileCache::new(Device::Cpu, 4);
845 let calls = Cell::new(0);
846 let _ = cache.get_or_compile(1, || {
847 calls.set(calls.get() + 1);
848 tiny_graph(8)
849 });
850 let _ = cache.get_or_compile(2, || {
851 calls.set(calls.get() + 1);
852 tiny_graph(16)
853 });
854 let _ = cache.get_or_compile(1, || {
855 calls.set(calls.get() + 1);
856 tiny_graph(8)
857 });
858 assert_eq!(calls.get(), 2);
860 assert_eq!(cache.len(), 2);
861 }
862
863 #[test]
866 fn bucket_amortizes_keys_within_range() {
867 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..4, 4..16]);
868 let calls = Cell::new(0);
869 let uppers = Cell::new((0u64, 0u64));
870
871 let (u1, _) = cache
873 .get_or_compile(2, |upper| {
874 calls.set(calls.get() + 1);
875 uppers.set((upper, uppers.get().1));
876 tiny_graph(upper as usize)
877 })
878 .expect("key 2 in range");
879 let (u2, _) = cache
880 .get_or_compile(3, |upper| {
881 calls.set(calls.get() + 1);
882 uppers.set((uppers.get().0, upper));
883 tiny_graph(upper as usize)
884 })
885 .expect("key 3 in range");
886
887 assert_eq!(calls.get(), 1);
889 assert_eq!(u1, 3);
890 assert_eq!(u2, 3);
891 assert_eq!(uppers.get().0, 3);
892 assert_eq!(cache.compiled_count(), 1);
893 assert_eq!(cache.total_buckets(), 2);
894 }
895
896 #[test]
897 fn bucket_lookup_returns_none_outside_range() {
898 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..4, 4..16]);
899 assert!(cache.bucket_for(0).is_none());
900 assert!(cache.bucket_for(16).is_none());
901 assert!(cache.bucket_for(100).is_none());
902 assert_eq!(cache.bucket_for(3), Some(0));
903 assert_eq!(cache.bucket_for(4), Some(1));
904 assert_eq!(cache.bucket_upper_for_key(3), Some(3));
905 assert_eq!(cache.bucket_upper_for_key(4), Some(15));
906 assert!(cache.bucket_upper_for_key(0).is_none());
907
908 let calls = Cell::new(0);
909 let result = cache.get_or_compile(100, |u| {
910 calls.set(calls.get() + 1);
911 tiny_graph(u as usize)
912 });
913 assert!(result.is_none());
914 assert_eq!(calls.get(), 0); assert_eq!(cache.compiled_count(), 0);
916 }
917
918 #[test]
919 fn bucket_compiles_lazily_per_bucket() {
920 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..4, 4..16, 16..64]);
921 let calls = Cell::new(0);
922
923 let _ = cache.get_or_compile(2, |u| {
924 calls.set(calls.get() + 1);
925 tiny_graph(u as usize)
926 });
927 let _ = cache.get_or_compile(8, |u| {
928 calls.set(calls.get() + 1);
929 tiny_graph(u as usize)
930 });
931 assert_eq!(calls.get(), 2);
933 assert_eq!(cache.compiled_count(), 2);
934 assert_eq!(cache.total_buckets(), 3);
935 }
936
937 #[test]
938 #[should_panic(expected = "overlap")]
939 fn bucket_overlap_rejected() {
940 let _ = BucketedCompileCache::new(Device::Cpu, vec![1..8, 4..16]);
941 }
942
943 #[test]
944 #[should_panic(expected = "≥1 bucket")]
945 fn empty_bucket_list_rejected() {
946 let _ = BucketedCompileCache::new(Device::Cpu, vec![]);
947 }
948
949 #[test]
952 fn pad_rows_appends_zeros() {
953 let p = pad_rows(&[1.0, 2.0, 3.0], 1, 5);
955 assert_eq!(p, vec![1.0, 2.0, 3.0, 0.0, 0.0]);
956
957 let p = pad_rows(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 4);
959 assert_eq!(
960 p,
961 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
962 );
963
964 let p = pad_rows(&[7.0, 8.0], 1, 2);
966 assert_eq!(p, vec![7.0, 8.0]);
967 }
968
969 #[test]
970 fn slice_rows_truncates_trailing() {
971 let s = slice_rows(&[1.0, 2.0, 3.0, 0.0, 0.0], 1, 3);
972 assert_eq!(s, vec![1.0, 2.0, 3.0]);
973
974 let s = slice_rows(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0, 0.0], 3, 2);
975 assert_eq!(s, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
976 }
977
978 #[test]
979 #[should_panic(expected = "exceed upper")]
980 fn pad_rows_rejects_too_long_input() {
981 let _ = pad_rows(&[1.0, 2.0, 3.0, 4.0], 1, 3);
982 }
983
984 #[test]
985 #[should_panic(expected = "exceed upper")]
986 fn slice_rows_rejects_too_large_actual() {
987 let _ = slice_rows(&[1.0, 2.0, 3.0], 1, 5);
988 }
989
990 #[test]
993 fn run_padded_pads_input_and_slices_output() {
994 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
997 let input: Vec<f32> = vec![1.0, -1.0, 2.0, -2.0, 3.0, -3.0, 4.0, -4.0, 5.0, -5.0];
998
999 let (upper, outs) = cache
1000 .run_padded(
1001 10, 10, |max| tiny_graph(max as usize),
1004 &[("x", &input, 1)], &[1], )
1007 .expect("key 10 in [1..16)");
1008
1009 assert_eq!(upper, 15);
1010 assert_eq!(outs.len(), 1);
1011 let out = &outs[0];
1012 assert_eq!(out.len(), 10, "output sliced back to actual_rows");
1013 let expected: Vec<f32> = input.iter().map(|x| x.max(0.0)).collect();
1014 assert_eq!(out, &expected);
1015 }
1016
1017 #[test]
1018 fn run_padded_reuses_bucket_across_actuals() {
1019 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
1021 let calls = Cell::new(0);
1022
1023 let (u1, o1) = cache
1024 .run_padded(
1025 10,
1026 10,
1027 |max| {
1028 calls.set(calls.get() + 1);
1029 tiny_graph(max as usize)
1030 },
1031 &[(
1032 "x",
1033 &[1.0, -1.0, 2.0, -2.0, 3.0, -3.0, 4.0, -4.0, 5.0, -5.0],
1034 1,
1035 )],
1036 &[1],
1037 )
1038 .unwrap();
1039 assert_eq!(o1.len(), 1);
1040 assert_eq!(o1[0].len(), 10);
1041 assert_eq!(u1, 15);
1042
1043 let (u2, o2) = cache
1044 .run_padded(
1045 5,
1046 5,
1047 |max| {
1048 calls.set(calls.get() + 1);
1049 tiny_graph(max as usize)
1050 },
1051 &[("x", &[-1.0, 2.0, -3.0, 4.0, -5.0], 1)],
1052 &[1],
1053 )
1054 .unwrap();
1055 assert_eq!(o2.len(), 1);
1056 assert_eq!(o2[0].len(), 5);
1057 assert_eq!(u2, 15);
1058 assert_eq!(o2[0], vec![0.0, 2.0, 0.0, 4.0, 0.0]);
1059
1060 assert_eq!(calls.get(), 1, "bucket cached across actuals");
1061 assert_eq!(cache.compiled_count(), 1);
1062 }
1063
1064 #[test]
1065 fn run_padded_returns_none_out_of_range() {
1066 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
1067 let calls = Cell::new(0);
1068 let result = cache.run_padded(
1069 100,
1070 5,
1071 |u| {
1072 calls.set(calls.get() + 1);
1073 tiny_graph(u as usize)
1074 },
1075 &[("x", &[1.0, 2.0, 3.0, 4.0, 5.0], 1)],
1076 &[1],
1077 );
1078 assert!(result.is_none());
1079 assert_eq!(calls.get(), 0);
1080 assert_eq!(cache.compiled_count(), 0);
1081 }
1082
1083 #[test]
1086 fn power_of_two_ladder_generates_log_buckets() {
1087 let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 8, 64);
1088 let ranges: Vec<_> = cache.buckets().cloned().collect();
1090 assert_eq!(ranges, vec![1..9, 9..17, 17..33, 33..65]);
1091 assert_eq!(cache.total_buckets(), 4);
1092 }
1093
1094 #[test]
1095 fn power_of_two_ladder_picks_smallest_extent_for_actual() {
1096 let mut cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 8, 64);
1099 let captured_uppers: std::cell::RefCell<Vec<u64>> = Default::default();
1100
1101 let (u17, _) = cache
1102 .get_or_compile(17, |upper| {
1103 captured_uppers.borrow_mut().push(upper);
1104 tiny_graph(upper as usize)
1105 })
1106 .unwrap();
1107 let (u9, _) = cache
1108 .get_or_compile(9, |upper| {
1109 captured_uppers.borrow_mut().push(upper);
1110 tiny_graph(upper as usize)
1111 })
1112 .unwrap();
1113 let (u3, _) = cache
1114 .get_or_compile(3, |upper| {
1115 captured_uppers.borrow_mut().push(upper);
1116 tiny_graph(upper as usize)
1117 })
1118 .unwrap();
1119 let (u64_, _) = cache
1120 .get_or_compile(64, |upper| {
1121 captured_uppers.borrow_mut().push(upper);
1122 tiny_graph(upper as usize)
1123 })
1124 .unwrap();
1125
1126 assert_eq!(u17, 32, "key=17 → smallest extent ≥ 17 is 32");
1127 assert_eq!(u9, 16, "key=9 → smallest extent ≥ 9 is 16");
1128 assert_eq!(u3, 8, "key=3 → smallest extent ≥ 3 is 8");
1129 assert_eq!(u64_, 64, "key=64 → exact match at 64");
1130 assert_eq!(*captured_uppers.borrow(), vec![32, 16, 8, 64]);
1131 assert_eq!(cache.compiled_count(), 4);
1132 }
1133
1134 #[test]
1135 fn power_of_two_ladder_min_above_one_starts_at_one() {
1136 let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 16, 32);
1139 let ranges: Vec<_> = cache.buckets().cloned().collect();
1140 assert_eq!(ranges, vec![1..17, 17..33]);
1142 }
1143
1144 #[test]
1145 fn power_of_two_ladder_non_pow2_min_rounds_up() {
1146 let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 10, 64);
1148 let ranges: Vec<_> = cache.buckets().cloned().collect();
1149 assert_eq!(ranges, vec![1..17, 17..33, 33..65]);
1150 }
1151
1152 #[test]
1153 fn power_of_two_ladder_max_below_pow2_extends_up() {
1154 let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 8, 20);
1156 let ranges: Vec<_> = cache.buckets().cloned().collect();
1157 assert_eq!(ranges, vec![1..9, 9..17, 17..33]);
1158 }
1159
1160 #[test]
1161 fn power_of_two_ladder_min_equals_max() {
1162 let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 16, 16);
1163 let ranges: Vec<_> = cache.buckets().cloned().collect();
1164 assert_eq!(ranges, vec![1..17]);
1165 }
1166
1167 #[test]
1168 #[should_panic(expected = "min must be ≥ 1")]
1169 fn power_of_two_ladder_zero_min_rejected() {
1170 let _ = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 0, 16);
1171 }
1172
1173 #[test]
1174 #[should_panic(expected = "max")]
1175 fn power_of_two_ladder_max_below_min_rejected() {
1176 let _ = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 32, 8);
1177 }
1178
1179 #[test]
1192 #[ignore = "active-extent execution is a stub on CPU (thunk.rs::execute_thunks_active)"]
1193 fn active_extent_skips_compute_on_cpu_activation() {
1194 let graph = tiny_graph(15);
1205 let mut compiled = Session::new(Device::Cpu).compile(graph);
1206
1207 let warm_input: Vec<f32> = vec![1.0; 15];
1209 let warm_outs = compiled.run(&[("x", &warm_input)]);
1210 assert_eq!(warm_outs[0], vec![1.0; 15], "warm-up sanity");
1211
1212 let neg_input: Vec<f32> = vec![-1.0; 15];
1215 compiled.set_active_extent(Some((5, 15)));
1216 let outs = compiled.run(&[("x", &neg_input)]);
1217 let out = &outs[0];
1218
1219 assert_eq!(out.len(), 15);
1220 assert_eq!(
1221 out[..5],
1222 [0.0; 5],
1223 "first 5 elements processed (relu of -1)"
1224 );
1225 assert_eq!(
1226 out[5..],
1227 [1.0; 10],
1228 "tail untouched — proves Copy + Activation skipped indices 5..15"
1229 );
1230
1231 compiled.set_active_extent(None);
1234 let outs = compiled.run(&[("x", &neg_input)]);
1235 assert_eq!(
1236 outs[0],
1237 vec![0.0; 15],
1238 "full-extent path must clip every negative"
1239 );
1240 }
1241
1242 #[test]
1243 #[ignore = "active-extent execution is a stub on CPU (thunk.rs::execute_thunks_active)"]
1244 fn active_extent_skips_compute_on_binary_full() {
1245 let mut g = Graph::new("add");
1249 let f = DType::F32;
1250 let a = g.input("a", Shape::new(&[4], f));
1251 let b = g.input("b", Shape::new(&[4], f));
1252 let c = g.add(a, b);
1253 g.set_outputs(vec![c]);
1254 let mut compiled = Session::new(Device::Cpu).compile(g);
1255
1256 let warm = compiled.run(&[("a", &[1.0f32; 4]), ("b", &[1.0f32; 4])]);
1258 assert_eq!(warm[0], vec![2.0; 4]);
1259
1260 compiled.set_active_extent(Some((2, 4)));
1263 let outs = compiled.run(&[("a", &[10.0f32; 4]), ("b", &[10.0f32; 4])]);
1264 let out = &outs[0];
1265 assert_eq!(out[..2], [20.0, 20.0], "first 2 = active sum");
1266 assert_eq!(
1267 out[2..],
1268 [2.0, 2.0],
1269 "tail untouched — proves BinaryFull skipped indices 2..4"
1270 );
1271
1272 compiled.set_active_extent(None);
1274 let outs = compiled.run(&[("a", &[10.0f32; 4]), ("b", &[10.0f32; 4])]);
1275 assert_eq!(outs[0], vec![20.0; 4]);
1276 }
1277
1278 #[test]
1279 #[ignore = "process-wide STATE; runs only in isolation via `cargo test perfetto -- --ignored`"]
1280 fn perfetto_trace_emits_per_thunk_events() {
1281 use std::env;
1288 use std::fs;
1289 let path = env::temp_dir().join(format!("rlx-perfetto-e2e-{}.json", std::process::id()));
1290 if path.exists() {
1291 let _ = fs::remove_file(&path);
1292 }
1293 unsafe {
1294 env::set_var("RLX_TRACE_PERFETTO", &path);
1295 }
1296
1297 let f = DType::F32;
1299 let mut g = Graph::new("perf");
1300 let a = g.input("a", Shape::new(&[4], f));
1301 let b = g.input("b", Shape::new(&[4], f));
1302 let s = g.add(a, b);
1303 let r = g.relu(s);
1304 g.set_outputs(vec![r]);
1305 let mut compiled = Session::new(Device::Cpu).compile(g);
1306 let _ = compiled.run(&[("a", &[1.0; 4]), ("b", &[1.0; 4])]);
1307
1308 crate::perfetto::flush_and_finalize();
1310
1311 let contents = fs::read_to_string(&path).expect("trace file");
1312 assert!(
1314 contents.contains("\"binary\"")
1315 || contents.contains("\"activation\"")
1316 || contents.contains("\"elementwise_region\""),
1317 "expected at least one thunk-name event in perfetto trace; got: {contents}"
1318 );
1319 assert!(contents.trim_start().starts_with('['));
1321 let _ = fs::remove_file(&path);
1322 }
1323
1324 #[test]
1325 fn elementwise_region_fused_matches_unfused() {
1326 let f = DType::F32;
1331 let mut g = Graph::new("ew_e2e");
1332 let a = g.input("a", Shape::new(&[8], f));
1333 let b = g.input("b", Shape::new(&[8], f));
1334 let c = g.input("c", Shape::new(&[8], f));
1335 let s = Shape::new(&[8], f);
1336 let add = g.add(a, b);
1337 let mul = g.mul(add, c);
1338 let relu = g.relu(mul);
1339 let _ = s;
1340 g.set_outputs(vec![relu]);
1341
1342 let mut compiled = Session::new(Device::Cpu).compile(g);
1343 let av: Vec<f32> = vec![1.0, -2.0, 3.0, -4.0, 0.5, -0.5, 1.5, -1.5];
1344 let bv: Vec<f32> = vec![0.5, 1.0, 2.0, 4.0, 0.5, 0.5, 0.5, 0.5];
1345 let cv: Vec<f32> = vec![1.0, 2.0, 1.0, 1.0, 2.0, 3.0, 0.5, 4.0];
1346 let outs = compiled.run(&[("a", &av), ("b", &bv), ("c", &cv)]);
1347 let out = &outs[0];
1348
1349 let expected: Vec<f32> = (0..8)
1350 .map(|i| {
1351 let v = (av[i] + bv[i]) * cv[i];
1352 v.max(0.0)
1353 })
1354 .collect();
1355 for (i, (got, exp)) in out.iter().zip(&expected).enumerate() {
1356 assert!(
1357 (got - exp).abs() < 1e-6,
1358 "mismatch at {i}: got {got}, expected {exp}"
1359 );
1360 }
1361 }
1362
1363 #[test]
1364 #[ignore = "active-extent execution is a stub on CPU (thunk.rs::execute_thunks_active)"]
1365 fn active_extent_skips_compute_on_attention() {
1366 use rlx_ir::op::MaskKind;
1369 let f = DType::F32;
1370 let mut g = Graph::new("attn");
1371 let q = g.input("q", Shape::new(&[1, 4, 8], f));
1372 let k = g.input("k", Shape::new(&[1, 4, 8], f));
1373 let v = g.input("v", Shape::new(&[1, 4, 8], f));
1374 let out = g.attention_kind(q, k, v, 2, 4, MaskKind::None, Shape::new(&[1, 4, 8], f));
1375 g.set_outputs(vec![out]);
1376 let mut compiled = Session::new(Device::Cpu).compile(g);
1377
1378 let warm = compiled.run(&[
1380 ("q", &[1.0f32; 32]),
1381 ("k", &[1.0f32; 32]),
1382 ("v", &[1.0f32; 32]),
1383 ]);
1384 let warm_out = warm[0].clone();
1385 assert_eq!(warm_out.len(), 32);
1386
1387 compiled.set_active_extent(Some((2, 4)));
1391 let outs = compiled.run(&[
1392 ("q", &[3.0f32; 32]),
1393 ("k", &[3.0f32; 32]),
1394 ("v", &[3.0f32; 32]),
1395 ]);
1396 let out = &outs[0];
1397 assert_eq!(out.len(), 32);
1398 assert_eq!(
1399 &out[16..],
1400 &warm_out[16..],
1401 "tail (positions 2,3) must be untouched — proves Attention skipped"
1402 );
1403 assert_ne!(
1405 &out[..16],
1406 &warm_out[..16],
1407 "first 2 positions should reflect new input"
1408 );
1409 }
1410
1411 #[test]
1412 fn active_extent_falls_back_when_unsupported_thunk_in_schedule() {
1413 }
1428
1429 #[test]
1430 fn run_padded_uses_active_extent_on_cpu() {
1431 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
1434 let input: Vec<f32> = vec![
1435 1.0, -1.0, 2.0, -2.0, 3.0, -10.0, -20.0, -30.0, -40.0, -50.0, ];
1438 let (upper, outs) = cache
1444 .run_padded(
1445 5,
1446 5,
1447 |max| tiny_graph(max as usize),
1448 &[("x", &input[..5], 1)],
1449 &[1],
1450 )
1451 .unwrap();
1452 assert_eq!(upper, 15);
1453 assert_eq!(outs[0].len(), 5);
1454 assert_eq!(outs[0], vec![1.0, 0.0, 2.0, 0.0, 3.0]);
1460 }
1461
1462 #[test]
1463 fn run_padded_inner_zero_returns_output_unsliced() {
1464 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
1467 let input: Vec<f32> = vec![1.0, -1.0, 2.0, -2.0, 3.0];
1468
1469 let (upper, outs) = cache
1470 .run_padded(
1471 5,
1472 5,
1473 |max| tiny_graph(max as usize),
1474 &[("x", &input, 1)],
1475 &[0], )
1477 .unwrap();
1478
1479 assert_eq!(upper, 15);
1480 assert_eq!(
1481 outs[0].len(),
1482 15,
1483 "unsliced output preserves full upper extent"
1484 );
1485 assert_eq!(&outs[0][..5], &[1.0, 0.0, 2.0, 0.0, 3.0]);
1487 assert!(outs[0][5..].iter().all(|&v| v == 0.0));
1488 }
1489
1490 #[test]
1491 fn dynamic_dim_cache_specializes_per_key() {
1492 use rlx_ir::DType;
1493 use rlx_ir::Shape;
1494 use rlx_ir::hir::HirModule;
1495 use rlx_ir::sym;
1496
1497 let mut cache = DynamicDimCompileCache::new(Device::Cpu, 4);
1498 let opts = crate::CompileOptions::new();
1499 {
1500 let _short = cache
1501 .get_or_specialize(
1502 8,
1503 &rlx_ir::DimBinding::batch_seq(1, 8),
1504 || {
1505 let mut hir = HirModule::new("dyn_cache");
1506 let x = hir.input_batch_seq("x", sym::BATCH, sym::SEQ, 4, DType::F32);
1507 let w = hir.param("w", Shape::new(&[4, 2], DType::F32));
1508 let y = hir.linear(
1509 x,
1510 w,
1511 None,
1512 None,
1513 Shape::batch_seq(sym::BATCH, sym::SEQ, 2, DType::F32),
1514 );
1515 hir.set_outputs(vec![y]);
1516 hir
1517 },
1518 &opts,
1519 )
1520 .expect("specialize short");
1521 }
1522 assert!(cache.has_template());
1523 assert_eq!(cache.len(), 1);
1524 cache
1525 .get_or_specialize(
1526 128,
1527 &rlx_ir::DimBinding::batch_seq(1, 128),
1528 || panic!("HIR builder must not run twice"),
1529 &opts,
1530 )
1531 .expect("specialize long");
1532 assert_eq!(cache.len(), 2);
1533 }
1534}