Skip to main content

rlx_runtime/
backend.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Backend trait — abstraction over CPU/GPU/CUDA execution.
17//!
18//! Each backend implements `Backend::compile(graph, &CompileOptions)` and
19//! returns an `ExecutableGraph`. New compile knobs go in `CompileOptions`
20//! rather than as new trait methods.
21
22use crate::CompileOptions;
23use rlx_ir::Graph;
24use rlx_ir::hir::HirModule;
25use rlx_ir::lir::LirModule;
26use std::collections::HashMap;
27use std::sync::Arc;
28
29use crate::cpu_low_precision;
30
31// ── Typed I/O helpers (shared across f32-arena backends) ────────────────
32
33/// Widen a typed byte buffer to `Vec<f32>`. Used by `set_param_typed` /
34/// `run_typed` overrides on backends whose internal arena is f32-uniform
35/// (CPU, Metal, wgpu) so callers can hand in F16/BF16 without doing the
36/// host-side cast themselves. Panics on dtypes the f32 arena can't carry.
37#[allow(dead_code)]
38pub(crate) fn widen_bytes_to_f32(data: &[u8], dtype: rlx_ir::DType) -> Vec<f32> {
39    use rlx_ir::DType;
40    match dtype {
41        DType::F32 => {
42            let n = data.len() / 4;
43            let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
44            s.to_vec()
45        }
46        DType::F16 => {
47            let n = data.len() / 2;
48            let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n) };
49            s.iter().map(|h| h.to_f32()).collect()
50        }
51        DType::BF16 => {
52            let n = data.len() / 2;
53            let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n) };
54            s.iter().map(|h| h.to_f32()).collect()
55        }
56        other => panic!(
57            "widen_bytes_to_f32: dtype {other:?} unsupported on f32-arena backends \
58             (only F32/F16/BF16 are accepted on the host I/O surface)"
59        ),
60    }
61}
62
63/// Narrow a `&[f32]` buffer down to the declared output dtype, returning
64/// the corresponding little-endian byte stream. Mirrors the bytes a
65/// backend that stores the native dtype would emit. Used by `run_typed`
66/// to keep the byte-level output contract identical across backends.
67#[allow(dead_code)]
68pub(crate) fn narrow_f32_to_bytes(v: &[f32], dt: rlx_ir::DType) -> Vec<u8> {
69    use rlx_ir::DType;
70    match dt {
71        DType::F32 => {
72            let mut bytes = Vec::with_capacity(v.len() * 4);
73            for &x in v {
74                bytes.extend_from_slice(&x.to_le_bytes());
75            }
76            bytes
77        }
78        DType::F16 => {
79            let mut bytes = Vec::with_capacity(v.len() * 2);
80            for &x in v {
81                bytes.extend_from_slice(&half::f16::from_f32(x).to_le_bytes());
82            }
83            bytes
84        }
85        DType::BF16 => {
86            let mut bytes = Vec::with_capacity(v.len() * 2);
87            for &x in v {
88                bytes.extend_from_slice(&half::bf16::from_f32(x).to_le_bytes());
89            }
90            bytes
91        }
92        DType::F64 => {
93            let mut bytes = Vec::with_capacity(v.len() * 8);
94            for &x in v {
95                bytes.extend_from_slice(&(x as f64).to_le_bytes());
96            }
97            bytes
98        }
99        DType::I8 => v.iter().map(|&x| x as i8 as u8).collect(),
100        DType::U8 => v.iter().map(|&x| x as u8).collect(),
101        DType::I16 => {
102            let mut bytes = Vec::with_capacity(v.len() * 2);
103            for &x in v {
104                bytes.extend_from_slice(&(x as i16).to_le_bytes());
105            }
106            bytes
107        }
108        DType::I32 => {
109            let mut bytes = Vec::with_capacity(v.len() * 4);
110            for &x in v {
111                bytes.extend_from_slice(&(x as i32).to_le_bytes());
112            }
113            bytes
114        }
115        DType::U32 => {
116            let mut bytes = Vec::with_capacity(v.len() * 4);
117            for &x in v {
118                bytes.extend_from_slice(&(x as u32).to_le_bytes());
119            }
120            bytes
121        }
122        DType::I64 => {
123            let mut bytes = Vec::with_capacity(v.len() * 8);
124            for &x in v {
125                bytes.extend_from_slice(&(x as i64).to_le_bytes());
126            }
127            bytes
128        }
129        DType::Bool => v
130            .iter()
131            .map(|&x| if x != 0.0 { 1u8 } else { 0u8 })
132            .collect(),
133        DType::C64 => {
134            // Complex narrow path: real part = the f32 value, imaginary
135            // part = 0. Mirrors how the backend stores narrowed f32
136            // operands when promoted to a complex op input.
137            let mut bytes = Vec::with_capacity(v.len() * 8);
138            for &x in v {
139                bytes.extend_from_slice(&x.to_le_bytes());
140                bytes.extend_from_slice(&0.0_f32.to_le_bytes());
141            }
142            bytes
143        }
144    }
145}
146
147/// A compiled, ready-to-execute graph on a specific backend.
148pub trait ExecutableGraph: Send {
149    /// Set a named parameter (weight) buffer.
150    fn set_param(&mut self, name: &str, data: &[f32]);
151
152    /// Called after all params are uploaded (`set_param` / `set_param_typed`).
153    /// Backends may warm caches (e.g. Metal QMatMul weight dequant).
154    fn finalize_params(&mut self) {}
155
156    /// Deep-clone this executable into a fresh `Box`. Lets
157    /// `CompiledGraph` implement `Clone` so callers (e.g. eda-mna's
158    /// `SensitivityContext`) can spin up N independent executor
159    /// copies for thread-parallel dispatch without paying the full
160    /// graph-compile cost N times. Default implementation panics;
161    /// backends that support cloning override.
162    fn clone_box(&self) -> Box<dyn ExecutableGraph> {
163        panic!("clone_box not implemented for this backend");
164    }
165
166    /// Execute the graph with named inputs. Returns output data (copies from arena).
167    fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>>;
168
169    /// Like [`Self::run`] but only read back outputs at `read_indices`.
170    /// GPU handle feeds still update for every output. Default: all outputs.
171    fn run_read_outputs(
172        &mut self,
173        inputs: &[(&str, &[f32])],
174        read_indices: Option<&[usize]>,
175    ) -> Vec<Vec<f32>> {
176        match read_indices {
177            None => self.run(inputs),
178            Some(ix) => {
179                // Backends without a native partial-read path still run the full
180                // graph; only clone the requested outputs on the host.
181                let all = self.run(inputs);
182                ix.iter().filter_map(|&i| all.get(i).cloned()).collect()
183            }
184        }
185    }
186
187    /// Execute and return raw pointers to output data in arena (zero-copy).
188    fn run_raw(&mut self, inputs: &[(&str, &[f32])]) -> Vec<(*const f32, usize)> {
189        let vecs = self.run(inputs);
190        vecs.iter().map(|v| (v.as_ptr(), v.len())).collect()
191    }
192
193    /// Fastest: inputs by slot index, returns output (offset, len) pairs.
194    /// Read output from arena via `arena_ptr().add(offset)`.
195    fn run_slots(&mut self, _inputs: &[&[f32]]) -> &[(usize, usize)] {
196        &[] // default: not supported
197    }
198
199    /// Get the raw arena buffer pointer for reading outputs after run_slots.
200    fn arena_ptr(&self) -> *const u8 {
201        std::ptr::null()
202    }
203
204    /// Hint the executor that subsequent `run` calls should process
205    /// only the first `actual` rows along the bucket axis (out of
206    /// `upper`, the extent the graph was compiled at). Backends that
207    /// support per-kernel active-extent dispatch honor this; others
208    /// ignore it and process the full compiled extent.
209    ///
210    /// Pass `None` to clear the hint. The hint is sticky — set it
211    /// before each `run` and clear it after, or maintain it across
212    /// runs at your discretion.
213    ///
214    /// Even when honored, callers must not rely on the contents of the
215    /// output past `actual` rows — that region may contain stale data
216    /// from earlier runs (kernels skip it).
217    ///
218    /// Default: no-op. See `BucketedCompileCache::run_padded` for the
219    /// canonical caller; backends opt in by overriding this method.
220    fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
221        let _ = extent;
222    }
223
224    /// Override RNG policy for in-graph random ops without recompiling.
225    fn set_rng(&mut self, rng: rlx_ir::RngOptions) {
226        let _ = rng;
227    }
228
229    /// Current RNG policy (default when the backend does not override).
230    fn rng(&self) -> rlx_ir::RngOptions {
231        rlx_ir::RngOptions::default()
232    }
233
234    /// TIDE merged placement mask (union across MoE layers). CPU: stats + host path.
235    fn set_moe_resident_experts(&mut self, _mask: &[bool]) {}
236
237    /// Per MoE layer placement (`masks[layer][expert]`). Preferred over merged on CPU.
238    fn set_moe_resident_experts_per_layer(&mut self, _masks: &[&[bool]]) {}
239
240    /// Capture MoE router TopK indices on the next CPU forward (TIDE refresh).
241    fn enable_moe_topk_capture(&mut self, _num_experts: usize) -> bool {
242        false
243    }
244
245    /// Take captured per-layer expert indices (one vec per MoE TopK in order).
246    fn take_moe_topk_capture(&mut self) -> Option<Vec<Vec<u32>>> {
247        None
248    }
249
250    /// MoE GroupedMatMul residency accounting from the last forward (CPU).
251    fn take_moe_residency_stats(&mut self) -> Option<crate::MoeResidencyStats> {
252        None
253    }
254
255    /// Bind a persistent buffer handle (KV-cache, training state, etc.).
256    /// The buffer lives across run() calls and is not in the arena.
257    /// Returns true if the backend supports persistent handles.
258    fn bind_handle(&mut self, _name: &str, _data: &[f32]) -> bool {
259        false
260    }
261
262    /// Read a persistent buffer's current contents.
263    fn read_handle(&self, _name: &str) -> Option<Vec<f32>> {
264        None
265    }
266
267    /// GPU-resident input (MLX): upload once, reuse across runs.
268    fn bind_gpu_handle(&mut self, _name: &str, _data: &[f32]) -> bool {
269        false
270    }
271
272    fn has_gpu_handle(&self, _name: &str) -> bool {
273        false
274    }
275
276    fn set_gpu_handle_feed(&mut self, _handle_name: &str, _output_index: usize) -> bool {
277        false
278    }
279
280    fn read_gpu_handle(&self, _name: &str) -> Option<Vec<f32>> {
281        None
282    }
283
284    /// Read one row from a row-major graph output after `run` / `run_read_outputs`.
285    /// Metal reads a single row from the arena; default returns `None` (caller falls back).
286    fn read_output_row(&self, _out_idx: usize, _row: usize, _row_inner: usize) -> Option<Vec<f32>> {
287        None
288    }
289
290    /// Run and refresh a GPU handle from `output_index`; returns that output on host.
291    fn run_feed_gpu_handle(
292        &mut self,
293        inputs: &[(&str, &[f32])],
294        _handle_name: &str,
295        _output_index: usize,
296    ) -> Option<Vec<f32>> {
297        let _ = inputs;
298        None
299    }
300
301    // ── Pipelined / async execution (Phase C) ─────────────────────────
302    //
303    // These allow callers to amortize per-run sync latency on backends
304    // where it matters (Metal: ~150 µs `wait_until_completed` per commit).
305    // CPU has no such cost, so the default impls just call `run` serially.
306
307    /// Encode + commit a forward pass without waiting for completion.
308    ///
309    /// Outputs of intermediate calls are stomped — use `run_pipelined` if
310    /// you need outputs from each individual commit. Pair with
311    /// `sync_pending` to drain.
312    ///
313    /// Default: synchronous fallback (calls `run`, discards output). CPU
314    /// uses this default since BLAS is synchronous anyway.
315    fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
316        let _ = self.run(inputs);
317    }
318
319    /// Wait for every command queued by `commit_no_wait`.
320    /// Default: no-op (synchronous backends have nothing pending).
321    fn sync_pending(&mut self) {}
322
323    /// Issue a batch of forward passes pipelined, returning per-run outputs.
324    ///
325    /// The Metal impl encodes a per-commit blit so each in-flight run's
326    /// outputs survive subsequent commits stomping the shared arena. The
327    /// CPU default is just sequential `run`s — equally correct, no perf
328    /// penalty (CPU has no GPU sync cost to amortize).
329    ///
330    /// Returns `out[run_idx][output_idx][element_idx]`.
331    fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
332        input_sets.iter().map(|inputs| self.run(inputs)).collect()
333    }
334
335    // ── Typed (non-F32) host I/O ──────────────────────────────────
336    //
337    // `set_param` and `run` are F32 by contract. The typed entry
338    // points let callers pass and receive raw bytes in any rlx-ir
339    // dtype, avoiding the f32 widen/narrow round-trip that's
340    // wasteful for F16/BF16 weights and activations.
341    //
342    // The default impls only handle F32 — any other dtype panics.
343    // Backends that support typed I/O natively (e.g. MLX via
344    // Array::from_bytes/to_bytes) override these.
345
346    /// Set a named parameter from raw bytes in the given dtype.
347    fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
348        if dtype != rlx_ir::DType::F32 {
349            panic!(
350                "backend's default set_param_typed only handles F32; \
351                    got {dtype:?}. Override on the backend for typed support."
352            );
353        }
354        if !data.len().is_multiple_of(4) {
355            panic!(
356                "set_param_typed F32: data length {} not a multiple of 4",
357                data.len()
358            );
359        }
360        // SAFETY: F32 bytes are 4-aligned by source convention; we
361        // only widen access (read &[f32] from owned &[u8]). Failure
362        // mode if a caller hands us mis-aligned bytes is undefined,
363        // hence the % 4 length check.
364        let n = data.len() / 4;
365        let f32_slice = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
366        self.set_param(name, f32_slice);
367    }
368
369    /// Run with typed inputs and typed outputs. Returns
370    /// `(bytes, dtype)` per output; the dtype is whatever the
371    /// graph's output node was declared as.
372    fn run_typed(
373        &mut self,
374        inputs: &[(&str, &[u8], rlx_ir::DType)],
375    ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
376        // Default impl: convert each typed input to f32 (F32-only),
377        // run, then re-emit outputs as F32 bytes.
378        let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
379        for (name, data, dt) in inputs {
380            if *dt != rlx_ir::DType::F32 {
381                panic!(
382                    "backend's default run_typed only handles F32 inputs; \
383                        got {dt:?} for input '{name}'"
384                );
385            }
386            if data.len() % 4 != 0 {
387                panic!(
388                    "run_typed F32 input '{name}': len {} not multiple of 4",
389                    data.len()
390                );
391            }
392            let n = data.len() / 4;
393            let v: Vec<f32> =
394                unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) }.to_vec();
395            owned.push((name.to_string(), v));
396        }
397        let refs: Vec<(&str, &[f32])> = owned
398            .iter()
399            .map(|(n, d)| (n.as_str(), d.as_slice()))
400            .collect();
401        let outs = self.run(&refs);
402        outs.into_iter()
403            .map(|v| {
404                let bytes =
405                    unsafe { std::slice::from_raw_parts(v.as_ptr() as *const u8, v.len() * 4) }
406                        .to_vec();
407                (bytes, rlx_ir::DType::F32)
408            })
409            .collect()
410    }
411}
412
413/// Backend implementation trait.
414///
415/// Single compile entry point. New compile-time knobs are added to
416/// `CompileOptions`, not as new trait methods.
417///
418/// `Send + Sync` because backends are stateless factories — multiple
419/// threads can call `compile` concurrently. The returned
420/// `Box<dyn ExecutableGraph>` is `Send` (moveable to a worker thread)
421/// but **not** `Sync` (`run`/`run_slots` take `&mut self`).
422pub trait Backend: Send + Sync {
423    /// Compile a graph for this backend with the given options.
424    fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph>;
425
426    /// Compile pre-optimized LIR (HIR → MIR → LIR pipeline output).
427    /// Default re-enters [`Self::compile`] — backends should override
428    /// when they can reuse the embedded buffer plan.
429    fn compile_lir(&self, lir: LirModule, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
430        self.compile(lir.into_graph(), options)
431    }
432
433    /// HIR-first compile: lower blocks, run fusion pipeline, emit executable.
434    fn compile_hir(
435        &self,
436        hir: HirModule,
437        device: rlx_driver::Device,
438        options: &CompileOptions,
439    ) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
440        let result = crate::stages::compile_hir_stages(device, hir, options)?;
441        crate::stages::maybe_log_fusion(&result.fusion);
442        Ok(self.compile_lir(result.lir, options))
443    }
444
445    /// [`GraphModule`] compile — unified HIR/MIR/LIR entry.
446    fn compile_module(
447        &self,
448        module: rlx_ir::GraphModule,
449        device: rlx_driver::Device,
450        options: &CompileOptions,
451    ) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
452        let result = crate::stages::compile_module_stages(device, module, options)?;
453        crate::stages::maybe_log_fusion(&result.fusion);
454        Ok(self.compile_lir(result.lir, options))
455    }
456
457    /// PLAN L4: declare which `OpKind`s this backend can lower.
458    /// Default: empty slice = "no claim made — accept everything"
459    /// (preserves existing behavior; backends opt in by overriding).
460    /// When non-empty, the `LegalizeForBackend` pass will refuse to
461    /// compile a graph that contains an op outside this set, instead
462    /// of silently falling through to slower / wrong dispatch.
463    fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
464        &[]
465    }
466}
467
468/// Prepare a fused MIR graph from LIR for backend executable construction.
469/// Skips the fusion pipeline — LIR must come from `compile_*_stages`.
470#[allow(dead_code)]
471fn prepare_fused_graph(
472    graph: Graph,
473    options: &CompileOptions,
474    supported_ops: &[rlx_ir::OpKind],
475    backend_name: &str,
476) -> Graph {
477    let (mut graph, report) = rlx_opt::prepare_graph_for_backend_with_report(
478        graph,
479        backend_name,
480        supported_ops,
481        options.kernel_dispatch,
482    );
483    rlx_opt::maybe_log_dispatch_report(&report);
484    if !report.compile_ready {
485        panic!(
486            "{}\n{}",
487            rlx_opt::format_legalize_error(backend_name, &report.still_unsupported),
488            rlx_opt::format_dispatch_report(&report)
489        );
490    }
491    graph = crate::precompile::post_fusion_cleanup(graph, options);
492    if let Some(p) = options.policy.clone() {
493        use rlx_opt::pass::Pass as _;
494        graph = rlx_opt::AutoMixedPrecision::new(p).run(graph);
495    }
496    graph
497}
498
499#[allow(dead_code)]
500fn declared_output_dtypes(
501    manifest: &cpu_low_precision::IoDtypeManifest,
502    exec_dtypes: Vec<rlx_ir::DType>,
503) -> Vec<rlx_ir::DType> {
504    exec_dtypes
505        .into_iter()
506        .enumerate()
507        .map(|(i, exec)| manifest.output_dtype(i, exec))
508        .collect()
509}
510
511// ── Convenience helpers preserved from older API ──────────────────────
512//
513// These let existing call sites keep working unchanged while the new
514// trait is the canonical one. We provide free functions rather than
515// trait methods so adding them doesn't grow the trait surface.
516
517/// Compile at default options (F32, no policy).
518pub fn compile(backend: &dyn Backend, graph: Graph) -> Box<dyn ExecutableGraph> {
519    backend.compile(graph, &CompileOptions::default())
520}
521
522/// Compile HIR through the fusion-first pipeline.
523pub fn compile_hir(
524    backend: &dyn Backend,
525    hir: HirModule,
526    device: rlx_driver::Device,
527    options: &CompileOptions,
528) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
529    backend.compile_hir(hir, device, options)
530}
531
532/// Compile a [`GraphModule`] through the fusion-first pipeline.
533pub fn compile_module(
534    backend: &dyn Backend,
535    module: rlx_ir::GraphModule,
536    device: rlx_driver::Device,
537    options: &CompileOptions,
538) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
539    backend.compile_module(module, device, options)
540}
541
542/// Compile at a specific precision (default policy = none).
543pub fn compile_with_precision(
544    backend: &dyn Backend,
545    graph: Graph,
546    precision: crate::Precision,
547) -> Box<dyn ExecutableGraph> {
548    backend.compile(graph, &CompileOptions::new().precision(precision))
549}
550
551/// Helper retained for backward compatibility — applies the precision
552/// rewrite at the runtime layer if backends don't override their
553/// pipeline placement. Modern code: pass the policy via CompileOptions
554/// and let the backend handle ordering.
555fn _legacy_apply_policy(graph: Graph, policy: Option<rlx_opt::PrecisionPolicy>) -> Graph {
556    use rlx_opt::pass::Pass as _;
557    match policy {
558        Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
559        None => graph,
560    }
561}
562
563// ── CPU Backend ─────────────────────────────────────────────────────────
564
565#[cfg(feature = "cpu")]
566pub mod cpu_backend {
567    use super::*;
568    use rlx_cpu::{arena::Arena, thunk};
569    use rlx_ir::{DType, NodeId, Op};
570    use rlx_opt::memory::{self, MemoryPlan};
571    // Arena typed read/write helpers live in `crate::arena` so every
572    // backend (CPU, Metal, future CUDA/wgpu/WASM) shares one implementation.
573    use rlx_driver::arena::{read_typed_to_f32, write_typed_from_f32};
574
575    pub struct CpuBackend;
576
577    /// PLAN L4: ops the CPU backend can lower today. Includes
578    /// DotGeneral (lowered via `LowerDotGeneral` pass) and
579    /// ElementwiseRegion (lowered natively per L2). Excludes
580    /// FusedTransformerLayer / If / While — those have IR variants
581    /// but no CPU lowering yet (see `compile_thunks` arm absence +
582    /// `subgraph.rs` "If/While executor wiring is pending" note).
583    const CPU_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
584        use rlx_ir::OpKind::*;
585        &[
586            Input,
587            Param,
588            Constant,
589            Activation,
590            Cast,
591            StopGradient,
592            Binary,
593            Compare,
594            Where,
595            ElementwiseRegion,
596            MatMul,
597            DotGeneral,
598            DenseSolve,
599            BatchedDenseSolve,
600            Scan,
601            ScanBackward,
602            ScanBackwardXs,
603            LayerNorm,
604            LayerNorm2d,
605            GroupNorm,
606            BatchNormInference,
607            RmsNorm,
608            ResizeNearest2x,
609            AxialRope2d,
610            Attention,
611            Rope,
612            Reshape,
613            Transpose,
614            Narrow,
615            Concat,
616            Expand,
617            Gather,
618            Reduce,
619            Softmax,
620            Cumsum,
621            ArgMax,
622            ArgMin,
623            TopK,
624            Sample,
625            RngNormal,
626            RngUniform,
627            Conv,
628            Im2Col,
629            ConvTranspose2d,
630            Pool,
631            GroupedMatMul,
632            DequantGroupedMatMul,
633            DequantMoEWeights,
634            ScatterAdd,
635            LoraMatMul,
636            DequantMatMul,
637            SelectiveScan,
638            GatedDeltaNet,
639            Lstm,
640            FusedSwiGLU,
641            FusedMatMulBiasAct,
642            FusedResidualLN,
643            FusedResidualRmsNorm,
644            FusedAttentionBlock,
645            // Backward ops emitted by `rlx_opt::autodiff::grad_with_loss`.
646            // Their thunks live in rlx-cpu/src/thunk.rs alongside the
647            // forward kernels; without these entries the legalize step
648            // below would reject any compiled gradient graph.
649            ReluBackward,
650            ActivationBackward,
651            FakeQuantize,
652            FakeQuantizeBackward,
653            MaxPool2dBackward,
654            Conv2dBackwardInput,
655            Conv2dBackwardWeight,
656            SoftmaxCrossEntropyWithLogits,
657            SoftmaxCrossEntropyBackward,
658            AttentionBackward,
659            LayerNormBackwardInput,
660            LayerNormBackwardGamma,
661            BatchNormInferenceBackwardInput,
662            BatchNormInferenceBackwardGamma,
663            BatchNormInferenceBackwardBeta,
664            RmsNormBackwardInput,
665            RmsNormBackwardGamma,
666            RmsNormBackwardBeta,
667            RopeBackward,
668            CumsumBackward,
669            GatherBackward,
670            // 3D Gaussian splat CPU reference render/backward (requires `rlx-cpu/splat`).
671            GaussianSplatRender,
672            GaussianSplatRenderBackward,
673            GaussianSplatPrepare,
674            GaussianSplatRasterize,
675            // User-registered custom ops dispatched through
676            // `rlx_cpu::op_registry`. Lowering panics with a clear
677            // message if the named CPU kernel isn't registered.
678            Custom,
679            // User-defined sub-graph with optional override AD rules
680            // (JAX-shaped custom_vjp / custom_jvp). Body is a regular
681            // Graph compiled recursively in compile_thunks.
682            CustomFn,
683            // FFT primitive (1D last-axis, 2N real-block layout, f64
684            // power-of-2 sizes). Other backends panic at lowering;
685            // pin FFT-containing graphs to Device::Cpu for now.
686            Fft,
687            FftButterflyStage,
688            LogMel,
689            LogMelBackward,
690            WelchPeaks,
691            // C64 Wirtinger AD surface. ComplexNormSq is the canonical
692            // real-valued loss for complex inputs; Conjugate is emitted
693            // by the new Wirtinger VJP rules for BinaryOp::Mul/Div on
694            // C64. Both have CPU thunks in rlx-cpu.
695            ComplexNormSq,
696            ComplexNormSqBackward,
697            Conjugate,
698        ]
699    };
700
701    impl Backend for CpuBackend {
702        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
703            CPU_SUPPORTED_OPS
704        }
705
706        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
707            use rlx_opt::pass::Pass as _;
708            static ONNX_KERNELS: std::sync::Once = std::sync::Once::new();
709            ONNX_KERNELS.call_once(rlx_cpu::onnx_ref::register_onnx_reference_kernels);
710            // Lower Op::If / Op::While to primitives BEFORE legalize
711            // so the supported-op check doesn't reject them — the CPU
712            // backend has no native sub-graph executor; this rewrite
713            // makes If/While invisible to the rest of the pipeline.
714            // No-op when neither op is in the graph.
715            let graph = rlx_opt::LowerControlFlow.run(graph);
716            // PLAN L4: legalize against the backend's claimed op set
717            // BEFORE running fusion (so the diagnostic points at the
718            // user's IR, not at a fused-away node).
719            if let Err(errors) = rlx_opt::legalize_for_backend(&graph, CPU_SUPPORTED_OPS) {
720                panic!("{}", rlx_opt::format_legalize_error("cpu", &errors));
721            }
722            let policy = options.policy.clone();
723            let _precision = options.precision;
724            let cfg = rlx_cpu::config::RuntimeConfig::global();
725
726            let graph = crate::precompile::precompile_cleanup(graph, options);
727
728            // Run fusion pipeline (HIR/MIR/LIR ideology — fusion is first-class).
729            let mut compile_opts = options.clone();
730            compile_opts.arena_alignment = cfg.arena_alignment;
731            let compile_result = crate::stages::compile_graph_stages_for_backend(
732                rlx_driver::Device::Cpu,
733                graph,
734                &compile_opts,
735                CPU_SUPPORTED_OPS,
736            );
737            crate::stages::maybe_log_fusion(&compile_result.fusion);
738            let fused = compile_result.lir.into_graph();
739
740            // Apply precision policy AFTER fusion — Cast nodes don't disrupt
741            // the now-flattened fused ops.
742            let fused = match policy {
743                Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(fused),
744                None => fused,
745            };
746
747            let io_manifest = cpu_low_precision::IoDtypeManifest::from_graph(&fused);
748            let exec_graph = if cpu_low_precision::needs_f32_exec(&fused) {
749                cpu_low_precision::promote_to_f32(fused)
750            } else {
751                fused
752            };
753
754            // Re-plan after precision rewrites (may change dtypes / sizes).
755            let plan = memory::plan_memory_aligned(&exec_graph, cfg.arena_alignment);
756            if cfg.verbose >= 1 {
757                eprintln!(
758                    "[rlx] arena: {} bytes, {} buffers, alignment: {}",
759                    plan.arena_size,
760                    plan.assignments.len(),
761                    cfg.arena_alignment
762                );
763            }
764            Box::new(build_cpu_executable(
765                exec_graph,
766                plan,
767                io_manifest,
768                options.rng,
769            ))
770        }
771
772        fn compile_lir(
773            &self,
774            lir: LirModule,
775            options: &CompileOptions,
776        ) -> Box<dyn ExecutableGraph> {
777            let alignment = lir.buffers.alignment.max(options.arena_alignment);
778            let mut graph = lir.into_graph();
779            {
780                use rlx_opt::pass::Pass as _;
781                graph = rlx_opt::LegalizeBroadcast.run(graph);
782            }
783            if let Some(p) = options.policy.clone() {
784                use rlx_opt::pass::Pass;
785                graph = rlx_opt::AutoMixedPrecision::new(p).run(graph);
786            }
787            let io_manifest = cpu_low_precision::IoDtypeManifest::from_graph(&graph);
788            let promote = cpu_low_precision::needs_f32_exec(&graph);
789            let exec_graph = if promote {
790                cpu_low_precision::promote_to_f32(graph)
791            } else {
792                graph
793            };
794            // LegalizeBroadcast may insert Expand nodes — must replan; the
795            // embedded LIR buffer map is from before legalization.
796            let plan = memory::plan_memory_aligned(&exec_graph, alignment);
797            let cfg = rlx_cpu::config::RuntimeConfig::global();
798            if cfg.verbose >= 1 {
799                eprintln!(
800                    "[rlx] compile_lir: arena {} bytes ({} buffers, alignment {})",
801                    plan.arena_size,
802                    plan.assignments.len(),
803                    alignment,
804                );
805            }
806            Box::new(build_cpu_executable(
807                exec_graph,
808                plan,
809                io_manifest,
810                options.rng,
811            ))
812        }
813    }
814
815    fn build_cpu_executable(
816        graph: Graph,
817        plan: MemoryPlan,
818        io_manifest: cpu_low_precision::IoDtypeManifest,
819        rng: rlx_ir::RngOptions,
820    ) -> CpuExecutable {
821        let mut arena = Arena::from_plan(plan);
822        let mut input_ids = HashMap::new();
823        let mut param_ids = HashMap::new();
824        let mut node_dtypes: HashMap<NodeId, DType> = HashMap::new();
825        for node in graph.nodes() {
826            node_dtypes.insert(node.id, node.shape.dtype());
827            match &node.op {
828                Op::Input { name } => {
829                    input_ids.insert(name.clone(), node.id);
830                }
831                Op::Param { name } => {
832                    param_ids.insert(name.clone(), node.id);
833                }
834                _ => {}
835            }
836        }
837
838        let schedule = thunk::compile_thunks_with_rng(&graph, &arena, rng);
839
840        let mut input_slots = Vec::new();
841        for node in graph.nodes() {
842            if let Op::Input { name } = &node.op {
843                let off = arena.byte_offset(node.id);
844                let len = node.shape.num_elements().unwrap_or(0);
845                input_slots.push((name.clone(), off, len, node.shape.dtype()));
846            }
847        }
848
849        let output_slots: Vec<(usize, usize)> = graph
850            .outputs
851            .iter()
852            .map(|&id| {
853                let off = arena.byte_offset(id);
854                let len = graph.node(id).shape.num_elements().unwrap_or(0);
855                (off, len)
856            })
857            .collect();
858
859        for node in graph.nodes() {
860            if let Op::Constant { data } = &node.op
861                && arena.has_buffer(node.id)
862                && !data.is_empty()
863            {
864                match node.shape.dtype() {
865                    // True-width dtypes (their arena slot is sized to the real
866                    // element width, not f32): copy the raw bytes. I64/I32/U32
867                    // constants were previously caught by the f32-reinterpret
868                    // branch below, which read them in 4-byte chunks as f32 —
869                    // corrupting e.g. the VITS sequence-mask `arange` constant
870                    // (i64 [0..T-1]) so the downstream i64 `Compare` read garbage.
871                    DType::F64
872                    | DType::F16
873                    | DType::BF16
874                    | DType::I64
875                    | DType::I32
876                    | DType::U32 => {
877                        let off = arena.byte_offset(node.id);
878                        let buf = arena.raw_buf_mut();
879                        let n = buf.len().saturating_sub(off).min(data.len());
880                        buf[off..off + n].copy_from_slice(&data[..n]);
881                    }
882                    _ => {
883                        let buf = arena.slice_mut(node.id);
884                        let n_floats = data.len() / 4;
885                        let n = buf.len().min(n_floats);
886                        for i in 0..n {
887                            let bytes = [
888                                data[i * 4],
889                                data[i * 4 + 1],
890                                data[i * 4 + 2],
891                                data[i * 4 + 3],
892                            ];
893                            buf[i] = f32::from_le_bytes(bytes);
894                        }
895                    }
896                }
897            }
898        }
899
900        CpuExecutable {
901            graph,
902            arena,
903            params: HashMap::new(),
904            typed_params: HashMap::new(),
905            input_ids,
906            param_ids,
907            node_dtypes,
908            io_manifest,
909            schedule,
910            input_slots,
911            output_slots,
912            handles: HashMap::new(),
913            active_extent: None,
914            moe_resident: None,
915            moe_resident_layers: None,
916            moe_topk_capture: None,
917        }
918    }
919
920    #[derive(Clone)]
921    struct CpuExecutable {
922        graph: Graph,
923        arena: Arena,
924        params: HashMap<String, Vec<f32>>,
925        /// Byte-backed params (`set_param_typed` / `set_param_bytes`).
926        typed_params: HashMap<String, (Vec<u8>, DType)>,
927        input_ids: HashMap<String, NodeId>,
928        param_ids: HashMap<String, NodeId>,
929        /// Per-node arena dtype. Lets set_param/run cast f32 ↔ F16/BF16
930        /// when AutoMixedPrecision has rewritten the graph.
931        node_dtypes: HashMap<NodeId, DType>,
932        /// User-facing boundary dtypes (before f32 promotion for CPU exec).
933        io_manifest: cpu_low_precision::IoDtypeManifest,
934        schedule: thunk::ThunkSchedule,
935        // Pre-resolved: ordered list of (input_name, arena_byte_offset, max_elems, dtype)
936        input_slots: Vec<(String, usize, usize, DType)>,
937        /// Output (byte_offset, num_elements). dtype is in node_dtypes.
938        output_slots: Vec<(usize, usize)>,
939        /// Persistent buffer handles (KV-cache, optimizer state, etc.).
940        /// Lives outside the arena and survives across run() calls.
941        /// On run(): if a handle's name matches a graph input, the
942        /// handle's data is used as the input.
943        handles: HashMap<String, Vec<f32>>,
944        /// Active-extent hint (`Some((actual, upper))`) for L1 bucketed
945        /// dispatch. When set AND every thunk in the schedule is in
946        /// `Thunk::safe_for_active_extent`, the executor processes only
947        /// `actual / upper` of each kernel's work. Otherwise (or when
948        /// `None`) runs at the full compiled extent. See PLAN L1.
949        active_extent: Option<(usize, usize)>,
950        moe_resident: Option<std::sync::Arc<[bool]>>,
951        moe_resident_layers: Option<std::sync::Arc<Vec<std::sync::Arc<[bool]>>>>,
952        moe_topk_capture: Option<std::sync::Arc<rlx_cpu::moe_topk_capture::MoeTopkCapture>>,
953    }
954
955    unsafe impl Send for CpuExecutable {}
956
957    impl CpuExecutable {
958        /// Write a f32 input slice into the arena, casting to the node's dtype.
959        fn write_input(&mut self, id: NodeId, data: &[f32]) {
960            let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
961            let off = self.arena.byte_offset(id);
962            let buf = self.arena.raw_buf_mut();
963            let elem_size = dtype.size_bytes();
964            let max_elems = (buf.len() - off) / elem_size;
965            unsafe {
966                write_typed_from_f32(buf.as_mut_ptr().add(off), dtype, data, max_elems);
967            }
968        }
969
970        /// Read a node's arena bytes back as Vec<f32>, casting from its dtype.
971        fn read_output(&self, id: NodeId) -> Vec<f32> {
972            let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
973            let off = self.arena.byte_offset(id);
974            let buf = self.arena.raw_buf();
975            let n_elems = self.graph.node(id).shape.num_elements().unwrap_or(0);
976            unsafe { read_typed_to_f32(buf.as_ptr().add(off), dtype, n_elems) }
977        }
978    }
979
980    impl ExecutableGraph for CpuExecutable {
981        fn clone_box(&self) -> Box<dyn ExecutableGraph> {
982            Box::new(self.clone())
983        }
984        fn set_param(&mut self, name: &str, data: &[f32]) {
985            self.params.insert(name.to_string(), data.to_vec());
986            self.typed_params.remove(name);
987            // Write directly into the arena — zero per-call lookup for params.
988            // Cast f32 → arena dtype when the param has been rewritten to F16/BF16.
989            if let Some(&id) = self.param_ids.get(name)
990                && self.arena.has_buffer(id)
991            {
992                let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
993                let off = self.arena.byte_offset(id);
994                let buf = self.arena.raw_buf_mut();
995                let elem_size = dtype.size_bytes();
996                let max_elems = (buf.len() - off) / elem_size;
997                unsafe {
998                    write_typed_from_f32(buf.as_mut_ptr().add(off), dtype, data, max_elems);
999                }
1000            }
1001        }
1002
1003        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
1004            self.restore_arena_baseline();
1005            // 1. Apply persistent handles first — they act like default inputs.
1006            //    Explicit `inputs` passed to run() override matching handle names.
1007            let handle_names: Vec<String> = self.handles.keys().cloned().collect();
1008            for name in &handle_names {
1009                if let Some(&id) = self.input_ids.get(name)
1010                    && self.arena.has_buffer(id)
1011                {
1012                    let data = self.handles.get(name).cloned().unwrap_or_default();
1013                    self.write_input(id, &data);
1014                }
1015            }
1016            // 2. Explicit per-call inputs override handles.
1017            for &(name, data) in inputs {
1018                if let Some(&id) = self.input_ids.get(name)
1019                    && self.arena.has_buffer(id)
1020                {
1021                    self.write_input(id, data);
1022                }
1023            }
1024
1025            // Active-extent fast-path (PLAN L1): if hinted AND every thunk
1026            // in the schedule supports it, run scaled. Otherwise fall back
1027            // to full-extent dispatch — preserves correctness when the
1028            // schedule contains a thunk that hasn't yet been wired in.
1029            let active_used = if let Some((actual, upper)) = self.active_extent {
1030                thunk::execute_thunks_active(
1031                    &self.schedule,
1032                    self.arena.raw_buf_mut(),
1033                    actual,
1034                    upper,
1035                )
1036            } else {
1037                false
1038            };
1039            if !active_used {
1040                // Execute via pre-compiled thunks (zero per-node dispatch overhead)
1041                thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1042            }
1043
1044            // 3. Sync any handle whose name matches a graph OUTPUT —
1045            //    KV-cache pattern: outputs flow back into the same-named
1046            //    handle for the next iteration.
1047            for (idx, &out_id) in self.graph.outputs.iter().enumerate() {
1048                let name = format!("out{idx}");
1049                if self.handles.contains_key(&name) {
1050                    let v = self.read_output(out_id);
1051                    self.handles.insert(name, v);
1052                }
1053            }
1054
1055            self.graph
1056                .outputs
1057                .iter()
1058                .map(|&out_id| self.read_output(out_id))
1059                .collect()
1060        }
1061
1062        fn run_raw(&mut self, inputs: &[(&str, &[f32])]) -> Vec<(*const f32, usize)> {
1063            self.restore_arena_baseline();
1064            // Copy inputs by name (HashMap lookup), casting to arena dtype.
1065            for &(name, data) in inputs {
1066                if let Some(&id) = self.input_ids.get(name)
1067                    && self.arena.has_buffer(id)
1068                {
1069                    self.write_input(id, data);
1070                }
1071            }
1072            thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1073            // Note: pointers are raw arena bytes — for F16 outputs, callers
1074            // must read 2 bytes/elem, not 4. run() is the safe path for
1075            // mixed precision; run_raw() is only meaningful for F32.
1076            self.graph
1077                .outputs
1078                .iter()
1079                .map(|&out_id| {
1080                    let (ptr, len) = self.arena.raw_ptr(out_id);
1081                    (ptr as *const f32, len)
1082                })
1083                .collect()
1084        }
1085
1086        /// Fastest path: inputs by index (matching input_slots order), zero-copy output.
1087        /// No HashMap, no name matching, no Vec allocation. Casts f32 input
1088        /// to F16/BF16 if the input slot's dtype was rewritten.
1089        fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
1090            self.restore_arena_baseline();
1091            let buf = self.arena.raw_buf_mut();
1092            for (i, &data) in inputs.iter().enumerate() {
1093                if i < self.input_slots.len() {
1094                    let (_, off, max_len, dtype) = &self.input_slots[i];
1095                    unsafe {
1096                        write_typed_from_f32(buf.as_mut_ptr().add(*off), *dtype, data, *max_len);
1097                    }
1098                }
1099            }
1100            thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1101            &self.output_slots
1102        }
1103
1104        fn arena_ptr(&self) -> *const u8 {
1105            self.arena.raw_buf_mut_ptr()
1106        }
1107
1108        fn bind_handle(&mut self, name: &str, data: &[f32]) -> bool {
1109            // Persistent buffer: stored separately from arena, survives run().
1110            // If the name matches a graph input, run() will use this data
1111            // as the input. If the graph also writes back to this name (via
1112            // an output binding pattern), read_handle returns the latest.
1113            self.handles.insert(name.to_string(), data.to_vec());
1114            true
1115        }
1116
1117        fn read_handle(&self, name: &str) -> Option<Vec<f32>> {
1118            self.handles.get(name).cloned()
1119        }
1120
1121        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1122            self.active_extent = extent;
1123        }
1124
1125        fn set_rng(&mut self, rng: rlx_ir::RngOptions) {
1126            *self.schedule.rng.write().unwrap() = rng;
1127        }
1128
1129        fn rng(&self) -> rlx_ir::RngOptions {
1130            *self.schedule.rng.read().unwrap()
1131        }
1132
1133        fn set_moe_resident_experts(&mut self, mask: &[bool]) {
1134            self.moe_resident_layers = None;
1135            self.schedule.moe_resident_layers = None;
1136            self.moe_resident = Some(Arc::from(mask));
1137            self.schedule.moe_resident = self.moe_resident.clone();
1138        }
1139
1140        fn set_moe_resident_experts_per_layer(&mut self, masks: &[&[bool]]) {
1141            self.moe_resident = None;
1142            self.schedule.moe_resident = None;
1143            let layers: Vec<Arc<[bool]>> = masks.iter().map(|m| Arc::from(*m)).collect();
1144            let arc = Arc::new(layers);
1145            self.moe_resident_layers = Some(arc.clone());
1146            self.schedule.moe_resident_layers = Some(arc);
1147        }
1148
1149        fn enable_moe_topk_capture(&mut self, num_experts: usize) -> bool {
1150            let cap = rlx_cpu::moe_topk_capture::MoeTopkCapture::new(num_experts);
1151            self.moe_topk_capture = Some(cap.clone());
1152            self.schedule.moe_topk_capture = Some(cap);
1153            true
1154        }
1155
1156        fn take_moe_topk_capture(&mut self) -> Option<Vec<Vec<u32>>> {
1157            let cap = self.moe_topk_capture.as_ref()?;
1158            let layers = cap.take_layers();
1159            if layers.is_empty() {
1160                None
1161            } else {
1162                Some(layers)
1163            }
1164        }
1165
1166        fn take_moe_residency_stats(&mut self) -> Option<crate::MoeResidencyStats> {
1167            rlx_cpu::moe_residency::take_last_forward_stats()
1168        }
1169
1170        /// Typed param upload. F32 / F16 / BF16 go through the existing
1171        /// widen-to-f32 path (the CPU arena is historically f32 with
1172        /// optional half-precision rewrite). F64 (and any future
1173        /// non-widenable dtype) lands directly in the arena as bytes —
1174        /// the f32 path would lose precision.
1175        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1176            if matches!(dtype, DType::F64 | DType::I64 | DType::I32 | DType::U32) {
1177                self.set_param_bytes(name, data, dtype);
1178                return;
1179            }
1180            // U8 / I8 raw byte tensors: opaque storage for the GGUF
1181            // K-quant `Op::DequantMatMul` path (weights stay packed
1182            // in the arena). One arena byte = one element.
1183            if matches!(dtype, DType::U8 | DType::I8) {
1184                self.set_param_bytes(name, data, dtype);
1185                return;
1186            }
1187            if dtype == DType::F32 {
1188                let n = data.len() / 4;
1189                let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
1190                self.set_param(name, s);
1191            } else {
1192                let f32_buf = super::widen_bytes_to_f32(data, dtype);
1193                self.set_param(name, &f32_buf);
1194            }
1195        }
1196
1197        /// Typed run with mixed-dtype inputs/outputs.
1198        ///
1199        /// For each input: if its declared graph dtype matches the
1200        /// caller's bytes, we write directly into the arena (zero
1201        /// precision loss — F64 stays F64). For F32 with a half-precision
1202        /// arena rewrite, we widen as before. F16/BF16 callers go
1203        /// through the existing widen path.
1204        ///
1205        /// Outputs are read straight from the arena in the graph node's
1206        /// declared dtype — F64 outputs come back as 8 bytes/element,
1207        /// F32 as 4, etc.
1208        fn run_typed(
1209            &mut self,
1210            inputs: &[(&str, &[u8], rlx_ir::DType)],
1211        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1212            // Decide: are *all* inputs F64? If so, use the direct-byte
1213            // path for everything and skip the f32 widening machinery
1214            // entirely. Mixed dtype graphs (F32 + F64) take the
1215            // per-input dispatch route below.
1216            let all_f64 = !inputs.is_empty() && inputs.iter().all(|(_, _, dt)| *dt == DType::F64);
1217
1218            if all_f64 {
1219                for (name, data, _) in inputs {
1220                    if let Some(&id) = self.input_ids.get(*name) {
1221                        if !self.arena.has_buffer(id) {
1222                            continue;
1223                        }
1224                        let off = self.arena.byte_offset(id);
1225                        let buf = self.arena.raw_buf_mut();
1226                        let n = data.len();
1227                        debug_assert!(
1228                            off + n <= buf.len(),
1229                            "run_typed: input '{name}' overflows arena slot"
1230                        );
1231                        buf[off..off + n].copy_from_slice(data);
1232                    }
1233                }
1234                thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1235            } else {
1236                // Mixed-dtype path: dtypes that survive untouched
1237                // through the f32-aliased arena (F64, I32, I64, U32)
1238                // go in as bytes; F32 and the half-precision family
1239                // route through widen-to-f32 + run.
1240                let mut f32_owned: Vec<(String, Vec<f32>)> = Vec::new();
1241                for (name, data, dt) in inputs {
1242                    let direct = matches!(
1243                        *dt,
1244                        DType::F64 | DType::I32 | DType::I64 | DType::U32 | DType::C64
1245                    );
1246                    if direct {
1247                        if let Some(&id) = self.input_ids.get(*name) {
1248                            if !self.arena.has_buffer(id) {
1249                                continue;
1250                            }
1251                            let off = self.arena.byte_offset(id);
1252                            let buf = self.arena.raw_buf_mut();
1253                            buf[off..off + data.len()].copy_from_slice(data);
1254                        }
1255                    } else {
1256                        let v = super::widen_bytes_to_f32(data, *dt);
1257                        f32_owned.push((name.to_string(), v));
1258                    }
1259                }
1260                for (name, data) in &f32_owned {
1261                    if let Some(&id) = self.input_ids.get(name.as_str()) {
1262                        if self.arena.has_buffer(id) {
1263                            self.write_input(id, data);
1264                        }
1265                    }
1266                }
1267                let active_used = if let Some((actual, upper)) = self.active_extent {
1268                    thunk::execute_thunks_active(
1269                        &self.schedule,
1270                        self.arena.raw_buf_mut(),
1271                        actual,
1272                        upper,
1273                    )
1274                } else {
1275                    false
1276                };
1277                if !active_used {
1278                    thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1279                }
1280            }
1281
1282            // Read outputs in declared boundary dtypes.
1283            self.graph
1284                .outputs
1285                .iter()
1286                .enumerate()
1287                .map(|(idx, &id)| {
1288                    let exec_dtype = self.graph.node(id).shape.dtype();
1289                    let declared = self.io_manifest.output_dtype(idx, exec_dtype);
1290                    if matches!(
1291                        exec_dtype,
1292                        DType::F64
1293                            | DType::F16
1294                            | DType::BF16
1295                            | DType::I32
1296                            | DType::I64
1297                            | DType::U32
1298                            | DType::C64
1299                    ) {
1300                        let n_elems = self.graph.node(id).shape.num_elements().unwrap_or(0);
1301                        let n_bytes = n_elems * exec_dtype.size_bytes();
1302                        let off = self.arena.byte_offset(id);
1303                        let bytes = self.arena.raw_buf()[off..off + n_bytes].to_vec();
1304                        return (bytes, declared);
1305                    }
1306                    let f32_vals = self.read_output(id);
1307                    if declared != exec_dtype {
1308                        return (super::narrow_f32_to_bytes(&f32_vals, declared), declared);
1309                    }
1310                    let bytes = f32_vals.iter().flat_map(|v| v.to_le_bytes()).collect();
1311                    (bytes, declared)
1312                })
1313                .collect()
1314        }
1315    }
1316
1317    impl CpuExecutable {
1318        /// Clear ephemeral arena slots, then restore compile-time constants
1319        /// and cached params. Intermediate buffers are reused across `run()`
1320        /// calls; without this reset, a second execution can read stale data
1321        /// from the previous pass.
1322        fn restore_arena_baseline(&mut self) {
1323            self.arena.raw_buf_mut().fill(0);
1324            let constants: Vec<(NodeId, DType, Vec<u8>)> = self
1325                .graph
1326                .nodes()
1327                .iter()
1328                .filter_map(|node| {
1329                    if let Op::Constant { data } = &node.op
1330                        && self.arena.has_buffer(node.id)
1331                        && !data.is_empty()
1332                    {
1333                        Some((node.id, node.shape.dtype(), data.clone()))
1334                    } else {
1335                        None
1336                    }
1337                })
1338                .collect();
1339            for (id, dtype, data) in constants {
1340                self.write_constant_to_arena(id, dtype, &data);
1341            }
1342            let params = self.params.clone();
1343            for (name, data) in params {
1344                if let Some(&id) = self.param_ids.get(&name)
1345                    && self.arena.has_buffer(id)
1346                {
1347                    let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
1348                    let off = self.arena.byte_offset(id);
1349                    let buf = self.arena.raw_buf_mut();
1350                    let elem_size = dtype.size_bytes();
1351                    let max_elems = (buf.len() - off) / elem_size;
1352                    unsafe {
1353                        write_typed_from_f32(buf.as_mut_ptr().add(off), dtype, &data, max_elems);
1354                    }
1355                }
1356            }
1357            let typed = self.typed_params.clone();
1358            for (name, (data, dtype)) in typed {
1359                self.write_param_bytes_to_arena(&name, &data);
1360                let _ = dtype;
1361            }
1362        }
1363
1364        fn write_constant_to_arena(&mut self, id: NodeId, dtype: DType, data: &[u8]) {
1365            match dtype {
1366                DType::F64 | DType::F16 | DType::BF16 | DType::U8 | DType::I8 => {
1367                    let off = self.arena.byte_offset(id);
1368                    let buf = self.arena.raw_buf_mut();
1369                    let n = buf.len().saturating_sub(off).min(data.len());
1370                    buf[off..off + n].copy_from_slice(&data[..n]);
1371                }
1372                _ => {
1373                    let buf = self.arena.slice_mut(id);
1374                    let n_floats = data.len() / 4;
1375                    let n = buf.len().min(n_floats);
1376                    for i in 0..n {
1377                        let bytes = [
1378                            data[i * 4],
1379                            data[i * 4 + 1],
1380                            data[i * 4 + 2],
1381                            data[i * 4 + 3],
1382                        ];
1383                        buf[i] = f32::from_le_bytes(bytes);
1384                    }
1385                }
1386            }
1387        }
1388
1389        /// Direct-byte param upload — copies caller's bytes into the
1390        /// arena slot for the named param without any dtype conversion.
1391        /// Used by `set_param_typed` for dtypes that f32-widening would
1392        /// corrupt (F64). Caller is responsible for matching the param's
1393        /// declared graph dtype.
1394        fn set_param_bytes(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1395            self.typed_params
1396                .insert(name.to_string(), (data.to_vec(), dtype));
1397            self.params.remove(name);
1398            self.write_param_bytes_to_arena(name, data);
1399        }
1400
1401        fn write_param_bytes_to_arena(&mut self, name: &str, data: &[u8]) {
1402            if let Some(&id) = self.param_ids.get(name)
1403                && self.arena.has_buffer(id)
1404            {
1405                let off = self.arena.byte_offset(id);
1406                let buf = self.arena.raw_buf_mut();
1407                debug_assert!(
1408                    off + data.len() <= buf.len(),
1409                    "set_param_bytes: '{name}' would overflow arena slot"
1410                );
1411                buf[off..off + data.len()].copy_from_slice(data);
1412            }
1413        }
1414    }
1415}
1416
1417// ── Metal Backend ───────────────────────────────────────────────────────
1418
1419// ── wgpu Backend ────────────────────────────────────────────────────────
1420
1421#[cfg(feature = "gpu")]
1422pub mod wgpu_backend {
1423    use super::*;
1424    use rlx_ir::OpKind;
1425    use rlx_wgpu::backend::WgpuExecutable;
1426
1427    pub struct WgpuBackend;
1428
1429    /// PLAN L4: ops the wgpu backend can lower today. The fused
1430    /// macro-kernels (FAB, FTL, FusedSwiGLU) get decomposed by
1431    /// `crate::unfuse::unfuse` upstream — they're listed here too so
1432    /// graphs that already contain them legalize cleanly. Conv1d/3d
1433    /// and Pool1d/3d are deferred (Conv2d only).
1434    const WGPU_SUPPORTED_OPS: &[OpKind] = &[
1435        OpKind::Input,
1436        OpKind::Param,
1437        OpKind::Constant,
1438        OpKind::Activation,
1439        OpKind::Cast,
1440        OpKind::StopGradient,
1441        OpKind::Binary,
1442        OpKind::Compare,
1443        OpKind::Where,
1444        OpKind::ElementwiseRegion,
1445        OpKind::TransformRegion,
1446        OpKind::BatchElementwiseRegion,
1447        OpKind::MatMul,
1448        OpKind::DotGeneral,
1449        OpKind::LayerNorm,
1450        OpKind::RmsNorm,
1451        OpKind::Attention,
1452        OpKind::AttentionBackward,
1453        OpKind::RmsNormBackwardInput,
1454        OpKind::RmsNormBackwardGamma,
1455        OpKind::RmsNormBackwardBeta,
1456        // LayerNorm backward family:
1457        //   * Input  — single workgroup-per-row fused kernel.
1458        //   * Gamma  — two-dispatch (partial + reduce) that uses a tail
1459        //              scratch zone in the arena to hold per-chunk
1460        //              partial sums; the reduce kernel sums them.
1461        // Both beat the autodiff-decomposed primitive chain.
1462        OpKind::LayerNormBackwardInput,
1463        OpKind::LayerNormBackwardGamma,
1464        OpKind::RopeBackward,
1465        OpKind::CumsumBackward,
1466        OpKind::GatherBackward,
1467        OpKind::Rope,
1468        OpKind::Reshape,
1469        OpKind::Transpose,
1470        OpKind::Narrow,
1471        OpKind::Concat,
1472        OpKind::Expand,
1473        OpKind::Gather,
1474        OpKind::Reduce,
1475        OpKind::Softmax,
1476        OpKind::Cumsum,
1477        OpKind::TopK,
1478        OpKind::Sample,
1479        OpKind::Conv,
1480        OpKind::Im2Col,
1481        OpKind::Pool,
1482        OpKind::GroupedMatMul,
1483        OpKind::DequantGroupedMatMul,
1484        OpKind::DequantMoEWeights,
1485        OpKind::ScatterAdd,
1486        OpKind::SelectiveScan,
1487        OpKind::Lstm,
1488        OpKind::DequantMatMul,
1489        OpKind::FusedMatMulBiasAct,
1490        OpKind::FusedResidualLN,
1491        OpKind::FusedResidualRmsNorm,
1492        OpKind::FusedSwiGLU,
1493        OpKind::FusedAttentionBlock,
1494        OpKind::FusedTransformerLayer,
1495        // Native FFT (WGSL radix-2): f32 only, power-of-2 N ≤ 1024.
1496        // Anything outside that envelope panics at lowering with a
1497        // "pin to Device::Cpu" hint. No host fallback — WGPU has no
1498        // unified memory, so silent CPU round-trip would be a hidden
1499        // performance cliff.
1500        OpKind::Fft,
1501        OpKind::LogMel,
1502        OpKind::LogMelBackward,
1503        OpKind::WelchPeaks,
1504        // 3D Gaussian splat: native Metal / CPU reference per backend.
1505        OpKind::GaussianSplatRender,
1506        OpKind::GaussianSplatRenderBackward,
1507        OpKind::GaussianSplatPrepare,
1508        OpKind::GaussianSplatRasterize,
1509        OpKind::Custom,
1510        OpKind::RngNormal,
1511        OpKind::RngUniform,
1512        // LoRA, If, While: not yet wired in wgpu — fail loudly.
1513    ];
1514
1515    impl Backend for WgpuBackend {
1516        fn supported_ops(&self) -> &'static [OpKind] {
1517            WGPU_SUPPORTED_OPS
1518        }
1519
1520        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
1521            use rlx_opt::pass::Pass as _;
1522            let graph = rlx_opt::LowerControlFlow.run(graph);
1523            let graph = rlx_opt::legalize_or_rewrite_for_backend(graph, WGPU_SUPPORTED_OPS)
1524                .unwrap_or_else(|errors| {
1525                    panic!("{}", rlx_opt::format_legalize_error("wgpu", &errors));
1526                });
1527            let graph = crate::precompile::precompile_cleanup(graph, options);
1528            // Materialize mid-axis broadcasts before MarkElementwiseRegions:
1529            // wgpu Binary/region kernels only handle trailing/scalar broadcast
1530            // via modulus; EEG patch embed uses [1,C,1,D] + [1,C,P,D].
1531            let graph = rlx_opt::LegalizeBroadcast.run(graph);
1532            // ORDER MATTERS: targeted-pattern fusions run BEFORE the
1533            // catch-all `MarkElementwiseRegions`. Otherwise the region
1534            // pass swallows the Add / Activation nodes into chains and
1535            // FuseMatMulBiasAct / FuseResidualLN fail to match the
1536            // narrower patterns they look for. (Metal pipeline at line
1537            // ~377 already orders these correctly; wgpu was inverted
1538            // and silently shipped 13 unfused LayerNorms per BERT
1539            // forward where 12 should have been FusedResidualLN.)
1540            let compile_result = crate::stages::compile_graph_stages_for_backend(
1541                rlx_driver::Device::Gpu,
1542                graph,
1543                options,
1544                WGPU_SUPPORTED_OPS,
1545            );
1546            crate::stages::maybe_log_fusion(&compile_result.fusion);
1547            let graph = compile_result.lir.into_graph();
1548            let graph = match options.policy.clone() {
1549                Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
1550                None => graph,
1551            };
1552            let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
1553            Box::new(WgpuExecutableWrapper {
1554                inner: WgpuExecutable::compile_rng(graph, options.rng),
1555                io_manifest,
1556            })
1557        }
1558
1559        fn compile_lir(
1560            &self,
1561            lir: LirModule,
1562            options: &CompileOptions,
1563        ) -> Box<dyn ExecutableGraph> {
1564            use rlx_opt::pass::Pass as _;
1565            // LIR may already contain fused ElementwiseRegions; legalize
1566            // broadcasts on the unfused graph shape before backend prep.
1567            let graph = rlx_opt::LegalizeBroadcast.run(lir.into_graph());
1568            let graph = prepare_fused_graph(graph, options, WGPU_SUPPORTED_OPS, "wgpu");
1569            let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
1570            Box::new(WgpuExecutableWrapper {
1571                inner: WgpuExecutable::compile_rng(graph, options.rng),
1572                io_manifest,
1573            })
1574        }
1575    }
1576
1577    struct WgpuExecutableWrapper {
1578        inner: WgpuExecutable,
1579        io_manifest: cpu_low_precision::IoDtypeManifest,
1580    }
1581
1582    unsafe impl Send for WgpuExecutableWrapper {}
1583
1584    impl ExecutableGraph for WgpuExecutableWrapper {
1585        fn set_param(&mut self, name: &str, data: &[f32]) {
1586            self.inner.set_param(name, data);
1587        }
1588        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
1589            self.inner.run(inputs)
1590        }
1591        fn run_read_outputs(
1592            &mut self,
1593            inputs: &[(&str, &[f32])],
1594            read_indices: Option<&[usize]>,
1595        ) -> Vec<Vec<f32>> {
1596            self.inner.run_read_outputs(inputs, read_indices)
1597        }
1598        fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
1599            self.inner.bind_gpu_handle(name, data)
1600        }
1601        fn has_gpu_handle(&self, name: &str) -> bool {
1602            self.inner.has_gpu_handle(name)
1603        }
1604        fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
1605            self.inner.set_gpu_handle_feed(handle_name, output_index);
1606            true
1607        }
1608        fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
1609            self.inner.read_gpu_handle(name)
1610        }
1611        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1612            self.inner.set_active_extent(extent);
1613        }
1614
1615        fn set_rng(&mut self, rng: rlx_ir::RngOptions) {
1616            self.inner.set_rng(rng);
1617        }
1618
1619        fn rng(&self) -> rlx_ir::RngOptions {
1620            self.inner.rng()
1621        }
1622
1623        /// Typed param upload: widens F16/BF16 to F32 at the host boundary,
1624        /// since the wgpu arena is f32-uniform.
1625        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1626            match dtype {
1627                rlx_ir::DType::U8 | rlx_ir::DType::I8 => {
1628                    self.inner.set_param_bytes(name, data);
1629                }
1630                rlx_ir::DType::F32 => {
1631                    let n = data.len() / 4;
1632                    let f32_slice =
1633                        unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
1634                    self.inner.set_param(name, f32_slice);
1635                }
1636                rlx_ir::DType::F16 => {
1637                    let n = data.len() / 2;
1638                    let f16_slice =
1639                        unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n) };
1640                    let f32: Vec<f32> = f16_slice.iter().map(|h| h.to_f32()).collect();
1641                    self.inner.set_param(name, &f32);
1642                }
1643                rlx_ir::DType::BF16 => {
1644                    let n = data.len() / 2;
1645                    let bf16_slice = unsafe {
1646                        std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n)
1647                    };
1648                    let f32: Vec<f32> = bf16_slice.iter().map(|h| h.to_f32()).collect();
1649                    self.inner.set_param(name, &f32);
1650                }
1651                other => panic!(
1652                    "rlx-wgpu set_param_typed: dtype {other:?} unsupported \
1653                                 (F32, F16, BF16 only — wgpu arena is f32-uniform)"
1654                ),
1655            }
1656        }
1657
1658        /// Typed run: widen each typed input to F32, run, then narrow each
1659        /// output back to its declared dtype.
1660        fn run_typed(
1661            &mut self,
1662            inputs: &[(&str, &[u8], rlx_ir::DType)],
1663        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1664            let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
1665            for (name, data, dt) in inputs {
1666                let v: Vec<f32> = match *dt {
1667                    rlx_ir::DType::F32 => {
1668                        let n = data.len() / 4;
1669                        unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) }
1670                            .to_vec()
1671                    }
1672                    rlx_ir::DType::F16 => {
1673                        let n = data.len() / 2;
1674                        let s = unsafe {
1675                            std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n)
1676                        };
1677                        s.iter().map(|h| h.to_f32()).collect()
1678                    }
1679                    rlx_ir::DType::BF16 => {
1680                        let n = data.len() / 2;
1681                        let s = unsafe {
1682                            std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n)
1683                        };
1684                        s.iter().map(|h| h.to_f32()).collect()
1685                    }
1686                    // Integer/bool inputs (e.g. embedding indices `phone_ids`) are
1687                    // widened to f32, matching the f32-arena convention shared with
1688                    // the CPU backend (Gather etc. operate on f32-encoded indices).
1689                    rlx_ir::DType::I64 => {
1690                        let n = data.len() / 8;
1691                        let s =
1692                            unsafe { std::slice::from_raw_parts(data.as_ptr() as *const i64, n) };
1693                        s.iter().map(|&x| x as f32).collect()
1694                    }
1695                    rlx_ir::DType::I32 => {
1696                        let n = data.len() / 4;
1697                        let s =
1698                            unsafe { std::slice::from_raw_parts(data.as_ptr() as *const i32, n) };
1699                        s.iter().map(|&x| x as f32).collect()
1700                    }
1701                    rlx_ir::DType::U8 | rlx_ir::DType::I8 | rlx_ir::DType::Bool => {
1702                        data.iter().map(|&b| b as f32).collect()
1703                    }
1704                    other => {
1705                        panic!("rlx-wgpu run_typed: input '{name}' dtype {other:?} unsupported")
1706                    }
1707                };
1708                owned.push((name.to_string(), v));
1709            }
1710            let refs: Vec<(&str, &[f32])> = owned
1711                .iter()
1712                .map(|(n, d)| (n.as_str(), d.as_slice()))
1713                .collect();
1714            let dtypes =
1715                super::declared_output_dtypes(&self.io_manifest, self.inner.output_dtypes());
1716            let outs = self.inner.run(&refs);
1717            outs.into_iter()
1718                .zip(
1719                    dtypes
1720                        .into_iter()
1721                        .chain(std::iter::repeat(rlx_ir::DType::F32)),
1722                )
1723                .map(|(v, dt)| (narrow_to_dtype(&v, dt), dt))
1724                .collect()
1725        }
1726
1727        fn clone_box(&self) -> Box<dyn ExecutableGraph> {
1728            Box::new(WgpuExecutableWrapper {
1729                inner: self.inner.clone_for_cache(),
1730                io_manifest: self.io_manifest.clone(),
1731            })
1732        }
1733    }
1734
1735    /// Cast every element of a wgpu f32 output buffer down to the
1736    /// declared output dtype, returning the corresponding byte stream.
1737    /// The arena keeps every value as f32; declared output dtypes
1738    /// (Bool, I8, I32, F16, ...) require an exit-time narrowing to be
1739    /// byte-identical with backends that store the native dtype.
1740    fn narrow_to_dtype(v: &[f32], dt: rlx_ir::DType) -> Vec<u8> {
1741        use rlx_ir::DType;
1742        match dt {
1743            DType::F32 => {
1744                let mut bytes = Vec::with_capacity(v.len() * 4);
1745                for &x in v {
1746                    bytes.extend_from_slice(&x.to_le_bytes());
1747                }
1748                bytes
1749            }
1750            DType::F16 => {
1751                let mut bytes = Vec::with_capacity(v.len() * 2);
1752                for &x in v {
1753                    bytes.extend_from_slice(&half::f16::from_f32(x).to_le_bytes());
1754                }
1755                bytes
1756            }
1757            DType::BF16 => {
1758                let mut bytes = Vec::with_capacity(v.len() * 2);
1759                for &x in v {
1760                    bytes.extend_from_slice(&half::bf16::from_f32(x).to_le_bytes());
1761                }
1762                bytes
1763            }
1764            DType::F64 => {
1765                let mut bytes = Vec::with_capacity(v.len() * 8);
1766                for &x in v {
1767                    bytes.extend_from_slice(&(x as f64).to_le_bytes());
1768                }
1769                bytes
1770            }
1771            DType::I8 => v.iter().map(|&x| x as i8 as u8).collect(),
1772            DType::U8 => v.iter().map(|&x| x as u8).collect(),
1773            DType::I16 => {
1774                let mut bytes = Vec::with_capacity(v.len() * 2);
1775                for &x in v {
1776                    bytes.extend_from_slice(&(x as i16).to_le_bytes());
1777                }
1778                bytes
1779            }
1780            DType::I32 => {
1781                let mut bytes = Vec::with_capacity(v.len() * 4);
1782                for &x in v {
1783                    bytes.extend_from_slice(&(x as i32).to_le_bytes());
1784                }
1785                bytes
1786            }
1787            DType::U32 => {
1788                let mut bytes = Vec::with_capacity(v.len() * 4);
1789                for &x in v {
1790                    bytes.extend_from_slice(&(x as u32).to_le_bytes());
1791                }
1792                bytes
1793            }
1794            DType::I64 => {
1795                let mut bytes = Vec::with_capacity(v.len() * 8);
1796                for &x in v {
1797                    bytes.extend_from_slice(&(x as i64).to_le_bytes());
1798                }
1799                bytes
1800            }
1801            DType::Bool => v
1802                .iter()
1803                .map(|&x| if x != 0.0 { 1u8 } else { 0u8 })
1804                .collect(),
1805            // C64 (complex f32 pair) — the wgpu backend's f32 arena
1806            // doesn't synthesize complex outputs today; this branch
1807            // only fires if a graph somehow asks for a C64 output and
1808            // the backend lowered it as 2N real floats. We pass the
1809            // raw f32 stream straight through; downstream code that
1810            // wants complex semantics is responsible for re-pairing.
1811            DType::C64 => {
1812                let mut bytes = Vec::with_capacity(v.len() * 4);
1813                for &x in v {
1814                    bytes.extend_from_slice(&x.to_le_bytes());
1815                }
1816                bytes
1817            }
1818        }
1819    }
1820}
1821
1822// ── MLX Backend ─────────────────────────────────────────────────────────
1823
1824#[cfg(all(feature = "mlx", rlx_mlx_host))]
1825pub mod mlx_backend {
1826    use super::*;
1827    use rlx_mlx::MlxExecutable;
1828
1829    pub struct MlxBackend;
1830
1831    /// PLAN L4: ops the MLX backend can lower today. MLX has the
1832    /// widest IR coverage of any GPU backend — handles everything
1833    /// including If/While via topo unrolling, and lowers
1834    /// ElementwiseRegion natively via the per-step composition in
1835    /// rlx-mlx/src/lower.rs (PLAN L2).
1836    ///
1837    /// `GroupNorm` / `BatchNormInference` are intentionally omitted — lowered
1838    /// to primitives via [`LowerGroupNorm`] / [`LowerBatchNormInference`]
1839    /// before MLX lowering (no native MLX kernel).
1840    const MLX_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
1841        use rlx_ir::OpKind::*;
1842        &[
1843            Input,
1844            Param,
1845            Constant,
1846            Activation,
1847            Cast,
1848            StopGradient,
1849            Binary,
1850            Compare,
1851            Where,
1852            ElementwiseRegion,
1853            TransformRegion,
1854            BatchElementwiseRegion,
1855            MatMul,
1856            DotGeneral,
1857            DenseSolve,
1858            BatchedDenseSolve,
1859            LayerNorm,
1860            LayerNorm2d,
1861            ResizeNearest2x,
1862            RmsNorm,
1863            Attention,
1864            Rope,
1865            Reshape,
1866            Transpose,
1867            Narrow,
1868            Concat,
1869            Expand,
1870            Gather,
1871            Reduce,
1872            Softmax,
1873            Cumsum,
1874            TopK,
1875            RngNormal,
1876            RngUniform,
1877            Sample,
1878            Conv,
1879            ConvTranspose2d,
1880            Pool,
1881            GroupedMatMul,
1882            DequantGroupedMatMul,
1883            DequantMoEWeights,
1884            ScatterAdd,
1885            LoraMatMul,
1886            DequantMatMul,
1887            SelectiveScan,
1888            GatedDeltaNet,
1889            FusedSwiGLU,
1890            FusedMatMulBiasAct,
1891            FusedResidualLN,
1892            FusedResidualRmsNorm,
1893            FusedAttentionBlock,
1894            FusedTransformerLayer,
1895            If,
1896            While,
1897            // Loop-unrolled scan (Op::Scan body is statically unrolled
1898            // `length` times into MLX ops; mirror of Op::While's
1899            // bounded-unroll lowering). ScanBackward is the AD
1900            // companion — handled the same way.
1901            Scan,
1902            ScanBackward,
1903            ScanBackwardXs,
1904            // Tier 1 autodiff backward ops — lowered as primitive
1905            // compositions in `rlx-mlx/src/lower.rs`.
1906            ReluBackward,
1907            ActivationBackward,
1908            SoftmaxCrossEntropyWithLogits,
1909            SoftmaxCrossEntropyBackward,
1910            AttentionBackward,
1911            LayerNormBackwardInput,
1912            LayerNormBackwardGamma,
1913            // Tier 2 — conv backward via `mc::conv_general` with the
1914            // same parameter-mapping MLX uses inside its built-in vjp.
1915            // Currently groups=1 only; grouped conv backward will
1916            // surface as a clear error from `lower.rs`.
1917            Conv2dBackwardInput,
1918            Conv2dBackwardWeight,
1919            // Tier 3 — max-pool backward via slice-strided argmax over
1920            // pool windows + a per-kernel-slot scatter-add, matching
1921            // the CPU thunk's "first-hit-wins" tiebreaking.
1922            MaxPool2dBackward,
1923            // QAT — `FakeQuantize` (PerBatch + Fixed scale modes;
1924            // EMA returns a clear error from `lower.rs`) and the
1925            // `FakeQuantizeBackward` family covering all 4 STE
1926            // variants. Closes the last gap vs `CPU_SUPPORTED_OPS`.
1927            FakeQuantize,
1928            FakeQuantizeBackward,
1929            // User-registered custom ops dispatched through
1930            // `rlx_mlx::op_registry`. Lowering looks up the
1931            // registered `MlxKernel` and calls its `execute` method
1932            // to produce the lazy MLX `Array` for this node.
1933            Custom,
1934            Fft,
1935            LogMel,
1936            LogMelBackward,
1937            WelchPeaks,
1938            GaussianSplatRender,
1939            GaussianSplatRenderBackward,
1940            // Op::Fft on MLX: native `mlx::fft::fft` via rlx_mlx_op_fft shim.
1941            // 2N real-block f32/f64 and complex64 inputs supported.
1942        ]
1943    };
1944
1945    impl Backend for MlxBackend {
1946        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
1947            MLX_SUPPORTED_OPS
1948        }
1949
1950        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
1951            let compile_result = crate::stages::compile_graph_stages_for_backend(
1952                rlx_driver::Device::Mlx,
1953                graph,
1954                options,
1955                MLX_SUPPORTED_OPS,
1956            );
1957            crate::stages::maybe_log_fusion(&compile_result.fusion);
1958            self.compile_lir(compile_result.lir, options)
1959        }
1960
1961        fn compile_lir(
1962            &self,
1963            lir: LirModule,
1964            options: &CompileOptions,
1965        ) -> Box<dyn ExecutableGraph> {
1966            use rlx_opt::pass::Pass as _;
1967            let mut graph = lir.into_graph();
1968            graph = rlx_opt::LowerControlFlow.run(graph);
1969            let graph = prepare_fused_graph(graph, options, MLX_SUPPORTED_OPS, "mlx");
1970            Box::new(build_mlx_executable(graph, options.rng))
1971        }
1972    }
1973
1974    fn build_mlx_executable(graph: Graph, rng: rlx_ir::RngOptions) -> MlxExecutableWrapper {
1975        let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
1976        let mode = mlx_mode_from_env();
1977        let mut exe = MlxExecutable::compile_from_fused_with_rng(graph, mode, rng);
1978        if mode == rlx_mlx::lower::MlxMode::Compiled {
1979            if let Err(e) = exe.warm_compile() {
1980                eprintln!(
1981                    "[rlx-runtime] MLX warm_compile failed ({e}); first run will pay the trace cost"
1982                );
1983            }
1984        }
1985        MlxExecutableWrapper {
1986            inner: exe,
1987            io_manifest,
1988        }
1989    }
1990
1991    fn mlx_mode_from_env() -> rlx_mlx::lower::MlxMode {
1992        match rlx_ir::env::var("RLX_MLX_MODE").as_deref() {
1993            Some(s) if s.eq_ignore_ascii_case("eager") => rlx_mlx::lower::MlxMode::Eager,
1994            Some(s) if s.eq_ignore_ascii_case("lazy") => rlx_mlx::lower::MlxMode::Lazy,
1995            Some(s) if s.eq_ignore_ascii_case("compiled") => rlx_mlx::lower::MlxMode::Compiled,
1996            _ => rlx_mlx::lower::MlxMode::Compiled,
1997        }
1998    }
1999
2000    struct MlxExecutableWrapper {
2001        inner: MlxExecutable,
2002        io_manifest: cpu_low_precision::IoDtypeManifest,
2003    }
2004
2005    unsafe impl Send for MlxExecutableWrapper {}
2006
2007    impl ExecutableGraph for MlxExecutableWrapper {
2008        fn set_param(&mut self, name: &str, data: &[f32]) {
2009            self.inner.set_param(name, data);
2010        }
2011        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2012            self.inner.run(inputs)
2013        }
2014        fn run_read_outputs(
2015            &mut self,
2016            inputs: &[(&str, &[f32])],
2017            read_indices: Option<&[usize]>,
2018        ) -> Vec<Vec<f32>> {
2019            self.inner
2020                .run_read_outputs(inputs, read_indices)
2021                .unwrap_or_else(|e| panic!("MLX run_read_outputs failed: {e}"))
2022        }
2023        fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
2024            self.inner.run_slots(inputs)
2025        }
2026        fn arena_ptr(&self) -> *const u8 {
2027            self.inner.arena_ptr()
2028        }
2029        fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
2030            self.inner.commit_no_wait(inputs);
2031        }
2032        fn sync_pending(&mut self) {
2033            self.inner.sync_pending();
2034        }
2035        fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
2036            self.inner.run_pipelined(input_sets)
2037        }
2038        fn bind_handle(&mut self, name: &str, data: &[f32]) -> bool {
2039            self.inner.bind_handle(name, data)
2040        }
2041        fn read_handle(&self, name: &str) -> Option<Vec<f32>> {
2042            self.inner.read_handle(name)
2043        }
2044        fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
2045            self.inner.bind_gpu_handle(name, data).is_ok()
2046        }
2047        fn has_gpu_handle(&self, name: &str) -> bool {
2048            self.inner.has_gpu_handle(name)
2049        }
2050        fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
2051            self.inner.set_gpu_handle_feed(handle_name, output_index);
2052            true
2053        }
2054        fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
2055            self.inner.read_gpu_handle(name).ok()
2056        }
2057        fn run_feed_gpu_handle(
2058            &mut self,
2059            inputs: &[(&str, &[f32])],
2060            handle_name: &str,
2061            output_index: usize,
2062        ) -> Option<Vec<f32>> {
2063            self.inner
2064                .run_feed_gpu(inputs, handle_name, output_index)
2065                .ok()
2066        }
2067        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2068            self.inner.set_param_typed(name, data, dtype);
2069        }
2070        fn run_typed(
2071            &mut self,
2072            inputs: &[(&str, &[u8], rlx_ir::DType)],
2073        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2074            self.inner.run_typed(inputs)
2075        }
2076        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2077            self.inner.set_active_extent(extent);
2078        }
2079
2080        fn set_rng(&mut self, rng: rlx_ir::RngOptions) {
2081            self.inner.set_rng(rng);
2082        }
2083
2084        fn rng(&self) -> rlx_ir::RngOptions {
2085            self.inner.rng()
2086        }
2087
2088        fn clone_box(&self) -> Box<dyn ExecutableGraph> {
2089            Box::new(MlxExecutableWrapper {
2090                inner: self.inner.clone_for_cache(),
2091                io_manifest: self.io_manifest.clone(),
2092            })
2093        }
2094    }
2095}
2096
2097/// Ops with a MIL lowering today (see `rlx_coreml::mil`).
2098///
2099/// Single source of truth, shared by `coreml_backend::CoremlBackend::
2100/// supported_ops` and `device_ext::supports(Device::Ane, ..)` so the two
2101/// never drift. Ungated so the support probe compiles on every target
2102/// (it's only *consulted* when the `coreml` backend is available).
2103pub(crate) const COREML_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2104    use rlx_ir::OpKind::*;
2105    &[
2106        Input,
2107        Param,
2108        Constant,
2109        Activation,
2110        Cast,
2111        Binary,
2112        MatMul,
2113        LayerNorm,
2114        RmsNorm,
2115        Reduce,
2116        Softmax,
2117        Reshape,
2118        Transpose,
2119        Narrow,
2120        Concat,
2121        Gather,
2122        Rope,
2123        Attention,
2124        Compare,
2125        Where,
2126        Expand,
2127        Cumsum,
2128        ScatterAdd,
2129        BatchNormInference,
2130        GroupNorm,
2131        LayerNorm2d,
2132        LoraMatMul,
2133        Conv,
2134        ConvTranspose2d,
2135        Pool,
2136        TopK,
2137        AxialRope2d,
2138        ResizeNearest2x,
2139        StopGradient,
2140        GroupedMatMul,
2141        DequantMatMul,
2142        DequantMoEWeights,
2143        DequantGroupedMatMul,
2144        Quantize,
2145        Dequantize,
2146        SelectiveScan,
2147        GatedDeltaNet,
2148    ]
2149};
2150
2151/// Apple CoreML / Neural Engine (ANE) backend wiring.
2152///
2153/// Unlike the GPU backends, CoreML compiles a *static* model with weights
2154/// baked in, so we hand the raw graph (Param nodes intact) straight to
2155/// `rlx_coreml` and skip the fusion/LIR arena pipeline. Weights arrive via
2156/// `set_param`; the `.mlpackage` is built + loaded on `finalize_params`
2157/// (or lazily on first `run`).
2158#[cfg(all(feature = "coreml", any(target_os = "macos", target_os = "ios")))]
2159pub mod coreml_backend {
2160    use super::*;
2161    use rlx_coreml::CoremlExecutable;
2162
2163    pub struct CoremlBackend;
2164
2165    impl Backend for CoremlBackend {
2166        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2167            super::COREML_SUPPORTED_OPS
2168        }
2169
2170        fn compile(&self, graph: Graph, _options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2171            // Pass the graph as-is: CoreML bakes Params at build time, so we
2172            // must not let fusion strip them into an arena plan.
2173            Box::new(CoremlExecutableWrapper {
2174                inner: CoremlExecutable::compile(graph),
2175            })
2176        }
2177
2178        fn compile_lir(
2179            &self,
2180            lir: LirModule,
2181            options: &CompileOptions,
2182        ) -> Box<dyn ExecutableGraph> {
2183            // No LIR arena path for CoreML; reconstruct the graph and go
2184            // through the normal compile.
2185            self.compile(lir.into_graph(), options)
2186        }
2187    }
2188
2189    struct CoremlExecutableWrapper {
2190        inner: CoremlExecutable,
2191    }
2192
2193    // The loaded MLModel is owned exclusively and accessed via &mut self.
2194    unsafe impl Send for CoremlExecutableWrapper {}
2195
2196    impl ExecutableGraph for CoremlExecutableWrapper {
2197        fn set_param(&mut self, name: &str, data: &[f32]) {
2198            self.inner.set_param(name, data);
2199        }
2200
2201        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2202            // GGUF-quantized weights arrive as raw bytes; the lowering
2203            // host-dequantizes them when baking the model.
2204            self.inner.set_param_typed(name, data, dtype);
2205        }
2206
2207        fn finalize_params(&mut self) {
2208            self.inner
2209                .finalize()
2210                .unwrap_or_else(|e| panic!("CoreML finalize failed: {e}"));
2211        }
2212
2213        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2214            self.inner
2215                .run(inputs)
2216                .unwrap_or_else(|e| panic!("CoreML run failed: {e}"))
2217        }
2218
2219        fn clone_box(&self) -> Box<dyn ExecutableGraph> {
2220            Box::new(CoremlExecutableWrapper {
2221                inner: self.inner.clone_for_cache(),
2222            })
2223        }
2224
2225        fn run_typed(
2226            &mut self,
2227            inputs: &[(&str, &[u8], rlx_ir::DType)],
2228        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2229            use rlx_ir::DType;
2230            // The CoreML model is promoted to an f32 flow (`promote_int_to_f32`), so
2231            // widen integer/bool inputs (e.g. `phone_ids`) to f32 on the host surface.
2232            let owned: Vec<(String, Vec<f32>)> = inputs
2233                .iter()
2234                .map(|(name, data, dt)| {
2235                    let v: Vec<f32> = match dt {
2236                        DType::I64 => data
2237                            .chunks_exact(8)
2238                            .map(|c| i64::from_le_bytes(c.try_into().unwrap()) as f32)
2239                            .collect(),
2240                        DType::I32 => data
2241                            .chunks_exact(4)
2242                            .map(|c| i32::from_le_bytes(c.try_into().unwrap()) as f32)
2243                            .collect(),
2244                        DType::U8 | DType::Bool => data.iter().map(|&b| b as f32).collect(),
2245                        _ => super::widen_bytes_to_f32(data, *dt),
2246                    };
2247                    (name.to_string(), v)
2248                })
2249                .collect();
2250            let refs: Vec<(&str, &[f32])> = owned
2251                .iter()
2252                .map(|(n, d)| (n.as_str(), d.as_slice()))
2253                .collect();
2254            self.run(&refs)
2255                .into_iter()
2256                .map(|v| {
2257                    let bytes: Vec<u8> = v.iter().flat_map(|f| f.to_le_bytes()).collect();
2258                    (bytes, DType::F32)
2259                })
2260                .collect()
2261        }
2262    }
2263}
2264
2265#[cfg(all(feature = "metal", target_os = "macos"))]
2266pub mod metal_backend {
2267    use super::*;
2268    use rlx_metal::backend::MetalExecutable;
2269
2270    pub struct MetalBackend;
2271
2272    /// PLAN L4: ops the Metal backend can lower today. Includes
2273    /// DotGeneral (LowerDotGeneral pass) and ElementwiseRegion
2274    /// (decomposed by UnfuseElementwiseRegions). Excludes
2275    /// SelectiveScan, LoraMatMul, Sample,
2276    /// FusedAttentionBlock, FusedTransformerLayer, If, While —
2277    /// not yet wired in `rlx-metal/src/thunk.rs`'s compile_thunks.
2278    /// DequantMatMul (GGUF K-quants) lowers to a GPU dequant kernel
2279    /// + MPS matmul; legacy Int8 schemes remain CPU-only.
2280    ///
2281    const METAL_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2282        use rlx_ir::OpKind::*;
2283        &[
2284            Input,
2285            Param,
2286            Constant,
2287            Activation,
2288            Cast,
2289            StopGradient,
2290            Binary,
2291            Compare,
2292            Where,
2293            ElementwiseRegion,
2294            TransformRegion,
2295            BatchElementwiseRegion,
2296            MatMul,
2297            DotGeneral,
2298            LayerNorm,
2299            LayerNorm2d,
2300            GroupNorm,
2301            RmsNorm,
2302            ResizeNearest2x,
2303            AxialRope2d,
2304            Attention,
2305            AttentionBackward,
2306            RmsNormBackwardInput,
2307            RmsNormBackwardGamma,
2308            RmsNormBackwardBeta,
2309            RopeBackward,
2310            Cumsum,
2311            CumsumBackward,
2312            GatherBackward,
2313            Conv2dBackwardInput,
2314            Conv2dBackwardWeight,
2315            MaxPool2dBackward,
2316            Rope,
2317            Reshape,
2318            Transpose,
2319            Narrow,
2320            Concat,
2321            Expand,
2322            Gather,
2323            Reduce,
2324            Softmax,
2325            TopK,
2326            RngNormal,
2327            RngUniform,
2328            Conv,
2329            Im2Col,
2330            ConvTranspose2d,
2331            Pool,
2332            GroupedMatMul,
2333            DequantGroupedMatMul,
2334            DequantMoEWeights,
2335            ScatterAdd,
2336            DequantMatMul,
2337            GatedDeltaNet,
2338            Lstm,
2339            FusedSwiGLU,
2340            FusedMatMulBiasAct,
2341            FusedResidualLN,
2342            FusedResidualRmsNorm,
2343            // User-registered custom ops dispatched through
2344            // `rlx_metal::op_registry`. Lowering panics with a clear
2345            // message if the named MetalKernel isn't registered;
2346            // executor inserts a sync point + runs the host kernel
2347            // against the unified-memory arena.
2348            Custom,
2349            // Op::Fft is supported via the same host-fallback pattern
2350            // as Custom: sync the GPU, run rlx-cpu's FFT against the
2351            // unified-memory arena, restart cmd_buf. A native Metal
2352            // compute kernel will replace this when a workload makes
2353            // the sync the bottleneck.
2354            Fft,
2355            LogMel,
2356            LogMelBackward,
2357            WelchPeaks,
2358            // Host-fallback splat (unified-memory arena + rlx-cpu/splat).
2359            GaussianSplatRender,
2360            GaussianSplatRenderBackward,
2361            GaussianSplatPrepare,
2362            GaussianSplatRasterize,
2363        ]
2364    };
2365
2366    impl Backend for MetalBackend {
2367        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2368            METAL_SUPPORTED_OPS
2369        }
2370
2371        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2372            use rlx_opt::pass::Pass as _;
2373            // Same If/While → primitive rewrite as the CPU pipeline
2374            // (Metal also has no native sub-graph executor wired
2375            // through its thunk schedule).
2376            let graph = rlx_opt::LowerControlFlow.run(graph);
2377            let dispatch = options.kernel_dispatch;
2378            let graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
2379                graph,
2380                METAL_SUPPORTED_OPS,
2381                dispatch,
2382            )
2383            .unwrap_or_else(|errors| {
2384                panic!("{}", rlx_opt::format_legalize_error("metal", &errors));
2385            });
2386            let graph = crate::precompile::precompile_cleanup(graph, options);
2387
2388            // Hand the policy to MetalExecutable so the rewrite runs AFTER
2389            // its internal fusion passes (avoids breaking pattern matchers).
2390            let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
2391            Box::new(MetalExecutableWrapper {
2392                inner: MetalExecutable::compile_with_policy(
2393                    graph,
2394                    options.policy.clone(),
2395                    Some(METAL_SUPPORTED_OPS),
2396                    options.rng,
2397                ),
2398                io_manifest,
2399            })
2400        }
2401
2402        fn compile_lir(
2403            &self,
2404            lir: LirModule,
2405            options: &CompileOptions,
2406        ) -> Box<dyn ExecutableGraph> {
2407            use rlx_opt::pass::Pass as _;
2408            let mut graph = lir.into_graph();
2409            graph = rlx_opt::LowerControlFlow.run(graph);
2410            let dispatch = options.kernel_dispatch;
2411            let mut graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
2412                graph,
2413                METAL_SUPPORTED_OPS,
2414                dispatch,
2415            )
2416            .unwrap_or_else(|errors| {
2417                panic!("{}", rlx_opt::format_legalize_error("metal", &errors));
2418            });
2419            graph = crate::precompile::precompile_cleanup(graph, options);
2420            let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
2421            Box::new(MetalExecutableWrapper {
2422                inner: MetalExecutable::compile_from_fused(
2423                    graph,
2424                    options.policy.clone(),
2425                    Some(METAL_SUPPORTED_OPS),
2426                    options.rng,
2427                ),
2428                io_manifest,
2429            })
2430        }
2431    }
2432
2433    struct MetalExecutableWrapper {
2434        inner: MetalExecutable,
2435        io_manifest: cpu_low_precision::IoDtypeManifest,
2436    }
2437
2438    unsafe impl Send for MetalExecutableWrapper {}
2439
2440    impl ExecutableGraph for MetalExecutableWrapper {
2441        fn set_param(&mut self, name: &str, data: &[f32]) {
2442            self.inner.set_param(name, data);
2443        }
2444
2445        fn finalize_params(&mut self) {
2446            self.inner.preload_qmatmul_weights();
2447        }
2448
2449        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2450            self.inner.run(inputs)
2451        }
2452        fn run_read_outputs(
2453            &mut self,
2454            inputs: &[(&str, &[f32])],
2455            read_indices: Option<&[usize]>,
2456        ) -> Vec<Vec<f32>> {
2457            self.inner.run_read_outputs(inputs, read_indices)
2458        }
2459        fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
2460            self.inner.bind_gpu_handle(name, data)
2461        }
2462        fn has_gpu_handle(&self, name: &str) -> bool {
2463            self.inner.has_gpu_handle(name)
2464        }
2465        fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
2466            self.inner.set_gpu_handle_feed(handle_name, output_index);
2467            true
2468        }
2469        fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
2470            self.inner.read_gpu_handle(name)
2471        }
2472        fn read_output_row(
2473            &self,
2474            out_idx: usize,
2475            row: usize,
2476            row_inner: usize,
2477        ) -> Option<Vec<f32>> {
2478            Some(self.inner.read_graph_output_row(out_idx, row, row_inner))
2479        }
2480        fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
2481            self.inner.run_slots(inputs)
2482        }
2483        fn arena_ptr(&self) -> *const u8 {
2484            self.inner.arena_ptr()
2485        }
2486        fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
2487            self.inner.commit_no_wait(inputs);
2488        }
2489        fn sync_pending(&mut self) {
2490            self.inner.sync_pending();
2491        }
2492        fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
2493            self.inner.run_pipelined(input_sets)
2494        }
2495        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2496            self.inner.set_active_extent(extent);
2497        }
2498
2499        fn set_rng(&mut self, rng: rlx_ir::RngOptions) {
2500            self.inner.set_rng(rng);
2501        }
2502
2503        fn rng(&self) -> rlx_ir::RngOptions {
2504            self.inner.rng()
2505        }
2506
2507        /// Typed param upload — accepts F16/BF16 host bytes by widening
2508        /// to F32 first, then routing through `set_param`. The Metal
2509        /// arena's `write_from_f32` honors per-node F16 storage when
2510        /// AutoMixedPrecision rewrote the param. U8/I8 packed weights
2511        /// copy directly into the arena for `Op::DequantMatMul`.
2512        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2513            if matches!(
2514                dtype,
2515                rlx_ir::DType::U8
2516                    | rlx_ir::DType::I8
2517                    | rlx_ir::DType::I32
2518                    | rlx_ir::DType::I64
2519                    | rlx_ir::DType::U32
2520                    | rlx_ir::DType::F64
2521            ) {
2522                self.inner.set_param_bytes(name, data);
2523                return;
2524            }
2525            if dtype == rlx_ir::DType::F32 {
2526                let n = data.len() / 4;
2527                let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2528                self.inner.set_param(name, s);
2529            } else {
2530                let f32_buf = super::widen_bytes_to_f32(data, dtype);
2531                self.inner.set_param(name, &f32_buf);
2532            }
2533        }
2534
2535        /// Typed run. Integer inputs (I64 token ids, etc.) are copied
2536        /// directly into the unified-memory arena; F32/F16/BF16 widen
2537        /// through the existing host path. Outputs use native arena bytes.
2538        fn run_typed(
2539            &mut self,
2540            inputs: &[(&str, &[u8], rlx_ir::DType)],
2541        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2542            self.inner.run_typed(inputs)
2543        }
2544
2545        fn clone_box(&self) -> Box<dyn ExecutableGraph> {
2546            Box::new(MetalExecutableWrapper {
2547                inner: self.inner.clone_for_cache(),
2548                io_manifest: self.io_manifest.clone(),
2549            })
2550        }
2551    }
2552}
2553
2554// ── CUDA Backend ────────────────────────────────────────────────────────
2555
2556#[cfg(feature = "cuda")]
2557pub mod cuda_backend {
2558    use super::*;
2559    use rlx_cuda::backend::CudaExecutable;
2560
2561    pub struct CudaBackend;
2562
2563    /// PLAN L4: ops the CUDA backend can lower today. Excludes
2564    /// FusedSwiGLU, LoraMatMul, FusedAttentionBlock,
2565    /// FusedTransformerLayer (no kernel) + If, While (no executor
2566    /// wiring). DotGeneral via LowerDotGeneral; ElementwiseRegion
2567    /// lowered natively by an NVRTC interpreted-chain kernel.
2568    const CUDA_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2569        use rlx_ir::OpKind::*;
2570        &[
2571            Input,
2572            Param,
2573            Constant,
2574            Activation,
2575            Cast,
2576            Binary,
2577            Compare,
2578            Where,
2579            ElementwiseRegion,
2580            TransformRegion,
2581            BatchElementwiseRegion,
2582            MatMul,
2583            DotGeneral,
2584            LayerNorm,
2585            LayerNorm2d,
2586            GroupNorm,
2587            ResizeNearest2x,
2588            RmsNorm,
2589            Attention,
2590            AttentionBackward,
2591            RmsNormBackwardInput,
2592            RmsNormBackwardGamma,
2593            RmsNormBackwardBeta,
2594            RopeBackward,
2595            CumsumBackward,
2596            GatherBackward,
2597            Conv2dBackwardInput,
2598            Conv2dBackwardWeight,
2599            MaxPool2dBackward,
2600            Rope,
2601            Reshape,
2602            Transpose,
2603            Narrow,
2604            Concat,
2605            Expand,
2606            Gather,
2607            Reduce,
2608            Softmax,
2609            Cumsum,
2610            TopK,
2611            Sample,
2612            Conv,
2613            ConvTranspose2d,
2614            Pool,
2615            GroupedMatMul,
2616            DequantGroupedMatMul,
2617            DequantMoEWeights,
2618            ScatterAdd,
2619            DequantMatMul,
2620            SelectiveScan,
2621            Lstm,
2622            FusedMatMulBiasAct,
2623            FusedResidualLN,
2624            FusedResidualRmsNorm,
2625            GaussianSplatRender,
2626            GaussianSplatRenderBackward,
2627            GaussianSplatPrepare,
2628            GaussianSplatRasterize,
2629            Custom,
2630            Fft,
2631            LogMel,
2632            LogMelBackward,
2633            WelchPeaks,
2634            Im2Col,
2635            RngNormal,
2636            RngUniform,
2637        ]
2638    };
2639
2640    impl Backend for CudaBackend {
2641        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2642            CUDA_SUPPORTED_OPS
2643        }
2644
2645        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2646            use rlx_opt::pass::Pass as _;
2647            // Decompose FusedSwiGLU / FAB / etc. before legalization (CudaExecutable
2648            // unfuses again; this pass is idempotent).
2649            let graph = rlx_cuda::unfuse::unfuse(graph);
2650            let graph = rlx_opt::legalize_or_rewrite_for_backend(graph, CUDA_SUPPORTED_OPS)
2651                .unwrap_or_else(|errors| {
2652                    panic!("{}", rlx_opt::format_legalize_error("cuda", &errors));
2653                });
2654            let graph = crate::precompile::precompile_cleanup(graph, options);
2655            // Mid-axis broadcasts (EEG patch embed) before elementwise fusion.
2656            let graph = rlx_opt::LegalizeBroadcast.run(graph);
2657            // Backend-aware fusion via the shared compile pipeline.
2658            let compile_result = crate::stages::compile_graph_stages_for_backend(
2659                rlx_driver::Device::Cuda,
2660                graph,
2661                options,
2662                CUDA_SUPPORTED_OPS,
2663            );
2664            crate::stages::maybe_log_fusion(&compile_result.fusion);
2665            let graph = compile_result.lir.into_graph();
2666            let graph = match options.policy.clone() {
2667                Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
2668                None => graph,
2669            };
2670            let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
2671            Box::new(CudaExecutableWrapper {
2672                inner: CudaExecutable::compile_rng(graph, options.rng),
2673                io_manifest,
2674            })
2675        }
2676
2677        fn compile_lir(
2678            &self,
2679            lir: LirModule,
2680            options: &CompileOptions,
2681        ) -> Box<dyn ExecutableGraph> {
2682            use rlx_opt::pass::Pass as _;
2683            let graph = rlx_opt::LegalizeBroadcast.run(lir.into_graph());
2684            let (graph, io_manifest) =
2685                cpu_low_precision::prepare_f32_exec_graph(prepare_fused_graph(
2686                    rlx_cuda::unfuse::unfuse(graph),
2687                    options,
2688                    CUDA_SUPPORTED_OPS,
2689                    "cuda",
2690                ));
2691            Box::new(CudaExecutableWrapper {
2692                inner: CudaExecutable::compile_rng(graph, options.rng),
2693                io_manifest,
2694            })
2695        }
2696    }
2697
2698    struct CudaExecutableWrapper {
2699        inner: CudaExecutable,
2700        io_manifest: cpu_low_precision::IoDtypeManifest,
2701    }
2702
2703    // CudaExecutable owns CudaContext + CudaSlice handles; cudarc claims
2704    // they're Send (CudaContext is Arc-wrapped, CudaSlice is logically
2705    // a device pointer + length). The Backend trait requires Send for
2706    // the executable; we honor that here.
2707    unsafe impl Send for CudaExecutableWrapper {}
2708
2709    impl ExecutableGraph for CudaExecutableWrapper {
2710        fn set_param(&mut self, name: &str, data: &[f32]) {
2711            self.inner.set_param(name, data);
2712        }
2713        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2714            self.inner.run(inputs)
2715        }
2716        fn run_read_outputs(
2717            &mut self,
2718            inputs: &[(&str, &[f32])],
2719            read_indices: Option<&[usize]>,
2720        ) -> Vec<Vec<f32>> {
2721            self.inner.run_read_outputs(inputs, read_indices)
2722        }
2723        fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
2724            self.inner.bind_gpu_handle(name, data)
2725        }
2726        fn has_gpu_handle(&self, name: &str) -> bool {
2727            self.inner.has_gpu_handle(name)
2728        }
2729        fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
2730            self.inner.set_gpu_handle_feed(handle_name, output_index);
2731            true
2732        }
2733        fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
2734            self.inner.read_gpu_handle(name)
2735        }
2736        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2737            self.inner.set_active_extent(extent);
2738        }
2739
2740        fn set_rng(&mut self, rng: rlx_ir::RngOptions) {
2741            self.inner.set_rng(rng);
2742        }
2743
2744        fn rng(&self) -> rlx_ir::RngOptions {
2745            self.inner.rng()
2746        }
2747
2748        fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
2749            self.inner.run_slots(inputs)
2750        }
2751
2752        fn arena_ptr(&self) -> *const u8 {
2753            self.inner.arena_ptr()
2754        }
2755
2756        /// Typed param upload — widens F16/BF16 host bytes to f32
2757        /// before routing through `set_param`. CUDA's arena is
2758        /// f32-uniform; the half-precision matmul tier opts in via
2759        /// the separate `set_param_half` API.
2760        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2761            if matches!(dtype, rlx_ir::DType::U8 | rlx_ir::DType::I8) {
2762                self.inner.set_param_bytes(name, data);
2763                return;
2764            }
2765            if dtype == rlx_ir::DType::F32 {
2766                let n = data.len() / 4;
2767                let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2768                self.inner.set_param(name, s);
2769            } else {
2770                let f32_buf = super::widen_bytes_to_f32(data, dtype);
2771                self.inner.set_param(name, &f32_buf);
2772            }
2773        }
2774
2775        /// Typed run — widen each typed input to F32, run, then narrow
2776        /// each output back to its declared graph dtype.
2777        fn run_typed(
2778            &mut self,
2779            inputs: &[(&str, &[u8], rlx_ir::DType)],
2780        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2781            let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2782            for (name, data, dt) in inputs {
2783                let v = super::widen_bytes_to_f32(data, *dt);
2784                owned.push((name.to_string(), v));
2785            }
2786            let refs: Vec<(&str, &[f32])> = owned
2787                .iter()
2788                .map(|(n, d)| (n.as_str(), d.as_slice()))
2789                .collect();
2790            let dtypes =
2791                super::declared_output_dtypes(&self.io_manifest, self.inner.output_dtypes());
2792            let outs = self.inner.run(&refs);
2793            outs.into_iter()
2794                .zip(
2795                    dtypes
2796                        .into_iter()
2797                        .chain(std::iter::repeat(rlx_ir::DType::F32)),
2798                )
2799                .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2800                .collect()
2801        }
2802
2803        fn clone_box(&self) -> Box<dyn ExecutableGraph> {
2804            Box::new(CudaExecutableWrapper {
2805                inner: self.inner.clone_for_cache(),
2806                io_manifest: self.io_manifest.clone(),
2807            })
2808        }
2809    }
2810}
2811
2812// ── ROCm Backend ────────────────────────────────────────────────────────
2813
2814#[cfg(feature = "rocm")]
2815pub mod rocm_backend {
2816    use super::*;
2817    use rlx_rocm::backend::RocmExecutable;
2818
2819    pub struct RocmBackend;
2820
2821    /// PLAN L4: ROCm is the sister crate of CUDA; identical Step
2822    /// enum + dispatch shape → identical claimed op set.
2823    const ROCM_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2824        use rlx_ir::OpKind::*;
2825        &[
2826            Input,
2827            Param,
2828            Constant,
2829            Activation,
2830            Cast,
2831            Binary,
2832            Compare,
2833            Where,
2834            ElementwiseRegion,
2835            TransformRegion,
2836            BatchElementwiseRegion,
2837            MatMul,
2838            DotGeneral,
2839            LayerNorm,
2840            LayerNorm2d,
2841            GroupNorm,
2842            ResizeNearest2x,
2843            RmsNorm,
2844            Attention,
2845            AttentionBackward,
2846            RmsNormBackwardInput,
2847            RmsNormBackwardGamma,
2848            RmsNormBackwardBeta,
2849            RopeBackward,
2850            CumsumBackward,
2851            GatherBackward,
2852            Rope,
2853            Reshape,
2854            Transpose,
2855            Narrow,
2856            Concat,
2857            Expand,
2858            Gather,
2859            Reduce,
2860            Softmax,
2861            Cumsum,
2862            TopK,
2863            Sample,
2864            Conv,
2865            ConvTranspose2d,
2866            Pool,
2867            GroupedMatMul,
2868            DequantGroupedMatMul,
2869            DequantMoEWeights,
2870            ScatterAdd,
2871            DequantMatMul,
2872            SelectiveScan,
2873            Lstm,
2874            FusedMatMulBiasAct,
2875            FusedResidualLN,
2876            FusedResidualRmsNorm,
2877            GaussianSplatRender,
2878            GaussianSplatRenderBackward,
2879            GaussianSplatPrepare,
2880            GaussianSplatRasterize,
2881            Custom,
2882            Fft,
2883            LogMel,
2884            LogMelBackward,
2885            WelchPeaks,
2886            Im2Col,
2887            RngNormal,
2888            RngUniform,
2889        ]
2890    };
2891
2892    impl Backend for RocmBackend {
2893        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2894            ROCM_SUPPORTED_OPS
2895        }
2896
2897        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2898            use rlx_opt::pass::Pass as _;
2899            let graph = rlx_rocm::unfuse::unfuse(graph);
2900            let graph = rlx_opt::legalize_or_rewrite_for_backend(graph, ROCM_SUPPORTED_OPS)
2901                .unwrap_or_else(|errors| {
2902                    panic!("{}", rlx_opt::format_legalize_error("rocm", &errors));
2903                });
2904            let graph = crate::precompile::precompile_cleanup(graph, options);
2905            let graph = rlx_opt::LegalizeBroadcast.run(graph);
2906            let compile_result = crate::stages::compile_graph_stages_for_backend(
2907                rlx_driver::Device::Rocm,
2908                graph,
2909                options,
2910                ROCM_SUPPORTED_OPS,
2911            );
2912            crate::stages::maybe_log_fusion(&compile_result.fusion);
2913            let graph = compile_result.lir.into_graph();
2914            let graph = match options.policy.clone() {
2915                Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
2916                None => graph,
2917            };
2918            let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
2919            Box::new(RocmExecutableWrapper {
2920                inner: RocmExecutable::compile_rng(graph, options.rng),
2921                io_manifest,
2922            })
2923        }
2924
2925        fn compile_lir(
2926            &self,
2927            lir: LirModule,
2928            options: &CompileOptions,
2929        ) -> Box<dyn ExecutableGraph> {
2930            let (graph, io_manifest) =
2931                cpu_low_precision::prepare_f32_exec_graph(prepare_fused_graph(
2932                    rlx_rocm::unfuse::unfuse(lir.into_graph()),
2933                    options,
2934                    ROCM_SUPPORTED_OPS,
2935                    "rocm",
2936                ));
2937            Box::new(RocmExecutableWrapper {
2938                inner: RocmExecutable::compile_rng(graph, options.rng),
2939                io_manifest,
2940            })
2941        }
2942    }
2943
2944    struct RocmExecutableWrapper {
2945        inner: RocmExecutable,
2946        io_manifest: cpu_low_precision::IoDtypeManifest,
2947    }
2948
2949    // Same Send-claim shape as CudaExecutableWrapper. RocmExecutable
2950    // owns Arc<RocmContext> + HipBuffer handles; the HipRuntime bundle
2951    // is internally thread-safe per AMD's documentation.
2952    unsafe impl Send for RocmExecutableWrapper {}
2953
2954    impl ExecutableGraph for RocmExecutableWrapper {
2955        fn set_param(&mut self, name: &str, data: &[f32]) {
2956            self.inner.set_param(name, data);
2957        }
2958        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2959            self.inner.run(inputs)
2960        }
2961        fn run_read_outputs(
2962            &mut self,
2963            inputs: &[(&str, &[f32])],
2964            read_indices: Option<&[usize]>,
2965        ) -> Vec<Vec<f32>> {
2966            self.inner.run_read_outputs(inputs, read_indices)
2967        }
2968        fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
2969            self.inner.bind_gpu_handle(name, data)
2970        }
2971        fn has_gpu_handle(&self, name: &str) -> bool {
2972            self.inner.has_gpu_handle(name)
2973        }
2974        fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
2975            self.inner.set_gpu_handle_feed(handle_name, output_index);
2976            true
2977        }
2978        fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
2979            self.inner.read_gpu_handle(name)
2980        }
2981        fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
2982            self.inner.run_slots(inputs)
2983        }
2984        fn arena_ptr(&self) -> *const u8 {
2985            self.inner.arena_ptr()
2986        }
2987        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2988            self.inner.set_active_extent(extent);
2989        }
2990
2991        fn set_rng(&mut self, rng: rlx_ir::RngOptions) {
2992            self.inner.set_rng(rng);
2993        }
2994
2995        fn rng(&self) -> rlx_ir::RngOptions {
2996            self.inner.rng()
2997        }
2998
2999        /// Typed param upload — widens F16/BF16 host bytes to f32
3000        /// before routing through `set_param`. ROCm's arena is
3001        /// f32-uniform; the half-precision matmul tier opts in via
3002        /// the separate `set_param_half` API.
3003        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
3004            if matches!(dtype, rlx_ir::DType::U8 | rlx_ir::DType::I8) {
3005                self.inner.set_param_bytes(name, data);
3006                return;
3007            }
3008            if dtype == rlx_ir::DType::F32 {
3009                let n = data.len() / 4;
3010                let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
3011                self.inner.set_param(name, s);
3012            } else {
3013                let f32_buf = super::widen_bytes_to_f32(data, dtype);
3014                self.inner.set_param(name, &f32_buf);
3015            }
3016        }
3017
3018        /// Typed run — widen each typed input to F32, run, then narrow
3019        /// each output back to its declared graph dtype.
3020        fn run_typed(
3021            &mut self,
3022            inputs: &[(&str, &[u8], rlx_ir::DType)],
3023        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
3024            let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
3025            for (name, data, dt) in inputs {
3026                let v = super::widen_bytes_to_f32(data, *dt);
3027                owned.push((name.to_string(), v));
3028            }
3029            let refs: Vec<(&str, &[f32])> = owned
3030                .iter()
3031                .map(|(n, d)| (n.as_str(), d.as_slice()))
3032                .collect();
3033            let dtypes =
3034                super::declared_output_dtypes(&self.io_manifest, self.inner.output_dtypes());
3035            let outs = self.inner.run(&refs);
3036            outs.into_iter()
3037                .zip(
3038                    dtypes
3039                        .into_iter()
3040                        .chain(std::iter::repeat(rlx_ir::DType::F32)),
3041                )
3042                .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
3043                .collect()
3044        }
3045
3046        fn clone_box(&self) -> Box<dyn ExecutableGraph> {
3047            Box::new(RocmExecutableWrapper {
3048                inner: self.inner.clone_for_cache(),
3049                io_manifest: self.io_manifest.clone(),
3050            })
3051        }
3052    }
3053}
3054
3055// ── TPU Backend ─────────────────────────────────────────────────────────
3056
3057#[cfg(feature = "tpu")]
3058pub mod tpu_backend {
3059    use super::*;
3060    use rlx_tpu::TpuExecutable;
3061
3062    pub struct TpuBackend;
3063
3064    /// Ops the TPU backend lowers to HLO. Full inference parity with
3065    /// rlx-cuda / rlx-rocm. Composite ops (FusedSwiGLU /
3066    /// FusedAttentionBlock / FusedTransformerLayer / LoraMatMul / If /
3067    /// While) are unfused inside `rlx_tpu::unfuse::unfuse` ahead of
3068    /// HLO emission, so they don't appear here.
3069    const TPU_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
3070        use rlx_ir::OpKind::*;
3071        &[
3072            Input,
3073            Param,
3074            Constant,
3075            Activation,
3076            Cast,
3077            Binary,
3078            Compare,
3079            Where,
3080            ElementwiseRegion,
3081            TransformRegion,
3082            BatchElementwiseRegion,
3083            MatMul,
3084            DotGeneral,
3085            LayerNorm,
3086            RmsNorm,
3087            Attention,
3088            Rope,
3089            Reshape,
3090            Transpose,
3091            Narrow,
3092            Concat,
3093            Expand,
3094            Gather,
3095            Reduce,
3096            Softmax,
3097            Cumsum,
3098            TopK,
3099            Sample,
3100            Conv,
3101            Pool,
3102            GroupedMatMul,
3103            DequantGroupedMatMul,
3104            DequantMoEWeights,
3105            ScatterAdd,
3106            DequantMatMul,
3107            SelectiveScan,
3108            // Real-INT8 path + fake-quant.
3109            QMatMul,
3110            QConv2d,
3111            Quantize,
3112            Dequantize,
3113            FusedMatMulBiasAct,
3114            FusedResidualLN,
3115            FusedResidualRmsNorm,
3116            Fft,
3117            LogMel,
3118            LogMelBackward,
3119            WelchPeaks,
3120            RngNormal,
3121            RngUniform,
3122            // Splat: no on-chip kernel — lowered to common primitive MIR via logical_kernel.
3123        ]
3124    };
3125
3126    impl Backend for TpuBackend {
3127        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
3128            TPU_SUPPORTED_OPS
3129        }
3130
3131        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
3132            let graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
3133                graph,
3134                TPU_SUPPORTED_OPS,
3135                options.kernel_dispatch,
3136            )
3137            .unwrap_or_else(|errors| {
3138                panic!("{}", rlx_opt::format_legalize_error("tpu", &errors));
3139            });
3140            // The TPU's IR-side pass pipeline (DCE, ConstFold,
3141            // FuseResidualLN, FuseMatMulBiasAct, LegalizeBroadcast,
3142            // MarkElementwiseRegions) lives inside
3143            // `TpuExecutable::compile` so the same passes run whether
3144            // a caller goes through Session or invokes the executable
3145            // directly. We only do backend-cross-cutting work here:
3146            // legalization (must precede the pipeline so we panic
3147            // early on unsupported ops) and AutoMixedPrecision.
3148            //
3149            // Default policy on TPU is `AutoMixedBf16`: BF16 is the
3150            // native compute dtype on TPU silicon and recent GPUs,
3151            // and XLA's CPU plugin handles it natively too. Callers
3152            // can opt out by passing an explicit `PrecisionPolicy`
3153            // (e.g. `AlwaysF32` for accuracy debugging or
3154            // `AlwaysF16` to match a CUDA workload's choice).
3155            use rlx_opt::pass::Pass as _;
3156            let policy = options
3157                .policy
3158                .clone()
3159                .unwrap_or(rlx_opt::PrecisionPolicy::AutoMixedBf16);
3160            let graph = rlx_opt::AutoMixedPrecision::new(policy).run(graph);
3161            let _ = options.dce;
3162            let _ = options.constant_folding;
3163            Box::new(TpuExecutableWrapper {
3164                inner: TpuExecutable::compile_rng(graph, options.rng),
3165            })
3166        }
3167    }
3168
3169    struct TpuExecutableWrapper {
3170        inner: TpuExecutable,
3171    }
3172
3173    // PJRT clients + buffers are documented as thread-safe per the
3174    // upstream C API. Same Send-claim shape as CudaExecutableWrapper /
3175    // RocmExecutableWrapper.
3176    unsafe impl Send for TpuExecutableWrapper {}
3177
3178    impl ExecutableGraph for TpuExecutableWrapper {
3179        fn set_param(&mut self, name: &str, data: &[f32]) {
3180            self.inner.set_param(name, data);
3181        }
3182        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
3183            self.inner.run(inputs)
3184        }
3185
3186        /// Typed param upload — widens F16/BF16/etc. host bytes to
3187        /// f32 today. Once the HLO emitter speaks bf16 natively
3188        /// (which TPUs prefer over f16), the typed path will hand
3189        /// the original bytes straight through `Buffer_FromHostBuffer`.
3190        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
3191            if dtype == rlx_ir::DType::F32 {
3192                let n = data.len() / 4;
3193                let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
3194                self.inner.set_param(name, s);
3195            } else {
3196                let f32_buf = super::widen_bytes_to_f32(data, dtype);
3197                self.inner.set_param(name, &f32_buf);
3198            }
3199        }
3200
3201        fn run_typed(
3202            &mut self,
3203            inputs: &[(&str, &[u8], rlx_ir::DType)],
3204        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
3205            let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
3206            for (name, data, dt) in inputs {
3207                let v = super::widen_bytes_to_f32(data, *dt);
3208                owned.push((name.to_string(), v));
3209            }
3210            let refs: Vec<(&str, &[f32])> = owned
3211                .iter()
3212                .map(|(n, d)| (n.as_str(), d.as_slice()))
3213                .collect();
3214            let dtypes = self.inner.output_dtypes();
3215            let outs = self.inner.run(&refs);
3216            outs.into_iter()
3217                .zip(
3218                    dtypes
3219                        .into_iter()
3220                        .chain(std::iter::repeat(rlx_ir::DType::F32)),
3221                )
3222                .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
3223                .collect()
3224        }
3225
3226        fn clone_box(&self) -> Box<dyn ExecutableGraph> {
3227            Box::new(TpuExecutableWrapper {
3228                inner: self.inner.clone_for_cache(),
3229            })
3230        }
3231    }
3232}