runmat_runtime/builtins/array/sorting_sets/
argsort.rs

1//! MATLAB-compatible `argsort` builtin returning permutation indices.
2
3use runmat_builtins::Value;
4use runmat_macros::runtime_builtin;
5
6use super::sort;
7use crate::builtins::common::spec::{
8    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
9    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
10};
11#[cfg(feature = "doc_export")]
12use crate::register_builtin_doc_text;
13use crate::{register_builtin_fusion_spec, register_builtin_gpu_spec};
14
15#[cfg(feature = "doc_export")]
16pub const DOC_MD: &str = r#"---
17title: "argsort"
18category: "array/sorting_sets"
19keywords: ["argsort", "sort", "indices", "permutation", "gpu"]
20summary: "Return the permutation indices that would sort tensors along a dimension."
21references:
22  - https://www.mathworks.com/help/matlab/ref/sort.html
23gpu_support:
24  elementwise: false
25  reduction: false
26  precisions: ["f32", "f64"]
27  broadcasting: "none"
28  notes: "Uses the same sort kernels as `sort`; falls back to host evaluation when the provider lacks `sort_dim`."
29fusion:
30  elementwise: false
31  reduction: false
32  max_inputs: 1
33  constants: "inline"
34requires_feature: null
35tested:
36  unit: "builtins::array::sorting_sets::argsort::tests"
37  integration: "builtins::array::sorting_sets::argsort::tests::argsort_gpu_roundtrip"
38---
39
40# What does the `argsort` function do in MATLAB / RunMat?
41`argsort(X)` returns the permutation indices that order `X` the same way `sort(X)` would. It matches the indices produced by `[~, I] = sort(X, ...)` in MathWorks MATLAB and honours the same argument forms for dimensions, directions, and comparison methods.
42
43## How does the `argsort` function behave in MATLAB / RunMat?
44- Operates along the first non-singleton dimension by default. Pass a dimension argument to override.
45- Accepts the same direction keywords as `sort`: `'ascend'` (default) or `'descend'`.
46- Supports `'ComparisonMethod'` values `'auto'`, `'real'`, and `'abs'` for real and complex inputs.
47- Returns indices as double-precision tensors using MATLAB's one-based indexing.
48- Treats NaN values as missing: they appear at the end for ascending permutations and at the beginning for descending permutations.
49- Acts as a residency sink. GPU tensors are gathered when the active provider does not expose a specialised sort kernel.
50
51## GPU execution in RunMat
52- `argsort` shares the `sort_dim` provider hook with the `sort` builtin. When implemented, indices are computed without leaving the device.
53- If the provider lacks `sort_dim`, RunMat gathers tensors to host memory, evaluates the permutation, and returns host-resident indices.
54- Outputs are always host-resident double tensors because permutation indices are consumed immediately by host-side logic (e.g., indexing).
55
56## Examples of using `argsort` in MATLAB / RunMat
57
58### Getting indices that sort a vector
59```matlab
60A = [4; 1; 3];
61idx = argsort(A);
62```
63Expected output:
64```matlab
65idx =
66     2
67     3
68     1
69```
70
71### Reordering data with the permutation indices
72```matlab
73A = [3 9 1 5];
74idx = argsort(A);
75sorted = A(idx);
76```
77Expected output:
78```matlab
79sorted =
80     1     3     5     9
81```
82
83### Sorting along a specific dimension
84```matlab
85A = [1 6 4; 2 3 5];
86idx = argsort(A, 2);
87```
88Expected output:
89```matlab
90idx =
91     1     3     2
92     1     2     3
93```
94
95### Descending order permutations
96```matlab
97A = [10 4 7 9];
98idx = argsort(A, 'descend');
99```
100Expected output:
101```matlab
102idx =
103     1     4     3     2
104```
105
106### Using `ComparisonMethod` to sort by magnitude
107```matlab
108A = [-8 -1 3 -2];
109idx = argsort(A, 'ComparisonMethod', 'abs');
110```
111Expected output:
112```matlab
113idx =
114     2     4     3     1
115```
116
117### Handling NaN values during permutation
118```matlab
119A = [NaN 4 1 2];
120idx = argsort(A);
121```
122Expected output:
123```matlab
124idx =
125     3     4     2     1
126```
127
128### Argsort on GPU tensors falls back gracefully
129```matlab
130G = gpuArray(randn(5, 1));
131idx = argsort(G);
132```
133RunMat gathers `G` to the host when no device sort kernel is available, ensuring the returned indices match MATLAB exactly.
134
135## FAQ
136
137### How is `argsort` different from `sort`?
138`argsort` returns only the permutation indices. It behaves like calling `[~, I] = sort(X, ...)` without materialising the sorted values.
139
140### Are the indices one-based like MATLAB?
141Yes. All indices follow MATLAB's one-based convention so they can be used directly with subsequent indexing operations.
142
143### Does `argsort` support the same arguments as `sort`?
144Yes. Dimension arguments, direction keywords, and `'ComparisonMethod'` behave exactly like they do for `sort`.
145
146### How are NaN values ordered?
147NaNs are treated as missing. They appear at the end for ascending permutations and at the beginning for descending permutations, matching MATLAB.
148
149### Can I call `argsort` on GPU arrays?
150Yes. When the active provider implements the `sort_dim` hook, permutations stay on the device. Otherwise tensors are gathered automatically and sorted on the host.
151
152### Is the permutation stable?
153Yes. Equal elements keep their relative order so that `argsort` remains consistent with MATLAB's stable sorting semantics.
154
155### What type is returned?
156A double-precision tensor (or scalar) with the same shape as the input, containing permutation indices.
157
158### Does `argsort` mutate its input?
159No. It only returns indices. Combine the result with indexing (`A(idx)`) to obtain reordered values when needed.
160
161## See also
162[sort](./sort), [sortrows](./sortrows), [randperm](../../array/creation/randperm), [max](../../math/reduction/max), [min](../../math/reduction/min)
163
164## Source & Feedback
165- Source code: [`crates/runmat-runtime/src/builtins/array/sorting_sets/argsort.rs`](https://github.com/runmat-org/runmat/blob/main/crates/runmat-runtime/src/builtins/array/sorting_sets/argsort.rs)
166- Found a bug? [Open an issue](https://github.com/runmat-org/runmat/issues/new/choose) with details and a minimal repro.
167"#;
168
169pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
170    name: "argsort",
171    op_kind: GpuOpKind::Custom("sort"),
172    supported_precisions: &[ScalarType::F32, ScalarType::F64],
173    broadcast: BroadcastSemantics::None,
174    provider_hooks: &[ProviderHook::Custom("sort_dim")],
175    constant_strategy: ConstantStrategy::InlineLiteral,
176    residency: ResidencyPolicy::GatherImmediately,
177    nan_mode: ReductionNaN::Include,
178    two_pass_threshold: None,
179    workgroup_size: None,
180    accepts_nan_mode: true,
181    notes: "Shares provider hooks with `sort`; when unavailable tensors are gathered to host memory before computing indices.",
182};
183
184register_builtin_gpu_spec!(GPU_SPEC);
185
186pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
187    name: "argsort",
188    shape: ShapeRequirements::Any,
189    constant_strategy: ConstantStrategy::InlineLiteral,
190    elementwise: None,
191    reduction: None,
192    emits_nan: true,
193    notes: "`argsort` breaks fusion chains and acts as a residency sink; upstream tensors are gathered when no GPU sort kernel is provided.",
194};
195
196register_builtin_fusion_spec!(FUSION_SPEC);
197
198#[cfg(feature = "doc_export")]
199register_builtin_doc_text!("argsort", DOC_MD);
200
201#[runtime_builtin(
202    name = "argsort",
203    category = "array/sorting_sets",
204    summary = "Return the permutation indices that would sort tensors along a dimension.",
205    keywords = "argsort,sort,indices,permutation,gpu",
206    accel = "sink",
207    sink = true
208)]
209fn argsort_builtin(value: Value, rest: Vec<Value>) -> Result<Value, String> {
210    let evaluation = sort::evaluate(value, &rest)?;
211    Ok(evaluation.indices_value())
212}
213
214#[cfg(test)]
215mod tests {
216    use super::sort;
217    use super::*;
218    use crate::builtins::common::test_support;
219    use runmat_builtins::{ComplexTensor, IntValue, Tensor, Value};
220
221    #[test]
222    fn argsort_vector_default() {
223        let tensor = Tensor::new(vec![4.0, 1.0, 3.0], vec![3, 1]).unwrap();
224        let indices = argsort_builtin(Value::Tensor(tensor), Vec::new()).expect("argsort");
225        match indices {
226            Value::Tensor(t) => {
227                assert_eq!(t.data, vec![2.0, 3.0, 1.0]);
228                assert_eq!(t.shape, vec![3, 1]);
229            }
230            other => panic!("expected tensor result, got {other:?}"),
231        }
232    }
233
234    #[test]
235    fn argsort_descend_direction() {
236        let tensor = Tensor::new(vec![10.0, 4.0, 7.0, 9.0], vec![4, 1]).unwrap();
237        let indices =
238            argsort_builtin(Value::Tensor(tensor), vec![Value::from("descend")]).expect("argsort");
239        match indices {
240            Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 4.0, 3.0, 2.0]),
241            other => panic!("expected tensor result, got {other:?}"),
242        }
243    }
244
245    #[test]
246    fn argsort_dimension_two() {
247        let tensor = Tensor::new(vec![1.0, 6.0, 4.0, 2.0, 3.0, 5.0], vec![2, 3]).unwrap();
248        let args = vec![Value::Int(IntValue::I32(2))];
249        let indices =
250            argsort_builtin(Value::Tensor(tensor.clone()), args.clone()).expect("argsort");
251        let expected = sort::evaluate(Value::Tensor(tensor), &args)
252            .expect("sort evaluate")
253            .indices_value();
254        assert_eq!(indices, expected);
255    }
256
257    #[test]
258    fn argsort_absolute_comparison() {
259        let tensor = Tensor::new(vec![-8.0, -1.0, 3.0, -2.0], vec![4, 1]).unwrap();
260        let indices = argsort_builtin(
261            Value::Tensor(tensor),
262            vec![Value::from("ComparisonMethod"), Value::from("abs")],
263        )
264        .expect("argsort");
265        match indices {
266            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 4.0, 3.0, 1.0]),
267            other => panic!("expected tensor result, got {other:?}"),
268        }
269    }
270
271    #[test]
272    fn argsort_handles_nan_like_sort() {
273        let tensor = Tensor::new(vec![f64::NAN, 4.0, 1.0, 2.0], vec![4, 1]).unwrap();
274        let indices = argsort_builtin(Value::Tensor(tensor), Vec::new()).expect("argsort");
275        match indices {
276            Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 4.0, 2.0, 1.0]),
277            other => panic!("expected tensor result, got {other:?}"),
278        }
279    }
280
281    #[test]
282    fn argsort_dimension_placeholder_then_dim() {
283        let tensor = Tensor::new(vec![1.0, 3.0, 4.0, 2.0], vec![2, 2]).unwrap();
284        let placeholder = Tensor::new(Vec::new(), vec![0, 0]).unwrap();
285        let args = vec![
286            Value::Tensor(placeholder),
287            Value::Int(IntValue::I32(2)),
288            Value::from("descend"),
289        ];
290        let indices =
291            argsort_builtin(Value::Tensor(tensor.clone()), args.clone()).expect("argsort");
292        let expected = sort::evaluate(Value::Tensor(tensor), &args)
293            .expect("sort evaluate")
294            .indices_value();
295        assert_eq!(indices, expected);
296    }
297
298    #[test]
299    fn argsort_dimension_greater_than_ndims_returns_ones() {
300        let tensor = Tensor::new(vec![1.0, 3.0, 2.0], vec![3, 1]).unwrap();
301        let indices = argsort_builtin(Value::Tensor(tensor), vec![Value::Int(IntValue::I32(5))])
302            .expect("argsort");
303        match indices {
304            Value::Tensor(t) => assert!(t.data.iter().all(|v| (*v - 1.0).abs() < f64::EPSILON)),
305            other => panic!("expected tensor result, got {other:?}"),
306        }
307    }
308
309    #[test]
310    fn argsort_dimension_zero_errors() {
311        let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
312        let err =
313            argsort_builtin(Value::Tensor(tensor), vec![Value::Int(IntValue::I32(0))]).unwrap_err();
314        assert!(
315            err.contains("dimension must be >= 1"),
316            "unexpected error: {err}"
317        );
318    }
319
320    #[test]
321    fn argsort_invalid_argument_errors() {
322        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
323        let err = argsort_builtin(
324            Value::Tensor(tensor),
325            vec![Value::from("MissingPlacement"), Value::from("auto")],
326        )
327        .unwrap_err();
328        assert!(
329            err.contains("sort: the 'MissingPlacement' option is not supported"),
330            "{err}"
331        );
332    }
333
334    #[test]
335    fn argsort_invalid_comparison_method_errors() {
336        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
337        let err = argsort_builtin(
338            Value::Tensor(tensor),
339            vec![Value::from("ComparisonMethod"), Value::from("unknown")],
340        )
341        .unwrap_err();
342        assert!(
343            err.contains("unsupported ComparisonMethod"),
344            "unexpected error: {err}"
345        );
346    }
347
348    #[test]
349    fn argsort_invalid_comparison_method_value_errors() {
350        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
351        let err = argsort_builtin(
352            Value::Tensor(tensor),
353            vec![
354                Value::from("ComparisonMethod"),
355                Value::Int(IntValue::I32(1)),
356            ],
357        )
358        .unwrap_err();
359        assert!(
360            err.contains("requires a string value"),
361            "unexpected error: {err}"
362        );
363    }
364
365    #[test]
366    fn argsort_stable_with_duplicates() {
367        let tensor = Tensor::new(vec![2.0, 2.0, 1.0, 2.0], vec![4, 1]).unwrap();
368        let indices = argsort_builtin(Value::Tensor(tensor), Vec::new()).expect("argsort");
369        match indices {
370            Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 1.0, 2.0, 4.0]),
371            other => panic!("expected tensor result, got {other:?}"),
372        }
373    }
374
375    #[test]
376    fn argsort_complex_real_method() {
377        let tensor =
378            ComplexTensor::new(vec![(1.0, 2.0), (-3.0, 0.5), (1.0, -1.0)], vec![3, 1]).unwrap();
379        let indices = argsort_builtin(
380            Value::ComplexTensor(tensor),
381            vec![Value::from("ComparisonMethod"), Value::from("real")],
382        )
383        .expect("argsort");
384        match indices {
385            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
386            other => panic!("expected tensor result, got {other:?}"),
387        }
388    }
389
390    #[test]
391    fn argsort_gpu_roundtrip() {
392        test_support::with_test_provider(|provider| {
393            let tensor = Tensor::new(vec![3.0, 1.0, 2.0], vec![3, 1]).unwrap();
394            let view = runmat_accelerate_api::HostTensorView {
395                data: &tensor.data,
396                shape: &tensor.shape,
397            };
398            let handle = provider.upload(&view).expect("upload");
399            let indices = argsort_builtin(Value::GpuTensor(handle), Vec::new()).expect("argsort");
400            match indices {
401                Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
402                other => panic!("expected tensor result, got {other:?}"),
403            }
404        });
405    }
406
407    #[test]
408    #[cfg(feature = "doc_export")]
409    fn doc_examples_present() {
410        let blocks = test_support::doc_examples(DOC_MD);
411        assert!(!blocks.is_empty());
412    }
413
414    #[test]
415    #[cfg(feature = "wgpu")]
416    fn argsort_wgpu_matches_cpu() {
417        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
418            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
419        );
420        let tensor = Tensor::new(vec![0.0, 5.0, -1.0, 2.0], vec![4, 1]).unwrap();
421        let cpu_indices = argsort_builtin(Value::Tensor(tensor.clone()), Vec::new()).unwrap();
422        let view = runmat_accelerate_api::HostTensorView {
423            data: &tensor.data,
424            shape: &tensor.shape,
425        };
426        let gpu_handle = runmat_accelerate_api::provider()
427            .unwrap()
428            .upload(&view)
429            .expect("upload");
430        let gpu_indices = argsort_builtin(Value::GpuTensor(gpu_handle), Vec::new()).unwrap();
431
432        let cpu_tensor = match cpu_indices {
433            Value::Tensor(t) => t,
434            other => panic!("expected tensor, got {other:?}"),
435        };
436        let gpu_tensor = match gpu_indices {
437            Value::Tensor(t) => t,
438            other => panic!("expected tensor, got {other:?}"),
439        };
440        assert_eq!(gpu_tensor.shape, cpu_tensor.shape);
441        assert_eq!(gpu_tensor.data, cpu_tensor.data);
442    }
443}