use super::ComplexMatrix;
use crate::{to_i32, Complex64, StrError, CBLAS_COL_MAJOR, CBLAS_NO_TRANS, CBLAS_TRANS};
extern "C" {
fn cblas_zgemm(
layout: i32,
transa: i32,
transb: i32,
m: i32,
n: i32,
k: i32,
alpha: *const Complex64,
a: *const Complex64,
lda: i32,
b: *const Complex64,
ldb: i32,
beta: *const Complex64,
c: *mut Complex64,
ldc: i32,
);
}
pub fn complex_mat_t_mat_mul(
c: &mut ComplexMatrix,
alpha: Complex64,
a: &ComplexMatrix,
b: &ComplexMatrix,
beta: Complex64,
) -> Result<(), StrError> {
let (m, n) = c.dims();
let k = a.nrow();
if a.ncol() != m || b.nrow() != k || b.ncol() != n {
return Err("matrices are incompatible");
}
if m == 0 || n == 0 {
return Ok(());
}
if k == 0 {
let zero = Complex64::new(0.0, 0.0);
c.fill(zero);
return Ok(());
}
let m_i32: i32 = to_i32(m);
let n_i32: i32 = to_i32(n);
let k_i32: i32 = to_i32(k);
let lda = k_i32;
let ldb = k_i32;
unsafe {
cblas_zgemm(
CBLAS_COL_MAJOR,
CBLAS_TRANS,
CBLAS_NO_TRANS,
m_i32,
n_i32,
k_i32,
&alpha,
a.as_data().as_ptr(),
lda,
b.as_data().as_ptr(),
ldb,
&beta,
c.as_mut_data().as_mut_ptr(),
m_i32,
);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::complex_mat_t_mat_mul;
use crate::{complex_mat_approx_eq, cpx, Complex64, ComplexMatrix};
#[test]
fn complex_mat_t_mat_mul_fails_on_wrong_dims() {
let a_2x1 = ComplexMatrix::new(2, 1);
let a_1x2 = ComplexMatrix::new(1, 2);
let b_2x1 = ComplexMatrix::new(2, 1);
let b_1x3 = ComplexMatrix::new(1, 3);
let mut c_2x2 = ComplexMatrix::new(2, 2);
let alpha = cpx!(1.0, 0.0);
let beta = cpx!(0.0, 0.0);
assert_eq!(
complex_mat_t_mat_mul(&mut c_2x2, alpha, &a_2x1, &b_2x1, beta),
Err("matrices are incompatible")
);
assert_eq!(
complex_mat_t_mat_mul(&mut c_2x2, alpha, &a_1x2, &b_2x1, beta),
Err("matrices are incompatible")
);
assert_eq!(
complex_mat_t_mat_mul(&mut c_2x2, alpha, &a_2x1, &b_1x3, beta),
Err("matrices are incompatible")
);
}
#[test]
fn complex_mat_t_mat_mul_0x0_works() {
let a = ComplexMatrix::new(0, 0);
let b = ComplexMatrix::new(0, 0);
let mut c = ComplexMatrix::new(0, 0);
let alpha = cpx!(2.0, 0.0);
let beta = cpx!(0.0, 0.0);
complex_mat_t_mat_mul(&mut c, alpha, &a, &b, beta).unwrap();
let a = ComplexMatrix::new(0, 1);
let b = ComplexMatrix::new(0, 1);
let mut c = ComplexMatrix::from(&[[cpx!(123.0, 456.0)]]);
complex_mat_t_mat_mul(&mut c, alpha, &a, &b, beta).unwrap();
let correct = &[
[cpx!(0.0, 0.0)], ];
complex_mat_approx_eq(&c, correct, 1e-15);
}
#[test]
fn complex_mat_t_mat_mul_works_1() {
let a = ComplexMatrix::from(&[
[1.0, 0.5],
[2.0, 0.75],
[3.0, 1.5],
]);
let b = ComplexMatrix::from(&[
[0.1, 0.5, 0.5, 0.75],
[0.2, 2.0, 2.0, 2.00],
[0.3, 0.5, 0.5, 0.50],
]);
let mut c = ComplexMatrix::new(2, 4);
let alpha = cpx!(2.0, 0.0);
let beta = cpx!(0.0, 0.0);
complex_mat_t_mat_mul(&mut c, alpha, &a, &b, beta).unwrap();
#[rustfmt::skip]
let correct = &[
[cpx!(2.80, 0.0), cpx!(12.0, 0.0), cpx!(12.0, 0.0), cpx!(12.50, 0.0)],
[cpx!(1.30, 0.0), cpx!( 5.0, 0.0), cpx!( 5.0, 0.0), cpx!( 5.25, 0.0)],
];
complex_mat_approx_eq(&c, correct, 1e-15);
}
#[test]
fn complex_mat_t_mat_mul_works_2() {
let a = ComplexMatrix::from(&[
[1.0, 0.5],
[2.0, 0.75],
[3.0, 1.5],
]);
let b = ComplexMatrix::from(&[
[0.1, 0.5, 0.5, 0.75],
[0.2, 2.0, 2.0, 2.00],
[0.3, 0.5, 0.5, 0.50],
]);
let mut c = ComplexMatrix::filled(2, 4, cpx!(100.0, 0.0));
let alpha = cpx!(2.0, 0.0);
let beta = cpx!(10.0, 0.0);
complex_mat_t_mat_mul(&mut c, alpha, &a, &b, beta).unwrap();
#[rustfmt::skip]
let correct = &[
[cpx!(1002.80, 0.0), cpx!(1012.0, 0.0), cpx!(1012.0, 0.0), cpx!(1012.50, 0.0)],
[cpx!(1001.30, 0.0), cpx!(1005.0, 0.0), cpx!(1005.0, 0.0), cpx!(1005.25, 0.0)],
];
complex_mat_approx_eq(&c, correct, 1e-15);
}
}