Skip to main content

rlx_ir/
op.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Operation types — every tensor op in the RLX IR.
17//!
18//! Designed for pattern-matching fusion: ops are grouped by category so
19//! fusion passes can reason about them structurally.
20
21use crate::DType;
22
23/// Unary element-wise activation functions.
24#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub enum Activation {
27    Gelu,
28    GeluApprox,
29    Silu, // SwiGLU gate activation
30    Relu,
31    Sigmoid,
32    Tanh,
33    Exp,
34    Log,
35    Sqrt,
36    Rsqrt,
37    Neg,
38    Abs,
39    /// `sin(x)`. Backward: `dx = upstream · cos(x)`.
40    Sin,
41    /// `cos(x)`. Backward: `dx = -upstream · sin(x)`.
42    Cos,
43    /// `tan(x)`. Backward: `dx = upstream · sec²(x) = upstream · (1 + tan²(x))`.
44    Tan,
45    /// `atan(x)`. Backward: `dx = upstream · (1 / (1 + x²))`.
46    Atan,
47    /// Round to nearest integer (half-to-even), in f32.
48    /// Forward: `x.round()`. Backward: STE — treats as identity, so
49    /// the gradient passes through unchanged. Useful as a primitive
50    /// for composing custom quantization schemes (Mul-by-recip-scale
51    /// → Round → Clamp → Mul-by-scale = a hand-rolled FakeQuantize
52    /// that the elementwise-region pass can fuse into a single kernel).
53    Round,
54}
55
56/// Scale-tracking strategy for `Op::FakeQuantize`. Determines how
57/// the per-channel `s[c]` is computed each forward pass.
58///
59/// * `PerBatch` — recompute `s[c] = max(|x|) / q_max` from the
60///   current data on every call. Simple, no extra inputs, but
61///   noisy for activations (max-abs jumps batch-to-batch).
62///
63/// * `EMA { decay }` — keep a running `s[c]` in a state tensor
64///   (passed as a second op input). On each call, blend the
65///   current per-batch max-abs into the state via
66///   `state' = decay·state + (1-decay)·max_abs`. Smooth scale
67///   over training, makes activation-QAT actually trainable.
68///   Typical `decay = 0.99`.
69///
70/// * `Fixed` — never recompute. The state tensor's value is
71///   used as-is each call (set once at construction or by the
72///   caller). Useful when scales are pre-calibrated.
73#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
74#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Default)]
75pub enum ScaleMode {
76    #[default]
77    PerBatch,
78    EMA {
79        decay: f32,
80    },
81    Fixed,
82}
83
84impl Eq for ScaleMode {}
85impl std::hash::Hash for ScaleMode {
86    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
87        match self {
88            ScaleMode::PerBatch => state.write_u8(0),
89            ScaleMode::EMA { decay } => {
90                state.write_u8(1);
91                state.write_u32(decay.to_bits());
92            }
93            ScaleMode::Fixed => state.write_u8(2),
94        }
95    }
96}
97
98/// Straight-through estimator variants for `Op::FakeQuantize`'s
99/// backward. The forward is the same regardless: discrete
100/// `clamp(round(x/s)) * s`. The choice here affects only the
101/// gradient w.r.t. `x` during training.
102///
103/// * `Identity` — `dx = upstream`. The original STE; treats the
104///   round as identity in the backward direction. Simplest, fine
105///   for moderate bit widths (i4 / i8).
106///
107/// * `ClippedIdentity` — `dx = upstream * (|x| ≤ q_max·s)`. Zero
108///   the gradient when the input was outside the quantization
109///   range (i.e. the clamp activated). Stops the optimizer from
110///   pushing weights further into saturation.
111///
112/// * `Tanh` — `dx = upstream * (1 - tanh²(x/s))`. Smooth surrogate
113///   for the round step. Slowly attenuates the gradient as `|x|`
114///   approaches `q_max·s`. Often best on tight bit widths (i2).
115///
116/// * `HardTanh` — `dx = upstream * (1 - |x/(q_max·s)|).max(0)`.
117///   Piecewise-linear cousin of `Tanh`; cheaper to compute and
118///   nearly as effective.
119#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
120#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
121pub enum SteKind {
122    #[default]
123    Identity,
124    ClippedIdentity,
125    Tanh,
126    HardTanh,
127}
128
129/// Binary element-wise operations.
130#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
132pub enum BinaryOp {
133    Add,
134    Sub,
135    Mul,
136    Div,
137    Max,
138    Min,
139    Pow,
140}
141
142/// Comparison operations (return Bool tensor).
143#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
144#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
145pub enum CmpOp {
146    Eq,
147    Ne,
148    Lt,
149    Le,
150    Gt,
151    Ge,
152}
153
154/// What kind of attention mask the kernel should apply.
155///
156/// Borrowed from MAX's `nn/attention/mha_mask.mojo` pattern (#20 in
157/// PLAN.md): one attention kernel handles all variants by branching on
158/// the mask kind, instead of forcing every caller to materialize a mask
159/// tensor. The win is two-fold:
160///   1. **`None`** — single unpadded sequence: no mask load, no per-key
161///      compare in the inner loop.
162///   2. **`Causal`** — autoregressive decode: kernel generates the upper-
163///      triangular fill from `(qi, ki)` directly; no `seq²` mask tensor
164///      ever exists.
165///
166/// `Custom` is the existing path — read mask values from the 4th input.
167#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
168#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
169pub enum MaskKind {
170    /// No masking — every position attends to every position.
171    None,
172    /// Causal (autoregressive) — position `qi` attends only to `ki <= qi`.
173    Causal,
174    /// Sliding window — position `qi` attends to `ki ∈ [qi - w, qi]`.
175    SlidingWindow(usize),
176    /// Read mask values from the input tensor (default; matches BERT
177    /// padding-mask behavior). Tensor shape `[batch, key_len]` with
178    /// `1.0` = valid, `<0.5` = ignored.
179    Custom,
180    /// Additive per-head, per-query bias tensor
181    /// `[batch, num_heads, query_len, key_len]` added to the
182    /// `QK^T · scale` scores before softmax. Lets DETR-style boxRPB
183    /// and other learned position biases reuse the fast `Op::Attention`
184    /// path instead of decomposing into matmul + add + softmax + matmul.
185    Bias,
186}
187
188/// Which forward input an [`Op::AttentionBackward`] node differentiates.
189#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
190#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
191pub enum AttentionBwdWrt {
192    Query,
193    Key,
194    Value,
195}
196
197/// Reduction operations along specified axes.
198#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
199#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
200pub enum ReduceOp {
201    Sum,
202    Mean,
203    Max,
204    Min,
205    Prod,
206}
207
208/// PLAN L4: discriminant for each [`Op`] variant. Used by
209/// [`Op::kind`] + the `Backend::supported_ops` trait method to declare
210/// which ops a backend can lower; the `LegalizeForBackend` pass in
211/// `rlx-opt` checks the graph against this set and fails the compile
212/// when an unsupported op is present (instead of silent fallback).
213#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
214#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
215pub enum OpKind {
216    Input,
217    Param,
218    Constant,
219    Activation,
220    Cast,
221    StopGradient,
222    Quantize,
223    Dequantize,
224    FakeQuantize,
225    FakeQuantizeLSQ,
226    FakeQuantizeLSQBackwardX,
227    FakeQuantizeLSQBackwardScale,
228    Binary,
229    Compare,
230    Where,
231    ElementwiseRegion,
232    /// Fused sampling / geometry chain (FKL-style transform region).
233    TransformRegion,
234    /// Same element-wise chain over multiple batch planes (horizontal fusion).
235    BatchElementwiseRegion,
236    MatMul,
237    DotGeneral,
238    DenseSolve,
239    BatchedDenseSolve,
240    LayerNorm,
241    LayerNorm2d,
242    GroupNorm,
243    BatchNormInference,
244    RmsNorm,
245    ResizeNearest2x,
246    Attention,
247    Rope,
248    AxialRope2d,
249    Reshape,
250    Transpose,
251    Narrow,
252    Concat,
253    Expand,
254    Gather,
255    Reduce,
256    Softmax,
257    Cumsum,
258    TopK,
259    Sample,
260    Conv,
261    Im2Col,
262    ConvTranspose2d,
263    Pool,
264    ReluBackward,
265    ActivationBackward,
266    FakeQuantizeBackward,
267    ComplexNormSq,
268    ComplexNormSqBackward,
269    Conjugate,
270    MaxPool2dBackward,
271    Conv2dBackwardInput,
272    Conv2dBackwardWeight,
273    SoftmaxCrossEntropyWithLogits,
274    SoftmaxCrossEntropyBackward,
275    AttentionBackward,
276    LayerNormBackwardInput,
277    LayerNormBackwardGamma,
278    RmsNormBackwardInput,
279    RmsNormBackwardGamma,
280    RmsNormBackwardBeta,
281    RopeBackward,
282    GroupNormBackwardInput,
283    GroupNormBackwardGamma,
284    GroupNormBackwardBeta,
285    BatchNormInferenceBackwardInput,
286    BatchNormInferenceBackwardGamma,
287    BatchNormInferenceBackwardBeta,
288    CumsumBackward,
289    GatherBackward,
290    GroupedMatMul,
291    DequantGroupedMatMul,
292    DequantMoEWeights,
293    ScatterAdd,
294    LoraMatMul,
295    DequantMatMul,
296    QMatMul,
297    QConv2d,
298    SelectiveScan,
299    GatedDeltaNet,
300    FusedSwiGLU,
301    FusedMatMulBiasAct,
302    FusedResidualLN,
303    FusedResidualRmsNorm,
304    FusedAttentionBlock,
305    FusedTransformerLayer,
306    If,
307    While,
308    Scan,
309    ScanBackward,
310    ScanBackwardXs,
311    /// CPU reference 3D Gaussian splat raster (project → bin → sort → raster).
312    /// See [`Op::GaussianSplatRender`].
313    GaussianSplatRender,
314    /// Backward of [`Op::GaussianSplatRender`] — packed scene parameter gradients.
315    GaussianSplatRenderBackward,
316    /// Project + tile bin + sort + ray grid (strict IR splat stage 1).
317    GaussianSplatPrepare,
318    /// Per-pixel raster from prepared buffers (strict IR splat stage 2).
319    GaussianSplatRasterize,
320    /// User-registered op dispatched through `op_registry`. All
321    /// custom ops (Sparse-LU, FFT, eigensolve, ...) share this kind;
322    /// the per-op identity lives in `Op::Custom::name`.
323    Custom,
324    /// User-defined sub-graph with optional override AD rules. See
325    /// [`Op::CustomFn`] / [`crate::Graph::custom_fn`].
326    CustomFn,
327    /// 1D FFT primitive (forward or inverse) — see [`Op::Fft`].
328    Fft,
329    /// Ternary pruned radix-2 butterfly stage — see [`Op::FftButterflyStage`].
330    FftButterflyStage,
331    /// Whisper-style log-mel from block-layout FFT spectrum — see [`Op::LogMel`].
332    LogMel,
333    /// Backward of [`Op::LogMel`] w.r.t. block-layout spectrum input 0.
334    LogMelBackward,
335}
336
337/// An operand inside a fused [`ChainStep`] — either a graph-level input
338/// to the [`Op::ElementwiseRegion`] (by index 0..num_inputs) or the
339/// result of a previous step in the chain (by index 0..step_position).
340#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
341#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
342pub enum ChainOperand {
343    Input(u32),
344    Step(u32),
345}
346
347/// One step in a fused element-wise chain. Each step produces exactly
348/// one scalar result (per element); later steps can refer to it via
349/// [`ChainOperand::Step`]. The whole chain runs per element in registers.
350#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
351#[derive(Debug, Clone, PartialEq)]
352pub enum ChainStep {
353    Activation(Activation, ChainOperand),
354    Cast(DType, ChainOperand),
355    Binary(BinaryOp, ChainOperand, ChainOperand),
356    Compare(CmpOp, ChainOperand, ChainOperand),
357    /// 3-input element-wise select: `cond ? on_true : on_false`. Mirrors
358    /// `Op::Where` inside a chain. `cond` is treated as truthy iff
359    /// non-zero. Lets the optimizer fold attention masks / clamp-style
360    /// patterns into a single region kernel instead of breaking the
361    /// chain at the first `Op::Where`.
362    Where(ChainOperand, ChainOperand, ChainOperand),
363}
364
365/// Pre-region memory transform fused into [`Op::ElementwiseRegion`].
366#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
367#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
368pub enum RegionPrologue {
369    #[default]
370    None,
371    /// Input is half-resolution NCHW; output shape is 2× H×W (nearest 2×).
372    ResizeNearest2x,
373}
374
375/// One sampling step in [`Op::TransformRegion`].
376#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
377#[derive(Debug, Clone, PartialEq)]
378pub enum TransformStep {
379    ResizeNearest2x(ChainOperand),
380}
381
382/// An operation in the RLX IR graph.
383///
384/// Operations are categorized for fusion analysis:
385/// - Element-wise ops fuse with anything reading their output
386/// - Matmul/Conv are BLAS-dispatched and form fusion boundaries
387/// - Reductions are fusion roots (drive the loop iteration)
388#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
389#[derive(Debug, Clone, PartialEq)]
390pub enum Op {
391    // ── Graph inputs ────────────────────────────────────────────
392    /// Model input with a name (shape on the Node).
393    Input {
394        name: String,
395    },
396
397    /// Model parameter (weight/bias) with a name.
398    Param {
399        name: String,
400    },
401
402    /// Constant tensor embedded in the graph.
403    Constant {
404        data: Vec<u8>,
405    },
406
407    // ── Element-wise unary ──────────────────────────────────────
408    /// Unary activation: one input, same shape output.
409    Activation(Activation),
410
411    /// Cast to a different dtype.
412    Cast {
413        to: DType,
414    },
415
416    /// Stop-gradient (a.k.a. `detach` / `lax.stop_gradient`). Forward is
417    /// identity; the reverse-mode autodiff rule returns **no** gradient
418    /// contribution for the input. Single input, same shape & dtype
419    /// output. Used to build a Gradient-Reverse-Layer with identity
420    /// forward semantics in user code (see maet-rs `dat_loss`).
421    StopGradient,
422
423    /// INT8 quantization. Input f32; output i8 same shape.
424    ///   `q[i] = saturate_i8(round(x[i] / scale[c]) + zero_point[c])`
425    /// where `c` selects the per-channel scale/zp when `axis = Some(d)`
426    /// (`c = idx[d]`), or always uses index 0 when `axis = None`
427    /// (per-tensor). The `scales` / `zero_points` payload length must
428    /// match `1` for per-tensor and `input.dim(d)` for per-channel.
429    /// Static — typically produced at calibration time and baked
430    /// into the loaded model. Use `Op::Dequantize` for the inverse.
431    Quantize {
432        axis: Option<usize>,
433        scales: Vec<f32>,
434        zero_points: Vec<i32>,
435    },
436
437    /// INT8 dequantization (inverse of `Op::Quantize`). Input i8;
438    /// output f32 same shape.
439    ///   `x[i] = (q[i] - zero_point[c]) · scale[c]`
440    /// where `c` is selected by `axis` exactly as in `Op::Quantize`.
441    Dequantize {
442        axis: Option<usize>,
443        scales: Vec<f32>,
444        zero_points: Vec<i32>,
445    },
446
447    /// "Fake-quantize" op for **quantization-aware training** (QAT).
448    /// Input f32; output f32 same shape. Forward computes a per-axis
449    /// (or per-tensor when `axis = None`) max-abs scale on the fly:
450    ///   `s[c] = max(|x[..., c, ...]|) / q_max(bits)`
451    /// then quantizes-then-dequantizes:
452    ///   `out[i] = clamp(round(x[i] / s[c]), -q_max, q_max) * s[c]`
453    /// where `q_max` is `127` for `bits=8`, `7` for `bits=4`, `1` for
454    /// `bits=2` (ternary). Symmetric only — zero-point is always 0.
455    ///
456    /// The point of this op is to make the SGD optimizer "see" the
457    /// deployment-time rounding during training. Backward is the
458    /// **straight-through estimator** (STE): the gradient passes
459    /// through (variant chosen by `ste`), ignoring the discontinuity
460    /// at the round. Without STE the rounding would have zero
461    /// gradient almost everywhere and learning would stop.
462    ///
463    /// Inserted by the trainer on conv / FC weight tensors when
464    /// `--qat` is on; the existing `Op::Quantize` / packing path at
465    /// the end of training still handles the deployment-side
466    /// conversion to `i8`/`i4`/`i2` codes.
467    FakeQuantize {
468        bits: u8,
469        axis: Option<usize>,
470        ste: SteKind,
471        scale_mode: ScaleMode,
472    },
473
474    /// Learned Step Size Quantization (LSQ; Esser et al. 2020,
475    /// `arXiv:1902.08153`). Like `FakeQuantize` but the per-channel
476    /// `scale` is a *learned parameter*, passed as the second input.
477    /// Forward is identical to `FakeQuantize` with a fixed scale:
478    ///   `out[i] = clamp(round(x[i]/s[c]), -q_max, q_max) * s[c]`
479    /// Backward computes both `dx` (STE) and `dscale[c]` via the
480    /// closed-form gradient:
481    ///   `dscale[c] = sum_i ψ(x[i]/s[c]) · upstream[i]`
482    /// where `ψ(z) = -z + round(z)` if `|z| ≤ q_max` else
483    /// `sign(z) · q_max`. Routinely beats per-batch and EMA at
484    /// tight bit widths (i2 / i3).
485    ///
486    /// Inputs: `[x, scale]`. `scale` is `[chan_dim]` f32 (matches
487    /// `axis`); for `axis = None` it's `[1]`.
488    FakeQuantizeLSQ {
489        bits: u8,
490        axis: Option<usize>,
491    },
492
493    /// Backward pass for `Op::FakeQuantizeLSQ`. Computes BOTH the
494    /// gradient w.r.t. `x` (STE) and the gradient w.r.t. `scale`
495    /// (closed-form). Output shape matches `x`; the `scale` gradient
496    /// is reduced separately by `LsqScaleGradient`.
497    /// Inputs: `[x, scale, dy]`. Output: `dx`, same shape as `x`.
498    FakeQuantizeLSQBackwardX {
499        bits: u8,
500        axis: Option<usize>,
501    },
502
503    /// Companion to `FakeQuantizeLSQBackwardX`: computes the
504    /// `[chan_dim]` per-channel scale gradient. Inputs `[x, scale, dy]`.
505    /// Output shape matches `scale`.
506    FakeQuantizeLSQBackwardScale {
507        bits: u8,
508        axis: Option<usize>,
509    },
510
511    // ── Element-wise binary ─────────────────────────────────────
512    /// Binary op with broadcasting: two inputs, output shape is broadcast result.
513    Binary(BinaryOp),
514
515    // ── Comparison ──────────────────────────────────────────────
516    /// Element-wise comparison: two inputs, Bool output.
517    Compare(CmpOp),
518
519    /// Select elements: cond (Bool), on_true, on_false → output.
520    Where,
521
522    /// Fused element-wise region (PLAN L2). Holds an N-step chain of
523    /// element-wise operations. Inputs are referenced by index 0..num_inputs;
524    /// each step's result can be referenced by later steps via
525    /// `ChainOperand::Step(idx)`. The output is the last step's result.
526    /// Emitted by `MarkElementwiseRegions` in `rlx-opt` from chains of
527    /// Activation/Cast/Binary/Compare/Where ops with single-consumer
528    /// intermediates and broadcast-compatible shapes. Backends that
529    /// don't have a region kernel can decompose back to the original
530    /// chain via `unfuse::unfuse_elementwise_regions`.
531    ///
532    /// `scalar_input_mask` is a per-input bitfield (bit `i` set ⇒
533    /// input `i` is a scalar broadcast — has shape `[1]`). Kept as a
534    /// fast-path indicator that lets kernels skip the modulo entirely
535    /// when they detect a scalar.
536    ///
537    /// `input_modulus[i]` is the per-input element count, used by
538    /// kernels to compute `arena[input_offs[i] + (gid % input_modulus[i])]`
539    /// — the trailing-shape broadcast pattern. `0` means "no broadcast"
540    /// (input matches the output element count; kernel reads `gid`
541    /// directly). `1` means scalar; any other value means the input
542    /// has fewer elements than the output and they tile by modulo.
543    /// The encoder only allows broadcasts where `out_elems % in_elems
544    /// == 0` so the modulo divides cleanly. Lets chains include bias /
545    /// scale / eps / mask factors that previously broke the chain at
546    /// a Binary op with mismatched shapes.
547    ElementwiseRegion {
548        chain: Vec<ChainStep>,
549        num_inputs: u32,
550        scalar_input_mask: u32,
551        input_modulus: [u32; 16],
552        /// FKL-style closed fusion: apply before the element-wise chain.
553        prologue: RegionPrologue,
554        /// External input index that supplies the prologue transform source (default 0).
555        prologue_input: u32,
556    },
557
558    /// Fused transform chain (resize, future crop/color). Decompose via
559    /// [`rlx_fusion::DecomposeFusionRegions`] when no native kernel exists.
560    TransformRegion {
561        steps: Vec<TransformStep>,
562        num_inputs: u32,
563    },
564
565    /// Identical [`Op::ElementwiseRegion`] chain over `num_batch_inputs` tensors
566    /// (horizontal / z-plane fusion). Inputs are separate batch slices.
567    BatchElementwiseRegion {
568        chain: Vec<ChainStep>,
569        num_batch_inputs: u32,
570        scalar_input_mask: u32,
571        input_modulus: [u32; 16],
572        prologue: RegionPrologue,
573        prologue_input: u32,
574    },
575
576    // ── Linear algebra ──────────────────────────────────────────
577    /// Matrix multiply. Inputs: [.., M, K] × [.., K, N] → [.., M, N].
578    /// Batch dimensions are broadcast.
579    MatMul,
580
581    /// Matrix multiply with explicit dimension specification.
582    /// Like XLA's DotGeneral — handles arbitrary batch/contracting dims.
583    DotGeneral {
584        lhs_contracting: Vec<usize>,
585        rhs_contracting: Vec<usize>,
586        lhs_batch: Vec<usize>,
587        rhs_batch: Vec<usize>,
588    },
589
590    /// Batched dense linear solve. Inputs: `A [B, N, N]`,
591    /// `b [B, N]` or `b [B, N, K]`. Output: same shape as `b`.
592    ///
593    /// Per-batch independent solve — each `A[i]` and `b[i]` are
594    /// solved as a separate `Op::DenseSolve`. Emitted by vmap of
595    /// `Op::DenseSolve`. The CPU lowering loops over the batch
596    /// dimension calling `dgesv` per slice (LAPACK doesn't expose a
597    /// batched solve on Accelerate; cuSOLVER does on NVIDIA).
598    BatchedDenseSolve,
599
600    /// Dense linear solve `x = A⁻¹ · b` via LU factorization.
601    /// Inputs: `A [N, N]`, `b [N]` (or `b [N, K]` for multi-RHS).
602    /// Output: same shape as `b`.
603    ///
604    /// VJP via the implicit-function theorem:
605    ///   `dx = solve(Aᵀ, upstream)`
606    ///   `dA = -outer(dx, x)`   (x is the forward output)
607    ///   `db = dx`
608    /// The rule is dtype-agnostic; lowering is per-backend (Accelerate
609    /// `dgesv` / `sgesv`, cuSOLVER, etc.).
610    DenseSolve,
611
612    // ── Normalization ───────────────────────────────────────────
613    /// Layer normalization: input, gamma, beta → normalized output.
614    /// `axis` is the feature dimension (usually -1).
615    LayerNorm {
616        axis: i32,
617        eps: f32,
618    },
619
620    /// Group normalization on NCHW tensors: `input`, `gamma`, `beta` → same shape.
621    /// Normalizes over `(C/num_groups) × H × W` per group.
622    GroupNorm {
623        num_groups: usize,
624        eps: f32,
625    },
626
627    /// LayerNorm2d on NCHW: normalize across the channel axis at each spatial
628    /// position (candle / SAM `LayerNorm2d` semantics — not PyTorch's H×W norm).
629    LayerNorm2d {
630        eps: f32,
631    },
632
633    /// Nearest-neighbor 2× upsample on NCHW (doubles spatial dims 2 and 3).
634    ResizeNearest2x,
635
636    /// RMS normalization: input, gamma → normalized output.
637    RmsNorm {
638        axis: i32,
639        eps: f32,
640    },
641
642    /// BatchNorm inference with frozen running statistics.
643    /// Inputs: `x`, `gamma`, `beta`, `running_mean`, `running_var`.
644    /// Feature dimension is the last axis of `x`; stats are 1-D `[C]`.
645    BatchNormInference {
646        eps: f32,
647    },
648
649    // ── Attention ───────────────────────────────────────────────
650    /// Scaled dot-product attention: Q, K, V, \[mask\] → output.
651    /// The compiler can lower this to fused SDPA or flash attention.
652    /// `mask_kind` controls how masking is applied — `Custom` reads from
653    /// the 4th input tensor; `None` / `Causal` / `SlidingWindow` skip the
654    /// mask load and apply the mask directly in the inner loop. See
655    /// `MaskKind` for the rationale.
656    ///
657    /// `score_scale`: when `Some(s)`, dot-product scores are multiplied by
658    /// `s` instead of the default `1/sqrt(head_dim)` (Gemma uses `head_dim^-0.5`
659    /// explicitly in config). `attn_logit_softcap`: when `Some(c)`, applies
660    /// `c * tanh(s/c)` to scores before softmax (Gemma 2).
661    Attention {
662        num_heads: usize,
663        head_dim: usize,
664        mask_kind: MaskKind,
665        score_scale: Option<f32>,
666        attn_logit_softcap: Option<f32>,
667    },
668
669    /// Rotary position embedding applied to one tensor: x, cos, sin → x_rotated.
670    /// Apply separately to Q and K. `head_dim` is the per-head width; `n_rot`
671    /// is how many leading dims get NeoX RoPE (pair offset `n_rot/2`). When
672    /// `n_rot < head_dim`, trailing dims are copied unchanged (Qwen3.5 MRoPE).
673    Rope {
674        head_dim: usize,
675        n_rot: usize,
676    },
677
678    /// SAM2 axial 2-D RoPE on `[batch, seq, num_heads * head_dim]`.
679    AxialRope2d {
680        end_x: usize,
681        end_y: usize,
682        head_dim: usize,
683        num_heads: usize,
684        theta: f32,
685        repeat_factor: usize,
686    },
687
688    // ── Shape manipulation ──────────────────────────────────────
689    Reshape {
690        new_shape: Vec<i64>,
691    },
692    Transpose {
693        perm: Vec<usize>,
694    },
695    /// Select a contiguous slice along an axis.
696    Narrow {
697        axis: usize,
698        start: usize,
699        len: usize,
700    },
701    /// Concatenate along an axis.
702    Concat {
703        axis: usize,
704    },
705    /// Expand (broadcast) to a target shape.
706    Expand {
707        target_shape: Vec<i64>,
708    },
709    /// Gather elements by index along an axis (embedding lookup).
710    Gather {
711        axis: usize,
712    },
713
714    // ── Reduction ───────────────────────────────────────────────
715    /// Reduce along specified axes.
716    Reduce {
717        op: ReduceOp,
718        axes: Vec<usize>,
719        keep_dim: bool,
720    },
721
722    /// Selective scan (plan #15) — Mamba-style state-space model
723    /// step. The recurrence:
724    ///   `h[t] = exp(Δ[t] * A) * h[t-1] + Δ[t] * B[t] * x[t]`
725    ///   `y[t] = C[t] * h[t]`
726    /// where state `h` has dimension `state_size` and the input has
727    /// `(batch, seq, hidden)`.
728    ///
729    /// Inputs (in order):
730    ///   `x [b, s, h]`      f32 input
731    ///   `delta [b, s, h]`  f32 step size (per-position, per-channel)
732    ///   `a [h, n]`         f32 transition matrix (one per channel)
733    ///   `b [b, s, n]`      f32 input projection
734    ///   `c [b, s, n]`      f32 output projection
735    /// Output: `[b, s, h]` f32. State `h` is implicit; the kernel
736    /// scans through the seq dimension carrying it.
737    ///
738    /// `state_size` = `n` is exposed for the cost model.
739    SelectiveScan {
740        state_size: usize,
741    },
742
743    /// Gated DeltaNet linear-attention recurrence — the per-layer
744    /// kernel used by Qwen3.5/3.6 trunk "linear attention" blocks
745    /// (and Qwen3-Next, Kimi-Linear). Mirrors
746    /// `llama.cpp / src/models/delta-net-base.cpp` autoregressive
747    /// path; chunked + fused variants ride the same op identity.
748    ///
749    /// **Math (per token `t`, head `h`, state size `n`):**
750    /// state matrix `S[h, i, j]` is implicit (reset per batch).
751    /// ```text
752    ///   S[h]     *= exp(g[t,h])                     # scalar gate
753    ///   sk[h,j]   = Σ_i S[h,i,j] * k[t,h,i]
754    ///   d[h,j]    = (v[t,h,j] - sk[h,j]) * b[t,h]   # b = beta
755    ///   S[h,i,j] += k[t,h,i] * d[h,j]               # outer-prod
756    ///   o[t,h,j]  = Σ_i S[h,i,j] * (q[t,h,i] / √n)
757    /// ```
758    ///
759    /// Inputs:
760    ///   `q     [b, s, h_v, n]`  f32 queries (L2-normed by caller)
761    ///   `k     [b, s, h_v, n]`  f32 keys    (L2-normed by caller;
762    ///                            GQA-repeated to match `h_v`)
763    ///   `v     [b, s, h_v, n]`  f32 values
764    ///   `g     [b, s, h_v]`     f32 log-gate (exp'd inside kernel)
765    ///   `beta  [b, s, h_v]`     f32 delta-rule mixing factor
766    ///
767    /// Output: `[b, s, h_v, n]` f32.
768    ///
769    /// When `carry_state` is true, a sixth input `state [b, h_v, n, n]`
770    /// provides the initial SSM matrix per head; the kernel updates it
771    /// in place across the sequence and leaves the final state in the
772    /// same buffer (same layout as the internal scan state:
773    /// `state[h, i, j]` row-major over `(n, n)` per head).
774    GatedDeltaNet {
775        state_size: usize,
776        carry_state: bool,
777    },
778
779    /// Fused dequant + matmul (plan #5). The biggest LLM-bandwidth
780    /// win on Apple Silicon: dequantizes weights inside the matmul
781    /// inner loop, never materializing f32 weights.
782    ///
783    /// **BREAKING CHANGE in 0.2.0:** `num_inputs()` is now
784    /// scheme-dependent — **4** for legacy Int8 schemes, **2** for
785    /// the new GGUF K-quant schemes (their scales/mins live inside
786    /// the packed bytes, so no side-channel `scale` / `zp` tensors
787    /// are fed in). Callers that assumed a fixed 4-input contract
788    /// must inspect `scheme.is_gguf()` before reading inputs.
789    ///
790    /// Inputs (Int8 schemes — `scheme.is_gguf() == false`):
791    ///   `x [m, k]`             f32 activations
792    ///   `w_q [k, n]` packed    quantized weight bytes (i8 per
793    ///                          element for Int8 schemes; 4-bit
794    ///                          packed two-per-byte for Int4)
795    ///   `scale [k/block, n]`   per-block f32 dequant scale
796    ///   `zp    [k/block, n]`   per-block f32 zero-point
797    ///                          (zero-tensor if symmetric)
798    ///
799    /// Inputs (`Nvfp4Block` — fixed group size 16 along K):
800    ///   `x [m, k]`             f32 activations
801    ///   `w_q [k,n/2]` packed   FP4 E2M1 codes (unsigned nibble 0..15)
802    ///   `scale [k/16, n]` u8   FP8 E4M3 block scales (one byte / group)
803    ///   `global_scale [1]` f32 per-tensor scale (pass `[1.0]` if unused)
804    ///
805    /// Inputs (GGUF schemes — `scheme.is_gguf() == true`):
806    ///   `x [m, k]`             f32 activations
807    ///   `packed_w [bytes]`     raw GGUF super-block bytes; the
808    ///                          dequantizer reads the per-sub-block
809    ///                          scales / mins / quants directly out
810    ///                          of the buffer per the K-quant block
811    ///                          layout (no side tensors).
812    ///
813    /// Output: `[m, n]` f32.
814    ///
815    /// `block_size` (on the Int8 schemes only) is the number of
816    /// consecutive elements that share one (scale, zero_point) pair.
817    /// The Op carries enough metadata that the kernel doesn't need
818    /// a separate `QuantMap` lookup at run time.
819    DequantMatMul {
820        scheme: crate::quant::QuantScheme,
821    },
822
823    /// Real INT8-arithmetic matrix multiply with i32 accumulation.
824    /// Inputs (in order):
825    ///   `x      [M, K]`  i8 activations (zero-point = `x_zp`)
826    ///   `w      [K, N]`  i8 weights     (zero-point = `w_zp`)
827    ///   `bias   [N]`     i32 (in accumulator scale = `x_scale·w_scale`),
828    ///                    pass a zeros tensor for "no bias"
829    /// Output:  `[M, N]`  i8 (zero-point = `out_zp`)
830    ///
831    /// Per-element compute:
832    ///   `out[m,n] = requantize(bias[n] + Σₖ (x[m,k]-x_zp)·(w[k,n]-w_zp), mult, out_zp)`
833    /// where `mult = x_scale · w_scale / out_scale`.
834    ///
835    /// This is the same kernel shape `rlx-cortexm/src/dense.rs`
836    /// uses for on-device int8 inference, lifted into the IR so the
837    /// rlx-cpu backend can run a quantized graph directly (instead
838    /// of round-tripping through fake-quant Dequantize → MatMul →
839    /// Quantize). 2-D only — generalizing to batched comes when a
840    /// real workload demands it.
841    QMatMul {
842        x_zp: i32,
843        w_zp: i32,
844        out_zp: i32,
845        mult: f32,
846    },
847
848    /// Real INT8-arithmetic 2-D convolution with i32 accumulation.
849    /// Inputs:
850    ///   `x      [N, C_in, H, W]`              i8 (zero-point = `x_zp`)
851    ///   `w      [C_out, C_in/groups, kH, kW]` i8 (zero-point = `w_zp`)
852    ///   `bias   [C_out]`                      i32 in accumulator scale
853    /// Output: `[N, C_out, H_out, W_out]` i8 (zero-point = `out_zp`).
854    /// Same NCHW geometry contract as `Op::Conv`; same requantize
855    /// math as `Op::QMatMul` (per-element `acc·mult` rounded to i8).
856    QConv2d {
857        kernel_size: Vec<usize>,
858        stride: Vec<usize>,
859        padding: Vec<usize>,
860        dilation: Vec<usize>,
861        groups: usize,
862        x_zp: i32,
863        w_zp: i32,
864        out_zp: i32,
865        mult: f32,
866    },
867
868    /// Fused LoRA matmul: `out = x·W + scale * x·A·B`.
869    /// Inputs (in order): `x [m, k]`, `w [k, n]`, `a [k, r]`, `b [r, n]`.
870    /// `r` is the LoRA rank (typically 4-64). `scale` is the
871    /// per-adapter `alpha / rank` knob.
872    /// Plan #9: lifts LoRA from "three matmuls + an add" into one
873    /// kernel that keeps the rank-r intermediate in registers.
874    LoraMatMul {
875        scale: f32,
876    },
877
878    /// Fused sampling kernel: logits → optional top-k filter →
879    /// optional top-p truncation → softmax → multinomial sample.
880    /// One f32-encoded sampled token id per batch row (output
881    /// shape `[batch]`).
882    ///
883    /// `temperature == 1.0` matches a plain argmax-of-softmax;
884    /// lower → sharper, higher → flatter. `top_k == 0` disables.
885    /// `top_p == 1.0` disables. `seed` is the Philox seed; pass 0
886    /// for "use process-global counter" (still deterministic
887    /// given the call order).
888    /// Borrowed from MAX's nn/sampling.mojo (#42 in PLAN.md).
889    /// Latency-critical: never materializes the full softmax
890    /// distribution on the host.
891    Sample {
892        top_k: usize,     // 0 = disabled
893        top_p: f32,       // 1.0 = disabled
894        temperature: f32, // 1.0 = neutral
895        seed: u64,        // 0 = use thread-local counter
896    },
897
898    /// Inclusive cumulative sum along an axis. Same shape in/out.
899    /// Underpins ragged-tensor offsets, sampling (top-p prefix sum),
900    /// and sequence-position math (#44 in PLAN.md).
901    /// `exclusive=true` shifts the result so output\[0\] = 0 (useful
902    /// for offset arrays where the first segment starts at 0).
903    Cumsum {
904        axis: i32,
905        exclusive: bool,
906    },
907
908    /// Softmax along an axis (reduction + element-wise).
909    Softmax {
910        axis: i32,
911    },
912
913    /// Top-K **indices** along the last axis. Output shape `[..., k]`,
914    /// f32-encoded indices (rlx is f32-only at the I/O boundary).
915    /// To recover the values, follow with a `Gather` against the
916    /// original tensor — works because Gather already supports any axis.
917    /// Ties broken by smaller index (matches NumPy / PyTorch
918    /// `torch.topk(..., largest=True, sorted=True)`).
919    /// Used by MoE gating; also useful for beam search.
920    TopK {
921        k: usize,
922    },
923
924    /// Indexed batched matmul. The MoE GEMM primitive.
925    /// Inputs: `[input, weight, expert_idx]`
926    ///   input       : [M, K]                — per-token activations
927    ///   weight      : [num_experts, K, N]   — stacked expert weights
928    ///   expert_idx  : \[M\]                   — f32-encoded expert id per token
929    /// Output         : [M, N]                — output\[i\] = input\[i\] @ weight[expert_idx\[i\]]
930    /// Naive impl on both backends; future work can replace with a
931    /// segmented/grouped GEMM when there's a real workload.
932    GroupedMatMul,
933
934    /// Fused GGUF K-quant dequant + [`Op::GroupedMatMul`]. Same three
935    /// inputs as `GroupedMatMul`, but `weight` is a U8 tensor holding
936    /// `num_experts` contiguous packed expert slabs (GGML layout, expert
937    /// dimension outermost). Scales live inside the packed bytes.
938    DequantGroupedMatMul {
939        scheme: crate::quant::QuantScheme,
940    },
941
942    /// Dequant a packed MoE expert stack to F32 `[num_experts, K, N]` in
943    /// GroupedMatMul layout. Input: U8 packed bytes; output shape is
944    /// declared on the node (`[E, K, N]`).
945    DequantMoEWeights {
946        scheme: crate::quant::QuantScheme,
947    },
948
949    /// Scatter-add into a destination tensor. The "unpermute" half of
950    /// MoE routing (also useful for embedding gradient updates).
951    /// Inputs: `[updates, indices]`
952    ///   updates : [num_updates, trailing]   — values to add
953    ///   indices : \[num_updates\]             — f32-encoded destination row
954    /// Output    : [out_dim, trailing]       — output[indices\[i\]] += updates\[i\]
955    /// `out_dim` is taken from the node's declared output shape.
956    /// Initial output is zero; multiple updates to the same row
957    /// accumulate (sequentially on CPU; with atomic-add on Metal).
958    ScatterAdd,
959
960    // ── Convolution ─────────────────────────────────────────────
961    /// 2D convolution on NCHW tensors. Also exposed as [`OpKind::Conv`] / `conv2d`.
962    /// Weight layout: `[C_out, C_in / groups, kH, kW]`.
963    Conv {
964        kernel_size: Vec<usize>,
965        stride: Vec<usize>,
966        padding: Vec<usize>,
967        dilation: Vec<usize>,
968        groups: usize,
969    },
970
971    /// NCHW im2col for conv backward-weight style matmul.
972    /// Input `[N, C, H, W]`. Output `[M, C·kH·kW]` with
973    /// `M = N · H_out · W_out`. When batch is [`dynamic::sym::BATCH`],
974    /// output rows use [`dynamic::sym::ROWS`] (bind `N · H_out · W_out`).
975    Im2Col {
976        kernel_size: Vec<usize>,
977        stride: Vec<usize>,
978        padding: Vec<usize>,
979        dilation: Vec<usize>,
980    },
981
982    /// 2D transposed convolution on NCHW. Weight layout (PyTorch):
983    /// `[C_in, C_out / groups, kH, kW]`.
984    ConvTranspose2d {
985        kernel_size: Vec<usize>,
986        stride: Vec<usize>,
987        padding: Vec<usize>,
988        dilation: Vec<usize>,
989        output_padding: Vec<usize>,
990        groups: usize,
991    },
992
993    // ── Pooling ─────────────────────────────────────────────────
994    Pool {
995        kind: ReduceOp,
996        kernel_size: Vec<usize>,
997        stride: Vec<usize>,
998        padding: Vec<usize>,
999    },
1000
1001    // ── Backward / training ops ─────────────────────────────────
1002    //
1003    // Closed-form gradient nodes emitted by `rlx-opt::autodiff`.
1004    // Pairing a forward op with a dedicated backward op (instead of
1005    // composing it from primitives) keeps the gradient kernel simple
1006    // and lets the backend recompute argmax / masks / softmax inline.
1007    /// ReLU backward: `dx = dy where x > 0 else 0`.
1008    /// Inputs: `[x, dy]` — both same shape; output matches.
1009    ReluBackward,
1010
1011    /// Element-wise complex squared-magnitude: `|z|² = z.re² + z.im²`.
1012    /// Input: 1 tensor with `DType::C64`. Output: same shape but
1013    /// `DType::F32`. The natural real-valued loss surface for
1014    /// Wirtinger reverse-mode AD on complex graphs — pair with
1015    /// [`Op::ComplexNormSqBackward`].
1016    ComplexNormSq,
1017
1018    /// Element-wise complex conjugate: `z̄ = z.re - i·z.im` per element.
1019    /// Input: 1 tensor with `DType::C64`. Output: same shape, same dtype.
1020    /// Used by Wirtinger VJP rules on `Op::Binary` over C64 (the rule
1021    /// for `y = a·b` is `dL/dā = upstream · conj(b)`, etc.).
1022    Conjugate,
1023
1024    /// Backward for [`Op::ComplexNormSq`] under Wirtinger calculus.
1025    /// `f(z) = |z|² = z·z̄`, so `∂f/∂z̄ = z`. Given upstream real
1026    /// cotangent `g` (same shape as the forward output), the C64
1027    /// gradient with respect to `z` is `g · z` element-wise, returned
1028    /// in C64 storage `[re_g·re_z, re_g·im_z]` per element.
1029    ///
1030    /// Inputs: `[z (C64), g (F32)]` — both same logical shape; output
1031    /// matches `z` (C64).
1032    ComplexNormSqBackward,
1033
1034    /// LayerNorm backward w.r.t. the input. Computes
1035    ///   `d_x[..., d] = inv_std · (dy·γ - mean(dy·γ) - x̂·mean(dy·γ·x̂))`
1036    /// over the feature axis, where `x̂ = (x - mean)/std` is recomputed
1037    /// inline from `x`. Inputs: `[x, gamma, dy]`; output shape = `x.shape`.
1038    /// Currently lowers axis=-1 only (matches the forward thunk).
1039    LayerNormBackwardInput {
1040        axis: i32,
1041        eps: f32,
1042    },
1043
1044    /// LayerNorm backward w.r.t. gamma. Computes
1045    ///   `d_gamma[d] = Σ_{batch} dy[..., d] · x̂[..., d]`
1046    /// — sums the per-element product of upstream and the (recomputed)
1047    /// normalized input over the leading axes. Inputs: `[x, dy]`;
1048    /// output shape = `gamma.shape` (= 1-D feature axis).
1049    LayerNormBackwardGamma {
1050        axis: i32,
1051        eps: f32,
1052    },
1053
1054    /// RMSNorm backward w.r.t. input. Inputs `[x, gamma, beta, dy]`; output = `x.shape`.
1055    RmsNormBackwardInput {
1056        axis: i32,
1057        eps: f32,
1058    },
1059
1060    /// RMSNorm backward w.r.t. gamma. Inputs `[x, gamma, beta, dy]`; output = `gamma.shape`.
1061    RmsNormBackwardGamma {
1062        axis: i32,
1063        eps: f32,
1064    },
1065
1066    /// RMSNorm backward w.r.t. beta. Inputs `[x, gamma, beta, dy]`; output = `beta.shape`.
1067    RmsNormBackwardBeta {
1068        axis: i32,
1069        eps: f32,
1070    },
1071
1072    /// RoPE backward w.r.t. `x`. Inputs `[dy, cos, sin]`; output = `dy.shape`.
1073    RopeBackward {
1074        head_dim: usize,
1075        n_rot: usize,
1076    },
1077
1078    /// GroupNorm (NCHW) backward w.r.t. input. Inputs `[x, gamma, beta, dy]`.
1079    GroupNormBackwardInput {
1080        num_groups: usize,
1081        eps: f32,
1082    },
1083
1084    /// GroupNorm backward w.r.t. gamma. Inputs `[x, dy]`; output = `gamma.shape`.
1085    GroupNormBackwardGamma {
1086        num_groups: usize,
1087        eps: f32,
1088    },
1089
1090    /// GroupNorm backward w.r.t. beta. Inputs `[x, dy]`; output = `beta.shape`.
1091    GroupNormBackwardBeta {
1092        num_groups: usize,
1093        eps: f32,
1094    },
1095
1096    /// BatchNorm inference backward w.r.t. `x`. Inputs `[x, gamma, mean, var, dy]`.
1097    BatchNormInferenceBackwardInput {
1098        eps: f32,
1099    },
1100
1101    /// BatchNorm inference backward w.r.t. `gamma`. Inputs `[x, mean, var, dy]`.
1102    BatchNormInferenceBackwardGamma {
1103        eps: f32,
1104    },
1105
1106    /// BatchNorm inference backward w.r.t. `beta`. Inputs `[dy]`; output = `beta.shape`.
1107    BatchNormInferenceBackwardBeta,
1108
1109    /// Cumsum backward along `axis`. Inputs `[dy]`; output matches forward input shape.
1110    CumsumBackward {
1111        axis: i32,
1112        exclusive: bool,
1113    },
1114
1115    /// Gather backward (scatter-add into table). Inputs `[dy, indices]`; output = table shape.
1116    /// `axis` matches forward [`Op::Gather`].
1117    GatherBackward {
1118        axis: i32,
1119    },
1120
1121    /// Generic element-wise activation backward. `kind` selects the
1122    /// closed-form derivative `d/dx act(x)`. Inputs: `[x, dy]`; output
1123    /// shape matches `x`. The kernel computes `d/dx · dy` per element.
1124    ///
1125    /// Closed forms (all element-wise):
1126    /// * `Gelu`     — exact derivative of erf-based GELU.
1127    /// * `GeluApprox` — derivative of the tanh approximation
1128    ///   `0.5 x (1 + tanh(√(2/π)(x + 0.044715 x³)))`.
1129    /// * `Silu`     — `σ(x)·(1 + x·(1 - σ(x)))`.
1130    /// * `Sigmoid`  — `σ(x)·(1 - σ(x))`.
1131    /// * `Tanh`     — `1 - tanh(x)²`.
1132    /// * `Exp`      — `exp(x)`.
1133    /// * `Log`      — `1 / x`.
1134    /// * `Sqrt`     — `0.5 / sqrt(x)`.
1135    /// * `Rsqrt`    — `-0.5 · x^(-3/2)`.
1136    /// * `Neg`      — `-1`.
1137    /// * `Abs`      — `sign(x)` (zero at x=0).
1138    /// * `Sin`      — `cos(x)`.
1139    /// * `Cos`      — `-sin(x)`.
1140    /// * `Tan`      — `1 + tan²(x)`.
1141    /// * `Atan`     — `1 / (1 + x²)`.
1142    /// * `Relu`     — kept here for completeness; the dedicated
1143    ///   `ReluBackward` op is preferred for relu and is what the
1144    ///   autodiff pass actually emits.
1145    ActivationBackward {
1146        kind: Activation,
1147    },
1148
1149    /// Backward for `Op::FakeQuantize` under a non-default STE.
1150    /// Inputs `[x, dy]`: the forward input and the upstream
1151    /// gradient. Output `dx` same shape. The `bits`/`axis`/`ste`
1152    /// fields must match the forward op so the kernel computes the
1153    /// same per-channel scale and applies the right STE attenuation.
1154    /// For `SteKind::Identity` this op is unnecessary — autodiff
1155    /// just routes `upstream` through unchanged.
1156    FakeQuantizeBackward {
1157        bits: u8,
1158        axis: Option<usize>,
1159        ste: SteKind,
1160    },
1161
1162    /// 2D max-pool backward. Routes each element of `dy` back into the
1163    /// position in `x`'s window where the forward max was taken.
1164    /// Inputs: `[x, dy]` with `x [N, C, H, W]` and
1165    /// `dy [N, C, H_out, W_out]`. Output: same shape as `x`.
1166    /// Carries the forward pool's geometry so the kernel can recompute
1167    /// the argmax position per window without a saved-indices tensor.
1168    MaxPool2dBackward {
1169        kernel_size: Vec<usize>,
1170        stride: Vec<usize>,
1171        padding: Vec<usize>,
1172    },
1173
1174    /// 2D conv backward w.r.t. input. Computes `dx = conv_transpose(dy, w)`.
1175    /// Inputs: `[dy, w]` with `dy [N, C_out, H_out, W_out]` and
1176    /// `w [C_out, C_in/groups, kH, kW]`. Output: `[N, C_in, H, W]`
1177    /// (declared on the node — caller knows the original input shape).
1178    /// Geometry is the forward conv's parameters, not the transposed
1179    /// conv's.
1180    Conv2dBackwardInput {
1181        kernel_size: Vec<usize>,
1182        stride: Vec<usize>,
1183        padding: Vec<usize>,
1184        dilation: Vec<usize>,
1185        groups: usize,
1186    },
1187
1188    /// 2D conv backward w.r.t. weight. Computes
1189    /// `dw[c_out, c_in, kh, kw] = sum_{n,h_out,w_out} x[n,c_in,...] * dy[n,c_out,h_out,w_out]`.
1190    /// Inputs: `[x, dy]`. Output: `[C_out, C_in/groups, kH, kW]`.
1191    Conv2dBackwardWeight {
1192        kernel_size: Vec<usize>,
1193        stride: Vec<usize>,
1194        padding: Vec<usize>,
1195        dilation: Vec<usize>,
1196        groups: usize,
1197    },
1198
1199    /// Fused softmax + cross-entropy loss with integer (f32-encoded)
1200    /// targets — the standard classification loss. Per-row output:
1201    /// `loss[n] = -log(softmax(logits[n])[labels[n]])`.
1202    /// Inputs: `[logits, labels]` with `logits [N, C]` and
1203    /// `labels [N]` (f32-encoded class indices). Output: `[N]`.
1204    /// Caller does the `Reduce::Mean` if they want a scalar.
1205    SoftmaxCrossEntropyWithLogits,
1206
1207    /// Backward of the fused loss above. Emits
1208    /// `dlogits[n,c] = (softmax(logits[n])[c] - one_hot(labels)[n,c]) * d_loss[n]`.
1209    /// Inputs: `[logits, labels, d_loss]`. Output: `[N, C]` (same shape
1210    /// as `logits`). Recomputes the softmax inline rather than threading
1211    /// it through from the forward node.
1212    SoftmaxCrossEntropyBackward,
1213
1214    /// Backward of [`Op::Attention`]. Recomputes scaled `QK^T`, applies
1215    /// the same `mask_kind` as the forward op, softmaxes scores, then
1216    /// emits **one** of `dQ`, `dK`, or `dV` selected by [`AttentionBwdWrt`].
1217    /// Autodiff emits three nodes (one per `wrt`) so each output shape
1218    /// stays a normal single-output MIR node.
1219    ///
1220    /// Inputs: `[q, k, v, dy]` plus optional mask when `mask_kind` is
1221    /// [`MaskKind::Custom`] or [`MaskKind::Bias`] (same convention as
1222    /// forward). Output shape matches `q`, `k`, or `v` respectively.
1223    AttentionBackward {
1224        num_heads: usize,
1225        head_dim: usize,
1226        mask_kind: MaskKind,
1227        wrt: AttentionBwdWrt,
1228    },
1229
1230    // ── Fused operations (created by optimization passes) ──────
1231    /// Fused matmul + bias + activation. Created from MatMul → Add → Activation.
1232    FusedMatMulBiasAct {
1233        activation: Option<Activation>,
1234    },
1235
1236    /// Fused residual + optional bias + layer norm.
1237    /// Created from Add(x, residual) → [Add(bias)] → LayerNorm.
1238    FusedResidualLN {
1239        has_bias: bool,
1240        eps: f32,
1241    },
1242
1243    /// Fused residual + optional bias + RMS norm.
1244    /// Created from Add(x, residual) → [Add(bias)] → RmsNorm.
1245    FusedResidualRmsNorm {
1246        has_bias: bool,
1247        eps: f32,
1248    },
1249
1250    /// Fused SwiGLU: split input into up/gate halves, silu(gate) * up.
1251    /// Created from Split → Silu → Mul when fed by a concatenated matmul.
1252    ///
1253    /// `cast_to`: optional output dtype — when `Some(dt)` the kernel casts
1254    /// its result from the input dtype to `dt` in-register, saving a
1255    /// separate cast pass. Reserved for future fp8/fp4 quantization paths;
1256    /// for f32→f16 mixed precision the AutoMixedPrecision pass already
1257    /// inserts a Cast node so this stays `None` in current pipelines.
1258    FusedSwiGLU {
1259        cast_to: Option<DType>,
1260        /// When `true`, the concatenated input stores gate in the low half
1261        /// `[..., 0..N)` and up in the high half `[..., N..2N)` — the layout
1262        /// produced when gate projection is emitted before up in the builder.
1263        /// Default `false`: up @ low, gate @ high (canonical concat order).
1264        gate_first: bool,
1265    },
1266
1267    /// Fused full transformer layer: attention block + residual+LN + FFN + residual+LN.
1268    /// All intermediates resident in registers/threadgroup memory; one kernel
1269    /// per layer instead of ~30 (the CPU's batch=1 win, lifted to IR so any
1270    /// backend can implement it as a monolithic kernel).
1271    ///
1272    /// Inputs: hidden, qkv_w, qkv_b, out_w, out_b,
1273    ///         ln1_g, ln1_b, fc1_w, fc1_b, fc2_w, fc2_b, ln2_g, ln2_b, mask
1274    /// Output: same shape as hidden.
1275    ///
1276    /// **Backend status:** same as FusedAttentionBlock. CPU implements
1277    /// the L1-cache-resident merge at the thunk level. Metal deferred —
1278    /// requires a single MSL kernel for the whole layer to actually
1279    /// beat the unfused path. Multi-day work; revisit when there's a
1280    /// model whose Metal inference is bottlenecked here rather than on
1281    /// the wait latency floor.
1282    FusedTransformerLayer {
1283        num_heads: usize,
1284        head_dim: usize,
1285        intermediate_size: usize,
1286        eps1: f32,
1287        eps2: f32,
1288        activation: Activation,
1289        has_bias: bool,
1290    },
1291
1292    /// Fused attention block: QKV projection → split → \[RoPE\] → SDPA → output projection.
1293    /// Created by FuseAttentionBlock pass when batch*seq is small.
1294    /// All intermediates stay in L1 cache — no arena writes between ops.
1295    ///
1296    /// Inputs (in order):
1297    ///   hidden, qkv_w, out_w, mask,
1298    ///   [qkv_b, out_b]      if has_bias,
1299    ///   [rope_cos, rope_sin] if has_rope
1300    ///
1301    /// **Backend status (Phase C finalize):**
1302    ///   CPU  — implemented at the *thunk* level: the CPU schedule
1303    ///          recognizes the multi-thunk pattern and merges into
1304    ///          a single FusedAttnBlock that keeps Q/K/V in stack
1305    ///          buffers across stages (the L1-cache win).
1306    ///   Metal — **deferred**. A dispatch-wrapper version (chaining
1307    ///          existing kernels) buys nothing the unfused Metal path
1308    ///          doesn't already get, since per-run cost is dominated
1309    ///          by `wait_until_completed` (~150 µs), not encode. The
1310    ///          real win is a single MSL kernel keeping Q/K/V in
1311    ///          threadgroup memory across stages — multi-day work.
1312    ///          Until then, Metal runs the unfused chain (one matmul,
1313    ///          three narrows, two ropes, attention, one matmul) — all
1314    ///          covered in op_coverage and parity_harness.
1315    FusedAttentionBlock {
1316        num_heads: usize,
1317        head_dim: usize,
1318        has_bias: bool,
1319        has_rope: bool,
1320    },
1321
1322    // ── Control flow (subgraphs as op payloads) ─────────────────
1323    //
1324    // Status: IR is defined; helper `run_if` / `run_while` exist in
1325    // rlx-runtime/src/subgraph.rs; **executor wiring is not yet
1326    // implemented** (both CPU thunk and Metal thunk fall through to
1327    // `Thunk::Nop` for these ops). Wiring requires:
1328    //   1. Recursive subgraph compile at parent-compile time.
1329    //   2. Per-subgraph input/output binding through the arena.
1330    //   3. Schedule-level dispatch when the predicate / loop cond is
1331    //      resolved at runtime.
1332    // Estimate: 4–6 hours of focused work + parity tests. Deferred
1333    // because no current in-tree model uses these ops;
1334    // surface area without a validation target invites silent bugs.
1335    /// Conditional: pick between two subgraphs based on a boolean predicate.
1336    /// Inputs: [predicate, ...captures (used inside both branches)].
1337    /// `then_branch` and `else_branch` are sub-graphs that share the
1338    /// captured inputs and must produce identically-shaped outputs.
1339    /// Used for: shape-dependent execution, batched inference of
1340    /// dynamic-length sequences with padding masks.
1341    If {
1342        then_branch: Box<crate::Graph>,
1343        else_branch: Box<crate::Graph>,
1344    },
1345
1346    /// Loop: iterate `body` while `cond` evaluates true.
1347    /// Inputs: [...initial loop-carried values].
1348    /// `cond`'s single output is a Bool scalar.
1349    /// `body`'s outputs become the next iteration's loop-carried inputs.
1350    /// Outputs of While are the values after the final iteration.
1351    /// Used for: KV-cache-driven autoregressive generation, beam search.
1352    While {
1353        cond: Box<crate::Graph>,
1354        body: Box<crate::Graph>,
1355        max_iterations: Option<usize>,
1356    },
1357
1358    /// Bounded-length loop with a fixed-shape carry, optional per-step
1359    /// inputs, and optional stacked output. Mirrors JAX's `lax.scan`.
1360    ///
1361    /// Body signature: `(carry, x_t_0, ..., x_t_{num_xs-1}) → carry_next`
1362    /// — `1 + num_xs` Op::Inputs in NodeId construction order (first
1363    /// declared is the carry; the remaining `num_xs` are per-step
1364    /// slices). Single output (the next carry).
1365    ///
1366    /// Outer Op::Scan inputs (in order):
1367    ///   `[init_carry, xs_0, xs_1, ..., xs_{num_xs-1}]`
1368    /// Each `xs_i` has shape `[length, *per_step_shape_i]`; the body
1369    /// sees `xs_i[t]` (a `per_step_shape_i` slice) on iteration `t`.
1370    ///
1371    /// Outer Op::Scan output:
1372    ///   * `save_trajectory == false` — final carry, shape `*carry`.
1373    ///   * `save_trajectory == true`  — stacked trajectory of carries,
1374    ///     shape `[length, *carry]`. Row `t` is the carry after step
1375    ///     `t+1`, so row `length-1` matches the no-trajectory case.
1376    ///
1377    /// Mirrors JAX's `lax.scan`. Common uses include time-stepping
1378    /// integrators with time-varying drives, Mamba-style SSM scans
1379    /// reading per-step inputs, and RNN-style sequence processing.
1380    Scan {
1381        body: Box<crate::Graph>,
1382        length: u32,
1383        save_trajectory: bool,
1384        /// Number of "broadcast" inputs — values that are constant
1385        /// across iterations. Outer scan inputs in order:
1386        ///   `[init, bcast_0..bcast_{B-1}, xs_0..xs_{X-1}]`
1387        /// Body Op::Inputs in NodeId order:
1388        ///   `[carry, bcast_0..bcast_{B-1}, x_t_0..x_t_{X-1}]`
1389        /// CPU executor fills bcast slots ONCE before the iteration
1390        /// loop (xs slots are filled per-step). The reverse-mode AD
1391        /// pre-pass materialises each bcast into an xs of shape
1392        /// `[length, *bcast]` via broadcast `Mul` so the rest of the
1393        /// VJP / executor pipeline can stay unchanged. `0` (default)
1394        /// keeps the original carry+xs scan shape.
1395        num_bcast: u32,
1396        /// Number of per-step `xs` inputs. Total outer Op::Scan
1397        /// inputs is `1 + num_bcast + num_xs`.
1398        num_xs: u32,
1399        /// Number of trajectory checkpoints when `save_trajectory ==
1400        /// true`. `0` means "save all `length` rows" (default). A
1401        /// positive value `K` means save only `K` evenly-spaced rows
1402        /// at indices `floor(t * length / K)` for `t in 0..K`. Used
1403        /// by recursive checkpointed AD: store O(√T) carries during
1404        /// forward, recompute the rest in the backward pass.
1405        ///
1406        /// When `0` (or `K == length`), the saved trajectory has
1407        /// shape `[length, *carry]` — same as the original behavior.
1408        /// When `0 < K < length`, the saved trajectory has shape
1409        /// `[K, *carry]`.
1410        num_checkpoints: u32,
1411    },
1412
1413    /// Reverse-mode AD companion to `Op::Scan` — extracts the carry
1414    /// gradient `dinit`. Walks `t = length-1 .. 0`, applying `body_vjp`
1415    /// to thread `dcarry` back through the time loop.
1416    ///
1417    /// Inputs (in order):
1418    ///   `[init, trajectory, upstream, xs_0, ..., xs_{num_xs-1}]`
1419    /// Output: `dinit`, shape = carry shape.
1420    ///
1421    /// `body_vjp` is the result of
1422    /// `autodiff::grad(body, [carry_id, xs_0_id, ..., xs_{num_xs-1}_id])`
1423    /// — a graph with `1 + num_xs + 1` Inputs (carry + x_t_i for each
1424    /// xs + `"d_output"`) and `1 + num_xs` outputs
1425    /// (dcarry + dx_t_i for each xs). This op reads `outputs[0]` =
1426    /// dcarry; the sibling [`Self::ScanBackwardXs`] reads the
1427    /// `outputs[1 + xs_idx]` slot for each xs gradient.
1428    ScanBackward {
1429        body_vjp: Box<crate::Graph>,
1430        length: u32,
1431        save_trajectory: bool,
1432        num_xs: u32,
1433        /// When `0` or equal to `length`, the trajectory input has
1434        /// shape `[length, *carry]` — every step's carry is cached
1435        /// (`CheckpointStrategy::All`). When `0 < K < length`, the
1436        /// trajectory input has shape `[K, *carry]` and the executor
1437        /// recomputes intermediate carries via `forward_body` between
1438        /// checkpoints. `forward_body` must be `Some` whenever this
1439        /// is < length.
1440        num_checkpoints: u32,
1441        /// Forward body (the same `body` from the forward Op::Scan).
1442        /// Required when `num_checkpoints > 0 && < length` so the
1443        /// executor can recompute carries between saved checkpoints.
1444        /// `None` for the All strategy (no recompute needed).
1445        forward_body: Option<Box<crate::Graph>>,
1446    },
1447
1448    /// Companion to [`Self::ScanBackward`] that extracts one stacked
1449    /// per-step `dxs_i` (shape `[length, *per_step_xs_i]`). Same inputs
1450    /// and same `body_vjp` graph as ScanBackward — `xs_idx` selects
1451    /// which body_vjp output to stack into the result.
1452    ///
1453    /// Note: each ScanBackwardXs runs its own backward loop. A future
1454    /// optimization can fuse them into a single multi-output backward
1455    /// kernel; for now it's `1 + num_xs` independent sweeps.
1456    ScanBackwardXs {
1457        body_vjp: Box<crate::Graph>,
1458        length: u32,
1459        save_trajectory: bool,
1460        num_xs: u32,
1461        xs_idx: u32,
1462        num_checkpoints: u32,
1463        forward_body: Option<Box<crate::Graph>>,
1464    },
1465
1466    /// CPU reference 3D Gaussian splat forward render.
1467    ///
1468    /// Seven flat F32 inputs (scene buffers + camera/render meta):
1469    ///   0. positions `[N*3]`
1470    ///   1. scales `[N*3]` (log-space)
1471    ///   2. rotations `[N*4]` (xyzw)
1472    ///   3. opacities `[N]` (logit)
1473    ///   4. colors `[N*3]` (linear RGB)
1474    ///   5. sh_coeffs `[N * sh_coeff_count * 3]`
1475    ///   6. meta `[23]` — camera position/target/up/fov/near/far, background RGB,
1476    ///      then width/height/tile_size/radius_scale/alpha_cutoff/max_splat_steps/
1477    ///      transmittance_threshold/max_list_entries as f32 bit-patterns.
1478    ///
1479    /// Output: `[height * width * 4]` linear RGBA (display gamma baked in).
1480    /// Build via [`crate::Graph::gaussian_splat_render`].
1481    ///
1482    /// Differentiable backward is not implemented in v1; autodiff treats this
1483    /// op as non-differentiable (same as [`Op::Sample`]).
1484    GaussianSplatRender {
1485        width: u32,
1486        height: u32,
1487        tile_size: u32,
1488        radius_scale: f32,
1489        alpha_cutoff: f32,
1490        max_splat_steps: u32,
1491        transmittance_threshold: f32,
1492        max_list_entries: u32,
1493    },
1494
1495    /// Backward pass for [`Self::GaussianSplatRender`].
1496    ///
1497    /// Eight inputs: the same seven as forward plus `d_loss_rgba` `[W*H*4]`
1498    /// (only RGB channels are used). Re-runs the training forward internally.
1499    ///
1500    /// Output: packed gradients
1501    /// `[positions(3N) | scales(3N) | rotations(4N) | opacities(N) | colors(3N) | sh(N*sh*3)]`.
1502    /// Unpack via [`crate::ops::splat::unpack_gaussian_splat_packed_grads`].
1503    GaussianSplatRenderBackward {
1504        width: u32,
1505        height: u32,
1506        tile_size: u32,
1507        radius_scale: f32,
1508        alpha_cutoff: f32,
1509        max_splat_steps: u32,
1510        transmittance_threshold: f32,
1511        max_list_entries: u32,
1512        loss_grad_clip: f32,
1513        sh_band: u32,
1514        max_anisotropy: f32,
1515    },
1516
1517    /// Strict IR stage 1: project, bin, sort, build per-pixel rays.
1518    ///
1519    /// Seven inputs (same scene + meta as [`Self::GaussianSplatRender`]). Output: packed
1520    /// prepare buffer (see `rlx_splat::prep_layout::prep_packed_len`).
1521    GaussianSplatPrepare {
1522        width: u32,
1523        height: u32,
1524        tile_size: u32,
1525        radius_scale: f32,
1526        alpha_cutoff: f32,
1527        max_splat_steps: u32,
1528        transmittance_threshold: f32,
1529        max_list_entries: u32,
1530    },
1531
1532    /// Strict IR stage 2: tile raster from [`Self::GaussianSplatPrepare`] output.
1533    ///
1534    /// Inputs: `prep` packed buffer, `meta` `[23]`. Output: `[width * height * 4]` RGBA.
1535    GaussianSplatRasterize {
1536        width: u32,
1537        height: u32,
1538        tile_size: u32,
1539        alpha_cutoff: f32,
1540        max_splat_steps: u32,
1541        transmittance_threshold: f32,
1542        max_list_entries: u32,
1543    },
1544
1545    /// User-registered custom op. `name` keys into the
1546    /// [`crate::op_registry`] for shape inference, autodiff, and
1547    /// per-backend execution. `attrs` is an opaque blob passed
1548    /// through to those callbacks (FFT direction, SparseLU
1549    /// reordering strategy, etc.). `num_inputs` is captured at
1550    /// construction time so [`Op::num_inputs`] stays infallible
1551    /// without a registry lookup. Build via [`crate::Graph::custom_op`].
1552    Custom {
1553        name: String,
1554        num_inputs: u32,
1555        attrs: Vec<u8>,
1556    },
1557
1558    /// 1D Fast Fourier Transform along the last axis.
1559    ///
1560    /// **Layouts**
1561    /// - `F32` / `F64`: 2N real-block — last axis is `[re₀…re_{N-1}, im₀…im_{N-1}]`.
1562    /// - `C64`: interleaved `[re, im]` pairs per complex element along the last axis.
1563    ///
1564    /// **ND transforms** — use `Graph::fftn` / `Graph::ifftn`, which compose
1565    /// `fft_axis` (transpose → Fft → transpose). Multi-axis `fftn` requires
1566    /// `DType::C64`; the 2N-block layout describes a single complex axis.
1567    ///
1568    /// Default (`FftNorm::Backward`) is **unnormalized** on both directions:
1569    ///   `fft(x)[k] = Σ x[n]·exp(-2πi·nk/N)`
1570    ///   `ifft(y)[n] = Σ y[k]·exp(+2πi·nk/N)`
1571    /// so `ifft(fft(x)) = N·x`. Use `FftNorm::Forward` for gpu-fft-style
1572    /// `1/N` scaling on inverse, or `FftNorm::Ortho` for unitary scaling.
1573    ///
1574    /// AD: VJP(`fft`) = `ifft`, VJP(`ifft`) = `fft` when `norm=Backward`;
1575    /// other norms apply the chain rule via output scaling.
1576    Fft {
1577        inverse: bool,
1578        norm: crate::fft::FftNorm,
1579    },
1580
1581    /// Ternary pruned radix-2 butterfly stage on interleaved complex state.
1582    ///
1583    /// Inputs:
1584    ///   0 — state `[batch, n_fft, 2]` (re/im on axis 2)
1585    ///   1 — gate `[half]` — 0 = identity, 1 = run butterfly (`half = n_fft/2`)
1586    ///   2 — rev `[half]` — 0 = forward, 1 = swap outputs when gate=1
1587    ///   3 — tw_re `[half]`
1588    ///   4 — tw_im `[half]`
1589    ///
1590    /// Output: `[batch, n_fft, 2]` same layout. Slots with gate=0 copy inputs
1591    /// without twiddle math.
1592    FftButterflyStage {
1593        stage: u32,
1594        n_fft: u32,
1595    },
1596
1597    /// Log-mel spectrogram from RLX FFT block-layout spectrum.
1598    ///
1599    /// Inputs:
1600    ///   0 — spectrum `[..., 2*n_fft]` (re plane then im plane, same as `Op::Fft` output)
1601    ///   1 — mel filterbank `[n_mels, n_bins]` with `n_bins = n_fft/2 + 1`
1602    ///
1603    /// Output: `[..., n_mels]` with Whisper dynamic-range compression
1604    /// (`log10`, clamp to max−8 dB, `(x+4)/4`).
1605    LogMel,
1606
1607    /// VJP of [`Op::LogMel`] w.r.t. spectrum (input 0).
1608    ///
1609    /// Inputs: spectrum block, mel filters, upstream `dy`.
1610    /// Output: `d_spectrum` (same shape as input 0).
1611    LogMelBackward,
1612
1613    /// User-defined sub-graph with optional override AD rules.
1614    /// Mirrors JAX's `custom_vjp` / `custom_jvp` decorators: the
1615    /// caller wraps a forward computation and supplies its own
1616    /// reverse- and/or forward-mode AD bodies. Useful when:
1617    ///   * The forward is iterative (Newton, fixed-point) and
1618    ///     differentiating through the loop is wasteful — the
1619    ///     vjp_body computes the implicit-function gradient at the
1620    ///     converged point in one shot.
1621    ///   * The math has a closed-form gradient that's much cheaper
1622    ///     than autodiff.
1623    ///   * The forward op is non-differentiable by tracing
1624    ///     (sampling, argmax) and the user wants a smooth surrogate.
1625    ///
1626    /// **fwd_body**: `num_inputs` Op::Inputs in NodeId construction
1627    /// order, one Op::Output (the primal y). Forward execution
1628    /// inlines this body once.
1629    ///
1630    /// **vjp_body** (optional): Op::Inputs are `num_inputs` primal
1631    /// inputs in NodeId order, plus two special-named Inputs —
1632    /// `"primal_output"` (the y from forward) and `"d_output"` (the
1633    /// upstream gradient). Outputs: `num_inputs` tensors in
1634    /// `set_outputs` order, matching the gradients of each primal
1635    /// input. When `None`, reverse-mode AD recurses into fwd_body
1636    /// — same as if the op were inlined.
1637    ///
1638    /// **jvp_body** (optional): Op::Inputs are `num_inputs` primal
1639    /// inputs in NodeId order, `num_inputs` special-named Inputs
1640    /// `"tangent_0"..="tangent_{num_inputs-1}"` carrying each input's
1641    /// tangent, and an optional special-named `"primal_output"` Input
1642    /// (the y from forward, useful when the JVP must be evaluated at
1643    /// a converged / nonlinear point — e.g. IFT-style forward-mode AD
1644    /// of an iterative solver). Output: 1 tensor (the tangent of y).
1645    /// When `None`, forward-mode AD recurses into fwd_body.
1646    ///
1647    /// `num_inputs` is captured so [`Op::num_inputs`] stays
1648    /// infallible. Build via [`crate::Graph::custom_fn`].
1649    CustomFn {
1650        fwd_body: Box<crate::Graph>,
1651        vjp_body: Option<Box<crate::Graph>>,
1652        jvp_body: Option<Box<crate::Graph>>,
1653        num_inputs: u32,
1654    },
1655}
1656
1657impl Op {
1658    /// PLAN L4: discriminant for backend-supported-set checks.
1659    /// Stable, parameter-free identity per variant — `Op::Activation(_)`
1660    /// and `Op::Activation(Relu)` share the same `OpKind::Activation`.
1661    pub fn kind(&self) -> OpKind {
1662        match self {
1663            Op::Input { .. } => OpKind::Input,
1664            Op::Param { .. } => OpKind::Param,
1665            Op::Constant { .. } => OpKind::Constant,
1666            Op::Activation(_) => OpKind::Activation,
1667            Op::Cast { .. } => OpKind::Cast,
1668            Op::StopGradient => OpKind::StopGradient,
1669            Op::Quantize { .. } => OpKind::Quantize,
1670            Op::Dequantize { .. } => OpKind::Dequantize,
1671            Op::FakeQuantize { .. } => OpKind::FakeQuantize,
1672            Op::FakeQuantizeLSQ { .. } => OpKind::FakeQuantizeLSQ,
1673            Op::FakeQuantizeLSQBackwardX { .. } => OpKind::FakeQuantizeLSQBackwardX,
1674            Op::FakeQuantizeLSQBackwardScale { .. } => OpKind::FakeQuantizeLSQBackwardScale,
1675            Op::Binary(_) => OpKind::Binary,
1676            Op::Compare(_) => OpKind::Compare,
1677            Op::Where => OpKind::Where,
1678            Op::ElementwiseRegion { .. } => OpKind::ElementwiseRegion,
1679            Op::TransformRegion { .. } => OpKind::TransformRegion,
1680            Op::BatchElementwiseRegion { .. } => OpKind::BatchElementwiseRegion,
1681            Op::MatMul => OpKind::MatMul,
1682            Op::DotGeneral { .. } => OpKind::DotGeneral,
1683            Op::DenseSolve => OpKind::DenseSolve,
1684            Op::BatchedDenseSolve => OpKind::BatchedDenseSolve,
1685            Op::LayerNorm { .. } => OpKind::LayerNorm,
1686            Op::LayerNorm2d { .. } => OpKind::LayerNorm2d,
1687            Op::GroupNorm { .. } => OpKind::GroupNorm,
1688            Op::BatchNormInference { .. } => OpKind::BatchNormInference,
1689            Op::RmsNorm { .. } => OpKind::RmsNorm,
1690            Op::ResizeNearest2x => OpKind::ResizeNearest2x,
1691            Op::Attention { .. } => OpKind::Attention,
1692            Op::Rope { .. } => OpKind::Rope,
1693            Op::AxialRope2d { .. } => OpKind::AxialRope2d,
1694            Op::Reshape { .. } => OpKind::Reshape,
1695            Op::Transpose { .. } => OpKind::Transpose,
1696            Op::Narrow { .. } => OpKind::Narrow,
1697            Op::Concat { .. } => OpKind::Concat,
1698            Op::Expand { .. } => OpKind::Expand,
1699            Op::Gather { .. } => OpKind::Gather,
1700            Op::Reduce { .. } => OpKind::Reduce,
1701            Op::Softmax { .. } => OpKind::Softmax,
1702            Op::Cumsum { .. } => OpKind::Cumsum,
1703            Op::TopK { .. } => OpKind::TopK,
1704            Op::Sample { .. } => OpKind::Sample,
1705            Op::Conv { .. } => OpKind::Conv,
1706            Op::Im2Col { .. } => OpKind::Im2Col,
1707            Op::ConvTranspose2d { .. } => OpKind::ConvTranspose2d,
1708            Op::Pool { .. } => OpKind::Pool,
1709            Op::ReluBackward => OpKind::ReluBackward,
1710            Op::ActivationBackward { .. } => OpKind::ActivationBackward,
1711            Op::FakeQuantizeBackward { .. } => OpKind::FakeQuantizeBackward,
1712            Op::ComplexNormSq => OpKind::ComplexNormSq,
1713            Op::ComplexNormSqBackward => OpKind::ComplexNormSqBackward,
1714            Op::Conjugate => OpKind::Conjugate,
1715            Op::LayerNormBackwardInput { .. } => OpKind::LayerNormBackwardInput,
1716            Op::LayerNormBackwardGamma { .. } => OpKind::LayerNormBackwardGamma,
1717            Op::RmsNormBackwardInput { .. } => OpKind::RmsNormBackwardInput,
1718            Op::RmsNormBackwardGamma { .. } => OpKind::RmsNormBackwardGamma,
1719            Op::RmsNormBackwardBeta { .. } => OpKind::RmsNormBackwardBeta,
1720            Op::RopeBackward { .. } => OpKind::RopeBackward,
1721            Op::GroupNormBackwardInput { .. } => OpKind::GroupNormBackwardInput,
1722            Op::GroupNormBackwardGamma { .. } => OpKind::GroupNormBackwardGamma,
1723            Op::GroupNormBackwardBeta { .. } => OpKind::GroupNormBackwardBeta,
1724            Op::BatchNormInferenceBackwardInput { .. } => OpKind::BatchNormInferenceBackwardInput,
1725            Op::BatchNormInferenceBackwardGamma { .. } => OpKind::BatchNormInferenceBackwardGamma,
1726            Op::BatchNormInferenceBackwardBeta => OpKind::BatchNormInferenceBackwardBeta,
1727            Op::CumsumBackward { .. } => OpKind::CumsumBackward,
1728            Op::GatherBackward { .. } => OpKind::GatherBackward,
1729            Op::MaxPool2dBackward { .. } => OpKind::MaxPool2dBackward,
1730            Op::Conv2dBackwardInput { .. } => OpKind::Conv2dBackwardInput,
1731            Op::Conv2dBackwardWeight { .. } => OpKind::Conv2dBackwardWeight,
1732            Op::SoftmaxCrossEntropyWithLogits => OpKind::SoftmaxCrossEntropyWithLogits,
1733            Op::SoftmaxCrossEntropyBackward => OpKind::SoftmaxCrossEntropyBackward,
1734            Op::AttentionBackward { .. } => OpKind::AttentionBackward,
1735            Op::GroupedMatMul => OpKind::GroupedMatMul,
1736            Op::DequantGroupedMatMul { .. } => OpKind::DequantGroupedMatMul,
1737            Op::DequantMoEWeights { .. } => OpKind::DequantMoEWeights,
1738            Op::ScatterAdd => OpKind::ScatterAdd,
1739            Op::LoraMatMul { .. } => OpKind::LoraMatMul,
1740            Op::DequantMatMul { .. } => OpKind::DequantMatMul,
1741            Op::QMatMul { .. } => OpKind::QMatMul,
1742            Op::QConv2d { .. } => OpKind::QConv2d,
1743            Op::SelectiveScan { .. } => OpKind::SelectiveScan,
1744            Op::GatedDeltaNet { .. } => OpKind::GatedDeltaNet,
1745            Op::FusedSwiGLU { .. } => OpKind::FusedSwiGLU,
1746            Op::FusedMatMulBiasAct { .. } => OpKind::FusedMatMulBiasAct,
1747            Op::FusedResidualLN { .. } => OpKind::FusedResidualLN,
1748            Op::FusedResidualRmsNorm { .. } => OpKind::FusedResidualRmsNorm,
1749            Op::FusedAttentionBlock { .. } => OpKind::FusedAttentionBlock,
1750            Op::FusedTransformerLayer { .. } => OpKind::FusedTransformerLayer,
1751            Op::If { .. } => OpKind::If,
1752            Op::While { .. } => OpKind::While,
1753            Op::Scan { .. } => OpKind::Scan,
1754            Op::ScanBackward { .. } => OpKind::ScanBackward,
1755            Op::ScanBackwardXs { .. } => OpKind::ScanBackwardXs,
1756            Op::GaussianSplatRender { .. } => OpKind::GaussianSplatRender,
1757            Op::GaussianSplatRenderBackward { .. } => OpKind::GaussianSplatRenderBackward,
1758            Op::GaussianSplatPrepare { .. } => OpKind::GaussianSplatPrepare,
1759            Op::GaussianSplatRasterize { .. } => OpKind::GaussianSplatRasterize,
1760            Op::Custom { .. } => OpKind::Custom,
1761            Op::CustomFn { .. } => OpKind::CustomFn,
1762            Op::Fft { .. } => OpKind::Fft,
1763            Op::FftButterflyStage { .. } => OpKind::FftButterflyStage,
1764            Op::LogMel => OpKind::LogMel,
1765            Op::LogMelBackward => OpKind::LogMelBackward,
1766        }
1767    }
1768
1769    /// True if this op is element-wise (same shape in, same shape out).
1770    /// Element-wise ops are prime fusion candidates.
1771    pub fn is_elementwise(&self) -> bool {
1772        matches!(
1773            self,
1774            Op::Activation(_)
1775                | Op::Cast { .. }
1776                | Op::StopGradient
1777                | Op::Binary(_)
1778                | Op::Compare(_)
1779                | Op::Where
1780                | Op::ElementwiseRegion { .. }
1781                | Op::BatchElementwiseRegion { .. }
1782        )
1783    }
1784
1785    /// True if this op may appear in a [`Op::TransformRegion`] chain.
1786    pub fn is_transform_eligible(&self) -> bool {
1787        matches!(self, Op::ResizeNearest2x)
1788    }
1789
1790    /// True if this op is a BLAS/compute-intensive op that forms a fusion boundary.
1791    pub fn is_blas(&self) -> bool {
1792        matches!(
1793            self,
1794            Op::MatMul
1795                | Op::DotGeneral { .. }
1796                | Op::DenseSolve
1797                | Op::BatchedDenseSolve
1798                | Op::Conv { .. }
1799                | Op::Im2Col { .. }
1800                | Op::ConvTranspose2d { .. }
1801                | Op::FusedMatMulBiasAct { .. }
1802                | Op::GroupedMatMul
1803                | Op::DequantGroupedMatMul { .. }
1804                | Op::DequantMoEWeights { .. }
1805                | Op::LoraMatMul { .. }
1806                | Op::DequantMatMul { .. }
1807                | Op::QMatMul { .. }
1808                | Op::QConv2d { .. }
1809        )
1810    }
1811
1812    /// True if element-wise fusion must not span across this op.
1813    pub fn is_fusion_boundary(&self) -> bool {
1814        self.is_blas()
1815            || matches!(
1816                self,
1817                Op::GaussianSplatRender { .. }
1818                    | Op::GaussianSplatRenderBackward { .. }
1819                    | Op::GaussianSplatPrepare { .. }
1820                    | Op::GaussianSplatRasterize { .. }
1821            )
1822    }
1823
1824    /// True if this op is a reduction (drives loop iteration in fused kernels).
1825    pub fn is_reduction(&self) -> bool {
1826        matches!(
1827            self,
1828            Op::Reduce { .. } | Op::Softmax { .. } | Op::TopK { .. }
1829        )
1830    }
1831
1832    /// Number of tensor inputs this op expects.
1833    pub fn num_inputs(&self) -> usize {
1834        match self {
1835            Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => 0,
1836            Op::Activation(_)
1837            | Op::Cast { .. }
1838            | Op::StopGradient
1839            | Op::Reshape { .. }
1840            | Op::Quantize { .. }
1841            | Op::Dequantize { .. }
1842            | Op::Transpose { .. }
1843            | Op::Narrow { .. }
1844            | Op::Expand { .. }
1845            | Op::Reduce { .. }
1846            | Op::Softmax { .. }
1847            | Op::FusedSwiGLU { .. }
1848            | Op::TopK { .. }
1849            | Op::Cumsum { .. }
1850            | Op::Sample { .. }
1851            | Op::ResizeNearest2x => 1,
1852            // EMA / Fixed scale modes carry a state tensor as a 2nd input;
1853            // PerBatch (default) doesn't need one.
1854            Op::FakeQuantize { scale_mode, .. } => match scale_mode {
1855                ScaleMode::PerBatch => 1,
1856                ScaleMode::EMA { .. } | ScaleMode::Fixed => 2,
1857            },
1858            Op::FakeQuantizeLSQ { .. } => 2, // x, scale (learned param)
1859            Op::FakeQuantizeLSQBackwardX { .. } | Op::FakeQuantizeLSQBackwardScale { .. } => 3, // x, scale, dy
1860            Op::Binary(_) | Op::Compare(_) | Op::Gather { .. } | Op::MatMul | Op::ScatterAdd => 2,
1861            Op::GroupedMatMul => 3,               // input, weight, expert_idx
1862            Op::DequantGroupedMatMul { .. } => 3, // input, packed_w, expert_idx
1863            Op::DequantMoEWeights { .. } => 1,    // packed_w
1864            Op::LoraMatMul { .. } => 4,           // x, w, a, b
1865            // x, w_q, scale, zp — or x, packed_w_bytes for GGUF
1866            // schemes (their scales/mins live inside the packed bytes,
1867            // see `QuantScheme::is_gguf`).
1868            Op::DequantMatMul { scheme } => {
1869                if scheme.is_gguf() {
1870                    2
1871                } else {
1872                    4
1873                }
1874            }
1875            Op::QMatMul { .. } => 3,       // x, w, bias
1876            Op::QConv2d { .. } => 3,       // x, w, bias
1877            Op::SelectiveScan { .. } => 5, // x, delta, a, b, c
1878            Op::GatedDeltaNet { carry_state, .. } if *carry_state => 6, // + state in/out
1879            Op::GatedDeltaNet { .. } => 5, // q, k, v, g, beta
1880            Op::Where => 3,                // cond, on_true, on_false
1881            Op::Attention { mask_kind, .. } => match mask_kind {
1882                MaskKind::Custom | MaskKind::Bias => 4, // Q, K, V, mask
1883                _ => 3,                                 // Q, K, V (mask synthesized in-kernel)
1884            },
1885            Op::AttentionBackward { mask_kind, .. } => match mask_kind {
1886                MaskKind::Custom | MaskKind::Bias => 5, // q, k, v, dy, mask
1887                _ => 4,                                 // q, k, v, dy
1888            },
1889            Op::Rope { .. } => 3, // x, cos, sin
1890            Op::AxialRope2d { .. } => 1,
1891            Op::LayerNorm { .. }
1892            | Op::LayerNorm2d { .. }
1893            | Op::GroupNorm { .. }
1894            | Op::RmsNorm { .. } => 3, // input, gamma, beta
1895            Op::BatchNormInference { .. } => 5, // x, gamma, beta, mean, var
1896            Op::FusedMatMulBiasAct { .. } => 3, // input, weight, bias
1897            Op::FusedResidualLN { has_bias: true, .. } => 5, // x, residual, bias, gamma, beta
1898            Op::FusedResidualLN {
1899                has_bias: false, ..
1900            } => 4, // x, residual, gamma, beta
1901            Op::FusedResidualRmsNorm { has_bias: true, .. } => 5, // x, residual, bias, gamma, beta
1902            Op::FusedResidualRmsNorm {
1903                has_bias: false, ..
1904            } => 4, // x, residual, gamma, beta
1905            Op::Conv { .. } | Op::ConvTranspose2d { .. } => 2, // input, weight (bias via Add)
1906            Op::Im2Col { .. } => 1,
1907            Op::Pool { .. } => 1,
1908            Op::ReluBackward => 2,                  // x, dy
1909            Op::ActivationBackward { .. } => 2,     // x, dy
1910            Op::FakeQuantizeBackward { .. } => 2,   // x, dy
1911            Op::ComplexNormSq => 1,                 // z (C64)
1912            Op::ComplexNormSqBackward => 2,         // z, g
1913            Op::Conjugate => 1,                     // z (C64)
1914            Op::LayerNormBackwardInput { .. } => 3, // x, gamma, dy
1915            Op::LayerNormBackwardGamma { .. } => 2, // x, dy
1916            Op::RmsNormBackwardInput { .. } => 4,   // x, gamma, beta, dy
1917            Op::RmsNormBackwardGamma { .. } => 4,
1918            Op::RmsNormBackwardBeta { .. } => 4,
1919            Op::RopeBackward { .. } => 3,           // dy, cos, sin
1920            Op::GroupNormBackwardInput { .. } => 4, // x, gamma, beta, dy
1921            Op::GroupNormBackwardGamma { .. } => 2, // x, dy
1922            Op::GroupNormBackwardBeta { .. } => 2,
1923            Op::BatchNormInferenceBackwardInput { .. } => 5, // x, gamma, mean, var, dy
1924            Op::BatchNormInferenceBackwardGamma { .. } => 4, // x, mean, var, dy
1925            Op::BatchNormInferenceBackwardBeta => 1,         // dy
1926            Op::CumsumBackward { .. } => 1,                  // dy
1927            Op::GatherBackward { .. } => 2,                  // dy, indices
1928            Op::MaxPool2dBackward { .. } => 2,               // x, dy
1929            Op::Conv2dBackwardInput { .. } => 2,             // dy, w
1930            Op::Conv2dBackwardWeight { .. } => 2,            // x, dy
1931            Op::SoftmaxCrossEntropyWithLogits => 2,          // logits, labels
1932            Op::SoftmaxCrossEntropyBackward => 3,            // logits, labels, d_loss
1933            Op::Concat { .. } => 0,                          // variadic — checked at graph level
1934            Op::DotGeneral { .. } => 2,
1935            Op::DenseSolve => 2,        // A, b
1936            Op::BatchedDenseSolve => 2, // A [B,N,N], b [B,N] or [B,N,K]
1937            Op::FusedAttentionBlock {
1938                has_bias, has_rope, ..
1939            } => 4 + if *has_bias { 2 } else { 0 } + if *has_rope { 2 } else { 0 },
1940            Op::If { .. } => 1,    // predicate (captures handled separately)
1941            Op::While { .. } => 0, // variadic loop-carried; checked at graph level
1942            Op::Scan {
1943                num_bcast, num_xs, ..
1944            } => 1 + *num_bcast as usize + *num_xs as usize,
1945            Op::ScanBackward { num_xs, .. } => 3 + *num_xs as usize, // init, trajectory, upstream, xs_0..
1946            Op::ScanBackwardXs { num_xs, .. } => 3 + *num_xs as usize, // same as ScanBackward
1947            Op::GaussianSplatRender { .. } => 7,
1948            Op::GaussianSplatRenderBackward { .. } => 8,
1949            Op::GaussianSplatPrepare { .. } => 7,
1950            Op::GaussianSplatRasterize { .. } => 2,
1951            Op::FusedTransformerLayer { has_bias, .. } => {
1952                // hidden + qkv_w + out_w + ln1_g + ln1_b + fc1_w + fc2_w + ln2_g + ln2_b + mask = 10
1953                // bias variant adds: qkv_b + out_b + fc1_b + fc2_b = 4 more
1954                10 + if *has_bias { 4 } else { 0 }
1955            }
1956            Op::ElementwiseRegion { num_inputs, .. } => *num_inputs as usize,
1957            Op::TransformRegion { num_inputs, .. } => *num_inputs as usize,
1958            Op::BatchElementwiseRegion {
1959                num_batch_inputs, ..
1960            } => *num_batch_inputs as usize,
1961            Op::Custom { num_inputs, .. } => *num_inputs as usize,
1962            Op::CustomFn { num_inputs, .. } => *num_inputs as usize,
1963            Op::Fft { .. } => 1,
1964            Op::FftButterflyStage { .. } => 5,
1965            Op::LogMel => 2,
1966            Op::LogMelBackward => 3,
1967        }
1968    }
1969}
1970
1971impl std::fmt::Display for Op {
1972    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1973        match self {
1974            Op::Input { name } => write!(f, "input(\"{name}\")"),
1975            Op::Param { name } => write!(f, "param(\"{name}\")"),
1976            Op::Constant { data } => write!(f, "const({}B)", data.len()),
1977            Op::Activation(a) => write!(f, "{a:?}"),
1978            Op::Quantize { axis, scales, .. } => match axis {
1979                None => write!(f, "quantize(s={})", scales[0]),
1980                Some(d) => write!(f, "quantize(axis={d},nch={})", scales.len()),
1981            },
1982            Op::Dequantize { axis, scales, .. } => match axis {
1983                None => write!(f, "dequantize(s={})", scales[0]),
1984                Some(d) => write!(f, "dequantize(axis={d},nch={})", scales.len()),
1985            },
1986            Op::FakeQuantize {
1987                bits,
1988                axis,
1989                ste,
1990                scale_mode,
1991            } => match axis {
1992                None => write!(
1993                    f,
1994                    "fake_quant(bits={bits},ste={ste:?},scale={scale_mode:?})"
1995                ),
1996                Some(d) => write!(
1997                    f,
1998                    "fake_quant(bits={bits},axis={d},ste={ste:?},scale={scale_mode:?})"
1999                ),
2000            },
2001            Op::FakeQuantizeLSQ { bits, axis } => match axis {
2002                None => write!(f, "fake_quant_lsq(bits={bits})"),
2003                Some(d) => write!(f, "fake_quant_lsq(bits={bits},axis={d})"),
2004            },
2005            Op::FakeQuantizeLSQBackwardX { bits, .. } => {
2006                write!(f, "fake_quant_lsq_bwd_x(bits={bits})")
2007            }
2008            Op::FakeQuantizeLSQBackwardScale { bits, .. } => {
2009                write!(f, "fake_quant_lsq_bwd_s(bits={bits})")
2010            }
2011            Op::Cast { to } => write!(f, "cast({to})"),
2012            Op::StopGradient => write!(f, "stop_gradient"),
2013            Op::Binary(op) => write!(f, "{op:?}"),
2014            Op::Compare(op) => write!(f, "{op:?}"),
2015            Op::Where => write!(f, "where"),
2016            Op::MatMul => write!(f, "matmul"),
2017            Op::DotGeneral { .. } => write!(f, "dot_general"),
2018            Op::DenseSolve => write!(f, "dense_solve"),
2019            Op::BatchedDenseSolve => write!(f, "batched_dense_solve"),
2020            Op::LayerNorm { eps, .. } => write!(f, "layer_norm(eps={eps})"),
2021            Op::GroupNorm { num_groups, eps } => {
2022                write!(f, "group_norm(groups={num_groups},eps={eps})")
2023            }
2024            Op::BatchNormInference { eps } => write!(f, "batch_norm_inference(eps={eps})"),
2025            Op::ResizeNearest2x => write!(f, "resize_nearest_2x"),
2026            Op::RmsNorm { eps, .. } => write!(f, "rms_norm(eps={eps})"),
2027            Op::Attention {
2028                num_heads,
2029                head_dim,
2030                mask_kind,
2031                score_scale,
2032                attn_logit_softcap,
2033            } => {
2034                let mut s = match mask_kind {
2035                    MaskKind::Custom => format!("attention(h={num_heads},d={head_dim})"),
2036                    MaskKind::None => format!("attention(h={num_heads},d={head_dim},nomask)"),
2037                    MaskKind::Causal => format!("attention(h={num_heads},d={head_dim},causal)"),
2038                    MaskKind::SlidingWindow(w) => {
2039                        format!("attention(h={num_heads},d={head_dim},sw={w})")
2040                    }
2041                    MaskKind::Bias => format!("attention(h={num_heads},d={head_dim},bias)"),
2042                };
2043                if let Some(sc) = score_scale {
2044                    s.push_str(&format!(",scale={sc}"));
2045                }
2046                if let Some(cap) = attn_logit_softcap {
2047                    s.push_str(&format!(",softcap={cap}"));
2048                }
2049                write!(f, "{s}")
2050            }
2051            Op::Rope { head_dim, n_rot } => write!(f, "rope(d={head_dim}, n_rot={n_rot})"),
2052            Op::AxialRope2d {
2053                end_x,
2054                end_y,
2055                head_dim,
2056                num_heads,
2057                theta,
2058                repeat_factor,
2059            } => write!(
2060                f,
2061                "axial_rope2d({end_x}x{end_y},h={num_heads},d={head_dim},θ={theta},r={repeat_factor})"
2062            ),
2063            Op::Reshape { new_shape } => write!(f, "reshape({new_shape:?})"),
2064            Op::Transpose { perm } => write!(f, "transpose({perm:?})"),
2065            Op::Narrow { axis, start, len } => write!(f, "narrow({axis},{start},{len})"),
2066            Op::Concat { axis } => write!(f, "concat(axis={axis})"),
2067            Op::Expand { .. } => write!(f, "expand"),
2068            Op::Gather { axis } => write!(f, "gather(axis={axis})"),
2069            Op::Reduce { op, axes, .. } => write!(f, "reduce_{op:?}({axes:?})"),
2070            Op::Softmax { axis } => write!(f, "softmax(axis={axis})"),
2071            Op::Cumsum { axis, exclusive } => {
2072                if *exclusive {
2073                    write!(f, "cumsum(axis={axis},excl)")
2074                } else {
2075                    write!(f, "cumsum(axis={axis})")
2076                }
2077            }
2078            Op::Sample {
2079                top_k,
2080                top_p,
2081                temperature,
2082                ..
2083            } => {
2084                write!(f, "sample(t={temperature}")?;
2085                if *top_k > 0 {
2086                    write!(f, ",k={top_k}")?;
2087                }
2088                if *top_p < 1.0 {
2089                    write!(f, ",p={top_p}")?;
2090                }
2091                write!(f, ")")
2092            }
2093            Op::TopK { k } => write!(f, "topk(k={k})"),
2094            Op::GroupedMatMul => write!(f, "grouped_matmul"),
2095            Op::DequantGroupedMatMul { scheme } => {
2096                write!(f, "dequant_grouped_matmul({scheme})")
2097            }
2098            Op::DequantMoEWeights { scheme } => write!(f, "dequant_moe_weights({scheme})"),
2099            Op::LoraMatMul { scale } => write!(f, "lora_matmul(scale={scale})"),
2100            Op::DequantMatMul { scheme } => write!(f, "dequant_matmul({scheme})"),
2101            Op::QMatMul {
2102                x_zp,
2103                w_zp,
2104                out_zp,
2105                mult,
2106            } => write!(
2107                f,
2108                "q_matmul(x_zp={x_zp},w_zp={w_zp},out_zp={out_zp},mult={mult})"
2109            ),
2110            Op::QConv2d { kernel_size, .. } => write!(f, "q_conv2d({kernel_size:?})"),
2111            Op::SelectiveScan { state_size } => write!(f, "ssm_scan(n={state_size})"),
2112            Op::GatedDeltaNet {
2113                state_size,
2114                carry_state,
2115            } => {
2116                if *carry_state {
2117                    write!(f, "gated_delta_net(n={state_size},carry)")
2118                } else {
2119                    write!(f, "gated_delta_net(n={state_size})")
2120                }
2121            }
2122            Op::ScatterAdd => write!(f, "scatter_add"),
2123            Op::Conv { kernel_size, .. } => write!(f, "conv2d({kernel_size:?})"),
2124            Op::Im2Col { kernel_size, .. } => write!(f, "im2col({kernel_size:?})"),
2125            Op::ConvTranspose2d { kernel_size, .. } => {
2126                write!(f, "conv_transpose2d({kernel_size:?})")
2127            }
2128            Op::LayerNorm2d { eps } => write!(f, "layer_norm2d(eps={eps})"),
2129            Op::Pool {
2130                kind, kernel_size, ..
2131            } => write!(f, "pool_{kind:?}({kernel_size:?})"),
2132            Op::ReluBackward => write!(f, "relu_backward"),
2133            Op::ActivationBackward { kind } => write!(f, "{kind:?}_backward"),
2134            Op::ComplexNormSq => write!(f, "complex_norm_sq"),
2135            Op::ComplexNormSqBackward => write!(f, "complex_norm_sq_backward"),
2136            Op::Conjugate => write!(f, "conjugate"),
2137            Op::FakeQuantizeBackward { bits, ste, .. } => {
2138                write!(f, "fake_quant_backward(bits={bits},ste={ste:?})")
2139            }
2140            Op::MaxPool2dBackward { kernel_size, .. } => {
2141                write!(f, "maxpool2d_backward({kernel_size:?})")
2142            }
2143            Op::Conv2dBackwardInput { kernel_size, .. } => {
2144                write!(f, "conv2d_backward_input({kernel_size:?})")
2145            }
2146            Op::Conv2dBackwardWeight { kernel_size, .. } => {
2147                write!(f, "conv2d_backward_weight({kernel_size:?})")
2148            }
2149            Op::SoftmaxCrossEntropyWithLogits => write!(f, "sce_with_logits"),
2150            Op::SoftmaxCrossEntropyBackward => write!(f, "sce_backward"),
2151            Op::AttentionBackward {
2152                num_heads,
2153                head_dim,
2154                mask_kind,
2155                wrt,
2156            } => match mask_kind {
2157                MaskKind::None => write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},nomask)"),
2158                MaskKind::Causal => {
2159                    write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},causal)")
2160                }
2161                MaskKind::SlidingWindow(w) => {
2162                    write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},sw={w})")
2163                }
2164                MaskKind::Custom => {
2165                    write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},custom)")
2166                }
2167                MaskKind::Bias => write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},bias)"),
2168            },
2169            Op::FusedMatMulBiasAct { activation } => {
2170                write!(f, "fused_mm_bias")?;
2171                if let Some(a) = activation {
2172                    write!(f, "_{a:?}")?;
2173                }
2174                Ok(())
2175            }
2176            Op::FusedResidualLN { has_bias, eps } => {
2177                write!(f, "fused_residual")?;
2178                if *has_bias {
2179                    write!(f, "_bias")?;
2180                }
2181                write!(f, "_ln(eps={eps})")
2182            }
2183            Op::FusedResidualRmsNorm { has_bias, eps } => {
2184                write!(f, "fused_residual")?;
2185                if *has_bias {
2186                    write!(f, "_bias")?;
2187                }
2188                write!(f, "_rms(eps={eps})")
2189            }
2190            Op::FusedSwiGLU {
2191                cast_to,
2192                gate_first,
2193            } => {
2194                let mut s = match cast_to {
2195                    Some(dt) => format!("fused_swiglu(cast={dt}"),
2196                    None => "fused_swiglu(".to_string(),
2197                };
2198                if *gate_first {
2199                    s.push_str(",gate_first");
2200                }
2201                s.push(')');
2202                write!(f, "{s}")
2203            }
2204            Op::FusedAttentionBlock {
2205                num_heads,
2206                head_dim,
2207                has_bias,
2208                has_rope,
2209            } => {
2210                write!(f, "fused_attn(h={num_heads},d={head_dim}")?;
2211                if *has_bias {
2212                    write!(f, ",bias")?;
2213                }
2214                if *has_rope {
2215                    write!(f, ",rope")?;
2216                }
2217                write!(f, ")")
2218            }
2219            Op::If { .. } => write!(f, "if(...)"),
2220            Op::While { max_iterations, .. } => match max_iterations {
2221                Some(n) => write!(f, "while(...max={n})"),
2222                None => write!(f, "while(...)"),
2223            },
2224            Op::Scan {
2225                length,
2226                save_trajectory,
2227                num_xs,
2228                ..
2229            } => {
2230                let traj = if *save_trajectory { ",traj" } else { "" };
2231                let xs = if *num_xs > 0 {
2232                    format!(",xs={}", num_xs)
2233                } else {
2234                    String::new()
2235                };
2236                write!(f, "scan(len={length}{xs}{traj})")
2237            }
2238            Op::ScanBackward {
2239                length,
2240                save_trajectory,
2241                num_xs,
2242                ..
2243            } => {
2244                let traj = if *save_trajectory { ",traj" } else { "" };
2245                let xs = if *num_xs > 0 {
2246                    format!(",xs={}", num_xs)
2247                } else {
2248                    String::new()
2249                };
2250                write!(f, "scan_bwd(len={length}{xs}{traj})")
2251            }
2252            Op::ScanBackwardXs {
2253                length,
2254                save_trajectory,
2255                num_xs,
2256                xs_idx,
2257                ..
2258            } => {
2259                let traj = if *save_trajectory { ",traj" } else { "" };
2260                write!(
2261                    f,
2262                    "scan_bwd_xs(len={length},xs={num_xs},idx={xs_idx}{traj})"
2263                )
2264            }
2265            Op::FusedTransformerLayer {
2266                num_heads,
2267                head_dim,
2268                intermediate_size,
2269                has_bias,
2270                ..
2271            } => {
2272                write!(
2273                    f,
2274                    "fused_layer(h={num_heads},d={head_dim},int={intermediate_size}"
2275                )?;
2276                if *has_bias {
2277                    write!(f, ",bias")?;
2278                }
2279                write!(f, ")")
2280            }
2281            Op::ElementwiseRegion {
2282                chain,
2283                num_inputs,
2284                scalar_input_mask,
2285                input_modulus: _,
2286                prologue,
2287                prologue_input: _,
2288            } => {
2289                let pro = match prologue {
2290                    RegionPrologue::None => "",
2291                    RegionPrologue::ResizeNearest2x => ",prologue=resize2x",
2292                };
2293                if *scalar_input_mask != 0 {
2294                    write!(
2295                        f,
2296                        "ew_region(in={num_inputs},steps={},scalar_mask=0x{:x}{pro})",
2297                        chain.len(),
2298                        scalar_input_mask
2299                    )
2300                } else {
2301                    write!(f, "ew_region(in={num_inputs},steps={}{pro})", chain.len())
2302                }
2303            }
2304            Op::TransformRegion { steps, num_inputs } => {
2305                write!(f, "transform_region(in={num_inputs},steps={})", steps.len())
2306            }
2307            Op::BatchElementwiseRegion {
2308                chain,
2309                num_batch_inputs,
2310                scalar_input_mask,
2311                prologue,
2312                ..
2313            } => write!(
2314                f,
2315                "batch_ew_region(batch={num_batch_inputs},steps={},mask=0x{:x},prologue={prologue:?})",
2316                chain.len(),
2317                scalar_input_mask
2318            ),
2319            Op::LayerNormBackwardInput { eps, .. } => {
2320                write!(f, "layer_norm_backward_input(eps={eps})")
2321            }
2322            Op::LayerNormBackwardGamma { eps, .. } => {
2323                write!(f, "layer_norm_backward_gamma(eps={eps})")
2324            }
2325            Op::RmsNormBackwardInput { eps, .. } => write!(f, "rms_norm_backward_input(eps={eps})"),
2326            Op::RmsNormBackwardGamma { eps, .. } => write!(f, "rms_norm_backward_gamma(eps={eps})"),
2327            Op::RmsNormBackwardBeta { eps, .. } => write!(f, "rms_norm_backward_beta(eps={eps})"),
2328            Op::RopeBackward { head_dim, n_rot } => {
2329                write!(f, "rope_backward(d={head_dim},n_rot={n_rot})")
2330            }
2331            Op::GroupNormBackwardInput { num_groups, eps } => {
2332                write!(f, "group_norm_backward_input(g={num_groups},eps={eps})")
2333            }
2334            Op::GroupNormBackwardGamma { num_groups, eps } => {
2335                write!(f, "group_norm_backward_gamma(g={num_groups},eps={eps})")
2336            }
2337            Op::GroupNormBackwardBeta { num_groups, eps } => {
2338                write!(f, "group_norm_backward_beta(g={num_groups},eps={eps})")
2339            }
2340            Op::BatchNormInferenceBackwardInput { eps } => {
2341                write!(f, "batch_norm_inference_backward_input(eps={eps})")
2342            }
2343            Op::BatchNormInferenceBackwardGamma { eps } => {
2344                write!(f, "batch_norm_inference_backward_gamma(eps={eps})")
2345            }
2346            Op::BatchNormInferenceBackwardBeta => {
2347                write!(f, "batch_norm_inference_backward_beta")
2348            }
2349            Op::CumsumBackward { axis, exclusive } => {
2350                write!(f, "cumsum_backward(axis={axis},exclusive={exclusive})")
2351            }
2352            Op::GatherBackward { axis } => write!(f, "gather_backward(axis={axis})"),
2353            Op::GaussianSplatRender {
2354                width,
2355                height,
2356                tile_size,
2357                radius_scale,
2358                alpha_cutoff,
2359                max_splat_steps,
2360                transmittance_threshold,
2361                max_list_entries,
2362            } => write!(
2363                f,
2364                "gaussian_splat_render({width}x{height},tile={tile_size},r={radius_scale},a={alpha_cutoff},steps={max_splat_steps},t={transmittance_threshold},list={max_list_entries})"
2365            ),
2366            Op::GaussianSplatRenderBackward {
2367                width,
2368                height,
2369                loss_grad_clip,
2370                sh_band,
2371                ..
2372            } => write!(
2373                f,
2374                "gaussian_splat_render_bwd({width}x{height},clip={loss_grad_clip},sh={sh_band})"
2375            ),
2376            Op::GaussianSplatPrepare {
2377                width,
2378                height,
2379                tile_size,
2380                radius_scale,
2381                alpha_cutoff,
2382                max_splat_steps,
2383                transmittance_threshold,
2384                max_list_entries,
2385                ..
2386            } => write!(
2387                f,
2388                "gaussian_splat_prepare({width}x{height},tile={tile_size},r={radius_scale},a={alpha_cutoff},steps={max_splat_steps},t={transmittance_threshold},list={max_list_entries})"
2389            ),
2390            Op::GaussianSplatRasterize {
2391                width,
2392                height,
2393                tile_size,
2394                alpha_cutoff,
2395                max_splat_steps,
2396                transmittance_threshold,
2397                max_list_entries,
2398                ..
2399            } => write!(
2400                f,
2401                "gaussian_splat_rasterize({width}x{height},tile={tile_size},a={alpha_cutoff},steps={max_splat_steps},t={transmittance_threshold},list={max_list_entries})"
2402            ),
2403            Op::Custom {
2404                name,
2405                num_inputs,
2406                attrs,
2407            } => write!(f, "custom({name},in={num_inputs},attrs={}B)", attrs.len()),
2408            Op::CustomFn {
2409                num_inputs,
2410                vjp_body,
2411                jvp_body,
2412                ..
2413            } => {
2414                let v = if vjp_body.is_some() { ",vjp" } else { "" };
2415                let j = if jvp_body.is_some() { ",jvp" } else { "" };
2416                write!(f, "custom_fn(in={num_inputs}{v}{j})")
2417            }
2418            Op::Fft { inverse, norm } => {
2419                write!(f, "fft(inverse={inverse}, norm={norm:?})")
2420            }
2421            Op::FftButterflyStage { stage, n_fft } => {
2422                write!(f, "fft_butterfly_stage(stage={stage}, n_fft={n_fft})")
2423            }
2424            Op::LogMel => write!(f, "log_mel()"),
2425            Op::LogMelBackward => write!(f, "log_mel_backward()"),
2426        }
2427    }
2428}