use mdarray::{DSlice, DTensor, DView, DViewMut, Layout};
use once_cell::sync::Lazy;
use std::sync::{Arc, RwLock};
#[cfg(feature = "system-blas")]
use blas_sys::dgemm_;
pub type DgemmFnPtr = unsafe extern "C" fn(
transa: *const libc::c_char,
transb: *const libc::c_char,
m: *const libc::c_int,
n: *const libc::c_int,
k: *const libc::c_int,
alpha: *const libc::c_double,
a: *const libc::c_double,
lda: *const libc::c_int,
b: *const libc::c_double,
ldb: *const libc::c_int,
beta: *const libc::c_double,
c: *mut libc::c_double,
ldc: *const libc::c_int,
);
pub type ZgemmFnPtr = unsafe extern "C" fn(
transa: *const libc::c_char,
transb: *const libc::c_char,
m: *const libc::c_int,
n: *const libc::c_int,
k: *const libc::c_int,
alpha: *const num_complex::Complex<f64>,
a: *const num_complex::Complex<f64>,
lda: *const libc::c_int,
b: *const num_complex::Complex<f64>,
ldb: *const libc::c_int,
beta: *const num_complex::Complex<f64>,
c: *mut num_complex::Complex<f64>,
ldc: *const libc::c_int,
);
#[cfg(feature = "system-blas")]
unsafe extern "C" fn zgemm_wrapper(
transa: *const libc::c_char,
transb: *const libc::c_char,
m: *const libc::c_int,
n: *const libc::c_int,
k: *const libc::c_int,
alpha: *const num_complex::Complex<f64>,
a: *const num_complex::Complex<f64>,
lda: *const libc::c_int,
b: *const num_complex::Complex<f64>,
ldb: *const libc::c_int,
beta: *const num_complex::Complex<f64>,
c: *mut num_complex::Complex<f64>,
ldc: *const libc::c_int,
) {
unsafe {
blas_sys::zgemm_(
transa,
transb,
m,
n,
k,
alpha as *const _ as *const blas_sys::c_double_complex,
a as *const _ as *const blas_sys::c_double_complex,
lda,
b as *const _ as *const blas_sys::c_double_complex,
ldb,
beta as *const _ as *const blas_sys::c_double_complex,
c as *mut _ as *mut blas_sys::c_double_complex,
ldc,
);
}
}
pub type Dgemm64FnPtr = unsafe extern "C" fn(
transa: *const libc::c_char,
transb: *const libc::c_char,
m: *const i64,
n: *const i64,
k: *const i64,
alpha: *const libc::c_double,
a: *const libc::c_double,
lda: *const i64,
b: *const libc::c_double,
ldb: *const i64,
beta: *const libc::c_double,
c: *mut libc::c_double,
ldc: *const i64,
);
pub type Zgemm64FnPtr = unsafe extern "C" fn(
transa: *const libc::c_char,
transb: *const libc::c_char,
m: *const i64,
n: *const i64,
k: *const i64,
alpha: *const num_complex::Complex<f64>,
a: *const num_complex::Complex<f64>,
lda: *const i64,
b: *const num_complex::Complex<f64>,
ldb: *const i64,
beta: *const num_complex::Complex<f64>,
c: *mut num_complex::Complex<f64>,
ldc: *const i64,
);
pub trait GemmBackend: Send + Sync {
unsafe fn dgemm(&self, m: usize, n: usize, k: usize, a: *const f64, b: *const f64, c: *mut f64);
unsafe fn zgemm(
&self,
m: usize,
n: usize,
k: usize,
a: *const num_complex::Complex<f64>,
b: *const num_complex::Complex<f64>,
c: *mut num_complex::Complex<f64>,
);
fn is_ilp64(&self) -> bool {
false
}
fn name(&self) -> &'static str;
}
struct FaerBackend;
impl GemmBackend for FaerBackend {
unsafe fn dgemm(
&self,
m: usize,
n: usize,
k: usize,
a: *const f64,
b: *const f64,
c: *mut f64,
) {
use faer::linalg::matmul::matmul;
use faer::mat::{MatMut, MatRef};
use faer::{Accum, Par};
let lhs = unsafe { MatRef::from_raw_parts(a, m, k, k as isize, 1) };
let rhs = unsafe { MatRef::from_raw_parts(b, k, n, n as isize, 1) };
let mut dst = unsafe { MatMut::from_raw_parts_mut(c, m, n, n as isize, 1) };
matmul(&mut dst, Accum::Replace, &lhs, &rhs, 1.0, Par::Seq);
}
unsafe fn zgemm(
&self,
m: usize,
n: usize,
k: usize,
a: *const num_complex::Complex<f64>,
b: *const num_complex::Complex<f64>,
c: *mut num_complex::Complex<f64>,
) {
use faer::linalg::matmul::matmul;
use faer::mat::{MatMut, MatRef};
use faer::{Accum, Par};
let lhs = unsafe { MatRef::from_raw_parts(a, m, k, k as isize, 1) };
let rhs = unsafe { MatRef::from_raw_parts(b, k, n, n as isize, 1) };
let mut dst = unsafe { MatMut::from_raw_parts_mut(c, m, n, n as isize, 1) };
matmul(
&mut dst,
Accum::Replace,
&lhs,
&rhs,
num_complex::Complex::new(1.0, 0.0),
Par::Seq,
);
}
fn name(&self) -> &'static str {
"Faer (Pure Rust)"
}
}
pub struct ExternalBlasBackend {
dgemm: DgemmFnPtr,
zgemm: ZgemmFnPtr,
}
impl ExternalBlasBackend {
pub fn new(dgemm: DgemmFnPtr, zgemm: ZgemmFnPtr) -> Self {
Self { dgemm, zgemm }
}
}
impl GemmBackend for ExternalBlasBackend {
unsafe fn dgemm(
&self,
m: usize,
n: usize,
k: usize,
a: *const f64,
b: *const f64,
c: *mut f64,
) {
assert!(
m <= i32::MAX as usize,
"Matrix dimension m too large for LP64 BLAS"
);
assert!(
n <= i32::MAX as usize,
"Matrix dimension n too large for LP64 BLAS"
);
assert!(
k <= i32::MAX as usize,
"Matrix dimension k too large for LP64 BLAS"
);
let transa = b'N' as libc::c_char; let transb = b'N' as libc::c_char; let m_i32 = n as i32; let n_i32 = m as i32; let k_i32 = k as i32; let alpha = 1.0f64;
let lda = n as i32; let ldb = k as i32; let beta = 0.0f64;
let ldc_i32 = n as i32;
unsafe {
(self.dgemm)(
&transa, &transb, &m_i32, &n_i32, &k_i32, &alpha, b, &lda, a, &ldb, &beta, c, &ldc_i32,
);
}
}
unsafe fn zgemm(
&self,
m: usize,
n: usize,
k: usize,
a: *const num_complex::Complex<f64>,
b: *const num_complex::Complex<f64>,
c: *mut num_complex::Complex<f64>,
) {
assert!(
m <= i32::MAX as usize,
"Matrix dimension m too large for LP64 BLAS"
);
assert!(
n <= i32::MAX as usize,
"Matrix dimension n too large for LP64 BLAS"
);
assert!(
k <= i32::MAX as usize,
"Matrix dimension k too large for LP64 BLAS"
);
let transa = b'N' as libc::c_char; let transb = b'N' as libc::c_char; let m_i32 = n as i32; let n_i32 = m as i32; let k_i32 = k as i32; let alpha = num_complex::Complex::new(1.0, 0.0);
let lda = n as i32; let ldb = k as i32; let beta = num_complex::Complex::new(0.0, 0.0);
let ldc_i32 = n as i32;
unsafe {
(self.zgemm)(
&transa,
&transb,
&m_i32,
&n_i32,
&k_i32,
&alpha,
b as *const _, &lda,
a as *const _, &ldb,
&beta,
c as *mut _,
&ldc_i32,
);
}
}
fn name(&self) -> &'static str {
"External BLAS (LP64)"
}
}
pub struct ExternalBlas64Backend {
dgemm64: Dgemm64FnPtr,
zgemm64: Zgemm64FnPtr,
}
impl ExternalBlas64Backend {
pub fn new(dgemm64: Dgemm64FnPtr, zgemm64: Zgemm64FnPtr) -> Self {
Self { dgemm64, zgemm64 }
}
}
impl GemmBackend for ExternalBlas64Backend {
unsafe fn dgemm(
&self,
m: usize,
n: usize,
k: usize,
a: *const f64,
b: *const f64,
c: *mut f64,
) {
let transa = b'N' as libc::c_char; let transb = b'N' as libc::c_char; let m_i64 = n as i64; let n_i64 = m as i64; let k_i64 = k as i64; let alpha = 1.0f64;
let lda = n as i64; let ldb = k as i64; let beta = 0.0f64;
let ldc_i64 = n as i64;
unsafe {
(self.dgemm64)(
&transa, &transb, &m_i64, &n_i64, &k_i64, &alpha, b, &lda, a, &ldb, &beta, c, &ldc_i64,
);
}
}
unsafe fn zgemm(
&self,
m: usize,
n: usize,
k: usize,
a: *const num_complex::Complex<f64>,
b: *const num_complex::Complex<f64>,
c: *mut num_complex::Complex<f64>,
) {
let transa = b'N' as libc::c_char; let transb = b'N' as libc::c_char; let m_i64 = n as i64; let n_i64 = m as i64; let k_i64 = k as i64; let alpha = num_complex::Complex::new(1.0, 0.0);
let lda = n as i64; let ldb = k as i64; let beta = num_complex::Complex::new(0.0, 0.0);
let ldc_i64 = n as i64;
unsafe {
(self.zgemm64)(
&transa,
&transb,
&m_i64,
&n_i64,
&k_i64,
&alpha,
b as *const _, &lda,
a as *const _, &ldb,
&beta,
c as *mut _,
&ldc_i64,
);
}
}
fn is_ilp64(&self) -> bool {
true
}
fn name(&self) -> &'static str {
"External BLAS (ILP64)"
}
}
#[derive(Clone)]
pub struct GemmBackendHandle {
inner: Arc<dyn GemmBackend>,
}
impl GemmBackendHandle {
pub fn new(backend: Box<dyn GemmBackend>) -> Self {
Self {
inner: Arc::from(backend),
}
}
pub fn default() -> Self {
Self {
inner: Arc::new(FaerBackend),
}
}
pub(crate) fn as_ref(&self) -> &dyn GemmBackend {
self.inner.as_ref()
}
}
static BLAS_DISPATCHER: Lazy<RwLock<Box<dyn GemmBackend>>> = Lazy::new(|| {
#[cfg(feature = "system-blas")]
{
let backend = ExternalBlasBackend::new(dgemm_ as DgemmFnPtr, zgemm_wrapper as ZgemmFnPtr);
RwLock::new(Box::new(backend) as Box<dyn GemmBackend>)
}
#[cfg(not(feature = "system-blas"))]
{
RwLock::new(Box::new(FaerBackend) as Box<dyn GemmBackend>)
}
});
pub unsafe fn set_blas_backend(dgemm: DgemmFnPtr, zgemm: ZgemmFnPtr) {
let backend = ExternalBlasBackend { dgemm, zgemm };
let mut dispatcher = BLAS_DISPATCHER.write().unwrap();
*dispatcher = Box::new(backend);
}
pub unsafe fn set_ilp64_backend(dgemm64: Dgemm64FnPtr, zgemm64: Zgemm64FnPtr) {
let backend = ExternalBlas64Backend { dgemm64, zgemm64 };
let mut dispatcher = BLAS_DISPATCHER.write().unwrap();
*dispatcher = Box::new(backend);
}
pub fn clear_blas_backend() {
let mut dispatcher = BLAS_DISPATCHER.write().unwrap();
*dispatcher = Box::new(FaerBackend);
}
pub fn get_backend_info() -> (&'static str, bool, bool) {
let dispatcher = BLAS_DISPATCHER.read().unwrap();
let name = dispatcher.name();
let is_external = !name.contains("Faer");
let is_ilp64 = dispatcher.is_ilp64();
(name, is_external, is_ilp64)
}
pub fn matmul_par<T>(
a: &DTensor<T, 2>,
b: &DTensor<T, 2>,
backend: Option<&GemmBackendHandle>,
) -> DTensor<T, 2>
where
T: num_complex::ComplexFloat + faer_traits::ComplexField + num_traits::One + Copy + 'static,
{
let (_m, k) = *a.shape();
let (k2, _n) = *b.shape();
assert_eq!(
k, k2,
"Matrix dimension mismatch: A.cols ({}) != B.rows ({})",
k, k2
);
let mut result = DTensor::<T, 2>::from_elem([_m, _n], T::zero().into());
matmul_par_overwrite(a, b, &mut result, backend);
result
}
pub fn matmul_par_view<T>(
a: &DView<'_, T, 2>,
b: &DView<'_, T, 2>,
backend: Option<&GemmBackendHandle>,
) -> DTensor<T, 2>
where
T: num_complex::ComplexFloat + faer_traits::ComplexField + num_traits::One + Copy + 'static,
{
assert!(
a.is_contiguous(),
"Matrix A view must be contiguous in memory"
);
assert!(
b.is_contiguous(),
"Matrix B view must be contiguous in memory"
);
let (m, k) = *a.shape();
let (k2, n) = *b.shape();
assert_eq!(
k, k2,
"Matrix dimension mismatch: A.cols ({}) != B.rows ({})",
k, k2
);
let mut result = DTensor::<T, 2>::from_elem([m, n], T::zero().into());
matmul_par_overwrite_view(a, b, &mut result, backend);
result
}
pub fn matmul_par_overwrite_view<T>(
a: &DView<'_, T, 2>,
b: &DView<'_, T, 2>,
c: &mut DTensor<T, 2>,
backend: Option<&GemmBackendHandle>,
) where
T: num_complex::ComplexFloat + faer_traits::ComplexField + num_traits::One + Copy + 'static,
{
assert!(
a.is_contiguous(),
"Matrix A view must be contiguous in memory"
);
assert!(
b.is_contiguous(),
"Matrix B view must be contiguous in memory"
);
let (m, k) = *a.shape();
let (k2, n) = *b.shape();
let (mc, nc) = *c.shape();
assert_eq!(
k, k2,
"Matrix dimension mismatch: A.cols ({}) != B.rows ({})",
k, k2
);
assert_eq!(
m, mc,
"Output matrix dimension mismatch: C.rows ({}) != A.rows ({})",
mc, m
);
assert_eq!(
n, nc,
"Output matrix dimension mismatch: C.cols ({}) != B.cols ({})",
nc, n
);
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
let a_ptr = a.as_ptr() as *const f64;
let b_ptr = b.as_ptr() as *const f64;
let c_ptr = c.as_mut_ptr() as *mut f64;
match backend {
Some(handle) => unsafe {
handle.as_ref().dgemm(m, n, k, a_ptr, b_ptr, c_ptr);
},
None => {
let dispatcher = BLAS_DISPATCHER.read().unwrap();
unsafe {
dispatcher.dgemm(m, n, k, a_ptr, b_ptr, c_ptr);
}
}
}
} else if std::any::TypeId::of::<T>() == std::any::TypeId::of::<num_complex::Complex<f64>>() {
let a_ptr = a.as_ptr() as *const num_complex::Complex<f64>;
let b_ptr = b.as_ptr() as *const num_complex::Complex<f64>;
let c_ptr = c.as_mut_ptr() as *mut num_complex::Complex<f64>;
match backend {
Some(handle) => unsafe {
handle.as_ref().zgemm(m, n, k, a_ptr, b_ptr, c_ptr);
},
None => {
let dispatcher = BLAS_DISPATCHER.read().unwrap();
unsafe {
dispatcher.zgemm(m, n, k, a_ptr, b_ptr, c_ptr);
}
}
}
} else {
let a_tensor = DTensor::<T, 2>::from_fn(*a.shape(), |idx| a[idx]);
let b_tensor = DTensor::<T, 2>::from_fn(*b.shape(), |idx| b[idx]);
use mdarray_linalg::matmul::MatMulBuilder;
use mdarray_linalg::prelude::MatMul;
use mdarray_linalg_faer::Faer;
Faer.matmul(&a_tensor, &b_tensor).parallelize().overwrite(c);
}
}
pub fn matmul_par_to_viewmut<T>(
a: &DView<'_, T, 2>,
b: &DView<'_, T, 2>,
c: &mut DViewMut<'_, T, 2>,
backend: Option<&GemmBackendHandle>,
) where
T: num_complex::ComplexFloat + faer_traits::ComplexField + num_traits::One + Copy + 'static,
{
assert!(
a.is_contiguous(),
"Matrix A view must be contiguous in memory"
);
assert!(
b.is_contiguous(),
"Matrix B view must be contiguous in memory"
);
assert!(
c.is_contiguous(),
"Matrix C view must be contiguous in memory"
);
let (m, k) = *a.shape();
let (k2, n) = *b.shape();
let (mc, nc) = *c.shape();
assert_eq!(
k, k2,
"Matrix dimension mismatch: A.cols ({}) != B.rows ({})",
k, k2
);
assert_eq!(
m, mc,
"Output matrix dimension mismatch: C.rows ({}) != A.rows ({})",
mc, m
);
assert_eq!(
n, nc,
"Output matrix dimension mismatch: C.cols ({}) != B.cols ({})",
nc, n
);
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
let a_ptr = a.as_ptr() as *const f64;
let b_ptr = b.as_ptr() as *const f64;
let c_ptr = c.as_mut_ptr() as *mut f64;
match backend {
Some(handle) => unsafe {
handle.as_ref().dgemm(m, n, k, a_ptr, b_ptr, c_ptr);
},
None => {
let dispatcher = BLAS_DISPATCHER.read().unwrap();
unsafe {
dispatcher.dgemm(m, n, k, a_ptr, b_ptr, c_ptr);
}
}
}
} else if std::any::TypeId::of::<T>() == std::any::TypeId::of::<num_complex::Complex<f64>>() {
let a_ptr = a.as_ptr() as *const num_complex::Complex<f64>;
let b_ptr = b.as_ptr() as *const num_complex::Complex<f64>;
let c_ptr = c.as_mut_ptr() as *mut num_complex::Complex<f64>;
match backend {
Some(handle) => unsafe {
handle.as_ref().zgemm(m, n, k, a_ptr, b_ptr, c_ptr);
},
None => {
let dispatcher = BLAS_DISPATCHER.read().unwrap();
unsafe {
dispatcher.zgemm(m, n, k, a_ptr, b_ptr, c_ptr);
}
}
}
} else {
let a_tensor = DTensor::<T, 2>::from_fn(*a.shape(), |idx| a[idx]);
let b_tensor = DTensor::<T, 2>::from_fn(*b.shape(), |idx| b[idx]);
let mut c_tensor = DTensor::<T, 2>::from_fn(*c.shape(), |_| T::zero());
use mdarray_linalg::matmul::MatMulBuilder;
use mdarray_linalg::prelude::MatMul;
use mdarray_linalg_faer::Faer;
Faer.matmul(&a_tensor, &b_tensor)
.parallelize()
.overwrite(&mut c_tensor);
for i in 0..mc {
for j in 0..nc {
c[[i, j]] = c_tensor[[i, j]];
}
}
}
}
pub fn matmul_par_overwrite<T, Lc: Layout>(
a: &DTensor<T, 2>,
b: &DTensor<T, 2>,
c: &mut DSlice<T, 2, Lc>,
backend: Option<&GemmBackendHandle>,
) where
T: num_complex::ComplexFloat + faer_traits::ComplexField + num_traits::One + Copy + 'static,
{
let (m, k) = *a.shape();
let (k2, n) = *b.shape();
let (mc, nc) = *c.shape();
assert_eq!(
k, k2,
"Matrix dimension mismatch: A.cols ({}) != B.rows ({})",
k, k2
);
assert_eq!(
m, mc,
"Output matrix dimension mismatch: C.rows ({}) != A.rows ({})",
mc, m
);
assert_eq!(
n, nc,
"Output matrix dimension mismatch: C.cols ({}) != B.cols ({})",
nc, n
);
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
let a_ptr = a.as_ptr() as *const f64;
let b_ptr = b.as_ptr() as *const f64;
let c_ptr = c.as_mut_ptr() as *mut f64;
match backend {
Some(handle) => {
unsafe {
handle.as_ref().dgemm(m, n, k, a_ptr, b_ptr, c_ptr);
}
}
None => {
let dispatcher = BLAS_DISPATCHER.read().unwrap();
unsafe {
dispatcher.dgemm(m, n, k, a_ptr, b_ptr, c_ptr);
}
}
}
} else if std::any::TypeId::of::<T>() == std::any::TypeId::of::<num_complex::Complex<f64>>() {
let a_ptr = a.as_ptr() as *const num_complex::Complex<f64>;
let b_ptr = b.as_ptr() as *const num_complex::Complex<f64>;
let c_ptr = c.as_mut_ptr() as *mut num_complex::Complex<f64>;
match backend {
Some(handle) => {
unsafe {
handle.as_ref().zgemm(m, n, k, a_ptr, b_ptr, c_ptr);
}
}
None => {
let dispatcher = BLAS_DISPATCHER.read().unwrap();
unsafe {
dispatcher.zgemm(m, n, k, a_ptr, b_ptr, c_ptr);
}
}
}
} else {
use mdarray_linalg::matmul::MatMulBuilder;
use mdarray_linalg::prelude::MatMul;
use mdarray_linalg_faer::Faer;
Faer.matmul(a, b).parallelize().overwrite(c);
}
}
#[cfg(test)]
mod tests {
use super::*;
use mdarray::DView;
#[test]
#[cfg(not(feature = "system-blas"))]
fn test_default_backend_is_faer() {
let (name, is_external, is_ilp64) = get_backend_info();
assert_eq!(name, "Faer (Pure Rust)");
assert!(!is_external);
assert!(!is_ilp64);
}
#[test]
fn test_matmul_par_view() {
let a = DTensor::<f64, 2>::from([[1.0, 2.0], [3.0, 4.0]]);
let b = DTensor::<f64, 2>::from([[5.0, 6.0], [7.0, 8.0]]);
let a_view: DView<'_, f64, 2> = a.view(.., ..);
let b_view: DView<'_, f64, 2> = b.view(.., ..);
let c_view = matmul_par_view(&a_view, &b_view, None);
let c_expected = matmul_par(&a, &b, None);
assert_eq!(c_view.shape(), c_expected.shape());
for i in 0..c_view.shape().0 {
for j in 0..c_view.shape().1 {
assert!((c_view[[i, j]] - c_expected[[i, j]]).abs() < 1e-10);
}
}
}
#[test]
fn test_matmul_par_overwrite_view() {
use num_complex::Complex;
let a = DTensor::<Complex<f64>, 2>::from_fn([2, 2], |idx| {
Complex::new((idx[0] * 2 + idx[1]) as f64, 0.0)
});
let b = DTensor::<Complex<f64>, 2>::from_fn([2, 2], |idx| {
Complex::new((idx[0] * 2 + idx[1] + 10) as f64, 0.0)
});
let a_view: DView<'_, Complex<f64>, 2> = a.view(.., ..);
let b_view: DView<'_, Complex<f64>, 2> = b.view(.., ..);
let mut c_view = DTensor::<Complex<f64>, 2>::from_elem([2, 2], Complex::new(0.0, 0.0));
matmul_par_overwrite_view(&a_view, &b_view, &mut c_view, None);
let c_expected = matmul_par(&a, &b, None);
assert_eq!(c_view.shape(), c_expected.shape());
for i in 0..c_view.shape().0 {
for j in 0..c_view.shape().1 {
assert!((c_view[[i, j]] - c_expected[[i, j]]).norm() < 1e-10);
}
}
}
#[test]
fn test_clear_backend() {
clear_blas_backend();
let (name, _, _) = get_backend_info();
assert_eq!(name, "Faer (Pure Rust)");
}
#[test]
fn test_matmul_f64() {
let a_data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b_data = [7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
let a = DTensor::<f64, 2>::from_fn([2, 3], |idx| a_data[idx[0] * 3 + idx[1]]);
let b = DTensor::<f64, 2>::from_fn([3, 2], |idx| b_data[idx[0] * 2 + idx[1]]);
let c = matmul_par(&a, &b, None);
assert_eq!(*c.shape(), (2, 2));
assert!((c[[0, 0]] - 58.0).abs() < 1e-10);
assert!((c[[0, 1]] - 64.0).abs() < 1e-10);
assert!((c[[1, 0]] - 139.0).abs() < 1e-10);
assert!((c[[1, 1]] - 154.0).abs() < 1e-10);
}
#[test]
fn test_matmul_par_basic() {
use mdarray::tensor;
let a: DTensor<f64, 2> = tensor![[1.0, 2.0], [3.0, 4.0]];
let b: DTensor<f64, 2> = tensor![[5.0, 6.0], [7.0, 8.0]];
let c = matmul_par(&a, &b, None);
assert!((c[[0, 0]] - 19.0).abs() < 1e-10);
assert!((c[[0, 1]] - 22.0).abs() < 1e-10);
assert!((c[[1, 0]] - 43.0).abs() < 1e-10);
assert!((c[[1, 1]] - 50.0).abs() < 1e-10);
}
#[test]
fn test_matmul_par_non_square() {
use mdarray::tensor;
let a: DTensor<f64, 2> = tensor![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; let b: DTensor<f64, 2> = tensor![[7.0], [8.0], [9.0]]; let c = matmul_par(&a, &b, None);
assert!((c[[0, 0]] - 50.0).abs() < 1e-10);
assert!((c[[1, 0]] - 122.0).abs() < 1e-10);
}
}