rlx_ir/op.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Operation types — every tensor op in the RLX IR.
17//!
18//! Designed for pattern-matching fusion: ops are grouped by category so
19//! fusion passes can reason about them structurally.
20
21use crate::DType;
22
23/// Unary element-wise activation functions.
24#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub enum Activation {
27 Gelu,
28 GeluApprox,
29 Silu, // SwiGLU gate activation
30 Relu,
31 Sigmoid,
32 Tanh,
33 Exp,
34 Log,
35 Sqrt,
36 Rsqrt,
37 Neg,
38 Abs,
39 /// `sin(x)`. Backward: `dx = upstream · cos(x)`.
40 Sin,
41 /// `cos(x)`. Backward: `dx = -upstream · sin(x)`.
42 Cos,
43 /// `tan(x)`. Backward: `dx = upstream · sec²(x) = upstream · (1 + tan²(x))`.
44 Tan,
45 /// `atan(x)`. Backward: `dx = upstream · (1 / (1 + x²))`.
46 Atan,
47 /// Round to nearest integer (half-to-even), in f32.
48 /// Forward: `x.round()`. Backward: STE — treats as identity, so
49 /// the gradient passes through unchanged. Useful as a primitive
50 /// for composing custom quantization schemes (Mul-by-recip-scale
51 /// → Round → Clamp → Mul-by-scale = a hand-rolled FakeQuantize
52 /// that the elementwise-region pass can fuse into a single kernel).
53 Round,
54}
55
56/// Scale-tracking strategy for `Op::FakeQuantize`. Determines how
57/// the per-channel `s[c]` is computed each forward pass.
58///
59/// * `PerBatch` — recompute `s[c] = max(|x|) / q_max` from the
60/// current data on every call. Simple, no extra inputs, but
61/// noisy for activations (max-abs jumps batch-to-batch).
62///
63/// * `EMA { decay }` — keep a running `s[c]` in a state tensor
64/// (passed as a second op input). On each call, blend the
65/// current per-batch max-abs into the state via
66/// `state' = decay·state + (1-decay)·max_abs`. Smooth scale
67/// over training, makes activation-QAT actually trainable.
68/// Typical `decay = 0.99`.
69///
70/// * `Fixed` — never recompute. The state tensor's value is
71/// used as-is each call (set once at construction or by the
72/// caller). Useful when scales are pre-calibrated.
73#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
74#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Default)]
75pub enum ScaleMode {
76 #[default]
77 PerBatch,
78 EMA {
79 decay: f32,
80 },
81 Fixed,
82}
83
84impl Eq for ScaleMode {}
85impl std::hash::Hash for ScaleMode {
86 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
87 match self {
88 ScaleMode::PerBatch => state.write_u8(0),
89 ScaleMode::EMA { decay } => {
90 state.write_u8(1);
91 state.write_u32(decay.to_bits());
92 }
93 ScaleMode::Fixed => state.write_u8(2),
94 }
95 }
96}
97
98/// Straight-through estimator variants for `Op::FakeQuantize`'s
99/// backward. The forward is the same regardless: discrete
100/// `clamp(round(x/s)) * s`. The choice here affects only the
101/// gradient w.r.t. `x` during training.
102///
103/// * `Identity` — `dx = upstream`. The original STE; treats the
104/// round as identity in the backward direction. Simplest, fine
105/// for moderate bit widths (i4 / i8).
106///
107/// * `ClippedIdentity` — `dx = upstream * (|x| ≤ q_max·s)`. Zero
108/// the gradient when the input was outside the quantization
109/// range (i.e. the clamp activated). Stops the optimizer from
110/// pushing weights further into saturation.
111///
112/// * `Tanh` — `dx = upstream * (1 - tanh²(x/s))`. Smooth surrogate
113/// for the round step. Slowly attenuates the gradient as `|x|`
114/// approaches `q_max·s`. Often best on tight bit widths (i2).
115///
116/// * `HardTanh` — `dx = upstream * (1 - |x/(q_max·s)|).max(0)`.
117/// Piecewise-linear cousin of `Tanh`; cheaper to compute and
118/// nearly as effective.
119#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
120#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
121pub enum SteKind {
122 #[default]
123 Identity,
124 ClippedIdentity,
125 Tanh,
126 HardTanh,
127}
128
129/// Binary element-wise operations.
130#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
132pub enum BinaryOp {
133 Add,
134 Sub,
135 Mul,
136 Div,
137 Max,
138 Min,
139 Pow,
140}
141
142/// Comparison operations (return Bool tensor).
143#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
144#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
145pub enum CmpOp {
146 Eq,
147 Ne,
148 Lt,
149 Le,
150 Gt,
151 Ge,
152}
153
154/// What kind of attention mask the kernel should apply.
155///
156/// Borrowed from MAX's `nn/attention/mha_mask.mojo` pattern (#20 in
157/// PLAN.md): one attention kernel handles all variants by branching on
158/// the mask kind, instead of forcing every caller to materialize a mask
159/// tensor. The win is two-fold:
160/// 1. **`None`** — single unpadded sequence: no mask load, no per-key
161/// compare in the inner loop.
162/// 2. **`Causal`** — autoregressive decode: kernel generates the upper-
163/// triangular fill from `(qi, ki)` directly; no `seq²` mask tensor
164/// ever exists.
165///
166/// `Custom` is the existing path — read mask values from the 4th input.
167#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
168#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
169pub enum MaskKind {
170 /// No masking — every position attends to every position.
171 None,
172 /// Causal (autoregressive) — position `qi` attends only to `ki <= qi`.
173 Causal,
174 /// Sliding window — position `qi` attends to `ki ∈ [qi - w, qi]`.
175 SlidingWindow(usize),
176 /// Read mask values from the input tensor (default; matches BERT
177 /// padding-mask behavior). Tensor shape `[batch, key_len]` with
178 /// `1.0` = valid, `<0.5` = ignored.
179 Custom,
180 /// Additive per-head, per-query bias tensor
181 /// `[batch, num_heads, query_len, key_len]` added to the
182 /// `QK^T · scale` scores before softmax. Lets DETR-style boxRPB
183 /// and other learned position biases reuse the fast `Op::Attention`
184 /// path instead of decomposing into matmul + add + softmax + matmul.
185 Bias,
186}
187
188/// Which forward input an [`Op::AttentionBackward`] node differentiates.
189#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
190#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
191pub enum AttentionBwdWrt {
192 Query,
193 Key,
194 Value,
195}
196
197/// Reduction operations along specified axes.
198#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
199#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
200pub enum ReduceOp {
201 Sum,
202 Mean,
203 Max,
204 Min,
205 Prod,
206}
207
208/// PLAN L4: discriminant for each [`Op`] variant. Used by
209/// [`Op::kind`] + the `Backend::supported_ops` trait method to declare
210/// which ops a backend can lower; the `LegalizeForBackend` pass in
211/// `rlx-opt` checks the graph against this set and fails the compile
212/// when an unsupported op is present (instead of silent fallback).
213#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
214#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
215pub enum OpKind {
216 Input,
217 Param,
218 Constant,
219 Activation,
220 Cast,
221 StopGradient,
222 Quantize,
223 Dequantize,
224 FakeQuantize,
225 FakeQuantizeLSQ,
226 FakeQuantizeLSQBackwardX,
227 FakeQuantizeLSQBackwardScale,
228 Binary,
229 Compare,
230 Where,
231 ElementwiseRegion,
232 /// Fused sampling / geometry chain (FKL-style transform region).
233 TransformRegion,
234 /// Same element-wise chain over multiple batch planes (horizontal fusion).
235 BatchElementwiseRegion,
236 MatMul,
237 DotGeneral,
238 DenseSolve,
239 BatchedDenseSolve,
240 LayerNorm,
241 LayerNorm2d,
242 GroupNorm,
243 BatchNormInference,
244 RmsNorm,
245 ResizeNearest2x,
246 Attention,
247 Rope,
248 AxialRope2d,
249 Reshape,
250 Transpose,
251 Narrow,
252 Concat,
253 Expand,
254 Gather,
255 Reduce,
256 Softmax,
257 Cumsum,
258 TopK,
259 Sample,
260 Conv,
261 Im2Col,
262 ConvTranspose2d,
263 Pool,
264 ReluBackward,
265 ActivationBackward,
266 FakeQuantizeBackward,
267 ComplexNormSq,
268 ComplexNormSqBackward,
269 Conjugate,
270 MaxPool2dBackward,
271 Conv2dBackwardInput,
272 Conv2dBackwardWeight,
273 SoftmaxCrossEntropyWithLogits,
274 SoftmaxCrossEntropyBackward,
275 AttentionBackward,
276 LayerNormBackwardInput,
277 LayerNormBackwardGamma,
278 RmsNormBackwardInput,
279 RmsNormBackwardGamma,
280 RmsNormBackwardBeta,
281 RopeBackward,
282 GroupNormBackwardInput,
283 GroupNormBackwardGamma,
284 GroupNormBackwardBeta,
285 BatchNormInferenceBackwardInput,
286 BatchNormInferenceBackwardGamma,
287 BatchNormInferenceBackwardBeta,
288 CumsumBackward,
289 GatherBackward,
290 GroupedMatMul,
291 DequantGroupedMatMul,
292 DequantMoEWeights,
293 ScatterAdd,
294 LoraMatMul,
295 DequantMatMul,
296 QMatMul,
297 QConv2d,
298 SelectiveScan,
299 GatedDeltaNet,
300 FusedSwiGLU,
301 FusedMatMulBiasAct,
302 FusedResidualLN,
303 FusedResidualRmsNorm,
304 FusedAttentionBlock,
305 FusedTransformerLayer,
306 If,
307 While,
308 Scan,
309 ScanBackward,
310 ScanBackwardXs,
311 /// CPU reference 3D Gaussian splat raster (project → bin → sort → raster).
312 /// See [`Op::GaussianSplatRender`].
313 GaussianSplatRender,
314 /// Backward of [`Op::GaussianSplatRender`] — packed scene parameter gradients.
315 GaussianSplatRenderBackward,
316 /// Project + tile bin + sort + ray grid (strict IR splat stage 1).
317 GaussianSplatPrepare,
318 /// Per-pixel raster from prepared buffers (strict IR splat stage 2).
319 GaussianSplatRasterize,
320 /// User-registered op dispatched through `op_registry`. All
321 /// custom ops (Sparse-LU, FFT, eigensolve, ...) share this kind;
322 /// the per-op identity lives in `Op::Custom::name`.
323 Custom,
324 /// User-defined sub-graph with optional override AD rules. See
325 /// [`Op::CustomFn`] / [`crate::Graph::custom_fn`].
326 CustomFn,
327 /// 1D FFT primitive (forward or inverse) — see [`Op::Fft`].
328 Fft,
329 /// Ternary pruned radix-2 butterfly stage — see [`Op::FftButterflyStage`].
330 FftButterflyStage,
331 /// Whisper-style log-mel from block-layout FFT spectrum — see [`Op::LogMel`].
332 LogMel,
333 /// Backward of [`Op::LogMel`] w.r.t. block-layout spectrum input 0.
334 LogMelBackward,
335 /// Welch PSD top-K spikes from block-layout FFT spectra — see [`Op::WelchPeaks`].
336 WelchPeaks,
337}
338
339/// An operand inside a fused [`ChainStep`] — either a graph-level input
340/// to the [`Op::ElementwiseRegion`] (by index 0..num_inputs) or the
341/// result of a previous step in the chain (by index 0..step_position).
342#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
343#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
344pub enum ChainOperand {
345 Input(u32),
346 Step(u32),
347}
348
349/// One step in a fused element-wise chain. Each step produces exactly
350/// one scalar result (per element); later steps can refer to it via
351/// [`ChainOperand::Step`]. The whole chain runs per element in registers.
352#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
353#[derive(Debug, Clone, PartialEq)]
354pub enum ChainStep {
355 Activation(Activation, ChainOperand),
356 Cast(DType, ChainOperand),
357 Binary(BinaryOp, ChainOperand, ChainOperand),
358 Compare(CmpOp, ChainOperand, ChainOperand),
359 /// 3-input element-wise select: `cond ? on_true : on_false`. Mirrors
360 /// `Op::Where` inside a chain. `cond` is treated as truthy iff
361 /// non-zero. Lets the optimizer fold attention masks / clamp-style
362 /// patterns into a single region kernel instead of breaking the
363 /// chain at the first `Op::Where`.
364 Where(ChainOperand, ChainOperand, ChainOperand),
365}
366
367/// Pre-region memory transform fused into [`Op::ElementwiseRegion`].
368#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
369#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
370pub enum RegionPrologue {
371 #[default]
372 None,
373 /// Input is half-resolution NCHW; output shape is 2× H×W (nearest 2×).
374 ResizeNearest2x,
375}
376
377/// One sampling step in [`Op::TransformRegion`].
378#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
379#[derive(Debug, Clone, PartialEq)]
380pub enum TransformStep {
381 ResizeNearest2x(ChainOperand),
382}
383
384/// An operation in the RLX IR graph.
385///
386/// Operations are categorized for fusion analysis:
387/// - Element-wise ops fuse with anything reading their output
388/// - Matmul/Conv are BLAS-dispatched and form fusion boundaries
389/// - Reductions are fusion roots (drive the loop iteration)
390#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
391#[derive(Debug, Clone, PartialEq)]
392pub enum Op {
393 // ── Graph inputs ────────────────────────────────────────────
394 /// Model input with a name (shape on the Node).
395 Input {
396 name: String,
397 },
398
399 /// Model parameter (weight/bias) with a name.
400 Param {
401 name: String,
402 },
403
404 /// Constant tensor embedded in the graph.
405 Constant {
406 data: Vec<u8>,
407 },
408
409 // ── Element-wise unary ──────────────────────────────────────
410 /// Unary activation: one input, same shape output.
411 Activation(Activation),
412
413 /// Cast to a different dtype.
414 Cast {
415 to: DType,
416 },
417
418 /// Stop-gradient (a.k.a. `detach` / `lax.stop_gradient`). Forward is
419 /// identity; the reverse-mode autodiff rule returns **no** gradient
420 /// contribution for the input. Single input, same shape & dtype
421 /// output. Used to build a Gradient-Reverse-Layer with identity
422 /// forward semantics in user code (see maet-rs `dat_loss`).
423 StopGradient,
424
425 /// INT8 quantization. Input f32; output i8 same shape.
426 /// `q[i] = saturate_i8(round(x[i] / scale[c]) + zero_point[c])`
427 /// where `c` selects the per-channel scale/zp when `axis = Some(d)`
428 /// (`c = idx[d]`), or always uses index 0 when `axis = None`
429 /// (per-tensor). The `scales` / `zero_points` payload length must
430 /// match `1` for per-tensor and `input.dim(d)` for per-channel.
431 /// Static — typically produced at calibration time and baked
432 /// into the loaded model. Use `Op::Dequantize` for the inverse.
433 Quantize {
434 axis: Option<usize>,
435 scales: Vec<f32>,
436 zero_points: Vec<i32>,
437 },
438
439 /// INT8 dequantization (inverse of `Op::Quantize`). Input i8;
440 /// output f32 same shape.
441 /// `x[i] = (q[i] - zero_point[c]) · scale[c]`
442 /// where `c` is selected by `axis` exactly as in `Op::Quantize`.
443 Dequantize {
444 axis: Option<usize>,
445 scales: Vec<f32>,
446 zero_points: Vec<i32>,
447 },
448
449 /// "Fake-quantize" op for **quantization-aware training** (QAT).
450 /// Input f32; output f32 same shape. Forward computes a per-axis
451 /// (or per-tensor when `axis = None`) max-abs scale on the fly:
452 /// `s[c] = max(|x[..., c, ...]|) / q_max(bits)`
453 /// then quantizes-then-dequantizes:
454 /// `out[i] = clamp(round(x[i] / s[c]), -q_max, q_max) * s[c]`
455 /// where `q_max` is `127` for `bits=8`, `7` for `bits=4`, `1` for
456 /// `bits=2` (ternary). Symmetric only — zero-point is always 0.
457 ///
458 /// The point of this op is to make the SGD optimizer "see" the
459 /// deployment-time rounding during training. Backward is the
460 /// **straight-through estimator** (STE): the gradient passes
461 /// through (variant chosen by `ste`), ignoring the discontinuity
462 /// at the round. Without STE the rounding would have zero
463 /// gradient almost everywhere and learning would stop.
464 ///
465 /// Inserted by the trainer on conv / FC weight tensors when
466 /// `--qat` is on; the existing `Op::Quantize` / packing path at
467 /// the end of training still handles the deployment-side
468 /// conversion to `i8`/`i4`/`i2` codes.
469 FakeQuantize {
470 bits: u8,
471 axis: Option<usize>,
472 ste: SteKind,
473 scale_mode: ScaleMode,
474 },
475
476 /// Learned Step Size Quantization (LSQ; Esser et al. 2020,
477 /// `arXiv:1902.08153`). Like `FakeQuantize` but the per-channel
478 /// `scale` is a *learned parameter*, passed as the second input.
479 /// Forward is identical to `FakeQuantize` with a fixed scale:
480 /// `out[i] = clamp(round(x[i]/s[c]), -q_max, q_max) * s[c]`
481 /// Backward computes both `dx` (STE) and `dscale[c]` via the
482 /// closed-form gradient:
483 /// `dscale[c] = sum_i ψ(x[i]/s[c]) · upstream[i]`
484 /// where `ψ(z) = -z + round(z)` if `|z| ≤ q_max` else
485 /// `sign(z) · q_max`. Routinely beats per-batch and EMA at
486 /// tight bit widths (i2 / i3).
487 ///
488 /// Inputs: `[x, scale]`. `scale` is `[chan_dim]` f32 (matches
489 /// `axis`); for `axis = None` it's `[1]`.
490 FakeQuantizeLSQ {
491 bits: u8,
492 axis: Option<usize>,
493 },
494
495 /// Backward pass for `Op::FakeQuantizeLSQ`. Computes BOTH the
496 /// gradient w.r.t. `x` (STE) and the gradient w.r.t. `scale`
497 /// (closed-form). Output shape matches `x`; the `scale` gradient
498 /// is reduced separately by `LsqScaleGradient`.
499 /// Inputs: `[x, scale, dy]`. Output: `dx`, same shape as `x`.
500 FakeQuantizeLSQBackwardX {
501 bits: u8,
502 axis: Option<usize>,
503 },
504
505 /// Companion to `FakeQuantizeLSQBackwardX`: computes the
506 /// `[chan_dim]` per-channel scale gradient. Inputs `[x, scale, dy]`.
507 /// Output shape matches `scale`.
508 FakeQuantizeLSQBackwardScale {
509 bits: u8,
510 axis: Option<usize>,
511 },
512
513 // ── Element-wise binary ─────────────────────────────────────
514 /// Binary op with broadcasting: two inputs, output shape is broadcast result.
515 Binary(BinaryOp),
516
517 // ── Comparison ──────────────────────────────────────────────
518 /// Element-wise comparison: two inputs, Bool output.
519 Compare(CmpOp),
520
521 /// Select elements: cond (Bool), on_true, on_false → output.
522 Where,
523
524 /// Fused element-wise region (PLAN L2). Holds an N-step chain of
525 /// element-wise operations. Inputs are referenced by index 0..num_inputs;
526 /// each step's result can be referenced by later steps via
527 /// `ChainOperand::Step(idx)`. The output is the last step's result.
528 /// Emitted by `MarkElementwiseRegions` in `rlx-opt` from chains of
529 /// Activation/Cast/Binary/Compare/Where ops with single-consumer
530 /// intermediates and broadcast-compatible shapes. Backends that
531 /// don't have a region kernel can decompose back to the original
532 /// chain via `unfuse::unfuse_elementwise_regions`.
533 ///
534 /// `scalar_input_mask` is a per-input bitfield (bit `i` set ⇒
535 /// input `i` is a scalar broadcast — has shape `[1]`). Kept as a
536 /// fast-path indicator that lets kernels skip the modulo entirely
537 /// when they detect a scalar.
538 ///
539 /// `input_modulus[i]` is the per-input element count, used by
540 /// kernels to compute `arena[input_offs[i] + (gid % input_modulus[i])]`
541 /// — the trailing-shape broadcast pattern. `0` means "no broadcast"
542 /// (input matches the output element count; kernel reads `gid`
543 /// directly). `1` means scalar; any other value means the input
544 /// has fewer elements than the output and they tile by modulo.
545 /// The encoder only allows broadcasts where `out_elems % in_elems
546 /// == 0` so the modulo divides cleanly. Lets chains include bias /
547 /// scale / eps / mask factors that previously broke the chain at
548 /// a Binary op with mismatched shapes.
549 ElementwiseRegion {
550 chain: Vec<ChainStep>,
551 num_inputs: u32,
552 scalar_input_mask: u32,
553 input_modulus: [u32; 16],
554 /// FKL-style closed fusion: apply before the element-wise chain.
555 prologue: RegionPrologue,
556 /// External input index that supplies the prologue transform source (default 0).
557 prologue_input: u32,
558 },
559
560 /// Fused transform chain (resize, future crop/color). Decompose via
561 /// [`rlx_fusion::DecomposeFusionRegions`] when no native kernel exists.
562 TransformRegion {
563 steps: Vec<TransformStep>,
564 num_inputs: u32,
565 },
566
567 /// Identical [`Op::ElementwiseRegion`] chain over `num_batch_inputs` tensors
568 /// (horizontal / z-plane fusion). Inputs are separate batch slices.
569 BatchElementwiseRegion {
570 chain: Vec<ChainStep>,
571 num_batch_inputs: u32,
572 scalar_input_mask: u32,
573 input_modulus: [u32; 16],
574 prologue: RegionPrologue,
575 prologue_input: u32,
576 },
577
578 // ── Linear algebra ──────────────────────────────────────────
579 /// Matrix multiply. Inputs: [.., M, K] × [.., K, N] → [.., M, N].
580 /// Batch dimensions are broadcast.
581 MatMul,
582
583 /// Matrix multiply with explicit dimension specification.
584 /// Like XLA's DotGeneral — handles arbitrary batch/contracting dims.
585 DotGeneral {
586 lhs_contracting: Vec<usize>,
587 rhs_contracting: Vec<usize>,
588 lhs_batch: Vec<usize>,
589 rhs_batch: Vec<usize>,
590 },
591
592 /// Batched dense linear solve. Inputs: `A [B, N, N]`,
593 /// `b [B, N]` or `b [B, N, K]`. Output: same shape as `b`.
594 ///
595 /// Per-batch independent solve — each `A[i]` and `b[i]` are
596 /// solved as a separate `Op::DenseSolve`. Emitted by vmap of
597 /// `Op::DenseSolve`. The CPU lowering loops over the batch
598 /// dimension calling `dgesv` per slice (LAPACK doesn't expose a
599 /// batched solve on Accelerate; cuSOLVER does on NVIDIA).
600 BatchedDenseSolve,
601
602 /// Dense linear solve `x = A⁻¹ · b` via LU factorization.
603 /// Inputs: `A [N, N]`, `b [N]` (or `b [N, K]` for multi-RHS).
604 /// Output: same shape as `b`.
605 ///
606 /// VJP via the implicit-function theorem:
607 /// `dx = solve(Aᵀ, upstream)`
608 /// `dA = -outer(dx, x)` (x is the forward output)
609 /// `db = dx`
610 /// The rule is dtype-agnostic; lowering is per-backend (Accelerate
611 /// `dgesv` / `sgesv`, cuSOLVER, etc.).
612 DenseSolve,
613
614 // ── Normalization ───────────────────────────────────────────
615 /// Layer normalization: input, gamma, beta → normalized output.
616 /// `axis` is the feature dimension (usually -1).
617 LayerNorm {
618 axis: i32,
619 eps: f32,
620 },
621
622 /// Group normalization on NCHW tensors: `input`, `gamma`, `beta` → same shape.
623 /// Normalizes over `(C/num_groups) × H × W` per group.
624 GroupNorm {
625 num_groups: usize,
626 eps: f32,
627 },
628
629 /// LayerNorm2d on NCHW: normalize across the channel axis at each spatial
630 /// position (candle / SAM `LayerNorm2d` semantics — not PyTorch's H×W norm).
631 LayerNorm2d {
632 eps: f32,
633 },
634
635 /// Nearest-neighbor 2× upsample on NCHW (doubles spatial dims 2 and 3).
636 ResizeNearest2x,
637
638 /// RMS normalization: input, gamma → normalized output.
639 RmsNorm {
640 axis: i32,
641 eps: f32,
642 },
643
644 /// BatchNorm inference with frozen running statistics.
645 /// Inputs: `x`, `gamma`, `beta`, `running_mean`, `running_var`.
646 /// Feature dimension is the last axis of `x`; stats are 1-D `[C]`.
647 BatchNormInference {
648 eps: f32,
649 },
650
651 // ── Attention ───────────────────────────────────────────────
652 /// Scaled dot-product attention: Q, K, V, \[mask\] → output.
653 /// The compiler can lower this to fused SDPA or flash attention.
654 /// `mask_kind` controls how masking is applied — `Custom` reads from
655 /// the 4th input tensor; `None` / `Causal` / `SlidingWindow` skip the
656 /// mask load and apply the mask directly in the inner loop. See
657 /// `MaskKind` for the rationale.
658 ///
659 /// `score_scale`: when `Some(s)`, dot-product scores are multiplied by
660 /// `s` instead of the default `1/sqrt(head_dim)` (Gemma uses `head_dim^-0.5`
661 /// explicitly in config). `attn_logit_softcap`: when `Some(c)`, applies
662 /// `c * tanh(s/c)` to scores before softmax (Gemma 2).
663 Attention {
664 num_heads: usize,
665 head_dim: usize,
666 mask_kind: MaskKind,
667 score_scale: Option<f32>,
668 attn_logit_softcap: Option<f32>,
669 },
670
671 /// Rotary position embedding applied to one tensor: x, cos, sin → x_rotated.
672 /// Apply separately to Q and K. `head_dim` is the per-head width; `n_rot`
673 /// is how many leading dims get NeoX RoPE (pair offset `n_rot/2`). When
674 /// `n_rot < head_dim`, trailing dims are copied unchanged (Qwen3.5 MRoPE).
675 Rope {
676 head_dim: usize,
677 n_rot: usize,
678 },
679
680 /// SAM2 axial 2-D RoPE on `[batch, seq, num_heads * head_dim]`.
681 AxialRope2d {
682 end_x: usize,
683 end_y: usize,
684 head_dim: usize,
685 num_heads: usize,
686 theta: f32,
687 repeat_factor: usize,
688 },
689
690 // ── Shape manipulation ──────────────────────────────────────
691 Reshape {
692 new_shape: Vec<i64>,
693 },
694 Transpose {
695 perm: Vec<usize>,
696 },
697 /// Select a contiguous slice along an axis.
698 Narrow {
699 axis: usize,
700 start: usize,
701 len: usize,
702 },
703 /// Concatenate along an axis.
704 Concat {
705 axis: usize,
706 },
707 /// Expand (broadcast) to a target shape.
708 Expand {
709 target_shape: Vec<i64>,
710 },
711 /// Gather elements by index along an axis (embedding lookup).
712 Gather {
713 axis: usize,
714 },
715
716 // ── Reduction ───────────────────────────────────────────────
717 /// Reduce along specified axes.
718 Reduce {
719 op: ReduceOp,
720 axes: Vec<usize>,
721 keep_dim: bool,
722 },
723
724 /// Selective scan (plan #15) — Mamba-style state-space model
725 /// step. The recurrence:
726 /// `h[t] = exp(Δ[t] * A) * h[t-1] + Δ[t] * B[t] * x[t]`
727 /// `y[t] = C[t] * h[t]`
728 /// where state `h` has dimension `state_size` and the input has
729 /// `(batch, seq, hidden)`.
730 ///
731 /// Inputs (in order):
732 /// `x [b, s, h]` f32 input
733 /// `delta [b, s, h]` f32 step size (per-position, per-channel)
734 /// `a [h, n]` f32 transition matrix (one per channel)
735 /// `b [b, s, n]` f32 input projection
736 /// `c [b, s, n]` f32 output projection
737 /// Output: `[b, s, h]` f32. State `h` is implicit; the kernel
738 /// scans through the seq dimension carrying it.
739 ///
740 /// `state_size` = `n` is exposed for the cost model.
741 SelectiveScan {
742 state_size: usize,
743 },
744
745 /// Gated DeltaNet linear-attention recurrence — the per-layer
746 /// kernel used by Qwen3.5/3.6 trunk "linear attention" blocks
747 /// (and Qwen3-Next, Kimi-Linear). Mirrors
748 /// `llama.cpp / src/models/delta-net-base.cpp` autoregressive
749 /// path; chunked + fused variants ride the same op identity.
750 ///
751 /// **Math (per token `t`, head `h`, state size `n`):**
752 /// state matrix `S[h, i, j]` is implicit (reset per batch).
753 /// ```text
754 /// S[h] *= exp(g[t,h]) # scalar gate
755 /// sk[h,j] = Σ_i S[h,i,j] * k[t,h,i]
756 /// d[h,j] = (v[t,h,j] - sk[h,j]) * b[t,h] # b = beta
757 /// S[h,i,j] += k[t,h,i] * d[h,j] # outer-prod
758 /// o[t,h,j] = Σ_i S[h,i,j] * (q[t,h,i] / √n)
759 /// ```
760 ///
761 /// Inputs:
762 /// `q [b, s, h_v, n]` f32 queries (L2-normed by caller)
763 /// `k [b, s, h_v, n]` f32 keys (L2-normed by caller;
764 /// GQA-repeated to match `h_v`)
765 /// `v [b, s, h_v, n]` f32 values
766 /// `g [b, s, h_v]` f32 log-gate (exp'd inside kernel)
767 /// `beta [b, s, h_v]` f32 delta-rule mixing factor
768 ///
769 /// Output: `[b, s, h_v, n]` f32.
770 ///
771 /// When `carry_state` is true, a sixth input `state [b, h_v, n, n]`
772 /// provides the initial SSM matrix per head; the kernel updates it
773 /// in place across the sequence and leaves the final state in the
774 /// same buffer (same layout as the internal scan state:
775 /// `state[h, i, j]` row-major over `(n, n)` per head).
776 GatedDeltaNet {
777 state_size: usize,
778 carry_state: bool,
779 },
780
781 /// Fused dequant + matmul (plan #5). The biggest LLM-bandwidth
782 /// win on Apple Silicon: dequantizes weights inside the matmul
783 /// inner loop, never materializing f32 weights.
784 ///
785 /// **BREAKING CHANGE in 0.2.0:** `num_inputs()` is now
786 /// scheme-dependent — **4** for legacy Int8 schemes, **2** for
787 /// the new GGUF K-quant schemes (their scales/mins live inside
788 /// the packed bytes, so no side-channel `scale` / `zp` tensors
789 /// are fed in). Callers that assumed a fixed 4-input contract
790 /// must inspect `scheme.is_gguf()` before reading inputs.
791 ///
792 /// Inputs (Int8 schemes — `scheme.is_gguf() == false`):
793 /// `x [m, k]` f32 activations
794 /// `w_q [k, n]` packed quantized weight bytes (i8 per
795 /// element for Int8 schemes; 4-bit
796 /// packed two-per-byte for Int4)
797 /// `scale [k/block, n]` per-block f32 dequant scale
798 /// `zp [k/block, n]` per-block f32 zero-point
799 /// (zero-tensor if symmetric)
800 ///
801 /// Inputs (`Nvfp4Block` — fixed group size 16 along K):
802 /// `x [m, k]` f32 activations
803 /// `w_q [k,n/2]` packed FP4 E2M1 codes (unsigned nibble 0..15)
804 /// `scale [k/16, n]` u8 FP8 E4M3 block scales (one byte / group)
805 /// `global_scale [1]` f32 per-tensor scale (pass `[1.0]` if unused)
806 ///
807 /// Inputs (GGUF schemes — `scheme.is_gguf() == true`):
808 /// `x [m, k]` f32 activations
809 /// `packed_w [bytes]` raw GGUF super-block bytes; the
810 /// dequantizer reads the per-sub-block
811 /// scales / mins / quants directly out
812 /// of the buffer per the K-quant block
813 /// layout (no side tensors).
814 ///
815 /// Output: `[m, n]` f32.
816 ///
817 /// `block_size` (on the Int8 schemes only) is the number of
818 /// consecutive elements that share one (scale, zero_point) pair.
819 /// The Op carries enough metadata that the kernel doesn't need
820 /// a separate `QuantMap` lookup at run time.
821 DequantMatMul {
822 scheme: crate::quant::QuantScheme,
823 },
824
825 /// Real INT8-arithmetic matrix multiply with i32 accumulation.
826 /// Inputs (in order):
827 /// `x [M, K]` i8 activations (zero-point = `x_zp`)
828 /// `w [K, N]` i8 weights (zero-point = `w_zp`)
829 /// `bias [N]` i32 (in accumulator scale = `x_scale·w_scale`),
830 /// pass a zeros tensor for "no bias"
831 /// Output: `[M, N]` i8 (zero-point = `out_zp`)
832 ///
833 /// Per-element compute:
834 /// `out[m,n] = requantize(bias[n] + Σₖ (x[m,k]-x_zp)·(w[k,n]-w_zp), mult, out_zp)`
835 /// where `mult = x_scale · w_scale / out_scale`.
836 ///
837 /// This is the same kernel shape `rlx-cortexm/src/dense.rs`
838 /// uses for on-device int8 inference, lifted into the IR so the
839 /// rlx-cpu backend can run a quantized graph directly (instead
840 /// of round-tripping through fake-quant Dequantize → MatMul →
841 /// Quantize). 2-D only — generalizing to batched comes when a
842 /// real workload demands it.
843 QMatMul {
844 x_zp: i32,
845 w_zp: i32,
846 out_zp: i32,
847 mult: f32,
848 },
849
850 /// Real INT8-arithmetic 2-D convolution with i32 accumulation.
851 /// Inputs:
852 /// `x [N, C_in, H, W]` i8 (zero-point = `x_zp`)
853 /// `w [C_out, C_in/groups, kH, kW]` i8 (zero-point = `w_zp`)
854 /// `bias [C_out]` i32 in accumulator scale
855 /// Output: `[N, C_out, H_out, W_out]` i8 (zero-point = `out_zp`).
856 /// Same NCHW geometry contract as `Op::Conv`; same requantize
857 /// math as `Op::QMatMul` (per-element `acc·mult` rounded to i8).
858 QConv2d {
859 kernel_size: Vec<usize>,
860 stride: Vec<usize>,
861 padding: Vec<usize>,
862 dilation: Vec<usize>,
863 groups: usize,
864 x_zp: i32,
865 w_zp: i32,
866 out_zp: i32,
867 mult: f32,
868 },
869
870 /// Fused LoRA matmul: `out = x·W + scale * x·A·B`.
871 /// Inputs (in order): `x [m, k]`, `w [k, n]`, `a [k, r]`, `b [r, n]`.
872 /// `r` is the LoRA rank (typically 4-64). `scale` is the
873 /// per-adapter `alpha / rank` knob.
874 /// Plan #9: lifts LoRA from "three matmuls + an add" into one
875 /// kernel that keeps the rank-r intermediate in registers.
876 LoraMatMul {
877 scale: f32,
878 },
879
880 /// Fused sampling kernel: logits → optional top-k filter →
881 /// optional top-p truncation → softmax → multinomial sample.
882 /// One f32-encoded sampled token id per batch row (output
883 /// shape `[batch]`).
884 ///
885 /// `temperature == 1.0` matches a plain argmax-of-softmax;
886 /// lower → sharper, higher → flatter. `top_k == 0` disables.
887 /// `top_p == 1.0` disables. `seed` is the Philox seed; pass 0
888 /// for "use process-global counter" (still deterministic
889 /// given the call order).
890 /// Borrowed from MAX's nn/sampling.mojo (#42 in PLAN.md).
891 /// Latency-critical: never materializes the full softmax
892 /// distribution on the host.
893 Sample {
894 top_k: usize, // 0 = disabled
895 top_p: f32, // 1.0 = disabled
896 temperature: f32, // 1.0 = neutral
897 seed: u64, // 0 = use thread-local counter
898 },
899
900 /// Inclusive cumulative sum along an axis. Same shape in/out.
901 /// Underpins ragged-tensor offsets, sampling (top-p prefix sum),
902 /// and sequence-position math (#44 in PLAN.md).
903 /// `exclusive=true` shifts the result so output\[0\] = 0 (useful
904 /// for offset arrays where the first segment starts at 0).
905 Cumsum {
906 axis: i32,
907 exclusive: bool,
908 },
909
910 /// Softmax along an axis (reduction + element-wise).
911 Softmax {
912 axis: i32,
913 },
914
915 /// Top-K **indices** along the last axis. Output shape `[..., k]`,
916 /// f32-encoded indices (rlx is f32-only at the I/O boundary).
917 /// To recover the values, follow with a `Gather` against the
918 /// original tensor — works because Gather already supports any axis.
919 /// Ties broken by smaller index (matches NumPy / PyTorch
920 /// `torch.topk(..., largest=True, sorted=True)`).
921 /// Used by MoE gating; also useful for beam search.
922 TopK {
923 k: usize,
924 },
925
926 /// Indexed batched matmul. The MoE GEMM primitive.
927 /// Inputs: `[input, weight, expert_idx]`
928 /// input : [M, K] — per-token activations
929 /// weight : [num_experts, K, N] — stacked expert weights
930 /// expert_idx : \[M\] — f32-encoded expert id per token
931 /// Output : [M, N] — output\[i\] = input\[i\] @ weight[expert_idx\[i\]]
932 /// Naive impl on both backends; future work can replace with a
933 /// segmented/grouped GEMM when there's a real workload.
934 GroupedMatMul,
935
936 /// Fused GGUF K-quant dequant + [`Op::GroupedMatMul`]. Same three
937 /// inputs as `GroupedMatMul`, but `weight` is a U8 tensor holding
938 /// `num_experts` contiguous packed expert slabs (GGML layout, expert
939 /// dimension outermost). Scales live inside the packed bytes.
940 DequantGroupedMatMul {
941 scheme: crate::quant::QuantScheme,
942 },
943
944 /// Dequant a packed MoE expert stack to F32 `[num_experts, K, N]` in
945 /// GroupedMatMul layout. Input: U8 packed bytes; output shape is
946 /// declared on the node (`[E, K, N]`).
947 DequantMoEWeights {
948 scheme: crate::quant::QuantScheme,
949 },
950
951 /// Scatter-add into a destination tensor. The "unpermute" half of
952 /// MoE routing (also useful for embedding gradient updates).
953 /// Inputs: `[updates, indices]`
954 /// updates : [num_updates, trailing] — values to add
955 /// indices : \[num_updates\] — f32-encoded destination row
956 /// Output : [out_dim, trailing] — output[indices\[i\]] += updates\[i\]
957 /// `out_dim` is taken from the node's declared output shape.
958 /// Initial output is zero; multiple updates to the same row
959 /// accumulate (sequentially on CPU; with atomic-add on Metal).
960 ScatterAdd,
961
962 // ── Convolution ─────────────────────────────────────────────
963 /// 2D convolution on NCHW tensors. Also exposed as [`OpKind::Conv`] / `conv2d`.
964 /// Weight layout: `[C_out, C_in / groups, kH, kW]`.
965 Conv {
966 kernel_size: Vec<usize>,
967 stride: Vec<usize>,
968 padding: Vec<usize>,
969 dilation: Vec<usize>,
970 groups: usize,
971 },
972
973 /// NCHW im2col for conv backward-weight style matmul.
974 /// Input `[N, C, H, W]`. Output `[M, C·kH·kW]` with
975 /// `M = N · H_out · W_out`. When batch is [`dynamic::sym::BATCH`],
976 /// output rows use [`dynamic::sym::ROWS`] (bind `N · H_out · W_out`).
977 Im2Col {
978 kernel_size: Vec<usize>,
979 stride: Vec<usize>,
980 padding: Vec<usize>,
981 dilation: Vec<usize>,
982 },
983
984 /// 2D transposed convolution on NCHW. Weight layout (PyTorch):
985 /// `[C_in, C_out / groups, kH, kW]`.
986 ConvTranspose2d {
987 kernel_size: Vec<usize>,
988 stride: Vec<usize>,
989 padding: Vec<usize>,
990 dilation: Vec<usize>,
991 output_padding: Vec<usize>,
992 groups: usize,
993 },
994
995 // ── Pooling ─────────────────────────────────────────────────
996 Pool {
997 kind: ReduceOp,
998 kernel_size: Vec<usize>,
999 stride: Vec<usize>,
1000 padding: Vec<usize>,
1001 },
1002
1003 // ── Backward / training ops ─────────────────────────────────
1004 //
1005 // Closed-form gradient nodes emitted by `rlx-opt::autodiff`.
1006 // Pairing a forward op with a dedicated backward op (instead of
1007 // composing it from primitives) keeps the gradient kernel simple
1008 // and lets the backend recompute argmax / masks / softmax inline.
1009 /// ReLU backward: `dx = dy where x > 0 else 0`.
1010 /// Inputs: `[x, dy]` — both same shape; output matches.
1011 ReluBackward,
1012
1013 /// Element-wise complex squared-magnitude: `|z|² = z.re² + z.im²`.
1014 /// Input: 1 tensor with `DType::C64`. Output: same shape but
1015 /// `DType::F32`. The natural real-valued loss surface for
1016 /// Wirtinger reverse-mode AD on complex graphs — pair with
1017 /// [`Op::ComplexNormSqBackward`].
1018 ComplexNormSq,
1019
1020 /// Element-wise complex conjugate: `z̄ = z.re - i·z.im` per element.
1021 /// Input: 1 tensor with `DType::C64`. Output: same shape, same dtype.
1022 /// Used by Wirtinger VJP rules on `Op::Binary` over C64 (the rule
1023 /// for `y = a·b` is `dL/dā = upstream · conj(b)`, etc.).
1024 Conjugate,
1025
1026 /// Backward for [`Op::ComplexNormSq`] under Wirtinger calculus.
1027 /// `f(z) = |z|² = z·z̄`, so `∂f/∂z̄ = z`. Given upstream real
1028 /// cotangent `g` (same shape as the forward output), the C64
1029 /// gradient with respect to `z` is `g · z` element-wise, returned
1030 /// in C64 storage `[re_g·re_z, re_g·im_z]` per element.
1031 ///
1032 /// Inputs: `[z (C64), g (F32)]` — both same logical shape; output
1033 /// matches `z` (C64).
1034 ComplexNormSqBackward,
1035
1036 /// LayerNorm backward w.r.t. the input. Computes
1037 /// `d_x[..., d] = inv_std · (dy·γ - mean(dy·γ) - x̂·mean(dy·γ·x̂))`
1038 /// over the feature axis, where `x̂ = (x - mean)/std` is recomputed
1039 /// inline from `x`. Inputs: `[x, gamma, dy]`; output shape = `x.shape`.
1040 /// Currently lowers axis=-1 only (matches the forward thunk).
1041 LayerNormBackwardInput {
1042 axis: i32,
1043 eps: f32,
1044 },
1045
1046 /// LayerNorm backward w.r.t. gamma. Computes
1047 /// `d_gamma[d] = Σ_{batch} dy[..., d] · x̂[..., d]`
1048 /// — sums the per-element product of upstream and the (recomputed)
1049 /// normalized input over the leading axes. Inputs: `[x, dy]`;
1050 /// output shape = `gamma.shape` (= 1-D feature axis).
1051 LayerNormBackwardGamma {
1052 axis: i32,
1053 eps: f32,
1054 },
1055
1056 /// RMSNorm backward w.r.t. input. Inputs `[x, gamma, beta, dy]`; output = `x.shape`.
1057 RmsNormBackwardInput {
1058 axis: i32,
1059 eps: f32,
1060 },
1061
1062 /// RMSNorm backward w.r.t. gamma. Inputs `[x, gamma, beta, dy]`; output = `gamma.shape`.
1063 RmsNormBackwardGamma {
1064 axis: i32,
1065 eps: f32,
1066 },
1067
1068 /// RMSNorm backward w.r.t. beta. Inputs `[x, gamma, beta, dy]`; output = `beta.shape`.
1069 RmsNormBackwardBeta {
1070 axis: i32,
1071 eps: f32,
1072 },
1073
1074 /// RoPE backward w.r.t. `x`. Inputs `[dy, cos, sin]`; output = `dy.shape`.
1075 RopeBackward {
1076 head_dim: usize,
1077 n_rot: usize,
1078 },
1079
1080 /// GroupNorm (NCHW) backward w.r.t. input. Inputs `[x, gamma, beta, dy]`.
1081 GroupNormBackwardInput {
1082 num_groups: usize,
1083 eps: f32,
1084 },
1085
1086 /// GroupNorm backward w.r.t. gamma. Inputs `[x, dy]`; output = `gamma.shape`.
1087 GroupNormBackwardGamma {
1088 num_groups: usize,
1089 eps: f32,
1090 },
1091
1092 /// GroupNorm backward w.r.t. beta. Inputs `[x, dy]`; output = `beta.shape`.
1093 GroupNormBackwardBeta {
1094 num_groups: usize,
1095 eps: f32,
1096 },
1097
1098 /// BatchNorm inference backward w.r.t. `x`. Inputs `[x, gamma, mean, var, dy]`.
1099 BatchNormInferenceBackwardInput {
1100 eps: f32,
1101 },
1102
1103 /// BatchNorm inference backward w.r.t. `gamma`. Inputs `[x, mean, var, dy]`.
1104 BatchNormInferenceBackwardGamma {
1105 eps: f32,
1106 },
1107
1108 /// BatchNorm inference backward w.r.t. `beta`. Inputs `[dy]`; output = `beta.shape`.
1109 BatchNormInferenceBackwardBeta,
1110
1111 /// Cumsum backward along `axis`. Inputs `[dy]`; output matches forward input shape.
1112 CumsumBackward {
1113 axis: i32,
1114 exclusive: bool,
1115 },
1116
1117 /// Gather backward (scatter-add into table). Inputs `[dy, indices]`; output = table shape.
1118 /// `axis` matches forward [`Op::Gather`].
1119 GatherBackward {
1120 axis: i32,
1121 },
1122
1123 /// Generic element-wise activation backward. `kind` selects the
1124 /// closed-form derivative `d/dx act(x)`. Inputs: `[x, dy]`; output
1125 /// shape matches `x`. The kernel computes `d/dx · dy` per element.
1126 ///
1127 /// Closed forms (all element-wise):
1128 /// * `Gelu` — exact derivative of erf-based GELU.
1129 /// * `GeluApprox` — derivative of the tanh approximation
1130 /// `0.5 x (1 + tanh(√(2/π)(x + 0.044715 x³)))`.
1131 /// * `Silu` — `σ(x)·(1 + x·(1 - σ(x)))`.
1132 /// * `Sigmoid` — `σ(x)·(1 - σ(x))`.
1133 /// * `Tanh` — `1 - tanh(x)²`.
1134 /// * `Exp` — `exp(x)`.
1135 /// * `Log` — `1 / x`.
1136 /// * `Sqrt` — `0.5 / sqrt(x)`.
1137 /// * `Rsqrt` — `-0.5 · x^(-3/2)`.
1138 /// * `Neg` — `-1`.
1139 /// * `Abs` — `sign(x)` (zero at x=0).
1140 /// * `Sin` — `cos(x)`.
1141 /// * `Cos` — `-sin(x)`.
1142 /// * `Tan` — `1 + tan²(x)`.
1143 /// * `Atan` — `1 / (1 + x²)`.
1144 /// * `Relu` — kept here for completeness; the dedicated
1145 /// `ReluBackward` op is preferred for relu and is what the
1146 /// autodiff pass actually emits.
1147 ActivationBackward {
1148 kind: Activation,
1149 },
1150
1151 /// Backward for `Op::FakeQuantize` under a non-default STE.
1152 /// Inputs `[x, dy]`: the forward input and the upstream
1153 /// gradient. Output `dx` same shape. The `bits`/`axis`/`ste`
1154 /// fields must match the forward op so the kernel computes the
1155 /// same per-channel scale and applies the right STE attenuation.
1156 /// For `SteKind::Identity` this op is unnecessary — autodiff
1157 /// just routes `upstream` through unchanged.
1158 FakeQuantizeBackward {
1159 bits: u8,
1160 axis: Option<usize>,
1161 ste: SteKind,
1162 },
1163
1164 /// 2D max-pool backward. Routes each element of `dy` back into the
1165 /// position in `x`'s window where the forward max was taken.
1166 /// Inputs: `[x, dy]` with `x [N, C, H, W]` and
1167 /// `dy [N, C, H_out, W_out]`. Output: same shape as `x`.
1168 /// Carries the forward pool's geometry so the kernel can recompute
1169 /// the argmax position per window without a saved-indices tensor.
1170 MaxPool2dBackward {
1171 kernel_size: Vec<usize>,
1172 stride: Vec<usize>,
1173 padding: Vec<usize>,
1174 },
1175
1176 /// 2D conv backward w.r.t. input. Computes `dx = conv_transpose(dy, w)`.
1177 /// Inputs: `[dy, w]` with `dy [N, C_out, H_out, W_out]` and
1178 /// `w [C_out, C_in/groups, kH, kW]`. Output: `[N, C_in, H, W]`
1179 /// (declared on the node — caller knows the original input shape).
1180 /// Geometry is the forward conv's parameters, not the transposed
1181 /// conv's.
1182 Conv2dBackwardInput {
1183 kernel_size: Vec<usize>,
1184 stride: Vec<usize>,
1185 padding: Vec<usize>,
1186 dilation: Vec<usize>,
1187 groups: usize,
1188 },
1189
1190 /// 2D conv backward w.r.t. weight. Computes
1191 /// `dw[c_out, c_in, kh, kw] = sum_{n,h_out,w_out} x[n,c_in,...] * dy[n,c_out,h_out,w_out]`.
1192 /// Inputs: `[x, dy]`. Output: `[C_out, C_in/groups, kH, kW]`.
1193 Conv2dBackwardWeight {
1194 kernel_size: Vec<usize>,
1195 stride: Vec<usize>,
1196 padding: Vec<usize>,
1197 dilation: Vec<usize>,
1198 groups: usize,
1199 },
1200
1201 /// Fused softmax + cross-entropy loss with integer (f32-encoded)
1202 /// targets — the standard classification loss. Per-row output:
1203 /// `loss[n] = -log(softmax(logits[n])[labels[n]])`.
1204 /// Inputs: `[logits, labels]` with `logits [N, C]` and
1205 /// `labels [N]` (f32-encoded class indices). Output: `[N]`.
1206 /// Caller does the `Reduce::Mean` if they want a scalar.
1207 SoftmaxCrossEntropyWithLogits,
1208
1209 /// Backward of the fused loss above. Emits
1210 /// `dlogits[n,c] = (softmax(logits[n])[c] - one_hot(labels)[n,c]) * d_loss[n]`.
1211 /// Inputs: `[logits, labels, d_loss]`. Output: `[N, C]` (same shape
1212 /// as `logits`). Recomputes the softmax inline rather than threading
1213 /// it through from the forward node.
1214 SoftmaxCrossEntropyBackward,
1215
1216 /// Backward of [`Op::Attention`]. Recomputes scaled `QK^T`, applies
1217 /// the same `mask_kind` as the forward op, softmaxes scores, then
1218 /// emits **one** of `dQ`, `dK`, or `dV` selected by [`AttentionBwdWrt`].
1219 /// Autodiff emits three nodes (one per `wrt`) so each output shape
1220 /// stays a normal single-output MIR node.
1221 ///
1222 /// Inputs: `[q, k, v, dy]` plus optional mask when `mask_kind` is
1223 /// [`MaskKind::Custom`] or [`MaskKind::Bias`] (same convention as
1224 /// forward). Output shape matches `q`, `k`, or `v` respectively.
1225 AttentionBackward {
1226 num_heads: usize,
1227 head_dim: usize,
1228 mask_kind: MaskKind,
1229 wrt: AttentionBwdWrt,
1230 },
1231
1232 // ── Fused operations (created by optimization passes) ──────
1233 /// Fused matmul + bias + activation. Created from MatMul → Add → Activation.
1234 FusedMatMulBiasAct {
1235 activation: Option<Activation>,
1236 },
1237
1238 /// Fused residual + optional bias + layer norm.
1239 /// Created from Add(x, residual) → [Add(bias)] → LayerNorm.
1240 FusedResidualLN {
1241 has_bias: bool,
1242 eps: f32,
1243 },
1244
1245 /// Fused residual + optional bias + RMS norm.
1246 /// Created from Add(x, residual) → [Add(bias)] → RmsNorm.
1247 FusedResidualRmsNorm {
1248 has_bias: bool,
1249 eps: f32,
1250 },
1251
1252 /// Fused SwiGLU: split input into up/gate halves, silu(gate) * up.
1253 /// Created from Split → Silu → Mul when fed by a concatenated matmul.
1254 ///
1255 /// `cast_to`: optional output dtype — when `Some(dt)` the kernel casts
1256 /// its result from the input dtype to `dt` in-register, saving a
1257 /// separate cast pass. Reserved for future fp8/fp4 quantization paths;
1258 /// for f32→f16 mixed precision the AutoMixedPrecision pass already
1259 /// inserts a Cast node so this stays `None` in current pipelines.
1260 FusedSwiGLU {
1261 cast_to: Option<DType>,
1262 /// When `true`, the concatenated input stores gate in the low half
1263 /// `[..., 0..N)` and up in the high half `[..., N..2N)` — the layout
1264 /// produced when gate projection is emitted before up in the builder.
1265 /// Default `false`: up @ low, gate @ high (canonical concat order).
1266 gate_first: bool,
1267 },
1268
1269 /// Fused full transformer layer: attention block + residual+LN + FFN + residual+LN.
1270 /// All intermediates resident in registers/threadgroup memory; one kernel
1271 /// per layer instead of ~30 (the CPU's batch=1 win, lifted to IR so any
1272 /// backend can implement it as a monolithic kernel).
1273 ///
1274 /// Inputs: hidden, qkv_w, qkv_b, out_w, out_b,
1275 /// ln1_g, ln1_b, fc1_w, fc1_b, fc2_w, fc2_b, ln2_g, ln2_b, mask
1276 /// Output: same shape as hidden.
1277 ///
1278 /// **Backend status:** same as FusedAttentionBlock. CPU implements
1279 /// the L1-cache-resident merge at the thunk level. Metal deferred —
1280 /// requires a single MSL kernel for the whole layer to actually
1281 /// beat the unfused path. Multi-day work; revisit when there's a
1282 /// model whose Metal inference is bottlenecked here rather than on
1283 /// the wait latency floor.
1284 FusedTransformerLayer {
1285 num_heads: usize,
1286 head_dim: usize,
1287 intermediate_size: usize,
1288 eps1: f32,
1289 eps2: f32,
1290 activation: Activation,
1291 has_bias: bool,
1292 },
1293
1294 /// Fused attention block: QKV projection → split → \[RoPE\] → SDPA → output projection.
1295 /// Created by FuseAttentionBlock pass when batch*seq is small.
1296 /// All intermediates stay in L1 cache — no arena writes between ops.
1297 ///
1298 /// Inputs (in order):
1299 /// hidden, qkv_w, out_w, mask,
1300 /// [qkv_b, out_b] if has_bias,
1301 /// [rope_cos, rope_sin] if has_rope
1302 ///
1303 /// **Backend status (Phase C finalize):**
1304 /// CPU — implemented at the *thunk* level: the CPU schedule
1305 /// recognizes the multi-thunk pattern and merges into
1306 /// a single FusedAttnBlock that keeps Q/K/V in stack
1307 /// buffers across stages (the L1-cache win).
1308 /// Metal — **deferred**. A dispatch-wrapper version (chaining
1309 /// existing kernels) buys nothing the unfused Metal path
1310 /// doesn't already get, since per-run cost is dominated
1311 /// by `wait_until_completed` (~150 µs), not encode. The
1312 /// real win is a single MSL kernel keeping Q/K/V in
1313 /// threadgroup memory across stages — multi-day work.
1314 /// Until then, Metal runs the unfused chain (one matmul,
1315 /// three narrows, two ropes, attention, one matmul) — all
1316 /// covered in op_coverage and parity_harness.
1317 FusedAttentionBlock {
1318 num_heads: usize,
1319 head_dim: usize,
1320 has_bias: bool,
1321 has_rope: bool,
1322 },
1323
1324 // ── Control flow (subgraphs as op payloads) ─────────────────
1325 //
1326 // Status: IR is defined; helper `run_if` / `run_while` exist in
1327 // rlx-runtime/src/subgraph.rs; **executor wiring is not yet
1328 // implemented** (both CPU thunk and Metal thunk fall through to
1329 // `Thunk::Nop` for these ops). Wiring requires:
1330 // 1. Recursive subgraph compile at parent-compile time.
1331 // 2. Per-subgraph input/output binding through the arena.
1332 // 3. Schedule-level dispatch when the predicate / loop cond is
1333 // resolved at runtime.
1334 // Estimate: 4–6 hours of focused work + parity tests. Deferred
1335 // because no current in-tree model uses these ops;
1336 // surface area without a validation target invites silent bugs.
1337 /// Conditional: pick between two subgraphs based on a boolean predicate.
1338 /// Inputs: [predicate, ...captures (used inside both branches)].
1339 /// `then_branch` and `else_branch` are sub-graphs that share the
1340 /// captured inputs and must produce identically-shaped outputs.
1341 /// Used for: shape-dependent execution, batched inference of
1342 /// dynamic-length sequences with padding masks.
1343 If {
1344 then_branch: Box<crate::Graph>,
1345 else_branch: Box<crate::Graph>,
1346 },
1347
1348 /// Loop: iterate `body` while `cond` evaluates true.
1349 /// Inputs: [...initial loop-carried values].
1350 /// `cond`'s single output is a Bool scalar.
1351 /// `body`'s outputs become the next iteration's loop-carried inputs.
1352 /// Outputs of While are the values after the final iteration.
1353 /// Used for: KV-cache-driven autoregressive generation, beam search.
1354 While {
1355 cond: Box<crate::Graph>,
1356 body: Box<crate::Graph>,
1357 max_iterations: Option<usize>,
1358 },
1359
1360 /// Bounded-length loop with a fixed-shape carry, optional per-step
1361 /// inputs, and optional stacked output. Mirrors JAX's `lax.scan`.
1362 ///
1363 /// Body signature: `(carry, x_t_0, ..., x_t_{num_xs-1}) → carry_next`
1364 /// — `1 + num_xs` Op::Inputs in NodeId construction order (first
1365 /// declared is the carry; the remaining `num_xs` are per-step
1366 /// slices). Single output (the next carry).
1367 ///
1368 /// Outer Op::Scan inputs (in order):
1369 /// `[init_carry, xs_0, xs_1, ..., xs_{num_xs-1}]`
1370 /// Each `xs_i` has shape `[length, *per_step_shape_i]`; the body
1371 /// sees `xs_i[t]` (a `per_step_shape_i` slice) on iteration `t`.
1372 ///
1373 /// Outer Op::Scan output:
1374 /// * `save_trajectory == false` — final carry, shape `*carry`.
1375 /// * `save_trajectory == true` — stacked trajectory of carries,
1376 /// shape `[length, *carry]`. Row `t` is the carry after step
1377 /// `t+1`, so row `length-1` matches the no-trajectory case.
1378 ///
1379 /// Mirrors JAX's `lax.scan`. Common uses include time-stepping
1380 /// integrators with time-varying drives, Mamba-style SSM scans
1381 /// reading per-step inputs, and RNN-style sequence processing.
1382 Scan {
1383 body: Box<crate::Graph>,
1384 length: u32,
1385 save_trajectory: bool,
1386 /// Number of "broadcast" inputs — values that are constant
1387 /// across iterations. Outer scan inputs in order:
1388 /// `[init, bcast_0..bcast_{B-1}, xs_0..xs_{X-1}]`
1389 /// Body Op::Inputs in NodeId order:
1390 /// `[carry, bcast_0..bcast_{B-1}, x_t_0..x_t_{X-1}]`
1391 /// CPU executor fills bcast slots ONCE before the iteration
1392 /// loop (xs slots are filled per-step). The reverse-mode AD
1393 /// pre-pass materialises each bcast into an xs of shape
1394 /// `[length, *bcast]` via broadcast `Mul` so the rest of the
1395 /// VJP / executor pipeline can stay unchanged. `0` (default)
1396 /// keeps the original carry+xs scan shape.
1397 num_bcast: u32,
1398 /// Number of per-step `xs` inputs. Total outer Op::Scan
1399 /// inputs is `1 + num_bcast + num_xs`.
1400 num_xs: u32,
1401 /// Number of trajectory checkpoints when `save_trajectory ==
1402 /// true`. `0` means "save all `length` rows" (default). A
1403 /// positive value `K` means save only `K` evenly-spaced rows
1404 /// at indices `floor(t * length / K)` for `t in 0..K`. Used
1405 /// by recursive checkpointed AD: store O(√T) carries during
1406 /// forward, recompute the rest in the backward pass.
1407 ///
1408 /// When `0` (or `K == length`), the saved trajectory has
1409 /// shape `[length, *carry]` — same as the original behavior.
1410 /// When `0 < K < length`, the saved trajectory has shape
1411 /// `[K, *carry]`.
1412 num_checkpoints: u32,
1413 },
1414
1415 /// Reverse-mode AD companion to `Op::Scan` — extracts the carry
1416 /// gradient `dinit`. Walks `t = length-1 .. 0`, applying `body_vjp`
1417 /// to thread `dcarry` back through the time loop.
1418 ///
1419 /// Inputs (in order):
1420 /// `[init, trajectory, upstream, xs_0, ..., xs_{num_xs-1}]`
1421 /// Output: `dinit`, shape = carry shape.
1422 ///
1423 /// `body_vjp` is the result of
1424 /// `autodiff::grad(body, [carry_id, xs_0_id, ..., xs_{num_xs-1}_id])`
1425 /// — a graph with `1 + num_xs + 1` Inputs (carry + x_t_i for each
1426 /// xs + `"d_output"`) and `1 + num_xs` outputs
1427 /// (dcarry + dx_t_i for each xs). This op reads `outputs[0]` =
1428 /// dcarry; the sibling [`Self::ScanBackwardXs`] reads the
1429 /// `outputs[1 + xs_idx]` slot for each xs gradient.
1430 ScanBackward {
1431 body_vjp: Box<crate::Graph>,
1432 length: u32,
1433 save_trajectory: bool,
1434 num_xs: u32,
1435 /// When `0` or equal to `length`, the trajectory input has
1436 /// shape `[length, *carry]` — every step's carry is cached
1437 /// (`CheckpointStrategy::All`). When `0 < K < length`, the
1438 /// trajectory input has shape `[K, *carry]` and the executor
1439 /// recomputes intermediate carries via `forward_body` between
1440 /// checkpoints. `forward_body` must be `Some` whenever this
1441 /// is < length.
1442 num_checkpoints: u32,
1443 /// Forward body (the same `body` from the forward Op::Scan).
1444 /// Required when `num_checkpoints > 0 && < length` so the
1445 /// executor can recompute carries between saved checkpoints.
1446 /// `None` for the All strategy (no recompute needed).
1447 forward_body: Option<Box<crate::Graph>>,
1448 },
1449
1450 /// Companion to [`Self::ScanBackward`] that extracts one stacked
1451 /// per-step `dxs_i` (shape `[length, *per_step_xs_i]`). Same inputs
1452 /// and same `body_vjp` graph as ScanBackward — `xs_idx` selects
1453 /// which body_vjp output to stack into the result.
1454 ///
1455 /// Note: each ScanBackwardXs runs its own backward loop. A future
1456 /// optimization can fuse them into a single multi-output backward
1457 /// kernel; for now it's `1 + num_xs` independent sweeps.
1458 ScanBackwardXs {
1459 body_vjp: Box<crate::Graph>,
1460 length: u32,
1461 save_trajectory: bool,
1462 num_xs: u32,
1463 xs_idx: u32,
1464 num_checkpoints: u32,
1465 forward_body: Option<Box<crate::Graph>>,
1466 },
1467
1468 /// CPU reference 3D Gaussian splat forward render.
1469 ///
1470 /// Seven flat F32 inputs (scene buffers + camera/render meta):
1471 /// 0. positions `[N*3]`
1472 /// 1. scales `[N*3]` (log-space)
1473 /// 2. rotations `[N*4]` (xyzw)
1474 /// 3. opacities `[N]` (logit)
1475 /// 4. colors `[N*3]` (linear RGB)
1476 /// 5. sh_coeffs `[N * sh_coeff_count * 3]`
1477 /// 6. meta `[23]` — camera position/target/up/fov/near/far, background RGB,
1478 /// then width/height/tile_size/radius_scale/alpha_cutoff/max_splat_steps/
1479 /// transmittance_threshold/max_list_entries as f32 bit-patterns.
1480 ///
1481 /// Output: `[height * width * 4]` linear RGBA (display gamma baked in).
1482 /// Build via [`crate::Graph::gaussian_splat_render`].
1483 ///
1484 /// Differentiable backward is not implemented in v1; autodiff treats this
1485 /// op as non-differentiable (same as [`Op::Sample`]).
1486 GaussianSplatRender {
1487 width: u32,
1488 height: u32,
1489 tile_size: u32,
1490 radius_scale: f32,
1491 alpha_cutoff: f32,
1492 max_splat_steps: u32,
1493 transmittance_threshold: f32,
1494 max_list_entries: u32,
1495 },
1496
1497 /// Backward pass for [`Self::GaussianSplatRender`].
1498 ///
1499 /// Eight inputs: the same seven as forward plus `d_loss_rgba` `[W*H*4]`
1500 /// (only RGB channels are used). Re-runs the training forward internally.
1501 ///
1502 /// Output: packed gradients
1503 /// `[positions(3N) | scales(3N) | rotations(4N) | opacities(N) | colors(3N) | sh(N*sh*3)]`.
1504 /// Unpack via [`crate::ops::splat::unpack_gaussian_splat_packed_grads`].
1505 GaussianSplatRenderBackward {
1506 width: u32,
1507 height: u32,
1508 tile_size: u32,
1509 radius_scale: f32,
1510 alpha_cutoff: f32,
1511 max_splat_steps: u32,
1512 transmittance_threshold: f32,
1513 max_list_entries: u32,
1514 loss_grad_clip: f32,
1515 sh_band: u32,
1516 max_anisotropy: f32,
1517 },
1518
1519 /// Strict IR stage 1: project, bin, sort, build per-pixel rays.
1520 ///
1521 /// Seven inputs (same scene + meta as [`Self::GaussianSplatRender`]). Output: packed
1522 /// prepare buffer (see `rlx_splat::prep_layout::prep_packed_len`).
1523 GaussianSplatPrepare {
1524 width: u32,
1525 height: u32,
1526 tile_size: u32,
1527 radius_scale: f32,
1528 alpha_cutoff: f32,
1529 max_splat_steps: u32,
1530 transmittance_threshold: f32,
1531 max_list_entries: u32,
1532 },
1533
1534 /// Strict IR stage 2: tile raster from [`Self::GaussianSplatPrepare`] output.
1535 ///
1536 /// Inputs: `prep` packed buffer, `meta` `[23]`. Output: `[width * height * 4]` RGBA.
1537 GaussianSplatRasterize {
1538 width: u32,
1539 height: u32,
1540 tile_size: u32,
1541 alpha_cutoff: f32,
1542 max_splat_steps: u32,
1543 transmittance_threshold: f32,
1544 max_list_entries: u32,
1545 },
1546
1547 /// User-registered custom op. `name` keys into the
1548 /// [`crate::op_registry`] for shape inference, autodiff, and
1549 /// per-backend execution. `attrs` is an opaque blob passed
1550 /// through to those callbacks (FFT direction, SparseLU
1551 /// reordering strategy, etc.). `num_inputs` is captured at
1552 /// construction time so [`Op::num_inputs`] stays infallible
1553 /// without a registry lookup. Build via [`crate::Graph::custom_op`].
1554 Custom {
1555 name: String,
1556 num_inputs: u32,
1557 attrs: Vec<u8>,
1558 },
1559
1560 /// 1D Fast Fourier Transform along the last axis.
1561 ///
1562 /// **Layouts**
1563 /// - `F32` / `F64`: 2N real-block — last axis is `[re₀…re_{N-1}, im₀…im_{N-1}]`.
1564 /// - `C64`: interleaved `[re, im]` pairs per complex element along the last axis.
1565 ///
1566 /// **ND transforms** — use `Graph::fftn` / `Graph::ifftn`, which compose
1567 /// `fft_axis` (transpose → Fft → transpose). Multi-axis `fftn` requires
1568 /// `DType::C64`; the 2N-block layout describes a single complex axis.
1569 ///
1570 /// Default (`FftNorm::Backward`) is **unnormalized** on both directions:
1571 /// `fft(x)[k] = Σ x[n]·exp(-2πi·nk/N)`
1572 /// `ifft(y)[n] = Σ y[k]·exp(+2πi·nk/N)`
1573 /// so `ifft(fft(x)) = N·x`. Use `FftNorm::Forward` for gpu-fft-style
1574 /// `1/N` scaling on inverse, or `FftNorm::Ortho` for unitary scaling.
1575 ///
1576 /// AD: VJP(`fft`) = `ifft`, VJP(`ifft`) = `fft` when `norm=Backward`;
1577 /// other norms apply the chain rule via output scaling.
1578 Fft {
1579 inverse: bool,
1580 norm: crate::fft::FftNorm,
1581 },
1582
1583 /// Ternary pruned radix-2 butterfly stage on interleaved complex state.
1584 ///
1585 /// Inputs:
1586 /// 0 — state `[batch, n_fft, 2]` (re/im on axis 2)
1587 /// 1 — gate `[half]` — 0 = identity, 1 = run butterfly (`half = n_fft/2`)
1588 /// 2 — rev `[half]` — 0 = forward, 1 = swap outputs when gate=1
1589 /// 3 — tw_re `[half]`
1590 /// 4 — tw_im `[half]`
1591 ///
1592 /// Output: `[batch, n_fft, 2]` same layout. Slots with gate=0 copy inputs
1593 /// without twiddle math.
1594 FftButterflyStage {
1595 stage: u32,
1596 n_fft: u32,
1597 },
1598
1599 /// Log-mel spectrogram from RLX FFT block-layout spectrum.
1600 ///
1601 /// Inputs:
1602 /// 0 — spectrum `[..., 2*n_fft]` (re plane then im plane, same as `Op::Fft` output)
1603 /// 1 — mel filterbank `[n_mels, n_bins]` with `n_bins = n_fft/2 + 1`
1604 ///
1605 /// Output: `[..., n_mels]` with Whisper dynamic-range compression
1606 /// (`log10`, clamp to max−8 dB, `(x+4)/4`).
1607 LogMel,
1608
1609 /// VJP of [`Op::LogMel`] w.r.t. spectrum (input 0).
1610 ///
1611 /// Inputs: spectrum block, mel filters, upstream `dy`.
1612 /// Output: `d_spectrum` (same shape as input 0).
1613 LogMelBackward,
1614
1615 /// Top-K Welch peaks from block-layout segment spectra.
1616 ///
1617 /// Input 0: spectrum `[batch * n_segments, 2*n_fft]` (re ∥ im planes).
1618 /// Output: `[batch, k*2]` interleaved `(bin, power)` per spike.
1619 WelchPeaks {
1620 k: usize,
1621 n_segments: usize,
1622 },
1623
1624 /// User-defined sub-graph with optional override AD rules.
1625 /// Mirrors JAX's `custom_vjp` / `custom_jvp` decorators: the
1626 /// caller wraps a forward computation and supplies its own
1627 /// reverse- and/or forward-mode AD bodies. Useful when:
1628 /// * The forward is iterative (Newton, fixed-point) and
1629 /// differentiating through the loop is wasteful — the
1630 /// vjp_body computes the implicit-function gradient at the
1631 /// converged point in one shot.
1632 /// * The math has a closed-form gradient that's much cheaper
1633 /// than autodiff.
1634 /// * The forward op is non-differentiable by tracing
1635 /// (sampling, argmax) and the user wants a smooth surrogate.
1636 ///
1637 /// **fwd_body**: `num_inputs` Op::Inputs in NodeId construction
1638 /// order, one Op::Output (the primal y). Forward execution
1639 /// inlines this body once.
1640 ///
1641 /// **vjp_body** (optional): Op::Inputs are `num_inputs` primal
1642 /// inputs in NodeId order, plus two special-named Inputs —
1643 /// `"primal_output"` (the y from forward) and `"d_output"` (the
1644 /// upstream gradient). Outputs: `num_inputs` tensors in
1645 /// `set_outputs` order, matching the gradients of each primal
1646 /// input. When `None`, reverse-mode AD recurses into fwd_body
1647 /// — same as if the op were inlined.
1648 ///
1649 /// **jvp_body** (optional): Op::Inputs are `num_inputs` primal
1650 /// inputs in NodeId order, `num_inputs` special-named Inputs
1651 /// `"tangent_0"..="tangent_{num_inputs-1}"` carrying each input's
1652 /// tangent, and an optional special-named `"primal_output"` Input
1653 /// (the y from forward, useful when the JVP must be evaluated at
1654 /// a converged / nonlinear point — e.g. IFT-style forward-mode AD
1655 /// of an iterative solver). Output: 1 tensor (the tangent of y).
1656 /// When `None`, forward-mode AD recurses into fwd_body.
1657 ///
1658 /// `num_inputs` is captured so [`Op::num_inputs`] stays
1659 /// infallible. Build via [`crate::Graph::custom_fn`].
1660 CustomFn {
1661 fwd_body: Box<crate::Graph>,
1662 vjp_body: Option<Box<crate::Graph>>,
1663 jvp_body: Option<Box<crate::Graph>>,
1664 num_inputs: u32,
1665 },
1666}
1667
1668impl Op {
1669 /// PLAN L4: discriminant for backend-supported-set checks.
1670 /// Stable, parameter-free identity per variant — `Op::Activation(_)`
1671 /// and `Op::Activation(Relu)` share the same `OpKind::Activation`.
1672 pub fn kind(&self) -> OpKind {
1673 match self {
1674 Op::Input { .. } => OpKind::Input,
1675 Op::Param { .. } => OpKind::Param,
1676 Op::Constant { .. } => OpKind::Constant,
1677 Op::Activation(_) => OpKind::Activation,
1678 Op::Cast { .. } => OpKind::Cast,
1679 Op::StopGradient => OpKind::StopGradient,
1680 Op::Quantize { .. } => OpKind::Quantize,
1681 Op::Dequantize { .. } => OpKind::Dequantize,
1682 Op::FakeQuantize { .. } => OpKind::FakeQuantize,
1683 Op::FakeQuantizeLSQ { .. } => OpKind::FakeQuantizeLSQ,
1684 Op::FakeQuantizeLSQBackwardX { .. } => OpKind::FakeQuantizeLSQBackwardX,
1685 Op::FakeQuantizeLSQBackwardScale { .. } => OpKind::FakeQuantizeLSQBackwardScale,
1686 Op::Binary(_) => OpKind::Binary,
1687 Op::Compare(_) => OpKind::Compare,
1688 Op::Where => OpKind::Where,
1689 Op::ElementwiseRegion { .. } => OpKind::ElementwiseRegion,
1690 Op::TransformRegion { .. } => OpKind::TransformRegion,
1691 Op::BatchElementwiseRegion { .. } => OpKind::BatchElementwiseRegion,
1692 Op::MatMul => OpKind::MatMul,
1693 Op::DotGeneral { .. } => OpKind::DotGeneral,
1694 Op::DenseSolve => OpKind::DenseSolve,
1695 Op::BatchedDenseSolve => OpKind::BatchedDenseSolve,
1696 Op::LayerNorm { .. } => OpKind::LayerNorm,
1697 Op::LayerNorm2d { .. } => OpKind::LayerNorm2d,
1698 Op::GroupNorm { .. } => OpKind::GroupNorm,
1699 Op::BatchNormInference { .. } => OpKind::BatchNormInference,
1700 Op::RmsNorm { .. } => OpKind::RmsNorm,
1701 Op::ResizeNearest2x => OpKind::ResizeNearest2x,
1702 Op::Attention { .. } => OpKind::Attention,
1703 Op::Rope { .. } => OpKind::Rope,
1704 Op::AxialRope2d { .. } => OpKind::AxialRope2d,
1705 Op::Reshape { .. } => OpKind::Reshape,
1706 Op::Transpose { .. } => OpKind::Transpose,
1707 Op::Narrow { .. } => OpKind::Narrow,
1708 Op::Concat { .. } => OpKind::Concat,
1709 Op::Expand { .. } => OpKind::Expand,
1710 Op::Gather { .. } => OpKind::Gather,
1711 Op::Reduce { .. } => OpKind::Reduce,
1712 Op::Softmax { .. } => OpKind::Softmax,
1713 Op::Cumsum { .. } => OpKind::Cumsum,
1714 Op::TopK { .. } => OpKind::TopK,
1715 Op::Sample { .. } => OpKind::Sample,
1716 Op::Conv { .. } => OpKind::Conv,
1717 Op::Im2Col { .. } => OpKind::Im2Col,
1718 Op::ConvTranspose2d { .. } => OpKind::ConvTranspose2d,
1719 Op::Pool { .. } => OpKind::Pool,
1720 Op::ReluBackward => OpKind::ReluBackward,
1721 Op::ActivationBackward { .. } => OpKind::ActivationBackward,
1722 Op::FakeQuantizeBackward { .. } => OpKind::FakeQuantizeBackward,
1723 Op::ComplexNormSq => OpKind::ComplexNormSq,
1724 Op::ComplexNormSqBackward => OpKind::ComplexNormSqBackward,
1725 Op::Conjugate => OpKind::Conjugate,
1726 Op::LayerNormBackwardInput { .. } => OpKind::LayerNormBackwardInput,
1727 Op::LayerNormBackwardGamma { .. } => OpKind::LayerNormBackwardGamma,
1728 Op::RmsNormBackwardInput { .. } => OpKind::RmsNormBackwardInput,
1729 Op::RmsNormBackwardGamma { .. } => OpKind::RmsNormBackwardGamma,
1730 Op::RmsNormBackwardBeta { .. } => OpKind::RmsNormBackwardBeta,
1731 Op::RopeBackward { .. } => OpKind::RopeBackward,
1732 Op::GroupNormBackwardInput { .. } => OpKind::GroupNormBackwardInput,
1733 Op::GroupNormBackwardGamma { .. } => OpKind::GroupNormBackwardGamma,
1734 Op::GroupNormBackwardBeta { .. } => OpKind::GroupNormBackwardBeta,
1735 Op::BatchNormInferenceBackwardInput { .. } => OpKind::BatchNormInferenceBackwardInput,
1736 Op::BatchNormInferenceBackwardGamma { .. } => OpKind::BatchNormInferenceBackwardGamma,
1737 Op::BatchNormInferenceBackwardBeta => OpKind::BatchNormInferenceBackwardBeta,
1738 Op::CumsumBackward { .. } => OpKind::CumsumBackward,
1739 Op::GatherBackward { .. } => OpKind::GatherBackward,
1740 Op::MaxPool2dBackward { .. } => OpKind::MaxPool2dBackward,
1741 Op::Conv2dBackwardInput { .. } => OpKind::Conv2dBackwardInput,
1742 Op::Conv2dBackwardWeight { .. } => OpKind::Conv2dBackwardWeight,
1743 Op::SoftmaxCrossEntropyWithLogits => OpKind::SoftmaxCrossEntropyWithLogits,
1744 Op::SoftmaxCrossEntropyBackward => OpKind::SoftmaxCrossEntropyBackward,
1745 Op::AttentionBackward { .. } => OpKind::AttentionBackward,
1746 Op::GroupedMatMul => OpKind::GroupedMatMul,
1747 Op::DequantGroupedMatMul { .. } => OpKind::DequantGroupedMatMul,
1748 Op::DequantMoEWeights { .. } => OpKind::DequantMoEWeights,
1749 Op::ScatterAdd => OpKind::ScatterAdd,
1750 Op::LoraMatMul { .. } => OpKind::LoraMatMul,
1751 Op::DequantMatMul { .. } => OpKind::DequantMatMul,
1752 Op::QMatMul { .. } => OpKind::QMatMul,
1753 Op::QConv2d { .. } => OpKind::QConv2d,
1754 Op::SelectiveScan { .. } => OpKind::SelectiveScan,
1755 Op::GatedDeltaNet { .. } => OpKind::GatedDeltaNet,
1756 Op::FusedSwiGLU { .. } => OpKind::FusedSwiGLU,
1757 Op::FusedMatMulBiasAct { .. } => OpKind::FusedMatMulBiasAct,
1758 Op::FusedResidualLN { .. } => OpKind::FusedResidualLN,
1759 Op::FusedResidualRmsNorm { .. } => OpKind::FusedResidualRmsNorm,
1760 Op::FusedAttentionBlock { .. } => OpKind::FusedAttentionBlock,
1761 Op::FusedTransformerLayer { .. } => OpKind::FusedTransformerLayer,
1762 Op::If { .. } => OpKind::If,
1763 Op::While { .. } => OpKind::While,
1764 Op::Scan { .. } => OpKind::Scan,
1765 Op::ScanBackward { .. } => OpKind::ScanBackward,
1766 Op::ScanBackwardXs { .. } => OpKind::ScanBackwardXs,
1767 Op::GaussianSplatRender { .. } => OpKind::GaussianSplatRender,
1768 Op::GaussianSplatRenderBackward { .. } => OpKind::GaussianSplatRenderBackward,
1769 Op::GaussianSplatPrepare { .. } => OpKind::GaussianSplatPrepare,
1770 Op::GaussianSplatRasterize { .. } => OpKind::GaussianSplatRasterize,
1771 Op::Custom { .. } => OpKind::Custom,
1772 Op::CustomFn { .. } => OpKind::CustomFn,
1773 Op::Fft { .. } => OpKind::Fft,
1774 Op::FftButterflyStage { .. } => OpKind::FftButterflyStage,
1775 Op::LogMel => OpKind::LogMel,
1776 Op::LogMelBackward => OpKind::LogMelBackward,
1777 Op::WelchPeaks { .. } => OpKind::WelchPeaks,
1778 }
1779 }
1780
1781 /// True if this op is element-wise (same shape in, same shape out).
1782 /// Element-wise ops are prime fusion candidates.
1783 pub fn is_elementwise(&self) -> bool {
1784 matches!(
1785 self,
1786 Op::Activation(_)
1787 | Op::Cast { .. }
1788 | Op::StopGradient
1789 | Op::Binary(_)
1790 | Op::Compare(_)
1791 | Op::Where
1792 | Op::ElementwiseRegion { .. }
1793 | Op::BatchElementwiseRegion { .. }
1794 )
1795 }
1796
1797 /// True if this op may appear in a [`Op::TransformRegion`] chain.
1798 pub fn is_transform_eligible(&self) -> bool {
1799 matches!(self, Op::ResizeNearest2x)
1800 }
1801
1802 /// True if this op is a BLAS/compute-intensive op that forms a fusion boundary.
1803 pub fn is_blas(&self) -> bool {
1804 matches!(
1805 self,
1806 Op::MatMul
1807 | Op::DotGeneral { .. }
1808 | Op::DenseSolve
1809 | Op::BatchedDenseSolve
1810 | Op::Conv { .. }
1811 | Op::Im2Col { .. }
1812 | Op::ConvTranspose2d { .. }
1813 | Op::FusedMatMulBiasAct { .. }
1814 | Op::GroupedMatMul
1815 | Op::DequantGroupedMatMul { .. }
1816 | Op::DequantMoEWeights { .. }
1817 | Op::LoraMatMul { .. }
1818 | Op::DequantMatMul { .. }
1819 | Op::QMatMul { .. }
1820 | Op::QConv2d { .. }
1821 )
1822 }
1823
1824 /// True if element-wise fusion must not span across this op.
1825 pub fn is_fusion_boundary(&self) -> bool {
1826 self.is_blas()
1827 || matches!(
1828 self,
1829 Op::GaussianSplatRender { .. }
1830 | Op::GaussianSplatRenderBackward { .. }
1831 | Op::GaussianSplatPrepare { .. }
1832 | Op::GaussianSplatRasterize { .. }
1833 )
1834 }
1835
1836 /// True if this op is a reduction (drives loop iteration in fused kernels).
1837 pub fn is_reduction(&self) -> bool {
1838 matches!(
1839 self,
1840 Op::Reduce { .. } | Op::Softmax { .. } | Op::TopK { .. }
1841 )
1842 }
1843
1844 /// Number of tensor inputs this op expects.
1845 pub fn num_inputs(&self) -> usize {
1846 match self {
1847 Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => 0,
1848 Op::Activation(_)
1849 | Op::Cast { .. }
1850 | Op::StopGradient
1851 | Op::Reshape { .. }
1852 | Op::Quantize { .. }
1853 | Op::Dequantize { .. }
1854 | Op::Transpose { .. }
1855 | Op::Narrow { .. }
1856 | Op::Expand { .. }
1857 | Op::Reduce { .. }
1858 | Op::Softmax { .. }
1859 | Op::FusedSwiGLU { .. }
1860 | Op::TopK { .. }
1861 | Op::Cumsum { .. }
1862 | Op::Sample { .. }
1863 | Op::ResizeNearest2x => 1,
1864 // EMA / Fixed scale modes carry a state tensor as a 2nd input;
1865 // PerBatch (default) doesn't need one.
1866 Op::FakeQuantize { scale_mode, .. } => match scale_mode {
1867 ScaleMode::PerBatch => 1,
1868 ScaleMode::EMA { .. } | ScaleMode::Fixed => 2,
1869 },
1870 Op::FakeQuantizeLSQ { .. } => 2, // x, scale (learned param)
1871 Op::FakeQuantizeLSQBackwardX { .. } | Op::FakeQuantizeLSQBackwardScale { .. } => 3, // x, scale, dy
1872 Op::Binary(_) | Op::Compare(_) | Op::Gather { .. } | Op::MatMul | Op::ScatterAdd => 2,
1873 Op::GroupedMatMul => 3, // input, weight, expert_idx
1874 Op::DequantGroupedMatMul { .. } => 3, // input, packed_w, expert_idx
1875 Op::DequantMoEWeights { .. } => 1, // packed_w
1876 Op::LoraMatMul { .. } => 4, // x, w, a, b
1877 // x, w_q, scale, zp — or x, packed_w_bytes for GGUF
1878 // schemes (their scales/mins live inside the packed bytes,
1879 // see `QuantScheme::is_gguf`).
1880 Op::DequantMatMul { scheme } => {
1881 if scheme.is_gguf() {
1882 2
1883 } else {
1884 4
1885 }
1886 }
1887 Op::QMatMul { .. } => 3, // x, w, bias
1888 Op::QConv2d { .. } => 3, // x, w, bias
1889 Op::SelectiveScan { .. } => 5, // x, delta, a, b, c
1890 Op::GatedDeltaNet { carry_state, .. } if *carry_state => 6, // + state in/out
1891 Op::GatedDeltaNet { .. } => 5, // q, k, v, g, beta
1892 Op::Where => 3, // cond, on_true, on_false
1893 Op::Attention { mask_kind, .. } => match mask_kind {
1894 MaskKind::Custom | MaskKind::Bias => 4, // Q, K, V, mask
1895 _ => 3, // Q, K, V (mask synthesized in-kernel)
1896 },
1897 Op::AttentionBackward { mask_kind, .. } => match mask_kind {
1898 MaskKind::Custom | MaskKind::Bias => 5, // q, k, v, dy, mask
1899 _ => 4, // q, k, v, dy
1900 },
1901 Op::Rope { .. } => 3, // x, cos, sin
1902 Op::AxialRope2d { .. } => 1,
1903 Op::LayerNorm { .. }
1904 | Op::LayerNorm2d { .. }
1905 | Op::GroupNorm { .. }
1906 | Op::RmsNorm { .. } => 3, // input, gamma, beta
1907 Op::BatchNormInference { .. } => 5, // x, gamma, beta, mean, var
1908 Op::FusedMatMulBiasAct { .. } => 3, // input, weight, bias
1909 Op::FusedResidualLN { has_bias: true, .. } => 5, // x, residual, bias, gamma, beta
1910 Op::FusedResidualLN {
1911 has_bias: false, ..
1912 } => 4, // x, residual, gamma, beta
1913 Op::FusedResidualRmsNorm { has_bias: true, .. } => 5, // x, residual, bias, gamma, beta
1914 Op::FusedResidualRmsNorm {
1915 has_bias: false, ..
1916 } => 4, // x, residual, gamma, beta
1917 Op::Conv { .. } | Op::ConvTranspose2d { .. } => 2, // input, weight (bias via Add)
1918 Op::Im2Col { .. } => 1,
1919 Op::Pool { .. } => 1,
1920 Op::ReluBackward => 2, // x, dy
1921 Op::ActivationBackward { .. } => 2, // x, dy
1922 Op::FakeQuantizeBackward { .. } => 2, // x, dy
1923 Op::ComplexNormSq => 1, // z (C64)
1924 Op::ComplexNormSqBackward => 2, // z, g
1925 Op::Conjugate => 1, // z (C64)
1926 Op::LayerNormBackwardInput { .. } => 3, // x, gamma, dy
1927 Op::LayerNormBackwardGamma { .. } => 2, // x, dy
1928 Op::RmsNormBackwardInput { .. } => 4, // x, gamma, beta, dy
1929 Op::RmsNormBackwardGamma { .. } => 4,
1930 Op::RmsNormBackwardBeta { .. } => 4,
1931 Op::RopeBackward { .. } => 3, // dy, cos, sin
1932 Op::GroupNormBackwardInput { .. } => 4, // x, gamma, beta, dy
1933 Op::GroupNormBackwardGamma { .. } => 2, // x, dy
1934 Op::GroupNormBackwardBeta { .. } => 2,
1935 Op::BatchNormInferenceBackwardInput { .. } => 5, // x, gamma, mean, var, dy
1936 Op::BatchNormInferenceBackwardGamma { .. } => 4, // x, mean, var, dy
1937 Op::BatchNormInferenceBackwardBeta => 1, // dy
1938 Op::CumsumBackward { .. } => 1, // dy
1939 Op::GatherBackward { .. } => 2, // dy, indices
1940 Op::MaxPool2dBackward { .. } => 2, // x, dy
1941 Op::Conv2dBackwardInput { .. } => 2, // dy, w
1942 Op::Conv2dBackwardWeight { .. } => 2, // x, dy
1943 Op::SoftmaxCrossEntropyWithLogits => 2, // logits, labels
1944 Op::SoftmaxCrossEntropyBackward => 3, // logits, labels, d_loss
1945 Op::Concat { .. } => 0, // variadic — checked at graph level
1946 Op::DotGeneral { .. } => 2,
1947 Op::DenseSolve => 2, // A, b
1948 Op::BatchedDenseSolve => 2, // A [B,N,N], b [B,N] or [B,N,K]
1949 Op::FusedAttentionBlock {
1950 has_bias, has_rope, ..
1951 } => 4 + if *has_bias { 2 } else { 0 } + if *has_rope { 2 } else { 0 },
1952 Op::If { .. } => 1, // predicate (captures handled separately)
1953 Op::While { .. } => 0, // variadic loop-carried; checked at graph level
1954 Op::Scan {
1955 num_bcast, num_xs, ..
1956 } => 1 + *num_bcast as usize + *num_xs as usize,
1957 Op::ScanBackward { num_xs, .. } => 3 + *num_xs as usize, // init, trajectory, upstream, xs_0..
1958 Op::ScanBackwardXs { num_xs, .. } => 3 + *num_xs as usize, // same as ScanBackward
1959 Op::GaussianSplatRender { .. } => 7,
1960 Op::GaussianSplatRenderBackward { .. } => 8,
1961 Op::GaussianSplatPrepare { .. } => 7,
1962 Op::GaussianSplatRasterize { .. } => 2,
1963 Op::FusedTransformerLayer { has_bias, .. } => {
1964 // hidden + qkv_w + out_w + ln1_g + ln1_b + fc1_w + fc2_w + ln2_g + ln2_b + mask = 10
1965 // bias variant adds: qkv_b + out_b + fc1_b + fc2_b = 4 more
1966 10 + if *has_bias { 4 } else { 0 }
1967 }
1968 Op::ElementwiseRegion { num_inputs, .. } => *num_inputs as usize,
1969 Op::TransformRegion { num_inputs, .. } => *num_inputs as usize,
1970 Op::BatchElementwiseRegion {
1971 num_batch_inputs, ..
1972 } => *num_batch_inputs as usize,
1973 Op::Custom { num_inputs, .. } => *num_inputs as usize,
1974 Op::CustomFn { num_inputs, .. } => *num_inputs as usize,
1975 Op::Fft { .. } => 1,
1976 Op::FftButterflyStage { .. } => 5,
1977 Op::LogMel => 2,
1978 Op::LogMelBackward => 3,
1979 Op::WelchPeaks { .. } => 1,
1980 }
1981 }
1982}
1983
1984impl std::fmt::Display for Op {
1985 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1986 match self {
1987 Op::Input { name } => write!(f, "input(\"{name}\")"),
1988 Op::Param { name } => write!(f, "param(\"{name}\")"),
1989 Op::Constant { data } => write!(f, "const({}B)", data.len()),
1990 Op::Activation(a) => write!(f, "{a:?}"),
1991 Op::Quantize { axis, scales, .. } => match axis {
1992 None => write!(f, "quantize(s={})", scales[0]),
1993 Some(d) => write!(f, "quantize(axis={d},nch={})", scales.len()),
1994 },
1995 Op::Dequantize { axis, scales, .. } => match axis {
1996 None => write!(f, "dequantize(s={})", scales[0]),
1997 Some(d) => write!(f, "dequantize(axis={d},nch={})", scales.len()),
1998 },
1999 Op::FakeQuantize {
2000 bits,
2001 axis,
2002 ste,
2003 scale_mode,
2004 } => match axis {
2005 None => write!(
2006 f,
2007 "fake_quant(bits={bits},ste={ste:?},scale={scale_mode:?})"
2008 ),
2009 Some(d) => write!(
2010 f,
2011 "fake_quant(bits={bits},axis={d},ste={ste:?},scale={scale_mode:?})"
2012 ),
2013 },
2014 Op::FakeQuantizeLSQ { bits, axis } => match axis {
2015 None => write!(f, "fake_quant_lsq(bits={bits})"),
2016 Some(d) => write!(f, "fake_quant_lsq(bits={bits},axis={d})"),
2017 },
2018 Op::FakeQuantizeLSQBackwardX { bits, .. } => {
2019 write!(f, "fake_quant_lsq_bwd_x(bits={bits})")
2020 }
2021 Op::FakeQuantizeLSQBackwardScale { bits, .. } => {
2022 write!(f, "fake_quant_lsq_bwd_s(bits={bits})")
2023 }
2024 Op::Cast { to } => write!(f, "cast({to})"),
2025 Op::StopGradient => write!(f, "stop_gradient"),
2026 Op::Binary(op) => write!(f, "{op:?}"),
2027 Op::Compare(op) => write!(f, "{op:?}"),
2028 Op::Where => write!(f, "where"),
2029 Op::MatMul => write!(f, "matmul"),
2030 Op::DotGeneral { .. } => write!(f, "dot_general"),
2031 Op::DenseSolve => write!(f, "dense_solve"),
2032 Op::BatchedDenseSolve => write!(f, "batched_dense_solve"),
2033 Op::LayerNorm { eps, .. } => write!(f, "layer_norm(eps={eps})"),
2034 Op::GroupNorm { num_groups, eps } => {
2035 write!(f, "group_norm(groups={num_groups},eps={eps})")
2036 }
2037 Op::BatchNormInference { eps } => write!(f, "batch_norm_inference(eps={eps})"),
2038 Op::ResizeNearest2x => write!(f, "resize_nearest_2x"),
2039 Op::RmsNorm { eps, .. } => write!(f, "rms_norm(eps={eps})"),
2040 Op::Attention {
2041 num_heads,
2042 head_dim,
2043 mask_kind,
2044 score_scale,
2045 attn_logit_softcap,
2046 } => {
2047 let mut s = match mask_kind {
2048 MaskKind::Custom => format!("attention(h={num_heads},d={head_dim})"),
2049 MaskKind::None => format!("attention(h={num_heads},d={head_dim},nomask)"),
2050 MaskKind::Causal => format!("attention(h={num_heads},d={head_dim},causal)"),
2051 MaskKind::SlidingWindow(w) => {
2052 format!("attention(h={num_heads},d={head_dim},sw={w})")
2053 }
2054 MaskKind::Bias => format!("attention(h={num_heads},d={head_dim},bias)"),
2055 };
2056 if let Some(sc) = score_scale {
2057 s.push_str(&format!(",scale={sc}"));
2058 }
2059 if let Some(cap) = attn_logit_softcap {
2060 s.push_str(&format!(",softcap={cap}"));
2061 }
2062 write!(f, "{s}")
2063 }
2064 Op::Rope { head_dim, n_rot } => write!(f, "rope(d={head_dim}, n_rot={n_rot})"),
2065 Op::AxialRope2d {
2066 end_x,
2067 end_y,
2068 head_dim,
2069 num_heads,
2070 theta,
2071 repeat_factor,
2072 } => write!(
2073 f,
2074 "axial_rope2d({end_x}x{end_y},h={num_heads},d={head_dim},θ={theta},r={repeat_factor})"
2075 ),
2076 Op::Reshape { new_shape } => write!(f, "reshape({new_shape:?})"),
2077 Op::Transpose { perm } => write!(f, "transpose({perm:?})"),
2078 Op::Narrow { axis, start, len } => write!(f, "narrow({axis},{start},{len})"),
2079 Op::Concat { axis } => write!(f, "concat(axis={axis})"),
2080 Op::Expand { .. } => write!(f, "expand"),
2081 Op::Gather { axis } => write!(f, "gather(axis={axis})"),
2082 Op::Reduce { op, axes, .. } => write!(f, "reduce_{op:?}({axes:?})"),
2083 Op::Softmax { axis } => write!(f, "softmax(axis={axis})"),
2084 Op::Cumsum { axis, exclusive } => {
2085 if *exclusive {
2086 write!(f, "cumsum(axis={axis},excl)")
2087 } else {
2088 write!(f, "cumsum(axis={axis})")
2089 }
2090 }
2091 Op::Sample {
2092 top_k,
2093 top_p,
2094 temperature,
2095 ..
2096 } => {
2097 write!(f, "sample(t={temperature}")?;
2098 if *top_k > 0 {
2099 write!(f, ",k={top_k}")?;
2100 }
2101 if *top_p < 1.0 {
2102 write!(f, ",p={top_p}")?;
2103 }
2104 write!(f, ")")
2105 }
2106 Op::TopK { k } => write!(f, "topk(k={k})"),
2107 Op::GroupedMatMul => write!(f, "grouped_matmul"),
2108 Op::DequantGroupedMatMul { scheme } => {
2109 write!(f, "dequant_grouped_matmul({scheme})")
2110 }
2111 Op::DequantMoEWeights { scheme } => write!(f, "dequant_moe_weights({scheme})"),
2112 Op::LoraMatMul { scale } => write!(f, "lora_matmul(scale={scale})"),
2113 Op::DequantMatMul { scheme } => write!(f, "dequant_matmul({scheme})"),
2114 Op::QMatMul {
2115 x_zp,
2116 w_zp,
2117 out_zp,
2118 mult,
2119 } => write!(
2120 f,
2121 "q_matmul(x_zp={x_zp},w_zp={w_zp},out_zp={out_zp},mult={mult})"
2122 ),
2123 Op::QConv2d { kernel_size, .. } => write!(f, "q_conv2d({kernel_size:?})"),
2124 Op::SelectiveScan { state_size } => write!(f, "ssm_scan(n={state_size})"),
2125 Op::GatedDeltaNet {
2126 state_size,
2127 carry_state,
2128 } => {
2129 if *carry_state {
2130 write!(f, "gated_delta_net(n={state_size},carry)")
2131 } else {
2132 write!(f, "gated_delta_net(n={state_size})")
2133 }
2134 }
2135 Op::ScatterAdd => write!(f, "scatter_add"),
2136 Op::Conv { kernel_size, .. } => write!(f, "conv2d({kernel_size:?})"),
2137 Op::Im2Col { kernel_size, .. } => write!(f, "im2col({kernel_size:?})"),
2138 Op::ConvTranspose2d { kernel_size, .. } => {
2139 write!(f, "conv_transpose2d({kernel_size:?})")
2140 }
2141 Op::LayerNorm2d { eps } => write!(f, "layer_norm2d(eps={eps})"),
2142 Op::Pool {
2143 kind, kernel_size, ..
2144 } => write!(f, "pool_{kind:?}({kernel_size:?})"),
2145 Op::ReluBackward => write!(f, "relu_backward"),
2146 Op::ActivationBackward { kind } => write!(f, "{kind:?}_backward"),
2147 Op::ComplexNormSq => write!(f, "complex_norm_sq"),
2148 Op::ComplexNormSqBackward => write!(f, "complex_norm_sq_backward"),
2149 Op::Conjugate => write!(f, "conjugate"),
2150 Op::FakeQuantizeBackward { bits, ste, .. } => {
2151 write!(f, "fake_quant_backward(bits={bits},ste={ste:?})")
2152 }
2153 Op::MaxPool2dBackward { kernel_size, .. } => {
2154 write!(f, "maxpool2d_backward({kernel_size:?})")
2155 }
2156 Op::Conv2dBackwardInput { kernel_size, .. } => {
2157 write!(f, "conv2d_backward_input({kernel_size:?})")
2158 }
2159 Op::Conv2dBackwardWeight { kernel_size, .. } => {
2160 write!(f, "conv2d_backward_weight({kernel_size:?})")
2161 }
2162 Op::SoftmaxCrossEntropyWithLogits => write!(f, "sce_with_logits"),
2163 Op::SoftmaxCrossEntropyBackward => write!(f, "sce_backward"),
2164 Op::AttentionBackward {
2165 num_heads,
2166 head_dim,
2167 mask_kind,
2168 wrt,
2169 } => match mask_kind {
2170 MaskKind::None => write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},nomask)"),
2171 MaskKind::Causal => {
2172 write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},causal)")
2173 }
2174 MaskKind::SlidingWindow(w) => {
2175 write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},sw={w})")
2176 }
2177 MaskKind::Custom => {
2178 write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},custom)")
2179 }
2180 MaskKind::Bias => write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},bias)"),
2181 },
2182 Op::FusedMatMulBiasAct { activation } => {
2183 write!(f, "fused_mm_bias")?;
2184 if let Some(a) = activation {
2185 write!(f, "_{a:?}")?;
2186 }
2187 Ok(())
2188 }
2189 Op::FusedResidualLN { has_bias, eps } => {
2190 write!(f, "fused_residual")?;
2191 if *has_bias {
2192 write!(f, "_bias")?;
2193 }
2194 write!(f, "_ln(eps={eps})")
2195 }
2196 Op::FusedResidualRmsNorm { has_bias, eps } => {
2197 write!(f, "fused_residual")?;
2198 if *has_bias {
2199 write!(f, "_bias")?;
2200 }
2201 write!(f, "_rms(eps={eps})")
2202 }
2203 Op::FusedSwiGLU {
2204 cast_to,
2205 gate_first,
2206 } => {
2207 let mut s = match cast_to {
2208 Some(dt) => format!("fused_swiglu(cast={dt}"),
2209 None => "fused_swiglu(".to_string(),
2210 };
2211 if *gate_first {
2212 s.push_str(",gate_first");
2213 }
2214 s.push(')');
2215 write!(f, "{s}")
2216 }
2217 Op::FusedAttentionBlock {
2218 num_heads,
2219 head_dim,
2220 has_bias,
2221 has_rope,
2222 } => {
2223 write!(f, "fused_attn(h={num_heads},d={head_dim}")?;
2224 if *has_bias {
2225 write!(f, ",bias")?;
2226 }
2227 if *has_rope {
2228 write!(f, ",rope")?;
2229 }
2230 write!(f, ")")
2231 }
2232 Op::If { .. } => write!(f, "if(...)"),
2233 Op::While { max_iterations, .. } => match max_iterations {
2234 Some(n) => write!(f, "while(...max={n})"),
2235 None => write!(f, "while(...)"),
2236 },
2237 Op::Scan {
2238 length,
2239 save_trajectory,
2240 num_xs,
2241 ..
2242 } => {
2243 let traj = if *save_trajectory { ",traj" } else { "" };
2244 let xs = if *num_xs > 0 {
2245 format!(",xs={}", num_xs)
2246 } else {
2247 String::new()
2248 };
2249 write!(f, "scan(len={length}{xs}{traj})")
2250 }
2251 Op::ScanBackward {
2252 length,
2253 save_trajectory,
2254 num_xs,
2255 ..
2256 } => {
2257 let traj = if *save_trajectory { ",traj" } else { "" };
2258 let xs = if *num_xs > 0 {
2259 format!(",xs={}", num_xs)
2260 } else {
2261 String::new()
2262 };
2263 write!(f, "scan_bwd(len={length}{xs}{traj})")
2264 }
2265 Op::ScanBackwardXs {
2266 length,
2267 save_trajectory,
2268 num_xs,
2269 xs_idx,
2270 ..
2271 } => {
2272 let traj = if *save_trajectory { ",traj" } else { "" };
2273 write!(
2274 f,
2275 "scan_bwd_xs(len={length},xs={num_xs},idx={xs_idx}{traj})"
2276 )
2277 }
2278 Op::FusedTransformerLayer {
2279 num_heads,
2280 head_dim,
2281 intermediate_size,
2282 has_bias,
2283 ..
2284 } => {
2285 write!(
2286 f,
2287 "fused_layer(h={num_heads},d={head_dim},int={intermediate_size}"
2288 )?;
2289 if *has_bias {
2290 write!(f, ",bias")?;
2291 }
2292 write!(f, ")")
2293 }
2294 Op::ElementwiseRegion {
2295 chain,
2296 num_inputs,
2297 scalar_input_mask,
2298 input_modulus: _,
2299 prologue,
2300 prologue_input: _,
2301 } => {
2302 let pro = match prologue {
2303 RegionPrologue::None => "",
2304 RegionPrologue::ResizeNearest2x => ",prologue=resize2x",
2305 };
2306 if *scalar_input_mask != 0 {
2307 write!(
2308 f,
2309 "ew_region(in={num_inputs},steps={},scalar_mask=0x{:x}{pro})",
2310 chain.len(),
2311 scalar_input_mask
2312 )
2313 } else {
2314 write!(f, "ew_region(in={num_inputs},steps={}{pro})", chain.len())
2315 }
2316 }
2317 Op::TransformRegion { steps, num_inputs } => {
2318 write!(f, "transform_region(in={num_inputs},steps={})", steps.len())
2319 }
2320 Op::BatchElementwiseRegion {
2321 chain,
2322 num_batch_inputs,
2323 scalar_input_mask,
2324 prologue,
2325 ..
2326 } => write!(
2327 f,
2328 "batch_ew_region(batch={num_batch_inputs},steps={},mask=0x{:x},prologue={prologue:?})",
2329 chain.len(),
2330 scalar_input_mask
2331 ),
2332 Op::LayerNormBackwardInput { eps, .. } => {
2333 write!(f, "layer_norm_backward_input(eps={eps})")
2334 }
2335 Op::LayerNormBackwardGamma { eps, .. } => {
2336 write!(f, "layer_norm_backward_gamma(eps={eps})")
2337 }
2338 Op::RmsNormBackwardInput { eps, .. } => write!(f, "rms_norm_backward_input(eps={eps})"),
2339 Op::RmsNormBackwardGamma { eps, .. } => write!(f, "rms_norm_backward_gamma(eps={eps})"),
2340 Op::RmsNormBackwardBeta { eps, .. } => write!(f, "rms_norm_backward_beta(eps={eps})"),
2341 Op::RopeBackward { head_dim, n_rot } => {
2342 write!(f, "rope_backward(d={head_dim},n_rot={n_rot})")
2343 }
2344 Op::GroupNormBackwardInput { num_groups, eps } => {
2345 write!(f, "group_norm_backward_input(g={num_groups},eps={eps})")
2346 }
2347 Op::GroupNormBackwardGamma { num_groups, eps } => {
2348 write!(f, "group_norm_backward_gamma(g={num_groups},eps={eps})")
2349 }
2350 Op::GroupNormBackwardBeta { num_groups, eps } => {
2351 write!(f, "group_norm_backward_beta(g={num_groups},eps={eps})")
2352 }
2353 Op::BatchNormInferenceBackwardInput { eps } => {
2354 write!(f, "batch_norm_inference_backward_input(eps={eps})")
2355 }
2356 Op::BatchNormInferenceBackwardGamma { eps } => {
2357 write!(f, "batch_norm_inference_backward_gamma(eps={eps})")
2358 }
2359 Op::BatchNormInferenceBackwardBeta => {
2360 write!(f, "batch_norm_inference_backward_beta")
2361 }
2362 Op::CumsumBackward { axis, exclusive } => {
2363 write!(f, "cumsum_backward(axis={axis},exclusive={exclusive})")
2364 }
2365 Op::GatherBackward { axis } => write!(f, "gather_backward(axis={axis})"),
2366 Op::GaussianSplatRender {
2367 width,
2368 height,
2369 tile_size,
2370 radius_scale,
2371 alpha_cutoff,
2372 max_splat_steps,
2373 transmittance_threshold,
2374 max_list_entries,
2375 } => write!(
2376 f,
2377 "gaussian_splat_render({width}x{height},tile={tile_size},r={radius_scale},a={alpha_cutoff},steps={max_splat_steps},t={transmittance_threshold},list={max_list_entries})"
2378 ),
2379 Op::GaussianSplatRenderBackward {
2380 width,
2381 height,
2382 loss_grad_clip,
2383 sh_band,
2384 ..
2385 } => write!(
2386 f,
2387 "gaussian_splat_render_bwd({width}x{height},clip={loss_grad_clip},sh={sh_band})"
2388 ),
2389 Op::GaussianSplatPrepare {
2390 width,
2391 height,
2392 tile_size,
2393 radius_scale,
2394 alpha_cutoff,
2395 max_splat_steps,
2396 transmittance_threshold,
2397 max_list_entries,
2398 ..
2399 } => write!(
2400 f,
2401 "gaussian_splat_prepare({width}x{height},tile={tile_size},r={radius_scale},a={alpha_cutoff},steps={max_splat_steps},t={transmittance_threshold},list={max_list_entries})"
2402 ),
2403 Op::GaussianSplatRasterize {
2404 width,
2405 height,
2406 tile_size,
2407 alpha_cutoff,
2408 max_splat_steps,
2409 transmittance_threshold,
2410 max_list_entries,
2411 ..
2412 } => write!(
2413 f,
2414 "gaussian_splat_rasterize({width}x{height},tile={tile_size},a={alpha_cutoff},steps={max_splat_steps},t={transmittance_threshold},list={max_list_entries})"
2415 ),
2416 Op::Custom {
2417 name,
2418 num_inputs,
2419 attrs,
2420 } => write!(f, "custom({name},in={num_inputs},attrs={}B)", attrs.len()),
2421 Op::CustomFn {
2422 num_inputs,
2423 vjp_body,
2424 jvp_body,
2425 ..
2426 } => {
2427 let v = if vjp_body.is_some() { ",vjp" } else { "" };
2428 let j = if jvp_body.is_some() { ",jvp" } else { "" };
2429 write!(f, "custom_fn(in={num_inputs}{v}{j})")
2430 }
2431 Op::Fft { inverse, norm } => {
2432 write!(f, "fft(inverse={inverse}, norm={norm:?})")
2433 }
2434 Op::FftButterflyStage { stage, n_fft } => {
2435 write!(f, "fft_butterfly_stage(stage={stage}, n_fft={n_fft})")
2436 }
2437 Op::LogMel => write!(f, "log_mel()"),
2438 Op::LogMelBackward => write!(f, "log_mel_backward()"),
2439 Op::WelchPeaks { k, n_segments } => {
2440 write!(f, "welch_peaks(k={k}, n_segments={n_segments})")
2441 }
2442 }
2443 }
2444}