use crate::{CompiledGraph, Device, Session};
use rlx_ir::DimBinding;
use rlx_ir::Graph;
use rlx_ir::hir::HirModule;
use rlx_opt::CompileResult;
use std::collections::HashMap;
use std::collections::VecDeque;
use std::ops::Range;
pub struct CacheRunInput<'a> {
pub name: &'a str,
pub data: &'a [f32],
pub row_inner: Option<usize>,
}
pub struct CompileCache {
device: Device,
capacity: usize,
policy: Option<rlx_opt::PrecisionPolicy>,
entries: Vec<(u64, CompiledGraph)>,
order: VecDeque<u64>,
}
impl CompileCache {
pub fn new(device: Device, capacity: usize) -> Self {
Self::with_policy(device, capacity, None)
}
pub fn with_policy(
device: Device,
capacity: usize,
policy: Option<rlx_opt::PrecisionPolicy>,
) -> Self {
assert!(capacity > 0, "CompileCache capacity must be ≥ 1");
Self {
device,
capacity,
policy,
entries: Vec::with_capacity(capacity),
order: VecDeque::with_capacity(capacity),
}
}
pub fn get_or_compile<F: FnOnce() -> Graph>(
&mut self,
key: u64,
build: F,
) -> &mut CompiledGraph {
self.get_or_compile_with_options(key, build, &crate::CompileOptions::new())
}
pub fn get_or_compile_with_options<F: FnOnce() -> Graph>(
&mut self,
key: u64,
build: F,
options: &crate::CompileOptions,
) -> &mut CompiledGraph {
if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
return &mut self.entries[idx].1;
}
let mut session = Session::new(self.device);
if let Some(p) = &self.policy {
session = session.with_policy(p.clone());
}
let compiled = session.compile_with(build(), options);
if self.entries.len() >= self.capacity
&& let Some(evict_key) = self.order.pop_front()
{
self.entries.retain(|(k, _)| *k != evict_key);
}
self.entries.push((key, compiled));
self.order.push_back(key);
&mut self.entries.last_mut().unwrap().1
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn contains(&self, key: u64) -> bool {
self.entries.iter().any(|(k, _)| *k == key)
}
}
pub struct BucketedCompileCache {
device: Device,
policy: Option<rlx_opt::PrecisionPolicy>,
buckets: Vec<Bucket>,
}
struct Bucket {
range: Range<u64>,
compiled: Option<CompiledGraph>,
}
impl BucketedCompileCache {
pub fn new(device: Device, buckets: Vec<Range<u64>>) -> Self {
Self::with_policy(device, buckets, None)
}
pub fn power_of_two_ladder(device: Device, min: u64, max: u64) -> Self {
Self::power_of_two_ladder_with_policy(device, min, max, None)
}
pub fn power_of_two_ladder_with_policy(
device: Device,
min: u64,
max: u64,
policy: Option<rlx_opt::PrecisionPolicy>,
) -> Self {
assert!(min >= 1, "power_of_two_ladder: min must be ≥ 1, got {min}");
assert!(
max >= min,
"power_of_two_ladder: max ({max}) must be ≥ min ({min})"
);
let mut buckets: Vec<Range<u64>> = Vec::new();
let mut start = 1u64;
let mut extent = min.next_power_of_two();
loop {
buckets.push(start..(extent + 1));
if extent >= max {
break;
}
start = extent + 1;
extent = extent
.checked_mul(2)
.expect("power_of_two_ladder: extent overflow");
}
Self::with_policy(device, buckets, policy)
}
pub fn with_policy(
device: Device,
buckets: Vec<Range<u64>>,
policy: Option<rlx_opt::PrecisionPolicy>,
) -> Self {
assert!(!buckets.is_empty(), "BucketedCompileCache needs ≥1 bucket");
for (i, b) in buckets.iter().enumerate() {
assert!(b.start < b.end, "bucket {i} ({b:?}) is empty");
if i + 1 < buckets.len() {
assert!(
b.end <= buckets[i + 1].start,
"buckets {i} ({b:?}) and {} ({:?}) overlap",
i + 1,
buckets[i + 1],
);
}
}
let buckets = buckets
.into_iter()
.map(|range| Bucket {
range,
compiled: None,
})
.collect();
Self {
device,
policy,
buckets,
}
}
pub fn get_or_compile<F: FnOnce(u64) -> Graph>(
&mut self,
key: u64,
build: F,
) -> Option<(u64, &mut CompiledGraph)> {
self.get_or_compile_with_options(key, build, &crate::CompileOptions::new())
}
pub fn get_or_compile_with_options<F: FnOnce(u64) -> Graph>(
&mut self,
key: u64,
build: F,
options: &crate::CompileOptions,
) -> Option<(u64, &mut CompiledGraph)> {
let idx = self.bucket_for(key)?;
let upper = self.buckets[idx].range.end - 1;
if self.buckets[idx].compiled.is_none() {
let mut session = Session::new(self.device);
if let Some(p) = &self.policy {
session = session.with_policy(p.clone());
}
self.buckets[idx].compiled = Some(session.compile_with(build(upper), options));
}
Some((upper, self.buckets[idx].compiled.as_mut().unwrap()))
}
pub fn get_or_compile_hir<F: FnOnce(u64) -> HirModule>(
&mut self,
key: u64,
build: F,
) -> Option<(u64, &mut CompiledGraph)> {
self.get_or_compile_hir_with_options(key, build, &crate::CompileOptions::new())
}
pub fn get_or_compile_hir_with_options<F: FnOnce(u64) -> HirModule>(
&mut self,
key: u64,
build: F,
options: &crate::CompileOptions,
) -> Option<(u64, &mut CompiledGraph)> {
let idx = self.bucket_for(key)?;
let upper = self.buckets[idx].range.end - 1;
if self.buckets[idx].compiled.is_none() {
let mut session = Session::new(self.device);
if let Some(p) = &self.policy {
session = session.with_policy(p.clone());
}
let compiled = session
.compile_hir_with(build(upper), options)
.expect("HIR lower/compile in bucketed cache");
self.buckets[idx].compiled = Some(compiled);
}
Some((upper, self.buckets[idx].compiled.as_mut().unwrap()))
}
pub fn bucket_for(&self, key: u64) -> Option<usize> {
self.buckets.iter().position(|b| b.range.contains(&key))
}
pub fn buckets(&self) -> impl Iterator<Item = &Range<u64>> {
self.buckets.iter().map(|b| &b.range)
}
pub fn compiled_count(&self) -> usize {
self.buckets.iter().filter(|b| b.compiled.is_some()).count()
}
pub fn total_buckets(&self) -> usize {
self.buckets.len()
}
pub fn run_padded<F: FnOnce(u64) -> Graph>(
&mut self,
key: u64,
actual_rows: usize,
build: F,
inputs: &[(&str, &[f32], usize)],
output_inners: &[usize],
) -> Option<(u64, Vec<Vec<f32>>)> {
let (upper, compiled) = self.get_or_compile(key, build)?;
let padded: Vec<(&str, Vec<f32>)> = inputs
.iter()
.map(|(name, data, inner)| (*name, pad_rows(data, *inner, upper)))
.collect();
let pairs: Vec<(&str, &[f32])> = padded.iter().map(|(n, d)| (*n, d.as_slice())).collect();
compiled.set_active_extent(Some((actual_rows, upper as usize)));
let raw_outputs = compiled.run(&pairs);
compiled.set_active_extent(None);
let outs = raw_outputs
.into_iter()
.enumerate()
.map(|(i, out)| match output_inners.get(i).copied() {
Some(0) | None => out,
Some(inner) => slice_rows(&out, inner, actual_rows),
})
.collect();
Some((upper, outs))
}
pub fn ensure_graph_with_params<F>(
&mut self,
key: u64,
build: F,
options: &crate::CompileOptions,
) -> Option<(u64, &mut CompiledGraph)>
where
F: FnOnce(u64) -> (Graph, HashMap<String, Vec<f32>>),
{
let idx = self.bucket_for(key)?;
let upper = self.buckets[idx].range.end - 1;
if self.buckets[idx].compiled.is_none() {
let (graph, params) = build(upper);
let mut session = Session::new(self.device);
if let Some(p) = &self.policy {
session = session.with_policy(p.clone());
}
let mut compiled = session.compile_with(graph, options);
for (name, data) in params {
compiled.set_param(&name, &data);
}
self.buckets[idx].compiled = Some(compiled);
}
Some((upper, self.buckets[idx].compiled.as_mut().unwrap()))
}
pub fn ensure_hir_with_params<F>(
&mut self,
key: u64,
build: F,
options: &crate::CompileOptions,
) -> Option<(u64, &mut CompiledGraph)>
where
F: FnOnce(u64) -> (HirModule, HashMap<String, Vec<f32>>),
{
let idx = self.bucket_for(key)?;
let upper = self.buckets[idx].range.end - 1;
if self.buckets[idx].compiled.is_none() {
let (hir, params) = build(upper);
let mut session = Session::new(self.device);
if let Some(p) = &self.policy {
session = session.with_policy(p.clone());
}
let mut compiled = session
.compile_hir_with(hir, options)
.expect("HIR lower/compile in ensure_hir_with_params");
for (name, data) in params {
compiled.set_param(&name, &data);
}
self.buckets[idx].compiled = Some(compiled);
}
Some((upper, self.buckets[idx].compiled.as_mut().unwrap()))
}
pub fn run_padded_mixed<F>(
&mut self,
key: u64,
actual_rows: usize,
build: F,
inputs: &[CacheRunInput<'_>],
output_inners: &[usize],
) -> Option<(u64, Vec<Vec<f32>>)>
where
F: FnOnce(u64) -> Graph,
{
let (upper, compiled) = self.get_or_compile(key, build)?;
let padded: Vec<(&str, Vec<f32>)> = inputs
.iter()
.map(|inp| match inp.row_inner {
Some(inner) => (inp.name, pad_rows(inp.data, inner, upper)),
None => (inp.name, inp.data.to_vec()),
})
.collect();
let pairs: Vec<(&str, &[f32])> = padded.iter().map(|(n, d)| (*n, d.as_slice())).collect();
compiled.set_active_extent(Some((actual_rows, upper as usize)));
let raw_outputs = compiled.run(&pairs);
compiled.set_active_extent(None);
let outs = raw_outputs
.into_iter()
.enumerate()
.map(|(i, out)| match output_inners.get(i).copied() {
Some(0) | None => out,
Some(inner) => slice_rows(&out, inner, actual_rows),
})
.collect();
Some((upper, outs))
}
}
pub struct DynamicDimCompileCache {
device: Device,
policy: Option<rlx_opt::PrecisionPolicy>,
capacity: usize,
template: Option<CompileResult>,
entries: Vec<(u64, CompiledGraph)>,
order: VecDeque<u64>,
}
impl DynamicDimCompileCache {
pub fn new(device: Device, capacity: usize) -> Self {
Self::with_policy(device, capacity, None)
}
pub fn with_policy(
device: Device,
capacity: usize,
policy: Option<rlx_opt::PrecisionPolicy>,
) -> Self {
assert!(capacity > 0, "DynamicDimCompileCache capacity must be ≥ 1");
Self {
device,
policy,
capacity,
template: None,
entries: Vec::with_capacity(capacity),
order: VecDeque::with_capacity(capacity),
}
}
pub fn compile_device(&self) -> Device {
self.device
}
pub fn get_or_specialize<F: FnOnce() -> HirModule>(
&mut self,
key: u64,
binding: &DimBinding,
build_hir: F,
options: &crate::CompileOptions,
) -> Result<&mut CompiledGraph, rlx_ir::hir::LowerError> {
if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
return Ok(&mut self.entries[idx].1);
}
if self.template.is_none() {
let mut template_opts = options.clone();
template_opts.dim_binding = None;
let pipe = crate::stages::pipeline_for(self.device, &template_opts);
self.template = Some(pipe.compile_hir(build_hir())?);
}
let template = self.template.as_ref().expect("template just set");
let mut spec_opts = options.clone();
spec_opts.dim_binding = None;
let pipe = crate::stages::pipeline_for(self.device, &spec_opts);
let specialized = template.specialize(&pipe, binding);
let backend = crate::registry::backend_for(self.device).expect("backend registered");
let mut compile_opts = options.clone();
compile_opts.dim_binding = None;
if compile_opts.policy.is_none() {
if let Some(p) = &self.policy {
compile_opts = compile_opts.policy(p.clone());
}
}
let executable = backend.compile_lir(specialized.lir, &compile_opts);
let compiled = CompiledGraph::new(executable, self.device);
if self.entries.len() >= self.capacity
&& let Some(evict_key) = self.order.pop_front()
{
self.entries.retain(|(k, _)| *k != evict_key);
}
self.entries.push((key, compiled));
self.order.push_back(key);
Ok(&mut self.entries.last_mut().unwrap().1)
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn contains(&self, key: u64) -> bool {
self.entries.iter().any(|(k, _)| *k == key)
}
pub fn has_template(&self) -> bool {
self.template.is_some()
}
pub fn ensure_template<F: FnOnce() -> HirModule>(
&mut self,
build_hir: F,
options: &crate::CompileOptions,
) -> Result<&CompileResult, rlx_ir::hir::LowerError> {
if self.template.is_none() {
let mut opts = options.clone();
opts.dim_binding = None;
let pipe = crate::stages::pipeline_for(self.device, &opts);
self.template = Some(pipe.compile_hir(build_hir())?);
}
Ok(self.template.as_ref().expect("template set"))
}
pub fn template_result(&self) -> Option<&CompileResult> {
self.template.as_ref()
}
pub fn get_or_specialize_aot<F: FnOnce() -> HirModule>(
&mut self,
aot: &crate::AotCache,
disk_base: &str,
key: u64,
binding: &rlx_ir::DimBinding,
build_hir: F,
options: &crate::CompileOptions,
) -> Result<&mut CompiledGraph, crate::AotCacheError> {
if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
return Ok(&mut self.entries[idx].1);
}
let device = self.device;
let template = self.ensure_template(build_hir, options)?;
let compiled = aot.specialize_cached(disk_base, binding, device, template, options)?;
if self.entries.len() >= self.capacity
&& let Some(evict_key) = self.order.pop_front()
{
self.entries.retain(|(k, _)| *k != evict_key);
}
self.entries.push((key, compiled));
self.order.push_back(key);
Ok(&mut self.entries.last_mut().unwrap().1)
}
}
pub fn pad_rows(data: &[f32], inner: usize, upper: u64) -> Vec<f32> {
assert!(inner > 0, "pad_rows: inner stride must be ≥ 1");
assert_eq!(
data.len() % inner,
0,
"pad_rows: data len {} not a multiple of inner {inner}",
data.len(),
);
let upper = upper as usize;
let actual = data.len() / inner;
assert!(
actual <= upper,
"pad_rows: actual rows {actual} exceed upper bound {upper}",
);
let mut out = vec![0.0_f32; upper * inner];
out[..actual * inner].copy_from_slice(data);
out
}
pub fn slice_rows(data: &[f32], inner: usize, actual: usize) -> Vec<f32> {
assert!(inner > 0, "slice_rows: inner stride must be ≥ 1");
assert_eq!(
data.len() % inner,
0,
"slice_rows: data len {} not a multiple of inner {inner}",
data.len(),
);
let upper = data.len() / inner;
assert!(
actual <= upper,
"slice_rows: actual rows {actual} exceed upper {upper}",
);
data[..actual * inner].to_vec()
}
#[cfg(test)]
mod tests {
use super::*;
use rlx_ir::infer::GraphExt;
use rlx_ir::*;
use std::cell::Cell;
fn tiny_graph(n: usize) -> Graph {
let mut g = Graph::new("t");
let f = DType::F32;
let x = g.input("x", Shape::new(&[n], f));
let y = g.activation(rlx_ir::op::Activation::Relu, x, Shape::new(&[n], f));
g.set_outputs(vec![y]);
g
}
#[test]
fn cache_hits_avoid_recompile() {
let mut cache = CompileCache::new(Device::Cpu, 4);
let calls = Cell::new(0);
let _ = cache.get_or_compile(1, || {
calls.set(calls.get() + 1);
tiny_graph(8)
});
let _ = cache.get_or_compile(1, || {
calls.set(calls.get() + 1);
tiny_graph(8)
});
let _ = cache.get_or_compile(1, || {
calls.set(calls.get() + 1);
tiny_graph(8)
});
assert_eq!(calls.get(), 1);
assert_eq!(cache.len(), 1);
}
#[test]
fn fifo_evicts_oldest_at_capacity() {
let mut cache = CompileCache::new(Device::Cpu, 2);
let _ = cache.get_or_compile(1, || tiny_graph(4));
let _ = cache.get_or_compile(2, || tiny_graph(8));
assert!(cache.contains(1) && cache.contains(2));
let _ = cache.get_or_compile(3, || tiny_graph(16));
assert!(!cache.contains(1));
assert!(cache.contains(2) && cache.contains(3));
}
#[test]
fn different_keys_keep_separate_compiles() {
let mut cache = CompileCache::new(Device::Cpu, 4);
let calls = Cell::new(0);
let _ = cache.get_or_compile(1, || {
calls.set(calls.get() + 1);
tiny_graph(8)
});
let _ = cache.get_or_compile(2, || {
calls.set(calls.get() + 1);
tiny_graph(16)
});
let _ = cache.get_or_compile(1, || {
calls.set(calls.get() + 1);
tiny_graph(8)
});
assert_eq!(calls.get(), 2);
assert_eq!(cache.len(), 2);
}
#[test]
fn bucket_amortizes_keys_within_range() {
let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..4, 4..16]);
let calls = Cell::new(0);
let uppers = Cell::new((0u64, 0u64));
let (u1, _) = cache
.get_or_compile(2, |upper| {
calls.set(calls.get() + 1);
uppers.set((upper, uppers.get().1));
tiny_graph(upper as usize)
})
.expect("key 2 in range");
let (u2, _) = cache
.get_or_compile(3, |upper| {
calls.set(calls.get() + 1);
uppers.set((uppers.get().0, upper));
tiny_graph(upper as usize)
})
.expect("key 3 in range");
assert_eq!(calls.get(), 1);
assert_eq!(u1, 3);
assert_eq!(u2, 3);
assert_eq!(uppers.get().0, 3);
assert_eq!(cache.compiled_count(), 1);
assert_eq!(cache.total_buckets(), 2);
}
#[test]
fn bucket_lookup_returns_none_outside_range() {
let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..4, 4..16]);
assert!(cache.bucket_for(0).is_none());
assert!(cache.bucket_for(16).is_none());
assert!(cache.bucket_for(100).is_none());
assert_eq!(cache.bucket_for(3), Some(0));
assert_eq!(cache.bucket_for(4), Some(1));
let calls = Cell::new(0);
let result = cache.get_or_compile(100, |u| {
calls.set(calls.get() + 1);
tiny_graph(u as usize)
});
assert!(result.is_none());
assert_eq!(calls.get(), 0); assert_eq!(cache.compiled_count(), 0);
}
#[test]
fn bucket_compiles_lazily_per_bucket() {
let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..4, 4..16, 16..64]);
let calls = Cell::new(0);
let _ = cache.get_or_compile(2, |u| {
calls.set(calls.get() + 1);
tiny_graph(u as usize)
});
let _ = cache.get_or_compile(8, |u| {
calls.set(calls.get() + 1);
tiny_graph(u as usize)
});
assert_eq!(calls.get(), 2);
assert_eq!(cache.compiled_count(), 2);
assert_eq!(cache.total_buckets(), 3);
}
#[test]
#[should_panic(expected = "overlap")]
fn bucket_overlap_rejected() {
let _ = BucketedCompileCache::new(Device::Cpu, vec![1..8, 4..16]);
}
#[test]
#[should_panic(expected = "≥1 bucket")]
fn empty_bucket_list_rejected() {
let _ = BucketedCompileCache::new(Device::Cpu, vec![]);
}
#[test]
fn pad_rows_appends_zeros() {
let p = pad_rows(&[1.0, 2.0, 3.0], 1, 5);
assert_eq!(p, vec![1.0, 2.0, 3.0, 0.0, 0.0]);
let p = pad_rows(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 4);
assert_eq!(
p,
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
);
let p = pad_rows(&[7.0, 8.0], 1, 2);
assert_eq!(p, vec![7.0, 8.0]);
}
#[test]
fn slice_rows_truncates_trailing() {
let s = slice_rows(&[1.0, 2.0, 3.0, 0.0, 0.0], 1, 3);
assert_eq!(s, vec![1.0, 2.0, 3.0]);
let s = slice_rows(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0, 0.0], 3, 2);
assert_eq!(s, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
}
#[test]
#[should_panic(expected = "exceed upper")]
fn pad_rows_rejects_too_long_input() {
let _ = pad_rows(&[1.0, 2.0, 3.0, 4.0], 1, 3);
}
#[test]
#[should_panic(expected = "exceed upper")]
fn slice_rows_rejects_too_large_actual() {
let _ = slice_rows(&[1.0, 2.0, 3.0], 1, 5);
}
#[test]
fn run_padded_pads_input_and_slices_output() {
let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
let input: Vec<f32> = vec![1.0, -1.0, 2.0, -2.0, 3.0, -3.0, 4.0, -4.0, 5.0, -5.0];
let (upper, outs) = cache
.run_padded(
10, 10, |max| tiny_graph(max as usize),
&[("x", &input, 1)], &[1], )
.expect("key 10 in [1..16)");
assert_eq!(upper, 15);
assert_eq!(outs.len(), 1);
let out = &outs[0];
assert_eq!(out.len(), 10, "output sliced back to actual_rows");
let expected: Vec<f32> = input.iter().map(|x| x.max(0.0)).collect();
assert_eq!(out, &expected);
}
#[test]
fn run_padded_reuses_bucket_across_actuals() {
let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
let calls = Cell::new(0);
let (u1, o1) = cache
.run_padded(
10,
10,
|max| {
calls.set(calls.get() + 1);
tiny_graph(max as usize)
},
&[(
"x",
&[1.0, -1.0, 2.0, -2.0, 3.0, -3.0, 4.0, -4.0, 5.0, -5.0],
1,
)],
&[1],
)
.unwrap();
assert_eq!(o1.len(), 1);
assert_eq!(o1[0].len(), 10);
assert_eq!(u1, 15);
let (u2, o2) = cache
.run_padded(
5,
5,
|max| {
calls.set(calls.get() + 1);
tiny_graph(max as usize)
},
&[("x", &[-1.0, 2.0, -3.0, 4.0, -5.0], 1)],
&[1],
)
.unwrap();
assert_eq!(o2.len(), 1);
assert_eq!(o2[0].len(), 5);
assert_eq!(u2, 15);
assert_eq!(o2[0], vec![0.0, 2.0, 0.0, 4.0, 0.0]);
assert_eq!(calls.get(), 1, "bucket cached across actuals");
assert_eq!(cache.compiled_count(), 1);
}
#[test]
fn run_padded_returns_none_out_of_range() {
let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
let calls = Cell::new(0);
let result = cache.run_padded(
100,
5,
|u| {
calls.set(calls.get() + 1);
tiny_graph(u as usize)
},
&[("x", &[1.0, 2.0, 3.0, 4.0, 5.0], 1)],
&[1],
);
assert!(result.is_none());
assert_eq!(calls.get(), 0);
assert_eq!(cache.compiled_count(), 0);
}
#[test]
fn power_of_two_ladder_generates_log_buckets() {
let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 8, 64);
let ranges: Vec<_> = cache.buckets().cloned().collect();
assert_eq!(ranges, vec![1..9, 9..17, 17..33, 33..65]);
assert_eq!(cache.total_buckets(), 4);
}
#[test]
fn power_of_two_ladder_picks_smallest_extent_for_actual() {
let mut cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 8, 64);
let captured_uppers: std::cell::RefCell<Vec<u64>> = Default::default();
let (u17, _) = cache
.get_or_compile(17, |upper| {
captured_uppers.borrow_mut().push(upper);
tiny_graph(upper as usize)
})
.unwrap();
let (u9, _) = cache
.get_or_compile(9, |upper| {
captured_uppers.borrow_mut().push(upper);
tiny_graph(upper as usize)
})
.unwrap();
let (u3, _) = cache
.get_or_compile(3, |upper| {
captured_uppers.borrow_mut().push(upper);
tiny_graph(upper as usize)
})
.unwrap();
let (u64_, _) = cache
.get_or_compile(64, |upper| {
captured_uppers.borrow_mut().push(upper);
tiny_graph(upper as usize)
})
.unwrap();
assert_eq!(u17, 32, "key=17 → smallest extent ≥ 17 is 32");
assert_eq!(u9, 16, "key=9 → smallest extent ≥ 9 is 16");
assert_eq!(u3, 8, "key=3 → smallest extent ≥ 3 is 8");
assert_eq!(u64_, 64, "key=64 → exact match at 64");
assert_eq!(*captured_uppers.borrow(), vec![32, 16, 8, 64]);
assert_eq!(cache.compiled_count(), 4);
}
#[test]
fn power_of_two_ladder_min_above_one_starts_at_one() {
let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 16, 32);
let ranges: Vec<_> = cache.buckets().cloned().collect();
assert_eq!(ranges, vec![1..17, 17..33]);
}
#[test]
fn power_of_two_ladder_non_pow2_min_rounds_up() {
let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 10, 64);
let ranges: Vec<_> = cache.buckets().cloned().collect();
assert_eq!(ranges, vec![1..17, 17..33, 33..65]);
}
#[test]
fn power_of_two_ladder_max_below_pow2_extends_up() {
let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 8, 20);
let ranges: Vec<_> = cache.buckets().cloned().collect();
assert_eq!(ranges, vec![1..9, 9..17, 17..33]);
}
#[test]
fn power_of_two_ladder_min_equals_max() {
let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 16, 16);
let ranges: Vec<_> = cache.buckets().cloned().collect();
assert_eq!(ranges, vec![1..17]);
}
#[test]
#[should_panic(expected = "min must be ≥ 1")]
fn power_of_two_ladder_zero_min_rejected() {
let _ = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 0, 16);
}
#[test]
#[should_panic(expected = "max")]
fn power_of_two_ladder_max_below_min_rejected() {
let _ = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 32, 8);
}
#[test]
#[ignore = "active-extent execution is a stub on CPU (thunk.rs::execute_thunks_active)"]
fn active_extent_skips_compute_on_cpu_activation() {
let graph = tiny_graph(15);
let mut compiled = Session::new(Device::Cpu).compile(graph);
let warm_input: Vec<f32> = vec![1.0; 15];
let warm_outs = compiled.run(&[("x", &warm_input)]);
assert_eq!(warm_outs[0], vec![1.0; 15], "warm-up sanity");
let neg_input: Vec<f32> = vec![-1.0; 15];
compiled.set_active_extent(Some((5, 15)));
let outs = compiled.run(&[("x", &neg_input)]);
let out = &outs[0];
assert_eq!(out.len(), 15);
assert_eq!(
out[..5],
[0.0; 5],
"first 5 elements processed (relu of -1)"
);
assert_eq!(
out[5..],
[1.0; 10],
"tail untouched — proves Copy + Activation skipped indices 5..15"
);
compiled.set_active_extent(None);
let outs = compiled.run(&[("x", &neg_input)]);
assert_eq!(
outs[0],
vec![0.0; 15],
"full-extent path must clip every negative"
);
}
#[test]
#[ignore = "active-extent execution is a stub on CPU (thunk.rs::execute_thunks_active)"]
fn active_extent_skips_compute_on_binary_full() {
let mut g = Graph::new("add");
let f = DType::F32;
let a = g.input("a", Shape::new(&[4], f));
let b = g.input("b", Shape::new(&[4], f));
let c = g.add(a, b);
g.set_outputs(vec![c]);
let mut compiled = Session::new(Device::Cpu).compile(g);
let warm = compiled.run(&[("a", &[1.0f32; 4]), ("b", &[1.0f32; 4])]);
assert_eq!(warm[0], vec![2.0; 4]);
compiled.set_active_extent(Some((2, 4)));
let outs = compiled.run(&[("a", &[10.0f32; 4]), ("b", &[10.0f32; 4])]);
let out = &outs[0];
assert_eq!(out[..2], [20.0, 20.0], "first 2 = active sum");
assert_eq!(
out[2..],
[2.0, 2.0],
"tail untouched — proves BinaryFull skipped indices 2..4"
);
compiled.set_active_extent(None);
let outs = compiled.run(&[("a", &[10.0f32; 4]), ("b", &[10.0f32; 4])]);
assert_eq!(outs[0], vec![20.0; 4]);
}
#[test]
#[ignore = "process-wide STATE; runs only in isolation via `cargo test perfetto -- --ignored`"]
fn perfetto_trace_emits_per_thunk_events() {
use std::env;
use std::fs;
let path = env::temp_dir().join(format!("rlx-perfetto-e2e-{}.json", std::process::id()));
if path.exists() {
let _ = fs::remove_file(&path);
}
unsafe {
env::set_var("RLX_TRACE_PERFETTO", &path);
}
let f = DType::F32;
let mut g = Graph::new("perf");
let a = g.input("a", Shape::new(&[4], f));
let b = g.input("b", Shape::new(&[4], f));
let s = g.add(a, b);
let r = g.relu(s);
g.set_outputs(vec![r]);
let mut compiled = Session::new(Device::Cpu).compile(g);
let _ = compiled.run(&[("a", &[1.0; 4]), ("b", &[1.0; 4])]);
crate::perfetto::flush_and_finalize();
let contents = fs::read_to_string(&path).expect("trace file");
assert!(
contents.contains("\"binary\"")
|| contents.contains("\"activation\"")
|| contents.contains("\"elementwise_region\""),
"expected at least one thunk-name event in perfetto trace; got: {contents}"
);
assert!(contents.trim_start().starts_with('['));
let _ = fs::remove_file(&path);
}
#[test]
fn elementwise_region_fused_matches_unfused() {
let f = DType::F32;
let mut g = Graph::new("ew_e2e");
let a = g.input("a", Shape::new(&[8], f));
let b = g.input("b", Shape::new(&[8], f));
let c = g.input("c", Shape::new(&[8], f));
let s = Shape::new(&[8], f);
let add = g.add(a, b);
let mul = g.mul(add, c);
let relu = g.relu(mul);
let _ = s;
g.set_outputs(vec![relu]);
let mut compiled = Session::new(Device::Cpu).compile(g);
let av: Vec<f32> = vec![1.0, -2.0, 3.0, -4.0, 0.5, -0.5, 1.5, -1.5];
let bv: Vec<f32> = vec![0.5, 1.0, 2.0, 4.0, 0.5, 0.5, 0.5, 0.5];
let cv: Vec<f32> = vec![1.0, 2.0, 1.0, 1.0, 2.0, 3.0, 0.5, 4.0];
let outs = compiled.run(&[("a", &av), ("b", &bv), ("c", &cv)]);
let out = &outs[0];
let expected: Vec<f32> = (0..8)
.map(|i| {
let v = (av[i] + bv[i]) * cv[i];
v.max(0.0)
})
.collect();
for (i, (got, exp)) in out.iter().zip(&expected).enumerate() {
assert!(
(got - exp).abs() < 1e-6,
"mismatch at {i}: got {got}, expected {exp}"
);
}
}
#[test]
#[ignore = "active-extent execution is a stub on CPU (thunk.rs::execute_thunks_active)"]
fn active_extent_skips_compute_on_attention() {
use rlx_ir::op::MaskKind;
let f = DType::F32;
let mut g = Graph::new("attn");
let q = g.input("q", Shape::new(&[1, 4, 8], f));
let k = g.input("k", Shape::new(&[1, 4, 8], f));
let v = g.input("v", Shape::new(&[1, 4, 8], f));
let out = g.attention_kind(q, k, v, 2, 4, MaskKind::None, Shape::new(&[1, 4, 8], f));
g.set_outputs(vec![out]);
let mut compiled = Session::new(Device::Cpu).compile(g);
let warm = compiled.run(&[
("q", &[1.0f32; 32]),
("k", &[1.0f32; 32]),
("v", &[1.0f32; 32]),
]);
let warm_out = warm[0].clone();
assert_eq!(warm_out.len(), 32);
compiled.set_active_extent(Some((2, 4)));
let outs = compiled.run(&[
("q", &[3.0f32; 32]),
("k", &[3.0f32; 32]),
("v", &[3.0f32; 32]),
]);
let out = &outs[0];
assert_eq!(out.len(), 32);
assert_eq!(
&out[16..],
&warm_out[16..],
"tail (positions 2,3) must be untouched — proves Attention skipped"
);
assert_ne!(
&out[..16],
&warm_out[..16],
"first 2 positions should reflect new input"
);
}
#[test]
fn active_extent_falls_back_when_unsupported_thunk_in_schedule() {
}
#[test]
fn run_padded_uses_active_extent_on_cpu() {
let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
let input: Vec<f32> = vec![
1.0, -1.0, 2.0, -2.0, 3.0, -10.0, -20.0, -30.0, -40.0, -50.0, ];
let (upper, outs) = cache
.run_padded(
5,
5,
|max| tiny_graph(max as usize),
&[("x", &input[..5], 1)],
&[1],
)
.unwrap();
assert_eq!(upper, 15);
assert_eq!(outs[0].len(), 5);
assert_eq!(outs[0], vec![1.0, 0.0, 2.0, 0.0, 3.0]);
}
#[test]
fn run_padded_inner_zero_returns_output_unsliced() {
let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
let input: Vec<f32> = vec![1.0, -1.0, 2.0, -2.0, 3.0];
let (upper, outs) = cache
.run_padded(
5,
5,
|max| tiny_graph(max as usize),
&[("x", &input, 1)],
&[0], )
.unwrap();
assert_eq!(upper, 15);
assert_eq!(
outs[0].len(),
15,
"unsliced output preserves full upper extent"
);
assert_eq!(&outs[0][..5], &[1.0, 0.0, 2.0, 0.0, 3.0]);
assert!(outs[0][5..].iter().all(|&v| v == 0.0));
}
#[test]
fn dynamic_dim_cache_specializes_per_key() {
use rlx_ir::DType;
use rlx_ir::Shape;
use rlx_ir::hir::HirModule;
use rlx_ir::sym;
let mut cache = DynamicDimCompileCache::new(Device::Cpu, 4);
let opts = crate::CompileOptions::new();
{
let _short = cache
.get_or_specialize(
8,
&rlx_ir::DimBinding::batch_seq(1, 8),
|| {
let mut hir = HirModule::new("dyn_cache");
let x = hir.input_batch_seq("x", sym::BATCH, sym::SEQ, 4, DType::F32);
let w = hir.param("w", Shape::new(&[4, 2], DType::F32));
let y = hir.linear(
x,
w,
None,
None,
Shape::batch_seq(sym::BATCH, sym::SEQ, 2, DType::F32),
);
hir.set_outputs(vec![y]);
hir
},
&opts,
)
.expect("specialize short");
}
assert!(cache.has_template());
assert_eq!(cache.len(), 1);
cache
.get_or_specialize(
128,
&rlx_ir::DimBinding::batch_seq(1, 128),
|| panic!("HIR builder must not run twice"),
&opts,
)
.expect("specialize long");
assert_eq!(cache.len(), 2);
}
}