Skip to main content

runmat_runtime/builtins/array/sorting_sets/
argsort.rs

1//! MATLAB-compatible `argsort` builtin returning permutation indices.
2
3use runmat_builtins::Value;
4use runmat_macros::runtime_builtin;
5
6use super::sort;
7use super::type_resolvers::index_output_type;
8use crate::builtins::common::spec::{
9    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
10    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
11};
12
13#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::sorting_sets::argsort")]
14pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
15    name: "argsort",
16    op_kind: GpuOpKind::Custom("sort"),
17    supported_precisions: &[ScalarType::F32, ScalarType::F64],
18    broadcast: BroadcastSemantics::None,
19    provider_hooks: &[ProviderHook::Custom("sort_dim")],
20    constant_strategy: ConstantStrategy::InlineLiteral,
21    residency: ResidencyPolicy::GatherImmediately,
22    nan_mode: ReductionNaN::Include,
23    two_pass_threshold: None,
24    workgroup_size: None,
25    accepts_nan_mode: true,
26    notes: "Shares provider hooks with `sort`; when unavailable tensors are gathered to host memory before computing indices.",
27};
28
29#[runmat_macros::register_fusion_spec(
30    builtin_path = "crate::builtins::array::sorting_sets::argsort"
31)]
32pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
33    name: "argsort",
34    shape: ShapeRequirements::Any,
35    constant_strategy: ConstantStrategy::InlineLiteral,
36    elementwise: None,
37    reduction: None,
38    emits_nan: true,
39    notes: "`argsort` breaks fusion chains and acts as a residency sink; upstream tensors are gathered when no GPU sort kernel is provided.",
40};
41
42#[runtime_builtin(
43    name = "argsort",
44    category = "array/sorting_sets",
45    summary = "Return the permutation indices that would sort tensors along a dimension.",
46    keywords = "argsort,sort,indices,permutation,gpu",
47    accel = "sink",
48    sink = true,
49    type_resolver(index_output_type),
50    builtin_path = "crate::builtins::array::sorting_sets::argsort"
51)]
52async fn argsort_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
53    let evaluation = sort::evaluate(value, &rest).await?;
54    Ok(evaluation.indices_value())
55}
56
57#[cfg(test)]
58pub(crate) mod tests {
59    use super::index_output_type;
60    use super::sort;
61    use futures::executor::block_on;
62
63    fn argsort_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
64        block_on(super::argsort_builtin(value, rest))
65    }
66    use crate::builtins::common::test_support;
67    use runmat_builtins::{ComplexTensor, IntValue, ResolveContext, Tensor, Type, Value};
68
69    fn error_message(err: crate::RuntimeError) -> String {
70        err.message().to_string()
71    }
72
73    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
74    #[test]
75    fn argsort_vector_default() {
76        let tensor = Tensor::new(vec![4.0, 1.0, 3.0], vec![3, 1]).unwrap();
77        let indices = argsort_builtin(Value::Tensor(tensor), Vec::new()).expect("argsort");
78        match indices {
79            Value::Tensor(t) => {
80                assert_eq!(t.data, vec![2.0, 3.0, 1.0]);
81                assert_eq!(t.shape, vec![3, 1]);
82            }
83            other => panic!("expected tensor result, got {other:?}"),
84        }
85    }
86
87    #[test]
88    fn argsort_type_resolver_indices() {
89        assert_eq!(
90            index_output_type(&[Type::tensor()], &ResolveContext::new(Vec::new())),
91            Type::tensor()
92        );
93    }
94
95    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
96    #[test]
97    fn argsort_descend_direction() {
98        let tensor = Tensor::new(vec![10.0, 4.0, 7.0, 9.0], vec![4, 1]).unwrap();
99        let indices =
100            argsort_builtin(Value::Tensor(tensor), vec![Value::from("descend")]).expect("argsort");
101        match indices {
102            Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 4.0, 3.0, 2.0]),
103            other => panic!("expected tensor result, got {other:?}"),
104        }
105    }
106
107    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
108    #[test]
109    fn argsort_dimension_two() {
110        let tensor = Tensor::new(vec![1.0, 6.0, 4.0, 2.0, 3.0, 5.0], vec![2, 3]).unwrap();
111        let args = vec![Value::Int(IntValue::I32(2))];
112        let indices =
113            argsort_builtin(Value::Tensor(tensor.clone()), args.clone()).expect("argsort");
114        let expected = futures::executor::block_on(sort::evaluate(Value::Tensor(tensor), &args))
115            .expect("sort evaluate")
116            .indices_value();
117        assert_eq!(indices, expected);
118    }
119
120    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
121    #[test]
122    fn argsort_absolute_comparison() {
123        let tensor = Tensor::new(vec![-8.0, -1.0, 3.0, -2.0], vec![4, 1]).unwrap();
124        let indices = argsort_builtin(
125            Value::Tensor(tensor),
126            vec![Value::from("ComparisonMethod"), Value::from("abs")],
127        )
128        .expect("argsort");
129        match indices {
130            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 4.0, 3.0, 1.0]),
131            other => panic!("expected tensor result, got {other:?}"),
132        }
133    }
134
135    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
136    #[test]
137    fn argsort_handles_nan_like_sort() {
138        let tensor = Tensor::new(vec![f64::NAN, 4.0, 1.0, 2.0], vec![4, 1]).unwrap();
139        let indices = argsort_builtin(Value::Tensor(tensor), Vec::new()).expect("argsort");
140        match indices {
141            Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 4.0, 2.0, 1.0]),
142            other => panic!("expected tensor result, got {other:?}"),
143        }
144    }
145
146    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
147    #[test]
148    fn argsort_dimension_placeholder_then_dim() {
149        let tensor = Tensor::new(vec![1.0, 3.0, 4.0, 2.0], vec![2, 2]).unwrap();
150        let placeholder = Tensor::new(Vec::new(), vec![0, 0]).unwrap();
151        let args = vec![
152            Value::Tensor(placeholder),
153            Value::Int(IntValue::I32(2)),
154            Value::from("descend"),
155        ];
156        let indices =
157            argsort_builtin(Value::Tensor(tensor.clone()), args.clone()).expect("argsort");
158        let expected = futures::executor::block_on(sort::evaluate(Value::Tensor(tensor), &args))
159            .expect("sort evaluate")
160            .indices_value();
161        assert_eq!(indices, expected);
162    }
163
164    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
165    #[test]
166    fn argsort_dimension_greater_than_ndims_returns_ones() {
167        let tensor = Tensor::new(vec![1.0, 3.0, 2.0], vec![3, 1]).unwrap();
168        let indices = argsort_builtin(Value::Tensor(tensor), vec![Value::Int(IntValue::I32(5))])
169            .expect("argsort");
170        match indices {
171            Value::Tensor(t) => assert!(t.data.iter().all(|v| (*v - 1.0).abs() < f64::EPSILON)),
172            other => panic!("expected tensor result, got {other:?}"),
173        }
174    }
175
176    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
177    #[test]
178    fn argsort_dimension_zero_errors() {
179        let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
180        let err = error_message(
181            argsort_builtin(Value::Tensor(tensor), vec![Value::Int(IntValue::I32(0))]).unwrap_err(),
182        );
183        assert!(
184            err.contains("dimension must be >= 1"),
185            "unexpected error: {err}"
186        );
187    }
188
189    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
190    #[test]
191    fn argsort_invalid_argument_errors() {
192        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
193        let err = error_message(
194            argsort_builtin(
195                Value::Tensor(tensor),
196                vec![Value::from("MissingPlacement"), Value::from("auto")],
197            )
198            .unwrap_err(),
199        );
200        assert!(
201            err.contains("sort: the 'MissingPlacement' option is not supported"),
202            "{err}"
203        );
204    }
205
206    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
207    #[test]
208    fn argsort_invalid_comparison_method_errors() {
209        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
210        let err = error_message(
211            argsort_builtin(
212                Value::Tensor(tensor),
213                vec![Value::from("ComparisonMethod"), Value::from("unknown")],
214            )
215            .unwrap_err(),
216        );
217        assert!(
218            err.contains("unsupported ComparisonMethod"),
219            "unexpected error: {err}"
220        );
221    }
222
223    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
224    #[test]
225    fn argsort_invalid_comparison_method_value_errors() {
226        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
227        let err = error_message(
228            argsort_builtin(
229                Value::Tensor(tensor),
230                vec![
231                    Value::from("ComparisonMethod"),
232                    Value::Int(IntValue::I32(1)),
233                ],
234            )
235            .unwrap_err(),
236        );
237        assert!(
238            err.contains("requires a string value"),
239            "unexpected error: {err}"
240        );
241    }
242
243    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
244    #[test]
245    fn argsort_stable_with_duplicates() {
246        let tensor = Tensor::new(vec![2.0, 2.0, 1.0, 2.0], vec![4, 1]).unwrap();
247        let indices = argsort_builtin(Value::Tensor(tensor), Vec::new()).expect("argsort");
248        match indices {
249            Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 1.0, 2.0, 4.0]),
250            other => panic!("expected tensor result, got {other:?}"),
251        }
252    }
253
254    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
255    #[test]
256    fn argsort_complex_real_method() {
257        let tensor =
258            ComplexTensor::new(vec![(1.0, 2.0), (-3.0, 0.5), (1.0, -1.0)], vec![3, 1]).unwrap();
259        let indices = argsort_builtin(
260            Value::ComplexTensor(tensor),
261            vec![Value::from("ComparisonMethod"), Value::from("real")],
262        )
263        .expect("argsort");
264        match indices {
265            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
266            other => panic!("expected tensor result, got {other:?}"),
267        }
268    }
269
270    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
271    #[test]
272    fn argsort_gpu_roundtrip() {
273        test_support::with_test_provider(|provider| {
274            let tensor = Tensor::new(vec![3.0, 1.0, 2.0], vec![3, 1]).unwrap();
275            let view = runmat_accelerate_api::HostTensorView {
276                data: &tensor.data,
277                shape: &tensor.shape,
278            };
279            let handle = provider.upload(&view).expect("upload");
280            let indices = argsort_builtin(Value::GpuTensor(handle), Vec::new()).expect("argsort");
281            match indices {
282                Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
283                other => panic!("expected tensor result, got {other:?}"),
284            }
285        });
286    }
287
288    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
289    #[test]
290    #[cfg(feature = "wgpu")]
291    fn argsort_wgpu_matches_cpu() {
292        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
293            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
294        );
295        let tensor = Tensor::new(vec![0.0, 5.0, -1.0, 2.0], vec![4, 1]).unwrap();
296        let cpu_indices = argsort_builtin(Value::Tensor(tensor.clone()), Vec::new()).unwrap();
297        let view = runmat_accelerate_api::HostTensorView {
298            data: &tensor.data,
299            shape: &tensor.shape,
300        };
301        let gpu_handle = runmat_accelerate_api::provider()
302            .unwrap()
303            .upload(&view)
304            .expect("upload");
305        let gpu_indices = argsort_builtin(Value::GpuTensor(gpu_handle), Vec::new()).unwrap();
306
307        let cpu_tensor = match cpu_indices {
308            Value::Tensor(t) => t,
309            other => panic!("expected tensor, got {other:?}"),
310        };
311        let gpu_tensor = match gpu_indices {
312            Value::Tensor(t) => t,
313            other => panic!("expected tensor, got {other:?}"),
314        };
315        assert_eq!(gpu_tensor.shape, cpu_tensor.shape);
316        assert_eq!(gpu_tensor.data, cpu_tensor.data);
317    }
318}