#![allow(clippy::too_many_arguments)]
use core::ffi::c_void;
use core::ptr;
use core::sync::atomic::{AtomicPtr, AtomicUsize, Ordering};
use super::{
cublasCreate_v2, cublasDestroy_v2, cublasGemmEx, cublasGemmStridedBatchedEx, cublasHandle_t,
cublasSetStream_v2, CUBLAS_COMPUTE_32F, CUBLAS_COMPUTE_64F, CUBLAS_GEMM_DEFAULT, CUBLAS_OP_N,
CUBLAS_OP_T, CUDA_R_16BF, CUDA_R_16F, CUDA_R_32F, CUDA_R_64F,
};
unsafe extern "C" {
fn cuCtxGetCurrent(ctx: *mut *mut c_void) -> i32;
}
const OK: i32 = 0;
const INVALID: i32 = 2;
const INTERNAL: i32 = 5;
const LAYOUT_RRR: i32 = 0;
const LAYOUT_RCR: i32 = 1;
const LAYOUT_CRR: i32 = 2;
const POOL_SLOTS: usize = 8;
struct Slot {
ctx: AtomicUsize,
handle: AtomicPtr<c_void>,
}
#[allow(clippy::declare_interior_mutable_const)]
const EMPTY_SLOT: Slot = Slot {
ctx: AtomicUsize::new(0),
handle: AtomicPtr::new(ptr::null_mut()),
};
static POOL: [Slot; POOL_SLOTS] = [EMPTY_SLOT; POOL_SLOTS];
fn current_ctx_key() -> usize {
let mut ctx: *mut c_void = ptr::null_mut();
let st = unsafe { cuCtxGetCurrent(&mut ctx as *mut _) };
if st != 0 || ctx.is_null() {
usize::MAX
} else {
ctx as usize
}
}
unsafe fn take_handle(key: usize) -> Result<cublasHandle_t, i32> {
for slot in POOL.iter() {
if slot.ctx.load(Ordering::Acquire) == key {
let h = slot.handle.swap(ptr::null_mut(), Ordering::AcqRel);
if !h.is_null() {
return Ok(h);
}
}
}
let mut handle: cublasHandle_t = ptr::null_mut();
for attempt in 0..5u64 {
if attempt > 0 {
let spins = 4_000_000 * attempt;
for _ in 0..spins {
core::hint::spin_loop();
}
}
let st = unsafe { cublasCreate_v2(&mut handle as *mut _) };
if st == 0 && !handle.is_null() {
return Ok(handle);
}
handle = ptr::null_mut();
}
Err(INTERNAL)
}
unsafe fn put_handle(key: usize, h: cublasHandle_t) {
if key == usize::MAX {
unsafe {
let _ = cublasDestroy_v2(h);
}
return;
}
for slot in POOL.iter() {
if slot.ctx.load(Ordering::Acquire) == key
&& slot
.handle
.compare_exchange(ptr::null_mut(), h, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
return;
}
}
for slot in POOL.iter() {
if slot
.ctx
.compare_exchange(0, key, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
if slot
.handle
.compare_exchange(ptr::null_mut(), h, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
return;
}
break;
}
}
unsafe {
let _ = cublasDestroy_v2(h);
}
}
fn validate(
m: i32,
n: i32,
k: i32,
batch: i32,
layout: i32,
lda: i64,
ldb: i64,
ldd: i64,
stride_d: i64,
) -> i32 {
if m < 0 || n < 0 || k < 0 || batch < 0 {
return INVALID;
}
let (min_lda, min_ldb) = match layout {
LAYOUT_RRR => (k as i64, n as i64),
LAYOUT_RCR => (k as i64, k as i64),
LAYOUT_CRR => (m as i64, n as i64),
_ => return INVALID,
};
if lda < min_lda.max(1) || ldb < min_ldb.max(1) || ldd < (n as i64).max(1) {
return INVALID;
}
let i32_max = i32::MAX as i64;
if lda > i32_max || ldb > i32_max || ldd > i32_max {
return INVALID;
}
if batch > 1 && stride_d == 0 {
return INVALID;
}
OK
}
unsafe fn gemm_dense_run_impl(
m: i32,
n: i32,
k: i32,
batch: i32,
layout: i32,
alpha: *const c_void,
beta: *const c_void,
a: *const c_void,
lda: i64,
stride_a: i64,
b: *const c_void,
ldb: i64,
stride_b: i64,
d: *mut c_void,
ldd: i64,
stride_d: i64,
data_type: i32,
compute_type: i32,
stream: *mut c_void,
) -> i32 {
let st = validate(m, n, k, batch, layout, lda, ldb, ldd, stride_d);
if st != OK {
return st;
}
if m == 0 || n == 0 || batch == 0 {
return OK;
}
if a.is_null() || b.is_null() || d.is_null() {
return INVALID;
}
let (transa, transb) = match layout {
LAYOUT_RRR => (CUBLAS_OP_N, CUBLAS_OP_N),
LAYOUT_RCR => (CUBLAS_OP_T, CUBLAS_OP_N),
LAYOUT_CRR => (CUBLAS_OP_N, CUBLAS_OP_T),
_ => return INVALID,
};
let key = current_ctx_key();
let handle = match unsafe { take_handle(key) } {
Ok(h) => h,
Err(e) => return e,
};
let key = if key == usize::MAX { current_ctx_key() } else { key };
let st = unsafe { cublasSetStream_v2(handle, stream) };
if st != 0 {
unsafe {
let _ = cublasDestroy_v2(handle);
}
return INTERNAL;
}
let status = if batch == 1 {
unsafe {
cublasGemmEx(
handle,
transa,
transb,
n,
m,
k,
alpha,
b,
data_type,
ldb as i32,
a,
data_type,
lda as i32,
beta,
d,
data_type,
ldd as i32,
compute_type,
CUBLAS_GEMM_DEFAULT,
)
}
} else {
unsafe {
cublasGemmStridedBatchedEx(
handle,
transa,
transb,
n,
m,
k,
alpha,
b,
data_type,
ldb as i32,
stride_b,
a,
data_type,
lda as i32,
stride_a,
beta,
d,
data_type,
ldd as i32,
stride_d,
batch,
compute_type,
CUBLAS_GEMM_DEFAULT,
)
}
};
unsafe { put_handle(key, handle) };
if status != 0 {
INTERNAL
} else {
OK
}
}
macro_rules! gemm_dense_family {
(
$run:ident,
$can:ident,
$ws:ident,
$scalar:ty,
$data_type:expr,
$compute_type:expr,
$dtype_doc:literal,
$acc_doc:literal
) => {
#[doc = concat!(
"Dense ", $dtype_doc, " GEMM (cuBLAS-backed): ",
"`D[g] = α · A[g] · B[g] + β · D[g]` for `g ∈ [0, batch)`, ",
"accumulating in ", $acc_doc, ". Row-major problem; see the ",
"module docs for the `layout` tag (0 = RRR, 1 = RCR, ",
"2 = CRR), leading-dim minimums, and the batch-stride ",
"contract (element strides; `stride_a`/`stride_b` may be 0 ",
"to broadcast; strides ignored at `batch == 1`).",
)]
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $run(
m: i32,
n: i32,
k: i32,
batch: i32,
layout: i32,
alpha: $scalar,
beta: $scalar,
a: *const c_void,
lda: i64,
stride_a: i64,
b: *const c_void,
ldb: i64,
stride_b: i64,
d: *mut c_void,
ldd: i64,
stride_d: i64,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
let _ = (workspace, workspace_bytes);
let alpha_v: $scalar = alpha;
let beta_v: $scalar = beta;
unsafe {
gemm_dense_run_impl(
m,
n,
k,
batch,
layout,
&alpha_v as *const $scalar as *const c_void,
&beta_v as *const $scalar as *const c_void,
a,
lda,
stride_a,
b,
ldb,
stride_b,
d,
ldd,
stride_d,
$data_type,
$compute_type,
stream,
)
}
}
#[doc = concat!(
"Host-side validity check for [`", stringify!($run), "`]. ",
"Validates extents, the `layout` tag, leading-dim minimums, ",
"i32-fit of leading dims, and `stride_d != 0` at ",
"`batch > 1`. `stride_a` / `stride_b` are accepted ",
"unconditionally (any value, including 0-broadcast).",
)]
#[unsafe(no_mangle)]
pub extern "C" fn $can(
m: i32,
n: i32,
k: i32,
batch: i32,
layout: i32,
lda: i64,
ldb: i64,
ldd: i64,
_stride_a: i64,
_stride_b: i64,
stride_d: i64,
) -> i32 {
validate(m, n, k, batch, layout, lda, ldb, ldd, stride_d)
}
#[doc = concat!(
"Workspace query for [`", stringify!($run), "`]. Always ",
"`0` — cuBLAS allocates its workspace internally per handle.",
)]
#[unsafe(no_mangle)]
pub extern "C" fn $ws(_m: i32, _n: i32, _k: i32, _batch: i32, _layout: i32) -> usize {
0
}
};
}
gemm_dense_family!(
baracuda_kernels_gemm_dense_f32_run,
baracuda_kernels_gemm_dense_f32_can_implement,
baracuda_kernels_gemm_dense_f32_workspace_size,
f32,
CUDA_R_32F,
CUBLAS_COMPUTE_32F,
"f32",
"IEEE binary32 (default math mode — NOT TF32)"
);
gemm_dense_family!(
baracuda_kernels_gemm_dense_f64_run,
baracuda_kernels_gemm_dense_f64_can_implement,
baracuda_kernels_gemm_dense_f64_workspace_size,
f64,
CUDA_R_64F,
CUBLAS_COMPUTE_64F,
"f64",
"f64"
);
gemm_dense_family!(
baracuda_kernels_gemm_dense_f16_run,
baracuda_kernels_gemm_dense_f16_can_implement,
baracuda_kernels_gemm_dense_f16_workspace_size,
f32,
CUDA_R_16F,
CUBLAS_COMPUTE_32F,
"f16",
"f32"
);
gemm_dense_family!(
baracuda_kernels_gemm_dense_bf16_run,
baracuda_kernels_gemm_dense_bf16_can_implement,
baracuda_kernels_gemm_dense_bf16_workspace_size,
f32,
CUDA_R_16BF,
CUBLAS_COMPUTE_32F,
"bf16",
"f32"
);