Skip to main content

rlx_ir/
op.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Operation types — every tensor op in the RLX IR.
17//!
18//! Designed for pattern-matching fusion: ops are grouped by category so
19//! fusion passes can reason about them structurally.
20
21use crate::DType;
22
23/// Unary element-wise activation functions.
24#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub enum Activation {
27    Gelu,
28    GeluApprox,
29    Silu, // SwiGLU gate activation
30    Relu,
31    Sigmoid,
32    Tanh,
33    Exp,
34    Log,
35    Sqrt,
36    Rsqrt,
37    Neg,
38    Abs,
39    /// `sin(x)`. Backward: `dx = upstream · cos(x)`.
40    Sin,
41    /// `cos(x)`. Backward: `dx = -upstream · sin(x)`.
42    Cos,
43    /// `tan(x)`. Backward: `dx = upstream · sec²(x) = upstream · (1 + tan²(x))`.
44    Tan,
45    /// `atan(x)`. Backward: `dx = upstream · (1 / (1 + x²))`.
46    Atan,
47    /// Round to nearest integer (half-to-even), in f32.
48    /// Forward: `x.round()`. Backward: STE — treats as identity, so
49    /// the gradient passes through unchanged. Useful as a primitive
50    /// for composing custom quantization schemes (Mul-by-recip-scale
51    /// → Round → Clamp → Mul-by-scale = a hand-rolled FakeQuantize
52    /// that the elementwise-region pass can fuse into a single kernel).
53    Round,
54}
55
56/// Scale-tracking strategy for `Op::FakeQuantize`. Determines how
57/// the per-channel `s[c]` is computed each forward pass.
58///
59/// * `PerBatch` — recompute `s[c] = max(|x|) / q_max` from the
60///   current data on every call. Simple, no extra inputs, but
61///   noisy for activations (max-abs jumps batch-to-batch).
62///
63/// * `EMA { decay }` — keep a running `s[c]` in a state tensor
64///   (passed as a second op input). On each call, blend the
65///   current per-batch max-abs into the state via
66///   `state' = decay·state + (1-decay)·max_abs`. Smooth scale
67///   over training, makes activation-QAT actually trainable.
68///   Typical `decay = 0.99`.
69///
70/// * `Fixed` — never recompute. The state tensor's value is
71///   used as-is each call (set once at construction or by the
72///   caller). Useful when scales are pre-calibrated.
73#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
74#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Default)]
75pub enum ScaleMode {
76    #[default]
77    PerBatch,
78    EMA {
79        decay: f32,
80    },
81    Fixed,
82}
83
84impl Eq for ScaleMode {}
85impl std::hash::Hash for ScaleMode {
86    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
87        match self {
88            ScaleMode::PerBatch => state.write_u8(0),
89            ScaleMode::EMA { decay } => {
90                state.write_u8(1);
91                state.write_u32(decay.to_bits());
92            }
93            ScaleMode::Fixed => state.write_u8(2),
94        }
95    }
96}
97
98/// Straight-through estimator variants for `Op::FakeQuantize`'s
99/// backward. The forward is the same regardless: discrete
100/// `clamp(round(x/s)) * s`. The choice here affects only the
101/// gradient w.r.t. `x` during training.
102///
103/// * `Identity` — `dx = upstream`. The original STE; treats the
104///   round as identity in the backward direction. Simplest, fine
105///   for moderate bit widths (i4 / i8).
106///
107/// * `ClippedIdentity` — `dx = upstream * (|x| ≤ q_max·s)`. Zero
108///   the gradient when the input was outside the quantization
109///   range (i.e. the clamp activated). Stops the optimizer from
110///   pushing weights further into saturation.
111///
112/// * `Tanh` — `dx = upstream * (1 - tanh²(x/s))`. Smooth surrogate
113///   for the round step. Slowly attenuates the gradient as `|x|`
114///   approaches `q_max·s`. Often best on tight bit widths (i2).
115///
116/// * `HardTanh` — `dx = upstream * (1 - |x/(q_max·s)|).max(0)`.
117///   Piecewise-linear cousin of `Tanh`; cheaper to compute and
118///   nearly as effective.
119#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
120#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
121pub enum SteKind {
122    #[default]
123    Identity,
124    ClippedIdentity,
125    Tanh,
126    HardTanh,
127}
128
129/// Binary element-wise operations.
130#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
132pub enum BinaryOp {
133    Add,
134    Sub,
135    Mul,
136    Div,
137    Max,
138    Min,
139    Pow,
140}
141
142/// Comparison operations (return Bool tensor).
143#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
144#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
145pub enum CmpOp {
146    Eq,
147    Ne,
148    Lt,
149    Le,
150    Gt,
151    Ge,
152}
153
154/// What kind of attention mask the kernel should apply.
155///
156/// Borrowed from MAX's `nn/attention/mha_mask.mojo` pattern (#20 in
157/// PLAN.md): one attention kernel handles all variants by branching on
158/// the mask kind, instead of forcing every caller to materialize a mask
159/// tensor. The win is two-fold:
160///   1. **`None`** — single unpadded sequence: no mask load, no per-key
161///      compare in the inner loop.
162///   2. **`Causal`** — autoregressive decode: kernel generates the upper-
163///      triangular fill from `(qi, ki)` directly; no `seq²` mask tensor
164///      ever exists.
165///
166/// `Custom` is the existing path — read mask values from the 4th input.
167#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
168#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
169pub enum MaskKind {
170    /// No masking — every position attends to every position.
171    None,
172    /// Causal (autoregressive) — position `qi` attends only to `ki <= qi`.
173    Causal,
174    /// Sliding window — position `qi` attends to `ki ∈ [qi - w, qi]`.
175    SlidingWindow(usize),
176    /// Read mask values from the input tensor (default; matches BERT
177    /// padding-mask behavior). Tensor shape `[batch, key_len]` with
178    /// `1.0` = valid, `<0.5` = ignored.
179    Custom,
180    /// Additive per-head, per-query bias tensor
181    /// `[batch, num_heads, query_len, key_len]` added to the
182    /// `QK^T · scale` scores before softmax. Lets DETR-style boxRPB
183    /// and other learned position biases reuse the fast `Op::Attention`
184    /// path instead of decomposing into matmul + add + softmax + matmul.
185    Bias,
186}
187
188/// Which forward input an [`Op::AttentionBackward`] node differentiates.
189#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
190#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
191pub enum AttentionBwdWrt {
192    Query,
193    Key,
194    Value,
195}
196
197/// Reduction operations along specified axes.
198#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
199#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
200pub enum ReduceOp {
201    Sum,
202    Mean,
203    Max,
204    Min,
205    Prod,
206}
207
208/// PLAN L4: discriminant for each [`Op`] variant. Used by
209/// [`Op::kind`] + the `Backend::supported_ops` trait method to declare
210/// which ops a backend can lower; the `LegalizeForBackend` pass in
211/// `rlx-opt` checks the graph against this set and fails the compile
212/// when an unsupported op is present (instead of silent fallback).
213#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
214#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
215pub enum OpKind {
216    Input,
217    Param,
218    Constant,
219    Activation,
220    Cast,
221    StopGradient,
222    Quantize,
223    Dequantize,
224    FakeQuantize,
225    FakeQuantizeLSQ,
226    FakeQuantizeLSQBackwardX,
227    FakeQuantizeLSQBackwardScale,
228    Binary,
229    Compare,
230    Where,
231    ElementwiseRegion,
232    /// Fused sampling / geometry chain (FKL-style transform region).
233    TransformRegion,
234    /// Same element-wise chain over multiple batch planes (horizontal fusion).
235    BatchElementwiseRegion,
236    MatMul,
237    DotGeneral,
238    DenseSolve,
239    BatchedDenseSolve,
240    LayerNorm,
241    LayerNorm2d,
242    GroupNorm,
243    BatchNormInference,
244    RmsNorm,
245    ResizeNearest2x,
246    Attention,
247    Rope,
248    AxialRope2d,
249    Reshape,
250    Transpose,
251    Narrow,
252    Concat,
253    Expand,
254    Gather,
255    Reduce,
256    Softmax,
257    Cumsum,
258    ArgMax,
259    ArgMin,
260    TopK,
261    Sample,
262    /// ONNX `RandomNormalLike` — shape from input 0, output filled at runtime.
263    RngNormal,
264    /// ONNX `RandomUniformLike`.
265    RngUniform,
266    Conv,
267    Im2Col,
268    ConvTranspose2d,
269    Pool,
270    ReluBackward,
271    ActivationBackward,
272    FakeQuantizeBackward,
273    ComplexNormSq,
274    ComplexNormSqBackward,
275    Conjugate,
276    MaxPool2dBackward,
277    Conv2dBackwardInput,
278    Conv2dBackwardWeight,
279    SoftmaxCrossEntropyWithLogits,
280    SoftmaxCrossEntropyBackward,
281    AttentionBackward,
282    LayerNormBackwardInput,
283    LayerNormBackwardGamma,
284    RmsNormBackwardInput,
285    RmsNormBackwardGamma,
286    RmsNormBackwardBeta,
287    RopeBackward,
288    GroupNormBackwardInput,
289    GroupNormBackwardGamma,
290    GroupNormBackwardBeta,
291    BatchNormInferenceBackwardInput,
292    BatchNormInferenceBackwardGamma,
293    BatchNormInferenceBackwardBeta,
294    CumsumBackward,
295    GatherBackward,
296    GroupedMatMul,
297    DequantGroupedMatMul,
298    DequantMoEWeights,
299    ScatterAdd,
300    LoraMatMul,
301    DequantMatMul,
302    QMatMul,
303    QConv2d,
304    SelectiveScan,
305    GatedDeltaNet,
306    Lstm,
307    Gru,
308    Rnn,
309    Mamba2,
310    FusedSwiGLU,
311    FusedMatMulBiasAct,
312    FusedResidualLN,
313    FusedResidualRmsNorm,
314    FusedAttentionBlock,
315    FusedTransformerLayer,
316    If,
317    While,
318    Scan,
319    ScanBackward,
320    ScanBackwardXs,
321    /// CPU reference 3D Gaussian splat raster (project → bin → sort → raster).
322    /// See [`Op::GaussianSplatRender`].
323    GaussianSplatRender,
324    /// Backward of [`Op::GaussianSplatRender`] — packed scene parameter gradients.
325    GaussianSplatRenderBackward,
326    /// Project + tile bin + sort + ray grid (strict IR splat stage 1).
327    GaussianSplatPrepare,
328    /// Per-pixel raster from prepared buffers (strict IR splat stage 2).
329    GaussianSplatRasterize,
330    /// User-registered op dispatched through `op_registry`. All
331    /// custom ops (Sparse-LU, FFT, eigensolve, ...) share this kind;
332    /// the per-op identity lives in `Op::Custom::name`.
333    Custom,
334    /// User-defined sub-graph with optional override AD rules. See
335    /// [`Op::CustomFn`] / [`crate::Graph::custom_fn`].
336    CustomFn,
337    /// 1D FFT primitive (forward or inverse) — see [`Op::Fft`].
338    Fft,
339    /// Ternary pruned radix-2 butterfly stage — see [`Op::FftButterflyStage`].
340    FftButterflyStage,
341    /// Whisper-style log-mel from block-layout FFT spectrum — see [`Op::LogMel`].
342    LogMel,
343    /// Backward of [`Op::LogMel`] w.r.t. block-layout spectrum input 0.
344    LogMelBackward,
345    /// Welch PSD top-K spikes from block-layout FFT spectra — see [`Op::WelchPeaks`].
346    WelchPeaks,
347}
348
349/// An operand inside a fused [`ChainStep`] — either a graph-level input
350/// to the [`Op::ElementwiseRegion`] (by index 0..num_inputs) or the
351/// result of a previous step in the chain (by index 0..step_position).
352#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
353#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
354pub enum ChainOperand {
355    Input(u32),
356    Step(u32),
357}
358
359/// One step in a fused element-wise chain. Each step produces exactly
360/// one scalar result (per element); later steps can refer to it via
361/// [`ChainOperand::Step`]. The whole chain runs per element in registers.
362#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
363#[derive(Debug, Clone, PartialEq)]
364pub enum ChainStep {
365    Activation(Activation, ChainOperand),
366    Cast(DType, ChainOperand),
367    Binary(BinaryOp, ChainOperand, ChainOperand),
368    Compare(CmpOp, ChainOperand, ChainOperand),
369    /// 3-input element-wise select: `cond ? on_true : on_false`. Mirrors
370    /// `Op::Where` inside a chain. `cond` is treated as truthy iff
371    /// non-zero. Lets the optimizer fold attention masks / clamp-style
372    /// patterns into a single region kernel instead of breaking the
373    /// chain at the first `Op::Where`.
374    Where(ChainOperand, ChainOperand, ChainOperand),
375}
376
377/// Pre-region memory transform fused into [`Op::ElementwiseRegion`].
378#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
379#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
380pub enum RegionPrologue {
381    #[default]
382    None,
383    /// Input is half-resolution NCHW; output shape is 2× H×W (nearest 2×).
384    ResizeNearest2x,
385}
386
387/// One sampling step in [`Op::TransformRegion`].
388#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
389#[derive(Debug, Clone, PartialEq)]
390pub enum TransformStep {
391    ResizeNearest2x(ChainOperand),
392}
393
394/// An operation in the RLX IR graph.
395///
396/// Operations are categorized for fusion analysis:
397/// - Element-wise ops fuse with anything reading their output
398/// - Matmul/Conv are BLAS-dispatched and form fusion boundaries
399/// - Reductions are fusion roots (drive the loop iteration)
400#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
401#[derive(Debug, Clone, PartialEq)]
402pub enum Op {
403    // ── Graph inputs ────────────────────────────────────────────
404    /// Model input with a name (shape on the Node).
405    Input {
406        name: String,
407    },
408
409    /// Model parameter (weight/bias) with a name.
410    Param {
411        name: String,
412    },
413
414    /// Constant tensor embedded in the graph.
415    Constant {
416        data: Vec<u8>,
417    },
418
419    // ── Element-wise unary ──────────────────────────────────────
420    /// Unary activation: one input, same shape output.
421    Activation(Activation),
422
423    /// Cast to a different dtype.
424    Cast {
425        to: DType,
426    },
427
428    /// Stop-gradient (a.k.a. `detach` / `lax.stop_gradient`). Forward is
429    /// identity; the reverse-mode autodiff rule returns **no** gradient
430    /// contribution for the input. Single input, same shape & dtype
431    /// output. Used to build a Gradient-Reverse-Layer with identity
432    /// forward semantics in user code (see maet-rs `dat_loss`).
433    StopGradient,
434
435    /// INT8 quantization. Input f32; output i8 same shape.
436    ///   `q[i] = saturate_i8(round(x[i] / scale[c]) + zero_point[c])`
437    /// where `c` selects the per-channel scale/zp when `axis = Some(d)`
438    /// (`c = idx[d]`), or always uses index 0 when `axis = None`
439    /// (per-tensor). The `scales` / `zero_points` payload length must
440    /// match `1` for per-tensor and `input.dim(d)` for per-channel.
441    /// Static — typically produced at calibration time and baked
442    /// into the loaded model. Use `Op::Dequantize` for the inverse.
443    Quantize {
444        axis: Option<usize>,
445        scales: Vec<f32>,
446        zero_points: Vec<i32>,
447    },
448
449    /// INT8 dequantization (inverse of `Op::Quantize`). Input i8;
450    /// output f32 same shape.
451    ///   `x[i] = (q[i] - zero_point[c]) · scale[c]`
452    /// where `c` is selected by `axis` exactly as in `Op::Quantize`.
453    Dequantize {
454        axis: Option<usize>,
455        scales: Vec<f32>,
456        zero_points: Vec<i32>,
457    },
458
459    /// "Fake-quantize" op for **quantization-aware training** (QAT).
460    /// Input f32; output f32 same shape. Forward computes a per-axis
461    /// (or per-tensor when `axis = None`) max-abs scale on the fly:
462    ///   `s[c] = max(|x[..., c, ...]|) / q_max(bits)`
463    /// then quantizes-then-dequantizes:
464    ///   `out[i] = clamp(round(x[i] / s[c]), -q_max, q_max) * s[c]`
465    /// where `q_max` is `127` for `bits=8`, `7` for `bits=4`, `1` for
466    /// `bits=2` (ternary). Symmetric only — zero-point is always 0.
467    ///
468    /// The point of this op is to make the SGD optimizer "see" the
469    /// deployment-time rounding during training. Backward is the
470    /// **straight-through estimator** (STE): the gradient passes
471    /// through (variant chosen by `ste`), ignoring the discontinuity
472    /// at the round. Without STE the rounding would have zero
473    /// gradient almost everywhere and learning would stop.
474    ///
475    /// Inserted by the trainer on conv / FC weight tensors when
476    /// `--qat` is on; the existing `Op::Quantize` / packing path at
477    /// the end of training still handles the deployment-side
478    /// conversion to `i8`/`i4`/`i2` codes.
479    FakeQuantize {
480        bits: u8,
481        axis: Option<usize>,
482        ste: SteKind,
483        scale_mode: ScaleMode,
484    },
485
486    /// Learned Step Size Quantization (LSQ; Esser et al. 2020,
487    /// `arXiv:1902.08153`). Like `FakeQuantize` but the per-channel
488    /// `scale` is a *learned parameter*, passed as the second input.
489    /// Forward is identical to `FakeQuantize` with a fixed scale:
490    ///   `out[i] = clamp(round(x[i]/s[c]), -q_max, q_max) * s[c]`
491    /// Backward computes both `dx` (STE) and `dscale[c]` via the
492    /// closed-form gradient:
493    ///   `dscale[c] = sum_i ψ(x[i]/s[c]) · upstream[i]`
494    /// where `ψ(z) = -z + round(z)` if `|z| ≤ q_max` else
495    /// `sign(z) · q_max`. Routinely beats per-batch and EMA at
496    /// tight bit widths (i2 / i3).
497    ///
498    /// Inputs: `[x, scale]`. `scale` is `[chan_dim]` f32 (matches
499    /// `axis`); for `axis = None` it's `[1]`.
500    FakeQuantizeLSQ {
501        bits: u8,
502        axis: Option<usize>,
503    },
504
505    /// Backward pass for `Op::FakeQuantizeLSQ`. Computes BOTH the
506    /// gradient w.r.t. `x` (STE) and the gradient w.r.t. `scale`
507    /// (closed-form). Output shape matches `x`; the `scale` gradient
508    /// is reduced separately by `LsqScaleGradient`.
509    /// Inputs: `[x, scale, dy]`. Output: `dx`, same shape as `x`.
510    FakeQuantizeLSQBackwardX {
511        bits: u8,
512        axis: Option<usize>,
513    },
514
515    /// Companion to `FakeQuantizeLSQBackwardX`: computes the
516    /// `[chan_dim]` per-channel scale gradient. Inputs `[x, scale, dy]`.
517    /// Output shape matches `scale`.
518    FakeQuantizeLSQBackwardScale {
519        bits: u8,
520        axis: Option<usize>,
521    },
522
523    // ── Element-wise binary ─────────────────────────────────────
524    /// Binary op with broadcasting: two inputs, output shape is broadcast result.
525    Binary(BinaryOp),
526
527    // ── Comparison ──────────────────────────────────────────────
528    /// Element-wise comparison: two inputs, Bool output.
529    Compare(CmpOp),
530
531    /// Select elements: cond (Bool), on_true, on_false → output.
532    Where,
533
534    /// Fused element-wise region (PLAN L2). Holds an N-step chain of
535    /// element-wise operations. Inputs are referenced by index 0..num_inputs;
536    /// each step's result can be referenced by later steps via
537    /// `ChainOperand::Step(idx)`. The output is the last step's result.
538    /// Emitted by `MarkElementwiseRegions` in `rlx-opt` from chains of
539    /// Activation/Cast/Binary/Compare/Where ops with single-consumer
540    /// intermediates and broadcast-compatible shapes. Backends that
541    /// don't have a region kernel can decompose back to the original
542    /// chain via `unfuse::unfuse_elementwise_regions`.
543    ///
544    /// `scalar_input_mask` is a per-input bitfield (bit `i` set ⇒
545    /// input `i` is a scalar broadcast — has shape `[1]`). Kept as a
546    /// fast-path indicator that lets kernels skip the modulo entirely
547    /// when they detect a scalar.
548    ///
549    /// `input_modulus[i]` is the per-input element count, used by
550    /// kernels to compute `arena[input_offs[i] + (gid % input_modulus[i])]`
551    /// — the trailing-shape broadcast pattern. `0` means "no broadcast"
552    /// (input matches the output element count; kernel reads `gid`
553    /// directly). `1` means scalar; any other value means the input
554    /// has fewer elements than the output and they tile by modulo.
555    /// The encoder only allows broadcasts where `out_elems % in_elems
556    /// == 0` so the modulo divides cleanly. Lets chains include bias /
557    /// scale / eps / mask factors that previously broke the chain at
558    /// a Binary op with mismatched shapes.
559    ElementwiseRegion {
560        chain: Vec<ChainStep>,
561        num_inputs: u32,
562        scalar_input_mask: u32,
563        input_modulus: [u32; 16],
564        /// FKL-style closed fusion: apply before the element-wise chain.
565        prologue: RegionPrologue,
566        /// External input index that supplies the prologue transform source (default 0).
567        prologue_input: u32,
568    },
569
570    /// Fused transform chain (resize, future crop/color). Decompose via
571    /// [`rlx_fusion::DecomposeFusionRegions`] when no native kernel exists.
572    TransformRegion {
573        steps: Vec<TransformStep>,
574        num_inputs: u32,
575    },
576
577    /// Identical [`Op::ElementwiseRegion`] chain over `num_batch_inputs` tensors
578    /// (horizontal / z-plane fusion). Inputs are separate batch slices.
579    BatchElementwiseRegion {
580        chain: Vec<ChainStep>,
581        num_batch_inputs: u32,
582        scalar_input_mask: u32,
583        input_modulus: [u32; 16],
584        prologue: RegionPrologue,
585        prologue_input: u32,
586    },
587
588    // ── Linear algebra ──────────────────────────────────────────
589    /// Matrix multiply. Inputs: [.., M, K] × [.., K, N] → [.., M, N].
590    /// Batch dimensions are broadcast.
591    MatMul,
592
593    /// Matrix multiply with explicit dimension specification.
594    /// Like XLA's DotGeneral — handles arbitrary batch/contracting dims.
595    DotGeneral {
596        lhs_contracting: Vec<usize>,
597        rhs_contracting: Vec<usize>,
598        lhs_batch: Vec<usize>,
599        rhs_batch: Vec<usize>,
600    },
601
602    /// Batched dense linear solve. Inputs: `A [B, N, N]`,
603    /// `b [B, N]` or `b [B, N, K]`. Output: same shape as `b`.
604    ///
605    /// Per-batch independent solve — each `A[i]` and `b[i]` are
606    /// solved as a separate `Op::DenseSolve`. Emitted by vmap of
607    /// `Op::DenseSolve`. The CPU lowering loops over the batch
608    /// dimension calling `dgesv` per slice (LAPACK doesn't expose a
609    /// batched solve on Accelerate; cuSOLVER does on NVIDIA).
610    BatchedDenseSolve,
611
612    /// Dense linear solve `x = A⁻¹ · b` via LU factorization.
613    /// Inputs: `A [N, N]`, `b [N]` (or `b [N, K]` for multi-RHS).
614    /// Output: same shape as `b`.
615    ///
616    /// VJP via the implicit-function theorem:
617    ///   `dx = solve(Aᵀ, upstream)`
618    ///   `dA = -outer(dx, x)`   (x is the forward output)
619    ///   `db = dx`
620    /// The rule is dtype-agnostic; lowering is per-backend (Accelerate
621    /// `dgesv` / `sgesv`, cuSOLVER, etc.).
622    DenseSolve,
623
624    // ── Normalization ───────────────────────────────────────────
625    /// Layer normalization: input, gamma, beta → normalized output.
626    /// `axis` is the feature dimension (usually -1).
627    LayerNorm {
628        axis: i32,
629        eps: f32,
630    },
631
632    /// Group normalization on NCHW tensors: `input`, `gamma`, `beta` → same shape.
633    /// Normalizes over `(C/num_groups) × H × W` per group.
634    GroupNorm {
635        num_groups: usize,
636        eps: f32,
637    },
638
639    /// LayerNorm2d on NCHW: normalize across the channel axis at each spatial
640    /// position (candle / SAM `LayerNorm2d` semantics — not PyTorch's H×W norm).
641    LayerNorm2d {
642        eps: f32,
643    },
644
645    /// Nearest-neighbor 2× upsample on NCHW (doubles spatial dims 2 and 3).
646    ResizeNearest2x,
647
648    /// RMS normalization: input, gamma → normalized output.
649    RmsNorm {
650        axis: i32,
651        eps: f32,
652    },
653
654    /// BatchNorm inference with frozen running statistics.
655    /// Inputs: `x`, `gamma`, `beta`, `running_mean`, `running_var`.
656    /// Feature dimension is the last axis of `x`; stats are 1-D `[C]`.
657    BatchNormInference {
658        eps: f32,
659    },
660
661    // ── Attention ───────────────────────────────────────────────
662    /// Scaled dot-product attention: Q, K, V, \[mask\] → output.
663    /// The compiler can lower this to fused SDPA or flash attention.
664    /// `mask_kind` controls how masking is applied — `Custom` reads from
665    /// the 4th input tensor; `None` / `Causal` / `SlidingWindow` skip the
666    /// mask load and apply the mask directly in the inner loop. See
667    /// `MaskKind` for the rationale.
668    ///
669    /// `score_scale`: when `Some(s)`, dot-product scores are multiplied by
670    /// `s` instead of the default `1/sqrt(head_dim)` (Gemma uses `head_dim^-0.5`
671    /// explicitly in config). `attn_logit_softcap`: when `Some(c)`, applies
672    /// `c * tanh(s/c)` to scores before softmax (Gemma 2).
673    Attention {
674        num_heads: usize,
675        head_dim: usize,
676        mask_kind: MaskKind,
677        score_scale: Option<f32>,
678        attn_logit_softcap: Option<f32>,
679    },
680
681    /// Rotary position embedding applied to one tensor: x, cos, sin → x_rotated.
682    /// Apply separately to Q and K. `head_dim` is the per-head width; `n_rot`
683    /// is how many leading dims get NeoX RoPE (pair offset `n_rot/2`). When
684    /// `n_rot < head_dim`, trailing dims are copied unchanged (Qwen3.5 MRoPE).
685    Rope {
686        head_dim: usize,
687        n_rot: usize,
688    },
689
690    /// SAM2 axial 2-D RoPE on `[batch, seq, num_heads * head_dim]`.
691    AxialRope2d {
692        end_x: usize,
693        end_y: usize,
694        head_dim: usize,
695        num_heads: usize,
696        theta: f32,
697        repeat_factor: usize,
698    },
699
700    // ── Shape manipulation ──────────────────────────────────────
701    Reshape {
702        new_shape: Vec<i64>,
703    },
704    Transpose {
705        perm: Vec<usize>,
706    },
707    /// Select a contiguous slice along an axis.
708    Narrow {
709        axis: usize,
710        start: usize,
711        len: usize,
712    },
713    /// Concatenate along an axis.
714    Concat {
715        axis: usize,
716    },
717    /// Expand (broadcast) to a target shape.
718    Expand {
719        target_shape: Vec<i64>,
720    },
721    /// Gather elements by index along an axis (embedding lookup).
722    Gather {
723        axis: usize,
724    },
725
726    // ── Reduction ───────────────────────────────────────────────
727    /// Reduce along specified axes.
728    Reduce {
729        op: ReduceOp,
730        axes: Vec<usize>,
731        keep_dim: bool,
732    },
733
734    /// Selective scan (plan #15) — Mamba-style state-space model
735    /// step. The recurrence:
736    ///   `h[t] = exp(Δ[t] * A) * h[t-1] + Δ[t] * B[t] * x[t]`
737    ///   `y[t] = C[t] * h[t]`
738    /// where state `h` has dimension `state_size` and the input has
739    /// `(batch, seq, hidden)`.
740    ///
741    /// Inputs (in order):
742    ///   `x [b, s, h]`      f32 input
743    ///   `delta [b, s, h]`  f32 step size (per-position, per-channel)
744    ///   `a [h, n]`         f32 transition matrix (one per channel)
745    ///   `b [b, s, n]`      f32 input projection
746    ///   `c [b, s, n]`      f32 output projection
747    /// Output: `[b, s, h]` f32. State `h` is implicit; the kernel
748    /// scans through the seq dimension carrying it.
749    ///
750    /// `state_size` = `n` is exposed for the cost model.
751    SelectiveScan {
752        state_size: usize,
753    },
754
755    /// Gated DeltaNet linear-attention recurrence — the per-layer
756    /// kernel used by Qwen3.5/3.6 trunk "linear attention" blocks
757    /// (and Qwen3-Next, Kimi-Linear). Mirrors
758    /// `llama.cpp / src/models/delta-net-base.cpp` autoregressive
759    /// path; chunked + fused variants ride the same op identity.
760    ///
761    /// **Math (per token `t`, head `h`, state size `n`):**
762    /// state matrix `S[h, i, j]` is implicit (reset per batch).
763    /// ```text
764    ///   S[h]     *= exp(g[t,h])                     # scalar gate
765    ///   sk[h,j]   = Σ_i S[h,i,j] * k[t,h,i]
766    ///   d[h,j]    = (v[t,h,j] - sk[h,j]) * b[t,h]   # b = beta
767    ///   S[h,i,j] += k[t,h,i] * d[h,j]               # outer-prod
768    ///   o[t,h,j]  = Σ_i S[h,i,j] * (q[t,h,i] / √n)
769    /// ```
770    ///
771    /// Inputs:
772    ///   `q     [b, s, h_v, n]`  f32 queries (L2-normed by caller)
773    ///   `k     [b, s, h_v, n]`  f32 keys    (L2-normed by caller;
774    ///                            GQA-repeated to match `h_v`)
775    ///   `v     [b, s, h_v, n]`  f32 values
776    ///   `g     [b, s, h_v]`     f32 log-gate (exp'd inside kernel)
777    ///   `beta  [b, s, h_v]`     f32 delta-rule mixing factor
778    ///
779    /// Output: `[b, s, h_v, n]` f32.
780    ///
781    /// When `carry_state` is true, a sixth input `state [b, h_v, n, n]`
782    /// provides the initial SSM matrix per head; the kernel updates it
783    /// in place across the sequence and leaves the final state in the
784    /// same buffer (same layout as the internal scan state:
785    /// `state[h, i, j]` row-major over `(n, n)` per head).
786    GatedDeltaNet {
787        state_size: usize,
788        carry_state: bool,
789    },
790
791    /// Multi-layer (optionally bidirectional) LSTM over a
792    /// `[batch, seq, input]` sequence. Gate order i, f, g, o (PyTorch);
793    /// recurrence per step:
794    /// ```text
795    ///   z      = x_t · w_ihᵀ + h_{t-1} · w_hhᵀ + bias
796    ///   i,f,o  = σ(z_i), σ(z_f), σ(z_o);   g = tanh(z_g)
797    ///   c_t    = f · c_{t-1} + i · g;      h_t = o · tanh(c_t)
798    /// ```
799    /// `D = 2 if bidirectional else 1`. Inputs `[x, w_ih, w_hh, bias]`
800    /// (`+ [h0, c0]` when `carry`):
801    ///   * `x`:    `[batch, seq, input]`
802    ///   * `w_ih`: packed, all `(layer, direction)` blocks concatenated in
803    ///     `layer`-major then `direction` order. Block `(l,d)` is
804    ///     `[4*hidden, in_l]` with `in_l = input` for `l=0` else `D*hidden`.
805    ///   * `w_hh`: packed `L*D` blocks of `[4*hidden, hidden]`.
806    ///   * `bias`: packed `L*D` blocks of `[4*hidden]` (combined `b_ih+b_hh`).
807    ///   * `h0`, `c0` (carry only): `[L*D, batch, hidden]`; the final
808    ///     `hn`/`cn` are written back **in place** (decode threading).
809    ///
810    /// Output `y`: `[batch, seq, D*hidden]` — last layer's hidden states
811    /// (forward ‖ reverse concatenated on the feature axis). With
812    /// `num_layers = 1, bidirectional = false, carry = false` this is the
813    /// plain single-layer LSTM and the weight shapes reduce to
814    /// `[4*hidden, input]`, `[4*hidden, hidden]`, `[4*hidden]`.
815    Lstm {
816        hidden_size: usize,
817        num_layers: usize,
818        bidirectional: bool,
819        carry: bool,
820    },
821
822    /// Gated Recurrent Unit (PyTorch). `D = 2 if bidirectional else 1`;
823    /// gate order r, z, n. Recurrence:
824    /// ```text
825    ///   r = σ(x·W_irᵀ + b_ir + h·W_hrᵀ + b_hr)
826    ///   z = σ(x·W_izᵀ + b_iz + h·W_hzᵀ + b_hz)
827    ///   n = tanh(x·W_inᵀ + b_in + r ⊙ (h·W_hnᵀ + b_hn))
828    ///   h' = (1 - z) ⊙ n + z ⊙ h
829    /// ```
830    /// The new-gate reset is applied to the hidden term *after* its bias,
831    /// so `b_ih`/`b_hh` cannot be merged. Inputs `[x, w_ih, w_hh, b_ih, b_hh]`
832    /// (`+ [h0]` when carry); packing matches [`Op::Lstm`] with `3*hidden`
833    /// gate rows. `h0` `[L*D, batch, hidden]`. Output `[batch, seq, D*hidden]`.
834    Gru {
835        hidden_size: usize,
836        num_layers: usize,
837        bidirectional: bool,
838        carry: bool,
839    },
840
841    /// Elman RNN (PyTorch): `h' = act(x·w_ihᵀ + h·w_hhᵀ + bias)` with
842    /// `act = ReLU` when `relu` else `tanh`. Packed `w_ih` `[hidden, in_l]`,
843    /// `w_hh` `[hidden, hidden]`, merged `bias` `[hidden]` (per layer ×
844    /// direction). Inputs `[x, w_ih, w_hh, bias]` (`+ [h0]` when carry).
845    /// Output `[batch, seq, D*hidden]`.
846    Rnn {
847        hidden_size: usize,
848        num_layers: usize,
849        bidirectional: bool,
850        carry: bool,
851        relu: bool,
852    },
853
854    /// Mamba-2 / SSD (structured state-space duality) scan — the
855    /// scalar-decay SSM at the core of Mamba-2, sibling of
856    /// [`Op::SelectiveScan`] / [`Op::GatedDeltaNet`]. Inputs
857    /// `[x, dt, a, b, c]` (all `f32`):
858    ///   * `x`:  `[batch, seq, heads, head_dim]`
859    ///   * `dt`: `[batch, seq, heads]` — discretization step (already ≥ 0,
860    ///     e.g. softplus'd by the caller)
861    ///   * `a`:  `[heads]` — per-head decay rate (`dA = exp(dt·a)`)
862    ///   * `b`:  `[batch, seq, heads, state_size]`
863    ///   * `c`:  `[batch, seq, heads, state_size]`
864    ///
865    /// State `S [batch, heads, head_dim, state_size]` is zero-initialized
866    /// per sequence. Recurrence per timestep `t`:
867    /// ```text
868    ///   dA_t = exp(dt_t · a)
869    ///   S_t  = dA_t · S_{t-1} + (dt_t · x_t) ⊗ b_t
870    ///   y_t  = Σ_n S_t[:, n] · c_t[n]
871    /// ```
872    /// Output `y`: `[batch, seq, heads, head_dim]` (same shape as `x`).
873    Mamba2 {
874        head_dim: usize,
875        state_size: usize,
876    },
877
878    /// Fused dequant + matmul (plan #5). The biggest LLM-bandwidth
879    /// win on Apple Silicon: dequantizes weights inside the matmul
880    /// inner loop, never materializing f32 weights.
881    ///
882    /// **BREAKING CHANGE in 0.2.0:** `num_inputs()` is now
883    /// scheme-dependent — **4** for legacy Int8 schemes, **2** for
884    /// the new GGUF K-quant schemes (their scales/mins live inside
885    /// the packed bytes, so no side-channel `scale` / `zp` tensors
886    /// are fed in). Callers that assumed a fixed 4-input contract
887    /// must inspect `scheme.is_gguf()` before reading inputs.
888    ///
889    /// Inputs (Int8 schemes — `scheme.is_gguf() == false`):
890    ///   `x [m, k]`             f32 activations
891    ///   `w_q [k, n]` packed    quantized weight bytes (i8 per
892    ///                          element for Int8 schemes; 4-bit
893    ///                          packed two-per-byte for Int4)
894    ///   `scale [k/block, n]`   per-block f32 dequant scale
895    ///   `zp    [k/block, n]`   per-block f32 zero-point
896    ///                          (zero-tensor if symmetric)
897    ///
898    /// Inputs (`Nvfp4Block` — fixed group size 16 along K):
899    ///   `x [m, k]`             f32 activations
900    ///   `w_q [k,n/2]` packed   FP4 E2M1 codes (unsigned nibble 0..15)
901    ///   `scale [k/16, n]` u8   FP8 E4M3 block scales (one byte / group)
902    ///   `global_scale [1]` f32 per-tensor scale (pass `[1.0]` if unused)
903    ///
904    /// Inputs (GGUF schemes — `scheme.is_gguf() == true`):
905    ///   `x [m, k]`             f32 activations
906    ///   `packed_w [bytes]`     raw GGUF super-block bytes; the
907    ///                          dequantizer reads the per-sub-block
908    ///                          scales / mins / quants directly out
909    ///                          of the buffer per the K-quant block
910    ///                          layout (no side tensors).
911    ///
912    /// Output: `[m, n]` f32.
913    ///
914    /// `block_size` (on the Int8 schemes only) is the number of
915    /// consecutive elements that share one (scale, zero_point) pair.
916    /// The Op carries enough metadata that the kernel doesn't need
917    /// a separate `QuantMap` lookup at run time.
918    DequantMatMul {
919        scheme: crate::quant::QuantScheme,
920    },
921
922    /// Real INT8-arithmetic matrix multiply with i32 accumulation.
923    /// Inputs (in order):
924    ///   `x      [M, K]`  i8 activations (zero-point = `x_zp`)
925    ///   `w      [K, N]`  i8 weights     (zero-point = `w_zp`)
926    ///   `bias   [N]`     i32 (in accumulator scale = `x_scale·w_scale`),
927    ///                    pass a zeros tensor for "no bias"
928    /// Output:  `[M, N]`  i8 (zero-point = `out_zp`)
929    ///
930    /// Per-element compute:
931    ///   `out[m,n] = requantize(bias[n] + Σₖ (x[m,k]-x_zp)·(w[k,n]-w_zp), mult, out_zp)`
932    /// where `mult = x_scale · w_scale / out_scale`.
933    ///
934    /// This is the same kernel shape `rlx-cortexm/src/dense.rs`
935    /// uses for on-device int8 inference, lifted into the IR so the
936    /// rlx-cpu backend can run a quantized graph directly (instead
937    /// of round-tripping through fake-quant Dequantize → MatMul →
938    /// Quantize). 2-D only — generalizing to batched comes when a
939    /// real workload demands it.
940    QMatMul {
941        x_zp: i32,
942        w_zp: i32,
943        out_zp: i32,
944        mult: f32,
945    },
946
947    /// Real INT8-arithmetic 2-D convolution with i32 accumulation.
948    /// Inputs:
949    ///   `x      [N, C_in, H, W]`              i8 (zero-point = `x_zp`)
950    ///   `w      [C_out, C_in/groups, kH, kW]` i8 (zero-point = `w_zp`)
951    ///   `bias   [C_out]`                      i32 in accumulator scale
952    /// Output: `[N, C_out, H_out, W_out]` i8 (zero-point = `out_zp`).
953    /// Same NCHW geometry contract as `Op::Conv`; same requantize
954    /// math as `Op::QMatMul` (per-element `acc·mult` rounded to i8).
955    QConv2d {
956        kernel_size: Vec<usize>,
957        stride: Vec<usize>,
958        padding: Vec<usize>,
959        dilation: Vec<usize>,
960        groups: usize,
961        x_zp: i32,
962        w_zp: i32,
963        out_zp: i32,
964        mult: f32,
965    },
966
967    /// Fused LoRA matmul: `out = x·W + scale * x·A·B`.
968    /// Inputs (in order): `x [m, k]`, `w [k, n]`, `a [k, r]`, `b [r, n]`.
969    /// `r` is the LoRA rank (typically 4-64). `scale` is the
970    /// per-adapter `alpha / rank` knob.
971    /// Plan #9: lifts LoRA from "three matmuls + an add" into one
972    /// kernel that keeps the rank-r intermediate in registers.
973    LoraMatMul {
974        scale: f32,
975    },
976
977    /// Fused sampling kernel: logits → optional top-k filter →
978    /// optional top-p truncation → softmax → multinomial sample.
979    /// One f32-encoded sampled token id per batch row (output
980    /// shape `[batch]`).
981    ///
982    /// `temperature == 1.0` matches a plain argmax-of-softmax;
983    /// lower → sharper, higher → flatter. `top_k == 0` disables.
984    /// `top_p == 1.0` disables. `seed` is the Philox seed; pass 0
985    /// for "use process-global counter" (still deterministic
986    /// given the call order).
987    /// Borrowed from MAX's nn/sampling.mojo (#42 in PLAN.md).
988    /// Latency-critical: never materializes the full softmax
989    /// distribution on the host.
990    Sample {
991        top_k: usize,     // 0 = disabled
992        top_p: f32,       // 1.0 = disabled
993        temperature: f32, // 1.0 = neutral
994        seed: u64,        // 0 = use thread-local counter
995    },
996
997    /// ONNX `RandomNormalLike` / `RandomNormal`: zero or one shape-template
998    /// input (Like uses the template's shape; `RandomNormal` with a `shape`
999    /// attribute needs no input). Output shape is fixed on the node.
1000    /// at compile/execute time. Optional ONNX `seed` attribute (f32) overrides
1001    /// the mixed seed on the Ort backend.
1002    RngNormal {
1003        mean: f32,
1004        scale: f32,
1005        key: u64,
1006        op_seed: Option<f32>,
1007    },
1008
1009    /// ONNX `RandomUniformLike`.
1010    RngUniform {
1011        low: f32,
1012        high: f32,
1013        key: u64,
1014        op_seed: Option<f32>,
1015    },
1016
1017    /// Inclusive cumulative sum along an axis. Same shape in/out.
1018    /// Underpins ragged-tensor offsets, sampling (top-p prefix sum),
1019    /// and sequence-position math (#44 in PLAN.md).
1020    /// `exclusive=true` shifts the result so output\[0\] = 0 (useful
1021    /// for offset arrays where the first segment starts at 0).
1022    Cumsum {
1023        axis: i32,
1024        exclusive: bool,
1025    },
1026
1027    /// Index of the maximum along `axis`. Output drops that axis (or keeps it
1028    /// as size 1 when `keep_dim`). Indices are **f32-encoded** (rlx is f32 at
1029    /// the I/O boundary, matching [`Op::TopK`]).
1030    ArgMax {
1031        axis: usize,
1032        keep_dim: bool,
1033    },
1034
1035    /// Index of the minimum along `axis`; see [`Op::ArgMax`].
1036    ArgMin {
1037        axis: usize,
1038        keep_dim: bool,
1039    },
1040
1041    /// Softmax along an axis (reduction + element-wise).
1042    Softmax {
1043        axis: i32,
1044    },
1045
1046    /// Top-K **indices** along the last axis. Output shape `[..., k]`,
1047    /// f32-encoded indices (rlx is f32-only at the I/O boundary).
1048    /// To recover the values, follow with a `Gather` against the
1049    /// original tensor — works because Gather already supports any axis.
1050    /// Ties broken by smaller index (matches NumPy / PyTorch
1051    /// `torch.topk(..., largest=True, sorted=True)`).
1052    /// Used by MoE gating; also useful for beam search.
1053    TopK {
1054        k: usize,
1055    },
1056
1057    /// Indexed batched matmul. The MoE GEMM primitive.
1058    /// Inputs: `[input, weight, expert_idx]`
1059    ///   input       : [M, K]                — per-token activations
1060    ///   weight      : [num_experts, K, N]   — stacked expert weights
1061    ///   expert_idx  : \[M\]                   — f32-encoded expert id per token
1062    /// Output         : [M, N]                — output\[i\] = input\[i\] @ weight[expert_idx\[i\]]
1063    /// Naive impl on both backends; future work can replace with a
1064    /// segmented/grouped GEMM when there's a real workload.
1065    GroupedMatMul,
1066
1067    /// Fused GGUF K-quant dequant + [`Op::GroupedMatMul`]. Same three
1068    /// inputs as `GroupedMatMul`, but `weight` is a U8 tensor holding
1069    /// `num_experts` contiguous packed expert slabs (GGML layout, expert
1070    /// dimension outermost). Scales live inside the packed bytes.
1071    DequantGroupedMatMul {
1072        scheme: crate::quant::QuantScheme,
1073    },
1074
1075    /// Dequant a packed MoE expert stack to F32 `[num_experts, K, N]` in
1076    /// GroupedMatMul layout. Input: U8 packed bytes; output shape is
1077    /// declared on the node (`[E, K, N]`).
1078    DequantMoEWeights {
1079        scheme: crate::quant::QuantScheme,
1080    },
1081
1082    /// Scatter-add into a destination tensor. The "unpermute" half of
1083    /// MoE routing (also useful for embedding gradient updates).
1084    /// Inputs: `[updates, indices]`
1085    ///   updates : [num_updates, trailing]   — values to add
1086    ///   indices : \[num_updates\]             — f32-encoded destination row
1087    /// Output    : [out_dim, trailing]       — output[indices\[i\]] += updates\[i\]
1088    /// `out_dim` is taken from the node's declared output shape.
1089    /// Initial output is zero; multiple updates to the same row
1090    /// accumulate (sequentially on CPU; with atomic-add on Metal).
1091    ScatterAdd,
1092
1093    // ── Convolution ─────────────────────────────────────────────
1094    /// 2D convolution on NCHW tensors. Also exposed as [`OpKind::Conv`] / `conv2d`.
1095    /// Weight layout: `[C_out, C_in / groups, kH, kW]`.
1096    Conv {
1097        kernel_size: Vec<usize>,
1098        stride: Vec<usize>,
1099        padding: Vec<usize>,
1100        dilation: Vec<usize>,
1101        groups: usize,
1102    },
1103
1104    /// NCHW im2col for conv backward-weight style matmul.
1105    /// Input `[N, C, H, W]`. Output `[M, C·kH·kW]` with
1106    /// `M = N · H_out · W_out`. When batch is [`dynamic::sym::BATCH`],
1107    /// output rows use [`dynamic::sym::ROWS`] (bind `N · H_out · W_out`).
1108    Im2Col {
1109        kernel_size: Vec<usize>,
1110        stride: Vec<usize>,
1111        padding: Vec<usize>,
1112        dilation: Vec<usize>,
1113    },
1114
1115    /// 2D transposed convolution on NCHW. Weight layout (PyTorch):
1116    /// `[C_in, C_out / groups, kH, kW]`.
1117    ConvTranspose2d {
1118        kernel_size: Vec<usize>,
1119        stride: Vec<usize>,
1120        padding: Vec<usize>,
1121        dilation: Vec<usize>,
1122        output_padding: Vec<usize>,
1123        groups: usize,
1124    },
1125
1126    // ── Pooling ─────────────────────────────────────────────────
1127    Pool {
1128        kind: ReduceOp,
1129        kernel_size: Vec<usize>,
1130        stride: Vec<usize>,
1131        padding: Vec<usize>,
1132    },
1133
1134    // ── Backward / training ops ─────────────────────────────────
1135    //
1136    // Closed-form gradient nodes emitted by `rlx-opt::autodiff`.
1137    // Pairing a forward op with a dedicated backward op (instead of
1138    // composing it from primitives) keeps the gradient kernel simple
1139    // and lets the backend recompute argmax / masks / softmax inline.
1140    /// ReLU backward: `dx = dy where x > 0 else 0`.
1141    /// Inputs: `[x, dy]` — both same shape; output matches.
1142    ReluBackward,
1143
1144    /// Element-wise complex squared-magnitude: `|z|² = z.re² + z.im²`.
1145    /// Input: 1 tensor with `DType::C64`. Output: same shape but
1146    /// `DType::F32`. The natural real-valued loss surface for
1147    /// Wirtinger reverse-mode AD on complex graphs — pair with
1148    /// [`Op::ComplexNormSqBackward`].
1149    ComplexNormSq,
1150
1151    /// Element-wise complex conjugate: `z̄ = z.re - i·z.im` per element.
1152    /// Input: 1 tensor with `DType::C64`. Output: same shape, same dtype.
1153    /// Used by Wirtinger VJP rules on `Op::Binary` over C64 (the rule
1154    /// for `y = a·b` is `dL/dā = upstream · conj(b)`, etc.).
1155    Conjugate,
1156
1157    /// Backward for [`Op::ComplexNormSq`] under Wirtinger calculus.
1158    /// `f(z) = |z|² = z·z̄`, so `∂f/∂z̄ = z`. Given upstream real
1159    /// cotangent `g` (same shape as the forward output), the C64
1160    /// gradient with respect to `z` is `g · z` element-wise, returned
1161    /// in C64 storage `[re_g·re_z, re_g·im_z]` per element.
1162    ///
1163    /// Inputs: `[z (C64), g (F32)]` — both same logical shape; output
1164    /// matches `z` (C64).
1165    ComplexNormSqBackward,
1166
1167    /// LayerNorm backward w.r.t. the input. Computes
1168    ///   `d_x[..., d] = inv_std · (dy·γ - mean(dy·γ) - x̂·mean(dy·γ·x̂))`
1169    /// over the feature axis, where `x̂ = (x - mean)/std` is recomputed
1170    /// inline from `x`. Inputs: `[x, gamma, dy]`; output shape = `x.shape`.
1171    /// Currently lowers axis=-1 only (matches the forward thunk).
1172    LayerNormBackwardInput {
1173        axis: i32,
1174        eps: f32,
1175    },
1176
1177    /// LayerNorm backward w.r.t. gamma. Computes
1178    ///   `d_gamma[d] = Σ_{batch} dy[..., d] · x̂[..., d]`
1179    /// — sums the per-element product of upstream and the (recomputed)
1180    /// normalized input over the leading axes. Inputs: `[x, dy]`;
1181    /// output shape = `gamma.shape` (= 1-D feature axis).
1182    LayerNormBackwardGamma {
1183        axis: i32,
1184        eps: f32,
1185    },
1186
1187    /// RMSNorm backward w.r.t. input. Inputs `[x, gamma, beta, dy]`; output = `x.shape`.
1188    RmsNormBackwardInput {
1189        axis: i32,
1190        eps: f32,
1191    },
1192
1193    /// RMSNorm backward w.r.t. gamma. Inputs `[x, gamma, beta, dy]`; output = `gamma.shape`.
1194    RmsNormBackwardGamma {
1195        axis: i32,
1196        eps: f32,
1197    },
1198
1199    /// RMSNorm backward w.r.t. beta. Inputs `[x, gamma, beta, dy]`; output = `beta.shape`.
1200    RmsNormBackwardBeta {
1201        axis: i32,
1202        eps: f32,
1203    },
1204
1205    /// RoPE backward w.r.t. `x`. Inputs `[dy, cos, sin]`; output = `dy.shape`.
1206    RopeBackward {
1207        head_dim: usize,
1208        n_rot: usize,
1209    },
1210
1211    /// GroupNorm (NCHW) backward w.r.t. input. Inputs `[x, gamma, beta, dy]`.
1212    GroupNormBackwardInput {
1213        num_groups: usize,
1214        eps: f32,
1215    },
1216
1217    /// GroupNorm backward w.r.t. gamma. Inputs `[x, dy]`; output = `gamma.shape`.
1218    GroupNormBackwardGamma {
1219        num_groups: usize,
1220        eps: f32,
1221    },
1222
1223    /// GroupNorm backward w.r.t. beta. Inputs `[x, dy]`; output = `beta.shape`.
1224    GroupNormBackwardBeta {
1225        num_groups: usize,
1226        eps: f32,
1227    },
1228
1229    /// BatchNorm inference backward w.r.t. `x`. Inputs `[x, gamma, mean, var, dy]`.
1230    BatchNormInferenceBackwardInput {
1231        eps: f32,
1232    },
1233
1234    /// BatchNorm inference backward w.r.t. `gamma`. Inputs `[x, mean, var, dy]`.
1235    BatchNormInferenceBackwardGamma {
1236        eps: f32,
1237    },
1238
1239    /// BatchNorm inference backward w.r.t. `beta`. Inputs `[dy]`; output = `beta.shape`.
1240    BatchNormInferenceBackwardBeta,
1241
1242    /// Cumsum backward along `axis`. Inputs `[dy]`; output matches forward input shape.
1243    CumsumBackward {
1244        axis: i32,
1245        exclusive: bool,
1246    },
1247
1248    /// Gather backward (scatter-add into table). Inputs `[dy, indices]`; output = table shape.
1249    /// `axis` matches forward [`Op::Gather`].
1250    GatherBackward {
1251        axis: i32,
1252    },
1253
1254    /// Generic element-wise activation backward. `kind` selects the
1255    /// closed-form derivative `d/dx act(x)`. Inputs: `[x, dy]`; output
1256    /// shape matches `x`. The kernel computes `d/dx · dy` per element.
1257    ///
1258    /// Closed forms (all element-wise):
1259    /// * `Gelu`     — exact derivative of erf-based GELU.
1260    /// * `GeluApprox` — derivative of the tanh approximation
1261    ///   `0.5 x (1 + tanh(√(2/π)(x + 0.044715 x³)))`.
1262    /// * `Silu`     — `σ(x)·(1 + x·(1 - σ(x)))`.
1263    /// * `Sigmoid`  — `σ(x)·(1 - σ(x))`.
1264    /// * `Tanh`     — `1 - tanh(x)²`.
1265    /// * `Exp`      — `exp(x)`.
1266    /// * `Log`      — `1 / x`.
1267    /// * `Sqrt`     — `0.5 / sqrt(x)`.
1268    /// * `Rsqrt`    — `-0.5 · x^(-3/2)`.
1269    /// * `Neg`      — `-1`.
1270    /// * `Abs`      — `sign(x)` (zero at x=0).
1271    /// * `Sin`      — `cos(x)`.
1272    /// * `Cos`      — `-sin(x)`.
1273    /// * `Tan`      — `1 + tan²(x)`.
1274    /// * `Atan`     — `1 / (1 + x²)`.
1275    /// * `Relu`     — kept here for completeness; the dedicated
1276    ///   `ReluBackward` op is preferred for relu and is what the
1277    ///   autodiff pass actually emits.
1278    ActivationBackward {
1279        kind: Activation,
1280    },
1281
1282    /// Backward for `Op::FakeQuantize` under a non-default STE.
1283    /// Inputs `[x, dy]`: the forward input and the upstream
1284    /// gradient. Output `dx` same shape. The `bits`/`axis`/`ste`
1285    /// fields must match the forward op so the kernel computes the
1286    /// same per-channel scale and applies the right STE attenuation.
1287    /// For `SteKind::Identity` this op is unnecessary — autodiff
1288    /// just routes `upstream` through unchanged.
1289    FakeQuantizeBackward {
1290        bits: u8,
1291        axis: Option<usize>,
1292        ste: SteKind,
1293    },
1294
1295    /// 2D max-pool backward. Routes each element of `dy` back into the
1296    /// position in `x`'s window where the forward max was taken.
1297    /// Inputs: `[x, dy]` with `x [N, C, H, W]` and
1298    /// `dy [N, C, H_out, W_out]`. Output: same shape as `x`.
1299    /// Carries the forward pool's geometry so the kernel can recompute
1300    /// the argmax position per window without a saved-indices tensor.
1301    MaxPool2dBackward {
1302        kernel_size: Vec<usize>,
1303        stride: Vec<usize>,
1304        padding: Vec<usize>,
1305    },
1306
1307    /// 2D conv backward w.r.t. input. Computes `dx = conv_transpose(dy, w)`.
1308    /// Inputs: `[dy, w]` with `dy [N, C_out, H_out, W_out]` and
1309    /// `w [C_out, C_in/groups, kH, kW]`. Output: `[N, C_in, H, W]`
1310    /// (declared on the node — caller knows the original input shape).
1311    /// Geometry is the forward conv's parameters, not the transposed
1312    /// conv's.
1313    Conv2dBackwardInput {
1314        kernel_size: Vec<usize>,
1315        stride: Vec<usize>,
1316        padding: Vec<usize>,
1317        dilation: Vec<usize>,
1318        groups: usize,
1319    },
1320
1321    /// 2D conv backward w.r.t. weight. Computes
1322    /// `dw[c_out, c_in, kh, kw] = sum_{n,h_out,w_out} x[n,c_in,...] * dy[n,c_out,h_out,w_out]`.
1323    /// Inputs: `[x, dy]`. Output: `[C_out, C_in/groups, kH, kW]`.
1324    Conv2dBackwardWeight {
1325        kernel_size: Vec<usize>,
1326        stride: Vec<usize>,
1327        padding: Vec<usize>,
1328        dilation: Vec<usize>,
1329        groups: usize,
1330    },
1331
1332    /// Fused softmax + cross-entropy loss with integer (f32-encoded)
1333    /// targets — the standard classification loss. Per-row output:
1334    /// `loss[n] = -log(softmax(logits[n])[labels[n]])`.
1335    /// Inputs: `[logits, labels]` with `logits [N, C]` and
1336    /// `labels [N]` (f32-encoded class indices). Output: `[N]`.
1337    /// Caller does the `Reduce::Mean` if they want a scalar.
1338    SoftmaxCrossEntropyWithLogits,
1339
1340    /// Backward of the fused loss above. Emits
1341    /// `dlogits[n,c] = (softmax(logits[n])[c] - one_hot(labels)[n,c]) * d_loss[n]`.
1342    /// Inputs: `[logits, labels, d_loss]`. Output: `[N, C]` (same shape
1343    /// as `logits`). Recomputes the softmax inline rather than threading
1344    /// it through from the forward node.
1345    SoftmaxCrossEntropyBackward,
1346
1347    /// Backward of [`Op::Attention`]. Recomputes scaled `QK^T`, applies
1348    /// the same `mask_kind` as the forward op, softmaxes scores, then
1349    /// emits **one** of `dQ`, `dK`, or `dV` selected by [`AttentionBwdWrt`].
1350    /// Autodiff emits three nodes (one per `wrt`) so each output shape
1351    /// stays a normal single-output MIR node.
1352    ///
1353    /// Inputs: `[q, k, v, dy]` plus optional mask when `mask_kind` is
1354    /// [`MaskKind::Custom`] or [`MaskKind::Bias`] (same convention as
1355    /// forward). Output shape matches `q`, `k`, or `v` respectively.
1356    AttentionBackward {
1357        num_heads: usize,
1358        head_dim: usize,
1359        mask_kind: MaskKind,
1360        wrt: AttentionBwdWrt,
1361    },
1362
1363    // ── Fused operations (created by optimization passes) ──────
1364    /// Fused matmul + bias + activation. Created from MatMul → Add → Activation.
1365    FusedMatMulBiasAct {
1366        activation: Option<Activation>,
1367    },
1368
1369    /// Fused residual + optional bias + layer norm.
1370    /// Created from Add(x, residual) → [Add(bias)] → LayerNorm.
1371    FusedResidualLN {
1372        has_bias: bool,
1373        eps: f32,
1374    },
1375
1376    /// Fused residual + optional bias + RMS norm.
1377    /// Created from Add(x, residual) → [Add(bias)] → RmsNorm.
1378    FusedResidualRmsNorm {
1379        has_bias: bool,
1380        eps: f32,
1381    },
1382
1383    /// Fused SwiGLU: split input into up/gate halves, silu(gate) * up.
1384    /// Created from Split → Silu → Mul when fed by a concatenated matmul.
1385    ///
1386    /// `cast_to`: optional output dtype — when `Some(dt)` the kernel casts
1387    /// its result from the input dtype to `dt` in-register, saving a
1388    /// separate cast pass. Reserved for future fp8/fp4 quantization paths;
1389    /// for f32→f16 mixed precision the AutoMixedPrecision pass already
1390    /// inserts a Cast node so this stays `None` in current pipelines.
1391    FusedSwiGLU {
1392        cast_to: Option<DType>,
1393        /// When `true`, the concatenated input stores gate in the low half
1394        /// `[..., 0..N)` and up in the high half `[..., N..2N)` — the layout
1395        /// produced when gate projection is emitted before up in the builder.
1396        /// Default `false`: up @ low, gate @ high (canonical concat order).
1397        gate_first: bool,
1398    },
1399
1400    /// Fused full transformer layer: attention block + residual+LN + FFN + residual+LN.
1401    /// All intermediates resident in registers/threadgroup memory; one kernel
1402    /// per layer instead of ~30 (the CPU's batch=1 win, lifted to IR so any
1403    /// backend can implement it as a monolithic kernel).
1404    ///
1405    /// Inputs: hidden, qkv_w, qkv_b, out_w, out_b,
1406    ///         ln1_g, ln1_b, fc1_w, fc1_b, fc2_w, fc2_b, ln2_g, ln2_b, mask
1407    /// Output: same shape as hidden.
1408    ///
1409    /// **Backend status:** same as FusedAttentionBlock. CPU implements
1410    /// the L1-cache-resident merge at the thunk level. Metal deferred —
1411    /// requires a single MSL kernel for the whole layer to actually
1412    /// beat the unfused path. Multi-day work; revisit when there's a
1413    /// model whose Metal inference is bottlenecked here rather than on
1414    /// the wait latency floor.
1415    FusedTransformerLayer {
1416        num_heads: usize,
1417        head_dim: usize,
1418        intermediate_size: usize,
1419        eps1: f32,
1420        eps2: f32,
1421        activation: Activation,
1422        has_bias: bool,
1423    },
1424
1425    /// Fused attention block: QKV projection → split → \[RoPE\] → SDPA → output projection.
1426    /// Created by FuseAttentionBlock pass when batch*seq is small.
1427    /// All intermediates stay in L1 cache — no arena writes between ops.
1428    ///
1429    /// Inputs (in order):
1430    ///   hidden, qkv_w, out_w, mask,
1431    ///   [qkv_b, out_b]      if has_bias,
1432    ///   [rope_cos, rope_sin] if has_rope
1433    ///
1434    /// **Backend status (Phase C finalize):**
1435    ///   CPU  — implemented at the *thunk* level: the CPU schedule
1436    ///          recognizes the multi-thunk pattern and merges into
1437    ///          a single FusedAttnBlock that keeps Q/K/V in stack
1438    ///          buffers across stages (the L1-cache win).
1439    ///   Metal — **deferred**. A dispatch-wrapper version (chaining
1440    ///          existing kernels) buys nothing the unfused Metal path
1441    ///          doesn't already get, since per-run cost is dominated
1442    ///          by `wait_until_completed` (~150 µs), not encode. The
1443    ///          real win is a single MSL kernel keeping Q/K/V in
1444    ///          threadgroup memory across stages — multi-day work.
1445    ///          Until then, Metal runs the unfused chain (one matmul,
1446    ///          three narrows, two ropes, attention, one matmul) — all
1447    ///          covered in op_coverage and parity_harness.
1448    FusedAttentionBlock {
1449        num_heads: usize,
1450        head_dim: usize,
1451        has_bias: bool,
1452        has_rope: bool,
1453    },
1454
1455    // ── Control flow (subgraphs as op payloads) ─────────────────
1456    //
1457    // Status: IR is defined; helper `run_if` / `run_while` exist in
1458    // rlx-runtime/src/subgraph.rs; **executor wiring is not yet
1459    // implemented** (both CPU thunk and Metal thunk fall through to
1460    // `Thunk::Nop` for these ops). Wiring requires:
1461    //   1. Recursive subgraph compile at parent-compile time.
1462    //   2. Per-subgraph input/output binding through the arena.
1463    //   3. Schedule-level dispatch when the predicate / loop cond is
1464    //      resolved at runtime.
1465    // Estimate: 4–6 hours of focused work + parity tests. Deferred
1466    // because no current in-tree model uses these ops;
1467    // surface area without a validation target invites silent bugs.
1468    /// Conditional: pick between two subgraphs based on a boolean predicate.
1469    /// Inputs: [predicate, ...captures (used inside both branches)].
1470    /// `then_branch` and `else_branch` are sub-graphs that share the
1471    /// captured inputs and must produce identically-shaped outputs.
1472    /// Used for: shape-dependent execution, batched inference of
1473    /// dynamic-length sequences with padding masks.
1474    If {
1475        then_branch: Box<crate::Graph>,
1476        else_branch: Box<crate::Graph>,
1477    },
1478
1479    /// Loop: iterate `body` while `cond` evaluates true.
1480    /// Inputs: [...initial loop-carried values].
1481    /// `cond`'s single output is a Bool scalar.
1482    /// `body`'s outputs become the next iteration's loop-carried inputs.
1483    /// Outputs of While are the values after the final iteration.
1484    /// Used for: KV-cache-driven autoregressive generation, beam search.
1485    While {
1486        cond: Box<crate::Graph>,
1487        body: Box<crate::Graph>,
1488        max_iterations: Option<usize>,
1489    },
1490
1491    /// Bounded-length loop with a fixed-shape carry, optional per-step
1492    /// inputs, and optional stacked output. Mirrors JAX's `lax.scan`.
1493    ///
1494    /// Body signature: `(carry, x_t_0, ..., x_t_{num_xs-1}) → carry_next`
1495    /// — `1 + num_xs` Op::Inputs in NodeId construction order (first
1496    /// declared is the carry; the remaining `num_xs` are per-step
1497    /// slices). Single output (the next carry).
1498    ///
1499    /// Outer Op::Scan inputs (in order):
1500    ///   `[init_carry, xs_0, xs_1, ..., xs_{num_xs-1}]`
1501    /// Each `xs_i` has shape `[length, *per_step_shape_i]`; the body
1502    /// sees `xs_i[t]` (a `per_step_shape_i` slice) on iteration `t`.
1503    ///
1504    /// Outer Op::Scan output:
1505    ///   * `save_trajectory == false` — final carry, shape `*carry`.
1506    ///   * `save_trajectory == true`  — stacked trajectory of carries,
1507    ///     shape `[length, *carry]`. Row `t` is the carry after step
1508    ///     `t+1`, so row `length-1` matches the no-trajectory case.
1509    ///
1510    /// Mirrors JAX's `lax.scan`. Common uses include time-stepping
1511    /// integrators with time-varying drives, Mamba-style SSM scans
1512    /// reading per-step inputs, and RNN-style sequence processing.
1513    Scan {
1514        body: Box<crate::Graph>,
1515        length: u32,
1516        save_trajectory: bool,
1517        /// Number of "broadcast" inputs — values that are constant
1518        /// across iterations. Outer scan inputs in order:
1519        ///   `[init, bcast_0..bcast_{B-1}, xs_0..xs_{X-1}]`
1520        /// Body Op::Inputs in NodeId order:
1521        ///   `[carry, bcast_0..bcast_{B-1}, x_t_0..x_t_{X-1}]`
1522        /// CPU executor fills bcast slots ONCE before the iteration
1523        /// loop (xs slots are filled per-step). The reverse-mode AD
1524        /// pre-pass materialises each bcast into an xs of shape
1525        /// `[length, *bcast]` via broadcast `Mul` so the rest of the
1526        /// VJP / executor pipeline can stay unchanged. `0` (default)
1527        /// keeps the original carry+xs scan shape.
1528        num_bcast: u32,
1529        /// Number of per-step `xs` inputs. Total outer Op::Scan
1530        /// inputs is `1 + num_bcast + num_xs`.
1531        num_xs: u32,
1532        /// Number of trajectory checkpoints when `save_trajectory ==
1533        /// true`. `0` means "save all `length` rows" (default). A
1534        /// positive value `K` means save only `K` evenly-spaced rows
1535        /// at indices `floor(t * length / K)` for `t in 0..K`. Used
1536        /// by recursive checkpointed AD: store O(√T) carries during
1537        /// forward, recompute the rest in the backward pass.
1538        ///
1539        /// When `0` (or `K == length`), the saved trajectory has
1540        /// shape `[length, *carry]` — same as the original behavior.
1541        /// When `0 < K < length`, the saved trajectory has shape
1542        /// `[K, *carry]`.
1543        num_checkpoints: u32,
1544    },
1545
1546    /// Reverse-mode AD companion to `Op::Scan` — extracts the carry
1547    /// gradient `dinit`. Walks `t = length-1 .. 0`, applying `body_vjp`
1548    /// to thread `dcarry` back through the time loop.
1549    ///
1550    /// Inputs (in order):
1551    ///   `[init, trajectory, upstream, xs_0, ..., xs_{num_xs-1}]`
1552    /// Output: `dinit`, shape = carry shape.
1553    ///
1554    /// `body_vjp` is the result of
1555    /// `autodiff::grad(body, [carry_id, xs_0_id, ..., xs_{num_xs-1}_id])`
1556    /// — a graph with `1 + num_xs + 1` Inputs (carry + x_t_i for each
1557    /// xs + `"d_output"`) and `1 + num_xs` outputs
1558    /// (dcarry + dx_t_i for each xs). This op reads `outputs[0]` =
1559    /// dcarry; the sibling [`Self::ScanBackwardXs`] reads the
1560    /// `outputs[1 + xs_idx]` slot for each xs gradient.
1561    ScanBackward {
1562        body_vjp: Box<crate::Graph>,
1563        length: u32,
1564        save_trajectory: bool,
1565        num_xs: u32,
1566        /// When `0` or equal to `length`, the trajectory input has
1567        /// shape `[length, *carry]` — every step's carry is cached
1568        /// (`CheckpointStrategy::All`). When `0 < K < length`, the
1569        /// trajectory input has shape `[K, *carry]` and the executor
1570        /// recomputes intermediate carries via `forward_body` between
1571        /// checkpoints. `forward_body` must be `Some` whenever this
1572        /// is < length.
1573        num_checkpoints: u32,
1574        /// Forward body (the same `body` from the forward Op::Scan).
1575        /// Required when `num_checkpoints > 0 && < length` so the
1576        /// executor can recompute carries between saved checkpoints.
1577        /// `None` for the All strategy (no recompute needed).
1578        forward_body: Option<Box<crate::Graph>>,
1579    },
1580
1581    /// Companion to [`Self::ScanBackward`] that extracts one stacked
1582    /// per-step `dxs_i` (shape `[length, *per_step_xs_i]`). Same inputs
1583    /// and same `body_vjp` graph as ScanBackward — `xs_idx` selects
1584    /// which body_vjp output to stack into the result.
1585    ///
1586    /// Note: each ScanBackwardXs runs its own backward loop. A future
1587    /// optimization can fuse them into a single multi-output backward
1588    /// kernel; for now it's `1 + num_xs` independent sweeps.
1589    ScanBackwardXs {
1590        body_vjp: Box<crate::Graph>,
1591        length: u32,
1592        save_trajectory: bool,
1593        num_xs: u32,
1594        xs_idx: u32,
1595        num_checkpoints: u32,
1596        forward_body: Option<Box<crate::Graph>>,
1597    },
1598
1599    /// CPU reference 3D Gaussian splat forward render.
1600    ///
1601    /// Seven flat F32 inputs (scene buffers + camera/render meta):
1602    ///   0. positions `[N*3]`
1603    ///   1. scales `[N*3]` (log-space)
1604    ///   2. rotations `[N*4]` (xyzw)
1605    ///   3. opacities `[N]` (logit)
1606    ///   4. colors `[N*3]` (linear RGB)
1607    ///   5. sh_coeffs `[N * sh_coeff_count * 3]`
1608    ///   6. meta `[23]` — camera position/target/up/fov/near/far, background RGB,
1609    ///      then width/height/tile_size/radius_scale/alpha_cutoff/max_splat_steps/
1610    ///      transmittance_threshold/max_list_entries as f32 bit-patterns.
1611    ///
1612    /// Output: `[height * width * 4]` linear RGBA (display gamma baked in).
1613    /// Build via [`crate::Graph::gaussian_splat_render`].
1614    ///
1615    /// Differentiable backward is not implemented in v1; autodiff treats this
1616    /// op as non-differentiable (same as [`Op::Sample`]).
1617    GaussianSplatRender {
1618        width: u32,
1619        height: u32,
1620        tile_size: u32,
1621        radius_scale: f32,
1622        alpha_cutoff: f32,
1623        max_splat_steps: u32,
1624        transmittance_threshold: f32,
1625        max_list_entries: u32,
1626    },
1627
1628    /// Backward pass for [`Self::GaussianSplatRender`].
1629    ///
1630    /// Eight inputs: the same seven as forward plus `d_loss_rgba` `[W*H*4]`
1631    /// (only RGB channels are used). Re-runs the training forward internally.
1632    ///
1633    /// Output: packed gradients
1634    /// `[positions(3N) | scales(3N) | rotations(4N) | opacities(N) | colors(3N) | sh(N*sh*3)]`.
1635    /// Unpack via [`crate::ops::splat::unpack_gaussian_splat_packed_grads`].
1636    GaussianSplatRenderBackward {
1637        width: u32,
1638        height: u32,
1639        tile_size: u32,
1640        radius_scale: f32,
1641        alpha_cutoff: f32,
1642        max_splat_steps: u32,
1643        transmittance_threshold: f32,
1644        max_list_entries: u32,
1645        loss_grad_clip: f32,
1646        sh_band: u32,
1647        max_anisotropy: f32,
1648    },
1649
1650    /// Strict IR stage 1: project, bin, sort, build per-pixel rays.
1651    ///
1652    /// Seven inputs (same scene + meta as [`Self::GaussianSplatRender`]). Output: packed
1653    /// prepare buffer (see `rlx_splat::prep_layout::prep_packed_len`).
1654    GaussianSplatPrepare {
1655        width: u32,
1656        height: u32,
1657        tile_size: u32,
1658        radius_scale: f32,
1659        alpha_cutoff: f32,
1660        max_splat_steps: u32,
1661        transmittance_threshold: f32,
1662        max_list_entries: u32,
1663    },
1664
1665    /// Strict IR stage 2: tile raster from [`Self::GaussianSplatPrepare`] output.
1666    ///
1667    /// Inputs: `prep` packed buffer, `meta` `[23]`. Output: `[width * height * 4]` RGBA.
1668    GaussianSplatRasterize {
1669        width: u32,
1670        height: u32,
1671        tile_size: u32,
1672        alpha_cutoff: f32,
1673        max_splat_steps: u32,
1674        transmittance_threshold: f32,
1675        max_list_entries: u32,
1676    },
1677
1678    /// User-registered custom op. `name` keys into the
1679    /// [`crate::op_registry`] for shape inference, autodiff, and
1680    /// per-backend execution. `attrs` is an opaque blob passed
1681    /// through to those callbacks (FFT direction, SparseLU
1682    /// reordering strategy, etc.). `num_inputs` is captured at
1683    /// construction time so [`Op::num_inputs`] stays infallible
1684    /// without a registry lookup. Build via [`crate::Graph::custom_op`].
1685    Custom {
1686        name: String,
1687        num_inputs: u32,
1688        attrs: Vec<u8>,
1689    },
1690
1691    /// 1D Fast Fourier Transform along the last axis.
1692    ///
1693    /// **Layouts**
1694    /// - `F32` / `F64`: 2N real-block — last axis is `[re₀…re_{N-1}, im₀…im_{N-1}]`.
1695    /// - `C64`: interleaved `[re, im]` pairs per complex element along the last axis.
1696    ///
1697    /// **ND transforms** — use `Graph::fftn` / `Graph::ifftn`, which compose
1698    /// `fft_axis` (transpose → Fft → transpose). Multi-axis `fftn` requires
1699    /// `DType::C64`; the 2N-block layout describes a single complex axis.
1700    ///
1701    /// Default (`FftNorm::Backward`) is **unnormalized** on both directions:
1702    ///   `fft(x)[k] = Σ x[n]·exp(-2πi·nk/N)`
1703    ///   `ifft(y)[n] = Σ y[k]·exp(+2πi·nk/N)`
1704    /// so `ifft(fft(x)) = N·x`. Use `FftNorm::Forward` for gpu-fft-style
1705    /// `1/N` scaling on inverse, or `FftNorm::Ortho` for unitary scaling.
1706    ///
1707    /// AD: VJP(`fft`) = `ifft`, VJP(`ifft`) = `fft` when `norm=Backward`;
1708    /// other norms apply the chain rule via output scaling.
1709    Fft {
1710        inverse: bool,
1711        norm: crate::fft::FftNorm,
1712    },
1713
1714    /// Ternary pruned radix-2 butterfly stage on interleaved complex state.
1715    ///
1716    /// Inputs:
1717    ///   0 — state `[batch, n_fft, 2]` (re/im on axis 2)
1718    ///   1 — gate `[half]` — 0 = identity, 1 = run butterfly (`half = n_fft/2`)
1719    ///   2 — rev `[half]` — 0 = forward, 1 = swap outputs when gate=1
1720    ///   3 — tw_re `[half]`
1721    ///   4 — tw_im `[half]`
1722    ///
1723    /// Output: `[batch, n_fft, 2]` same layout. Slots with gate=0 copy inputs
1724    /// without twiddle math.
1725    FftButterflyStage {
1726        stage: u32,
1727        n_fft: u32,
1728    },
1729
1730    /// Log-mel spectrogram from RLX FFT block-layout spectrum.
1731    ///
1732    /// Inputs:
1733    ///   0 — spectrum `[..., 2*n_fft]` (re plane then im plane, same as `Op::Fft` output)
1734    ///   1 — mel filterbank `[n_mels, n_bins]` with `n_bins = n_fft/2 + 1`
1735    ///
1736    /// Output: `[..., n_mels]` with Whisper dynamic-range compression
1737    /// (`log10`, clamp to max−8 dB, `(x+4)/4`).
1738    LogMel,
1739
1740    /// VJP of [`Op::LogMel`] w.r.t. spectrum (input 0).
1741    ///
1742    /// Inputs: spectrum block, mel filters, upstream `dy`.
1743    /// Output: `d_spectrum` (same shape as input 0).
1744    LogMelBackward,
1745
1746    /// Top-K Welch peaks from block-layout segment spectra.
1747    ///
1748    /// Input 0: spectrum `[batch * n_segments, 2*n_fft]` (re ∥ im planes).
1749    /// Output: `[batch, k*2]` interleaved `(bin, power)` per spike.
1750    WelchPeaks {
1751        k: usize,
1752        n_segments: usize,
1753    },
1754
1755    /// User-defined sub-graph with optional override AD rules.
1756    /// Mirrors JAX's `custom_vjp` / `custom_jvp` decorators: the
1757    /// caller wraps a forward computation and supplies its own
1758    /// reverse- and/or forward-mode AD bodies. Useful when:
1759    ///   * The forward is iterative (Newton, fixed-point) and
1760    ///     differentiating through the loop is wasteful — the
1761    ///     vjp_body computes the implicit-function gradient at the
1762    ///     converged point in one shot.
1763    ///   * The math has a closed-form gradient that's much cheaper
1764    ///     than autodiff.
1765    ///   * The forward op is non-differentiable by tracing
1766    ///     (sampling, argmax) and the user wants a smooth surrogate.
1767    ///
1768    /// **fwd_body**: `num_inputs` Op::Inputs in NodeId construction
1769    /// order, one Op::Output (the primal y). Forward execution
1770    /// inlines this body once.
1771    ///
1772    /// **vjp_body** (optional): Op::Inputs are `num_inputs` primal
1773    /// inputs in NodeId order, plus two special-named Inputs —
1774    /// `"primal_output"` (the y from forward) and `"d_output"` (the
1775    /// upstream gradient). Outputs: `num_inputs` tensors in
1776    /// `set_outputs` order, matching the gradients of each primal
1777    /// input. When `None`, reverse-mode AD recurses into fwd_body
1778    /// — same as if the op were inlined.
1779    ///
1780    /// **jvp_body** (optional): Op::Inputs are `num_inputs` primal
1781    /// inputs in NodeId order, `num_inputs` special-named Inputs
1782    /// `"tangent_0"..="tangent_{num_inputs-1}"` carrying each input's
1783    /// tangent, and an optional special-named `"primal_output"` Input
1784    /// (the y from forward, useful when the JVP must be evaluated at
1785    /// a converged / nonlinear point — e.g. IFT-style forward-mode AD
1786    /// of an iterative solver). Output: 1 tensor (the tangent of y).
1787    /// When `None`, forward-mode AD recurses into fwd_body.
1788    ///
1789    /// `num_inputs` is captured so [`Op::num_inputs`] stays
1790    /// infallible. Build via [`crate::Graph::custom_fn`].
1791    CustomFn {
1792        fwd_body: Box<crate::Graph>,
1793        vjp_body: Option<Box<crate::Graph>>,
1794        jvp_body: Option<Box<crate::Graph>>,
1795        num_inputs: u32,
1796    },
1797}
1798
1799impl Op {
1800    /// PLAN L4: discriminant for backend-supported-set checks.
1801    /// Stable, parameter-free identity per variant — `Op::Activation(_)`
1802    /// and `Op::Activation(Relu)` share the same `OpKind::Activation`.
1803    pub fn kind(&self) -> OpKind {
1804        match self {
1805            Op::Input { .. } => OpKind::Input,
1806            Op::Param { .. } => OpKind::Param,
1807            Op::Constant { .. } => OpKind::Constant,
1808            Op::Activation(_) => OpKind::Activation,
1809            Op::Cast { .. } => OpKind::Cast,
1810            Op::StopGradient => OpKind::StopGradient,
1811            Op::Quantize { .. } => OpKind::Quantize,
1812            Op::Dequantize { .. } => OpKind::Dequantize,
1813            Op::FakeQuantize { .. } => OpKind::FakeQuantize,
1814            Op::FakeQuantizeLSQ { .. } => OpKind::FakeQuantizeLSQ,
1815            Op::FakeQuantizeLSQBackwardX { .. } => OpKind::FakeQuantizeLSQBackwardX,
1816            Op::FakeQuantizeLSQBackwardScale { .. } => OpKind::FakeQuantizeLSQBackwardScale,
1817            Op::Binary(_) => OpKind::Binary,
1818            Op::Compare(_) => OpKind::Compare,
1819            Op::Where => OpKind::Where,
1820            Op::ElementwiseRegion { .. } => OpKind::ElementwiseRegion,
1821            Op::TransformRegion { .. } => OpKind::TransformRegion,
1822            Op::BatchElementwiseRegion { .. } => OpKind::BatchElementwiseRegion,
1823            Op::MatMul => OpKind::MatMul,
1824            Op::DotGeneral { .. } => OpKind::DotGeneral,
1825            Op::DenseSolve => OpKind::DenseSolve,
1826            Op::BatchedDenseSolve => OpKind::BatchedDenseSolve,
1827            Op::LayerNorm { .. } => OpKind::LayerNorm,
1828            Op::LayerNorm2d { .. } => OpKind::LayerNorm2d,
1829            Op::GroupNorm { .. } => OpKind::GroupNorm,
1830            Op::BatchNormInference { .. } => OpKind::BatchNormInference,
1831            Op::RmsNorm { .. } => OpKind::RmsNorm,
1832            Op::ResizeNearest2x => OpKind::ResizeNearest2x,
1833            Op::Attention { .. } => OpKind::Attention,
1834            Op::Rope { .. } => OpKind::Rope,
1835            Op::AxialRope2d { .. } => OpKind::AxialRope2d,
1836            Op::Reshape { .. } => OpKind::Reshape,
1837            Op::Transpose { .. } => OpKind::Transpose,
1838            Op::Narrow { .. } => OpKind::Narrow,
1839            Op::Concat { .. } => OpKind::Concat,
1840            Op::Expand { .. } => OpKind::Expand,
1841            Op::Gather { .. } => OpKind::Gather,
1842            Op::Reduce { .. } => OpKind::Reduce,
1843            Op::Softmax { .. } => OpKind::Softmax,
1844            Op::Cumsum { .. } => OpKind::Cumsum,
1845            Op::ArgMax { .. } => OpKind::ArgMax,
1846            Op::ArgMin { .. } => OpKind::ArgMin,
1847            Op::TopK { .. } => OpKind::TopK,
1848            Op::Sample { .. } => OpKind::Sample,
1849            Op::RngNormal { .. } => OpKind::RngNormal,
1850            Op::RngUniform { .. } => OpKind::RngUniform,
1851            Op::Conv { .. } => OpKind::Conv,
1852            Op::Im2Col { .. } => OpKind::Im2Col,
1853            Op::ConvTranspose2d { .. } => OpKind::ConvTranspose2d,
1854            Op::Pool { .. } => OpKind::Pool,
1855            Op::ReluBackward => OpKind::ReluBackward,
1856            Op::ActivationBackward { .. } => OpKind::ActivationBackward,
1857            Op::FakeQuantizeBackward { .. } => OpKind::FakeQuantizeBackward,
1858            Op::ComplexNormSq => OpKind::ComplexNormSq,
1859            Op::ComplexNormSqBackward => OpKind::ComplexNormSqBackward,
1860            Op::Conjugate => OpKind::Conjugate,
1861            Op::LayerNormBackwardInput { .. } => OpKind::LayerNormBackwardInput,
1862            Op::LayerNormBackwardGamma { .. } => OpKind::LayerNormBackwardGamma,
1863            Op::RmsNormBackwardInput { .. } => OpKind::RmsNormBackwardInput,
1864            Op::RmsNormBackwardGamma { .. } => OpKind::RmsNormBackwardGamma,
1865            Op::RmsNormBackwardBeta { .. } => OpKind::RmsNormBackwardBeta,
1866            Op::RopeBackward { .. } => OpKind::RopeBackward,
1867            Op::GroupNormBackwardInput { .. } => OpKind::GroupNormBackwardInput,
1868            Op::GroupNormBackwardGamma { .. } => OpKind::GroupNormBackwardGamma,
1869            Op::GroupNormBackwardBeta { .. } => OpKind::GroupNormBackwardBeta,
1870            Op::BatchNormInferenceBackwardInput { .. } => OpKind::BatchNormInferenceBackwardInput,
1871            Op::BatchNormInferenceBackwardGamma { .. } => OpKind::BatchNormInferenceBackwardGamma,
1872            Op::BatchNormInferenceBackwardBeta => OpKind::BatchNormInferenceBackwardBeta,
1873            Op::CumsumBackward { .. } => OpKind::CumsumBackward,
1874            Op::GatherBackward { .. } => OpKind::GatherBackward,
1875            Op::MaxPool2dBackward { .. } => OpKind::MaxPool2dBackward,
1876            Op::Conv2dBackwardInput { .. } => OpKind::Conv2dBackwardInput,
1877            Op::Conv2dBackwardWeight { .. } => OpKind::Conv2dBackwardWeight,
1878            Op::SoftmaxCrossEntropyWithLogits => OpKind::SoftmaxCrossEntropyWithLogits,
1879            Op::SoftmaxCrossEntropyBackward => OpKind::SoftmaxCrossEntropyBackward,
1880            Op::AttentionBackward { .. } => OpKind::AttentionBackward,
1881            Op::GroupedMatMul => OpKind::GroupedMatMul,
1882            Op::DequantGroupedMatMul { .. } => OpKind::DequantGroupedMatMul,
1883            Op::DequantMoEWeights { .. } => OpKind::DequantMoEWeights,
1884            Op::ScatterAdd => OpKind::ScatterAdd,
1885            Op::LoraMatMul { .. } => OpKind::LoraMatMul,
1886            Op::DequantMatMul { .. } => OpKind::DequantMatMul,
1887            Op::QMatMul { .. } => OpKind::QMatMul,
1888            Op::QConv2d { .. } => OpKind::QConv2d,
1889            Op::SelectiveScan { .. } => OpKind::SelectiveScan,
1890            Op::GatedDeltaNet { .. } => OpKind::GatedDeltaNet,
1891            Op::Lstm { .. } => OpKind::Lstm,
1892            Op::Gru { .. } => OpKind::Gru,
1893            Op::Rnn { .. } => OpKind::Rnn,
1894            Op::Mamba2 { .. } => OpKind::Mamba2,
1895            Op::FusedSwiGLU { .. } => OpKind::FusedSwiGLU,
1896            Op::FusedMatMulBiasAct { .. } => OpKind::FusedMatMulBiasAct,
1897            Op::FusedResidualLN { .. } => OpKind::FusedResidualLN,
1898            Op::FusedResidualRmsNorm { .. } => OpKind::FusedResidualRmsNorm,
1899            Op::FusedAttentionBlock { .. } => OpKind::FusedAttentionBlock,
1900            Op::FusedTransformerLayer { .. } => OpKind::FusedTransformerLayer,
1901            Op::If { .. } => OpKind::If,
1902            Op::While { .. } => OpKind::While,
1903            Op::Scan { .. } => OpKind::Scan,
1904            Op::ScanBackward { .. } => OpKind::ScanBackward,
1905            Op::ScanBackwardXs { .. } => OpKind::ScanBackwardXs,
1906            Op::GaussianSplatRender { .. } => OpKind::GaussianSplatRender,
1907            Op::GaussianSplatRenderBackward { .. } => OpKind::GaussianSplatRenderBackward,
1908            Op::GaussianSplatPrepare { .. } => OpKind::GaussianSplatPrepare,
1909            Op::GaussianSplatRasterize { .. } => OpKind::GaussianSplatRasterize,
1910            Op::Custom { .. } => OpKind::Custom,
1911            Op::CustomFn { .. } => OpKind::CustomFn,
1912            Op::Fft { .. } => OpKind::Fft,
1913            Op::FftButterflyStage { .. } => OpKind::FftButterflyStage,
1914            Op::LogMel => OpKind::LogMel,
1915            Op::LogMelBackward => OpKind::LogMelBackward,
1916            Op::WelchPeaks { .. } => OpKind::WelchPeaks,
1917        }
1918    }
1919
1920    /// True if this op is element-wise (same shape in, same shape out).
1921    /// Element-wise ops are prime fusion candidates.
1922    pub fn is_elementwise(&self) -> bool {
1923        matches!(
1924            self,
1925            Op::Activation(_)
1926                | Op::Cast { .. }
1927                | Op::StopGradient
1928                | Op::Binary(_)
1929                | Op::Compare(_)
1930                | Op::Where
1931                | Op::ElementwiseRegion { .. }
1932                | Op::BatchElementwiseRegion { .. }
1933        )
1934    }
1935
1936    /// True if this op may appear in a [`Op::TransformRegion`] chain.
1937    pub fn is_transform_eligible(&self) -> bool {
1938        matches!(self, Op::ResizeNearest2x)
1939    }
1940
1941    /// True if this op is a BLAS/compute-intensive op that forms a fusion boundary.
1942    pub fn is_blas(&self) -> bool {
1943        matches!(
1944            self,
1945            Op::MatMul
1946                | Op::DotGeneral { .. }
1947                | Op::DenseSolve
1948                | Op::BatchedDenseSolve
1949                | Op::Conv { .. }
1950                | Op::Im2Col { .. }
1951                | Op::ConvTranspose2d { .. }
1952                | Op::FusedMatMulBiasAct { .. }
1953                | Op::GroupedMatMul
1954                | Op::DequantGroupedMatMul { .. }
1955                | Op::DequantMoEWeights { .. }
1956                | Op::LoraMatMul { .. }
1957                | Op::DequantMatMul { .. }
1958                | Op::QMatMul { .. }
1959                | Op::QConv2d { .. }
1960        )
1961    }
1962
1963    /// True if element-wise fusion must not span across this op.
1964    pub fn is_fusion_boundary(&self) -> bool {
1965        self.is_blas()
1966            || matches!(
1967                self,
1968                Op::GaussianSplatRender { .. }
1969                    | Op::GaussianSplatRenderBackward { .. }
1970                    | Op::GaussianSplatPrepare { .. }
1971                    | Op::GaussianSplatRasterize { .. }
1972            )
1973    }
1974
1975    /// True if this op is a reduction (drives loop iteration in fused kernels).
1976    pub fn is_reduction(&self) -> bool {
1977        matches!(
1978            self,
1979            Op::Reduce { .. } | Op::Softmax { .. } | Op::TopK { .. }
1980        )
1981    }
1982
1983    /// Number of tensor inputs this op expects.
1984    pub fn num_inputs(&self) -> usize {
1985        match self {
1986            Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => 0,
1987            Op::Activation(_)
1988            | Op::Cast { .. }
1989            | Op::StopGradient
1990            | Op::Reshape { .. }
1991            | Op::Quantize { .. }
1992            | Op::Dequantize { .. }
1993            | Op::Transpose { .. }
1994            | Op::Narrow { .. }
1995            | Op::Expand { .. }
1996            | Op::Reduce { .. }
1997            | Op::Softmax { .. }
1998            | Op::FusedSwiGLU { .. }
1999            | Op::TopK { .. }
2000            | Op::Cumsum { .. }
2001            | Op::ArgMax { .. }
2002            | Op::ArgMin { .. }
2003            | Op::Sample { .. }
2004            | Op::ResizeNearest2x => 1,
2005            Op::RngNormal { .. } | Op::RngUniform { .. } => 0, // 0 or 1 — see verify
2006            // EMA / Fixed scale modes carry a state tensor as a 2nd input;
2007            // PerBatch (default) doesn't need one.
2008            Op::FakeQuantize { scale_mode, .. } => match scale_mode {
2009                ScaleMode::PerBatch => 1,
2010                ScaleMode::EMA { .. } | ScaleMode::Fixed => 2,
2011            },
2012            Op::FakeQuantizeLSQ { .. } => 2, // x, scale (learned param)
2013            Op::FakeQuantizeLSQBackwardX { .. } | Op::FakeQuantizeLSQBackwardScale { .. } => 3, // x, scale, dy
2014            Op::Binary(_) | Op::Compare(_) | Op::Gather { .. } | Op::MatMul | Op::ScatterAdd => 2,
2015            Op::GroupedMatMul => 3,               // input, weight, expert_idx
2016            Op::DequantGroupedMatMul { .. } => 3, // input, packed_w, expert_idx
2017            Op::DequantMoEWeights { .. } => 1,    // packed_w
2018            Op::LoraMatMul { .. } => 4,           // x, w, a, b
2019            // x, w_q, scale, zp — or x, packed_w_bytes for GGUF
2020            // schemes (their scales/mins live inside the packed bytes,
2021            // see `QuantScheme::is_gguf`).
2022            Op::DequantMatMul { scheme } => {
2023                if scheme.is_gguf() {
2024                    2
2025                } else {
2026                    4
2027                }
2028            }
2029            Op::QMatMul { .. } => 3,       // x, w, bias
2030            Op::QConv2d { .. } => 3,       // x, w, bias
2031            Op::SelectiveScan { .. } => 5, // x, delta, a, b, c
2032            Op::GatedDeltaNet { carry_state, .. } if *carry_state => 6, // + state in/out
2033            Op::GatedDeltaNet { .. } => 5, // q, k, v, g, beta
2034            Op::Lstm { carry, .. } => {
2035                if *carry { 6 } else { 4 } // x, w_ih, w_hh, bias (+ h0, c0)
2036            }
2037            Op::Gru { carry, .. } => {
2038                if *carry { 6 } else { 5 } // x, w_ih, w_hh, b_ih, b_hh (+ h0)
2039            }
2040            Op::Rnn { carry, .. } => {
2041                if *carry { 5 } else { 4 } // x, w_ih, w_hh, bias (+ h0)
2042            }
2043            Op::Mamba2 { .. } => 5, // x, dt, a, b, c
2044            Op::Where => 3,         // cond, on_true, on_false
2045            Op::Attention { mask_kind, .. } => match mask_kind {
2046                MaskKind::Custom | MaskKind::Bias => 4, // Q, K, V, mask
2047                _ => 3,                                 // Q, K, V (mask synthesized in-kernel)
2048            },
2049            Op::AttentionBackward { mask_kind, .. } => match mask_kind {
2050                MaskKind::Custom | MaskKind::Bias => 5, // q, k, v, dy, mask
2051                _ => 4,                                 // q, k, v, dy
2052            },
2053            Op::Rope { .. } => 3, // x, cos, sin
2054            Op::AxialRope2d { .. } => 1,
2055            Op::LayerNorm { .. }
2056            | Op::LayerNorm2d { .. }
2057            | Op::GroupNorm { .. }
2058            | Op::RmsNorm { .. } => 3, // input, gamma, beta
2059            Op::BatchNormInference { .. } => 5, // x, gamma, beta, mean, var
2060            Op::FusedMatMulBiasAct { .. } => 3, // input, weight, bias
2061            Op::FusedResidualLN { has_bias: true, .. } => 5, // x, residual, bias, gamma, beta
2062            Op::FusedResidualLN {
2063                has_bias: false, ..
2064            } => 4, // x, residual, gamma, beta
2065            Op::FusedResidualRmsNorm { has_bias: true, .. } => 5, // x, residual, bias, gamma, beta
2066            Op::FusedResidualRmsNorm {
2067                has_bias: false, ..
2068            } => 4, // x, residual, gamma, beta
2069            Op::Conv { .. } | Op::ConvTranspose2d { .. } => 2, // input, weight (bias via Add)
2070            Op::Im2Col { .. } => 1,
2071            Op::Pool { .. } => 1,
2072            Op::ReluBackward => 2,                  // x, dy
2073            Op::ActivationBackward { .. } => 2,     // x, dy
2074            Op::FakeQuantizeBackward { .. } => 2,   // x, dy
2075            Op::ComplexNormSq => 1,                 // z (C64)
2076            Op::ComplexNormSqBackward => 2,         // z, g
2077            Op::Conjugate => 1,                     // z (C64)
2078            Op::LayerNormBackwardInput { .. } => 3, // x, gamma, dy
2079            Op::LayerNormBackwardGamma { .. } => 2, // x, dy
2080            Op::RmsNormBackwardInput { .. } => 4,   // x, gamma, beta, dy
2081            Op::RmsNormBackwardGamma { .. } => 4,
2082            Op::RmsNormBackwardBeta { .. } => 4,
2083            Op::RopeBackward { .. } => 3,           // dy, cos, sin
2084            Op::GroupNormBackwardInput { .. } => 4, // x, gamma, beta, dy
2085            Op::GroupNormBackwardGamma { .. } => 2, // x, dy
2086            Op::GroupNormBackwardBeta { .. } => 2,
2087            Op::BatchNormInferenceBackwardInput { .. } => 5, // x, gamma, mean, var, dy
2088            Op::BatchNormInferenceBackwardGamma { .. } => 4, // x, mean, var, dy
2089            Op::BatchNormInferenceBackwardBeta => 1,         // dy
2090            Op::CumsumBackward { .. } => 1,                  // dy
2091            Op::GatherBackward { .. } => 2,                  // dy, indices
2092            Op::MaxPool2dBackward { .. } => 2,               // x, dy
2093            Op::Conv2dBackwardInput { .. } => 2,             // dy, w
2094            Op::Conv2dBackwardWeight { .. } => 2,            // x, dy
2095            Op::SoftmaxCrossEntropyWithLogits => 2,          // logits, labels
2096            Op::SoftmaxCrossEntropyBackward => 3,            // logits, labels, d_loss
2097            Op::Concat { .. } => 0,                          // variadic — checked at graph level
2098            Op::DotGeneral { .. } => 2,
2099            Op::DenseSolve => 2,        // A, b
2100            Op::BatchedDenseSolve => 2, // A [B,N,N], b [B,N] or [B,N,K]
2101            Op::FusedAttentionBlock {
2102                has_bias, has_rope, ..
2103            } => 4 + if *has_bias { 2 } else { 0 } + if *has_rope { 2 } else { 0 },
2104            Op::If { .. } => 1,    // predicate (captures handled separately)
2105            Op::While { .. } => 0, // variadic loop-carried; checked at graph level
2106            Op::Scan {
2107                num_bcast, num_xs, ..
2108            } => 1 + *num_bcast as usize + *num_xs as usize,
2109            Op::ScanBackward { num_xs, .. } => 3 + *num_xs as usize, // init, trajectory, upstream, xs_0..
2110            Op::ScanBackwardXs { num_xs, .. } => 3 + *num_xs as usize, // same as ScanBackward
2111            Op::GaussianSplatRender { .. } => 7,
2112            Op::GaussianSplatRenderBackward { .. } => 8,
2113            Op::GaussianSplatPrepare { .. } => 7,
2114            Op::GaussianSplatRasterize { .. } => 2,
2115            Op::FusedTransformerLayer { has_bias, .. } => {
2116                // hidden + qkv_w + out_w + ln1_g + ln1_b + fc1_w + fc2_w + ln2_g + ln2_b + mask = 10
2117                // bias variant adds: qkv_b + out_b + fc1_b + fc2_b = 4 more
2118                10 + if *has_bias { 4 } else { 0 }
2119            }
2120            Op::ElementwiseRegion { num_inputs, .. } => *num_inputs as usize,
2121            Op::TransformRegion { num_inputs, .. } => *num_inputs as usize,
2122            Op::BatchElementwiseRegion {
2123                num_batch_inputs, ..
2124            } => *num_batch_inputs as usize,
2125            Op::Custom { num_inputs, .. } => *num_inputs as usize,
2126            Op::CustomFn { num_inputs, .. } => *num_inputs as usize,
2127            Op::Fft { .. } => 1,
2128            Op::FftButterflyStage { .. } => 5,
2129            Op::LogMel => 2,
2130            Op::LogMelBackward => 3,
2131            Op::WelchPeaks { .. } => 1,
2132        }
2133    }
2134}
2135
2136impl std::fmt::Display for Op {
2137    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2138        match self {
2139            Op::Input { name } => write!(f, "input(\"{name}\")"),
2140            Op::Param { name } => write!(f, "param(\"{name}\")"),
2141            Op::Constant { data } => write!(f, "const({}B)", data.len()),
2142            Op::Activation(a) => write!(f, "{a:?}"),
2143            Op::Quantize { axis, scales, .. } => match axis {
2144                None => write!(f, "quantize(s={})", scales[0]),
2145                Some(d) => write!(f, "quantize(axis={d},nch={})", scales.len()),
2146            },
2147            Op::Dequantize { axis, scales, .. } => match axis {
2148                None => write!(f, "dequantize(s={})", scales[0]),
2149                Some(d) => write!(f, "dequantize(axis={d},nch={})", scales.len()),
2150            },
2151            Op::FakeQuantize {
2152                bits,
2153                axis,
2154                ste,
2155                scale_mode,
2156            } => match axis {
2157                None => write!(
2158                    f,
2159                    "fake_quant(bits={bits},ste={ste:?},scale={scale_mode:?})"
2160                ),
2161                Some(d) => write!(
2162                    f,
2163                    "fake_quant(bits={bits},axis={d},ste={ste:?},scale={scale_mode:?})"
2164                ),
2165            },
2166            Op::FakeQuantizeLSQ { bits, axis } => match axis {
2167                None => write!(f, "fake_quant_lsq(bits={bits})"),
2168                Some(d) => write!(f, "fake_quant_lsq(bits={bits},axis={d})"),
2169            },
2170            Op::FakeQuantizeLSQBackwardX { bits, .. } => {
2171                write!(f, "fake_quant_lsq_bwd_x(bits={bits})")
2172            }
2173            Op::FakeQuantizeLSQBackwardScale { bits, .. } => {
2174                write!(f, "fake_quant_lsq_bwd_s(bits={bits})")
2175            }
2176            Op::Cast { to } => write!(f, "cast({to})"),
2177            Op::StopGradient => write!(f, "stop_gradient"),
2178            Op::Binary(op) => write!(f, "{op:?}"),
2179            Op::Compare(op) => write!(f, "{op:?}"),
2180            Op::Where => write!(f, "where"),
2181            Op::MatMul => write!(f, "matmul"),
2182            Op::DotGeneral { .. } => write!(f, "dot_general"),
2183            Op::DenseSolve => write!(f, "dense_solve"),
2184            Op::BatchedDenseSolve => write!(f, "batched_dense_solve"),
2185            Op::LayerNorm { eps, .. } => write!(f, "layer_norm(eps={eps})"),
2186            Op::GroupNorm { num_groups, eps } => {
2187                write!(f, "group_norm(groups={num_groups},eps={eps})")
2188            }
2189            Op::BatchNormInference { eps } => write!(f, "batch_norm_inference(eps={eps})"),
2190            Op::ResizeNearest2x => write!(f, "resize_nearest_2x"),
2191            Op::RmsNorm { eps, .. } => write!(f, "rms_norm(eps={eps})"),
2192            Op::Attention {
2193                num_heads,
2194                head_dim,
2195                mask_kind,
2196                score_scale,
2197                attn_logit_softcap,
2198            } => {
2199                let mut s = match mask_kind {
2200                    MaskKind::Custom => format!("attention(h={num_heads},d={head_dim})"),
2201                    MaskKind::None => format!("attention(h={num_heads},d={head_dim},nomask)"),
2202                    MaskKind::Causal => format!("attention(h={num_heads},d={head_dim},causal)"),
2203                    MaskKind::SlidingWindow(w) => {
2204                        format!("attention(h={num_heads},d={head_dim},sw={w})")
2205                    }
2206                    MaskKind::Bias => format!("attention(h={num_heads},d={head_dim},bias)"),
2207                };
2208                if let Some(sc) = score_scale {
2209                    s.push_str(&format!(",scale={sc}"));
2210                }
2211                if let Some(cap) = attn_logit_softcap {
2212                    s.push_str(&format!(",softcap={cap}"));
2213                }
2214                write!(f, "{s}")
2215            }
2216            Op::Rope { head_dim, n_rot } => write!(f, "rope(d={head_dim}, n_rot={n_rot})"),
2217            Op::AxialRope2d {
2218                end_x,
2219                end_y,
2220                head_dim,
2221                num_heads,
2222                theta,
2223                repeat_factor,
2224            } => write!(
2225                f,
2226                "axial_rope2d({end_x}x{end_y},h={num_heads},d={head_dim},θ={theta},r={repeat_factor})"
2227            ),
2228            Op::Reshape { new_shape } => write!(f, "reshape({new_shape:?})"),
2229            Op::Transpose { perm } => write!(f, "transpose({perm:?})"),
2230            Op::Narrow { axis, start, len } => write!(f, "narrow({axis},{start},{len})"),
2231            Op::Concat { axis } => write!(f, "concat(axis={axis})"),
2232            Op::Expand { .. } => write!(f, "expand"),
2233            Op::Gather { axis } => write!(f, "gather(axis={axis})"),
2234            Op::Reduce { op, axes, .. } => write!(f, "reduce_{op:?}({axes:?})"),
2235            Op::Softmax { axis } => write!(f, "softmax(axis={axis})"),
2236            Op::Cumsum { axis, exclusive } => {
2237                if *exclusive {
2238                    write!(f, "cumsum(axis={axis},excl)")
2239                } else {
2240                    write!(f, "cumsum(axis={axis})")
2241                }
2242            }
2243            Op::ArgMax { axis, keep_dim } => write!(f, "argmax(axis={axis},keep={keep_dim})"),
2244            Op::ArgMin { axis, keep_dim } => write!(f, "argmin(axis={axis},keep={keep_dim})"),
2245            Op::Sample {
2246                top_k,
2247                top_p,
2248                temperature,
2249                ..
2250            } => {
2251                write!(f, "sample(t={temperature}")?;
2252                if *top_k > 0 {
2253                    write!(f, ",k={top_k}")?;
2254                }
2255                if *top_p < 1.0 {
2256                    write!(f, ",p={top_p}")?;
2257                }
2258                write!(f, ")")
2259            }
2260            Op::RngNormal {
2261                mean,
2262                scale,
2263                key,
2264                op_seed,
2265            } => {
2266                write!(f, "rng_normal({mean},{scale},key={key}")?;
2267                if let Some(s) = op_seed {
2268                    write!(f, ",seed={s}")?;
2269                }
2270                write!(f, ")")
2271            }
2272            Op::RngUniform {
2273                low,
2274                high,
2275                key,
2276                op_seed,
2277            } => {
2278                write!(f, "rng_uniform({low},{high},key={key}")?;
2279                if let Some(s) = op_seed {
2280                    write!(f, ",seed={s}")?;
2281                }
2282                write!(f, ")")
2283            }
2284            Op::TopK { k } => write!(f, "topk(k={k})"),
2285            Op::GroupedMatMul => write!(f, "grouped_matmul"),
2286            Op::DequantGroupedMatMul { scheme } => {
2287                write!(f, "dequant_grouped_matmul({scheme})")
2288            }
2289            Op::DequantMoEWeights { scheme } => write!(f, "dequant_moe_weights({scheme})"),
2290            Op::LoraMatMul { scale } => write!(f, "lora_matmul(scale={scale})"),
2291            Op::DequantMatMul { scheme } => write!(f, "dequant_matmul({scheme})"),
2292            Op::QMatMul {
2293                x_zp,
2294                w_zp,
2295                out_zp,
2296                mult,
2297            } => write!(
2298                f,
2299                "q_matmul(x_zp={x_zp},w_zp={w_zp},out_zp={out_zp},mult={mult})"
2300            ),
2301            Op::QConv2d { kernel_size, .. } => write!(f, "q_conv2d({kernel_size:?})"),
2302            Op::SelectiveScan { state_size } => write!(f, "ssm_scan(n={state_size})"),
2303            Op::GatedDeltaNet {
2304                state_size,
2305                carry_state,
2306            } => {
2307                if *carry_state {
2308                    write!(f, "gated_delta_net(n={state_size},carry)")
2309                } else {
2310                    write!(f, "gated_delta_net(n={state_size})")
2311                }
2312            }
2313            Op::Lstm {
2314                hidden_size,
2315                num_layers,
2316                bidirectional,
2317                carry,
2318            } => {
2319                let dir = if *bidirectional { "bi" } else { "uni" };
2320                let c = if *carry { ",carry" } else { "" };
2321                write!(f, "lstm(h={hidden_size},layers={num_layers},{dir}{c})")
2322            }
2323            Op::Gru {
2324                hidden_size,
2325                num_layers,
2326                bidirectional,
2327                carry,
2328            } => {
2329                let dir = if *bidirectional { "bi" } else { "uni" };
2330                let c = if *carry { ",carry" } else { "" };
2331                write!(f, "gru(h={hidden_size},layers={num_layers},{dir}{c})")
2332            }
2333            Op::Rnn {
2334                hidden_size,
2335                num_layers,
2336                bidirectional,
2337                carry,
2338                relu,
2339            } => {
2340                let dir = if *bidirectional { "bi" } else { "uni" };
2341                let act = if *relu { "relu" } else { "tanh" };
2342                let c = if *carry { ",carry" } else { "" };
2343                write!(f, "rnn(h={hidden_size},layers={num_layers},{dir},{act}{c})")
2344            }
2345            Op::Mamba2 {
2346                head_dim,
2347                state_size,
2348            } => write!(f, "mamba2(p={head_dim},n={state_size})"),
2349            Op::ScatterAdd => write!(f, "scatter_add"),
2350            Op::Conv { kernel_size, .. } => write!(f, "conv2d({kernel_size:?})"),
2351            Op::Im2Col { kernel_size, .. } => write!(f, "im2col({kernel_size:?})"),
2352            Op::ConvTranspose2d { kernel_size, .. } => {
2353                write!(f, "conv_transpose2d({kernel_size:?})")
2354            }
2355            Op::LayerNorm2d { eps } => write!(f, "layer_norm2d(eps={eps})"),
2356            Op::Pool {
2357                kind, kernel_size, ..
2358            } => write!(f, "pool_{kind:?}({kernel_size:?})"),
2359            Op::ReluBackward => write!(f, "relu_backward"),
2360            Op::ActivationBackward { kind } => write!(f, "{kind:?}_backward"),
2361            Op::ComplexNormSq => write!(f, "complex_norm_sq"),
2362            Op::ComplexNormSqBackward => write!(f, "complex_norm_sq_backward"),
2363            Op::Conjugate => write!(f, "conjugate"),
2364            Op::FakeQuantizeBackward { bits, ste, .. } => {
2365                write!(f, "fake_quant_backward(bits={bits},ste={ste:?})")
2366            }
2367            Op::MaxPool2dBackward { kernel_size, .. } => {
2368                write!(f, "maxpool2d_backward({kernel_size:?})")
2369            }
2370            Op::Conv2dBackwardInput { kernel_size, .. } => {
2371                write!(f, "conv2d_backward_input({kernel_size:?})")
2372            }
2373            Op::Conv2dBackwardWeight { kernel_size, .. } => {
2374                write!(f, "conv2d_backward_weight({kernel_size:?})")
2375            }
2376            Op::SoftmaxCrossEntropyWithLogits => write!(f, "sce_with_logits"),
2377            Op::SoftmaxCrossEntropyBackward => write!(f, "sce_backward"),
2378            Op::AttentionBackward {
2379                num_heads,
2380                head_dim,
2381                mask_kind,
2382                wrt,
2383            } => match mask_kind {
2384                MaskKind::None => write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},nomask)"),
2385                MaskKind::Causal => {
2386                    write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},causal)")
2387                }
2388                MaskKind::SlidingWindow(w) => {
2389                    write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},sw={w})")
2390                }
2391                MaskKind::Custom => {
2392                    write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},custom)")
2393                }
2394                MaskKind::Bias => write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},bias)"),
2395            },
2396            Op::FusedMatMulBiasAct { activation } => {
2397                write!(f, "fused_mm_bias")?;
2398                if let Some(a) = activation {
2399                    write!(f, "_{a:?}")?;
2400                }
2401                Ok(())
2402            }
2403            Op::FusedResidualLN { has_bias, eps } => {
2404                write!(f, "fused_residual")?;
2405                if *has_bias {
2406                    write!(f, "_bias")?;
2407                }
2408                write!(f, "_ln(eps={eps})")
2409            }
2410            Op::FusedResidualRmsNorm { has_bias, eps } => {
2411                write!(f, "fused_residual")?;
2412                if *has_bias {
2413                    write!(f, "_bias")?;
2414                }
2415                write!(f, "_rms(eps={eps})")
2416            }
2417            Op::FusedSwiGLU {
2418                cast_to,
2419                gate_first,
2420            } => {
2421                let mut s = match cast_to {
2422                    Some(dt) => format!("fused_swiglu(cast={dt}"),
2423                    None => "fused_swiglu(".to_string(),
2424                };
2425                if *gate_first {
2426                    s.push_str(",gate_first");
2427                }
2428                s.push(')');
2429                write!(f, "{s}")
2430            }
2431            Op::FusedAttentionBlock {
2432                num_heads,
2433                head_dim,
2434                has_bias,
2435                has_rope,
2436            } => {
2437                write!(f, "fused_attn(h={num_heads},d={head_dim}")?;
2438                if *has_bias {
2439                    write!(f, ",bias")?;
2440                }
2441                if *has_rope {
2442                    write!(f, ",rope")?;
2443                }
2444                write!(f, ")")
2445            }
2446            Op::If { .. } => write!(f, "if(...)"),
2447            Op::While { max_iterations, .. } => match max_iterations {
2448                Some(n) => write!(f, "while(...max={n})"),
2449                None => write!(f, "while(...)"),
2450            },
2451            Op::Scan {
2452                length,
2453                save_trajectory,
2454                num_xs,
2455                ..
2456            } => {
2457                let traj = if *save_trajectory { ",traj" } else { "" };
2458                let xs = if *num_xs > 0 {
2459                    format!(",xs={}", num_xs)
2460                } else {
2461                    String::new()
2462                };
2463                write!(f, "scan(len={length}{xs}{traj})")
2464            }
2465            Op::ScanBackward {
2466                length,
2467                save_trajectory,
2468                num_xs,
2469                ..
2470            } => {
2471                let traj = if *save_trajectory { ",traj" } else { "" };
2472                let xs = if *num_xs > 0 {
2473                    format!(",xs={}", num_xs)
2474                } else {
2475                    String::new()
2476                };
2477                write!(f, "scan_bwd(len={length}{xs}{traj})")
2478            }
2479            Op::ScanBackwardXs {
2480                length,
2481                save_trajectory,
2482                num_xs,
2483                xs_idx,
2484                ..
2485            } => {
2486                let traj = if *save_trajectory { ",traj" } else { "" };
2487                write!(
2488                    f,
2489                    "scan_bwd_xs(len={length},xs={num_xs},idx={xs_idx}{traj})"
2490                )
2491            }
2492            Op::FusedTransformerLayer {
2493                num_heads,
2494                head_dim,
2495                intermediate_size,
2496                has_bias,
2497                ..
2498            } => {
2499                write!(
2500                    f,
2501                    "fused_layer(h={num_heads},d={head_dim},int={intermediate_size}"
2502                )?;
2503                if *has_bias {
2504                    write!(f, ",bias")?;
2505                }
2506                write!(f, ")")
2507            }
2508            Op::ElementwiseRegion {
2509                chain,
2510                num_inputs,
2511                scalar_input_mask,
2512                input_modulus: _,
2513                prologue,
2514                prologue_input: _,
2515            } => {
2516                let pro = match prologue {
2517                    RegionPrologue::None => "",
2518                    RegionPrologue::ResizeNearest2x => ",prologue=resize2x",
2519                };
2520                if *scalar_input_mask != 0 {
2521                    write!(
2522                        f,
2523                        "ew_region(in={num_inputs},steps={},scalar_mask=0x{:x}{pro})",
2524                        chain.len(),
2525                        scalar_input_mask
2526                    )
2527                } else {
2528                    write!(f, "ew_region(in={num_inputs},steps={}{pro})", chain.len())
2529                }
2530            }
2531            Op::TransformRegion { steps, num_inputs } => {
2532                write!(f, "transform_region(in={num_inputs},steps={})", steps.len())
2533            }
2534            Op::BatchElementwiseRegion {
2535                chain,
2536                num_batch_inputs,
2537                scalar_input_mask,
2538                prologue,
2539                ..
2540            } => write!(
2541                f,
2542                "batch_ew_region(batch={num_batch_inputs},steps={},mask=0x{:x},prologue={prologue:?})",
2543                chain.len(),
2544                scalar_input_mask
2545            ),
2546            Op::LayerNormBackwardInput { eps, .. } => {
2547                write!(f, "layer_norm_backward_input(eps={eps})")
2548            }
2549            Op::LayerNormBackwardGamma { eps, .. } => {
2550                write!(f, "layer_norm_backward_gamma(eps={eps})")
2551            }
2552            Op::RmsNormBackwardInput { eps, .. } => write!(f, "rms_norm_backward_input(eps={eps})"),
2553            Op::RmsNormBackwardGamma { eps, .. } => write!(f, "rms_norm_backward_gamma(eps={eps})"),
2554            Op::RmsNormBackwardBeta { eps, .. } => write!(f, "rms_norm_backward_beta(eps={eps})"),
2555            Op::RopeBackward { head_dim, n_rot } => {
2556                write!(f, "rope_backward(d={head_dim},n_rot={n_rot})")
2557            }
2558            Op::GroupNormBackwardInput { num_groups, eps } => {
2559                write!(f, "group_norm_backward_input(g={num_groups},eps={eps})")
2560            }
2561            Op::GroupNormBackwardGamma { num_groups, eps } => {
2562                write!(f, "group_norm_backward_gamma(g={num_groups},eps={eps})")
2563            }
2564            Op::GroupNormBackwardBeta { num_groups, eps } => {
2565                write!(f, "group_norm_backward_beta(g={num_groups},eps={eps})")
2566            }
2567            Op::BatchNormInferenceBackwardInput { eps } => {
2568                write!(f, "batch_norm_inference_backward_input(eps={eps})")
2569            }
2570            Op::BatchNormInferenceBackwardGamma { eps } => {
2571                write!(f, "batch_norm_inference_backward_gamma(eps={eps})")
2572            }
2573            Op::BatchNormInferenceBackwardBeta => {
2574                write!(f, "batch_norm_inference_backward_beta")
2575            }
2576            Op::CumsumBackward { axis, exclusive } => {
2577                write!(f, "cumsum_backward(axis={axis},exclusive={exclusive})")
2578            }
2579            Op::GatherBackward { axis } => write!(f, "gather_backward(axis={axis})"),
2580            Op::GaussianSplatRender {
2581                width,
2582                height,
2583                tile_size,
2584                radius_scale,
2585                alpha_cutoff,
2586                max_splat_steps,
2587                transmittance_threshold,
2588                max_list_entries,
2589            } => write!(
2590                f,
2591                "gaussian_splat_render({width}x{height},tile={tile_size},r={radius_scale},a={alpha_cutoff},steps={max_splat_steps},t={transmittance_threshold},list={max_list_entries})"
2592            ),
2593            Op::GaussianSplatRenderBackward {
2594                width,
2595                height,
2596                loss_grad_clip,
2597                sh_band,
2598                ..
2599            } => write!(
2600                f,
2601                "gaussian_splat_render_bwd({width}x{height},clip={loss_grad_clip},sh={sh_band})"
2602            ),
2603            Op::GaussianSplatPrepare {
2604                width,
2605                height,
2606                tile_size,
2607                radius_scale,
2608                alpha_cutoff,
2609                max_splat_steps,
2610                transmittance_threshold,
2611                max_list_entries,
2612                ..
2613            } => write!(
2614                f,
2615                "gaussian_splat_prepare({width}x{height},tile={tile_size},r={radius_scale},a={alpha_cutoff},steps={max_splat_steps},t={transmittance_threshold},list={max_list_entries})"
2616            ),
2617            Op::GaussianSplatRasterize {
2618                width,
2619                height,
2620                tile_size,
2621                alpha_cutoff,
2622                max_splat_steps,
2623                transmittance_threshold,
2624                max_list_entries,
2625                ..
2626            } => write!(
2627                f,
2628                "gaussian_splat_rasterize({width}x{height},tile={tile_size},a={alpha_cutoff},steps={max_splat_steps},t={transmittance_threshold},list={max_list_entries})"
2629            ),
2630            Op::Custom {
2631                name,
2632                num_inputs,
2633                attrs,
2634            } => write!(f, "custom({name},in={num_inputs},attrs={}B)", attrs.len()),
2635            Op::CustomFn {
2636                num_inputs,
2637                vjp_body,
2638                jvp_body,
2639                ..
2640            } => {
2641                let v = if vjp_body.is_some() { ",vjp" } else { "" };
2642                let j = if jvp_body.is_some() { ",jvp" } else { "" };
2643                write!(f, "custom_fn(in={num_inputs}{v}{j})")
2644            }
2645            Op::Fft { inverse, norm } => {
2646                write!(f, "fft(inverse={inverse}, norm={norm:?})")
2647            }
2648            Op::FftButterflyStage { stage, n_fft } => {
2649                write!(f, "fft_butterfly_stage(stage={stage}, n_fft={n_fft})")
2650            }
2651            Op::LogMel => write!(f, "log_mel()"),
2652            Op::LogMelBackward => write!(f, "log_mel_backward()"),
2653            Op::WelchPeaks { k, n_segments } => {
2654                write!(f, "welch_peaks(k={k}, n_segments={n_segments})")
2655            }
2656        }
2657    }
2658}