1use crate::assert;
2use crate::internal_prelude::*;
3use faer_traits::RealReg;
4
5#[derive(Copy, Clone, Debug, PartialEq, Eq)]
7pub enum NanHandling {
8 Propagate,
10 Ignore,
12}
13
14#[inline(always)]
15#[math]
16fn from_usize<T: RealField>(n: usize) -> T {
17 from_f64::<T>(n as u32 as f64) + (from_f64::<T>((n as u64 - (n as u32 as u64)) as f64))
18}
19
20#[inline(always)]
21fn reduce<T: ComplexField, S: pulp::Simd>(non_nan_count: T::SimdIndex<S>) -> usize {
22 let slice: &[T::Index] = bytemuck::cast_slice(core::slice::from_ref(&non_nan_count));
23
24 let mut acc = 0usize;
25 for &count in slice {
26 acc += count.zx();
27 }
28 acc
29}
30
31fn col_mean_row_major_ignore_nan<T: ComplexField>(out: ColMut<'_, T>, mat: MatRef<'_, T, usize, usize, isize, ContiguousFwd>) {
32 struct Impl<'a, T: ComplexField> {
33 out: ColMut<'a, T>,
34 mat: MatRef<'a, T, usize, usize, isize, ContiguousFwd>,
35 }
36
37 impl<T: ComplexField> pulp::WithSimd for Impl<'_, T> {
38 type Output = ();
39
40 #[inline(always)]
41 fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
42 let Self { out, mat } = self;
43 with_dim!(M, mat.nrows());
44 with_dim!(N, mat.ncols());
45
46 let simd = SimdCtx::<T, S>::new(T::simd_ctx(simd), N);
47
48 let mut out = out.as_row_shape_mut(M);
49 let mat = mat.as_shape(M, N);
50 let indices = simd.batch_indices::<4>();
51
52 let nan = simd.splat(&nan::<T>());
53 for i in M.indices() {
54 let row = mat.row(i).transpose();
55 let (head, mut body4, body1, tail) = indices.clone();
56
57 let mut non_nan_count_total = 0usize;
58
59 #[inline(always)]
60 fn process<'M, T: ComplexField, S: pulp::Simd>(
61 simd: SimdCtx<'M, T, S>,
62 acc: T::SimdVec<S>,
63 non_nan_count: T::SimdIndex<S>,
64 val: T::SimdVec<S>,
65 ) -> (T::SimdVec<S>, T::SimdIndex<S>) {
66 let is_not_nan = (*simd).eq(val, val);
67
68 (
69 simd.select(is_not_nan, simd.add(acc, val), acc),
70 simd.iselect(is_not_nan, simd.iadd(non_nan_count, simd.isplat(T::Index::truncate(1))), non_nan_count),
71 )
72 }
73
74 let mut sum0 = simd.splat(&zero::<T>());
75 let mut sum1 = simd.splat(&zero::<T>());
76 let mut sum2 = simd.splat(&zero::<T>());
77 let mut sum3 = simd.splat(&zero::<T>());
78 let mut non_nan_count0 = simd.isplat(T::Index::truncate(0));
79 let mut non_nan_count1 = simd.isplat(T::Index::truncate(0));
80 let mut non_nan_count2 = simd.isplat(T::Index::truncate(0));
81 let mut non_nan_count3 = simd.isplat(T::Index::truncate(0));
82
83 if let Some(i) = head {
84 (sum0, non_nan_count0) = process(simd, sum0, non_nan_count0, simd.select(simd.head_mask(), simd.read(row, i), nan));
85 non_nan_count_total += reduce::<T, S>(non_nan_count0);
86 non_nan_count0 = simd.isplat(T::Index::truncate(0));
87 }
88
89 loop {
90 if body4.len() == 0 {
91 break;
92 }
93 for [i0, i1, i2, i3] in (&mut body4).take(256) {
94 (sum0, non_nan_count0) = process(simd, sum0, non_nan_count0, simd.read(row, i0));
95 (sum1, non_nan_count1) = process(simd, sum1, non_nan_count1, simd.read(row, i1));
96 (sum2, non_nan_count2) = process(simd, sum2, non_nan_count2, simd.read(row, i2));
97 (sum3, non_nan_count3) = process(simd, sum3, non_nan_count3, simd.read(row, i3));
98 }
99
100 non_nan_count0 = simd.iadd(non_nan_count0, non_nan_count1);
101 non_nan_count2 = simd.iadd(non_nan_count2, non_nan_count3);
102 non_nan_count0 = simd.iadd(non_nan_count0, non_nan_count2);
103 non_nan_count_total += reduce::<T, S>(non_nan_count0);
104 non_nan_count0 = simd.isplat(T::Index::truncate(0));
105 non_nan_count1 = simd.isplat(T::Index::truncate(0));
106 non_nan_count2 = simd.isplat(T::Index::truncate(0));
107 non_nan_count3 = simd.isplat(T::Index::truncate(0));
108 }
109
110 for i in body1 {
111 (sum0, non_nan_count0) = process(simd, sum0, non_nan_count0, simd.read(row, i));
112 }
113
114 if let Some(i) = tail {
115 (sum0, non_nan_count0) = process(simd, sum0, non_nan_count0, simd.select(simd.tail_mask(), simd.read(row, i), nan));
116 }
117 non_nan_count_total += reduce::<T, S>(non_nan_count0);
118
119 sum0 = simd.add(sum0, sum1);
120 sum2 = simd.add(sum2, sum3);
121 sum0 = simd.add(sum0, sum2);
122
123 let sum = simd.reduce_sum(sum0);
124
125 out[i] = mul_real(&sum, &recip(&from_usize::<T::Real>(non_nan_count_total)));
126 }
127 }
128 }
129
130 T::Arch::default().dispatch(Impl { out, mat });
131}
132
133fn col_varm_row_major_ignore_nan<T: ComplexField>(
134 out: ColMut<'_, T::Real>,
135 mat: MatRef<'_, T, usize, usize, isize, ContiguousFwd>,
136 col_mean: ColRef<'_, T>,
137) {
138 struct Impl<'a, T: ComplexField> {
139 out: ColMut<'a, T::Real>,
140 mat: MatRef<'a, T, usize, usize, isize, ContiguousFwd>,
141 col_mean: ColRef<'a, T>,
142 }
143
144 impl<T: ComplexField> pulp::WithSimd for Impl<'_, T> {
145 type Output = ();
146
147 #[inline(always)]
148 fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
149 let Self { out, mat, col_mean } = self;
150 with_dim!(M, mat.nrows());
151 with_dim!(N, mat.ncols());
152
153 let simd = SimdCtx::<T, S>::new(T::simd_ctx(simd), N);
154
155 let mut out = out.as_row_shape_mut(M);
156 let mat = mat.as_shape(M, N);
157 let col_mean = col_mean.as_row_shape(M);
158 let indices = simd.batch_indices::<4>();
159
160 let nan_v = simd.splat(&nan::<T>());
161 for i in M.indices() {
162 let row = mat.row(i).transpose();
163 let mean = simd.splat(&col_mean[i]);
164 let (head, mut body4, body1, tail) = indices.clone();
165
166 let mut non_nan_count = 0usize;
167
168 #[inline(always)]
169 fn process<'M, T: ComplexField, S: pulp::Simd>(
170 simd: SimdCtx<'M, T, S>,
171 mean: T::SimdVec<S>,
172 acc: RealReg<T::SimdVec<S>>,
173 non_nan_count: T::SimdIndex<S>,
174 val: T::SimdVec<S>,
175 ) -> (RealReg<T::SimdVec<S>>, T::SimdIndex<S>) {
176 let is_not_nan = (*simd).eq(val, val);
177 let diff = simd.sub(val, mean);
178
179 (
180 RealReg(simd.select(is_not_nan, simd.abs2_add(diff, acc).0, acc.0)),
181 simd.iselect(is_not_nan, simd.iadd(non_nan_count, simd.isplat(T::Index::truncate(1))), non_nan_count),
182 )
183 }
184
185 let mut sum0 = RealReg(simd.splat(&zero::<T>()));
186 let mut sum1 = RealReg(simd.splat(&zero::<T>()));
187 let mut sum2 = RealReg(simd.splat(&zero::<T>()));
188 let mut sum3 = RealReg(simd.splat(&zero::<T>()));
189 let mut non_nan_count0 = simd.isplat(T::Index::truncate(0));
190 let mut non_nan_count1 = simd.isplat(T::Index::truncate(0));
191 let mut non_nan_count2 = simd.isplat(T::Index::truncate(0));
192 let mut non_nan_count3 = simd.isplat(T::Index::truncate(0));
193
194 if let Some(i) = head {
195 (sum0, non_nan_count0) = process(simd, mean, sum0, non_nan_count0, simd.select(simd.head_mask(), simd.read(row, i), nan_v));
196 non_nan_count += reduce::<T, S>(non_nan_count0);
197 non_nan_count0 = simd.isplat(T::Index::truncate(0));
198 }
199
200 loop {
201 if body4.len() == 0 {
202 break;
203 }
204 for [i0, i1, i2, i3] in (&mut body4).take(256) {
205 (sum0, non_nan_count0) = process(simd, mean, sum0, non_nan_count0, simd.read(row, i0));
206 (sum1, non_nan_count1) = process(simd, mean, sum1, non_nan_count1, simd.read(row, i1));
207 (sum2, non_nan_count2) = process(simd, mean, sum2, non_nan_count2, simd.read(row, i2));
208 (sum3, non_nan_count3) = process(simd, mean, sum3, non_nan_count3, simd.read(row, i3));
209 }
210
211 non_nan_count0 = simd.iadd(non_nan_count0, non_nan_count1);
212 non_nan_count2 = simd.iadd(non_nan_count2, non_nan_count3);
213 non_nan_count0 = simd.iadd(non_nan_count0, non_nan_count2);
214 non_nan_count += reduce::<T, S>(non_nan_count0);
215 non_nan_count0 = simd.isplat(T::Index::truncate(0));
216 non_nan_count1 = simd.isplat(T::Index::truncate(0));
217 non_nan_count2 = simd.isplat(T::Index::truncate(0));
218 non_nan_count3 = simd.isplat(T::Index::truncate(0));
219 }
220
221 for i in body1 {
222 (sum0, non_nan_count0) = process(simd, mean, sum0, non_nan_count0, simd.read(row, i));
223 }
224
225 if let Some(i) = tail {
226 (sum0, non_nan_count0) = process(simd, mean, sum0, non_nan_count0, simd.select(simd.tail_mask(), simd.read(row, i), nan_v));
227 }
228 non_nan_count += reduce::<T, S>(non_nan_count0);
229
230 sum0 = RealReg(simd.add(sum0.0, sum1.0));
231 sum2 = RealReg(simd.add(sum2.0, sum3.0));
232 sum0 = RealReg(simd.add(sum0.0, sum2.0));
233
234 let sum = real(&simd.reduce_sum(sum0.0));
235
236 if non_nan_count == 0 {
237 out[i] = nan();
238 } else if non_nan_count == 1 {
239 out[i] = zero();
240 } else {
241 out[i] = mul_real(&sum, &recip(&from_usize::<T::Real>(non_nan_count - 1)));
242 }
243 }
244 }
245 }
246
247 T::Arch::default().dispatch(Impl { out, mat, col_mean });
248}
249
250fn col_mean_row_major_propagate_nan<T: ComplexField>(out: ColMut<'_, T>, mat: MatRef<'_, T, usize, usize, isize, ContiguousFwd>) {
251 struct Impl<'a, T: ComplexField> {
252 out: ColMut<'a, T>,
253 mat: MatRef<'a, T, usize, usize, isize, ContiguousFwd>,
254 }
255
256 impl<T: ComplexField> pulp::WithSimd for Impl<'_, T> {
257 type Output = ();
258
259 #[inline(always)]
260 fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
261 let Self { out, mat } = self;
262 with_dim!(M, mat.nrows());
263 with_dim!(N, mat.ncols());
264
265 let simd = SimdCtx::<T, S>::new(T::simd_ctx(simd), N);
266
267 let mut out = out.as_row_shape_mut(M);
268 let mat = mat.as_shape(M, N);
269 let indices = simd.batch_indices::<4>();
270
271 let n = recip(&from_usize::<T::Real>(*N));
272 for i in M.indices() {
273 let row = mat.row(i).transpose();
274 let (head, body4, body1, tail) = indices.clone();
275
276 let mut sum0 = simd.splat(&zero::<T>());
277 let mut sum1 = simd.splat(&zero::<T>());
278 let mut sum2 = simd.splat(&zero::<T>());
279 let mut sum3 = simd.splat(&zero::<T>());
280
281 if let Some(i) = head {
282 sum0 = simd.add(sum0, simd.read(row, i));
283 }
284
285 for [i0, i1, i2, i3] in body4 {
286 sum0 = simd.add(sum0, simd.read(row, i0));
287 sum1 = simd.add(sum1, simd.read(row, i1));
288 sum2 = simd.add(sum2, simd.read(row, i2));
289 sum3 = simd.add(sum3, simd.read(row, i3));
290 }
291
292 for i in body1 {
293 sum0 = simd.add(sum0, simd.read(row, i));
294 }
295
296 if let Some(i) = tail {
297 sum0 = simd.add(sum0, simd.read(row, i));
298 }
299
300 sum0 = simd.add(sum0, sum1);
301 sum2 = simd.add(sum2, sum3);
302 sum0 = simd.add(sum0, sum2);
303
304 let sum = simd.reduce_sum(sum0);
305
306 out[i] = mul_real(&sum, &n);
307 }
308 }
309 }
310
311 T::Arch::default().dispatch(Impl { out, mat });
312}
313
314fn col_varm_row_major_propagate_nan<T: ComplexField>(
315 out: ColMut<'_, T::Real>,
316 mat: MatRef<'_, T, usize, usize, isize, ContiguousFwd>,
317 col_mean: ColRef<'_, T>,
318) {
319 struct Impl<'a, T: ComplexField> {
320 out: ColMut<'a, T::Real>,
321 mat: MatRef<'a, T, usize, usize, isize, ContiguousFwd>,
322 col_mean: ColRef<'a, T>,
323 }
324
325 impl<T: ComplexField> pulp::WithSimd for Impl<'_, T> {
326 type Output = ();
327
328 #[inline(always)]
329 fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
330 let Self { out, mat, col_mean } = self;
331 with_dim!(M, mat.nrows());
332 with_dim!(N, mat.ncols());
333
334 let simd = SimdCtx::<T, S>::new(T::simd_ctx(simd), N);
335
336 let mut out = out.as_row_shape_mut(M);
337 let mat = mat.as_shape(M, N);
338 let col_mean = col_mean.as_row_shape(M);
339 let indices = simd.batch_indices::<4>();
340
341 let n = *N;
342 if n == 0 {
343 out.fill(nan());
344 } else if n == 1 {
345 out.fill(zero());
346 } else {
347 let n = recip(&from_usize::<T::Real>(n - 1));
348 for i in M.indices() {
349 let row = mat.row(i).transpose();
350 let mean = simd.splat(&col_mean[i]);
351 let (head, body4, body1, tail) = indices.clone();
352
353 let mut sum0 = simd.splat(&zero::<T>());
354 let mut sum1 = simd.splat(&zero::<T>());
355 let mut sum2 = simd.splat(&zero::<T>());
356 let mut sum3 = simd.splat(&zero::<T>());
357
358 if let Some(i0) = head {
359 sum0 = simd.select(simd.head_mask(), simd.abs2_add(simd.sub(simd.read(row, i0), mean), RealReg(sum0)).0, sum0);
360 }
361
362 for [i0, i1, i2, i3] in body4 {
363 sum0 = simd.abs2_add(simd.sub(simd.read(row, i0), mean), RealReg(sum0)).0;
364 sum1 = simd.abs2_add(simd.sub(simd.read(row, i1), mean), RealReg(sum1)).0;
365 sum2 = simd.abs2_add(simd.sub(simd.read(row, i2), mean), RealReg(sum2)).0;
366 sum3 = simd.abs2_add(simd.sub(simd.read(row, i3), mean), RealReg(sum3)).0;
367 }
368
369 for i0 in body1 {
370 sum0 = simd.abs2_add(simd.sub(simd.read(row, i0), mean), RealReg(sum0)).0;
371 }
372
373 if let Some(i0) = tail {
374 sum0 = simd.select(simd.tail_mask(), simd.abs2_add(simd.sub(simd.read(row, i0), mean), RealReg(sum0)).0, sum0);
375 }
376
377 sum0 = simd.add(sum0, sum1);
378 sum2 = simd.add(sum2, sum3);
379 sum0 = simd.add(sum0, sum2);
380
381 let sum = real(&simd.reduce_sum(sum0));
382
383 out[i] = mul_real(&sum, &n);
384 }
385 }
386 }
387 }
388
389 T::Arch::default().dispatch(Impl { out, mat, col_mean });
390}
391#[math]
392fn col_mean_ignore_nan_fallback<T: ComplexField>(out: ColMut<'_, T>, mat: MatRef<'_, T>) {
393 with_dim!(M, mat.nrows());
394 with_dim!(N, mat.ncols());
395
396 let non_nan_count = &mut *alloc::vec![0usize; *M];
397 let non_nan_count = Array::from_mut(non_nan_count, M);
398
399 let mut out = out.as_row_shape_mut(M);
400 let mat = mat.as_shape(M, N);
401
402 out.fill(zero());
403
404 for j in N.indices() {
405 for i in M.indices() {
406 let val = copy(mat[(i, j)]);
407 let nan = is_nan(val);
408 let val = if nan { zero::<T>() } else { val };
409
410 non_nan_count[i] += (!nan) as usize;
411 out[i] = out[i] + val;
412 }
413 }
414
415 for i in M.indices() {
416 out[i] = mul_real(out[i], recip(from_usize::<T::Real>(non_nan_count[i])));
417 }
418}
419
420#[math]
421fn col_varm_ignore_nan_fallback<T: ComplexField>(out: ColMut<'_, T::Real>, mat: MatRef<'_, T>, col_mean: ColRef<'_, T>) {
422 with_dim!(M, mat.nrows());
423 with_dim!(N, mat.ncols());
424
425 let non_nan_count = &mut *alloc::vec![0usize; *M];
426 let non_nan_count = Array::from_mut(non_nan_count, M);
427
428 let mut out = out.as_row_shape_mut(M);
429 let mat = mat.as_shape(M, N);
430 let col_mean = col_mean.as_row_shape(M);
431
432 out.fill(zero());
433
434 for j in N.indices() {
435 for i in M.indices() {
436 let val = copy(mat[(i, j)]);
437 let col_mean = copy(col_mean[i]);
438 let nan = is_nan(val);
439 let val = if nan { zero::<T::Real>() } else { abs2(val - col_mean) };
440
441 non_nan_count[i] += (!nan) as usize;
442 out[i] = out[i] + val;
443 }
444 }
445
446 for i in M.indices() {
447 let non_nan_count = non_nan_count[i];
448 if non_nan_count == 0 {
449 out[i] = nan();
450 } else if non_nan_count == 1 {
451 out[i] = zero();
452 } else {
453 out[i] = mul_real(out[i], recip(from_usize::<T::Real>(non_nan_count - 1)));
454 }
455 }
456}
457
458#[math]
459fn col_mean_propagate_nan_fallback<T: ComplexField>(out: ColMut<'_, T>, mat: MatRef<'_, T>) {
460 with_dim!(M, mat.nrows());
461 with_dim!(N, mat.ncols());
462
463 let mut out = out.as_row_shape_mut(M);
464 let mat = mat.as_shape(M, N);
465
466 out.fill(zero());
467
468 for j in N.indices() {
469 for i in M.indices() {
470 out[i] = out[i] + mat[(i, j)];
471 }
472 }
473
474 let n = recip(from_usize::<T::Real>(*N));
475 for i in M.indices() {
476 out[i] = mul_real(out[i], n);
477 }
478}
479
480#[math]
481fn col_varm_propagate_nan_fallback<T: ComplexField>(out: ColMut<'_, T::Real>, mat: MatRef<'_, T>, col_mean: ColRef<'_, T>) {
482 with_dim!(M, mat.nrows());
483 with_dim!(N, mat.ncols());
484
485 let mut out = out.as_row_shape_mut(M);
486 let mat = mat.as_shape(M, N);
487 let col_mean = col_mean.as_row_shape(M);
488
489 out.fill(zero());
490
491 for j in N.indices() {
492 for i in M.indices() {
493 let val = abs2(mat[(i, j)] - col_mean[i]);
494 out[i] = out[i] + val;
495 }
496 }
497
498 let n = *N;
499 if n == 0 {
500 out.fill(nan());
501 } else if n == 1 {
502 out.fill(zero());
503 } else {
504 let n = recip(from_usize::<T::Real>(*N - 1));
505 for i in M.indices() {
506 out[i] = mul_real(out[i], n);
507 }
508 }
509}
510
511fn col_mean_ignore<T: ComplexField>(out: ColMut<'_, T>, mat: MatRef<'_, T>) {
512 let mat = if mat.col_stride() >= 0 { mat } else { mat.reverse_cols() };
513 let mat = if mat.row_stride() >= 0 { mat } else { mat.reverse_rows() };
514
515 if const { T::SIMD_CAPABILITIES.is_simd() } {
516 if mat.ncols() > 1 && mat.col_stride() == 1 {
517 col_mean_row_major_ignore_nan(out, mat.try_as_row_major().unwrap());
518 } else {
519 col_mean_ignore_nan_fallback(out, mat);
520 }
521 } else {
522 col_mean_ignore_nan_fallback(out, mat);
523 }
524}
525
526fn col_varm_ignore<T: ComplexField>(out: ColMut<'_, T::Real>, mat: MatRef<'_, T>, col_mean: ColRef<'_, T>) {
527 let mat = if mat.col_stride() >= 0 { mat } else { mat.reverse_cols() };
528 let mat = if mat.row_stride() >= 0 { mat } else { mat.reverse_rows() };
529
530 if const { T::SIMD_CAPABILITIES.is_simd() } {
531 if mat.ncols() > 1 && mat.col_stride() == 1 {
532 col_varm_row_major_ignore_nan(out, mat.try_as_row_major().unwrap(), col_mean);
533 } else {
534 col_varm_ignore_nan_fallback(out, mat, col_mean);
535 }
536 } else {
537 col_varm_ignore_nan_fallback(out, mat, col_mean);
538 }
539}
540
541fn col_mean_propagate<T: ComplexField>(out: ColMut<'_, T>, mat: MatRef<'_, T>) {
542 let mat = if mat.col_stride() >= 0 { mat } else { mat.reverse_cols() };
543 let mat = if mat.row_stride() >= 0 { mat } else { mat.reverse_rows() };
544
545 if const { T::SIMD_CAPABILITIES.is_simd() } {
546 if mat.ncols() > 1 && mat.col_stride() == 1 {
547 col_mean_row_major_propagate_nan(out, mat.try_as_row_major().unwrap());
548 } else {
549 col_mean_propagate_nan_fallback(out, mat);
550 }
551 } else {
552 col_mean_propagate_nan_fallback(out, mat);
553 }
554}
555
556fn col_varm_propagate<T: ComplexField>(out: ColMut<'_, T::Real>, mat: MatRef<'_, T>, col_mean: ColRef<'_, T>) {
557 let mat = if mat.col_stride() >= 0 { mat } else { mat.reverse_cols() };
558 let mat = if mat.row_stride() >= 0 { mat } else { mat.reverse_rows() };
559
560 if const { T::SIMD_CAPABILITIES.is_simd() } {
561 if mat.ncols() > 1 && mat.col_stride() == 1 {
562 col_varm_row_major_propagate_nan(out, mat.try_as_row_major().unwrap(), col_mean);
563 } else {
564 col_varm_propagate_nan_fallback(out, mat, col_mean);
565 }
566 } else {
567 col_varm_propagate_nan_fallback(out, mat, col_mean);
568 }
569}
570
571#[track_caller]
573pub fn col_mean<T: ComplexField>(out: ColMut<'_, T>, mat: MatRef<'_, T>, nan: NanHandling) {
574 assert!(all(out.nrows() == mat.nrows()));
575
576 match nan {
577 NanHandling::Propagate => col_mean_propagate(out, mat),
578 NanHandling::Ignore => col_mean_ignore(out, mat),
579 }
580}
581
582#[track_caller]
584pub fn row_mean<T: ComplexField>(out: RowMut<'_, T>, mat: MatRef<'_, T>, nan: NanHandling) {
585 assert!(all(out.ncols() == mat.ncols()));
586 col_mean(out.transpose_mut(), mat.transpose(), nan);
587}
588
589#[track_caller]
591pub fn col_varm<T: ComplexField>(out: ColMut<'_, T::Real>, mat: MatRef<'_, T>, col_mean: ColRef<'_, T>, nan: NanHandling) {
592 assert!(all(out.nrows() == mat.nrows(), col_mean.nrows() == mat.nrows()));
593
594 match nan {
595 NanHandling::Propagate => col_varm_propagate(out, mat, col_mean),
596 NanHandling::Ignore => col_varm_ignore(out, mat, col_mean),
597 }
598}
599
600#[track_caller]
602pub fn row_varm<T: ComplexField>(out: RowMut<'_, T::Real>, mat: MatRef<'_, T>, row_mean: RowRef<'_, T>, nan: NanHandling) {
603 assert!(all(out.ncols() == mat.ncols(), row_mean.ncols() == mat.ncols()));
604
605 col_varm(out.transpose_mut(), mat.transpose(), row_mean.transpose(), nan);
606}
607
608#[cfg(test)]
609mod tests {
610 use super::*;
611 use equator::assert;
612
613 #[test]
614 fn test_meanvar_propagate() {
615 let c32 = c32::new;
616 let A = mat![[c32(1.2, 2.3), c32(3.4, 1.2)], [c32(1.7, -1.0), c32(-3.8, 1.95)],];
617
618 let mut row_mean = Row::zeros(A.ncols());
619 let mut row_var = Row::zeros(A.ncols());
620 super::row_mean(row_mean.as_mut(), A.as_ref(), NanHandling::Propagate);
621 super::row_varm(row_var.as_mut(), A.as_ref(), row_mean.as_ref(), NanHandling::Propagate);
622
623 let mut col_mean = Col::zeros(A.nrows());
624 let mut col_var = Col::zeros(A.nrows());
625 super::col_mean(col_mean.as_mut(), A.as_ref(), NanHandling::Propagate);
626 super::col_varm(col_var.as_mut(), A.as_ref(), col_mean.as_ref(), NanHandling::Propagate);
627
628 assert!(row_mean == row![(A[(0, 0)] + A[(1, 0)]) / 2.0, (A[(0, 1)] + A[(1, 1)]) / 2.0,]);
629 assert!(
630 row_var
631 == row![
632 abs2(&(A[(0, 0)] - row_mean[0])) + abs2(&(A[(1, 0)] - row_mean[0])),
633 abs2(&(A[(0, 1)] - row_mean[1])) + abs2(&(A[(1, 1)] - row_mean[1])),
634 ]
635 );
636
637 assert!(col_mean == col![(A[(0, 0)] + A[(0, 1)]) / 2.0, (A[(1, 0)] + A[(1, 1)]) / 2.0,]);
638 assert!(
639 col_var
640 == col![
641 abs2(&(A[(0, 0)] - col_mean[0])) + abs2(&(A[(0, 1)] - col_mean[0])),
642 abs2(&(A[(1, 0)] - col_mean[1])) + abs2(&(A[(1, 1)] - col_mean[1])),
643 ]
644 );
645 }
646
647 #[test]
648 fn test_meanvar_ignore_nan_nonan_c32() {
649 let c32 = c32::new;
650 let A = mat![[c32(1.2, 2.3), c32(3.4, 1.2)], [c32(1.7, -1.0), c32(-3.8, 1.95)],];
651
652 let mut row_mean = Row::zeros(A.ncols());
653 let mut row_var = Row::zeros(A.ncols());
654 super::row_mean(row_mean.as_mut(), A.as_ref(), NanHandling::Ignore);
655 super::row_varm(row_var.as_mut(), A.as_ref(), row_mean.as_ref(), NanHandling::Ignore);
656
657 let mut col_mean = Col::zeros(A.nrows());
658 let mut col_var = Col::zeros(A.nrows());
659 super::col_mean(col_mean.as_mut(), A.as_ref(), NanHandling::Ignore);
660 super::col_varm(col_var.as_mut(), A.as_ref(), col_mean.as_ref(), NanHandling::Ignore);
661
662 assert!(row_mean == row![(A[(0, 0)] + A[(1, 0)]) / 2.0, (A[(0, 1)] + A[(1, 1)]) / 2.0,]);
663 assert!(
664 row_var
665 == row![
666 abs2(&(A[(0, 0)] - row_mean[0])) + abs2(&(A[(1, 0)] - row_mean[0])),
667 abs2(&(A[(0, 1)] - row_mean[1])) + abs2(&(A[(1, 1)] - row_mean[1])),
668 ]
669 );
670
671 assert!(col_mean == col![(A[(0, 0)] + A[(0, 1)]) / 2.0, (A[(1, 0)] + A[(1, 1)]) / 2.0,]);
672 assert!(
673 col_var
674 == col![
675 abs2(&(A[(0, 0)] - col_mean[0])) + abs2(&(A[(0, 1)] - col_mean[0])),
676 abs2(&(A[(1, 0)] - col_mean[1])) + abs2(&(A[(1, 1)] - col_mean[1])),
677 ]
678 );
679 }
680
681 #[test]
682 fn test_meanvar_ignore_nan_yesnan_c32() {
683 let c32 = c32::new;
684 let nan = f32::NAN;
685 let A = mat![[c32(1.2, nan), c32(3.4, 1.2)], [c32(1.7, -1.0), c32(-3.8, 1.95)],];
686
687 let mut row_mean = Row::zeros(A.ncols());
688 let mut row_var = Row::zeros(A.ncols());
689 super::row_mean(row_mean.as_mut(), A.as_ref(), NanHandling::Ignore);
690 super::row_varm(row_var.as_mut(), A.as_ref(), row_mean.as_ref(), NanHandling::Ignore);
691
692 let mut col_mean = Col::zeros(A.nrows());
693 let mut col_var = Col::zeros(A.nrows());
694 super::col_mean(col_mean.as_mut(), A.as_ref(), NanHandling::Ignore);
695 super::col_varm(col_var.as_mut(), A.as_ref(), col_mean.as_ref(), NanHandling::Ignore);
696
697 assert!(row_mean == row![A[(1, 0)] / 1.0, (A[(0, 1)] + A[(1, 1)]) / 2.0,]);
698 assert!(
699 row_var
700 == row![
701 abs2(&(A[(1, 0)] - row_mean[0])),
702 abs2(&(A[(0, 1)] - row_mean[1])) + abs2(&(A[(1, 1)] - row_mean[1])),
703 ]
704 );
705
706 assert!(col_mean == col![A[(0, 1)] / 1.0, (A[(1, 0)] + A[(1, 1)]) / 2.0,]);
707 assert!(
708 col_var
709 == col![
710 abs2(&(A[(0, 1)] - col_mean[0])),
711 abs2(&(A[(1, 0)] - col_mean[1])) + abs2(&(A[(1, 1)] - col_mean[1])),
712 ]
713 );
714 }
715
716 #[test]
717 fn test_meanvar_ignore_nan_nonan_c64() {
718 let c64 = c64::new;
719 let A = mat![[c64(1.2, 2.3), c64(3.4, 1.2)], [c64(1.7, -1.0), c64(-3.8, 1.95)],];
720
721 let mut row_mean = Row::zeros(A.ncols());
722 let mut row_var = Row::zeros(A.ncols());
723 super::row_mean(row_mean.as_mut(), A.as_ref(), NanHandling::Ignore);
724 super::row_varm(row_var.as_mut(), A.as_ref(), row_mean.as_ref(), NanHandling::Ignore);
725
726 let mut col_mean = Col::zeros(A.nrows());
727 let mut col_var = Col::zeros(A.nrows());
728 super::col_mean(col_mean.as_mut(), A.as_ref(), NanHandling::Ignore);
729 super::col_varm(col_var.as_mut(), A.as_ref(), col_mean.as_ref(), NanHandling::Ignore);
730
731 assert!(row_mean == row![(A[(0, 0)] + A[(1, 0)]) / 2.0, (A[(0, 1)] + A[(1, 1)]) / 2.0,]);
732 assert!(
733 row_var
734 == row![
735 abs2(&(A[(0, 0)] - row_mean[0])) + abs2(&(A[(1, 0)] - row_mean[0])),
736 abs2(&(A[(0, 1)] - row_mean[1])) + abs2(&(A[(1, 1)] - row_mean[1])),
737 ]
738 );
739
740 assert!(col_mean == col![(A[(0, 0)] + A[(0, 1)]) / 2.0, (A[(1, 0)] + A[(1, 1)]) / 2.0,]);
741 assert!(
742 col_var
743 == col![
744 abs2(&(A[(0, 0)] - col_mean[0])) + abs2(&(A[(0, 1)] - col_mean[0])),
745 abs2(&(A[(1, 0)] - col_mean[1])) + abs2(&(A[(1, 1)] - col_mean[1])),
746 ]
747 );
748 }
749
750 #[test]
751 fn test_meanvar_ignore_nan_yesnan_c64() {
752 let c64 = c64::new;
753 let nan = f64::NAN;
754 let A = mat![[c64(1.2, nan), c64(3.4, 1.2)], [c64(1.7, -1.0), c64(-3.8, 1.95)],];
755
756 let mut row_mean = Row::zeros(A.ncols());
757 let mut row_var = Row::zeros(A.ncols());
758 super::row_mean(row_mean.as_mut(), A.as_ref(), NanHandling::Ignore);
759 super::row_varm(row_var.as_mut(), A.as_ref(), row_mean.as_ref(), NanHandling::Ignore);
760
761 let mut col_mean = Col::zeros(A.nrows());
762 let mut col_var = Col::zeros(A.nrows());
763 super::col_mean(col_mean.as_mut(), A.as_ref(), NanHandling::Ignore);
764 super::col_varm(col_var.as_mut(), A.as_ref(), col_mean.as_ref(), NanHandling::Ignore);
765
766 assert!(row_mean == row![A[(1, 0)] / 1.0, (A[(0, 1)] + A[(1, 1)]) / 2.0,]);
767 assert!(
768 row_var
769 == row![
770 abs2(&(A[(1, 0)] - row_mean[0])),
771 abs2(&(A[(0, 1)] - row_mean[1])) + abs2(&(A[(1, 1)] - row_mean[1])),
772 ]
773 );
774
775 assert!(col_mean == col![A[(0, 1)] / 1.0, (A[(1, 0)] + A[(1, 1)]) / 2.0,]);
776 assert!(
777 col_var
778 == col![
779 abs2(&(A[(0, 1)] - col_mean[0])),
780 abs2(&(A[(1, 0)] - col_mean[1])) + abs2(&(A[(1, 1)] - col_mean[1])),
781 ]
782 );
783 }
784}