jax_rs/ops/
sort.rs

1//! Sorting and searching operations.
2
3use crate::{Array, DType, Shape};
4
5impl Array {
6    /// Sort array elements.
7    ///
8    /// # Examples
9    ///
10    /// ```
11    /// # use jax_rs::{Array, Shape};
12    /// let a = Array::from_vec(vec![3.0, 1.0, 4.0, 2.0], Shape::new(vec![4]));
13    /// let sorted = a.sort();
14    /// assert_eq!(sorted.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
15    /// ```
16    pub fn sort(&self) -> Array {
17        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
18        let mut data = self.to_vec();
19        data.sort_by(|a, b| a.partial_cmp(b).unwrap());
20        Array::from_vec(data, self.shape().clone())
21    }
22
23    /// Sort array elements in descending order.
24    ///
25    /// # Examples
26    ///
27    /// ```
28    /// # use jax_rs::{Array, Shape};
29    /// let a = Array::from_vec(vec![3.0, 1.0, 4.0, 2.0], Shape::new(vec![4]));
30    /// let sorted = a.sort_descending();
31    /// assert_eq!(sorted.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
32    /// ```
33    pub fn sort_descending(&self) -> Array {
34        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
35        let mut data = self.to_vec();
36        data.sort_by(|a, b| b.partial_cmp(a).unwrap());
37        Array::from_vec(data, self.shape().clone())
38    }
39
40    /// Return indices that would sort the array.
41    ///
42    /// # Examples
43    ///
44    /// ```
45    /// # use jax_rs::{Array, Shape};
46    /// let a = Array::from_vec(vec![3.0, 1.0, 4.0, 2.0], Shape::new(vec![4]));
47    /// let indices = a.argsort();
48    /// assert_eq!(indices, vec![1, 3, 0, 2]);
49    /// ```
50    pub fn argsort(&self) -> Vec<usize> {
51        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
52        let data = self.to_vec();
53        let mut indices: Vec<usize> = (0..data.len()).collect();
54        indices.sort_by(|&a, &b| data[a].partial_cmp(&data[b]).unwrap());
55        indices
56    }
57
58    /// Return indices that would sort the array in descending order.
59    ///
60    /// # Examples
61    ///
62    /// ```
63    /// # use jax_rs::{Array, Shape};
64    /// let a = Array::from_vec(vec![3.0, 1.0, 4.0, 2.0], Shape::new(vec![4]));
65    /// let indices = a.argsort_descending();
66    /// assert_eq!(indices, vec![2, 0, 3, 1]);
67    /// ```
68    pub fn argsort_descending(&self) -> Vec<usize> {
69        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
70        let data = self.to_vec();
71        let mut indices: Vec<usize> = (0..data.len()).collect();
72        indices.sort_by(|&a, &b| data[b].partial_cmp(&data[a]).unwrap());
73        indices
74    }
75
76    /// Find the k smallest elements and return their indices.
77    ///
78    /// # Examples
79    ///
80    /// ```
81    /// # use jax_rs::{Array, Shape};
82    /// let a = Array::from_vec(vec![3.0, 1.0, 4.0, 2.0, 5.0], Shape::new(vec![5]));
83    /// let top2 = a.top_k_smallest(2);
84    /// assert_eq!(top2, vec![1, 3]);
85    /// ```
86    pub fn top_k_smallest(&self, k: usize) -> Vec<usize> {
87        assert!(k <= self.size(), "k must be <= array size");
88        let indices = self.argsort();
89        indices.into_iter().take(k).collect()
90    }
91
92    /// Find the k largest elements and return their indices.
93    ///
94    /// # Examples
95    ///
96    /// ```
97    /// # use jax_rs::{Array, Shape};
98    /// let a = Array::from_vec(vec![3.0, 1.0, 4.0, 2.0, 5.0], Shape::new(vec![5]));
99    /// let top2 = a.top_k_largest(2);
100    /// assert_eq!(top2, vec![4, 2]);
101    /// ```
102    pub fn top_k_largest(&self, k: usize) -> Vec<usize> {
103        assert!(k <= self.size(), "k must be <= array size");
104        let indices = self.argsort_descending();
105        indices.into_iter().take(k).collect()
106    }
107
108    /// Find indices where elements should be inserted to maintain order.
109    ///
110    /// # Examples
111    ///
112    /// ```
113    /// # use jax_rs::{Array, Shape};
114    /// let a = Array::from_vec(vec![1.0, 3.0, 5.0, 7.0], Shape::new(vec![4]));
115    /// let idx = a.searchsorted(4.0);
116    /// assert_eq!(idx, 2);
117    /// ```
118    pub fn searchsorted(&self, value: f32) -> usize {
119        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
120        let data = self.to_vec();
121
122        // Binary search
123        let mut left = 0;
124        let mut right = data.len();
125
126        while left < right {
127            let mid = left + (right - left) / 2;
128            if data[mid] < value {
129                left = mid + 1;
130            } else {
131                right = mid;
132            }
133        }
134
135        left
136    }
137
138    /// Find unique elements in the array.
139    ///
140    /// # Examples
141    ///
142    /// ```
143    /// # use jax_rs::{Array, Shape};
144    /// let a = Array::from_vec(vec![1.0, 2.0, 1.0, 3.0, 2.0], Shape::new(vec![5]));
145    /// let unique = a.unique();
146    /// assert_eq!(unique.to_vec(), vec![1.0, 2.0, 3.0]);
147    /// ```
148    pub fn unique(&self) -> Array {
149        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
150        let mut data = self.to_vec();
151        data.sort_by(|a, b| a.partial_cmp(b).unwrap());
152        data.dedup_by(|a, b| (*a - *b).abs() < 1e-7);
153        let len = data.len();
154        Array::from_vec(data, Shape::new(vec![len]))
155    }
156
157    /// Count occurrences of each unique value.
158    ///
159    /// Returns (unique_values, counts).
160    ///
161    /// # Examples
162    ///
163    /// ```
164    /// # use jax_rs::{Array, Shape};
165    /// let a = Array::from_vec(vec![1.0, 2.0, 1.0, 3.0, 2.0, 1.0], Shape::new(vec![6]));
166    /// let (values, counts) = a.unique_counts();
167    /// assert_eq!(values.to_vec(), vec![1.0, 2.0, 3.0]);
168    /// assert_eq!(counts, vec![3, 2, 1]);
169    /// ```
170    pub fn unique_counts(&self) -> (Array, Vec<usize>) {
171        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
172        let mut data = self.to_vec();
173        data.sort_by(|a, b| a.partial_cmp(b).unwrap());
174
175        let mut unique_vals = Vec::new();
176        let mut counts = Vec::new();
177
178        if !data.is_empty() {
179            let mut current = data[0];
180            let mut count = 1;
181
182            for &val in data.iter().skip(1) {
183                if (val - current).abs() < 1e-7 {
184                    count += 1;
185                } else {
186                    unique_vals.push(current);
187                    counts.push(count);
188                    current = val;
189                    count = 1;
190                }
191            }
192            unique_vals.push(current);
193            counts.push(count);
194        }
195
196        (
197            Array::from_vec(unique_vals, Shape::new(vec![counts.len()])),
198            counts,
199        )
200    }
201
202    /// Find the set difference of two arrays.
203    ///
204    /// Returns the unique values in the first array that are not in the second array.
205    ///
206    /// # Examples
207    ///
208    /// ```
209    /// # use jax_rs::{Array, Shape};
210    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
211    /// let b = Array::from_vec(vec![2.0, 4.0, 5.0], Shape::new(vec![3]));
212    /// let diff = a.setdiff1d(&b);
213    /// assert_eq!(diff.to_vec(), vec![1.0, 3.0]);
214    /// ```
215    pub fn setdiff1d(&self, other: &Array) -> Array {
216        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
217        assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
218
219        let self_unique = self.unique();
220        let other_data = other.to_vec();
221
222        let result: Vec<f32> = self_unique
223            .to_vec()
224            .into_iter()
225            .filter(|&val| !other_data.iter().any(|&x| (x - val).abs() < 1e-7))
226            .collect();
227
228        let len = result.len();
229        Array::from_vec(result, Shape::new(vec![len]))
230    }
231
232    /// Find the union of two arrays.
233    ///
234    /// Returns the unique values that are in either of the two arrays.
235    ///
236    /// # Examples
237    ///
238    /// ```
239    /// # use jax_rs::{Array, Shape};
240    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
241    /// let b = Array::from_vec(vec![2.0, 3.0, 4.0], Shape::new(vec![3]));
242    /// let union = a.union1d(&b);
243    /// assert_eq!(union.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
244    /// ```
245    pub fn union1d(&self, other: &Array) -> Array {
246        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
247        assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
248
249        let mut combined = self.to_vec();
250        combined.extend(other.to_vec());
251
252        let temp = Array::from_vec(combined, Shape::new(vec![self.size() + other.size()]));
253        temp.unique()
254    }
255
256    /// Find the intersection of two arrays.
257    ///
258    /// Returns the unique values that are in both arrays.
259    ///
260    /// # Examples
261    ///
262    /// ```
263    /// # use jax_rs::{Array, Shape};
264    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
265    /// let b = Array::from_vec(vec![2.0, 3.0, 4.0], Shape::new(vec![3]));
266    /// let intersect = a.intersect1d(&b);
267    /// assert_eq!(intersect.to_vec(), vec![2.0, 3.0]);
268    /// ```
269    pub fn intersect1d(&self, other: &Array) -> Array {
270        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
271        assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
272
273        let self_unique = self.unique();
274        let other_data = other.to_vec();
275
276        let result: Vec<f32> = self_unique
277            .to_vec()
278            .into_iter()
279            .filter(|&val| other_data.iter().any(|&x| (x - val).abs() < 1e-7))
280            .collect();
281
282        let len = result.len();
283        Array::from_vec(result, Shape::new(vec![len]))
284    }
285
286    /// Find the exclusive-or of two arrays.
287    ///
288    /// Returns the unique values that are in exactly one of the two arrays.
289    ///
290    /// # Examples
291    ///
292    /// ```
293    /// # use jax_rs::{Array, Shape};
294    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
295    /// let b = Array::from_vec(vec![2.0, 3.0, 4.0], Shape::new(vec![3]));
296    /// let xor = a.setxor1d(&b);
297    /// assert_eq!(xor.to_vec(), vec![1.0, 4.0]);
298    /// ```
299    pub fn setxor1d(&self, other: &Array) -> Array {
300        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
301        assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
302
303        let union = self.union1d(other);
304        let intersect = self.intersect1d(other);
305        union.setdiff1d(&intersect)
306    }
307
308    /// Test whether each element of a 1D array is also present in a second array.
309    ///
310    /// Returns a boolean-like array (1.0 for true, 0.0 for false).
311    ///
312    /// # Examples
313    ///
314    /// ```
315    /// # use jax_rs::{Array, Shape};
316    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
317    /// let b = Array::from_vec(vec![2.0, 4.0], Shape::new(vec![2]));
318    /// let result = a.in1d(&b);
319    /// assert_eq!(result.to_vec(), vec![0.0, 1.0, 0.0, 1.0]);
320    /// ```
321    pub fn in1d(&self, test_elements: &Array) -> Array {
322        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
323        assert_eq!(
324            test_elements.dtype(),
325            DType::Float32,
326            "Only Float32 supported"
327        );
328
329        let data = self.to_vec();
330        let test_data = test_elements.to_vec();
331
332        let result: Vec<f32> = data
333            .iter()
334            .map(|&val| {
335                if test_data.iter().any(|&x| (x - val).abs() < 1e-7) {
336                    1.0
337                } else {
338                    0.0
339                }
340            })
341            .collect();
342
343        Array::from_vec(result, self.shape().clone())
344    }
345
346    /// Return the indices of the bins to which each value belongs.
347    ///
348    /// # Examples
349    ///
350    /// ```
351    /// # use jax_rs::{Array, Shape};
352    /// let x = Array::from_vec(vec![0.2, 6.4, 3.0, 1.6], Shape::new(vec![4]));
353    /// let bins = Array::from_vec(vec![0.0, 1.0, 2.5, 4.0, 10.0], Shape::new(vec![5]));
354    /// let indices = x.digitize(&bins);
355    /// assert_eq!(indices, vec![1, 4, 3, 2]);
356    /// ```
357    pub fn digitize(&self, bins: &Array) -> Vec<usize> {
358        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
359        assert_eq!(bins.dtype(), DType::Float32, "Only Float32 supported");
360
361        let data = self.to_vec();
362        let bin_edges = bins.to_vec();
363
364        data.iter()
365            .map(|&val| {
366                // Find the bin index using binary search
367                let mut left = 0;
368                let mut right = bin_edges.len();
369
370                while left < right {
371                    let mid = left + (right - left) / 2;
372                    if bin_edges[mid] <= val {
373                        left = mid + 1;
374                    } else {
375                        right = mid;
376                    }
377                }
378                left
379            })
380            .collect()
381    }
382
383    /// Compute the histogram of a dataset.
384    ///
385    /// Returns (hist, bin_edges) where hist contains the counts and bin_edges
386    /// contains the bin boundaries.
387    ///
388    /// # Examples
389    ///
390    /// ```
391    /// # use jax_rs::{Array, Shape};
392    /// let a = Array::from_vec(vec![1.0, 2.0, 1.0, 3.0, 2.0, 1.0], Shape::new(vec![6]));
393    /// let (hist, edges) = a.histogram(3, 0.0, 4.0);
394    /// assert_eq!(hist, vec![3, 2, 1]);
395    /// ```
396    pub fn histogram(&self, bins: usize, range_min: f32, range_max: f32) -> (Vec<usize>, Vec<f32>) {
397        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
398        assert!(bins > 0, "Number of bins must be positive");
399        assert!(range_max > range_min, "range_max must be > range_min");
400
401        let data = self.to_vec();
402        let bin_width = (range_max - range_min) / bins as f32;
403
404        // Create bin edges
405        let mut bin_edges = Vec::with_capacity(bins + 1);
406        for i in 0..=bins {
407            bin_edges.push(range_min + i as f32 * bin_width);
408        }
409
410        // Count values in each bin
411        let mut hist = vec![0; bins];
412        for &val in data.iter() {
413            if val >= range_min && val <= range_max {
414                let bin_idx = ((val - range_min) / bin_width).floor() as usize;
415                let bin_idx = bin_idx.min(bins - 1); // Handle edge case where val == range_max
416                hist[bin_idx] += 1;
417            }
418        }
419
420        (hist, bin_edges)
421    }
422
423    /// Count number of occurrences of each value in array of non-negative integers.
424    ///
425    /// # Examples
426    ///
427    /// ```
428    /// # use jax_rs::{Array, Shape};
429    /// let a = Array::from_vec(vec![0.0, 1.0, 1.0, 3.0, 2.0, 1.0, 7.0], Shape::new(vec![7]));
430    /// let counts = a.bincount();
431    /// assert_eq!(counts, vec![1, 3, 1, 1, 0, 0, 0, 1]);
432    /// ```
433    pub fn bincount(&self) -> Vec<usize> {
434        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
435        let data = self.to_vec();
436
437        // Find the maximum value to determine array size
438        let max_val = data
439            .iter()
440            .map(|&x| x as usize)
441            .max()
442            .unwrap_or(0);
443
444        let mut counts = vec![0; max_val + 1];
445        for &val in data.iter() {
446            let idx = val as usize;
447            counts[idx] += 1;
448        }
449
450        counts
451    }
452
453    /// Count number of occurrences with optional weights.
454    ///
455    /// # Examples
456    ///
457    /// ```
458    /// # use jax_rs::{Array, Shape};
459    /// let a = Array::from_vec(vec![0.0, 1.0, 1.0, 2.0], Shape::new(vec![4]));
460    /// let weights = Array::from_vec(vec![0.3, 0.5, 0.2, 0.7], Shape::new(vec![4]));
461    /// let counts = a.bincount_weighted(&weights);
462    /// assert!((counts[0] - 0.3).abs() < 1e-6);
463    /// assert!((counts[1] - 0.7).abs() < 1e-6);
464    /// assert!((counts[2] - 0.7).abs() < 1e-6);
465    /// ```
466    pub fn bincount_weighted(&self, weights: &Array) -> Vec<f32> {
467        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
468        assert_eq!(weights.dtype(), DType::Float32, "Only Float32 supported");
469        assert_eq!(
470            self.size(),
471            weights.size(),
472            "Array and weights must have same size"
473        );
474
475        let data = self.to_vec();
476        let weight_data = weights.to_vec();
477
478        // Find the maximum value to determine array size
479        let max_val = data
480            .iter()
481            .map(|&x| x as usize)
482            .max()
483            .unwrap_or(0);
484
485        let mut counts = vec![0.0; max_val + 1];
486        for (i, &val) in data.iter().enumerate() {
487            let idx = val as usize;
488            counts[idx] += weight_data[i];
489        }
490
491        counts
492    }
493
494    /// Partially sort array so that the k-th element is in sorted position.
495    ///
496    /// Elements smaller than the k-th element are moved before it,
497    /// and elements larger are moved after it.
498    ///
499    /// # Examples
500    ///
501    /// ```
502    /// # use jax_rs::{Array, Shape};
503    /// let a = Array::from_vec(vec![3.0, 4.0, 2.0, 1.0], Shape::new(vec![4]));
504    /// let partitioned = a.partition(2);
505    /// // Element at index 2 is in correct sorted position
506    /// let data = partitioned.to_vec();
507    /// assert!(data[0] <= data[2] && data[1] <= data[2]);
508    /// assert!(data[2] <= data[3]);
509    /// ```
510    pub fn partition(&self, kth: usize) -> Array {
511        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
512        assert!(kth < self.size(), "kth must be less than array size");
513
514        let mut data = self.to_vec();
515
516        // Use selection algorithm (quickselect-style partitioning)
517        let n = data.len();
518        let mut left = 0;
519        let mut right = n - 1;
520
521        while left < right {
522            let pivot = data[right];
523            let mut store_idx = left;
524
525            for i in left..right {
526                if data[i] < pivot {
527                    data.swap(i, store_idx);
528                    store_idx += 1;
529                }
530            }
531            data.swap(store_idx, right);
532
533            if store_idx == kth {
534                break;
535            } else if store_idx < kth {
536                left = store_idx + 1;
537            } else {
538                right = store_idx.saturating_sub(1);
539            }
540        }
541
542        Array::from_vec(data, self.shape().clone())
543    }
544
545    /// Return indices that would partition the array.
546    ///
547    /// # Examples
548    ///
549    /// ```
550    /// # use jax_rs::{Array, Shape};
551    /// let a = Array::from_vec(vec![3.0, 4.0, 2.0, 1.0], Shape::new(vec![4]));
552    /// let indices = a.argpartition(2);
553    /// // Indices are such that a[indices[0]] and a[indices[1]] are smaller than a[indices[2]]
554    /// ```
555    pub fn argpartition(&self, kth: usize) -> Vec<usize> {
556        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
557        assert!(kth < self.size(), "kth must be less than array size");
558
559        let data = self.to_vec();
560        let mut indices: Vec<usize> = (0..data.len()).collect();
561
562        // Partition indices based on values
563        let n = indices.len();
564        let mut left = 0;
565        let mut right = n - 1;
566
567        while left < right {
568            let pivot_val = data[indices[right]];
569            let mut store_idx = left;
570
571            for i in left..right {
572                if data[indices[i]] < pivot_val {
573                    indices.swap(i, store_idx);
574                    store_idx += 1;
575                }
576            }
577            indices.swap(store_idx, right);
578
579            if store_idx == kth {
580                break;
581            } else if store_idx < kth {
582                left = store_idx + 1;
583            } else {
584                right = store_idx.saturating_sub(1);
585            }
586        }
587
588        indices
589    }
590
591    /// Perform indirect stable sort using a sequence of keys.
592    ///
593    /// Sort by the last key first, then by second-to-last, etc.
594    /// This is equivalent to numpy's lexsort.
595    ///
596    /// # Examples
597    ///
598    /// ```
599    /// # use jax_rs::{Array, Shape};
600    /// // Sort by surname, then by first name
601    /// let surnames = Array::from_vec(vec![1.0, 2.0, 1.0, 2.0], Shape::new(vec![4]));  // Hertz, Move, Hertz, Newton
602    /// let first_names = Array::from_vec(vec![1.0, 2.0, 2.0, 1.0], Shape::new(vec![4])); // Heinrich, Never, Lansen, Isaac
603    /// let indices = Array::lexsort(&[&first_names, &surnames]);
604    /// // Should be sorted by surname first, then first_name
605    /// ```
606    pub fn lexsort(keys: &[&Array]) -> Vec<usize> {
607        assert!(!keys.is_empty(), "Need at least one key");
608
609        let n = keys[0].size();
610        for key in keys {
611            assert_eq!(key.size(), n, "All keys must have same length");
612            assert_eq!(key.dtype(), DType::Float32, "Only Float32 supported");
613        }
614
615        let mut indices: Vec<usize> = (0..n).collect();
616
617        // Sort by keys in reverse order (last key is primary sort key)
618        indices.sort_by(|&a, &b| {
619            // Compare from last key to first
620            for key in keys.iter().rev() {
621                let key_data = key.to_vec();
622                let cmp = key_data[a].partial_cmp(&key_data[b]).unwrap();
623                if cmp != std::cmp::Ordering::Equal {
624                    return cmp;
625                }
626            }
627            std::cmp::Ordering::Equal
628        });
629
630        indices
631    }
632
633    /// Return the median element without full sorting.
634    ///
635    /// Uses quickselect algorithm for O(n) average performance.
636    ///
637    /// # Examples
638    ///
639    /// ```
640    /// # use jax_rs::{Array, Shape};
641    /// let a = Array::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0], Shape::new(vec![5]));
642    /// let med = a.median_select();
643    /// assert_eq!(med, 3.0);
644    /// ```
645    pub fn median_select(&self) -> f32 {
646        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
647        let n = self.size();
648        assert!(n > 0, "Array must not be empty");
649
650        let partitioned = self.partition(n / 2);
651        let data = partitioned.to_vec();
652
653        if n % 2 == 1 {
654            data[n / 2]
655        } else {
656            // For even length, also need the element before
657            let left = self.partition(n / 2 - 1);
658            let left_data = left.to_vec();
659            (left_data[n / 2 - 1] + data[n / 2]) / 2.0
660        }
661    }
662
663    /// Return the k-th smallest element using selection algorithm.
664    ///
665    /// # Examples
666    ///
667    /// ```
668    /// # use jax_rs::{Array, Shape};
669    /// let a = Array::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0], Shape::new(vec![5]));
670    /// let kth = a.select_kth(2); // 0-indexed, so 3rd smallest
671    /// assert_eq!(kth, 3.0);
672    /// ```
673    pub fn select_kth(&self, k: usize) -> f32 {
674        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
675        assert!(k < self.size(), "k must be less than array size");
676
677        let partitioned = self.partition(k);
678        partitioned.to_vec()[k]
679    }
680
681    /// Sort array along the last axis.
682    ///
683    /// For 2D arrays, sorts each row independently.
684    ///
685    /// # Examples
686    ///
687    /// ```
688    /// # use jax_rs::{Array, Shape};
689    /// let a = Array::from_vec(vec![3.0, 1.0, 2.0, 6.0, 4.0, 5.0], Shape::new(vec![2, 3]));
690    /// let sorted = a.sort_axis(-1);
691    /// assert_eq!(sorted.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
692    /// ```
693    pub fn sort_axis(&self, _axis: i32) -> Array {
694        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
695        let shape = self.shape().as_slice();
696
697        if shape.len() == 1 {
698            return self.sort();
699        }
700
701        if shape.len() == 2 {
702            // Sort each row
703            let (rows, cols) = (shape[0], shape[1]);
704            let data = self.to_vec();
705            let mut result = Vec::with_capacity(data.len());
706
707            for r in 0..rows {
708                let start = r * cols;
709                let mut row: Vec<f32> = data[start..start + cols].to_vec();
710                row.sort_by(|a, b| a.partial_cmp(b).unwrap());
711                result.extend(row);
712            }
713
714            return Array::from_vec(result, self.shape().clone());
715        }
716
717        // For higher dimensions, just sort flat
718        self.sort()
719    }
720
721    /// Return indices that would sort the array in stable order.
722    ///
723    /// Unlike argsort, stable_argsort preserves the relative order
724    /// of equal elements.
725    ///
726    /// # Examples
727    ///
728    /// ```
729    /// # use jax_rs::{Array, Shape};
730    /// let a = Array::from_vec(vec![3.0, 1.0, 1.0, 2.0], Shape::new(vec![4]));
731    /// let indices = a.stable_argsort();
732    /// assert_eq!(indices, vec![1, 2, 3, 0]); // The two 1.0s maintain order
733    /// ```
734    pub fn stable_argsort(&self) -> Vec<usize> {
735        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
736        let data = self.to_vec();
737        let mut indexed: Vec<(usize, f32)> = data.into_iter().enumerate().collect();
738        indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
739        indexed.into_iter().map(|(i, _)| i).collect()
740    }
741}
742
743#[cfg(test)]
744mod tests {
745    use super::*;
746
747    #[test]
748    fn test_sort() {
749        let a = Array::from_vec(vec![3.0, 1.0, 4.0, 2.0], Shape::new(vec![4]));
750        let sorted = a.sort();
751        assert_eq!(sorted.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
752    }
753
754    #[test]
755    fn test_sort_descending() {
756        let a = Array::from_vec(vec![3.0, 1.0, 4.0, 2.0], Shape::new(vec![4]));
757        let sorted = a.sort_descending();
758        assert_eq!(sorted.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
759    }
760
761    #[test]
762    fn test_argsort() {
763        let a = Array::from_vec(vec![3.0, 1.0, 4.0, 2.0], Shape::new(vec![4]));
764        let indices = a.argsort();
765        assert_eq!(indices, vec![1, 3, 0, 2]);
766    }
767
768    #[test]
769    fn test_top_k() {
770        let a = Array::from_vec(
771            vec![3.0, 1.0, 4.0, 2.0, 5.0],
772            Shape::new(vec![5]),
773        );
774        let smallest = a.top_k_smallest(2);
775        assert_eq!(smallest, vec![1, 3]);
776
777        let largest = a.top_k_largest(2);
778        assert_eq!(largest, vec![4, 2]);
779    }
780
781    #[test]
782    fn test_searchsorted() {
783        let a = Array::from_vec(vec![1.0, 3.0, 5.0, 7.0], Shape::new(vec![4]));
784        assert_eq!(a.searchsorted(4.0), 2);
785        assert_eq!(a.searchsorted(0.0), 0);
786        assert_eq!(a.searchsorted(10.0), 4);
787        assert_eq!(a.searchsorted(5.0), 2);
788    }
789
790    #[test]
791    fn test_unique() {
792        let a = Array::from_vec(
793            vec![1.0, 2.0, 1.0, 3.0, 2.0],
794            Shape::new(vec![5]),
795        );
796        let unique = a.unique();
797        assert_eq!(unique.to_vec(), vec![1.0, 2.0, 3.0]);
798    }
799
800    #[test]
801    fn test_unique_counts() {
802        let a = Array::from_vec(
803            vec![1.0, 2.0, 1.0, 3.0, 2.0, 1.0],
804            Shape::new(vec![6]),
805        );
806        let (values, counts) = a.unique_counts();
807        assert_eq!(values.to_vec(), vec![1.0, 2.0, 3.0]);
808        assert_eq!(counts, vec![3, 2, 1]);
809    }
810
811    #[test]
812    fn test_setdiff1d() {
813        let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
814        let b = Array::from_vec(vec![2.0, 4.0, 5.0], Shape::new(vec![3]));
815        let diff = a.setdiff1d(&b);
816        assert_eq!(diff.to_vec(), vec![1.0, 3.0]);
817    }
818
819    #[test]
820    fn test_union1d() {
821        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
822        let b = Array::from_vec(vec![2.0, 3.0, 4.0], Shape::new(vec![3]));
823        let union = a.union1d(&b);
824        assert_eq!(union.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
825    }
826
827    #[test]
828    fn test_intersect1d() {
829        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
830        let b = Array::from_vec(vec![2.0, 3.0, 4.0], Shape::new(vec![3]));
831        let intersect = a.intersect1d(&b);
832        assert_eq!(intersect.to_vec(), vec![2.0, 3.0]);
833    }
834
835    #[test]
836    fn test_setxor1d() {
837        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
838        let b = Array::from_vec(vec![2.0, 3.0, 4.0], Shape::new(vec![3]));
839        let xor = a.setxor1d(&b);
840        assert_eq!(xor.to_vec(), vec![1.0, 4.0]);
841    }
842
843    #[test]
844    fn test_in1d() {
845        let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
846        let b = Array::from_vec(vec![2.0, 4.0], Shape::new(vec![2]));
847        let result = a.in1d(&b);
848        assert_eq!(result.to_vec(), vec![0.0, 1.0, 0.0, 1.0]);
849    }
850
851    #[test]
852    fn test_digitize() {
853        let x = Array::from_vec(vec![0.2, 6.4, 3.0, 1.6], Shape::new(vec![4]));
854        let bins =
855            Array::from_vec(vec![0.0, 1.0, 2.5, 4.0, 10.0], Shape::new(vec![5]));
856        let indices = x.digitize(&bins);
857        assert_eq!(indices, vec![1, 4, 3, 2]);
858    }
859
860    #[test]
861    fn test_histogram() {
862        let a = Array::from_vec(
863            vec![1.0, 2.0, 1.0, 3.0, 2.0, 1.0],
864            Shape::new(vec![6]),
865        );
866        let (hist, edges) = a.histogram(3, 0.0, 4.0);
867        assert_eq!(hist, vec![3, 2, 1]);
868        assert_eq!(edges.len(), 4); // bins + 1
869    }
870
871    #[test]
872    fn test_bincount() {
873        let a = Array::from_vec(
874            vec![0.0, 1.0, 1.0, 3.0, 2.0, 1.0, 7.0],
875            Shape::new(vec![7]),
876        );
877        let counts = a.bincount();
878        assert_eq!(counts, vec![1, 3, 1, 1, 0, 0, 0, 1]);
879    }
880
881    #[test]
882    fn test_bincount_weighted() {
883        let a = Array::from_vec(vec![0.0, 1.0, 1.0, 2.0], Shape::new(vec![4]));
884        let weights =
885            Array::from_vec(vec![0.3, 0.5, 0.2, 0.7], Shape::new(vec![4]));
886        let counts = a.bincount_weighted(&weights);
887        assert!((counts[0] - 0.3).abs() < 1e-6);
888        assert!((counts[1] - 0.7).abs() < 1e-6);
889        assert!((counts[2] - 0.7).abs() < 1e-6);
890    }
891}