Skip to main content

rlx_ir/
op.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Operation types — every tensor op in the RLX IR.
17//!
18//! Designed for pattern-matching fusion: ops are grouped by category so
19//! fusion passes can reason about them structurally.
20
21use crate::DType;
22
23/// Unary element-wise activation functions.
24#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub enum Activation {
27    Gelu,
28    GeluApprox,
29    Silu, // SwiGLU gate activation
30    Relu,
31    Sigmoid,
32    Tanh,
33    Exp,
34    Log,
35    Sqrt,
36    Rsqrt,
37    Neg,
38    Abs,
39    /// `sin(x)`. Backward: `dx = upstream · cos(x)`.
40    Sin,
41    /// `cos(x)`. Backward: `dx = -upstream · sin(x)`.
42    Cos,
43    /// `tan(x)`. Backward: `dx = upstream · sec²(x) = upstream · (1 + tan²(x))`.
44    Tan,
45    /// `atan(x)`. Backward: `dx = upstream · (1 / (1 + x²))`.
46    Atan,
47    /// Round to nearest integer (half-to-even), in f32.
48    /// Forward: `x.round()`. Backward: STE — treats as identity, so
49    /// the gradient passes through unchanged. Useful as a primitive
50    /// for composing custom quantization schemes (Mul-by-recip-scale
51    /// → Round → Clamp → Mul-by-scale = a hand-rolled FakeQuantize
52    /// that the elementwise-region pass can fuse into a single kernel).
53    Round,
54}
55
56/// Scale-tracking strategy for `Op::FakeQuantize`. Determines how
57/// the per-channel `s[c]` is computed each forward pass.
58///
59/// * `PerBatch` — recompute `s[c] = max(|x|) / q_max` from the
60///   current data on every call. Simple, no extra inputs, but
61///   noisy for activations (max-abs jumps batch-to-batch).
62///
63/// * `EMA { decay }` — keep a running `s[c]` in a state tensor
64///   (passed as a second op input). On each call, blend the
65///   current per-batch max-abs into the state via
66///   `state' = decay·state + (1-decay)·max_abs`. Smooth scale
67///   over training, makes activation-QAT actually trainable.
68///   Typical `decay = 0.99`.
69///
70/// * `Fixed` — never recompute. The state tensor's value is
71///   used as-is each call (set once at construction or by the
72///   caller). Useful when scales are pre-calibrated.
73#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
74#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Default)]
75pub enum ScaleMode {
76    #[default]
77    PerBatch,
78    EMA {
79        decay: f32,
80    },
81    Fixed,
82}
83
84impl Eq for ScaleMode {}
85impl std::hash::Hash for ScaleMode {
86    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
87        match self {
88            ScaleMode::PerBatch => state.write_u8(0),
89            ScaleMode::EMA { decay } => {
90                state.write_u8(1);
91                state.write_u32(decay.to_bits());
92            }
93            ScaleMode::Fixed => state.write_u8(2),
94        }
95    }
96}
97
98/// Straight-through estimator variants for `Op::FakeQuantize`'s
99/// backward. The forward is the same regardless: discrete
100/// `clamp(round(x/s)) * s`. The choice here affects only the
101/// gradient w.r.t. `x` during training.
102///
103/// * `Identity` — `dx = upstream`. The original STE; treats the
104///   round as identity in the backward direction. Simplest, fine
105///   for moderate bit widths (i4 / i8).
106///
107/// * `ClippedIdentity` — `dx = upstream * (|x| ≤ q_max·s)`. Zero
108///   the gradient when the input was outside the quantization
109///   range (i.e. the clamp activated). Stops the optimizer from
110///   pushing weights further into saturation.
111///
112/// * `Tanh` — `dx = upstream * (1 - tanh²(x/s))`. Smooth surrogate
113///   for the round step. Slowly attenuates the gradient as `|x|`
114///   approaches `q_max·s`. Often best on tight bit widths (i2).
115///
116/// * `HardTanh` — `dx = upstream * (1 - |x/(q_max·s)|).max(0)`.
117///   Piecewise-linear cousin of `Tanh`; cheaper to compute and
118///   nearly as effective.
119#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
120#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
121pub enum SteKind {
122    #[default]
123    Identity,
124    ClippedIdentity,
125    Tanh,
126    HardTanh,
127}
128
129/// Binary element-wise operations.
130#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
132pub enum BinaryOp {
133    Add,
134    Sub,
135    Mul,
136    Div,
137    Max,
138    Min,
139    Pow,
140}
141
142/// Comparison operations (return Bool tensor).
143#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
144#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
145pub enum CmpOp {
146    Eq,
147    Ne,
148    Lt,
149    Le,
150    Gt,
151    Ge,
152}
153
154/// What kind of attention mask the kernel should apply.
155///
156/// Borrowed from MAX's `nn/attention/mha_mask.mojo` pattern (#20 in
157/// PLAN.md): one attention kernel handles all variants by branching on
158/// the mask kind, instead of forcing every caller to materialize a mask
159/// tensor. The win is two-fold:
160///   1. **`None`** — single unpadded sequence: no mask load, no per-key
161///      compare in the inner loop.
162///   2. **`Causal`** — autoregressive decode: kernel generates the upper-
163///      triangular fill from `(qi, ki)` directly; no `seq²` mask tensor
164///      ever exists.
165///
166/// `Custom` is the existing path — read mask values from the 4th input.
167#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
168#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
169pub enum MaskKind {
170    /// No masking — every position attends to every position.
171    None,
172    /// Causal (autoregressive) — position `qi` attends only to `ki <= qi`.
173    Causal,
174    /// Sliding window — position `qi` attends to `ki ∈ [qi - w, qi]`.
175    SlidingWindow(usize),
176    /// Read mask values from the input tensor (default; matches BERT
177    /// padding-mask behavior). Tensor shape `[batch, key_len]` with
178    /// `1.0` = valid, `<0.5` = ignored.
179    Custom,
180    /// Additive per-head, per-query bias tensor
181    /// `[batch, num_heads, query_len, key_len]` added to the
182    /// `QK^T · scale` scores before softmax. Lets DETR-style boxRPB
183    /// and other learned position biases reuse the fast `Op::Attention`
184    /// path instead of decomposing into matmul + add + softmax + matmul.
185    Bias,
186}
187
188/// Which forward input an [`Op::AttentionBackward`] node differentiates.
189#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
190#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
191pub enum AttentionBwdWrt {
192    Query,
193    Key,
194    Value,
195}
196
197/// Reduction operations along specified axes.
198#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
199#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
200pub enum ReduceOp {
201    Sum,
202    Mean,
203    Max,
204    Min,
205    Prod,
206}
207
208/// PLAN L4: discriminant for each [`Op`] variant. Used by
209/// [`Op::kind`] + the `Backend::supported_ops` trait method to declare
210/// which ops a backend can lower; the `LegalizeForBackend` pass in
211/// `rlx-opt` checks the graph against this set and fails the compile
212/// when an unsupported op is present (instead of silent fallback).
213#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
214#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
215pub enum OpKind {
216    Input,
217    Param,
218    Constant,
219    Activation,
220    Cast,
221    Quantize,
222    Dequantize,
223    FakeQuantize,
224    FakeQuantizeLSQ,
225    FakeQuantizeLSQBackwardX,
226    FakeQuantizeLSQBackwardScale,
227    Binary,
228    Compare,
229    Where,
230    ElementwiseRegion,
231    MatMul,
232    DotGeneral,
233    DenseSolve,
234    BatchedDenseSolve,
235    LayerNorm,
236    LayerNorm2d,
237    GroupNorm,
238    RmsNorm,
239    ResizeNearest2x,
240    Attention,
241    Rope,
242    AxialRope2d,
243    Reshape,
244    Transpose,
245    Narrow,
246    Concat,
247    Expand,
248    Gather,
249    Reduce,
250    Softmax,
251    Cumsum,
252    TopK,
253    Sample,
254    Conv,
255    ConvTranspose2d,
256    Pool,
257    ReluBackward,
258    ActivationBackward,
259    FakeQuantizeBackward,
260    ComplexNormSq,
261    ComplexNormSqBackward,
262    Conjugate,
263    MaxPool2dBackward,
264    Conv2dBackwardInput,
265    Conv2dBackwardWeight,
266    SoftmaxCrossEntropyWithLogits,
267    SoftmaxCrossEntropyBackward,
268    AttentionBackward,
269    LayerNormBackwardInput,
270    LayerNormBackwardGamma,
271    RmsNormBackwardInput,
272    RmsNormBackwardGamma,
273    RmsNormBackwardBeta,
274    RopeBackward,
275    GroupNormBackwardInput,
276    GroupNormBackwardGamma,
277    GroupNormBackwardBeta,
278    CumsumBackward,
279    GatherBackward,
280    GroupedMatMul,
281    DequantGroupedMatMul,
282    DequantMoEWeights,
283    ScatterAdd,
284    LoraMatMul,
285    DequantMatMul,
286    QMatMul,
287    QConv2d,
288    SelectiveScan,
289    GatedDeltaNet,
290    FusedSwiGLU,
291    FusedMatMulBiasAct,
292    FusedResidualLN,
293    FusedResidualRmsNorm,
294    FusedAttentionBlock,
295    FusedTransformerLayer,
296    If,
297    While,
298    Scan,
299    ScanBackward,
300    ScanBackwardXs,
301    /// CPU reference 3D Gaussian splat raster (project → bin → sort → raster).
302    /// See [`Op::GaussianSplatRender`].
303    GaussianSplatRender,
304    /// Backward of [`Op::GaussianSplatRender`] — packed scene parameter gradients.
305    GaussianSplatRenderBackward,
306    /// Project + tile bin + sort + ray grid (strict IR splat stage 1).
307    GaussianSplatPrepare,
308    /// Per-pixel raster from prepared buffers (strict IR splat stage 2).
309    GaussianSplatRasterize,
310    /// User-registered op dispatched through `op_registry`. All
311    /// custom ops (Sparse-LU, FFT, eigensolve, ...) share this kind;
312    /// the per-op identity lives in `Op::Custom::name`.
313    Custom,
314    /// User-defined sub-graph with optional override AD rules. See
315    /// [`Op::CustomFn`] / [`crate::Graph::custom_fn`].
316    CustomFn,
317    /// 1D FFT primitive (forward or inverse) — see [`Op::Fft`].
318    Fft,
319}
320
321/// An operand inside a fused [`ChainStep`] — either a graph-level input
322/// to the [`Op::ElementwiseRegion`] (by index 0..num_inputs) or the
323/// result of a previous step in the chain (by index 0..step_position).
324#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
325#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
326pub enum ChainOperand {
327    Input(u32),
328    Step(u32),
329}
330
331/// One step in a fused element-wise chain. Each step produces exactly
332/// one scalar result (per element); later steps can refer to it via
333/// [`ChainOperand::Step`]. The whole chain runs per element in registers.
334#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
335#[derive(Debug, Clone, PartialEq)]
336pub enum ChainStep {
337    Activation(Activation, ChainOperand),
338    Cast(DType, ChainOperand),
339    Binary(BinaryOp, ChainOperand, ChainOperand),
340    Compare(CmpOp, ChainOperand, ChainOperand),
341    /// 3-input element-wise select: `cond ? on_true : on_false`. Mirrors
342    /// `Op::Where` inside a chain. `cond` is treated as truthy iff
343    /// non-zero. Lets the optimizer fold attention masks / clamp-style
344    /// patterns into a single region kernel instead of breaking the
345    /// chain at the first `Op::Where`.
346    Where(ChainOperand, ChainOperand, ChainOperand),
347}
348
349/// An operation in the RLX IR graph.
350///
351/// Operations are categorized for fusion analysis:
352/// - Element-wise ops fuse with anything reading their output
353/// - Matmul/Conv are BLAS-dispatched and form fusion boundaries
354/// - Reductions are fusion roots (drive the loop iteration)
355#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
356#[derive(Debug, Clone, PartialEq)]
357pub enum Op {
358    // ── Graph inputs ────────────────────────────────────────────
359    /// Model input with a name (shape on the Node).
360    Input {
361        name: String,
362    },
363
364    /// Model parameter (weight/bias) with a name.
365    Param {
366        name: String,
367    },
368
369    /// Constant tensor embedded in the graph.
370    Constant {
371        data: Vec<u8>,
372    },
373
374    // ── Element-wise unary ──────────────────────────────────────
375    /// Unary activation: one input, same shape output.
376    Activation(Activation),
377
378    /// Cast to a different dtype.
379    Cast {
380        to: DType,
381    },
382
383    /// INT8 quantization. Input f32; output i8 same shape.
384    ///   `q[i] = saturate_i8(round(x[i] / scale[c]) + zero_point[c])`
385    /// where `c` selects the per-channel scale/zp when `axis = Some(d)`
386    /// (`c = idx[d]`), or always uses index 0 when `axis = None`
387    /// (per-tensor). The `scales` / `zero_points` payload length must
388    /// match `1` for per-tensor and `input.dim(d)` for per-channel.
389    /// Static — typically produced at calibration time and baked
390    /// into the loaded model. Use `Op::Dequantize` for the inverse.
391    Quantize {
392        axis: Option<usize>,
393        scales: Vec<f32>,
394        zero_points: Vec<i32>,
395    },
396
397    /// INT8 dequantization (inverse of `Op::Quantize`). Input i8;
398    /// output f32 same shape.
399    ///   `x[i] = (q[i] - zero_point[c]) · scale[c]`
400    /// where `c` is selected by `axis` exactly as in `Op::Quantize`.
401    Dequantize {
402        axis: Option<usize>,
403        scales: Vec<f32>,
404        zero_points: Vec<i32>,
405    },
406
407    /// "Fake-quantize" op for **quantization-aware training** (QAT).
408    /// Input f32; output f32 same shape. Forward computes a per-axis
409    /// (or per-tensor when `axis = None`) max-abs scale on the fly:
410    ///   `s[c] = max(|x[..., c, ...]|) / q_max(bits)`
411    /// then quantizes-then-dequantizes:
412    ///   `out[i] = clamp(round(x[i] / s[c]), -q_max, q_max) * s[c]`
413    /// where `q_max` is `127` for `bits=8`, `7` for `bits=4`, `1` for
414    /// `bits=2` (ternary). Symmetric only — zero-point is always 0.
415    ///
416    /// The point of this op is to make the SGD optimizer "see" the
417    /// deployment-time rounding during training. Backward is the
418    /// **straight-through estimator** (STE): the gradient passes
419    /// through (variant chosen by `ste`), ignoring the discontinuity
420    /// at the round. Without STE the rounding would have zero
421    /// gradient almost everywhere and learning would stop.
422    ///
423    /// Inserted by the trainer on conv / FC weight tensors when
424    /// `--qat` is on; the existing `Op::Quantize` / packing path at
425    /// the end of training still handles the deployment-side
426    /// conversion to `i8`/`i4`/`i2` codes.
427    FakeQuantize {
428        bits: u8,
429        axis: Option<usize>,
430        ste: SteKind,
431        scale_mode: ScaleMode,
432    },
433
434    /// Learned Step Size Quantization (LSQ; Esser et al. 2020,
435    /// `arXiv:1902.08153`). Like `FakeQuantize` but the per-channel
436    /// `scale` is a *learned parameter*, passed as the second input.
437    /// Forward is identical to `FakeQuantize` with a fixed scale:
438    ///   `out[i] = clamp(round(x[i]/s[c]), -q_max, q_max) * s[c]`
439    /// Backward computes both `dx` (STE) and `dscale[c]` via the
440    /// closed-form gradient:
441    ///   `dscale[c] = sum_i ψ(x[i]/s[c]) · upstream[i]`
442    /// where `ψ(z) = -z + round(z)` if `|z| ≤ q_max` else
443    /// `sign(z) · q_max`. Routinely beats per-batch and EMA at
444    /// tight bit widths (i2 / i3).
445    ///
446    /// Inputs: `[x, scale]`. `scale` is `[chan_dim]` f32 (matches
447    /// `axis`); for `axis = None` it's `[1]`.
448    FakeQuantizeLSQ {
449        bits: u8,
450        axis: Option<usize>,
451    },
452
453    /// Backward pass for `Op::FakeQuantizeLSQ`. Computes BOTH the
454    /// gradient w.r.t. `x` (STE) and the gradient w.r.t. `scale`
455    /// (closed-form). Output shape matches `x`; the `scale` gradient
456    /// is reduced separately by `LsqScaleGradient`.
457    /// Inputs: `[x, scale, dy]`. Output: `dx`, same shape as `x`.
458    FakeQuantizeLSQBackwardX {
459        bits: u8,
460        axis: Option<usize>,
461    },
462
463    /// Companion to `FakeQuantizeLSQBackwardX`: computes the
464    /// `[chan_dim]` per-channel scale gradient. Inputs `[x, scale, dy]`.
465    /// Output shape matches `scale`.
466    FakeQuantizeLSQBackwardScale {
467        bits: u8,
468        axis: Option<usize>,
469    },
470
471    // ── Element-wise binary ─────────────────────────────────────
472    /// Binary op with broadcasting: two inputs, output shape is broadcast result.
473    Binary(BinaryOp),
474
475    // ── Comparison ──────────────────────────────────────────────
476    /// Element-wise comparison: two inputs, Bool output.
477    Compare(CmpOp),
478
479    /// Select elements: cond (Bool), on_true, on_false → output.
480    Where,
481
482    /// Fused element-wise region (PLAN L2). Holds an N-step chain of
483    /// element-wise operations. Inputs are referenced by index 0..num_inputs;
484    /// each step's result can be referenced by later steps via
485    /// `ChainOperand::Step(idx)`. The output is the last step's result.
486    /// Emitted by `MarkElementwiseRegions` in `rlx-opt` from chains of
487    /// Activation/Cast/Binary/Compare/Where ops with single-consumer
488    /// intermediates and broadcast-compatible shapes. Backends that
489    /// don't have a region kernel can decompose back to the original
490    /// chain via `unfuse::unfuse_elementwise_regions`.
491    ///
492    /// `scalar_input_mask` is a per-input bitfield (bit `i` set ⇒
493    /// input `i` is a scalar broadcast — has shape `[1]`). Kept as a
494    /// fast-path indicator that lets kernels skip the modulo entirely
495    /// when they detect a scalar.
496    ///
497    /// `input_modulus[i]` is the per-input element count, used by
498    /// kernels to compute `arena[input_offs[i] + (gid % input_modulus[i])]`
499    /// — the trailing-shape broadcast pattern. `0` means "no broadcast"
500    /// (input matches the output element count; kernel reads `gid`
501    /// directly). `1` means scalar; any other value means the input
502    /// has fewer elements than the output and they tile by modulo.
503    /// The encoder only allows broadcasts where `out_elems % in_elems
504    /// == 0` so the modulo divides cleanly. Lets chains include bias /
505    /// scale / eps / mask factors that previously broke the chain at
506    /// a Binary op with mismatched shapes.
507    ElementwiseRegion {
508        chain: Vec<ChainStep>,
509        num_inputs: u32,
510        scalar_input_mask: u32,
511        input_modulus: [u32; 16],
512    },
513
514    // ── Linear algebra ──────────────────────────────────────────
515    /// Matrix multiply. Inputs: [.., M, K] × [.., K, N] → [.., M, N].
516    /// Batch dimensions are broadcast.
517    MatMul,
518
519    /// Matrix multiply with explicit dimension specification.
520    /// Like XLA's DotGeneral — handles arbitrary batch/contracting dims.
521    DotGeneral {
522        lhs_contracting: Vec<usize>,
523        rhs_contracting: Vec<usize>,
524        lhs_batch: Vec<usize>,
525        rhs_batch: Vec<usize>,
526    },
527
528    /// Batched dense linear solve. Inputs: `A [B, N, N]`,
529    /// `b [B, N]` or `b [B, N, K]`. Output: same shape as `b`.
530    ///
531    /// Per-batch independent solve — each `A[i]` and `b[i]` are
532    /// solved as a separate `Op::DenseSolve`. Emitted by vmap of
533    /// `Op::DenseSolve`. The CPU lowering loops over the batch
534    /// dimension calling `dgesv` per slice (LAPACK doesn't expose a
535    /// batched solve on Accelerate; cuSOLVER does on NVIDIA).
536    BatchedDenseSolve,
537
538    /// Dense linear solve `x = A⁻¹ · b` via LU factorization.
539    /// Inputs: `A [N, N]`, `b [N]` (or `b [N, K]` for multi-RHS).
540    /// Output: same shape as `b`.
541    ///
542    /// VJP via the implicit-function theorem:
543    ///   `dx = solve(Aᵀ, upstream)`
544    ///   `dA = -outer(dx, x)`   (x is the forward output)
545    ///   `db = dx`
546    /// The rule is dtype-agnostic; lowering is per-backend (Accelerate
547    /// `dgesv` / `sgesv`, cuSOLVER, etc.).
548    DenseSolve,
549
550    // ── Normalization ───────────────────────────────────────────
551    /// Layer normalization: input, gamma, beta → normalized output.
552    /// `axis` is the feature dimension (usually -1).
553    LayerNorm {
554        axis: i32,
555        eps: f32,
556    },
557
558    /// Group normalization on NCHW tensors: `input`, `gamma`, `beta` → same shape.
559    /// Normalizes over `(C/num_groups) × H × W` per group.
560    GroupNorm {
561        num_groups: usize,
562        eps: f32,
563    },
564
565    /// LayerNorm2d on NCHW: normalize across the channel axis at each spatial
566    /// position (candle / SAM `LayerNorm2d` semantics — not PyTorch's H×W norm).
567    LayerNorm2d {
568        eps: f32,
569    },
570
571    /// Nearest-neighbor 2× upsample on NCHW (doubles spatial dims 2 and 3).
572    ResizeNearest2x,
573
574    /// RMS normalization: input, gamma → normalized output.
575    RmsNorm {
576        axis: i32,
577        eps: f32,
578    },
579
580    // ── Attention ───────────────────────────────────────────────
581    /// Scaled dot-product attention: Q, K, V, \[mask\] → output.
582    /// The compiler can lower this to fused SDPA or flash attention.
583    /// `mask_kind` controls how masking is applied — `Custom` reads from
584    /// the 4th input tensor; `None` / `Causal` / `SlidingWindow` skip the
585    /// mask load and apply the mask directly in the inner loop. See
586    /// `MaskKind` for the rationale.
587    ///
588    /// `score_scale`: when `Some(s)`, dot-product scores are multiplied by
589    /// `s` instead of the default `1/sqrt(head_dim)` (Gemma uses `head_dim^-0.5`
590    /// explicitly in config). `attn_logit_softcap`: when `Some(c)`, applies
591    /// `c * tanh(s/c)` to scores before softmax (Gemma 2).
592    Attention {
593        num_heads: usize,
594        head_dim: usize,
595        mask_kind: MaskKind,
596        score_scale: Option<f32>,
597        attn_logit_softcap: Option<f32>,
598    },
599
600    /// Rotary position embedding applied to one tensor: x, cos, sin → x_rotated.
601    /// Apply separately to Q and K. `head_dim` is the per-head width; `n_rot`
602    /// is how many leading dims get NeoX RoPE (pair offset `n_rot/2`). When
603    /// `n_rot < head_dim`, trailing dims are copied unchanged (Qwen3.5 MRoPE).
604    Rope {
605        head_dim: usize,
606        n_rot: usize,
607    },
608
609    /// SAM2 axial 2-D RoPE on `[batch, seq, num_heads * head_dim]`.
610    AxialRope2d {
611        end_x: usize,
612        end_y: usize,
613        head_dim: usize,
614        num_heads: usize,
615        theta: f32,
616        repeat_factor: usize,
617    },
618
619    // ── Shape manipulation ──────────────────────────────────────
620    Reshape {
621        new_shape: Vec<i64>,
622    },
623    Transpose {
624        perm: Vec<usize>,
625    },
626    /// Select a contiguous slice along an axis.
627    Narrow {
628        axis: usize,
629        start: usize,
630        len: usize,
631    },
632    /// Concatenate along an axis.
633    Concat {
634        axis: usize,
635    },
636    /// Expand (broadcast) to a target shape.
637    Expand {
638        target_shape: Vec<i64>,
639    },
640    /// Gather elements by index along an axis (embedding lookup).
641    Gather {
642        axis: usize,
643    },
644
645    // ── Reduction ───────────────────────────────────────────────
646    /// Reduce along specified axes.
647    Reduce {
648        op: ReduceOp,
649        axes: Vec<usize>,
650        keep_dim: bool,
651    },
652
653    /// Selective scan (plan #15) — Mamba-style state-space model
654    /// step. The recurrence:
655    ///   `h[t] = exp(Δ[t] * A) * h[t-1] + Δ[t] * B[t] * x[t]`
656    ///   `y[t] = C[t] * h[t]`
657    /// where state `h` has dimension `state_size` and the input has
658    /// `(batch, seq, hidden)`.
659    ///
660    /// Inputs (in order):
661    ///   `x [b, s, h]`      f32 input
662    ///   `delta [b, s, h]`  f32 step size (per-position, per-channel)
663    ///   `a [h, n]`         f32 transition matrix (one per channel)
664    ///   `b [b, s, n]`      f32 input projection
665    ///   `c [b, s, n]`      f32 output projection
666    /// Output: `[b, s, h]` f32. State `h` is implicit; the kernel
667    /// scans through the seq dimension carrying it.
668    ///
669    /// `state_size` = `n` is exposed for the cost model.
670    SelectiveScan {
671        state_size: usize,
672    },
673
674    /// Gated DeltaNet linear-attention recurrence — the per-layer
675    /// kernel used by Qwen3.5/3.6 trunk "linear attention" blocks
676    /// (and Qwen3-Next, Kimi-Linear). Mirrors
677    /// `llama.cpp / src/models/delta-net-base.cpp` autoregressive
678    /// path; chunked + fused variants ride the same op identity.
679    ///
680    /// **Math (per token `t`, head `h`, state size `n`):**
681    /// state matrix `S[h, i, j]` is implicit (reset per batch).
682    /// ```text
683    ///   S[h]     *= exp(g[t,h])                     # scalar gate
684    ///   sk[h,j]   = Σ_i S[h,i,j] * k[t,h,i]
685    ///   d[h,j]    = (v[t,h,j] - sk[h,j]) * b[t,h]   # b = beta
686    ///   S[h,i,j] += k[t,h,i] * d[h,j]               # outer-prod
687    ///   o[t,h,j]  = Σ_i S[h,i,j] * (q[t,h,i] / √n)
688    /// ```
689    ///
690    /// Inputs:
691    ///   `q     [b, s, h_v, n]`  f32 queries (L2-normed by caller)
692    ///   `k     [b, s, h_v, n]`  f32 keys    (L2-normed by caller;
693    ///                            GQA-repeated to match `h_v`)
694    ///   `v     [b, s, h_v, n]`  f32 values
695    ///   `g     [b, s, h_v]`     f32 log-gate (exp'd inside kernel)
696    ///   `beta  [b, s, h_v]`     f32 delta-rule mixing factor
697    ///
698    /// Output: `[b, s, h_v, n]` f32.
699    ///
700    /// When `carry_state` is true, a sixth input `state [b, h_v, n, n]`
701    /// provides the initial SSM matrix per head; the kernel updates it
702    /// in place across the sequence and leaves the final state in the
703    /// same buffer (same layout as the internal scan state:
704    /// `state[h, i, j]` row-major over `(n, n)` per head).
705    GatedDeltaNet {
706        state_size: usize,
707        carry_state: bool,
708    },
709
710    /// Fused dequant + matmul (plan #5). The biggest LLM-bandwidth
711    /// win on Apple Silicon: dequantizes weights inside the matmul
712    /// inner loop, never materializing f32 weights.
713    ///
714    /// **BREAKING CHANGE in 0.2.0:** `num_inputs()` is now
715    /// scheme-dependent — **4** for legacy Int8 schemes, **2** for
716    /// the new GGUF K-quant schemes (their scales/mins live inside
717    /// the packed bytes, so no side-channel `scale` / `zp` tensors
718    /// are fed in). Callers that assumed a fixed 4-input contract
719    /// must inspect `scheme.is_gguf()` before reading inputs.
720    ///
721    /// Inputs (Int8 schemes — `scheme.is_gguf() == false`):
722    ///   `x [m, k]`             f32 activations
723    ///   `w_q [k, n]` packed    quantized weight bytes (i8 per
724    ///                          element for Int8 schemes; 4-bit
725    ///                          packed two-per-byte for Int4)
726    ///   `scale [k/block, n]`   per-block f32 dequant scale
727    ///   `zp    [k/block, n]`   per-block f32 zero-point
728    ///                          (zero-tensor if symmetric)
729    ///
730    /// Inputs (`Nvfp4Block` — fixed group size 16 along K):
731    ///   `x [m, k]`             f32 activations
732    ///   `w_q [k,n/2]` packed   FP4 E2M1 codes (unsigned nibble 0..15)
733    ///   `scale [k/16, n]` u8   FP8 E4M3 block scales (one byte / group)
734    ///   `global_scale [1]` f32 per-tensor scale (pass `[1.0]` if unused)
735    ///
736    /// Inputs (GGUF schemes — `scheme.is_gguf() == true`):
737    ///   `x [m, k]`             f32 activations
738    ///   `packed_w [bytes]`     raw GGUF super-block bytes; the
739    ///                          dequantizer reads the per-sub-block
740    ///                          scales / mins / quants directly out
741    ///                          of the buffer per the K-quant block
742    ///                          layout (no side tensors).
743    ///
744    /// Output: `[m, n]` f32.
745    ///
746    /// `block_size` (on the Int8 schemes only) is the number of
747    /// consecutive elements that share one (scale, zero_point) pair.
748    /// The Op carries enough metadata that the kernel doesn't need
749    /// a separate `QuantMap` lookup at run time.
750    DequantMatMul {
751        scheme: crate::quant::QuantScheme,
752    },
753
754    /// Real INT8-arithmetic matrix multiply with i32 accumulation.
755    /// Inputs (in order):
756    ///   `x      [M, K]`  i8 activations (zero-point = `x_zp`)
757    ///   `w      [K, N]`  i8 weights     (zero-point = `w_zp`)
758    ///   `bias   [N]`     i32 (in accumulator scale = `x_scale·w_scale`),
759    ///                    pass a zeros tensor for "no bias"
760    /// Output:  `[M, N]`  i8 (zero-point = `out_zp`)
761    ///
762    /// Per-element compute:
763    ///   `out[m,n] = requantize(bias[n] + Σₖ (x[m,k]-x_zp)·(w[k,n]-w_zp), mult, out_zp)`
764    /// where `mult = x_scale · w_scale / out_scale`.
765    ///
766    /// This is the same kernel shape `rlx-cortexm/src/dense.rs`
767    /// uses for on-device int8 inference, lifted into the IR so the
768    /// rlx-cpu backend can run a quantized graph directly (instead
769    /// of round-tripping through fake-quant Dequantize → MatMul →
770    /// Quantize). 2-D only — generalizing to batched comes when a
771    /// real workload demands it.
772    QMatMul {
773        x_zp: i32,
774        w_zp: i32,
775        out_zp: i32,
776        mult: f32,
777    },
778
779    /// Real INT8-arithmetic 2-D convolution with i32 accumulation.
780    /// Inputs:
781    ///   `x      [N, C_in, H, W]`              i8 (zero-point = `x_zp`)
782    ///   `w      [C_out, C_in/groups, kH, kW]` i8 (zero-point = `w_zp`)
783    ///   `bias   [C_out]`                      i32 in accumulator scale
784    /// Output: `[N, C_out, H_out, W_out]` i8 (zero-point = `out_zp`).
785    /// Same NCHW geometry contract as `Op::Conv`; same requantize
786    /// math as `Op::QMatMul` (per-element `acc·mult` rounded to i8).
787    QConv2d {
788        kernel_size: Vec<usize>,
789        stride: Vec<usize>,
790        padding: Vec<usize>,
791        dilation: Vec<usize>,
792        groups: usize,
793        x_zp: i32,
794        w_zp: i32,
795        out_zp: i32,
796        mult: f32,
797    },
798
799    /// Fused LoRA matmul: `out = x·W + scale * x·A·B`.
800    /// Inputs (in order): `x [m, k]`, `w [k, n]`, `a [k, r]`, `b [r, n]`.
801    /// `r` is the LoRA rank (typically 4-64). `scale` is the
802    /// per-adapter `alpha / rank` knob.
803    /// Plan #9: lifts LoRA from "three matmuls + an add" into one
804    /// kernel that keeps the rank-r intermediate in registers.
805    LoraMatMul {
806        scale: f32,
807    },
808
809    /// Fused sampling kernel: logits → optional top-k filter →
810    /// optional top-p truncation → softmax → multinomial sample.
811    /// One f32-encoded sampled token id per batch row (output
812    /// shape `[batch]`).
813    ///
814    /// `temperature == 1.0` matches a plain argmax-of-softmax;
815    /// lower → sharper, higher → flatter. `top_k == 0` disables.
816    /// `top_p == 1.0` disables. `seed` is the Philox seed; pass 0
817    /// for "use process-global counter" (still deterministic
818    /// given the call order).
819    /// Borrowed from MAX's nn/sampling.mojo (#42 in PLAN.md).
820    /// Latency-critical: never materializes the full softmax
821    /// distribution on the host.
822    Sample {
823        top_k: usize,     // 0 = disabled
824        top_p: f32,       // 1.0 = disabled
825        temperature: f32, // 1.0 = neutral
826        seed: u64,        // 0 = use thread-local counter
827    },
828
829    /// Inclusive cumulative sum along an axis. Same shape in/out.
830    /// Underpins ragged-tensor offsets, sampling (top-p prefix sum),
831    /// and sequence-position math (#44 in PLAN.md).
832    /// `exclusive=true` shifts the result so output\[0\] = 0 (useful
833    /// for offset arrays where the first segment starts at 0).
834    Cumsum {
835        axis: i32,
836        exclusive: bool,
837    },
838
839    /// Softmax along an axis (reduction + element-wise).
840    Softmax {
841        axis: i32,
842    },
843
844    /// Top-K **indices** along the last axis. Output shape `[..., k]`,
845    /// f32-encoded indices (rlx is f32-only at the I/O boundary).
846    /// To recover the values, follow with a `Gather` against the
847    /// original tensor — works because Gather already supports any axis.
848    /// Ties broken by smaller index (matches NumPy / PyTorch
849    /// `torch.topk(..., largest=True, sorted=True)`).
850    /// Used by MoE gating; also useful for beam search.
851    TopK {
852        k: usize,
853    },
854
855    /// Indexed batched matmul. The MoE GEMM primitive.
856    /// Inputs: `[input, weight, expert_idx]`
857    ///   input       : [M, K]                — per-token activations
858    ///   weight      : [num_experts, K, N]   — stacked expert weights
859    ///   expert_idx  : \[M\]                   — f32-encoded expert id per token
860    /// Output         : [M, N]                — output\[i\] = input\[i\] @ weight[expert_idx\[i\]]
861    /// Naive impl on both backends; future work can replace with a
862    /// segmented/grouped GEMM when there's a real workload.
863    GroupedMatMul,
864
865    /// Fused GGUF K-quant dequant + [`Op::GroupedMatMul`]. Same three
866    /// inputs as `GroupedMatMul`, but `weight` is a U8 tensor holding
867    /// `num_experts` contiguous packed expert slabs (GGML layout, expert
868    /// dimension outermost). Scales live inside the packed bytes.
869    DequantGroupedMatMul {
870        scheme: crate::quant::QuantScheme,
871    },
872
873    /// Dequant a packed MoE expert stack to F32 `[num_experts, K, N]` in
874    /// GroupedMatMul layout. Input: U8 packed bytes; output shape is
875    /// declared on the node (`[E, K, N]`).
876    DequantMoEWeights {
877        scheme: crate::quant::QuantScheme,
878    },
879
880    /// Scatter-add into a destination tensor. The "unpermute" half of
881    /// MoE routing (also useful for embedding gradient updates).
882    /// Inputs: `[updates, indices]`
883    ///   updates : [num_updates, trailing]   — values to add
884    ///   indices : \[num_updates\]             — f32-encoded destination row
885    /// Output    : [out_dim, trailing]       — output[indices\[i\]] += updates\[i\]
886    /// `out_dim` is taken from the node's declared output shape.
887    /// Initial output is zero; multiple updates to the same row
888    /// accumulate (sequentially on CPU; with atomic-add on Metal).
889    ScatterAdd,
890
891    // ── Convolution ─────────────────────────────────────────────
892    /// 2D convolution on NCHW tensors. Also exposed as [`OpKind::Conv`] / `conv2d`.
893    /// Weight layout: `[C_out, C_in / groups, kH, kW]`.
894    Conv {
895        kernel_size: Vec<usize>,
896        stride: Vec<usize>,
897        padding: Vec<usize>,
898        dilation: Vec<usize>,
899        groups: usize,
900    },
901
902    /// 2D transposed convolution on NCHW. Weight layout (PyTorch):
903    /// `[C_in, C_out / groups, kH, kW]`.
904    ConvTranspose2d {
905        kernel_size: Vec<usize>,
906        stride: Vec<usize>,
907        padding: Vec<usize>,
908        dilation: Vec<usize>,
909        output_padding: Vec<usize>,
910        groups: usize,
911    },
912
913    // ── Pooling ─────────────────────────────────────────────────
914    Pool {
915        kind: ReduceOp,
916        kernel_size: Vec<usize>,
917        stride: Vec<usize>,
918        padding: Vec<usize>,
919    },
920
921    // ── Backward / training ops ─────────────────────────────────
922    //
923    // Closed-form gradient nodes emitted by `rlx-opt::autodiff`.
924    // Pairing a forward op with a dedicated backward op (instead of
925    // composing it from primitives) keeps the gradient kernel simple
926    // and lets the backend recompute argmax / masks / softmax inline.
927    /// ReLU backward: `dx = dy where x > 0 else 0`.
928    /// Inputs: `[x, dy]` — both same shape; output matches.
929    ReluBackward,
930
931    /// Element-wise complex squared-magnitude: `|z|² = z.re² + z.im²`.
932    /// Input: 1 tensor with `DType::C64`. Output: same shape but
933    /// `DType::F32`. The natural real-valued loss surface for
934    /// Wirtinger reverse-mode AD on complex graphs — pair with
935    /// [`Op::ComplexNormSqBackward`].
936    ComplexNormSq,
937
938    /// Element-wise complex conjugate: `z̄ = z.re - i·z.im` per element.
939    /// Input: 1 tensor with `DType::C64`. Output: same shape, same dtype.
940    /// Used by Wirtinger VJP rules on `Op::Binary` over C64 (the rule
941    /// for `y = a·b` is `dL/dā = upstream · conj(b)`, etc.).
942    Conjugate,
943
944    /// Backward for [`Op::ComplexNormSq`] under Wirtinger calculus.
945    /// `f(z) = |z|² = z·z̄`, so `∂f/∂z̄ = z`. Given upstream real
946    /// cotangent `g` (same shape as the forward output), the C64
947    /// gradient with respect to `z` is `g · z` element-wise, returned
948    /// in C64 storage `[re_g·re_z, re_g·im_z]` per element.
949    ///
950    /// Inputs: `[z (C64), g (F32)]` — both same logical shape; output
951    /// matches `z` (C64).
952    ComplexNormSqBackward,
953
954    /// LayerNorm backward w.r.t. the input. Computes
955    ///   `d_x[..., d] = inv_std · (dy·γ - mean(dy·γ) - x̂·mean(dy·γ·x̂))`
956    /// over the feature axis, where `x̂ = (x - mean)/std` is recomputed
957    /// inline from `x`. Inputs: `[x, gamma, dy]`; output shape = `x.shape`.
958    /// Currently lowers axis=-1 only (matches the forward thunk).
959    LayerNormBackwardInput {
960        axis: i32,
961        eps: f32,
962    },
963
964    /// LayerNorm backward w.r.t. gamma. Computes
965    ///   `d_gamma[d] = Σ_{batch} dy[..., d] · x̂[..., d]`
966    /// — sums the per-element product of upstream and the (recomputed)
967    /// normalized input over the leading axes. Inputs: `[x, dy]`;
968    /// output shape = `gamma.shape` (= 1-D feature axis).
969    LayerNormBackwardGamma {
970        axis: i32,
971        eps: f32,
972    },
973
974    /// RMSNorm backward w.r.t. input. Inputs `[x, gamma, beta, dy]`; output = `x.shape`.
975    RmsNormBackwardInput {
976        axis: i32,
977        eps: f32,
978    },
979
980    /// RMSNorm backward w.r.t. gamma. Inputs `[x, gamma, beta, dy]`; output = `gamma.shape`.
981    RmsNormBackwardGamma {
982        axis: i32,
983        eps: f32,
984    },
985
986    /// RMSNorm backward w.r.t. beta. Inputs `[x, gamma, beta, dy]`; output = `beta.shape`.
987    RmsNormBackwardBeta {
988        axis: i32,
989        eps: f32,
990    },
991
992    /// RoPE backward w.r.t. `x`. Inputs `[dy, cos, sin]`; output = `dy.shape`.
993    RopeBackward {
994        head_dim: usize,
995        n_rot: usize,
996    },
997
998    /// GroupNorm (NCHW) backward w.r.t. input. Inputs `[x, gamma, beta, dy]`.
999    GroupNormBackwardInput {
1000        num_groups: usize,
1001        eps: f32,
1002    },
1003
1004    /// GroupNorm backward w.r.t. gamma. Inputs `[x, dy]`; output = `gamma.shape`.
1005    GroupNormBackwardGamma {
1006        num_groups: usize,
1007        eps: f32,
1008    },
1009
1010    /// GroupNorm backward w.r.t. beta. Inputs `[x, dy]`; output = `beta.shape`.
1011    GroupNormBackwardBeta {
1012        num_groups: usize,
1013        eps: f32,
1014    },
1015
1016    /// Cumsum backward along `axis`. Inputs `[dy]`; output matches forward input shape.
1017    CumsumBackward {
1018        axis: i32,
1019        exclusive: bool,
1020    },
1021
1022    /// Gather backward (scatter-add into table). Inputs `[dy, indices]`; output = table shape.
1023    /// `axis` matches forward [`Op::Gather`].
1024    GatherBackward {
1025        axis: i32,
1026    },
1027
1028    /// Generic element-wise activation backward. `kind` selects the
1029    /// closed-form derivative `d/dx act(x)`. Inputs: `[x, dy]`; output
1030    /// shape matches `x`. The kernel computes `d/dx · dy` per element.
1031    ///
1032    /// Closed forms (all element-wise):
1033    /// * `Gelu`     — exact derivative of erf-based GELU.
1034    /// * `GeluApprox` — derivative of the tanh approximation
1035    ///   `0.5 x (1 + tanh(√(2/π)(x + 0.044715 x³)))`.
1036    /// * `Silu`     — `σ(x)·(1 + x·(1 - σ(x)))`.
1037    /// * `Sigmoid`  — `σ(x)·(1 - σ(x))`.
1038    /// * `Tanh`     — `1 - tanh(x)²`.
1039    /// * `Exp`      — `exp(x)`.
1040    /// * `Log`      — `1 / x`.
1041    /// * `Sqrt`     — `0.5 / sqrt(x)`.
1042    /// * `Rsqrt`    — `-0.5 · x^(-3/2)`.
1043    /// * `Neg`      — `-1`.
1044    /// * `Abs`      — `sign(x)` (zero at x=0).
1045    /// * `Sin`      — `cos(x)`.
1046    /// * `Cos`      — `-sin(x)`.
1047    /// * `Tan`      — `1 + tan²(x)`.
1048    /// * `Atan`     — `1 / (1 + x²)`.
1049    /// * `Relu`     — kept here for completeness; the dedicated
1050    ///   `ReluBackward` op is preferred for relu and is what the
1051    ///   autodiff pass actually emits.
1052    ActivationBackward {
1053        kind: Activation,
1054    },
1055
1056    /// Backward for `Op::FakeQuantize` under a non-default STE.
1057    /// Inputs `[x, dy]`: the forward input and the upstream
1058    /// gradient. Output `dx` same shape. The `bits`/`axis`/`ste`
1059    /// fields must match the forward op so the kernel computes the
1060    /// same per-channel scale and applies the right STE attenuation.
1061    /// For `SteKind::Identity` this op is unnecessary — autodiff
1062    /// just routes `upstream` through unchanged.
1063    FakeQuantizeBackward {
1064        bits: u8,
1065        axis: Option<usize>,
1066        ste: SteKind,
1067    },
1068
1069    /// 2D max-pool backward. Routes each element of `dy` back into the
1070    /// position in `x`'s window where the forward max was taken.
1071    /// Inputs: `[x, dy]` with `x [N, C, H, W]` and
1072    /// `dy [N, C, H_out, W_out]`. Output: same shape as `x`.
1073    /// Carries the forward pool's geometry so the kernel can recompute
1074    /// the argmax position per window without a saved-indices tensor.
1075    MaxPool2dBackward {
1076        kernel_size: Vec<usize>,
1077        stride: Vec<usize>,
1078        padding: Vec<usize>,
1079    },
1080
1081    /// 2D conv backward w.r.t. input. Computes `dx = conv_transpose(dy, w)`.
1082    /// Inputs: `[dy, w]` with `dy [N, C_out, H_out, W_out]` and
1083    /// `w [C_out, C_in/groups, kH, kW]`. Output: `[N, C_in, H, W]`
1084    /// (declared on the node — caller knows the original input shape).
1085    /// Geometry is the forward conv's parameters, not the transposed
1086    /// conv's.
1087    Conv2dBackwardInput {
1088        kernel_size: Vec<usize>,
1089        stride: Vec<usize>,
1090        padding: Vec<usize>,
1091        dilation: Vec<usize>,
1092        groups: usize,
1093    },
1094
1095    /// 2D conv backward w.r.t. weight. Computes
1096    /// `dw[c_out, c_in, kh, kw] = sum_{n,h_out,w_out} x[n,c_in,...] * dy[n,c_out,h_out,w_out]`.
1097    /// Inputs: `[x, dy]`. Output: `[C_out, C_in/groups, kH, kW]`.
1098    Conv2dBackwardWeight {
1099        kernel_size: Vec<usize>,
1100        stride: Vec<usize>,
1101        padding: Vec<usize>,
1102        dilation: Vec<usize>,
1103        groups: usize,
1104    },
1105
1106    /// Fused softmax + cross-entropy loss with integer (f32-encoded)
1107    /// targets — the standard classification loss. Per-row output:
1108    /// `loss[n] = -log(softmax(logits[n])[labels[n]])`.
1109    /// Inputs: `[logits, labels]` with `logits [N, C]` and
1110    /// `labels [N]` (f32-encoded class indices). Output: `[N]`.
1111    /// Caller does the `Reduce::Mean` if they want a scalar.
1112    SoftmaxCrossEntropyWithLogits,
1113
1114    /// Backward of the fused loss above. Emits
1115    /// `dlogits[n,c] = (softmax(logits[n])[c] - one_hot(labels)[n,c]) * d_loss[n]`.
1116    /// Inputs: `[logits, labels, d_loss]`. Output: `[N, C]` (same shape
1117    /// as `logits`). Recomputes the softmax inline rather than threading
1118    /// it through from the forward node.
1119    SoftmaxCrossEntropyBackward,
1120
1121    /// Backward of [`Op::Attention`]. Recomputes scaled `QK^T`, applies
1122    /// the same `mask_kind` as the forward op, softmaxes scores, then
1123    /// emits **one** of `dQ`, `dK`, or `dV` selected by [`AttentionBwdWrt`].
1124    /// Autodiff emits three nodes (one per `wrt`) so each output shape
1125    /// stays a normal single-output MIR node.
1126    ///
1127    /// Inputs: `[q, k, v, dy]` plus optional mask when `mask_kind` is
1128    /// [`MaskKind::Custom`] or [`MaskKind::Bias`] (same convention as
1129    /// forward). Output shape matches `q`, `k`, or `v` respectively.
1130    AttentionBackward {
1131        num_heads: usize,
1132        head_dim: usize,
1133        mask_kind: MaskKind,
1134        wrt: AttentionBwdWrt,
1135    },
1136
1137    // ── Fused operations (created by optimization passes) ──────
1138    /// Fused matmul + bias + activation. Created from MatMul → Add → Activation.
1139    FusedMatMulBiasAct {
1140        activation: Option<Activation>,
1141    },
1142
1143    /// Fused residual + optional bias + layer norm.
1144    /// Created from Add(x, residual) → [Add(bias)] → LayerNorm.
1145    FusedResidualLN {
1146        has_bias: bool,
1147        eps: f32,
1148    },
1149
1150    /// Fused residual + optional bias + RMS norm.
1151    /// Created from Add(x, residual) → [Add(bias)] → RmsNorm.
1152    FusedResidualRmsNorm {
1153        has_bias: bool,
1154        eps: f32,
1155    },
1156
1157    /// Fused SwiGLU: split input into up/gate halves, silu(gate) * up.
1158    /// Created from Split → Silu → Mul when fed by a concatenated matmul.
1159    ///
1160    /// `cast_to`: optional output dtype — when `Some(dt)` the kernel casts
1161    /// its result from the input dtype to `dt` in-register, saving a
1162    /// separate cast pass. Reserved for future fp8/fp4 quantization paths;
1163    /// for f32→f16 mixed precision the AutoMixedPrecision pass already
1164    /// inserts a Cast node so this stays `None` in current pipelines.
1165    FusedSwiGLU {
1166        cast_to: Option<DType>,
1167        /// When `true`, the concatenated input stores gate in the low half
1168        /// `[..., 0..N)` and up in the high half `[..., N..2N)` — the layout
1169        /// produced when gate projection is emitted before up in the builder.
1170        /// Default `false`: up @ low, gate @ high (canonical concat order).
1171        gate_first: bool,
1172    },
1173
1174    /// Fused full transformer layer: attention block + residual+LN + FFN + residual+LN.
1175    /// All intermediates resident in registers/threadgroup memory; one kernel
1176    /// per layer instead of ~30 (the CPU's batch=1 win, lifted to IR so any
1177    /// backend can implement it as a monolithic kernel).
1178    ///
1179    /// Inputs: hidden, qkv_w, qkv_b, out_w, out_b,
1180    ///         ln1_g, ln1_b, fc1_w, fc1_b, fc2_w, fc2_b, ln2_g, ln2_b, mask
1181    /// Output: same shape as hidden.
1182    ///
1183    /// **Backend status:** same as FusedAttentionBlock. CPU implements
1184    /// the L1-cache-resident merge at the thunk level. Metal deferred —
1185    /// requires a single MSL kernel for the whole layer to actually
1186    /// beat the unfused path. Multi-day work; revisit when there's a
1187    /// model whose Metal inference is bottlenecked here rather than on
1188    /// the wait latency floor.
1189    FusedTransformerLayer {
1190        num_heads: usize,
1191        head_dim: usize,
1192        intermediate_size: usize,
1193        eps1: f32,
1194        eps2: f32,
1195        activation: Activation,
1196        has_bias: bool,
1197    },
1198
1199    /// Fused attention block: QKV projection → split → \[RoPE\] → SDPA → output projection.
1200    /// Created by FuseAttentionBlock pass when batch*seq is small.
1201    /// All intermediates stay in L1 cache — no arena writes between ops.
1202    ///
1203    /// Inputs (in order):
1204    ///   hidden, qkv_w, out_w, mask,
1205    ///   [qkv_b, out_b]      if has_bias,
1206    ///   [rope_cos, rope_sin] if has_rope
1207    ///
1208    /// **Backend status (Phase C finalize):**
1209    ///   CPU  — implemented at the *thunk* level: the CPU schedule
1210    ///          recognizes the multi-thunk pattern and merges into
1211    ///          a single FusedAttnBlock that keeps Q/K/V in stack
1212    ///          buffers across stages (the L1-cache win).
1213    ///   Metal — **deferred**. A dispatch-wrapper version (chaining
1214    ///          existing kernels) buys nothing the unfused Metal path
1215    ///          doesn't already get, since per-run cost is dominated
1216    ///          by `wait_until_completed` (~150 µs), not encode. The
1217    ///          real win is a single MSL kernel keeping Q/K/V in
1218    ///          threadgroup memory across stages — multi-day work.
1219    ///          Until then, Metal runs the unfused chain (one matmul,
1220    ///          three narrows, two ropes, attention, one matmul) — all
1221    ///          covered in op_coverage and parity_harness.
1222    FusedAttentionBlock {
1223        num_heads: usize,
1224        head_dim: usize,
1225        has_bias: bool,
1226        has_rope: bool,
1227    },
1228
1229    // ── Control flow (subgraphs as op payloads) ─────────────────
1230    //
1231    // Status: IR is defined; helper `run_if` / `run_while` exist in
1232    // rlx-runtime/src/subgraph.rs; **executor wiring is not yet
1233    // implemented** (both CPU thunk and Metal thunk fall through to
1234    // `Thunk::Nop` for these ops). Wiring requires:
1235    //   1. Recursive subgraph compile at parent-compile time.
1236    //   2. Per-subgraph input/output binding through the arena.
1237    //   3. Schedule-level dispatch when the predicate / loop cond is
1238    //      resolved at runtime.
1239    // Estimate: 4–6 hours of focused work + parity tests. Deferred
1240    // because no current in-tree model uses these ops;
1241    // surface area without a validation target invites silent bugs.
1242    /// Conditional: pick between two subgraphs based on a boolean predicate.
1243    /// Inputs: [predicate, ...captures (used inside both branches)].
1244    /// `then_branch` and `else_branch` are sub-graphs that share the
1245    /// captured inputs and must produce identically-shaped outputs.
1246    /// Used for: shape-dependent execution, batched inference of
1247    /// dynamic-length sequences with padding masks.
1248    If {
1249        then_branch: Box<crate::Graph>,
1250        else_branch: Box<crate::Graph>,
1251    },
1252
1253    /// Loop: iterate `body` while `cond` evaluates true.
1254    /// Inputs: [...initial loop-carried values].
1255    /// `cond`'s single output is a Bool scalar.
1256    /// `body`'s outputs become the next iteration's loop-carried inputs.
1257    /// Outputs of While are the values after the final iteration.
1258    /// Used for: KV-cache-driven autoregressive generation, beam search.
1259    While {
1260        cond: Box<crate::Graph>,
1261        body: Box<crate::Graph>,
1262        max_iterations: Option<usize>,
1263    },
1264
1265    /// Bounded-length loop with a fixed-shape carry, optional per-step
1266    /// inputs, and optional stacked output. Mirrors JAX's `lax.scan`.
1267    ///
1268    /// Body signature: `(carry, x_t_0, ..., x_t_{num_xs-1}) → carry_next`
1269    /// — `1 + num_xs` Op::Inputs in NodeId construction order (first
1270    /// declared is the carry; the remaining `num_xs` are per-step
1271    /// slices). Single output (the next carry).
1272    ///
1273    /// Outer Op::Scan inputs (in order):
1274    ///   `[init_carry, xs_0, xs_1, ..., xs_{num_xs-1}]`
1275    /// Each `xs_i` has shape `[length, *per_step_shape_i]`; the body
1276    /// sees `xs_i[t]` (a `per_step_shape_i` slice) on iteration `t`.
1277    ///
1278    /// Outer Op::Scan output:
1279    ///   * `save_trajectory == false` — final carry, shape `*carry`.
1280    ///   * `save_trajectory == true`  — stacked trajectory of carries,
1281    ///     shape `[length, *carry]`. Row `t` is the carry after step
1282    ///     `t+1`, so row `length-1` matches the no-trajectory case.
1283    ///
1284    /// Mirrors JAX's `lax.scan`. Common uses include time-stepping
1285    /// integrators with time-varying drives, Mamba-style SSM scans
1286    /// reading per-step inputs, and RNN-style sequence processing.
1287    Scan {
1288        body: Box<crate::Graph>,
1289        length: u32,
1290        save_trajectory: bool,
1291        /// Number of "broadcast" inputs — values that are constant
1292        /// across iterations. Outer scan inputs in order:
1293        ///   `[init, bcast_0..bcast_{B-1}, xs_0..xs_{X-1}]`
1294        /// Body Op::Inputs in NodeId order:
1295        ///   `[carry, bcast_0..bcast_{B-1}, x_t_0..x_t_{X-1}]`
1296        /// CPU executor fills bcast slots ONCE before the iteration
1297        /// loop (xs slots are filled per-step). The reverse-mode AD
1298        /// pre-pass materialises each bcast into an xs of shape
1299        /// `[length, *bcast]` via broadcast `Mul` so the rest of the
1300        /// VJP / executor pipeline can stay unchanged. `0` (default)
1301        /// keeps the original carry+xs scan shape.
1302        num_bcast: u32,
1303        /// Number of per-step `xs` inputs. Total outer Op::Scan
1304        /// inputs is `1 + num_bcast + num_xs`.
1305        num_xs: u32,
1306        /// Number of trajectory checkpoints when `save_trajectory ==
1307        /// true`. `0` means "save all `length` rows" (default). A
1308        /// positive value `K` means save only `K` evenly-spaced rows
1309        /// at indices `floor(t * length / K)` for `t in 0..K`. Used
1310        /// by recursive checkpointed AD: store O(√T) carries during
1311        /// forward, recompute the rest in the backward pass.
1312        ///
1313        /// When `0` (or `K == length`), the saved trajectory has
1314        /// shape `[length, *carry]` — same as the original behavior.
1315        /// When `0 < K < length`, the saved trajectory has shape
1316        /// `[K, *carry]`.
1317        num_checkpoints: u32,
1318    },
1319
1320    /// Reverse-mode AD companion to `Op::Scan` — extracts the carry
1321    /// gradient `dinit`. Walks `t = length-1 .. 0`, applying `body_vjp`
1322    /// to thread `dcarry` back through the time loop.
1323    ///
1324    /// Inputs (in order):
1325    ///   `[init, trajectory, upstream, xs_0, ..., xs_{num_xs-1}]`
1326    /// Output: `dinit`, shape = carry shape.
1327    ///
1328    /// `body_vjp` is the result of
1329    /// `autodiff::grad(body, [carry_id, xs_0_id, ..., xs_{num_xs-1}_id])`
1330    /// — a graph with `1 + num_xs + 1` Inputs (carry + x_t_i for each
1331    /// xs + `"d_output"`) and `1 + num_xs` outputs
1332    /// (dcarry + dx_t_i for each xs). This op reads `outputs[0]` =
1333    /// dcarry; the sibling [`Self::ScanBackwardXs`] reads the
1334    /// `outputs[1 + xs_idx]` slot for each xs gradient.
1335    ScanBackward {
1336        body_vjp: Box<crate::Graph>,
1337        length: u32,
1338        save_trajectory: bool,
1339        num_xs: u32,
1340        /// When `0` or equal to `length`, the trajectory input has
1341        /// shape `[length, *carry]` — every step's carry is cached
1342        /// (`CheckpointStrategy::All`). When `0 < K < length`, the
1343        /// trajectory input has shape `[K, *carry]` and the executor
1344        /// recomputes intermediate carries via `forward_body` between
1345        /// checkpoints. `forward_body` must be `Some` whenever this
1346        /// is < length.
1347        num_checkpoints: u32,
1348        /// Forward body (the same `body` from the forward Op::Scan).
1349        /// Required when `num_checkpoints > 0 && < length` so the
1350        /// executor can recompute carries between saved checkpoints.
1351        /// `None` for the All strategy (no recompute needed).
1352        forward_body: Option<Box<crate::Graph>>,
1353    },
1354
1355    /// Companion to [`Self::ScanBackward`] that extracts one stacked
1356    /// per-step `dxs_i` (shape `[length, *per_step_xs_i]`). Same inputs
1357    /// and same `body_vjp` graph as ScanBackward — `xs_idx` selects
1358    /// which body_vjp output to stack into the result.
1359    ///
1360    /// Note: each ScanBackwardXs runs its own backward loop. A future
1361    /// optimization can fuse them into a single multi-output backward
1362    /// kernel; for now it's `1 + num_xs` independent sweeps.
1363    ScanBackwardXs {
1364        body_vjp: Box<crate::Graph>,
1365        length: u32,
1366        save_trajectory: bool,
1367        num_xs: u32,
1368        xs_idx: u32,
1369        num_checkpoints: u32,
1370        forward_body: Option<Box<crate::Graph>>,
1371    },
1372
1373    /// CPU reference 3D Gaussian splat forward render.
1374    ///
1375    /// Seven flat F32 inputs (scene buffers + camera/render meta):
1376    ///   0. positions `[N*3]`
1377    ///   1. scales `[N*3]` (log-space)
1378    ///   2. rotations `[N*4]` (xyzw)
1379    ///   3. opacities `[N]` (logit)
1380    ///   4. colors `[N*3]` (linear RGB)
1381    ///   5. sh_coeffs `[N * sh_coeff_count * 3]`
1382    ///   6. meta `[23]` — camera position/target/up/fov/near/far, background RGB,
1383    ///      then width/height/tile_size/radius_scale/alpha_cutoff/max_splat_steps/
1384    ///      transmittance_threshold/max_list_entries as f32 bit-patterns.
1385    ///
1386    /// Output: `[height * width * 4]` linear RGBA (display gamma baked in).
1387    /// Build via [`crate::Graph::gaussian_splat_render`].
1388    ///
1389    /// Differentiable backward is not implemented in v1; autodiff treats this
1390    /// op as non-differentiable (same as [`Op::Sample`]).
1391    GaussianSplatRender {
1392        width: u32,
1393        height: u32,
1394        tile_size: u32,
1395        radius_scale: f32,
1396        alpha_cutoff: f32,
1397        max_splat_steps: u32,
1398        transmittance_threshold: f32,
1399        max_list_entries: u32,
1400    },
1401
1402    /// Backward pass for [`Self::GaussianSplatRender`].
1403    ///
1404    /// Eight inputs: the same seven as forward plus `d_loss_rgba` `[W*H*4]`
1405    /// (only RGB channels are used). Re-runs the training forward internally.
1406    ///
1407    /// Output: packed gradients
1408    /// `[positions(3N) | scales(3N) | rotations(4N) | opacities(N) | colors(3N) | sh(N*sh*3)]`.
1409    /// Unpack via [`crate::ops::splat::unpack_gaussian_splat_packed_grads`].
1410    GaussianSplatRenderBackward {
1411        width: u32,
1412        height: u32,
1413        tile_size: u32,
1414        radius_scale: f32,
1415        alpha_cutoff: f32,
1416        max_splat_steps: u32,
1417        transmittance_threshold: f32,
1418        max_list_entries: u32,
1419        loss_grad_clip: f32,
1420        sh_band: u32,
1421        max_anisotropy: f32,
1422    },
1423
1424    /// Strict IR stage 1: project, bin, sort, build per-pixel rays.
1425    ///
1426    /// Seven inputs (same scene + meta as [`Self::GaussianSplatRender`]). Output: packed
1427    /// prepare buffer (see `rlx_splat::prep_layout::prep_packed_len`).
1428    GaussianSplatPrepare {
1429        width: u32,
1430        height: u32,
1431        tile_size: u32,
1432        radius_scale: f32,
1433        alpha_cutoff: f32,
1434        max_splat_steps: u32,
1435        transmittance_threshold: f32,
1436        max_list_entries: u32,
1437    },
1438
1439    /// Strict IR stage 2: tile raster from [`Self::GaussianSplatPrepare`] output.
1440    ///
1441    /// Inputs: `prep` packed buffer, `meta` `[23]`. Output: `[width * height * 4]` RGBA.
1442    GaussianSplatRasterize {
1443        width: u32,
1444        height: u32,
1445        tile_size: u32,
1446        alpha_cutoff: f32,
1447        max_splat_steps: u32,
1448        transmittance_threshold: f32,
1449        max_list_entries: u32,
1450    },
1451
1452    /// User-registered custom op. `name` keys into the
1453    /// [`crate::op_registry`] for shape inference, autodiff, and
1454    /// per-backend execution. `attrs` is an opaque blob passed
1455    /// through to those callbacks (FFT direction, SparseLU
1456    /// reordering strategy, etc.). `num_inputs` is captured at
1457    /// construction time so [`Op::num_inputs`] stays infallible
1458    /// without a registry lookup. Build via [`crate::Graph::custom_op`].
1459    Custom {
1460        name: String,
1461        num_inputs: u32,
1462        attrs: Vec<u8>,
1463    },
1464
1465    /// 1D Fast Fourier Transform along the last axis.
1466    ///
1467    /// Convention: complex tensors are represented as 2N real-block
1468    /// — the input shape is `[..., 2N]` along the last axis, with
1469    /// the first N elements the real part and the second N the
1470    /// imaginary part. Output shape matches the input. Last axis
1471    /// length must be even (and a power of 2 for the v1 radix-2
1472    /// kernel; other sizes will eventually go through mixed-radix).
1473    ///
1474    /// Both forward and inverse are **unnormalized** (no 1/N scale):
1475    ///   `fft(x)[k] = Σ x[n]·exp(-2πi·nk/N)`
1476    ///   `ifft(y)[n] = Σ y[k]·exp(+2πi·nk/N)`
1477    /// so `ifft(fft(x)) = N·x`. Users dividing by N for round-trip
1478    /// identity matches numpy's `fft.fft` / `fft.ifft·N` convention.
1479    ///
1480    /// The unnormalized choice keeps both AD rules free of scaling:
1481    ///   * reverse-mode VJP: `VJP(fft) = ifft`, `VJP(ifft) = fft`
1482    ///     (transpose of the DFT matrix over the 2N-real-block view
1483    ///     equals the unnormalized inverse).
1484    ///   * forward-mode JVP: same op, same direction — FFT is linear,
1485    ///     so the JVP is the linear map itself, not its transpose.
1486    ///
1487    /// CPU paths exist for both `DType::F32` and `DType::F64` on the
1488    /// 2N-real-block layout. Native `DType::C64` and non-power-of-two
1489    /// sizes (Bluestein / mixed-radix) are not implemented; ND FFT
1490    /// and non-CPU backend lowerings are deferred.
1491    Fft {
1492        inverse: bool,
1493    },
1494
1495    /// User-defined sub-graph with optional override AD rules.
1496    /// Mirrors JAX's `custom_vjp` / `custom_jvp` decorators: the
1497    /// caller wraps a forward computation and supplies its own
1498    /// reverse- and/or forward-mode AD bodies. Useful when:
1499    ///   * The forward is iterative (Newton, fixed-point) and
1500    ///     differentiating through the loop is wasteful — the
1501    ///     vjp_body computes the implicit-function gradient at the
1502    ///     converged point in one shot.
1503    ///   * The math has a closed-form gradient that's much cheaper
1504    ///     than autodiff.
1505    ///   * The forward op is non-differentiable by tracing
1506    ///     (sampling, argmax) and the user wants a smooth surrogate.
1507    ///
1508    /// **fwd_body**: `num_inputs` Op::Inputs in NodeId construction
1509    /// order, one Op::Output (the primal y). Forward execution
1510    /// inlines this body once.
1511    ///
1512    /// **vjp_body** (optional): Op::Inputs are `num_inputs` primal
1513    /// inputs in NodeId order, plus two special-named Inputs —
1514    /// `"primal_output"` (the y from forward) and `"d_output"` (the
1515    /// upstream gradient). Outputs: `num_inputs` tensors in
1516    /// `set_outputs` order, matching the gradients of each primal
1517    /// input. When `None`, reverse-mode AD recurses into fwd_body
1518    /// — same as if the op were inlined.
1519    ///
1520    /// **jvp_body** (optional): Op::Inputs are `num_inputs` primal
1521    /// inputs in NodeId order, `num_inputs` special-named Inputs
1522    /// `"tangent_0"..="tangent_{num_inputs-1}"` carrying each input's
1523    /// tangent, and an optional special-named `"primal_output"` Input
1524    /// (the y from forward, useful when the JVP must be evaluated at
1525    /// a converged / nonlinear point — e.g. IFT-style forward-mode AD
1526    /// of an iterative solver). Output: 1 tensor (the tangent of y).
1527    /// When `None`, forward-mode AD recurses into fwd_body.
1528    ///
1529    /// `num_inputs` is captured so [`Op::num_inputs`] stays
1530    /// infallible. Build via [`crate::Graph::custom_fn`].
1531    CustomFn {
1532        fwd_body: Box<crate::Graph>,
1533        vjp_body: Option<Box<crate::Graph>>,
1534        jvp_body: Option<Box<crate::Graph>>,
1535        num_inputs: u32,
1536    },
1537}
1538
1539impl Op {
1540    /// PLAN L4: discriminant for backend-supported-set checks.
1541    /// Stable, parameter-free identity per variant — `Op::Activation(_)`
1542    /// and `Op::Activation(Relu)` share the same `OpKind::Activation`.
1543    pub fn kind(&self) -> OpKind {
1544        match self {
1545            Op::Input { .. } => OpKind::Input,
1546            Op::Param { .. } => OpKind::Param,
1547            Op::Constant { .. } => OpKind::Constant,
1548            Op::Activation(_) => OpKind::Activation,
1549            Op::Cast { .. } => OpKind::Cast,
1550            Op::Quantize { .. } => OpKind::Quantize,
1551            Op::Dequantize { .. } => OpKind::Dequantize,
1552            Op::FakeQuantize { .. } => OpKind::FakeQuantize,
1553            Op::FakeQuantizeLSQ { .. } => OpKind::FakeQuantizeLSQ,
1554            Op::FakeQuantizeLSQBackwardX { .. } => OpKind::FakeQuantizeLSQBackwardX,
1555            Op::FakeQuantizeLSQBackwardScale { .. } => OpKind::FakeQuantizeLSQBackwardScale,
1556            Op::Binary(_) => OpKind::Binary,
1557            Op::Compare(_) => OpKind::Compare,
1558            Op::Where => OpKind::Where,
1559            Op::ElementwiseRegion { .. } => OpKind::ElementwiseRegion,
1560            Op::MatMul => OpKind::MatMul,
1561            Op::DotGeneral { .. } => OpKind::DotGeneral,
1562            Op::DenseSolve => OpKind::DenseSolve,
1563            Op::BatchedDenseSolve => OpKind::BatchedDenseSolve,
1564            Op::LayerNorm { .. } => OpKind::LayerNorm,
1565            Op::LayerNorm2d { .. } => OpKind::LayerNorm2d,
1566            Op::GroupNorm { .. } => OpKind::GroupNorm,
1567            Op::RmsNorm { .. } => OpKind::RmsNorm,
1568            Op::ResizeNearest2x => OpKind::ResizeNearest2x,
1569            Op::Attention { .. } => OpKind::Attention,
1570            Op::Rope { .. } => OpKind::Rope,
1571            Op::AxialRope2d { .. } => OpKind::AxialRope2d,
1572            Op::Reshape { .. } => OpKind::Reshape,
1573            Op::Transpose { .. } => OpKind::Transpose,
1574            Op::Narrow { .. } => OpKind::Narrow,
1575            Op::Concat { .. } => OpKind::Concat,
1576            Op::Expand { .. } => OpKind::Expand,
1577            Op::Gather { .. } => OpKind::Gather,
1578            Op::Reduce { .. } => OpKind::Reduce,
1579            Op::Softmax { .. } => OpKind::Softmax,
1580            Op::Cumsum { .. } => OpKind::Cumsum,
1581            Op::TopK { .. } => OpKind::TopK,
1582            Op::Sample { .. } => OpKind::Sample,
1583            Op::Conv { .. } => OpKind::Conv,
1584            Op::ConvTranspose2d { .. } => OpKind::ConvTranspose2d,
1585            Op::Pool { .. } => OpKind::Pool,
1586            Op::ReluBackward => OpKind::ReluBackward,
1587            Op::ActivationBackward { .. } => OpKind::ActivationBackward,
1588            Op::FakeQuantizeBackward { .. } => OpKind::FakeQuantizeBackward,
1589            Op::ComplexNormSq => OpKind::ComplexNormSq,
1590            Op::ComplexNormSqBackward => OpKind::ComplexNormSqBackward,
1591            Op::Conjugate => OpKind::Conjugate,
1592            Op::LayerNormBackwardInput { .. } => OpKind::LayerNormBackwardInput,
1593            Op::LayerNormBackwardGamma { .. } => OpKind::LayerNormBackwardGamma,
1594            Op::RmsNormBackwardInput { .. } => OpKind::RmsNormBackwardInput,
1595            Op::RmsNormBackwardGamma { .. } => OpKind::RmsNormBackwardGamma,
1596            Op::RmsNormBackwardBeta { .. } => OpKind::RmsNormBackwardBeta,
1597            Op::RopeBackward { .. } => OpKind::RopeBackward,
1598            Op::GroupNormBackwardInput { .. } => OpKind::GroupNormBackwardInput,
1599            Op::GroupNormBackwardGamma { .. } => OpKind::GroupNormBackwardGamma,
1600            Op::GroupNormBackwardBeta { .. } => OpKind::GroupNormBackwardBeta,
1601            Op::CumsumBackward { .. } => OpKind::CumsumBackward,
1602            Op::GatherBackward { .. } => OpKind::GatherBackward,
1603            Op::MaxPool2dBackward { .. } => OpKind::MaxPool2dBackward,
1604            Op::Conv2dBackwardInput { .. } => OpKind::Conv2dBackwardInput,
1605            Op::Conv2dBackwardWeight { .. } => OpKind::Conv2dBackwardWeight,
1606            Op::SoftmaxCrossEntropyWithLogits => OpKind::SoftmaxCrossEntropyWithLogits,
1607            Op::SoftmaxCrossEntropyBackward => OpKind::SoftmaxCrossEntropyBackward,
1608            Op::AttentionBackward { .. } => OpKind::AttentionBackward,
1609            Op::GroupedMatMul => OpKind::GroupedMatMul,
1610            Op::DequantGroupedMatMul { .. } => OpKind::DequantGroupedMatMul,
1611            Op::DequantMoEWeights { .. } => OpKind::DequantMoEWeights,
1612            Op::ScatterAdd => OpKind::ScatterAdd,
1613            Op::LoraMatMul { .. } => OpKind::LoraMatMul,
1614            Op::DequantMatMul { .. } => OpKind::DequantMatMul,
1615            Op::QMatMul { .. } => OpKind::QMatMul,
1616            Op::QConv2d { .. } => OpKind::QConv2d,
1617            Op::SelectiveScan { .. } => OpKind::SelectiveScan,
1618            Op::GatedDeltaNet { .. } => OpKind::GatedDeltaNet,
1619            Op::FusedSwiGLU { .. } => OpKind::FusedSwiGLU,
1620            Op::FusedMatMulBiasAct { .. } => OpKind::FusedMatMulBiasAct,
1621            Op::FusedResidualLN { .. } => OpKind::FusedResidualLN,
1622            Op::FusedResidualRmsNorm { .. } => OpKind::FusedResidualRmsNorm,
1623            Op::FusedAttentionBlock { .. } => OpKind::FusedAttentionBlock,
1624            Op::FusedTransformerLayer { .. } => OpKind::FusedTransformerLayer,
1625            Op::If { .. } => OpKind::If,
1626            Op::While { .. } => OpKind::While,
1627            Op::Scan { .. } => OpKind::Scan,
1628            Op::ScanBackward { .. } => OpKind::ScanBackward,
1629            Op::ScanBackwardXs { .. } => OpKind::ScanBackwardXs,
1630            Op::GaussianSplatRender { .. } => OpKind::GaussianSplatRender,
1631            Op::GaussianSplatRenderBackward { .. } => OpKind::GaussianSplatRenderBackward,
1632            Op::GaussianSplatPrepare { .. } => OpKind::GaussianSplatPrepare,
1633            Op::GaussianSplatRasterize { .. } => OpKind::GaussianSplatRasterize,
1634            Op::Custom { .. } => OpKind::Custom,
1635            Op::CustomFn { .. } => OpKind::CustomFn,
1636            Op::Fft { .. } => OpKind::Fft,
1637        }
1638    }
1639
1640    /// True if this op is element-wise (same shape in, same shape out).
1641    /// Element-wise ops are prime fusion candidates.
1642    pub fn is_elementwise(&self) -> bool {
1643        matches!(
1644            self,
1645            Op::Activation(_)
1646                | Op::Cast { .. }
1647                | Op::Binary(_)
1648                | Op::Compare(_)
1649                | Op::Where
1650                | Op::ElementwiseRegion { .. }
1651        )
1652    }
1653
1654    /// True if this op is a BLAS/compute-intensive op that forms a fusion boundary.
1655    pub fn is_blas(&self) -> bool {
1656        matches!(
1657            self,
1658            Op::MatMul
1659                | Op::DotGeneral { .. }
1660                | Op::DenseSolve
1661                | Op::BatchedDenseSolve
1662                | Op::Conv { .. }
1663                | Op::ConvTranspose2d { .. }
1664                | Op::FusedMatMulBiasAct { .. }
1665                | Op::GroupedMatMul
1666                | Op::DequantGroupedMatMul { .. }
1667                | Op::DequantMoEWeights { .. }
1668                | Op::LoraMatMul { .. }
1669                | Op::DequantMatMul { .. }
1670                | Op::QMatMul { .. }
1671                | Op::QConv2d { .. }
1672        )
1673    }
1674
1675    /// True if element-wise fusion must not span across this op.
1676    pub fn is_fusion_boundary(&self) -> bool {
1677        self.is_blas()
1678            || matches!(
1679                self,
1680                Op::GaussianSplatRender { .. }
1681                    | Op::GaussianSplatRenderBackward { .. }
1682                    | Op::GaussianSplatPrepare { .. }
1683                    | Op::GaussianSplatRasterize { .. }
1684            )
1685    }
1686
1687    /// True if this op is a reduction (drives loop iteration in fused kernels).
1688    pub fn is_reduction(&self) -> bool {
1689        matches!(
1690            self,
1691            Op::Reduce { .. } | Op::Softmax { .. } | Op::TopK { .. }
1692        )
1693    }
1694
1695    /// Number of tensor inputs this op expects.
1696    pub fn num_inputs(&self) -> usize {
1697        match self {
1698            Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => 0,
1699            Op::Activation(_)
1700            | Op::Cast { .. }
1701            | Op::Reshape { .. }
1702            | Op::Quantize { .. }
1703            | Op::Dequantize { .. }
1704            | Op::Transpose { .. }
1705            | Op::Narrow { .. }
1706            | Op::Expand { .. }
1707            | Op::Reduce { .. }
1708            | Op::Softmax { .. }
1709            | Op::FusedSwiGLU { .. }
1710            | Op::TopK { .. }
1711            | Op::Cumsum { .. }
1712            | Op::Sample { .. }
1713            | Op::ResizeNearest2x => 1,
1714            // EMA / Fixed scale modes carry a state tensor as a 2nd input;
1715            // PerBatch (default) doesn't need one.
1716            Op::FakeQuantize { scale_mode, .. } => match scale_mode {
1717                ScaleMode::PerBatch => 1,
1718                ScaleMode::EMA { .. } | ScaleMode::Fixed => 2,
1719            },
1720            Op::FakeQuantizeLSQ { .. } => 2, // x, scale (learned param)
1721            Op::FakeQuantizeLSQBackwardX { .. } | Op::FakeQuantizeLSQBackwardScale { .. } => 3, // x, scale, dy
1722            Op::Binary(_) | Op::Compare(_) | Op::Gather { .. } | Op::MatMul | Op::ScatterAdd => 2,
1723            Op::GroupedMatMul => 3,               // input, weight, expert_idx
1724            Op::DequantGroupedMatMul { .. } => 3, // input, packed_w, expert_idx
1725            Op::DequantMoEWeights { .. } => 1,    // packed_w
1726            Op::LoraMatMul { .. } => 4,           // x, w, a, b
1727            // x, w_q, scale, zp — or x, packed_w_bytes for GGUF
1728            // schemes (their scales/mins live inside the packed bytes,
1729            // see `QuantScheme::is_gguf`).
1730            Op::DequantMatMul { scheme } => {
1731                if scheme.is_gguf() {
1732                    2
1733                } else {
1734                    4
1735                }
1736            }
1737            Op::QMatMul { .. } => 3,       // x, w, bias
1738            Op::QConv2d { .. } => 3,       // x, w, bias
1739            Op::SelectiveScan { .. } => 5, // x, delta, a, b, c
1740            Op::GatedDeltaNet { carry_state, .. } if *carry_state => 6, // + state in/out
1741            Op::GatedDeltaNet { .. } => 5, // q, k, v, g, beta
1742            Op::Where => 3,                // cond, on_true, on_false
1743            Op::Attention { mask_kind, .. } => match mask_kind {
1744                MaskKind::Custom | MaskKind::Bias => 4, // Q, K, V, mask
1745                _ => 3,                                 // Q, K, V (mask synthesized in-kernel)
1746            },
1747            Op::AttentionBackward { mask_kind, .. } => match mask_kind {
1748                MaskKind::Custom | MaskKind::Bias => 5, // q, k, v, dy, mask
1749                _ => 4,                                 // q, k, v, dy
1750            },
1751            Op::Rope { .. } => 3, // x, cos, sin
1752            Op::AxialRope2d { .. } => 1,
1753            Op::LayerNorm { .. }
1754            | Op::LayerNorm2d { .. }
1755            | Op::GroupNorm { .. }
1756            | Op::RmsNorm { .. } => 3, // input, gamma, beta
1757            Op::FusedMatMulBiasAct { .. } => 3, // input, weight, bias
1758            Op::FusedResidualLN { has_bias: true, .. } => 5, // x, residual, bias, gamma, beta
1759            Op::FusedResidualLN {
1760                has_bias: false, ..
1761            } => 4, // x, residual, gamma, beta
1762            Op::FusedResidualRmsNorm { has_bias: true, .. } => 5, // x, residual, bias, gamma, beta
1763            Op::FusedResidualRmsNorm {
1764                has_bias: false, ..
1765            } => 4, // x, residual, gamma, beta
1766            Op::Conv { .. } | Op::ConvTranspose2d { .. } => 2, // input, weight (bias via Add)
1767            Op::Pool { .. } => 1,
1768            Op::ReluBackward => 2,                  // x, dy
1769            Op::ActivationBackward { .. } => 2,     // x, dy
1770            Op::FakeQuantizeBackward { .. } => 2,   // x, dy
1771            Op::ComplexNormSq => 1,                 // z (C64)
1772            Op::ComplexNormSqBackward => 2,         // z, g
1773            Op::Conjugate => 1,                     // z (C64)
1774            Op::LayerNormBackwardInput { .. } => 3, // x, gamma, dy
1775            Op::LayerNormBackwardGamma { .. } => 2, // x, dy
1776            Op::RmsNormBackwardInput { .. } => 4,   // x, gamma, beta, dy
1777            Op::RmsNormBackwardGamma { .. } => 4,
1778            Op::RmsNormBackwardBeta { .. } => 4,
1779            Op::RopeBackward { .. } => 3,           // dy, cos, sin
1780            Op::GroupNormBackwardInput { .. } => 4, // x, gamma, beta, dy
1781            Op::GroupNormBackwardGamma { .. } => 2, // x, dy
1782            Op::GroupNormBackwardBeta { .. } => 2,
1783            Op::CumsumBackward { .. } => 1,         // dy
1784            Op::GatherBackward { .. } => 2,         // dy, indices
1785            Op::MaxPool2dBackward { .. } => 2,      // x, dy
1786            Op::Conv2dBackwardInput { .. } => 2,    // dy, w
1787            Op::Conv2dBackwardWeight { .. } => 2,   // x, dy
1788            Op::SoftmaxCrossEntropyWithLogits => 2, // logits, labels
1789            Op::SoftmaxCrossEntropyBackward => 3,   // logits, labels, d_loss
1790            Op::Concat { .. } => 0,                 // variadic — checked at graph level
1791            Op::DotGeneral { .. } => 2,
1792            Op::DenseSolve => 2,        // A, b
1793            Op::BatchedDenseSolve => 2, // A [B,N,N], b [B,N] or [B,N,K]
1794            Op::FusedAttentionBlock {
1795                has_bias, has_rope, ..
1796            } => 4 + if *has_bias { 2 } else { 0 } + if *has_rope { 2 } else { 0 },
1797            Op::If { .. } => 1,    // predicate (captures handled separately)
1798            Op::While { .. } => 0, // variadic loop-carried; checked at graph level
1799            Op::Scan {
1800                num_bcast, num_xs, ..
1801            } => 1 + *num_bcast as usize + *num_xs as usize,
1802            Op::ScanBackward { num_xs, .. } => 3 + *num_xs as usize, // init, trajectory, upstream, xs_0..
1803            Op::ScanBackwardXs { num_xs, .. } => 3 + *num_xs as usize, // same as ScanBackward
1804            Op::GaussianSplatRender { .. } => 7,
1805            Op::GaussianSplatRenderBackward { .. } => 8,
1806            Op::GaussianSplatPrepare { .. } => 7,
1807            Op::GaussianSplatRasterize { .. } => 2,
1808            Op::FusedTransformerLayer { has_bias, .. } => {
1809                // hidden + qkv_w + out_w + ln1_g + ln1_b + fc1_w + fc2_w + ln2_g + ln2_b + mask = 10
1810                // bias variant adds: qkv_b + out_b + fc1_b + fc2_b = 4 more
1811                10 + if *has_bias { 4 } else { 0 }
1812            }
1813            Op::ElementwiseRegion { num_inputs, .. } => *num_inputs as usize,
1814            Op::Custom { num_inputs, .. } => *num_inputs as usize,
1815            Op::CustomFn { num_inputs, .. } => *num_inputs as usize,
1816            Op::Fft { .. } => 1,
1817        }
1818    }
1819}
1820
1821impl std::fmt::Display for Op {
1822    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1823        match self {
1824            Op::Input { name } => write!(f, "input(\"{name}\")"),
1825            Op::Param { name } => write!(f, "param(\"{name}\")"),
1826            Op::Constant { data } => write!(f, "const({}B)", data.len()),
1827            Op::Activation(a) => write!(f, "{a:?}"),
1828            Op::Quantize { axis, scales, .. } => match axis {
1829                None => write!(f, "quantize(s={})", scales[0]),
1830                Some(d) => write!(f, "quantize(axis={d},nch={})", scales.len()),
1831            },
1832            Op::Dequantize { axis, scales, .. } => match axis {
1833                None => write!(f, "dequantize(s={})", scales[0]),
1834                Some(d) => write!(f, "dequantize(axis={d},nch={})", scales.len()),
1835            },
1836            Op::FakeQuantize {
1837                bits,
1838                axis,
1839                ste,
1840                scale_mode,
1841            } => match axis {
1842                None => write!(
1843                    f,
1844                    "fake_quant(bits={bits},ste={ste:?},scale={scale_mode:?})"
1845                ),
1846                Some(d) => write!(
1847                    f,
1848                    "fake_quant(bits={bits},axis={d},ste={ste:?},scale={scale_mode:?})"
1849                ),
1850            },
1851            Op::FakeQuantizeLSQ { bits, axis } => match axis {
1852                None => write!(f, "fake_quant_lsq(bits={bits})"),
1853                Some(d) => write!(f, "fake_quant_lsq(bits={bits},axis={d})"),
1854            },
1855            Op::FakeQuantizeLSQBackwardX { bits, .. } => {
1856                write!(f, "fake_quant_lsq_bwd_x(bits={bits})")
1857            }
1858            Op::FakeQuantizeLSQBackwardScale { bits, .. } => {
1859                write!(f, "fake_quant_lsq_bwd_s(bits={bits})")
1860            }
1861            Op::Cast { to } => write!(f, "cast({to})"),
1862            Op::Binary(op) => write!(f, "{op:?}"),
1863            Op::Compare(op) => write!(f, "{op:?}"),
1864            Op::Where => write!(f, "where"),
1865            Op::MatMul => write!(f, "matmul"),
1866            Op::DotGeneral { .. } => write!(f, "dot_general"),
1867            Op::DenseSolve => write!(f, "dense_solve"),
1868            Op::BatchedDenseSolve => write!(f, "batched_dense_solve"),
1869            Op::LayerNorm { eps, .. } => write!(f, "layer_norm(eps={eps})"),
1870            Op::GroupNorm { num_groups, eps } => {
1871                write!(f, "group_norm(groups={num_groups},eps={eps})")
1872            }
1873            Op::ResizeNearest2x => write!(f, "resize_nearest_2x"),
1874            Op::RmsNorm { eps, .. } => write!(f, "rms_norm(eps={eps})"),
1875            Op::Attention {
1876                num_heads,
1877                head_dim,
1878                mask_kind,
1879                score_scale,
1880                attn_logit_softcap,
1881            } => {
1882                let mut s = match mask_kind {
1883                    MaskKind::Custom => format!("attention(h={num_heads},d={head_dim})"),
1884                    MaskKind::None => format!("attention(h={num_heads},d={head_dim},nomask)"),
1885                    MaskKind::Causal => format!("attention(h={num_heads},d={head_dim},causal)"),
1886                    MaskKind::SlidingWindow(w) => {
1887                        format!("attention(h={num_heads},d={head_dim},sw={w})")
1888                    }
1889                    MaskKind::Bias => format!("attention(h={num_heads},d={head_dim},bias)"),
1890                };
1891                if let Some(sc) = score_scale {
1892                    s.push_str(&format!(",scale={sc}"));
1893                }
1894                if let Some(cap) = attn_logit_softcap {
1895                    s.push_str(&format!(",softcap={cap}"));
1896                }
1897                write!(f, "{s}")
1898            }
1899            Op::Rope { head_dim, n_rot } => write!(f, "rope(d={head_dim}, n_rot={n_rot})"),
1900            Op::AxialRope2d {
1901                end_x,
1902                end_y,
1903                head_dim,
1904                num_heads,
1905                theta,
1906                repeat_factor,
1907            } => write!(
1908                f,
1909                "axial_rope2d({end_x}x{end_y},h={num_heads},d={head_dim},θ={theta},r={repeat_factor})"
1910            ),
1911            Op::Reshape { new_shape } => write!(f, "reshape({new_shape:?})"),
1912            Op::Transpose { perm } => write!(f, "transpose({perm:?})"),
1913            Op::Narrow { axis, start, len } => write!(f, "narrow({axis},{start},{len})"),
1914            Op::Concat { axis } => write!(f, "concat(axis={axis})"),
1915            Op::Expand { .. } => write!(f, "expand"),
1916            Op::Gather { axis } => write!(f, "gather(axis={axis})"),
1917            Op::Reduce { op, axes, .. } => write!(f, "reduce_{op:?}({axes:?})"),
1918            Op::Softmax { axis } => write!(f, "softmax(axis={axis})"),
1919            Op::Cumsum { axis, exclusive } => {
1920                if *exclusive {
1921                    write!(f, "cumsum(axis={axis},excl)")
1922                } else {
1923                    write!(f, "cumsum(axis={axis})")
1924                }
1925            }
1926            Op::Sample {
1927                top_k,
1928                top_p,
1929                temperature,
1930                ..
1931            } => {
1932                write!(f, "sample(t={temperature}")?;
1933                if *top_k > 0 {
1934                    write!(f, ",k={top_k}")?;
1935                }
1936                if *top_p < 1.0 {
1937                    write!(f, ",p={top_p}")?;
1938                }
1939                write!(f, ")")
1940            }
1941            Op::TopK { k } => write!(f, "topk(k={k})"),
1942            Op::GroupedMatMul => write!(f, "grouped_matmul"),
1943            Op::DequantGroupedMatMul { scheme } => {
1944                write!(f, "dequant_grouped_matmul({scheme})")
1945            }
1946            Op::DequantMoEWeights { scheme } => write!(f, "dequant_moe_weights({scheme})"),
1947            Op::LoraMatMul { scale } => write!(f, "lora_matmul(scale={scale})"),
1948            Op::DequantMatMul { scheme } => write!(f, "dequant_matmul({scheme})"),
1949            Op::QMatMul {
1950                x_zp,
1951                w_zp,
1952                out_zp,
1953                mult,
1954            } => write!(
1955                f,
1956                "q_matmul(x_zp={x_zp},w_zp={w_zp},out_zp={out_zp},mult={mult})"
1957            ),
1958            Op::QConv2d { kernel_size, .. } => write!(f, "q_conv2d({kernel_size:?})"),
1959            Op::SelectiveScan { state_size } => write!(f, "ssm_scan(n={state_size})"),
1960            Op::GatedDeltaNet {
1961                state_size,
1962                carry_state,
1963            } => {
1964                if *carry_state {
1965                    write!(f, "gated_delta_net(n={state_size},carry)")
1966                } else {
1967                    write!(f, "gated_delta_net(n={state_size})")
1968                }
1969            }
1970            Op::ScatterAdd => write!(f, "scatter_add"),
1971            Op::Conv { kernel_size, .. } => write!(f, "conv2d({kernel_size:?})"),
1972            Op::ConvTranspose2d { kernel_size, .. } => {
1973                write!(f, "conv_transpose2d({kernel_size:?})")
1974            }
1975            Op::LayerNorm2d { eps } => write!(f, "layer_norm2d(eps={eps})"),
1976            Op::Pool {
1977                kind, kernel_size, ..
1978            } => write!(f, "pool_{kind:?}({kernel_size:?})"),
1979            Op::ReluBackward => write!(f, "relu_backward"),
1980            Op::ActivationBackward { kind } => write!(f, "{kind:?}_backward"),
1981            Op::ComplexNormSq => write!(f, "complex_norm_sq"),
1982            Op::ComplexNormSqBackward => write!(f, "complex_norm_sq_backward"),
1983            Op::Conjugate => write!(f, "conjugate"),
1984            Op::FakeQuantizeBackward { bits, ste, .. } => {
1985                write!(f, "fake_quant_backward(bits={bits},ste={ste:?})")
1986            }
1987            Op::MaxPool2dBackward { kernel_size, .. } => {
1988                write!(f, "maxpool2d_backward({kernel_size:?})")
1989            }
1990            Op::Conv2dBackwardInput { kernel_size, .. } => {
1991                write!(f, "conv2d_backward_input({kernel_size:?})")
1992            }
1993            Op::Conv2dBackwardWeight { kernel_size, .. } => {
1994                write!(f, "conv2d_backward_weight({kernel_size:?})")
1995            }
1996            Op::SoftmaxCrossEntropyWithLogits => write!(f, "sce_with_logits"),
1997            Op::SoftmaxCrossEntropyBackward => write!(f, "sce_backward"),
1998            Op::AttentionBackward {
1999                num_heads,
2000                head_dim,
2001                mask_kind,
2002                wrt,
2003            } => match mask_kind {
2004                MaskKind::None => write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},nomask)"),
2005                MaskKind::Causal => {
2006                    write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},causal)")
2007                }
2008                MaskKind::SlidingWindow(w) => {
2009                    write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},sw={w})")
2010                }
2011                MaskKind::Custom => {
2012                    write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},custom)")
2013                }
2014                MaskKind::Bias => write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},bias)"),
2015            },
2016            Op::FusedMatMulBiasAct { activation } => {
2017                write!(f, "fused_mm_bias")?;
2018                if let Some(a) = activation {
2019                    write!(f, "_{a:?}")?;
2020                }
2021                Ok(())
2022            }
2023            Op::FusedResidualLN { has_bias, eps } => {
2024                write!(f, "fused_residual")?;
2025                if *has_bias {
2026                    write!(f, "_bias")?;
2027                }
2028                write!(f, "_ln(eps={eps})")
2029            }
2030            Op::FusedResidualRmsNorm { has_bias, eps } => {
2031                write!(f, "fused_residual")?;
2032                if *has_bias {
2033                    write!(f, "_bias")?;
2034                }
2035                write!(f, "_rms(eps={eps})")
2036            }
2037            Op::FusedSwiGLU {
2038                cast_to,
2039                gate_first,
2040            } => {
2041                let mut s = match cast_to {
2042                    Some(dt) => format!("fused_swiglu(cast={dt}"),
2043                    None => "fused_swiglu(".to_string(),
2044                };
2045                if *gate_first {
2046                    s.push_str(",gate_first");
2047                }
2048                s.push(')');
2049                write!(f, "{s}")
2050            }
2051            Op::FusedAttentionBlock {
2052                num_heads,
2053                head_dim,
2054                has_bias,
2055                has_rope,
2056            } => {
2057                write!(f, "fused_attn(h={num_heads},d={head_dim}")?;
2058                if *has_bias {
2059                    write!(f, ",bias")?;
2060                }
2061                if *has_rope {
2062                    write!(f, ",rope")?;
2063                }
2064                write!(f, ")")
2065            }
2066            Op::If { .. } => write!(f, "if(...)"),
2067            Op::While { max_iterations, .. } => match max_iterations {
2068                Some(n) => write!(f, "while(...max={n})"),
2069                None => write!(f, "while(...)"),
2070            },
2071            Op::Scan {
2072                length,
2073                save_trajectory,
2074                num_xs,
2075                ..
2076            } => {
2077                let traj = if *save_trajectory { ",traj" } else { "" };
2078                let xs = if *num_xs > 0 {
2079                    format!(",xs={}", num_xs)
2080                } else {
2081                    String::new()
2082                };
2083                write!(f, "scan(len={length}{xs}{traj})")
2084            }
2085            Op::ScanBackward {
2086                length,
2087                save_trajectory,
2088                num_xs,
2089                ..
2090            } => {
2091                let traj = if *save_trajectory { ",traj" } else { "" };
2092                let xs = if *num_xs > 0 {
2093                    format!(",xs={}", num_xs)
2094                } else {
2095                    String::new()
2096                };
2097                write!(f, "scan_bwd(len={length}{xs}{traj})")
2098            }
2099            Op::ScanBackwardXs {
2100                length,
2101                save_trajectory,
2102                num_xs,
2103                xs_idx,
2104                ..
2105            } => {
2106                let traj = if *save_trajectory { ",traj" } else { "" };
2107                write!(
2108                    f,
2109                    "scan_bwd_xs(len={length},xs={num_xs},idx={xs_idx}{traj})"
2110                )
2111            }
2112            Op::FusedTransformerLayer {
2113                num_heads,
2114                head_dim,
2115                intermediate_size,
2116                has_bias,
2117                ..
2118            } => {
2119                write!(
2120                    f,
2121                    "fused_layer(h={num_heads},d={head_dim},int={intermediate_size}"
2122                )?;
2123                if *has_bias {
2124                    write!(f, ",bias")?;
2125                }
2126                write!(f, ")")
2127            }
2128            Op::ElementwiseRegion {
2129                chain,
2130                num_inputs,
2131                scalar_input_mask,
2132                input_modulus: _,
2133            } => {
2134                if *scalar_input_mask != 0 {
2135                    write!(
2136                        f,
2137                        "ew_region(in={num_inputs},steps={},scalar_mask=0x{:x})",
2138                        chain.len(),
2139                        scalar_input_mask
2140                    )
2141                } else {
2142                    write!(f, "ew_region(in={num_inputs},steps={})", chain.len())
2143                }
2144            }
2145            Op::LayerNormBackwardInput { eps, .. } => {
2146                write!(f, "layer_norm_backward_input(eps={eps})")
2147            }
2148            Op::LayerNormBackwardGamma { eps, .. } => {
2149                write!(f, "layer_norm_backward_gamma(eps={eps})")
2150            }
2151            Op::RmsNormBackwardInput { eps, .. } => write!(f, "rms_norm_backward_input(eps={eps})"),
2152            Op::RmsNormBackwardGamma { eps, .. } => write!(f, "rms_norm_backward_gamma(eps={eps})"),
2153            Op::RmsNormBackwardBeta { eps, .. } => write!(f, "rms_norm_backward_beta(eps={eps})"),
2154            Op::RopeBackward { head_dim, n_rot } => {
2155                write!(f, "rope_backward(d={head_dim},n_rot={n_rot})")
2156            }
2157            Op::GroupNormBackwardInput { num_groups, eps } => {
2158                write!(f, "group_norm_backward_input(g={num_groups},eps={eps})")
2159            }
2160            Op::GroupNormBackwardGamma { num_groups, eps } => {
2161                write!(f, "group_norm_backward_gamma(g={num_groups},eps={eps})")
2162            }
2163            Op::GroupNormBackwardBeta { num_groups, eps } => {
2164                write!(f, "group_norm_backward_beta(g={num_groups},eps={eps})")
2165            }
2166            Op::CumsumBackward { axis, exclusive } => {
2167                write!(f, "cumsum_backward(axis={axis},exclusive={exclusive})")
2168            }
2169            Op::GatherBackward { axis } => write!(f, "gather_backward(axis={axis})"),
2170            Op::GaussianSplatRender {
2171                width,
2172                height,
2173                tile_size,
2174                radius_scale,
2175                alpha_cutoff,
2176                max_splat_steps,
2177                transmittance_threshold,
2178                max_list_entries,
2179            } => write!(
2180                f,
2181                "gaussian_splat_render({width}x{height},tile={tile_size},r={radius_scale},a={alpha_cutoff},steps={max_splat_steps},t={transmittance_threshold},list={max_list_entries})"
2182            ),
2183            Op::GaussianSplatRenderBackward {
2184                width,
2185                height,
2186                loss_grad_clip,
2187                sh_band,
2188                ..
2189            } => write!(
2190                f,
2191                "gaussian_splat_render_bwd({width}x{height},clip={loss_grad_clip},sh={sh_band})"
2192            ),
2193            Op::GaussianSplatPrepare {
2194                width,
2195                height,
2196                tile_size,
2197                radius_scale,
2198                alpha_cutoff,
2199                max_splat_steps,
2200                transmittance_threshold,
2201                max_list_entries,
2202                ..
2203            } => write!(
2204                f,
2205                "gaussian_splat_prepare({width}x{height},tile={tile_size},r={radius_scale},a={alpha_cutoff},steps={max_splat_steps},t={transmittance_threshold},list={max_list_entries})"
2206            ),
2207            Op::GaussianSplatRasterize {
2208                width,
2209                height,
2210                tile_size,
2211                alpha_cutoff,
2212                max_splat_steps,
2213                transmittance_threshold,
2214                max_list_entries,
2215                ..
2216            } => write!(
2217                f,
2218                "gaussian_splat_rasterize({width}x{height},tile={tile_size},a={alpha_cutoff},steps={max_splat_steps},t={transmittance_threshold},list={max_list_entries})"
2219            ),
2220            Op::Custom {
2221                name,
2222                num_inputs,
2223                attrs,
2224            } => write!(f, "custom({name},in={num_inputs},attrs={}B)", attrs.len()),
2225            Op::CustomFn {
2226                num_inputs,
2227                vjp_body,
2228                jvp_body,
2229                ..
2230            } => {
2231                let v = if vjp_body.is_some() { ",vjp" } else { "" };
2232                let j = if jvp_body.is_some() { ",jvp" } else { "" };
2233                write!(f, "custom_fn(in={num_inputs}{v}{j})")
2234            }
2235            Op::Fft { inverse } => write!(f, "fft(inverse={inverse})"),
2236        }
2237    }
2238}