1use std::cmp::Ordering;
4
5use runmat_accelerate_api::{
6 GpuTensorHandle, SortComparison as ProviderSortComparison, SortOrder as ProviderSortOrder,
7};
8use runmat_builtins::{ComplexTensor, Tensor, Value};
9use runmat_macros::runtime_builtin;
10
11use crate::builtins::common::gpu_helpers;
12use crate::builtins::common::spec::{
13 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
14 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
15};
16use crate::builtins::common::tensor;
17#[cfg(feature = "doc_export")]
18use crate::register_builtin_doc_text;
19use crate::{register_builtin_fusion_spec, register_builtin_gpu_spec};
20
21#[cfg(feature = "doc_export")]
22pub const DOC_MD: &str = r#"---
23title: "sort"
24category: "array/sorting_sets"
25keywords: ["sort", "ascending", "descending", "indices", "comparisonmethod", "gpu"]
26summary: "Sort scalars, vectors, matrices, or N-D tensors along a dimension, with optional index outputs."
27references:
28 - https://www.mathworks.com/help/matlab/ref/sort.html
29gpu_support:
30 elementwise: false
31 reduction: false
32 precisions: ["f32", "f64"]
33 broadcasting: "none"
34 notes: "Uses the provider `sort_dim` hook when available; otherwise tensors are gathered and sorted on the host."
35fusion:
36 elementwise: false
37 reduction: false
38 max_inputs: 1
39 constants: "inline"
40requires_feature: null
41tested:
42 unit: "builtins::array::sorting_sets::sort::tests"
43 integration: "builtins::array::sorting_sets::sort::tests::sort_gpu_round_trip"
44---
45
46# What does the `sort` function do in MATLAB / RunMat?
47`sort` orders the elements of its input along a chosen dimension. By default, it sorts ascending along the first non-singleton dimension and can also return the permutation indices that achieve the ordering.
48
49## How does the `sort` function behave in MATLAB / RunMat?
50- Scalars remain unchanged; vectors are reordered into ascending or descending order.
51- For matrices and higher-dimensional tensors, the sort operates along a specified dimension (default: first non-singleton). All other dimensions remain untouched.
52- `[B, I] = sort(A, ...)` returns both the sorted values `B` and the permutation indices `I`, using MATLAB's one-based indexing.
53- Direction arguments accept `'ascend'` (default) or `'descend'`.
54- Name-value pairs support `'ComparisonMethod'` with values `'auto'`, `'real'`, or `'abs'`. The `'abs'` option sorts by absolute value while breaking ties using the signed value.
55- For complex inputs, the default `'ComparisonMethod'='auto'` behaves like `'abs'` (magnitude ordering) and `'real'` compares the real component first with the imaginary component used for tie-breaking.
56- NaN values are treated as missing: they appear at the end for ascending sorts and at the beginning for descending sorts (matching MATLAB's `'MissingPlacement','auto'` behaviour for doubles).
57- Dimensions greater than `ndims(A)` are treated as singleton dimensions (size 1) and therefore leave `A` unchanged while returning index values of `1`.
58
59## GPU execution in RunMat
60- `sort` is registered as a sink builtin. When tensors reside on a GPU without a specialised sort kernel, RunMat gathers them to host memory, performs the sort, and returns host-resident outputs.
61- Providers may implement a future `sort_dim` hook to keep data on the GPU. Until then, all providers fall back to the host path automatically.
62- The returned index tensor is always host-resident double precision.
63
64## Examples of using `sort` in MATLAB / RunMat
65
66### Sorting a column vector in ascending order
67```matlab
68A = [3; 1; 2];
69B = sort(A);
70```
71Expected output:
72```matlab
73B =
74 1
75 2
76 3
77```
78
79### Sorting rows by specifying the dimension
80```matlab
81A = [1 4 2; 3 2 5];
82B = sort(A, 2);
83```
84Expected output:
85```matlab
86B =
87 1 2 4
88 2 3 5
89```
90
91### Sorting values in descending order
92```matlab
93A = [10 4 7 9];
94B = sort(A, 'descend');
95```
96Expected output:
97```matlab
98B =
99 10 9 7 4
100```
101
102### Retrieving permutation indices alongside the sorted values
103```matlab
104A = [4 1 9 2];
105[B, I] = sort(A);
106```
107Expected output:
108```matlab
109B =
110 1 2 4 9
111I =
112 2 4 1 3
113```
114
115### Sorting by absolute value using `ComparisonMethod`
116```matlab
117A = [-8 -1 3 -2];
118B = sort(A, 'ComparisonMethod', 'abs');
119```
120Expected output:
121```matlab
122B =
123 -1 -2 3 -8
124```
125
126### Sorting tensors containing NaN values
127```matlab
128A = [NaN 4 1 2];
129[B, I] = sort(A);
130```
131Expected output:
132```matlab
133B =
134 1 2 4 NaN
135I =
136 3 4 2 1
137```
138
139### Sorting GPU tensors with automatic host fallback
140```matlab
141G = gpuArray(randn(5, 1));
142[B, I] = sort(G, 'descend');
143```
144The runtime gathers `G`, performs the sort on the host, and returns host-resident results. The ordering matches MATLAB's semantics exactly.
145
146## FAQ
147
148### Can `sort` return both values and indices?
149Yes. Use `[B, I] = sort(A, ...)` to receive the permutation indices alongside the sorted values.
150
151### How are NaN values handled?
152NaN values are considered missing. They appear last for ascending sorts and first for descending sorts, matching MATLAB's `'MissingPlacement','auto'` default for doubles.
153
154### What happens when I sort along a dimension that does not exist?
155`sort(A, dim)` treats dimensions beyond `ndims(A)` as singleton dimensions (size 1). The data remains unchanged and the index output is filled with ones.
156
157### Does `sort` support name-value arguments?
158Yes. `'ComparisonMethod'` accepts `'auto'`, `'real'`, or `'abs'`. Other name-value pairs such as `'MissingPlacement'` are currently not supported and raise an error.
159
160### Are GPU tensors sorted in-place?
161Not yet. `sort` is a sink builtin that gathers GPU tensors to host memory when no GPU sort kernel is available. Providers can implement specialised hooks in the future.
162
163### Does `sort` preserve the shape of the input?
164Yes. The output is the same size as the input tensor. Only values along the selected dimension are reordered.
165
166### What numeric type do the index outputs use?
167Permutation indices are returned as double-precision tensors (or scalars) using MATLAB's one-based indexing.
168
169### Is the sorting stable?
170Yes. Equal elements (including ties when sorting by absolute value) preserve their original order.
171
172### How does `ComparisonMethod` behave with real inputs?
173`'auto'` and `'real'` behave identically. `'abs'` sorts by absolute value and uses the signed value to break ties so that results match MATLAB.
174
175### How does `ComparisonMethod` behave with complex inputs?
176`'auto'` (the default) and `'abs'` order values by magnitude. `'real'` compares the real component first and falls back to the imaginary component to break ties while preserving stability.
177
178### Do logical or scalar inputs work?
179Yes. Logical inputs are promoted to double precision automatically. Scalars are returned unchanged with an index output of `1` when requested.
180
181## See also
182[sortrows](./sortrows), [unique](./unique), [max](../../math/reduction/max), [min](../../math/reduction/min), [permute](../../array/shape/permute)
183
184## Source & Feedback
185- Source code: [`crates/runmat-runtime/src/builtins/array/sorting_sets/sort.rs`](https://github.com/runmat-org/runmat/blob/main/crates/runmat-runtime/src/builtins/array/sorting_sets/sort.rs)
186- Found a bug? [Open an issue](https://github.com/runmat-org/runmat/issues/new/choose) with a minimal reproduction.
187"#;
188
189pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
190 name: "sort",
191 op_kind: GpuOpKind::Custom("sort"),
192 supported_precisions: &[ScalarType::F32, ScalarType::F64],
193 broadcast: BroadcastSemantics::None,
194 provider_hooks: &[ProviderHook::Custom("sort_dim")],
195 constant_strategy: ConstantStrategy::InlineLiteral,
196 residency: ResidencyPolicy::GatherImmediately,
197 nan_mode: ReductionNaN::Include,
198 two_pass_threshold: None,
199 workgroup_size: None,
200 accepts_nan_mode: true,
201 notes: "Providers may add a dedicated sort kernel in the future; today tensors are gathered to host memory before sorting.",
202};
203
204register_builtin_gpu_spec!(GPU_SPEC);
205
206pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
207 name: "sort",
208 shape: ShapeRequirements::Any,
209 constant_strategy: ConstantStrategy::InlineLiteral,
210 elementwise: None,
211 reduction: None,
212 emits_nan: true,
213 notes: "Sorting breaks fusion chains and acts as a residency sink; upstream tensors are gathered to host memory.",
214};
215
216register_builtin_fusion_spec!(FUSION_SPEC);
217
218#[cfg(feature = "doc_export")]
219register_builtin_doc_text!("sort", DOC_MD);
220
221#[runtime_builtin(
222 name = "sort",
223 category = "array/sorting_sets",
224 summary = "Sort scalars, vectors, matrices, or N-D tensors along a dimension, with optional index outputs.",
225 keywords = "sort,ascending,descending,indices,comparisonmethod,gpu",
226 accel = "sink",
227 sink = true
228)]
229fn sort_builtin(value: Value, rest: Vec<Value>) -> Result<Value, String> {
230 evaluate(value, &rest).map(|eval| eval.into_sorted_value())
231}
232
233pub fn evaluate(value: Value, rest: &[Value]) -> Result<SortEvaluation, String> {
235 let args = SortArgs::parse(rest)?;
236 match value {
237 Value::GpuTensor(handle) => sort_gpu(handle, &args),
238 other => sort_host(other, &args),
239 }
240}
241
242fn sort_gpu(handle: GpuTensorHandle, args: &SortArgs) -> Result<SortEvaluation, String> {
243 let shape = handle.shape.clone();
244 let dim = args.dimension.unwrap_or_else(|| default_dimension(&shape));
245 if dim == 0 {
246 return Err("sort: dimension must be >= 1".to_string());
247 }
248 let dim_len = dimension_length(&shape, dim);
249 if dim_len > 1 {
250 if let Some(provider) = runmat_accelerate_api::provider() {
251 let order = args.direction.to_provider();
252 let comparison = args.comparison.to_provider();
253 let zero_based = dim - 1;
254 if let Ok(result) = provider.sort_dim(&handle, zero_based, order, comparison) {
255 let sorted_tensor = Tensor::new(result.values.data, result.values.shape)
256 .map_err(|e| format!("sort: {e}"))?;
257 let sorted_value = tensor::tensor_into_value(sorted_tensor);
258 let indices_tensor = Tensor::new(result.indices.data, result.indices.shape)
259 .map_err(|e| format!("sort: {e}"))?;
260 return Ok(SortEvaluation {
261 sorted: sorted_value,
262 indices: indices_tensor,
263 });
264 }
265 }
266 }
267 let tensor = gpu_helpers::gather_tensor(&handle)?;
268 sort_real_tensor(tensor, args)
269}
270
271fn sort_host(value: Value, args: &SortArgs) -> Result<SortEvaluation, String> {
272 match value {
273 Value::ComplexTensor(ct) => sort_complex_tensor(ct, args),
274 Value::Complex(re, im) => {
275 let tensor =
276 ComplexTensor::new(vec![(re, im)], vec![1, 1]).map_err(|e| format!("sort: {e}"))?;
277 sort_complex_tensor(tensor, args)
278 }
279 other => {
280 let tensor = tensor::value_into_tensor_for("sort", other)?;
281 sort_real_tensor(tensor, args)
282 }
283 }
284}
285
286fn sort_real_tensor(tensor: Tensor, args: &SortArgs) -> Result<SortEvaluation, String> {
287 let dim = args
288 .dimension
289 .unwrap_or_else(|| default_dimension(&tensor.shape));
290 if dim == 0 {
291 return Err("sort: dimension must be >= 1".to_string());
292 }
293
294 let dim_len = dimension_length(&tensor.shape, dim);
295 if tensor.data.is_empty() || dim_len <= 1 {
296 let indices = vec![1.0; tensor.data.len()];
297 let index_tensor =
298 Tensor::new(indices, tensor.shape.clone()).map_err(|e| format!("sort: {e}"))?;
299 let sorted_value = tensor::tensor_into_value(tensor);
300 return Ok(SortEvaluation {
301 sorted: sorted_value,
302 indices: index_tensor,
303 });
304 }
305
306 let stride_before = stride_before(&tensor.shape, dim);
307 let stride_after = stride_after(&tensor.shape, dim);
308 let mut sorted = tensor.data.clone();
309 let mut indices = vec![0.0f64; tensor.data.len()];
310 let mut buffer: Vec<(usize, f64)> = Vec::with_capacity(dim_len);
311
312 for after in 0..stride_after {
313 for before in 0..stride_before {
314 buffer.clear();
315 for k in 0..dim_len {
316 let idx = before + k * stride_before + after * stride_before * dim_len;
317 let value = tensor.data[idx];
318 buffer.push((k, value));
319 }
320 buffer.sort_by(|a, b| compare_real_values(a.1, b.1, args));
321 for (pos, (original_index, value)) in buffer.iter().enumerate() {
322 let target = before + pos * stride_before + after * stride_before * dim_len;
323 sorted[target] = *value;
324 indices[target] = (*original_index + 1) as f64;
325 }
326 }
327 }
328
329 let sorted_tensor =
330 Tensor::new(sorted, tensor.shape.clone()).map_err(|e| format!("sort: {e}"))?;
331 let index_tensor =
332 Tensor::new(indices, tensor.shape.clone()).map_err(|e| format!("sort: {e}"))?;
333
334 Ok(SortEvaluation {
335 sorted: tensor::tensor_into_value(sorted_tensor),
336 indices: index_tensor,
337 })
338}
339
340fn sort_complex_tensor(tensor: ComplexTensor, args: &SortArgs) -> Result<SortEvaluation, String> {
341 let dim = args
342 .dimension
343 .unwrap_or_else(|| default_dimension(&tensor.shape));
344 if dim == 0 {
345 return Err("sort: dimension must be >= 1".to_string());
346 }
347
348 let dim_len = dimension_length(&tensor.shape, dim);
349 if tensor.data.is_empty() || dim_len <= 1 {
350 let indices = vec![1.0; tensor.data.len()];
351 let index_tensor =
352 Tensor::new(indices, tensor.shape.clone()).map_err(|e| format!("sort: {e}"))?;
353 return Ok(SortEvaluation {
354 sorted: complex_tensor_into_value(tensor),
355 indices: index_tensor,
356 });
357 }
358
359 let stride_before = stride_before(&tensor.shape, dim);
360 let stride_after = stride_after(&tensor.shape, dim);
361 let mut sorted = tensor.data.clone();
362 let mut indices = vec![0.0f64; tensor.data.len()];
363 let mut buffer: Vec<(usize, (f64, f64))> = Vec::with_capacity(dim_len);
364
365 for after in 0..stride_after {
366 for before in 0..stride_before {
367 buffer.clear();
368 for k in 0..dim_len {
369 let idx = before + k * stride_before + after * stride_before * dim_len;
370 let value = tensor.data[idx];
371 buffer.push((k, value));
372 }
373 buffer.sort_by(|a, b| compare_complex_values(a.1, b.1, args));
374 for (pos, (original_index, value)) in buffer.iter().enumerate() {
375 let target = before + pos * stride_before + after * stride_before * dim_len;
376 sorted[target] = *value;
377 indices[target] = (*original_index + 1) as f64;
378 }
379 }
380 }
381
382 let sorted_tensor =
383 ComplexTensor::new(sorted, tensor.shape.clone()).map_err(|e| format!("sort: {e}"))?;
384 let index_tensor =
385 Tensor::new(indices, tensor.shape.clone()).map_err(|e| format!("sort: {e}"))?;
386
387 Ok(SortEvaluation {
388 sorted: complex_tensor_into_value(sorted_tensor),
389 indices: index_tensor,
390 })
391}
392
393fn complex_tensor_into_value(tensor: ComplexTensor) -> Value {
394 if tensor.data.len() == 1 {
395 let (re, im) = tensor.data[0];
396 Value::Complex(re, im)
397 } else {
398 Value::ComplexTensor(tensor)
399 }
400}
401
402fn compare_real_values(a: f64, b: f64, args: &SortArgs) -> Ordering {
403 match (a.is_nan(), b.is_nan()) {
404 (true, true) => Ordering::Equal,
405 (true, false) => match args.direction {
406 SortDirection::Ascend => Ordering::Greater,
407 SortDirection::Descend => Ordering::Less,
408 },
409 (false, true) => match args.direction {
410 SortDirection::Ascend => Ordering::Less,
411 SortDirection::Descend => Ordering::Greater,
412 },
413 (false, false) => compare_real_finite(a, b, args),
414 }
415}
416
417fn compare_real_finite(a: f64, b: f64, args: &SortArgs) -> Ordering {
418 let primary = match args.comparison {
419 ComparisonMethod::Abs => {
420 let abs_cmp = a.abs().partial_cmp(&b.abs()).unwrap_or(Ordering::Equal);
421 if abs_cmp != Ordering::Equal {
422 return match args.direction {
423 SortDirection::Ascend => abs_cmp,
424 SortDirection::Descend => abs_cmp.reverse(),
425 };
426 }
427 Ordering::Equal
428 }
429 ComparisonMethod::Auto | ComparisonMethod::Real => Ordering::Equal,
430 };
431 if primary != Ordering::Equal {
432 return primary;
433 }
434 match args.direction {
435 SortDirection::Ascend => a.partial_cmp(&b).unwrap_or(Ordering::Equal),
436 SortDirection::Descend => b.partial_cmp(&a).unwrap_or(Ordering::Equal),
437 }
438}
439
440fn compare_complex_values(a: (f64, f64), b: (f64, f64), args: &SortArgs) -> Ordering {
441 match (complex_is_nan(a), complex_is_nan(b)) {
442 (true, true) => Ordering::Equal,
443 (true, false) => match args.direction {
444 SortDirection::Ascend => Ordering::Greater,
445 SortDirection::Descend => Ordering::Less,
446 },
447 (false, true) => match args.direction {
448 SortDirection::Ascend => Ordering::Less,
449 SortDirection::Descend => Ordering::Greater,
450 },
451 (false, false) => compare_complex_finite(a, b, args),
452 }
453}
454
455fn compare_complex_finite(a: (f64, f64), b: (f64, f64), args: &SortArgs) -> Ordering {
456 match args.comparison {
457 ComparisonMethod::Real => compare_complex_real_imag(a, b, args.direction),
458 ComparisonMethod::Abs | ComparisonMethod::Auto => {
459 let abs_cmp = complex_abs(a)
460 .partial_cmp(&complex_abs(b))
461 .unwrap_or(Ordering::Equal);
462 if abs_cmp != Ordering::Equal {
463 return match args.direction {
464 SortDirection::Ascend => abs_cmp,
465 SortDirection::Descend => abs_cmp.reverse(),
466 };
467 }
468 compare_complex_real_imag(a, b, args.direction)
469 }
470 }
471}
472
473fn compare_complex_real_imag(a: (f64, f64), b: (f64, f64), direction: SortDirection) -> Ordering {
474 let real_cmp = match direction {
475 SortDirection::Ascend => a.0.partial_cmp(&b.0),
476 SortDirection::Descend => b.0.partial_cmp(&a.0),
477 }
478 .unwrap_or(Ordering::Equal);
479 if real_cmp != Ordering::Equal {
480 return real_cmp;
481 }
482 match direction {
483 SortDirection::Ascend => a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal),
484 SortDirection::Descend => b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal),
485 }
486}
487
488fn complex_is_nan(value: (f64, f64)) -> bool {
489 value.0.is_nan() || value.1.is_nan()
490}
491
492fn complex_abs(value: (f64, f64)) -> f64 {
493 value.0.hypot(value.1)
494}
495
496fn stride_before(shape: &[usize], dim: usize) -> usize {
497 if dim <= 1 {
498 return 1;
499 }
500 let mut product = 1usize;
501 for i in 0..(dim - 1) {
502 product = product.saturating_mul(*shape.get(i).unwrap_or(&1));
503 }
504 product
505}
506
507fn stride_after(shape: &[usize], dim: usize) -> usize {
508 if dim >= shape.len() {
509 return 1;
510 }
511 let mut product = 1usize;
512 for extent in shape.iter().skip(dim) {
513 product = product.saturating_mul(*extent);
514 }
515 product
516}
517
518fn dimension_length(shape: &[usize], dim: usize) -> usize {
519 shape.get(dim - 1).copied().unwrap_or(1)
520}
521
522fn default_dimension(shape: &[usize]) -> usize {
523 shape
524 .iter()
525 .position(|&extent| extent > 1)
526 .map(|idx| idx + 1)
527 .unwrap_or(1)
528}
529
530#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
531enum SortDirection {
532 #[default]
533 Ascend,
534 Descend,
535}
536
537impl SortDirection {
538 fn to_provider(self) -> ProviderSortOrder {
539 match self {
540 SortDirection::Ascend => ProviderSortOrder::Ascend,
541 SortDirection::Descend => ProviderSortOrder::Descend,
542 }
543 }
544}
545
546#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
547enum ComparisonMethod {
548 #[default]
549 Auto,
550 Real,
551 Abs,
552}
553
554impl ComparisonMethod {
555 fn to_provider(self) -> ProviderSortComparison {
556 match self {
557 ComparisonMethod::Auto => ProviderSortComparison::Auto,
558 ComparisonMethod::Real => ProviderSortComparison::Real,
559 ComparisonMethod::Abs => ProviderSortComparison::Abs,
560 }
561 }
562}
563
564#[derive(Debug, Clone, Default)]
565struct SortArgs {
566 dimension: Option<usize>,
567 direction: SortDirection,
568 comparison: ComparisonMethod,
569}
570
571impl SortArgs {
572 fn parse(rest: &[Value]) -> Result<Self, String> {
573 let mut args = SortArgs::default();
574 let mut i = 0usize;
575 while i < rest.len() {
576 if args.dimension.is_none() {
577 if is_dimension_placeholder(&rest[i]) {
578 i += 1;
579 continue;
580 }
581 match tensor::parse_dimension(&rest[i], "sort") {
582 Ok(dim) => {
583 args.dimension = Some(dim);
584 i += 1;
585 continue;
586 }
587 Err(err) => {
588 if matches!(rest[i], Value::Int(_) | Value::Num(_)) {
589 return Err(err);
590 }
591 }
592 }
593 }
594 if let Some(keyword) = tensor::value_to_string(&rest[i]) {
595 let lowered = keyword.trim().to_ascii_lowercase();
596 match lowered.as_str() {
597 "ascend" | "ascending" => {
598 args.direction = SortDirection::Ascend;
599 i += 1;
600 continue;
601 }
602 "descend" | "descending" => {
603 args.direction = SortDirection::Descend;
604 i += 1;
605 continue;
606 }
607 "comparisonmethod" => {
608 i += 1;
609 if i >= rest.len() {
610 return Err("sort: expected a value for 'ComparisonMethod'".to_string());
611 }
612 let raw = &rest[i];
613 let value = match raw {
614 Value::String(s) => s.clone(),
615 Value::StringArray(sa) if sa.data.len() == 1 => sa.data[0].clone(),
616 Value::CharArray(ca) if ca.rows == 1 => ca.data.iter().collect(),
617 _ => {
618 return Err(
619 "sort: 'ComparisonMethod' requires a string value".to_string()
620 )
621 }
622 };
623 let lowered_value = value.trim().to_ascii_lowercase();
624 args.comparison = match lowered_value.as_str() {
625 "auto" => ComparisonMethod::Auto,
626 "real" => ComparisonMethod::Real,
627 "abs" | "magnitude" => ComparisonMethod::Abs,
628 other => {
629 return Err(format!("sort: unsupported ComparisonMethod '{other}'"))
630 }
631 };
632 i += 1;
633 continue;
634 }
635 "missingplacement" => {
636 return Err(
637 "sort: the 'MissingPlacement' option is not supported yet".to_string()
638 );
639 }
640 _ => {}
641 }
642 }
643 return Err(format!("sort: unrecognised argument {:?}", rest[i]));
644 }
645 Ok(args)
646 }
647}
648
649fn is_dimension_placeholder(value: &Value) -> bool {
650 match value {
651 Value::Tensor(t) => t.data.is_empty(),
652 Value::LogicalArray(logical) => logical.data.is_empty(),
653 _ => false,
654 }
655}
656
657pub struct SortEvaluation {
658 sorted: Value,
659 indices: Tensor,
660}
661
662impl SortEvaluation {
663 pub fn into_sorted_value(self) -> Value {
664 self.sorted
665 }
666
667 pub fn into_values(self) -> (Value, Value) {
668 let indices = tensor::tensor_into_value(self.indices);
669 (self.sorted, indices)
670 }
671
672 pub fn indices_value(&self) -> Value {
673 tensor::tensor_into_value(self.indices.clone())
674 }
675}
676
677#[cfg(test)]
678mod tests {
679 use super::*;
680 use crate::builtins::common::test_support;
681 use runmat_builtins::{ComplexTensor, IntValue, Tensor, Value};
682
683 #[test]
684 fn sort_vector_default() {
685 let tensor = Tensor::new(vec![3.0, 1.0, 2.0], vec![3, 1]).unwrap();
686 let result = sort_builtin(Value::Tensor(tensor), Vec::new()).expect("sort");
687 match result {
688 Value::Tensor(t) => {
689 assert_eq!(t.data, vec![1.0, 2.0, 3.0]);
690 assert_eq!(t.shape, vec![3, 1]);
691 }
692 other => panic!("expected tensor result, got {other:?}"),
693 }
694 }
695
696 #[test]
697 fn sort_descend_direction() {
698 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
699 let result =
700 sort_builtin(Value::Tensor(tensor), vec![Value::from("descend")]).expect("sort");
701 match result {
702 Value::Tensor(t) => assert_eq!(t.data, vec![4.0, 3.0, 2.0, 1.0]),
703 other => panic!("expected tensor, got {other:?}"),
704 }
705 }
706
707 #[test]
708 fn sort_matrix_default_dim1() {
709 let tensor = Tensor::new(vec![4.0, 2.0, 1.0, 5.0, 6.0, 3.0], vec![2, 3]).unwrap();
710 let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
711 let (sorted, indices) = eval.into_values();
712 match sorted {
713 Value::Tensor(t) => {
714 assert_eq!(t.data, vec![2.0, 4.0, 1.0, 5.0, 3.0, 6.0]);
715 assert_eq!(t.shape, vec![2, 3]);
716 }
717 other => panic!("expected tensor result, got {other:?}"),
718 }
719 match indices {
720 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0, 1.0, 2.0, 2.0, 1.0]),
721 other => panic!("expected tensor indices, got {other:?}"),
722 }
723 }
724
725 #[test]
726 fn sort_matrix_along_dimension_two() {
727 let tensor = Tensor::new(vec![1.0, 3.0, 4.0, 2.0, 2.0, 5.0], vec![2, 3]).unwrap();
728 let eval =
729 evaluate(Value::Tensor(tensor), &[Value::Int(IntValue::I32(2))]).expect("evaluate");
730 let (sorted, indices) = eval.into_values();
731 match sorted {
732 Value::Tensor(t) => {
733 assert_eq!(t.data, vec![1.0, 2.0, 2.0, 3.0, 4.0, 5.0]);
734 assert_eq!(t.shape, vec![2, 3]);
735 }
736 other => panic!("expected tensor result, got {other:?}"),
737 }
738 match indices {
739 Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]),
740 other => panic!("expected tensor indices, got {other:?}"),
741 }
742 }
743
744 #[test]
745 fn sort_dimension_placeholder_then_dim() {
746 let tensor = Tensor::new(vec![1.0, 3.0, 4.0, 2.0], vec![2, 2]).unwrap();
747 let placeholder = Tensor::new(Vec::new(), vec![0, 0]).unwrap();
748 let eval = evaluate(
749 Value::Tensor(tensor),
750 &[
751 Value::Tensor(placeholder),
752 Value::Int(IntValue::I32(2)),
753 Value::from("descend"),
754 ],
755 )
756 .expect("evaluate");
757 let (sorted, _) = eval.into_values();
758 match sorted {
759 Value::Tensor(t) => assert_eq!(t.data, vec![4.0, 3.0, 1.0, 2.0]),
760 other => panic!("expected tensor result, got {other:?}"),
761 }
762 }
763
764 #[test]
765 fn sort_descend_then_dimension() {
766 let tensor = Tensor::new(vec![1.0, 3.0, 4.0, 2.0, 2.0, 5.0], vec![2, 3]).unwrap();
767 let eval = evaluate(
768 Value::Tensor(tensor),
769 &[Value::from("descend"), Value::Int(IntValue::I32(1))],
770 )
771 .expect("evaluate");
772 let (sorted, _) = eval.into_values();
773 match sorted {
774 Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 1.0, 4.0, 2.0, 5.0, 2.0]),
775 other => panic!("expected tensor result, got {other:?}"),
776 }
777 }
778
779 #[test]
780 fn sort_returns_indices() {
781 let tensor = Tensor::new(vec![4.0, 1.0, 9.0, 2.0], vec![4, 1]).unwrap();
782 let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
783 let (sorted, indices) = eval.into_values();
784 match sorted {
785 Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 4.0, 9.0]),
786 other => panic!("expected tensor, got {other:?}"),
787 }
788 match indices {
789 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 4.0, 1.0, 3.0]),
790 other => panic!("expected tensor, got {other:?}"),
791 }
792 }
793
794 #[test]
795 fn sort_with_nan_handling() {
796 let tensor = Tensor::new(vec![f64::NAN, 4.0, 1.0, 2.0], vec![4, 1]).unwrap();
797 let eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("evaluate");
798 let (sorted, _) = eval.into_values();
799 match sorted {
800 Value::Tensor(t) => {
801 assert!(t.data[3].is_nan());
802 assert_eq!(&t.data[0..3], &[1.0, 2.0, 4.0]);
803 }
804 other => panic!("expected tensor, got {other:?}"),
805 }
806
807 let eval_desc =
808 evaluate(Value::Tensor(tensor), &[Value::from("descend")]).expect("evaluate");
809 let (sorted_desc, _) = eval_desc.into_values();
810 match sorted_desc {
811 Value::Tensor(t) => {
812 assert!(t.data[0].is_nan());
813 assert_eq!(&t.data[1..], &[4.0, 2.0, 1.0]);
814 }
815 other => panic!("expected tensor, got {other:?}"),
816 }
817 }
818
819 #[test]
820 fn sort_by_absolute_value() {
821 let tensor = Tensor::new(vec![-8.0, -1.0, 3.0, -2.0], vec![4, 1]).unwrap();
822 let eval = evaluate(
823 Value::Tensor(tensor),
824 &[Value::from("ComparisonMethod"), Value::from("abs")],
825 )
826 .expect("evaluate");
827 let (sorted, _) = eval.into_values();
828 match sorted {
829 Value::Tensor(t) => assert_eq!(t.data, vec![-1.0, -2.0, 3.0, -8.0]),
830 other => panic!("expected tensor, got {other:?}"),
831 }
832 }
833
834 #[test]
835 fn sort_by_absolute_value_descend() {
836 let tensor = Tensor::new(vec![-1.0, 2.0, -3.0, 4.0], vec![4, 1]).unwrap();
837 let eval = evaluate(
838 Value::Tensor(tensor),
839 &[
840 Value::from("descend"),
841 Value::from("ComparisonMethod"),
842 Value::from("abs"),
843 ],
844 )
845 .expect("evaluate");
846 let (sorted, _) = eval.into_values();
847 match sorted {
848 Value::Tensor(t) => assert_eq!(t.data, vec![4.0, -3.0, 2.0, -1.0]),
849 other => panic!("expected tensor, got {other:?}"),
850 }
851 }
852
853 #[test]
854 fn sort_complex_auto_abs() {
855 let tensor =
856 ComplexTensor::new(vec![(1.0, 2.0), (-3.0, 0.5), (0.0, -1.0)], vec![3, 1]).unwrap();
857 let eval = evaluate(Value::ComplexTensor(tensor), &[]).expect("evaluate");
858 let (sorted, indices) = eval.into_values();
859 match sorted {
860 Value::ComplexTensor(t) => {
861 assert_eq!(t.data, vec![(0.0, -1.0), (1.0, 2.0), (-3.0, 0.5)])
862 }
863 other => panic!("expected complex tensor, got {other:?}"),
864 }
865 match indices {
866 Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 1.0, 2.0]),
867 other => panic!("expected tensor indices, got {other:?}"),
868 }
869 }
870
871 #[test]
872 fn sort_complex_real_descend() {
873 let tensor =
874 ComplexTensor::new(vec![(1.0, 2.0), (-3.0, 0.0), (1.0, -1.0)], vec![3, 1]).unwrap();
875 let eval = evaluate(
876 Value::ComplexTensor(tensor),
877 &[
878 Value::from("descend"),
879 Value::from("ComparisonMethod"),
880 Value::from("real"),
881 ],
882 )
883 .expect("evaluate");
884 let (sorted, _) = eval.into_values();
885 match sorted {
886 Value::ComplexTensor(t) => {
887 assert_eq!(t.data, vec![(1.0, 2.0), (1.0, -1.0), (-3.0, 0.0)]);
888 }
889 other => panic!("expected complex tensor, got {other:?}"),
890 }
891 }
892
893 #[test]
894 fn sort_stable_with_duplicates() {
895 let tensor = Tensor::new(vec![2.0, 2.0, 1.0, 2.0], vec![4, 1]).unwrap();
896 let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
897 let (sorted, indices) = eval.into_values();
898 match sorted {
899 Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 2.0, 2.0]),
900 other => panic!("expected tensor, got {other:?}"),
901 }
902 match indices {
903 Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 1.0, 2.0, 4.0]),
904 other => panic!("expected tensor indices, got {other:?}"),
905 }
906 }
907
908 #[test]
909 fn sort_empty_tensor() {
910 let tensor = Tensor::new(Vec::new(), vec![0, 3]).unwrap();
911 let eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("evaluate");
912 let (sorted, indices) = eval.into_values();
913 match sorted {
914 Value::Tensor(t) => {
915 assert!(t.data.is_empty());
916 assert_eq!(t.shape, tensor.shape);
917 }
918 other => panic!("expected tensor, got {other:?}"),
919 }
920 match indices {
921 Value::Tensor(t) => assert!(t.data.is_empty()),
922 other => panic!("expected tensor, got {other:?}"),
923 }
924 }
925
926 #[test]
927 fn sort_dim_greater_than_ndims() {
928 let tensor = Tensor::new(vec![4.0, 2.0, 3.0, 1.0], vec![2, 2]).unwrap();
929 let eval = evaluate(
930 Value::Tensor(tensor.clone()),
931 &[Value::Int(IntValue::I32(3))],
932 )
933 .expect("evaluate");
934 let (sorted, indices) = eval.into_values();
935 match sorted {
936 Value::Tensor(t) => assert_eq!(t.data, tensor.data),
937 other => panic!("expected tensor, got {other:?}"),
938 }
939 match indices {
940 Value::Tensor(t) => assert!(t.data.iter().all(|v| (*v - 1.0).abs() < f64::EPSILON)),
941 other => panic!("expected tensor, got {other:?}"),
942 }
943 }
944
945 #[test]
946 fn sort_invalid_argument_errors() {
947 let err = sort_builtin(
948 Value::Tensor(Tensor::new(vec![1.0], vec![1, 1]).unwrap()),
949 vec![Value::from("missingplacement"), Value::from("first")],
950 )
951 .unwrap_err();
952 assert!(err.contains("MissingPlacement"));
953 }
954
955 #[test]
956 fn sort_invalid_comparison_method_errors() {
957 let err = sort_builtin(
958 Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap()),
959 vec![Value::from("ComparisonMethod"), Value::from("unknown")],
960 )
961 .unwrap_err();
962 assert!(err.contains("ComparisonMethod"), "unexpected error: {err}");
963 }
964
965 #[test]
966 fn sort_invalid_comparison_method_value_errors() {
967 let err = sort_builtin(
968 Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap()),
969 vec![
970 Value::from("ComparisonMethod"),
971 Value::Int(IntValue::I32(1)),
972 ],
973 )
974 .unwrap_err();
975 assert!(
976 err.contains("requires a string value"),
977 "unexpected error: {err}"
978 );
979 }
980
981 #[test]
982 fn sort_dimension_zero_errors() {
983 let err = sort_builtin(
984 Value::Tensor(Tensor::new(vec![1.0], vec![1, 1]).unwrap()),
985 vec![Value::Num(0.0)],
986 )
987 .unwrap_err();
988 assert!(
989 err.contains("dimension must be >= 1"),
990 "unexpected error: {err}"
991 );
992 }
993
994 #[test]
995 fn sort_gpu_round_trip() {
996 test_support::with_test_provider(|provider| {
997 let tensor = Tensor::new(vec![3.0, 1.0, 2.0], vec![3, 1]).unwrap();
998 let view = runmat_accelerate_api::HostTensorView {
999 data: &tensor.data,
1000 shape: &tensor.shape,
1001 };
1002 let handle = provider.upload(&view).expect("upload");
1003 let eval = evaluate(Value::GpuTensor(handle), &[]).expect("evaluate");
1004 let (sorted, indices) = eval.into_values();
1005 match sorted {
1006 Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 3.0]),
1007 other => panic!("expected tensor, got {other:?}"),
1008 }
1009 match indices {
1010 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
1011 other => panic!("expected tensor, got {other:?}"),
1012 }
1013 });
1014 }
1015
1016 #[test]
1017 #[cfg(feature = "wgpu")]
1018 fn sort_wgpu_matches_cpu() {
1019 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1020 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1021 );
1022 let tensor = Tensor::new(vec![4.0, 1.0, 3.0, 2.0], vec![4, 1]).unwrap();
1023 let cpu_eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("cpu sort");
1024 let (cpu_sorted, cpu_indices) = cpu_eval.into_values();
1025
1026 let gpu_view = runmat_accelerate_api::HostTensorView {
1027 data: &tensor.data,
1028 shape: &tensor.shape,
1029 };
1030 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1031 let handle = provider.upload(&gpu_view).expect("upload");
1032 let gpu_eval = evaluate(Value::GpuTensor(handle), &[]).expect("gpu sort");
1033 let (gpu_sorted, gpu_indices) = gpu_eval.into_values();
1034
1035 let cpu_sorted_tensor = match cpu_sorted {
1036 Value::Tensor(t) => t,
1037 Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
1038 other => panic!("unexpected CPU sorted value {other:?}"),
1039 };
1040 let cpu_indices_tensor = match cpu_indices {
1041 Value::Tensor(t) => t,
1042 Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
1043 other => panic!("unexpected CPU indices value {other:?}"),
1044 };
1045 let gpu_sorted_tensor = match gpu_sorted {
1046 Value::Tensor(t) => t,
1047 Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
1048 other => panic!("unexpected GPU sorted value {other:?}"),
1049 };
1050 let gpu_indices_tensor = match gpu_indices {
1051 Value::Tensor(t) => t,
1052 Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
1053 other => panic!("unexpected GPU indices value {other:?}"),
1054 };
1055
1056 assert_eq!(gpu_sorted_tensor.data, cpu_sorted_tensor.data);
1057 assert_eq!(gpu_indices_tensor.data, cpu_indices_tensor.data);
1058 }
1059
1060 #[cfg(feature = "doc_export")]
1061 #[test]
1062 fn doc_examples_present() {
1063 let blocks = test_support::doc_examples(DOC_MD);
1064 assert!(!blocks.is_empty());
1065 }
1066}