rlx_ir/op.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Operation types — every tensor op in the RLX IR.
17//!
18//! Designed for pattern-matching fusion: ops are grouped by category so
19//! fusion passes can reason about them structurally.
20
21use crate::DType;
22
23/// Unary element-wise activation functions.
24#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub enum Activation {
27 Gelu,
28 GeluApprox,
29 Silu, // SwiGLU gate activation
30 Relu,
31 Sigmoid,
32 Tanh,
33 Exp,
34 Log,
35 Sqrt,
36 Rsqrt,
37 Neg,
38 Abs,
39 /// `sin(x)`. Backward: `dx = upstream · cos(x)`.
40 Sin,
41 /// `cos(x)`. Backward: `dx = -upstream · sin(x)`.
42 Cos,
43 /// `tan(x)`. Backward: `dx = upstream · sec²(x) = upstream · (1 + tan²(x))`.
44 Tan,
45 /// `atan(x)`. Backward: `dx = upstream · (1 / (1 + x²))`.
46 Atan,
47 /// Round to nearest integer (half-to-even), in f32.
48 /// Forward: `x.round()`. Backward: STE — treats as identity, so
49 /// the gradient passes through unchanged. Useful as a primitive
50 /// for composing custom quantization schemes (Mul-by-recip-scale
51 /// → Round → Clamp → Mul-by-scale = a hand-rolled FakeQuantize
52 /// that the elementwise-region pass can fuse into a single kernel).
53 Round,
54}
55
56/// Scale-tracking strategy for `Op::FakeQuantize`. Determines how
57/// the per-channel `s[c]` is computed each forward pass.
58///
59/// * `PerBatch` — recompute `s[c] = max(|x|) / q_max` from the
60/// current data on every call. Simple, no extra inputs, but
61/// noisy for activations (max-abs jumps batch-to-batch).
62///
63/// * `EMA { decay }` — keep a running `s[c]` in a state tensor
64/// (passed as a second op input). On each call, blend the
65/// current per-batch max-abs into the state via
66/// `state' = decay·state + (1-decay)·max_abs`. Smooth scale
67/// over training, makes activation-QAT actually trainable.
68/// Typical `decay = 0.99`.
69///
70/// * `Fixed` — never recompute. The state tensor's value is
71/// used as-is each call (set once at construction or by the
72/// caller). Useful when scales are pre-calibrated.
73#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
74#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Default)]
75pub enum ScaleMode {
76 #[default]
77 PerBatch,
78 EMA {
79 decay: f32,
80 },
81 Fixed,
82}
83
84impl Eq for ScaleMode {}
85impl std::hash::Hash for ScaleMode {
86 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
87 match self {
88 ScaleMode::PerBatch => state.write_u8(0),
89 ScaleMode::EMA { decay } => {
90 state.write_u8(1);
91 state.write_u32(decay.to_bits());
92 }
93 ScaleMode::Fixed => state.write_u8(2),
94 }
95 }
96}
97
98/// Straight-through estimator variants for `Op::FakeQuantize`'s
99/// backward. The forward is the same regardless: discrete
100/// `clamp(round(x/s)) * s`. The choice here affects only the
101/// gradient w.r.t. `x` during training.
102///
103/// * `Identity` — `dx = upstream`. The original STE; treats the
104/// round as identity in the backward direction. Simplest, fine
105/// for moderate bit widths (i4 / i8).
106///
107/// * `ClippedIdentity` — `dx = upstream * (|x| ≤ q_max·s)`. Zero
108/// the gradient when the input was outside the quantization
109/// range (i.e. the clamp activated). Stops the optimizer from
110/// pushing weights further into saturation.
111///
112/// * `Tanh` — `dx = upstream * (1 - tanh²(x/s))`. Smooth surrogate
113/// for the round step. Slowly attenuates the gradient as `|x|`
114/// approaches `q_max·s`. Often best on tight bit widths (i2).
115///
116/// * `HardTanh` — `dx = upstream * (1 - |x/(q_max·s)|).max(0)`.
117/// Piecewise-linear cousin of `Tanh`; cheaper to compute and
118/// nearly as effective.
119#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
120#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
121pub enum SteKind {
122 #[default]
123 Identity,
124 ClippedIdentity,
125 Tanh,
126 HardTanh,
127}
128
129/// Binary element-wise operations.
130#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
132pub enum BinaryOp {
133 Add,
134 Sub,
135 Mul,
136 Div,
137 Max,
138 Min,
139 Pow,
140}
141
142/// Comparison operations (return Bool tensor).
143#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
144#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
145pub enum CmpOp {
146 Eq,
147 Ne,
148 Lt,
149 Le,
150 Gt,
151 Ge,
152}
153
154/// What kind of attention mask the kernel should apply.
155///
156/// Borrowed from MAX's `nn/attention/mha_mask.mojo` pattern (#20 in
157/// PLAN.md): one attention kernel handles all variants by branching on
158/// the mask kind, instead of forcing every caller to materialize a mask
159/// tensor. The win is two-fold:
160/// 1. **`None`** — single unpadded sequence: no mask load, no per-key
161/// compare in the inner loop.
162/// 2. **`Causal`** — autoregressive decode: kernel generates the upper-
163/// triangular fill from `(qi, ki)` directly; no `seq²` mask tensor
164/// ever exists.
165///
166/// `Custom` is the existing path — read mask values from the 4th input.
167#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
168#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
169pub enum MaskKind {
170 /// No masking — every position attends to every position.
171 None,
172 /// Causal (autoregressive) — position `qi` attends only to `ki <= qi`.
173 Causal,
174 /// Sliding window — position `qi` attends to `ki ∈ [qi - w, qi]`.
175 SlidingWindow(usize),
176 /// Read mask values from the input tensor (default; matches BERT
177 /// padding-mask behavior). Tensor shape `[batch, key_len]` with
178 /// `1.0` = valid, `<0.5` = ignored.
179 Custom,
180 /// Additive per-head, per-query bias tensor
181 /// `[batch, num_heads, query_len, key_len]` added to the
182 /// `QK^T · scale` scores before softmax. Lets DETR-style boxRPB
183 /// and other learned position biases reuse the fast `Op::Attention`
184 /// path instead of decomposing into matmul + add + softmax + matmul.
185 Bias,
186}
187
188/// Which forward input an [`Op::AttentionBackward`] node differentiates.
189#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
190#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
191pub enum AttentionBwdWrt {
192 Query,
193 Key,
194 Value,
195}
196
197/// Reduction operations along specified axes.
198#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
199#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
200pub enum ReduceOp {
201 Sum,
202 Mean,
203 Max,
204 Min,
205 Prod,
206}
207
208/// PLAN L4: discriminant for each [`Op`] variant. Used by
209/// [`Op::kind`] + the `Backend::supported_ops` trait method to declare
210/// which ops a backend can lower; the `LegalizeForBackend` pass in
211/// `rlx-opt` checks the graph against this set and fails the compile
212/// when an unsupported op is present (instead of silent fallback).
213#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
214#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
215pub enum OpKind {
216 Input,
217 Param,
218 Constant,
219 Activation,
220 Cast,
221 StopGradient,
222 Quantize,
223 Dequantize,
224 FakeQuantize,
225 FakeQuantizeLSQ,
226 FakeQuantizeLSQBackwardX,
227 FakeQuantizeLSQBackwardScale,
228 Binary,
229 Compare,
230 Where,
231 ElementwiseRegion,
232 /// Fused sampling / geometry chain (FKL-style transform region).
233 TransformRegion,
234 /// Same element-wise chain over multiple batch planes (horizontal fusion).
235 BatchElementwiseRegion,
236 MatMul,
237 DotGeneral,
238 DenseSolve,
239 BatchedDenseSolve,
240 LayerNorm,
241 LayerNorm2d,
242 GroupNorm,
243 BatchNormInference,
244 RmsNorm,
245 ResizeNearest2x,
246 Attention,
247 Rope,
248 AxialRope2d,
249 Reshape,
250 Transpose,
251 Narrow,
252 Concat,
253 Expand,
254 Gather,
255 Reduce,
256 Softmax,
257 Cumsum,
258 ArgMax,
259 ArgMin,
260 TopK,
261 Sample,
262 /// ONNX `RandomNormalLike` — shape from input 0, output filled at runtime.
263 RngNormal,
264 /// ONNX `RandomUniformLike`.
265 RngUniform,
266 Conv,
267 Im2Col,
268 ConvTranspose2d,
269 Pool,
270 ReluBackward,
271 ActivationBackward,
272 FakeQuantizeBackward,
273 ComplexNormSq,
274 ComplexNormSqBackward,
275 Conjugate,
276 MaxPool2dBackward,
277 Conv2dBackwardInput,
278 Conv2dBackwardWeight,
279 SoftmaxCrossEntropyWithLogits,
280 SoftmaxCrossEntropyBackward,
281 AttentionBackward,
282 LayerNormBackwardInput,
283 LayerNormBackwardGamma,
284 RmsNormBackwardInput,
285 RmsNormBackwardGamma,
286 RmsNormBackwardBeta,
287 RopeBackward,
288 GroupNormBackwardInput,
289 GroupNormBackwardGamma,
290 GroupNormBackwardBeta,
291 BatchNormInferenceBackwardInput,
292 BatchNormInferenceBackwardGamma,
293 BatchNormInferenceBackwardBeta,
294 CumsumBackward,
295 GatherBackward,
296 GroupedMatMul,
297 DequantGroupedMatMul,
298 DequantMoEWeights,
299 ScatterAdd,
300 LoraMatMul,
301 DequantMatMul,
302 QMatMul,
303 QConv2d,
304 SelectiveScan,
305 GatedDeltaNet,
306 Lstm,
307 Gru,
308 Rnn,
309 Mamba2,
310 FusedSwiGLU,
311 FusedMatMulBiasAct,
312 FusedResidualLN,
313 FusedResidualRmsNorm,
314 FusedAttentionBlock,
315 FusedTransformerLayer,
316 If,
317 While,
318 Scan,
319 ScanBackward,
320 ScanBackwardXs,
321 /// CPU reference 3D Gaussian splat raster (project → bin → sort → raster).
322 /// See [`Op::GaussianSplatRender`].
323 GaussianSplatRender,
324 /// Backward of [`Op::GaussianSplatRender`] — packed scene parameter gradients.
325 GaussianSplatRenderBackward,
326 /// Project + tile bin + sort + ray grid (strict IR splat stage 1).
327 GaussianSplatPrepare,
328 /// Per-pixel raster from prepared buffers (strict IR splat stage 2).
329 GaussianSplatRasterize,
330 /// User-registered op dispatched through `op_registry`. All
331 /// custom ops (Sparse-LU, FFT, eigensolve, ...) share this kind;
332 /// the per-op identity lives in `Op::Custom::name`.
333 Custom,
334 /// User-defined sub-graph with optional override AD rules. See
335 /// [`Op::CustomFn`] / [`crate::Graph::custom_fn`].
336 CustomFn,
337 /// 1D FFT primitive (forward or inverse) — see [`Op::Fft`].
338 Fft,
339 /// Ternary pruned radix-2 butterfly stage — see [`Op::FftButterflyStage`].
340 FftButterflyStage,
341 /// Whisper-style log-mel from block-layout FFT spectrum — see [`Op::LogMel`].
342 LogMel,
343 /// Backward of [`Op::LogMel`] w.r.t. block-layout spectrum input 0.
344 LogMelBackward,
345 /// Welch PSD top-K spikes from block-layout FFT spectra — see [`Op::WelchPeaks`].
346 WelchPeaks,
347}
348
349/// An operand inside a fused [`ChainStep`] — either a graph-level input
350/// to the [`Op::ElementwiseRegion`] (by index 0..num_inputs) or the
351/// result of a previous step in the chain (by index 0..step_position).
352#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
353#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
354pub enum ChainOperand {
355 Input(u32),
356 Step(u32),
357}
358
359/// One step in a fused element-wise chain. Each step produces exactly
360/// one scalar result (per element); later steps can refer to it via
361/// [`ChainOperand::Step`]. The whole chain runs per element in registers.
362#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
363#[derive(Debug, Clone, PartialEq)]
364pub enum ChainStep {
365 Activation(Activation, ChainOperand),
366 Cast(DType, ChainOperand),
367 Binary(BinaryOp, ChainOperand, ChainOperand),
368 Compare(CmpOp, ChainOperand, ChainOperand),
369 /// 3-input element-wise select: `cond ? on_true : on_false`. Mirrors
370 /// `Op::Where` inside a chain. `cond` is treated as truthy iff
371 /// non-zero. Lets the optimizer fold attention masks / clamp-style
372 /// patterns into a single region kernel instead of breaking the
373 /// chain at the first `Op::Where`.
374 Where(ChainOperand, ChainOperand, ChainOperand),
375}
376
377/// Pre-region memory transform fused into [`Op::ElementwiseRegion`].
378#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
379#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
380pub enum RegionPrologue {
381 #[default]
382 None,
383 /// Input is half-resolution NCHW; output shape is 2× H×W (nearest 2×).
384 ResizeNearest2x,
385}
386
387/// One sampling step in [`Op::TransformRegion`].
388#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
389#[derive(Debug, Clone, PartialEq)]
390pub enum TransformStep {
391 ResizeNearest2x(ChainOperand),
392}
393
394/// An operation in the RLX IR graph.
395///
396/// Operations are categorized for fusion analysis:
397/// - Element-wise ops fuse with anything reading their output
398/// - Matmul/Conv are BLAS-dispatched and form fusion boundaries
399/// - Reductions are fusion roots (drive the loop iteration)
400#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
401#[derive(Debug, Clone, PartialEq)]
402pub enum Op {
403 // ── Graph inputs ────────────────────────────────────────────
404 /// Model input with a name (shape on the Node).
405 Input {
406 name: String,
407 },
408
409 /// Model parameter (weight/bias) with a name.
410 Param {
411 name: String,
412 },
413
414 /// Constant tensor embedded in the graph.
415 Constant {
416 data: Vec<u8>,
417 },
418
419 // ── Element-wise unary ──────────────────────────────────────
420 /// Unary activation: one input, same shape output.
421 Activation(Activation),
422
423 /// Cast to a different dtype.
424 Cast {
425 to: DType,
426 },
427
428 /// Stop-gradient (a.k.a. `detach` / `lax.stop_gradient`). Forward is
429 /// identity; the reverse-mode autodiff rule returns **no** gradient
430 /// contribution for the input. Single input, same shape & dtype
431 /// output. Used to build a Gradient-Reverse-Layer with identity
432 /// forward semantics in user code (see maet-rs `dat_loss`).
433 StopGradient,
434
435 /// INT8 quantization. Input f32; output i8 same shape.
436 /// `q[i] = saturate_i8(round(x[i] / scale[c]) + zero_point[c])`
437 /// where `c` selects the per-channel scale/zp when `axis = Some(d)`
438 /// (`c = idx[d]`), or always uses index 0 when `axis = None`
439 /// (per-tensor). The `scales` / `zero_points` payload length must
440 /// match `1` for per-tensor and `input.dim(d)` for per-channel.
441 /// Static — typically produced at calibration time and baked
442 /// into the loaded model. Use `Op::Dequantize` for the inverse.
443 Quantize {
444 axis: Option<usize>,
445 scales: Vec<f32>,
446 zero_points: Vec<i32>,
447 },
448
449 /// INT8 dequantization (inverse of `Op::Quantize`). Input i8;
450 /// output f32 same shape.
451 /// `x[i] = (q[i] - zero_point[c]) · scale[c]`
452 /// where `c` is selected by `axis` exactly as in `Op::Quantize`.
453 Dequantize {
454 axis: Option<usize>,
455 scales: Vec<f32>,
456 zero_points: Vec<i32>,
457 },
458
459 /// "Fake-quantize" op for **quantization-aware training** (QAT).
460 /// Input f32; output f32 same shape. Forward computes a per-axis
461 /// (or per-tensor when `axis = None`) max-abs scale on the fly:
462 /// `s[c] = max(|x[..., c, ...]|) / q_max(bits)`
463 /// then quantizes-then-dequantizes:
464 /// `out[i] = clamp(round(x[i] / s[c]), -q_max, q_max) * s[c]`
465 /// where `q_max` is `127` for `bits=8`, `7` for `bits=4`, `1` for
466 /// `bits=2` (ternary). Symmetric only — zero-point is always 0.
467 ///
468 /// The point of this op is to make the SGD optimizer "see" the
469 /// deployment-time rounding during training. Backward is the
470 /// **straight-through estimator** (STE): the gradient passes
471 /// through (variant chosen by `ste`), ignoring the discontinuity
472 /// at the round. Without STE the rounding would have zero
473 /// gradient almost everywhere and learning would stop.
474 ///
475 /// Inserted by the trainer on conv / FC weight tensors when
476 /// `--qat` is on; the existing `Op::Quantize` / packing path at
477 /// the end of training still handles the deployment-side
478 /// conversion to `i8`/`i4`/`i2` codes.
479 FakeQuantize {
480 bits: u8,
481 axis: Option<usize>,
482 ste: SteKind,
483 scale_mode: ScaleMode,
484 },
485
486 /// Learned Step Size Quantization (LSQ; Esser et al. 2020,
487 /// `arXiv:1902.08153`). Like `FakeQuantize` but the per-channel
488 /// `scale` is a *learned parameter*, passed as the second input.
489 /// Forward is identical to `FakeQuantize` with a fixed scale:
490 /// `out[i] = clamp(round(x[i]/s[c]), -q_max, q_max) * s[c]`
491 /// Backward computes both `dx` (STE) and `dscale[c]` via the
492 /// closed-form gradient:
493 /// `dscale[c] = sum_i ψ(x[i]/s[c]) · upstream[i]`
494 /// where `ψ(z) = -z + round(z)` if `|z| ≤ q_max` else
495 /// `sign(z) · q_max`. Routinely beats per-batch and EMA at
496 /// tight bit widths (i2 / i3).
497 ///
498 /// Inputs: `[x, scale]`. `scale` is `[chan_dim]` f32 (matches
499 /// `axis`); for `axis = None` it's `[1]`.
500 FakeQuantizeLSQ {
501 bits: u8,
502 axis: Option<usize>,
503 },
504
505 /// Backward pass for `Op::FakeQuantizeLSQ`. Computes BOTH the
506 /// gradient w.r.t. `x` (STE) and the gradient w.r.t. `scale`
507 /// (closed-form). Output shape matches `x`; the `scale` gradient
508 /// is reduced separately by `LsqScaleGradient`.
509 /// Inputs: `[x, scale, dy]`. Output: `dx`, same shape as `x`.
510 FakeQuantizeLSQBackwardX {
511 bits: u8,
512 axis: Option<usize>,
513 },
514
515 /// Companion to `FakeQuantizeLSQBackwardX`: computes the
516 /// `[chan_dim]` per-channel scale gradient. Inputs `[x, scale, dy]`.
517 /// Output shape matches `scale`.
518 FakeQuantizeLSQBackwardScale {
519 bits: u8,
520 axis: Option<usize>,
521 },
522
523 // ── Element-wise binary ─────────────────────────────────────
524 /// Binary op with broadcasting: two inputs, output shape is broadcast result.
525 Binary(BinaryOp),
526
527 // ── Comparison ──────────────────────────────────────────────
528 /// Element-wise comparison: two inputs, Bool output.
529 Compare(CmpOp),
530
531 /// Select elements: cond (Bool), on_true, on_false → output.
532 Where,
533
534 /// Fused element-wise region (PLAN L2). Holds an N-step chain of
535 /// element-wise operations. Inputs are referenced by index 0..num_inputs;
536 /// each step's result can be referenced by later steps via
537 /// `ChainOperand::Step(idx)`. The output is the last step's result.
538 /// Emitted by `MarkElementwiseRegions` in `rlx-opt` from chains of
539 /// Activation/Cast/Binary/Compare/Where ops with single-consumer
540 /// intermediates and broadcast-compatible shapes. Backends that
541 /// don't have a region kernel can decompose back to the original
542 /// chain via `unfuse::unfuse_elementwise_regions`.
543 ///
544 /// `scalar_input_mask` is a per-input bitfield (bit `i` set ⇒
545 /// input `i` is a scalar broadcast — has shape `[1]`). Kept as a
546 /// fast-path indicator that lets kernels skip the modulo entirely
547 /// when they detect a scalar.
548 ///
549 /// `input_modulus[i]` is the per-input element count, used by
550 /// kernels to compute `arena[input_offs[i] + (gid % input_modulus[i])]`
551 /// — the trailing-shape broadcast pattern. `0` means "no broadcast"
552 /// (input matches the output element count; kernel reads `gid`
553 /// directly). `1` means scalar; any other value means the input
554 /// has fewer elements than the output and they tile by modulo.
555 /// The encoder only allows broadcasts where `out_elems % in_elems
556 /// == 0` so the modulo divides cleanly. Lets chains include bias /
557 /// scale / eps / mask factors that previously broke the chain at
558 /// a Binary op with mismatched shapes.
559 ElementwiseRegion {
560 chain: Vec<ChainStep>,
561 num_inputs: u32,
562 scalar_input_mask: u32,
563 input_modulus: [u32; 16],
564 /// FKL-style closed fusion: apply before the element-wise chain.
565 prologue: RegionPrologue,
566 /// External input index that supplies the prologue transform source (default 0).
567 prologue_input: u32,
568 },
569
570 /// Fused transform chain (resize, future crop/color). Decompose via
571 /// [`rlx_fusion::DecomposeFusionRegions`] when no native kernel exists.
572 TransformRegion {
573 steps: Vec<TransformStep>,
574 num_inputs: u32,
575 },
576
577 /// Identical [`Op::ElementwiseRegion`] chain over `num_batch_inputs` tensors
578 /// (horizontal / z-plane fusion). Inputs are separate batch slices.
579 BatchElementwiseRegion {
580 chain: Vec<ChainStep>,
581 num_batch_inputs: u32,
582 scalar_input_mask: u32,
583 input_modulus: [u32; 16],
584 prologue: RegionPrologue,
585 prologue_input: u32,
586 },
587
588 // ── Linear algebra ──────────────────────────────────────────
589 /// Matrix multiply. Inputs: [.., M, K] × [.., K, N] → [.., M, N].
590 /// Batch dimensions are broadcast.
591 MatMul,
592
593 /// Matrix multiply with explicit dimension specification.
594 /// Like XLA's DotGeneral — handles arbitrary batch/contracting dims.
595 DotGeneral {
596 lhs_contracting: Vec<usize>,
597 rhs_contracting: Vec<usize>,
598 lhs_batch: Vec<usize>,
599 rhs_batch: Vec<usize>,
600 },
601
602 /// Batched dense linear solve. Inputs: `A [B, N, N]`,
603 /// `b [B, N]` or `b [B, N, K]`. Output: same shape as `b`.
604 ///
605 /// Per-batch independent solve — each `A[i]` and `b[i]` are
606 /// solved as a separate `Op::DenseSolve`. Emitted by vmap of
607 /// `Op::DenseSolve`. The CPU lowering loops over the batch
608 /// dimension calling `dgesv` per slice (LAPACK doesn't expose a
609 /// batched solve on Accelerate; cuSOLVER does on NVIDIA).
610 BatchedDenseSolve,
611
612 /// Dense linear solve `x = A⁻¹ · b` via LU factorization.
613 /// Inputs: `A [N, N]`, `b [N]` (or `b [N, K]` for multi-RHS).
614 /// Output: same shape as `b`.
615 ///
616 /// VJP via the implicit-function theorem:
617 /// `dx = solve(Aᵀ, upstream)`
618 /// `dA = -outer(dx, x)` (x is the forward output)
619 /// `db = dx`
620 /// The rule is dtype-agnostic; lowering is per-backend (Accelerate
621 /// `dgesv` / `sgesv`, cuSOLVER, etc.).
622 DenseSolve,
623
624 // ── Normalization ───────────────────────────────────────────
625 /// Layer normalization: input, gamma, beta → normalized output.
626 /// `axis` is the feature dimension (usually -1).
627 LayerNorm {
628 axis: i32,
629 eps: f32,
630 },
631
632 /// Group normalization on NCHW tensors: `input`, `gamma`, `beta` → same shape.
633 /// Normalizes over `(C/num_groups) × H × W` per group.
634 GroupNorm {
635 num_groups: usize,
636 eps: f32,
637 },
638
639 /// LayerNorm2d on NCHW: normalize across the channel axis at each spatial
640 /// position (candle / SAM `LayerNorm2d` semantics — not PyTorch's H×W norm).
641 LayerNorm2d {
642 eps: f32,
643 },
644
645 /// Nearest-neighbor 2× upsample on NCHW (doubles spatial dims 2 and 3).
646 ResizeNearest2x,
647
648 /// RMS normalization: input, gamma → normalized output.
649 RmsNorm {
650 axis: i32,
651 eps: f32,
652 },
653
654 /// BatchNorm inference with frozen running statistics.
655 /// Inputs: `x`, `gamma`, `beta`, `running_mean`, `running_var`.
656 /// Feature dimension is the last axis of `x`; stats are 1-D `[C]`.
657 BatchNormInference {
658 eps: f32,
659 },
660
661 // ── Attention ───────────────────────────────────────────────
662 /// Scaled dot-product attention: Q, K, V, \[mask\] → output.
663 /// The compiler can lower this to fused SDPA or flash attention.
664 /// `mask_kind` controls how masking is applied — `Custom` reads from
665 /// the 4th input tensor; `None` / `Causal` / `SlidingWindow` skip the
666 /// mask load and apply the mask directly in the inner loop. See
667 /// `MaskKind` for the rationale.
668 ///
669 /// `score_scale`: when `Some(s)`, dot-product scores are multiplied by
670 /// `s` instead of the default `1/sqrt(head_dim)` (Gemma uses `head_dim^-0.5`
671 /// explicitly in config). `attn_logit_softcap`: when `Some(c)`, applies
672 /// `c * tanh(s/c)` to scores before softmax (Gemma 2).
673 Attention {
674 num_heads: usize,
675 head_dim: usize,
676 mask_kind: MaskKind,
677 score_scale: Option<f32>,
678 attn_logit_softcap: Option<f32>,
679 },
680
681 /// Rotary position embedding applied to one tensor: x, cos, sin → x_rotated.
682 /// Apply separately to Q and K. `head_dim` is the per-head width; `n_rot`
683 /// is how many leading dims get NeoX RoPE (pair offset `n_rot/2`). When
684 /// `n_rot < head_dim`, trailing dims are copied unchanged (Qwen3.5 MRoPE).
685 Rope {
686 head_dim: usize,
687 n_rot: usize,
688 },
689
690 /// SAM2 axial 2-D RoPE on `[batch, seq, num_heads * head_dim]`.
691 AxialRope2d {
692 end_x: usize,
693 end_y: usize,
694 head_dim: usize,
695 num_heads: usize,
696 theta: f32,
697 repeat_factor: usize,
698 },
699
700 // ── Shape manipulation ──────────────────────────────────────
701 Reshape {
702 new_shape: Vec<i64>,
703 },
704 Transpose {
705 perm: Vec<usize>,
706 },
707 /// Select a contiguous slice along an axis.
708 Narrow {
709 axis: usize,
710 start: usize,
711 len: usize,
712 },
713 /// Concatenate along an axis.
714 Concat {
715 axis: usize,
716 },
717 /// Expand (broadcast) to a target shape.
718 Expand {
719 target_shape: Vec<i64>,
720 },
721 /// Gather elements by index along an axis (embedding lookup).
722 Gather {
723 axis: usize,
724 },
725
726 // ── Reduction ───────────────────────────────────────────────
727 /// Reduce along specified axes.
728 Reduce {
729 op: ReduceOp,
730 axes: Vec<usize>,
731 keep_dim: bool,
732 },
733
734 /// Selective scan (plan #15) — Mamba-style state-space model
735 /// step. The recurrence:
736 /// `h[t] = exp(Δ[t] * A) * h[t-1] + Δ[t] * B[t] * x[t]`
737 /// `y[t] = C[t] * h[t]`
738 /// where state `h` has dimension `state_size` and the input has
739 /// `(batch, seq, hidden)`.
740 ///
741 /// Inputs (in order):
742 /// `x [b, s, h]` f32 input
743 /// `delta [b, s, h]` f32 step size (per-position, per-channel)
744 /// `a [h, n]` f32 transition matrix (one per channel)
745 /// `b [b, s, n]` f32 input projection
746 /// `c [b, s, n]` f32 output projection
747 /// Output: `[b, s, h]` f32. State `h` is implicit; the kernel
748 /// scans through the seq dimension carrying it.
749 ///
750 /// `state_size` = `n` is exposed for the cost model.
751 SelectiveScan {
752 state_size: usize,
753 },
754
755 /// Gated DeltaNet linear-attention recurrence — the per-layer
756 /// kernel used by Qwen3.5/3.6 trunk "linear attention" blocks
757 /// (and Qwen3-Next, Kimi-Linear). Mirrors
758 /// `llama.cpp / src/models/delta-net-base.cpp` autoregressive
759 /// path; chunked + fused variants ride the same op identity.
760 ///
761 /// **Math (per token `t`, head `h`, state size `n`):**
762 /// state matrix `S[h, i, j]` is implicit (reset per batch).
763 /// ```text
764 /// S[h] *= exp(g[t,h]) # scalar gate
765 /// sk[h,j] = Σ_i S[h,i,j] * k[t,h,i]
766 /// d[h,j] = (v[t,h,j] - sk[h,j]) * b[t,h] # b = beta
767 /// S[h,i,j] += k[t,h,i] * d[h,j] # outer-prod
768 /// o[t,h,j] = Σ_i S[h,i,j] * (q[t,h,i] / √n)
769 /// ```
770 ///
771 /// Inputs:
772 /// `q [b, s, h_v, n]` f32 queries (L2-normed by caller)
773 /// `k [b, s, h_v, n]` f32 keys (L2-normed by caller;
774 /// GQA-repeated to match `h_v`)
775 /// `v [b, s, h_v, n]` f32 values
776 /// `g [b, s, h_v]` f32 log-gate (exp'd inside kernel)
777 /// `beta [b, s, h_v]` f32 delta-rule mixing factor
778 ///
779 /// Output: `[b, s, h_v, n]` f32.
780 ///
781 /// When `carry_state` is true, a sixth input `state [b, h_v, n, n]`
782 /// provides the initial SSM matrix per head; the kernel updates it
783 /// in place across the sequence and leaves the final state in the
784 /// same buffer (same layout as the internal scan state:
785 /// `state[h, i, j]` row-major over `(n, n)` per head).
786 GatedDeltaNet {
787 state_size: usize,
788 carry_state: bool,
789 },
790
791 /// Multi-layer (optionally bidirectional) LSTM over a
792 /// `[batch, seq, input]` sequence. Gate order i, f, g, o (PyTorch);
793 /// recurrence per step:
794 /// ```text
795 /// z = x_t · w_ihᵀ + h_{t-1} · w_hhᵀ + bias
796 /// i,f,o = σ(z_i), σ(z_f), σ(z_o); g = tanh(z_g)
797 /// c_t = f · c_{t-1} + i · g; h_t = o · tanh(c_t)
798 /// ```
799 /// `D = 2 if bidirectional else 1`. Inputs `[x, w_ih, w_hh, bias]`
800 /// (`+ [h0, c0]` when `carry`):
801 /// * `x`: `[batch, seq, input]`
802 /// * `w_ih`: packed, all `(layer, direction)` blocks concatenated in
803 /// `layer`-major then `direction` order. Block `(l,d)` is
804 /// `[4*hidden, in_l]` with `in_l = input` for `l=0` else `D*hidden`.
805 /// * `w_hh`: packed `L*D` blocks of `[4*hidden, hidden]`.
806 /// * `bias`: packed `L*D` blocks of `[4*hidden]` (combined `b_ih+b_hh`).
807 /// * `h0`, `c0` (carry only): `[L*D, batch, hidden]`; the final
808 /// `hn`/`cn` are written back **in place** (decode threading).
809 ///
810 /// Output `y`: `[batch, seq, D*hidden]` — last layer's hidden states
811 /// (forward ‖ reverse concatenated on the feature axis). With
812 /// `num_layers = 1, bidirectional = false, carry = false` this is the
813 /// plain single-layer LSTM and the weight shapes reduce to
814 /// `[4*hidden, input]`, `[4*hidden, hidden]`, `[4*hidden]`.
815 Lstm {
816 hidden_size: usize,
817 num_layers: usize,
818 bidirectional: bool,
819 carry: bool,
820 },
821
822 /// Gated Recurrent Unit (PyTorch). `D = 2 if bidirectional else 1`;
823 /// gate order r, z, n. Recurrence:
824 /// ```text
825 /// r = σ(x·W_irᵀ + b_ir + h·W_hrᵀ + b_hr)
826 /// z = σ(x·W_izᵀ + b_iz + h·W_hzᵀ + b_hz)
827 /// n = tanh(x·W_inᵀ + b_in + r ⊙ (h·W_hnᵀ + b_hn))
828 /// h' = (1 - z) ⊙ n + z ⊙ h
829 /// ```
830 /// The new-gate reset is applied to the hidden term *after* its bias,
831 /// so `b_ih`/`b_hh` cannot be merged. Inputs `[x, w_ih, w_hh, b_ih, b_hh]`
832 /// (`+ [h0]` when carry); packing matches [`Op::Lstm`] with `3*hidden`
833 /// gate rows. `h0` `[L*D, batch, hidden]`. Output `[batch, seq, D*hidden]`.
834 Gru {
835 hidden_size: usize,
836 num_layers: usize,
837 bidirectional: bool,
838 carry: bool,
839 },
840
841 /// Elman RNN (PyTorch): `h' = act(x·w_ihᵀ + h·w_hhᵀ + bias)` with
842 /// `act = ReLU` when `relu` else `tanh`. Packed `w_ih` `[hidden, in_l]`,
843 /// `w_hh` `[hidden, hidden]`, merged `bias` `[hidden]` (per layer ×
844 /// direction). Inputs `[x, w_ih, w_hh, bias]` (`+ [h0]` when carry).
845 /// Output `[batch, seq, D*hidden]`.
846 Rnn {
847 hidden_size: usize,
848 num_layers: usize,
849 bidirectional: bool,
850 carry: bool,
851 relu: bool,
852 },
853
854 /// Mamba-2 / SSD (structured state-space duality) scan — the
855 /// scalar-decay SSM at the core of Mamba-2, sibling of
856 /// [`Op::SelectiveScan`] / [`Op::GatedDeltaNet`]. Inputs
857 /// `[x, dt, a, b, c]` (all `f32`):
858 /// * `x`: `[batch, seq, heads, head_dim]`
859 /// * `dt`: `[batch, seq, heads]` — discretization step (already ≥ 0,
860 /// e.g. softplus'd by the caller)
861 /// * `a`: `[heads]` — per-head decay rate (`dA = exp(dt·a)`)
862 /// * `b`: `[batch, seq, heads, state_size]`
863 /// * `c`: `[batch, seq, heads, state_size]`
864 ///
865 /// State `S [batch, heads, head_dim, state_size]` is zero-initialized
866 /// per sequence. Recurrence per timestep `t`:
867 /// ```text
868 /// dA_t = exp(dt_t · a)
869 /// S_t = dA_t · S_{t-1} + (dt_t · x_t) ⊗ b_t
870 /// y_t = Σ_n S_t[:, n] · c_t[n]
871 /// ```
872 /// Output `y`: `[batch, seq, heads, head_dim]` (same shape as `x`).
873 Mamba2 {
874 head_dim: usize,
875 state_size: usize,
876 },
877
878 /// Fused dequant + matmul (plan #5). The biggest LLM-bandwidth
879 /// win on Apple Silicon: dequantizes weights inside the matmul
880 /// inner loop, never materializing f32 weights.
881 ///
882 /// **BREAKING CHANGE in 0.2.0:** `num_inputs()` is now
883 /// scheme-dependent — **4** for legacy Int8 schemes, **2** for
884 /// the new GGUF K-quant schemes (their scales/mins live inside
885 /// the packed bytes, so no side-channel `scale` / `zp` tensors
886 /// are fed in). Callers that assumed a fixed 4-input contract
887 /// must inspect `scheme.is_gguf()` before reading inputs.
888 ///
889 /// Inputs (Int8 schemes — `scheme.is_gguf() == false`):
890 /// `x [m, k]` f32 activations
891 /// `w_q [k, n]` packed quantized weight bytes (i8 per
892 /// element for Int8 schemes; 4-bit
893 /// packed two-per-byte for Int4)
894 /// `scale [k/block, n]` per-block f32 dequant scale
895 /// `zp [k/block, n]` per-block f32 zero-point
896 /// (zero-tensor if symmetric)
897 ///
898 /// Inputs (`Nvfp4Block` — fixed group size 16 along K):
899 /// `x [m, k]` f32 activations
900 /// `w_q [k,n/2]` packed FP4 E2M1 codes (unsigned nibble 0..15)
901 /// `scale [k/16, n]` u8 FP8 E4M3 block scales (one byte / group)
902 /// `global_scale [1]` f32 per-tensor scale (pass `[1.0]` if unused)
903 ///
904 /// Inputs (GGUF schemes — `scheme.is_gguf() == true`):
905 /// `x [m, k]` f32 activations
906 /// `packed_w [bytes]` raw GGUF super-block bytes; the
907 /// dequantizer reads the per-sub-block
908 /// scales / mins / quants directly out
909 /// of the buffer per the K-quant block
910 /// layout (no side tensors).
911 ///
912 /// Output: `[m, n]` f32.
913 ///
914 /// `block_size` (on the Int8 schemes only) is the number of
915 /// consecutive elements that share one (scale, zero_point) pair.
916 /// The Op carries enough metadata that the kernel doesn't need
917 /// a separate `QuantMap` lookup at run time.
918 DequantMatMul {
919 scheme: crate::quant::QuantScheme,
920 },
921
922 /// Real INT8-arithmetic matrix multiply with i32 accumulation.
923 /// Inputs (in order):
924 /// `x [M, K]` i8 activations (zero-point = `x_zp`)
925 /// `w [K, N]` i8 weights (zero-point = `w_zp`)
926 /// `bias [N]` i32 (in accumulator scale = `x_scale·w_scale`),
927 /// pass a zeros tensor for "no bias"
928 /// Output: `[M, N]` i8 (zero-point = `out_zp`)
929 ///
930 /// Per-element compute:
931 /// `out[m,n] = requantize(bias[n] + Σₖ (x[m,k]-x_zp)·(w[k,n]-w_zp), mult, out_zp)`
932 /// where `mult = x_scale · w_scale / out_scale`.
933 ///
934 /// This is the same kernel shape `rlx-cortexm/src/dense.rs`
935 /// uses for on-device int8 inference, lifted into the IR so the
936 /// rlx-cpu backend can run a quantized graph directly (instead
937 /// of round-tripping through fake-quant Dequantize → MatMul →
938 /// Quantize). 2-D only — generalizing to batched comes when a
939 /// real workload demands it.
940 QMatMul {
941 x_zp: i32,
942 w_zp: i32,
943 out_zp: i32,
944 mult: f32,
945 },
946
947 /// Real INT8-arithmetic 2-D convolution with i32 accumulation.
948 /// Inputs:
949 /// `x [N, C_in, H, W]` i8 (zero-point = `x_zp`)
950 /// `w [C_out, C_in/groups, kH, kW]` i8 (zero-point = `w_zp`)
951 /// `bias [C_out]` i32 in accumulator scale
952 /// Output: `[N, C_out, H_out, W_out]` i8 (zero-point = `out_zp`).
953 /// Same NCHW geometry contract as `Op::Conv`; same requantize
954 /// math as `Op::QMatMul` (per-element `acc·mult` rounded to i8).
955 QConv2d {
956 kernel_size: Vec<usize>,
957 stride: Vec<usize>,
958 padding: Vec<usize>,
959 dilation: Vec<usize>,
960 groups: usize,
961 x_zp: i32,
962 w_zp: i32,
963 out_zp: i32,
964 mult: f32,
965 },
966
967 /// Fused LoRA matmul: `out = x·W + scale * x·A·B`.
968 /// Inputs (in order): `x [m, k]`, `w [k, n]`, `a [k, r]`, `b [r, n]`.
969 /// `r` is the LoRA rank (typically 4-64). `scale` is the
970 /// per-adapter `alpha / rank` knob.
971 /// Plan #9: lifts LoRA from "three matmuls + an add" into one
972 /// kernel that keeps the rank-r intermediate in registers.
973 LoraMatMul {
974 scale: f32,
975 },
976
977 /// Fused sampling kernel: logits → optional top-k filter →
978 /// optional top-p truncation → softmax → multinomial sample.
979 /// One f32-encoded sampled token id per batch row (output
980 /// shape `[batch]`).
981 ///
982 /// `temperature == 1.0` matches a plain argmax-of-softmax;
983 /// lower → sharper, higher → flatter. `top_k == 0` disables.
984 /// `top_p == 1.0` disables. `seed` is the Philox seed; pass 0
985 /// for "use process-global counter" (still deterministic
986 /// given the call order).
987 /// Borrowed from MAX's nn/sampling.mojo (#42 in PLAN.md).
988 /// Latency-critical: never materializes the full softmax
989 /// distribution on the host.
990 Sample {
991 top_k: usize, // 0 = disabled
992 top_p: f32, // 1.0 = disabled
993 temperature: f32, // 1.0 = neutral
994 seed: u64, // 0 = use thread-local counter
995 },
996
997 /// ONNX `RandomNormalLike` / `RandomNormal`: zero or one shape-template
998 /// input (Like uses the template's shape; `RandomNormal` with a `shape`
999 /// attribute needs no input). Output shape is fixed on the node.
1000 /// at compile/execute time. Optional ONNX `seed` attribute (f32) overrides
1001 /// the mixed seed on the Ort backend.
1002 RngNormal {
1003 mean: f32,
1004 scale: f32,
1005 key: u64,
1006 op_seed: Option<f32>,
1007 },
1008
1009 /// ONNX `RandomUniformLike`.
1010 RngUniform {
1011 low: f32,
1012 high: f32,
1013 key: u64,
1014 op_seed: Option<f32>,
1015 },
1016
1017 /// Inclusive cumulative sum along an axis. Same shape in/out.
1018 /// Underpins ragged-tensor offsets, sampling (top-p prefix sum),
1019 /// and sequence-position math (#44 in PLAN.md).
1020 /// `exclusive=true` shifts the result so output\[0\] = 0 (useful
1021 /// for offset arrays where the first segment starts at 0).
1022 Cumsum {
1023 axis: i32,
1024 exclusive: bool,
1025 },
1026
1027 /// Index of the maximum along `axis`. Output drops that axis (or keeps it
1028 /// as size 1 when `keep_dim`). Indices are **f32-encoded** (rlx is f32 at
1029 /// the I/O boundary, matching [`Op::TopK`]).
1030 ArgMax {
1031 axis: usize,
1032 keep_dim: bool,
1033 },
1034
1035 /// Index of the minimum along `axis`; see [`Op::ArgMax`].
1036 ArgMin {
1037 axis: usize,
1038 keep_dim: bool,
1039 },
1040
1041 /// Softmax along an axis (reduction + element-wise).
1042 Softmax {
1043 axis: i32,
1044 },
1045
1046 /// Top-K **indices** along the last axis. Output shape `[..., k]`,
1047 /// f32-encoded indices (rlx is f32-only at the I/O boundary).
1048 /// To recover the values, follow with a `Gather` against the
1049 /// original tensor — works because Gather already supports any axis.
1050 /// Ties broken by smaller index (matches NumPy / PyTorch
1051 /// `torch.topk(..., largest=True, sorted=True)`).
1052 /// Used by MoE gating; also useful for beam search.
1053 TopK {
1054 k: usize,
1055 },
1056
1057 /// Indexed batched matmul. The MoE GEMM primitive.
1058 /// Inputs: `[input, weight, expert_idx]`
1059 /// input : [M, K] — per-token activations
1060 /// weight : [num_experts, K, N] — stacked expert weights
1061 /// expert_idx : \[M\] — f32-encoded expert id per token
1062 /// Output : [M, N] — output\[i\] = input\[i\] @ weight[expert_idx\[i\]]
1063 /// Naive impl on both backends; future work can replace with a
1064 /// segmented/grouped GEMM when there's a real workload.
1065 GroupedMatMul,
1066
1067 /// Fused GGUF K-quant dequant + [`Op::GroupedMatMul`]. Same three
1068 /// inputs as `GroupedMatMul`, but `weight` is a U8 tensor holding
1069 /// `num_experts` contiguous packed expert slabs (GGML layout, expert
1070 /// dimension outermost). Scales live inside the packed bytes.
1071 DequantGroupedMatMul {
1072 scheme: crate::quant::QuantScheme,
1073 },
1074
1075 /// Dequant a packed MoE expert stack to F32 `[num_experts, K, N]` in
1076 /// GroupedMatMul layout. Input: U8 packed bytes; output shape is
1077 /// declared on the node (`[E, K, N]`).
1078 DequantMoEWeights {
1079 scheme: crate::quant::QuantScheme,
1080 },
1081
1082 /// Scatter-add into a destination tensor. The "unpermute" half of
1083 /// MoE routing (also useful for embedding gradient updates).
1084 /// Inputs: `[updates, indices]`
1085 /// updates : [num_updates, trailing] — values to add
1086 /// indices : \[num_updates\] — f32-encoded destination row
1087 /// Output : [out_dim, trailing] — output[indices\[i\]] += updates\[i\]
1088 /// `out_dim` is taken from the node's declared output shape.
1089 /// Initial output is zero; multiple updates to the same row
1090 /// accumulate (sequentially on CPU; with atomic-add on Metal).
1091 ScatterAdd,
1092
1093 // ── Convolution ─────────────────────────────────────────────
1094 /// 2D convolution on NCHW tensors. Also exposed as [`OpKind::Conv`] / `conv2d`.
1095 /// Weight layout: `[C_out, C_in / groups, kH, kW]`.
1096 Conv {
1097 kernel_size: Vec<usize>,
1098 stride: Vec<usize>,
1099 padding: Vec<usize>,
1100 dilation: Vec<usize>,
1101 groups: usize,
1102 },
1103
1104 /// NCHW im2col for conv backward-weight style matmul.
1105 /// Input `[N, C, H, W]`. Output `[M, C·kH·kW]` with
1106 /// `M = N · H_out · W_out`. When batch is [`dynamic::sym::BATCH`],
1107 /// output rows use [`dynamic::sym::ROWS`] (bind `N · H_out · W_out`).
1108 Im2Col {
1109 kernel_size: Vec<usize>,
1110 stride: Vec<usize>,
1111 padding: Vec<usize>,
1112 dilation: Vec<usize>,
1113 },
1114
1115 /// 2D transposed convolution on NCHW. Weight layout (PyTorch):
1116 /// `[C_in, C_out / groups, kH, kW]`.
1117 ConvTranspose2d {
1118 kernel_size: Vec<usize>,
1119 stride: Vec<usize>,
1120 padding: Vec<usize>,
1121 dilation: Vec<usize>,
1122 output_padding: Vec<usize>,
1123 groups: usize,
1124 },
1125
1126 // ── Pooling ─────────────────────────────────────────────────
1127 Pool {
1128 kind: ReduceOp,
1129 kernel_size: Vec<usize>,
1130 stride: Vec<usize>,
1131 padding: Vec<usize>,
1132 },
1133
1134 // ── Backward / training ops ─────────────────────────────────
1135 //
1136 // Closed-form gradient nodes emitted by `rlx-opt::autodiff`.
1137 // Pairing a forward op with a dedicated backward op (instead of
1138 // composing it from primitives) keeps the gradient kernel simple
1139 // and lets the backend recompute argmax / masks / softmax inline.
1140 /// ReLU backward: `dx = dy where x > 0 else 0`.
1141 /// Inputs: `[x, dy]` — both same shape; output matches.
1142 ReluBackward,
1143
1144 /// Element-wise complex squared-magnitude: `|z|² = z.re² + z.im²`.
1145 /// Input: 1 tensor with `DType::C64`. Output: same shape but
1146 /// `DType::F32`. The natural real-valued loss surface for
1147 /// Wirtinger reverse-mode AD on complex graphs — pair with
1148 /// [`Op::ComplexNormSqBackward`].
1149 ComplexNormSq,
1150
1151 /// Element-wise complex conjugate: `z̄ = z.re - i·z.im` per element.
1152 /// Input: 1 tensor with `DType::C64`. Output: same shape, same dtype.
1153 /// Used by Wirtinger VJP rules on `Op::Binary` over C64 (the rule
1154 /// for `y = a·b` is `dL/dā = upstream · conj(b)`, etc.).
1155 Conjugate,
1156
1157 /// Backward for [`Op::ComplexNormSq`] under Wirtinger calculus.
1158 /// `f(z) = |z|² = z·z̄`, so `∂f/∂z̄ = z`. Given upstream real
1159 /// cotangent `g` (same shape as the forward output), the C64
1160 /// gradient with respect to `z` is `g · z` element-wise, returned
1161 /// in C64 storage `[re_g·re_z, re_g·im_z]` per element.
1162 ///
1163 /// Inputs: `[z (C64), g (F32)]` — both same logical shape; output
1164 /// matches `z` (C64).
1165 ComplexNormSqBackward,
1166
1167 /// LayerNorm backward w.r.t. the input. Computes
1168 /// `d_x[..., d] = inv_std · (dy·γ - mean(dy·γ) - x̂·mean(dy·γ·x̂))`
1169 /// over the feature axis, where `x̂ = (x - mean)/std` is recomputed
1170 /// inline from `x`. Inputs: `[x, gamma, dy]`; output shape = `x.shape`.
1171 /// Currently lowers axis=-1 only (matches the forward thunk).
1172 LayerNormBackwardInput {
1173 axis: i32,
1174 eps: f32,
1175 },
1176
1177 /// LayerNorm backward w.r.t. gamma. Computes
1178 /// `d_gamma[d] = Σ_{batch} dy[..., d] · x̂[..., d]`
1179 /// — sums the per-element product of upstream and the (recomputed)
1180 /// normalized input over the leading axes. Inputs: `[x, dy]`;
1181 /// output shape = `gamma.shape` (= 1-D feature axis).
1182 LayerNormBackwardGamma {
1183 axis: i32,
1184 eps: f32,
1185 },
1186
1187 /// RMSNorm backward w.r.t. input. Inputs `[x, gamma, beta, dy]`; output = `x.shape`.
1188 RmsNormBackwardInput {
1189 axis: i32,
1190 eps: f32,
1191 },
1192
1193 /// RMSNorm backward w.r.t. gamma. Inputs `[x, gamma, beta, dy]`; output = `gamma.shape`.
1194 RmsNormBackwardGamma {
1195 axis: i32,
1196 eps: f32,
1197 },
1198
1199 /// RMSNorm backward w.r.t. beta. Inputs `[x, gamma, beta, dy]`; output = `beta.shape`.
1200 RmsNormBackwardBeta {
1201 axis: i32,
1202 eps: f32,
1203 },
1204
1205 /// RoPE backward w.r.t. `x`. Inputs `[dy, cos, sin]`; output = `dy.shape`.
1206 RopeBackward {
1207 head_dim: usize,
1208 n_rot: usize,
1209 },
1210
1211 /// GroupNorm (NCHW) backward w.r.t. input. Inputs `[x, gamma, beta, dy]`.
1212 GroupNormBackwardInput {
1213 num_groups: usize,
1214 eps: f32,
1215 },
1216
1217 /// GroupNorm backward w.r.t. gamma. Inputs `[x, dy]`; output = `gamma.shape`.
1218 GroupNormBackwardGamma {
1219 num_groups: usize,
1220 eps: f32,
1221 },
1222
1223 /// GroupNorm backward w.r.t. beta. Inputs `[x, dy]`; output = `beta.shape`.
1224 GroupNormBackwardBeta {
1225 num_groups: usize,
1226 eps: f32,
1227 },
1228
1229 /// BatchNorm inference backward w.r.t. `x`. Inputs `[x, gamma, mean, var, dy]`.
1230 BatchNormInferenceBackwardInput {
1231 eps: f32,
1232 },
1233
1234 /// BatchNorm inference backward w.r.t. `gamma`. Inputs `[x, mean, var, dy]`.
1235 BatchNormInferenceBackwardGamma {
1236 eps: f32,
1237 },
1238
1239 /// BatchNorm inference backward w.r.t. `beta`. Inputs `[dy]`; output = `beta.shape`.
1240 BatchNormInferenceBackwardBeta,
1241
1242 /// Cumsum backward along `axis`. Inputs `[dy]`; output matches forward input shape.
1243 CumsumBackward {
1244 axis: i32,
1245 exclusive: bool,
1246 },
1247
1248 /// Gather backward (scatter-add into table). Inputs `[dy, indices]`; output = table shape.
1249 /// `axis` matches forward [`Op::Gather`].
1250 GatherBackward {
1251 axis: i32,
1252 },
1253
1254 /// Generic element-wise activation backward. `kind` selects the
1255 /// closed-form derivative `d/dx act(x)`. Inputs: `[x, dy]`; output
1256 /// shape matches `x`. The kernel computes `d/dx · dy` per element.
1257 ///
1258 /// Closed forms (all element-wise):
1259 /// * `Gelu` — exact derivative of erf-based GELU.
1260 /// * `GeluApprox` — derivative of the tanh approximation
1261 /// `0.5 x (1 + tanh(√(2/π)(x + 0.044715 x³)))`.
1262 /// * `Silu` — `σ(x)·(1 + x·(1 - σ(x)))`.
1263 /// * `Sigmoid` — `σ(x)·(1 - σ(x))`.
1264 /// * `Tanh` — `1 - tanh(x)²`.
1265 /// * `Exp` — `exp(x)`.
1266 /// * `Log` — `1 / x`.
1267 /// * `Sqrt` — `0.5 / sqrt(x)`.
1268 /// * `Rsqrt` — `-0.5 · x^(-3/2)`.
1269 /// * `Neg` — `-1`.
1270 /// * `Abs` — `sign(x)` (zero at x=0).
1271 /// * `Sin` — `cos(x)`.
1272 /// * `Cos` — `-sin(x)`.
1273 /// * `Tan` — `1 + tan²(x)`.
1274 /// * `Atan` — `1 / (1 + x²)`.
1275 /// * `Relu` — kept here for completeness; the dedicated
1276 /// `ReluBackward` op is preferred for relu and is what the
1277 /// autodiff pass actually emits.
1278 ActivationBackward {
1279 kind: Activation,
1280 },
1281
1282 /// Backward for `Op::FakeQuantize` under a non-default STE.
1283 /// Inputs `[x, dy]`: the forward input and the upstream
1284 /// gradient. Output `dx` same shape. The `bits`/`axis`/`ste`
1285 /// fields must match the forward op so the kernel computes the
1286 /// same per-channel scale and applies the right STE attenuation.
1287 /// For `SteKind::Identity` this op is unnecessary — autodiff
1288 /// just routes `upstream` through unchanged.
1289 FakeQuantizeBackward {
1290 bits: u8,
1291 axis: Option<usize>,
1292 ste: SteKind,
1293 },
1294
1295 /// 2D max-pool backward. Routes each element of `dy` back into the
1296 /// position in `x`'s window where the forward max was taken.
1297 /// Inputs: `[x, dy]` with `x [N, C, H, W]` and
1298 /// `dy [N, C, H_out, W_out]`. Output: same shape as `x`.
1299 /// Carries the forward pool's geometry so the kernel can recompute
1300 /// the argmax position per window without a saved-indices tensor.
1301 MaxPool2dBackward {
1302 kernel_size: Vec<usize>,
1303 stride: Vec<usize>,
1304 padding: Vec<usize>,
1305 },
1306
1307 /// 2D conv backward w.r.t. input. Computes `dx = conv_transpose(dy, w)`.
1308 /// Inputs: `[dy, w]` with `dy [N, C_out, H_out, W_out]` and
1309 /// `w [C_out, C_in/groups, kH, kW]`. Output: `[N, C_in, H, W]`
1310 /// (declared on the node — caller knows the original input shape).
1311 /// Geometry is the forward conv's parameters, not the transposed
1312 /// conv's.
1313 Conv2dBackwardInput {
1314 kernel_size: Vec<usize>,
1315 stride: Vec<usize>,
1316 padding: Vec<usize>,
1317 dilation: Vec<usize>,
1318 groups: usize,
1319 },
1320
1321 /// 2D conv backward w.r.t. weight. Computes
1322 /// `dw[c_out, c_in, kh, kw] = sum_{n,h_out,w_out} x[n,c_in,...] * dy[n,c_out,h_out,w_out]`.
1323 /// Inputs: `[x, dy]`. Output: `[C_out, C_in/groups, kH, kW]`.
1324 Conv2dBackwardWeight {
1325 kernel_size: Vec<usize>,
1326 stride: Vec<usize>,
1327 padding: Vec<usize>,
1328 dilation: Vec<usize>,
1329 groups: usize,
1330 },
1331
1332 /// Fused softmax + cross-entropy loss with integer (f32-encoded)
1333 /// targets — the standard classification loss. Per-row output:
1334 /// `loss[n] = -log(softmax(logits[n])[labels[n]])`.
1335 /// Inputs: `[logits, labels]` with `logits [N, C]` and
1336 /// `labels [N]` (f32-encoded class indices). Output: `[N]`.
1337 /// Caller does the `Reduce::Mean` if they want a scalar.
1338 SoftmaxCrossEntropyWithLogits,
1339
1340 /// Backward of the fused loss above. Emits
1341 /// `dlogits[n,c] = (softmax(logits[n])[c] - one_hot(labels)[n,c]) * d_loss[n]`.
1342 /// Inputs: `[logits, labels, d_loss]`. Output: `[N, C]` (same shape
1343 /// as `logits`). Recomputes the softmax inline rather than threading
1344 /// it through from the forward node.
1345 SoftmaxCrossEntropyBackward,
1346
1347 /// Backward of [`Op::Attention`]. Recomputes scaled `QK^T`, applies
1348 /// the same `mask_kind` as the forward op, softmaxes scores, then
1349 /// emits **one** of `dQ`, `dK`, or `dV` selected by [`AttentionBwdWrt`].
1350 /// Autodiff emits three nodes (one per `wrt`) so each output shape
1351 /// stays a normal single-output MIR node.
1352 ///
1353 /// Inputs: `[q, k, v, dy]` plus optional mask when `mask_kind` is
1354 /// [`MaskKind::Custom`] or [`MaskKind::Bias`] (same convention as
1355 /// forward). Output shape matches `q`, `k`, or `v` respectively.
1356 AttentionBackward {
1357 num_heads: usize,
1358 head_dim: usize,
1359 mask_kind: MaskKind,
1360 wrt: AttentionBwdWrt,
1361 },
1362
1363 // ── Fused operations (created by optimization passes) ──────
1364 /// Fused matmul + bias + activation. Created from MatMul → Add → Activation.
1365 FusedMatMulBiasAct {
1366 activation: Option<Activation>,
1367 },
1368
1369 /// Fused residual + optional bias + layer norm.
1370 /// Created from Add(x, residual) → [Add(bias)] → LayerNorm.
1371 FusedResidualLN {
1372 has_bias: bool,
1373 eps: f32,
1374 },
1375
1376 /// Fused residual + optional bias + RMS norm.
1377 /// Created from Add(x, residual) → [Add(bias)] → RmsNorm.
1378 FusedResidualRmsNorm {
1379 has_bias: bool,
1380 eps: f32,
1381 },
1382
1383 /// Fused SwiGLU: split input into up/gate halves, silu(gate) * up.
1384 /// Created from Split → Silu → Mul when fed by a concatenated matmul.
1385 ///
1386 /// `cast_to`: optional output dtype — when `Some(dt)` the kernel casts
1387 /// its result from the input dtype to `dt` in-register, saving a
1388 /// separate cast pass. Reserved for future fp8/fp4 quantization paths;
1389 /// for f32→f16 mixed precision the AutoMixedPrecision pass already
1390 /// inserts a Cast node so this stays `None` in current pipelines.
1391 FusedSwiGLU {
1392 cast_to: Option<DType>,
1393 /// When `true`, the concatenated input stores gate in the low half
1394 /// `[..., 0..N)` and up in the high half `[..., N..2N)` — the layout
1395 /// produced when gate projection is emitted before up in the builder.
1396 /// Default `false`: up @ low, gate @ high (canonical concat order).
1397 gate_first: bool,
1398 },
1399
1400 /// Fused full transformer layer: attention block + residual+LN + FFN + residual+LN.
1401 /// All intermediates resident in registers/threadgroup memory; one kernel
1402 /// per layer instead of ~30 (the CPU's batch=1 win, lifted to IR so any
1403 /// backend can implement it as a monolithic kernel).
1404 ///
1405 /// Inputs: hidden, qkv_w, qkv_b, out_w, out_b,
1406 /// ln1_g, ln1_b, fc1_w, fc1_b, fc2_w, fc2_b, ln2_g, ln2_b, mask
1407 /// Output: same shape as hidden.
1408 ///
1409 /// **Backend status:** same as FusedAttentionBlock. CPU implements
1410 /// the L1-cache-resident merge at the thunk level. Metal deferred —
1411 /// requires a single MSL kernel for the whole layer to actually
1412 /// beat the unfused path. Multi-day work; revisit when there's a
1413 /// model whose Metal inference is bottlenecked here rather than on
1414 /// the wait latency floor.
1415 FusedTransformerLayer {
1416 num_heads: usize,
1417 head_dim: usize,
1418 intermediate_size: usize,
1419 eps1: f32,
1420 eps2: f32,
1421 activation: Activation,
1422 has_bias: bool,
1423 },
1424
1425 /// Fused attention block: QKV projection → split → \[RoPE\] → SDPA → output projection.
1426 /// Created by FuseAttentionBlock pass when batch*seq is small.
1427 /// All intermediates stay in L1 cache — no arena writes between ops.
1428 ///
1429 /// Inputs (in order):
1430 /// hidden, qkv_w, out_w, mask,
1431 /// [qkv_b, out_b] if has_bias,
1432 /// [rope_cos, rope_sin] if has_rope
1433 ///
1434 /// **Backend status (Phase C finalize):**
1435 /// CPU — implemented at the *thunk* level: the CPU schedule
1436 /// recognizes the multi-thunk pattern and merges into
1437 /// a single FusedAttnBlock that keeps Q/K/V in stack
1438 /// buffers across stages (the L1-cache win).
1439 /// Metal — **deferred**. A dispatch-wrapper version (chaining
1440 /// existing kernels) buys nothing the unfused Metal path
1441 /// doesn't already get, since per-run cost is dominated
1442 /// by `wait_until_completed` (~150 µs), not encode. The
1443 /// real win is a single MSL kernel keeping Q/K/V in
1444 /// threadgroup memory across stages — multi-day work.
1445 /// Until then, Metal runs the unfused chain (one matmul,
1446 /// three narrows, two ropes, attention, one matmul) — all
1447 /// covered in op_coverage and parity_harness.
1448 FusedAttentionBlock {
1449 num_heads: usize,
1450 head_dim: usize,
1451 has_bias: bool,
1452 has_rope: bool,
1453 },
1454
1455 // ── Control flow (subgraphs as op payloads) ─────────────────
1456 //
1457 // Status: IR is defined; helper `run_if` / `run_while` exist in
1458 // rlx-runtime/src/subgraph.rs; **executor wiring is not yet
1459 // implemented** (both CPU thunk and Metal thunk fall through to
1460 // `Thunk::Nop` for these ops). Wiring requires:
1461 // 1. Recursive subgraph compile at parent-compile time.
1462 // 2. Per-subgraph input/output binding through the arena.
1463 // 3. Schedule-level dispatch when the predicate / loop cond is
1464 // resolved at runtime.
1465 // Estimate: 4–6 hours of focused work + parity tests. Deferred
1466 // because no current in-tree model uses these ops;
1467 // surface area without a validation target invites silent bugs.
1468 /// Conditional: pick between two subgraphs based on a boolean predicate.
1469 /// Inputs: [predicate, ...captures (used inside both branches)].
1470 /// `then_branch` and `else_branch` are sub-graphs that share the
1471 /// captured inputs and must produce identically-shaped outputs.
1472 /// Used for: shape-dependent execution, batched inference of
1473 /// dynamic-length sequences with padding masks.
1474 If {
1475 then_branch: Box<crate::Graph>,
1476 else_branch: Box<crate::Graph>,
1477 },
1478
1479 /// Loop: iterate `body` while `cond` evaluates true.
1480 /// Inputs: [...initial loop-carried values].
1481 /// `cond`'s single output is a Bool scalar.
1482 /// `body`'s outputs become the next iteration's loop-carried inputs.
1483 /// Outputs of While are the values after the final iteration.
1484 /// Used for: KV-cache-driven autoregressive generation, beam search.
1485 While {
1486 cond: Box<crate::Graph>,
1487 body: Box<crate::Graph>,
1488 max_iterations: Option<usize>,
1489 },
1490
1491 /// Bounded-length loop with a fixed-shape carry, optional per-step
1492 /// inputs, and optional stacked output. Mirrors JAX's `lax.scan`.
1493 ///
1494 /// Body signature: `(carry, x_t_0, ..., x_t_{num_xs-1}) → carry_next`
1495 /// — `1 + num_xs` Op::Inputs in NodeId construction order (first
1496 /// declared is the carry; the remaining `num_xs` are per-step
1497 /// slices). Single output (the next carry).
1498 ///
1499 /// Outer Op::Scan inputs (in order):
1500 /// `[init_carry, xs_0, xs_1, ..., xs_{num_xs-1}]`
1501 /// Each `xs_i` has shape `[length, *per_step_shape_i]`; the body
1502 /// sees `xs_i[t]` (a `per_step_shape_i` slice) on iteration `t`.
1503 ///
1504 /// Outer Op::Scan output:
1505 /// * `save_trajectory == false` — final carry, shape `*carry`.
1506 /// * `save_trajectory == true` — stacked trajectory of carries,
1507 /// shape `[length, *carry]`. Row `t` is the carry after step
1508 /// `t+1`, so row `length-1` matches the no-trajectory case.
1509 ///
1510 /// Mirrors JAX's `lax.scan`. Common uses include time-stepping
1511 /// integrators with time-varying drives, Mamba-style SSM scans
1512 /// reading per-step inputs, and RNN-style sequence processing.
1513 Scan {
1514 body: Box<crate::Graph>,
1515 length: u32,
1516 save_trajectory: bool,
1517 /// Number of "broadcast" inputs — values that are constant
1518 /// across iterations. Outer scan inputs in order:
1519 /// `[init, bcast_0..bcast_{B-1}, xs_0..xs_{X-1}]`
1520 /// Body Op::Inputs in NodeId order:
1521 /// `[carry, bcast_0..bcast_{B-1}, x_t_0..x_t_{X-1}]`
1522 /// CPU executor fills bcast slots ONCE before the iteration
1523 /// loop (xs slots are filled per-step). The reverse-mode AD
1524 /// pre-pass materialises each bcast into an xs of shape
1525 /// `[length, *bcast]` via broadcast `Mul` so the rest of the
1526 /// VJP / executor pipeline can stay unchanged. `0` (default)
1527 /// keeps the original carry+xs scan shape.
1528 num_bcast: u32,
1529 /// Number of per-step `xs` inputs. Total outer Op::Scan
1530 /// inputs is `1 + num_bcast + num_xs`.
1531 num_xs: u32,
1532 /// Number of trajectory checkpoints when `save_trajectory ==
1533 /// true`. `0` means "save all `length` rows" (default). A
1534 /// positive value `K` means save only `K` evenly-spaced rows
1535 /// at indices `floor(t * length / K)` for `t in 0..K`. Used
1536 /// by recursive checkpointed AD: store O(√T) carries during
1537 /// forward, recompute the rest in the backward pass.
1538 ///
1539 /// When `0` (or `K == length`), the saved trajectory has
1540 /// shape `[length, *carry]` — same as the original behavior.
1541 /// When `0 < K < length`, the saved trajectory has shape
1542 /// `[K, *carry]`.
1543 num_checkpoints: u32,
1544 },
1545
1546 /// Reverse-mode AD companion to `Op::Scan` — extracts the carry
1547 /// gradient `dinit`. Walks `t = length-1 .. 0`, applying `body_vjp`
1548 /// to thread `dcarry` back through the time loop.
1549 ///
1550 /// Inputs (in order):
1551 /// `[init, trajectory, upstream, xs_0, ..., xs_{num_xs-1}]`
1552 /// Output: `dinit`, shape = carry shape.
1553 ///
1554 /// `body_vjp` is the result of
1555 /// `autodiff::grad(body, [carry_id, xs_0_id, ..., xs_{num_xs-1}_id])`
1556 /// — a graph with `1 + num_xs + 1` Inputs (carry + x_t_i for each
1557 /// xs + `"d_output"`) and `1 + num_xs` outputs
1558 /// (dcarry + dx_t_i for each xs). This op reads `outputs[0]` =
1559 /// dcarry; the sibling [`Self::ScanBackwardXs`] reads the
1560 /// `outputs[1 + xs_idx]` slot for each xs gradient.
1561 ScanBackward {
1562 body_vjp: Box<crate::Graph>,
1563 length: u32,
1564 save_trajectory: bool,
1565 num_xs: u32,
1566 /// When `0` or equal to `length`, the trajectory input has
1567 /// shape `[length, *carry]` — every step's carry is cached
1568 /// (`CheckpointStrategy::All`). When `0 < K < length`, the
1569 /// trajectory input has shape `[K, *carry]` and the executor
1570 /// recomputes intermediate carries via `forward_body` between
1571 /// checkpoints. `forward_body` must be `Some` whenever this
1572 /// is < length.
1573 num_checkpoints: u32,
1574 /// Forward body (the same `body` from the forward Op::Scan).
1575 /// Required when `num_checkpoints > 0 && < length` so the
1576 /// executor can recompute carries between saved checkpoints.
1577 /// `None` for the All strategy (no recompute needed).
1578 forward_body: Option<Box<crate::Graph>>,
1579 },
1580
1581 /// Companion to [`Self::ScanBackward`] that extracts one stacked
1582 /// per-step `dxs_i` (shape `[length, *per_step_xs_i]`). Same inputs
1583 /// and same `body_vjp` graph as ScanBackward — `xs_idx` selects
1584 /// which body_vjp output to stack into the result.
1585 ///
1586 /// Note: each ScanBackwardXs runs its own backward loop. A future
1587 /// optimization can fuse them into a single multi-output backward
1588 /// kernel; for now it's `1 + num_xs` independent sweeps.
1589 ScanBackwardXs {
1590 body_vjp: Box<crate::Graph>,
1591 length: u32,
1592 save_trajectory: bool,
1593 num_xs: u32,
1594 xs_idx: u32,
1595 num_checkpoints: u32,
1596 forward_body: Option<Box<crate::Graph>>,
1597 },
1598
1599 /// CPU reference 3D Gaussian splat forward render.
1600 ///
1601 /// Seven flat F32 inputs (scene buffers + camera/render meta):
1602 /// 0. positions `[N*3]`
1603 /// 1. scales `[N*3]` (log-space)
1604 /// 2. rotations `[N*4]` (xyzw)
1605 /// 3. opacities `[N]` (logit)
1606 /// 4. colors `[N*3]` (linear RGB)
1607 /// 5. sh_coeffs `[N * sh_coeff_count * 3]`
1608 /// 6. meta `[23]` — camera position/target/up/fov/near/far, background RGB,
1609 /// then width/height/tile_size/radius_scale/alpha_cutoff/max_splat_steps/
1610 /// transmittance_threshold/max_list_entries as f32 bit-patterns.
1611 ///
1612 /// Output: `[height * width * 4]` linear RGBA (display gamma baked in).
1613 /// Build via [`crate::Graph::gaussian_splat_render`].
1614 ///
1615 /// Differentiable backward is not implemented in v1; autodiff treats this
1616 /// op as non-differentiable (same as [`Op::Sample`]).
1617 GaussianSplatRender {
1618 width: u32,
1619 height: u32,
1620 tile_size: u32,
1621 radius_scale: f32,
1622 alpha_cutoff: f32,
1623 max_splat_steps: u32,
1624 transmittance_threshold: f32,
1625 max_list_entries: u32,
1626 },
1627
1628 /// Backward pass for [`Self::GaussianSplatRender`].
1629 ///
1630 /// Eight inputs: the same seven as forward plus `d_loss_rgba` `[W*H*4]`
1631 /// (only RGB channels are used). Re-runs the training forward internally.
1632 ///
1633 /// Output: packed gradients
1634 /// `[positions(3N) | scales(3N) | rotations(4N) | opacities(N) | colors(3N) | sh(N*sh*3)]`.
1635 /// Unpack via [`crate::ops::splat::unpack_gaussian_splat_packed_grads`].
1636 GaussianSplatRenderBackward {
1637 width: u32,
1638 height: u32,
1639 tile_size: u32,
1640 radius_scale: f32,
1641 alpha_cutoff: f32,
1642 max_splat_steps: u32,
1643 transmittance_threshold: f32,
1644 max_list_entries: u32,
1645 loss_grad_clip: f32,
1646 sh_band: u32,
1647 max_anisotropy: f32,
1648 },
1649
1650 /// Strict IR stage 1: project, bin, sort, build per-pixel rays.
1651 ///
1652 /// Seven inputs (same scene + meta as [`Self::GaussianSplatRender`]). Output: packed
1653 /// prepare buffer (see `rlx_splat::prep_layout::prep_packed_len`).
1654 GaussianSplatPrepare {
1655 width: u32,
1656 height: u32,
1657 tile_size: u32,
1658 radius_scale: f32,
1659 alpha_cutoff: f32,
1660 max_splat_steps: u32,
1661 transmittance_threshold: f32,
1662 max_list_entries: u32,
1663 },
1664
1665 /// Strict IR stage 2: tile raster from [`Self::GaussianSplatPrepare`] output.
1666 ///
1667 /// Inputs: `prep` packed buffer, `meta` `[23]`. Output: `[width * height * 4]` RGBA.
1668 GaussianSplatRasterize {
1669 width: u32,
1670 height: u32,
1671 tile_size: u32,
1672 alpha_cutoff: f32,
1673 max_splat_steps: u32,
1674 transmittance_threshold: f32,
1675 max_list_entries: u32,
1676 },
1677
1678 /// User-registered custom op. `name` keys into the
1679 /// [`crate::op_registry`] for shape inference, autodiff, and
1680 /// per-backend execution. `attrs` is an opaque blob passed
1681 /// through to those callbacks (FFT direction, SparseLU
1682 /// reordering strategy, etc.). `num_inputs` is captured at
1683 /// construction time so [`Op::num_inputs`] stays infallible
1684 /// without a registry lookup. Build via [`crate::Graph::custom_op`].
1685 Custom {
1686 name: String,
1687 num_inputs: u32,
1688 attrs: Vec<u8>,
1689 },
1690
1691 /// 1D Fast Fourier Transform along the last axis.
1692 ///
1693 /// **Layouts**
1694 /// - `F32` / `F64`: 2N real-block — last axis is `[re₀…re_{N-1}, im₀…im_{N-1}]`.
1695 /// - `C64`: interleaved `[re, im]` pairs per complex element along the last axis.
1696 ///
1697 /// **ND transforms** — use `Graph::fftn` / `Graph::ifftn`, which compose
1698 /// `fft_axis` (transpose → Fft → transpose). Multi-axis `fftn` requires
1699 /// `DType::C64`; the 2N-block layout describes a single complex axis.
1700 ///
1701 /// Default (`FftNorm::Backward`) is **unnormalized** on both directions:
1702 /// `fft(x)[k] = Σ x[n]·exp(-2πi·nk/N)`
1703 /// `ifft(y)[n] = Σ y[k]·exp(+2πi·nk/N)`
1704 /// so `ifft(fft(x)) = N·x`. Use `FftNorm::Forward` for gpu-fft-style
1705 /// `1/N` scaling on inverse, or `FftNorm::Ortho` for unitary scaling.
1706 ///
1707 /// AD: VJP(`fft`) = `ifft`, VJP(`ifft`) = `fft` when `norm=Backward`;
1708 /// other norms apply the chain rule via output scaling.
1709 Fft {
1710 inverse: bool,
1711 norm: crate::fft::FftNorm,
1712 },
1713
1714 /// Ternary pruned radix-2 butterfly stage on interleaved complex state.
1715 ///
1716 /// Inputs:
1717 /// 0 — state `[batch, n_fft, 2]` (re/im on axis 2)
1718 /// 1 — gate `[half]` — 0 = identity, 1 = run butterfly (`half = n_fft/2`)
1719 /// 2 — rev `[half]` — 0 = forward, 1 = swap outputs when gate=1
1720 /// 3 — tw_re `[half]`
1721 /// 4 — tw_im `[half]`
1722 ///
1723 /// Output: `[batch, n_fft, 2]` same layout. Slots with gate=0 copy inputs
1724 /// without twiddle math.
1725 FftButterflyStage {
1726 stage: u32,
1727 n_fft: u32,
1728 },
1729
1730 /// Log-mel spectrogram from RLX FFT block-layout spectrum.
1731 ///
1732 /// Inputs:
1733 /// 0 — spectrum `[..., 2*n_fft]` (re plane then im plane, same as `Op::Fft` output)
1734 /// 1 — mel filterbank `[n_mels, n_bins]` with `n_bins = n_fft/2 + 1`
1735 ///
1736 /// Output: `[..., n_mels]` with Whisper dynamic-range compression
1737 /// (`log10`, clamp to max−8 dB, `(x+4)/4`).
1738 LogMel,
1739
1740 /// VJP of [`Op::LogMel`] w.r.t. spectrum (input 0).
1741 ///
1742 /// Inputs: spectrum block, mel filters, upstream `dy`.
1743 /// Output: `d_spectrum` (same shape as input 0).
1744 LogMelBackward,
1745
1746 /// Top-K Welch peaks from block-layout segment spectra.
1747 ///
1748 /// Input 0: spectrum `[batch * n_segments, 2*n_fft]` (re ∥ im planes).
1749 /// Output: `[batch, k*2]` interleaved `(bin, power)` per spike.
1750 WelchPeaks {
1751 k: usize,
1752 n_segments: usize,
1753 },
1754
1755 /// User-defined sub-graph with optional override AD rules.
1756 /// Mirrors JAX's `custom_vjp` / `custom_jvp` decorators: the
1757 /// caller wraps a forward computation and supplies its own
1758 /// reverse- and/or forward-mode AD bodies. Useful when:
1759 /// * The forward is iterative (Newton, fixed-point) and
1760 /// differentiating through the loop is wasteful — the
1761 /// vjp_body computes the implicit-function gradient at the
1762 /// converged point in one shot.
1763 /// * The math has a closed-form gradient that's much cheaper
1764 /// than autodiff.
1765 /// * The forward op is non-differentiable by tracing
1766 /// (sampling, argmax) and the user wants a smooth surrogate.
1767 ///
1768 /// **fwd_body**: `num_inputs` Op::Inputs in NodeId construction
1769 /// order, one Op::Output (the primal y). Forward execution
1770 /// inlines this body once.
1771 ///
1772 /// **vjp_body** (optional): Op::Inputs are `num_inputs` primal
1773 /// inputs in NodeId order, plus two special-named Inputs —
1774 /// `"primal_output"` (the y from forward) and `"d_output"` (the
1775 /// upstream gradient). Outputs: `num_inputs` tensors in
1776 /// `set_outputs` order, matching the gradients of each primal
1777 /// input. When `None`, reverse-mode AD recurses into fwd_body
1778 /// — same as if the op were inlined.
1779 ///
1780 /// **jvp_body** (optional): Op::Inputs are `num_inputs` primal
1781 /// inputs in NodeId order, `num_inputs` special-named Inputs
1782 /// `"tangent_0"..="tangent_{num_inputs-1}"` carrying each input's
1783 /// tangent, and an optional special-named `"primal_output"` Input
1784 /// (the y from forward, useful when the JVP must be evaluated at
1785 /// a converged / nonlinear point — e.g. IFT-style forward-mode AD
1786 /// of an iterative solver). Output: 1 tensor (the tangent of y).
1787 /// When `None`, forward-mode AD recurses into fwd_body.
1788 ///
1789 /// `num_inputs` is captured so [`Op::num_inputs`] stays
1790 /// infallible. Build via [`crate::Graph::custom_fn`].
1791 CustomFn {
1792 fwd_body: Box<crate::Graph>,
1793 vjp_body: Option<Box<crate::Graph>>,
1794 jvp_body: Option<Box<crate::Graph>>,
1795 num_inputs: u32,
1796 },
1797}
1798
1799impl Op {
1800 /// PLAN L4: discriminant for backend-supported-set checks.
1801 /// Stable, parameter-free identity per variant — `Op::Activation(_)`
1802 /// and `Op::Activation(Relu)` share the same `OpKind::Activation`.
1803 pub fn kind(&self) -> OpKind {
1804 match self {
1805 Op::Input { .. } => OpKind::Input,
1806 Op::Param { .. } => OpKind::Param,
1807 Op::Constant { .. } => OpKind::Constant,
1808 Op::Activation(_) => OpKind::Activation,
1809 Op::Cast { .. } => OpKind::Cast,
1810 Op::StopGradient => OpKind::StopGradient,
1811 Op::Quantize { .. } => OpKind::Quantize,
1812 Op::Dequantize { .. } => OpKind::Dequantize,
1813 Op::FakeQuantize { .. } => OpKind::FakeQuantize,
1814 Op::FakeQuantizeLSQ { .. } => OpKind::FakeQuantizeLSQ,
1815 Op::FakeQuantizeLSQBackwardX { .. } => OpKind::FakeQuantizeLSQBackwardX,
1816 Op::FakeQuantizeLSQBackwardScale { .. } => OpKind::FakeQuantizeLSQBackwardScale,
1817 Op::Binary(_) => OpKind::Binary,
1818 Op::Compare(_) => OpKind::Compare,
1819 Op::Where => OpKind::Where,
1820 Op::ElementwiseRegion { .. } => OpKind::ElementwiseRegion,
1821 Op::TransformRegion { .. } => OpKind::TransformRegion,
1822 Op::BatchElementwiseRegion { .. } => OpKind::BatchElementwiseRegion,
1823 Op::MatMul => OpKind::MatMul,
1824 Op::DotGeneral { .. } => OpKind::DotGeneral,
1825 Op::DenseSolve => OpKind::DenseSolve,
1826 Op::BatchedDenseSolve => OpKind::BatchedDenseSolve,
1827 Op::LayerNorm { .. } => OpKind::LayerNorm,
1828 Op::LayerNorm2d { .. } => OpKind::LayerNorm2d,
1829 Op::GroupNorm { .. } => OpKind::GroupNorm,
1830 Op::BatchNormInference { .. } => OpKind::BatchNormInference,
1831 Op::RmsNorm { .. } => OpKind::RmsNorm,
1832 Op::ResizeNearest2x => OpKind::ResizeNearest2x,
1833 Op::Attention { .. } => OpKind::Attention,
1834 Op::Rope { .. } => OpKind::Rope,
1835 Op::AxialRope2d { .. } => OpKind::AxialRope2d,
1836 Op::Reshape { .. } => OpKind::Reshape,
1837 Op::Transpose { .. } => OpKind::Transpose,
1838 Op::Narrow { .. } => OpKind::Narrow,
1839 Op::Concat { .. } => OpKind::Concat,
1840 Op::Expand { .. } => OpKind::Expand,
1841 Op::Gather { .. } => OpKind::Gather,
1842 Op::Reduce { .. } => OpKind::Reduce,
1843 Op::Softmax { .. } => OpKind::Softmax,
1844 Op::Cumsum { .. } => OpKind::Cumsum,
1845 Op::ArgMax { .. } => OpKind::ArgMax,
1846 Op::ArgMin { .. } => OpKind::ArgMin,
1847 Op::TopK { .. } => OpKind::TopK,
1848 Op::Sample { .. } => OpKind::Sample,
1849 Op::RngNormal { .. } => OpKind::RngNormal,
1850 Op::RngUniform { .. } => OpKind::RngUniform,
1851 Op::Conv { .. } => OpKind::Conv,
1852 Op::Im2Col { .. } => OpKind::Im2Col,
1853 Op::ConvTranspose2d { .. } => OpKind::ConvTranspose2d,
1854 Op::Pool { .. } => OpKind::Pool,
1855 Op::ReluBackward => OpKind::ReluBackward,
1856 Op::ActivationBackward { .. } => OpKind::ActivationBackward,
1857 Op::FakeQuantizeBackward { .. } => OpKind::FakeQuantizeBackward,
1858 Op::ComplexNormSq => OpKind::ComplexNormSq,
1859 Op::ComplexNormSqBackward => OpKind::ComplexNormSqBackward,
1860 Op::Conjugate => OpKind::Conjugate,
1861 Op::LayerNormBackwardInput { .. } => OpKind::LayerNormBackwardInput,
1862 Op::LayerNormBackwardGamma { .. } => OpKind::LayerNormBackwardGamma,
1863 Op::RmsNormBackwardInput { .. } => OpKind::RmsNormBackwardInput,
1864 Op::RmsNormBackwardGamma { .. } => OpKind::RmsNormBackwardGamma,
1865 Op::RmsNormBackwardBeta { .. } => OpKind::RmsNormBackwardBeta,
1866 Op::RopeBackward { .. } => OpKind::RopeBackward,
1867 Op::GroupNormBackwardInput { .. } => OpKind::GroupNormBackwardInput,
1868 Op::GroupNormBackwardGamma { .. } => OpKind::GroupNormBackwardGamma,
1869 Op::GroupNormBackwardBeta { .. } => OpKind::GroupNormBackwardBeta,
1870 Op::BatchNormInferenceBackwardInput { .. } => OpKind::BatchNormInferenceBackwardInput,
1871 Op::BatchNormInferenceBackwardGamma { .. } => OpKind::BatchNormInferenceBackwardGamma,
1872 Op::BatchNormInferenceBackwardBeta => OpKind::BatchNormInferenceBackwardBeta,
1873 Op::CumsumBackward { .. } => OpKind::CumsumBackward,
1874 Op::GatherBackward { .. } => OpKind::GatherBackward,
1875 Op::MaxPool2dBackward { .. } => OpKind::MaxPool2dBackward,
1876 Op::Conv2dBackwardInput { .. } => OpKind::Conv2dBackwardInput,
1877 Op::Conv2dBackwardWeight { .. } => OpKind::Conv2dBackwardWeight,
1878 Op::SoftmaxCrossEntropyWithLogits => OpKind::SoftmaxCrossEntropyWithLogits,
1879 Op::SoftmaxCrossEntropyBackward => OpKind::SoftmaxCrossEntropyBackward,
1880 Op::AttentionBackward { .. } => OpKind::AttentionBackward,
1881 Op::GroupedMatMul => OpKind::GroupedMatMul,
1882 Op::DequantGroupedMatMul { .. } => OpKind::DequantGroupedMatMul,
1883 Op::DequantMoEWeights { .. } => OpKind::DequantMoEWeights,
1884 Op::ScatterAdd => OpKind::ScatterAdd,
1885 Op::LoraMatMul { .. } => OpKind::LoraMatMul,
1886 Op::DequantMatMul { .. } => OpKind::DequantMatMul,
1887 Op::QMatMul { .. } => OpKind::QMatMul,
1888 Op::QConv2d { .. } => OpKind::QConv2d,
1889 Op::SelectiveScan { .. } => OpKind::SelectiveScan,
1890 Op::GatedDeltaNet { .. } => OpKind::GatedDeltaNet,
1891 Op::Lstm { .. } => OpKind::Lstm,
1892 Op::Gru { .. } => OpKind::Gru,
1893 Op::Rnn { .. } => OpKind::Rnn,
1894 Op::Mamba2 { .. } => OpKind::Mamba2,
1895 Op::FusedSwiGLU { .. } => OpKind::FusedSwiGLU,
1896 Op::FusedMatMulBiasAct { .. } => OpKind::FusedMatMulBiasAct,
1897 Op::FusedResidualLN { .. } => OpKind::FusedResidualLN,
1898 Op::FusedResidualRmsNorm { .. } => OpKind::FusedResidualRmsNorm,
1899 Op::FusedAttentionBlock { .. } => OpKind::FusedAttentionBlock,
1900 Op::FusedTransformerLayer { .. } => OpKind::FusedTransformerLayer,
1901 Op::If { .. } => OpKind::If,
1902 Op::While { .. } => OpKind::While,
1903 Op::Scan { .. } => OpKind::Scan,
1904 Op::ScanBackward { .. } => OpKind::ScanBackward,
1905 Op::ScanBackwardXs { .. } => OpKind::ScanBackwardXs,
1906 Op::GaussianSplatRender { .. } => OpKind::GaussianSplatRender,
1907 Op::GaussianSplatRenderBackward { .. } => OpKind::GaussianSplatRenderBackward,
1908 Op::GaussianSplatPrepare { .. } => OpKind::GaussianSplatPrepare,
1909 Op::GaussianSplatRasterize { .. } => OpKind::GaussianSplatRasterize,
1910 Op::Custom { .. } => OpKind::Custom,
1911 Op::CustomFn { .. } => OpKind::CustomFn,
1912 Op::Fft { .. } => OpKind::Fft,
1913 Op::FftButterflyStage { .. } => OpKind::FftButterflyStage,
1914 Op::LogMel => OpKind::LogMel,
1915 Op::LogMelBackward => OpKind::LogMelBackward,
1916 Op::WelchPeaks { .. } => OpKind::WelchPeaks,
1917 }
1918 }
1919
1920 /// True if this op is element-wise (same shape in, same shape out).
1921 /// Element-wise ops are prime fusion candidates.
1922 pub fn is_elementwise(&self) -> bool {
1923 matches!(
1924 self,
1925 Op::Activation(_)
1926 | Op::Cast { .. }
1927 | Op::StopGradient
1928 | Op::Binary(_)
1929 | Op::Compare(_)
1930 | Op::Where
1931 | Op::ElementwiseRegion { .. }
1932 | Op::BatchElementwiseRegion { .. }
1933 )
1934 }
1935
1936 /// True if this op may appear in a [`Op::TransformRegion`] chain.
1937 pub fn is_transform_eligible(&self) -> bool {
1938 matches!(self, Op::ResizeNearest2x)
1939 }
1940
1941 /// True if this op is a BLAS/compute-intensive op that forms a fusion boundary.
1942 pub fn is_blas(&self) -> bool {
1943 matches!(
1944 self,
1945 Op::MatMul
1946 | Op::DotGeneral { .. }
1947 | Op::DenseSolve
1948 | Op::BatchedDenseSolve
1949 | Op::Conv { .. }
1950 | Op::Im2Col { .. }
1951 | Op::ConvTranspose2d { .. }
1952 | Op::FusedMatMulBiasAct { .. }
1953 | Op::GroupedMatMul
1954 | Op::DequantGroupedMatMul { .. }
1955 | Op::DequantMoEWeights { .. }
1956 | Op::LoraMatMul { .. }
1957 | Op::DequantMatMul { .. }
1958 | Op::QMatMul { .. }
1959 | Op::QConv2d { .. }
1960 )
1961 }
1962
1963 /// True if element-wise fusion must not span across this op.
1964 pub fn is_fusion_boundary(&self) -> bool {
1965 self.is_blas()
1966 || matches!(
1967 self,
1968 Op::GaussianSplatRender { .. }
1969 | Op::GaussianSplatRenderBackward { .. }
1970 | Op::GaussianSplatPrepare { .. }
1971 | Op::GaussianSplatRasterize { .. }
1972 )
1973 }
1974
1975 /// True if this op is a reduction (drives loop iteration in fused kernels).
1976 pub fn is_reduction(&self) -> bool {
1977 matches!(
1978 self,
1979 Op::Reduce { .. } | Op::Softmax { .. } | Op::TopK { .. }
1980 )
1981 }
1982
1983 /// Number of tensor inputs this op expects.
1984 pub fn num_inputs(&self) -> usize {
1985 match self {
1986 Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => 0,
1987 Op::Activation(_)
1988 | Op::Cast { .. }
1989 | Op::StopGradient
1990 | Op::Reshape { .. }
1991 | Op::Quantize { .. }
1992 | Op::Dequantize { .. }
1993 | Op::Transpose { .. }
1994 | Op::Narrow { .. }
1995 | Op::Expand { .. }
1996 | Op::Reduce { .. }
1997 | Op::Softmax { .. }
1998 | Op::FusedSwiGLU { .. }
1999 | Op::TopK { .. }
2000 | Op::Cumsum { .. }
2001 | Op::ArgMax { .. }
2002 | Op::ArgMin { .. }
2003 | Op::Sample { .. }
2004 | Op::ResizeNearest2x => 1,
2005 Op::RngNormal { .. } | Op::RngUniform { .. } => 0, // 0 or 1 — see verify
2006 // EMA / Fixed scale modes carry a state tensor as a 2nd input;
2007 // PerBatch (default) doesn't need one.
2008 Op::FakeQuantize { scale_mode, .. } => match scale_mode {
2009 ScaleMode::PerBatch => 1,
2010 ScaleMode::EMA { .. } | ScaleMode::Fixed => 2,
2011 },
2012 Op::FakeQuantizeLSQ { .. } => 2, // x, scale (learned param)
2013 Op::FakeQuantizeLSQBackwardX { .. } | Op::FakeQuantizeLSQBackwardScale { .. } => 3, // x, scale, dy
2014 Op::Binary(_) | Op::Compare(_) | Op::Gather { .. } | Op::MatMul | Op::ScatterAdd => 2,
2015 Op::GroupedMatMul => 3, // input, weight, expert_idx
2016 Op::DequantGroupedMatMul { .. } => 3, // input, packed_w, expert_idx
2017 Op::DequantMoEWeights { .. } => 1, // packed_w
2018 Op::LoraMatMul { .. } => 4, // x, w, a, b
2019 // x, w_q, scale, zp — or x, packed_w_bytes for GGUF
2020 // schemes (their scales/mins live inside the packed bytes,
2021 // see `QuantScheme::is_gguf`).
2022 Op::DequantMatMul { scheme } => {
2023 if scheme.is_gguf() {
2024 2
2025 } else {
2026 4
2027 }
2028 }
2029 Op::QMatMul { .. } => 3, // x, w, bias
2030 Op::QConv2d { .. } => 3, // x, w, bias
2031 Op::SelectiveScan { .. } => 5, // x, delta, a, b, c
2032 Op::GatedDeltaNet { carry_state, .. } if *carry_state => 6, // + state in/out
2033 Op::GatedDeltaNet { .. } => 5, // q, k, v, g, beta
2034 Op::Lstm { carry, .. } => {
2035 if *carry { 6 } else { 4 } // x, w_ih, w_hh, bias (+ h0, c0)
2036 }
2037 Op::Gru { carry, .. } => {
2038 if *carry { 6 } else { 5 } // x, w_ih, w_hh, b_ih, b_hh (+ h0)
2039 }
2040 Op::Rnn { carry, .. } => {
2041 if *carry { 5 } else { 4 } // x, w_ih, w_hh, bias (+ h0)
2042 }
2043 Op::Mamba2 { .. } => 5, // x, dt, a, b, c
2044 Op::Where => 3, // cond, on_true, on_false
2045 Op::Attention { mask_kind, .. } => match mask_kind {
2046 MaskKind::Custom | MaskKind::Bias => 4, // Q, K, V, mask
2047 _ => 3, // Q, K, V (mask synthesized in-kernel)
2048 },
2049 Op::AttentionBackward { mask_kind, .. } => match mask_kind {
2050 MaskKind::Custom | MaskKind::Bias => 5, // q, k, v, dy, mask
2051 _ => 4, // q, k, v, dy
2052 },
2053 Op::Rope { .. } => 3, // x, cos, sin
2054 Op::AxialRope2d { .. } => 1,
2055 Op::LayerNorm { .. }
2056 | Op::LayerNorm2d { .. }
2057 | Op::GroupNorm { .. }
2058 | Op::RmsNorm { .. } => 3, // input, gamma, beta
2059 Op::BatchNormInference { .. } => 5, // x, gamma, beta, mean, var
2060 Op::FusedMatMulBiasAct { .. } => 3, // input, weight, bias
2061 Op::FusedResidualLN { has_bias: true, .. } => 5, // x, residual, bias, gamma, beta
2062 Op::FusedResidualLN {
2063 has_bias: false, ..
2064 } => 4, // x, residual, gamma, beta
2065 Op::FusedResidualRmsNorm { has_bias: true, .. } => 5, // x, residual, bias, gamma, beta
2066 Op::FusedResidualRmsNorm {
2067 has_bias: false, ..
2068 } => 4, // x, residual, gamma, beta
2069 Op::Conv { .. } | Op::ConvTranspose2d { .. } => 2, // input, weight (bias via Add)
2070 Op::Im2Col { .. } => 1,
2071 Op::Pool { .. } => 1,
2072 Op::ReluBackward => 2, // x, dy
2073 Op::ActivationBackward { .. } => 2, // x, dy
2074 Op::FakeQuantizeBackward { .. } => 2, // x, dy
2075 Op::ComplexNormSq => 1, // z (C64)
2076 Op::ComplexNormSqBackward => 2, // z, g
2077 Op::Conjugate => 1, // z (C64)
2078 Op::LayerNormBackwardInput { .. } => 3, // x, gamma, dy
2079 Op::LayerNormBackwardGamma { .. } => 2, // x, dy
2080 Op::RmsNormBackwardInput { .. } => 4, // x, gamma, beta, dy
2081 Op::RmsNormBackwardGamma { .. } => 4,
2082 Op::RmsNormBackwardBeta { .. } => 4,
2083 Op::RopeBackward { .. } => 3, // dy, cos, sin
2084 Op::GroupNormBackwardInput { .. } => 4, // x, gamma, beta, dy
2085 Op::GroupNormBackwardGamma { .. } => 2, // x, dy
2086 Op::GroupNormBackwardBeta { .. } => 2,
2087 Op::BatchNormInferenceBackwardInput { .. } => 5, // x, gamma, mean, var, dy
2088 Op::BatchNormInferenceBackwardGamma { .. } => 4, // x, mean, var, dy
2089 Op::BatchNormInferenceBackwardBeta => 1, // dy
2090 Op::CumsumBackward { .. } => 1, // dy
2091 Op::GatherBackward { .. } => 2, // dy, indices
2092 Op::MaxPool2dBackward { .. } => 2, // x, dy
2093 Op::Conv2dBackwardInput { .. } => 2, // dy, w
2094 Op::Conv2dBackwardWeight { .. } => 2, // x, dy
2095 Op::SoftmaxCrossEntropyWithLogits => 2, // logits, labels
2096 Op::SoftmaxCrossEntropyBackward => 3, // logits, labels, d_loss
2097 Op::Concat { .. } => 0, // variadic — checked at graph level
2098 Op::DotGeneral { .. } => 2,
2099 Op::DenseSolve => 2, // A, b
2100 Op::BatchedDenseSolve => 2, // A [B,N,N], b [B,N] or [B,N,K]
2101 Op::FusedAttentionBlock {
2102 has_bias, has_rope, ..
2103 } => 4 + if *has_bias { 2 } else { 0 } + if *has_rope { 2 } else { 0 },
2104 Op::If { .. } => 1, // predicate (captures handled separately)
2105 Op::While { .. } => 0, // variadic loop-carried; checked at graph level
2106 Op::Scan {
2107 num_bcast, num_xs, ..
2108 } => 1 + *num_bcast as usize + *num_xs as usize,
2109 Op::ScanBackward { num_xs, .. } => 3 + *num_xs as usize, // init, trajectory, upstream, xs_0..
2110 Op::ScanBackwardXs { num_xs, .. } => 3 + *num_xs as usize, // same as ScanBackward
2111 Op::GaussianSplatRender { .. } => 7,
2112 Op::GaussianSplatRenderBackward { .. } => 8,
2113 Op::GaussianSplatPrepare { .. } => 7,
2114 Op::GaussianSplatRasterize { .. } => 2,
2115 Op::FusedTransformerLayer { has_bias, .. } => {
2116 // hidden + qkv_w + out_w + ln1_g + ln1_b + fc1_w + fc2_w + ln2_g + ln2_b + mask = 10
2117 // bias variant adds: qkv_b + out_b + fc1_b + fc2_b = 4 more
2118 10 + if *has_bias { 4 } else { 0 }
2119 }
2120 Op::ElementwiseRegion { num_inputs, .. } => *num_inputs as usize,
2121 Op::TransformRegion { num_inputs, .. } => *num_inputs as usize,
2122 Op::BatchElementwiseRegion {
2123 num_batch_inputs, ..
2124 } => *num_batch_inputs as usize,
2125 Op::Custom { num_inputs, .. } => *num_inputs as usize,
2126 Op::CustomFn { num_inputs, .. } => *num_inputs as usize,
2127 Op::Fft { .. } => 1,
2128 Op::FftButterflyStage { .. } => 5,
2129 Op::LogMel => 2,
2130 Op::LogMelBackward => 3,
2131 Op::WelchPeaks { .. } => 1,
2132 }
2133 }
2134}
2135
2136impl std::fmt::Display for Op {
2137 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2138 match self {
2139 Op::Input { name } => write!(f, "input(\"{name}\")"),
2140 Op::Param { name } => write!(f, "param(\"{name}\")"),
2141 Op::Constant { data } => write!(f, "const({}B)", data.len()),
2142 Op::Activation(a) => write!(f, "{a:?}"),
2143 Op::Quantize { axis, scales, .. } => match axis {
2144 None => write!(f, "quantize(s={})", scales[0]),
2145 Some(d) => write!(f, "quantize(axis={d},nch={})", scales.len()),
2146 },
2147 Op::Dequantize { axis, scales, .. } => match axis {
2148 None => write!(f, "dequantize(s={})", scales[0]),
2149 Some(d) => write!(f, "dequantize(axis={d},nch={})", scales.len()),
2150 },
2151 Op::FakeQuantize {
2152 bits,
2153 axis,
2154 ste,
2155 scale_mode,
2156 } => match axis {
2157 None => write!(
2158 f,
2159 "fake_quant(bits={bits},ste={ste:?},scale={scale_mode:?})"
2160 ),
2161 Some(d) => write!(
2162 f,
2163 "fake_quant(bits={bits},axis={d},ste={ste:?},scale={scale_mode:?})"
2164 ),
2165 },
2166 Op::FakeQuantizeLSQ { bits, axis } => match axis {
2167 None => write!(f, "fake_quant_lsq(bits={bits})"),
2168 Some(d) => write!(f, "fake_quant_lsq(bits={bits},axis={d})"),
2169 },
2170 Op::FakeQuantizeLSQBackwardX { bits, .. } => {
2171 write!(f, "fake_quant_lsq_bwd_x(bits={bits})")
2172 }
2173 Op::FakeQuantizeLSQBackwardScale { bits, .. } => {
2174 write!(f, "fake_quant_lsq_bwd_s(bits={bits})")
2175 }
2176 Op::Cast { to } => write!(f, "cast({to})"),
2177 Op::StopGradient => write!(f, "stop_gradient"),
2178 Op::Binary(op) => write!(f, "{op:?}"),
2179 Op::Compare(op) => write!(f, "{op:?}"),
2180 Op::Where => write!(f, "where"),
2181 Op::MatMul => write!(f, "matmul"),
2182 Op::DotGeneral { .. } => write!(f, "dot_general"),
2183 Op::DenseSolve => write!(f, "dense_solve"),
2184 Op::BatchedDenseSolve => write!(f, "batched_dense_solve"),
2185 Op::LayerNorm { eps, .. } => write!(f, "layer_norm(eps={eps})"),
2186 Op::GroupNorm { num_groups, eps } => {
2187 write!(f, "group_norm(groups={num_groups},eps={eps})")
2188 }
2189 Op::BatchNormInference { eps } => write!(f, "batch_norm_inference(eps={eps})"),
2190 Op::ResizeNearest2x => write!(f, "resize_nearest_2x"),
2191 Op::RmsNorm { eps, .. } => write!(f, "rms_norm(eps={eps})"),
2192 Op::Attention {
2193 num_heads,
2194 head_dim,
2195 mask_kind,
2196 score_scale,
2197 attn_logit_softcap,
2198 } => {
2199 let mut s = match mask_kind {
2200 MaskKind::Custom => format!("attention(h={num_heads},d={head_dim})"),
2201 MaskKind::None => format!("attention(h={num_heads},d={head_dim},nomask)"),
2202 MaskKind::Causal => format!("attention(h={num_heads},d={head_dim},causal)"),
2203 MaskKind::SlidingWindow(w) => {
2204 format!("attention(h={num_heads},d={head_dim},sw={w})")
2205 }
2206 MaskKind::Bias => format!("attention(h={num_heads},d={head_dim},bias)"),
2207 };
2208 if let Some(sc) = score_scale {
2209 s.push_str(&format!(",scale={sc}"));
2210 }
2211 if let Some(cap) = attn_logit_softcap {
2212 s.push_str(&format!(",softcap={cap}"));
2213 }
2214 write!(f, "{s}")
2215 }
2216 Op::Rope { head_dim, n_rot } => write!(f, "rope(d={head_dim}, n_rot={n_rot})"),
2217 Op::AxialRope2d {
2218 end_x,
2219 end_y,
2220 head_dim,
2221 num_heads,
2222 theta,
2223 repeat_factor,
2224 } => write!(
2225 f,
2226 "axial_rope2d({end_x}x{end_y},h={num_heads},d={head_dim},θ={theta},r={repeat_factor})"
2227 ),
2228 Op::Reshape { new_shape } => write!(f, "reshape({new_shape:?})"),
2229 Op::Transpose { perm } => write!(f, "transpose({perm:?})"),
2230 Op::Narrow { axis, start, len } => write!(f, "narrow({axis},{start},{len})"),
2231 Op::Concat { axis } => write!(f, "concat(axis={axis})"),
2232 Op::Expand { .. } => write!(f, "expand"),
2233 Op::Gather { axis } => write!(f, "gather(axis={axis})"),
2234 Op::Reduce { op, axes, .. } => write!(f, "reduce_{op:?}({axes:?})"),
2235 Op::Softmax { axis } => write!(f, "softmax(axis={axis})"),
2236 Op::Cumsum { axis, exclusive } => {
2237 if *exclusive {
2238 write!(f, "cumsum(axis={axis},excl)")
2239 } else {
2240 write!(f, "cumsum(axis={axis})")
2241 }
2242 }
2243 Op::ArgMax { axis, keep_dim } => write!(f, "argmax(axis={axis},keep={keep_dim})"),
2244 Op::ArgMin { axis, keep_dim } => write!(f, "argmin(axis={axis},keep={keep_dim})"),
2245 Op::Sample {
2246 top_k,
2247 top_p,
2248 temperature,
2249 ..
2250 } => {
2251 write!(f, "sample(t={temperature}")?;
2252 if *top_k > 0 {
2253 write!(f, ",k={top_k}")?;
2254 }
2255 if *top_p < 1.0 {
2256 write!(f, ",p={top_p}")?;
2257 }
2258 write!(f, ")")
2259 }
2260 Op::RngNormal {
2261 mean,
2262 scale,
2263 key,
2264 op_seed,
2265 } => {
2266 write!(f, "rng_normal({mean},{scale},key={key}")?;
2267 if let Some(s) = op_seed {
2268 write!(f, ",seed={s}")?;
2269 }
2270 write!(f, ")")
2271 }
2272 Op::RngUniform {
2273 low,
2274 high,
2275 key,
2276 op_seed,
2277 } => {
2278 write!(f, "rng_uniform({low},{high},key={key}")?;
2279 if let Some(s) = op_seed {
2280 write!(f, ",seed={s}")?;
2281 }
2282 write!(f, ")")
2283 }
2284 Op::TopK { k } => write!(f, "topk(k={k})"),
2285 Op::GroupedMatMul => write!(f, "grouped_matmul"),
2286 Op::DequantGroupedMatMul { scheme } => {
2287 write!(f, "dequant_grouped_matmul({scheme})")
2288 }
2289 Op::DequantMoEWeights { scheme } => write!(f, "dequant_moe_weights({scheme})"),
2290 Op::LoraMatMul { scale } => write!(f, "lora_matmul(scale={scale})"),
2291 Op::DequantMatMul { scheme } => write!(f, "dequant_matmul({scheme})"),
2292 Op::QMatMul {
2293 x_zp,
2294 w_zp,
2295 out_zp,
2296 mult,
2297 } => write!(
2298 f,
2299 "q_matmul(x_zp={x_zp},w_zp={w_zp},out_zp={out_zp},mult={mult})"
2300 ),
2301 Op::QConv2d { kernel_size, .. } => write!(f, "q_conv2d({kernel_size:?})"),
2302 Op::SelectiveScan { state_size } => write!(f, "ssm_scan(n={state_size})"),
2303 Op::GatedDeltaNet {
2304 state_size,
2305 carry_state,
2306 } => {
2307 if *carry_state {
2308 write!(f, "gated_delta_net(n={state_size},carry)")
2309 } else {
2310 write!(f, "gated_delta_net(n={state_size})")
2311 }
2312 }
2313 Op::Lstm {
2314 hidden_size,
2315 num_layers,
2316 bidirectional,
2317 carry,
2318 } => {
2319 let dir = if *bidirectional { "bi" } else { "uni" };
2320 let c = if *carry { ",carry" } else { "" };
2321 write!(f, "lstm(h={hidden_size},layers={num_layers},{dir}{c})")
2322 }
2323 Op::Gru {
2324 hidden_size,
2325 num_layers,
2326 bidirectional,
2327 carry,
2328 } => {
2329 let dir = if *bidirectional { "bi" } else { "uni" };
2330 let c = if *carry { ",carry" } else { "" };
2331 write!(f, "gru(h={hidden_size},layers={num_layers},{dir}{c})")
2332 }
2333 Op::Rnn {
2334 hidden_size,
2335 num_layers,
2336 bidirectional,
2337 carry,
2338 relu,
2339 } => {
2340 let dir = if *bidirectional { "bi" } else { "uni" };
2341 let act = if *relu { "relu" } else { "tanh" };
2342 let c = if *carry { ",carry" } else { "" };
2343 write!(f, "rnn(h={hidden_size},layers={num_layers},{dir},{act}{c})")
2344 }
2345 Op::Mamba2 {
2346 head_dim,
2347 state_size,
2348 } => write!(f, "mamba2(p={head_dim},n={state_size})"),
2349 Op::ScatterAdd => write!(f, "scatter_add"),
2350 Op::Conv { kernel_size, .. } => write!(f, "conv2d({kernel_size:?})"),
2351 Op::Im2Col { kernel_size, .. } => write!(f, "im2col({kernel_size:?})"),
2352 Op::ConvTranspose2d { kernel_size, .. } => {
2353 write!(f, "conv_transpose2d({kernel_size:?})")
2354 }
2355 Op::LayerNorm2d { eps } => write!(f, "layer_norm2d(eps={eps})"),
2356 Op::Pool {
2357 kind, kernel_size, ..
2358 } => write!(f, "pool_{kind:?}({kernel_size:?})"),
2359 Op::ReluBackward => write!(f, "relu_backward"),
2360 Op::ActivationBackward { kind } => write!(f, "{kind:?}_backward"),
2361 Op::ComplexNormSq => write!(f, "complex_norm_sq"),
2362 Op::ComplexNormSqBackward => write!(f, "complex_norm_sq_backward"),
2363 Op::Conjugate => write!(f, "conjugate"),
2364 Op::FakeQuantizeBackward { bits, ste, .. } => {
2365 write!(f, "fake_quant_backward(bits={bits},ste={ste:?})")
2366 }
2367 Op::MaxPool2dBackward { kernel_size, .. } => {
2368 write!(f, "maxpool2d_backward({kernel_size:?})")
2369 }
2370 Op::Conv2dBackwardInput { kernel_size, .. } => {
2371 write!(f, "conv2d_backward_input({kernel_size:?})")
2372 }
2373 Op::Conv2dBackwardWeight { kernel_size, .. } => {
2374 write!(f, "conv2d_backward_weight({kernel_size:?})")
2375 }
2376 Op::SoftmaxCrossEntropyWithLogits => write!(f, "sce_with_logits"),
2377 Op::SoftmaxCrossEntropyBackward => write!(f, "sce_backward"),
2378 Op::AttentionBackward {
2379 num_heads,
2380 head_dim,
2381 mask_kind,
2382 wrt,
2383 } => match mask_kind {
2384 MaskKind::None => write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},nomask)"),
2385 MaskKind::Causal => {
2386 write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},causal)")
2387 }
2388 MaskKind::SlidingWindow(w) => {
2389 write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},sw={w})")
2390 }
2391 MaskKind::Custom => {
2392 write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},custom)")
2393 }
2394 MaskKind::Bias => write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},bias)"),
2395 },
2396 Op::FusedMatMulBiasAct { activation } => {
2397 write!(f, "fused_mm_bias")?;
2398 if let Some(a) = activation {
2399 write!(f, "_{a:?}")?;
2400 }
2401 Ok(())
2402 }
2403 Op::FusedResidualLN { has_bias, eps } => {
2404 write!(f, "fused_residual")?;
2405 if *has_bias {
2406 write!(f, "_bias")?;
2407 }
2408 write!(f, "_ln(eps={eps})")
2409 }
2410 Op::FusedResidualRmsNorm { has_bias, eps } => {
2411 write!(f, "fused_residual")?;
2412 if *has_bias {
2413 write!(f, "_bias")?;
2414 }
2415 write!(f, "_rms(eps={eps})")
2416 }
2417 Op::FusedSwiGLU {
2418 cast_to,
2419 gate_first,
2420 } => {
2421 let mut s = match cast_to {
2422 Some(dt) => format!("fused_swiglu(cast={dt}"),
2423 None => "fused_swiglu(".to_string(),
2424 };
2425 if *gate_first {
2426 s.push_str(",gate_first");
2427 }
2428 s.push(')');
2429 write!(f, "{s}")
2430 }
2431 Op::FusedAttentionBlock {
2432 num_heads,
2433 head_dim,
2434 has_bias,
2435 has_rope,
2436 } => {
2437 write!(f, "fused_attn(h={num_heads},d={head_dim}")?;
2438 if *has_bias {
2439 write!(f, ",bias")?;
2440 }
2441 if *has_rope {
2442 write!(f, ",rope")?;
2443 }
2444 write!(f, ")")
2445 }
2446 Op::If { .. } => write!(f, "if(...)"),
2447 Op::While { max_iterations, .. } => match max_iterations {
2448 Some(n) => write!(f, "while(...max={n})"),
2449 None => write!(f, "while(...)"),
2450 },
2451 Op::Scan {
2452 length,
2453 save_trajectory,
2454 num_xs,
2455 ..
2456 } => {
2457 let traj = if *save_trajectory { ",traj" } else { "" };
2458 let xs = if *num_xs > 0 {
2459 format!(",xs={}", num_xs)
2460 } else {
2461 String::new()
2462 };
2463 write!(f, "scan(len={length}{xs}{traj})")
2464 }
2465 Op::ScanBackward {
2466 length,
2467 save_trajectory,
2468 num_xs,
2469 ..
2470 } => {
2471 let traj = if *save_trajectory { ",traj" } else { "" };
2472 let xs = if *num_xs > 0 {
2473 format!(",xs={}", num_xs)
2474 } else {
2475 String::new()
2476 };
2477 write!(f, "scan_bwd(len={length}{xs}{traj})")
2478 }
2479 Op::ScanBackwardXs {
2480 length,
2481 save_trajectory,
2482 num_xs,
2483 xs_idx,
2484 ..
2485 } => {
2486 let traj = if *save_trajectory { ",traj" } else { "" };
2487 write!(
2488 f,
2489 "scan_bwd_xs(len={length},xs={num_xs},idx={xs_idx}{traj})"
2490 )
2491 }
2492 Op::FusedTransformerLayer {
2493 num_heads,
2494 head_dim,
2495 intermediate_size,
2496 has_bias,
2497 ..
2498 } => {
2499 write!(
2500 f,
2501 "fused_layer(h={num_heads},d={head_dim},int={intermediate_size}"
2502 )?;
2503 if *has_bias {
2504 write!(f, ",bias")?;
2505 }
2506 write!(f, ")")
2507 }
2508 Op::ElementwiseRegion {
2509 chain,
2510 num_inputs,
2511 scalar_input_mask,
2512 input_modulus: _,
2513 prologue,
2514 prologue_input: _,
2515 } => {
2516 let pro = match prologue {
2517 RegionPrologue::None => "",
2518 RegionPrologue::ResizeNearest2x => ",prologue=resize2x",
2519 };
2520 if *scalar_input_mask != 0 {
2521 write!(
2522 f,
2523 "ew_region(in={num_inputs},steps={},scalar_mask=0x{:x}{pro})",
2524 chain.len(),
2525 scalar_input_mask
2526 )
2527 } else {
2528 write!(f, "ew_region(in={num_inputs},steps={}{pro})", chain.len())
2529 }
2530 }
2531 Op::TransformRegion { steps, num_inputs } => {
2532 write!(f, "transform_region(in={num_inputs},steps={})", steps.len())
2533 }
2534 Op::BatchElementwiseRegion {
2535 chain,
2536 num_batch_inputs,
2537 scalar_input_mask,
2538 prologue,
2539 ..
2540 } => write!(
2541 f,
2542 "batch_ew_region(batch={num_batch_inputs},steps={},mask=0x{:x},prologue={prologue:?})",
2543 chain.len(),
2544 scalar_input_mask
2545 ),
2546 Op::LayerNormBackwardInput { eps, .. } => {
2547 write!(f, "layer_norm_backward_input(eps={eps})")
2548 }
2549 Op::LayerNormBackwardGamma { eps, .. } => {
2550 write!(f, "layer_norm_backward_gamma(eps={eps})")
2551 }
2552 Op::RmsNormBackwardInput { eps, .. } => write!(f, "rms_norm_backward_input(eps={eps})"),
2553 Op::RmsNormBackwardGamma { eps, .. } => write!(f, "rms_norm_backward_gamma(eps={eps})"),
2554 Op::RmsNormBackwardBeta { eps, .. } => write!(f, "rms_norm_backward_beta(eps={eps})"),
2555 Op::RopeBackward { head_dim, n_rot } => {
2556 write!(f, "rope_backward(d={head_dim},n_rot={n_rot})")
2557 }
2558 Op::GroupNormBackwardInput { num_groups, eps } => {
2559 write!(f, "group_norm_backward_input(g={num_groups},eps={eps})")
2560 }
2561 Op::GroupNormBackwardGamma { num_groups, eps } => {
2562 write!(f, "group_norm_backward_gamma(g={num_groups},eps={eps})")
2563 }
2564 Op::GroupNormBackwardBeta { num_groups, eps } => {
2565 write!(f, "group_norm_backward_beta(g={num_groups},eps={eps})")
2566 }
2567 Op::BatchNormInferenceBackwardInput { eps } => {
2568 write!(f, "batch_norm_inference_backward_input(eps={eps})")
2569 }
2570 Op::BatchNormInferenceBackwardGamma { eps } => {
2571 write!(f, "batch_norm_inference_backward_gamma(eps={eps})")
2572 }
2573 Op::BatchNormInferenceBackwardBeta => {
2574 write!(f, "batch_norm_inference_backward_beta")
2575 }
2576 Op::CumsumBackward { axis, exclusive } => {
2577 write!(f, "cumsum_backward(axis={axis},exclusive={exclusive})")
2578 }
2579 Op::GatherBackward { axis } => write!(f, "gather_backward(axis={axis})"),
2580 Op::GaussianSplatRender {
2581 width,
2582 height,
2583 tile_size,
2584 radius_scale,
2585 alpha_cutoff,
2586 max_splat_steps,
2587 transmittance_threshold,
2588 max_list_entries,
2589 } => write!(
2590 f,
2591 "gaussian_splat_render({width}x{height},tile={tile_size},r={radius_scale},a={alpha_cutoff},steps={max_splat_steps},t={transmittance_threshold},list={max_list_entries})"
2592 ),
2593 Op::GaussianSplatRenderBackward {
2594 width,
2595 height,
2596 loss_grad_clip,
2597 sh_band,
2598 ..
2599 } => write!(
2600 f,
2601 "gaussian_splat_render_bwd({width}x{height},clip={loss_grad_clip},sh={sh_band})"
2602 ),
2603 Op::GaussianSplatPrepare {
2604 width,
2605 height,
2606 tile_size,
2607 radius_scale,
2608 alpha_cutoff,
2609 max_splat_steps,
2610 transmittance_threshold,
2611 max_list_entries,
2612 ..
2613 } => write!(
2614 f,
2615 "gaussian_splat_prepare({width}x{height},tile={tile_size},r={radius_scale},a={alpha_cutoff},steps={max_splat_steps},t={transmittance_threshold},list={max_list_entries})"
2616 ),
2617 Op::GaussianSplatRasterize {
2618 width,
2619 height,
2620 tile_size,
2621 alpha_cutoff,
2622 max_splat_steps,
2623 transmittance_threshold,
2624 max_list_entries,
2625 ..
2626 } => write!(
2627 f,
2628 "gaussian_splat_rasterize({width}x{height},tile={tile_size},a={alpha_cutoff},steps={max_splat_steps},t={transmittance_threshold},list={max_list_entries})"
2629 ),
2630 Op::Custom {
2631 name,
2632 num_inputs,
2633 attrs,
2634 } => write!(f, "custom({name},in={num_inputs},attrs={}B)", attrs.len()),
2635 Op::CustomFn {
2636 num_inputs,
2637 vjp_body,
2638 jvp_body,
2639 ..
2640 } => {
2641 let v = if vjp_body.is_some() { ",vjp" } else { "" };
2642 let j = if jvp_body.is_some() { ",jvp" } else { "" };
2643 write!(f, "custom_fn(in={num_inputs}{v}{j})")
2644 }
2645 Op::Fft { inverse, norm } => {
2646 write!(f, "fft(inverse={inverse}, norm={norm:?})")
2647 }
2648 Op::FftButterflyStage { stage, n_fft } => {
2649 write!(f, "fft_butterfly_stage(stage={stage}, n_fft={n_fft})")
2650 }
2651 Op::LogMel => write!(f, "log_mel()"),
2652 Op::LogMelBackward => write!(f, "log_mel_backward()"),
2653 Op::WelchPeaks { k, n_segments } => {
2654 write!(f, "welch_peaks(k={k}, n_segments={n_segments})")
2655 }
2656 }
2657 }
2658}