1use crate::{Array, DType, Shape};
4
5impl Array {
6 pub fn sort(&self) -> Array {
17 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
18 let mut data = self.to_vec();
19 data.sort_by(|a, b| a.partial_cmp(b).unwrap());
20 Array::from_vec(data, self.shape().clone())
21 }
22
23 pub fn sort_descending(&self) -> Array {
34 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
35 let mut data = self.to_vec();
36 data.sort_by(|a, b| b.partial_cmp(a).unwrap());
37 Array::from_vec(data, self.shape().clone())
38 }
39
40 pub fn argsort(&self) -> Vec<usize> {
51 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
52 let data = self.to_vec();
53 let mut indices: Vec<usize> = (0..data.len()).collect();
54 indices.sort_by(|&a, &b| data[a].partial_cmp(&data[b]).unwrap());
55 indices
56 }
57
58 pub fn argsort_descending(&self) -> Vec<usize> {
69 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
70 let data = self.to_vec();
71 let mut indices: Vec<usize> = (0..data.len()).collect();
72 indices.sort_by(|&a, &b| data[b].partial_cmp(&data[a]).unwrap());
73 indices
74 }
75
76 pub fn top_k_smallest(&self, k: usize) -> Vec<usize> {
87 assert!(k <= self.size(), "k must be <= array size");
88 let indices = self.argsort();
89 indices.into_iter().take(k).collect()
90 }
91
92 pub fn top_k_largest(&self, k: usize) -> Vec<usize> {
103 assert!(k <= self.size(), "k must be <= array size");
104 let indices = self.argsort_descending();
105 indices.into_iter().take(k).collect()
106 }
107
108 pub fn searchsorted(&self, value: f32) -> usize {
119 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
120 let data = self.to_vec();
121
122 let mut left = 0;
124 let mut right = data.len();
125
126 while left < right {
127 let mid = left + (right - left) / 2;
128 if data[mid] < value {
129 left = mid + 1;
130 } else {
131 right = mid;
132 }
133 }
134
135 left
136 }
137
138 pub fn unique(&self) -> Array {
149 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
150 let mut data = self.to_vec();
151 data.sort_by(|a, b| a.partial_cmp(b).unwrap());
152 data.dedup_by(|a, b| (*a - *b).abs() < 1e-7);
153 let len = data.len();
154 Array::from_vec(data, Shape::new(vec![len]))
155 }
156
157 pub fn unique_counts(&self) -> (Array, Vec<usize>) {
171 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
172 let mut data = self.to_vec();
173 data.sort_by(|a, b| a.partial_cmp(b).unwrap());
174
175 let mut unique_vals = Vec::new();
176 let mut counts = Vec::new();
177
178 if !data.is_empty() {
179 let mut current = data[0];
180 let mut count = 1;
181
182 for &val in data.iter().skip(1) {
183 if (val - current).abs() < 1e-7 {
184 count += 1;
185 } else {
186 unique_vals.push(current);
187 counts.push(count);
188 current = val;
189 count = 1;
190 }
191 }
192 unique_vals.push(current);
193 counts.push(count);
194 }
195
196 (
197 Array::from_vec(unique_vals, Shape::new(vec![counts.len()])),
198 counts,
199 )
200 }
201
202 pub fn setdiff1d(&self, other: &Array) -> Array {
216 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
217 assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
218
219 let self_unique = self.unique();
220 let other_data = other.to_vec();
221
222 let result: Vec<f32> = self_unique
223 .to_vec()
224 .into_iter()
225 .filter(|&val| !other_data.iter().any(|&x| (x - val).abs() < 1e-7))
226 .collect();
227
228 let len = result.len();
229 Array::from_vec(result, Shape::new(vec![len]))
230 }
231
232 pub fn union1d(&self, other: &Array) -> Array {
246 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
247 assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
248
249 let mut combined = self.to_vec();
250 combined.extend(other.to_vec());
251
252 let temp = Array::from_vec(combined, Shape::new(vec![self.size() + other.size()]));
253 temp.unique()
254 }
255
256 pub fn intersect1d(&self, other: &Array) -> Array {
270 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
271 assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
272
273 let self_unique = self.unique();
274 let other_data = other.to_vec();
275
276 let result: Vec<f32> = self_unique
277 .to_vec()
278 .into_iter()
279 .filter(|&val| other_data.iter().any(|&x| (x - val).abs() < 1e-7))
280 .collect();
281
282 let len = result.len();
283 Array::from_vec(result, Shape::new(vec![len]))
284 }
285
286 pub fn setxor1d(&self, other: &Array) -> Array {
300 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
301 assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
302
303 let union = self.union1d(other);
304 let intersect = self.intersect1d(other);
305 union.setdiff1d(&intersect)
306 }
307
308 pub fn in1d(&self, test_elements: &Array) -> Array {
322 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
323 assert_eq!(
324 test_elements.dtype(),
325 DType::Float32,
326 "Only Float32 supported"
327 );
328
329 let data = self.to_vec();
330 let test_data = test_elements.to_vec();
331
332 let result: Vec<f32> = data
333 .iter()
334 .map(|&val| {
335 if test_data.iter().any(|&x| (x - val).abs() < 1e-7) {
336 1.0
337 } else {
338 0.0
339 }
340 })
341 .collect();
342
343 Array::from_vec(result, self.shape().clone())
344 }
345
346 pub fn digitize(&self, bins: &Array) -> Vec<usize> {
358 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
359 assert_eq!(bins.dtype(), DType::Float32, "Only Float32 supported");
360
361 let data = self.to_vec();
362 let bin_edges = bins.to_vec();
363
364 data.iter()
365 .map(|&val| {
366 let mut left = 0;
368 let mut right = bin_edges.len();
369
370 while left < right {
371 let mid = left + (right - left) / 2;
372 if bin_edges[mid] <= val {
373 left = mid + 1;
374 } else {
375 right = mid;
376 }
377 }
378 left
379 })
380 .collect()
381 }
382
383 pub fn histogram(&self, bins: usize, range_min: f32, range_max: f32) -> (Vec<usize>, Vec<f32>) {
397 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
398 assert!(bins > 0, "Number of bins must be positive");
399 assert!(range_max > range_min, "range_max must be > range_min");
400
401 let data = self.to_vec();
402 let bin_width = (range_max - range_min) / bins as f32;
403
404 let mut bin_edges = Vec::with_capacity(bins + 1);
406 for i in 0..=bins {
407 bin_edges.push(range_min + i as f32 * bin_width);
408 }
409
410 let mut hist = vec![0; bins];
412 for &val in data.iter() {
413 if val >= range_min && val <= range_max {
414 let bin_idx = ((val - range_min) / bin_width).floor() as usize;
415 let bin_idx = bin_idx.min(bins - 1); hist[bin_idx] += 1;
417 }
418 }
419
420 (hist, bin_edges)
421 }
422
423 pub fn bincount(&self) -> Vec<usize> {
434 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
435 let data = self.to_vec();
436
437 let max_val = data
439 .iter()
440 .map(|&x| x as usize)
441 .max()
442 .unwrap_or(0);
443
444 let mut counts = vec![0; max_val + 1];
445 for &val in data.iter() {
446 let idx = val as usize;
447 counts[idx] += 1;
448 }
449
450 counts
451 }
452
453 pub fn bincount_weighted(&self, weights: &Array) -> Vec<f32> {
467 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
468 assert_eq!(weights.dtype(), DType::Float32, "Only Float32 supported");
469 assert_eq!(
470 self.size(),
471 weights.size(),
472 "Array and weights must have same size"
473 );
474
475 let data = self.to_vec();
476 let weight_data = weights.to_vec();
477
478 let max_val = data
480 .iter()
481 .map(|&x| x as usize)
482 .max()
483 .unwrap_or(0);
484
485 let mut counts = vec![0.0; max_val + 1];
486 for (i, &val) in data.iter().enumerate() {
487 let idx = val as usize;
488 counts[idx] += weight_data[i];
489 }
490
491 counts
492 }
493
494 pub fn partition(&self, kth: usize) -> Array {
511 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
512 assert!(kth < self.size(), "kth must be less than array size");
513
514 let mut data = self.to_vec();
515
516 let n = data.len();
518 let mut left = 0;
519 let mut right = n - 1;
520
521 while left < right {
522 let pivot = data[right];
523 let mut store_idx = left;
524
525 for i in left..right {
526 if data[i] < pivot {
527 data.swap(i, store_idx);
528 store_idx += 1;
529 }
530 }
531 data.swap(store_idx, right);
532
533 if store_idx == kth {
534 break;
535 } else if store_idx < kth {
536 left = store_idx + 1;
537 } else {
538 right = store_idx.saturating_sub(1);
539 }
540 }
541
542 Array::from_vec(data, self.shape().clone())
543 }
544
545 pub fn argpartition(&self, kth: usize) -> Vec<usize> {
556 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
557 assert!(kth < self.size(), "kth must be less than array size");
558
559 let data = self.to_vec();
560 let mut indices: Vec<usize> = (0..data.len()).collect();
561
562 let n = indices.len();
564 let mut left = 0;
565 let mut right = n - 1;
566
567 while left < right {
568 let pivot_val = data[indices[right]];
569 let mut store_idx = left;
570
571 for i in left..right {
572 if data[indices[i]] < pivot_val {
573 indices.swap(i, store_idx);
574 store_idx += 1;
575 }
576 }
577 indices.swap(store_idx, right);
578
579 if store_idx == kth {
580 break;
581 } else if store_idx < kth {
582 left = store_idx + 1;
583 } else {
584 right = store_idx.saturating_sub(1);
585 }
586 }
587
588 indices
589 }
590
591 pub fn lexsort(keys: &[&Array]) -> Vec<usize> {
607 assert!(!keys.is_empty(), "Need at least one key");
608
609 let n = keys[0].size();
610 for key in keys {
611 assert_eq!(key.size(), n, "All keys must have same length");
612 assert_eq!(key.dtype(), DType::Float32, "Only Float32 supported");
613 }
614
615 let mut indices: Vec<usize> = (0..n).collect();
616
617 indices.sort_by(|&a, &b| {
619 for key in keys.iter().rev() {
621 let key_data = key.to_vec();
622 let cmp = key_data[a].partial_cmp(&key_data[b]).unwrap();
623 if cmp != std::cmp::Ordering::Equal {
624 return cmp;
625 }
626 }
627 std::cmp::Ordering::Equal
628 });
629
630 indices
631 }
632
633 pub fn median_select(&self) -> f32 {
646 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
647 let n = self.size();
648 assert!(n > 0, "Array must not be empty");
649
650 let partitioned = self.partition(n / 2);
651 let data = partitioned.to_vec();
652
653 if n % 2 == 1 {
654 data[n / 2]
655 } else {
656 let left = self.partition(n / 2 - 1);
658 let left_data = left.to_vec();
659 (left_data[n / 2 - 1] + data[n / 2]) / 2.0
660 }
661 }
662
663 pub fn select_kth(&self, k: usize) -> f32 {
674 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
675 assert!(k < self.size(), "k must be less than array size");
676
677 let partitioned = self.partition(k);
678 partitioned.to_vec()[k]
679 }
680
681 pub fn sort_axis(&self, _axis: i32) -> Array {
694 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
695 let shape = self.shape().as_slice();
696
697 if shape.len() == 1 {
698 return self.sort();
699 }
700
701 if shape.len() == 2 {
702 let (rows, cols) = (shape[0], shape[1]);
704 let data = self.to_vec();
705 let mut result = Vec::with_capacity(data.len());
706
707 for r in 0..rows {
708 let start = r * cols;
709 let mut row: Vec<f32> = data[start..start + cols].to_vec();
710 row.sort_by(|a, b| a.partial_cmp(b).unwrap());
711 result.extend(row);
712 }
713
714 return Array::from_vec(result, self.shape().clone());
715 }
716
717 self.sort()
719 }
720
721 pub fn stable_argsort(&self) -> Vec<usize> {
735 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
736 let data = self.to_vec();
737 let mut indexed: Vec<(usize, f32)> = data.into_iter().enumerate().collect();
738 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
739 indexed.into_iter().map(|(i, _)| i).collect()
740 }
741}
742
743#[cfg(test)]
744mod tests {
745 use super::*;
746
747 #[test]
748 fn test_sort() {
749 let a = Array::from_vec(vec![3.0, 1.0, 4.0, 2.0], Shape::new(vec![4]));
750 let sorted = a.sort();
751 assert_eq!(sorted.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
752 }
753
754 #[test]
755 fn test_sort_descending() {
756 let a = Array::from_vec(vec![3.0, 1.0, 4.0, 2.0], Shape::new(vec![4]));
757 let sorted = a.sort_descending();
758 assert_eq!(sorted.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
759 }
760
761 #[test]
762 fn test_argsort() {
763 let a = Array::from_vec(vec![3.0, 1.0, 4.0, 2.0], Shape::new(vec![4]));
764 let indices = a.argsort();
765 assert_eq!(indices, vec![1, 3, 0, 2]);
766 }
767
768 #[test]
769 fn test_top_k() {
770 let a = Array::from_vec(
771 vec![3.0, 1.0, 4.0, 2.0, 5.0],
772 Shape::new(vec![5]),
773 );
774 let smallest = a.top_k_smallest(2);
775 assert_eq!(smallest, vec![1, 3]);
776
777 let largest = a.top_k_largest(2);
778 assert_eq!(largest, vec![4, 2]);
779 }
780
781 #[test]
782 fn test_searchsorted() {
783 let a = Array::from_vec(vec![1.0, 3.0, 5.0, 7.0], Shape::new(vec![4]));
784 assert_eq!(a.searchsorted(4.0), 2);
785 assert_eq!(a.searchsorted(0.0), 0);
786 assert_eq!(a.searchsorted(10.0), 4);
787 assert_eq!(a.searchsorted(5.0), 2);
788 }
789
790 #[test]
791 fn test_unique() {
792 let a = Array::from_vec(
793 vec![1.0, 2.0, 1.0, 3.0, 2.0],
794 Shape::new(vec![5]),
795 );
796 let unique = a.unique();
797 assert_eq!(unique.to_vec(), vec![1.0, 2.0, 3.0]);
798 }
799
800 #[test]
801 fn test_unique_counts() {
802 let a = Array::from_vec(
803 vec![1.0, 2.0, 1.0, 3.0, 2.0, 1.0],
804 Shape::new(vec![6]),
805 );
806 let (values, counts) = a.unique_counts();
807 assert_eq!(values.to_vec(), vec![1.0, 2.0, 3.0]);
808 assert_eq!(counts, vec![3, 2, 1]);
809 }
810
811 #[test]
812 fn test_setdiff1d() {
813 let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
814 let b = Array::from_vec(vec![2.0, 4.0, 5.0], Shape::new(vec![3]));
815 let diff = a.setdiff1d(&b);
816 assert_eq!(diff.to_vec(), vec![1.0, 3.0]);
817 }
818
819 #[test]
820 fn test_union1d() {
821 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
822 let b = Array::from_vec(vec![2.0, 3.0, 4.0], Shape::new(vec![3]));
823 let union = a.union1d(&b);
824 assert_eq!(union.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
825 }
826
827 #[test]
828 fn test_intersect1d() {
829 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
830 let b = Array::from_vec(vec![2.0, 3.0, 4.0], Shape::new(vec![3]));
831 let intersect = a.intersect1d(&b);
832 assert_eq!(intersect.to_vec(), vec![2.0, 3.0]);
833 }
834
835 #[test]
836 fn test_setxor1d() {
837 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
838 let b = Array::from_vec(vec![2.0, 3.0, 4.0], Shape::new(vec![3]));
839 let xor = a.setxor1d(&b);
840 assert_eq!(xor.to_vec(), vec![1.0, 4.0]);
841 }
842
843 #[test]
844 fn test_in1d() {
845 let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
846 let b = Array::from_vec(vec![2.0, 4.0], Shape::new(vec![2]));
847 let result = a.in1d(&b);
848 assert_eq!(result.to_vec(), vec![0.0, 1.0, 0.0, 1.0]);
849 }
850
851 #[test]
852 fn test_digitize() {
853 let x = Array::from_vec(vec![0.2, 6.4, 3.0, 1.6], Shape::new(vec![4]));
854 let bins =
855 Array::from_vec(vec![0.0, 1.0, 2.5, 4.0, 10.0], Shape::new(vec![5]));
856 let indices = x.digitize(&bins);
857 assert_eq!(indices, vec![1, 4, 3, 2]);
858 }
859
860 #[test]
861 fn test_histogram() {
862 let a = Array::from_vec(
863 vec![1.0, 2.0, 1.0, 3.0, 2.0, 1.0],
864 Shape::new(vec![6]),
865 );
866 let (hist, edges) = a.histogram(3, 0.0, 4.0);
867 assert_eq!(hist, vec![3, 2, 1]);
868 assert_eq!(edges.len(), 4); }
870
871 #[test]
872 fn test_bincount() {
873 let a = Array::from_vec(
874 vec![0.0, 1.0, 1.0, 3.0, 2.0, 1.0, 7.0],
875 Shape::new(vec![7]),
876 );
877 let counts = a.bincount();
878 assert_eq!(counts, vec![1, 3, 1, 1, 0, 0, 0, 1]);
879 }
880
881 #[test]
882 fn test_bincount_weighted() {
883 let a = Array::from_vec(vec![0.0, 1.0, 1.0, 2.0], Shape::new(vec![4]));
884 let weights =
885 Array::from_vec(vec![0.3, 0.5, 0.2, 0.7], Shape::new(vec![4]));
886 let counts = a.bincount_weighted(&weights);
887 assert!((counts[0] - 0.3).abs() < 1e-6);
888 assert!((counts[1] - 0.7).abs() < 1e-6);
889 assert!((counts[2] - 0.7).abs() < 1e-6);
890 }
891}