Skip to main content

baracuda_kernels_types/
element.rs

1//! Element types and trait hierarchy shared across baracuda kernel
2//! wrappers.
3//!
4//! # Trait map
5//!
6//! [`KernelDtype`] is the **umbrella marker** every kernel-usable dtype
7//! implements. It captures the minimum a dtype needs to participate in
8//! any kernel: a fixed memory layout ([`DeviceRepr`]), `Copy + 'static`,
9//! and a runtime tag ([`ElementKind`]) for dispatch. Phase 28 added
10//! `KernelDtype` as a `1.0`-freeze stability prereq — code that wants to
11//! accept *any* dtype (sub-byte, FP8, packed-bit) without enumerating
12//! sibling traits now has a single bound to reach for.
13//!
14//! The op-shaped sub-traits all extend [`KernelDtype`]:
15//!
16//! - [`Element`] — the **plan-shaped** family that participates in the
17//!   `<T: Element>`-parameterized elementwise plans
18//!   (`UnaryPlan<T, N>`, `BinaryPlan<T, N>`, …). Today: `f16`, `bf16`,
19//!   `f32`, [`F32Strict`], `f64`, `i32`, `i64`, [`Bool`], [`Complex32`],
20//!   [`Complex64`]. Adds a `type Scalar: ScalarType` projection for the
21//!   kernel's α/β scalar type.
22//! - [`IntElement`] — sub-byte / byte-packed integer GEMM operand types
23//!   ([`S8`], [`U8`], [`S4`], [`U4`]). Distinct trait because the
24//!   int-GEMM kernels use an int32 accumulator with float α/β, a
25//!   programming model that doesn't share kernel shape with the
26//!   elementwise plans.
27//! - [`FpElement`] — 8-bit floating-point GEMM operands ([`Fp8E4M3`],
28//!   [`Fp8E5M2`]). sm_89+ only.
29//! - [`BinElement`] — 1-bit packed-byte binary GEMM operands ([`Bin`]).
30//!   Distinct programming model (`mma.sync ... .b1.b1.s32.xor.popc`).
31//!
32//! Three sibling traits cover the auxiliary slot types that don't fit
33//! [`KernelDtype`]'s `ElementKind` projection (they have their own kind
34//! enum):
35//!
36//! - [`BiasElement`] — bias broadcast element types accepted by integer
37//!   GEMM epilogues. Today: `f32` and `i32`.
38//! - [`IndexElement`] — index element types accepted by indexing /
39//!   embedding / segment kernel families. Today: `i32` (legacy) and
40//!   `i64` (PyTorch default).
41//! - [`IndexOutputElement`] — output index dtype produced by
42//!   arg-reduction kernels. Today: `u32`, `i32`, `i64`.
43//!
44//! # When to use which
45//!
46//! - Reach for [`Element`] when writing a plan parameterized over the
47//!   primitive-FP / int / bool / complex family that goes through the
48//!   shared `BinaryPlan<T, N>` / `UnaryPlan<T, N>` shape.
49//! - Reach for [`IntElement`] / [`FpElement`] / [`BinElement`] when the
50//!   plan is one of the sub-byte / packed GEMM families.
51//! - Reach for [`KernelDtype`] when you genuinely don't care which
52//!   family the dtype belongs to — e.g. a generic "dtype size in bytes"
53//!   helper, a telemetry function that just needs the [`ElementKind`]
54//!   tag, or a downstream wrapper that wants to accept the union of all
55//!   kernel-usable dtypes.
56//!
57//! `Element` was originally named `CutlassElement` in the
58//! `baracuda-cutlass` crate. The rename here unifies the vocabulary
59//! across the wider kernel facade — `baracuda-cutlass` keeps the
60//! `CutlassElement` name available as a re-export for back-compat.
61
62use baracuda_types::DeviceRepr;
63use half::{bf16, f16};
64
65mod sealed {
66    pub trait Sealed {}
67}
68
69mod kerneldtype_sealed {
70    pub trait Sealed {}
71}
72
73mod scalar_sealed {
74    pub trait Sealed {}
75}
76
77mod int_sealed {
78    pub trait Sealed {}
79}
80
81mod fp_sealed {
82    pub trait Sealed {}
83}
84
85mod bin_sealed {
86    pub trait Sealed {}
87}
88
89mod bias_sealed {
90    pub trait Sealed {}
91}
92
93mod index_sealed {
94    pub trait Sealed {}
95}
96
97mod index_output_sealed {
98    pub trait Sealed {}
99}
100
101/// Umbrella marker trait for every dtype usable as a kernel input or
102/// output.
103///
104/// The bound captures the three minimum properties a kernel dtype
105/// needs: a fixed memory layout ([`DeviceRepr`]) so the host can ship
106/// bytes to the device verbatim, `Copy + 'static` so the type can
107/// flow through plan / args structs without an `&mut self`, and a
108/// runtime tag ([`ElementKind`]) for dispatch.
109///
110/// `KernelDtype` is **wider** than [`Element`]: it covers the
111/// sub-byte / FP8 / packed-bit newtypes (`S4`, `U4`, `S8`, `U8`,
112/// `Fp8E4M3`, `Fp8E5M2`, `Bin`) that have their own kernel families
113/// and don't fit the `<T: Element>` plan shape. Every [`Element`],
114/// [`IntElement`], [`FpElement`], and [`BinElement`] type also
115/// implements `KernelDtype` (the sibling traits all use it as a
116/// supertrait), so a function bounded by `<T: KernelDtype>` accepts
117/// any kernel-usable type.
118///
119/// Sealed because adding a new dtype requires a matching kernel
120/// instantiation in `baracuda-kernels-sys`.
121///
122/// # When to use
123///
124/// Prefer [`Element`] when you're parameterizing a plan that lives in
125/// the elementwise / reduce / scan / norm / loss families — those
126/// plan shapes are written against `<T: Element>` and use the
127/// `type Scalar` projection. Reach for `KernelDtype` only when you
128/// genuinely want the **union** of every kernel dtype (sub-byte +
129/// FP8 + packed-bit included) — e.g. a generic dtype-size helper,
130/// telemetry function, or downstream wrapper.
131pub trait KernelDtype:
132    DeviceRepr + kerneldtype_sealed::Sealed + Copy + 'static
133{
134    /// Runtime tag for this dtype. Stable across the workspace —
135    /// keyed by this same enum in [`crate::KernelSku::element`].
136    const KIND: ElementKind;
137}
138
139/// Sealed marker for the alpha/beta scalar type an [`Element`] uses.
140///
141/// `f32` for f16/bf16/f32/[`F32Strict`] kernels (epilogue compute runs at
142/// f32). `f64` for f64 kernels. Sealed to keep the kernel-side dispatch
143/// closed — adding a new scalar type requires shipping new C ABI
144/// signatures in the underlying `*-kernels-sys` crate.
145pub trait ScalarType: scalar_sealed::Sealed + Copy + Default + PartialEq + 'static {
146    /// Discriminant used by the plan layer to dispatch to f32-scalar vs
147    /// f64-scalar FFI entry points.
148    const IS_F64: bool;
149
150    /// Additive identity (`0.0`). Useful when writing generic code over
151    /// `<S: ScalarType>` that needs to initialize accumulators or default
152    /// alpha/beta values.
153    const ZERO: Self;
154
155    /// Multiplicative identity (`1.0`). Useful when writing generic code
156    /// over `<S: ScalarType>` that needs a unit alpha value.
157    const ONE: Self;
158
159    /// Convert to `f32`. Used by the plan layer to feed the f32-scalar
160    /// FFI dispatchers when `IS_F64` is `false` (round-trip is lossless
161    /// because the underlying type IS `f32` in that branch). When called
162    /// on the `f64` impl this is a narrowing cast — only callers that
163    /// gate on `IS_F64 == false` should reach it.
164    #[doc(hidden)]
165    fn to_f32(self) -> f32;
166
167    /// Convert to `f64`. Used by the plan layer to feed the f64-scalar
168    /// FFI dispatchers when `IS_F64` is `true`. Lossless from both
169    /// underlying types.
170    #[doc(hidden)]
171    fn to_f64(self) -> f64;
172
173    /// Convert from `f32`. Lossless for the `f32` impl, widening for the
174    /// `f64` impl. Use this instead of `as` casts when writing generic
175    /// code over `<S: ScalarType>` — `S::from_f32(0.5)` works regardless
176    /// of which scalar type is bound.
177    fn from_f32(x: f32) -> Self;
178}
179
180impl scalar_sealed::Sealed for f32 {}
181impl scalar_sealed::Sealed for f64 {}
182
183impl ScalarType for f32 {
184    const IS_F64: bool = false;
185    const ZERO: Self = 0.0;
186    const ONE: Self = 1.0;
187    #[inline] fn to_f32(self) -> f32 { self }
188    #[inline] fn to_f64(self) -> f64 { self as f64 }
189    #[inline] fn from_f32(x: f32) -> Self { x }
190}
191impl ScalarType for f64 {
192    const IS_F64: bool = true;
193    const ZERO: Self = 0.0;
194    const ONE: Self = 1.0;
195    #[inline] fn to_f32(self) -> f32 { self as f32 }
196    #[inline] fn to_f64(self) -> f64 { self }
197    #[inline] fn from_f32(x: f32) -> Self { x as f64 }
198}
199
200/// Element types supported by the kernel facade.
201///
202/// Sealed to prevent downstream `impl`s — adding a new dtype requires
203/// shipping a new kernel instantiation in the corresponding `*-kernels-sys`
204/// crate.
205///
206/// The trait spans three families that share the `<T: Element>`-
207/// parameterized plan shape but route through distinct kernel SKUs:
208///
209/// - **Floating-point**: `f16`, `bf16`, `f32`, [`F32Strict`], `f64`.
210///   `f32` reduces through TF32 tensor cores (10-bit mantissa);
211///   [`F32Strict`] uses SIMT CUDA cores at full IEEE 754 binary32 with
212///   bit-stable results. The `Scalar` projection is `f32` for the
213///   16-bit / 32-bit float members and `f64` for `f64`.
214/// - **Integer**: `i32`, `i64`. Used for elementwise integer arithmetic
215///   (bitwise ops, integer comparison). The `Scalar` projection is
216///   `f32` — these types don't participate in α/β-scaled epilogues, so
217///   the projection is nominal. Note: [`S8`] / [`U8`] / [`S4`] / [`U4`]
218///   are GEMM-only operand types and live on the separate [`IntElement`]
219///   trait — they don't implement [`Element`].
220/// - **Boolean**: [`Bool`] (1-byte storage, 0/non-zero truthiness).
221///   Used for logical ops and as the output type of comparison ops.
222///   The `Scalar` projection is `f32` (also nominal).
223///
224/// Sibling traits [`IntElement`], [`FpElement`], [`BinElement`], and
225/// [`BiasElement`] cover GEMM-only / FP8 / packed-bit / bias-broadcast
226/// types respectively; those have their own kernel families and don't
227/// route through `<T: Element>`-parameterized elementwise plans. The
228/// umbrella [`KernelDtype`] supertrait covers the union of `Element`
229/// + `IntElement` + `FpElement` + `BinElement`.
230///
231/// # `KIND` lookup
232///
233/// `Element` does NOT redeclare `const KIND`; the const is inherited
234/// from the [`KernelDtype`] supertrait. This keeps `T::KIND` unambiguous
235/// at every call site under `<T: Element>` bounds. Pre-Phase-28 code
236/// using the fully-qualified form `<T as Element>::KIND` must update
237/// to `<T as KernelDtype>::KIND` (or just plain `T::KIND` which works
238/// regardless of which trait bound is in scope).
239pub trait Element: KernelDtype + sealed::Sealed {
240    /// Scalar type used for the kernel's alpha / beta parameters (and
241    /// the epilogue compute type). `f32` for f16/bf16/f32/[`F32Strict`]
242    /// — the epilogue runs at f32 to match the F32 accumulator. `f64`
243    /// for [`prim@f64`] — the DGEMM path uses an F64 accumulator and
244    /// f64 alpha/beta. For integer / [`Bool`] elements the projection
245    /// is nominally `f32` (no α/β-scaled epilogue applies).
246    type Scalar: ScalarType;
247}
248
249impl sealed::Sealed for f16 {}
250impl sealed::Sealed for bf16 {}
251impl sealed::Sealed for f32 {}
252impl sealed::Sealed for F32Strict {}
253impl sealed::Sealed for f64 {}
254impl sealed::Sealed for i32 {}
255impl sealed::Sealed for i64 {}
256impl sealed::Sealed for Bool {}
257
258impl Element for f16 {
259    type Scalar = f32;
260}
261
262impl Element for bf16 {
263    type Scalar = f32;
264}
265
266/// `f32` GEMM routes through TF32 tensor cores — see
267/// [`crate::PrecisionGuarantee::math_precision`] (returns
268/// [`MathPrecision::Tf32`]). Inputs are full F32; the math instruction
269/// reduces to TF32 (10-bit mantissa) and accumulates into F32. Use
270/// [`F32Strict`] instead when bit-stable, full-precision IEEE 754
271/// binary32 math is required.
272impl Element for f32 {
273    type Scalar = f32;
274}
275
276/// `f64` GEMM via Ampere FP64 tensor cores (DGEMM). Full IEEE 754
277/// binary64 inputs, accumulator, and scalars. Analogous to cuBLAS's
278/// `CUBLAS_COMPUTE_64F`.
279impl Element for f64 {
280    type Scalar = f64;
281}
282
283/// `i32` as an elementwise kernel input element. Used by the integer
284/// arithmetic kernels (bitwise and / or / xor / shift, integer
285/// comparison, integer scans). Distinct from [`ElementKind::I32`]'s
286/// historical use as an accumulator-only marker for integer GEMMs —
287/// here `i32` is a first-class kernel *input* type with an `Element`
288/// impl, so the same `BinaryPlan<T, N>` / `UnaryPlan<T, N>` shapes
289/// extend to integer arithmetic.
290///
291/// The `Scalar` projection is `f32` (nominal — integer kernels don't
292/// use α/β-scaled epilogues today).
293impl Element for i32 {
294    type Scalar = f32;
295}
296
297/// `i64` as an elementwise kernel input element. Sibling of the `i32`
298/// impl above for 64-bit integer arithmetic (PyTorch's default integer
299/// tensor dtype). Same kernel families, twice the storage width.
300impl Element for i64 {
301    type Scalar = f32;
302}
303
304/// Boolean as an elementwise kernel input element. Used by the logical
305/// op family (`logical_and` / `logical_or` / `logical_xor`) and as the
306/// output type of comparison ops. Storage is 1 byte per element via the
307/// [`Bool`] wrapper.
308///
309/// The `Scalar` projection is `f32` (nominal).
310impl Element for Bool {
311    type Scalar = f32;
312}
313
314impl sealed::Sealed for Complex32 {}
315impl sealed::Sealed for Complex64 {}
316
317/// Single-precision complex (interleaved real/imag pair of `f32`) as an
318/// elementwise kernel input element. Used by the FFT family (`fft`,
319/// `ifft`, `rfft` output / `irfft` input, etc.) for spectrum-domain
320/// tensors. The `Scalar` projection is `f32` (matches the real width).
321impl Element for Complex32 {
322    type Scalar = f32;
323}
324
325/// Double-precision complex (interleaved real/imag pair of `f64`) as an
326/// elementwise kernel input element. Sibling to [`Complex32`]; the
327/// `Scalar` projection is `f64`.
328impl Element for Complex64 {
329    type Scalar = f64;
330}
331
332// ============================================================================
333// Boolean element type — implements Element directly
334// ============================================================================
335
336/// Boolean element marker. `#[repr(transparent)]` wrapper around `u8`
337/// (1-byte storage).
338///
339/// Truthiness convention follows PyTorch / NumPy: `0` is false; **any**
340/// non-zero byte is true. Kernels that consume `Bool` operands normalize
341/// the input to `0` or `1` before applying the logical op so the result
342/// is always strictly `0` or `1`. The wrapper is `#[repr(transparent)]`
343/// over `u8`, so a `DeviceBuffer<u8>` (byte substrate) can be
344/// reinterpreted as a `DeviceBuffer<Bool>` via `view_as` without
345/// copying.
346///
347/// Used as the element type of comparison-op output tensors (`eq`, `gt`,
348/// …) and as the input element type for the logical-op family
349/// (`logical_and`, `logical_or`, `logical_xor`). Implements [`Element`]
350/// so the same `BinaryPlan<T, N>` / `UnaryPlan<T, N>` shapes extend to
351/// boolean tensors.
352#[repr(transparent)]
353#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
354pub struct Bool(pub u8);
355
356impl Bool {
357    /// Build a [`Bool`] from a Rust `bool`. `true` becomes `1`, `false`
358    /// becomes `0`.
359    #[inline]
360    pub const fn new(b: bool) -> Self {
361        Self(b as u8)
362    }
363
364    /// Convert to a Rust `bool` using the PyTorch convention: any
365    /// non-zero byte is true.
366    #[inline]
367    pub const fn to_bool(self) -> bool {
368        self.0 != 0
369    }
370}
371
372// SAFETY: Bool is #[repr(transparent)] around u8, which is DeviceRepr.
373// Same ABI, same Copy + 'static bounds.
374unsafe impl DeviceRepr for Bool {}
375
376// ============================================================================
377// Complex element types — implements Element directly
378// ============================================================================
379
380/// Single-precision complex element. `#[repr(C)]` struct of two `f32`
381/// fields (real, imag) — ABI-compatible with cuFFT's `cufftComplex`
382/// (which is itself an alias for CUDA's `float2`), with NumPy's
383/// `complex64`, and with PyTorch's `torch.complex64`.
384///
385/// Used by the FFT op family (Milestone 6.4) as the element type for
386/// spectrum-domain tensors. Complex arithmetic is not a kernel concern
387/// at this layer — Rust callers build / inspect complex values via the
388/// `re` / `im` fields and pass `DeviceBuffer<Complex32>` directly to
389/// the FFT plans, which reinterpret them as `cufftComplex` over the
390/// FFI boundary.
391///
392/// Layout invariant: `Complex32 { re, im }` and `cufftComplex { x, y }`
393/// share identical byte storage on every platform CUDA supports
394/// (`(f32, f32)` is 8-byte aligned, naturally padded). A
395/// `DeviceBuffer<Complex32>` can be reinterpreted as a
396/// `DeviceBuffer<cufftComplex>` via `view_as` without copying.
397#[repr(C)]
398#[derive(Copy, Clone, Debug, Default, PartialEq)]
399pub struct Complex32 {
400    /// Real component.
401    pub re: f32,
402    /// Imaginary component.
403    pub im: f32,
404}
405
406impl Complex32 {
407    /// Build a `Complex32` from real and imaginary `f32` parts.
408    #[inline]
409    pub const fn new(re: f32, im: f32) -> Self {
410        Self { re, im }
411    }
412}
413
414/// Double-precision complex element. `#[repr(C)]` struct of two `f64`
415/// fields — ABI-compatible with cuFFT's `cufftDoubleComplex`, NumPy's
416/// `complex128`, and PyTorch's `torch.complex128`. Sibling to
417/// [`Complex32`].
418#[repr(C)]
419#[derive(Copy, Clone, Debug, Default, PartialEq)]
420pub struct Complex64 {
421    /// Real component.
422    pub re: f64,
423    /// Imaginary component.
424    pub im: f64,
425}
426
427impl Complex64 {
428    /// Build a `Complex64` from real and imaginary `f64` parts.
429    #[inline]
430    pub const fn new(re: f64, im: f64) -> Self {
431        Self { re, im }
432    }
433}
434
435// SAFETY: Complex32 / Complex64 are #[repr(C)] structs of two FP fields
436// each, with no padding (8-byte and 16-byte natural alignment), so they
437// satisfy DeviceRepr's invariants (no uninitialized bytes, no host-side
438// resource handles, byte-for-byte transferable between host and device).
439unsafe impl DeviceRepr for Complex32 {}
440unsafe impl DeviceRepr for Complex64 {}
441
442// ============================================================================
443// Integer element family — sibling to Element
444// ============================================================================
445
446/// Signed 8-bit integer element marker. `#[repr(transparent)]` around
447/// `i8`.
448///
449/// Identical memory layout to `i8`, so a `DeviceBuffer<i8>` (or any byte
450/// substrate the caller has) can be reinterpreted as a `DeviceBuffer<S8>`
451/// via `view_as` without copying. The wrapper exists to drive kernel
452/// selection at the Rust type level: integer GEMM plans parameterized on
453/// `S8` route the launch through the signed int8 tensor-core kernels.
454///
455/// Numerical contract: int8 inputs, int32 accumulator, float alpha/beta
456/// scaling, saturating round-to-nearest cast back to int8 on store.
457#[repr(transparent)]
458#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
459pub struct S8(pub i8);
460
461/// Unsigned 8-bit integer element marker. `#[repr(transparent)]` around
462/// `u8`.
463///
464/// Identical memory layout to `u8`, so a `DeviceBuffer<u8>` (byte
465/// substrate) can be reinterpreted as a `DeviceBuffer<U8>` (quantized
466/// GEMM operand) via `view_as` without copying. The wrapper exists to
467/// disambiguate "byte buffer" from "quantized operand" at the Rust type
468/// level — a `DeviceBuffer<U8>` is unambiguously a GEMM operand,
469/// `DeviceBuffer<u8>` stays a byte-storage abstraction.
470///
471/// Numerical contract: same as [`S8`] except the multiply operands are
472/// unsigned. The accumulator is still int32 and alpha/beta are still
473/// float; saturating cast at store clamps to `[0, 255]`.
474#[repr(transparent)]
475#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
476pub struct U8(pub u8);
477
478// SAFETY: S8 / U8 are #[repr(transparent)] around i8 / u8, which are
479// both DeviceRepr. Same ABI, same Copy + 'static bounds.
480unsafe impl DeviceRepr for S8 {}
481unsafe impl DeviceRepr for U8 {}
482
483/// Signed 4-bit integer element marker — **packed-pair storage**.
484///
485/// `#[repr(transparent)]` around `u8`. One [`S4`] *storage slot* is one
486/// byte and holds **two** packed s4 elements: the low nibble is the
487/// element at even logical index, the high nibble is the element at
488/// odd logical index (along the K axis for A/B operands, along the
489/// N axis for D output). Sign-extended to s32 on the GPU side via
490/// `((s8)(nibble << 4)) >> 4`.
491///
492/// A `DeviceBuffer<u8>` of `(M*K)/2` bytes can be reinterpreted as a
493/// `DeviceBuffer<S4>` of `(M*K)/2` storage slots via `view_as` without
494/// copying — `S4` is byte-storage at the buffer layer, and *element
495/// count* lives at the plan-layer descriptor (M / N / K).
496///
497/// Numerical range per element: `[-8, +7]`. The plan layer
498/// (`Int4GemmPlan` in `baracuda-kernels`) takes `M`, `N`, `K` in
499/// **element** counts and leading dimensions in **storage-slot
500/// (= byte)** counts — `MatrixRef<S4>::ld` therefore equals `K / 2` for
501/// row-major A with no padding. `K` must be even (packing is byte-
502/// aligned). Routes through Ada Lovelace int4 tensor cores
503/// (`mma.sync.aligned.m16n8k64.row.col.satfinite.s32.s4.s4.s32`) with
504/// S32 accumulation and float `alpha` / `beta` scaling. First landed in
505/// baracuda-kernels Phase 2.
506#[repr(transparent)]
507#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
508pub struct S4(pub u8);
509
510/// Unsigned 4-bit integer element marker — **packed-pair storage**.
511///
512/// `#[repr(transparent)]` around `u8`. Packing convention is identical
513/// to [`S4`] (low nibble = even index, high nibble = odd index); the
514/// only difference is zero-extension to s32 on the GPU side
515/// (`nibble & 0xF`).
516///
517/// Numerical range per element: `[0, 15]`. Plan-layer conventions
518/// (M/N/K in elements, LDs in storage slots, K even) match [`S4`].
519/// Routes through Ada Lovelace int4 tensor cores
520/// (`mma.sync.aligned.m16n8k64.row.col.satfinite.s32.u4.u4.s32`) with
521/// the same S32 accumulator and `float` α/β family as [`S4`].
522#[repr(transparent)]
523#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
524pub struct U4(pub u8);
525
526// SAFETY: S4 / U4 are #[repr(transparent)] around u8, which is DeviceRepr.
527// Same ABI, same Copy + 'static bounds.
528unsafe impl DeviceRepr for S4 {}
529unsafe impl DeviceRepr for U4 {}
530
531impl S4 {
532    /// Pack two s4 values `[lo, hi]` (each in `[-8, +7]`) into one
533    /// storage slot. Values outside the range are masked to their low 4
534    /// bits (no saturation — pre-clamp on the caller side if needed).
535    #[inline]
536    pub fn pack(lo: i8, hi: i8) -> Self {
537        Self(((lo as u8) & 0x0F) | (((hi as u8) & 0x0F) << 4))
538    }
539
540    /// Unpack into `[low_nibble_as_s4, high_nibble_as_s4]`. Each
541    /// returned value is sign-extended from the 4-bit nibble.
542    #[inline]
543    pub fn unpack(self) -> [i8; 2] {
544        let lo = ((self.0 & 0x0F) << 4) as i8 >> 4;
545        let hi = (self.0 & 0xF0) as i8 >> 4;
546        [lo, hi]
547    }
548}
549
550impl U4 {
551    /// Pack two u4 values `[lo, hi]` (each in `[0, 15]`) into one
552    /// storage slot. Values outside the range are masked to their low 4
553    /// bits.
554    #[inline]
555    pub fn pack(lo: u8, hi: u8) -> Self {
556        Self((lo & 0x0F) | ((hi & 0x0F) << 4))
557    }
558
559    /// Unpack into `[low_nibble, high_nibble]`. Each returned value is
560    /// in `[0, 15]`.
561    #[inline]
562    pub fn unpack(self) -> [u8; 2] {
563        [self.0 & 0x0F, (self.0 >> 4) & 0x0F]
564    }
565}
566
567/// Integer element types supported by the int-GEMM kernel set.
568///
569/// Sibling trait to [`Element`] (the float family) — kept separate
570/// because the kernel-level dispatch, accumulator type (int32 vs f32),
571/// and epilogue family differ enough that mixing them through a single
572/// trait would smear the type signatures of integer plans.
573///
574/// Sealed to prevent downstream `impl`s — adding a new int dtype
575/// requires shipping new kernel instantiations.
576///
577/// `KIND` is inherited from the [`KernelDtype`] supertrait. Pre-Phase-28
578/// code using `<T as IntElement>::KIND` must update to plain `T::KIND`
579/// or `<T as KernelDtype>::KIND`.
580pub trait IntElement: KernelDtype + int_sealed::Sealed {}
581
582impl int_sealed::Sealed for S8 {}
583impl int_sealed::Sealed for U8 {}
584impl int_sealed::Sealed for S4 {}
585impl int_sealed::Sealed for U4 {}
586
587impl IntElement for S8 {}
588impl IntElement for U8 {}
589impl IntElement for S4 {}
590impl IntElement for U4 {}
591
592// ============================================================================
593// 8-bit floating-point element family — sibling to Element / IntElement
594// ============================================================================
595
596/// 8-bit floating-point, E4M3 encoding (1 sign + 4 exponent + 3 mantissa,
597/// exponent bias 7).
598///
599/// `#[repr(transparent)]` around `u8` storage — bit-compatible with
600/// `__nv_fp8_storage_t` on the CUDA side and with `float8::F8E4M3` on the
601/// host side. A `DeviceBuffer<u8>` (byte substrate) can be reinterpreted
602/// as `DeviceBuffer<Fp8E4M3>` via `view_as` without copying.
603///
604/// Numerical range: ±448 (max finite). One NaN encoding only
605/// (`S.1111.111`); E4M3 has **no infinities**. The conversion path
606/// matches NVIDIA's `__nv_cvt_float_to_fp8(x, __NV_SATFINITE, __NV_E4M3)`:
607/// round-half-to-even, saturating-to-max-finite on overflow.
608///
609/// Routes through Ada Lovelace FP8 tensor cores
610/// (`mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32`) with F32
611/// accumulation and float alpha / beta scaling. First landed in
612/// baracuda-kernels Phase 2.
613#[repr(transparent)]
614#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
615pub struct Fp8E4M3(pub u8);
616
617impl Fp8E4M3 {
618    /// Convert from `f32` using NVIDIA's `SATFINITE` semantics
619    /// (round-half-to-even, clamp `|x|` to the E4M3 max-finite `448.0`).
620    #[inline]
621    pub fn from_f32(x: f32) -> Self {
622        Self(float8::F8E4M3::from_f32(x).to_bits())
623    }
624
625    /// Convert to `f32`. The E4M3 grid is sparse — the result is one of
626    /// 254 finite values (or NaN) on the E4M3 lattice.
627    #[inline]
628    pub fn to_f32(self) -> f32 {
629        float8::F8E4M3::from_bits(self.0).to_f32()
630    }
631}
632
633// SAFETY: Fp8E4M3 is #[repr(transparent)] over u8, which is DeviceRepr.
634// Same ABI, same Copy + 'static bounds.
635unsafe impl DeviceRepr for Fp8E4M3 {}
636
637/// 8-bit floating-point, E5M2 encoding (1 sign + 5 exponent + 2 mantissa,
638/// exponent bias 15).
639///
640/// `#[repr(transparent)]` around `u8` storage — bit-compatible with
641/// `__nv_fp8_storage_t` on the CUDA side and with `float8::F8E5M2` on the
642/// host side. A `DeviceBuffer<u8>` (byte substrate) can be reinterpreted
643/// as `DeviceBuffer<Fp8E5M2>` via `view_as` without copying.
644///
645/// Numerical range: ±57344 (max finite). IEEE-style infinity and NaN
646/// encodings (unlike [`Fp8E4M3`], which has neither). The conversion
647/// path matches NVIDIA's
648/// `__nv_cvt_float_to_fp8(x, __NV_SATFINITE, __NV_E5M2)`:
649/// round-half-to-even, saturating-to-max-finite on overflow.
650///
651/// Routes through Ada Lovelace FP8 tensor cores
652/// (`mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32`) with F32
653/// accumulation and float alpha / beta scaling.
654#[repr(transparent)]
655#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
656pub struct Fp8E5M2(pub u8);
657
658impl Fp8E5M2 {
659    /// Convert from `f32` using NVIDIA's `SATFINITE` semantics
660    /// (round-half-to-even, clamp `|x|` to the E5M2 max-finite `57344.0`).
661    #[inline]
662    pub fn from_f32(x: f32) -> Self {
663        Self(float8::F8E5M2::from_f32(x).to_bits())
664    }
665
666    /// Convert to `f32`. The E5M2 grid is sparse — the result is one of
667    /// the finite values (or inf / NaN) on the E5M2 lattice.
668    #[inline]
669    pub fn to_f32(self) -> f32 {
670        float8::F8E5M2::from_bits(self.0).to_f32()
671    }
672}
673
674// SAFETY: Fp8E5M2 is #[repr(transparent)] over u8, which is DeviceRepr.
675// Same ABI, same Copy + 'static bounds.
676unsafe impl DeviceRepr for Fp8E5M2 {}
677
678/// 8-bit floating-point element types supported by the kernel facade.
679///
680/// Sibling trait to [`Element`] (which covers f16 / bf16 / f32 /
681/// [`F32Strict`] / f64) and to [`IntElement`] (which covers S8 / U8) —
682/// kept separate because the FP8 kernel family has its own MMA
683/// instruction set (`mma.sync ... .f32.e4m3.e4m3.f32`), arch requirement
684/// (sm_89+), and conversion semantics (saturating-to-max-finite vs the
685/// int family's saturating-to-INT_MAX).
686///
687/// Sealed because adding a new FP8 variant requires shipping new kernel
688/// instantiations in `baracuda-kernels-sys`.
689///
690/// `KIND` is inherited from the [`KernelDtype`] supertrait. Pre-Phase-28
691/// code using `<T as FpElement>::KIND` must update to plain `T::KIND`
692/// or `<T as KernelDtype>::KIND`.
693pub trait FpElement: KernelDtype + fp_sealed::Sealed {}
694
695impl fp_sealed::Sealed for Fp8E4M3 {}
696impl fp_sealed::Sealed for Fp8E5M2 {}
697
698impl FpElement for Fp8E4M3 {}
699impl FpElement for Fp8E5M2 {}
700
701// ============================================================================
702// Binary element family — sibling to Element / IntElement / FpElement
703// ============================================================================
704
705/// 1-bit binary element marker — **packed-byte storage**.
706///
707/// `#[repr(transparent)]` around `u8`. One [`Bin`] *storage slot* is one
708/// byte and holds **eight** packed b1 elements: bit `i` of the byte
709/// (LSB = bit 0) is the element at K offset `8 * byte_idx + i`. Packing
710/// is along the K axis for A/B operands.
711///
712/// A `DeviceBuffer<u8>` of `(M*K)/8` bytes can be reinterpreted as a
713/// `DeviceBuffer<Bin>` of `(M*K)/8` storage slots via `view_as` without
714/// copying.
715///
716/// Routes through Ampere+ binary tensor cores
717/// (`mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc`) with
718/// an **S32 output accumulator**. Unlike the int4 / int8 / FP8
719/// families, bin GEMM does **not** quantize its output back to the
720/// input element type — the result is the raw popcount accumulator
721/// (`popcount(xor(A_row, B_col))` summed over K bytes), surfaced as
722/// `i32`. No α / β / bias / activation chain (the popcount programming
723/// model doesn't have a meaningful place for them).
724///
725/// The plan layer ([`Bin` is consumed by `BinGemmPlan` in
726/// `baracuda-kernels`) takes `M`, `N`, `K` in **element** counts and
727/// leading dimensions in **storage-slot (= byte)** counts —
728/// `MatrixRef<Bin>::ld` therefore equals `K / 8` for row-major A with
729/// no padding. `K` must be divisible by 8 (packing is byte-aligned).
730#[repr(transparent)]
731#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
732pub struct Bin(pub u8);
733
734// SAFETY: Bin is #[repr(transparent)] around u8, which is DeviceRepr.
735unsafe impl DeviceRepr for Bin {}
736
737impl Bin {
738    /// Pack eight `bool` values `bits[0..8]` into one storage byte.
739    /// `bits[i]` becomes bit `i` of the result (LSB-first).
740    #[inline]
741    pub fn pack(bits: [bool; 8]) -> Self {
742        let mut b = 0u8;
743        let mut i = 0;
744        while i < 8 {
745            if bits[i] {
746                b |= 1 << i;
747            }
748            i += 1;
749        }
750        Self(b)
751    }
752
753    /// Unpack one storage byte into eight `bool` values along K (LSB-first).
754    #[inline]
755    pub fn unpack(self) -> [bool; 8] {
756        let b = self.0;
757        [
758            (b >> 0) & 1 != 0,
759            (b >> 1) & 1 != 0,
760            (b >> 2) & 1 != 0,
761            (b >> 3) & 1 != 0,
762            (b >> 4) & 1 != 0,
763            (b >> 5) & 1 != 0,
764            (b >> 6) & 1 != 0,
765            (b >> 7) & 1 != 0,
766        ]
767    }
768}
769
770/// Binary (1-bit) element types supported by the kernel facade.
771///
772/// Sibling trait to [`Element`] / [`IntElement`] / [`FpElement`] —
773/// kept separate because the bin kernel family has a distinct
774/// programming model (popcount-based, `D = popcount(xor(A, B))`, no
775/// α/β/bias/activation chain) and a non-matching output type (raw
776/// `i32` accumulator rather than re-quantized to the input type).
777///
778/// `KIND` is inherited from the [`KernelDtype`] supertrait. Pre-Phase-28
779/// code using `<T as BinElement>::KIND` must update to plain `T::KIND`
780/// or `<T as KernelDtype>::KIND`.
781pub trait BinElement: KernelDtype + bin_sealed::Sealed {}
782
783impl bin_sealed::Sealed for Bin {}
784
785impl BinElement for Bin {}
786
787/// Bias element types accepted by the int-GEMM bias epilogue family.
788///
789/// Integer GEMM kernels can broadcast either a per-channel `f32` bias
790/// (matching the float bias convention used elsewhere) or a per-channel
791/// `i32` bias (matching TensorRT's int8 inference convention). The
792/// choice is a compile-time generic on integer plans — `<T, f32>` and
793/// `<T, i32>` resolve to distinct kernel SKUs.
794///
795/// Sealed because the bias-element kernel variants are baked into the
796/// `*-kernels-sys` crates at build time.
797pub trait BiasElement: DeviceRepr + bias_sealed::Sealed + Copy + 'static {
798    /// Runtime tag for this bias element type.
799    const KIND: BiasElementKind;
800}
801
802impl bias_sealed::Sealed for f32 {}
803impl bias_sealed::Sealed for i32 {}
804
805impl BiasElement for f32 {
806    const KIND: BiasElementKind = BiasElementKind::F32;
807}
808impl BiasElement for i32 {
809    const KIND: BiasElementKind = BiasElementKind::I32;
810}
811
812/// Runtime tag for a [`BiasElement`].
813///
814/// **Intentionally NOT `#[non_exhaustive]`** — the int-GEMM bias
815/// dispatchers exhaustively match `(T::KIND, BT::KIND)` to pick
816/// per-bias-dtype kernel SKUs. Adding a new bias dtype (e.g. `f16`
817/// for quantized-GEMM) should surface as a build break across every
818/// match site so each can wire or reject. New variants are a
819/// deliberate breaking-change event.
820#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
821pub enum BiasElementKind {
822    /// IEEE 754 binary32 bias broadcast. The conservative default —
823    /// matches the float-GEMM bias convention.
824    F32,
825    /// Signed 32-bit integer bias broadcast. Matches the convention
826    /// TensorRT uses for int8 inference (per-channel int32 bias).
827    I32,
828}
829
830/// Sealed marker trait for index-element types accepted by the
831/// indexing / embedding / segment kernel families.
832///
833/// Phase 11.5 (Fuel team feedback #7): split out as a sibling of
834/// [`Element`] so plans like [`crate::indexing::GatherPlan`] /
835/// [`crate::embedding::EmbeddingPlan`] / [`crate::segment::SegmentSumPlan`]
836/// can dispatch over the index dtype without coupling the value-dtype
837/// trait hierarchy. Today's members are `i32` (legacy) and `i64`
838/// (PyTorch default). Sealed because new members require a matching
839/// FFI entry point in the `*-kernels-sys` crate.
840pub trait IndexElement: DeviceRepr + index_sealed::Sealed + Copy + 'static {
841    /// Runtime tag for this index element type.
842    const KIND: IndexElementKind;
843}
844
845impl index_sealed::Sealed for i32 {}
846impl index_sealed::Sealed for i64 {}
847
848impl IndexElement for i32 {
849    const KIND: IndexElementKind = IndexElementKind::I32;
850}
851impl IndexElement for i64 {
852    const KIND: IndexElementKind = IndexElementKind::I64;
853}
854
855/// Runtime tag for an [`IndexElement`]. `i32` is the legacy default;
856/// `i64` was added in Phase 11.5 to match PyTorch's int64 index
857/// convention without an extra cast pass.
858///
859/// `#[non_exhaustive]` — additional index dtypes (`u32` follows the
860/// IndexOutputElement precedent) may land in future phases. Match
861/// arms must include a `_ =>` catch-all.
862#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
863#[non_exhaustive]
864pub enum IndexElementKind {
865    /// Signed 32-bit index dtype — legacy default.
866    I32,
867    /// Signed 64-bit index dtype — PyTorch default.
868    I64,
869}
870
871/// Sealed marker trait for the *output* index dtype produced by
872/// arg-reduction kernels (`argmax` / `argmin` axis ops).
873///
874/// Phase 12.2 (Fuel team feedback): split out as a sibling of
875/// [`IndexElement`] (which marks *input* index dtypes accepted by
876/// indexing / embedding / segment kernels) so plans like
877/// [`crate::ArgReduceKind`]-driven `ArgReducePlan` can dispatch over the
878/// output dtype without affecting the input-index trait hierarchy.
879///
880/// Today's members are `u32`, `i32`, and `i64`. PyTorch defaults to
881/// `i64`; CUB / NVIDIA libraries and some downstream frameworks (e.g.
882/// Fuel) prefer `u32`. The trait is sealed because new members require
883/// a matching FFI entry point in the `*-kernels-sys` crate.
884pub trait IndexOutputElement:
885    DeviceRepr + index_output_sealed::Sealed + Copy + Default + 'static
886{
887    /// Runtime tag for this output index element type.
888    const KIND: IndexOutputKind;
889}
890
891impl index_output_sealed::Sealed for u32 {}
892impl index_output_sealed::Sealed for i32 {}
893impl index_output_sealed::Sealed for i64 {}
894
895impl IndexOutputElement for u32 {
896    const KIND: IndexOutputKind = IndexOutputKind::U32;
897}
898impl IndexOutputElement for i32 {
899    const KIND: IndexOutputKind = IndexOutputKind::I32;
900}
901impl IndexOutputElement for i64 {
902    const KIND: IndexOutputKind = IndexOutputKind::I64;
903}
904
905/// Runtime tag for an [`IndexOutputElement`]. `i64` is the default
906/// (PyTorch convention) and the only variant prior to Phase 12.2;
907/// `u32` and `i32` were added so downstream frameworks that prefer
908/// narrower index dtypes (Fuel uses `u32`) can avoid a post-pass cast.
909///
910/// `#[non_exhaustive]` — additional output index dtypes (`u64` for
911/// frameworks that prefer unsigned indices end-to-end) may land in
912/// future phases. Match arms must include a `_ =>` catch-all.
913#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
914#[non_exhaustive]
915pub enum IndexOutputKind {
916    /// Unsigned 32-bit output index dtype.
917    U32,
918    /// Signed 32-bit output index dtype.
919    I32,
920    /// Signed 64-bit output index dtype — PyTorch default.
921    I64,
922}
923
924/// Strict-precision f32 element marker.
925///
926/// `#[repr(transparent)]` wrapper around `f32`. Identical memory layout
927/// to a plain `f32` device buffer — a `DeviceBuffer<f32>` can be
928/// reinterpreted as a `DeviceBuffer<F32Strict>` via `view_as` without
929/// copying. The wrapper exists purely to drive kernel selection at the
930/// Rust type level: choosing the `F32Strict` element routes the launch
931/// through the SIMT (CUDA-cores) GEMM kernels, while the plain `f32`
932/// element routes through the TF32 tensor-core kernels.
933///
934/// Numerical contract: full IEEE 754 binary32 multiply-add throughout
935/// (no tensor-core warp-reduction nondeterminism).
936#[repr(transparent)]
937#[derive(Copy, Clone, Debug, Default, PartialEq, PartialOrd)]
938pub struct F32Strict(pub f32);
939
940// SAFETY: F32Strict is #[repr(transparent)] around f32, which is itself
941// DeviceRepr. Same ABI, same Copy + 'static bounds.
942unsafe impl DeviceRepr for F32Strict {}
943
944impl Element for F32Strict {
945    type Scalar = f32;
946}
947
948// ============================================================================
949// KernelDtype umbrella impls — every concrete kernel dtype is sealed here.
950// ============================================================================
951//
952// One macro keeps the 17 impls visually flat. The Phase 28 refactor
953// removed `const KIND` from the per-family sibling traits (`Element`,
954// `IntElement`, `FpElement`, `BinElement`); `KIND` now lives only on
955// the [`KernelDtype`] supertrait, so `T::KIND` resolves uniquely
956// under any subtrait bound.
957
958macro_rules! impl_kerneldtype {
959    ($($t:ty => $k:ident,)*) => {
960        $(
961            impl kerneldtype_sealed::Sealed for $t {}
962            impl KernelDtype for $t {
963                const KIND: ElementKind = ElementKind::$k;
964            }
965        )*
966    };
967}
968
969impl_kerneldtype! {
970    // Element family (FP + int + bool + complex)
971    f16        => F16,
972    bf16       => Bf16,
973    f32        => F32,
974    F32Strict  => F32Strict,
975    f64        => F64,
976    i32        => I32,
977    i64        => I64,
978    Bool       => Bool,
979    Complex32  => Complex32,
980    Complex64  => Complex64,
981    // IntElement family (GEMM operand newtypes)
982    S8         => S8,
983    U8         => U8,
984    S4         => S4,
985    U4         => U4,
986    // FpElement family (FP8)
987    Fp8E4M3    => Fp8E4M3,
988    Fp8E5M2    => Fp8E5M2,
989    // BinElement family
990    Bin        => Bin,
991}
992
993/// Runtime tag for an [`Element`] or [`IntElement`].
994///
995/// Unified across the float and integer kernel families so that a single
996/// kernel-SKU descriptor can describe any baracuda kernel.
997#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
998pub enum ElementKind {
999    /// IEEE 754 binary16.
1000    F16,
1001    /// Brain-float 16.
1002    Bf16,
1003    /// IEEE 754 binary32 inputs reduced through TF32 tensor cores
1004    /// (10-bit mantissa). Maps to the `f32` Rust type.
1005    F32,
1006    /// IEEE 754 binary32 inputs reduced through SIMT CUDA cores at full
1007    /// f32 precision. Maps to the [`F32Strict`] wrapper type. Bit-stable
1008    /// on the same hardware.
1009    F32Strict,
1010    /// IEEE 754 binary64. Maps to the [`prim@f64`] Rust type.
1011    F64,
1012    /// Signed 8-bit integer. Maps to the [`S8`] wrapper type. Routed
1013    /// through Ampere int8 tensor cores (`mma.sync m16n8k32` integer
1014    /// variant) with int32 accumulation; float `alpha` / `beta` let
1015    /// the kernel act as a dequantize-in-epilogue.
1016    S8,
1017    /// Unsigned 8-bit integer. Maps to the [`U8`] wrapper type. Same
1018    /// kernel family as [`S8`] with unsigned operands.
1019    U8,
1020    /// Signed 32-bit integer. Maps to the `i32` Rust type via the
1021    /// [`Element`] impl. Two roles:
1022    /// 1. **Accumulator marker** for integer GEMM SKUs (reported by
1023    ///    [`crate::PrecisionGuarantee::accumulator`]).
1024    /// 2. **Input element** for elementwise integer arithmetic
1025    ///    (bitwise / comparison / scan ops). The same plan shapes used
1026    ///    for floating-point inputs extend to `i32` via the [`Element`]
1027    ///    impl.
1028    I32,
1029    /// Signed 64-bit integer. Maps to the `i64` Rust type via the
1030    /// [`Element`] impl. Used as an input element for the elementwise
1031    /// integer arithmetic family (bitwise / comparison / scan ops).
1032    /// PyTorch's default integer tensor dtype.
1033    I64,
1034    /// Boolean (1-byte storage). Maps to the [`Bool`] wrapper type via
1035    /// the [`Element`] impl. Used as the input element for the logical-
1036    /// op family (`logical_and` / `logical_or` / `logical_xor`) and as
1037    /// the output element for the comparison-op family
1038    /// (`eq` / `ne` / `gt` / `ge` / `lt` / `le`). Truthiness convention
1039    /// follows PyTorch: 0 = false, any non-zero byte = true.
1040    Bool,
1041    /// 8-bit floating-point, E4M3 encoding (1 sign + 4 exponent + 3
1042    /// mantissa, bias 7, max-finite 448, no infinities). Maps to the
1043    /// [`Fp8E4M3`] wrapper type. Routed through Ada / Hopper FP8 tensor
1044    /// cores (`mma.sync m16n8k32` FP8 variant) with F32 accumulation.
1045    Fp8E4M3,
1046    /// 8-bit floating-point, E5M2 encoding (1 sign + 5 exponent + 2
1047    /// mantissa, bias 15, IEEE-754-compatible inf / NaN). Maps to the
1048    /// [`Fp8E5M2`] wrapper type. Same FP8 tensor-core path as
1049    /// [`Fp8E4M3`] with the alternate operand tag
1050    /// (`.e5m2.e5m2.f32`).
1051    Fp8E5M2,
1052    /// Signed 4-bit integer — packed-pair storage. Maps to the [`S4`]
1053    /// wrapper type. Routed through Ada Lovelace int4 tensor cores
1054    /// (`mma.sync.aligned.m16n8k64.row.col.satfinite.s32.s4.s4.s32`)
1055    /// with int32 accumulation; float `alpha` / `beta` let the kernel
1056    /// act as a dequantize-in-epilogue (same convention as the int8
1057    /// family).
1058    S4,
1059    /// Unsigned 4-bit integer — packed-pair storage. Maps to the [`U4`]
1060    /// wrapper type. Same kernel family as [`S4`] with the alternate
1061    /// operand tag (`.u4.u4.s32`).
1062    U4,
1063    /// 1-bit binary — packed-byte storage (8 bits per byte, LSB =
1064    /// lowest K index). Maps to the [`Bin`] wrapper type. Routed
1065    /// through Ampere+ binary tensor cores
1066    /// (`mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc`).
1067    /// Distinct programming model: the output is the raw popcount
1068    /// accumulator (s32), not a re-quantized b1.
1069    Bin,
1070    /// Single-precision complex — interleaved real/imag pair of `f32`
1071    /// (`#[repr(C)]`). Maps to the [`Complex32`] wrapper type. Used by
1072    /// the FFT op family (Milestone 6.4) for spectrum-domain tensors.
1073    /// ABI-compatible with cuFFT's `cufftComplex`, NumPy's `complex64`,
1074    /// and PyTorch's `torch.complex64`.
1075    Complex32,
1076    /// Double-precision complex — interleaved real/imag pair of `f64`.
1077    /// Maps to the [`Complex64`] wrapper type. ABI-compatible with
1078    /// cuFFT's `cufftDoubleComplex`, NumPy's `complex128`, and
1079    /// PyTorch's `torch.complex128`.
1080    Complex64,
1081}
1082
1083/// Math precision used by the FMA / tensor-core instruction.
1084///
1085/// Distinct from the *input* element type because tensor cores can take
1086/// inputs at one precision and reduce through an instruction at a
1087/// different precision (most notably TF32: F32 inputs, 10-bit-mantissa
1088/// math).
1089#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
1090pub enum MathPrecision {
1091    /// IEEE 754 binary16 multiply-add.
1092    F16,
1093    /// Brain-float 16 multiply-add.
1094    Bf16,
1095    /// TensorFloat-32 (10-bit mantissa) multiply-add. Inputs are stored
1096    /// as F32 but reduced through TF32 tensor cores.
1097    Tf32,
1098    /// IEEE 754 binary32 multiply-add (CUDA cores, no tensor cores).
1099    F32,
1100    /// IEEE 754 binary64 multiply-add via Ampere FP64 tensor cores
1101    /// (DGEMM).
1102    F64,
1103    /// 8-bit integer multiply-add (`mma.sync m16n8k32` integer variant)
1104    /// with int32 accumulation. Used by both signed (s8) and unsigned
1105    /// (u8) integer GEMM SKUs; the multiply operands are 8-bit, the
1106    /// accumulator is 32-bit, and the multiply-add uses the
1107    /// `OpMultiplyAddSaturate` operator (clamps the accumulator on
1108    /// overflow rather than wrapping).
1109    Int8,
1110    /// FP8 E4M3 multiply-add (`mma.sync m16n8k32` FP8 variant) with F32
1111    /// accumulation. Inputs are E4M3 (8-bit), the accumulator is F32,
1112    /// and the epilogue cast saturates to the E4M3 max-finite (±448).
1113    Fp8E4M3,
1114    /// FP8 E5M2 multiply-add. Same instruction family as
1115    /// [`Fp8E4M3`](Self::Fp8E4M3) but with the E5M2 encoding (wider
1116    /// exponent, narrower mantissa).
1117    Fp8E5M2,
1118    /// 4-bit integer multiply-add (`mma.sync m16n8k64` int4 variant)
1119    /// with int32 accumulation. Used by both signed (s4) and unsigned
1120    /// (u4) integer GEMM SKUs; the multiply operands are 4-bit
1121    /// (packed-pair storage in memory), the accumulator is 32-bit, and
1122    /// the multiply-add uses the `satfinite` operator (clamps the
1123    /// accumulator on overflow rather than wrapping). sm_89+.
1124    Int4,
1125    /// 1-bit binary `xor.popc` multiply-add
1126    /// (`mma.sync m16n8k256` b1 variant) with int32 accumulation. The
1127    /// "multiply" is per-bit XOR and the "add" is popcount. Used by
1128    /// the binary GEMM SKU; operands are 1-bit (packed 8-per-byte in
1129    /// memory), the accumulator is 32-bit, and the output is the
1130    /// **raw** popcount accumulator — no re-quantization back to b1.
1131    /// sm_80+.
1132    Binary,
1133}