use sparse_ir::gemm::{
Dgemm64FnPtr, DgemmFnPtr, ExternalBlas64Backend, ExternalBlasBackend, GemmBackendHandle,
Zgemm64FnPtr, ZgemmFnPtr,
};
#[repr(C)]
pub struct spir_gemm_backend {
pub(crate) _private: *const std::ffi::c_void,
}
impl spir_gemm_backend {
pub(crate) fn inner(&self) -> &GemmBackendHandle {
unsafe { &*(self._private as *const GemmBackendHandle) }
}
pub(crate) fn new(handle: GemmBackendHandle) -> Self {
Self {
_private: Box::into_raw(Box::new(handle)) as *const std::ffi::c_void,
}
}
}
impl Drop for spir_gemm_backend {
fn drop(&mut self) {
if !self._private.is_null() {
unsafe {
let _ = Box::from_raw(
self._private as *const GemmBackendHandle as *mut GemmBackendHandle,
);
}
}
}
}
impl Clone for spir_gemm_backend {
fn clone(&self) -> Self {
let inner = self.inner().clone();
Self::new(inner)
}
}
#[unsafe(no_mangle)]
pub extern "C" fn spir_gemm_backend_new_from_fblas_lp64(
dgemm: *const libc::c_void,
zgemm: *const libc::c_void,
) -> *mut spir_gemm_backend {
if dgemm.is_null() || zgemm.is_null() {
return std::ptr::null_mut();
}
let result = std::panic::catch_unwind(|| {
let dgemm_fn: DgemmFnPtr = unsafe { std::mem::transmute(dgemm) };
let zgemm_fn: ZgemmFnPtr = unsafe { std::mem::transmute(zgemm) };
let backend = ExternalBlasBackend::new(dgemm_fn, zgemm_fn);
let handle = GemmBackendHandle::new(Box::new(backend));
Box::into_raw(Box::new(spir_gemm_backend::new(handle)))
});
result.unwrap_or(std::ptr::null_mut())
}
#[unsafe(no_mangle)]
pub extern "C" fn spir_gemm_backend_new_from_fblas_ilp64(
dgemm64: *const libc::c_void,
zgemm64: *const libc::c_void,
) -> *mut spir_gemm_backend {
if dgemm64.is_null() || zgemm64.is_null() {
return std::ptr::null_mut();
}
let result = std::panic::catch_unwind(|| {
let dgemm64_fn: Dgemm64FnPtr = unsafe { std::mem::transmute(dgemm64) };
let zgemm64_fn: Zgemm64FnPtr = unsafe { std::mem::transmute(zgemm64) };
let backend = ExternalBlas64Backend::new(dgemm64_fn, zgemm64_fn);
let handle = GemmBackendHandle::new(Box::new(backend));
Box::into_raw(Box::new(spir_gemm_backend::new(handle)))
});
result.unwrap_or(std::ptr::null_mut())
}
#[unsafe(no_mangle)]
pub extern "C" fn spir_gemm_backend_release(backend: *mut spir_gemm_backend) {
if !backend.is_null() {
unsafe {
let _ = Box::from_raw(backend);
}
}
}
pub(crate) unsafe fn get_backend_handle<'a>(
backend: *const spir_gemm_backend,
) -> Option<&'a GemmBackendHandle> {
if backend.is_null() {
None
} else {
unsafe { Some((*backend).inner()) }
}
}
#[cfg(test)]
mod tests {
use super::*;
unsafe extern "C" fn mock_dgemm(
_transa: *const libc::c_char,
_transb: *const libc::c_char,
_m: *const libc::c_int,
_n: *const libc::c_int,
_k: *const libc::c_int,
_alpha: *const libc::c_double,
_a: *const libc::c_double,
_lda: *const libc::c_int,
_b: *const libc::c_double,
_ldb: *const libc::c_int,
_beta: *const libc::c_double,
_c: *mut libc::c_double,
_ldc: *const libc::c_int,
) {
}
unsafe extern "C" fn mock_zgemm(
_transa: *const libc::c_char,
_transb: *const libc::c_char,
_m: *const libc::c_int,
_n: *const libc::c_int,
_k: *const libc::c_int,
_alpha: *const num_complex::Complex<f64>,
_a: *const num_complex::Complex<f64>,
_lda: *const libc::c_int,
_b: *const num_complex::Complex<f64>,
_ldb: *const libc::c_int,
_beta: *const num_complex::Complex<f64>,
_c: *mut num_complex::Complex<f64>,
_ldc: *const libc::c_int,
) {
}
unsafe extern "C" fn mock_dgemm64(
_transa: *const libc::c_char,
_transb: *const libc::c_char,
_m: *const i64,
_n: *const i64,
_k: *const i64,
_alpha: *const libc::c_double,
_a: *const libc::c_double,
_lda: *const i64,
_b: *const libc::c_double,
_ldb: *const i64,
_beta: *const libc::c_double,
_c: *mut libc::c_double,
_ldc: *const i64,
) {
}
unsafe extern "C" fn mock_zgemm64(
_transa: *const libc::c_char,
_transb: *const libc::c_char,
_m: *const i64,
_n: *const i64,
_k: *const i64,
_alpha: *const num_complex::Complex<f64>,
_a: *const num_complex::Complex<f64>,
_lda: *const i64,
_b: *const num_complex::Complex<f64>,
_ldb: *const i64,
_beta: *const num_complex::Complex<f64>,
_c: *mut num_complex::Complex<f64>,
_ldc: *const i64,
) {
}
#[test]
fn test_backend_new_from_fblas_lp64_success() {
unsafe {
let backend = spir_gemm_backend_new_from_fblas_lp64(
mock_dgemm as *const _,
mock_zgemm as *const _,
);
assert!(!backend.is_null(), "Backend should not be null");
spir_gemm_backend_release(backend);
}
}
#[test]
fn test_backend_new_from_fblas_ilp64_success() {
unsafe {
let backend = spir_gemm_backend_new_from_fblas_ilp64(
mock_dgemm64 as *const _,
mock_zgemm64 as *const _,
);
assert!(!backend.is_null(), "Backend should not be null");
spir_gemm_backend_release(backend);
}
}
#[test]
fn test_backend_new_from_fblas_lp64_null_dgemm() {
unsafe {
let backend =
spir_gemm_backend_new_from_fblas_lp64(std::ptr::null(), mock_zgemm as *const _);
assert!(
backend.is_null(),
"Backend should be null when dgemm is null"
);
}
}
#[test]
fn test_backend_new_from_fblas_lp64_null_zgemm() {
unsafe {
let backend =
spir_gemm_backend_new_from_fblas_lp64(mock_dgemm as *const _, std::ptr::null());
assert!(
backend.is_null(),
"Backend should be null when zgemm is null"
);
}
}
#[test]
fn test_backend_new_from_fblas_ilp64_null_pointers() {
unsafe {
let backend =
spir_gemm_backend_new_from_fblas_ilp64(std::ptr::null(), std::ptr::null());
assert!(
backend.is_null(),
"Backend should be null when pointers are null"
);
}
}
#[test]
fn test_backend_release_null() {
unsafe {
spir_gemm_backend_release(std::ptr::null_mut());
}
}
#[cfg(all(test, feature = "system-blas"))]
mod system_blas_tests {
use super::*;
use blas_sys::{dgemm_, zgemm_};
use mdarray::tensor;
use sparse_ir::gemm::matmul_par;
unsafe fn create_blas_backend() -> *mut spir_gemm_backend {
unsafe {
spir_gemm_backend_new_from_fblas_lp64(
dgemm_ as *const _,
unsafe {
std::mem::transmute::<
unsafe extern "C" fn(
*const libc::c_char,
*const libc::c_char,
*const libc::c_int,
*const libc::c_int,
*const libc::c_int,
*const blas_sys::c_double_complex,
*const blas_sys::c_double_complex,
*const libc::c_int,
*const blas_sys::c_double_complex,
*const libc::c_int,
*const blas_sys::c_double_complex,
*mut blas_sys::c_double_complex,
*const libc::c_int,
),
sparse_ir::gemm::ZgemmFnPtr,
>(zgemm_)
} as *const _,
)
}
}
#[test]
fn test_default_backend_matrix_multiplication_f64() {
unsafe {
let backend = std::ptr::null();
let a: mdarray::DTensor<f64, 2> = tensor![[1.0, 2.0], [3.0, 4.0]];
let b: mdarray::DTensor<f64, 2> = tensor![[5.0, 6.0], [7.0, 8.0]];
let backend_handle = get_backend_handle(backend);
let c = matmul_par(&a, &b, backend_handle);
assert!(
(c[[0, 0]] - 19.0).abs() < 1e-10,
"c[0,0] should be 19.0, got {}",
c[[0, 0]]
);
assert!((c[[0, 1]] - 22.0).abs() < 1e-10, "c[0,1] should be 22.0");
assert!((c[[1, 0]] - 43.0).abs() < 1e-10, "c[1,0] should be 43.0");
assert!((c[[1, 1]] - 50.0).abs() < 1e-10, "c[1,1] should be 50.0");
}
}
#[test]
fn test_lp64_backend_matrix_multiplication_f64() {
unsafe {
let backend = create_blas_backend();
assert!(!backend.is_null());
let a: mdarray::DTensor<f64, 2> = tensor![[1.0, 2.0], [3.0, 4.0]];
let b: mdarray::DTensor<f64, 2> = tensor![[5.0, 6.0], [7.0, 8.0]];
let backend_handle = get_backend_handle(backend);
let c = matmul_par(&a, &b, backend_handle);
assert!(
(c[[0, 0]] - 19.0).abs() < 1e-10,
"c[0,0] should be 19.0, got {}",
c[[0, 0]]
);
assert!((c[[0, 1]] - 22.0).abs() < 1e-10, "c[0,1] should be 22.0");
assert!((c[[1, 0]] - 43.0).abs() < 1e-10, "c[1,0] should be 43.0");
assert!((c[[1, 1]] - 50.0).abs() < 1e-10, "c[1,1] should be 50.0");
spir_gemm_backend_release(backend);
}
}
#[test]
fn test_default_backend_matrix_multiplication_complex() {
unsafe {
let backend = std::ptr::null();
let a: mdarray::DTensor<num_complex::Complex<f64>, 2> = tensor![
[
num_complex::Complex::new(1.0, 0.0),
num_complex::Complex::new(2.0, 0.0)
],
[
num_complex::Complex::new(3.0, 0.0),
num_complex::Complex::new(4.0, 0.0)
]
];
let b: mdarray::DTensor<num_complex::Complex<f64>, 2> = tensor![
[
num_complex::Complex::new(5.0, 0.0),
num_complex::Complex::new(6.0, 0.0)
],
[
num_complex::Complex::new(7.0, 0.0),
num_complex::Complex::new(8.0, 0.0)
]
];
let backend_handle = get_backend_handle(backend);
let c = matmul_par(&a, &b, backend_handle);
assert!((c[[0, 0]].re - 19.0).abs() < 1e-10);
assert!((c[[0, 1]].re - 22.0).abs() < 1e-10);
assert!((c[[1, 0]].re - 43.0).abs() < 1e-10);
assert!((c[[1, 1]].re - 50.0).abs() < 1e-10);
assert!(c[[0, 0]].im.abs() < 1e-10);
}
}
#[test]
fn test_lp64_backend_matrix_multiplication_complex() {
unsafe {
let backend = create_blas_backend();
assert!(!backend.is_null());
let a: mdarray::DTensor<num_complex::Complex<f64>, 2> = tensor![
[
num_complex::Complex::new(1.0, 0.0),
num_complex::Complex::new(2.0, 0.0)
],
[
num_complex::Complex::new(3.0, 0.0),
num_complex::Complex::new(4.0, 0.0)
]
];
let b: mdarray::DTensor<num_complex::Complex<f64>, 2> = tensor![
[
num_complex::Complex::new(5.0, 0.0),
num_complex::Complex::new(6.0, 0.0)
],
[
num_complex::Complex::new(7.0, 0.0),
num_complex::Complex::new(8.0, 0.0)
]
];
let backend_handle = get_backend_handle(backend);
let c = matmul_par(&a, &b, backend_handle);
assert!((c[[0, 0]].re - 19.0).abs() < 1e-10);
assert!((c[[0, 1]].re - 22.0).abs() < 1e-10);
assert!((c[[1, 0]].re - 43.0).abs() < 1e-10);
assert!((c[[1, 1]].re - 50.0).abs() < 1e-10);
assert!(c[[0, 0]].im.abs() < 1e-10);
spir_gemm_backend_release(backend);
}
}
#[test]
fn test_default_backend_larger_matrix() {
unsafe {
let backend = std::ptr::null();
let a: mdarray::DTensor<f64, 2> = tensor![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
let b: mdarray::DTensor<f64, 2> =
tensor![[7.0, 8.0, 9.0, 10.0], [11.0, 12.0, 13.0, 14.0]];
let backend_handle = get_backend_handle(backend);
let c = matmul_par(&a, &b, backend_handle);
assert!((c[[0, 0]] - 29.0).abs() < 1e-10);
assert!((c[[0, 1]] - 32.0).abs() < 1e-10);
assert!((c[[0, 2]] - 35.0).abs() < 1e-10);
assert!((c[[0, 3]] - 38.0).abs() < 1e-10);
}
}
#[test]
fn test_lp64_backend_larger_matrix() {
unsafe {
let backend = create_blas_backend();
assert!(!backend.is_null());
let a: mdarray::DTensor<f64, 2> = tensor![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
let b: mdarray::DTensor<f64, 2> =
tensor![[7.0, 8.0, 9.0, 10.0], [11.0, 12.0, 13.0, 14.0]];
let backend_handle = get_backend_handle(backend);
let c = matmul_par(&a, &b, backend_handle);
assert!((c[[0, 0]] - 29.0).abs() < 1e-10);
assert!((c[[0, 1]] - 32.0).abs() < 1e-10);
assert!((c[[0, 2]] - 35.0).abs() < 1e-10);
assert!((c[[0, 3]] - 38.0).abs() < 1e-10);
spir_gemm_backend_release(backend);
}
}
}
}