use crate::matrix::ComplexMatrix;
use crate::vector::ComplexVector;
use crate::{to_i32, Complex64, StrError, CBLAS_COL_MAJOR, CBLAS_TRANS};
extern "C" {
fn cblas_zgemv(
layout: i32,
transa: i32,
m: i32,
n: i32,
alpha: *const Complex64,
a: *const Complex64,
lda: i32,
x: *const Complex64,
incx: i32,
beta: *const Complex64,
y: *mut Complex64,
incy: i32,
);
}
pub fn complex_vec_mat_mul(
v: &mut ComplexVector,
alpha: Complex64,
u: &ComplexVector,
a: &ComplexMatrix,
) -> Result<(), StrError> {
let n = v.dim();
let m = u.dim();
if m != a.nrow() || n != a.ncol() {
return Err("matrix and vectors are incompatible");
}
if n == 0 {
return Ok(());
}
let zero = Complex64::new(0.0, 0.0);
if m == 0 {
v.fill(zero);
return Ok(());
}
let m_i32: i32 = to_i32(m);
let n_i32: i32 = to_i32(n);
let incx = 1;
let incy = 1;
unsafe {
cblas_zgemv(
CBLAS_COL_MAJOR,
CBLAS_TRANS,
m_i32,
n_i32,
&alpha,
a.as_data().as_ptr(),
m_i32,
u.as_data().as_ptr(),
incx,
&zero,
v.as_mut_data().as_mut_ptr(),
incy,
);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::complex_vec_mat_mul;
use crate::{complex_vec_approx_eq, cpx, Complex64, ComplexMatrix, ComplexVector};
#[test]
fn vec_mat_mul_fails_on_wrong_dims() {
let u = ComplexVector::new(2);
let a_1x2 = ComplexMatrix::new(1, 2);
let a_3x1 = ComplexMatrix::new(3, 1);
let mut v = ComplexVector::new(3);
let one = cpx!(1.0, 0.0);
assert_eq!(
complex_vec_mat_mul(&mut v, one, &u, &a_1x2),
Err("matrix and vectors are incompatible")
);
assert_eq!(
complex_vec_mat_mul(&mut v, one, &u, &a_3x1),
Err("matrix and vectors are incompatible")
);
}
#[test]
fn vec_mat_mul_works() {
#[rustfmt::skip]
let a = ComplexMatrix::from(&[
[ 5.0, -2.0, 0.0, 1.0],
[10.0, -4.0, 0.0, 2.0],
[15.0, -6.0, 0.0, 3.0],
]);
let u = ComplexVector::from(&[1.0, 3.0, 8.0]);
let mut v = ComplexVector::new(a.ncol());
let one = cpx!(1.0, 0.0);
complex_vec_mat_mul(&mut v, one, &u, &a).unwrap();
let correct = &[cpx!(155.0, 0.0), cpx!(-62.0, 0.0), cpx!(0.0, 0.0), cpx!(31.0, 0.0)];
complex_vec_approx_eq(&v, correct, 1e-15);
}
#[test]
fn vec_mat_mul_zero_works() {
let a_0x0 = ComplexMatrix::new(0, 0);
let a_0x1 = ComplexMatrix::new(0, 1);
let a_1x0 = ComplexMatrix::new(1, 0);
let u0 = ComplexVector::new(0);
let u1 = ComplexVector::new(1);
let mut v0 = ComplexVector::new(0);
let mut v1 = ComplexVector::new(1);
let one = cpx!(1.0, 0.0);
let zero = cpx!(0.0, 0.0);
complex_vec_mat_mul(&mut v0, one, &u0, &a_0x0).unwrap();
assert_eq!(v0.as_data(), &[] as &[Complex64]);
complex_vec_mat_mul(&mut v1, one, &u0, &a_0x1).unwrap();
assert_eq!(v1.as_data(), &[zero]);
complex_vec_mat_mul(&mut v0, one, &u1, &a_1x0).unwrap();
assert_eq!(v0.as_data(), &[] as &[Complex64]);
}
}