use super::ComplexMatrix;
use crate::{array_plus_opx_complex, Complex64, StrError};
pub fn complex_mat_add(
c: &mut ComplexMatrix,
alpha: Complex64,
a: &ComplexMatrix,
beta: Complex64,
b: &ComplexMatrix,
) -> Result<(), StrError> {
let (m, n) = c.dims();
if a.nrow() != m || a.ncol() != n || b.nrow() != m || b.ncol() != n {
return Err("matrices are incompatible");
}
array_plus_opx_complex(c.as_mut_data(), alpha, a.as_data(), beta, b.as_data())
}
#[cfg(test)]
mod tests {
use super::{complex_mat_add, ComplexMatrix};
use crate::{complex_mat_approx_eq, cpx, Complex64};
#[test]
fn complex_mat_add_fails_on_wrong_dims() {
let a_2x2 = ComplexMatrix::new(2, 2);
let a_2x3 = ComplexMatrix::new(2, 3);
let a_3x2 = ComplexMatrix::new(3, 2);
let b_2x2 = ComplexMatrix::new(2, 2);
let b_2x3 = ComplexMatrix::new(2, 3);
let b_3x2 = ComplexMatrix::new(3, 2);
let mut c_2x2 = ComplexMatrix::new(2, 2);
let alpha = cpx!(1.0, 0.0);
let beta = cpx!(1.0, 0.0);
assert_eq!(
complex_mat_add(&mut c_2x2, alpha, &a_2x3, beta, &b_2x2),
Err("matrices are incompatible")
);
assert_eq!(
complex_mat_add(&mut c_2x2, alpha, &a_3x2, beta, &b_2x2),
Err("matrices are incompatible")
);
assert_eq!(
complex_mat_add(&mut c_2x2, alpha, &a_2x2, beta, &b_2x3),
Err("matrices are incompatible")
);
assert_eq!(
complex_mat_add(&mut c_2x2, alpha, &a_2x2, beta, &b_3x2),
Err("matrices are incompatible")
);
}
#[test]
fn complex_mat_add_works() {
const NOISE: Complex64 = cpx!(1234.567, 3456.789);
#[rustfmt::skip]
let a = ComplexMatrix::from(&[
[1.0, 2.0, 3.0, 4.0],
[1.0, 2.0, 3.0, 4.0],
[1.0, 2.0, 3.0, 4.0],
]);
#[rustfmt::skip]
let b = ComplexMatrix::from(&[
[0.5, 1.0, 1.5, 2.0],
[0.5, 1.0, 1.5, 2.0],
[0.5, 1.0, 1.5, 2.0],
]);
let mut c = ComplexMatrix::from(&[
[NOISE, NOISE, NOISE, NOISE],
[NOISE, NOISE, NOISE, NOISE],
[NOISE, NOISE, NOISE, NOISE],
]);
let alpha = cpx!(1.0, 0.0);
let beta = cpx!(-4.0, 0.0);
complex_mat_add(&mut c, alpha, &a, beta, &b).unwrap();
#[rustfmt::skip]
let correct = &[
[cpx!(-1.0, 0.0), cpx!(-2.0, 0.0), cpx!(-3.0, 0.0), cpx!(-4.0, 0.0)],
[cpx!(-1.0, 0.0), cpx!(-2.0, 0.0), cpx!(-3.0, 0.0), cpx!(-4.0, 0.0)],
[cpx!(-1.0, 0.0), cpx!(-2.0, 0.0), cpx!(-3.0, 0.0), cpx!(-4.0, 0.0)],
];
complex_mat_approx_eq(&c, correct, 1e-15);
}
#[test]
fn complex_add_matrix_5x5_works() {
const NOISE: Complex64 = cpx!(1234.567, 3456.789);
let a = ComplexMatrix::from(&[
[1.0, 2.0, 3.0, 4.0, 5.0],
[1.0, 2.0, 3.0, 4.0, 5.0],
[1.0, 2.0, 3.0, 4.0, 5.0],
[1.0, 2.0, 3.0, 4.0, 5.0],
[1.0, 2.0, 3.0, 4.0, 5.0],
]);
let b = ComplexMatrix::from(&[
[0.5, 1.0, 1.5, 2.0, 2.5],
[0.5, 1.0, 1.5, 2.0, 2.5],
[0.5, 1.0, 1.5, 2.0, 2.5],
[0.5, 1.0, 1.5, 2.0, 2.5],
[0.5, 1.0, 1.5, 2.0, 2.5],
]);
let mut c = ComplexMatrix::from(&[
[NOISE, NOISE, NOISE, NOISE, NOISE],
[NOISE, NOISE, NOISE, NOISE, NOISE],
[NOISE, NOISE, NOISE, NOISE, NOISE],
[NOISE, NOISE, NOISE, NOISE, NOISE],
[NOISE, NOISE, NOISE, NOISE, NOISE],
]);
let alpha = cpx!(1.0, 0.0);
let beta = cpx!(-4.0, 0.0);
complex_mat_add(&mut c, alpha, &a, beta, &b).unwrap();
#[rustfmt::skip]
let correct = &[
[cpx!(-1.0,0.0), cpx!(-2.0,0.0), cpx!(-3.0,0.0), cpx!(-4.0,0.0), cpx!(-5.0,0.0)],
[cpx!(-1.0,0.0), cpx!(-2.0,0.0), cpx!(-3.0,0.0), cpx!(-4.0,0.0), cpx!(-5.0,0.0)],
[cpx!(-1.0,0.0), cpx!(-2.0,0.0), cpx!(-3.0,0.0), cpx!(-4.0,0.0), cpx!(-5.0,0.0)],
[cpx!(-1.0,0.0), cpx!(-2.0,0.0), cpx!(-3.0,0.0), cpx!(-4.0,0.0), cpx!(-5.0,0.0)],
[cpx!(-1.0,0.0), cpx!(-2.0,0.0), cpx!(-3.0,0.0), cpx!(-4.0,0.0), cpx!(-5.0,0.0)],
];
complex_mat_approx_eq(&c, correct, 1e-15);
}
#[test]
fn complex_mat_add_skip() {
let a = ComplexMatrix::new(0, 0);
let b = ComplexMatrix::new(0, 0);
let mut c = ComplexMatrix::new(0, 0);
let alpha = cpx!(1.0, 0.0);
let beta = cpx!(1.0, 0.0);
complex_mat_add(&mut c, alpha, &a, beta, &b).unwrap();
assert_eq!(c.as_data().len(), 0);
}
}