sgemm-bi 0.1.0

Deterministic, batch-invariant CUDA GEMM engine with a full training triad (forward, dW, dX) in f32 / bf16 / f16, plus an opt-in tensor-core tier that is faster than cuBLAS PEDANTIC. Bit-identical results across runs; fixed reduction order; no atomics; no cuBLAS dependency.
//! # sgemm-bi
//!
//! Deterministic, batch-invariant CUDA GEMM engine with a full training
//! triad — forward (`Y = X@W + bias`), weight gradient (`dW += X^T@dY`),
//! and input gradient (`dX = dY@W^T`) — in f32, bf16, and f16, plus an
//! opt-in tensor-core tier.
//!
//! ## Guarantees
//!
//! - **Run-to-run determinism**: every kernel uses a fixed reduction order
//!   (no atomics, no data-dependent splits). Same inputs → bit-identical
//!   outputs, on every call, including through CUDA Graph replay.
//! - **Batch invariance**: within a dispatch bucket, row 0 of the output
//!   is bit-identical regardless of M. The tensor-core forward is
//!   strictly batch-invariant across ALL M.
//! - **Typed bit contract**: bf16/f16 results are bit-identical to
//!   "upcast inputs to f32, run the f32 tier, round-to-nearest-even
//!   downcast" — accumulation never happens in reduced precision, and
//!   exactly one rounding is applied at the output store.
//! - **No vendor BLAS**: cuBLAS is not linked, called, or fallen back to.
//!
//! ## Tiers
//!
//! | tier | entry points | contract |
//! |---|---|---|
//! | f32 | [`SgemmBi::forward_f32`], `backward_dw_f32`, `backward_dx_f32` | reference chain |
//! | typed (bf16/f16) | [`SgemmBi::forward`], `backward_dw`, `backward_dx` | bit-equal to f32 tier on upcast inputs |
//! | tensor cores | [`SgemmBi::forward_tc`], `backward_dw_tc`, `backward_dx_tc` | own deterministic contract (mma.sync, f32 accumulate) |
//!
//! The tensor-core tier is typically 3–7× faster than the scalar tiers on
//! 128x128-tile-friendly shapes and turns the determinism overhead
//! negative against cuBLAS-PEDANTIC-class baselines; it does not
//! bit-match the scalar tiers (a tensor-core reduction tree cannot
//! reproduce a scalar FMA chain).
//!
//! ## Requirements
//!
//! NVIDIA GPU, `sm_80`+ for the typed and tensor-core tiers (cp.async,
//! ldmatrix, bf16 mma). Kernels compile at engine construction via NVRTC
//! for the device's native architecture — no CUDA toolkit needed at run
//! time beyond the driver and NVRTC.
//!
//! ## Example
//!
//! ```no_run
//! use sgemm_bi::{Dtype, SgemmBi, TypedPtr};
//!
//! let context = cudarc::driver::CudaContext::new(0).unwrap();
//! let stream = context.new_stream().unwrap();
//! let engine = SgemmBi::new(&context, stream.clone()).unwrap();
//!
//! let (m, k, n) = (2048, 768, 3072);
//! let x = stream.alloc_zeros::<u16>(m * k).unwrap(); // bf16 storage
//! let w = stream.alloc_zeros::<u16>(k * n).unwrap();
//! let y = stream.alloc_zeros::<u16>(m * n).unwrap();
//! # use cudarc::driver::DevicePtr;
//! let ptr = |b: &cudarc::driver::CudaSlice<u16>| b.device_ptr(&stream).0;
//!
//! engine
//!     .forward(
//!         TypedPtr::new(ptr(&y), Dtype::Bf16),
//!         TypedPtr::new(ptr(&x), Dtype::Bf16),
//!         TypedPtr::new(ptr(&w), Dtype::Bf16),
//!         None,
//!         (m, k, n),
//!     )
//!     .unwrap();
//! stream.synchronize().unwrap();
//! ```

mod dispatch;
mod dtype;
mod engine;
mod error;
mod kernels;

pub use dtype::{Dtype, TypedPtr};
pub use engine::SgemmBi;
pub use error::{Error, Result};