russell_lab/matrix/
complex_mat_sym_rank_op.rs

1use super::ComplexMatrix;
2use crate::{to_i32, Complex64, StrError, CBLAS_COL_MAJOR, CBLAS_LOWER, CBLAS_NO_TRANS, CBLAS_TRANS, CBLAS_UPPER};
3
4extern "C" {
5    // Performs one of the symmetric rank k operations
6    // <https://www.netlib.org/lapack/explore-html/de/d54/zsyrk_8f.html>
7    fn cblas_zsyrk(
8        layout: i32,
9        uplo: i32,
10        trans: i32,
11        n: i32,
12        k: i32,
13        alpha: *const Complex64,
14        a: *const Complex64,
15        lda: i32,
16        beta: *const Complex64,
17        c: *mut Complex64,
18        ldc: i32,
19    );
20}
21
22/// (zsyrk) Performs a symmetric rank k operation
23///
24/// Performs one of the symmetric rank k operations:
25///
26/// ```text
27/// First case:
28///
29///   c   := α ⋅ a   ⋅  aᵀ + β ⋅ c
30/// (n,n)      (n,k)  (k,n)    (n,n)
31/// ```
32///
33/// or
34///
35/// ```text
36/// Second case:
37///
38///   c   := α ⋅  aᵀ  ⋅  a + β ⋅ c
39/// (n,n)       (n,k)  (k,n)   (n,n)
40/// ```
41///
42/// where `c = cᵀ`
43///
44/// See also: <https://www.netlib.org/lapack/explore-html/de/d54/zsyrk_8f.html>
45///
46/// # Input
47///
48/// * `c` -- the (n,n) **symmetric** matrix (will be modified)
49/// * `a` -- the (n,k) matrix on the first case or (k,n) on the second case
50/// * `alpha` -- the α coefficient
51/// * `beta` -- the β coefficient
52/// * `upper` -- whether the upper triangle of `a` must be considered instead of the lower triangle
53/// * `second_case` -- indicates the second case illustrated above
54///
55/// # Examples
56///
57/// ```
58/// use russell_lab::*;
59///
60/// fn main() -> Result<(), StrError> {
61///     //  -1   2   0,
62///     //   2   1   2,
63///     //   0   2   1,
64///     let ___ = 0.0;
65///     #[rustfmt::skip]
66///     let mut c_lower = ComplexMatrix::from(&[
67///         [-1.0, ___, ___],
68///         [ 2.0, 1.0, ___],
69///         [ 0.0, 2.0, 1.0],
70///     ]);
71///
72///     #[rustfmt::skip]
73///     let a = ComplexMatrix::from(&[
74///         [ 1.0,  2.0, -1.0],
75///         [-1.0,  2.0,  0.0],
76///     ]);
77///
78///     let (alpha, beta) = (cpx!(-1.0, 1.0), cpx!(2.0, -1.0));
79///
80///     // c := (-1+1i) aᵀ⋅a + (2-1i) c
81///     complex_mat_sym_rank_op(&mut c_lower, &a, alpha, beta, false, true).unwrap();
82///
83///     let ________________ = cpx!(0.0, 0.0);
84///     #[rustfmt::skip]
85///     let c_ref = ComplexMatrix::from(&[
86///         [cpx!(-4.0,  3.0), ________________, ________________],
87///         [cpx!( 4.0, -2.0), cpx!(-6.0,  7.0), ________________],
88///         [cpx!( 1.0, -1.0), cpx!( 6.0, -4.0), cpx!( 1.0,  0.0)],
89///     ]);
90///     complex_mat_approx_eq(&c_lower, &c_ref, 1e-15);
91///     Ok(())
92/// }
93/// ```
94pub fn complex_mat_sym_rank_op(
95    c: &mut ComplexMatrix,
96    a: &ComplexMatrix,
97    alpha: Complex64,
98    beta: Complex64,
99    upper: bool,
100    second_case: bool,
101) -> Result<(), StrError> {
102    let (m, n) = c.dims();
103    if m != n {
104        return Err("[c] matrix must be square");
105    }
106    let (row, col) = a.dims();
107    let (lda, k, trans) = if !second_case {
108        //   c   := α ⋅ a   ⋅  aᵀ + β ⋅ c
109        // (n,n)      (n,k)  (k,n)    (n,n)
110        if row != n {
111            return Err("[a] matrix is incompatible");
112        }
113        (row, col, CBLAS_NO_TRANS)
114    } else {
115        //   c   := α ⋅  aᵀ  ⋅  a + β ⋅ c
116        // (n,n)       (n,k)  (k,n)   (n,n)
117        if col != n {
118            return Err("[a] matrix is incompatible");
119        }
120        (row, row, CBLAS_TRANS)
121    };
122    let uplo = if upper { CBLAS_UPPER } else { CBLAS_LOWER };
123    let n_i32 = to_i32(n);
124    let k_i32 = to_i32(k);
125    let ldc = n_i32;
126    unsafe {
127        cblas_zsyrk(
128            CBLAS_COL_MAJOR,
129            uplo,
130            trans,
131            n_i32,
132            k_i32,
133            &alpha,
134            a.as_data().as_ptr(),
135            to_i32(lda),
136            &beta,
137            c.as_mut_data().as_mut_ptr(),
138            ldc,
139        );
140    }
141    Ok(())
142}
143
144////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
145
146#[cfg(test)]
147mod tests {
148    use super::complex_mat_sym_rank_op;
149    use crate::{complex_mat_approx_eq, cpx, Complex64, ComplexMatrix};
150
151    #[test]
152    fn complex_mat_sym_rank_op_fail_on_wrong_dims() {
153        let mut c_2x2 = ComplexMatrix::new(2, 2);
154        let mut c_3x2 = ComplexMatrix::new(3, 2);
155        let a_2x3 = ComplexMatrix::new(2, 3);
156        let a_3x2 = ComplexMatrix::new(3, 2);
157        let alpha = cpx!(2.0, 1.0);
158        let beta = cpx!(3.0, 1.0);
159        assert_eq!(
160            complex_mat_sym_rank_op(&mut c_3x2, &a_3x2, alpha, beta, false, false).err(),
161            Some("[c] matrix must be square")
162        );
163        assert_eq!(
164            complex_mat_sym_rank_op(&mut c_2x2, &a_3x2, alpha, beta, false, false).err(),
165            Some("[a] matrix is incompatible")
166        );
167        assert_eq!(
168            complex_mat_sym_rank_op(&mut c_2x2, &a_2x3, alpha, beta, false, true).err(),
169            Some("[a] matrix is incompatible")
170        );
171    }
172
173    #[test]
174    fn complex_mat_sym_rank_op_works_first_case() {
175        // c matrix
176        // #[rustfmt::skip]
177        // let c = ComplexMatrix::from(&[
178        //     [cpx!( 3.0,  1.0), cpx!(0.0,  0.0), cpx!(-2.0,  0.0), cpx!(0.0,  0.0)],
179        //     [cpx!(-1.0,  0.0), cpx!(3.0,  0.0), cpx!( 0.0,  0.0), cpx!(2.0,  0.0)],
180        //     [cpx!(-4.0,  0.0), cpx!(1.0,  0.0), cpx!( 3.0,  0.0), cpx!(1.0,  0.0)],
181        //     [cpx!(-1.0,  0.0), cpx!(2.0,  0.0), cpx!( 0.0,  0.0), cpx!(3.0, -1.0)],
182        // ]);
183        #[rustfmt::skip]
184        let mut c_lower = ComplexMatrix::from(&[
185            [cpx!( 3.0,  1.0), cpx!(0.0,  0.0),  cpx!(0.0,  0.0), cpx!(0.0,  0.0)],
186            [cpx!(-1.0,  0.0), cpx!(3.0,  0.0),  cpx!(0.0,  0.0), cpx!(0.0,  0.0)],
187            [cpx!(-4.0,  0.0), cpx!(1.0,  0.0),  cpx!(3.0,  0.0), cpx!(0.0,  0.0)],
188            [cpx!(-1.0,  0.0), cpx!(2.0,  0.0),  cpx!(0.0,  0.0), cpx!(3.0, -1.0)],
189        ]);
190        #[rustfmt::skip]
191        let mut c_upper = ComplexMatrix::from(&[
192            [cpx!( 3.0,  1.0), cpx!(0.0,  0.0), cpx!(-2.0,  0.0), cpx!(0.0,  0.0)],
193            [cpx!( 0.0,  0.0), cpx!(3.0,  0.0), cpx!( 0.0,  0.0), cpx!(2.0,  0.0)],
194            [cpx!( 0.0,  0.0), cpx!(0.0,  0.0), cpx!( 3.0,  0.0), cpx!(1.0,  0.0)],
195            [cpx!( 0.0,  0.0), cpx!(0.0,  0.0), cpx!( 0.0,  0.0), cpx!(3.0, -1.0)],
196        ]);
197
198        // a matrix
199        #[rustfmt::skip]
200        let a = ComplexMatrix::from(&[
201            [cpx!( 1.0, -1.0),  cpx!(2.0, 0.0),  cpx!(1.0, 0.0), cpx!( 1.0, 0.0), cpx!(-1.0, 0.0), cpx!( 0.0,  0.0)],
202            [cpx!( 2.0,  0.0),  cpx!(2.0, 0.0),  cpx!(1.0, 0.0), cpx!( 0.0, 0.0), cpx!( 0.0, 0.0), cpx!( 0.0,  1.0)],
203            [cpx!( 3.0,  1.0),  cpx!(1.0, 0.0),  cpx!(3.0, 0.0), cpx!( 1.0, 0.0), cpx!( 2.0, 0.0), cpx!(-1.0,  0.0)],
204            [cpx!( 1.0,  0.0),  cpx!(0.0, 0.0),  cpx!(1.0, 0.0), cpx!(-1.0, 0.0), cpx!( 0.0, 0.0), cpx!( 0.0,  1.0)],
205        ]);
206
207        // constants
208        let (alpha, beta) = (cpx!(3.0, 0.0), cpx!(1.0, 0.0));
209
210        // lower: c := 3⋅a⋅aᵀ + c
211        complex_mat_sym_rank_op(&mut c_lower, &a, alpha, beta, false, false).unwrap();
212        // println!("{}", c_lower);
213        #[rustfmt::skip]
214        let c_ref = ComplexMatrix::from(&[
215            [cpx!(24.0, -5.0), cpx!( 0.0,  0.0), cpx!( 0.0,   0.0),  cpx!(0.0,  0.0)],
216            [cpx!(20.0, -6.0), cpx!(27.0,  0.0), cpx!( 0.0,   0.0),  cpx!(0.0,  0.0)],
217            [cpx!(20.0, -6.0), cpx!(34.0,  3.0), cpx!(75.0,  18.0),  cpx!(0.0,  0.0)],
218            [cpx!( 2.0, -3.0), cpx!( 8.0,  0.0), cpx!(15.0,   0.0),  cpx!(9.0, -1.0)],
219        ]);
220        complex_mat_approx_eq(&c_lower, &c_ref, 1e-15);
221
222        // upper: c := 3⋅a⋅aᵀ + c
223        complex_mat_sym_rank_op(&mut c_upper, &a, alpha, beta, true, false).unwrap();
224        // println!("{}", c_upper);
225        #[rustfmt::skip]
226        let c_ref = ComplexMatrix::from(&[
227            [cpx!(24.0, -5.0), cpx!(21.0, -6.0), cpx!(22.0,  -6.0),  cpx!(3.0, -3.0)],
228            [cpx!( 0.0,  0.0), cpx!(27.0,  0.0), cpx!(33.0,   3.0),  cpx!(8.0,  0.0)],
229            [cpx!( 0.0,  0.0), cpx!( 0.0,  0.0), cpx!(75.0,  18.0), cpx!(16.0,  0.0)],
230            [cpx!( 0.0,  0.0), cpx!( 0.0,  0.0), cpx!( 0.0,   0.0),  cpx!(9.0, -1.0)],
231        ]);
232        complex_mat_approx_eq(&c_upper, &c_ref, 1e-15);
233    }
234
235    #[test]
236    fn complex_mat_sym_rank_op_works_second_case() {
237        // c matrix
238        // #[rustfmt::skip]
239        // let c = Matrix::from(&[
240        //     [ 3.0, 0.0, -3.0, 0.0, 0.0, 0.0],
241        //     [ 0.0, 3.0,  1.0, 2.0, 2.0, 2.0],
242        //     [-3.0, 1.0,  4.0, 1.0, 1.0, 1.0],
243        //     [ 0.0, 2.0,  1.0, 3.0, 3.0, 3.0],
244        //     [ 0.0, 2.0,  1.0, 3.0, 4.0, 3.0],
245        //     [ 0.0, 2.0,  1.0, 3.0, 3.0, 4.0],
246        // ]);
247        #[rustfmt::skip]
248        let mut c_lower = ComplexMatrix::from(&[
249            [ 3.0, 0.0,  0.0, 0.0, 0.0, 0.0],
250            [ 0.0, 3.0,  0.0, 0.0, 0.0, 0.0],
251            [-3.0, 1.0,  4.0, 0.0, 0.0, 0.0],
252            [ 0.0, 2.0,  1.0, 3.0, 0.0, 0.0],
253            [ 0.0, 2.0,  1.0, 3.0, 4.0, 0.0],
254            [ 0.0, 2.0,  1.0, 3.0, 3.0, 4.0],
255        ]);
256        #[rustfmt::skip]
257        let mut c_upper = ComplexMatrix::from(&[
258            [ 3.0, 0.0, -3.0, 0.0, 0.0, 0.0],
259            [ 0.0, 3.0,  1.0, 2.0, 2.0, 2.0],
260            [ 0.0, 0.0,  4.0, 1.0, 1.0, 1.0],
261            [ 0.0, 0.0,  0.0, 3.0, 3.0, 3.0],
262            [ 0.0, 0.0,  0.0, 0.0, 4.0, 3.0],
263            [ 0.0, 0.0,  0.0, 0.0, 0.0, 4.0],
264        ]);
265
266        // a matrix
267        #[rustfmt::skip]
268        let a = ComplexMatrix::from(&[
269            [ 1.0,  2.0,  1.0,  1.0, -1.0,  0.0],
270            [ 2.0,  2.0,  1.0,  0.0,  0.0,  0.0],
271            [ 3.0,  1.0,  3.0,  1.0,  2.0, -1.0],
272            [ 1.0,  0.0,  1.0, -1.0,  0.0,  0.0],
273        ]);
274
275        // constants
276        let (alpha, beta) = (cpx!(3.0, 0.0), cpx!(1.0, 0.0));
277
278        // lower: c := 3⋅aᵀ⋅a + c
279        complex_mat_sym_rank_op(&mut c_lower, &a, alpha, beta, false, true).unwrap();
280        // println!("{}", c_lower);
281        #[rustfmt::skip]
282        let c_ref = ComplexMatrix::from(&[
283            [48.0,  0.0,  0.0,  0.0,  0.0,  0.0],
284            [27.0, 30.0,  0.0,  0.0,  0.0,  0.0],
285            [36.0, 22.0, 40.0,  0.0,  0.0,  0.0],
286            [ 9.0, 11.0, 10.0, 12.0,  0.0,  0.0],
287            [15.0,  2.0, 16.0,  6.0, 19.0,  0.0],
288            [-9.0, -1.0, -8.0,  0.0, -3.0,  7.0],
289        ]);
290        complex_mat_approx_eq(&c_lower, &c_ref, 1e-15);
291
292        // upper: c := 3⋅aᵀ⋅a + c
293        complex_mat_sym_rank_op(&mut c_upper, &a, alpha, beta, true, true).unwrap();
294        // println!("{}", c_upper);
295        #[rustfmt::skip]
296        let c_ref = ComplexMatrix::from(&[
297            [48.0, 27.0, 36.0,  9.0, 15.0, -9.0],
298            [ 0.0, 30.0, 22.0, 11.0,  2.0, -1.0],
299            [ 0.0,  0.0, 40.0, 10.0, 16.0, -8.0],
300            [ 0.0,  0.0,  0.0, 12.0,  6.0,  0.0],
301            [ 0.0,  0.0,  0.0,  0.0, 19.0, -3.0],
302            [ 0.0,  0.0,  0.0,  0.0,  0.0,  7.0],
303        ]);
304        complex_mat_approx_eq(&c_upper, &c_ref, 1e-15);
305    }
306}