Skip to main content

runmat_runtime/builtins/array/sorting_sets/
argsort.rs

1//! MATLAB-compatible `argsort` builtin returning permutation indices.
2
3use runmat_builtins::{
4    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor, Value,
6};
7use runmat_macros::runtime_builtin;
8
9use super::sort;
10use super::type_resolvers::index_output_type;
11use crate::builtins::common::spec::{
12    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
13    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
14};
15
16#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::sorting_sets::argsort")]
17pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
18    name: "argsort",
19    op_kind: GpuOpKind::Custom("sort"),
20    supported_precisions: &[ScalarType::F32, ScalarType::F64],
21    broadcast: BroadcastSemantics::None,
22    provider_hooks: &[ProviderHook::Custom("sort_dim")],
23    constant_strategy: ConstantStrategy::InlineLiteral,
24    residency: ResidencyPolicy::GatherImmediately,
25    nan_mode: ReductionNaN::Include,
26    two_pass_threshold: None,
27    workgroup_size: None,
28    accepts_nan_mode: true,
29    notes: "Shares provider hooks with `sort`; when unavailable tensors are gathered to host memory before computing indices.",
30};
31
32#[runmat_macros::register_fusion_spec(
33    builtin_path = "crate::builtins::array::sorting_sets::argsort"
34)]
35pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
36    name: "argsort",
37    shape: ShapeRequirements::Any,
38    constant_strategy: ConstantStrategy::InlineLiteral,
39    elementwise: None,
40    reduction: None,
41    emits_nan: true,
42    notes: "`argsort` breaks fusion chains and acts as a residency sink; upstream tensors are gathered when no GPU sort kernel is provided.",
43};
44
45const ARGSORT_OUTPUT_I: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
46    name: "I",
47    ty: BuiltinParamType::NumericArray,
48    arity: BuiltinParamArity::Required,
49    default: None,
50    description: "One-based permutation indices that sort each slice.",
51}];
52
53const ARGSORT_INPUTS_A: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
54    name: "A",
55    ty: BuiltinParamType::Any,
56    arity: BuiltinParamArity::Required,
57    default: None,
58    description: "Input array.",
59}];
60
61const ARGSORT_INPUTS_A_ARG1: [BuiltinParamDescriptor; 2] = [
62    BuiltinParamDescriptor {
63        name: "A",
64        ty: BuiltinParamType::Any,
65        arity: BuiltinParamArity::Required,
66        default: None,
67        description: "Input array.",
68    },
69    BuiltinParamDescriptor {
70        name: "arg1",
71        ty: BuiltinParamType::Any,
72        arity: BuiltinParamArity::Required,
73        default: None,
74        description: "Dimension selector or direction token.",
75    },
76];
77
78const ARGSORT_INPUTS_A_ARG1_ARG2: [BuiltinParamDescriptor; 3] = [
79    BuiltinParamDescriptor {
80        name: "A",
81        ty: BuiltinParamType::Any,
82        arity: BuiltinParamArity::Required,
83        default: None,
84        description: "Input array.",
85    },
86    BuiltinParamDescriptor {
87        name: "arg1",
88        ty: BuiltinParamType::Any,
89        arity: BuiltinParamArity::Required,
90        default: None,
91        description: "Dimension selector, placeholder, or direction token.",
92    },
93    BuiltinParamDescriptor {
94        name: "arg2",
95        ty: BuiltinParamType::Any,
96        arity: BuiltinParamArity::Required,
97        default: None,
98        description: "Dimension selector or direction token.",
99    },
100];
101
102const ARGSORT_INPUTS_COMPARISON_METHOD: [BuiltinParamDescriptor; 4] = [
103    BuiltinParamDescriptor {
104        name: "A",
105        ty: BuiltinParamType::Any,
106        arity: BuiltinParamArity::Required,
107        default: None,
108        description: "Input array.",
109    },
110    BuiltinParamDescriptor {
111        name: "arg",
112        ty: BuiltinParamType::Any,
113        arity: BuiltinParamArity::Variadic,
114        default: None,
115        description: "Optional dimension/direction arguments.",
116    },
117    BuiltinParamDescriptor {
118        name: "name",
119        ty: BuiltinParamType::StringScalar,
120        arity: BuiltinParamArity::Required,
121        default: Some("\"ComparisonMethod\""),
122        description: "Name-value option key.",
123    },
124    BuiltinParamDescriptor {
125        name: "method",
126        ty: BuiltinParamType::StringScalar,
127        arity: BuiltinParamArity::Required,
128        default: Some("\"auto\""),
129        description: "Comparison method: 'auto', 'real', or 'abs'.",
130    },
131];
132
133const ARGSORT_INPUTS_MISSING_PLACEMENT: [BuiltinParamDescriptor; 4] = [
134    BuiltinParamDescriptor {
135        name: "A",
136        ty: BuiltinParamType::Any,
137        arity: BuiltinParamArity::Required,
138        default: None,
139        description: "Input array.",
140    },
141    BuiltinParamDescriptor {
142        name: "arg",
143        ty: BuiltinParamType::Any,
144        arity: BuiltinParamArity::Variadic,
145        default: None,
146        description: "Optional dimension/direction arguments.",
147    },
148    BuiltinParamDescriptor {
149        name: "name",
150        ty: BuiltinParamType::StringScalar,
151        arity: BuiltinParamArity::Required,
152        default: Some("\"MissingPlacement\""),
153        description: "Name-value option key.",
154    },
155    BuiltinParamDescriptor {
156        name: "placement",
157        ty: BuiltinParamType::StringScalar,
158        arity: BuiltinParamArity::Required,
159        default: Some("\"auto\""),
160        description: "Requested NaN placement option (currently unsupported).",
161    },
162];
163
164const ARGSORT_SIGNATURES: [BuiltinSignatureDescriptor; 5] = [
165    BuiltinSignatureDescriptor {
166        label: "I = argsort(A)",
167        inputs: &ARGSORT_INPUTS_A,
168        outputs: &ARGSORT_OUTPUT_I,
169    },
170    BuiltinSignatureDescriptor {
171        label: "I = argsort(A, arg1)",
172        inputs: &ARGSORT_INPUTS_A_ARG1,
173        outputs: &ARGSORT_OUTPUT_I,
174    },
175    BuiltinSignatureDescriptor {
176        label: "I = argsort(A, arg1, arg2)",
177        inputs: &ARGSORT_INPUTS_A_ARG1_ARG2,
178        outputs: &ARGSORT_OUTPUT_I,
179    },
180    BuiltinSignatureDescriptor {
181        label: "I = argsort(A, ..., \"ComparisonMethod\", method)",
182        inputs: &ARGSORT_INPUTS_COMPARISON_METHOD,
183        outputs: &ARGSORT_OUTPUT_I,
184    },
185    BuiltinSignatureDescriptor {
186        label: "I = argsort(A, ..., \"MissingPlacement\", placement)",
187        inputs: &ARGSORT_INPUTS_MISSING_PLACEMENT,
188        outputs: &ARGSORT_OUTPUT_I,
189    },
190];
191
192const ARGSORT_ERROR_INVALID_DIMENSION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
193    code: "RM.ARGSORT.INVALID_DIMENSION",
194    identifier: Some("RunMat:sort:InvalidDimension"),
195    when: "Dimension argument is non-positive, non-integer, or otherwise invalid.",
196    message: "sort: invalid dimension argument",
197};
198
199const ARGSORT_ERROR_COMPARISON_METHOD_REQUIRES_STRING: BuiltinErrorDescriptor =
200    BuiltinErrorDescriptor {
201        code: "RM.ARGSORT.COMPARISON_METHOD_REQUIRES_STRING",
202        identifier: Some("RunMat:sort:ComparisonMethodRequiresString"),
203        when: "ComparisonMethod option value is not string-like.",
204        message: "sort: 'ComparisonMethod' requires a string value",
205    };
206
207const ARGSORT_ERROR_COMPARISON_METHOD_UNKNOWN: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
208    code: "RM.ARGSORT.COMPARISON_METHOD_UNKNOWN",
209    identifier: Some("RunMat:sort:ComparisonMethodUnknown"),
210    when: "ComparisonMethod option value is not one of 'auto'/'real'/'abs'.",
211    message: "sort: unsupported ComparisonMethod",
212};
213
214const ARGSORT_ERROR_MISSINGPLACEMENT_UNSUPPORTED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
215    code: "RM.ARGSORT.MISSINGPLACEMENT_UNSUPPORTED",
216    identifier: Some("RunMat:sort:MissingPlacementUnsupported"),
217    when: "MissingPlacement option is provided but unsupported.",
218    message: "sort: the 'MissingPlacement' option is not supported yet",
219};
220
221const ARGSORT_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
222    code: "RM.ARGSORT.INVALID_ARGUMENT",
223    identifier: Some("RunMat:sort:InvalidArgument"),
224    when: "Parser encounters invalid or unrecognized option/value arguments.",
225    message: "sort: invalid argument sequence",
226};
227
228const ARGSORT_ERRORS: [BuiltinErrorDescriptor; 5] = [
229    ARGSORT_ERROR_INVALID_DIMENSION,
230    ARGSORT_ERROR_COMPARISON_METHOD_REQUIRES_STRING,
231    ARGSORT_ERROR_COMPARISON_METHOD_UNKNOWN,
232    ARGSORT_ERROR_MISSINGPLACEMENT_UNSUPPORTED,
233    ARGSORT_ERROR_INVALID_ARGUMENT,
234];
235
236pub const ARGSORT_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
237    signatures: &ARGSORT_SIGNATURES,
238    output_mode: BuiltinOutputMode::Fixed,
239    completion_policy: BuiltinCompletionPolicy::Public,
240    errors: &ARGSORT_ERRORS,
241};
242
243#[runtime_builtin(
244    name = "argsort",
245    category = "array/sorting_sets",
246    summary = "Return permutation indices that sort arrays along a dimension.",
247    keywords = "argsort,sort,indices,permutation,gpu",
248    accel = "sink",
249    sink = true,
250    type_resolver(index_output_type),
251    descriptor(crate::builtins::array::sorting_sets::argsort::ARGSORT_DESCRIPTOR),
252    builtin_path = "crate::builtins::array::sorting_sets::argsort"
253)]
254async fn argsort_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
255    let evaluation = sort::evaluate(value, &rest).await?;
256    Ok(evaluation.indices_value())
257}
258
259#[cfg(test)]
260pub(crate) mod tests {
261    use super::index_output_type;
262    use super::sort;
263    use super::ARGSORT_ERROR_COMPARISON_METHOD_REQUIRES_STRING;
264    use super::ARGSORT_ERROR_COMPARISON_METHOD_UNKNOWN;
265    use super::ARGSORT_ERROR_INVALID_DIMENSION;
266    use super::ARGSORT_ERROR_MISSINGPLACEMENT_UNSUPPORTED;
267    use futures::executor::block_on;
268
269    fn argsort_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
270        block_on(super::argsort_builtin(value, rest))
271    }
272    use crate::builtins::common::test_support;
273    use runmat_builtins::{ComplexTensor, IntValue, ResolveContext, Tensor, Type, Value};
274
275    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
276    #[test]
277    fn argsort_vector_default() {
278        let tensor = Tensor::new(vec![4.0, 1.0, 3.0], vec![3, 1]).unwrap();
279        let indices = argsort_builtin(Value::Tensor(tensor), Vec::new()).expect("argsort");
280        match indices {
281            Value::Tensor(t) => {
282                assert_eq!(t.data, vec![2.0, 3.0, 1.0]);
283                assert_eq!(t.shape, vec![3, 1]);
284            }
285            other => panic!("expected tensor result, got {other:?}"),
286        }
287    }
288
289    #[test]
290    fn argsort_type_resolver_indices() {
291        assert_eq!(
292            index_output_type(&[Type::tensor()], &ResolveContext::new(Vec::new())),
293            Type::tensor()
294        );
295    }
296
297    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
298    #[test]
299    fn argsort_descend_direction() {
300        let tensor = Tensor::new(vec![10.0, 4.0, 7.0, 9.0], vec![4, 1]).unwrap();
301        let indices =
302            argsort_builtin(Value::Tensor(tensor), vec![Value::from("descend")]).expect("argsort");
303        match indices {
304            Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 4.0, 3.0, 2.0]),
305            other => panic!("expected tensor result, got {other:?}"),
306        }
307    }
308
309    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
310    #[test]
311    fn argsort_dimension_two() {
312        let tensor = Tensor::new(vec![1.0, 6.0, 4.0, 2.0, 3.0, 5.0], vec![2, 3]).unwrap();
313        let args = vec![Value::Int(IntValue::I32(2))];
314        let indices =
315            argsort_builtin(Value::Tensor(tensor.clone()), args.clone()).expect("argsort");
316        let expected = futures::executor::block_on(sort::evaluate(Value::Tensor(tensor), &args))
317            .expect("sort evaluate")
318            .indices_value();
319        assert_eq!(indices, expected);
320    }
321
322    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
323    #[test]
324    fn argsort_absolute_comparison() {
325        let tensor = Tensor::new(vec![-8.0, -1.0, 3.0, -2.0], vec![4, 1]).unwrap();
326        let indices = argsort_builtin(
327            Value::Tensor(tensor),
328            vec![Value::from("ComparisonMethod"), Value::from("abs")],
329        )
330        .expect("argsort");
331        match indices {
332            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 4.0, 3.0, 1.0]),
333            other => panic!("expected tensor result, got {other:?}"),
334        }
335    }
336
337    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
338    #[test]
339    fn argsort_handles_nan_like_sort() {
340        let tensor = Tensor::new(vec![f64::NAN, 4.0, 1.0, 2.0], vec![4, 1]).unwrap();
341        let indices = argsort_builtin(Value::Tensor(tensor), Vec::new()).expect("argsort");
342        match indices {
343            Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 4.0, 2.0, 1.0]),
344            other => panic!("expected tensor result, got {other:?}"),
345        }
346    }
347
348    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
349    #[test]
350    fn argsort_dimension_placeholder_then_dim() {
351        let tensor = Tensor::new(vec![1.0, 3.0, 4.0, 2.0], vec![2, 2]).unwrap();
352        let placeholder = Tensor::new(Vec::new(), vec![0, 0]).unwrap();
353        let args = vec![
354            Value::Tensor(placeholder),
355            Value::Int(IntValue::I32(2)),
356            Value::from("descend"),
357        ];
358        let indices =
359            argsort_builtin(Value::Tensor(tensor.clone()), args.clone()).expect("argsort");
360        let expected = futures::executor::block_on(sort::evaluate(Value::Tensor(tensor), &args))
361            .expect("sort evaluate")
362            .indices_value();
363        assert_eq!(indices, expected);
364    }
365
366    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
367    #[test]
368    fn argsort_dimension_greater_than_ndims_returns_ones() {
369        let tensor = Tensor::new(vec![1.0, 3.0, 2.0], vec![3, 1]).unwrap();
370        let indices = argsort_builtin(Value::Tensor(tensor), vec![Value::Int(IntValue::I32(5))])
371            .expect("argsort");
372        match indices {
373            Value::Tensor(t) => assert!(t.data.iter().all(|v| (*v - 1.0).abs() < f64::EPSILON)),
374            other => panic!("expected tensor result, got {other:?}"),
375        }
376    }
377
378    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
379    #[test]
380    fn argsort_dimension_zero_errors() {
381        let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
382        let err =
383            argsort_builtin(Value::Tensor(tensor), vec![Value::Int(IntValue::I32(0))]).unwrap_err();
384        assert_eq!(err.identifier(), ARGSORT_ERROR_INVALID_DIMENSION.identifier);
385    }
386
387    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
388    #[test]
389    fn argsort_invalid_argument_errors() {
390        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
391        let err = argsort_builtin(
392            Value::Tensor(tensor),
393            vec![Value::from("MissingPlacement"), Value::from("auto")],
394        )
395        .unwrap_err();
396        assert_eq!(
397            err.identifier(),
398            ARGSORT_ERROR_MISSINGPLACEMENT_UNSUPPORTED.identifier
399        );
400    }
401
402    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
403    #[test]
404    fn argsort_invalid_comparison_method_errors() {
405        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
406        let err = argsort_builtin(
407            Value::Tensor(tensor),
408            vec![Value::from("ComparisonMethod"), Value::from("unknown")],
409        )
410        .unwrap_err();
411        assert_eq!(
412            err.identifier(),
413            ARGSORT_ERROR_COMPARISON_METHOD_UNKNOWN.identifier
414        );
415    }
416
417    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
418    #[test]
419    fn argsort_invalid_comparison_method_value_errors() {
420        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
421        let err = argsort_builtin(
422            Value::Tensor(tensor),
423            vec![
424                Value::from("ComparisonMethod"),
425                Value::Int(IntValue::I32(1)),
426            ],
427        )
428        .unwrap_err();
429        assert_eq!(
430            err.identifier(),
431            ARGSORT_ERROR_COMPARISON_METHOD_REQUIRES_STRING.identifier
432        );
433    }
434
435    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
436    #[test]
437    fn argsort_stable_with_duplicates() {
438        let tensor = Tensor::new(vec![2.0, 2.0, 1.0, 2.0], vec![4, 1]).unwrap();
439        let indices = argsort_builtin(Value::Tensor(tensor), Vec::new()).expect("argsort");
440        match indices {
441            Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 1.0, 2.0, 4.0]),
442            other => panic!("expected tensor result, got {other:?}"),
443        }
444    }
445
446    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
447    #[test]
448    fn argsort_complex_real_method() {
449        let tensor =
450            ComplexTensor::new(vec![(1.0, 2.0), (-3.0, 0.5), (1.0, -1.0)], vec![3, 1]).unwrap();
451        let indices = argsort_builtin(
452            Value::ComplexTensor(tensor),
453            vec![Value::from("ComparisonMethod"), Value::from("real")],
454        )
455        .expect("argsort");
456        match indices {
457            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
458            other => panic!("expected tensor result, got {other:?}"),
459        }
460    }
461
462    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
463    #[test]
464    fn argsort_gpu_roundtrip() {
465        test_support::with_test_provider(|provider| {
466            let tensor = Tensor::new(vec![3.0, 1.0, 2.0], vec![3, 1]).unwrap();
467            let view = runmat_accelerate_api::HostTensorView {
468                data: &tensor.data,
469                shape: &tensor.shape,
470            };
471            let handle = provider.upload(&view).expect("upload");
472            let indices = argsort_builtin(Value::GpuTensor(handle), Vec::new()).expect("argsort");
473            match indices {
474                Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
475                other => panic!("expected tensor result, got {other:?}"),
476            }
477        });
478    }
479
480    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
481    #[test]
482    #[cfg(feature = "wgpu")]
483    fn argsort_wgpu_matches_cpu() {
484        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
485            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
486        );
487        let tensor = Tensor::new(vec![0.0, 5.0, -1.0, 2.0], vec![4, 1]).unwrap();
488        let cpu_indices = argsort_builtin(Value::Tensor(tensor.clone()), Vec::new()).unwrap();
489        let view = runmat_accelerate_api::HostTensorView {
490            data: &tensor.data,
491            shape: &tensor.shape,
492        };
493        let gpu_handle = runmat_accelerate_api::provider()
494            .unwrap()
495            .upload(&view)
496            .expect("upload");
497        let gpu_indices = argsort_builtin(Value::GpuTensor(gpu_handle), Vec::new()).unwrap();
498
499        let cpu_tensor = match cpu_indices {
500            Value::Tensor(t) => t,
501            other => panic!("expected tensor, got {other:?}"),
502        };
503        let gpu_tensor = match gpu_indices {
504            Value::Tensor(t) => t,
505            other => panic!("expected tensor, got {other:?}"),
506        };
507        assert_eq!(gpu_tensor.shape, cpu_tensor.shape);
508        assert_eq!(gpu_tensor.data, cpu_tensor.data);
509    }
510}