faer/stats/
meanvar.rs

1use crate::assert;
2use crate::internal_prelude::*;
3use faer_traits::RealReg;
4
5/// Specifies how missing values should be handled in mean and variance computations.
6#[derive(Copy, Clone, Debug, PartialEq, Eq)]
7pub enum NanHandling {
8	/// NaNs are passed as-is to arithmetic operators.
9	Propagate,
10	/// NaNs are skipped, and they're not included in the total count of entries.
11	Ignore,
12}
13
14#[inline(always)]
15#[math]
16fn from_usize<T: RealField>(n: usize) -> T {
17	from_f64::<T>(n as u32 as f64) + (from_f64::<T>((n as u64 - (n as u32 as u64)) as f64))
18}
19
20#[inline(always)]
21fn reduce<T: ComplexField, S: pulp::Simd>(non_nan_count: T::SimdIndex<S>) -> usize {
22	let slice: &[T::Index] = bytemuck::cast_slice(core::slice::from_ref(&non_nan_count));
23
24	let mut acc = 0usize;
25	for &count in slice {
26		acc += count.zx();
27	}
28	acc
29}
30
31fn col_mean_row_major_ignore_nan<T: ComplexField>(out: ColMut<'_, T>, mat: MatRef<'_, T, usize, usize, isize, ContiguousFwd>) {
32	struct Impl<'a, T: ComplexField> {
33		out: ColMut<'a, T>,
34		mat: MatRef<'a, T, usize, usize, isize, ContiguousFwd>,
35	}
36
37	impl<T: ComplexField> pulp::WithSimd for Impl<'_, T> {
38		type Output = ();
39
40		#[inline(always)]
41		fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
42			let Self { out, mat } = self;
43			with_dim!(M, mat.nrows());
44			with_dim!(N, mat.ncols());
45
46			let simd = SimdCtx::<T, S>::new(T::simd_ctx(simd), N);
47
48			let mut out = out.as_row_shape_mut(M);
49			let mat = mat.as_shape(M, N);
50			let indices = simd.batch_indices::<4>();
51
52			let nan = simd.splat(&nan::<T>());
53			for i in M.indices() {
54				let row = mat.row(i).transpose();
55				let (head, mut body4, body1, tail) = indices.clone();
56
57				let mut non_nan_count_total = 0usize;
58
59				#[inline(always)]
60				fn process<'M, T: ComplexField, S: pulp::Simd>(
61					simd: SimdCtx<'M, T, S>,
62					acc: T::SimdVec<S>,
63					non_nan_count: T::SimdIndex<S>,
64					val: T::SimdVec<S>,
65				) -> (T::SimdVec<S>, T::SimdIndex<S>) {
66					let is_not_nan = (*simd).eq(val, val);
67
68					(
69						simd.select(is_not_nan, simd.add(acc, val), acc),
70						simd.iselect(is_not_nan, simd.iadd(non_nan_count, simd.isplat(T::Index::truncate(1))), non_nan_count),
71					)
72				}
73
74				let mut sum0 = simd.splat(&zero::<T>());
75				let mut sum1 = simd.splat(&zero::<T>());
76				let mut sum2 = simd.splat(&zero::<T>());
77				let mut sum3 = simd.splat(&zero::<T>());
78				let mut non_nan_count0 = simd.isplat(T::Index::truncate(0));
79				let mut non_nan_count1 = simd.isplat(T::Index::truncate(0));
80				let mut non_nan_count2 = simd.isplat(T::Index::truncate(0));
81				let mut non_nan_count3 = simd.isplat(T::Index::truncate(0));
82
83				if let Some(i) = head {
84					(sum0, non_nan_count0) = process(simd, sum0, non_nan_count0, simd.select(simd.head_mask(), simd.read(row, i), nan));
85					non_nan_count_total += reduce::<T, S>(non_nan_count0);
86					non_nan_count0 = simd.isplat(T::Index::truncate(0));
87				}
88
89				loop {
90					if body4.len() == 0 {
91						break;
92					}
93					for [i0, i1, i2, i3] in (&mut body4).take(256) {
94						(sum0, non_nan_count0) = process(simd, sum0, non_nan_count0, simd.read(row, i0));
95						(sum1, non_nan_count1) = process(simd, sum1, non_nan_count1, simd.read(row, i1));
96						(sum2, non_nan_count2) = process(simd, sum2, non_nan_count2, simd.read(row, i2));
97						(sum3, non_nan_count3) = process(simd, sum3, non_nan_count3, simd.read(row, i3));
98					}
99
100					non_nan_count0 = simd.iadd(non_nan_count0, non_nan_count1);
101					non_nan_count2 = simd.iadd(non_nan_count2, non_nan_count3);
102					non_nan_count0 = simd.iadd(non_nan_count0, non_nan_count2);
103					non_nan_count_total += reduce::<T, S>(non_nan_count0);
104					non_nan_count0 = simd.isplat(T::Index::truncate(0));
105					non_nan_count1 = simd.isplat(T::Index::truncate(0));
106					non_nan_count2 = simd.isplat(T::Index::truncate(0));
107					non_nan_count3 = simd.isplat(T::Index::truncate(0));
108				}
109
110				for i in body1 {
111					(sum0, non_nan_count0) = process(simd, sum0, non_nan_count0, simd.read(row, i));
112				}
113
114				if let Some(i) = tail {
115					(sum0, non_nan_count0) = process(simd, sum0, non_nan_count0, simd.select(simd.tail_mask(), simd.read(row, i), nan));
116				}
117				non_nan_count_total += reduce::<T, S>(non_nan_count0);
118
119				sum0 = simd.add(sum0, sum1);
120				sum2 = simd.add(sum2, sum3);
121				sum0 = simd.add(sum0, sum2);
122
123				let sum = simd.reduce_sum(sum0);
124
125				out[i] = mul_real(&sum, &recip(&from_usize::<T::Real>(non_nan_count_total)));
126			}
127		}
128	}
129
130	T::Arch::default().dispatch(Impl { out, mat });
131}
132
133fn col_varm_row_major_ignore_nan<T: ComplexField>(
134	out: ColMut<'_, T::Real>,
135	mat: MatRef<'_, T, usize, usize, isize, ContiguousFwd>,
136	col_mean: ColRef<'_, T>,
137) {
138	struct Impl<'a, T: ComplexField> {
139		out: ColMut<'a, T::Real>,
140		mat: MatRef<'a, T, usize, usize, isize, ContiguousFwd>,
141		col_mean: ColRef<'a, T>,
142	}
143
144	impl<T: ComplexField> pulp::WithSimd for Impl<'_, T> {
145		type Output = ();
146
147		#[inline(always)]
148		fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
149			let Self { out, mat, col_mean } = self;
150			with_dim!(M, mat.nrows());
151			with_dim!(N, mat.ncols());
152
153			let simd = SimdCtx::<T, S>::new(T::simd_ctx(simd), N);
154
155			let mut out = out.as_row_shape_mut(M);
156			let mat = mat.as_shape(M, N);
157			let col_mean = col_mean.as_row_shape(M);
158			let indices = simd.batch_indices::<4>();
159
160			let nan_v = simd.splat(&nan::<T>());
161			for i in M.indices() {
162				let row = mat.row(i).transpose();
163				let mean = simd.splat(&col_mean[i]);
164				let (head, mut body4, body1, tail) = indices.clone();
165
166				let mut non_nan_count = 0usize;
167
168				#[inline(always)]
169				fn process<'M, T: ComplexField, S: pulp::Simd>(
170					simd: SimdCtx<'M, T, S>,
171					mean: T::SimdVec<S>,
172					acc: RealReg<T::SimdVec<S>>,
173					non_nan_count: T::SimdIndex<S>,
174					val: T::SimdVec<S>,
175				) -> (RealReg<T::SimdVec<S>>, T::SimdIndex<S>) {
176					let is_not_nan = (*simd).eq(val, val);
177					let diff = simd.sub(val, mean);
178
179					(
180						RealReg(simd.select(is_not_nan, simd.abs2_add(diff, acc).0, acc.0)),
181						simd.iselect(is_not_nan, simd.iadd(non_nan_count, simd.isplat(T::Index::truncate(1))), non_nan_count),
182					)
183				}
184
185				let mut sum0 = RealReg(simd.splat(&zero::<T>()));
186				let mut sum1 = RealReg(simd.splat(&zero::<T>()));
187				let mut sum2 = RealReg(simd.splat(&zero::<T>()));
188				let mut sum3 = RealReg(simd.splat(&zero::<T>()));
189				let mut non_nan_count0 = simd.isplat(T::Index::truncate(0));
190				let mut non_nan_count1 = simd.isplat(T::Index::truncate(0));
191				let mut non_nan_count2 = simd.isplat(T::Index::truncate(0));
192				let mut non_nan_count3 = simd.isplat(T::Index::truncate(0));
193
194				if let Some(i) = head {
195					(sum0, non_nan_count0) = process(simd, mean, sum0, non_nan_count0, simd.select(simd.head_mask(), simd.read(row, i), nan_v));
196					non_nan_count += reduce::<T, S>(non_nan_count0);
197					non_nan_count0 = simd.isplat(T::Index::truncate(0));
198				}
199
200				loop {
201					if body4.len() == 0 {
202						break;
203					}
204					for [i0, i1, i2, i3] in (&mut body4).take(256) {
205						(sum0, non_nan_count0) = process(simd, mean, sum0, non_nan_count0, simd.read(row, i0));
206						(sum1, non_nan_count1) = process(simd, mean, sum1, non_nan_count1, simd.read(row, i1));
207						(sum2, non_nan_count2) = process(simd, mean, sum2, non_nan_count2, simd.read(row, i2));
208						(sum3, non_nan_count3) = process(simd, mean, sum3, non_nan_count3, simd.read(row, i3));
209					}
210
211					non_nan_count0 = simd.iadd(non_nan_count0, non_nan_count1);
212					non_nan_count2 = simd.iadd(non_nan_count2, non_nan_count3);
213					non_nan_count0 = simd.iadd(non_nan_count0, non_nan_count2);
214					non_nan_count += reduce::<T, S>(non_nan_count0);
215					non_nan_count0 = simd.isplat(T::Index::truncate(0));
216					non_nan_count1 = simd.isplat(T::Index::truncate(0));
217					non_nan_count2 = simd.isplat(T::Index::truncate(0));
218					non_nan_count3 = simd.isplat(T::Index::truncate(0));
219				}
220
221				for i in body1 {
222					(sum0, non_nan_count0) = process(simd, mean, sum0, non_nan_count0, simd.read(row, i));
223				}
224
225				if let Some(i) = tail {
226					(sum0, non_nan_count0) = process(simd, mean, sum0, non_nan_count0, simd.select(simd.tail_mask(), simd.read(row, i), nan_v));
227				}
228				non_nan_count += reduce::<T, S>(non_nan_count0);
229
230				sum0 = RealReg(simd.add(sum0.0, sum1.0));
231				sum2 = RealReg(simd.add(sum2.0, sum3.0));
232				sum0 = RealReg(simd.add(sum0.0, sum2.0));
233
234				let sum = real(&simd.reduce_sum(sum0.0));
235
236				if non_nan_count == 0 {
237					out[i] = nan();
238				} else if non_nan_count == 1 {
239					out[i] = zero();
240				} else {
241					out[i] = mul_real(&sum, &recip(&from_usize::<T::Real>(non_nan_count - 1)));
242				}
243			}
244		}
245	}
246
247	T::Arch::default().dispatch(Impl { out, mat, col_mean });
248}
249
250fn col_mean_row_major_propagate_nan<T: ComplexField>(out: ColMut<'_, T>, mat: MatRef<'_, T, usize, usize, isize, ContiguousFwd>) {
251	struct Impl<'a, T: ComplexField> {
252		out: ColMut<'a, T>,
253		mat: MatRef<'a, T, usize, usize, isize, ContiguousFwd>,
254	}
255
256	impl<T: ComplexField> pulp::WithSimd for Impl<'_, T> {
257		type Output = ();
258
259		#[inline(always)]
260		fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
261			let Self { out, mat } = self;
262			with_dim!(M, mat.nrows());
263			with_dim!(N, mat.ncols());
264
265			let simd = SimdCtx::<T, S>::new(T::simd_ctx(simd), N);
266
267			let mut out = out.as_row_shape_mut(M);
268			let mat = mat.as_shape(M, N);
269			let indices = simd.batch_indices::<4>();
270
271			let n = recip(&from_usize::<T::Real>(*N));
272			for i in M.indices() {
273				let row = mat.row(i).transpose();
274				let (head, body4, body1, tail) = indices.clone();
275
276				let mut sum0 = simd.splat(&zero::<T>());
277				let mut sum1 = simd.splat(&zero::<T>());
278				let mut sum2 = simd.splat(&zero::<T>());
279				let mut sum3 = simd.splat(&zero::<T>());
280
281				if let Some(i) = head {
282					sum0 = simd.add(sum0, simd.read(row, i));
283				}
284
285				for [i0, i1, i2, i3] in body4 {
286					sum0 = simd.add(sum0, simd.read(row, i0));
287					sum1 = simd.add(sum1, simd.read(row, i1));
288					sum2 = simd.add(sum2, simd.read(row, i2));
289					sum3 = simd.add(sum3, simd.read(row, i3));
290				}
291
292				for i in body1 {
293					sum0 = simd.add(sum0, simd.read(row, i));
294				}
295
296				if let Some(i) = tail {
297					sum0 = simd.add(sum0, simd.read(row, i));
298				}
299
300				sum0 = simd.add(sum0, sum1);
301				sum2 = simd.add(sum2, sum3);
302				sum0 = simd.add(sum0, sum2);
303
304				let sum = simd.reduce_sum(sum0);
305
306				out[i] = mul_real(&sum, &n);
307			}
308		}
309	}
310
311	T::Arch::default().dispatch(Impl { out, mat });
312}
313
314fn col_varm_row_major_propagate_nan<T: ComplexField>(
315	out: ColMut<'_, T::Real>,
316	mat: MatRef<'_, T, usize, usize, isize, ContiguousFwd>,
317	col_mean: ColRef<'_, T>,
318) {
319	struct Impl<'a, T: ComplexField> {
320		out: ColMut<'a, T::Real>,
321		mat: MatRef<'a, T, usize, usize, isize, ContiguousFwd>,
322		col_mean: ColRef<'a, T>,
323	}
324
325	impl<T: ComplexField> pulp::WithSimd for Impl<'_, T> {
326		type Output = ();
327
328		#[inline(always)]
329		fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
330			let Self { out, mat, col_mean } = self;
331			with_dim!(M, mat.nrows());
332			with_dim!(N, mat.ncols());
333
334			let simd = SimdCtx::<T, S>::new(T::simd_ctx(simd), N);
335
336			let mut out = out.as_row_shape_mut(M);
337			let mat = mat.as_shape(M, N);
338			let col_mean = col_mean.as_row_shape(M);
339			let indices = simd.batch_indices::<4>();
340
341			let n = *N;
342			if n == 0 {
343				out.fill(nan());
344			} else if n == 1 {
345				out.fill(zero());
346			} else {
347				let n = recip(&from_usize::<T::Real>(n - 1));
348				for i in M.indices() {
349					let row = mat.row(i).transpose();
350					let mean = simd.splat(&col_mean[i]);
351					let (head, body4, body1, tail) = indices.clone();
352
353					let mut sum0 = simd.splat(&zero::<T>());
354					let mut sum1 = simd.splat(&zero::<T>());
355					let mut sum2 = simd.splat(&zero::<T>());
356					let mut sum3 = simd.splat(&zero::<T>());
357
358					if let Some(i0) = head {
359						sum0 = simd.select(simd.head_mask(), simd.abs2_add(simd.sub(simd.read(row, i0), mean), RealReg(sum0)).0, sum0);
360					}
361
362					for [i0, i1, i2, i3] in body4 {
363						sum0 = simd.abs2_add(simd.sub(simd.read(row, i0), mean), RealReg(sum0)).0;
364						sum1 = simd.abs2_add(simd.sub(simd.read(row, i1), mean), RealReg(sum1)).0;
365						sum2 = simd.abs2_add(simd.sub(simd.read(row, i2), mean), RealReg(sum2)).0;
366						sum3 = simd.abs2_add(simd.sub(simd.read(row, i3), mean), RealReg(sum3)).0;
367					}
368
369					for i0 in body1 {
370						sum0 = simd.abs2_add(simd.sub(simd.read(row, i0), mean), RealReg(sum0)).0;
371					}
372
373					if let Some(i0) = tail {
374						sum0 = simd.select(simd.tail_mask(), simd.abs2_add(simd.sub(simd.read(row, i0), mean), RealReg(sum0)).0, sum0);
375					}
376
377					sum0 = simd.add(sum0, sum1);
378					sum2 = simd.add(sum2, sum3);
379					sum0 = simd.add(sum0, sum2);
380
381					let sum = real(&simd.reduce_sum(sum0));
382
383					out[i] = mul_real(&sum, &n);
384				}
385			}
386		}
387	}
388
389	T::Arch::default().dispatch(Impl { out, mat, col_mean });
390}
391#[math]
392fn col_mean_ignore_nan_fallback<T: ComplexField>(out: ColMut<'_, T>, mat: MatRef<'_, T>) {
393	with_dim!(M, mat.nrows());
394	with_dim!(N, mat.ncols());
395
396	let non_nan_count = &mut *alloc::vec![0usize; *M];
397	let non_nan_count = Array::from_mut(non_nan_count, M);
398
399	let mut out = out.as_row_shape_mut(M);
400	let mat = mat.as_shape(M, N);
401
402	out.fill(zero());
403
404	for j in N.indices() {
405		for i in M.indices() {
406			let val = copy(mat[(i, j)]);
407			let nan = is_nan(val);
408			let val = if nan { zero::<T>() } else { val };
409
410			non_nan_count[i] += (!nan) as usize;
411			out[i] = out[i] + val;
412		}
413	}
414
415	for i in M.indices() {
416		out[i] = mul_real(out[i], recip(from_usize::<T::Real>(non_nan_count[i])));
417	}
418}
419
420#[math]
421fn col_varm_ignore_nan_fallback<T: ComplexField>(out: ColMut<'_, T::Real>, mat: MatRef<'_, T>, col_mean: ColRef<'_, T>) {
422	with_dim!(M, mat.nrows());
423	with_dim!(N, mat.ncols());
424
425	let non_nan_count = &mut *alloc::vec![0usize; *M];
426	let non_nan_count = Array::from_mut(non_nan_count, M);
427
428	let mut out = out.as_row_shape_mut(M);
429	let mat = mat.as_shape(M, N);
430	let col_mean = col_mean.as_row_shape(M);
431
432	out.fill(zero());
433
434	for j in N.indices() {
435		for i in M.indices() {
436			let val = copy(mat[(i, j)]);
437			let col_mean = copy(col_mean[i]);
438			let nan = is_nan(val);
439			let val = if nan { zero::<T::Real>() } else { abs2(val - col_mean) };
440
441			non_nan_count[i] += (!nan) as usize;
442			out[i] = out[i] + val;
443		}
444	}
445
446	for i in M.indices() {
447		let non_nan_count = non_nan_count[i];
448		if non_nan_count == 0 {
449			out[i] = nan();
450		} else if non_nan_count == 1 {
451			out[i] = zero();
452		} else {
453			out[i] = mul_real(out[i], recip(from_usize::<T::Real>(non_nan_count - 1)));
454		}
455	}
456}
457
458#[math]
459fn col_mean_propagate_nan_fallback<T: ComplexField>(out: ColMut<'_, T>, mat: MatRef<'_, T>) {
460	with_dim!(M, mat.nrows());
461	with_dim!(N, mat.ncols());
462
463	let mut out = out.as_row_shape_mut(M);
464	let mat = mat.as_shape(M, N);
465
466	out.fill(zero());
467
468	for j in N.indices() {
469		for i in M.indices() {
470			out[i] = out[i] + mat[(i, j)];
471		}
472	}
473
474	let n = recip(from_usize::<T::Real>(*N));
475	for i in M.indices() {
476		out[i] = mul_real(out[i], n);
477	}
478}
479
480#[math]
481fn col_varm_propagate_nan_fallback<T: ComplexField>(out: ColMut<'_, T::Real>, mat: MatRef<'_, T>, col_mean: ColRef<'_, T>) {
482	with_dim!(M, mat.nrows());
483	with_dim!(N, mat.ncols());
484
485	let mut out = out.as_row_shape_mut(M);
486	let mat = mat.as_shape(M, N);
487	let col_mean = col_mean.as_row_shape(M);
488
489	out.fill(zero());
490
491	for j in N.indices() {
492		for i in M.indices() {
493			let val = abs2(mat[(i, j)] - col_mean[i]);
494			out[i] = out[i] + val;
495		}
496	}
497
498	let n = *N;
499	if n == 0 {
500		out.fill(nan());
501	} else if n == 1 {
502		out.fill(zero());
503	} else {
504		let n = recip(from_usize::<T::Real>(*N - 1));
505		for i in M.indices() {
506			out[i] = mul_real(out[i], n);
507		}
508	}
509}
510
511fn col_mean_ignore<T: ComplexField>(out: ColMut<'_, T>, mat: MatRef<'_, T>) {
512	let mat = if mat.col_stride() >= 0 { mat } else { mat.reverse_cols() };
513	let mat = if mat.row_stride() >= 0 { mat } else { mat.reverse_rows() };
514
515	if const { T::SIMD_CAPABILITIES.is_simd() } {
516		if mat.ncols() > 1 && mat.col_stride() == 1 {
517			col_mean_row_major_ignore_nan(out, mat.try_as_row_major().unwrap());
518		} else {
519			col_mean_ignore_nan_fallback(out, mat);
520		}
521	} else {
522		col_mean_ignore_nan_fallback(out, mat);
523	}
524}
525
526fn col_varm_ignore<T: ComplexField>(out: ColMut<'_, T::Real>, mat: MatRef<'_, T>, col_mean: ColRef<'_, T>) {
527	let mat = if mat.col_stride() >= 0 { mat } else { mat.reverse_cols() };
528	let mat = if mat.row_stride() >= 0 { mat } else { mat.reverse_rows() };
529
530	if const { T::SIMD_CAPABILITIES.is_simd() } {
531		if mat.ncols() > 1 && mat.col_stride() == 1 {
532			col_varm_row_major_ignore_nan(out, mat.try_as_row_major().unwrap(), col_mean);
533		} else {
534			col_varm_ignore_nan_fallback(out, mat, col_mean);
535		}
536	} else {
537		col_varm_ignore_nan_fallback(out, mat, col_mean);
538	}
539}
540
541fn col_mean_propagate<T: ComplexField>(out: ColMut<'_, T>, mat: MatRef<'_, T>) {
542	let mat = if mat.col_stride() >= 0 { mat } else { mat.reverse_cols() };
543	let mat = if mat.row_stride() >= 0 { mat } else { mat.reverse_rows() };
544
545	if const { T::SIMD_CAPABILITIES.is_simd() } {
546		if mat.ncols() > 1 && mat.col_stride() == 1 {
547			col_mean_row_major_propagate_nan(out, mat.try_as_row_major().unwrap());
548		} else {
549			col_mean_propagate_nan_fallback(out, mat);
550		}
551	} else {
552		col_mean_propagate_nan_fallback(out, mat);
553	}
554}
555
556fn col_varm_propagate<T: ComplexField>(out: ColMut<'_, T::Real>, mat: MatRef<'_, T>, col_mean: ColRef<'_, T>) {
557	let mat = if mat.col_stride() >= 0 { mat } else { mat.reverse_cols() };
558	let mat = if mat.row_stride() >= 0 { mat } else { mat.reverse_rows() };
559
560	if const { T::SIMD_CAPABILITIES.is_simd() } {
561		if mat.ncols() > 1 && mat.col_stride() == 1 {
562			col_varm_row_major_propagate_nan(out, mat.try_as_row_major().unwrap(), col_mean);
563		} else {
564			col_varm_propagate_nan_fallback(out, mat, col_mean);
565		}
566	} else {
567		col_varm_propagate_nan_fallback(out, mat, col_mean);
568	}
569}
570
571/// computes the mean of the columns of `mat` and stores the result in `out`
572#[track_caller]
573pub fn col_mean<T: ComplexField>(out: ColMut<'_, T>, mat: MatRef<'_, T>, nan: NanHandling) {
574	assert!(all(out.nrows() == mat.nrows()));
575
576	match nan {
577		NanHandling::Propagate => col_mean_propagate(out, mat),
578		NanHandling::Ignore => col_mean_ignore(out, mat),
579	}
580}
581
582/// computes the mean of the rows of `mat` and stores the result in `out`
583#[track_caller]
584pub fn row_mean<T: ComplexField>(out: RowMut<'_, T>, mat: MatRef<'_, T>, nan: NanHandling) {
585	assert!(all(out.ncols() == mat.ncols()));
586	col_mean(out.transpose_mut(), mat.transpose(), nan);
587}
588
589/// computes the variance of the columns of `mat` and stores the result in `out`
590#[track_caller]
591pub fn col_varm<T: ComplexField>(out: ColMut<'_, T::Real>, mat: MatRef<'_, T>, col_mean: ColRef<'_, T>, nan: NanHandling) {
592	assert!(all(out.nrows() == mat.nrows(), col_mean.nrows() == mat.nrows()));
593
594	match nan {
595		NanHandling::Propagate => col_varm_propagate(out, mat, col_mean),
596		NanHandling::Ignore => col_varm_ignore(out, mat, col_mean),
597	}
598}
599
600/// computes the variance of the rows of `mat` and stores the result in `out`
601#[track_caller]
602pub fn row_varm<T: ComplexField>(out: RowMut<'_, T::Real>, mat: MatRef<'_, T>, row_mean: RowRef<'_, T>, nan: NanHandling) {
603	assert!(all(out.ncols() == mat.ncols(), row_mean.ncols() == mat.ncols()));
604
605	col_varm(out.transpose_mut(), mat.transpose(), row_mean.transpose(), nan);
606}
607
608#[cfg(test)]
609mod tests {
610	use super::*;
611	use equator::assert;
612
613	#[test]
614	fn test_meanvar_propagate() {
615		let c32 = c32::new;
616		let A = mat![[c32(1.2, 2.3), c32(3.4, 1.2)], [c32(1.7, -1.0), c32(-3.8, 1.95)],];
617
618		let mut row_mean = Row::zeros(A.ncols());
619		let mut row_var = Row::zeros(A.ncols());
620		super::row_mean(row_mean.as_mut(), A.as_ref(), NanHandling::Propagate);
621		super::row_varm(row_var.as_mut(), A.as_ref(), row_mean.as_ref(), NanHandling::Propagate);
622
623		let mut col_mean = Col::zeros(A.nrows());
624		let mut col_var = Col::zeros(A.nrows());
625		super::col_mean(col_mean.as_mut(), A.as_ref(), NanHandling::Propagate);
626		super::col_varm(col_var.as_mut(), A.as_ref(), col_mean.as_ref(), NanHandling::Propagate);
627
628		assert!(row_mean == row![(A[(0, 0)] + A[(1, 0)]) / 2.0, (A[(0, 1)] + A[(1, 1)]) / 2.0,]);
629		assert!(
630			row_var
631				== row![
632					abs2(&(A[(0, 0)] - row_mean[0])) + abs2(&(A[(1, 0)] - row_mean[0])),
633					abs2(&(A[(0, 1)] - row_mean[1])) + abs2(&(A[(1, 1)] - row_mean[1])),
634				]
635		);
636
637		assert!(col_mean == col![(A[(0, 0)] + A[(0, 1)]) / 2.0, (A[(1, 0)] + A[(1, 1)]) / 2.0,]);
638		assert!(
639			col_var
640				== col![
641					abs2(&(A[(0, 0)] - col_mean[0])) + abs2(&(A[(0, 1)] - col_mean[0])),
642					abs2(&(A[(1, 0)] - col_mean[1])) + abs2(&(A[(1, 1)] - col_mean[1])),
643				]
644		);
645	}
646
647	#[test]
648	fn test_meanvar_ignore_nan_nonan_c32() {
649		let c32 = c32::new;
650		let A = mat![[c32(1.2, 2.3), c32(3.4, 1.2)], [c32(1.7, -1.0), c32(-3.8, 1.95)],];
651
652		let mut row_mean = Row::zeros(A.ncols());
653		let mut row_var = Row::zeros(A.ncols());
654		super::row_mean(row_mean.as_mut(), A.as_ref(), NanHandling::Ignore);
655		super::row_varm(row_var.as_mut(), A.as_ref(), row_mean.as_ref(), NanHandling::Ignore);
656
657		let mut col_mean = Col::zeros(A.nrows());
658		let mut col_var = Col::zeros(A.nrows());
659		super::col_mean(col_mean.as_mut(), A.as_ref(), NanHandling::Ignore);
660		super::col_varm(col_var.as_mut(), A.as_ref(), col_mean.as_ref(), NanHandling::Ignore);
661
662		assert!(row_mean == row![(A[(0, 0)] + A[(1, 0)]) / 2.0, (A[(0, 1)] + A[(1, 1)]) / 2.0,]);
663		assert!(
664			row_var
665				== row![
666					abs2(&(A[(0, 0)] - row_mean[0])) + abs2(&(A[(1, 0)] - row_mean[0])),
667					abs2(&(A[(0, 1)] - row_mean[1])) + abs2(&(A[(1, 1)] - row_mean[1])),
668				]
669		);
670
671		assert!(col_mean == col![(A[(0, 0)] + A[(0, 1)]) / 2.0, (A[(1, 0)] + A[(1, 1)]) / 2.0,]);
672		assert!(
673			col_var
674				== col![
675					abs2(&(A[(0, 0)] - col_mean[0])) + abs2(&(A[(0, 1)] - col_mean[0])),
676					abs2(&(A[(1, 0)] - col_mean[1])) + abs2(&(A[(1, 1)] - col_mean[1])),
677				]
678		);
679	}
680
681	#[test]
682	fn test_meanvar_ignore_nan_yesnan_c32() {
683		let c32 = c32::new;
684		let nan = f32::NAN;
685		let A = mat![[c32(1.2, nan), c32(3.4, 1.2)], [c32(1.7, -1.0), c32(-3.8, 1.95)],];
686
687		let mut row_mean = Row::zeros(A.ncols());
688		let mut row_var = Row::zeros(A.ncols());
689		super::row_mean(row_mean.as_mut(), A.as_ref(), NanHandling::Ignore);
690		super::row_varm(row_var.as_mut(), A.as_ref(), row_mean.as_ref(), NanHandling::Ignore);
691
692		let mut col_mean = Col::zeros(A.nrows());
693		let mut col_var = Col::zeros(A.nrows());
694		super::col_mean(col_mean.as_mut(), A.as_ref(), NanHandling::Ignore);
695		super::col_varm(col_var.as_mut(), A.as_ref(), col_mean.as_ref(), NanHandling::Ignore);
696
697		assert!(row_mean == row![A[(1, 0)] / 1.0, (A[(0, 1)] + A[(1, 1)]) / 2.0,]);
698		assert!(
699			row_var
700				== row![
701					abs2(&(A[(1, 0)] - row_mean[0])),
702					abs2(&(A[(0, 1)] - row_mean[1])) + abs2(&(A[(1, 1)] - row_mean[1])),
703				]
704		);
705
706		assert!(col_mean == col![A[(0, 1)] / 1.0, (A[(1, 0)] + A[(1, 1)]) / 2.0,]);
707		assert!(
708			col_var
709				== col![
710					abs2(&(A[(0, 1)] - col_mean[0])),
711					abs2(&(A[(1, 0)] - col_mean[1])) + abs2(&(A[(1, 1)] - col_mean[1])),
712				]
713		);
714	}
715
716	#[test]
717	fn test_meanvar_ignore_nan_nonan_c64() {
718		let c64 = c64::new;
719		let A = mat![[c64(1.2, 2.3), c64(3.4, 1.2)], [c64(1.7, -1.0), c64(-3.8, 1.95)],];
720
721		let mut row_mean = Row::zeros(A.ncols());
722		let mut row_var = Row::zeros(A.ncols());
723		super::row_mean(row_mean.as_mut(), A.as_ref(), NanHandling::Ignore);
724		super::row_varm(row_var.as_mut(), A.as_ref(), row_mean.as_ref(), NanHandling::Ignore);
725
726		let mut col_mean = Col::zeros(A.nrows());
727		let mut col_var = Col::zeros(A.nrows());
728		super::col_mean(col_mean.as_mut(), A.as_ref(), NanHandling::Ignore);
729		super::col_varm(col_var.as_mut(), A.as_ref(), col_mean.as_ref(), NanHandling::Ignore);
730
731		assert!(row_mean == row![(A[(0, 0)] + A[(1, 0)]) / 2.0, (A[(0, 1)] + A[(1, 1)]) / 2.0,]);
732		assert!(
733			row_var
734				== row![
735					abs2(&(A[(0, 0)] - row_mean[0])) + abs2(&(A[(1, 0)] - row_mean[0])),
736					abs2(&(A[(0, 1)] - row_mean[1])) + abs2(&(A[(1, 1)] - row_mean[1])),
737				]
738		);
739
740		assert!(col_mean == col![(A[(0, 0)] + A[(0, 1)]) / 2.0, (A[(1, 0)] + A[(1, 1)]) / 2.0,]);
741		assert!(
742			col_var
743				== col![
744					abs2(&(A[(0, 0)] - col_mean[0])) + abs2(&(A[(0, 1)] - col_mean[0])),
745					abs2(&(A[(1, 0)] - col_mean[1])) + abs2(&(A[(1, 1)] - col_mean[1])),
746				]
747		);
748	}
749
750	#[test]
751	fn test_meanvar_ignore_nan_yesnan_c64() {
752		let c64 = c64::new;
753		let nan = f64::NAN;
754		let A = mat![[c64(1.2, nan), c64(3.4, 1.2)], [c64(1.7, -1.0), c64(-3.8, 1.95)],];
755
756		let mut row_mean = Row::zeros(A.ncols());
757		let mut row_var = Row::zeros(A.ncols());
758		super::row_mean(row_mean.as_mut(), A.as_ref(), NanHandling::Ignore);
759		super::row_varm(row_var.as_mut(), A.as_ref(), row_mean.as_ref(), NanHandling::Ignore);
760
761		let mut col_mean = Col::zeros(A.nrows());
762		let mut col_var = Col::zeros(A.nrows());
763		super::col_mean(col_mean.as_mut(), A.as_ref(), NanHandling::Ignore);
764		super::col_varm(col_var.as_mut(), A.as_ref(), col_mean.as_ref(), NanHandling::Ignore);
765
766		assert!(row_mean == row![A[(1, 0)] / 1.0, (A[(0, 1)] + A[(1, 1)]) / 2.0,]);
767		assert!(
768			row_var
769				== row![
770					abs2(&(A[(1, 0)] - row_mean[0])),
771					abs2(&(A[(0, 1)] - row_mean[1])) + abs2(&(A[(1, 1)] - row_mean[1])),
772				]
773		);
774
775		assert!(col_mean == col![A[(0, 1)] / 1.0, (A[(1, 0)] + A[(1, 1)]) / 2.0,]);
776		assert!(
777			col_var
778				== col![
779					abs2(&(A[(0, 1)] - col_mean[0])),
780					abs2(&(A[(1, 0)] - col_mean[1])) + abs2(&(A[(1, 1)] - col_mean[1])),
781				]
782		);
783	}
784}