Skip to main content

rlx_runtime/
backend.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Backend trait — abstraction over CPU/GPU/CUDA execution.
17//!
18//! Each backend implements `Backend::compile(graph, &CompileOptions)` and
19//! returns an `ExecutableGraph`. New compile knobs go in `CompileOptions`
20//! rather than as new trait methods.
21
22use crate::CompileOptions;
23use rlx_ir::Graph;
24use rlx_ir::hir::HirModule;
25use rlx_ir::lir::LirModule;
26use std::collections::HashMap;
27use std::sync::Arc;
28
29// ── Typed I/O helpers (shared across f32-arena backends) ────────────────
30
31/// Widen a typed byte buffer to `Vec<f32>`. Used by `set_param_typed` /
32/// `run_typed` overrides on backends whose internal arena is f32-uniform
33/// (CPU, Metal, wgpu) so callers can hand in F16/BF16 without doing the
34/// host-side cast themselves. Panics on dtypes the f32 arena can't carry.
35#[allow(dead_code)]
36pub(crate) fn widen_bytes_to_f32(data: &[u8], dtype: rlx_ir::DType) -> Vec<f32> {
37    use rlx_ir::DType;
38    match dtype {
39        DType::F32 => {
40            let n = data.len() / 4;
41            let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
42            s.to_vec()
43        }
44        DType::F16 => {
45            let n = data.len() / 2;
46            let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n) };
47            s.iter().map(|h| h.to_f32()).collect()
48        }
49        DType::BF16 => {
50            let n = data.len() / 2;
51            let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n) };
52            s.iter().map(|h| h.to_f32()).collect()
53        }
54        other => panic!(
55            "widen_bytes_to_f32: dtype {other:?} unsupported on f32-arena backends \
56             (only F32/F16/BF16 are accepted on the host I/O surface)"
57        ),
58    }
59}
60
61/// Narrow a `&[f32]` buffer down to the declared output dtype, returning
62/// the corresponding little-endian byte stream. Mirrors the bytes a
63/// backend that stores the native dtype would emit. Used by `run_typed`
64/// to keep the byte-level output contract identical across backends.
65#[allow(dead_code)]
66pub(crate) fn narrow_f32_to_bytes(v: &[f32], dt: rlx_ir::DType) -> Vec<u8> {
67    use rlx_ir::DType;
68    match dt {
69        DType::F32 => {
70            let mut bytes = Vec::with_capacity(v.len() * 4);
71            for &x in v {
72                bytes.extend_from_slice(&x.to_le_bytes());
73            }
74            bytes
75        }
76        DType::F16 => {
77            let mut bytes = Vec::with_capacity(v.len() * 2);
78            for &x in v {
79                bytes.extend_from_slice(&half::f16::from_f32(x).to_le_bytes());
80            }
81            bytes
82        }
83        DType::BF16 => {
84            let mut bytes = Vec::with_capacity(v.len() * 2);
85            for &x in v {
86                bytes.extend_from_slice(&half::bf16::from_f32(x).to_le_bytes());
87            }
88            bytes
89        }
90        DType::F64 => {
91            let mut bytes = Vec::with_capacity(v.len() * 8);
92            for &x in v {
93                bytes.extend_from_slice(&(x as f64).to_le_bytes());
94            }
95            bytes
96        }
97        DType::I8 => v.iter().map(|&x| x as i8 as u8).collect(),
98        DType::U8 => v.iter().map(|&x| x as u8).collect(),
99        DType::I16 => {
100            let mut bytes = Vec::with_capacity(v.len() * 2);
101            for &x in v {
102                bytes.extend_from_slice(&(x as i16).to_le_bytes());
103            }
104            bytes
105        }
106        DType::I32 => {
107            let mut bytes = Vec::with_capacity(v.len() * 4);
108            for &x in v {
109                bytes.extend_from_slice(&(x as i32).to_le_bytes());
110            }
111            bytes
112        }
113        DType::U32 => {
114            let mut bytes = Vec::with_capacity(v.len() * 4);
115            for &x in v {
116                bytes.extend_from_slice(&(x as u32).to_le_bytes());
117            }
118            bytes
119        }
120        DType::I64 => {
121            let mut bytes = Vec::with_capacity(v.len() * 8);
122            for &x in v {
123                bytes.extend_from_slice(&(x as i64).to_le_bytes());
124            }
125            bytes
126        }
127        DType::Bool => v
128            .iter()
129            .map(|&x| if x != 0.0 { 1u8 } else { 0u8 })
130            .collect(),
131        DType::C64 => {
132            // Complex narrow path: real part = the f32 value, imaginary
133            // part = 0. Mirrors how the backend stores narrowed f32
134            // operands when promoted to a complex op input.
135            let mut bytes = Vec::with_capacity(v.len() * 8);
136            for &x in v {
137                bytes.extend_from_slice(&x.to_le_bytes());
138                bytes.extend_from_slice(&0.0_f32.to_le_bytes());
139            }
140            bytes
141        }
142    }
143}
144
145/// A compiled, ready-to-execute graph on a specific backend.
146pub trait ExecutableGraph: Send {
147    /// Set a named parameter (weight) buffer.
148    fn set_param(&mut self, name: &str, data: &[f32]);
149
150    /// Deep-clone this executable into a fresh `Box`. Lets
151    /// `CompiledGraph` implement `Clone` so callers (e.g. eda-mna's
152    /// `SensitivityContext`) can spin up N independent executor
153    /// copies for thread-parallel dispatch without paying the full
154    /// graph-compile cost N times. Default implementation panics;
155    /// backends that support cloning override.
156    fn clone_box(&self) -> Box<dyn ExecutableGraph> {
157        panic!("clone_box not implemented for this backend");
158    }
159
160    /// Execute the graph with named inputs. Returns output data (copies from arena).
161    fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>>;
162
163    /// Execute and return raw pointers to output data in arena (zero-copy).
164    fn run_raw(&mut self, inputs: &[(&str, &[f32])]) -> Vec<(*const f32, usize)> {
165        let vecs = self.run(inputs);
166        vecs.iter().map(|v| (v.as_ptr(), v.len())).collect()
167    }
168
169    /// Fastest: inputs by slot index, returns output (offset, len) pairs.
170    /// Read output from arena via `arena_ptr().add(offset)`.
171    fn run_slots(&mut self, _inputs: &[&[f32]]) -> &[(usize, usize)] {
172        &[] // default: not supported
173    }
174
175    /// Get the raw arena buffer pointer for reading outputs after run_slots.
176    fn arena_ptr(&self) -> *const u8 {
177        std::ptr::null()
178    }
179
180    /// Hint the executor that subsequent `run` calls should process
181    /// only the first `actual` rows along the bucket axis (out of
182    /// `upper`, the extent the graph was compiled at). Backends that
183    /// support per-kernel active-extent dispatch honor this; others
184    /// ignore it and process the full compiled extent.
185    ///
186    /// Pass `None` to clear the hint. The hint is sticky — set it
187    /// before each `run` and clear it after, or maintain it across
188    /// runs at your discretion.
189    ///
190    /// Even when honored, callers must not rely on the contents of the
191    /// output past `actual` rows — that region may contain stale data
192    /// from earlier runs (kernels skip it).
193    ///
194    /// Default: no-op. See `BucketedCompileCache::run_padded` for the
195    /// canonical caller; backends opt in by overriding this method.
196    fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
197        let _ = extent;
198    }
199
200    /// TIDE merged placement mask (union across MoE layers). CPU: stats + host path.
201    fn set_moe_resident_experts(&mut self, _mask: &[bool]) {}
202
203    /// Per MoE layer placement (`masks[layer][expert]`). Preferred over merged on CPU.
204    fn set_moe_resident_experts_per_layer(&mut self, _masks: &[&[bool]]) {}
205
206    /// Capture MoE router TopK indices on the next CPU forward (TIDE refresh).
207    fn enable_moe_topk_capture(&mut self, _num_experts: usize) -> bool {
208        false
209    }
210
211    /// Take captured per-layer expert indices (one vec per MoE TopK in order).
212    fn take_moe_topk_capture(&mut self) -> Option<Vec<Vec<u32>>> {
213        None
214    }
215
216    /// MoE GroupedMatMul residency accounting from the last forward (CPU).
217    fn take_moe_residency_stats(&mut self) -> Option<crate::MoeResidencyStats> {
218        None
219    }
220
221    /// Bind a persistent buffer handle (KV-cache, training state, etc.).
222    /// The buffer lives across run() calls and is not in the arena.
223    /// Returns true if the backend supports persistent handles.
224    fn bind_handle(&mut self, _name: &str, _data: &[f32]) -> bool {
225        false
226    }
227
228    /// Read a persistent buffer's current contents.
229    fn read_handle(&self, _name: &str) -> Option<Vec<f32>> {
230        None
231    }
232
233    /// GPU-resident input (MLX): upload once, reuse across runs.
234    fn bind_gpu_handle(&mut self, _name: &str, _data: &[f32]) -> bool {
235        false
236    }
237
238    fn has_gpu_handle(&self, _name: &str) -> bool {
239        false
240    }
241
242    fn set_gpu_handle_feed(&mut self, _handle_name: &str, _output_index: usize) -> bool {
243        false
244    }
245
246    fn read_gpu_handle(&self, _name: &str) -> Option<Vec<f32>> {
247        None
248    }
249
250    /// Run and refresh a GPU handle from `output_index`; returns that output on host.
251    fn run_feed_gpu_handle(
252        &mut self,
253        inputs: &[(&str, &[f32])],
254        _handle_name: &str,
255        _output_index: usize,
256    ) -> Option<Vec<f32>> {
257        let _ = inputs;
258        None
259    }
260
261    // ── Pipelined / async execution (Phase C) ─────────────────────────
262    //
263    // These allow callers to amortize per-run sync latency on backends
264    // where it matters (Metal: ~150 µs `wait_until_completed` per commit).
265    // CPU has no such cost, so the default impls just call `run` serially.
266
267    /// Encode + commit a forward pass without waiting for completion.
268    ///
269    /// Outputs of intermediate calls are stomped — use `run_pipelined` if
270    /// you need outputs from each individual commit. Pair with
271    /// `sync_pending` to drain.
272    ///
273    /// Default: synchronous fallback (calls `run`, discards output). CPU
274    /// uses this default since BLAS is synchronous anyway.
275    fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
276        let _ = self.run(inputs);
277    }
278
279    /// Wait for every command queued by `commit_no_wait`.
280    /// Default: no-op (synchronous backends have nothing pending).
281    fn sync_pending(&mut self) {}
282
283    /// Issue a batch of forward passes pipelined, returning per-run outputs.
284    ///
285    /// The Metal impl encodes a per-commit blit so each in-flight run's
286    /// outputs survive subsequent commits stomping the shared arena. The
287    /// CPU default is just sequential `run`s — equally correct, no perf
288    /// penalty (CPU has no GPU sync cost to amortize).
289    ///
290    /// Returns `out[run_idx][output_idx][element_idx]`.
291    fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
292        input_sets.iter().map(|inputs| self.run(inputs)).collect()
293    }
294
295    // ── Typed (non-F32) host I/O ──────────────────────────────────
296    //
297    // `set_param` and `run` are F32 by contract. The typed entry
298    // points let callers pass and receive raw bytes in any rlx-ir
299    // dtype, avoiding the f32 widen/narrow round-trip that's
300    // wasteful for F16/BF16 weights and activations.
301    //
302    // The default impls only handle F32 — any other dtype panics.
303    // Backends that support typed I/O natively (e.g. MLX via
304    // Array::from_bytes/to_bytes) override these.
305
306    /// Set a named parameter from raw bytes in the given dtype.
307    fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
308        if dtype != rlx_ir::DType::F32 {
309            panic!(
310                "backend's default set_param_typed only handles F32; \
311                    got {dtype:?}. Override on the backend for typed support."
312            );
313        }
314        if !data.len().is_multiple_of(4) {
315            panic!(
316                "set_param_typed F32: data length {} not a multiple of 4",
317                data.len()
318            );
319        }
320        // SAFETY: F32 bytes are 4-aligned by source convention; we
321        // only widen access (read &[f32] from owned &[u8]). Failure
322        // mode if a caller hands us mis-aligned bytes is undefined,
323        // hence the % 4 length check.
324        let n = data.len() / 4;
325        let f32_slice = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
326        self.set_param(name, f32_slice);
327    }
328
329    /// Run with typed inputs and typed outputs. Returns
330    /// `(bytes, dtype)` per output; the dtype is whatever the
331    /// graph's output node was declared as.
332    fn run_typed(
333        &mut self,
334        inputs: &[(&str, &[u8], rlx_ir::DType)],
335    ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
336        // Default impl: convert each typed input to f32 (F32-only),
337        // run, then re-emit outputs as F32 bytes.
338        let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
339        for (name, data, dt) in inputs {
340            if *dt != rlx_ir::DType::F32 {
341                panic!(
342                    "backend's default run_typed only handles F32 inputs; \
343                        got {dt:?} for input '{name}'"
344                );
345            }
346            if data.len() % 4 != 0 {
347                panic!(
348                    "run_typed F32 input '{name}': len {} not multiple of 4",
349                    data.len()
350                );
351            }
352            let n = data.len() / 4;
353            let v: Vec<f32> =
354                unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) }.to_vec();
355            owned.push((name.to_string(), v));
356        }
357        let refs: Vec<(&str, &[f32])> = owned
358            .iter()
359            .map(|(n, d)| (n.as_str(), d.as_slice()))
360            .collect();
361        let outs = self.run(&refs);
362        outs.into_iter()
363            .map(|v| {
364                let bytes =
365                    unsafe { std::slice::from_raw_parts(v.as_ptr() as *const u8, v.len() * 4) }
366                        .to_vec();
367                (bytes, rlx_ir::DType::F32)
368            })
369            .collect()
370    }
371}
372
373/// Backend implementation trait.
374///
375/// Single compile entry point. New compile-time knobs are added to
376/// `CompileOptions`, not as new trait methods.
377///
378/// `Send + Sync` because backends are stateless factories — multiple
379/// threads can call `compile` concurrently. The returned
380/// `Box<dyn ExecutableGraph>` is `Send` (moveable to a worker thread)
381/// but **not** `Sync` (`run`/`run_slots` take `&mut self`).
382pub trait Backend: Send + Sync {
383    /// Compile a graph for this backend with the given options.
384    fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph>;
385
386    /// Compile pre-optimized LIR (HIR → MIR → LIR pipeline output).
387    /// Default re-enters [`Self::compile`] — backends should override
388    /// when they can reuse the embedded buffer plan.
389    fn compile_lir(&self, lir: LirModule, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
390        self.compile(lir.into_graph(), options)
391    }
392
393    /// HIR-first compile: lower blocks, run fusion pipeline, emit executable.
394    fn compile_hir(
395        &self,
396        hir: HirModule,
397        device: rlx_driver::Device,
398        options: &CompileOptions,
399    ) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
400        let result = crate::stages::compile_hir_stages(device, hir, options)?;
401        crate::stages::maybe_log_fusion(&result.fusion);
402        Ok(self.compile_lir(result.lir, options))
403    }
404
405    /// [`GraphModule`] compile — unified HIR/MIR/LIR entry.
406    fn compile_module(
407        &self,
408        module: rlx_ir::GraphModule,
409        device: rlx_driver::Device,
410        options: &CompileOptions,
411    ) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
412        let result = crate::stages::compile_module_stages(device, module, options)?;
413        crate::stages::maybe_log_fusion(&result.fusion);
414        Ok(self.compile_lir(result.lir, options))
415    }
416
417    /// PLAN L4: declare which `OpKind`s this backend can lower.
418    /// Default: empty slice = "no claim made — accept everything"
419    /// (preserves existing behavior; backends opt in by overriding).
420    /// When non-empty, the `LegalizeForBackend` pass will refuse to
421    /// compile a graph that contains an op outside this set, instead
422    /// of silently falling through to slower / wrong dispatch.
423    fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
424        &[]
425    }
426}
427
428/// Prepare a fused MIR graph from LIR for backend executable construction.
429/// Skips the fusion pipeline — LIR must come from `compile_*_stages`.
430#[allow(dead_code)]
431fn prepare_fused_graph(
432    graph: Graph,
433    options: &CompileOptions,
434    supported_ops: &[rlx_ir::OpKind],
435    backend_name: &str,
436) -> Graph {
437    let (mut graph, report) = rlx_opt::prepare_graph_for_backend_with_report(
438        graph,
439        backend_name,
440        supported_ops,
441        options.kernel_dispatch,
442    );
443    rlx_opt::maybe_log_dispatch_report(&report);
444    if !report.compile_ready {
445        panic!(
446            "{}\n{}",
447            rlx_opt::format_legalize_error(backend_name, &report.still_unsupported),
448            rlx_opt::format_dispatch_report(&report)
449        );
450    }
451    use rlx_opt::pass::Pass as _;
452    if options.dce {
453        graph = rlx_opt::DeadCodeElimination.run(graph);
454    }
455    if options.constant_folding {
456        graph = rlx_opt::ConstantFolding.run(graph);
457    }
458    if let Some(p) = options.policy.clone() {
459        graph = rlx_opt::AutoMixedPrecision::new(p).run(graph);
460    }
461    graph
462}
463
464// ── Convenience helpers preserved from older API ──────────────────────
465//
466// These let existing call sites keep working unchanged while the new
467// trait is the canonical one. We provide free functions rather than
468// trait methods so adding them doesn't grow the trait surface.
469
470/// Compile at default options (F32, no policy).
471pub fn compile(backend: &dyn Backend, graph: Graph) -> Box<dyn ExecutableGraph> {
472    backend.compile(graph, &CompileOptions::default())
473}
474
475/// Compile HIR through the fusion-first pipeline.
476pub fn compile_hir(
477    backend: &dyn Backend,
478    hir: HirModule,
479    device: rlx_driver::Device,
480    options: &CompileOptions,
481) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
482    backend.compile_hir(hir, device, options)
483}
484
485/// Compile a [`GraphModule`] through the fusion-first pipeline.
486pub fn compile_module(
487    backend: &dyn Backend,
488    module: rlx_ir::GraphModule,
489    device: rlx_driver::Device,
490    options: &CompileOptions,
491) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
492    backend.compile_module(module, device, options)
493}
494
495/// Compile at a specific precision (default policy = none).
496pub fn compile_with_precision(
497    backend: &dyn Backend,
498    graph: Graph,
499    precision: crate::Precision,
500) -> Box<dyn ExecutableGraph> {
501    backend.compile(graph, &CompileOptions::new().precision(precision))
502}
503
504/// Helper retained for backward compatibility — applies the precision
505/// rewrite at the runtime layer if backends don't override their
506/// pipeline placement. Modern code: pass the policy via CompileOptions
507/// and let the backend handle ordering.
508fn _legacy_apply_policy(graph: Graph, policy: Option<rlx_opt::PrecisionPolicy>) -> Graph {
509    use rlx_opt::pass::Pass as _;
510    match policy {
511        Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
512        None => graph,
513    }
514}
515
516// ── CPU Backend ─────────────────────────────────────────────────────────
517
518#[cfg(feature = "cpu")]
519pub mod cpu_backend {
520    use super::*;
521    use rlx_cpu::{arena::Arena, thunk};
522    use rlx_ir::{DType, NodeId, Op};
523    use rlx_opt::memory::{self, MemoryPlan};
524    // Arena typed read/write helpers live in `crate::arena` so every
525    // backend (CPU, Metal, future CUDA/wgpu/WASM) shares one implementation.
526    use rlx_driver::arena::{read_typed_to_f32, write_typed_from_f32};
527
528    pub struct CpuBackend;
529
530    /// PLAN L4: ops the CPU backend can lower today. Includes
531    /// DotGeneral (lowered via `LowerDotGeneral` pass) and
532    /// ElementwiseRegion (lowered natively per L2). Excludes
533    /// FusedTransformerLayer / If / While — those have IR variants
534    /// but no CPU lowering yet (see `compile_thunks` arm absence +
535    /// `subgraph.rs` "If/While executor wiring is pending" note).
536    const CPU_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
537        use rlx_ir::OpKind::*;
538        &[
539            Input,
540            Param,
541            Constant,
542            Activation,
543            Cast,
544            Binary,
545            Compare,
546            Where,
547            ElementwiseRegion,
548            MatMul,
549            DotGeneral,
550            DenseSolve,
551            BatchedDenseSolve,
552            Scan,
553            ScanBackward,
554            ScanBackwardXs,
555            LayerNorm,
556            LayerNorm2d,
557            GroupNorm,
558            RmsNorm,
559            ResizeNearest2x,
560            AxialRope2d,
561            Attention,
562            Rope,
563            Reshape,
564            Transpose,
565            Narrow,
566            Concat,
567            Expand,
568            Gather,
569            Reduce,
570            Softmax,
571            Cumsum,
572            TopK,
573            Sample,
574            Conv,
575            ConvTranspose2d,
576            Pool,
577            GroupedMatMul,
578            DequantGroupedMatMul,
579            DequantMoEWeights,
580            ScatterAdd,
581            LoraMatMul,
582            DequantMatMul,
583            SelectiveScan,
584            GatedDeltaNet,
585            FusedSwiGLU,
586            FusedMatMulBiasAct,
587            FusedResidualLN,
588            FusedResidualRmsNorm,
589            FusedAttentionBlock,
590            // Backward ops emitted by `rlx_opt::autodiff::grad_with_loss`.
591            // Their thunks live in rlx-cpu/src/thunk.rs alongside the
592            // forward kernels; without these entries the legalize step
593            // below would reject any compiled gradient graph.
594            ReluBackward,
595            ActivationBackward,
596            FakeQuantize,
597            FakeQuantizeBackward,
598            MaxPool2dBackward,
599            Conv2dBackwardInput,
600            Conv2dBackwardWeight,
601            SoftmaxCrossEntropyWithLogits,
602            SoftmaxCrossEntropyBackward,
603            AttentionBackward,
604            LayerNormBackwardInput,
605            LayerNormBackwardGamma,
606            RmsNormBackwardInput,
607            RmsNormBackwardGamma,
608            RmsNormBackwardBeta,
609            RopeBackward,
610            CumsumBackward,
611            GatherBackward,
612            // 3D Gaussian splat CPU reference render/backward (requires `rlx-cpu/splat`).
613            GaussianSplatRender,
614            GaussianSplatRenderBackward,
615            GaussianSplatPrepare,
616            GaussianSplatRasterize,
617            // User-registered custom ops dispatched through
618            // `rlx_cpu::op_registry`. Lowering panics with a clear
619            // message if the named CPU kernel isn't registered.
620            Custom,
621            // User-defined sub-graph with optional override AD rules
622            // (JAX-shaped custom_vjp / custom_jvp). Body is a regular
623            // Graph compiled recursively in compile_thunks.
624            CustomFn,
625            // FFT primitive (1D last-axis, 2N real-block layout, f64
626            // power-of-2 sizes). Other backends panic at lowering;
627            // pin FFT-containing graphs to Device::Cpu for now.
628            Fft,
629            // C64 Wirtinger AD surface. ComplexNormSq is the canonical
630            // real-valued loss for complex inputs; Conjugate is emitted
631            // by the new Wirtinger VJP rules for BinaryOp::Mul/Div on
632            // C64. Both have CPU thunks in rlx-cpu.
633            ComplexNormSq,
634            ComplexNormSqBackward,
635            Conjugate,
636        ]
637    };
638
639    impl Backend for CpuBackend {
640        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
641            CPU_SUPPORTED_OPS
642        }
643
644        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
645            use rlx_opt::pass::Pass as _;
646            // Lower Op::If / Op::While to primitives BEFORE legalize
647            // so the supported-op check doesn't reject them — the CPU
648            // backend has no native sub-graph executor; this rewrite
649            // makes If/While invisible to the rest of the pipeline.
650            // No-op when neither op is in the graph.
651            let graph = rlx_opt::LowerControlFlow.run(graph);
652            // PLAN L4: legalize against the backend's claimed op set
653            // BEFORE running fusion (so the diagnostic points at the
654            // user's IR, not at a fused-away node).
655            if let Err(errors) = rlx_opt::legalize_for_backend(&graph, CPU_SUPPORTED_OPS) {
656                panic!("{}", rlx_opt::format_legalize_error("cpu", &errors));
657            }
658            let policy = options.policy.clone();
659            let _precision = options.precision;
660            let cfg = rlx_cpu::config::RuntimeConfig::global();
661
662            // Optional preliminary cleanup passes
663            let graph = if options.dce {
664                rlx_opt::DeadCodeElimination.run(graph)
665            } else {
666                graph
667            };
668            let graph = if options.constant_folding {
669                rlx_opt::ConstantFolding.run(graph)
670            } else {
671                graph
672            };
673
674            // Run fusion pipeline (HIR/MIR/LIR ideology — fusion is first-class).
675            let mut compile_opts = options.clone();
676            compile_opts.arena_alignment = cfg.arena_alignment;
677            let compile_result = crate::stages::compile_graph_stages_for_backend(
678                rlx_driver::Device::Cpu,
679                graph,
680                &compile_opts,
681                CPU_SUPPORTED_OPS,
682            );
683            crate::stages::maybe_log_fusion(&compile_result.fusion);
684            let fused = compile_result.lir.into_graph();
685
686            // Apply precision policy AFTER fusion — Cast nodes don't disrupt
687            // the now-flattened fused ops.
688            let fused = match policy {
689                Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(fused),
690                None => fused,
691            };
692
693            // Re-plan after precision rewrites (may change dtypes / sizes).
694            let plan = memory::plan_memory_aligned(&fused, cfg.arena_alignment);
695            if cfg.verbose >= 1 {
696                eprintln!(
697                    "[rlx] arena: {} bytes, {} buffers, alignment: {}",
698                    plan.arena_size,
699                    plan.assignments.len(),
700                    cfg.arena_alignment
701                );
702            }
703            Box::new(build_cpu_executable(fused, plan))
704        }
705
706        fn compile_lir(
707            &self,
708            lir: LirModule,
709            options: &CompileOptions,
710        ) -> Box<dyn ExecutableGraph> {
711            let alignment = lir.buffers.alignment.max(options.arena_alignment);
712            let embedded: MemoryPlan = (&lir.buffers).into();
713            let mut graph = lir.into_graph();
714            if let Some(p) = options.policy.clone() {
715                use rlx_opt::pass::Pass;
716                graph = rlx_opt::AutoMixedPrecision::new(p).run(graph);
717            }
718            let plan = if options.policy.is_some() {
719                memory::plan_memory_aligned(&graph, alignment)
720            } else {
721                embedded
722            };
723            let cfg = rlx_cpu::config::RuntimeConfig::global();
724            if cfg.verbose >= 1 {
725                eprintln!(
726                    "[rlx] compile_lir: arena {} bytes ({} buffers, alignment {})",
727                    plan.arena_size,
728                    plan.assignments.len(),
729                    alignment,
730                );
731            }
732            Box::new(build_cpu_executable(graph, plan))
733        }
734    }
735
736    fn build_cpu_executable(graph: Graph, plan: MemoryPlan) -> CpuExecutable {
737        let mut arena = Arena::from_plan(plan);
738        let mut input_ids = HashMap::new();
739        let mut param_ids = HashMap::new();
740        let mut node_dtypes: HashMap<NodeId, DType> = HashMap::new();
741        for node in graph.nodes() {
742            node_dtypes.insert(node.id, node.shape.dtype());
743            match &node.op {
744                Op::Input { name } => {
745                    input_ids.insert(name.clone(), node.id);
746                }
747                Op::Param { name } => {
748                    param_ids.insert(name.clone(), node.id);
749                }
750                _ => {}
751            }
752        }
753
754        let schedule = thunk::compile_thunks(&graph, &arena);
755
756        let mut input_slots = Vec::new();
757        for node in graph.nodes() {
758            if let Op::Input { name } = &node.op {
759                let off = arena.byte_offset(node.id);
760                let len = node.shape.num_elements().unwrap_or(0);
761                input_slots.push((name.clone(), off, len, node.shape.dtype()));
762            }
763        }
764
765        let output_slots: Vec<(usize, usize)> = graph
766            .outputs
767            .iter()
768            .map(|&id| {
769                let off = arena.byte_offset(id);
770                let len = graph.node(id).shape.num_elements().unwrap_or(0);
771                (off, len)
772            })
773            .collect();
774
775        for node in graph.nodes() {
776            if let Op::Constant { data } = &node.op
777                && arena.has_buffer(node.id)
778                && !data.is_empty()
779            {
780                match node.shape.dtype() {
781                    DType::F64 => {
782                        let off = arena.byte_offset(node.id);
783                        let buf = arena.raw_buf_mut();
784                        let n = buf.len().saturating_sub(off).min(data.len());
785                        buf[off..off + n].copy_from_slice(&data[..n]);
786                    }
787                    _ => {
788                        let buf = arena.slice_mut(node.id);
789                        let n_floats = data.len() / 4;
790                        let n = buf.len().min(n_floats);
791                        for i in 0..n {
792                            let bytes = [
793                                data[i * 4],
794                                data[i * 4 + 1],
795                                data[i * 4 + 2],
796                                data[i * 4 + 3],
797                            ];
798                            buf[i] = f32::from_le_bytes(bytes);
799                        }
800                    }
801                }
802            }
803        }
804
805        CpuExecutable {
806            graph,
807            arena,
808            params: HashMap::new(),
809            input_ids,
810            param_ids,
811            node_dtypes,
812            schedule,
813            input_slots,
814            output_slots,
815            handles: HashMap::new(),
816            active_extent: None,
817            moe_resident: None,
818            moe_resident_layers: None,
819            moe_topk_capture: None,
820        }
821    }
822
823    #[derive(Clone)]
824    struct CpuExecutable {
825        graph: Graph,
826        arena: Arena,
827        params: HashMap<String, Vec<f32>>,
828        input_ids: HashMap<String, NodeId>,
829        param_ids: HashMap<String, NodeId>,
830        /// Per-node arena dtype. Lets set_param/run cast f32 ↔ F16/BF16
831        /// when AutoMixedPrecision has rewritten the graph.
832        node_dtypes: HashMap<NodeId, DType>,
833        schedule: thunk::ThunkSchedule,
834        // Pre-resolved: ordered list of (input_name, arena_byte_offset, max_elems, dtype)
835        input_slots: Vec<(String, usize, usize, DType)>,
836        /// Output (byte_offset, num_elements). dtype is in node_dtypes.
837        output_slots: Vec<(usize, usize)>,
838        /// Persistent buffer handles (KV-cache, optimizer state, etc.).
839        /// Lives outside the arena and survives across run() calls.
840        /// On run(): if a handle's name matches a graph input, the
841        /// handle's data is used as the input.
842        handles: HashMap<String, Vec<f32>>,
843        /// Active-extent hint (`Some((actual, upper))`) for L1 bucketed
844        /// dispatch. When set AND every thunk in the schedule is in
845        /// `Thunk::safe_for_active_extent`, the executor processes only
846        /// `actual / upper` of each kernel's work. Otherwise (or when
847        /// `None`) runs at the full compiled extent. See PLAN L1.
848        active_extent: Option<(usize, usize)>,
849        moe_resident: Option<std::sync::Arc<[bool]>>,
850        moe_resident_layers: Option<std::sync::Arc<Vec<std::sync::Arc<[bool]>>>>,
851        moe_topk_capture: Option<std::sync::Arc<rlx_cpu::moe_topk_capture::MoeTopkCapture>>,
852    }
853
854    unsafe impl Send for CpuExecutable {}
855
856    impl CpuExecutable {
857        /// Write a f32 input slice into the arena, casting to the node's dtype.
858        fn write_input(&mut self, id: NodeId, data: &[f32]) {
859            let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
860            let off = self.arena.byte_offset(id);
861            let buf = self.arena.raw_buf_mut();
862            let elem_size = dtype.size_bytes();
863            let max_elems = (buf.len() - off) / elem_size;
864            unsafe {
865                write_typed_from_f32(buf.as_mut_ptr().add(off), dtype, data, max_elems);
866            }
867        }
868
869        /// Read a node's arena bytes back as Vec<f32>, casting from its dtype.
870        fn read_output(&self, id: NodeId) -> Vec<f32> {
871            let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
872            let off = self.arena.byte_offset(id);
873            let buf = self.arena.raw_buf();
874            let n_elems = self.graph.node(id).shape.num_elements().unwrap_or(0);
875            unsafe { read_typed_to_f32(buf.as_ptr().add(off), dtype, n_elems) }
876        }
877    }
878
879    impl ExecutableGraph for CpuExecutable {
880        fn clone_box(&self) -> Box<dyn ExecutableGraph> {
881            Box::new(self.clone())
882        }
883        fn set_param(&mut self, name: &str, data: &[f32]) {
884            // Write directly into the arena — zero per-call lookup for params.
885            // Cast f32 → arena dtype when the param has been rewritten to F16/BF16.
886            if let Some(&id) = self.param_ids.get(name)
887                && self.arena.has_buffer(id)
888            {
889                let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
890                let off = self.arena.byte_offset(id);
891                let buf = self.arena.raw_buf_mut();
892                let elem_size = dtype.size_bytes();
893                let max_elems = (buf.len() - off) / elem_size;
894                unsafe {
895                    write_typed_from_f32(buf.as_mut_ptr().add(off), dtype, data, max_elems);
896                }
897                return;
898            }
899            // Fallback: store in HashMap if no arena slot
900            self.params.insert(name.to_string(), data.to_vec());
901        }
902
903        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
904            // 1. Apply persistent handles first — they act like default inputs.
905            //    Explicit `inputs` passed to run() override matching handle names.
906            let handle_names: Vec<String> = self.handles.keys().cloned().collect();
907            for name in &handle_names {
908                if let Some(&id) = self.input_ids.get(name)
909                    && self.arena.has_buffer(id)
910                {
911                    let data = self.handles.get(name).cloned().unwrap_or_default();
912                    self.write_input(id, &data);
913                }
914            }
915            // 2. Explicit per-call inputs override handles.
916            for &(name, data) in inputs {
917                if let Some(&id) = self.input_ids.get(name)
918                    && self.arena.has_buffer(id)
919                {
920                    self.write_input(id, data);
921                }
922            }
923
924            // Active-extent fast-path (PLAN L1): if hinted AND every thunk
925            // in the schedule supports it, run scaled. Otherwise fall back
926            // to full-extent dispatch — preserves correctness when the
927            // schedule contains a thunk that hasn't yet been wired in.
928            let active_used = if let Some((actual, upper)) = self.active_extent {
929                thunk::execute_thunks_active(
930                    &self.schedule,
931                    self.arena.raw_buf_mut(),
932                    actual,
933                    upper,
934                )
935            } else {
936                false
937            };
938            if !active_used {
939                // Execute via pre-compiled thunks (zero per-node dispatch overhead)
940                thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
941            }
942
943            // 3. Sync any handle whose name matches a graph OUTPUT —
944            //    KV-cache pattern: outputs flow back into the same-named
945            //    handle for the next iteration.
946            for (idx, &out_id) in self.graph.outputs.iter().enumerate() {
947                let name = format!("out{idx}");
948                if self.handles.contains_key(&name) {
949                    let v = self.read_output(out_id);
950                    self.handles.insert(name, v);
951                }
952            }
953
954            self.graph
955                .outputs
956                .iter()
957                .map(|&out_id| self.read_output(out_id))
958                .collect()
959        }
960
961        fn run_raw(&mut self, inputs: &[(&str, &[f32])]) -> Vec<(*const f32, usize)> {
962            // Copy inputs by name (HashMap lookup), casting to arena dtype.
963            for &(name, data) in inputs {
964                if let Some(&id) = self.input_ids.get(name)
965                    && self.arena.has_buffer(id)
966                {
967                    self.write_input(id, data);
968                }
969            }
970            thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
971            // Note: pointers are raw arena bytes — for F16 outputs, callers
972            // must read 2 bytes/elem, not 4. run() is the safe path for
973            // mixed precision; run_raw() is only meaningful for F32.
974            self.graph
975                .outputs
976                .iter()
977                .map(|&out_id| {
978                    let (ptr, len) = self.arena.raw_ptr(out_id);
979                    (ptr as *const f32, len)
980                })
981                .collect()
982        }
983
984        /// Fastest path: inputs by index (matching input_slots order), zero-copy output.
985        /// No HashMap, no name matching, no Vec allocation. Casts f32 input
986        /// to F16/BF16 if the input slot's dtype was rewritten.
987        fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
988            let buf = self.arena.raw_buf_mut();
989            for (i, &data) in inputs.iter().enumerate() {
990                if i < self.input_slots.len() {
991                    let (_, off, max_len, dtype) = &self.input_slots[i];
992                    unsafe {
993                        write_typed_from_f32(buf.as_mut_ptr().add(*off), *dtype, data, *max_len);
994                    }
995                }
996            }
997            thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
998            &self.output_slots
999        }
1000
1001        fn arena_ptr(&self) -> *const u8 {
1002            self.arena.raw_buf_mut_ptr()
1003        }
1004
1005        fn bind_handle(&mut self, name: &str, data: &[f32]) -> bool {
1006            // Persistent buffer: stored separately from arena, survives run().
1007            // If the name matches a graph input, run() will use this data
1008            // as the input. If the graph also writes back to this name (via
1009            // an output binding pattern), read_handle returns the latest.
1010            self.handles.insert(name.to_string(), data.to_vec());
1011            true
1012        }
1013
1014        fn read_handle(&self, name: &str) -> Option<Vec<f32>> {
1015            self.handles.get(name).cloned()
1016        }
1017
1018        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1019            self.active_extent = extent;
1020        }
1021
1022        fn set_moe_resident_experts(&mut self, mask: &[bool]) {
1023            self.moe_resident_layers = None;
1024            self.schedule.moe_resident_layers = None;
1025            self.moe_resident = Some(Arc::from(mask));
1026            self.schedule.moe_resident = self.moe_resident.clone();
1027        }
1028
1029        fn set_moe_resident_experts_per_layer(&mut self, masks: &[&[bool]]) {
1030            self.moe_resident = None;
1031            self.schedule.moe_resident = None;
1032            let layers: Vec<Arc<[bool]>> = masks.iter().map(|m| Arc::from(*m)).collect();
1033            let arc = Arc::new(layers);
1034            self.moe_resident_layers = Some(arc.clone());
1035            self.schedule.moe_resident_layers = Some(arc);
1036        }
1037
1038        fn enable_moe_topk_capture(&mut self, num_experts: usize) -> bool {
1039            let cap = rlx_cpu::moe_topk_capture::MoeTopkCapture::new(num_experts);
1040            self.moe_topk_capture = Some(cap.clone());
1041            self.schedule.moe_topk_capture = Some(cap);
1042            true
1043        }
1044
1045        fn take_moe_topk_capture(&mut self) -> Option<Vec<Vec<u32>>> {
1046            let cap = self.moe_topk_capture.as_ref()?;
1047            let layers = cap.take_layers();
1048            if layers.is_empty() {
1049                None
1050            } else {
1051                Some(layers)
1052            }
1053        }
1054
1055        fn take_moe_residency_stats(&mut self) -> Option<crate::MoeResidencyStats> {
1056            rlx_cpu::moe_residency::take_last_forward_stats()
1057        }
1058
1059        /// Typed param upload. F32 / F16 / BF16 go through the existing
1060        /// widen-to-f32 path (the CPU arena is historically f32 with
1061        /// optional half-precision rewrite). F64 (and any future
1062        /// non-widenable dtype) lands directly in the arena as bytes —
1063        /// the f32 path would lose precision.
1064        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1065            if dtype == DType::F64 {
1066                self.set_param_bytes(name, data, dtype);
1067                return;
1068            }
1069            // U8 / I8 raw byte tensors: opaque storage for the GGUF
1070            // K-quant `Op::DequantMatMul` path (weights stay packed
1071            // in the arena). One arena byte = one element.
1072            if matches!(dtype, DType::U8 | DType::I8) {
1073                self.set_param_bytes(name, data, dtype);
1074                return;
1075            }
1076            if dtype == DType::F32 {
1077                let n = data.len() / 4;
1078                let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
1079                self.set_param(name, s);
1080            } else {
1081                let f32_buf = super::widen_bytes_to_f32(data, dtype);
1082                self.set_param(name, &f32_buf);
1083            }
1084        }
1085
1086        /// Typed run with mixed-dtype inputs/outputs.
1087        ///
1088        /// For each input: if its declared graph dtype matches the
1089        /// caller's bytes, we write directly into the arena (zero
1090        /// precision loss — F64 stays F64). For F32 with a half-precision
1091        /// arena rewrite, we widen as before. F16/BF16 callers go
1092        /// through the existing widen path.
1093        ///
1094        /// Outputs are read straight from the arena in the graph node's
1095        /// declared dtype — F64 outputs come back as 8 bytes/element,
1096        /// F32 as 4, etc.
1097        fn run_typed(
1098            &mut self,
1099            inputs: &[(&str, &[u8], rlx_ir::DType)],
1100        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1101            // Decide: are *all* inputs F64? If so, use the direct-byte
1102            // path for everything and skip the f32 widening machinery
1103            // entirely. Mixed dtype graphs (F32 + F64) take the
1104            // per-input dispatch route below.
1105            let all_f64 = !inputs.is_empty() && inputs.iter().all(|(_, _, dt)| *dt == DType::F64);
1106
1107            if all_f64 {
1108                for (name, data, _) in inputs {
1109                    if let Some(&id) = self.input_ids.get(*name) {
1110                        if !self.arena.has_buffer(id) {
1111                            continue;
1112                        }
1113                        let off = self.arena.byte_offset(id);
1114                        let buf = self.arena.raw_buf_mut();
1115                        let n = data.len();
1116                        debug_assert!(
1117                            off + n <= buf.len(),
1118                            "run_typed: input '{name}' overflows arena slot"
1119                        );
1120                        buf[off..off + n].copy_from_slice(data);
1121                    }
1122                }
1123                thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1124            } else {
1125                // Mixed-dtype path: dtypes that survive untouched
1126                // through the f32-aliased arena (F64, I32, I64, U32)
1127                // go in as bytes; F32 and the half-precision family
1128                // route through widen-to-f32 + run.
1129                let mut f32_owned: Vec<(String, Vec<f32>)> = Vec::new();
1130                for (name, data, dt) in inputs {
1131                    let direct = matches!(*dt, DType::F64 | DType::I32 | DType::I64 | DType::U32,);
1132                    if direct {
1133                        if let Some(&id) = self.input_ids.get(*name) {
1134                            if !self.arena.has_buffer(id) {
1135                                continue;
1136                            }
1137                            let off = self.arena.byte_offset(id);
1138                            let buf = self.arena.raw_buf_mut();
1139                            buf[off..off + data.len()].copy_from_slice(data);
1140                        }
1141                    } else {
1142                        let v = super::widen_bytes_to_f32(data, *dt);
1143                        f32_owned.push((name.to_string(), v));
1144                    }
1145                }
1146                let refs: Vec<(&str, &[f32])> = f32_owned
1147                    .iter()
1148                    .map(|(n, d)| (n.as_str(), d.as_slice()))
1149                    .collect();
1150                let _ = self.run(&refs);
1151            }
1152
1153            // Read each output's bytes from the arena in its declared dtype.
1154            self.graph
1155                .outputs
1156                .iter()
1157                .map(|&id| {
1158                    let dtype = self.graph.node(id).shape.dtype();
1159                    let n_elems = self.graph.node(id).shape.num_elements().unwrap_or(0);
1160                    let n_bytes = n_elems * dtype.size_bytes();
1161                    let off = self.arena.byte_offset(id);
1162                    let bytes = self.arena.raw_buf()[off..off + n_bytes].to_vec();
1163                    (bytes, dtype)
1164                })
1165                .collect()
1166        }
1167    }
1168
1169    impl CpuExecutable {
1170        /// Direct-byte param upload — copies caller's bytes into the
1171        /// arena slot for the named param without any dtype conversion.
1172        /// Used by `set_param_typed` for dtypes that f32-widening would
1173        /// corrupt (F64). Caller is responsible for matching the param's
1174        /// declared graph dtype.
1175        fn set_param_bytes(&mut self, name: &str, data: &[u8], _dtype: rlx_ir::DType) {
1176            if let Some(&id) = self.param_ids.get(name)
1177                && self.arena.has_buffer(id)
1178            {
1179                let off = self.arena.byte_offset(id);
1180                let buf = self.arena.raw_buf_mut();
1181                debug_assert!(
1182                    off + data.len() <= buf.len(),
1183                    "set_param_bytes: '{name}' would overflow arena slot"
1184                );
1185                buf[off..off + data.len()].copy_from_slice(data);
1186            }
1187        }
1188    }
1189}
1190
1191// ── Metal Backend ───────────────────────────────────────────────────────
1192
1193// ── wgpu Backend ────────────────────────────────────────────────────────
1194
1195#[cfg(feature = "gpu")]
1196pub mod wgpu_backend {
1197    use super::*;
1198    use rlx_ir::OpKind;
1199    use rlx_wgpu::backend::WgpuExecutable;
1200
1201    pub struct WgpuBackend;
1202
1203    /// PLAN L4: ops the wgpu backend can lower today. The fused
1204    /// macro-kernels (FAB, FTL, FusedSwiGLU) get decomposed by
1205    /// `crate::unfuse::unfuse` upstream — they're listed here too so
1206    /// graphs that already contain them legalize cleanly. Conv1d/3d
1207    /// and Pool1d/3d are deferred (Conv2d only).
1208    const WGPU_SUPPORTED_OPS: &[OpKind] = &[
1209        OpKind::Input,
1210        OpKind::Param,
1211        OpKind::Constant,
1212        OpKind::Activation,
1213        OpKind::Cast,
1214        OpKind::Binary,
1215        OpKind::Compare,
1216        OpKind::Where,
1217        OpKind::ElementwiseRegion,
1218        OpKind::MatMul,
1219        OpKind::DotGeneral,
1220        OpKind::LayerNorm,
1221        OpKind::RmsNorm,
1222        OpKind::Attention,
1223        OpKind::AttentionBackward,
1224        OpKind::RmsNormBackwardInput,
1225        OpKind::RmsNormBackwardGamma,
1226        OpKind::RmsNormBackwardBeta,
1227        OpKind::RopeBackward,
1228        OpKind::CumsumBackward,
1229        OpKind::GatherBackward,
1230        OpKind::Rope,
1231        OpKind::Reshape,
1232        OpKind::Transpose,
1233        OpKind::Narrow,
1234        OpKind::Concat,
1235        OpKind::Expand,
1236        OpKind::Gather,
1237        OpKind::Reduce,
1238        OpKind::Softmax,
1239        OpKind::Cumsum,
1240        OpKind::TopK,
1241        OpKind::Sample,
1242        OpKind::Conv,
1243        OpKind::Pool,
1244        OpKind::GroupedMatMul,
1245        OpKind::DequantGroupedMatMul,
1246        OpKind::DequantMoEWeights,
1247        OpKind::ScatterAdd,
1248        OpKind::SelectiveScan,
1249        OpKind::DequantMatMul,
1250        OpKind::FusedMatMulBiasAct,
1251        OpKind::FusedResidualLN,
1252        OpKind::FusedResidualRmsNorm,
1253        OpKind::FusedSwiGLU,
1254        OpKind::FusedAttentionBlock,
1255        OpKind::FusedTransformerLayer,
1256        // Native FFT (WGSL radix-2): f32 only, power-of-2 N ≤ 1024.
1257        // Anything outside that envelope panics at lowering with a
1258        // "pin to Device::Cpu" hint. No host fallback — WGPU has no
1259        // unified memory, so silent CPU round-trip would be a hidden
1260        // performance cliff.
1261        OpKind::Fft,
1262        // 3D Gaussian splat: native Metal / CPU reference per backend.
1263        OpKind::GaussianSplatRender,
1264        OpKind::GaussianSplatRenderBackward,
1265        OpKind::GaussianSplatPrepare,
1266        OpKind::GaussianSplatRasterize,
1267        OpKind::Custom,
1268        // LoRA, If, While: not yet wired in wgpu — fail loudly.
1269    ];
1270
1271    impl Backend for WgpuBackend {
1272        fn supported_ops(&self) -> &'static [OpKind] {
1273            WGPU_SUPPORTED_OPS
1274        }
1275
1276        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
1277            use rlx_opt::pass::Pass as _;
1278            let graph = rlx_opt::LowerControlFlow.run(graph);
1279            let graph = rlx_opt::legalize_or_rewrite_for_backend(graph, WGPU_SUPPORTED_OPS)
1280                .unwrap_or_else(|errors| {
1281                    panic!("{}", rlx_opt::format_legalize_error("wgpu", &errors));
1282                });
1283            // Cleanup passes upstream of wgpu's pipeline.
1284            let graph = if options.dce {
1285                rlx_opt::DeadCodeElimination.run(graph)
1286            } else {
1287                graph
1288            };
1289            let graph = if options.constant_folding {
1290                rlx_opt::ConstantFolding.run(graph)
1291            } else {
1292                graph
1293            };
1294            // ORDER MATTERS: targeted-pattern fusions run BEFORE the
1295            // catch-all `MarkElementwiseRegions`. Otherwise the region
1296            // pass swallows the Add / Activation nodes into chains and
1297            // FuseMatMulBiasAct / FuseResidualLN fail to match the
1298            // narrower patterns they look for. (Metal pipeline at line
1299            // ~377 already orders these correctly; wgpu was inverted
1300            // and silently shipped 13 unfused LayerNorms per BERT
1301            // forward where 12 should have been FusedResidualLN.)
1302            let compile_result = crate::stages::compile_graph_stages_for_backend(
1303                rlx_driver::Device::Gpu,
1304                graph,
1305                options,
1306                WGPU_SUPPORTED_OPS,
1307            );
1308            crate::stages::maybe_log_fusion(&compile_result.fusion);
1309            let graph = compile_result.lir.into_graph();
1310            let graph = match options.policy.clone() {
1311                Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
1312                None => graph,
1313            };
1314            Box::new(WgpuExecutableWrapper {
1315                inner: WgpuExecutable::compile(graph),
1316            })
1317        }
1318
1319        fn compile_lir(
1320            &self,
1321            lir: LirModule,
1322            options: &CompileOptions,
1323        ) -> Box<dyn ExecutableGraph> {
1324            let graph = prepare_fused_graph(lir.into_graph(), options, WGPU_SUPPORTED_OPS, "wgpu");
1325            Box::new(WgpuExecutableWrapper {
1326                inner: WgpuExecutable::compile(graph),
1327            })
1328        }
1329    }
1330
1331    struct WgpuExecutableWrapper {
1332        inner: WgpuExecutable,
1333    }
1334
1335    unsafe impl Send for WgpuExecutableWrapper {}
1336
1337    impl ExecutableGraph for WgpuExecutableWrapper {
1338        fn set_param(&mut self, name: &str, data: &[f32]) {
1339            self.inner.set_param(name, data);
1340        }
1341        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
1342            self.inner.run(inputs)
1343        }
1344        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1345            self.inner.set_active_extent(extent);
1346        }
1347
1348        /// Typed param upload: widens F16/BF16 to F32 at the host boundary,
1349        /// since the wgpu arena is f32-uniform.
1350        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1351            match dtype {
1352                rlx_ir::DType::U8 | rlx_ir::DType::I8 => {
1353                    self.inner.set_param_bytes(name, data);
1354                }
1355                rlx_ir::DType::F32 => {
1356                    let n = data.len() / 4;
1357                    let f32_slice =
1358                        unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
1359                    self.inner.set_param(name, f32_slice);
1360                }
1361                rlx_ir::DType::F16 => {
1362                    let n = data.len() / 2;
1363                    let f16_slice =
1364                        unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n) };
1365                    let f32: Vec<f32> = f16_slice.iter().map(|h| h.to_f32()).collect();
1366                    self.inner.set_param(name, &f32);
1367                }
1368                rlx_ir::DType::BF16 => {
1369                    let n = data.len() / 2;
1370                    let bf16_slice = unsafe {
1371                        std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n)
1372                    };
1373                    let f32: Vec<f32> = bf16_slice.iter().map(|h| h.to_f32()).collect();
1374                    self.inner.set_param(name, &f32);
1375                }
1376                other => panic!(
1377                    "rlx-wgpu set_param_typed: dtype {other:?} unsupported \
1378                                 (F32, F16, BF16 only — wgpu arena is f32-uniform)"
1379                ),
1380            }
1381        }
1382
1383        /// Typed run: widen each typed input to F32, run, then narrow each
1384        /// output back to its declared dtype.
1385        fn run_typed(
1386            &mut self,
1387            inputs: &[(&str, &[u8], rlx_ir::DType)],
1388        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1389            let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
1390            for (name, data, dt) in inputs {
1391                let v: Vec<f32> = match *dt {
1392                    rlx_ir::DType::F32 => {
1393                        let n = data.len() / 4;
1394                        unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) }
1395                            .to_vec()
1396                    }
1397                    rlx_ir::DType::F16 => {
1398                        let n = data.len() / 2;
1399                        let s = unsafe {
1400                            std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n)
1401                        };
1402                        s.iter().map(|h| h.to_f32()).collect()
1403                    }
1404                    rlx_ir::DType::BF16 => {
1405                        let n = data.len() / 2;
1406                        let s = unsafe {
1407                            std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n)
1408                        };
1409                        s.iter().map(|h| h.to_f32()).collect()
1410                    }
1411                    other => {
1412                        panic!("rlx-wgpu run_typed: input '{name}' dtype {other:?} unsupported")
1413                    }
1414                };
1415                owned.push((name.to_string(), v));
1416            }
1417            let refs: Vec<(&str, &[f32])> = owned
1418                .iter()
1419                .map(|(n, d)| (n.as_str(), d.as_slice()))
1420                .collect();
1421            let dtypes = self.inner.output_dtypes();
1422            let outs = self.inner.run(&refs);
1423            outs.into_iter()
1424                .zip(
1425                    dtypes
1426                        .into_iter()
1427                        .chain(std::iter::repeat(rlx_ir::DType::F32)),
1428                )
1429                .map(|(v, dt)| (narrow_to_dtype(&v, dt), dt))
1430                .collect()
1431        }
1432    }
1433
1434    /// Cast every element of a wgpu f32 output buffer down to the
1435    /// declared output dtype, returning the corresponding byte stream.
1436    /// The arena keeps every value as f32; declared output dtypes
1437    /// (Bool, I8, I32, F16, ...) require an exit-time narrowing to be
1438    /// byte-identical with backends that store the native dtype.
1439    fn narrow_to_dtype(v: &[f32], dt: rlx_ir::DType) -> Vec<u8> {
1440        use rlx_ir::DType;
1441        match dt {
1442            DType::F32 => {
1443                let mut bytes = Vec::with_capacity(v.len() * 4);
1444                for &x in v {
1445                    bytes.extend_from_slice(&x.to_le_bytes());
1446                }
1447                bytes
1448            }
1449            DType::F16 => {
1450                let mut bytes = Vec::with_capacity(v.len() * 2);
1451                for &x in v {
1452                    bytes.extend_from_slice(&half::f16::from_f32(x).to_le_bytes());
1453                }
1454                bytes
1455            }
1456            DType::BF16 => {
1457                let mut bytes = Vec::with_capacity(v.len() * 2);
1458                for &x in v {
1459                    bytes.extend_from_slice(&half::bf16::from_f32(x).to_le_bytes());
1460                }
1461                bytes
1462            }
1463            DType::F64 => {
1464                let mut bytes = Vec::with_capacity(v.len() * 8);
1465                for &x in v {
1466                    bytes.extend_from_slice(&(x as f64).to_le_bytes());
1467                }
1468                bytes
1469            }
1470            DType::I8 => v.iter().map(|&x| x as i8 as u8).collect(),
1471            DType::U8 => v.iter().map(|&x| x as u8).collect(),
1472            DType::I16 => {
1473                let mut bytes = Vec::with_capacity(v.len() * 2);
1474                for &x in v {
1475                    bytes.extend_from_slice(&(x as i16).to_le_bytes());
1476                }
1477                bytes
1478            }
1479            DType::I32 => {
1480                let mut bytes = Vec::with_capacity(v.len() * 4);
1481                for &x in v {
1482                    bytes.extend_from_slice(&(x as i32).to_le_bytes());
1483                }
1484                bytes
1485            }
1486            DType::U32 => {
1487                let mut bytes = Vec::with_capacity(v.len() * 4);
1488                for &x in v {
1489                    bytes.extend_from_slice(&(x as u32).to_le_bytes());
1490                }
1491                bytes
1492            }
1493            DType::I64 => {
1494                let mut bytes = Vec::with_capacity(v.len() * 8);
1495                for &x in v {
1496                    bytes.extend_from_slice(&(x as i64).to_le_bytes());
1497                }
1498                bytes
1499            }
1500            DType::Bool => v
1501                .iter()
1502                .map(|&x| if x != 0.0 { 1u8 } else { 0u8 })
1503                .collect(),
1504            // C64 (complex f32 pair) — the wgpu backend's f32 arena
1505            // doesn't synthesize complex outputs today; this branch
1506            // only fires if a graph somehow asks for a C64 output and
1507            // the backend lowered it as 2N real floats. We pass the
1508            // raw f32 stream straight through; downstream code that
1509            // wants complex semantics is responsible for re-pairing.
1510            DType::C64 => {
1511                let mut bytes = Vec::with_capacity(v.len() * 4);
1512                for &x in v {
1513                    bytes.extend_from_slice(&x.to_le_bytes());
1514                }
1515                bytes
1516            }
1517        }
1518    }
1519}
1520
1521// ── MLX Backend ─────────────────────────────────────────────────────────
1522
1523#[cfg(all(feature = "mlx", target_os = "macos"))]
1524pub mod mlx_backend {
1525    use super::*;
1526    use rlx_mlx::MlxExecutable;
1527
1528    pub struct MlxBackend;
1529
1530    /// PLAN L4: ops the MLX backend can lower today. MLX has the
1531    /// widest IR coverage of any GPU backend — handles everything
1532    /// including If/While via topo unrolling, and lowers
1533    /// ElementwiseRegion natively via the per-step composition in
1534    /// rlx-mlx/src/lower.rs (PLAN L2).
1535    const MLX_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
1536        use rlx_ir::OpKind::*;
1537        &[
1538            Input,
1539            Param,
1540            Constant,
1541            Activation,
1542            Cast,
1543            Binary,
1544            Compare,
1545            Where,
1546            ElementwiseRegion,
1547            MatMul,
1548            DotGeneral,
1549            DenseSolve,
1550            BatchedDenseSolve,
1551            LayerNorm,
1552            LayerNorm2d,
1553            RmsNorm,
1554            Attention,
1555            Rope,
1556            Reshape,
1557            Transpose,
1558            Narrow,
1559            Concat,
1560            Expand,
1561            Gather,
1562            Reduce,
1563            Softmax,
1564            Cumsum,
1565            TopK,
1566            Sample,
1567            Conv,
1568            ConvTranspose2d,
1569            Pool,
1570            GroupedMatMul,
1571            DequantGroupedMatMul,
1572            DequantMoEWeights,
1573            ScatterAdd,
1574            LoraMatMul,
1575            DequantMatMul,
1576            SelectiveScan,
1577            GatedDeltaNet,
1578            FusedSwiGLU,
1579            FusedMatMulBiasAct,
1580            FusedResidualLN,
1581            FusedResidualRmsNorm,
1582            FusedAttentionBlock,
1583            FusedTransformerLayer,
1584            If,
1585            While,
1586            // Loop-unrolled scan (Op::Scan body is statically unrolled
1587            // `length` times into MLX ops; mirror of Op::While's
1588            // bounded-unroll lowering). ScanBackward is the AD
1589            // companion — handled the same way.
1590            Scan,
1591            ScanBackward,
1592            ScanBackwardXs,
1593            // Tier 1 autodiff backward ops — lowered as primitive
1594            // compositions in `rlx-mlx/src/lower.rs`.
1595            ReluBackward,
1596            ActivationBackward,
1597            SoftmaxCrossEntropyWithLogits,
1598            SoftmaxCrossEntropyBackward,
1599            AttentionBackward,
1600            LayerNormBackwardInput,
1601            LayerNormBackwardGamma,
1602            // Tier 2 — conv backward via `mc::conv_general` with the
1603            // same parameter-mapping MLX uses inside its built-in vjp.
1604            // Currently groups=1 only; grouped conv backward will
1605            // surface as a clear error from `lower.rs`.
1606            Conv2dBackwardInput,
1607            Conv2dBackwardWeight,
1608            // Tier 3 — max-pool backward via slice-strided argmax over
1609            // pool windows + a per-kernel-slot scatter-add, matching
1610            // the CPU thunk's "first-hit-wins" tiebreaking.
1611            MaxPool2dBackward,
1612            // QAT — `FakeQuantize` (PerBatch + Fixed scale modes;
1613            // EMA returns a clear error from `lower.rs`) and the
1614            // `FakeQuantizeBackward` family covering all 4 STE
1615            // variants. Closes the last gap vs `CPU_SUPPORTED_OPS`.
1616            FakeQuantize,
1617            FakeQuantizeBackward,
1618            // User-registered custom ops dispatched through
1619            // `rlx_mlx::op_registry`. Lowering looks up the
1620            // registered `MlxKernel` and calls its `execute` method
1621            // to produce the lazy MLX `Array` for this node.
1622            Custom,
1623            Fft,
1624            GaussianSplatRender,
1625            GaussianSplatRenderBackward,
1626            // Op::Fft on MLX: native `mlx::fft::fft` via rlx_mlx_op_fft shim.
1627            // 2N real-block f32/f64 and complex64 inputs supported.
1628        ]
1629    };
1630
1631    impl Backend for MlxBackend {
1632        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
1633            MLX_SUPPORTED_OPS
1634        }
1635
1636        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
1637            let compile_result = crate::stages::compile_graph_stages_for_backend(
1638                rlx_driver::Device::Mlx,
1639                graph,
1640                options,
1641                MLX_SUPPORTED_OPS,
1642            );
1643            crate::stages::maybe_log_fusion(&compile_result.fusion);
1644            self.compile_lir(compile_result.lir, options)
1645        }
1646
1647        fn compile_lir(
1648            &self,
1649            lir: LirModule,
1650            options: &CompileOptions,
1651        ) -> Box<dyn ExecutableGraph> {
1652            use rlx_opt::pass::Pass as _;
1653            let mut graph = lir.into_graph();
1654            graph = rlx_opt::LowerControlFlow.run(graph);
1655            let graph = prepare_fused_graph(graph, options, MLX_SUPPORTED_OPS, "mlx");
1656            Box::new(build_mlx_executable(graph))
1657        }
1658    }
1659
1660    fn build_mlx_executable(graph: Graph) -> MlxExecutableWrapper {
1661        let mode = mlx_mode_from_env();
1662        let mut exe = MlxExecutable::compile_from_fused(graph, mode);
1663        if mode == rlx_mlx::lower::MlxMode::Compiled {
1664            if let Err(e) = exe.warm_compile() {
1665                eprintln!(
1666                    "[rlx-runtime] MLX warm_compile failed ({e}); first run will pay the trace cost"
1667                );
1668            }
1669        }
1670        MlxExecutableWrapper { inner: exe }
1671    }
1672
1673    fn mlx_mode_from_env() -> rlx_mlx::lower::MlxMode {
1674        match rlx_ir::env::var("RLX_MLX_MODE").as_deref() {
1675            Some(s) if s.eq_ignore_ascii_case("eager") => rlx_mlx::lower::MlxMode::Eager,
1676            Some(s) if s.eq_ignore_ascii_case("lazy") => rlx_mlx::lower::MlxMode::Lazy,
1677            Some(s) if s.eq_ignore_ascii_case("compiled") => rlx_mlx::lower::MlxMode::Compiled,
1678            _ => rlx_mlx::lower::MlxMode::Compiled,
1679        }
1680    }
1681
1682    struct MlxExecutableWrapper {
1683        inner: MlxExecutable,
1684    }
1685
1686    unsafe impl Send for MlxExecutableWrapper {}
1687
1688    impl ExecutableGraph for MlxExecutableWrapper {
1689        fn set_param(&mut self, name: &str, data: &[f32]) {
1690            self.inner.set_param(name, data);
1691        }
1692        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
1693            self.inner.run(inputs)
1694        }
1695        fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
1696            self.inner.run_slots(inputs)
1697        }
1698        fn arena_ptr(&self) -> *const u8 {
1699            self.inner.arena_ptr()
1700        }
1701        fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
1702            self.inner.commit_no_wait(inputs);
1703        }
1704        fn sync_pending(&mut self) {
1705            self.inner.sync_pending();
1706        }
1707        fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
1708            self.inner.run_pipelined(input_sets)
1709        }
1710        fn bind_handle(&mut self, name: &str, data: &[f32]) -> bool {
1711            self.inner.bind_handle(name, data)
1712        }
1713        fn read_handle(&self, name: &str) -> Option<Vec<f32>> {
1714            self.inner.read_handle(name)
1715        }
1716        fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
1717            self.inner.bind_gpu_handle(name, data).is_ok()
1718        }
1719        fn has_gpu_handle(&self, name: &str) -> bool {
1720            self.inner.has_gpu_handle(name)
1721        }
1722        fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
1723            self.inner.set_gpu_handle_feed(handle_name, output_index);
1724            true
1725        }
1726        fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
1727            self.inner.read_gpu_handle(name).ok()
1728        }
1729        fn run_feed_gpu_handle(
1730            &mut self,
1731            inputs: &[(&str, &[f32])],
1732            handle_name: &str,
1733            output_index: usize,
1734        ) -> Option<Vec<f32>> {
1735            self.inner
1736                .run_feed_gpu(inputs, handle_name, output_index)
1737                .ok()
1738        }
1739        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1740            self.inner.set_param_typed(name, data, dtype);
1741        }
1742        fn run_typed(
1743            &mut self,
1744            inputs: &[(&str, &[u8], rlx_ir::DType)],
1745        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1746            self.inner.run_typed(inputs)
1747        }
1748        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1749            self.inner.set_active_extent(extent);
1750        }
1751    }
1752}
1753
1754#[cfg(all(feature = "metal", target_os = "macos"))]
1755pub mod metal_backend {
1756    use super::*;
1757    use rlx_metal::backend::MetalExecutable;
1758
1759    pub struct MetalBackend;
1760
1761    /// PLAN L4: ops the Metal backend can lower today. Includes
1762    /// DotGeneral (LowerDotGeneral pass) and ElementwiseRegion
1763    /// (decomposed by UnfuseElementwiseRegions). Excludes Cumsum,
1764    /// SelectiveScan, LoraMatMul, Sample,
1765    /// FusedAttentionBlock, FusedTransformerLayer, If, While —
1766    /// not yet wired in `rlx-metal/src/thunk.rs`'s compile_thunks.
1767    /// DequantMatMul (GGUF K-quants) lowers to a GPU dequant kernel
1768    /// + MPS matmul; legacy Int8 schemes remain CPU-only.
1769    const METAL_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
1770        use rlx_ir::OpKind::*;
1771        &[
1772            Input,
1773            Param,
1774            Constant,
1775            Activation,
1776            Cast,
1777            Binary,
1778            Compare,
1779            Where,
1780            ElementwiseRegion,
1781            MatMul,
1782            DotGeneral,
1783            LayerNorm,
1784            LayerNorm2d,
1785            GroupNorm,
1786            RmsNorm,
1787            ResizeNearest2x,
1788            AxialRope2d,
1789            Attention,
1790            AttentionBackward,
1791            RmsNormBackwardInput,
1792            RmsNormBackwardGamma,
1793            RmsNormBackwardBeta,
1794            RopeBackward,
1795            CumsumBackward,
1796            GatherBackward,
1797            Rope,
1798            Reshape,
1799            Transpose,
1800            Narrow,
1801            Concat,
1802            Expand,
1803            Gather,
1804            Reduce,
1805            Softmax,
1806            TopK,
1807            Conv,
1808            ConvTranspose2d,
1809            Pool,
1810            GroupedMatMul,
1811            DequantGroupedMatMul,
1812            DequantMoEWeights,
1813            ScatterAdd,
1814            DequantMatMul,
1815            GatedDeltaNet,
1816            FusedSwiGLU,
1817            FusedMatMulBiasAct,
1818            FusedResidualLN,
1819            FusedResidualRmsNorm,
1820            // User-registered custom ops dispatched through
1821            // `rlx_metal::op_registry`. Lowering panics with a clear
1822            // message if the named MetalKernel isn't registered;
1823            // executor inserts a sync point + runs the host kernel
1824            // against the unified-memory arena.
1825            Custom,
1826            // Op::Fft is supported via the same host-fallback pattern
1827            // as Custom: sync the GPU, run rlx-cpu's FFT against the
1828            // unified-memory arena, restart cmd_buf. A native Metal
1829            // compute kernel will replace this when a workload makes
1830            // the sync the bottleneck.
1831            Fft,
1832            // Host-fallback splat (unified-memory arena + rlx-cpu/splat).
1833            GaussianSplatRender,
1834            GaussianSplatRenderBackward,
1835            GaussianSplatPrepare,
1836            GaussianSplatRasterize,
1837        ]
1838    };
1839
1840    impl Backend for MetalBackend {
1841        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
1842            METAL_SUPPORTED_OPS
1843        }
1844
1845        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
1846            use rlx_opt::pass::Pass as _;
1847            // Same If/While → primitive rewrite as the CPU pipeline
1848            // (Metal also has no native sub-graph executor wired
1849            // through its thunk schedule).
1850            let graph = rlx_opt::LowerControlFlow.run(graph);
1851            let mut dispatch = options.kernel_dispatch;
1852            let graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
1853                graph,
1854                METAL_SUPPORTED_OPS,
1855                dispatch,
1856            )
1857            .unwrap_or_else(|errors| {
1858                panic!("{}", rlx_opt::format_legalize_error("metal", &errors));
1859            });
1860            // Optional cleanup passes upstream of Metal's pipeline
1861            let graph = if options.dce {
1862                rlx_opt::DeadCodeElimination.run(graph)
1863            } else {
1864                graph
1865            };
1866            let graph = if options.constant_folding {
1867                rlx_opt::ConstantFolding.run(graph)
1868            } else {
1869                graph
1870            };
1871
1872            // Hand the policy to MetalExecutable so the rewrite runs AFTER
1873            // its internal fusion passes (avoids breaking pattern matchers).
1874            Box::new(MetalExecutableWrapper {
1875                inner: MetalExecutable::compile_with_policy(
1876                    graph,
1877                    options.policy.clone(),
1878                    Some(METAL_SUPPORTED_OPS),
1879                ),
1880            })
1881        }
1882
1883        fn compile_lir(
1884            &self,
1885            lir: LirModule,
1886            options: &CompileOptions,
1887        ) -> Box<dyn ExecutableGraph> {
1888            use rlx_opt::pass::Pass as _;
1889            let mut graph = lir.into_graph();
1890            graph = rlx_opt::LowerControlFlow.run(graph);
1891            let mut dispatch = options.kernel_dispatch;
1892            let mut graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
1893                graph,
1894                METAL_SUPPORTED_OPS,
1895                dispatch,
1896            )
1897            .unwrap_or_else(|errors| {
1898                panic!("{}", rlx_opt::format_legalize_error("metal", &errors));
1899            });
1900            if options.dce {
1901                graph = rlx_opt::DeadCodeElimination.run(graph);
1902            }
1903            if options.constant_folding {
1904                graph = rlx_opt::ConstantFolding.run(graph);
1905            }
1906            Box::new(MetalExecutableWrapper {
1907                inner: MetalExecutable::compile_from_fused(
1908                    graph,
1909                    options.policy.clone(),
1910                    Some(METAL_SUPPORTED_OPS),
1911                ),
1912            })
1913        }
1914    }
1915
1916    struct MetalExecutableWrapper {
1917        inner: MetalExecutable,
1918    }
1919
1920    unsafe impl Send for MetalExecutableWrapper {}
1921
1922    impl ExecutableGraph for MetalExecutableWrapper {
1923        fn set_param(&mut self, name: &str, data: &[f32]) {
1924            self.inner.set_param(name, data);
1925        }
1926        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
1927            self.inner.run(inputs)
1928        }
1929        fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
1930            self.inner.run_slots(inputs)
1931        }
1932        fn arena_ptr(&self) -> *const u8 {
1933            self.inner.arena_ptr()
1934        }
1935        fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
1936            self.inner.commit_no_wait(inputs);
1937        }
1938        fn sync_pending(&mut self) {
1939            self.inner.sync_pending();
1940        }
1941        fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
1942            self.inner.run_pipelined(input_sets)
1943        }
1944        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1945            self.inner.set_active_extent(extent);
1946        }
1947
1948        /// Typed param upload — accepts F16/BF16 host bytes by widening
1949        /// to F32 first, then routing through `set_param`. The Metal
1950        /// arena's `write_from_f32` honors per-node F16 storage when
1951        /// AutoMixedPrecision rewrote the param. U8/I8 packed weights
1952        /// copy directly into the arena for `Op::DequantMatMul`.
1953        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1954            if matches!(dtype, rlx_ir::DType::U8 | rlx_ir::DType::I8) {
1955                self.inner.set_param_bytes(name, data);
1956                return;
1957            }
1958            if dtype == rlx_ir::DType::F32 {
1959                let n = data.len() / 4;
1960                let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
1961                self.inner.set_param(name, s);
1962            } else {
1963                let f32_buf = super::widen_bytes_to_f32(data, dtype);
1964                self.inner.set_param(name, &f32_buf);
1965            }
1966        }
1967
1968        /// Typed run. Inputs widen to F32 (existing path; F64 host
1969        /// inputs through `run_typed` is a separate Metal extension).
1970        /// Outputs: F64 outputs go through the byte-direct
1971        /// `output_bytes_per_node` path (no precision loss in the
1972        /// f32 round-trip); other dtypes keep the f32-narrow path
1973        /// for backward compatibility with existing AutoMixedPrecision
1974        /// rewrites.
1975        fn run_typed(
1976            &mut self,
1977            inputs: &[(&str, &[u8], rlx_ir::DType)],
1978        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1979            let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
1980            for (name, data, dt) in inputs {
1981                let v = super::widen_bytes_to_f32(data, *dt);
1982                owned.push((name.to_string(), v));
1983            }
1984            let refs: Vec<(&str, &[f32])> = owned
1985                .iter()
1986                .map(|(n, d)| (n.as_str(), d.as_slice()))
1987                .collect();
1988            let dtypes = self.inner.output_dtypes();
1989            let f32_outs = self.inner.run(&refs);
1990            let byte_outs = self.inner.output_bytes_per_node();
1991            f32_outs
1992                .into_iter()
1993                .zip(byte_outs.into_iter())
1994                .zip(
1995                    dtypes
1996                        .into_iter()
1997                        .chain(std::iter::repeat(rlx_ir::DType::F32)),
1998                )
1999                .map(|((f32_v, byte_v), dt)| match dt {
2000                    rlx_ir::DType::F64 => (byte_v, dt),
2001                    _ => (super::narrow_f32_to_bytes(&f32_v, dt), dt),
2002                })
2003                .collect()
2004        }
2005    }
2006}
2007
2008// ── CUDA Backend ────────────────────────────────────────────────────────
2009
2010#[cfg(feature = "cuda")]
2011pub mod cuda_backend {
2012    use super::*;
2013    use rlx_cuda::backend::CudaExecutable;
2014
2015    pub struct CudaBackend;
2016
2017    /// PLAN L4: ops the CUDA backend can lower today. Excludes
2018    /// FusedSwiGLU, LoraMatMul, FusedAttentionBlock,
2019    /// FusedTransformerLayer (no kernel) + If, While (no executor
2020    /// wiring). DotGeneral via LowerDotGeneral; ElementwiseRegion
2021    /// lowered natively by an NVRTC interpreted-chain kernel.
2022    const CUDA_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2023        use rlx_ir::OpKind::*;
2024        &[
2025            Input,
2026            Param,
2027            Constant,
2028            Activation,
2029            Cast,
2030            Binary,
2031            Compare,
2032            Where,
2033            ElementwiseRegion,
2034            MatMul,
2035            DotGeneral,
2036            LayerNorm,
2037            LayerNorm2d,
2038            RmsNorm,
2039            Attention,
2040            AttentionBackward,
2041            RmsNormBackwardInput,
2042            RmsNormBackwardGamma,
2043            RmsNormBackwardBeta,
2044            RopeBackward,
2045            CumsumBackward,
2046            GatherBackward,
2047            Rope,
2048            Reshape,
2049            Transpose,
2050            Narrow,
2051            Concat,
2052            Expand,
2053            Gather,
2054            Reduce,
2055            Softmax,
2056            Cumsum,
2057            TopK,
2058            Sample,
2059            Conv,
2060            ConvTranspose2d,
2061            Pool,
2062            GroupedMatMul,
2063            DequantGroupedMatMul,
2064            DequantMoEWeights,
2065            ScatterAdd,
2066            DequantMatMul,
2067            SelectiveScan,
2068            FusedMatMulBiasAct,
2069            FusedResidualLN,
2070            FusedResidualRmsNorm,
2071            GaussianSplatRender,
2072            GaussianSplatRenderBackward,
2073            GaussianSplatPrepare,
2074            GaussianSplatRasterize,
2075            Custom,
2076            Fft,
2077        ]
2078    };
2079
2080    impl Backend for CudaBackend {
2081        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2082            CUDA_SUPPORTED_OPS
2083        }
2084
2085        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2086            // Decompose FusedSwiGLU / FAB / etc. before legalization (CudaExecutable
2087            // unfuses again; this pass is idempotent).
2088            let graph = rlx_cuda::unfuse::unfuse(graph);
2089            let graph = rlx_opt::legalize_or_rewrite_for_backend(graph, CUDA_SUPPORTED_OPS)
2090                .unwrap_or_else(|errors| {
2091                    panic!("{}", rlx_opt::format_legalize_error("cuda", &errors));
2092                });
2093            use rlx_opt::pass::Pass as _;
2094            let graph = if options.dce {
2095                rlx_opt::DeadCodeElimination.run(graph)
2096            } else {
2097                graph
2098            };
2099            let graph = if options.constant_folding {
2100                rlx_opt::ConstantFolding.run(graph)
2101            } else {
2102                graph
2103            };
2104            // Backend-aware fusion via the shared compile pipeline.
2105            let compile_result = crate::stages::compile_graph_stages_for_backend(
2106                rlx_driver::Device::Cuda,
2107                graph,
2108                options,
2109                CUDA_SUPPORTED_OPS,
2110            );
2111            crate::stages::maybe_log_fusion(&compile_result.fusion);
2112            let graph = compile_result.lir.into_graph();
2113            let graph = match options.policy.clone() {
2114                Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
2115                None => graph,
2116            };
2117            Box::new(CudaExecutableWrapper {
2118                inner: CudaExecutable::compile(graph),
2119            })
2120        }
2121
2122        fn compile_lir(
2123            &self,
2124            lir: LirModule,
2125            options: &CompileOptions,
2126        ) -> Box<dyn ExecutableGraph> {
2127            let graph = prepare_fused_graph(
2128                rlx_cuda::unfuse::unfuse(lir.into_graph()),
2129                options,
2130                CUDA_SUPPORTED_OPS,
2131                "cuda",
2132            );
2133            Box::new(CudaExecutableWrapper {
2134                inner: CudaExecutable::compile(graph),
2135            })
2136        }
2137    }
2138
2139    struct CudaExecutableWrapper {
2140        inner: CudaExecutable,
2141    }
2142
2143    // CudaExecutable owns CudaContext + CudaSlice handles; cudarc claims
2144    // they're Send (CudaContext is Arc-wrapped, CudaSlice is logically
2145    // a device pointer + length). The Backend trait requires Send for
2146    // the executable; we honor that here.
2147    unsafe impl Send for CudaExecutableWrapper {}
2148
2149    impl ExecutableGraph for CudaExecutableWrapper {
2150        fn set_param(&mut self, name: &str, data: &[f32]) {
2151            self.inner.set_param(name, data);
2152        }
2153        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2154            self.inner.run(inputs)
2155        }
2156        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2157            self.inner.set_active_extent(extent);
2158        }
2159
2160        /// Typed param upload — widens F16/BF16 host bytes to f32
2161        /// before routing through `set_param`. CUDA's arena is
2162        /// f32-uniform; the half-precision matmul tier opts in via
2163        /// the separate `set_param_half` API.
2164        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2165            if matches!(dtype, rlx_ir::DType::U8 | rlx_ir::DType::I8) {
2166                self.inner.set_param_bytes(name, data);
2167                return;
2168            }
2169            if dtype == rlx_ir::DType::F32 {
2170                let n = data.len() / 4;
2171                let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2172                self.inner.set_param(name, s);
2173            } else {
2174                let f32_buf = super::widen_bytes_to_f32(data, dtype);
2175                self.inner.set_param(name, &f32_buf);
2176            }
2177        }
2178
2179        /// Typed run — widen each typed input to F32, run, then narrow
2180        /// each output back to its declared graph dtype.
2181        fn run_typed(
2182            &mut self,
2183            inputs: &[(&str, &[u8], rlx_ir::DType)],
2184        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2185            let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2186            for (name, data, dt) in inputs {
2187                let v = super::widen_bytes_to_f32(data, *dt);
2188                owned.push((name.to_string(), v));
2189            }
2190            let refs: Vec<(&str, &[f32])> = owned
2191                .iter()
2192                .map(|(n, d)| (n.as_str(), d.as_slice()))
2193                .collect();
2194            let dtypes = self.inner.output_dtypes();
2195            let outs = self.inner.run(&refs);
2196            outs.into_iter()
2197                .zip(
2198                    dtypes
2199                        .into_iter()
2200                        .chain(std::iter::repeat(rlx_ir::DType::F32)),
2201                )
2202                .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2203                .collect()
2204        }
2205    }
2206}
2207
2208// ── ROCm Backend ────────────────────────────────────────────────────────
2209
2210#[cfg(feature = "rocm")]
2211pub mod rocm_backend {
2212    use super::*;
2213    use rlx_rocm::backend::RocmExecutable;
2214
2215    pub struct RocmBackend;
2216
2217    /// PLAN L4: ROCm is the sister crate of CUDA; identical Step
2218    /// enum + dispatch shape → identical claimed op set.
2219    const ROCM_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2220        use rlx_ir::OpKind::*;
2221        &[
2222            Input,
2223            Param,
2224            Constant,
2225            Activation,
2226            Cast,
2227            Binary,
2228            Compare,
2229            Where,
2230            ElementwiseRegion,
2231            MatMul,
2232            DotGeneral,
2233            LayerNorm,
2234            RmsNorm,
2235            Attention,
2236            AttentionBackward,
2237            Rope,
2238            Reshape,
2239            Transpose,
2240            Narrow,
2241            Concat,
2242            Expand,
2243            Gather,
2244            Reduce,
2245            Softmax,
2246            Cumsum,
2247            TopK,
2248            Sample,
2249            Conv,
2250            Pool,
2251            GroupedMatMul,
2252            DequantGroupedMatMul,
2253            DequantMoEWeights,
2254            ScatterAdd,
2255            DequantMatMul,
2256            SelectiveScan,
2257            FusedMatMulBiasAct,
2258            FusedResidualLN,
2259            FusedResidualRmsNorm,
2260            GaussianSplatRender,
2261            GaussianSplatRenderBackward,
2262            GaussianSplatPrepare,
2263            GaussianSplatRasterize,
2264            Custom,
2265            Fft,
2266        ]
2267    };
2268
2269    impl Backend for RocmBackend {
2270        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2271            ROCM_SUPPORTED_OPS
2272        }
2273
2274        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2275            let graph = rlx_opt::rewrite_for_backend(graph, ROCM_SUPPORTED_OPS);
2276            if let Err(errors) = rlx_opt::legalize_for_backend(&graph, ROCM_SUPPORTED_OPS) {
2277                panic!("{}", rlx_opt::format_legalize_error("rocm", &errors));
2278            }
2279            use rlx_opt::pass::Pass as _;
2280            let graph = if options.dce {
2281                rlx_opt::DeadCodeElimination.run(graph)
2282            } else {
2283                graph
2284            };
2285            let graph = if options.constant_folding {
2286                rlx_opt::ConstantFolding.run(graph)
2287            } else {
2288                graph
2289            };
2290            let compile_result = crate::stages::compile_graph_stages_for_backend(
2291                rlx_driver::Device::Rocm,
2292                graph,
2293                options,
2294                ROCM_SUPPORTED_OPS,
2295            );
2296            crate::stages::maybe_log_fusion(&compile_result.fusion);
2297            let graph = compile_result.lir.into_graph();
2298            let graph = match options.policy.clone() {
2299                Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
2300                None => graph,
2301            };
2302            Box::new(RocmExecutableWrapper {
2303                inner: RocmExecutable::compile(graph),
2304            })
2305        }
2306
2307        fn compile_lir(
2308            &self,
2309            lir: LirModule,
2310            options: &CompileOptions,
2311        ) -> Box<dyn ExecutableGraph> {
2312            let graph = prepare_fused_graph(lir.into_graph(), options, ROCM_SUPPORTED_OPS, "rocm");
2313            Box::new(RocmExecutableWrapper {
2314                inner: RocmExecutable::compile(graph),
2315            })
2316        }
2317    }
2318
2319    struct RocmExecutableWrapper {
2320        inner: RocmExecutable,
2321    }
2322
2323    // Same Send-claim shape as CudaExecutableWrapper. RocmExecutable
2324    // owns Arc<RocmContext> + HipBuffer handles; the HipRuntime bundle
2325    // is internally thread-safe per AMD's documentation.
2326    unsafe impl Send for RocmExecutableWrapper {}
2327
2328    impl ExecutableGraph for RocmExecutableWrapper {
2329        fn set_param(&mut self, name: &str, data: &[f32]) {
2330            self.inner.set_param(name, data);
2331        }
2332        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2333            self.inner.run(inputs)
2334        }
2335        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2336            self.inner.set_active_extent(extent);
2337        }
2338
2339        /// Typed param upload — widens F16/BF16 host bytes to f32
2340        /// before routing through `set_param`. ROCm's arena is
2341        /// f32-uniform; the half-precision matmul tier opts in via
2342        /// the separate `set_param_half` API.
2343        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2344            if matches!(dtype, rlx_ir::DType::U8 | rlx_ir::DType::I8) {
2345                self.inner.set_param_bytes(name, data);
2346                return;
2347            }
2348            if dtype == rlx_ir::DType::F32 {
2349                let n = data.len() / 4;
2350                let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2351                self.inner.set_param(name, s);
2352            } else {
2353                let f32_buf = super::widen_bytes_to_f32(data, dtype);
2354                self.inner.set_param(name, &f32_buf);
2355            }
2356        }
2357
2358        /// Typed run — widen each typed input to F32, run, then narrow
2359        /// each output back to its declared graph dtype.
2360        fn run_typed(
2361            &mut self,
2362            inputs: &[(&str, &[u8], rlx_ir::DType)],
2363        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2364            let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2365            for (name, data, dt) in inputs {
2366                let v = super::widen_bytes_to_f32(data, *dt);
2367                owned.push((name.to_string(), v));
2368            }
2369            let refs: Vec<(&str, &[f32])> = owned
2370                .iter()
2371                .map(|(n, d)| (n.as_str(), d.as_slice()))
2372                .collect();
2373            let dtypes = self.inner.output_dtypes();
2374            let outs = self.inner.run(&refs);
2375            outs.into_iter()
2376                .zip(
2377                    dtypes
2378                        .into_iter()
2379                        .chain(std::iter::repeat(rlx_ir::DType::F32)),
2380                )
2381                .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2382                .collect()
2383        }
2384    }
2385}
2386
2387// ── TPU Backend ─────────────────────────────────────────────────────────
2388
2389#[cfg(feature = "tpu")]
2390pub mod tpu_backend {
2391    use super::*;
2392    use rlx_tpu::TpuExecutable;
2393
2394    pub struct TpuBackend;
2395
2396    /// Ops the TPU backend lowers to HLO. Full inference parity with
2397    /// rlx-cuda / rlx-rocm. Composite ops (FusedSwiGLU /
2398    /// FusedAttentionBlock / FusedTransformerLayer / LoraMatMul / If /
2399    /// While) are unfused inside `rlx_tpu::unfuse::unfuse` ahead of
2400    /// HLO emission, so they don't appear here.
2401    const TPU_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2402        use rlx_ir::OpKind::*;
2403        &[
2404            Input,
2405            Param,
2406            Constant,
2407            Activation,
2408            Cast,
2409            Binary,
2410            Compare,
2411            Where,
2412            ElementwiseRegion,
2413            MatMul,
2414            DotGeneral,
2415            LayerNorm,
2416            RmsNorm,
2417            Attention,
2418            Rope,
2419            Reshape,
2420            Transpose,
2421            Narrow,
2422            Concat,
2423            Expand,
2424            Gather,
2425            Reduce,
2426            Softmax,
2427            Cumsum,
2428            TopK,
2429            Sample,
2430            Conv,
2431            Pool,
2432            GroupedMatMul,
2433            DequantGroupedMatMul,
2434            DequantMoEWeights,
2435            ScatterAdd,
2436            DequantMatMul,
2437            SelectiveScan,
2438            // Real-INT8 path + fake-quant.
2439            QMatMul,
2440            QConv2d,
2441            Quantize,
2442            Dequantize,
2443            FusedMatMulBiasAct,
2444            FusedResidualLN,
2445            FusedResidualRmsNorm,
2446            Fft,
2447            // Splat: no on-chip kernel — lowered to common primitive MIR via logical_kernel.
2448        ]
2449    };
2450
2451    impl Backend for TpuBackend {
2452        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2453            TPU_SUPPORTED_OPS
2454        }
2455
2456        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2457            let graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
2458                graph,
2459                TPU_SUPPORTED_OPS,
2460                options.kernel_dispatch,
2461            )
2462            .unwrap_or_else(|errors| {
2463                panic!("{}", rlx_opt::format_legalize_error("tpu", &errors));
2464            });
2465            // The TPU's IR-side pass pipeline (DCE, ConstFold,
2466            // FuseResidualLN, FuseMatMulBiasAct, LegalizeBroadcast,
2467            // MarkElementwiseRegions) lives inside
2468            // `TpuExecutable::compile` so the same passes run whether
2469            // a caller goes through Session or invokes the executable
2470            // directly. We only do backend-cross-cutting work here:
2471            // legalization (must precede the pipeline so we panic
2472            // early on unsupported ops) and AutoMixedPrecision.
2473            //
2474            // Default policy on TPU is `AutoMixedBf16`: BF16 is the
2475            // native compute dtype on TPU silicon and recent GPUs,
2476            // and XLA's CPU plugin handles it natively too. Callers
2477            // can opt out by passing an explicit `PrecisionPolicy`
2478            // (e.g. `AlwaysF32` for accuracy debugging or
2479            // `AlwaysF16` to match a CUDA workload's choice).
2480            use rlx_opt::pass::Pass as _;
2481            let policy = options
2482                .policy
2483                .clone()
2484                .unwrap_or(rlx_opt::PrecisionPolicy::AutoMixedBf16);
2485            let graph = rlx_opt::AutoMixedPrecision::new(policy).run(graph);
2486            let _ = options.dce;
2487            let _ = options.constant_folding;
2488            Box::new(TpuExecutableWrapper {
2489                inner: TpuExecutable::compile(graph),
2490            })
2491        }
2492    }
2493
2494    struct TpuExecutableWrapper {
2495        inner: TpuExecutable,
2496    }
2497
2498    // PJRT clients + buffers are documented as thread-safe per the
2499    // upstream C API. Same Send-claim shape as CudaExecutableWrapper /
2500    // RocmExecutableWrapper.
2501    unsafe impl Send for TpuExecutableWrapper {}
2502
2503    impl ExecutableGraph for TpuExecutableWrapper {
2504        fn set_param(&mut self, name: &str, data: &[f32]) {
2505            self.inner.set_param(name, data);
2506        }
2507        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2508            self.inner.run(inputs)
2509        }
2510
2511        /// Typed param upload — widens F16/BF16/etc. host bytes to
2512        /// f32 today. Once the HLO emitter speaks bf16 natively
2513        /// (which TPUs prefer over f16), the typed path will hand
2514        /// the original bytes straight through `Buffer_FromHostBuffer`.
2515        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2516            if dtype == rlx_ir::DType::F32 {
2517                let n = data.len() / 4;
2518                let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2519                self.inner.set_param(name, s);
2520            } else {
2521                let f32_buf = super::widen_bytes_to_f32(data, dtype);
2522                self.inner.set_param(name, &f32_buf);
2523            }
2524        }
2525
2526        fn run_typed(
2527            &mut self,
2528            inputs: &[(&str, &[u8], rlx_ir::DType)],
2529        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2530            let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2531            for (name, data, dt) in inputs {
2532                let v = super::widen_bytes_to_f32(data, *dt);
2533                owned.push((name.to_string(), v));
2534            }
2535            let refs: Vec<(&str, &[f32])> = owned
2536                .iter()
2537                .map(|(n, d)| (n.as_str(), d.as_slice()))
2538                .collect();
2539            let dtypes = self.inner.output_dtypes();
2540            let outs = self.inner.run(&refs);
2541            outs.into_iter()
2542                .zip(
2543                    dtypes
2544                        .into_iter()
2545                        .chain(std::iter::repeat(rlx_ir::DType::F32)),
2546                )
2547                .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2548                .collect()
2549        }
2550    }
2551}