1use runmat_builtins::{
4 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor, Value,
6};
7use runmat_macros::runtime_builtin;
8
9use super::sort;
10use super::type_resolvers::index_output_type;
11use crate::builtins::common::spec::{
12 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
13 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
14};
15
16#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::sorting_sets::argsort")]
17pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
18 name: "argsort",
19 op_kind: GpuOpKind::Custom("sort"),
20 supported_precisions: &[ScalarType::F32, ScalarType::F64],
21 broadcast: BroadcastSemantics::None,
22 provider_hooks: &[ProviderHook::Custom("sort_dim")],
23 constant_strategy: ConstantStrategy::InlineLiteral,
24 residency: ResidencyPolicy::GatherImmediately,
25 nan_mode: ReductionNaN::Include,
26 two_pass_threshold: None,
27 workgroup_size: None,
28 accepts_nan_mode: true,
29 notes: "Shares provider hooks with `sort`; when unavailable tensors are gathered to host memory before computing indices.",
30};
31
32#[runmat_macros::register_fusion_spec(
33 builtin_path = "crate::builtins::array::sorting_sets::argsort"
34)]
35pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
36 name: "argsort",
37 shape: ShapeRequirements::Any,
38 constant_strategy: ConstantStrategy::InlineLiteral,
39 elementwise: None,
40 reduction: None,
41 emits_nan: true,
42 notes: "`argsort` breaks fusion chains and acts as a residency sink; upstream tensors are gathered when no GPU sort kernel is provided.",
43};
44
45const ARGSORT_OUTPUT_I: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
46 name: "I",
47 ty: BuiltinParamType::NumericArray,
48 arity: BuiltinParamArity::Required,
49 default: None,
50 description: "One-based permutation indices that sort each slice.",
51}];
52
53const ARGSORT_INPUTS_A: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
54 name: "A",
55 ty: BuiltinParamType::Any,
56 arity: BuiltinParamArity::Required,
57 default: None,
58 description: "Input array.",
59}];
60
61const ARGSORT_INPUTS_A_ARG1: [BuiltinParamDescriptor; 2] = [
62 BuiltinParamDescriptor {
63 name: "A",
64 ty: BuiltinParamType::Any,
65 arity: BuiltinParamArity::Required,
66 default: None,
67 description: "Input array.",
68 },
69 BuiltinParamDescriptor {
70 name: "arg1",
71 ty: BuiltinParamType::Any,
72 arity: BuiltinParamArity::Required,
73 default: None,
74 description: "Dimension selector or direction token.",
75 },
76];
77
78const ARGSORT_INPUTS_A_ARG1_ARG2: [BuiltinParamDescriptor; 3] = [
79 BuiltinParamDescriptor {
80 name: "A",
81 ty: BuiltinParamType::Any,
82 arity: BuiltinParamArity::Required,
83 default: None,
84 description: "Input array.",
85 },
86 BuiltinParamDescriptor {
87 name: "arg1",
88 ty: BuiltinParamType::Any,
89 arity: BuiltinParamArity::Required,
90 default: None,
91 description: "Dimension selector, placeholder, or direction token.",
92 },
93 BuiltinParamDescriptor {
94 name: "arg2",
95 ty: BuiltinParamType::Any,
96 arity: BuiltinParamArity::Required,
97 default: None,
98 description: "Dimension selector or direction token.",
99 },
100];
101
102const ARGSORT_INPUTS_COMPARISON_METHOD: [BuiltinParamDescriptor; 4] = [
103 BuiltinParamDescriptor {
104 name: "A",
105 ty: BuiltinParamType::Any,
106 arity: BuiltinParamArity::Required,
107 default: None,
108 description: "Input array.",
109 },
110 BuiltinParamDescriptor {
111 name: "arg",
112 ty: BuiltinParamType::Any,
113 arity: BuiltinParamArity::Variadic,
114 default: None,
115 description: "Optional dimension/direction arguments.",
116 },
117 BuiltinParamDescriptor {
118 name: "name",
119 ty: BuiltinParamType::StringScalar,
120 arity: BuiltinParamArity::Required,
121 default: Some("\"ComparisonMethod\""),
122 description: "Name-value option key.",
123 },
124 BuiltinParamDescriptor {
125 name: "method",
126 ty: BuiltinParamType::StringScalar,
127 arity: BuiltinParamArity::Required,
128 default: Some("\"auto\""),
129 description: "Comparison method: 'auto', 'real', or 'abs'.",
130 },
131];
132
133const ARGSORT_INPUTS_MISSING_PLACEMENT: [BuiltinParamDescriptor; 4] = [
134 BuiltinParamDescriptor {
135 name: "A",
136 ty: BuiltinParamType::Any,
137 arity: BuiltinParamArity::Required,
138 default: None,
139 description: "Input array.",
140 },
141 BuiltinParamDescriptor {
142 name: "arg",
143 ty: BuiltinParamType::Any,
144 arity: BuiltinParamArity::Variadic,
145 default: None,
146 description: "Optional dimension/direction arguments.",
147 },
148 BuiltinParamDescriptor {
149 name: "name",
150 ty: BuiltinParamType::StringScalar,
151 arity: BuiltinParamArity::Required,
152 default: Some("\"MissingPlacement\""),
153 description: "Name-value option key.",
154 },
155 BuiltinParamDescriptor {
156 name: "placement",
157 ty: BuiltinParamType::StringScalar,
158 arity: BuiltinParamArity::Required,
159 default: Some("\"auto\""),
160 description: "Requested NaN placement option (currently unsupported).",
161 },
162];
163
164const ARGSORT_SIGNATURES: [BuiltinSignatureDescriptor; 5] = [
165 BuiltinSignatureDescriptor {
166 label: "I = argsort(A)",
167 inputs: &ARGSORT_INPUTS_A,
168 outputs: &ARGSORT_OUTPUT_I,
169 },
170 BuiltinSignatureDescriptor {
171 label: "I = argsort(A, arg1)",
172 inputs: &ARGSORT_INPUTS_A_ARG1,
173 outputs: &ARGSORT_OUTPUT_I,
174 },
175 BuiltinSignatureDescriptor {
176 label: "I = argsort(A, arg1, arg2)",
177 inputs: &ARGSORT_INPUTS_A_ARG1_ARG2,
178 outputs: &ARGSORT_OUTPUT_I,
179 },
180 BuiltinSignatureDescriptor {
181 label: "I = argsort(A, ..., \"ComparisonMethod\", method)",
182 inputs: &ARGSORT_INPUTS_COMPARISON_METHOD,
183 outputs: &ARGSORT_OUTPUT_I,
184 },
185 BuiltinSignatureDescriptor {
186 label: "I = argsort(A, ..., \"MissingPlacement\", placement)",
187 inputs: &ARGSORT_INPUTS_MISSING_PLACEMENT,
188 outputs: &ARGSORT_OUTPUT_I,
189 },
190];
191
192const ARGSORT_ERROR_INVALID_DIMENSION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
193 code: "RM.ARGSORT.INVALID_DIMENSION",
194 identifier: Some("RunMat:sort:InvalidDimension"),
195 when: "Dimension argument is non-positive, non-integer, or otherwise invalid.",
196 message: "sort: invalid dimension argument",
197};
198
199const ARGSORT_ERROR_COMPARISON_METHOD_REQUIRES_STRING: BuiltinErrorDescriptor =
200 BuiltinErrorDescriptor {
201 code: "RM.ARGSORT.COMPARISON_METHOD_REQUIRES_STRING",
202 identifier: Some("RunMat:sort:ComparisonMethodRequiresString"),
203 when: "ComparisonMethod option value is not string-like.",
204 message: "sort: 'ComparisonMethod' requires a string value",
205 };
206
207const ARGSORT_ERROR_COMPARISON_METHOD_UNKNOWN: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
208 code: "RM.ARGSORT.COMPARISON_METHOD_UNKNOWN",
209 identifier: Some("RunMat:sort:ComparisonMethodUnknown"),
210 when: "ComparisonMethod option value is not one of 'auto'/'real'/'abs'.",
211 message: "sort: unsupported ComparisonMethod",
212};
213
214const ARGSORT_ERROR_MISSINGPLACEMENT_UNSUPPORTED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
215 code: "RM.ARGSORT.MISSINGPLACEMENT_UNSUPPORTED",
216 identifier: Some("RunMat:sort:MissingPlacementUnsupported"),
217 when: "MissingPlacement option is provided but unsupported.",
218 message: "sort: the 'MissingPlacement' option is not supported yet",
219};
220
221const ARGSORT_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
222 code: "RM.ARGSORT.INVALID_ARGUMENT",
223 identifier: Some("RunMat:sort:InvalidArgument"),
224 when: "Parser encounters invalid or unrecognized option/value arguments.",
225 message: "sort: invalid argument sequence",
226};
227
228const ARGSORT_ERRORS: [BuiltinErrorDescriptor; 5] = [
229 ARGSORT_ERROR_INVALID_DIMENSION,
230 ARGSORT_ERROR_COMPARISON_METHOD_REQUIRES_STRING,
231 ARGSORT_ERROR_COMPARISON_METHOD_UNKNOWN,
232 ARGSORT_ERROR_MISSINGPLACEMENT_UNSUPPORTED,
233 ARGSORT_ERROR_INVALID_ARGUMENT,
234];
235
236pub const ARGSORT_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
237 signatures: &ARGSORT_SIGNATURES,
238 output_mode: BuiltinOutputMode::Fixed,
239 completion_policy: BuiltinCompletionPolicy::Public,
240 errors: &ARGSORT_ERRORS,
241};
242
243#[runtime_builtin(
244 name = "argsort",
245 category = "array/sorting_sets",
246 summary = "Return permutation indices that sort arrays along a dimension.",
247 keywords = "argsort,sort,indices,permutation,gpu",
248 accel = "sink",
249 sink = true,
250 type_resolver(index_output_type),
251 descriptor(crate::builtins::array::sorting_sets::argsort::ARGSORT_DESCRIPTOR),
252 builtin_path = "crate::builtins::array::sorting_sets::argsort"
253)]
254async fn argsort_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
255 let evaluation = sort::evaluate(value, &rest).await?;
256 Ok(evaluation.indices_value())
257}
258
259#[cfg(test)]
260pub(crate) mod tests {
261 use super::index_output_type;
262 use super::sort;
263 use super::ARGSORT_ERROR_COMPARISON_METHOD_REQUIRES_STRING;
264 use super::ARGSORT_ERROR_COMPARISON_METHOD_UNKNOWN;
265 use super::ARGSORT_ERROR_INVALID_DIMENSION;
266 use super::ARGSORT_ERROR_MISSINGPLACEMENT_UNSUPPORTED;
267 use futures::executor::block_on;
268
269 fn argsort_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
270 block_on(super::argsort_builtin(value, rest))
271 }
272 use crate::builtins::common::test_support;
273 use runmat_builtins::{ComplexTensor, IntValue, ResolveContext, Tensor, Type, Value};
274
275 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
276 #[test]
277 fn argsort_vector_default() {
278 let tensor = Tensor::new(vec![4.0, 1.0, 3.0], vec![3, 1]).unwrap();
279 let indices = argsort_builtin(Value::Tensor(tensor), Vec::new()).expect("argsort");
280 match indices {
281 Value::Tensor(t) => {
282 assert_eq!(t.data, vec![2.0, 3.0, 1.0]);
283 assert_eq!(t.shape, vec![3, 1]);
284 }
285 other => panic!("expected tensor result, got {other:?}"),
286 }
287 }
288
289 #[test]
290 fn argsort_type_resolver_indices() {
291 assert_eq!(
292 index_output_type(&[Type::tensor()], &ResolveContext::new(Vec::new())),
293 Type::tensor()
294 );
295 }
296
297 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
298 #[test]
299 fn argsort_descend_direction() {
300 let tensor = Tensor::new(vec![10.0, 4.0, 7.0, 9.0], vec![4, 1]).unwrap();
301 let indices =
302 argsort_builtin(Value::Tensor(tensor), vec![Value::from("descend")]).expect("argsort");
303 match indices {
304 Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 4.0, 3.0, 2.0]),
305 other => panic!("expected tensor result, got {other:?}"),
306 }
307 }
308
309 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
310 #[test]
311 fn argsort_dimension_two() {
312 let tensor = Tensor::new(vec![1.0, 6.0, 4.0, 2.0, 3.0, 5.0], vec![2, 3]).unwrap();
313 let args = vec![Value::Int(IntValue::I32(2))];
314 let indices =
315 argsort_builtin(Value::Tensor(tensor.clone()), args.clone()).expect("argsort");
316 let expected = futures::executor::block_on(sort::evaluate(Value::Tensor(tensor), &args))
317 .expect("sort evaluate")
318 .indices_value();
319 assert_eq!(indices, expected);
320 }
321
322 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
323 #[test]
324 fn argsort_absolute_comparison() {
325 let tensor = Tensor::new(vec![-8.0, -1.0, 3.0, -2.0], vec![4, 1]).unwrap();
326 let indices = argsort_builtin(
327 Value::Tensor(tensor),
328 vec![Value::from("ComparisonMethod"), Value::from("abs")],
329 )
330 .expect("argsort");
331 match indices {
332 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 4.0, 3.0, 1.0]),
333 other => panic!("expected tensor result, got {other:?}"),
334 }
335 }
336
337 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
338 #[test]
339 fn argsort_handles_nan_like_sort() {
340 let tensor = Tensor::new(vec![f64::NAN, 4.0, 1.0, 2.0], vec![4, 1]).unwrap();
341 let indices = argsort_builtin(Value::Tensor(tensor), Vec::new()).expect("argsort");
342 match indices {
343 Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 4.0, 2.0, 1.0]),
344 other => panic!("expected tensor result, got {other:?}"),
345 }
346 }
347
348 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
349 #[test]
350 fn argsort_dimension_placeholder_then_dim() {
351 let tensor = Tensor::new(vec![1.0, 3.0, 4.0, 2.0], vec![2, 2]).unwrap();
352 let placeholder = Tensor::new(Vec::new(), vec![0, 0]).unwrap();
353 let args = vec![
354 Value::Tensor(placeholder),
355 Value::Int(IntValue::I32(2)),
356 Value::from("descend"),
357 ];
358 let indices =
359 argsort_builtin(Value::Tensor(tensor.clone()), args.clone()).expect("argsort");
360 let expected = futures::executor::block_on(sort::evaluate(Value::Tensor(tensor), &args))
361 .expect("sort evaluate")
362 .indices_value();
363 assert_eq!(indices, expected);
364 }
365
366 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
367 #[test]
368 fn argsort_dimension_greater_than_ndims_returns_ones() {
369 let tensor = Tensor::new(vec![1.0, 3.0, 2.0], vec![3, 1]).unwrap();
370 let indices = argsort_builtin(Value::Tensor(tensor), vec![Value::Int(IntValue::I32(5))])
371 .expect("argsort");
372 match indices {
373 Value::Tensor(t) => assert!(t.data.iter().all(|v| (*v - 1.0).abs() < f64::EPSILON)),
374 other => panic!("expected tensor result, got {other:?}"),
375 }
376 }
377
378 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
379 #[test]
380 fn argsort_dimension_zero_errors() {
381 let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
382 let err =
383 argsort_builtin(Value::Tensor(tensor), vec![Value::Int(IntValue::I32(0))]).unwrap_err();
384 assert_eq!(err.identifier(), ARGSORT_ERROR_INVALID_DIMENSION.identifier);
385 }
386
387 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
388 #[test]
389 fn argsort_invalid_argument_errors() {
390 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
391 let err = argsort_builtin(
392 Value::Tensor(tensor),
393 vec![Value::from("MissingPlacement"), Value::from("auto")],
394 )
395 .unwrap_err();
396 assert_eq!(
397 err.identifier(),
398 ARGSORT_ERROR_MISSINGPLACEMENT_UNSUPPORTED.identifier
399 );
400 }
401
402 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
403 #[test]
404 fn argsort_invalid_comparison_method_errors() {
405 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
406 let err = argsort_builtin(
407 Value::Tensor(tensor),
408 vec![Value::from("ComparisonMethod"), Value::from("unknown")],
409 )
410 .unwrap_err();
411 assert_eq!(
412 err.identifier(),
413 ARGSORT_ERROR_COMPARISON_METHOD_UNKNOWN.identifier
414 );
415 }
416
417 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
418 #[test]
419 fn argsort_invalid_comparison_method_value_errors() {
420 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
421 let err = argsort_builtin(
422 Value::Tensor(tensor),
423 vec![
424 Value::from("ComparisonMethod"),
425 Value::Int(IntValue::I32(1)),
426 ],
427 )
428 .unwrap_err();
429 assert_eq!(
430 err.identifier(),
431 ARGSORT_ERROR_COMPARISON_METHOD_REQUIRES_STRING.identifier
432 );
433 }
434
435 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
436 #[test]
437 fn argsort_stable_with_duplicates() {
438 let tensor = Tensor::new(vec![2.0, 2.0, 1.0, 2.0], vec![4, 1]).unwrap();
439 let indices = argsort_builtin(Value::Tensor(tensor), Vec::new()).expect("argsort");
440 match indices {
441 Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 1.0, 2.0, 4.0]),
442 other => panic!("expected tensor result, got {other:?}"),
443 }
444 }
445
446 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
447 #[test]
448 fn argsort_complex_real_method() {
449 let tensor =
450 ComplexTensor::new(vec![(1.0, 2.0), (-3.0, 0.5), (1.0, -1.0)], vec![3, 1]).unwrap();
451 let indices = argsort_builtin(
452 Value::ComplexTensor(tensor),
453 vec![Value::from("ComparisonMethod"), Value::from("real")],
454 )
455 .expect("argsort");
456 match indices {
457 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
458 other => panic!("expected tensor result, got {other:?}"),
459 }
460 }
461
462 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
463 #[test]
464 fn argsort_gpu_roundtrip() {
465 test_support::with_test_provider(|provider| {
466 let tensor = Tensor::new(vec![3.0, 1.0, 2.0], vec![3, 1]).unwrap();
467 let view = runmat_accelerate_api::HostTensorView {
468 data: &tensor.data,
469 shape: &tensor.shape,
470 };
471 let handle = provider.upload(&view).expect("upload");
472 let indices = argsort_builtin(Value::GpuTensor(handle), Vec::new()).expect("argsort");
473 match indices {
474 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
475 other => panic!("expected tensor result, got {other:?}"),
476 }
477 });
478 }
479
480 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
481 #[test]
482 #[cfg(feature = "wgpu")]
483 fn argsort_wgpu_matches_cpu() {
484 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
485 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
486 );
487 let tensor = Tensor::new(vec![0.0, 5.0, -1.0, 2.0], vec![4, 1]).unwrap();
488 let cpu_indices = argsort_builtin(Value::Tensor(tensor.clone()), Vec::new()).unwrap();
489 let view = runmat_accelerate_api::HostTensorView {
490 data: &tensor.data,
491 shape: &tensor.shape,
492 };
493 let gpu_handle = runmat_accelerate_api::provider()
494 .unwrap()
495 .upload(&view)
496 .expect("upload");
497 let gpu_indices = argsort_builtin(Value::GpuTensor(gpu_handle), Vec::new()).unwrap();
498
499 let cpu_tensor = match cpu_indices {
500 Value::Tensor(t) => t,
501 other => panic!("expected tensor, got {other:?}"),
502 };
503 let gpu_tensor = match gpu_indices {
504 Value::Tensor(t) => t,
505 other => panic!("expected tensor, got {other:?}"),
506 };
507 assert_eq!(gpu_tensor.shape, cpu_tensor.shape);
508 assert_eq!(gpu_tensor.data, cpu_tensor.data);
509 }
510}