rlx_ir/op.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Operation types — every tensor op in the RLX IR.
17//!
18//! Designed for pattern-matching fusion: ops are grouped by category so
19//! fusion passes can reason about them structurally.
20
21use crate::DType;
22
23/// Unary element-wise activation functions.
24#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub enum Activation {
27 Gelu,
28 GeluApprox,
29 Silu, // SwiGLU gate activation
30 Relu,
31 Sigmoid,
32 Tanh,
33 Exp,
34 Log,
35 Sqrt,
36 Rsqrt,
37 Neg,
38 Abs,
39 /// `sin(x)`. Backward: `dx = upstream · cos(x)`.
40 Sin,
41 /// `cos(x)`. Backward: `dx = -upstream · sin(x)`.
42 Cos,
43 /// `tan(x)`. Backward: `dx = upstream · sec²(x) = upstream · (1 + tan²(x))`.
44 Tan,
45 /// `atan(x)`. Backward: `dx = upstream · (1 / (1 + x²))`.
46 Atan,
47 /// Round to nearest integer (half-to-even), in f32.
48 /// Forward: `x.round()`. Backward: STE — treats as identity, so
49 /// the gradient passes through unchanged. Useful as a primitive
50 /// for composing custom quantization schemes (Mul-by-recip-scale
51 /// → Round → Clamp → Mul-by-scale = a hand-rolled FakeQuantize
52 /// that the elementwise-region pass can fuse into a single kernel).
53 Round,
54}
55
56/// Scale-tracking strategy for `Op::FakeQuantize`. Determines how
57/// the per-channel `s[c]` is computed each forward pass.
58///
59/// * `PerBatch` — recompute `s[c] = max(|x|) / q_max` from the
60/// current data on every call. Simple, no extra inputs, but
61/// noisy for activations (max-abs jumps batch-to-batch).
62///
63/// * `EMA { decay }` — keep a running `s[c]` in a state tensor
64/// (passed as a second op input). On each call, blend the
65/// current per-batch max-abs into the state via
66/// `state' = decay·state + (1-decay)·max_abs`. Smooth scale
67/// over training, makes activation-QAT actually trainable.
68/// Typical `decay = 0.99`.
69///
70/// * `Fixed` — never recompute. The state tensor's value is
71/// used as-is each call (set once at construction or by the
72/// caller). Useful when scales are pre-calibrated.
73#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
74#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Default)]
75pub enum ScaleMode {
76 #[default]
77 PerBatch,
78 EMA {
79 decay: f32,
80 },
81 Fixed,
82}
83
84impl Eq for ScaleMode {}
85impl std::hash::Hash for ScaleMode {
86 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
87 match self {
88 ScaleMode::PerBatch => state.write_u8(0),
89 ScaleMode::EMA { decay } => {
90 state.write_u8(1);
91 state.write_u32(decay.to_bits());
92 }
93 ScaleMode::Fixed => state.write_u8(2),
94 }
95 }
96}
97
98/// Straight-through estimator variants for `Op::FakeQuantize`'s
99/// backward. The forward is the same regardless: discrete
100/// `clamp(round(x/s)) * s`. The choice here affects only the
101/// gradient w.r.t. `x` during training.
102///
103/// * `Identity` — `dx = upstream`. The original STE; treats the
104/// round as identity in the backward direction. Simplest, fine
105/// for moderate bit widths (i4 / i8).
106///
107/// * `ClippedIdentity` — `dx = upstream * (|x| ≤ q_max·s)`. Zero
108/// the gradient when the input was outside the quantization
109/// range (i.e. the clamp activated). Stops the optimizer from
110/// pushing weights further into saturation.
111///
112/// * `Tanh` — `dx = upstream * (1 - tanh²(x/s))`. Smooth surrogate
113/// for the round step. Slowly attenuates the gradient as `|x|`
114/// approaches `q_max·s`. Often best on tight bit widths (i2).
115///
116/// * `HardTanh` — `dx = upstream * (1 - |x/(q_max·s)|).max(0)`.
117/// Piecewise-linear cousin of `Tanh`; cheaper to compute and
118/// nearly as effective.
119#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
120#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
121pub enum SteKind {
122 #[default]
123 Identity,
124 ClippedIdentity,
125 Tanh,
126 HardTanh,
127}
128
129/// Binary element-wise operations.
130#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
132pub enum BinaryOp {
133 Add,
134 Sub,
135 Mul,
136 Div,
137 Max,
138 Min,
139 Pow,
140}
141
142/// Comparison operations (return Bool tensor).
143#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
144#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
145pub enum CmpOp {
146 Eq,
147 Ne,
148 Lt,
149 Le,
150 Gt,
151 Ge,
152}
153
154/// What kind of attention mask the kernel should apply.
155///
156/// Borrowed from MAX's `nn/attention/mha_mask.mojo` pattern (#20 in
157/// PLAN.md): one attention kernel handles all variants by branching on
158/// the mask kind, instead of forcing every caller to materialize a mask
159/// tensor. The win is two-fold:
160/// 1. **`None`** — single unpadded sequence: no mask load, no per-key
161/// compare in the inner loop.
162/// 2. **`Causal`** — autoregressive decode: kernel generates the upper-
163/// triangular fill from `(qi, ki)` directly; no `seq²` mask tensor
164/// ever exists.
165///
166/// `Custom` is the existing path — read mask values from the 4th input.
167#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
168#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
169pub enum MaskKind {
170 /// No masking — every position attends to every position.
171 None,
172 /// Causal (autoregressive) — position `qi` attends only to `ki <= qi`.
173 Causal,
174 /// Sliding window — position `qi` attends to `ki ∈ [qi - w, qi]`.
175 SlidingWindow(usize),
176 /// Read mask values from the input tensor (default; matches BERT
177 /// padding-mask behavior). Tensor shape `[batch, key_len]` with
178 /// `1.0` = valid, `<0.5` = ignored.
179 Custom,
180 /// Additive per-head, per-query bias tensor
181 /// `[batch, num_heads, query_len, key_len]` added to the
182 /// `QK^T · scale` scores before softmax. Lets DETR-style boxRPB
183 /// and other learned position biases reuse the fast `Op::Attention`
184 /// path instead of decomposing into matmul + add + softmax + matmul.
185 Bias,
186}
187
188/// Which forward input an [`Op::AttentionBackward`] node differentiates.
189#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
190#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
191pub enum AttentionBwdWrt {
192 Query,
193 Key,
194 Value,
195}
196
197/// Reduction operations along specified axes.
198#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
199#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
200pub enum ReduceOp {
201 Sum,
202 Mean,
203 Max,
204 Min,
205 Prod,
206}
207
208/// PLAN L4: discriminant for each [`Op`] variant. Used by
209/// [`Op::kind`] + the `Backend::supported_ops` trait method to declare
210/// which ops a backend can lower; the `LegalizeForBackend` pass in
211/// `rlx-opt` checks the graph against this set and fails the compile
212/// when an unsupported op is present (instead of silent fallback).
213#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
214#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
215pub enum OpKind {
216 Input,
217 Param,
218 Constant,
219 Activation,
220 Cast,
221 Quantize,
222 Dequantize,
223 FakeQuantize,
224 FakeQuantizeLSQ,
225 FakeQuantizeLSQBackwardX,
226 FakeQuantizeLSQBackwardScale,
227 Binary,
228 Compare,
229 Where,
230 ElementwiseRegion,
231 MatMul,
232 DotGeneral,
233 DenseSolve,
234 BatchedDenseSolve,
235 LayerNorm,
236 LayerNorm2d,
237 GroupNorm,
238 RmsNorm,
239 ResizeNearest2x,
240 Attention,
241 Rope,
242 AxialRope2d,
243 Reshape,
244 Transpose,
245 Narrow,
246 Concat,
247 Expand,
248 Gather,
249 Reduce,
250 Softmax,
251 Cumsum,
252 TopK,
253 Sample,
254 Conv,
255 ConvTranspose2d,
256 Pool,
257 ReluBackward,
258 ActivationBackward,
259 FakeQuantizeBackward,
260 ComplexNormSq,
261 ComplexNormSqBackward,
262 Conjugate,
263 MaxPool2dBackward,
264 Conv2dBackwardInput,
265 Conv2dBackwardWeight,
266 SoftmaxCrossEntropyWithLogits,
267 SoftmaxCrossEntropyBackward,
268 AttentionBackward,
269 LayerNormBackwardInput,
270 LayerNormBackwardGamma,
271 RmsNormBackwardInput,
272 RmsNormBackwardGamma,
273 RmsNormBackwardBeta,
274 RopeBackward,
275 GroupNormBackwardInput,
276 GroupNormBackwardGamma,
277 GroupNormBackwardBeta,
278 CumsumBackward,
279 GatherBackward,
280 GroupedMatMul,
281 DequantGroupedMatMul,
282 DequantMoEWeights,
283 ScatterAdd,
284 LoraMatMul,
285 DequantMatMul,
286 QMatMul,
287 QConv2d,
288 SelectiveScan,
289 GatedDeltaNet,
290 FusedSwiGLU,
291 FusedMatMulBiasAct,
292 FusedResidualLN,
293 FusedResidualRmsNorm,
294 FusedAttentionBlock,
295 FusedTransformerLayer,
296 If,
297 While,
298 Scan,
299 ScanBackward,
300 ScanBackwardXs,
301 /// CPU reference 3D Gaussian splat raster (project → bin → sort → raster).
302 /// See [`Op::GaussianSplatRender`].
303 GaussianSplatRender,
304 /// Backward of [`Op::GaussianSplatRender`] — packed scene parameter gradients.
305 GaussianSplatRenderBackward,
306 /// Project + tile bin + sort + ray grid (strict IR splat stage 1).
307 GaussianSplatPrepare,
308 /// Per-pixel raster from prepared buffers (strict IR splat stage 2).
309 GaussianSplatRasterize,
310 /// User-registered op dispatched through `op_registry`. All
311 /// custom ops (Sparse-LU, FFT, eigensolve, ...) share this kind;
312 /// the per-op identity lives in `Op::Custom::name`.
313 Custom,
314 /// User-defined sub-graph with optional override AD rules. See
315 /// [`Op::CustomFn`] / [`crate::Graph::custom_fn`].
316 CustomFn,
317 /// 1D FFT primitive (forward or inverse) — see [`Op::Fft`].
318 Fft,
319}
320
321/// An operand inside a fused [`ChainStep`] — either a graph-level input
322/// to the [`Op::ElementwiseRegion`] (by index 0..num_inputs) or the
323/// result of a previous step in the chain (by index 0..step_position).
324#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
325#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
326pub enum ChainOperand {
327 Input(u32),
328 Step(u32),
329}
330
331/// One step in a fused element-wise chain. Each step produces exactly
332/// one scalar result (per element); later steps can refer to it via
333/// [`ChainOperand::Step`]. The whole chain runs per element in registers.
334#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
335#[derive(Debug, Clone, PartialEq)]
336pub enum ChainStep {
337 Activation(Activation, ChainOperand),
338 Cast(DType, ChainOperand),
339 Binary(BinaryOp, ChainOperand, ChainOperand),
340 Compare(CmpOp, ChainOperand, ChainOperand),
341 /// 3-input element-wise select: `cond ? on_true : on_false`. Mirrors
342 /// `Op::Where` inside a chain. `cond` is treated as truthy iff
343 /// non-zero. Lets the optimizer fold attention masks / clamp-style
344 /// patterns into a single region kernel instead of breaking the
345 /// chain at the first `Op::Where`.
346 Where(ChainOperand, ChainOperand, ChainOperand),
347}
348
349/// An operation in the RLX IR graph.
350///
351/// Operations are categorized for fusion analysis:
352/// - Element-wise ops fuse with anything reading their output
353/// - Matmul/Conv are BLAS-dispatched and form fusion boundaries
354/// - Reductions are fusion roots (drive the loop iteration)
355#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
356#[derive(Debug, Clone, PartialEq)]
357pub enum Op {
358 // ── Graph inputs ────────────────────────────────────────────
359 /// Model input with a name (shape on the Node).
360 Input {
361 name: String,
362 },
363
364 /// Model parameter (weight/bias) with a name.
365 Param {
366 name: String,
367 },
368
369 /// Constant tensor embedded in the graph.
370 Constant {
371 data: Vec<u8>,
372 },
373
374 // ── Element-wise unary ──────────────────────────────────────
375 /// Unary activation: one input, same shape output.
376 Activation(Activation),
377
378 /// Cast to a different dtype.
379 Cast {
380 to: DType,
381 },
382
383 /// INT8 quantization. Input f32; output i8 same shape.
384 /// `q[i] = saturate_i8(round(x[i] / scale[c]) + zero_point[c])`
385 /// where `c` selects the per-channel scale/zp when `axis = Some(d)`
386 /// (`c = idx[d]`), or always uses index 0 when `axis = None`
387 /// (per-tensor). The `scales` / `zero_points` payload length must
388 /// match `1` for per-tensor and `input.dim(d)` for per-channel.
389 /// Static — typically produced at calibration time and baked
390 /// into the loaded model. Use `Op::Dequantize` for the inverse.
391 Quantize {
392 axis: Option<usize>,
393 scales: Vec<f32>,
394 zero_points: Vec<i32>,
395 },
396
397 /// INT8 dequantization (inverse of `Op::Quantize`). Input i8;
398 /// output f32 same shape.
399 /// `x[i] = (q[i] - zero_point[c]) · scale[c]`
400 /// where `c` is selected by `axis` exactly as in `Op::Quantize`.
401 Dequantize {
402 axis: Option<usize>,
403 scales: Vec<f32>,
404 zero_points: Vec<i32>,
405 },
406
407 /// "Fake-quantize" op for **quantization-aware training** (QAT).
408 /// Input f32; output f32 same shape. Forward computes a per-axis
409 /// (or per-tensor when `axis = None`) max-abs scale on the fly:
410 /// `s[c] = max(|x[..., c, ...]|) / q_max(bits)`
411 /// then quantizes-then-dequantizes:
412 /// `out[i] = clamp(round(x[i] / s[c]), -q_max, q_max) * s[c]`
413 /// where `q_max` is `127` for `bits=8`, `7` for `bits=4`, `1` for
414 /// `bits=2` (ternary). Symmetric only — zero-point is always 0.
415 ///
416 /// The point of this op is to make the SGD optimizer "see" the
417 /// deployment-time rounding during training. Backward is the
418 /// **straight-through estimator** (STE): the gradient passes
419 /// through (variant chosen by `ste`), ignoring the discontinuity
420 /// at the round. Without STE the rounding would have zero
421 /// gradient almost everywhere and learning would stop.
422 ///
423 /// Inserted by the trainer on conv / FC weight tensors when
424 /// `--qat` is on; the existing `Op::Quantize` / packing path at
425 /// the end of training still handles the deployment-side
426 /// conversion to `i8`/`i4`/`i2` codes.
427 FakeQuantize {
428 bits: u8,
429 axis: Option<usize>,
430 ste: SteKind,
431 scale_mode: ScaleMode,
432 },
433
434 /// Learned Step Size Quantization (LSQ; Esser et al. 2020,
435 /// `arXiv:1902.08153`). Like `FakeQuantize` but the per-channel
436 /// `scale` is a *learned parameter*, passed as the second input.
437 /// Forward is identical to `FakeQuantize` with a fixed scale:
438 /// `out[i] = clamp(round(x[i]/s[c]), -q_max, q_max) * s[c]`
439 /// Backward computes both `dx` (STE) and `dscale[c]` via the
440 /// closed-form gradient:
441 /// `dscale[c] = sum_i ψ(x[i]/s[c]) · upstream[i]`
442 /// where `ψ(z) = -z + round(z)` if `|z| ≤ q_max` else
443 /// `sign(z) · q_max`. Routinely beats per-batch and EMA at
444 /// tight bit widths (i2 / i3).
445 ///
446 /// Inputs: `[x, scale]`. `scale` is `[chan_dim]` f32 (matches
447 /// `axis`); for `axis = None` it's `[1]`.
448 FakeQuantizeLSQ {
449 bits: u8,
450 axis: Option<usize>,
451 },
452
453 /// Backward pass for `Op::FakeQuantizeLSQ`. Computes BOTH the
454 /// gradient w.r.t. `x` (STE) and the gradient w.r.t. `scale`
455 /// (closed-form). Output shape matches `x`; the `scale` gradient
456 /// is reduced separately by `LsqScaleGradient`.
457 /// Inputs: `[x, scale, dy]`. Output: `dx`, same shape as `x`.
458 FakeQuantizeLSQBackwardX {
459 bits: u8,
460 axis: Option<usize>,
461 },
462
463 /// Companion to `FakeQuantizeLSQBackwardX`: computes the
464 /// `[chan_dim]` per-channel scale gradient. Inputs `[x, scale, dy]`.
465 /// Output shape matches `scale`.
466 FakeQuantizeLSQBackwardScale {
467 bits: u8,
468 axis: Option<usize>,
469 },
470
471 // ── Element-wise binary ─────────────────────────────────────
472 /// Binary op with broadcasting: two inputs, output shape is broadcast result.
473 Binary(BinaryOp),
474
475 // ── Comparison ──────────────────────────────────────────────
476 /// Element-wise comparison: two inputs, Bool output.
477 Compare(CmpOp),
478
479 /// Select elements: cond (Bool), on_true, on_false → output.
480 Where,
481
482 /// Fused element-wise region (PLAN L2). Holds an N-step chain of
483 /// element-wise operations. Inputs are referenced by index 0..num_inputs;
484 /// each step's result can be referenced by later steps via
485 /// `ChainOperand::Step(idx)`. The output is the last step's result.
486 /// Emitted by `MarkElementwiseRegions` in `rlx-opt` from chains of
487 /// Activation/Cast/Binary/Compare/Where ops with single-consumer
488 /// intermediates and broadcast-compatible shapes. Backends that
489 /// don't have a region kernel can decompose back to the original
490 /// chain via `unfuse::unfuse_elementwise_regions`.
491 ///
492 /// `scalar_input_mask` is a per-input bitfield (bit `i` set ⇒
493 /// input `i` is a scalar broadcast — has shape `[1]`). Kept as a
494 /// fast-path indicator that lets kernels skip the modulo entirely
495 /// when they detect a scalar.
496 ///
497 /// `input_modulus[i]` is the per-input element count, used by
498 /// kernels to compute `arena[input_offs[i] + (gid % input_modulus[i])]`
499 /// — the trailing-shape broadcast pattern. `0` means "no broadcast"
500 /// (input matches the output element count; kernel reads `gid`
501 /// directly). `1` means scalar; any other value means the input
502 /// has fewer elements than the output and they tile by modulo.
503 /// The encoder only allows broadcasts where `out_elems % in_elems
504 /// == 0` so the modulo divides cleanly. Lets chains include bias /
505 /// scale / eps / mask factors that previously broke the chain at
506 /// a Binary op with mismatched shapes.
507 ElementwiseRegion {
508 chain: Vec<ChainStep>,
509 num_inputs: u32,
510 scalar_input_mask: u32,
511 input_modulus: [u32; 16],
512 },
513
514 // ── Linear algebra ──────────────────────────────────────────
515 /// Matrix multiply. Inputs: [.., M, K] × [.., K, N] → [.., M, N].
516 /// Batch dimensions are broadcast.
517 MatMul,
518
519 /// Matrix multiply with explicit dimension specification.
520 /// Like XLA's DotGeneral — handles arbitrary batch/contracting dims.
521 DotGeneral {
522 lhs_contracting: Vec<usize>,
523 rhs_contracting: Vec<usize>,
524 lhs_batch: Vec<usize>,
525 rhs_batch: Vec<usize>,
526 },
527
528 /// Batched dense linear solve. Inputs: `A [B, N, N]`,
529 /// `b [B, N]` or `b [B, N, K]`. Output: same shape as `b`.
530 ///
531 /// Per-batch independent solve — each `A[i]` and `b[i]` are
532 /// solved as a separate `Op::DenseSolve`. Emitted by vmap of
533 /// `Op::DenseSolve`. The CPU lowering loops over the batch
534 /// dimension calling `dgesv` per slice (LAPACK doesn't expose a
535 /// batched solve on Accelerate; cuSOLVER does on NVIDIA).
536 BatchedDenseSolve,
537
538 /// Dense linear solve `x = A⁻¹ · b` via LU factorization.
539 /// Inputs: `A [N, N]`, `b [N]` (or `b [N, K]` for multi-RHS).
540 /// Output: same shape as `b`.
541 ///
542 /// VJP via the implicit-function theorem:
543 /// `dx = solve(Aᵀ, upstream)`
544 /// `dA = -outer(dx, x)` (x is the forward output)
545 /// `db = dx`
546 /// The rule is dtype-agnostic; lowering is per-backend (Accelerate
547 /// `dgesv` / `sgesv`, cuSOLVER, etc.).
548 DenseSolve,
549
550 // ── Normalization ───────────────────────────────────────────
551 /// Layer normalization: input, gamma, beta → normalized output.
552 /// `axis` is the feature dimension (usually -1).
553 LayerNorm {
554 axis: i32,
555 eps: f32,
556 },
557
558 /// Group normalization on NCHW tensors: `input`, `gamma`, `beta` → same shape.
559 /// Normalizes over `(C/num_groups) × H × W` per group.
560 GroupNorm {
561 num_groups: usize,
562 eps: f32,
563 },
564
565 /// LayerNorm2d on NCHW: normalize across the channel axis at each spatial
566 /// position (candle / SAM `LayerNorm2d` semantics — not PyTorch's H×W norm).
567 LayerNorm2d {
568 eps: f32,
569 },
570
571 /// Nearest-neighbor 2× upsample on NCHW (doubles spatial dims 2 and 3).
572 ResizeNearest2x,
573
574 /// RMS normalization: input, gamma → normalized output.
575 RmsNorm {
576 axis: i32,
577 eps: f32,
578 },
579
580 // ── Attention ───────────────────────────────────────────────
581 /// Scaled dot-product attention: Q, K, V, \[mask\] → output.
582 /// The compiler can lower this to fused SDPA or flash attention.
583 /// `mask_kind` controls how masking is applied — `Custom` reads from
584 /// the 4th input tensor; `None` / `Causal` / `SlidingWindow` skip the
585 /// mask load and apply the mask directly in the inner loop. See
586 /// `MaskKind` for the rationale.
587 ///
588 /// `score_scale`: when `Some(s)`, dot-product scores are multiplied by
589 /// `s` instead of the default `1/sqrt(head_dim)` (Gemma uses `head_dim^-0.5`
590 /// explicitly in config). `attn_logit_softcap`: when `Some(c)`, applies
591 /// `c * tanh(s/c)` to scores before softmax (Gemma 2).
592 Attention {
593 num_heads: usize,
594 head_dim: usize,
595 mask_kind: MaskKind,
596 score_scale: Option<f32>,
597 attn_logit_softcap: Option<f32>,
598 },
599
600 /// Rotary position embedding applied to one tensor: x, cos, sin → x_rotated.
601 /// Apply separately to Q and K. `head_dim` is the per-head width; `n_rot`
602 /// is how many leading dims get NeoX RoPE (pair offset `n_rot/2`). When
603 /// `n_rot < head_dim`, trailing dims are copied unchanged (Qwen3.5 MRoPE).
604 Rope {
605 head_dim: usize,
606 n_rot: usize,
607 },
608
609 /// SAM2 axial 2-D RoPE on `[batch, seq, num_heads * head_dim]`.
610 AxialRope2d {
611 end_x: usize,
612 end_y: usize,
613 head_dim: usize,
614 num_heads: usize,
615 theta: f32,
616 repeat_factor: usize,
617 },
618
619 // ── Shape manipulation ──────────────────────────────────────
620 Reshape {
621 new_shape: Vec<i64>,
622 },
623 Transpose {
624 perm: Vec<usize>,
625 },
626 /// Select a contiguous slice along an axis.
627 Narrow {
628 axis: usize,
629 start: usize,
630 len: usize,
631 },
632 /// Concatenate along an axis.
633 Concat {
634 axis: usize,
635 },
636 /// Expand (broadcast) to a target shape.
637 Expand {
638 target_shape: Vec<i64>,
639 },
640 /// Gather elements by index along an axis (embedding lookup).
641 Gather {
642 axis: usize,
643 },
644
645 // ── Reduction ───────────────────────────────────────────────
646 /// Reduce along specified axes.
647 Reduce {
648 op: ReduceOp,
649 axes: Vec<usize>,
650 keep_dim: bool,
651 },
652
653 /// Selective scan (plan #15) — Mamba-style state-space model
654 /// step. The recurrence:
655 /// `h[t] = exp(Δ[t] * A) * h[t-1] + Δ[t] * B[t] * x[t]`
656 /// `y[t] = C[t] * h[t]`
657 /// where state `h` has dimension `state_size` and the input has
658 /// `(batch, seq, hidden)`.
659 ///
660 /// Inputs (in order):
661 /// `x [b, s, h]` f32 input
662 /// `delta [b, s, h]` f32 step size (per-position, per-channel)
663 /// `a [h, n]` f32 transition matrix (one per channel)
664 /// `b [b, s, n]` f32 input projection
665 /// `c [b, s, n]` f32 output projection
666 /// Output: `[b, s, h]` f32. State `h` is implicit; the kernel
667 /// scans through the seq dimension carrying it.
668 ///
669 /// `state_size` = `n` is exposed for the cost model.
670 SelectiveScan {
671 state_size: usize,
672 },
673
674 /// Gated DeltaNet linear-attention recurrence — the per-layer
675 /// kernel used by Qwen3.5/3.6 trunk "linear attention" blocks
676 /// (and Qwen3-Next, Kimi-Linear). Mirrors
677 /// `llama.cpp / src/models/delta-net-base.cpp` autoregressive
678 /// path; chunked + fused variants ride the same op identity.
679 ///
680 /// **Math (per token `t`, head `h`, state size `n`):**
681 /// state matrix `S[h, i, j]` is implicit (reset per batch).
682 /// ```text
683 /// S[h] *= exp(g[t,h]) # scalar gate
684 /// sk[h,j] = Σ_i S[h,i,j] * k[t,h,i]
685 /// d[h,j] = (v[t,h,j] - sk[h,j]) * b[t,h] # b = beta
686 /// S[h,i,j] += k[t,h,i] * d[h,j] # outer-prod
687 /// o[t,h,j] = Σ_i S[h,i,j] * (q[t,h,i] / √n)
688 /// ```
689 ///
690 /// Inputs:
691 /// `q [b, s, h_v, n]` f32 queries (L2-normed by caller)
692 /// `k [b, s, h_v, n]` f32 keys (L2-normed by caller;
693 /// GQA-repeated to match `h_v`)
694 /// `v [b, s, h_v, n]` f32 values
695 /// `g [b, s, h_v]` f32 log-gate (exp'd inside kernel)
696 /// `beta [b, s, h_v]` f32 delta-rule mixing factor
697 ///
698 /// Output: `[b, s, h_v, n]` f32.
699 ///
700 /// When `carry_state` is true, a sixth input `state [b, h_v, n, n]`
701 /// provides the initial SSM matrix per head; the kernel updates it
702 /// in place across the sequence and leaves the final state in the
703 /// same buffer (same layout as the internal scan state:
704 /// `state[h, i, j]` row-major over `(n, n)` per head).
705 GatedDeltaNet {
706 state_size: usize,
707 carry_state: bool,
708 },
709
710 /// Fused dequant + matmul (plan #5). The biggest LLM-bandwidth
711 /// win on Apple Silicon: dequantizes weights inside the matmul
712 /// inner loop, never materializing f32 weights.
713 ///
714 /// **BREAKING CHANGE in 0.2.0:** `num_inputs()` is now
715 /// scheme-dependent — **4** for legacy Int8 schemes, **2** for
716 /// the new GGUF K-quant schemes (their scales/mins live inside
717 /// the packed bytes, so no side-channel `scale` / `zp` tensors
718 /// are fed in). Callers that assumed a fixed 4-input contract
719 /// must inspect `scheme.is_gguf()` before reading inputs.
720 ///
721 /// Inputs (Int8 schemes — `scheme.is_gguf() == false`):
722 /// `x [m, k]` f32 activations
723 /// `w_q [k, n]` packed quantized weight bytes (i8 per
724 /// element for Int8 schemes; 4-bit
725 /// packed two-per-byte for Int4)
726 /// `scale [k/block, n]` per-block f32 dequant scale
727 /// `zp [k/block, n]` per-block f32 zero-point
728 /// (zero-tensor if symmetric)
729 ///
730 /// Inputs (`Nvfp4Block` — fixed group size 16 along K):
731 /// `x [m, k]` f32 activations
732 /// `w_q [k,n/2]` packed FP4 E2M1 codes (unsigned nibble 0..15)
733 /// `scale [k/16, n]` u8 FP8 E4M3 block scales (one byte / group)
734 /// `global_scale [1]` f32 per-tensor scale (pass `[1.0]` if unused)
735 ///
736 /// Inputs (GGUF schemes — `scheme.is_gguf() == true`):
737 /// `x [m, k]` f32 activations
738 /// `packed_w [bytes]` raw GGUF super-block bytes; the
739 /// dequantizer reads the per-sub-block
740 /// scales / mins / quants directly out
741 /// of the buffer per the K-quant block
742 /// layout (no side tensors).
743 ///
744 /// Output: `[m, n]` f32.
745 ///
746 /// `block_size` (on the Int8 schemes only) is the number of
747 /// consecutive elements that share one (scale, zero_point) pair.
748 /// The Op carries enough metadata that the kernel doesn't need
749 /// a separate `QuantMap` lookup at run time.
750 DequantMatMul {
751 scheme: crate::quant::QuantScheme,
752 },
753
754 /// Real INT8-arithmetic matrix multiply with i32 accumulation.
755 /// Inputs (in order):
756 /// `x [M, K]` i8 activations (zero-point = `x_zp`)
757 /// `w [K, N]` i8 weights (zero-point = `w_zp`)
758 /// `bias [N]` i32 (in accumulator scale = `x_scale·w_scale`),
759 /// pass a zeros tensor for "no bias"
760 /// Output: `[M, N]` i8 (zero-point = `out_zp`)
761 ///
762 /// Per-element compute:
763 /// `out[m,n] = requantize(bias[n] + Σₖ (x[m,k]-x_zp)·(w[k,n]-w_zp), mult, out_zp)`
764 /// where `mult = x_scale · w_scale / out_scale`.
765 ///
766 /// This is the same kernel shape `rlx-cortexm/src/dense.rs`
767 /// uses for on-device int8 inference, lifted into the IR so the
768 /// rlx-cpu backend can run a quantized graph directly (instead
769 /// of round-tripping through fake-quant Dequantize → MatMul →
770 /// Quantize). 2-D only — generalizing to batched comes when a
771 /// real workload demands it.
772 QMatMul {
773 x_zp: i32,
774 w_zp: i32,
775 out_zp: i32,
776 mult: f32,
777 },
778
779 /// Real INT8-arithmetic 2-D convolution with i32 accumulation.
780 /// Inputs:
781 /// `x [N, C_in, H, W]` i8 (zero-point = `x_zp`)
782 /// `w [C_out, C_in/groups, kH, kW]` i8 (zero-point = `w_zp`)
783 /// `bias [C_out]` i32 in accumulator scale
784 /// Output: `[N, C_out, H_out, W_out]` i8 (zero-point = `out_zp`).
785 /// Same NCHW geometry contract as `Op::Conv`; same requantize
786 /// math as `Op::QMatMul` (per-element `acc·mult` rounded to i8).
787 QConv2d {
788 kernel_size: Vec<usize>,
789 stride: Vec<usize>,
790 padding: Vec<usize>,
791 dilation: Vec<usize>,
792 groups: usize,
793 x_zp: i32,
794 w_zp: i32,
795 out_zp: i32,
796 mult: f32,
797 },
798
799 /// Fused LoRA matmul: `out = x·W + scale * x·A·B`.
800 /// Inputs (in order): `x [m, k]`, `w [k, n]`, `a [k, r]`, `b [r, n]`.
801 /// `r` is the LoRA rank (typically 4-64). `scale` is the
802 /// per-adapter `alpha / rank` knob.
803 /// Plan #9: lifts LoRA from "three matmuls + an add" into one
804 /// kernel that keeps the rank-r intermediate in registers.
805 LoraMatMul {
806 scale: f32,
807 },
808
809 /// Fused sampling kernel: logits → optional top-k filter →
810 /// optional top-p truncation → softmax → multinomial sample.
811 /// One f32-encoded sampled token id per batch row (output
812 /// shape `[batch]`).
813 ///
814 /// `temperature == 1.0` matches a plain argmax-of-softmax;
815 /// lower → sharper, higher → flatter. `top_k == 0` disables.
816 /// `top_p == 1.0` disables. `seed` is the Philox seed; pass 0
817 /// for "use process-global counter" (still deterministic
818 /// given the call order).
819 /// Borrowed from MAX's nn/sampling.mojo (#42 in PLAN.md).
820 /// Latency-critical: never materializes the full softmax
821 /// distribution on the host.
822 Sample {
823 top_k: usize, // 0 = disabled
824 top_p: f32, // 1.0 = disabled
825 temperature: f32, // 1.0 = neutral
826 seed: u64, // 0 = use thread-local counter
827 },
828
829 /// Inclusive cumulative sum along an axis. Same shape in/out.
830 /// Underpins ragged-tensor offsets, sampling (top-p prefix sum),
831 /// and sequence-position math (#44 in PLAN.md).
832 /// `exclusive=true` shifts the result so output\[0\] = 0 (useful
833 /// for offset arrays where the first segment starts at 0).
834 Cumsum {
835 axis: i32,
836 exclusive: bool,
837 },
838
839 /// Softmax along an axis (reduction + element-wise).
840 Softmax {
841 axis: i32,
842 },
843
844 /// Top-K **indices** along the last axis. Output shape `[..., k]`,
845 /// f32-encoded indices (rlx is f32-only at the I/O boundary).
846 /// To recover the values, follow with a `Gather` against the
847 /// original tensor — works because Gather already supports any axis.
848 /// Ties broken by smaller index (matches NumPy / PyTorch
849 /// `torch.topk(..., largest=True, sorted=True)`).
850 /// Used by MoE gating; also useful for beam search.
851 TopK {
852 k: usize,
853 },
854
855 /// Indexed batched matmul. The MoE GEMM primitive.
856 /// Inputs: `[input, weight, expert_idx]`
857 /// input : [M, K] — per-token activations
858 /// weight : [num_experts, K, N] — stacked expert weights
859 /// expert_idx : \[M\] — f32-encoded expert id per token
860 /// Output : [M, N] — output\[i\] = input\[i\] @ weight[expert_idx\[i\]]
861 /// Naive impl on both backends; future work can replace with a
862 /// segmented/grouped GEMM when there's a real workload.
863 GroupedMatMul,
864
865 /// Fused GGUF K-quant dequant + [`Op::GroupedMatMul`]. Same three
866 /// inputs as `GroupedMatMul`, but `weight` is a U8 tensor holding
867 /// `num_experts` contiguous packed expert slabs (GGML layout, expert
868 /// dimension outermost). Scales live inside the packed bytes.
869 DequantGroupedMatMul {
870 scheme: crate::quant::QuantScheme,
871 },
872
873 /// Dequant a packed MoE expert stack to F32 `[num_experts, K, N]` in
874 /// GroupedMatMul layout. Input: U8 packed bytes; output shape is
875 /// declared on the node (`[E, K, N]`).
876 DequantMoEWeights {
877 scheme: crate::quant::QuantScheme,
878 },
879
880 /// Scatter-add into a destination tensor. The "unpermute" half of
881 /// MoE routing (also useful for embedding gradient updates).
882 /// Inputs: `[updates, indices]`
883 /// updates : [num_updates, trailing] — values to add
884 /// indices : \[num_updates\] — f32-encoded destination row
885 /// Output : [out_dim, trailing] — output[indices\[i\]] += updates\[i\]
886 /// `out_dim` is taken from the node's declared output shape.
887 /// Initial output is zero; multiple updates to the same row
888 /// accumulate (sequentially on CPU; with atomic-add on Metal).
889 ScatterAdd,
890
891 // ── Convolution ─────────────────────────────────────────────
892 /// 2D convolution on NCHW tensors. Also exposed as [`OpKind::Conv`] / `conv2d`.
893 /// Weight layout: `[C_out, C_in / groups, kH, kW]`.
894 Conv {
895 kernel_size: Vec<usize>,
896 stride: Vec<usize>,
897 padding: Vec<usize>,
898 dilation: Vec<usize>,
899 groups: usize,
900 },
901
902 /// 2D transposed convolution on NCHW. Weight layout (PyTorch):
903 /// `[C_in, C_out / groups, kH, kW]`.
904 ConvTranspose2d {
905 kernel_size: Vec<usize>,
906 stride: Vec<usize>,
907 padding: Vec<usize>,
908 dilation: Vec<usize>,
909 output_padding: Vec<usize>,
910 groups: usize,
911 },
912
913 // ── Pooling ─────────────────────────────────────────────────
914 Pool {
915 kind: ReduceOp,
916 kernel_size: Vec<usize>,
917 stride: Vec<usize>,
918 padding: Vec<usize>,
919 },
920
921 // ── Backward / training ops ─────────────────────────────────
922 //
923 // Closed-form gradient nodes emitted by `rlx-opt::autodiff`.
924 // Pairing a forward op with a dedicated backward op (instead of
925 // composing it from primitives) keeps the gradient kernel simple
926 // and lets the backend recompute argmax / masks / softmax inline.
927 /// ReLU backward: `dx = dy where x > 0 else 0`.
928 /// Inputs: `[x, dy]` — both same shape; output matches.
929 ReluBackward,
930
931 /// Element-wise complex squared-magnitude: `|z|² = z.re² + z.im²`.
932 /// Input: 1 tensor with `DType::C64`. Output: same shape but
933 /// `DType::F32`. The natural real-valued loss surface for
934 /// Wirtinger reverse-mode AD on complex graphs — pair with
935 /// [`Op::ComplexNormSqBackward`].
936 ComplexNormSq,
937
938 /// Element-wise complex conjugate: `z̄ = z.re - i·z.im` per element.
939 /// Input: 1 tensor with `DType::C64`. Output: same shape, same dtype.
940 /// Used by Wirtinger VJP rules on `Op::Binary` over C64 (the rule
941 /// for `y = a·b` is `dL/dā = upstream · conj(b)`, etc.).
942 Conjugate,
943
944 /// Backward for [`Op::ComplexNormSq`] under Wirtinger calculus.
945 /// `f(z) = |z|² = z·z̄`, so `∂f/∂z̄ = z`. Given upstream real
946 /// cotangent `g` (same shape as the forward output), the C64
947 /// gradient with respect to `z` is `g · z` element-wise, returned
948 /// in C64 storage `[re_g·re_z, re_g·im_z]` per element.
949 ///
950 /// Inputs: `[z (C64), g (F32)]` — both same logical shape; output
951 /// matches `z` (C64).
952 ComplexNormSqBackward,
953
954 /// LayerNorm backward w.r.t. the input. Computes
955 /// `d_x[..., d] = inv_std · (dy·γ - mean(dy·γ) - x̂·mean(dy·γ·x̂))`
956 /// over the feature axis, where `x̂ = (x - mean)/std` is recomputed
957 /// inline from `x`. Inputs: `[x, gamma, dy]`; output shape = `x.shape`.
958 /// Currently lowers axis=-1 only (matches the forward thunk).
959 LayerNormBackwardInput {
960 axis: i32,
961 eps: f32,
962 },
963
964 /// LayerNorm backward w.r.t. gamma. Computes
965 /// `d_gamma[d] = Σ_{batch} dy[..., d] · x̂[..., d]`
966 /// — sums the per-element product of upstream and the (recomputed)
967 /// normalized input over the leading axes. Inputs: `[x, dy]`;
968 /// output shape = `gamma.shape` (= 1-D feature axis).
969 LayerNormBackwardGamma {
970 axis: i32,
971 eps: f32,
972 },
973
974 /// RMSNorm backward w.r.t. input. Inputs `[x, gamma, beta, dy]`; output = `x.shape`.
975 RmsNormBackwardInput {
976 axis: i32,
977 eps: f32,
978 },
979
980 /// RMSNorm backward w.r.t. gamma. Inputs `[x, gamma, beta, dy]`; output = `gamma.shape`.
981 RmsNormBackwardGamma {
982 axis: i32,
983 eps: f32,
984 },
985
986 /// RMSNorm backward w.r.t. beta. Inputs `[x, gamma, beta, dy]`; output = `beta.shape`.
987 RmsNormBackwardBeta {
988 axis: i32,
989 eps: f32,
990 },
991
992 /// RoPE backward w.r.t. `x`. Inputs `[dy, cos, sin]`; output = `dy.shape`.
993 RopeBackward {
994 head_dim: usize,
995 n_rot: usize,
996 },
997
998 /// GroupNorm (NCHW) backward w.r.t. input. Inputs `[x, gamma, beta, dy]`.
999 GroupNormBackwardInput {
1000 num_groups: usize,
1001 eps: f32,
1002 },
1003
1004 /// GroupNorm backward w.r.t. gamma. Inputs `[x, dy]`; output = `gamma.shape`.
1005 GroupNormBackwardGamma {
1006 num_groups: usize,
1007 eps: f32,
1008 },
1009
1010 /// GroupNorm backward w.r.t. beta. Inputs `[x, dy]`; output = `beta.shape`.
1011 GroupNormBackwardBeta {
1012 num_groups: usize,
1013 eps: f32,
1014 },
1015
1016 /// Cumsum backward along `axis`. Inputs `[dy]`; output matches forward input shape.
1017 CumsumBackward {
1018 axis: i32,
1019 exclusive: bool,
1020 },
1021
1022 /// Gather backward (scatter-add into table). Inputs `[dy, indices]`; output = table shape.
1023 /// `axis` matches forward [`Op::Gather`].
1024 GatherBackward {
1025 axis: i32,
1026 },
1027
1028 /// Generic element-wise activation backward. `kind` selects the
1029 /// closed-form derivative `d/dx act(x)`. Inputs: `[x, dy]`; output
1030 /// shape matches `x`. The kernel computes `d/dx · dy` per element.
1031 ///
1032 /// Closed forms (all element-wise):
1033 /// * `Gelu` — exact derivative of erf-based GELU.
1034 /// * `GeluApprox` — derivative of the tanh approximation
1035 /// `0.5 x (1 + tanh(√(2/π)(x + 0.044715 x³)))`.
1036 /// * `Silu` — `σ(x)·(1 + x·(1 - σ(x)))`.
1037 /// * `Sigmoid` — `σ(x)·(1 - σ(x))`.
1038 /// * `Tanh` — `1 - tanh(x)²`.
1039 /// * `Exp` — `exp(x)`.
1040 /// * `Log` — `1 / x`.
1041 /// * `Sqrt` — `0.5 / sqrt(x)`.
1042 /// * `Rsqrt` — `-0.5 · x^(-3/2)`.
1043 /// * `Neg` — `-1`.
1044 /// * `Abs` — `sign(x)` (zero at x=0).
1045 /// * `Sin` — `cos(x)`.
1046 /// * `Cos` — `-sin(x)`.
1047 /// * `Tan` — `1 + tan²(x)`.
1048 /// * `Atan` — `1 / (1 + x²)`.
1049 /// * `Relu` — kept here for completeness; the dedicated
1050 /// `ReluBackward` op is preferred for relu and is what the
1051 /// autodiff pass actually emits.
1052 ActivationBackward {
1053 kind: Activation,
1054 },
1055
1056 /// Backward for `Op::FakeQuantize` under a non-default STE.
1057 /// Inputs `[x, dy]`: the forward input and the upstream
1058 /// gradient. Output `dx` same shape. The `bits`/`axis`/`ste`
1059 /// fields must match the forward op so the kernel computes the
1060 /// same per-channel scale and applies the right STE attenuation.
1061 /// For `SteKind::Identity` this op is unnecessary — autodiff
1062 /// just routes `upstream` through unchanged.
1063 FakeQuantizeBackward {
1064 bits: u8,
1065 axis: Option<usize>,
1066 ste: SteKind,
1067 },
1068
1069 /// 2D max-pool backward. Routes each element of `dy` back into the
1070 /// position in `x`'s window where the forward max was taken.
1071 /// Inputs: `[x, dy]` with `x [N, C, H, W]` and
1072 /// `dy [N, C, H_out, W_out]`. Output: same shape as `x`.
1073 /// Carries the forward pool's geometry so the kernel can recompute
1074 /// the argmax position per window without a saved-indices tensor.
1075 MaxPool2dBackward {
1076 kernel_size: Vec<usize>,
1077 stride: Vec<usize>,
1078 padding: Vec<usize>,
1079 },
1080
1081 /// 2D conv backward w.r.t. input. Computes `dx = conv_transpose(dy, w)`.
1082 /// Inputs: `[dy, w]` with `dy [N, C_out, H_out, W_out]` and
1083 /// `w [C_out, C_in/groups, kH, kW]`. Output: `[N, C_in, H, W]`
1084 /// (declared on the node — caller knows the original input shape).
1085 /// Geometry is the forward conv's parameters, not the transposed
1086 /// conv's.
1087 Conv2dBackwardInput {
1088 kernel_size: Vec<usize>,
1089 stride: Vec<usize>,
1090 padding: Vec<usize>,
1091 dilation: Vec<usize>,
1092 groups: usize,
1093 },
1094
1095 /// 2D conv backward w.r.t. weight. Computes
1096 /// `dw[c_out, c_in, kh, kw] = sum_{n,h_out,w_out} x[n,c_in,...] * dy[n,c_out,h_out,w_out]`.
1097 /// Inputs: `[x, dy]`. Output: `[C_out, C_in/groups, kH, kW]`.
1098 Conv2dBackwardWeight {
1099 kernel_size: Vec<usize>,
1100 stride: Vec<usize>,
1101 padding: Vec<usize>,
1102 dilation: Vec<usize>,
1103 groups: usize,
1104 },
1105
1106 /// Fused softmax + cross-entropy loss with integer (f32-encoded)
1107 /// targets — the standard classification loss. Per-row output:
1108 /// `loss[n] = -log(softmax(logits[n])[labels[n]])`.
1109 /// Inputs: `[logits, labels]` with `logits [N, C]` and
1110 /// `labels [N]` (f32-encoded class indices). Output: `[N]`.
1111 /// Caller does the `Reduce::Mean` if they want a scalar.
1112 SoftmaxCrossEntropyWithLogits,
1113
1114 /// Backward of the fused loss above. Emits
1115 /// `dlogits[n,c] = (softmax(logits[n])[c] - one_hot(labels)[n,c]) * d_loss[n]`.
1116 /// Inputs: `[logits, labels, d_loss]`. Output: `[N, C]` (same shape
1117 /// as `logits`). Recomputes the softmax inline rather than threading
1118 /// it through from the forward node.
1119 SoftmaxCrossEntropyBackward,
1120
1121 /// Backward of [`Op::Attention`]. Recomputes scaled `QK^T`, applies
1122 /// the same `mask_kind` as the forward op, softmaxes scores, then
1123 /// emits **one** of `dQ`, `dK`, or `dV` selected by [`AttentionBwdWrt`].
1124 /// Autodiff emits three nodes (one per `wrt`) so each output shape
1125 /// stays a normal single-output MIR node.
1126 ///
1127 /// Inputs: `[q, k, v, dy]` plus optional mask when `mask_kind` is
1128 /// [`MaskKind::Custom`] or [`MaskKind::Bias`] (same convention as
1129 /// forward). Output shape matches `q`, `k`, or `v` respectively.
1130 AttentionBackward {
1131 num_heads: usize,
1132 head_dim: usize,
1133 mask_kind: MaskKind,
1134 wrt: AttentionBwdWrt,
1135 },
1136
1137 // ── Fused operations (created by optimization passes) ──────
1138 /// Fused matmul + bias + activation. Created from MatMul → Add → Activation.
1139 FusedMatMulBiasAct {
1140 activation: Option<Activation>,
1141 },
1142
1143 /// Fused residual + optional bias + layer norm.
1144 /// Created from Add(x, residual) → [Add(bias)] → LayerNorm.
1145 FusedResidualLN {
1146 has_bias: bool,
1147 eps: f32,
1148 },
1149
1150 /// Fused residual + optional bias + RMS norm.
1151 /// Created from Add(x, residual) → [Add(bias)] → RmsNorm.
1152 FusedResidualRmsNorm {
1153 has_bias: bool,
1154 eps: f32,
1155 },
1156
1157 /// Fused SwiGLU: split input into up/gate halves, silu(gate) * up.
1158 /// Created from Split → Silu → Mul when fed by a concatenated matmul.
1159 ///
1160 /// `cast_to`: optional output dtype — when `Some(dt)` the kernel casts
1161 /// its result from the input dtype to `dt` in-register, saving a
1162 /// separate cast pass. Reserved for future fp8/fp4 quantization paths;
1163 /// for f32→f16 mixed precision the AutoMixedPrecision pass already
1164 /// inserts a Cast node so this stays `None` in current pipelines.
1165 FusedSwiGLU {
1166 cast_to: Option<DType>,
1167 /// When `true`, the concatenated input stores gate in the low half
1168 /// `[..., 0..N)` and up in the high half `[..., N..2N)` — the layout
1169 /// produced when gate projection is emitted before up in the builder.
1170 /// Default `false`: up @ low, gate @ high (canonical concat order).
1171 gate_first: bool,
1172 },
1173
1174 /// Fused full transformer layer: attention block + residual+LN + FFN + residual+LN.
1175 /// All intermediates resident in registers/threadgroup memory; one kernel
1176 /// per layer instead of ~30 (the CPU's batch=1 win, lifted to IR so any
1177 /// backend can implement it as a monolithic kernel).
1178 ///
1179 /// Inputs: hidden, qkv_w, qkv_b, out_w, out_b,
1180 /// ln1_g, ln1_b, fc1_w, fc1_b, fc2_w, fc2_b, ln2_g, ln2_b, mask
1181 /// Output: same shape as hidden.
1182 ///
1183 /// **Backend status:** same as FusedAttentionBlock. CPU implements
1184 /// the L1-cache-resident merge at the thunk level. Metal deferred —
1185 /// requires a single MSL kernel for the whole layer to actually
1186 /// beat the unfused path. Multi-day work; revisit when there's a
1187 /// model whose Metal inference is bottlenecked here rather than on
1188 /// the wait latency floor.
1189 FusedTransformerLayer {
1190 num_heads: usize,
1191 head_dim: usize,
1192 intermediate_size: usize,
1193 eps1: f32,
1194 eps2: f32,
1195 activation: Activation,
1196 has_bias: bool,
1197 },
1198
1199 /// Fused attention block: QKV projection → split → \[RoPE\] → SDPA → output projection.
1200 /// Created by FuseAttentionBlock pass when batch*seq is small.
1201 /// All intermediates stay in L1 cache — no arena writes between ops.
1202 ///
1203 /// Inputs (in order):
1204 /// hidden, qkv_w, out_w, mask,
1205 /// [qkv_b, out_b] if has_bias,
1206 /// [rope_cos, rope_sin] if has_rope
1207 ///
1208 /// **Backend status (Phase C finalize):**
1209 /// CPU — implemented at the *thunk* level: the CPU schedule
1210 /// recognizes the multi-thunk pattern and merges into
1211 /// a single FusedAttnBlock that keeps Q/K/V in stack
1212 /// buffers across stages (the L1-cache win).
1213 /// Metal — **deferred**. A dispatch-wrapper version (chaining
1214 /// existing kernels) buys nothing the unfused Metal path
1215 /// doesn't already get, since per-run cost is dominated
1216 /// by `wait_until_completed` (~150 µs), not encode. The
1217 /// real win is a single MSL kernel keeping Q/K/V in
1218 /// threadgroup memory across stages — multi-day work.
1219 /// Until then, Metal runs the unfused chain (one matmul,
1220 /// three narrows, two ropes, attention, one matmul) — all
1221 /// covered in op_coverage and parity_harness.
1222 FusedAttentionBlock {
1223 num_heads: usize,
1224 head_dim: usize,
1225 has_bias: bool,
1226 has_rope: bool,
1227 },
1228
1229 // ── Control flow (subgraphs as op payloads) ─────────────────
1230 //
1231 // Status: IR is defined; helper `run_if` / `run_while` exist in
1232 // rlx-runtime/src/subgraph.rs; **executor wiring is not yet
1233 // implemented** (both CPU thunk and Metal thunk fall through to
1234 // `Thunk::Nop` for these ops). Wiring requires:
1235 // 1. Recursive subgraph compile at parent-compile time.
1236 // 2. Per-subgraph input/output binding through the arena.
1237 // 3. Schedule-level dispatch when the predicate / loop cond is
1238 // resolved at runtime.
1239 // Estimate: 4–6 hours of focused work + parity tests. Deferred
1240 // because no current in-tree model uses these ops;
1241 // surface area without a validation target invites silent bugs.
1242 /// Conditional: pick between two subgraphs based on a boolean predicate.
1243 /// Inputs: [predicate, ...captures (used inside both branches)].
1244 /// `then_branch` and `else_branch` are sub-graphs that share the
1245 /// captured inputs and must produce identically-shaped outputs.
1246 /// Used for: shape-dependent execution, batched inference of
1247 /// dynamic-length sequences with padding masks.
1248 If {
1249 then_branch: Box<crate::Graph>,
1250 else_branch: Box<crate::Graph>,
1251 },
1252
1253 /// Loop: iterate `body` while `cond` evaluates true.
1254 /// Inputs: [...initial loop-carried values].
1255 /// `cond`'s single output is a Bool scalar.
1256 /// `body`'s outputs become the next iteration's loop-carried inputs.
1257 /// Outputs of While are the values after the final iteration.
1258 /// Used for: KV-cache-driven autoregressive generation, beam search.
1259 While {
1260 cond: Box<crate::Graph>,
1261 body: Box<crate::Graph>,
1262 max_iterations: Option<usize>,
1263 },
1264
1265 /// Bounded-length loop with a fixed-shape carry, optional per-step
1266 /// inputs, and optional stacked output. Mirrors JAX's `lax.scan`.
1267 ///
1268 /// Body signature: `(carry, x_t_0, ..., x_t_{num_xs-1}) → carry_next`
1269 /// — `1 + num_xs` Op::Inputs in NodeId construction order (first
1270 /// declared is the carry; the remaining `num_xs` are per-step
1271 /// slices). Single output (the next carry).
1272 ///
1273 /// Outer Op::Scan inputs (in order):
1274 /// `[init_carry, xs_0, xs_1, ..., xs_{num_xs-1}]`
1275 /// Each `xs_i` has shape `[length, *per_step_shape_i]`; the body
1276 /// sees `xs_i[t]` (a `per_step_shape_i` slice) on iteration `t`.
1277 ///
1278 /// Outer Op::Scan output:
1279 /// * `save_trajectory == false` — final carry, shape `*carry`.
1280 /// * `save_trajectory == true` — stacked trajectory of carries,
1281 /// shape `[length, *carry]`. Row `t` is the carry after step
1282 /// `t+1`, so row `length-1` matches the no-trajectory case.
1283 ///
1284 /// Mirrors JAX's `lax.scan`. Common uses include time-stepping
1285 /// integrators with time-varying drives, Mamba-style SSM scans
1286 /// reading per-step inputs, and RNN-style sequence processing.
1287 Scan {
1288 body: Box<crate::Graph>,
1289 length: u32,
1290 save_trajectory: bool,
1291 /// Number of "broadcast" inputs — values that are constant
1292 /// across iterations. Outer scan inputs in order:
1293 /// `[init, bcast_0..bcast_{B-1}, xs_0..xs_{X-1}]`
1294 /// Body Op::Inputs in NodeId order:
1295 /// `[carry, bcast_0..bcast_{B-1}, x_t_0..x_t_{X-1}]`
1296 /// CPU executor fills bcast slots ONCE before the iteration
1297 /// loop (xs slots are filled per-step). The reverse-mode AD
1298 /// pre-pass materialises each bcast into an xs of shape
1299 /// `[length, *bcast]` via broadcast `Mul` so the rest of the
1300 /// VJP / executor pipeline can stay unchanged. `0` (default)
1301 /// keeps the original carry+xs scan shape.
1302 num_bcast: u32,
1303 /// Number of per-step `xs` inputs. Total outer Op::Scan
1304 /// inputs is `1 + num_bcast + num_xs`.
1305 num_xs: u32,
1306 /// Number of trajectory checkpoints when `save_trajectory ==
1307 /// true`. `0` means "save all `length` rows" (default). A
1308 /// positive value `K` means save only `K` evenly-spaced rows
1309 /// at indices `floor(t * length / K)` for `t in 0..K`. Used
1310 /// by recursive checkpointed AD: store O(√T) carries during
1311 /// forward, recompute the rest in the backward pass.
1312 ///
1313 /// When `0` (or `K == length`), the saved trajectory has
1314 /// shape `[length, *carry]` — same as the original behavior.
1315 /// When `0 < K < length`, the saved trajectory has shape
1316 /// `[K, *carry]`.
1317 num_checkpoints: u32,
1318 },
1319
1320 /// Reverse-mode AD companion to `Op::Scan` — extracts the carry
1321 /// gradient `dinit`. Walks `t = length-1 .. 0`, applying `body_vjp`
1322 /// to thread `dcarry` back through the time loop.
1323 ///
1324 /// Inputs (in order):
1325 /// `[init, trajectory, upstream, xs_0, ..., xs_{num_xs-1}]`
1326 /// Output: `dinit`, shape = carry shape.
1327 ///
1328 /// `body_vjp` is the result of
1329 /// `autodiff::grad(body, [carry_id, xs_0_id, ..., xs_{num_xs-1}_id])`
1330 /// — a graph with `1 + num_xs + 1` Inputs (carry + x_t_i for each
1331 /// xs + `"d_output"`) and `1 + num_xs` outputs
1332 /// (dcarry + dx_t_i for each xs). This op reads `outputs[0]` =
1333 /// dcarry; the sibling [`Self::ScanBackwardXs`] reads the
1334 /// `outputs[1 + xs_idx]` slot for each xs gradient.
1335 ScanBackward {
1336 body_vjp: Box<crate::Graph>,
1337 length: u32,
1338 save_trajectory: bool,
1339 num_xs: u32,
1340 /// When `0` or equal to `length`, the trajectory input has
1341 /// shape `[length, *carry]` — every step's carry is cached
1342 /// (`CheckpointStrategy::All`). When `0 < K < length`, the
1343 /// trajectory input has shape `[K, *carry]` and the executor
1344 /// recomputes intermediate carries via `forward_body` between
1345 /// checkpoints. `forward_body` must be `Some` whenever this
1346 /// is < length.
1347 num_checkpoints: u32,
1348 /// Forward body (the same `body` from the forward Op::Scan).
1349 /// Required when `num_checkpoints > 0 && < length` so the
1350 /// executor can recompute carries between saved checkpoints.
1351 /// `None` for the All strategy (no recompute needed).
1352 forward_body: Option<Box<crate::Graph>>,
1353 },
1354
1355 /// Companion to [`Self::ScanBackward`] that extracts one stacked
1356 /// per-step `dxs_i` (shape `[length, *per_step_xs_i]`). Same inputs
1357 /// and same `body_vjp` graph as ScanBackward — `xs_idx` selects
1358 /// which body_vjp output to stack into the result.
1359 ///
1360 /// Note: each ScanBackwardXs runs its own backward loop. A future
1361 /// optimization can fuse them into a single multi-output backward
1362 /// kernel; for now it's `1 + num_xs` independent sweeps.
1363 ScanBackwardXs {
1364 body_vjp: Box<crate::Graph>,
1365 length: u32,
1366 save_trajectory: bool,
1367 num_xs: u32,
1368 xs_idx: u32,
1369 num_checkpoints: u32,
1370 forward_body: Option<Box<crate::Graph>>,
1371 },
1372
1373 /// CPU reference 3D Gaussian splat forward render.
1374 ///
1375 /// Seven flat F32 inputs (scene buffers + camera/render meta):
1376 /// 0. positions `[N*3]`
1377 /// 1. scales `[N*3]` (log-space)
1378 /// 2. rotations `[N*4]` (xyzw)
1379 /// 3. opacities `[N]` (logit)
1380 /// 4. colors `[N*3]` (linear RGB)
1381 /// 5. sh_coeffs `[N * sh_coeff_count * 3]`
1382 /// 6. meta `[23]` — camera position/target/up/fov/near/far, background RGB,
1383 /// then width/height/tile_size/radius_scale/alpha_cutoff/max_splat_steps/
1384 /// transmittance_threshold/max_list_entries as f32 bit-patterns.
1385 ///
1386 /// Output: `[height * width * 4]` linear RGBA (display gamma baked in).
1387 /// Build via [`crate::Graph::gaussian_splat_render`].
1388 ///
1389 /// Differentiable backward is not implemented in v1; autodiff treats this
1390 /// op as non-differentiable (same as [`Op::Sample`]).
1391 GaussianSplatRender {
1392 width: u32,
1393 height: u32,
1394 tile_size: u32,
1395 radius_scale: f32,
1396 alpha_cutoff: f32,
1397 max_splat_steps: u32,
1398 transmittance_threshold: f32,
1399 max_list_entries: u32,
1400 },
1401
1402 /// Backward pass for [`Self::GaussianSplatRender`].
1403 ///
1404 /// Eight inputs: the same seven as forward plus `d_loss_rgba` `[W*H*4]`
1405 /// (only RGB channels are used). Re-runs the training forward internally.
1406 ///
1407 /// Output: packed gradients
1408 /// `[positions(3N) | scales(3N) | rotations(4N) | opacities(N) | colors(3N) | sh(N*sh*3)]`.
1409 /// Unpack via [`crate::ops::splat::unpack_gaussian_splat_packed_grads`].
1410 GaussianSplatRenderBackward {
1411 width: u32,
1412 height: u32,
1413 tile_size: u32,
1414 radius_scale: f32,
1415 alpha_cutoff: f32,
1416 max_splat_steps: u32,
1417 transmittance_threshold: f32,
1418 max_list_entries: u32,
1419 loss_grad_clip: f32,
1420 sh_band: u32,
1421 max_anisotropy: f32,
1422 },
1423
1424 /// Strict IR stage 1: project, bin, sort, build per-pixel rays.
1425 ///
1426 /// Seven inputs (same scene + meta as [`Self::GaussianSplatRender`]). Output: packed
1427 /// prepare buffer (see `rlx_splat::prep_layout::prep_packed_len`).
1428 GaussianSplatPrepare {
1429 width: u32,
1430 height: u32,
1431 tile_size: u32,
1432 radius_scale: f32,
1433 alpha_cutoff: f32,
1434 max_splat_steps: u32,
1435 transmittance_threshold: f32,
1436 max_list_entries: u32,
1437 },
1438
1439 /// Strict IR stage 2: tile raster from [`Self::GaussianSplatPrepare`] output.
1440 ///
1441 /// Inputs: `prep` packed buffer, `meta` `[23]`. Output: `[width * height * 4]` RGBA.
1442 GaussianSplatRasterize {
1443 width: u32,
1444 height: u32,
1445 tile_size: u32,
1446 alpha_cutoff: f32,
1447 max_splat_steps: u32,
1448 transmittance_threshold: f32,
1449 max_list_entries: u32,
1450 },
1451
1452 /// User-registered custom op. `name` keys into the
1453 /// [`crate::op_registry`] for shape inference, autodiff, and
1454 /// per-backend execution. `attrs` is an opaque blob passed
1455 /// through to those callbacks (FFT direction, SparseLU
1456 /// reordering strategy, etc.). `num_inputs` is captured at
1457 /// construction time so [`Op::num_inputs`] stays infallible
1458 /// without a registry lookup. Build via [`crate::Graph::custom_op`].
1459 Custom {
1460 name: String,
1461 num_inputs: u32,
1462 attrs: Vec<u8>,
1463 },
1464
1465 /// 1D Fast Fourier Transform along the last axis.
1466 ///
1467 /// **Layouts**
1468 /// - `F32` / `F64`: 2N real-block — last axis is `[re₀…re_{N-1}, im₀…im_{N-1}]`.
1469 /// - `C64`: interleaved `[re, im]` pairs per complex element along the last axis.
1470 ///
1471 /// **ND transforms** — use `Graph::fftn` / `Graph::ifftn`, which compose
1472 /// `fft_axis` (transpose → Fft → transpose). Multi-axis `fftn` requires
1473 /// `DType::C64`; the 2N-block layout describes a single complex axis.
1474 ///
1475 /// Default (`FftNorm::Backward`) is **unnormalized** on both directions:
1476 /// `fft(x)[k] = Σ x[n]·exp(-2πi·nk/N)`
1477 /// `ifft(y)[n] = Σ y[k]·exp(+2πi·nk/N)`
1478 /// so `ifft(fft(x)) = N·x`. Use `FftNorm::Forward` for gpu-fft-style
1479 /// `1/N` scaling on inverse, or `FftNorm::Ortho` for unitary scaling.
1480 ///
1481 /// AD: VJP(`fft`) = `ifft`, VJP(`ifft`) = `fft` when `norm=Backward`;
1482 /// other norms apply the chain rule via output scaling.
1483 Fft {
1484 inverse: bool,
1485 norm: crate::fft::FftNorm,
1486 },
1487
1488 /// User-defined sub-graph with optional override AD rules.
1489 /// Mirrors JAX's `custom_vjp` / `custom_jvp` decorators: the
1490 /// caller wraps a forward computation and supplies its own
1491 /// reverse- and/or forward-mode AD bodies. Useful when:
1492 /// * The forward is iterative (Newton, fixed-point) and
1493 /// differentiating through the loop is wasteful — the
1494 /// vjp_body computes the implicit-function gradient at the
1495 /// converged point in one shot.
1496 /// * The math has a closed-form gradient that's much cheaper
1497 /// than autodiff.
1498 /// * The forward op is non-differentiable by tracing
1499 /// (sampling, argmax) and the user wants a smooth surrogate.
1500 ///
1501 /// **fwd_body**: `num_inputs` Op::Inputs in NodeId construction
1502 /// order, one Op::Output (the primal y). Forward execution
1503 /// inlines this body once.
1504 ///
1505 /// **vjp_body** (optional): Op::Inputs are `num_inputs` primal
1506 /// inputs in NodeId order, plus two special-named Inputs —
1507 /// `"primal_output"` (the y from forward) and `"d_output"` (the
1508 /// upstream gradient). Outputs: `num_inputs` tensors in
1509 /// `set_outputs` order, matching the gradients of each primal
1510 /// input. When `None`, reverse-mode AD recurses into fwd_body
1511 /// — same as if the op were inlined.
1512 ///
1513 /// **jvp_body** (optional): Op::Inputs are `num_inputs` primal
1514 /// inputs in NodeId order, `num_inputs` special-named Inputs
1515 /// `"tangent_0"..="tangent_{num_inputs-1}"` carrying each input's
1516 /// tangent, and an optional special-named `"primal_output"` Input
1517 /// (the y from forward, useful when the JVP must be evaluated at
1518 /// a converged / nonlinear point — e.g. IFT-style forward-mode AD
1519 /// of an iterative solver). Output: 1 tensor (the tangent of y).
1520 /// When `None`, forward-mode AD recurses into fwd_body.
1521 ///
1522 /// `num_inputs` is captured so [`Op::num_inputs`] stays
1523 /// infallible. Build via [`crate::Graph::custom_fn`].
1524 CustomFn {
1525 fwd_body: Box<crate::Graph>,
1526 vjp_body: Option<Box<crate::Graph>>,
1527 jvp_body: Option<Box<crate::Graph>>,
1528 num_inputs: u32,
1529 },
1530}
1531
1532impl Op {
1533 /// PLAN L4: discriminant for backend-supported-set checks.
1534 /// Stable, parameter-free identity per variant — `Op::Activation(_)`
1535 /// and `Op::Activation(Relu)` share the same `OpKind::Activation`.
1536 pub fn kind(&self) -> OpKind {
1537 match self {
1538 Op::Input { .. } => OpKind::Input,
1539 Op::Param { .. } => OpKind::Param,
1540 Op::Constant { .. } => OpKind::Constant,
1541 Op::Activation(_) => OpKind::Activation,
1542 Op::Cast { .. } => OpKind::Cast,
1543 Op::Quantize { .. } => OpKind::Quantize,
1544 Op::Dequantize { .. } => OpKind::Dequantize,
1545 Op::FakeQuantize { .. } => OpKind::FakeQuantize,
1546 Op::FakeQuantizeLSQ { .. } => OpKind::FakeQuantizeLSQ,
1547 Op::FakeQuantizeLSQBackwardX { .. } => OpKind::FakeQuantizeLSQBackwardX,
1548 Op::FakeQuantizeLSQBackwardScale { .. } => OpKind::FakeQuantizeLSQBackwardScale,
1549 Op::Binary(_) => OpKind::Binary,
1550 Op::Compare(_) => OpKind::Compare,
1551 Op::Where => OpKind::Where,
1552 Op::ElementwiseRegion { .. } => OpKind::ElementwiseRegion,
1553 Op::MatMul => OpKind::MatMul,
1554 Op::DotGeneral { .. } => OpKind::DotGeneral,
1555 Op::DenseSolve => OpKind::DenseSolve,
1556 Op::BatchedDenseSolve => OpKind::BatchedDenseSolve,
1557 Op::LayerNorm { .. } => OpKind::LayerNorm,
1558 Op::LayerNorm2d { .. } => OpKind::LayerNorm2d,
1559 Op::GroupNorm { .. } => OpKind::GroupNorm,
1560 Op::RmsNorm { .. } => OpKind::RmsNorm,
1561 Op::ResizeNearest2x => OpKind::ResizeNearest2x,
1562 Op::Attention { .. } => OpKind::Attention,
1563 Op::Rope { .. } => OpKind::Rope,
1564 Op::AxialRope2d { .. } => OpKind::AxialRope2d,
1565 Op::Reshape { .. } => OpKind::Reshape,
1566 Op::Transpose { .. } => OpKind::Transpose,
1567 Op::Narrow { .. } => OpKind::Narrow,
1568 Op::Concat { .. } => OpKind::Concat,
1569 Op::Expand { .. } => OpKind::Expand,
1570 Op::Gather { .. } => OpKind::Gather,
1571 Op::Reduce { .. } => OpKind::Reduce,
1572 Op::Softmax { .. } => OpKind::Softmax,
1573 Op::Cumsum { .. } => OpKind::Cumsum,
1574 Op::TopK { .. } => OpKind::TopK,
1575 Op::Sample { .. } => OpKind::Sample,
1576 Op::Conv { .. } => OpKind::Conv,
1577 Op::ConvTranspose2d { .. } => OpKind::ConvTranspose2d,
1578 Op::Pool { .. } => OpKind::Pool,
1579 Op::ReluBackward => OpKind::ReluBackward,
1580 Op::ActivationBackward { .. } => OpKind::ActivationBackward,
1581 Op::FakeQuantizeBackward { .. } => OpKind::FakeQuantizeBackward,
1582 Op::ComplexNormSq => OpKind::ComplexNormSq,
1583 Op::ComplexNormSqBackward => OpKind::ComplexNormSqBackward,
1584 Op::Conjugate => OpKind::Conjugate,
1585 Op::LayerNormBackwardInput { .. } => OpKind::LayerNormBackwardInput,
1586 Op::LayerNormBackwardGamma { .. } => OpKind::LayerNormBackwardGamma,
1587 Op::RmsNormBackwardInput { .. } => OpKind::RmsNormBackwardInput,
1588 Op::RmsNormBackwardGamma { .. } => OpKind::RmsNormBackwardGamma,
1589 Op::RmsNormBackwardBeta { .. } => OpKind::RmsNormBackwardBeta,
1590 Op::RopeBackward { .. } => OpKind::RopeBackward,
1591 Op::GroupNormBackwardInput { .. } => OpKind::GroupNormBackwardInput,
1592 Op::GroupNormBackwardGamma { .. } => OpKind::GroupNormBackwardGamma,
1593 Op::GroupNormBackwardBeta { .. } => OpKind::GroupNormBackwardBeta,
1594 Op::CumsumBackward { .. } => OpKind::CumsumBackward,
1595 Op::GatherBackward { .. } => OpKind::GatherBackward,
1596 Op::MaxPool2dBackward { .. } => OpKind::MaxPool2dBackward,
1597 Op::Conv2dBackwardInput { .. } => OpKind::Conv2dBackwardInput,
1598 Op::Conv2dBackwardWeight { .. } => OpKind::Conv2dBackwardWeight,
1599 Op::SoftmaxCrossEntropyWithLogits => OpKind::SoftmaxCrossEntropyWithLogits,
1600 Op::SoftmaxCrossEntropyBackward => OpKind::SoftmaxCrossEntropyBackward,
1601 Op::AttentionBackward { .. } => OpKind::AttentionBackward,
1602 Op::GroupedMatMul => OpKind::GroupedMatMul,
1603 Op::DequantGroupedMatMul { .. } => OpKind::DequantGroupedMatMul,
1604 Op::DequantMoEWeights { .. } => OpKind::DequantMoEWeights,
1605 Op::ScatterAdd => OpKind::ScatterAdd,
1606 Op::LoraMatMul { .. } => OpKind::LoraMatMul,
1607 Op::DequantMatMul { .. } => OpKind::DequantMatMul,
1608 Op::QMatMul { .. } => OpKind::QMatMul,
1609 Op::QConv2d { .. } => OpKind::QConv2d,
1610 Op::SelectiveScan { .. } => OpKind::SelectiveScan,
1611 Op::GatedDeltaNet { .. } => OpKind::GatedDeltaNet,
1612 Op::FusedSwiGLU { .. } => OpKind::FusedSwiGLU,
1613 Op::FusedMatMulBiasAct { .. } => OpKind::FusedMatMulBiasAct,
1614 Op::FusedResidualLN { .. } => OpKind::FusedResidualLN,
1615 Op::FusedResidualRmsNorm { .. } => OpKind::FusedResidualRmsNorm,
1616 Op::FusedAttentionBlock { .. } => OpKind::FusedAttentionBlock,
1617 Op::FusedTransformerLayer { .. } => OpKind::FusedTransformerLayer,
1618 Op::If { .. } => OpKind::If,
1619 Op::While { .. } => OpKind::While,
1620 Op::Scan { .. } => OpKind::Scan,
1621 Op::ScanBackward { .. } => OpKind::ScanBackward,
1622 Op::ScanBackwardXs { .. } => OpKind::ScanBackwardXs,
1623 Op::GaussianSplatRender { .. } => OpKind::GaussianSplatRender,
1624 Op::GaussianSplatRenderBackward { .. } => OpKind::GaussianSplatRenderBackward,
1625 Op::GaussianSplatPrepare { .. } => OpKind::GaussianSplatPrepare,
1626 Op::GaussianSplatRasterize { .. } => OpKind::GaussianSplatRasterize,
1627 Op::Custom { .. } => OpKind::Custom,
1628 Op::CustomFn { .. } => OpKind::CustomFn,
1629 Op::Fft { .. } => OpKind::Fft,
1630 }
1631 }
1632
1633 /// True if this op is element-wise (same shape in, same shape out).
1634 /// Element-wise ops are prime fusion candidates.
1635 pub fn is_elementwise(&self) -> bool {
1636 matches!(
1637 self,
1638 Op::Activation(_)
1639 | Op::Cast { .. }
1640 | Op::Binary(_)
1641 | Op::Compare(_)
1642 | Op::Where
1643 | Op::ElementwiseRegion { .. }
1644 )
1645 }
1646
1647 /// True if this op is a BLAS/compute-intensive op that forms a fusion boundary.
1648 pub fn is_blas(&self) -> bool {
1649 matches!(
1650 self,
1651 Op::MatMul
1652 | Op::DotGeneral { .. }
1653 | Op::DenseSolve
1654 | Op::BatchedDenseSolve
1655 | Op::Conv { .. }
1656 | Op::ConvTranspose2d { .. }
1657 | Op::FusedMatMulBiasAct { .. }
1658 | Op::GroupedMatMul
1659 | Op::DequantGroupedMatMul { .. }
1660 | Op::DequantMoEWeights { .. }
1661 | Op::LoraMatMul { .. }
1662 | Op::DequantMatMul { .. }
1663 | Op::QMatMul { .. }
1664 | Op::QConv2d { .. }
1665 )
1666 }
1667
1668 /// True if element-wise fusion must not span across this op.
1669 pub fn is_fusion_boundary(&self) -> bool {
1670 self.is_blas()
1671 || matches!(
1672 self,
1673 Op::GaussianSplatRender { .. }
1674 | Op::GaussianSplatRenderBackward { .. }
1675 | Op::GaussianSplatPrepare { .. }
1676 | Op::GaussianSplatRasterize { .. }
1677 )
1678 }
1679
1680 /// True if this op is a reduction (drives loop iteration in fused kernels).
1681 pub fn is_reduction(&self) -> bool {
1682 matches!(
1683 self,
1684 Op::Reduce { .. } | Op::Softmax { .. } | Op::TopK { .. }
1685 )
1686 }
1687
1688 /// Number of tensor inputs this op expects.
1689 pub fn num_inputs(&self) -> usize {
1690 match self {
1691 Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => 0,
1692 Op::Activation(_)
1693 | Op::Cast { .. }
1694 | Op::Reshape { .. }
1695 | Op::Quantize { .. }
1696 | Op::Dequantize { .. }
1697 | Op::Transpose { .. }
1698 | Op::Narrow { .. }
1699 | Op::Expand { .. }
1700 | Op::Reduce { .. }
1701 | Op::Softmax { .. }
1702 | Op::FusedSwiGLU { .. }
1703 | Op::TopK { .. }
1704 | Op::Cumsum { .. }
1705 | Op::Sample { .. }
1706 | Op::ResizeNearest2x => 1,
1707 // EMA / Fixed scale modes carry a state tensor as a 2nd input;
1708 // PerBatch (default) doesn't need one.
1709 Op::FakeQuantize { scale_mode, .. } => match scale_mode {
1710 ScaleMode::PerBatch => 1,
1711 ScaleMode::EMA { .. } | ScaleMode::Fixed => 2,
1712 },
1713 Op::FakeQuantizeLSQ { .. } => 2, // x, scale (learned param)
1714 Op::FakeQuantizeLSQBackwardX { .. } | Op::FakeQuantizeLSQBackwardScale { .. } => 3, // x, scale, dy
1715 Op::Binary(_) | Op::Compare(_) | Op::Gather { .. } | Op::MatMul | Op::ScatterAdd => 2,
1716 Op::GroupedMatMul => 3, // input, weight, expert_idx
1717 Op::DequantGroupedMatMul { .. } => 3, // input, packed_w, expert_idx
1718 Op::DequantMoEWeights { .. } => 1, // packed_w
1719 Op::LoraMatMul { .. } => 4, // x, w, a, b
1720 // x, w_q, scale, zp — or x, packed_w_bytes for GGUF
1721 // schemes (their scales/mins live inside the packed bytes,
1722 // see `QuantScheme::is_gguf`).
1723 Op::DequantMatMul { scheme } => {
1724 if scheme.is_gguf() {
1725 2
1726 } else {
1727 4
1728 }
1729 }
1730 Op::QMatMul { .. } => 3, // x, w, bias
1731 Op::QConv2d { .. } => 3, // x, w, bias
1732 Op::SelectiveScan { .. } => 5, // x, delta, a, b, c
1733 Op::GatedDeltaNet { carry_state, .. } if *carry_state => 6, // + state in/out
1734 Op::GatedDeltaNet { .. } => 5, // q, k, v, g, beta
1735 Op::Where => 3, // cond, on_true, on_false
1736 Op::Attention { mask_kind, .. } => match mask_kind {
1737 MaskKind::Custom | MaskKind::Bias => 4, // Q, K, V, mask
1738 _ => 3, // Q, K, V (mask synthesized in-kernel)
1739 },
1740 Op::AttentionBackward { mask_kind, .. } => match mask_kind {
1741 MaskKind::Custom | MaskKind::Bias => 5, // q, k, v, dy, mask
1742 _ => 4, // q, k, v, dy
1743 },
1744 Op::Rope { .. } => 3, // x, cos, sin
1745 Op::AxialRope2d { .. } => 1,
1746 Op::LayerNorm { .. }
1747 | Op::LayerNorm2d { .. }
1748 | Op::GroupNorm { .. }
1749 | Op::RmsNorm { .. } => 3, // input, gamma, beta
1750 Op::FusedMatMulBiasAct { .. } => 3, // input, weight, bias
1751 Op::FusedResidualLN { has_bias: true, .. } => 5, // x, residual, bias, gamma, beta
1752 Op::FusedResidualLN {
1753 has_bias: false, ..
1754 } => 4, // x, residual, gamma, beta
1755 Op::FusedResidualRmsNorm { has_bias: true, .. } => 5, // x, residual, bias, gamma, beta
1756 Op::FusedResidualRmsNorm {
1757 has_bias: false, ..
1758 } => 4, // x, residual, gamma, beta
1759 Op::Conv { .. } | Op::ConvTranspose2d { .. } => 2, // input, weight (bias via Add)
1760 Op::Pool { .. } => 1,
1761 Op::ReluBackward => 2, // x, dy
1762 Op::ActivationBackward { .. } => 2, // x, dy
1763 Op::FakeQuantizeBackward { .. } => 2, // x, dy
1764 Op::ComplexNormSq => 1, // z (C64)
1765 Op::ComplexNormSqBackward => 2, // z, g
1766 Op::Conjugate => 1, // z (C64)
1767 Op::LayerNormBackwardInput { .. } => 3, // x, gamma, dy
1768 Op::LayerNormBackwardGamma { .. } => 2, // x, dy
1769 Op::RmsNormBackwardInput { .. } => 4, // x, gamma, beta, dy
1770 Op::RmsNormBackwardGamma { .. } => 4,
1771 Op::RmsNormBackwardBeta { .. } => 4,
1772 Op::RopeBackward { .. } => 3, // dy, cos, sin
1773 Op::GroupNormBackwardInput { .. } => 4, // x, gamma, beta, dy
1774 Op::GroupNormBackwardGamma { .. } => 2, // x, dy
1775 Op::GroupNormBackwardBeta { .. } => 2,
1776 Op::CumsumBackward { .. } => 1, // dy
1777 Op::GatherBackward { .. } => 2, // dy, indices
1778 Op::MaxPool2dBackward { .. } => 2, // x, dy
1779 Op::Conv2dBackwardInput { .. } => 2, // dy, w
1780 Op::Conv2dBackwardWeight { .. } => 2, // x, dy
1781 Op::SoftmaxCrossEntropyWithLogits => 2, // logits, labels
1782 Op::SoftmaxCrossEntropyBackward => 3, // logits, labels, d_loss
1783 Op::Concat { .. } => 0, // variadic — checked at graph level
1784 Op::DotGeneral { .. } => 2,
1785 Op::DenseSolve => 2, // A, b
1786 Op::BatchedDenseSolve => 2, // A [B,N,N], b [B,N] or [B,N,K]
1787 Op::FusedAttentionBlock {
1788 has_bias, has_rope, ..
1789 } => 4 + if *has_bias { 2 } else { 0 } + if *has_rope { 2 } else { 0 },
1790 Op::If { .. } => 1, // predicate (captures handled separately)
1791 Op::While { .. } => 0, // variadic loop-carried; checked at graph level
1792 Op::Scan {
1793 num_bcast, num_xs, ..
1794 } => 1 + *num_bcast as usize + *num_xs as usize,
1795 Op::ScanBackward { num_xs, .. } => 3 + *num_xs as usize, // init, trajectory, upstream, xs_0..
1796 Op::ScanBackwardXs { num_xs, .. } => 3 + *num_xs as usize, // same as ScanBackward
1797 Op::GaussianSplatRender { .. } => 7,
1798 Op::GaussianSplatRenderBackward { .. } => 8,
1799 Op::GaussianSplatPrepare { .. } => 7,
1800 Op::GaussianSplatRasterize { .. } => 2,
1801 Op::FusedTransformerLayer { has_bias, .. } => {
1802 // hidden + qkv_w + out_w + ln1_g + ln1_b + fc1_w + fc2_w + ln2_g + ln2_b + mask = 10
1803 // bias variant adds: qkv_b + out_b + fc1_b + fc2_b = 4 more
1804 10 + if *has_bias { 4 } else { 0 }
1805 }
1806 Op::ElementwiseRegion { num_inputs, .. } => *num_inputs as usize,
1807 Op::Custom { num_inputs, .. } => *num_inputs as usize,
1808 Op::CustomFn { num_inputs, .. } => *num_inputs as usize,
1809 Op::Fft { .. } => 1,
1810 }
1811 }
1812}
1813
1814impl std::fmt::Display for Op {
1815 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1816 match self {
1817 Op::Input { name } => write!(f, "input(\"{name}\")"),
1818 Op::Param { name } => write!(f, "param(\"{name}\")"),
1819 Op::Constant { data } => write!(f, "const({}B)", data.len()),
1820 Op::Activation(a) => write!(f, "{a:?}"),
1821 Op::Quantize { axis, scales, .. } => match axis {
1822 None => write!(f, "quantize(s={})", scales[0]),
1823 Some(d) => write!(f, "quantize(axis={d},nch={})", scales.len()),
1824 },
1825 Op::Dequantize { axis, scales, .. } => match axis {
1826 None => write!(f, "dequantize(s={})", scales[0]),
1827 Some(d) => write!(f, "dequantize(axis={d},nch={})", scales.len()),
1828 },
1829 Op::FakeQuantize {
1830 bits,
1831 axis,
1832 ste,
1833 scale_mode,
1834 } => match axis {
1835 None => write!(
1836 f,
1837 "fake_quant(bits={bits},ste={ste:?},scale={scale_mode:?})"
1838 ),
1839 Some(d) => write!(
1840 f,
1841 "fake_quant(bits={bits},axis={d},ste={ste:?},scale={scale_mode:?})"
1842 ),
1843 },
1844 Op::FakeQuantizeLSQ { bits, axis } => match axis {
1845 None => write!(f, "fake_quant_lsq(bits={bits})"),
1846 Some(d) => write!(f, "fake_quant_lsq(bits={bits},axis={d})"),
1847 },
1848 Op::FakeQuantizeLSQBackwardX { bits, .. } => {
1849 write!(f, "fake_quant_lsq_bwd_x(bits={bits})")
1850 }
1851 Op::FakeQuantizeLSQBackwardScale { bits, .. } => {
1852 write!(f, "fake_quant_lsq_bwd_s(bits={bits})")
1853 }
1854 Op::Cast { to } => write!(f, "cast({to})"),
1855 Op::Binary(op) => write!(f, "{op:?}"),
1856 Op::Compare(op) => write!(f, "{op:?}"),
1857 Op::Where => write!(f, "where"),
1858 Op::MatMul => write!(f, "matmul"),
1859 Op::DotGeneral { .. } => write!(f, "dot_general"),
1860 Op::DenseSolve => write!(f, "dense_solve"),
1861 Op::BatchedDenseSolve => write!(f, "batched_dense_solve"),
1862 Op::LayerNorm { eps, .. } => write!(f, "layer_norm(eps={eps})"),
1863 Op::GroupNorm { num_groups, eps } => {
1864 write!(f, "group_norm(groups={num_groups},eps={eps})")
1865 }
1866 Op::ResizeNearest2x => write!(f, "resize_nearest_2x"),
1867 Op::RmsNorm { eps, .. } => write!(f, "rms_norm(eps={eps})"),
1868 Op::Attention {
1869 num_heads,
1870 head_dim,
1871 mask_kind,
1872 score_scale,
1873 attn_logit_softcap,
1874 } => {
1875 let mut s = match mask_kind {
1876 MaskKind::Custom => format!("attention(h={num_heads},d={head_dim})"),
1877 MaskKind::None => format!("attention(h={num_heads},d={head_dim},nomask)"),
1878 MaskKind::Causal => format!("attention(h={num_heads},d={head_dim},causal)"),
1879 MaskKind::SlidingWindow(w) => {
1880 format!("attention(h={num_heads},d={head_dim},sw={w})")
1881 }
1882 MaskKind::Bias => format!("attention(h={num_heads},d={head_dim},bias)"),
1883 };
1884 if let Some(sc) = score_scale {
1885 s.push_str(&format!(",scale={sc}"));
1886 }
1887 if let Some(cap) = attn_logit_softcap {
1888 s.push_str(&format!(",softcap={cap}"));
1889 }
1890 write!(f, "{s}")
1891 }
1892 Op::Rope { head_dim, n_rot } => write!(f, "rope(d={head_dim}, n_rot={n_rot})"),
1893 Op::AxialRope2d {
1894 end_x,
1895 end_y,
1896 head_dim,
1897 num_heads,
1898 theta,
1899 repeat_factor,
1900 } => write!(
1901 f,
1902 "axial_rope2d({end_x}x{end_y},h={num_heads},d={head_dim},θ={theta},r={repeat_factor})"
1903 ),
1904 Op::Reshape { new_shape } => write!(f, "reshape({new_shape:?})"),
1905 Op::Transpose { perm } => write!(f, "transpose({perm:?})"),
1906 Op::Narrow { axis, start, len } => write!(f, "narrow({axis},{start},{len})"),
1907 Op::Concat { axis } => write!(f, "concat(axis={axis})"),
1908 Op::Expand { .. } => write!(f, "expand"),
1909 Op::Gather { axis } => write!(f, "gather(axis={axis})"),
1910 Op::Reduce { op, axes, .. } => write!(f, "reduce_{op:?}({axes:?})"),
1911 Op::Softmax { axis } => write!(f, "softmax(axis={axis})"),
1912 Op::Cumsum { axis, exclusive } => {
1913 if *exclusive {
1914 write!(f, "cumsum(axis={axis},excl)")
1915 } else {
1916 write!(f, "cumsum(axis={axis})")
1917 }
1918 }
1919 Op::Sample {
1920 top_k,
1921 top_p,
1922 temperature,
1923 ..
1924 } => {
1925 write!(f, "sample(t={temperature}")?;
1926 if *top_k > 0 {
1927 write!(f, ",k={top_k}")?;
1928 }
1929 if *top_p < 1.0 {
1930 write!(f, ",p={top_p}")?;
1931 }
1932 write!(f, ")")
1933 }
1934 Op::TopK { k } => write!(f, "topk(k={k})"),
1935 Op::GroupedMatMul => write!(f, "grouped_matmul"),
1936 Op::DequantGroupedMatMul { scheme } => {
1937 write!(f, "dequant_grouped_matmul({scheme})")
1938 }
1939 Op::DequantMoEWeights { scheme } => write!(f, "dequant_moe_weights({scheme})"),
1940 Op::LoraMatMul { scale } => write!(f, "lora_matmul(scale={scale})"),
1941 Op::DequantMatMul { scheme } => write!(f, "dequant_matmul({scheme})"),
1942 Op::QMatMul {
1943 x_zp,
1944 w_zp,
1945 out_zp,
1946 mult,
1947 } => write!(
1948 f,
1949 "q_matmul(x_zp={x_zp},w_zp={w_zp},out_zp={out_zp},mult={mult})"
1950 ),
1951 Op::QConv2d { kernel_size, .. } => write!(f, "q_conv2d({kernel_size:?})"),
1952 Op::SelectiveScan { state_size } => write!(f, "ssm_scan(n={state_size})"),
1953 Op::GatedDeltaNet {
1954 state_size,
1955 carry_state,
1956 } => {
1957 if *carry_state {
1958 write!(f, "gated_delta_net(n={state_size},carry)")
1959 } else {
1960 write!(f, "gated_delta_net(n={state_size})")
1961 }
1962 }
1963 Op::ScatterAdd => write!(f, "scatter_add"),
1964 Op::Conv { kernel_size, .. } => write!(f, "conv2d({kernel_size:?})"),
1965 Op::ConvTranspose2d { kernel_size, .. } => {
1966 write!(f, "conv_transpose2d({kernel_size:?})")
1967 }
1968 Op::LayerNorm2d { eps } => write!(f, "layer_norm2d(eps={eps})"),
1969 Op::Pool {
1970 kind, kernel_size, ..
1971 } => write!(f, "pool_{kind:?}({kernel_size:?})"),
1972 Op::ReluBackward => write!(f, "relu_backward"),
1973 Op::ActivationBackward { kind } => write!(f, "{kind:?}_backward"),
1974 Op::ComplexNormSq => write!(f, "complex_norm_sq"),
1975 Op::ComplexNormSqBackward => write!(f, "complex_norm_sq_backward"),
1976 Op::Conjugate => write!(f, "conjugate"),
1977 Op::FakeQuantizeBackward { bits, ste, .. } => {
1978 write!(f, "fake_quant_backward(bits={bits},ste={ste:?})")
1979 }
1980 Op::MaxPool2dBackward { kernel_size, .. } => {
1981 write!(f, "maxpool2d_backward({kernel_size:?})")
1982 }
1983 Op::Conv2dBackwardInput { kernel_size, .. } => {
1984 write!(f, "conv2d_backward_input({kernel_size:?})")
1985 }
1986 Op::Conv2dBackwardWeight { kernel_size, .. } => {
1987 write!(f, "conv2d_backward_weight({kernel_size:?})")
1988 }
1989 Op::SoftmaxCrossEntropyWithLogits => write!(f, "sce_with_logits"),
1990 Op::SoftmaxCrossEntropyBackward => write!(f, "sce_backward"),
1991 Op::AttentionBackward {
1992 num_heads,
1993 head_dim,
1994 mask_kind,
1995 wrt,
1996 } => match mask_kind {
1997 MaskKind::None => write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},nomask)"),
1998 MaskKind::Causal => {
1999 write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},causal)")
2000 }
2001 MaskKind::SlidingWindow(w) => {
2002 write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},sw={w})")
2003 }
2004 MaskKind::Custom => {
2005 write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},custom)")
2006 }
2007 MaskKind::Bias => write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},bias)"),
2008 },
2009 Op::FusedMatMulBiasAct { activation } => {
2010 write!(f, "fused_mm_bias")?;
2011 if let Some(a) = activation {
2012 write!(f, "_{a:?}")?;
2013 }
2014 Ok(())
2015 }
2016 Op::FusedResidualLN { has_bias, eps } => {
2017 write!(f, "fused_residual")?;
2018 if *has_bias {
2019 write!(f, "_bias")?;
2020 }
2021 write!(f, "_ln(eps={eps})")
2022 }
2023 Op::FusedResidualRmsNorm { has_bias, eps } => {
2024 write!(f, "fused_residual")?;
2025 if *has_bias {
2026 write!(f, "_bias")?;
2027 }
2028 write!(f, "_rms(eps={eps})")
2029 }
2030 Op::FusedSwiGLU {
2031 cast_to,
2032 gate_first,
2033 } => {
2034 let mut s = match cast_to {
2035 Some(dt) => format!("fused_swiglu(cast={dt}"),
2036 None => "fused_swiglu(".to_string(),
2037 };
2038 if *gate_first {
2039 s.push_str(",gate_first");
2040 }
2041 s.push(')');
2042 write!(f, "{s}")
2043 }
2044 Op::FusedAttentionBlock {
2045 num_heads,
2046 head_dim,
2047 has_bias,
2048 has_rope,
2049 } => {
2050 write!(f, "fused_attn(h={num_heads},d={head_dim}")?;
2051 if *has_bias {
2052 write!(f, ",bias")?;
2053 }
2054 if *has_rope {
2055 write!(f, ",rope")?;
2056 }
2057 write!(f, ")")
2058 }
2059 Op::If { .. } => write!(f, "if(...)"),
2060 Op::While { max_iterations, .. } => match max_iterations {
2061 Some(n) => write!(f, "while(...max={n})"),
2062 None => write!(f, "while(...)"),
2063 },
2064 Op::Scan {
2065 length,
2066 save_trajectory,
2067 num_xs,
2068 ..
2069 } => {
2070 let traj = if *save_trajectory { ",traj" } else { "" };
2071 let xs = if *num_xs > 0 {
2072 format!(",xs={}", num_xs)
2073 } else {
2074 String::new()
2075 };
2076 write!(f, "scan(len={length}{xs}{traj})")
2077 }
2078 Op::ScanBackward {
2079 length,
2080 save_trajectory,
2081 num_xs,
2082 ..
2083 } => {
2084 let traj = if *save_trajectory { ",traj" } else { "" };
2085 let xs = if *num_xs > 0 {
2086 format!(",xs={}", num_xs)
2087 } else {
2088 String::new()
2089 };
2090 write!(f, "scan_bwd(len={length}{xs}{traj})")
2091 }
2092 Op::ScanBackwardXs {
2093 length,
2094 save_trajectory,
2095 num_xs,
2096 xs_idx,
2097 ..
2098 } => {
2099 let traj = if *save_trajectory { ",traj" } else { "" };
2100 write!(
2101 f,
2102 "scan_bwd_xs(len={length},xs={num_xs},idx={xs_idx}{traj})"
2103 )
2104 }
2105 Op::FusedTransformerLayer {
2106 num_heads,
2107 head_dim,
2108 intermediate_size,
2109 has_bias,
2110 ..
2111 } => {
2112 write!(
2113 f,
2114 "fused_layer(h={num_heads},d={head_dim},int={intermediate_size}"
2115 )?;
2116 if *has_bias {
2117 write!(f, ",bias")?;
2118 }
2119 write!(f, ")")
2120 }
2121 Op::ElementwiseRegion {
2122 chain,
2123 num_inputs,
2124 scalar_input_mask,
2125 input_modulus: _,
2126 } => {
2127 if *scalar_input_mask != 0 {
2128 write!(
2129 f,
2130 "ew_region(in={num_inputs},steps={},scalar_mask=0x{:x})",
2131 chain.len(),
2132 scalar_input_mask
2133 )
2134 } else {
2135 write!(f, "ew_region(in={num_inputs},steps={})", chain.len())
2136 }
2137 }
2138 Op::LayerNormBackwardInput { eps, .. } => {
2139 write!(f, "layer_norm_backward_input(eps={eps})")
2140 }
2141 Op::LayerNormBackwardGamma { eps, .. } => {
2142 write!(f, "layer_norm_backward_gamma(eps={eps})")
2143 }
2144 Op::RmsNormBackwardInput { eps, .. } => write!(f, "rms_norm_backward_input(eps={eps})"),
2145 Op::RmsNormBackwardGamma { eps, .. } => write!(f, "rms_norm_backward_gamma(eps={eps})"),
2146 Op::RmsNormBackwardBeta { eps, .. } => write!(f, "rms_norm_backward_beta(eps={eps})"),
2147 Op::RopeBackward { head_dim, n_rot } => {
2148 write!(f, "rope_backward(d={head_dim},n_rot={n_rot})")
2149 }
2150 Op::GroupNormBackwardInput { num_groups, eps } => {
2151 write!(f, "group_norm_backward_input(g={num_groups},eps={eps})")
2152 }
2153 Op::GroupNormBackwardGamma { num_groups, eps } => {
2154 write!(f, "group_norm_backward_gamma(g={num_groups},eps={eps})")
2155 }
2156 Op::GroupNormBackwardBeta { num_groups, eps } => {
2157 write!(f, "group_norm_backward_beta(g={num_groups},eps={eps})")
2158 }
2159 Op::CumsumBackward { axis, exclusive } => {
2160 write!(f, "cumsum_backward(axis={axis},exclusive={exclusive})")
2161 }
2162 Op::GatherBackward { axis } => write!(f, "gather_backward(axis={axis})"),
2163 Op::GaussianSplatRender {
2164 width,
2165 height,
2166 tile_size,
2167 radius_scale,
2168 alpha_cutoff,
2169 max_splat_steps,
2170 transmittance_threshold,
2171 max_list_entries,
2172 } => write!(
2173 f,
2174 "gaussian_splat_render({width}x{height},tile={tile_size},r={radius_scale},a={alpha_cutoff},steps={max_splat_steps},t={transmittance_threshold},list={max_list_entries})"
2175 ),
2176 Op::GaussianSplatRenderBackward {
2177 width,
2178 height,
2179 loss_grad_clip,
2180 sh_band,
2181 ..
2182 } => write!(
2183 f,
2184 "gaussian_splat_render_bwd({width}x{height},clip={loss_grad_clip},sh={sh_band})"
2185 ),
2186 Op::GaussianSplatPrepare {
2187 width,
2188 height,
2189 tile_size,
2190 radius_scale,
2191 alpha_cutoff,
2192 max_splat_steps,
2193 transmittance_threshold,
2194 max_list_entries,
2195 ..
2196 } => write!(
2197 f,
2198 "gaussian_splat_prepare({width}x{height},tile={tile_size},r={radius_scale},a={alpha_cutoff},steps={max_splat_steps},t={transmittance_threshold},list={max_list_entries})"
2199 ),
2200 Op::GaussianSplatRasterize {
2201 width,
2202 height,
2203 tile_size,
2204 alpha_cutoff,
2205 max_splat_steps,
2206 transmittance_threshold,
2207 max_list_entries,
2208 ..
2209 } => write!(
2210 f,
2211 "gaussian_splat_rasterize({width}x{height},tile={tile_size},a={alpha_cutoff},steps={max_splat_steps},t={transmittance_threshold},list={max_list_entries})"
2212 ),
2213 Op::Custom {
2214 name,
2215 num_inputs,
2216 attrs,
2217 } => write!(f, "custom({name},in={num_inputs},attrs={}B)", attrs.len()),
2218 Op::CustomFn {
2219 num_inputs,
2220 vjp_body,
2221 jvp_body,
2222 ..
2223 } => {
2224 let v = if vjp_body.is_some() { ",vjp" } else { "" };
2225 let j = if jvp_body.is_some() { ",jvp" } else { "" };
2226 write!(f, "custom_fn(in={num_inputs}{v}{j})")
2227 }
2228 Op::Fft { inverse, norm } => {
2229 write!(f, "fft(inverse={inverse}, norm={norm:?})")
2230 }
2231 }
2232 }
2233}