1use runmat_builtins::Value;
4use runmat_macros::runtime_builtin;
5
6use super::sort;
7use crate::builtins::common::spec::{
8 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
9 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
10};
11#[cfg(feature = "doc_export")]
12use crate::register_builtin_doc_text;
13use crate::{register_builtin_fusion_spec, register_builtin_gpu_spec};
14
15#[cfg(feature = "doc_export")]
16pub const DOC_MD: &str = r#"---
17title: "argsort"
18category: "array/sorting_sets"
19keywords: ["argsort", "sort", "indices", "permutation", "gpu"]
20summary: "Return the permutation indices that would sort tensors along a dimension."
21references:
22 - https://www.mathworks.com/help/matlab/ref/sort.html
23gpu_support:
24 elementwise: false
25 reduction: false
26 precisions: ["f32", "f64"]
27 broadcasting: "none"
28 notes: "Uses the same sort kernels as `sort`; falls back to host evaluation when the provider lacks `sort_dim`."
29fusion:
30 elementwise: false
31 reduction: false
32 max_inputs: 1
33 constants: "inline"
34requires_feature: null
35tested:
36 unit: "builtins::array::sorting_sets::argsort::tests"
37 integration: "builtins::array::sorting_sets::argsort::tests::argsort_gpu_roundtrip"
38---
39
40# What does the `argsort` function do in MATLAB / RunMat?
41`argsort(X)` returns the permutation indices that order `X` the same way `sort(X)` would. It matches the indices produced by `[~, I] = sort(X, ...)` in MathWorks MATLAB and honours the same argument forms for dimensions, directions, and comparison methods.
42
43## How does the `argsort` function behave in MATLAB / RunMat?
44- Operates along the first non-singleton dimension by default. Pass a dimension argument to override.
45- Accepts the same direction keywords as `sort`: `'ascend'` (default) or `'descend'`.
46- Supports `'ComparisonMethod'` values `'auto'`, `'real'`, and `'abs'` for real and complex inputs.
47- Returns indices as double-precision tensors using MATLAB's one-based indexing.
48- Treats NaN values as missing: they appear at the end for ascending permutations and at the beginning for descending permutations.
49- Acts as a residency sink. GPU tensors are gathered when the active provider does not expose a specialised sort kernel.
50
51## GPU execution in RunMat
52- `argsort` shares the `sort_dim` provider hook with the `sort` builtin. When implemented, indices are computed without leaving the device.
53- If the provider lacks `sort_dim`, RunMat gathers tensors to host memory, evaluates the permutation, and returns host-resident indices.
54- Outputs are always host-resident double tensors because permutation indices are consumed immediately by host-side logic (e.g., indexing).
55
56## Examples of using `argsort` in MATLAB / RunMat
57
58### Getting indices that sort a vector
59```matlab
60A = [4; 1; 3];
61idx = argsort(A);
62```
63Expected output:
64```matlab
65idx =
66 2
67 3
68 1
69```
70
71### Reordering data with the permutation indices
72```matlab
73A = [3 9 1 5];
74idx = argsort(A);
75sorted = A(idx);
76```
77Expected output:
78```matlab
79sorted =
80 1 3 5 9
81```
82
83### Sorting along a specific dimension
84```matlab
85A = [1 6 4; 2 3 5];
86idx = argsort(A, 2);
87```
88Expected output:
89```matlab
90idx =
91 1 3 2
92 1 2 3
93```
94
95### Descending order permutations
96```matlab
97A = [10 4 7 9];
98idx = argsort(A, 'descend');
99```
100Expected output:
101```matlab
102idx =
103 1 4 3 2
104```
105
106### Using `ComparisonMethod` to sort by magnitude
107```matlab
108A = [-8 -1 3 -2];
109idx = argsort(A, 'ComparisonMethod', 'abs');
110```
111Expected output:
112```matlab
113idx =
114 2 4 3 1
115```
116
117### Handling NaN values during permutation
118```matlab
119A = [NaN 4 1 2];
120idx = argsort(A);
121```
122Expected output:
123```matlab
124idx =
125 3 4 2 1
126```
127
128### Argsort on GPU tensors falls back gracefully
129```matlab
130G = gpuArray(randn(5, 1));
131idx = argsort(G);
132```
133RunMat gathers `G` to the host when no device sort kernel is available, ensuring the returned indices match MATLAB exactly.
134
135## FAQ
136
137### How is `argsort` different from `sort`?
138`argsort` returns only the permutation indices. It behaves like calling `[~, I] = sort(X, ...)` without materialising the sorted values.
139
140### Are the indices one-based like MATLAB?
141Yes. All indices follow MATLAB's one-based convention so they can be used directly with subsequent indexing operations.
142
143### Does `argsort` support the same arguments as `sort`?
144Yes. Dimension arguments, direction keywords, and `'ComparisonMethod'` behave exactly like they do for `sort`.
145
146### How are NaN values ordered?
147NaNs are treated as missing. They appear at the end for ascending permutations and at the beginning for descending permutations, matching MATLAB.
148
149### Can I call `argsort` on GPU arrays?
150Yes. When the active provider implements the `sort_dim` hook, permutations stay on the device. Otherwise tensors are gathered automatically and sorted on the host.
151
152### Is the permutation stable?
153Yes. Equal elements keep their relative order so that `argsort` remains consistent with MATLAB's stable sorting semantics.
154
155### What type is returned?
156A double-precision tensor (or scalar) with the same shape as the input, containing permutation indices.
157
158### Does `argsort` mutate its input?
159No. It only returns indices. Combine the result with indexing (`A(idx)`) to obtain reordered values when needed.
160
161## See also
162[sort](./sort), [sortrows](./sortrows), [randperm](../../array/creation/randperm), [max](../../math/reduction/max), [min](../../math/reduction/min)
163
164## Source & Feedback
165- Source code: [`crates/runmat-runtime/src/builtins/array/sorting_sets/argsort.rs`](https://github.com/runmat-org/runmat/blob/main/crates/runmat-runtime/src/builtins/array/sorting_sets/argsort.rs)
166- Found a bug? [Open an issue](https://github.com/runmat-org/runmat/issues/new/choose) with details and a minimal repro.
167"#;
168
169pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
170 name: "argsort",
171 op_kind: GpuOpKind::Custom("sort"),
172 supported_precisions: &[ScalarType::F32, ScalarType::F64],
173 broadcast: BroadcastSemantics::None,
174 provider_hooks: &[ProviderHook::Custom("sort_dim")],
175 constant_strategy: ConstantStrategy::InlineLiteral,
176 residency: ResidencyPolicy::GatherImmediately,
177 nan_mode: ReductionNaN::Include,
178 two_pass_threshold: None,
179 workgroup_size: None,
180 accepts_nan_mode: true,
181 notes: "Shares provider hooks with `sort`; when unavailable tensors are gathered to host memory before computing indices.",
182};
183
184register_builtin_gpu_spec!(GPU_SPEC);
185
186pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
187 name: "argsort",
188 shape: ShapeRequirements::Any,
189 constant_strategy: ConstantStrategy::InlineLiteral,
190 elementwise: None,
191 reduction: None,
192 emits_nan: true,
193 notes: "`argsort` breaks fusion chains and acts as a residency sink; upstream tensors are gathered when no GPU sort kernel is provided.",
194};
195
196register_builtin_fusion_spec!(FUSION_SPEC);
197
198#[cfg(feature = "doc_export")]
199register_builtin_doc_text!("argsort", DOC_MD);
200
201#[runtime_builtin(
202 name = "argsort",
203 category = "array/sorting_sets",
204 summary = "Return the permutation indices that would sort tensors along a dimension.",
205 keywords = "argsort,sort,indices,permutation,gpu",
206 accel = "sink",
207 sink = true
208)]
209fn argsort_builtin(value: Value, rest: Vec<Value>) -> Result<Value, String> {
210 let evaluation = sort::evaluate(value, &rest)?;
211 Ok(evaluation.indices_value())
212}
213
214#[cfg(test)]
215mod tests {
216 use super::sort;
217 use super::*;
218 use crate::builtins::common::test_support;
219 use runmat_builtins::{ComplexTensor, IntValue, Tensor, Value};
220
221 #[test]
222 fn argsort_vector_default() {
223 let tensor = Tensor::new(vec![4.0, 1.0, 3.0], vec![3, 1]).unwrap();
224 let indices = argsort_builtin(Value::Tensor(tensor), Vec::new()).expect("argsort");
225 match indices {
226 Value::Tensor(t) => {
227 assert_eq!(t.data, vec![2.0, 3.0, 1.0]);
228 assert_eq!(t.shape, vec![3, 1]);
229 }
230 other => panic!("expected tensor result, got {other:?}"),
231 }
232 }
233
234 #[test]
235 fn argsort_descend_direction() {
236 let tensor = Tensor::new(vec![10.0, 4.0, 7.0, 9.0], vec![4, 1]).unwrap();
237 let indices =
238 argsort_builtin(Value::Tensor(tensor), vec![Value::from("descend")]).expect("argsort");
239 match indices {
240 Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 4.0, 3.0, 2.0]),
241 other => panic!("expected tensor result, got {other:?}"),
242 }
243 }
244
245 #[test]
246 fn argsort_dimension_two() {
247 let tensor = Tensor::new(vec![1.0, 6.0, 4.0, 2.0, 3.0, 5.0], vec![2, 3]).unwrap();
248 let args = vec![Value::Int(IntValue::I32(2))];
249 let indices =
250 argsort_builtin(Value::Tensor(tensor.clone()), args.clone()).expect("argsort");
251 let expected = sort::evaluate(Value::Tensor(tensor), &args)
252 .expect("sort evaluate")
253 .indices_value();
254 assert_eq!(indices, expected);
255 }
256
257 #[test]
258 fn argsort_absolute_comparison() {
259 let tensor = Tensor::new(vec![-8.0, -1.0, 3.0, -2.0], vec![4, 1]).unwrap();
260 let indices = argsort_builtin(
261 Value::Tensor(tensor),
262 vec![Value::from("ComparisonMethod"), Value::from("abs")],
263 )
264 .expect("argsort");
265 match indices {
266 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 4.0, 3.0, 1.0]),
267 other => panic!("expected tensor result, got {other:?}"),
268 }
269 }
270
271 #[test]
272 fn argsort_handles_nan_like_sort() {
273 let tensor = Tensor::new(vec![f64::NAN, 4.0, 1.0, 2.0], vec![4, 1]).unwrap();
274 let indices = argsort_builtin(Value::Tensor(tensor), Vec::new()).expect("argsort");
275 match indices {
276 Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 4.0, 2.0, 1.0]),
277 other => panic!("expected tensor result, got {other:?}"),
278 }
279 }
280
281 #[test]
282 fn argsort_dimension_placeholder_then_dim() {
283 let tensor = Tensor::new(vec![1.0, 3.0, 4.0, 2.0], vec![2, 2]).unwrap();
284 let placeholder = Tensor::new(Vec::new(), vec![0, 0]).unwrap();
285 let args = vec![
286 Value::Tensor(placeholder),
287 Value::Int(IntValue::I32(2)),
288 Value::from("descend"),
289 ];
290 let indices =
291 argsort_builtin(Value::Tensor(tensor.clone()), args.clone()).expect("argsort");
292 let expected = sort::evaluate(Value::Tensor(tensor), &args)
293 .expect("sort evaluate")
294 .indices_value();
295 assert_eq!(indices, expected);
296 }
297
298 #[test]
299 fn argsort_dimension_greater_than_ndims_returns_ones() {
300 let tensor = Tensor::new(vec![1.0, 3.0, 2.0], vec![3, 1]).unwrap();
301 let indices = argsort_builtin(Value::Tensor(tensor), vec![Value::Int(IntValue::I32(5))])
302 .expect("argsort");
303 match indices {
304 Value::Tensor(t) => assert!(t.data.iter().all(|v| (*v - 1.0).abs() < f64::EPSILON)),
305 other => panic!("expected tensor result, got {other:?}"),
306 }
307 }
308
309 #[test]
310 fn argsort_dimension_zero_errors() {
311 let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
312 let err =
313 argsort_builtin(Value::Tensor(tensor), vec![Value::Int(IntValue::I32(0))]).unwrap_err();
314 assert!(
315 err.contains("dimension must be >= 1"),
316 "unexpected error: {err}"
317 );
318 }
319
320 #[test]
321 fn argsort_invalid_argument_errors() {
322 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
323 let err = argsort_builtin(
324 Value::Tensor(tensor),
325 vec![Value::from("MissingPlacement"), Value::from("auto")],
326 )
327 .unwrap_err();
328 assert!(
329 err.contains("sort: the 'MissingPlacement' option is not supported"),
330 "{err}"
331 );
332 }
333
334 #[test]
335 fn argsort_invalid_comparison_method_errors() {
336 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
337 let err = argsort_builtin(
338 Value::Tensor(tensor),
339 vec![Value::from("ComparisonMethod"), Value::from("unknown")],
340 )
341 .unwrap_err();
342 assert!(
343 err.contains("unsupported ComparisonMethod"),
344 "unexpected error: {err}"
345 );
346 }
347
348 #[test]
349 fn argsort_invalid_comparison_method_value_errors() {
350 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
351 let err = argsort_builtin(
352 Value::Tensor(tensor),
353 vec![
354 Value::from("ComparisonMethod"),
355 Value::Int(IntValue::I32(1)),
356 ],
357 )
358 .unwrap_err();
359 assert!(
360 err.contains("requires a string value"),
361 "unexpected error: {err}"
362 );
363 }
364
365 #[test]
366 fn argsort_stable_with_duplicates() {
367 let tensor = Tensor::new(vec![2.0, 2.0, 1.0, 2.0], vec![4, 1]).unwrap();
368 let indices = argsort_builtin(Value::Tensor(tensor), Vec::new()).expect("argsort");
369 match indices {
370 Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 1.0, 2.0, 4.0]),
371 other => panic!("expected tensor result, got {other:?}"),
372 }
373 }
374
375 #[test]
376 fn argsort_complex_real_method() {
377 let tensor =
378 ComplexTensor::new(vec![(1.0, 2.0), (-3.0, 0.5), (1.0, -1.0)], vec![3, 1]).unwrap();
379 let indices = argsort_builtin(
380 Value::ComplexTensor(tensor),
381 vec![Value::from("ComparisonMethod"), Value::from("real")],
382 )
383 .expect("argsort");
384 match indices {
385 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
386 other => panic!("expected tensor result, got {other:?}"),
387 }
388 }
389
390 #[test]
391 fn argsort_gpu_roundtrip() {
392 test_support::with_test_provider(|provider| {
393 let tensor = Tensor::new(vec![3.0, 1.0, 2.0], vec![3, 1]).unwrap();
394 let view = runmat_accelerate_api::HostTensorView {
395 data: &tensor.data,
396 shape: &tensor.shape,
397 };
398 let handle = provider.upload(&view).expect("upload");
399 let indices = argsort_builtin(Value::GpuTensor(handle), Vec::new()).expect("argsort");
400 match indices {
401 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
402 other => panic!("expected tensor result, got {other:?}"),
403 }
404 });
405 }
406
407 #[test]
408 #[cfg(feature = "doc_export")]
409 fn doc_examples_present() {
410 let blocks = test_support::doc_examples(DOC_MD);
411 assert!(!blocks.is_empty());
412 }
413
414 #[test]
415 #[cfg(feature = "wgpu")]
416 fn argsort_wgpu_matches_cpu() {
417 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
418 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
419 );
420 let tensor = Tensor::new(vec![0.0, 5.0, -1.0, 2.0], vec![4, 1]).unwrap();
421 let cpu_indices = argsort_builtin(Value::Tensor(tensor.clone()), Vec::new()).unwrap();
422 let view = runmat_accelerate_api::HostTensorView {
423 data: &tensor.data,
424 shape: &tensor.shape,
425 };
426 let gpu_handle = runmat_accelerate_api::provider()
427 .unwrap()
428 .upload(&view)
429 .expect("upload");
430 let gpu_indices = argsort_builtin(Value::GpuTensor(gpu_handle), Vec::new()).unwrap();
431
432 let cpu_tensor = match cpu_indices {
433 Value::Tensor(t) => t,
434 other => panic!("expected tensor, got {other:?}"),
435 };
436 let gpu_tensor = match gpu_indices {
437 Value::Tensor(t) => t,
438 other => panic!("expected tensor, got {other:?}"),
439 };
440 assert_eq!(gpu_tensor.shape, cpu_tensor.shape);
441 assert_eq!(gpu_tensor.data, cpu_tensor.data);
442 }
443}