baracuda_kernels_types/ops.rs
1//! Per-category op discriminant enums.
2//!
3//! Each op category (B, C, D, N, …) gets a `*Kind` enum whose variants
4//! correspond to individual PyTorch / JAX ops. The enum value is also
5//! the runtime tag stored as a `u16` in [`crate::KernelSku::op`].
6//!
7//! New enums land alongside the Plan type that consumes them. Today
8//! Phase 3 contributes [`BinaryKind`] and [`UnaryKind`];
9//! [`TernaryKind`], [`GatedActivationKind`], [`ShapeLayoutKind`] follow
10//! as their Plan types ship.
11
12/// Binary elementwise op discriminant.
13///
14/// Stored as `u16` in [`crate::KernelSku::op`] when
15/// `category == OpCategory::BinaryElementwise`. Variants correspond to
16/// the union of PyTorch (`torch.<op>` / `torch.Tensor.<op>`) and JAX
17/// (`jax.numpy.<op>` / `jax.lax.<op>`) binary elementwise ops.
18///
19/// Today only [`Self::Add`] is wired — the Phase 3 trailblazer SKU. The
20/// other variants are reserved discriminants for the fanout sessions
21/// that ship sub / mul / div / pow / comparisons / bitwise.
22#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
23#[repr(u16)]
24#[non_exhaustive]
25pub enum BinaryKind {
26 /// `y = a + b` — elementwise addition. Trailblazer SKU for
27 /// `baracuda-kernels` Phase 3.
28 Add = 0,
29 /// `y = a - b` — elementwise subtraction.
30 Sub = 1,
31 /// `y = a * b` — elementwise multiplication.
32 Mul = 2,
33 /// `y = a / b` — elementwise division.
34 Div = 3,
35 /// `y = floor(a / b)` — elementwise floor-divide.
36 FloorDivide = 4,
37 /// `y = a mod b` — elementwise Python-style modulo (sign matches `b`).
38 Mod = 5,
39 /// `y = remainder(a, b)` — elementwise C-style remainder (sign
40 /// matches `a`).
41 Remainder = 6,
42 /// `y = a ** b` — elementwise power (broadcast scalar exponent OK).
43 Pow = 7,
44 /// `y = atan2(a, b)`.
45 Atan2 = 8,
46 /// `y = hypot(a, b) = sqrt(a² + b²)`.
47 Hypot = 9,
48 /// `y = a` with sign-bit copied from `b`.
49 Copysign = 10,
50 /// `y` = next representable value from `a` toward `b`.
51 Nextafter = 11,
52 /// `y = a · 2^b` (integer `b` broadcast as scalar in practice).
53 Ldexp = 12,
54 /// `y = min(a, b)` — IEEE 754 semantics (NaN-aware).
55 Minimum = 13,
56 /// `y = max(a, b)` — IEEE 754 semantics (NaN-aware).
57 Maximum = 14,
58 /// `y = fmin(a, b)` — PyTorch fmin (NaN-propagating-from-other).
59 Fmin = 15,
60 /// `y = fmax(a, b)` — PyTorch fmax (NaN-propagating-from-other).
61 Fmax = 16,
62 /// `y = (a == b)` — returns bool.
63 Eq = 17,
64 /// `y = (a != b)` — returns bool.
65 Ne = 18,
66 /// `y = (a > b)` — returns bool.
67 Gt = 19,
68 /// `y = (a >= b)` — returns bool.
69 Ge = 20,
70 /// `y = (a < b)` — returns bool.
71 Lt = 21,
72 /// `y = (a <= b)` — returns bool.
73 Le = 22,
74 /// `y = a && b` — bool only.
75 LogicalAnd = 23,
76 /// `y = a || b` — bool only.
77 LogicalOr = 24,
78 /// `y = a ^ b` (logical) — bool only.
79 LogicalXor = 25,
80 /// `y = a & b` — integer only.
81 BitwiseAnd = 26,
82 /// `y = a | b` — integer only.
83 BitwiseOr = 27,
84 /// `y = a ^ b` (bitwise) — integer only.
85 BitwiseXor = 28,
86 /// `y = a << b` — integer only.
87 BitwiseLeftShift = 29,
88 /// `y = a >> b` — integer only.
89 BitwiseRightShift = 30,
90 /// `y = a + (b - a) * weight` (broadcast scalar weight). Per
91 /// PyTorch's `torch.lerp` convention.
92 Lerp = 31,
93}
94
95/// Unary elementwise op discriminant.
96///
97/// Stored as `u16` in [`crate::KernelSku::op`] when
98/// `category == OpCategory::UnaryElementwise`. Variants correspond to
99/// the union of PyTorch (`torch.<op>` / `torch.Tensor.<op>`) and JAX
100/// (`jax.numpy.<op>` / `jax.lax.<op>`) unary elementwise ops, plus the
101/// activation family from PyTorch `nn.functional`.
102///
103/// Today only [`Self::Neg`] is wired — the Phase 3 unary trailblazer
104/// SKU. The other variants are reserved discriminants for the fanout
105/// sessions that ship the math (abs / sqrt / exp / log / sin / …) and
106/// activation (relu / gelu / silu / …) families.
107///
108/// Ops that return a different dtype than the input (`isnan`, `isinf`,
109/// `isfinite`, `logical_not`) are reserved here but will route through
110/// a future `UnaryToBoolPlan` (or similar) with a distinct output type
111/// — not through this enum's `UnaryPlan<T, N>`.
112///
113/// Parameterized activations (`leaky_relu(α)`, `elu(α)`, `threshold(t, v)`,
114/// `hardshrink(λ)`, `softshrink(λ)`) carry their parameters via a
115/// `UnaryParams` field on the descriptor — landed when the first
116/// parameterized op ships, omitted for the trailblazer.
117#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
118#[repr(u16)]
119#[non_exhaustive]
120pub enum UnaryKind {
121 // ---- Category B: elementwise unary (math) — trivial ----
122 /// `y = -x` — elementwise negation. Trailblazer SKU.
123 Neg = 0,
124 /// `y = |x|` — elementwise absolute value.
125 Abs = 1,
126 /// `y = sign(x)` — `-1` / `0` / `+1` per the input's sign.
127 Sign = 2,
128 /// `y = 1 / x` — elementwise reciprocal.
129 Reciprocal = 3,
130 /// `y = x * x` — elementwise square.
131 Square = 4,
132 /// `y = x * x * x` — elementwise cube.
133 Cube = 5,
134
135 // ---- Category B: roots ----
136 /// `y = sqrt(x)`.
137 Sqrt = 10,
138 /// `y = 1 / sqrt(x)` — reciprocal square root.
139 Rsqrt = 11,
140 /// `y = cbrt(x)` — cube root.
141 Cbrt = 12,
142
143 // ---- Category B: exp / log family ----
144 /// `y = exp(x)`.
145 Exp = 20,
146 /// `y = 2^x`.
147 Exp2 = 21,
148 /// `y = exp(x) - 1`.
149 Expm1 = 22,
150 /// `y = ln(x)` — natural log.
151 Log = 23,
152 /// `y = log_2(x)`.
153 Log2 = 24,
154 /// `y = log_10(x)`.
155 Log10 = 25,
156 /// `y = ln(1 + x)`.
157 Log1p = 26,
158
159 // ---- Category B: trig ----
160 /// `y = sin(x)`.
161 Sin = 30,
162 /// `y = cos(x)`.
163 Cos = 31,
164 /// `y = tan(x)`.
165 Tan = 32,
166 /// `y = asin(x)`.
167 Asin = 33,
168 /// `y = acos(x)`.
169 Acos = 34,
170 /// `y = atan(x)`.
171 Atan = 35,
172
173 // ---- Category B: hyperbolic ----
174 /// `y = sinh(x)`.
175 Sinh = 40,
176 /// `y = cosh(x)`.
177 Cosh = 41,
178 /// `y = tanh(x)`.
179 Tanh = 42,
180 /// `y = asinh(x)`.
181 Asinh = 43,
182 /// `y = acosh(x)`.
183 Acosh = 44,
184 /// `y = atanh(x)`.
185 Atanh = 45,
186
187 // ---- Category B: rounding ----
188 /// `y = floor(x)`.
189 Floor = 50,
190 /// `y = ceil(x)`.
191 Ceil = 51,
192 /// `y = round(x)` — round-half-to-even (PyTorch convention).
193 Round = 52,
194 /// `y = trunc(x)` — truncate toward zero.
195 Trunc = 53,
196 /// `y = x - trunc(x)` — fractional part with sign of `x`.
197 Frac = 54,
198
199 // ---- Category B: special functions ----
200 /// `y = erf(x)`.
201 Erf = 60,
202 /// `y = erfc(x) = 1 - erf(x)`.
203 Erfc = 61,
204 /// `y = erfinv(x)`.
205 Erfinv = 62,
206 /// `y = lgamma(x) = ln(|Γ(x)|)`.
207 Lgamma = 63,
208 /// `y = digamma(x) = Γ'(x) / Γ(x)`.
209 Digamma = 64,
210
211 // ---- Category B: bitwise / integer (int-typed only) ----
212 /// `y = ~x` — bitwise NOT (integer dtypes).
213 BitwiseNot = 70,
214 /// `y = popcount(x)` — population count of set bits (integer).
215 Popcount = 71,
216 /// `y = clz(x)` — count leading zeros (integer).
217 Clz = 72,
218 /// `y = ctz(x)` — count trailing zeros (integer).
219 Ctz = 73,
220
221 // ---- Category B': activations (unparameterized) ----
222 /// `y = relu(x) = max(x, 0)`.
223 Relu = 100,
224 /// `y = gelu(x)` — ERF-EXACT Gaussian Error Linear Unit,
225 /// `0.5·x·(1+erf(x/√2))` — NOT the tanh approximation (that's
226 /// [`Self::GeluTanh`]). The sys-level `unary_gelu_erf_*` symbols
227 /// are a bit-identical alias of the `unary_gelu_*` symbols this
228 /// variant dispatches to.
229 Gelu = 101,
230 /// `y = gelu_tanh(x)` — tanh APPROXIMATION of gelu,
231 /// `0.5·x·(1+tanh(√(2/π)·(x+0.044715·x³)))`. Diverges from the
232 /// erf-exact [`Self::Gelu`] by up to ~1e-4.
233 GeluTanh = 102,
234 /// `y = silu(x) = x · sigmoid(x)`. Also known as Swish-1.
235 Silu = 103,
236 /// `y = mish(x) = x · tanh(softplus(x))`.
237 Mish = 104,
238 /// `y = sigmoid(x) = 1 / (1 + exp(-x))`.
239 Sigmoid = 105,
240 /// `y = logit(x) = log(x / (1 - x))`. Inverse of sigmoid.
241 Logit = 106,
242 /// `y = softplus(x) = ln(1 + exp(x))`.
243 Softplus = 107,
244 /// `y = softsign(x) = x / (1 + |x|)`.
245 Softsign = 108,
246 /// `y = tanhshrink(x) = x - tanh(x)`.
247 Tanhshrink = 109,
248 /// `y = relu6(x) = min(max(x, 0), 6)`.
249 Relu6 = 110,
250 /// `y = hardswish(x)` — piecewise-linear approximation of swish.
251 Hardswish = 111,
252 /// `y = hardsigmoid(x)` — piecewise-linear approximation of sigmoid.
253 Hardsigmoid = 112,
254 /// `y = hardtanh(x, -1, +1)` — piecewise-linear clamp.
255 Hardtanh = 113,
256 /// `y = selu(x)` — scaled exponential linear unit.
257 Selu = 114,
258 /// `y = leaky_relu(x) = x if x > 0 else α·x`. Hardcoded α = 0.01 in
259 /// the current bespoke kernel; will re-emit as a fanout from a
260 /// parameterized-unary plan once that infrastructure lands.
261 LeakyRelu = 115,
262 /// `y = elu(x) = x if x > 0 else α·(exp(x) - 1)`. Hardcoded α = 1.0
263 /// in the current bespoke kernel; same parameterization story as
264 /// `LeakyRelu`.
265 Elu = 116,
266 /// `y = hardshrink(x) = x if |x| > λ else 0`. Hardcoded λ = 0.5 in
267 /// the current bespoke kernel; same parameterization story as
268 /// `LeakyRelu`.
269 Hardshrink = 117,
270 /// `y = softshrink(x) = x - λ if x > λ; x + λ if x < -λ; else 0`.
271 /// Hardcoded λ = 0.5 in the current bespoke kernel; same
272 /// parameterization story as `LeakyRelu`.
273 Softshrink = 118,
274 /// Reserved — `threshold(x; t, v) = x if x > t else v`. Needs the
275 /// parameterized-unary plan (two scalar parameters); not wired yet.
276 Threshold = 119,
277 /// `prelu(x; α) = x if x > 0 else α·x` with per-channel learnable α
278 /// vector (or single scalar α). Uses a distinct plan shape
279 /// (`PReluPlan` / `PReluBackwardPlan`) because α is a tensor operand,
280 /// not a scalar parameter. Wired in Milestone 5.3.
281 PReLU = 120,
282 /// `powi(x; n) = x^n` for a fixed runtime *integer* exponent `n`.
283 /// Distinct from the generic [`BinaryKind::Pow`] (which takes an
284 /// f32 exponent tensor) because the integer-only path can use
285 /// power-by-squaring — faster than `__expf(n · __logf(x))` and
286 /// also well-defined for negative `x` (real `pow(-1.5, 2) = 2.25`,
287 /// no NaN). The exponent is threaded via the `params: [f32; 2]`
288 /// slot 0 with a host-side cast (`n as f32`); slot 1 is unused.
289 /// Reasonable |n| values round-trip through f32 exactly (≤ 2^24).
290 /// Phase 12.1 wires `{f32, f16, bf16, f64}` through `UnaryParamPlan`.
291 PowI = 121,
292 /// `y = step(x) = 1 if x > 0 else 0` — Heaviside step function.
293 /// `step(0) = 0` and `step(-0.0) = 0` (`x > 0` is false at both
294 /// zeros); NaN → 0 (`NaN > 0` is false), matching PyTorch's
295 /// `heaviside(x, values=0)` for the `>` branch. Wires the Phase 31
296 /// `unary_step_*` kernels.
297 Step = 122,
298
299 // ---- Category B: dtype / scalar-shape ops ----
300 /// `y = (TOut) x` — dtype conversion. Heterogeneous input / output
301 /// element types, so it goes through its own `CastPlan` (not the
302 /// same-dtype `UnaryPlan<T, N>`). The discriminant lives here for
303 /// telemetry / SKU-tagging consistency with the rest of the unary
304 /// family. Wired from `fuel-cuda-kernels/cast.cu`.
305 Cast = 130,
306 /// `y = a * x + b` — fused affine (multiply-add) with scalar
307 /// parameters `a` / `b`. Same-dtype input/output but carries two
308 /// scalar parameters, so it gets its own `AffinePlan` (the unified
309 /// `UnaryPlan<T, N>` doesn't carry kernel parameters). Wired from
310 /// `fuel-cuda-kernels/affine.cu`.
311 Affine = 131,
312}
313
314/// Ternary elementwise op discriminant.
315///
316/// Stored as `u16` in [`crate::KernelSku::op`] when
317/// `category == OpCategory::TernaryElementwise`. Same-dtype-input,
318/// same-dtype-output ops only — [`Self::Where`] (which takes a bool
319/// cond + two value tensors) is reserved here but won't be wired via
320/// the same-dtype `TernaryPlan<T, N>`; it gets its own plan shape in
321/// a future session.
322///
323/// Today only [`Self::Clamp`] on `f32` is wired — the Phase 3 ternary
324/// trailblazer SKU. The remaining ops + non-f32 dtypes follow in
325/// fanout sessions.
326#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
327#[repr(u16)]
328#[non_exhaustive]
329pub enum TernaryKind {
330 /// `y = min(max(x, lo), hi)` — clamp `x` to `[lo, hi]`. Trailblazer.
331 Clamp = 0,
332 /// `y = a * b + c` — fused multiply-add. PyTorch `torch.addcmul(c, a, b)`
333 /// with value = 1.
334 Fma = 1,
335 /// `y = self + value * t1 * t2` — PyTorch `addcmul`. Reserved for
336 /// a future parameterized-ternary path (the scalar `value` is a
337 /// runtime parameter, not a tensor operand).
338 Addcmul = 2,
339 /// `y = self + value * t1 / t2` — PyTorch `addcdiv`. Same
340 /// parameterization story as `Addcmul`.
341 Addcdiv = 3,
342 /// `y = cond ? a : b` — element-select. Heterogeneous-dtype inputs
343 /// (cond is bool, a / b match output type) — needs its own plan
344 /// shape, won't be wired via the same-dtype `TernaryPlan`.
345 Where = 4,
346}
347
348/// Gated-activation op discriminant (category C').
349///
350/// Stored as `u16` in [`crate::KernelSku::op`] when
351/// `category == OpCategory::GatedActivation`. All variants follow the
352/// same shape: split input `x` along `split_dim` into two halves
353/// `(a, b)`, output `y = a · gate(b)`. The `gate` function varies by
354/// variant.
355///
356/// Today the FW + BW are wired for `{Glu, ReGlu, SwiGlu, GeGlu} × {f32,
357/// f16, bf16, f64}`. SwiGLU is the trailblazer (highest LLM relevance).
358#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
359#[repr(u16)]
360#[non_exhaustive]
361pub enum GatedActivationKind {
362 /// `y = a · sigmoid(b)` — PyTorch `torch.nn.functional.glu`.
363 Glu = 0,
364 /// `y = a · relu(b)`.
365 ReGlu = 1,
366 /// `y = a · silu(b) = a · b · sigmoid(b)` — Llama / Mistral / Gemma.
367 SwiGlu = 2,
368 /// `y = a · gelu(b)` (exact, erf-based).
369 GeGlu = 3,
370}
371
372/// Padding mode for [`crate::ops::ShapeLayoutKind::Pad`].
373///
374/// Today only [`Self::Constant`] is wired in the Phase 3 trailblazer.
375/// Reflect / Replicate / Circular follow in fanout sessions — each
376/// changes the kernel body's "what value goes in the pad region"
377/// branch but keeps the same plan shape.
378#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
379#[repr(u16)]
380#[non_exhaustive]
381pub enum PadMode {
382 /// Pad with a constant value (`PadDescriptor::value`).
383 Constant = 0,
384 /// Reflect input across the boundary (no edge duplication).
385 Reflect = 1,
386 /// Replicate the boundary value into the pad region.
387 Replicate = 2,
388 /// Wrap-around padding (also called "circular").
389 Circular = 3,
390}
391
392/// Shape / layout op discriminant — Category N.
393///
394/// Tags the kernel SKU for telemetry / autotuner-cache keys. Each
395/// variant has its own Plan type today (PadPlan, ConcatPlan, …)
396/// because their descriptor / args shapes differ enough that one
397/// `ShapeLayoutPlan<T, N>` doesn't fit. The enum exists so all of
398/// them populate `KernelSku::op` from a shared discriminant space.
399#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
400#[repr(u16)]
401#[non_exhaustive]
402pub enum ShapeLayoutKind {
403 /// `F.pad(x, pad, mode='constant', value=v)` — Phase 3 trailblazer.
404 Pad = 0,
405 /// `torch.cat(tensors, dim)` — variable-arity input. Reserved.
406 Concat = 1,
407 /// Materialized `torch.permute(x, dims)` (strided-view materialization
408 /// when needed). Reserved.
409 Permute = 2,
410 /// `x.repeat(...)` / `torch.tile(x, ...)`. Reserved.
411 Repeat = 3,
412 /// `torch.flip(x, dims)` — reverse along axes. Reserved.
413 Flip = 4,
414 /// `torch.roll(x, shifts, dims)` — shift along axes. Reserved.
415 Roll = 5,
416 /// `torch.meshgrid(*tensors)` — N rank-1 → N rank-N. Reserved.
417 Meshgrid = 6,
418 /// `torch.full(shape, value)` / `Tensor.fill_(value)` — fill every
419 /// element of an output tensor with a scalar constant. Wired from
420 /// `fuel-cuda-kernels/fill.cu`.
421 Fill = 7,
422 /// `dest[start_0..end_0, ..., start_{N-1}..end_{N-1}] = source`
423 /// (assign, not accumulate). Per-axis range write. Phase 13.1
424 /// trailblazer — driven by Fuel team's persistent KV-cache append
425 /// (autoregressive decoding). See
426 /// `baracuda_kernels::WriteSlicePlan`.
427 WriteSlice = 8,
428 /// Strided→contiguous materialization (`torch.Tensor.contiguous`).
429 /// Phase 13.2: closes the D2H→CPU contiguize→H2D fallback cliff
430 /// for non-contiguous CUDA inputs. Byte-level dtype-agnostic
431 /// (sizeof-templated kernel) covering every byte-aligned dtype;
432 /// nibble (S4 / U4) shipped behind a documented innermost-stride
433 /// constraint. See `baracuda_kernels::ContiguizePlan`.
434 Contiguize = 9,
435 /// `torch.triu(input, diagonal)` — keep upper triangular part of
436 /// the last two dims of `input`; zero everything below the
437 /// `diagonal`-th diagonal. Batch dims (anything before the last
438 /// two) are independently masked. Phase 13.4 trailblazer — driven
439 /// by Fuel team's CPU-only triu/tril gap. See
440 /// `baracuda_kernels::TriuPlan`.
441 Triu = 10,
442 /// `torch.tril(input, diagonal)` — keep lower triangular part of
443 /// the last two dims of `input`; zero everything above the
444 /// `diagonal`-th diagonal. Sibling of [`Self::Triu`] with the
445 /// predicate flipped. See `baracuda_kernels::TrilPlan`.
446 Tril = 11,
447}
448
449/// Index-returning reduction discriminant — Phase 4 (`ArgReducePlan`).
450///
451/// Distinct from [`ReduceKind`] because the output dtype is i64
452/// (index), not the input value dtype. Goes through its own plan
453/// shape (`ArgReducePlan<T, N>`) for the heterogeneous-output-dtype
454/// case.
455#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
456#[repr(u16)]
457#[non_exhaustive]
458pub enum ArgReduceKind {
459 /// Index of the maximum along the reduced axis. Ties broken by
460 /// first occurrence (smallest index wins) — PyTorch convention.
461 Argmax = 0,
462 /// Index of the minimum along the reduced axis.
463 Argmin = 1,
464}
465
466/// Reduction op discriminant — Phase 4 (Category E).
467///
468/// Output shape differs from input: the reduced axis collapses to size
469/// 1 (keepdim convention). Other variants are reserved for fanout.
470#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
471#[repr(u16)]
472#[non_exhaustive]
473pub enum ReduceKind {
474 /// Sum along the reduced axis. Phase 4 trailblazer.
475 Sum = 0,
476 /// Arithmetic mean along the reduced axis.
477 Mean = 1,
478 /// Maximum value along the reduced axis.
479 Max = 2,
480 /// Minimum value along the reduced axis.
481 Min = 3,
482 /// Product along the reduced axis.
483 Prod = 4,
484 /// Sample variance (Bessel-corrected) along the reduced axis.
485 Var = 5,
486 /// Sample standard deviation along the reduced axis.
487 Std = 6,
488 /// `||x||_2` along the reduced axis.
489 Norm2 = 7,
490 /// `argmax` along the reduced axis — returns indices (different
491 /// output dtype). Will need a separate plan shape, reserved here.
492 Argmax = 8,
493 /// `argmin` along the reduced axis. Will need a separate plan
494 /// shape.
495 Argmin = 9,
496 /// `any` (logical OR) along the reduced axis.
497 Any = 10,
498 /// `all` (logical AND) along the reduced axis.
499 All = 11,
500 /// `logsumexp(x) = log(sum(exp(x - max)))`, numerically stable.
501 LogSumExp = 12,
502 /// `trace(M) = sum(diag(M))` — sum of the diagonal of a 2-D
503 /// square matrix. Reduces *both* axes via the `i == i` constraint
504 /// rather than a single reduce-axis, so dispatch goes through a
505 /// dedicated `TracePlan` (separate from `ReducePlan`); the
506 /// discriminant lives here for telemetry / SKU-tagging consistency
507 /// with the rest of the reduction family.
508 Trace = 13,
509 /// `count_nonzero(x)` along the reduced axis — output is i64
510 /// (PyTorch `torch.count_nonzero` returns int64). Heterogeneous
511 /// output dtype (always i64 regardless of input), so dispatch
512 /// goes through a dedicated `CountReducePlan` (separate from
513 /// `ReducePlan`); the discriminant lives here for telemetry /
514 /// SKU-tagging consistency with the rest of the reduction family.
515 CountNonzero = 14,
516}
517
518/// Broadcast-reverse reduction op discriminant — `ReduceToPlan`.
519///
520/// The autograd primitive that undoes a forward `BroadcastTo`: for
521/// each output cell, reduce every input cell that broadcasts TO it.
522/// Distinct from [`ReduceKind`] because the reduction collapses an
523/// arbitrary *set* of axes in one launch (every dim where
524/// `output_shape[d] == 1` while `input_shape[d] != 1`) rather than a
525/// single `reduce_axis`, so dispatch goes through its own plan shape
526/// (`ReduceToPlan<T, N>` in `baracuda-kernels`).
527///
528/// Discriminants mirror [`ReduceKind`]'s values for the same logical
529/// op so `KernelSku::op` tags consistently across the reduction
530/// family (hence the gap at 1 — there is no broadcast-reverse Mean).
531#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
532#[repr(u16)]
533#[non_exhaustive]
534pub enum ReduceToOp {
535 /// Sum over the broadcast set. Identity on an empty reduce set
536 /// (any reduced `input_shape[d] == 0`): `0`.
537 Sum = 0,
538 /// Maximum over the broadcast set. Identity on an empty reduce
539 /// set: `-FLT_MAX` / `-DBL_MAX` per the kernel's `AccMax` policy —
540 /// the most-negative *finite* value for f32 / f64 outputs. For
541 /// f16 / bf16 the f32 identity overflows the storage dtype on the
542 /// final narrowing store and lands as `-inf`.
543 Max = 2,
544 /// Minimum over the broadcast set. Identity on an empty reduce
545 /// set: `+FLT_MAX` / `+DBL_MAX` per the kernel's `AccMin` policy —
546 /// the most-positive *finite* value for f32 / f64 outputs; `+inf`
547 /// for f16 / bf16 (same narrowing overflow as [`Self::Max`]).
548 Min = 3,
549 /// Product over the broadcast set. Identity on an empty reduce
550 /// set: `1`.
551 Prod = 4,
552}
553
554/// Softmax-family op discriminant — category H from the comprehensive
555/// plan.
556///
557/// Stored as `u16` in [`crate::KernelSku::op`] when
558/// `category == OpCategory::Softmax`. All variants apply a
559/// length-preserving transform along a single axis (output shape ==
560/// input shape — distinct from reductions, like scans).
561///
562/// Today wired: `{Softmax, LogSoftmax} × {f32, f16, bf16, f64}` —
563/// FW + BW. `GumbelSoftmax` (needs RNG state from Phase 4 random) and
564/// `Sparsemax` (different gradient — projection onto simplex) are
565/// reserved-but-deferred.
566#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
567#[repr(u16)]
568#[non_exhaustive]
569pub enum SoftmaxKind {
570 /// `y[k] = exp(x[k] - max(x)) / Σ_j exp(x[j] - max(x))`
571 /// — numerically stable softmax.
572 Softmax = 0,
573 /// `y[k] = x[k] - logsumexp(x)` — log-domain softmax, also stable.
574 /// Output is the elementwise log of `Softmax(x)`.
575 LogSoftmax = 1,
576 /// `y = (x + Gumbel(0,1)) / τ → softmax` — reserved.
577 GumbelSoftmax = 2,
578 /// `y = ProjSimplex(x)` — reserved (different gradient than softmax).
579 Sparsemax = 3,
580}
581
582/// Scan (associative prefix) op discriminant — category F from the
583/// comprehensive plan.
584///
585/// Stored as `u16` in [`crate::KernelSku::op`] when
586/// `category == OpCategory::Scan`. Output shape equals input shape —
587/// scans are length-preserving along the scan axis (in contrast with
588/// reductions, which collapse the axis to size 1). Inclusive scan by
589/// default (PyTorch convention: `y[i] = op(x[0], x[1], …, x[i])`).
590/// Direction is controlled by the descriptor's `reverse` flag.
591///
592/// Today wired: `{Cumsum} × {f32, f16, bf16, f64}` (FW + BW) as the
593/// scan trailblazer. Cumprod / Cummax / Cummin land in fanout;
594/// LogCumsumExp and the JAX-style generic `associative_scan` are
595/// reserved-but-deferred (numerics / generic-functor work).
596#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
597#[repr(u16)]
598#[non_exhaustive]
599pub enum ScanKind {
600 /// `y[i] = Σ_{j ≤ i} x[j]` — inclusive prefix sum.
601 Cumsum = 0,
602 /// `y[i] = ∏_{j ≤ i} x[j]` — inclusive prefix product.
603 Cumprod = 1,
604 /// `y[i] = max(x[0..=i])` — running maximum.
605 Cummax = 2,
606 /// `y[i] = min(x[0..=i])` — running minimum.
607 Cummin = 3,
608 /// `y[i] = log(Σ_{j ≤ i} exp(x[j]))` — numerically stable (running
609 /// max subtraction). Reserved.
610 LogCumsumExp = 4,
611}
612
613/// Binary comparison op discriminant.
614///
615/// Stored as `u16` in [`crate::KernelSku::op`] when
616/// `category == OpCategory::BinaryElementwise` and the SKU is from the
617/// **comparison family** — distinguished from [`BinaryKind`] because
618/// the output dtype is fixed to `u8` (PyTorch / NumPy convention: bool
619/// stored as 1 byte, 0 = false, 1 = true) regardless of the input
620/// element type.
621///
622/// Today only [`Self::Eq`] on `f32` is wired — the Phase 3 comparison
623/// trailblazer. The other variants are reserved discriminants for the
624/// fanout sessions.
625///
626/// Why a separate enum (rather than reusing [`BinaryKind`]): the
627/// dispatch shape differs — these ops produce a different dtype than
628/// they consume, so they need their own Plan type
629/// (`BinaryCmpPlan<T, N>` with `TensorMut<u8>` output) instead of
630/// `BinaryPlan<T, N>` with `TensorMut<T>` output. The reserved Eq /
631/// Ne / Gt / Ge / Lt / Le slots in `BinaryKind` are vestigial — they
632/// will never be wired into the same-dtype binary path.
633#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
634#[repr(u16)]
635#[non_exhaustive]
636pub enum BinaryCmpKind {
637 /// `y = (a == b)` — elementwise equality. Trailblazer SKU.
638 Eq = 0,
639 /// `y = (a != b)`.
640 Ne = 1,
641 /// `y = (a > b)`.
642 Gt = 2,
643 /// `y = (a >= b)`.
644 Ge = 3,
645 /// `y = (a < b)`.
646 Lt = 4,
647 /// `y = (a <= b)`.
648 Le = 5,
649}
650
651/// Normalization op discriminant — category G from the comprehensive plan.
652///
653/// Stored as `u16` in [`crate::KernelSku::op`] when
654/// `category == OpCategory::Normalization`. The variants differ in
655/// which axes are reduced for the per-row statistics and how the
656/// affine parameters (gamma / beta) are indexed.
657///
658/// Today wired: `{RMSNorm, LayerNorm, BatchNorm, GroupNorm,
659/// InstanceNorm} × {f32, f16, bf16, f64}` — FW + BW. RMSNorm /
660/// LayerNorm support **multi-axis normalization** via a bitmask
661/// (PyTorch's `normalized_shape` — must be a suffix of the input
662/// shape). InstanceNorm is implemented as a thin wrapper around
663/// GroupNorm with `num_groups == c_extent` (shares kernel symbols).
664///
665/// BatchNorm is **training-mode-only** for the trailblazer — it
666/// computes per-channel stats from the batch and saves them for BW.
667/// Inference mode (use of running statistics, reducing to a per-
668/// channel affine multiply) is reserved for a follow-up. `WeightNorm`
669/// (a parameterization rather than a plain op) and `LocalResponseNorm`
670/// (rarely used today) are explicitly deferred.
671#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
672#[repr(u16)]
673#[non_exhaustive]
674pub enum NormalizationKind {
675 /// `y = x / sqrt(mean(x², over norm_axes) + eps) * gamma`.
676 /// Llama / Mistral / Gemma block-pre-norm. Trailblazer SKU.
677 RMSNorm = 0,
678 /// `y = (x - mean) / sqrt(var + eps) * gamma + beta`. PyTorch's
679 /// `torch.nn.LayerNorm` with biased / "population" variance.
680 LayerNorm = 1,
681 /// Per-group-of-channels statistics. `y[n, c, ...] = (x[n, c, ...] -
682 /// mean[n, g]) / sqrt(var[n, g] + eps) * gamma[c] + beta[c]`,
683 /// `g = c / (C / num_groups)`. PyTorch `torch.nn.GroupNorm`.
684 GroupNorm = 2,
685 /// Per-channel statistics across batch + spatial. Training-mode
686 /// only — saves `(saved_mean, saved_rstd)` of shape `[C]`. Inference
687 /// mode (running stats) deferred. PyTorch `torch.nn.BatchNormNd`.
688 BatchNorm = 3,
689 /// Per-`(sample, channel)` statistics across spatial only. PyTorch
690 /// `torch.nn.InstanceNormNd`. Equivalent to GroupNorm with
691 /// `num_groups == num_channels`; same kernel symbols.
692 InstanceNorm = 4,
693}
694
695/// Loss op discriminant — category R from the comprehensive plan.
696///
697/// Stored as `u16` in [`crate::KernelSku::op`] when
698/// `category == OpCategory::Loss`. Each variant has its own Plan type
699/// today (different argument shapes — MSE / BCE / KLDiv take two
700/// same-dtype tensor inputs, NLL / CrossEntropy take a `T` input plus an
701/// `i64` target index tensor) but they share the [`LossReduction`]
702/// enum for selecting per-cell / mean / sum output shape.
703///
704/// Today wired: `{Mse, Nll, CrossEntropy, Bce, KlDiv} × {f32, f16, bf16,
705/// f64}` — FW + BW. `HingeEmbedding`, `L1`, `SmoothL1`, `MarginRanking`,
706/// `TripletMargin`, `CtcLoss`, and `PoissonNllLoss` are reserved
707/// discriminants for future fanout.
708#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
709#[repr(u16)]
710#[non_exhaustive]
711pub enum LossKind {
712 /// `y = mean((pred - target)²)` (or sum / per-cell). PyTorch
713 /// `torch.nn.functional.mse_loss`.
714 Mse = 0,
715 /// `y = -mean(input[target_idx[i]])` along the feature axis. PyTorch
716 /// `torch.nn.functional.nll_loss`. Heterogeneous-dtype: input `T`,
717 /// target `i64`.
718 Nll = 1,
719 /// `y = NLLLoss(LogSoftmax(input), target)` — fused for numerical
720 /// stability. PyTorch `torch.nn.functional.cross_entropy`. Today wired
721 /// for class-index target only (`i64`); soft-target CE is reserved.
722 CrossEntropy = 2,
723 /// `y = -mean(target·log(pred) + (1-target)·log(1-pred))`. PyTorch
724 /// `torch.nn.functional.binary_cross_entropy`. Caller ensures
725 /// pred ∈ (0, 1).
726 Bce = 3,
727 /// `y = mean(target·(log(target) - input))`. PyTorch
728 /// `torch.nn.functional.kl_div` with the "input is log-prob"
729 /// convention.
730 KlDiv = 4,
731 /// `y = mean(|pred - target|)` (or sum / per-cell). PyTorch
732 /// `torch.nn.functional.l1_loss`.
733 L1 = 5,
734 /// Smooth L1 / "Huber-with-β" loss. PyTorch
735 /// `torch.nn.functional.smooth_l1_loss`.
736 SmoothL1 = 6,
737 /// `y = mean(input if t==1 else max(0, margin - input))`. PyTorch
738 /// `torch.nn.functional.hinge_embedding_loss`. Heterogeneous-dtype:
739 /// input is `T`, target is `i64` (±1).
740 HingeEmbedding = 7,
741 /// `y = mean(max(0, -t · (x1 - x2) + margin))`. PyTorch
742 /// `torch.nn.functional.margin_ranking_loss`. Target `t` is `T` (±1).
743 MarginRanking = 8,
744 /// `y = mean(max(0, ||a-p||_p - ||a-n||_p + margin))`. PyTorch
745 /// `torch.nn.functional.triplet_margin_loss`. 2-D input `[N, D]`.
746 TripletMargin = 9,
747 /// Reserved — `torch.nn.functional.ctc_loss`.
748 Ctc = 10,
749 /// `y = mean(exp(input) - target · input)` (default `log_input=true`).
750 /// PyTorch `torch.nn.functional.poisson_nll_loss`.
751 PoissonNll = 11,
752 /// Huber loss (separate from SmoothL1 — PyTorch
753 /// `torch.nn.functional.huber_loss`).
754 Huber = 12,
755 /// Numerically stable BCE for raw logits. PyTorch
756 /// `torch.nn.functional.binary_cross_entropy_with_logits`.
757 BceWithLogits = 13,
758 /// Gaussian NLL. PyTorch `torch.nn.GaussianNLLLoss`.
759 GaussianNll = 14,
760 /// `y = (1 - cos(x1, x2))` if `t==1` else `max(0, cos(x1, x2) - margin)`,
761 /// then mean. PyTorch `torch.nn.functional.cosine_embedding_loss`.
762 /// 2-D input `[N, D]`. Target is `T` (±1.0).
763 CosineEmbedding = 15,
764 /// `y = mean_i Σ_{j != t_i} max(0, margin - input[i, t_i] + input[i, j])^p / C`.
765 /// PyTorch `torch.nn.functional.multi_margin_loss`. Input `[N, C]`,
766 /// target `[N]` `i64` class indices.
767 MultiMargin = 16,
768 /// Multi-label margin loss. PyTorch
769 /// `torch.nn.functional.multilabel_margin_loss`. Input `[N, C]`,
770 /// target `[N, C]` `i64` (positive class indices followed by -1
771 /// padding sentinel).
772 MultilabelMargin = 17,
773 /// `y = mean(-mean_c(target·log(sigmoid(x)) + (1-target)·log(1-sigmoid(x))))`.
774 /// PyTorch `torch.nn.functional.multilabel_soft_margin_loss`.
775 /// Input `[N, C]`, target `[N, C]` `T`.
776 MultilabelSoftMargin = 18,
777 /// Fused Linear Cross-Entropy. `loss = CE(input @ weight^T, target)`
778 /// **without** materializing the `[BT, V]` logits tensor — the
779 /// projection GEMM and the cross-entropy reduction run together
780 /// in a chunked outer loop. Backward produces `grad_input` and
781 /// `grad_weight` directly during the forward pass; backward call
782 /// just multiplies them by the upstream `dy` scalar. Algorithm:
783 /// LinkedIn Liger-Kernel
784 /// (`liger_kernel/ops/fused_linear_cross_entropy.py`).
785 /// Saves ~5-10 GiB at `vocab=128K, BT=16K` (Llama-3-class) by
786 /// streaming logits in `chunk_size`-row tiles.
787 FusedLinearCrossEntropy = 19,
788}
789
790/// CrossEntropy target-tensor kind. Selects between PyTorch's two
791/// target formats: class indices (`i64[N]`) and soft probabilities
792/// (`T[N, C]` — used for label smoothing / distillation).
793#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
794#[repr(u8)]
795pub enum CrossEntropyTargetKind {
796 /// Target is class-index `i64[N]`.
797 ClassIndex = 0,
798 /// Target is soft probability `T[N, C]` (same dtype as input).
799 SoftProb = 1,
800}
801
802/// Loss reduction mode. Selects the output shape and the final scalar
803/// scaling for a [`LossKind`] plan. PyTorch's `reduction` parameter.
804#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
805#[repr(u8)]
806pub enum LossReduction {
807 /// Output is per-cell (same shape as the loss surface). No reduction.
808 None = 0,
809 /// Output is a scalar — sum of per-cell terms divided by element count.
810 Mean = 1,
811 /// Output is a scalar — sum of per-cell terms (no divide).
812 Sum = 2,
813}
814
815/// Random / sampling op discriminant.
816///
817/// Stored as `u16` in [`crate::KernelSku::op`] when
818/// `category == OpCategory::Random`. Phase 4.5 wires:
819/// - [`Self::Uniform`] (f32, f64) — `y ~ U(low, high)` via cuRAND.
820/// - [`Self::Normal`] (f32, f64) — `y ~ N(mean, std)` via cuRAND.
821/// - [`Self::Bernoulli`] (Bool output) — `y = (rand < p) ? 1 : 0` via
822/// cuRAND uniform + custom threshold kernel.
823///
824/// Multinomial / Randint / exponential / gamma / quasi-random are
825/// reserved discriminants for future milestones.
826#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
827#[repr(u16)]
828#[non_exhaustive]
829pub enum RandomKind {
830 /// `y[i] ~ U(low, high)` — uniform on the half-open interval. Plan
831 /// descriptor `param1 = low`, `param2 = high`.
832 Uniform = 0,
833 /// `y[i] ~ N(mean, std)` — Gaussian. Plan descriptor
834 /// `param1 = mean`, `param2 = stddev`.
835 Normal = 1,
836 /// `y[i] = 1 if uniform < p else 0`, Bool output. Plan descriptor
837 /// `param1 = p`. `param2` ignored.
838 Bernoulli = 2,
839 /// `y[b] = sample one cell from row probs[b, :]` using inverse-CDF
840 /// sampling. Phase 46 wires the FlashInfer sort-free Top-K /
841 /// Top-P / Min-P / combined Top-K + Top-P samplers under this
842 /// discriminant via the `TopKTopPSamplingPlan` in baracuda-kernels.
843 Multinomial = 3,
844}
845
846/// Linear-algebra (dense) op discriminant — covers the cuSOLVER family
847/// shipped in Milestone 6.3.
848///
849/// Stored as `u16` in [`crate::KernelSku::op`] when
850/// `category == OpCategory::Linalg`. Today the four canonical PyTorch /
851/// JAX dense linalg ops are wired:
852///
853/// - [`Self::Cholesky`] — `A = L · L^T` (symmetric positive-definite).
854/// Batched via `cusolverDnSpotrfBatched` / `cusolverDnDpotrfBatched`.
855/// - [`Self::Lu`] — `P · A = L · U`. Batched via
856/// `cusolverDnSgetrfBatched` / `cusolverDnDgetrfBatched`.
857/// - [`Self::Qr`] — `A = Q · R`. cuSOLVER has no batched variant; 2-D
858/// only.
859/// - [`Self::Svd`] — `A = U · diag(S) · V^T`. cuSOLVER 2-D only.
860///
861/// Dtype coverage is `f32` + `f64` — cuSOLVER's dense API does not
862/// support `f16` / `bf16` for these factorizations. Reserved variants
863/// (`Inverse`, `Eig`, `Solve`, `LeastSquares`, `MatrixExp`) follow in
864/// future milestones.
865#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
866#[repr(u16)]
867#[non_exhaustive]
868pub enum LinalgKind {
869 /// Cholesky factorization `A = L · L^T` (lower) or `A = U^T · U`
870 /// (upper). Input must be symmetric positive-definite.
871 Cholesky = 0,
872 /// LU factorization with partial pivoting `P · A = L · U`. Returns
873 /// the packed `LU` factors plus an `i32` pivot vector.
874 Lu = 1,
875 /// QR factorization `A = Q · R`. Computes full `Q` (`[M, M]`) and
876 /// the upper-triangular `R` (`[M, N]`) via `geqrf` + `ormqr`.
877 Qr = 2,
878 /// Singular value decomposition `A = U · diag(S) · V^T`. cuSOLVER
879 /// 2-D only; `full_matrices` controls whether `U`/`V^T` are full
880 /// (`[M,M]` / `[N,N]`) or thin (`[M,K]` / `[K,N]`) where
881 /// `K = min(M, N)`.
882 Svd = 3,
883 /// Matrix inverse `A^{-1}` via `getrf` + `getrs` over an identity
884 /// RHS. Wired in Milestone 6.9.
885 Inverse = 4,
886 /// General (non-symmetric) eigen-decomposition `A · v = λ · v`. Wired
887 /// via `cusolverDnXgeev` in Milestone 6.12. Always emits complex
888 /// eigenvalues (and optional left / right complex eigenvectors).
889 Eig = 5,
890 /// Linear solve `A · X = B` via `getrf` + `getrs`. Wired in
891 /// Milestone 6.9.
892 Solve = 6,
893 /// Least-squares solve `min ||A·x - b||²` via cuSOLVER's
894 /// mixed-precision iterative-refinement `_gels` routine. Wired in
895 /// Milestone 6.11.
896 LeastSquares = 7,
897 /// Reserved — matrix exponential / matrix functions.
898 MatrixExp = 8,
899 /// Batched QR factorization `A_b = Q_b · R_b` via
900 /// `cusolverDn*geqrfBatched`. Wired in Milestone 6.11.
901 BatchedQr = 9,
902 /// Batched SVD via Jacobi `cusolverDn*gesvdjBatched`. Wired in
903 /// Milestone 6.11.
904 BatchedSvd = 10,
905 /// Symmetric / Hermitian eigen-decomposition `A · v = λ · v` (real
906 /// eigenvalues). Wired via `cusolverDn{S,D}syevd` /
907 /// `cusolverDn{C,Z}heevd` in Milestone 6.12.
908 Eigh = 11,
909 /// Rectangular batched approximate-SVD via cuSOLVER's
910 /// `gesvdaStridedBatched`. Unlike [`Self::BatchedSvd`] (which is
911 /// square-only Jacobi), this routine accepts arbitrary `m × n` per
912 /// batch slot, uses element-strides between slots, and reports per-
913 /// slot residual Frobenius norms to a host array. Wired in
914 /// Milestone 6.15.
915 BatchedSvda = 12,
916 /// Bespoke batched-`ormqr` — applies the implicit `Q` from a
917 /// [`Self::BatchedQr`] packed output to a batch of matrices `C`,
918 /// all slots fused into one CUDA launch. cuSOLVER's `ormqr` is
919 /// non-batched, so in the small-matrix regime where batched-QR is
920 /// most useful the per-slot launch latency dominates; this bespoke
921 /// kernel amortizes one launch over the whole batch. Side = Left,
922 /// op ∈ {N, T} in the trailblazer (Right + complex variants
923 /// deferred). Wired in Milestone 6.14.
924 BatchedOrmqr = 13,
925 /// Bespoke "materialize dense Q and R from batched-`geqrf` packed
926 /// output". Tiny upper-triangle-copy kernel for R; identity-stage
927 /// + [`Self::BatchedOrmqr`] for Q. Wired in Milestone 6.14 as the
928 /// consumer of `BatchedOrmqrPlan`.
929 BatchedQrMaterialize = 14,
930 /// WY-blocked batched-`ormqr` — applies the implicit `Q` (or `Q^T`)
931 /// from a [`Self::BatchedQr`] packed output to a batch of matrices
932 /// `C` at GEMM-rates by fusing groups of `nb` consecutive Householder
933 /// reflectors into a block reflector `(I - V·T·V^T)` and applying it
934 /// via three cuBLAS strided-batched GEMMs per block. Sibling to
935 /// [`Self::BatchedOrmqr`] (the reflector-by-reflector GEMV-rates
936 /// variant); callers pick by problem size — WY wins decisively for
937 /// `M, N > ~16`, the reflector kernel wins for tiny inputs.
938 /// Side = Left, op ∈ {N, T} in the trailblazer. Wired in
939 /// Milestone 6.17.
940 BatchedOrmqrWy = 15,
941}
942
943/// Fill-mode tag for triangular linalg ops (Cholesky / triangular solve).
944///
945/// Selects whether the factor lives in the lower or upper triangle of
946/// the in-place output matrix. The row-major-input → column-major-cuSOLVER
947/// adapter at the plan layer flips this when handing the descriptor down
948/// to cuSOLVER (a row-major lower-L is bit-identical to a column-major
949/// upper-U^T over the same byte storage).
950#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
951#[repr(u8)]
952pub enum FillMode {
953 /// Lower triangular (the usual PyTorch / scipy default).
954 Lower = 0,
955 /// Upper triangular.
956 Upper = 1,
957}
958
959/// FFT-family op discriminant — Category U from the comprehensive plan.
960///
961/// Stored as `u16` in [`crate::KernelSku::op`] when
962/// `category == OpCategory::Fft`. Milestone 6.4 wires the four
963/// canonical PyTorch / JAX 1-D FFTs (`fft` / `ifft` / `rfft` / `irfft`)
964/// plus the two index-permutation helpers (`fftshift` / `ifftshift`).
965///
966/// 1-D only for the trailblazer. Multi-D FFTs (`fft2`, `fftn`, …) and
967/// arbitrary-axis FFTs follow in fanout sessions — they don't require
968/// new cuFFT bindings, just additional descriptor shape + plan glue.
969///
970/// Dtype coverage: `f32` (single precision) and `f64` (double
971/// precision) only. cuFFT's main API does not expose `f16` / `bf16`
972/// for native transforms. Callers needing reduced precision must cast
973/// on either side. Spectrum-domain tensors use [`crate::Complex32`] /
974/// [`crate::Complex64`] for the interleaved real/imag pairs.
975///
976/// Normalization: forward transforms are unnormalized; inverse
977/// transforms are normalized by `1/N` to match PyTorch's
978/// `norm="backward"` default. cuFFT itself returns `N · IFFT(x)`; the
979/// plan layer multiplies by `1/N` after the inverse exec.
980#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
981#[repr(u16)]
982#[non_exhaustive]
983pub enum FftKind {
984 /// `y = FFT(x)` — complex-to-complex forward transform (unnormalized).
985 /// PyTorch `torch.fft.fft`. Both input and output are complex with
986 /// the same shape `[batch, n]`.
987 Fft = 0,
988 /// `y = IFFT(x)` — complex-to-complex inverse transform, normalized
989 /// by `1/N` to match PyTorch's `norm="backward"`. PyTorch
990 /// `torch.fft.ifft`. Both input and output are complex `[batch, n]`.
991 Ifft = 1,
992 /// `y = RFFT(x)` — real-to-complex forward transform (unnormalized).
993 /// PyTorch `torch.fft.rfft`. Input is real `[batch, n]`, output is
994 /// complex `[batch, n/2 + 1]` (Hermitian-half).
995 Rfft = 2,
996 /// `y = IRFFT(x, n)` — complex-to-real inverse transform, normalized
997 /// by `1/N`. PyTorch `torch.fft.irfft`. Input is complex
998 /// `[batch, n/2 + 1]`, output is real `[batch, n]`. The output
999 /// length `n` is a required descriptor parameter (cannot be inferred
1000 /// from the Hermitian-half input shape — both `2*(n/2)` and
1001 /// `2*(n/2)+1` map to the same Hermitian-half length).
1002 Irfft = 3,
1003 /// `fftshift` — shift the zero-frequency component to the center of
1004 /// the spectrum. PyTorch `torch.fft.fftshift` (matches NumPy's
1005 /// `np.fft.fftshift`).
1006 ///
1007 /// Equivalent to `roll(x, n // 2)`, giving:
1008 /// `y[i] = x[(i - n // 2) mod n] = x[(i + (n+1) // 2) mod n]`.
1009 ///
1010 /// Bit-exact (pure index permutation, no arithmetic on values).
1011 FftShift = 4,
1012 /// `ifftshift` — true inverse of `fftshift`:
1013 /// `ifftshift(fftshift(x)) == x` for any `n`. PyTorch
1014 /// `torch.fft.ifftshift`.
1015 ///
1016 /// Equivalent to `roll(x, -(n // 2))`, giving:
1017 /// `y[i] = x[(i + n // 2) mod n]`.
1018 ///
1019 /// For even `n` this is identical to `fftshift` (the `n/2` offset
1020 /// is self-inverse mod `n`); for odd `n` the two cyclic offsets
1021 /// differ by one cell. Bit-exact.
1022 IfftShift = 5,
1023}
1024
1025/// Convolution-family op discriminant — Category I from the
1026/// comprehensive plan.
1027///
1028/// Stored as `u16` in [`crate::KernelSku::op`] when
1029/// `category == OpCategory::Convolution`. Each variant maps to a
1030/// distinct cuDNN exec path (forward, data-gradient, filter-gradient)
1031/// of the underlying convolution descriptor. The dimensional axis
1032/// (1-D / 2-D / 3-D), padding / stride / dilation, and depthwise /
1033/// transposed flavors live on the per-plan descriptor — they don't
1034/// fan out a separate enum slot here.
1035///
1036/// Today wired: `Conv2d` × `{f32, f64, f16, bf16}` (FW + BW data +
1037/// BW filter) via cuDNN. Conv1d / Conv3d / ConvTranspose* / depthwise
1038/// / `unfold` / `fold` are reserved discriminants for fanout
1039/// milestones.
1040#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
1041#[repr(u16)]
1042#[non_exhaustive]
1043pub enum ConvKind {
1044 /// 2-D convolution forward pass. PyTorch
1045 /// `torch.nn.functional.conv2d`. Trailblazer for Phase 7.
1046 Conv2d = 0,
1047 /// 2-D convolution data-gradient pass (computes `dx` from `dy`
1048 /// and the filter `w`). PyTorch's autograd-internal
1049 /// `conv2d_backward_input`.
1050 Conv2dBackwardData = 1,
1051 /// 2-D convolution filter-gradient pass (computes `dw` from `x`
1052 /// and `dy`). PyTorch's autograd-internal
1053 /// `conv2d_backward_weight`.
1054 Conv2dBackwardFilter = 2,
1055 /// 1-D convolution forward. Reserved.
1056 Conv1d = 3,
1057 /// 1-D convolution data-gradient. Reserved.
1058 Conv1dBackwardData = 4,
1059 /// 1-D convolution filter-gradient. Reserved.
1060 Conv1dBackwardFilter = 5,
1061 /// 3-D convolution forward. Reserved.
1062 Conv3d = 6,
1063 /// 3-D convolution data-gradient. Reserved.
1064 Conv3dBackwardData = 7,
1065 /// 3-D convolution filter-gradient. Reserved.
1066 Conv3dBackwardFilter = 8,
1067 /// 2-D transposed convolution (fractionally-strided / "deconv").
1068 /// Forward pass.
1069 ConvTranspose2d = 9,
1070 /// 2-D transposed convolution backward. Reserved — backward is
1071 /// dispatched through the same plan via `run_bw_data` / `run_dw`.
1072 ConvTranspose2dBackward = 10,
1073 /// Depthwise 2-D convolution (`groups == c_in`). Today callers
1074 /// route through the generic `Conv2dPlan` with `groups` set on
1075 /// the descriptor — cuDNN's `cudnnSetConvolutionGroupCount`
1076 /// detects the depthwise path automatically.
1077 DepthwiseConv2d = 11,
1078 /// `torch.nn.functional.unfold` — extract sliding windows. Reserved.
1079 Unfold = 12,
1080 /// `torch.nn.functional.fold` — inverse of unfold. Reserved.
1081 Fold = 13,
1082 /// 1-D transposed convolution forward.
1083 ConvTranspose1d = 14,
1084 /// 1-D transposed convolution data-gradient.
1085 ConvTranspose1dBackwardData = 15,
1086 /// 1-D transposed convolution filter-gradient.
1087 ConvTranspose1dBackwardFilter = 16,
1088 /// 2-D transposed convolution data-gradient.
1089 ConvTranspose2dBackwardData = 17,
1090 /// 2-D transposed convolution filter-gradient.
1091 ConvTranspose2dBackwardFilter = 18,
1092 /// 3-D transposed convolution forward.
1093 ConvTranspose3d = 19,
1094 /// 3-D transposed convolution data-gradient.
1095 ConvTranspose3dBackwardData = 20,
1096 /// 3-D transposed convolution filter-gradient.
1097 ConvTranspose3dBackwardFilter = 21,
1098 /// 2-D im2col — `torch.nn.functional.unfold` (Phase 19.3). Extracts
1099 /// sliding windows from an NCHW input into an
1100 /// `[N, C·kh·kw, h_out·w_out]` column-shaped matrix. Distinct from
1101 /// the reserved [`Self::Unfold`] discriminant for forward-source-
1102 /// compat; the 19.3 wiring routes through this discriminant.
1103 Im2Col2d = 22,
1104 /// 1-D im2col (NCL → `[N, C·kl, l_out]`).
1105 Im2Col1d = 23,
1106 /// 1-D col2im — inverse of [`Self::Im2Col1d`]. Atomic-add scatter.
1107 Col2Im1d = 24,
1108}
1109
1110/// Pooling-family op discriminant — Category J from the comprehensive
1111/// plan.
1112///
1113/// Stored as `u16` in [`crate::KernelSku::op`] when
1114/// `category == OpCategory::Pooling`. Each variant maps to a distinct
1115/// cuDNN pooling exec path (forward / backward) for one of three
1116/// pooling modes: max, average-include-padding, average-exclude-padding.
1117/// PyTorch's `nn.MaxPool2d` corresponds to [`Self::MaxPool2d`];
1118/// `nn.AvgPool2d` defaults to `count_include_pad=False` which maps to
1119/// [`Self::AvgPool2dExcludePad`].
1120///
1121/// Today wired: `{MaxPool2d, AvgPool2d} × {f32, f64, f16, bf16}` (FW +
1122/// BW) via cuDNN. 1-D / 3-D pooling, adaptive pooling, LP-pool, and
1123/// fractional-max-pool are reserved discriminants for fanout milestones.
1124#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
1125#[repr(u16)]
1126#[non_exhaustive]
1127pub enum PoolKind {
1128 /// 2-D max-pool forward. PyTorch `torch.nn.functional.max_pool2d`.
1129 /// Trailblazer for Phase 7 Milestone 7.2.
1130 MaxPool2d = 0,
1131 /// 2-D max-pool backward (data-gradient). PyTorch's autograd-
1132 /// internal `max_pool2d_with_indices_backward`.
1133 MaxPool2dBackward = 1,
1134 /// 2-D average-pool forward, **count-include-padding** denominator.
1135 /// Matches cuDNN's `*_COUNT_INCLUDE_PADDING` mode.
1136 AvgPool2dIncludePad = 2,
1137 /// 2-D average-pool backward, count-include-padding.
1138 AvgPool2dIncludePadBackward = 3,
1139 /// 2-D average-pool forward, **count-exclude-padding** denominator
1140 /// (PyTorch default — `nn.AvgPool2d` with `count_include_pad=False`).
1141 AvgPool2dExcludePad = 4,
1142 /// 2-D average-pool backward, count-exclude-padding.
1143 AvgPool2dExcludePadBackward = 5,
1144 /// 1-D max-pool forward (Phase 11.8). NCL layout via cuDNN's Nd
1145 /// pool descriptor with `W = 1`.
1146 MaxPool1d = 6,
1147 /// 1-D average-pool forward. Reserved.
1148 AvgPool1d = 7,
1149 /// 3-D max-pool forward (Phase 11.8). NCDHW layout via cuDNN's
1150 /// Nd pool descriptor.
1151 MaxPool3d = 8,
1152 /// 3-D average-pool forward. Reserved.
1153 AvgPool3d = 9,
1154 /// `torch.nn.functional.adaptive_max_pool*` — reserved.
1155 AdaptiveMaxPool = 10,
1156 /// `torch.nn.functional.adaptive_avg_pool*` — reserved.
1157 AdaptiveAvgPool = 11,
1158 /// `torch.nn.functional.lp_pool*` — reserved.
1159 LpPool = 12,
1160 /// `torch.nn.functional.fractional_max_pool*` — reserved.
1161 FractionalMaxPool = 13,
1162 /// 1-D max-pool backward.
1163 MaxPool1dBackward = 14,
1164 /// 1-D average-pool backward (count-include-padding).
1165 AvgPool1dIncludePadBackward = 15,
1166 /// 1-D average-pool forward (count-include-padding).
1167 AvgPool1dIncludePad = 16,
1168 /// 1-D average-pool forward (count-exclude-padding — PyTorch default).
1169 AvgPool1dExcludePad = 17,
1170 /// 1-D average-pool backward (count-exclude-padding).
1171 AvgPool1dExcludePadBackward = 18,
1172 /// 3-D max-pool backward.
1173 MaxPool3dBackward = 19,
1174 /// 3-D average-pool forward (count-include-padding).
1175 AvgPool3dIncludePad = 20,
1176 /// 3-D average-pool backward (count-include-padding).
1177 AvgPool3dIncludePadBackward = 21,
1178 /// 3-D average-pool forward (count-exclude-padding).
1179 AvgPool3dExcludePad = 22,
1180 /// 3-D average-pool backward (count-exclude-padding).
1181 AvgPool3dExcludePadBackward = 23,
1182 /// Adaptive average-pool 1-D (Phase 11.8 — cuDNN approximation).
1183 AdaptiveAvgPool1d = 24,
1184 /// Adaptive average-pool 1-D backward.
1185 AdaptiveAvgPool1dBackward = 25,
1186 /// Adaptive average-pool 2-D.
1187 AdaptiveAvgPool2d = 26,
1188 /// Adaptive average-pool 2-D backward.
1189 AdaptiveAvgPool2dBackward = 27,
1190 /// Adaptive average-pool 3-D.
1191 AdaptiveAvgPool3d = 28,
1192 /// Adaptive average-pool 3-D backward.
1193 AdaptiveAvgPool3dBackward = 29,
1194 /// Adaptive max-pool 1-D.
1195 AdaptiveMaxPool1d = 30,
1196 /// Adaptive max-pool 1-D backward.
1197 AdaptiveMaxPool1dBackward = 31,
1198 /// Adaptive max-pool 2-D.
1199 AdaptiveMaxPool2d = 32,
1200 /// Adaptive max-pool 2-D backward.
1201 AdaptiveMaxPool2dBackward = 33,
1202 /// Adaptive max-pool 3-D.
1203 AdaptiveMaxPool3d = 34,
1204 /// Adaptive max-pool 3-D backward.
1205 AdaptiveMaxPool3dBackward = 35,
1206 /// LP-pool 1-D (Phase 16.2 — bespoke fused kernel:
1207 /// `y = (Σ |x|^p)^(1/p)` over each pool window in one launch).
1208 LpPool1d = 36,
1209 /// LP-pool 2-D (Phase 16.2 — bespoke fused kernel).
1210 LpPool2d = 37,
1211 /// Fractional max-pool 2-D (Phase 16.3 — bespoke kernel; cuDNN has
1212 /// no fractional-pool primitive).
1213 FractionalMaxPool2d = 38,
1214 /// Fractional max-pool 3-D (Phase 16.3 — bespoke kernel).
1215 FractionalMaxPool3d = 39,
1216 /// LP-pool 1-D backward (Phase 16.2 — atomicAdd scatter from
1217 /// each output cell over its source window).
1218 LpPool1dBackward = 40,
1219 /// LP-pool 2-D backward (Phase 16.2 — atomicAdd scatter).
1220 LpPool2dBackward = 41,
1221 /// Fractional max-pool 2-D backward (Phase 16.3 — atomicAdd scatter
1222 /// from each output cell into `dx[indices[cell]]` via saved
1223 /// argmax). half / bf16 atomicAdd routes through atomicCAS.
1224 FractionalMaxPool2dBackward = 42,
1225 /// Fractional max-pool 3-D backward (Phase 16.3 — atomicAdd scatter).
1226 FractionalMaxPool3dBackward = 43,
1227}
1228
1229/// Attention-family op discriminant — Category K from the comprehensive
1230/// plan.
1231///
1232/// Stored as `u16` in [`crate::KernelSku::op`] when
1233/// `category == OpCategory::Attention`. Phase 6.1 wires the two
1234/// positional-encoding ops [`Self::Rope`] and [`Self::Alibi`]; the rest
1235/// are reserved discriminants for future milestones (SDPA, FlashAttention,
1236/// KV-cache, paged attention).
1237///
1238/// All variants in this family operate on rank-4 attention-shaped
1239/// tensors (typically `[batch, num_heads, seq_len, head_dim]` for RoPE
1240/// or `[batch, num_heads, query_len, key_len]` for attention scores /
1241/// ALiBi). Plan shapes differ between ops — the discriminant is here
1242/// for SKU-tagging uniformity, not for shared dispatch.
1243#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
1244#[repr(u16)]
1245#[non_exhaustive]
1246pub enum AttentionKind {
1247 /// Rotary position embedding (Llama / Mistral / Gemma / Qwen / Phi).
1248 /// Rotates pairs of consecutive features `(2i, 2i+1)` of a
1249 /// `[B, H, S, D]` Q/K tensor by per-position angles
1250 /// `θ = pos · base^(-2i / D)`. Trailblazer for Phase 6.
1251 Rope = 0,
1252 /// Attention with Linear Biases (MPT / BLOOM). Adds the bias
1253 /// `slope[h] · (j - i)` to attention-score cell `(b, h, i, j)`.
1254 /// Linear (non-transcendental) FW; BW reduces over the
1255 /// score-shape axes to recover `dslope[h]`.
1256 Alibi = 1,
1257 /// Scaled dot-product attention — reserved.
1258 Sdpa = 2,
1259 /// FlashAttention (Tri Dao 2022) — wired in Milestone 6.6. Tiled
1260 /// fused online-softmax FW kernel that avoids materializing the
1261 /// `[B, H, Q, K]` attention matrix; instead saves a small
1262 /// `lse: [B, H, Q]` log-sum-exp tensor for the BW pass. Trailblazer
1263 /// constraints: `Br = Bc = 64`, `d_k = d_v ≤ 128`, optional causal
1264 /// mask, no explicit additive mask (use `SdpaPlan` for masked
1265 /// attention).
1266 FlashAttention = 3,
1267 /// KV-cache append — decoder-inference helper that writes
1268 /// newly-generated `K` / `V` slices into running cache buffers at
1269 /// per-sample offsets. Wired in Milestone 6.5 (FW only, no BW —
1270 /// inference-time op).
1271 KvCache = 4,
1272 /// Paged attention (vLLM-style) — reserved.
1273 PagedAttention = 5,
1274 /// Manifold-Constrained Hyper-Connections (DeepSeek-AI 2025, mHC).
1275 /// Drop-in replacement for the bare `y = x + sublayer(x)` residual
1276 /// connection in transformer blocks. Mixes `n` parallel residual
1277 /// streams through a small Sinkhorn-Knopp-normalized matrix `M`
1278 /// that lives on the manifold of doubly-stochastic matrices.
1279 /// Wired in Phase 43 — bf16 weights / f32 activations, static-H
1280 /// FW only (dynamic-H + BW deferred). Backed by the vendored
1281 /// `mHC.cu` (Andre Slavescu, MIT) under
1282 /// `crates/baracuda-kernels-sys/vendor/mhc/`.
1283 HyperConnection = 6,
1284 /// Mamba-2 State-Space Duality (SSD) chunk-scan (Phase 50). Bespoke
1285 /// kernel powering the Mamba-2 family (Mamba-2 8B, Codestral-Mamba,
1286 /// Falcon-Mamba, Zamba2). Operates on rank-4 `[B, L, H, D]` input
1287 /// + rank-4 `[B, L, H, N]` `B` / `C` modulation tensors + per-head
1288 /// scalar SSM eigenvalue `A: [H]`, producing rank-4 `[B, L, H, D]`
1289 /// output. State residency is `H * D * N` floats in SMEM (trailblazer
1290 /// caps `D, N ≤ 256` for FW, `≤ 64` for BW). Behind the `mamba`
1291 /// cargo feature.
1292 SsdChunkScan = 7,
1293 /// Mamba-1 selective_scan (Phase 50b). Bespoke kernel powering the
1294 /// Mamba-1 family (Mamba-7B, Falcon-Mamba, Codestral-Mamba). Operates
1295 /// on rank-3 `[B, L, D]` input + rank-3 `[B, L, N]` `B` / `C`
1296 /// modulation tensors + per-channel `[D, N]` state matrix `A`, with
1297 /// optional `D[d]` skip-connection, SiLU-gated tail `z`, delta-bias,
1298 /// and softplus-`delta` mapping. State residency is `N` floats in
1299 /// SMEM per `(b, d)` block (trailblazer caps `N ≤ 256`). Behind the
1300 /// `mamba` cargo feature.
1301 SelectiveScan = 8,
1302 /// Block-sparse SDPA (Phase 54, xFormers algorithmic-reference
1303 /// hand-port). Attention mask is a per-block boolean pattern
1304 /// `[B, H, num_blocks_q * num_blocks_k]`; only the active
1305 /// (q_block, k_block) pairs participate in the QK^T matmul +
1306 /// online-softmax accumulation. Different from
1307 /// [`Self::FlashAttention`] (dense) and from the Phase 51
1308 /// arbitrary-additive-mask path (which still computes every cell).
1309 /// FW only in Tier 1; backed by bespoke `mma`-free tile kernel
1310 /// behind the `xformers_blocksparse` cargo feature.
1311 BlockSparseAttention = 9,
1312 /// Ring Attention (Phase 56). Sequence-parallel attention where
1313 /// the Q tensor is sliced across `world_size` ranks and the K/V
1314 /// chunks rotate around a NCCL ring; each rank folds the resident
1315 /// K/V chunk into a persistent (o_acc, m_acc, l_acc) accumulator
1316 /// via online-softmax reconstruction (Flash Attention math), and
1317 /// after `world_size` rotations every rank has computed
1318 /// `Q[my_slice] @ K^T @ V` for the full global sequence — but
1319 /// with O(N/P) memory where N = total seq len, P = ring size.
1320 /// Algorithm: Liu, Yan, Abbeel 2023 (arXiv:2310.01889; reference
1321 /// at https://github.com/lhao499/RingAttention, Apache-2.0).
1322 /// Tier 1 ships FW only, f16/bf16, head_dim=128. Behind the
1323 /// `ring_attention` cargo feature; pulls in `baracuda-nccl`.
1324 RingAttention = 10,
1325}
1326
1327/// Indexing / scatter / gather op discriminant — Category L from the
1328/// comprehensive plan.
1329///
1330/// Stored as `u16` in [`crate::KernelSku::op`] when
1331/// `category == OpCategory::Indexing`. Phase 7 Milestone 7.3 wires:
1332/// - [`Self::Gather`] (FW + BW): `out[i] = src[index[i]]` along a dim.
1333/// - [`Self::ScatterAdd`]: `out[index[i]] += updates[i]` along a dim
1334/// (atomicAdd, dup-safe).
1335/// - [`Self::IndexSelect`] (FW + BW): `out[..., j, ...] = src[..., idx[j], ...]`
1336/// with a 1-D i32 idx tensor.
1337/// - [`Self::MaskedFill`] (FW + BW): `out[i] = mask[i] ? value : src[i]`.
1338/// - [`Self::OneHot`] (FW only — non-differentiable):
1339/// `out[..., c] = 1 if c == src[...] else 0`.
1340/// - [`Self::Nonzero`] (FW only): coordinates where input != 0,
1341/// returned as an `[k, rank]` i32 table plus a count.
1342///
1343/// Index dtype is `i32` only in the trailblazer (i64 deferred).
1344/// Out-of-bounds and negative indices are treated as no-ops (the kernel
1345/// skips them — PyTorch-style negative wrap-around is deferred).
1346#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
1347#[repr(u16)]
1348#[non_exhaustive]
1349pub enum IndexingKind {
1350 /// `gather(src, dim, index)` — `out[..., j, ...] = src[..., index[..., j, ...], ...]`
1351 /// along the specified gather dimension. PyTorch `torch.gather`.
1352 Gather = 0,
1353 /// Gradient of [`Self::Gather`]: scatters `dout` into `dsrc` along
1354 /// the gather dim with atomicAdd (dup-safe). Different signature
1355 /// from [`Self::ScatterAdd`] because the dst is `dsrc` and the
1356 /// index pattern matches the FW gather coordinates exactly.
1357 GatherBackward = 1,
1358 /// `scatter_add(out, dim, index, updates)` —
1359 /// `out[..., index[..., j, ...], ...] += updates[..., j, ...]`
1360 /// (atomicAdd). PyTorch `torch.scatter_add_`.
1361 ScatterAdd = 2,
1362 /// `index_select(src, dim, idx)` —
1363 /// `out[..., j, ...] = src[..., idx[j], ...]` with a 1-D i32 idx
1364 /// tensor. Faster / simpler than `gather` when the index tensor
1365 /// is 1-D. PyTorch `torch.index_select`.
1366 IndexSelect = 3,
1367 /// Gradient of [`Self::IndexSelect`]: scatter-add `dout` into `dsrc`
1368 /// along `select_dim` using `idx` (atomicAdd).
1369 IndexSelectBackward = 4,
1370 /// `masked_fill(src, mask, value)` —
1371 /// `out[i] = mask[i] ? value : src[i]`. PyTorch
1372 /// `torch.Tensor.masked_fill`.
1373 MaskedFill = 5,
1374 /// Gradient of [`Self::MaskedFill`]: `dsrc[i] = mask[i] ? 0 : dout[i]`.
1375 /// `value` is a non-differentiable scalar.
1376 MaskedFillBackward = 6,
1377 /// `one_hot(src, num_classes)` —
1378 /// `out[indices..., c] = 1 if c == src[indices...] else 0`. Input
1379 /// dtype is i32 (class indices); output dtype is configurable.
1380 /// PyTorch `torch.nn.functional.one_hot`. Non-differentiable.
1381 OneHot = 7,
1382 /// `nonzero(x)` — coordinates where `x != 0`. Returns an
1383 /// `[k, rank]` i32 coordinate table plus a count. PyTorch
1384 /// `torch.nonzero`. Output ordering is NOT row-major (atomic-counter
1385 /// races); callers that need sorted output sort afterward.
1386 Nonzero = 8,
1387 /// `scatter(out, dim, index, updates)` —
1388 /// `out[..., index[..., j, ...], ...] = updates[..., j, ...]`
1389 /// (NO accumulation; last writer wins on duplicate-target races).
1390 /// PyTorch `torch.scatter_` (the in-place pure-assign variant).
1391 /// Distinct from [`Self::ScatterAdd`]. Phase 39 (Fuel 6c.4 Gap 5).
1392 Scatter = 9,
1393 /// `index_add(dst, dim, idx, src)` —
1394 /// `dst[idx[i], ...] += src[i, ...]` along `add_dim` (atomicAdd-Σ).
1395 /// PyTorch `torch.Tensor.index_add_`. Algorithmically identical to
1396 /// [`Self::IndexSelectBackward`] but exposed under a non-autograd-
1397 /// flavored name (and with broader dtype coverage). Phase 39
1398 /// (Fuel 6c.4 Gap 5).
1399 IndexAdd = 10,
1400}
1401
1402/// Segment / scatter-reduce op discriminant — Category S from the
1403/// comprehensive plan.
1404///
1405/// Stored as `u16` in [`crate::KernelSku::op`] when
1406/// `category == OpCategory::SegmentOps`. Each variant maps to a
1407/// distinct kernel symbol — sorted and unsorted families live in the
1408/// same enum (different `op` slots) because the kernel implementation
1409/// differs (sorted = binary-search single-pass sweep; unsorted = atomic
1410/// scatter from the input side).
1411///
1412/// Phase 7 Milestone 7.6 wires:
1413/// - Sorted: [`Self::SegmentSum`], [`Self::SegmentMean`],
1414/// [`Self::SegmentMax`], [`Self::SegmentMin`], [`Self::SegmentProd`]
1415/// (FW). Sum / Mean carry a BW variant
1416/// ([`Self::SegmentSumBackward`], [`Self::SegmentMeanBackward`]).
1417/// - Unsorted: [`Self::UnsortedSegmentSum`],
1418/// [`Self::UnsortedSegmentMean`], [`Self::UnsortedSegmentMax`],
1419/// [`Self::UnsortedSegmentMin`] (FW). Sum / Mean carry a BW variant
1420/// ([`Self::UnsortedSegmentSumBackward`],
1421/// [`Self::UnsortedSegmentMeanBackward`]).
1422///
1423/// Phase 25 closes the remaining BW gaps: Max / Min BW (sorted +
1424/// unsorted) recompute the argmax in the BW kernel (preserves FW API
1425/// source-compat — no paired-index tensor in the FW signature). Prod
1426/// BW (sorted + unsorted) computes `d_output * prod / x` with direct
1427/// division — caller must avoid zero-valued inputs in the segment or
1428/// accept NaN/Inf in the gradient. Unsorted Prod FW uses an
1429/// `atomicCAS` retry loop (no native FP `atomicMul`).
1430///
1431/// Dtype coverage: `f32, f64` (atomic-supported FP types). f16 / bf16
1432/// deferred — the kernels use `atomicAdd` / `atomicMax` / `atomicMin`
1433/// which are restricted to native-FP-atomic types.
1434#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
1435#[repr(u16)]
1436#[non_exhaustive]
1437pub enum SegmentKind {
1438 /// `out[s, d] = Σ_{n : seg[n] == s} input[n, d]` — sorted segment
1439 /// IDs (monotonically non-decreasing). TF / JAX `segment_sum`.
1440 SegmentSum = 0,
1441 /// Gradient of [`Self::SegmentSum`]:
1442 /// `d_input[n, d] = d_output[seg[n], d]` (gather along seg ids).
1443 SegmentSumBackward = 1,
1444 /// `out[s, d] = mean_{n : seg[n] == s} input[n, d]` — sorted.
1445 SegmentMean = 2,
1446 /// Gradient of [`Self::SegmentMean`]:
1447 /// `d_input[n, d] = d_output[seg[n], d] / count[seg[n]]`.
1448 SegmentMeanBackward = 3,
1449 /// `out[s, d] = max_{n : seg[n] == s} input[n, d]` — sorted.
1450 SegmentMax = 4,
1451 /// `out[s, d] = min_{n : seg[n] == s} input[n, d]` — sorted.
1452 SegmentMin = 5,
1453 /// `out[s, d] = prod_{n : seg[n] == s} input[n, d]` — sorted.
1454 SegmentProd = 6,
1455 /// `out[s, d] = Σ_{n : seg[n] == s} input[n, d]` — unsorted
1456 /// (seg IDs in any order). TF `unsorted_segment_sum`.
1457 UnsortedSegmentSum = 7,
1458 /// Gradient of [`Self::UnsortedSegmentSum`]:
1459 /// `d_input[n, d] = d_output[seg[n], d]`.
1460 UnsortedSegmentSumBackward = 8,
1461 /// `out[s, d] = mean_{n : seg[n] == s} input[n, d]` — unsorted.
1462 UnsortedSegmentMean = 9,
1463 /// Gradient of [`Self::UnsortedSegmentMean`]:
1464 /// `d_input[n, d] = d_output[seg[n], d] / count[seg[n]]`.
1465 UnsortedSegmentMeanBackward = 10,
1466 /// `out[s, d] = max_{n : seg[n] == s} input[n, d]` — unsorted.
1467 UnsortedSegmentMax = 11,
1468 /// `out[s, d] = min_{n : seg[n] == s} input[n, d]` — unsorted.
1469 UnsortedSegmentMin = 12,
1470 /// Phase 25. Gradient of [`Self::SegmentMax`]:
1471 /// `d_input[k, d] = d_output[seg, d]` for the (lowest-index) `k`
1472 /// where `input[k, d] == max`. Argmax recomputed in BW kernel
1473 /// (re-scans the segment) so the FW signature stays unchanged.
1474 SegmentMaxBackward = 13,
1475 /// Phase 25. Gradient of [`Self::SegmentMin`] — mirror of
1476 /// [`Self::SegmentMaxBackward`].
1477 SegmentMinBackward = 14,
1478 /// Phase 25. Gradient of [`Self::SegmentProd`]:
1479 /// `d_input[k, d] = d_output[seg, d] * (prod[seg, d] / x[k, d])`.
1480 /// Direct division — caller must avoid zero-valued inputs in the
1481 /// segment or accept NaN / Inf in the gradient.
1482 SegmentProdBackward = 15,
1483 /// Phase 25. Gradient of [`Self::UnsortedSegmentMax`] — same
1484 /// recompute-argmax pattern as the sorted variant but scans the
1485 /// full input array per (seg, d) cell. Non-deterministic w.r.t.
1486 /// tie-breaking when the FW was non-deterministic.
1487 UnsortedSegmentMaxBackward = 16,
1488 /// Phase 25. Gradient of [`Self::UnsortedSegmentMin`] — mirror of
1489 /// [`Self::UnsortedSegmentMaxBackward`].
1490 UnsortedSegmentMinBackward = 17,
1491 /// Phase 25. `out[s, d] = prod_{n : seg[n] == s} input[n, d]` —
1492 /// unsorted. Uses an `atomicCAS` retry loop because no native FP
1493 /// `atomicMul` exists. Non-deterministic.
1494 UnsortedSegmentProd = 18,
1495 /// Phase 25. Gradient of [`Self::UnsortedSegmentProd`] — same
1496 /// direct-division pattern as [`Self::SegmentProdBackward`].
1497 UnsortedSegmentProdBackward = 19,
1498}
1499
1500/// Embedding-family op discriminant — Category M from the comprehensive
1501/// plan.
1502///
1503/// Stored as `u16` in [`crate::KernelSku::op`] when
1504/// `category == OpCategory::Embedding`. Phase 7 Milestone 7.5 wires:
1505/// - [`Self::Embedding`] (FW + BW): row-lookup
1506/// `out[i, :] = weight[indices[i], :]` with optional `padding_idx`
1507/// that emits an all-zero row at FW and skips accumulation at BW.
1508/// - [`Self::EmbeddingBagSum`] / [`Self::EmbeddingBagMean`] (FW + BW):
1509/// bag-reduced row lookup —
1510/// `out[b, :] = reduce(weight[indices[k], :] for k in offsets[b]..offsets[b+1])`.
1511/// Mode determines the reducer (sum / divide-by-bag-size).
1512/// `EmbeddingBagMax` is deferred (needs argmax tracking for BW).
1513///
1514/// Index dtype is `i32` only (i64 deferred). FW kernels emit
1515/// `f32, f64, f16, bf16` (pure copy / reduce); BW kernels emit `f32,
1516/// f64` (atomicAdd).
1517#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
1518#[repr(u16)]
1519#[non_exhaustive]
1520pub enum EmbeddingKind {
1521 /// `embedding(weight, indices, padding_idx)` —
1522 /// `out[i, :] = weight[indices[i], :]`. PyTorch
1523 /// `torch.nn.functional.embedding`.
1524 Embedding = 0,
1525 /// Gradient of [`Self::Embedding`]:
1526 /// `dweight[indices[i], :] += dout[i, :]` (atomicAdd), skipping
1527 /// rows where `indices[i] == padding_idx`.
1528 EmbeddingBackward = 1,
1529 /// `embedding_bag(weight, indices, offsets, mode=Sum)`.
1530 /// PyTorch `torch.nn.functional.embedding_bag` with `mode='sum'`.
1531 EmbeddingBagSum = 2,
1532 /// `embedding_bag(weight, indices, offsets, mode=Mean)`.
1533 /// PyTorch `torch.nn.functional.embedding_bag` with `mode='mean'`.
1534 EmbeddingBagMean = 3,
1535 /// Gradient of `embedding_bag` (Sum-mode):
1536 /// `dweight[indices[k], :] += dout[b, :]` for k in bag b (atomicAdd).
1537 EmbeddingBagSumBackward = 4,
1538 /// Gradient of `embedding_bag` (Mean-mode):
1539 /// `dweight[indices[k], :] += dout[b, :] / bag_size(b)` (atomicAdd).
1540 EmbeddingBagMeanBackward = 5,
1541 /// `embedding_bag(weight, indices, offsets, mode=Max)` — reserved.
1542 /// Max-mode requires argmax tracking on FW (the per-feature index
1543 /// of the contributing row) so the BW can scatter into just that
1544 /// row — different plan shape; deferred.
1545 EmbeddingBagMax = 6,
1546 /// Gradient of `embedding_bag` (Max-mode) — reserved.
1547 EmbeddingBagMaxBackward = 7,
1548}
1549
1550/// Quantization op discriminant — Category P from the comprehensive plan.
1551///
1552/// Stored as `u16` in [`crate::KernelSku::op`] when
1553/// `category == OpCategory::Quantization`. Phase 8 Milestone 8.1 wires the
1554/// trailblazer set: per-tensor + per-channel quantize / dequantize plus
1555/// fake_quantize (round-trip in FP space). All entries support FW + BW
1556/// where applicable (FW-only for kinds that have no meaningful gradient).
1557///
1558/// **Trailblazer dtype scope.** Input FP × output int:
1559/// - Input FP: `f32, f64, f16, bf16`.
1560/// - Output int: `s8, u8`. Sub-byte packed types (`s4`, `u4`) are deferred.
1561/// - `scale` matches the input FP dtype; `zero_point` is always `i32`
1562/// (wide enough for any int output qmin/qmax range).
1563///
1564/// **Backward convention (Straight-Through Estimator).** The BW of
1565/// `quantize` and `fake_quantize` uses STE — the gradient passes through
1566/// (with a `1/scale` factor for `quantize`, no factor for `fake_quantize`)
1567/// where the rounded result was in-range `[qmin, qmax]`, zero elsewhere.
1568/// The "in-range mask" is **recomputed in BW from the saved `input`
1569/// tensor** rather than saved as a separate FW output — this matches
1570/// PyTorch's internal FakeQuantize and keeps the FW signature clean.
1571/// Callers must therefore retain the original input tensor for the BW
1572/// pass (which they would do anyway for autograd).
1573///
1574/// Future milestones extend this enum with `PerToken` / `PerGroup` /
1575/// `DynamicRange` variants — discriminant gaps are intentionally left
1576/// for those.
1577#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
1578#[repr(u16)]
1579#[non_exhaustive]
1580pub enum QuantizeKind {
1581 /// `quantize_per_tensor(x, scale, zero_point)` —
1582 /// `q = clamp(round(x / scale) + zero_point, qmin, qmax)`.
1583 /// One scalar `scale` (FP) and `zero_point` (i32) for the whole
1584 /// tensor. PyTorch `torch.quantize_per_tensor`.
1585 PerTensor = 0,
1586 /// Gradient of [`Self::PerTensor`] via STE:
1587 /// `dx = (dy / scale) * in_range_mask`, where the mask is
1588 /// `qmin <= round(x/scale) + zp <= qmax`.
1589 PerTensorBackward = 1,
1590 /// `dequantize_per_tensor(q, scale, zero_point)` —
1591 /// `x = scale * (q - zero_point)`. Linear; exactly invertible up to
1592 /// rounding. PyTorch `torch.Tensor.dequantize`.
1593 DequantizePerTensor = 2,
1594 /// Gradient of [`Self::DequantizePerTensor`]: `dq = dy * scale`
1595 /// (linear identity scaled).
1596 DequantizePerTensorBackward = 3,
1597 /// `quantize_per_channel(x, scale[C], zero_point[C], axis)` — same
1598 /// math as [`Self::PerTensor`] but with one `scale[c]` /
1599 /// `zero_point[c]` pair per slice along `axis`. PyTorch
1600 /// `torch.quantize_per_channel`.
1601 PerChannel = 4,
1602 /// Gradient of [`Self::PerChannel`] via STE:
1603 /// `dx = (dy / scale[c]) * in_range_mask[c]`.
1604 PerChannelBackward = 5,
1605 /// `dequantize_per_channel(q, scale[C], zero_point[C], axis)` —
1606 /// `x = scale[c] * (q - zero_point[c])`.
1607 DequantizePerChannel = 6,
1608 /// Gradient of [`Self::DequantizePerChannel`]:
1609 /// `dq = dy * scale[c]`.
1610 DequantizePerChannelBackward = 7,
1611 /// `fake_quantize_per_tensor(x, scale, zero_point)` —
1612 /// `y = scale * (clamp(round(x/scale)+zp, qmin, qmax) - zp)`. The
1613 /// roundtrip quantize-then-dequantize in FP space; produces a lossy
1614 /// FP output. PyTorch
1615 /// `torch.fake_quantize_per_tensor_affine`.
1616 FakeQuantize = 8,
1617 /// Gradient of [`Self::FakeQuantize`] via STE:
1618 /// `dx = dy * in_range_mask`. **No `1/scale` factor** — the
1619 /// dequant-side multiplication by `scale` in FW cancels the
1620 /// `1/scale` from STE.
1621 FakeQuantizeBackward = 9,
1622 /// Reserved — `quantize_per_token` (per-row dynamic-range
1623 /// quantization used by activation quantization).
1624 PerToken = 16,
1625 /// Reserved — gradient of [`Self::PerToken`].
1626 PerTokenBackward = 17,
1627 /// Reserved — `quantize_per_group` (block-wise quantization used by
1628 /// GPTQ / AWQ / GGML).
1629 PerGroup = 18,
1630 /// Reserved — gradient of [`Self::PerGroup`].
1631 PerGroupBackward = 19,
1632 /// Reserved — `dynamic_range_quantize` (post-training dynamic
1633 /// quantization).
1634 DynamicRange = 20,
1635 // ---- Milestone 8.2 completion — per-token / per-group dequant
1636 // + backwards (FW PerToken / PerGroup discriminants were
1637 // reserved above at 16-19) ----
1638 /// `dequantize_per_token(q, scale[N], zero_point[N])` —
1639 /// `y[n, d] = scale[n] * (q[n, d] - zp[n])`. Per-row inverse of
1640 /// [`Self::PerToken`].
1641 DequantizePerToken = 21,
1642 /// Gradient of [`Self::DequantizePerToken`]:
1643 /// `dq = dy * scale[n]` (straight-through).
1644 DequantizePerTokenBackward = 22,
1645 /// `dequantize_per_group(q, scale[outer, num_groups],
1646 /// zero_point[outer, num_groups])` — per-group inverse of
1647 /// [`Self::PerGroup`].
1648 DequantizePerGroup = 23,
1649 /// Gradient of [`Self::DequantizePerGroup`]:
1650 /// `dq[i, j] = dy[i, j] * scale[i, j/g]` (straight-through).
1651 DequantizePerGroupBackward = 24,
1652 // ---- Milestone 8.3 — composing quantization ops ----
1653 /// `quantized_linear(activation_fp, weight_q_s8, weight_scale,
1654 /// bias?)` — W8A8 fused quantized matmul. Pipeline: dynamic-range
1655 /// per-token quantize the activation → int8 GEMM with int32
1656 /// accumulator → dequantize via per-row `scale_a` and per-channel
1657 /// `scale_w`. The canonical inference-time LLM matmul recipe
1658 /// (e.g. SmoothQuant, AWQ-runtime); FP activation in, FP output out,
1659 /// int8 storage only on the GEMM. Backward isn't shipped — this op
1660 /// is inference-only by convention.
1661 QuantizedLinear = 25,
1662 // ---- Milestone 8.4 — GGUF block-format quant family ----
1663 /// `gguf_dequantize(packed_bytes) -> fp_tensor` — unpack a
1664 /// GGUF-packed weight buffer (Q4_0 / Q4_1 / Q5_0 / Q5_1 / Q8_0 +
1665 /// Q2_K / Q3_K / Q4_K / Q5_K / Q6_K / Q8_K) into a dense FP
1666 /// tensor. The block format is carried out-of-band on the plan
1667 /// descriptor (see [`GgufBlockFormat`]); the kernel surface
1668 /// fans out across block formats but the enum value is the same.
1669 /// Inference-only by convention (BW not shipped).
1670 GgufDequantize = 26,
1671 /// `gguf_mmvq(packed_weight, fp_activation) -> fp_output` —
1672 /// fused dequant + matrix-vector multiply: the inference-time
1673 /// "decode-step" matmul used by llama.cpp on GGUF weights.
1674 /// FP activation in (f32 today), FP output out. Inference-only
1675 /// (BW not shipped).
1676 GgufMmvq = 27,
1677}
1678
1679/// GGUF block-format selector for [`QuantizeKind::GgufDequantize`] /
1680/// [`QuantizeKind::GgufMmvq`] plans. Mirrors the discriminants used by
1681/// llama.cpp / `ggml` so a descriptor can be round-tripped to a GGUF
1682/// file header without translation.
1683///
1684/// Block sizes:
1685/// * Type-0/1 variants (`Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`)
1686/// pack 32 quantized values per block plus a shared FP scale
1687/// (+ min for the `_1` variants).
1688/// * k-quants variants (`Q2_K` ... `Q8_K`) pack 256 values per
1689/// super-block with a multi-level scale hierarchy
1690/// (quantized sub-block scales + FP super-block scale).
1691///
1692/// Discriminant values match the `GGML_TYPE_*` enum in upstream
1693/// `ggml.h`, ensuring binary compatibility with GGUF file headers.
1694#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
1695#[repr(u16)]
1696#[non_exhaustive]
1697pub enum GgufBlockFormat {
1698 /// 4-bit, 32-element block, single FP scale. `block_q4_0`.
1699 Q4_0 = 2,
1700 /// 4-bit, 32-element block, FP scale + FP min. `block_q4_1`.
1701 Q4_1 = 3,
1702 /// 5-bit, 32-element block, single FP scale. `block_q5_0`.
1703 Q5_0 = 6,
1704 /// 5-bit, 32-element block, FP scale + FP min. `block_q5_1`.
1705 Q5_1 = 7,
1706 /// 8-bit, 32-element block, single FP scale. `block_q8_0`.
1707 Q8_0 = 8,
1708 /// 2.5-bit (effective), 256-element super-block. `block_q2_K`.
1709 Q2K = 10,
1710 /// 3.4-bit (effective), 256-element super-block. `block_q3_K`.
1711 Q3K = 11,
1712 /// 4.5-bit (effective), 256-element super-block. `block_q4_K`.
1713 Q4K = 12,
1714 /// 5.5-bit (effective), 256-element super-block. `block_q5_K`.
1715 Q5K = 13,
1716 /// 6.6-bit (effective), 256-element super-block. `block_q6_K`.
1717 Q6K = 14,
1718 /// 8-bit, 256-element super-block (CPU-side intermediate).
1719 /// `block_q8_K`. Dequant supported; MMVQ NOT supported (matches
1720 /// llama.cpp — no upstream MMVQ specialization).
1721 Q8K = 15,
1722}
1723
1724impl GgufBlockFormat {
1725 /// Number of FP elements per packed block.
1726 #[inline]
1727 pub const fn block_size(self) -> usize {
1728 match self {
1729 GgufBlockFormat::Q4_0
1730 | GgufBlockFormat::Q4_1
1731 | GgufBlockFormat::Q5_0
1732 | GgufBlockFormat::Q5_1
1733 | GgufBlockFormat::Q8_0 => 32,
1734 _ => 256,
1735 }
1736 }
1737
1738 /// Number of bytes per packed block. Matches `sizeof(block_q*)`
1739 /// from `ggml.h`. Used by the Rust plan layer to size the input
1740 /// weight buffer.
1741 #[inline]
1742 pub const fn type_size(self) -> usize {
1743 match self {
1744 // 2 (fp16 d) + 16 (qs[16])
1745 GgufBlockFormat::Q4_0 => 18,
1746 // 2*2 (half2 dm) + 16 (qs[16])
1747 GgufBlockFormat::Q4_1 => 20,
1748 // 2 (fp16 d) + 4 (qh) + 16 (qs[16])
1749 GgufBlockFormat::Q5_0 => 22,
1750 // 2*2 (half2 dm) + 4 (qh) + 16 (qs[16])
1751 GgufBlockFormat::Q5_1 => 24,
1752 // 2 (fp16 d) + 32 (qs[32])
1753 GgufBlockFormat::Q8_0 => 34,
1754 // 2*2 (half2 dm) + QK_K/16 (16 scales) + QK_K/4 (64 qs) = 4+16+64
1755 GgufBlockFormat::Q2K => 84,
1756 // hmask(32) + qs(64) + scales(12) + d(2)
1757 GgufBlockFormat::Q3K => 110,
1758 // dm(4) + scales(12) + qs(128)
1759 GgufBlockFormat::Q4K => 144,
1760 // dm(4) + scales(12) + qh(32) + qs(128)
1761 GgufBlockFormat::Q5K => 176,
1762 // ql(128) + qh(64) + scales(16) + d(2)
1763 GgufBlockFormat::Q6K => 210,
1764 // d(4) + qs(256) + bsums(32)
1765 GgufBlockFormat::Q8K => 292,
1766 }
1767 }
1768
1769 /// `true` for the type-0/1 family (32-element blocks); `false`
1770 /// for the k-quants family (256-element super-blocks).
1771 #[inline]
1772 pub const fn is_type_01(self) -> bool {
1773 matches!(
1774 self,
1775 GgufBlockFormat::Q4_0
1776 | GgufBlockFormat::Q4_1
1777 | GgufBlockFormat::Q5_0
1778 | GgufBlockFormat::Q5_1
1779 | GgufBlockFormat::Q8_0
1780 )
1781 }
1782
1783 /// `true` if MMVQ (fused dequant + matvec) is supported for this
1784 /// block format. As of Phase 11.4, all 11 GGUF block formats ship a
1785 /// MMVQ kernel. Q8_K MMVQ is a bespoke baracuda addition (upstream
1786 /// llama.cpp / Fuel reserve Q8_K as a CPU-side intermediate and ship
1787 /// dequant only); we close that gap to avoid 2× memory traffic on
1788 /// the inference decode step.
1789 #[inline]
1790 pub const fn has_mmvq(self) -> bool {
1791 match self {
1792 GgufBlockFormat::Q4_0
1793 | GgufBlockFormat::Q4_1
1794 | GgufBlockFormat::Q5_0
1795 | GgufBlockFormat::Q5_1
1796 | GgufBlockFormat::Q8_0
1797 | GgufBlockFormat::Q2K
1798 | GgufBlockFormat::Q3K
1799 | GgufBlockFormat::Q4K
1800 | GgufBlockFormat::Q5K
1801 | GgufBlockFormat::Q6K
1802 | GgufBlockFormat::Q8K => true,
1803 }
1804 }
1805}
1806
1807
1808/// Mixture-of-Experts (MoE) variant selector — used as the `op`
1809/// discriminant for kernel SKUs whose [`crate::OpCategory`] is
1810/// [`crate::OpCategory::Moe`]. Phase 8 Milestone 8.5 wires the three
1811/// fused per-token-dispatch + expert-matmul + accumulate kernels.
1812///
1813/// MoE forward pass shape:
1814/// * Input activations `[T, D_model]`.
1815/// * Per-token top-k expert indices `[T, top_k]` (i32).
1816/// * Per-token top-k expert weights `[T, top_k]` (FP).
1817/// * Per-expert weight matrices `[num_experts, D_model, D_expert]`
1818/// (dtype depends on the variant: FP for `Wmma`, GGUF-packed bytes
1819/// for `ScalarGguf` / `WmmaGguf`).
1820/// * Output `[T, D_model]` (after expert mixing).
1821///
1822/// All three variants are inference-only by convention; backward
1823/// passes are not shipped (MoE training uses higher-level autograd
1824/// surfaces that compose the per-expert FFN ops manually).
1825///
1826/// Lineage: vendored from `attention.rs` via `fuel-cuda-kernels`. See
1827/// `crates/baracuda-kernels-sys/LICENSE-thirdparty.md` for the full
1828/// attribution chain.
1829#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
1830#[repr(u16)]
1831#[non_exhaustive]
1832pub enum MoeKind {
1833 /// Scalar dispatch path operating on GGUF-quantized expert weights
1834 /// staged through a q8_1 intermediate (FP32 activations in, FP32
1835 /// output out). No tensor cores. Used as a portability fallback
1836 /// and as the slower-but-simpler reference for the WMMA + GGUF
1837 /// hot path. Block formats: `Q8_0`, `Q2_K`, `Q3_K`, `Q4_K`,
1838 /// `Q5_K`, `Q6_K` (matches Fuel's `moe_gemm_gguf` switch).
1839 ScalarGguf = 0,
1840 /// WMMA tensor-core path operating on dense FP expert weights
1841 /// (f16 / bf16). The FP MoE hot path used when full-precision
1842 /// expert weights are available — typically training-time or
1843 /// FP-deployment inference. sm_70+ required.
1844 Wmma = 1,
1845 /// Combined WMMA tensor-core + GGUF-quantized weight path. The
1846 /// dispatcher dequantizes one GGUF block per N-row into shared
1847 /// memory, then issues a 16×16×16 WMMA mma.sync against the
1848 /// dense activation tile. The production hot path for quantized
1849 /// LLM inference. Activation dtype: f16 / bf16. Weight block
1850 /// formats: same set as [`Self::ScalarGguf`]. sm_70+ required.
1851 WmmaGguf = 2,
1852}
1853
1854/// Sorting / order-statistics op discriminant — Category O from the
1855/// comprehensive plan (Phase 9).
1856///
1857/// Stored as `u16` in [`crate::KernelSku::op`] when
1858/// `category == OpCategory::Sorting`. Phase 9 wires the block-bitonic
1859/// trailblazer family (`row_len ≤ 1024`, `k ≤ 64`):
1860///
1861/// - [`Self::Sort`] / [`Self::SortBackward`] — full sort with saved
1862/// indices for BW. PyTorch `torch.sort`.
1863/// - [`Self::Argsort`] — indices-only variant. PyTorch `torch.argsort`.
1864/// - [`Self::Msort`] / [`Self::MsortBackward`] — stable sort (tie-break
1865/// on original index preserves input order). PyTorch `torch.msort`.
1866/// - [`Self::Topk`] / [`Self::TopkBackward`] — top-k by value (or
1867/// bottom-k when `largest == false`). PyTorch `torch.topk`.
1868/// - [`Self::Kthvalue`] / [`Self::KthvalueBackward`] — composed atop
1869/// topk; returns the k-th value + its index.
1870/// - [`Self::Unique`] / [`Self::UniqueConsecutive`] — set-valued ops;
1871/// `unique` chains sort + consecutive-dedup, `unique_consecutive`
1872/// assumes the input is already sorted (or only run-equal cells
1873/// matter). No BW (set-valued).
1874/// - [`Self::Histogram`] / [`Self::Histogramdd`] / [`Self::Bincount`]
1875/// — atomic-bin accumulation; histogram + bincount FW shipped,
1876/// histogramdd reserved (rank > 1 trailblazer follow-up).
1877/// - [`Self::Searchsorted`] — per-query binary search in a 1-D sorted
1878/// array. PyTorch `torch.searchsorted`. No BW.
1879///
1880/// Dtype coverage:
1881/// - sort / argsort / msort FW: `f32, f64, i32, i64`.
1882/// - sort / msort BW: `f32, f64` (FP grads only).
1883/// - topk FW + BW: `f32, f64`.
1884/// - kthvalue: composes topk; same dtype set.
1885/// - unique / unique_consecutive: `f32, f64, i32`.
1886/// - histogram: `f32, f64` input → `i32` counts.
1887/// - bincount: `i32, i64` input → `i32` counts.
1888/// - searchsorted: `f32, f64, i32, i64`.
1889#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
1890#[repr(u16)]
1891#[non_exhaustive]
1892pub enum SortKind {
1893 /// `sort(x, dim, descending)` — returns sorted values + sorted
1894 /// indices. PyTorch `torch.sort`.
1895 Sort = 0,
1896 /// Gradient of [`Self::Sort`] — scatter `dy` back to the original
1897 /// positions via the saved indices.
1898 SortBackward = 1,
1899 /// `argsort(x, dim, descending)` — returns sorted indices only.
1900 /// PyTorch `torch.argsort`.
1901 Argsort = 2,
1902 /// `msort(x)` — stable sort along the last dimension. Tie-break on
1903 /// original index preserves input order. PyTorch `torch.msort`.
1904 Msort = 3,
1905 /// Gradient of [`Self::Msort`] — same scatter as
1906 /// [`Self::SortBackward`].
1907 MsortBackward = 4,
1908 /// `topk(x, k, dim, largest)` — top-k (or bottom-k) values + their
1909 /// indices. PyTorch `torch.topk`. Trailblazer caps `k ≤ 64`.
1910 Topk = 5,
1911 /// Gradient of [`Self::Topk`] — scatter the k-wide `dy` back to a
1912 /// zero-init `row_len`-wide `dx` via saved indices.
1913 TopkBackward = 6,
1914 /// `kthvalue(x, k, dim)` — the k-th smallest value + its index.
1915 /// Composed at the Rust plan layer atop [`Self::Topk`] with the
1916 /// "bottom-k" order.
1917 Kthvalue = 7,
1918 /// Gradient of [`Self::Kthvalue`] — scatter the scalar `dy` back
1919 /// to the single source position.
1920 KthvalueBackward = 8,
1921 /// `unique(x, sorted=True)` — returns the unique values in `x`. At
1922 /// the Rust plan layer this chains [`Self::Sort`] + the consecutive
1923 /// dedup. Set-valued — no BW.
1924 Unique = 9,
1925 /// `unique_consecutive(x)` — emits one cell per run-start (input
1926 /// must be sorted, or only consecutive-equal cells should be
1927 /// collapsed). Set-valued — no BW.
1928 UniqueConsecutive = 10,
1929 /// `histogram(x, bins, range)` — 1-D uniform-bin histogram.
1930 /// PyTorch `torch.histogram`. FW only.
1931 Histogram = 11,
1932 /// `histogramdd(x, bins, range)` — N-D histogram. Reserved
1933 /// discriminant; rank > 1 trailblazer follow-up.
1934 Histogramdd = 12,
1935 /// `bincount(x, minlength)` — count occurrences of each integer
1936 /// in `x`. PyTorch `torch.bincount`. FW only.
1937 Bincount = 13,
1938 /// `searchsorted(sorted_seq, values, right)` — per-query
1939 /// lower/upper bound binary search. PyTorch `torch.searchsorted`.
1940 /// FW only.
1941 Searchsorted = 14,
1942}
1943
1944/// Image / spatial-transform op discriminant — Category T from the
1945/// comprehensive plan.
1946///
1947/// Stored as `u16` in [`crate::KernelSku::op`] when
1948/// `category == OpCategory::Image`. Phase 9 Category T wires the
1949/// trailblazer set:
1950/// - [`Self::InterpolateBilinear2d`] / [`Self::InterpolateBilinear2dBackward`]
1951/// — spatial up/downsample via bilinear interpolation.
1952/// - [`Self::GridSample2d`] / [`Self::GridSample2dBackward`] — sample
1953/// input at arbitrary normalized coordinates (PyTorch
1954/// `torch.nn.functional.grid_sample`, default config:
1955/// `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=false`).
1956/// - [`Self::AffineGrid2d`] — generate a sampling grid from a 2×3
1957/// affine matrix (companion to GridSample).
1958/// - [`Self::PixelShuffle`] / [`Self::PixelUnshuffle`] — pure index
1959/// permutation between `[N, C·r², H, W]` and `[N, C, H·r, W·r]`.
1960/// Each is the other's backward.
1961/// - [`Self::RoiAlign`] / [`Self::RoiAlignBackward`] — extract fixed-
1962/// size feature from variable RoIs via bilinear sampling.
1963/// - [`Self::RoiPool`] / [`Self::RoiPoolBackward`] — max-pool variant
1964/// of RoiAlign (argmax routing on BW).
1965/// - [`Self::Nms`] — non-max suppression on bounding boxes. Returns a
1966/// boolean keep mask + count; no BW (set-valued op).
1967///
1968/// Other interpolation modes (`nearest`, `bicubic`, `trilinear`,
1969/// `linear`, `area`) have discriminants reserved here but the kernels
1970/// are stubbed `Unsupported` in the trailblazer.
1971///
1972/// Trailblazer dtype coverage: `f32, f64` for math-bearing ops;
1973/// `pixel_shuffle` / `pixel_unshuffle` additionally cover `f16, bf16`
1974/// (pure layout — dtype-agnostic).
1975#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
1976#[repr(u16)]
1977#[non_exhaustive]
1978pub enum ImageKind {
1979 /// `interpolate(x, mode='bilinear', size=…)` — 2-D spatial
1980 /// resample with bilinear weights. Trailblazer wired today.
1981 InterpolateBilinear2d = 0,
1982 /// Gradient of [`Self::InterpolateBilinear2d`] — atomic-add of
1983 /// weighted contributions from each output cell to the 4 input
1984 /// cells it bilinearly sampled.
1985 InterpolateBilinear2dBackward = 1,
1986 /// `interpolate(x, mode='nearest')` — reserved.
1987 InterpolateNearest2d = 2,
1988 /// Gradient of [`Self::InterpolateNearest2d`] — reserved.
1989 InterpolateNearest2dBackward = 3,
1990 /// `interpolate(x, mode='bicubic')` — reserved.
1991 InterpolateBicubic2d = 4,
1992 /// Gradient of [`Self::InterpolateBicubic2d`] — reserved.
1993 InterpolateBicubic2dBackward = 5,
1994 /// `interpolate(x, mode='trilinear')` — reserved.
1995 InterpolateTrilinear3d = 6,
1996 /// Gradient of [`Self::InterpolateTrilinear3d`] — reserved.
1997 InterpolateTrilinear3dBackward = 7,
1998 /// `interpolate(x, mode='linear')` — reserved (1-D).
1999 InterpolateLinear1d = 8,
2000 /// Gradient of [`Self::InterpolateLinear1d`] — reserved.
2001 InterpolateLinear1dBackward = 9,
2002 /// `interpolate(x, mode='area')` — reserved (adaptive avg pool).
2003 InterpolateArea2d = 10,
2004 /// Gradient of [`Self::InterpolateArea2d`] — reserved.
2005 InterpolateArea2dBackward = 11,
2006
2007 /// `grid_sample(input, grid)` — 2-D bilinear, zeros-pad,
2008 /// `align_corners=false`. PyTorch defaults.
2009 GridSample2d = 16,
2010 /// Gradient of [`Self::GridSample2d`] — atomic-add into `dinput`
2011 /// + analytical bilinear coordinate derivatives into `dgrid`.
2012 GridSample2dBackward = 17,
2013 /// `affine_grid(theta, size)` — generate the normalized sampling
2014 /// grid for a 2×3 affine matrix. Companion to GridSample2d.
2015 AffineGrid2d = 18,
2016
2017 /// `pixel_shuffle(x, r)` — `[N, C·r², H, W] → [N, C, H·r, W·r]`.
2018 /// Pure index permutation. BW is `PixelUnshuffle`.
2019 PixelShuffle = 24,
2020 /// `pixel_unshuffle(x, r)` — `[N, C, H·r, W·r] → [N, C·r², H, W]`.
2021 /// Inverse of `PixelShuffle`. BW is `PixelShuffle`.
2022 PixelUnshuffle = 25,
2023
2024 /// `roi_align(input, rois, output_size, spatial_scale,
2025 /// sampling_ratio=0, aligned=false)`. PyTorch convention.
2026 RoiAlign = 32,
2027 /// Gradient of [`Self::RoiAlign`] — bilinear-weighted atomic-add
2028 /// into `dinput`.
2029 RoiAlignBackward = 33,
2030 /// `roi_pool(input, rois, output_size, spatial_scale)` — max-pool
2031 /// variant of RoiAlign. Saves argmax indices for BW.
2032 RoiPool = 34,
2033 /// Gradient of [`Self::RoiPool`] — atomic-add of `dout[i, c, h, w]`
2034 /// into `dinput` at the saved argmax cell.
2035 RoiPoolBackward = 35,
2036
2037 /// `nms(boxes, scores, iou_threshold)` — non-max suppression.
2038 /// Returns a boolean keep mask `[num_boxes]` and a count scalar.
2039 /// No BW (set-valued op).
2040 Nms = 40,
2041}