Skip to main content

rlx_runtime/
compile_cache.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Shape-bucketed compile cache.
17//!
18//! Lets variable-shape callers (e.g., embedding-model wrappers that vary
19//! batch + seq per request) amortize the per-(shape) compile cost. Cache
20//! keys are caller-provided `u64`s — the caller decides what counts as a
21//! shape bucket. Typical recipe: `(batch as u64) << 32 | seq as u64`.
22//!
23//! The cache stores one `CompiledGraph` per key. Params loaded onto a
24//! cached entry persist for that entry — re-fetching from cache does
25//! **not** require re-running `set_param`. Eviction is FIFO, capped at
26//! `capacity` entries (good enough for the current "a handful of common
27//! shapes" usage pattern; switch to LRU if a real workload shows churn).
28//!
29//! # Example
30//!
31//! ```rust,ignore
32//! let mut cache = CompileCache::new(Device::Metal, 8);
33//! let key = ((batch as u64) << 32) | seq as u64;
34//! let mut compiled = cache.get_or_compile(key, || build_my_graph(batch, seq));
35//! // First call for `key`: compiles. Subsequent calls: cache hit.
36//! compiled.run(&[("x", &input_data)]);
37//! ```
38
39use crate::{CompiledGraph, Device, Session};
40use rlx_ir::DimBinding;
41use rlx_ir::Graph;
42use rlx_ir::hir::HirModule;
43use rlx_opt::CompileResult;
44use std::collections::HashMap;
45use std::collections::VecDeque;
46use std::ops::Range;
47
48/// Named runtime input for [`BucketedCompileCache::run_padded_mixed`].
49pub struct CacheRunInput<'a> {
50    pub name: &'a str,
51    pub data: &'a [f32],
52    /// Row inner stride for [`pad_rows`]; `None` = use data as-is (no padding).
53    pub row_inner: Option<usize>,
54}
55
56pub struct CompileCache {
57    device: Device,
58    capacity: usize,
59    // Per-cache precision policy. None → default (F32). Set once at
60    // construction; applies to every compile this cache performs.
61    policy: Option<rlx_opt::PrecisionPolicy>,
62    // (key, compiled). Vec keeps insertion order for FIFO eviction; the
63    // expected hit-rate at our cap (~8) makes the linear scan cheaper
64    // than a HashMap + separate eviction list.
65    entries: Vec<(u64, CompiledGraph)>,
66    // Insertion order for eviction.
67    order: VecDeque<u64>,
68}
69
70impl CompileCache {
71    pub fn new(device: Device, capacity: usize) -> Self {
72        Self::with_policy(device, capacity, None)
73    }
74
75    /// Cache that compiles every entry with the given precision policy.
76    /// Use this when the cached entries should differ from CPU-default
77    /// F32 — e.g., `PrecisionPolicy::AutoMixed` for f16 compute on Metal.
78    pub fn with_policy(
79        device: Device,
80        capacity: usize,
81        policy: Option<rlx_opt::PrecisionPolicy>,
82    ) -> Self {
83        assert!(capacity > 0, "CompileCache capacity must be ≥ 1");
84        Self {
85            device,
86            capacity,
87            policy,
88            entries: Vec::with_capacity(capacity),
89            order: VecDeque::with_capacity(capacity),
90        }
91    }
92
93    /// Compile if not present, then return a mutable reference. The borrow
94    /// lifetime is tied to `&mut self` so callers naturally serialize their
95    /// use of any one entry — the cache is single-owner today.
96    pub fn get_or_compile<F: FnOnce() -> Graph>(
97        &mut self,
98        key: u64,
99        build: F,
100    ) -> &mut CompiledGraph {
101        self.get_or_compile_with_options(key, build, &crate::CompileOptions::new())
102    }
103
104    /// Like [`Self::get_or_compile`] with explicit [`CompileOptions`].
105    pub fn get_or_compile_with_options<F: FnOnce() -> Graph>(
106        &mut self,
107        key: u64,
108        build: F,
109        options: &crate::CompileOptions,
110    ) -> &mut CompiledGraph {
111        if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
112            return &mut self.entries[idx].1;
113        }
114        let mut session = Session::new(self.device);
115        if let Some(p) = &self.policy {
116            session = session.with_policy(p.clone());
117        }
118        let compiled = session.compile_with(build(), options);
119
120        // Evict FIFO if at capacity.
121        if self.entries.len() >= self.capacity
122            && let Some(evict_key) = self.order.pop_front()
123        {
124            self.entries.retain(|(k, _)| *k != evict_key);
125        }
126        self.entries.push((key, compiled));
127        self.order.push_back(key);
128        &mut self.entries.last_mut().unwrap().1
129    }
130
131    /// Number of entries currently cached. Useful for tests + diagnostics.
132    pub fn len(&self) -> usize {
133        self.entries.len()
134    }
135    pub fn is_empty(&self) -> bool {
136        self.entries.is_empty()
137    }
138    /// Was this key already compiled? Doesn't change recency.
139    pub fn contains(&self, key: u64) -> bool {
140        self.entries.iter().any(|(k, _)| *k == key)
141    }
142}
143
144// ── Bucketed cache (PLAN L1) ──────────────────────────────────────────
145//
146// Variant of `CompileCache` that compiles one `CompiledGraph` per shape
147// *range* instead of per exact key. The caller declares buckets up front
148// (e.g. `1..16`, `16..64`, `64..256`); each bucket is compiled lazily at
149// its upper bound the first time a key in that bucket arrives.
150//
151// Trade vs `CompileCache`: unique keys → unique compiles becomes unique
152// buckets → unique compiles. The compiled graph is specialized for each
153// bucket's upper-bound dim. Two ways to use it:
154//
155// **Manual padding** — caller drives the pad/slice cycle:
156// ```rust,ignore
157// let buckets = vec![1..16, 16..64, 64..256];
158// let mut cache = BucketedCompileCache::new(Device::Metal, buckets);
159// let (upper, compiled) = cache
160//     .get_or_compile(seq as u64, |max_seq| build_graph(max_seq as usize))
161//     .expect("seq within buckets");
162// // pad input to `upper as usize` elements before run
163// compiled.run(&[("x", &padded)]);
164// ```
165//
166// **`run_padded` shortcut** — cache pads and slices for you:
167// ```rust,ignore
168// let (upper, outputs) = cache.run_padded(
169//     seq as u64,
170//     seq,                                    // actual rows
171//     |max_seq| build_graph(max_seq as usize),
172//     &[("x", &raw_input, hidden)],           // (name, data, inner stride)
173//     &[hidden],                              // per-output inner stride
174// ).expect("in range");
175// ```
176//
177// **How "skip compute" actually works here**: each bucket compiles at
178// its own upper bound, so kernels run at *that* extent, not at some
179// global maximum. Smaller buckets ⇒ less padded compute. The
180// `power_of_two_ladder` constructor builds a logarithmic schedule that
181// guarantees ≤2× padding waste in exchange for `O(log max)` compiled
182// artifacts. For finer control, hand-construct the bucket list.
183//
184// True per-kernel active-extent dispatch (one big compile, runtime
185// extent override that short-circuits each kernel's inner loop) is a
186// per-backend change across `rlx-cuda`, `rlx-rocm`,
187// `rlx-cpu/src/thunk.rs`, `rlx-metal/src/thunk.rs`, `rlx-mlx`,
188// `rlx-wgpu` — multi-day project, not in this layer.
189
190pub struct BucketedCompileCache {
191    device: Device,
192    policy: Option<rlx_opt::PrecisionPolicy>,
193    buckets: Vec<Bucket>,
194}
195
196struct Bucket {
197    range: Range<u64>,
198    compiled: Option<CompiledGraph>,
199}
200
201impl BucketedCompileCache {
202    pub fn new(device: Device, buckets: Vec<Range<u64>>) -> Self {
203        Self::with_policy(device, buckets, None)
204    }
205
206    /// Power-of-two ladder over `[1, max]`, with extents
207    /// `[min_pow2, 2·min_pow2, 4·min_pow2, …, max_pow2]` where
208    /// `min_pow2 = min.next_power_of_two()` and `max_pow2` is the smallest
209    /// power of two ≥ `max`. Each bucket compiles at its upper-bound
210    /// extent, so an `actual` value in bucket `(prev_extent .. ext]` runs
211    /// kernels at extent `ext` (not at the worst case of the whole range).
212    /// Guarantees compute waste from padding ≤2× — `actual > ext / 2`
213    /// for every bucket except possibly the smallest.
214    ///
215    /// Example: `power_of_two_ladder(Device::Cpu, 8, 256)` yields buckets
216    /// `1..9, 9..17, 17..33, 33..65, 65..129, 129..257` with compile
217    /// extents `8, 16, 32, 64, 128, 256`. An `actual = 17` runs at extent
218    /// 32 instead of the 255 a single wide `1..256` bucket would compile
219    /// at — that's the "skip compute" win, paid for with `O(log max)`
220    /// compiled artifacts instead of one.
221    pub fn power_of_two_ladder(device: Device, min: u64, max: u64) -> Self {
222        Self::power_of_two_ladder_with_policy(device, min, max, None)
223    }
224
225    pub fn power_of_two_ladder_with_policy(
226        device: Device,
227        min: u64,
228        max: u64,
229        policy: Option<rlx_opt::PrecisionPolicy>,
230    ) -> Self {
231        assert!(min >= 1, "power_of_two_ladder: min must be ≥ 1, got {min}");
232        assert!(
233            max >= min,
234            "power_of_two_ladder: max ({max}) must be ≥ min ({min})"
235        );
236        let mut buckets: Vec<Range<u64>> = Vec::new();
237        let mut start = 1u64;
238        let mut extent = min.next_power_of_two();
239        loop {
240            buckets.push(start..(extent + 1));
241            if extent >= max {
242                break;
243            }
244            start = extent + 1;
245            extent = extent
246                .checked_mul(2)
247                .expect("power_of_two_ladder: extent overflow");
248        }
249        Self::with_policy(device, buckets, policy)
250    }
251
252    pub fn with_policy(
253        device: Device,
254        buckets: Vec<Range<u64>>,
255        policy: Option<rlx_opt::PrecisionPolicy>,
256    ) -> Self {
257        assert!(!buckets.is_empty(), "BucketedCompileCache needs ≥1 bucket");
258        for (i, b) in buckets.iter().enumerate() {
259            assert!(b.start < b.end, "bucket {i} ({b:?}) is empty");
260            if i + 1 < buckets.len() {
261                assert!(
262                    b.end <= buckets[i + 1].start,
263                    "buckets {i} ({b:?}) and {} ({:?}) overlap",
264                    i + 1,
265                    buckets[i + 1],
266                );
267            }
268        }
269        let buckets = buckets
270            .into_iter()
271            .map(|range| Bucket {
272                range,
273                compiled: None,
274            })
275            .collect();
276        Self {
277            device,
278            policy,
279            buckets,
280        }
281    }
282
283    /// Find the bucket containing `key`, compile if needed, return
284    /// `(upper, &mut CompiledGraph)` where `upper = range.end - 1` is the
285    /// extent the graph was compiled for. Caller pads inputs to `upper`
286    /// before calling `run`. Returns `None` if `key` is outside every
287    /// bucket — caller decides whether to fall back to a one-off compile.
288    ///
289    /// `build` receives `upper` and must return a `Graph` specialized for
290    /// that extent.
291    pub fn get_or_compile<F: FnOnce(u64) -> Graph>(
292        &mut self,
293        key: u64,
294        build: F,
295    ) -> Option<(u64, &mut CompiledGraph)> {
296        self.get_or_compile_with_options(key, build, &crate::CompileOptions::new())
297    }
298
299    /// Like [`Self::get_or_compile`] with explicit [`CompileOptions`].
300    pub fn get_or_compile_with_options<F: FnOnce(u64) -> Graph>(
301        &mut self,
302        key: u64,
303        build: F,
304        options: &crate::CompileOptions,
305    ) -> Option<(u64, &mut CompiledGraph)> {
306        let idx = self.bucket_for(key)?;
307        let upper = self.buckets[idx].range.end - 1;
308        if self.buckets[idx].compiled.is_none() {
309            let mut session = Session::new(self.device);
310            if let Some(p) = &self.policy {
311                session = session.with_policy(p.clone());
312            }
313            self.buckets[idx].compiled = Some(session.compile_with(build(upper), options));
314        }
315        Some((upper, self.buckets[idx].compiled.as_mut().unwrap()))
316    }
317
318    /// Like [`Self::get_or_compile`] but builds and compiles HIR directly
319    /// through the fusion-first pipeline (`Session::compile_hir`).
320    pub fn get_or_compile_hir<F: FnOnce(u64) -> HirModule>(
321        &mut self,
322        key: u64,
323        build: F,
324    ) -> Option<(u64, &mut CompiledGraph)> {
325        self.get_or_compile_hir_with_options(key, build, &crate::CompileOptions::new())
326    }
327
328    /// Like [`Self::get_or_compile_hir`] with explicit [`CompileOptions`] (tier-1 profile, fusion target, …).
329    pub fn get_or_compile_hir_with_options<F: FnOnce(u64) -> HirModule>(
330        &mut self,
331        key: u64,
332        build: F,
333        options: &crate::CompileOptions,
334    ) -> Option<(u64, &mut CompiledGraph)> {
335        let idx = self.bucket_for(key)?;
336        let upper = self.buckets[idx].range.end - 1;
337        if self.buckets[idx].compiled.is_none() {
338            let mut session = Session::new(self.device);
339            if let Some(p) = &self.policy {
340                session = session.with_policy(p.clone());
341            }
342            let compiled = session
343                .compile_hir_with(build(upper), options)
344                .expect("HIR lower/compile in bucketed cache");
345            self.buckets[idx].compiled = Some(compiled);
346        }
347        Some((upper, self.buckets[idx].compiled.as_mut().unwrap()))
348    }
349
350    /// Index of the bucket containing `key`, or `None` if out of range.
351    /// Linear scan — bucket counts are small in practice.
352    pub fn bucket_for(&self, key: u64) -> Option<usize> {
353        self.buckets.iter().position(|b| b.range.contains(&key))
354    }
355
356    pub fn buckets(&self) -> impl Iterator<Item = &Range<u64>> {
357        self.buckets.iter().map(|b| &b.range)
358    }
359
360    /// Number of buckets that have been compiled so far (≤ total buckets).
361    pub fn compiled_count(&self) -> usize {
362        self.buckets.iter().filter(|b| b.compiled.is_some()).count()
363    }
364
365    pub fn total_buckets(&self) -> usize {
366        self.buckets.len()
367    }
368
369    /// "Compile at max, run at less" convenience for inputs and outputs
370    /// whose outer dimension is the bucket key:
371    ///
372    /// 1. Find or compile the bucket containing `key`.
373    /// 2. For each input, pad to `upper` rows along the outer dim using
374    ///    `pad_rows` (caller passes the inner-dim stride per input;
375    ///    `inner = 1` for purely 1D inputs).
376    /// 3. Run the compiled graph at full extent.
377    /// 4. Slice each output back to `actual_rows` along its outer dim.
378    ///    Outputs flagged with `inner = 0` in `output_inners` are
379    ///    returned unsliced (use this for extent-independent outputs
380    ///    like a pooled `[hidden]` embedding). Missing entries past
381    ///    the end of `output_inners` are also returned unsliced.
382    ///
383    /// Returns `(upper, outputs)`. Returns `None` if `key` falls outside
384    /// every bucket.
385    ///
386    /// **Compute scope:** kernels execute at the bucket's compile
387    /// extent (`upper`), not at `actual_rows`. This means smaller
388    /// buckets directly translate to less padded compute. With
389    /// [`power_of_two_ladder`](Self::power_of_two_ladder) the worst-
390    /// case waste is bounded at 2×; with hand-tuned buckets it can be
391    /// arbitrarily tight. True active-extent dispatch — one big
392    /// compile, kernels short-circuit at runtime — is a separate
393    /// per-backend change.
394    pub fn run_padded<F: FnOnce(u64) -> Graph>(
395        &mut self,
396        key: u64,
397        actual_rows: usize,
398        build: F,
399        inputs: &[(&str, &[f32], usize)],
400        output_inners: &[usize],
401    ) -> Option<(u64, Vec<Vec<f32>>)> {
402        let (upper, compiled) = self.get_or_compile(key, build)?;
403
404        // Own the padded buffers so they outlive the borrow handed to `run`.
405        let padded: Vec<(&str, Vec<f32>)> = inputs
406            .iter()
407            .map(|(name, data, inner)| (*name, pad_rows(data, *inner, upper)))
408            .collect();
409        let pairs: Vec<(&str, &[f32])> = padded.iter().map(|(n, d)| (*n, d.as_slice())).collect();
410
411        // Hint active-extent: backends that support per-kernel skip-
412        // compute (today: CPU's Activation thunk family) honor it; the
413        // default trait impl is a no-op, so other backends just process
414        // full extent and the slice_rows below still gives the user
415        // correct outputs.
416        compiled.set_active_extent(Some((actual_rows, upper as usize)));
417        let raw_outputs = compiled.run(&pairs);
418        compiled.set_active_extent(None);
419
420        let outs = raw_outputs
421            .into_iter()
422            .enumerate()
423            .map(|(i, out)| match output_inners.get(i).copied() {
424                Some(0) | None => out,
425                Some(inner) => slice_rows(&out, inner, actual_rows),
426            })
427            .collect();
428
429        Some((upper, outs))
430    }
431
432    /// Like [`Self::get_or_compile_with_options`] but also uploads `params` on first compile.
433    pub fn ensure_graph_with_params<F>(
434        &mut self,
435        key: u64,
436        build: F,
437        options: &crate::CompileOptions,
438    ) -> Option<(u64, &mut CompiledGraph)>
439    where
440        F: FnOnce(u64) -> (Graph, HashMap<String, Vec<f32>>),
441    {
442        let idx = self.bucket_for(key)?;
443        let upper = self.buckets[idx].range.end - 1;
444        if self.buckets[idx].compiled.is_none() {
445            let (graph, params) = build(upper);
446            let mut session = Session::new(self.device);
447            if let Some(p) = &self.policy {
448                session = session.with_policy(p.clone());
449            }
450            let mut compiled = session.compile_with(graph, options);
451            for (name, data) in params {
452                compiled.set_param(&name, &data);
453            }
454            self.buckets[idx].compiled = Some(compiled);
455        }
456        Some((upper, self.buckets[idx].compiled.as_mut().unwrap()))
457    }
458
459    /// HIR variant of [`Self::ensure_graph_with_params`].
460    pub fn ensure_hir_with_params<F>(
461        &mut self,
462        key: u64,
463        build: F,
464        options: &crate::CompileOptions,
465    ) -> Option<(u64, &mut CompiledGraph)>
466    where
467        F: FnOnce(u64) -> (HirModule, HashMap<String, Vec<f32>>),
468    {
469        let idx = self.bucket_for(key)?;
470        let upper = self.buckets[idx].range.end - 1;
471        if self.buckets[idx].compiled.is_none() {
472            let (hir, params) = build(upper);
473            let mut session = Session::new(self.device);
474            if let Some(p) = &self.policy {
475                session = session.with_policy(p.clone());
476            }
477            let mut compiled = session
478                .compile_hir_with(hir, options)
479                .expect("HIR lower/compile in ensure_hir_with_params");
480            for (name, data) in params {
481                compiled.set_param(&name, &data);
482            }
483            self.buckets[idx].compiled = Some(compiled);
484        }
485        Some((upper, self.buckets[idx].compiled.as_mut().unwrap()))
486    }
487
488    /// [`Self::run_padded`] with per-input optional row padding (`CacheRunInput`).
489    pub fn run_padded_mixed<F>(
490        &mut self,
491        key: u64,
492        actual_rows: usize,
493        build: F,
494        inputs: &[CacheRunInput<'_>],
495        output_inners: &[usize],
496    ) -> Option<(u64, Vec<Vec<f32>>)>
497    where
498        F: FnOnce(u64) -> Graph,
499    {
500        let (upper, compiled) = self.get_or_compile(key, build)?;
501
502        let padded: Vec<(&str, Vec<f32>)> = inputs
503            .iter()
504            .map(|inp| match inp.row_inner {
505                Some(inner) => (inp.name, pad_rows(inp.data, inner, upper)),
506                None => (inp.name, inp.data.to_vec()),
507            })
508            .collect();
509        let pairs: Vec<(&str, &[f32])> = padded.iter().map(|(n, d)| (*n, d.as_slice())).collect();
510
511        compiled.set_active_extent(Some((actual_rows, upper as usize)));
512        let raw_outputs = compiled.run(&pairs);
513        compiled.set_active_extent(None);
514
515        let outs = raw_outputs
516            .into_iter()
517            .enumerate()
518            .map(|(i, out)| match output_inners.get(i).copied() {
519                Some(0) | None => out,
520                Some(inner) => slice_rows(&out, inner, actual_rows),
521            })
522            .collect();
523
524        Some((upper, outs))
525    }
526}
527
528// ── Dynamic-dim cache (plan #54) ──────────────────────────────────────
529//
530// Compile HIR once through the fusion pipeline (graph may contain
531// `Dim::Dynamic` symbols), then specialize to concrete shapes per cache
532// key and backend-compile the resulting LIR.
533
534/// Compile-once / specialize-at-runtime cache for symbolic HIR modules.
535pub struct DynamicDimCompileCache {
536    device: Device,
537    policy: Option<rlx_opt::PrecisionPolicy>,
538    capacity: usize,
539    template: Option<CompileResult>,
540    entries: Vec<(u64, CompiledGraph)>,
541    order: VecDeque<u64>,
542}
543
544impl DynamicDimCompileCache {
545    pub fn new(device: Device, capacity: usize) -> Self {
546        Self::with_policy(device, capacity, None)
547    }
548
549    pub fn with_policy(
550        device: Device,
551        capacity: usize,
552        policy: Option<rlx_opt::PrecisionPolicy>,
553    ) -> Self {
554        assert!(capacity > 0, "DynamicDimCompileCache capacity must be ≥ 1");
555        Self {
556            device,
557            policy,
558            capacity,
559            template: None,
560            entries: Vec::with_capacity(capacity),
561            order: VecDeque::with_capacity(capacity),
562        }
563    }
564
565    pub fn compile_device(&self) -> Device {
566        self.device
567    }
568
569    /// Return a backend-compiled graph specialized for `binding`.
570    /// `build_hir` runs at most once to populate the dynamic template.
571    pub fn get_or_specialize<F: FnOnce() -> HirModule>(
572        &mut self,
573        key: u64,
574        binding: &DimBinding,
575        build_hir: F,
576        options: &crate::CompileOptions,
577    ) -> Result<&mut CompiledGraph, rlx_ir::hir::LowerError> {
578        if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
579            return Ok(&mut self.entries[idx].1);
580        }
581        if self.template.is_none() {
582            let mut template_opts = options.clone();
583            template_opts.dim_binding = None;
584            let pipe = crate::stages::pipeline_for(self.device, &template_opts);
585            self.template = Some(pipe.compile_hir(build_hir())?);
586        }
587        let template = self.template.as_ref().expect("template just set");
588        let mut spec_opts = options.clone();
589        spec_opts.dim_binding = None;
590        let pipe = crate::stages::pipeline_for(self.device, &spec_opts);
591        let specialized = template.specialize(&pipe, binding);
592        let backend = crate::registry::backend_for(self.device).expect("backend registered");
593        let mut compile_opts = options.clone();
594        compile_opts.dim_binding = None;
595        if compile_opts.policy.is_none() {
596            if let Some(p) = &self.policy {
597                compile_opts = compile_opts.policy(p.clone());
598            }
599        }
600        let executable = backend.compile_lir(specialized.lir, &compile_opts);
601        let compiled = CompiledGraph::new(executable, self.device);
602
603        if self.entries.len() >= self.capacity
604            && let Some(evict_key) = self.order.pop_front()
605        {
606            self.entries.retain(|(k, _)| *k != evict_key);
607        }
608        self.entries.push((key, compiled));
609        self.order.push_back(key);
610        Ok(&mut self.entries.last_mut().unwrap().1)
611    }
612
613    pub fn len(&self) -> usize {
614        self.entries.len()
615    }
616
617    pub fn is_empty(&self) -> bool {
618        self.entries.is_empty()
619    }
620
621    pub fn contains(&self, key: u64) -> bool {
622        self.entries.iter().any(|(k, _)| *k == key)
623    }
624
625    pub fn has_template(&self) -> bool {
626        self.template.is_some()
627    }
628
629    /// Build the symbolic template once (no specialization).
630    pub fn ensure_template<F: FnOnce() -> HirModule>(
631        &mut self,
632        build_hir: F,
633        options: &crate::CompileOptions,
634    ) -> Result<&CompileResult, rlx_ir::hir::LowerError> {
635        if self.template.is_none() {
636            let mut opts = options.clone();
637            opts.dim_binding = None;
638            let pipe = crate::stages::pipeline_for(self.device, &opts);
639            self.template = Some(pipe.compile_hir(build_hir())?);
640        }
641        Ok(self.template.as_ref().expect("template set"))
642    }
643
644    pub fn template_result(&self) -> Option<&CompileResult> {
645        self.template.as_ref()
646    }
647
648    /// Specialize via on-disk LIR cache ([`CompilationMode::Aot`]).
649    /// Disk-backed specialize ([`rlx_ir::CompilationMode::Aot`]).
650    pub fn get_or_specialize_aot<F: FnOnce() -> HirModule>(
651        &mut self,
652        aot: &crate::AotCache,
653        disk_base: &str,
654        key: u64,
655        binding: &rlx_ir::DimBinding,
656        build_hir: F,
657        options: &crate::CompileOptions,
658    ) -> Result<&mut CompiledGraph, crate::AotCacheError> {
659        if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
660            return Ok(&mut self.entries[idx].1);
661        }
662        let device = self.device;
663        let template = self.ensure_template(build_hir, options)?;
664        let compiled = aot.specialize_cached(disk_base, binding, device, template, options)?;
665        if self.entries.len() >= self.capacity
666            && let Some(evict_key) = self.order.pop_front()
667        {
668            self.entries.retain(|(k, _)| *k != evict_key);
669        }
670        self.entries.push((key, compiled));
671        self.order.push_back(key);
672        Ok(&mut self.entries.last_mut().unwrap().1)
673    }
674}
675
676/// Pad `data` (interpreted as `[actual, inner]` row-major) up to `upper`
677/// rows by appending zeros. Returns a `Vec<f32>` of length
678/// `upper * inner`. Companion of [`slice_rows`] for the
679/// "compile at max, run at less" workflow with [`BucketedCompileCache`].
680///
681/// Panics if `data.len()` is not a multiple of `inner`, if `inner == 0`,
682/// or if `data.len() / inner > upper`.
683pub fn pad_rows(data: &[f32], inner: usize, upper: u64) -> Vec<f32> {
684    assert!(inner > 0, "pad_rows: inner stride must be ≥ 1");
685    assert_eq!(
686        data.len() % inner,
687        0,
688        "pad_rows: data len {} not a multiple of inner {inner}",
689        data.len(),
690    );
691    let upper = upper as usize;
692    let actual = data.len() / inner;
693    assert!(
694        actual <= upper,
695        "pad_rows: actual rows {actual} exceed upper bound {upper}",
696    );
697    let mut out = vec![0.0_f32; upper * inner];
698    out[..actual * inner].copy_from_slice(data);
699    out
700}
701
702/// Slice `data` (interpreted as `[upper, inner]` row-major) down to
703/// `actual` rows. Companion of [`pad_rows`].
704///
705/// Panics if `data.len()` is not a multiple of `inner`, if `inner == 0`,
706/// or if `actual` exceeds the number of rows in `data`.
707pub fn slice_rows(data: &[f32], inner: usize, actual: usize) -> Vec<f32> {
708    assert!(inner > 0, "slice_rows: inner stride must be ≥ 1");
709    assert_eq!(
710        data.len() % inner,
711        0,
712        "slice_rows: data len {} not a multiple of inner {inner}",
713        data.len(),
714    );
715    let upper = data.len() / inner;
716    assert!(
717        actual <= upper,
718        "slice_rows: actual rows {actual} exceed upper {upper}",
719    );
720    data[..actual * inner].to_vec()
721}
722
723#[cfg(test)]
724mod tests {
725    use super::*;
726    use rlx_ir::infer::GraphExt;
727    use rlx_ir::*;
728    use std::cell::Cell;
729
730    fn tiny_graph(n: usize) -> Graph {
731        let mut g = Graph::new("t");
732        let f = DType::F32;
733        let x = g.input("x", Shape::new(&[n], f));
734        let y = g.activation(rlx_ir::op::Activation::Relu, x, Shape::new(&[n], f));
735        g.set_outputs(vec![y]);
736        g
737    }
738
739    #[test]
740    fn cache_hits_avoid_recompile() {
741        let mut cache = CompileCache::new(Device::Cpu, 4);
742        let calls = Cell::new(0);
743
744        let _ = cache.get_or_compile(1, || {
745            calls.set(calls.get() + 1);
746            tiny_graph(8)
747        });
748        let _ = cache.get_or_compile(1, || {
749            calls.set(calls.get() + 1);
750            tiny_graph(8)
751        });
752        let _ = cache.get_or_compile(1, || {
753            calls.set(calls.get() + 1);
754            tiny_graph(8)
755        });
756        // Same key three times: build closure runs once.
757        assert_eq!(calls.get(), 1);
758        assert_eq!(cache.len(), 1);
759    }
760
761    #[test]
762    fn fifo_evicts_oldest_at_capacity() {
763        let mut cache = CompileCache::new(Device::Cpu, 2);
764        let _ = cache.get_or_compile(1, || tiny_graph(4));
765        let _ = cache.get_or_compile(2, || tiny_graph(8));
766        assert!(cache.contains(1) && cache.contains(2));
767        // Third entry evicts key 1 (oldest).
768        let _ = cache.get_or_compile(3, || tiny_graph(16));
769        assert!(!cache.contains(1));
770        assert!(cache.contains(2) && cache.contains(3));
771    }
772
773    #[test]
774    fn different_keys_keep_separate_compiles() {
775        let mut cache = CompileCache::new(Device::Cpu, 4);
776        let calls = Cell::new(0);
777        let _ = cache.get_or_compile(1, || {
778            calls.set(calls.get() + 1);
779            tiny_graph(8)
780        });
781        let _ = cache.get_or_compile(2, || {
782            calls.set(calls.get() + 1);
783            tiny_graph(16)
784        });
785        let _ = cache.get_or_compile(1, || {
786            calls.set(calls.get() + 1);
787            tiny_graph(8)
788        });
789        // Two unique keys → two compiles.
790        assert_eq!(calls.get(), 2);
791        assert_eq!(cache.len(), 2);
792    }
793
794    // ── BucketedCompileCache ──────────────────────────────────────────
795
796    #[test]
797    fn bucket_amortizes_keys_within_range() {
798        let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..4, 4..16]);
799        let calls = Cell::new(0);
800        let uppers = Cell::new((0u64, 0u64));
801
802        // Two distinct keys (2 and 3) both fall inside bucket 0 (1..4).
803        let (u1, _) = cache
804            .get_or_compile(2, |upper| {
805                calls.set(calls.get() + 1);
806                uppers.set((upper, uppers.get().1));
807                tiny_graph(upper as usize)
808            })
809            .expect("key 2 in range");
810        let (u2, _) = cache
811            .get_or_compile(3, |upper| {
812                calls.set(calls.get() + 1);
813                uppers.set((uppers.get().0, upper));
814                tiny_graph(upper as usize)
815            })
816            .expect("key 3 in range");
817
818        // One compile, both calls saw the same upper = range.end - 1 = 3.
819        assert_eq!(calls.get(), 1);
820        assert_eq!(u1, 3);
821        assert_eq!(u2, 3);
822        assert_eq!(uppers.get().0, 3);
823        assert_eq!(cache.compiled_count(), 1);
824        assert_eq!(cache.total_buckets(), 2);
825    }
826
827    #[test]
828    fn bucket_lookup_returns_none_outside_range() {
829        let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..4, 4..16]);
830        assert!(cache.bucket_for(0).is_none());
831        assert!(cache.bucket_for(16).is_none());
832        assert!(cache.bucket_for(100).is_none());
833        assert_eq!(cache.bucket_for(3), Some(0));
834        assert_eq!(cache.bucket_for(4), Some(1));
835
836        let calls = Cell::new(0);
837        let result = cache.get_or_compile(100, |u| {
838            calls.set(calls.get() + 1);
839            tiny_graph(u as usize)
840        });
841        assert!(result.is_none());
842        assert_eq!(calls.get(), 0); // build closure must not run for OOR keys
843        assert_eq!(cache.compiled_count(), 0);
844    }
845
846    #[test]
847    fn bucket_compiles_lazily_per_bucket() {
848        let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..4, 4..16, 16..64]);
849        let calls = Cell::new(0);
850
851        let _ = cache.get_or_compile(2, |u| {
852            calls.set(calls.get() + 1);
853            tiny_graph(u as usize)
854        });
855        let _ = cache.get_or_compile(8, |u| {
856            calls.set(calls.get() + 1);
857            tiny_graph(u as usize)
858        });
859        // Two distinct buckets hit → two compiles. Third bucket untouched.
860        assert_eq!(calls.get(), 2);
861        assert_eq!(cache.compiled_count(), 2);
862        assert_eq!(cache.total_buckets(), 3);
863    }
864
865    #[test]
866    #[should_panic(expected = "overlap")]
867    fn bucket_overlap_rejected() {
868        let _ = BucketedCompileCache::new(Device::Cpu, vec![1..8, 4..16]);
869    }
870
871    #[test]
872    #[should_panic(expected = "≥1 bucket")]
873    fn empty_bucket_list_rejected() {
874        let _ = BucketedCompileCache::new(Device::Cpu, vec![]);
875    }
876
877    // ── pad_rows / slice_rows ─────────────────────────────────────────
878
879    #[test]
880    fn pad_rows_appends_zeros() {
881        // 1D: actual=3 → upper=5, inner=1.
882        let p = pad_rows(&[1.0, 2.0, 3.0], 1, 5);
883        assert_eq!(p, vec![1.0, 2.0, 3.0, 0.0, 0.0]);
884
885        // 2D row-major [actual=2, inner=3] → [upper=4, inner=3].
886        let p = pad_rows(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 4);
887        assert_eq!(
888            p,
889            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
890        );
891
892        // actual == upper: no-op pad.
893        let p = pad_rows(&[7.0, 8.0], 1, 2);
894        assert_eq!(p, vec![7.0, 8.0]);
895    }
896
897    #[test]
898    fn slice_rows_truncates_trailing() {
899        let s = slice_rows(&[1.0, 2.0, 3.0, 0.0, 0.0], 1, 3);
900        assert_eq!(s, vec![1.0, 2.0, 3.0]);
901
902        let s = slice_rows(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0, 0.0], 3, 2);
903        assert_eq!(s, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
904    }
905
906    #[test]
907    #[should_panic(expected = "exceed upper")]
908    fn pad_rows_rejects_too_long_input() {
909        let _ = pad_rows(&[1.0, 2.0, 3.0, 4.0], 1, 3);
910    }
911
912    #[test]
913    #[should_panic(expected = "exceed upper")]
914    fn slice_rows_rejects_too_large_actual() {
915        let _ = slice_rows(&[1.0, 2.0, 3.0], 1, 5);
916    }
917
918    // ── BucketedCompileCache::run_padded ──────────────────────────────
919
920    #[test]
921    fn run_padded_pads_input_and_slices_output() {
922        // tiny_graph is 1D [n] → relu → [n].
923        // Compile bucket [1..16) at upper=15, run with actual_rows=10.
924        let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
925        let input: Vec<f32> = vec![1.0, -1.0, 2.0, -2.0, 3.0, -3.0, 4.0, -4.0, 5.0, -5.0];
926
927        let (upper, outs) = cache
928            .run_padded(
929                10, // key
930                10, // actual rows
931                |max| tiny_graph(max as usize),
932                &[("x", &input, 1)], // 1D, inner stride 1
933                &[1],                // slice the one output to actual rows
934            )
935            .expect("key 10 in [1..16)");
936
937        assert_eq!(upper, 15);
938        assert_eq!(outs.len(), 1);
939        let out = &outs[0];
940        assert_eq!(out.len(), 10, "output sliced back to actual_rows");
941        let expected: Vec<f32> = input.iter().map(|x| x.max(0.0)).collect();
942        assert_eq!(out, &expected);
943    }
944
945    #[test]
946    fn run_padded_reuses_bucket_across_actuals() {
947        // Same bucket, two different actuals — only one compile.
948        let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
949        let calls = Cell::new(0);
950
951        let (u1, o1) = cache
952            .run_padded(
953                10,
954                10,
955                |max| {
956                    calls.set(calls.get() + 1);
957                    tiny_graph(max as usize)
958                },
959                &[(
960                    "x",
961                    &[1.0, -1.0, 2.0, -2.0, 3.0, -3.0, 4.0, -4.0, 5.0, -5.0],
962                    1,
963                )],
964                &[1],
965            )
966            .unwrap();
967        assert_eq!(o1.len(), 1);
968        assert_eq!(o1[0].len(), 10);
969        assert_eq!(u1, 15);
970
971        let (u2, o2) = cache
972            .run_padded(
973                5,
974                5,
975                |max| {
976                    calls.set(calls.get() + 1);
977                    tiny_graph(max as usize)
978                },
979                &[("x", &[-1.0, 2.0, -3.0, 4.0, -5.0], 1)],
980                &[1],
981            )
982            .unwrap();
983        assert_eq!(o2.len(), 1);
984        assert_eq!(o2[0].len(), 5);
985        assert_eq!(u2, 15);
986        assert_eq!(o2[0], vec![0.0, 2.0, 0.0, 4.0, 0.0]);
987
988        assert_eq!(calls.get(), 1, "bucket cached across actuals");
989        assert_eq!(cache.compiled_count(), 1);
990    }
991
992    #[test]
993    fn run_padded_returns_none_out_of_range() {
994        let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
995        let calls = Cell::new(0);
996        let result = cache.run_padded(
997            100,
998            5,
999            |u| {
1000                calls.set(calls.get() + 1);
1001                tiny_graph(u as usize)
1002            },
1003            &[("x", &[1.0, 2.0, 3.0, 4.0, 5.0], 1)],
1004            &[1],
1005        );
1006        assert!(result.is_none());
1007        assert_eq!(calls.get(), 0);
1008        assert_eq!(cache.compiled_count(), 0);
1009    }
1010
1011    // ── power_of_two_ladder ───────────────────────────────────────────
1012
1013    #[test]
1014    fn power_of_two_ladder_generates_log_buckets() {
1015        let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 8, 64);
1016        // Expect buckets covering keys 1..=64 with extents 8, 16, 32, 64.
1017        let ranges: Vec<_> = cache.buckets().cloned().collect();
1018        assert_eq!(ranges, vec![1..9, 9..17, 17..33, 33..65]);
1019        assert_eq!(cache.total_buckets(), 4);
1020    }
1021
1022    #[test]
1023    fn power_of_two_ladder_picks_smallest_extent_for_actual() {
1024        // Ladder: extents 8, 16, 32, 64. actual=17 lands in the 32-extent
1025        // bucket, NOT the 64-extent one — that's the compute saving.
1026        let mut cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 8, 64);
1027        let captured_uppers: std::cell::RefCell<Vec<u64>> = Default::default();
1028
1029        let (u17, _) = cache
1030            .get_or_compile(17, |upper| {
1031                captured_uppers.borrow_mut().push(upper);
1032                tiny_graph(upper as usize)
1033            })
1034            .unwrap();
1035        let (u9, _) = cache
1036            .get_or_compile(9, |upper| {
1037                captured_uppers.borrow_mut().push(upper);
1038                tiny_graph(upper as usize)
1039            })
1040            .unwrap();
1041        let (u3, _) = cache
1042            .get_or_compile(3, |upper| {
1043                captured_uppers.borrow_mut().push(upper);
1044                tiny_graph(upper as usize)
1045            })
1046            .unwrap();
1047        let (u64_, _) = cache
1048            .get_or_compile(64, |upper| {
1049                captured_uppers.borrow_mut().push(upper);
1050                tiny_graph(upper as usize)
1051            })
1052            .unwrap();
1053
1054        assert_eq!(u17, 32, "key=17 → smallest extent ≥ 17 is 32");
1055        assert_eq!(u9, 16, "key=9  → smallest extent ≥ 9  is 16");
1056        assert_eq!(u3, 8, "key=3  → smallest extent ≥ 3  is 8");
1057        assert_eq!(u64_, 64, "key=64 → exact match at 64");
1058        assert_eq!(*captured_uppers.borrow(), vec![32, 16, 8, 64]);
1059        assert_eq!(cache.compiled_count(), 4);
1060    }
1061
1062    #[test]
1063    fn power_of_two_ladder_min_above_one_starts_at_one() {
1064        // First bucket always covers from key 1, even when min > 1.
1065        // (`min` controls the ladder's first extent, not the lower edge.)
1066        let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 16, 32);
1067        let ranges: Vec<_> = cache.buckets().cloned().collect();
1068        // min=16 → first extent 16, second 32. Buckets: 1..17, 17..33.
1069        assert_eq!(ranges, vec![1..17, 17..33]);
1070    }
1071
1072    #[test]
1073    fn power_of_two_ladder_non_pow2_min_rounds_up() {
1074        // min=10 → next_power_of_two = 16.
1075        let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 10, 64);
1076        let ranges: Vec<_> = cache.buckets().cloned().collect();
1077        assert_eq!(ranges, vec![1..17, 17..33, 33..65]);
1078    }
1079
1080    #[test]
1081    fn power_of_two_ladder_max_below_pow2_extends_up() {
1082        // max=20 needs to be covered → ladder extends to 32.
1083        let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 8, 20);
1084        let ranges: Vec<_> = cache.buckets().cloned().collect();
1085        assert_eq!(ranges, vec![1..9, 9..17, 17..33]);
1086    }
1087
1088    #[test]
1089    fn power_of_two_ladder_min_equals_max() {
1090        let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 16, 16);
1091        let ranges: Vec<_> = cache.buckets().cloned().collect();
1092        assert_eq!(ranges, vec![1..17]);
1093    }
1094
1095    #[test]
1096    #[should_panic(expected = "min must be ≥ 1")]
1097    fn power_of_two_ladder_zero_min_rejected() {
1098        let _ = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 0, 16);
1099    }
1100
1101    #[test]
1102    #[should_panic(expected = "max")]
1103    fn power_of_two_ladder_max_below_min_rejected() {
1104        let _ = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 32, 8);
1105    }
1106
1107    // ── Active-extent dispatch (true per-kernel skip-compute) ─────────
1108    //
1109    // The 3 tests below assert per-thunk active-extent scaling on the CPU
1110    // backend. Today `rlx_cpu::thunk::execute_thunks_active` is documented
1111    // as a stub that returns false (rlx-cpu/src/thunk.rs:2100-2110), so
1112    // the runtime falls back to full-extent dispatch — overwrites the
1113    // tail and the tail-preservation assertions fail. They're left here
1114    // (marked `#[ignore]`) as the test-driven contract that the future
1115    // active-extent implementation must satisfy. Drop the `#[ignore]`
1116    // when the per-thunk scaling lands for Copy / ActivationInPlace /
1117    // BinaryFull / Attention.
1118
1119    #[test]
1120    #[ignore = "active-extent execution is a stub on CPU (thunk.rs::execute_thunks_active)"]
1121    fn active_extent_skips_compute_on_cpu_activation() {
1122        // tiny_graph(15) is `Input([15]) → Relu → Output` and lowers to
1123        // a Copy + ActivationInPlace pair on CPU — both are in the safe
1124        // set, so the active-extent path runs scaled.
1125        //
1126        // To prove kernels actually skipped: warm the arena with a prior
1127        // full-extent run whose output is `[1.0; 15]`, then run again
1128        // with a negative-only input and active=5. The first 5 outputs
1129        // get re-copied + re-relu'd to 0; the tail (indices 5..15) stays
1130        // at 1.0 because both Copy and Activation skipped it. A full-
1131        // extent fallback would clip every element to 0.
1132        let graph = tiny_graph(15);
1133        let mut compiled = Session::new(Device::Cpu).compile(graph);
1134
1135        // Warm-up: full extent, all-positive input → output [1.0; 15].
1136        let warm_input: Vec<f32> = vec![1.0; 15];
1137        let warm_outs = compiled.run(&[("x", &warm_input)]);
1138        assert_eq!(warm_outs[0], vec![1.0; 15], "warm-up sanity");
1139
1140        // Active-extent run: all-negative input, hint actual=5 of 15.
1141        // First 5: Copy(-1) + Relu → 0. Tail: kernels skip → stays 1.0.
1142        let neg_input: Vec<f32> = vec![-1.0; 15];
1143        compiled.set_active_extent(Some((5, 15)));
1144        let outs = compiled.run(&[("x", &neg_input)]);
1145        let out = &outs[0];
1146
1147        assert_eq!(out.len(), 15);
1148        assert_eq!(
1149            out[..5],
1150            [0.0; 5],
1151            "first 5 elements processed (relu of -1)"
1152        );
1153        assert_eq!(
1154            out[5..],
1155            [1.0; 10],
1156            "tail untouched — proves Copy + Activation skipped indices 5..15"
1157        );
1158
1159        // Clear the hint and run again with the negative input — full
1160        // extent now processes everything, every element clips to 0.
1161        compiled.set_active_extent(None);
1162        let outs = compiled.run(&[("x", &neg_input)]);
1163        assert_eq!(
1164            outs[0],
1165            vec![0.0; 15],
1166            "full-extent path must clip every negative"
1167        );
1168    }
1169
1170    #[test]
1171    #[ignore = "active-extent execution is a stub on CPU (thunk.rs::execute_thunks_active)"]
1172    fn active_extent_skips_compute_on_binary_full() {
1173        // Input([4]) + Input([4]) → Output. Lowers to a BinaryFull
1174        // thunk with no broadcast (lhs_len == rhs_len == len), which
1175        // is in the safe set.
1176        let mut g = Graph::new("add");
1177        let f = DType::F32;
1178        let a = g.input("a", Shape::new(&[4], f));
1179        let b = g.input("b", Shape::new(&[4], f));
1180        let c = g.add(a, b);
1181        g.set_outputs(vec![c]);
1182        let mut compiled = Session::new(Device::Cpu).compile(g);
1183
1184        // Warm: full extent, output buffer becomes [2.0; 4].
1185        let warm = compiled.run(&[("a", &[1.0f32; 4]), ("b", &[1.0f32; 4])]);
1186        assert_eq!(warm[0], vec![2.0; 4]);
1187
1188        // Active-extent run: actual=2 of upper=4. Process first 2
1189        // elements only; tail (indices 2..4) stays at 2.0 from warm.
1190        compiled.set_active_extent(Some((2, 4)));
1191        let outs = compiled.run(&[("a", &[10.0f32; 4]), ("b", &[10.0f32; 4])]);
1192        let out = &outs[0];
1193        assert_eq!(out[..2], [20.0, 20.0], "first 2 = active sum");
1194        assert_eq!(
1195            out[2..],
1196            [2.0, 2.0],
1197            "tail untouched — proves BinaryFull skipped indices 2..4"
1198        );
1199
1200        // Clear hint → full path overwrites entire output.
1201        compiled.set_active_extent(None);
1202        let outs = compiled.run(&[("a", &[10.0f32; 4]), ("b", &[10.0f32; 4])]);
1203        assert_eq!(outs[0], vec![20.0; 4]);
1204    }
1205
1206    #[test]
1207    #[ignore = "process-wide STATE; runs only in isolation via `cargo test perfetto -- --ignored`"]
1208    fn perfetto_trace_emits_per_thunk_events() {
1209        // PLAN L3: end-to-end Perfetto event capture. Requires the env
1210        // var to be set BEFORE the perfetto module is first touched
1211        // (OnceLock — can't re-init). We set it here unconditionally;
1212        // for tests run in parallel within the same process, the
1213        // earliest test wins. To avoid flake we mark this `#[ignore]`
1214        // and the developer runs it explicitly.
1215        use std::env;
1216        use std::fs;
1217        let path = env::temp_dir().join(format!("rlx-perfetto-e2e-{}.json", std::process::id()));
1218        if path.exists() {
1219            let _ = fs::remove_file(&path);
1220        }
1221        unsafe {
1222            env::set_var("RLX_TRACE_PERFETTO", &path);
1223        }
1224
1225        // Build + run a small CPU graph — Add → Relu (no fusion macros).
1226        let f = DType::F32;
1227        let mut g = Graph::new("perf");
1228        let a = g.input("a", Shape::new(&[4], f));
1229        let b = g.input("b", Shape::new(&[4], f));
1230        let s = g.add(a, b);
1231        let r = g.relu(s);
1232        g.set_outputs(vec![r]);
1233        let mut compiled = Session::new(Device::Cpu).compile(g);
1234        let _ = compiled.run(&[("a", &[1.0; 4]), ("b", &[1.0; 4])]);
1235
1236        // Force the trace file to flush its closing bracket.
1237        crate::perfetto::flush_and_finalize();
1238
1239        let contents = fs::read_to_string(&path).expect("trace file");
1240        // At minimum we should see one of our thunk names.
1241        assert!(
1242            contents.contains("\"binary\"")
1243                || contents.contains("\"activation\"")
1244                || contents.contains("\"elementwise_region\""),
1245            "expected at least one thunk-name event in perfetto trace; got: {contents}"
1246        );
1247        // JSON shape: starts with `[` and (after flush) ends with `]`.
1248        assert!(contents.trim_start().starts_with('['));
1249        let _ = fs::remove_file(&path);
1250    }
1251
1252    #[test]
1253    fn elementwise_region_fused_matches_unfused() {
1254        // PLAN L2: a chain `Add(a, b) → Mul(_, c) → Relu` should fuse
1255        // into one ElementwiseRegion thunk in the CPU backend. Compare
1256        // its output against the value computed by hand to confirm the
1257        // fused execution is numerically identical.
1258        let f = DType::F32;
1259        let mut g = Graph::new("ew_e2e");
1260        let a = g.input("a", Shape::new(&[8], f));
1261        let b = g.input("b", Shape::new(&[8], f));
1262        let c = g.input("c", Shape::new(&[8], f));
1263        let s = Shape::new(&[8], f);
1264        let add = g.add(a, b);
1265        let mul = g.mul(add, c);
1266        let relu = g.relu(mul);
1267        let _ = s;
1268        g.set_outputs(vec![relu]);
1269
1270        let mut compiled = Session::new(Device::Cpu).compile(g);
1271        let av: Vec<f32> = vec![1.0, -2.0, 3.0, -4.0, 0.5, -0.5, 1.5, -1.5];
1272        let bv: Vec<f32> = vec![0.5, 1.0, 2.0, 4.0, 0.5, 0.5, 0.5, 0.5];
1273        let cv: Vec<f32> = vec![1.0, 2.0, 1.0, 1.0, 2.0, 3.0, 0.5, 4.0];
1274        let outs = compiled.run(&[("a", &av), ("b", &bv), ("c", &cv)]);
1275        let out = &outs[0];
1276
1277        let expected: Vec<f32> = (0..8)
1278            .map(|i| {
1279                let v = (av[i] + bv[i]) * cv[i];
1280                v.max(0.0)
1281            })
1282            .collect();
1283        for (i, (got, exp)) in out.iter().zip(&expected).enumerate() {
1284            assert!(
1285                (got - exp).abs() < 1e-6,
1286                "mismatch at {i}: got {got}, expected {exp}"
1287            );
1288        }
1289    }
1290
1291    #[test]
1292    #[ignore = "active-extent execution is a stub on CPU (thunk.rs::execute_thunks_active)"]
1293    fn active_extent_skips_compute_on_attention() {
1294        // Standalone Attention with kernel-synthesized MaskKind::None.
1295        // Q/K/V shape: [batch=1, seq=4, num_heads*head_dim=8].
1296        use rlx_ir::op::MaskKind;
1297        let f = DType::F32;
1298        let mut g = Graph::new("attn");
1299        let q = g.input("q", Shape::new(&[1, 4, 8], f));
1300        let k = g.input("k", Shape::new(&[1, 4, 8], f));
1301        let v = g.input("v", Shape::new(&[1, 4, 8], f));
1302        let out = g.attention_kind(q, k, v, 2, 4, MaskKind::None, Shape::new(&[1, 4, 8], f));
1303        g.set_outputs(vec![out]);
1304        let mut compiled = Session::new(Device::Cpu).compile(g);
1305
1306        // Warm: full extent. Q=K=V uniform → output uniform-ish.
1307        let warm = compiled.run(&[
1308            ("q", &[1.0f32; 32]),
1309            ("k", &[1.0f32; 32]),
1310            ("v", &[1.0f32; 32]),
1311        ]);
1312        let warm_out = warm[0].clone();
1313        assert_eq!(warm_out.len(), 32);
1314
1315        // Active: s_active=2 of s_full=4. Different inputs.
1316        // Tail rows (indices 16..32 = positions 2,3) should be untouched
1317        // — preserved from the warm run. First 16 indices recomputed.
1318        compiled.set_active_extent(Some((2, 4)));
1319        let outs = compiled.run(&[
1320            ("q", &[3.0f32; 32]),
1321            ("k", &[3.0f32; 32]),
1322            ("v", &[3.0f32; 32]),
1323        ]);
1324        let out = &outs[0];
1325        assert_eq!(out.len(), 32);
1326        assert_eq!(
1327            &out[16..],
1328            &warm_out[16..],
1329            "tail (positions 2,3) must be untouched — proves Attention skipped"
1330        );
1331        // Sanity: first 2 positions changed since input value differs (3.0 vs 1.0).
1332        assert_ne!(
1333            &out[..16],
1334            &warm_out[..16],
1335            "first 2 positions should reflect new input"
1336        );
1337    }
1338
1339    #[test]
1340    fn active_extent_falls_back_when_unsupported_thunk_in_schedule() {
1341        // A graph containing any thunk outside `safe_for_active_extent`
1342        // (e.g. Sgemm via a matmul) must fall back to the full-extent
1343        // executor — partial application would feed garbage downstream.
1344        // We can't easily construct such a graph at this layer without
1345        // pulling in matmul builders, but we can verify the trait
1346        // contract via the simpler check: setting an extent hint on a
1347        // matmul-bearing graph still gives correct outputs (full-extent
1348        // fallback path was taken).
1349        //
1350        // Skipped explicit construction here — the safety net is the
1351        // `if !all(safe) return false` guard inside execute_thunks_active
1352        // plus the `if !active_used { execute_thunks(...) }` fallback in
1353        // the CPU executor, both unit-tested via direct safety-predicate
1354        // and the warm-arena test above.
1355    }
1356
1357    #[test]
1358    fn run_padded_uses_active_extent_on_cpu() {
1359        // End-to-end: the cache wires set_active_extent before run.
1360        // Same setup as above but driven through run_padded.
1361        let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
1362        let input: Vec<f32> = vec![
1363            1.0, -1.0, 2.0, -2.0, 3.0, // 5 real values
1364            -10.0, -20.0, -30.0, -40.0, -50.0, // padding zeros from pad_rows
1365        ];
1366        // pad_rows zero-pads from len=5 up to upper=15, so the arena
1367        // tail past index 5 is 0.0 going in. After active-extent run,
1368        // tail stays at 0.0 (untouched, but the value happens to match
1369        // what relu would produce). We can't observe skip via output
1370        // here — slice_rows trims to actual_rows anyway.
1371        let (upper, outs) = cache
1372            .run_padded(
1373                5,
1374                5,
1375                |max| tiny_graph(max as usize),
1376                &[("x", &input[..5], 1)],
1377                &[1],
1378            )
1379            .unwrap();
1380        assert_eq!(upper, 15);
1381        assert_eq!(outs[0].len(), 5);
1382        // Active-extent path (CPU honors): outputs match relu of the
1383        // first 5 inputs. Slicing already handled, so user-visible
1384        // result is the same whether or not the kernel skipped tail
1385        // compute. The point of this test is just to confirm the wiring
1386        // path doesn't crash and produces correct outputs end-to-end.
1387        assert_eq!(outs[0], vec![1.0, 0.0, 2.0, 0.0, 3.0]);
1388    }
1389
1390    #[test]
1391    fn run_padded_inner_zero_returns_output_unsliced() {
1392        // Marking output_inners[0] = 0 disables slicing for that output.
1393        // The compiled graph still runs at upper=15, so we expect 15 outputs back.
1394        let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
1395        let input: Vec<f32> = vec![1.0, -1.0, 2.0, -2.0, 3.0];
1396
1397        let (upper, outs) = cache
1398            .run_padded(
1399                5,
1400                5,
1401                |max| tiny_graph(max as usize),
1402                &[("x", &input, 1)],
1403                &[0], // don't slice this output
1404            )
1405            .unwrap();
1406
1407        assert_eq!(upper, 15);
1408        assert_eq!(
1409            outs[0].len(),
1410            15,
1411            "unsliced output preserves full upper extent"
1412        );
1413        // First 5 = relu of input, tail 10 = relu(0) = 0.
1414        assert_eq!(&outs[0][..5], &[1.0, 0.0, 2.0, 0.0, 3.0]);
1415        assert!(outs[0][5..].iter().all(|&v| v == 0.0));
1416    }
1417
1418    #[test]
1419    fn dynamic_dim_cache_specializes_per_key() {
1420        use rlx_ir::DType;
1421        use rlx_ir::Shape;
1422        use rlx_ir::hir::HirModule;
1423        use rlx_ir::sym;
1424
1425        let mut cache = DynamicDimCompileCache::new(Device::Cpu, 4);
1426        let opts = crate::CompileOptions::new();
1427        {
1428            let _short = cache
1429                .get_or_specialize(
1430                    8,
1431                    &rlx_ir::DimBinding::batch_seq(1, 8),
1432                    || {
1433                        let mut hir = HirModule::new("dyn_cache");
1434                        let x = hir.input_batch_seq("x", sym::BATCH, sym::SEQ, 4, DType::F32);
1435                        let w = hir.param("w", Shape::new(&[4, 2], DType::F32));
1436                        let y = hir.linear(
1437                            x,
1438                            w,
1439                            None,
1440                            None,
1441                            Shape::batch_seq(sym::BATCH, sym::SEQ, 2, DType::F32),
1442                        );
1443                        hir.set_outputs(vec![y]);
1444                        hir
1445                    },
1446                    &opts,
1447                )
1448                .expect("specialize short");
1449        }
1450        assert!(cache.has_template());
1451        assert_eq!(cache.len(), 1);
1452        cache
1453            .get_or_specialize(
1454                128,
1455                &rlx_ir::DimBinding::batch_seq(1, 128),
1456                || panic!("HIR builder must not run twice"),
1457                &opts,
1458            )
1459            .expect("specialize long");
1460        assert_eq!(cache.len(), 2);
1461    }
1462}