runmat_runtime/builtins/array/sorting_sets/
sortrows.rs

1//! MATLAB-compatible `sortrows` builtin with GPU-aware semantics.
2
3use std::cmp::Ordering;
4
5use runmat_accelerate_api::{
6    GpuTensorHandle, SortComparison as ProviderSortComparison, SortOrder as ProviderSortOrder,
7    SortResult as ProviderSortResult, SortRowsColumnSpec as ProviderSortRowsColumnSpec,
8};
9use runmat_builtins::{CharArray, ComplexTensor, Tensor, Value};
10use runmat_macros::runtime_builtin;
11
12use crate::builtins::common::gpu_helpers;
13use crate::builtins::common::spec::{
14    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
15    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
16};
17use crate::builtins::common::tensor;
18#[cfg(feature = "doc_export")]
19use crate::register_builtin_doc_text;
20use crate::{register_builtin_fusion_spec, register_builtin_gpu_spec};
21
22#[cfg(feature = "doc_export")]
23pub const DOC_MD: &str = r#"---
24title: "sortrows"
25category: "array/sorting_sets"
26keywords: ["sortrows", "row sort", "lexicographic", "gpu"]
27summary: "Sort matrix rows lexicographically with optional column and direction control."
28references:
29  - https://www.mathworks.com/help/matlab/ref/sortrows.html
30gpu_support:
31  elementwise: false
32  reduction: false
33  precisions: ["f32", "f64"]
34  broadcasting: "none"
35  notes: "Falls back to host memory when providers do not expose a dedicated row sort kernel."
36fusion:
37  elementwise: false
38  reduction: false
39  max_inputs: 1
40  constants: "inline"
41requires_feature: null
42tested:
43  unit: "builtins::array::sorting_sets::sortrows::tests"
44  integration: "builtins::array::sorting_sets::sortrows::tests::sortrows_gpu_roundtrip"
45---
46
47# What does the `sortrows` function do in MATLAB / RunMat?
48`sortrows` reorders the rows of a matrix (or character array) so they appear in lexicographic order.
49You can control which columns participate in the comparison and whether each column uses ascending or descending order.
50
51## How does the `sortrows` function behave in MATLAB / RunMat?
52- `sortrows(A)` sorts by column `1`, then column `2`, and so on, all in ascending order.
53- `sortrows(A, C)` treats the vector `C` as the column order. Positive entries sort ascending; negative entries sort descending.
54- `sortrows(A, 'descend')` sorts all columns in descending order. Combine this with a column vector to mix directions.
55- `[B, I] = sortrows(A, ...)` also returns `I`, the 1-based row permutation indices.
56- `sortrows` is stable: rows that compare equal keep their original order.
57- For complex inputs, `'ComparisonMethod'` accepts `'auto'`, `'real'`, or `'abs'`, matching MATLAB semantics.
58- NaN handling mirrors MATLAB: in ascending sorts rows containing NaN values move to the end; in descending sorts they move to the beginning.
59- `'MissingPlacement'` lets you choose whether NaN (and other missing) rows appear `'first'`, `'last'`, or follow MATLAB's `'auto'` default.
60- Character arrays are sorted lexicographically using their character codes.
61
62## `sortrows` Function GPU Execution Behaviour
63- `sortrows` is registered as a sink builtin. When the input tensor already lives on the GPU and the active provider exposes a `sortrows` hook, the runtime delegates to that hook; the current provider contract returns host buffers, so the sorted rows and permutation indices are materialised on the CPU before being returned.
64- When the provider lacks the hook—or cannot honour a specific combination of options such as `'MissingPlacement','first'` or `'MissingPlacement','last'`—RunMat gathers the tensor and performs the sort on the host while preserving MATLAB semantics.
65- Name-value options that the provider does not advertize fall back automatically; callers do not need to special-case GPU vs CPU execution.
66- The permutation indices are emitted as double-precision column vectors so they can be reused directly for MATLAB-style indexing.
67
68## Examples of using `sortrows` in MATLAB / RunMat
69
70### Sorting rows of a matrix in ascending order
71```matlab
72A = [3 2; 1 4; 2 1];
73B = sortrows(A);
74```
75Expected output:
76```matlab
77B =
78     1     4
79     2     1
80     3     2
81```
82
83### Sorting by a custom column order
84```matlab
85A = [1 4 2; 3 2 5; 3 2 1];
86B = sortrows(A, [2 3 1]);
87```
88Expected output:
89```matlab
90B =
91     3     2     1
92     3     2     5
93     1     4     2
94```
95
96### Sorting rows in descending order
97```matlab
98A = [2 8; 4 1; 3 5];
99B = sortrows(A, 'descend');
100```
101Expected output:
102```matlab
103B =
104     4     1
105     3     5
106     2     8
107```
108
109### Mixing ascending and descending directions
110```matlab
111A = [1 7 3; 1 2 9; 1 2 3];
112B = sortrows(A, [1 -2 3]);
113```
114Expected output:
115```matlab
116B =
117     1     7     3
118     1     2     3
119     1     2     9
120```
121
122### Sorting rows of a character array
123```matlab
124names = ['bob '; 'al  '; 'ally'];
125sorted = sortrows(names);
126```
127Expected output:
128```matlab
129sorted =
130al
131ally
132bob
133```
134
135### Sorting rows of complex data by magnitude
136```matlab
137Z = [3+4i, 3; 1+2i, 4];
138B = sortrows(Z, 'ComparisonMethod', 'abs');
139```
140Expected output:
141```matlab
142B =
143    1.0000 + 2.0000i    4.0000
144    3.0000 + 4.0000i    3.0000
145```
146
147### Forcing NaN rows to the top
148```matlab
149A = [1 NaN; NaN 2];
150B = sortrows(A, 'MissingPlacement', 'first');
151```
152Expected output:
153```matlab
154B =
155   NaN     2
156     1   NaN
157```
158
159### Sorting GPU-resident data with automatic host fallback
160```matlab
161G = gpuArray([3 1; 2 4; 1 2]);
162[B, I] = sortrows(G);
163```
164The runtime gathers `G`, performs the sort on the host, and returns host tensors. The permutation indices `I`
165match MATLAB's 1-based output.
166
167## FAQ
168
169### Can I request the permutation indices?
170Yes. Call `[B, I] = sortrows(A, ...)` to receive the 1-based row permutation indices in `I`.
171
172### How do I sort specific columns?
173Provide a column vector, e.g. `sortrows(A, [2 -3])` sorts by column `2` ascending and column `3` descending.
174
175### What happens when rows contain NaN values?
176Rows containing NaNs move to the bottom for ascending sorts and to the top for descending sorts when `'MissingPlacement'` is left at its `'auto'` default, matching MATLAB.
177
178### How can I force NaNs or missing values to the top or bottom?
179Use the name-value pair `'MissingPlacement','first'` to place missing rows before finite ones, or `'MissingPlacement','last'` to move them to the end regardless of direction.
180
181### Does `sortrows` work with complex numbers?
182Yes. Use `'ComparisonMethod','real'` to sort by the real component or `'abs'` to sort by magnitude (the default behaviour matches MATLAB's `'auto'` rules).
183
184### Can I combine a direction string with a column vector?
185Yes. `sortrows(A, [1 3], 'descend')` applies descending order to both columns after applying the specified column order.
186
187### Is the operation stable?
188Yes. Rows that compare equal remain in their original order.
189
190### Does `sortrows` mutate its input?
191No. It returns a sorted copy of the input. GPU inputs are gathered to host memory when required.
192
193### Are string arrays supported?
194String arrays are not yet supported. Convert them to character matrices or use tables before sorting.
195
196## See Also
197[sort](./sort), [unique](./unique), [max](../../math/reduction/max), [min](../../math/reduction/min), [permute](../../array/shape/permute)
198
199## Source & Feedback
200- Source code: [`crates/runmat-runtime/src/builtins/array/sorting_sets/sortrows.rs`](https://github.com/runmat-org/runmat/blob/main/crates/runmat-runtime/src/builtins/array/sorting_sets/sortrows.rs)
201- Found a bug? [Open an issue](https://github.com/runmat-org/runmat/issues/new/choose) with a minimal reproduction.
202"#;
203
204pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
205    name: "sortrows",
206    op_kind: GpuOpKind::Custom("sortrows"),
207    supported_precisions: &[ScalarType::F32, ScalarType::F64],
208    broadcast: BroadcastSemantics::None,
209    provider_hooks: &[ProviderHook::Custom("sortrows")],
210    constant_strategy: ConstantStrategy::InlineLiteral,
211    residency: ResidencyPolicy::GatherImmediately,
212    nan_mode: ReductionNaN::Include,
213    two_pass_threshold: None,
214    workgroup_size: None,
215    accepts_nan_mode: true,
216    notes:
217        "Providers may implement a row-sort kernel; explicit MissingPlacement overrides fall back to host memory until native support exists.",
218};
219
220register_builtin_gpu_spec!(GPU_SPEC);
221
222pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
223    name: "sortrows",
224    shape: ShapeRequirements::Any,
225    constant_strategy: ConstantStrategy::InlineLiteral,
226    elementwise: None,
227    reduction: None,
228    emits_nan: true,
229    notes: "Acts as a sink operation; upstream fusion chains terminate before sorting rows.",
230};
231
232register_builtin_fusion_spec!(FUSION_SPEC);
233
234#[cfg(feature = "doc_export")]
235register_builtin_doc_text!("sortrows", DOC_MD);
236
237#[runtime_builtin(
238    name = "sortrows",
239    category = "array/sorting_sets",
240    summary = "Sort matrix rows lexicographically with optional column and direction control.",
241    keywords = "sortrows,row sort,lexicographic,gpu",
242    accel = "sink",
243    sink = true
244)]
245fn sortrows_builtin(value: Value, rest: Vec<Value>) -> Result<Value, String> {
246    evaluate(value, &rest).map(|eval| eval.into_sorted_value())
247}
248
249/// Evaluate `sortrows`, returning both sorted values and permutation indices.
250pub fn evaluate(value: Value, rest: &[Value]) -> Result<SortRowsEvaluation, String> {
251    match value {
252        Value::GpuTensor(handle) => sortrows_gpu(handle, rest),
253        other => sortrows_host(other, rest),
254    }
255}
256
257fn sortrows_gpu(handle: GpuTensorHandle, rest: &[Value]) -> Result<SortRowsEvaluation, String> {
258    ensure_matrix_shape(&handle.shape)?;
259    let (_, cols) = rows_cols_from_shape(&handle.shape);
260    let args = SortRowsArgs::parse(rest, cols)?;
261
262    if args.missing_is_auto() {
263        if let Some(provider) = runmat_accelerate_api::provider() {
264            let provider_columns = args.to_provider_columns();
265            let provider_comparison = args.provider_comparison();
266            match provider.sort_rows(&handle, &provider_columns, provider_comparison) {
267                Ok(result) => return sortrows_from_provider_result(result),
268                Err(_err) => {
269                    // fall back to host path when provider cannot service the request
270                }
271            }
272        }
273    }
274
275    let tensor = gpu_helpers::gather_tensor(&handle)?;
276    sortrows_real_tensor_with_args(tensor, &args)
277}
278
279fn sortrows_from_provider_result(result: ProviderSortResult) -> Result<SortRowsEvaluation, String> {
280    let sorted_tensor = Tensor::new(result.values.data, result.values.shape)
281        .map_err(|e| format!("sortrows: {e}"))?;
282    let indices_tensor = Tensor::new(result.indices.data, result.indices.shape)
283        .map_err(|e| format!("sortrows: {e}"))?;
284    Ok(SortRowsEvaluation {
285        sorted: tensor::tensor_into_value(sorted_tensor),
286        indices: indices_tensor,
287    })
288}
289
290fn sortrows_host(value: Value, rest: &[Value]) -> Result<SortRowsEvaluation, String> {
291    match value {
292        Value::Tensor(tensor) => sortrows_real_tensor(tensor, rest),
293        Value::LogicalArray(logical) => {
294            let tensor = tensor::logical_to_tensor(&logical)?;
295            sortrows_real_tensor(tensor, rest)
296        }
297        Value::Num(_) | Value::Int(_) | Value::Bool(_) => {
298            let tensor = tensor::value_into_tensor_for("sortrows", value)?;
299            sortrows_real_tensor(tensor, rest)
300        }
301        Value::ComplexTensor(ct) => sortrows_complex_tensor(ct, rest),
302        Value::Complex(re, im) => {
303            let tensor =
304                ComplexTensor::new(vec![(re, im)], vec![1, 1]).map_err(|e| format!("sortrows: {e}"))?;
305            sortrows_complex_tensor(tensor, rest)
306        }
307        Value::CharArray(ca) => sortrows_char_array(ca, rest),
308        other => Err(format!(
309            "sortrows: unsupported input type {:?}; expected numeric, logical, complex, or char arrays",
310            other
311        )),
312    }
313}
314
315fn sortrows_real_tensor(tensor: Tensor, rest: &[Value]) -> Result<SortRowsEvaluation, String> {
316    ensure_matrix_shape(&tensor.shape)?;
317    let cols = tensor.cols();
318    let args = SortRowsArgs::parse(rest, cols)?;
319    sortrows_real_tensor_with_args(tensor, &args)
320}
321
322fn sortrows_real_tensor_with_args(
323    tensor: Tensor,
324    args: &SortRowsArgs,
325) -> Result<SortRowsEvaluation, String> {
326    let rows = tensor.rows();
327    let cols = tensor.cols();
328
329    if rows <= 1 || cols == 0 || tensor.data.is_empty() || args.columns.is_empty() {
330        let indices = identity_indices(rows)?;
331        return Ok(SortRowsEvaluation {
332            sorted: tensor::tensor_into_value(tensor),
333            indices,
334        });
335    }
336
337    let mut order: Vec<usize> = (0..rows).collect();
338    order.sort_by(|&a, &b| compare_real_rows(&tensor, rows, args, a, b));
339
340    let sorted_tensor = reorder_real_rows(&tensor, rows, cols, &order)?;
341    let indices = permutation_indices(&order)?;
342    Ok(SortRowsEvaluation {
343        sorted: tensor::tensor_into_value(sorted_tensor),
344        indices,
345    })
346}
347
348fn sortrows_complex_tensor(
349    tensor: ComplexTensor,
350    rest: &[Value],
351) -> Result<SortRowsEvaluation, String> {
352    ensure_matrix_shape(&tensor.shape)?;
353    let cols = tensor.cols;
354    let args = SortRowsArgs::parse(rest, cols)?;
355    sortrows_complex_tensor_with_args(tensor, &args)
356}
357
358fn sortrows_complex_tensor_with_args(
359    tensor: ComplexTensor,
360    args: &SortRowsArgs,
361) -> Result<SortRowsEvaluation, String> {
362    let rows = tensor.rows;
363    let cols = tensor.cols;
364
365    if rows <= 1 || cols == 0 || tensor.data.is_empty() || args.columns.is_empty() {
366        let indices = identity_indices(rows)?;
367        return Ok(SortRowsEvaluation {
368            sorted: complex_tensor_into_value(tensor),
369            indices,
370        });
371    }
372
373    let mut order: Vec<usize> = (0..rows).collect();
374    order.sort_by(|&a, &b| compare_complex_rows(&tensor, rows, args, a, b));
375
376    let sorted_tensor = reorder_complex_rows(&tensor, rows, cols, &order)?;
377    let indices = permutation_indices(&order)?;
378    Ok(SortRowsEvaluation {
379        sorted: complex_tensor_into_value(sorted_tensor),
380        indices,
381    })
382}
383
384fn sortrows_char_array(ca: CharArray, rest: &[Value]) -> Result<SortRowsEvaluation, String> {
385    let cols = ca.cols;
386    let args = SortRowsArgs::parse(rest, cols)?;
387    sortrows_char_array_with_args(ca, &args)
388}
389
390fn sortrows_char_array_with_args(
391    ca: CharArray,
392    args: &SortRowsArgs,
393) -> Result<SortRowsEvaluation, String> {
394    let rows = ca.rows;
395    let cols = ca.cols;
396
397    if rows <= 1 || cols == 0 || ca.data.is_empty() || args.columns.is_empty() {
398        let indices = identity_indices(rows)?;
399        return Ok(SortRowsEvaluation {
400            sorted: Value::CharArray(ca),
401            indices,
402        });
403    }
404
405    let mut order: Vec<usize> = (0..rows).collect();
406    order.sort_by(|&a, &b| compare_char_rows(&ca, args, a, b));
407
408    let sorted = reorder_char_rows(&ca, rows, cols, &order)?;
409    let indices = permutation_indices(&order)?;
410    Ok(SortRowsEvaluation {
411        sorted: Value::CharArray(sorted),
412        indices,
413    })
414}
415
416fn ensure_matrix_shape(shape: &[usize]) -> Result<(), String> {
417    if shape.len() <= 2 {
418        Ok(())
419    } else {
420        Err("sortrows: input must be a 2-D matrix".to_string())
421    }
422}
423
424fn rows_cols_from_shape(shape: &[usize]) -> (usize, usize) {
425    match shape.len() {
426        0 => (1, 1),
427        1 => (1, shape[0]),
428        _ => (shape[0], shape[1]),
429    }
430}
431
432fn compare_real_rows(
433    tensor: &Tensor,
434    rows: usize,
435    args: &SortRowsArgs,
436    a: usize,
437    b: usize,
438) -> Ordering {
439    for spec in &args.columns {
440        if spec.index >= tensor.cols() {
441            continue;
442        }
443        let idx_a = a + spec.index * rows;
444        let idx_b = b + spec.index * rows;
445        let va = tensor.data[idx_a];
446        let vb = tensor.data[idx_b];
447        let missing = args.missing_for_direction(spec.direction);
448        let ord = compare_real_scalars(va, vb, spec.direction, args.comparison, missing);
449        if ord != Ordering::Equal {
450            return ord;
451        }
452    }
453    Ordering::Equal
454}
455
456fn compare_complex_rows(
457    tensor: &ComplexTensor,
458    rows: usize,
459    args: &SortRowsArgs,
460    a: usize,
461    b: usize,
462) -> Ordering {
463    for spec in &args.columns {
464        if spec.index >= tensor.cols {
465            continue;
466        }
467        let idx_a = a + spec.index * rows;
468        let idx_b = b + spec.index * rows;
469        let va = tensor.data[idx_a];
470        let vb = tensor.data[idx_b];
471        let missing = args.missing_for_direction(spec.direction);
472        let ord = compare_complex_scalars(va, vb, spec.direction, args.comparison, missing);
473        if ord != Ordering::Equal {
474            return ord;
475        }
476    }
477    Ordering::Equal
478}
479
480fn compare_char_rows(ca: &CharArray, args: &SortRowsArgs, a: usize, b: usize) -> Ordering {
481    for spec in &args.columns {
482        if spec.index >= ca.cols {
483            continue;
484        }
485        let idx_a = a * ca.cols + spec.index;
486        let idx_b = b * ca.cols + spec.index;
487        let va = ca.data[idx_a];
488        let vb = ca.data[idx_b];
489        let ord = match spec.direction {
490            SortDirection::Ascend => va.cmp(&vb),
491            SortDirection::Descend => vb.cmp(&va),
492        };
493        if ord != Ordering::Equal {
494            return ord;
495        }
496    }
497    Ordering::Equal
498}
499
500fn reorder_real_rows(
501    tensor: &Tensor,
502    rows: usize,
503    cols: usize,
504    order: &[usize],
505) -> Result<Tensor, String> {
506    let mut data = vec![0.0; tensor.data.len()];
507    for col in 0..cols {
508        for (dest_row, &src_row) in order.iter().enumerate() {
509            let src_idx = src_row + col * rows;
510            let dst_idx = dest_row + col * rows;
511            data[dst_idx] = tensor.data[src_idx];
512        }
513    }
514    Tensor::new(data, tensor.shape.clone()).map_err(|e| format!("sortrows: {e}"))
515}
516
517fn reorder_complex_rows(
518    tensor: &ComplexTensor,
519    rows: usize,
520    cols: usize,
521    order: &[usize],
522) -> Result<ComplexTensor, String> {
523    let mut data = vec![(0.0, 0.0); tensor.data.len()];
524    for col in 0..cols {
525        for (dest_row, &src_row) in order.iter().enumerate() {
526            let src_idx = src_row + col * rows;
527            let dst_idx = dest_row + col * rows;
528            data[dst_idx] = tensor.data[src_idx];
529        }
530    }
531    ComplexTensor::new(data, tensor.shape.clone()).map_err(|e| format!("sortrows: {e}"))
532}
533
534fn reorder_char_rows(
535    ca: &CharArray,
536    rows: usize,
537    cols: usize,
538    order: &[usize],
539) -> Result<CharArray, String> {
540    let mut data = vec!['\0'; ca.data.len()];
541    for (dest_row, &src_row) in order.iter().enumerate() {
542        for col in 0..cols {
543            let src_idx = src_row * cols + col;
544            let dst_idx = dest_row * cols + col;
545            data[dst_idx] = ca.data[src_idx];
546        }
547    }
548    CharArray::new(data, rows, cols).map_err(|e| format!("sortrows: {e}"))
549}
550
551fn compare_real_scalars(
552    a: f64,
553    b: f64,
554    direction: SortDirection,
555    comparison: ComparisonMethod,
556    missing: MissingPlacementResolved,
557) -> Ordering {
558    match (a.is_nan(), b.is_nan()) {
559        (true, true) => Ordering::Equal,
560        (true, false) => match missing {
561            MissingPlacementResolved::First => Ordering::Less,
562            MissingPlacementResolved::Last => Ordering::Greater,
563        },
564        (false, true) => match missing {
565            MissingPlacementResolved::First => Ordering::Greater,
566            MissingPlacementResolved::Last => Ordering::Less,
567        },
568        (false, false) => compare_real_finite_scalars(a, b, direction, comparison),
569    }
570}
571
572fn compare_real_finite_scalars(
573    a: f64,
574    b: f64,
575    direction: SortDirection,
576    comparison: ComparisonMethod,
577) -> Ordering {
578    if matches!(comparison, ComparisonMethod::Abs) {
579        let abs_cmp = a.abs().partial_cmp(&b.abs()).unwrap_or(Ordering::Equal);
580        if abs_cmp != Ordering::Equal {
581            return match direction {
582                SortDirection::Ascend => abs_cmp,
583                SortDirection::Descend => abs_cmp.reverse(),
584            };
585        }
586    }
587    match direction {
588        SortDirection::Ascend => a.partial_cmp(&b).unwrap_or(Ordering::Equal),
589        SortDirection::Descend => b.partial_cmp(&a).unwrap_or(Ordering::Equal),
590    }
591}
592
593fn compare_complex_scalars(
594    a: (f64, f64),
595    b: (f64, f64),
596    direction: SortDirection,
597    comparison: ComparisonMethod,
598    missing: MissingPlacementResolved,
599) -> Ordering {
600    match (complex_is_nan(a), complex_is_nan(b)) {
601        (true, true) => Ordering::Equal,
602        (true, false) => match missing {
603            MissingPlacementResolved::First => Ordering::Less,
604            MissingPlacementResolved::Last => Ordering::Greater,
605        },
606        (false, true) => match missing {
607            MissingPlacementResolved::First => Ordering::Greater,
608            MissingPlacementResolved::Last => Ordering::Less,
609        },
610        (false, false) => compare_complex_finite_scalars(a, b, direction, comparison),
611    }
612}
613
614fn compare_complex_finite_scalars(
615    a: (f64, f64),
616    b: (f64, f64),
617    direction: SortDirection,
618    comparison: ComparisonMethod,
619) -> Ordering {
620    match comparison {
621        ComparisonMethod::Real => compare_complex_real_first(a, b, direction),
622        ComparisonMethod::Auto | ComparisonMethod::Abs => {
623            let abs_cmp = complex_abs(a)
624                .partial_cmp(&complex_abs(b))
625                .unwrap_or(Ordering::Equal);
626            if abs_cmp != Ordering::Equal {
627                return match direction {
628                    SortDirection::Ascend => abs_cmp,
629                    SortDirection::Descend => abs_cmp.reverse(),
630                };
631            }
632            compare_complex_real_first(a, b, direction)
633        }
634    }
635}
636
637fn compare_complex_real_first(a: (f64, f64), b: (f64, f64), direction: SortDirection) -> Ordering {
638    let real_cmp = match direction {
639        SortDirection::Ascend => a.0.partial_cmp(&b.0),
640        SortDirection::Descend => b.0.partial_cmp(&a.0),
641    }
642    .unwrap_or(Ordering::Equal);
643    if real_cmp != Ordering::Equal {
644        return real_cmp;
645    }
646    match direction {
647        SortDirection::Ascend => a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal),
648        SortDirection::Descend => b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal),
649    }
650}
651
652fn complex_is_nan(value: (f64, f64)) -> bool {
653    value.0.is_nan() || value.1.is_nan()
654}
655
656fn complex_abs(value: (f64, f64)) -> f64 {
657    value.0.hypot(value.1)
658}
659
660fn permutation_indices(order: &[usize]) -> Result<Tensor, String> {
661    let rows = order.len();
662    let mut data = Vec::with_capacity(rows);
663    for &idx in order {
664        data.push((idx + 1) as f64);
665    }
666    Tensor::new(data, vec![rows, 1]).map_err(|e| format!("sortrows: {e}"))
667}
668
669fn identity_indices(rows: usize) -> Result<Tensor, String> {
670    let mut data = Vec::with_capacity(rows);
671    for i in 0..rows {
672        data.push((i + 1) as f64);
673    }
674    Tensor::new(data, vec![rows, 1]).map_err(|e| format!("sortrows: {e}"))
675}
676
677fn complex_tensor_into_value(tensor: ComplexTensor) -> Value {
678    if tensor.data.len() == 1 {
679        Value::Complex(tensor.data[0].0, tensor.data[0].1)
680    } else {
681        Value::ComplexTensor(tensor)
682    }
683}
684
685#[derive(Debug, Clone, Copy, PartialEq, Eq)]
686enum SortDirection {
687    Ascend,
688    Descend,
689}
690
691impl SortDirection {
692    fn from_str(value: &str) -> Option<Self> {
693        match value.trim().to_ascii_lowercase().as_str() {
694            "ascend" | "ascending" => Some(SortDirection::Ascend),
695            "descend" | "descending" => Some(SortDirection::Descend),
696            _ => None,
697        }
698    }
699}
700
701#[derive(Debug, Clone, Copy, PartialEq, Eq)]
702enum ComparisonMethod {
703    Auto,
704    Real,
705    Abs,
706}
707
708#[derive(Debug, Clone, Copy, PartialEq, Eq)]
709enum MissingPlacement {
710    Auto,
711    First,
712    Last,
713}
714
715#[derive(Debug, Clone, Copy, PartialEq, Eq)]
716enum MissingPlacementResolved {
717    First,
718    Last,
719}
720
721impl MissingPlacement {
722    fn resolve(self, direction: SortDirection) -> MissingPlacementResolved {
723        match self {
724            MissingPlacement::First => MissingPlacementResolved::First,
725            MissingPlacement::Last => MissingPlacementResolved::Last,
726            MissingPlacement::Auto => match direction {
727                SortDirection::Ascend => MissingPlacementResolved::Last,
728                SortDirection::Descend => MissingPlacementResolved::First,
729            },
730        }
731    }
732
733    fn is_auto(self) -> bool {
734        matches!(self, MissingPlacement::Auto)
735    }
736}
737
738#[derive(Debug, Clone)]
739struct ColumnSpec {
740    index: usize,
741    direction: SortDirection,
742}
743
744#[derive(Debug, Clone)]
745struct SortRowsArgs {
746    columns: Vec<ColumnSpec>,
747    comparison: ComparisonMethod,
748    missing: MissingPlacement,
749}
750
751impl SortRowsArgs {
752    fn parse(rest: &[Value], num_cols: usize) -> Result<Self, String> {
753        let mut columns: Option<Vec<ColumnSpec>> = None;
754        let mut override_direction: Option<SortDirection> = None;
755        let mut comparison = ComparisonMethod::Auto;
756        let mut missing = MissingPlacement::Auto;
757        let mut i = 0usize;
758
759        while i < rest.len() {
760            if columns.is_none() {
761                if let Some(parsed) = parse_column_vector(&rest[i], num_cols)? {
762                    columns = Some(parsed);
763                    i += 1;
764                    continue;
765                }
766            }
767            if let Some(direction) = parse_direction(&rest[i]) {
768                override_direction = Some(direction);
769                i += 1;
770                continue;
771            }
772            let Some(keyword) = tensor::value_to_string(&rest[i]) else {
773                return Err(format!("sortrows: invalid argument {:?}", rest[i]));
774            };
775            let lowered = keyword.trim().to_ascii_lowercase();
776            match lowered.as_str() {
777                "comparisonmethod" => {
778                    i += 1;
779                    if i >= rest.len() {
780                        return Err("sortrows: expected a value for 'ComparisonMethod'".to_string());
781                    }
782                    let Some(value_str) = tensor::value_to_string(&rest[i]) else {
783                        return Err(
784                            "sortrows: 'ComparisonMethod' expects a string value".to_string()
785                        );
786                    };
787                    comparison = match value_str.trim().to_ascii_lowercase().as_str() {
788                        "auto" => ComparisonMethod::Auto,
789                        "real" => ComparisonMethod::Real,
790                        "abs" | "magnitude" => ComparisonMethod::Abs,
791                        other => {
792                            return Err(format!("sortrows: unsupported ComparisonMethod '{other}'"))
793                        }
794                    };
795                    i += 1;
796                }
797                "missingplacement" => {
798                    i += 1;
799                    if i >= rest.len() {
800                        return Err("sortrows: expected a value for 'MissingPlacement'".to_string());
801                    }
802                    let Some(value_str) = tensor::value_to_string(&rest[i]) else {
803                        return Err(
804                            "sortrows: 'MissingPlacement' expects a string value".to_string()
805                        );
806                    };
807                    missing = match value_str.trim().to_ascii_lowercase().as_str() {
808                        "auto" => MissingPlacement::Auto,
809                        "first" => MissingPlacement::First,
810                        "last" => MissingPlacement::Last,
811                        other => {
812                            return Err(format!("sortrows: unsupported MissingPlacement '{other}'"))
813                        }
814                    };
815                    i += 1;
816                }
817                other => {
818                    return Err(format!("sortrows: unexpected argument '{other}'"));
819                }
820            }
821        }
822
823        let mut columns = columns.unwrap_or_else(|| default_columns(num_cols));
824        if let Some(dir) = override_direction {
825            for spec in &mut columns {
826                spec.direction = dir;
827            }
828        }
829        validate_columns(&columns, num_cols)?;
830
831        Ok(SortRowsArgs {
832            columns,
833            comparison,
834            missing,
835        })
836    }
837
838    fn to_provider_columns(&self) -> Vec<ProviderSortRowsColumnSpec> {
839        self.columns
840            .iter()
841            .map(|spec| ProviderSortRowsColumnSpec {
842                index: spec.index,
843                order: match spec.direction {
844                    SortDirection::Ascend => ProviderSortOrder::Ascend,
845                    SortDirection::Descend => ProviderSortOrder::Descend,
846                },
847            })
848            .collect()
849    }
850
851    fn provider_comparison(&self) -> ProviderSortComparison {
852        match self.comparison {
853            ComparisonMethod::Auto => ProviderSortComparison::Auto,
854            ComparisonMethod::Real => ProviderSortComparison::Real,
855            ComparisonMethod::Abs => ProviderSortComparison::Abs,
856        }
857    }
858
859    fn missing_for_direction(&self, direction: SortDirection) -> MissingPlacementResolved {
860        self.missing.resolve(direction)
861    }
862
863    fn missing_is_auto(&self) -> bool {
864        self.missing.is_auto()
865    }
866}
867
868fn parse_column_vector(value: &Value, num_cols: usize) -> Result<Option<Vec<ColumnSpec>>, String> {
869    match value {
870        Value::Int(i) => parse_single_column(i.to_i64(), num_cols).map(Some),
871        Value::Num(n) => {
872            if !n.is_finite() {
873                return Err("sortrows: column indices must be finite".to_string());
874            }
875            let rounded = n.round();
876            if (rounded - n).abs() > f64::EPSILON {
877                return Err("sortrows: column indices must be integers".to_string());
878            }
879            parse_single_column(rounded as i64, num_cols).map(Some)
880        }
881        Value::Tensor(tensor) => {
882            if !is_vector(&tensor.shape) {
883                return Err("sortrows: column specification must be a vector".to_string());
884            }
885            let mut specs = Vec::with_capacity(tensor.data.len());
886            for &entry in &tensor.data {
887                if !entry.is_finite() {
888                    return Err("sortrows: column indices must be finite".to_string());
889                }
890                let rounded = entry.round();
891                if (rounded - entry).abs() > f64::EPSILON {
892                    return Err("sortrows: column indices must be integers".to_string());
893                }
894                let column = parse_single_column_i64(rounded as i64, num_cols)?;
895                specs.push(column);
896            }
897            Ok(Some(specs))
898        }
899        _ => Ok(None),
900    }
901}
902
903fn parse_single_column(value: i64, num_cols: usize) -> Result<Vec<ColumnSpec>, String> {
904    parse_single_column_i64(value, num_cols).map(|spec| vec![spec])
905}
906
907fn parse_single_column_i64(value: i64, num_cols: usize) -> Result<ColumnSpec, String> {
908    if value == 0 {
909        return Err("sortrows: column indices must be non-zero".to_string());
910    }
911    let abs = value.unsigned_abs() as usize;
912    if abs == 0 {
913        return Err("sortrows: column indices must be >= 1".to_string());
914    }
915    if num_cols == 0 {
916        return Err("sortrows: column index exceeds matrix with 0 columns".to_string());
917    }
918    if abs > num_cols {
919        return Err(format!(
920            "sortrows: column index {} exceeds matrix with {} columns",
921            abs, num_cols
922        ));
923    }
924    let direction = if value > 0 {
925        SortDirection::Ascend
926    } else {
927        SortDirection::Descend
928    };
929    Ok(ColumnSpec {
930        index: abs - 1,
931        direction,
932    })
933}
934
935fn parse_direction(value: &Value) -> Option<SortDirection> {
936    tensor::value_to_string(value).and_then(|s| SortDirection::from_str(&s))
937}
938
939fn default_columns(num_cols: usize) -> Vec<ColumnSpec> {
940    let mut columns = Vec::with_capacity(num_cols);
941    for col in 0..num_cols {
942        columns.push(ColumnSpec {
943            index: col,
944            direction: SortDirection::Ascend,
945        });
946    }
947    columns
948}
949
950fn validate_columns(columns: &[ColumnSpec], num_cols: usize) -> Result<(), String> {
951    if num_cols == 0 && columns.iter().any(|spec| spec.index > 0) {
952        return Err("sortrows: column index exceeds matrix with 0 columns".to_string());
953    }
954    for spec in columns {
955        if num_cols > 0 && spec.index >= num_cols {
956            return Err(format!(
957                "sortrows: column index {} exceeds matrix with {} columns",
958                spec.index + 1,
959                num_cols
960            ));
961        }
962    }
963    Ok(())
964}
965
966fn is_vector(shape: &[usize]) -> bool {
967    match shape.len() {
968        0 => true,
969        1 => true,
970        2 => shape[0] == 1 || shape[1] == 1,
971        _ => false,
972    }
973}
974
975#[derive(Debug)]
976pub struct SortRowsEvaluation {
977    sorted: Value,
978    indices: Tensor,
979}
980
981impl SortRowsEvaluation {
982    pub fn into_sorted_value(self) -> Value {
983        self.sorted
984    }
985
986    pub fn into_values(self) -> (Value, Value) {
987        let indices = tensor::tensor_into_value(self.indices);
988        (self.sorted, indices)
989    }
990
991    pub fn indices_value(&self) -> Value {
992        tensor::tensor_into_value(self.indices.clone())
993    }
994}
995
996#[cfg(test)]
997mod tests {
998    use super::*;
999    use crate::builtins::common::test_support;
1000    use runmat_builtins::{IntValue, Value};
1001
1002    #[test]
1003    fn sortrows_default_matrix() {
1004        let tensor = Tensor::new(vec![3.0, 1.0, 2.0, 4.0, 1.0, 5.0], vec![3, 2]).unwrap();
1005        let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
1006        let (sorted, indices) = eval.into_values();
1007        match sorted {
1008            Value::Tensor(t) => {
1009                assert_eq!(t.shape, vec![3, 2]);
1010                assert_eq!(t.data, vec![1.0, 2.0, 3.0, 1.0, 5.0, 4.0]);
1011            }
1012            other => panic!("expected tensor, got {other:?}"),
1013        }
1014        match indices {
1015            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
1016            Value::Num(_) => panic!("expected tensor indices"),
1017            other => panic!("unexpected indices {other:?}"),
1018        }
1019    }
1020
1021    #[test]
1022    fn sortrows_with_column_vector() {
1023        let tensor = Tensor::new(
1024            vec![1.0, 3.0, 3.0, 4.0, 2.0, 2.0, 2.0, 5.0, 1.0],
1025            vec![3, 3],
1026        )
1027        .unwrap();
1028        let cols = Tensor::new(vec![2.0, 3.0, 1.0], vec![3, 1]).unwrap();
1029        let eval = evaluate(Value::Tensor(tensor), &[Value::Tensor(cols)]).expect("evaluate");
1030        let (sorted, _) = eval.into_values();
1031        match sorted {
1032            Value::Tensor(t) => {
1033                assert_eq!(t.data, vec![3.0, 3.0, 1.0, 2.0, 2.0, 4.0, 1.0, 5.0, 2.0]);
1034            }
1035            other => panic!("expected tensor, got {other:?}"),
1036        }
1037    }
1038
1039    #[test]
1040    fn sortrows_direction_descend() {
1041        let tensor = Tensor::new(vec![1.0, 2.0, 4.0, 3.0], vec![2, 2]).unwrap();
1042        let eval = evaluate(Value::Tensor(tensor), &[Value::from("descend")]).expect("evaluate");
1043        let (sorted, _) = eval.into_values();
1044        match sorted {
1045            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0, 3.0, 4.0]),
1046            other => panic!("expected tensor, got {other:?}"),
1047        }
1048    }
1049
1050    #[test]
1051    fn sortrows_mixed_directions() {
1052        let tensor = Tensor::new(vec![1.0, 1.0, 1.0, 1.0, 7.0, 2.0], vec![3, 2]).unwrap();
1053        let cols = Tensor::new(vec![1.0, -2.0], vec![2, 1]).unwrap();
1054        let eval = evaluate(Value::Tensor(tensor), &[Value::Tensor(cols)]).expect("evaluate");
1055        let (sorted, _) = eval.into_values();
1056        match sorted {
1057            Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 1.0, 1.0, 7.0, 2.0, 1.0]),
1058            other => panic!("expected tensor, got {other:?}"),
1059        }
1060    }
1061
1062    #[test]
1063    fn sortrows_returns_indices() {
1064        let tensor = Tensor::new(vec![2.0, 1.0, 3.0, 4.0], vec![2, 2]).unwrap();
1065        let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
1066        let (_, indices) = eval.into_values();
1067        match indices {
1068            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0]),
1069            Value::Num(_) => panic!("expected tensor indices"),
1070            other => panic!("unexpected indices {other:?}"),
1071        }
1072    }
1073
1074    #[test]
1075    fn sortrows_char_array() {
1076        let chars = CharArray::new(
1077            "bob "
1078                .chars()
1079                .chain("al  ".chars())
1080                .chain("ally".chars())
1081                .collect(),
1082            3,
1083            4,
1084        )
1085        .unwrap();
1086        let eval = evaluate(Value::CharArray(chars), &[]).expect("evaluate");
1087        let (sorted, _) = eval.into_values();
1088        match sorted {
1089            Value::CharArray(ca) => {
1090                assert_eq!(ca.rows, 3);
1091                assert_eq!(ca.cols, 4);
1092                let strings: Vec<String> = (0..ca.rows)
1093                    .map(|r| {
1094                        ca.data[r * ca.cols..(r + 1) * ca.cols]
1095                            .iter()
1096                            .collect::<String>()
1097                    })
1098                    .collect();
1099                assert_eq!(
1100                    strings,
1101                    vec!["al  ".to_string(), "ally".to_string(), "bob ".to_string()]
1102                );
1103            }
1104            other => panic!("expected char array, got {other:?}"),
1105        }
1106    }
1107
1108    #[test]
1109    fn sortrows_complex_abs() {
1110        let tensor = ComplexTensor::new(vec![(1.0, 2.0), (-2.0, 1.0)], vec![2, 1]).unwrap();
1111        let eval = evaluate(
1112            Value::ComplexTensor(tensor),
1113            &[Value::from("ComparisonMethod"), Value::from("abs")],
1114        )
1115        .expect("evaluate");
1116        let (sorted, _) = eval.into_values();
1117        match sorted {
1118            Value::ComplexTensor(ct) => {
1119                assert_eq!(ct.data, vec![(-2.0, 1.0), (1.0, 2.0)]);
1120            }
1121            other => panic!("expected complex tensor, got {other:?}"),
1122        }
1123    }
1124
1125    #[test]
1126    fn sortrows_invalid_column_index_errors() {
1127        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1128        let err = evaluate(Value::Tensor(tensor), &[Value::Int(IntValue::I32(3))]).unwrap_err();
1129        assert!(
1130            err.contains("column index"),
1131            "unexpected error message: {err}"
1132        );
1133    }
1134
1135    #[test]
1136    fn sortrows_missingplacement_first_moves_nan_first() {
1137        let tensor = Tensor::new(vec![1.0, f64::NAN, 2.0, 3.0], vec![2, 2]).unwrap();
1138        let eval = evaluate(
1139            Value::Tensor(tensor),
1140            &[Value::from("MissingPlacement"), Value::from("first")],
1141        )
1142        .expect("evaluate");
1143        let (sorted, indices) = eval.into_values();
1144        match sorted {
1145            Value::Tensor(t) => {
1146                assert!(t.data[0].is_nan());
1147                assert_eq!(t.data[1], 1.0);
1148                assert_eq!(t.data[2], 3.0);
1149                assert_eq!(t.data[3], 2.0);
1150            }
1151            other => panic!("expected tensor, got {other:?}"),
1152        }
1153        match indices {
1154            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0]),
1155            Value::Num(_) => panic!("expected tensor indices"),
1156            other => panic!("unexpected indices {other:?}"),
1157        }
1158    }
1159
1160    #[test]
1161    fn sortrows_missingplacement_last_descend_moves_nan_last() {
1162        let tensor = Tensor::new(vec![f64::NAN, 5.0, 1.0, 2.0], vec![2, 2]).unwrap();
1163        let eval = evaluate(
1164            Value::Tensor(tensor),
1165            &[
1166                Value::from("descend"),
1167                Value::from("MissingPlacement"),
1168                Value::from("last"),
1169            ],
1170        )
1171        .expect("evaluate");
1172        let (sorted, indices) = eval.into_values();
1173        match sorted {
1174            Value::Tensor(t) => {
1175                assert_eq!(t.data[0], 5.0);
1176                assert!(t.data[1].is_nan());
1177                assert_eq!(t.data[2], 2.0);
1178                assert_eq!(t.data[3], 1.0);
1179            }
1180            other => panic!("expected tensor, got {other:?}"),
1181        }
1182        match indices {
1183            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0]),
1184            Value::Num(_) => panic!("expected tensor indices"),
1185            other => panic!("unexpected indices {other:?}"),
1186        }
1187    }
1188
1189    #[test]
1190    fn sortrows_missingplacement_invalid_value_errors() {
1191        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1192        let err = evaluate(
1193            Value::Tensor(tensor),
1194            &[Value::from("MissingPlacement"), Value::from("middle")],
1195        )
1196        .unwrap_err();
1197        assert!(
1198            err.contains("MissingPlacement"),
1199            "unexpected error message: {err}"
1200        );
1201    }
1202
1203    #[test]
1204    fn sortrows_gpu_roundtrip() {
1205        test_support::with_test_provider(|provider| {
1206            let tensor = Tensor::new(vec![3.0, 1.0, 2.0, 4.0, 1.0, 5.0], vec![3, 2]).unwrap();
1207            let view = runmat_accelerate_api::HostTensorView {
1208                data: &tensor.data,
1209                shape: &tensor.shape,
1210            };
1211            let handle = provider.upload(&view).expect("upload");
1212            let eval = evaluate(Value::GpuTensor(handle), &[]).expect("evaluate");
1213            let (sorted, indices) = eval.into_values();
1214            match sorted {
1215                Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 3.0, 1.0, 5.0, 4.0]),
1216                other => panic!("expected tensor, got {other:?}"),
1217            }
1218            match indices {
1219                Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
1220                other => panic!("unexpected indices {other:?}"),
1221            }
1222        });
1223    }
1224
1225    #[test]
1226    #[cfg(feature = "wgpu")]
1227    fn sortrows_wgpu_matches_cpu() {
1228        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1229            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1230        );
1231
1232        let tensor = Tensor::new(vec![4.0, 2.0, 3.0, 1.0, 2.0, 5.0], vec![3, 2]).unwrap();
1233        let cpu_eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("cpu evaluate");
1234        let (cpu_sorted_val, cpu_indices_val) = cpu_eval.into_values();
1235        let cpu_sorted = match cpu_sorted_val {
1236            Value::Tensor(t) => t,
1237            other => panic!("expected tensor, got {other:?}"),
1238        };
1239        let cpu_indices = match cpu_indices_val {
1240            Value::Tensor(t) => t,
1241            other => panic!("expected tensor indices, got {other:?}"),
1242        };
1243
1244        let view = runmat_accelerate_api::HostTensorView {
1245            data: &tensor.data,
1246            shape: &tensor.shape,
1247        };
1248        let provider = runmat_accelerate_api::provider().expect("provider");
1249        let handle = provider.upload(&view).expect("upload");
1250        let gpu_eval = evaluate(Value::GpuTensor(handle.clone()), &[]).expect("gpu evaluate");
1251        let (gpu_sorted_val, gpu_indices_val) = gpu_eval.into_values();
1252        let gpu_sorted = match gpu_sorted_val {
1253            Value::Tensor(t) => t,
1254            other => panic!("expected tensor, got {other:?}"),
1255        };
1256        let gpu_indices = match gpu_indices_val {
1257            Value::Tensor(t) => t,
1258            other => panic!("expected tensor indices, got {other:?}"),
1259        };
1260
1261        assert_eq!(gpu_sorted.shape, cpu_sorted.shape);
1262        assert_eq!(gpu_sorted.data, cpu_sorted.data);
1263        assert_eq!(gpu_indices.shape, cpu_indices.shape);
1264        assert_eq!(gpu_indices.data, cpu_indices.data);
1265
1266        let _ = provider.free(&handle);
1267    }
1268
1269    #[test]
1270    #[cfg(feature = "doc_export")]
1271    fn doc_examples_present() {
1272        let blocks = test_support::doc_examples(DOC_MD);
1273        assert!(!blocks.is_empty());
1274    }
1275}