1use crate::{CompiledGraph, Device, Session};
40use rlx_ir::DimBinding;
41use rlx_ir::Graph;
42use rlx_ir::hir::HirModule;
43use rlx_opt::CompileResult;
44use std::collections::HashMap;
45use std::collections::VecDeque;
46use std::ops::Range;
47
48pub struct CacheRunInput<'a> {
50 pub name: &'a str,
51 pub data: &'a [f32],
52 pub row_inner: Option<usize>,
54}
55
56pub struct CompileCache {
57 device: Device,
58 capacity: usize,
59 policy: Option<rlx_opt::PrecisionPolicy>,
62 entries: Vec<(u64, CompiledGraph)>,
66 order: VecDeque<u64>,
68}
69
70impl CompileCache {
71 pub fn new(device: Device, capacity: usize) -> Self {
72 Self::with_policy(device, capacity, None)
73 }
74
75 pub fn with_policy(
79 device: Device,
80 capacity: usize,
81 policy: Option<rlx_opt::PrecisionPolicy>,
82 ) -> Self {
83 assert!(capacity > 0, "CompileCache capacity must be ≥ 1");
84 Self {
85 device,
86 capacity,
87 policy,
88 entries: Vec::with_capacity(capacity),
89 order: VecDeque::with_capacity(capacity),
90 }
91 }
92
93 pub fn get_or_compile<F: FnOnce() -> Graph>(
97 &mut self,
98 key: u64,
99 build: F,
100 ) -> &mut CompiledGraph {
101 self.get_or_compile_with_options(key, build, &crate::CompileOptions::new())
102 }
103
104 pub fn get_or_compile_with_options<F: FnOnce() -> Graph>(
106 &mut self,
107 key: u64,
108 build: F,
109 options: &crate::CompileOptions,
110 ) -> &mut CompiledGraph {
111 if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
112 return &mut self.entries[idx].1;
113 }
114 let mut session = Session::new(self.device);
115 if let Some(p) = &self.policy {
116 session = session.with_policy(p.clone());
117 }
118 let compiled = session.compile_with(build(), options);
119
120 if self.entries.len() >= self.capacity
122 && let Some(evict_key) = self.order.pop_front()
123 {
124 self.entries.retain(|(k, _)| *k != evict_key);
125 }
126 self.entries.push((key, compiled));
127 self.order.push_back(key);
128 &mut self.entries.last_mut().unwrap().1
129 }
130
131 pub fn len(&self) -> usize {
133 self.entries.len()
134 }
135 pub fn is_empty(&self) -> bool {
136 self.entries.is_empty()
137 }
138 pub fn contains(&self, key: u64) -> bool {
140 self.entries.iter().any(|(k, _)| *k == key)
141 }
142}
143
144pub struct BucketedCompileCache {
191 device: Device,
192 policy: Option<rlx_opt::PrecisionPolicy>,
193 buckets: Vec<Bucket>,
194}
195
196struct Bucket {
197 range: Range<u64>,
198 compiled: Option<CompiledGraph>,
199}
200
201impl BucketedCompileCache {
202 pub fn new(device: Device, buckets: Vec<Range<u64>>) -> Self {
203 Self::with_policy(device, buckets, None)
204 }
205
206 pub fn power_of_two_ladder(device: Device, min: u64, max: u64) -> Self {
222 Self::power_of_two_ladder_with_policy(device, min, max, None)
223 }
224
225 pub fn power_of_two_ladder_with_policy(
226 device: Device,
227 min: u64,
228 max: u64,
229 policy: Option<rlx_opt::PrecisionPolicy>,
230 ) -> Self {
231 assert!(min >= 1, "power_of_two_ladder: min must be ≥ 1, got {min}");
232 assert!(
233 max >= min,
234 "power_of_two_ladder: max ({max}) must be ≥ min ({min})"
235 );
236 let mut buckets: Vec<Range<u64>> = Vec::new();
237 let mut start = 1u64;
238 let mut extent = min.next_power_of_two();
239 loop {
240 buckets.push(start..(extent + 1));
241 if extent >= max {
242 break;
243 }
244 start = extent + 1;
245 extent = extent
246 .checked_mul(2)
247 .expect("power_of_two_ladder: extent overflow");
248 }
249 Self::with_policy(device, buckets, policy)
250 }
251
252 pub fn with_policy(
253 device: Device,
254 buckets: Vec<Range<u64>>,
255 policy: Option<rlx_opt::PrecisionPolicy>,
256 ) -> Self {
257 assert!(!buckets.is_empty(), "BucketedCompileCache needs ≥1 bucket");
258 for (i, b) in buckets.iter().enumerate() {
259 assert!(b.start < b.end, "bucket {i} ({b:?}) is empty");
260 if i + 1 < buckets.len() {
261 assert!(
262 b.end <= buckets[i + 1].start,
263 "buckets {i} ({b:?}) and {} ({:?}) overlap",
264 i + 1,
265 buckets[i + 1],
266 );
267 }
268 }
269 let buckets = buckets
270 .into_iter()
271 .map(|range| Bucket {
272 range,
273 compiled: None,
274 })
275 .collect();
276 Self {
277 device,
278 policy,
279 buckets,
280 }
281 }
282
283 pub fn get_or_compile<F: FnOnce(u64) -> Graph>(
292 &mut self,
293 key: u64,
294 build: F,
295 ) -> Option<(u64, &mut CompiledGraph)> {
296 self.get_or_compile_with_options(key, build, &crate::CompileOptions::new())
297 }
298
299 pub fn get_or_compile_with_options<F: FnOnce(u64) -> Graph>(
301 &mut self,
302 key: u64,
303 build: F,
304 options: &crate::CompileOptions,
305 ) -> Option<(u64, &mut CompiledGraph)> {
306 let idx = self.bucket_for(key)?;
307 let upper = self.buckets[idx].range.end - 1;
308 if self.buckets[idx].compiled.is_none() {
309 let mut session = Session::new(self.device);
310 if let Some(p) = &self.policy {
311 session = session.with_policy(p.clone());
312 }
313 self.buckets[idx].compiled = Some(session.compile_with(build(upper), options));
314 }
315 Some((upper, self.buckets[idx].compiled.as_mut().unwrap()))
316 }
317
318 pub fn get_or_compile_hir<F: FnOnce(u64) -> HirModule>(
321 &mut self,
322 key: u64,
323 build: F,
324 ) -> Option<(u64, &mut CompiledGraph)> {
325 self.get_or_compile_hir_with_options(key, build, &crate::CompileOptions::new())
326 }
327
328 pub fn get_or_compile_hir_with_options<F: FnOnce(u64) -> HirModule>(
330 &mut self,
331 key: u64,
332 build: F,
333 options: &crate::CompileOptions,
334 ) -> Option<(u64, &mut CompiledGraph)> {
335 let idx = self.bucket_for(key)?;
336 let upper = self.buckets[idx].range.end - 1;
337 if self.buckets[idx].compiled.is_none() {
338 let mut session = Session::new(self.device);
339 if let Some(p) = &self.policy {
340 session = session.with_policy(p.clone());
341 }
342 let compiled = session
343 .compile_hir_with(build(upper), options)
344 .expect("HIR lower/compile in bucketed cache");
345 self.buckets[idx].compiled = Some(compiled);
346 }
347 Some((upper, self.buckets[idx].compiled.as_mut().unwrap()))
348 }
349
350 pub fn bucket_for(&self, key: u64) -> Option<usize> {
353 self.buckets.iter().position(|b| b.range.contains(&key))
354 }
355
356 pub fn buckets(&self) -> impl Iterator<Item = &Range<u64>> {
357 self.buckets.iter().map(|b| &b.range)
358 }
359
360 pub fn compiled_count(&self) -> usize {
362 self.buckets.iter().filter(|b| b.compiled.is_some()).count()
363 }
364
365 pub fn total_buckets(&self) -> usize {
366 self.buckets.len()
367 }
368
369 pub fn run_padded<F: FnOnce(u64) -> Graph>(
395 &mut self,
396 key: u64,
397 actual_rows: usize,
398 build: F,
399 inputs: &[(&str, &[f32], usize)],
400 output_inners: &[usize],
401 ) -> Option<(u64, Vec<Vec<f32>>)> {
402 let (upper, compiled) = self.get_or_compile(key, build)?;
403
404 let padded: Vec<(&str, Vec<f32>)> = inputs
406 .iter()
407 .map(|(name, data, inner)| (*name, pad_rows(data, *inner, upper)))
408 .collect();
409 let pairs: Vec<(&str, &[f32])> = padded.iter().map(|(n, d)| (*n, d.as_slice())).collect();
410
411 compiled.set_active_extent(Some((actual_rows, upper as usize)));
417 let raw_outputs = compiled.run(&pairs);
418 compiled.set_active_extent(None);
419
420 let outs = raw_outputs
421 .into_iter()
422 .enumerate()
423 .map(|(i, out)| match output_inners.get(i).copied() {
424 Some(0) | None => out,
425 Some(inner) => slice_rows(&out, inner, actual_rows),
426 })
427 .collect();
428
429 Some((upper, outs))
430 }
431
432 pub fn ensure_graph_with_params<F>(
434 &mut self,
435 key: u64,
436 build: F,
437 options: &crate::CompileOptions,
438 ) -> Option<(u64, &mut CompiledGraph)>
439 where
440 F: FnOnce(u64) -> (Graph, HashMap<String, Vec<f32>>),
441 {
442 let idx = self.bucket_for(key)?;
443 let upper = self.buckets[idx].range.end - 1;
444 if self.buckets[idx].compiled.is_none() {
445 let (graph, params) = build(upper);
446 let mut session = Session::new(self.device);
447 if let Some(p) = &self.policy {
448 session = session.with_policy(p.clone());
449 }
450 let mut compiled = session.compile_with(graph, options);
451 for (name, data) in params {
452 compiled.set_param(&name, &data);
453 }
454 self.buckets[idx].compiled = Some(compiled);
455 }
456 Some((upper, self.buckets[idx].compiled.as_mut().unwrap()))
457 }
458
459 pub fn ensure_hir_with_params<F>(
461 &mut self,
462 key: u64,
463 build: F,
464 options: &crate::CompileOptions,
465 ) -> Option<(u64, &mut CompiledGraph)>
466 where
467 F: FnOnce(u64) -> (HirModule, HashMap<String, Vec<f32>>),
468 {
469 let idx = self.bucket_for(key)?;
470 let upper = self.buckets[idx].range.end - 1;
471 if self.buckets[idx].compiled.is_none() {
472 let (hir, params) = build(upper);
473 let mut session = Session::new(self.device);
474 if let Some(p) = &self.policy {
475 session = session.with_policy(p.clone());
476 }
477 let mut compiled = session
478 .compile_hir_with(hir, options)
479 .expect("HIR lower/compile in ensure_hir_with_params");
480 for (name, data) in params {
481 compiled.set_param(&name, &data);
482 }
483 self.buckets[idx].compiled = Some(compiled);
484 }
485 Some((upper, self.buckets[idx].compiled.as_mut().unwrap()))
486 }
487
488 pub fn run_padded_mixed<F>(
490 &mut self,
491 key: u64,
492 actual_rows: usize,
493 build: F,
494 inputs: &[CacheRunInput<'_>],
495 output_inners: &[usize],
496 ) -> Option<(u64, Vec<Vec<f32>>)>
497 where
498 F: FnOnce(u64) -> Graph,
499 {
500 let (upper, compiled) = self.get_or_compile(key, build)?;
501
502 let padded: Vec<(&str, Vec<f32>)> = inputs
503 .iter()
504 .map(|inp| match inp.row_inner {
505 Some(inner) => (inp.name, pad_rows(inp.data, inner, upper)),
506 None => (inp.name, inp.data.to_vec()),
507 })
508 .collect();
509 let pairs: Vec<(&str, &[f32])> = padded.iter().map(|(n, d)| (*n, d.as_slice())).collect();
510
511 compiled.set_active_extent(Some((actual_rows, upper as usize)));
512 let raw_outputs = compiled.run(&pairs);
513 compiled.set_active_extent(None);
514
515 let outs = raw_outputs
516 .into_iter()
517 .enumerate()
518 .map(|(i, out)| match output_inners.get(i).copied() {
519 Some(0) | None => out,
520 Some(inner) => slice_rows(&out, inner, actual_rows),
521 })
522 .collect();
523
524 Some((upper, outs))
525 }
526}
527
528pub struct DynamicDimCompileCache {
536 device: Device,
537 policy: Option<rlx_opt::PrecisionPolicy>,
538 capacity: usize,
539 template: Option<CompileResult>,
540 entries: Vec<(u64, CompiledGraph)>,
541 order: VecDeque<u64>,
542}
543
544impl DynamicDimCompileCache {
545 pub fn new(device: Device, capacity: usize) -> Self {
546 Self::with_policy(device, capacity, None)
547 }
548
549 pub fn with_policy(
550 device: Device,
551 capacity: usize,
552 policy: Option<rlx_opt::PrecisionPolicy>,
553 ) -> Self {
554 assert!(capacity > 0, "DynamicDimCompileCache capacity must be ≥ 1");
555 Self {
556 device,
557 policy,
558 capacity,
559 template: None,
560 entries: Vec::with_capacity(capacity),
561 order: VecDeque::with_capacity(capacity),
562 }
563 }
564
565 pub fn compile_device(&self) -> Device {
566 self.device
567 }
568
569 pub fn get_or_specialize<F: FnOnce() -> HirModule>(
572 &mut self,
573 key: u64,
574 binding: &DimBinding,
575 build_hir: F,
576 options: &crate::CompileOptions,
577 ) -> Result<&mut CompiledGraph, rlx_ir::hir::LowerError> {
578 if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
579 return Ok(&mut self.entries[idx].1);
580 }
581 if self.template.is_none() {
582 let mut template_opts = options.clone();
583 template_opts.dim_binding = None;
584 let pipe = crate::stages::pipeline_for(self.device, &template_opts);
585 self.template = Some(pipe.compile_hir(build_hir())?);
586 }
587 let template = self.template.as_ref().expect("template just set");
588 let mut spec_opts = options.clone();
589 spec_opts.dim_binding = None;
590 let pipe = crate::stages::pipeline_for(self.device, &spec_opts);
591 let specialized = template.specialize(&pipe, binding);
592 let backend = crate::registry::backend_for(self.device).expect("backend registered");
593 let mut compile_opts = options.clone();
594 compile_opts.dim_binding = None;
595 if compile_opts.policy.is_none() {
596 if let Some(p) = &self.policy {
597 compile_opts = compile_opts.policy(p.clone());
598 }
599 }
600 let executable = backend.compile_lir(specialized.lir, &compile_opts);
601 let compiled = CompiledGraph::new(executable, self.device);
602
603 if self.entries.len() >= self.capacity
604 && let Some(evict_key) = self.order.pop_front()
605 {
606 self.entries.retain(|(k, _)| *k != evict_key);
607 }
608 self.entries.push((key, compiled));
609 self.order.push_back(key);
610 Ok(&mut self.entries.last_mut().unwrap().1)
611 }
612
613 pub fn len(&self) -> usize {
614 self.entries.len()
615 }
616
617 pub fn is_empty(&self) -> bool {
618 self.entries.is_empty()
619 }
620
621 pub fn contains(&self, key: u64) -> bool {
622 self.entries.iter().any(|(k, _)| *k == key)
623 }
624
625 pub fn has_template(&self) -> bool {
626 self.template.is_some()
627 }
628
629 pub fn ensure_template<F: FnOnce() -> HirModule>(
631 &mut self,
632 build_hir: F,
633 options: &crate::CompileOptions,
634 ) -> Result<&CompileResult, rlx_ir::hir::LowerError> {
635 if self.template.is_none() {
636 let mut opts = options.clone();
637 opts.dim_binding = None;
638 let pipe = crate::stages::pipeline_for(self.device, &opts);
639 self.template = Some(pipe.compile_hir(build_hir())?);
640 }
641 Ok(self.template.as_ref().expect("template set"))
642 }
643
644 pub fn template_result(&self) -> Option<&CompileResult> {
645 self.template.as_ref()
646 }
647
648 pub fn get_or_specialize_aot<F: FnOnce() -> HirModule>(
651 &mut self,
652 aot: &crate::AotCache,
653 disk_base: &str,
654 key: u64,
655 binding: &rlx_ir::DimBinding,
656 build_hir: F,
657 options: &crate::CompileOptions,
658 ) -> Result<&mut CompiledGraph, crate::AotCacheError> {
659 if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
660 return Ok(&mut self.entries[idx].1);
661 }
662 let device = self.device;
663 let template = self.ensure_template(build_hir, options)?;
664 let compiled = aot.specialize_cached(disk_base, binding, device, template, options)?;
665 if self.entries.len() >= self.capacity
666 && let Some(evict_key) = self.order.pop_front()
667 {
668 self.entries.retain(|(k, _)| *k != evict_key);
669 }
670 self.entries.push((key, compiled));
671 self.order.push_back(key);
672 Ok(&mut self.entries.last_mut().unwrap().1)
673 }
674}
675
676pub fn pad_rows(data: &[f32], inner: usize, upper: u64) -> Vec<f32> {
684 assert!(inner > 0, "pad_rows: inner stride must be ≥ 1");
685 assert_eq!(
686 data.len() % inner,
687 0,
688 "pad_rows: data len {} not a multiple of inner {inner}",
689 data.len(),
690 );
691 let upper = upper as usize;
692 let actual = data.len() / inner;
693 assert!(
694 actual <= upper,
695 "pad_rows: actual rows {actual} exceed upper bound {upper}",
696 );
697 let mut out = vec![0.0_f32; upper * inner];
698 out[..actual * inner].copy_from_slice(data);
699 out
700}
701
702pub fn slice_rows(data: &[f32], inner: usize, actual: usize) -> Vec<f32> {
708 assert!(inner > 0, "slice_rows: inner stride must be ≥ 1");
709 assert_eq!(
710 data.len() % inner,
711 0,
712 "slice_rows: data len {} not a multiple of inner {inner}",
713 data.len(),
714 );
715 let upper = data.len() / inner;
716 assert!(
717 actual <= upper,
718 "slice_rows: actual rows {actual} exceed upper {upper}",
719 );
720 data[..actual * inner].to_vec()
721}
722
723#[cfg(test)]
724mod tests {
725 use super::*;
726 use rlx_ir::infer::GraphExt;
727 use rlx_ir::*;
728 use std::cell::Cell;
729
730 fn tiny_graph(n: usize) -> Graph {
731 let mut g = Graph::new("t");
732 let f = DType::F32;
733 let x = g.input("x", Shape::new(&[n], f));
734 let y = g.activation(rlx_ir::op::Activation::Relu, x, Shape::new(&[n], f));
735 g.set_outputs(vec![y]);
736 g
737 }
738
739 #[test]
740 fn cache_hits_avoid_recompile() {
741 let mut cache = CompileCache::new(Device::Cpu, 4);
742 let calls = Cell::new(0);
743
744 let _ = cache.get_or_compile(1, || {
745 calls.set(calls.get() + 1);
746 tiny_graph(8)
747 });
748 let _ = cache.get_or_compile(1, || {
749 calls.set(calls.get() + 1);
750 tiny_graph(8)
751 });
752 let _ = cache.get_or_compile(1, || {
753 calls.set(calls.get() + 1);
754 tiny_graph(8)
755 });
756 assert_eq!(calls.get(), 1);
758 assert_eq!(cache.len(), 1);
759 }
760
761 #[test]
762 fn fifo_evicts_oldest_at_capacity() {
763 let mut cache = CompileCache::new(Device::Cpu, 2);
764 let _ = cache.get_or_compile(1, || tiny_graph(4));
765 let _ = cache.get_or_compile(2, || tiny_graph(8));
766 assert!(cache.contains(1) && cache.contains(2));
767 let _ = cache.get_or_compile(3, || tiny_graph(16));
769 assert!(!cache.contains(1));
770 assert!(cache.contains(2) && cache.contains(3));
771 }
772
773 #[test]
774 fn different_keys_keep_separate_compiles() {
775 let mut cache = CompileCache::new(Device::Cpu, 4);
776 let calls = Cell::new(0);
777 let _ = cache.get_or_compile(1, || {
778 calls.set(calls.get() + 1);
779 tiny_graph(8)
780 });
781 let _ = cache.get_or_compile(2, || {
782 calls.set(calls.get() + 1);
783 tiny_graph(16)
784 });
785 let _ = cache.get_or_compile(1, || {
786 calls.set(calls.get() + 1);
787 tiny_graph(8)
788 });
789 assert_eq!(calls.get(), 2);
791 assert_eq!(cache.len(), 2);
792 }
793
794 #[test]
797 fn bucket_amortizes_keys_within_range() {
798 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..4, 4..16]);
799 let calls = Cell::new(0);
800 let uppers = Cell::new((0u64, 0u64));
801
802 let (u1, _) = cache
804 .get_or_compile(2, |upper| {
805 calls.set(calls.get() + 1);
806 uppers.set((upper, uppers.get().1));
807 tiny_graph(upper as usize)
808 })
809 .expect("key 2 in range");
810 let (u2, _) = cache
811 .get_or_compile(3, |upper| {
812 calls.set(calls.get() + 1);
813 uppers.set((uppers.get().0, upper));
814 tiny_graph(upper as usize)
815 })
816 .expect("key 3 in range");
817
818 assert_eq!(calls.get(), 1);
820 assert_eq!(u1, 3);
821 assert_eq!(u2, 3);
822 assert_eq!(uppers.get().0, 3);
823 assert_eq!(cache.compiled_count(), 1);
824 assert_eq!(cache.total_buckets(), 2);
825 }
826
827 #[test]
828 fn bucket_lookup_returns_none_outside_range() {
829 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..4, 4..16]);
830 assert!(cache.bucket_for(0).is_none());
831 assert!(cache.bucket_for(16).is_none());
832 assert!(cache.bucket_for(100).is_none());
833 assert_eq!(cache.bucket_for(3), Some(0));
834 assert_eq!(cache.bucket_for(4), Some(1));
835
836 let calls = Cell::new(0);
837 let result = cache.get_or_compile(100, |u| {
838 calls.set(calls.get() + 1);
839 tiny_graph(u as usize)
840 });
841 assert!(result.is_none());
842 assert_eq!(calls.get(), 0); assert_eq!(cache.compiled_count(), 0);
844 }
845
846 #[test]
847 fn bucket_compiles_lazily_per_bucket() {
848 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..4, 4..16, 16..64]);
849 let calls = Cell::new(0);
850
851 let _ = cache.get_or_compile(2, |u| {
852 calls.set(calls.get() + 1);
853 tiny_graph(u as usize)
854 });
855 let _ = cache.get_or_compile(8, |u| {
856 calls.set(calls.get() + 1);
857 tiny_graph(u as usize)
858 });
859 assert_eq!(calls.get(), 2);
861 assert_eq!(cache.compiled_count(), 2);
862 assert_eq!(cache.total_buckets(), 3);
863 }
864
865 #[test]
866 #[should_panic(expected = "overlap")]
867 fn bucket_overlap_rejected() {
868 let _ = BucketedCompileCache::new(Device::Cpu, vec![1..8, 4..16]);
869 }
870
871 #[test]
872 #[should_panic(expected = "≥1 bucket")]
873 fn empty_bucket_list_rejected() {
874 let _ = BucketedCompileCache::new(Device::Cpu, vec![]);
875 }
876
877 #[test]
880 fn pad_rows_appends_zeros() {
881 let p = pad_rows(&[1.0, 2.0, 3.0], 1, 5);
883 assert_eq!(p, vec![1.0, 2.0, 3.0, 0.0, 0.0]);
884
885 let p = pad_rows(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 4);
887 assert_eq!(
888 p,
889 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
890 );
891
892 let p = pad_rows(&[7.0, 8.0], 1, 2);
894 assert_eq!(p, vec![7.0, 8.0]);
895 }
896
897 #[test]
898 fn slice_rows_truncates_trailing() {
899 let s = slice_rows(&[1.0, 2.0, 3.0, 0.0, 0.0], 1, 3);
900 assert_eq!(s, vec![1.0, 2.0, 3.0]);
901
902 let s = slice_rows(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0, 0.0], 3, 2);
903 assert_eq!(s, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
904 }
905
906 #[test]
907 #[should_panic(expected = "exceed upper")]
908 fn pad_rows_rejects_too_long_input() {
909 let _ = pad_rows(&[1.0, 2.0, 3.0, 4.0], 1, 3);
910 }
911
912 #[test]
913 #[should_panic(expected = "exceed upper")]
914 fn slice_rows_rejects_too_large_actual() {
915 let _ = slice_rows(&[1.0, 2.0, 3.0], 1, 5);
916 }
917
918 #[test]
921 fn run_padded_pads_input_and_slices_output() {
922 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
925 let input: Vec<f32> = vec![1.0, -1.0, 2.0, -2.0, 3.0, -3.0, 4.0, -4.0, 5.0, -5.0];
926
927 let (upper, outs) = cache
928 .run_padded(
929 10, 10, |max| tiny_graph(max as usize),
932 &[("x", &input, 1)], &[1], )
935 .expect("key 10 in [1..16)");
936
937 assert_eq!(upper, 15);
938 assert_eq!(outs.len(), 1);
939 let out = &outs[0];
940 assert_eq!(out.len(), 10, "output sliced back to actual_rows");
941 let expected: Vec<f32> = input.iter().map(|x| x.max(0.0)).collect();
942 assert_eq!(out, &expected);
943 }
944
945 #[test]
946 fn run_padded_reuses_bucket_across_actuals() {
947 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
949 let calls = Cell::new(0);
950
951 let (u1, o1) = cache
952 .run_padded(
953 10,
954 10,
955 |max| {
956 calls.set(calls.get() + 1);
957 tiny_graph(max as usize)
958 },
959 &[(
960 "x",
961 &[1.0, -1.0, 2.0, -2.0, 3.0, -3.0, 4.0, -4.0, 5.0, -5.0],
962 1,
963 )],
964 &[1],
965 )
966 .unwrap();
967 assert_eq!(o1.len(), 1);
968 assert_eq!(o1[0].len(), 10);
969 assert_eq!(u1, 15);
970
971 let (u2, o2) = cache
972 .run_padded(
973 5,
974 5,
975 |max| {
976 calls.set(calls.get() + 1);
977 tiny_graph(max as usize)
978 },
979 &[("x", &[-1.0, 2.0, -3.0, 4.0, -5.0], 1)],
980 &[1],
981 )
982 .unwrap();
983 assert_eq!(o2.len(), 1);
984 assert_eq!(o2[0].len(), 5);
985 assert_eq!(u2, 15);
986 assert_eq!(o2[0], vec![0.0, 2.0, 0.0, 4.0, 0.0]);
987
988 assert_eq!(calls.get(), 1, "bucket cached across actuals");
989 assert_eq!(cache.compiled_count(), 1);
990 }
991
992 #[test]
993 fn run_padded_returns_none_out_of_range() {
994 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
995 let calls = Cell::new(0);
996 let result = cache.run_padded(
997 100,
998 5,
999 |u| {
1000 calls.set(calls.get() + 1);
1001 tiny_graph(u as usize)
1002 },
1003 &[("x", &[1.0, 2.0, 3.0, 4.0, 5.0], 1)],
1004 &[1],
1005 );
1006 assert!(result.is_none());
1007 assert_eq!(calls.get(), 0);
1008 assert_eq!(cache.compiled_count(), 0);
1009 }
1010
1011 #[test]
1014 fn power_of_two_ladder_generates_log_buckets() {
1015 let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 8, 64);
1016 let ranges: Vec<_> = cache.buckets().cloned().collect();
1018 assert_eq!(ranges, vec![1..9, 9..17, 17..33, 33..65]);
1019 assert_eq!(cache.total_buckets(), 4);
1020 }
1021
1022 #[test]
1023 fn power_of_two_ladder_picks_smallest_extent_for_actual() {
1024 let mut cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 8, 64);
1027 let captured_uppers: std::cell::RefCell<Vec<u64>> = Default::default();
1028
1029 let (u17, _) = cache
1030 .get_or_compile(17, |upper| {
1031 captured_uppers.borrow_mut().push(upper);
1032 tiny_graph(upper as usize)
1033 })
1034 .unwrap();
1035 let (u9, _) = cache
1036 .get_or_compile(9, |upper| {
1037 captured_uppers.borrow_mut().push(upper);
1038 tiny_graph(upper as usize)
1039 })
1040 .unwrap();
1041 let (u3, _) = cache
1042 .get_or_compile(3, |upper| {
1043 captured_uppers.borrow_mut().push(upper);
1044 tiny_graph(upper as usize)
1045 })
1046 .unwrap();
1047 let (u64_, _) = cache
1048 .get_or_compile(64, |upper| {
1049 captured_uppers.borrow_mut().push(upper);
1050 tiny_graph(upper as usize)
1051 })
1052 .unwrap();
1053
1054 assert_eq!(u17, 32, "key=17 → smallest extent ≥ 17 is 32");
1055 assert_eq!(u9, 16, "key=9 → smallest extent ≥ 9 is 16");
1056 assert_eq!(u3, 8, "key=3 → smallest extent ≥ 3 is 8");
1057 assert_eq!(u64_, 64, "key=64 → exact match at 64");
1058 assert_eq!(*captured_uppers.borrow(), vec![32, 16, 8, 64]);
1059 assert_eq!(cache.compiled_count(), 4);
1060 }
1061
1062 #[test]
1063 fn power_of_two_ladder_min_above_one_starts_at_one() {
1064 let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 16, 32);
1067 let ranges: Vec<_> = cache.buckets().cloned().collect();
1068 assert_eq!(ranges, vec![1..17, 17..33]);
1070 }
1071
1072 #[test]
1073 fn power_of_two_ladder_non_pow2_min_rounds_up() {
1074 let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 10, 64);
1076 let ranges: Vec<_> = cache.buckets().cloned().collect();
1077 assert_eq!(ranges, vec![1..17, 17..33, 33..65]);
1078 }
1079
1080 #[test]
1081 fn power_of_two_ladder_max_below_pow2_extends_up() {
1082 let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 8, 20);
1084 let ranges: Vec<_> = cache.buckets().cloned().collect();
1085 assert_eq!(ranges, vec![1..9, 9..17, 17..33]);
1086 }
1087
1088 #[test]
1089 fn power_of_two_ladder_min_equals_max() {
1090 let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 16, 16);
1091 let ranges: Vec<_> = cache.buckets().cloned().collect();
1092 assert_eq!(ranges, vec![1..17]);
1093 }
1094
1095 #[test]
1096 #[should_panic(expected = "min must be ≥ 1")]
1097 fn power_of_two_ladder_zero_min_rejected() {
1098 let _ = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 0, 16);
1099 }
1100
1101 #[test]
1102 #[should_panic(expected = "max")]
1103 fn power_of_two_ladder_max_below_min_rejected() {
1104 let _ = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 32, 8);
1105 }
1106
1107 #[test]
1120 #[ignore = "active-extent execution is a stub on CPU (thunk.rs::execute_thunks_active)"]
1121 fn active_extent_skips_compute_on_cpu_activation() {
1122 let graph = tiny_graph(15);
1133 let mut compiled = Session::new(Device::Cpu).compile(graph);
1134
1135 let warm_input: Vec<f32> = vec![1.0; 15];
1137 let warm_outs = compiled.run(&[("x", &warm_input)]);
1138 assert_eq!(warm_outs[0], vec![1.0; 15], "warm-up sanity");
1139
1140 let neg_input: Vec<f32> = vec![-1.0; 15];
1143 compiled.set_active_extent(Some((5, 15)));
1144 let outs = compiled.run(&[("x", &neg_input)]);
1145 let out = &outs[0];
1146
1147 assert_eq!(out.len(), 15);
1148 assert_eq!(
1149 out[..5],
1150 [0.0; 5],
1151 "first 5 elements processed (relu of -1)"
1152 );
1153 assert_eq!(
1154 out[5..],
1155 [1.0; 10],
1156 "tail untouched — proves Copy + Activation skipped indices 5..15"
1157 );
1158
1159 compiled.set_active_extent(None);
1162 let outs = compiled.run(&[("x", &neg_input)]);
1163 assert_eq!(
1164 outs[0],
1165 vec![0.0; 15],
1166 "full-extent path must clip every negative"
1167 );
1168 }
1169
1170 #[test]
1171 #[ignore = "active-extent execution is a stub on CPU (thunk.rs::execute_thunks_active)"]
1172 fn active_extent_skips_compute_on_binary_full() {
1173 let mut g = Graph::new("add");
1177 let f = DType::F32;
1178 let a = g.input("a", Shape::new(&[4], f));
1179 let b = g.input("b", Shape::new(&[4], f));
1180 let c = g.add(a, b);
1181 g.set_outputs(vec![c]);
1182 let mut compiled = Session::new(Device::Cpu).compile(g);
1183
1184 let warm = compiled.run(&[("a", &[1.0f32; 4]), ("b", &[1.0f32; 4])]);
1186 assert_eq!(warm[0], vec![2.0; 4]);
1187
1188 compiled.set_active_extent(Some((2, 4)));
1191 let outs = compiled.run(&[("a", &[10.0f32; 4]), ("b", &[10.0f32; 4])]);
1192 let out = &outs[0];
1193 assert_eq!(out[..2], [20.0, 20.0], "first 2 = active sum");
1194 assert_eq!(
1195 out[2..],
1196 [2.0, 2.0],
1197 "tail untouched — proves BinaryFull skipped indices 2..4"
1198 );
1199
1200 compiled.set_active_extent(None);
1202 let outs = compiled.run(&[("a", &[10.0f32; 4]), ("b", &[10.0f32; 4])]);
1203 assert_eq!(outs[0], vec![20.0; 4]);
1204 }
1205
1206 #[test]
1207 #[ignore = "process-wide STATE; runs only in isolation via `cargo test perfetto -- --ignored`"]
1208 fn perfetto_trace_emits_per_thunk_events() {
1209 use std::env;
1216 use std::fs;
1217 let path = env::temp_dir().join(format!("rlx-perfetto-e2e-{}.json", std::process::id()));
1218 if path.exists() {
1219 let _ = fs::remove_file(&path);
1220 }
1221 unsafe {
1222 env::set_var("RLX_TRACE_PERFETTO", &path);
1223 }
1224
1225 let f = DType::F32;
1227 let mut g = Graph::new("perf");
1228 let a = g.input("a", Shape::new(&[4], f));
1229 let b = g.input("b", Shape::new(&[4], f));
1230 let s = g.add(a, b);
1231 let r = g.relu(s);
1232 g.set_outputs(vec![r]);
1233 let mut compiled = Session::new(Device::Cpu).compile(g);
1234 let _ = compiled.run(&[("a", &[1.0; 4]), ("b", &[1.0; 4])]);
1235
1236 crate::perfetto::flush_and_finalize();
1238
1239 let contents = fs::read_to_string(&path).expect("trace file");
1240 assert!(
1242 contents.contains("\"binary\"")
1243 || contents.contains("\"activation\"")
1244 || contents.contains("\"elementwise_region\""),
1245 "expected at least one thunk-name event in perfetto trace; got: {contents}"
1246 );
1247 assert!(contents.trim_start().starts_with('['));
1249 let _ = fs::remove_file(&path);
1250 }
1251
1252 #[test]
1253 fn elementwise_region_fused_matches_unfused() {
1254 let f = DType::F32;
1259 let mut g = Graph::new("ew_e2e");
1260 let a = g.input("a", Shape::new(&[8], f));
1261 let b = g.input("b", Shape::new(&[8], f));
1262 let c = g.input("c", Shape::new(&[8], f));
1263 let s = Shape::new(&[8], f);
1264 let add = g.add(a, b);
1265 let mul = g.mul(add, c);
1266 let relu = g.relu(mul);
1267 let _ = s;
1268 g.set_outputs(vec![relu]);
1269
1270 let mut compiled = Session::new(Device::Cpu).compile(g);
1271 let av: Vec<f32> = vec![1.0, -2.0, 3.0, -4.0, 0.5, -0.5, 1.5, -1.5];
1272 let bv: Vec<f32> = vec![0.5, 1.0, 2.0, 4.0, 0.5, 0.5, 0.5, 0.5];
1273 let cv: Vec<f32> = vec![1.0, 2.0, 1.0, 1.0, 2.0, 3.0, 0.5, 4.0];
1274 let outs = compiled.run(&[("a", &av), ("b", &bv), ("c", &cv)]);
1275 let out = &outs[0];
1276
1277 let expected: Vec<f32> = (0..8)
1278 .map(|i| {
1279 let v = (av[i] + bv[i]) * cv[i];
1280 v.max(0.0)
1281 })
1282 .collect();
1283 for (i, (got, exp)) in out.iter().zip(&expected).enumerate() {
1284 assert!(
1285 (got - exp).abs() < 1e-6,
1286 "mismatch at {i}: got {got}, expected {exp}"
1287 );
1288 }
1289 }
1290
1291 #[test]
1292 #[ignore = "active-extent execution is a stub on CPU (thunk.rs::execute_thunks_active)"]
1293 fn active_extent_skips_compute_on_attention() {
1294 use rlx_ir::op::MaskKind;
1297 let f = DType::F32;
1298 let mut g = Graph::new("attn");
1299 let q = g.input("q", Shape::new(&[1, 4, 8], f));
1300 let k = g.input("k", Shape::new(&[1, 4, 8], f));
1301 let v = g.input("v", Shape::new(&[1, 4, 8], f));
1302 let out = g.attention_kind(q, k, v, 2, 4, MaskKind::None, Shape::new(&[1, 4, 8], f));
1303 g.set_outputs(vec![out]);
1304 let mut compiled = Session::new(Device::Cpu).compile(g);
1305
1306 let warm = compiled.run(&[
1308 ("q", &[1.0f32; 32]),
1309 ("k", &[1.0f32; 32]),
1310 ("v", &[1.0f32; 32]),
1311 ]);
1312 let warm_out = warm[0].clone();
1313 assert_eq!(warm_out.len(), 32);
1314
1315 compiled.set_active_extent(Some((2, 4)));
1319 let outs = compiled.run(&[
1320 ("q", &[3.0f32; 32]),
1321 ("k", &[3.0f32; 32]),
1322 ("v", &[3.0f32; 32]),
1323 ]);
1324 let out = &outs[0];
1325 assert_eq!(out.len(), 32);
1326 assert_eq!(
1327 &out[16..],
1328 &warm_out[16..],
1329 "tail (positions 2,3) must be untouched — proves Attention skipped"
1330 );
1331 assert_ne!(
1333 &out[..16],
1334 &warm_out[..16],
1335 "first 2 positions should reflect new input"
1336 );
1337 }
1338
1339 #[test]
1340 fn active_extent_falls_back_when_unsupported_thunk_in_schedule() {
1341 }
1356
1357 #[test]
1358 fn run_padded_uses_active_extent_on_cpu() {
1359 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
1362 let input: Vec<f32> = vec![
1363 1.0, -1.0, 2.0, -2.0, 3.0, -10.0, -20.0, -30.0, -40.0, -50.0, ];
1366 let (upper, outs) = cache
1372 .run_padded(
1373 5,
1374 5,
1375 |max| tiny_graph(max as usize),
1376 &[("x", &input[..5], 1)],
1377 &[1],
1378 )
1379 .unwrap();
1380 assert_eq!(upper, 15);
1381 assert_eq!(outs[0].len(), 5);
1382 assert_eq!(outs[0], vec![1.0, 0.0, 2.0, 0.0, 3.0]);
1388 }
1389
1390 #[test]
1391 fn run_padded_inner_zero_returns_output_unsliced() {
1392 let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
1395 let input: Vec<f32> = vec![1.0, -1.0, 2.0, -2.0, 3.0];
1396
1397 let (upper, outs) = cache
1398 .run_padded(
1399 5,
1400 5,
1401 |max| tiny_graph(max as usize),
1402 &[("x", &input, 1)],
1403 &[0], )
1405 .unwrap();
1406
1407 assert_eq!(upper, 15);
1408 assert_eq!(
1409 outs[0].len(),
1410 15,
1411 "unsliced output preserves full upper extent"
1412 );
1413 assert_eq!(&outs[0][..5], &[1.0, 0.0, 2.0, 0.0, 3.0]);
1415 assert!(outs[0][5..].iter().all(|&v| v == 0.0));
1416 }
1417
1418 #[test]
1419 fn dynamic_dim_cache_specializes_per_key() {
1420 use rlx_ir::DType;
1421 use rlx_ir::Shape;
1422 use rlx_ir::hir::HirModule;
1423 use rlx_ir::sym;
1424
1425 let mut cache = DynamicDimCompileCache::new(Device::Cpu, 4);
1426 let opts = crate::CompileOptions::new();
1427 {
1428 let _short = cache
1429 .get_or_specialize(
1430 8,
1431 &rlx_ir::DimBinding::batch_seq(1, 8),
1432 || {
1433 let mut hir = HirModule::new("dyn_cache");
1434 let x = hir.input_batch_seq("x", sym::BATCH, sym::SEQ, 4, DType::F32);
1435 let w = hir.param("w", Shape::new(&[4, 2], DType::F32));
1436 let y = hir.linear(
1437 x,
1438 w,
1439 None,
1440 None,
1441 Shape::batch_seq(sym::BATCH, sym::SEQ, 2, DType::F32),
1442 );
1443 hir.set_outputs(vec![y]);
1444 hir
1445 },
1446 &opts,
1447 )
1448 .expect("specialize short");
1449 }
1450 assert!(cache.has_template());
1451 assert_eq!(cache.len(), 1);
1452 cache
1453 .get_or_specialize(
1454 128,
1455 &rlx_ir::DimBinding::batch_seq(1, 128),
1456 || panic!("HIR builder must not run twice"),
1457 &opts,
1458 )
1459 .expect("specialize long");
1460 assert_eq!(cache.len(), 2);
1461 }
1462}