Skip to main content

rlx_ir/
op.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Operation types — every tensor op in the RLX IR.
17//!
18//! Designed for pattern-matching fusion: ops are grouped by category so
19//! fusion passes can reason about them structurally.
20
21use crate::DType;
22
23/// Unary element-wise activation functions.
24#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub enum Activation {
27    Gelu,
28    GeluApprox,
29    Silu, // SwiGLU gate activation
30    Relu,
31    Sigmoid,
32    Tanh,
33    Exp,
34    Log,
35    Sqrt,
36    Rsqrt,
37    Neg,
38    Abs,
39    /// `sin(x)`. Backward: `dx = upstream · cos(x)`.
40    Sin,
41    /// `cos(x)`. Backward: `dx = -upstream · sin(x)`.
42    Cos,
43    /// `tan(x)`. Backward: `dx = upstream · sec²(x) = upstream · (1 + tan²(x))`.
44    Tan,
45    /// `atan(x)`. Backward: `dx = upstream · (1 / (1 + x²))`.
46    Atan,
47    /// Round to nearest integer (half-to-even), in f32.
48    /// Forward: `x.round()`. Backward: STE — treats as identity, so
49    /// the gradient passes through unchanged. Useful as a primitive
50    /// for composing custom quantization schemes (Mul-by-recip-scale
51    /// → Round → Clamp → Mul-by-scale = a hand-rolled FakeQuantize
52    /// that the elementwise-region pass can fuse into a single kernel).
53    Round,
54}
55
56/// Scale-tracking strategy for `Op::FakeQuantize`. Determines how
57/// the per-channel `s[c]` is computed each forward pass.
58///
59/// * `PerBatch` — recompute `s[c] = max(|x|) / q_max` from the
60///   current data on every call. Simple, no extra inputs, but
61///   noisy for activations (max-abs jumps batch-to-batch).
62///
63/// * `EMA { decay }` — keep a running `s[c]` in a state tensor
64///   (passed as a second op input). On each call, blend the
65///   current per-batch max-abs into the state via
66///   `state' = decay·state + (1-decay)·max_abs`. Smooth scale
67///   over training, makes activation-QAT actually trainable.
68///   Typical `decay = 0.99`.
69///
70/// * `Fixed` — never recompute. The state tensor's value is
71///   used as-is each call (set once at construction or by the
72///   caller). Useful when scales are pre-calibrated.
73#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
74#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Default)]
75pub enum ScaleMode {
76    #[default]
77    PerBatch,
78    EMA {
79        decay: f32,
80    },
81    Fixed,
82}
83
84impl Eq for ScaleMode {}
85impl std::hash::Hash for ScaleMode {
86    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
87        match self {
88            ScaleMode::PerBatch => state.write_u8(0),
89            ScaleMode::EMA { decay } => {
90                state.write_u8(1);
91                state.write_u32(decay.to_bits());
92            }
93            ScaleMode::Fixed => state.write_u8(2),
94        }
95    }
96}
97
98/// Straight-through estimator variants for `Op::FakeQuantize`'s
99/// backward. The forward is the same regardless: discrete
100/// `clamp(round(x/s)) * s`. The choice here affects only the
101/// gradient w.r.t. `x` during training.
102///
103/// * `Identity` — `dx = upstream`. The original STE; treats the
104///   round as identity in the backward direction. Simplest, fine
105///   for moderate bit widths (i4 / i8).
106///
107/// * `ClippedIdentity` — `dx = upstream * (|x| ≤ q_max·s)`. Zero
108///   the gradient when the input was outside the quantization
109///   range (i.e. the clamp activated). Stops the optimizer from
110///   pushing weights further into saturation.
111///
112/// * `Tanh` — `dx = upstream * (1 - tanh²(x/s))`. Smooth surrogate
113///   for the round step. Slowly attenuates the gradient as `|x|`
114///   approaches `q_max·s`. Often best on tight bit widths (i2).
115///
116/// * `HardTanh` — `dx = upstream * (1 - |x/(q_max·s)|).max(0)`.
117///   Piecewise-linear cousin of `Tanh`; cheaper to compute and
118///   nearly as effective.
119#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
120#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
121pub enum SteKind {
122    #[default]
123    Identity,
124    ClippedIdentity,
125    Tanh,
126    HardTanh,
127}
128
129/// Binary element-wise operations.
130#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
132pub enum BinaryOp {
133    Add,
134    Sub,
135    Mul,
136    Div,
137    Max,
138    Min,
139    Pow,
140}
141
142/// Comparison operations (return Bool tensor).
143#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
144#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
145pub enum CmpOp {
146    Eq,
147    Ne,
148    Lt,
149    Le,
150    Gt,
151    Ge,
152}
153
154/// What kind of attention mask the kernel should apply.
155///
156/// Borrowed from MAX's `nn/attention/mha_mask.mojo` pattern (#20 in
157/// PLAN.md): one attention kernel handles all variants by branching on
158/// the mask kind, instead of forcing every caller to materialize a mask
159/// tensor. The win is two-fold:
160///   1. **`None`** — single unpadded sequence: no mask load, no per-key
161///      compare in the inner loop.
162///   2. **`Causal`** — autoregressive decode: kernel generates the upper-
163///      triangular fill from `(qi, ki)` directly; no `seq²` mask tensor
164///      ever exists.
165///
166/// `Custom` is the existing path — read mask values from the 4th input.
167#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
168#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
169pub enum MaskKind {
170    /// No masking — every position attends to every position.
171    None,
172    /// Causal (autoregressive) — position `qi` attends only to `ki <= qi`.
173    Causal,
174    /// Sliding window — position `qi` attends to `ki ∈ [qi - w, qi]`.
175    SlidingWindow(usize),
176    /// Read mask values from the input tensor (default; matches BERT
177    /// padding-mask behavior). Tensor shape `[batch, key_len]` with
178    /// `1.0` = valid, `<0.5` = ignored.
179    Custom,
180    /// Additive per-head, per-query bias tensor
181    /// `[batch, num_heads, query_len, key_len]` added to the
182    /// `QK^T · scale` scores before softmax. Lets DETR-style boxRPB
183    /// and other learned position biases reuse the fast `Op::Attention`
184    /// path instead of decomposing into matmul + add + softmax + matmul.
185    Bias,
186}
187
188/// Which forward input an [`Op::AttentionBackward`] node differentiates.
189#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
190#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
191pub enum AttentionBwdWrt {
192    Query,
193    Key,
194    Value,
195}
196
197/// Reduction operations along specified axes.
198#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
199#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
200pub enum ReduceOp {
201    Sum,
202    Mean,
203    Max,
204    Min,
205    Prod,
206}
207
208/// PLAN L4: discriminant for each [`Op`] variant. Used by
209/// [`Op::kind`] + the `Backend::supported_ops` trait method to declare
210/// which ops a backend can lower; the `LegalizeForBackend` pass in
211/// `rlx-opt` checks the graph against this set and fails the compile
212/// when an unsupported op is present (instead of silent fallback).
213#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
214#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
215pub enum OpKind {
216    Input,
217    Param,
218    Constant,
219    Activation,
220    Cast,
221    Quantize,
222    Dequantize,
223    FakeQuantize,
224    FakeQuantizeLSQ,
225    FakeQuantizeLSQBackwardX,
226    FakeQuantizeLSQBackwardScale,
227    Binary,
228    Compare,
229    Where,
230    ElementwiseRegion,
231    MatMul,
232    DotGeneral,
233    DenseSolve,
234    BatchedDenseSolve,
235    LayerNorm,
236    LayerNorm2d,
237    GroupNorm,
238    RmsNorm,
239    ResizeNearest2x,
240    Attention,
241    Rope,
242    AxialRope2d,
243    Reshape,
244    Transpose,
245    Narrow,
246    Concat,
247    Expand,
248    Gather,
249    Reduce,
250    Softmax,
251    Cumsum,
252    TopK,
253    Sample,
254    Conv,
255    ConvTranspose2d,
256    Pool,
257    ReluBackward,
258    ActivationBackward,
259    FakeQuantizeBackward,
260    ComplexNormSq,
261    ComplexNormSqBackward,
262    Conjugate,
263    MaxPool2dBackward,
264    Conv2dBackwardInput,
265    Conv2dBackwardWeight,
266    SoftmaxCrossEntropyWithLogits,
267    SoftmaxCrossEntropyBackward,
268    AttentionBackward,
269    LayerNormBackwardInput,
270    LayerNormBackwardGamma,
271    RmsNormBackwardInput,
272    RmsNormBackwardGamma,
273    RmsNormBackwardBeta,
274    RopeBackward,
275    GroupNormBackwardInput,
276    GroupNormBackwardGamma,
277    GroupNormBackwardBeta,
278    CumsumBackward,
279    GatherBackward,
280    GroupedMatMul,
281    DequantGroupedMatMul,
282    DequantMoEWeights,
283    ScatterAdd,
284    LoraMatMul,
285    DequantMatMul,
286    QMatMul,
287    QConv2d,
288    SelectiveScan,
289    GatedDeltaNet,
290    FusedSwiGLU,
291    FusedMatMulBiasAct,
292    FusedResidualLN,
293    FusedResidualRmsNorm,
294    FusedAttentionBlock,
295    FusedTransformerLayer,
296    If,
297    While,
298    Scan,
299    ScanBackward,
300    ScanBackwardXs,
301    /// CPU reference 3D Gaussian splat raster (project → bin → sort → raster).
302    /// See [`Op::GaussianSplatRender`].
303    GaussianSplatRender,
304    /// Backward of [`Op::GaussianSplatRender`] — packed scene parameter gradients.
305    GaussianSplatRenderBackward,
306    /// Project + tile bin + sort + ray grid (strict IR splat stage 1).
307    GaussianSplatPrepare,
308    /// Per-pixel raster from prepared buffers (strict IR splat stage 2).
309    GaussianSplatRasterize,
310    /// User-registered op dispatched through `op_registry`. All
311    /// custom ops (Sparse-LU, FFT, eigensolve, ...) share this kind;
312    /// the per-op identity lives in `Op::Custom::name`.
313    Custom,
314    /// User-defined sub-graph with optional override AD rules. See
315    /// [`Op::CustomFn`] / [`crate::Graph::custom_fn`].
316    CustomFn,
317    /// 1D FFT primitive (forward or inverse) — see [`Op::Fft`].
318    Fft,
319}
320
321/// An operand inside a fused [`ChainStep`] — either a graph-level input
322/// to the [`Op::ElementwiseRegion`] (by index 0..num_inputs) or the
323/// result of a previous step in the chain (by index 0..step_position).
324#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
325#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
326pub enum ChainOperand {
327    Input(u32),
328    Step(u32),
329}
330
331/// One step in a fused element-wise chain. Each step produces exactly
332/// one scalar result (per element); later steps can refer to it via
333/// [`ChainOperand::Step`]. The whole chain runs per element in registers.
334#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
335#[derive(Debug, Clone, PartialEq)]
336pub enum ChainStep {
337    Activation(Activation, ChainOperand),
338    Cast(DType, ChainOperand),
339    Binary(BinaryOp, ChainOperand, ChainOperand),
340    Compare(CmpOp, ChainOperand, ChainOperand),
341    /// 3-input element-wise select: `cond ? on_true : on_false`. Mirrors
342    /// `Op::Where` inside a chain. `cond` is treated as truthy iff
343    /// non-zero. Lets the optimizer fold attention masks / clamp-style
344    /// patterns into a single region kernel instead of breaking the
345    /// chain at the first `Op::Where`.
346    Where(ChainOperand, ChainOperand, ChainOperand),
347}
348
349/// An operation in the RLX IR graph.
350///
351/// Operations are categorized for fusion analysis:
352/// - Element-wise ops fuse with anything reading their output
353/// - Matmul/Conv are BLAS-dispatched and form fusion boundaries
354/// - Reductions are fusion roots (drive the loop iteration)
355#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
356#[derive(Debug, Clone, PartialEq)]
357pub enum Op {
358    // ── Graph inputs ────────────────────────────────────────────
359    /// Model input with a name (shape on the Node).
360    Input {
361        name: String,
362    },
363
364    /// Model parameter (weight/bias) with a name.
365    Param {
366        name: String,
367    },
368
369    /// Constant tensor embedded in the graph.
370    Constant {
371        data: Vec<u8>,
372    },
373
374    // ── Element-wise unary ──────────────────────────────────────
375    /// Unary activation: one input, same shape output.
376    Activation(Activation),
377
378    /// Cast to a different dtype.
379    Cast {
380        to: DType,
381    },
382
383    /// INT8 quantization. Input f32; output i8 same shape.
384    ///   `q[i] = saturate_i8(round(x[i] / scale[c]) + zero_point[c])`
385    /// where `c` selects the per-channel scale/zp when `axis = Some(d)`
386    /// (`c = idx[d]`), or always uses index 0 when `axis = None`
387    /// (per-tensor). The `scales` / `zero_points` payload length must
388    /// match `1` for per-tensor and `input.dim(d)` for per-channel.
389    /// Static — typically produced at calibration time and baked
390    /// into the loaded model. Use `Op::Dequantize` for the inverse.
391    Quantize {
392        axis: Option<usize>,
393        scales: Vec<f32>,
394        zero_points: Vec<i32>,
395    },
396
397    /// INT8 dequantization (inverse of `Op::Quantize`). Input i8;
398    /// output f32 same shape.
399    ///   `x[i] = (q[i] - zero_point[c]) · scale[c]`
400    /// where `c` is selected by `axis` exactly as in `Op::Quantize`.
401    Dequantize {
402        axis: Option<usize>,
403        scales: Vec<f32>,
404        zero_points: Vec<i32>,
405    },
406
407    /// "Fake-quantize" op for **quantization-aware training** (QAT).
408    /// Input f32; output f32 same shape. Forward computes a per-axis
409    /// (or per-tensor when `axis = None`) max-abs scale on the fly:
410    ///   `s[c] = max(|x[..., c, ...]|) / q_max(bits)`
411    /// then quantizes-then-dequantizes:
412    ///   `out[i] = clamp(round(x[i] / s[c]), -q_max, q_max) * s[c]`
413    /// where `q_max` is `127` for `bits=8`, `7` for `bits=4`, `1` for
414    /// `bits=2` (ternary). Symmetric only — zero-point is always 0.
415    ///
416    /// The point of this op is to make the SGD optimizer "see" the
417    /// deployment-time rounding during training. Backward is the
418    /// **straight-through estimator** (STE): the gradient passes
419    /// through (variant chosen by `ste`), ignoring the discontinuity
420    /// at the round. Without STE the rounding would have zero
421    /// gradient almost everywhere and learning would stop.
422    ///
423    /// Inserted by the trainer on conv / FC weight tensors when
424    /// `--qat` is on; the existing `Op::Quantize` / packing path at
425    /// the end of training still handles the deployment-side
426    /// conversion to `i8`/`i4`/`i2` codes.
427    FakeQuantize {
428        bits: u8,
429        axis: Option<usize>,
430        ste: SteKind,
431        scale_mode: ScaleMode,
432    },
433
434    /// Learned Step Size Quantization (LSQ; Esser et al. 2020,
435    /// `arXiv:1902.08153`). Like `FakeQuantize` but the per-channel
436    /// `scale` is a *learned parameter*, passed as the second input.
437    /// Forward is identical to `FakeQuantize` with a fixed scale:
438    ///   `out[i] = clamp(round(x[i]/s[c]), -q_max, q_max) * s[c]`
439    /// Backward computes both `dx` (STE) and `dscale[c]` via the
440    /// closed-form gradient:
441    ///   `dscale[c] = sum_i ψ(x[i]/s[c]) · upstream[i]`
442    /// where `ψ(z) = -z + round(z)` if `|z| ≤ q_max` else
443    /// `sign(z) · q_max`. Routinely beats per-batch and EMA at
444    /// tight bit widths (i2 / i3).
445    ///
446    /// Inputs: `[x, scale]`. `scale` is `[chan_dim]` f32 (matches
447    /// `axis`); for `axis = None` it's `[1]`.
448    FakeQuantizeLSQ {
449        bits: u8,
450        axis: Option<usize>,
451    },
452
453    /// Backward pass for `Op::FakeQuantizeLSQ`. Computes BOTH the
454    /// gradient w.r.t. `x` (STE) and the gradient w.r.t. `scale`
455    /// (closed-form). Output shape matches `x`; the `scale` gradient
456    /// is reduced separately by `LsqScaleGradient`.
457    /// Inputs: `[x, scale, dy]`. Output: `dx`, same shape as `x`.
458    FakeQuantizeLSQBackwardX {
459        bits: u8,
460        axis: Option<usize>,
461    },
462
463    /// Companion to `FakeQuantizeLSQBackwardX`: computes the
464    /// `[chan_dim]` per-channel scale gradient. Inputs `[x, scale, dy]`.
465    /// Output shape matches `scale`.
466    FakeQuantizeLSQBackwardScale {
467        bits: u8,
468        axis: Option<usize>,
469    },
470
471    // ── Element-wise binary ─────────────────────────────────────
472    /// Binary op with broadcasting: two inputs, output shape is broadcast result.
473    Binary(BinaryOp),
474
475    // ── Comparison ──────────────────────────────────────────────
476    /// Element-wise comparison: two inputs, Bool output.
477    Compare(CmpOp),
478
479    /// Select elements: cond (Bool), on_true, on_false → output.
480    Where,
481
482    /// Fused element-wise region (PLAN L2). Holds an N-step chain of
483    /// element-wise operations. Inputs are referenced by index 0..num_inputs;
484    /// each step's result can be referenced by later steps via
485    /// `ChainOperand::Step(idx)`. The output is the last step's result.
486    /// Emitted by `MarkElementwiseRegions` in `rlx-opt` from chains of
487    /// Activation/Cast/Binary/Compare/Where ops with single-consumer
488    /// intermediates and broadcast-compatible shapes. Backends that
489    /// don't have a region kernel can decompose back to the original
490    /// chain via `unfuse::unfuse_elementwise_regions`.
491    ///
492    /// `scalar_input_mask` is a per-input bitfield (bit `i` set ⇒
493    /// input `i` is a scalar broadcast — has shape `[1]`). Kept as a
494    /// fast-path indicator that lets kernels skip the modulo entirely
495    /// when they detect a scalar.
496    ///
497    /// `input_modulus[i]` is the per-input element count, used by
498    /// kernels to compute `arena[input_offs[i] + (gid % input_modulus[i])]`
499    /// — the trailing-shape broadcast pattern. `0` means "no broadcast"
500    /// (input matches the output element count; kernel reads `gid`
501    /// directly). `1` means scalar; any other value means the input
502    /// has fewer elements than the output and they tile by modulo.
503    /// The encoder only allows broadcasts where `out_elems % in_elems
504    /// == 0` so the modulo divides cleanly. Lets chains include bias /
505    /// scale / eps / mask factors that previously broke the chain at
506    /// a Binary op with mismatched shapes.
507    ElementwiseRegion {
508        chain: Vec<ChainStep>,
509        num_inputs: u32,
510        scalar_input_mask: u32,
511        input_modulus: [u32; 16],
512    },
513
514    // ── Linear algebra ──────────────────────────────────────────
515    /// Matrix multiply. Inputs: [.., M, K] × [.., K, N] → [.., M, N].
516    /// Batch dimensions are broadcast.
517    MatMul,
518
519    /// Matrix multiply with explicit dimension specification.
520    /// Like XLA's DotGeneral — handles arbitrary batch/contracting dims.
521    DotGeneral {
522        lhs_contracting: Vec<usize>,
523        rhs_contracting: Vec<usize>,
524        lhs_batch: Vec<usize>,
525        rhs_batch: Vec<usize>,
526    },
527
528    /// Batched dense linear solve. Inputs: `A [B, N, N]`,
529    /// `b [B, N]` or `b [B, N, K]`. Output: same shape as `b`.
530    ///
531    /// Per-batch independent solve — each `A[i]` and `b[i]` are
532    /// solved as a separate `Op::DenseSolve`. Emitted by vmap of
533    /// `Op::DenseSolve`. The CPU lowering loops over the batch
534    /// dimension calling `dgesv` per slice (LAPACK doesn't expose a
535    /// batched solve on Accelerate; cuSOLVER does on NVIDIA).
536    BatchedDenseSolve,
537
538    /// Dense linear solve `x = A⁻¹ · b` via LU factorization.
539    /// Inputs: `A [N, N]`, `b [N]` (or `b [N, K]` for multi-RHS).
540    /// Output: same shape as `b`.
541    ///
542    /// VJP via the implicit-function theorem:
543    ///   `dx = solve(Aᵀ, upstream)`
544    ///   `dA = -outer(dx, x)`   (x is the forward output)
545    ///   `db = dx`
546    /// The rule is dtype-agnostic; lowering is per-backend (Accelerate
547    /// `dgesv` / `sgesv`, cuSOLVER, etc.).
548    DenseSolve,
549
550    // ── Normalization ───────────────────────────────────────────
551    /// Layer normalization: input, gamma, beta → normalized output.
552    /// `axis` is the feature dimension (usually -1).
553    LayerNorm {
554        axis: i32,
555        eps: f32,
556    },
557
558    /// Group normalization on NCHW tensors: `input`, `gamma`, `beta` → same shape.
559    /// Normalizes over `(C/num_groups) × H × W` per group.
560    GroupNorm {
561        num_groups: usize,
562        eps: f32,
563    },
564
565    /// LayerNorm2d on NCHW: normalize across the channel axis at each spatial
566    /// position (candle / SAM `LayerNorm2d` semantics — not PyTorch's H×W norm).
567    LayerNorm2d {
568        eps: f32,
569    },
570
571    /// Nearest-neighbor 2× upsample on NCHW (doubles spatial dims 2 and 3).
572    ResizeNearest2x,
573
574    /// RMS normalization: input, gamma → normalized output.
575    RmsNorm {
576        axis: i32,
577        eps: f32,
578    },
579
580    // ── Attention ───────────────────────────────────────────────
581    /// Scaled dot-product attention: Q, K, V, \[mask\] → output.
582    /// The compiler can lower this to fused SDPA or flash attention.
583    /// `mask_kind` controls how masking is applied — `Custom` reads from
584    /// the 4th input tensor; `None` / `Causal` / `SlidingWindow` skip the
585    /// mask load and apply the mask directly in the inner loop. See
586    /// `MaskKind` for the rationale.
587    Attention {
588        num_heads: usize,
589        head_dim: usize,
590        mask_kind: MaskKind,
591    },
592
593    /// Rotary position embedding applied to one tensor: x, cos, sin → x_rotated.
594    /// Apply separately to Q and K. `head_dim` is the per-head width; `n_rot`
595    /// is how many leading dims get NeoX RoPE (pair offset `n_rot/2`). When
596    /// `n_rot < head_dim`, trailing dims are copied unchanged (Qwen3.5 MRoPE).
597    Rope {
598        head_dim: usize,
599        n_rot: usize,
600    },
601
602    /// SAM2 axial 2-D RoPE on `[batch, seq, num_heads * head_dim]`.
603    AxialRope2d {
604        end_x: usize,
605        end_y: usize,
606        head_dim: usize,
607        num_heads: usize,
608        theta: f32,
609        repeat_factor: usize,
610    },
611
612    // ── Shape manipulation ──────────────────────────────────────
613    Reshape {
614        new_shape: Vec<i64>,
615    },
616    Transpose {
617        perm: Vec<usize>,
618    },
619    /// Select a contiguous slice along an axis.
620    Narrow {
621        axis: usize,
622        start: usize,
623        len: usize,
624    },
625    /// Concatenate along an axis.
626    Concat {
627        axis: usize,
628    },
629    /// Expand (broadcast) to a target shape.
630    Expand {
631        target_shape: Vec<i64>,
632    },
633    /// Gather elements by index along an axis (embedding lookup).
634    Gather {
635        axis: usize,
636    },
637
638    // ── Reduction ───────────────────────────────────────────────
639    /// Reduce along specified axes.
640    Reduce {
641        op: ReduceOp,
642        axes: Vec<usize>,
643        keep_dim: bool,
644    },
645
646    /// Selective scan (plan #15) — Mamba-style state-space model
647    /// step. The recurrence:
648    ///   `h[t] = exp(Δ[t] * A) * h[t-1] + Δ[t] * B[t] * x[t]`
649    ///   `y[t] = C[t] * h[t]`
650    /// where state `h` has dimension `state_size` and the input has
651    /// `(batch, seq, hidden)`.
652    ///
653    /// Inputs (in order):
654    ///   `x [b, s, h]`      f32 input
655    ///   `delta [b, s, h]`  f32 step size (per-position, per-channel)
656    ///   `a [h, n]`         f32 transition matrix (one per channel)
657    ///   `b [b, s, n]`      f32 input projection
658    ///   `c [b, s, n]`      f32 output projection
659    /// Output: `[b, s, h]` f32. State `h` is implicit; the kernel
660    /// scans through the seq dimension carrying it.
661    ///
662    /// `state_size` = `n` is exposed for the cost model.
663    SelectiveScan {
664        state_size: usize,
665    },
666
667    /// Gated DeltaNet linear-attention recurrence — the per-layer
668    /// kernel used by Qwen3.5/3.6 trunk "linear attention" blocks
669    /// (and Qwen3-Next, Kimi-Linear). Mirrors
670    /// `llama.cpp / src/models/delta-net-base.cpp` autoregressive
671    /// path; chunked + fused variants ride the same op identity.
672    ///
673    /// **Math (per token `t`, head `h`, state size `n`):**
674    /// state matrix `S[h, i, j]` is implicit (reset per batch).
675    /// ```text
676    ///   S[h]     *= exp(g[t,h])                     # scalar gate
677    ///   sk[h,j]   = Σ_i S[h,i,j] * k[t,h,i]
678    ///   d[h,j]    = (v[t,h,j] - sk[h,j]) * b[t,h]   # b = beta
679    ///   S[h,i,j] += k[t,h,i] * d[h,j]               # outer-prod
680    ///   o[t,h,j]  = Σ_i S[h,i,j] * (q[t,h,i] / √n)
681    /// ```
682    ///
683    /// Inputs:
684    ///   `q     [b, s, h_v, n]`  f32 queries (L2-normed by caller)
685    ///   `k     [b, s, h_v, n]`  f32 keys    (L2-normed by caller;
686    ///                            GQA-repeated to match `h_v`)
687    ///   `v     [b, s, h_v, n]`  f32 values
688    ///   `g     [b, s, h_v]`     f32 log-gate (exp'd inside kernel)
689    ///   `beta  [b, s, h_v]`     f32 delta-rule mixing factor
690    ///
691    /// Output: `[b, s, h_v, n]` f32.
692    ///
693    /// When `carry_state` is true, a sixth input `state [b, h_v, n, n]`
694    /// provides the initial SSM matrix per head; the kernel updates it
695    /// in place across the sequence and leaves the final state in the
696    /// same buffer (same layout as the internal scan state:
697    /// `state[h, i, j]` row-major over `(n, n)` per head).
698    GatedDeltaNet {
699        state_size: usize,
700        carry_state: bool,
701    },
702
703    /// Fused dequant + matmul (plan #5). The biggest LLM-bandwidth
704    /// win on Apple Silicon: dequantizes weights inside the matmul
705    /// inner loop, never materializing f32 weights.
706    ///
707    /// **BREAKING CHANGE in 0.2.0:** `num_inputs()` is now
708    /// scheme-dependent — **4** for legacy Int8 schemes, **2** for
709    /// the new GGUF K-quant schemes (their scales/mins live inside
710    /// the packed bytes, so no side-channel `scale` / `zp` tensors
711    /// are fed in). Callers that assumed a fixed 4-input contract
712    /// must inspect `scheme.is_gguf()` before reading inputs.
713    ///
714    /// Inputs (Int8 schemes — `scheme.is_gguf() == false`):
715    ///   `x [m, k]`             f32 activations
716    ///   `w_q [k, n]` packed    quantized weight bytes (i8 per
717    ///                          element for Int8 schemes; 4-bit
718    ///                          packed two-per-byte for Int4)
719    ///   `scale [k/block, n]`   per-block f32 dequant scale
720    ///   `zp    [k/block, n]`   per-block f32 zero-point
721    ///                          (zero-tensor if symmetric)
722    ///
723    /// Inputs (`Nvfp4Block` — fixed group size 16 along K):
724    ///   `x [m, k]`             f32 activations
725    ///   `w_q [k,n/2]` packed   FP4 E2M1 codes (unsigned nibble 0..15)
726    ///   `scale [k/16, n]` u8   FP8 E4M3 block scales (one byte / group)
727    ///   `global_scale [1]` f32 per-tensor scale (pass `[1.0]` if unused)
728    ///
729    /// Inputs (GGUF schemes — `scheme.is_gguf() == true`):
730    ///   `x [m, k]`             f32 activations
731    ///   `packed_w [bytes]`     raw GGUF super-block bytes; the
732    ///                          dequantizer reads the per-sub-block
733    ///                          scales / mins / quants directly out
734    ///                          of the buffer per the K-quant block
735    ///                          layout (no side tensors).
736    ///
737    /// Output: `[m, n]` f32.
738    ///
739    /// `block_size` (on the Int8 schemes only) is the number of
740    /// consecutive elements that share one (scale, zero_point) pair.
741    /// The Op carries enough metadata that the kernel doesn't need
742    /// a separate `QuantMap` lookup at run time.
743    DequantMatMul {
744        scheme: crate::quant::QuantScheme,
745    },
746
747    /// Real INT8-arithmetic matrix multiply with i32 accumulation.
748    /// Inputs (in order):
749    ///   `x      [M, K]`  i8 activations (zero-point = `x_zp`)
750    ///   `w      [K, N]`  i8 weights     (zero-point = `w_zp`)
751    ///   `bias   [N]`     i32 (in accumulator scale = `x_scale·w_scale`),
752    ///                    pass a zeros tensor for "no bias"
753    /// Output:  `[M, N]`  i8 (zero-point = `out_zp`)
754    ///
755    /// Per-element compute:
756    ///   `out[m,n] = requantize(bias[n] + Σₖ (x[m,k]-x_zp)·(w[k,n]-w_zp), mult, out_zp)`
757    /// where `mult = x_scale · w_scale / out_scale`.
758    ///
759    /// This is the same kernel shape `rlx-cortexm/src/dense.rs`
760    /// uses for on-device int8 inference, lifted into the IR so the
761    /// rlx-cpu backend can run a quantized graph directly (instead
762    /// of round-tripping through fake-quant Dequantize → MatMul →
763    /// Quantize). 2-D only — generalizing to batched comes when a
764    /// real workload demands it.
765    QMatMul {
766        x_zp: i32,
767        w_zp: i32,
768        out_zp: i32,
769        mult: f32,
770    },
771
772    /// Real INT8-arithmetic 2-D convolution with i32 accumulation.
773    /// Inputs:
774    ///   `x      [N, C_in, H, W]`              i8 (zero-point = `x_zp`)
775    ///   `w      [C_out, C_in/groups, kH, kW]` i8 (zero-point = `w_zp`)
776    ///   `bias   [C_out]`                      i32 in accumulator scale
777    /// Output: `[N, C_out, H_out, W_out]` i8 (zero-point = `out_zp`).
778    /// Same NCHW geometry contract as `Op::Conv`; same requantize
779    /// math as `Op::QMatMul` (per-element `acc·mult` rounded to i8).
780    QConv2d {
781        kernel_size: Vec<usize>,
782        stride: Vec<usize>,
783        padding: Vec<usize>,
784        dilation: Vec<usize>,
785        groups: usize,
786        x_zp: i32,
787        w_zp: i32,
788        out_zp: i32,
789        mult: f32,
790    },
791
792    /// Fused LoRA matmul: `out = x·W + scale * x·A·B`.
793    /// Inputs (in order): `x [m, k]`, `w [k, n]`, `a [k, r]`, `b [r, n]`.
794    /// `r` is the LoRA rank (typically 4-64). `scale` is the
795    /// per-adapter `alpha / rank` knob.
796    /// Plan #9: lifts LoRA from "three matmuls + an add" into one
797    /// kernel that keeps the rank-r intermediate in registers.
798    LoraMatMul {
799        scale: f32,
800    },
801
802    /// Fused sampling kernel: logits → optional top-k filter →
803    /// optional top-p truncation → softmax → multinomial sample.
804    /// One f32-encoded sampled token id per batch row (output
805    /// shape `[batch]`).
806    ///
807    /// `temperature == 1.0` matches a plain argmax-of-softmax;
808    /// lower → sharper, higher → flatter. `top_k == 0` disables.
809    /// `top_p == 1.0` disables. `seed` is the Philox seed; pass 0
810    /// for "use process-global counter" (still deterministic
811    /// given the call order).
812    /// Borrowed from MAX's nn/sampling.mojo (#42 in PLAN.md).
813    /// Latency-critical: never materializes the full softmax
814    /// distribution on the host.
815    Sample {
816        top_k: usize,     // 0 = disabled
817        top_p: f32,       // 1.0 = disabled
818        temperature: f32, // 1.0 = neutral
819        seed: u64,        // 0 = use thread-local counter
820    },
821
822    /// Inclusive cumulative sum along an axis. Same shape in/out.
823    /// Underpins ragged-tensor offsets, sampling (top-p prefix sum),
824    /// and sequence-position math (#44 in PLAN.md).
825    /// `exclusive=true` shifts the result so output\[0\] = 0 (useful
826    /// for offset arrays where the first segment starts at 0).
827    Cumsum {
828        axis: i32,
829        exclusive: bool,
830    },
831
832    /// Softmax along an axis (reduction + element-wise).
833    Softmax {
834        axis: i32,
835    },
836
837    /// Top-K **indices** along the last axis. Output shape `[..., k]`,
838    /// f32-encoded indices (rlx is f32-only at the I/O boundary).
839    /// To recover the values, follow with a `Gather` against the
840    /// original tensor — works because Gather already supports any axis.
841    /// Ties broken by smaller index (matches NumPy / PyTorch
842    /// `torch.topk(..., largest=True, sorted=True)`).
843    /// Used by MoE gating; also useful for beam search.
844    TopK {
845        k: usize,
846    },
847
848    /// Indexed batched matmul. The MoE GEMM primitive.
849    /// Inputs: `[input, weight, expert_idx]`
850    ///   input       : [M, K]                — per-token activations
851    ///   weight      : [num_experts, K, N]   — stacked expert weights
852    ///   expert_idx  : \[M\]                   — f32-encoded expert id per token
853    /// Output         : [M, N]                — output\[i\] = input\[i\] @ weight[expert_idx\[i\]]
854    /// Naive impl on both backends; future work can replace with a
855    /// segmented/grouped GEMM when there's a real workload.
856    GroupedMatMul,
857
858    /// Fused GGUF K-quant dequant + [`Op::GroupedMatMul`]. Same three
859    /// inputs as `GroupedMatMul`, but `weight` is a U8 tensor holding
860    /// `num_experts` contiguous packed expert slabs (GGML layout, expert
861    /// dimension outermost). Scales live inside the packed bytes.
862    DequantGroupedMatMul {
863        scheme: crate::quant::QuantScheme,
864    },
865
866    /// Dequant a packed MoE expert stack to F32 `[num_experts, K, N]` in
867    /// GroupedMatMul layout. Input: U8 packed bytes; output shape is
868    /// declared on the node (`[E, K, N]`).
869    DequantMoEWeights {
870        scheme: crate::quant::QuantScheme,
871    },
872
873    /// Scatter-add into a destination tensor. The "unpermute" half of
874    /// MoE routing (also useful for embedding gradient updates).
875    /// Inputs: `[updates, indices]`
876    ///   updates : [num_updates, trailing]   — values to add
877    ///   indices : \[num_updates\]             — f32-encoded destination row
878    /// Output    : [out_dim, trailing]       — output[indices\[i\]] += updates\[i\]
879    /// `out_dim` is taken from the node's declared output shape.
880    /// Initial output is zero; multiple updates to the same row
881    /// accumulate (sequentially on CPU; with atomic-add on Metal).
882    ScatterAdd,
883
884    // ── Convolution ─────────────────────────────────────────────
885    /// 2D convolution on NCHW tensors. Also exposed as [`OpKind::Conv`] / `conv2d`.
886    /// Weight layout: `[C_out, C_in / groups, kH, kW]`.
887    Conv {
888        kernel_size: Vec<usize>,
889        stride: Vec<usize>,
890        padding: Vec<usize>,
891        dilation: Vec<usize>,
892        groups: usize,
893    },
894
895    /// 2D transposed convolution on NCHW. Weight layout (PyTorch):
896    /// `[C_in, C_out / groups, kH, kW]`.
897    ConvTranspose2d {
898        kernel_size: Vec<usize>,
899        stride: Vec<usize>,
900        padding: Vec<usize>,
901        dilation: Vec<usize>,
902        output_padding: Vec<usize>,
903        groups: usize,
904    },
905
906    // ── Pooling ─────────────────────────────────────────────────
907    Pool {
908        kind: ReduceOp,
909        kernel_size: Vec<usize>,
910        stride: Vec<usize>,
911        padding: Vec<usize>,
912    },
913
914    // ── Backward / training ops ─────────────────────────────────
915    //
916    // Closed-form gradient nodes emitted by `rlx-opt::autodiff`.
917    // Pairing a forward op with a dedicated backward op (instead of
918    // composing it from primitives) keeps the gradient kernel simple
919    // and lets the backend recompute argmax / masks / softmax inline.
920    /// ReLU backward: `dx = dy where x > 0 else 0`.
921    /// Inputs: `[x, dy]` — both same shape; output matches.
922    ReluBackward,
923
924    /// Element-wise complex squared-magnitude: `|z|² = z.re² + z.im²`.
925    /// Input: 1 tensor with `DType::C64`. Output: same shape but
926    /// `DType::F32`. The natural real-valued loss surface for
927    /// Wirtinger reverse-mode AD on complex graphs — pair with
928    /// [`Op::ComplexNormSqBackward`].
929    ComplexNormSq,
930
931    /// Element-wise complex conjugate: `z̄ = z.re - i·z.im` per element.
932    /// Input: 1 tensor with `DType::C64`. Output: same shape, same dtype.
933    /// Used by Wirtinger VJP rules on `Op::Binary` over C64 (the rule
934    /// for `y = a·b` is `dL/dā = upstream · conj(b)`, etc.).
935    Conjugate,
936
937    /// Backward for [`Op::ComplexNormSq`] under Wirtinger calculus.
938    /// `f(z) = |z|² = z·z̄`, so `∂f/∂z̄ = z`. Given upstream real
939    /// cotangent `g` (same shape as the forward output), the C64
940    /// gradient with respect to `z` is `g · z` element-wise, returned
941    /// in C64 storage `[re_g·re_z, re_g·im_z]` per element.
942    ///
943    /// Inputs: `[z (C64), g (F32)]` — both same logical shape; output
944    /// matches `z` (C64).
945    ComplexNormSqBackward,
946
947    /// LayerNorm backward w.r.t. the input. Computes
948    ///   `d_x[..., d] = inv_std · (dy·γ - mean(dy·γ) - x̂·mean(dy·γ·x̂))`
949    /// over the feature axis, where `x̂ = (x - mean)/std` is recomputed
950    /// inline from `x`. Inputs: `[x, gamma, dy]`; output shape = `x.shape`.
951    /// Currently lowers axis=-1 only (matches the forward thunk).
952    LayerNormBackwardInput {
953        axis: i32,
954        eps: f32,
955    },
956
957    /// LayerNorm backward w.r.t. gamma. Computes
958    ///   `d_gamma[d] = Σ_{batch} dy[..., d] · x̂[..., d]`
959    /// — sums the per-element product of upstream and the (recomputed)
960    /// normalized input over the leading axes. Inputs: `[x, dy]`;
961    /// output shape = `gamma.shape` (= 1-D feature axis).
962    LayerNormBackwardGamma {
963        axis: i32,
964        eps: f32,
965    },
966
967    /// RMSNorm backward w.r.t. input. Inputs `[x, gamma, beta, dy]`; output = `x.shape`.
968    RmsNormBackwardInput {
969        axis: i32,
970        eps: f32,
971    },
972
973    /// RMSNorm backward w.r.t. gamma. Inputs `[x, gamma, beta, dy]`; output = `gamma.shape`.
974    RmsNormBackwardGamma {
975        axis: i32,
976        eps: f32,
977    },
978
979    /// RMSNorm backward w.r.t. beta. Inputs `[x, gamma, beta, dy]`; output = `beta.shape`.
980    RmsNormBackwardBeta {
981        axis: i32,
982        eps: f32,
983    },
984
985    /// RoPE backward w.r.t. `x`. Inputs `[dy, cos, sin]`; output = `dy.shape`.
986    RopeBackward {
987        head_dim: usize,
988        n_rot: usize,
989    },
990
991    /// GroupNorm (NCHW) backward w.r.t. input. Inputs `[x, gamma, beta, dy]`.
992    GroupNormBackwardInput {
993        num_groups: usize,
994        eps: f32,
995    },
996
997    /// GroupNorm backward w.r.t. gamma. Inputs `[x, dy]`; output = `gamma.shape`.
998    GroupNormBackwardGamma {
999        num_groups: usize,
1000        eps: f32,
1001    },
1002
1003    /// GroupNorm backward w.r.t. beta. Inputs `[x, dy]`; output = `beta.shape`.
1004    GroupNormBackwardBeta {
1005        num_groups: usize,
1006        eps: f32,
1007    },
1008
1009    /// Cumsum backward along `axis`. Inputs `[dy]`; output matches forward input shape.
1010    CumsumBackward {
1011        axis: i32,
1012        exclusive: bool,
1013    },
1014
1015    /// Gather backward (scatter-add into table). Inputs `[dy, indices]`; output = table shape.
1016    /// `axis` matches forward [`Op::Gather`].
1017    GatherBackward {
1018        axis: i32,
1019    },
1020
1021    /// Generic element-wise activation backward. `kind` selects the
1022    /// closed-form derivative `d/dx act(x)`. Inputs: `[x, dy]`; output
1023    /// shape matches `x`. The kernel computes `d/dx · dy` per element.
1024    ///
1025    /// Closed forms (all element-wise):
1026    /// * `Gelu`     — exact derivative of erf-based GELU.
1027    /// * `GeluApprox` — derivative of the tanh approximation
1028    ///   `0.5 x (1 + tanh(√(2/π)(x + 0.044715 x³)))`.
1029    /// * `Silu`     — `σ(x)·(1 + x·(1 - σ(x)))`.
1030    /// * `Sigmoid`  — `σ(x)·(1 - σ(x))`.
1031    /// * `Tanh`     — `1 - tanh(x)²`.
1032    /// * `Exp`      — `exp(x)`.
1033    /// * `Log`      — `1 / x`.
1034    /// * `Sqrt`     — `0.5 / sqrt(x)`.
1035    /// * `Rsqrt`    — `-0.5 · x^(-3/2)`.
1036    /// * `Neg`      — `-1`.
1037    /// * `Abs`      — `sign(x)` (zero at x=0).
1038    /// * `Sin`      — `cos(x)`.
1039    /// * `Cos`      — `-sin(x)`.
1040    /// * `Tan`      — `1 + tan²(x)`.
1041    /// * `Atan`     — `1 / (1 + x²)`.
1042    /// * `Relu`     — kept here for completeness; the dedicated
1043    ///   `ReluBackward` op is preferred for relu and is what the
1044    ///   autodiff pass actually emits.
1045    ActivationBackward {
1046        kind: Activation,
1047    },
1048
1049    /// Backward for `Op::FakeQuantize` under a non-default STE.
1050    /// Inputs `[x, dy]`: the forward input and the upstream
1051    /// gradient. Output `dx` same shape. The `bits`/`axis`/`ste`
1052    /// fields must match the forward op so the kernel computes the
1053    /// same per-channel scale and applies the right STE attenuation.
1054    /// For `SteKind::Identity` this op is unnecessary — autodiff
1055    /// just routes `upstream` through unchanged.
1056    FakeQuantizeBackward {
1057        bits: u8,
1058        axis: Option<usize>,
1059        ste: SteKind,
1060    },
1061
1062    /// 2D max-pool backward. Routes each element of `dy` back into the
1063    /// position in `x`'s window where the forward max was taken.
1064    /// Inputs: `[x, dy]` with `x [N, C, H, W]` and
1065    /// `dy [N, C, H_out, W_out]`. Output: same shape as `x`.
1066    /// Carries the forward pool's geometry so the kernel can recompute
1067    /// the argmax position per window without a saved-indices tensor.
1068    MaxPool2dBackward {
1069        kernel_size: Vec<usize>,
1070        stride: Vec<usize>,
1071        padding: Vec<usize>,
1072    },
1073
1074    /// 2D conv backward w.r.t. input. Computes `dx = conv_transpose(dy, w)`.
1075    /// Inputs: `[dy, w]` with `dy [N, C_out, H_out, W_out]` and
1076    /// `w [C_out, C_in/groups, kH, kW]`. Output: `[N, C_in, H, W]`
1077    /// (declared on the node — caller knows the original input shape).
1078    /// Geometry is the forward conv's parameters, not the transposed
1079    /// conv's.
1080    Conv2dBackwardInput {
1081        kernel_size: Vec<usize>,
1082        stride: Vec<usize>,
1083        padding: Vec<usize>,
1084        dilation: Vec<usize>,
1085        groups: usize,
1086    },
1087
1088    /// 2D conv backward w.r.t. weight. Computes
1089    /// `dw[c_out, c_in, kh, kw] = sum_{n,h_out,w_out} x[n,c_in,...] * dy[n,c_out,h_out,w_out]`.
1090    /// Inputs: `[x, dy]`. Output: `[C_out, C_in/groups, kH, kW]`.
1091    Conv2dBackwardWeight {
1092        kernel_size: Vec<usize>,
1093        stride: Vec<usize>,
1094        padding: Vec<usize>,
1095        dilation: Vec<usize>,
1096        groups: usize,
1097    },
1098
1099    /// Fused softmax + cross-entropy loss with integer (f32-encoded)
1100    /// targets — the standard classification loss. Per-row output:
1101    /// `loss[n] = -log(softmax(logits[n])[labels[n]])`.
1102    /// Inputs: `[logits, labels]` with `logits [N, C]` and
1103    /// `labels [N]` (f32-encoded class indices). Output: `[N]`.
1104    /// Caller does the `Reduce::Mean` if they want a scalar.
1105    SoftmaxCrossEntropyWithLogits,
1106
1107    /// Backward of the fused loss above. Emits
1108    /// `dlogits[n,c] = (softmax(logits[n])[c] - one_hot(labels)[n,c]) * d_loss[n]`.
1109    /// Inputs: `[logits, labels, d_loss]`. Output: `[N, C]` (same shape
1110    /// as `logits`). Recomputes the softmax inline rather than threading
1111    /// it through from the forward node.
1112    SoftmaxCrossEntropyBackward,
1113
1114    /// Backward of [`Op::Attention`]. Recomputes scaled `QK^T`, applies
1115    /// the same `mask_kind` as the forward op, softmaxes scores, then
1116    /// emits **one** of `dQ`, `dK`, or `dV` selected by [`AttentionBwdWrt`].
1117    /// Autodiff emits three nodes (one per `wrt`) so each output shape
1118    /// stays a normal single-output MIR node.
1119    ///
1120    /// Inputs: `[q, k, v, dy]` plus optional mask when `mask_kind` is
1121    /// [`MaskKind::Custom`] or [`MaskKind::Bias`] (same convention as
1122    /// forward). Output shape matches `q`, `k`, or `v` respectively.
1123    AttentionBackward {
1124        num_heads: usize,
1125        head_dim: usize,
1126        mask_kind: MaskKind,
1127        wrt: AttentionBwdWrt,
1128    },
1129
1130    // ── Fused operations (created by optimization passes) ──────
1131    /// Fused matmul + bias + activation. Created from MatMul → Add → Activation.
1132    FusedMatMulBiasAct {
1133        activation: Option<Activation>,
1134    },
1135
1136    /// Fused residual + optional bias + layer norm.
1137    /// Created from Add(x, residual) → [Add(bias)] → LayerNorm.
1138    FusedResidualLN {
1139        has_bias: bool,
1140        eps: f32,
1141    },
1142
1143    /// Fused residual + optional bias + RMS norm.
1144    /// Created from Add(x, residual) → [Add(bias)] → RmsNorm.
1145    FusedResidualRmsNorm {
1146        has_bias: bool,
1147        eps: f32,
1148    },
1149
1150    /// Fused SwiGLU: split input into up/gate halves, silu(gate) * up.
1151    /// Created from Split → Silu → Mul when fed by a concatenated matmul.
1152    ///
1153    /// `cast_to`: optional output dtype — when `Some(dt)` the kernel casts
1154    /// its result from the input dtype to `dt` in-register, saving a
1155    /// separate cast pass. Reserved for future fp8/fp4 quantization paths;
1156    /// for f32→f16 mixed precision the AutoMixedPrecision pass already
1157    /// inserts a Cast node so this stays `None` in current pipelines.
1158    FusedSwiGLU {
1159        cast_to: Option<DType>,
1160        /// When `true`, the concatenated input stores gate in the low half
1161        /// `[..., 0..N)` and up in the high half `[..., N..2N)` — the layout
1162        /// produced when gate projection is emitted before up in the builder.
1163        /// Default `false`: up @ low, gate @ high (canonical concat order).
1164        gate_first: bool,
1165    },
1166
1167    /// Fused full transformer layer: attention block + residual+LN + FFN + residual+LN.
1168    /// All intermediates resident in registers/threadgroup memory; one kernel
1169    /// per layer instead of ~30 (the CPU's batch=1 win, lifted to IR so any
1170    /// backend can implement it as a monolithic kernel).
1171    ///
1172    /// Inputs: hidden, qkv_w, qkv_b, out_w, out_b,
1173    ///         ln1_g, ln1_b, fc1_w, fc1_b, fc2_w, fc2_b, ln2_g, ln2_b, mask
1174    /// Output: same shape as hidden.
1175    ///
1176    /// **Backend status:** same as FusedAttentionBlock. CPU implements
1177    /// the L1-cache-resident merge at the thunk level. Metal deferred —
1178    /// requires a single MSL kernel for the whole layer to actually
1179    /// beat the unfused path. Multi-day work; revisit when there's a
1180    /// model whose Metal inference is bottlenecked here rather than on
1181    /// the wait latency floor.
1182    FusedTransformerLayer {
1183        num_heads: usize,
1184        head_dim: usize,
1185        intermediate_size: usize,
1186        eps1: f32,
1187        eps2: f32,
1188        activation: Activation,
1189        has_bias: bool,
1190    },
1191
1192    /// Fused attention block: QKV projection → split → \[RoPE\] → SDPA → output projection.
1193    /// Created by FuseAttentionBlock pass when batch*seq is small.
1194    /// All intermediates stay in L1 cache — no arena writes between ops.
1195    ///
1196    /// Inputs (in order):
1197    ///   hidden, qkv_w, out_w, mask,
1198    ///   [qkv_b, out_b]      if has_bias,
1199    ///   [rope_cos, rope_sin] if has_rope
1200    ///
1201    /// **Backend status (Phase C finalize):**
1202    ///   CPU  — implemented at the *thunk* level: the CPU schedule
1203    ///          recognizes the multi-thunk pattern and merges into
1204    ///          a single FusedAttnBlock that keeps Q/K/V in stack
1205    ///          buffers across stages (the L1-cache win).
1206    ///   Metal — **deferred**. A dispatch-wrapper version (chaining
1207    ///          existing kernels) buys nothing the unfused Metal path
1208    ///          doesn't already get, since per-run cost is dominated
1209    ///          by `wait_until_completed` (~150 µs), not encode. The
1210    ///          real win is a single MSL kernel keeping Q/K/V in
1211    ///          threadgroup memory across stages — multi-day work.
1212    ///          Until then, Metal runs the unfused chain (one matmul,
1213    ///          three narrows, two ropes, attention, one matmul) — all
1214    ///          covered in op_coverage and parity_harness.
1215    FusedAttentionBlock {
1216        num_heads: usize,
1217        head_dim: usize,
1218        has_bias: bool,
1219        has_rope: bool,
1220    },
1221
1222    // ── Control flow (subgraphs as op payloads) ─────────────────
1223    //
1224    // Status: IR is defined; helper `run_if` / `run_while` exist in
1225    // rlx-runtime/src/subgraph.rs; **executor wiring is not yet
1226    // implemented** (both CPU thunk and Metal thunk fall through to
1227    // `Thunk::Nop` for these ops). Wiring requires:
1228    //   1. Recursive subgraph compile at parent-compile time.
1229    //   2. Per-subgraph input/output binding through the arena.
1230    //   3. Schedule-level dispatch when the predicate / loop cond is
1231    //      resolved at runtime.
1232    // Estimate: 4–6 hours of focused work + parity tests. Deferred
1233    // because no current in-tree model uses these ops;
1234    // surface area without a validation target invites silent bugs.
1235    /// Conditional: pick between two subgraphs based on a boolean predicate.
1236    /// Inputs: [predicate, ...captures (used inside both branches)].
1237    /// `then_branch` and `else_branch` are sub-graphs that share the
1238    /// captured inputs and must produce identically-shaped outputs.
1239    /// Used for: shape-dependent execution, batched inference of
1240    /// dynamic-length sequences with padding masks.
1241    If {
1242        then_branch: Box<crate::Graph>,
1243        else_branch: Box<crate::Graph>,
1244    },
1245
1246    /// Loop: iterate `body` while `cond` evaluates true.
1247    /// Inputs: [...initial loop-carried values].
1248    /// `cond`'s single output is a Bool scalar.
1249    /// `body`'s outputs become the next iteration's loop-carried inputs.
1250    /// Outputs of While are the values after the final iteration.
1251    /// Used for: KV-cache-driven autoregressive generation, beam search.
1252    While {
1253        cond: Box<crate::Graph>,
1254        body: Box<crate::Graph>,
1255        max_iterations: Option<usize>,
1256    },
1257
1258    /// Bounded-length loop with a fixed-shape carry, optional per-step
1259    /// inputs, and optional stacked output. Mirrors JAX's `lax.scan`.
1260    ///
1261    /// Body signature: `(carry, x_t_0, ..., x_t_{num_xs-1}) → carry_next`
1262    /// — `1 + num_xs` Op::Inputs in NodeId construction order (first
1263    /// declared is the carry; the remaining `num_xs` are per-step
1264    /// slices). Single output (the next carry).
1265    ///
1266    /// Outer Op::Scan inputs (in order):
1267    ///   `[init_carry, xs_0, xs_1, ..., xs_{num_xs-1}]`
1268    /// Each `xs_i` has shape `[length, *per_step_shape_i]`; the body
1269    /// sees `xs_i[t]` (a `per_step_shape_i` slice) on iteration `t`.
1270    ///
1271    /// Outer Op::Scan output:
1272    ///   * `save_trajectory == false` — final carry, shape `*carry`.
1273    ///   * `save_trajectory == true`  — stacked trajectory of carries,
1274    ///     shape `[length, *carry]`. Row `t` is the carry after step
1275    ///     `t+1`, so row `length-1` matches the no-trajectory case.
1276    ///
1277    /// Mirrors JAX's `lax.scan`. Common uses include time-stepping
1278    /// integrators with time-varying drives, Mamba-style SSM scans
1279    /// reading per-step inputs, and RNN-style sequence processing.
1280    Scan {
1281        body: Box<crate::Graph>,
1282        length: u32,
1283        save_trajectory: bool,
1284        /// Number of "broadcast" inputs — values that are constant
1285        /// across iterations. Outer scan inputs in order:
1286        ///   `[init, bcast_0..bcast_{B-1}, xs_0..xs_{X-1}]`
1287        /// Body Op::Inputs in NodeId order:
1288        ///   `[carry, bcast_0..bcast_{B-1}, x_t_0..x_t_{X-1}]`
1289        /// CPU executor fills bcast slots ONCE before the iteration
1290        /// loop (xs slots are filled per-step). The reverse-mode AD
1291        /// pre-pass materialises each bcast into an xs of shape
1292        /// `[length, *bcast]` via broadcast `Mul` so the rest of the
1293        /// VJP / executor pipeline can stay unchanged. `0` (default)
1294        /// keeps the original carry+xs scan shape.
1295        num_bcast: u32,
1296        /// Number of per-step `xs` inputs. Total outer Op::Scan
1297        /// inputs is `1 + num_bcast + num_xs`.
1298        num_xs: u32,
1299        /// Number of trajectory checkpoints when `save_trajectory ==
1300        /// true`. `0` means "save all `length` rows" (default). A
1301        /// positive value `K` means save only `K` evenly-spaced rows
1302        /// at indices `floor(t * length / K)` for `t in 0..K`. Used
1303        /// by recursive checkpointed AD: store O(√T) carries during
1304        /// forward, recompute the rest in the backward pass.
1305        ///
1306        /// When `0` (or `K == length`), the saved trajectory has
1307        /// shape `[length, *carry]` — same as the original behavior.
1308        /// When `0 < K < length`, the saved trajectory has shape
1309        /// `[K, *carry]`.
1310        num_checkpoints: u32,
1311    },
1312
1313    /// Reverse-mode AD companion to `Op::Scan` — extracts the carry
1314    /// gradient `dinit`. Walks `t = length-1 .. 0`, applying `body_vjp`
1315    /// to thread `dcarry` back through the time loop.
1316    ///
1317    /// Inputs (in order):
1318    ///   `[init, trajectory, upstream, xs_0, ..., xs_{num_xs-1}]`
1319    /// Output: `dinit`, shape = carry shape.
1320    ///
1321    /// `body_vjp` is the result of
1322    /// `autodiff::grad(body, [carry_id, xs_0_id, ..., xs_{num_xs-1}_id])`
1323    /// — a graph with `1 + num_xs + 1` Inputs (carry + x_t_i for each
1324    /// xs + `"d_output"`) and `1 + num_xs` outputs
1325    /// (dcarry + dx_t_i for each xs). This op reads `outputs[0]` =
1326    /// dcarry; the sibling [`Self::ScanBackwardXs`] reads the
1327    /// `outputs[1 + xs_idx]` slot for each xs gradient.
1328    ScanBackward {
1329        body_vjp: Box<crate::Graph>,
1330        length: u32,
1331        save_trajectory: bool,
1332        num_xs: u32,
1333        /// When `0` or equal to `length`, the trajectory input has
1334        /// shape `[length, *carry]` — every step's carry is cached
1335        /// (`CheckpointStrategy::All`). When `0 < K < length`, the
1336        /// trajectory input has shape `[K, *carry]` and the executor
1337        /// recomputes intermediate carries via `forward_body` between
1338        /// checkpoints. `forward_body` must be `Some` whenever this
1339        /// is < length.
1340        num_checkpoints: u32,
1341        /// Forward body (the same `body` from the forward Op::Scan).
1342        /// Required when `num_checkpoints > 0 && < length` so the
1343        /// executor can recompute carries between saved checkpoints.
1344        /// `None` for the All strategy (no recompute needed).
1345        forward_body: Option<Box<crate::Graph>>,
1346    },
1347
1348    /// Companion to [`Self::ScanBackward`] that extracts one stacked
1349    /// per-step `dxs_i` (shape `[length, *per_step_xs_i]`). Same inputs
1350    /// and same `body_vjp` graph as ScanBackward — `xs_idx` selects
1351    /// which body_vjp output to stack into the result.
1352    ///
1353    /// Note: each ScanBackwardXs runs its own backward loop. A future
1354    /// optimization can fuse them into a single multi-output backward
1355    /// kernel; for now it's `1 + num_xs` independent sweeps.
1356    ScanBackwardXs {
1357        body_vjp: Box<crate::Graph>,
1358        length: u32,
1359        save_trajectory: bool,
1360        num_xs: u32,
1361        xs_idx: u32,
1362        num_checkpoints: u32,
1363        forward_body: Option<Box<crate::Graph>>,
1364    },
1365
1366    /// CPU reference 3D Gaussian splat forward render.
1367    ///
1368    /// Seven flat F32 inputs (scene buffers + camera/render meta):
1369    ///   0. positions `[N*3]`
1370    ///   1. scales `[N*3]` (log-space)
1371    ///   2. rotations `[N*4]` (xyzw)
1372    ///   3. opacities `[N]` (logit)
1373    ///   4. colors `[N*3]` (linear RGB)
1374    ///   5. sh_coeffs `[N * sh_coeff_count * 3]`
1375    ///   6. meta `[23]` — camera position/target/up/fov/near/far, background RGB,
1376    ///      then width/height/tile_size/radius_scale/alpha_cutoff/max_splat_steps/
1377    ///      transmittance_threshold/max_list_entries as f32 bit-patterns.
1378    ///
1379    /// Output: `[height * width * 4]` linear RGBA (display gamma baked in).
1380    /// Build via [`crate::Graph::gaussian_splat_render`].
1381    ///
1382    /// Differentiable backward is not implemented in v1; autodiff treats this
1383    /// op as non-differentiable (same as [`Op::Sample`]).
1384    GaussianSplatRender {
1385        width: u32,
1386        height: u32,
1387        tile_size: u32,
1388        radius_scale: f32,
1389        alpha_cutoff: f32,
1390        max_splat_steps: u32,
1391        transmittance_threshold: f32,
1392        max_list_entries: u32,
1393    },
1394
1395    /// Backward pass for [`Self::GaussianSplatRender`].
1396    ///
1397    /// Eight inputs: the same seven as forward plus `d_loss_rgba` `[W*H*4]`
1398    /// (only RGB channels are used). Re-runs the training forward internally.
1399    ///
1400    /// Output: packed gradients
1401    /// `[positions(3N) | scales(3N) | rotations(4N) | opacities(N) | colors(3N) | sh(N*sh*3)]`.
1402    /// Unpack via [`crate::ops::splat::unpack_gaussian_splat_packed_grads`].
1403    GaussianSplatRenderBackward {
1404        width: u32,
1405        height: u32,
1406        tile_size: u32,
1407        radius_scale: f32,
1408        alpha_cutoff: f32,
1409        max_splat_steps: u32,
1410        transmittance_threshold: f32,
1411        max_list_entries: u32,
1412        loss_grad_clip: f32,
1413        sh_band: u32,
1414        max_anisotropy: f32,
1415    },
1416
1417    /// Strict IR stage 1: project, bin, sort, build per-pixel rays.
1418    ///
1419    /// Seven inputs (same scene + meta as [`Self::GaussianSplatRender`]). Output: packed
1420    /// prepare buffer (see `rlx_splat::prep_layout::prep_packed_len`).
1421    GaussianSplatPrepare {
1422        width: u32,
1423        height: u32,
1424        tile_size: u32,
1425        radius_scale: f32,
1426        alpha_cutoff: f32,
1427        max_splat_steps: u32,
1428        transmittance_threshold: f32,
1429        max_list_entries: u32,
1430    },
1431
1432    /// Strict IR stage 2: tile raster from [`Self::GaussianSplatPrepare`] output.
1433    ///
1434    /// Inputs: `prep` packed buffer, `meta` `[23]`. Output: `[width * height * 4]` RGBA.
1435    GaussianSplatRasterize {
1436        width: u32,
1437        height: u32,
1438        tile_size: u32,
1439        alpha_cutoff: f32,
1440        max_splat_steps: u32,
1441        transmittance_threshold: f32,
1442        max_list_entries: u32,
1443    },
1444
1445    /// User-registered custom op. `name` keys into the
1446    /// [`crate::op_registry`] for shape inference, autodiff, and
1447    /// per-backend execution. `attrs` is an opaque blob passed
1448    /// through to those callbacks (FFT direction, SparseLU
1449    /// reordering strategy, etc.). `num_inputs` is captured at
1450    /// construction time so [`Op::num_inputs`] stays infallible
1451    /// without a registry lookup. Build via [`crate::Graph::custom_op`].
1452    Custom {
1453        name: String,
1454        num_inputs: u32,
1455        attrs: Vec<u8>,
1456    },
1457
1458    /// 1D Fast Fourier Transform along the last axis.
1459    ///
1460    /// Convention: complex tensors are represented as 2N real-block
1461    /// — the input shape is `[..., 2N]` along the last axis, with
1462    /// the first N elements the real part and the second N the
1463    /// imaginary part. Output shape matches the input. Last axis
1464    /// length must be even (and a power of 2 for the v1 radix-2
1465    /// kernel; other sizes will eventually go through mixed-radix).
1466    ///
1467    /// Both forward and inverse are **unnormalized** (no 1/N scale):
1468    ///   `fft(x)[k] = Σ x[n]·exp(-2πi·nk/N)`
1469    ///   `ifft(y)[n] = Σ y[k]·exp(+2πi·nk/N)`
1470    /// so `ifft(fft(x)) = N·x`. Users dividing by N for round-trip
1471    /// identity matches numpy's `fft.fft` / `fft.ifft·N` convention.
1472    ///
1473    /// The unnormalized choice keeps both AD rules free of scaling:
1474    ///   * reverse-mode VJP: `VJP(fft) = ifft`, `VJP(ifft) = fft`
1475    ///     (transpose of the DFT matrix over the 2N-real-block view
1476    ///     equals the unnormalized inverse).
1477    ///   * forward-mode JVP: same op, same direction — FFT is linear,
1478    ///     so the JVP is the linear map itself, not its transpose.
1479    ///
1480    /// CPU paths exist for both `DType::F32` and `DType::F64` on the
1481    /// 2N-real-block layout. Native `DType::C64` and non-power-of-two
1482    /// sizes (Bluestein / mixed-radix) are not implemented; ND FFT
1483    /// and non-CPU backend lowerings are deferred.
1484    Fft {
1485        inverse: bool,
1486    },
1487
1488    /// User-defined sub-graph with optional override AD rules.
1489    /// Mirrors JAX's `custom_vjp` / `custom_jvp` decorators: the
1490    /// caller wraps a forward computation and supplies its own
1491    /// reverse- and/or forward-mode AD bodies. Useful when:
1492    ///   * The forward is iterative (Newton, fixed-point) and
1493    ///     differentiating through the loop is wasteful — the
1494    ///     vjp_body computes the implicit-function gradient at the
1495    ///     converged point in one shot.
1496    ///   * The math has a closed-form gradient that's much cheaper
1497    ///     than autodiff.
1498    ///   * The forward op is non-differentiable by tracing
1499    ///     (sampling, argmax) and the user wants a smooth surrogate.
1500    ///
1501    /// **fwd_body**: `num_inputs` Op::Inputs in NodeId construction
1502    /// order, one Op::Output (the primal y). Forward execution
1503    /// inlines this body once.
1504    ///
1505    /// **vjp_body** (optional): Op::Inputs are `num_inputs` primal
1506    /// inputs in NodeId order, plus two special-named Inputs —
1507    /// `"primal_output"` (the y from forward) and `"d_output"` (the
1508    /// upstream gradient). Outputs: `num_inputs` tensors in
1509    /// `set_outputs` order, matching the gradients of each primal
1510    /// input. When `None`, reverse-mode AD recurses into fwd_body
1511    /// — same as if the op were inlined.
1512    ///
1513    /// **jvp_body** (optional): Op::Inputs are `num_inputs` primal
1514    /// inputs in NodeId order, `num_inputs` special-named Inputs
1515    /// `"tangent_0"..="tangent_{num_inputs-1}"` carrying each input's
1516    /// tangent, and an optional special-named `"primal_output"` Input
1517    /// (the y from forward, useful when the JVP must be evaluated at
1518    /// a converged / nonlinear point — e.g. IFT-style forward-mode AD
1519    /// of an iterative solver). Output: 1 tensor (the tangent of y).
1520    /// When `None`, forward-mode AD recurses into fwd_body.
1521    ///
1522    /// `num_inputs` is captured so [`Op::num_inputs`] stays
1523    /// infallible. Build via [`crate::Graph::custom_fn`].
1524    CustomFn {
1525        fwd_body: Box<crate::Graph>,
1526        vjp_body: Option<Box<crate::Graph>>,
1527        jvp_body: Option<Box<crate::Graph>>,
1528        num_inputs: u32,
1529    },
1530}
1531
1532impl Op {
1533    /// PLAN L4: discriminant for backend-supported-set checks.
1534    /// Stable, parameter-free identity per variant — `Op::Activation(_)`
1535    /// and `Op::Activation(Relu)` share the same `OpKind::Activation`.
1536    pub fn kind(&self) -> OpKind {
1537        match self {
1538            Op::Input { .. } => OpKind::Input,
1539            Op::Param { .. } => OpKind::Param,
1540            Op::Constant { .. } => OpKind::Constant,
1541            Op::Activation(_) => OpKind::Activation,
1542            Op::Cast { .. } => OpKind::Cast,
1543            Op::Quantize { .. } => OpKind::Quantize,
1544            Op::Dequantize { .. } => OpKind::Dequantize,
1545            Op::FakeQuantize { .. } => OpKind::FakeQuantize,
1546            Op::FakeQuantizeLSQ { .. } => OpKind::FakeQuantizeLSQ,
1547            Op::FakeQuantizeLSQBackwardX { .. } => OpKind::FakeQuantizeLSQBackwardX,
1548            Op::FakeQuantizeLSQBackwardScale { .. } => OpKind::FakeQuantizeLSQBackwardScale,
1549            Op::Binary(_) => OpKind::Binary,
1550            Op::Compare(_) => OpKind::Compare,
1551            Op::Where => OpKind::Where,
1552            Op::ElementwiseRegion { .. } => OpKind::ElementwiseRegion,
1553            Op::MatMul => OpKind::MatMul,
1554            Op::DotGeneral { .. } => OpKind::DotGeneral,
1555            Op::DenseSolve => OpKind::DenseSolve,
1556            Op::BatchedDenseSolve => OpKind::BatchedDenseSolve,
1557            Op::LayerNorm { .. } => OpKind::LayerNorm,
1558            Op::LayerNorm2d { .. } => OpKind::LayerNorm2d,
1559            Op::GroupNorm { .. } => OpKind::GroupNorm,
1560            Op::RmsNorm { .. } => OpKind::RmsNorm,
1561            Op::ResizeNearest2x => OpKind::ResizeNearest2x,
1562            Op::Attention { .. } => OpKind::Attention,
1563            Op::Rope { .. } => OpKind::Rope,
1564            Op::AxialRope2d { .. } => OpKind::AxialRope2d,
1565            Op::Reshape { .. } => OpKind::Reshape,
1566            Op::Transpose { .. } => OpKind::Transpose,
1567            Op::Narrow { .. } => OpKind::Narrow,
1568            Op::Concat { .. } => OpKind::Concat,
1569            Op::Expand { .. } => OpKind::Expand,
1570            Op::Gather { .. } => OpKind::Gather,
1571            Op::Reduce { .. } => OpKind::Reduce,
1572            Op::Softmax { .. } => OpKind::Softmax,
1573            Op::Cumsum { .. } => OpKind::Cumsum,
1574            Op::TopK { .. } => OpKind::TopK,
1575            Op::Sample { .. } => OpKind::Sample,
1576            Op::Conv { .. } => OpKind::Conv,
1577            Op::ConvTranspose2d { .. } => OpKind::ConvTranspose2d,
1578            Op::Pool { .. } => OpKind::Pool,
1579            Op::ReluBackward => OpKind::ReluBackward,
1580            Op::ActivationBackward { .. } => OpKind::ActivationBackward,
1581            Op::FakeQuantizeBackward { .. } => OpKind::FakeQuantizeBackward,
1582            Op::ComplexNormSq => OpKind::ComplexNormSq,
1583            Op::ComplexNormSqBackward => OpKind::ComplexNormSqBackward,
1584            Op::Conjugate => OpKind::Conjugate,
1585            Op::LayerNormBackwardInput { .. } => OpKind::LayerNormBackwardInput,
1586            Op::LayerNormBackwardGamma { .. } => OpKind::LayerNormBackwardGamma,
1587            Op::RmsNormBackwardInput { .. } => OpKind::RmsNormBackwardInput,
1588            Op::RmsNormBackwardGamma { .. } => OpKind::RmsNormBackwardGamma,
1589            Op::RmsNormBackwardBeta { .. } => OpKind::RmsNormBackwardBeta,
1590            Op::RopeBackward { .. } => OpKind::RopeBackward,
1591            Op::GroupNormBackwardInput { .. } => OpKind::GroupNormBackwardInput,
1592            Op::GroupNormBackwardGamma { .. } => OpKind::GroupNormBackwardGamma,
1593            Op::GroupNormBackwardBeta { .. } => OpKind::GroupNormBackwardBeta,
1594            Op::CumsumBackward { .. } => OpKind::CumsumBackward,
1595            Op::GatherBackward { .. } => OpKind::GatherBackward,
1596            Op::MaxPool2dBackward { .. } => OpKind::MaxPool2dBackward,
1597            Op::Conv2dBackwardInput { .. } => OpKind::Conv2dBackwardInput,
1598            Op::Conv2dBackwardWeight { .. } => OpKind::Conv2dBackwardWeight,
1599            Op::SoftmaxCrossEntropyWithLogits => OpKind::SoftmaxCrossEntropyWithLogits,
1600            Op::SoftmaxCrossEntropyBackward => OpKind::SoftmaxCrossEntropyBackward,
1601            Op::AttentionBackward { .. } => OpKind::AttentionBackward,
1602            Op::GroupedMatMul => OpKind::GroupedMatMul,
1603            Op::DequantGroupedMatMul { .. } => OpKind::DequantGroupedMatMul,
1604            Op::DequantMoEWeights { .. } => OpKind::DequantMoEWeights,
1605            Op::ScatterAdd => OpKind::ScatterAdd,
1606            Op::LoraMatMul { .. } => OpKind::LoraMatMul,
1607            Op::DequantMatMul { .. } => OpKind::DequantMatMul,
1608            Op::QMatMul { .. } => OpKind::QMatMul,
1609            Op::QConv2d { .. } => OpKind::QConv2d,
1610            Op::SelectiveScan { .. } => OpKind::SelectiveScan,
1611            Op::GatedDeltaNet { .. } => OpKind::GatedDeltaNet,
1612            Op::FusedSwiGLU { .. } => OpKind::FusedSwiGLU,
1613            Op::FusedMatMulBiasAct { .. } => OpKind::FusedMatMulBiasAct,
1614            Op::FusedResidualLN { .. } => OpKind::FusedResidualLN,
1615            Op::FusedResidualRmsNorm { .. } => OpKind::FusedResidualRmsNorm,
1616            Op::FusedAttentionBlock { .. } => OpKind::FusedAttentionBlock,
1617            Op::FusedTransformerLayer { .. } => OpKind::FusedTransformerLayer,
1618            Op::If { .. } => OpKind::If,
1619            Op::While { .. } => OpKind::While,
1620            Op::Scan { .. } => OpKind::Scan,
1621            Op::ScanBackward { .. } => OpKind::ScanBackward,
1622            Op::ScanBackwardXs { .. } => OpKind::ScanBackwardXs,
1623            Op::GaussianSplatRender { .. } => OpKind::GaussianSplatRender,
1624            Op::GaussianSplatRenderBackward { .. } => OpKind::GaussianSplatRenderBackward,
1625            Op::GaussianSplatPrepare { .. } => OpKind::GaussianSplatPrepare,
1626            Op::GaussianSplatRasterize { .. } => OpKind::GaussianSplatRasterize,
1627            Op::Custom { .. } => OpKind::Custom,
1628            Op::CustomFn { .. } => OpKind::CustomFn,
1629            Op::Fft { .. } => OpKind::Fft,
1630        }
1631    }
1632
1633    /// True if this op is element-wise (same shape in, same shape out).
1634    /// Element-wise ops are prime fusion candidates.
1635    pub fn is_elementwise(&self) -> bool {
1636        matches!(
1637            self,
1638            Op::Activation(_)
1639                | Op::Cast { .. }
1640                | Op::Binary(_)
1641                | Op::Compare(_)
1642                | Op::Where
1643                | Op::ElementwiseRegion { .. }
1644        )
1645    }
1646
1647    /// True if this op is a BLAS/compute-intensive op that forms a fusion boundary.
1648    pub fn is_blas(&self) -> bool {
1649        matches!(
1650            self,
1651            Op::MatMul
1652                | Op::DotGeneral { .. }
1653                | Op::DenseSolve
1654                | Op::BatchedDenseSolve
1655                | Op::Conv { .. }
1656                | Op::ConvTranspose2d { .. }
1657                | Op::FusedMatMulBiasAct { .. }
1658                | Op::GroupedMatMul
1659                | Op::DequantGroupedMatMul { .. }
1660                | Op::DequantMoEWeights { .. }
1661                | Op::LoraMatMul { .. }
1662                | Op::DequantMatMul { .. }
1663                | Op::QMatMul { .. }
1664                | Op::QConv2d { .. }
1665        )
1666    }
1667
1668    /// True if element-wise fusion must not span across this op.
1669    pub fn is_fusion_boundary(&self) -> bool {
1670        self.is_blas()
1671            || matches!(
1672                self,
1673                Op::GaussianSplatRender { .. }
1674                    | Op::GaussianSplatRenderBackward { .. }
1675                    | Op::GaussianSplatPrepare { .. }
1676                    | Op::GaussianSplatRasterize { .. }
1677            )
1678    }
1679
1680    /// True if this op is a reduction (drives loop iteration in fused kernels).
1681    pub fn is_reduction(&self) -> bool {
1682        matches!(
1683            self,
1684            Op::Reduce { .. } | Op::Softmax { .. } | Op::TopK { .. }
1685        )
1686    }
1687
1688    /// Number of tensor inputs this op expects.
1689    pub fn num_inputs(&self) -> usize {
1690        match self {
1691            Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => 0,
1692            Op::Activation(_)
1693            | Op::Cast { .. }
1694            | Op::Reshape { .. }
1695            | Op::Quantize { .. }
1696            | Op::Dequantize { .. }
1697            | Op::Transpose { .. }
1698            | Op::Narrow { .. }
1699            | Op::Expand { .. }
1700            | Op::Reduce { .. }
1701            | Op::Softmax { .. }
1702            | Op::FusedSwiGLU { .. }
1703            | Op::TopK { .. }
1704            | Op::Cumsum { .. }
1705            | Op::Sample { .. }
1706            | Op::ResizeNearest2x => 1,
1707            // EMA / Fixed scale modes carry a state tensor as a 2nd input;
1708            // PerBatch (default) doesn't need one.
1709            Op::FakeQuantize { scale_mode, .. } => match scale_mode {
1710                ScaleMode::PerBatch => 1,
1711                ScaleMode::EMA { .. } | ScaleMode::Fixed => 2,
1712            },
1713            Op::FakeQuantizeLSQ { .. } => 2, // x, scale (learned param)
1714            Op::FakeQuantizeLSQBackwardX { .. } | Op::FakeQuantizeLSQBackwardScale { .. } => 3, // x, scale, dy
1715            Op::Binary(_) | Op::Compare(_) | Op::Gather { .. } | Op::MatMul | Op::ScatterAdd => 2,
1716            Op::GroupedMatMul => 3,               // input, weight, expert_idx
1717            Op::DequantGroupedMatMul { .. } => 3, // input, packed_w, expert_idx
1718            Op::DequantMoEWeights { .. } => 1,    // packed_w
1719            Op::LoraMatMul { .. } => 4,           // x, w, a, b
1720            // x, w_q, scale, zp — or x, packed_w_bytes for GGUF
1721            // schemes (their scales/mins live inside the packed bytes,
1722            // see `QuantScheme::is_gguf`).
1723            Op::DequantMatMul { scheme } => {
1724                if scheme.is_gguf() {
1725                    2
1726                } else {
1727                    4
1728                }
1729            }
1730            Op::QMatMul { .. } => 3,       // x, w, bias
1731            Op::QConv2d { .. } => 3,       // x, w, bias
1732            Op::SelectiveScan { .. } => 5, // x, delta, a, b, c
1733            Op::GatedDeltaNet { carry_state, .. } if *carry_state => 6, // + state in/out
1734            Op::GatedDeltaNet { .. } => 5, // q, k, v, g, beta
1735            Op::Where => 3,                // cond, on_true, on_false
1736            Op::Attention { mask_kind, .. } => match mask_kind {
1737                MaskKind::Custom | MaskKind::Bias => 4, // Q, K, V, mask
1738                _ => 3,                                 // Q, K, V (mask synthesized in-kernel)
1739            },
1740            Op::AttentionBackward { mask_kind, .. } => match mask_kind {
1741                MaskKind::Custom | MaskKind::Bias => 5, // q, k, v, dy, mask
1742                _ => 4,                                 // q, k, v, dy
1743            },
1744            Op::Rope { .. } => 3, // x, cos, sin
1745            Op::AxialRope2d { .. } => 1,
1746            Op::LayerNorm { .. }
1747            | Op::LayerNorm2d { .. }
1748            | Op::GroupNorm { .. }
1749            | Op::RmsNorm { .. } => 3, // input, gamma, beta
1750            Op::FusedMatMulBiasAct { .. } => 3, // input, weight, bias
1751            Op::FusedResidualLN { has_bias: true, .. } => 5, // x, residual, bias, gamma, beta
1752            Op::FusedResidualLN {
1753                has_bias: false, ..
1754            } => 4, // x, residual, gamma, beta
1755            Op::FusedResidualRmsNorm { has_bias: true, .. } => 5, // x, residual, bias, gamma, beta
1756            Op::FusedResidualRmsNorm {
1757                has_bias: false, ..
1758            } => 4, // x, residual, gamma, beta
1759            Op::Conv { .. } | Op::ConvTranspose2d { .. } => 2, // input, weight (bias via Add)
1760            Op::Pool { .. } => 1,
1761            Op::ReluBackward => 2,                  // x, dy
1762            Op::ActivationBackward { .. } => 2,     // x, dy
1763            Op::FakeQuantizeBackward { .. } => 2,   // x, dy
1764            Op::ComplexNormSq => 1,                 // z (C64)
1765            Op::ComplexNormSqBackward => 2,         // z, g
1766            Op::Conjugate => 1,                     // z (C64)
1767            Op::LayerNormBackwardInput { .. } => 3, // x, gamma, dy
1768            Op::LayerNormBackwardGamma { .. } => 2, // x, dy
1769            Op::RmsNormBackwardInput { .. } => 4,   // x, gamma, beta, dy
1770            Op::RmsNormBackwardGamma { .. } => 4,
1771            Op::RmsNormBackwardBeta { .. } => 4,
1772            Op::RopeBackward { .. } => 3,           // dy, cos, sin
1773            Op::GroupNormBackwardInput { .. } => 4, // x, gamma, beta, dy
1774            Op::GroupNormBackwardGamma { .. } => 2, // x, dy
1775            Op::GroupNormBackwardBeta { .. } => 2,
1776            Op::CumsumBackward { .. } => 1,         // dy
1777            Op::GatherBackward { .. } => 2,         // dy, indices
1778            Op::MaxPool2dBackward { .. } => 2,      // x, dy
1779            Op::Conv2dBackwardInput { .. } => 2,    // dy, w
1780            Op::Conv2dBackwardWeight { .. } => 2,   // x, dy
1781            Op::SoftmaxCrossEntropyWithLogits => 2, // logits, labels
1782            Op::SoftmaxCrossEntropyBackward => 3,   // logits, labels, d_loss
1783            Op::Concat { .. } => 0,                 // variadic — checked at graph level
1784            Op::DotGeneral { .. } => 2,
1785            Op::DenseSolve => 2,        // A, b
1786            Op::BatchedDenseSolve => 2, // A [B,N,N], b [B,N] or [B,N,K]
1787            Op::FusedAttentionBlock {
1788                has_bias, has_rope, ..
1789            } => 4 + if *has_bias { 2 } else { 0 } + if *has_rope { 2 } else { 0 },
1790            Op::If { .. } => 1,    // predicate (captures handled separately)
1791            Op::While { .. } => 0, // variadic loop-carried; checked at graph level
1792            Op::Scan {
1793                num_bcast, num_xs, ..
1794            } => 1 + *num_bcast as usize + *num_xs as usize,
1795            Op::ScanBackward { num_xs, .. } => 3 + *num_xs as usize, // init, trajectory, upstream, xs_0..
1796            Op::ScanBackwardXs { num_xs, .. } => 3 + *num_xs as usize, // same as ScanBackward
1797            Op::GaussianSplatRender { .. } => 7,
1798            Op::GaussianSplatRenderBackward { .. } => 8,
1799            Op::GaussianSplatPrepare { .. } => 7,
1800            Op::GaussianSplatRasterize { .. } => 2,
1801            Op::FusedTransformerLayer { has_bias, .. } => {
1802                // hidden + qkv_w + out_w + ln1_g + ln1_b + fc1_w + fc2_w + ln2_g + ln2_b + mask = 10
1803                // bias variant adds: qkv_b + out_b + fc1_b + fc2_b = 4 more
1804                10 + if *has_bias { 4 } else { 0 }
1805            }
1806            Op::ElementwiseRegion { num_inputs, .. } => *num_inputs as usize,
1807            Op::Custom { num_inputs, .. } => *num_inputs as usize,
1808            Op::CustomFn { num_inputs, .. } => *num_inputs as usize,
1809            Op::Fft { .. } => 1,
1810        }
1811    }
1812}
1813
1814impl std::fmt::Display for Op {
1815    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1816        match self {
1817            Op::Input { name } => write!(f, "input(\"{name}\")"),
1818            Op::Param { name } => write!(f, "param(\"{name}\")"),
1819            Op::Constant { data } => write!(f, "const({}B)", data.len()),
1820            Op::Activation(a) => write!(f, "{a:?}"),
1821            Op::Quantize { axis, scales, .. } => match axis {
1822                None => write!(f, "quantize(s={})", scales[0]),
1823                Some(d) => write!(f, "quantize(axis={d},nch={})", scales.len()),
1824            },
1825            Op::Dequantize { axis, scales, .. } => match axis {
1826                None => write!(f, "dequantize(s={})", scales[0]),
1827                Some(d) => write!(f, "dequantize(axis={d},nch={})", scales.len()),
1828            },
1829            Op::FakeQuantize {
1830                bits,
1831                axis,
1832                ste,
1833                scale_mode,
1834            } => match axis {
1835                None => write!(
1836                    f,
1837                    "fake_quant(bits={bits},ste={ste:?},scale={scale_mode:?})"
1838                ),
1839                Some(d) => write!(
1840                    f,
1841                    "fake_quant(bits={bits},axis={d},ste={ste:?},scale={scale_mode:?})"
1842                ),
1843            },
1844            Op::FakeQuantizeLSQ { bits, axis } => match axis {
1845                None => write!(f, "fake_quant_lsq(bits={bits})"),
1846                Some(d) => write!(f, "fake_quant_lsq(bits={bits},axis={d})"),
1847            },
1848            Op::FakeQuantizeLSQBackwardX { bits, .. } => {
1849                write!(f, "fake_quant_lsq_bwd_x(bits={bits})")
1850            }
1851            Op::FakeQuantizeLSQBackwardScale { bits, .. } => {
1852                write!(f, "fake_quant_lsq_bwd_s(bits={bits})")
1853            }
1854            Op::Cast { to } => write!(f, "cast({to})"),
1855            Op::Binary(op) => write!(f, "{op:?}"),
1856            Op::Compare(op) => write!(f, "{op:?}"),
1857            Op::Where => write!(f, "where"),
1858            Op::MatMul => write!(f, "matmul"),
1859            Op::DotGeneral { .. } => write!(f, "dot_general"),
1860            Op::DenseSolve => write!(f, "dense_solve"),
1861            Op::BatchedDenseSolve => write!(f, "batched_dense_solve"),
1862            Op::LayerNorm { eps, .. } => write!(f, "layer_norm(eps={eps})"),
1863            Op::GroupNorm { num_groups, eps } => {
1864                write!(f, "group_norm(groups={num_groups},eps={eps})")
1865            }
1866            Op::ResizeNearest2x => write!(f, "resize_nearest_2x"),
1867            Op::RmsNorm { eps, .. } => write!(f, "rms_norm(eps={eps})"),
1868            Op::Attention {
1869                num_heads,
1870                head_dim,
1871                mask_kind,
1872            } => match mask_kind {
1873                MaskKind::Custom => write!(f, "attention(h={num_heads},d={head_dim})"),
1874                MaskKind::None => write!(f, "attention(h={num_heads},d={head_dim},nomask)"),
1875                MaskKind::Causal => write!(f, "attention(h={num_heads},d={head_dim},causal)"),
1876                MaskKind::SlidingWindow(w) => {
1877                    write!(f, "attention(h={num_heads},d={head_dim},sw={w})")
1878                }
1879                MaskKind::Bias => write!(f, "attention(h={num_heads},d={head_dim},bias)"),
1880            },
1881            Op::Rope { head_dim, n_rot } => write!(f, "rope(d={head_dim}, n_rot={n_rot})"),
1882            Op::AxialRope2d {
1883                end_x,
1884                end_y,
1885                head_dim,
1886                num_heads,
1887                theta,
1888                repeat_factor,
1889            } => write!(
1890                f,
1891                "axial_rope2d({end_x}x{end_y},h={num_heads},d={head_dim},θ={theta},r={repeat_factor})"
1892            ),
1893            Op::Reshape { new_shape } => write!(f, "reshape({new_shape:?})"),
1894            Op::Transpose { perm } => write!(f, "transpose({perm:?})"),
1895            Op::Narrow { axis, start, len } => write!(f, "narrow({axis},{start},{len})"),
1896            Op::Concat { axis } => write!(f, "concat(axis={axis})"),
1897            Op::Expand { .. } => write!(f, "expand"),
1898            Op::Gather { axis } => write!(f, "gather(axis={axis})"),
1899            Op::Reduce { op, axes, .. } => write!(f, "reduce_{op:?}({axes:?})"),
1900            Op::Softmax { axis } => write!(f, "softmax(axis={axis})"),
1901            Op::Cumsum { axis, exclusive } => {
1902                if *exclusive {
1903                    write!(f, "cumsum(axis={axis},excl)")
1904                } else {
1905                    write!(f, "cumsum(axis={axis})")
1906                }
1907            }
1908            Op::Sample {
1909                top_k,
1910                top_p,
1911                temperature,
1912                ..
1913            } => {
1914                write!(f, "sample(t={temperature}")?;
1915                if *top_k > 0 {
1916                    write!(f, ",k={top_k}")?;
1917                }
1918                if *top_p < 1.0 {
1919                    write!(f, ",p={top_p}")?;
1920                }
1921                write!(f, ")")
1922            }
1923            Op::TopK { k } => write!(f, "topk(k={k})"),
1924            Op::GroupedMatMul => write!(f, "grouped_matmul"),
1925            Op::DequantGroupedMatMul { scheme } => {
1926                write!(f, "dequant_grouped_matmul({scheme})")
1927            }
1928            Op::DequantMoEWeights { scheme } => write!(f, "dequant_moe_weights({scheme})"),
1929            Op::LoraMatMul { scale } => write!(f, "lora_matmul(scale={scale})"),
1930            Op::DequantMatMul { scheme } => write!(f, "dequant_matmul({scheme})"),
1931            Op::QMatMul {
1932                x_zp,
1933                w_zp,
1934                out_zp,
1935                mult,
1936            } => write!(
1937                f,
1938                "q_matmul(x_zp={x_zp},w_zp={w_zp},out_zp={out_zp},mult={mult})"
1939            ),
1940            Op::QConv2d { kernel_size, .. } => write!(f, "q_conv2d({kernel_size:?})"),
1941            Op::SelectiveScan { state_size } => write!(f, "ssm_scan(n={state_size})"),
1942            Op::GatedDeltaNet {
1943                state_size,
1944                carry_state,
1945            } => {
1946                if *carry_state {
1947                    write!(f, "gated_delta_net(n={state_size},carry)")
1948                } else {
1949                    write!(f, "gated_delta_net(n={state_size})")
1950                }
1951            }
1952            Op::ScatterAdd => write!(f, "scatter_add"),
1953            Op::Conv { kernel_size, .. } => write!(f, "conv2d({kernel_size:?})"),
1954            Op::ConvTranspose2d { kernel_size, .. } => {
1955                write!(f, "conv_transpose2d({kernel_size:?})")
1956            }
1957            Op::LayerNorm2d { eps } => write!(f, "layer_norm2d(eps={eps})"),
1958            Op::Pool {
1959                kind, kernel_size, ..
1960            } => write!(f, "pool_{kind:?}({kernel_size:?})"),
1961            Op::ReluBackward => write!(f, "relu_backward"),
1962            Op::ActivationBackward { kind } => write!(f, "{kind:?}_backward"),
1963            Op::ComplexNormSq => write!(f, "complex_norm_sq"),
1964            Op::ComplexNormSqBackward => write!(f, "complex_norm_sq_backward"),
1965            Op::Conjugate => write!(f, "conjugate"),
1966            Op::FakeQuantizeBackward { bits, ste, .. } => {
1967                write!(f, "fake_quant_backward(bits={bits},ste={ste:?})")
1968            }
1969            Op::MaxPool2dBackward { kernel_size, .. } => {
1970                write!(f, "maxpool2d_backward({kernel_size:?})")
1971            }
1972            Op::Conv2dBackwardInput { kernel_size, .. } => {
1973                write!(f, "conv2d_backward_input({kernel_size:?})")
1974            }
1975            Op::Conv2dBackwardWeight { kernel_size, .. } => {
1976                write!(f, "conv2d_backward_weight({kernel_size:?})")
1977            }
1978            Op::SoftmaxCrossEntropyWithLogits => write!(f, "sce_with_logits"),
1979            Op::SoftmaxCrossEntropyBackward => write!(f, "sce_backward"),
1980            Op::AttentionBackward {
1981                num_heads,
1982                head_dim,
1983                mask_kind,
1984                wrt,
1985            } => match mask_kind {
1986                MaskKind::None => write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},nomask)"),
1987                MaskKind::Causal => {
1988                    write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},causal)")
1989                }
1990                MaskKind::SlidingWindow(w) => {
1991                    write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},sw={w})")
1992                }
1993                MaskKind::Custom => {
1994                    write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},custom)")
1995                }
1996                MaskKind::Bias => write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},bias)"),
1997            },
1998            Op::FusedMatMulBiasAct { activation } => {
1999                write!(f, "fused_mm_bias")?;
2000                if let Some(a) = activation {
2001                    write!(f, "_{a:?}")?;
2002                }
2003                Ok(())
2004            }
2005            Op::FusedResidualLN { has_bias, eps } => {
2006                write!(f, "fused_residual")?;
2007                if *has_bias {
2008                    write!(f, "_bias")?;
2009                }
2010                write!(f, "_ln(eps={eps})")
2011            }
2012            Op::FusedResidualRmsNorm { has_bias, eps } => {
2013                write!(f, "fused_residual")?;
2014                if *has_bias {
2015                    write!(f, "_bias")?;
2016                }
2017                write!(f, "_rms(eps={eps})")
2018            }
2019            Op::FusedSwiGLU {
2020                cast_to,
2021                gate_first,
2022            } => {
2023                let mut s = match cast_to {
2024                    Some(dt) => format!("fused_swiglu(cast={dt}"),
2025                    None => "fused_swiglu(".to_string(),
2026                };
2027                if *gate_first {
2028                    s.push_str(",gate_first");
2029                }
2030                s.push(')');
2031                write!(f, "{s}")
2032            }
2033            Op::FusedAttentionBlock {
2034                num_heads,
2035                head_dim,
2036                has_bias,
2037                has_rope,
2038            } => {
2039                write!(f, "fused_attn(h={num_heads},d={head_dim}")?;
2040                if *has_bias {
2041                    write!(f, ",bias")?;
2042                }
2043                if *has_rope {
2044                    write!(f, ",rope")?;
2045                }
2046                write!(f, ")")
2047            }
2048            Op::If { .. } => write!(f, "if(...)"),
2049            Op::While { max_iterations, .. } => match max_iterations {
2050                Some(n) => write!(f, "while(...max={n})"),
2051                None => write!(f, "while(...)"),
2052            },
2053            Op::Scan {
2054                length,
2055                save_trajectory,
2056                num_xs,
2057                ..
2058            } => {
2059                let traj = if *save_trajectory { ",traj" } else { "" };
2060                let xs = if *num_xs > 0 {
2061                    format!(",xs={}", num_xs)
2062                } else {
2063                    String::new()
2064                };
2065                write!(f, "scan(len={length}{xs}{traj})")
2066            }
2067            Op::ScanBackward {
2068                length,
2069                save_trajectory,
2070                num_xs,
2071                ..
2072            } => {
2073                let traj = if *save_trajectory { ",traj" } else { "" };
2074                let xs = if *num_xs > 0 {
2075                    format!(",xs={}", num_xs)
2076                } else {
2077                    String::new()
2078                };
2079                write!(f, "scan_bwd(len={length}{xs}{traj})")
2080            }
2081            Op::ScanBackwardXs {
2082                length,
2083                save_trajectory,
2084                num_xs,
2085                xs_idx,
2086                ..
2087            } => {
2088                let traj = if *save_trajectory { ",traj" } else { "" };
2089                write!(
2090                    f,
2091                    "scan_bwd_xs(len={length},xs={num_xs},idx={xs_idx}{traj})"
2092                )
2093            }
2094            Op::FusedTransformerLayer {
2095                num_heads,
2096                head_dim,
2097                intermediate_size,
2098                has_bias,
2099                ..
2100            } => {
2101                write!(
2102                    f,
2103                    "fused_layer(h={num_heads},d={head_dim},int={intermediate_size}"
2104                )?;
2105                if *has_bias {
2106                    write!(f, ",bias")?;
2107                }
2108                write!(f, ")")
2109            }
2110            Op::ElementwiseRegion {
2111                chain,
2112                num_inputs,
2113                scalar_input_mask,
2114                input_modulus: _,
2115            } => {
2116                if *scalar_input_mask != 0 {
2117                    write!(
2118                        f,
2119                        "ew_region(in={num_inputs},steps={},scalar_mask=0x{:x})",
2120                        chain.len(),
2121                        scalar_input_mask
2122                    )
2123                } else {
2124                    write!(f, "ew_region(in={num_inputs},steps={})", chain.len())
2125                }
2126            }
2127            Op::LayerNormBackwardInput { eps, .. } => {
2128                write!(f, "layer_norm_backward_input(eps={eps})")
2129            }
2130            Op::LayerNormBackwardGamma { eps, .. } => {
2131                write!(f, "layer_norm_backward_gamma(eps={eps})")
2132            }
2133            Op::RmsNormBackwardInput { eps, .. } => write!(f, "rms_norm_backward_input(eps={eps})"),
2134            Op::RmsNormBackwardGamma { eps, .. } => write!(f, "rms_norm_backward_gamma(eps={eps})"),
2135            Op::RmsNormBackwardBeta { eps, .. } => write!(f, "rms_norm_backward_beta(eps={eps})"),
2136            Op::RopeBackward { head_dim, n_rot } => {
2137                write!(f, "rope_backward(d={head_dim},n_rot={n_rot})")
2138            }
2139            Op::GroupNormBackwardInput { num_groups, eps } => {
2140                write!(f, "group_norm_backward_input(g={num_groups},eps={eps})")
2141            }
2142            Op::GroupNormBackwardGamma { num_groups, eps } => {
2143                write!(f, "group_norm_backward_gamma(g={num_groups},eps={eps})")
2144            }
2145            Op::GroupNormBackwardBeta { num_groups, eps } => {
2146                write!(f, "group_norm_backward_beta(g={num_groups},eps={eps})")
2147            }
2148            Op::CumsumBackward { axis, exclusive } => {
2149                write!(f, "cumsum_backward(axis={axis},exclusive={exclusive})")
2150            }
2151            Op::GatherBackward { axis } => write!(f, "gather_backward(axis={axis})"),
2152            Op::GaussianSplatRender {
2153                width,
2154                height,
2155                tile_size,
2156                radius_scale,
2157                alpha_cutoff,
2158                max_splat_steps,
2159                transmittance_threshold,
2160                max_list_entries,
2161            } => write!(
2162                f,
2163                "gaussian_splat_render({width}x{height},tile={tile_size},r={radius_scale},a={alpha_cutoff},steps={max_splat_steps},t={transmittance_threshold},list={max_list_entries})"
2164            ),
2165            Op::GaussianSplatRenderBackward {
2166                width,
2167                height,
2168                loss_grad_clip,
2169                sh_band,
2170                ..
2171            } => write!(
2172                f,
2173                "gaussian_splat_render_bwd({width}x{height},clip={loss_grad_clip},sh={sh_band})"
2174            ),
2175            Op::GaussianSplatPrepare {
2176                width,
2177                height,
2178                tile_size,
2179                radius_scale,
2180                alpha_cutoff,
2181                max_splat_steps,
2182                transmittance_threshold,
2183                max_list_entries,
2184                ..
2185            } => write!(
2186                f,
2187                "gaussian_splat_prepare({width}x{height},tile={tile_size},r={radius_scale},a={alpha_cutoff},steps={max_splat_steps},t={transmittance_threshold},list={max_list_entries})"
2188            ),
2189            Op::GaussianSplatRasterize {
2190                width,
2191                height,
2192                tile_size,
2193                alpha_cutoff,
2194                max_splat_steps,
2195                transmittance_threshold,
2196                max_list_entries,
2197                ..
2198            } => write!(
2199                f,
2200                "gaussian_splat_rasterize({width}x{height},tile={tile_size},a={alpha_cutoff},steps={max_splat_steps},t={transmittance_threshold},list={max_list_entries})"
2201            ),
2202            Op::Custom {
2203                name,
2204                num_inputs,
2205                attrs,
2206            } => write!(f, "custom({name},in={num_inputs},attrs={}B)", attrs.len()),
2207            Op::CustomFn {
2208                num_inputs,
2209                vjp_body,
2210                jvp_body,
2211                ..
2212            } => {
2213                let v = if vjp_body.is_some() { ",vjp" } else { "" };
2214                let j = if jvp_body.is_some() { ",jvp" } else { "" };
2215                write!(f, "custom_fn(in={num_inputs}{v}{j})")
2216            }
2217            Op::Fft { inverse } => write!(f, "fft(inverse={inverse})"),
2218        }
2219    }
2220}