1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
//! # sgemm-bi
//!
//! Deterministic, batch-invariant CUDA GEMM engine with a full training
//! triad — forward (`Y = X@W + bias`), weight gradient (`dW += X^T@dY`),
//! and input gradient (`dX = dY@W^T`) — in f32, bf16, and f16, plus an
//! opt-in tensor-core tier.
//!
//! ## Guarantees
//!
//! - **Run-to-run determinism**: every kernel uses a fixed reduction order
//! (no atomics, no data-dependent splits). Same inputs → bit-identical
//! outputs, on every call, including through CUDA Graph replay.
//! - **Batch invariance**: within a dispatch bucket, row 0 of the output
//! is bit-identical regardless of M. The tensor-core forward is
//! strictly batch-invariant across ALL M.
//! - **Typed bit contract**: bf16/f16 results are bit-identical to
//! "upcast inputs to f32, run the f32 tier, round-to-nearest-even
//! downcast" — accumulation never happens in reduced precision, and
//! exactly one rounding is applied at the output store.
//! - **No vendor BLAS**: cuBLAS is not linked, called, or fallen back to.
//!
//! ## Tiers
//!
//! | tier | entry points | contract |
//! |---|---|---|
//! | f32 | [`SgemmBi::forward_f32`], `backward_dw_f32`, `backward_dx_f32` | reference chain |
//! | typed (bf16/f16) | [`SgemmBi::forward`], `backward_dw`, `backward_dx` | bit-equal to f32 tier on upcast inputs |
//! | tensor cores | [`SgemmBi::forward_tc`], `backward_dw_tc`, `backward_dx_tc` | own deterministic contract (mma.sync, f32 accumulate) |
//!
//! The tensor-core tier is typically 3.5–6.3× faster than the scalar
//! tiers (two bit-identical tile families, 128x128 and 64x64, routed by
//! shape, covering both output dims >= 64) and turns the determinism
//! overhead zero-to-negative against cuBLAS-PEDANTIC-class baselines;
//! it does not bit-match the scalar tiers (a tensor-core reduction tree
//! cannot reproduce a scalar FMA chain).
//!
//! ## Requirements
//!
//! NVIDIA Ampere or newer (`sm_80`+) — the kernel blob uses cp.async,
//! ldmatrix, and native bf16 throughout, so [`SgemmBi::new`] rejects
//! older devices with [`Error::UnsupportedArch`]. Kernels compile at
//! engine construction via NVRTC for the device's native architecture —
//! no CUDA toolkit needed at run time beyond the driver and NVRTC.
//!
//! ## C interface
//!
//! Enable the `capi` feature for a flat `extern "C"` surface (module
//! [`capi`], header `include/sgemm_bi.h`): engine create/destroy, the
//! six GEMM entry points on raw `CUdeviceptr`s, per-thread error
//! strings, and scratch pre-sizing for CUDA Graph capture.
//!
//! ## Example
//!
//! ```no_run
//! use sgemm_bi::{Dtype, SgemmBi, TypedPtr};
//!
//! let context = cudarc::driver::CudaContext::new(0).unwrap();
//! let stream = context.new_stream().unwrap();
//! let engine = SgemmBi::new(&context, stream.clone()).unwrap();
//!
//! let (m, k, n) = (2048, 768, 3072);
//! let x = stream.alloc_zeros::<u16>(m * k).unwrap(); // bf16 storage
//! let w = stream.alloc_zeros::<u16>(k * n).unwrap();
//! let y = stream.alloc_zeros::<u16>(m * n).unwrap();
//! # use cudarc::driver::DevicePtr;
//! let ptr = |b: &cudarc::driver::CudaSlice<u16>| b.device_ptr(&stream).0;
//!
//! engine
//! .forward(
//! TypedPtr::new(ptr(&y), Dtype::Bf16),
//! TypedPtr::new(ptr(&x), Dtype::Bf16),
//! TypedPtr::new(ptr(&w), Dtype::Bf16),
//! None,
//! (m, k, n),
//! )
//! .unwrap();
//! stream.synchronize().unwrap();
//! ```
pub use ;
pub use SgemmBi;
pub use ;