Skip to main content

runmat_runtime/builtins/array/sorting_sets/
sortrows.rs

1//! MATLAB-compatible `sortrows` builtin with GPU-aware semantics.
2
3use std::cmp::Ordering;
4
5use runmat_accelerate_api::{
6    GpuTensorHandle, SortComparison as ProviderSortComparison, SortOrder as ProviderSortOrder,
7    SortResult as ProviderSortResult, SortRowsColumnSpec as ProviderSortRowsColumnSpec,
8};
9use runmat_builtins::{
10    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
11    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
12    CharArray, ComplexTensor, Tensor, Value,
13};
14use runmat_macros::runtime_builtin;
15
16use super::type_resolvers::tensor_output_type;
17use crate::build_runtime_error;
18use crate::builtins::common::gpu_helpers;
19use crate::builtins::common::spec::{
20    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
21    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
22};
23use crate::builtins::common::tensor;
24
25#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::sorting_sets::sortrows")]
26pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
27    name: "sortrows",
28    op_kind: GpuOpKind::Custom("sortrows"),
29    supported_precisions: &[ScalarType::F32, ScalarType::F64],
30    broadcast: BroadcastSemantics::None,
31    provider_hooks: &[ProviderHook::Custom("sortrows")],
32    constant_strategy: ConstantStrategy::InlineLiteral,
33    residency: ResidencyPolicy::GatherImmediately,
34    nan_mode: ReductionNaN::Include,
35    two_pass_threshold: None,
36    workgroup_size: None,
37    accepts_nan_mode: true,
38    notes:
39        "Providers may implement a row-sort kernel; explicit MissingPlacement overrides fall back to host memory until native support exists.",
40};
41
42#[runmat_macros::register_fusion_spec(
43    builtin_path = "crate::builtins::array::sorting_sets::sortrows"
44)]
45pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
46    name: "sortrows",
47    shape: ShapeRequirements::Any,
48    constant_strategy: ConstantStrategy::InlineLiteral,
49    elementwise: None,
50    reduction: None,
51    emits_nan: true,
52    notes: "`sortrows` terminates fusion chains and materialises results on the host; upstream tensors are gathered when necessary.",
53};
54
55const BUILTIN_NAME: &str = "sortrows";
56
57const SORTROWS_OUTPUT_B: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
58    name: "B",
59    ty: BuiltinParamType::Any,
60    arity: BuiltinParamArity::Required,
61    default: None,
62    description: "Sorted input rows.",
63}];
64
65const SORTROWS_OUTPUT_BI: [BuiltinParamDescriptor; 2] = [
66    BuiltinParamDescriptor {
67        name: "B",
68        ty: BuiltinParamType::Any,
69        arity: BuiltinParamArity::Required,
70        default: None,
71        description: "Sorted input rows.",
72    },
73    BuiltinParamDescriptor {
74        name: "I",
75        ty: BuiltinParamType::NumericArray,
76        arity: BuiltinParamArity::Required,
77        default: None,
78        description: "Permutation indices mapping sorted rows to original rows.",
79    },
80];
81
82const SORTROWS_INPUTS_A: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
83    name: "A",
84    ty: BuiltinParamType::Any,
85    arity: BuiltinParamArity::Required,
86    default: None,
87    description: "Input matrix to sort by rows.",
88}];
89
90const SORTROWS_INPUTS_A_COLUMNS: [BuiltinParamDescriptor; 2] = [
91    BuiltinParamDescriptor {
92        name: "A",
93        ty: BuiltinParamType::Any,
94        arity: BuiltinParamArity::Required,
95        default: None,
96        description: "Input matrix to sort by rows.",
97    },
98    BuiltinParamDescriptor {
99        name: "column",
100        ty: BuiltinParamType::NumericArray,
101        arity: BuiltinParamArity::Required,
102        default: None,
103        description: "Column specification vector (negative entries request descending order).",
104    },
105];
106
107const SORTROWS_INPUTS_A_DIRECTION: [BuiltinParamDescriptor; 2] = [
108    BuiltinParamDescriptor {
109        name: "A",
110        ty: BuiltinParamType::Any,
111        arity: BuiltinParamArity::Required,
112        default: None,
113        description: "Input matrix to sort by rows.",
114    },
115    BuiltinParamDescriptor {
116        name: "direction",
117        ty: BuiltinParamType::StringScalar,
118        arity: BuiltinParamArity::Required,
119        default: Some("\"ascend\""),
120        description: "Global row direction override: 'ascend' or 'descend'.",
121    },
122];
123
124const SORTROWS_INPUTS_A_COLUMNS_DIRECTION: [BuiltinParamDescriptor; 3] = [
125    BuiltinParamDescriptor {
126        name: "A",
127        ty: BuiltinParamType::Any,
128        arity: BuiltinParamArity::Required,
129        default: None,
130        description: "Input matrix to sort by rows.",
131    },
132    BuiltinParamDescriptor {
133        name: "column",
134        ty: BuiltinParamType::NumericArray,
135        arity: BuiltinParamArity::Required,
136        default: None,
137        description: "Column specification vector (negative entries request descending order).",
138    },
139    BuiltinParamDescriptor {
140        name: "direction",
141        ty: BuiltinParamType::StringScalar,
142        arity: BuiltinParamArity::Required,
143        default: Some("\"ascend\""),
144        description: "Global row direction override: 'ascend' or 'descend'.",
145    },
146];
147
148const SORTROWS_INPUTS_COMPARISON_METHOD: [BuiltinParamDescriptor; 4] = [
149    BuiltinParamDescriptor {
150        name: "A",
151        ty: BuiltinParamType::Any,
152        arity: BuiltinParamArity::Required,
153        default: None,
154        description: "Input matrix to sort by rows.",
155    },
156    BuiltinParamDescriptor {
157        name: "arg",
158        ty: BuiltinParamType::Any,
159        arity: BuiltinParamArity::Variadic,
160        default: None,
161        description: "Optional column and direction arguments.",
162    },
163    BuiltinParamDescriptor {
164        name: "name",
165        ty: BuiltinParamType::StringScalar,
166        arity: BuiltinParamArity::Required,
167        default: Some("\"ComparisonMethod\""),
168        description: "Name-value option key.",
169    },
170    BuiltinParamDescriptor {
171        name: "method",
172        ty: BuiltinParamType::StringScalar,
173        arity: BuiltinParamArity::Required,
174        default: Some("\"auto\""),
175        description: "Comparison method: 'auto', 'real', or 'abs'.",
176    },
177];
178
179const SORTROWS_INPUTS_MISSING_PLACEMENT: [BuiltinParamDescriptor; 4] = [
180    BuiltinParamDescriptor {
181        name: "A",
182        ty: BuiltinParamType::Any,
183        arity: BuiltinParamArity::Required,
184        default: None,
185        description: "Input matrix to sort by rows.",
186    },
187    BuiltinParamDescriptor {
188        name: "arg",
189        ty: BuiltinParamType::Any,
190        arity: BuiltinParamArity::Variadic,
191        default: None,
192        description: "Optional column and direction arguments.",
193    },
194    BuiltinParamDescriptor {
195        name: "name",
196        ty: BuiltinParamType::StringScalar,
197        arity: BuiltinParamArity::Required,
198        default: Some("\"MissingPlacement\""),
199        description: "Name-value option key.",
200    },
201    BuiltinParamDescriptor {
202        name: "placement",
203        ty: BuiltinParamType::StringScalar,
204        arity: BuiltinParamArity::Required,
205        default: Some("\"auto\""),
206        description: "NaN placement policy: 'auto', 'first', or 'last'.",
207    },
208];
209
210const SORTROWS_SIGNATURES: [BuiltinSignatureDescriptor; 12] = [
211    BuiltinSignatureDescriptor {
212        label: "B = sortrows(A)",
213        inputs: &SORTROWS_INPUTS_A,
214        outputs: &SORTROWS_OUTPUT_B,
215    },
216    BuiltinSignatureDescriptor {
217        label: "B = sortrows(A, column)",
218        inputs: &SORTROWS_INPUTS_A_COLUMNS,
219        outputs: &SORTROWS_OUTPUT_B,
220    },
221    BuiltinSignatureDescriptor {
222        label: "B = sortrows(A, direction)",
223        inputs: &SORTROWS_INPUTS_A_DIRECTION,
224        outputs: &SORTROWS_OUTPUT_B,
225    },
226    BuiltinSignatureDescriptor {
227        label: "B = sortrows(A, column, direction)",
228        inputs: &SORTROWS_INPUTS_A_COLUMNS_DIRECTION,
229        outputs: &SORTROWS_OUTPUT_B,
230    },
231    BuiltinSignatureDescriptor {
232        label: "B = sortrows(A, ..., \"ComparisonMethod\", method)",
233        inputs: &SORTROWS_INPUTS_COMPARISON_METHOD,
234        outputs: &SORTROWS_OUTPUT_B,
235    },
236    BuiltinSignatureDescriptor {
237        label: "B = sortrows(A, ..., \"MissingPlacement\", placement)",
238        inputs: &SORTROWS_INPUTS_MISSING_PLACEMENT,
239        outputs: &SORTROWS_OUTPUT_B,
240    },
241    BuiltinSignatureDescriptor {
242        label: "[B, I] = sortrows(A)",
243        inputs: &SORTROWS_INPUTS_A,
244        outputs: &SORTROWS_OUTPUT_BI,
245    },
246    BuiltinSignatureDescriptor {
247        label: "[B, I] = sortrows(A, column)",
248        inputs: &SORTROWS_INPUTS_A_COLUMNS,
249        outputs: &SORTROWS_OUTPUT_BI,
250    },
251    BuiltinSignatureDescriptor {
252        label: "[B, I] = sortrows(A, direction)",
253        inputs: &SORTROWS_INPUTS_A_DIRECTION,
254        outputs: &SORTROWS_OUTPUT_BI,
255    },
256    BuiltinSignatureDescriptor {
257        label: "[B, I] = sortrows(A, column, direction)",
258        inputs: &SORTROWS_INPUTS_A_COLUMNS_DIRECTION,
259        outputs: &SORTROWS_OUTPUT_BI,
260    },
261    BuiltinSignatureDescriptor {
262        label: "[B, I] = sortrows(A, ..., \"ComparisonMethod\", method)",
263        inputs: &SORTROWS_INPUTS_COMPARISON_METHOD,
264        outputs: &SORTROWS_OUTPUT_BI,
265    },
266    BuiltinSignatureDescriptor {
267        label: "[B, I] = sortrows(A, ..., \"MissingPlacement\", placement)",
268        inputs: &SORTROWS_INPUTS_MISSING_PLACEMENT,
269        outputs: &SORTROWS_OUTPUT_BI,
270    },
271];
272
273const SORTROWS_ERROR_INVALID_COLUMN_INDEX: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
274    code: "RM.SORTROWS.INVALID_COLUMN_INDEX",
275    identifier: Some("RunMat:sortrows:InvalidColumnIndex"),
276    when: "Column specification indices are out of range, zero, or otherwise invalid.",
277    message: "sortrows: invalid column index",
278};
279
280const SORTROWS_ERROR_MISSING_PLACEMENT_UNKNOWN: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
281    code: "RM.SORTROWS.MISSING_PLACEMENT_UNKNOWN",
282    identifier: Some("RunMat:sortrows:MissingPlacementUnknown"),
283    when: "MissingPlacement option value is unsupported.",
284    message: "sortrows: unsupported MissingPlacement value",
285};
286
287const SORTROWS_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
288    code: "RM.SORTROWS.INVALID_ARGUMENT",
289    identifier: Some("RunMat:sortrows:InvalidArgument"),
290    when: "Option parsing receives invalid argument kinds or malformed name-value pairs.",
291    message: "sortrows: invalid argument",
292};
293
294const SORTROWS_ERROR_COMPARISON_METHOD_UNKNOWN: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
295    code: "RM.SORTROWS.COMPARISON_METHOD_UNKNOWN",
296    identifier: Some("RunMat:sortrows:ComparisonMethodUnknown"),
297    when: "ComparisonMethod option value is unsupported.",
298    message: "sortrows: unsupported ComparisonMethod value",
299};
300
301const SORTROWS_ERROR_UNSUPPORTED_INPUT_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
302    code: "RM.SORTROWS.UNSUPPORTED_INPUT_TYPE",
303    identifier: Some("RunMat:sortrows:UnsupportedInputType"),
304    when: "Input cannot be converted to numeric, logical, complex, or char matrix domain.",
305    message: "sortrows: unsupported input type",
306};
307
308const SORTROWS_ERROR_MATRIX_REQUIRED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
309    code: "RM.SORTROWS.MATRIX_REQUIRED",
310    identifier: Some("RunMat:sortrows:MatrixRequired"),
311    when: "Input has rank greater than 2 where matrix input is required.",
312    message: "sortrows: input must be a 2-D matrix",
313};
314
315const SORTROWS_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
316    code: "RM.SORTROWS.INTERNAL",
317    identifier: Some("RunMat:sortrows:Internal"),
318    when: "Internal conversion/allocation/provider decode fails.",
319    message: "sortrows: internal operation failed",
320};
321
322const SORTROWS_ERRORS: [BuiltinErrorDescriptor; 7] = [
323    SORTROWS_ERROR_INVALID_COLUMN_INDEX,
324    SORTROWS_ERROR_MISSING_PLACEMENT_UNKNOWN,
325    SORTROWS_ERROR_INVALID_ARGUMENT,
326    SORTROWS_ERROR_COMPARISON_METHOD_UNKNOWN,
327    SORTROWS_ERROR_UNSUPPORTED_INPUT_TYPE,
328    SORTROWS_ERROR_MATRIX_REQUIRED,
329    SORTROWS_ERROR_INTERNAL,
330];
331
332pub const SORTROWS_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
333    signatures: &SORTROWS_SIGNATURES,
334    output_mode: BuiltinOutputMode::ByRequestedOutputCount,
335    completion_policy: BuiltinCompletionPolicy::Public,
336    errors: &SORTROWS_ERRORS,
337};
338
339fn sortrows_error_with(
340    error: &'static BuiltinErrorDescriptor,
341    message: impl Into<String>,
342) -> crate::RuntimeError {
343    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
344    if let Some(identifier) = error.identifier {
345        builder = builder.with_identifier(identifier);
346    }
347    builder.build()
348}
349
350fn sortrows_error(error: &'static BuiltinErrorDescriptor) -> crate::RuntimeError {
351    sortrows_error_with(error, error.message)
352}
353
354fn sortrows_internal_error(message: impl Into<String>) -> crate::RuntimeError {
355    sortrows_error_with(&SORTROWS_ERROR_INTERNAL, message)
356}
357
358#[runtime_builtin(
359    name = "sortrows",
360    category = "array/sorting_sets",
361    summary = "Sort matrix rows lexicographically with column and direction controls.",
362    keywords = "sortrows,row sort,lexicographic,gpu",
363    accel = "sink",
364    sink = true,
365    type_resolver(tensor_output_type),
366    descriptor(crate::builtins::array::sorting_sets::sortrows::SORTROWS_DESCRIPTOR),
367    builtin_path = "crate::builtins::array::sorting_sets::sortrows"
368)]
369async fn sortrows_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
370    let eval = evaluate(value, &rest).await?;
371    if let Some(out_count) = crate::output_count::current_output_count() {
372        if out_count == 0 {
373            return Ok(Value::OutputList(Vec::new()));
374        }
375        let (sorted, indices) = eval.into_values();
376        let mut outputs = vec![sorted];
377        if out_count >= 2 {
378            outputs.push(indices);
379        }
380        return Ok(crate::output_count::output_list_with_padding(
381            out_count, outputs,
382        ));
383    }
384    Ok(eval.into_sorted_value())
385}
386
387/// Evaluate the `sortrows` builtin once and expose both outputs.
388pub async fn evaluate(value: Value, rest: &[Value]) -> crate::BuiltinResult<SortRowsEvaluation> {
389    match value {
390        Value::GpuTensor(handle) => sortrows_gpu(handle, rest).await,
391        other => sortrows_host(other, rest),
392    }
393}
394
395async fn sortrows_gpu(
396    handle: GpuTensorHandle,
397    rest: &[Value],
398) -> crate::BuiltinResult<SortRowsEvaluation> {
399    ensure_matrix_shape(&handle.shape)?;
400    let (_, cols) = rows_cols_from_shape(&handle.shape);
401    let args = SortRowsArgs::parse(rest, cols)?;
402
403    if args.missing_is_auto() {
404        if let Some(provider) = runmat_accelerate_api::provider() {
405            let provider_columns = args.to_provider_columns();
406            let provider_comparison = args.provider_comparison();
407            match provider
408                .sort_rows(&handle, &provider_columns, provider_comparison)
409                .await
410            {
411                Ok(result) => return sortrows_from_provider_result(result),
412                Err(_err) => {
413                    // fall back to host path when provider cannot service the request
414                }
415            }
416        }
417    }
418
419    let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
420    sortrows_real_tensor_with_args(tensor, &args)
421}
422
423fn sortrows_from_provider_result(
424    result: ProviderSortResult,
425) -> crate::BuiltinResult<SortRowsEvaluation> {
426    let sorted_tensor = Tensor::new(result.values.data, result.values.shape)
427        .map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))?;
428    let indices_tensor = Tensor::new(result.indices.data, result.indices.shape)
429        .map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))?;
430    Ok(SortRowsEvaluation {
431        sorted: tensor::tensor_into_value(sorted_tensor),
432        indices: indices_tensor,
433    })
434}
435
436fn sortrows_host(value: Value, rest: &[Value]) -> crate::BuiltinResult<SortRowsEvaluation> {
437    match value {
438        Value::Tensor(tensor) => sortrows_real_tensor(tensor, rest),
439        Value::LogicalArray(logical) => {
440            let tensor = tensor::logical_to_tensor(&logical)
441                .map_err(|e| sortrows_internal_error(e))?;
442            sortrows_real_tensor(tensor, rest)
443        }
444        Value::Num(_) | Value::Int(_) | Value::Bool(_) => {
445            let tensor = tensor::value_into_tensor_for("sortrows", value)
446                .map_err(|e| sortrows_internal_error(e))?;
447            sortrows_real_tensor(tensor, rest)
448        }
449        Value::ComplexTensor(ct) => sortrows_complex_tensor(ct, rest),
450        Value::Complex(re, im) => {
451            let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
452                .map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))?;
453            sortrows_complex_tensor(tensor, rest)
454        }
455        Value::CharArray(ca) => sortrows_char_array(ca, rest),
456        other => Err(sortrows_error_with(
457            &SORTROWS_ERROR_UNSUPPORTED_INPUT_TYPE,
458            format!(
459                "sortrows: unsupported input type {:?}; expected numeric, logical, complex, or char arrays",
460                other
461            ),
462        )
463        .into()),
464    }
465}
466
467fn sortrows_real_tensor(
468    tensor: Tensor,
469    rest: &[Value],
470) -> crate::BuiltinResult<SortRowsEvaluation> {
471    ensure_matrix_shape(&tensor.shape)?;
472    let cols = tensor.cols();
473    let args = SortRowsArgs::parse(rest, cols)?;
474    sortrows_real_tensor_with_args(tensor, &args)
475}
476
477fn sortrows_real_tensor_with_args(
478    tensor: Tensor,
479    args: &SortRowsArgs,
480) -> crate::BuiltinResult<SortRowsEvaluation> {
481    let rows = tensor.rows();
482    let cols = tensor.cols();
483
484    if rows <= 1 || cols == 0 || tensor.data.is_empty() || args.columns.is_empty() {
485        let indices = identity_indices(rows)?;
486        return Ok(SortRowsEvaluation {
487            sorted: tensor::tensor_into_value(tensor),
488            indices,
489        });
490    }
491
492    let mut order: Vec<usize> = (0..rows).collect();
493    order.sort_by(|&a, &b| compare_real_rows(&tensor, rows, args, a, b));
494
495    let sorted_tensor = reorder_real_rows(&tensor, rows, cols, &order)?;
496    let indices = permutation_indices(&order)?;
497    Ok(SortRowsEvaluation {
498        sorted: tensor::tensor_into_value(sorted_tensor),
499        indices,
500    })
501}
502
503fn sortrows_complex_tensor(
504    tensor: ComplexTensor,
505    rest: &[Value],
506) -> crate::BuiltinResult<SortRowsEvaluation> {
507    ensure_matrix_shape(&tensor.shape)?;
508    let cols = tensor.cols;
509    let args = SortRowsArgs::parse(rest, cols)?;
510    sortrows_complex_tensor_with_args(tensor, &args)
511}
512
513fn sortrows_complex_tensor_with_args(
514    tensor: ComplexTensor,
515    args: &SortRowsArgs,
516) -> crate::BuiltinResult<SortRowsEvaluation> {
517    let rows = tensor.rows;
518    let cols = tensor.cols;
519
520    if rows <= 1 || cols == 0 || tensor.data.is_empty() || args.columns.is_empty() {
521        let indices = identity_indices(rows)?;
522        return Ok(SortRowsEvaluation {
523            sorted: complex_tensor_into_value(tensor),
524            indices,
525        });
526    }
527
528    let mut order: Vec<usize> = (0..rows).collect();
529    order.sort_by(|&a, &b| compare_complex_rows(&tensor, rows, args, a, b));
530
531    let sorted_tensor = reorder_complex_rows(&tensor, rows, cols, &order)?;
532    let indices = permutation_indices(&order)?;
533    Ok(SortRowsEvaluation {
534        sorted: complex_tensor_into_value(sorted_tensor),
535        indices,
536    })
537}
538
539fn sortrows_char_array(ca: CharArray, rest: &[Value]) -> crate::BuiltinResult<SortRowsEvaluation> {
540    let cols = ca.cols;
541    let args = SortRowsArgs::parse(rest, cols)?;
542    sortrows_char_array_with_args(ca, &args)
543}
544
545fn sortrows_char_array_with_args(
546    ca: CharArray,
547    args: &SortRowsArgs,
548) -> crate::BuiltinResult<SortRowsEvaluation> {
549    let rows = ca.rows;
550    let cols = ca.cols;
551
552    if rows <= 1 || cols == 0 || ca.data.is_empty() || args.columns.is_empty() {
553        let indices = identity_indices(rows)?;
554        return Ok(SortRowsEvaluation {
555            sorted: Value::CharArray(ca),
556            indices,
557        });
558    }
559
560    let mut order: Vec<usize> = (0..rows).collect();
561    order.sort_by(|&a, &b| compare_char_rows(&ca, args, a, b));
562
563    let sorted = reorder_char_rows(&ca, rows, cols, &order)?;
564    let indices = permutation_indices(&order)?;
565    Ok(SortRowsEvaluation {
566        sorted: Value::CharArray(sorted),
567        indices,
568    })
569}
570
571fn ensure_matrix_shape(shape: &[usize]) -> crate::BuiltinResult<()> {
572    if shape.len() <= 2 {
573        Ok(())
574    } else {
575        Err(sortrows_error(&SORTROWS_ERROR_MATRIX_REQUIRED))
576    }
577}
578
579fn rows_cols_from_shape(shape: &[usize]) -> (usize, usize) {
580    match shape.len() {
581        0 => (1, 1),
582        1 => (1, shape[0]),
583        _ => (shape[0], shape[1]),
584    }
585}
586
587fn compare_real_rows(
588    tensor: &Tensor,
589    rows: usize,
590    args: &SortRowsArgs,
591    a: usize,
592    b: usize,
593) -> Ordering {
594    for spec in &args.columns {
595        if spec.index >= tensor.cols() {
596            continue;
597        }
598        let idx_a = a + spec.index * rows;
599        let idx_b = b + spec.index * rows;
600        let va = tensor.data[idx_a];
601        let vb = tensor.data[idx_b];
602        let missing = args.missing_for_direction(spec.direction);
603        let ord = compare_real_scalars(va, vb, spec.direction, args.comparison, missing);
604        if ord != Ordering::Equal {
605            return ord;
606        }
607    }
608    Ordering::Equal
609}
610
611fn compare_complex_rows(
612    tensor: &ComplexTensor,
613    rows: usize,
614    args: &SortRowsArgs,
615    a: usize,
616    b: usize,
617) -> Ordering {
618    for spec in &args.columns {
619        if spec.index >= tensor.cols {
620            continue;
621        }
622        let idx_a = a + spec.index * rows;
623        let idx_b = b + spec.index * rows;
624        let va = tensor.data[idx_a];
625        let vb = tensor.data[idx_b];
626        let missing = args.missing_for_direction(spec.direction);
627        let ord = compare_complex_scalars(va, vb, spec.direction, args.comparison, missing);
628        if ord != Ordering::Equal {
629            return ord;
630        }
631    }
632    Ordering::Equal
633}
634
635fn compare_char_rows(ca: &CharArray, args: &SortRowsArgs, a: usize, b: usize) -> Ordering {
636    for spec in &args.columns {
637        if spec.index >= ca.cols {
638            continue;
639        }
640        let idx_a = a * ca.cols + spec.index;
641        let idx_b = b * ca.cols + spec.index;
642        let va = ca.data[idx_a];
643        let vb = ca.data[idx_b];
644        let ord = match spec.direction {
645            SortDirection::Ascend => va.cmp(&vb),
646            SortDirection::Descend => vb.cmp(&va),
647        };
648        if ord != Ordering::Equal {
649            return ord;
650        }
651    }
652    Ordering::Equal
653}
654
655fn reorder_real_rows(
656    tensor: &Tensor,
657    rows: usize,
658    cols: usize,
659    order: &[usize],
660) -> crate::BuiltinResult<Tensor> {
661    let mut data = vec![0.0; tensor.data.len()];
662    for col in 0..cols {
663        for (dest_row, &src_row) in order.iter().enumerate() {
664            let src_idx = src_row + col * rows;
665            let dst_idx = dest_row + col * rows;
666            data[dst_idx] = tensor.data[src_idx];
667        }
668    }
669    Tensor::new(data, tensor.shape.clone())
670        .map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))
671}
672
673fn reorder_complex_rows(
674    tensor: &ComplexTensor,
675    rows: usize,
676    cols: usize,
677    order: &[usize],
678) -> crate::BuiltinResult<ComplexTensor> {
679    let mut data = vec![(0.0, 0.0); tensor.data.len()];
680    for col in 0..cols {
681        for (dest_row, &src_row) in order.iter().enumerate() {
682            let src_idx = src_row + col * rows;
683            let dst_idx = dest_row + col * rows;
684            data[dst_idx] = tensor.data[src_idx];
685        }
686    }
687    ComplexTensor::new(data, tensor.shape.clone())
688        .map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))
689}
690
691fn reorder_char_rows(
692    ca: &CharArray,
693    rows: usize,
694    cols: usize,
695    order: &[usize],
696) -> crate::BuiltinResult<CharArray> {
697    let mut data = vec!['\0'; ca.data.len()];
698    for (dest_row, &src_row) in order.iter().enumerate() {
699        for col in 0..cols {
700            let src_idx = src_row * cols + col;
701            let dst_idx = dest_row * cols + col;
702            data[dst_idx] = ca.data[src_idx];
703        }
704    }
705    CharArray::new(data, rows, cols).map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))
706}
707
708fn compare_real_scalars(
709    a: f64,
710    b: f64,
711    direction: SortDirection,
712    comparison: ComparisonMethod,
713    missing: MissingPlacementResolved,
714) -> Ordering {
715    match (a.is_nan(), b.is_nan()) {
716        (true, true) => Ordering::Equal,
717        (true, false) => match missing {
718            MissingPlacementResolved::First => Ordering::Less,
719            MissingPlacementResolved::Last => Ordering::Greater,
720        },
721        (false, true) => match missing {
722            MissingPlacementResolved::First => Ordering::Greater,
723            MissingPlacementResolved::Last => Ordering::Less,
724        },
725        (false, false) => compare_real_finite_scalars(a, b, direction, comparison),
726    }
727}
728
729fn compare_real_finite_scalars(
730    a: f64,
731    b: f64,
732    direction: SortDirection,
733    comparison: ComparisonMethod,
734) -> Ordering {
735    if matches!(comparison, ComparisonMethod::Abs) {
736        let abs_cmp = a.abs().partial_cmp(&b.abs()).unwrap_or(Ordering::Equal);
737        if abs_cmp != Ordering::Equal {
738            return match direction {
739                SortDirection::Ascend => abs_cmp,
740                SortDirection::Descend => abs_cmp.reverse(),
741            };
742        }
743    }
744    match direction {
745        SortDirection::Ascend => a.partial_cmp(&b).unwrap_or(Ordering::Equal),
746        SortDirection::Descend => b.partial_cmp(&a).unwrap_or(Ordering::Equal),
747    }
748}
749
750fn compare_complex_scalars(
751    a: (f64, f64),
752    b: (f64, f64),
753    direction: SortDirection,
754    comparison: ComparisonMethod,
755    missing: MissingPlacementResolved,
756) -> Ordering {
757    match (complex_is_nan(a), complex_is_nan(b)) {
758        (true, true) => Ordering::Equal,
759        (true, false) => match missing {
760            MissingPlacementResolved::First => Ordering::Less,
761            MissingPlacementResolved::Last => Ordering::Greater,
762        },
763        (false, true) => match missing {
764            MissingPlacementResolved::First => Ordering::Greater,
765            MissingPlacementResolved::Last => Ordering::Less,
766        },
767        (false, false) => compare_complex_finite_scalars(a, b, direction, comparison),
768    }
769}
770
771fn compare_complex_finite_scalars(
772    a: (f64, f64),
773    b: (f64, f64),
774    direction: SortDirection,
775    comparison: ComparisonMethod,
776) -> Ordering {
777    match comparison {
778        ComparisonMethod::Real => compare_complex_real_first(a, b, direction),
779        ComparisonMethod::Auto | ComparisonMethod::Abs => {
780            let abs_cmp = complex_abs(a)
781                .partial_cmp(&complex_abs(b))
782                .unwrap_or(Ordering::Equal);
783            if abs_cmp != Ordering::Equal {
784                return match direction {
785                    SortDirection::Ascend => abs_cmp,
786                    SortDirection::Descend => abs_cmp.reverse(),
787                };
788            }
789            compare_complex_real_first(a, b, direction)
790        }
791    }
792}
793
794fn compare_complex_real_first(a: (f64, f64), b: (f64, f64), direction: SortDirection) -> Ordering {
795    let real_cmp = match direction {
796        SortDirection::Ascend => a.0.partial_cmp(&b.0),
797        SortDirection::Descend => b.0.partial_cmp(&a.0),
798    }
799    .unwrap_or(Ordering::Equal);
800    if real_cmp != Ordering::Equal {
801        return real_cmp;
802    }
803    match direction {
804        SortDirection::Ascend => a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal),
805        SortDirection::Descend => b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal),
806    }
807}
808
809fn complex_is_nan(value: (f64, f64)) -> bool {
810    value.0.is_nan() || value.1.is_nan()
811}
812
813fn complex_abs(value: (f64, f64)) -> f64 {
814    value.0.hypot(value.1)
815}
816
817fn permutation_indices(order: &[usize]) -> crate::BuiltinResult<Tensor> {
818    let rows = order.len();
819    let mut data = Vec::with_capacity(rows);
820    for &idx in order {
821        data.push((idx + 1) as f64);
822    }
823    Tensor::new(data, vec![rows, 1]).map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))
824}
825
826fn identity_indices(rows: usize) -> crate::BuiltinResult<Tensor> {
827    let mut data = Vec::with_capacity(rows);
828    for i in 0..rows {
829        data.push((i + 1) as f64);
830    }
831    Tensor::new(data, vec![rows, 1]).map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))
832}
833
834fn complex_tensor_into_value(tensor: ComplexTensor) -> Value {
835    if tensor.data.len() == 1 {
836        Value::Complex(tensor.data[0].0, tensor.data[0].1)
837    } else {
838        Value::ComplexTensor(tensor)
839    }
840}
841
842#[derive(Debug, Clone, Copy, PartialEq, Eq)]
843enum SortDirection {
844    Ascend,
845    Descend,
846}
847
848impl SortDirection {
849    fn from_str(value: &str) -> Option<Self> {
850        match value.trim().to_ascii_lowercase().as_str() {
851            "ascend" | "ascending" => Some(SortDirection::Ascend),
852            "descend" | "descending" => Some(SortDirection::Descend),
853            _ => None,
854        }
855    }
856}
857
858#[derive(Debug, Clone, Copy, PartialEq, Eq)]
859enum ComparisonMethod {
860    Auto,
861    Real,
862    Abs,
863}
864
865#[derive(Debug, Clone, Copy, PartialEq, Eq)]
866enum MissingPlacement {
867    Auto,
868    First,
869    Last,
870}
871
872#[derive(Debug, Clone, Copy, PartialEq, Eq)]
873enum MissingPlacementResolved {
874    First,
875    Last,
876}
877
878impl MissingPlacement {
879    fn resolve(self, direction: SortDirection) -> MissingPlacementResolved {
880        match self {
881            MissingPlacement::First => MissingPlacementResolved::First,
882            MissingPlacement::Last => MissingPlacementResolved::Last,
883            MissingPlacement::Auto => match direction {
884                SortDirection::Ascend => MissingPlacementResolved::Last,
885                SortDirection::Descend => MissingPlacementResolved::First,
886            },
887        }
888    }
889
890    fn is_auto(self) -> bool {
891        matches!(self, MissingPlacement::Auto)
892    }
893}
894
895#[derive(Debug, Clone)]
896struct ColumnSpec {
897    index: usize,
898    direction: SortDirection,
899}
900
901#[derive(Debug, Clone)]
902struct SortRowsArgs {
903    columns: Vec<ColumnSpec>,
904    comparison: ComparisonMethod,
905    missing: MissingPlacement,
906}
907
908impl SortRowsArgs {
909    fn parse(rest: &[Value], num_cols: usize) -> crate::BuiltinResult<Self> {
910        let mut columns: Option<Vec<ColumnSpec>> = None;
911        let mut override_direction: Option<SortDirection> = None;
912        let mut comparison = ComparisonMethod::Auto;
913        let mut missing = MissingPlacement::Auto;
914        let mut i = 0usize;
915
916        while i < rest.len() {
917            if columns.is_none() {
918                if let Some(parsed) = parse_column_vector(&rest[i], num_cols)? {
919                    columns = Some(parsed);
920                    i += 1;
921                    continue;
922                }
923            }
924            if let Some(direction) = parse_direction(&rest[i]) {
925                override_direction = Some(direction);
926                i += 1;
927                continue;
928            }
929            let Some(keyword) = tensor::value_to_string(&rest[i]) else {
930                return Err(sortrows_error_with(
931                    &SORTROWS_ERROR_INVALID_ARGUMENT,
932                    format!("sortrows: invalid argument {:?}", rest[i]),
933                ));
934            };
935            let lowered = keyword.trim().to_ascii_lowercase();
936            match lowered.as_str() {
937                "comparisonmethod" => {
938                    i += 1;
939                    if i >= rest.len() {
940                        return Err(sortrows_error_with(
941                            &SORTROWS_ERROR_INVALID_ARGUMENT,
942                            "sortrows: expected a value for 'ComparisonMethod'",
943                        ));
944                    }
945                    let Some(value_str) = tensor::value_to_string(&rest[i]) else {
946                        return Err(sortrows_error_with(
947                            &SORTROWS_ERROR_INVALID_ARGUMENT,
948                            "sortrows: 'ComparisonMethod' expects a string value",
949                        )
950                        .into());
951                    };
952                    comparison = match value_str.trim().to_ascii_lowercase().as_str() {
953                        "auto" => ComparisonMethod::Auto,
954                        "real" => ComparisonMethod::Real,
955                        "abs" | "magnitude" => ComparisonMethod::Abs,
956                        other => {
957                            return Err(sortrows_error_with(
958                                &SORTROWS_ERROR_COMPARISON_METHOD_UNKNOWN,
959                                format!("sortrows: unsupported ComparisonMethod '{other}'"),
960                            )
961                            .into())
962                        }
963                    };
964                    i += 1;
965                }
966                "missingplacement" => {
967                    i += 1;
968                    if i >= rest.len() {
969                        return Err(sortrows_error_with(
970                            &SORTROWS_ERROR_INVALID_ARGUMENT,
971                            "sortrows: expected a value for 'MissingPlacement'",
972                        )
973                        .into());
974                    }
975                    let Some(value_str) = tensor::value_to_string(&rest[i]) else {
976                        return Err(sortrows_error_with(
977                            &SORTROWS_ERROR_INVALID_ARGUMENT,
978                            "sortrows: 'MissingPlacement' expects a string value",
979                        )
980                        .into());
981                    };
982                    missing = match value_str.trim().to_ascii_lowercase().as_str() {
983                        "auto" => MissingPlacement::Auto,
984                        "first" => MissingPlacement::First,
985                        "last" => MissingPlacement::Last,
986                        other => {
987                            return Err(sortrows_error_with(
988                                &SORTROWS_ERROR_MISSING_PLACEMENT_UNKNOWN,
989                                format!("sortrows: unsupported MissingPlacement '{other}'"),
990                            )
991                            .into())
992                        }
993                    };
994                    i += 1;
995                }
996                other => {
997                    return Err(sortrows_error_with(
998                        &SORTROWS_ERROR_INVALID_ARGUMENT,
999                        format!("sortrows: unexpected argument '{other}'"),
1000                    ));
1001                }
1002            }
1003        }
1004
1005        let mut columns = columns.unwrap_or_else(|| default_columns(num_cols));
1006        if let Some(dir) = override_direction {
1007            for spec in &mut columns {
1008                spec.direction = dir;
1009            }
1010        }
1011        validate_columns(&columns, num_cols)?;
1012
1013        Ok(SortRowsArgs {
1014            columns,
1015            comparison,
1016            missing,
1017        })
1018    }
1019
1020    fn to_provider_columns(&self) -> Vec<ProviderSortRowsColumnSpec> {
1021        self.columns
1022            .iter()
1023            .map(|spec| ProviderSortRowsColumnSpec {
1024                index: spec.index,
1025                order: match spec.direction {
1026                    SortDirection::Ascend => ProviderSortOrder::Ascend,
1027                    SortDirection::Descend => ProviderSortOrder::Descend,
1028                },
1029            })
1030            .collect()
1031    }
1032
1033    fn provider_comparison(&self) -> ProviderSortComparison {
1034        match self.comparison {
1035            ComparisonMethod::Auto => ProviderSortComparison::Auto,
1036            ComparisonMethod::Real => ProviderSortComparison::Real,
1037            ComparisonMethod::Abs => ProviderSortComparison::Abs,
1038        }
1039    }
1040
1041    fn missing_for_direction(&self, direction: SortDirection) -> MissingPlacementResolved {
1042        self.missing.resolve(direction)
1043    }
1044
1045    fn missing_is_auto(&self) -> bool {
1046        self.missing.is_auto()
1047    }
1048}
1049
1050fn parse_column_vector(
1051    value: &Value,
1052    num_cols: usize,
1053) -> crate::BuiltinResult<Option<Vec<ColumnSpec>>> {
1054    match value {
1055        Value::Int(i) => parse_single_column(i.to_i64(), num_cols).map(Some),
1056        Value::Num(n) => {
1057            if !n.is_finite() {
1058                return Err(sortrows_error_with(
1059                    &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1060                    "sortrows: column indices must be finite",
1061                ));
1062            }
1063            let rounded = n.round();
1064            if (rounded - n).abs() > f64::EPSILON {
1065                return Err(sortrows_error_with(
1066                    &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1067                    "sortrows: column indices must be integers",
1068                ));
1069            }
1070            parse_single_column(rounded as i64, num_cols).map(Some)
1071        }
1072        Value::Tensor(tensor) => {
1073            if !is_vector(&tensor.shape) {
1074                return Err(sortrows_error_with(
1075                    &SORTROWS_ERROR_INVALID_ARGUMENT,
1076                    "sortrows: column specification must be a vector",
1077                ));
1078            }
1079            let mut specs = Vec::with_capacity(tensor.data.len());
1080            for &entry in &tensor.data {
1081                if !entry.is_finite() {
1082                    return Err(sortrows_error_with(
1083                        &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1084                        "sortrows: column indices must be finite",
1085                    ));
1086                }
1087                let rounded = entry.round();
1088                if (rounded - entry).abs() > f64::EPSILON {
1089                    return Err(sortrows_error_with(
1090                        &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1091                        "sortrows: column indices must be integers",
1092                    ));
1093                }
1094                let column = parse_single_column_i64(rounded as i64, num_cols)?;
1095                specs.push(column);
1096            }
1097            Ok(Some(specs))
1098        }
1099        _ => Ok(None),
1100    }
1101}
1102
1103fn parse_single_column(value: i64, num_cols: usize) -> crate::BuiltinResult<Vec<ColumnSpec>> {
1104    parse_single_column_i64(value, num_cols).map(|spec| vec![spec])
1105}
1106
1107fn parse_single_column_i64(value: i64, num_cols: usize) -> crate::BuiltinResult<ColumnSpec> {
1108    if value == 0 {
1109        return Err(sortrows_error_with(
1110            &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1111            "sortrows: column indices must be non-zero",
1112        ));
1113    }
1114    let abs = value.unsigned_abs() as usize;
1115    if abs == 0 {
1116        return Err(sortrows_error_with(
1117            &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1118            "sortrows: column indices must be >= 1",
1119        ));
1120    }
1121    if num_cols == 0 {
1122        return Err(sortrows_error_with(
1123            &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1124            "sortrows: column index exceeds matrix with 0 columns",
1125        ));
1126    }
1127    if abs > num_cols {
1128        return Err(sortrows_error_with(
1129            &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1130            format!(
1131                "sortrows: column index {} exceeds matrix with {} columns",
1132                abs, num_cols
1133            ),
1134        )
1135        .into());
1136    }
1137    let direction = if value > 0 {
1138        SortDirection::Ascend
1139    } else {
1140        SortDirection::Descend
1141    };
1142    Ok(ColumnSpec {
1143        index: abs - 1,
1144        direction,
1145    })
1146}
1147
1148fn parse_direction(value: &Value) -> Option<SortDirection> {
1149    tensor::value_to_string(value).and_then(|s| SortDirection::from_str(&s))
1150}
1151
1152fn default_columns(num_cols: usize) -> Vec<ColumnSpec> {
1153    let mut columns = Vec::with_capacity(num_cols);
1154    for col in 0..num_cols {
1155        columns.push(ColumnSpec {
1156            index: col,
1157            direction: SortDirection::Ascend,
1158        });
1159    }
1160    columns
1161}
1162
1163fn validate_columns(columns: &[ColumnSpec], num_cols: usize) -> crate::BuiltinResult<()> {
1164    if num_cols == 0 && columns.iter().any(|spec| spec.index > 0) {
1165        return Err(sortrows_error_with(
1166            &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1167            "sortrows: column index exceeds matrix with 0 columns",
1168        ));
1169    }
1170    for spec in columns {
1171        if num_cols > 0 && spec.index >= num_cols {
1172            return Err(sortrows_error_with(
1173                &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1174                format!(
1175                    "sortrows: column index {} exceeds matrix with {} columns",
1176                    spec.index + 1,
1177                    num_cols
1178                ),
1179            )
1180            .into());
1181        }
1182    }
1183    Ok(())
1184}
1185
1186fn is_vector(shape: &[usize]) -> bool {
1187    match shape.len() {
1188        0 => true,
1189        1 => true,
1190        2 => shape[0] == 1 || shape[1] == 1,
1191        _ => false,
1192    }
1193}
1194
1195#[derive(Debug)]
1196pub struct SortRowsEvaluation {
1197    sorted: Value,
1198    indices: Tensor,
1199}
1200
1201impl SortRowsEvaluation {
1202    pub fn into_sorted_value(self) -> Value {
1203        self.sorted
1204    }
1205
1206    pub fn into_values(self) -> (Value, Value) {
1207        let indices = tensor::tensor_into_value(self.indices);
1208        (self.sorted, indices)
1209    }
1210
1211    pub fn indices_value(&self) -> Value {
1212        tensor::tensor_into_value(self.indices.clone())
1213    }
1214}
1215
1216#[cfg(test)]
1217pub(crate) mod tests {
1218    use super::*;
1219    use crate::builtins::common::test_support;
1220    use runmat_builtins::{IntValue, ResolveContext, Type, Value};
1221
1222    fn evaluate(value: Value, rest: &[Value]) -> crate::BuiltinResult<SortRowsEvaluation> {
1223        futures::executor::block_on(super::evaluate(value, rest))
1224    }
1225
1226    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1227    #[test]
1228    fn sortrows_default_matrix() {
1229        let tensor = Tensor::new(vec![3.0, 1.0, 2.0, 4.0, 1.0, 5.0], vec![3, 2]).unwrap();
1230        let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
1231        let (sorted, indices) = eval.into_values();
1232        match sorted {
1233            Value::Tensor(t) => {
1234                assert_eq!(t.shape, vec![3, 2]);
1235                assert_eq!(t.data, vec![1.0, 2.0, 3.0, 1.0, 5.0, 4.0]);
1236            }
1237            other => panic!("expected tensor, got {other:?}"),
1238        }
1239        match indices {
1240            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
1241            Value::Num(_) => panic!("expected tensor indices"),
1242            other => panic!("unexpected indices {other:?}"),
1243        }
1244    }
1245
1246    #[test]
1247    fn sortrows_type_resolver_tensor() {
1248        assert_eq!(
1249            tensor_output_type(&[Type::tensor()], &ResolveContext::new(Vec::new())),
1250            Type::tensor()
1251        );
1252    }
1253
1254    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1255    #[test]
1256    fn sortrows_with_column_vector() {
1257        let tensor = Tensor::new(
1258            vec![1.0, 3.0, 3.0, 4.0, 2.0, 2.0, 2.0, 5.0, 1.0],
1259            vec![3, 3],
1260        )
1261        .unwrap();
1262        let cols = Tensor::new(vec![2.0, 3.0, 1.0], vec![3, 1]).unwrap();
1263        let eval = evaluate(Value::Tensor(tensor), &[Value::Tensor(cols)]).expect("evaluate");
1264        let (sorted, _) = eval.into_values();
1265        match sorted {
1266            Value::Tensor(t) => {
1267                assert_eq!(t.data, vec![3.0, 3.0, 1.0, 2.0, 2.0, 4.0, 1.0, 5.0, 2.0]);
1268            }
1269            other => panic!("expected tensor, got {other:?}"),
1270        }
1271    }
1272
1273    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1274    #[test]
1275    fn sortrows_direction_descend() {
1276        let tensor = Tensor::new(vec![1.0, 2.0, 4.0, 3.0], vec![2, 2]).unwrap();
1277        let eval = evaluate(Value::Tensor(tensor), &[Value::from("descend")]).expect("evaluate");
1278        let (sorted, _) = eval.into_values();
1279        match sorted {
1280            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0, 3.0, 4.0]),
1281            other => panic!("expected tensor, got {other:?}"),
1282        }
1283    }
1284
1285    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1286    #[test]
1287    fn sortrows_mixed_directions() {
1288        let tensor = Tensor::new(vec![1.0, 1.0, 1.0, 1.0, 7.0, 2.0], vec![3, 2]).unwrap();
1289        let cols = Tensor::new(vec![1.0, -2.0], vec![2, 1]).unwrap();
1290        let eval = evaluate(Value::Tensor(tensor), &[Value::Tensor(cols)]).expect("evaluate");
1291        let (sorted, _) = eval.into_values();
1292        match sorted {
1293            Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 1.0, 1.0, 7.0, 2.0, 1.0]),
1294            other => panic!("expected tensor, got {other:?}"),
1295        }
1296    }
1297
1298    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1299    #[test]
1300    fn sortrows_returns_indices() {
1301        let tensor = Tensor::new(vec![2.0, 1.0, 3.0, 4.0], vec![2, 2]).unwrap();
1302        let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
1303        let (_, indices) = eval.into_values();
1304        match indices {
1305            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0]),
1306            Value::Num(_) => panic!("expected tensor indices"),
1307            other => panic!("unexpected indices {other:?}"),
1308        }
1309    }
1310
1311    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1312    #[test]
1313    fn sortrows_char_array() {
1314        let chars = CharArray::new(
1315            "bob "
1316                .chars()
1317                .chain("al  ".chars())
1318                .chain("ally".chars())
1319                .collect(),
1320            3,
1321            4,
1322        )
1323        .unwrap();
1324        let eval = evaluate(Value::CharArray(chars), &[]).expect("evaluate");
1325        let (sorted, _) = eval.into_values();
1326        match sorted {
1327            Value::CharArray(ca) => {
1328                assert_eq!(ca.rows, 3);
1329                assert_eq!(ca.cols, 4);
1330                let strings: Vec<String> = (0..ca.rows)
1331                    .map(|r| {
1332                        ca.data[r * ca.cols..(r + 1) * ca.cols]
1333                            .iter()
1334                            .collect::<String>()
1335                    })
1336                    .collect();
1337                assert_eq!(
1338                    strings,
1339                    vec!["al  ".to_string(), "ally".to_string(), "bob ".to_string()]
1340                );
1341            }
1342            other => panic!("expected char array, got {other:?}"),
1343        }
1344    }
1345
1346    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1347    #[test]
1348    fn sortrows_complex_abs() {
1349        let tensor = ComplexTensor::new(vec![(1.0, 2.0), (-2.0, 1.0)], vec![2, 1]).unwrap();
1350        let eval = evaluate(
1351            Value::ComplexTensor(tensor),
1352            &[Value::from("ComparisonMethod"), Value::from("abs")],
1353        )
1354        .expect("evaluate");
1355        let (sorted, _) = eval.into_values();
1356        match sorted {
1357            Value::ComplexTensor(ct) => {
1358                assert_eq!(ct.data, vec![(-2.0, 1.0), (1.0, 2.0)]);
1359            }
1360            other => panic!("expected complex tensor, got {other:?}"),
1361        }
1362    }
1363
1364    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1365    #[test]
1366    fn sortrows_invalid_column_index_errors() {
1367        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1368        let err = evaluate(Value::Tensor(tensor), &[Value::Int(IntValue::I32(3))]).unwrap_err();
1369        assert_eq!(
1370            err.identifier(),
1371            SORTROWS_ERROR_INVALID_COLUMN_INDEX.identifier
1372        );
1373    }
1374
1375    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1376    #[test]
1377    fn sortrows_missingplacement_first_moves_nan_first() {
1378        let tensor = Tensor::new(vec![1.0, f64::NAN, 2.0, 3.0], vec![2, 2]).unwrap();
1379        let eval = evaluate(
1380            Value::Tensor(tensor),
1381            &[Value::from("MissingPlacement"), Value::from("first")],
1382        )
1383        .expect("evaluate");
1384        let (sorted, indices) = eval.into_values();
1385        match sorted {
1386            Value::Tensor(t) => {
1387                assert!(t.data[0].is_nan());
1388                assert_eq!(t.data[1], 1.0);
1389                assert_eq!(t.data[2], 3.0);
1390                assert_eq!(t.data[3], 2.0);
1391            }
1392            other => panic!("expected tensor, got {other:?}"),
1393        }
1394        match indices {
1395            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0]),
1396            Value::Num(_) => panic!("expected tensor indices"),
1397            other => panic!("unexpected indices {other:?}"),
1398        }
1399    }
1400
1401    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1402    #[test]
1403    fn sortrows_missingplacement_last_descend_moves_nan_last() {
1404        let tensor = Tensor::new(vec![f64::NAN, 5.0, 1.0, 2.0], vec![2, 2]).unwrap();
1405        let eval = evaluate(
1406            Value::Tensor(tensor),
1407            &[
1408                Value::from("descend"),
1409                Value::from("MissingPlacement"),
1410                Value::from("last"),
1411            ],
1412        )
1413        .expect("evaluate");
1414        let (sorted, indices) = eval.into_values();
1415        match sorted {
1416            Value::Tensor(t) => {
1417                assert_eq!(t.data[0], 5.0);
1418                assert!(t.data[1].is_nan());
1419                assert_eq!(t.data[2], 2.0);
1420                assert_eq!(t.data[3], 1.0);
1421            }
1422            other => panic!("expected tensor, got {other:?}"),
1423        }
1424        match indices {
1425            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0]),
1426            Value::Num(_) => panic!("expected tensor indices"),
1427            other => panic!("unexpected indices {other:?}"),
1428        }
1429    }
1430
1431    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1432    #[test]
1433    fn sortrows_missingplacement_invalid_value_errors() {
1434        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1435        let err = evaluate(
1436            Value::Tensor(tensor),
1437            &[Value::from("MissingPlacement"), Value::from("middle")],
1438        )
1439        .unwrap_err();
1440        assert_eq!(
1441            err.identifier(),
1442            SORTROWS_ERROR_MISSING_PLACEMENT_UNKNOWN.identifier
1443        );
1444    }
1445
1446    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1447    #[test]
1448    fn sortrows_gpu_roundtrip() {
1449        test_support::with_test_provider(|provider| {
1450            let tensor = Tensor::new(vec![3.0, 1.0, 2.0, 4.0, 1.0, 5.0], vec![3, 2]).unwrap();
1451            let view = runmat_accelerate_api::HostTensorView {
1452                data: &tensor.data,
1453                shape: &tensor.shape,
1454            };
1455            let handle = provider.upload(&view).expect("upload");
1456            let eval = evaluate(Value::GpuTensor(handle), &[]).expect("evaluate");
1457            let (sorted, indices) = eval.into_values();
1458            match sorted {
1459                Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 3.0, 1.0, 5.0, 4.0]),
1460                other => panic!("expected tensor, got {other:?}"),
1461            }
1462            match indices {
1463                Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
1464                other => panic!("unexpected indices {other:?}"),
1465            }
1466        });
1467    }
1468
1469    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1470    #[test]
1471    #[cfg(feature = "wgpu")]
1472    fn sortrows_wgpu_matches_cpu() {
1473        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1474            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1475        );
1476
1477        let tensor = Tensor::new(vec![4.0, 2.0, 3.0, 1.0, 2.0, 5.0], vec![3, 2]).unwrap();
1478        let cpu_eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("cpu evaluate");
1479        let (cpu_sorted_val, cpu_indices_val) = cpu_eval.into_values();
1480        let cpu_sorted = match cpu_sorted_val {
1481            Value::Tensor(t) => t,
1482            other => panic!("expected tensor, got {other:?}"),
1483        };
1484        let cpu_indices = match cpu_indices_val {
1485            Value::Tensor(t) => t,
1486            other => panic!("expected tensor indices, got {other:?}"),
1487        };
1488
1489        let view = runmat_accelerate_api::HostTensorView {
1490            data: &tensor.data,
1491            shape: &tensor.shape,
1492        };
1493        let provider = runmat_accelerate_api::provider().expect("provider");
1494        let handle = provider.upload(&view).expect("upload");
1495        let gpu_eval = evaluate(Value::GpuTensor(handle.clone()), &[]).expect("gpu evaluate");
1496        let (gpu_sorted_val, gpu_indices_val) = gpu_eval.into_values();
1497        let gpu_sorted = match gpu_sorted_val {
1498            Value::Tensor(t) => t,
1499            other => panic!("expected tensor, got {other:?}"),
1500        };
1501        let gpu_indices = match gpu_indices_val {
1502            Value::Tensor(t) => t,
1503            other => panic!("expected tensor indices, got {other:?}"),
1504        };
1505
1506        assert_eq!(gpu_sorted.shape, cpu_sorted.shape);
1507        assert_eq!(gpu_sorted.data, cpu_sorted.data);
1508        assert_eq!(gpu_indices.shape, cpu_indices.shape);
1509        assert_eq!(gpu_indices.data, cpu_indices.data);
1510
1511        let _ = provider.free(&handle);
1512    }
1513}