use crate::dtype::{Dtype, TypedPtr};
use crate::engine::SgemmBi;
use crate::error::{Error, Result};
use cudarc::driver::CudaContext;
use std::cell::RefCell;
use std::ffi::{CString, c_char};
use std::panic::{AssertUnwindSafe, catch_unwind};
use std::sync::Arc;
pub const SGB_OK: i32 = 0;
pub const SGB_ERR_CUDA: i32 = 1;
pub const SGB_ERR_UNCOVERED: i32 = 2;
pub const SGB_ERR_DTYPE: i32 = 3;
pub const SGB_ERR_UNSUPPORTED_ARCH: i32 = 4;
pub const SGB_ERR_INVALID_ARG: i32 = 5;
pub const SGB_ERR_PANIC: i32 = 6;
pub const SGB_F32: i32 = 0;
pub const SGB_BF16: i32 = 1;
pub const SGB_F16: i32 = 2;
pub struct SgbEngine {
engine: SgemmBi,
_context: Arc<CudaContext>,
}
#[repr(C)]
pub struct SgbGemm {
pub out: u64,
pub a: u64,
pub b: u64,
pub bias: u64,
pub m: i64,
pub k: i64,
pub n: i64,
pub dtype: i32,
pub reserved: i32,
}
thread_local! {
static LAST_ERROR: RefCell<CString> = RefCell::new(CString::default());
}
fn set_error(msg: &str) {
let c = CString::new(msg).unwrap_or_else(|_| CString::new("invalid error text").unwrap());
LAST_ERROR.with(|e| *e.borrow_mut() = c);
}
fn code_of(e: &Error) -> i32 {
match e {
Error::Cuda(_) => SGB_ERR_CUDA,
Error::Uncovered { .. } => SGB_ERR_UNCOVERED,
Error::DtypeMismatch(_) => SGB_ERR_DTYPE,
Error::UnsupportedArch { .. } => SGB_ERR_UNSUPPORTED_ARCH,
}
}
fn finish(r: Result<()>) -> i32 {
match r {
Ok(()) => SGB_OK,
Err(e) => {
set_error(&e.to_string());
code_of(&e)
}
}
}
fn invalid(msg: &str) -> i32 {
set_error(msg);
SGB_ERR_INVALID_ARG
}
fn guarded(f: impl FnOnce() -> i32) -> i32 {
catch_unwind(AssertUnwindSafe(f)).unwrap_or_else(|_| {
set_error("internal panic in sgemm-bi");
SGB_ERR_PANIC
})
}
struct GemmArgs {
out: u64,
a: u64,
b: u64,
bias: Option<u64>,
dims: (usize, usize, usize),
dtype: Option<Dtype>,
}
fn parse_gemm(g: &SgbGemm) -> std::result::Result<GemmArgs, String> {
if g.m <= 0 || g.k <= 0 || g.n <= 0 {
return Err(format!(
"dimensions must be positive: M={} K={} N={}",
g.m, g.k, g.n
));
}
if g.out == 0 || g.a == 0 || g.b == 0 {
return Err("out/a/b device pointers must be non-null".into());
}
let dtype = match g.dtype {
SGB_F32 => None,
SGB_BF16 => Some(Dtype::Bf16),
SGB_F16 => Some(Dtype::F16),
other => return Err(format!("unknown dtype code {other}")),
};
Ok(GemmArgs {
out: g.out,
a: g.a,
b: g.b,
bias: (g.bias != 0).then_some(g.bias),
dims: (g.m as usize, g.k as usize, g.n as usize),
dtype,
})
}
unsafe fn with_gemm(
eng: *const SgbEngine,
gemm: *const SgbGemm,
f: impl FnOnce(&SgbEngine, GemmArgs) -> Result<()>,
) -> i32 {
if eng.is_null() || gemm.is_null() {
return invalid("null engine or descriptor pointer");
}
let (eng, gemm) = unsafe { (&*eng, &*gemm) };
guarded(|| match parse_gemm(gemm) {
Ok(args) => finish(f(eng, args)),
Err(msg) => invalid(&msg),
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn sgb_last_error() -> *const c_char {
LAST_ERROR.with(|e| e.borrow().as_ptr())
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn sgb_engine_create(device_ordinal: i32, out: *mut *mut SgbEngine) -> i32 {
if out.is_null() {
return invalid("null output handle pointer");
}
if device_ordinal < 0 {
return invalid("device ordinal must be non-negative");
}
guarded(|| {
let built = (|| -> Result<Box<SgbEngine>> {
let context = CudaContext::new(device_ordinal as usize)
.map_err(|e| Error::Cuda(format!("create context: {e:?}")))?;
let stream = context
.new_stream()
.map_err(|e| Error::Cuda(format!("create stream: {e:?}")))?;
let engine = SgemmBi::new(&context, stream)?;
Ok(Box::new(SgbEngine {
engine,
_context: context,
}))
})();
match built {
Ok(handle) => {
unsafe { *out = Box::into_raw(handle) };
SGB_OK
}
Err(e) => {
set_error(&e.to_string());
code_of(&e)
}
}
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn sgb_engine_destroy(eng: *mut SgbEngine) {
if !eng.is_null() {
drop(unsafe { Box::from_raw(eng) });
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn sgb_engine_synchronize(eng: *const SgbEngine) -> i32 {
if eng.is_null() {
return invalid("null engine pointer");
}
let eng = unsafe { &*eng };
guarded(|| {
finish(
eng.engine
.stream()
.synchronize()
.map_err(|e| Error::Cuda(format!("synchronize: {e:?}"))),
)
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn sgb_engine_stream(eng: *const SgbEngine) -> u64 {
if eng.is_null() {
return 0;
}
let eng = unsafe { &*eng };
eng.engine.stream().cu_stream() as u64
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn sgb_forward(eng: *const SgbEngine, gemm: *const SgbGemm) -> i32 {
unsafe {
with_gemm(eng, gemm, |e, g| match g.dtype {
None => e.engine.forward_f32(g.out, g.a, g.b, g.bias, g.dims),
Some(dt) => e.engine.forward(
TypedPtr::new(g.out, dt),
TypedPtr::new(g.a, dt),
TypedPtr::new(g.b, dt),
g.bias,
g.dims,
),
})
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn sgb_backward_dw(eng: *const SgbEngine, gemm: *const SgbGemm) -> i32 {
unsafe {
with_gemm(eng, gemm, |e, g| match g.dtype {
None => e.engine.backward_dw_f32(g.out, g.a, g.b, g.dims),
Some(dt) => e.engine.backward_dw(
g.out,
TypedPtr::new(g.a, dt),
TypedPtr::new(g.b, dt),
g.dims,
),
})
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn sgb_backward_dx(eng: *const SgbEngine, gemm: *const SgbGemm) -> i32 {
unsafe {
with_gemm(eng, gemm, |e, g| match g.dtype {
None => e.engine.backward_dx_f32(g.out, g.a, g.b, g.dims),
Some(dt) => e.engine.backward_dx(
TypedPtr::new(g.out, dt),
TypedPtr::new(g.a, dt),
TypedPtr::new(g.b, dt),
g.dims,
),
})
}
}
fn require_typed(dtype: Option<Dtype>) -> Result<Dtype> {
dtype.ok_or(Error::DtypeMismatch(
"tensor-core tier requires bf16 or f16",
))
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn sgb_forward_tc(eng: *const SgbEngine, gemm: *const SgbGemm) -> i32 {
unsafe {
with_gemm(eng, gemm, |e, g| {
let dt = require_typed(g.dtype)?;
e.engine.forward_tc(
TypedPtr::new(g.out, dt),
TypedPtr::new(g.a, dt),
TypedPtr::new(g.b, dt),
g.bias,
g.dims,
)
})
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn sgb_backward_dw_tc(eng: *const SgbEngine, gemm: *const SgbGemm) -> i32 {
unsafe {
with_gemm(eng, gemm, |e, g| {
let dt = require_typed(g.dtype)?;
e.engine.backward_dw_tc(
g.out,
TypedPtr::new(g.a, dt),
TypedPtr::new(g.b, dt),
g.dims,
)
})
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn sgb_backward_dx_tc(eng: *const SgbEngine, gemm: *const SgbGemm) -> i32 {
unsafe {
with_gemm(eng, gemm, |e, g| {
let dt = require_typed(g.dtype)?;
e.engine.backward_dx_tc(
TypedPtr::new(g.out, dt),
TypedPtr::new(g.a, dt),
TypedPtr::new(g.b, dt),
g.dims,
)
})
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn sgb_presize_upcast_scratch(
eng: *const SgbEngine,
a_elems: i64,
b_elems: i64,
c_elems: i64,
) -> i32 {
if eng.is_null() {
return invalid("null engine pointer");
}
if a_elems < 0 || b_elems < 0 || c_elems < 0 {
return invalid("scratch element counts must be non-negative");
}
let eng = unsafe { &*eng };
guarded(|| {
finish(eng.engine.presize_upcast_scratch((
a_elems as usize,
b_elems as usize,
c_elems as usize,
)))
})
}