1#![allow(clippy::manual_div_ceil, clippy::needless_range_loop)]
2#![allow(dead_code)]
14#![allow(clippy::too_many_arguments)]
15
16use rayon::prelude::*;
17
18#[derive(Debug, Clone)]
24pub struct Tile {
25 pub data: Vec<f64>,
27}
28
29impl Tile {
30 pub fn from_slice(s: &[f64]) -> Self {
32 Self { data: s.to_vec() }
33 }
34
35 pub fn reduce_sum(&self) -> f64 {
37 self.data.iter().copied().sum()
38 }
39
40 pub fn reduce_max(&self) -> f64 {
42 self.data.iter().copied().fold(f64::NEG_INFINITY, f64::max)
43 }
44
45 pub fn reduce_min(&self) -> f64 {
47 self.data.iter().copied().fold(f64::INFINITY, f64::min)
48 }
49
50 pub fn exclusive_scan_inplace(&mut self) {
52 let mut acc = 0.0;
53 for v in &mut self.data {
54 let old = *v;
55 *v = acc;
56 acc += old;
57 }
58 }
59
60 pub fn inclusive_scan_inplace(&mut self) {
62 let mut acc = 0.0;
63 for v in &mut self.data {
64 acc += *v;
65 *v = acc;
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
79pub struct TiledReducer {
80 pub tile_size: usize,
82}
83
84impl TiledReducer {
85 pub fn new(tile_size: usize) -> Self {
87 assert!(tile_size > 0, "tile_size must be > 0");
88 Self { tile_size }
89 }
90
91 pub fn sum(&self, data: &[f64]) -> f64 {
93 if data.is_empty() {
94 return 0.0;
95 }
96 let tile_sums: Vec<f64> = data
97 .par_chunks(self.tile_size)
98 .map(|chunk| chunk.iter().copied().sum::<f64>())
99 .collect();
100 tile_sums.iter().copied().sum()
101 }
102
103 pub fn max(&self, data: &[f64]) -> f64 {
105 if data.is_empty() {
106 return f64::NEG_INFINITY;
107 }
108 let tile_maxs: Vec<f64> = data
109 .par_chunks(self.tile_size)
110 .map(|chunk| chunk.iter().copied().fold(f64::NEG_INFINITY, f64::max))
111 .collect();
112 tile_maxs.iter().copied().fold(f64::NEG_INFINITY, f64::max)
113 }
114
115 pub fn min(&self, data: &[f64]) -> f64 {
117 if data.is_empty() {
118 return f64::INFINITY;
119 }
120 let tile_mins: Vec<f64> = data
121 .par_chunks(self.tile_size)
122 .map(|chunk| chunk.iter().copied().fold(f64::INFINITY, f64::min))
123 .collect();
124 tile_mins.iter().copied().fold(f64::INFINITY, f64::min)
125 }
126
127 pub fn dot(&self, a: &[f64], b: &[f64]) -> f64 {
129 assert_eq!(a.len(), b.len(), "dot product requires equal-length inputs");
130 a.par_iter()
131 .zip(b.par_iter())
132 .map(|(&ai, &bi)| ai * bi)
133 .sum()
134 }
135
136 pub fn tile_sums(&self, data: &[f64]) -> Vec<f64> {
138 data.par_chunks(self.tile_size)
139 .map(|chunk| chunk.iter().copied().sum::<f64>())
140 .collect()
141 }
142}
143
144pub fn segmented_exclusive_scan(data: &[f64], flags: &[bool]) -> Vec<f64> {
153 assert_eq!(
154 data.len(),
155 flags.len(),
156 "data and flags must be same length"
157 );
158 let mut result = vec![0.0; data.len()];
159 let mut acc = 0.0;
160 for i in 0..data.len() {
161 if flags[i] {
162 acc = 0.0; }
164 result[i] = acc;
165 acc += data[i];
166 }
167 result
168}
169
170pub fn segmented_inclusive_scan(data: &[f64], flags: &[bool]) -> Vec<f64> {
172 assert_eq!(data.len(), flags.len());
173 let mut result = vec![0.0; data.len()];
174 let mut acc = 0.0;
175 for i in 0..data.len() {
176 if flags[i] {
177 acc = 0.0;
178 }
179 acc += data[i];
180 result[i] = acc;
181 }
182 result
183}
184
185pub fn segmented_reduce_sum(data: &[f64], flags: &[bool]) -> Vec<f64> {
187 assert_eq!(data.len(), flags.len());
188 let mut sums: Vec<f64> = Vec::new();
189 let mut acc = 0.0;
190 for i in 0..data.len() {
191 if flags[i] && i > 0 {
192 sums.push(acc);
193 acc = 0.0;
194 }
195 acc += data[i];
196 }
197 sums.push(acc);
198 sums
199}
200
201pub fn filter_compact<T, F>(data: &[T], predicate: F) -> Vec<T>
210where
211 T: Clone + Send + Sync,
212 F: Fn(&T) -> bool + Sync,
213{
214 data.par_iter().filter(|x| predicate(x)).cloned().collect()
215}
216
217pub fn partition_stable<T, F>(data: &[T], predicate: F) -> (Vec<T>, Vec<T>)
219where
220 T: Clone,
221 F: Fn(&T) -> bool,
222{
223 let mut pass = Vec::new();
224 let mut fail = Vec::new();
225 for x in data {
226 if predicate(x) {
227 pass.push(x.clone());
228 } else {
229 fail.push(x.clone());
230 }
231 }
232 (pass, fail)
233}
234
235pub fn scatter(dst: &mut [f64], src: &[f64], indices: &[usize]) {
243 assert_eq!(
244 src.len(),
245 indices.len(),
246 "src and indices must have equal length"
247 );
248 for (&v, &idx) in src.iter().zip(indices.iter()) {
249 dst[idx] = v;
250 }
251}
252
253pub fn gather(src: &[f64], indices: &[usize]) -> Vec<f64> {
255 indices.iter().map(|&i| src[i]).collect()
256}
257
258pub fn atomic_scatter_add(dst: &mut [f64], src: &[f64], indices: &[usize]) {
262 assert_eq!(src.len(), indices.len());
263 for (&v, &idx) in src.iter().zip(indices.iter()) {
264 dst[idx] += v;
265 }
266}
267
268pub const WARP_SIZE: usize = 32;
274
275pub fn warp_broadcast(lanes: &[f64], leader: usize) -> Vec<f64> {
277 assert!(leader < lanes.len(), "leader lane out of range");
278 vec![lanes[leader]; lanes.len()]
279}
280
281pub fn warp_reduce_sum(lanes: &[f64]) -> Vec<f64> {
283 let total: f64 = lanes.iter().copied().sum();
284 vec![total; lanes.len()]
285}
286
287pub fn warp_exclusive_scan(lanes: &[f64]) -> Vec<f64> {
289 let mut result = vec![0.0; lanes.len()];
290 let mut acc = 0.0;
291 for (i, &v) in lanes.iter().enumerate() {
292 result[i] = acc;
293 acc += v;
294 }
295 result
296}
297
298pub fn warp_vote_any<F: Fn(f64) -> bool>(lanes: &[f64], pred: F) -> bool {
300 lanes.iter().any(|&v| pred(v))
301}
302
303pub fn warp_vote_all<F: Fn(f64) -> bool>(lanes: &[f64], pred: F) -> bool {
305 lanes.iter().all(|&v| pred(v))
306}
307
308#[allow(clippy::too_many_arguments)]
325pub fn estimate_occupancy(
326 wg_size: usize,
327 regs_per_thread: usize,
328 shared_mem_bytes: usize,
329 max_wgs_per_sm: usize,
330 max_threads_per_sm: usize,
331 max_regs_per_sm: usize,
332 max_smem_per_sm: usize,
333) -> f64 {
334 if wg_size == 0 {
335 return 0.0;
336 }
337 let by_threads = max_threads_per_sm / wg_size;
339 let by_regs = if regs_per_thread == 0 {
340 max_wgs_per_sm
341 } else {
342 max_regs_per_sm / (regs_per_thread * wg_size)
343 };
344 let by_smem = max_smem_per_sm
345 .checked_div(shared_mem_bytes)
346 .unwrap_or(max_wgs_per_sm);
347 let actual_wgs = by_threads.min(by_regs).min(by_smem).min(max_wgs_per_sm);
348 let active_threads = actual_wgs * wg_size;
349 (active_threads as f64 / max_threads_per_sm as f64).min(1.0)
350}
351
352#[derive(Debug, Clone)]
358pub struct GridReduceStats {
359 pub count: usize,
361 pub sum: f64,
363 pub mean: f64,
365 pub variance: f64,
367 pub min: f64,
369 pub max: f64,
371}
372
373impl GridReduceStats {
374 pub fn compute(data: &[f64]) -> Self {
376 let count = data.len();
377 if count == 0 {
378 return Self {
379 count: 0,
380 sum: 0.0,
381 mean: 0.0,
382 variance: 0.0,
383 min: 0.0,
384 max: 0.0,
385 };
386 }
387 let sum: f64 = data.par_iter().copied().sum();
388 let mean = sum / count as f64;
389 let variance: f64 = data
390 .par_iter()
391 .map(|&v| (v - mean) * (v - mean))
392 .sum::<f64>()
393 / count as f64;
394 let min = data.par_iter().copied().reduce(|| f64::INFINITY, f64::min);
395 let max = data
396 .par_iter()
397 .copied()
398 .reduce(|| f64::NEG_INFINITY, f64::max);
399 Self {
400 count,
401 sum,
402 mean,
403 variance,
404 min,
405 max,
406 }
407 }
408
409 pub fn std_dev(&self) -> f64 {
411 self.variance.sqrt()
412 }
413}
414
415#[derive(Debug, Clone)]
424pub struct Histogram {
425 pub bins: Vec<u64>,
427 pub lo: f64,
429 pub hi: f64,
431}
432
433impl Histogram {
434 pub fn compute(data: &[f64], lo: f64, hi: f64, n_bins: usize) -> Self {
438 assert!(n_bins > 0, "n_bins must be > 0");
439 assert!(lo < hi, "lo must be < hi");
440 let width = hi - lo;
441 let mut bins = vec![0u64; n_bins];
442 for &v in data {
443 let idx = ((v - lo) / width * n_bins as f64) as isize;
444 let idx = idx.max(0).min(n_bins as isize - 1) as usize;
445 bins[idx] += 1;
446 }
447 Self { bins, lo, hi }
448 }
449
450 pub fn total(&self) -> u64 {
452 self.bins.iter().sum()
453 }
454
455 pub fn bin_centre(&self, i: usize) -> f64 {
457 let bin_width = (self.hi - self.lo) / self.bins.len() as f64;
458 self.lo + (i as f64 + 0.5) * bin_width
459 }
460
461 pub fn mode_bin(&self) -> usize {
463 self.bins
464 .iter()
465 .enumerate()
466 .max_by_key(|&(_, c)| *c)
467 .map(|(i, _)| i)
468 .unwrap_or(0)
469 }
470
471 pub fn approx_mean(&self) -> f64 {
473 let total = self.total();
474 if total == 0 {
475 return 0.0;
476 }
477 let sum: f64 = self
478 .bins
479 .iter()
480 .enumerate()
481 .map(|(i, &c)| self.bin_centre(i) * c as f64)
482 .sum();
483 sum / total as f64
484 }
485}
486
487pub fn norm_l1(data: &[f64]) -> f64 {
493 data.par_iter().map(|&v| v.abs()).sum()
494}
495
496pub fn norm_l2(data: &[f64]) -> f64 {
498 let sq: f64 = data.par_iter().map(|&v| v * v).sum();
499 sq.sqrt()
500}
501
502pub fn norm_linf(data: &[f64]) -> f64 {
504 data.par_iter()
505 .map(|&v| v.abs())
506 .reduce(|| 0.0_f64, f64::max)
507}
508
509pub fn dist_sq_l2(a: &[f64], b: &[f64]) -> f64 {
511 assert_eq!(a.len(), b.len());
512 a.par_iter()
513 .zip(b.par_iter())
514 .map(|(&ai, &bi)| (ai - bi) * (ai - bi))
515 .sum()
516}
517
518pub fn dist_l2(a: &[f64], b: &[f64]) -> f64 {
520 dist_sq_l2(a, b).sqrt()
521}
522
523pub fn covariance_matrix(data: &[f64], n: usize, d: usize) -> Vec<f64> {
532 assert_eq!(data.len(), n * d, "data must have n*d elements");
533 let mut mean = vec![0.0f64; d];
535 for row in 0..n {
536 for col in 0..d {
537 mean[col] += data[row * d + col];
538 }
539 }
540 for m in &mut mean {
541 *m /= n as f64;
542 }
543
544 let mut cov = vec![0.0f64; d * d];
546 for row in 0..n {
547 for i in 0..d {
548 for j in 0..d {
549 let xi = data[row * d + i] - mean[i];
550 let xj = data[row * d + j] - mean[j];
551 cov[i * d + j] += xi * xj;
552 }
553 }
554 }
555 for c in &mut cov {
556 *c /= n as f64;
557 }
558 cov
559}
560
561pub fn matrix_diagonal(mat: &[f64], d: usize) -> Vec<f64> {
563 (0..d).map(|i| mat[i * d + i]).collect()
564}
565
566pub fn matvec(a: &[f64], m: usize, n: usize, x: &[f64]) -> Vec<f64> {
573 assert_eq!(a.len(), m * n);
574 assert_eq!(x.len(), n);
575 (0..m)
576 .map(|i| {
577 a[i * n..(i + 1) * n]
578 .iter()
579 .zip(x.iter())
580 .map(|(&ai, &xi)| ai * xi)
581 .sum()
582 })
583 .collect()
584}
585
586#[allow(clippy::too_many_arguments)]
589pub fn matmul(a: &[f64], m: usize, k: usize, b: &[f64], n: usize) -> Vec<f64> {
590 assert_eq!(a.len(), m * k);
591 assert_eq!(b.len(), k * n);
592 let mut c = vec![0.0f64; m * n];
593 for i in 0..m {
594 for p in 0..k {
595 let a_ip = a[i * k + p];
596 for j in 0..n {
597 c[i * n + j] += a_ip * b[p * n + j];
598 }
599 }
600 }
601 c
602}
603
604#[derive(Debug, Clone, Default)]
611pub struct WelfordStats {
612 pub count: u64,
614 pub mean: f64,
616 m2: f64,
618}
619
620impl WelfordStats {
621 pub fn update(&mut self, x: f64) {
623 self.count += 1;
624 let delta = x - self.mean;
625 self.mean += delta / self.count as f64;
626 let delta2 = x - self.mean;
627 self.m2 += delta * delta2;
628 }
629
630 pub fn variance(&self) -> f64 {
632 if self.count < 2 {
633 return 0.0;
634 }
635 self.m2 / self.count as f64
636 }
637
638 pub fn sample_variance(&self) -> f64 {
640 if self.count < 2 {
641 return 0.0;
642 }
643 self.m2 / (self.count - 1) as f64
644 }
645
646 pub fn std_dev(&self) -> f64 {
648 self.variance().sqrt()
649 }
650}
651
652pub fn parallel_histogram(
661 data: &[f64],
662 lo: f64,
663 hi: f64,
664 n_bins: usize,
665 n_workers: usize,
666) -> Vec<u64> {
667 assert!(n_bins > 0);
668 assert!(lo < hi);
669 let chunk_size = data.len().div_ceil(n_workers.max(1));
670 if chunk_size == 0 {
671 return vec![0u64; n_bins];
672 }
673 let partial: Vec<Vec<u64>> = data
674 .par_chunks(chunk_size)
675 .map(|chunk| {
676 let width = hi - lo;
677 let mut bins = vec![0u64; n_bins];
678 for &v in chunk {
679 let idx = ((v - lo) / width * n_bins as f64) as isize;
680 let idx = idx.max(0).min(n_bins as isize - 1) as usize;
681 bins[idx] += 1;
682 }
683 bins
684 })
685 .collect();
686
687 let mut merged = vec![0u64; n_bins];
689 for part in &partial {
690 for (m, &p) in merged.iter_mut().zip(part.iter()) {
691 *m += p;
692 }
693 }
694 merged
695}
696
697pub fn exclusive_scan_u64(data: &[u64]) -> Vec<u64> {
703 let mut result = Vec::with_capacity(data.len());
704 let mut acc = 0u64;
705 for &v in data {
706 result.push(acc);
707 acc = acc.saturating_add(v);
708 }
709 result
710}
711
712pub fn inclusive_scan_u64(data: &[u64]) -> Vec<u64> {
714 let mut result = Vec::with_capacity(data.len());
715 let mut acc = 0u64;
716 for &v in data {
717 acc = acc.saturating_add(v);
718 result.push(acc);
719 }
720 result
721}
722
723pub fn convolve1d(signal: &[f64], kernel: &[f64]) -> Vec<f64> {
732 if signal.is_empty() || kernel.is_empty() {
733 return vec![];
734 }
735 let out_len = signal.len() + kernel.len() - 1;
736 let mut out = vec![0.0f64; out_len];
737 for (i, &s) in signal.iter().enumerate() {
738 for (j, &k) in kernel.iter().enumerate() {
739 out[i + j] += s * k;
740 }
741 }
742 out
743}
744
745pub fn correlate1d_valid(signal: &[f64], pattern: &[f64]) -> Vec<f64> {
748 if pattern.len() > signal.len() {
749 return vec![];
750 }
751 let out_len = signal.len() - pattern.len() + 1;
752 (0..out_len)
753 .map(|i| {
754 signal[i..i + pattern.len()]
755 .iter()
756 .zip(pattern.iter())
757 .map(|(&s, &p)| s * p)
758 .sum()
759 })
760 .collect()
761}
762
763#[cfg(test)]
768mod grid_reduce_tests {
769 use super::*;
770 use crate::grid_reduce::Histogram;
771
772 use crate::grid_reduce::Tile;
773 use crate::grid_reduce::TiledReducer;
774
775 use crate::grid_reduce::WelfordStats;
776
777 use crate::grid_reduce::exclusive_scan_u64;
778
779 use crate::grid_reduce::inclusive_scan_u64;
780
781 use crate::grid_reduce::segmented_reduce_sum;
782
783 #[test]
784 fn test_tile_reduce_sum() {
785 let t = Tile::from_slice(&[1.0, 2.0, 3.0, 4.0]);
786 assert!((t.reduce_sum() - 10.0).abs() < 1e-12);
787 }
788
789 #[test]
790 fn test_tile_exclusive_scan() {
791 let mut t = Tile::from_slice(&[1.0, 2.0, 3.0, 4.0]);
792 t.exclusive_scan_inplace();
793 assert_eq!(t.data, vec![0.0, 1.0, 3.0, 6.0]);
794 }
795
796 #[test]
797 fn test_tile_inclusive_scan() {
798 let mut t = Tile::from_slice(&[1.0, 2.0, 3.0]);
799 t.inclusive_scan_inplace();
800 assert_eq!(t.data, vec![1.0, 3.0, 6.0]);
801 }
802
803 #[test]
804 fn test_tiled_reducer_sum() {
805 let data: Vec<f64> = (1..=100).map(|i| i as f64).collect();
806 let r = TiledReducer::new(16);
807 let s = r.sum(&data);
808 assert!((s - 5050.0).abs() < 1e-8, "sum 1..100 = 5050, got {s}");
809 }
810
811 #[test]
812 fn test_tiled_reducer_dot_product() {
813 let a = vec![1.0, 2.0, 3.0];
814 let b = vec![4.0, 5.0, 6.0];
815 let r = TiledReducer::new(8);
816 let d = r.dot(&a, &b);
817 assert!((d - 32.0).abs() < 1e-12, "dot([1,2,3],[4,5,6]) = 32");
818 }
819
820 #[test]
821 fn test_segmented_exclusive_scan() {
822 let data = [1.0, 2.0, 3.0, 1.0, 2.0];
823 let flags = [true, false, false, true, false];
824 let out = segmented_exclusive_scan(&data, &flags);
825 assert_eq!(out, vec![0.0, 1.0, 3.0, 0.0, 1.0]);
826 }
827
828 #[test]
829 fn test_segmented_reduce_sum() {
830 let data = [1.0, 2.0, 3.0, 10.0, 20.0];
831 let flags = [true, false, false, true, false];
832 let sums = segmented_reduce_sum(&data, &flags);
833 assert_eq!(sums.len(), 2);
834 assert!((sums[0] - 6.0).abs() < 1e-12, "first segment sum = 6");
835 assert!((sums[1] - 30.0).abs() < 1e-12, "second segment sum = 30");
836 }
837
838 #[test]
839 fn test_filter_compact() {
840 let data = vec![1.0, -2.0, 3.0, -4.0, 5.0];
841 let pos: Vec<f64> = filter_compact(&data, |&x| x > 0.0);
842 assert_eq!(pos, vec![1.0, 3.0, 5.0]);
843 }
844
845 #[test]
846 fn test_scatter_gather_roundtrip() {
847 let mut dst = vec![0.0; 5];
848 let src = vec![10.0, 20.0, 30.0];
849 let indices = vec![4, 1, 2];
850 scatter(&mut dst, &src, &indices);
851 assert!((dst[4] - 10.0).abs() < 1e-12);
852 assert!((dst[1] - 20.0).abs() < 1e-12);
853 let gathered = gather(&dst, &[4, 1, 2]);
854 assert_eq!(gathered, vec![10.0, 20.0, 30.0]);
855 }
856
857 #[test]
858 fn test_warp_reduce_sum_all_lanes_equal() {
859 let lanes = vec![1.0, 2.0, 3.0, 4.0];
860 let result = warp_reduce_sum(&lanes);
861 assert!(
862 result.iter().all(|&v| (v - 10.0).abs() < 1e-12),
863 "all lanes should get the total sum"
864 );
865 }
866
867 #[test]
868 fn test_warp_exclusive_scan() {
869 let lanes = vec![1.0, 1.0, 1.0, 1.0];
870 let out = warp_exclusive_scan(&lanes);
871 assert_eq!(out, vec![0.0, 1.0, 2.0, 3.0]);
872 }
873
874 #[test]
875 fn test_occupancy_estimate_full() {
876 let occ = estimate_occupancy(64, 32, 0, 32, 2048, 65536, 49152);
879 assert!((occ - 1.0).abs() < 1e-9, "should be 100% occupancy");
880 }
881
882 #[test]
883 fn test_grid_reduce_stats() {
884 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
885 let stats = GridReduceStats::compute(&data);
886 assert_eq!(stats.count, 5);
887 assert!((stats.sum - 15.0).abs() < 1e-10);
888 assert!((stats.mean - 3.0).abs() < 1e-10);
889 assert!((stats.min - 1.0).abs() < 1e-10);
890 assert!((stats.max - 5.0).abs() < 1e-10);
891 assert!((stats.variance - 2.0).abs() < 1e-10);
893 assert!((stats.std_dev() - 2.0_f64.sqrt()).abs() < 1e-10);
894 }
895
896 #[test]
897 fn test_grid_reduce_stats_empty() {
898 let stats = GridReduceStats::compute(&[]);
899 assert_eq!(stats.count, 0);
900 assert!((stats.sum).abs() < 1e-12);
901 }
902
903 #[test]
906 fn test_histogram_basic() {
907 let data = vec![0.1, 0.5, 0.9, 1.5, 1.9];
908 let h = Histogram::compute(&data, 0.0, 2.0, 2);
909 assert_eq!(h.bins[0], 3);
912 assert_eq!(h.bins[1], 2);
913 assert_eq!(h.total(), 5);
914 }
915
916 #[test]
917 fn test_histogram_mode_bin() {
918 let data = vec![0.1, 0.2, 0.3, 1.5];
919 let h = Histogram::compute(&data, 0.0, 2.0, 2);
920 assert_eq!(h.mode_bin(), 0); }
922
923 #[test]
924 fn test_histogram_bin_centre() {
925 let h = Histogram::compute(&[], 0.0, 4.0, 4);
926 assert!((h.bin_centre(0) - 0.5).abs() < 1e-10);
928 assert!((h.bin_centre(3) - 3.5).abs() < 1e-10);
929 }
930
931 #[test]
932 fn test_histogram_approx_mean() {
933 let data = vec![0.1, 0.2, 0.3, 0.4];
935 let h = Histogram::compute(&data, 0.0, 1.0, 1);
936 assert!((h.approx_mean() - 0.5).abs() < 1e-10);
937 }
938
939 #[test]
942 fn test_norm_l1() {
943 let v = vec![1.0, -2.0, 3.0];
944 assert!((norm_l1(&v) - 6.0).abs() < 1e-12);
945 }
946
947 #[test]
948 fn test_norm_l2() {
949 let v = vec![3.0, 4.0];
950 assert!((norm_l2(&v) - 5.0).abs() < 1e-12);
951 }
952
953 #[test]
954 fn test_norm_linf() {
955 let v = vec![1.0, -5.0, 3.0];
956 assert!((norm_linf(&v) - 5.0).abs() < 1e-12);
957 }
958
959 #[test]
960 fn test_dist_l2() {
961 let a = vec![0.0, 0.0];
962 let b = vec![3.0, 4.0];
963 assert!((dist_l2(&a, &b) - 5.0).abs() < 1e-12);
964 }
965
966 #[test]
969 fn test_covariance_identity_pattern() {
970 let data = vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0];
973 let cov = covariance_matrix(&data, 3, 2);
974 let expected_var = 2.0 / 3.0;
976 assert!(
977 (cov[0] - expected_var).abs() < 1e-10,
978 "cov[0,0] = {}",
979 cov[0]
980 );
981 assert!(
982 (cov[1] - expected_var).abs() < 1e-10,
983 "cov[0,1] = {}",
984 cov[1]
985 );
986 assert!(
987 (cov[3] - expected_var).abs() < 1e-10,
988 "cov[1,1] = {}",
989 cov[3]
990 );
991 }
992
993 #[test]
994 fn test_matrix_diagonal() {
995 let mat = vec![1.0, 2.0, 3.0, 4.0]; let diag = matrix_diagonal(&mat, 2);
997 assert_eq!(diag, vec![1.0, 4.0]);
998 }
999
1000 #[test]
1003 fn test_matvec_identity() {
1004 let identity = vec![1.0, 0.0, 0.0, 1.0]; let x = vec![3.0, 7.0];
1006 let y = matvec(&identity, 2, 2, &x);
1007 assert_eq!(y, x);
1008 }
1009
1010 #[test]
1011 fn test_matvec_basic() {
1012 let a = vec![1.0, 2.0, 3.0, 4.0];
1014 let x = vec![1.0, 1.0];
1015 let y = matvec(&a, 2, 2, &x);
1016 assert!((y[0] - 3.0).abs() < 1e-12);
1017 assert!((y[1] - 7.0).abs() < 1e-12);
1018 }
1019
1020 #[test]
1021 fn test_matmul_2x2() {
1022 let a = vec![1.0, 2.0, 3.0, 4.0];
1025 let b = vec![5.0, 6.0, 7.0, 8.0];
1026 let c = matmul(&a, 2, 2, &b, 2);
1027 assert!((c[0] - 19.0).abs() < 1e-12);
1028 assert!((c[1] - 22.0).abs() < 1e-12);
1029 assert!((c[2] - 43.0).abs() < 1e-12);
1030 assert!((c[3] - 50.0).abs() < 1e-12);
1031 }
1032
1033 #[test]
1036 fn test_welford_mean_and_variance() {
1037 let mut w = WelfordStats::default();
1038 for &v in &[2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0] {
1039 w.update(v);
1040 }
1041 assert!((w.mean - 5.0).abs() < 1e-10, "mean = {}", w.mean);
1042 assert!((w.variance() - 4.0).abs() < 1e-10, "var = {}", w.variance());
1044 }
1045
1046 #[test]
1047 fn test_welford_single_sample() {
1048 let mut w = WelfordStats::default();
1049 w.update(42.0);
1050 assert!((w.mean - 42.0).abs() < 1e-12);
1051 assert!((w.variance()).abs() < 1e-12);
1052 }
1053
1054 #[test]
1057 fn test_parallel_histogram_matches_serial() {
1058 let data: Vec<f64> = (0..200).map(|i| i as f64 / 10.0).collect(); let serial = Histogram::compute(&data, 0.0, 20.0, 10);
1060 let par = parallel_histogram(&data, 0.0, 20.0, 10, 4);
1061 assert_eq!(
1062 serial.bins, par,
1063 "parallel and serial histograms must agree"
1064 );
1065 }
1066
1067 #[test]
1070 fn test_exclusive_scan_u64() {
1071 let data = [1u64, 2, 3, 4];
1072 let out = exclusive_scan_u64(&data);
1073 assert_eq!(out, vec![0, 1, 3, 6]);
1074 }
1075
1076 #[test]
1077 fn test_inclusive_scan_u64() {
1078 let data = [1u64, 2, 3, 4];
1079 let out = inclusive_scan_u64(&data);
1080 assert_eq!(out, vec![1, 3, 6, 10]);
1081 }
1082
1083 #[test]
1086 fn test_convolve1d_basic() {
1087 let sig = vec![1.0, 2.0, 3.0];
1089 let ker = vec![1.0];
1090 let out = convolve1d(&sig, &ker);
1091 assert_eq!(out, sig);
1092 }
1093
1094 #[test]
1095 fn test_convolve1d_box_filter() {
1096 let sig = vec![0.0, 6.0, 0.0]; let ker = vec![1.0, 1.0, 1.0];
1099 let out = convolve1d(&sig, &ker); assert!((out[0]).abs() < 1e-12);
1102 assert!((out[1] - 6.0).abs() < 1e-12);
1103 assert!((out[3] - 6.0).abs() < 1e-12);
1104 assert!((out[4]).abs() < 1e-12);
1105 }
1106
1107 #[test]
1108 fn test_correlate1d_valid() {
1109 let sig = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1110 let pat = vec![1.0, 0.0, -1.0]; let out = correlate1d_valid(&sig, &pat);
1112 assert_eq!(out.len(), 3);
1114 assert!((out[0] - (1.0 - 3.0)).abs() < 1e-12);
1115 assert!((out[1] - (2.0 - 4.0)).abs() < 1e-12);
1116 assert!((out[2] - (3.0 - 5.0)).abs() < 1e-12);
1117 }
1118}
1119
1120pub fn blelloch_exclusive_scan(data: &[f64]) -> Vec<f64> {
1133 if data.is_empty() {
1134 return vec![];
1135 }
1136 let n = data.len();
1138 let mut p = 1usize;
1139 while p < n {
1140 p <<= 1;
1141 }
1142 let mut buf = vec![0.0f64; p];
1143 buf[..n].copy_from_slice(data);
1144
1145 let mut stride = 1usize;
1147 while stride < p {
1148 let step = stride * 2;
1149 let mut i = step - 1;
1150 while i < p {
1151 buf[i] += buf[i - stride];
1152 i += step;
1153 }
1154 stride = step;
1155 }
1156
1157 buf[p - 1] = 0.0;
1159
1160 let mut stride = p / 2;
1162 while stride >= 1 {
1163 let step = stride * 2;
1164 let mut i = step - 1;
1165 while i < p {
1166 let t = buf[i - stride];
1167 buf[i - stride] = buf[i];
1168 buf[i] += t;
1169 i += step;
1170 }
1171 stride /= 2;
1172 }
1173
1174 buf[..n].to_vec()
1175}
1176
1177pub fn blelloch_inclusive_scan(data: &[f64]) -> Vec<f64> {
1179 let excl = blelloch_exclusive_scan(data);
1180 excl.into_iter()
1181 .zip(data.iter())
1182 .map(|(e, &v)| e + v)
1183 .collect()
1184}
1185
1186pub fn blelloch_segmented_exclusive_scan(data: &[f64], flags: &[bool]) -> Vec<f64> {
1195 assert_eq!(data.len(), flags.len());
1196 segmented_exclusive_scan(data, flags)
1199}
1200
1201pub fn parallel_segmented_reduce_sum(data: &[f64], flags: &[bool]) -> Vec<f64> {
1205 assert_eq!(data.len(), flags.len());
1206 let mut starts = vec![0usize];
1208 for i in 1..flags.len() {
1209 if flags[i] {
1210 starts.push(i);
1211 }
1212 }
1213 starts.push(data.len());
1214 starts
1215 .windows(2)
1216 .map(|w| data[w[0]..w[1]].iter().sum())
1217 .collect()
1218}
1219
1220pub fn filter_compact_indexed(
1229 data: &[f64],
1230 predicate: impl Fn(f64) -> bool,
1231) -> (Vec<f64>, Vec<usize>) {
1232 let mut vals = Vec::new();
1233 let mut idxs = Vec::new();
1234 for (i, &v) in data.iter().enumerate() {
1235 if predicate(v) {
1236 vals.push(v);
1237 idxs.push(i);
1238 }
1239 }
1240 (vals, idxs)
1241}
1242
1243pub fn filter_compact_counted<T: Clone>(
1245 data: &[T],
1246 predicate: impl Fn(&T) -> bool,
1247) -> (Vec<T>, usize) {
1248 let compacted: Vec<T> = data.iter().filter(|x| predicate(x)).cloned().collect();
1249 let n_removed = data.len() - compacted.len();
1250 (compacted, n_removed)
1251}
1252
1253pub fn radix_sort_pass_u64(data: &[u64], bit_pos: u32, radix: usize) -> Vec<u64> {
1263 assert!(radix.is_power_of_two(), "radix must be a power of two");
1264 let mask = (radix - 1) as u64;
1265 let mut counts = vec![0usize; radix];
1267 for &v in data {
1268 let digit = ((v >> bit_pos) & mask) as usize;
1269 counts[digit] += 1;
1270 }
1271 let offsets = exclusive_scan_u64(&counts.iter().map(|&c| c as u64).collect::<Vec<_>>());
1273 let mut offsets: Vec<usize> = offsets.iter().map(|&o| o as usize).collect();
1274 let mut out = vec![0u64; data.len()];
1276 for &v in data {
1277 let digit = ((v >> bit_pos) & mask) as usize;
1278 out[offsets[digit]] = v;
1279 offsets[digit] += 1;
1280 }
1281 out
1282}
1283
1284pub fn radix_sort_u64(data: &[u64]) -> Vec<u64> {
1286 let mut buf = data.to_vec();
1287 for pass in 0..8u32 {
1288 buf = radix_sort_pass_u64(&buf, pass * 8, 256);
1289 }
1290 buf
1291}
1292
1293pub fn radix_sort_f64(data: &[f64]) -> Vec<f64> {
1298 let mut keys: Vec<u64> = data
1299 .iter()
1300 .map(|&v| {
1301 let bits = v.to_bits();
1302 if bits >> 63 == 0 {
1303 bits | (1u64 << 63) } else {
1305 !bits }
1307 })
1308 .collect();
1309 keys = radix_sort_u64(&keys);
1310 keys.iter()
1311 .map(|&bits| {
1312 let recovered = if bits >> 63 == 1 {
1313 bits ^ (1u64 << 63) } else {
1315 !bits };
1317 f64::from_bits(recovered)
1318 })
1319 .collect()
1320}
1321
1322pub fn tree_reduce_sum(data: &[f64]) -> f64 {
1331 if data.is_empty() {
1332 return 0.0;
1333 }
1334 let mut buf = data.to_vec();
1335 let mut n = buf.len();
1336 while n > 1 {
1337 let half = n / 2;
1338 for i in 0..half {
1339 buf[i] += buf[i + half];
1340 }
1341 if n % 2 == 1 {
1342 buf[half - 1] += buf[n - 1];
1343 }
1344 n = half;
1345 }
1346 buf[0]
1347}
1348
1349pub fn tree_reduce_max(data: &[f64]) -> f64 {
1351 if data.is_empty() {
1352 return f64::NEG_INFINITY;
1353 }
1354 let mut buf = data.to_vec();
1355 let mut n = buf.len();
1356 while n > 1 {
1357 let half = n / 2;
1358 for i in 0..half {
1359 buf[i] = f64::max(buf[i], buf[i + half]);
1360 }
1361 if n % 2 == 1 {
1362 buf[half - 1] = f64::max(buf[half - 1], buf[n - 1]);
1363 }
1364 n = half;
1365 }
1366 buf[0]
1367}
1368
1369pub fn tree_reduce_min(data: &[f64]) -> f64 {
1371 if data.is_empty() {
1372 return f64::INFINITY;
1373 }
1374 let mut buf = data.to_vec();
1375 let mut n = buf.len();
1376 while n > 1 {
1377 let half = n / 2;
1378 for i in 0..half {
1379 buf[i] = f64::min(buf[i], buf[i + half]);
1380 }
1381 if n % 2 == 1 {
1382 buf[half - 1] = f64::min(buf[half - 1], buf[n - 1]);
1383 }
1384 n = half;
1385 }
1386 buf[0]
1387}
1388
1389pub fn reduce_broadcast(data: &[f64]) -> Vec<f64> {
1397 let total: f64 = data.iter().copied().sum();
1398 vec![total; data.len()]
1399}
1400
1401pub fn normalise_by_sum(data: &[f64]) -> Vec<f64> {
1403 let s: f64 = data.iter().copied().sum();
1404 if s.abs() < 1e-30 {
1405 return data.to_vec();
1406 }
1407 data.iter().map(|&v| v / s).collect()
1408}
1409
1410#[derive(Debug, Clone)]
1418pub struct TwoLevelHistogram {
1419 pub bins: Vec<u64>,
1421 pub lo: f64,
1423 pub hi: f64,
1425 pub n_tiles: usize,
1427}
1428
1429impl TwoLevelHistogram {
1430 pub fn compute(data: &[f64], lo: f64, hi: f64, n_bins: usize, tile_size: usize) -> Self {
1432 let n_tiles = (data.len() + tile_size - 1) / tile_size.max(1);
1433 let bins = parallel_histogram(data, lo, hi, n_bins, n_tiles.max(1));
1434 Self {
1435 bins,
1436 lo,
1437 hi,
1438 n_tiles,
1439 }
1440 }
1441
1442 pub fn total(&self) -> u64 {
1444 self.bins.iter().sum()
1445 }
1446
1447 pub fn approx_median(&self) -> f64 {
1449 let total = self.total();
1450 if total == 0 {
1451 return (self.lo + self.hi) / 2.0;
1452 }
1453 let half = total / 2;
1454 let n = self.bins.len() as f64;
1455 let mut acc = 0u64;
1456 for (i, &c) in self.bins.iter().enumerate() {
1457 acc += c;
1458 if acc >= half {
1459 let bin_width = (self.hi - self.lo) / n;
1460 return self.lo + (i as f64 + 0.5) * bin_width;
1461 }
1462 }
1463 self.hi
1464 }
1465}
1466
1467#[derive(Debug, Clone, Default)]
1473pub struct RunningMinMax {
1474 pub min: f64,
1476 pub max: f64,
1478 pub count: u64,
1480}
1481
1482impl RunningMinMax {
1483 pub fn new() -> Self {
1485 Self {
1486 min: f64::INFINITY,
1487 max: f64::NEG_INFINITY,
1488 count: 0,
1489 }
1490 }
1491
1492 pub fn update(&mut self, v: f64) {
1494 self.min = f64::min(self.min, v);
1495 self.max = f64::max(self.max, v);
1496 self.count += 1;
1497 }
1498
1499 pub fn update_slice(&mut self, data: &[f64]) {
1501 for &v in data {
1502 self.update(v);
1503 }
1504 }
1505
1506 pub fn range(&self) -> f64 {
1508 if self.count == 0 {
1509 return 0.0;
1510 }
1511 self.max - self.min
1512 }
1513}
1514
1515pub fn compact_scatter(src: &[f64], mask: &[bool], dst: &mut Vec<f64>) -> usize {
1524 assert_eq!(src.len(), mask.len());
1525 let before = dst.len();
1526 for (&v, &keep) in src.iter().zip(mask.iter()) {
1527 if keep {
1528 dst.push(v);
1529 }
1530 }
1531 dst.len() - before
1532}
1533
1534pub fn compaction_offsets(mask: &[bool]) -> Vec<usize> {
1539 let mut result = vec![usize::MAX; mask.len()];
1540 let mut counter = 0usize;
1541 for (i, &keep) in mask.iter().enumerate() {
1542 if keep {
1543 result[i] = counter;
1544 counter += 1;
1545 }
1546 }
1547 result
1548}
1549
1550#[cfg(test)]
1555mod extended_tests {
1556 use crate::grid_reduce::Histogram;
1557 use crate::grid_reduce::RunningMinMax;
1558 use crate::grid_reduce::Tile;
1559 use crate::grid_reduce::TiledReducer;
1560 use crate::grid_reduce::TwoLevelHistogram;
1561 use crate::grid_reduce::WelfordStats;
1562 use crate::grid_reduce::blelloch_exclusive_scan;
1563 use crate::grid_reduce::blelloch_inclusive_scan;
1564 use crate::grid_reduce::compact_scatter;
1565 use crate::grid_reduce::compaction_offsets;
1566 use crate::grid_reduce::exclusive_scan_u64;
1567 use crate::grid_reduce::filter_compact_counted;
1568 use crate::grid_reduce::filter_compact_indexed;
1569 use crate::grid_reduce::inclusive_scan_u64;
1570 use crate::grid_reduce::normalise_by_sum;
1571 use crate::grid_reduce::parallel_segmented_reduce_sum;
1572 use crate::grid_reduce::radix_sort_f64;
1573 use crate::grid_reduce::radix_sort_pass_u64;
1574 use crate::grid_reduce::radix_sort_u64;
1575 use crate::grid_reduce::reduce_broadcast;
1576 use crate::grid_reduce::segmented_reduce_sum;
1577 use crate::grid_reduce::tree_reduce_max;
1578 use crate::grid_reduce::tree_reduce_min;
1579 use crate::grid_reduce::tree_reduce_sum;
1580
1581 #[test]
1584 fn blelloch_exclusive_scan_matches_serial() {
1585 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1586 let serial = {
1587 let mut r = Vec::new();
1588 let mut acc = 0.0f64;
1589 for &v in &data {
1590 r.push(acc);
1591 acc += v;
1592 }
1593 r
1594 };
1595 let blelloch = blelloch_exclusive_scan(&data);
1596 for (a, b) in serial.iter().zip(blelloch.iter()) {
1597 assert!((a - b).abs() < 1e-10, "mismatch: serial={a} blelloch={b}");
1598 }
1599 }
1600
1601 #[test]
1602 fn blelloch_exclusive_scan_non_pow2() {
1603 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0]; let result = blelloch_exclusive_scan(&data);
1605 assert_eq!(result.len(), 5);
1606 assert!((result[0] - 0.0).abs() < 1e-10);
1607 assert!((result[1] - 1.0).abs() < 1e-10);
1608 assert!((result[2] - 3.0).abs() < 1e-10);
1609 assert!((result[3] - 6.0).abs() < 1e-10);
1610 assert!((result[4] - 10.0).abs() < 1e-10);
1611 }
1612
1613 #[test]
1614 fn blelloch_inclusive_scan_correct() {
1615 let data = vec![1.0, 2.0, 3.0, 4.0];
1616 let result = blelloch_inclusive_scan(&data);
1617 assert_eq!(result, vec![1.0, 3.0, 6.0, 10.0]);
1618 }
1619
1620 #[test]
1621 fn blelloch_exclusive_scan_single_element() {
1622 let result = blelloch_exclusive_scan(&[42.0]);
1623 assert_eq!(result, vec![0.0]);
1624 }
1625
1626 #[test]
1627 fn blelloch_exclusive_scan_all_zeros() {
1628 let data = vec![0.0; 8];
1629 let result = blelloch_exclusive_scan(&data);
1630 assert!(result.iter().all(|&v| v.abs() < 1e-12));
1631 }
1632
1633 #[test]
1636 fn parallel_segmented_reduce_matches_serial() {
1637 let data = [1.0, 2.0, 3.0, 10.0, 20.0, 30.0];
1638 let flags = [true, false, false, true, false, false];
1639 let par = parallel_segmented_reduce_sum(&data, &flags);
1640 let ser = segmented_reduce_sum(&data, &flags);
1641 assert_eq!(par, ser);
1642 }
1643
1644 #[test]
1645 fn parallel_segmented_reduce_single_segment() {
1646 let data = [1.0, 2.0, 3.0];
1647 let flags = [true, false, false];
1648 let result = parallel_segmented_reduce_sum(&data, &flags);
1649 assert_eq!(result.len(), 1);
1650 assert!((result[0] - 6.0).abs() < 1e-10);
1651 }
1652
1653 #[test]
1656 fn filter_compact_indexed_positive() {
1657 let data = vec![-1.0, 2.0, -3.0, 4.0, 5.0];
1658 let (vals, idxs) = filter_compact_indexed(&data, |v| v > 0.0);
1659 assert_eq!(vals, vec![2.0, 4.0, 5.0]);
1660 assert_eq!(idxs, vec![1, 3, 4]);
1661 }
1662
1663 #[test]
1664 fn filter_compact_indexed_empty_result() {
1665 let data = vec![-1.0, -2.0, -3.0];
1666 let (vals, idxs) = filter_compact_indexed(&data, |v| v > 0.0);
1667 assert!(vals.is_empty());
1668 assert!(idxs.is_empty());
1669 }
1670
1671 #[test]
1672 fn filter_compact_counted_removes_negatives() {
1673 let data = vec![1.0, -2.0, 3.0, -4.0, 5.0];
1674 let (kept, removed) = filter_compact_counted(&data, |v| *v >= 0.0);
1675 assert_eq!(kept, vec![1.0, 3.0, 5.0]);
1676 assert_eq!(removed, 2);
1677 }
1678
1679 #[test]
1682 fn radix_sort_u64_ascending() {
1683 let mut data = vec![5u64, 3, 8, 1, 9, 2, 7, 4, 6, 0];
1684 let sorted = radix_sort_u64(&data);
1685 data.sort_unstable();
1686 assert_eq!(sorted, data);
1687 }
1688
1689 #[test]
1690 fn radix_sort_u64_empty() {
1691 let sorted = radix_sort_u64(&[]);
1692 assert!(sorted.is_empty());
1693 }
1694
1695 #[test]
1696 fn radix_sort_u64_already_sorted() {
1697 let data = vec![1u64, 2, 3, 4, 5];
1698 assert_eq!(radix_sort_u64(&data), data);
1699 }
1700
1701 #[test]
1702 fn radix_sort_u64_reverse() {
1703 let data = vec![5u64, 4, 3, 2, 1];
1704 let sorted = radix_sort_u64(&data);
1705 assert_eq!(sorted, vec![1u64, 2, 3, 4, 5]);
1706 }
1707
1708 #[test]
1709 fn radix_sort_f64_positive_values() {
1710 let data = vec![3.125, 1.41, 2.71, 0.57, 1.73];
1711 let sorted = radix_sort_f64(&data);
1712 let mut expected = data.clone();
1713 expected.sort_by(|a, b| a.partial_cmp(b).unwrap());
1714 for (a, b) in sorted.iter().zip(expected.iter()) {
1715 assert!((a - b).abs() < 1e-12, "a={a} b={b}");
1716 }
1717 }
1718
1719 #[test]
1720 fn radix_sort_pass_u64_single_pass() {
1721 let data = vec![0x03u64, 0x01, 0x04, 0x01, 0x05];
1723 let sorted = radix_sort_pass_u64(&data, 0, 256);
1724 assert_eq!(sorted.len(), data.len());
1725 for w in sorted.windows(2) {
1727 assert!(w[0] & 0xFF <= w[1] & 0xFF, "not sorted by low byte");
1728 }
1729 }
1730
1731 #[test]
1734 fn tree_reduce_sum_correct() {
1735 let data: Vec<f64> = (1..=16).map(|i| i as f64).collect();
1736 let s = tree_reduce_sum(&data);
1737 assert!((s - 136.0).abs() < 1e-10, "sum = {s}");
1738 }
1739
1740 #[test]
1741 fn tree_reduce_sum_odd_length() {
1742 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1743 let s = tree_reduce_sum(&data);
1744 assert!((s - 15.0).abs() < 1e-10, "sum = {s}");
1745 }
1746
1747 #[test]
1748 fn tree_reduce_max_correct() {
1749 let data = vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0];
1750 assert!((tree_reduce_max(&data) - 9.0).abs() < 1e-12);
1751 }
1752
1753 #[test]
1754 fn tree_reduce_min_correct() {
1755 let data = vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0];
1756 assert!((tree_reduce_min(&data) - 1.0).abs() < 1e-12);
1757 }
1758
1759 #[test]
1760 fn tree_reduce_empty() {
1761 assert!((tree_reduce_sum(&[])).abs() < 1e-12);
1762 assert!(tree_reduce_max(&[]) == f64::NEG_INFINITY);
1763 assert!(tree_reduce_min(&[]) == f64::INFINITY);
1764 }
1765
1766 #[test]
1767 fn tree_reduce_single() {
1768 assert!((tree_reduce_sum(&[42.0]) - 42.0).abs() < 1e-12);
1769 assert!((tree_reduce_max(&[42.0]) - 42.0).abs() < 1e-12);
1770 assert!((tree_reduce_min(&[42.0]) - 42.0).abs() < 1e-12);
1771 }
1772
1773 #[test]
1774 fn tree_reduce_matches_tiled_reducer() {
1775 let data: Vec<f64> = (0..100).map(|i| i as f64).collect();
1776 let tr = TiledReducer::new(16);
1777 let tiled_sum = tr.sum(&data);
1778 let tree_sum = tree_reduce_sum(&data);
1779 assert!(
1780 (tiled_sum - tree_sum).abs() < 1e-8,
1781 "tiled={tiled_sum} tree={tree_sum}"
1782 );
1783 }
1784
1785 #[test]
1788 fn reduce_broadcast_all_equal() {
1789 let data = vec![1.0, 2.0, 3.0];
1790 let result = reduce_broadcast(&data);
1791 assert!(
1792 result.iter().all(|&v| (v - 6.0).abs() < 1e-12),
1793 "all should equal 6"
1794 );
1795 }
1796
1797 #[test]
1798 fn normalise_by_sum_sums_to_one() {
1799 let data = vec![1.0, 2.0, 3.0, 4.0];
1800 let normed = normalise_by_sum(&data);
1801 let s: f64 = normed.iter().sum();
1802 assert!((s - 1.0).abs() < 1e-10, "sum = {s}");
1803 }
1804
1805 #[test]
1806 fn normalise_by_sum_zero_input_unchanged() {
1807 let data = vec![0.0, 0.0, 0.0];
1808 let result = normalise_by_sum(&data);
1809 assert_eq!(result, data);
1810 }
1811
1812 #[test]
1815 fn two_level_histogram_total_correct() {
1816 let data: Vec<f64> = (0..100).map(|i| i as f64 / 10.0).collect();
1817 let h = TwoLevelHistogram::compute(&data, 0.0, 10.0, 10, 16);
1818 assert_eq!(h.total(), 100);
1819 }
1820
1821 #[test]
1822 fn two_level_histogram_approx_median() {
1823 let data: Vec<f64> = (0..1000).map(|i| i as f64 / 100.0).collect();
1825 let h = TwoLevelHistogram::compute(&data, 0.0, 10.0, 100, 64);
1826 let med = h.approx_median();
1827 assert!((med - 5.0).abs() < 0.2, "approx median = {med}");
1828 }
1829
1830 #[test]
1831 fn two_level_histogram_bins_count_matches() {
1832 let data = vec![0.5, 1.5, 2.5, 3.5];
1833 let h = TwoLevelHistogram::compute(&data, 0.0, 4.0, 4, 2);
1834 assert_eq!(h.total(), 4);
1835 for &c in &h.bins {
1836 assert_eq!(c, 1, "each bin should have 1 element");
1837 }
1838 }
1839
1840 #[test]
1843 fn running_min_max_basic() {
1844 let mut t = RunningMinMax::new();
1845 t.update_slice(&[3.0, 1.0, 4.0, 1.0, 5.0]);
1846 assert!((t.min - 1.0).abs() < 1e-12);
1847 assert!((t.max - 5.0).abs() < 1e-12);
1848 assert_eq!(t.count, 5);
1849 assert!((t.range() - 4.0).abs() < 1e-12);
1850 }
1851
1852 #[test]
1853 fn running_min_max_single() {
1854 let mut t = RunningMinMax::new();
1855 t.update(42.0);
1856 assert!((t.min - 42.0).abs() < 1e-12);
1857 assert!((t.max - 42.0).abs() < 1e-12);
1858 assert!((t.range()).abs() < 1e-12);
1859 }
1860
1861 #[test]
1862 fn running_min_max_empty_range() {
1863 let t = RunningMinMax::new();
1864 assert!((t.range()).abs() < 1e-12);
1865 }
1866
1867 #[test]
1870 fn compact_scatter_basic() {
1871 let src = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1872 let mask = vec![true, false, true, false, true];
1873 let mut dst = Vec::new();
1874 let n = compact_scatter(&src, &mask, &mut dst);
1875 assert_eq!(n, 3);
1876 assert_eq!(dst, vec![1.0, 3.0, 5.0]);
1877 }
1878
1879 #[test]
1880 fn compact_scatter_appends_to_existing() {
1881 let src = vec![10.0, 20.0];
1882 let mask = vec![true, true];
1883 let mut dst = vec![0.0, 0.0];
1884 compact_scatter(&src, &mask, &mut dst);
1885 assert_eq!(dst, vec![0.0, 0.0, 10.0, 20.0]);
1886 }
1887
1888 #[test]
1889 fn compaction_offsets_correct() {
1890 let mask = vec![true, false, true, false, true];
1891 let offsets = compaction_offsets(&mask);
1892 assert_eq!(offsets[0], 0);
1893 assert_eq!(offsets[1], usize::MAX);
1894 assert_eq!(offsets[2], 1);
1895 assert_eq!(offsets[3], usize::MAX);
1896 assert_eq!(offsets[4], 2);
1897 }
1898
1899 #[test]
1900 fn compaction_offsets_all_false() {
1901 let mask = vec![false; 5];
1902 let offsets = compaction_offsets(&mask);
1903 assert!(offsets.iter().all(|&o| o == usize::MAX));
1904 }
1905
1906 #[test]
1909 fn histogram_uniform_distribution() {
1910 let data: Vec<f64> = (0..10).map(|i| i as f64 + 0.5).collect();
1911 let h = Histogram::compute(&data, 0.0, 10.0, 10);
1912 for &c in &h.bins {
1913 assert_eq!(c, 1, "each bin should have exactly 1 element");
1914 }
1915 }
1916
1917 #[test]
1918 fn histogram_clamped_out_of_range() {
1919 let data = vec![-5.0, 5.0, 15.0]; let h = Histogram::compute(&data, 0.0, 10.0, 2);
1921 assert_eq!(
1922 h.total(),
1923 3,
1924 "out-of-range values should be clamped into boundary bins"
1925 );
1926 }
1927
1928 #[test]
1931 fn welford_sample_variance_two_samples() {
1932 let mut w = WelfordStats::default();
1933 w.update(2.0);
1934 w.update(4.0);
1935 let sv = w.sample_variance();
1937 assert!((sv - 2.0).abs() < 1e-10, "sample_var = {sv}");
1938 }
1939
1940 #[test]
1941 fn welford_std_dev_known_dataset() {
1942 let mut w = WelfordStats::default();
1943 for &v in &[2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0] {
1944 w.update(v);
1945 }
1946 assert!(
1947 (w.std_dev() - 2.0).abs() < 1e-10,
1948 "std_dev = {}",
1949 w.std_dev()
1950 );
1951 }
1952
1953 #[test]
1956 fn exclusive_scan_u64_empty() {
1957 let r = exclusive_scan_u64(&[]);
1958 assert!(r.is_empty());
1959 }
1960
1961 #[test]
1962 fn inclusive_scan_u64_single() {
1963 let r = inclusive_scan_u64(&[7u64]);
1964 assert_eq!(r, vec![7]);
1965 }
1966
1967 #[test]
1970 fn tile_reduce_max_and_min() {
1971 let t = Tile::from_slice(&[3.0, 1.0, 4.0, 1.0, 5.0]);
1972 assert!((t.reduce_max() - 5.0).abs() < 1e-12);
1973 assert!((t.reduce_min() - 1.0).abs() < 1e-12);
1974 }
1975
1976 #[test]
1977 fn tiled_reducer_tile_sums_length() {
1978 let data: Vec<f64> = (0..100).map(|i| i as f64).collect();
1979 let r = TiledReducer::new(16);
1980 let ts = r.tile_sums(&data);
1981 assert_eq!(ts.len(), 7); }
1983
1984 #[test]
1985 fn tiled_reducer_max_and_min() {
1986 let data = vec![-5.0, 3.0, 8.0, -1.0, 2.0];
1987 let r = TiledReducer::new(4);
1988 assert!((r.max(&data) - 8.0).abs() < 1e-12);
1989 assert!((r.min(&data) - (-5.0)).abs() < 1e-12);
1990 }
1991}