rlx_ir/op.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Operation types — every tensor op in the RLX IR.
17//!
18//! Designed for pattern-matching fusion: ops are grouped by category so
19//! fusion passes can reason about them structurally.
20
21use crate::DType;
22
23/// Unary element-wise activation functions.
24#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub enum Activation {
27 Gelu,
28 GeluApprox,
29 Silu, // SwiGLU gate activation
30 Relu,
31 Sigmoid,
32 Tanh,
33 Exp,
34 Log,
35 Sqrt,
36 Rsqrt,
37 Neg,
38 Abs,
39 /// `sin(x)`. Backward: `dx = upstream · cos(x)`.
40 Sin,
41 /// `cos(x)`. Backward: `dx = -upstream · sin(x)`.
42 Cos,
43 /// `tan(x)`. Backward: `dx = upstream · sec²(x) = upstream · (1 + tan²(x))`.
44 Tan,
45 /// `atan(x)`. Backward: `dx = upstream · (1 / (1 + x²))`.
46 Atan,
47 /// Round to nearest integer (half-to-even), in f32.
48 /// Forward: `x.round()`. Backward: STE — treats as identity, so
49 /// the gradient passes through unchanged. Useful as a primitive
50 /// for composing custom quantization schemes (Mul-by-recip-scale
51 /// → Round → Clamp → Mul-by-scale = a hand-rolled FakeQuantize
52 /// that the elementwise-region pass can fuse into a single kernel).
53 Round,
54}
55
56/// Scale-tracking strategy for `Op::FakeQuantize`. Determines how
57/// the per-channel `s[c]` is computed each forward pass.
58///
59/// * `PerBatch` — recompute `s[c] = max(|x|) / q_max` from the
60/// current data on every call. Simple, no extra inputs, but
61/// noisy for activations (max-abs jumps batch-to-batch).
62///
63/// * `EMA { decay }` — keep a running `s[c]` in a state tensor
64/// (passed as a second op input). On each call, blend the
65/// current per-batch max-abs into the state via
66/// `state' = decay·state + (1-decay)·max_abs`. Smooth scale
67/// over training, makes activation-QAT actually trainable.
68/// Typical `decay = 0.99`.
69///
70/// * `Fixed` — never recompute. The state tensor's value is
71/// used as-is each call (set once at construction or by the
72/// caller). Useful when scales are pre-calibrated.
73#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
74#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Default)]
75pub enum ScaleMode {
76 #[default]
77 PerBatch,
78 EMA {
79 decay: f32,
80 },
81 Fixed,
82}
83
84impl Eq for ScaleMode {}
85impl std::hash::Hash for ScaleMode {
86 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
87 match self {
88 ScaleMode::PerBatch => state.write_u8(0),
89 ScaleMode::EMA { decay } => {
90 state.write_u8(1);
91 state.write_u32(decay.to_bits());
92 }
93 ScaleMode::Fixed => state.write_u8(2),
94 }
95 }
96}
97
98/// Straight-through estimator variants for `Op::FakeQuantize`'s
99/// backward. The forward is the same regardless: discrete
100/// `clamp(round(x/s)) * s`. The choice here affects only the
101/// gradient w.r.t. `x` during training.
102///
103/// * `Identity` — `dx = upstream`. The original STE; treats the
104/// round as identity in the backward direction. Simplest, fine
105/// for moderate bit widths (i4 / i8).
106///
107/// * `ClippedIdentity` — `dx = upstream * (|x| ≤ q_max·s)`. Zero
108/// the gradient when the input was outside the quantization
109/// range (i.e. the clamp activated). Stops the optimizer from
110/// pushing weights further into saturation.
111///
112/// * `Tanh` — `dx = upstream * (1 - tanh²(x/s))`. Smooth surrogate
113/// for the round step. Slowly attenuates the gradient as `|x|`
114/// approaches `q_max·s`. Often best on tight bit widths (i2).
115///
116/// * `HardTanh` — `dx = upstream * (1 - |x/(q_max·s)|).max(0)`.
117/// Piecewise-linear cousin of `Tanh`; cheaper to compute and
118/// nearly as effective.
119#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
120#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
121pub enum SteKind {
122 #[default]
123 Identity,
124 ClippedIdentity,
125 Tanh,
126 HardTanh,
127}
128
129/// Binary element-wise operations.
130#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
132pub enum BinaryOp {
133 Add,
134 Sub,
135 Mul,
136 Div,
137 Max,
138 Min,
139 Pow,
140}
141
142/// Comparison operations (return Bool tensor).
143#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
144#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
145pub enum CmpOp {
146 Eq,
147 Ne,
148 Lt,
149 Le,
150 Gt,
151 Ge,
152}
153
154/// What kind of attention mask the kernel should apply.
155///
156/// Borrowed from MAX's `nn/attention/mha_mask.mojo` pattern (#20 in
157/// PLAN.md): one attention kernel handles all variants by branching on
158/// the mask kind, instead of forcing every caller to materialize a mask
159/// tensor. The win is two-fold:
160/// 1. **`None`** — single unpadded sequence: no mask load, no per-key
161/// compare in the inner loop.
162/// 2. **`Causal`** — autoregressive decode: kernel generates the upper-
163/// triangular fill from `(qi, ki)` directly; no `seq²` mask tensor
164/// ever exists.
165///
166/// `Custom` is the existing path — read mask values from the 4th input.
167#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
168#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
169pub enum MaskKind {
170 /// No masking — every position attends to every position.
171 None,
172 /// Causal (autoregressive) — position `qi` attends only to `ki <= qi`.
173 Causal,
174 /// Sliding window — position `qi` attends to `ki ∈ [qi - w, qi]`.
175 SlidingWindow(usize),
176 /// Read mask values from the input tensor (default; matches BERT
177 /// padding-mask behavior). Tensor shape `[batch, key_len]` with
178 /// `1.0` = valid, `<0.5` = ignored.
179 Custom,
180 /// Additive per-head, per-query bias tensor
181 /// `[batch, num_heads, query_len, key_len]` added to the
182 /// `QK^T · scale` scores before softmax. Lets DETR-style boxRPB
183 /// and other learned position biases reuse the fast `Op::Attention`
184 /// path instead of decomposing into matmul + add + softmax + matmul.
185 Bias,
186}
187
188/// Which forward input an [`Op::AttentionBackward`] node differentiates.
189#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
190#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
191pub enum AttentionBwdWrt {
192 Query,
193 Key,
194 Value,
195}
196
197/// Reduction operations along specified axes.
198#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
199#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
200pub enum ReduceOp {
201 Sum,
202 Mean,
203 Max,
204 Min,
205 Prod,
206}
207
208/// PLAN L4: discriminant for each [`Op`] variant. Used by
209/// [`Op::kind`] + the `Backend::supported_ops` trait method to declare
210/// which ops a backend can lower; the `LegalizeForBackend` pass in
211/// `rlx-opt` checks the graph against this set and fails the compile
212/// when an unsupported op is present (instead of silent fallback).
213#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
214#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
215pub enum OpKind {
216 Input,
217 Param,
218 Constant,
219 Activation,
220 Cast,
221 StopGradient,
222 Quantize,
223 Dequantize,
224 FakeQuantize,
225 FakeQuantizeLSQ,
226 FakeQuantizeLSQBackwardX,
227 FakeQuantizeLSQBackwardScale,
228 Binary,
229 Compare,
230 Where,
231 ElementwiseRegion,
232 /// Fused sampling / geometry chain (FKL-style transform region).
233 TransformRegion,
234 /// Same element-wise chain over multiple batch planes (horizontal fusion).
235 BatchElementwiseRegion,
236 MatMul,
237 DotGeneral,
238 DenseSolve,
239 BatchedDenseSolve,
240 LayerNorm,
241 LayerNorm2d,
242 GroupNorm,
243 BatchNormInference,
244 RmsNorm,
245 ResizeNearest2x,
246 Attention,
247 Rope,
248 AxialRope2d,
249 Reshape,
250 Transpose,
251 Narrow,
252 Concat,
253 Expand,
254 Gather,
255 Reduce,
256 Softmax,
257 Cumsum,
258 TopK,
259 Sample,
260 Conv,
261 Im2Col,
262 ConvTranspose2d,
263 Pool,
264 ReluBackward,
265 ActivationBackward,
266 FakeQuantizeBackward,
267 ComplexNormSq,
268 ComplexNormSqBackward,
269 Conjugate,
270 MaxPool2dBackward,
271 Conv2dBackwardInput,
272 Conv2dBackwardWeight,
273 SoftmaxCrossEntropyWithLogits,
274 SoftmaxCrossEntropyBackward,
275 AttentionBackward,
276 LayerNormBackwardInput,
277 LayerNormBackwardGamma,
278 RmsNormBackwardInput,
279 RmsNormBackwardGamma,
280 RmsNormBackwardBeta,
281 RopeBackward,
282 GroupNormBackwardInput,
283 GroupNormBackwardGamma,
284 GroupNormBackwardBeta,
285 BatchNormInferenceBackwardInput,
286 BatchNormInferenceBackwardGamma,
287 BatchNormInferenceBackwardBeta,
288 CumsumBackward,
289 GatherBackward,
290 GroupedMatMul,
291 DequantGroupedMatMul,
292 DequantMoEWeights,
293 ScatterAdd,
294 LoraMatMul,
295 DequantMatMul,
296 QMatMul,
297 QConv2d,
298 SelectiveScan,
299 GatedDeltaNet,
300 FusedSwiGLU,
301 FusedMatMulBiasAct,
302 FusedResidualLN,
303 FusedResidualRmsNorm,
304 FusedAttentionBlock,
305 FusedTransformerLayer,
306 If,
307 While,
308 Scan,
309 ScanBackward,
310 ScanBackwardXs,
311 /// CPU reference 3D Gaussian splat raster (project → bin → sort → raster).
312 /// See [`Op::GaussianSplatRender`].
313 GaussianSplatRender,
314 /// Backward of [`Op::GaussianSplatRender`] — packed scene parameter gradients.
315 GaussianSplatRenderBackward,
316 /// Project + tile bin + sort + ray grid (strict IR splat stage 1).
317 GaussianSplatPrepare,
318 /// Per-pixel raster from prepared buffers (strict IR splat stage 2).
319 GaussianSplatRasterize,
320 /// User-registered op dispatched through `op_registry`. All
321 /// custom ops (Sparse-LU, FFT, eigensolve, ...) share this kind;
322 /// the per-op identity lives in `Op::Custom::name`.
323 Custom,
324 /// User-defined sub-graph with optional override AD rules. See
325 /// [`Op::CustomFn`] / [`crate::Graph::custom_fn`].
326 CustomFn,
327 /// 1D FFT primitive (forward or inverse) — see [`Op::Fft`].
328 Fft,
329 /// Ternary pruned radix-2 butterfly stage — see [`Op::FftButterflyStage`].
330 FftButterflyStage,
331 /// Whisper-style log-mel from block-layout FFT spectrum — see [`Op::LogMel`].
332 LogMel,
333 /// Backward of [`Op::LogMel`] w.r.t. block-layout spectrum input 0.
334 LogMelBackward,
335}
336
337/// An operand inside a fused [`ChainStep`] — either a graph-level input
338/// to the [`Op::ElementwiseRegion`] (by index 0..num_inputs) or the
339/// result of a previous step in the chain (by index 0..step_position).
340#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
341#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
342pub enum ChainOperand {
343 Input(u32),
344 Step(u32),
345}
346
347/// One step in a fused element-wise chain. Each step produces exactly
348/// one scalar result (per element); later steps can refer to it via
349/// [`ChainOperand::Step`]. The whole chain runs per element in registers.
350#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
351#[derive(Debug, Clone, PartialEq)]
352pub enum ChainStep {
353 Activation(Activation, ChainOperand),
354 Cast(DType, ChainOperand),
355 Binary(BinaryOp, ChainOperand, ChainOperand),
356 Compare(CmpOp, ChainOperand, ChainOperand),
357 /// 3-input element-wise select: `cond ? on_true : on_false`. Mirrors
358 /// `Op::Where` inside a chain. `cond` is treated as truthy iff
359 /// non-zero. Lets the optimizer fold attention masks / clamp-style
360 /// patterns into a single region kernel instead of breaking the
361 /// chain at the first `Op::Where`.
362 Where(ChainOperand, ChainOperand, ChainOperand),
363}
364
365/// Pre-region memory transform fused into [`Op::ElementwiseRegion`].
366#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
367#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
368pub enum RegionPrologue {
369 #[default]
370 None,
371 /// Input is half-resolution NCHW; output shape is 2× H×W (nearest 2×).
372 ResizeNearest2x,
373}
374
375/// One sampling step in [`Op::TransformRegion`].
376#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
377#[derive(Debug, Clone, PartialEq)]
378pub enum TransformStep {
379 ResizeNearest2x(ChainOperand),
380}
381
382/// An operation in the RLX IR graph.
383///
384/// Operations are categorized for fusion analysis:
385/// - Element-wise ops fuse with anything reading their output
386/// - Matmul/Conv are BLAS-dispatched and form fusion boundaries
387/// - Reductions are fusion roots (drive the loop iteration)
388#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
389#[derive(Debug, Clone, PartialEq)]
390pub enum Op {
391 // ── Graph inputs ────────────────────────────────────────────
392 /// Model input with a name (shape on the Node).
393 Input {
394 name: String,
395 },
396
397 /// Model parameter (weight/bias) with a name.
398 Param {
399 name: String,
400 },
401
402 /// Constant tensor embedded in the graph.
403 Constant {
404 data: Vec<u8>,
405 },
406
407 // ── Element-wise unary ──────────────────────────────────────
408 /// Unary activation: one input, same shape output.
409 Activation(Activation),
410
411 /// Cast to a different dtype.
412 Cast {
413 to: DType,
414 },
415
416 /// Stop-gradient (a.k.a. `detach` / `lax.stop_gradient`). Forward is
417 /// identity; the reverse-mode autodiff rule returns **no** gradient
418 /// contribution for the input. Single input, same shape & dtype
419 /// output. Used to build a Gradient-Reverse-Layer with identity
420 /// forward semantics in user code (see maet-rs `dat_loss`).
421 StopGradient,
422
423 /// INT8 quantization. Input f32; output i8 same shape.
424 /// `q[i] = saturate_i8(round(x[i] / scale[c]) + zero_point[c])`
425 /// where `c` selects the per-channel scale/zp when `axis = Some(d)`
426 /// (`c = idx[d]`), or always uses index 0 when `axis = None`
427 /// (per-tensor). The `scales` / `zero_points` payload length must
428 /// match `1` for per-tensor and `input.dim(d)` for per-channel.
429 /// Static — typically produced at calibration time and baked
430 /// into the loaded model. Use `Op::Dequantize` for the inverse.
431 Quantize {
432 axis: Option<usize>,
433 scales: Vec<f32>,
434 zero_points: Vec<i32>,
435 },
436
437 /// INT8 dequantization (inverse of `Op::Quantize`). Input i8;
438 /// output f32 same shape.
439 /// `x[i] = (q[i] - zero_point[c]) · scale[c]`
440 /// where `c` is selected by `axis` exactly as in `Op::Quantize`.
441 Dequantize {
442 axis: Option<usize>,
443 scales: Vec<f32>,
444 zero_points: Vec<i32>,
445 },
446
447 /// "Fake-quantize" op for **quantization-aware training** (QAT).
448 /// Input f32; output f32 same shape. Forward computes a per-axis
449 /// (or per-tensor when `axis = None`) max-abs scale on the fly:
450 /// `s[c] = max(|x[..., c, ...]|) / q_max(bits)`
451 /// then quantizes-then-dequantizes:
452 /// `out[i] = clamp(round(x[i] / s[c]), -q_max, q_max) * s[c]`
453 /// where `q_max` is `127` for `bits=8`, `7` for `bits=4`, `1` for
454 /// `bits=2` (ternary). Symmetric only — zero-point is always 0.
455 ///
456 /// The point of this op is to make the SGD optimizer "see" the
457 /// deployment-time rounding during training. Backward is the
458 /// **straight-through estimator** (STE): the gradient passes
459 /// through (variant chosen by `ste`), ignoring the discontinuity
460 /// at the round. Without STE the rounding would have zero
461 /// gradient almost everywhere and learning would stop.
462 ///
463 /// Inserted by the trainer on conv / FC weight tensors when
464 /// `--qat` is on; the existing `Op::Quantize` / packing path at
465 /// the end of training still handles the deployment-side
466 /// conversion to `i8`/`i4`/`i2` codes.
467 FakeQuantize {
468 bits: u8,
469 axis: Option<usize>,
470 ste: SteKind,
471 scale_mode: ScaleMode,
472 },
473
474 /// Learned Step Size Quantization (LSQ; Esser et al. 2020,
475 /// `arXiv:1902.08153`). Like `FakeQuantize` but the per-channel
476 /// `scale` is a *learned parameter*, passed as the second input.
477 /// Forward is identical to `FakeQuantize` with a fixed scale:
478 /// `out[i] = clamp(round(x[i]/s[c]), -q_max, q_max) * s[c]`
479 /// Backward computes both `dx` (STE) and `dscale[c]` via the
480 /// closed-form gradient:
481 /// `dscale[c] = sum_i ψ(x[i]/s[c]) · upstream[i]`
482 /// where `ψ(z) = -z + round(z)` if `|z| ≤ q_max` else
483 /// `sign(z) · q_max`. Routinely beats per-batch and EMA at
484 /// tight bit widths (i2 / i3).
485 ///
486 /// Inputs: `[x, scale]`. `scale` is `[chan_dim]` f32 (matches
487 /// `axis`); for `axis = None` it's `[1]`.
488 FakeQuantizeLSQ {
489 bits: u8,
490 axis: Option<usize>,
491 },
492
493 /// Backward pass for `Op::FakeQuantizeLSQ`. Computes BOTH the
494 /// gradient w.r.t. `x` (STE) and the gradient w.r.t. `scale`
495 /// (closed-form). Output shape matches `x`; the `scale` gradient
496 /// is reduced separately by `LsqScaleGradient`.
497 /// Inputs: `[x, scale, dy]`. Output: `dx`, same shape as `x`.
498 FakeQuantizeLSQBackwardX {
499 bits: u8,
500 axis: Option<usize>,
501 },
502
503 /// Companion to `FakeQuantizeLSQBackwardX`: computes the
504 /// `[chan_dim]` per-channel scale gradient. Inputs `[x, scale, dy]`.
505 /// Output shape matches `scale`.
506 FakeQuantizeLSQBackwardScale {
507 bits: u8,
508 axis: Option<usize>,
509 },
510
511 // ── Element-wise binary ─────────────────────────────────────
512 /// Binary op with broadcasting: two inputs, output shape is broadcast result.
513 Binary(BinaryOp),
514
515 // ── Comparison ──────────────────────────────────────────────
516 /// Element-wise comparison: two inputs, Bool output.
517 Compare(CmpOp),
518
519 /// Select elements: cond (Bool), on_true, on_false → output.
520 Where,
521
522 /// Fused element-wise region (PLAN L2). Holds an N-step chain of
523 /// element-wise operations. Inputs are referenced by index 0..num_inputs;
524 /// each step's result can be referenced by later steps via
525 /// `ChainOperand::Step(idx)`. The output is the last step's result.
526 /// Emitted by `MarkElementwiseRegions` in `rlx-opt` from chains of
527 /// Activation/Cast/Binary/Compare/Where ops with single-consumer
528 /// intermediates and broadcast-compatible shapes. Backends that
529 /// don't have a region kernel can decompose back to the original
530 /// chain via `unfuse::unfuse_elementwise_regions`.
531 ///
532 /// `scalar_input_mask` is a per-input bitfield (bit `i` set ⇒
533 /// input `i` is a scalar broadcast — has shape `[1]`). Kept as a
534 /// fast-path indicator that lets kernels skip the modulo entirely
535 /// when they detect a scalar.
536 ///
537 /// `input_modulus[i]` is the per-input element count, used by
538 /// kernels to compute `arena[input_offs[i] + (gid % input_modulus[i])]`
539 /// — the trailing-shape broadcast pattern. `0` means "no broadcast"
540 /// (input matches the output element count; kernel reads `gid`
541 /// directly). `1` means scalar; any other value means the input
542 /// has fewer elements than the output and they tile by modulo.
543 /// The encoder only allows broadcasts where `out_elems % in_elems
544 /// == 0` so the modulo divides cleanly. Lets chains include bias /
545 /// scale / eps / mask factors that previously broke the chain at
546 /// a Binary op with mismatched shapes.
547 ElementwiseRegion {
548 chain: Vec<ChainStep>,
549 num_inputs: u32,
550 scalar_input_mask: u32,
551 input_modulus: [u32; 16],
552 /// FKL-style closed fusion: apply before the element-wise chain.
553 prologue: RegionPrologue,
554 /// External input index that supplies the prologue transform source (default 0).
555 prologue_input: u32,
556 },
557
558 /// Fused transform chain (resize, future crop/color). Decompose via
559 /// [`rlx_fusion::DecomposeFusionRegions`] when no native kernel exists.
560 TransformRegion {
561 steps: Vec<TransformStep>,
562 num_inputs: u32,
563 },
564
565 /// Identical [`Op::ElementwiseRegion`] chain over `num_batch_inputs` tensors
566 /// (horizontal / z-plane fusion). Inputs are separate batch slices.
567 BatchElementwiseRegion {
568 chain: Vec<ChainStep>,
569 num_batch_inputs: u32,
570 scalar_input_mask: u32,
571 input_modulus: [u32; 16],
572 prologue: RegionPrologue,
573 prologue_input: u32,
574 },
575
576 // ── Linear algebra ──────────────────────────────────────────
577 /// Matrix multiply. Inputs: [.., M, K] × [.., K, N] → [.., M, N].
578 /// Batch dimensions are broadcast.
579 MatMul,
580
581 /// Matrix multiply with explicit dimension specification.
582 /// Like XLA's DotGeneral — handles arbitrary batch/contracting dims.
583 DotGeneral {
584 lhs_contracting: Vec<usize>,
585 rhs_contracting: Vec<usize>,
586 lhs_batch: Vec<usize>,
587 rhs_batch: Vec<usize>,
588 },
589
590 /// Batched dense linear solve. Inputs: `A [B, N, N]`,
591 /// `b [B, N]` or `b [B, N, K]`. Output: same shape as `b`.
592 ///
593 /// Per-batch independent solve — each `A[i]` and `b[i]` are
594 /// solved as a separate `Op::DenseSolve`. Emitted by vmap of
595 /// `Op::DenseSolve`. The CPU lowering loops over the batch
596 /// dimension calling `dgesv` per slice (LAPACK doesn't expose a
597 /// batched solve on Accelerate; cuSOLVER does on NVIDIA).
598 BatchedDenseSolve,
599
600 /// Dense linear solve `x = A⁻¹ · b` via LU factorization.
601 /// Inputs: `A [N, N]`, `b [N]` (or `b [N, K]` for multi-RHS).
602 /// Output: same shape as `b`.
603 ///
604 /// VJP via the implicit-function theorem:
605 /// `dx = solve(Aᵀ, upstream)`
606 /// `dA = -outer(dx, x)` (x is the forward output)
607 /// `db = dx`
608 /// The rule is dtype-agnostic; lowering is per-backend (Accelerate
609 /// `dgesv` / `sgesv`, cuSOLVER, etc.).
610 DenseSolve,
611
612 // ── Normalization ───────────────────────────────────────────
613 /// Layer normalization: input, gamma, beta → normalized output.
614 /// `axis` is the feature dimension (usually -1).
615 LayerNorm {
616 axis: i32,
617 eps: f32,
618 },
619
620 /// Group normalization on NCHW tensors: `input`, `gamma`, `beta` → same shape.
621 /// Normalizes over `(C/num_groups) × H × W` per group.
622 GroupNorm {
623 num_groups: usize,
624 eps: f32,
625 },
626
627 /// LayerNorm2d on NCHW: normalize across the channel axis at each spatial
628 /// position (candle / SAM `LayerNorm2d` semantics — not PyTorch's H×W norm).
629 LayerNorm2d {
630 eps: f32,
631 },
632
633 /// Nearest-neighbor 2× upsample on NCHW (doubles spatial dims 2 and 3).
634 ResizeNearest2x,
635
636 /// RMS normalization: input, gamma → normalized output.
637 RmsNorm {
638 axis: i32,
639 eps: f32,
640 },
641
642 /// BatchNorm inference with frozen running statistics.
643 /// Inputs: `x`, `gamma`, `beta`, `running_mean`, `running_var`.
644 /// Feature dimension is the last axis of `x`; stats are 1-D `[C]`.
645 BatchNormInference {
646 eps: f32,
647 },
648
649 // ── Attention ───────────────────────────────────────────────
650 /// Scaled dot-product attention: Q, K, V, \[mask\] → output.
651 /// The compiler can lower this to fused SDPA or flash attention.
652 /// `mask_kind` controls how masking is applied — `Custom` reads from
653 /// the 4th input tensor; `None` / `Causal` / `SlidingWindow` skip the
654 /// mask load and apply the mask directly in the inner loop. See
655 /// `MaskKind` for the rationale.
656 ///
657 /// `score_scale`: when `Some(s)`, dot-product scores are multiplied by
658 /// `s` instead of the default `1/sqrt(head_dim)` (Gemma uses `head_dim^-0.5`
659 /// explicitly in config). `attn_logit_softcap`: when `Some(c)`, applies
660 /// `c * tanh(s/c)` to scores before softmax (Gemma 2).
661 Attention {
662 num_heads: usize,
663 head_dim: usize,
664 mask_kind: MaskKind,
665 score_scale: Option<f32>,
666 attn_logit_softcap: Option<f32>,
667 },
668
669 /// Rotary position embedding applied to one tensor: x, cos, sin → x_rotated.
670 /// Apply separately to Q and K. `head_dim` is the per-head width; `n_rot`
671 /// is how many leading dims get NeoX RoPE (pair offset `n_rot/2`). When
672 /// `n_rot < head_dim`, trailing dims are copied unchanged (Qwen3.5 MRoPE).
673 Rope {
674 head_dim: usize,
675 n_rot: usize,
676 },
677
678 /// SAM2 axial 2-D RoPE on `[batch, seq, num_heads * head_dim]`.
679 AxialRope2d {
680 end_x: usize,
681 end_y: usize,
682 head_dim: usize,
683 num_heads: usize,
684 theta: f32,
685 repeat_factor: usize,
686 },
687
688 // ── Shape manipulation ──────────────────────────────────────
689 Reshape {
690 new_shape: Vec<i64>,
691 },
692 Transpose {
693 perm: Vec<usize>,
694 },
695 /// Select a contiguous slice along an axis.
696 Narrow {
697 axis: usize,
698 start: usize,
699 len: usize,
700 },
701 /// Concatenate along an axis.
702 Concat {
703 axis: usize,
704 },
705 /// Expand (broadcast) to a target shape.
706 Expand {
707 target_shape: Vec<i64>,
708 },
709 /// Gather elements by index along an axis (embedding lookup).
710 Gather {
711 axis: usize,
712 },
713
714 // ── Reduction ───────────────────────────────────────────────
715 /// Reduce along specified axes.
716 Reduce {
717 op: ReduceOp,
718 axes: Vec<usize>,
719 keep_dim: bool,
720 },
721
722 /// Selective scan (plan #15) — Mamba-style state-space model
723 /// step. The recurrence:
724 /// `h[t] = exp(Δ[t] * A) * h[t-1] + Δ[t] * B[t] * x[t]`
725 /// `y[t] = C[t] * h[t]`
726 /// where state `h` has dimension `state_size` and the input has
727 /// `(batch, seq, hidden)`.
728 ///
729 /// Inputs (in order):
730 /// `x [b, s, h]` f32 input
731 /// `delta [b, s, h]` f32 step size (per-position, per-channel)
732 /// `a [h, n]` f32 transition matrix (one per channel)
733 /// `b [b, s, n]` f32 input projection
734 /// `c [b, s, n]` f32 output projection
735 /// Output: `[b, s, h]` f32. State `h` is implicit; the kernel
736 /// scans through the seq dimension carrying it.
737 ///
738 /// `state_size` = `n` is exposed for the cost model.
739 SelectiveScan {
740 state_size: usize,
741 },
742
743 /// Gated DeltaNet linear-attention recurrence — the per-layer
744 /// kernel used by Qwen3.5/3.6 trunk "linear attention" blocks
745 /// (and Qwen3-Next, Kimi-Linear). Mirrors
746 /// `llama.cpp / src/models/delta-net-base.cpp` autoregressive
747 /// path; chunked + fused variants ride the same op identity.
748 ///
749 /// **Math (per token `t`, head `h`, state size `n`):**
750 /// state matrix `S[h, i, j]` is implicit (reset per batch).
751 /// ```text
752 /// S[h] *= exp(g[t,h]) # scalar gate
753 /// sk[h,j] = Σ_i S[h,i,j] * k[t,h,i]
754 /// d[h,j] = (v[t,h,j] - sk[h,j]) * b[t,h] # b = beta
755 /// S[h,i,j] += k[t,h,i] * d[h,j] # outer-prod
756 /// o[t,h,j] = Σ_i S[h,i,j] * (q[t,h,i] / √n)
757 /// ```
758 ///
759 /// Inputs:
760 /// `q [b, s, h_v, n]` f32 queries (L2-normed by caller)
761 /// `k [b, s, h_v, n]` f32 keys (L2-normed by caller;
762 /// GQA-repeated to match `h_v`)
763 /// `v [b, s, h_v, n]` f32 values
764 /// `g [b, s, h_v]` f32 log-gate (exp'd inside kernel)
765 /// `beta [b, s, h_v]` f32 delta-rule mixing factor
766 ///
767 /// Output: `[b, s, h_v, n]` f32.
768 ///
769 /// When `carry_state` is true, a sixth input `state [b, h_v, n, n]`
770 /// provides the initial SSM matrix per head; the kernel updates it
771 /// in place across the sequence and leaves the final state in the
772 /// same buffer (same layout as the internal scan state:
773 /// `state[h, i, j]` row-major over `(n, n)` per head).
774 GatedDeltaNet {
775 state_size: usize,
776 carry_state: bool,
777 },
778
779 /// Fused dequant + matmul (plan #5). The biggest LLM-bandwidth
780 /// win on Apple Silicon: dequantizes weights inside the matmul
781 /// inner loop, never materializing f32 weights.
782 ///
783 /// **BREAKING CHANGE in 0.2.0:** `num_inputs()` is now
784 /// scheme-dependent — **4** for legacy Int8 schemes, **2** for
785 /// the new GGUF K-quant schemes (their scales/mins live inside
786 /// the packed bytes, so no side-channel `scale` / `zp` tensors
787 /// are fed in). Callers that assumed a fixed 4-input contract
788 /// must inspect `scheme.is_gguf()` before reading inputs.
789 ///
790 /// Inputs (Int8 schemes — `scheme.is_gguf() == false`):
791 /// `x [m, k]` f32 activations
792 /// `w_q [k, n]` packed quantized weight bytes (i8 per
793 /// element for Int8 schemes; 4-bit
794 /// packed two-per-byte for Int4)
795 /// `scale [k/block, n]` per-block f32 dequant scale
796 /// `zp [k/block, n]` per-block f32 zero-point
797 /// (zero-tensor if symmetric)
798 ///
799 /// Inputs (`Nvfp4Block` — fixed group size 16 along K):
800 /// `x [m, k]` f32 activations
801 /// `w_q [k,n/2]` packed FP4 E2M1 codes (unsigned nibble 0..15)
802 /// `scale [k/16, n]` u8 FP8 E4M3 block scales (one byte / group)
803 /// `global_scale [1]` f32 per-tensor scale (pass `[1.0]` if unused)
804 ///
805 /// Inputs (GGUF schemes — `scheme.is_gguf() == true`):
806 /// `x [m, k]` f32 activations
807 /// `packed_w [bytes]` raw GGUF super-block bytes; the
808 /// dequantizer reads the per-sub-block
809 /// scales / mins / quants directly out
810 /// of the buffer per the K-quant block
811 /// layout (no side tensors).
812 ///
813 /// Output: `[m, n]` f32.
814 ///
815 /// `block_size` (on the Int8 schemes only) is the number of
816 /// consecutive elements that share one (scale, zero_point) pair.
817 /// The Op carries enough metadata that the kernel doesn't need
818 /// a separate `QuantMap` lookup at run time.
819 DequantMatMul {
820 scheme: crate::quant::QuantScheme,
821 },
822
823 /// Real INT8-arithmetic matrix multiply with i32 accumulation.
824 /// Inputs (in order):
825 /// `x [M, K]` i8 activations (zero-point = `x_zp`)
826 /// `w [K, N]` i8 weights (zero-point = `w_zp`)
827 /// `bias [N]` i32 (in accumulator scale = `x_scale·w_scale`),
828 /// pass a zeros tensor for "no bias"
829 /// Output: `[M, N]` i8 (zero-point = `out_zp`)
830 ///
831 /// Per-element compute:
832 /// `out[m,n] = requantize(bias[n] + Σₖ (x[m,k]-x_zp)·(w[k,n]-w_zp), mult, out_zp)`
833 /// where `mult = x_scale · w_scale / out_scale`.
834 ///
835 /// This is the same kernel shape `rlx-cortexm/src/dense.rs`
836 /// uses for on-device int8 inference, lifted into the IR so the
837 /// rlx-cpu backend can run a quantized graph directly (instead
838 /// of round-tripping through fake-quant Dequantize → MatMul →
839 /// Quantize). 2-D only — generalizing to batched comes when a
840 /// real workload demands it.
841 QMatMul {
842 x_zp: i32,
843 w_zp: i32,
844 out_zp: i32,
845 mult: f32,
846 },
847
848 /// Real INT8-arithmetic 2-D convolution with i32 accumulation.
849 /// Inputs:
850 /// `x [N, C_in, H, W]` i8 (zero-point = `x_zp`)
851 /// `w [C_out, C_in/groups, kH, kW]` i8 (zero-point = `w_zp`)
852 /// `bias [C_out]` i32 in accumulator scale
853 /// Output: `[N, C_out, H_out, W_out]` i8 (zero-point = `out_zp`).
854 /// Same NCHW geometry contract as `Op::Conv`; same requantize
855 /// math as `Op::QMatMul` (per-element `acc·mult` rounded to i8).
856 QConv2d {
857 kernel_size: Vec<usize>,
858 stride: Vec<usize>,
859 padding: Vec<usize>,
860 dilation: Vec<usize>,
861 groups: usize,
862 x_zp: i32,
863 w_zp: i32,
864 out_zp: i32,
865 mult: f32,
866 },
867
868 /// Fused LoRA matmul: `out = x·W + scale * x·A·B`.
869 /// Inputs (in order): `x [m, k]`, `w [k, n]`, `a [k, r]`, `b [r, n]`.
870 /// `r` is the LoRA rank (typically 4-64). `scale` is the
871 /// per-adapter `alpha / rank` knob.
872 /// Plan #9: lifts LoRA from "three matmuls + an add" into one
873 /// kernel that keeps the rank-r intermediate in registers.
874 LoraMatMul {
875 scale: f32,
876 },
877
878 /// Fused sampling kernel: logits → optional top-k filter →
879 /// optional top-p truncation → softmax → multinomial sample.
880 /// One f32-encoded sampled token id per batch row (output
881 /// shape `[batch]`).
882 ///
883 /// `temperature == 1.0` matches a plain argmax-of-softmax;
884 /// lower → sharper, higher → flatter. `top_k == 0` disables.
885 /// `top_p == 1.0` disables. `seed` is the Philox seed; pass 0
886 /// for "use process-global counter" (still deterministic
887 /// given the call order).
888 /// Borrowed from MAX's nn/sampling.mojo (#42 in PLAN.md).
889 /// Latency-critical: never materializes the full softmax
890 /// distribution on the host.
891 Sample {
892 top_k: usize, // 0 = disabled
893 top_p: f32, // 1.0 = disabled
894 temperature: f32, // 1.0 = neutral
895 seed: u64, // 0 = use thread-local counter
896 },
897
898 /// Inclusive cumulative sum along an axis. Same shape in/out.
899 /// Underpins ragged-tensor offsets, sampling (top-p prefix sum),
900 /// and sequence-position math (#44 in PLAN.md).
901 /// `exclusive=true` shifts the result so output\[0\] = 0 (useful
902 /// for offset arrays where the first segment starts at 0).
903 Cumsum {
904 axis: i32,
905 exclusive: bool,
906 },
907
908 /// Softmax along an axis (reduction + element-wise).
909 Softmax {
910 axis: i32,
911 },
912
913 /// Top-K **indices** along the last axis. Output shape `[..., k]`,
914 /// f32-encoded indices (rlx is f32-only at the I/O boundary).
915 /// To recover the values, follow with a `Gather` against the
916 /// original tensor — works because Gather already supports any axis.
917 /// Ties broken by smaller index (matches NumPy / PyTorch
918 /// `torch.topk(..., largest=True, sorted=True)`).
919 /// Used by MoE gating; also useful for beam search.
920 TopK {
921 k: usize,
922 },
923
924 /// Indexed batched matmul. The MoE GEMM primitive.
925 /// Inputs: `[input, weight, expert_idx]`
926 /// input : [M, K] — per-token activations
927 /// weight : [num_experts, K, N] — stacked expert weights
928 /// expert_idx : \[M\] — f32-encoded expert id per token
929 /// Output : [M, N] — output\[i\] = input\[i\] @ weight[expert_idx\[i\]]
930 /// Naive impl on both backends; future work can replace with a
931 /// segmented/grouped GEMM when there's a real workload.
932 GroupedMatMul,
933
934 /// Fused GGUF K-quant dequant + [`Op::GroupedMatMul`]. Same three
935 /// inputs as `GroupedMatMul`, but `weight` is a U8 tensor holding
936 /// `num_experts` contiguous packed expert slabs (GGML layout, expert
937 /// dimension outermost). Scales live inside the packed bytes.
938 DequantGroupedMatMul {
939 scheme: crate::quant::QuantScheme,
940 },
941
942 /// Dequant a packed MoE expert stack to F32 `[num_experts, K, N]` in
943 /// GroupedMatMul layout. Input: U8 packed bytes; output shape is
944 /// declared on the node (`[E, K, N]`).
945 DequantMoEWeights {
946 scheme: crate::quant::QuantScheme,
947 },
948
949 /// Scatter-add into a destination tensor. The "unpermute" half of
950 /// MoE routing (also useful for embedding gradient updates).
951 /// Inputs: `[updates, indices]`
952 /// updates : [num_updates, trailing] — values to add
953 /// indices : \[num_updates\] — f32-encoded destination row
954 /// Output : [out_dim, trailing] — output[indices\[i\]] += updates\[i\]
955 /// `out_dim` is taken from the node's declared output shape.
956 /// Initial output is zero; multiple updates to the same row
957 /// accumulate (sequentially on CPU; with atomic-add on Metal).
958 ScatterAdd,
959
960 // ── Convolution ─────────────────────────────────────────────
961 /// 2D convolution on NCHW tensors. Also exposed as [`OpKind::Conv`] / `conv2d`.
962 /// Weight layout: `[C_out, C_in / groups, kH, kW]`.
963 Conv {
964 kernel_size: Vec<usize>,
965 stride: Vec<usize>,
966 padding: Vec<usize>,
967 dilation: Vec<usize>,
968 groups: usize,
969 },
970
971 /// NCHW im2col for conv backward-weight style matmul.
972 /// Input `[N, C, H, W]`. Output `[M, C·kH·kW]` with
973 /// `M = N · H_out · W_out`. When batch is [`dynamic::sym::BATCH`],
974 /// output rows use [`dynamic::sym::ROWS`] (bind `N · H_out · W_out`).
975 Im2Col {
976 kernel_size: Vec<usize>,
977 stride: Vec<usize>,
978 padding: Vec<usize>,
979 dilation: Vec<usize>,
980 },
981
982 /// 2D transposed convolution on NCHW. Weight layout (PyTorch):
983 /// `[C_in, C_out / groups, kH, kW]`.
984 ConvTranspose2d {
985 kernel_size: Vec<usize>,
986 stride: Vec<usize>,
987 padding: Vec<usize>,
988 dilation: Vec<usize>,
989 output_padding: Vec<usize>,
990 groups: usize,
991 },
992
993 // ── Pooling ─────────────────────────────────────────────────
994 Pool {
995 kind: ReduceOp,
996 kernel_size: Vec<usize>,
997 stride: Vec<usize>,
998 padding: Vec<usize>,
999 },
1000
1001 // ── Backward / training ops ─────────────────────────────────
1002 //
1003 // Closed-form gradient nodes emitted by `rlx-opt::autodiff`.
1004 // Pairing a forward op with a dedicated backward op (instead of
1005 // composing it from primitives) keeps the gradient kernel simple
1006 // and lets the backend recompute argmax / masks / softmax inline.
1007 /// ReLU backward: `dx = dy where x > 0 else 0`.
1008 /// Inputs: `[x, dy]` — both same shape; output matches.
1009 ReluBackward,
1010
1011 /// Element-wise complex squared-magnitude: `|z|² = z.re² + z.im²`.
1012 /// Input: 1 tensor with `DType::C64`. Output: same shape but
1013 /// `DType::F32`. The natural real-valued loss surface for
1014 /// Wirtinger reverse-mode AD on complex graphs — pair with
1015 /// [`Op::ComplexNormSqBackward`].
1016 ComplexNormSq,
1017
1018 /// Element-wise complex conjugate: `z̄ = z.re - i·z.im` per element.
1019 /// Input: 1 tensor with `DType::C64`. Output: same shape, same dtype.
1020 /// Used by Wirtinger VJP rules on `Op::Binary` over C64 (the rule
1021 /// for `y = a·b` is `dL/dā = upstream · conj(b)`, etc.).
1022 Conjugate,
1023
1024 /// Backward for [`Op::ComplexNormSq`] under Wirtinger calculus.
1025 /// `f(z) = |z|² = z·z̄`, so `∂f/∂z̄ = z`. Given upstream real
1026 /// cotangent `g` (same shape as the forward output), the C64
1027 /// gradient with respect to `z` is `g · z` element-wise, returned
1028 /// in C64 storage `[re_g·re_z, re_g·im_z]` per element.
1029 ///
1030 /// Inputs: `[z (C64), g (F32)]` — both same logical shape; output
1031 /// matches `z` (C64).
1032 ComplexNormSqBackward,
1033
1034 /// LayerNorm backward w.r.t. the input. Computes
1035 /// `d_x[..., d] = inv_std · (dy·γ - mean(dy·γ) - x̂·mean(dy·γ·x̂))`
1036 /// over the feature axis, where `x̂ = (x - mean)/std` is recomputed
1037 /// inline from `x`. Inputs: `[x, gamma, dy]`; output shape = `x.shape`.
1038 /// Currently lowers axis=-1 only (matches the forward thunk).
1039 LayerNormBackwardInput {
1040 axis: i32,
1041 eps: f32,
1042 },
1043
1044 /// LayerNorm backward w.r.t. gamma. Computes
1045 /// `d_gamma[d] = Σ_{batch} dy[..., d] · x̂[..., d]`
1046 /// — sums the per-element product of upstream and the (recomputed)
1047 /// normalized input over the leading axes. Inputs: `[x, dy]`;
1048 /// output shape = `gamma.shape` (= 1-D feature axis).
1049 LayerNormBackwardGamma {
1050 axis: i32,
1051 eps: f32,
1052 },
1053
1054 /// RMSNorm backward w.r.t. input. Inputs `[x, gamma, beta, dy]`; output = `x.shape`.
1055 RmsNormBackwardInput {
1056 axis: i32,
1057 eps: f32,
1058 },
1059
1060 /// RMSNorm backward w.r.t. gamma. Inputs `[x, gamma, beta, dy]`; output = `gamma.shape`.
1061 RmsNormBackwardGamma {
1062 axis: i32,
1063 eps: f32,
1064 },
1065
1066 /// RMSNorm backward w.r.t. beta. Inputs `[x, gamma, beta, dy]`; output = `beta.shape`.
1067 RmsNormBackwardBeta {
1068 axis: i32,
1069 eps: f32,
1070 },
1071
1072 /// RoPE backward w.r.t. `x`. Inputs `[dy, cos, sin]`; output = `dy.shape`.
1073 RopeBackward {
1074 head_dim: usize,
1075 n_rot: usize,
1076 },
1077
1078 /// GroupNorm (NCHW) backward w.r.t. input. Inputs `[x, gamma, beta, dy]`.
1079 GroupNormBackwardInput {
1080 num_groups: usize,
1081 eps: f32,
1082 },
1083
1084 /// GroupNorm backward w.r.t. gamma. Inputs `[x, dy]`; output = `gamma.shape`.
1085 GroupNormBackwardGamma {
1086 num_groups: usize,
1087 eps: f32,
1088 },
1089
1090 /// GroupNorm backward w.r.t. beta. Inputs `[x, dy]`; output = `beta.shape`.
1091 GroupNormBackwardBeta {
1092 num_groups: usize,
1093 eps: f32,
1094 },
1095
1096 /// BatchNorm inference backward w.r.t. `x`. Inputs `[x, gamma, mean, var, dy]`.
1097 BatchNormInferenceBackwardInput {
1098 eps: f32,
1099 },
1100
1101 /// BatchNorm inference backward w.r.t. `gamma`. Inputs `[x, mean, var, dy]`.
1102 BatchNormInferenceBackwardGamma {
1103 eps: f32,
1104 },
1105
1106 /// BatchNorm inference backward w.r.t. `beta`. Inputs `[dy]`; output = `beta.shape`.
1107 BatchNormInferenceBackwardBeta,
1108
1109 /// Cumsum backward along `axis`. Inputs `[dy]`; output matches forward input shape.
1110 CumsumBackward {
1111 axis: i32,
1112 exclusive: bool,
1113 },
1114
1115 /// Gather backward (scatter-add into table). Inputs `[dy, indices]`; output = table shape.
1116 /// `axis` matches forward [`Op::Gather`].
1117 GatherBackward {
1118 axis: i32,
1119 },
1120
1121 /// Generic element-wise activation backward. `kind` selects the
1122 /// closed-form derivative `d/dx act(x)`. Inputs: `[x, dy]`; output
1123 /// shape matches `x`. The kernel computes `d/dx · dy` per element.
1124 ///
1125 /// Closed forms (all element-wise):
1126 /// * `Gelu` — exact derivative of erf-based GELU.
1127 /// * `GeluApprox` — derivative of the tanh approximation
1128 /// `0.5 x (1 + tanh(√(2/π)(x + 0.044715 x³)))`.
1129 /// * `Silu` — `σ(x)·(1 + x·(1 - σ(x)))`.
1130 /// * `Sigmoid` — `σ(x)·(1 - σ(x))`.
1131 /// * `Tanh` — `1 - tanh(x)²`.
1132 /// * `Exp` — `exp(x)`.
1133 /// * `Log` — `1 / x`.
1134 /// * `Sqrt` — `0.5 / sqrt(x)`.
1135 /// * `Rsqrt` — `-0.5 · x^(-3/2)`.
1136 /// * `Neg` — `-1`.
1137 /// * `Abs` — `sign(x)` (zero at x=0).
1138 /// * `Sin` — `cos(x)`.
1139 /// * `Cos` — `-sin(x)`.
1140 /// * `Tan` — `1 + tan²(x)`.
1141 /// * `Atan` — `1 / (1 + x²)`.
1142 /// * `Relu` — kept here for completeness; the dedicated
1143 /// `ReluBackward` op is preferred for relu and is what the
1144 /// autodiff pass actually emits.
1145 ActivationBackward {
1146 kind: Activation,
1147 },
1148
1149 /// Backward for `Op::FakeQuantize` under a non-default STE.
1150 /// Inputs `[x, dy]`: the forward input and the upstream
1151 /// gradient. Output `dx` same shape. The `bits`/`axis`/`ste`
1152 /// fields must match the forward op so the kernel computes the
1153 /// same per-channel scale and applies the right STE attenuation.
1154 /// For `SteKind::Identity` this op is unnecessary — autodiff
1155 /// just routes `upstream` through unchanged.
1156 FakeQuantizeBackward {
1157 bits: u8,
1158 axis: Option<usize>,
1159 ste: SteKind,
1160 },
1161
1162 /// 2D max-pool backward. Routes each element of `dy` back into the
1163 /// position in `x`'s window where the forward max was taken.
1164 /// Inputs: `[x, dy]` with `x [N, C, H, W]` and
1165 /// `dy [N, C, H_out, W_out]`. Output: same shape as `x`.
1166 /// Carries the forward pool's geometry so the kernel can recompute
1167 /// the argmax position per window without a saved-indices tensor.
1168 MaxPool2dBackward {
1169 kernel_size: Vec<usize>,
1170 stride: Vec<usize>,
1171 padding: Vec<usize>,
1172 },
1173
1174 /// 2D conv backward w.r.t. input. Computes `dx = conv_transpose(dy, w)`.
1175 /// Inputs: `[dy, w]` with `dy [N, C_out, H_out, W_out]` and
1176 /// `w [C_out, C_in/groups, kH, kW]`. Output: `[N, C_in, H, W]`
1177 /// (declared on the node — caller knows the original input shape).
1178 /// Geometry is the forward conv's parameters, not the transposed
1179 /// conv's.
1180 Conv2dBackwardInput {
1181 kernel_size: Vec<usize>,
1182 stride: Vec<usize>,
1183 padding: Vec<usize>,
1184 dilation: Vec<usize>,
1185 groups: usize,
1186 },
1187
1188 /// 2D conv backward w.r.t. weight. Computes
1189 /// `dw[c_out, c_in, kh, kw] = sum_{n,h_out,w_out} x[n,c_in,...] * dy[n,c_out,h_out,w_out]`.
1190 /// Inputs: `[x, dy]`. Output: `[C_out, C_in/groups, kH, kW]`.
1191 Conv2dBackwardWeight {
1192 kernel_size: Vec<usize>,
1193 stride: Vec<usize>,
1194 padding: Vec<usize>,
1195 dilation: Vec<usize>,
1196 groups: usize,
1197 },
1198
1199 /// Fused softmax + cross-entropy loss with integer (f32-encoded)
1200 /// targets — the standard classification loss. Per-row output:
1201 /// `loss[n] = -log(softmax(logits[n])[labels[n]])`.
1202 /// Inputs: `[logits, labels]` with `logits [N, C]` and
1203 /// `labels [N]` (f32-encoded class indices). Output: `[N]`.
1204 /// Caller does the `Reduce::Mean` if they want a scalar.
1205 SoftmaxCrossEntropyWithLogits,
1206
1207 /// Backward of the fused loss above. Emits
1208 /// `dlogits[n,c] = (softmax(logits[n])[c] - one_hot(labels)[n,c]) * d_loss[n]`.
1209 /// Inputs: `[logits, labels, d_loss]`. Output: `[N, C]` (same shape
1210 /// as `logits`). Recomputes the softmax inline rather than threading
1211 /// it through from the forward node.
1212 SoftmaxCrossEntropyBackward,
1213
1214 /// Backward of [`Op::Attention`]. Recomputes scaled `QK^T`, applies
1215 /// the same `mask_kind` as the forward op, softmaxes scores, then
1216 /// emits **one** of `dQ`, `dK`, or `dV` selected by [`AttentionBwdWrt`].
1217 /// Autodiff emits three nodes (one per `wrt`) so each output shape
1218 /// stays a normal single-output MIR node.
1219 ///
1220 /// Inputs: `[q, k, v, dy]` plus optional mask when `mask_kind` is
1221 /// [`MaskKind::Custom`] or [`MaskKind::Bias`] (same convention as
1222 /// forward). Output shape matches `q`, `k`, or `v` respectively.
1223 AttentionBackward {
1224 num_heads: usize,
1225 head_dim: usize,
1226 mask_kind: MaskKind,
1227 wrt: AttentionBwdWrt,
1228 },
1229
1230 // ── Fused operations (created by optimization passes) ──────
1231 /// Fused matmul + bias + activation. Created from MatMul → Add → Activation.
1232 FusedMatMulBiasAct {
1233 activation: Option<Activation>,
1234 },
1235
1236 /// Fused residual + optional bias + layer norm.
1237 /// Created from Add(x, residual) → [Add(bias)] → LayerNorm.
1238 FusedResidualLN {
1239 has_bias: bool,
1240 eps: f32,
1241 },
1242
1243 /// Fused residual + optional bias + RMS norm.
1244 /// Created from Add(x, residual) → [Add(bias)] → RmsNorm.
1245 FusedResidualRmsNorm {
1246 has_bias: bool,
1247 eps: f32,
1248 },
1249
1250 /// Fused SwiGLU: split input into up/gate halves, silu(gate) * up.
1251 /// Created from Split → Silu → Mul when fed by a concatenated matmul.
1252 ///
1253 /// `cast_to`: optional output dtype — when `Some(dt)` the kernel casts
1254 /// its result from the input dtype to `dt` in-register, saving a
1255 /// separate cast pass. Reserved for future fp8/fp4 quantization paths;
1256 /// for f32→f16 mixed precision the AutoMixedPrecision pass already
1257 /// inserts a Cast node so this stays `None` in current pipelines.
1258 FusedSwiGLU {
1259 cast_to: Option<DType>,
1260 /// When `true`, the concatenated input stores gate in the low half
1261 /// `[..., 0..N)` and up in the high half `[..., N..2N)` — the layout
1262 /// produced when gate projection is emitted before up in the builder.
1263 /// Default `false`: up @ low, gate @ high (canonical concat order).
1264 gate_first: bool,
1265 },
1266
1267 /// Fused full transformer layer: attention block + residual+LN + FFN + residual+LN.
1268 /// All intermediates resident in registers/threadgroup memory; one kernel
1269 /// per layer instead of ~30 (the CPU's batch=1 win, lifted to IR so any
1270 /// backend can implement it as a monolithic kernel).
1271 ///
1272 /// Inputs: hidden, qkv_w, qkv_b, out_w, out_b,
1273 /// ln1_g, ln1_b, fc1_w, fc1_b, fc2_w, fc2_b, ln2_g, ln2_b, mask
1274 /// Output: same shape as hidden.
1275 ///
1276 /// **Backend status:** same as FusedAttentionBlock. CPU implements
1277 /// the L1-cache-resident merge at the thunk level. Metal deferred —
1278 /// requires a single MSL kernel for the whole layer to actually
1279 /// beat the unfused path. Multi-day work; revisit when there's a
1280 /// model whose Metal inference is bottlenecked here rather than on
1281 /// the wait latency floor.
1282 FusedTransformerLayer {
1283 num_heads: usize,
1284 head_dim: usize,
1285 intermediate_size: usize,
1286 eps1: f32,
1287 eps2: f32,
1288 activation: Activation,
1289 has_bias: bool,
1290 },
1291
1292 /// Fused attention block: QKV projection → split → \[RoPE\] → SDPA → output projection.
1293 /// Created by FuseAttentionBlock pass when batch*seq is small.
1294 /// All intermediates stay in L1 cache — no arena writes between ops.
1295 ///
1296 /// Inputs (in order):
1297 /// hidden, qkv_w, out_w, mask,
1298 /// [qkv_b, out_b] if has_bias,
1299 /// [rope_cos, rope_sin] if has_rope
1300 ///
1301 /// **Backend status (Phase C finalize):**
1302 /// CPU — implemented at the *thunk* level: the CPU schedule
1303 /// recognizes the multi-thunk pattern and merges into
1304 /// a single FusedAttnBlock that keeps Q/K/V in stack
1305 /// buffers across stages (the L1-cache win).
1306 /// Metal — **deferred**. A dispatch-wrapper version (chaining
1307 /// existing kernels) buys nothing the unfused Metal path
1308 /// doesn't already get, since per-run cost is dominated
1309 /// by `wait_until_completed` (~150 µs), not encode. The
1310 /// real win is a single MSL kernel keeping Q/K/V in
1311 /// threadgroup memory across stages — multi-day work.
1312 /// Until then, Metal runs the unfused chain (one matmul,
1313 /// three narrows, two ropes, attention, one matmul) — all
1314 /// covered in op_coverage and parity_harness.
1315 FusedAttentionBlock {
1316 num_heads: usize,
1317 head_dim: usize,
1318 has_bias: bool,
1319 has_rope: bool,
1320 },
1321
1322 // ── Control flow (subgraphs as op payloads) ─────────────────
1323 //
1324 // Status: IR is defined; helper `run_if` / `run_while` exist in
1325 // rlx-runtime/src/subgraph.rs; **executor wiring is not yet
1326 // implemented** (both CPU thunk and Metal thunk fall through to
1327 // `Thunk::Nop` for these ops). Wiring requires:
1328 // 1. Recursive subgraph compile at parent-compile time.
1329 // 2. Per-subgraph input/output binding through the arena.
1330 // 3. Schedule-level dispatch when the predicate / loop cond is
1331 // resolved at runtime.
1332 // Estimate: 4–6 hours of focused work + parity tests. Deferred
1333 // because no current in-tree model uses these ops;
1334 // surface area without a validation target invites silent bugs.
1335 /// Conditional: pick between two subgraphs based on a boolean predicate.
1336 /// Inputs: [predicate, ...captures (used inside both branches)].
1337 /// `then_branch` and `else_branch` are sub-graphs that share the
1338 /// captured inputs and must produce identically-shaped outputs.
1339 /// Used for: shape-dependent execution, batched inference of
1340 /// dynamic-length sequences with padding masks.
1341 If {
1342 then_branch: Box<crate::Graph>,
1343 else_branch: Box<crate::Graph>,
1344 },
1345
1346 /// Loop: iterate `body` while `cond` evaluates true.
1347 /// Inputs: [...initial loop-carried values].
1348 /// `cond`'s single output is a Bool scalar.
1349 /// `body`'s outputs become the next iteration's loop-carried inputs.
1350 /// Outputs of While are the values after the final iteration.
1351 /// Used for: KV-cache-driven autoregressive generation, beam search.
1352 While {
1353 cond: Box<crate::Graph>,
1354 body: Box<crate::Graph>,
1355 max_iterations: Option<usize>,
1356 },
1357
1358 /// Bounded-length loop with a fixed-shape carry, optional per-step
1359 /// inputs, and optional stacked output. Mirrors JAX's `lax.scan`.
1360 ///
1361 /// Body signature: `(carry, x_t_0, ..., x_t_{num_xs-1}) → carry_next`
1362 /// — `1 + num_xs` Op::Inputs in NodeId construction order (first
1363 /// declared is the carry; the remaining `num_xs` are per-step
1364 /// slices). Single output (the next carry).
1365 ///
1366 /// Outer Op::Scan inputs (in order):
1367 /// `[init_carry, xs_0, xs_1, ..., xs_{num_xs-1}]`
1368 /// Each `xs_i` has shape `[length, *per_step_shape_i]`; the body
1369 /// sees `xs_i[t]` (a `per_step_shape_i` slice) on iteration `t`.
1370 ///
1371 /// Outer Op::Scan output:
1372 /// * `save_trajectory == false` — final carry, shape `*carry`.
1373 /// * `save_trajectory == true` — stacked trajectory of carries,
1374 /// shape `[length, *carry]`. Row `t` is the carry after step
1375 /// `t+1`, so row `length-1` matches the no-trajectory case.
1376 ///
1377 /// Mirrors JAX's `lax.scan`. Common uses include time-stepping
1378 /// integrators with time-varying drives, Mamba-style SSM scans
1379 /// reading per-step inputs, and RNN-style sequence processing.
1380 Scan {
1381 body: Box<crate::Graph>,
1382 length: u32,
1383 save_trajectory: bool,
1384 /// Number of "broadcast" inputs — values that are constant
1385 /// across iterations. Outer scan inputs in order:
1386 /// `[init, bcast_0..bcast_{B-1}, xs_0..xs_{X-1}]`
1387 /// Body Op::Inputs in NodeId order:
1388 /// `[carry, bcast_0..bcast_{B-1}, x_t_0..x_t_{X-1}]`
1389 /// CPU executor fills bcast slots ONCE before the iteration
1390 /// loop (xs slots are filled per-step). The reverse-mode AD
1391 /// pre-pass materialises each bcast into an xs of shape
1392 /// `[length, *bcast]` via broadcast `Mul` so the rest of the
1393 /// VJP / executor pipeline can stay unchanged. `0` (default)
1394 /// keeps the original carry+xs scan shape.
1395 num_bcast: u32,
1396 /// Number of per-step `xs` inputs. Total outer Op::Scan
1397 /// inputs is `1 + num_bcast + num_xs`.
1398 num_xs: u32,
1399 /// Number of trajectory checkpoints when `save_trajectory ==
1400 /// true`. `0` means "save all `length` rows" (default). A
1401 /// positive value `K` means save only `K` evenly-spaced rows
1402 /// at indices `floor(t * length / K)` for `t in 0..K`. Used
1403 /// by recursive checkpointed AD: store O(√T) carries during
1404 /// forward, recompute the rest in the backward pass.
1405 ///
1406 /// When `0` (or `K == length`), the saved trajectory has
1407 /// shape `[length, *carry]` — same as the original behavior.
1408 /// When `0 < K < length`, the saved trajectory has shape
1409 /// `[K, *carry]`.
1410 num_checkpoints: u32,
1411 },
1412
1413 /// Reverse-mode AD companion to `Op::Scan` — extracts the carry
1414 /// gradient `dinit`. Walks `t = length-1 .. 0`, applying `body_vjp`
1415 /// to thread `dcarry` back through the time loop.
1416 ///
1417 /// Inputs (in order):
1418 /// `[init, trajectory, upstream, xs_0, ..., xs_{num_xs-1}]`
1419 /// Output: `dinit`, shape = carry shape.
1420 ///
1421 /// `body_vjp` is the result of
1422 /// `autodiff::grad(body, [carry_id, xs_0_id, ..., xs_{num_xs-1}_id])`
1423 /// — a graph with `1 + num_xs + 1` Inputs (carry + x_t_i for each
1424 /// xs + `"d_output"`) and `1 + num_xs` outputs
1425 /// (dcarry + dx_t_i for each xs). This op reads `outputs[0]` =
1426 /// dcarry; the sibling [`Self::ScanBackwardXs`] reads the
1427 /// `outputs[1 + xs_idx]` slot for each xs gradient.
1428 ScanBackward {
1429 body_vjp: Box<crate::Graph>,
1430 length: u32,
1431 save_trajectory: bool,
1432 num_xs: u32,
1433 /// When `0` or equal to `length`, the trajectory input has
1434 /// shape `[length, *carry]` — every step's carry is cached
1435 /// (`CheckpointStrategy::All`). When `0 < K < length`, the
1436 /// trajectory input has shape `[K, *carry]` and the executor
1437 /// recomputes intermediate carries via `forward_body` between
1438 /// checkpoints. `forward_body` must be `Some` whenever this
1439 /// is < length.
1440 num_checkpoints: u32,
1441 /// Forward body (the same `body` from the forward Op::Scan).
1442 /// Required when `num_checkpoints > 0 && < length` so the
1443 /// executor can recompute carries between saved checkpoints.
1444 /// `None` for the All strategy (no recompute needed).
1445 forward_body: Option<Box<crate::Graph>>,
1446 },
1447
1448 /// Companion to [`Self::ScanBackward`] that extracts one stacked
1449 /// per-step `dxs_i` (shape `[length, *per_step_xs_i]`). Same inputs
1450 /// and same `body_vjp` graph as ScanBackward — `xs_idx` selects
1451 /// which body_vjp output to stack into the result.
1452 ///
1453 /// Note: each ScanBackwardXs runs its own backward loop. A future
1454 /// optimization can fuse them into a single multi-output backward
1455 /// kernel; for now it's `1 + num_xs` independent sweeps.
1456 ScanBackwardXs {
1457 body_vjp: Box<crate::Graph>,
1458 length: u32,
1459 save_trajectory: bool,
1460 num_xs: u32,
1461 xs_idx: u32,
1462 num_checkpoints: u32,
1463 forward_body: Option<Box<crate::Graph>>,
1464 },
1465
1466 /// CPU reference 3D Gaussian splat forward render.
1467 ///
1468 /// Seven flat F32 inputs (scene buffers + camera/render meta):
1469 /// 0. positions `[N*3]`
1470 /// 1. scales `[N*3]` (log-space)
1471 /// 2. rotations `[N*4]` (xyzw)
1472 /// 3. opacities `[N]` (logit)
1473 /// 4. colors `[N*3]` (linear RGB)
1474 /// 5. sh_coeffs `[N * sh_coeff_count * 3]`
1475 /// 6. meta `[23]` — camera position/target/up/fov/near/far, background RGB,
1476 /// then width/height/tile_size/radius_scale/alpha_cutoff/max_splat_steps/
1477 /// transmittance_threshold/max_list_entries as f32 bit-patterns.
1478 ///
1479 /// Output: `[height * width * 4]` linear RGBA (display gamma baked in).
1480 /// Build via [`crate::Graph::gaussian_splat_render`].
1481 ///
1482 /// Differentiable backward is not implemented in v1; autodiff treats this
1483 /// op as non-differentiable (same as [`Op::Sample`]).
1484 GaussianSplatRender {
1485 width: u32,
1486 height: u32,
1487 tile_size: u32,
1488 radius_scale: f32,
1489 alpha_cutoff: f32,
1490 max_splat_steps: u32,
1491 transmittance_threshold: f32,
1492 max_list_entries: u32,
1493 },
1494
1495 /// Backward pass for [`Self::GaussianSplatRender`].
1496 ///
1497 /// Eight inputs: the same seven as forward plus `d_loss_rgba` `[W*H*4]`
1498 /// (only RGB channels are used). Re-runs the training forward internally.
1499 ///
1500 /// Output: packed gradients
1501 /// `[positions(3N) | scales(3N) | rotations(4N) | opacities(N) | colors(3N) | sh(N*sh*3)]`.
1502 /// Unpack via [`crate::ops::splat::unpack_gaussian_splat_packed_grads`].
1503 GaussianSplatRenderBackward {
1504 width: u32,
1505 height: u32,
1506 tile_size: u32,
1507 radius_scale: f32,
1508 alpha_cutoff: f32,
1509 max_splat_steps: u32,
1510 transmittance_threshold: f32,
1511 max_list_entries: u32,
1512 loss_grad_clip: f32,
1513 sh_band: u32,
1514 max_anisotropy: f32,
1515 },
1516
1517 /// Strict IR stage 1: project, bin, sort, build per-pixel rays.
1518 ///
1519 /// Seven inputs (same scene + meta as [`Self::GaussianSplatRender`]). Output: packed
1520 /// prepare buffer (see `rlx_splat::prep_layout::prep_packed_len`).
1521 GaussianSplatPrepare {
1522 width: u32,
1523 height: u32,
1524 tile_size: u32,
1525 radius_scale: f32,
1526 alpha_cutoff: f32,
1527 max_splat_steps: u32,
1528 transmittance_threshold: f32,
1529 max_list_entries: u32,
1530 },
1531
1532 /// Strict IR stage 2: tile raster from [`Self::GaussianSplatPrepare`] output.
1533 ///
1534 /// Inputs: `prep` packed buffer, `meta` `[23]`. Output: `[width * height * 4]` RGBA.
1535 GaussianSplatRasterize {
1536 width: u32,
1537 height: u32,
1538 tile_size: u32,
1539 alpha_cutoff: f32,
1540 max_splat_steps: u32,
1541 transmittance_threshold: f32,
1542 max_list_entries: u32,
1543 },
1544
1545 /// User-registered custom op. `name` keys into the
1546 /// [`crate::op_registry`] for shape inference, autodiff, and
1547 /// per-backend execution. `attrs` is an opaque blob passed
1548 /// through to those callbacks (FFT direction, SparseLU
1549 /// reordering strategy, etc.). `num_inputs` is captured at
1550 /// construction time so [`Op::num_inputs`] stays infallible
1551 /// without a registry lookup. Build via [`crate::Graph::custom_op`].
1552 Custom {
1553 name: String,
1554 num_inputs: u32,
1555 attrs: Vec<u8>,
1556 },
1557
1558 /// 1D Fast Fourier Transform along the last axis.
1559 ///
1560 /// **Layouts**
1561 /// - `F32` / `F64`: 2N real-block — last axis is `[re₀…re_{N-1}, im₀…im_{N-1}]`.
1562 /// - `C64`: interleaved `[re, im]` pairs per complex element along the last axis.
1563 ///
1564 /// **ND transforms** — use `Graph::fftn` / `Graph::ifftn`, which compose
1565 /// `fft_axis` (transpose → Fft → transpose). Multi-axis `fftn` requires
1566 /// `DType::C64`; the 2N-block layout describes a single complex axis.
1567 ///
1568 /// Default (`FftNorm::Backward`) is **unnormalized** on both directions:
1569 /// `fft(x)[k] = Σ x[n]·exp(-2πi·nk/N)`
1570 /// `ifft(y)[n] = Σ y[k]·exp(+2πi·nk/N)`
1571 /// so `ifft(fft(x)) = N·x`. Use `FftNorm::Forward` for gpu-fft-style
1572 /// `1/N` scaling on inverse, or `FftNorm::Ortho` for unitary scaling.
1573 ///
1574 /// AD: VJP(`fft`) = `ifft`, VJP(`ifft`) = `fft` when `norm=Backward`;
1575 /// other norms apply the chain rule via output scaling.
1576 Fft {
1577 inverse: bool,
1578 norm: crate::fft::FftNorm,
1579 },
1580
1581 /// Ternary pruned radix-2 butterfly stage on interleaved complex state.
1582 ///
1583 /// Inputs:
1584 /// 0 — state `[batch, n_fft, 2]` (re/im on axis 2)
1585 /// 1 — gate `[half]` — 0 = identity, 1 = run butterfly (`half = n_fft/2`)
1586 /// 2 — rev `[half]` — 0 = forward, 1 = swap outputs when gate=1
1587 /// 3 — tw_re `[half]`
1588 /// 4 — tw_im `[half]`
1589 ///
1590 /// Output: `[batch, n_fft, 2]` same layout. Slots with gate=0 copy inputs
1591 /// without twiddle math.
1592 FftButterflyStage {
1593 stage: u32,
1594 n_fft: u32,
1595 },
1596
1597 /// Log-mel spectrogram from RLX FFT block-layout spectrum.
1598 ///
1599 /// Inputs:
1600 /// 0 — spectrum `[..., 2*n_fft]` (re plane then im plane, same as `Op::Fft` output)
1601 /// 1 — mel filterbank `[n_mels, n_bins]` with `n_bins = n_fft/2 + 1`
1602 ///
1603 /// Output: `[..., n_mels]` with Whisper dynamic-range compression
1604 /// (`log10`, clamp to max−8 dB, `(x+4)/4`).
1605 LogMel,
1606
1607 /// VJP of [`Op::LogMel`] w.r.t. spectrum (input 0).
1608 ///
1609 /// Inputs: spectrum block, mel filters, upstream `dy`.
1610 /// Output: `d_spectrum` (same shape as input 0).
1611 LogMelBackward,
1612
1613 /// User-defined sub-graph with optional override AD rules.
1614 /// Mirrors JAX's `custom_vjp` / `custom_jvp` decorators: the
1615 /// caller wraps a forward computation and supplies its own
1616 /// reverse- and/or forward-mode AD bodies. Useful when:
1617 /// * The forward is iterative (Newton, fixed-point) and
1618 /// differentiating through the loop is wasteful — the
1619 /// vjp_body computes the implicit-function gradient at the
1620 /// converged point in one shot.
1621 /// * The math has a closed-form gradient that's much cheaper
1622 /// than autodiff.
1623 /// * The forward op is non-differentiable by tracing
1624 /// (sampling, argmax) and the user wants a smooth surrogate.
1625 ///
1626 /// **fwd_body**: `num_inputs` Op::Inputs in NodeId construction
1627 /// order, one Op::Output (the primal y). Forward execution
1628 /// inlines this body once.
1629 ///
1630 /// **vjp_body** (optional): Op::Inputs are `num_inputs` primal
1631 /// inputs in NodeId order, plus two special-named Inputs —
1632 /// `"primal_output"` (the y from forward) and `"d_output"` (the
1633 /// upstream gradient). Outputs: `num_inputs` tensors in
1634 /// `set_outputs` order, matching the gradients of each primal
1635 /// input. When `None`, reverse-mode AD recurses into fwd_body
1636 /// — same as if the op were inlined.
1637 ///
1638 /// **jvp_body** (optional): Op::Inputs are `num_inputs` primal
1639 /// inputs in NodeId order, `num_inputs` special-named Inputs
1640 /// `"tangent_0"..="tangent_{num_inputs-1}"` carrying each input's
1641 /// tangent, and an optional special-named `"primal_output"` Input
1642 /// (the y from forward, useful when the JVP must be evaluated at
1643 /// a converged / nonlinear point — e.g. IFT-style forward-mode AD
1644 /// of an iterative solver). Output: 1 tensor (the tangent of y).
1645 /// When `None`, forward-mode AD recurses into fwd_body.
1646 ///
1647 /// `num_inputs` is captured so [`Op::num_inputs`] stays
1648 /// infallible. Build via [`crate::Graph::custom_fn`].
1649 CustomFn {
1650 fwd_body: Box<crate::Graph>,
1651 vjp_body: Option<Box<crate::Graph>>,
1652 jvp_body: Option<Box<crate::Graph>>,
1653 num_inputs: u32,
1654 },
1655}
1656
1657impl Op {
1658 /// PLAN L4: discriminant for backend-supported-set checks.
1659 /// Stable, parameter-free identity per variant — `Op::Activation(_)`
1660 /// and `Op::Activation(Relu)` share the same `OpKind::Activation`.
1661 pub fn kind(&self) -> OpKind {
1662 match self {
1663 Op::Input { .. } => OpKind::Input,
1664 Op::Param { .. } => OpKind::Param,
1665 Op::Constant { .. } => OpKind::Constant,
1666 Op::Activation(_) => OpKind::Activation,
1667 Op::Cast { .. } => OpKind::Cast,
1668 Op::StopGradient => OpKind::StopGradient,
1669 Op::Quantize { .. } => OpKind::Quantize,
1670 Op::Dequantize { .. } => OpKind::Dequantize,
1671 Op::FakeQuantize { .. } => OpKind::FakeQuantize,
1672 Op::FakeQuantizeLSQ { .. } => OpKind::FakeQuantizeLSQ,
1673 Op::FakeQuantizeLSQBackwardX { .. } => OpKind::FakeQuantizeLSQBackwardX,
1674 Op::FakeQuantizeLSQBackwardScale { .. } => OpKind::FakeQuantizeLSQBackwardScale,
1675 Op::Binary(_) => OpKind::Binary,
1676 Op::Compare(_) => OpKind::Compare,
1677 Op::Where => OpKind::Where,
1678 Op::ElementwiseRegion { .. } => OpKind::ElementwiseRegion,
1679 Op::TransformRegion { .. } => OpKind::TransformRegion,
1680 Op::BatchElementwiseRegion { .. } => OpKind::BatchElementwiseRegion,
1681 Op::MatMul => OpKind::MatMul,
1682 Op::DotGeneral { .. } => OpKind::DotGeneral,
1683 Op::DenseSolve => OpKind::DenseSolve,
1684 Op::BatchedDenseSolve => OpKind::BatchedDenseSolve,
1685 Op::LayerNorm { .. } => OpKind::LayerNorm,
1686 Op::LayerNorm2d { .. } => OpKind::LayerNorm2d,
1687 Op::GroupNorm { .. } => OpKind::GroupNorm,
1688 Op::BatchNormInference { .. } => OpKind::BatchNormInference,
1689 Op::RmsNorm { .. } => OpKind::RmsNorm,
1690 Op::ResizeNearest2x => OpKind::ResizeNearest2x,
1691 Op::Attention { .. } => OpKind::Attention,
1692 Op::Rope { .. } => OpKind::Rope,
1693 Op::AxialRope2d { .. } => OpKind::AxialRope2d,
1694 Op::Reshape { .. } => OpKind::Reshape,
1695 Op::Transpose { .. } => OpKind::Transpose,
1696 Op::Narrow { .. } => OpKind::Narrow,
1697 Op::Concat { .. } => OpKind::Concat,
1698 Op::Expand { .. } => OpKind::Expand,
1699 Op::Gather { .. } => OpKind::Gather,
1700 Op::Reduce { .. } => OpKind::Reduce,
1701 Op::Softmax { .. } => OpKind::Softmax,
1702 Op::Cumsum { .. } => OpKind::Cumsum,
1703 Op::TopK { .. } => OpKind::TopK,
1704 Op::Sample { .. } => OpKind::Sample,
1705 Op::Conv { .. } => OpKind::Conv,
1706 Op::Im2Col { .. } => OpKind::Im2Col,
1707 Op::ConvTranspose2d { .. } => OpKind::ConvTranspose2d,
1708 Op::Pool { .. } => OpKind::Pool,
1709 Op::ReluBackward => OpKind::ReluBackward,
1710 Op::ActivationBackward { .. } => OpKind::ActivationBackward,
1711 Op::FakeQuantizeBackward { .. } => OpKind::FakeQuantizeBackward,
1712 Op::ComplexNormSq => OpKind::ComplexNormSq,
1713 Op::ComplexNormSqBackward => OpKind::ComplexNormSqBackward,
1714 Op::Conjugate => OpKind::Conjugate,
1715 Op::LayerNormBackwardInput { .. } => OpKind::LayerNormBackwardInput,
1716 Op::LayerNormBackwardGamma { .. } => OpKind::LayerNormBackwardGamma,
1717 Op::RmsNormBackwardInput { .. } => OpKind::RmsNormBackwardInput,
1718 Op::RmsNormBackwardGamma { .. } => OpKind::RmsNormBackwardGamma,
1719 Op::RmsNormBackwardBeta { .. } => OpKind::RmsNormBackwardBeta,
1720 Op::RopeBackward { .. } => OpKind::RopeBackward,
1721 Op::GroupNormBackwardInput { .. } => OpKind::GroupNormBackwardInput,
1722 Op::GroupNormBackwardGamma { .. } => OpKind::GroupNormBackwardGamma,
1723 Op::GroupNormBackwardBeta { .. } => OpKind::GroupNormBackwardBeta,
1724 Op::BatchNormInferenceBackwardInput { .. } => OpKind::BatchNormInferenceBackwardInput,
1725 Op::BatchNormInferenceBackwardGamma { .. } => OpKind::BatchNormInferenceBackwardGamma,
1726 Op::BatchNormInferenceBackwardBeta => OpKind::BatchNormInferenceBackwardBeta,
1727 Op::CumsumBackward { .. } => OpKind::CumsumBackward,
1728 Op::GatherBackward { .. } => OpKind::GatherBackward,
1729 Op::MaxPool2dBackward { .. } => OpKind::MaxPool2dBackward,
1730 Op::Conv2dBackwardInput { .. } => OpKind::Conv2dBackwardInput,
1731 Op::Conv2dBackwardWeight { .. } => OpKind::Conv2dBackwardWeight,
1732 Op::SoftmaxCrossEntropyWithLogits => OpKind::SoftmaxCrossEntropyWithLogits,
1733 Op::SoftmaxCrossEntropyBackward => OpKind::SoftmaxCrossEntropyBackward,
1734 Op::AttentionBackward { .. } => OpKind::AttentionBackward,
1735 Op::GroupedMatMul => OpKind::GroupedMatMul,
1736 Op::DequantGroupedMatMul { .. } => OpKind::DequantGroupedMatMul,
1737 Op::DequantMoEWeights { .. } => OpKind::DequantMoEWeights,
1738 Op::ScatterAdd => OpKind::ScatterAdd,
1739 Op::LoraMatMul { .. } => OpKind::LoraMatMul,
1740 Op::DequantMatMul { .. } => OpKind::DequantMatMul,
1741 Op::QMatMul { .. } => OpKind::QMatMul,
1742 Op::QConv2d { .. } => OpKind::QConv2d,
1743 Op::SelectiveScan { .. } => OpKind::SelectiveScan,
1744 Op::GatedDeltaNet { .. } => OpKind::GatedDeltaNet,
1745 Op::FusedSwiGLU { .. } => OpKind::FusedSwiGLU,
1746 Op::FusedMatMulBiasAct { .. } => OpKind::FusedMatMulBiasAct,
1747 Op::FusedResidualLN { .. } => OpKind::FusedResidualLN,
1748 Op::FusedResidualRmsNorm { .. } => OpKind::FusedResidualRmsNorm,
1749 Op::FusedAttentionBlock { .. } => OpKind::FusedAttentionBlock,
1750 Op::FusedTransformerLayer { .. } => OpKind::FusedTransformerLayer,
1751 Op::If { .. } => OpKind::If,
1752 Op::While { .. } => OpKind::While,
1753 Op::Scan { .. } => OpKind::Scan,
1754 Op::ScanBackward { .. } => OpKind::ScanBackward,
1755 Op::ScanBackwardXs { .. } => OpKind::ScanBackwardXs,
1756 Op::GaussianSplatRender { .. } => OpKind::GaussianSplatRender,
1757 Op::GaussianSplatRenderBackward { .. } => OpKind::GaussianSplatRenderBackward,
1758 Op::GaussianSplatPrepare { .. } => OpKind::GaussianSplatPrepare,
1759 Op::GaussianSplatRasterize { .. } => OpKind::GaussianSplatRasterize,
1760 Op::Custom { .. } => OpKind::Custom,
1761 Op::CustomFn { .. } => OpKind::CustomFn,
1762 Op::Fft { .. } => OpKind::Fft,
1763 Op::FftButterflyStage { .. } => OpKind::FftButterflyStage,
1764 Op::LogMel => OpKind::LogMel,
1765 Op::LogMelBackward => OpKind::LogMelBackward,
1766 }
1767 }
1768
1769 /// True if this op is element-wise (same shape in, same shape out).
1770 /// Element-wise ops are prime fusion candidates.
1771 pub fn is_elementwise(&self) -> bool {
1772 matches!(
1773 self,
1774 Op::Activation(_)
1775 | Op::Cast { .. }
1776 | Op::StopGradient
1777 | Op::Binary(_)
1778 | Op::Compare(_)
1779 | Op::Where
1780 | Op::ElementwiseRegion { .. }
1781 | Op::BatchElementwiseRegion { .. }
1782 )
1783 }
1784
1785 /// True if this op may appear in a [`Op::TransformRegion`] chain.
1786 pub fn is_transform_eligible(&self) -> bool {
1787 matches!(self, Op::ResizeNearest2x)
1788 }
1789
1790 /// True if this op is a BLAS/compute-intensive op that forms a fusion boundary.
1791 pub fn is_blas(&self) -> bool {
1792 matches!(
1793 self,
1794 Op::MatMul
1795 | Op::DotGeneral { .. }
1796 | Op::DenseSolve
1797 | Op::BatchedDenseSolve
1798 | Op::Conv { .. }
1799 | Op::Im2Col { .. }
1800 | Op::ConvTranspose2d { .. }
1801 | Op::FusedMatMulBiasAct { .. }
1802 | Op::GroupedMatMul
1803 | Op::DequantGroupedMatMul { .. }
1804 | Op::DequantMoEWeights { .. }
1805 | Op::LoraMatMul { .. }
1806 | Op::DequantMatMul { .. }
1807 | Op::QMatMul { .. }
1808 | Op::QConv2d { .. }
1809 )
1810 }
1811
1812 /// True if element-wise fusion must not span across this op.
1813 pub fn is_fusion_boundary(&self) -> bool {
1814 self.is_blas()
1815 || matches!(
1816 self,
1817 Op::GaussianSplatRender { .. }
1818 | Op::GaussianSplatRenderBackward { .. }
1819 | Op::GaussianSplatPrepare { .. }
1820 | Op::GaussianSplatRasterize { .. }
1821 )
1822 }
1823
1824 /// True if this op is a reduction (drives loop iteration in fused kernels).
1825 pub fn is_reduction(&self) -> bool {
1826 matches!(
1827 self,
1828 Op::Reduce { .. } | Op::Softmax { .. } | Op::TopK { .. }
1829 )
1830 }
1831
1832 /// Number of tensor inputs this op expects.
1833 pub fn num_inputs(&self) -> usize {
1834 match self {
1835 Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => 0,
1836 Op::Activation(_)
1837 | Op::Cast { .. }
1838 | Op::StopGradient
1839 | Op::Reshape { .. }
1840 | Op::Quantize { .. }
1841 | Op::Dequantize { .. }
1842 | Op::Transpose { .. }
1843 | Op::Narrow { .. }
1844 | Op::Expand { .. }
1845 | Op::Reduce { .. }
1846 | Op::Softmax { .. }
1847 | Op::FusedSwiGLU { .. }
1848 | Op::TopK { .. }
1849 | Op::Cumsum { .. }
1850 | Op::Sample { .. }
1851 | Op::ResizeNearest2x => 1,
1852 // EMA / Fixed scale modes carry a state tensor as a 2nd input;
1853 // PerBatch (default) doesn't need one.
1854 Op::FakeQuantize { scale_mode, .. } => match scale_mode {
1855 ScaleMode::PerBatch => 1,
1856 ScaleMode::EMA { .. } | ScaleMode::Fixed => 2,
1857 },
1858 Op::FakeQuantizeLSQ { .. } => 2, // x, scale (learned param)
1859 Op::FakeQuantizeLSQBackwardX { .. } | Op::FakeQuantizeLSQBackwardScale { .. } => 3, // x, scale, dy
1860 Op::Binary(_) | Op::Compare(_) | Op::Gather { .. } | Op::MatMul | Op::ScatterAdd => 2,
1861 Op::GroupedMatMul => 3, // input, weight, expert_idx
1862 Op::DequantGroupedMatMul { .. } => 3, // input, packed_w, expert_idx
1863 Op::DequantMoEWeights { .. } => 1, // packed_w
1864 Op::LoraMatMul { .. } => 4, // x, w, a, b
1865 // x, w_q, scale, zp — or x, packed_w_bytes for GGUF
1866 // schemes (their scales/mins live inside the packed bytes,
1867 // see `QuantScheme::is_gguf`).
1868 Op::DequantMatMul { scheme } => {
1869 if scheme.is_gguf() {
1870 2
1871 } else {
1872 4
1873 }
1874 }
1875 Op::QMatMul { .. } => 3, // x, w, bias
1876 Op::QConv2d { .. } => 3, // x, w, bias
1877 Op::SelectiveScan { .. } => 5, // x, delta, a, b, c
1878 Op::GatedDeltaNet { carry_state, .. } if *carry_state => 6, // + state in/out
1879 Op::GatedDeltaNet { .. } => 5, // q, k, v, g, beta
1880 Op::Where => 3, // cond, on_true, on_false
1881 Op::Attention { mask_kind, .. } => match mask_kind {
1882 MaskKind::Custom | MaskKind::Bias => 4, // Q, K, V, mask
1883 _ => 3, // Q, K, V (mask synthesized in-kernel)
1884 },
1885 Op::AttentionBackward { mask_kind, .. } => match mask_kind {
1886 MaskKind::Custom | MaskKind::Bias => 5, // q, k, v, dy, mask
1887 _ => 4, // q, k, v, dy
1888 },
1889 Op::Rope { .. } => 3, // x, cos, sin
1890 Op::AxialRope2d { .. } => 1,
1891 Op::LayerNorm { .. }
1892 | Op::LayerNorm2d { .. }
1893 | Op::GroupNorm { .. }
1894 | Op::RmsNorm { .. } => 3, // input, gamma, beta
1895 Op::BatchNormInference { .. } => 5, // x, gamma, beta, mean, var
1896 Op::FusedMatMulBiasAct { .. } => 3, // input, weight, bias
1897 Op::FusedResidualLN { has_bias: true, .. } => 5, // x, residual, bias, gamma, beta
1898 Op::FusedResidualLN {
1899 has_bias: false, ..
1900 } => 4, // x, residual, gamma, beta
1901 Op::FusedResidualRmsNorm { has_bias: true, .. } => 5, // x, residual, bias, gamma, beta
1902 Op::FusedResidualRmsNorm {
1903 has_bias: false, ..
1904 } => 4, // x, residual, gamma, beta
1905 Op::Conv { .. } | Op::ConvTranspose2d { .. } => 2, // input, weight (bias via Add)
1906 Op::Im2Col { .. } => 1,
1907 Op::Pool { .. } => 1,
1908 Op::ReluBackward => 2, // x, dy
1909 Op::ActivationBackward { .. } => 2, // x, dy
1910 Op::FakeQuantizeBackward { .. } => 2, // x, dy
1911 Op::ComplexNormSq => 1, // z (C64)
1912 Op::ComplexNormSqBackward => 2, // z, g
1913 Op::Conjugate => 1, // z (C64)
1914 Op::LayerNormBackwardInput { .. } => 3, // x, gamma, dy
1915 Op::LayerNormBackwardGamma { .. } => 2, // x, dy
1916 Op::RmsNormBackwardInput { .. } => 4, // x, gamma, beta, dy
1917 Op::RmsNormBackwardGamma { .. } => 4,
1918 Op::RmsNormBackwardBeta { .. } => 4,
1919 Op::RopeBackward { .. } => 3, // dy, cos, sin
1920 Op::GroupNormBackwardInput { .. } => 4, // x, gamma, beta, dy
1921 Op::GroupNormBackwardGamma { .. } => 2, // x, dy
1922 Op::GroupNormBackwardBeta { .. } => 2,
1923 Op::BatchNormInferenceBackwardInput { .. } => 5, // x, gamma, mean, var, dy
1924 Op::BatchNormInferenceBackwardGamma { .. } => 4, // x, mean, var, dy
1925 Op::BatchNormInferenceBackwardBeta => 1, // dy
1926 Op::CumsumBackward { .. } => 1, // dy
1927 Op::GatherBackward { .. } => 2, // dy, indices
1928 Op::MaxPool2dBackward { .. } => 2, // x, dy
1929 Op::Conv2dBackwardInput { .. } => 2, // dy, w
1930 Op::Conv2dBackwardWeight { .. } => 2, // x, dy
1931 Op::SoftmaxCrossEntropyWithLogits => 2, // logits, labels
1932 Op::SoftmaxCrossEntropyBackward => 3, // logits, labels, d_loss
1933 Op::Concat { .. } => 0, // variadic — checked at graph level
1934 Op::DotGeneral { .. } => 2,
1935 Op::DenseSolve => 2, // A, b
1936 Op::BatchedDenseSolve => 2, // A [B,N,N], b [B,N] or [B,N,K]
1937 Op::FusedAttentionBlock {
1938 has_bias, has_rope, ..
1939 } => 4 + if *has_bias { 2 } else { 0 } + if *has_rope { 2 } else { 0 },
1940 Op::If { .. } => 1, // predicate (captures handled separately)
1941 Op::While { .. } => 0, // variadic loop-carried; checked at graph level
1942 Op::Scan {
1943 num_bcast, num_xs, ..
1944 } => 1 + *num_bcast as usize + *num_xs as usize,
1945 Op::ScanBackward { num_xs, .. } => 3 + *num_xs as usize, // init, trajectory, upstream, xs_0..
1946 Op::ScanBackwardXs { num_xs, .. } => 3 + *num_xs as usize, // same as ScanBackward
1947 Op::GaussianSplatRender { .. } => 7,
1948 Op::GaussianSplatRenderBackward { .. } => 8,
1949 Op::GaussianSplatPrepare { .. } => 7,
1950 Op::GaussianSplatRasterize { .. } => 2,
1951 Op::FusedTransformerLayer { has_bias, .. } => {
1952 // hidden + qkv_w + out_w + ln1_g + ln1_b + fc1_w + fc2_w + ln2_g + ln2_b + mask = 10
1953 // bias variant adds: qkv_b + out_b + fc1_b + fc2_b = 4 more
1954 10 + if *has_bias { 4 } else { 0 }
1955 }
1956 Op::ElementwiseRegion { num_inputs, .. } => *num_inputs as usize,
1957 Op::TransformRegion { num_inputs, .. } => *num_inputs as usize,
1958 Op::BatchElementwiseRegion {
1959 num_batch_inputs, ..
1960 } => *num_batch_inputs as usize,
1961 Op::Custom { num_inputs, .. } => *num_inputs as usize,
1962 Op::CustomFn { num_inputs, .. } => *num_inputs as usize,
1963 Op::Fft { .. } => 1,
1964 Op::FftButterflyStage { .. } => 5,
1965 Op::LogMel => 2,
1966 Op::LogMelBackward => 3,
1967 }
1968 }
1969}
1970
1971impl std::fmt::Display for Op {
1972 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1973 match self {
1974 Op::Input { name } => write!(f, "input(\"{name}\")"),
1975 Op::Param { name } => write!(f, "param(\"{name}\")"),
1976 Op::Constant { data } => write!(f, "const({}B)", data.len()),
1977 Op::Activation(a) => write!(f, "{a:?}"),
1978 Op::Quantize { axis, scales, .. } => match axis {
1979 None => write!(f, "quantize(s={})", scales[0]),
1980 Some(d) => write!(f, "quantize(axis={d},nch={})", scales.len()),
1981 },
1982 Op::Dequantize { axis, scales, .. } => match axis {
1983 None => write!(f, "dequantize(s={})", scales[0]),
1984 Some(d) => write!(f, "dequantize(axis={d},nch={})", scales.len()),
1985 },
1986 Op::FakeQuantize {
1987 bits,
1988 axis,
1989 ste,
1990 scale_mode,
1991 } => match axis {
1992 None => write!(
1993 f,
1994 "fake_quant(bits={bits},ste={ste:?},scale={scale_mode:?})"
1995 ),
1996 Some(d) => write!(
1997 f,
1998 "fake_quant(bits={bits},axis={d},ste={ste:?},scale={scale_mode:?})"
1999 ),
2000 },
2001 Op::FakeQuantizeLSQ { bits, axis } => match axis {
2002 None => write!(f, "fake_quant_lsq(bits={bits})"),
2003 Some(d) => write!(f, "fake_quant_lsq(bits={bits},axis={d})"),
2004 },
2005 Op::FakeQuantizeLSQBackwardX { bits, .. } => {
2006 write!(f, "fake_quant_lsq_bwd_x(bits={bits})")
2007 }
2008 Op::FakeQuantizeLSQBackwardScale { bits, .. } => {
2009 write!(f, "fake_quant_lsq_bwd_s(bits={bits})")
2010 }
2011 Op::Cast { to } => write!(f, "cast({to})"),
2012 Op::StopGradient => write!(f, "stop_gradient"),
2013 Op::Binary(op) => write!(f, "{op:?}"),
2014 Op::Compare(op) => write!(f, "{op:?}"),
2015 Op::Where => write!(f, "where"),
2016 Op::MatMul => write!(f, "matmul"),
2017 Op::DotGeneral { .. } => write!(f, "dot_general"),
2018 Op::DenseSolve => write!(f, "dense_solve"),
2019 Op::BatchedDenseSolve => write!(f, "batched_dense_solve"),
2020 Op::LayerNorm { eps, .. } => write!(f, "layer_norm(eps={eps})"),
2021 Op::GroupNorm { num_groups, eps } => {
2022 write!(f, "group_norm(groups={num_groups},eps={eps})")
2023 }
2024 Op::BatchNormInference { eps } => write!(f, "batch_norm_inference(eps={eps})"),
2025 Op::ResizeNearest2x => write!(f, "resize_nearest_2x"),
2026 Op::RmsNorm { eps, .. } => write!(f, "rms_norm(eps={eps})"),
2027 Op::Attention {
2028 num_heads,
2029 head_dim,
2030 mask_kind,
2031 score_scale,
2032 attn_logit_softcap,
2033 } => {
2034 let mut s = match mask_kind {
2035 MaskKind::Custom => format!("attention(h={num_heads},d={head_dim})"),
2036 MaskKind::None => format!("attention(h={num_heads},d={head_dim},nomask)"),
2037 MaskKind::Causal => format!("attention(h={num_heads},d={head_dim},causal)"),
2038 MaskKind::SlidingWindow(w) => {
2039 format!("attention(h={num_heads},d={head_dim},sw={w})")
2040 }
2041 MaskKind::Bias => format!("attention(h={num_heads},d={head_dim},bias)"),
2042 };
2043 if let Some(sc) = score_scale {
2044 s.push_str(&format!(",scale={sc}"));
2045 }
2046 if let Some(cap) = attn_logit_softcap {
2047 s.push_str(&format!(",softcap={cap}"));
2048 }
2049 write!(f, "{s}")
2050 }
2051 Op::Rope { head_dim, n_rot } => write!(f, "rope(d={head_dim}, n_rot={n_rot})"),
2052 Op::AxialRope2d {
2053 end_x,
2054 end_y,
2055 head_dim,
2056 num_heads,
2057 theta,
2058 repeat_factor,
2059 } => write!(
2060 f,
2061 "axial_rope2d({end_x}x{end_y},h={num_heads},d={head_dim},θ={theta},r={repeat_factor})"
2062 ),
2063 Op::Reshape { new_shape } => write!(f, "reshape({new_shape:?})"),
2064 Op::Transpose { perm } => write!(f, "transpose({perm:?})"),
2065 Op::Narrow { axis, start, len } => write!(f, "narrow({axis},{start},{len})"),
2066 Op::Concat { axis } => write!(f, "concat(axis={axis})"),
2067 Op::Expand { .. } => write!(f, "expand"),
2068 Op::Gather { axis } => write!(f, "gather(axis={axis})"),
2069 Op::Reduce { op, axes, .. } => write!(f, "reduce_{op:?}({axes:?})"),
2070 Op::Softmax { axis } => write!(f, "softmax(axis={axis})"),
2071 Op::Cumsum { axis, exclusive } => {
2072 if *exclusive {
2073 write!(f, "cumsum(axis={axis},excl)")
2074 } else {
2075 write!(f, "cumsum(axis={axis})")
2076 }
2077 }
2078 Op::Sample {
2079 top_k,
2080 top_p,
2081 temperature,
2082 ..
2083 } => {
2084 write!(f, "sample(t={temperature}")?;
2085 if *top_k > 0 {
2086 write!(f, ",k={top_k}")?;
2087 }
2088 if *top_p < 1.0 {
2089 write!(f, ",p={top_p}")?;
2090 }
2091 write!(f, ")")
2092 }
2093 Op::TopK { k } => write!(f, "topk(k={k})"),
2094 Op::GroupedMatMul => write!(f, "grouped_matmul"),
2095 Op::DequantGroupedMatMul { scheme } => {
2096 write!(f, "dequant_grouped_matmul({scheme})")
2097 }
2098 Op::DequantMoEWeights { scheme } => write!(f, "dequant_moe_weights({scheme})"),
2099 Op::LoraMatMul { scale } => write!(f, "lora_matmul(scale={scale})"),
2100 Op::DequantMatMul { scheme } => write!(f, "dequant_matmul({scheme})"),
2101 Op::QMatMul {
2102 x_zp,
2103 w_zp,
2104 out_zp,
2105 mult,
2106 } => write!(
2107 f,
2108 "q_matmul(x_zp={x_zp},w_zp={w_zp},out_zp={out_zp},mult={mult})"
2109 ),
2110 Op::QConv2d { kernel_size, .. } => write!(f, "q_conv2d({kernel_size:?})"),
2111 Op::SelectiveScan { state_size } => write!(f, "ssm_scan(n={state_size})"),
2112 Op::GatedDeltaNet {
2113 state_size,
2114 carry_state,
2115 } => {
2116 if *carry_state {
2117 write!(f, "gated_delta_net(n={state_size},carry)")
2118 } else {
2119 write!(f, "gated_delta_net(n={state_size})")
2120 }
2121 }
2122 Op::ScatterAdd => write!(f, "scatter_add"),
2123 Op::Conv { kernel_size, .. } => write!(f, "conv2d({kernel_size:?})"),
2124 Op::Im2Col { kernel_size, .. } => write!(f, "im2col({kernel_size:?})"),
2125 Op::ConvTranspose2d { kernel_size, .. } => {
2126 write!(f, "conv_transpose2d({kernel_size:?})")
2127 }
2128 Op::LayerNorm2d { eps } => write!(f, "layer_norm2d(eps={eps})"),
2129 Op::Pool {
2130 kind, kernel_size, ..
2131 } => write!(f, "pool_{kind:?}({kernel_size:?})"),
2132 Op::ReluBackward => write!(f, "relu_backward"),
2133 Op::ActivationBackward { kind } => write!(f, "{kind:?}_backward"),
2134 Op::ComplexNormSq => write!(f, "complex_norm_sq"),
2135 Op::ComplexNormSqBackward => write!(f, "complex_norm_sq_backward"),
2136 Op::Conjugate => write!(f, "conjugate"),
2137 Op::FakeQuantizeBackward { bits, ste, .. } => {
2138 write!(f, "fake_quant_backward(bits={bits},ste={ste:?})")
2139 }
2140 Op::MaxPool2dBackward { kernel_size, .. } => {
2141 write!(f, "maxpool2d_backward({kernel_size:?})")
2142 }
2143 Op::Conv2dBackwardInput { kernel_size, .. } => {
2144 write!(f, "conv2d_backward_input({kernel_size:?})")
2145 }
2146 Op::Conv2dBackwardWeight { kernel_size, .. } => {
2147 write!(f, "conv2d_backward_weight({kernel_size:?})")
2148 }
2149 Op::SoftmaxCrossEntropyWithLogits => write!(f, "sce_with_logits"),
2150 Op::SoftmaxCrossEntropyBackward => write!(f, "sce_backward"),
2151 Op::AttentionBackward {
2152 num_heads,
2153 head_dim,
2154 mask_kind,
2155 wrt,
2156 } => match mask_kind {
2157 MaskKind::None => write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},nomask)"),
2158 MaskKind::Causal => {
2159 write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},causal)")
2160 }
2161 MaskKind::SlidingWindow(w) => {
2162 write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},sw={w})")
2163 }
2164 MaskKind::Custom => {
2165 write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},custom)")
2166 }
2167 MaskKind::Bias => write!(f, "attn_bwd_{wrt:?}(h={num_heads},d={head_dim},bias)"),
2168 },
2169 Op::FusedMatMulBiasAct { activation } => {
2170 write!(f, "fused_mm_bias")?;
2171 if let Some(a) = activation {
2172 write!(f, "_{a:?}")?;
2173 }
2174 Ok(())
2175 }
2176 Op::FusedResidualLN { has_bias, eps } => {
2177 write!(f, "fused_residual")?;
2178 if *has_bias {
2179 write!(f, "_bias")?;
2180 }
2181 write!(f, "_ln(eps={eps})")
2182 }
2183 Op::FusedResidualRmsNorm { has_bias, eps } => {
2184 write!(f, "fused_residual")?;
2185 if *has_bias {
2186 write!(f, "_bias")?;
2187 }
2188 write!(f, "_rms(eps={eps})")
2189 }
2190 Op::FusedSwiGLU {
2191 cast_to,
2192 gate_first,
2193 } => {
2194 let mut s = match cast_to {
2195 Some(dt) => format!("fused_swiglu(cast={dt}"),
2196 None => "fused_swiglu(".to_string(),
2197 };
2198 if *gate_first {
2199 s.push_str(",gate_first");
2200 }
2201 s.push(')');
2202 write!(f, "{s}")
2203 }
2204 Op::FusedAttentionBlock {
2205 num_heads,
2206 head_dim,
2207 has_bias,
2208 has_rope,
2209 } => {
2210 write!(f, "fused_attn(h={num_heads},d={head_dim}")?;
2211 if *has_bias {
2212 write!(f, ",bias")?;
2213 }
2214 if *has_rope {
2215 write!(f, ",rope")?;
2216 }
2217 write!(f, ")")
2218 }
2219 Op::If { .. } => write!(f, "if(...)"),
2220 Op::While { max_iterations, .. } => match max_iterations {
2221 Some(n) => write!(f, "while(...max={n})"),
2222 None => write!(f, "while(...)"),
2223 },
2224 Op::Scan {
2225 length,
2226 save_trajectory,
2227 num_xs,
2228 ..
2229 } => {
2230 let traj = if *save_trajectory { ",traj" } else { "" };
2231 let xs = if *num_xs > 0 {
2232 format!(",xs={}", num_xs)
2233 } else {
2234 String::new()
2235 };
2236 write!(f, "scan(len={length}{xs}{traj})")
2237 }
2238 Op::ScanBackward {
2239 length,
2240 save_trajectory,
2241 num_xs,
2242 ..
2243 } => {
2244 let traj = if *save_trajectory { ",traj" } else { "" };
2245 let xs = if *num_xs > 0 {
2246 format!(",xs={}", num_xs)
2247 } else {
2248 String::new()
2249 };
2250 write!(f, "scan_bwd(len={length}{xs}{traj})")
2251 }
2252 Op::ScanBackwardXs {
2253 length,
2254 save_trajectory,
2255 num_xs,
2256 xs_idx,
2257 ..
2258 } => {
2259 let traj = if *save_trajectory { ",traj" } else { "" };
2260 write!(
2261 f,
2262 "scan_bwd_xs(len={length},xs={num_xs},idx={xs_idx}{traj})"
2263 )
2264 }
2265 Op::FusedTransformerLayer {
2266 num_heads,
2267 head_dim,
2268 intermediate_size,
2269 has_bias,
2270 ..
2271 } => {
2272 write!(
2273 f,
2274 "fused_layer(h={num_heads},d={head_dim},int={intermediate_size}"
2275 )?;
2276 if *has_bias {
2277 write!(f, ",bias")?;
2278 }
2279 write!(f, ")")
2280 }
2281 Op::ElementwiseRegion {
2282 chain,
2283 num_inputs,
2284 scalar_input_mask,
2285 input_modulus: _,
2286 prologue,
2287 prologue_input: _,
2288 } => {
2289 let pro = match prologue {
2290 RegionPrologue::None => "",
2291 RegionPrologue::ResizeNearest2x => ",prologue=resize2x",
2292 };
2293 if *scalar_input_mask != 0 {
2294 write!(
2295 f,
2296 "ew_region(in={num_inputs},steps={},scalar_mask=0x{:x}{pro})",
2297 chain.len(),
2298 scalar_input_mask
2299 )
2300 } else {
2301 write!(f, "ew_region(in={num_inputs},steps={}{pro})", chain.len())
2302 }
2303 }
2304 Op::TransformRegion { steps, num_inputs } => {
2305 write!(f, "transform_region(in={num_inputs},steps={})", steps.len())
2306 }
2307 Op::BatchElementwiseRegion {
2308 chain,
2309 num_batch_inputs,
2310 scalar_input_mask,
2311 prologue,
2312 ..
2313 } => write!(
2314 f,
2315 "batch_ew_region(batch={num_batch_inputs},steps={},mask=0x{:x},prologue={prologue:?})",
2316 chain.len(),
2317 scalar_input_mask
2318 ),
2319 Op::LayerNormBackwardInput { eps, .. } => {
2320 write!(f, "layer_norm_backward_input(eps={eps})")
2321 }
2322 Op::LayerNormBackwardGamma { eps, .. } => {
2323 write!(f, "layer_norm_backward_gamma(eps={eps})")
2324 }
2325 Op::RmsNormBackwardInput { eps, .. } => write!(f, "rms_norm_backward_input(eps={eps})"),
2326 Op::RmsNormBackwardGamma { eps, .. } => write!(f, "rms_norm_backward_gamma(eps={eps})"),
2327 Op::RmsNormBackwardBeta { eps, .. } => write!(f, "rms_norm_backward_beta(eps={eps})"),
2328 Op::RopeBackward { head_dim, n_rot } => {
2329 write!(f, "rope_backward(d={head_dim},n_rot={n_rot})")
2330 }
2331 Op::GroupNormBackwardInput { num_groups, eps } => {
2332 write!(f, "group_norm_backward_input(g={num_groups},eps={eps})")
2333 }
2334 Op::GroupNormBackwardGamma { num_groups, eps } => {
2335 write!(f, "group_norm_backward_gamma(g={num_groups},eps={eps})")
2336 }
2337 Op::GroupNormBackwardBeta { num_groups, eps } => {
2338 write!(f, "group_norm_backward_beta(g={num_groups},eps={eps})")
2339 }
2340 Op::BatchNormInferenceBackwardInput { eps } => {
2341 write!(f, "batch_norm_inference_backward_input(eps={eps})")
2342 }
2343 Op::BatchNormInferenceBackwardGamma { eps } => {
2344 write!(f, "batch_norm_inference_backward_gamma(eps={eps})")
2345 }
2346 Op::BatchNormInferenceBackwardBeta => {
2347 write!(f, "batch_norm_inference_backward_beta")
2348 }
2349 Op::CumsumBackward { axis, exclusive } => {
2350 write!(f, "cumsum_backward(axis={axis},exclusive={exclusive})")
2351 }
2352 Op::GatherBackward { axis } => write!(f, "gather_backward(axis={axis})"),
2353 Op::GaussianSplatRender {
2354 width,
2355 height,
2356 tile_size,
2357 radius_scale,
2358 alpha_cutoff,
2359 max_splat_steps,
2360 transmittance_threshold,
2361 max_list_entries,
2362 } => write!(
2363 f,
2364 "gaussian_splat_render({width}x{height},tile={tile_size},r={radius_scale},a={alpha_cutoff},steps={max_splat_steps},t={transmittance_threshold},list={max_list_entries})"
2365 ),
2366 Op::GaussianSplatRenderBackward {
2367 width,
2368 height,
2369 loss_grad_clip,
2370 sh_band,
2371 ..
2372 } => write!(
2373 f,
2374 "gaussian_splat_render_bwd({width}x{height},clip={loss_grad_clip},sh={sh_band})"
2375 ),
2376 Op::GaussianSplatPrepare {
2377 width,
2378 height,
2379 tile_size,
2380 radius_scale,
2381 alpha_cutoff,
2382 max_splat_steps,
2383 transmittance_threshold,
2384 max_list_entries,
2385 ..
2386 } => write!(
2387 f,
2388 "gaussian_splat_prepare({width}x{height},tile={tile_size},r={radius_scale},a={alpha_cutoff},steps={max_splat_steps},t={transmittance_threshold},list={max_list_entries})"
2389 ),
2390 Op::GaussianSplatRasterize {
2391 width,
2392 height,
2393 tile_size,
2394 alpha_cutoff,
2395 max_splat_steps,
2396 transmittance_threshold,
2397 max_list_entries,
2398 ..
2399 } => write!(
2400 f,
2401 "gaussian_splat_rasterize({width}x{height},tile={tile_size},a={alpha_cutoff},steps={max_splat_steps},t={transmittance_threshold},list={max_list_entries})"
2402 ),
2403 Op::Custom {
2404 name,
2405 num_inputs,
2406 attrs,
2407 } => write!(f, "custom({name},in={num_inputs},attrs={}B)", attrs.len()),
2408 Op::CustomFn {
2409 num_inputs,
2410 vjp_body,
2411 jvp_body,
2412 ..
2413 } => {
2414 let v = if vjp_body.is_some() { ",vjp" } else { "" };
2415 let j = if jvp_body.is_some() { ",jvp" } else { "" };
2416 write!(f, "custom_fn(in={num_inputs}{v}{j})")
2417 }
2418 Op::Fft { inverse, norm } => {
2419 write!(f, "fft(inverse={inverse}, norm={norm:?})")
2420 }
2421 Op::FftButterflyStage { stage, n_fft } => {
2422 write!(f, "fft_butterfly_stage(stage={stage}, n_fft={n_fft})")
2423 }
2424 Op::LogMel => write!(f, "log_mel()"),
2425 Op::LogMelBackward => write!(f, "log_mel_backward()"),
2426 }
2427 }
2428}