use super::ComplexMatrix;
use crate::{to_i32, Complex64, StrError, CBLAS_COL_MAJOR, CBLAS_LOWER, CBLAS_NO_TRANS, CBLAS_TRANS, CBLAS_UPPER};
extern "C" {
fn cblas_zsyrk(
layout: i32,
uplo: i32,
trans: i32,
n: i32,
k: i32,
alpha: *const Complex64,
a: *const Complex64,
lda: i32,
beta: *const Complex64,
c: *mut Complex64,
ldc: i32,
);
}
pub fn complex_mat_sym_rank_op(
c: &mut ComplexMatrix,
a: &ComplexMatrix,
alpha: Complex64,
beta: Complex64,
upper: bool,
second_case: bool,
) -> Result<(), StrError> {
let (m, n) = c.dims();
if m != n {
return Err("[c] matrix must be square");
}
let (row, col) = a.dims();
let (lda, k, trans) = if !second_case {
if row != n {
return Err("[a] matrix is incompatible");
}
(row, col, CBLAS_NO_TRANS)
} else {
if col != n {
return Err("[a] matrix is incompatible");
}
(row, row, CBLAS_TRANS)
};
let uplo = if upper { CBLAS_UPPER } else { CBLAS_LOWER };
let n_i32 = to_i32(n);
let k_i32 = to_i32(k);
let ldc = n_i32;
unsafe {
cblas_zsyrk(
CBLAS_COL_MAJOR,
uplo,
trans,
n_i32,
k_i32,
&alpha,
a.as_data().as_ptr(),
to_i32(lda),
&beta,
c.as_mut_data().as_mut_ptr(),
ldc,
);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::complex_mat_sym_rank_op;
use crate::{complex_mat_approx_eq, cpx, Complex64, ComplexMatrix};
#[test]
fn complex_mat_sym_rank_op_fail_on_wrong_dims() {
let mut c_2x2 = ComplexMatrix::new(2, 2);
let mut c_3x2 = ComplexMatrix::new(3, 2);
let a_2x3 = ComplexMatrix::new(2, 3);
let a_3x2 = ComplexMatrix::new(3, 2);
let alpha = cpx!(2.0, 1.0);
let beta = cpx!(3.0, 1.0);
assert_eq!(
complex_mat_sym_rank_op(&mut c_3x2, &a_3x2, alpha, beta, false, false).err(),
Some("[c] matrix must be square")
);
assert_eq!(
complex_mat_sym_rank_op(&mut c_2x2, &a_3x2, alpha, beta, false, false).err(),
Some("[a] matrix is incompatible")
);
assert_eq!(
complex_mat_sym_rank_op(&mut c_2x2, &a_2x3, alpha, beta, false, true).err(),
Some("[a] matrix is incompatible")
);
}
#[test]
fn complex_mat_sym_rank_op_works_first_case() {
#[rustfmt::skip]
let mut c_lower = ComplexMatrix::from(&[
[cpx!( 3.0, 1.0), cpx!(0.0, 0.0), cpx!(0.0, 0.0), cpx!(0.0, 0.0)],
[cpx!(-1.0, 0.0), cpx!(3.0, 0.0), cpx!(0.0, 0.0), cpx!(0.0, 0.0)],
[cpx!(-4.0, 0.0), cpx!(1.0, 0.0), cpx!(3.0, 0.0), cpx!(0.0, 0.0)],
[cpx!(-1.0, 0.0), cpx!(2.0, 0.0), cpx!(0.0, 0.0), cpx!(3.0, -1.0)],
]);
#[rustfmt::skip]
let mut c_upper = ComplexMatrix::from(&[
[cpx!( 3.0, 1.0), cpx!(0.0, 0.0), cpx!(-2.0, 0.0), cpx!(0.0, 0.0)],
[cpx!( 0.0, 0.0), cpx!(3.0, 0.0), cpx!( 0.0, 0.0), cpx!(2.0, 0.0)],
[cpx!( 0.0, 0.0), cpx!(0.0, 0.0), cpx!( 3.0, 0.0), cpx!(1.0, 0.0)],
[cpx!( 0.0, 0.0), cpx!(0.0, 0.0), cpx!( 0.0, 0.0), cpx!(3.0, -1.0)],
]);
#[rustfmt::skip]
let a = ComplexMatrix::from(&[
[cpx!( 1.0, -1.0), cpx!(2.0, 0.0), cpx!(1.0, 0.0), cpx!( 1.0, 0.0), cpx!(-1.0, 0.0), cpx!( 0.0, 0.0)],
[cpx!( 2.0, 0.0), cpx!(2.0, 0.0), cpx!(1.0, 0.0), cpx!( 0.0, 0.0), cpx!( 0.0, 0.0), cpx!( 0.0, 1.0)],
[cpx!( 3.0, 1.0), cpx!(1.0, 0.0), cpx!(3.0, 0.0), cpx!( 1.0, 0.0), cpx!( 2.0, 0.0), cpx!(-1.0, 0.0)],
[cpx!( 1.0, 0.0), cpx!(0.0, 0.0), cpx!(1.0, 0.0), cpx!(-1.0, 0.0), cpx!( 0.0, 0.0), cpx!( 0.0, 1.0)],
]);
let (alpha, beta) = (cpx!(3.0, 0.0), cpx!(1.0, 0.0));
complex_mat_sym_rank_op(&mut c_lower, &a, alpha, beta, false, false).unwrap();
#[rustfmt::skip]
let c_ref = ComplexMatrix::from(&[
[cpx!(24.0, -5.0), cpx!( 0.0, 0.0), cpx!( 0.0, 0.0), cpx!(0.0, 0.0)],
[cpx!(20.0, -6.0), cpx!(27.0, 0.0), cpx!( 0.0, 0.0), cpx!(0.0, 0.0)],
[cpx!(20.0, -6.0), cpx!(34.0, 3.0), cpx!(75.0, 18.0), cpx!(0.0, 0.0)],
[cpx!( 2.0, -3.0), cpx!( 8.0, 0.0), cpx!(15.0, 0.0), cpx!(9.0, -1.0)],
]);
complex_mat_approx_eq(&c_lower, &c_ref, 1e-15);
complex_mat_sym_rank_op(&mut c_upper, &a, alpha, beta, true, false).unwrap();
#[rustfmt::skip]
let c_ref = ComplexMatrix::from(&[
[cpx!(24.0, -5.0), cpx!(21.0, -6.0), cpx!(22.0, -6.0), cpx!(3.0, -3.0)],
[cpx!( 0.0, 0.0), cpx!(27.0, 0.0), cpx!(33.0, 3.0), cpx!(8.0, 0.0)],
[cpx!( 0.0, 0.0), cpx!( 0.0, 0.0), cpx!(75.0, 18.0), cpx!(16.0, 0.0)],
[cpx!( 0.0, 0.0), cpx!( 0.0, 0.0), cpx!( 0.0, 0.0), cpx!(9.0, -1.0)],
]);
complex_mat_approx_eq(&c_upper, &c_ref, 1e-15);
}
#[test]
fn complex_mat_sym_rank_op_works_second_case() {
#[rustfmt::skip]
let mut c_lower = ComplexMatrix::from(&[
[ 3.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[ 0.0, 3.0, 0.0, 0.0, 0.0, 0.0],
[-3.0, 1.0, 4.0, 0.0, 0.0, 0.0],
[ 0.0, 2.0, 1.0, 3.0, 0.0, 0.0],
[ 0.0, 2.0, 1.0, 3.0, 4.0, 0.0],
[ 0.0, 2.0, 1.0, 3.0, 3.0, 4.0],
]);
#[rustfmt::skip]
let mut c_upper = ComplexMatrix::from(&[
[ 3.0, 0.0, -3.0, 0.0, 0.0, 0.0],
[ 0.0, 3.0, 1.0, 2.0, 2.0, 2.0],
[ 0.0, 0.0, 4.0, 1.0, 1.0, 1.0],
[ 0.0, 0.0, 0.0, 3.0, 3.0, 3.0],
[ 0.0, 0.0, 0.0, 0.0, 4.0, 3.0],
[ 0.0, 0.0, 0.0, 0.0, 0.0, 4.0],
]);
#[rustfmt::skip]
let a = ComplexMatrix::from(&[
[ 1.0, 2.0, 1.0, 1.0, -1.0, 0.0],
[ 2.0, 2.0, 1.0, 0.0, 0.0, 0.0],
[ 3.0, 1.0, 3.0, 1.0, 2.0, -1.0],
[ 1.0, 0.0, 1.0, -1.0, 0.0, 0.0],
]);
let (alpha, beta) = (cpx!(3.0, 0.0), cpx!(1.0, 0.0));
complex_mat_sym_rank_op(&mut c_lower, &a, alpha, beta, false, true).unwrap();
#[rustfmt::skip]
let c_ref = ComplexMatrix::from(&[
[48.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[27.0, 30.0, 0.0, 0.0, 0.0, 0.0],
[36.0, 22.0, 40.0, 0.0, 0.0, 0.0],
[ 9.0, 11.0, 10.0, 12.0, 0.0, 0.0],
[15.0, 2.0, 16.0, 6.0, 19.0, 0.0],
[-9.0, -1.0, -8.0, 0.0, -3.0, 7.0],
]);
complex_mat_approx_eq(&c_lower, &c_ref, 1e-15);
complex_mat_sym_rank_op(&mut c_upper, &a, alpha, beta, true, true).unwrap();
#[rustfmt::skip]
let c_ref = ComplexMatrix::from(&[
[48.0, 27.0, 36.0, 9.0, 15.0, -9.0],
[ 0.0, 30.0, 22.0, 11.0, 2.0, -1.0],
[ 0.0, 0.0, 40.0, 10.0, 16.0, -8.0],
[ 0.0, 0.0, 0.0, 12.0, 6.0, 0.0],
[ 0.0, 0.0, 0.0, 0.0, 19.0, -3.0],
[ 0.0, 0.0, 0.0, 0.0, 0.0, 7.0],
]);
complex_mat_approx_eq(&c_upper, &c_ref, 1e-15);
}
}