Skip to main content

rlx_ir/
op.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Operation types — every tensor op in the RLX IR.
17//!
18//! Designed for pattern-matching fusion: ops are grouped by category so
19//! fusion passes can reason about them structurally.
20
21use crate::DType;
22
23/// Unary element-wise activation functions.
24#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub enum Activation {
27    Gelu,
28    GeluApprox,
29    Silu, // SwiGLU gate activation
30    Relu,
31    Sigmoid,
32    Tanh,
33    Exp,
34    Log,
35    Sqrt,
36    Rsqrt,
37    Neg,
38    Abs,
39    /// `sin(x)`. Backward: `dx = upstream · cos(x)`.
40    Sin,
41    /// `cos(x)`. Backward: `dx = -upstream · sin(x)`.
42    Cos,
43    /// `tan(x)`. Backward: `dx = upstream · sec²(x) = upstream · (1 + tan²(x))`.
44    Tan,
45    /// `atan(x)`. Backward: `dx = upstream · (1 / (1 + x²))`.
46    Atan,
47    /// Round to nearest integer (half-to-even), in f32.
48    /// Forward: `x.round()`. Backward: STE — treats as identity, so
49    /// the gradient passes through unchanged. Useful as a primitive
50    /// for composing custom quantization schemes (Mul-by-recip-scale
51    /// → Round → Clamp → Mul-by-scale = a hand-rolled FakeQuantize
52    /// that the elementwise-region pass can fuse into a single kernel).
53    Round,
54}
55
56/// Scale-tracking strategy for `Op::FakeQuantize`. Determines how
57/// the per-channel `s[c]` is computed each forward pass.
58///
59/// * `PerBatch` — recompute `s[c] = max(|x|) / q_max` from the
60///   current data on every call. Simple, no extra inputs, but
61///   noisy for activations (max-abs jumps batch-to-batch).
62///
63/// * `EMA { decay }` — keep a running `s[c]` in a state tensor
64///   (passed as a second op input). On each call, blend the
65///   current per-batch max-abs into the state via
66///   `state' = decay·state + (1-decay)·max_abs`. Smooth scale
67///   over training, makes activation-QAT actually trainable.
68///   Typical `decay = 0.99`.
69///
70/// * `Fixed` — never recompute. The state tensor's value is
71///   used as-is each call (set once at construction or by the
72///   caller). Useful when scales are pre-calibrated.
73#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
74#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Default)]
75pub enum ScaleMode {
76    #[default]
77    PerBatch,
78    EMA {
79        decay: f32,
80    },
81    Fixed,
82}
83
84impl Eq for ScaleMode {}
85impl std::hash::Hash for ScaleMode {
86    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
87        match self {
88            ScaleMode::PerBatch => state.write_u8(0),
89            ScaleMode::EMA { decay } => {
90                state.write_u8(1);
91                state.write_u32(decay.to_bits());
92            }
93            ScaleMode::Fixed => state.write_u8(2),
94        }
95    }
96}
97
98/// Straight-through estimator variants for `Op::FakeQuantize`'s
99/// backward. The forward is the same regardless: discrete
100/// `clamp(round(x/s)) * s`. The choice here affects only the
101/// gradient w.r.t. `x` during training.
102///
103/// * `Identity` — `dx = upstream`. The original STE; treats the
104///   round as identity in the backward direction. Simplest, fine
105///   for moderate bit widths (i4 / i8).
106///
107/// * `ClippedIdentity` — `dx = upstream * (|x| ≤ q_max·s)`. Zero
108///   the gradient when the input was outside the quantization
109///   range (i.e. the clamp activated). Stops the optimizer from
110///   pushing weights further into saturation.
111///
112/// * `Tanh` — `dx = upstream * (1 - tanh²(x/s))`. Smooth surrogate
113///   for the round step. Slowly attenuates the gradient as `|x|`
114///   approaches `q_max·s`. Often best on tight bit widths (i2).
115///
116/// * `HardTanh` — `dx = upstream * (1 - |x/(q_max·s)|).max(0)`.
117///   Piecewise-linear cousin of `Tanh`; cheaper to compute and
118///   nearly as effective.
119#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
120#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
121pub enum SteKind {
122    #[default]
123    Identity,
124    ClippedIdentity,
125    Tanh,
126    HardTanh,
127}
128
129/// Binary element-wise operations.
130#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
132pub enum BinaryOp {
133    Add,
134    Sub,
135    Mul,
136    Div,
137    Max,
138    Min,
139    Pow,
140}
141
142/// Comparison operations (return Bool tensor).
143#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
144#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
145pub enum CmpOp {
146    Eq,
147    Ne,
148    Lt,
149    Le,
150    Gt,
151    Ge,
152}
153
154/// What kind of attention mask the kernel should apply.
155///
156/// Borrowed from MAX's `nn/attention/mha_mask.mojo` pattern (#20 in
157/// PLAN.md): one attention kernel handles all variants by branching on
158/// the mask kind, instead of forcing every caller to materialize a mask
159/// tensor. The win is two-fold:
160///   1. **`None`** — single unpadded sequence: no mask load, no per-key
161///      compare in the inner loop.
162///   2. **`Causal`** — autoregressive decode: kernel generates the upper-
163///      triangular fill from `(qi, ki)` directly; no `seq²` mask tensor
164///      ever exists.
165///
166/// `Custom` is the existing path — read mask values from the 4th input.
167#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
168#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
169pub enum MaskKind {
170    /// No masking — every position attends to every position.
171    None,
172    /// Causal (autoregressive) — position `qi` attends only to `ki <= qi`.
173    Causal,
174    /// Sliding window — position `qi` attends to `ki ∈ [qi - w, qi]`.
175    SlidingWindow(usize),
176    /// Read mask values from the input tensor (default; matches BERT
177    /// padding-mask behavior). Tensor shape `[batch, key_len]` with
178    /// `1.0` = valid, `<0.5` = ignored.
179    Custom,
180    /// Additive per-head, per-query bias tensor
181    /// `[batch, num_heads, query_len, key_len]` added to the
182    /// `QK^T · scale` scores before softmax. Lets DETR-style boxRPB
183    /// and other learned position biases reuse the fast `Op::Attention`
184    /// path instead of decomposing into matmul + add + softmax + matmul.
185    Bias,
186}
187
188/// Which forward input an [`Op::AttentionBackward`] node differentiates.
189#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
190#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
191pub enum AttentionBwdWrt {
192    Query,
193    Key,
194    Value,
195}
196
197/// Reduction operations along specified axes.
198#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
199#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
200pub enum ReduceOp {
201    Sum,
202    Mean,
203    Max,
204    Min,
205    Prod,
206}
207
208/// PLAN L4: discriminant for each [`Op`] variant. Used by
209/// [`Op::kind`] + the `Backend::supported_ops` trait method to declare
210/// which ops a backend can lower; the `LegalizeForBackend` pass in
211/// `rlx-opt` checks the graph against this set and fails the compile
212/// when an unsupported op is present (instead of silent fallback).
213#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
214#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
215pub enum OpKind {
216    Input,
217    Param,
218    Constant,
219    Activation,
220    Cast,
221    StopGradient,
222    Quantize,
223    Dequantize,
224    FakeQuantize,
225    FakeQuantizeLSQ,
226    FakeQuantizeLSQBackwardX,
227    FakeQuantizeLSQBackwardScale,
228    Binary,
229    Compare,
230    Where,
231    ElementwiseRegion,
232    /// Fused sampling / geometry chain (FKL-style transform region).
233    TransformRegion,
234    /// Same element-wise chain over multiple batch planes (horizontal fusion).
235    BatchElementwiseRegion,
236    MatMul,
237    DotGeneral,
238    DenseSolve,
239    BatchedDenseSolve,
240    LayerNorm,
241    LayerNorm2d,
242    GroupNorm,
243    BatchNormInference,
244    RmsNorm,
245    ResizeNearest2x,
246    Attention,
247    Rope,
248    AxialRope2d,
249    Reshape,
250    Transpose,
251    Narrow,
252    Concat,
253    Expand,
254    Gather,
255    Reduce,
256    Softmax,
257    Cumsum,
258    TopK,
259    Sample,
260    Conv,
261    Im2Col,
262    ConvTranspose2d,
263    Pool,
264    ReluBackward,
265    ActivationBackward,
266    FakeQuantizeBackward,
267    ComplexNormSq,
268    ComplexNormSqBackward,
269    Conjugate,
270    MaxPool2dBackward,
271    Conv2dBackwardInput,
272    Conv2dBackwardWeight,
273    SoftmaxCrossEntropyWithLogits,
274    SoftmaxCrossEntropyBackward,
275    AttentionBackward,
276    LayerNormBackwardInput,
277    LayerNormBackwardGamma,
278    RmsNormBackwardInput,
279    RmsNormBackwardGamma,
280    RmsNormBackwardBeta,
281    RopeBackward,
282    GroupNormBackwardInput,
283    GroupNormBackwardGamma,
284    GroupNormBackwardBeta,
285    BatchNormInferenceBackwardInput,
286    BatchNormInferenceBackwardGamma,
287    BatchNormInferenceBackwardBeta,
288    CumsumBackward,
289    GatherBackward,
290    GroupedMatMul,
291    DequantGroupedMatMul,
292    DequantMoEWeights,
293    ScatterAdd,
294    LoraMatMul,
295    DequantMatMul,
296    QMatMul,
297    QConv2d,
298    SelectiveScan,
299    GatedDeltaNet,
300    FusedSwiGLU,
301    FusedMatMulBiasAct,
302    FusedResidualLN,
303    FusedResidualRmsNorm,
304    FusedAttentionBlock,
305    FusedTransformerLayer,
306    If,
307    While,
308    Scan,
309    ScanBackward,
310    ScanBackwardXs,
311    /// CPU reference 3D Gaussian splat raster (project → bin → sort → raster).
312    /// See [`Op::GaussianSplatRender`].
313    GaussianSplatRender,
314    /// Backward of [`Op::GaussianSplatRender`] — packed scene parameter gradients.
315    GaussianSplatRenderBackward,
316    /// Project + tile bin + sort + ray grid (strict IR splat stage 1).
317    GaussianSplatPrepare,
318    /// Per-pixel raster from prepared buffers (strict IR splat stage 2).
319    GaussianSplatRasterize,
320    /// User-registered op dispatched through `op_registry`. All
321    /// custom ops (Sparse-LU, FFT, eigensolve, ...) share this kind;
322    /// the per-op identity lives in `Op::Custom::name`.
323    Custom,
324    /// User-defined sub-graph with optional override AD rules. See
325    /// [`Op::CustomFn`] / [`crate::Graph::custom_fn`].
326    CustomFn,
327    /// 1D FFT primitive (forward or inverse) — see [`Op::Fft`].
328    Fft,
329    /// Ternary pruned radix-2 butterfly stage — see [`Op::FftButterflyStage`].
330    FftButterflyStage,
331    /// Whisper-style log-mel from block-layout FFT spectrum — see [`Op::LogMel`].
332    LogMel,
333    /// Backward of [`Op::LogMel`] w.r.t. block-layout spectrum input 0.
334    LogMelBackward,
335    /// Welch PSD top-K spikes from block-layout FFT spectra — see [`Op::WelchPeaks`].
336    WelchPeaks,
337}
338
339/// An operand inside a fused [`ChainStep`] — either a graph-level input
340/// to the [`Op::ElementwiseRegion`] (by index 0..num_inputs) or the
341/// result of a previous step in the chain (by index 0..step_position).
342#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
343#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
344pub enum ChainOperand {
345    Input(u32),
346    Step(u32),
347}
348
349/// One step in a fused element-wise chain. Each step produces exactly
350/// one scalar result (per element); later steps can refer to it via
351/// [`ChainOperand::Step`]. The whole chain runs per element in registers.
352#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
353#[derive(Debug, Clone, PartialEq)]
354pub enum ChainStep {
355    Activation(Activation, ChainOperand),
356    Cast(DType, ChainOperand),
357    Binary(BinaryOp, ChainOperand, ChainOperand),
358    Compare(CmpOp, ChainOperand, ChainOperand),
359    /// 3-input element-wise select: `cond ? on_true : on_false`. Mirrors
360    /// `Op::Where` inside a chain. `cond` is treated as truthy iff
361    /// non-zero. Lets the optimizer fold attention masks / clamp-style
362    /// patterns into a single region kernel instead of breaking the
363    /// chain at the first `Op::Where`.
364    Where(ChainOperand, ChainOperand, ChainOperand),
365}
366
367/// Pre-region memory transform fused into [`Op::ElementwiseRegion`].
368#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
369#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
370pub enum RegionPrologue {
371    #[default]
372    None,
373    /// Input is half-resolution NCHW; output shape is 2× H×W (nearest 2×).
374    ResizeNearest2x,
375}
376
377/// One sampling step in [`Op::TransformRegion`].
378#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
379#[derive(Debug, Clone, PartialEq)]
380pub enum TransformStep {
381    ResizeNearest2x(ChainOperand),
382}
383
384/// An operation in the RLX IR graph.
385///
386/// Operations are categorized for fusion analysis:
387/// - Element-wise ops fuse with anything reading their output
388/// - Matmul/Conv are BLAS-dispatched and form fusion boundaries
389/// - Reductions are fusion roots (drive the loop iteration)
390#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
391#[derive(Debug, Clone, PartialEq)]
392pub enum Op {
393    // ── Graph inputs ────────────────────────────────────────────
394    /// Model input with a name (shape on the Node).
395    Input {
396        name: String,
397    },
398
399    /// Model parameter (weight/bias) with a name.
400    Param {
401        name: String,
402    },
403
404    /// Constant tensor embedded in the graph.
405    Constant {
406        data: Vec<u8>,
407    },
408
409    // ── Element-wise unary ──────────────────────────────────────
410    /// Unary activation: one input, same shape output.
411    Activation(Activation),
412
413    /// Cast to a different dtype.
414    Cast {
415        to: DType,
416    },
417
418    /// Stop-gradient (a.k.a. `detach` / `lax.stop_gradient`). Forward is
419    /// identity; the reverse-mode autodiff rule returns **no** gradient
420    /// contribution for the input. Single input, same shape & dtype
421    /// output. Used to build a Gradient-Reverse-Layer with identity
422    /// forward semantics in user code (see maet-rs `dat_loss`).
423    StopGradient,
424
425    /// INT8 quantization. Input f32; output i8 same shape.
426    ///   `q[i] = saturate_i8(round(x[i] / scale[c]) + zero_point[c])`
427    /// where `c` selects the per-channel scale/zp when `axis = Some(d)`
428    /// (`c = idx[d]`), or always uses index 0 when `axis = None`
429    /// (per-tensor). The `scales` / `zero_points` payload length must
430    /// match `1` for per-tensor and `input.dim(d)` for per-channel.
431    /// Static — typically produced at calibration time and baked
432    /// into the loaded model. Use `Op::Dequantize` for the inverse.
433    Quantize {
434        axis: Option<usize>,
435        scales: Vec<f32>,
436        zero_points: Vec<i32>,
437    },
438
439    /// INT8 dequantization (inverse of `Op::Quantize`). Input i8;
440    /// output f32 same shape.
441    ///   `x[i] = (q[i] - zero_point[c]) · scale[c]`
442    /// where `c` is selected by `axis` exactly as in `Op::Quantize`.
443    Dequantize {
444        axis: Option<usize>,
445        scales: Vec<f32>,
446        zero_points: Vec<i32>,
447    },
448
449    /// "Fake-quantize" op for **quantization-aware training** (QAT).
450    /// Input f32; output f32 same shape. Forward computes a per-axis
451    /// (or per-tensor when `axis = None`) max-abs scale on the fly:
452    ///   `s[c] = max(|x[..., c, ...]|) / q_max(bits)`
453    /// then quantizes-then-dequantizes:
454    ///   `out[i] = clamp(round(x[i] / s[c]), -q_max, q_max) * s[c]`
455    /// where `q_max` is `127` for `bits=8`, `7` for `bits=4`, `1` for
456    /// `bits=2` (ternary). Symmetric only — zero-point is always 0.
457    ///
458    /// The point of this op is to make the SGD optimizer "see" the
459    /// deployment-time rounding during training. Backward is the
460    /// **straight-through estimator** (STE): the gradient passes
461    /// through (variant chosen by `ste`), ignoring the discontinuity
462    /// at the round. Without STE the rounding would have zero
463    /// gradient almost everywhere and learning would stop.
464    ///
465    /// Inserted by the trainer on conv / FC weight tensors when
466    /// `--qat` is on; the existing `Op::Quantize` / packing path at
467    /// the end of training still handles the deployment-side
468    /// conversion to `i8`/`i4`/`i2` codes.
469    FakeQuantize {
470        bits: u8,
471        axis: Option<usize>,
472        ste: SteKind,
473        scale_mode: ScaleMode,
474    },
475
476    /// Learned Step Size Quantization (LSQ; Esser et al. 2020,
477    /// `arXiv:1902.08153`). Like `FakeQuantize` but the per-channel
478    /// `scale` is a *learned parameter*, passed as the second input.
479    /// Forward is identical to `FakeQuantize` with a fixed scale:
480    ///   `out[i] = clamp(round(x[i]/s[c]), -q_max, q_max) * s[c]`
481    /// Backward computes both `dx` (STE) and `dscale[c]` via the
482    /// closed-form gradient:
483    ///   `dscale[c] = sum_i ψ(x[i]/s[c]) · upstream[i]`
484    /// where `ψ(z) = -z + round(z)` if `|z| ≤ q_max` else
485    /// `sign(z) · q_max`. Routinely beats per-batch and EMA at
486    /// tight bit widths (i2 / i3).
487    ///
488    /// Inputs: `[x, scale]`. `scale` is `[chan_dim]` f32 (matches
489    /// `axis`); for `axis = None` it's `[1]`.
490    FakeQuantizeLSQ {
491        bits: u8,
492        axis: Option<usize>,
493    },
494
495    /// Backward pass for `Op::FakeQuantizeLSQ`. Computes BOTH the
496    /// gradient w.r.t. `x` (STE) and the gradient w.r.t. `scale`
497    /// (closed-form). Output shape matches `x`; the `scale` gradient
498    /// is reduced separately by `LsqScaleGradient`.
499    /// Inputs: `[x, scale, dy]`. Output: `dx`, same shape as `x`.
500    FakeQuantizeLSQBackwardX {
501        bits: u8,
502        axis: Option<usize>,
503    },
504
505    /// Companion to `FakeQuantizeLSQBackwardX`: computes the
506    /// `[chan_dim]` per-channel scale gradient. Inputs `[x, scale, dy]`.
507    /// Output shape matches `scale`.
508    FakeQuantizeLSQBackwardScale {
509        bits: u8,
510        axis: Option<usize>,
511    },
512
513    // ── Element-wise binary ─────────────────────────────────────
514    /// Binary op with broadcasting: two inputs, output shape is broadcast result.
515    Binary(BinaryOp),
516
517    // ── Comparison ──────────────────────────────────────────────
518    /// Element-wise comparison: two inputs, Bool output.
519    Compare(CmpOp),
520
521    /// Select elements: cond (Bool), on_true, on_false → output.
522    Where,
523
524    /// Fused element-wise region (PLAN L2). Holds an N-step chain of
525    /// element-wise operations. Inputs are referenced by index 0..num_inputs;
526    /// each step's result can be referenced by later steps via
527    /// `ChainOperand::Step(idx)`. The output is the last step's result.
528    /// Emitted by `MarkElementwiseRegions` in `rlx-opt` from chains of
529    /// Activation/Cast/Binary/Compare/Where ops with single-consumer
530    /// intermediates and broadcast-compatible shapes. Backends that
531    /// don't have a region kernel can decompose back to the original
532    /// chain via `unfuse::unfuse_elementwise_regions`.
533    ///
534    /// `scalar_input_mask` is a per-input bitfield (bit `i` set ⇒
535    /// input `i` is a scalar broadcast — has shape `[1]`). Kept as a
536    /// fast-path indicator that lets kernels skip the modulo entirely
537    /// when they detect a scalar.
538    ///
539    /// `input_modulus[i]` is the per-input element count, used by
540    /// kernels to compute `arena[input_offs[i] + (gid % input_modulus[i])]`
541    /// — the trailing-shape broadcast pattern. `0` means "no broadcast"
542    /// (input matches the output element count; kernel reads `gid`
543    /// directly). `1` means scalar; any other value means the input
544    /// has fewer elements than the output and they tile by modulo.
545    /// The encoder only allows broadcasts where `out_elems % in_elems
546    /// == 0` so the modulo divides cleanly. Lets chains include bias /
547    /// scale / eps / mask factors that previously broke the chain at
548    /// a Binary op with mismatched shapes.
549    ElementwiseRegion {
550        chain: Vec<ChainStep>,
551        num_inputs: u32,
552        scalar_input_mask: u32,
553        input_modulus: [u32; 16],
554        /// FKL-style closed fusion: apply before the element-wise chain.
555        prologue: RegionPrologue,
556        /// External input index that supplies the prologue transform source (default 0).
557        prologue_input: u32,
558    },
559
560    /// Fused transform chain (resize, future crop/color). Decompose via
561    /// [`rlx_fusion::DecomposeFusionRegions`] when no native kernel exists.
562    TransformRegion {
563        steps: Vec<TransformStep>,
564        num_inputs: u32,
565    },
566
567    /// Identical [`Op::ElementwiseRegion`] chain over `num_batch_inputs` tensors
568    /// (horizontal / z-plane fusion). Inputs are separate batch slices.
569    BatchElementwiseRegion {
570        chain: Vec<ChainStep>,
571        num_batch_inputs: u32,
572        scalar_input_mask: u32,
573        input_modulus: [u32; 16],
574        prologue: RegionPrologue,
575        prologue_input: u32,
576    },
577
578    // ── Linear algebra ──────────────────────────────────────────
579    /// Matrix multiply. Inputs: [.., M, K] × [.., K, N] → [.., M, N].
580    /// Batch dimensions are broadcast.
581    MatMul,
582
583    /// Matrix multiply with explicit dimension specification.
584    /// Like XLA's DotGeneral — handles arbitrary batch/contracting dims.
585    DotGeneral {
586        lhs_contracting: Vec<usize>,
587        rhs_contracting: Vec<usize>,
588        lhs_batch: Vec<usize>,
589        rhs_batch: Vec<usize>,
590    },
591
592    /// Batched dense linear solve. Inputs: `A [B, N, N]`,
593    /// `b [B, N]` or `b [B, N, K]`. Output: same shape as `b`.
594    ///
595    /// Per-batch independent solve — each `A[i]` and `b[i]` are
596    /// solved as a separate `Op::DenseSolve`. Emitted by vmap of
597    /// `Op::DenseSolve`. The CPU lowering loops over the batch
598    /// dimension calling `dgesv` per slice (LAPACK doesn't expose a
599    /// batched solve on Accelerate; cuSOLVER does on NVIDIA).
600    BatchedDenseSolve,
601
602    /// Dense linear solve `x = A⁻¹ · b` via LU factorization.
603    /// Inputs: `A [N, N]`, `b [N]` (or `b [N, K]` for multi-RHS).
604    /// Output: same shape as `b`.
605    ///
606    /// VJP via the implicit-function theorem:
607    ///   `dx = solve(Aᵀ, upstream)`
608    ///   `dA = -outer(dx, x)`   (x is the forward output)
609    ///   `db = dx`
610    /// The rule is dtype-agnostic; lowering is per-backend (Accelerate
611    /// `dgesv` / `sgesv`, cuSOLVER, etc.).
612    DenseSolve,
613
614    // ── Normalization ───────────────────────────────────────────
615    /// Layer normalization: input, gamma, beta → normalized output.
616    /// `axis` is the feature dimension (usually -1).
617    LayerNorm {
618        axis: i32,
619        eps: f32,
620    },
621
622    /// Group normalization on NCHW tensors: `input`, `gamma`, `beta` → same shape.
623    /// Normalizes over `(C/num_groups) × H × W` per group.
624    GroupNorm {
625        num_groups: usize,
626        eps: f32,
627    },
628
629    /// LayerNorm2d on NCHW: normalize across the channel axis at each spatial
630    /// position (candle / SAM `LayerNorm2d` semantics — not PyTorch's H×W norm).
631    LayerNorm2d {
632        eps: f32,
633    },
634
635    /// Nearest-neighbor 2× upsample on NCHW (doubles spatial dims 2 and 3).
636    ResizeNearest2x,
637
638    /// RMS normalization: input, gamma → normalized output.
639    RmsNorm {
640        axis: i32,
641        eps: f32,
642    },
643
644    /// BatchNorm inference with frozen running statistics.
645    /// Inputs: `x`, `gamma`, `beta`, `running_mean`, `running_var`.
646    /// Feature dimension is the last axis of `x`; stats are 1-D `[C]`.
647    BatchNormInference {
648        eps: f32,
649    },
650
651    // ── Attention ───────────────────────────────────────────────
652    /// Scaled dot-product attention: Q, K, V, \[mask\] → output.
653    /// The compiler can lower this to fused SDPA or flash attention.
654    /// `mask_kind` controls how masking is applied — `Custom` reads from
655    /// the 4th input tensor; `None` / `Causal` / `SlidingWindow` skip the
656    /// mask load and apply the mask directly in the inner loop. See
657    /// `MaskKind` for the rationale.
658    ///
659    /// `score_scale`: when `Some(s)`, dot-product scores are multiplied by
660    /// `s` instead of the default `1/sqrt(head_dim)` (Gemma uses `head_dim^-0.5`
661    /// explicitly in config). `attn_logit_softcap`: when `Some(c)`, applies
662    /// `c * tanh(s/c)` to scores before softmax (Gemma 2).
663    Attention {
664        num_heads: usize,
665        head_dim: usize,
666        mask_kind: MaskKind,
667        score_scale: Option<f32>,
668        attn_logit_softcap: Option<f32>,
669    },
670
671    /// Rotary position embedding applied to one tensor: x, cos, sin → x_rotated.
672    /// Apply separately to Q and K. `head_dim` is the per-head width; `n_rot`
673    /// is how many leading dims get NeoX RoPE (pair offset `n_rot/2`). When
674    /// `n_rot < head_dim`, trailing dims are copied unchanged (Qwen3.5 MRoPE).
675    Rope {
676        head_dim: usize,
677        n_rot: usize,
678    },
679
680    /// SAM2 axial 2-D RoPE on `[batch, seq, num_heads * head_dim]`.
681    AxialRope2d {
682        end_x: usize,
683        end_y: usize,
684        head_dim: usize,
685        num_heads: usize,
686        theta: f32,
687        repeat_factor: usize,
688    },
689
690    // ── Shape manipulation ──────────────────────────────────────
691    Reshape {
692        new_shape: Vec<i64>,
693    },
694    Transpose {
695        perm: Vec<usize>,
696    },
697    /// Select a contiguous slice along an axis.
698    Narrow {
699        axis: usize,
700        start: usize,
701        len: usize,
702    },
703    /// Concatenate along an axis.
704    Concat {
705        axis: usize,
706    },
707    /// Expand (broadcast) to a target shape.
708    Expand {
709        target_shape: Vec<i64>,
710    },
711    /// Gather elements by index along an axis (embedding lookup).
712    Gather {
713        axis: usize,
714    },
715
716    // ── Reduction ───────────────────────────────────────────────
717    /// Reduce along specified axes.
718    Reduce {
719        op: ReduceOp,
720        axes: Vec<usize>,
721        keep_dim: bool,
722    },
723
724    /// Selective scan (plan #15) — Mamba-style state-space model
725    /// step. The recurrence:
726    ///   `h[t] = exp(Δ[t] * A) * h[t-1] + Δ[t] * B[t] * x[t]`
727    ///   `y[t] = C[t] * h[t]`
728    /// where state `h` has dimension `state_size` and the input has
729    /// `(batch, seq, hidden)`.
730    ///
731    /// Inputs (in order):
732    ///   `x [b, s, h]`      f32 input
733    ///   `delta [b, s, h]`  f32 step size (per-position, per-channel)
734    ///   `a [h, n]`         f32 transition matrix (one per channel)
735    ///   `b [b, s, n]`      f32 input projection
736    ///   `c [b, s, n]`      f32 output projection
737    /// Output: `[b, s, h]` f32. State `h` is implicit; the kernel
738    /// scans through the seq dimension carrying it.
739    ///
740    /// `state_size` = `n` is exposed for the cost model.
741    SelectiveScan {
742        state_size: usize,
743    },
744
745    /// Gated DeltaNet linear-attention recurrence — the per-layer
746    /// kernel used by Qwen3.5/3.6 trunk "linear attention" blocks
747    /// (and Qwen3-Next, Kimi-Linear). Mirrors
748    /// `llama.cpp / src/models/delta-net-base.cpp` autoregressive
749    /// path; chunked + fused variants ride the same op identity.
750    ///
751    /// **Math (per token `t`, head `h`, state size `n`):**
752    /// state matrix `S[h, i, j]` is implicit (reset per batch).
753    /// ```text
754    ///   S[h]     *= exp(g[t,h])                     # scalar gate
755    ///   sk[h,j]   = Σ_i S[h,i,j] * k[t,h,i]
756    ///   d[h,j]    = (v[t,h,j] - sk[h,j]) * b[t,h]   # b = beta
757    ///   S[h,i,j] += k[t,h,i] * d[h,j]               # outer-prod
758    ///   o[t,h,j]  = Σ_i S[h,i,j] * (q[t,h,i] / √n)
759    /// ```
760    ///
761    /// Inputs:
762    ///   `q     [b, s, h_v, n]`  f32 queries (L2-normed by caller)
763    ///   `k     [b, s, h_v, n]`  f32 keys    (L2-normed by caller;
764    ///                            GQA-repeated to match `h_v`)
765    ///   `v     [b, s, h_v, n]`  f32 values
766    ///   `g     [b, s, h_v]`     f32 log-gate (exp'd inside kernel)
767    ///   `beta  [b, s, h_v]`     f32 delta-rule mixing factor
768    ///
769    /// Output: `[b, s, h_v, n]` f32.
770    ///
771    /// When `carry_state` is true, a sixth input `state [b, h_v, n, n]`
772    /// provides the initial SSM matrix per head; the kernel updates it
773    /// in place across the sequence and leaves the final state in the
774    /// same buffer (same layout as the internal scan state:
775    /// `state[h, i, j]` row-major over `(n, n)` per head).
776    GatedDeltaNet {
777        state_size: usize,
778        carry_state: bool,
779    },
780
781    /// Fused dequant + matmul (plan #5). The biggest LLM-bandwidth
782    /// win on Apple Silicon: dequantizes weights inside the matmul
783    /// inner loop, never materializing f32 weights.
784    ///
785    /// **BREAKING CHANGE in 0.2.0:** `num_inputs()` is now
786    /// scheme-dependent — **4** for legacy Int8 schemes, **2** for
787    /// the new GGUF K-quant schemes (their scales/mins live inside
788    /// the packed bytes, so no side-channel `scale` / `zp` tensors
789    /// are fed in). Callers that assumed a fixed 4-input contract
790    /// must inspect `scheme.is_gguf()` before reading inputs.
791    ///
792    /// Inputs (Int8 schemes — `scheme.is_gguf() == false`):
793    ///   `x [m, k]`             f32 activations
794    ///   `w_q [k, n]` packed    quantized weight bytes (i8 per
795    ///                          element for Int8 schemes; 4-bit
796    ///                          packed two-per-byte for Int4)
797    ///   `scale [k/block, n]`   per-block f32 dequant scale
798    ///   `zp    [k/block, n]`   per-block f32 zero-point
799    ///                          (zero-tensor if symmetric)
800    ///
801    /// Inputs (`Nvfp4Block` — fixed group size 16 along K):
802    ///   `x [m, k]`             f32 activations
803    ///   `w_q [k,n/2]` packed   FP4 E2M1 codes (unsigned nibble 0..15)
804    ///   `scale [k/16, n]` u8   FP8 E4M3 block scales (one byte / group)
805    ///   `global_scale [1]` f32 per-tensor scale (pass `[1.0]` if unused)
806    ///
807    /// Inputs (GGUF schemes — `scheme.is_gguf() == true`):
808    ///   `x [m, k]`             f32 activations
809    ///   `packed_w [bytes]`     raw GGUF super-block bytes; the
810    ///                          dequantizer reads the per-sub-block
811    ///                          scales / mins / quants directly out
812    ///                          of the buffer per the K-quant block
813    ///                          layout (no side tensors).
814    ///
815    /// Output: `[m, n]` f32.
816    ///
817    /// `block_size` (on the Int8 schemes only) is the number of
818    /// consecutive elements that share one (scale, zero_point) pair.
819    /// The Op carries enough metadata that the kernel doesn't need
820    /// a separate `QuantMap` lookup at run time.
821    DequantMatMul {
822        scheme: crate::quant::QuantScheme,
823    },
824
825    /// Real INT8-arithmetic matrix multiply with i32 accumulation.
826    /// Inputs (in order):
827    ///   `x      [M, K]`  i8 activations (zero-point = `x_zp`)
828    ///   `w      [K, N]`  i8 weights     (zero-point = `w_zp`)
829    ///   `bias   [N]`     i32 (in accumulator scale = `x_scale·w_scale`),
830    ///                    pass a zeros tensor for "no bias"
831    /// Output:  `[M, N]`  i8 (zero-point = `out_zp`)
832    ///
833    /// Per-element compute:
834    ///   `out[m,n] = requantize(bias[n] + Σₖ (x[m,k]-x_zp)·(w[k,n]-w_zp), mult, out_zp)`
835    /// where `mult = x_scale · w_scale / out_scale`.
836    ///
837    /// This is the same kernel shape `rlx-cortexm/src/dense.rs`
838    /// uses for on-device int8 inference, lifted into the IR so the
839    /// rlx-cpu backend can run a quantized graph directly (instead
840    /// of round-tripping through fake-quant Dequantize → MatMul →
841    /// Quantize). 2-D only — generalizing to batched comes when a
842    /// real workload demands it.
843    QMatMul {
844        x_zp: i32,
845        w_zp: i32,
846        out_zp: i32,
847        mult: f32,
848    },
849
850    /// Real INT8-arithmetic 2-D convolution with i32 accumulation.
851    /// Inputs:
852    ///   `x      [N, C_in, H, W]`              i8 (zero-point = `x_zp`)
853    ///   `w      [C_out, C_in/groups, kH, kW]` i8 (zero-point = `w_zp`)
854    ///   `bias   [C_out]`                      i32 in accumulator scale
855    /// Output: `[N, C_out, H_out, W_out]` i8 (zero-point = `out_zp`).
856    /// Same NCHW geometry contract as `Op::Conv`; same requantize
857    /// math as `Op::QMatMul` (per-element `acc·mult` rounded to i8).
858    QConv2d {
859        kernel_size: Vec<usize>,
860        stride: Vec<usize>,
861        padding: Vec<usize>,
862        dilation: Vec<usize>,
863        groups: usize,
864        x_zp: i32,
865        w_zp: i32,
866        out_zp: i32,
867        mult: f32,
868    },
869
870    /// Fused LoRA matmul: `out = x·W + scale * x·A·B`.
871    /// Inputs (in order): `x [m, k]`, `w [k, n]`, `a [k, r]`, `b [r, n]`.
872    /// `r` is the LoRA rank (typically 4-64). `scale` is the
873    /// per-adapter `alpha / rank` knob.
874    /// Plan #9: lifts LoRA from "three matmuls + an add" into one
875    /// kernel that keeps the rank-r intermediate in registers.
876    LoraMatMul {
877        scale: f32,
878    },
879
880    /// Fused sampling kernel: logits → optional top-k filter →
881    /// optional top-p truncation → softmax → multinomial sample.
882    /// One f32-encoded sampled token id per batch row (output
883    /// shape `[batch]`).
884    ///
885    /// `temperature == 1.0` matches a plain argmax-of-softmax;
886    /// lower → sharper, higher → flatter. `top_k == 0` disables.
887    /// `top_p == 1.0` disables. `seed` is the Philox seed; pass 0
888    /// for "use process-global counter" (still deterministic
889    /// given the call order).
890    /// Borrowed from MAX's nn/sampling.mojo (#42 in PLAN.md).
891    /// Latency-critical: never materializes the full softmax
892    /// distribution on the host.
893    Sample {
894        top_k: usize,     // 0 = disabled
895        top_p: f32,       // 1.0 = disabled
896        temperature: f32, // 1.0 = neutral
897        seed: u64,        // 0 = use thread-local counter
898    },
899
900    /// Inclusive cumulative sum along an axis. Same shape in/out.
901    /// Underpins ragged-tensor offsets, sampling (top-p prefix sum),
902    /// and sequence-position math (#44 in PLAN.md).
903    /// `exclusive=true` shifts the result so output\[0\] = 0 (useful
904    /// for offset arrays where the first segment starts at 0).
905    Cumsum {
906        axis: i32,
907        exclusive: bool,
908    },
909
910    /// Softmax along an axis (reduction + element-wise).
911    Softmax {
912        axis: i32,
913    },
914
915    /// Top-K **indices** along the last axis. Output shape `[..., k]`,
916    /// f32-encoded indices (rlx is f32-only at the I/O boundary).
917    /// To recover the values, follow with a `Gather` against the
918    /// original tensor — works because Gather already supports any axis.
919    /// Ties broken by smaller index (matches NumPy / PyTorch
920    /// `torch.topk(..., largest=True, sorted=True)`).
921    /// Used by MoE gating; also useful for beam search.
922    TopK {
923        k: usize,
924    },
925
926    /// Indexed batched matmul. The MoE GEMM primitive.
927    /// Inputs: `[input, weight, expert_idx]`
928    ///   input       : [M, K]                — per-token activations
929    ///   weight      : [num_experts, K, N]   — stacked expert weights
930    ///   expert_idx  : \[M\]                   — f32-encoded expert id per token
931    /// Output         : [M, N]                — output\[i\] = input\[i\] @ weight[expert_idx\[i\]]
932    /// Naive impl on both backends; future work can replace with a
933    /// segmented/grouped GEMM when there's a real workload.
934    GroupedMatMul,
935
936    /// Fused GGUF K-quant dequant + [`Op::GroupedMatMul`]. Same three
937    /// inputs as `GroupedMatMul`, but `weight` is a U8 tensor holding
938    /// `num_experts` contiguous packed expert slabs (GGML layout, expert
939    /// dimension outermost). Scales live inside the packed bytes.
940    DequantGroupedMatMul {
941        scheme: crate::quant::QuantScheme,
942    },
943
944    /// Dequant a packed MoE expert stack to F32 `[num_experts, K, N]` in
945    /// GroupedMatMul layout. Input: U8 packed bytes; output shape is
946    /// declared on the node (`[E, K, N]`).
947    DequantMoEWeights {
948        scheme: crate::quant::QuantScheme,
949    },
950
951    /// Scatter-add into a destination tensor. The "unpermute" half of
952    /// MoE routing (also useful for embedding gradient updates).
953    /// Inputs: `[updates, indices]`
954    ///   updates : [num_updates, trailing]   — values to add
955    ///   indices : \[num_updates\]             — f32-encoded destination row
956    /// Output    : [out_dim, trailing]       — output[indices\[i\]] += updates\[i\]
957    /// `out_dim` is taken from the node's declared output shape.
958    /// Initial output is zero; multiple updates to the same row
959    /// accumulate (sequentially on CPU; with atomic-add on Metal).
960    ScatterAdd,
961
962    // ── Convolution ─────────────────────────────────────────────
963    /// 2D convolution on NCHW tensors. Also exposed as [`OpKind::Conv`] / `conv2d`.
964    /// Weight layout: `[C_out, C_in / groups, kH, kW]`.
965    Conv {
966        kernel_size: Vec<usize>,
967        stride: Vec<usize>,
968        padding: Vec<usize>,
969        dilation: Vec<usize>,
970        groups: usize,
971    },
972
973    /// NCHW im2col for conv backward-weight style matmul.
974    /// Input `[N, C, H, W]`. Output `[M, C·kH·kW]` with
975    /// `M = N · H_out · W_out`. When batch is [`dynamic::sym::BATCH`],
976    /// output rows use [`dynamic::sym::ROWS`] (bind `N · H_out · W_out`).
977    Im2Col {
978        kernel_size: Vec<usize>,
979        stride: Vec<usize>,
980        padding: Vec<usize>,
981        dilation: Vec<usize>,
982    },
983
984    /// 2D transposed convolution on NCHW. Weight layout (PyTorch):
985    /// `[C_in, C_out / groups, kH, kW]`.
986    ConvTranspose2d {
987        kernel_size: Vec<usize>,
988        stride: Vec<usize>,
989        padding: Vec<usize>,
990        dilation: Vec<usize>,
991        output_padding: Vec<usize>,
992        groups: usize,
993    },
994
995    // ── Pooling ─────────────────────────────────────────────────
996    Pool {
997        kind: ReduceOp,
998        kernel_size: Vec<usize>,
999        stride: Vec<usize>,
1000        padding: Vec<usize>,
1001    },
1002
1003    // ── Backward / training ops ─────────────────────────────────
1004    //
1005    // Closed-form gradient nodes emitted by `rlx-opt::autodiff`.
1006    // Pairing a forward op with a dedicated backward op (instead of
1007    // composing it from primitives) keeps the gradient kernel simple
1008    // and lets the backend recompute argmax / masks / softmax inline.
1009    /// ReLU backward: `dx = dy where x > 0 else 0`.
1010    /// Inputs: `[x, dy]` — both same shape; output matches.
1011    ReluBackward,
1012
1013    /// Element-wise complex squared-magnitude: `|z|² = z.re² + z.im²`.
1014    /// Input: 1 tensor with `DType::C64`. Output: same shape but
1015    /// `DType::F32`. The natural real-valued loss surface for
1016    /// Wirtinger reverse-mode AD on complex graphs — pair with
1017    /// [`Op::ComplexNormSqBackward`].
1018    ComplexNormSq,
1019
1020    /// Element-wise complex conjugate: `z̄ = z.re - i·z.im` per element.
1021    /// Input: 1 tensor with `DType::C64`. Output: same shape, same dtype.
1022    /// Used by Wirtinger VJP rules on `Op::Binary` over C64 (the rule
1023    /// for `y = a·b` is `dL/dā = upstream · conj(b)`, etc.).
1024    Conjugate,
1025
1026    /// Backward for [`Op::ComplexNormSq`] under Wirtinger calculus.
1027    /// `f(z) = |z|² = z·z̄`, so `∂f/∂z̄ = z`. Given upstream real
1028    /// cotangent `g` (same shape as the forward output), the C64
1029    /// gradient with respect to `z` is `g · z` element-wise, returned
1030    /// in C64 storage `[re_g·re_z, re_g·im_z]` per element.
1031    ///
1032    /// Inputs: `[z (C64), g (F32)]` — both same logical shape; output
1033    /// matches `z` (C64).
1034    ComplexNormSqBackward,
1035
1036    /// LayerNorm backward w.r.t. the input. Computes
1037    ///   `d_x[..., d] = inv_std · (dy·γ - mean(dy·γ) - x̂·mean(dy·γ·x̂))`
1038    /// over the feature axis, where `x̂ = (x - mean)/std` is recomputed
1039    /// inline from `x`. Inputs: `[x, gamma, dy]`; output shape = `x.shape`.
1040    /// Currently lowers axis=-1 only (matches the forward thunk).
1041    LayerNormBackwardInput {
1042        axis: i32,
1043        eps: f32,
1044    },
1045
1046    /// LayerNorm backward w.r.t. gamma. Computes
1047    ///   `d_gamma[d] = Σ_{batch} dy[..., d] · x̂[..., d]`
1048    /// — sums the per-element product of upstream and the (recomputed)
1049    /// normalized input over the leading axes. Inputs: `[x, dy]`;
1050    /// output shape = `gamma.shape` (= 1-D feature axis).
1051    LayerNormBackwardGamma {
1052        axis: i32,
1053        eps: f32,
1054    },
1055
1056    /// RMSNorm backward w.r.t. input. Inputs `[x, gamma, beta, dy]`; output = `x.shape`.
1057    RmsNormBackwardInput {
1058        axis: i32,
1059        eps: f32,
1060    },
1061
1062    /// RMSNorm backward w.r.t. gamma. Inputs `[x, gamma, beta, dy]`; output = `gamma.shape`.
1063    RmsNormBackwardGamma {
1064        axis: i32,
1065        eps: f32,
1066    },
1067
1068    /// RMSNorm backward w.r.t. beta. Inputs `[x, gamma, beta, dy]`; output = `beta.shape`.
1069    RmsNormBackwardBeta {
1070        axis: i32,
1071        eps: f32,
1072    },
1073
1074    /// RoPE backward w.r.t. `x`. Inputs `[dy, cos, sin]`; output = `dy.shape`.
1075    RopeBackward {
1076        head_dim: usize,
1077        n_rot: usize,
1078    },
1079
1080    /// GroupNorm (NCHW) backward w.r.t. input. Inputs `[x, gamma, beta, dy]`.
1081    GroupNormBackwardInput {
1082        num_groups: usize,
1083        eps: f32,
1084    },
1085
1086    /// GroupNorm backward w.r.t. gamma. Inputs `[x, dy]`; output = `gamma.shape`.
1087    GroupNormBackwardGamma {
1088        num_groups: usize,
1089        eps: f32,
1090    },
1091
1092    /// GroupNorm backward w.r.t. beta. Inputs `[x, dy]`; output = `beta.shape`.
1093    GroupNormBackwardBeta {
1094        num_groups: usize,
1095        eps: f32,
1096    },
1097
1098    /// BatchNorm inference backward w.r.t. `x`. Inputs `[x, gamma, mean, var, dy]`.
1099    BatchNormInferenceBackwardInput {
1100        eps: f32,
1101    },
1102
1103    /// BatchNorm inference backward w.r.t. `gamma`. Inputs `[x, mean, var, dy]`.
1104    BatchNormInferenceBackwardGamma {
1105        eps: f32,
1106    },
1107
1108    /// BatchNorm inference backward w.r.t. `beta`. Inputs `[dy]`; output = `beta.shape`.
1109    BatchNormInferenceBackwardBeta,
1110
1111    /// Cumsum backward along `axis`. Inputs `[dy]`; output matches forward input shape.
1112    CumsumBackward {
1113        axis: i32,
1114        exclusive: bool,
1115    },
1116
1117    /// Gather backward (scatter-add into table). Inputs `[dy, indices]`; output = table shape.
1118    /// `axis` matches forward [`Op::Gather`].
1119    GatherBackward {
1120        axis: i32,
1121    },
1122
1123    /// Generic element-wise activation backward. `kind` selects the
1124    /// closed-form derivative `d/dx act(x)`. Inputs: `[x, dy]`; output
1125    /// shape matches `x`. The kernel computes `d/dx · dy` per element.
1126    ///
1127    /// Closed forms (all element-wise):
1128    /// * `Gelu`     — exact derivative of erf-based GELU.
1129    /// * `GeluApprox` — derivative of the tanh approximation
1130    ///   `0.5 x (1 + tanh(√(2/π)(x + 0.044715 x³)))`.
1131    /// * `Silu`     — `σ(x)·(1 + x·(1 - σ(x)))`.
1132    /// * `Sigmoid`  — `σ(x)·(1 - σ(x))`.
1133    /// * `Tanh`     — `1 - tanh(x)²`.
1134    /// * `Exp`      — `exp(x)`.
1135    /// * `Log`      — `1 / x`.
1136    /// * `Sqrt`     — `0.5 / sqrt(x)`.
1137    /// * `Rsqrt`    — `-0.5 · x^(-3/2)`.
1138    /// * `Neg`      — `-1`.
1139    /// * `Abs`      — `sign(x)` (zero at x=0).
1140    /// * `Sin`      — `cos(x)`.
1141    /// * `Cos`      — `-sin(x)`.
1142    /// * `Tan`      — `1 + tan²(x)`.
1143    /// * `Atan`     — `1 / (1 + x²)`.
1144    /// * `Relu`     — kept here for completeness; the dedicated
1145    ///   `ReluBackward` op is preferred for relu and is what the
1146    ///   autodiff pass actually emits.
1147    ActivationBackward {
1148        kind: Activation,
1149    },
1150
1151    /// Backward for `Op::FakeQuantize` under a non-default STE.
1152    /// Inputs `[x, dy]`: the forward input and the upstream
1153    /// gradient. Output `dx` same shape. The `bits`/`axis`/`ste`
1154    /// fields must match the forward op so the kernel computes the
1155    /// same per-channel scale and applies the right STE attenuation.
1156    /// For `SteKind::Identity` this op is unnecessary — autodiff
1157    /// just routes `upstream` through unchanged.
1158    FakeQuantizeBackward {
1159        bits: u8,
1160        axis: Option<usize>,
1161        ste: SteKind,
1162    },
1163
1164    /// 2D max-pool backward. Routes each element of `dy` back into the
1165    /// position in `x`'s window where the forward max was taken.
1166    /// Inputs: `[x, dy]` with `x [N, C, H, W]` and
1167    /// `dy [N, C, H_out, W_out]`. Output: same shape as `x`.
1168    /// Carries the forward pool's geometry so the kernel can recompute
1169    /// the argmax position per window without a saved-indices tensor.
1170    MaxPool2dBackward {
1171        kernel_size: Vec<usize>,
1172        stride: Vec<usize>,
1173        padding: Vec<usize>,
1174    },
1175
1176    /// 2D conv backward w.r.t. input. Computes `dx = conv_transpose(dy, w)`.
1177    /// Inputs: `[dy, w]` with `dy [N, C_out, H_out, W_out]` and
1178    /// `w [C_out, C_in/groups, kH, kW]`. Output: `[N, C_in, H, W]`
1179    /// (declared on the node — caller knows the original input shape).
1180    /// Geometry is the forward conv's parameters, not the transposed
1181    /// conv's.
1182    Conv2dBackwardInput {
1183        kernel_size: Vec<usize>,
1184        stride: Vec<usize>,
1185        padding: Vec<usize>,
1186        dilation: Vec<usize>,
1187        groups: usize,
1188    },
1189
1190    /// 2D conv backward w.r.t. weight. Computes
1191    /// `dw[c_out, c_in, kh, kw] = sum_{n,h_out,w_out} x[n,c_in,...] * dy[n,c_out,h_out,w_out]`.
1192    /// Inputs: `[x, dy]`. Output: `[C_out, C_in/groups, kH, kW]`.
1193    Conv2dBackwardWeight {
1194        kernel_size: Vec<usize>,
1195        stride: Vec<usize>,
1196        padding: Vec<usize>,
1197        dilation: Vec<usize>,
1198        groups: usize,
1199    },
1200
1201    /// Fused softmax + cross-entropy loss with integer (f32-encoded)
1202    /// targets — the standard classification loss. Per-row output:
1203    /// `loss[n] = -log(softmax(logits[n])[labels[n]])`.
1204    /// Inputs: `[logits, labels]` with `logits [N, C]` and
1205    /// `labels [N]` (f32-encoded class indices). Output: `[N]`.
1206    /// Caller does the `Reduce::Mean` if they want a scalar.
1207    SoftmaxCrossEntropyWithLogits,
1208
1209    /// Backward of the fused loss above. Emits
1210    /// `dlogits[n,c] = (softmax(logits[n])[c] - one_hot(labels)[n,c]) * d_loss[n]`.
1211    /// Inputs: `[logits, labels, d_loss]`. Output: `[N, C]` (same shape
1212    /// as `logits`). Recomputes the softmax inline rather than threading
1213    /// it through from the forward node.
1214    SoftmaxCrossEntropyBackward,
1215
1216    /// Backward of [`Op::Attention`]. Recomputes scaled `QK^T`, applies
1217    /// the same `mask_kind` as the forward op, softmaxes scores, then
1218    /// emits **one** of `dQ`, `dK`, or `dV` selected by [`AttentionBwdWrt`].
1219    /// Autodiff emits three nodes (one per `wrt`) so each output shape
1220    /// stays a normal single-output MIR node.
1221    ///
1222    /// Inputs: `[q, k, v, dy]` plus optional mask when `mask_kind` is
1223    /// [`MaskKind::Custom`] or [`MaskKind::Bias`] (same convention as
1224    /// forward). Output shape matches `q`, `k`, or `v` respectively.
1225    AttentionBackward {
1226        num_heads: usize,
1227        head_dim: usize,
1228        mask_kind: MaskKind,
1229        wrt: AttentionBwdWrt,
1230    },
1231
1232    // ── Fused operations (created by optimization passes) ──────
1233    /// Fused matmul + bias + activation. Created from MatMul → Add → Activation.
1234    FusedMatMulBiasAct {
1235        activation: Option<Activation>,
1236    },
1237
1238    /// Fused residual + optional bias + layer norm.
1239    /// Created from Add(x, residual) → [Add(bias)] → LayerNorm.
1240    FusedResidualLN {
1241        has_bias: bool,
1242        eps: f32,
1243    },
1244
1245    /// Fused residual + optional bias + RMS norm.
1246    /// Created from Add(x, residual) → [Add(bias)] → RmsNorm.
1247    FusedResidualRmsNorm {
1248        has_bias: bool,
1249        eps: f32,
1250    },
1251
1252    /// Fused SwiGLU: split input into up/gate halves, silu(gate) * up.
1253    /// Created from Split → Silu → Mul when fed by a concatenated matmul.
1254    ///
1255    /// `cast_to`: optional output dtype — when `Some(dt)` the kernel casts
1256    /// its result from the input dtype to `dt` in-register, saving a
1257    /// separate cast pass. Reserved for future fp8/fp4 quantization paths;
1258    /// for f32→f16 mixed precision the AutoMixedPrecision pass already
1259    /// inserts a Cast node so this stays `None` in current pipelines.
1260    FusedSwiGLU {
1261        cast_to: Option<DType>,
1262        /// When `true`, the concatenated input stores gate in the low half
1263        /// `[..., 0..N)` and up in the high half `[..., N..2N)` — the layout
1264        /// produced when gate projection is emitted before up in the builder.
1265        /// Default `false`: up @ low, gate @ high (canonical concat order).
1266        gate_first: bool,
1267    },
1268
1269    /// Fused full transformer layer: attention block + residual+LN + FFN + residual+LN.
1270    /// All intermediates resident in registers/threadgroup memory; one kernel
1271    /// per layer instead of ~30 (the CPU's batch=1 win, lifted to IR so any
1272    /// backend can implement it as a monolithic kernel).
1273    ///
1274    /// Inputs: hidden, qkv_w, qkv_b, out_w, out_b,
1275    ///         ln1_g, ln1_b, fc1_w, fc1_b, fc2_w, fc2_b, ln2_g, ln2_b, mask
1276    /// Output: same shape as hidden.
1277    ///
1278    /// **Backend status:** same as FusedAttentionBlock. CPU implements
1279    /// the L1-cache-resident merge at the thunk level. Metal deferred —
1280    /// requires a single MSL kernel for the whole layer to actually
1281    /// beat the unfused path. Multi-day work; revisit when there's a
1282    /// model whose Metal inference is bottlenecked here rather than on
1283    /// the wait latency floor.
1284    FusedTransformerLayer {
1285        num_heads: usize,
1286        head_dim: usize,
1287        intermediate_size: usize,
1288        eps1: f32,
1289        eps2: f32,
1290        activation: Activation,
1291        has_bias: bool,
1292    },
1293
1294    /// Fused attention block: QKV projection → split → \[RoPE\] → SDPA → output projection.
1295    /// Created by FuseAttentionBlock pass when batch*seq is small.
1296    /// All intermediates stay in L1 cache — no arena writes between ops.
1297    ///
1298    /// Inputs (in order):
1299    ///   hidden, qkv_w, out_w, mask,
1300    ///   [qkv_b, out_b]      if has_bias,
1301    ///   [rope_cos, rope_sin] if has_rope
1302    ///
1303    /// **Backend status (Phase C finalize):**
1304    ///   CPU  — implemented at the *thunk* level: the CPU schedule
1305    ///          recognizes the multi-thunk pattern and merges into
1306    ///          a single FusedAttnBlock that keeps Q/K/V in stack
1307    ///          buffers across stages (the L1-cache win).
1308    ///   Metal — **deferred**. A dispatch-wrapper version (chaining
1309    ///          existing kernels) buys nothing the unfused Metal path
1310    ///          doesn't already get, since per-run cost is dominated
1311    ///          by `wait_until_completed` (~150 µs), not encode. The
1312    ///          real win is a single MSL kernel keeping Q/K/V in
1313    ///          threadgroup memory across stages — multi-day work.
1314    ///          Until then, Metal runs the unfused chain (one matmul,
1315    ///          three narrows, two ropes, attention, one matmul) — all
1316    ///          covered in op_coverage and parity_harness.
1317    FusedAttentionBlock {
1318        num_heads: usize,
1319        head_dim: usize,
1320        has_bias: bool,
1321        has_rope: bool,
1322    },
1323
1324    // ── Control flow (subgraphs as op payloads) ─────────────────
1325    //
1326    // Status: IR is defined; helper `run_if` / `run_while` exist in
1327    // rlx-runtime/src/subgraph.rs; **executor wiring is not yet
1328    // implemented** (both CPU thunk and Metal thunk fall through to
1329    // `Thunk::Nop` for these ops). Wiring requires:
1330    //   1. Recursive subgraph compile at parent-compile time.
1331    //   2. Per-subgraph input/output binding through the arena.
1332    //   3. Schedule-level dispatch when the predicate / loop cond is
1333    //      resolved at runtime.
1334    // Estimate: 4–6 hours of focused work + parity tests. Deferred
1335    // because no current in-tree model uses these ops;
1336    // surface area without a validation target invites silent bugs.
1337    /// Conditional: pick between two subgraphs based on a boolean predicate.
1338    /// Inputs: [predicate, ...captures (used inside both branches)].
1339    /// `then_branch` and `else_branch` are sub-graphs that share the
1340    /// captured inputs and must produce identically-shaped outputs.
1341    /// Used for: shape-dependent execution, batched inference of
1342    /// dynamic-length sequences with padding masks.
1343    If {
1344        then_branch: Box<crate::Graph>,
1345        else_branch: Box<crate::Graph>,
1346    },
1347
1348    /// Loop: iterate `body` while `cond` evaluates true.
1349    /// Inputs: [...initial loop-carried values].
1350    /// `cond`'s single output is a Bool scalar.
1351    /// `body`'s outputs become the next iteration's loop-carried inputs.
1352    /// Outputs of While are the values after the final iteration.
1353    /// Used for: KV-cache-driven autoregressive generation, beam search.
1354    While {
1355        cond: Box<crate::Graph>,
1356        body: Box<crate::Graph>,
1357        max_iterations: Option<usize>,
1358    },
1359
1360    /// Bounded-length loop with a fixed-shape carry, optional per-step
1361    /// inputs, and optional stacked output. Mirrors JAX's `lax.scan`.
1362    ///
1363    /// Body signature: `(carry, x_t_0, ..., x_t_{num_xs-1}) → carry_next`
1364    /// — `1 + num_xs` Op::Inputs in NodeId construction order (first
1365    /// declared is the carry; the remaining `num_xs` are per-step
1366    /// slices). Single output (the next carry).
1367    ///
1368    /// Outer Op::Scan inputs (in order):
1369    ///   `[init_carry, xs_0, xs_1, ..., xs_{num_xs-1}]`
1370    /// Each `xs_i` has shape `[length, *per_step_shape_i]`; the body
1371    /// sees `xs_i[t]` (a `per_step_shape_i` slice) on iteration `t`.
1372    ///
1373    /// Outer Op::Scan output:
1374    ///   * `save_trajectory == false` — final carry, shape `*carry`.
1375    ///   * `save_trajectory == true`  — stacked trajectory of carries,
1376    ///     shape `[length, *carry]`. Row `t` is the carry after step
1377    ///     `t+1`, so row `length-1` matches the no-trajectory case.
1378    ///
1379    /// Mirrors JAX's `lax.scan`. Common uses include time-stepping
1380    /// integrators with time-varying drives, Mamba-style SSM scans
1381    /// reading per-step inputs, and RNN-style sequence processing.
1382    Scan {
1383        body: Box<crate::Graph>,
1384        length: u32,
1385        save_trajectory: bool,
1386        /// Number of "broadcast" inputs — values that are constant
1387        /// across iterations. Outer scan inputs in order:
1388        ///   `[init, bcast_0..bcast_{B-1}, xs_0..xs_{X-1}]`
1389        /// Body Op::Inputs in NodeId order:
1390        ///   `[carry, bcast_0..bcast_{B-1}, x_t_0..x_t_{X-1}]`
1391        /// CPU executor fills bcast slots ONCE before the iteration
1392        /// loop (xs slots are filled per-step). The reverse-mode AD
1393        /// pre-pass materialises each bcast into an xs of shape
1394        /// `[length, *bcast]` via broadcast `Mul` so the rest of the
1395        /// VJP / executor pipeline can stay unchanged. `0` (default)
1396        /// keeps the original carry+xs scan shape.
1397        num_bcast: u32,
1398        /// Number of per-step `xs` inputs. Total outer Op::Scan
1399        /// inputs is `1 + num_bcast + num_xs`.
1400        num_xs: u32,
1401        /// Number of trajectory checkpoints when `save_trajectory ==
1402        /// true`. `0` means "save all `length` rows" (default). A
1403        /// positive value `K` means save only `K` evenly-spaced rows
1404        /// at indices `floor(t * length / K)` for `t in 0..K`. Used
1405        /// by recursive checkpointed AD: store O(√T) carries during
1406        /// forward, recompute the rest in the backward pass.
1407        ///
1408        /// When `0` (or `K == length`), the saved trajectory has
1409        /// shape `[length, *carry]` — same as the original behavior.
1410        /// When `0 < K < length`, the saved trajectory has shape
1411        /// `[K, *carry]`.
1412        num_checkpoints: u32,
1413    },
1414
1415    /// Reverse-mode AD companion to `Op::Scan` — extracts the carry
1416    /// gradient `dinit`. Walks `t = length-1 .. 0`, applying `body_vjp`
1417    /// to thread `dcarry` back through the time loop.
1418    ///
1419    /// Inputs (in order):
1420    ///   `[init, trajectory, upstream, xs_0, ..., xs_{num_xs-1}]`
1421    /// Output: `dinit`, shape = carry shape.
1422    ///
1423    /// `body_vjp` is the result of
1424    /// `autodiff::grad(body, [carry_id, xs_0_id, ..., xs_{num_xs-1}_id])`
1425    /// — a graph with `1 + num_xs + 1` Inputs (carry + x_t_i for each
1426    /// xs + `"d_output"`) and `1 + num_xs` outputs
1427    /// (dcarry + dx_t_i for each xs). This op reads `outputs[0]` =
1428    /// dcarry; the sibling [`Self::ScanBackwardXs`] reads the
1429    /// `outputs[1 + xs_idx]` slot for each xs gradient.
1430    ScanBackward {
1431        body_vjp: Box<crate::Graph>,
1432        length: u32,
1433        save_trajectory: bool,
1434        num_xs: u32,
1435        /// When `0` or equal to `length`, the trajectory input has
1436        /// shape `[length, *carry]` — every step's carry is cached
1437        /// (`CheckpointStrategy::All`). When `0 < K < length`, the
1438        /// trajectory input has shape `[K, *carry]` and the executor
1439        /// recomputes intermediate carries via `forward_body` between
1440        /// checkpoints. `forward_body` must be `Some` whenever this
1441        /// is < length.
1442        num_checkpoints: u32,
1443        /// Forward body (the same `body` from the forward Op::Scan).
1444        /// Required when `num_checkpoints > 0 && < length` so the
1445        /// executor can recompute carries between saved checkpoints.
1446        /// `None` for the All strategy (no recompute needed).
1447        forward_body: Option<Box<crate::Graph>>,
1448    },
1449
1450    /// Companion to [`Self::ScanBackward`] that extracts one stacked
1451    /// per-step `dxs_i` (shape `[length, *per_step_xs_i]`). Same inputs
1452    /// and same `body_vjp` graph as ScanBackward — `xs_idx` selects
1453    /// which body_vjp output to stack into the result.
1454    ///
1455    /// Note: each ScanBackwardXs runs its own backward loop. A future
1456    /// optimization can fuse them into a single multi-output backward
1457    /// kernel; for now it's `1 + num_xs` independent sweeps.
1458    ScanBackwardXs {
1459        body_vjp: Box<crate::Graph>,
1460        length: u32,
1461        save_trajectory: bool,
1462        num_xs: u32,
1463        xs_idx: u32,
1464        num_checkpoints: u32,
1465        forward_body: Option<Box<crate::Graph>>,
1466    },
1467
1468    /// CPU reference 3D Gaussian splat forward render.
1469    ///
1470    /// Seven flat F32 inputs (scene buffers + camera/render meta):
1471    ///   0. positions `[N*3]`
1472    ///   1. scales `[N*3]` (log-space)
1473    ///   2. rotations `[N*4]` (xyzw)
1474    ///   3. opacities `[N]` (logit)
1475    ///   4. colors `[N*3]` (linear RGB)
1476    ///   5. sh_coeffs `[N * sh_coeff_count * 3]`
1477    ///   6. meta `[23]` — camera position/target/up/fov/near/far, background RGB,
1478    ///      then width/height/tile_size/radius_scale/alpha_cutoff/max_splat_steps/
1479    ///      transmittance_threshold/max_list_entries as f32 bit-patterns.
1480    ///
1481    /// Output: `[height * width * 4]` linear RGBA (display gamma baked in).
1482    /// Build via [`crate::Graph::gaussian_splat_render`].
1483    ///
1484    /// Differentiable backward is not implemented in v1; autodiff treats this
1485    /// op as non-differentiable (same as [`Op::Sample`]).
1486    GaussianSplatRender {
1487        width: u32,
1488        height: u32,
1489        tile_size: u32,
1490        radius_scale: f32,
1491        alpha_cutoff: f32,
1492        max_splat_steps: u32,
1493        transmittance_threshold: f32,
1494        max_list_entries: u32,
1495    },
1496
1497    /// Backward pass for [`Self::GaussianSplatRender`].
1498    ///
1499    /// Eight inputs: the same seven as forward plus `d_loss_rgba` `[W*H*4]`
1500    /// (only RGB channels are used). Re-runs the training forward internally.
1501    ///
1502    /// Output: packed gradients
1503    /// `[positions(3N) | scales(3N) | rotations(4N) | opacities(N) | colors(3N) | sh(N*sh*3)]`.
1504    /// Unpack via [`crate::ops::splat::unpack_gaussian_splat_packed_grads`].
1505    GaussianSplatRenderBackward {
1506        width: u32,
1507        height: u32,
1508        tile_size: u32,
1509        radius_scale: f32,
1510        alpha_cutoff: f32,
1511        max_splat_steps: u32,
1512        transmittance_threshold: f32,
1513        max_list_entries: u32,
1514        loss_grad_clip: f32,
1515        sh_band: u32,
1516        max_anisotropy: f32,
1517    },
1518
1519    /// Strict IR stage 1: project, bin, sort, build per-pixel rays.
1520    ///
1521    /// Seven inputs (same scene + meta as [`Self::GaussianSplatRender`]). Output: packed
1522    /// prepare buffer (see `rlx_splat::prep_layout::prep_packed_len`).
1523    GaussianSplatPrepare {
1524        width: u32,
1525        height: u32,
1526        tile_size: u32,
1527        radius_scale: f32,
1528        alpha_cutoff: f32,
1529        max_splat_steps: u32,
1530        transmittance_threshold: f32,
1531        max_list_entries: u32,
1532    },
1533
1534    /// Strict IR stage 2: tile raster from [`Self::GaussianSplatPrepare`] output.
1535    ///
1536    /// Inputs: `prep` packed buffer, `meta` `[23]`. Output: `[width * height * 4]` RGBA.
1537    GaussianSplatRasterize {
1538        width: u32,
1539        height: u32,
1540        tile_size: u32,
1541        alpha_cutoff: f32,
1542        max_splat_steps: u32,
1543        transmittance_threshold: f32,
1544        max_list_entries: u32,
1545    },
1546
1547    /// User-registered custom op. `name` keys into the
1548    /// [`crate::op_registry`] for shape inference, autodiff, and
1549    /// per-backend execution. `attrs` is an opaque blob passed
1550    /// through to those callbacks (FFT direction, SparseLU
1551    /// reordering strategy, etc.). `num_inputs` is captured at
1552    /// construction time so [`Op::num_inputs`] stays infallible
1553    /// without a registry lookup. Build via [`crate::Graph::custom_op`].
1554    Custom {
1555        name: String,
1556        num_inputs: u32,
1557        attrs: Vec<u8>,
1558    },
1559
1560    /// 1D Fast Fourier Transform along the last axis.
1561    ///
1562    /// **Layouts**
1563    /// - `F32` / `F64`: 2N real-block — last axis is `[re₀…re_{N-1}, im₀…im_{N-1}]`.
1564    /// - `C64`: interleaved `[re, im]` pairs per complex element along the last axis.
1565    ///
1566    /// **ND transforms** — use `Graph::fftn` / `Graph::ifftn`, which compose
1567    /// `fft_axis` (transpose → Fft → transpose). Multi-axis `fftn` requires
1568    /// `DType::C64`; the 2N-block layout describes a single complex axis.
1569    ///
1570    /// Default (`FftNorm::Backward`) is **unnormalized** on both directions:
1571    ///   `fft(x)[k] = Σ x[n]·exp(-2πi·nk/N)`
1572    ///   `ifft(y)[n] = Σ y[k]·exp(+2πi·nk/N)`
1573    /// so `ifft(fft(x)) = N·x`. Use `FftNorm::Forward` for gpu-fft-style
1574    /// `1/N` scaling on inverse, or `FftNorm::Ortho` for unitary scaling.
1575    ///
1576    /// AD: VJP(`fft`) = `ifft`, VJP(`ifft`) = `fft` when `norm=Backward`;
1577    /// other norms apply the chain rule via output scaling.
1578    Fft {
1579        inverse: bool,
1580        norm: crate::fft::FftNorm,
1581    },
1582
1583    /// Ternary pruned radix-2 butterfly stage on interleaved complex state.
1584    ///
1585    /// Inputs:
1586    ///   0 — state `[batch, n_fft, 2]` (re/im on axis 2)
1587    ///   1 — gate `[half]` — 0 = identity, 1 = run butterfly (`half = n_fft/2`)
1588    ///   2 — rev `[half]` — 0 = forward, 1 = swap outputs when gate=1
1589    ///   3 — tw_re `[half]`
1590    ///   4 — tw_im `[half]`
1591    ///
1592    /// Output: `[batch, n_fft, 2]` same layout. Slots with gate=0 copy inputs
1593    /// without twiddle math.
1594    FftButterflyStage {
1595        stage: u32,
1596        n_fft: u32,
1597    },
1598
1599    /// Log-mel spectrogram from RLX FFT block-layout spectrum.
1600    ///
1601    /// Inputs:
1602    ///   0 — spectrum `[..., 2*n_fft]` (re plane then im plane, same as `Op::Fft` output)
1603    ///   1 — mel filterbank `[n_mels, n_bins]` with `n_bins = n_fft/2 + 1`
1604    ///
1605    /// Output: `[..., n_mels]` with Whisper dynamic-range compression
1606    /// (`log10`, clamp to max−8 dB, `(x+4)/4`).
1607    LogMel,
1608
1609    /// VJP of [`Op::LogMel`] w.r.t. spectrum (input 0).
1610    ///
1611    /// Inputs: spectrum block, mel filters, upstream `dy`.
1612    /// Output: `d_spectrum` (same shape as input 0).
1613    LogMelBackward,
1614
1615    /// Top-K Welch peaks from block-layout segment spectra.
1616    ///
1617    /// Input 0: spectrum `[batch * n_segments, 2*n_fft]` (re ∥ im planes).
1618    /// Output: `[batch, k*2]` interleaved `(bin, power)` per spike.
1619    WelchPeaks {
1620        k: usize,
1621        n_segments: usize,
1622    },
1623
1624    /// User-defined sub-graph with optional override AD rules.
1625    /// Mirrors JAX's `custom_vjp` / `custom_jvp` decorators: the
1626    /// caller wraps a forward computation and supplies its own
1627    /// reverse- and/or forward-mode AD bodies. Useful when:
1628    ///   * The forward is iterative (Newton, fixed-point) and
1629    ///     differentiating through the loop is wasteful — the
1630    ///     vjp_body computes the implicit-function gradient at the
1631    ///     converged point in one shot.
1632    ///   * The math has a closed-form gradient that's much cheaper
1633    ///     than autodiff.
1634    ///   * The forward op is non-differentiable by tracing
1635    ///     (sampling, argmax) and the user wants a smooth surrogate.
1636    ///
1637    /// **fwd_body**: `num_inputs` Op::Inputs in NodeId construction
1638    /// order, one Op::Output (the primal y). Forward execution
1639    /// inlines this body once.
1640    ///
1641    /// **vjp_body** (optional): Op::Inputs are `num_inputs` primal
1642    /// inputs in NodeId order, plus two special-named Inputs —
1643    /// `"primal_output"` (the y from forward) and `"d_output"` (the
1644    /// upstream gradient). Outputs: `num_inputs` tensors in
1645    /// `set_outputs` order, matching the gradients of each primal
1646    /// input. When `None`, reverse-mode AD recurses into fwd_body
1647    /// — same as if the op were inlined.
1648    ///
1649    /// **jvp_body** (optional): Op::Inputs are `num_inputs` primal
1650    /// inputs in NodeId order, `num_inputs` special-named Inputs
1651    /// `"tangent_0"..="tangent_{num_inputs-1}"` carrying each input's
1652    /// tangent, and an optional special-named `"primal_output"` Input
1653    /// (the y from forward, useful when the JVP must be evaluated at
1654    /// a converged / nonlinear point — e.g. IFT-style forward-mode AD
1655    /// of an iterative solver). Output: 1 tensor (the tangent of y).
1656    /// When `None`, forward-mode AD recurses into fwd_body.
1657    ///
1658    /// `num_inputs` is captured so [`Op::num_inputs`] stays
1659    /// infallible. Build via [`crate::Graph::custom_fn`].
1660    CustomFn {
1661        fwd_body: Box<crate::Graph>,
1662        vjp_body: Option<Box<crate::Graph>>,
1663        jvp_body: Option<Box<crate::Graph>>,
1664        num_inputs: u32,
1665    },
1666}
1667
1668impl Op {
1669    /// PLAN L4: discriminant for backend-supported-set checks.
1670    /// Stable, parameter-free identity per variant — `Op::Activation(_)`
1671    /// and `Op::Activation(Relu)` share the same `OpKind::Activation`.
1672    pub fn kind(&self) -> OpKind {
1673        match self {
1674            Op::Input { .. } => OpKind::Input,
1675            Op::Param { .. } => OpKind::Param,
1676            Op::Constant { .. } => OpKind::Constant,
1677            Op::Activation(_) => OpKind::Activation,
1678            Op::Cast { .. } => OpKind::Cast,
1679            Op::StopGradient => OpKind::StopGradient,
1680            Op::Quantize { .. } => OpKind::Quantize,
1681            Op::Dequantize { .. } => OpKind::Dequantize,
1682            Op::FakeQuantize { .. } => OpKind::FakeQuantize,
1683            Op::FakeQuantizeLSQ { .. } => OpKind::FakeQuantizeLSQ,
1684            Op::FakeQuantizeLSQBackwardX { .. } => OpKind::FakeQuantizeLSQBackwardX,
1685            Op::FakeQuantizeLSQBackwardScale { .. } => OpKind::FakeQuantizeLSQBackwardScale,
1686            Op::Binary(_) => OpKind::Binary,
1687            Op::Compare(_) => OpKind::Compare,
1688            Op::Where => OpKind::Where,
1689            Op::ElementwiseRegion { .. } => OpKind::ElementwiseRegion,
1690            Op::TransformRegion { .. } => OpKind::TransformRegion,
1691            Op::BatchElementwiseRegion { .. } => OpKind::BatchElementwiseRegion,
1692            Op::MatMul => OpKind::MatMul,
1693            Op::DotGeneral { .. } => OpKind::DotGeneral,
1694            Op::DenseSolve => OpKind::DenseSolve,
1695            Op::BatchedDenseSolve => OpKind::BatchedDenseSolve,
1696            Op::LayerNorm { .. } => OpKind::LayerNorm,
1697            Op::LayerNorm2d { .. } => OpKind::LayerNorm2d,
1698            Op::GroupNorm { .. } => OpKind::GroupNorm,
1699            Op::BatchNormInference { .. } => OpKind::BatchNormInference,
1700            Op::RmsNorm { .. } => OpKind::RmsNorm,
1701            Op::ResizeNearest2x => OpKind::ResizeNearest2x,
1702            Op::Attention { .. } => OpKind::Attention,
1703            Op::Rope { .. } => OpKind::Rope,
1704            Op::AxialRope2d { .. } => OpKind::AxialRope2d,
1705            Op::Reshape { .. } => OpKind::Reshape,
1706            Op::Transpose { .. } => OpKind::Transpose,
1707            Op::Narrow { .. } => OpKind::Narrow,
1708            Op::Concat { .. } => OpKind::Concat,
1709            Op::Expand { .. } => OpKind::Expand,
1710            Op::Gather { .. } => OpKind::Gather,
1711            Op::Reduce { .. } => OpKind::Reduce,
1712            Op::Softmax { .. } => OpKind::Softmax,
1713            Op::Cumsum { .. } => OpKind::Cumsum,
1714            Op::TopK { .. } => OpKind::TopK,
1715            Op::Sample { .. } => OpKind::Sample,
1716            Op::Conv { .. } => OpKind::Conv,
1717            Op::Im2Col { .. } => OpKind::Im2Col,
1718            Op::ConvTranspose2d { .. } => OpKind::ConvTranspose2d,
1719            Op::Pool { .. } => OpKind::Pool,
1720            Op::ReluBackward => OpKind::ReluBackward,
1721            Op::ActivationBackward { .. } => OpKind::ActivationBackward,
1722            Op::FakeQuantizeBackward { .. } => OpKind::FakeQuantizeBackward,
1723            Op::ComplexNormSq => OpKind::ComplexNormSq,
1724            Op::ComplexNormSqBackward => OpKind::ComplexNormSqBackward,
1725            Op::Conjugate => OpKind::Conjugate,
1726            Op::LayerNormBackwardInput { .. } => OpKind::LayerNormBackwardInput,
1727            Op::LayerNormBackwardGamma { .. } => OpKind::LayerNormBackwardGamma,
1728            Op::RmsNormBackwardInput { .. } => OpKind::RmsNormBackwardInput,
1729            Op::RmsNormBackwardGamma { .. } => OpKind::RmsNormBackwardGamma,
1730            Op::RmsNormBackwardBeta { .. } => OpKind::RmsNormBackwardBeta,
1731            Op::RopeBackward { .. } => OpKind::RopeBackward,
1732            Op::GroupNormBackwardInput { .. } => OpKind::GroupNormBackwardInput,
1733            Op::GroupNormBackwardGamma { .. } => OpKind::GroupNormBackwardGamma,
1734            Op::GroupNormBackwardBeta { .. } => OpKind::GroupNormBackwardBeta,
1735            Op::BatchNormInferenceBackwardInput { .. } => OpKind::BatchNormInferenceBackwardInput,
1736            Op::BatchNormInferenceBackwardGamma { .. } => OpKind::BatchNormInferenceBackwardGamma,
1737            Op::BatchNormInferenceBackwardBeta => OpKind::BatchNormInferenceBackwardBeta,
1738            Op::CumsumBackward { .. } => OpKind::CumsumBackward,
1739            Op::GatherBackward { .. } => OpKind::GatherBackward,
1740            Op::MaxPool2dBackward { .. } => OpKind::MaxPool2dBackward,
1741            Op::Conv2dBackwardInput { .. } => OpKind::Conv2dBackwardInput,
1742            Op::Conv2dBackwardWeight { .. } => OpKind::Conv2dBackwardWeight,
1743            Op::SoftmaxCrossEntropyWithLogits => OpKind::SoftmaxCrossEntropyWithLogits,
1744            Op::SoftmaxCrossEntropyBackward => OpKind::SoftmaxCrossEntropyBackward,
1745            Op::AttentionBackward { .. } => OpKind::AttentionBackward,
1746            Op::GroupedMatMul => OpKind::GroupedMatMul,
1747            Op::DequantGroupedMatMul { .. } => OpKind::DequantGroupedMatMul,
1748            Op::DequantMoEWeights { .. } => OpKind::DequantMoEWeights,
1749            Op::ScatterAdd => OpKind::ScatterAdd,
1750            Op::LoraMatMul { .. } => OpKind::LoraMatMul,
1751            Op::DequantMatMul { .. } => OpKind::DequantMatMul,
1752            Op::QMatMul { .. } => OpKind::QMatMul,
1753            Op::QConv2d { .. } => OpKind::QConv2d,
1754            Op::SelectiveScan { .. } => OpKind::SelectiveScan,
1755            Op::GatedDeltaNet { .. } => OpKind::GatedDeltaNet,
1756            Op::FusedSwiGLU { .. } => OpKind::FusedSwiGLU,
1757            Op::FusedMatMulBiasAct { .. } => OpKind::FusedMatMulBiasAct,
1758            Op::FusedResidualLN { .. } => OpKind::FusedResidualLN,
1759            Op::FusedResidualRmsNorm { .. } => OpKind::FusedResidualRmsNorm,
1760            Op::FusedAttentionBlock { .. } => OpKind::FusedAttentionBlock,
1761            Op::FusedTransformerLayer { .. } => OpKind::FusedTransformerLayer,
1762            Op::If { .. } => OpKind::If,
1763            Op::While { .. } => OpKind::While,
1764            Op::Scan { .. } => OpKind::Scan,
1765            Op::ScanBackward { .. } => OpKind::ScanBackward,
1766            Op::ScanBackwardXs { .. } => OpKind::ScanBackwardXs,
1767            Op::GaussianSplatRender { .. } => OpKind::GaussianSplatRender,
1768            Op::GaussianSplatRenderBackward { .. } => OpKind::GaussianSplatRenderBackward,
1769            Op::GaussianSplatPrepare { .. } => OpKind::GaussianSplatPrepare,
1770            Op::GaussianSplatRasterize { .. } => OpKind::GaussianSplatRasterize,
1771            Op::Custom { .. } => OpKind::Custom,
1772            Op::CustomFn { .. } => OpKind::CustomFn,
1773            Op::Fft { .. } => OpKind::Fft,
1774            Op::FftButterflyStage { .. } => OpKind::FftButterflyStage,
1775            Op::LogMel => OpKind::LogMel,
1776            Op::LogMelBackward => OpKind::LogMelBackward,
1777            Op::WelchPeaks { .. } => OpKind::WelchPeaks,
1778        }
1779    }
1780
1781    /// True if this op is element-wise (same shape in, same shape out).
1782    /// Element-wise ops are prime fusion candidates.
1783    pub fn is_elementwise(&self) -> bool {
1784        matches!(
1785            self,
1786            Op::Activation(_)
1787                | Op::Cast { .. }
1788                | Op::StopGradient
1789                | Op::Binary(_)
1790                | Op::Compare(_)
1791                | Op::Where
1792                | Op::ElementwiseRegion { .. }
1793                | Op::BatchElementwiseRegion { .. }
1794        )
1795    }
1796
1797    /// True if this op may appear in a [`Op::TransformRegion`] chain.
1798    pub fn is_transform_eligible(&self) -> bool {
1799        matches!(self, Op::ResizeNearest2x)
1800    }
1801
1802    /// True if this op is a BLAS/compute-intensive op that forms a fusion boundary.
1803    pub fn is_blas(&self) -> bool {
1804        matches!(
1805            self,
1806            Op::MatMul
1807                | Op::DotGeneral { .. }
1808                | Op::DenseSolve
1809                | Op::BatchedDenseSolve
1810                | Op::Conv { .. }
1811                | Op::Im2Col { .. }
1812                | Op::ConvTranspose2d { .. }
1813                | Op::FusedMatMulBiasAct { .. }
1814                | Op::GroupedMatMul
1815                | Op::DequantGroupedMatMul { .. }
1816                | Op::DequantMoEWeights { .. }
1817                | Op::LoraMatMul { .. }
1818                | Op::DequantMatMul { .. }
1819                | Op::QMatMul { .. }
1820                | Op::QConv2d { .. }
1821        )
1822    }
1823
1824    /// True if element-wise fusion must not span across this op.
1825    pub fn is_fusion_boundary(&self) -> bool {
1826        self.is_blas()
1827            || matches!(
1828                self,
1829                Op::GaussianSplatRender { .. }
1830                    | Op::GaussianSplatRenderBackward { .. }
1831                    | Op::GaussianSplatPrepare { .. }
1832                    | Op::GaussianSplatRasterize { .. }
1833            )
1834    }
1835
1836    /// True if this op is a reduction (drives loop iteration in fused kernels).
1837    pub fn is_reduction(&self) -> bool {
1838        matches!(
1839            self,
1840            Op::Reduce { .. } | Op::Softmax { .. } | Op::TopK { .. }
1841        )
1842    }
1843
1844    /// Number of tensor inputs this op expects.
1845    pub fn num_inputs(&self) -> usize {
1846        match self {
1847            Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => 0,
1848            Op::Activation(_)
1849            | Op::Cast { .. }
1850            | Op::StopGradient
1851            | Op::Reshape { .. }
1852            | Op::Quantize { .. }
1853            | Op::Dequantize { .. }
1854            | Op::Transpose { .. }
1855            | Op::Narrow { .. }
1856            | Op::Expand { .. }
1857            | Op::Reduce { .. }
1858            | Op::Softmax { .. }
1859            | Op::FusedSwiGLU { .. }
1860            | Op::TopK { .. }
1861            | Op::Cumsum { .. }
1862            | Op::Sample { .. }
1863            | Op::ResizeNearest2x => 1,
1864            // EMA / Fixed scale modes carry a state tensor as a 2nd input;
1865            // PerBatch (default) doesn't need one.
1866            Op::FakeQuantize { scale_mode, .. } => match scale_mode {
1867                ScaleMode::PerBatch => 1,
1868                ScaleMode::EMA { .. } | ScaleMode::Fixed => 2,
1869            },
1870            Op::FakeQuantizeLSQ { .. } => 2, // x, scale (learned param)
1871            Op::FakeQuantizeLSQBackwardX { .. } | Op::FakeQuantizeLSQBackwardScale { .. } => 3, // x, scale, dy
1872            Op::Binary(_) | Op::Compare(_) | Op::Gather { .. } | Op::MatMul | Op::ScatterAdd => 2,
1873            Op::GroupedMatMul => 3,               // input, weight, expert_idx
1874            Op::DequantGroupedMatMul { .. } => 3, // input, packed_w, expert_idx
1875            Op::DequantMoEWeights { .. } => 1,    // packed_w
1876            Op::LoraMatMul { .. } => 4,           // x, w, a, b
1877            // x, w_q, scale, zp — or x, packed_w_bytes for GGUF
1878            // schemes (their scales/mins live inside the packed bytes,
1879            // see `QuantScheme::is_gguf`).
1880            Op::DequantMatMul { scheme } => {
1881                if scheme.is_gguf() {
1882                    2
1883                } else {
1884                    4
1885                }
1886            }
1887            Op::QMatMul { .. } => 3,       // x, w, bias
1888            Op::QConv2d { .. } => 3,       // x, w, bias
1889            Op::SelectiveScan { .. } => 5, // x, delta, a, b, c
1890            Op::GatedDeltaNet { carry_state, .. } if *carry_state => 6, // + state in/out
1891            Op::GatedDeltaNet { .. } => 5, // q, k, v, g, beta
1892            Op::Where => 3,                // cond, on_true, on_false
1893            Op::Attention { mask_kind, .. } => match mask_kind {
1894                MaskKind::Custom | MaskKind::Bias => 4, // Q, K, V, mask
1895                _ => 3,                                 // Q, K, V (mask synthesized in-kernel)
1896            },
1897            Op::AttentionBackward { mask_kind, .. } => match mask_kind {
1898                MaskKind::Custom | MaskKind::Bias => 5, // q, k, v, dy, mask
1899                _ => 4,                                 // q, k, v, dy
1900            },
1901            Op::Rope { .. } => 3, // x, cos, sin
1902            Op::AxialRope2d { .. } => 1,
1903            Op::LayerNorm { .. }
1904            | Op::LayerNorm2d { .. }
1905            | Op::GroupNorm { .. }
1906            | Op::RmsNorm { .. } => 3, // input, gamma, beta
1907            Op::BatchNormInference { .. } => 5, // x, gamma, beta, mean, var
1908            Op::FusedMatMulBiasAct { .. } => 3, // input, weight, bias
1909            Op::FusedResidualLN { has_bias: true, .. } => 5, // x, residual, bias, gamma, beta
1910            Op::FusedResidualLN {
1911                has_bias: false, ..
1912            } => 4, // x, residual, gamma, beta
1913            Op::FusedResidualRmsNorm { has_bias: true, .. } => 5, // x, residual, bias, gamma, beta
1914            Op::FusedResidualRmsNorm {
1915                has_bias: false, ..
1916            } => 4, // x, residual, gamma, beta
1917            Op::Conv { .. } | Op::ConvTranspose2d { .. } => 2, // input, weight (bias via Add)
1918            Op::Im2Col { .. } => 1,
1919            Op::Pool { .. } => 1,
1920            Op::ReluBackward => 2,                  // x, dy
1921            Op::ActivationBackward { .. } => 2,     // x, dy
1922            Op::FakeQuantizeBackward { .. } => 2,   // x, dy
1923            Op::ComplexNormSq => 1,                 // z (C64)
1924            Op::ComplexNormSqBackward => 2,         // z, g
1925            Op::Conjugate => 1,                     // z (C64)
1926            Op::LayerNormBackwardInput { .. } => 3, // x, gamma, dy
1927            Op::LayerNormBackwardGamma { .. } => 2, // x, dy
1928            Op::RmsNormBackwardInput { .. } => 4,   // x, gamma, beta, dy
1929            Op::RmsNormBackwardGamma { .. } => 4,
1930            Op::RmsNormBackwardBeta { .. } => 4,
1931            Op::RopeBackward { .. } => 3,           // dy, cos, sin
1932            Op::GroupNormBackwardInput { .. } => 4, // x, gamma, beta, dy
1933            Op::GroupNormBackwardGamma { .. } => 2, // x, dy
1934            Op::GroupNormBackwardBeta { .. } => 2,
1935            Op::BatchNormInferenceBackwardInput { .. } => 5, // x, gamma, mean, var, dy
1936            Op::BatchNormInferenceBackwardGamma { .. } => 4, // x, mean, var, dy
1937            Op::BatchNormInferenceBackwardBeta => 1,         // dy
1938            Op::CumsumBackward { .. } => 1,                  // dy
1939            Op::GatherBackward { .. } => 2,                  // dy, indices
1940            Op::MaxPool2dBackward { .. } => 2,               // x, dy
1941            Op::Conv2dBackwardInput { .. } => 2,             // dy, w
1942            Op::Conv2dBackwardWeight { .. } => 2,            // x, dy
1943            Op::SoftmaxCrossEntropyWithLogits => 2,          // logits, labels
1944            Op::SoftmaxCrossEntropyBackward => 3,            // logits, labels, d_loss
1945            Op::Concat { .. } => 0,                          // variadic — checked at graph level
1946            Op::DotGeneral { .. } => 2,
1947            Op::DenseSolve => 2,        // A, b
1948            Op::BatchedDenseSolve => 2, // A [B,N,N], b [B,N] or [B,N,K]
1949            Op::FusedAttentionBlock {
1950                has_bias, has_rope, ..
1951            } => 4 + if *has_bias { 2 } else { 0 } + if *has_rope { 2 } else { 0 },
1952            Op::If { .. } => 1,    // predicate (captures handled separately)
1953            Op::While { .. } => 0, // variadic loop-carried; checked at graph level
1954            Op::Scan {
1955                num_bcast, num_xs, ..
1956            } => 1 + *num_bcast as usize + *num_xs as usize,
1957            Op::ScanBackward { num_xs, .. } => 3 + *num_xs as usize, // init, trajectory, upstream, xs_0..
1958            Op::ScanBackwardXs { num_xs, .. } => 3 + *num_xs as usize, // same as ScanBackward
1959            Op::GaussianSplatRender { .. } => 7,
1960            Op::GaussianSplatRenderBackward { .. } => 8,
1961            Op::GaussianSplatPrepare { .. } => 7,
1962            Op::GaussianSplatRasterize { .. } => 2,
1963            Op::FusedTransformerLayer { has_bias, .. } => {
1964                // hidden + qkv_w + out_w + ln1_g + ln1_b + fc1_w + fc2_w + ln2_g + ln2_b + mask = 10
1965                // bias variant adds: qkv_b + out_b + fc1_b + fc2_b = 4 more
1966                10 + if *has_bias { 4 } else { 0 }
1967            }
1968            Op::ElementwiseRegion { num_inputs, .. } => *num_inputs as usize,
1969            Op::TransformRegion { num_inputs, .. } => *num_inputs as usize,
1970            Op::BatchElementwiseRegion {
1971                num_batch_inputs, ..
1972            } => *num_batch_inputs as usize,
1973            Op::Custom { num_inputs, .. } => *num_inputs as usize,
1974            Op::CustomFn { num_inputs, .. } => *num_inputs as usize,
1975            Op::Fft { .. } => 1,
1976            Op::FftButterflyStage { .. } => 5,
1977            Op::LogMel => 2,
1978            Op::LogMelBackward => 3,
1979            Op::WelchPeaks { .. } => 1,
1980        }
1981    }
1982}
1983
1984impl std::fmt::Display for Op {
1985    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1986        match self {
1987            Op::Input { name } => write!(f, "input(\"{name}\")"),
1988            Op::Param { name } => write!(f, "param(\"{name}\")"),
1989            Op::Constant { data } => write!(f, "const({}B)", data.len()),
1990            Op::Activation(a) => write!(f, "{a:?}"),
1991            Op::Quantize { axis, scales, .. } => match axis {
1992                None => write!(f, "quantize(s={})", scales[0]),
1993                Some(d) => write!(f, "quantize(axis={d},nch={})", scales.len()),
1994            },
1995            Op::Dequantize { axis, scales, .. } => match axis {
1996                None => write!(f, "dequantize(s={})", scales[0]),
1997                Some(d) => write!(f, "dequantize(axis={d},nch={})", scales.len()),
1998            },
1999            Op::FakeQuantize {
2000                bits,
2001                axis,
2002                ste,
2003                scale_mode,
2004            } => match axis {
2005                None => write!(
2006                    f,
2007                    "fake_quant(bits={bits},ste={ste:?},scale={scale_mode:?})"
2008                ),
2009                Some(d) => write!(
2010                    f,
2011                    "fake_quant(bits={bits},axis={d},ste={ste:?},scale={scale_mode:?})"
2012                ),
2013            },
2014            Op::FakeQuantizeLSQ { bits, axis } => match axis {
2015                None => write!(f, "fake_quant_lsq(bits={bits})"),
2016                Some(d) => write!(f, "fake_quant_lsq(bits={bits},axis={d})"),
2017            },
2018            Op::FakeQuantizeLSQBackwardX { bits, .. } => {
2019                write!(f, "fake_quant_lsq_bwd_x(bits={bits})")
2020            }
2021            Op::FakeQuantizeLSQBackwardScale { bits, .. } => {
2022                write!(f, "fake_quant_lsq_bwd_s(bits={bits})")
2023            }
2024            Op::Cast { to } => write!(f, "cast({to})"),
2025            Op::StopGradient => write!(f, "stop_gradient"),
2026            Op::Binary(op) => write!(f, "{op:?}"),
2027            Op::Compare(op) => write!(f, "{op:?}"),
2028            Op::Where => write!(f, "where"),
2029            Op::MatMul => write!(f, "matmul"),
2030            Op::DotGeneral { .. } => write!(f, "dot_general"),
2031            Op::DenseSolve => write!(f, "dense_solve"),
2032            Op::BatchedDenseSolve => write!(f, "batched_dense_solve"),
2033            Op::LayerNorm { eps, .. } => write!(f, "layer_norm(eps={eps})"),
2034            Op::GroupNorm { num_groups, eps } => {
2035                write!(f, "group_norm(groups={num_groups},eps={eps})")
2036            }
2037            Op::BatchNormInference { eps } => write!(f, "batch_norm_inference(eps={eps})"),
2038            Op::ResizeNearest2x => write!(f, "resize_nearest_2x"),
2039            Op::RmsNorm { eps, .. } => write!(f, "rms_norm(eps={eps})"),
2040            Op::Attention {
2041                num_heads,
2042                head_dim,
2043                mask_kind,
2044                score_scale,
2045                attn_logit_softcap,
2046            } => {
2047                let mut s = match mask_kind {
2048                    MaskKind::Custom => format!("attention(h={num_heads},d={head_dim})"),
2049                    MaskKind::None => format!("attention(h={num_heads},d={head_dim},nomask)"),
2050                    MaskKind::Causal => format!("attention(h={num_heads},d={head_dim},causal)"),
2051                    MaskKind::SlidingWindow(w) => {
2052                        format!("attention(h={num_heads},d={head_dim},sw={w})")
2053                    }
2054                    MaskKind::Bias => format!("attention(h={num_heads},d={head_dim},bias)"),
2055                };
2056                if let Some(sc) = score_scale {
2057                    s.push_str(&format!(",scale={sc}"));
2058                }
2059                if let Some(cap) = attn_logit_softcap {
2060                    s.push_str(&format!(",softcap={cap}"));
2061                }
2062                write!(f, "{s}")
2063            }
2064            Op::Rope { head_dim, n_rot } => write!(f, "rope(d={head_dim}, n_rot={n_rot})"),
2065            Op::AxialRope2d {
2066                end_x,
2067                end_y,
2068                head_dim,
2069                num_heads,
2070                theta,
2071                repeat_factor,
2072            } => write!(
2073                f,
2074                "axial_rope2d({end_x}x{end_y},h={num_heads},d={head_dim},θ={theta},r={repeat_factor})"
2075            ),
2076            Op::Reshape { new_shape } => write!(f, "reshape({new_shape:?})"),
2077            Op::Transpose { perm } => write!(f, "transpose({perm:?})"),
2078            Op::Narrow { axis, start, len } => write!(f, "narrow({axis},{start},{len})"),
2079            Op::Concat { axis } => write!(f, "concat(axis={axis})"),
2080            Op::Expand { .. } => write!(f, "expand"),
2081            Op::Gather { axis } => write!(f, "gather(axis={axis})"),
2082            Op::Reduce { op, axes, .. } => write!(f, "reduce_{op:?}({axes:?})"),
2083            Op::Softmax { axis } => write!(f, "softmax(axis={axis})"),
2084            Op::Cumsum { axis, exclusive } => {
2085                if *exclusive {
2086                    write!(f, "cumsum(axis={axis},excl)")
2087                } else {
2088                    write!(f, "cumsum(axis={axis})")
2089                }
2090            }
2091            Op::Sample {
2092                top_k,
2093                top_p,
2094                temperature,
2095                ..
2096            } => {
2097                write!(f, "sample(t={temperature}")?;
2098                if *top_k > 0 {
2099                    write!(f, ",k={top_k}")?;
2100                }
2101                if *top_p < 1.0 {
2102                    write!(f, ",p={top_p}")?;
2103                }
2104                write!(f, ")")
2105            }
2106            Op::TopK { k } => write!(f, "topk(k={k})"),
2107            Op::GroupedMatMul => write!(f, "grouped_matmul"),
2108            Op::DequantGroupedMatMul { scheme } => {
2109                write!(f, "dequant_grouped_matmul({scheme})")
2110            }
2111            Op::DequantMoEWeights { scheme } => write!(f, "dequant_moe_weights({scheme})"),
2112            Op::LoraMatMul { scale } => write!(f, "lora_matmul(scale={scale})"),
2113            Op::DequantMatMul { scheme } => write!(f, "dequant_matmul({scheme})"),
2114            Op::QMatMul {
2115                x_zp,
2116                w_zp,
2117                out_zp,
2118                mult,
2119            } => write!(
2120                f,
2121                "q_matmul(x_zp={x_zp},w_zp={w_zp},out_zp={out_zp},mult={mult})"
2122            ),
2123            Op::QConv2d { kernel_size, .. } => write!(f, "q_conv2d({kernel_size:?})"),
2124            Op::SelectiveScan { state_size } => write!(f, "ssm_scan(n={state_size})"),
2125            Op::GatedDeltaNet {
2126                state_size,
2127                carry_state,
2128            } => {
2129                if *carry_state {
2130                    write!(f, "gated_delta_net(n={state_size},carry)")
2131                } else {
2132                    write!(f, "gated_delta_net(n={state_size})")
2133                }
2134            }
2135            Op::ScatterAdd => write!(f, "scatter_add"),
2136            Op::Conv { kernel_size, .. } => write!(f, "conv2d({kernel_size:?})"),
2137            Op::Im2Col { kernel_size, .. } => write!(f, "im2col({kernel_size:?})"),
2138            Op::ConvTranspose2d { kernel_size, .. } => {
2139                write!(f, "conv_transpose2d({kernel_size:?})")
2140            }
2141            Op::LayerNorm2d { eps } => write!(f, "layer_norm2d(eps={eps})"),
2142            Op::Pool {
2143                kind, kernel_size, ..
2144            } => write!(f, "pool_{kind:?}({kernel_size:?})"),
2145            Op::ReluBackward => write!(f, "relu_backward"),
2146            Op::ActivationBackward { kind } => write!(f, "{kind:?}_backward"),
2147            Op::ComplexNormSq => write!(f, "complex_norm_sq"),
2148            Op::ComplexNormSqBackward => write!(f, "complex_norm_sq_backward"),
2149            Op::Conjugate => write!(f, "conjugate"),
2150            Op::FakeQuantizeBackward { bits, ste, .. } => {
2151                write!(f, "fake_quant_backward(bits={bits},ste={ste:?})")
2152            }
2153            Op::MaxPool2dBackward { kernel_size, .. } => {
2154                write!(f, "maxpool2d_backward({kernel_size:?})")
2155            }
2156            Op::Conv2dBackwardInput { kernel_size, .. } => {
2157                write!(f, "conv2d_backward_input({kernel_size:?})")
2158            }
2159            Op::Conv2dBackwardWeight { kernel_size, .. } => {
2160                write!(f, "conv2d_backward_weight({kernel_size:?})")
2161            }
2162            Op::SoftmaxCrossEntropyWithLogits => write!(f, "sce_with_logits"),
2163            Op::SoftmaxCrossEntropyBackward => write!(f, "sce_backward"),
2164            Op::AttentionBackward {
2165                num_heads,
2166                head_dim,
2167                mask_kind,
2168                wrt,
2169            } => match mask_kind {
2170                MaskKind::None => write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},nomask)"),
2171                MaskKind::Causal => {
2172                    write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},causal)")
2173                }
2174                MaskKind::SlidingWindow(w) => {
2175                    write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},sw={w})")
2176                }
2177                MaskKind::Custom => {
2178                    write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},custom)")
2179                }
2180                MaskKind::Bias => write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},bias)"),
2181            },
2182            Op::FusedMatMulBiasAct { activation } => {
2183                write!(f, "fused_mm_bias")?;
2184                if let Some(a) = activation {
2185                    write!(f, "_{a:?}")?;
2186                }
2187                Ok(())
2188            }
2189            Op::FusedResidualLN { has_bias, eps } => {
2190                write!(f, "fused_residual")?;
2191                if *has_bias {
2192                    write!(f, "_bias")?;
2193                }
2194                write!(f, "_ln(eps={eps})")
2195            }
2196            Op::FusedResidualRmsNorm { has_bias, eps } => {
2197                write!(f, "fused_residual")?;
2198                if *has_bias {
2199                    write!(f, "_bias")?;
2200                }
2201                write!(f, "_rms(eps={eps})")
2202            }
2203            Op::FusedSwiGLU {
2204                cast_to,
2205                gate_first,
2206            } => {
2207                let mut s = match cast_to {
2208                    Some(dt) => format!("fused_swiglu(cast={dt}"),
2209                    None => "fused_swiglu(".to_string(),
2210                };
2211                if *gate_first {
2212                    s.push_str(",gate_first");
2213                }
2214                s.push(')');
2215                write!(f, "{s}")
2216            }
2217            Op::FusedAttentionBlock {
2218                num_heads,
2219                head_dim,
2220                has_bias,
2221                has_rope,
2222            } => {
2223                write!(f, "fused_attn(h={num_heads},d={head_dim}")?;
2224                if *has_bias {
2225                    write!(f, ",bias")?;
2226                }
2227                if *has_rope {
2228                    write!(f, ",rope")?;
2229                }
2230                write!(f, ")")
2231            }
2232            Op::If { .. } => write!(f, "if(...)"),
2233            Op::While { max_iterations, .. } => match max_iterations {
2234                Some(n) => write!(f, "while(...max={n})"),
2235                None => write!(f, "while(...)"),
2236            },
2237            Op::Scan {
2238                length,
2239                save_trajectory,
2240                num_xs,
2241                ..
2242            } => {
2243                let traj = if *save_trajectory { ",traj" } else { "" };
2244                let xs = if *num_xs > 0 {
2245                    format!(",xs={}", num_xs)
2246                } else {
2247                    String::new()
2248                };
2249                write!(f, "scan(len={length}{xs}{traj})")
2250            }
2251            Op::ScanBackward {
2252                length,
2253                save_trajectory,
2254                num_xs,
2255                ..
2256            } => {
2257                let traj = if *save_trajectory { ",traj" } else { "" };
2258                let xs = if *num_xs > 0 {
2259                    format!(",xs={}", num_xs)
2260                } else {
2261                    String::new()
2262                };
2263                write!(f, "scan_bwd(len={length}{xs}{traj})")
2264            }
2265            Op::ScanBackwardXs {
2266                length,
2267                save_trajectory,
2268                num_xs,
2269                xs_idx,
2270                ..
2271            } => {
2272                let traj = if *save_trajectory { ",traj" } else { "" };
2273                write!(
2274                    f,
2275                    "scan_bwd_xs(len={length},xs={num_xs},idx={xs_idx}{traj})"
2276                )
2277            }
2278            Op::FusedTransformerLayer {
2279                num_heads,
2280                head_dim,
2281                intermediate_size,
2282                has_bias,
2283                ..
2284            } => {
2285                write!(
2286                    f,
2287                    "fused_layer(h={num_heads},d={head_dim},int={intermediate_size}"
2288                )?;
2289                if *has_bias {
2290                    write!(f, ",bias")?;
2291                }
2292                write!(f, ")")
2293            }
2294            Op::ElementwiseRegion {
2295                chain,
2296                num_inputs,
2297                scalar_input_mask,
2298                input_modulus: _,
2299                prologue,
2300                prologue_input: _,
2301            } => {
2302                let pro = match prologue {
2303                    RegionPrologue::None => "",
2304                    RegionPrologue::ResizeNearest2x => ",prologue=resize2x",
2305                };
2306                if *scalar_input_mask != 0 {
2307                    write!(
2308                        f,
2309                        "ew_region(in={num_inputs},steps={},scalar_mask=0x{:x}{pro})",
2310                        chain.len(),
2311                        scalar_input_mask
2312                    )
2313                } else {
2314                    write!(f, "ew_region(in={num_inputs},steps={}{pro})", chain.len())
2315                }
2316            }
2317            Op::TransformRegion { steps, num_inputs } => {
2318                write!(f, "transform_region(in={num_inputs},steps={})", steps.len())
2319            }
2320            Op::BatchElementwiseRegion {
2321                chain,
2322                num_batch_inputs,
2323                scalar_input_mask,
2324                prologue,
2325                ..
2326            } => write!(
2327                f,
2328                "batch_ew_region(batch={num_batch_inputs},steps={},mask=0x{:x},prologue={prologue:?})",
2329                chain.len(),
2330                scalar_input_mask
2331            ),
2332            Op::LayerNormBackwardInput { eps, .. } => {
2333                write!(f, "layer_norm_backward_input(eps={eps})")
2334            }
2335            Op::LayerNormBackwardGamma { eps, .. } => {
2336                write!(f, "layer_norm_backward_gamma(eps={eps})")
2337            }
2338            Op::RmsNormBackwardInput { eps, .. } => write!(f, "rms_norm_backward_input(eps={eps})"),
2339            Op::RmsNormBackwardGamma { eps, .. } => write!(f, "rms_norm_backward_gamma(eps={eps})"),
2340            Op::RmsNormBackwardBeta { eps, .. } => write!(f, "rms_norm_backward_beta(eps={eps})"),
2341            Op::RopeBackward { head_dim, n_rot } => {
2342                write!(f, "rope_backward(d={head_dim},n_rot={n_rot})")
2343            }
2344            Op::GroupNormBackwardInput { num_groups, eps } => {
2345                write!(f, "group_norm_backward_input(g={num_groups},eps={eps})")
2346            }
2347            Op::GroupNormBackwardGamma { num_groups, eps } => {
2348                write!(f, "group_norm_backward_gamma(g={num_groups},eps={eps})")
2349            }
2350            Op::GroupNormBackwardBeta { num_groups, eps } => {
2351                write!(f, "group_norm_backward_beta(g={num_groups},eps={eps})")
2352            }
2353            Op::BatchNormInferenceBackwardInput { eps } => {
2354                write!(f, "batch_norm_inference_backward_input(eps={eps})")
2355            }
2356            Op::BatchNormInferenceBackwardGamma { eps } => {
2357                write!(f, "batch_norm_inference_backward_gamma(eps={eps})")
2358            }
2359            Op::BatchNormInferenceBackwardBeta => {
2360                write!(f, "batch_norm_inference_backward_beta")
2361            }
2362            Op::CumsumBackward { axis, exclusive } => {
2363                write!(f, "cumsum_backward(axis={axis},exclusive={exclusive})")
2364            }
2365            Op::GatherBackward { axis } => write!(f, "gather_backward(axis={axis})"),
2366            Op::GaussianSplatRender {
2367                width,
2368                height,
2369                tile_size,
2370                radius_scale,
2371                alpha_cutoff,
2372                max_splat_steps,
2373                transmittance_threshold,
2374                max_list_entries,
2375            } => write!(
2376                f,
2377                "gaussian_splat_render({width}x{height},tile={tile_size},r={radius_scale},a={alpha_cutoff},steps={max_splat_steps},t={transmittance_threshold},list={max_list_entries})"
2378            ),
2379            Op::GaussianSplatRenderBackward {
2380                width,
2381                height,
2382                loss_grad_clip,
2383                sh_band,
2384                ..
2385            } => write!(
2386                f,
2387                "gaussian_splat_render_bwd({width}x{height},clip={loss_grad_clip},sh={sh_band})"
2388            ),
2389            Op::GaussianSplatPrepare {
2390                width,
2391                height,
2392                tile_size,
2393                radius_scale,
2394                alpha_cutoff,
2395                max_splat_steps,
2396                transmittance_threshold,
2397                max_list_entries,
2398                ..
2399            } => write!(
2400                f,
2401                "gaussian_splat_prepare({width}x{height},tile={tile_size},r={radius_scale},a={alpha_cutoff},steps={max_splat_steps},t={transmittance_threshold},list={max_list_entries})"
2402            ),
2403            Op::GaussianSplatRasterize {
2404                width,
2405                height,
2406                tile_size,
2407                alpha_cutoff,
2408                max_splat_steps,
2409                transmittance_threshold,
2410                max_list_entries,
2411                ..
2412            } => write!(
2413                f,
2414                "gaussian_splat_rasterize({width}x{height},tile={tile_size},a={alpha_cutoff},steps={max_splat_steps},t={transmittance_threshold},list={max_list_entries})"
2415            ),
2416            Op::Custom {
2417                name,
2418                num_inputs,
2419                attrs,
2420            } => write!(f, "custom({name},in={num_inputs},attrs={}B)", attrs.len()),
2421            Op::CustomFn {
2422                num_inputs,
2423                vjp_body,
2424                jvp_body,
2425                ..
2426            } => {
2427                let v = if vjp_body.is_some() { ",vjp" } else { "" };
2428                let j = if jvp_body.is_some() { ",jvp" } else { "" };
2429                write!(f, "custom_fn(in={num_inputs}{v}{j})")
2430            }
2431            Op::Fft { inverse, norm } => {
2432                write!(f, "fft(inverse={inverse}, norm={norm:?})")
2433            }
2434            Op::FftButterflyStage { stage, n_fft } => {
2435                write!(f, "fft_butterfly_stage(stage={stage}, n_fft={n_fft})")
2436            }
2437            Op::LogMel => write!(f, "log_mel()"),
2438            Op::LogMelBackward => write!(f, "log_mel_backward()"),
2439            Op::WelchPeaks { k, n_segments } => {
2440                write!(f, "welch_peaks(k={k}, n_segments={n_segments})")
2441            }
2442        }
2443    }
2444}