Skip to main content

runmat_runtime/builtins/array/sorting_sets/
sortrows.rs

1//! MATLAB-compatible `sortrows` builtin with GPU-aware semantics.
2
3use std::cmp::Ordering;
4
5use runmat_accelerate_api::{
6    GpuTensorHandle, SortComparison as ProviderSortComparison, SortOrder as ProviderSortOrder,
7    SortResult as ProviderSortResult, SortRowsColumnSpec as ProviderSortRowsColumnSpec,
8};
9use runmat_builtins::{
10    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
11    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
12    CharArray, ComplexTensor, Tensor, Value,
13};
14use runmat_macros::runtime_builtin;
15
16use super::type_resolvers::tensor_output_type;
17use crate::build_runtime_error;
18use crate::builtins::common::gpu_helpers;
19use crate::builtins::common::spec::{
20    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
21    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
22};
23use crate::builtins::common::tensor;
24
25#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::sorting_sets::sortrows")]
26pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
27    name: "sortrows",
28    op_kind: GpuOpKind::Custom("sortrows"),
29    supported_precisions: &[ScalarType::F32, ScalarType::F64],
30    broadcast: BroadcastSemantics::None,
31    provider_hooks: &[ProviderHook::Custom("sortrows")],
32    constant_strategy: ConstantStrategy::InlineLiteral,
33    residency: ResidencyPolicy::GatherImmediately,
34    nan_mode: ReductionNaN::Include,
35    two_pass_threshold: None,
36    workgroup_size: None,
37    accepts_nan_mode: true,
38    notes:
39        "Providers may implement a row-sort kernel; explicit MissingPlacement overrides fall back to host memory until native support exists.",
40};
41
42#[runmat_macros::register_fusion_spec(
43    builtin_path = "crate::builtins::array::sorting_sets::sortrows"
44)]
45pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
46    name: "sortrows",
47    shape: ShapeRequirements::Any,
48    constant_strategy: ConstantStrategy::InlineLiteral,
49    elementwise: None,
50    reduction: None,
51    emits_nan: true,
52    notes: "`sortrows` terminates fusion chains and materialises results on the host; upstream tensors are gathered when necessary.",
53};
54
55const BUILTIN_NAME: &str = "sortrows";
56
57const SORTROWS_OUTPUT_B: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
58    name: "B",
59    ty: BuiltinParamType::Any,
60    arity: BuiltinParamArity::Required,
61    default: None,
62    description: "Sorted input rows.",
63}];
64
65const SORTROWS_OUTPUT_BI: [BuiltinParamDescriptor; 2] = [
66    BuiltinParamDescriptor {
67        name: "B",
68        ty: BuiltinParamType::Any,
69        arity: BuiltinParamArity::Required,
70        default: None,
71        description: "Sorted input rows.",
72    },
73    BuiltinParamDescriptor {
74        name: "I",
75        ty: BuiltinParamType::NumericArray,
76        arity: BuiltinParamArity::Required,
77        default: None,
78        description: "Permutation indices mapping sorted rows to original rows.",
79    },
80];
81
82const SORTROWS_INPUTS_A: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
83    name: "A",
84    ty: BuiltinParamType::Any,
85    arity: BuiltinParamArity::Required,
86    default: None,
87    description: "Input matrix to sort by rows.",
88}];
89
90const SORTROWS_INPUTS_A_COLUMNS: [BuiltinParamDescriptor; 2] = [
91    BuiltinParamDescriptor {
92        name: "A",
93        ty: BuiltinParamType::Any,
94        arity: BuiltinParamArity::Required,
95        default: None,
96        description: "Input matrix to sort by rows.",
97    },
98    BuiltinParamDescriptor {
99        name: "column",
100        ty: BuiltinParamType::NumericArray,
101        arity: BuiltinParamArity::Required,
102        default: None,
103        description: "Column specification vector (negative entries request descending order).",
104    },
105];
106
107const SORTROWS_INPUTS_A_DIRECTION: [BuiltinParamDescriptor; 2] = [
108    BuiltinParamDescriptor {
109        name: "A",
110        ty: BuiltinParamType::Any,
111        arity: BuiltinParamArity::Required,
112        default: None,
113        description: "Input matrix to sort by rows.",
114    },
115    BuiltinParamDescriptor {
116        name: "direction",
117        ty: BuiltinParamType::StringScalar,
118        arity: BuiltinParamArity::Required,
119        default: Some("\"ascend\""),
120        description: "Global row direction override: 'ascend' or 'descend'.",
121    },
122];
123
124const SORTROWS_INPUTS_A_COLUMNS_DIRECTION: [BuiltinParamDescriptor; 3] = [
125    BuiltinParamDescriptor {
126        name: "A",
127        ty: BuiltinParamType::Any,
128        arity: BuiltinParamArity::Required,
129        default: None,
130        description: "Input matrix to sort by rows.",
131    },
132    BuiltinParamDescriptor {
133        name: "column",
134        ty: BuiltinParamType::NumericArray,
135        arity: BuiltinParamArity::Required,
136        default: None,
137        description: "Column specification vector (negative entries request descending order).",
138    },
139    BuiltinParamDescriptor {
140        name: "direction",
141        ty: BuiltinParamType::StringScalar,
142        arity: BuiltinParamArity::Required,
143        default: Some("\"ascend\""),
144        description: "Global row direction override: 'ascend' or 'descend'.",
145    },
146];
147
148const SORTROWS_INPUTS_COMPARISON_METHOD: [BuiltinParamDescriptor; 4] = [
149    BuiltinParamDescriptor {
150        name: "A",
151        ty: BuiltinParamType::Any,
152        arity: BuiltinParamArity::Required,
153        default: None,
154        description: "Input matrix to sort by rows.",
155    },
156    BuiltinParamDescriptor {
157        name: "arg",
158        ty: BuiltinParamType::Any,
159        arity: BuiltinParamArity::Variadic,
160        default: None,
161        description: "Optional column and direction arguments.",
162    },
163    BuiltinParamDescriptor {
164        name: "name",
165        ty: BuiltinParamType::StringScalar,
166        arity: BuiltinParamArity::Required,
167        default: Some("\"ComparisonMethod\""),
168        description: "Name-value option key.",
169    },
170    BuiltinParamDescriptor {
171        name: "method",
172        ty: BuiltinParamType::StringScalar,
173        arity: BuiltinParamArity::Required,
174        default: Some("\"auto\""),
175        description: "Comparison method: 'auto', 'real', or 'abs'.",
176    },
177];
178
179const SORTROWS_INPUTS_MISSING_PLACEMENT: [BuiltinParamDescriptor; 4] = [
180    BuiltinParamDescriptor {
181        name: "A",
182        ty: BuiltinParamType::Any,
183        arity: BuiltinParamArity::Required,
184        default: None,
185        description: "Input matrix to sort by rows.",
186    },
187    BuiltinParamDescriptor {
188        name: "arg",
189        ty: BuiltinParamType::Any,
190        arity: BuiltinParamArity::Variadic,
191        default: None,
192        description: "Optional column and direction arguments.",
193    },
194    BuiltinParamDescriptor {
195        name: "name",
196        ty: BuiltinParamType::StringScalar,
197        arity: BuiltinParamArity::Required,
198        default: Some("\"MissingPlacement\""),
199        description: "Name-value option key.",
200    },
201    BuiltinParamDescriptor {
202        name: "placement",
203        ty: BuiltinParamType::StringScalar,
204        arity: BuiltinParamArity::Required,
205        default: Some("\"auto\""),
206        description: "NaN placement policy: 'auto', 'first', or 'last'.",
207    },
208];
209
210const SORTROWS_SIGNATURES: [BuiltinSignatureDescriptor; 12] = [
211    BuiltinSignatureDescriptor {
212        label: "B = sortrows(A)",
213        inputs: &SORTROWS_INPUTS_A,
214        outputs: &SORTROWS_OUTPUT_B,
215    },
216    BuiltinSignatureDescriptor {
217        label: "B = sortrows(A, column)",
218        inputs: &SORTROWS_INPUTS_A_COLUMNS,
219        outputs: &SORTROWS_OUTPUT_B,
220    },
221    BuiltinSignatureDescriptor {
222        label: "B = sortrows(A, direction)",
223        inputs: &SORTROWS_INPUTS_A_DIRECTION,
224        outputs: &SORTROWS_OUTPUT_B,
225    },
226    BuiltinSignatureDescriptor {
227        label: "B = sortrows(A, column, direction)",
228        inputs: &SORTROWS_INPUTS_A_COLUMNS_DIRECTION,
229        outputs: &SORTROWS_OUTPUT_B,
230    },
231    BuiltinSignatureDescriptor {
232        label: "B = sortrows(A, ..., \"ComparisonMethod\", method)",
233        inputs: &SORTROWS_INPUTS_COMPARISON_METHOD,
234        outputs: &SORTROWS_OUTPUT_B,
235    },
236    BuiltinSignatureDescriptor {
237        label: "B = sortrows(A, ..., \"MissingPlacement\", placement)",
238        inputs: &SORTROWS_INPUTS_MISSING_PLACEMENT,
239        outputs: &SORTROWS_OUTPUT_B,
240    },
241    BuiltinSignatureDescriptor {
242        label: "[B, I] = sortrows(A)",
243        inputs: &SORTROWS_INPUTS_A,
244        outputs: &SORTROWS_OUTPUT_BI,
245    },
246    BuiltinSignatureDescriptor {
247        label: "[B, I] = sortrows(A, column)",
248        inputs: &SORTROWS_INPUTS_A_COLUMNS,
249        outputs: &SORTROWS_OUTPUT_BI,
250    },
251    BuiltinSignatureDescriptor {
252        label: "[B, I] = sortrows(A, direction)",
253        inputs: &SORTROWS_INPUTS_A_DIRECTION,
254        outputs: &SORTROWS_OUTPUT_BI,
255    },
256    BuiltinSignatureDescriptor {
257        label: "[B, I] = sortrows(A, column, direction)",
258        inputs: &SORTROWS_INPUTS_A_COLUMNS_DIRECTION,
259        outputs: &SORTROWS_OUTPUT_BI,
260    },
261    BuiltinSignatureDescriptor {
262        label: "[B, I] = sortrows(A, ..., \"ComparisonMethod\", method)",
263        inputs: &SORTROWS_INPUTS_COMPARISON_METHOD,
264        outputs: &SORTROWS_OUTPUT_BI,
265    },
266    BuiltinSignatureDescriptor {
267        label: "[B, I] = sortrows(A, ..., \"MissingPlacement\", placement)",
268        inputs: &SORTROWS_INPUTS_MISSING_PLACEMENT,
269        outputs: &SORTROWS_OUTPUT_BI,
270    },
271];
272
273const SORTROWS_ERROR_INVALID_COLUMN_INDEX: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
274    code: "RM.SORTROWS.INVALID_COLUMN_INDEX",
275    identifier: Some("RunMat:sortrows:InvalidColumnIndex"),
276    when: "Column specification indices are out of range, zero, or otherwise invalid.",
277    message: "sortrows: invalid column index",
278};
279
280const SORTROWS_ERROR_MISSING_PLACEMENT_UNKNOWN: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
281    code: "RM.SORTROWS.MISSING_PLACEMENT_UNKNOWN",
282    identifier: Some("RunMat:sortrows:MissingPlacementUnknown"),
283    when: "MissingPlacement option value is unsupported.",
284    message: "sortrows: unsupported MissingPlacement value",
285};
286
287const SORTROWS_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
288    code: "RM.SORTROWS.INVALID_ARGUMENT",
289    identifier: Some("RunMat:sortrows:InvalidArgument"),
290    when: "Option parsing receives invalid argument kinds or malformed name-value pairs.",
291    message: "sortrows: invalid argument",
292};
293
294const SORTROWS_ERROR_COMPARISON_METHOD_UNKNOWN: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
295    code: "RM.SORTROWS.COMPARISON_METHOD_UNKNOWN",
296    identifier: Some("RunMat:sortrows:ComparisonMethodUnknown"),
297    when: "ComparisonMethod option value is unsupported.",
298    message: "sortrows: unsupported ComparisonMethod value",
299};
300
301const SORTROWS_ERROR_UNSUPPORTED_INPUT_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
302    code: "RM.SORTROWS.UNSUPPORTED_INPUT_TYPE",
303    identifier: Some("RunMat:sortrows:UnsupportedInputType"),
304    when: "Input cannot be converted to numeric, logical, complex, or char matrix domain.",
305    message: "sortrows: unsupported input type",
306};
307
308const SORTROWS_ERROR_MATRIX_REQUIRED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
309    code: "RM.SORTROWS.MATRIX_REQUIRED",
310    identifier: Some("RunMat:sortrows:MatrixRequired"),
311    when: "Input has rank greater than 2 where matrix input is required.",
312    message: "sortrows: input must be a 2-D matrix",
313};
314
315const SORTROWS_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
316    code: "RM.SORTROWS.INTERNAL",
317    identifier: Some("RunMat:sortrows:Internal"),
318    when: "Internal conversion/allocation/provider decode fails.",
319    message: "sortrows: internal operation failed",
320};
321
322const SORTROWS_ERRORS: [BuiltinErrorDescriptor; 7] = [
323    SORTROWS_ERROR_INVALID_COLUMN_INDEX,
324    SORTROWS_ERROR_MISSING_PLACEMENT_UNKNOWN,
325    SORTROWS_ERROR_INVALID_ARGUMENT,
326    SORTROWS_ERROR_COMPARISON_METHOD_UNKNOWN,
327    SORTROWS_ERROR_UNSUPPORTED_INPUT_TYPE,
328    SORTROWS_ERROR_MATRIX_REQUIRED,
329    SORTROWS_ERROR_INTERNAL,
330];
331
332pub const SORTROWS_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
333    signatures: &SORTROWS_SIGNATURES,
334    output_mode: BuiltinOutputMode::ByRequestedOutputCount,
335    completion_policy: BuiltinCompletionPolicy::Public,
336    errors: &SORTROWS_ERRORS,
337};
338
339fn sortrows_error_with(
340    error: &'static BuiltinErrorDescriptor,
341    message: impl Into<String>,
342) -> crate::RuntimeError {
343    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
344    if let Some(identifier) = error.identifier {
345        builder = builder.with_identifier(identifier);
346    }
347    builder.build()
348}
349
350fn sortrows_error(error: &'static BuiltinErrorDescriptor) -> crate::RuntimeError {
351    sortrows_error_with(error, error.message)
352}
353
354fn sortrows_internal_error(message: impl Into<String>) -> crate::RuntimeError {
355    sortrows_error_with(&SORTROWS_ERROR_INTERNAL, message)
356}
357
358#[runtime_builtin(
359    name = "sortrows",
360    category = "array/sorting_sets",
361    summary = "Sort matrix rows lexicographically with column and direction controls.",
362    keywords = "sortrows,row sort,lexicographic,gpu",
363    accel = "sink",
364    sink = true,
365    type_resolver(tensor_output_type),
366    descriptor(crate::builtins::array::sorting_sets::sortrows::SORTROWS_DESCRIPTOR),
367    builtin_path = "crate::builtins::array::sorting_sets::sortrows"
368)]
369async fn sortrows_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
370    let eval = evaluate(value, &rest).await?;
371    if let Some(out_count) = crate::output_count::current_output_count() {
372        if out_count == 0 {
373            return Ok(Value::OutputList(Vec::new()));
374        }
375        let (sorted, indices) = eval.into_values();
376        let mut outputs = vec![sorted];
377        if out_count >= 2 {
378            outputs.push(indices);
379        }
380        return Ok(crate::output_count::output_list_with_padding(
381            out_count, outputs,
382        ));
383    }
384    Ok(eval.into_sorted_value())
385}
386
387/// Evaluate the `sortrows` builtin once and expose both outputs.
388pub async fn evaluate(value: Value, rest: &[Value]) -> crate::BuiltinResult<SortRowsEvaluation> {
389    match value {
390        Value::GpuTensor(handle) => sortrows_gpu(handle, rest).await,
391        other => sortrows_host(other, rest),
392    }
393}
394
395async fn sortrows_gpu(
396    handle: GpuTensorHandle,
397    rest: &[Value],
398) -> crate::BuiltinResult<SortRowsEvaluation> {
399    ensure_matrix_shape(&handle.shape)?;
400    let (_, cols) = rows_cols_from_shape(&handle.shape);
401    let args = SortRowsArgs::parse(rest, cols)?;
402
403    if args.missing_is_auto() {
404        if let Some(provider) = runmat_accelerate_api::provider() {
405            let provider_columns = args.to_provider_columns();
406            let provider_comparison = args.provider_comparison();
407            match provider
408                .sort_rows(&handle, &provider_columns, provider_comparison)
409                .await
410            {
411                Ok(result) => return sortrows_from_provider_result(result),
412                Err(_err) => {
413                    // fall back to host path when provider cannot service the request
414                }
415            }
416        }
417    }
418
419    let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
420    sortrows_real_tensor_with_args(tensor, &args)
421}
422
423fn sortrows_from_provider_result(
424    result: ProviderSortResult,
425) -> crate::BuiltinResult<SortRowsEvaluation> {
426    let sorted_tensor = Tensor::new(result.values.data, result.values.shape)
427        .map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))?;
428    let indices_tensor = Tensor::new(result.indices.data, result.indices.shape)
429        .map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))?;
430    Ok(SortRowsEvaluation {
431        sorted: tensor::tensor_into_value(sorted_tensor),
432        indices: indices_tensor,
433    })
434}
435
436fn sortrows_host(value: Value, rest: &[Value]) -> crate::BuiltinResult<SortRowsEvaluation> {
437    match value {
438        Value::Tensor(tensor) => sortrows_real_tensor(tensor, rest),
439        Value::LogicalArray(logical) => {
440            let tensor = tensor::logical_to_tensor(&logical)
441                .map_err(|e| sortrows_internal_error(e))?;
442            sortrows_real_tensor(tensor, rest)
443        }
444        Value::Num(_) | Value::Int(_) | Value::Bool(_) => {
445            let tensor = tensor::value_into_tensor_for("sortrows", value)
446                .map_err(|e| sortrows_internal_error(e))?;
447            sortrows_real_tensor(tensor, rest)
448        }
449        Value::ComplexTensor(ct) => sortrows_complex_tensor(ct, rest),
450        Value::Complex(re, im) => {
451            let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
452                .map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))?;
453            sortrows_complex_tensor(tensor, rest)
454        }
455        Value::CharArray(ca) => sortrows_char_array(ca, rest),
456        Value::Object(obj) if obj.is_class(crate::builtins::table::TABLE_CLASS) => {
457            let (sorted, indices) =
458                crate::builtins::table::sortrows_table(Value::Object(obj), rest)?;
459            Ok(SortRowsEvaluation::from_parts(sorted, indices))
460        }
461        other => Err(sortrows_error_with(
462            &SORTROWS_ERROR_UNSUPPORTED_INPUT_TYPE,
463            format!(
464                "sortrows: unsupported input type {:?}; expected numeric, logical, complex, or char arrays",
465                other
466            ),
467        )
468        .into()),
469    }
470}
471
472fn sortrows_real_tensor(
473    tensor: Tensor,
474    rest: &[Value],
475) -> crate::BuiltinResult<SortRowsEvaluation> {
476    ensure_matrix_shape(&tensor.shape)?;
477    let cols = tensor.cols();
478    let args = SortRowsArgs::parse(rest, cols)?;
479    sortrows_real_tensor_with_args(tensor, &args)
480}
481
482fn sortrows_real_tensor_with_args(
483    tensor: Tensor,
484    args: &SortRowsArgs,
485) -> crate::BuiltinResult<SortRowsEvaluation> {
486    let rows = tensor.rows();
487    let cols = tensor.cols();
488
489    if rows <= 1 || cols == 0 || tensor.data.is_empty() || args.columns.is_empty() {
490        let indices = identity_indices(rows)?;
491        return Ok(SortRowsEvaluation {
492            sorted: tensor::tensor_into_value(tensor),
493            indices,
494        });
495    }
496
497    let mut order: Vec<usize> = (0..rows).collect();
498    order.sort_by(|&a, &b| compare_real_rows(&tensor, rows, args, a, b));
499
500    let sorted_tensor = reorder_real_rows(&tensor, rows, cols, &order)?;
501    let indices = permutation_indices(&order)?;
502    Ok(SortRowsEvaluation {
503        sorted: tensor::tensor_into_value(sorted_tensor),
504        indices,
505    })
506}
507
508fn sortrows_complex_tensor(
509    tensor: ComplexTensor,
510    rest: &[Value],
511) -> crate::BuiltinResult<SortRowsEvaluation> {
512    ensure_matrix_shape(&tensor.shape)?;
513    let cols = tensor.cols;
514    let args = SortRowsArgs::parse(rest, cols)?;
515    sortrows_complex_tensor_with_args(tensor, &args)
516}
517
518fn sortrows_complex_tensor_with_args(
519    tensor: ComplexTensor,
520    args: &SortRowsArgs,
521) -> crate::BuiltinResult<SortRowsEvaluation> {
522    let rows = tensor.rows;
523    let cols = tensor.cols;
524
525    if rows <= 1 || cols == 0 || tensor.data.is_empty() || args.columns.is_empty() {
526        let indices = identity_indices(rows)?;
527        return Ok(SortRowsEvaluation {
528            sorted: complex_tensor_into_value(tensor),
529            indices,
530        });
531    }
532
533    let mut order: Vec<usize> = (0..rows).collect();
534    order.sort_by(|&a, &b| compare_complex_rows(&tensor, rows, args, a, b));
535
536    let sorted_tensor = reorder_complex_rows(&tensor, rows, cols, &order)?;
537    let indices = permutation_indices(&order)?;
538    Ok(SortRowsEvaluation {
539        sorted: complex_tensor_into_value(sorted_tensor),
540        indices,
541    })
542}
543
544fn sortrows_char_array(ca: CharArray, rest: &[Value]) -> crate::BuiltinResult<SortRowsEvaluation> {
545    let cols = ca.cols;
546    let args = SortRowsArgs::parse(rest, cols)?;
547    sortrows_char_array_with_args(ca, &args)
548}
549
550fn sortrows_char_array_with_args(
551    ca: CharArray,
552    args: &SortRowsArgs,
553) -> crate::BuiltinResult<SortRowsEvaluation> {
554    let rows = ca.rows;
555    let cols = ca.cols;
556
557    if rows <= 1 || cols == 0 || ca.data.is_empty() || args.columns.is_empty() {
558        let indices = identity_indices(rows)?;
559        return Ok(SortRowsEvaluation {
560            sorted: Value::CharArray(ca),
561            indices,
562        });
563    }
564
565    let mut order: Vec<usize> = (0..rows).collect();
566    order.sort_by(|&a, &b| compare_char_rows(&ca, args, a, b));
567
568    let sorted = reorder_char_rows(&ca, rows, cols, &order)?;
569    let indices = permutation_indices(&order)?;
570    Ok(SortRowsEvaluation {
571        sorted: Value::CharArray(sorted),
572        indices,
573    })
574}
575
576fn ensure_matrix_shape(shape: &[usize]) -> crate::BuiltinResult<()> {
577    if shape.len() <= 2 {
578        Ok(())
579    } else {
580        Err(sortrows_error(&SORTROWS_ERROR_MATRIX_REQUIRED))
581    }
582}
583
584fn rows_cols_from_shape(shape: &[usize]) -> (usize, usize) {
585    match shape.len() {
586        0 => (1, 1),
587        1 => (1, shape[0]),
588        _ => (shape[0], shape[1]),
589    }
590}
591
592fn compare_real_rows(
593    tensor: &Tensor,
594    rows: usize,
595    args: &SortRowsArgs,
596    a: usize,
597    b: usize,
598) -> Ordering {
599    for spec in &args.columns {
600        if spec.index >= tensor.cols() {
601            continue;
602        }
603        let idx_a = a + spec.index * rows;
604        let idx_b = b + spec.index * rows;
605        let va = tensor.data[idx_a];
606        let vb = tensor.data[idx_b];
607        let missing = args.missing_for_direction(spec.direction);
608        let ord = compare_real_scalars(va, vb, spec.direction, args.comparison, missing);
609        if ord != Ordering::Equal {
610            return ord;
611        }
612    }
613    Ordering::Equal
614}
615
616fn compare_complex_rows(
617    tensor: &ComplexTensor,
618    rows: usize,
619    args: &SortRowsArgs,
620    a: usize,
621    b: usize,
622) -> Ordering {
623    for spec in &args.columns {
624        if spec.index >= tensor.cols {
625            continue;
626        }
627        let idx_a = a + spec.index * rows;
628        let idx_b = b + spec.index * rows;
629        let va = tensor.data[idx_a];
630        let vb = tensor.data[idx_b];
631        let missing = args.missing_for_direction(spec.direction);
632        let ord = compare_complex_scalars(va, vb, spec.direction, args.comparison, missing);
633        if ord != Ordering::Equal {
634            return ord;
635        }
636    }
637    Ordering::Equal
638}
639
640fn compare_char_rows(ca: &CharArray, args: &SortRowsArgs, a: usize, b: usize) -> Ordering {
641    for spec in &args.columns {
642        if spec.index >= ca.cols {
643            continue;
644        }
645        let idx_a = a * ca.cols + spec.index;
646        let idx_b = b * ca.cols + spec.index;
647        let va = ca.data[idx_a];
648        let vb = ca.data[idx_b];
649        let ord = match spec.direction {
650            SortDirection::Ascend => va.cmp(&vb),
651            SortDirection::Descend => vb.cmp(&va),
652        };
653        if ord != Ordering::Equal {
654            return ord;
655        }
656    }
657    Ordering::Equal
658}
659
660fn reorder_real_rows(
661    tensor: &Tensor,
662    rows: usize,
663    cols: usize,
664    order: &[usize],
665) -> crate::BuiltinResult<Tensor> {
666    let mut data = vec![0.0; tensor.data.len()];
667    for col in 0..cols {
668        for (dest_row, &src_row) in order.iter().enumerate() {
669            let src_idx = src_row + col * rows;
670            let dst_idx = dest_row + col * rows;
671            data[dst_idx] = tensor.data[src_idx];
672        }
673    }
674    Tensor::new(data, tensor.shape.clone())
675        .map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))
676}
677
678fn reorder_complex_rows(
679    tensor: &ComplexTensor,
680    rows: usize,
681    cols: usize,
682    order: &[usize],
683) -> crate::BuiltinResult<ComplexTensor> {
684    let mut data = vec![(0.0, 0.0); tensor.data.len()];
685    for col in 0..cols {
686        for (dest_row, &src_row) in order.iter().enumerate() {
687            let src_idx = src_row + col * rows;
688            let dst_idx = dest_row + col * rows;
689            data[dst_idx] = tensor.data[src_idx];
690        }
691    }
692    ComplexTensor::new(data, tensor.shape.clone())
693        .map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))
694}
695
696fn reorder_char_rows(
697    ca: &CharArray,
698    rows: usize,
699    cols: usize,
700    order: &[usize],
701) -> crate::BuiltinResult<CharArray> {
702    let mut data = vec!['\0'; ca.data.len()];
703    for (dest_row, &src_row) in order.iter().enumerate() {
704        for col in 0..cols {
705            let src_idx = src_row * cols + col;
706            let dst_idx = dest_row * cols + col;
707            data[dst_idx] = ca.data[src_idx];
708        }
709    }
710    CharArray::new(data, rows, cols).map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))
711}
712
713fn compare_real_scalars(
714    a: f64,
715    b: f64,
716    direction: SortDirection,
717    comparison: ComparisonMethod,
718    missing: MissingPlacementResolved,
719) -> Ordering {
720    match (a.is_nan(), b.is_nan()) {
721        (true, true) => Ordering::Equal,
722        (true, false) => match missing {
723            MissingPlacementResolved::First => Ordering::Less,
724            MissingPlacementResolved::Last => Ordering::Greater,
725        },
726        (false, true) => match missing {
727            MissingPlacementResolved::First => Ordering::Greater,
728            MissingPlacementResolved::Last => Ordering::Less,
729        },
730        (false, false) => compare_real_finite_scalars(a, b, direction, comparison),
731    }
732}
733
734fn compare_real_finite_scalars(
735    a: f64,
736    b: f64,
737    direction: SortDirection,
738    comparison: ComparisonMethod,
739) -> Ordering {
740    if matches!(comparison, ComparisonMethod::Abs) {
741        let abs_cmp = a.abs().partial_cmp(&b.abs()).unwrap_or(Ordering::Equal);
742        if abs_cmp != Ordering::Equal {
743            return match direction {
744                SortDirection::Ascend => abs_cmp,
745                SortDirection::Descend => abs_cmp.reverse(),
746            };
747        }
748    }
749    match direction {
750        SortDirection::Ascend => a.partial_cmp(&b).unwrap_or(Ordering::Equal),
751        SortDirection::Descend => b.partial_cmp(&a).unwrap_or(Ordering::Equal),
752    }
753}
754
755fn compare_complex_scalars(
756    a: (f64, f64),
757    b: (f64, f64),
758    direction: SortDirection,
759    comparison: ComparisonMethod,
760    missing: MissingPlacementResolved,
761) -> Ordering {
762    match (complex_is_nan(a), complex_is_nan(b)) {
763        (true, true) => Ordering::Equal,
764        (true, false) => match missing {
765            MissingPlacementResolved::First => Ordering::Less,
766            MissingPlacementResolved::Last => Ordering::Greater,
767        },
768        (false, true) => match missing {
769            MissingPlacementResolved::First => Ordering::Greater,
770            MissingPlacementResolved::Last => Ordering::Less,
771        },
772        (false, false) => compare_complex_finite_scalars(a, b, direction, comparison),
773    }
774}
775
776fn compare_complex_finite_scalars(
777    a: (f64, f64),
778    b: (f64, f64),
779    direction: SortDirection,
780    comparison: ComparisonMethod,
781) -> Ordering {
782    match comparison {
783        ComparisonMethod::Real => compare_complex_real_first(a, b, direction),
784        ComparisonMethod::Auto | ComparisonMethod::Abs => {
785            let abs_cmp = complex_abs(a)
786                .partial_cmp(&complex_abs(b))
787                .unwrap_or(Ordering::Equal);
788            if abs_cmp != Ordering::Equal {
789                return match direction {
790                    SortDirection::Ascend => abs_cmp,
791                    SortDirection::Descend => abs_cmp.reverse(),
792                };
793            }
794            compare_complex_real_first(a, b, direction)
795        }
796    }
797}
798
799fn compare_complex_real_first(a: (f64, f64), b: (f64, f64), direction: SortDirection) -> Ordering {
800    let real_cmp = match direction {
801        SortDirection::Ascend => a.0.partial_cmp(&b.0),
802        SortDirection::Descend => b.0.partial_cmp(&a.0),
803    }
804    .unwrap_or(Ordering::Equal);
805    if real_cmp != Ordering::Equal {
806        return real_cmp;
807    }
808    match direction {
809        SortDirection::Ascend => a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal),
810        SortDirection::Descend => b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal),
811    }
812}
813
814fn complex_is_nan(value: (f64, f64)) -> bool {
815    value.0.is_nan() || value.1.is_nan()
816}
817
818fn complex_abs(value: (f64, f64)) -> f64 {
819    value.0.hypot(value.1)
820}
821
822fn permutation_indices(order: &[usize]) -> crate::BuiltinResult<Tensor> {
823    let rows = order.len();
824    let mut data = Vec::with_capacity(rows);
825    for &idx in order {
826        data.push((idx + 1) as f64);
827    }
828    Tensor::new(data, vec![rows, 1]).map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))
829}
830
831fn identity_indices(rows: usize) -> crate::BuiltinResult<Tensor> {
832    let mut data = Vec::with_capacity(rows);
833    for i in 0..rows {
834        data.push((i + 1) as f64);
835    }
836    Tensor::new(data, vec![rows, 1]).map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))
837}
838
839fn complex_tensor_into_value(tensor: ComplexTensor) -> Value {
840    if tensor.data.len() == 1 {
841        Value::Complex(tensor.data[0].0, tensor.data[0].1)
842    } else {
843        Value::ComplexTensor(tensor)
844    }
845}
846
847#[derive(Debug, Clone, Copy, PartialEq, Eq)]
848enum SortDirection {
849    Ascend,
850    Descend,
851}
852
853impl SortDirection {
854    fn from_str(value: &str) -> Option<Self> {
855        match value.trim().to_ascii_lowercase().as_str() {
856            "ascend" | "ascending" => Some(SortDirection::Ascend),
857            "descend" | "descending" => Some(SortDirection::Descend),
858            _ => None,
859        }
860    }
861}
862
863#[derive(Debug, Clone, Copy, PartialEq, Eq)]
864enum ComparisonMethod {
865    Auto,
866    Real,
867    Abs,
868}
869
870#[derive(Debug, Clone, Copy, PartialEq, Eq)]
871enum MissingPlacement {
872    Auto,
873    First,
874    Last,
875}
876
877#[derive(Debug, Clone, Copy, PartialEq, Eq)]
878enum MissingPlacementResolved {
879    First,
880    Last,
881}
882
883impl MissingPlacement {
884    fn resolve(self, direction: SortDirection) -> MissingPlacementResolved {
885        match self {
886            MissingPlacement::First => MissingPlacementResolved::First,
887            MissingPlacement::Last => MissingPlacementResolved::Last,
888            MissingPlacement::Auto => match direction {
889                SortDirection::Ascend => MissingPlacementResolved::Last,
890                SortDirection::Descend => MissingPlacementResolved::First,
891            },
892        }
893    }
894
895    fn is_auto(self) -> bool {
896        matches!(self, MissingPlacement::Auto)
897    }
898}
899
900#[derive(Debug, Clone)]
901struct ColumnSpec {
902    index: usize,
903    direction: SortDirection,
904}
905
906#[derive(Debug, Clone)]
907struct SortRowsArgs {
908    columns: Vec<ColumnSpec>,
909    comparison: ComparisonMethod,
910    missing: MissingPlacement,
911}
912
913impl SortRowsArgs {
914    fn parse(rest: &[Value], num_cols: usize) -> crate::BuiltinResult<Self> {
915        let mut columns: Option<Vec<ColumnSpec>> = None;
916        let mut override_direction: Option<SortDirection> = None;
917        let mut comparison = ComparisonMethod::Auto;
918        let mut missing = MissingPlacement::Auto;
919        let mut i = 0usize;
920
921        while i < rest.len() {
922            if columns.is_none() {
923                if let Some(parsed) = parse_column_vector(&rest[i], num_cols)? {
924                    columns = Some(parsed);
925                    i += 1;
926                    continue;
927                }
928            }
929            if let Some(direction) = parse_direction(&rest[i]) {
930                override_direction = Some(direction);
931                i += 1;
932                continue;
933            }
934            let Some(keyword) = tensor::value_to_string(&rest[i]) else {
935                return Err(sortrows_error_with(
936                    &SORTROWS_ERROR_INVALID_ARGUMENT,
937                    format!("sortrows: invalid argument {:?}", rest[i]),
938                ));
939            };
940            let lowered = keyword.trim().to_ascii_lowercase();
941            match lowered.as_str() {
942                "comparisonmethod" => {
943                    i += 1;
944                    if i >= rest.len() {
945                        return Err(sortrows_error_with(
946                            &SORTROWS_ERROR_INVALID_ARGUMENT,
947                            "sortrows: expected a value for 'ComparisonMethod'",
948                        ));
949                    }
950                    let Some(value_str) = tensor::value_to_string(&rest[i]) else {
951                        return Err(sortrows_error_with(
952                            &SORTROWS_ERROR_INVALID_ARGUMENT,
953                            "sortrows: 'ComparisonMethod' expects a string value",
954                        )
955                        .into());
956                    };
957                    comparison = match value_str.trim().to_ascii_lowercase().as_str() {
958                        "auto" => ComparisonMethod::Auto,
959                        "real" => ComparisonMethod::Real,
960                        "abs" | "magnitude" => ComparisonMethod::Abs,
961                        other => {
962                            return Err(sortrows_error_with(
963                                &SORTROWS_ERROR_COMPARISON_METHOD_UNKNOWN,
964                                format!("sortrows: unsupported ComparisonMethod '{other}'"),
965                            )
966                            .into())
967                        }
968                    };
969                    i += 1;
970                }
971                "missingplacement" => {
972                    i += 1;
973                    if i >= rest.len() {
974                        return Err(sortrows_error_with(
975                            &SORTROWS_ERROR_INVALID_ARGUMENT,
976                            "sortrows: expected a value for 'MissingPlacement'",
977                        )
978                        .into());
979                    }
980                    let Some(value_str) = tensor::value_to_string(&rest[i]) else {
981                        return Err(sortrows_error_with(
982                            &SORTROWS_ERROR_INVALID_ARGUMENT,
983                            "sortrows: 'MissingPlacement' expects a string value",
984                        )
985                        .into());
986                    };
987                    missing = match value_str.trim().to_ascii_lowercase().as_str() {
988                        "auto" => MissingPlacement::Auto,
989                        "first" => MissingPlacement::First,
990                        "last" => MissingPlacement::Last,
991                        other => {
992                            return Err(sortrows_error_with(
993                                &SORTROWS_ERROR_MISSING_PLACEMENT_UNKNOWN,
994                                format!("sortrows: unsupported MissingPlacement '{other}'"),
995                            )
996                            .into())
997                        }
998                    };
999                    i += 1;
1000                }
1001                other => {
1002                    return Err(sortrows_error_with(
1003                        &SORTROWS_ERROR_INVALID_ARGUMENT,
1004                        format!("sortrows: unexpected argument '{other}'"),
1005                    ));
1006                }
1007            }
1008        }
1009
1010        let mut columns = columns.unwrap_or_else(|| default_columns(num_cols));
1011        if let Some(dir) = override_direction {
1012            for spec in &mut columns {
1013                spec.direction = dir;
1014            }
1015        }
1016        validate_columns(&columns, num_cols)?;
1017
1018        Ok(SortRowsArgs {
1019            columns,
1020            comparison,
1021            missing,
1022        })
1023    }
1024
1025    fn to_provider_columns(&self) -> Vec<ProviderSortRowsColumnSpec> {
1026        self.columns
1027            .iter()
1028            .map(|spec| ProviderSortRowsColumnSpec {
1029                index: spec.index,
1030                order: match spec.direction {
1031                    SortDirection::Ascend => ProviderSortOrder::Ascend,
1032                    SortDirection::Descend => ProviderSortOrder::Descend,
1033                },
1034            })
1035            .collect()
1036    }
1037
1038    fn provider_comparison(&self) -> ProviderSortComparison {
1039        match self.comparison {
1040            ComparisonMethod::Auto => ProviderSortComparison::Auto,
1041            ComparisonMethod::Real => ProviderSortComparison::Real,
1042            ComparisonMethod::Abs => ProviderSortComparison::Abs,
1043        }
1044    }
1045
1046    fn missing_for_direction(&self, direction: SortDirection) -> MissingPlacementResolved {
1047        self.missing.resolve(direction)
1048    }
1049
1050    fn missing_is_auto(&self) -> bool {
1051        self.missing.is_auto()
1052    }
1053}
1054
1055fn parse_column_vector(
1056    value: &Value,
1057    num_cols: usize,
1058) -> crate::BuiltinResult<Option<Vec<ColumnSpec>>> {
1059    match value {
1060        Value::Int(i) => parse_single_column(i.to_i64(), num_cols).map(Some),
1061        Value::Num(n) => {
1062            if !n.is_finite() {
1063                return Err(sortrows_error_with(
1064                    &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1065                    "sortrows: column indices must be finite",
1066                ));
1067            }
1068            let rounded = n.round();
1069            if (rounded - n).abs() > f64::EPSILON {
1070                return Err(sortrows_error_with(
1071                    &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1072                    "sortrows: column indices must be integers",
1073                ));
1074            }
1075            parse_single_column(rounded as i64, num_cols).map(Some)
1076        }
1077        Value::Tensor(tensor) => {
1078            if !is_vector(&tensor.shape) {
1079                return Err(sortrows_error_with(
1080                    &SORTROWS_ERROR_INVALID_ARGUMENT,
1081                    "sortrows: column specification must be a vector",
1082                ));
1083            }
1084            let mut specs = Vec::with_capacity(tensor.data.len());
1085            for &entry in &tensor.data {
1086                if !entry.is_finite() {
1087                    return Err(sortrows_error_with(
1088                        &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1089                        "sortrows: column indices must be finite",
1090                    ));
1091                }
1092                let rounded = entry.round();
1093                if (rounded - entry).abs() > f64::EPSILON {
1094                    return Err(sortrows_error_with(
1095                        &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1096                        "sortrows: column indices must be integers",
1097                    ));
1098                }
1099                let column = parse_single_column_i64(rounded as i64, num_cols)?;
1100                specs.push(column);
1101            }
1102            Ok(Some(specs))
1103        }
1104        _ => Ok(None),
1105    }
1106}
1107
1108fn parse_single_column(value: i64, num_cols: usize) -> crate::BuiltinResult<Vec<ColumnSpec>> {
1109    parse_single_column_i64(value, num_cols).map(|spec| vec![spec])
1110}
1111
1112fn parse_single_column_i64(value: i64, num_cols: usize) -> crate::BuiltinResult<ColumnSpec> {
1113    if value == 0 {
1114        return Err(sortrows_error_with(
1115            &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1116            "sortrows: column indices must be non-zero",
1117        ));
1118    }
1119    let abs = value.unsigned_abs() as usize;
1120    if abs == 0 {
1121        return Err(sortrows_error_with(
1122            &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1123            "sortrows: column indices must be >= 1",
1124        ));
1125    }
1126    if num_cols == 0 {
1127        return Err(sortrows_error_with(
1128            &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1129            "sortrows: column index exceeds matrix with 0 columns",
1130        ));
1131    }
1132    if abs > num_cols {
1133        return Err(sortrows_error_with(
1134            &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1135            format!(
1136                "sortrows: column index {} exceeds matrix with {} columns",
1137                abs, num_cols
1138            ),
1139        )
1140        .into());
1141    }
1142    let direction = if value > 0 {
1143        SortDirection::Ascend
1144    } else {
1145        SortDirection::Descend
1146    };
1147    Ok(ColumnSpec {
1148        index: abs - 1,
1149        direction,
1150    })
1151}
1152
1153fn parse_direction(value: &Value) -> Option<SortDirection> {
1154    tensor::value_to_string(value).and_then(|s| SortDirection::from_str(&s))
1155}
1156
1157fn default_columns(num_cols: usize) -> Vec<ColumnSpec> {
1158    let mut columns = Vec::with_capacity(num_cols);
1159    for col in 0..num_cols {
1160        columns.push(ColumnSpec {
1161            index: col,
1162            direction: SortDirection::Ascend,
1163        });
1164    }
1165    columns
1166}
1167
1168fn validate_columns(columns: &[ColumnSpec], num_cols: usize) -> crate::BuiltinResult<()> {
1169    if num_cols == 0 && columns.iter().any(|spec| spec.index > 0) {
1170        return Err(sortrows_error_with(
1171            &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1172            "sortrows: column index exceeds matrix with 0 columns",
1173        ));
1174    }
1175    for spec in columns {
1176        if num_cols > 0 && spec.index >= num_cols {
1177            return Err(sortrows_error_with(
1178                &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1179                format!(
1180                    "sortrows: column index {} exceeds matrix with {} columns",
1181                    spec.index + 1,
1182                    num_cols
1183                ),
1184            )
1185            .into());
1186        }
1187    }
1188    Ok(())
1189}
1190
1191fn is_vector(shape: &[usize]) -> bool {
1192    match shape.len() {
1193        0 => true,
1194        1 => true,
1195        2 => shape[0] == 1 || shape[1] == 1,
1196        _ => false,
1197    }
1198}
1199
1200#[derive(Debug)]
1201pub struct SortRowsEvaluation {
1202    sorted: Value,
1203    indices: Tensor,
1204}
1205
1206impl SortRowsEvaluation {
1207    pub(crate) fn from_parts(sorted: Value, indices: Tensor) -> Self {
1208        Self { sorted, indices }
1209    }
1210
1211    pub fn into_sorted_value(self) -> Value {
1212        self.sorted
1213    }
1214
1215    pub fn into_values(self) -> (Value, Value) {
1216        let indices = tensor::tensor_into_value(self.indices);
1217        (self.sorted, indices)
1218    }
1219
1220    pub fn indices_value(&self) -> Value {
1221        tensor::tensor_into_value(self.indices.clone())
1222    }
1223}
1224
1225#[cfg(test)]
1226pub(crate) mod tests {
1227    use super::*;
1228    use crate::builtins::common::test_support;
1229    use runmat_builtins::{IntValue, ResolveContext, Type, Value};
1230
1231    fn evaluate(value: Value, rest: &[Value]) -> crate::BuiltinResult<SortRowsEvaluation> {
1232        futures::executor::block_on(super::evaluate(value, rest))
1233    }
1234
1235    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1236    #[test]
1237    fn sortrows_default_matrix() {
1238        let tensor = Tensor::new(vec![3.0, 1.0, 2.0, 4.0, 1.0, 5.0], vec![3, 2]).unwrap();
1239        let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
1240        let (sorted, indices) = eval.into_values();
1241        match sorted {
1242            Value::Tensor(t) => {
1243                assert_eq!(t.shape, vec![3, 2]);
1244                assert_eq!(t.data, vec![1.0, 2.0, 3.0, 1.0, 5.0, 4.0]);
1245            }
1246            other => panic!("expected tensor, got {other:?}"),
1247        }
1248        match indices {
1249            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
1250            Value::Num(_) => panic!("expected tensor indices"),
1251            other => panic!("unexpected indices {other:?}"),
1252        }
1253    }
1254
1255    #[test]
1256    fn sortrows_type_resolver_tensor() {
1257        assert_eq!(
1258            tensor_output_type(&[Type::tensor()], &ResolveContext::new(Vec::new())),
1259            Type::tensor()
1260        );
1261    }
1262
1263    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1264    #[test]
1265    fn sortrows_with_column_vector() {
1266        let tensor = Tensor::new(
1267            vec![1.0, 3.0, 3.0, 4.0, 2.0, 2.0, 2.0, 5.0, 1.0],
1268            vec![3, 3],
1269        )
1270        .unwrap();
1271        let cols = Tensor::new(vec![2.0, 3.0, 1.0], vec![3, 1]).unwrap();
1272        let eval = evaluate(Value::Tensor(tensor), &[Value::Tensor(cols)]).expect("evaluate");
1273        let (sorted, _) = eval.into_values();
1274        match sorted {
1275            Value::Tensor(t) => {
1276                assert_eq!(t.data, vec![3.0, 3.0, 1.0, 2.0, 2.0, 4.0, 1.0, 5.0, 2.0]);
1277            }
1278            other => panic!("expected tensor, got {other:?}"),
1279        }
1280    }
1281
1282    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1283    #[test]
1284    fn sortrows_direction_descend() {
1285        let tensor = Tensor::new(vec![1.0, 2.0, 4.0, 3.0], vec![2, 2]).unwrap();
1286        let eval = evaluate(Value::Tensor(tensor), &[Value::from("descend")]).expect("evaluate");
1287        let (sorted, _) = eval.into_values();
1288        match sorted {
1289            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0, 3.0, 4.0]),
1290            other => panic!("expected tensor, got {other:?}"),
1291        }
1292    }
1293
1294    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1295    #[test]
1296    fn sortrows_mixed_directions() {
1297        let tensor = Tensor::new(vec![1.0, 1.0, 1.0, 1.0, 7.0, 2.0], vec![3, 2]).unwrap();
1298        let cols = Tensor::new(vec![1.0, -2.0], vec![2, 1]).unwrap();
1299        let eval = evaluate(Value::Tensor(tensor), &[Value::Tensor(cols)]).expect("evaluate");
1300        let (sorted, _) = eval.into_values();
1301        match sorted {
1302            Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 1.0, 1.0, 7.0, 2.0, 1.0]),
1303            other => panic!("expected tensor, got {other:?}"),
1304        }
1305    }
1306
1307    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1308    #[test]
1309    fn sortrows_returns_indices() {
1310        let tensor = Tensor::new(vec![2.0, 1.0, 3.0, 4.0], vec![2, 2]).unwrap();
1311        let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
1312        let (_, indices) = eval.into_values();
1313        match indices {
1314            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0]),
1315            Value::Num(_) => panic!("expected tensor indices"),
1316            other => panic!("unexpected indices {other:?}"),
1317        }
1318    }
1319
1320    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1321    #[test]
1322    fn sortrows_char_array() {
1323        let chars = CharArray::new(
1324            "bob "
1325                .chars()
1326                .chain("al  ".chars())
1327                .chain("ally".chars())
1328                .collect(),
1329            3,
1330            4,
1331        )
1332        .unwrap();
1333        let eval = evaluate(Value::CharArray(chars), &[]).expect("evaluate");
1334        let (sorted, _) = eval.into_values();
1335        match sorted {
1336            Value::CharArray(ca) => {
1337                assert_eq!(ca.rows, 3);
1338                assert_eq!(ca.cols, 4);
1339                let strings: Vec<String> = (0..ca.rows)
1340                    .map(|r| {
1341                        ca.data[r * ca.cols..(r + 1) * ca.cols]
1342                            .iter()
1343                            .collect::<String>()
1344                    })
1345                    .collect();
1346                assert_eq!(
1347                    strings,
1348                    vec!["al  ".to_string(), "ally".to_string(), "bob ".to_string()]
1349                );
1350            }
1351            other => panic!("expected char array, got {other:?}"),
1352        }
1353    }
1354
1355    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1356    #[test]
1357    fn sortrows_complex_abs() {
1358        let tensor = ComplexTensor::new(vec![(1.0, 2.0), (-2.0, 1.0)], vec![2, 1]).unwrap();
1359        let eval = evaluate(
1360            Value::ComplexTensor(tensor),
1361            &[Value::from("ComparisonMethod"), Value::from("abs")],
1362        )
1363        .expect("evaluate");
1364        let (sorted, _) = eval.into_values();
1365        match sorted {
1366            Value::ComplexTensor(ct) => {
1367                assert_eq!(ct.data, vec![(-2.0, 1.0), (1.0, 2.0)]);
1368            }
1369            other => panic!("expected complex tensor, got {other:?}"),
1370        }
1371    }
1372
1373    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1374    #[test]
1375    fn sortrows_invalid_column_index_errors() {
1376        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1377        let err = evaluate(Value::Tensor(tensor), &[Value::Int(IntValue::I32(3))]).unwrap_err();
1378        assert_eq!(
1379            err.identifier(),
1380            SORTROWS_ERROR_INVALID_COLUMN_INDEX.identifier
1381        );
1382    }
1383
1384    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1385    #[test]
1386    fn sortrows_missingplacement_first_moves_nan_first() {
1387        let tensor = Tensor::new(vec![1.0, f64::NAN, 2.0, 3.0], vec![2, 2]).unwrap();
1388        let eval = evaluate(
1389            Value::Tensor(tensor),
1390            &[Value::from("MissingPlacement"), Value::from("first")],
1391        )
1392        .expect("evaluate");
1393        let (sorted, indices) = eval.into_values();
1394        match sorted {
1395            Value::Tensor(t) => {
1396                assert!(t.data[0].is_nan());
1397                assert_eq!(t.data[1], 1.0);
1398                assert_eq!(t.data[2], 3.0);
1399                assert_eq!(t.data[3], 2.0);
1400            }
1401            other => panic!("expected tensor, got {other:?}"),
1402        }
1403        match indices {
1404            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0]),
1405            Value::Num(_) => panic!("expected tensor indices"),
1406            other => panic!("unexpected indices {other:?}"),
1407        }
1408    }
1409
1410    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1411    #[test]
1412    fn sortrows_missingplacement_last_descend_moves_nan_last() {
1413        let tensor = Tensor::new(vec![f64::NAN, 5.0, 1.0, 2.0], vec![2, 2]).unwrap();
1414        let eval = evaluate(
1415            Value::Tensor(tensor),
1416            &[
1417                Value::from("descend"),
1418                Value::from("MissingPlacement"),
1419                Value::from("last"),
1420            ],
1421        )
1422        .expect("evaluate");
1423        let (sorted, indices) = eval.into_values();
1424        match sorted {
1425            Value::Tensor(t) => {
1426                assert_eq!(t.data[0], 5.0);
1427                assert!(t.data[1].is_nan());
1428                assert_eq!(t.data[2], 2.0);
1429                assert_eq!(t.data[3], 1.0);
1430            }
1431            other => panic!("expected tensor, got {other:?}"),
1432        }
1433        match indices {
1434            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0]),
1435            Value::Num(_) => panic!("expected tensor indices"),
1436            other => panic!("unexpected indices {other:?}"),
1437        }
1438    }
1439
1440    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1441    #[test]
1442    fn sortrows_missingplacement_invalid_value_errors() {
1443        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1444        let err = evaluate(
1445            Value::Tensor(tensor),
1446            &[Value::from("MissingPlacement"), Value::from("middle")],
1447        )
1448        .unwrap_err();
1449        assert_eq!(
1450            err.identifier(),
1451            SORTROWS_ERROR_MISSING_PLACEMENT_UNKNOWN.identifier
1452        );
1453    }
1454
1455    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1456    #[test]
1457    fn sortrows_gpu_roundtrip() {
1458        test_support::with_test_provider(|provider| {
1459            let tensor = Tensor::new(vec![3.0, 1.0, 2.0, 4.0, 1.0, 5.0], vec![3, 2]).unwrap();
1460            let view = runmat_accelerate_api::HostTensorView {
1461                data: &tensor.data,
1462                shape: &tensor.shape,
1463            };
1464            let handle = provider.upload(&view).expect("upload");
1465            let eval = evaluate(Value::GpuTensor(handle), &[]).expect("evaluate");
1466            let (sorted, indices) = eval.into_values();
1467            match sorted {
1468                Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 3.0, 1.0, 5.0, 4.0]),
1469                other => panic!("expected tensor, got {other:?}"),
1470            }
1471            match indices {
1472                Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
1473                other => panic!("unexpected indices {other:?}"),
1474            }
1475        });
1476    }
1477
1478    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1479    #[test]
1480    #[cfg(feature = "wgpu")]
1481    fn sortrows_wgpu_matches_cpu() {
1482        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1483            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1484        );
1485
1486        let tensor = Tensor::new(vec![4.0, 2.0, 3.0, 1.0, 2.0, 5.0], vec![3, 2]).unwrap();
1487        let cpu_eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("cpu evaluate");
1488        let (cpu_sorted_val, cpu_indices_val) = cpu_eval.into_values();
1489        let cpu_sorted = match cpu_sorted_val {
1490            Value::Tensor(t) => t,
1491            other => panic!("expected tensor, got {other:?}"),
1492        };
1493        let cpu_indices = match cpu_indices_val {
1494            Value::Tensor(t) => t,
1495            other => panic!("expected tensor indices, got {other:?}"),
1496        };
1497
1498        let view = runmat_accelerate_api::HostTensorView {
1499            data: &tensor.data,
1500            shape: &tensor.shape,
1501        };
1502        let provider = runmat_accelerate_api::provider().expect("provider");
1503        let handle = provider.upload(&view).expect("upload");
1504        let gpu_eval = evaluate(Value::GpuTensor(handle.clone()), &[]).expect("gpu evaluate");
1505        let (gpu_sorted_val, gpu_indices_val) = gpu_eval.into_values();
1506        let gpu_sorted = match gpu_sorted_val {
1507            Value::Tensor(t) => t,
1508            other => panic!("expected tensor, got {other:?}"),
1509        };
1510        let gpu_indices = match gpu_indices_val {
1511            Value::Tensor(t) => t,
1512            other => panic!("expected tensor indices, got {other:?}"),
1513        };
1514
1515        assert_eq!(gpu_sorted.shape, cpu_sorted.shape);
1516        assert_eq!(gpu_sorted.data, cpu_sorted.data);
1517        assert_eq!(gpu_indices.shape, cpu_indices.shape);
1518        assert_eq!(gpu_indices.data, cpu_indices.data);
1519
1520        let _ = provider.free(&handle);
1521    }
1522}