runmat_runtime/builtins/array/sorting_sets/
sort.rs

1//! MATLAB-compatible `sort` builtin with multi-output and GPU-aware semantics.
2
3use std::cmp::Ordering;
4
5use runmat_accelerate_api::{
6    GpuTensorHandle, SortComparison as ProviderSortComparison, SortOrder as ProviderSortOrder,
7};
8use runmat_builtins::{ComplexTensor, Tensor, Value};
9use runmat_macros::runtime_builtin;
10
11use crate::builtins::common::gpu_helpers;
12use crate::builtins::common::spec::{
13    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
14    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
15};
16use crate::builtins::common::tensor;
17#[cfg(feature = "doc_export")]
18use crate::register_builtin_doc_text;
19use crate::{register_builtin_fusion_spec, register_builtin_gpu_spec};
20
21#[cfg(feature = "doc_export")]
22pub const DOC_MD: &str = r#"---
23title: "sort"
24category: "array/sorting_sets"
25keywords: ["sort", "ascending", "descending", "indices", "comparisonmethod", "gpu"]
26summary: "Sort scalars, vectors, matrices, or N-D tensors along a dimension, with optional index outputs."
27references:
28  - https://www.mathworks.com/help/matlab/ref/sort.html
29gpu_support:
30  elementwise: false
31  reduction: false
32  precisions: ["f32", "f64"]
33  broadcasting: "none"
34  notes: "Uses the provider `sort_dim` hook when available; otherwise tensors are gathered and sorted on the host."
35fusion:
36  elementwise: false
37  reduction: false
38  max_inputs: 1
39  constants: "inline"
40requires_feature: null
41tested:
42  unit: "builtins::array::sorting_sets::sort::tests"
43  integration: "builtins::array::sorting_sets::sort::tests::sort_gpu_round_trip"
44---
45
46# What does the `sort` function do in MATLAB / RunMat?
47`sort` orders the elements of its input along a chosen dimension. By default, it sorts ascending along the first non-singleton dimension and can also return the permutation indices that achieve the ordering.
48
49## How does the `sort` function behave in MATLAB / RunMat?
50- Scalars remain unchanged; vectors are reordered into ascending or descending order.
51- For matrices and higher-dimensional tensors, the sort operates along a specified dimension (default: first non-singleton). All other dimensions remain untouched.
52- `[B, I] = sort(A, ...)` returns both the sorted values `B` and the permutation indices `I`, using MATLAB's one-based indexing.
53- Direction arguments accept `'ascend'` (default) or `'descend'`.
54- Name-value pairs support `'ComparisonMethod'` with values `'auto'`, `'real'`, or `'abs'`. The `'abs'` option sorts by absolute value while breaking ties using the signed value.
55- For complex inputs, the default `'ComparisonMethod'='auto'` behaves like `'abs'` (magnitude ordering) and `'real'` compares the real component first with the imaginary component used for tie-breaking.
56- NaN values are treated as missing: they appear at the end for ascending sorts and at the beginning for descending sorts (matching MATLAB's `'MissingPlacement','auto'` behaviour for doubles).
57- Dimensions greater than `ndims(A)` are treated as singleton dimensions (size 1) and therefore leave `A` unchanged while returning index values of `1`.
58
59## GPU execution in RunMat
60- `sort` is registered as a sink builtin. When tensors reside on a GPU without a specialised sort kernel, RunMat gathers them to host memory, performs the sort, and returns host-resident outputs.
61- Providers may implement a future `sort_dim` hook to keep data on the GPU. Until then, all providers fall back to the host path automatically.
62- The returned index tensor is always host-resident double precision.
63
64## Examples of using `sort` in MATLAB / RunMat
65
66### Sorting a column vector in ascending order
67```matlab
68A = [3; 1; 2];
69B = sort(A);
70```
71Expected output:
72```matlab
73B =
74     1
75     2
76     3
77```
78
79### Sorting rows by specifying the dimension
80```matlab
81A = [1 4 2; 3 2 5];
82B = sort(A, 2);
83```
84Expected output:
85```matlab
86B =
87     1     2     4
88     2     3     5
89```
90
91### Sorting values in descending order
92```matlab
93A = [10 4 7 9];
94B = sort(A, 'descend');
95```
96Expected output:
97```matlab
98B =
99    10     9     7     4
100```
101
102### Retrieving permutation indices alongside the sorted values
103```matlab
104A = [4 1 9 2];
105[B, I] = sort(A);
106```
107Expected output:
108```matlab
109B =
110     1     2     4     9
111I =
112     2     4     1     3
113```
114
115### Sorting by absolute value using `ComparisonMethod`
116```matlab
117A = [-8 -1 3 -2];
118B = sort(A, 'ComparisonMethod', 'abs');
119```
120Expected output:
121```matlab
122B =
123    -1    -2     3    -8
124```
125
126### Sorting tensors containing NaN values
127```matlab
128A = [NaN 4 1 2];
129[B, I] = sort(A);
130```
131Expected output:
132```matlab
133B =
134     1     2     4   NaN
135I =
136     3     4     2     1
137```
138
139### Sorting GPU tensors with automatic host fallback
140```matlab
141G = gpuArray(randn(5, 1));
142[B, I] = sort(G, 'descend');
143```
144The runtime gathers `G`, performs the sort on the host, and returns host-resident results. The ordering matches MATLAB's semantics exactly.
145
146## FAQ
147
148### Can `sort` return both values and indices?
149Yes. Use `[B, I] = sort(A, ...)` to receive the permutation indices alongside the sorted values.
150
151### How are NaN values handled?
152NaN values are considered missing. They appear last for ascending sorts and first for descending sorts, matching MATLAB's `'MissingPlacement','auto'` default for doubles.
153
154### What happens when I sort along a dimension that does not exist?
155`sort(A, dim)` treats dimensions beyond `ndims(A)` as singleton dimensions (size 1). The data remains unchanged and the index output is filled with ones.
156
157### Does `sort` support name-value arguments?
158Yes. `'ComparisonMethod'` accepts `'auto'`, `'real'`, or `'abs'`. Other name-value pairs such as `'MissingPlacement'` are currently not supported and raise an error.
159
160### Are GPU tensors sorted in-place?
161Not yet. `sort` is a sink builtin that gathers GPU tensors to host memory when no GPU sort kernel is available. Providers can implement specialised hooks in the future.
162
163### Does `sort` preserve the shape of the input?
164Yes. The output is the same size as the input tensor. Only values along the selected dimension are reordered.
165
166### What numeric type do the index outputs use?
167Permutation indices are returned as double-precision tensors (or scalars) using MATLAB's one-based indexing.
168
169### Is the sorting stable?
170Yes. Equal elements (including ties when sorting by absolute value) preserve their original order.
171
172### How does `ComparisonMethod` behave with real inputs?
173`'auto'` and `'real'` behave identically. `'abs'` sorts by absolute value and uses the signed value to break ties so that results match MATLAB.
174
175### How does `ComparisonMethod` behave with complex inputs?
176`'auto'` (the default) and `'abs'` order values by magnitude. `'real'` compares the real component first and falls back to the imaginary component to break ties while preserving stability.
177
178### Do logical or scalar inputs work?
179Yes. Logical inputs are promoted to double precision automatically. Scalars are returned unchanged with an index output of `1` when requested.
180
181## See also
182[sortrows](./sortrows), [unique](./unique), [max](../../math/reduction/max), [min](../../math/reduction/min), [permute](../../array/shape/permute)
183
184## Source & Feedback
185- Source code: [`crates/runmat-runtime/src/builtins/array/sorting_sets/sort.rs`](https://github.com/runmat-org/runmat/blob/main/crates/runmat-runtime/src/builtins/array/sorting_sets/sort.rs)
186- Found a bug? [Open an issue](https://github.com/runmat-org/runmat/issues/new/choose) with a minimal reproduction.
187"#;
188
189pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
190    name: "sort",
191    op_kind: GpuOpKind::Custom("sort"),
192    supported_precisions: &[ScalarType::F32, ScalarType::F64],
193    broadcast: BroadcastSemantics::None,
194    provider_hooks: &[ProviderHook::Custom("sort_dim")],
195    constant_strategy: ConstantStrategy::InlineLiteral,
196    residency: ResidencyPolicy::GatherImmediately,
197    nan_mode: ReductionNaN::Include,
198    two_pass_threshold: None,
199    workgroup_size: None,
200    accepts_nan_mode: true,
201    notes: "Providers may add a dedicated sort kernel in the future; today tensors are gathered to host memory before sorting.",
202};
203
204register_builtin_gpu_spec!(GPU_SPEC);
205
206pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
207    name: "sort",
208    shape: ShapeRequirements::Any,
209    constant_strategy: ConstantStrategy::InlineLiteral,
210    elementwise: None,
211    reduction: None,
212    emits_nan: true,
213    notes: "Sorting breaks fusion chains and acts as a residency sink; upstream tensors are gathered to host memory.",
214};
215
216register_builtin_fusion_spec!(FUSION_SPEC);
217
218#[cfg(feature = "doc_export")]
219register_builtin_doc_text!("sort", DOC_MD);
220
221#[runtime_builtin(
222    name = "sort",
223    category = "array/sorting_sets",
224    summary = "Sort scalars, vectors, matrices, or N-D tensors along a dimension, with optional index outputs.",
225    keywords = "sort,ascending,descending,indices,comparisonmethod,gpu",
226    accel = "sink",
227    sink = true
228)]
229fn sort_builtin(value: Value, rest: Vec<Value>) -> Result<Value, String> {
230    evaluate(value, &rest).map(|eval| eval.into_sorted_value())
231}
232
233/// Evaluate the `sort` builtin once and expose both outputs.
234pub fn evaluate(value: Value, rest: &[Value]) -> Result<SortEvaluation, String> {
235    let args = SortArgs::parse(rest)?;
236    match value {
237        Value::GpuTensor(handle) => sort_gpu(handle, &args),
238        other => sort_host(other, &args),
239    }
240}
241
242fn sort_gpu(handle: GpuTensorHandle, args: &SortArgs) -> Result<SortEvaluation, String> {
243    let shape = handle.shape.clone();
244    let dim = args.dimension.unwrap_or_else(|| default_dimension(&shape));
245    if dim == 0 {
246        return Err("sort: dimension must be >= 1".to_string());
247    }
248    let dim_len = dimension_length(&shape, dim);
249    if dim_len > 1 {
250        if let Some(provider) = runmat_accelerate_api::provider() {
251            let order = args.direction.to_provider();
252            let comparison = args.comparison.to_provider();
253            let zero_based = dim - 1;
254            if let Ok(result) = provider.sort_dim(&handle, zero_based, order, comparison) {
255                let sorted_tensor = Tensor::new(result.values.data, result.values.shape)
256                    .map_err(|e| format!("sort: {e}"))?;
257                let sorted_value = tensor::tensor_into_value(sorted_tensor);
258                let indices_tensor = Tensor::new(result.indices.data, result.indices.shape)
259                    .map_err(|e| format!("sort: {e}"))?;
260                return Ok(SortEvaluation {
261                    sorted: sorted_value,
262                    indices: indices_tensor,
263                });
264            }
265        }
266    }
267    let tensor = gpu_helpers::gather_tensor(&handle)?;
268    sort_real_tensor(tensor, args)
269}
270
271fn sort_host(value: Value, args: &SortArgs) -> Result<SortEvaluation, String> {
272    match value {
273        Value::ComplexTensor(ct) => sort_complex_tensor(ct, args),
274        Value::Complex(re, im) => {
275            let tensor =
276                ComplexTensor::new(vec![(re, im)], vec![1, 1]).map_err(|e| format!("sort: {e}"))?;
277            sort_complex_tensor(tensor, args)
278        }
279        other => {
280            let tensor = tensor::value_into_tensor_for("sort", other)?;
281            sort_real_tensor(tensor, args)
282        }
283    }
284}
285
286fn sort_real_tensor(tensor: Tensor, args: &SortArgs) -> Result<SortEvaluation, String> {
287    let dim = args
288        .dimension
289        .unwrap_or_else(|| default_dimension(&tensor.shape));
290    if dim == 0 {
291        return Err("sort: dimension must be >= 1".to_string());
292    }
293
294    let dim_len = dimension_length(&tensor.shape, dim);
295    if tensor.data.is_empty() || dim_len <= 1 {
296        let indices = vec![1.0; tensor.data.len()];
297        let index_tensor =
298            Tensor::new(indices, tensor.shape.clone()).map_err(|e| format!("sort: {e}"))?;
299        let sorted_value = tensor::tensor_into_value(tensor);
300        return Ok(SortEvaluation {
301            sorted: sorted_value,
302            indices: index_tensor,
303        });
304    }
305
306    let stride_before = stride_before(&tensor.shape, dim);
307    let stride_after = stride_after(&tensor.shape, dim);
308    let mut sorted = tensor.data.clone();
309    let mut indices = vec![0.0f64; tensor.data.len()];
310    let mut buffer: Vec<(usize, f64)> = Vec::with_capacity(dim_len);
311
312    for after in 0..stride_after {
313        for before in 0..stride_before {
314            buffer.clear();
315            for k in 0..dim_len {
316                let idx = before + k * stride_before + after * stride_before * dim_len;
317                let value = tensor.data[idx];
318                buffer.push((k, value));
319            }
320            buffer.sort_by(|a, b| compare_real_values(a.1, b.1, args));
321            for (pos, (original_index, value)) in buffer.iter().enumerate() {
322                let target = before + pos * stride_before + after * stride_before * dim_len;
323                sorted[target] = *value;
324                indices[target] = (*original_index + 1) as f64;
325            }
326        }
327    }
328
329    let sorted_tensor =
330        Tensor::new(sorted, tensor.shape.clone()).map_err(|e| format!("sort: {e}"))?;
331    let index_tensor =
332        Tensor::new(indices, tensor.shape.clone()).map_err(|e| format!("sort: {e}"))?;
333
334    Ok(SortEvaluation {
335        sorted: tensor::tensor_into_value(sorted_tensor),
336        indices: index_tensor,
337    })
338}
339
340fn sort_complex_tensor(tensor: ComplexTensor, args: &SortArgs) -> Result<SortEvaluation, String> {
341    let dim = args
342        .dimension
343        .unwrap_or_else(|| default_dimension(&tensor.shape));
344    if dim == 0 {
345        return Err("sort: dimension must be >= 1".to_string());
346    }
347
348    let dim_len = dimension_length(&tensor.shape, dim);
349    if tensor.data.is_empty() || dim_len <= 1 {
350        let indices = vec![1.0; tensor.data.len()];
351        let index_tensor =
352            Tensor::new(indices, tensor.shape.clone()).map_err(|e| format!("sort: {e}"))?;
353        return Ok(SortEvaluation {
354            sorted: complex_tensor_into_value(tensor),
355            indices: index_tensor,
356        });
357    }
358
359    let stride_before = stride_before(&tensor.shape, dim);
360    let stride_after = stride_after(&tensor.shape, dim);
361    let mut sorted = tensor.data.clone();
362    let mut indices = vec![0.0f64; tensor.data.len()];
363    let mut buffer: Vec<(usize, (f64, f64))> = Vec::with_capacity(dim_len);
364
365    for after in 0..stride_after {
366        for before in 0..stride_before {
367            buffer.clear();
368            for k in 0..dim_len {
369                let idx = before + k * stride_before + after * stride_before * dim_len;
370                let value = tensor.data[idx];
371                buffer.push((k, value));
372            }
373            buffer.sort_by(|a, b| compare_complex_values(a.1, b.1, args));
374            for (pos, (original_index, value)) in buffer.iter().enumerate() {
375                let target = before + pos * stride_before + after * stride_before * dim_len;
376                sorted[target] = *value;
377                indices[target] = (*original_index + 1) as f64;
378            }
379        }
380    }
381
382    let sorted_tensor =
383        ComplexTensor::new(sorted, tensor.shape.clone()).map_err(|e| format!("sort: {e}"))?;
384    let index_tensor =
385        Tensor::new(indices, tensor.shape.clone()).map_err(|e| format!("sort: {e}"))?;
386
387    Ok(SortEvaluation {
388        sorted: complex_tensor_into_value(sorted_tensor),
389        indices: index_tensor,
390    })
391}
392
393fn complex_tensor_into_value(tensor: ComplexTensor) -> Value {
394    if tensor.data.len() == 1 {
395        let (re, im) = tensor.data[0];
396        Value::Complex(re, im)
397    } else {
398        Value::ComplexTensor(tensor)
399    }
400}
401
402fn compare_real_values(a: f64, b: f64, args: &SortArgs) -> Ordering {
403    match (a.is_nan(), b.is_nan()) {
404        (true, true) => Ordering::Equal,
405        (true, false) => match args.direction {
406            SortDirection::Ascend => Ordering::Greater,
407            SortDirection::Descend => Ordering::Less,
408        },
409        (false, true) => match args.direction {
410            SortDirection::Ascend => Ordering::Less,
411            SortDirection::Descend => Ordering::Greater,
412        },
413        (false, false) => compare_real_finite(a, b, args),
414    }
415}
416
417fn compare_real_finite(a: f64, b: f64, args: &SortArgs) -> Ordering {
418    let primary = match args.comparison {
419        ComparisonMethod::Abs => {
420            let abs_cmp = a.abs().partial_cmp(&b.abs()).unwrap_or(Ordering::Equal);
421            if abs_cmp != Ordering::Equal {
422                return match args.direction {
423                    SortDirection::Ascend => abs_cmp,
424                    SortDirection::Descend => abs_cmp.reverse(),
425                };
426            }
427            Ordering::Equal
428        }
429        ComparisonMethod::Auto | ComparisonMethod::Real => Ordering::Equal,
430    };
431    if primary != Ordering::Equal {
432        return primary;
433    }
434    match args.direction {
435        SortDirection::Ascend => a.partial_cmp(&b).unwrap_or(Ordering::Equal),
436        SortDirection::Descend => b.partial_cmp(&a).unwrap_or(Ordering::Equal),
437    }
438}
439
440fn compare_complex_values(a: (f64, f64), b: (f64, f64), args: &SortArgs) -> Ordering {
441    match (complex_is_nan(a), complex_is_nan(b)) {
442        (true, true) => Ordering::Equal,
443        (true, false) => match args.direction {
444            SortDirection::Ascend => Ordering::Greater,
445            SortDirection::Descend => Ordering::Less,
446        },
447        (false, true) => match args.direction {
448            SortDirection::Ascend => Ordering::Less,
449            SortDirection::Descend => Ordering::Greater,
450        },
451        (false, false) => compare_complex_finite(a, b, args),
452    }
453}
454
455fn compare_complex_finite(a: (f64, f64), b: (f64, f64), args: &SortArgs) -> Ordering {
456    match args.comparison {
457        ComparisonMethod::Real => compare_complex_real_imag(a, b, args.direction),
458        ComparisonMethod::Abs | ComparisonMethod::Auto => {
459            let abs_cmp = complex_abs(a)
460                .partial_cmp(&complex_abs(b))
461                .unwrap_or(Ordering::Equal);
462            if abs_cmp != Ordering::Equal {
463                return match args.direction {
464                    SortDirection::Ascend => abs_cmp,
465                    SortDirection::Descend => abs_cmp.reverse(),
466                };
467            }
468            compare_complex_real_imag(a, b, args.direction)
469        }
470    }
471}
472
473fn compare_complex_real_imag(a: (f64, f64), b: (f64, f64), direction: SortDirection) -> Ordering {
474    let real_cmp = match direction {
475        SortDirection::Ascend => a.0.partial_cmp(&b.0),
476        SortDirection::Descend => b.0.partial_cmp(&a.0),
477    }
478    .unwrap_or(Ordering::Equal);
479    if real_cmp != Ordering::Equal {
480        return real_cmp;
481    }
482    match direction {
483        SortDirection::Ascend => a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal),
484        SortDirection::Descend => b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal),
485    }
486}
487
488fn complex_is_nan(value: (f64, f64)) -> bool {
489    value.0.is_nan() || value.1.is_nan()
490}
491
492fn complex_abs(value: (f64, f64)) -> f64 {
493    value.0.hypot(value.1)
494}
495
496fn stride_before(shape: &[usize], dim: usize) -> usize {
497    if dim <= 1 {
498        return 1;
499    }
500    let mut product = 1usize;
501    for i in 0..(dim - 1) {
502        product = product.saturating_mul(*shape.get(i).unwrap_or(&1));
503    }
504    product
505}
506
507fn stride_after(shape: &[usize], dim: usize) -> usize {
508    if dim >= shape.len() {
509        return 1;
510    }
511    let mut product = 1usize;
512    for extent in shape.iter().skip(dim) {
513        product = product.saturating_mul(*extent);
514    }
515    product
516}
517
518fn dimension_length(shape: &[usize], dim: usize) -> usize {
519    shape.get(dim - 1).copied().unwrap_or(1)
520}
521
522fn default_dimension(shape: &[usize]) -> usize {
523    shape
524        .iter()
525        .position(|&extent| extent > 1)
526        .map(|idx| idx + 1)
527        .unwrap_or(1)
528}
529
530#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
531enum SortDirection {
532    #[default]
533    Ascend,
534    Descend,
535}
536
537impl SortDirection {
538    fn to_provider(self) -> ProviderSortOrder {
539        match self {
540            SortDirection::Ascend => ProviderSortOrder::Ascend,
541            SortDirection::Descend => ProviderSortOrder::Descend,
542        }
543    }
544}
545
546#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
547enum ComparisonMethod {
548    #[default]
549    Auto,
550    Real,
551    Abs,
552}
553
554impl ComparisonMethod {
555    fn to_provider(self) -> ProviderSortComparison {
556        match self {
557            ComparisonMethod::Auto => ProviderSortComparison::Auto,
558            ComparisonMethod::Real => ProviderSortComparison::Real,
559            ComparisonMethod::Abs => ProviderSortComparison::Abs,
560        }
561    }
562}
563
564#[derive(Debug, Clone, Default)]
565struct SortArgs {
566    dimension: Option<usize>,
567    direction: SortDirection,
568    comparison: ComparisonMethod,
569}
570
571impl SortArgs {
572    fn parse(rest: &[Value]) -> Result<Self, String> {
573        let mut args = SortArgs::default();
574        let mut i = 0usize;
575        while i < rest.len() {
576            if args.dimension.is_none() {
577                if is_dimension_placeholder(&rest[i]) {
578                    i += 1;
579                    continue;
580                }
581                match tensor::parse_dimension(&rest[i], "sort") {
582                    Ok(dim) => {
583                        args.dimension = Some(dim);
584                        i += 1;
585                        continue;
586                    }
587                    Err(err) => {
588                        if matches!(rest[i], Value::Int(_) | Value::Num(_)) {
589                            return Err(err);
590                        }
591                    }
592                }
593            }
594            if let Some(keyword) = tensor::value_to_string(&rest[i]) {
595                let lowered = keyword.trim().to_ascii_lowercase();
596                match lowered.as_str() {
597                    "ascend" | "ascending" => {
598                        args.direction = SortDirection::Ascend;
599                        i += 1;
600                        continue;
601                    }
602                    "descend" | "descending" => {
603                        args.direction = SortDirection::Descend;
604                        i += 1;
605                        continue;
606                    }
607                    "comparisonmethod" => {
608                        i += 1;
609                        if i >= rest.len() {
610                            return Err("sort: expected a value for 'ComparisonMethod'".to_string());
611                        }
612                        let raw = &rest[i];
613                        let value = match raw {
614                            Value::String(s) => s.clone(),
615                            Value::StringArray(sa) if sa.data.len() == 1 => sa.data[0].clone(),
616                            Value::CharArray(ca) if ca.rows == 1 => ca.data.iter().collect(),
617                            _ => {
618                                return Err(
619                                    "sort: 'ComparisonMethod' requires a string value".to_string()
620                                )
621                            }
622                        };
623                        let lowered_value = value.trim().to_ascii_lowercase();
624                        args.comparison = match lowered_value.as_str() {
625                            "auto" => ComparisonMethod::Auto,
626                            "real" => ComparisonMethod::Real,
627                            "abs" | "magnitude" => ComparisonMethod::Abs,
628                            other => {
629                                return Err(format!("sort: unsupported ComparisonMethod '{other}'"))
630                            }
631                        };
632                        i += 1;
633                        continue;
634                    }
635                    "missingplacement" => {
636                        return Err(
637                            "sort: the 'MissingPlacement' option is not supported yet".to_string()
638                        );
639                    }
640                    _ => {}
641                }
642            }
643            return Err(format!("sort: unrecognised argument {:?}", rest[i]));
644        }
645        Ok(args)
646    }
647}
648
649fn is_dimension_placeholder(value: &Value) -> bool {
650    match value {
651        Value::Tensor(t) => t.data.is_empty(),
652        Value::LogicalArray(logical) => logical.data.is_empty(),
653        _ => false,
654    }
655}
656
657pub struct SortEvaluation {
658    sorted: Value,
659    indices: Tensor,
660}
661
662impl SortEvaluation {
663    pub fn into_sorted_value(self) -> Value {
664        self.sorted
665    }
666
667    pub fn into_values(self) -> (Value, Value) {
668        let indices = tensor::tensor_into_value(self.indices);
669        (self.sorted, indices)
670    }
671
672    pub fn indices_value(&self) -> Value {
673        tensor::tensor_into_value(self.indices.clone())
674    }
675}
676
677#[cfg(test)]
678mod tests {
679    use super::*;
680    use crate::builtins::common::test_support;
681    use runmat_builtins::{ComplexTensor, IntValue, Tensor, Value};
682
683    #[test]
684    fn sort_vector_default() {
685        let tensor = Tensor::new(vec![3.0, 1.0, 2.0], vec![3, 1]).unwrap();
686        let result = sort_builtin(Value::Tensor(tensor), Vec::new()).expect("sort");
687        match result {
688            Value::Tensor(t) => {
689                assert_eq!(t.data, vec![1.0, 2.0, 3.0]);
690                assert_eq!(t.shape, vec![3, 1]);
691            }
692            other => panic!("expected tensor result, got {other:?}"),
693        }
694    }
695
696    #[test]
697    fn sort_descend_direction() {
698        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
699        let result =
700            sort_builtin(Value::Tensor(tensor), vec![Value::from("descend")]).expect("sort");
701        match result {
702            Value::Tensor(t) => assert_eq!(t.data, vec![4.0, 3.0, 2.0, 1.0]),
703            other => panic!("expected tensor, got {other:?}"),
704        }
705    }
706
707    #[test]
708    fn sort_matrix_default_dim1() {
709        let tensor = Tensor::new(vec![4.0, 2.0, 1.0, 5.0, 6.0, 3.0], vec![2, 3]).unwrap();
710        let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
711        let (sorted, indices) = eval.into_values();
712        match sorted {
713            Value::Tensor(t) => {
714                assert_eq!(t.data, vec![2.0, 4.0, 1.0, 5.0, 3.0, 6.0]);
715                assert_eq!(t.shape, vec![2, 3]);
716            }
717            other => panic!("expected tensor result, got {other:?}"),
718        }
719        match indices {
720            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0, 1.0, 2.0, 2.0, 1.0]),
721            other => panic!("expected tensor indices, got {other:?}"),
722        }
723    }
724
725    #[test]
726    fn sort_matrix_along_dimension_two() {
727        let tensor = Tensor::new(vec![1.0, 3.0, 4.0, 2.0, 2.0, 5.0], vec![2, 3]).unwrap();
728        let eval =
729            evaluate(Value::Tensor(tensor), &[Value::Int(IntValue::I32(2))]).expect("evaluate");
730        let (sorted, indices) = eval.into_values();
731        match sorted {
732            Value::Tensor(t) => {
733                assert_eq!(t.data, vec![1.0, 2.0, 2.0, 3.0, 4.0, 5.0]);
734                assert_eq!(t.shape, vec![2, 3]);
735            }
736            other => panic!("expected tensor result, got {other:?}"),
737        }
738        match indices {
739            Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]),
740            other => panic!("expected tensor indices, got {other:?}"),
741        }
742    }
743
744    #[test]
745    fn sort_dimension_placeholder_then_dim() {
746        let tensor = Tensor::new(vec![1.0, 3.0, 4.0, 2.0], vec![2, 2]).unwrap();
747        let placeholder = Tensor::new(Vec::new(), vec![0, 0]).unwrap();
748        let eval = evaluate(
749            Value::Tensor(tensor),
750            &[
751                Value::Tensor(placeholder),
752                Value::Int(IntValue::I32(2)),
753                Value::from("descend"),
754            ],
755        )
756        .expect("evaluate");
757        let (sorted, _) = eval.into_values();
758        match sorted {
759            Value::Tensor(t) => assert_eq!(t.data, vec![4.0, 3.0, 1.0, 2.0]),
760            other => panic!("expected tensor result, got {other:?}"),
761        }
762    }
763
764    #[test]
765    fn sort_descend_then_dimension() {
766        let tensor = Tensor::new(vec![1.0, 3.0, 4.0, 2.0, 2.0, 5.0], vec![2, 3]).unwrap();
767        let eval = evaluate(
768            Value::Tensor(tensor),
769            &[Value::from("descend"), Value::Int(IntValue::I32(1))],
770        )
771        .expect("evaluate");
772        let (sorted, _) = eval.into_values();
773        match sorted {
774            Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 1.0, 4.0, 2.0, 5.0, 2.0]),
775            other => panic!("expected tensor result, got {other:?}"),
776        }
777    }
778
779    #[test]
780    fn sort_returns_indices() {
781        let tensor = Tensor::new(vec![4.0, 1.0, 9.0, 2.0], vec![4, 1]).unwrap();
782        let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
783        let (sorted, indices) = eval.into_values();
784        match sorted {
785            Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 4.0, 9.0]),
786            other => panic!("expected tensor, got {other:?}"),
787        }
788        match indices {
789            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 4.0, 1.0, 3.0]),
790            other => panic!("expected tensor, got {other:?}"),
791        }
792    }
793
794    #[test]
795    fn sort_with_nan_handling() {
796        let tensor = Tensor::new(vec![f64::NAN, 4.0, 1.0, 2.0], vec![4, 1]).unwrap();
797        let eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("evaluate");
798        let (sorted, _) = eval.into_values();
799        match sorted {
800            Value::Tensor(t) => {
801                assert!(t.data[3].is_nan());
802                assert_eq!(&t.data[0..3], &[1.0, 2.0, 4.0]);
803            }
804            other => panic!("expected tensor, got {other:?}"),
805        }
806
807        let eval_desc =
808            evaluate(Value::Tensor(tensor), &[Value::from("descend")]).expect("evaluate");
809        let (sorted_desc, _) = eval_desc.into_values();
810        match sorted_desc {
811            Value::Tensor(t) => {
812                assert!(t.data[0].is_nan());
813                assert_eq!(&t.data[1..], &[4.0, 2.0, 1.0]);
814            }
815            other => panic!("expected tensor, got {other:?}"),
816        }
817    }
818
819    #[test]
820    fn sort_by_absolute_value() {
821        let tensor = Tensor::new(vec![-8.0, -1.0, 3.0, -2.0], vec![4, 1]).unwrap();
822        let eval = evaluate(
823            Value::Tensor(tensor),
824            &[Value::from("ComparisonMethod"), Value::from("abs")],
825        )
826        .expect("evaluate");
827        let (sorted, _) = eval.into_values();
828        match sorted {
829            Value::Tensor(t) => assert_eq!(t.data, vec![-1.0, -2.0, 3.0, -8.0]),
830            other => panic!("expected tensor, got {other:?}"),
831        }
832    }
833
834    #[test]
835    fn sort_by_absolute_value_descend() {
836        let tensor = Tensor::new(vec![-1.0, 2.0, -3.0, 4.0], vec![4, 1]).unwrap();
837        let eval = evaluate(
838            Value::Tensor(tensor),
839            &[
840                Value::from("descend"),
841                Value::from("ComparisonMethod"),
842                Value::from("abs"),
843            ],
844        )
845        .expect("evaluate");
846        let (sorted, _) = eval.into_values();
847        match sorted {
848            Value::Tensor(t) => assert_eq!(t.data, vec![4.0, -3.0, 2.0, -1.0]),
849            other => panic!("expected tensor, got {other:?}"),
850        }
851    }
852
853    #[test]
854    fn sort_complex_auto_abs() {
855        let tensor =
856            ComplexTensor::new(vec![(1.0, 2.0), (-3.0, 0.5), (0.0, -1.0)], vec![3, 1]).unwrap();
857        let eval = evaluate(Value::ComplexTensor(tensor), &[]).expect("evaluate");
858        let (sorted, indices) = eval.into_values();
859        match sorted {
860            Value::ComplexTensor(t) => {
861                assert_eq!(t.data, vec![(0.0, -1.0), (1.0, 2.0), (-3.0, 0.5)])
862            }
863            other => panic!("expected complex tensor, got {other:?}"),
864        }
865        match indices {
866            Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 1.0, 2.0]),
867            other => panic!("expected tensor indices, got {other:?}"),
868        }
869    }
870
871    #[test]
872    fn sort_complex_real_descend() {
873        let tensor =
874            ComplexTensor::new(vec![(1.0, 2.0), (-3.0, 0.0), (1.0, -1.0)], vec![3, 1]).unwrap();
875        let eval = evaluate(
876            Value::ComplexTensor(tensor),
877            &[
878                Value::from("descend"),
879                Value::from("ComparisonMethod"),
880                Value::from("real"),
881            ],
882        )
883        .expect("evaluate");
884        let (sorted, _) = eval.into_values();
885        match sorted {
886            Value::ComplexTensor(t) => {
887                assert_eq!(t.data, vec![(1.0, 2.0), (1.0, -1.0), (-3.0, 0.0)]);
888            }
889            other => panic!("expected complex tensor, got {other:?}"),
890        }
891    }
892
893    #[test]
894    fn sort_stable_with_duplicates() {
895        let tensor = Tensor::new(vec![2.0, 2.0, 1.0, 2.0], vec![4, 1]).unwrap();
896        let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
897        let (sorted, indices) = eval.into_values();
898        match sorted {
899            Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 2.0, 2.0]),
900            other => panic!("expected tensor, got {other:?}"),
901        }
902        match indices {
903            Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 1.0, 2.0, 4.0]),
904            other => panic!("expected tensor indices, got {other:?}"),
905        }
906    }
907
908    #[test]
909    fn sort_empty_tensor() {
910        let tensor = Tensor::new(Vec::new(), vec![0, 3]).unwrap();
911        let eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("evaluate");
912        let (sorted, indices) = eval.into_values();
913        match sorted {
914            Value::Tensor(t) => {
915                assert!(t.data.is_empty());
916                assert_eq!(t.shape, tensor.shape);
917            }
918            other => panic!("expected tensor, got {other:?}"),
919        }
920        match indices {
921            Value::Tensor(t) => assert!(t.data.is_empty()),
922            other => panic!("expected tensor, got {other:?}"),
923        }
924    }
925
926    #[test]
927    fn sort_dim_greater_than_ndims() {
928        let tensor = Tensor::new(vec![4.0, 2.0, 3.0, 1.0], vec![2, 2]).unwrap();
929        let eval = evaluate(
930            Value::Tensor(tensor.clone()),
931            &[Value::Int(IntValue::I32(3))],
932        )
933        .expect("evaluate");
934        let (sorted, indices) = eval.into_values();
935        match sorted {
936            Value::Tensor(t) => assert_eq!(t.data, tensor.data),
937            other => panic!("expected tensor, got {other:?}"),
938        }
939        match indices {
940            Value::Tensor(t) => assert!(t.data.iter().all(|v| (*v - 1.0).abs() < f64::EPSILON)),
941            other => panic!("expected tensor, got {other:?}"),
942        }
943    }
944
945    #[test]
946    fn sort_invalid_argument_errors() {
947        let err = sort_builtin(
948            Value::Tensor(Tensor::new(vec![1.0], vec![1, 1]).unwrap()),
949            vec![Value::from("missingplacement"), Value::from("first")],
950        )
951        .unwrap_err();
952        assert!(err.contains("MissingPlacement"));
953    }
954
955    #[test]
956    fn sort_invalid_comparison_method_errors() {
957        let err = sort_builtin(
958            Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap()),
959            vec![Value::from("ComparisonMethod"), Value::from("unknown")],
960        )
961        .unwrap_err();
962        assert!(err.contains("ComparisonMethod"), "unexpected error: {err}");
963    }
964
965    #[test]
966    fn sort_invalid_comparison_method_value_errors() {
967        let err = sort_builtin(
968            Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap()),
969            vec![
970                Value::from("ComparisonMethod"),
971                Value::Int(IntValue::I32(1)),
972            ],
973        )
974        .unwrap_err();
975        assert!(
976            err.contains("requires a string value"),
977            "unexpected error: {err}"
978        );
979    }
980
981    #[test]
982    fn sort_dimension_zero_errors() {
983        let err = sort_builtin(
984            Value::Tensor(Tensor::new(vec![1.0], vec![1, 1]).unwrap()),
985            vec![Value::Num(0.0)],
986        )
987        .unwrap_err();
988        assert!(
989            err.contains("dimension must be >= 1"),
990            "unexpected error: {err}"
991        );
992    }
993
994    #[test]
995    fn sort_gpu_round_trip() {
996        test_support::with_test_provider(|provider| {
997            let tensor = Tensor::new(vec![3.0, 1.0, 2.0], vec![3, 1]).unwrap();
998            let view = runmat_accelerate_api::HostTensorView {
999                data: &tensor.data,
1000                shape: &tensor.shape,
1001            };
1002            let handle = provider.upload(&view).expect("upload");
1003            let eval = evaluate(Value::GpuTensor(handle), &[]).expect("evaluate");
1004            let (sorted, indices) = eval.into_values();
1005            match sorted {
1006                Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 3.0]),
1007                other => panic!("expected tensor, got {other:?}"),
1008            }
1009            match indices {
1010                Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
1011                other => panic!("expected tensor, got {other:?}"),
1012            }
1013        });
1014    }
1015
1016    #[test]
1017    #[cfg(feature = "wgpu")]
1018    fn sort_wgpu_matches_cpu() {
1019        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1020            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1021        );
1022        let tensor = Tensor::new(vec![4.0, 1.0, 3.0, 2.0], vec![4, 1]).unwrap();
1023        let cpu_eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("cpu sort");
1024        let (cpu_sorted, cpu_indices) = cpu_eval.into_values();
1025
1026        let gpu_view = runmat_accelerate_api::HostTensorView {
1027            data: &tensor.data,
1028            shape: &tensor.shape,
1029        };
1030        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1031        let handle = provider.upload(&gpu_view).expect("upload");
1032        let gpu_eval = evaluate(Value::GpuTensor(handle), &[]).expect("gpu sort");
1033        let (gpu_sorted, gpu_indices) = gpu_eval.into_values();
1034
1035        let cpu_sorted_tensor = match cpu_sorted {
1036            Value::Tensor(t) => t,
1037            Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
1038            other => panic!("unexpected CPU sorted value {other:?}"),
1039        };
1040        let cpu_indices_tensor = match cpu_indices {
1041            Value::Tensor(t) => t,
1042            Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
1043            other => panic!("unexpected CPU indices value {other:?}"),
1044        };
1045        let gpu_sorted_tensor = match gpu_sorted {
1046            Value::Tensor(t) => t,
1047            Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
1048            other => panic!("unexpected GPU sorted value {other:?}"),
1049        };
1050        let gpu_indices_tensor = match gpu_indices {
1051            Value::Tensor(t) => t,
1052            Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
1053            other => panic!("unexpected GPU indices value {other:?}"),
1054        };
1055
1056        assert_eq!(gpu_sorted_tensor.data, cpu_sorted_tensor.data);
1057        assert_eq!(gpu_indices_tensor.data, cpu_indices_tensor.data);
1058    }
1059
1060    #[cfg(feature = "doc_export")]
1061    #[test]
1062    fn doc_examples_present() {
1063        let blocks = test_support::doc_examples(DOC_MD);
1064        assert!(!blocks.is_empty());
1065    }
1066}