1use runmat_builtins::Value;
4use runmat_macros::runtime_builtin;
5
6use super::sort;
7use super::type_resolvers::index_output_type;
8use crate::builtins::common::spec::{
9 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
10 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
11};
12
13#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::sorting_sets::argsort")]
14pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
15 name: "argsort",
16 op_kind: GpuOpKind::Custom("sort"),
17 supported_precisions: &[ScalarType::F32, ScalarType::F64],
18 broadcast: BroadcastSemantics::None,
19 provider_hooks: &[ProviderHook::Custom("sort_dim")],
20 constant_strategy: ConstantStrategy::InlineLiteral,
21 residency: ResidencyPolicy::GatherImmediately,
22 nan_mode: ReductionNaN::Include,
23 two_pass_threshold: None,
24 workgroup_size: None,
25 accepts_nan_mode: true,
26 notes: "Shares provider hooks with `sort`; when unavailable tensors are gathered to host memory before computing indices.",
27};
28
29#[runmat_macros::register_fusion_spec(
30 builtin_path = "crate::builtins::array::sorting_sets::argsort"
31)]
32pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
33 name: "argsort",
34 shape: ShapeRequirements::Any,
35 constant_strategy: ConstantStrategy::InlineLiteral,
36 elementwise: None,
37 reduction: None,
38 emits_nan: true,
39 notes: "`argsort` breaks fusion chains and acts as a residency sink; upstream tensors are gathered when no GPU sort kernel is provided.",
40};
41
42#[runtime_builtin(
43 name = "argsort",
44 category = "array/sorting_sets",
45 summary = "Return the permutation indices that would sort tensors along a dimension.",
46 keywords = "argsort,sort,indices,permutation,gpu",
47 accel = "sink",
48 sink = true,
49 type_resolver(index_output_type),
50 builtin_path = "crate::builtins::array::sorting_sets::argsort"
51)]
52async fn argsort_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
53 let evaluation = sort::evaluate(value, &rest).await?;
54 Ok(evaluation.indices_value())
55}
56
57#[cfg(test)]
58pub(crate) mod tests {
59 use super::index_output_type;
60 use super::sort;
61 use futures::executor::block_on;
62
63 fn argsort_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
64 block_on(super::argsort_builtin(value, rest))
65 }
66 use crate::builtins::common::test_support;
67 use runmat_builtins::{ComplexTensor, IntValue, ResolveContext, Tensor, Type, Value};
68
69 fn error_message(err: crate::RuntimeError) -> String {
70 err.message().to_string()
71 }
72
73 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
74 #[test]
75 fn argsort_vector_default() {
76 let tensor = Tensor::new(vec![4.0, 1.0, 3.0], vec![3, 1]).unwrap();
77 let indices = argsort_builtin(Value::Tensor(tensor), Vec::new()).expect("argsort");
78 match indices {
79 Value::Tensor(t) => {
80 assert_eq!(t.data, vec![2.0, 3.0, 1.0]);
81 assert_eq!(t.shape, vec![3, 1]);
82 }
83 other => panic!("expected tensor result, got {other:?}"),
84 }
85 }
86
87 #[test]
88 fn argsort_type_resolver_indices() {
89 assert_eq!(
90 index_output_type(&[Type::tensor()], &ResolveContext::new(Vec::new())),
91 Type::tensor()
92 );
93 }
94
95 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
96 #[test]
97 fn argsort_descend_direction() {
98 let tensor = Tensor::new(vec![10.0, 4.0, 7.0, 9.0], vec![4, 1]).unwrap();
99 let indices =
100 argsort_builtin(Value::Tensor(tensor), vec![Value::from("descend")]).expect("argsort");
101 match indices {
102 Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 4.0, 3.0, 2.0]),
103 other => panic!("expected tensor result, got {other:?}"),
104 }
105 }
106
107 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
108 #[test]
109 fn argsort_dimension_two() {
110 let tensor = Tensor::new(vec![1.0, 6.0, 4.0, 2.0, 3.0, 5.0], vec![2, 3]).unwrap();
111 let args = vec![Value::Int(IntValue::I32(2))];
112 let indices =
113 argsort_builtin(Value::Tensor(tensor.clone()), args.clone()).expect("argsort");
114 let expected = futures::executor::block_on(sort::evaluate(Value::Tensor(tensor), &args))
115 .expect("sort evaluate")
116 .indices_value();
117 assert_eq!(indices, expected);
118 }
119
120 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
121 #[test]
122 fn argsort_absolute_comparison() {
123 let tensor = Tensor::new(vec![-8.0, -1.0, 3.0, -2.0], vec![4, 1]).unwrap();
124 let indices = argsort_builtin(
125 Value::Tensor(tensor),
126 vec![Value::from("ComparisonMethod"), Value::from("abs")],
127 )
128 .expect("argsort");
129 match indices {
130 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 4.0, 3.0, 1.0]),
131 other => panic!("expected tensor result, got {other:?}"),
132 }
133 }
134
135 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
136 #[test]
137 fn argsort_handles_nan_like_sort() {
138 let tensor = Tensor::new(vec![f64::NAN, 4.0, 1.0, 2.0], vec![4, 1]).unwrap();
139 let indices = argsort_builtin(Value::Tensor(tensor), Vec::new()).expect("argsort");
140 match indices {
141 Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 4.0, 2.0, 1.0]),
142 other => panic!("expected tensor result, got {other:?}"),
143 }
144 }
145
146 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
147 #[test]
148 fn argsort_dimension_placeholder_then_dim() {
149 let tensor = Tensor::new(vec![1.0, 3.0, 4.0, 2.0], vec![2, 2]).unwrap();
150 let placeholder = Tensor::new(Vec::new(), vec![0, 0]).unwrap();
151 let args = vec![
152 Value::Tensor(placeholder),
153 Value::Int(IntValue::I32(2)),
154 Value::from("descend"),
155 ];
156 let indices =
157 argsort_builtin(Value::Tensor(tensor.clone()), args.clone()).expect("argsort");
158 let expected = futures::executor::block_on(sort::evaluate(Value::Tensor(tensor), &args))
159 .expect("sort evaluate")
160 .indices_value();
161 assert_eq!(indices, expected);
162 }
163
164 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
165 #[test]
166 fn argsort_dimension_greater_than_ndims_returns_ones() {
167 let tensor = Tensor::new(vec![1.0, 3.0, 2.0], vec![3, 1]).unwrap();
168 let indices = argsort_builtin(Value::Tensor(tensor), vec![Value::Int(IntValue::I32(5))])
169 .expect("argsort");
170 match indices {
171 Value::Tensor(t) => assert!(t.data.iter().all(|v| (*v - 1.0).abs() < f64::EPSILON)),
172 other => panic!("expected tensor result, got {other:?}"),
173 }
174 }
175
176 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
177 #[test]
178 fn argsort_dimension_zero_errors() {
179 let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
180 let err = error_message(
181 argsort_builtin(Value::Tensor(tensor), vec![Value::Int(IntValue::I32(0))]).unwrap_err(),
182 );
183 assert!(
184 err.contains("dimension must be >= 1"),
185 "unexpected error: {err}"
186 );
187 }
188
189 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
190 #[test]
191 fn argsort_invalid_argument_errors() {
192 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
193 let err = error_message(
194 argsort_builtin(
195 Value::Tensor(tensor),
196 vec![Value::from("MissingPlacement"), Value::from("auto")],
197 )
198 .unwrap_err(),
199 );
200 assert!(
201 err.contains("sort: the 'MissingPlacement' option is not supported"),
202 "{err}"
203 );
204 }
205
206 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
207 #[test]
208 fn argsort_invalid_comparison_method_errors() {
209 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
210 let err = error_message(
211 argsort_builtin(
212 Value::Tensor(tensor),
213 vec![Value::from("ComparisonMethod"), Value::from("unknown")],
214 )
215 .unwrap_err(),
216 );
217 assert!(
218 err.contains("unsupported ComparisonMethod"),
219 "unexpected error: {err}"
220 );
221 }
222
223 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
224 #[test]
225 fn argsort_invalid_comparison_method_value_errors() {
226 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
227 let err = error_message(
228 argsort_builtin(
229 Value::Tensor(tensor),
230 vec![
231 Value::from("ComparisonMethod"),
232 Value::Int(IntValue::I32(1)),
233 ],
234 )
235 .unwrap_err(),
236 );
237 assert!(
238 err.contains("requires a string value"),
239 "unexpected error: {err}"
240 );
241 }
242
243 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
244 #[test]
245 fn argsort_stable_with_duplicates() {
246 let tensor = Tensor::new(vec![2.0, 2.0, 1.0, 2.0], vec![4, 1]).unwrap();
247 let indices = argsort_builtin(Value::Tensor(tensor), Vec::new()).expect("argsort");
248 match indices {
249 Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 1.0, 2.0, 4.0]),
250 other => panic!("expected tensor result, got {other:?}"),
251 }
252 }
253
254 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
255 #[test]
256 fn argsort_complex_real_method() {
257 let tensor =
258 ComplexTensor::new(vec![(1.0, 2.0), (-3.0, 0.5), (1.0, -1.0)], vec![3, 1]).unwrap();
259 let indices = argsort_builtin(
260 Value::ComplexTensor(tensor),
261 vec![Value::from("ComparisonMethod"), Value::from("real")],
262 )
263 .expect("argsort");
264 match indices {
265 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
266 other => panic!("expected tensor result, got {other:?}"),
267 }
268 }
269
270 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
271 #[test]
272 fn argsort_gpu_roundtrip() {
273 test_support::with_test_provider(|provider| {
274 let tensor = Tensor::new(vec![3.0, 1.0, 2.0], vec![3, 1]).unwrap();
275 let view = runmat_accelerate_api::HostTensorView {
276 data: &tensor.data,
277 shape: &tensor.shape,
278 };
279 let handle = provider.upload(&view).expect("upload");
280 let indices = argsort_builtin(Value::GpuTensor(handle), Vec::new()).expect("argsort");
281 match indices {
282 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
283 other => panic!("expected tensor result, got {other:?}"),
284 }
285 });
286 }
287
288 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
289 #[test]
290 #[cfg(feature = "wgpu")]
291 fn argsort_wgpu_matches_cpu() {
292 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
293 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
294 );
295 let tensor = Tensor::new(vec![0.0, 5.0, -1.0, 2.0], vec![4, 1]).unwrap();
296 let cpu_indices = argsort_builtin(Value::Tensor(tensor.clone()), Vec::new()).unwrap();
297 let view = runmat_accelerate_api::HostTensorView {
298 data: &tensor.data,
299 shape: &tensor.shape,
300 };
301 let gpu_handle = runmat_accelerate_api::provider()
302 .unwrap()
303 .upload(&view)
304 .expect("upload");
305 let gpu_indices = argsort_builtin(Value::GpuTensor(gpu_handle), Vec::new()).unwrap();
306
307 let cpu_tensor = match cpu_indices {
308 Value::Tensor(t) => t,
309 other => panic!("expected tensor, got {other:?}"),
310 };
311 let gpu_tensor = match gpu_indices {
312 Value::Tensor(t) => t,
313 other => panic!("expected tensor, got {other:?}"),
314 };
315 assert_eq!(gpu_tensor.shape, cpu_tensor.shape);
316 assert_eq!(gpu_tensor.data, cpu_tensor.data);
317 }
318}