Skip to main content

rlx_runtime/
compile_cache.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Shape-bucketed compile cache.
17//!
18//! Lets variable-shape callers (e.g., embedding-model wrappers that vary
19//! batch + seq per request) amortize the per-(shape) compile cost. Cache
20//! keys are caller-provided `u64`s — the caller decides what counts as a
21//! shape bucket. Typical recipe: `(batch as u64) << 32 | seq as u64`.
22//!
23//! The cache stores one `CompiledGraph` per key. Params loaded onto a
24//! cached entry persist for that entry — re-fetching from cache does
25//! **not** require re-running `set_param`. Eviction is FIFO, capped at
26//! `capacity` entries (good enough for the current "a handful of common
27//! shapes" usage pattern; switch to LRU if a real workload shows churn).
28//!
29//! # Example
30//!
31//! ```rust,ignore
32//! let mut cache = CompileCache::new(Device::Metal, 8);
33//! let key = ((batch as u64) << 32) | seq as u64;
34//! let mut compiled = cache.get_or_compile(key, || build_my_graph(batch, seq));
35//! // First call for `key`: compiles. Subsequent calls: cache hit.
36//! compiled.run(&[("x", &input_data)]);
37//! ```
38
39use crate::{CompiledGraph, Device, Session};
40use rlx_ir::DimBinding;
41use rlx_ir::Graph;
42use rlx_ir::hir::HirModule;
43use rlx_opt::CompileResult;
44use std::collections::HashMap;
45use std::collections::VecDeque;
46use std::ops::Range;
47
48/// Named runtime input for [`BucketedCompileCache::run_padded_mixed`].
49pub struct CacheRunInput<'a> {
50    pub name: &'a str,
51    pub data: &'a [f32],
52    /// Row inner stride for [`pad_rows`]; `None` = use data as-is (no padding).
53    pub row_inner: Option<usize>,
54}
55
56pub struct CompileCache {
57    device: Device,
58    capacity: usize,
59    // Per-cache precision policy. None → default (F32). Set once at
60    // construction; applies to every compile this cache performs.
61    policy: Option<rlx_opt::PrecisionPolicy>,
62    // (key, compiled). Vec keeps insertion order for FIFO eviction; the
63    // expected hit-rate at our cap (~8) makes the linear scan cheaper
64    // than a HashMap + separate eviction list.
65    entries: Vec<(u64, CompiledGraph)>,
66    // Insertion order for eviction.
67    order: VecDeque<u64>,
68}
69
70impl CompileCache {
71    pub fn new(device: Device, capacity: usize) -> Self {
72        Self::with_policy(device, capacity, None)
73    }
74
75    /// Cache that compiles every entry with the given precision policy.
76    /// Use this when the cached entries should differ from CPU-default
77    /// F32 — e.g., `PrecisionPolicy::AutoMixed` for f16 compute on Metal.
78    pub fn with_policy(
79        device: Device,
80        capacity: usize,
81        policy: Option<rlx_opt::PrecisionPolicy>,
82    ) -> Self {
83        assert!(capacity > 0, "CompileCache capacity must be ≥ 1");
84        Self {
85            device,
86            capacity,
87            policy,
88            entries: Vec::with_capacity(capacity),
89            order: VecDeque::with_capacity(capacity),
90        }
91    }
92
93    /// Compile if not present, then return a mutable reference. The borrow
94    /// lifetime is tied to `&mut self` so callers naturally serialize their
95    /// use of any one entry — the cache is single-owner today.
96    pub fn get_or_compile<F: FnOnce() -> Graph>(
97        &mut self,
98        key: u64,
99        build: F,
100    ) -> &mut CompiledGraph {
101        self.get_or_compile_with_options(key, build, &crate::CompileOptions::new())
102    }
103
104    /// Like [`Self::get_or_compile`] with explicit [`CompileOptions`].
105    pub fn get_or_compile_with_options<F: FnOnce() -> Graph>(
106        &mut self,
107        key: u64,
108        build: F,
109        options: &crate::CompileOptions,
110    ) -> &mut CompiledGraph {
111        if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
112            return &mut self.entries[idx].1;
113        }
114        let mut session = Session::new(self.device);
115        if let Some(p) = &self.policy {
116            session = session.with_policy(p.clone());
117        }
118        let compiled = session.compile_with(build(), options);
119
120        // Evict FIFO if at capacity.
121        if self.entries.len() >= self.capacity
122            && let Some(evict_key) = self.order.pop_front()
123        {
124            sync_evicted_entry(&mut self.entries, evict_key);
125            self.entries.retain(|(k, _)| *k != evict_key);
126        }
127        self.entries.push((key, compiled));
128        self.order.push_back(key);
129        &mut self.entries.last_mut().unwrap().1
130    }
131
132    /// Like [`Self::get_or_compile_with_options`] but builds and compiles HIR directly.
133    pub fn get_or_compile_hir_with_options<F: FnOnce() -> rlx_ir::hir::HirModule>(
134        &mut self,
135        key: u64,
136        build: F,
137        options: &crate::CompileOptions,
138    ) -> &mut CompiledGraph {
139        if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
140            return &mut self.entries[idx].1;
141        }
142        let mut session = Session::new(self.device);
143        if let Some(p) = &self.policy {
144            session = session.with_policy(p.clone());
145        }
146        let compiled = session
147            .compile_hir_with(build(), options)
148            .expect("HIR lower/compile in compile cache");
149
150        if self.entries.len() >= self.capacity
151            && let Some(evict_key) = self.order.pop_front()
152        {
153            sync_evicted_entry(&mut self.entries, evict_key);
154            self.entries.retain(|(k, _)| *k != evict_key);
155        }
156        self.entries.push((key, compiled));
157        self.order.push_back(key);
158        &mut self.entries.last_mut().unwrap().1
159    }
160
161    /// Number of entries currently cached. Useful for tests + diagnostics.
162    pub fn len(&self) -> usize {
163        self.entries.len()
164    }
165    pub fn is_empty(&self) -> bool {
166        self.entries.is_empty()
167    }
168    /// Was this key already compiled? Doesn't change recency.
169    pub fn contains(&self, key: u64) -> bool {
170        self.entries.iter().any(|(k, _)| *k == key)
171    }
172
173    /// Drop all cached compiled graphs (free weight params).
174    pub fn clear(&mut self) {
175        self.sync_all();
176        self.entries.clear();
177        self.order.clear();
178    }
179
180    /// Drain in-flight GPU work on every cached entry (Metal `commit_no_wait` paths).
181    pub fn sync_all(&mut self) {
182        for (_, compiled) in &mut self.entries {
183            compiled.sync_pending();
184        }
185    }
186}
187
188fn sync_evicted_entry(entries: &mut [(u64, CompiledGraph)], evict_key: u64) {
189    if let Some((_, compiled)) = entries.iter_mut().find(|(k, _)| *k == evict_key) {
190        compiled.sync_pending();
191    }
192}
193
194// ── Bucketed cache (PLAN L1) ──────────────────────────────────────────
195//
196// Variant of `CompileCache` that compiles one `CompiledGraph` per shape
197// *range* instead of per exact key. The caller declares buckets up front
198// (e.g. `1..16`, `16..64`, `64..256`); each bucket is compiled lazily at
199// its upper bound the first time a key in that bucket arrives.
200//
201// Trade vs `CompileCache`: unique keys → unique compiles becomes unique
202// buckets → unique compiles. The compiled graph is specialized for each
203// bucket's upper-bound dim. Two ways to use it:
204//
205// **Manual padding** — caller drives the pad/slice cycle:
206// ```rust,ignore
207// let buckets = vec![1..16, 16..64, 64..256];
208// let mut cache = BucketedCompileCache::new(Device::Metal, buckets);
209// let (upper, compiled) = cache
210//     .get_or_compile(seq as u64, |max_seq| build_graph(max_seq as usize))
211//     .expect("seq within buckets");
212// // pad input to `upper as usize` elements before run
213// compiled.run(&[("x", &padded)]);
214// ```
215//
216// **`run_padded` shortcut** — cache pads and slices for you:
217// ```rust,ignore
218// let (upper, outputs) = cache.run_padded(
219//     seq as u64,
220//     seq,                                    // actual rows
221//     |max_seq| build_graph(max_seq as usize),
222//     &[("x", &raw_input, hidden)],           // (name, data, inner stride)
223//     &[hidden],                              // per-output inner stride
224// ).expect("in range");
225// ```
226//
227// **How "skip compute" actually works here**: each bucket compiles at
228// its own upper bound, so kernels run at *that* extent, not at some
229// global maximum. Smaller buckets ⇒ less padded compute. The
230// `power_of_two_ladder` constructor builds a logarithmic schedule that
231// guarantees ≤2× padding waste in exchange for `O(log max)` compiled
232// artifacts. For finer control, hand-construct the bucket list.
233//
234// True per-kernel active-extent dispatch (one big compile, runtime
235// extent override that short-circuits each kernel's inner loop) is a
236// per-backend change across `rlx-cuda`, `rlx-rocm`,
237// `rlx-cpu/src/thunk.rs`, `rlx-metal/src/thunk.rs`, `rlx-mlx`,
238// `rlx-wgpu` — multi-day project, not in this layer.
239
240pub struct BucketedCompileCache {
241    device: Device,
242    policy: Option<rlx_opt::PrecisionPolicy>,
243    buckets: Vec<Bucket>,
244}
245
246struct Bucket {
247    range: Range<u64>,
248    compiled: Option<CompiledGraph>,
249}
250
251impl BucketedCompileCache {
252    pub fn new(device: Device, buckets: Vec<Range<u64>>) -> Self {
253        Self::with_policy(device, buckets, None)
254    }
255
256    /// Power-of-two ladder over `[1, max]`, with extents
257    /// `[min_pow2, 2·min_pow2, 4·min_pow2, …, max_pow2]` where
258    /// `min_pow2 = min.next_power_of_two()` and `max_pow2` is the smallest
259    /// power of two ≥ `max`. Each bucket compiles at its upper-bound
260    /// extent, so an `actual` value in bucket `(prev_extent .. ext]` runs
261    /// kernels at extent `ext` (not at the worst case of the whole range).
262    /// Guarantees compute waste from padding ≤2× — `actual > ext / 2`
263    /// for every bucket except possibly the smallest.
264    ///
265    /// Example: `power_of_two_ladder(Device::Cpu, 8, 256)` yields buckets
266    /// `1..9, 9..17, 17..33, 33..65, 65..129, 129..257` with compile
267    /// extents `8, 16, 32, 64, 128, 256`. An `actual = 17` runs at extent
268    /// 32 instead of the 255 a single wide `1..256` bucket would compile
269    /// at — that's the "skip compute" win, paid for with `O(log max)`
270    /// compiled artifacts instead of one.
271    pub fn power_of_two_ladder(device: Device, min: u64, max: u64) -> Self {
272        Self::power_of_two_ladder_with_policy(device, min, max, None)
273    }
274
275    pub fn power_of_two_ladder_with_policy(
276        device: Device,
277        min: u64,
278        max: u64,
279        policy: Option<rlx_opt::PrecisionPolicy>,
280    ) -> Self {
281        assert!(min >= 1, "power_of_two_ladder: min must be ≥ 1, got {min}");
282        assert!(
283            max >= min,
284            "power_of_two_ladder: max ({max}) must be ≥ min ({min})"
285        );
286        let mut buckets: Vec<Range<u64>> = Vec::new();
287        let mut start = 1u64;
288        let mut extent = min.next_power_of_two();
289        loop {
290            buckets.push(start..(extent + 1));
291            if extent >= max {
292                break;
293            }
294            start = extent + 1;
295            extent = extent
296                .checked_mul(2)
297                .expect("power_of_two_ladder: extent overflow");
298        }
299        Self::with_policy(device, buckets, policy)
300    }
301
302    pub fn with_policy(
303        device: Device,
304        buckets: Vec<Range<u64>>,
305        policy: Option<rlx_opt::PrecisionPolicy>,
306    ) -> Self {
307        assert!(!buckets.is_empty(), "BucketedCompileCache needs ≥1 bucket");
308        for (i, b) in buckets.iter().enumerate() {
309            assert!(b.start < b.end, "bucket {i} ({b:?}) is empty");
310            if i + 1 < buckets.len() {
311                assert!(
312                    b.end <= buckets[i + 1].start,
313                    "buckets {i} ({b:?}) and {} ({:?}) overlap",
314                    i + 1,
315                    buckets[i + 1],
316                );
317            }
318        }
319        let buckets = buckets
320            .into_iter()
321            .map(|range| Bucket {
322                range,
323                compiled: None,
324            })
325            .collect();
326        Self {
327            device,
328            policy,
329            buckets,
330        }
331    }
332
333    /// Find the bucket containing `key`, compile if needed, return
334    /// `(upper, &mut CompiledGraph)` where `upper = range.end - 1` is the
335    /// extent the graph was compiled for. Caller pads inputs to `upper`
336    /// before calling `run`. Returns `None` if `key` is outside every
337    /// bucket — caller decides whether to fall back to a one-off compile.
338    ///
339    /// `build` receives `upper` and must return a `Graph` specialized for
340    /// that extent.
341    pub fn get_or_compile<F: FnOnce(u64) -> Graph>(
342        &mut self,
343        key: u64,
344        build: F,
345    ) -> Option<(u64, &mut CompiledGraph)> {
346        self.get_or_compile_with_options(key, build, &crate::CompileOptions::new())
347    }
348
349    /// Like [`Self::get_or_compile`] with explicit [`CompileOptions`].
350    pub fn get_or_compile_with_options<F: FnOnce(u64) -> Graph>(
351        &mut self,
352        key: u64,
353        build: F,
354        options: &crate::CompileOptions,
355    ) -> Option<(u64, &mut CompiledGraph)> {
356        let idx = self.bucket_for(key)?;
357        let upper = self.buckets[idx].range.end - 1;
358        if self.buckets[idx].compiled.is_none() {
359            let mut session = Session::new(self.device);
360            if let Some(p) = &self.policy {
361                session = session.with_policy(p.clone());
362            }
363            self.buckets[idx].compiled = Some(session.compile_with(build(upper), options));
364        }
365        Some((upper, self.buckets[idx].compiled.as_mut().unwrap()))
366    }
367
368    /// Like [`Self::get_or_compile`] but builds and compiles HIR directly
369    /// through the fusion-first pipeline (`Session::compile_hir`).
370    pub fn get_or_compile_hir<F: FnOnce(u64) -> HirModule>(
371        &mut self,
372        key: u64,
373        build: F,
374    ) -> Option<(u64, &mut CompiledGraph)> {
375        self.get_or_compile_hir_with_options(key, build, &crate::CompileOptions::new())
376    }
377
378    /// Like [`Self::get_or_compile_hir`] with explicit [`CompileOptions`] (tier-1 profile, fusion target, …).
379    pub fn get_or_compile_hir_with_options<F: FnOnce(u64) -> HirModule>(
380        &mut self,
381        key: u64,
382        build: F,
383        options: &crate::CompileOptions,
384    ) -> Option<(u64, &mut CompiledGraph)> {
385        let idx = self.bucket_for(key)?;
386        let upper = self.buckets[idx].range.end - 1;
387        if self.buckets[idx].compiled.is_none() {
388            let mut session = Session::new(self.device);
389            if let Some(p) = &self.policy {
390                session = session.with_policy(p.clone());
391            }
392            let compiled = session
393                .compile_hir_with(build(upper), options)
394                .expect("HIR lower/compile in bucketed cache");
395            self.buckets[idx].compiled = Some(compiled);
396        }
397        Some((upper, self.buckets[idx].compiled.as_mut().unwrap()))
398    }
399
400    /// Index of the bucket containing `key`, or `None` if out of range.
401    /// Linear scan — bucket counts are small in practice.
402    pub fn bucket_for(&self, key: u64) -> Option<usize> {
403        self.buckets.iter().position(|b| b.range.contains(&key))
404    }
405
406    /// Upper compile extent for `key`'s bucket (`range.end - 1`), without compiling.
407    pub fn bucket_upper_for_key(&self, key: u64) -> Option<u64> {
408        let idx = self.bucket_for(key)?;
409        Some(self.buckets[idx].range.end - 1)
410    }
411
412    pub fn buckets(&self) -> impl Iterator<Item = &Range<u64>> {
413        self.buckets.iter().map(|b| &b.range)
414    }
415
416    /// Number of buckets that have been compiled so far (≤ total buckets).
417    pub fn compiled_count(&self) -> usize {
418        self.buckets.iter().filter(|b| b.compiled.is_some()).count()
419    }
420
421    /// Mutable compiled graph for `key`'s bucket, if already compiled.
422    pub fn compiled_for_key_mut(&mut self, key: u64) -> Option<&mut CompiledGraph> {
423        let idx = self.bucket_for(key)?;
424        self.buckets[idx].compiled.as_mut()
425    }
426
427    pub fn total_buckets(&self) -> usize {
428        self.buckets.len()
429    }
430
431    /// Drop compiled graphs for all buckets except `keep`, freeing weight params.
432    pub fn evict_except(&mut self, keep: usize) {
433        for (i, bucket) in self.buckets.iter_mut().enumerate() {
434            if i != keep {
435                bucket.compiled = None;
436            }
437        }
438    }
439
440    /// Drop all compiled graphs (free param storage).
441    pub fn clear_compiled(&mut self) {
442        for bucket in &mut self.buckets {
443            bucket.compiled = None;
444        }
445    }
446
447    /// "Compile at max, run at less" convenience for inputs and outputs
448    /// whose outer dimension is the bucket key:
449    ///
450    /// 1. Find or compile the bucket containing `key`.
451    /// 2. For each input, pad to `upper` rows along the outer dim using
452    ///    `pad_rows` (caller passes the inner-dim stride per input;
453    ///    `inner = 1` for purely 1D inputs).
454    /// 3. Run the compiled graph at full extent.
455    /// 4. Slice each output back to `actual_rows` along its outer dim.
456    ///    Outputs flagged with `inner = 0` in `output_inners` are
457    ///    returned unsliced (use this for extent-independent outputs
458    ///    like a pooled `[hidden]` embedding). Missing entries past
459    ///    the end of `output_inners` are also returned unsliced.
460    ///
461    /// Returns `(upper, outputs)`. Returns `None` if `key` falls outside
462    /// every bucket.
463    ///
464    /// **Compute scope:** kernels execute at the bucket's compile
465    /// extent (`upper`), not at `actual_rows`. This means smaller
466    /// buckets directly translate to less padded compute. With
467    /// [`power_of_two_ladder`](Self::power_of_two_ladder) the worst-
468    /// case waste is bounded at 2×; with hand-tuned buckets it can be
469    /// arbitrarily tight. True active-extent dispatch — one big
470    /// compile, kernels short-circuit at runtime — is a separate
471    /// per-backend change.
472    pub fn run_padded<F: FnOnce(u64) -> Graph>(
473        &mut self,
474        key: u64,
475        actual_rows: usize,
476        build: F,
477        inputs: &[(&str, &[f32], usize)],
478        output_inners: &[usize],
479    ) -> Option<(u64, Vec<Vec<f32>>)> {
480        let (upper, compiled) = self.get_or_compile(key, build)?;
481
482        // Own the padded buffers so they outlive the borrow handed to `run`.
483        let padded: Vec<(&str, Vec<f32>)> = inputs
484            .iter()
485            .map(|(name, data, inner)| (*name, pad_rows(data, *inner, upper)))
486            .collect();
487        let pairs: Vec<(&str, &[f32])> = padded.iter().map(|(n, d)| (*n, d.as_slice())).collect();
488
489        // Hint active-extent: backends that support per-kernel skip-
490        // compute (today: CPU's Activation thunk family) honor it; the
491        // default trait impl is a no-op, so other backends just process
492        // full extent and the slice_rows below still gives the user
493        // correct outputs.
494        compiled.set_active_extent(Some((actual_rows, upper as usize)));
495        let raw_outputs = compiled.run(&pairs);
496        compiled.set_active_extent(None);
497        #[cfg(feature = "cpu")]
498        crate::onnx_active::set_active_token_count(None);
499
500        let outs = raw_outputs
501            .into_iter()
502            .enumerate()
503            .map(|(i, out)| match output_inners.get(i).copied() {
504                Some(0) | None => out,
505                Some(inner) => slice_rows(&out, inner, actual_rows),
506            })
507            .collect();
508
509        Some((upper, outs))
510    }
511
512    /// Like [`Self::get_or_compile_with_options`] but also uploads `params` on first compile.
513    pub fn ensure_graph_with_params<F>(
514        &mut self,
515        key: u64,
516        build: F,
517        options: &crate::CompileOptions,
518    ) -> Option<(u64, &mut CompiledGraph)>
519    where
520        F: FnOnce(u64) -> (Graph, HashMap<String, Vec<f32>>),
521    {
522        let idx = self.bucket_for(key)?;
523        let upper = self.buckets[idx].range.end - 1;
524        if self.buckets[idx].compiled.is_none() {
525            let (graph, params) = build(upper);
526            let mut session = Session::new(self.device);
527            if let Some(p) = &self.policy {
528                session = session.with_policy(p.clone());
529            }
530            let mut compiled = session.compile_with(graph, options);
531            for (name, data) in params {
532                compiled.set_param(&name, &data);
533            }
534            self.buckets[idx].compiled = Some(compiled);
535        }
536        Some((upper, self.buckets[idx].compiled.as_mut().unwrap()))
537    }
538
539    /// HIR variant of [`Self::ensure_graph_with_params`].
540    pub fn ensure_hir_with_params<F>(
541        &mut self,
542        key: u64,
543        build: F,
544        options: &crate::CompileOptions,
545    ) -> Option<(u64, &mut CompiledGraph)>
546    where
547        F: FnOnce(u64) -> (HirModule, HashMap<String, Vec<f32>>),
548    {
549        let idx = self.bucket_for(key)?;
550        let upper = self.buckets[idx].range.end - 1;
551        if self.buckets[idx].compiled.is_none() {
552            let (hir, params) = build(upper);
553            let mut session = Session::new(self.device);
554            if let Some(p) = &self.policy {
555                session = session.with_policy(p.clone());
556            }
557            let mut compiled = session
558                .compile_hir_with(hir, options)
559                .expect("HIR lower/compile in ensure_hir_with_params");
560            for (name, data) in params {
561                compiled.set_param(&name, &data);
562            }
563            self.buckets[idx].compiled = Some(compiled);
564        }
565        Some((upper, self.buckets[idx].compiled.as_mut().unwrap()))
566    }
567
568    /// [`Self::run_padded`] with per-input optional row padding (`CacheRunInput`).
569    pub fn run_padded_mixed<F>(
570        &mut self,
571        key: u64,
572        actual_rows: usize,
573        build: F,
574        inputs: &[CacheRunInput<'_>],
575        output_inners: &[usize],
576    ) -> Option<(u64, Vec<Vec<f32>>)>
577    where
578        F: FnOnce(u64) -> Graph,
579    {
580        let (upper, compiled) = self.get_or_compile(key, build)?;
581
582        let padded: Vec<(&str, Vec<f32>)> = inputs
583            .iter()
584            .map(|inp| match inp.row_inner {
585                Some(inner) => (inp.name, pad_rows(inp.data, inner, upper)),
586                None => (inp.name, inp.data.to_vec()),
587            })
588            .collect();
589        let pairs: Vec<(&str, &[f32])> = padded.iter().map(|(n, d)| (*n, d.as_slice())).collect();
590
591        compiled.set_active_extent(Some((actual_rows, upper as usize)));
592        let raw_outputs = compiled.run(&pairs);
593        compiled.set_active_extent(None);
594        #[cfg(feature = "cpu")]
595        crate::onnx_active::set_active_token_count(None);
596
597        let outs = raw_outputs
598            .into_iter()
599            .enumerate()
600            .map(|(i, out)| match output_inners.get(i).copied() {
601                Some(0) | None => out,
602                Some(inner) => slice_rows(&out, inner, actual_rows),
603            })
604            .collect();
605
606        Some((upper, outs))
607    }
608
609    /// Drain in-flight GPU work on every compiled bucket.
610    pub fn sync_all(&mut self) {
611        for bucket in &mut self.buckets {
612            if let Some(compiled) = &mut bucket.compiled {
613                compiled.sync_pending();
614            }
615        }
616    }
617}
618
619// ── Dynamic-dim cache (plan #54) ──────────────────────────────────────
620//
621// Compile HIR once through the fusion pipeline (graph may contain
622// `Dim::Dynamic` symbols), then specialize to concrete shapes per cache
623// key and backend-compile the resulting LIR.
624
625/// Compile-once / specialize-at-runtime cache for symbolic HIR modules.
626pub struct DynamicDimCompileCache {
627    device: Device,
628    policy: Option<rlx_opt::PrecisionPolicy>,
629    capacity: usize,
630    template: Option<CompileResult>,
631    entries: Vec<(u64, CompiledGraph)>,
632    order: VecDeque<u64>,
633}
634
635impl DynamicDimCompileCache {
636    pub fn new(device: Device, capacity: usize) -> Self {
637        Self::with_policy(device, capacity, None)
638    }
639
640    pub fn with_policy(
641        device: Device,
642        capacity: usize,
643        policy: Option<rlx_opt::PrecisionPolicy>,
644    ) -> Self {
645        assert!(capacity > 0, "DynamicDimCompileCache capacity must be ≥ 1");
646        Self {
647            device,
648            policy,
649            capacity,
650            template: None,
651            entries: Vec::with_capacity(capacity),
652            order: VecDeque::with_capacity(capacity),
653        }
654    }
655
656    pub fn compile_device(&self) -> Device {
657        self.device
658    }
659
660    /// Return a backend-compiled graph specialized for `binding`.
661    /// `build_hir` runs at most once to populate the dynamic template.
662    pub fn get_or_specialize<F: FnOnce() -> HirModule>(
663        &mut self,
664        key: u64,
665        binding: &DimBinding,
666        build_hir: F,
667        options: &crate::CompileOptions,
668    ) -> Result<&mut CompiledGraph, rlx_ir::hir::LowerError> {
669        if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
670            return Ok(&mut self.entries[idx].1);
671        }
672        if self.template.is_none() {
673            let mut template_opts = options.clone();
674            template_opts.dim_binding = None;
675            let pipe = crate::stages::pipeline_for(self.device, &template_opts);
676            self.template = Some(pipe.compile_hir(build_hir())?);
677        }
678        let template = self.template.as_ref().expect("template just set");
679        let mut spec_opts = options.clone();
680        spec_opts.dim_binding = None;
681        let pipe = crate::stages::pipeline_for(self.device, &spec_opts);
682        let specialized = template.specialize(&pipe, binding);
683        let backend = crate::registry::backend_for(self.device).expect("backend registered");
684        let mut compile_opts = options.clone();
685        compile_opts.dim_binding = None;
686        if compile_opts.policy.is_none() {
687            if let Some(p) = &self.policy {
688                compile_opts = compile_opts.policy(p.clone());
689            }
690        }
691        let executable = backend.compile_lir(specialized.lir, &compile_opts);
692        let compiled = CompiledGraph::new(executable, self.device);
693
694        if self.entries.len() >= self.capacity
695            && let Some(evict_key) = self.order.pop_front()
696        {
697            sync_evicted_entry(&mut self.entries, evict_key);
698            self.entries.retain(|(k, _)| *k != evict_key);
699        }
700        self.entries.push((key, compiled));
701        self.order.push_back(key);
702        Ok(&mut self.entries.last_mut().unwrap().1)
703    }
704
705    pub fn len(&self) -> usize {
706        self.entries.len()
707    }
708
709    pub fn is_empty(&self) -> bool {
710        self.entries.is_empty()
711    }
712
713    pub fn contains(&self, key: u64) -> bool {
714        self.entries.iter().any(|(k, _)| *k == key)
715    }
716
717    pub fn clear(&mut self) {
718        self.sync_all();
719        self.template = None;
720        self.entries.clear();
721        self.order.clear();
722    }
723
724    pub fn has_template(&self) -> bool {
725        self.template.is_some()
726    }
727
728    /// Drain in-flight GPU work on every specialized entry.
729    pub fn sync_all(&mut self) {
730        for (_, compiled) in &mut self.entries {
731            compiled.sync_pending();
732        }
733    }
734
735    /// Build the symbolic template once (no specialization).
736    pub fn ensure_template<F: FnOnce() -> HirModule>(
737        &mut self,
738        build_hir: F,
739        options: &crate::CompileOptions,
740    ) -> Result<&CompileResult, rlx_ir::hir::LowerError> {
741        if self.template.is_none() {
742            let mut opts = options.clone();
743            opts.dim_binding = None;
744            let pipe = crate::stages::pipeline_for(self.device, &opts);
745            self.template = Some(pipe.compile_hir(build_hir())?);
746        }
747        Ok(self.template.as_ref().expect("template set"))
748    }
749
750    pub fn template_result(&self) -> Option<&CompileResult> {
751        self.template.as_ref()
752    }
753
754    /// Specialize via on-disk LIR cache ([`CompilationMode::Aot`]).
755    /// Disk-backed specialize ([`rlx_ir::CompilationMode::Aot`]).
756    pub fn get_or_specialize_aot<F: FnOnce() -> HirModule>(
757        &mut self,
758        aot: &crate::AotCache,
759        disk_base: &str,
760        key: u64,
761        binding: &rlx_ir::DimBinding,
762        build_hir: F,
763        options: &crate::CompileOptions,
764    ) -> Result<&mut CompiledGraph, crate::AotCacheError> {
765        if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
766            return Ok(&mut self.entries[idx].1);
767        }
768        let device = self.device;
769        let template = self.ensure_template(build_hir, options)?;
770        let compiled = aot.specialize_cached(disk_base, binding, device, template, options)?;
771        if self.entries.len() >= self.capacity
772            && let Some(evict_key) = self.order.pop_front()
773        {
774            sync_evicted_entry(&mut self.entries, evict_key);
775            self.entries.retain(|(k, _)| *k != evict_key);
776        }
777        self.entries.push((key, compiled));
778        self.order.push_back(key);
779        Ok(&mut self.entries.last_mut().unwrap().1)
780    }
781}
782
783/// Pad `data` (interpreted as `[actual, inner]` row-major) up to `upper`
784/// rows by appending zeros. Returns a `Vec<f32>` of length
785/// `upper * inner`. Companion of [`slice_rows`] for the
786/// "compile at max, run at less" workflow with [`BucketedCompileCache`].
787///
788/// Panics if `data.len()` is not a multiple of `inner`, if `inner == 0`,
789/// or if `data.len() / inner > upper`.
790pub fn pad_rows(data: &[f32], inner: usize, upper: u64) -> Vec<f32> {
791    assert!(inner > 0, "pad_rows: inner stride must be ≥ 1");
792    assert_eq!(
793        data.len() % inner,
794        0,
795        "pad_rows: data len {} not a multiple of inner {inner}",
796        data.len(),
797    );
798    let upper = upper as usize;
799    let actual = data.len() / inner;
800    assert!(
801        actual <= upper,
802        "pad_rows: actual rows {actual} exceed upper bound {upper}",
803    );
804    let mut out = vec![0.0_f32; upper * inner];
805    out[..actual * inner].copy_from_slice(data);
806    out
807}
808
809/// Pad `data` (`[actual, inner]` row-major) into preallocated `out` (`[upper, inner]`).
810pub fn pad_rows_into(out: &mut [f32], data: &[f32], inner: usize) {
811    assert!(inner > 0, "pad_rows_into: inner stride must be ≥ 1");
812    assert_eq!(
813        data.len() % inner,
814        0,
815        "pad_rows_into: data len {} not a multiple of inner {inner}",
816        data.len(),
817    );
818    assert_eq!(
819        out.len() % inner,
820        0,
821        "pad_rows_into: out len {} not a multiple of inner {inner}",
822        out.len(),
823    );
824    let upper = out.len() / inner;
825    let actual = data.len() / inner;
826    assert!(
827        actual <= upper,
828        "pad_rows_into: actual rows {actual} exceed upper bound {upper}",
829    );
830    out.fill(0.0);
831    out[..data.len()].copy_from_slice(data);
832}
833
834/// Slice `data` (interpreted as `[upper, inner]` row-major) down to
835/// `actual` rows. Companion of [`pad_rows`].
836///
837/// Panics if `data.len()` is not a multiple of `inner`, if `inner == 0`,
838/// or if `actual` exceeds the number of rows in `data`.
839pub fn slice_rows(data: &[f32], inner: usize, actual: usize) -> Vec<f32> {
840    assert!(inner > 0, "slice_rows: inner stride must be ≥ 1");
841    assert_eq!(
842        data.len() % inner,
843        0,
844        "slice_rows: data len {} not a multiple of inner {inner}",
845        data.len(),
846    );
847    let upper = data.len() / inner;
848    assert!(
849        actual <= upper,
850        "slice_rows: actual rows {actual} exceed upper {upper}",
851    );
852    data[..actual * inner].to_vec()
853}
854
855#[cfg(test)]
856mod tests {
857    use super::*;
858    use rlx_ir::infer::GraphExt;
859    use rlx_ir::*;
860    use std::cell::Cell;
861
862    fn tiny_graph(n: usize) -> Graph {
863        let mut g = Graph::new("t");
864        let f = DType::F32;
865        let x = g.input("x", Shape::new(&[n], f));
866        let y = g.activation(rlx_ir::op::Activation::Relu, x, Shape::new(&[n], f));
867        g.set_outputs(vec![y]);
868        g
869    }
870
871    #[test]
872    fn cache_hits_avoid_recompile() {
873        let mut cache = CompileCache::new(Device::Cpu, 4);
874        let calls = Cell::new(0);
875
876        let _ = cache.get_or_compile(1, || {
877            calls.set(calls.get() + 1);
878            tiny_graph(8)
879        });
880        let _ = cache.get_or_compile(1, || {
881            calls.set(calls.get() + 1);
882            tiny_graph(8)
883        });
884        let _ = cache.get_or_compile(1, || {
885            calls.set(calls.get() + 1);
886            tiny_graph(8)
887        });
888        // Same key three times: build closure runs once.
889        assert_eq!(calls.get(), 1);
890        assert_eq!(cache.len(), 1);
891    }
892
893    #[test]
894    fn fifo_evicts_oldest_at_capacity() {
895        let mut cache = CompileCache::new(Device::Cpu, 2);
896        let _ = cache.get_or_compile(1, || tiny_graph(4));
897        let _ = cache.get_or_compile(2, || tiny_graph(8));
898        assert!(cache.contains(1) && cache.contains(2));
899        // Third entry evicts key 1 (oldest).
900        let _ = cache.get_or_compile(3, || tiny_graph(16));
901        assert!(!cache.contains(1));
902        assert!(cache.contains(2) && cache.contains(3));
903    }
904
905    #[test]
906    fn different_keys_keep_separate_compiles() {
907        let mut cache = CompileCache::new(Device::Cpu, 4);
908        let calls = Cell::new(0);
909        let _ = cache.get_or_compile(1, || {
910            calls.set(calls.get() + 1);
911            tiny_graph(8)
912        });
913        let _ = cache.get_or_compile(2, || {
914            calls.set(calls.get() + 1);
915            tiny_graph(16)
916        });
917        let _ = cache.get_or_compile(1, || {
918            calls.set(calls.get() + 1);
919            tiny_graph(8)
920        });
921        // Two unique keys → two compiles.
922        assert_eq!(calls.get(), 2);
923        assert_eq!(cache.len(), 2);
924    }
925
926    // ── BucketedCompileCache ──────────────────────────────────────────
927
928    #[test]
929    fn bucket_amortizes_keys_within_range() {
930        let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..4, 4..16]);
931        let calls = Cell::new(0);
932        let uppers = Cell::new((0u64, 0u64));
933
934        // Two distinct keys (2 and 3) both fall inside bucket 0 (1..4).
935        let (u1, _) = cache
936            .get_or_compile(2, |upper| {
937                calls.set(calls.get() + 1);
938                uppers.set((upper, uppers.get().1));
939                tiny_graph(upper as usize)
940            })
941            .expect("key 2 in range");
942        let (u2, _) = cache
943            .get_or_compile(3, |upper| {
944                calls.set(calls.get() + 1);
945                uppers.set((uppers.get().0, upper));
946                tiny_graph(upper as usize)
947            })
948            .expect("key 3 in range");
949
950        // One compile, both calls saw the same upper = range.end - 1 = 3.
951        assert_eq!(calls.get(), 1);
952        assert_eq!(u1, 3);
953        assert_eq!(u2, 3);
954        assert_eq!(uppers.get().0, 3);
955        assert_eq!(cache.compiled_count(), 1);
956        assert_eq!(cache.total_buckets(), 2);
957    }
958
959    #[test]
960    fn bucket_lookup_returns_none_outside_range() {
961        let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..4, 4..16]);
962        assert!(cache.bucket_for(0).is_none());
963        assert!(cache.bucket_for(16).is_none());
964        assert!(cache.bucket_for(100).is_none());
965        assert_eq!(cache.bucket_for(3), Some(0));
966        assert_eq!(cache.bucket_for(4), Some(1));
967        assert_eq!(cache.bucket_upper_for_key(3), Some(3));
968        assert_eq!(cache.bucket_upper_for_key(4), Some(15));
969        assert!(cache.bucket_upper_for_key(0).is_none());
970
971        let calls = Cell::new(0);
972        let result = cache.get_or_compile(100, |u| {
973            calls.set(calls.get() + 1);
974            tiny_graph(u as usize)
975        });
976        assert!(result.is_none());
977        assert_eq!(calls.get(), 0); // build closure must not run for OOR keys
978        assert_eq!(cache.compiled_count(), 0);
979    }
980
981    #[test]
982    fn bucket_compiles_lazily_per_bucket() {
983        let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..4, 4..16, 16..64]);
984        let calls = Cell::new(0);
985
986        let _ = cache.get_or_compile(2, |u| {
987            calls.set(calls.get() + 1);
988            tiny_graph(u as usize)
989        });
990        let _ = cache.get_or_compile(8, |u| {
991            calls.set(calls.get() + 1);
992            tiny_graph(u as usize)
993        });
994        // Two distinct buckets hit → two compiles. Third bucket untouched.
995        assert_eq!(calls.get(), 2);
996        assert_eq!(cache.compiled_count(), 2);
997        assert_eq!(cache.total_buckets(), 3);
998    }
999
1000    #[test]
1001    #[should_panic(expected = "overlap")]
1002    fn bucket_overlap_rejected() {
1003        let _ = BucketedCompileCache::new(Device::Cpu, vec![1..8, 4..16]);
1004    }
1005
1006    #[test]
1007    #[should_panic(expected = "≥1 bucket")]
1008    fn empty_bucket_list_rejected() {
1009        let _ = BucketedCompileCache::new(Device::Cpu, vec![]);
1010    }
1011
1012    // ── pad_rows / slice_rows ─────────────────────────────────────────
1013
1014    #[test]
1015    fn pad_rows_appends_zeros() {
1016        // 1D: actual=3 → upper=5, inner=1.
1017        let p = pad_rows(&[1.0, 2.0, 3.0], 1, 5);
1018        assert_eq!(p, vec![1.0, 2.0, 3.0, 0.0, 0.0]);
1019
1020        // 2D row-major [actual=2, inner=3] → [upper=4, inner=3].
1021        let p = pad_rows(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 4);
1022        assert_eq!(
1023            p,
1024            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1025        );
1026
1027        // actual == upper: no-op pad.
1028        let p = pad_rows(&[7.0, 8.0], 1, 2);
1029        assert_eq!(p, vec![7.0, 8.0]);
1030    }
1031
1032    #[test]
1033    fn slice_rows_truncates_trailing() {
1034        let s = slice_rows(&[1.0, 2.0, 3.0, 0.0, 0.0], 1, 3);
1035        assert_eq!(s, vec![1.0, 2.0, 3.0]);
1036
1037        let s = slice_rows(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0, 0.0], 3, 2);
1038        assert_eq!(s, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1039    }
1040
1041    #[test]
1042    #[should_panic(expected = "exceed upper")]
1043    fn pad_rows_rejects_too_long_input() {
1044        let _ = pad_rows(&[1.0, 2.0, 3.0, 4.0], 1, 3);
1045    }
1046
1047    #[test]
1048    #[should_panic(expected = "exceed upper")]
1049    fn slice_rows_rejects_too_large_actual() {
1050        let _ = slice_rows(&[1.0, 2.0, 3.0], 1, 5);
1051    }
1052
1053    // ── BucketedCompileCache::run_padded ──────────────────────────────
1054
1055    #[test]
1056    fn run_padded_pads_input_and_slices_output() {
1057        // tiny_graph is 1D [n] → relu → [n].
1058        // Compile bucket [1..16) at upper=15, run with actual_rows=10.
1059        let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
1060        let input: Vec<f32> = vec![1.0, -1.0, 2.0, -2.0, 3.0, -3.0, 4.0, -4.0, 5.0, -5.0];
1061
1062        let (upper, outs) = cache
1063            .run_padded(
1064                10, // key
1065                10, // actual rows
1066                |max| tiny_graph(max as usize),
1067                &[("x", &input, 1)], // 1D, inner stride 1
1068                &[1],                // slice the one output to actual rows
1069            )
1070            .expect("key 10 in [1..16)");
1071
1072        assert_eq!(upper, 15);
1073        assert_eq!(outs.len(), 1);
1074        let out = &outs[0];
1075        assert_eq!(out.len(), 10, "output sliced back to actual_rows");
1076        let expected: Vec<f32> = input.iter().map(|x| x.max(0.0)).collect();
1077        assert_eq!(out, &expected);
1078    }
1079
1080    #[test]
1081    fn run_padded_reuses_bucket_across_actuals() {
1082        // Same bucket, two different actuals — only one compile.
1083        let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
1084        let calls = Cell::new(0);
1085
1086        let (u1, o1) = cache
1087            .run_padded(
1088                10,
1089                10,
1090                |max| {
1091                    calls.set(calls.get() + 1);
1092                    tiny_graph(max as usize)
1093                },
1094                &[(
1095                    "x",
1096                    &[1.0, -1.0, 2.0, -2.0, 3.0, -3.0, 4.0, -4.0, 5.0, -5.0],
1097                    1,
1098                )],
1099                &[1],
1100            )
1101            .unwrap();
1102        assert_eq!(o1.len(), 1);
1103        assert_eq!(o1[0].len(), 10);
1104        assert_eq!(u1, 15);
1105
1106        let (u2, o2) = cache
1107            .run_padded(
1108                5,
1109                5,
1110                |max| {
1111                    calls.set(calls.get() + 1);
1112                    tiny_graph(max as usize)
1113                },
1114                &[("x", &[-1.0, 2.0, -3.0, 4.0, -5.0], 1)],
1115                &[1],
1116            )
1117            .unwrap();
1118        assert_eq!(o2.len(), 1);
1119        assert_eq!(o2[0].len(), 5);
1120        assert_eq!(u2, 15);
1121        assert_eq!(o2[0], vec![0.0, 2.0, 0.0, 4.0, 0.0]);
1122
1123        assert_eq!(calls.get(), 1, "bucket cached across actuals");
1124        assert_eq!(cache.compiled_count(), 1);
1125    }
1126
1127    #[test]
1128    fn run_padded_returns_none_out_of_range() {
1129        let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
1130        let calls = Cell::new(0);
1131        let result = cache.run_padded(
1132            100,
1133            5,
1134            |u| {
1135                calls.set(calls.get() + 1);
1136                tiny_graph(u as usize)
1137            },
1138            &[("x", &[1.0, 2.0, 3.0, 4.0, 5.0], 1)],
1139            &[1],
1140        );
1141        assert!(result.is_none());
1142        assert_eq!(calls.get(), 0);
1143        assert_eq!(cache.compiled_count(), 0);
1144    }
1145
1146    // ── power_of_two_ladder ───────────────────────────────────────────
1147
1148    #[test]
1149    fn power_of_two_ladder_generates_log_buckets() {
1150        let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 8, 64);
1151        // Expect buckets covering keys 1..=64 with extents 8, 16, 32, 64.
1152        let ranges: Vec<_> = cache.buckets().cloned().collect();
1153        assert_eq!(ranges, vec![1..9, 9..17, 17..33, 33..65]);
1154        assert_eq!(cache.total_buckets(), 4);
1155    }
1156
1157    #[test]
1158    fn power_of_two_ladder_picks_smallest_extent_for_actual() {
1159        // Ladder: extents 8, 16, 32, 64. actual=17 lands in the 32-extent
1160        // bucket, NOT the 64-extent one — that's the compute saving.
1161        let mut cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 8, 64);
1162        let captured_uppers: std::cell::RefCell<Vec<u64>> = Default::default();
1163
1164        let (u17, _) = cache
1165            .get_or_compile(17, |upper| {
1166                captured_uppers.borrow_mut().push(upper);
1167                tiny_graph(upper as usize)
1168            })
1169            .unwrap();
1170        let (u9, _) = cache
1171            .get_or_compile(9, |upper| {
1172                captured_uppers.borrow_mut().push(upper);
1173                tiny_graph(upper as usize)
1174            })
1175            .unwrap();
1176        let (u3, _) = cache
1177            .get_or_compile(3, |upper| {
1178                captured_uppers.borrow_mut().push(upper);
1179                tiny_graph(upper as usize)
1180            })
1181            .unwrap();
1182        let (u64_, _) = cache
1183            .get_or_compile(64, |upper| {
1184                captured_uppers.borrow_mut().push(upper);
1185                tiny_graph(upper as usize)
1186            })
1187            .unwrap();
1188
1189        assert_eq!(u17, 32, "key=17 → smallest extent ≥ 17 is 32");
1190        assert_eq!(u9, 16, "key=9  → smallest extent ≥ 9  is 16");
1191        assert_eq!(u3, 8, "key=3  → smallest extent ≥ 3  is 8");
1192        assert_eq!(u64_, 64, "key=64 → exact match at 64");
1193        assert_eq!(*captured_uppers.borrow(), vec![32, 16, 8, 64]);
1194        assert_eq!(cache.compiled_count(), 4);
1195    }
1196
1197    #[test]
1198    fn power_of_two_ladder_min_above_one_starts_at_one() {
1199        // First bucket always covers from key 1, even when min > 1.
1200        // (`min` controls the ladder's first extent, not the lower edge.)
1201        let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 16, 32);
1202        let ranges: Vec<_> = cache.buckets().cloned().collect();
1203        // min=16 → first extent 16, second 32. Buckets: 1..17, 17..33.
1204        assert_eq!(ranges, vec![1..17, 17..33]);
1205    }
1206
1207    #[test]
1208    fn power_of_two_ladder_non_pow2_min_rounds_up() {
1209        // min=10 → next_power_of_two = 16.
1210        let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 10, 64);
1211        let ranges: Vec<_> = cache.buckets().cloned().collect();
1212        assert_eq!(ranges, vec![1..17, 17..33, 33..65]);
1213    }
1214
1215    #[test]
1216    fn power_of_two_ladder_max_below_pow2_extends_up() {
1217        // max=20 needs to be covered → ladder extends to 32.
1218        let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 8, 20);
1219        let ranges: Vec<_> = cache.buckets().cloned().collect();
1220        assert_eq!(ranges, vec![1..9, 9..17, 17..33]);
1221    }
1222
1223    #[test]
1224    fn power_of_two_ladder_min_equals_max() {
1225        let cache = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 16, 16);
1226        let ranges: Vec<_> = cache.buckets().cloned().collect();
1227        assert_eq!(ranges, vec![1..17]);
1228    }
1229
1230    #[test]
1231    #[should_panic(expected = "min must be ≥ 1")]
1232    fn power_of_two_ladder_zero_min_rejected() {
1233        let _ = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 0, 16);
1234    }
1235
1236    #[test]
1237    #[should_panic(expected = "max")]
1238    fn power_of_two_ladder_max_below_min_rejected() {
1239        let _ = BucketedCompileCache::power_of_two_ladder(Device::Cpu, 32, 8);
1240    }
1241
1242    // ── Active-extent dispatch (true per-kernel skip-compute) ─────────
1243    //
1244    // The 3 tests below assert per-thunk active-extent scaling on the CPU
1245    // backend. Today `rlx_cpu::thunk::execute_thunks_active` is documented
1246    // as a stub that returns false (rlx-cpu/src/thunk.rs:2100-2110), so
1247    // the runtime falls back to full-extent dispatch — overwrites the
1248    // tail and the tail-preservation assertions fail. They're left here
1249    // (marked `#[ignore]`) as the test-driven contract that the future
1250    // active-extent implementation must satisfy. Drop the `#[ignore]`
1251    // when the per-thunk scaling lands for Copy / ActivationInPlace /
1252    // BinaryFull / Attention.
1253
1254    #[test]
1255    #[ignore = "active-extent execution is a stub on CPU (thunk.rs::execute_thunks_active)"]
1256    fn active_extent_skips_compute_on_cpu_activation() {
1257        // tiny_graph(15) is `Input([15]) → Relu → Output` and lowers to
1258        // a Copy + ActivationInPlace pair on CPU — both are in the safe
1259        // set, so the active-extent path runs scaled.
1260        //
1261        // To prove kernels actually skipped: warm the arena with a prior
1262        // full-extent run whose output is `[1.0; 15]`, then run again
1263        // with a negative-only input and active=5. The first 5 outputs
1264        // get re-copied + re-relu'd to 0; the tail (indices 5..15) stays
1265        // at 1.0 because both Copy and Activation skipped it. A full-
1266        // extent fallback would clip every element to 0.
1267        let graph = tiny_graph(15);
1268        let mut compiled = Session::new(Device::Cpu).compile(graph);
1269
1270        // Warm-up: full extent, all-positive input → output [1.0; 15].
1271        let warm_input: Vec<f32> = vec![1.0; 15];
1272        let warm_outs = compiled.run(&[("x", &warm_input)]);
1273        assert_eq!(warm_outs[0], vec![1.0; 15], "warm-up sanity");
1274
1275        // Active-extent run: all-negative input, hint actual=5 of 15.
1276        // First 5: Copy(-1) + Relu → 0. Tail: kernels skip → stays 1.0.
1277        let neg_input: Vec<f32> = vec![-1.0; 15];
1278        compiled.set_active_extent(Some((5, 15)));
1279        let outs = compiled.run(&[("x", &neg_input)]);
1280        let out = &outs[0];
1281
1282        assert_eq!(out.len(), 15);
1283        assert_eq!(
1284            out[..5],
1285            [0.0; 5],
1286            "first 5 elements processed (relu of -1)"
1287        );
1288        assert_eq!(
1289            out[5..],
1290            [1.0; 10],
1291            "tail untouched — proves Copy + Activation skipped indices 5..15"
1292        );
1293
1294        // Clear the hint and run again with the negative input — full
1295        // extent now processes everything, every element clips to 0.
1296        compiled.set_active_extent(None);
1297        let outs = compiled.run(&[("x", &neg_input)]);
1298        assert_eq!(
1299            outs[0],
1300            vec![0.0; 15],
1301            "full-extent path must clip every negative"
1302        );
1303    }
1304
1305    #[test]
1306    #[ignore = "active-extent execution is a stub on CPU (thunk.rs::execute_thunks_active)"]
1307    fn active_extent_skips_compute_on_binary_full() {
1308        // Input([4]) + Input([4]) → Output. Lowers to a BinaryFull
1309        // thunk with no broadcast (lhs_len == rhs_len == len), which
1310        // is in the safe set.
1311        let mut g = Graph::new("add");
1312        let f = DType::F32;
1313        let a = g.input("a", Shape::new(&[4], f));
1314        let b = g.input("b", Shape::new(&[4], f));
1315        let c = g.add(a, b);
1316        g.set_outputs(vec![c]);
1317        let mut compiled = Session::new(Device::Cpu).compile(g);
1318
1319        // Warm: full extent, output buffer becomes [2.0; 4].
1320        let warm = compiled.run(&[("a", &[1.0f32; 4]), ("b", &[1.0f32; 4])]);
1321        assert_eq!(warm[0], vec![2.0; 4]);
1322
1323        // Active-extent run: actual=2 of upper=4. Process first 2
1324        // elements only; tail (indices 2..4) stays at 2.0 from warm.
1325        compiled.set_active_extent(Some((2, 4)));
1326        let outs = compiled.run(&[("a", &[10.0f32; 4]), ("b", &[10.0f32; 4])]);
1327        let out = &outs[0];
1328        assert_eq!(out[..2], [20.0, 20.0], "first 2 = active sum");
1329        assert_eq!(
1330            out[2..],
1331            [2.0, 2.0],
1332            "tail untouched — proves BinaryFull skipped indices 2..4"
1333        );
1334
1335        // Clear hint → full path overwrites entire output.
1336        compiled.set_active_extent(None);
1337        let outs = compiled.run(&[("a", &[10.0f32; 4]), ("b", &[10.0f32; 4])]);
1338        assert_eq!(outs[0], vec![20.0; 4]);
1339    }
1340
1341    #[test]
1342    #[ignore = "process-wide STATE; runs only in isolation via `cargo test perfetto -- --ignored`"]
1343    fn perfetto_trace_emits_per_thunk_events() {
1344        // PLAN L3: end-to-end Perfetto event capture. Requires the env
1345        // var to be set BEFORE the perfetto module is first touched
1346        // (OnceLock — can't re-init). We set it here unconditionally;
1347        // for tests run in parallel within the same process, the
1348        // earliest test wins. To avoid flake we mark this `#[ignore]`
1349        // and the developer runs it explicitly.
1350        use std::env;
1351        use std::fs;
1352        let path = env::temp_dir().join(format!("rlx-perfetto-e2e-{}.json", std::process::id()));
1353        if path.exists() {
1354            let _ = fs::remove_file(&path);
1355        }
1356        unsafe {
1357            env::set_var("RLX_TRACE_PERFETTO", &path);
1358        }
1359
1360        // Build + run a small CPU graph — Add → Relu (no fusion macros).
1361        let f = DType::F32;
1362        let mut g = Graph::new("perf");
1363        let a = g.input("a", Shape::new(&[4], f));
1364        let b = g.input("b", Shape::new(&[4], f));
1365        let s = g.add(a, b);
1366        let r = g.relu(s);
1367        g.set_outputs(vec![r]);
1368        let mut compiled = Session::new(Device::Cpu).compile(g);
1369        let _ = compiled.run(&[("a", &[1.0; 4]), ("b", &[1.0; 4])]);
1370
1371        // Force the trace file to flush its closing bracket.
1372        crate::perfetto::flush_and_finalize();
1373
1374        let contents = fs::read_to_string(&path).expect("trace file");
1375        // At minimum we should see one of our thunk names.
1376        assert!(
1377            contents.contains("\"binary\"")
1378                || contents.contains("\"activation\"")
1379                || contents.contains("\"elementwise_region\""),
1380            "expected at least one thunk-name event in perfetto trace; got: {contents}"
1381        );
1382        // JSON shape: starts with `[` and (after flush) ends with `]`.
1383        assert!(contents.trim_start().starts_with('['));
1384        let _ = fs::remove_file(&path);
1385    }
1386
1387    #[test]
1388    fn elementwise_region_fused_matches_unfused() {
1389        // PLAN L2: a chain `Add(a, b) → Mul(_, c) → Relu` should fuse
1390        // into one ElementwiseRegion thunk in the CPU backend. Compare
1391        // its output against the value computed by hand to confirm the
1392        // fused execution is numerically identical.
1393        let f = DType::F32;
1394        let mut g = Graph::new("ew_e2e");
1395        let a = g.input("a", Shape::new(&[8], f));
1396        let b = g.input("b", Shape::new(&[8], f));
1397        let c = g.input("c", Shape::new(&[8], f));
1398        let s = Shape::new(&[8], f);
1399        let add = g.add(a, b);
1400        let mul = g.mul(add, c);
1401        let relu = g.relu(mul);
1402        let _ = s;
1403        g.set_outputs(vec![relu]);
1404
1405        let mut compiled = Session::new(Device::Cpu).compile(g);
1406        let av: Vec<f32> = vec![1.0, -2.0, 3.0, -4.0, 0.5, -0.5, 1.5, -1.5];
1407        let bv: Vec<f32> = vec![0.5, 1.0, 2.0, 4.0, 0.5, 0.5, 0.5, 0.5];
1408        let cv: Vec<f32> = vec![1.0, 2.0, 1.0, 1.0, 2.0, 3.0, 0.5, 4.0];
1409        let outs = compiled.run(&[("a", &av), ("b", &bv), ("c", &cv)]);
1410        let out = &outs[0];
1411
1412        let expected: Vec<f32> = (0..8)
1413            .map(|i| {
1414                let v = (av[i] + bv[i]) * cv[i];
1415                v.max(0.0)
1416            })
1417            .collect();
1418        for (i, (got, exp)) in out.iter().zip(&expected).enumerate() {
1419            assert!(
1420                (got - exp).abs() < 1e-6,
1421                "mismatch at {i}: got {got}, expected {exp}"
1422            );
1423        }
1424    }
1425
1426    #[test]
1427    #[ignore = "active-extent execution is a stub on CPU (thunk.rs::execute_thunks_active)"]
1428    fn active_extent_skips_compute_on_attention() {
1429        // Standalone Attention with kernel-synthesized MaskKind::None.
1430        // Q/K/V shape: [batch=1, seq=4, num_heads*head_dim=8].
1431        use rlx_ir::op::MaskKind;
1432        let f = DType::F32;
1433        let mut g = Graph::new("attn");
1434        let q = g.input("q", Shape::new(&[1, 4, 8], f));
1435        let k = g.input("k", Shape::new(&[1, 4, 8], f));
1436        let v = g.input("v", Shape::new(&[1, 4, 8], f));
1437        let out = g.attention_kind(q, k, v, 2, 4, MaskKind::None, Shape::new(&[1, 4, 8], f));
1438        g.set_outputs(vec![out]);
1439        let mut compiled = Session::new(Device::Cpu).compile(g);
1440
1441        // Warm: full extent. Q=K=V uniform → output uniform-ish.
1442        let warm = compiled.run(&[
1443            ("q", &[1.0f32; 32]),
1444            ("k", &[1.0f32; 32]),
1445            ("v", &[1.0f32; 32]),
1446        ]);
1447        let warm_out = warm[0].clone();
1448        assert_eq!(warm_out.len(), 32);
1449
1450        // Active: s_active=2 of s_full=4. Different inputs.
1451        // Tail rows (indices 16..32 = positions 2,3) should be untouched
1452        // — preserved from the warm run. First 16 indices recomputed.
1453        compiled.set_active_extent(Some((2, 4)));
1454        let outs = compiled.run(&[
1455            ("q", &[3.0f32; 32]),
1456            ("k", &[3.0f32; 32]),
1457            ("v", &[3.0f32; 32]),
1458        ]);
1459        let out = &outs[0];
1460        assert_eq!(out.len(), 32);
1461        assert_eq!(
1462            &out[16..],
1463            &warm_out[16..],
1464            "tail (positions 2,3) must be untouched — proves Attention skipped"
1465        );
1466        // Sanity: first 2 positions changed since input value differs (3.0 vs 1.0).
1467        assert_ne!(
1468            &out[..16],
1469            &warm_out[..16],
1470            "first 2 positions should reflect new input"
1471        );
1472    }
1473
1474    #[test]
1475    fn active_extent_falls_back_when_unsupported_thunk_in_schedule() {
1476        // A graph containing any thunk outside `safe_for_active_extent`
1477        // (e.g. Sgemm via a matmul) must fall back to the full-extent
1478        // executor — partial application would feed garbage downstream.
1479        // We can't easily construct such a graph at this layer without
1480        // pulling in matmul builders, but we can verify the trait
1481        // contract via the simpler check: setting an extent hint on a
1482        // matmul-bearing graph still gives correct outputs (full-extent
1483        // fallback path was taken).
1484        //
1485        // Skipped explicit construction here — the safety net is the
1486        // `if !all(safe) return false` guard inside execute_thunks_active
1487        // plus the `if !active_used { execute_thunks(...) }` fallback in
1488        // the CPU executor, both unit-tested via direct safety-predicate
1489        // and the warm-arena test above.
1490    }
1491
1492    #[test]
1493    fn run_padded_uses_active_extent_on_cpu() {
1494        // End-to-end: the cache wires set_active_extent before run.
1495        // Same setup as above but driven through run_padded.
1496        let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
1497        let input: Vec<f32> = vec![
1498            1.0, -1.0, 2.0, -2.0, 3.0, // 5 real values
1499            -10.0, -20.0, -30.0, -40.0, -50.0, // padding zeros from pad_rows
1500        ];
1501        // pad_rows zero-pads from len=5 up to upper=15, so the arena
1502        // tail past index 5 is 0.0 going in. After active-extent run,
1503        // tail stays at 0.0 (untouched, but the value happens to match
1504        // what relu would produce). We can't observe skip via output
1505        // here — slice_rows trims to actual_rows anyway.
1506        let (upper, outs) = cache
1507            .run_padded(
1508                5,
1509                5,
1510                |max| tiny_graph(max as usize),
1511                &[("x", &input[..5], 1)],
1512                &[1],
1513            )
1514            .unwrap();
1515        assert_eq!(upper, 15);
1516        assert_eq!(outs[0].len(), 5);
1517        // Active-extent path (CPU honors): outputs match relu of the
1518        // first 5 inputs. Slicing already handled, so user-visible
1519        // result is the same whether or not the kernel skipped tail
1520        // compute. The point of this test is just to confirm the wiring
1521        // path doesn't crash and produces correct outputs end-to-end.
1522        assert_eq!(outs[0], vec![1.0, 0.0, 2.0, 0.0, 3.0]);
1523    }
1524
1525    #[test]
1526    fn run_padded_inner_zero_returns_output_unsliced() {
1527        // Marking output_inners[0] = 0 disables slicing for that output.
1528        // The compiled graph still runs at upper=15, so we expect 15 outputs back.
1529        let mut cache = BucketedCompileCache::new(Device::Cpu, vec![1..16]);
1530        let input: Vec<f32> = vec![1.0, -1.0, 2.0, -2.0, 3.0];
1531
1532        let (upper, outs) = cache
1533            .run_padded(
1534                5,
1535                5,
1536                |max| tiny_graph(max as usize),
1537                &[("x", &input, 1)],
1538                &[0], // don't slice this output
1539            )
1540            .unwrap();
1541
1542        assert_eq!(upper, 15);
1543        assert_eq!(
1544            outs[0].len(),
1545            15,
1546            "unsliced output preserves full upper extent"
1547        );
1548        // First 5 = relu of input, tail 10 = relu(0) = 0.
1549        assert_eq!(&outs[0][..5], &[1.0, 0.0, 2.0, 0.0, 3.0]);
1550        assert!(outs[0][5..].iter().all(|&v| v == 0.0));
1551    }
1552
1553    #[test]
1554    fn dynamic_dim_cache_specializes_per_key() {
1555        use rlx_ir::DType;
1556        use rlx_ir::Shape;
1557        use rlx_ir::hir::HirModule;
1558        use rlx_ir::sym;
1559
1560        let mut cache = DynamicDimCompileCache::new(Device::Cpu, 4);
1561        let opts = crate::CompileOptions::new();
1562        {
1563            let _short = cache
1564                .get_or_specialize(
1565                    8,
1566                    &rlx_ir::DimBinding::batch_seq(1, 8),
1567                    || {
1568                        let mut hir = HirModule::new("dyn_cache");
1569                        let x = hir.input_batch_seq("x", sym::BATCH, sym::SEQ, 4, DType::F32);
1570                        let w = hir.param("w", Shape::new(&[4, 2], DType::F32));
1571                        let y = hir.linear(
1572                            x,
1573                            w,
1574                            None,
1575                            None,
1576                            Shape::batch_seq(sym::BATCH, sym::SEQ, 2, DType::F32),
1577                        );
1578                        hir.set_outputs(vec![y]);
1579                        hir
1580                    },
1581                    &opts,
1582                )
1583                .expect("specialize short");
1584        }
1585        assert!(cache.has_template());
1586        assert_eq!(cache.len(), 1);
1587        cache
1588            .get_or_specialize(
1589                128,
1590                &rlx_ir::DimBinding::batch_seq(1, 128),
1591                || panic!("HIR builder must not run twice"),
1592                &opts,
1593            )
1594            .expect("specialize long");
1595        assert_eq!(cache.len(), 2);
1596    }
1597}