Skip to main content

rlx_runtime/
backend.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Backend trait — abstraction over CPU/GPU/CUDA execution.
17//!
18//! Each backend implements `Backend::compile(graph, &CompileOptions)` and
19//! returns an `ExecutableGraph`. New compile knobs go in `CompileOptions`
20//! rather than as new trait methods.
21
22use crate::CompileOptions;
23use rlx_ir::Graph;
24use rlx_ir::hir::HirModule;
25use rlx_ir::lir::LirModule;
26use std::collections::HashMap;
27use std::sync::Arc;
28
29use crate::cpu_low_precision;
30
31// ── Typed I/O helpers (shared across f32-arena backends) ────────────────
32
33/// Widen a typed byte buffer to `Vec<f32>`. Used by `set_param_typed` /
34/// `run_typed` overrides on backends whose internal arena is f32-uniform
35/// (CPU, Metal, wgpu) so callers can hand in F16/BF16 without doing the
36/// host-side cast themselves. Panics on dtypes the f32 arena can't carry.
37#[allow(dead_code)]
38pub(crate) fn widen_bytes_to_f32(data: &[u8], dtype: rlx_ir::DType) -> Vec<f32> {
39    use rlx_ir::DType;
40    match dtype {
41        DType::F32 => {
42            let n = data.len() / 4;
43            let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
44            s.to_vec()
45        }
46        DType::F16 => {
47            let n = data.len() / 2;
48            let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n) };
49            s.iter().map(|h| h.to_f32()).collect()
50        }
51        DType::BF16 => {
52            let n = data.len() / 2;
53            let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n) };
54            s.iter().map(|h| h.to_f32()).collect()
55        }
56        other => panic!(
57            "widen_bytes_to_f32: dtype {other:?} unsupported on f32-arena backends \
58             (only F32/F16/BF16 are accepted on the host I/O surface)"
59        ),
60    }
61}
62
63/// Narrow a `&[f32]` buffer down to the declared output dtype, returning
64/// the corresponding little-endian byte stream. Mirrors the bytes a
65/// backend that stores the native dtype would emit. Used by `run_typed`
66/// to keep the byte-level output contract identical across backends.
67#[allow(dead_code)]
68pub(crate) fn narrow_f32_to_bytes(v: &[f32], dt: rlx_ir::DType) -> Vec<u8> {
69    use rlx_ir::DType;
70    match dt {
71        DType::F32 => {
72            let mut bytes = Vec::with_capacity(v.len() * 4);
73            for &x in v {
74                bytes.extend_from_slice(&x.to_le_bytes());
75            }
76            bytes
77        }
78        DType::F16 => {
79            let mut bytes = Vec::with_capacity(v.len() * 2);
80            for &x in v {
81                bytes.extend_from_slice(&half::f16::from_f32(x).to_le_bytes());
82            }
83            bytes
84        }
85        DType::BF16 => {
86            let mut bytes = Vec::with_capacity(v.len() * 2);
87            for &x in v {
88                bytes.extend_from_slice(&half::bf16::from_f32(x).to_le_bytes());
89            }
90            bytes
91        }
92        DType::F64 => {
93            let mut bytes = Vec::with_capacity(v.len() * 8);
94            for &x in v {
95                bytes.extend_from_slice(&(x as f64).to_le_bytes());
96            }
97            bytes
98        }
99        DType::I8 => v.iter().map(|&x| x as i8 as u8).collect(),
100        DType::U8 => v.iter().map(|&x| x as u8).collect(),
101        DType::I16 => {
102            let mut bytes = Vec::with_capacity(v.len() * 2);
103            for &x in v {
104                bytes.extend_from_slice(&(x as i16).to_le_bytes());
105            }
106            bytes
107        }
108        DType::I32 => {
109            let mut bytes = Vec::with_capacity(v.len() * 4);
110            for &x in v {
111                bytes.extend_from_slice(&(x as i32).to_le_bytes());
112            }
113            bytes
114        }
115        DType::U32 => {
116            let mut bytes = Vec::with_capacity(v.len() * 4);
117            for &x in v {
118                bytes.extend_from_slice(&(x as u32).to_le_bytes());
119            }
120            bytes
121        }
122        DType::I64 => {
123            let mut bytes = Vec::with_capacity(v.len() * 8);
124            for &x in v {
125                bytes.extend_from_slice(&(x as i64).to_le_bytes());
126            }
127            bytes
128        }
129        DType::Bool => v
130            .iter()
131            .map(|&x| if x != 0.0 { 1u8 } else { 0u8 })
132            .collect(),
133        DType::C64 => {
134            // Complex narrow path: real part = the f32 value, imaginary
135            // part = 0. Mirrors how the backend stores narrowed f32
136            // operands when promoted to a complex op input.
137            let mut bytes = Vec::with_capacity(v.len() * 8);
138            for &x in v {
139                bytes.extend_from_slice(&x.to_le_bytes());
140                bytes.extend_from_slice(&0.0_f32.to_le_bytes());
141            }
142            bytes
143        }
144    }
145}
146
147/// A compiled, ready-to-execute graph on a specific backend.
148pub trait ExecutableGraph: Send {
149    /// Set a named parameter (weight) buffer.
150    fn set_param(&mut self, name: &str, data: &[f32]);
151
152    /// Deep-clone this executable into a fresh `Box`. Lets
153    /// `CompiledGraph` implement `Clone` so callers (e.g. eda-mna's
154    /// `SensitivityContext`) can spin up N independent executor
155    /// copies for thread-parallel dispatch without paying the full
156    /// graph-compile cost N times. Default implementation panics;
157    /// backends that support cloning override.
158    fn clone_box(&self) -> Box<dyn ExecutableGraph> {
159        panic!("clone_box not implemented for this backend");
160    }
161
162    /// Execute the graph with named inputs. Returns output data (copies from arena).
163    fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>>;
164
165    /// Like [`Self::run`] but only read back outputs at `read_indices`.
166    /// GPU handle feeds still update for every output. Default: all outputs.
167    fn run_read_outputs(
168        &mut self,
169        inputs: &[(&str, &[f32])],
170        read_indices: Option<&[usize]>,
171    ) -> Vec<Vec<f32>> {
172        match read_indices {
173            None => self.run(inputs),
174            Some(ix) => {
175                // Backends without a native partial-read path still run the full
176                // graph; only clone the requested outputs on the host.
177                let all = self.run(inputs);
178                ix.iter().filter_map(|&i| all.get(i).cloned()).collect()
179            }
180        }
181    }
182
183    /// Execute and return raw pointers to output data in arena (zero-copy).
184    fn run_raw(&mut self, inputs: &[(&str, &[f32])]) -> Vec<(*const f32, usize)> {
185        let vecs = self.run(inputs);
186        vecs.iter().map(|v| (v.as_ptr(), v.len())).collect()
187    }
188
189    /// Fastest: inputs by slot index, returns output (offset, len) pairs.
190    /// Read output from arena via `arena_ptr().add(offset)`.
191    fn run_slots(&mut self, _inputs: &[&[f32]]) -> &[(usize, usize)] {
192        &[] // default: not supported
193    }
194
195    /// Get the raw arena buffer pointer for reading outputs after run_slots.
196    fn arena_ptr(&self) -> *const u8 {
197        std::ptr::null()
198    }
199
200    /// Hint the executor that subsequent `run` calls should process
201    /// only the first `actual` rows along the bucket axis (out of
202    /// `upper`, the extent the graph was compiled at). Backends that
203    /// support per-kernel active-extent dispatch honor this; others
204    /// ignore it and process the full compiled extent.
205    ///
206    /// Pass `None` to clear the hint. The hint is sticky — set it
207    /// before each `run` and clear it after, or maintain it across
208    /// runs at your discretion.
209    ///
210    /// Even when honored, callers must not rely on the contents of the
211    /// output past `actual` rows — that region may contain stale data
212    /// from earlier runs (kernels skip it).
213    ///
214    /// Default: no-op. See `BucketedCompileCache::run_padded` for the
215    /// canonical caller; backends opt in by overriding this method.
216    fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
217        let _ = extent;
218    }
219
220    /// TIDE merged placement mask (union across MoE layers). CPU: stats + host path.
221    fn set_moe_resident_experts(&mut self, _mask: &[bool]) {}
222
223    /// Per MoE layer placement (`masks[layer][expert]`). Preferred over merged on CPU.
224    fn set_moe_resident_experts_per_layer(&mut self, _masks: &[&[bool]]) {}
225
226    /// Capture MoE router TopK indices on the next CPU forward (TIDE refresh).
227    fn enable_moe_topk_capture(&mut self, _num_experts: usize) -> bool {
228        false
229    }
230
231    /// Take captured per-layer expert indices (one vec per MoE TopK in order).
232    fn take_moe_topk_capture(&mut self) -> Option<Vec<Vec<u32>>> {
233        None
234    }
235
236    /// MoE GroupedMatMul residency accounting from the last forward (CPU).
237    fn take_moe_residency_stats(&mut self) -> Option<crate::MoeResidencyStats> {
238        None
239    }
240
241    /// Bind a persistent buffer handle (KV-cache, training state, etc.).
242    /// The buffer lives across run() calls and is not in the arena.
243    /// Returns true if the backend supports persistent handles.
244    fn bind_handle(&mut self, _name: &str, _data: &[f32]) -> bool {
245        false
246    }
247
248    /// Read a persistent buffer's current contents.
249    fn read_handle(&self, _name: &str) -> Option<Vec<f32>> {
250        None
251    }
252
253    /// GPU-resident input (MLX): upload once, reuse across runs.
254    fn bind_gpu_handle(&mut self, _name: &str, _data: &[f32]) -> bool {
255        false
256    }
257
258    fn has_gpu_handle(&self, _name: &str) -> bool {
259        false
260    }
261
262    fn set_gpu_handle_feed(&mut self, _handle_name: &str, _output_index: usize) -> bool {
263        false
264    }
265
266    fn read_gpu_handle(&self, _name: &str) -> Option<Vec<f32>> {
267        None
268    }
269
270    /// Read one row from a row-major graph output after `run` / `run_read_outputs`.
271    /// Metal reads a single row from the arena; default returns `None` (caller falls back).
272    fn read_output_row(&self, _out_idx: usize, _row: usize, _row_inner: usize) -> Option<Vec<f32>> {
273        None
274    }
275
276    /// Run and refresh a GPU handle from `output_index`; returns that output on host.
277    fn run_feed_gpu_handle(
278        &mut self,
279        inputs: &[(&str, &[f32])],
280        _handle_name: &str,
281        _output_index: usize,
282    ) -> Option<Vec<f32>> {
283        let _ = inputs;
284        None
285    }
286
287    // ── Pipelined / async execution (Phase C) ─────────────────────────
288    //
289    // These allow callers to amortize per-run sync latency on backends
290    // where it matters (Metal: ~150 µs `wait_until_completed` per commit).
291    // CPU has no such cost, so the default impls just call `run` serially.
292
293    /// Encode + commit a forward pass without waiting for completion.
294    ///
295    /// Outputs of intermediate calls are stomped — use `run_pipelined` if
296    /// you need outputs from each individual commit. Pair with
297    /// `sync_pending` to drain.
298    ///
299    /// Default: synchronous fallback (calls `run`, discards output). CPU
300    /// uses this default since BLAS is synchronous anyway.
301    fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
302        let _ = self.run(inputs);
303    }
304
305    /// Wait for every command queued by `commit_no_wait`.
306    /// Default: no-op (synchronous backends have nothing pending).
307    fn sync_pending(&mut self) {}
308
309    /// Issue a batch of forward passes pipelined, returning per-run outputs.
310    ///
311    /// The Metal impl encodes a per-commit blit so each in-flight run's
312    /// outputs survive subsequent commits stomping the shared arena. The
313    /// CPU default is just sequential `run`s — equally correct, no perf
314    /// penalty (CPU has no GPU sync cost to amortize).
315    ///
316    /// Returns `out[run_idx][output_idx][element_idx]`.
317    fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
318        input_sets.iter().map(|inputs| self.run(inputs)).collect()
319    }
320
321    // ── Typed (non-F32) host I/O ──────────────────────────────────
322    //
323    // `set_param` and `run` are F32 by contract. The typed entry
324    // points let callers pass and receive raw bytes in any rlx-ir
325    // dtype, avoiding the f32 widen/narrow round-trip that's
326    // wasteful for F16/BF16 weights and activations.
327    //
328    // The default impls only handle F32 — any other dtype panics.
329    // Backends that support typed I/O natively (e.g. MLX via
330    // Array::from_bytes/to_bytes) override these.
331
332    /// Set a named parameter from raw bytes in the given dtype.
333    fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
334        if dtype != rlx_ir::DType::F32 {
335            panic!(
336                "backend's default set_param_typed only handles F32; \
337                    got {dtype:?}. Override on the backend for typed support."
338            );
339        }
340        if !data.len().is_multiple_of(4) {
341            panic!(
342                "set_param_typed F32: data length {} not a multiple of 4",
343                data.len()
344            );
345        }
346        // SAFETY: F32 bytes are 4-aligned by source convention; we
347        // only widen access (read &[f32] from owned &[u8]). Failure
348        // mode if a caller hands us mis-aligned bytes is undefined,
349        // hence the % 4 length check.
350        let n = data.len() / 4;
351        let f32_slice = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
352        self.set_param(name, f32_slice);
353    }
354
355    /// Run with typed inputs and typed outputs. Returns
356    /// `(bytes, dtype)` per output; the dtype is whatever the
357    /// graph's output node was declared as.
358    fn run_typed(
359        &mut self,
360        inputs: &[(&str, &[u8], rlx_ir::DType)],
361    ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
362        // Default impl: convert each typed input to f32 (F32-only),
363        // run, then re-emit outputs as F32 bytes.
364        let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
365        for (name, data, dt) in inputs {
366            if *dt != rlx_ir::DType::F32 {
367                panic!(
368                    "backend's default run_typed only handles F32 inputs; \
369                        got {dt:?} for input '{name}'"
370                );
371            }
372            if data.len() % 4 != 0 {
373                panic!(
374                    "run_typed F32 input '{name}': len {} not multiple of 4",
375                    data.len()
376                );
377            }
378            let n = data.len() / 4;
379            let v: Vec<f32> =
380                unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) }.to_vec();
381            owned.push((name.to_string(), v));
382        }
383        let refs: Vec<(&str, &[f32])> = owned
384            .iter()
385            .map(|(n, d)| (n.as_str(), d.as_slice()))
386            .collect();
387        let outs = self.run(&refs);
388        outs.into_iter()
389            .map(|v| {
390                let bytes =
391                    unsafe { std::slice::from_raw_parts(v.as_ptr() as *const u8, v.len() * 4) }
392                        .to_vec();
393                (bytes, rlx_ir::DType::F32)
394            })
395            .collect()
396    }
397}
398
399/// Backend implementation trait.
400///
401/// Single compile entry point. New compile-time knobs are added to
402/// `CompileOptions`, not as new trait methods.
403///
404/// `Send + Sync` because backends are stateless factories — multiple
405/// threads can call `compile` concurrently. The returned
406/// `Box<dyn ExecutableGraph>` is `Send` (moveable to a worker thread)
407/// but **not** `Sync` (`run`/`run_slots` take `&mut self`).
408pub trait Backend: Send + Sync {
409    /// Compile a graph for this backend with the given options.
410    fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph>;
411
412    /// Compile pre-optimized LIR (HIR → MIR → LIR pipeline output).
413    /// Default re-enters [`Self::compile`] — backends should override
414    /// when they can reuse the embedded buffer plan.
415    fn compile_lir(&self, lir: LirModule, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
416        self.compile(lir.into_graph(), options)
417    }
418
419    /// HIR-first compile: lower blocks, run fusion pipeline, emit executable.
420    fn compile_hir(
421        &self,
422        hir: HirModule,
423        device: rlx_driver::Device,
424        options: &CompileOptions,
425    ) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
426        let result = crate::stages::compile_hir_stages(device, hir, options)?;
427        crate::stages::maybe_log_fusion(&result.fusion);
428        Ok(self.compile_lir(result.lir, options))
429    }
430
431    /// [`GraphModule`] compile — unified HIR/MIR/LIR entry.
432    fn compile_module(
433        &self,
434        module: rlx_ir::GraphModule,
435        device: rlx_driver::Device,
436        options: &CompileOptions,
437    ) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
438        let result = crate::stages::compile_module_stages(device, module, options)?;
439        crate::stages::maybe_log_fusion(&result.fusion);
440        Ok(self.compile_lir(result.lir, options))
441    }
442
443    /// PLAN L4: declare which `OpKind`s this backend can lower.
444    /// Default: empty slice = "no claim made — accept everything"
445    /// (preserves existing behavior; backends opt in by overriding).
446    /// When non-empty, the `LegalizeForBackend` pass will refuse to
447    /// compile a graph that contains an op outside this set, instead
448    /// of silently falling through to slower / wrong dispatch.
449    fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
450        &[]
451    }
452}
453
454/// Prepare a fused MIR graph from LIR for backend executable construction.
455/// Skips the fusion pipeline — LIR must come from `compile_*_stages`.
456#[allow(dead_code)]
457fn prepare_fused_graph(
458    graph: Graph,
459    options: &CompileOptions,
460    supported_ops: &[rlx_ir::OpKind],
461    backend_name: &str,
462) -> Graph {
463    let (mut graph, report) = rlx_opt::prepare_graph_for_backend_with_report(
464        graph,
465        backend_name,
466        supported_ops,
467        options.kernel_dispatch,
468    );
469    rlx_opt::maybe_log_dispatch_report(&report);
470    if !report.compile_ready {
471        panic!(
472            "{}\n{}",
473            rlx_opt::format_legalize_error(backend_name, &report.still_unsupported),
474            rlx_opt::format_dispatch_report(&report)
475        );
476    }
477    graph = crate::precompile::post_fusion_cleanup(graph, options);
478    if let Some(p) = options.policy.clone() {
479        use rlx_opt::pass::Pass as _;
480        graph = rlx_opt::AutoMixedPrecision::new(p).run(graph);
481    }
482    graph
483}
484
485#[allow(dead_code)]
486fn declared_output_dtypes(
487    manifest: &cpu_low_precision::IoDtypeManifest,
488    exec_dtypes: Vec<rlx_ir::DType>,
489) -> Vec<rlx_ir::DType> {
490    exec_dtypes
491        .into_iter()
492        .enumerate()
493        .map(|(i, exec)| manifest.output_dtype(i, exec))
494        .collect()
495}
496
497// ── Convenience helpers preserved from older API ──────────────────────
498//
499// These let existing call sites keep working unchanged while the new
500// trait is the canonical one. We provide free functions rather than
501// trait methods so adding them doesn't grow the trait surface.
502
503/// Compile at default options (F32, no policy).
504pub fn compile(backend: &dyn Backend, graph: Graph) -> Box<dyn ExecutableGraph> {
505    backend.compile(graph, &CompileOptions::default())
506}
507
508/// Compile HIR through the fusion-first pipeline.
509pub fn compile_hir(
510    backend: &dyn Backend,
511    hir: HirModule,
512    device: rlx_driver::Device,
513    options: &CompileOptions,
514) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
515    backend.compile_hir(hir, device, options)
516}
517
518/// Compile a [`GraphModule`] through the fusion-first pipeline.
519pub fn compile_module(
520    backend: &dyn Backend,
521    module: rlx_ir::GraphModule,
522    device: rlx_driver::Device,
523    options: &CompileOptions,
524) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
525    backend.compile_module(module, device, options)
526}
527
528/// Compile at a specific precision (default policy = none).
529pub fn compile_with_precision(
530    backend: &dyn Backend,
531    graph: Graph,
532    precision: crate::Precision,
533) -> Box<dyn ExecutableGraph> {
534    backend.compile(graph, &CompileOptions::new().precision(precision))
535}
536
537/// Helper retained for backward compatibility — applies the precision
538/// rewrite at the runtime layer if backends don't override their
539/// pipeline placement. Modern code: pass the policy via CompileOptions
540/// and let the backend handle ordering.
541fn _legacy_apply_policy(graph: Graph, policy: Option<rlx_opt::PrecisionPolicy>) -> Graph {
542    use rlx_opt::pass::Pass as _;
543    match policy {
544        Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
545        None => graph,
546    }
547}
548
549// ── CPU Backend ─────────────────────────────────────────────────────────
550
551#[cfg(feature = "cpu")]
552pub mod cpu_backend {
553    use super::*;
554    use rlx_cpu::{arena::Arena, thunk};
555    use rlx_ir::{DType, NodeId, Op};
556    use rlx_opt::memory::{self, MemoryPlan};
557    // Arena typed read/write helpers live in `crate::arena` so every
558    // backend (CPU, Metal, future CUDA/wgpu/WASM) shares one implementation.
559    use rlx_driver::arena::{read_typed_to_f32, write_typed_from_f32};
560
561    pub struct CpuBackend;
562
563    /// PLAN L4: ops the CPU backend can lower today. Includes
564    /// DotGeneral (lowered via `LowerDotGeneral` pass) and
565    /// ElementwiseRegion (lowered natively per L2). Excludes
566    /// FusedTransformerLayer / If / While — those have IR variants
567    /// but no CPU lowering yet (see `compile_thunks` arm absence +
568    /// `subgraph.rs` "If/While executor wiring is pending" note).
569    const CPU_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
570        use rlx_ir::OpKind::*;
571        &[
572            Input,
573            Param,
574            Constant,
575            Activation,
576            Cast,
577            StopGradient,
578            Binary,
579            Compare,
580            Where,
581            ElementwiseRegion,
582            MatMul,
583            DotGeneral,
584            DenseSolve,
585            BatchedDenseSolve,
586            Scan,
587            ScanBackward,
588            ScanBackwardXs,
589            LayerNorm,
590            LayerNorm2d,
591            GroupNorm,
592            BatchNormInference,
593            RmsNorm,
594            ResizeNearest2x,
595            AxialRope2d,
596            Attention,
597            Rope,
598            Reshape,
599            Transpose,
600            Narrow,
601            Concat,
602            Expand,
603            Gather,
604            Reduce,
605            Softmax,
606            Cumsum,
607            TopK,
608            Sample,
609            Conv,
610            Im2Col,
611            ConvTranspose2d,
612            Pool,
613            GroupedMatMul,
614            DequantGroupedMatMul,
615            DequantMoEWeights,
616            ScatterAdd,
617            LoraMatMul,
618            DequantMatMul,
619            SelectiveScan,
620            GatedDeltaNet,
621            FusedSwiGLU,
622            FusedMatMulBiasAct,
623            FusedResidualLN,
624            FusedResidualRmsNorm,
625            FusedAttentionBlock,
626            // Backward ops emitted by `rlx_opt::autodiff::grad_with_loss`.
627            // Their thunks live in rlx-cpu/src/thunk.rs alongside the
628            // forward kernels; without these entries the legalize step
629            // below would reject any compiled gradient graph.
630            ReluBackward,
631            ActivationBackward,
632            FakeQuantize,
633            FakeQuantizeBackward,
634            MaxPool2dBackward,
635            Conv2dBackwardInput,
636            Conv2dBackwardWeight,
637            SoftmaxCrossEntropyWithLogits,
638            SoftmaxCrossEntropyBackward,
639            AttentionBackward,
640            LayerNormBackwardInput,
641            LayerNormBackwardGamma,
642            BatchNormInferenceBackwardInput,
643            BatchNormInferenceBackwardGamma,
644            BatchNormInferenceBackwardBeta,
645            RmsNormBackwardInput,
646            RmsNormBackwardGamma,
647            RmsNormBackwardBeta,
648            RopeBackward,
649            CumsumBackward,
650            GatherBackward,
651            // 3D Gaussian splat CPU reference render/backward (requires `rlx-cpu/splat`).
652            GaussianSplatRender,
653            GaussianSplatRenderBackward,
654            GaussianSplatPrepare,
655            GaussianSplatRasterize,
656            // User-registered custom ops dispatched through
657            // `rlx_cpu::op_registry`. Lowering panics with a clear
658            // message if the named CPU kernel isn't registered.
659            Custom,
660            // User-defined sub-graph with optional override AD rules
661            // (JAX-shaped custom_vjp / custom_jvp). Body is a regular
662            // Graph compiled recursively in compile_thunks.
663            CustomFn,
664            // FFT primitive (1D last-axis, 2N real-block layout, f64
665            // power-of-2 sizes). Other backends panic at lowering;
666            // pin FFT-containing graphs to Device::Cpu for now.
667            Fft,
668            FftButterflyStage,
669            LogMel,
670            LogMelBackward,
671            WelchPeaks,
672            // C64 Wirtinger AD surface. ComplexNormSq is the canonical
673            // real-valued loss for complex inputs; Conjugate is emitted
674            // by the new Wirtinger VJP rules for BinaryOp::Mul/Div on
675            // C64. Both have CPU thunks in rlx-cpu.
676            ComplexNormSq,
677            ComplexNormSqBackward,
678            Conjugate,
679        ]
680    };
681
682    impl Backend for CpuBackend {
683        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
684            CPU_SUPPORTED_OPS
685        }
686
687        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
688            use rlx_opt::pass::Pass as _;
689            // Lower Op::If / Op::While to primitives BEFORE legalize
690            // so the supported-op check doesn't reject them — the CPU
691            // backend has no native sub-graph executor; this rewrite
692            // makes If/While invisible to the rest of the pipeline.
693            // No-op when neither op is in the graph.
694            let graph = rlx_opt::LowerControlFlow.run(graph);
695            // PLAN L4: legalize against the backend's claimed op set
696            // BEFORE running fusion (so the diagnostic points at the
697            // user's IR, not at a fused-away node).
698            if let Err(errors) = rlx_opt::legalize_for_backend(&graph, CPU_SUPPORTED_OPS) {
699                panic!("{}", rlx_opt::format_legalize_error("cpu", &errors));
700            }
701            let policy = options.policy.clone();
702            let _precision = options.precision;
703            let cfg = rlx_cpu::config::RuntimeConfig::global();
704
705            let graph = crate::precompile::precompile_cleanup(graph, options);
706
707            // Run fusion pipeline (HIR/MIR/LIR ideology — fusion is first-class).
708            let mut compile_opts = options.clone();
709            compile_opts.arena_alignment = cfg.arena_alignment;
710            let compile_result = crate::stages::compile_graph_stages_for_backend(
711                rlx_driver::Device::Cpu,
712                graph,
713                &compile_opts,
714                CPU_SUPPORTED_OPS,
715            );
716            crate::stages::maybe_log_fusion(&compile_result.fusion);
717            let fused = compile_result.lir.into_graph();
718
719            // Apply precision policy AFTER fusion — Cast nodes don't disrupt
720            // the now-flattened fused ops.
721            let fused = match policy {
722                Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(fused),
723                None => fused,
724            };
725
726            let io_manifest = cpu_low_precision::IoDtypeManifest::from_graph(&fused);
727            let exec_graph = if cpu_low_precision::needs_f32_exec(&fused) {
728                cpu_low_precision::promote_to_f32(fused)
729            } else {
730                fused
731            };
732
733            // Re-plan after precision rewrites (may change dtypes / sizes).
734            let plan = memory::plan_memory_aligned(&exec_graph, cfg.arena_alignment);
735            if cfg.verbose >= 1 {
736                eprintln!(
737                    "[rlx] arena: {} bytes, {} buffers, alignment: {}",
738                    plan.arena_size,
739                    plan.assignments.len(),
740                    cfg.arena_alignment
741                );
742            }
743            Box::new(build_cpu_executable(exec_graph, plan, io_manifest))
744        }
745
746        fn compile_lir(
747            &self,
748            lir: LirModule,
749            options: &CompileOptions,
750        ) -> Box<dyn ExecutableGraph> {
751            let alignment = lir.buffers.alignment.max(options.arena_alignment);
752            let mut graph = lir.into_graph();
753            {
754                use rlx_opt::pass::Pass as _;
755                graph = rlx_opt::LegalizeBroadcast.run(graph);
756            }
757            if let Some(p) = options.policy.clone() {
758                use rlx_opt::pass::Pass;
759                graph = rlx_opt::AutoMixedPrecision::new(p).run(graph);
760            }
761            let io_manifest = cpu_low_precision::IoDtypeManifest::from_graph(&graph);
762            let promote = cpu_low_precision::needs_f32_exec(&graph);
763            let exec_graph = if promote {
764                cpu_low_precision::promote_to_f32(graph)
765            } else {
766                graph
767            };
768            // LegalizeBroadcast may insert Expand nodes — must replan; the
769            // embedded LIR buffer map is from before legalization.
770            let plan = memory::plan_memory_aligned(&exec_graph, alignment);
771            let cfg = rlx_cpu::config::RuntimeConfig::global();
772            if cfg.verbose >= 1 {
773                eprintln!(
774                    "[rlx] compile_lir: arena {} bytes ({} buffers, alignment {})",
775                    plan.arena_size,
776                    plan.assignments.len(),
777                    alignment,
778                );
779            }
780            Box::new(build_cpu_executable(exec_graph, plan, io_manifest))
781        }
782    }
783
784    fn build_cpu_executable(
785        graph: Graph,
786        plan: MemoryPlan,
787        io_manifest: cpu_low_precision::IoDtypeManifest,
788    ) -> CpuExecutable {
789        let mut arena = Arena::from_plan(plan);
790        let mut input_ids = HashMap::new();
791        let mut param_ids = HashMap::new();
792        let mut node_dtypes: HashMap<NodeId, DType> = HashMap::new();
793        for node in graph.nodes() {
794            node_dtypes.insert(node.id, node.shape.dtype());
795            match &node.op {
796                Op::Input { name } => {
797                    input_ids.insert(name.clone(), node.id);
798                }
799                Op::Param { name } => {
800                    param_ids.insert(name.clone(), node.id);
801                }
802                _ => {}
803            }
804        }
805
806        let schedule = thunk::compile_thunks(&graph, &arena);
807
808        let mut input_slots = Vec::new();
809        for node in graph.nodes() {
810            if let Op::Input { name } = &node.op {
811                let off = arena.byte_offset(node.id);
812                let len = node.shape.num_elements().unwrap_or(0);
813                input_slots.push((name.clone(), off, len, node.shape.dtype()));
814            }
815        }
816
817        let output_slots: Vec<(usize, usize)> = graph
818            .outputs
819            .iter()
820            .map(|&id| {
821                let off = arena.byte_offset(id);
822                let len = graph.node(id).shape.num_elements().unwrap_or(0);
823                (off, len)
824            })
825            .collect();
826
827        for node in graph.nodes() {
828            if let Op::Constant { data } = &node.op
829                && arena.has_buffer(node.id)
830                && !data.is_empty()
831            {
832                match node.shape.dtype() {
833                    DType::F64 | DType::F16 | DType::BF16 => {
834                        let off = arena.byte_offset(node.id);
835                        let buf = arena.raw_buf_mut();
836                        let n = buf.len().saturating_sub(off).min(data.len());
837                        buf[off..off + n].copy_from_slice(&data[..n]);
838                    }
839                    _ => {
840                        let buf = arena.slice_mut(node.id);
841                        let n_floats = data.len() / 4;
842                        let n = buf.len().min(n_floats);
843                        for i in 0..n {
844                            let bytes = [
845                                data[i * 4],
846                                data[i * 4 + 1],
847                                data[i * 4 + 2],
848                                data[i * 4 + 3],
849                            ];
850                            buf[i] = f32::from_le_bytes(bytes);
851                        }
852                    }
853                }
854            }
855        }
856
857        CpuExecutable {
858            graph,
859            arena,
860            params: HashMap::new(),
861            typed_params: HashMap::new(),
862            input_ids,
863            param_ids,
864            node_dtypes,
865            io_manifest,
866            schedule,
867            input_slots,
868            output_slots,
869            handles: HashMap::new(),
870            active_extent: None,
871            moe_resident: None,
872            moe_resident_layers: None,
873            moe_topk_capture: None,
874        }
875    }
876
877    #[derive(Clone)]
878    struct CpuExecutable {
879        graph: Graph,
880        arena: Arena,
881        params: HashMap<String, Vec<f32>>,
882        /// Byte-backed params (`set_param_typed` / `set_param_bytes`).
883        typed_params: HashMap<String, (Vec<u8>, DType)>,
884        input_ids: HashMap<String, NodeId>,
885        param_ids: HashMap<String, NodeId>,
886        /// Per-node arena dtype. Lets set_param/run cast f32 ↔ F16/BF16
887        /// when AutoMixedPrecision has rewritten the graph.
888        node_dtypes: HashMap<NodeId, DType>,
889        /// User-facing boundary dtypes (before f32 promotion for CPU exec).
890        io_manifest: cpu_low_precision::IoDtypeManifest,
891        schedule: thunk::ThunkSchedule,
892        // Pre-resolved: ordered list of (input_name, arena_byte_offset, max_elems, dtype)
893        input_slots: Vec<(String, usize, usize, DType)>,
894        /// Output (byte_offset, num_elements). dtype is in node_dtypes.
895        output_slots: Vec<(usize, usize)>,
896        /// Persistent buffer handles (KV-cache, optimizer state, etc.).
897        /// Lives outside the arena and survives across run() calls.
898        /// On run(): if a handle's name matches a graph input, the
899        /// handle's data is used as the input.
900        handles: HashMap<String, Vec<f32>>,
901        /// Active-extent hint (`Some((actual, upper))`) for L1 bucketed
902        /// dispatch. When set AND every thunk in the schedule is in
903        /// `Thunk::safe_for_active_extent`, the executor processes only
904        /// `actual / upper` of each kernel's work. Otherwise (or when
905        /// `None`) runs at the full compiled extent. See PLAN L1.
906        active_extent: Option<(usize, usize)>,
907        moe_resident: Option<std::sync::Arc<[bool]>>,
908        moe_resident_layers: Option<std::sync::Arc<Vec<std::sync::Arc<[bool]>>>>,
909        moe_topk_capture: Option<std::sync::Arc<rlx_cpu::moe_topk_capture::MoeTopkCapture>>,
910    }
911
912    unsafe impl Send for CpuExecutable {}
913
914    impl CpuExecutable {
915        /// Write a f32 input slice into the arena, casting to the node's dtype.
916        fn write_input(&mut self, id: NodeId, data: &[f32]) {
917            let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
918            let off = self.arena.byte_offset(id);
919            let buf = self.arena.raw_buf_mut();
920            let elem_size = dtype.size_bytes();
921            let max_elems = (buf.len() - off) / elem_size;
922            unsafe {
923                write_typed_from_f32(buf.as_mut_ptr().add(off), dtype, data, max_elems);
924            }
925        }
926
927        /// Read a node's arena bytes back as Vec<f32>, casting from its dtype.
928        fn read_output(&self, id: NodeId) -> Vec<f32> {
929            let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
930            let off = self.arena.byte_offset(id);
931            let buf = self.arena.raw_buf();
932            let n_elems = self.graph.node(id).shape.num_elements().unwrap_or(0);
933            unsafe { read_typed_to_f32(buf.as_ptr().add(off), dtype, n_elems) }
934        }
935    }
936
937    impl ExecutableGraph for CpuExecutable {
938        fn clone_box(&self) -> Box<dyn ExecutableGraph> {
939            Box::new(self.clone())
940        }
941        fn set_param(&mut self, name: &str, data: &[f32]) {
942            self.params.insert(name.to_string(), data.to_vec());
943            self.typed_params.remove(name);
944            // Write directly into the arena — zero per-call lookup for params.
945            // Cast f32 → arena dtype when the param has been rewritten to F16/BF16.
946            if let Some(&id) = self.param_ids.get(name)
947                && self.arena.has_buffer(id)
948            {
949                let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
950                let off = self.arena.byte_offset(id);
951                let buf = self.arena.raw_buf_mut();
952                let elem_size = dtype.size_bytes();
953                let max_elems = (buf.len() - off) / elem_size;
954                unsafe {
955                    write_typed_from_f32(buf.as_mut_ptr().add(off), dtype, data, max_elems);
956                }
957            }
958        }
959
960        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
961            self.restore_arena_baseline();
962            // 1. Apply persistent handles first — they act like default inputs.
963            //    Explicit `inputs` passed to run() override matching handle names.
964            let handle_names: Vec<String> = self.handles.keys().cloned().collect();
965            for name in &handle_names {
966                if let Some(&id) = self.input_ids.get(name)
967                    && self.arena.has_buffer(id)
968                {
969                    let data = self.handles.get(name).cloned().unwrap_or_default();
970                    self.write_input(id, &data);
971                }
972            }
973            // 2. Explicit per-call inputs override handles.
974            for &(name, data) in inputs {
975                if let Some(&id) = self.input_ids.get(name)
976                    && self.arena.has_buffer(id)
977                {
978                    self.write_input(id, data);
979                }
980            }
981
982            // Active-extent fast-path (PLAN L1): if hinted AND every thunk
983            // in the schedule supports it, run scaled. Otherwise fall back
984            // to full-extent dispatch — preserves correctness when the
985            // schedule contains a thunk that hasn't yet been wired in.
986            let active_used = if let Some((actual, upper)) = self.active_extent {
987                thunk::execute_thunks_active(
988                    &self.schedule,
989                    self.arena.raw_buf_mut(),
990                    actual,
991                    upper,
992                )
993            } else {
994                false
995            };
996            if !active_used {
997                // Execute via pre-compiled thunks (zero per-node dispatch overhead)
998                thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
999            }
1000
1001            // 3. Sync any handle whose name matches a graph OUTPUT —
1002            //    KV-cache pattern: outputs flow back into the same-named
1003            //    handle for the next iteration.
1004            for (idx, &out_id) in self.graph.outputs.iter().enumerate() {
1005                let name = format!("out{idx}");
1006                if self.handles.contains_key(&name) {
1007                    let v = self.read_output(out_id);
1008                    self.handles.insert(name, v);
1009                }
1010            }
1011
1012            self.graph
1013                .outputs
1014                .iter()
1015                .map(|&out_id| self.read_output(out_id))
1016                .collect()
1017        }
1018
1019        fn run_raw(&mut self, inputs: &[(&str, &[f32])]) -> Vec<(*const f32, usize)> {
1020            self.restore_arena_baseline();
1021            // Copy inputs by name (HashMap lookup), casting to arena dtype.
1022            for &(name, data) in inputs {
1023                if let Some(&id) = self.input_ids.get(name)
1024                    && self.arena.has_buffer(id)
1025                {
1026                    self.write_input(id, data);
1027                }
1028            }
1029            thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1030            // Note: pointers are raw arena bytes — for F16 outputs, callers
1031            // must read 2 bytes/elem, not 4. run() is the safe path for
1032            // mixed precision; run_raw() is only meaningful for F32.
1033            self.graph
1034                .outputs
1035                .iter()
1036                .map(|&out_id| {
1037                    let (ptr, len) = self.arena.raw_ptr(out_id);
1038                    (ptr as *const f32, len)
1039                })
1040                .collect()
1041        }
1042
1043        /// Fastest path: inputs by index (matching input_slots order), zero-copy output.
1044        /// No HashMap, no name matching, no Vec allocation. Casts f32 input
1045        /// to F16/BF16 if the input slot's dtype was rewritten.
1046        fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
1047            self.restore_arena_baseline();
1048            let buf = self.arena.raw_buf_mut();
1049            for (i, &data) in inputs.iter().enumerate() {
1050                if i < self.input_slots.len() {
1051                    let (_, off, max_len, dtype) = &self.input_slots[i];
1052                    unsafe {
1053                        write_typed_from_f32(buf.as_mut_ptr().add(*off), *dtype, data, *max_len);
1054                    }
1055                }
1056            }
1057            thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1058            &self.output_slots
1059        }
1060
1061        fn arena_ptr(&self) -> *const u8 {
1062            self.arena.raw_buf_mut_ptr()
1063        }
1064
1065        fn bind_handle(&mut self, name: &str, data: &[f32]) -> bool {
1066            // Persistent buffer: stored separately from arena, survives run().
1067            // If the name matches a graph input, run() will use this data
1068            // as the input. If the graph also writes back to this name (via
1069            // an output binding pattern), read_handle returns the latest.
1070            self.handles.insert(name.to_string(), data.to_vec());
1071            true
1072        }
1073
1074        fn read_handle(&self, name: &str) -> Option<Vec<f32>> {
1075            self.handles.get(name).cloned()
1076        }
1077
1078        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1079            self.active_extent = extent;
1080        }
1081
1082        fn set_moe_resident_experts(&mut self, mask: &[bool]) {
1083            self.moe_resident_layers = None;
1084            self.schedule.moe_resident_layers = None;
1085            self.moe_resident = Some(Arc::from(mask));
1086            self.schedule.moe_resident = self.moe_resident.clone();
1087        }
1088
1089        fn set_moe_resident_experts_per_layer(&mut self, masks: &[&[bool]]) {
1090            self.moe_resident = None;
1091            self.schedule.moe_resident = None;
1092            let layers: Vec<Arc<[bool]>> = masks.iter().map(|m| Arc::from(*m)).collect();
1093            let arc = Arc::new(layers);
1094            self.moe_resident_layers = Some(arc.clone());
1095            self.schedule.moe_resident_layers = Some(arc);
1096        }
1097
1098        fn enable_moe_topk_capture(&mut self, num_experts: usize) -> bool {
1099            let cap = rlx_cpu::moe_topk_capture::MoeTopkCapture::new(num_experts);
1100            self.moe_topk_capture = Some(cap.clone());
1101            self.schedule.moe_topk_capture = Some(cap);
1102            true
1103        }
1104
1105        fn take_moe_topk_capture(&mut self) -> Option<Vec<Vec<u32>>> {
1106            let cap = self.moe_topk_capture.as_ref()?;
1107            let layers = cap.take_layers();
1108            if layers.is_empty() {
1109                None
1110            } else {
1111                Some(layers)
1112            }
1113        }
1114
1115        fn take_moe_residency_stats(&mut self) -> Option<crate::MoeResidencyStats> {
1116            rlx_cpu::moe_residency::take_last_forward_stats()
1117        }
1118
1119        /// Typed param upload. F32 / F16 / BF16 go through the existing
1120        /// widen-to-f32 path (the CPU arena is historically f32 with
1121        /// optional half-precision rewrite). F64 (and any future
1122        /// non-widenable dtype) lands directly in the arena as bytes —
1123        /// the f32 path would lose precision.
1124        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1125            if matches!(dtype, DType::F64 | DType::I64 | DType::I32 | DType::U32) {
1126                self.set_param_bytes(name, data, dtype);
1127                return;
1128            }
1129            // U8 / I8 raw byte tensors: opaque storage for the GGUF
1130            // K-quant `Op::DequantMatMul` path (weights stay packed
1131            // in the arena). One arena byte = one element.
1132            if matches!(dtype, DType::U8 | DType::I8) {
1133                self.set_param_bytes(name, data, dtype);
1134                return;
1135            }
1136            if dtype == DType::F32 {
1137                let n = data.len() / 4;
1138                let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
1139                self.set_param(name, s);
1140            } else {
1141                let f32_buf = super::widen_bytes_to_f32(data, dtype);
1142                self.set_param(name, &f32_buf);
1143            }
1144        }
1145
1146        /// Typed run with mixed-dtype inputs/outputs.
1147        ///
1148        /// For each input: if its declared graph dtype matches the
1149        /// caller's bytes, we write directly into the arena (zero
1150        /// precision loss — F64 stays F64). For F32 with a half-precision
1151        /// arena rewrite, we widen as before. F16/BF16 callers go
1152        /// through the existing widen path.
1153        ///
1154        /// Outputs are read straight from the arena in the graph node's
1155        /// declared dtype — F64 outputs come back as 8 bytes/element,
1156        /// F32 as 4, etc.
1157        fn run_typed(
1158            &mut self,
1159            inputs: &[(&str, &[u8], rlx_ir::DType)],
1160        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1161            // Decide: are *all* inputs F64? If so, use the direct-byte
1162            // path for everything and skip the f32 widening machinery
1163            // entirely. Mixed dtype graphs (F32 + F64) take the
1164            // per-input dispatch route below.
1165            let all_f64 = !inputs.is_empty() && inputs.iter().all(|(_, _, dt)| *dt == DType::F64);
1166
1167            if all_f64 {
1168                for (name, data, _) in inputs {
1169                    if let Some(&id) = self.input_ids.get(*name) {
1170                        if !self.arena.has_buffer(id) {
1171                            continue;
1172                        }
1173                        let off = self.arena.byte_offset(id);
1174                        let buf = self.arena.raw_buf_mut();
1175                        let n = data.len();
1176                        debug_assert!(
1177                            off + n <= buf.len(),
1178                            "run_typed: input '{name}' overflows arena slot"
1179                        );
1180                        buf[off..off + n].copy_from_slice(data);
1181                    }
1182                }
1183                thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1184            } else {
1185                // Mixed-dtype path: dtypes that survive untouched
1186                // through the f32-aliased arena (F64, I32, I64, U32)
1187                // go in as bytes; F32 and the half-precision family
1188                // route through widen-to-f32 + run.
1189                let mut f32_owned: Vec<(String, Vec<f32>)> = Vec::new();
1190                for (name, data, dt) in inputs {
1191                    let direct = matches!(
1192                        *dt,
1193                        DType::F64 | DType::I32 | DType::I64 | DType::U32 | DType::C64
1194                    );
1195                    if direct {
1196                        if let Some(&id) = self.input_ids.get(*name) {
1197                            if !self.arena.has_buffer(id) {
1198                                continue;
1199                            }
1200                            let off = self.arena.byte_offset(id);
1201                            let buf = self.arena.raw_buf_mut();
1202                            buf[off..off + data.len()].copy_from_slice(data);
1203                        }
1204                    } else {
1205                        let v = super::widen_bytes_to_f32(data, *dt);
1206                        f32_owned.push((name.to_string(), v));
1207                    }
1208                }
1209                for (name, data) in &f32_owned {
1210                    if let Some(&id) = self.input_ids.get(name.as_str()) {
1211                        if self.arena.has_buffer(id) {
1212                            self.write_input(id, data);
1213                        }
1214                    }
1215                }
1216                let active_used = if let Some((actual, upper)) = self.active_extent {
1217                    thunk::execute_thunks_active(
1218                        &self.schedule,
1219                        self.arena.raw_buf_mut(),
1220                        actual,
1221                        upper,
1222                    )
1223                } else {
1224                    false
1225                };
1226                if !active_used {
1227                    thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1228                }
1229            }
1230
1231            // Read outputs in declared boundary dtypes.
1232            self.graph
1233                .outputs
1234                .iter()
1235                .enumerate()
1236                .map(|(idx, &id)| {
1237                    let exec_dtype = self.graph.node(id).shape.dtype();
1238                    let declared = self.io_manifest.output_dtype(idx, exec_dtype);
1239                    if matches!(
1240                        exec_dtype,
1241                        DType::F64
1242                            | DType::F16
1243                            | DType::BF16
1244                            | DType::I32
1245                            | DType::I64
1246                            | DType::U32
1247                            | DType::C64
1248                    ) {
1249                        let n_elems = self.graph.node(id).shape.num_elements().unwrap_or(0);
1250                        let n_bytes = n_elems * exec_dtype.size_bytes();
1251                        let off = self.arena.byte_offset(id);
1252                        let bytes = self.arena.raw_buf()[off..off + n_bytes].to_vec();
1253                        return (bytes, declared);
1254                    }
1255                    let f32_vals = self.read_output(id);
1256                    if declared != exec_dtype {
1257                        return (super::narrow_f32_to_bytes(&f32_vals, declared), declared);
1258                    }
1259                    let bytes = f32_vals.iter().flat_map(|v| v.to_le_bytes()).collect();
1260                    (bytes, declared)
1261                })
1262                .collect()
1263        }
1264    }
1265
1266    impl CpuExecutable {
1267        /// Clear ephemeral arena slots, then restore compile-time constants
1268        /// and cached params. Intermediate buffers are reused across `run()`
1269        /// calls; without this reset, a second execution can read stale data
1270        /// from the previous pass.
1271        fn restore_arena_baseline(&mut self) {
1272            self.arena.raw_buf_mut().fill(0);
1273            let constants: Vec<(NodeId, DType, Vec<u8>)> = self
1274                .graph
1275                .nodes()
1276                .iter()
1277                .filter_map(|node| {
1278                    if let Op::Constant { data } = &node.op
1279                        && self.arena.has_buffer(node.id)
1280                        && !data.is_empty()
1281                    {
1282                        Some((node.id, node.shape.dtype(), data.clone()))
1283                    } else {
1284                        None
1285                    }
1286                })
1287                .collect();
1288            for (id, dtype, data) in constants {
1289                self.write_constant_to_arena(id, dtype, &data);
1290            }
1291            let params = self.params.clone();
1292            for (name, data) in params {
1293                if let Some(&id) = self.param_ids.get(&name)
1294                    && self.arena.has_buffer(id)
1295                {
1296                    let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
1297                    let off = self.arena.byte_offset(id);
1298                    let buf = self.arena.raw_buf_mut();
1299                    let elem_size = dtype.size_bytes();
1300                    let max_elems = (buf.len() - off) / elem_size;
1301                    unsafe {
1302                        write_typed_from_f32(buf.as_mut_ptr().add(off), dtype, &data, max_elems);
1303                    }
1304                }
1305            }
1306            let typed = self.typed_params.clone();
1307            for (name, (data, dtype)) in typed {
1308                self.write_param_bytes_to_arena(&name, &data);
1309                let _ = dtype;
1310            }
1311        }
1312
1313        fn write_constant_to_arena(&mut self, id: NodeId, dtype: DType, data: &[u8]) {
1314            match dtype {
1315                DType::F64 | DType::F16 | DType::BF16 | DType::U8 | DType::I8 => {
1316                    let off = self.arena.byte_offset(id);
1317                    let buf = self.arena.raw_buf_mut();
1318                    let n = buf.len().saturating_sub(off).min(data.len());
1319                    buf[off..off + n].copy_from_slice(&data[..n]);
1320                }
1321                _ => {
1322                    let buf = self.arena.slice_mut(id);
1323                    let n_floats = data.len() / 4;
1324                    let n = buf.len().min(n_floats);
1325                    for i in 0..n {
1326                        let bytes = [
1327                            data[i * 4],
1328                            data[i * 4 + 1],
1329                            data[i * 4 + 2],
1330                            data[i * 4 + 3],
1331                        ];
1332                        buf[i] = f32::from_le_bytes(bytes);
1333                    }
1334                }
1335            }
1336        }
1337
1338        /// Direct-byte param upload — copies caller's bytes into the
1339        /// arena slot for the named param without any dtype conversion.
1340        /// Used by `set_param_typed` for dtypes that f32-widening would
1341        /// corrupt (F64). Caller is responsible for matching the param's
1342        /// declared graph dtype.
1343        fn set_param_bytes(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1344            self.typed_params
1345                .insert(name.to_string(), (data.to_vec(), dtype));
1346            self.params.remove(name);
1347            self.write_param_bytes_to_arena(name, data);
1348        }
1349
1350        fn write_param_bytes_to_arena(&mut self, name: &str, data: &[u8]) {
1351            if let Some(&id) = self.param_ids.get(name)
1352                && self.arena.has_buffer(id)
1353            {
1354                let off = self.arena.byte_offset(id);
1355                let buf = self.arena.raw_buf_mut();
1356                debug_assert!(
1357                    off + data.len() <= buf.len(),
1358                    "set_param_bytes: '{name}' would overflow arena slot"
1359                );
1360                buf[off..off + data.len()].copy_from_slice(data);
1361            }
1362        }
1363    }
1364}
1365
1366// ── Metal Backend ───────────────────────────────────────────────────────
1367
1368// ── wgpu Backend ────────────────────────────────────────────────────────
1369
1370#[cfg(feature = "gpu")]
1371pub mod wgpu_backend {
1372    use super::*;
1373    use rlx_ir::OpKind;
1374    use rlx_wgpu::backend::WgpuExecutable;
1375
1376    pub struct WgpuBackend;
1377
1378    /// PLAN L4: ops the wgpu backend can lower today. The fused
1379    /// macro-kernels (FAB, FTL, FusedSwiGLU) get decomposed by
1380    /// `crate::unfuse::unfuse` upstream — they're listed here too so
1381    /// graphs that already contain them legalize cleanly. Conv1d/3d
1382    /// and Pool1d/3d are deferred (Conv2d only).
1383    const WGPU_SUPPORTED_OPS: &[OpKind] = &[
1384        OpKind::Input,
1385        OpKind::Param,
1386        OpKind::Constant,
1387        OpKind::Activation,
1388        OpKind::Cast,
1389        OpKind::StopGradient,
1390        OpKind::Binary,
1391        OpKind::Compare,
1392        OpKind::Where,
1393        OpKind::ElementwiseRegion,
1394        OpKind::TransformRegion,
1395        OpKind::BatchElementwiseRegion,
1396        OpKind::MatMul,
1397        OpKind::DotGeneral,
1398        OpKind::LayerNorm,
1399        OpKind::RmsNorm,
1400        OpKind::Attention,
1401        OpKind::AttentionBackward,
1402        OpKind::RmsNormBackwardInput,
1403        OpKind::RmsNormBackwardGamma,
1404        OpKind::RmsNormBackwardBeta,
1405        // LayerNorm backward family:
1406        //   * Input  — single workgroup-per-row fused kernel.
1407        //   * Gamma  — two-dispatch (partial + reduce) that uses a tail
1408        //              scratch zone in the arena to hold per-chunk
1409        //              partial sums; the reduce kernel sums them.
1410        // Both beat the autodiff-decomposed primitive chain.
1411        OpKind::LayerNormBackwardInput,
1412        OpKind::LayerNormBackwardGamma,
1413        OpKind::RopeBackward,
1414        OpKind::CumsumBackward,
1415        OpKind::GatherBackward,
1416        OpKind::Rope,
1417        OpKind::Reshape,
1418        OpKind::Transpose,
1419        OpKind::Narrow,
1420        OpKind::Concat,
1421        OpKind::Expand,
1422        OpKind::Gather,
1423        OpKind::Reduce,
1424        OpKind::Softmax,
1425        OpKind::Cumsum,
1426        OpKind::TopK,
1427        OpKind::Sample,
1428        OpKind::Conv,
1429        OpKind::Im2Col,
1430        OpKind::Pool,
1431        OpKind::GroupedMatMul,
1432        OpKind::DequantGroupedMatMul,
1433        OpKind::DequantMoEWeights,
1434        OpKind::ScatterAdd,
1435        OpKind::SelectiveScan,
1436        OpKind::DequantMatMul,
1437        OpKind::FusedMatMulBiasAct,
1438        OpKind::FusedResidualLN,
1439        OpKind::FusedResidualRmsNorm,
1440        OpKind::FusedSwiGLU,
1441        OpKind::FusedAttentionBlock,
1442        OpKind::FusedTransformerLayer,
1443        // Native FFT (WGSL radix-2): f32 only, power-of-2 N ≤ 1024.
1444        // Anything outside that envelope panics at lowering with a
1445        // "pin to Device::Cpu" hint. No host fallback — WGPU has no
1446        // unified memory, so silent CPU round-trip would be a hidden
1447        // performance cliff.
1448        OpKind::Fft,
1449        OpKind::LogMel,
1450        OpKind::LogMelBackward,
1451        OpKind::WelchPeaks,
1452        // 3D Gaussian splat: native Metal / CPU reference per backend.
1453        OpKind::GaussianSplatRender,
1454        OpKind::GaussianSplatRenderBackward,
1455        OpKind::GaussianSplatPrepare,
1456        OpKind::GaussianSplatRasterize,
1457        OpKind::Custom,
1458        // LoRA, If, While: not yet wired in wgpu — fail loudly.
1459    ];
1460
1461    impl Backend for WgpuBackend {
1462        fn supported_ops(&self) -> &'static [OpKind] {
1463            WGPU_SUPPORTED_OPS
1464        }
1465
1466        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
1467            use rlx_opt::pass::Pass as _;
1468            let graph = rlx_opt::LowerControlFlow.run(graph);
1469            let graph = rlx_opt::legalize_or_rewrite_for_backend(graph, WGPU_SUPPORTED_OPS)
1470                .unwrap_or_else(|errors| {
1471                    panic!("{}", rlx_opt::format_legalize_error("wgpu", &errors));
1472                });
1473            let graph = crate::precompile::precompile_cleanup(graph, options);
1474            // Materialize mid-axis broadcasts before MarkElementwiseRegions:
1475            // wgpu Binary/region kernels only handle trailing/scalar broadcast
1476            // via modulus; EEG patch embed uses [1,C,1,D] + [1,C,P,D].
1477            let graph = rlx_opt::LegalizeBroadcast.run(graph);
1478            // ORDER MATTERS: targeted-pattern fusions run BEFORE the
1479            // catch-all `MarkElementwiseRegions`. Otherwise the region
1480            // pass swallows the Add / Activation nodes into chains and
1481            // FuseMatMulBiasAct / FuseResidualLN fail to match the
1482            // narrower patterns they look for. (Metal pipeline at line
1483            // ~377 already orders these correctly; wgpu was inverted
1484            // and silently shipped 13 unfused LayerNorms per BERT
1485            // forward where 12 should have been FusedResidualLN.)
1486            let compile_result = crate::stages::compile_graph_stages_for_backend(
1487                rlx_driver::Device::Gpu,
1488                graph,
1489                options,
1490                WGPU_SUPPORTED_OPS,
1491            );
1492            crate::stages::maybe_log_fusion(&compile_result.fusion);
1493            let graph = compile_result.lir.into_graph();
1494            let graph = match options.policy.clone() {
1495                Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
1496                None => graph,
1497            };
1498            let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
1499            Box::new(WgpuExecutableWrapper {
1500                inner: WgpuExecutable::compile(graph),
1501                io_manifest,
1502            })
1503        }
1504
1505        fn compile_lir(
1506            &self,
1507            lir: LirModule,
1508            options: &CompileOptions,
1509        ) -> Box<dyn ExecutableGraph> {
1510            use rlx_opt::pass::Pass as _;
1511            // LIR may already contain fused ElementwiseRegions; legalize
1512            // broadcasts on the unfused graph shape before backend prep.
1513            let graph = rlx_opt::LegalizeBroadcast.run(lir.into_graph());
1514            let graph = prepare_fused_graph(graph, options, WGPU_SUPPORTED_OPS, "wgpu");
1515            let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
1516            Box::new(WgpuExecutableWrapper {
1517                inner: WgpuExecutable::compile(graph),
1518                io_manifest,
1519            })
1520        }
1521    }
1522
1523    struct WgpuExecutableWrapper {
1524        inner: WgpuExecutable,
1525        io_manifest: cpu_low_precision::IoDtypeManifest,
1526    }
1527
1528    unsafe impl Send for WgpuExecutableWrapper {}
1529
1530    impl ExecutableGraph for WgpuExecutableWrapper {
1531        fn set_param(&mut self, name: &str, data: &[f32]) {
1532            self.inner.set_param(name, data);
1533        }
1534        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
1535            self.inner.run(inputs)
1536        }
1537        fn run_read_outputs(
1538            &mut self,
1539            inputs: &[(&str, &[f32])],
1540            read_indices: Option<&[usize]>,
1541        ) -> Vec<Vec<f32>> {
1542            self.inner.run_read_outputs(inputs, read_indices)
1543        }
1544        fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
1545            self.inner.bind_gpu_handle(name, data)
1546        }
1547        fn has_gpu_handle(&self, name: &str) -> bool {
1548            self.inner.has_gpu_handle(name)
1549        }
1550        fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
1551            self.inner.set_gpu_handle_feed(handle_name, output_index);
1552            true
1553        }
1554        fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
1555            self.inner.read_gpu_handle(name)
1556        }
1557        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1558            self.inner.set_active_extent(extent);
1559        }
1560
1561        /// Typed param upload: widens F16/BF16 to F32 at the host boundary,
1562        /// since the wgpu arena is f32-uniform.
1563        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1564            match dtype {
1565                rlx_ir::DType::U8 | rlx_ir::DType::I8 => {
1566                    self.inner.set_param_bytes(name, data);
1567                }
1568                rlx_ir::DType::F32 => {
1569                    let n = data.len() / 4;
1570                    let f32_slice =
1571                        unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
1572                    self.inner.set_param(name, f32_slice);
1573                }
1574                rlx_ir::DType::F16 => {
1575                    let n = data.len() / 2;
1576                    let f16_slice =
1577                        unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n) };
1578                    let f32: Vec<f32> = f16_slice.iter().map(|h| h.to_f32()).collect();
1579                    self.inner.set_param(name, &f32);
1580                }
1581                rlx_ir::DType::BF16 => {
1582                    let n = data.len() / 2;
1583                    let bf16_slice = unsafe {
1584                        std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n)
1585                    };
1586                    let f32: Vec<f32> = bf16_slice.iter().map(|h| h.to_f32()).collect();
1587                    self.inner.set_param(name, &f32);
1588                }
1589                other => panic!(
1590                    "rlx-wgpu set_param_typed: dtype {other:?} unsupported \
1591                                 (F32, F16, BF16 only — wgpu arena is f32-uniform)"
1592                ),
1593            }
1594        }
1595
1596        /// Typed run: widen each typed input to F32, run, then narrow each
1597        /// output back to its declared dtype.
1598        fn run_typed(
1599            &mut self,
1600            inputs: &[(&str, &[u8], rlx_ir::DType)],
1601        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1602            let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
1603            for (name, data, dt) in inputs {
1604                let v: Vec<f32> = match *dt {
1605                    rlx_ir::DType::F32 => {
1606                        let n = data.len() / 4;
1607                        unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) }
1608                            .to_vec()
1609                    }
1610                    rlx_ir::DType::F16 => {
1611                        let n = data.len() / 2;
1612                        let s = unsafe {
1613                            std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n)
1614                        };
1615                        s.iter().map(|h| h.to_f32()).collect()
1616                    }
1617                    rlx_ir::DType::BF16 => {
1618                        let n = data.len() / 2;
1619                        let s = unsafe {
1620                            std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n)
1621                        };
1622                        s.iter().map(|h| h.to_f32()).collect()
1623                    }
1624                    other => {
1625                        panic!("rlx-wgpu run_typed: input '{name}' dtype {other:?} unsupported")
1626                    }
1627                };
1628                owned.push((name.to_string(), v));
1629            }
1630            let refs: Vec<(&str, &[f32])> = owned
1631                .iter()
1632                .map(|(n, d)| (n.as_str(), d.as_slice()))
1633                .collect();
1634            let dtypes =
1635                super::declared_output_dtypes(&self.io_manifest, self.inner.output_dtypes());
1636            let outs = self.inner.run(&refs);
1637            outs.into_iter()
1638                .zip(
1639                    dtypes
1640                        .into_iter()
1641                        .chain(std::iter::repeat(rlx_ir::DType::F32)),
1642                )
1643                .map(|(v, dt)| (narrow_to_dtype(&v, dt), dt))
1644                .collect()
1645        }
1646    }
1647
1648    /// Cast every element of a wgpu f32 output buffer down to the
1649    /// declared output dtype, returning the corresponding byte stream.
1650    /// The arena keeps every value as f32; declared output dtypes
1651    /// (Bool, I8, I32, F16, ...) require an exit-time narrowing to be
1652    /// byte-identical with backends that store the native dtype.
1653    fn narrow_to_dtype(v: &[f32], dt: rlx_ir::DType) -> Vec<u8> {
1654        use rlx_ir::DType;
1655        match dt {
1656            DType::F32 => {
1657                let mut bytes = Vec::with_capacity(v.len() * 4);
1658                for &x in v {
1659                    bytes.extend_from_slice(&x.to_le_bytes());
1660                }
1661                bytes
1662            }
1663            DType::F16 => {
1664                let mut bytes = Vec::with_capacity(v.len() * 2);
1665                for &x in v {
1666                    bytes.extend_from_slice(&half::f16::from_f32(x).to_le_bytes());
1667                }
1668                bytes
1669            }
1670            DType::BF16 => {
1671                let mut bytes = Vec::with_capacity(v.len() * 2);
1672                for &x in v {
1673                    bytes.extend_from_slice(&half::bf16::from_f32(x).to_le_bytes());
1674                }
1675                bytes
1676            }
1677            DType::F64 => {
1678                let mut bytes = Vec::with_capacity(v.len() * 8);
1679                for &x in v {
1680                    bytes.extend_from_slice(&(x as f64).to_le_bytes());
1681                }
1682                bytes
1683            }
1684            DType::I8 => v.iter().map(|&x| x as i8 as u8).collect(),
1685            DType::U8 => v.iter().map(|&x| x as u8).collect(),
1686            DType::I16 => {
1687                let mut bytes = Vec::with_capacity(v.len() * 2);
1688                for &x in v {
1689                    bytes.extend_from_slice(&(x as i16).to_le_bytes());
1690                }
1691                bytes
1692            }
1693            DType::I32 => {
1694                let mut bytes = Vec::with_capacity(v.len() * 4);
1695                for &x in v {
1696                    bytes.extend_from_slice(&(x as i32).to_le_bytes());
1697                }
1698                bytes
1699            }
1700            DType::U32 => {
1701                let mut bytes = Vec::with_capacity(v.len() * 4);
1702                for &x in v {
1703                    bytes.extend_from_slice(&(x as u32).to_le_bytes());
1704                }
1705                bytes
1706            }
1707            DType::I64 => {
1708                let mut bytes = Vec::with_capacity(v.len() * 8);
1709                for &x in v {
1710                    bytes.extend_from_slice(&(x as i64).to_le_bytes());
1711                }
1712                bytes
1713            }
1714            DType::Bool => v
1715                .iter()
1716                .map(|&x| if x != 0.0 { 1u8 } else { 0u8 })
1717                .collect(),
1718            // C64 (complex f32 pair) — the wgpu backend's f32 arena
1719            // doesn't synthesize complex outputs today; this branch
1720            // only fires if a graph somehow asks for a C64 output and
1721            // the backend lowered it as 2N real floats. We pass the
1722            // raw f32 stream straight through; downstream code that
1723            // wants complex semantics is responsible for re-pairing.
1724            DType::C64 => {
1725                let mut bytes = Vec::with_capacity(v.len() * 4);
1726                for &x in v {
1727                    bytes.extend_from_slice(&x.to_le_bytes());
1728                }
1729                bytes
1730            }
1731        }
1732    }
1733}
1734
1735// ── MLX Backend ─────────────────────────────────────────────────────────
1736
1737#[cfg(all(feature = "mlx", rlx_mlx_host))]
1738pub mod mlx_backend {
1739    use super::*;
1740    use rlx_mlx::MlxExecutable;
1741
1742    pub struct MlxBackend;
1743
1744    /// PLAN L4: ops the MLX backend can lower today. MLX has the
1745    /// widest IR coverage of any GPU backend — handles everything
1746    /// including If/While via topo unrolling, and lowers
1747    /// ElementwiseRegion natively via the per-step composition in
1748    /// rlx-mlx/src/lower.rs (PLAN L2).
1749    ///
1750    /// `GroupNorm` / `BatchNormInference` are intentionally omitted — lowered
1751    /// to primitives via [`LowerGroupNorm`] / [`LowerBatchNormInference`]
1752    /// before MLX lowering (no native MLX kernel).
1753    const MLX_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
1754        use rlx_ir::OpKind::*;
1755        &[
1756            Input,
1757            Param,
1758            Constant,
1759            Activation,
1760            Cast,
1761            StopGradient,
1762            Binary,
1763            Compare,
1764            Where,
1765            ElementwiseRegion,
1766            TransformRegion,
1767            BatchElementwiseRegion,
1768            MatMul,
1769            DotGeneral,
1770            DenseSolve,
1771            BatchedDenseSolve,
1772            LayerNorm,
1773            LayerNorm2d,
1774            ResizeNearest2x,
1775            RmsNorm,
1776            Attention,
1777            Rope,
1778            Reshape,
1779            Transpose,
1780            Narrow,
1781            Concat,
1782            Expand,
1783            Gather,
1784            Reduce,
1785            Softmax,
1786            Cumsum,
1787            TopK,
1788            Sample,
1789            Conv,
1790            ConvTranspose2d,
1791            Pool,
1792            GroupedMatMul,
1793            DequantGroupedMatMul,
1794            DequantMoEWeights,
1795            ScatterAdd,
1796            LoraMatMul,
1797            DequantMatMul,
1798            SelectiveScan,
1799            GatedDeltaNet,
1800            FusedSwiGLU,
1801            FusedMatMulBiasAct,
1802            FusedResidualLN,
1803            FusedResidualRmsNorm,
1804            FusedAttentionBlock,
1805            FusedTransformerLayer,
1806            If,
1807            While,
1808            // Loop-unrolled scan (Op::Scan body is statically unrolled
1809            // `length` times into MLX ops; mirror of Op::While's
1810            // bounded-unroll lowering). ScanBackward is the AD
1811            // companion — handled the same way.
1812            Scan,
1813            ScanBackward,
1814            ScanBackwardXs,
1815            // Tier 1 autodiff backward ops — lowered as primitive
1816            // compositions in `rlx-mlx/src/lower.rs`.
1817            ReluBackward,
1818            ActivationBackward,
1819            SoftmaxCrossEntropyWithLogits,
1820            SoftmaxCrossEntropyBackward,
1821            AttentionBackward,
1822            LayerNormBackwardInput,
1823            LayerNormBackwardGamma,
1824            // Tier 2 — conv backward via `mc::conv_general` with the
1825            // same parameter-mapping MLX uses inside its built-in vjp.
1826            // Currently groups=1 only; grouped conv backward will
1827            // surface as a clear error from `lower.rs`.
1828            Conv2dBackwardInput,
1829            Conv2dBackwardWeight,
1830            // Tier 3 — max-pool backward via slice-strided argmax over
1831            // pool windows + a per-kernel-slot scatter-add, matching
1832            // the CPU thunk's "first-hit-wins" tiebreaking.
1833            MaxPool2dBackward,
1834            // QAT — `FakeQuantize` (PerBatch + Fixed scale modes;
1835            // EMA returns a clear error from `lower.rs`) and the
1836            // `FakeQuantizeBackward` family covering all 4 STE
1837            // variants. Closes the last gap vs `CPU_SUPPORTED_OPS`.
1838            FakeQuantize,
1839            FakeQuantizeBackward,
1840            // User-registered custom ops dispatched through
1841            // `rlx_mlx::op_registry`. Lowering looks up the
1842            // registered `MlxKernel` and calls its `execute` method
1843            // to produce the lazy MLX `Array` for this node.
1844            Custom,
1845            Fft,
1846            LogMel,
1847            LogMelBackward,
1848            WelchPeaks,
1849            GaussianSplatRender,
1850            GaussianSplatRenderBackward,
1851            // Op::Fft on MLX: native `mlx::fft::fft` via rlx_mlx_op_fft shim.
1852            // 2N real-block f32/f64 and complex64 inputs supported.
1853        ]
1854    };
1855
1856    impl Backend for MlxBackend {
1857        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
1858            MLX_SUPPORTED_OPS
1859        }
1860
1861        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
1862            let compile_result = crate::stages::compile_graph_stages_for_backend(
1863                rlx_driver::Device::Mlx,
1864                graph,
1865                options,
1866                MLX_SUPPORTED_OPS,
1867            );
1868            crate::stages::maybe_log_fusion(&compile_result.fusion);
1869            self.compile_lir(compile_result.lir, options)
1870        }
1871
1872        fn compile_lir(
1873            &self,
1874            lir: LirModule,
1875            options: &CompileOptions,
1876        ) -> Box<dyn ExecutableGraph> {
1877            use rlx_opt::pass::Pass as _;
1878            let mut graph = lir.into_graph();
1879            graph = rlx_opt::LowerControlFlow.run(graph);
1880            let graph = prepare_fused_graph(graph, options, MLX_SUPPORTED_OPS, "mlx");
1881            Box::new(build_mlx_executable(graph))
1882        }
1883    }
1884
1885    fn build_mlx_executable(graph: Graph) -> MlxExecutableWrapper {
1886        let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
1887        let mode = mlx_mode_from_env();
1888        let mut exe = MlxExecutable::compile_from_fused(graph, mode);
1889        if mode == rlx_mlx::lower::MlxMode::Compiled {
1890            if let Err(e) = exe.warm_compile() {
1891                eprintln!(
1892                    "[rlx-runtime] MLX warm_compile failed ({e}); first run will pay the trace cost"
1893                );
1894            }
1895        }
1896        MlxExecutableWrapper {
1897            inner: exe,
1898            io_manifest,
1899        }
1900    }
1901
1902    fn mlx_mode_from_env() -> rlx_mlx::lower::MlxMode {
1903        match rlx_ir::env::var("RLX_MLX_MODE").as_deref() {
1904            Some(s) if s.eq_ignore_ascii_case("eager") => rlx_mlx::lower::MlxMode::Eager,
1905            Some(s) if s.eq_ignore_ascii_case("lazy") => rlx_mlx::lower::MlxMode::Lazy,
1906            Some(s) if s.eq_ignore_ascii_case("compiled") => rlx_mlx::lower::MlxMode::Compiled,
1907            _ => rlx_mlx::lower::MlxMode::Compiled,
1908        }
1909    }
1910
1911    struct MlxExecutableWrapper {
1912        inner: MlxExecutable,
1913        io_manifest: cpu_low_precision::IoDtypeManifest,
1914    }
1915
1916    unsafe impl Send for MlxExecutableWrapper {}
1917
1918    impl ExecutableGraph for MlxExecutableWrapper {
1919        fn set_param(&mut self, name: &str, data: &[f32]) {
1920            self.inner.set_param(name, data);
1921        }
1922        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
1923            self.inner.run(inputs)
1924        }
1925        fn run_read_outputs(
1926            &mut self,
1927            inputs: &[(&str, &[f32])],
1928            read_indices: Option<&[usize]>,
1929        ) -> Vec<Vec<f32>> {
1930            self.inner
1931                .run_read_outputs(inputs, read_indices)
1932                .unwrap_or_else(|e| panic!("MLX run_read_outputs failed: {e}"))
1933        }
1934        fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
1935            self.inner.run_slots(inputs)
1936        }
1937        fn arena_ptr(&self) -> *const u8 {
1938            self.inner.arena_ptr()
1939        }
1940        fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
1941            self.inner.commit_no_wait(inputs);
1942        }
1943        fn sync_pending(&mut self) {
1944            self.inner.sync_pending();
1945        }
1946        fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
1947            self.inner.run_pipelined(input_sets)
1948        }
1949        fn bind_handle(&mut self, name: &str, data: &[f32]) -> bool {
1950            self.inner.bind_handle(name, data)
1951        }
1952        fn read_handle(&self, name: &str) -> Option<Vec<f32>> {
1953            self.inner.read_handle(name)
1954        }
1955        fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
1956            self.inner.bind_gpu_handle(name, data).is_ok()
1957        }
1958        fn has_gpu_handle(&self, name: &str) -> bool {
1959            self.inner.has_gpu_handle(name)
1960        }
1961        fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
1962            self.inner.set_gpu_handle_feed(handle_name, output_index);
1963            true
1964        }
1965        fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
1966            self.inner.read_gpu_handle(name).ok()
1967        }
1968        fn run_feed_gpu_handle(
1969            &mut self,
1970            inputs: &[(&str, &[f32])],
1971            handle_name: &str,
1972            output_index: usize,
1973        ) -> Option<Vec<f32>> {
1974            self.inner
1975                .run_feed_gpu(inputs, handle_name, output_index)
1976                .ok()
1977        }
1978        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1979            self.inner.set_param_typed(name, data, dtype);
1980        }
1981        fn run_typed(
1982            &mut self,
1983            inputs: &[(&str, &[u8], rlx_ir::DType)],
1984        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1985            let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
1986            for (name, data, dt) in inputs {
1987                let v = super::widen_bytes_to_f32(data, *dt);
1988                owned.push((name.to_string(), v));
1989            }
1990            let refs: Vec<(&str, &[f32])> = owned
1991                .iter()
1992                .map(|(n, d)| (n.as_str(), d.as_slice()))
1993                .collect();
1994            let f32_outs = self.inner.run(&refs);
1995            let declared = super::declared_output_dtypes(
1996                &self.io_manifest,
1997                (0..f32_outs.len()).map(|_| rlx_ir::DType::F32).collect(),
1998            );
1999            f32_outs
2000                .into_iter()
2001                .zip(
2002                    declared
2003                        .into_iter()
2004                        .chain(std::iter::repeat(rlx_ir::DType::F32)),
2005                )
2006                .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2007                .collect()
2008        }
2009        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2010            self.inner.set_active_extent(extent);
2011        }
2012    }
2013}
2014
2015#[cfg(all(feature = "metal", target_os = "macos"))]
2016pub mod metal_backend {
2017    use super::*;
2018    use rlx_metal::backend::MetalExecutable;
2019
2020    pub struct MetalBackend;
2021
2022    /// PLAN L4: ops the Metal backend can lower today. Includes
2023    /// DotGeneral (LowerDotGeneral pass) and ElementwiseRegion
2024    /// (decomposed by UnfuseElementwiseRegions). Excludes Cumsum,
2025    /// SelectiveScan, LoraMatMul, Sample,
2026    /// FusedAttentionBlock, FusedTransformerLayer, If, While —
2027    /// not yet wired in `rlx-metal/src/thunk.rs`'s compile_thunks.
2028    /// DequantMatMul (GGUF K-quants) lowers to a GPU dequant kernel
2029    /// + MPS matmul; legacy Int8 schemes remain CPU-only.
2030    ///
2031    const METAL_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2032        use rlx_ir::OpKind::*;
2033        &[
2034            Input,
2035            Param,
2036            Constant,
2037            Activation,
2038            Cast,
2039            StopGradient,
2040            Binary,
2041            Compare,
2042            Where,
2043            ElementwiseRegion,
2044            TransformRegion,
2045            BatchElementwiseRegion,
2046            MatMul,
2047            DotGeneral,
2048            LayerNorm,
2049            LayerNorm2d,
2050            GroupNorm,
2051            RmsNorm,
2052            ResizeNearest2x,
2053            AxialRope2d,
2054            Attention,
2055            AttentionBackward,
2056            RmsNormBackwardInput,
2057            RmsNormBackwardGamma,
2058            RmsNormBackwardBeta,
2059            RopeBackward,
2060            CumsumBackward,
2061            GatherBackward,
2062            Conv2dBackwardInput,
2063            Conv2dBackwardWeight,
2064            MaxPool2dBackward,
2065            Rope,
2066            Reshape,
2067            Transpose,
2068            Narrow,
2069            Concat,
2070            Expand,
2071            Gather,
2072            Reduce,
2073            Softmax,
2074            TopK,
2075            Conv,
2076            Im2Col,
2077            ConvTranspose2d,
2078            Pool,
2079            GroupedMatMul,
2080            DequantGroupedMatMul,
2081            DequantMoEWeights,
2082            ScatterAdd,
2083            DequantMatMul,
2084            GatedDeltaNet,
2085            FusedSwiGLU,
2086            FusedMatMulBiasAct,
2087            FusedResidualLN,
2088            FusedResidualRmsNorm,
2089            // User-registered custom ops dispatched through
2090            // `rlx_metal::op_registry`. Lowering panics with a clear
2091            // message if the named MetalKernel isn't registered;
2092            // executor inserts a sync point + runs the host kernel
2093            // against the unified-memory arena.
2094            Custom,
2095            // Op::Fft is supported via the same host-fallback pattern
2096            // as Custom: sync the GPU, run rlx-cpu's FFT against the
2097            // unified-memory arena, restart cmd_buf. A native Metal
2098            // compute kernel will replace this when a workload makes
2099            // the sync the bottleneck.
2100            Fft,
2101            LogMel,
2102            LogMelBackward,
2103            WelchPeaks,
2104            // Host-fallback splat (unified-memory arena + rlx-cpu/splat).
2105            GaussianSplatRender,
2106            GaussianSplatRenderBackward,
2107            GaussianSplatPrepare,
2108            GaussianSplatRasterize,
2109        ]
2110    };
2111
2112    impl Backend for MetalBackend {
2113        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2114            METAL_SUPPORTED_OPS
2115        }
2116
2117        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2118            use rlx_opt::pass::Pass as _;
2119            // Same If/While → primitive rewrite as the CPU pipeline
2120            // (Metal also has no native sub-graph executor wired
2121            // through its thunk schedule).
2122            let graph = rlx_opt::LowerControlFlow.run(graph);
2123            let dispatch = options.kernel_dispatch;
2124            let graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
2125                graph,
2126                METAL_SUPPORTED_OPS,
2127                dispatch,
2128            )
2129            .unwrap_or_else(|errors| {
2130                panic!("{}", rlx_opt::format_legalize_error("metal", &errors));
2131            });
2132            let graph = crate::precompile::precompile_cleanup(graph, options);
2133
2134            // Hand the policy to MetalExecutable so the rewrite runs AFTER
2135            // its internal fusion passes (avoids breaking pattern matchers).
2136            let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
2137            Box::new(MetalExecutableWrapper {
2138                inner: MetalExecutable::compile_with_policy(
2139                    graph,
2140                    options.policy.clone(),
2141                    Some(METAL_SUPPORTED_OPS),
2142                ),
2143                io_manifest,
2144            })
2145        }
2146
2147        fn compile_lir(
2148            &self,
2149            lir: LirModule,
2150            options: &CompileOptions,
2151        ) -> Box<dyn ExecutableGraph> {
2152            use rlx_opt::pass::Pass as _;
2153            let mut graph = lir.into_graph();
2154            graph = rlx_opt::LowerControlFlow.run(graph);
2155            let dispatch = options.kernel_dispatch;
2156            let mut graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
2157                graph,
2158                METAL_SUPPORTED_OPS,
2159                dispatch,
2160            )
2161            .unwrap_or_else(|errors| {
2162                panic!("{}", rlx_opt::format_legalize_error("metal", &errors));
2163            });
2164            graph = crate::precompile::precompile_cleanup(graph, options);
2165            let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
2166            Box::new(MetalExecutableWrapper {
2167                inner: MetalExecutable::compile_from_fused(
2168                    graph,
2169                    options.policy.clone(),
2170                    Some(METAL_SUPPORTED_OPS),
2171                ),
2172                io_manifest,
2173            })
2174        }
2175    }
2176
2177    struct MetalExecutableWrapper {
2178        inner: MetalExecutable,
2179        io_manifest: cpu_low_precision::IoDtypeManifest,
2180    }
2181
2182    unsafe impl Send for MetalExecutableWrapper {}
2183
2184    impl ExecutableGraph for MetalExecutableWrapper {
2185        fn set_param(&mut self, name: &str, data: &[f32]) {
2186            self.inner.set_param(name, data);
2187        }
2188        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2189            self.inner.run(inputs)
2190        }
2191        fn run_read_outputs(
2192            &mut self,
2193            inputs: &[(&str, &[f32])],
2194            read_indices: Option<&[usize]>,
2195        ) -> Vec<Vec<f32>> {
2196            self.inner.run_read_outputs(inputs, read_indices)
2197        }
2198        fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
2199            self.inner.bind_gpu_handle(name, data)
2200        }
2201        fn has_gpu_handle(&self, name: &str) -> bool {
2202            self.inner.has_gpu_handle(name)
2203        }
2204        fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
2205            self.inner.set_gpu_handle_feed(handle_name, output_index);
2206            true
2207        }
2208        fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
2209            self.inner.read_gpu_handle(name)
2210        }
2211        fn read_output_row(
2212            &self,
2213            out_idx: usize,
2214            row: usize,
2215            row_inner: usize,
2216        ) -> Option<Vec<f32>> {
2217            Some(self.inner.read_graph_output_row(out_idx, row, row_inner))
2218        }
2219        fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
2220            self.inner.run_slots(inputs)
2221        }
2222        fn arena_ptr(&self) -> *const u8 {
2223            self.inner.arena_ptr()
2224        }
2225        fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
2226            self.inner.commit_no_wait(inputs);
2227        }
2228        fn sync_pending(&mut self) {
2229            self.inner.sync_pending();
2230        }
2231        fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
2232            self.inner.run_pipelined(input_sets)
2233        }
2234        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2235            self.inner.set_active_extent(extent);
2236        }
2237
2238        /// Typed param upload — accepts F16/BF16 host bytes by widening
2239        /// to F32 first, then routing through `set_param`. The Metal
2240        /// arena's `write_from_f32` honors per-node F16 storage when
2241        /// AutoMixedPrecision rewrote the param. U8/I8 packed weights
2242        /// copy directly into the arena for `Op::DequantMatMul`.
2243        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2244            if matches!(dtype, rlx_ir::DType::U8 | rlx_ir::DType::I8) {
2245                self.inner.set_param_bytes(name, data);
2246                return;
2247            }
2248            if dtype == rlx_ir::DType::F32 {
2249                let n = data.len() / 4;
2250                let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2251                self.inner.set_param(name, s);
2252            } else {
2253                let f32_buf = super::widen_bytes_to_f32(data, dtype);
2254                self.inner.set_param(name, &f32_buf);
2255            }
2256        }
2257
2258        /// Typed run. Inputs widen to F32 (existing path; F64 host
2259        /// inputs through `run_typed` is a separate Metal extension).
2260        /// Outputs: F64 outputs go through the byte-direct
2261        /// `output_bytes_per_node` path (no precision loss in the
2262        /// f32 round-trip); other dtypes keep the f32-narrow path
2263        /// for backward compatibility with existing AutoMixedPrecision
2264        /// rewrites.
2265        fn run_typed(
2266            &mut self,
2267            inputs: &[(&str, &[u8], rlx_ir::DType)],
2268        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2269            let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2270            for (name, data, dt) in inputs {
2271                let v = super::widen_bytes_to_f32(data, *dt);
2272                owned.push((name.to_string(), v));
2273            }
2274            let refs: Vec<(&str, &[f32])> = owned
2275                .iter()
2276                .map(|(n, d)| (n.as_str(), d.as_slice()))
2277                .collect();
2278            let dtypes =
2279                super::declared_output_dtypes(&self.io_manifest, self.inner.output_dtypes());
2280            let f32_outs = self.inner.run(&refs);
2281            let byte_outs = self.inner.output_bytes_per_node();
2282            f32_outs
2283                .into_iter()
2284                .zip(byte_outs.into_iter())
2285                .zip(
2286                    dtypes
2287                        .into_iter()
2288                        .chain(std::iter::repeat(rlx_ir::DType::F32)),
2289                )
2290                .map(|((f32_v, byte_v), dt)| match dt {
2291                    rlx_ir::DType::F64 => (byte_v, dt),
2292                    _ => (super::narrow_f32_to_bytes(&f32_v, dt), dt),
2293                })
2294                .collect()
2295        }
2296    }
2297}
2298
2299// ── CUDA Backend ────────────────────────────────────────────────────────
2300
2301#[cfg(feature = "cuda")]
2302pub mod cuda_backend {
2303    use super::*;
2304    use rlx_cuda::backend::CudaExecutable;
2305
2306    pub struct CudaBackend;
2307
2308    /// PLAN L4: ops the CUDA backend can lower today. Excludes
2309    /// FusedSwiGLU, LoraMatMul, FusedAttentionBlock,
2310    /// FusedTransformerLayer (no kernel) + If, While (no executor
2311    /// wiring). DotGeneral via LowerDotGeneral; ElementwiseRegion
2312    /// lowered natively by an NVRTC interpreted-chain kernel.
2313    const CUDA_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2314        use rlx_ir::OpKind::*;
2315        &[
2316            Input,
2317            Param,
2318            Constant,
2319            Activation,
2320            Cast,
2321            Binary,
2322            Compare,
2323            Where,
2324            ElementwiseRegion,
2325            TransformRegion,
2326            BatchElementwiseRegion,
2327            MatMul,
2328            DotGeneral,
2329            LayerNorm,
2330            LayerNorm2d,
2331            GroupNorm,
2332            ResizeNearest2x,
2333            RmsNorm,
2334            Attention,
2335            AttentionBackward,
2336            RmsNormBackwardInput,
2337            RmsNormBackwardGamma,
2338            RmsNormBackwardBeta,
2339            RopeBackward,
2340            CumsumBackward,
2341            GatherBackward,
2342            Conv2dBackwardInput,
2343            Conv2dBackwardWeight,
2344            MaxPool2dBackward,
2345            Rope,
2346            Reshape,
2347            Transpose,
2348            Narrow,
2349            Concat,
2350            Expand,
2351            Gather,
2352            Reduce,
2353            Softmax,
2354            Cumsum,
2355            TopK,
2356            Sample,
2357            Conv,
2358            ConvTranspose2d,
2359            Pool,
2360            GroupedMatMul,
2361            DequantGroupedMatMul,
2362            DequantMoEWeights,
2363            ScatterAdd,
2364            DequantMatMul,
2365            SelectiveScan,
2366            FusedMatMulBiasAct,
2367            FusedResidualLN,
2368            FusedResidualRmsNorm,
2369            GaussianSplatRender,
2370            GaussianSplatRenderBackward,
2371            GaussianSplatPrepare,
2372            GaussianSplatRasterize,
2373            Custom,
2374            Fft,
2375            LogMel,
2376            LogMelBackward,
2377            WelchPeaks,
2378            Im2Col,
2379        ]
2380    };
2381
2382    impl Backend for CudaBackend {
2383        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2384            CUDA_SUPPORTED_OPS
2385        }
2386
2387        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2388            use rlx_opt::pass::Pass as _;
2389            // Decompose FusedSwiGLU / FAB / etc. before legalization (CudaExecutable
2390            // unfuses again; this pass is idempotent).
2391            let graph = rlx_cuda::unfuse::unfuse(graph);
2392            let graph = rlx_opt::legalize_or_rewrite_for_backend(graph, CUDA_SUPPORTED_OPS)
2393                .unwrap_or_else(|errors| {
2394                    panic!("{}", rlx_opt::format_legalize_error("cuda", &errors));
2395                });
2396            let graph = crate::precompile::precompile_cleanup(graph, options);
2397            // Mid-axis broadcasts (EEG patch embed) before elementwise fusion.
2398            let graph = rlx_opt::LegalizeBroadcast.run(graph);
2399            // Backend-aware fusion via the shared compile pipeline.
2400            let compile_result = crate::stages::compile_graph_stages_for_backend(
2401                rlx_driver::Device::Cuda,
2402                graph,
2403                options,
2404                CUDA_SUPPORTED_OPS,
2405            );
2406            crate::stages::maybe_log_fusion(&compile_result.fusion);
2407            let graph = compile_result.lir.into_graph();
2408            let graph = match options.policy.clone() {
2409                Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
2410                None => graph,
2411            };
2412            let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
2413            Box::new(CudaExecutableWrapper {
2414                inner: CudaExecutable::compile(graph),
2415                io_manifest,
2416            })
2417        }
2418
2419        fn compile_lir(
2420            &self,
2421            lir: LirModule,
2422            options: &CompileOptions,
2423        ) -> Box<dyn ExecutableGraph> {
2424            use rlx_opt::pass::Pass as _;
2425            let graph = rlx_opt::LegalizeBroadcast.run(lir.into_graph());
2426            let (graph, io_manifest) =
2427                cpu_low_precision::prepare_f32_exec_graph(prepare_fused_graph(
2428                    rlx_cuda::unfuse::unfuse(graph),
2429                    options,
2430                    CUDA_SUPPORTED_OPS,
2431                    "cuda",
2432                ));
2433            Box::new(CudaExecutableWrapper {
2434                inner: CudaExecutable::compile(graph),
2435                io_manifest,
2436            })
2437        }
2438    }
2439
2440    struct CudaExecutableWrapper {
2441        inner: CudaExecutable,
2442        io_manifest: cpu_low_precision::IoDtypeManifest,
2443    }
2444
2445    // CudaExecutable owns CudaContext + CudaSlice handles; cudarc claims
2446    // they're Send (CudaContext is Arc-wrapped, CudaSlice is logically
2447    // a device pointer + length). The Backend trait requires Send for
2448    // the executable; we honor that here.
2449    unsafe impl Send for CudaExecutableWrapper {}
2450
2451    impl ExecutableGraph for CudaExecutableWrapper {
2452        fn set_param(&mut self, name: &str, data: &[f32]) {
2453            self.inner.set_param(name, data);
2454        }
2455        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2456            self.inner.run(inputs)
2457        }
2458        fn run_read_outputs(
2459            &mut self,
2460            inputs: &[(&str, &[f32])],
2461            read_indices: Option<&[usize]>,
2462        ) -> Vec<Vec<f32>> {
2463            self.inner.run_read_outputs(inputs, read_indices)
2464        }
2465        fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
2466            self.inner.bind_gpu_handle(name, data)
2467        }
2468        fn has_gpu_handle(&self, name: &str) -> bool {
2469            self.inner.has_gpu_handle(name)
2470        }
2471        fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
2472            self.inner.set_gpu_handle_feed(handle_name, output_index);
2473            true
2474        }
2475        fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
2476            self.inner.read_gpu_handle(name)
2477        }
2478        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2479            self.inner.set_active_extent(extent);
2480        }
2481
2482        fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
2483            self.inner.run_slots(inputs)
2484        }
2485
2486        fn arena_ptr(&self) -> *const u8 {
2487            self.inner.arena_ptr()
2488        }
2489
2490        /// Typed param upload — widens F16/BF16 host bytes to f32
2491        /// before routing through `set_param`. CUDA's arena is
2492        /// f32-uniform; the half-precision matmul tier opts in via
2493        /// the separate `set_param_half` API.
2494        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2495            if matches!(dtype, rlx_ir::DType::U8 | rlx_ir::DType::I8) {
2496                self.inner.set_param_bytes(name, data);
2497                return;
2498            }
2499            if dtype == rlx_ir::DType::F32 {
2500                let n = data.len() / 4;
2501                let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2502                self.inner.set_param(name, s);
2503            } else {
2504                let f32_buf = super::widen_bytes_to_f32(data, dtype);
2505                self.inner.set_param(name, &f32_buf);
2506            }
2507        }
2508
2509        /// Typed run — widen each typed input to F32, run, then narrow
2510        /// each output back to its declared graph dtype.
2511        fn run_typed(
2512            &mut self,
2513            inputs: &[(&str, &[u8], rlx_ir::DType)],
2514        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2515            let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2516            for (name, data, dt) in inputs {
2517                let v = super::widen_bytes_to_f32(data, *dt);
2518                owned.push((name.to_string(), v));
2519            }
2520            let refs: Vec<(&str, &[f32])> = owned
2521                .iter()
2522                .map(|(n, d)| (n.as_str(), d.as_slice()))
2523                .collect();
2524            let dtypes =
2525                super::declared_output_dtypes(&self.io_manifest, self.inner.output_dtypes());
2526            let outs = self.inner.run(&refs);
2527            outs.into_iter()
2528                .zip(
2529                    dtypes
2530                        .into_iter()
2531                        .chain(std::iter::repeat(rlx_ir::DType::F32)),
2532                )
2533                .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2534                .collect()
2535        }
2536    }
2537}
2538
2539// ── ROCm Backend ────────────────────────────────────────────────────────
2540
2541#[cfg(feature = "rocm")]
2542pub mod rocm_backend {
2543    use super::*;
2544    use rlx_rocm::backend::RocmExecutable;
2545
2546    pub struct RocmBackend;
2547
2548    /// PLAN L4: ROCm is the sister crate of CUDA; identical Step
2549    /// enum + dispatch shape → identical claimed op set.
2550    const ROCM_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2551        use rlx_ir::OpKind::*;
2552        &[
2553            Input,
2554            Param,
2555            Constant,
2556            Activation,
2557            Cast,
2558            Binary,
2559            Compare,
2560            Where,
2561            ElementwiseRegion,
2562            TransformRegion,
2563            BatchElementwiseRegion,
2564            MatMul,
2565            DotGeneral,
2566            LayerNorm,
2567            LayerNorm2d,
2568            GroupNorm,
2569            ResizeNearest2x,
2570            RmsNorm,
2571            Attention,
2572            AttentionBackward,
2573            RmsNormBackwardInput,
2574            RmsNormBackwardGamma,
2575            RmsNormBackwardBeta,
2576            RopeBackward,
2577            CumsumBackward,
2578            GatherBackward,
2579            Rope,
2580            Reshape,
2581            Transpose,
2582            Narrow,
2583            Concat,
2584            Expand,
2585            Gather,
2586            Reduce,
2587            Softmax,
2588            Cumsum,
2589            TopK,
2590            Sample,
2591            Conv,
2592            ConvTranspose2d,
2593            Pool,
2594            GroupedMatMul,
2595            DequantGroupedMatMul,
2596            DequantMoEWeights,
2597            ScatterAdd,
2598            DequantMatMul,
2599            SelectiveScan,
2600            FusedMatMulBiasAct,
2601            FusedResidualLN,
2602            FusedResidualRmsNorm,
2603            GaussianSplatRender,
2604            GaussianSplatRenderBackward,
2605            GaussianSplatPrepare,
2606            GaussianSplatRasterize,
2607            Custom,
2608            Fft,
2609            LogMel,
2610            LogMelBackward,
2611            WelchPeaks,
2612            Im2Col,
2613        ]
2614    };
2615
2616    impl Backend for RocmBackend {
2617        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2618            ROCM_SUPPORTED_OPS
2619        }
2620
2621        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2622            use rlx_opt::pass::Pass as _;
2623            let graph = rlx_rocm::unfuse::unfuse(graph);
2624            let graph = rlx_opt::legalize_or_rewrite_for_backend(graph, ROCM_SUPPORTED_OPS)
2625                .unwrap_or_else(|errors| {
2626                    panic!("{}", rlx_opt::format_legalize_error("rocm", &errors));
2627                });
2628            let graph = crate::precompile::precompile_cleanup(graph, options);
2629            let graph = rlx_opt::LegalizeBroadcast.run(graph);
2630            let compile_result = crate::stages::compile_graph_stages_for_backend(
2631                rlx_driver::Device::Rocm,
2632                graph,
2633                options,
2634                ROCM_SUPPORTED_OPS,
2635            );
2636            crate::stages::maybe_log_fusion(&compile_result.fusion);
2637            let graph = compile_result.lir.into_graph();
2638            let graph = match options.policy.clone() {
2639                Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
2640                None => graph,
2641            };
2642            let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
2643            Box::new(RocmExecutableWrapper {
2644                inner: RocmExecutable::compile(graph),
2645                io_manifest,
2646            })
2647        }
2648
2649        fn compile_lir(
2650            &self,
2651            lir: LirModule,
2652            options: &CompileOptions,
2653        ) -> Box<dyn ExecutableGraph> {
2654            let (graph, io_manifest) =
2655                cpu_low_precision::prepare_f32_exec_graph(prepare_fused_graph(
2656                    rlx_rocm::unfuse::unfuse(lir.into_graph()),
2657                    options,
2658                    ROCM_SUPPORTED_OPS,
2659                    "rocm",
2660                ));
2661            Box::new(RocmExecutableWrapper {
2662                inner: RocmExecutable::compile(graph),
2663                io_manifest,
2664            })
2665        }
2666    }
2667
2668    struct RocmExecutableWrapper {
2669        inner: RocmExecutable,
2670        io_manifest: cpu_low_precision::IoDtypeManifest,
2671    }
2672
2673    // Same Send-claim shape as CudaExecutableWrapper. RocmExecutable
2674    // owns Arc<RocmContext> + HipBuffer handles; the HipRuntime bundle
2675    // is internally thread-safe per AMD's documentation.
2676    unsafe impl Send for RocmExecutableWrapper {}
2677
2678    impl ExecutableGraph for RocmExecutableWrapper {
2679        fn set_param(&mut self, name: &str, data: &[f32]) {
2680            self.inner.set_param(name, data);
2681        }
2682        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2683            self.inner.run(inputs)
2684        }
2685        fn run_read_outputs(
2686            &mut self,
2687            inputs: &[(&str, &[f32])],
2688            read_indices: Option<&[usize]>,
2689        ) -> Vec<Vec<f32>> {
2690            self.inner.run_read_outputs(inputs, read_indices)
2691        }
2692        fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
2693            self.inner.bind_gpu_handle(name, data)
2694        }
2695        fn has_gpu_handle(&self, name: &str) -> bool {
2696            self.inner.has_gpu_handle(name)
2697        }
2698        fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
2699            self.inner.set_gpu_handle_feed(handle_name, output_index);
2700            true
2701        }
2702        fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
2703            self.inner.read_gpu_handle(name)
2704        }
2705        fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
2706            self.inner.run_slots(inputs)
2707        }
2708        fn arena_ptr(&self) -> *const u8 {
2709            self.inner.arena_ptr()
2710        }
2711        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2712            self.inner.set_active_extent(extent);
2713        }
2714
2715        /// Typed param upload — widens F16/BF16 host bytes to f32
2716        /// before routing through `set_param`. ROCm's arena is
2717        /// f32-uniform; the half-precision matmul tier opts in via
2718        /// the separate `set_param_half` API.
2719        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2720            if matches!(dtype, rlx_ir::DType::U8 | rlx_ir::DType::I8) {
2721                self.inner.set_param_bytes(name, data);
2722                return;
2723            }
2724            if dtype == rlx_ir::DType::F32 {
2725                let n = data.len() / 4;
2726                let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2727                self.inner.set_param(name, s);
2728            } else {
2729                let f32_buf = super::widen_bytes_to_f32(data, dtype);
2730                self.inner.set_param(name, &f32_buf);
2731            }
2732        }
2733
2734        /// Typed run — widen each typed input to F32, run, then narrow
2735        /// each output back to its declared graph dtype.
2736        fn run_typed(
2737            &mut self,
2738            inputs: &[(&str, &[u8], rlx_ir::DType)],
2739        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2740            let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2741            for (name, data, dt) in inputs {
2742                let v = super::widen_bytes_to_f32(data, *dt);
2743                owned.push((name.to_string(), v));
2744            }
2745            let refs: Vec<(&str, &[f32])> = owned
2746                .iter()
2747                .map(|(n, d)| (n.as_str(), d.as_slice()))
2748                .collect();
2749            let dtypes =
2750                super::declared_output_dtypes(&self.io_manifest, self.inner.output_dtypes());
2751            let outs = self.inner.run(&refs);
2752            outs.into_iter()
2753                .zip(
2754                    dtypes
2755                        .into_iter()
2756                        .chain(std::iter::repeat(rlx_ir::DType::F32)),
2757                )
2758                .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2759                .collect()
2760        }
2761    }
2762}
2763
2764// ── TPU Backend ─────────────────────────────────────────────────────────
2765
2766#[cfg(feature = "tpu")]
2767pub mod tpu_backend {
2768    use super::*;
2769    use rlx_tpu::TpuExecutable;
2770
2771    pub struct TpuBackend;
2772
2773    /// Ops the TPU backend lowers to HLO. Full inference parity with
2774    /// rlx-cuda / rlx-rocm. Composite ops (FusedSwiGLU /
2775    /// FusedAttentionBlock / FusedTransformerLayer / LoraMatMul / If /
2776    /// While) are unfused inside `rlx_tpu::unfuse::unfuse` ahead of
2777    /// HLO emission, so they don't appear here.
2778    const TPU_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2779        use rlx_ir::OpKind::*;
2780        &[
2781            Input,
2782            Param,
2783            Constant,
2784            Activation,
2785            Cast,
2786            Binary,
2787            Compare,
2788            Where,
2789            ElementwiseRegion,
2790            TransformRegion,
2791            BatchElementwiseRegion,
2792            MatMul,
2793            DotGeneral,
2794            LayerNorm,
2795            RmsNorm,
2796            Attention,
2797            Rope,
2798            Reshape,
2799            Transpose,
2800            Narrow,
2801            Concat,
2802            Expand,
2803            Gather,
2804            Reduce,
2805            Softmax,
2806            Cumsum,
2807            TopK,
2808            Sample,
2809            Conv,
2810            Pool,
2811            GroupedMatMul,
2812            DequantGroupedMatMul,
2813            DequantMoEWeights,
2814            ScatterAdd,
2815            DequantMatMul,
2816            SelectiveScan,
2817            // Real-INT8 path + fake-quant.
2818            QMatMul,
2819            QConv2d,
2820            Quantize,
2821            Dequantize,
2822            FusedMatMulBiasAct,
2823            FusedResidualLN,
2824            FusedResidualRmsNorm,
2825            Fft,
2826            LogMel,
2827            LogMelBackward,
2828            WelchPeaks,
2829            // Splat: no on-chip kernel — lowered to common primitive MIR via logical_kernel.
2830        ]
2831    };
2832
2833    impl Backend for TpuBackend {
2834        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2835            TPU_SUPPORTED_OPS
2836        }
2837
2838        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2839            let graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
2840                graph,
2841                TPU_SUPPORTED_OPS,
2842                options.kernel_dispatch,
2843            )
2844            .unwrap_or_else(|errors| {
2845                panic!("{}", rlx_opt::format_legalize_error("tpu", &errors));
2846            });
2847            // The TPU's IR-side pass pipeline (DCE, ConstFold,
2848            // FuseResidualLN, FuseMatMulBiasAct, LegalizeBroadcast,
2849            // MarkElementwiseRegions) lives inside
2850            // `TpuExecutable::compile` so the same passes run whether
2851            // a caller goes through Session or invokes the executable
2852            // directly. We only do backend-cross-cutting work here:
2853            // legalization (must precede the pipeline so we panic
2854            // early on unsupported ops) and AutoMixedPrecision.
2855            //
2856            // Default policy on TPU is `AutoMixedBf16`: BF16 is the
2857            // native compute dtype on TPU silicon and recent GPUs,
2858            // and XLA's CPU plugin handles it natively too. Callers
2859            // can opt out by passing an explicit `PrecisionPolicy`
2860            // (e.g. `AlwaysF32` for accuracy debugging or
2861            // `AlwaysF16` to match a CUDA workload's choice).
2862            use rlx_opt::pass::Pass as _;
2863            let policy = options
2864                .policy
2865                .clone()
2866                .unwrap_or(rlx_opt::PrecisionPolicy::AutoMixedBf16);
2867            let graph = rlx_opt::AutoMixedPrecision::new(policy).run(graph);
2868            let _ = options.dce;
2869            let _ = options.constant_folding;
2870            Box::new(TpuExecutableWrapper {
2871                inner: TpuExecutable::compile(graph),
2872            })
2873        }
2874    }
2875
2876    struct TpuExecutableWrapper {
2877        inner: TpuExecutable,
2878    }
2879
2880    // PJRT clients + buffers are documented as thread-safe per the
2881    // upstream C API. Same Send-claim shape as CudaExecutableWrapper /
2882    // RocmExecutableWrapper.
2883    unsafe impl Send for TpuExecutableWrapper {}
2884
2885    impl ExecutableGraph for TpuExecutableWrapper {
2886        fn set_param(&mut self, name: &str, data: &[f32]) {
2887            self.inner.set_param(name, data);
2888        }
2889        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2890            self.inner.run(inputs)
2891        }
2892
2893        /// Typed param upload — widens F16/BF16/etc. host bytes to
2894        /// f32 today. Once the HLO emitter speaks bf16 natively
2895        /// (which TPUs prefer over f16), the typed path will hand
2896        /// the original bytes straight through `Buffer_FromHostBuffer`.
2897        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2898            if dtype == rlx_ir::DType::F32 {
2899                let n = data.len() / 4;
2900                let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2901                self.inner.set_param(name, s);
2902            } else {
2903                let f32_buf = super::widen_bytes_to_f32(data, dtype);
2904                self.inner.set_param(name, &f32_buf);
2905            }
2906        }
2907
2908        fn run_typed(
2909            &mut self,
2910            inputs: &[(&str, &[u8], rlx_ir::DType)],
2911        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2912            let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2913            for (name, data, dt) in inputs {
2914                let v = super::widen_bytes_to_f32(data, *dt);
2915                owned.push((name.to_string(), v));
2916            }
2917            let refs: Vec<(&str, &[f32])> = owned
2918                .iter()
2919                .map(|(n, d)| (n.as_str(), d.as_slice()))
2920                .collect();
2921            let dtypes = self.inner.output_dtypes();
2922            let outs = self.inner.run(&refs);
2923            outs.into_iter()
2924                .zip(
2925                    dtypes
2926                        .into_iter()
2927                        .chain(std::iter::repeat(rlx_ir::DType::F32)),
2928                )
2929                .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2930                .collect()
2931        }
2932    }
2933}