1use super::ComplexMatrix;
2use crate::{to_i32, Complex64, StrError, CBLAS_COL_MAJOR, CBLAS_LOWER, CBLAS_NO_TRANS, CBLAS_TRANS, CBLAS_UPPER};
3
4extern "C" {
5 fn cblas_zsyrk(
8 layout: i32,
9 uplo: i32,
10 trans: i32,
11 n: i32,
12 k: i32,
13 alpha: *const Complex64,
14 a: *const Complex64,
15 lda: i32,
16 beta: *const Complex64,
17 c: *mut Complex64,
18 ldc: i32,
19 );
20}
21
22pub fn complex_mat_sym_rank_op(
95 c: &mut ComplexMatrix,
96 a: &ComplexMatrix,
97 alpha: Complex64,
98 beta: Complex64,
99 upper: bool,
100 second_case: bool,
101) -> Result<(), StrError> {
102 let (m, n) = c.dims();
103 if m != n {
104 return Err("[c] matrix must be square");
105 }
106 let (row, col) = a.dims();
107 let (lda, k, trans) = if !second_case {
108 if row != n {
111 return Err("[a] matrix is incompatible");
112 }
113 (row, col, CBLAS_NO_TRANS)
114 } else {
115 if col != n {
118 return Err("[a] matrix is incompatible");
119 }
120 (row, row, CBLAS_TRANS)
121 };
122 let uplo = if upper { CBLAS_UPPER } else { CBLAS_LOWER };
123 let n_i32 = to_i32(n);
124 let k_i32 = to_i32(k);
125 let ldc = n_i32;
126 unsafe {
127 cblas_zsyrk(
128 CBLAS_COL_MAJOR,
129 uplo,
130 trans,
131 n_i32,
132 k_i32,
133 &alpha,
134 a.as_data().as_ptr(),
135 to_i32(lda),
136 &beta,
137 c.as_mut_data().as_mut_ptr(),
138 ldc,
139 );
140 }
141 Ok(())
142}
143
144#[cfg(test)]
147mod tests {
148 use super::complex_mat_sym_rank_op;
149 use crate::{complex_mat_approx_eq, cpx, Complex64, ComplexMatrix};
150
151 #[test]
152 fn complex_mat_sym_rank_op_fail_on_wrong_dims() {
153 let mut c_2x2 = ComplexMatrix::new(2, 2);
154 let mut c_3x2 = ComplexMatrix::new(3, 2);
155 let a_2x3 = ComplexMatrix::new(2, 3);
156 let a_3x2 = ComplexMatrix::new(3, 2);
157 let alpha = cpx!(2.0, 1.0);
158 let beta = cpx!(3.0, 1.0);
159 assert_eq!(
160 complex_mat_sym_rank_op(&mut c_3x2, &a_3x2, alpha, beta, false, false).err(),
161 Some("[c] matrix must be square")
162 );
163 assert_eq!(
164 complex_mat_sym_rank_op(&mut c_2x2, &a_3x2, alpha, beta, false, false).err(),
165 Some("[a] matrix is incompatible")
166 );
167 assert_eq!(
168 complex_mat_sym_rank_op(&mut c_2x2, &a_2x3, alpha, beta, false, true).err(),
169 Some("[a] matrix is incompatible")
170 );
171 }
172
173 #[test]
174 fn complex_mat_sym_rank_op_works_first_case() {
175 #[rustfmt::skip]
184 let mut c_lower = ComplexMatrix::from(&[
185 [cpx!( 3.0, 1.0), cpx!(0.0, 0.0), cpx!(0.0, 0.0), cpx!(0.0, 0.0)],
186 [cpx!(-1.0, 0.0), cpx!(3.0, 0.0), cpx!(0.0, 0.0), cpx!(0.0, 0.0)],
187 [cpx!(-4.0, 0.0), cpx!(1.0, 0.0), cpx!(3.0, 0.0), cpx!(0.0, 0.0)],
188 [cpx!(-1.0, 0.0), cpx!(2.0, 0.0), cpx!(0.0, 0.0), cpx!(3.0, -1.0)],
189 ]);
190 #[rustfmt::skip]
191 let mut c_upper = ComplexMatrix::from(&[
192 [cpx!( 3.0, 1.0), cpx!(0.0, 0.0), cpx!(-2.0, 0.0), cpx!(0.0, 0.0)],
193 [cpx!( 0.0, 0.0), cpx!(3.0, 0.0), cpx!( 0.0, 0.0), cpx!(2.0, 0.0)],
194 [cpx!( 0.0, 0.0), cpx!(0.0, 0.0), cpx!( 3.0, 0.0), cpx!(1.0, 0.0)],
195 [cpx!( 0.0, 0.0), cpx!(0.0, 0.0), cpx!( 0.0, 0.0), cpx!(3.0, -1.0)],
196 ]);
197
198 #[rustfmt::skip]
200 let a = ComplexMatrix::from(&[
201 [cpx!( 1.0, -1.0), cpx!(2.0, 0.0), cpx!(1.0, 0.0), cpx!( 1.0, 0.0), cpx!(-1.0, 0.0), cpx!( 0.0, 0.0)],
202 [cpx!( 2.0, 0.0), cpx!(2.0, 0.0), cpx!(1.0, 0.0), cpx!( 0.0, 0.0), cpx!( 0.0, 0.0), cpx!( 0.0, 1.0)],
203 [cpx!( 3.0, 1.0), cpx!(1.0, 0.0), cpx!(3.0, 0.0), cpx!( 1.0, 0.0), cpx!( 2.0, 0.0), cpx!(-1.0, 0.0)],
204 [cpx!( 1.0, 0.0), cpx!(0.0, 0.0), cpx!(1.0, 0.0), cpx!(-1.0, 0.0), cpx!( 0.0, 0.0), cpx!( 0.0, 1.0)],
205 ]);
206
207 let (alpha, beta) = (cpx!(3.0, 0.0), cpx!(1.0, 0.0));
209
210 complex_mat_sym_rank_op(&mut c_lower, &a, alpha, beta, false, false).unwrap();
212 #[rustfmt::skip]
214 let c_ref = ComplexMatrix::from(&[
215 [cpx!(24.0, -5.0), cpx!( 0.0, 0.0), cpx!( 0.0, 0.0), cpx!(0.0, 0.0)],
216 [cpx!(20.0, -6.0), cpx!(27.0, 0.0), cpx!( 0.0, 0.0), cpx!(0.0, 0.0)],
217 [cpx!(20.0, -6.0), cpx!(34.0, 3.0), cpx!(75.0, 18.0), cpx!(0.0, 0.0)],
218 [cpx!( 2.0, -3.0), cpx!( 8.0, 0.0), cpx!(15.0, 0.0), cpx!(9.0, -1.0)],
219 ]);
220 complex_mat_approx_eq(&c_lower, &c_ref, 1e-15);
221
222 complex_mat_sym_rank_op(&mut c_upper, &a, alpha, beta, true, false).unwrap();
224 #[rustfmt::skip]
226 let c_ref = ComplexMatrix::from(&[
227 [cpx!(24.0, -5.0), cpx!(21.0, -6.0), cpx!(22.0, -6.0), cpx!(3.0, -3.0)],
228 [cpx!( 0.0, 0.0), cpx!(27.0, 0.0), cpx!(33.0, 3.0), cpx!(8.0, 0.0)],
229 [cpx!( 0.0, 0.0), cpx!( 0.0, 0.0), cpx!(75.0, 18.0), cpx!(16.0, 0.0)],
230 [cpx!( 0.0, 0.0), cpx!( 0.0, 0.0), cpx!( 0.0, 0.0), cpx!(9.0, -1.0)],
231 ]);
232 complex_mat_approx_eq(&c_upper, &c_ref, 1e-15);
233 }
234
235 #[test]
236 fn complex_mat_sym_rank_op_works_second_case() {
237 #[rustfmt::skip]
248 let mut c_lower = ComplexMatrix::from(&[
249 [ 3.0, 0.0, 0.0, 0.0, 0.0, 0.0],
250 [ 0.0, 3.0, 0.0, 0.0, 0.0, 0.0],
251 [-3.0, 1.0, 4.0, 0.0, 0.0, 0.0],
252 [ 0.0, 2.0, 1.0, 3.0, 0.0, 0.0],
253 [ 0.0, 2.0, 1.0, 3.0, 4.0, 0.0],
254 [ 0.0, 2.0, 1.0, 3.0, 3.0, 4.0],
255 ]);
256 #[rustfmt::skip]
257 let mut c_upper = ComplexMatrix::from(&[
258 [ 3.0, 0.0, -3.0, 0.0, 0.0, 0.0],
259 [ 0.0, 3.0, 1.0, 2.0, 2.0, 2.0],
260 [ 0.0, 0.0, 4.0, 1.0, 1.0, 1.0],
261 [ 0.0, 0.0, 0.0, 3.0, 3.0, 3.0],
262 [ 0.0, 0.0, 0.0, 0.0, 4.0, 3.0],
263 [ 0.0, 0.0, 0.0, 0.0, 0.0, 4.0],
264 ]);
265
266 #[rustfmt::skip]
268 let a = ComplexMatrix::from(&[
269 [ 1.0, 2.0, 1.0, 1.0, -1.0, 0.0],
270 [ 2.0, 2.0, 1.0, 0.0, 0.0, 0.0],
271 [ 3.0, 1.0, 3.0, 1.0, 2.0, -1.0],
272 [ 1.0, 0.0, 1.0, -1.0, 0.0, 0.0],
273 ]);
274
275 let (alpha, beta) = (cpx!(3.0, 0.0), cpx!(1.0, 0.0));
277
278 complex_mat_sym_rank_op(&mut c_lower, &a, alpha, beta, false, true).unwrap();
280 #[rustfmt::skip]
282 let c_ref = ComplexMatrix::from(&[
283 [48.0, 0.0, 0.0, 0.0, 0.0, 0.0],
284 [27.0, 30.0, 0.0, 0.0, 0.0, 0.0],
285 [36.0, 22.0, 40.0, 0.0, 0.0, 0.0],
286 [ 9.0, 11.0, 10.0, 12.0, 0.0, 0.0],
287 [15.0, 2.0, 16.0, 6.0, 19.0, 0.0],
288 [-9.0, -1.0, -8.0, 0.0, -3.0, 7.0],
289 ]);
290 complex_mat_approx_eq(&c_lower, &c_ref, 1e-15);
291
292 complex_mat_sym_rank_op(&mut c_upper, &a, alpha, beta, true, true).unwrap();
294 #[rustfmt::skip]
296 let c_ref = ComplexMatrix::from(&[
297 [48.0, 27.0, 36.0, 9.0, 15.0, -9.0],
298 [ 0.0, 30.0, 22.0, 11.0, 2.0, -1.0],
299 [ 0.0, 0.0, 40.0, 10.0, 16.0, -8.0],
300 [ 0.0, 0.0, 0.0, 12.0, 6.0, 0.0],
301 [ 0.0, 0.0, 0.0, 0.0, 19.0, -3.0],
302 [ 0.0, 0.0, 0.0, 0.0, 0.0, 7.0],
303 ]);
304 complex_mat_approx_eq(&c_upper, &c_ref, 1e-15);
305 }
306}