Skip to main content

runmat_runtime/builtins/array/sorting_sets/
sortrows.rs

1//! MATLAB-compatible `sortrows` builtin with GPU-aware semantics.
2
3use std::cmp::Ordering;
4
5use runmat_accelerate_api::{
6    GpuTensorHandle, SortComparison as ProviderSortComparison, SortOrder as ProviderSortOrder,
7    SortResult as ProviderSortResult, SortRowsColumnSpec as ProviderSortRowsColumnSpec,
8};
9use runmat_builtins::{CharArray, ComplexTensor, Tensor, Value};
10use runmat_macros::runtime_builtin;
11
12use super::type_resolvers::tensor_output_type;
13use crate::build_runtime_error;
14use crate::builtins::common::gpu_helpers;
15use crate::builtins::common::spec::{
16    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
17    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
18};
19use crate::builtins::common::tensor;
20
21#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::sorting_sets::sortrows")]
22pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
23    name: "sortrows",
24    op_kind: GpuOpKind::Custom("sortrows"),
25    supported_precisions: &[ScalarType::F32, ScalarType::F64],
26    broadcast: BroadcastSemantics::None,
27    provider_hooks: &[ProviderHook::Custom("sortrows")],
28    constant_strategy: ConstantStrategy::InlineLiteral,
29    residency: ResidencyPolicy::GatherImmediately,
30    nan_mode: ReductionNaN::Include,
31    two_pass_threshold: None,
32    workgroup_size: None,
33    accepts_nan_mode: true,
34    notes:
35        "Providers may implement a row-sort kernel; explicit MissingPlacement overrides fall back to host memory until native support exists.",
36};
37
38#[runmat_macros::register_fusion_spec(
39    builtin_path = "crate::builtins::array::sorting_sets::sortrows"
40)]
41pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
42    name: "sortrows",
43    shape: ShapeRequirements::Any,
44    constant_strategy: ConstantStrategy::InlineLiteral,
45    elementwise: None,
46    reduction: None,
47    emits_nan: true,
48    notes: "`sortrows` terminates fusion chains and materialises results on the host; upstream tensors are gathered when necessary.",
49};
50
51fn sortrows_error(message: impl Into<String>) -> crate::RuntimeError {
52    build_runtime_error(message)
53        .with_builtin("sortrows")
54        .build()
55}
56
57#[runtime_builtin(
58    name = "sortrows",
59    category = "array/sorting_sets",
60    summary = "Sort matrix rows lexicographically with optional column and direction control.",
61    keywords = "sortrows,row sort,lexicographic,gpu",
62    accel = "sink",
63    sink = true,
64    type_resolver(tensor_output_type),
65    builtin_path = "crate::builtins::array::sorting_sets::sortrows"
66)]
67async fn sortrows_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
68    let eval = evaluate(value, &rest).await?;
69    if let Some(out_count) = crate::output_count::current_output_count() {
70        if out_count == 0 {
71            return Ok(Value::OutputList(Vec::new()));
72        }
73        let (sorted, indices) = eval.into_values();
74        let mut outputs = vec![sorted];
75        if out_count >= 2 {
76            outputs.push(indices);
77        }
78        return Ok(crate::output_count::output_list_with_padding(
79            out_count, outputs,
80        ));
81    }
82    Ok(eval.into_sorted_value())
83}
84
85/// Evaluate the `sortrows` builtin once and expose both outputs.
86pub async fn evaluate(value: Value, rest: &[Value]) -> crate::BuiltinResult<SortRowsEvaluation> {
87    match value {
88        Value::GpuTensor(handle) => sortrows_gpu(handle, rest).await,
89        other => sortrows_host(other, rest),
90    }
91}
92
93async fn sortrows_gpu(
94    handle: GpuTensorHandle,
95    rest: &[Value],
96) -> crate::BuiltinResult<SortRowsEvaluation> {
97    ensure_matrix_shape(&handle.shape)?;
98    let (_, cols) = rows_cols_from_shape(&handle.shape);
99    let args = SortRowsArgs::parse(rest, cols)?;
100
101    if args.missing_is_auto() {
102        if let Some(provider) = runmat_accelerate_api::provider() {
103            let provider_columns = args.to_provider_columns();
104            let provider_comparison = args.provider_comparison();
105            match provider
106                .sort_rows(&handle, &provider_columns, provider_comparison)
107                .await
108            {
109                Ok(result) => return sortrows_from_provider_result(result),
110                Err(_err) => {
111                    // fall back to host path when provider cannot service the request
112                }
113            }
114        }
115    }
116
117    let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
118    sortrows_real_tensor_with_args(tensor, &args)
119}
120
121fn sortrows_from_provider_result(
122    result: ProviderSortResult,
123) -> crate::BuiltinResult<SortRowsEvaluation> {
124    let sorted_tensor = Tensor::new(result.values.data, result.values.shape)
125        .map_err(|e| sortrows_error(format!("sortrows: {e}")))?;
126    let indices_tensor = Tensor::new(result.indices.data, result.indices.shape)
127        .map_err(|e| sortrows_error(format!("sortrows: {e}")))?;
128    Ok(SortRowsEvaluation {
129        sorted: tensor::tensor_into_value(sorted_tensor),
130        indices: indices_tensor,
131    })
132}
133
134fn sortrows_host(value: Value, rest: &[Value]) -> crate::BuiltinResult<SortRowsEvaluation> {
135    match value {
136        Value::Tensor(tensor) => sortrows_real_tensor(tensor, rest),
137        Value::LogicalArray(logical) => {
138            let tensor = tensor::logical_to_tensor(&logical)
139                .map_err(|e| sortrows_error(e))?;
140            sortrows_real_tensor(tensor, rest)
141        }
142        Value::Num(_) | Value::Int(_) | Value::Bool(_) => {
143            let tensor = tensor::value_into_tensor_for("sortrows", value)
144                .map_err(|e| sortrows_error(e))?;
145            sortrows_real_tensor(tensor, rest)
146        }
147        Value::ComplexTensor(ct) => sortrows_complex_tensor(ct, rest),
148        Value::Complex(re, im) => {
149            let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
150                .map_err(|e| sortrows_error(format!("sortrows: {e}")))?;
151            sortrows_complex_tensor(tensor, rest)
152        }
153        Value::CharArray(ca) => sortrows_char_array(ca, rest),
154        other => Err(sortrows_error(format!(
155            "sortrows: unsupported input type {:?}; expected numeric, logical, complex, or char arrays",
156            other
157        ))
158        .into()),
159    }
160}
161
162fn sortrows_real_tensor(
163    tensor: Tensor,
164    rest: &[Value],
165) -> crate::BuiltinResult<SortRowsEvaluation> {
166    ensure_matrix_shape(&tensor.shape)?;
167    let cols = tensor.cols();
168    let args = SortRowsArgs::parse(rest, cols)?;
169    sortrows_real_tensor_with_args(tensor, &args)
170}
171
172fn sortrows_real_tensor_with_args(
173    tensor: Tensor,
174    args: &SortRowsArgs,
175) -> crate::BuiltinResult<SortRowsEvaluation> {
176    let rows = tensor.rows();
177    let cols = tensor.cols();
178
179    if rows <= 1 || cols == 0 || tensor.data.is_empty() || args.columns.is_empty() {
180        let indices = identity_indices(rows)?;
181        return Ok(SortRowsEvaluation {
182            sorted: tensor::tensor_into_value(tensor),
183            indices,
184        });
185    }
186
187    let mut order: Vec<usize> = (0..rows).collect();
188    order.sort_by(|&a, &b| compare_real_rows(&tensor, rows, args, a, b));
189
190    let sorted_tensor = reorder_real_rows(&tensor, rows, cols, &order)?;
191    let indices = permutation_indices(&order)?;
192    Ok(SortRowsEvaluation {
193        sorted: tensor::tensor_into_value(sorted_tensor),
194        indices,
195    })
196}
197
198fn sortrows_complex_tensor(
199    tensor: ComplexTensor,
200    rest: &[Value],
201) -> crate::BuiltinResult<SortRowsEvaluation> {
202    ensure_matrix_shape(&tensor.shape)?;
203    let cols = tensor.cols;
204    let args = SortRowsArgs::parse(rest, cols)?;
205    sortrows_complex_tensor_with_args(tensor, &args)
206}
207
208fn sortrows_complex_tensor_with_args(
209    tensor: ComplexTensor,
210    args: &SortRowsArgs,
211) -> crate::BuiltinResult<SortRowsEvaluation> {
212    let rows = tensor.rows;
213    let cols = tensor.cols;
214
215    if rows <= 1 || cols == 0 || tensor.data.is_empty() || args.columns.is_empty() {
216        let indices = identity_indices(rows)?;
217        return Ok(SortRowsEvaluation {
218            sorted: complex_tensor_into_value(tensor),
219            indices,
220        });
221    }
222
223    let mut order: Vec<usize> = (0..rows).collect();
224    order.sort_by(|&a, &b| compare_complex_rows(&tensor, rows, args, a, b));
225
226    let sorted_tensor = reorder_complex_rows(&tensor, rows, cols, &order)?;
227    let indices = permutation_indices(&order)?;
228    Ok(SortRowsEvaluation {
229        sorted: complex_tensor_into_value(sorted_tensor),
230        indices,
231    })
232}
233
234fn sortrows_char_array(ca: CharArray, rest: &[Value]) -> crate::BuiltinResult<SortRowsEvaluation> {
235    let cols = ca.cols;
236    let args = SortRowsArgs::parse(rest, cols)?;
237    sortrows_char_array_with_args(ca, &args)
238}
239
240fn sortrows_char_array_with_args(
241    ca: CharArray,
242    args: &SortRowsArgs,
243) -> crate::BuiltinResult<SortRowsEvaluation> {
244    let rows = ca.rows;
245    let cols = ca.cols;
246
247    if rows <= 1 || cols == 0 || ca.data.is_empty() || args.columns.is_empty() {
248        let indices = identity_indices(rows)?;
249        return Ok(SortRowsEvaluation {
250            sorted: Value::CharArray(ca),
251            indices,
252        });
253    }
254
255    let mut order: Vec<usize> = (0..rows).collect();
256    order.sort_by(|&a, &b| compare_char_rows(&ca, args, a, b));
257
258    let sorted = reorder_char_rows(&ca, rows, cols, &order)?;
259    let indices = permutation_indices(&order)?;
260    Ok(SortRowsEvaluation {
261        sorted: Value::CharArray(sorted),
262        indices,
263    })
264}
265
266fn ensure_matrix_shape(shape: &[usize]) -> crate::BuiltinResult<()> {
267    if shape.len() <= 2 {
268        Ok(())
269    } else {
270        Err(sortrows_error("sortrows: input must be a 2-D matrix"))
271    }
272}
273
274fn rows_cols_from_shape(shape: &[usize]) -> (usize, usize) {
275    match shape.len() {
276        0 => (1, 1),
277        1 => (1, shape[0]),
278        _ => (shape[0], shape[1]),
279    }
280}
281
282fn compare_real_rows(
283    tensor: &Tensor,
284    rows: usize,
285    args: &SortRowsArgs,
286    a: usize,
287    b: usize,
288) -> Ordering {
289    for spec in &args.columns {
290        if spec.index >= tensor.cols() {
291            continue;
292        }
293        let idx_a = a + spec.index * rows;
294        let idx_b = b + spec.index * rows;
295        let va = tensor.data[idx_a];
296        let vb = tensor.data[idx_b];
297        let missing = args.missing_for_direction(spec.direction);
298        let ord = compare_real_scalars(va, vb, spec.direction, args.comparison, missing);
299        if ord != Ordering::Equal {
300            return ord;
301        }
302    }
303    Ordering::Equal
304}
305
306fn compare_complex_rows(
307    tensor: &ComplexTensor,
308    rows: usize,
309    args: &SortRowsArgs,
310    a: usize,
311    b: usize,
312) -> Ordering {
313    for spec in &args.columns {
314        if spec.index >= tensor.cols {
315            continue;
316        }
317        let idx_a = a + spec.index * rows;
318        let idx_b = b + spec.index * rows;
319        let va = tensor.data[idx_a];
320        let vb = tensor.data[idx_b];
321        let missing = args.missing_for_direction(spec.direction);
322        let ord = compare_complex_scalars(va, vb, spec.direction, args.comparison, missing);
323        if ord != Ordering::Equal {
324            return ord;
325        }
326    }
327    Ordering::Equal
328}
329
330fn compare_char_rows(ca: &CharArray, args: &SortRowsArgs, a: usize, b: usize) -> Ordering {
331    for spec in &args.columns {
332        if spec.index >= ca.cols {
333            continue;
334        }
335        let idx_a = a * ca.cols + spec.index;
336        let idx_b = b * ca.cols + spec.index;
337        let va = ca.data[idx_a];
338        let vb = ca.data[idx_b];
339        let ord = match spec.direction {
340            SortDirection::Ascend => va.cmp(&vb),
341            SortDirection::Descend => vb.cmp(&va),
342        };
343        if ord != Ordering::Equal {
344            return ord;
345        }
346    }
347    Ordering::Equal
348}
349
350fn reorder_real_rows(
351    tensor: &Tensor,
352    rows: usize,
353    cols: usize,
354    order: &[usize],
355) -> crate::BuiltinResult<Tensor> {
356    let mut data = vec![0.0; tensor.data.len()];
357    for col in 0..cols {
358        for (dest_row, &src_row) in order.iter().enumerate() {
359            let src_idx = src_row + col * rows;
360            let dst_idx = dest_row + col * rows;
361            data[dst_idx] = tensor.data[src_idx];
362        }
363    }
364    Tensor::new(data, tensor.shape.clone()).map_err(|e| sortrows_error(format!("sortrows: {e}")))
365}
366
367fn reorder_complex_rows(
368    tensor: &ComplexTensor,
369    rows: usize,
370    cols: usize,
371    order: &[usize],
372) -> crate::BuiltinResult<ComplexTensor> {
373    let mut data = vec![(0.0, 0.0); tensor.data.len()];
374    for col in 0..cols {
375        for (dest_row, &src_row) in order.iter().enumerate() {
376            let src_idx = src_row + col * rows;
377            let dst_idx = dest_row + col * rows;
378            data[dst_idx] = tensor.data[src_idx];
379        }
380    }
381    ComplexTensor::new(data, tensor.shape.clone())
382        .map_err(|e| sortrows_error(format!("sortrows: {e}")))
383}
384
385fn reorder_char_rows(
386    ca: &CharArray,
387    rows: usize,
388    cols: usize,
389    order: &[usize],
390) -> crate::BuiltinResult<CharArray> {
391    let mut data = vec!['\0'; ca.data.len()];
392    for (dest_row, &src_row) in order.iter().enumerate() {
393        for col in 0..cols {
394            let src_idx = src_row * cols + col;
395            let dst_idx = dest_row * cols + col;
396            data[dst_idx] = ca.data[src_idx];
397        }
398    }
399    CharArray::new(data, rows, cols).map_err(|e| sortrows_error(format!("sortrows: {e}")))
400}
401
402fn compare_real_scalars(
403    a: f64,
404    b: f64,
405    direction: SortDirection,
406    comparison: ComparisonMethod,
407    missing: MissingPlacementResolved,
408) -> Ordering {
409    match (a.is_nan(), b.is_nan()) {
410        (true, true) => Ordering::Equal,
411        (true, false) => match missing {
412            MissingPlacementResolved::First => Ordering::Less,
413            MissingPlacementResolved::Last => Ordering::Greater,
414        },
415        (false, true) => match missing {
416            MissingPlacementResolved::First => Ordering::Greater,
417            MissingPlacementResolved::Last => Ordering::Less,
418        },
419        (false, false) => compare_real_finite_scalars(a, b, direction, comparison),
420    }
421}
422
423fn compare_real_finite_scalars(
424    a: f64,
425    b: f64,
426    direction: SortDirection,
427    comparison: ComparisonMethod,
428) -> Ordering {
429    if matches!(comparison, ComparisonMethod::Abs) {
430        let abs_cmp = a.abs().partial_cmp(&b.abs()).unwrap_or(Ordering::Equal);
431        if abs_cmp != Ordering::Equal {
432            return match direction {
433                SortDirection::Ascend => abs_cmp,
434                SortDirection::Descend => abs_cmp.reverse(),
435            };
436        }
437    }
438    match direction {
439        SortDirection::Ascend => a.partial_cmp(&b).unwrap_or(Ordering::Equal),
440        SortDirection::Descend => b.partial_cmp(&a).unwrap_or(Ordering::Equal),
441    }
442}
443
444fn compare_complex_scalars(
445    a: (f64, f64),
446    b: (f64, f64),
447    direction: SortDirection,
448    comparison: ComparisonMethod,
449    missing: MissingPlacementResolved,
450) -> Ordering {
451    match (complex_is_nan(a), complex_is_nan(b)) {
452        (true, true) => Ordering::Equal,
453        (true, false) => match missing {
454            MissingPlacementResolved::First => Ordering::Less,
455            MissingPlacementResolved::Last => Ordering::Greater,
456        },
457        (false, true) => match missing {
458            MissingPlacementResolved::First => Ordering::Greater,
459            MissingPlacementResolved::Last => Ordering::Less,
460        },
461        (false, false) => compare_complex_finite_scalars(a, b, direction, comparison),
462    }
463}
464
465fn compare_complex_finite_scalars(
466    a: (f64, f64),
467    b: (f64, f64),
468    direction: SortDirection,
469    comparison: ComparisonMethod,
470) -> Ordering {
471    match comparison {
472        ComparisonMethod::Real => compare_complex_real_first(a, b, direction),
473        ComparisonMethod::Auto | ComparisonMethod::Abs => {
474            let abs_cmp = complex_abs(a)
475                .partial_cmp(&complex_abs(b))
476                .unwrap_or(Ordering::Equal);
477            if abs_cmp != Ordering::Equal {
478                return match direction {
479                    SortDirection::Ascend => abs_cmp,
480                    SortDirection::Descend => abs_cmp.reverse(),
481                };
482            }
483            compare_complex_real_first(a, b, direction)
484        }
485    }
486}
487
488fn compare_complex_real_first(a: (f64, f64), b: (f64, f64), direction: SortDirection) -> Ordering {
489    let real_cmp = match direction {
490        SortDirection::Ascend => a.0.partial_cmp(&b.0),
491        SortDirection::Descend => b.0.partial_cmp(&a.0),
492    }
493    .unwrap_or(Ordering::Equal);
494    if real_cmp != Ordering::Equal {
495        return real_cmp;
496    }
497    match direction {
498        SortDirection::Ascend => a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal),
499        SortDirection::Descend => b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal),
500    }
501}
502
503fn complex_is_nan(value: (f64, f64)) -> bool {
504    value.0.is_nan() || value.1.is_nan()
505}
506
507fn complex_abs(value: (f64, f64)) -> f64 {
508    value.0.hypot(value.1)
509}
510
511fn permutation_indices(order: &[usize]) -> crate::BuiltinResult<Tensor> {
512    let rows = order.len();
513    let mut data = Vec::with_capacity(rows);
514    for &idx in order {
515        data.push((idx + 1) as f64);
516    }
517    Tensor::new(data, vec![rows, 1]).map_err(|e| sortrows_error(format!("sortrows: {e}")))
518}
519
520fn identity_indices(rows: usize) -> crate::BuiltinResult<Tensor> {
521    let mut data = Vec::with_capacity(rows);
522    for i in 0..rows {
523        data.push((i + 1) as f64);
524    }
525    Tensor::new(data, vec![rows, 1]).map_err(|e| sortrows_error(format!("sortrows: {e}")))
526}
527
528fn complex_tensor_into_value(tensor: ComplexTensor) -> Value {
529    if tensor.data.len() == 1 {
530        Value::Complex(tensor.data[0].0, tensor.data[0].1)
531    } else {
532        Value::ComplexTensor(tensor)
533    }
534}
535
536#[derive(Debug, Clone, Copy, PartialEq, Eq)]
537enum SortDirection {
538    Ascend,
539    Descend,
540}
541
542impl SortDirection {
543    fn from_str(value: &str) -> Option<Self> {
544        match value.trim().to_ascii_lowercase().as_str() {
545            "ascend" | "ascending" => Some(SortDirection::Ascend),
546            "descend" | "descending" => Some(SortDirection::Descend),
547            _ => None,
548        }
549    }
550}
551
552#[derive(Debug, Clone, Copy, PartialEq, Eq)]
553enum ComparisonMethod {
554    Auto,
555    Real,
556    Abs,
557}
558
559#[derive(Debug, Clone, Copy, PartialEq, Eq)]
560enum MissingPlacement {
561    Auto,
562    First,
563    Last,
564}
565
566#[derive(Debug, Clone, Copy, PartialEq, Eq)]
567enum MissingPlacementResolved {
568    First,
569    Last,
570}
571
572impl MissingPlacement {
573    fn resolve(self, direction: SortDirection) -> MissingPlacementResolved {
574        match self {
575            MissingPlacement::First => MissingPlacementResolved::First,
576            MissingPlacement::Last => MissingPlacementResolved::Last,
577            MissingPlacement::Auto => match direction {
578                SortDirection::Ascend => MissingPlacementResolved::Last,
579                SortDirection::Descend => MissingPlacementResolved::First,
580            },
581        }
582    }
583
584    fn is_auto(self) -> bool {
585        matches!(self, MissingPlacement::Auto)
586    }
587}
588
589#[derive(Debug, Clone)]
590struct ColumnSpec {
591    index: usize,
592    direction: SortDirection,
593}
594
595#[derive(Debug, Clone)]
596struct SortRowsArgs {
597    columns: Vec<ColumnSpec>,
598    comparison: ComparisonMethod,
599    missing: MissingPlacement,
600}
601
602impl SortRowsArgs {
603    fn parse(rest: &[Value], num_cols: usize) -> crate::BuiltinResult<Self> {
604        let mut columns: Option<Vec<ColumnSpec>> = None;
605        let mut override_direction: Option<SortDirection> = None;
606        let mut comparison = ComparisonMethod::Auto;
607        let mut missing = MissingPlacement::Auto;
608        let mut i = 0usize;
609
610        while i < rest.len() {
611            if columns.is_none() {
612                if let Some(parsed) = parse_column_vector(&rest[i], num_cols)? {
613                    columns = Some(parsed);
614                    i += 1;
615                    continue;
616                }
617            }
618            if let Some(direction) = parse_direction(&rest[i]) {
619                override_direction = Some(direction);
620                i += 1;
621                continue;
622            }
623            let Some(keyword) = tensor::value_to_string(&rest[i]) else {
624                return Err(sortrows_error(format!(
625                    "sortrows: invalid argument {:?}",
626                    rest[i]
627                )));
628            };
629            let lowered = keyword.trim().to_ascii_lowercase();
630            match lowered.as_str() {
631                "comparisonmethod" => {
632                    i += 1;
633                    if i >= rest.len() {
634                        return Err(sortrows_error(
635                            "sortrows: expected a value for 'ComparisonMethod'",
636                        ));
637                    }
638                    let Some(value_str) = tensor::value_to_string(&rest[i]) else {
639                        return Err(sortrows_error(
640                            "sortrows: 'ComparisonMethod' expects a string value",
641                        )
642                        .into());
643                    };
644                    comparison = match value_str.trim().to_ascii_lowercase().as_str() {
645                        "auto" => ComparisonMethod::Auto,
646                        "real" => ComparisonMethod::Real,
647                        "abs" | "magnitude" => ComparisonMethod::Abs,
648                        other => {
649                            return Err(sortrows_error(format!(
650                                "sortrows: unsupported ComparisonMethod '{other}'"
651                            ))
652                            .into())
653                        }
654                    };
655                    i += 1;
656                }
657                "missingplacement" => {
658                    i += 1;
659                    if i >= rest.len() {
660                        return Err(sortrows_error(
661                            "sortrows: expected a value for 'MissingPlacement'",
662                        )
663                        .into());
664                    }
665                    let Some(value_str) = tensor::value_to_string(&rest[i]) else {
666                        return Err(sortrows_error(
667                            "sortrows: 'MissingPlacement' expects a string value",
668                        )
669                        .into());
670                    };
671                    missing = match value_str.trim().to_ascii_lowercase().as_str() {
672                        "auto" => MissingPlacement::Auto,
673                        "first" => MissingPlacement::First,
674                        "last" => MissingPlacement::Last,
675                        other => {
676                            return Err(sortrows_error(format!(
677                                "sortrows: unsupported MissingPlacement '{other}'"
678                            ))
679                            .into())
680                        }
681                    };
682                    i += 1;
683                }
684                other => {
685                    return Err(sortrows_error(format!(
686                        "sortrows: unexpected argument '{other}'"
687                    )));
688                }
689            }
690        }
691
692        let mut columns = columns.unwrap_or_else(|| default_columns(num_cols));
693        if let Some(dir) = override_direction {
694            for spec in &mut columns {
695                spec.direction = dir;
696            }
697        }
698        validate_columns(&columns, num_cols)?;
699
700        Ok(SortRowsArgs {
701            columns,
702            comparison,
703            missing,
704        })
705    }
706
707    fn to_provider_columns(&self) -> Vec<ProviderSortRowsColumnSpec> {
708        self.columns
709            .iter()
710            .map(|spec| ProviderSortRowsColumnSpec {
711                index: spec.index,
712                order: match spec.direction {
713                    SortDirection::Ascend => ProviderSortOrder::Ascend,
714                    SortDirection::Descend => ProviderSortOrder::Descend,
715                },
716            })
717            .collect()
718    }
719
720    fn provider_comparison(&self) -> ProviderSortComparison {
721        match self.comparison {
722            ComparisonMethod::Auto => ProviderSortComparison::Auto,
723            ComparisonMethod::Real => ProviderSortComparison::Real,
724            ComparisonMethod::Abs => ProviderSortComparison::Abs,
725        }
726    }
727
728    fn missing_for_direction(&self, direction: SortDirection) -> MissingPlacementResolved {
729        self.missing.resolve(direction)
730    }
731
732    fn missing_is_auto(&self) -> bool {
733        self.missing.is_auto()
734    }
735}
736
737fn parse_column_vector(
738    value: &Value,
739    num_cols: usize,
740) -> crate::BuiltinResult<Option<Vec<ColumnSpec>>> {
741    match value {
742        Value::Int(i) => parse_single_column(i.to_i64(), num_cols).map(Some),
743        Value::Num(n) => {
744            if !n.is_finite() {
745                return Err(sortrows_error("sortrows: column indices must be finite"));
746            }
747            let rounded = n.round();
748            if (rounded - n).abs() > f64::EPSILON {
749                return Err(sortrows_error("sortrows: column indices must be integers"));
750            }
751            parse_single_column(rounded as i64, num_cols).map(Some)
752        }
753        Value::Tensor(tensor) => {
754            if !is_vector(&tensor.shape) {
755                return Err(sortrows_error(
756                    "sortrows: column specification must be a vector",
757                ));
758            }
759            let mut specs = Vec::with_capacity(tensor.data.len());
760            for &entry in &tensor.data {
761                if !entry.is_finite() {
762                    return Err(sortrows_error("sortrows: column indices must be finite"));
763                }
764                let rounded = entry.round();
765                if (rounded - entry).abs() > f64::EPSILON {
766                    return Err(sortrows_error("sortrows: column indices must be integers"));
767                }
768                let column = parse_single_column_i64(rounded as i64, num_cols)?;
769                specs.push(column);
770            }
771            Ok(Some(specs))
772        }
773        _ => Ok(None),
774    }
775}
776
777fn parse_single_column(value: i64, num_cols: usize) -> crate::BuiltinResult<Vec<ColumnSpec>> {
778    parse_single_column_i64(value, num_cols).map(|spec| vec![spec])
779}
780
781fn parse_single_column_i64(value: i64, num_cols: usize) -> crate::BuiltinResult<ColumnSpec> {
782    if value == 0 {
783        return Err(sortrows_error("sortrows: column indices must be non-zero"));
784    }
785    let abs = value.unsigned_abs() as usize;
786    if abs == 0 {
787        return Err(sortrows_error("sortrows: column indices must be >= 1"));
788    }
789    if num_cols == 0 {
790        return Err(sortrows_error(
791            "sortrows: column index exceeds matrix with 0 columns",
792        ));
793    }
794    if abs > num_cols {
795        return Err(sortrows_error(format!(
796            "sortrows: column index {} exceeds matrix with {} columns",
797            abs, num_cols
798        ))
799        .into());
800    }
801    let direction = if value > 0 {
802        SortDirection::Ascend
803    } else {
804        SortDirection::Descend
805    };
806    Ok(ColumnSpec {
807        index: abs - 1,
808        direction,
809    })
810}
811
812fn parse_direction(value: &Value) -> Option<SortDirection> {
813    tensor::value_to_string(value).and_then(|s| SortDirection::from_str(&s))
814}
815
816fn default_columns(num_cols: usize) -> Vec<ColumnSpec> {
817    let mut columns = Vec::with_capacity(num_cols);
818    for col in 0..num_cols {
819        columns.push(ColumnSpec {
820            index: col,
821            direction: SortDirection::Ascend,
822        });
823    }
824    columns
825}
826
827fn validate_columns(columns: &[ColumnSpec], num_cols: usize) -> crate::BuiltinResult<()> {
828    if num_cols == 0 && columns.iter().any(|spec| spec.index > 0) {
829        return Err(sortrows_error(
830            "sortrows: column index exceeds matrix with 0 columns",
831        ));
832    }
833    for spec in columns {
834        if num_cols > 0 && spec.index >= num_cols {
835            return Err(sortrows_error(format!(
836                "sortrows: column index {} exceeds matrix with {} columns",
837                spec.index + 1,
838                num_cols
839            ))
840            .into());
841        }
842    }
843    Ok(())
844}
845
846fn is_vector(shape: &[usize]) -> bool {
847    match shape.len() {
848        0 => true,
849        1 => true,
850        2 => shape[0] == 1 || shape[1] == 1,
851        _ => false,
852    }
853}
854
855#[derive(Debug)]
856pub struct SortRowsEvaluation {
857    sorted: Value,
858    indices: Tensor,
859}
860
861impl SortRowsEvaluation {
862    pub fn into_sorted_value(self) -> Value {
863        self.sorted
864    }
865
866    pub fn into_values(self) -> (Value, Value) {
867        let indices = tensor::tensor_into_value(self.indices);
868        (self.sorted, indices)
869    }
870
871    pub fn indices_value(&self) -> Value {
872        tensor::tensor_into_value(self.indices.clone())
873    }
874}
875
876#[cfg(test)]
877pub(crate) mod tests {
878    use super::*;
879    use crate::builtins::common::test_support;
880    use runmat_builtins::{IntValue, ResolveContext, Type, Value};
881
882    fn error_message(err: crate::RuntimeError) -> String {
883        err.message().to_string()
884    }
885
886    fn evaluate(value: Value, rest: &[Value]) -> crate::BuiltinResult<SortRowsEvaluation> {
887        futures::executor::block_on(super::evaluate(value, rest))
888    }
889
890    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
891    #[test]
892    fn sortrows_default_matrix() {
893        let tensor = Tensor::new(vec![3.0, 1.0, 2.0, 4.0, 1.0, 5.0], vec![3, 2]).unwrap();
894        let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
895        let (sorted, indices) = eval.into_values();
896        match sorted {
897            Value::Tensor(t) => {
898                assert_eq!(t.shape, vec![3, 2]);
899                assert_eq!(t.data, vec![1.0, 2.0, 3.0, 1.0, 5.0, 4.0]);
900            }
901            other => panic!("expected tensor, got {other:?}"),
902        }
903        match indices {
904            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
905            Value::Num(_) => panic!("expected tensor indices"),
906            other => panic!("unexpected indices {other:?}"),
907        }
908    }
909
910    #[test]
911    fn sortrows_type_resolver_tensor() {
912        assert_eq!(
913            tensor_output_type(&[Type::tensor()], &ResolveContext::new(Vec::new())),
914            Type::tensor()
915        );
916    }
917
918    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
919    #[test]
920    fn sortrows_with_column_vector() {
921        let tensor = Tensor::new(
922            vec![1.0, 3.0, 3.0, 4.0, 2.0, 2.0, 2.0, 5.0, 1.0],
923            vec![3, 3],
924        )
925        .unwrap();
926        let cols = Tensor::new(vec![2.0, 3.0, 1.0], vec![3, 1]).unwrap();
927        let eval = evaluate(Value::Tensor(tensor), &[Value::Tensor(cols)]).expect("evaluate");
928        let (sorted, _) = eval.into_values();
929        match sorted {
930            Value::Tensor(t) => {
931                assert_eq!(t.data, vec![3.0, 3.0, 1.0, 2.0, 2.0, 4.0, 1.0, 5.0, 2.0]);
932            }
933            other => panic!("expected tensor, got {other:?}"),
934        }
935    }
936
937    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
938    #[test]
939    fn sortrows_direction_descend() {
940        let tensor = Tensor::new(vec![1.0, 2.0, 4.0, 3.0], vec![2, 2]).unwrap();
941        let eval = evaluate(Value::Tensor(tensor), &[Value::from("descend")]).expect("evaluate");
942        let (sorted, _) = eval.into_values();
943        match sorted {
944            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0, 3.0, 4.0]),
945            other => panic!("expected tensor, got {other:?}"),
946        }
947    }
948
949    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
950    #[test]
951    fn sortrows_mixed_directions() {
952        let tensor = Tensor::new(vec![1.0, 1.0, 1.0, 1.0, 7.0, 2.0], vec![3, 2]).unwrap();
953        let cols = Tensor::new(vec![1.0, -2.0], vec![2, 1]).unwrap();
954        let eval = evaluate(Value::Tensor(tensor), &[Value::Tensor(cols)]).expect("evaluate");
955        let (sorted, _) = eval.into_values();
956        match sorted {
957            Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 1.0, 1.0, 7.0, 2.0, 1.0]),
958            other => panic!("expected tensor, got {other:?}"),
959        }
960    }
961
962    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
963    #[test]
964    fn sortrows_returns_indices() {
965        let tensor = Tensor::new(vec![2.0, 1.0, 3.0, 4.0], vec![2, 2]).unwrap();
966        let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
967        let (_, indices) = eval.into_values();
968        match indices {
969            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0]),
970            Value::Num(_) => panic!("expected tensor indices"),
971            other => panic!("unexpected indices {other:?}"),
972        }
973    }
974
975    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
976    #[test]
977    fn sortrows_char_array() {
978        let chars = CharArray::new(
979            "bob "
980                .chars()
981                .chain("al  ".chars())
982                .chain("ally".chars())
983                .collect(),
984            3,
985            4,
986        )
987        .unwrap();
988        let eval = evaluate(Value::CharArray(chars), &[]).expect("evaluate");
989        let (sorted, _) = eval.into_values();
990        match sorted {
991            Value::CharArray(ca) => {
992                assert_eq!(ca.rows, 3);
993                assert_eq!(ca.cols, 4);
994                let strings: Vec<String> = (0..ca.rows)
995                    .map(|r| {
996                        ca.data[r * ca.cols..(r + 1) * ca.cols]
997                            .iter()
998                            .collect::<String>()
999                    })
1000                    .collect();
1001                assert_eq!(
1002                    strings,
1003                    vec!["al  ".to_string(), "ally".to_string(), "bob ".to_string()]
1004                );
1005            }
1006            other => panic!("expected char array, got {other:?}"),
1007        }
1008    }
1009
1010    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1011    #[test]
1012    fn sortrows_complex_abs() {
1013        let tensor = ComplexTensor::new(vec![(1.0, 2.0), (-2.0, 1.0)], vec![2, 1]).unwrap();
1014        let eval = evaluate(
1015            Value::ComplexTensor(tensor),
1016            &[Value::from("ComparisonMethod"), Value::from("abs")],
1017        )
1018        .expect("evaluate");
1019        let (sorted, _) = eval.into_values();
1020        match sorted {
1021            Value::ComplexTensor(ct) => {
1022                assert_eq!(ct.data, vec![(-2.0, 1.0), (1.0, 2.0)]);
1023            }
1024            other => panic!("expected complex tensor, got {other:?}"),
1025        }
1026    }
1027
1028    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1029    #[test]
1030    fn sortrows_invalid_column_index_errors() {
1031        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1032        let err = error_message(
1033            evaluate(Value::Tensor(tensor), &[Value::Int(IntValue::I32(3))]).unwrap_err(),
1034        );
1035        assert!(
1036            err.contains("column index"),
1037            "unexpected error message: {err}"
1038        );
1039    }
1040
1041    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1042    #[test]
1043    fn sortrows_missingplacement_first_moves_nan_first() {
1044        let tensor = Tensor::new(vec![1.0, f64::NAN, 2.0, 3.0], vec![2, 2]).unwrap();
1045        let eval = evaluate(
1046            Value::Tensor(tensor),
1047            &[Value::from("MissingPlacement"), Value::from("first")],
1048        )
1049        .expect("evaluate");
1050        let (sorted, indices) = eval.into_values();
1051        match sorted {
1052            Value::Tensor(t) => {
1053                assert!(t.data[0].is_nan());
1054                assert_eq!(t.data[1], 1.0);
1055                assert_eq!(t.data[2], 3.0);
1056                assert_eq!(t.data[3], 2.0);
1057            }
1058            other => panic!("expected tensor, got {other:?}"),
1059        }
1060        match indices {
1061            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0]),
1062            Value::Num(_) => panic!("expected tensor indices"),
1063            other => panic!("unexpected indices {other:?}"),
1064        }
1065    }
1066
1067    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1068    #[test]
1069    fn sortrows_missingplacement_last_descend_moves_nan_last() {
1070        let tensor = Tensor::new(vec![f64::NAN, 5.0, 1.0, 2.0], vec![2, 2]).unwrap();
1071        let eval = evaluate(
1072            Value::Tensor(tensor),
1073            &[
1074                Value::from("descend"),
1075                Value::from("MissingPlacement"),
1076                Value::from("last"),
1077            ],
1078        )
1079        .expect("evaluate");
1080        let (sorted, indices) = eval.into_values();
1081        match sorted {
1082            Value::Tensor(t) => {
1083                assert_eq!(t.data[0], 5.0);
1084                assert!(t.data[1].is_nan());
1085                assert_eq!(t.data[2], 2.0);
1086                assert_eq!(t.data[3], 1.0);
1087            }
1088            other => panic!("expected tensor, got {other:?}"),
1089        }
1090        match indices {
1091            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0]),
1092            Value::Num(_) => panic!("expected tensor indices"),
1093            other => panic!("unexpected indices {other:?}"),
1094        }
1095    }
1096
1097    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1098    #[test]
1099    fn sortrows_missingplacement_invalid_value_errors() {
1100        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1101        let err = error_message(
1102            evaluate(
1103                Value::Tensor(tensor),
1104                &[Value::from("MissingPlacement"), Value::from("middle")],
1105            )
1106            .unwrap_err(),
1107        );
1108        assert!(
1109            err.contains("MissingPlacement"),
1110            "unexpected error message: {err}"
1111        );
1112    }
1113
1114    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1115    #[test]
1116    fn sortrows_gpu_roundtrip() {
1117        test_support::with_test_provider(|provider| {
1118            let tensor = Tensor::new(vec![3.0, 1.0, 2.0, 4.0, 1.0, 5.0], vec![3, 2]).unwrap();
1119            let view = runmat_accelerate_api::HostTensorView {
1120                data: &tensor.data,
1121                shape: &tensor.shape,
1122            };
1123            let handle = provider.upload(&view).expect("upload");
1124            let eval = evaluate(Value::GpuTensor(handle), &[]).expect("evaluate");
1125            let (sorted, indices) = eval.into_values();
1126            match sorted {
1127                Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 3.0, 1.0, 5.0, 4.0]),
1128                other => panic!("expected tensor, got {other:?}"),
1129            }
1130            match indices {
1131                Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
1132                other => panic!("unexpected indices {other:?}"),
1133            }
1134        });
1135    }
1136
1137    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1138    #[test]
1139    #[cfg(feature = "wgpu")]
1140    fn sortrows_wgpu_matches_cpu() {
1141        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1142            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1143        );
1144
1145        let tensor = Tensor::new(vec![4.0, 2.0, 3.0, 1.0, 2.0, 5.0], vec![3, 2]).unwrap();
1146        let cpu_eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("cpu evaluate");
1147        let (cpu_sorted_val, cpu_indices_val) = cpu_eval.into_values();
1148        let cpu_sorted = match cpu_sorted_val {
1149            Value::Tensor(t) => t,
1150            other => panic!("expected tensor, got {other:?}"),
1151        };
1152        let cpu_indices = match cpu_indices_val {
1153            Value::Tensor(t) => t,
1154            other => panic!("expected tensor indices, got {other:?}"),
1155        };
1156
1157        let view = runmat_accelerate_api::HostTensorView {
1158            data: &tensor.data,
1159            shape: &tensor.shape,
1160        };
1161        let provider = runmat_accelerate_api::provider().expect("provider");
1162        let handle = provider.upload(&view).expect("upload");
1163        let gpu_eval = evaluate(Value::GpuTensor(handle.clone()), &[]).expect("gpu evaluate");
1164        let (gpu_sorted_val, gpu_indices_val) = gpu_eval.into_values();
1165        let gpu_sorted = match gpu_sorted_val {
1166            Value::Tensor(t) => t,
1167            other => panic!("expected tensor, got {other:?}"),
1168        };
1169        let gpu_indices = match gpu_indices_val {
1170            Value::Tensor(t) => t,
1171            other => panic!("expected tensor indices, got {other:?}"),
1172        };
1173
1174        assert_eq!(gpu_sorted.shape, cpu_sorted.shape);
1175        assert_eq!(gpu_sorted.data, cpu_sorted.data);
1176        assert_eq!(gpu_indices.shape, cpu_indices.shape);
1177        assert_eq!(gpu_indices.data, cpu_indices.data);
1178
1179        let _ = provider.free(&handle);
1180    }
1181}