Skip to main content

rlx_runtime/
backend.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Backend trait — abstraction over CPU/GPU/CUDA execution.
17//!
18//! Each backend implements `Backend::compile(graph, &CompileOptions)` and
19//! returns an `ExecutableGraph`. New compile knobs go in `CompileOptions`
20//! rather than as new trait methods.
21
22use crate::CompileOptions;
23use rlx_ir::Graph;
24use rlx_ir::hir::HirModule;
25use rlx_ir::lir::LirModule;
26use std::collections::HashMap;
27use std::sync::Arc;
28
29use crate::cpu_low_precision;
30
31// ── Typed I/O helpers (shared across f32-arena backends) ────────────────
32
33/// Widen a typed byte buffer to `Vec<f32>`. Used by `set_param_typed` /
34/// `run_typed` overrides on backends whose internal arena is f32-uniform
35/// (CPU, Metal, wgpu) so callers can hand in F16/BF16 without doing the
36/// host-side cast themselves. Panics on dtypes the f32 arena can't carry.
37#[allow(dead_code)]
38pub(crate) fn widen_bytes_to_f32(data: &[u8], dtype: rlx_ir::DType) -> Vec<f32> {
39    use rlx_ir::DType;
40    match dtype {
41        DType::F32 => {
42            let n = data.len() / 4;
43            let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
44            s.to_vec()
45        }
46        DType::F16 => {
47            let n = data.len() / 2;
48            let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n) };
49            s.iter().map(|h| h.to_f32()).collect()
50        }
51        DType::BF16 => {
52            let n = data.len() / 2;
53            let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n) };
54            s.iter().map(|h| h.to_f32()).collect()
55        }
56        other => panic!(
57            "widen_bytes_to_f32: dtype {other:?} unsupported on f32-arena backends \
58             (only F32/F16/BF16 are accepted on the host I/O surface)"
59        ),
60    }
61}
62
63/// Narrow a `&[f32]` buffer down to the declared output dtype, returning
64/// the corresponding little-endian byte stream. Mirrors the bytes a
65/// backend that stores the native dtype would emit. Used by `run_typed`
66/// to keep the byte-level output contract identical across backends.
67#[allow(dead_code)]
68pub(crate) fn narrow_f32_to_bytes(v: &[f32], dt: rlx_ir::DType) -> Vec<u8> {
69    use rlx_ir::DType;
70    match dt {
71        DType::F32 => {
72            let mut bytes = Vec::with_capacity(v.len() * 4);
73            for &x in v {
74                bytes.extend_from_slice(&x.to_le_bytes());
75            }
76            bytes
77        }
78        DType::F16 => {
79            let mut bytes = Vec::with_capacity(v.len() * 2);
80            for &x in v {
81                bytes.extend_from_slice(&half::f16::from_f32(x).to_le_bytes());
82            }
83            bytes
84        }
85        DType::BF16 => {
86            let mut bytes = Vec::with_capacity(v.len() * 2);
87            for &x in v {
88                bytes.extend_from_slice(&half::bf16::from_f32(x).to_le_bytes());
89            }
90            bytes
91        }
92        DType::F64 => {
93            let mut bytes = Vec::with_capacity(v.len() * 8);
94            for &x in v {
95                bytes.extend_from_slice(&(x as f64).to_le_bytes());
96            }
97            bytes
98        }
99        DType::I8 => v.iter().map(|&x| x as i8 as u8).collect(),
100        DType::U8 => v.iter().map(|&x| x as u8).collect(),
101        DType::I16 => {
102            let mut bytes = Vec::with_capacity(v.len() * 2);
103            for &x in v {
104                bytes.extend_from_slice(&(x as i16).to_le_bytes());
105            }
106            bytes
107        }
108        DType::I32 => {
109            let mut bytes = Vec::with_capacity(v.len() * 4);
110            for &x in v {
111                bytes.extend_from_slice(&(x as i32).to_le_bytes());
112            }
113            bytes
114        }
115        DType::U32 => {
116            let mut bytes = Vec::with_capacity(v.len() * 4);
117            for &x in v {
118                bytes.extend_from_slice(&(x as u32).to_le_bytes());
119            }
120            bytes
121        }
122        DType::I64 => {
123            let mut bytes = Vec::with_capacity(v.len() * 8);
124            for &x in v {
125                bytes.extend_from_slice(&(x as i64).to_le_bytes());
126            }
127            bytes
128        }
129        DType::Bool => v
130            .iter()
131            .map(|&x| if x != 0.0 { 1u8 } else { 0u8 })
132            .collect(),
133        DType::C64 => {
134            // Complex narrow path: real part = the f32 value, imaginary
135            // part = 0. Mirrors how the backend stores narrowed f32
136            // operands when promoted to a complex op input.
137            let mut bytes = Vec::with_capacity(v.len() * 8);
138            for &x in v {
139                bytes.extend_from_slice(&x.to_le_bytes());
140                bytes.extend_from_slice(&0.0_f32.to_le_bytes());
141            }
142            bytes
143        }
144    }
145}
146
147/// A compiled, ready-to-execute graph on a specific backend.
148pub trait ExecutableGraph: Send {
149    /// Set a named parameter (weight) buffer.
150    fn set_param(&mut self, name: &str, data: &[f32]);
151
152    /// Deep-clone this executable into a fresh `Box`. Lets
153    /// `CompiledGraph` implement `Clone` so callers (e.g. eda-mna's
154    /// `SensitivityContext`) can spin up N independent executor
155    /// copies for thread-parallel dispatch without paying the full
156    /// graph-compile cost N times. Default implementation panics;
157    /// backends that support cloning override.
158    fn clone_box(&self) -> Box<dyn ExecutableGraph> {
159        panic!("clone_box not implemented for this backend");
160    }
161
162    /// Execute the graph with named inputs. Returns output data (copies from arena).
163    fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>>;
164
165    /// Like [`Self::run`] but only read back outputs at `read_indices`.
166    /// GPU handle feeds still update for every output. Default: all outputs.
167    fn run_read_outputs(
168        &mut self,
169        inputs: &[(&str, &[f32])],
170        read_indices: Option<&[usize]>,
171    ) -> Vec<Vec<f32>> {
172        match read_indices {
173            None => self.run(inputs),
174            Some(ix) => {
175                // Backends without a native partial-read path still run the full
176                // graph; only clone the requested outputs on the host.
177                let all = self.run(inputs);
178                ix.iter().filter_map(|&i| all.get(i).cloned()).collect()
179            }
180        }
181    }
182
183    /// Execute and return raw pointers to output data in arena (zero-copy).
184    fn run_raw(&mut self, inputs: &[(&str, &[f32])]) -> Vec<(*const f32, usize)> {
185        let vecs = self.run(inputs);
186        vecs.iter().map(|v| (v.as_ptr(), v.len())).collect()
187    }
188
189    /// Fastest: inputs by slot index, returns output (offset, len) pairs.
190    /// Read output from arena via `arena_ptr().add(offset)`.
191    fn run_slots(&mut self, _inputs: &[&[f32]]) -> &[(usize, usize)] {
192        &[] // default: not supported
193    }
194
195    /// Get the raw arena buffer pointer for reading outputs after run_slots.
196    fn arena_ptr(&self) -> *const u8 {
197        std::ptr::null()
198    }
199
200    /// Hint the executor that subsequent `run` calls should process
201    /// only the first `actual` rows along the bucket axis (out of
202    /// `upper`, the extent the graph was compiled at). Backends that
203    /// support per-kernel active-extent dispatch honor this; others
204    /// ignore it and process the full compiled extent.
205    ///
206    /// Pass `None` to clear the hint. The hint is sticky — set it
207    /// before each `run` and clear it after, or maintain it across
208    /// runs at your discretion.
209    ///
210    /// Even when honored, callers must not rely on the contents of the
211    /// output past `actual` rows — that region may contain stale data
212    /// from earlier runs (kernels skip it).
213    ///
214    /// Default: no-op. See `BucketedCompileCache::run_padded` for the
215    /// canonical caller; backends opt in by overriding this method.
216    fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
217        let _ = extent;
218    }
219
220    /// TIDE merged placement mask (union across MoE layers). CPU: stats + host path.
221    fn set_moe_resident_experts(&mut self, _mask: &[bool]) {}
222
223    /// Per MoE layer placement (`masks[layer][expert]`). Preferred over merged on CPU.
224    fn set_moe_resident_experts_per_layer(&mut self, _masks: &[&[bool]]) {}
225
226    /// Capture MoE router TopK indices on the next CPU forward (TIDE refresh).
227    fn enable_moe_topk_capture(&mut self, _num_experts: usize) -> bool {
228        false
229    }
230
231    /// Take captured per-layer expert indices (one vec per MoE TopK in order).
232    fn take_moe_topk_capture(&mut self) -> Option<Vec<Vec<u32>>> {
233        None
234    }
235
236    /// MoE GroupedMatMul residency accounting from the last forward (CPU).
237    fn take_moe_residency_stats(&mut self) -> Option<crate::MoeResidencyStats> {
238        None
239    }
240
241    /// Bind a persistent buffer handle (KV-cache, training state, etc.).
242    /// The buffer lives across run() calls and is not in the arena.
243    /// Returns true if the backend supports persistent handles.
244    fn bind_handle(&mut self, _name: &str, _data: &[f32]) -> bool {
245        false
246    }
247
248    /// Read a persistent buffer's current contents.
249    fn read_handle(&self, _name: &str) -> Option<Vec<f32>> {
250        None
251    }
252
253    /// GPU-resident input (MLX): upload once, reuse across runs.
254    fn bind_gpu_handle(&mut self, _name: &str, _data: &[f32]) -> bool {
255        false
256    }
257
258    fn has_gpu_handle(&self, _name: &str) -> bool {
259        false
260    }
261
262    fn set_gpu_handle_feed(&mut self, _handle_name: &str, _output_index: usize) -> bool {
263        false
264    }
265
266    fn read_gpu_handle(&self, _name: &str) -> Option<Vec<f32>> {
267        None
268    }
269
270    /// Read one row from a row-major graph output after `run` / `run_read_outputs`.
271    /// Metal reads a single row from the arena; default returns `None` (caller falls back).
272    fn read_output_row(&self, _out_idx: usize, _row: usize, _row_inner: usize) -> Option<Vec<f32>> {
273        None
274    }
275
276    /// Run and refresh a GPU handle from `output_index`; returns that output on host.
277    fn run_feed_gpu_handle(
278        &mut self,
279        inputs: &[(&str, &[f32])],
280        _handle_name: &str,
281        _output_index: usize,
282    ) -> Option<Vec<f32>> {
283        let _ = inputs;
284        None
285    }
286
287    // ── Pipelined / async execution (Phase C) ─────────────────────────
288    //
289    // These allow callers to amortize per-run sync latency on backends
290    // where it matters (Metal: ~150 µs `wait_until_completed` per commit).
291    // CPU has no such cost, so the default impls just call `run` serially.
292
293    /// Encode + commit a forward pass without waiting for completion.
294    ///
295    /// Outputs of intermediate calls are stomped — use `run_pipelined` if
296    /// you need outputs from each individual commit. Pair with
297    /// `sync_pending` to drain.
298    ///
299    /// Default: synchronous fallback (calls `run`, discards output). CPU
300    /// uses this default since BLAS is synchronous anyway.
301    fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
302        let _ = self.run(inputs);
303    }
304
305    /// Wait for every command queued by `commit_no_wait`.
306    /// Default: no-op (synchronous backends have nothing pending).
307    fn sync_pending(&mut self) {}
308
309    /// Issue a batch of forward passes pipelined, returning per-run outputs.
310    ///
311    /// The Metal impl encodes a per-commit blit so each in-flight run's
312    /// outputs survive subsequent commits stomping the shared arena. The
313    /// CPU default is just sequential `run`s — equally correct, no perf
314    /// penalty (CPU has no GPU sync cost to amortize).
315    ///
316    /// Returns `out[run_idx][output_idx][element_idx]`.
317    fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
318        input_sets.iter().map(|inputs| self.run(inputs)).collect()
319    }
320
321    // ── Typed (non-F32) host I/O ──────────────────────────────────
322    //
323    // `set_param` and `run` are F32 by contract. The typed entry
324    // points let callers pass and receive raw bytes in any rlx-ir
325    // dtype, avoiding the f32 widen/narrow round-trip that's
326    // wasteful for F16/BF16 weights and activations.
327    //
328    // The default impls only handle F32 — any other dtype panics.
329    // Backends that support typed I/O natively (e.g. MLX via
330    // Array::from_bytes/to_bytes) override these.
331
332    /// Set a named parameter from raw bytes in the given dtype.
333    fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
334        if dtype != rlx_ir::DType::F32 {
335            panic!(
336                "backend's default set_param_typed only handles F32; \
337                    got {dtype:?}. Override on the backend for typed support."
338            );
339        }
340        if !data.len().is_multiple_of(4) {
341            panic!(
342                "set_param_typed F32: data length {} not a multiple of 4",
343                data.len()
344            );
345        }
346        // SAFETY: F32 bytes are 4-aligned by source convention; we
347        // only widen access (read &[f32] from owned &[u8]). Failure
348        // mode if a caller hands us mis-aligned bytes is undefined,
349        // hence the % 4 length check.
350        let n = data.len() / 4;
351        let f32_slice = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
352        self.set_param(name, f32_slice);
353    }
354
355    /// Run with typed inputs and typed outputs. Returns
356    /// `(bytes, dtype)` per output; the dtype is whatever the
357    /// graph's output node was declared as.
358    fn run_typed(
359        &mut self,
360        inputs: &[(&str, &[u8], rlx_ir::DType)],
361    ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
362        // Default impl: convert each typed input to f32 (F32-only),
363        // run, then re-emit outputs as F32 bytes.
364        let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
365        for (name, data, dt) in inputs {
366            if *dt != rlx_ir::DType::F32 {
367                panic!(
368                    "backend's default run_typed only handles F32 inputs; \
369                        got {dt:?} for input '{name}'"
370                );
371            }
372            if data.len() % 4 != 0 {
373                panic!(
374                    "run_typed F32 input '{name}': len {} not multiple of 4",
375                    data.len()
376                );
377            }
378            let n = data.len() / 4;
379            let v: Vec<f32> =
380                unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) }.to_vec();
381            owned.push((name.to_string(), v));
382        }
383        let refs: Vec<(&str, &[f32])> = owned
384            .iter()
385            .map(|(n, d)| (n.as_str(), d.as_slice()))
386            .collect();
387        let outs = self.run(&refs);
388        outs.into_iter()
389            .map(|v| {
390                let bytes =
391                    unsafe { std::slice::from_raw_parts(v.as_ptr() as *const u8, v.len() * 4) }
392                        .to_vec();
393                (bytes, rlx_ir::DType::F32)
394            })
395            .collect()
396    }
397}
398
399/// Backend implementation trait.
400///
401/// Single compile entry point. New compile-time knobs are added to
402/// `CompileOptions`, not as new trait methods.
403///
404/// `Send + Sync` because backends are stateless factories — multiple
405/// threads can call `compile` concurrently. The returned
406/// `Box<dyn ExecutableGraph>` is `Send` (moveable to a worker thread)
407/// but **not** `Sync` (`run`/`run_slots` take `&mut self`).
408pub trait Backend: Send + Sync {
409    /// Compile a graph for this backend with the given options.
410    fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph>;
411
412    /// Compile pre-optimized LIR (HIR → MIR → LIR pipeline output).
413    /// Default re-enters [`Self::compile`] — backends should override
414    /// when they can reuse the embedded buffer plan.
415    fn compile_lir(&self, lir: LirModule, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
416        self.compile(lir.into_graph(), options)
417    }
418
419    /// HIR-first compile: lower blocks, run fusion pipeline, emit executable.
420    fn compile_hir(
421        &self,
422        hir: HirModule,
423        device: rlx_driver::Device,
424        options: &CompileOptions,
425    ) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
426        let result = crate::stages::compile_hir_stages(device, hir, options)?;
427        crate::stages::maybe_log_fusion(&result.fusion);
428        Ok(self.compile_lir(result.lir, options))
429    }
430
431    /// [`GraphModule`] compile — unified HIR/MIR/LIR entry.
432    fn compile_module(
433        &self,
434        module: rlx_ir::GraphModule,
435        device: rlx_driver::Device,
436        options: &CompileOptions,
437    ) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
438        let result = crate::stages::compile_module_stages(device, module, options)?;
439        crate::stages::maybe_log_fusion(&result.fusion);
440        Ok(self.compile_lir(result.lir, options))
441    }
442
443    /// PLAN L4: declare which `OpKind`s this backend can lower.
444    /// Default: empty slice = "no claim made — accept everything"
445    /// (preserves existing behavior; backends opt in by overriding).
446    /// When non-empty, the `LegalizeForBackend` pass will refuse to
447    /// compile a graph that contains an op outside this set, instead
448    /// of silently falling through to slower / wrong dispatch.
449    fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
450        &[]
451    }
452}
453
454/// Prepare a fused MIR graph from LIR for backend executable construction.
455/// Skips the fusion pipeline — LIR must come from `compile_*_stages`.
456#[allow(dead_code)]
457fn prepare_fused_graph(
458    graph: Graph,
459    options: &CompileOptions,
460    supported_ops: &[rlx_ir::OpKind],
461    backend_name: &str,
462) -> Graph {
463    let (mut graph, report) = rlx_opt::prepare_graph_for_backend_with_report(
464        graph,
465        backend_name,
466        supported_ops,
467        options.kernel_dispatch,
468    );
469    rlx_opt::maybe_log_dispatch_report(&report);
470    if !report.compile_ready {
471        panic!(
472            "{}\n{}",
473            rlx_opt::format_legalize_error(backend_name, &report.still_unsupported),
474            rlx_opt::format_dispatch_report(&report)
475        );
476    }
477    graph = crate::precompile::post_fusion_cleanup(graph, options);
478    if let Some(p) = options.policy.clone() {
479        use rlx_opt::pass::Pass as _;
480        graph = rlx_opt::AutoMixedPrecision::new(p).run(graph);
481    }
482    graph
483}
484
485#[allow(dead_code)]
486fn declared_output_dtypes(
487    manifest: &cpu_low_precision::IoDtypeManifest,
488    exec_dtypes: Vec<rlx_ir::DType>,
489) -> Vec<rlx_ir::DType> {
490    exec_dtypes
491        .into_iter()
492        .enumerate()
493        .map(|(i, exec)| manifest.output_dtype(i, exec))
494        .collect()
495}
496
497// ── Convenience helpers preserved from older API ──────────────────────
498//
499// These let existing call sites keep working unchanged while the new
500// trait is the canonical one. We provide free functions rather than
501// trait methods so adding them doesn't grow the trait surface.
502
503/// Compile at default options (F32, no policy).
504pub fn compile(backend: &dyn Backend, graph: Graph) -> Box<dyn ExecutableGraph> {
505    backend.compile(graph, &CompileOptions::default())
506}
507
508/// Compile HIR through the fusion-first pipeline.
509pub fn compile_hir(
510    backend: &dyn Backend,
511    hir: HirModule,
512    device: rlx_driver::Device,
513    options: &CompileOptions,
514) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
515    backend.compile_hir(hir, device, options)
516}
517
518/// Compile a [`GraphModule`] through the fusion-first pipeline.
519pub fn compile_module(
520    backend: &dyn Backend,
521    module: rlx_ir::GraphModule,
522    device: rlx_driver::Device,
523    options: &CompileOptions,
524) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
525    backend.compile_module(module, device, options)
526}
527
528/// Compile at a specific precision (default policy = none).
529pub fn compile_with_precision(
530    backend: &dyn Backend,
531    graph: Graph,
532    precision: crate::Precision,
533) -> Box<dyn ExecutableGraph> {
534    backend.compile(graph, &CompileOptions::new().precision(precision))
535}
536
537/// Helper retained for backward compatibility — applies the precision
538/// rewrite at the runtime layer if backends don't override their
539/// pipeline placement. Modern code: pass the policy via CompileOptions
540/// and let the backend handle ordering.
541fn _legacy_apply_policy(graph: Graph, policy: Option<rlx_opt::PrecisionPolicy>) -> Graph {
542    use rlx_opt::pass::Pass as _;
543    match policy {
544        Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
545        None => graph,
546    }
547}
548
549// ── CPU Backend ─────────────────────────────────────────────────────────
550
551#[cfg(feature = "cpu")]
552pub mod cpu_backend {
553    use super::*;
554    use rlx_cpu::{arena::Arena, thunk};
555    use rlx_ir::{DType, NodeId, Op};
556    use rlx_opt::memory::{self, MemoryPlan};
557    // Arena typed read/write helpers live in `crate::arena` so every
558    // backend (CPU, Metal, future CUDA/wgpu/WASM) shares one implementation.
559    use rlx_driver::arena::{read_typed_to_f32, write_typed_from_f32};
560
561    pub struct CpuBackend;
562
563    /// PLAN L4: ops the CPU backend can lower today. Includes
564    /// DotGeneral (lowered via `LowerDotGeneral` pass) and
565    /// ElementwiseRegion (lowered natively per L2). Excludes
566    /// FusedTransformerLayer / If / While — those have IR variants
567    /// but no CPU lowering yet (see `compile_thunks` arm absence +
568    /// `subgraph.rs` "If/While executor wiring is pending" note).
569    const CPU_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
570        use rlx_ir::OpKind::*;
571        &[
572            Input,
573            Param,
574            Constant,
575            Activation,
576            Cast,
577            StopGradient,
578            Binary,
579            Compare,
580            Where,
581            ElementwiseRegion,
582            MatMul,
583            DotGeneral,
584            DenseSolve,
585            BatchedDenseSolve,
586            Scan,
587            ScanBackward,
588            ScanBackwardXs,
589            LayerNorm,
590            LayerNorm2d,
591            GroupNorm,
592            BatchNormInference,
593            RmsNorm,
594            ResizeNearest2x,
595            AxialRope2d,
596            Attention,
597            Rope,
598            Reshape,
599            Transpose,
600            Narrow,
601            Concat,
602            Expand,
603            Gather,
604            Reduce,
605            Softmax,
606            Cumsum,
607            TopK,
608            Sample,
609            Conv,
610            Im2Col,
611            ConvTranspose2d,
612            Pool,
613            GroupedMatMul,
614            DequantGroupedMatMul,
615            DequantMoEWeights,
616            ScatterAdd,
617            LoraMatMul,
618            DequantMatMul,
619            SelectiveScan,
620            GatedDeltaNet,
621            FusedSwiGLU,
622            FusedMatMulBiasAct,
623            FusedResidualLN,
624            FusedResidualRmsNorm,
625            FusedAttentionBlock,
626            // Backward ops emitted by `rlx_opt::autodiff::grad_with_loss`.
627            // Their thunks live in rlx-cpu/src/thunk.rs alongside the
628            // forward kernels; without these entries the legalize step
629            // below would reject any compiled gradient graph.
630            ReluBackward,
631            ActivationBackward,
632            FakeQuantize,
633            FakeQuantizeBackward,
634            MaxPool2dBackward,
635            Conv2dBackwardInput,
636            Conv2dBackwardWeight,
637            SoftmaxCrossEntropyWithLogits,
638            SoftmaxCrossEntropyBackward,
639            AttentionBackward,
640            LayerNormBackwardInput,
641            LayerNormBackwardGamma,
642            BatchNormInferenceBackwardInput,
643            BatchNormInferenceBackwardGamma,
644            BatchNormInferenceBackwardBeta,
645            RmsNormBackwardInput,
646            RmsNormBackwardGamma,
647            RmsNormBackwardBeta,
648            RopeBackward,
649            CumsumBackward,
650            GatherBackward,
651            // 3D Gaussian splat CPU reference render/backward (requires `rlx-cpu/splat`).
652            GaussianSplatRender,
653            GaussianSplatRenderBackward,
654            GaussianSplatPrepare,
655            GaussianSplatRasterize,
656            // User-registered custom ops dispatched through
657            // `rlx_cpu::op_registry`. Lowering panics with a clear
658            // message if the named CPU kernel isn't registered.
659            Custom,
660            // User-defined sub-graph with optional override AD rules
661            // (JAX-shaped custom_vjp / custom_jvp). Body is a regular
662            // Graph compiled recursively in compile_thunks.
663            CustomFn,
664            // FFT primitive (1D last-axis, 2N real-block layout, f64
665            // power-of-2 sizes). Other backends panic at lowering;
666            // pin FFT-containing graphs to Device::Cpu for now.
667            Fft,
668            FftButterflyStage,
669            LogMel,
670            LogMelBackward,
671            // C64 Wirtinger AD surface. ComplexNormSq is the canonical
672            // real-valued loss for complex inputs; Conjugate is emitted
673            // by the new Wirtinger VJP rules for BinaryOp::Mul/Div on
674            // C64. Both have CPU thunks in rlx-cpu.
675            ComplexNormSq,
676            ComplexNormSqBackward,
677            Conjugate,
678        ]
679    };
680
681    impl Backend for CpuBackend {
682        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
683            CPU_SUPPORTED_OPS
684        }
685
686        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
687            use rlx_opt::pass::Pass as _;
688            // Lower Op::If / Op::While to primitives BEFORE legalize
689            // so the supported-op check doesn't reject them — the CPU
690            // backend has no native sub-graph executor; this rewrite
691            // makes If/While invisible to the rest of the pipeline.
692            // No-op when neither op is in the graph.
693            let graph = rlx_opt::LowerControlFlow.run(graph);
694            // PLAN L4: legalize against the backend's claimed op set
695            // BEFORE running fusion (so the diagnostic points at the
696            // user's IR, not at a fused-away node).
697            if let Err(errors) = rlx_opt::legalize_for_backend(&graph, CPU_SUPPORTED_OPS) {
698                panic!("{}", rlx_opt::format_legalize_error("cpu", &errors));
699            }
700            let policy = options.policy.clone();
701            let _precision = options.precision;
702            let cfg = rlx_cpu::config::RuntimeConfig::global();
703
704            let graph = crate::precompile::precompile_cleanup(graph, options);
705
706            // Run fusion pipeline (HIR/MIR/LIR ideology — fusion is first-class).
707            let mut compile_opts = options.clone();
708            compile_opts.arena_alignment = cfg.arena_alignment;
709            let compile_result = crate::stages::compile_graph_stages_for_backend(
710                rlx_driver::Device::Cpu,
711                graph,
712                &compile_opts,
713                CPU_SUPPORTED_OPS,
714            );
715            crate::stages::maybe_log_fusion(&compile_result.fusion);
716            let fused = compile_result.lir.into_graph();
717
718            // Apply precision policy AFTER fusion — Cast nodes don't disrupt
719            // the now-flattened fused ops.
720            let fused = match policy {
721                Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(fused),
722                None => fused,
723            };
724
725            let io_manifest = cpu_low_precision::IoDtypeManifest::from_graph(&fused);
726            let exec_graph = if cpu_low_precision::needs_f32_exec(&fused) {
727                cpu_low_precision::promote_to_f32(fused)
728            } else {
729                fused
730            };
731
732            // Re-plan after precision rewrites (may change dtypes / sizes).
733            let plan = memory::plan_memory_aligned(&exec_graph, cfg.arena_alignment);
734            if cfg.verbose >= 1 {
735                eprintln!(
736                    "[rlx] arena: {} bytes, {} buffers, alignment: {}",
737                    plan.arena_size,
738                    plan.assignments.len(),
739                    cfg.arena_alignment
740                );
741            }
742            Box::new(build_cpu_executable(exec_graph, plan, io_manifest))
743        }
744
745        fn compile_lir(
746            &self,
747            lir: LirModule,
748            options: &CompileOptions,
749        ) -> Box<dyn ExecutableGraph> {
750            let alignment = lir.buffers.alignment.max(options.arena_alignment);
751            let mut graph = lir.into_graph();
752            {
753                use rlx_opt::pass::Pass as _;
754                graph = rlx_opt::LegalizeBroadcast.run(graph);
755            }
756            if let Some(p) = options.policy.clone() {
757                use rlx_opt::pass::Pass;
758                graph = rlx_opt::AutoMixedPrecision::new(p).run(graph);
759            }
760            let io_manifest = cpu_low_precision::IoDtypeManifest::from_graph(&graph);
761            let promote = cpu_low_precision::needs_f32_exec(&graph);
762            let exec_graph = if promote {
763                cpu_low_precision::promote_to_f32(graph)
764            } else {
765                graph
766            };
767            // LegalizeBroadcast may insert Expand nodes — must replan; the
768            // embedded LIR buffer map is from before legalization.
769            let plan = memory::plan_memory_aligned(&exec_graph, alignment);
770            let cfg = rlx_cpu::config::RuntimeConfig::global();
771            if cfg.verbose >= 1 {
772                eprintln!(
773                    "[rlx] compile_lir: arena {} bytes ({} buffers, alignment {})",
774                    plan.arena_size,
775                    plan.assignments.len(),
776                    alignment,
777                );
778            }
779            Box::new(build_cpu_executable(exec_graph, plan, io_manifest))
780        }
781    }
782
783    fn build_cpu_executable(
784        graph: Graph,
785        plan: MemoryPlan,
786        io_manifest: cpu_low_precision::IoDtypeManifest,
787    ) -> CpuExecutable {
788        let mut arena = Arena::from_plan(plan);
789        let mut input_ids = HashMap::new();
790        let mut param_ids = HashMap::new();
791        let mut node_dtypes: HashMap<NodeId, DType> = HashMap::new();
792        for node in graph.nodes() {
793            node_dtypes.insert(node.id, node.shape.dtype());
794            match &node.op {
795                Op::Input { name } => {
796                    input_ids.insert(name.clone(), node.id);
797                }
798                Op::Param { name } => {
799                    param_ids.insert(name.clone(), node.id);
800                }
801                _ => {}
802            }
803        }
804
805        let schedule = thunk::compile_thunks(&graph, &arena);
806
807        let mut input_slots = Vec::new();
808        for node in graph.nodes() {
809            if let Op::Input { name } = &node.op {
810                let off = arena.byte_offset(node.id);
811                let len = node.shape.num_elements().unwrap_or(0);
812                input_slots.push((name.clone(), off, len, node.shape.dtype()));
813            }
814        }
815
816        let output_slots: Vec<(usize, usize)> = graph
817            .outputs
818            .iter()
819            .map(|&id| {
820                let off = arena.byte_offset(id);
821                let len = graph.node(id).shape.num_elements().unwrap_or(0);
822                (off, len)
823            })
824            .collect();
825
826        for node in graph.nodes() {
827            if let Op::Constant { data } = &node.op
828                && arena.has_buffer(node.id)
829                && !data.is_empty()
830            {
831                match node.shape.dtype() {
832                    DType::F64 | DType::F16 | DType::BF16 => {
833                        let off = arena.byte_offset(node.id);
834                        let buf = arena.raw_buf_mut();
835                        let n = buf.len().saturating_sub(off).min(data.len());
836                        buf[off..off + n].copy_from_slice(&data[..n]);
837                    }
838                    _ => {
839                        let buf = arena.slice_mut(node.id);
840                        let n_floats = data.len() / 4;
841                        let n = buf.len().min(n_floats);
842                        for i in 0..n {
843                            let bytes = [
844                                data[i * 4],
845                                data[i * 4 + 1],
846                                data[i * 4 + 2],
847                                data[i * 4 + 3],
848                            ];
849                            buf[i] = f32::from_le_bytes(bytes);
850                        }
851                    }
852                }
853            }
854        }
855
856        CpuExecutable {
857            graph,
858            arena,
859            params: HashMap::new(),
860            typed_params: HashMap::new(),
861            input_ids,
862            param_ids,
863            node_dtypes,
864            io_manifest,
865            schedule,
866            input_slots,
867            output_slots,
868            handles: HashMap::new(),
869            active_extent: None,
870            moe_resident: None,
871            moe_resident_layers: None,
872            moe_topk_capture: None,
873        }
874    }
875
876    #[derive(Clone)]
877    struct CpuExecutable {
878        graph: Graph,
879        arena: Arena,
880        params: HashMap<String, Vec<f32>>,
881        /// Byte-backed params (`set_param_typed` / `set_param_bytes`).
882        typed_params: HashMap<String, (Vec<u8>, DType)>,
883        input_ids: HashMap<String, NodeId>,
884        param_ids: HashMap<String, NodeId>,
885        /// Per-node arena dtype. Lets set_param/run cast f32 ↔ F16/BF16
886        /// when AutoMixedPrecision has rewritten the graph.
887        node_dtypes: HashMap<NodeId, DType>,
888        /// User-facing boundary dtypes (before f32 promotion for CPU exec).
889        io_manifest: cpu_low_precision::IoDtypeManifest,
890        schedule: thunk::ThunkSchedule,
891        // Pre-resolved: ordered list of (input_name, arena_byte_offset, max_elems, dtype)
892        input_slots: Vec<(String, usize, usize, DType)>,
893        /// Output (byte_offset, num_elements). dtype is in node_dtypes.
894        output_slots: Vec<(usize, usize)>,
895        /// Persistent buffer handles (KV-cache, optimizer state, etc.).
896        /// Lives outside the arena and survives across run() calls.
897        /// On run(): if a handle's name matches a graph input, the
898        /// handle's data is used as the input.
899        handles: HashMap<String, Vec<f32>>,
900        /// Active-extent hint (`Some((actual, upper))`) for L1 bucketed
901        /// dispatch. When set AND every thunk in the schedule is in
902        /// `Thunk::safe_for_active_extent`, the executor processes only
903        /// `actual / upper` of each kernel's work. Otherwise (or when
904        /// `None`) runs at the full compiled extent. See PLAN L1.
905        active_extent: Option<(usize, usize)>,
906        moe_resident: Option<std::sync::Arc<[bool]>>,
907        moe_resident_layers: Option<std::sync::Arc<Vec<std::sync::Arc<[bool]>>>>,
908        moe_topk_capture: Option<std::sync::Arc<rlx_cpu::moe_topk_capture::MoeTopkCapture>>,
909    }
910
911    unsafe impl Send for CpuExecutable {}
912
913    impl CpuExecutable {
914        /// Write a f32 input slice into the arena, casting to the node's dtype.
915        fn write_input(&mut self, id: NodeId, data: &[f32]) {
916            let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
917            let off = self.arena.byte_offset(id);
918            let buf = self.arena.raw_buf_mut();
919            let elem_size = dtype.size_bytes();
920            let max_elems = (buf.len() - off) / elem_size;
921            unsafe {
922                write_typed_from_f32(buf.as_mut_ptr().add(off), dtype, data, max_elems);
923            }
924        }
925
926        /// Read a node's arena bytes back as Vec<f32>, casting from its dtype.
927        fn read_output(&self, id: NodeId) -> Vec<f32> {
928            let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
929            let off = self.arena.byte_offset(id);
930            let buf = self.arena.raw_buf();
931            let n_elems = self.graph.node(id).shape.num_elements().unwrap_or(0);
932            unsafe { read_typed_to_f32(buf.as_ptr().add(off), dtype, n_elems) }
933        }
934    }
935
936    impl ExecutableGraph for CpuExecutable {
937        fn clone_box(&self) -> Box<dyn ExecutableGraph> {
938            Box::new(self.clone())
939        }
940        fn set_param(&mut self, name: &str, data: &[f32]) {
941            self.params.insert(name.to_string(), data.to_vec());
942            self.typed_params.remove(name);
943            // Write directly into the arena — zero per-call lookup for params.
944            // Cast f32 → arena dtype when the param has been rewritten to F16/BF16.
945            if let Some(&id) = self.param_ids.get(name)
946                && self.arena.has_buffer(id)
947            {
948                let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
949                let off = self.arena.byte_offset(id);
950                let buf = self.arena.raw_buf_mut();
951                let elem_size = dtype.size_bytes();
952                let max_elems = (buf.len() - off) / elem_size;
953                unsafe {
954                    write_typed_from_f32(buf.as_mut_ptr().add(off), dtype, data, max_elems);
955                }
956            }
957        }
958
959        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
960            self.restore_arena_baseline();
961            // 1. Apply persistent handles first — they act like default inputs.
962            //    Explicit `inputs` passed to run() override matching handle names.
963            let handle_names: Vec<String> = self.handles.keys().cloned().collect();
964            for name in &handle_names {
965                if let Some(&id) = self.input_ids.get(name)
966                    && self.arena.has_buffer(id)
967                {
968                    let data = self.handles.get(name).cloned().unwrap_or_default();
969                    self.write_input(id, &data);
970                }
971            }
972            // 2. Explicit per-call inputs override handles.
973            for &(name, data) in inputs {
974                if let Some(&id) = self.input_ids.get(name)
975                    && self.arena.has_buffer(id)
976                {
977                    self.write_input(id, data);
978                }
979            }
980
981            // Active-extent fast-path (PLAN L1): if hinted AND every thunk
982            // in the schedule supports it, run scaled. Otherwise fall back
983            // to full-extent dispatch — preserves correctness when the
984            // schedule contains a thunk that hasn't yet been wired in.
985            let active_used = if let Some((actual, upper)) = self.active_extent {
986                thunk::execute_thunks_active(
987                    &self.schedule,
988                    self.arena.raw_buf_mut(),
989                    actual,
990                    upper,
991                )
992            } else {
993                false
994            };
995            if !active_used {
996                // Execute via pre-compiled thunks (zero per-node dispatch overhead)
997                thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
998            }
999
1000            // 3. Sync any handle whose name matches a graph OUTPUT —
1001            //    KV-cache pattern: outputs flow back into the same-named
1002            //    handle for the next iteration.
1003            for (idx, &out_id) in self.graph.outputs.iter().enumerate() {
1004                let name = format!("out{idx}");
1005                if self.handles.contains_key(&name) {
1006                    let v = self.read_output(out_id);
1007                    self.handles.insert(name, v);
1008                }
1009            }
1010
1011            self.graph
1012                .outputs
1013                .iter()
1014                .map(|&out_id| self.read_output(out_id))
1015                .collect()
1016        }
1017
1018        fn run_raw(&mut self, inputs: &[(&str, &[f32])]) -> Vec<(*const f32, usize)> {
1019            self.restore_arena_baseline();
1020            // Copy inputs by name (HashMap lookup), casting to arena dtype.
1021            for &(name, data) in inputs {
1022                if let Some(&id) = self.input_ids.get(name)
1023                    && self.arena.has_buffer(id)
1024                {
1025                    self.write_input(id, data);
1026                }
1027            }
1028            thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1029            // Note: pointers are raw arena bytes — for F16 outputs, callers
1030            // must read 2 bytes/elem, not 4. run() is the safe path for
1031            // mixed precision; run_raw() is only meaningful for F32.
1032            self.graph
1033                .outputs
1034                .iter()
1035                .map(|&out_id| {
1036                    let (ptr, len) = self.arena.raw_ptr(out_id);
1037                    (ptr as *const f32, len)
1038                })
1039                .collect()
1040        }
1041
1042        /// Fastest path: inputs by index (matching input_slots order), zero-copy output.
1043        /// No HashMap, no name matching, no Vec allocation. Casts f32 input
1044        /// to F16/BF16 if the input slot's dtype was rewritten.
1045        fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
1046            self.restore_arena_baseline();
1047            let buf = self.arena.raw_buf_mut();
1048            for (i, &data) in inputs.iter().enumerate() {
1049                if i < self.input_slots.len() {
1050                    let (_, off, max_len, dtype) = &self.input_slots[i];
1051                    unsafe {
1052                        write_typed_from_f32(buf.as_mut_ptr().add(*off), *dtype, data, *max_len);
1053                    }
1054                }
1055            }
1056            thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1057            &self.output_slots
1058        }
1059
1060        fn arena_ptr(&self) -> *const u8 {
1061            self.arena.raw_buf_mut_ptr()
1062        }
1063
1064        fn bind_handle(&mut self, name: &str, data: &[f32]) -> bool {
1065            // Persistent buffer: stored separately from arena, survives run().
1066            // If the name matches a graph input, run() will use this data
1067            // as the input. If the graph also writes back to this name (via
1068            // an output binding pattern), read_handle returns the latest.
1069            self.handles.insert(name.to_string(), data.to_vec());
1070            true
1071        }
1072
1073        fn read_handle(&self, name: &str) -> Option<Vec<f32>> {
1074            self.handles.get(name).cloned()
1075        }
1076
1077        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1078            self.active_extent = extent;
1079        }
1080
1081        fn set_moe_resident_experts(&mut self, mask: &[bool]) {
1082            self.moe_resident_layers = None;
1083            self.schedule.moe_resident_layers = None;
1084            self.moe_resident = Some(Arc::from(mask));
1085            self.schedule.moe_resident = self.moe_resident.clone();
1086        }
1087
1088        fn set_moe_resident_experts_per_layer(&mut self, masks: &[&[bool]]) {
1089            self.moe_resident = None;
1090            self.schedule.moe_resident = None;
1091            let layers: Vec<Arc<[bool]>> = masks.iter().map(|m| Arc::from(*m)).collect();
1092            let arc = Arc::new(layers);
1093            self.moe_resident_layers = Some(arc.clone());
1094            self.schedule.moe_resident_layers = Some(arc);
1095        }
1096
1097        fn enable_moe_topk_capture(&mut self, num_experts: usize) -> bool {
1098            let cap = rlx_cpu::moe_topk_capture::MoeTopkCapture::new(num_experts);
1099            self.moe_topk_capture = Some(cap.clone());
1100            self.schedule.moe_topk_capture = Some(cap);
1101            true
1102        }
1103
1104        fn take_moe_topk_capture(&mut self) -> Option<Vec<Vec<u32>>> {
1105            let cap = self.moe_topk_capture.as_ref()?;
1106            let layers = cap.take_layers();
1107            if layers.is_empty() {
1108                None
1109            } else {
1110                Some(layers)
1111            }
1112        }
1113
1114        fn take_moe_residency_stats(&mut self) -> Option<crate::MoeResidencyStats> {
1115            rlx_cpu::moe_residency::take_last_forward_stats()
1116        }
1117
1118        /// Typed param upload. F32 / F16 / BF16 go through the existing
1119        /// widen-to-f32 path (the CPU arena is historically f32 with
1120        /// optional half-precision rewrite). F64 (and any future
1121        /// non-widenable dtype) lands directly in the arena as bytes —
1122        /// the f32 path would lose precision.
1123        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1124            if matches!(dtype, DType::F64 | DType::I64 | DType::I32 | DType::U32) {
1125                self.set_param_bytes(name, data, dtype);
1126                return;
1127            }
1128            // U8 / I8 raw byte tensors: opaque storage for the GGUF
1129            // K-quant `Op::DequantMatMul` path (weights stay packed
1130            // in the arena). One arena byte = one element.
1131            if matches!(dtype, DType::U8 | DType::I8) {
1132                self.set_param_bytes(name, data, dtype);
1133                return;
1134            }
1135            if dtype == DType::F32 {
1136                let n = data.len() / 4;
1137                let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
1138                self.set_param(name, s);
1139            } else {
1140                let f32_buf = super::widen_bytes_to_f32(data, dtype);
1141                self.set_param(name, &f32_buf);
1142            }
1143        }
1144
1145        /// Typed run with mixed-dtype inputs/outputs.
1146        ///
1147        /// For each input: if its declared graph dtype matches the
1148        /// caller's bytes, we write directly into the arena (zero
1149        /// precision loss — F64 stays F64). For F32 with a half-precision
1150        /// arena rewrite, we widen as before. F16/BF16 callers go
1151        /// through the existing widen path.
1152        ///
1153        /// Outputs are read straight from the arena in the graph node's
1154        /// declared dtype — F64 outputs come back as 8 bytes/element,
1155        /// F32 as 4, etc.
1156        fn run_typed(
1157            &mut self,
1158            inputs: &[(&str, &[u8], rlx_ir::DType)],
1159        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1160            // Decide: are *all* inputs F64? If so, use the direct-byte
1161            // path for everything and skip the f32 widening machinery
1162            // entirely. Mixed dtype graphs (F32 + F64) take the
1163            // per-input dispatch route below.
1164            let all_f64 = !inputs.is_empty() && inputs.iter().all(|(_, _, dt)| *dt == DType::F64);
1165
1166            if all_f64 {
1167                for (name, data, _) in inputs {
1168                    if let Some(&id) = self.input_ids.get(*name) {
1169                        if !self.arena.has_buffer(id) {
1170                            continue;
1171                        }
1172                        let off = self.arena.byte_offset(id);
1173                        let buf = self.arena.raw_buf_mut();
1174                        let n = data.len();
1175                        debug_assert!(
1176                            off + n <= buf.len(),
1177                            "run_typed: input '{name}' overflows arena slot"
1178                        );
1179                        buf[off..off + n].copy_from_slice(data);
1180                    }
1181                }
1182                thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1183            } else {
1184                // Mixed-dtype path: dtypes that survive untouched
1185                // through the f32-aliased arena (F64, I32, I64, U32)
1186                // go in as bytes; F32 and the half-precision family
1187                // route through widen-to-f32 + run.
1188                let mut f32_owned: Vec<(String, Vec<f32>)> = Vec::new();
1189                for (name, data, dt) in inputs {
1190                    let direct = matches!(
1191                        *dt,
1192                        DType::F64 | DType::I32 | DType::I64 | DType::U32 | DType::C64
1193                    );
1194                    if direct {
1195                        if let Some(&id) = self.input_ids.get(*name) {
1196                            if !self.arena.has_buffer(id) {
1197                                continue;
1198                            }
1199                            let off = self.arena.byte_offset(id);
1200                            let buf = self.arena.raw_buf_mut();
1201                            buf[off..off + data.len()].copy_from_slice(data);
1202                        }
1203                    } else {
1204                        let v = super::widen_bytes_to_f32(data, *dt);
1205                        f32_owned.push((name.to_string(), v));
1206                    }
1207                }
1208                for (name, data) in &f32_owned {
1209                    if let Some(&id) = self.input_ids.get(name.as_str()) {
1210                        if self.arena.has_buffer(id) {
1211                            self.write_input(id, data);
1212                        }
1213                    }
1214                }
1215                let active_used = if let Some((actual, upper)) = self.active_extent {
1216                    thunk::execute_thunks_active(
1217                        &self.schedule,
1218                        self.arena.raw_buf_mut(),
1219                        actual,
1220                        upper,
1221                    )
1222                } else {
1223                    false
1224                };
1225                if !active_used {
1226                    thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1227                }
1228            }
1229
1230            // Read outputs in declared boundary dtypes.
1231            self.graph
1232                .outputs
1233                .iter()
1234                .enumerate()
1235                .map(|(idx, &id)| {
1236                    let exec_dtype = self.graph.node(id).shape.dtype();
1237                    let declared = self.io_manifest.output_dtype(idx, exec_dtype);
1238                    if matches!(
1239                        exec_dtype,
1240                        DType::F64
1241                            | DType::F16
1242                            | DType::BF16
1243                            | DType::I32
1244                            | DType::I64
1245                            | DType::U32
1246                            | DType::C64
1247                    ) {
1248                        let n_elems = self.graph.node(id).shape.num_elements().unwrap_or(0);
1249                        let n_bytes = n_elems * exec_dtype.size_bytes();
1250                        let off = self.arena.byte_offset(id);
1251                        let bytes = self.arena.raw_buf()[off..off + n_bytes].to_vec();
1252                        return (bytes, declared);
1253                    }
1254                    let f32_vals = self.read_output(id);
1255                    if declared != exec_dtype {
1256                        return (super::narrow_f32_to_bytes(&f32_vals, declared), declared);
1257                    }
1258                    let bytes = f32_vals.iter().flat_map(|v| v.to_le_bytes()).collect();
1259                    (bytes, declared)
1260                })
1261                .collect()
1262        }
1263    }
1264
1265    impl CpuExecutable {
1266        /// Clear ephemeral arena slots, then restore compile-time constants
1267        /// and cached params. Intermediate buffers are reused across `run()`
1268        /// calls; without this reset, a second execution can read stale data
1269        /// from the previous pass.
1270        fn restore_arena_baseline(&mut self) {
1271            self.arena.raw_buf_mut().fill(0);
1272            let constants: Vec<(NodeId, DType, Vec<u8>)> = self
1273                .graph
1274                .nodes()
1275                .iter()
1276                .filter_map(|node| {
1277                    if let Op::Constant { data } = &node.op
1278                        && self.arena.has_buffer(node.id)
1279                        && !data.is_empty()
1280                    {
1281                        Some((node.id, node.shape.dtype(), data.clone()))
1282                    } else {
1283                        None
1284                    }
1285                })
1286                .collect();
1287            for (id, dtype, data) in constants {
1288                self.write_constant_to_arena(id, dtype, &data);
1289            }
1290            let params = self.params.clone();
1291            for (name, data) in params {
1292                if let Some(&id) = self.param_ids.get(&name)
1293                    && self.arena.has_buffer(id)
1294                {
1295                    let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
1296                    let off = self.arena.byte_offset(id);
1297                    let buf = self.arena.raw_buf_mut();
1298                    let elem_size = dtype.size_bytes();
1299                    let max_elems = (buf.len() - off) / elem_size;
1300                    unsafe {
1301                        write_typed_from_f32(buf.as_mut_ptr().add(off), dtype, &data, max_elems);
1302                    }
1303                }
1304            }
1305            let typed = self.typed_params.clone();
1306            for (name, (data, dtype)) in typed {
1307                self.write_param_bytes_to_arena(&name, &data);
1308                let _ = dtype;
1309            }
1310        }
1311
1312        fn write_constant_to_arena(&mut self, id: NodeId, dtype: DType, data: &[u8]) {
1313            match dtype {
1314                DType::F64 | DType::F16 | DType::BF16 | DType::U8 | DType::I8 => {
1315                    let off = self.arena.byte_offset(id);
1316                    let buf = self.arena.raw_buf_mut();
1317                    let n = buf.len().saturating_sub(off).min(data.len());
1318                    buf[off..off + n].copy_from_slice(&data[..n]);
1319                }
1320                _ => {
1321                    let buf = self.arena.slice_mut(id);
1322                    let n_floats = data.len() / 4;
1323                    let n = buf.len().min(n_floats);
1324                    for i in 0..n {
1325                        let bytes = [
1326                            data[i * 4],
1327                            data[i * 4 + 1],
1328                            data[i * 4 + 2],
1329                            data[i * 4 + 3],
1330                        ];
1331                        buf[i] = f32::from_le_bytes(bytes);
1332                    }
1333                }
1334            }
1335        }
1336
1337        /// Direct-byte param upload — copies caller's bytes into the
1338        /// arena slot for the named param without any dtype conversion.
1339        /// Used by `set_param_typed` for dtypes that f32-widening would
1340        /// corrupt (F64). Caller is responsible for matching the param's
1341        /// declared graph dtype.
1342        fn set_param_bytes(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1343            self.typed_params
1344                .insert(name.to_string(), (data.to_vec(), dtype));
1345            self.params.remove(name);
1346            self.write_param_bytes_to_arena(name, data);
1347        }
1348
1349        fn write_param_bytes_to_arena(&mut self, name: &str, data: &[u8]) {
1350            if let Some(&id) = self.param_ids.get(name)
1351                && self.arena.has_buffer(id)
1352            {
1353                let off = self.arena.byte_offset(id);
1354                let buf = self.arena.raw_buf_mut();
1355                debug_assert!(
1356                    off + data.len() <= buf.len(),
1357                    "set_param_bytes: '{name}' would overflow arena slot"
1358                );
1359                buf[off..off + data.len()].copy_from_slice(data);
1360            }
1361        }
1362    }
1363}
1364
1365// ── Metal Backend ───────────────────────────────────────────────────────
1366
1367// ── wgpu Backend ────────────────────────────────────────────────────────
1368
1369#[cfg(feature = "gpu")]
1370pub mod wgpu_backend {
1371    use super::*;
1372    use rlx_ir::OpKind;
1373    use rlx_wgpu::backend::WgpuExecutable;
1374
1375    pub struct WgpuBackend;
1376
1377    /// PLAN L4: ops the wgpu backend can lower today. The fused
1378    /// macro-kernels (FAB, FTL, FusedSwiGLU) get decomposed by
1379    /// `crate::unfuse::unfuse` upstream — they're listed here too so
1380    /// graphs that already contain them legalize cleanly. Conv1d/3d
1381    /// and Pool1d/3d are deferred (Conv2d only).
1382    const WGPU_SUPPORTED_OPS: &[OpKind] = &[
1383        OpKind::Input,
1384        OpKind::Param,
1385        OpKind::Constant,
1386        OpKind::Activation,
1387        OpKind::Cast,
1388        OpKind::StopGradient,
1389        OpKind::Binary,
1390        OpKind::Compare,
1391        OpKind::Where,
1392        OpKind::ElementwiseRegion,
1393        OpKind::TransformRegion,
1394        OpKind::BatchElementwiseRegion,
1395        OpKind::MatMul,
1396        OpKind::DotGeneral,
1397        OpKind::LayerNorm,
1398        OpKind::RmsNorm,
1399        OpKind::Attention,
1400        OpKind::AttentionBackward,
1401        OpKind::RmsNormBackwardInput,
1402        OpKind::RmsNormBackwardGamma,
1403        OpKind::RmsNormBackwardBeta,
1404        // LayerNorm backward family:
1405        //   * Input  — single workgroup-per-row fused kernel.
1406        //   * Gamma  — two-dispatch (partial + reduce) that uses a tail
1407        //              scratch zone in the arena to hold per-chunk
1408        //              partial sums; the reduce kernel sums them.
1409        // Both beat the autodiff-decomposed primitive chain.
1410        OpKind::LayerNormBackwardInput,
1411        OpKind::LayerNormBackwardGamma,
1412        OpKind::RopeBackward,
1413        OpKind::CumsumBackward,
1414        OpKind::GatherBackward,
1415        OpKind::Rope,
1416        OpKind::Reshape,
1417        OpKind::Transpose,
1418        OpKind::Narrow,
1419        OpKind::Concat,
1420        OpKind::Expand,
1421        OpKind::Gather,
1422        OpKind::Reduce,
1423        OpKind::Softmax,
1424        OpKind::Cumsum,
1425        OpKind::TopK,
1426        OpKind::Sample,
1427        OpKind::Conv,
1428        OpKind::Im2Col,
1429        OpKind::Pool,
1430        OpKind::GroupedMatMul,
1431        OpKind::DequantGroupedMatMul,
1432        OpKind::DequantMoEWeights,
1433        OpKind::ScatterAdd,
1434        OpKind::SelectiveScan,
1435        OpKind::DequantMatMul,
1436        OpKind::FusedMatMulBiasAct,
1437        OpKind::FusedResidualLN,
1438        OpKind::FusedResidualRmsNorm,
1439        OpKind::FusedSwiGLU,
1440        OpKind::FusedAttentionBlock,
1441        OpKind::FusedTransformerLayer,
1442        // Native FFT (WGSL radix-2): f32 only, power-of-2 N ≤ 1024.
1443        // Anything outside that envelope panics at lowering with a
1444        // "pin to Device::Cpu" hint. No host fallback — WGPU has no
1445        // unified memory, so silent CPU round-trip would be a hidden
1446        // performance cliff.
1447        OpKind::Fft,
1448        OpKind::LogMel,
1449        OpKind::LogMelBackward,
1450        // 3D Gaussian splat: native Metal / CPU reference per backend.
1451        OpKind::GaussianSplatRender,
1452        OpKind::GaussianSplatRenderBackward,
1453        OpKind::GaussianSplatPrepare,
1454        OpKind::GaussianSplatRasterize,
1455        OpKind::Custom,
1456        // LoRA, If, While: not yet wired in wgpu — fail loudly.
1457    ];
1458
1459    impl Backend for WgpuBackend {
1460        fn supported_ops(&self) -> &'static [OpKind] {
1461            WGPU_SUPPORTED_OPS
1462        }
1463
1464        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
1465            use rlx_opt::pass::Pass as _;
1466            let graph = rlx_opt::LowerControlFlow.run(graph);
1467            let graph = rlx_opt::legalize_or_rewrite_for_backend(graph, WGPU_SUPPORTED_OPS)
1468                .unwrap_or_else(|errors| {
1469                    panic!("{}", rlx_opt::format_legalize_error("wgpu", &errors));
1470                });
1471            let graph = crate::precompile::precompile_cleanup(graph, options);
1472            // Materialize mid-axis broadcasts before MarkElementwiseRegions:
1473            // wgpu Binary/region kernels only handle trailing/scalar broadcast
1474            // via modulus; EEG patch embed uses [1,C,1,D] + [1,C,P,D].
1475            let graph = rlx_opt::LegalizeBroadcast.run(graph);
1476            // ORDER MATTERS: targeted-pattern fusions run BEFORE the
1477            // catch-all `MarkElementwiseRegions`. Otherwise the region
1478            // pass swallows the Add / Activation nodes into chains and
1479            // FuseMatMulBiasAct / FuseResidualLN fail to match the
1480            // narrower patterns they look for. (Metal pipeline at line
1481            // ~377 already orders these correctly; wgpu was inverted
1482            // and silently shipped 13 unfused LayerNorms per BERT
1483            // forward where 12 should have been FusedResidualLN.)
1484            let compile_result = crate::stages::compile_graph_stages_for_backend(
1485                rlx_driver::Device::Gpu,
1486                graph,
1487                options,
1488                WGPU_SUPPORTED_OPS,
1489            );
1490            crate::stages::maybe_log_fusion(&compile_result.fusion);
1491            let graph = compile_result.lir.into_graph();
1492            let graph = match options.policy.clone() {
1493                Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
1494                None => graph,
1495            };
1496            let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
1497            Box::new(WgpuExecutableWrapper {
1498                inner: WgpuExecutable::compile(graph),
1499                io_manifest,
1500            })
1501        }
1502
1503        fn compile_lir(
1504            &self,
1505            lir: LirModule,
1506            options: &CompileOptions,
1507        ) -> Box<dyn ExecutableGraph> {
1508            use rlx_opt::pass::Pass as _;
1509            // LIR may already contain fused ElementwiseRegions; legalize
1510            // broadcasts on the unfused graph shape before backend prep.
1511            let graph = rlx_opt::LegalizeBroadcast.run(lir.into_graph());
1512            let graph = prepare_fused_graph(graph, options, WGPU_SUPPORTED_OPS, "wgpu");
1513            let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
1514            Box::new(WgpuExecutableWrapper {
1515                inner: WgpuExecutable::compile(graph),
1516                io_manifest,
1517            })
1518        }
1519    }
1520
1521    struct WgpuExecutableWrapper {
1522        inner: WgpuExecutable,
1523        io_manifest: cpu_low_precision::IoDtypeManifest,
1524    }
1525
1526    unsafe impl Send for WgpuExecutableWrapper {}
1527
1528    impl ExecutableGraph for WgpuExecutableWrapper {
1529        fn set_param(&mut self, name: &str, data: &[f32]) {
1530            self.inner.set_param(name, data);
1531        }
1532        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
1533            self.inner.run(inputs)
1534        }
1535        fn run_read_outputs(
1536            &mut self,
1537            inputs: &[(&str, &[f32])],
1538            read_indices: Option<&[usize]>,
1539        ) -> Vec<Vec<f32>> {
1540            self.inner.run_read_outputs(inputs, read_indices)
1541        }
1542        fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
1543            self.inner.bind_gpu_handle(name, data)
1544        }
1545        fn has_gpu_handle(&self, name: &str) -> bool {
1546            self.inner.has_gpu_handle(name)
1547        }
1548        fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
1549            self.inner.set_gpu_handle_feed(handle_name, output_index);
1550            true
1551        }
1552        fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
1553            self.inner.read_gpu_handle(name)
1554        }
1555        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1556            self.inner.set_active_extent(extent);
1557        }
1558
1559        /// Typed param upload: widens F16/BF16 to F32 at the host boundary,
1560        /// since the wgpu arena is f32-uniform.
1561        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1562            match dtype {
1563                rlx_ir::DType::U8 | rlx_ir::DType::I8 => {
1564                    self.inner.set_param_bytes(name, data);
1565                }
1566                rlx_ir::DType::F32 => {
1567                    let n = data.len() / 4;
1568                    let f32_slice =
1569                        unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
1570                    self.inner.set_param(name, f32_slice);
1571                }
1572                rlx_ir::DType::F16 => {
1573                    let n = data.len() / 2;
1574                    let f16_slice =
1575                        unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n) };
1576                    let f32: Vec<f32> = f16_slice.iter().map(|h| h.to_f32()).collect();
1577                    self.inner.set_param(name, &f32);
1578                }
1579                rlx_ir::DType::BF16 => {
1580                    let n = data.len() / 2;
1581                    let bf16_slice = unsafe {
1582                        std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n)
1583                    };
1584                    let f32: Vec<f32> = bf16_slice.iter().map(|h| h.to_f32()).collect();
1585                    self.inner.set_param(name, &f32);
1586                }
1587                other => panic!(
1588                    "rlx-wgpu set_param_typed: dtype {other:?} unsupported \
1589                                 (F32, F16, BF16 only — wgpu arena is f32-uniform)"
1590                ),
1591            }
1592        }
1593
1594        /// Typed run: widen each typed input to F32, run, then narrow each
1595        /// output back to its declared dtype.
1596        fn run_typed(
1597            &mut self,
1598            inputs: &[(&str, &[u8], rlx_ir::DType)],
1599        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1600            let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
1601            for (name, data, dt) in inputs {
1602                let v: Vec<f32> = match *dt {
1603                    rlx_ir::DType::F32 => {
1604                        let n = data.len() / 4;
1605                        unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) }
1606                            .to_vec()
1607                    }
1608                    rlx_ir::DType::F16 => {
1609                        let n = data.len() / 2;
1610                        let s = unsafe {
1611                            std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n)
1612                        };
1613                        s.iter().map(|h| h.to_f32()).collect()
1614                    }
1615                    rlx_ir::DType::BF16 => {
1616                        let n = data.len() / 2;
1617                        let s = unsafe {
1618                            std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n)
1619                        };
1620                        s.iter().map(|h| h.to_f32()).collect()
1621                    }
1622                    other => {
1623                        panic!("rlx-wgpu run_typed: input '{name}' dtype {other:?} unsupported")
1624                    }
1625                };
1626                owned.push((name.to_string(), v));
1627            }
1628            let refs: Vec<(&str, &[f32])> = owned
1629                .iter()
1630                .map(|(n, d)| (n.as_str(), d.as_slice()))
1631                .collect();
1632            let dtypes =
1633                super::declared_output_dtypes(&self.io_manifest, self.inner.output_dtypes());
1634            let outs = self.inner.run(&refs);
1635            outs.into_iter()
1636                .zip(
1637                    dtypes
1638                        .into_iter()
1639                        .chain(std::iter::repeat(rlx_ir::DType::F32)),
1640                )
1641                .map(|(v, dt)| (narrow_to_dtype(&v, dt), dt))
1642                .collect()
1643        }
1644    }
1645
1646    /// Cast every element of a wgpu f32 output buffer down to the
1647    /// declared output dtype, returning the corresponding byte stream.
1648    /// The arena keeps every value as f32; declared output dtypes
1649    /// (Bool, I8, I32, F16, ...) require an exit-time narrowing to be
1650    /// byte-identical with backends that store the native dtype.
1651    fn narrow_to_dtype(v: &[f32], dt: rlx_ir::DType) -> Vec<u8> {
1652        use rlx_ir::DType;
1653        match dt {
1654            DType::F32 => {
1655                let mut bytes = Vec::with_capacity(v.len() * 4);
1656                for &x in v {
1657                    bytes.extend_from_slice(&x.to_le_bytes());
1658                }
1659                bytes
1660            }
1661            DType::F16 => {
1662                let mut bytes = Vec::with_capacity(v.len() * 2);
1663                for &x in v {
1664                    bytes.extend_from_slice(&half::f16::from_f32(x).to_le_bytes());
1665                }
1666                bytes
1667            }
1668            DType::BF16 => {
1669                let mut bytes = Vec::with_capacity(v.len() * 2);
1670                for &x in v {
1671                    bytes.extend_from_slice(&half::bf16::from_f32(x).to_le_bytes());
1672                }
1673                bytes
1674            }
1675            DType::F64 => {
1676                let mut bytes = Vec::with_capacity(v.len() * 8);
1677                for &x in v {
1678                    bytes.extend_from_slice(&(x as f64).to_le_bytes());
1679                }
1680                bytes
1681            }
1682            DType::I8 => v.iter().map(|&x| x as i8 as u8).collect(),
1683            DType::U8 => v.iter().map(|&x| x as u8).collect(),
1684            DType::I16 => {
1685                let mut bytes = Vec::with_capacity(v.len() * 2);
1686                for &x in v {
1687                    bytes.extend_from_slice(&(x as i16).to_le_bytes());
1688                }
1689                bytes
1690            }
1691            DType::I32 => {
1692                let mut bytes = Vec::with_capacity(v.len() * 4);
1693                for &x in v {
1694                    bytes.extend_from_slice(&(x as i32).to_le_bytes());
1695                }
1696                bytes
1697            }
1698            DType::U32 => {
1699                let mut bytes = Vec::with_capacity(v.len() * 4);
1700                for &x in v {
1701                    bytes.extend_from_slice(&(x as u32).to_le_bytes());
1702                }
1703                bytes
1704            }
1705            DType::I64 => {
1706                let mut bytes = Vec::with_capacity(v.len() * 8);
1707                for &x in v {
1708                    bytes.extend_from_slice(&(x as i64).to_le_bytes());
1709                }
1710                bytes
1711            }
1712            DType::Bool => v
1713                .iter()
1714                .map(|&x| if x != 0.0 { 1u8 } else { 0u8 })
1715                .collect(),
1716            // C64 (complex f32 pair) — the wgpu backend's f32 arena
1717            // doesn't synthesize complex outputs today; this branch
1718            // only fires if a graph somehow asks for a C64 output and
1719            // the backend lowered it as 2N real floats. We pass the
1720            // raw f32 stream straight through; downstream code that
1721            // wants complex semantics is responsible for re-pairing.
1722            DType::C64 => {
1723                let mut bytes = Vec::with_capacity(v.len() * 4);
1724                for &x in v {
1725                    bytes.extend_from_slice(&x.to_le_bytes());
1726                }
1727                bytes
1728            }
1729        }
1730    }
1731}
1732
1733// ── MLX Backend ─────────────────────────────────────────────────────────
1734
1735#[cfg(all(feature = "mlx", rlx_mlx_host))]
1736pub mod mlx_backend {
1737    use super::*;
1738    use rlx_mlx::MlxExecutable;
1739
1740    pub struct MlxBackend;
1741
1742    /// PLAN L4: ops the MLX backend can lower today. MLX has the
1743    /// widest IR coverage of any GPU backend — handles everything
1744    /// including If/While via topo unrolling, and lowers
1745    /// ElementwiseRegion natively via the per-step composition in
1746    /// rlx-mlx/src/lower.rs (PLAN L2).
1747    ///
1748    /// `GroupNorm` / `BatchNormInference` are intentionally omitted — lowered
1749    /// to primitives via [`LowerGroupNorm`] / [`LowerBatchNormInference`]
1750    /// before MLX lowering (no native MLX kernel).
1751    const MLX_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
1752        use rlx_ir::OpKind::*;
1753        &[
1754            Input,
1755            Param,
1756            Constant,
1757            Activation,
1758            Cast,
1759            StopGradient,
1760            Binary,
1761            Compare,
1762            Where,
1763            ElementwiseRegion,
1764            TransformRegion,
1765            BatchElementwiseRegion,
1766            MatMul,
1767            DotGeneral,
1768            DenseSolve,
1769            BatchedDenseSolve,
1770            LayerNorm,
1771            LayerNorm2d,
1772            ResizeNearest2x,
1773            RmsNorm,
1774            Attention,
1775            Rope,
1776            Reshape,
1777            Transpose,
1778            Narrow,
1779            Concat,
1780            Expand,
1781            Gather,
1782            Reduce,
1783            Softmax,
1784            Cumsum,
1785            TopK,
1786            Sample,
1787            Conv,
1788            ConvTranspose2d,
1789            Pool,
1790            GroupedMatMul,
1791            DequantGroupedMatMul,
1792            DequantMoEWeights,
1793            ScatterAdd,
1794            LoraMatMul,
1795            DequantMatMul,
1796            SelectiveScan,
1797            GatedDeltaNet,
1798            FusedSwiGLU,
1799            FusedMatMulBiasAct,
1800            FusedResidualLN,
1801            FusedResidualRmsNorm,
1802            FusedAttentionBlock,
1803            FusedTransformerLayer,
1804            If,
1805            While,
1806            // Loop-unrolled scan (Op::Scan body is statically unrolled
1807            // `length` times into MLX ops; mirror of Op::While's
1808            // bounded-unroll lowering). ScanBackward is the AD
1809            // companion — handled the same way.
1810            Scan,
1811            ScanBackward,
1812            ScanBackwardXs,
1813            // Tier 1 autodiff backward ops — lowered as primitive
1814            // compositions in `rlx-mlx/src/lower.rs`.
1815            ReluBackward,
1816            ActivationBackward,
1817            SoftmaxCrossEntropyWithLogits,
1818            SoftmaxCrossEntropyBackward,
1819            AttentionBackward,
1820            LayerNormBackwardInput,
1821            LayerNormBackwardGamma,
1822            // Tier 2 — conv backward via `mc::conv_general` with the
1823            // same parameter-mapping MLX uses inside its built-in vjp.
1824            // Currently groups=1 only; grouped conv backward will
1825            // surface as a clear error from `lower.rs`.
1826            Conv2dBackwardInput,
1827            Conv2dBackwardWeight,
1828            // Tier 3 — max-pool backward via slice-strided argmax over
1829            // pool windows + a per-kernel-slot scatter-add, matching
1830            // the CPU thunk's "first-hit-wins" tiebreaking.
1831            MaxPool2dBackward,
1832            // QAT — `FakeQuantize` (PerBatch + Fixed scale modes;
1833            // EMA returns a clear error from `lower.rs`) and the
1834            // `FakeQuantizeBackward` family covering all 4 STE
1835            // variants. Closes the last gap vs `CPU_SUPPORTED_OPS`.
1836            FakeQuantize,
1837            FakeQuantizeBackward,
1838            // User-registered custom ops dispatched through
1839            // `rlx_mlx::op_registry`. Lowering looks up the
1840            // registered `MlxKernel` and calls its `execute` method
1841            // to produce the lazy MLX `Array` for this node.
1842            Custom,
1843            Fft,
1844            LogMel,
1845            LogMelBackward,
1846            GaussianSplatRender,
1847            GaussianSplatRenderBackward,
1848            // Op::Fft on MLX: native `mlx::fft::fft` via rlx_mlx_op_fft shim.
1849            // 2N real-block f32/f64 and complex64 inputs supported.
1850        ]
1851    };
1852
1853    impl Backend for MlxBackend {
1854        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
1855            MLX_SUPPORTED_OPS
1856        }
1857
1858        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
1859            let compile_result = crate::stages::compile_graph_stages_for_backend(
1860                rlx_driver::Device::Mlx,
1861                graph,
1862                options,
1863                MLX_SUPPORTED_OPS,
1864            );
1865            crate::stages::maybe_log_fusion(&compile_result.fusion);
1866            self.compile_lir(compile_result.lir, options)
1867        }
1868
1869        fn compile_lir(
1870            &self,
1871            lir: LirModule,
1872            options: &CompileOptions,
1873        ) -> Box<dyn ExecutableGraph> {
1874            use rlx_opt::pass::Pass as _;
1875            let mut graph = lir.into_graph();
1876            graph = rlx_opt::LowerControlFlow.run(graph);
1877            let graph = prepare_fused_graph(graph, options, MLX_SUPPORTED_OPS, "mlx");
1878            Box::new(build_mlx_executable(graph))
1879        }
1880    }
1881
1882    fn build_mlx_executable(graph: Graph) -> MlxExecutableWrapper {
1883        let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
1884        let mode = mlx_mode_from_env();
1885        let mut exe = MlxExecutable::compile_from_fused(graph, mode);
1886        if mode == rlx_mlx::lower::MlxMode::Compiled {
1887            if let Err(e) = exe.warm_compile() {
1888                eprintln!(
1889                    "[rlx-runtime] MLX warm_compile failed ({e}); first run will pay the trace cost"
1890                );
1891            }
1892        }
1893        MlxExecutableWrapper {
1894            inner: exe,
1895            io_manifest,
1896        }
1897    }
1898
1899    fn mlx_mode_from_env() -> rlx_mlx::lower::MlxMode {
1900        match rlx_ir::env::var("RLX_MLX_MODE").as_deref() {
1901            Some(s) if s.eq_ignore_ascii_case("eager") => rlx_mlx::lower::MlxMode::Eager,
1902            Some(s) if s.eq_ignore_ascii_case("lazy") => rlx_mlx::lower::MlxMode::Lazy,
1903            Some(s) if s.eq_ignore_ascii_case("compiled") => rlx_mlx::lower::MlxMode::Compiled,
1904            _ => rlx_mlx::lower::MlxMode::Compiled,
1905        }
1906    }
1907
1908    struct MlxExecutableWrapper {
1909        inner: MlxExecutable,
1910        io_manifest: cpu_low_precision::IoDtypeManifest,
1911    }
1912
1913    unsafe impl Send for MlxExecutableWrapper {}
1914
1915    impl ExecutableGraph for MlxExecutableWrapper {
1916        fn set_param(&mut self, name: &str, data: &[f32]) {
1917            self.inner.set_param(name, data);
1918        }
1919        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
1920            self.inner.run(inputs)
1921        }
1922        fn run_read_outputs(
1923            &mut self,
1924            inputs: &[(&str, &[f32])],
1925            read_indices: Option<&[usize]>,
1926        ) -> Vec<Vec<f32>> {
1927            self.inner
1928                .run_read_outputs(inputs, read_indices)
1929                .unwrap_or_else(|e| panic!("MLX run_read_outputs failed: {e}"))
1930        }
1931        fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
1932            self.inner.run_slots(inputs)
1933        }
1934        fn arena_ptr(&self) -> *const u8 {
1935            self.inner.arena_ptr()
1936        }
1937        fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
1938            self.inner.commit_no_wait(inputs);
1939        }
1940        fn sync_pending(&mut self) {
1941            self.inner.sync_pending();
1942        }
1943        fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
1944            self.inner.run_pipelined(input_sets)
1945        }
1946        fn bind_handle(&mut self, name: &str, data: &[f32]) -> bool {
1947            self.inner.bind_handle(name, data)
1948        }
1949        fn read_handle(&self, name: &str) -> Option<Vec<f32>> {
1950            self.inner.read_handle(name)
1951        }
1952        fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
1953            self.inner.bind_gpu_handle(name, data).is_ok()
1954        }
1955        fn has_gpu_handle(&self, name: &str) -> bool {
1956            self.inner.has_gpu_handle(name)
1957        }
1958        fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
1959            self.inner.set_gpu_handle_feed(handle_name, output_index);
1960            true
1961        }
1962        fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
1963            self.inner.read_gpu_handle(name).ok()
1964        }
1965        fn run_feed_gpu_handle(
1966            &mut self,
1967            inputs: &[(&str, &[f32])],
1968            handle_name: &str,
1969            output_index: usize,
1970        ) -> Option<Vec<f32>> {
1971            self.inner
1972                .run_feed_gpu(inputs, handle_name, output_index)
1973                .ok()
1974        }
1975        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1976            self.inner.set_param_typed(name, data, dtype);
1977        }
1978        fn run_typed(
1979            &mut self,
1980            inputs: &[(&str, &[u8], rlx_ir::DType)],
1981        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1982            let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
1983            for (name, data, dt) in inputs {
1984                let v = super::widen_bytes_to_f32(data, *dt);
1985                owned.push((name.to_string(), v));
1986            }
1987            let refs: Vec<(&str, &[f32])> = owned
1988                .iter()
1989                .map(|(n, d)| (n.as_str(), d.as_slice()))
1990                .collect();
1991            let f32_outs = self.inner.run(&refs);
1992            let declared = super::declared_output_dtypes(
1993                &self.io_manifest,
1994                (0..f32_outs.len()).map(|_| rlx_ir::DType::F32).collect(),
1995            );
1996            f32_outs
1997                .into_iter()
1998                .zip(
1999                    declared
2000                        .into_iter()
2001                        .chain(std::iter::repeat(rlx_ir::DType::F32)),
2002                )
2003                .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2004                .collect()
2005        }
2006        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2007            self.inner.set_active_extent(extent);
2008        }
2009    }
2010}
2011
2012#[cfg(all(feature = "metal", target_os = "macos"))]
2013pub mod metal_backend {
2014    use super::*;
2015    use rlx_metal::backend::MetalExecutable;
2016
2017    pub struct MetalBackend;
2018
2019    /// PLAN L4: ops the Metal backend can lower today. Includes
2020    /// DotGeneral (LowerDotGeneral pass) and ElementwiseRegion
2021    /// (decomposed by UnfuseElementwiseRegions). Excludes Cumsum,
2022    /// SelectiveScan, LoraMatMul, Sample,
2023    /// FusedAttentionBlock, FusedTransformerLayer, If, While —
2024    /// not yet wired in `rlx-metal/src/thunk.rs`'s compile_thunks.
2025    /// DequantMatMul (GGUF K-quants) lowers to a GPU dequant kernel
2026    /// + MPS matmul; legacy Int8 schemes remain CPU-only.
2027    ///
2028    const METAL_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2029        use rlx_ir::OpKind::*;
2030        &[
2031            Input,
2032            Param,
2033            Constant,
2034            Activation,
2035            Cast,
2036            StopGradient,
2037            Binary,
2038            Compare,
2039            Where,
2040            ElementwiseRegion,
2041            TransformRegion,
2042            BatchElementwiseRegion,
2043            MatMul,
2044            DotGeneral,
2045            LayerNorm,
2046            LayerNorm2d,
2047            GroupNorm,
2048            RmsNorm,
2049            ResizeNearest2x,
2050            AxialRope2d,
2051            Attention,
2052            AttentionBackward,
2053            RmsNormBackwardInput,
2054            RmsNormBackwardGamma,
2055            RmsNormBackwardBeta,
2056            RopeBackward,
2057            CumsumBackward,
2058            GatherBackward,
2059            Conv2dBackwardInput,
2060            Conv2dBackwardWeight,
2061            MaxPool2dBackward,
2062            Rope,
2063            Reshape,
2064            Transpose,
2065            Narrow,
2066            Concat,
2067            Expand,
2068            Gather,
2069            Reduce,
2070            Softmax,
2071            TopK,
2072            Conv,
2073            Im2Col,
2074            ConvTranspose2d,
2075            Pool,
2076            GroupedMatMul,
2077            DequantGroupedMatMul,
2078            DequantMoEWeights,
2079            ScatterAdd,
2080            DequantMatMul,
2081            GatedDeltaNet,
2082            FusedSwiGLU,
2083            FusedMatMulBiasAct,
2084            FusedResidualLN,
2085            FusedResidualRmsNorm,
2086            // User-registered custom ops dispatched through
2087            // `rlx_metal::op_registry`. Lowering panics with a clear
2088            // message if the named MetalKernel isn't registered;
2089            // executor inserts a sync point + runs the host kernel
2090            // against the unified-memory arena.
2091            Custom,
2092            // Op::Fft is supported via the same host-fallback pattern
2093            // as Custom: sync the GPU, run rlx-cpu's FFT against the
2094            // unified-memory arena, restart cmd_buf. A native Metal
2095            // compute kernel will replace this when a workload makes
2096            // the sync the bottleneck.
2097            Fft,
2098            LogMel,
2099            LogMelBackward,
2100            // Host-fallback splat (unified-memory arena + rlx-cpu/splat).
2101            GaussianSplatRender,
2102            GaussianSplatRenderBackward,
2103            GaussianSplatPrepare,
2104            GaussianSplatRasterize,
2105        ]
2106    };
2107
2108    impl Backend for MetalBackend {
2109        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2110            METAL_SUPPORTED_OPS
2111        }
2112
2113        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2114            use rlx_opt::pass::Pass as _;
2115            // Same If/While → primitive rewrite as the CPU pipeline
2116            // (Metal also has no native sub-graph executor wired
2117            // through its thunk schedule).
2118            let graph = rlx_opt::LowerControlFlow.run(graph);
2119            let dispatch = options.kernel_dispatch;
2120            let graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
2121                graph,
2122                METAL_SUPPORTED_OPS,
2123                dispatch,
2124            )
2125            .unwrap_or_else(|errors| {
2126                panic!("{}", rlx_opt::format_legalize_error("metal", &errors));
2127            });
2128            let graph = crate::precompile::precompile_cleanup(graph, options);
2129
2130            // Hand the policy to MetalExecutable so the rewrite runs AFTER
2131            // its internal fusion passes (avoids breaking pattern matchers).
2132            let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
2133            Box::new(MetalExecutableWrapper {
2134                inner: MetalExecutable::compile_with_policy(
2135                    graph,
2136                    options.policy.clone(),
2137                    Some(METAL_SUPPORTED_OPS),
2138                ),
2139                io_manifest,
2140            })
2141        }
2142
2143        fn compile_lir(
2144            &self,
2145            lir: LirModule,
2146            options: &CompileOptions,
2147        ) -> Box<dyn ExecutableGraph> {
2148            use rlx_opt::pass::Pass as _;
2149            let mut graph = lir.into_graph();
2150            graph = rlx_opt::LowerControlFlow.run(graph);
2151            let dispatch = options.kernel_dispatch;
2152            let mut graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
2153                graph,
2154                METAL_SUPPORTED_OPS,
2155                dispatch,
2156            )
2157            .unwrap_or_else(|errors| {
2158                panic!("{}", rlx_opt::format_legalize_error("metal", &errors));
2159            });
2160            graph = crate::precompile::precompile_cleanup(graph, options);
2161            let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
2162            Box::new(MetalExecutableWrapper {
2163                inner: MetalExecutable::compile_from_fused(
2164                    graph,
2165                    options.policy.clone(),
2166                    Some(METAL_SUPPORTED_OPS),
2167                ),
2168                io_manifest,
2169            })
2170        }
2171    }
2172
2173    struct MetalExecutableWrapper {
2174        inner: MetalExecutable,
2175        io_manifest: cpu_low_precision::IoDtypeManifest,
2176    }
2177
2178    unsafe impl Send for MetalExecutableWrapper {}
2179
2180    impl ExecutableGraph for MetalExecutableWrapper {
2181        fn set_param(&mut self, name: &str, data: &[f32]) {
2182            self.inner.set_param(name, data);
2183        }
2184        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2185            self.inner.run(inputs)
2186        }
2187        fn run_read_outputs(
2188            &mut self,
2189            inputs: &[(&str, &[f32])],
2190            read_indices: Option<&[usize]>,
2191        ) -> Vec<Vec<f32>> {
2192            self.inner.run_read_outputs(inputs, read_indices)
2193        }
2194        fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
2195            self.inner.bind_gpu_handle(name, data)
2196        }
2197        fn has_gpu_handle(&self, name: &str) -> bool {
2198            self.inner.has_gpu_handle(name)
2199        }
2200        fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
2201            self.inner.set_gpu_handle_feed(handle_name, output_index);
2202            true
2203        }
2204        fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
2205            self.inner.read_gpu_handle(name)
2206        }
2207        fn read_output_row(
2208            &self,
2209            out_idx: usize,
2210            row: usize,
2211            row_inner: usize,
2212        ) -> Option<Vec<f32>> {
2213            Some(self.inner.read_graph_output_row(out_idx, row, row_inner))
2214        }
2215        fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
2216            self.inner.run_slots(inputs)
2217        }
2218        fn arena_ptr(&self) -> *const u8 {
2219            self.inner.arena_ptr()
2220        }
2221        fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
2222            self.inner.commit_no_wait(inputs);
2223        }
2224        fn sync_pending(&mut self) {
2225            self.inner.sync_pending();
2226        }
2227        fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
2228            self.inner.run_pipelined(input_sets)
2229        }
2230        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2231            self.inner.set_active_extent(extent);
2232        }
2233
2234        /// Typed param upload — accepts F16/BF16 host bytes by widening
2235        /// to F32 first, then routing through `set_param`. The Metal
2236        /// arena's `write_from_f32` honors per-node F16 storage when
2237        /// AutoMixedPrecision rewrote the param. U8/I8 packed weights
2238        /// copy directly into the arena for `Op::DequantMatMul`.
2239        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2240            if matches!(dtype, rlx_ir::DType::U8 | rlx_ir::DType::I8) {
2241                self.inner.set_param_bytes(name, data);
2242                return;
2243            }
2244            if dtype == rlx_ir::DType::F32 {
2245                let n = data.len() / 4;
2246                let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2247                self.inner.set_param(name, s);
2248            } else {
2249                let f32_buf = super::widen_bytes_to_f32(data, dtype);
2250                self.inner.set_param(name, &f32_buf);
2251            }
2252        }
2253
2254        /// Typed run. Inputs widen to F32 (existing path; F64 host
2255        /// inputs through `run_typed` is a separate Metal extension).
2256        /// Outputs: F64 outputs go through the byte-direct
2257        /// `output_bytes_per_node` path (no precision loss in the
2258        /// f32 round-trip); other dtypes keep the f32-narrow path
2259        /// for backward compatibility with existing AutoMixedPrecision
2260        /// rewrites.
2261        fn run_typed(
2262            &mut self,
2263            inputs: &[(&str, &[u8], rlx_ir::DType)],
2264        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2265            let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2266            for (name, data, dt) in inputs {
2267                let v = super::widen_bytes_to_f32(data, *dt);
2268                owned.push((name.to_string(), v));
2269            }
2270            let refs: Vec<(&str, &[f32])> = owned
2271                .iter()
2272                .map(|(n, d)| (n.as_str(), d.as_slice()))
2273                .collect();
2274            let dtypes =
2275                super::declared_output_dtypes(&self.io_manifest, self.inner.output_dtypes());
2276            let f32_outs = self.inner.run(&refs);
2277            let byte_outs = self.inner.output_bytes_per_node();
2278            f32_outs
2279                .into_iter()
2280                .zip(byte_outs.into_iter())
2281                .zip(
2282                    dtypes
2283                        .into_iter()
2284                        .chain(std::iter::repeat(rlx_ir::DType::F32)),
2285                )
2286                .map(|((f32_v, byte_v), dt)| match dt {
2287                    rlx_ir::DType::F64 => (byte_v, dt),
2288                    _ => (super::narrow_f32_to_bytes(&f32_v, dt), dt),
2289                })
2290                .collect()
2291        }
2292    }
2293}
2294
2295// ── CUDA Backend ────────────────────────────────────────────────────────
2296
2297#[cfg(feature = "cuda")]
2298pub mod cuda_backend {
2299    use super::*;
2300    use rlx_cuda::backend::CudaExecutable;
2301
2302    pub struct CudaBackend;
2303
2304    /// PLAN L4: ops the CUDA backend can lower today. Excludes
2305    /// FusedSwiGLU, LoraMatMul, FusedAttentionBlock,
2306    /// FusedTransformerLayer (no kernel) + If, While (no executor
2307    /// wiring). DotGeneral via LowerDotGeneral; ElementwiseRegion
2308    /// lowered natively by an NVRTC interpreted-chain kernel.
2309    const CUDA_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2310        use rlx_ir::OpKind::*;
2311        &[
2312            Input,
2313            Param,
2314            Constant,
2315            Activation,
2316            Cast,
2317            Binary,
2318            Compare,
2319            Where,
2320            ElementwiseRegion,
2321            TransformRegion,
2322            BatchElementwiseRegion,
2323            MatMul,
2324            DotGeneral,
2325            LayerNorm,
2326            LayerNorm2d,
2327            GroupNorm,
2328            ResizeNearest2x,
2329            RmsNorm,
2330            Attention,
2331            AttentionBackward,
2332            RmsNormBackwardInput,
2333            RmsNormBackwardGamma,
2334            RmsNormBackwardBeta,
2335            RopeBackward,
2336            CumsumBackward,
2337            GatherBackward,
2338            Conv2dBackwardInput,
2339            Conv2dBackwardWeight,
2340            MaxPool2dBackward,
2341            Rope,
2342            Reshape,
2343            Transpose,
2344            Narrow,
2345            Concat,
2346            Expand,
2347            Gather,
2348            Reduce,
2349            Softmax,
2350            Cumsum,
2351            TopK,
2352            Sample,
2353            Conv,
2354            ConvTranspose2d,
2355            Pool,
2356            GroupedMatMul,
2357            DequantGroupedMatMul,
2358            DequantMoEWeights,
2359            ScatterAdd,
2360            DequantMatMul,
2361            SelectiveScan,
2362            FusedMatMulBiasAct,
2363            FusedResidualLN,
2364            FusedResidualRmsNorm,
2365            GaussianSplatRender,
2366            GaussianSplatRenderBackward,
2367            GaussianSplatPrepare,
2368            GaussianSplatRasterize,
2369            Custom,
2370            Fft,
2371            LogMel,
2372            LogMelBackward,
2373            Im2Col,
2374        ]
2375    };
2376
2377    impl Backend for CudaBackend {
2378        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2379            CUDA_SUPPORTED_OPS
2380        }
2381
2382        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2383            // Decompose FusedSwiGLU / FAB / etc. before legalization (CudaExecutable
2384            // unfuses again; this pass is idempotent).
2385            let graph = rlx_cuda::unfuse::unfuse(graph);
2386            let graph = rlx_opt::legalize_or_rewrite_for_backend(graph, CUDA_SUPPORTED_OPS)
2387                .unwrap_or_else(|errors| {
2388                    panic!("{}", rlx_opt::format_legalize_error("cuda", &errors));
2389                });
2390            let graph = crate::precompile::precompile_cleanup(graph, options);
2391            // Mid-axis broadcasts (EEG patch embed) before elementwise fusion.
2392            let graph = rlx_opt::LegalizeBroadcast.run(graph);
2393            // Backend-aware fusion via the shared compile pipeline.
2394            let compile_result = crate::stages::compile_graph_stages_for_backend(
2395                rlx_driver::Device::Cuda,
2396                graph,
2397                options,
2398                CUDA_SUPPORTED_OPS,
2399            );
2400            crate::stages::maybe_log_fusion(&compile_result.fusion);
2401            let graph = compile_result.lir.into_graph();
2402            let graph = match options.policy.clone() {
2403                Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
2404                None => graph,
2405            };
2406            let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
2407            Box::new(CudaExecutableWrapper {
2408                inner: CudaExecutable::compile(graph),
2409                io_manifest,
2410            })
2411        }
2412
2413        fn compile_lir(
2414            &self,
2415            lir: LirModule,
2416            options: &CompileOptions,
2417        ) -> Box<dyn ExecutableGraph> {
2418            use rlx_opt::pass::Pass as _;
2419            let graph = rlx_opt::LegalizeBroadcast.run(lir.into_graph());
2420            let (graph, io_manifest) =
2421                cpu_low_precision::prepare_f32_exec_graph(prepare_fused_graph(
2422                    rlx_cuda::unfuse::unfuse(graph),
2423                    options,
2424                    CUDA_SUPPORTED_OPS,
2425                    "cuda",
2426                ));
2427            Box::new(CudaExecutableWrapper {
2428                inner: CudaExecutable::compile(graph),
2429                io_manifest,
2430            })
2431        }
2432    }
2433
2434    struct CudaExecutableWrapper {
2435        inner: CudaExecutable,
2436        io_manifest: cpu_low_precision::IoDtypeManifest,
2437    }
2438
2439    // CudaExecutable owns CudaContext + CudaSlice handles; cudarc claims
2440    // they're Send (CudaContext is Arc-wrapped, CudaSlice is logically
2441    // a device pointer + length). The Backend trait requires Send for
2442    // the executable; we honor that here.
2443    unsafe impl Send for CudaExecutableWrapper {}
2444
2445    impl ExecutableGraph for CudaExecutableWrapper {
2446        fn set_param(&mut self, name: &str, data: &[f32]) {
2447            self.inner.set_param(name, data);
2448        }
2449        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2450            self.inner.run(inputs)
2451        }
2452        fn run_read_outputs(
2453            &mut self,
2454            inputs: &[(&str, &[f32])],
2455            read_indices: Option<&[usize]>,
2456        ) -> Vec<Vec<f32>> {
2457            self.inner.run_read_outputs(inputs, read_indices)
2458        }
2459        fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
2460            self.inner.bind_gpu_handle(name, data)
2461        }
2462        fn has_gpu_handle(&self, name: &str) -> bool {
2463            self.inner.has_gpu_handle(name)
2464        }
2465        fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
2466            self.inner.set_gpu_handle_feed(handle_name, output_index);
2467            true
2468        }
2469        fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
2470            self.inner.read_gpu_handle(name)
2471        }
2472        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2473            self.inner.set_active_extent(extent);
2474        }
2475
2476        fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
2477            self.inner.run_slots(inputs)
2478        }
2479
2480        fn arena_ptr(&self) -> *const u8 {
2481            self.inner.arena_ptr()
2482        }
2483
2484        /// Typed param upload — widens F16/BF16 host bytes to f32
2485        /// before routing through `set_param`. CUDA's arena is
2486        /// f32-uniform; the half-precision matmul tier opts in via
2487        /// the separate `set_param_half` API.
2488        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2489            if matches!(dtype, rlx_ir::DType::U8 | rlx_ir::DType::I8) {
2490                self.inner.set_param_bytes(name, data);
2491                return;
2492            }
2493            if dtype == rlx_ir::DType::F32 {
2494                let n = data.len() / 4;
2495                let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2496                self.inner.set_param(name, s);
2497            } else {
2498                let f32_buf = super::widen_bytes_to_f32(data, dtype);
2499                self.inner.set_param(name, &f32_buf);
2500            }
2501        }
2502
2503        /// Typed run — widen each typed input to F32, run, then narrow
2504        /// each output back to its declared graph dtype.
2505        fn run_typed(
2506            &mut self,
2507            inputs: &[(&str, &[u8], rlx_ir::DType)],
2508        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2509            let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2510            for (name, data, dt) in inputs {
2511                let v = super::widen_bytes_to_f32(data, *dt);
2512                owned.push((name.to_string(), v));
2513            }
2514            let refs: Vec<(&str, &[f32])> = owned
2515                .iter()
2516                .map(|(n, d)| (n.as_str(), d.as_slice()))
2517                .collect();
2518            let dtypes =
2519                super::declared_output_dtypes(&self.io_manifest, self.inner.output_dtypes());
2520            let outs = self.inner.run(&refs);
2521            outs.into_iter()
2522                .zip(
2523                    dtypes
2524                        .into_iter()
2525                        .chain(std::iter::repeat(rlx_ir::DType::F32)),
2526                )
2527                .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2528                .collect()
2529        }
2530    }
2531}
2532
2533// ── ROCm Backend ────────────────────────────────────────────────────────
2534
2535#[cfg(feature = "rocm")]
2536pub mod rocm_backend {
2537    use super::*;
2538    use rlx_rocm::backend::RocmExecutable;
2539
2540    pub struct RocmBackend;
2541
2542    /// PLAN L4: ROCm is the sister crate of CUDA; identical Step
2543    /// enum + dispatch shape → identical claimed op set.
2544    const ROCM_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2545        use rlx_ir::OpKind::*;
2546        &[
2547            Input,
2548            Param,
2549            Constant,
2550            Activation,
2551            Cast,
2552            Binary,
2553            Compare,
2554            Where,
2555            ElementwiseRegion,
2556            TransformRegion,
2557            BatchElementwiseRegion,
2558            MatMul,
2559            DotGeneral,
2560            LayerNorm,
2561            LayerNorm2d,
2562            GroupNorm,
2563            ResizeNearest2x,
2564            RmsNorm,
2565            Attention,
2566            AttentionBackward,
2567            RmsNormBackwardInput,
2568            RmsNormBackwardGamma,
2569            RmsNormBackwardBeta,
2570            RopeBackward,
2571            CumsumBackward,
2572            GatherBackward,
2573            Rope,
2574            Reshape,
2575            Transpose,
2576            Narrow,
2577            Concat,
2578            Expand,
2579            Gather,
2580            Reduce,
2581            Softmax,
2582            Cumsum,
2583            TopK,
2584            Sample,
2585            Conv,
2586            ConvTranspose2d,
2587            Pool,
2588            GroupedMatMul,
2589            DequantGroupedMatMul,
2590            DequantMoEWeights,
2591            ScatterAdd,
2592            DequantMatMul,
2593            SelectiveScan,
2594            FusedMatMulBiasAct,
2595            FusedResidualLN,
2596            FusedResidualRmsNorm,
2597            GaussianSplatRender,
2598            GaussianSplatRenderBackward,
2599            GaussianSplatPrepare,
2600            GaussianSplatRasterize,
2601            Custom,
2602            Fft,
2603            LogMel,
2604            LogMelBackward,
2605            Im2Col,
2606        ]
2607    };
2608
2609    impl Backend for RocmBackend {
2610        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2611            ROCM_SUPPORTED_OPS
2612        }
2613
2614        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2615            let graph = rlx_rocm::unfuse::unfuse(graph);
2616            let graph = rlx_opt::legalize_or_rewrite_for_backend(graph, ROCM_SUPPORTED_OPS)
2617                .unwrap_or_else(|errors| {
2618                    panic!("{}", rlx_opt::format_legalize_error("rocm", &errors));
2619                });
2620            let graph = crate::precompile::precompile_cleanup(graph, options);
2621            let graph = rlx_opt::LegalizeBroadcast.run(graph);
2622            let compile_result = crate::stages::compile_graph_stages_for_backend(
2623                rlx_driver::Device::Rocm,
2624                graph,
2625                options,
2626                ROCM_SUPPORTED_OPS,
2627            );
2628            crate::stages::maybe_log_fusion(&compile_result.fusion);
2629            let graph = compile_result.lir.into_graph();
2630            let graph = match options.policy.clone() {
2631                Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
2632                None => graph,
2633            };
2634            let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
2635            Box::new(RocmExecutableWrapper {
2636                inner: RocmExecutable::compile(graph),
2637                io_manifest,
2638            })
2639        }
2640
2641        fn compile_lir(
2642            &self,
2643            lir: LirModule,
2644            options: &CompileOptions,
2645        ) -> Box<dyn ExecutableGraph> {
2646            let (graph, io_manifest) =
2647                cpu_low_precision::prepare_f32_exec_graph(prepare_fused_graph(
2648                    rlx_rocm::unfuse::unfuse(lir.into_graph()),
2649                    options,
2650                    ROCM_SUPPORTED_OPS,
2651                    "rocm",
2652                ));
2653            Box::new(RocmExecutableWrapper {
2654                inner: RocmExecutable::compile(graph),
2655                io_manifest,
2656            })
2657        }
2658    }
2659
2660    struct RocmExecutableWrapper {
2661        inner: RocmExecutable,
2662        io_manifest: cpu_low_precision::IoDtypeManifest,
2663    }
2664
2665    // Same Send-claim shape as CudaExecutableWrapper. RocmExecutable
2666    // owns Arc<RocmContext> + HipBuffer handles; the HipRuntime bundle
2667    // is internally thread-safe per AMD's documentation.
2668    unsafe impl Send for RocmExecutableWrapper {}
2669
2670    impl ExecutableGraph for RocmExecutableWrapper {
2671        fn set_param(&mut self, name: &str, data: &[f32]) {
2672            self.inner.set_param(name, data);
2673        }
2674        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2675            self.inner.run(inputs)
2676        }
2677        fn run_read_outputs(
2678            &mut self,
2679            inputs: &[(&str, &[f32])],
2680            read_indices: Option<&[usize]>,
2681        ) -> Vec<Vec<f32>> {
2682            self.inner.run_read_outputs(inputs, read_indices)
2683        }
2684        fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
2685            self.inner.bind_gpu_handle(name, data)
2686        }
2687        fn has_gpu_handle(&self, name: &str) -> bool {
2688            self.inner.has_gpu_handle(name)
2689        }
2690        fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
2691            self.inner.set_gpu_handle_feed(handle_name, output_index);
2692            true
2693        }
2694        fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
2695            self.inner.read_gpu_handle(name)
2696        }
2697        fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
2698            self.inner.run_slots(inputs)
2699        }
2700        fn arena_ptr(&self) -> *const u8 {
2701            self.inner.arena_ptr()
2702        }
2703        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2704            self.inner.set_active_extent(extent);
2705        }
2706
2707        /// Typed param upload — widens F16/BF16 host bytes to f32
2708        /// before routing through `set_param`. ROCm's arena is
2709        /// f32-uniform; the half-precision matmul tier opts in via
2710        /// the separate `set_param_half` API.
2711        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2712            if matches!(dtype, rlx_ir::DType::U8 | rlx_ir::DType::I8) {
2713                self.inner.set_param_bytes(name, data);
2714                return;
2715            }
2716            if dtype == rlx_ir::DType::F32 {
2717                let n = data.len() / 4;
2718                let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2719                self.inner.set_param(name, s);
2720            } else {
2721                let f32_buf = super::widen_bytes_to_f32(data, dtype);
2722                self.inner.set_param(name, &f32_buf);
2723            }
2724        }
2725
2726        /// Typed run — widen each typed input to F32, run, then narrow
2727        /// each output back to its declared graph dtype.
2728        fn run_typed(
2729            &mut self,
2730            inputs: &[(&str, &[u8], rlx_ir::DType)],
2731        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2732            let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2733            for (name, data, dt) in inputs {
2734                let v = super::widen_bytes_to_f32(data, *dt);
2735                owned.push((name.to_string(), v));
2736            }
2737            let refs: Vec<(&str, &[f32])> = owned
2738                .iter()
2739                .map(|(n, d)| (n.as_str(), d.as_slice()))
2740                .collect();
2741            let dtypes =
2742                super::declared_output_dtypes(&self.io_manifest, self.inner.output_dtypes());
2743            let outs = self.inner.run(&refs);
2744            outs.into_iter()
2745                .zip(
2746                    dtypes
2747                        .into_iter()
2748                        .chain(std::iter::repeat(rlx_ir::DType::F32)),
2749                )
2750                .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2751                .collect()
2752        }
2753    }
2754}
2755
2756// ── TPU Backend ─────────────────────────────────────────────────────────
2757
2758#[cfg(feature = "tpu")]
2759pub mod tpu_backend {
2760    use super::*;
2761    use rlx_tpu::TpuExecutable;
2762
2763    pub struct TpuBackend;
2764
2765    /// Ops the TPU backend lowers to HLO. Full inference parity with
2766    /// rlx-cuda / rlx-rocm. Composite ops (FusedSwiGLU /
2767    /// FusedAttentionBlock / FusedTransformerLayer / LoraMatMul / If /
2768    /// While) are unfused inside `rlx_tpu::unfuse::unfuse` ahead of
2769    /// HLO emission, so they don't appear here.
2770    const TPU_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2771        use rlx_ir::OpKind::*;
2772        &[
2773            Input,
2774            Param,
2775            Constant,
2776            Activation,
2777            Cast,
2778            Binary,
2779            Compare,
2780            Where,
2781            ElementwiseRegion,
2782            TransformRegion,
2783            BatchElementwiseRegion,
2784            MatMul,
2785            DotGeneral,
2786            LayerNorm,
2787            RmsNorm,
2788            Attention,
2789            Rope,
2790            Reshape,
2791            Transpose,
2792            Narrow,
2793            Concat,
2794            Expand,
2795            Gather,
2796            Reduce,
2797            Softmax,
2798            Cumsum,
2799            TopK,
2800            Sample,
2801            Conv,
2802            Pool,
2803            GroupedMatMul,
2804            DequantGroupedMatMul,
2805            DequantMoEWeights,
2806            ScatterAdd,
2807            DequantMatMul,
2808            SelectiveScan,
2809            // Real-INT8 path + fake-quant.
2810            QMatMul,
2811            QConv2d,
2812            Quantize,
2813            Dequantize,
2814            FusedMatMulBiasAct,
2815            FusedResidualLN,
2816            FusedResidualRmsNorm,
2817            Fft,
2818            LogMel,
2819            LogMelBackward,
2820            // Splat: no on-chip kernel — lowered to common primitive MIR via logical_kernel.
2821        ]
2822    };
2823
2824    impl Backend for TpuBackend {
2825        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2826            TPU_SUPPORTED_OPS
2827        }
2828
2829        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2830            let graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
2831                graph,
2832                TPU_SUPPORTED_OPS,
2833                options.kernel_dispatch,
2834            )
2835            .unwrap_or_else(|errors| {
2836                panic!("{}", rlx_opt::format_legalize_error("tpu", &errors));
2837            });
2838            // The TPU's IR-side pass pipeline (DCE, ConstFold,
2839            // FuseResidualLN, FuseMatMulBiasAct, LegalizeBroadcast,
2840            // MarkElementwiseRegions) lives inside
2841            // `TpuExecutable::compile` so the same passes run whether
2842            // a caller goes through Session or invokes the executable
2843            // directly. We only do backend-cross-cutting work here:
2844            // legalization (must precede the pipeline so we panic
2845            // early on unsupported ops) and AutoMixedPrecision.
2846            //
2847            // Default policy on TPU is `AutoMixedBf16`: BF16 is the
2848            // native compute dtype on TPU silicon and recent GPUs,
2849            // and XLA's CPU plugin handles it natively too. Callers
2850            // can opt out by passing an explicit `PrecisionPolicy`
2851            // (e.g. `AlwaysF32` for accuracy debugging or
2852            // `AlwaysF16` to match a CUDA workload's choice).
2853            use rlx_opt::pass::Pass as _;
2854            let policy = options
2855                .policy
2856                .clone()
2857                .unwrap_or(rlx_opt::PrecisionPolicy::AutoMixedBf16);
2858            let graph = rlx_opt::AutoMixedPrecision::new(policy).run(graph);
2859            let _ = options.dce;
2860            let _ = options.constant_folding;
2861            Box::new(TpuExecutableWrapper {
2862                inner: TpuExecutable::compile(graph),
2863            })
2864        }
2865    }
2866
2867    struct TpuExecutableWrapper {
2868        inner: TpuExecutable,
2869    }
2870
2871    // PJRT clients + buffers are documented as thread-safe per the
2872    // upstream C API. Same Send-claim shape as CudaExecutableWrapper /
2873    // RocmExecutableWrapper.
2874    unsafe impl Send for TpuExecutableWrapper {}
2875
2876    impl ExecutableGraph for TpuExecutableWrapper {
2877        fn set_param(&mut self, name: &str, data: &[f32]) {
2878            self.inner.set_param(name, data);
2879        }
2880        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2881            self.inner.run(inputs)
2882        }
2883
2884        /// Typed param upload — widens F16/BF16/etc. host bytes to
2885        /// f32 today. Once the HLO emitter speaks bf16 natively
2886        /// (which TPUs prefer over f16), the typed path will hand
2887        /// the original bytes straight through `Buffer_FromHostBuffer`.
2888        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2889            if dtype == rlx_ir::DType::F32 {
2890                let n = data.len() / 4;
2891                let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2892                self.inner.set_param(name, s);
2893            } else {
2894                let f32_buf = super::widen_bytes_to_f32(data, dtype);
2895                self.inner.set_param(name, &f32_buf);
2896            }
2897        }
2898
2899        fn run_typed(
2900            &mut self,
2901            inputs: &[(&str, &[u8], rlx_ir::DType)],
2902        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2903            let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2904            for (name, data, dt) in inputs {
2905                let v = super::widen_bytes_to_f32(data, *dt);
2906                owned.push((name.to_string(), v));
2907            }
2908            let refs: Vec<(&str, &[f32])> = owned
2909                .iter()
2910                .map(|(n, d)| (n.as_str(), d.as_slice()))
2911                .collect();
2912            let dtypes = self.inner.output_dtypes();
2913            let outs = self.inner.run(&refs);
2914            outs.into_iter()
2915                .zip(
2916                    dtypes
2917                        .into_iter()
2918                        .chain(std::iter::repeat(rlx_ir::DType::F32)),
2919                )
2920                .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2921                .collect()
2922        }
2923    }
2924}