#[derive(Default, Clone, Copy)]
pub struct Blas;
use super::*;
macro_rules! impl_dot {
($t: ty, $blas_fn: ident) => {
impl operations::DotProduct<$t> for Blas {
fn dot<const LEN: usize>(
&self,
a: &impl StaticVec<$t, LEN>,
b: &impl StaticVec<$t, LEN>,
) -> $t {
unsafe { cblas_sys::$blas_fn(LEN as i32, a.as_ptr(), 1, b.as_ptr(), 1) }
}
}
};
}
macro_rules! impl_dot_comp {
($t: ty, $comp_blas_fn: ident) => {
impl operations::DotProduct<Complex<$t>> for Blas {
fn dot<const LEN: usize>(
&self,
a: &impl StaticVec<Complex<$t>, LEN>,
b: &impl StaticVec<Complex<$t>, LEN>,
) -> Complex<$t> {
let mut tmp: [$t; 2] = [0.; 2];
unsafe {
cblas_sys::$comp_blas_fn(
LEN as i32,
a.as_ptr() as *const [$t; 2],
1,
b.as_ptr() as *const [$t; 2],
1,
tmp.as_mut_ptr() as *mut [$t; 2],
)
}
Complex {
re: tmp[0],
im: tmp[1],
}
}
}
};
}
macro_rules! impl_gemm {
($t: ty : $gemm: ident $gemv: ident) => {
impl operations::MatrixMul<$t> for Blas {
fn matrix_mul<
A: StaticVec<$t, ALEN>,
B: StaticVec<$t, BLEN>,
C: StaticVec<$t, CLEN>,
const ALEN: usize,
const BLEN: usize,
const CLEN: usize,
>(
&self,
a: &A,
b: &B,
buffer: &mut C,
m: usize,
n: usize,
k: usize,
lda: usize,
ldb: usize,
ldc: usize,
a_trans: bool,
b_trans: bool,
) where
A: Sized,
B: Sized,
{
use cblas_sys::CBLAS_TRANSPOSE::*;
unsafe {
cblas_sys::$gemm(
cblas_sys::CBLAS_LAYOUT::CblasRowMajor,
if a_trans { CblasTrans } else { CblasNoTrans },
if b_trans { CblasTrans } else { CblasNoTrans },
m as i32,
n as i32,
k as i32,
1., a.as_ptr(), lda as i32, b.as_ptr(), ldb as i32, 0., buffer.as_ptr() as *mut $t, ldc as i32, )
}
}
fn matrix_vector_mul<
A: StaticVec<$t, ALEN>,
B: StaticVec<$t, BLEN>,
C: StaticVec<$t, CLEN>,
const ALEN: usize,
const BLEN: usize,
const CLEN: usize,
>(
&self,
a: &A,
b: &B,
buffer: &mut C,
m: usize,
n: usize,
lda: usize,
a_trans: bool,
) where
A: Sized,
B: Sized,
{
use cblas_sys::CBLAS_TRANSPOSE::*;
unsafe {
cblas_sys::$gemv(
cblas_sys::CBLAS_LAYOUT::CblasRowMajor,
if a_trans { CblasTrans } else { CblasNoTrans },
n as i32,
m as i32,
1.,
a.as_ptr(),
lda as i32,
b.as_ptr(),
1,
0.,
buffer.as_ptr() as *mut $t,
1,
)
}
}
}
};
}
macro_rules! impl_norm {
($t: ty, $t2: ty, $t3: ty, $blas_fn: ident) => {
impl operations::Normalize<$t> for Blas {
type NormOutput = $t3;
fn norm<const LEN: usize>(&self, a: &impl StaticVec<$t, LEN>) -> $t3 {
unsafe { cblas_sys::$blas_fn(LEN as i32, a.as_ptr() as *const $t2, 1) }.into()
}
fn normalize<const LEN: usize>(&self, a: &mut impl StaticVec<$t, LEN>) {
let norm = <$t>::from(Backend::norm(self, a));
a.mut_moo_ref().iter_mut().for_each(|n| *n = *n / norm);
}
}
};
}
impl_gemm!(f32: cblas_sgemm cblas_sgemv);
impl_gemm!(f64: cblas_dgemm cblas_dgemv);
impl_dot!(f32, cblas_sdot);
impl_dot!(f64, cblas_ddot);
impl_dot_comp!(f32, cblas_cdotu_sub);
impl_dot_comp!(f64, cblas_zdotu_sub);
impl_norm!(f32, f32, f32, cblas_snrm2);
impl_norm!(f64, f64, f64, cblas_dnrm2);
impl_norm!(Complex<f32>, [f32; 2], f32, cblas_scnrm2);
impl_norm!(Complex<f64>, [f64; 2], f64, cblas_dznrm2);
impl Backend<f32> for Blas {}
impl Backend<f64> for Blas {}
impl Backend<Complex<f32>> for Blas {}
impl Backend<Complex<f64>> for Blas {}