use super::ComplexMatrix;
use crate::{to_i32, Complex64, StrError, CBLAS_COL_MAJOR, CBLAS_CONJ_TRANS, CBLAS_LOWER, CBLAS_NO_TRANS, CBLAS_UPPER};
extern "C" {
fn cblas_zherk(
layout: i32,
uplo: i32,
trans: i32,
n: i32,
k: i32,
alpha: f64,
a: *const Complex64,
lda: i32,
beta: f64,
c: *mut Complex64,
ldc: i32,
);
}
pub fn complex_mat_herm_rank_op(
c: &mut ComplexMatrix,
a: &ComplexMatrix,
alpha: f64,
beta: f64,
upper: bool,
second_case: bool,
) -> Result<(), StrError> {
let (m, n) = c.dims();
if m != n {
return Err("[c] matrix must be square");
}
let (row, col) = a.dims();
let (lda, k, trans) = if !second_case {
if row != n {
return Err("[a] matrix is incompatible");
}
(row, col, CBLAS_NO_TRANS)
} else {
if col != n {
return Err("[a] matrix is incompatible");
}
(row, row, CBLAS_CONJ_TRANS)
};
let uplo = if upper { CBLAS_UPPER } else { CBLAS_LOWER };
let n_i32 = to_i32(n);
let k_i32 = to_i32(k);
let ldc = n_i32;
unsafe {
cblas_zherk(
CBLAS_COL_MAJOR,
uplo,
trans,
n_i32,
k_i32,
alpha,
a.as_data().as_ptr(),
to_i32(lda),
beta,
c.as_mut_data().as_mut_ptr(),
ldc,
);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::{complex_mat_herm_rank_op, ComplexMatrix};
use crate::matrix::testing::check_hermitian_uplo;
use crate::{complex_mat_approx_eq, cpx, Complex64};
#[test]
fn complex_mat_herm_rank_op_fail_on_wrong_dims() {
let mut c_2x2 = ComplexMatrix::new(2, 2);
let mut c_3x2 = ComplexMatrix::new(3, 2);
let a_2x3 = ComplexMatrix::new(2, 3);
let a_3x2 = ComplexMatrix::new(3, 2);
let alpha = 2.0;
let beta = 3.0;
assert_eq!(
complex_mat_herm_rank_op(&mut c_3x2, &a_3x2, alpha, beta, false, false).err(),
Some("[c] matrix must be square")
);
assert_eq!(
complex_mat_herm_rank_op(&mut c_2x2, &a_3x2, alpha, beta, false, false).err(),
Some("[a] matrix is incompatible")
);
assert_eq!(
complex_mat_herm_rank_op(&mut c_2x2, &a_2x3, alpha, beta, false, true).err(),
Some("[a] matrix is incompatible")
);
}
#[test]
fn complex_mat_herm_rank_op_works_first_case() {
#[rustfmt::skip]
let c = ComplexMatrix::from(&[
[cpx!( 4.0, 0.0), cpx!(0.0, 1.0), cpx!(-3.0, 1.0), cpx!(0.0, 2.0)],
[cpx!( 0.0, -1.0), cpx!(3.0, 0.0), cpx!( 1.0, 0.0), cpx!(2.0, 0.0)],
[cpx!(-3.0, -1.0), cpx!(1.0, 0.0), cpx!( 4.0, 0.0), cpx!(1.0, -1.0)],
[cpx!( 0.0, -2.0), cpx!(2.0, 0.0), cpx!( 1.0, 1.0), cpx!(4.0, 0.0)],
]);
#[rustfmt::skip]
let mut c_lower = ComplexMatrix::from(&[
[cpx!( 4.0, 0.0), cpx!(0.0, 0.0), cpx!( 0.0, 0.0), cpx!(0.0, 0.0)],
[cpx!( 0.0, -1.0), cpx!(3.0, 0.0), cpx!( 0.0, 0.0), cpx!(0.0, 0.0)],
[cpx!(-3.0, -1.0), cpx!(1.0, 0.0), cpx!( 4.0, 0.0), cpx!(0.0, 0.0)],
[cpx!( 0.0, -2.0), cpx!(2.0, 0.0), cpx!( 1.0, 1.0), cpx!(4.0, 0.0)],
]);
#[rustfmt::skip]
let mut c_upper = ComplexMatrix::from(&[
[cpx!( 4.0, 0.0), cpx!(0.0, 1.0), cpx!(-3.0, 1.0), cpx!(0.0, 2.0)],
[cpx!( 0.0, 0.0), cpx!(3.0, 0.0), cpx!( 1.0, 0.0), cpx!(2.0, 0.0)],
[cpx!( 0.0, 0.0), cpx!(0.0, 0.0), cpx!( 4.0, 0.0), cpx!(1.0, -1.0)],
[cpx!( 0.0, 0.0), cpx!(0.0, 0.0), cpx!( 0.0, 0.0), cpx!(4.0, 0.0)],
]);
check_hermitian_uplo(&c, &c_lower, &c_upper);
#[rustfmt::skip]
let a = ComplexMatrix::from(&[
[cpx!( 1.0, -1.0), cpx!( 2.0, 0.0), cpx!( 1.0, 0.0), cpx!( 1.0, 0.0), cpx!(-1.0, 0.0), cpx!( 0.0, 0.0)],
[cpx!( 2.0, 0.0), cpx!( 2.0, 0.0), cpx!( 1.0, 0.0), cpx!( 0.0, 0.0), cpx!( 0.0, 0.0), cpx!( 0.0, 1.0)],
[cpx!( 3.0, 1.0), cpx!( 1.0, 0.0), cpx!( 3.0, 0.0), cpx!( 1.0, 0.0), cpx!( 2.0, 0.0), cpx!(-1.0, 0.0)],
[cpx!( 1.0, 0.0), cpx!( 0.0, 0.0), cpx!( 1.0, 0.0), cpx!(-1.0, 0.0), cpx!( 0.0, 0.0), cpx!( 0.0, 1.0)],
]);
let (alpha, beta) = (3.0, -2.0);
#[rustfmt::skip]
let c_ref_full = ComplexMatrix::from(&[
[cpx!(19.0, 0.0), cpx!(21.0, -8.0), cpx!(24.0, -14.0), cpx!( 3.0, -7.0)],
[cpx!(21.0, 8.0), cpx!(24.0, 0.0), cpx!(31.0, -9.0), cpx!( 8.0, 0.0)],
[cpx!(24.0, 14.0), cpx!(31.0, 9.0), cpx!(70.0, 0.0), cpx!(13.0, 8.0)],
[cpx!( 3.0, 7.0), cpx!( 8.0, 0.0), cpx!(13.0, -8.0), cpx!( 4.0, 0.0)],
]);
#[rustfmt::skip]
let c_ref_lower = ComplexMatrix::from(&[
[cpx!(19.0, 0.0), cpx!( 0.0, 0.0), cpx!( 0.0, 0.0), cpx!( 0.0, 0.0)],
[cpx!(21.0, 8.0), cpx!(24.0, 0.0), cpx!( 0.0, 0.0), cpx!( 0.0, 0.0)],
[cpx!(24.0, 14.0), cpx!(31.0, 9.0), cpx!(70.0, 0.0), cpx!( 0.0, 0.0)],
[cpx!( 3.0, 7.0), cpx!( 8.0, 0.0), cpx!(13.0, -8.0), cpx!( 4.0, 0.0)],
]);
#[rustfmt::skip]
let c_ref_upper = ComplexMatrix::from(&[
[cpx!(19.0, 0.0), cpx!(21.0, -8.0), cpx!(24.0, -14.0), cpx!( 3.0, -7.0)],
[cpx!( 0.0, 0.0), cpx!(24.0, 0.0), cpx!(31.0, -9.0), cpx!( 8.0, 0.0)],
[cpx!( 0.0, 0.0), cpx!( 0.0, 0.0), cpx!(70.0, 0.0), cpx!(13.0, 8.0)],
[cpx!( 0.0, 0.0), cpx!( 0.0, 0.0), cpx!( 0.0, 0.0), cpx!( 4.0, 0.0)],
]);
check_hermitian_uplo(&c_ref_full, &c_ref_lower, &c_ref_upper);
complex_mat_herm_rank_op(&mut c_lower, &a, alpha, beta, false, false).unwrap();
complex_mat_approx_eq(&c_lower, &c_ref_lower, 1e-15);
complex_mat_herm_rank_op(&mut c_upper, &a, alpha, beta, true, false).unwrap();
complex_mat_approx_eq(&c_upper, &c_ref_upper, 1e-15);
}
#[test]
fn complex_mat_herm_rank_op_works_second_case() {
#[rustfmt::skip]
let c = ComplexMatrix::from(&[
[cpx!( 4.0, 0.0), cpx!(0.0, 1.0), cpx!(-3.0, 1.0), cpx!(0.0, 2.0)],
[cpx!( 0.0, -1.0), cpx!(3.0, 0.0), cpx!( 1.0, 0.0), cpx!(2.0, 0.0)],
[cpx!(-3.0, -1.0), cpx!(1.0, 0.0), cpx!( 4.0, 0.0), cpx!(1.0, -1.0)],
[cpx!( 0.0, -2.0), cpx!(2.0, 0.0), cpx!( 1.0, 1.0), cpx!(4.0, 0.0)],
]);
#[rustfmt::skip]
let mut c_lower = ComplexMatrix::from(&[
[cpx!( 4.0, 0.0), cpx!(0.0, 0.0), cpx!( 0.0, 0.0), cpx!(0.0, 0.0)],
[cpx!( 0.0, -1.0), cpx!(3.0, 0.0), cpx!( 0.0, 0.0), cpx!(0.0, 0.0)],
[cpx!(-3.0, -1.0), cpx!(1.0, 0.0), cpx!( 4.0, 0.0), cpx!(0.0, 0.0)],
[cpx!( 0.0, -2.0), cpx!(2.0, 0.0), cpx!( 1.0, 1.0), cpx!(4.0, 0.0)],
]);
#[rustfmt::skip]
let mut c_upper = ComplexMatrix::from(&[
[cpx!( 4.0, 0.0), cpx!(0.0, 1.0), cpx!(-3.0, 1.0), cpx!(0.0, 2.0)],
[cpx!( 0.0, 0.0), cpx!(3.0, 0.0), cpx!( 1.0, 0.0), cpx!(2.0, 0.0)],
[cpx!( 0.0, 0.0), cpx!(0.0, 0.0), cpx!( 4.0, 0.0), cpx!(1.0, -1.0)],
[cpx!( 0.0, 0.0), cpx!(0.0, 0.0), cpx!( 0.0, 0.0), cpx!(4.0, 0.0)],
]);
check_hermitian_uplo(&c, &c_lower, &c_upper);
#[rustfmt::skip]
let a = ComplexMatrix::from(&[
[cpx!( 1.0, -1.0), cpx!( 2.0, 0.0), cpx!( 1.0, 0.0), cpx!( 1.0, 0.0)],
[cpx!( 3.0, 1.0), cpx!( 1.0, 0.0), cpx!( 3.0, 0.0), cpx!( 1.0, 2.0)],
]);
let (alpha, beta) = (3.0, -2.0);
#[rustfmt::skip]
let c_ref_full = ComplexMatrix::from(&[
[cpx!(28.0, 0.0), cpx!(15.0, 1.0), cpx!(36.0, -8.0), cpx!(18.0, 14.0)],
[cpx!(15.0, -1.0), cpx!( 9.0, 0.0), cpx!(13.0, 0.0), cpx!( 5.0, 6.0)],
[cpx!(36.0, 8.0), cpx!(13.0, 0.0), cpx!(22.0, 0.0), cpx!(10.0, 20.0)],
[cpx!(18.0, -14.0), cpx!( 5.0, -6.0), cpx!(10.0, -20.0), cpx!(10.0, 0.0)],
]);
#[rustfmt::skip]
let c_ref_lower = ComplexMatrix::from(&[
[cpx!(28.0, 0.0), cpx!( 0.0, 0.0), cpx!( 0.0, 0.0), cpx!( 0.0, 0.0)],
[cpx!(15.0, -1.0), cpx!( 9.0, 0.0), cpx!( 0.0, 0.0), cpx!( 0.0, 0.0)],
[cpx!(36.0, 8.0), cpx!(13.0, 0.0), cpx!(22.0, 0.0), cpx!( 0.0, 0.0)],
[cpx!(18.0, -14.0), cpx!( 5.0, -6.0), cpx!(10.0, -20.0), cpx!(10.0, 0.0)],
]);
#[rustfmt::skip]
let c_ref_upper = ComplexMatrix::from(&[
[cpx!(28.0, 0.0), cpx!(15.0, 1.0), cpx!(36.0, -8.0), cpx!(18.0, 14.0)],
[cpx!( 0.0, 0.0), cpx!( 9.0, 0.0), cpx!(13.0, 0.0), cpx!( 5.0, 6.0)],
[cpx!( 0.0, 0.0), cpx!( 0.0, 0.0), cpx!(22.0, 0.0), cpx!(10.0, 20.0)],
[cpx!( 0.0, 0.0), cpx!( 0.0, 0.0), cpx!( 0.0, 0.0), cpx!(10.0, 0.0)],
]);
check_hermitian_uplo(&c_ref_full, &c_ref_lower, &c_ref_upper);
complex_mat_herm_rank_op(&mut c_lower, &a, alpha, beta, false, true).unwrap();
println!("{}", c_lower);
complex_mat_approx_eq(&c_lower, &c_ref_lower, 1e-15);
complex_mat_herm_rank_op(&mut c_upper, &a, alpha, beta, true, true).unwrap();
println!("{}", c_upper);
complex_mat_approx_eq(&c_upper, &c_ref_upper, 1e-15);
}
}