Skip to main content

rlx_runtime/
backend.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Backend trait — abstraction over CPU/GPU/CUDA execution.
17//!
18//! Each backend implements `Backend::compile(graph, &CompileOptions)` and
19//! returns an `ExecutableGraph`. New compile knobs go in `CompileOptions`
20//! rather than as new trait methods.
21
22use crate::CompileOptions;
23use rlx_ir::Graph;
24use rlx_ir::hir::HirModule;
25use rlx_ir::lir::LirModule;
26use std::collections::HashMap;
27use std::sync::Arc;
28
29// ── Typed I/O helpers (shared across f32-arena backends) ────────────────
30
31/// Widen a typed byte buffer to `Vec<f32>`. Used by `set_param_typed` /
32/// `run_typed` overrides on backends whose internal arena is f32-uniform
33/// (CPU, Metal, wgpu) so callers can hand in F16/BF16 without doing the
34/// host-side cast themselves. Panics on dtypes the f32 arena can't carry.
35#[allow(dead_code)]
36pub(crate) fn widen_bytes_to_f32(data: &[u8], dtype: rlx_ir::DType) -> Vec<f32> {
37    use rlx_ir::DType;
38    match dtype {
39        DType::F32 => {
40            let n = data.len() / 4;
41            let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
42            s.to_vec()
43        }
44        DType::F16 => {
45            let n = data.len() / 2;
46            let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n) };
47            s.iter().map(|h| h.to_f32()).collect()
48        }
49        DType::BF16 => {
50            let n = data.len() / 2;
51            let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n) };
52            s.iter().map(|h| h.to_f32()).collect()
53        }
54        other => panic!(
55            "widen_bytes_to_f32: dtype {other:?} unsupported on f32-arena backends \
56             (only F32/F16/BF16 are accepted on the host I/O surface)"
57        ),
58    }
59}
60
61/// Narrow a `&[f32]` buffer down to the declared output dtype, returning
62/// the corresponding little-endian byte stream. Mirrors the bytes a
63/// backend that stores the native dtype would emit. Used by `run_typed`
64/// to keep the byte-level output contract identical across backends.
65#[allow(dead_code)]
66pub(crate) fn narrow_f32_to_bytes(v: &[f32], dt: rlx_ir::DType) -> Vec<u8> {
67    use rlx_ir::DType;
68    match dt {
69        DType::F32 => {
70            let mut bytes = Vec::with_capacity(v.len() * 4);
71            for &x in v {
72                bytes.extend_from_slice(&x.to_le_bytes());
73            }
74            bytes
75        }
76        DType::F16 => {
77            let mut bytes = Vec::with_capacity(v.len() * 2);
78            for &x in v {
79                bytes.extend_from_slice(&half::f16::from_f32(x).to_le_bytes());
80            }
81            bytes
82        }
83        DType::BF16 => {
84            let mut bytes = Vec::with_capacity(v.len() * 2);
85            for &x in v {
86                bytes.extend_from_slice(&half::bf16::from_f32(x).to_le_bytes());
87            }
88            bytes
89        }
90        DType::F64 => {
91            let mut bytes = Vec::with_capacity(v.len() * 8);
92            for &x in v {
93                bytes.extend_from_slice(&(x as f64).to_le_bytes());
94            }
95            bytes
96        }
97        DType::I8 => v.iter().map(|&x| x as i8 as u8).collect(),
98        DType::U8 => v.iter().map(|&x| x as u8).collect(),
99        DType::I16 => {
100            let mut bytes = Vec::with_capacity(v.len() * 2);
101            for &x in v {
102                bytes.extend_from_slice(&(x as i16).to_le_bytes());
103            }
104            bytes
105        }
106        DType::I32 => {
107            let mut bytes = Vec::with_capacity(v.len() * 4);
108            for &x in v {
109                bytes.extend_from_slice(&(x as i32).to_le_bytes());
110            }
111            bytes
112        }
113        DType::U32 => {
114            let mut bytes = Vec::with_capacity(v.len() * 4);
115            for &x in v {
116                bytes.extend_from_slice(&(x as u32).to_le_bytes());
117            }
118            bytes
119        }
120        DType::I64 => {
121            let mut bytes = Vec::with_capacity(v.len() * 8);
122            for &x in v {
123                bytes.extend_from_slice(&(x as i64).to_le_bytes());
124            }
125            bytes
126        }
127        DType::Bool => v
128            .iter()
129            .map(|&x| if x != 0.0 { 1u8 } else { 0u8 })
130            .collect(),
131        DType::C64 => {
132            // Complex narrow path: real part = the f32 value, imaginary
133            // part = 0. Mirrors how the backend stores narrowed f32
134            // operands when promoted to a complex op input.
135            let mut bytes = Vec::with_capacity(v.len() * 8);
136            for &x in v {
137                bytes.extend_from_slice(&x.to_le_bytes());
138                bytes.extend_from_slice(&0.0_f32.to_le_bytes());
139            }
140            bytes
141        }
142    }
143}
144
145/// A compiled, ready-to-execute graph on a specific backend.
146pub trait ExecutableGraph: Send {
147    /// Set a named parameter (weight) buffer.
148    fn set_param(&mut self, name: &str, data: &[f32]);
149
150    /// Deep-clone this executable into a fresh `Box`. Lets
151    /// `CompiledGraph` implement `Clone` so callers (e.g. eda-mna's
152    /// `SensitivityContext`) can spin up N independent executor
153    /// copies for thread-parallel dispatch without paying the full
154    /// graph-compile cost N times. Default implementation panics;
155    /// backends that support cloning override.
156    fn clone_box(&self) -> Box<dyn ExecutableGraph> {
157        panic!("clone_box not implemented for this backend");
158    }
159
160    /// Execute the graph with named inputs. Returns output data (copies from arena).
161    fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>>;
162
163    /// Execute and return raw pointers to output data in arena (zero-copy).
164    fn run_raw(&mut self, inputs: &[(&str, &[f32])]) -> Vec<(*const f32, usize)> {
165        let vecs = self.run(inputs);
166        vecs.iter().map(|v| (v.as_ptr(), v.len())).collect()
167    }
168
169    /// Fastest: inputs by slot index, returns output (offset, len) pairs.
170    /// Read output from arena via `arena_ptr().add(offset)`.
171    fn run_slots(&mut self, _inputs: &[&[f32]]) -> &[(usize, usize)] {
172        &[] // default: not supported
173    }
174
175    /// Get the raw arena buffer pointer for reading outputs after run_slots.
176    fn arena_ptr(&self) -> *const u8 {
177        std::ptr::null()
178    }
179
180    /// Hint the executor that subsequent `run` calls should process
181    /// only the first `actual` rows along the bucket axis (out of
182    /// `upper`, the extent the graph was compiled at). Backends that
183    /// support per-kernel active-extent dispatch honor this; others
184    /// ignore it and process the full compiled extent.
185    ///
186    /// Pass `None` to clear the hint. The hint is sticky — set it
187    /// before each `run` and clear it after, or maintain it across
188    /// runs at your discretion.
189    ///
190    /// Even when honored, callers must not rely on the contents of the
191    /// output past `actual` rows — that region may contain stale data
192    /// from earlier runs (kernels skip it).
193    ///
194    /// Default: no-op. See `BucketedCompileCache::run_padded` for the
195    /// canonical caller; backends opt in by overriding this method.
196    fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
197        let _ = extent;
198    }
199
200    /// TIDE merged placement mask (union across MoE layers). CPU: stats + host path.
201    fn set_moe_resident_experts(&mut self, _mask: &[bool]) {}
202
203    /// Per MoE layer placement (`masks[layer][expert]`). Preferred over merged on CPU.
204    fn set_moe_resident_experts_per_layer(&mut self, _masks: &[&[bool]]) {}
205
206    /// Capture MoE router TopK indices on the next CPU forward (TIDE refresh).
207    fn enable_moe_topk_capture(&mut self, _num_experts: usize) -> bool {
208        false
209    }
210
211    /// Take captured per-layer expert indices (one vec per MoE TopK in order).
212    fn take_moe_topk_capture(&mut self) -> Option<Vec<Vec<u32>>> {
213        None
214    }
215
216    /// MoE GroupedMatMul residency accounting from the last forward (CPU).
217    fn take_moe_residency_stats(&mut self) -> Option<crate::MoeResidencyStats> {
218        None
219    }
220
221    /// Bind a persistent buffer handle (KV-cache, training state, etc.).
222    /// The buffer lives across run() calls and is not in the arena.
223    /// Returns true if the backend supports persistent handles.
224    fn bind_handle(&mut self, _name: &str, _data: &[f32]) -> bool {
225        false
226    }
227
228    /// Read a persistent buffer's current contents.
229    fn read_handle(&self, _name: &str) -> Option<Vec<f32>> {
230        None
231    }
232
233    /// GPU-resident input (MLX): upload once, reuse across runs.
234    fn bind_gpu_handle(&mut self, _name: &str, _data: &[f32]) -> bool {
235        false
236    }
237
238    fn has_gpu_handle(&self, _name: &str) -> bool {
239        false
240    }
241
242    fn set_gpu_handle_feed(&mut self, _handle_name: &str, _output_index: usize) -> bool {
243        false
244    }
245
246    fn read_gpu_handle(&self, _name: &str) -> Option<Vec<f32>> {
247        None
248    }
249
250    /// Run and refresh a GPU handle from `output_index`; returns that output on host.
251    fn run_feed_gpu_handle(
252        &mut self,
253        inputs: &[(&str, &[f32])],
254        _handle_name: &str,
255        _output_index: usize,
256    ) -> Option<Vec<f32>> {
257        let _ = inputs;
258        None
259    }
260
261    // ── Pipelined / async execution (Phase C) ─────────────────────────
262    //
263    // These allow callers to amortize per-run sync latency on backends
264    // where it matters (Metal: ~150 µs `wait_until_completed` per commit).
265    // CPU has no such cost, so the default impls just call `run` serially.
266
267    /// Encode + commit a forward pass without waiting for completion.
268    ///
269    /// Outputs of intermediate calls are stomped — use `run_pipelined` if
270    /// you need outputs from each individual commit. Pair with
271    /// `sync_pending` to drain.
272    ///
273    /// Default: synchronous fallback (calls `run`, discards output). CPU
274    /// uses this default since BLAS is synchronous anyway.
275    fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
276        let _ = self.run(inputs);
277    }
278
279    /// Wait for every command queued by `commit_no_wait`.
280    /// Default: no-op (synchronous backends have nothing pending).
281    fn sync_pending(&mut self) {}
282
283    /// Issue a batch of forward passes pipelined, returning per-run outputs.
284    ///
285    /// The Metal impl encodes a per-commit blit so each in-flight run's
286    /// outputs survive subsequent commits stomping the shared arena. The
287    /// CPU default is just sequential `run`s — equally correct, no perf
288    /// penalty (CPU has no GPU sync cost to amortize).
289    ///
290    /// Returns `out[run_idx][output_idx][element_idx]`.
291    fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
292        input_sets.iter().map(|inputs| self.run(inputs)).collect()
293    }
294
295    // ── Typed (non-F32) host I/O ──────────────────────────────────
296    //
297    // `set_param` and `run` are F32 by contract. The typed entry
298    // points let callers pass and receive raw bytes in any rlx-ir
299    // dtype, avoiding the f32 widen/narrow round-trip that's
300    // wasteful for F16/BF16 weights and activations.
301    //
302    // The default impls only handle F32 — any other dtype panics.
303    // Backends that support typed I/O natively (e.g. MLX via
304    // Array::from_bytes/to_bytes) override these.
305
306    /// Set a named parameter from raw bytes in the given dtype.
307    fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
308        if dtype != rlx_ir::DType::F32 {
309            panic!(
310                "backend's default set_param_typed only handles F32; \
311                    got {dtype:?}. Override on the backend for typed support."
312            );
313        }
314        if !data.len().is_multiple_of(4) {
315            panic!(
316                "set_param_typed F32: data length {} not a multiple of 4",
317                data.len()
318            );
319        }
320        // SAFETY: F32 bytes are 4-aligned by source convention; we
321        // only widen access (read &[f32] from owned &[u8]). Failure
322        // mode if a caller hands us mis-aligned bytes is undefined,
323        // hence the % 4 length check.
324        let n = data.len() / 4;
325        let f32_slice = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
326        self.set_param(name, f32_slice);
327    }
328
329    /// Run with typed inputs and typed outputs. Returns
330    /// `(bytes, dtype)` per output; the dtype is whatever the
331    /// graph's output node was declared as.
332    fn run_typed(
333        &mut self,
334        inputs: &[(&str, &[u8], rlx_ir::DType)],
335    ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
336        // Default impl: convert each typed input to f32 (F32-only),
337        // run, then re-emit outputs as F32 bytes.
338        let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
339        for (name, data, dt) in inputs {
340            if *dt != rlx_ir::DType::F32 {
341                panic!(
342                    "backend's default run_typed only handles F32 inputs; \
343                        got {dt:?} for input '{name}'"
344                );
345            }
346            if data.len() % 4 != 0 {
347                panic!(
348                    "run_typed F32 input '{name}': len {} not multiple of 4",
349                    data.len()
350                );
351            }
352            let n = data.len() / 4;
353            let v: Vec<f32> =
354                unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) }.to_vec();
355            owned.push((name.to_string(), v));
356        }
357        let refs: Vec<(&str, &[f32])> = owned
358            .iter()
359            .map(|(n, d)| (n.as_str(), d.as_slice()))
360            .collect();
361        let outs = self.run(&refs);
362        outs.into_iter()
363            .map(|v| {
364                let bytes =
365                    unsafe { std::slice::from_raw_parts(v.as_ptr() as *const u8, v.len() * 4) }
366                        .to_vec();
367                (bytes, rlx_ir::DType::F32)
368            })
369            .collect()
370    }
371}
372
373/// Backend implementation trait.
374///
375/// Single compile entry point. New compile-time knobs are added to
376/// `CompileOptions`, not as new trait methods.
377///
378/// `Send + Sync` because backends are stateless factories — multiple
379/// threads can call `compile` concurrently. The returned
380/// `Box<dyn ExecutableGraph>` is `Send` (moveable to a worker thread)
381/// but **not** `Sync` (`run`/`run_slots` take `&mut self`).
382pub trait Backend: Send + Sync {
383    /// Compile a graph for this backend with the given options.
384    fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph>;
385
386    /// Compile pre-optimized LIR (HIR → MIR → LIR pipeline output).
387    /// Default re-enters [`Self::compile`] — backends should override
388    /// when they can reuse the embedded buffer plan.
389    fn compile_lir(&self, lir: LirModule, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
390        self.compile(lir.into_graph(), options)
391    }
392
393    /// HIR-first compile: lower blocks, run fusion pipeline, emit executable.
394    fn compile_hir(
395        &self,
396        hir: HirModule,
397        device: rlx_driver::Device,
398        options: &CompileOptions,
399    ) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
400        let result = crate::stages::compile_hir_stages(device, hir, options)?;
401        crate::stages::maybe_log_fusion(&result.fusion);
402        Ok(self.compile_lir(result.lir, options))
403    }
404
405    /// [`GraphModule`] compile — unified HIR/MIR/LIR entry.
406    fn compile_module(
407        &self,
408        module: rlx_ir::GraphModule,
409        device: rlx_driver::Device,
410        options: &CompileOptions,
411    ) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
412        let result = crate::stages::compile_module_stages(device, module, options)?;
413        crate::stages::maybe_log_fusion(&result.fusion);
414        Ok(self.compile_lir(result.lir, options))
415    }
416
417    /// PLAN L4: declare which `OpKind`s this backend can lower.
418    /// Default: empty slice = "no claim made — accept everything"
419    /// (preserves existing behavior; backends opt in by overriding).
420    /// When non-empty, the `LegalizeForBackend` pass will refuse to
421    /// compile a graph that contains an op outside this set, instead
422    /// of silently falling through to slower / wrong dispatch.
423    fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
424        &[]
425    }
426}
427
428/// Prepare a fused MIR graph from LIR for backend executable construction.
429/// Skips the fusion pipeline — LIR must come from `compile_*_stages`.
430#[allow(dead_code)]
431fn prepare_fused_graph(
432    graph: Graph,
433    options: &CompileOptions,
434    supported_ops: &[rlx_ir::OpKind],
435    backend_name: &str,
436) -> Graph {
437    let (mut graph, report) = rlx_opt::prepare_graph_for_backend_with_report(
438        graph,
439        backend_name,
440        supported_ops,
441        options.kernel_dispatch,
442    );
443    rlx_opt::maybe_log_dispatch_report(&report);
444    if !report.compile_ready {
445        panic!(
446            "{}\n{}",
447            rlx_opt::format_legalize_error(backend_name, &report.still_unsupported),
448            rlx_opt::format_dispatch_report(&report)
449        );
450    }
451    use rlx_opt::pass::Pass as _;
452    if options.dce {
453        graph = rlx_opt::DeadCodeElimination.run(graph);
454    }
455    if options.constant_folding {
456        graph = rlx_opt::ConstantFolding.run(graph);
457    }
458    if let Some(p) = options.policy.clone() {
459        graph = rlx_opt::AutoMixedPrecision::new(p).run(graph);
460    }
461    graph
462}
463
464// ── Convenience helpers preserved from older API ──────────────────────
465//
466// These let existing call sites keep working unchanged while the new
467// trait is the canonical one. We provide free functions rather than
468// trait methods so adding them doesn't grow the trait surface.
469
470/// Compile at default options (F32, no policy).
471pub fn compile(backend: &dyn Backend, graph: Graph) -> Box<dyn ExecutableGraph> {
472    backend.compile(graph, &CompileOptions::default())
473}
474
475/// Compile HIR through the fusion-first pipeline.
476pub fn compile_hir(
477    backend: &dyn Backend,
478    hir: HirModule,
479    device: rlx_driver::Device,
480    options: &CompileOptions,
481) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
482    backend.compile_hir(hir, device, options)
483}
484
485/// Compile a [`GraphModule`] through the fusion-first pipeline.
486pub fn compile_module(
487    backend: &dyn Backend,
488    module: rlx_ir::GraphModule,
489    device: rlx_driver::Device,
490    options: &CompileOptions,
491) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
492    backend.compile_module(module, device, options)
493}
494
495/// Compile at a specific precision (default policy = none).
496pub fn compile_with_precision(
497    backend: &dyn Backend,
498    graph: Graph,
499    precision: crate::Precision,
500) -> Box<dyn ExecutableGraph> {
501    backend.compile(graph, &CompileOptions::new().precision(precision))
502}
503
504/// Helper retained for backward compatibility — applies the precision
505/// rewrite at the runtime layer if backends don't override their
506/// pipeline placement. Modern code: pass the policy via CompileOptions
507/// and let the backend handle ordering.
508fn _legacy_apply_policy(graph: Graph, policy: Option<rlx_opt::PrecisionPolicy>) -> Graph {
509    use rlx_opt::pass::Pass as _;
510    match policy {
511        Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
512        None => graph,
513    }
514}
515
516// ── CPU Backend ─────────────────────────────────────────────────────────
517
518#[cfg(feature = "cpu")]
519pub mod cpu_backend {
520    use super::*;
521    use rlx_cpu::{arena::Arena, thunk};
522    use rlx_ir::{DType, NodeId, Op};
523    use rlx_opt::memory::{self, MemoryPlan};
524    // Arena typed read/write helpers live in `crate::arena` so every
525    // backend (CPU, Metal, future CUDA/wgpu/WASM) shares one implementation.
526    use rlx_driver::arena::{read_typed_to_f32, write_typed_from_f32};
527
528    pub struct CpuBackend;
529
530    /// PLAN L4: ops the CPU backend can lower today. Includes
531    /// DotGeneral (lowered via `LowerDotGeneral` pass) and
532    /// ElementwiseRegion (lowered natively per L2). Excludes
533    /// FusedTransformerLayer / If / While — those have IR variants
534    /// but no CPU lowering yet (see `compile_thunks` arm absence +
535    /// `subgraph.rs` "If/While executor wiring is pending" note).
536    const CPU_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
537        use rlx_ir::OpKind::*;
538        &[
539            Input,
540            Param,
541            Constant,
542            Activation,
543            Cast,
544            Binary,
545            Compare,
546            Where,
547            ElementwiseRegion,
548            MatMul,
549            DotGeneral,
550            DenseSolve,
551            BatchedDenseSolve,
552            Scan,
553            ScanBackward,
554            ScanBackwardXs,
555            LayerNorm,
556            LayerNorm2d,
557            GroupNorm,
558            RmsNorm,
559            ResizeNearest2x,
560            AxialRope2d,
561            Attention,
562            Rope,
563            Reshape,
564            Transpose,
565            Narrow,
566            Concat,
567            Expand,
568            Gather,
569            Reduce,
570            Softmax,
571            Cumsum,
572            TopK,
573            Sample,
574            Conv,
575            ConvTranspose2d,
576            Pool,
577            GroupedMatMul,
578            DequantGroupedMatMul,
579            DequantMoEWeights,
580            ScatterAdd,
581            LoraMatMul,
582            DequantMatMul,
583            SelectiveScan,
584            GatedDeltaNet,
585            FusedSwiGLU,
586            FusedMatMulBiasAct,
587            FusedResidualLN,
588            FusedResidualRmsNorm,
589            FusedAttentionBlock,
590            // Backward ops emitted by `rlx_opt::autodiff::grad_with_loss`.
591            // Their thunks live in rlx-cpu/src/thunk.rs alongside the
592            // forward kernels; without these entries the legalize step
593            // below would reject any compiled gradient graph.
594            ReluBackward,
595            ActivationBackward,
596            FakeQuantize,
597            FakeQuantizeBackward,
598            MaxPool2dBackward,
599            Conv2dBackwardInput,
600            Conv2dBackwardWeight,
601            SoftmaxCrossEntropyWithLogits,
602            SoftmaxCrossEntropyBackward,
603            AttentionBackward,
604            LayerNormBackwardInput,
605            LayerNormBackwardGamma,
606            RmsNormBackwardInput,
607            RmsNormBackwardGamma,
608            RmsNormBackwardBeta,
609            RopeBackward,
610            CumsumBackward,
611            GatherBackward,
612            // 3D Gaussian splat CPU reference render/backward (requires `rlx-cpu/splat`).
613            GaussianSplatRender,
614            GaussianSplatRenderBackward,
615            GaussianSplatPrepare,
616            GaussianSplatRasterize,
617            // User-registered custom ops dispatched through
618            // `rlx_cpu::op_registry`. Lowering panics with a clear
619            // message if the named CPU kernel isn't registered.
620            Custom,
621            // User-defined sub-graph with optional override AD rules
622            // (JAX-shaped custom_vjp / custom_jvp). Body is a regular
623            // Graph compiled recursively in compile_thunks.
624            CustomFn,
625            // FFT primitive (1D last-axis, 2N real-block layout, f64
626            // power-of-2 sizes). Other backends panic at lowering;
627            // pin FFT-containing graphs to Device::Cpu for now.
628            Fft,
629            // C64 Wirtinger AD surface. ComplexNormSq is the canonical
630            // real-valued loss for complex inputs; Conjugate is emitted
631            // by the new Wirtinger VJP rules for BinaryOp::Mul/Div on
632            // C64. Both have CPU thunks in rlx-cpu.
633            ComplexNormSq,
634            ComplexNormSqBackward,
635            Conjugate,
636        ]
637    };
638
639    impl Backend for CpuBackend {
640        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
641            CPU_SUPPORTED_OPS
642        }
643
644        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
645            use rlx_opt::pass::Pass as _;
646            // Lower Op::If / Op::While to primitives BEFORE legalize
647            // so the supported-op check doesn't reject them — the CPU
648            // backend has no native sub-graph executor; this rewrite
649            // makes If/While invisible to the rest of the pipeline.
650            // No-op when neither op is in the graph.
651            let graph = rlx_opt::LowerControlFlow.run(graph);
652            // PLAN L4: legalize against the backend's claimed op set
653            // BEFORE running fusion (so the diagnostic points at the
654            // user's IR, not at a fused-away node).
655            if let Err(errors) = rlx_opt::legalize_for_backend(&graph, CPU_SUPPORTED_OPS) {
656                panic!("{}", rlx_opt::format_legalize_error("cpu", &errors));
657            }
658            let policy = options.policy.clone();
659            let _precision = options.precision;
660            let cfg = rlx_cpu::config::RuntimeConfig::global();
661
662            // Optional preliminary cleanup passes
663            let graph = if options.dce {
664                rlx_opt::DeadCodeElimination.run(graph)
665            } else {
666                graph
667            };
668            let graph = if options.constant_folding {
669                rlx_opt::ConstantFolding.run(graph)
670            } else {
671                graph
672            };
673
674            // Run fusion pipeline (HIR/MIR/LIR ideology — fusion is first-class).
675            let mut compile_opts = options.clone();
676            compile_opts.arena_alignment = cfg.arena_alignment;
677            let compile_result = crate::stages::compile_graph_stages_for_backend(
678                rlx_driver::Device::Cpu,
679                graph,
680                &compile_opts,
681                CPU_SUPPORTED_OPS,
682            );
683            crate::stages::maybe_log_fusion(&compile_result.fusion);
684            let fused = compile_result.lir.into_graph();
685
686            // Apply precision policy AFTER fusion — Cast nodes don't disrupt
687            // the now-flattened fused ops.
688            let fused = match policy {
689                Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(fused),
690                None => fused,
691            };
692
693            // Re-plan after precision rewrites (may change dtypes / sizes).
694            let plan = memory::plan_memory_aligned(&fused, cfg.arena_alignment);
695            if cfg.verbose >= 1 {
696                eprintln!(
697                    "[rlx] arena: {} bytes, {} buffers, alignment: {}",
698                    plan.arena_size,
699                    plan.assignments.len(),
700                    cfg.arena_alignment
701                );
702            }
703            Box::new(build_cpu_executable(fused, plan))
704        }
705
706        fn compile_lir(
707            &self,
708            lir: LirModule,
709            options: &CompileOptions,
710        ) -> Box<dyn ExecutableGraph> {
711            let alignment = lir.buffers.alignment.max(options.arena_alignment);
712            let embedded: MemoryPlan = (&lir.buffers).into();
713            let mut graph = lir.into_graph();
714            if let Some(p) = options.policy.clone() {
715                use rlx_opt::pass::Pass;
716                graph = rlx_opt::AutoMixedPrecision::new(p).run(graph);
717            }
718            let plan = if options.policy.is_some() {
719                memory::plan_memory_aligned(&graph, alignment)
720            } else {
721                embedded
722            };
723            let cfg = rlx_cpu::config::RuntimeConfig::global();
724            if cfg.verbose >= 1 {
725                eprintln!(
726                    "[rlx] compile_lir: arena {} bytes ({} buffers, alignment {})",
727                    plan.arena_size,
728                    plan.assignments.len(),
729                    alignment,
730                );
731            }
732            Box::new(build_cpu_executable(graph, plan))
733        }
734    }
735
736    fn build_cpu_executable(graph: Graph, plan: MemoryPlan) -> CpuExecutable {
737        let mut arena = Arena::from_plan(plan);
738        let mut input_ids = HashMap::new();
739        let mut param_ids = HashMap::new();
740        let mut node_dtypes: HashMap<NodeId, DType> = HashMap::new();
741        for node in graph.nodes() {
742            node_dtypes.insert(node.id, node.shape.dtype());
743            match &node.op {
744                Op::Input { name } => {
745                    input_ids.insert(name.clone(), node.id);
746                }
747                Op::Param { name } => {
748                    param_ids.insert(name.clone(), node.id);
749                }
750                _ => {}
751            }
752        }
753
754        let schedule = thunk::compile_thunks(&graph, &arena);
755
756        let mut input_slots = Vec::new();
757        for node in graph.nodes() {
758            if let Op::Input { name } = &node.op {
759                let off = arena.byte_offset(node.id);
760                let len = node.shape.num_elements().unwrap_or(0);
761                input_slots.push((name.clone(), off, len, node.shape.dtype()));
762            }
763        }
764
765        let output_slots: Vec<(usize, usize)> = graph
766            .outputs
767            .iter()
768            .map(|&id| {
769                let off = arena.byte_offset(id);
770                let len = graph.node(id).shape.num_elements().unwrap_or(0);
771                (off, len)
772            })
773            .collect();
774
775        for node in graph.nodes() {
776            if let Op::Constant { data } = &node.op
777                && arena.has_buffer(node.id)
778                && !data.is_empty()
779            {
780                match node.shape.dtype() {
781                    DType::F64 => {
782                        let off = arena.byte_offset(node.id);
783                        let buf = arena.raw_buf_mut();
784                        let n = buf.len().saturating_sub(off).min(data.len());
785                        buf[off..off + n].copy_from_slice(&data[..n]);
786                    }
787                    _ => {
788                        let buf = arena.slice_mut(node.id);
789                        let n_floats = data.len() / 4;
790                        let n = buf.len().min(n_floats);
791                        for i in 0..n {
792                            let bytes = [
793                                data[i * 4],
794                                data[i * 4 + 1],
795                                data[i * 4 + 2],
796                                data[i * 4 + 3],
797                            ];
798                            buf[i] = f32::from_le_bytes(bytes);
799                        }
800                    }
801                }
802            }
803        }
804
805        CpuExecutable {
806            graph,
807            arena,
808            params: HashMap::new(),
809            input_ids,
810            param_ids,
811            node_dtypes,
812            schedule,
813            input_slots,
814            output_slots,
815            handles: HashMap::new(),
816            active_extent: None,
817            moe_resident: None,
818            moe_resident_layers: None,
819            moe_topk_capture: None,
820        }
821    }
822
823    #[derive(Clone)]
824    struct CpuExecutable {
825        graph: Graph,
826        arena: Arena,
827        params: HashMap<String, Vec<f32>>,
828        input_ids: HashMap<String, NodeId>,
829        param_ids: HashMap<String, NodeId>,
830        /// Per-node arena dtype. Lets set_param/run cast f32 ↔ F16/BF16
831        /// when AutoMixedPrecision has rewritten the graph.
832        node_dtypes: HashMap<NodeId, DType>,
833        schedule: thunk::ThunkSchedule,
834        // Pre-resolved: ordered list of (input_name, arena_byte_offset, max_elems, dtype)
835        input_slots: Vec<(String, usize, usize, DType)>,
836        /// Output (byte_offset, num_elements). dtype is in node_dtypes.
837        output_slots: Vec<(usize, usize)>,
838        /// Persistent buffer handles (KV-cache, optimizer state, etc.).
839        /// Lives outside the arena and survives across run() calls.
840        /// On run(): if a handle's name matches a graph input, the
841        /// handle's data is used as the input.
842        handles: HashMap<String, Vec<f32>>,
843        /// Active-extent hint (`Some((actual, upper))`) for L1 bucketed
844        /// dispatch. When set AND every thunk in the schedule is in
845        /// `Thunk::safe_for_active_extent`, the executor processes only
846        /// `actual / upper` of each kernel's work. Otherwise (or when
847        /// `None`) runs at the full compiled extent. See PLAN L1.
848        active_extent: Option<(usize, usize)>,
849        moe_resident: Option<std::sync::Arc<[bool]>>,
850        moe_resident_layers: Option<std::sync::Arc<Vec<std::sync::Arc<[bool]>>>>,
851        moe_topk_capture: Option<std::sync::Arc<rlx_cpu::moe_topk_capture::MoeTopkCapture>>,
852    }
853
854    unsafe impl Send for CpuExecutable {}
855
856    impl CpuExecutable {
857        /// Write a f32 input slice into the arena, casting to the node's dtype.
858        fn write_input(&mut self, id: NodeId, data: &[f32]) {
859            let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
860            let off = self.arena.byte_offset(id);
861            let buf = self.arena.raw_buf_mut();
862            let elem_size = dtype.size_bytes();
863            let max_elems = (buf.len() - off) / elem_size;
864            unsafe {
865                write_typed_from_f32(buf.as_mut_ptr().add(off), dtype, data, max_elems);
866            }
867        }
868
869        /// Read a node's arena bytes back as Vec<f32>, casting from its dtype.
870        fn read_output(&self, id: NodeId) -> Vec<f32> {
871            let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
872            let off = self.arena.byte_offset(id);
873            let buf = self.arena.raw_buf();
874            let n_elems = self.graph.node(id).shape.num_elements().unwrap_or(0);
875            unsafe { read_typed_to_f32(buf.as_ptr().add(off), dtype, n_elems) }
876        }
877    }
878
879    impl ExecutableGraph for CpuExecutable {
880        fn clone_box(&self) -> Box<dyn ExecutableGraph> {
881            Box::new(self.clone())
882        }
883        fn set_param(&mut self, name: &str, data: &[f32]) {
884            // Write directly into the arena — zero per-call lookup for params.
885            // Cast f32 → arena dtype when the param has been rewritten to F16/BF16.
886            if let Some(&id) = self.param_ids.get(name)
887                && self.arena.has_buffer(id)
888            {
889                let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
890                let off = self.arena.byte_offset(id);
891                let buf = self.arena.raw_buf_mut();
892                let elem_size = dtype.size_bytes();
893                let max_elems = (buf.len() - off) / elem_size;
894                unsafe {
895                    write_typed_from_f32(buf.as_mut_ptr().add(off), dtype, data, max_elems);
896                }
897                return;
898            }
899            // Fallback: store in HashMap if no arena slot
900            self.params.insert(name.to_string(), data.to_vec());
901        }
902
903        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
904            // 1. Apply persistent handles first — they act like default inputs.
905            //    Explicit `inputs` passed to run() override matching handle names.
906            let handle_names: Vec<String> = self.handles.keys().cloned().collect();
907            for name in &handle_names {
908                if let Some(&id) = self.input_ids.get(name)
909                    && self.arena.has_buffer(id)
910                {
911                    let data = self.handles.get(name).cloned().unwrap_or_default();
912                    self.write_input(id, &data);
913                }
914            }
915            // 2. Explicit per-call inputs override handles.
916            for &(name, data) in inputs {
917                if let Some(&id) = self.input_ids.get(name)
918                    && self.arena.has_buffer(id)
919                {
920                    self.write_input(id, data);
921                }
922            }
923
924            // Active-extent fast-path (PLAN L1): if hinted AND every thunk
925            // in the schedule supports it, run scaled. Otherwise fall back
926            // to full-extent dispatch — preserves correctness when the
927            // schedule contains a thunk that hasn't yet been wired in.
928            let active_used = if let Some((actual, upper)) = self.active_extent {
929                thunk::execute_thunks_active(
930                    &self.schedule,
931                    self.arena.raw_buf_mut(),
932                    actual,
933                    upper,
934                )
935            } else {
936                false
937            };
938            if !active_used {
939                // Execute via pre-compiled thunks (zero per-node dispatch overhead)
940                thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
941            }
942
943            // 3. Sync any handle whose name matches a graph OUTPUT —
944            //    KV-cache pattern: outputs flow back into the same-named
945            //    handle for the next iteration.
946            for (idx, &out_id) in self.graph.outputs.iter().enumerate() {
947                let name = format!("out{idx}");
948                if self.handles.contains_key(&name) {
949                    let v = self.read_output(out_id);
950                    self.handles.insert(name, v);
951                }
952            }
953
954            self.graph
955                .outputs
956                .iter()
957                .map(|&out_id| self.read_output(out_id))
958                .collect()
959        }
960
961        fn run_raw(&mut self, inputs: &[(&str, &[f32])]) -> Vec<(*const f32, usize)> {
962            // Copy inputs by name (HashMap lookup), casting to arena dtype.
963            for &(name, data) in inputs {
964                if let Some(&id) = self.input_ids.get(name)
965                    && self.arena.has_buffer(id)
966                {
967                    self.write_input(id, data);
968                }
969            }
970            thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
971            // Note: pointers are raw arena bytes — for F16 outputs, callers
972            // must read 2 bytes/elem, not 4. run() is the safe path for
973            // mixed precision; run_raw() is only meaningful for F32.
974            self.graph
975                .outputs
976                .iter()
977                .map(|&out_id| {
978                    let (ptr, len) = self.arena.raw_ptr(out_id);
979                    (ptr as *const f32, len)
980                })
981                .collect()
982        }
983
984        /// Fastest path: inputs by index (matching input_slots order), zero-copy output.
985        /// No HashMap, no name matching, no Vec allocation. Casts f32 input
986        /// to F16/BF16 if the input slot's dtype was rewritten.
987        fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
988            let buf = self.arena.raw_buf_mut();
989            for (i, &data) in inputs.iter().enumerate() {
990                if i < self.input_slots.len() {
991                    let (_, off, max_len, dtype) = &self.input_slots[i];
992                    unsafe {
993                        write_typed_from_f32(buf.as_mut_ptr().add(*off), *dtype, data, *max_len);
994                    }
995                }
996            }
997            thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
998            &self.output_slots
999        }
1000
1001        fn arena_ptr(&self) -> *const u8 {
1002            self.arena.raw_buf_mut_ptr()
1003        }
1004
1005        fn bind_handle(&mut self, name: &str, data: &[f32]) -> bool {
1006            // Persistent buffer: stored separately from arena, survives run().
1007            // If the name matches a graph input, run() will use this data
1008            // as the input. If the graph also writes back to this name (via
1009            // an output binding pattern), read_handle returns the latest.
1010            self.handles.insert(name.to_string(), data.to_vec());
1011            true
1012        }
1013
1014        fn read_handle(&self, name: &str) -> Option<Vec<f32>> {
1015            self.handles.get(name).cloned()
1016        }
1017
1018        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1019            self.active_extent = extent;
1020        }
1021
1022        fn set_moe_resident_experts(&mut self, mask: &[bool]) {
1023            self.moe_resident_layers = None;
1024            self.schedule.moe_resident_layers = None;
1025            self.moe_resident = Some(Arc::from(mask));
1026            self.schedule.moe_resident = self.moe_resident.clone();
1027        }
1028
1029        fn set_moe_resident_experts_per_layer(&mut self, masks: &[&[bool]]) {
1030            self.moe_resident = None;
1031            self.schedule.moe_resident = None;
1032            let layers: Vec<Arc<[bool]>> = masks.iter().map(|m| Arc::from(*m)).collect();
1033            let arc = Arc::new(layers);
1034            self.moe_resident_layers = Some(arc.clone());
1035            self.schedule.moe_resident_layers = Some(arc);
1036        }
1037
1038        fn enable_moe_topk_capture(&mut self, num_experts: usize) -> bool {
1039            let cap = rlx_cpu::moe_topk_capture::MoeTopkCapture::new(num_experts);
1040            self.moe_topk_capture = Some(cap.clone());
1041            self.schedule.moe_topk_capture = Some(cap);
1042            true
1043        }
1044
1045        fn take_moe_topk_capture(&mut self) -> Option<Vec<Vec<u32>>> {
1046            let cap = self.moe_topk_capture.as_ref()?;
1047            let layers = cap.take_layers();
1048            if layers.is_empty() {
1049                None
1050            } else {
1051                Some(layers)
1052            }
1053        }
1054
1055        fn take_moe_residency_stats(&mut self) -> Option<crate::MoeResidencyStats> {
1056            rlx_cpu::moe_residency::take_last_forward_stats()
1057        }
1058
1059        /// Typed param upload. F32 / F16 / BF16 go through the existing
1060        /// widen-to-f32 path (the CPU arena is historically f32 with
1061        /// optional half-precision rewrite). F64 (and any future
1062        /// non-widenable dtype) lands directly in the arena as bytes —
1063        /// the f32 path would lose precision.
1064        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1065            if dtype == DType::F64 {
1066                self.set_param_bytes(name, data, dtype);
1067                return;
1068            }
1069            // U8 / I8 raw byte tensors: opaque storage for the GGUF
1070            // K-quant `Op::DequantMatMul` path (weights stay packed
1071            // in the arena). One arena byte = one element.
1072            if matches!(dtype, DType::U8 | DType::I8) {
1073                self.set_param_bytes(name, data, dtype);
1074                return;
1075            }
1076            if dtype == DType::F32 {
1077                let n = data.len() / 4;
1078                let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
1079                self.set_param(name, s);
1080            } else {
1081                let f32_buf = super::widen_bytes_to_f32(data, dtype);
1082                self.set_param(name, &f32_buf);
1083            }
1084        }
1085
1086        /// Typed run with mixed-dtype inputs/outputs.
1087        ///
1088        /// For each input: if its declared graph dtype matches the
1089        /// caller's bytes, we write directly into the arena (zero
1090        /// precision loss — F64 stays F64). For F32 with a half-precision
1091        /// arena rewrite, we widen as before. F16/BF16 callers go
1092        /// through the existing widen path.
1093        ///
1094        /// Outputs are read straight from the arena in the graph node's
1095        /// declared dtype — F64 outputs come back as 8 bytes/element,
1096        /// F32 as 4, etc.
1097        fn run_typed(
1098            &mut self,
1099            inputs: &[(&str, &[u8], rlx_ir::DType)],
1100        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1101            // Decide: are *all* inputs F64? If so, use the direct-byte
1102            // path for everything and skip the f32 widening machinery
1103            // entirely. Mixed dtype graphs (F32 + F64) take the
1104            // per-input dispatch route below.
1105            let all_f64 = !inputs.is_empty() && inputs.iter().all(|(_, _, dt)| *dt == DType::F64);
1106
1107            if all_f64 {
1108                for (name, data, _) in inputs {
1109                    if let Some(&id) = self.input_ids.get(*name) {
1110                        if !self.arena.has_buffer(id) {
1111                            continue;
1112                        }
1113                        let off = self.arena.byte_offset(id);
1114                        let buf = self.arena.raw_buf_mut();
1115                        let n = data.len();
1116                        debug_assert!(
1117                            off + n <= buf.len(),
1118                            "run_typed: input '{name}' overflows arena slot"
1119                        );
1120                        buf[off..off + n].copy_from_slice(data);
1121                    }
1122                }
1123                thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1124            } else {
1125                // Mixed-dtype path: dtypes that survive untouched
1126                // through the f32-aliased arena (F64, I32, I64, U32)
1127                // go in as bytes; F32 and the half-precision family
1128                // route through widen-to-f32 + run.
1129                let mut f32_owned: Vec<(String, Vec<f32>)> = Vec::new();
1130                for (name, data, dt) in inputs {
1131                    let direct = matches!(*dt, DType::F64 | DType::I32 | DType::I64 | DType::U32,);
1132                    if direct {
1133                        if let Some(&id) = self.input_ids.get(*name) {
1134                            if !self.arena.has_buffer(id) {
1135                                continue;
1136                            }
1137                            let off = self.arena.byte_offset(id);
1138                            let buf = self.arena.raw_buf_mut();
1139                            buf[off..off + data.len()].copy_from_slice(data);
1140                        }
1141                    } else {
1142                        let v = super::widen_bytes_to_f32(data, *dt);
1143                        f32_owned.push((name.to_string(), v));
1144                    }
1145                }
1146                let refs: Vec<(&str, &[f32])> = f32_owned
1147                    .iter()
1148                    .map(|(n, d)| (n.as_str(), d.as_slice()))
1149                    .collect();
1150                let _ = self.run(&refs);
1151            }
1152
1153            // Read each output's bytes from the arena in its declared dtype.
1154            self.graph
1155                .outputs
1156                .iter()
1157                .map(|&id| {
1158                    let dtype = self.graph.node(id).shape.dtype();
1159                    let n_elems = self.graph.node(id).shape.num_elements().unwrap_or(0);
1160                    let n_bytes = n_elems * dtype.size_bytes();
1161                    let off = self.arena.byte_offset(id);
1162                    let bytes = self.arena.raw_buf()[off..off + n_bytes].to_vec();
1163                    (bytes, dtype)
1164                })
1165                .collect()
1166        }
1167    }
1168
1169    impl CpuExecutable {
1170        /// Direct-byte param upload — copies caller's bytes into the
1171        /// arena slot for the named param without any dtype conversion.
1172        /// Used by `set_param_typed` for dtypes that f32-widening would
1173        /// corrupt (F64). Caller is responsible for matching the param's
1174        /// declared graph dtype.
1175        fn set_param_bytes(&mut self, name: &str, data: &[u8], _dtype: rlx_ir::DType) {
1176            if let Some(&id) = self.param_ids.get(name)
1177                && self.arena.has_buffer(id)
1178            {
1179                let off = self.arena.byte_offset(id);
1180                let buf = self.arena.raw_buf_mut();
1181                debug_assert!(
1182                    off + data.len() <= buf.len(),
1183                    "set_param_bytes: '{name}' would overflow arena slot"
1184                );
1185                buf[off..off + data.len()].copy_from_slice(data);
1186            }
1187        }
1188    }
1189}
1190
1191// ── Metal Backend ───────────────────────────────────────────────────────
1192
1193// ── wgpu Backend ────────────────────────────────────────────────────────
1194
1195#[cfg(feature = "gpu")]
1196pub mod wgpu_backend {
1197    use super::*;
1198    use rlx_ir::OpKind;
1199    use rlx_wgpu::backend::WgpuExecutable;
1200
1201    pub struct WgpuBackend;
1202
1203    /// PLAN L4: ops the wgpu backend can lower today. The fused
1204    /// macro-kernels (FAB, FTL, FusedSwiGLU) get decomposed by
1205    /// `crate::unfuse::unfuse` upstream — they're listed here too so
1206    /// graphs that already contain them legalize cleanly. Conv1d/3d
1207    /// and Pool1d/3d are deferred (Conv2d only).
1208    const WGPU_SUPPORTED_OPS: &[OpKind] = &[
1209        OpKind::Input,
1210        OpKind::Param,
1211        OpKind::Constant,
1212        OpKind::Activation,
1213        OpKind::Cast,
1214        OpKind::Binary,
1215        OpKind::Compare,
1216        OpKind::Where,
1217        OpKind::ElementwiseRegion,
1218        OpKind::MatMul,
1219        OpKind::DotGeneral,
1220        OpKind::LayerNorm,
1221        OpKind::RmsNorm,
1222        OpKind::Attention,
1223        OpKind::AttentionBackward,
1224        OpKind::RmsNormBackwardInput,
1225        OpKind::RmsNormBackwardGamma,
1226        OpKind::RmsNormBackwardBeta,
1227        OpKind::RopeBackward,
1228        OpKind::CumsumBackward,
1229        OpKind::GatherBackward,
1230        OpKind::Rope,
1231        OpKind::Reshape,
1232        OpKind::Transpose,
1233        OpKind::Narrow,
1234        OpKind::Concat,
1235        OpKind::Expand,
1236        OpKind::Gather,
1237        OpKind::Reduce,
1238        OpKind::Softmax,
1239        OpKind::Cumsum,
1240        OpKind::TopK,
1241        OpKind::Sample,
1242        OpKind::Conv,
1243        OpKind::Pool,
1244        OpKind::GroupedMatMul,
1245        OpKind::DequantGroupedMatMul,
1246        OpKind::DequantMoEWeights,
1247        OpKind::ScatterAdd,
1248        OpKind::SelectiveScan,
1249        OpKind::DequantMatMul,
1250        OpKind::FusedMatMulBiasAct,
1251        OpKind::FusedResidualLN,
1252        OpKind::FusedResidualRmsNorm,
1253        OpKind::FusedSwiGLU,
1254        OpKind::FusedAttentionBlock,
1255        OpKind::FusedTransformerLayer,
1256        // Native FFT (WGSL radix-2): f32 only, power-of-2 N ≤ 1024.
1257        // Anything outside that envelope panics at lowering with a
1258        // "pin to Device::Cpu" hint. No host fallback — WGPU has no
1259        // unified memory, so silent CPU round-trip would be a hidden
1260        // performance cliff.
1261        OpKind::Fft,
1262        // 3D Gaussian splat: native Metal / CPU reference per backend.
1263        OpKind::GaussianSplatRender,
1264        OpKind::GaussianSplatRenderBackward,
1265        OpKind::GaussianSplatPrepare,
1266        OpKind::GaussianSplatRasterize,
1267        OpKind::Custom,
1268        // LoRA, If, While: not yet wired in wgpu — fail loudly.
1269    ];
1270
1271    impl Backend for WgpuBackend {
1272        fn supported_ops(&self) -> &'static [OpKind] {
1273            WGPU_SUPPORTED_OPS
1274        }
1275
1276        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
1277            // PLAN L4: legalize against the backend's claimed op set
1278            // BEFORE running fusion passes (so the diagnostic points
1279            // at the user's IR, not at a fused-away node).
1280            if let Err(errors) = rlx_opt::legalize_for_backend(&graph, WGPU_SUPPORTED_OPS) {
1281                panic!("{}", rlx_opt::format_legalize_error("wgpu", &errors));
1282            }
1283            use rlx_opt::pass::Pass as _;
1284            // Cleanup passes upstream of wgpu's pipeline.
1285            let graph = if options.dce {
1286                rlx_opt::DeadCodeElimination.run(graph)
1287            } else {
1288                graph
1289            };
1290            let graph = if options.constant_folding {
1291                rlx_opt::ConstantFolding.run(graph)
1292            } else {
1293                graph
1294            };
1295            // ORDER MATTERS: targeted-pattern fusions run BEFORE the
1296            // catch-all `MarkElementwiseRegions`. Otherwise the region
1297            // pass swallows the Add / Activation nodes into chains and
1298            // FuseMatMulBiasAct / FuseResidualLN fail to match the
1299            // narrower patterns they look for. (Metal pipeline at line
1300            // ~377 already orders these correctly; wgpu was inverted
1301            // and silently shipped 13 unfused LayerNorms per BERT
1302            // forward where 12 should have been FusedResidualLN.)
1303            let compile_result = crate::stages::compile_graph_stages_for_backend(
1304                rlx_driver::Device::Gpu,
1305                graph,
1306                options,
1307                WGPU_SUPPORTED_OPS,
1308            );
1309            crate::stages::maybe_log_fusion(&compile_result.fusion);
1310            let graph = compile_result.lir.into_graph();
1311            let graph = match options.policy.clone() {
1312                Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
1313                None => graph,
1314            };
1315            Box::new(WgpuExecutableWrapper {
1316                inner: WgpuExecutable::compile(graph),
1317            })
1318        }
1319
1320        fn compile_lir(
1321            &self,
1322            lir: LirModule,
1323            options: &CompileOptions,
1324        ) -> Box<dyn ExecutableGraph> {
1325            let graph = prepare_fused_graph(lir.into_graph(), options, WGPU_SUPPORTED_OPS, "wgpu");
1326            Box::new(WgpuExecutableWrapper {
1327                inner: WgpuExecutable::compile(graph),
1328            })
1329        }
1330    }
1331
1332    struct WgpuExecutableWrapper {
1333        inner: WgpuExecutable,
1334    }
1335
1336    unsafe impl Send for WgpuExecutableWrapper {}
1337
1338    impl ExecutableGraph for WgpuExecutableWrapper {
1339        fn set_param(&mut self, name: &str, data: &[f32]) {
1340            self.inner.set_param(name, data);
1341        }
1342        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
1343            self.inner.run(inputs)
1344        }
1345        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1346            self.inner.set_active_extent(extent);
1347        }
1348
1349        /// Typed param upload: widens F16/BF16 to F32 at the host boundary,
1350        /// since the wgpu arena is f32-uniform.
1351        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1352            match dtype {
1353                rlx_ir::DType::U8 | rlx_ir::DType::I8 => {
1354                    self.inner.set_param_bytes(name, data);
1355                }
1356                rlx_ir::DType::F32 => {
1357                    let n = data.len() / 4;
1358                    let f32_slice =
1359                        unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
1360                    self.inner.set_param(name, f32_slice);
1361                }
1362                rlx_ir::DType::F16 => {
1363                    let n = data.len() / 2;
1364                    let f16_slice =
1365                        unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n) };
1366                    let f32: Vec<f32> = f16_slice.iter().map(|h| h.to_f32()).collect();
1367                    self.inner.set_param(name, &f32);
1368                }
1369                rlx_ir::DType::BF16 => {
1370                    let n = data.len() / 2;
1371                    let bf16_slice = unsafe {
1372                        std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n)
1373                    };
1374                    let f32: Vec<f32> = bf16_slice.iter().map(|h| h.to_f32()).collect();
1375                    self.inner.set_param(name, &f32);
1376                }
1377                other => panic!(
1378                    "rlx-wgpu set_param_typed: dtype {other:?} unsupported \
1379                                 (F32, F16, BF16 only — wgpu arena is f32-uniform)"
1380                ),
1381            }
1382        }
1383
1384        /// Typed run: widen each typed input to F32, run, then narrow each
1385        /// output back to its declared dtype.
1386        fn run_typed(
1387            &mut self,
1388            inputs: &[(&str, &[u8], rlx_ir::DType)],
1389        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1390            let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
1391            for (name, data, dt) in inputs {
1392                let v: Vec<f32> = match *dt {
1393                    rlx_ir::DType::F32 => {
1394                        let n = data.len() / 4;
1395                        unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) }
1396                            .to_vec()
1397                    }
1398                    rlx_ir::DType::F16 => {
1399                        let n = data.len() / 2;
1400                        let s = unsafe {
1401                            std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n)
1402                        };
1403                        s.iter().map(|h| h.to_f32()).collect()
1404                    }
1405                    rlx_ir::DType::BF16 => {
1406                        let n = data.len() / 2;
1407                        let s = unsafe {
1408                            std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n)
1409                        };
1410                        s.iter().map(|h| h.to_f32()).collect()
1411                    }
1412                    other => {
1413                        panic!("rlx-wgpu run_typed: input '{name}' dtype {other:?} unsupported")
1414                    }
1415                };
1416                owned.push((name.to_string(), v));
1417            }
1418            let refs: Vec<(&str, &[f32])> = owned
1419                .iter()
1420                .map(|(n, d)| (n.as_str(), d.as_slice()))
1421                .collect();
1422            let dtypes = self.inner.output_dtypes();
1423            let outs = self.inner.run(&refs);
1424            outs.into_iter()
1425                .zip(
1426                    dtypes
1427                        .into_iter()
1428                        .chain(std::iter::repeat(rlx_ir::DType::F32)),
1429                )
1430                .map(|(v, dt)| (narrow_to_dtype(&v, dt), dt))
1431                .collect()
1432        }
1433    }
1434
1435    /// Cast every element of a wgpu f32 output buffer down to the
1436    /// declared output dtype, returning the corresponding byte stream.
1437    /// The arena keeps every value as f32; declared output dtypes
1438    /// (Bool, I8, I32, F16, ...) require an exit-time narrowing to be
1439    /// byte-identical with backends that store the native dtype.
1440    fn narrow_to_dtype(v: &[f32], dt: rlx_ir::DType) -> Vec<u8> {
1441        use rlx_ir::DType;
1442        match dt {
1443            DType::F32 => {
1444                let mut bytes = Vec::with_capacity(v.len() * 4);
1445                for &x in v {
1446                    bytes.extend_from_slice(&x.to_le_bytes());
1447                }
1448                bytes
1449            }
1450            DType::F16 => {
1451                let mut bytes = Vec::with_capacity(v.len() * 2);
1452                for &x in v {
1453                    bytes.extend_from_slice(&half::f16::from_f32(x).to_le_bytes());
1454                }
1455                bytes
1456            }
1457            DType::BF16 => {
1458                let mut bytes = Vec::with_capacity(v.len() * 2);
1459                for &x in v {
1460                    bytes.extend_from_slice(&half::bf16::from_f32(x).to_le_bytes());
1461                }
1462                bytes
1463            }
1464            DType::F64 => {
1465                let mut bytes = Vec::with_capacity(v.len() * 8);
1466                for &x in v {
1467                    bytes.extend_from_slice(&(x as f64).to_le_bytes());
1468                }
1469                bytes
1470            }
1471            DType::I8 => v.iter().map(|&x| x as i8 as u8).collect(),
1472            DType::U8 => v.iter().map(|&x| x as u8).collect(),
1473            DType::I16 => {
1474                let mut bytes = Vec::with_capacity(v.len() * 2);
1475                for &x in v {
1476                    bytes.extend_from_slice(&(x as i16).to_le_bytes());
1477                }
1478                bytes
1479            }
1480            DType::I32 => {
1481                let mut bytes = Vec::with_capacity(v.len() * 4);
1482                for &x in v {
1483                    bytes.extend_from_slice(&(x as i32).to_le_bytes());
1484                }
1485                bytes
1486            }
1487            DType::U32 => {
1488                let mut bytes = Vec::with_capacity(v.len() * 4);
1489                for &x in v {
1490                    bytes.extend_from_slice(&(x as u32).to_le_bytes());
1491                }
1492                bytes
1493            }
1494            DType::I64 => {
1495                let mut bytes = Vec::with_capacity(v.len() * 8);
1496                for &x in v {
1497                    bytes.extend_from_slice(&(x as i64).to_le_bytes());
1498                }
1499                bytes
1500            }
1501            DType::Bool => v
1502                .iter()
1503                .map(|&x| if x != 0.0 { 1u8 } else { 0u8 })
1504                .collect(),
1505            // C64 (complex f32 pair) — the wgpu backend's f32 arena
1506            // doesn't synthesize complex outputs today; this branch
1507            // only fires if a graph somehow asks for a C64 output and
1508            // the backend lowered it as 2N real floats. We pass the
1509            // raw f32 stream straight through; downstream code that
1510            // wants complex semantics is responsible for re-pairing.
1511            DType::C64 => {
1512                let mut bytes = Vec::with_capacity(v.len() * 4);
1513                for &x in v {
1514                    bytes.extend_from_slice(&x.to_le_bytes());
1515                }
1516                bytes
1517            }
1518        }
1519    }
1520}
1521
1522// ── MLX Backend ─────────────────────────────────────────────────────────
1523
1524#[cfg(all(feature = "mlx", target_os = "macos"))]
1525pub mod mlx_backend {
1526    use super::*;
1527    use rlx_mlx::MlxExecutable;
1528
1529    pub struct MlxBackend;
1530
1531    /// PLAN L4: ops the MLX backend can lower today. MLX has the
1532    /// widest IR coverage of any GPU backend — handles everything
1533    /// including If/While via topo unrolling, and lowers
1534    /// ElementwiseRegion natively via the per-step composition in
1535    /// rlx-mlx/src/lower.rs (PLAN L2).
1536    const MLX_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
1537        use rlx_ir::OpKind::*;
1538        &[
1539            Input,
1540            Param,
1541            Constant,
1542            Activation,
1543            Cast,
1544            Binary,
1545            Compare,
1546            Where,
1547            ElementwiseRegion,
1548            MatMul,
1549            DotGeneral,
1550            DenseSolve,
1551            BatchedDenseSolve,
1552            LayerNorm,
1553            LayerNorm2d,
1554            RmsNorm,
1555            Attention,
1556            Rope,
1557            Reshape,
1558            Transpose,
1559            Narrow,
1560            Concat,
1561            Expand,
1562            Gather,
1563            Reduce,
1564            Softmax,
1565            Cumsum,
1566            TopK,
1567            Sample,
1568            Conv,
1569            ConvTranspose2d,
1570            Pool,
1571            GroupedMatMul,
1572            DequantGroupedMatMul,
1573            DequantMoEWeights,
1574            ScatterAdd,
1575            LoraMatMul,
1576            DequantMatMul,
1577            SelectiveScan,
1578            GatedDeltaNet,
1579            FusedSwiGLU,
1580            FusedMatMulBiasAct,
1581            FusedResidualLN,
1582            FusedResidualRmsNorm,
1583            FusedAttentionBlock,
1584            FusedTransformerLayer,
1585            If,
1586            While,
1587            // Loop-unrolled scan (Op::Scan body is statically unrolled
1588            // `length` times into MLX ops; mirror of Op::While's
1589            // bounded-unroll lowering). ScanBackward is the AD
1590            // companion — handled the same way.
1591            Scan,
1592            ScanBackward,
1593            ScanBackwardXs,
1594            // Tier 1 autodiff backward ops — lowered as primitive
1595            // compositions in `rlx-mlx/src/lower.rs`.
1596            ReluBackward,
1597            ActivationBackward,
1598            SoftmaxCrossEntropyWithLogits,
1599            SoftmaxCrossEntropyBackward,
1600            AttentionBackward,
1601            LayerNormBackwardInput,
1602            LayerNormBackwardGamma,
1603            // Tier 2 — conv backward via `mc::conv_general` with the
1604            // same parameter-mapping MLX uses inside its built-in vjp.
1605            // Currently groups=1 only; grouped conv backward will
1606            // surface as a clear error from `lower.rs`.
1607            Conv2dBackwardInput,
1608            Conv2dBackwardWeight,
1609            // Tier 3 — max-pool backward via slice-strided argmax over
1610            // pool windows + a per-kernel-slot scatter-add, matching
1611            // the CPU thunk's "first-hit-wins" tiebreaking.
1612            MaxPool2dBackward,
1613            // QAT — `FakeQuantize` (PerBatch + Fixed scale modes;
1614            // EMA returns a clear error from `lower.rs`) and the
1615            // `FakeQuantizeBackward` family covering all 4 STE
1616            // variants. Closes the last gap vs `CPU_SUPPORTED_OPS`.
1617            FakeQuantize,
1618            FakeQuantizeBackward,
1619            // User-registered custom ops dispatched through
1620            // `rlx_mlx::op_registry`. Lowering looks up the
1621            // registered `MlxKernel` and calls its `execute` method
1622            // to produce the lazy MLX `Array` for this node.
1623            Custom,
1624            GaussianSplatRender,
1625            GaussianSplatRenderBackward,
1626            // Op::Fft on MLX: NOT supported. Host-fallback was tried
1627            // and rejected — MLX's compile callback forbids `eval`,
1628            // and `Array::to_bytes` requires eval, so we can't
1629            // materialize/transform/rematerialize inside the lower
1630            // pass. Pin FFT subgraphs to Device::Cpu (or Device::Metal,
1631            // which has a working unified-memory host-fallback). Real
1632            // MLX support needs a native `mlx::fft::fft` FFI shim;
1633            // tracked in PLAN.md.
1634        ]
1635    };
1636
1637    impl Backend for MlxBackend {
1638        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
1639            MLX_SUPPORTED_OPS
1640        }
1641
1642        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
1643            let compile_result = crate::stages::compile_graph_stages_for_backend(
1644                rlx_driver::Device::Mlx,
1645                graph,
1646                options,
1647                MLX_SUPPORTED_OPS,
1648            );
1649            crate::stages::maybe_log_fusion(&compile_result.fusion);
1650            self.compile_lir(compile_result.lir, options)
1651        }
1652
1653        fn compile_lir(
1654            &self,
1655            lir: LirModule,
1656            options: &CompileOptions,
1657        ) -> Box<dyn ExecutableGraph> {
1658            use rlx_opt::pass::Pass as _;
1659            let mut graph = lir.into_graph();
1660            graph = rlx_opt::LowerControlFlow.run(graph);
1661            let graph = prepare_fused_graph(graph, options, MLX_SUPPORTED_OPS, "mlx");
1662            Box::new(build_mlx_executable(graph))
1663        }
1664    }
1665
1666    fn build_mlx_executable(graph: Graph) -> MlxExecutableWrapper {
1667        let mode = mlx_mode_from_env();
1668        let mut exe = MlxExecutable::compile_from_fused(graph, mode);
1669        if mode == rlx_mlx::lower::MlxMode::Compiled {
1670            if let Err(e) = exe.warm_compile() {
1671                eprintln!(
1672                    "[rlx-runtime] MLX warm_compile failed ({e}); first run will pay the trace cost"
1673                );
1674            }
1675        }
1676        MlxExecutableWrapper { inner: exe }
1677    }
1678
1679    fn mlx_mode_from_env() -> rlx_mlx::lower::MlxMode {
1680        match rlx_ir::env::var("RLX_MLX_MODE").as_deref() {
1681            Some(s) if s.eq_ignore_ascii_case("eager") => rlx_mlx::lower::MlxMode::Eager,
1682            Some(s) if s.eq_ignore_ascii_case("lazy") => rlx_mlx::lower::MlxMode::Lazy,
1683            Some(s) if s.eq_ignore_ascii_case("compiled") => rlx_mlx::lower::MlxMode::Compiled,
1684            _ => rlx_mlx::lower::MlxMode::Compiled,
1685        }
1686    }
1687
1688    struct MlxExecutableWrapper {
1689        inner: MlxExecutable,
1690    }
1691
1692    unsafe impl Send for MlxExecutableWrapper {}
1693
1694    impl ExecutableGraph for MlxExecutableWrapper {
1695        fn set_param(&mut self, name: &str, data: &[f32]) {
1696            self.inner.set_param(name, data);
1697        }
1698        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
1699            self.inner.run(inputs)
1700        }
1701        fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
1702            self.inner.run_slots(inputs)
1703        }
1704        fn arena_ptr(&self) -> *const u8 {
1705            self.inner.arena_ptr()
1706        }
1707        fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
1708            self.inner.commit_no_wait(inputs);
1709        }
1710        fn sync_pending(&mut self) {
1711            self.inner.sync_pending();
1712        }
1713        fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
1714            self.inner.run_pipelined(input_sets)
1715        }
1716        fn bind_handle(&mut self, name: &str, data: &[f32]) -> bool {
1717            self.inner.bind_handle(name, data)
1718        }
1719        fn read_handle(&self, name: &str) -> Option<Vec<f32>> {
1720            self.inner.read_handle(name)
1721        }
1722        fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
1723            self.inner.bind_gpu_handle(name, data).is_ok()
1724        }
1725        fn has_gpu_handle(&self, name: &str) -> bool {
1726            self.inner.has_gpu_handle(name)
1727        }
1728        fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
1729            self.inner.set_gpu_handle_feed(handle_name, output_index);
1730            true
1731        }
1732        fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
1733            self.inner.read_gpu_handle(name).ok()
1734        }
1735        fn run_feed_gpu_handle(
1736            &mut self,
1737            inputs: &[(&str, &[f32])],
1738            handle_name: &str,
1739            output_index: usize,
1740        ) -> Option<Vec<f32>> {
1741            self.inner
1742                .run_feed_gpu(inputs, handle_name, output_index)
1743                .ok()
1744        }
1745        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1746            self.inner.set_param_typed(name, data, dtype);
1747        }
1748        fn run_typed(
1749            &mut self,
1750            inputs: &[(&str, &[u8], rlx_ir::DType)],
1751        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1752            self.inner.run_typed(inputs)
1753        }
1754        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1755            self.inner.set_active_extent(extent);
1756        }
1757    }
1758}
1759
1760#[cfg(all(feature = "metal", target_os = "macos"))]
1761pub mod metal_backend {
1762    use super::*;
1763    use rlx_metal::backend::MetalExecutable;
1764
1765    pub struct MetalBackend;
1766
1767    /// PLAN L4: ops the Metal backend can lower today. Includes
1768    /// DotGeneral (LowerDotGeneral pass) and ElementwiseRegion
1769    /// (decomposed by UnfuseElementwiseRegions). Excludes Cumsum,
1770    /// SelectiveScan, LoraMatMul, Sample,
1771    /// FusedAttentionBlock, FusedTransformerLayer, If, While —
1772    /// not yet wired in `rlx-metal/src/thunk.rs`'s compile_thunks.
1773    /// DequantMatMul (GGUF K-quants) lowers to a GPU dequant kernel
1774    /// + MPS matmul; legacy Int8 schemes remain CPU-only.
1775    const METAL_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
1776        use rlx_ir::OpKind::*;
1777        &[
1778            Input,
1779            Param,
1780            Constant,
1781            Activation,
1782            Cast,
1783            Binary,
1784            Compare,
1785            Where,
1786            ElementwiseRegion,
1787            MatMul,
1788            DotGeneral,
1789            LayerNorm,
1790            LayerNorm2d,
1791            GroupNorm,
1792            RmsNorm,
1793            ResizeNearest2x,
1794            AxialRope2d,
1795            Attention,
1796            AttentionBackward,
1797            RmsNormBackwardInput,
1798            RmsNormBackwardGamma,
1799            RmsNormBackwardBeta,
1800            RopeBackward,
1801            CumsumBackward,
1802            GatherBackward,
1803            Rope,
1804            Reshape,
1805            Transpose,
1806            Narrow,
1807            Concat,
1808            Expand,
1809            Gather,
1810            Reduce,
1811            Softmax,
1812            TopK,
1813            Conv,
1814            ConvTranspose2d,
1815            Pool,
1816            GroupedMatMul,
1817            DequantGroupedMatMul,
1818            DequantMoEWeights,
1819            ScatterAdd,
1820            DequantMatMul,
1821            GatedDeltaNet,
1822            FusedSwiGLU,
1823            FusedMatMulBiasAct,
1824            FusedResidualLN,
1825            FusedResidualRmsNorm,
1826            // User-registered custom ops dispatched through
1827            // `rlx_metal::op_registry`. Lowering panics with a clear
1828            // message if the named MetalKernel isn't registered;
1829            // executor inserts a sync point + runs the host kernel
1830            // against the unified-memory arena.
1831            Custom,
1832            // Op::Fft is supported via the same host-fallback pattern
1833            // as Custom: sync the GPU, run rlx-cpu's FFT against the
1834            // unified-memory arena, restart cmd_buf. A native Metal
1835            // compute kernel will replace this when a workload makes
1836            // the sync the bottleneck.
1837            Fft,
1838            // Host-fallback splat (unified-memory arena + rlx-cpu/splat).
1839            GaussianSplatRender,
1840            GaussianSplatRenderBackward,
1841            GaussianSplatPrepare,
1842            GaussianSplatRasterize,
1843        ]
1844    };
1845
1846    impl Backend for MetalBackend {
1847        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
1848            METAL_SUPPORTED_OPS
1849        }
1850
1851        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
1852            use rlx_opt::pass::Pass as _;
1853            // Same If/While → primitive rewrite as the CPU pipeline
1854            // (Metal also has no native sub-graph executor wired
1855            // through its thunk schedule).
1856            let graph = rlx_opt::LowerControlFlow.run(graph);
1857            let mut dispatch = options.kernel_dispatch;
1858            let graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
1859                graph,
1860                METAL_SUPPORTED_OPS,
1861                dispatch,
1862            )
1863            .unwrap_or_else(|errors| {
1864                panic!("{}", rlx_opt::format_legalize_error("metal", &errors));
1865            });
1866            // Optional cleanup passes upstream of Metal's pipeline
1867            let graph = if options.dce {
1868                rlx_opt::DeadCodeElimination.run(graph)
1869            } else {
1870                graph
1871            };
1872            let graph = if options.constant_folding {
1873                rlx_opt::ConstantFolding.run(graph)
1874            } else {
1875                graph
1876            };
1877
1878            // Hand the policy to MetalExecutable so the rewrite runs AFTER
1879            // its internal fusion passes (avoids breaking pattern matchers).
1880            Box::new(MetalExecutableWrapper {
1881                inner: MetalExecutable::compile_with_policy(
1882                    graph,
1883                    options.policy.clone(),
1884                    Some(METAL_SUPPORTED_OPS),
1885                ),
1886            })
1887        }
1888
1889        fn compile_lir(
1890            &self,
1891            lir: LirModule,
1892            options: &CompileOptions,
1893        ) -> Box<dyn ExecutableGraph> {
1894            use rlx_opt::pass::Pass as _;
1895            let mut graph = lir.into_graph();
1896            graph = rlx_opt::LowerControlFlow.run(graph);
1897            let mut dispatch = options.kernel_dispatch;
1898            let mut graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
1899                graph,
1900                METAL_SUPPORTED_OPS,
1901                dispatch,
1902            )
1903            .unwrap_or_else(|errors| {
1904                panic!("{}", rlx_opt::format_legalize_error("metal", &errors));
1905            });
1906            if options.dce {
1907                graph = rlx_opt::DeadCodeElimination.run(graph);
1908            }
1909            if options.constant_folding {
1910                graph = rlx_opt::ConstantFolding.run(graph);
1911            }
1912            Box::new(MetalExecutableWrapper {
1913                inner: MetalExecutable::compile_from_fused(
1914                    graph,
1915                    options.policy.clone(),
1916                    Some(METAL_SUPPORTED_OPS),
1917                ),
1918            })
1919        }
1920    }
1921
1922    struct MetalExecutableWrapper {
1923        inner: MetalExecutable,
1924    }
1925
1926    unsafe impl Send for MetalExecutableWrapper {}
1927
1928    impl ExecutableGraph for MetalExecutableWrapper {
1929        fn set_param(&mut self, name: &str, data: &[f32]) {
1930            self.inner.set_param(name, data);
1931        }
1932        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
1933            self.inner.run(inputs)
1934        }
1935        fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
1936            self.inner.run_slots(inputs)
1937        }
1938        fn arena_ptr(&self) -> *const u8 {
1939            self.inner.arena_ptr()
1940        }
1941        fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
1942            self.inner.commit_no_wait(inputs);
1943        }
1944        fn sync_pending(&mut self) {
1945            self.inner.sync_pending();
1946        }
1947        fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
1948            self.inner.run_pipelined(input_sets)
1949        }
1950        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1951            self.inner.set_active_extent(extent);
1952        }
1953
1954        /// Typed param upload — accepts F16/BF16 host bytes by widening
1955        /// to F32 first, then routing through `set_param`. The Metal
1956        /// arena's `write_from_f32` honors per-node F16 storage when
1957        /// AutoMixedPrecision rewrote the param. U8/I8 packed weights
1958        /// copy directly into the arena for `Op::DequantMatMul`.
1959        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1960            if matches!(dtype, rlx_ir::DType::U8 | rlx_ir::DType::I8) {
1961                self.inner.set_param_bytes(name, data);
1962                return;
1963            }
1964            if dtype == rlx_ir::DType::F32 {
1965                let n = data.len() / 4;
1966                let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
1967                self.inner.set_param(name, s);
1968            } else {
1969                let f32_buf = super::widen_bytes_to_f32(data, dtype);
1970                self.inner.set_param(name, &f32_buf);
1971            }
1972        }
1973
1974        /// Typed run. Inputs widen to F32 (existing path; F64 host
1975        /// inputs through `run_typed` is a separate Metal extension).
1976        /// Outputs: F64 outputs go through the byte-direct
1977        /// `output_bytes_per_node` path (no precision loss in the
1978        /// f32 round-trip); other dtypes keep the f32-narrow path
1979        /// for backward compatibility with existing AutoMixedPrecision
1980        /// rewrites.
1981        fn run_typed(
1982            &mut self,
1983            inputs: &[(&str, &[u8], rlx_ir::DType)],
1984        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1985            let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
1986            for (name, data, dt) in inputs {
1987                let v = super::widen_bytes_to_f32(data, *dt);
1988                owned.push((name.to_string(), v));
1989            }
1990            let refs: Vec<(&str, &[f32])> = owned
1991                .iter()
1992                .map(|(n, d)| (n.as_str(), d.as_slice()))
1993                .collect();
1994            let dtypes = self.inner.output_dtypes();
1995            let f32_outs = self.inner.run(&refs);
1996            let byte_outs = self.inner.output_bytes_per_node();
1997            f32_outs
1998                .into_iter()
1999                .zip(byte_outs.into_iter())
2000                .zip(
2001                    dtypes
2002                        .into_iter()
2003                        .chain(std::iter::repeat(rlx_ir::DType::F32)),
2004                )
2005                .map(|((f32_v, byte_v), dt)| match dt {
2006                    rlx_ir::DType::F64 => (byte_v, dt),
2007                    _ => (super::narrow_f32_to_bytes(&f32_v, dt), dt),
2008                })
2009                .collect()
2010        }
2011    }
2012}
2013
2014// ── CUDA Backend ────────────────────────────────────────────────────────
2015
2016#[cfg(feature = "cuda")]
2017pub mod cuda_backend {
2018    use super::*;
2019    use rlx_cuda::backend::CudaExecutable;
2020
2021    pub struct CudaBackend;
2022
2023    /// PLAN L4: ops the CUDA backend can lower today. Excludes
2024    /// FusedSwiGLU, LoraMatMul, FusedAttentionBlock,
2025    /// FusedTransformerLayer (no kernel) + If, While (no executor
2026    /// wiring). DotGeneral via LowerDotGeneral; ElementwiseRegion
2027    /// lowered natively by an NVRTC interpreted-chain kernel.
2028    const CUDA_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2029        use rlx_ir::OpKind::*;
2030        &[
2031            Input,
2032            Param,
2033            Constant,
2034            Activation,
2035            Cast,
2036            Binary,
2037            Compare,
2038            Where,
2039            ElementwiseRegion,
2040            MatMul,
2041            DotGeneral,
2042            LayerNorm,
2043            LayerNorm2d,
2044            RmsNorm,
2045            Attention,
2046            AttentionBackward,
2047            RmsNormBackwardInput,
2048            RmsNormBackwardGamma,
2049            RmsNormBackwardBeta,
2050            RopeBackward,
2051            CumsumBackward,
2052            GatherBackward,
2053            Rope,
2054            Reshape,
2055            Transpose,
2056            Narrow,
2057            Concat,
2058            Expand,
2059            Gather,
2060            Reduce,
2061            Softmax,
2062            Cumsum,
2063            TopK,
2064            Sample,
2065            Conv,
2066            ConvTranspose2d,
2067            Pool,
2068            GroupedMatMul,
2069            DequantGroupedMatMul,
2070            DequantMoEWeights,
2071            ScatterAdd,
2072            DequantMatMul,
2073            SelectiveScan,
2074            FusedMatMulBiasAct,
2075            FusedResidualLN,
2076            FusedResidualRmsNorm,
2077            GaussianSplatRender,
2078            GaussianSplatRenderBackward,
2079            GaussianSplatPrepare,
2080            GaussianSplatRasterize,
2081            Custom,
2082        ]
2083    };
2084
2085    impl Backend for CudaBackend {
2086        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2087            CUDA_SUPPORTED_OPS
2088        }
2089
2090        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2091            // Decompose FusedSwiGLU / FAB / etc. before legalization (CudaExecutable
2092            // unfuses again; this pass is idempotent).
2093            let graph = rlx_cuda::unfuse::unfuse(graph);
2094            let graph = rlx_opt::rewrite_for_backend(graph, CUDA_SUPPORTED_OPS);
2095            if let Err(errors) = rlx_opt::legalize_for_backend(&graph, CUDA_SUPPORTED_OPS) {
2096                panic!("{}", rlx_opt::format_legalize_error("cuda", &errors));
2097            }
2098            use rlx_opt::pass::Pass as _;
2099            let graph = if options.dce {
2100                rlx_opt::DeadCodeElimination.run(graph)
2101            } else {
2102                graph
2103            };
2104            let graph = if options.constant_folding {
2105                rlx_opt::ConstantFolding.run(graph)
2106            } else {
2107                graph
2108            };
2109            // Backend-aware fusion via the shared compile pipeline.
2110            let compile_result = crate::stages::compile_graph_stages_for_backend(
2111                rlx_driver::Device::Cuda,
2112                graph,
2113                options,
2114                CUDA_SUPPORTED_OPS,
2115            );
2116            crate::stages::maybe_log_fusion(&compile_result.fusion);
2117            let graph = compile_result.lir.into_graph();
2118            let graph = match options.policy.clone() {
2119                Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
2120                None => graph,
2121            };
2122            Box::new(CudaExecutableWrapper {
2123                inner: CudaExecutable::compile(graph),
2124            })
2125        }
2126
2127        fn compile_lir(
2128            &self,
2129            lir: LirModule,
2130            options: &CompileOptions,
2131        ) -> Box<dyn ExecutableGraph> {
2132            let graph = prepare_fused_graph(
2133                rlx_cuda::unfuse::unfuse(lir.into_graph()),
2134                options,
2135                CUDA_SUPPORTED_OPS,
2136                "cuda",
2137            );
2138            Box::new(CudaExecutableWrapper {
2139                inner: CudaExecutable::compile(graph),
2140            })
2141        }
2142    }
2143
2144    struct CudaExecutableWrapper {
2145        inner: CudaExecutable,
2146    }
2147
2148    // CudaExecutable owns CudaContext + CudaSlice handles; cudarc claims
2149    // they're Send (CudaContext is Arc-wrapped, CudaSlice is logically
2150    // a device pointer + length). The Backend trait requires Send for
2151    // the executable; we honor that here.
2152    unsafe impl Send for CudaExecutableWrapper {}
2153
2154    impl ExecutableGraph for CudaExecutableWrapper {
2155        fn set_param(&mut self, name: &str, data: &[f32]) {
2156            self.inner.set_param(name, data);
2157        }
2158        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2159            self.inner.run(inputs)
2160        }
2161        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2162            self.inner.set_active_extent(extent);
2163        }
2164
2165        /// Typed param upload — widens F16/BF16 host bytes to f32
2166        /// before routing through `set_param`. CUDA's arena is
2167        /// f32-uniform; the half-precision matmul tier opts in via
2168        /// the separate `set_param_half` API.
2169        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2170            if matches!(dtype, rlx_ir::DType::U8 | rlx_ir::DType::I8) {
2171                self.inner.set_param_bytes(name, data);
2172                return;
2173            }
2174            if dtype == rlx_ir::DType::F32 {
2175                let n = data.len() / 4;
2176                let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2177                self.inner.set_param(name, s);
2178            } else {
2179                let f32_buf = super::widen_bytes_to_f32(data, dtype);
2180                self.inner.set_param(name, &f32_buf);
2181            }
2182        }
2183
2184        /// Typed run — widen each typed input to F32, run, then narrow
2185        /// each output back to its declared graph dtype.
2186        fn run_typed(
2187            &mut self,
2188            inputs: &[(&str, &[u8], rlx_ir::DType)],
2189        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2190            let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2191            for (name, data, dt) in inputs {
2192                let v = super::widen_bytes_to_f32(data, *dt);
2193                owned.push((name.to_string(), v));
2194            }
2195            let refs: Vec<(&str, &[f32])> = owned
2196                .iter()
2197                .map(|(n, d)| (n.as_str(), d.as_slice()))
2198                .collect();
2199            let dtypes = self.inner.output_dtypes();
2200            let outs = self.inner.run(&refs);
2201            outs.into_iter()
2202                .zip(
2203                    dtypes
2204                        .into_iter()
2205                        .chain(std::iter::repeat(rlx_ir::DType::F32)),
2206                )
2207                .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2208                .collect()
2209        }
2210    }
2211}
2212
2213// ── ROCm Backend ────────────────────────────────────────────────────────
2214
2215#[cfg(feature = "rocm")]
2216pub mod rocm_backend {
2217    use super::*;
2218    use rlx_rocm::backend::RocmExecutable;
2219
2220    pub struct RocmBackend;
2221
2222    /// PLAN L4: ROCm is the sister crate of CUDA; identical Step
2223    /// enum + dispatch shape → identical claimed op set.
2224    const ROCM_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2225        use rlx_ir::OpKind::*;
2226        &[
2227            Input,
2228            Param,
2229            Constant,
2230            Activation,
2231            Cast,
2232            Binary,
2233            Compare,
2234            Where,
2235            ElementwiseRegion,
2236            MatMul,
2237            DotGeneral,
2238            LayerNorm,
2239            RmsNorm,
2240            Attention,
2241            AttentionBackward,
2242            Rope,
2243            Reshape,
2244            Transpose,
2245            Narrow,
2246            Concat,
2247            Expand,
2248            Gather,
2249            Reduce,
2250            Softmax,
2251            Cumsum,
2252            TopK,
2253            Sample,
2254            Conv,
2255            Pool,
2256            GroupedMatMul,
2257            DequantGroupedMatMul,
2258            DequantMoEWeights,
2259            ScatterAdd,
2260            DequantMatMul,
2261            SelectiveScan,
2262            FusedMatMulBiasAct,
2263            FusedResidualLN,
2264            FusedResidualRmsNorm,
2265            GaussianSplatRender,
2266            GaussianSplatRenderBackward,
2267            GaussianSplatPrepare,
2268            GaussianSplatRasterize,
2269            Custom,
2270        ]
2271    };
2272
2273    impl Backend for RocmBackend {
2274        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2275            ROCM_SUPPORTED_OPS
2276        }
2277
2278        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2279            let graph = rlx_opt::rewrite_for_backend(graph, ROCM_SUPPORTED_OPS);
2280            if let Err(errors) = rlx_opt::legalize_for_backend(&graph, ROCM_SUPPORTED_OPS) {
2281                panic!("{}", rlx_opt::format_legalize_error("rocm", &errors));
2282            }
2283            use rlx_opt::pass::Pass as _;
2284            let graph = if options.dce {
2285                rlx_opt::DeadCodeElimination.run(graph)
2286            } else {
2287                graph
2288            };
2289            let graph = if options.constant_folding {
2290                rlx_opt::ConstantFolding.run(graph)
2291            } else {
2292                graph
2293            };
2294            let compile_result = crate::stages::compile_graph_stages_for_backend(
2295                rlx_driver::Device::Rocm,
2296                graph,
2297                options,
2298                ROCM_SUPPORTED_OPS,
2299            );
2300            crate::stages::maybe_log_fusion(&compile_result.fusion);
2301            let graph = compile_result.lir.into_graph();
2302            let graph = match options.policy.clone() {
2303                Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
2304                None => graph,
2305            };
2306            Box::new(RocmExecutableWrapper {
2307                inner: RocmExecutable::compile(graph),
2308            })
2309        }
2310
2311        fn compile_lir(
2312            &self,
2313            lir: LirModule,
2314            options: &CompileOptions,
2315        ) -> Box<dyn ExecutableGraph> {
2316            let graph = prepare_fused_graph(lir.into_graph(), options, ROCM_SUPPORTED_OPS, "rocm");
2317            Box::new(RocmExecutableWrapper {
2318                inner: RocmExecutable::compile(graph),
2319            })
2320        }
2321    }
2322
2323    struct RocmExecutableWrapper {
2324        inner: RocmExecutable,
2325    }
2326
2327    // Same Send-claim shape as CudaExecutableWrapper. RocmExecutable
2328    // owns Arc<RocmContext> + HipBuffer handles; the HipRuntime bundle
2329    // is internally thread-safe per AMD's documentation.
2330    unsafe impl Send for RocmExecutableWrapper {}
2331
2332    impl ExecutableGraph for RocmExecutableWrapper {
2333        fn set_param(&mut self, name: &str, data: &[f32]) {
2334            self.inner.set_param(name, data);
2335        }
2336        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2337            self.inner.run(inputs)
2338        }
2339        fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2340            self.inner.set_active_extent(extent);
2341        }
2342
2343        /// Typed param upload — widens F16/BF16 host bytes to f32
2344        /// before routing through `set_param`. ROCm's arena is
2345        /// f32-uniform; the half-precision matmul tier opts in via
2346        /// the separate `set_param_half` API.
2347        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2348            if matches!(dtype, rlx_ir::DType::U8 | rlx_ir::DType::I8) {
2349                self.inner.set_param_bytes(name, data);
2350                return;
2351            }
2352            if dtype == rlx_ir::DType::F32 {
2353                let n = data.len() / 4;
2354                let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2355                self.inner.set_param(name, s);
2356            } else {
2357                let f32_buf = super::widen_bytes_to_f32(data, dtype);
2358                self.inner.set_param(name, &f32_buf);
2359            }
2360        }
2361
2362        /// Typed run — widen each typed input to F32, run, then narrow
2363        /// each output back to its declared graph dtype.
2364        fn run_typed(
2365            &mut self,
2366            inputs: &[(&str, &[u8], rlx_ir::DType)],
2367        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2368            let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2369            for (name, data, dt) in inputs {
2370                let v = super::widen_bytes_to_f32(data, *dt);
2371                owned.push((name.to_string(), v));
2372            }
2373            let refs: Vec<(&str, &[f32])> = owned
2374                .iter()
2375                .map(|(n, d)| (n.as_str(), d.as_slice()))
2376                .collect();
2377            let dtypes = self.inner.output_dtypes();
2378            let outs = self.inner.run(&refs);
2379            outs.into_iter()
2380                .zip(
2381                    dtypes
2382                        .into_iter()
2383                        .chain(std::iter::repeat(rlx_ir::DType::F32)),
2384                )
2385                .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2386                .collect()
2387        }
2388    }
2389}
2390
2391// ── TPU Backend ─────────────────────────────────────────────────────────
2392
2393#[cfg(feature = "tpu")]
2394pub mod tpu_backend {
2395    use super::*;
2396    use rlx_tpu::TpuExecutable;
2397
2398    pub struct TpuBackend;
2399
2400    /// Ops the TPU backend lowers to HLO. Full inference parity with
2401    /// rlx-cuda / rlx-rocm. Composite ops (FusedSwiGLU /
2402    /// FusedAttentionBlock / FusedTransformerLayer / LoraMatMul / If /
2403    /// While) are unfused inside `rlx_tpu::unfuse::unfuse` ahead of
2404    /// HLO emission, so they don't appear here.
2405    const TPU_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2406        use rlx_ir::OpKind::*;
2407        &[
2408            Input,
2409            Param,
2410            Constant,
2411            Activation,
2412            Cast,
2413            Binary,
2414            Compare,
2415            Where,
2416            ElementwiseRegion,
2417            MatMul,
2418            DotGeneral,
2419            LayerNorm,
2420            RmsNorm,
2421            Attention,
2422            Rope,
2423            Reshape,
2424            Transpose,
2425            Narrow,
2426            Concat,
2427            Expand,
2428            Gather,
2429            Reduce,
2430            Softmax,
2431            Cumsum,
2432            TopK,
2433            Sample,
2434            Conv,
2435            Pool,
2436            GroupedMatMul,
2437            DequantGroupedMatMul,
2438            DequantMoEWeights,
2439            ScatterAdd,
2440            DequantMatMul,
2441            SelectiveScan,
2442            // Real-INT8 path + fake-quant.
2443            QMatMul,
2444            QConv2d,
2445            Quantize,
2446            Dequantize,
2447            FusedMatMulBiasAct,
2448            FusedResidualLN,
2449            FusedResidualRmsNorm,
2450            // Splat: no on-chip kernel — lowered to common primitive MIR via logical_kernel.
2451        ]
2452    };
2453
2454    impl Backend for TpuBackend {
2455        fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2456            TPU_SUPPORTED_OPS
2457        }
2458
2459        fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2460            let graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
2461                graph,
2462                TPU_SUPPORTED_OPS,
2463                options.kernel_dispatch,
2464            )
2465            .unwrap_or_else(|errors| {
2466                panic!("{}", rlx_opt::format_legalize_error("tpu", &errors));
2467            });
2468            // The TPU's IR-side pass pipeline (DCE, ConstFold,
2469            // FuseResidualLN, FuseMatMulBiasAct, LegalizeBroadcast,
2470            // MarkElementwiseRegions) lives inside
2471            // `TpuExecutable::compile` so the same passes run whether
2472            // a caller goes through Session or invokes the executable
2473            // directly. We only do backend-cross-cutting work here:
2474            // legalization (must precede the pipeline so we panic
2475            // early on unsupported ops) and AutoMixedPrecision.
2476            //
2477            // Default policy on TPU is `AutoMixedBf16`: BF16 is the
2478            // native compute dtype on TPU silicon and recent GPUs,
2479            // and XLA's CPU plugin handles it natively too. Callers
2480            // can opt out by passing an explicit `PrecisionPolicy`
2481            // (e.g. `AlwaysF32` for accuracy debugging or
2482            // `AlwaysF16` to match a CUDA workload's choice).
2483            use rlx_opt::pass::Pass as _;
2484            let policy = options
2485                .policy
2486                .clone()
2487                .unwrap_or(rlx_opt::PrecisionPolicy::AutoMixedBf16);
2488            let graph = rlx_opt::AutoMixedPrecision::new(policy).run(graph);
2489            let _ = options.dce;
2490            let _ = options.constant_folding;
2491            Box::new(TpuExecutableWrapper {
2492                inner: TpuExecutable::compile(graph),
2493            })
2494        }
2495    }
2496
2497    struct TpuExecutableWrapper {
2498        inner: TpuExecutable,
2499    }
2500
2501    // PJRT clients + buffers are documented as thread-safe per the
2502    // upstream C API. Same Send-claim shape as CudaExecutableWrapper /
2503    // RocmExecutableWrapper.
2504    unsafe impl Send for TpuExecutableWrapper {}
2505
2506    impl ExecutableGraph for TpuExecutableWrapper {
2507        fn set_param(&mut self, name: &str, data: &[f32]) {
2508            self.inner.set_param(name, data);
2509        }
2510        fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2511            self.inner.run(inputs)
2512        }
2513
2514        /// Typed param upload — widens F16/BF16/etc. host bytes to
2515        /// f32 today. Once the HLO emitter speaks bf16 natively
2516        /// (which TPUs prefer over f16), the typed path will hand
2517        /// the original bytes straight through `Buffer_FromHostBuffer`.
2518        fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2519            if dtype == rlx_ir::DType::F32 {
2520                let n = data.len() / 4;
2521                let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2522                self.inner.set_param(name, s);
2523            } else {
2524                let f32_buf = super::widen_bytes_to_f32(data, dtype);
2525                self.inner.set_param(name, &f32_buf);
2526            }
2527        }
2528
2529        fn run_typed(
2530            &mut self,
2531            inputs: &[(&str, &[u8], rlx_ir::DType)],
2532        ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2533            let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2534            for (name, data, dt) in inputs {
2535                let v = super::widen_bytes_to_f32(data, *dt);
2536                owned.push((name.to_string(), v));
2537            }
2538            let refs: Vec<(&str, &[f32])> = owned
2539                .iter()
2540                .map(|(n, d)| (n.as_str(), d.as_slice()))
2541                .collect();
2542            let dtypes = self.inner.output_dtypes();
2543            let outs = self.inner.run(&refs);
2544            outs.into_iter()
2545                .zip(
2546                    dtypes
2547                        .into_iter()
2548                        .chain(std::iter::repeat(rlx_ir::DType::F32)),
2549                )
2550                .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2551                .collect()
2552        }
2553    }
2554}