Skip to main content

rlx_runtime/
compile_cache.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Shape-bucketed compile cache.
17//!
18//! Lets variable-shape callers (e.g., embedding-model wrappers that vary
19//! batch + seq per request) amortize the per-(shape) compile cost. Cache
20//! keys are caller-provided `u64`s — the caller decides what counts as a
21//! shape bucket. Typical recipe: `(batch as u64) << 32 | seq as u64`.
22//!
23//! The cache stores one `CompiledGraph` per key. Params loaded onto a
24//! cached entry persist for that entry — re-fetching from cache does
25//! **not** require re-running `set_param`. Eviction is FIFO, capped at
26//! `capacity` entries (good enough for the current "a handful of common
27//! shapes" usage pattern; switch to LRU if a real workload shows churn).
28//!
29//! # Example
30//!
31//! ```rust,ignore
32//! let mut cache = CompileCache::new(Device::Metal, 8);
33//! let key = ((batch as u64) << 32) | seq as u64;
34//! let mut compiled = cache.get_or_compile(key, || build_my_graph(batch, seq));
35//! // First call for `key`: compiles. Subsequent calls: cache hit.
36//! compiled.run(&[("x", &input_data)]);
37//! ```
38
39use crate::{CompiledGraph, Device, Session};
40use rlx_ir::DimBinding;
41use rlx_ir::Graph;
42use rlx_ir::hir::HirModule;
43use rlx_opt::CompileResult;
44use std::collections::VecDeque;
45use std::ops::Range;
46
47pub struct CompileCache {
48    device: Device,
49    capacity: usize,
50    // Per-cache precision policy. None → default (F32). Set once at
51    // construction; applies to every compile this cache performs.
52    policy: Option<rlx_opt::PrecisionPolicy>,
53    // (key, compiled). Vec keeps insertion order for FIFO eviction; the
54    // expected hit-rate at our cap (~8) makes the linear scan cheaper
55    // than a HashMap + separate eviction list.
56    entries: Vec<(u64, CompiledGraph)>,
57    // Insertion order for eviction.
58    order: VecDeque<u64>,
59}
60
61impl CompileCache {
62    pub fn new(device: Device, capacity: usize) -> Self {
63        Self::with_policy(device, capacity, None)
64    }
65
66    /// Cache that compiles every entry with the given precision policy.
67    /// Use this when the cached entries should differ from CPU-default
68    /// F32 — e.g., `PrecisionPolicy::AutoMixed` for f16 compute on Metal.
69    pub fn with_policy(
70        device: Device,
71        capacity: usize,
72        policy: Option<rlx_opt::PrecisionPolicy>,
73    ) -> Self {
74        assert!(capacity > 0, "CompileCache capacity must be ≥ 1");
75        Self {
76            device,
77            capacity,
78            policy,
79            entries: Vec::with_capacity(capacity),
80            order: VecDeque::with_capacity(capacity),
81        }
82    }
83
84    /// Compile if not present, then return a mutable reference. The borrow
85    /// lifetime is tied to `&mut self` so callers naturally serialize their
86    /// use of any one entry — the cache is single-owner today.
87    pub fn get_or_compile<F: FnOnce() -> Graph>(
88        &mut self,
89        key: u64,
90        build: F,
91    ) -> &mut CompiledGraph {
92        self.get_or_compile_with_options(key, build, &crate::CompileOptions::new())
93    }
94
95    /// Like [`Self::get_or_compile`] with explicit [`CompileOptions`].
96    pub fn get_or_compile_with_options<F: FnOnce() -> Graph>(
97        &mut self,
98        key: u64,
99        build: F,
100        options: &crate::CompileOptions,
101    ) -> &mut CompiledGraph {
102        if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
103            return &mut self.entries[idx].1;
104        }
105        let mut session = Session::new(self.device);
106        if let Some(p) = &self.policy {
107            session = session.with_policy(p.clone());
108        }
109        let compiled = session.compile_with(build(), options);
110
111        // Evict FIFO if at capacity.
112        if self.entries.len() >= self.capacity
113            && let Some(evict_key) = self.order.pop_front()
114        {
115            self.entries.retain(|(k, _)| *k != evict_key);
116        }
117        self.entries.push((key, compiled));
118        self.order.push_back(key);
119        &mut self.entries.last_mut().unwrap().1
120    }
121
122    /// Number of entries currently cached. Useful for tests + diagnostics.
123    pub fn len(&self) -> usize {
124        self.entries.len()
125    }
126    pub fn is_empty(&self) -> bool {
127        self.entries.is_empty()
128    }
129    /// Was this key already compiled? Doesn't change recency.
130    pub fn contains(&self, key: u64) -> bool {
131        self.entries.iter().any(|(k, _)| *k == key)
132    }
133}
134
135// ── Bucketed cache (PLAN L1) ──────────────────────────────────────────
136//
137// Variant of `CompileCache` that compiles one `CompiledGraph` per shape
138// *range* instead of per exact key. The caller declares buckets up front
139// (e.g. `1..16`, `16..64`, `64..256`); each bucket is compiled lazily at
140// its upper bound the first time a key in that bucket arrives.
141//
142// Trade vs `CompileCache`: unique keys → unique compiles becomes unique
143// buckets → unique compiles. The compiled graph is specialized for each
144// bucket's upper-bound dim. Two ways to use it:
145//
146// **Manual padding** — caller drives the pad/slice cycle:
147// ```rust,ignore
148// let buckets = vec![1..16, 16..64, 64..256];
149// let mut cache = BucketedCompileCache::new(Device::Metal, buckets);
150// let (upper, compiled) = cache
151//     .get_or_compile(seq as u64, |max_seq| build_graph(max_seq as usize))
152//     .expect("seq within buckets");
153// // pad input to `upper as usize` elements before run
154// compiled.run(&[("x", &padded)]);
155// ```
156//
157// **`run_padded` shortcut** — cache pads and slices for you:
158// ```rust,ignore
159// let (upper, outputs) = cache.run_padded(
160//     seq as u64,
161//     seq,                                    // actual rows
162//     |max_seq| build_graph(max_seq as usize),
163//     &[("x", &raw_input, hidden)],           // (name, data, inner stride)
164//     &[hidden],                              // per-output inner stride
165// ).expect("in range");
166// ```
167//
168// **How "skip compute" actually works here**: each bucket compiles at
169// its own upper bound, so kernels run at *that* extent, not at some
170// global maximum. Smaller buckets ⇒ less padded compute. The
171// `power_of_two_ladder` constructor builds a logarithmic schedule that
172// guarantees ≤2× padding waste in exchange for `O(log max)` compiled
173// artifacts. For finer control, hand-construct the bucket list.
174//
175// True per-kernel active-extent dispatch (one big compile, runtime
176// extent override that short-circuits each kernel's inner loop) is a
177// per-backend change across `rlx-cuda`, `rlx-rocm`,
178// `rlx-cpu/src/thunk.rs`, `rlx-metal/src/thunk.rs`, `rlx-mlx`,
179// `rlx-wgpu` — multi-day project, not in this layer.
180
181pub struct BucketedCompileCache {
182    device: Device,
183    policy: Option<rlx_opt::PrecisionPolicy>,
184    buckets: Vec<Bucket>,
185}
186
187struct Bucket {
188    range: Range<u64>,
189    compiled: Option<CompiledGraph>,
190}
191
192impl BucketedCompileCache {
193    pub fn new(device: Device, buckets: Vec<Range<u64>>) -> Self {
194        Self::with_policy(device, buckets, None)
195    }
196
197    /// Power-of-two ladder over `[1, max]`, with extents
198    /// `[min_pow2, 2·min_pow2, 4·min_pow2, …, max_pow2]` where
199    /// `min_pow2 = min.next_power_of_two()` and `max_pow2` is the smallest
200    /// power of two ≥ `max`. Each bucket compiles at its upper-bound
201    /// extent, so an `actual` value in bucket `(prev_extent .. ext]` runs
202    /// kernels at extent `ext` (not at the worst case of the whole range).
203    /// Guarantees compute waste from padding ≤2× — `actual > ext / 2`
204    /// for every bucket except possibly the smallest.
205    ///
206    /// Example: `power_of_two_ladder(Device::Cpu, 8, 256)` yields buckets
207    /// `1..9, 9..17, 17..33, 33..65, 65..129, 129..257` with compile
208    /// extents `8, 16, 32, 64, 128, 256`. An `actual = 17` runs at extent
209    /// 32 instead of the 255 a single wide `1..256` bucket would compile
210    /// at — that's the "skip compute" win, paid for with `O(log max)`
211    /// compiled artifacts instead of one.
212    pub fn power_of_two_ladder(device: Device, min: u64, max: u64) -> Self {
213        Self::power_of_two_ladder_with_policy(device, min, max, None)
214    }
215
216    pub fn power_of_two_ladder_with_policy(
217        device: Device,
218        min: u64,
219        max: u64,
220        policy: Option<rlx_opt::PrecisionPolicy>,
221    ) -> Self {
222        assert!(min >= 1, "power_of_two_ladder: min must be ≥ 1, got {min}");
223        assert!(
224            max >= min,
225            "power_of_two_ladder: max ({max}) must be ≥ min ({min})"
226        );
227        let mut buckets: Vec<Range<u64>> = Vec::new();
228        let mut start = 1u64;
229        let mut extent = min.next_power_of_two();
230        loop {
231            buckets.push(start..(extent + 1));
232            if extent >= max {
233                break;
234            }
235            start = extent + 1;
236            extent = extent
237                .checked_mul(2)
238                .expect("power_of_two_ladder: extent overflow");
239        }
240        Self::with_policy(device, buckets, policy)
241    }
242
243    pub fn with_policy(
244        device: Device,
245        buckets: Vec<Range<u64>>,
246        policy: Option<rlx_opt::PrecisionPolicy>,
247    ) -> Self {
248        assert!(!buckets.is_empty(), "BucketedCompileCache needs ≥1 bucket");
249        for (i, b) in buckets.iter().enumerate() {
250            assert!(b.start < b.end, "bucket {i} ({b:?}) is empty");
251            if i + 1 < buckets.len() {
252                assert!(
253                    b.end <= buckets[i + 1].start,
254                    "buckets {i} ({b:?}) and {} ({:?}) overlap",
255                    i + 1,
256                    buckets[i + 1],
257                );
258            }
259        }
260        let buckets = buckets
261            .into_iter()
262            .map(|range| Bucket {
263                range,
264                compiled: None,
265            })
266            .collect();
267        Self {
268            device,
269            policy,
270            buckets,
271        }
272    }
273
274    /// Find the bucket containing `key`, compile if needed, return
275    /// `(upper, &mut CompiledGraph)` where `upper = range.end - 1` is the
276    /// extent the graph was compiled for. Caller pads inputs to `upper`
277    /// before calling `run`. Returns `None` if `key` is outside every
278    /// bucket — caller decides whether to fall back to a one-off compile.
279    ///
280    /// `build` receives `upper` and must return a `Graph` specialized for
281    /// that extent.
282    pub fn get_or_compile<F: FnOnce(u64) -> Graph>(
283        &mut self,
284        key: u64,
285        build: F,
286    ) -> Option<(u64, &mut CompiledGraph)> {
287        self.get_or_compile_with_options(key, build, &crate::CompileOptions::new())
288    }
289
290    /// Like [`Self::get_or_compile`] with explicit [`CompileOptions`].
291    pub fn get_or_compile_with_options<F: FnOnce(u64) -> Graph>(
292        &mut self,
293        key: u64,
294        build: F,
295        options: &crate::CompileOptions,
296    ) -> Option<(u64, &mut CompiledGraph)> {
297        let idx = self.bucket_for(key)?;
298        let upper = self.buckets[idx].range.end - 1;
299        if self.buckets[idx].compiled.is_none() {
300            let mut session = Session::new(self.device);
301            if let Some(p) = &self.policy {
302                session = session.with_policy(p.clone());
303            }
304            self.buckets[idx].compiled = Some(session.compile_with(build(upper), options));
305        }
306        Some((upper, self.buckets[idx].compiled.as_mut().unwrap()))
307    }
308
309    /// Like [`Self::get_or_compile`] but builds and compiles HIR directly
310    /// through the fusion-first pipeline (`Session::compile_hir`).
311    pub fn get_or_compile_hir<F: FnOnce(u64) -> HirModule>(
312        &mut self,
313        key: u64,
314        build: F,
315    ) -> Option<(u64, &mut CompiledGraph)> {
316        self.get_or_compile_hir_with_options(key, build, &crate::CompileOptions::new())
317    }
318
319    /// Like [`Self::get_or_compile_hir`] with explicit [`CompileOptions`] (tier-1 profile, fusion target, …).
320    pub fn get_or_compile_hir_with_options<F: FnOnce(u64) -> HirModule>(
321        &mut self,
322        key: u64,
323        build: F,
324        options: &crate::CompileOptions,
325    ) -> Option<(u64, &mut CompiledGraph)> {
326        let idx = self.bucket_for(key)?;
327        let upper = self.buckets[idx].range.end - 1;
328        if self.buckets[idx].compiled.is_none() {
329            let mut session = Session::new(self.device);
330            if let Some(p) = &self.policy {
331                session = session.with_policy(p.clone());
332            }
333            let compiled = session
334                .compile_hir_with(build(upper), options)
335                .expect("HIR lower/compile in bucketed cache");
336            self.buckets[idx].compiled = Some(compiled);
337        }
338        Some((upper, self.buckets[idx].compiled.as_mut().unwrap()))
339    }
340
341    /// Index of the bucket containing `key`, or `None` if out of range.
342    /// Linear scan — bucket counts are small in practice.
343    pub fn bucket_for(&self, key: u64) -> Option<usize> {
344        self.buckets.iter().position(|b| b.range.contains(&key))
345    }
346
347    pub fn buckets(&self) -> impl Iterator<Item = &Range<u64>> {
348        self.buckets.iter().map(|b| &b.range)
349    }
350
351    /// Number of buckets that have been compiled so far (≤ total buckets).
352    pub fn compiled_count(&self) -> usize {
353        self.buckets.iter().filter(|b| b.compiled.is_some()).count()
354    }
355
356    pub fn total_buckets(&self) -> usize {
357        self.buckets.len()
358    }
359
360    /// "Compile at max, run at less" convenience for inputs and outputs
361    /// whose outer dimension is the bucket key:
362    ///
363    /// 1. Find or compile the bucket containing `key`.
364    /// 2. For each input, pad to `upper` rows along the outer dim using
365    ///    `pad_rows` (caller passes the inner-dim stride per input;
366    ///    `inner = 1` for purely 1D inputs).
367    /// 3. Run the compiled graph at full extent.
368    /// 4. Slice each output back to `actual_rows` along its outer dim.
369    ///    Outputs flagged with `inner = 0` in `output_inners` are
370    ///    returned unsliced (use this for extent-independent outputs
371    ///    like a pooled `[hidden]` embedding). Missing entries past
372    ///    the end of `output_inners` are also returned unsliced.
373    ///
374    /// Returns `(upper, outputs)`. Returns `None` if `key` falls outside
375    /// every bucket.
376    ///
377    /// **Compute scope:** kernels execute at the bucket's compile
378    /// extent (`upper`), not at `actual_rows`. This means smaller
379    /// buckets directly translate to less padded compute. With
380    /// [`power_of_two_ladder`](Self::power_of_two_ladder) the worst-
381    /// case waste is bounded at 2×; with hand-tuned buckets it can be
382    /// arbitrarily tight. True active-extent dispatch — one big
383    /// compile, kernels short-circuit at runtime — is a separate
384    /// per-backend change.
385    pub fn run_padded<F: FnOnce(u64) -> Graph>(
386        &mut self,
387        key: u64,
388        actual_rows: usize,
389        build: F,
390        inputs: &[(&str, &[f32], usize)],
391        output_inners: &[usize],
392    ) -> Option<(u64, Vec<Vec<f32>>)> {
393        let (upper, compiled) = self.get_or_compile(key, build)?;
394
395        // Own the padded buffers so they outlive the borrow handed to `run`.
396        let padded: Vec<(&str, Vec<f32>)> = inputs
397            .iter()
398            .map(|(name, data, inner)| (*name, pad_rows(data, *inner, upper)))
399            .collect();
400        let pairs: Vec<(&str, &[f32])> = padded.iter().map(|(n, d)| (*n, d.as_slice())).collect();
401
402        // Hint active-extent: backends that support per-kernel skip-
403        // compute (today: CPU's Activation thunk family) honor it; the
404        // default trait impl is a no-op, so other backends just process
405        // full extent and the slice_rows below still gives the user
406        // correct outputs.
407        compiled.set_active_extent(Some((actual_rows, upper as usize)));
408        let raw_outputs = compiled.run(&pairs);
409        compiled.set_active_extent(None);
410
411        let outs = raw_outputs
412            .into_iter()
413            .enumerate()
414            .map(|(i, out)| match output_inners.get(i).copied() {
415                Some(0) | None => out,
416                Some(inner) => slice_rows(&out, inner, actual_rows),
417            })
418            .collect();
419
420        Some((upper, outs))
421    }
422}
423
424// ── Dynamic-dim cache (plan #54) ──────────────────────────────────────
425//
426// Compile HIR once through the fusion pipeline (graph may contain
427// `Dim::Dynamic` symbols), then specialize to concrete shapes per cache
428// key and backend-compile the resulting LIR.
429
430/// Compile-once / specialize-at-runtime cache for symbolic HIR modules.
431pub struct DynamicDimCompileCache {
432    device: Device,
433    policy: Option<rlx_opt::PrecisionPolicy>,
434    capacity: usize,
435    template: Option<CompileResult>,
436    entries: Vec<(u64, CompiledGraph)>,
437    order: VecDeque<u64>,
438}
439
440impl DynamicDimCompileCache {
441    pub fn new(device: Device, capacity: usize) -> Self {
442        Self::with_policy(device, capacity, None)
443    }
444
445    pub fn with_policy(
446        device: Device,
447        capacity: usize,
448        policy: Option<rlx_opt::PrecisionPolicy>,
449    ) -> Self {
450        assert!(capacity > 0, "DynamicDimCompileCache capacity must be ≥ 1");
451        Self {
452            device,
453            policy,
454            capacity,
455            template: None,
456            entries: Vec::with_capacity(capacity),
457            order: VecDeque::with_capacity(capacity),
458        }
459    }
460
461    pub fn compile_device(&self) -> Device {
462        self.device
463    }
464
465    /// Return a backend-compiled graph specialized for `binding`.
466    /// `build_hir` runs at most once to populate the dynamic template.
467    pub fn get_or_specialize<F: FnOnce() -> HirModule>(
468        &mut self,
469        key: u64,
470        binding: &DimBinding,
471        build_hir: F,
472        options: &crate::CompileOptions,
473    ) -> Result<&mut CompiledGraph, rlx_ir::hir::LowerError> {
474        if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
475            return Ok(&mut self.entries[idx].1);
476        }
477        if self.template.is_none() {
478            let mut template_opts = options.clone();
479            template_opts.dim_binding = None;
480            let pipe = crate::stages::pipeline_for(self.device, &template_opts);
481            self.template = Some(pipe.compile_hir(build_hir())?);
482        }
483        let template = self.template.as_ref().expect("template just set");
484        let mut spec_opts = options.clone();
485        spec_opts.dim_binding = None;
486        let pipe = crate::stages::pipeline_for(self.device, &spec_opts);
487        let specialized = template.specialize(&pipe, binding);
488        let backend = crate::registry::backend_for(self.device).expect("backend registered");
489        let mut compile_opts = options.clone();
490        compile_opts.dim_binding = None;
491        if compile_opts.policy.is_none() {
492            if let Some(p) = &self.policy {
493                compile_opts = compile_opts.policy(p.clone());
494            }
495        }
496        let executable = backend.compile_lir(specialized.lir, &compile_opts);
497        let compiled = CompiledGraph::new(executable, self.device);
498
499        if self.entries.len() >= self.capacity
500            && let Some(evict_key) = self.order.pop_front()
501        {
502            self.entries.retain(|(k, _)| *k != evict_key);
503        }
504        self.entries.push((key, compiled));
505        self.order.push_back(key);
506        Ok(&mut self.entries.last_mut().unwrap().1)
507    }
508
509    pub fn len(&self) -> usize {
510        self.entries.len()
511    }
512
513    pub fn is_empty(&self) -> bool {
514        self.entries.is_empty()
515    }
516
517    pub fn contains(&self, key: u64) -> bool {
518        self.entries.iter().any(|(k, _)| *k == key)
519    }
520
521    pub fn has_template(&self) -> bool {
522        self.template.is_some()
523    }
524
525    /// Build the symbolic template once (no specialization).
526    pub fn ensure_template<F: FnOnce() -> HirModule>(
527        &mut self,
528        build_hir: F,
529        options: &crate::CompileOptions,
530    ) -> Result<&CompileResult, rlx_ir::hir::LowerError> {
531        if self.template.is_none() {
532            let mut opts = options.clone();
533            opts.dim_binding = None;
534            let pipe = crate::stages::pipeline_for(self.device, &opts);
535            self.template = Some(pipe.compile_hir(build_hir())?);
536        }
537        Ok(self.template.as_ref().expect("template set"))
538    }
539
540    pub fn template_result(&self) -> Option<&CompileResult> {
541        self.template.as_ref()
542    }
543
544    /// Specialize via on-disk LIR cache ([`CompilationMode::Aot`]).
545    /// Disk-backed specialize ([`rlx_ir::CompilationMode::Aot`]).
546    pub fn get_or_specialize_aot<F: FnOnce() -> HirModule>(
547        &mut self,
548        aot: &crate::AotCache,
549        disk_base: &str,
550        key: u64,
551        binding: &rlx_ir::DimBinding,
552        build_hir: F,
553        options: &crate::CompileOptions,
554    ) -> Result<&mut CompiledGraph, crate::AotCacheError> {
555        if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
556            return Ok(&mut self.entries[idx].1);
557        }
558        let device = self.device;
559        let template = self.ensure_template(build_hir, options)?;
560        let compiled = aot.specialize_cached(disk_base, binding, device, template, options)?;
561        if self.entries.len() >= self.capacity
562            && let Some(evict_key) = self.order.pop_front()
563        {
564            self.entries.retain(|(k, _)| *k != evict_key);
565        }
566        self.entries.push((key, compiled));
567        self.order.push_back(key);
568        Ok(&mut self.entries.last_mut().unwrap().1)
569    }
570}
571
572/// Pad `data` (interpreted as `[actual, inner]` row-major) up to `upper`
573/// rows by appending zeros. Returns a `Vec<f32>` of length
574/// `upper * inner`. Companion of [`slice_rows`] for the
575/// "compile at max, run at less" workflow with [`BucketedCompileCache`].
576///
577/// Panics if `data.len()` is not a multiple of `inner`, if `inner == 0`,
578/// or if `data.len() / inner > upper`.
579pub fn pad_rows(data: &[f32], inner: usize, upper: u64) -> Vec<f32> {
580    assert!(inner > 0, "pad_rows: inner stride must be ≥ 1");
581    assert_eq!(
582        data.len() % inner,
583        0,
584        "pad_rows: data len {} not a multiple of inner {inner}",
585        data.len(),
586    );
587    let upper = upper as usize;
588    let actual = data.len() / inner;
589    assert!(
590        actual <= upper,
591        "pad_rows: actual rows {actual} exceed upper bound {upper}",
592    );
593    let mut out = vec![0.0_f32; upper * inner];
594    out[..actual * inner].copy_from_slice(data);
595    out
596}
597
598/// Slice `data` (interpreted as `[upper, inner]` row-major) down to
599/// `actual` rows. Companion of [`pad_rows`].
600///
601/// Panics if `data.len()` is not a multiple of `inner`, if `inner == 0`,
602/// or if `actual` exceeds the number of rows in `data`.
603pub fn slice_rows(data: &[f32], inner: usize, actual: usize) -> Vec<f32> {
604    assert!(inner > 0, "slice_rows: inner stride must be ≥ 1");
605    assert_eq!(
606        data.len() % inner,
607        0,
608        "slice_rows: data len {} not a multiple of inner {inner}",
609        data.len(),
610    );
611    let upper = data.len() / inner;
612    assert!(
613        actual <= upper,
614        "slice_rows: actual rows {actual} exceed upper {upper}",
615    );
616    data[..actual * inner].to_vec()
617}
618
619#[cfg(test)]
620mod tests {
621    use super::*;
622    use rlx_ir::infer::GraphExt;
623    use rlx_ir::*;
624    use std::cell::Cell;
625
626    fn tiny_graph(n: usize) -> Graph {
627        let mut g = Graph::new("t");
628        let f = DType::F32;
629        let x = g.input("x", Shape::new(&[n], f));
630        let y = g.activation(rlx_ir::op::Activation::Relu, x, Shape::new(&[n], f));
631        g.set_outputs(vec![y]);
632        g
633    }
634
635    #[test]
636    fn cache_hits_avoid_recompile() {
637        let mut cache = CompileCache::new(Device::Cpu, 4);
638        let calls = Cell::new(0);
639
640        let _ = cache.get_or_compile(1, || {
641            calls.set(calls.get() + 1);
642            tiny_graph(8)
643        });
644        let _ = cache.get_or_compile(1, || {
645            calls.set(calls.get() + 1);
646            tiny_graph(8)
647        });
648        let _ = cache.get_or_compile(1, || {
649            calls.set(calls.get() + 1);
650            tiny_graph(8)
651        });
652        // Same key three times: build closure runs once.
653        assert_eq!(calls.get(), 1);
654        assert_eq!(cache.len(), 1);
655    }
656
657    #[test]
658    fn fifo_evicts_oldest_at_capacity() {
659        let mut cache = CompileCache::new(Device::Cpu, 2);
660        let _ = cache.get_or_compile(1, || tiny_graph(4));
661        let _ = cache.get_or_compile(2, || tiny_graph(8));
662        assert!(cache.contains(1) && cache.contains(2));
663        // Third entry evicts key 1 (oldest).
664        let _ = cache.get_or_compile(3, || tiny_graph(16));
665        assert!(!cache.contains(1));
666        assert!(cache.contains(2) && cache.contains(3));
667    }
668
669    #[test]
670    fn different_keys_keep_separate_compiles() {
671        let mut cache = CompileCache::new(Device::Cpu, 4);
672        let calls = Cell::new(0);
673        let _ = cache.get_or_compile(1, || {
674            calls.set(calls.get() + 1);
675            tiny_graph(8)
676        });
677        let _ = cache.get_or_compile(2, || {
678            calls.set(calls.get() + 1);
679            tiny_graph(16)
680        });
681        let _ = cache.get_or_compile(1, || {
682            calls.set(calls.get() + 1);
683            tiny_graph(8)
684        });
685        // Two unique keys → two compiles.
686        assert_eq!(calls.get(), 2);
687        assert_eq!(cache.len(), 2);
688    }
689
690    // ── BucketedCompileCache ──────────────────────────────────────────
691
692    #[test]
693    fn bucket_amortizes_keys_within_range() {
694        let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..4, 4..16]);
695        let calls = Cell::new(0);
696        let uppers = Cell::new((0u64, 0u64));
697
698        // Two distinct keys (2 and 3) both fall inside bucket 0 (1..4).
699        let (u1, _) = cache
700            .get_or_compile(2, |upper| {
701                calls.set(calls.get() + 1);
702                uppers.set((upper, uppers.get().1));
703                tiny_graph(upper as usize)
704            })
705            .expect("key 2 in range");
706        let (u2, _) = cache
707            .get_or_compile(3, |upper| {
708                calls.set(calls.get() + 1);
709                uppers.set((uppers.get().0, upper));
710                tiny_graph(upper as usize)
711            })
712            .expect("key 3 in range");
713
714        // One compile, both calls saw the same upper = range.end - 1 = 3.
715        assert_eq!(calls.get(), 1);
716        assert_eq!(u1, 3);
717        assert_eq!(u2, 3);
718        assert_eq!(uppers.get().0, 3);
719        assert_eq!(cache.compiled_count(), 1);
720        assert_eq!(cache.total_buckets(), 2);
721    }
722
723    #[test]
724    fn bucket_lookup_returns_none_outside_range() {
725        let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..4, 4..16]);
726        assert!(cache.bucket_for(0).is_none());
727        assert!(cache.bucket_for(16).is_none());
728        assert!(cache.bucket_for(100).is_none());
729        assert_eq!(cache.bucket_for(3), Some(0));
730        assert_eq!(cache.bucket_for(4), Some(1));
731
732        let calls = Cell::new(0);
733        let result = cache.get_or_compile(100, |u| {
734            calls.set(calls.get() + 1);
735            tiny_graph(u as usize)
736        });
737        assert!(result.is_none());
738        assert_eq!(calls.get(), 0); // build closure must not run for OOR keys
739        assert_eq!(cache.compiled_count(), 0);
740    }
741
742    #[test]
743    fn bucket_compiles_lazily_per_bucket() {
744        let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..4, 4..16, 16..64]);
745        let calls = Cell::new(0);
746
747        let _ = cache.get_or_compile(2, |u| {
748            calls.set(calls.get() + 1);
749            tiny_graph(u as usize)
750        });
751        let _ = cache.get_or_compile(8, |u| {
752            calls.set(calls.get() + 1);
753            tiny_graph(u as usize)
754        });
755        // Two distinct buckets hit → two compiles. Third bucket untouched.
756        assert_eq!(calls.get(), 2);
757        assert_eq!(cache.compiled_count(), 2);
758        assert_eq!(cache.total_buckets(), 3);
759    }
760
761    #[test]
762    #[should_panic(expected = "overlap")]
763    fn bucket_overlap_rejected() {
764        let _ = BucketedCompileCache::new(Device::Cpu, vec![1..8, 4..16]);
765    }
766
767    #[test]
768    #[should_panic(expected = "≥1 bucket")]
769    fn empty_bucket_list_rejected() {
770        let _ = BucketedCompileCache::new(Device::Cpu, vec![]);
771    }
772
773    // ── pad_rows / slice_rows ─────────────────────────────────────────
774
775    #[test]
776    fn pad_rows_appends_zeros() {
777        // 1D: actual=3 → upper=5, inner=1.
778        let p = pad_rows(&[1.0, 2.0, 3.0], 1, 5);
779        assert_eq!(p, vec![1.0, 2.0, 3.0, 0.0, 0.0]);
780
781        // 2D row-major [actual=2, inner=3] → [upper=4, inner=3].
782        let p = pad_rows(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 4);
783        assert_eq!(
784            p,
785            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
786        );
787
788        // actual == upper: no-op pad.
789        let p = pad_rows(&[7.0, 8.0], 1, 2);
790        assert_eq!(p, vec![7.0, 8.0]);
791    }
792
793    #[test]
794    fn slice_rows_truncates_trailing() {
795        let s = slice_rows(&[1.0, 2.0, 3.0, 0.0, 0.0], 1, 3);
796        assert_eq!(s, vec![1.0, 2.0, 3.0]);
797
798        let s = slice_rows(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0, 0.0], 3, 2);
799        assert_eq!(s, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
800    }
801
802    #[test]
803    #[should_panic(expected = "exceed upper")]
804    fn pad_rows_rejects_too_long_input() {
805        let _ = pad_rows(&[1.0, 2.0, 3.0, 4.0], 1, 3);
806    }
807
808    #[test]
809    #[should_panic(expected = "exceed upper")]
810    fn slice_rows_rejects_too_large_actual() {
811        let _ = slice_rows(&[1.0, 2.0, 3.0], 1, 5);
812    }
813
814    // ── BucketedCompileCache::run_padded ──────────────────────────────
815
816    #[test]
817    fn run_padded_pads_input_and_slices_output() {
818        // tiny_graph is 1D [n] → relu → [n].
819        // Compile bucket [1..16) at upper=15, run with actual_rows=10.
820        let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
821        let input: Vec<f32> = vec![1.0, -1.0, 2.0, -2.0, 3.0, -3.0, 4.0, -4.0, 5.0, -5.0];
822
823        let (upper, outs) = cache
824            .run_padded(
825                10, // key
826                10, // actual rows
827                |max| tiny_graph(max as usize),
828                &[("x", &input, 1)], // 1D, inner stride 1
829                &[1],                // slice the one output to actual rows
830            )
831            .expect("key 10 in [1..16)");
832
833        assert_eq!(upper, 15);
834        assert_eq!(outs.len(), 1);
835        let out = &outs[0];
836        assert_eq!(out.len(), 10, "output sliced back to actual_rows");
837        let expected: Vec<f32> = input.iter().map(|x| x.max(0.0)).collect();
838        assert_eq!(out, &expected);
839    }
840
841    #[test]
842    fn run_padded_reuses_bucket_across_actuals() {
843        // Same bucket, two different actuals — only one compile.
844        let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
845        let calls = Cell::new(0);
846
847        let (u1, o1) = cache
848            .run_padded(
849                10,
850                10,
851                |max| {
852                    calls.set(calls.get() + 1);
853                    tiny_graph(max as usize)
854                },
855                &[(
856                    "x",
857                    &[1.0, -1.0, 2.0, -2.0, 3.0, -3.0, 4.0, -4.0, 5.0, -5.0],
858                    1,
859                )],
860                &[1],
861            )
862            .unwrap();
863        assert_eq!(o1.len(), 1);
864        assert_eq!(o1[0].len(), 10);
865        assert_eq!(u1, 15);
866
867        let (u2, o2) = cache
868            .run_padded(
869                5,
870                5,
871                |max| {
872                    calls.set(calls.get() + 1);
873                    tiny_graph(max as usize)
874                },
875                &[("x", &[-1.0, 2.0, -3.0, 4.0, -5.0], 1)],
876                &[1],
877            )
878            .unwrap();
879        assert_eq!(o2.len(), 1);
880        assert_eq!(o2[0].len(), 5);
881        assert_eq!(u2, 15);
882        assert_eq!(o2[0], vec![0.0, 2.0, 0.0, 4.0, 0.0]);
883
884        assert_eq!(calls.get(), 1, "bucket cached across actuals");
885        assert_eq!(cache.compiled_count(), 1);
886    }
887
888    #[test]
889    fn run_padded_returns_none_out_of_range() {
890        let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
891        let calls = Cell::new(0);
892        let result = cache.run_padded(
893            100,
894            5,
895            |u| {
896                calls.set(calls.get() + 1);
897                tiny_graph(u as usize)
898            },
899            &[("x", &[1.0, 2.0, 3.0, 4.0, 5.0], 1)],
900            &[1],
901        );
902        assert!(result.is_none());
903        assert_eq!(calls.get(), 0);
904        assert_eq!(cache.compiled_count(), 0);
905    }
906
907    // ── power_of_two_ladder ───────────────────────────────────────────
908
909    #[test]
910    fn power_of_two_ladder_generates_log_buckets() {
911        let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 8, 64);
912        // Expect buckets covering keys 1..=64 with extents 8, 16, 32, 64.
913        let ranges: Vec<_> = cache.buckets().cloned().collect();
914        assert_eq!(ranges, vec![1..9, 9..17, 17..33, 33..65]);
915        assert_eq!(cache.total_buckets(), 4);
916    }
917
918    #[test]
919    fn power_of_two_ladder_picks_smallest_extent_for_actual() {
920        // Ladder: extents 8, 16, 32, 64. actual=17 lands in the 32-extent
921        // bucket, NOT the 64-extent one — that's the compute saving.
922        let mut cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 8, 64);
923        let captured_uppers: std::cell::RefCell<Vec<u64>> = Default::default();
924
925        let (u17, _) = cache
926            .get_or_compile(17, |upper| {
927                captured_uppers.borrow_mut().push(upper);
928                tiny_graph(upper as usize)
929            })
930            .unwrap();
931        let (u9, _) = cache
932            .get_or_compile(9, |upper| {
933                captured_uppers.borrow_mut().push(upper);
934                tiny_graph(upper as usize)
935            })
936            .unwrap();
937        let (u3, _) = cache
938            .get_or_compile(3, |upper| {
939                captured_uppers.borrow_mut().push(upper);
940                tiny_graph(upper as usize)
941            })
942            .unwrap();
943        let (u64_, _) = cache
944            .get_or_compile(64, |upper| {
945                captured_uppers.borrow_mut().push(upper);
946                tiny_graph(upper as usize)
947            })
948            .unwrap();
949
950        assert_eq!(u17, 32, "key=17 → smallest extent ≥ 17 is 32");
951        assert_eq!(u9, 16, "key=9  → smallest extent ≥ 9  is 16");
952        assert_eq!(u3, 8, "key=3  → smallest extent ≥ 3  is 8");
953        assert_eq!(u64_, 64, "key=64 → exact match at 64");
954        assert_eq!(*captured_uppers.borrow(), vec![32, 16, 8, 64]);
955        assert_eq!(cache.compiled_count(), 4);
956    }
957
958    #[test]
959    fn power_of_two_ladder_min_above_one_starts_at_one() {
960        // First bucket always covers from key 1, even when min > 1.
961        // (`min` controls the ladder's first extent, not the lower edge.)
962        let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 16, 32);
963        let ranges: Vec<_> = cache.buckets().cloned().collect();
964        // min=16 → first extent 16, second 32. Buckets: 1..17, 17..33.
965        assert_eq!(ranges, vec![1..17, 17..33]);
966    }
967
968    #[test]
969    fn power_of_two_ladder_non_pow2_min_rounds_up() {
970        // min=10 → next_power_of_two = 16.
971        let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 10, 64);
972        let ranges: Vec<_> = cache.buckets().cloned().collect();
973        assert_eq!(ranges, vec![1..17, 17..33, 33..65]);
974    }
975
976    #[test]
977    fn power_of_two_ladder_max_below_pow2_extends_up() {
978        // max=20 needs to be covered → ladder extends to 32.
979        let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 8, 20);
980        let ranges: Vec<_> = cache.buckets().cloned().collect();
981        assert_eq!(ranges, vec![1..9, 9..17, 17..33]);
982    }
983
984    #[test]
985    fn power_of_two_ladder_min_equals_max() {
986        let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 16, 16);
987        let ranges: Vec<_> = cache.buckets().cloned().collect();
988        assert_eq!(ranges, vec![1..17]);
989    }
990
991    #[test]
992    #[should_panic(expected = "min must be ≥ 1")]
993    fn power_of_two_ladder_zero_min_rejected() {
994        let _ = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 0, 16);
995    }
996
997    #[test]
998    #[should_panic(expected = "max")]
999    fn power_of_two_ladder_max_below_min_rejected() {
1000        let _ = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 32, 8);
1001    }
1002
1003    // ── Active-extent dispatch (true per-kernel skip-compute) ─────────
1004    //
1005    // The 3 tests below assert per-thunk active-extent scaling on the CPU
1006    // backend. Today `rlx_cpu::thunk::execute_thunks_active` is documented
1007    // as a stub that returns false (rlx-cpu/src/thunk.rs:2100-2110), so
1008    // the runtime falls back to full-extent dispatch — overwrites the
1009    // tail and the tail-preservation assertions fail. They're left here
1010    // (marked `#[ignore]`) as the test-driven contract that the future
1011    // active-extent implementation must satisfy. Drop the `#[ignore]`
1012    // when the per-thunk scaling lands for Copy / ActivationInPlace /
1013    // BinaryFull / Attention.
1014
1015    #[test]
1016    #[ignore = "active-extent execution is a stub on CPU (thunk.rs::execute_thunks_active)"]
1017    fn active_extent_skips_compute_on_cpu_activation() {
1018        // tiny_graph(15) is `Input([15]) → Relu → Output` and lowers to
1019        // a Copy + ActivationInPlace pair on CPU — both are in the safe
1020        // set, so the active-extent path runs scaled.
1021        //
1022        // To prove kernels actually skipped: warm the arena with a prior
1023        // full-extent run whose output is `[1.0; 15]`, then run again
1024        // with a negative-only input and active=5. The first 5 outputs
1025        // get re-copied + re-relu'd to 0; the tail (indices 5..15) stays
1026        // at 1.0 because both Copy and Activation skipped it. A full-
1027        // extent fallback would clip every element to 0.
1028        let graph = tiny_graph(15);
1029        let mut compiled = Session::new(Device::Cpu).compile(graph);
1030
1031        // Warm-up: full extent, all-positive input → output [1.0; 15].
1032        let warm_input: Vec<f32> = vec![1.0; 15];
1033        let warm_outs = compiled.run(&[("x", &warm_input)]);
1034        assert_eq!(warm_outs[0], vec![1.0; 15], "warm-up sanity");
1035
1036        // Active-extent run: all-negative input, hint actual=5 of 15.
1037        // First 5: Copy(-1) + Relu → 0. Tail: kernels skip → stays 1.0.
1038        let neg_input: Vec<f32> = vec![-1.0; 15];
1039        compiled.set_active_extent(Some((5, 15)));
1040        let outs = compiled.run(&[("x", &neg_input)]);
1041        let out = &outs[0];
1042
1043        assert_eq!(out.len(), 15);
1044        assert_eq!(
1045            out[..5],
1046            [0.0; 5],
1047            "first 5 elements processed (relu of -1)"
1048        );
1049        assert_eq!(
1050            out[5..],
1051            [1.0; 10],
1052            "tail untouched — proves Copy + Activation skipped indices 5..15"
1053        );
1054
1055        // Clear the hint and run again with the negative input — full
1056        // extent now processes everything, every element clips to 0.
1057        compiled.set_active_extent(None);
1058        let outs = compiled.run(&[("x", &neg_input)]);
1059        assert_eq!(
1060            outs[0],
1061            vec![0.0; 15],
1062            "full-extent path must clip every negative"
1063        );
1064    }
1065
1066    #[test]
1067    #[ignore = "active-extent execution is a stub on CPU (thunk.rs::execute_thunks_active)"]
1068    fn active_extent_skips_compute_on_binary_full() {
1069        // Input([4]) + Input([4]) → Output. Lowers to a BinaryFull
1070        // thunk with no broadcast (lhs_len == rhs_len == len), which
1071        // is in the safe set.
1072        let mut g = Graph::new("add");
1073        let f = DType::F32;
1074        let a = g.input("a", Shape::new(&[4], f));
1075        let b = g.input("b", Shape::new(&[4], f));
1076        let c = g.add(a, b);
1077        g.set_outputs(vec![c]);
1078        let mut compiled = Session::new(Device::Cpu).compile(g);
1079
1080        // Warm: full extent, output buffer becomes [2.0; 4].
1081        let warm = compiled.run(&[("a", &[1.0f32; 4]), ("b", &[1.0f32; 4])]);
1082        assert_eq!(warm[0], vec![2.0; 4]);
1083
1084        // Active-extent run: actual=2 of upper=4. Process first 2
1085        // elements only; tail (indices 2..4) stays at 2.0 from warm.
1086        compiled.set_active_extent(Some((2, 4)));
1087        let outs = compiled.run(&[("a", &[10.0f32; 4]), ("b", &[10.0f32; 4])]);
1088        let out = &outs[0];
1089        assert_eq!(out[..2], [20.0, 20.0], "first 2 = active sum");
1090        assert_eq!(
1091            out[2..],
1092            [2.0, 2.0],
1093            "tail untouched — proves BinaryFull skipped indices 2..4"
1094        );
1095
1096        // Clear hint → full path overwrites entire output.
1097        compiled.set_active_extent(None);
1098        let outs = compiled.run(&[("a", &[10.0f32; 4]), ("b", &[10.0f32; 4])]);
1099        assert_eq!(outs[0], vec![20.0; 4]);
1100    }
1101
1102    #[test]
1103    #[ignore = "process-wide STATE; runs only in isolation via `cargo test perfetto -- --ignored`"]
1104    fn perfetto_trace_emits_per_thunk_events() {
1105        // PLAN L3: end-to-end Perfetto event capture. Requires the env
1106        // var to be set BEFORE the perfetto module is first touched
1107        // (OnceLock — can't re-init). We set it here unconditionally;
1108        // for tests run in parallel within the same process, the
1109        // earliest test wins. To avoid flake we mark this `#[ignore]`
1110        // and the developer runs it explicitly.
1111        use std::env;
1112        use std::fs;
1113        let path = env::temp_dir().join(format!("rlx-perfetto-e2e-{}.json", std::process::id()));
1114        if path.exists() {
1115            let _ = fs::remove_file(&path);
1116        }
1117        unsafe {
1118            env::set_var("RLX_TRACE_PERFETTO", &path);
1119        }
1120
1121        // Build + run a small CPU graph — Add → Relu (no fusion macros).
1122        let f = DType::F32;
1123        let mut g = Graph::new("perf");
1124        let a = g.input("a", Shape::new(&[4], f));
1125        let b = g.input("b", Shape::new(&[4], f));
1126        let s = g.add(a, b);
1127        let r = g.relu(s);
1128        g.set_outputs(vec![r]);
1129        let mut compiled = Session::new(Device::Cpu).compile(g);
1130        let _ = compiled.run(&[("a", &[1.0; 4]), ("b", &[1.0; 4])]);
1131
1132        // Force the trace file to flush its closing bracket.
1133        crate::perfetto::flush_and_finalize();
1134
1135        let contents = fs::read_to_string(&path).expect("trace file");
1136        // At minimum we should see one of our thunk names.
1137        assert!(
1138            contents.contains("\"binary\"")
1139                || contents.contains("\"activation\"")
1140                || contents.contains("\"elementwise_region\""),
1141            "expected at least one thunk-name event in perfetto trace; got: {contents}"
1142        );
1143        // JSON shape: starts with `[` and (after flush) ends with `]`.
1144        assert!(contents.trim_start().starts_with('['));
1145        let _ = fs::remove_file(&path);
1146    }
1147
1148    #[test]
1149    fn elementwise_region_fused_matches_unfused() {
1150        // PLAN L2: a chain `Add(a, b) → Mul(_, c) → Relu` should fuse
1151        // into one ElementwiseRegion thunk in the CPU backend. Compare
1152        // its output against the value computed by hand to confirm the
1153        // fused execution is numerically identical.
1154        let f = DType::F32;
1155        let mut g = Graph::new("ew_e2e");
1156        let a = g.input("a", Shape::new(&[8], f));
1157        let b = g.input("b", Shape::new(&[8], f));
1158        let c = g.input("c", Shape::new(&[8], f));
1159        let s = Shape::new(&[8], f);
1160        let add = g.add(a, b);
1161        let mul = g.mul(add, c);
1162        let relu = g.relu(mul);
1163        let _ = s;
1164        g.set_outputs(vec![relu]);
1165
1166        let mut compiled = Session::new(Device::Cpu).compile(g);
1167        let av: Vec<f32> = vec![1.0, -2.0, 3.0, -4.0, 0.5, -0.5, 1.5, -1.5];
1168        let bv: Vec<f32> = vec![0.5, 1.0, 2.0, 4.0, 0.5, 0.5, 0.5, 0.5];
1169        let cv: Vec<f32> = vec![1.0, 2.0, 1.0, 1.0, 2.0, 3.0, 0.5, 4.0];
1170        let outs = compiled.run(&[("a", &av), ("b", &bv), ("c", &cv)]);
1171        let out = &outs[0];
1172
1173        let expected: Vec<f32> = (0..8)
1174            .map(|i| {
1175                let v = (av[i] + bv[i]) * cv[i];
1176                v.max(0.0)
1177            })
1178            .collect();
1179        for (i, (got, exp)) in out.iter().zip(&expected).enumerate() {
1180            assert!(
1181                (got - exp).abs() < 1e-6,
1182                "mismatch at {i}: got {got}, expected {exp}"
1183            );
1184        }
1185    }
1186
1187    #[test]
1188    #[ignore = "active-extent execution is a stub on CPU (thunk.rs::execute_thunks_active)"]
1189    fn active_extent_skips_compute_on_attention() {
1190        // Standalone Attention with kernel-synthesized MaskKind::None.
1191        // Q/K/V shape: [batch=1, seq=4, num_heads*head_dim=8].
1192        use rlx_ir::op::MaskKind;
1193        let f = DType::F32;
1194        let mut g = Graph::new("attn");
1195        let q = g.input("q", Shape::new(&[1, 4, 8], f));
1196        let k = g.input("k", Shape::new(&[1, 4, 8], f));
1197        let v = g.input("v", Shape::new(&[1, 4, 8], f));
1198        let out = g.attention_kind(q, k, v, 2, 4, MaskKind::None, Shape::new(&[1, 4, 8], f));
1199        g.set_outputs(vec![out]);
1200        let mut compiled = Session::new(Device::Cpu).compile(g);
1201
1202        // Warm: full extent. Q=K=V uniform → output uniform-ish.
1203        let warm = compiled.run(&[
1204            ("q", &[1.0f32; 32]),
1205            ("k", &[1.0f32; 32]),
1206            ("v", &[1.0f32; 32]),
1207        ]);
1208        let warm_out = warm[0].clone();
1209        assert_eq!(warm_out.len(), 32);
1210
1211        // Active: s_active=2 of s_full=4. Different inputs.
1212        // Tail rows (indices 16..32 = positions 2,3) should be untouched
1213        // — preserved from the warm run. First 16 indices recomputed.
1214        compiled.set_active_extent(Some((2, 4)));
1215        let outs = compiled.run(&[
1216            ("q", &[3.0f32; 32]),
1217            ("k", &[3.0f32; 32]),
1218            ("v", &[3.0f32; 32]),
1219        ]);
1220        let out = &outs[0];
1221        assert_eq!(out.len(), 32);
1222        assert_eq!(
1223            &out[16..],
1224            &warm_out[16..],
1225            "tail (positions 2,3) must be untouched — proves Attention skipped"
1226        );
1227        // Sanity: first 2 positions changed since input value differs (3.0 vs 1.0).
1228        assert_ne!(
1229            &out[..16],
1230            &warm_out[..16],
1231            "first 2 positions should reflect new input"
1232        );
1233    }
1234
1235    #[test]
1236    fn active_extent_falls_back_when_unsupported_thunk_in_schedule() {
1237        // A graph containing any thunk outside `safe_for_active_extent`
1238        // (e.g. Sgemm via a matmul) must fall back to the full-extent
1239        // executor — partial application would feed garbage downstream.
1240        // We can't easily construct such a graph at this layer without
1241        // pulling in matmul builders, but we can verify the trait
1242        // contract via the simpler check: setting an extent hint on a
1243        // matmul-bearing graph still gives correct outputs (full-extent
1244        // fallback path was taken).
1245        //
1246        // Skipped explicit construction here — the safety net is the
1247        // `if !all(safe) return false` guard inside execute_thunks_active
1248        // plus the `if !active_used { execute_thunks(...) }` fallback in
1249        // the CPU executor, both unit-tested via direct safety-predicate
1250        // and the warm-arena test above.
1251    }
1252
1253    #[test]
1254    fn run_padded_uses_active_extent_on_cpu() {
1255        // End-to-end: the cache wires set_active_extent before run.
1256        // Same setup as above but driven through run_padded.
1257        let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
1258        let input: Vec<f32> = vec![
1259            1.0, -1.0, 2.0, -2.0, 3.0, // 5 real values
1260            -10.0, -20.0, -30.0, -40.0, -50.0, // padding zeros from pad_rows
1261        ];
1262        // pad_rows zero-pads from len=5 up to upper=15, so the arena
1263        // tail past index 5 is 0.0 going in. After active-extent run,
1264        // tail stays at 0.0 (untouched, but the value happens to match
1265        // what relu would produce). We can't observe skip via output
1266        // here — slice_rows trims to actual_rows anyway.
1267        let (upper, outs) = cache
1268            .run_padded(
1269                5,
1270                5,
1271                |max| tiny_graph(max as usize),
1272                &[("x", &input[..5], 1)],
1273                &[1],
1274            )
1275            .unwrap();
1276        assert_eq!(upper, 15);
1277        assert_eq!(outs[0].len(), 5);
1278        // Active-extent path (CPU honors): outputs match relu of the
1279        // first 5 inputs. Slicing already handled, so user-visible
1280        // result is the same whether or not the kernel skipped tail
1281        // compute. The point of this test is just to confirm the wiring
1282        // path doesn't crash and produces correct outputs end-to-end.
1283        assert_eq!(outs[0], vec![1.0, 0.0, 2.0, 0.0, 3.0]);
1284    }
1285
1286    #[test]
1287    fn run_padded_inner_zero_returns_output_unsliced() {
1288        // Marking output_inners[0] = 0 disables slicing for that output.
1289        // The compiled graph still runs at upper=15, so we expect 15 outputs back.
1290        let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
1291        let input: Vec<f32> = vec![1.0, -1.0, 2.0, -2.0, 3.0];
1292
1293        let (upper, outs) = cache
1294            .run_padded(
1295                5,
1296                5,
1297                |max| tiny_graph(max as usize),
1298                &[("x", &input, 1)],
1299                &[0], // don't slice this output
1300            )
1301            .unwrap();
1302
1303        assert_eq!(upper, 15);
1304        assert_eq!(
1305            outs[0].len(),
1306            15,
1307            "unsliced output preserves full upper extent"
1308        );
1309        // First 5 = relu of input, tail 10 = relu(0) = 0.
1310        assert_eq!(&outs[0][..5], &[1.0, 0.0, 2.0, 0.0, 3.0]);
1311        assert!(outs[0][5..].iter().all(|&v| v == 0.0));
1312    }
1313
1314    #[test]
1315    fn dynamic_dim_cache_specializes_per_key() {
1316        use rlx_ir::DType;
1317        use rlx_ir::Shape;
1318        use rlx_ir::hir::HirModule;
1319        use rlx_ir::sym;
1320
1321        let mut cache = DynamicDimCompileCache::new(Device::Cpu, 4);
1322        let opts = crate::CompileOptions::new();
1323        {
1324            let _short = cache
1325                .get_or_specialize(
1326                    8,
1327                    &rlx_ir::DimBinding::batch_seq(1, 8),
1328                    || {
1329                        let mut hir = HirModule::new("dyn_cache");
1330                        let x = hir.input_batch_seq("x", sym::BATCH, sym::SEQ, 4, DType::F32);
1331                        let w = hir.param("w", Shape::new(&[4, 2], DType::F32));
1332                        let y = hir.linear(
1333                            x,
1334                            w,
1335                            None,
1336                            None,
1337                            Shape::batch_seq(sym::BATCH, sym::SEQ, 2, DType::F32),
1338                        );
1339                        hir.set_outputs(vec![y]);
1340                        hir
1341                    },
1342                    &opts,
1343                )
1344                .expect("specialize short");
1345        }
1346        assert!(cache.has_template());
1347        assert_eq!(cache.len(), 1);
1348        cache
1349            .get_or_specialize(
1350                128,
1351                &rlx_ir::DimBinding::batch_seq(1, 128),
1352                || panic!("HIR builder must not run twice"),
1353                &opts,
1354            )
1355            .expect("specialize long");
1356        assert_eq!(cache.len(), 2);
1357    }
1358}