1use std::cmp::Ordering;
4
5use runmat_accelerate_api::{
6 GpuTensorHandle, SortComparison as ProviderSortComparison, SortOrder as ProviderSortOrder,
7};
8use runmat_builtins::{
9 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
10 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
11 ComplexTensor, Tensor, Value,
12};
13use runmat_macros::runtime_builtin;
14
15use super::type_resolvers::tensor_output_type;
16use crate::build_runtime_error;
17use crate::builtins::common::arg_tokens::{tokens_from_values, ArgToken};
18use crate::builtins::common::gpu_helpers;
19use crate::builtins::common::spec::{
20 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
21 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
22};
23use crate::builtins::common::tensor;
24
25#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::sorting_sets::sort")]
26pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
27 name: "sort",
28 op_kind: GpuOpKind::Custom("sort"),
29 supported_precisions: &[ScalarType::F32, ScalarType::F64],
30 broadcast: BroadcastSemantics::None,
31 provider_hooks: &[ProviderHook::Custom("sort_dim")],
32 constant_strategy: ConstantStrategy::InlineLiteral,
33 residency: ResidencyPolicy::GatherImmediately,
34 nan_mode: ReductionNaN::Include,
35 two_pass_threshold: None,
36 workgroup_size: None,
37 accepts_nan_mode: true,
38 notes: "Providers may add a dedicated sort kernel in the future; today tensors are gathered to host memory before sorting.",
39};
40
41#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::sorting_sets::sort")]
42pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
43 name: "sort",
44 shape: ShapeRequirements::Any,
45 constant_strategy: ConstantStrategy::InlineLiteral,
46 elementwise: None,
47 reduction: None,
48 emits_nan: true,
49 notes: "Sorting breaks fusion chains and acts as a residency sink; upstream tensors are gathered to host memory.",
50};
51
52const BUILTIN_NAME: &str = "sort";
53
54const SORT_OUTPUT_B: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
55 name: "B",
56 ty: BuiltinParamType::Any,
57 arity: BuiltinParamArity::Required,
58 default: None,
59 description: "Sorted values.",
60}];
61
62const SORT_OUTPUT_BI: [BuiltinParamDescriptor; 2] = [
63 BuiltinParamDescriptor {
64 name: "B",
65 ty: BuiltinParamType::Any,
66 arity: BuiltinParamArity::Required,
67 default: None,
68 description: "Sorted values.",
69 },
70 BuiltinParamDescriptor {
71 name: "I",
72 ty: BuiltinParamType::NumericArray,
73 arity: BuiltinParamArity::Required,
74 default: None,
75 description: "One-based index permutation for each sorted slice.",
76 },
77];
78
79const SORT_INPUTS_A: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
80 name: "A",
81 ty: BuiltinParamType::Any,
82 arity: BuiltinParamArity::Required,
83 default: None,
84 description: "Input array.",
85}];
86
87const SORT_INPUTS_A_ARG1: [BuiltinParamDescriptor; 2] = [
88 BuiltinParamDescriptor {
89 name: "A",
90 ty: BuiltinParamType::Any,
91 arity: BuiltinParamArity::Required,
92 default: None,
93 description: "Input array.",
94 },
95 BuiltinParamDescriptor {
96 name: "arg1",
97 ty: BuiltinParamType::Any,
98 arity: BuiltinParamArity::Required,
99 default: None,
100 description: "Dimension selector or direction token ('ascend'/'descend').",
101 },
102];
103
104const SORT_INPUTS_A_ARG1_ARG2: [BuiltinParamDescriptor; 3] = [
105 BuiltinParamDescriptor {
106 name: "A",
107 ty: BuiltinParamType::Any,
108 arity: BuiltinParamArity::Required,
109 default: None,
110 description: "Input array.",
111 },
112 BuiltinParamDescriptor {
113 name: "arg1",
114 ty: BuiltinParamType::Any,
115 arity: BuiltinParamArity::Required,
116 default: None,
117 description: "Dimension selector, placeholder, or direction token.",
118 },
119 BuiltinParamDescriptor {
120 name: "arg2",
121 ty: BuiltinParamType::Any,
122 arity: BuiltinParamArity::Required,
123 default: None,
124 description: "Dimension selector or direction token.",
125 },
126];
127
128const SORT_INPUTS_COMPARISON_METHOD: [BuiltinParamDescriptor; 4] = [
129 BuiltinParamDescriptor {
130 name: "A",
131 ty: BuiltinParamType::Any,
132 arity: BuiltinParamArity::Required,
133 default: None,
134 description: "Input array.",
135 },
136 BuiltinParamDescriptor {
137 name: "arg",
138 ty: BuiltinParamType::Any,
139 arity: BuiltinParamArity::Variadic,
140 default: None,
141 description: "Optional dimension/direction arguments.",
142 },
143 BuiltinParamDescriptor {
144 name: "name",
145 ty: BuiltinParamType::StringScalar,
146 arity: BuiltinParamArity::Required,
147 default: Some("\"ComparisonMethod\""),
148 description: "Name-value option key.",
149 },
150 BuiltinParamDescriptor {
151 name: "method",
152 ty: BuiltinParamType::StringScalar,
153 arity: BuiltinParamArity::Required,
154 default: Some("\"auto\""),
155 description: "Comparison method: 'auto', 'real', or 'abs'.",
156 },
157];
158
159const SORT_INPUTS_MISSING_PLACEMENT: [BuiltinParamDescriptor; 4] = [
160 BuiltinParamDescriptor {
161 name: "A",
162 ty: BuiltinParamType::Any,
163 arity: BuiltinParamArity::Required,
164 default: None,
165 description: "Input array.",
166 },
167 BuiltinParamDescriptor {
168 name: "arg",
169 ty: BuiltinParamType::Any,
170 arity: BuiltinParamArity::Variadic,
171 default: None,
172 description: "Optional dimension/direction arguments.",
173 },
174 BuiltinParamDescriptor {
175 name: "name",
176 ty: BuiltinParamType::StringScalar,
177 arity: BuiltinParamArity::Required,
178 default: Some("\"MissingPlacement\""),
179 description: "Name-value option key.",
180 },
181 BuiltinParamDescriptor {
182 name: "placement",
183 ty: BuiltinParamType::StringScalar,
184 arity: BuiltinParamArity::Required,
185 default: Some("\"auto\""),
186 description: "Requested NaN placement option (currently unsupported).",
187 },
188];
189
190const SORT_SIGNATURES: [BuiltinSignatureDescriptor; 10] = [
191 BuiltinSignatureDescriptor {
192 label: "B = sort(A)",
193 inputs: &SORT_INPUTS_A,
194 outputs: &SORT_OUTPUT_B,
195 },
196 BuiltinSignatureDescriptor {
197 label: "B = sort(A, arg1)",
198 inputs: &SORT_INPUTS_A_ARG1,
199 outputs: &SORT_OUTPUT_B,
200 },
201 BuiltinSignatureDescriptor {
202 label: "B = sort(A, arg1, arg2)",
203 inputs: &SORT_INPUTS_A_ARG1_ARG2,
204 outputs: &SORT_OUTPUT_B,
205 },
206 BuiltinSignatureDescriptor {
207 label: "B = sort(A, ..., \"ComparisonMethod\", method)",
208 inputs: &SORT_INPUTS_COMPARISON_METHOD,
209 outputs: &SORT_OUTPUT_B,
210 },
211 BuiltinSignatureDescriptor {
212 label: "B = sort(A, ..., \"MissingPlacement\", placement)",
213 inputs: &SORT_INPUTS_MISSING_PLACEMENT,
214 outputs: &SORT_OUTPUT_B,
215 },
216 BuiltinSignatureDescriptor {
217 label: "[B, I] = sort(A)",
218 inputs: &SORT_INPUTS_A,
219 outputs: &SORT_OUTPUT_BI,
220 },
221 BuiltinSignatureDescriptor {
222 label: "[B, I] = sort(A, arg1)",
223 inputs: &SORT_INPUTS_A_ARG1,
224 outputs: &SORT_OUTPUT_BI,
225 },
226 BuiltinSignatureDescriptor {
227 label: "[B, I] = sort(A, arg1, arg2)",
228 inputs: &SORT_INPUTS_A_ARG1_ARG2,
229 outputs: &SORT_OUTPUT_BI,
230 },
231 BuiltinSignatureDescriptor {
232 label: "[B, I] = sort(A, ..., \"ComparisonMethod\", method)",
233 inputs: &SORT_INPUTS_COMPARISON_METHOD,
234 outputs: &SORT_OUTPUT_BI,
235 },
236 BuiltinSignatureDescriptor {
237 label: "[B, I] = sort(A, ..., \"MissingPlacement\", placement)",
238 inputs: &SORT_INPUTS_MISSING_PLACEMENT,
239 outputs: &SORT_OUTPUT_BI,
240 },
241];
242
243const SORT_ERROR_INVALID_DIMENSION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
244 code: "RM.SORT.INVALID_DIMENSION",
245 identifier: Some("RunMat:sort:InvalidDimension"),
246 when: "Dimension argument is non-positive, non-integer, or otherwise invalid.",
247 message: "sort: invalid dimension argument",
248};
249
250const SORT_ERROR_COMPARISON_METHOD_REQUIRES_STRING: BuiltinErrorDescriptor =
251 BuiltinErrorDescriptor {
252 code: "RM.SORT.COMPARISON_METHOD_REQUIRES_STRING",
253 identifier: Some("RunMat:sort:ComparisonMethodRequiresString"),
254 when: "ComparisonMethod option value is not string-like.",
255 message: "sort: 'ComparisonMethod' requires a string value",
256 };
257
258const SORT_ERROR_COMPARISON_METHOD_UNKNOWN: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
259 code: "RM.SORT.COMPARISON_METHOD_UNKNOWN",
260 identifier: Some("RunMat:sort:ComparisonMethodUnknown"),
261 when: "ComparisonMethod option value is not one of 'auto'/'real'/'abs'.",
262 message: "sort: unsupported ComparisonMethod",
263};
264
265const SORT_ERROR_MISSINGPLACEMENT_UNSUPPORTED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
266 code: "RM.SORT.MISSINGPLACEMENT_UNSUPPORTED",
267 identifier: Some("RunMat:sort:MissingPlacementUnsupported"),
268 when: "MissingPlacement option is provided but unsupported.",
269 message: "sort: the 'MissingPlacement' option is not supported yet",
270};
271
272const SORT_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
273 code: "RM.SORT.INVALID_ARGUMENT",
274 identifier: Some("RunMat:sort:InvalidArgument"),
275 when: "Parser encounters invalid or unrecognized option/value arguments.",
276 message: "sort: invalid argument sequence",
277};
278
279const SORT_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
280 code: "RM.SORT.INTERNAL",
281 identifier: Some("RunMat:sort:Internal"),
282 when: "Internal conversion, allocation, or provider result construction fails.",
283 message: "sort: internal operation failed",
284};
285
286const SORT_ERRORS: [BuiltinErrorDescriptor; 6] = [
287 SORT_ERROR_INVALID_DIMENSION,
288 SORT_ERROR_COMPARISON_METHOD_REQUIRES_STRING,
289 SORT_ERROR_COMPARISON_METHOD_UNKNOWN,
290 SORT_ERROR_MISSINGPLACEMENT_UNSUPPORTED,
291 SORT_ERROR_INVALID_ARGUMENT,
292 SORT_ERROR_INTERNAL,
293];
294
295pub const SORT_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
296 signatures: &SORT_SIGNATURES,
297 output_mode: BuiltinOutputMode::ByRequestedOutputCount,
298 completion_policy: BuiltinCompletionPolicy::Public,
299 errors: &SORT_ERRORS,
300};
301
302fn sort_error(
303 error: &'static BuiltinErrorDescriptor,
304 message: impl Into<String>,
305) -> crate::RuntimeError {
306 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
307 if let Some(identifier) = error.identifier {
308 builder = builder.with_identifier(identifier);
309 }
310 builder.build()
311}
312
313fn sort_internal(message: impl Into<String>) -> crate::RuntimeError {
314 sort_error(&SORT_ERROR_INTERNAL, message)
315}
316
317fn sort_invalid_argument(message: impl Into<String>) -> crate::RuntimeError {
318 sort_error(&SORT_ERROR_INVALID_ARGUMENT, message)
319}
320
321#[runtime_builtin(
322 name = "sort",
323 category = "array/sorting_sets",
324 summary = "Sort array elements along a dimension with optional index outputs.",
325 keywords = "sort,ascending,descending,indices,comparisonmethod,gpu",
326 accel = "sink",
327 sink = true,
328 type_resolver(tensor_output_type),
329 descriptor(crate::builtins::array::sorting_sets::sort::SORT_DESCRIPTOR),
330 builtin_path = "crate::builtins::array::sorting_sets::sort"
331)]
332async fn sort_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
333 let eval = evaluate(value, &rest).await?;
334 if let Some(out_count) = crate::output_count::current_output_count() {
335 if out_count == 0 {
336 return Ok(Value::OutputList(Vec::new()));
337 }
338 let (sorted, indices) = eval.into_values();
339 let mut outputs = vec![sorted];
340 if out_count >= 2 {
341 outputs.push(indices);
342 }
343 return Ok(crate::output_count::output_list_with_padding(
344 out_count, outputs,
345 ));
346 }
347 Ok(eval.into_sorted_value())
348}
349
350pub async fn evaluate(value: Value, rest: &[Value]) -> crate::BuiltinResult<SortEvaluation> {
352 let args = SortArgs::parse(rest)?;
353 match value {
354 Value::GpuTensor(handle) => sort_gpu(handle, &args).await,
355 other => sort_host(other, &args),
356 }
357}
358
359async fn sort_gpu(
360 handle: GpuTensorHandle,
361 args: &SortArgs,
362) -> crate::BuiltinResult<SortEvaluation> {
363 let shape = handle.shape.clone();
364 let dim = args.dimension.unwrap_or_else(|| default_dimension(&shape));
365 if dim == 0 {
366 return Err(sort_error(
367 &SORT_ERROR_INVALID_DIMENSION,
368 "sort: dimension must be >= 1",
369 ));
370 }
371 let dim_len = dimension_length(&shape, dim);
372 if dim_len > 1 {
373 if let Some(provider) = runmat_accelerate_api::provider() {
374 let order = args.direction.to_provider();
375 let comparison = args.comparison.to_provider();
376 let zero_based = dim - 1;
377 if let Ok(result) = provider
378 .sort_dim(&handle, zero_based, order, comparison)
379 .await
380 {
381 let sorted_tensor = Tensor::new(result.values.data, result.values.shape)
382 .map_err(|e| sort_internal(format!("sort: {e}")))?;
383 let sorted_value = tensor::tensor_into_value(sorted_tensor);
384 let indices_tensor = Tensor::new(result.indices.data, result.indices.shape)
385 .map_err(|e| sort_internal(format!("sort: {e}")))?;
386 return Ok(SortEvaluation {
387 sorted: sorted_value,
388 indices: indices_tensor,
389 });
390 }
391 }
392 }
393 let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
394 sort_real_tensor(tensor, args)
395}
396
397fn sort_host(value: Value, args: &SortArgs) -> crate::BuiltinResult<SortEvaluation> {
398 match value {
399 Value::ComplexTensor(ct) => sort_complex_tensor(ct, args),
400 Value::Complex(re, im) => {
401 let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
402 .map_err(|e| sort_internal(format!("sort: {e}")))?;
403 sort_complex_tensor(tensor, args)
404 }
405 other => {
406 let tensor =
407 tensor::value_into_tensor_for("sort", other).map_err(sort_invalid_argument)?;
408 sort_real_tensor(tensor, args)
409 }
410 }
411}
412
413fn sort_real_tensor(tensor: Tensor, args: &SortArgs) -> crate::BuiltinResult<SortEvaluation> {
414 let dim = args
415 .dimension
416 .unwrap_or_else(|| default_dimension(&tensor.shape));
417 if dim == 0 {
418 return Err(sort_error(
419 &SORT_ERROR_INVALID_DIMENSION,
420 "sort: dimension must be >= 1",
421 ));
422 }
423
424 let dim_len = dimension_length(&tensor.shape, dim);
425 if tensor.data.is_empty() || dim_len <= 1 {
426 let indices = vec![1.0; tensor.data.len()];
427 let index_tensor = Tensor::new(indices, tensor.shape.clone())
428 .map_err(|e| sort_internal(format!("sort: {e}")))?;
429 let sorted_value = tensor::tensor_into_value(tensor);
430 return Ok(SortEvaluation {
431 sorted: sorted_value,
432 indices: index_tensor,
433 });
434 }
435
436 let stride_before = stride_before(&tensor.shape, dim);
437 let stride_after = stride_after(&tensor.shape, dim);
438 let mut sorted = tensor.data.clone();
439 let mut indices = vec![0.0f64; tensor.data.len()];
440 let mut buffer: Vec<(usize, f64)> = Vec::with_capacity(dim_len);
441
442 for after in 0..stride_after {
443 for before in 0..stride_before {
444 buffer.clear();
445 for k in 0..dim_len {
446 let idx = before + k * stride_before + after * stride_before * dim_len;
447 let value = tensor.data[idx];
448 buffer.push((k, value));
449 }
450 buffer.sort_by(|a, b| compare_real_values(a.1, b.1, args));
451 for (pos, (original_index, value)) in buffer.iter().enumerate() {
452 let target = before + pos * stride_before + after * stride_before * dim_len;
453 sorted[target] = *value;
454 indices[target] = (*original_index + 1) as f64;
455 }
456 }
457 }
458
459 let sorted_tensor = Tensor::new(sorted, tensor.shape.clone())
460 .map_err(|e| sort_internal(format!("sort: {e}")))?;
461 let index_tensor = Tensor::new(indices, tensor.shape.clone())
462 .map_err(|e| sort_internal(format!("sort: {e}")))?;
463
464 Ok(SortEvaluation {
465 sorted: tensor::tensor_into_value(sorted_tensor),
466 indices: index_tensor,
467 })
468}
469
470fn sort_complex_tensor(
471 tensor: ComplexTensor,
472 args: &SortArgs,
473) -> crate::BuiltinResult<SortEvaluation> {
474 let dim = args
475 .dimension
476 .unwrap_or_else(|| default_dimension(&tensor.shape));
477 if dim == 0 {
478 return Err(sort_error(
479 &SORT_ERROR_INVALID_DIMENSION,
480 "sort: dimension must be >= 1",
481 ));
482 }
483
484 let dim_len = dimension_length(&tensor.shape, dim);
485 if tensor.data.is_empty() || dim_len <= 1 {
486 let indices = vec![1.0; tensor.data.len()];
487 let index_tensor = Tensor::new(indices, tensor.shape.clone())
488 .map_err(|e| sort_internal(format!("sort: {e}")))?;
489 return Ok(SortEvaluation {
490 sorted: complex_tensor_into_value(tensor),
491 indices: index_tensor,
492 });
493 }
494
495 let stride_before = stride_before(&tensor.shape, dim);
496 let stride_after = stride_after(&tensor.shape, dim);
497 let mut sorted = tensor.data.clone();
498 let mut indices = vec![0.0f64; tensor.data.len()];
499 let mut buffer: Vec<(usize, (f64, f64))> = Vec::with_capacity(dim_len);
500
501 for after in 0..stride_after {
502 for before in 0..stride_before {
503 buffer.clear();
504 for k in 0..dim_len {
505 let idx = before + k * stride_before + after * stride_before * dim_len;
506 let value = tensor.data[idx];
507 buffer.push((k, value));
508 }
509 buffer.sort_by(|a, b| compare_complex_values(a.1, b.1, args));
510 for (pos, (original_index, value)) in buffer.iter().enumerate() {
511 let target = before + pos * stride_before + after * stride_before * dim_len;
512 sorted[target] = *value;
513 indices[target] = (*original_index + 1) as f64;
514 }
515 }
516 }
517
518 let sorted_tensor = ComplexTensor::new(sorted, tensor.shape.clone())
519 .map_err(|e| sort_internal(format!("sort: {e}")))?;
520 let index_tensor = Tensor::new(indices, tensor.shape.clone())
521 .map_err(|e| sort_internal(format!("sort: {e}")))?;
522
523 Ok(SortEvaluation {
524 sorted: complex_tensor_into_value(sorted_tensor),
525 indices: index_tensor,
526 })
527}
528
529fn complex_tensor_into_value(tensor: ComplexTensor) -> Value {
530 if tensor.data.len() == 1 {
531 let (re, im) = tensor.data[0];
532 Value::Complex(re, im)
533 } else {
534 Value::ComplexTensor(tensor)
535 }
536}
537
538fn compare_real_values(a: f64, b: f64, args: &SortArgs) -> Ordering {
539 match (a.is_nan(), b.is_nan()) {
540 (true, true) => Ordering::Equal,
541 (true, false) => match args.direction {
542 SortDirection::Ascend => Ordering::Greater,
543 SortDirection::Descend => Ordering::Less,
544 },
545 (false, true) => match args.direction {
546 SortDirection::Ascend => Ordering::Less,
547 SortDirection::Descend => Ordering::Greater,
548 },
549 (false, false) => compare_real_finite(a, b, args),
550 }
551}
552
553fn compare_real_finite(a: f64, b: f64, args: &SortArgs) -> Ordering {
554 let primary = match args.comparison {
555 ComparisonMethod::Abs => {
556 let abs_cmp = a.abs().partial_cmp(&b.abs()).unwrap_or(Ordering::Equal);
557 if abs_cmp != Ordering::Equal {
558 return match args.direction {
559 SortDirection::Ascend => abs_cmp,
560 SortDirection::Descend => abs_cmp.reverse(),
561 };
562 }
563 Ordering::Equal
564 }
565 ComparisonMethod::Auto | ComparisonMethod::Real => Ordering::Equal,
566 };
567 if primary != Ordering::Equal {
568 return primary;
569 }
570 match args.direction {
571 SortDirection::Ascend => a.partial_cmp(&b).unwrap_or(Ordering::Equal),
572 SortDirection::Descend => b.partial_cmp(&a).unwrap_or(Ordering::Equal),
573 }
574}
575
576fn compare_complex_values(a: (f64, f64), b: (f64, f64), args: &SortArgs) -> Ordering {
577 match (complex_is_nan(a), complex_is_nan(b)) {
578 (true, true) => Ordering::Equal,
579 (true, false) => match args.direction {
580 SortDirection::Ascend => Ordering::Greater,
581 SortDirection::Descend => Ordering::Less,
582 },
583 (false, true) => match args.direction {
584 SortDirection::Ascend => Ordering::Less,
585 SortDirection::Descend => Ordering::Greater,
586 },
587 (false, false) => compare_complex_finite(a, b, args),
588 }
589}
590
591fn compare_complex_finite(a: (f64, f64), b: (f64, f64), args: &SortArgs) -> Ordering {
592 match args.comparison {
593 ComparisonMethod::Real => compare_complex_real_imag(a, b, args.direction),
594 ComparisonMethod::Abs | ComparisonMethod::Auto => {
595 let abs_cmp = complex_abs(a)
596 .partial_cmp(&complex_abs(b))
597 .unwrap_or(Ordering::Equal);
598 if abs_cmp != Ordering::Equal {
599 return match args.direction {
600 SortDirection::Ascend => abs_cmp,
601 SortDirection::Descend => abs_cmp.reverse(),
602 };
603 }
604 compare_complex_real_imag(a, b, args.direction)
605 }
606 }
607}
608
609fn compare_complex_real_imag(a: (f64, f64), b: (f64, f64), direction: SortDirection) -> Ordering {
610 let real_cmp = match direction {
611 SortDirection::Ascend => a.0.partial_cmp(&b.0),
612 SortDirection::Descend => b.0.partial_cmp(&a.0),
613 }
614 .unwrap_or(Ordering::Equal);
615 if real_cmp != Ordering::Equal {
616 return real_cmp;
617 }
618 match direction {
619 SortDirection::Ascend => a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal),
620 SortDirection::Descend => b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal),
621 }
622}
623
624fn complex_is_nan(value: (f64, f64)) -> bool {
625 value.0.is_nan() || value.1.is_nan()
626}
627
628fn complex_abs(value: (f64, f64)) -> f64 {
629 value.0.hypot(value.1)
630}
631
632fn stride_before(shape: &[usize], dim: usize) -> usize {
633 if dim <= 1 {
634 return 1;
635 }
636 let mut product = 1usize;
637 for i in 0..(dim - 1) {
638 product = product.saturating_mul(*shape.get(i).unwrap_or(&1));
639 }
640 product
641}
642
643fn stride_after(shape: &[usize], dim: usize) -> usize {
644 if dim >= shape.len() {
645 return 1;
646 }
647 let mut product = 1usize;
648 for extent in shape.iter().skip(dim) {
649 product = product.saturating_mul(*extent);
650 }
651 product
652}
653
654fn dimension_length(shape: &[usize], dim: usize) -> usize {
655 shape.get(dim - 1).copied().unwrap_or(1)
656}
657
658fn default_dimension(shape: &[usize]) -> usize {
659 shape
660 .iter()
661 .position(|&extent| extent > 1)
662 .map(|idx| idx + 1)
663 .unwrap_or(1)
664}
665
666#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
667enum SortDirection {
668 #[default]
669 Ascend,
670 Descend,
671}
672
673impl SortDirection {
674 fn to_provider(self) -> ProviderSortOrder {
675 match self {
676 SortDirection::Ascend => ProviderSortOrder::Ascend,
677 SortDirection::Descend => ProviderSortOrder::Descend,
678 }
679 }
680}
681
682#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
683enum ComparisonMethod {
684 #[default]
685 Auto,
686 Real,
687 Abs,
688}
689
690impl ComparisonMethod {
691 fn to_provider(self) -> ProviderSortComparison {
692 match self {
693 ComparisonMethod::Auto => ProviderSortComparison::Auto,
694 ComparisonMethod::Real => ProviderSortComparison::Real,
695 ComparisonMethod::Abs => ProviderSortComparison::Abs,
696 }
697 }
698}
699
700#[derive(Debug, Clone, Default)]
701struct SortArgs {
702 dimension: Option<usize>,
703 direction: SortDirection,
704 comparison: ComparisonMethod,
705}
706
707impl SortArgs {
708 fn parse(rest: &[Value]) -> crate::BuiltinResult<Self> {
709 let mut args = SortArgs::default();
710 let tokens = tokens_from_values(rest);
711 let mut i = 0usize;
712 while i < rest.len() {
713 if args.dimension.is_none() {
714 if is_dimension_placeholder(&rest[i]) {
715 i += 1;
716 continue;
717 }
718 match tensor::parse_dimension(&rest[i], "sort") {
719 Ok(dim) => {
720 args.dimension = Some(dim);
721 i += 1;
722 continue;
723 }
724 Err(err) => {
725 if matches!(rest[i], Value::Int(_) | Value::Num(_)) {
726 return Err(sort_error(&SORT_ERROR_INVALID_DIMENSION, err));
727 }
728 }
729 }
730 }
731 if let Some(ArgToken::String(text)) = tokens.get(i) {
732 match text.as_str() {
733 "ascend" | "ascending" => {
734 args.direction = SortDirection::Ascend;
735 i += 1;
736 continue;
737 }
738 "descend" | "descending" => {
739 args.direction = SortDirection::Descend;
740 i += 1;
741 continue;
742 }
743 "comparisonmethod" => {
744 i += 1;
745 if i >= rest.len() {
746 return Err(sort_invalid_argument(
747 "sort: expected a value for 'ComparisonMethod'",
748 ));
749 }
750 let value = match tokens.get(i) {
751 Some(ArgToken::String(value)) => value.as_str(),
752 _ => {
753 return Err(sort_error(
754 &SORT_ERROR_COMPARISON_METHOD_REQUIRES_STRING,
755 "sort: 'ComparisonMethod' requires a string value",
756 ))
757 }
758 };
759 args.comparison = match value {
760 "auto" => ComparisonMethod::Auto,
761 "real" => ComparisonMethod::Real,
762 "abs" | "magnitude" => ComparisonMethod::Abs,
763 other => {
764 return Err(sort_error(
765 &SORT_ERROR_COMPARISON_METHOD_UNKNOWN,
766 format!("sort: unsupported ComparisonMethod '{other}'"),
767 )
768 .into())
769 }
770 };
771 i += 1;
772 continue;
773 }
774 "missingplacement" => {
775 return Err(sort_error(
776 &SORT_ERROR_MISSINGPLACEMENT_UNSUPPORTED,
777 SORT_ERROR_MISSINGPLACEMENT_UNSUPPORTED.message,
778 )
779 .into());
780 }
781 _ => {}
782 }
783 }
784 if let Some(keyword) = tensor::value_to_string(&rest[i]) {
785 let lowered = keyword.trim().to_ascii_lowercase();
786 match lowered.as_str() {
787 "ascend" | "ascending" => {
788 args.direction = SortDirection::Ascend;
789 i += 1;
790 continue;
791 }
792 "descend" | "descending" => {
793 args.direction = SortDirection::Descend;
794 i += 1;
795 continue;
796 }
797 "comparisonmethod" => {
798 i += 1;
799 if i >= rest.len() {
800 return Err(sort_invalid_argument(
801 "sort: expected a value for 'ComparisonMethod'",
802 ));
803 }
804 let raw = &rest[i];
805 let value = match raw {
806 Value::String(s) => s.clone(),
807 Value::StringArray(sa) if sa.data.len() == 1 => sa.data[0].clone(),
808 Value::CharArray(ca) if ca.rows == 1 => {
809 ca.data.iter().copied().collect()
810 }
811 _ => {
812 return Err(sort_error(
813 &SORT_ERROR_COMPARISON_METHOD_REQUIRES_STRING,
814 "sort: 'ComparisonMethod' requires a string value",
815 ))
816 }
817 };
818 let lowered_value = value.trim().to_ascii_lowercase();
819 args.comparison = match lowered_value.as_str() {
820 "auto" => ComparisonMethod::Auto,
821 "real" => ComparisonMethod::Real,
822 "abs" | "magnitude" => ComparisonMethod::Abs,
823 other => {
824 return Err(sort_error(
825 &SORT_ERROR_COMPARISON_METHOD_UNKNOWN,
826 format!("sort: unsupported ComparisonMethod '{other}'"),
827 )
828 .into())
829 }
830 };
831 i += 1;
832 continue;
833 }
834 "missingplacement" => {
835 return Err(sort_error(
836 &SORT_ERROR_MISSINGPLACEMENT_UNSUPPORTED,
837 SORT_ERROR_MISSINGPLACEMENT_UNSUPPORTED.message,
838 )
839 .into());
840 }
841 _ => {}
842 }
843 }
844 return Err(sort_invalid_argument(format!(
845 "sort: unrecognised argument {:?}",
846 rest[i]
847 )));
848 }
849 Ok(args)
850 }
851}
852
853fn is_dimension_placeholder(value: &Value) -> bool {
854 match value {
855 Value::Tensor(t) => t.data.is_empty(),
856 Value::LogicalArray(logical) => logical.data.is_empty(),
857 _ => false,
858 }
859}
860
861pub struct SortEvaluation {
862 sorted: Value,
863 indices: Tensor,
864}
865
866impl SortEvaluation {
867 pub fn into_sorted_value(self) -> Value {
868 self.sorted
869 }
870
871 pub fn into_values(self) -> (Value, Value) {
872 let indices = tensor::tensor_into_value(self.indices);
873 (self.sorted, indices)
874 }
875
876 pub fn indices_value(&self) -> Value {
877 tensor::tensor_into_value(self.indices.clone())
878 }
879}
880
881#[cfg(test)]
882pub(crate) mod tests {
883 use super::*;
884 use crate::builtins::common::test_support;
885 use futures::executor::block_on;
886 use runmat_builtins::{ComplexTensor, IntValue, ResolveContext, Tensor, Type, Value};
887
888 fn sort_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
889 block_on(super::sort_builtin(value, rest))
890 }
891
892 fn evaluate(value: Value, rest: &[Value]) -> crate::BuiltinResult<SortEvaluation> {
893 block_on(super::evaluate(value, rest))
894 }
895
896 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
897 #[test]
898 fn sort_vector_default() {
899 let tensor = Tensor::new(vec![3.0, 1.0, 2.0], vec![3, 1]).unwrap();
900 let result = sort_builtin(Value::Tensor(tensor), Vec::new()).expect("sort");
901 match result {
902 Value::Tensor(t) => {
903 assert_eq!(t.data, vec![1.0, 2.0, 3.0]);
904 assert_eq!(t.shape, vec![3, 1]);
905 }
906 other => panic!("expected tensor result, got {other:?}"),
907 }
908 }
909
910 #[test]
911 fn sort_type_resolver_tensor() {
912 assert_eq!(
913 tensor_output_type(&[Type::tensor()], &ResolveContext::new(Vec::new())),
914 Type::tensor()
915 );
916 }
917
918 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
919 #[test]
920 fn sort_descend_direction() {
921 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
922 let result =
923 sort_builtin(Value::Tensor(tensor), vec![Value::from("descend")]).expect("sort");
924 match result {
925 Value::Tensor(t) => assert_eq!(t.data, vec![4.0, 3.0, 2.0, 1.0]),
926 other => panic!("expected tensor, got {other:?}"),
927 }
928 }
929
930 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
931 #[test]
932 fn sort_matrix_default_dim1() {
933 let tensor = Tensor::new(vec![4.0, 2.0, 1.0, 5.0, 6.0, 3.0], vec![2, 3]).unwrap();
934 let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
935 let (sorted, indices) = eval.into_values();
936 match sorted {
937 Value::Tensor(t) => {
938 assert_eq!(t.data, vec![2.0, 4.0, 1.0, 5.0, 3.0, 6.0]);
939 assert_eq!(t.shape, vec![2, 3]);
940 }
941 other => panic!("expected tensor result, got {other:?}"),
942 }
943 match indices {
944 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0, 1.0, 2.0, 2.0, 1.0]),
945 other => panic!("expected tensor indices, got {other:?}"),
946 }
947 }
948
949 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
950 #[test]
951 fn sort_matrix_along_dimension_two() {
952 let tensor = Tensor::new(vec![1.0, 3.0, 4.0, 2.0, 2.0, 5.0], vec![2, 3]).unwrap();
953 let eval =
954 evaluate(Value::Tensor(tensor), &[Value::Int(IntValue::I32(2))]).expect("evaluate");
955 let (sorted, indices) = eval.into_values();
956 match sorted {
957 Value::Tensor(t) => {
958 assert_eq!(t.data, vec![1.0, 2.0, 2.0, 3.0, 4.0, 5.0]);
959 assert_eq!(t.shape, vec![2, 3]);
960 }
961 other => panic!("expected tensor result, got {other:?}"),
962 }
963 match indices {
964 Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]),
965 other => panic!("expected tensor indices, got {other:?}"),
966 }
967 }
968
969 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
970 #[test]
971 fn sort_dimension_placeholder_then_dim() {
972 let tensor = Tensor::new(vec![1.0, 3.0, 4.0, 2.0], vec![2, 2]).unwrap();
973 let placeholder = Tensor::new(Vec::new(), vec![0, 0]).unwrap();
974 let eval = evaluate(
975 Value::Tensor(tensor),
976 &[
977 Value::Tensor(placeholder),
978 Value::Int(IntValue::I32(2)),
979 Value::from("descend"),
980 ],
981 )
982 .expect("evaluate");
983 let (sorted, _) = eval.into_values();
984 match sorted {
985 Value::Tensor(t) => assert_eq!(t.data, vec![4.0, 3.0, 1.0, 2.0]),
986 other => panic!("expected tensor result, got {other:?}"),
987 }
988 }
989
990 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
991 #[test]
992 fn sort_descend_then_dimension() {
993 let tensor = Tensor::new(vec![1.0, 3.0, 4.0, 2.0, 2.0, 5.0], vec![2, 3]).unwrap();
994 let eval = evaluate(
995 Value::Tensor(tensor),
996 &[Value::from("descend"), Value::Int(IntValue::I32(1))],
997 )
998 .expect("evaluate");
999 let (sorted, _) = eval.into_values();
1000 match sorted {
1001 Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 1.0, 4.0, 2.0, 5.0, 2.0]),
1002 other => panic!("expected tensor result, got {other:?}"),
1003 }
1004 }
1005
1006 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1007 #[test]
1008 fn sort_returns_indices() {
1009 let tensor = Tensor::new(vec![4.0, 1.0, 9.0, 2.0], vec![4, 1]).unwrap();
1010 let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
1011 let (sorted, indices) = eval.into_values();
1012 match sorted {
1013 Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 4.0, 9.0]),
1014 other => panic!("expected tensor, got {other:?}"),
1015 }
1016 match indices {
1017 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 4.0, 1.0, 3.0]),
1018 other => panic!("expected tensor, got {other:?}"),
1019 }
1020 }
1021
1022 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1023 #[test]
1024 fn sort_with_nan_handling() {
1025 let tensor = Tensor::new(vec![f64::NAN, 4.0, 1.0, 2.0], vec![4, 1]).unwrap();
1026 let eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("evaluate");
1027 let (sorted, _) = eval.into_values();
1028 match sorted {
1029 Value::Tensor(t) => {
1030 assert!(t.data[3].is_nan());
1031 assert_eq!(&t.data[0..3], &[1.0, 2.0, 4.0]);
1032 }
1033 other => panic!("expected tensor, got {other:?}"),
1034 }
1035
1036 let eval_desc =
1037 evaluate(Value::Tensor(tensor), &[Value::from("descend")]).expect("evaluate");
1038 let (sorted_desc, _) = eval_desc.into_values();
1039 match sorted_desc {
1040 Value::Tensor(t) => {
1041 assert!(t.data[0].is_nan());
1042 assert_eq!(&t.data[1..], &[4.0, 2.0, 1.0]);
1043 }
1044 other => panic!("expected tensor, got {other:?}"),
1045 }
1046 }
1047
1048 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1049 #[test]
1050 fn sort_by_absolute_value() {
1051 let tensor = Tensor::new(vec![-8.0, -1.0, 3.0, -2.0], vec![4, 1]).unwrap();
1052 let eval = evaluate(
1053 Value::Tensor(tensor),
1054 &[Value::from("ComparisonMethod"), Value::from("abs")],
1055 )
1056 .expect("evaluate");
1057 let (sorted, _) = eval.into_values();
1058 match sorted {
1059 Value::Tensor(t) => assert_eq!(t.data, vec![-1.0, -2.0, 3.0, -8.0]),
1060 other => panic!("expected tensor, got {other:?}"),
1061 }
1062 }
1063
1064 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1065 #[test]
1066 fn sort_by_absolute_value_descend() {
1067 let tensor = Tensor::new(vec![-1.0, 2.0, -3.0, 4.0], vec![4, 1]).unwrap();
1068 let eval = evaluate(
1069 Value::Tensor(tensor),
1070 &[
1071 Value::from("descend"),
1072 Value::from("ComparisonMethod"),
1073 Value::from("abs"),
1074 ],
1075 )
1076 .expect("evaluate");
1077 let (sorted, _) = eval.into_values();
1078 match sorted {
1079 Value::Tensor(t) => assert_eq!(t.data, vec![4.0, -3.0, 2.0, -1.0]),
1080 other => panic!("expected tensor, got {other:?}"),
1081 }
1082 }
1083
1084 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1085 #[test]
1086 fn sort_complex_auto_abs() {
1087 let tensor =
1088 ComplexTensor::new(vec![(1.0, 2.0), (-3.0, 0.5), (0.0, -1.0)], vec![3, 1]).unwrap();
1089 let eval = evaluate(Value::ComplexTensor(tensor), &[]).expect("evaluate");
1090 let (sorted, indices) = eval.into_values();
1091 match sorted {
1092 Value::ComplexTensor(t) => {
1093 assert_eq!(t.data, vec![(0.0, -1.0), (1.0, 2.0), (-3.0, 0.5)])
1094 }
1095 other => panic!("expected complex tensor, got {other:?}"),
1096 }
1097 match indices {
1098 Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 1.0, 2.0]),
1099 other => panic!("expected tensor indices, got {other:?}"),
1100 }
1101 }
1102
1103 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1104 #[test]
1105 fn sort_complex_real_descend() {
1106 let tensor =
1107 ComplexTensor::new(vec![(1.0, 2.0), (-3.0, 0.0), (1.0, -1.0)], vec![3, 1]).unwrap();
1108 let eval = evaluate(
1109 Value::ComplexTensor(tensor),
1110 &[
1111 Value::from("descend"),
1112 Value::from("ComparisonMethod"),
1113 Value::from("real"),
1114 ],
1115 )
1116 .expect("evaluate");
1117 let (sorted, _) = eval.into_values();
1118 match sorted {
1119 Value::ComplexTensor(t) => {
1120 assert_eq!(t.data, vec![(1.0, 2.0), (1.0, -1.0), (-3.0, 0.0)]);
1121 }
1122 other => panic!("expected complex tensor, got {other:?}"),
1123 }
1124 }
1125
1126 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1127 #[test]
1128 fn sort_stable_with_duplicates() {
1129 let tensor = Tensor::new(vec![2.0, 2.0, 1.0, 2.0], vec![4, 1]).unwrap();
1130 let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
1131 let (sorted, indices) = eval.into_values();
1132 match sorted {
1133 Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 2.0, 2.0]),
1134 other => panic!("expected tensor, got {other:?}"),
1135 }
1136 match indices {
1137 Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 1.0, 2.0, 4.0]),
1138 other => panic!("expected tensor indices, got {other:?}"),
1139 }
1140 }
1141
1142 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1143 #[test]
1144 fn sort_empty_tensor() {
1145 let tensor = Tensor::new(Vec::new(), vec![0, 3]).unwrap();
1146 let eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("evaluate");
1147 let (sorted, indices) = eval.into_values();
1148 match sorted {
1149 Value::Tensor(t) => {
1150 assert!(t.data.is_empty());
1151 assert_eq!(t.shape, tensor.shape);
1152 }
1153 other => panic!("expected tensor, got {other:?}"),
1154 }
1155 match indices {
1156 Value::Tensor(t) => assert!(t.data.is_empty()),
1157 other => panic!("expected tensor, got {other:?}"),
1158 }
1159 }
1160
1161 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1162 #[test]
1163 fn sort_dim_greater_than_ndims() {
1164 let tensor = Tensor::new(vec![4.0, 2.0, 3.0, 1.0], vec![2, 2]).unwrap();
1165 let eval = evaluate(
1166 Value::Tensor(tensor.clone()),
1167 &[Value::Int(IntValue::I32(3))],
1168 )
1169 .expect("evaluate");
1170 let (sorted, indices) = eval.into_values();
1171 match sorted {
1172 Value::Tensor(t) => assert_eq!(t.data, tensor.data),
1173 other => panic!("expected tensor, got {other:?}"),
1174 }
1175 match indices {
1176 Value::Tensor(t) => assert!(t.data.iter().all(|v| (*v - 1.0).abs() < f64::EPSILON)),
1177 other => panic!("expected tensor, got {other:?}"),
1178 }
1179 }
1180
1181 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1182 #[test]
1183 fn sort_invalid_argument_errors() {
1184 let err = sort_builtin(
1185 Value::Tensor(Tensor::new(vec![1.0], vec![1, 1]).unwrap()),
1186 vec![Value::from("missingplacement"), Value::from("first")],
1187 )
1188 .unwrap_err();
1189 assert_eq!(
1190 err.identifier(),
1191 SORT_ERROR_MISSINGPLACEMENT_UNSUPPORTED.identifier
1192 );
1193 }
1194
1195 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1196 #[test]
1197 fn sort_invalid_comparison_method_errors() {
1198 let err = sort_builtin(
1199 Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap()),
1200 vec![Value::from("ComparisonMethod"), Value::from("unknown")],
1201 )
1202 .unwrap_err();
1203 assert_eq!(
1204 err.identifier(),
1205 SORT_ERROR_COMPARISON_METHOD_UNKNOWN.identifier
1206 );
1207 }
1208
1209 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1210 #[test]
1211 fn sort_invalid_comparison_method_value_errors() {
1212 let err = sort_builtin(
1213 Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap()),
1214 vec![
1215 Value::from("ComparisonMethod"),
1216 Value::Int(IntValue::I32(1)),
1217 ],
1218 )
1219 .unwrap_err();
1220 assert_eq!(
1221 err.identifier(),
1222 SORT_ERROR_COMPARISON_METHOD_REQUIRES_STRING.identifier
1223 );
1224 }
1225
1226 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1227 #[test]
1228 fn sort_dimension_zero_errors() {
1229 let err = sort_builtin(
1230 Value::Tensor(Tensor::new(vec![1.0], vec![1, 1]).unwrap()),
1231 vec![Value::Num(0.0)],
1232 )
1233 .unwrap_err();
1234 assert_eq!(err.identifier(), SORT_ERROR_INVALID_DIMENSION.identifier);
1235 }
1236
1237 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1238 #[test]
1239 fn sort_gpu_round_trip() {
1240 test_support::with_test_provider(|provider| {
1241 let tensor = Tensor::new(vec![3.0, 1.0, 2.0], vec![3, 1]).unwrap();
1242 let view = runmat_accelerate_api::HostTensorView {
1243 data: &tensor.data,
1244 shape: &tensor.shape,
1245 };
1246 let handle = provider.upload(&view).expect("upload");
1247 let eval = evaluate(Value::GpuTensor(handle), &[]).expect("evaluate");
1248 let (sorted, indices) = eval.into_values();
1249 match sorted {
1250 Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 3.0]),
1251 other => panic!("expected tensor, got {other:?}"),
1252 }
1253 match indices {
1254 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
1255 other => panic!("expected tensor, got {other:?}"),
1256 }
1257 });
1258 }
1259
1260 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1261 #[test]
1262 #[cfg(feature = "wgpu")]
1263 fn sort_wgpu_matches_cpu() {
1264 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1265 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1266 );
1267 let tensor = Tensor::new(vec![4.0, 1.0, 3.0, 2.0], vec![4, 1]).unwrap();
1268 let cpu_eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("cpu sort");
1269 let (cpu_sorted, cpu_indices) = cpu_eval.into_values();
1270
1271 let gpu_view = runmat_accelerate_api::HostTensorView {
1272 data: &tensor.data,
1273 shape: &tensor.shape,
1274 };
1275 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1276 let handle = provider.upload(&gpu_view).expect("upload");
1277 let gpu_eval = evaluate(Value::GpuTensor(handle), &[]).expect("gpu sort");
1278 let (gpu_sorted, gpu_indices) = gpu_eval.into_values();
1279
1280 let cpu_sorted_tensor = match cpu_sorted {
1281 Value::Tensor(t) => t,
1282 Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
1283 other => panic!("unexpected CPU sorted value {other:?}"),
1284 };
1285 let cpu_indices_tensor = match cpu_indices {
1286 Value::Tensor(t) => t,
1287 Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
1288 other => panic!("unexpected CPU indices value {other:?}"),
1289 };
1290 let gpu_sorted_tensor = match gpu_sorted {
1291 Value::Tensor(t) => t,
1292 Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
1293 other => panic!("unexpected GPU sorted value {other:?}"),
1294 };
1295 let gpu_indices_tensor = match gpu_indices {
1296 Value::Tensor(t) => t,
1297 Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
1298 other => panic!("unexpected GPU indices value {other:?}"),
1299 };
1300
1301 assert_eq!(gpu_sorted_tensor.data, cpu_sorted_tensor.data);
1302 assert_eq!(gpu_indices_tensor.data, cpu_indices_tensor.data);
1303 }
1304}