Skip to main content

baracuda_kernels_types/
ops.rs

1//! Per-category op discriminant enums.
2//!
3//! Each op category (B, C, D, N, …) gets a `*Kind` enum whose variants
4//! correspond to individual PyTorch / JAX ops. The enum value is also
5//! the runtime tag stored as a `u16` in [`crate::KernelSku::op`].
6//!
7//! New enums land alongside the Plan type that consumes them. Today
8//! Phase 3 contributes [`BinaryKind`] and [`UnaryKind`];
9//! [`TernaryKind`], [`GatedActivationKind`], [`ShapeLayoutKind`] follow
10//! as their Plan types ship.
11
12/// Binary elementwise op discriminant.
13///
14/// Stored as `u16` in [`crate::KernelSku::op`] when
15/// `category == OpCategory::BinaryElementwise`. Variants correspond to
16/// the union of PyTorch (`torch.<op>` / `torch.Tensor.<op>`) and JAX
17/// (`jax.numpy.<op>` / `jax.lax.<op>`) binary elementwise ops.
18///
19/// Today only [`Self::Add`] is wired — the Phase 3 trailblazer SKU. The
20/// other variants are reserved discriminants for the fanout sessions
21/// that ship sub / mul / div / pow / comparisons / bitwise.
22#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
23#[repr(u16)]
24#[non_exhaustive]
25pub enum BinaryKind {
26    /// `y = a + b` — elementwise addition. Trailblazer SKU for
27    /// `baracuda-kernels` Phase 3.
28    Add = 0,
29    /// `y = a - b` — elementwise subtraction.
30    Sub = 1,
31    /// `y = a * b` — elementwise multiplication.
32    Mul = 2,
33    /// `y = a / b` — elementwise division.
34    Div = 3,
35    /// `y = floor(a / b)` — elementwise floor-divide.
36    FloorDivide = 4,
37    /// `y = a mod b` — elementwise Python-style modulo (sign matches `b`).
38    Mod = 5,
39    /// `y = remainder(a, b)` — elementwise C-style remainder (sign
40    /// matches `a`).
41    Remainder = 6,
42    /// `y = a ** b` — elementwise power (broadcast scalar exponent OK).
43    Pow = 7,
44    /// `y = atan2(a, b)`.
45    Atan2 = 8,
46    /// `y = hypot(a, b) = sqrt(a² + b²)`.
47    Hypot = 9,
48    /// `y = a` with sign-bit copied from `b`.
49    Copysign = 10,
50    /// `y` = next representable value from `a` toward `b`.
51    Nextafter = 11,
52    /// `y = a · 2^b` (integer `b` broadcast as scalar in practice).
53    Ldexp = 12,
54    /// `y = min(a, b)` — IEEE 754 semantics (NaN-aware).
55    Minimum = 13,
56    /// `y = max(a, b)` — IEEE 754 semantics (NaN-aware).
57    Maximum = 14,
58    /// `y = fmin(a, b)` — PyTorch fmin (NaN-propagating-from-other).
59    Fmin = 15,
60    /// `y = fmax(a, b)` — PyTorch fmax (NaN-propagating-from-other).
61    Fmax = 16,
62    /// `y = (a == b)` — returns bool.
63    Eq = 17,
64    /// `y = (a != b)` — returns bool.
65    Ne = 18,
66    /// `y = (a > b)` — returns bool.
67    Gt = 19,
68    /// `y = (a >= b)` — returns bool.
69    Ge = 20,
70    /// `y = (a < b)` — returns bool.
71    Lt = 21,
72    /// `y = (a <= b)` — returns bool.
73    Le = 22,
74    /// `y = a && b` — bool only.
75    LogicalAnd = 23,
76    /// `y = a || b` — bool only.
77    LogicalOr = 24,
78    /// `y = a ^ b` (logical) — bool only.
79    LogicalXor = 25,
80    /// `y = a & b` — integer only.
81    BitwiseAnd = 26,
82    /// `y = a | b` — integer only.
83    BitwiseOr = 27,
84    /// `y = a ^ b` (bitwise) — integer only.
85    BitwiseXor = 28,
86    /// `y = a << b` — integer only.
87    BitwiseLeftShift = 29,
88    /// `y = a >> b` — integer only.
89    BitwiseRightShift = 30,
90    /// `y = a + (b - a) * weight` (broadcast scalar weight). Per
91    /// PyTorch's `torch.lerp` convention.
92    Lerp = 31,
93}
94
95/// Unary elementwise op discriminant.
96///
97/// Stored as `u16` in [`crate::KernelSku::op`] when
98/// `category == OpCategory::UnaryElementwise`. Variants correspond to
99/// the union of PyTorch (`torch.<op>` / `torch.Tensor.<op>`) and JAX
100/// (`jax.numpy.<op>` / `jax.lax.<op>`) unary elementwise ops, plus the
101/// activation family from PyTorch `nn.functional`.
102///
103/// Today only [`Self::Neg`] is wired — the Phase 3 unary trailblazer
104/// SKU. The other variants are reserved discriminants for the fanout
105/// sessions that ship the math (abs / sqrt / exp / log / sin / …) and
106/// activation (relu / gelu / silu / …) families.
107///
108/// Ops that return a different dtype than the input (`isnan`, `isinf`,
109/// `isfinite`, `logical_not`) are reserved here but will route through
110/// a future `UnaryToBoolPlan` (or similar) with a distinct output type
111/// — not through this enum's `UnaryPlan<T, N>`.
112///
113/// Parameterized activations (`leaky_relu(α)`, `elu(α)`, `threshold(t, v)`,
114/// `hardshrink(λ)`, `softshrink(λ)`) carry their parameters via a
115/// `UnaryParams` field on the descriptor — landed when the first
116/// parameterized op ships, omitted for the trailblazer.
117#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
118#[repr(u16)]
119#[non_exhaustive]
120pub enum UnaryKind {
121    // ---- Category B: elementwise unary (math) — trivial ----
122    /// `y = -x` — elementwise negation. Trailblazer SKU.
123    Neg = 0,
124    /// `y = |x|` — elementwise absolute value.
125    Abs = 1,
126    /// `y = sign(x)` — `-1` / `0` / `+1` per the input's sign.
127    Sign = 2,
128    /// `y = 1 / x` — elementwise reciprocal.
129    Reciprocal = 3,
130    /// `y = x * x` — elementwise square.
131    Square = 4,
132    /// `y = x * x * x` — elementwise cube.
133    Cube = 5,
134
135    // ---- Category B: roots ----
136    /// `y = sqrt(x)`.
137    Sqrt = 10,
138    /// `y = 1 / sqrt(x)` — reciprocal square root.
139    Rsqrt = 11,
140    /// `y = cbrt(x)` — cube root.
141    Cbrt = 12,
142
143    // ---- Category B: exp / log family ----
144    /// `y = exp(x)`.
145    Exp = 20,
146    /// `y = 2^x`.
147    Exp2 = 21,
148    /// `y = exp(x) - 1`.
149    Expm1 = 22,
150    /// `y = ln(x)` — natural log.
151    Log = 23,
152    /// `y = log_2(x)`.
153    Log2 = 24,
154    /// `y = log_10(x)`.
155    Log10 = 25,
156    /// `y = ln(1 + x)`.
157    Log1p = 26,
158
159    // ---- Category B: trig ----
160    /// `y = sin(x)`.
161    Sin = 30,
162    /// `y = cos(x)`.
163    Cos = 31,
164    /// `y = tan(x)`.
165    Tan = 32,
166    /// `y = asin(x)`.
167    Asin = 33,
168    /// `y = acos(x)`.
169    Acos = 34,
170    /// `y = atan(x)`.
171    Atan = 35,
172
173    // ---- Category B: hyperbolic ----
174    /// `y = sinh(x)`.
175    Sinh = 40,
176    /// `y = cosh(x)`.
177    Cosh = 41,
178    /// `y = tanh(x)`.
179    Tanh = 42,
180    /// `y = asinh(x)`.
181    Asinh = 43,
182    /// `y = acosh(x)`.
183    Acosh = 44,
184    /// `y = atanh(x)`.
185    Atanh = 45,
186
187    // ---- Category B: rounding ----
188    /// `y = floor(x)`.
189    Floor = 50,
190    /// `y = ceil(x)`.
191    Ceil = 51,
192    /// `y = round(x)` — round-half-to-even (PyTorch convention).
193    Round = 52,
194    /// `y = trunc(x)` — truncate toward zero.
195    Trunc = 53,
196    /// `y = x - trunc(x)` — fractional part with sign of `x`.
197    Frac = 54,
198
199    // ---- Category B: special functions ----
200    /// `y = erf(x)`.
201    Erf = 60,
202    /// `y = erfc(x) = 1 - erf(x)`.
203    Erfc = 61,
204    /// `y = erfinv(x)`.
205    Erfinv = 62,
206    /// `y = lgamma(x) = ln(|Γ(x)|)`.
207    Lgamma = 63,
208    /// `y = digamma(x) = Γ'(x) / Γ(x)`.
209    Digamma = 64,
210
211    // ---- Category B: bitwise / integer (int-typed only) ----
212    /// `y = ~x` — bitwise NOT (integer dtypes).
213    BitwiseNot = 70,
214    /// `y = popcount(x)` — population count of set bits (integer).
215    Popcount = 71,
216    /// `y = clz(x)` — count leading zeros (integer).
217    Clz = 72,
218    /// `y = ctz(x)` — count trailing zeros (integer).
219    Ctz = 73,
220
221    // ---- Category B': activations (unparameterized) ----
222    /// `y = relu(x) = max(x, 0)`.
223    Relu = 100,
224    /// `y = gelu(x)` — ERF-EXACT Gaussian Error Linear Unit,
225    /// `0.5·x·(1+erf(x/√2))` — NOT the tanh approximation (that's
226    /// [`Self::GeluTanh`]). The sys-level `unary_gelu_erf_*` symbols
227    /// are a bit-identical alias of the `unary_gelu_*` symbols this
228    /// variant dispatches to.
229    Gelu = 101,
230    /// `y = gelu_tanh(x)` — tanh APPROXIMATION of gelu,
231    /// `0.5·x·(1+tanh(√(2/π)·(x+0.044715·x³)))`. Diverges from the
232    /// erf-exact [`Self::Gelu`] by up to ~1e-4.
233    GeluTanh = 102,
234    /// `y = silu(x) = x · sigmoid(x)`. Also known as Swish-1.
235    Silu = 103,
236    /// `y = mish(x) = x · tanh(softplus(x))`.
237    Mish = 104,
238    /// `y = sigmoid(x) = 1 / (1 + exp(-x))`.
239    Sigmoid = 105,
240    /// `y = logit(x) = log(x / (1 - x))`. Inverse of sigmoid.
241    Logit = 106,
242    /// `y = softplus(x) = ln(1 + exp(x))`.
243    Softplus = 107,
244    /// `y = softsign(x) = x / (1 + |x|)`.
245    Softsign = 108,
246    /// `y = tanhshrink(x) = x - tanh(x)`.
247    Tanhshrink = 109,
248    /// `y = relu6(x) = min(max(x, 0), 6)`.
249    Relu6 = 110,
250    /// `y = hardswish(x)` — piecewise-linear approximation of swish.
251    Hardswish = 111,
252    /// `y = hardsigmoid(x)` — piecewise-linear approximation of sigmoid.
253    Hardsigmoid = 112,
254    /// `y = hardtanh(x, -1, +1)` — piecewise-linear clamp.
255    Hardtanh = 113,
256    /// `y = selu(x)` — scaled exponential linear unit.
257    Selu = 114,
258    /// `y = leaky_relu(x) = x if x > 0 else α·x`. Hardcoded α = 0.01 in
259    /// the current bespoke kernel; will re-emit as a fanout from a
260    /// parameterized-unary plan once that infrastructure lands.
261    LeakyRelu = 115,
262    /// `y = elu(x) = x if x > 0 else α·(exp(x) - 1)`. Hardcoded α = 1.0
263    /// in the current bespoke kernel; same parameterization story as
264    /// `LeakyRelu`.
265    Elu = 116,
266    /// `y = hardshrink(x) = x if |x| > λ else 0`. Hardcoded λ = 0.5 in
267    /// the current bespoke kernel; same parameterization story as
268    /// `LeakyRelu`.
269    Hardshrink = 117,
270    /// `y = softshrink(x) = x - λ if x > λ; x + λ if x < -λ; else 0`.
271    /// Hardcoded λ = 0.5 in the current bespoke kernel; same
272    /// parameterization story as `LeakyRelu`.
273    Softshrink = 118,
274    /// Reserved — `threshold(x; t, v) = x if x > t else v`. Needs the
275    /// parameterized-unary plan (two scalar parameters); not wired yet.
276    Threshold = 119,
277    /// `prelu(x; α) = x if x > 0 else α·x` with per-channel learnable α
278    /// vector (or single scalar α). Uses a distinct plan shape
279    /// (`PReluPlan` / `PReluBackwardPlan`) because α is a tensor operand,
280    /// not a scalar parameter. Wired in Milestone 5.3.
281    PReLU = 120,
282    /// `powi(x; n) = x^n` for a fixed runtime *integer* exponent `n`.
283    /// Distinct from the generic [`BinaryKind::Pow`] (which takes an
284    /// f32 exponent tensor) because the integer-only path can use
285    /// power-by-squaring — faster than `__expf(n · __logf(x))` and
286    /// also well-defined for negative `x` (real `pow(-1.5, 2) = 2.25`,
287    /// no NaN). The exponent is threaded via the `params: [f32; 2]`
288    /// slot 0 with a host-side cast (`n as f32`); slot 1 is unused.
289    /// Reasonable |n| values round-trip through f32 exactly (≤ 2^24).
290    /// Phase 12.1 wires `{f32, f16, bf16, f64}` through `UnaryParamPlan`.
291    PowI = 121,
292    /// `y = step(x) = 1 if x > 0 else 0` — Heaviside step function.
293    /// `step(0) = 0` and `step(-0.0) = 0` (`x > 0` is false at both
294    /// zeros); NaN → 0 (`NaN > 0` is false), matching PyTorch's
295    /// `heaviside(x, values=0)` for the `>` branch. Wires the Phase 31
296    /// `unary_step_*` kernels.
297    Step = 122,
298
299    // ---- Category B: dtype / scalar-shape ops ----
300    /// `y = (TOut) x` — dtype conversion. Heterogeneous input / output
301    /// element types, so it goes through its own `CastPlan` (not the
302    /// same-dtype `UnaryPlan<T, N>`). The discriminant lives here for
303    /// telemetry / SKU-tagging consistency with the rest of the unary
304    /// family. Wired from `fuel-cuda-kernels/cast.cu`.
305    Cast = 130,
306    /// `y = a * x + b` — fused affine (multiply-add) with scalar
307    /// parameters `a` / `b`. Same-dtype input/output but carries two
308    /// scalar parameters, so it gets its own `AffinePlan` (the unified
309    /// `UnaryPlan<T, N>` doesn't carry kernel parameters). Wired from
310    /// `fuel-cuda-kernels/affine.cu`.
311    Affine = 131,
312}
313
314/// Ternary elementwise op discriminant.
315///
316/// Stored as `u16` in [`crate::KernelSku::op`] when
317/// `category == OpCategory::TernaryElementwise`. Same-dtype-input,
318/// same-dtype-output ops only — [`Self::Where`] (which takes a bool
319/// cond + two value tensors) is reserved here but won't be wired via
320/// the same-dtype `TernaryPlan<T, N>`; it gets its own plan shape in
321/// a future session.
322///
323/// Today only [`Self::Clamp`] on `f32` is wired — the Phase 3 ternary
324/// trailblazer SKU. The remaining ops + non-f32 dtypes follow in
325/// fanout sessions.
326#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
327#[repr(u16)]
328#[non_exhaustive]
329pub enum TernaryKind {
330    /// `y = min(max(x, lo), hi)` — clamp `x` to `[lo, hi]`. Trailblazer.
331    Clamp = 0,
332    /// `y = a * b + c` — fused multiply-add. PyTorch `torch.addcmul(c, a, b)`
333    /// with value = 1.
334    Fma = 1,
335    /// `y = self + value * t1 * t2` — PyTorch `addcmul`. Reserved for
336    /// a future parameterized-ternary path (the scalar `value` is a
337    /// runtime parameter, not a tensor operand).
338    Addcmul = 2,
339    /// `y = self + value * t1 / t2` — PyTorch `addcdiv`. Same
340    /// parameterization story as `Addcmul`.
341    Addcdiv = 3,
342    /// `y = cond ? a : b` — element-select. Heterogeneous-dtype inputs
343    /// (cond is bool, a / b match output type) — needs its own plan
344    /// shape, won't be wired via the same-dtype `TernaryPlan`.
345    Where = 4,
346}
347
348/// Gated-activation op discriminant (category C').
349///
350/// Stored as `u16` in [`crate::KernelSku::op`] when
351/// `category == OpCategory::GatedActivation`. All variants follow the
352/// same shape: split input `x` along `split_dim` into two halves
353/// `(a, b)`, output `y = a · gate(b)`. The `gate` function varies by
354/// variant.
355///
356/// Today the FW + BW are wired for `{Glu, ReGlu, SwiGlu, GeGlu} × {f32,
357/// f16, bf16, f64}`. SwiGLU is the trailblazer (highest LLM relevance).
358#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
359#[repr(u16)]
360#[non_exhaustive]
361pub enum GatedActivationKind {
362    /// `y = a · sigmoid(b)` — PyTorch `torch.nn.functional.glu`.
363    Glu = 0,
364    /// `y = a · relu(b)`.
365    ReGlu = 1,
366    /// `y = a · silu(b) = a · b · sigmoid(b)` — Llama / Mistral / Gemma.
367    SwiGlu = 2,
368    /// `y = a · gelu(b)` (exact, erf-based).
369    GeGlu = 3,
370}
371
372/// Padding mode for [`crate::ops::ShapeLayoutKind::Pad`].
373///
374/// Today only [`Self::Constant`] is wired in the Phase 3 trailblazer.
375/// Reflect / Replicate / Circular follow in fanout sessions — each
376/// changes the kernel body's "what value goes in the pad region"
377/// branch but keeps the same plan shape.
378#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
379#[repr(u16)]
380#[non_exhaustive]
381pub enum PadMode {
382    /// Pad with a constant value (`PadDescriptor::value`).
383    Constant = 0,
384    /// Reflect input across the boundary (no edge duplication).
385    Reflect = 1,
386    /// Replicate the boundary value into the pad region.
387    Replicate = 2,
388    /// Wrap-around padding (also called "circular").
389    Circular = 3,
390}
391
392/// Shape / layout op discriminant — Category N.
393///
394/// Tags the kernel SKU for telemetry / autotuner-cache keys. Each
395/// variant has its own Plan type today (PadPlan, ConcatPlan, …)
396/// because their descriptor / args shapes differ enough that one
397/// `ShapeLayoutPlan<T, N>` doesn't fit. The enum exists so all of
398/// them populate `KernelSku::op` from a shared discriminant space.
399#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
400#[repr(u16)]
401#[non_exhaustive]
402pub enum ShapeLayoutKind {
403    /// `F.pad(x, pad, mode='constant', value=v)` — Phase 3 trailblazer.
404    Pad = 0,
405    /// `torch.cat(tensors, dim)` — variable-arity input. Reserved.
406    Concat = 1,
407    /// Materialized `torch.permute(x, dims)` (strided-view materialization
408    /// when needed). Reserved.
409    Permute = 2,
410    /// `x.repeat(...)` / `torch.tile(x, ...)`. Reserved.
411    Repeat = 3,
412    /// `torch.flip(x, dims)` — reverse along axes. Reserved.
413    Flip = 4,
414    /// `torch.roll(x, shifts, dims)` — shift along axes. Reserved.
415    Roll = 5,
416    /// `torch.meshgrid(*tensors)` — N rank-1 → N rank-N. Reserved.
417    Meshgrid = 6,
418    /// `torch.full(shape, value)` / `Tensor.fill_(value)` — fill every
419    /// element of an output tensor with a scalar constant. Wired from
420    /// `fuel-cuda-kernels/fill.cu`.
421    Fill = 7,
422    /// `dest[start_0..end_0, ..., start_{N-1}..end_{N-1}] = source`
423    /// (assign, not accumulate). Per-axis range write. Phase 13.1
424    /// trailblazer — driven by Fuel team's persistent KV-cache append
425    /// (autoregressive decoding). See
426    /// `baracuda_kernels::WriteSlicePlan`.
427    WriteSlice = 8,
428    /// Strided→contiguous materialization (`torch.Tensor.contiguous`).
429    /// Phase 13.2: closes the D2H→CPU contiguize→H2D fallback cliff
430    /// for non-contiguous CUDA inputs. Byte-level dtype-agnostic
431    /// (sizeof-templated kernel) covering every byte-aligned dtype;
432    /// nibble (S4 / U4) shipped behind a documented innermost-stride
433    /// constraint. See `baracuda_kernels::ContiguizePlan`.
434    Contiguize = 9,
435    /// `torch.triu(input, diagonal)` — keep upper triangular part of
436    /// the last two dims of `input`; zero everything below the
437    /// `diagonal`-th diagonal. Batch dims (anything before the last
438    /// two) are independently masked. Phase 13.4 trailblazer — driven
439    /// by Fuel team's CPU-only triu/tril gap. See
440    /// `baracuda_kernels::TriuPlan`.
441    Triu = 10,
442    /// `torch.tril(input, diagonal)` — keep lower triangular part of
443    /// the last two dims of `input`; zero everything above the
444    /// `diagonal`-th diagonal. Sibling of [`Self::Triu`] with the
445    /// predicate flipped. See `baracuda_kernels::TrilPlan`.
446    Tril = 11,
447}
448
449/// Index-returning reduction discriminant — Phase 4 (`ArgReducePlan`).
450///
451/// Distinct from [`ReduceKind`] because the output dtype is i64
452/// (index), not the input value dtype. Goes through its own plan
453/// shape (`ArgReducePlan<T, N>`) for the heterogeneous-output-dtype
454/// case.
455#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
456#[repr(u16)]
457#[non_exhaustive]
458pub enum ArgReduceKind {
459    /// Index of the maximum along the reduced axis. Ties broken by
460    /// first occurrence (smallest index wins) — PyTorch convention.
461    Argmax = 0,
462    /// Index of the minimum along the reduced axis.
463    Argmin = 1,
464}
465
466/// Reduction op discriminant — Phase 4 (Category E).
467///
468/// Output shape differs from input: the reduced axis collapses to size
469/// 1 (keepdim convention). Other variants are reserved for fanout.
470#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
471#[repr(u16)]
472#[non_exhaustive]
473pub enum ReduceKind {
474    /// Sum along the reduced axis. Phase 4 trailblazer.
475    Sum = 0,
476    /// Arithmetic mean along the reduced axis.
477    Mean = 1,
478    /// Maximum value along the reduced axis.
479    Max = 2,
480    /// Minimum value along the reduced axis.
481    Min = 3,
482    /// Product along the reduced axis.
483    Prod = 4,
484    /// Sample variance (Bessel-corrected) along the reduced axis.
485    Var = 5,
486    /// Sample standard deviation along the reduced axis.
487    Std = 6,
488    /// `||x||_2` along the reduced axis.
489    Norm2 = 7,
490    /// `argmax` along the reduced axis — returns indices (different
491    /// output dtype). Will need a separate plan shape, reserved here.
492    Argmax = 8,
493    /// `argmin` along the reduced axis. Will need a separate plan
494    /// shape.
495    Argmin = 9,
496    /// `any` (logical OR) along the reduced axis.
497    Any = 10,
498    /// `all` (logical AND) along the reduced axis.
499    All = 11,
500    /// `logsumexp(x) = log(sum(exp(x - max)))`, numerically stable.
501    LogSumExp = 12,
502    /// `trace(M) = sum(diag(M))` — sum of the diagonal of a 2-D
503    /// square matrix. Reduces *both* axes via the `i == i` constraint
504    /// rather than a single reduce-axis, so dispatch goes through a
505    /// dedicated `TracePlan` (separate from `ReducePlan`); the
506    /// discriminant lives here for telemetry / SKU-tagging consistency
507    /// with the rest of the reduction family.
508    Trace = 13,
509    /// `count_nonzero(x)` along the reduced axis — output is i64
510    /// (PyTorch `torch.count_nonzero` returns int64). Heterogeneous
511    /// output dtype (always i64 regardless of input), so dispatch
512    /// goes through a dedicated `CountReducePlan` (separate from
513    /// `ReducePlan`); the discriminant lives here for telemetry /
514    /// SKU-tagging consistency with the rest of the reduction family.
515    CountNonzero = 14,
516}
517
518/// Broadcast-reverse reduction op discriminant — `ReduceToPlan`.
519///
520/// The autograd primitive that undoes a forward `BroadcastTo`: for
521/// each output cell, reduce every input cell that broadcasts TO it.
522/// Distinct from [`ReduceKind`] because the reduction collapses an
523/// arbitrary *set* of axes in one launch (every dim where
524/// `output_shape[d] == 1` while `input_shape[d] != 1`) rather than a
525/// single `reduce_axis`, so dispatch goes through its own plan shape
526/// (`ReduceToPlan<T, N>` in `baracuda-kernels`).
527///
528/// Discriminants mirror [`ReduceKind`]'s values for the same logical
529/// op so `KernelSku::op` tags consistently across the reduction
530/// family (hence the gap at 1 — there is no broadcast-reverse Mean).
531#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
532#[repr(u16)]
533#[non_exhaustive]
534pub enum ReduceToOp {
535    /// Sum over the broadcast set. Identity on an empty reduce set
536    /// (any reduced `input_shape[d] == 0`): `0`.
537    Sum = 0,
538    /// Maximum over the broadcast set. Identity on an empty reduce
539    /// set: `-FLT_MAX` / `-DBL_MAX` per the kernel's `AccMax` policy —
540    /// the most-negative *finite* value for f32 / f64 outputs. For
541    /// f16 / bf16 the f32 identity overflows the storage dtype on the
542    /// final narrowing store and lands as `-inf`.
543    Max = 2,
544    /// Minimum over the broadcast set. Identity on an empty reduce
545    /// set: `+FLT_MAX` / `+DBL_MAX` per the kernel's `AccMin` policy —
546    /// the most-positive *finite* value for f32 / f64 outputs; `+inf`
547    /// for f16 / bf16 (same narrowing overflow as [`Self::Max`]).
548    Min = 3,
549    /// Product over the broadcast set. Identity on an empty reduce
550    /// set: `1`.
551    Prod = 4,
552}
553
554/// Softmax-family op discriminant — category H from the comprehensive
555/// plan.
556///
557/// Stored as `u16` in [`crate::KernelSku::op`] when
558/// `category == OpCategory::Softmax`. All variants apply a
559/// length-preserving transform along a single axis (output shape ==
560/// input shape — distinct from reductions, like scans).
561///
562/// Today wired: `{Softmax, LogSoftmax} × {f32, f16, bf16, f64}` —
563/// FW + BW. `GumbelSoftmax` (needs RNG state from Phase 4 random) and
564/// `Sparsemax` (different gradient — projection onto simplex) are
565/// reserved-but-deferred.
566#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
567#[repr(u16)]
568#[non_exhaustive]
569pub enum SoftmaxKind {
570    /// `y[k] = exp(x[k] - max(x)) / Σ_j exp(x[j] - max(x))`
571    /// — numerically stable softmax.
572    Softmax = 0,
573    /// `y[k] = x[k] - logsumexp(x)` — log-domain softmax, also stable.
574    /// Output is the elementwise log of `Softmax(x)`.
575    LogSoftmax = 1,
576    /// `y = (x + Gumbel(0,1)) / τ → softmax` — reserved.
577    GumbelSoftmax = 2,
578    /// `y = ProjSimplex(x)` — reserved (different gradient than softmax).
579    Sparsemax = 3,
580}
581
582/// Scan (associative prefix) op discriminant — category F from the
583/// comprehensive plan.
584///
585/// Stored as `u16` in [`crate::KernelSku::op`] when
586/// `category == OpCategory::Scan`. Output shape equals input shape —
587/// scans are length-preserving along the scan axis (in contrast with
588/// reductions, which collapse the axis to size 1). Inclusive scan by
589/// default (PyTorch convention: `y[i] = op(x[0], x[1], …, x[i])`).
590/// Direction is controlled by the descriptor's `reverse` flag.
591///
592/// Today wired: `{Cumsum} × {f32, f16, bf16, f64}` (FW + BW) as the
593/// scan trailblazer. Cumprod / Cummax / Cummin land in fanout;
594/// LogCumsumExp and the JAX-style generic `associative_scan` are
595/// reserved-but-deferred (numerics / generic-functor work).
596#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
597#[repr(u16)]
598#[non_exhaustive]
599pub enum ScanKind {
600    /// `y[i] = Σ_{j ≤ i} x[j]` — inclusive prefix sum.
601    Cumsum = 0,
602    /// `y[i] = ∏_{j ≤ i} x[j]` — inclusive prefix product.
603    Cumprod = 1,
604    /// `y[i] = max(x[0..=i])` — running maximum.
605    Cummax = 2,
606    /// `y[i] = min(x[0..=i])` — running minimum.
607    Cummin = 3,
608    /// `y[i] = log(Σ_{j ≤ i} exp(x[j]))` — numerically stable (running
609    /// max subtraction). Reserved.
610    LogCumsumExp = 4,
611}
612
613/// Binary comparison op discriminant.
614///
615/// Stored as `u16` in [`crate::KernelSku::op`] when
616/// `category == OpCategory::BinaryElementwise` and the SKU is from the
617/// **comparison family** — distinguished from [`BinaryKind`] because
618/// the output dtype is fixed to `u8` (PyTorch / NumPy convention: bool
619/// stored as 1 byte, 0 = false, 1 = true) regardless of the input
620/// element type.
621///
622/// Today only [`Self::Eq`] on `f32` is wired — the Phase 3 comparison
623/// trailblazer. The other variants are reserved discriminants for the
624/// fanout sessions.
625///
626/// Why a separate enum (rather than reusing [`BinaryKind`]): the
627/// dispatch shape differs — these ops produce a different dtype than
628/// they consume, so they need their own Plan type
629/// (`BinaryCmpPlan<T, N>` with `TensorMut<u8>` output) instead of
630/// `BinaryPlan<T, N>` with `TensorMut<T>` output. The reserved Eq /
631/// Ne / Gt / Ge / Lt / Le slots in `BinaryKind` are vestigial — they
632/// will never be wired into the same-dtype binary path.
633#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
634#[repr(u16)]
635#[non_exhaustive]
636pub enum BinaryCmpKind {
637    /// `y = (a == b)` — elementwise equality. Trailblazer SKU.
638    Eq = 0,
639    /// `y = (a != b)`.
640    Ne = 1,
641    /// `y = (a > b)`.
642    Gt = 2,
643    /// `y = (a >= b)`.
644    Ge = 3,
645    /// `y = (a < b)`.
646    Lt = 4,
647    /// `y = (a <= b)`.
648    Le = 5,
649}
650
651/// Normalization op discriminant — category G from the comprehensive plan.
652///
653/// Stored as `u16` in [`crate::KernelSku::op`] when
654/// `category == OpCategory::Normalization`. The variants differ in
655/// which axes are reduced for the per-row statistics and how the
656/// affine parameters (gamma / beta) are indexed.
657///
658/// Today wired: `{RMSNorm, LayerNorm, BatchNorm, GroupNorm,
659/// InstanceNorm} × {f32, f16, bf16, f64}` — FW + BW. RMSNorm /
660/// LayerNorm support **multi-axis normalization** via a bitmask
661/// (PyTorch's `normalized_shape` — must be a suffix of the input
662/// shape). InstanceNorm is implemented as a thin wrapper around
663/// GroupNorm with `num_groups == c_extent` (shares kernel symbols).
664///
665/// BatchNorm is **training-mode-only** for the trailblazer — it
666/// computes per-channel stats from the batch and saves them for BW.
667/// Inference mode (use of running statistics, reducing to a per-
668/// channel affine multiply) is reserved for a follow-up. `WeightNorm`
669/// (a parameterization rather than a plain op) and `LocalResponseNorm`
670/// (rarely used today) are explicitly deferred.
671#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
672#[repr(u16)]
673#[non_exhaustive]
674pub enum NormalizationKind {
675    /// `y = x / sqrt(mean(x², over norm_axes) + eps) * gamma`.
676    /// Llama / Mistral / Gemma block-pre-norm. Trailblazer SKU.
677    RMSNorm = 0,
678    /// `y = (x - mean) / sqrt(var + eps) * gamma + beta`. PyTorch's
679    /// `torch.nn.LayerNorm` with biased / "population" variance.
680    LayerNorm = 1,
681    /// Per-group-of-channels statistics. `y[n, c, ...] = (x[n, c, ...] -
682    /// mean[n, g]) / sqrt(var[n, g] + eps) * gamma[c] + beta[c]`,
683    /// `g = c / (C / num_groups)`. PyTorch `torch.nn.GroupNorm`.
684    GroupNorm = 2,
685    /// Per-channel statistics across batch + spatial. Training-mode
686    /// only — saves `(saved_mean, saved_rstd)` of shape `[C]`. Inference
687    /// mode (running stats) deferred. PyTorch `torch.nn.BatchNormNd`.
688    BatchNorm = 3,
689    /// Per-`(sample, channel)` statistics across spatial only. PyTorch
690    /// `torch.nn.InstanceNormNd`. Equivalent to GroupNorm with
691    /// `num_groups == num_channels`; same kernel symbols.
692    InstanceNorm = 4,
693}
694
695/// Loss op discriminant — category R from the comprehensive plan.
696///
697/// Stored as `u16` in [`crate::KernelSku::op`] when
698/// `category == OpCategory::Loss`. Each variant has its own Plan type
699/// today (different argument shapes — MSE / BCE / KLDiv take two
700/// same-dtype tensor inputs, NLL / CrossEntropy take a `T` input plus an
701/// `i64` target index tensor) but they share the [`LossReduction`]
702/// enum for selecting per-cell / mean / sum output shape.
703///
704/// Today wired: `{Mse, Nll, CrossEntropy, Bce, KlDiv} × {f32, f16, bf16,
705/// f64}` — FW + BW. `HingeEmbedding`, `L1`, `SmoothL1`, `MarginRanking`,
706/// `TripletMargin`, `CtcLoss`, and `PoissonNllLoss` are reserved
707/// discriminants for future fanout.
708#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
709#[repr(u16)]
710#[non_exhaustive]
711pub enum LossKind {
712    /// `y = mean((pred - target)²)` (or sum / per-cell). PyTorch
713    /// `torch.nn.functional.mse_loss`.
714    Mse = 0,
715    /// `y = -mean(input[target_idx[i]])` along the feature axis. PyTorch
716    /// `torch.nn.functional.nll_loss`. Heterogeneous-dtype: input `T`,
717    /// target `i64`.
718    Nll = 1,
719    /// `y = NLLLoss(LogSoftmax(input), target)` — fused for numerical
720    /// stability. PyTorch `torch.nn.functional.cross_entropy`. Today wired
721    /// for class-index target only (`i64`); soft-target CE is reserved.
722    CrossEntropy = 2,
723    /// `y = -mean(target·log(pred) + (1-target)·log(1-pred))`. PyTorch
724    /// `torch.nn.functional.binary_cross_entropy`. Caller ensures
725    /// pred ∈ (0, 1).
726    Bce = 3,
727    /// `y = mean(target·(log(target) - input))`. PyTorch
728    /// `torch.nn.functional.kl_div` with the "input is log-prob"
729    /// convention.
730    KlDiv = 4,
731    /// `y = mean(|pred - target|)` (or sum / per-cell). PyTorch
732    /// `torch.nn.functional.l1_loss`.
733    L1 = 5,
734    /// Smooth L1 / "Huber-with-β" loss. PyTorch
735    /// `torch.nn.functional.smooth_l1_loss`.
736    SmoothL1 = 6,
737    /// `y = mean(input if t==1 else max(0, margin - input))`. PyTorch
738    /// `torch.nn.functional.hinge_embedding_loss`. Heterogeneous-dtype:
739    /// input is `T`, target is `i64` (±1).
740    HingeEmbedding = 7,
741    /// `y = mean(max(0, -t · (x1 - x2) + margin))`. PyTorch
742    /// `torch.nn.functional.margin_ranking_loss`. Target `t` is `T` (±1).
743    MarginRanking = 8,
744    /// `y = mean(max(0, ||a-p||_p - ||a-n||_p + margin))`. PyTorch
745    /// `torch.nn.functional.triplet_margin_loss`. 2-D input `[N, D]`.
746    TripletMargin = 9,
747    /// Reserved — `torch.nn.functional.ctc_loss`.
748    Ctc = 10,
749    /// `y = mean(exp(input) - target · input)` (default `log_input=true`).
750    /// PyTorch `torch.nn.functional.poisson_nll_loss`.
751    PoissonNll = 11,
752    /// Huber loss (separate from SmoothL1 — PyTorch
753    /// `torch.nn.functional.huber_loss`).
754    Huber = 12,
755    /// Numerically stable BCE for raw logits. PyTorch
756    /// `torch.nn.functional.binary_cross_entropy_with_logits`.
757    BceWithLogits = 13,
758    /// Gaussian NLL. PyTorch `torch.nn.GaussianNLLLoss`.
759    GaussianNll = 14,
760    /// `y = (1 - cos(x1, x2))` if `t==1` else `max(0, cos(x1, x2) - margin)`,
761    /// then mean. PyTorch `torch.nn.functional.cosine_embedding_loss`.
762    /// 2-D input `[N, D]`. Target is `T` (±1.0).
763    CosineEmbedding = 15,
764    /// `y = mean_i Σ_{j != t_i} max(0, margin - input[i, t_i] + input[i, j])^p / C`.
765    /// PyTorch `torch.nn.functional.multi_margin_loss`. Input `[N, C]`,
766    /// target `[N]` `i64` class indices.
767    MultiMargin = 16,
768    /// Multi-label margin loss. PyTorch
769    /// `torch.nn.functional.multilabel_margin_loss`. Input `[N, C]`,
770    /// target `[N, C]` `i64` (positive class indices followed by -1
771    /// padding sentinel).
772    MultilabelMargin = 17,
773    /// `y = mean(-mean_c(target·log(sigmoid(x)) + (1-target)·log(1-sigmoid(x))))`.
774    /// PyTorch `torch.nn.functional.multilabel_soft_margin_loss`.
775    /// Input `[N, C]`, target `[N, C]` `T`.
776    MultilabelSoftMargin = 18,
777    /// Fused Linear Cross-Entropy. `loss = CE(input @ weight^T, target)`
778    /// **without** materializing the `[BT, V]` logits tensor — the
779    /// projection GEMM and the cross-entropy reduction run together
780    /// in a chunked outer loop. Backward produces `grad_input` and
781    /// `grad_weight` directly during the forward pass; backward call
782    /// just multiplies them by the upstream `dy` scalar. Algorithm:
783    /// LinkedIn Liger-Kernel
784    /// (`liger_kernel/ops/fused_linear_cross_entropy.py`).
785    /// Saves ~5-10 GiB at `vocab=128K, BT=16K` (Llama-3-class) by
786    /// streaming logits in `chunk_size`-row tiles.
787    FusedLinearCrossEntropy = 19,
788}
789
790/// CrossEntropy target-tensor kind. Selects between PyTorch's two
791/// target formats: class indices (`i64[N]`) and soft probabilities
792/// (`T[N, C]` — used for label smoothing / distillation).
793#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
794#[repr(u8)]
795pub enum CrossEntropyTargetKind {
796    /// Target is class-index `i64[N]`.
797    ClassIndex = 0,
798    /// Target is soft probability `T[N, C]` (same dtype as input).
799    SoftProb = 1,
800}
801
802/// Loss reduction mode. Selects the output shape and the final scalar
803/// scaling for a [`LossKind`] plan. PyTorch's `reduction` parameter.
804#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
805#[repr(u8)]
806pub enum LossReduction {
807    /// Output is per-cell (same shape as the loss surface). No reduction.
808    None = 0,
809    /// Output is a scalar — sum of per-cell terms divided by element count.
810    Mean = 1,
811    /// Output is a scalar — sum of per-cell terms (no divide).
812    Sum = 2,
813}
814
815/// Random / sampling op discriminant.
816///
817/// Stored as `u16` in [`crate::KernelSku::op`] when
818/// `category == OpCategory::Random`. Phase 4.5 wires:
819/// - [`Self::Uniform`] (f32, f64) — `y ~ U(low, high)` via cuRAND.
820/// - [`Self::Normal`] (f32, f64) — `y ~ N(mean, std)` via cuRAND.
821/// - [`Self::Bernoulli`] (Bool output) — `y = (rand < p) ? 1 : 0` via
822///   cuRAND uniform + custom threshold kernel.
823///
824/// Multinomial / Randint / exponential / gamma / quasi-random are
825/// reserved discriminants for future milestones.
826#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
827#[repr(u16)]
828#[non_exhaustive]
829pub enum RandomKind {
830    /// `y[i] ~ U(low, high)` — uniform on the half-open interval. Plan
831    /// descriptor `param1 = low`, `param2 = high`.
832    Uniform = 0,
833    /// `y[i] ~ N(mean, std)` — Gaussian. Plan descriptor
834    /// `param1 = mean`, `param2 = stddev`.
835    Normal = 1,
836    /// `y[i] = 1 if uniform < p else 0`, Bool output. Plan descriptor
837    /// `param1 = p`. `param2` ignored.
838    Bernoulli = 2,
839    /// `y[b] = sample one cell from row probs[b, :]` using inverse-CDF
840    /// sampling. Phase 46 wires the FlashInfer sort-free Top-K /
841    /// Top-P / Min-P / combined Top-K + Top-P samplers under this
842    /// discriminant via the `TopKTopPSamplingPlan` in baracuda-kernels.
843    Multinomial = 3,
844}
845
846/// Linear-algebra (dense) op discriminant — covers the cuSOLVER family
847/// shipped in Milestone 6.3.
848///
849/// Stored as `u16` in [`crate::KernelSku::op`] when
850/// `category == OpCategory::Linalg`. Today the four canonical PyTorch /
851/// JAX dense linalg ops are wired:
852///
853/// - [`Self::Cholesky`] — `A = L · L^T` (symmetric positive-definite).
854///   Batched via `cusolverDnSpotrfBatched` / `cusolverDnDpotrfBatched`.
855/// - [`Self::Lu`] — `P · A = L · U`. Batched via
856///   `cusolverDnSgetrfBatched` / `cusolverDnDgetrfBatched`.
857/// - [`Self::Qr`] — `A = Q · R`. cuSOLVER has no batched variant; 2-D
858///   only.
859/// - [`Self::Svd`] — `A = U · diag(S) · V^T`. cuSOLVER 2-D only.
860///
861/// Dtype coverage is `f32` + `f64` — cuSOLVER's dense API does not
862/// support `f16` / `bf16` for these factorizations. Reserved variants
863/// (`Inverse`, `Eig`, `Solve`, `LeastSquares`, `MatrixExp`) follow in
864/// future milestones.
865#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
866#[repr(u16)]
867#[non_exhaustive]
868pub enum LinalgKind {
869    /// Cholesky factorization `A = L · L^T` (lower) or `A = U^T · U`
870    /// (upper). Input must be symmetric positive-definite.
871    Cholesky = 0,
872    /// LU factorization with partial pivoting `P · A = L · U`. Returns
873    /// the packed `LU` factors plus an `i32` pivot vector.
874    Lu = 1,
875    /// QR factorization `A = Q · R`. Computes full `Q` (`[M, M]`) and
876    /// the upper-triangular `R` (`[M, N]`) via `geqrf` + `ormqr`.
877    Qr = 2,
878    /// Singular value decomposition `A = U · diag(S) · V^T`. cuSOLVER
879    /// 2-D only; `full_matrices` controls whether `U`/`V^T` are full
880    /// (`[M,M]` / `[N,N]`) or thin (`[M,K]` / `[K,N]`) where
881    /// `K = min(M, N)`.
882    Svd = 3,
883    /// Matrix inverse `A^{-1}` via `getrf` + `getrs` over an identity
884    /// RHS. Wired in Milestone 6.9.
885    Inverse = 4,
886    /// General (non-symmetric) eigen-decomposition `A · v = λ · v`. Wired
887    /// via `cusolverDnXgeev` in Milestone 6.12. Always emits complex
888    /// eigenvalues (and optional left / right complex eigenvectors).
889    Eig = 5,
890    /// Linear solve `A · X = B` via `getrf` + `getrs`. Wired in
891    /// Milestone 6.9.
892    Solve = 6,
893    /// Least-squares solve `min ||A·x - b||²` via cuSOLVER's
894    /// mixed-precision iterative-refinement `_gels` routine. Wired in
895    /// Milestone 6.11.
896    LeastSquares = 7,
897    /// Reserved — matrix exponential / matrix functions.
898    MatrixExp = 8,
899    /// Batched QR factorization `A_b = Q_b · R_b` via
900    /// `cusolverDn*geqrfBatched`. Wired in Milestone 6.11.
901    BatchedQr = 9,
902    /// Batched SVD via Jacobi `cusolverDn*gesvdjBatched`. Wired in
903    /// Milestone 6.11.
904    BatchedSvd = 10,
905    /// Symmetric / Hermitian eigen-decomposition `A · v = λ · v` (real
906    /// eigenvalues). Wired via `cusolverDn{S,D}syevd` /
907    /// `cusolverDn{C,Z}heevd` in Milestone 6.12.
908    Eigh = 11,
909    /// Rectangular batched approximate-SVD via cuSOLVER's
910    /// `gesvdaStridedBatched`. Unlike [`Self::BatchedSvd`] (which is
911    /// square-only Jacobi), this routine accepts arbitrary `m × n` per
912    /// batch slot, uses element-strides between slots, and reports per-
913    /// slot residual Frobenius norms to a host array. Wired in
914    /// Milestone 6.15.
915    BatchedSvda = 12,
916    /// Bespoke batched-`ormqr` — applies the implicit `Q` from a
917    /// [`Self::BatchedQr`] packed output to a batch of matrices `C`,
918    /// all slots fused into one CUDA launch. cuSOLVER's `ormqr` is
919    /// non-batched, so in the small-matrix regime where batched-QR is
920    /// most useful the per-slot launch latency dominates; this bespoke
921    /// kernel amortizes one launch over the whole batch. Side = Left,
922    /// op ∈ {N, T} in the trailblazer (Right + complex variants
923    /// deferred). Wired in Milestone 6.14.
924    BatchedOrmqr = 13,
925    /// Bespoke "materialize dense Q and R from batched-`geqrf` packed
926    /// output". Tiny upper-triangle-copy kernel for R; identity-stage
927    /// + [`Self::BatchedOrmqr`] for Q. Wired in Milestone 6.14 as the
928    /// consumer of `BatchedOrmqrPlan`.
929    BatchedQrMaterialize = 14,
930    /// WY-blocked batched-`ormqr` — applies the implicit `Q` (or `Q^T`)
931    /// from a [`Self::BatchedQr`] packed output to a batch of matrices
932    /// `C` at GEMM-rates by fusing groups of `nb` consecutive Householder
933    /// reflectors into a block reflector `(I - V·T·V^T)` and applying it
934    /// via three cuBLAS strided-batched GEMMs per block. Sibling to
935    /// [`Self::BatchedOrmqr`] (the reflector-by-reflector GEMV-rates
936    /// variant); callers pick by problem size — WY wins decisively for
937    /// `M, N > ~16`, the reflector kernel wins for tiny inputs.
938    /// Side = Left, op ∈ {N, T} in the trailblazer. Wired in
939    /// Milestone 6.17.
940    BatchedOrmqrWy = 15,
941}
942
943/// Fill-mode tag for triangular linalg ops (Cholesky / triangular solve).
944///
945/// Selects whether the factor lives in the lower or upper triangle of
946/// the in-place output matrix. The row-major-input → column-major-cuSOLVER
947/// adapter at the plan layer flips this when handing the descriptor down
948/// to cuSOLVER (a row-major lower-L is bit-identical to a column-major
949/// upper-U^T over the same byte storage).
950#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
951#[repr(u8)]
952pub enum FillMode {
953    /// Lower triangular (the usual PyTorch / scipy default).
954    Lower = 0,
955    /// Upper triangular.
956    Upper = 1,
957}
958
959/// FFT-family op discriminant — Category U from the comprehensive plan.
960///
961/// Stored as `u16` in [`crate::KernelSku::op`] when
962/// `category == OpCategory::Fft`. Milestone 6.4 wires the four
963/// canonical PyTorch / JAX 1-D FFTs (`fft` / `ifft` / `rfft` / `irfft`)
964/// plus the two index-permutation helpers (`fftshift` / `ifftshift`).
965///
966/// 1-D only for the trailblazer. Multi-D FFTs (`fft2`, `fftn`, …) and
967/// arbitrary-axis FFTs follow in fanout sessions — they don't require
968/// new cuFFT bindings, just additional descriptor shape + plan glue.
969///
970/// Dtype coverage: `f32` (single precision) and `f64` (double
971/// precision) only. cuFFT's main API does not expose `f16` / `bf16`
972/// for native transforms. Callers needing reduced precision must cast
973/// on either side. Spectrum-domain tensors use [`crate::Complex32`] /
974/// [`crate::Complex64`] for the interleaved real/imag pairs.
975///
976/// Normalization: forward transforms are unnormalized; inverse
977/// transforms are normalized by `1/N` to match PyTorch's
978/// `norm="backward"` default. cuFFT itself returns `N · IFFT(x)`; the
979/// plan layer multiplies by `1/N` after the inverse exec.
980#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
981#[repr(u16)]
982#[non_exhaustive]
983pub enum FftKind {
984    /// `y = FFT(x)` — complex-to-complex forward transform (unnormalized).
985    /// PyTorch `torch.fft.fft`. Both input and output are complex with
986    /// the same shape `[batch, n]`.
987    Fft = 0,
988    /// `y = IFFT(x)` — complex-to-complex inverse transform, normalized
989    /// by `1/N` to match PyTorch's `norm="backward"`. PyTorch
990    /// `torch.fft.ifft`. Both input and output are complex `[batch, n]`.
991    Ifft = 1,
992    /// `y = RFFT(x)` — real-to-complex forward transform (unnormalized).
993    /// PyTorch `torch.fft.rfft`. Input is real `[batch, n]`, output is
994    /// complex `[batch, n/2 + 1]` (Hermitian-half).
995    Rfft = 2,
996    /// `y = IRFFT(x, n)` — complex-to-real inverse transform, normalized
997    /// by `1/N`. PyTorch `torch.fft.irfft`. Input is complex
998    /// `[batch, n/2 + 1]`, output is real `[batch, n]`. The output
999    /// length `n` is a required descriptor parameter (cannot be inferred
1000    /// from the Hermitian-half input shape — both `2*(n/2)` and
1001    /// `2*(n/2)+1` map to the same Hermitian-half length).
1002    Irfft = 3,
1003    /// `fftshift` — shift the zero-frequency component to the center of
1004    /// the spectrum. PyTorch `torch.fft.fftshift` (matches NumPy's
1005    /// `np.fft.fftshift`).
1006    ///
1007    /// Equivalent to `roll(x, n // 2)`, giving:
1008    /// `y[i] = x[(i - n // 2) mod n] = x[(i + (n+1) // 2) mod n]`.
1009    ///
1010    /// Bit-exact (pure index permutation, no arithmetic on values).
1011    FftShift = 4,
1012    /// `ifftshift` — true inverse of `fftshift`:
1013    /// `ifftshift(fftshift(x)) == x` for any `n`. PyTorch
1014    /// `torch.fft.ifftshift`.
1015    ///
1016    /// Equivalent to `roll(x, -(n // 2))`, giving:
1017    /// `y[i] = x[(i + n // 2) mod n]`.
1018    ///
1019    /// For even `n` this is identical to `fftshift` (the `n/2` offset
1020    /// is self-inverse mod `n`); for odd `n` the two cyclic offsets
1021    /// differ by one cell. Bit-exact.
1022    IfftShift = 5,
1023}
1024
1025/// Convolution-family op discriminant — Category I from the
1026/// comprehensive plan.
1027///
1028/// Stored as `u16` in [`crate::KernelSku::op`] when
1029/// `category == OpCategory::Convolution`. Each variant maps to a
1030/// distinct cuDNN exec path (forward, data-gradient, filter-gradient)
1031/// of the underlying convolution descriptor. The dimensional axis
1032/// (1-D / 2-D / 3-D), padding / stride / dilation, and depthwise /
1033/// transposed flavors live on the per-plan descriptor — they don't
1034/// fan out a separate enum slot here.
1035///
1036/// Today wired: `Conv2d` × `{f32, f64, f16, bf16}` (FW + BW data +
1037/// BW filter) via cuDNN. Conv1d / Conv3d / ConvTranspose* / depthwise
1038/// / `unfold` / `fold` are reserved discriminants for fanout
1039/// milestones.
1040#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
1041#[repr(u16)]
1042#[non_exhaustive]
1043pub enum ConvKind {
1044    /// 2-D convolution forward pass. PyTorch
1045    /// `torch.nn.functional.conv2d`. Trailblazer for Phase 7.
1046    Conv2d = 0,
1047    /// 2-D convolution data-gradient pass (computes `dx` from `dy`
1048    /// and the filter `w`). PyTorch's autograd-internal
1049    /// `conv2d_backward_input`.
1050    Conv2dBackwardData = 1,
1051    /// 2-D convolution filter-gradient pass (computes `dw` from `x`
1052    /// and `dy`). PyTorch's autograd-internal
1053    /// `conv2d_backward_weight`.
1054    Conv2dBackwardFilter = 2,
1055    /// 1-D convolution forward. Reserved.
1056    Conv1d = 3,
1057    /// 1-D convolution data-gradient. Reserved.
1058    Conv1dBackwardData = 4,
1059    /// 1-D convolution filter-gradient. Reserved.
1060    Conv1dBackwardFilter = 5,
1061    /// 3-D convolution forward. Reserved.
1062    Conv3d = 6,
1063    /// 3-D convolution data-gradient. Reserved.
1064    Conv3dBackwardData = 7,
1065    /// 3-D convolution filter-gradient. Reserved.
1066    Conv3dBackwardFilter = 8,
1067    /// 2-D transposed convolution (fractionally-strided / "deconv").
1068    /// Forward pass.
1069    ConvTranspose2d = 9,
1070    /// 2-D transposed convolution backward. Reserved — backward is
1071    /// dispatched through the same plan via `run_bw_data` / `run_dw`.
1072    ConvTranspose2dBackward = 10,
1073    /// Depthwise 2-D convolution (`groups == c_in`). Today callers
1074    /// route through the generic `Conv2dPlan` with `groups` set on
1075    /// the descriptor — cuDNN's `cudnnSetConvolutionGroupCount`
1076    /// detects the depthwise path automatically.
1077    DepthwiseConv2d = 11,
1078    /// `torch.nn.functional.unfold` — extract sliding windows. Reserved.
1079    Unfold = 12,
1080    /// `torch.nn.functional.fold` — inverse of unfold. Reserved.
1081    Fold = 13,
1082    /// 1-D transposed convolution forward.
1083    ConvTranspose1d = 14,
1084    /// 1-D transposed convolution data-gradient.
1085    ConvTranspose1dBackwardData = 15,
1086    /// 1-D transposed convolution filter-gradient.
1087    ConvTranspose1dBackwardFilter = 16,
1088    /// 2-D transposed convolution data-gradient.
1089    ConvTranspose2dBackwardData = 17,
1090    /// 2-D transposed convolution filter-gradient.
1091    ConvTranspose2dBackwardFilter = 18,
1092    /// 3-D transposed convolution forward.
1093    ConvTranspose3d = 19,
1094    /// 3-D transposed convolution data-gradient.
1095    ConvTranspose3dBackwardData = 20,
1096    /// 3-D transposed convolution filter-gradient.
1097    ConvTranspose3dBackwardFilter = 21,
1098    /// 2-D im2col — `torch.nn.functional.unfold` (Phase 19.3). Extracts
1099    /// sliding windows from an NCHW input into an
1100    /// `[N, C·kh·kw, h_out·w_out]` column-shaped matrix. Distinct from
1101    /// the reserved [`Self::Unfold`] discriminant for forward-source-
1102    /// compat; the 19.3 wiring routes through this discriminant.
1103    Im2Col2d = 22,
1104    /// 1-D im2col (NCL → `[N, C·kl, l_out]`).
1105    Im2Col1d = 23,
1106    /// 1-D col2im — inverse of [`Self::Im2Col1d`]. Atomic-add scatter.
1107    Col2Im1d = 24,
1108}
1109
1110/// Pooling-family op discriminant — Category J from the comprehensive
1111/// plan.
1112///
1113/// Stored as `u16` in [`crate::KernelSku::op`] when
1114/// `category == OpCategory::Pooling`. Each variant maps to a distinct
1115/// cuDNN pooling exec path (forward / backward) for one of three
1116/// pooling modes: max, average-include-padding, average-exclude-padding.
1117/// PyTorch's `nn.MaxPool2d` corresponds to [`Self::MaxPool2d`];
1118/// `nn.AvgPool2d` defaults to `count_include_pad=False` which maps to
1119/// [`Self::AvgPool2dExcludePad`].
1120///
1121/// Today wired: `{MaxPool2d, AvgPool2d} × {f32, f64, f16, bf16}` (FW +
1122/// BW) via cuDNN. 1-D / 3-D pooling, adaptive pooling, LP-pool, and
1123/// fractional-max-pool are reserved discriminants for fanout milestones.
1124#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
1125#[repr(u16)]
1126#[non_exhaustive]
1127pub enum PoolKind {
1128    /// 2-D max-pool forward. PyTorch `torch.nn.functional.max_pool2d`.
1129    /// Trailblazer for Phase 7 Milestone 7.2.
1130    MaxPool2d = 0,
1131    /// 2-D max-pool backward (data-gradient). PyTorch's autograd-
1132    /// internal `max_pool2d_with_indices_backward`.
1133    MaxPool2dBackward = 1,
1134    /// 2-D average-pool forward, **count-include-padding** denominator.
1135    /// Matches cuDNN's `*_COUNT_INCLUDE_PADDING` mode.
1136    AvgPool2dIncludePad = 2,
1137    /// 2-D average-pool backward, count-include-padding.
1138    AvgPool2dIncludePadBackward = 3,
1139    /// 2-D average-pool forward, **count-exclude-padding** denominator
1140    /// (PyTorch default — `nn.AvgPool2d` with `count_include_pad=False`).
1141    AvgPool2dExcludePad = 4,
1142    /// 2-D average-pool backward, count-exclude-padding.
1143    AvgPool2dExcludePadBackward = 5,
1144    /// 1-D max-pool forward (Phase 11.8). NCL layout via cuDNN's Nd
1145    /// pool descriptor with `W = 1`.
1146    MaxPool1d = 6,
1147    /// 1-D average-pool forward. Reserved.
1148    AvgPool1d = 7,
1149    /// 3-D max-pool forward (Phase 11.8). NCDHW layout via cuDNN's
1150    /// Nd pool descriptor.
1151    MaxPool3d = 8,
1152    /// 3-D average-pool forward. Reserved.
1153    AvgPool3d = 9,
1154    /// `torch.nn.functional.adaptive_max_pool*` — reserved.
1155    AdaptiveMaxPool = 10,
1156    /// `torch.nn.functional.adaptive_avg_pool*` — reserved.
1157    AdaptiveAvgPool = 11,
1158    /// `torch.nn.functional.lp_pool*` — reserved.
1159    LpPool = 12,
1160    /// `torch.nn.functional.fractional_max_pool*` — reserved.
1161    FractionalMaxPool = 13,
1162    /// 1-D max-pool backward.
1163    MaxPool1dBackward = 14,
1164    /// 1-D average-pool backward (count-include-padding).
1165    AvgPool1dIncludePadBackward = 15,
1166    /// 1-D average-pool forward (count-include-padding).
1167    AvgPool1dIncludePad = 16,
1168    /// 1-D average-pool forward (count-exclude-padding — PyTorch default).
1169    AvgPool1dExcludePad = 17,
1170    /// 1-D average-pool backward (count-exclude-padding).
1171    AvgPool1dExcludePadBackward = 18,
1172    /// 3-D max-pool backward.
1173    MaxPool3dBackward = 19,
1174    /// 3-D average-pool forward (count-include-padding).
1175    AvgPool3dIncludePad = 20,
1176    /// 3-D average-pool backward (count-include-padding).
1177    AvgPool3dIncludePadBackward = 21,
1178    /// 3-D average-pool forward (count-exclude-padding).
1179    AvgPool3dExcludePad = 22,
1180    /// 3-D average-pool backward (count-exclude-padding).
1181    AvgPool3dExcludePadBackward = 23,
1182    /// Adaptive average-pool 1-D (Phase 11.8 — cuDNN approximation).
1183    AdaptiveAvgPool1d = 24,
1184    /// Adaptive average-pool 1-D backward.
1185    AdaptiveAvgPool1dBackward = 25,
1186    /// Adaptive average-pool 2-D.
1187    AdaptiveAvgPool2d = 26,
1188    /// Adaptive average-pool 2-D backward.
1189    AdaptiveAvgPool2dBackward = 27,
1190    /// Adaptive average-pool 3-D.
1191    AdaptiveAvgPool3d = 28,
1192    /// Adaptive average-pool 3-D backward.
1193    AdaptiveAvgPool3dBackward = 29,
1194    /// Adaptive max-pool 1-D.
1195    AdaptiveMaxPool1d = 30,
1196    /// Adaptive max-pool 1-D backward.
1197    AdaptiveMaxPool1dBackward = 31,
1198    /// Adaptive max-pool 2-D.
1199    AdaptiveMaxPool2d = 32,
1200    /// Adaptive max-pool 2-D backward.
1201    AdaptiveMaxPool2dBackward = 33,
1202    /// Adaptive max-pool 3-D.
1203    AdaptiveMaxPool3d = 34,
1204    /// Adaptive max-pool 3-D backward.
1205    AdaptiveMaxPool3dBackward = 35,
1206    /// LP-pool 1-D (Phase 16.2 — bespoke fused kernel:
1207    /// `y = (Σ |x|^p)^(1/p)` over each pool window in one launch).
1208    LpPool1d = 36,
1209    /// LP-pool 2-D (Phase 16.2 — bespoke fused kernel).
1210    LpPool2d = 37,
1211    /// Fractional max-pool 2-D (Phase 16.3 — bespoke kernel; cuDNN has
1212    /// no fractional-pool primitive).
1213    FractionalMaxPool2d = 38,
1214    /// Fractional max-pool 3-D (Phase 16.3 — bespoke kernel).
1215    FractionalMaxPool3d = 39,
1216    /// LP-pool 1-D backward (Phase 16.2 — atomicAdd scatter from
1217    /// each output cell over its source window).
1218    LpPool1dBackward = 40,
1219    /// LP-pool 2-D backward (Phase 16.2 — atomicAdd scatter).
1220    LpPool2dBackward = 41,
1221    /// Fractional max-pool 2-D backward (Phase 16.3 — atomicAdd scatter
1222    /// from each output cell into `dx[indices[cell]]` via saved
1223    /// argmax). half / bf16 atomicAdd routes through atomicCAS.
1224    FractionalMaxPool2dBackward = 42,
1225    /// Fractional max-pool 3-D backward (Phase 16.3 — atomicAdd scatter).
1226    FractionalMaxPool3dBackward = 43,
1227}
1228
1229/// Attention-family op discriminant — Category K from the comprehensive
1230/// plan.
1231///
1232/// Stored as `u16` in [`crate::KernelSku::op`] when
1233/// `category == OpCategory::Attention`. Phase 6.1 wires the two
1234/// positional-encoding ops [`Self::Rope`] and [`Self::Alibi`]; the rest
1235/// are reserved discriminants for future milestones (SDPA, FlashAttention,
1236/// KV-cache, paged attention).
1237///
1238/// All variants in this family operate on rank-4 attention-shaped
1239/// tensors (typically `[batch, num_heads, seq_len, head_dim]` for RoPE
1240/// or `[batch, num_heads, query_len, key_len]` for attention scores /
1241/// ALiBi). Plan shapes differ between ops — the discriminant is here
1242/// for SKU-tagging uniformity, not for shared dispatch.
1243#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
1244#[repr(u16)]
1245#[non_exhaustive]
1246pub enum AttentionKind {
1247    /// Rotary position embedding (Llama / Mistral / Gemma / Qwen / Phi).
1248    /// Rotates pairs of consecutive features `(2i, 2i+1)` of a
1249    /// `[B, H, S, D]` Q/K tensor by per-position angles
1250    /// `θ = pos · base^(-2i / D)`. Trailblazer for Phase 6.
1251    Rope = 0,
1252    /// Attention with Linear Biases (MPT / BLOOM). Adds the bias
1253    /// `slope[h] · (j - i)` to attention-score cell `(b, h, i, j)`.
1254    /// Linear (non-transcendental) FW; BW reduces over the
1255    /// score-shape axes to recover `dslope[h]`.
1256    Alibi = 1,
1257    /// Scaled dot-product attention — reserved.
1258    Sdpa = 2,
1259    /// FlashAttention (Tri Dao 2022) — wired in Milestone 6.6. Tiled
1260    /// fused online-softmax FW kernel that avoids materializing the
1261    /// `[B, H, Q, K]` attention matrix; instead saves a small
1262    /// `lse: [B, H, Q]` log-sum-exp tensor for the BW pass. Trailblazer
1263    /// constraints: `Br = Bc = 64`, `d_k = d_v ≤ 128`, optional causal
1264    /// mask, no explicit additive mask (use `SdpaPlan` for masked
1265    /// attention).
1266    FlashAttention = 3,
1267    /// KV-cache append — decoder-inference helper that writes
1268    /// newly-generated `K` / `V` slices into running cache buffers at
1269    /// per-sample offsets. Wired in Milestone 6.5 (FW only, no BW —
1270    /// inference-time op).
1271    KvCache = 4,
1272    /// Paged attention (vLLM-style) — reserved.
1273    PagedAttention = 5,
1274    /// Manifold-Constrained Hyper-Connections (DeepSeek-AI 2025, mHC).
1275    /// Drop-in replacement for the bare `y = x + sublayer(x)` residual
1276    /// connection in transformer blocks. Mixes `n` parallel residual
1277    /// streams through a small Sinkhorn-Knopp-normalized matrix `M`
1278    /// that lives on the manifold of doubly-stochastic matrices.
1279    /// Wired in Phase 43 — bf16 weights / f32 activations, static-H
1280    /// FW only (dynamic-H + BW deferred). Backed by the vendored
1281    /// `mHC.cu` (Andre Slavescu, MIT) under
1282    /// `crates/baracuda-kernels-sys/vendor/mhc/`.
1283    HyperConnection = 6,
1284    /// Mamba-2 State-Space Duality (SSD) chunk-scan (Phase 50). Bespoke
1285    /// kernel powering the Mamba-2 family (Mamba-2 8B, Codestral-Mamba,
1286    /// Falcon-Mamba, Zamba2). Operates on rank-4 `[B, L, H, D]` input
1287    /// + rank-4 `[B, L, H, N]` `B` / `C` modulation tensors + per-head
1288    /// scalar SSM eigenvalue `A: [H]`, producing rank-4 `[B, L, H, D]`
1289    /// output. State residency is `H * D * N` floats in SMEM (trailblazer
1290    /// caps `D, N ≤ 256` for FW, `≤ 64` for BW). Behind the `mamba`
1291    /// cargo feature.
1292    SsdChunkScan = 7,
1293    /// Mamba-1 selective_scan (Phase 50b). Bespoke kernel powering the
1294    /// Mamba-1 family (Mamba-7B, Falcon-Mamba, Codestral-Mamba). Operates
1295    /// on rank-3 `[B, L, D]` input + rank-3 `[B, L, N]` `B` / `C`
1296    /// modulation tensors + per-channel `[D, N]` state matrix `A`, with
1297    /// optional `D[d]` skip-connection, SiLU-gated tail `z`, delta-bias,
1298    /// and softplus-`delta` mapping. State residency is `N` floats in
1299    /// SMEM per `(b, d)` block (trailblazer caps `N ≤ 256`). Behind the
1300    /// `mamba` cargo feature.
1301    SelectiveScan = 8,
1302    /// Block-sparse SDPA (Phase 54, xFormers algorithmic-reference
1303    /// hand-port). Attention mask is a per-block boolean pattern
1304    /// `[B, H, num_blocks_q * num_blocks_k]`; only the active
1305    /// (q_block, k_block) pairs participate in the QK^T matmul +
1306    /// online-softmax accumulation. Different from
1307    /// [`Self::FlashAttention`] (dense) and from the Phase 51
1308    /// arbitrary-additive-mask path (which still computes every cell).
1309    /// FW only in Tier 1; backed by bespoke `mma`-free tile kernel
1310    /// behind the `xformers_blocksparse` cargo feature.
1311    BlockSparseAttention = 9,
1312    /// Ring Attention (Phase 56). Sequence-parallel attention where
1313    /// the Q tensor is sliced across `world_size` ranks and the K/V
1314    /// chunks rotate around a NCCL ring; each rank folds the resident
1315    /// K/V chunk into a persistent (o_acc, m_acc, l_acc) accumulator
1316    /// via online-softmax reconstruction (Flash Attention math), and
1317    /// after `world_size` rotations every rank has computed
1318    /// `Q[my_slice] @ K^T @ V` for the full global sequence — but
1319    /// with O(N/P) memory where N = total seq len, P = ring size.
1320    /// Algorithm: Liu, Yan, Abbeel 2023 (arXiv:2310.01889; reference
1321    /// at https://github.com/lhao499/RingAttention, Apache-2.0).
1322    /// Tier 1 ships FW only, f16/bf16, head_dim=128. Behind the
1323    /// `ring_attention` cargo feature; pulls in `baracuda-nccl`.
1324    RingAttention = 10,
1325}
1326
1327/// Indexing / scatter / gather op discriminant — Category L from the
1328/// comprehensive plan.
1329///
1330/// Stored as `u16` in [`crate::KernelSku::op`] when
1331/// `category == OpCategory::Indexing`. Phase 7 Milestone 7.3 wires:
1332/// - [`Self::Gather`] (FW + BW): `out[i] = src[index[i]]` along a dim.
1333/// - [`Self::ScatterAdd`]: `out[index[i]] += updates[i]` along a dim
1334///   (atomicAdd, dup-safe).
1335/// - [`Self::IndexSelect`] (FW + BW): `out[..., j, ...] = src[..., idx[j], ...]`
1336///   with a 1-D i32 idx tensor.
1337/// - [`Self::MaskedFill`] (FW + BW): `out[i] = mask[i] ? value : src[i]`.
1338/// - [`Self::OneHot`] (FW only — non-differentiable):
1339///   `out[..., c] = 1 if c == src[...] else 0`.
1340/// - [`Self::Nonzero`] (FW only): coordinates where input != 0,
1341///   returned as an `[k, rank]` i32 table plus a count.
1342///
1343/// Index dtype is `i32` only in the trailblazer (i64 deferred).
1344/// Out-of-bounds and negative indices are treated as no-ops (the kernel
1345/// skips them — PyTorch-style negative wrap-around is deferred).
1346#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
1347#[repr(u16)]
1348#[non_exhaustive]
1349pub enum IndexingKind {
1350    /// `gather(src, dim, index)` — `out[..., j, ...] = src[..., index[..., j, ...], ...]`
1351    /// along the specified gather dimension. PyTorch `torch.gather`.
1352    Gather = 0,
1353    /// Gradient of [`Self::Gather`]: scatters `dout` into `dsrc` along
1354    /// the gather dim with atomicAdd (dup-safe). Different signature
1355    /// from [`Self::ScatterAdd`] because the dst is `dsrc` and the
1356    /// index pattern matches the FW gather coordinates exactly.
1357    GatherBackward = 1,
1358    /// `scatter_add(out, dim, index, updates)` —
1359    /// `out[..., index[..., j, ...], ...] += updates[..., j, ...]`
1360    /// (atomicAdd). PyTorch `torch.scatter_add_`.
1361    ScatterAdd = 2,
1362    /// `index_select(src, dim, idx)` —
1363    /// `out[..., j, ...] = src[..., idx[j], ...]` with a 1-D i32 idx
1364    /// tensor. Faster / simpler than `gather` when the index tensor
1365    /// is 1-D. PyTorch `torch.index_select`.
1366    IndexSelect = 3,
1367    /// Gradient of [`Self::IndexSelect`]: scatter-add `dout` into `dsrc`
1368    /// along `select_dim` using `idx` (atomicAdd).
1369    IndexSelectBackward = 4,
1370    /// `masked_fill(src, mask, value)` —
1371    /// `out[i] = mask[i] ? value : src[i]`. PyTorch
1372    /// `torch.Tensor.masked_fill`.
1373    MaskedFill = 5,
1374    /// Gradient of [`Self::MaskedFill`]: `dsrc[i] = mask[i] ? 0 : dout[i]`.
1375    /// `value` is a non-differentiable scalar.
1376    MaskedFillBackward = 6,
1377    /// `one_hot(src, num_classes)` —
1378    /// `out[indices..., c] = 1 if c == src[indices...] else 0`. Input
1379    /// dtype is i32 (class indices); output dtype is configurable.
1380    /// PyTorch `torch.nn.functional.one_hot`. Non-differentiable.
1381    OneHot = 7,
1382    /// `nonzero(x)` — coordinates where `x != 0`. Returns an
1383    /// `[k, rank]` i32 coordinate table plus a count. PyTorch
1384    /// `torch.nonzero`. Output ordering is NOT row-major (atomic-counter
1385    /// races); callers that need sorted output sort afterward.
1386    Nonzero = 8,
1387    /// `scatter(out, dim, index, updates)` —
1388    /// `out[..., index[..., j, ...], ...] = updates[..., j, ...]`
1389    /// (NO accumulation; last writer wins on duplicate-target races).
1390    /// PyTorch `torch.scatter_` (the in-place pure-assign variant).
1391    /// Distinct from [`Self::ScatterAdd`]. Phase 39 (Fuel 6c.4 Gap 5).
1392    Scatter = 9,
1393    /// `index_add(dst, dim, idx, src)` —
1394    /// `dst[idx[i], ...] += src[i, ...]` along `add_dim` (atomicAdd-Σ).
1395    /// PyTorch `torch.Tensor.index_add_`. Algorithmically identical to
1396    /// [`Self::IndexSelectBackward`] but exposed under a non-autograd-
1397    /// flavored name (and with broader dtype coverage). Phase 39
1398    /// (Fuel 6c.4 Gap 5).
1399    IndexAdd = 10,
1400}
1401
1402/// Segment / scatter-reduce op discriminant — Category S from the
1403/// comprehensive plan.
1404///
1405/// Stored as `u16` in [`crate::KernelSku::op`] when
1406/// `category == OpCategory::SegmentOps`. Each variant maps to a
1407/// distinct kernel symbol — sorted and unsorted families live in the
1408/// same enum (different `op` slots) because the kernel implementation
1409/// differs (sorted = binary-search single-pass sweep; unsorted = atomic
1410/// scatter from the input side).
1411///
1412/// Phase 7 Milestone 7.6 wires:
1413/// - Sorted: [`Self::SegmentSum`], [`Self::SegmentMean`],
1414///   [`Self::SegmentMax`], [`Self::SegmentMin`], [`Self::SegmentProd`]
1415///   (FW). Sum / Mean carry a BW variant
1416///   ([`Self::SegmentSumBackward`], [`Self::SegmentMeanBackward`]).
1417/// - Unsorted: [`Self::UnsortedSegmentSum`],
1418///   [`Self::UnsortedSegmentMean`], [`Self::UnsortedSegmentMax`],
1419///   [`Self::UnsortedSegmentMin`] (FW). Sum / Mean carry a BW variant
1420///   ([`Self::UnsortedSegmentSumBackward`],
1421///   [`Self::UnsortedSegmentMeanBackward`]).
1422///
1423/// Phase 25 closes the remaining BW gaps: Max / Min BW (sorted +
1424/// unsorted) recompute the argmax in the BW kernel (preserves FW API
1425/// source-compat — no paired-index tensor in the FW signature). Prod
1426/// BW (sorted + unsorted) computes `d_output * prod / x` with direct
1427/// division — caller must avoid zero-valued inputs in the segment or
1428/// accept NaN/Inf in the gradient. Unsorted Prod FW uses an
1429/// `atomicCAS` retry loop (no native FP `atomicMul`).
1430///
1431/// Dtype coverage: `f32, f64` (atomic-supported FP types). f16 / bf16
1432/// deferred — the kernels use `atomicAdd` / `atomicMax` / `atomicMin`
1433/// which are restricted to native-FP-atomic types.
1434#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
1435#[repr(u16)]
1436#[non_exhaustive]
1437pub enum SegmentKind {
1438    /// `out[s, d] = Σ_{n : seg[n] == s} input[n, d]` — sorted segment
1439    /// IDs (monotonically non-decreasing). TF / JAX `segment_sum`.
1440    SegmentSum = 0,
1441    /// Gradient of [`Self::SegmentSum`]:
1442    /// `d_input[n, d] = d_output[seg[n], d]` (gather along seg ids).
1443    SegmentSumBackward = 1,
1444    /// `out[s, d] = mean_{n : seg[n] == s} input[n, d]` — sorted.
1445    SegmentMean = 2,
1446    /// Gradient of [`Self::SegmentMean`]:
1447    /// `d_input[n, d] = d_output[seg[n], d] / count[seg[n]]`.
1448    SegmentMeanBackward = 3,
1449    /// `out[s, d] = max_{n : seg[n] == s} input[n, d]` — sorted.
1450    SegmentMax = 4,
1451    /// `out[s, d] = min_{n : seg[n] == s} input[n, d]` — sorted.
1452    SegmentMin = 5,
1453    /// `out[s, d] = prod_{n : seg[n] == s} input[n, d]` — sorted.
1454    SegmentProd = 6,
1455    /// `out[s, d] = Σ_{n : seg[n] == s} input[n, d]` — unsorted
1456    /// (seg IDs in any order). TF `unsorted_segment_sum`.
1457    UnsortedSegmentSum = 7,
1458    /// Gradient of [`Self::UnsortedSegmentSum`]:
1459    /// `d_input[n, d] = d_output[seg[n], d]`.
1460    UnsortedSegmentSumBackward = 8,
1461    /// `out[s, d] = mean_{n : seg[n] == s} input[n, d]` — unsorted.
1462    UnsortedSegmentMean = 9,
1463    /// Gradient of [`Self::UnsortedSegmentMean`]:
1464    /// `d_input[n, d] = d_output[seg[n], d] / count[seg[n]]`.
1465    UnsortedSegmentMeanBackward = 10,
1466    /// `out[s, d] = max_{n : seg[n] == s} input[n, d]` — unsorted.
1467    UnsortedSegmentMax = 11,
1468    /// `out[s, d] = min_{n : seg[n] == s} input[n, d]` — unsorted.
1469    UnsortedSegmentMin = 12,
1470    /// Phase 25. Gradient of [`Self::SegmentMax`]:
1471    /// `d_input[k, d] = d_output[seg, d]` for the (lowest-index) `k`
1472    /// where `input[k, d] == max`. Argmax recomputed in BW kernel
1473    /// (re-scans the segment) so the FW signature stays unchanged.
1474    SegmentMaxBackward = 13,
1475    /// Phase 25. Gradient of [`Self::SegmentMin`] — mirror of
1476    /// [`Self::SegmentMaxBackward`].
1477    SegmentMinBackward = 14,
1478    /// Phase 25. Gradient of [`Self::SegmentProd`]:
1479    /// `d_input[k, d] = d_output[seg, d] * (prod[seg, d] / x[k, d])`.
1480    /// Direct division — caller must avoid zero-valued inputs in the
1481    /// segment or accept NaN / Inf in the gradient.
1482    SegmentProdBackward = 15,
1483    /// Phase 25. Gradient of [`Self::UnsortedSegmentMax`] — same
1484    /// recompute-argmax pattern as the sorted variant but scans the
1485    /// full input array per (seg, d) cell. Non-deterministic w.r.t.
1486    /// tie-breaking when the FW was non-deterministic.
1487    UnsortedSegmentMaxBackward = 16,
1488    /// Phase 25. Gradient of [`Self::UnsortedSegmentMin`] — mirror of
1489    /// [`Self::UnsortedSegmentMaxBackward`].
1490    UnsortedSegmentMinBackward = 17,
1491    /// Phase 25. `out[s, d] = prod_{n : seg[n] == s} input[n, d]` —
1492    /// unsorted. Uses an `atomicCAS` retry loop because no native FP
1493    /// `atomicMul` exists. Non-deterministic.
1494    UnsortedSegmentProd = 18,
1495    /// Phase 25. Gradient of [`Self::UnsortedSegmentProd`] — same
1496    /// direct-division pattern as [`Self::SegmentProdBackward`].
1497    UnsortedSegmentProdBackward = 19,
1498}
1499
1500/// Embedding-family op discriminant — Category M from the comprehensive
1501/// plan.
1502///
1503/// Stored as `u16` in [`crate::KernelSku::op`] when
1504/// `category == OpCategory::Embedding`. Phase 7 Milestone 7.5 wires:
1505/// - [`Self::Embedding`] (FW + BW): row-lookup
1506///   `out[i, :] = weight[indices[i], :]` with optional `padding_idx`
1507///   that emits an all-zero row at FW and skips accumulation at BW.
1508/// - [`Self::EmbeddingBagSum`] / [`Self::EmbeddingBagMean`] (FW + BW):
1509///   bag-reduced row lookup —
1510///   `out[b, :] = reduce(weight[indices[k], :] for k in offsets[b]..offsets[b+1])`.
1511///   Mode determines the reducer (sum / divide-by-bag-size).
1512///   `EmbeddingBagMax` is deferred (needs argmax tracking for BW).
1513///
1514/// Index dtype is `i32` only (i64 deferred). FW kernels emit
1515/// `f32, f64, f16, bf16` (pure copy / reduce); BW kernels emit `f32,
1516/// f64` (atomicAdd).
1517#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
1518#[repr(u16)]
1519#[non_exhaustive]
1520pub enum EmbeddingKind {
1521    /// `embedding(weight, indices, padding_idx)` —
1522    /// `out[i, :] = weight[indices[i], :]`. PyTorch
1523    /// `torch.nn.functional.embedding`.
1524    Embedding = 0,
1525    /// Gradient of [`Self::Embedding`]:
1526    /// `dweight[indices[i], :] += dout[i, :]` (atomicAdd), skipping
1527    /// rows where `indices[i] == padding_idx`.
1528    EmbeddingBackward = 1,
1529    /// `embedding_bag(weight, indices, offsets, mode=Sum)`.
1530    /// PyTorch `torch.nn.functional.embedding_bag` with `mode='sum'`.
1531    EmbeddingBagSum = 2,
1532    /// `embedding_bag(weight, indices, offsets, mode=Mean)`.
1533    /// PyTorch `torch.nn.functional.embedding_bag` with `mode='mean'`.
1534    EmbeddingBagMean = 3,
1535    /// Gradient of `embedding_bag` (Sum-mode):
1536    /// `dweight[indices[k], :] += dout[b, :]` for k in bag b (atomicAdd).
1537    EmbeddingBagSumBackward = 4,
1538    /// Gradient of `embedding_bag` (Mean-mode):
1539    /// `dweight[indices[k], :] += dout[b, :] / bag_size(b)` (atomicAdd).
1540    EmbeddingBagMeanBackward = 5,
1541    /// `embedding_bag(weight, indices, offsets, mode=Max)` — reserved.
1542    /// Max-mode requires argmax tracking on FW (the per-feature index
1543    /// of the contributing row) so the BW can scatter into just that
1544    /// row — different plan shape; deferred.
1545    EmbeddingBagMax = 6,
1546    /// Gradient of `embedding_bag` (Max-mode) — reserved.
1547    EmbeddingBagMaxBackward = 7,
1548}
1549
1550/// Quantization op discriminant — Category P from the comprehensive plan.
1551///
1552/// Stored as `u16` in [`crate::KernelSku::op`] when
1553/// `category == OpCategory::Quantization`. Phase 8 Milestone 8.1 wires the
1554/// trailblazer set: per-tensor + per-channel quantize / dequantize plus
1555/// fake_quantize (round-trip in FP space). All entries support FW + BW
1556/// where applicable (FW-only for kinds that have no meaningful gradient).
1557///
1558/// **Trailblazer dtype scope.** Input FP × output int:
1559/// - Input FP: `f32, f64, f16, bf16`.
1560/// - Output int: `s8, u8`. Sub-byte packed types (`s4`, `u4`) are deferred.
1561/// - `scale` matches the input FP dtype; `zero_point` is always `i32`
1562///   (wide enough for any int output qmin/qmax range).
1563///
1564/// **Backward convention (Straight-Through Estimator).** The BW of
1565/// `quantize` and `fake_quantize` uses STE — the gradient passes through
1566/// (with a `1/scale` factor for `quantize`, no factor for `fake_quantize`)
1567/// where the rounded result was in-range `[qmin, qmax]`, zero elsewhere.
1568/// The "in-range mask" is **recomputed in BW from the saved `input`
1569/// tensor** rather than saved as a separate FW output — this matches
1570/// PyTorch's internal FakeQuantize and keeps the FW signature clean.
1571/// Callers must therefore retain the original input tensor for the BW
1572/// pass (which they would do anyway for autograd).
1573///
1574/// Future milestones extend this enum with `PerToken` / `PerGroup` /
1575/// `DynamicRange` variants — discriminant gaps are intentionally left
1576/// for those.
1577#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
1578#[repr(u16)]
1579#[non_exhaustive]
1580pub enum QuantizeKind {
1581    /// `quantize_per_tensor(x, scale, zero_point)` —
1582    /// `q = clamp(round(x / scale) + zero_point, qmin, qmax)`.
1583    /// One scalar `scale` (FP) and `zero_point` (i32) for the whole
1584    /// tensor. PyTorch `torch.quantize_per_tensor`.
1585    PerTensor = 0,
1586    /// Gradient of [`Self::PerTensor`] via STE:
1587    /// `dx = (dy / scale) * in_range_mask`, where the mask is
1588    /// `qmin <= round(x/scale) + zp <= qmax`.
1589    PerTensorBackward = 1,
1590    /// `dequantize_per_tensor(q, scale, zero_point)` —
1591    /// `x = scale * (q - zero_point)`. Linear; exactly invertible up to
1592    /// rounding. PyTorch `torch.Tensor.dequantize`.
1593    DequantizePerTensor = 2,
1594    /// Gradient of [`Self::DequantizePerTensor`]: `dq = dy * scale`
1595    /// (linear identity scaled).
1596    DequantizePerTensorBackward = 3,
1597    /// `quantize_per_channel(x, scale[C], zero_point[C], axis)` — same
1598    /// math as [`Self::PerTensor`] but with one `scale[c]` /
1599    /// `zero_point[c]` pair per slice along `axis`. PyTorch
1600    /// `torch.quantize_per_channel`.
1601    PerChannel = 4,
1602    /// Gradient of [`Self::PerChannel`] via STE:
1603    /// `dx = (dy / scale[c]) * in_range_mask[c]`.
1604    PerChannelBackward = 5,
1605    /// `dequantize_per_channel(q, scale[C], zero_point[C], axis)` —
1606    /// `x = scale[c] * (q - zero_point[c])`.
1607    DequantizePerChannel = 6,
1608    /// Gradient of [`Self::DequantizePerChannel`]:
1609    /// `dq = dy * scale[c]`.
1610    DequantizePerChannelBackward = 7,
1611    /// `fake_quantize_per_tensor(x, scale, zero_point)` —
1612    /// `y = scale * (clamp(round(x/scale)+zp, qmin, qmax) - zp)`. The
1613    /// roundtrip quantize-then-dequantize in FP space; produces a lossy
1614    /// FP output. PyTorch
1615    /// `torch.fake_quantize_per_tensor_affine`.
1616    FakeQuantize = 8,
1617    /// Gradient of [`Self::FakeQuantize`] via STE:
1618    /// `dx = dy * in_range_mask`. **No `1/scale` factor** — the
1619    /// dequant-side multiplication by `scale` in FW cancels the
1620    /// `1/scale` from STE.
1621    FakeQuantizeBackward = 9,
1622    /// Reserved — `quantize_per_token` (per-row dynamic-range
1623    /// quantization used by activation quantization).
1624    PerToken = 16,
1625    /// Reserved — gradient of [`Self::PerToken`].
1626    PerTokenBackward = 17,
1627    /// Reserved — `quantize_per_group` (block-wise quantization used by
1628    /// GPTQ / AWQ / GGML).
1629    PerGroup = 18,
1630    /// Reserved — gradient of [`Self::PerGroup`].
1631    PerGroupBackward = 19,
1632    /// Reserved — `dynamic_range_quantize` (post-training dynamic
1633    /// quantization).
1634    DynamicRange = 20,
1635    // ---- Milestone 8.2 completion — per-token / per-group dequant
1636    //      + backwards (FW PerToken / PerGroup discriminants were
1637    //      reserved above at 16-19) ----
1638    /// `dequantize_per_token(q, scale[N], zero_point[N])` —
1639    /// `y[n, d] = scale[n] * (q[n, d] - zp[n])`. Per-row inverse of
1640    /// [`Self::PerToken`].
1641    DequantizePerToken = 21,
1642    /// Gradient of [`Self::DequantizePerToken`]:
1643    /// `dq = dy * scale[n]` (straight-through).
1644    DequantizePerTokenBackward = 22,
1645    /// `dequantize_per_group(q, scale[outer, num_groups],
1646    /// zero_point[outer, num_groups])` — per-group inverse of
1647    /// [`Self::PerGroup`].
1648    DequantizePerGroup = 23,
1649    /// Gradient of [`Self::DequantizePerGroup`]:
1650    /// `dq[i, j] = dy[i, j] * scale[i, j/g]` (straight-through).
1651    DequantizePerGroupBackward = 24,
1652    // ---- Milestone 8.3 — composing quantization ops ----
1653    /// `quantized_linear(activation_fp, weight_q_s8, weight_scale,
1654    /// bias?)` — W8A8 fused quantized matmul. Pipeline: dynamic-range
1655    /// per-token quantize the activation → int8 GEMM with int32
1656    /// accumulator → dequantize via per-row `scale_a` and per-channel
1657    /// `scale_w`. The canonical inference-time LLM matmul recipe
1658    /// (e.g. SmoothQuant, AWQ-runtime); FP activation in, FP output out,
1659    /// int8 storage only on the GEMM. Backward isn't shipped — this op
1660    /// is inference-only by convention.
1661    QuantizedLinear = 25,
1662    // ---- Milestone 8.4 — GGUF block-format quant family ----
1663    /// `gguf_dequantize(packed_bytes) -> fp_tensor` — unpack a
1664    /// GGUF-packed weight buffer (Q4_0 / Q4_1 / Q5_0 / Q5_1 / Q8_0 +
1665    /// Q2_K / Q3_K / Q4_K / Q5_K / Q6_K / Q8_K) into a dense FP
1666    /// tensor. The block format is carried out-of-band on the plan
1667    /// descriptor (see [`GgufBlockFormat`]); the kernel surface
1668    /// fans out across block formats but the enum value is the same.
1669    /// Inference-only by convention (BW not shipped).
1670    GgufDequantize = 26,
1671    /// `gguf_mmvq(packed_weight, fp_activation) -> fp_output` —
1672    /// fused dequant + matrix-vector multiply: the inference-time
1673    /// "decode-step" matmul used by llama.cpp on GGUF weights.
1674    /// FP activation in (f32 today), FP output out. Inference-only
1675    /// (BW not shipped).
1676    GgufMmvq = 27,
1677}
1678
1679/// GGUF block-format selector for [`QuantizeKind::GgufDequantize`] /
1680/// [`QuantizeKind::GgufMmvq`] plans. Mirrors the discriminants used by
1681/// llama.cpp / `ggml` so a descriptor can be round-tripped to a GGUF
1682/// file header without translation.
1683///
1684/// Block sizes:
1685///   * Type-0/1 variants (`Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`)
1686///     pack 32 quantized values per block plus a shared FP scale
1687///     (+ min for the `_1` variants).
1688///   * k-quants variants (`Q2_K` ... `Q8_K`) pack 256 values per
1689///     super-block with a multi-level scale hierarchy
1690///     (quantized sub-block scales + FP super-block scale).
1691///
1692/// Discriminant values match the `GGML_TYPE_*` enum in upstream
1693/// `ggml.h`, ensuring binary compatibility with GGUF file headers.
1694#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
1695#[repr(u16)]
1696#[non_exhaustive]
1697pub enum GgufBlockFormat {
1698    /// 4-bit, 32-element block, single FP scale. `block_q4_0`.
1699    Q4_0 = 2,
1700    /// 4-bit, 32-element block, FP scale + FP min. `block_q4_1`.
1701    Q4_1 = 3,
1702    /// 5-bit, 32-element block, single FP scale. `block_q5_0`.
1703    Q5_0 = 6,
1704    /// 5-bit, 32-element block, FP scale + FP min. `block_q5_1`.
1705    Q5_1 = 7,
1706    /// 8-bit, 32-element block, single FP scale. `block_q8_0`.
1707    Q8_0 = 8,
1708    /// 2.5-bit (effective), 256-element super-block. `block_q2_K`.
1709    Q2K = 10,
1710    /// 3.4-bit (effective), 256-element super-block. `block_q3_K`.
1711    Q3K = 11,
1712    /// 4.5-bit (effective), 256-element super-block. `block_q4_K`.
1713    Q4K = 12,
1714    /// 5.5-bit (effective), 256-element super-block. `block_q5_K`.
1715    Q5K = 13,
1716    /// 6.6-bit (effective), 256-element super-block. `block_q6_K`.
1717    Q6K = 14,
1718    /// 8-bit, 256-element super-block (CPU-side intermediate).
1719    /// `block_q8_K`. Dequant supported; MMVQ NOT supported (matches
1720    /// llama.cpp — no upstream MMVQ specialization).
1721    Q8K = 15,
1722}
1723
1724impl GgufBlockFormat {
1725    /// Number of FP elements per packed block.
1726    #[inline]
1727    pub const fn block_size(self) -> usize {
1728        match self {
1729            GgufBlockFormat::Q4_0
1730            | GgufBlockFormat::Q4_1
1731            | GgufBlockFormat::Q5_0
1732            | GgufBlockFormat::Q5_1
1733            | GgufBlockFormat::Q8_0 => 32,
1734            _ => 256,
1735        }
1736    }
1737
1738    /// Number of bytes per packed block. Matches `sizeof(block_q*)`
1739    /// from `ggml.h`. Used by the Rust plan layer to size the input
1740    /// weight buffer.
1741    #[inline]
1742    pub const fn type_size(self) -> usize {
1743        match self {
1744            // 2 (fp16 d) + 16 (qs[16])
1745            GgufBlockFormat::Q4_0 => 18,
1746            // 2*2 (half2 dm) + 16 (qs[16])
1747            GgufBlockFormat::Q4_1 => 20,
1748            // 2 (fp16 d) + 4 (qh) + 16 (qs[16])
1749            GgufBlockFormat::Q5_0 => 22,
1750            // 2*2 (half2 dm) + 4 (qh) + 16 (qs[16])
1751            GgufBlockFormat::Q5_1 => 24,
1752            // 2 (fp16 d) + 32 (qs[32])
1753            GgufBlockFormat::Q8_0 => 34,
1754            // 2*2 (half2 dm) + QK_K/16 (16 scales) + QK_K/4 (64 qs) = 4+16+64
1755            GgufBlockFormat::Q2K => 84,
1756            // hmask(32) + qs(64) + scales(12) + d(2)
1757            GgufBlockFormat::Q3K => 110,
1758            // dm(4) + scales(12) + qs(128)
1759            GgufBlockFormat::Q4K => 144,
1760            // dm(4) + scales(12) + qh(32) + qs(128)
1761            GgufBlockFormat::Q5K => 176,
1762            // ql(128) + qh(64) + scales(16) + d(2)
1763            GgufBlockFormat::Q6K => 210,
1764            // d(4) + qs(256) + bsums(32)
1765            GgufBlockFormat::Q8K => 292,
1766        }
1767    }
1768
1769    /// `true` for the type-0/1 family (32-element blocks); `false`
1770    /// for the k-quants family (256-element super-blocks).
1771    #[inline]
1772    pub const fn is_type_01(self) -> bool {
1773        matches!(
1774            self,
1775            GgufBlockFormat::Q4_0
1776                | GgufBlockFormat::Q4_1
1777                | GgufBlockFormat::Q5_0
1778                | GgufBlockFormat::Q5_1
1779                | GgufBlockFormat::Q8_0
1780        )
1781    }
1782
1783    /// `true` if MMVQ (fused dequant + matvec) is supported for this
1784    /// block format. As of Phase 11.4, all 11 GGUF block formats ship a
1785    /// MMVQ kernel. Q8_K MMVQ is a bespoke baracuda addition (upstream
1786    /// llama.cpp / Fuel reserve Q8_K as a CPU-side intermediate and ship
1787    /// dequant only); we close that gap to avoid 2× memory traffic on
1788    /// the inference decode step.
1789    #[inline]
1790    pub const fn has_mmvq(self) -> bool {
1791        match self {
1792            GgufBlockFormat::Q4_0
1793            | GgufBlockFormat::Q4_1
1794            | GgufBlockFormat::Q5_0
1795            | GgufBlockFormat::Q5_1
1796            | GgufBlockFormat::Q8_0
1797            | GgufBlockFormat::Q2K
1798            | GgufBlockFormat::Q3K
1799            | GgufBlockFormat::Q4K
1800            | GgufBlockFormat::Q5K
1801            | GgufBlockFormat::Q6K
1802            | GgufBlockFormat::Q8K => true,
1803        }
1804    }
1805}
1806
1807
1808/// Mixture-of-Experts (MoE) variant selector — used as the `op`
1809/// discriminant for kernel SKUs whose [`crate::OpCategory`] is
1810/// [`crate::OpCategory::Moe`]. Phase 8 Milestone 8.5 wires the three
1811/// fused per-token-dispatch + expert-matmul + accumulate kernels.
1812///
1813/// MoE forward pass shape:
1814///   * Input activations `[T, D_model]`.
1815///   * Per-token top-k expert indices `[T, top_k]` (i32).
1816///   * Per-token top-k expert weights `[T, top_k]` (FP).
1817///   * Per-expert weight matrices `[num_experts, D_model, D_expert]`
1818///     (dtype depends on the variant: FP for `Wmma`, GGUF-packed bytes
1819///     for `ScalarGguf` / `WmmaGguf`).
1820///   * Output `[T, D_model]` (after expert mixing).
1821///
1822/// All three variants are inference-only by convention; backward
1823/// passes are not shipped (MoE training uses higher-level autograd
1824/// surfaces that compose the per-expert FFN ops manually).
1825///
1826/// Lineage: vendored from `attention.rs` via `fuel-cuda-kernels`. See
1827/// `crates/baracuda-kernels-sys/LICENSE-thirdparty.md` for the full
1828/// attribution chain.
1829#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
1830#[repr(u16)]
1831#[non_exhaustive]
1832pub enum MoeKind {
1833    /// Scalar dispatch path operating on GGUF-quantized expert weights
1834    /// staged through a q8_1 intermediate (FP32 activations in, FP32
1835    /// output out). No tensor cores. Used as a portability fallback
1836    /// and as the slower-but-simpler reference for the WMMA + GGUF
1837    /// hot path. Block formats: `Q8_0`, `Q2_K`, `Q3_K`, `Q4_K`,
1838    /// `Q5_K`, `Q6_K` (matches Fuel's `moe_gemm_gguf` switch).
1839    ScalarGguf = 0,
1840    /// WMMA tensor-core path operating on dense FP expert weights
1841    /// (f16 / bf16). The FP MoE hot path used when full-precision
1842    /// expert weights are available — typically training-time or
1843    /// FP-deployment inference. sm_70+ required.
1844    Wmma = 1,
1845    /// Combined WMMA tensor-core + GGUF-quantized weight path. The
1846    /// dispatcher dequantizes one GGUF block per N-row into shared
1847    /// memory, then issues a 16×16×16 WMMA mma.sync against the
1848    /// dense activation tile. The production hot path for quantized
1849    /// LLM inference. Activation dtype: f16 / bf16. Weight block
1850    /// formats: same set as [`Self::ScalarGguf`]. sm_70+ required.
1851    WmmaGguf = 2,
1852}
1853
1854/// Sorting / order-statistics op discriminant — Category O from the
1855/// comprehensive plan (Phase 9).
1856///
1857/// Stored as `u16` in [`crate::KernelSku::op`] when
1858/// `category == OpCategory::Sorting`. Phase 9 wires the block-bitonic
1859/// trailblazer family (`row_len ≤ 1024`, `k ≤ 64`):
1860///
1861/// - [`Self::Sort`] / [`Self::SortBackward`] — full sort with saved
1862///   indices for BW. PyTorch `torch.sort`.
1863/// - [`Self::Argsort`] — indices-only variant. PyTorch `torch.argsort`.
1864/// - [`Self::Msort`] / [`Self::MsortBackward`] — stable sort (tie-break
1865///   on original index preserves input order). PyTorch `torch.msort`.
1866/// - [`Self::Topk`] / [`Self::TopkBackward`] — top-k by value (or
1867///   bottom-k when `largest == false`). PyTorch `torch.topk`.
1868/// - [`Self::Kthvalue`] / [`Self::KthvalueBackward`] — composed atop
1869///   topk; returns the k-th value + its index.
1870/// - [`Self::Unique`] / [`Self::UniqueConsecutive`] — set-valued ops;
1871///   `unique` chains sort + consecutive-dedup, `unique_consecutive`
1872///   assumes the input is already sorted (or only run-equal cells
1873///   matter). No BW (set-valued).
1874/// - [`Self::Histogram`] / [`Self::Histogramdd`] / [`Self::Bincount`]
1875///   — atomic-bin accumulation; histogram + bincount FW shipped,
1876///   histogramdd reserved (rank > 1 trailblazer follow-up).
1877/// - [`Self::Searchsorted`] — per-query binary search in a 1-D sorted
1878///   array. PyTorch `torch.searchsorted`. No BW.
1879///
1880/// Dtype coverage:
1881/// - sort / argsort / msort FW: `f32, f64, i32, i64`.
1882/// - sort / msort BW: `f32, f64` (FP grads only).
1883/// - topk FW + BW: `f32, f64`.
1884/// - kthvalue: composes topk; same dtype set.
1885/// - unique / unique_consecutive: `f32, f64, i32`.
1886/// - histogram: `f32, f64` input → `i32` counts.
1887/// - bincount: `i32, i64` input → `i32` counts.
1888/// - searchsorted: `f32, f64, i32, i64`.
1889#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
1890#[repr(u16)]
1891#[non_exhaustive]
1892pub enum SortKind {
1893    /// `sort(x, dim, descending)` — returns sorted values + sorted
1894    /// indices. PyTorch `torch.sort`.
1895    Sort = 0,
1896    /// Gradient of [`Self::Sort`] — scatter `dy` back to the original
1897    /// positions via the saved indices.
1898    SortBackward = 1,
1899    /// `argsort(x, dim, descending)` — returns sorted indices only.
1900    /// PyTorch `torch.argsort`.
1901    Argsort = 2,
1902    /// `msort(x)` — stable sort along the last dimension. Tie-break on
1903    /// original index preserves input order. PyTorch `torch.msort`.
1904    Msort = 3,
1905    /// Gradient of [`Self::Msort`] — same scatter as
1906    /// [`Self::SortBackward`].
1907    MsortBackward = 4,
1908    /// `topk(x, k, dim, largest)` — top-k (or bottom-k) values + their
1909    /// indices. PyTorch `torch.topk`. Trailblazer caps `k ≤ 64`.
1910    Topk = 5,
1911    /// Gradient of [`Self::Topk`] — scatter the k-wide `dy` back to a
1912    /// zero-init `row_len`-wide `dx` via saved indices.
1913    TopkBackward = 6,
1914    /// `kthvalue(x, k, dim)` — the k-th smallest value + its index.
1915    /// Composed at the Rust plan layer atop [`Self::Topk`] with the
1916    /// "bottom-k" order.
1917    Kthvalue = 7,
1918    /// Gradient of [`Self::Kthvalue`] — scatter the scalar `dy` back
1919    /// to the single source position.
1920    KthvalueBackward = 8,
1921    /// `unique(x, sorted=True)` — returns the unique values in `x`. At
1922    /// the Rust plan layer this chains [`Self::Sort`] + the consecutive
1923    /// dedup. Set-valued — no BW.
1924    Unique = 9,
1925    /// `unique_consecutive(x)` — emits one cell per run-start (input
1926    /// must be sorted, or only consecutive-equal cells should be
1927    /// collapsed). Set-valued — no BW.
1928    UniqueConsecutive = 10,
1929    /// `histogram(x, bins, range)` — 1-D uniform-bin histogram.
1930    /// PyTorch `torch.histogram`. FW only.
1931    Histogram = 11,
1932    /// `histogramdd(x, bins, range)` — N-D histogram. Reserved
1933    /// discriminant; rank > 1 trailblazer follow-up.
1934    Histogramdd = 12,
1935    /// `bincount(x, minlength)` — count occurrences of each integer
1936    /// in `x`. PyTorch `torch.bincount`. FW only.
1937    Bincount = 13,
1938    /// `searchsorted(sorted_seq, values, right)` — per-query
1939    /// lower/upper bound binary search. PyTorch `torch.searchsorted`.
1940    /// FW only.
1941    Searchsorted = 14,
1942}
1943
1944/// Image / spatial-transform op discriminant — Category T from the
1945/// comprehensive plan.
1946///
1947/// Stored as `u16` in [`crate::KernelSku::op`] when
1948/// `category == OpCategory::Image`. Phase 9 Category T wires the
1949/// trailblazer set:
1950/// - [`Self::InterpolateBilinear2d`] / [`Self::InterpolateBilinear2dBackward`]
1951///   — spatial up/downsample via bilinear interpolation.
1952/// - [`Self::GridSample2d`] / [`Self::GridSample2dBackward`] — sample
1953///   input at arbitrary normalized coordinates (PyTorch
1954///   `torch.nn.functional.grid_sample`, default config:
1955///   `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=false`).
1956/// - [`Self::AffineGrid2d`] — generate a sampling grid from a 2×3
1957///   affine matrix (companion to GridSample).
1958/// - [`Self::PixelShuffle`] / [`Self::PixelUnshuffle`] — pure index
1959///   permutation between `[N, C·r², H, W]` and `[N, C, H·r, W·r]`.
1960///   Each is the other's backward.
1961/// - [`Self::RoiAlign`] / [`Self::RoiAlignBackward`] — extract fixed-
1962///   size feature from variable RoIs via bilinear sampling.
1963/// - [`Self::RoiPool`] / [`Self::RoiPoolBackward`] — max-pool variant
1964///   of RoiAlign (argmax routing on BW).
1965/// - [`Self::Nms`] — non-max suppression on bounding boxes. Returns a
1966///   boolean keep mask + count; no BW (set-valued op).
1967///
1968/// Other interpolation modes (`nearest`, `bicubic`, `trilinear`,
1969/// `linear`, `area`) have discriminants reserved here but the kernels
1970/// are stubbed `Unsupported` in the trailblazer.
1971///
1972/// Trailblazer dtype coverage: `f32, f64` for math-bearing ops;
1973/// `pixel_shuffle` / `pixel_unshuffle` additionally cover `f16, bf16`
1974/// (pure layout — dtype-agnostic).
1975#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
1976#[repr(u16)]
1977#[non_exhaustive]
1978pub enum ImageKind {
1979    /// `interpolate(x, mode='bilinear', size=…)` — 2-D spatial
1980    /// resample with bilinear weights. Trailblazer wired today.
1981    InterpolateBilinear2d = 0,
1982    /// Gradient of [`Self::InterpolateBilinear2d`] — atomic-add of
1983    /// weighted contributions from each output cell to the 4 input
1984    /// cells it bilinearly sampled.
1985    InterpolateBilinear2dBackward = 1,
1986    /// `interpolate(x, mode='nearest')` — reserved.
1987    InterpolateNearest2d = 2,
1988    /// Gradient of [`Self::InterpolateNearest2d`] — reserved.
1989    InterpolateNearest2dBackward = 3,
1990    /// `interpolate(x, mode='bicubic')` — reserved.
1991    InterpolateBicubic2d = 4,
1992    /// Gradient of [`Self::InterpolateBicubic2d`] — reserved.
1993    InterpolateBicubic2dBackward = 5,
1994    /// `interpolate(x, mode='trilinear')` — reserved.
1995    InterpolateTrilinear3d = 6,
1996    /// Gradient of [`Self::InterpolateTrilinear3d`] — reserved.
1997    InterpolateTrilinear3dBackward = 7,
1998    /// `interpolate(x, mode='linear')` — reserved (1-D).
1999    InterpolateLinear1d = 8,
2000    /// Gradient of [`Self::InterpolateLinear1d`] — reserved.
2001    InterpolateLinear1dBackward = 9,
2002    /// `interpolate(x, mode='area')` — reserved (adaptive avg pool).
2003    InterpolateArea2d = 10,
2004    /// Gradient of [`Self::InterpolateArea2d`] — reserved.
2005    InterpolateArea2dBackward = 11,
2006
2007    /// `grid_sample(input, grid)` — 2-D bilinear, zeros-pad,
2008    /// `align_corners=false`. PyTorch defaults.
2009    GridSample2d = 16,
2010    /// Gradient of [`Self::GridSample2d`] — atomic-add into `dinput`
2011    /// + analytical bilinear coordinate derivatives into `dgrid`.
2012    GridSample2dBackward = 17,
2013    /// `affine_grid(theta, size)` — generate the normalized sampling
2014    /// grid for a 2×3 affine matrix. Companion to GridSample2d.
2015    AffineGrid2d = 18,
2016
2017    /// `pixel_shuffle(x, r)` — `[N, C·r², H, W] → [N, C, H·r, W·r]`.
2018    /// Pure index permutation. BW is `PixelUnshuffle`.
2019    PixelShuffle = 24,
2020    /// `pixel_unshuffle(x, r)` — `[N, C, H·r, W·r] → [N, C·r², H, W]`.
2021    /// Inverse of `PixelShuffle`. BW is `PixelShuffle`.
2022    PixelUnshuffle = 25,
2023
2024    /// `roi_align(input, rois, output_size, spatial_scale,
2025    /// sampling_ratio=0, aligned=false)`. PyTorch convention.
2026    RoiAlign = 32,
2027    /// Gradient of [`Self::RoiAlign`] — bilinear-weighted atomic-add
2028    /// into `dinput`.
2029    RoiAlignBackward = 33,
2030    /// `roi_pool(input, rois, output_size, spatial_scale)` — max-pool
2031    /// variant of RoiAlign. Saves argmax indices for BW.
2032    RoiPool = 34,
2033    /// Gradient of [`Self::RoiPool`] — atomic-add of `dout[i, c, h, w]`
2034    /// into `dinput` at the saved argmax cell.
2035    RoiPoolBackward = 35,
2036
2037    /// `nms(boxes, scores, iou_threshold)` — non-max suppression.
2038    /// Returns a boolean keep mask `[num_boxes]` and a count scalar.
2039    /// No BW (set-valued op).
2040    Nms = 40,
2041}