Skip to main content

runmat_runtime/builtins/array/sorting_sets/
sort.rs

1//! MATLAB-compatible `sort` builtin with multi-output and GPU-aware semantics.
2
3use std::cmp::Ordering;
4
5use runmat_accelerate_api::{
6    GpuTensorHandle, SortComparison as ProviderSortComparison, SortOrder as ProviderSortOrder,
7};
8use runmat_builtins::{
9    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
10    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
11    ComplexTensor, Tensor, Value,
12};
13use runmat_macros::runtime_builtin;
14
15use super::type_resolvers::tensor_output_type;
16use crate::build_runtime_error;
17use crate::builtins::common::arg_tokens::{tokens_from_values, ArgToken};
18use crate::builtins::common::gpu_helpers;
19use crate::builtins::common::spec::{
20    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
21    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
22};
23use crate::builtins::common::tensor;
24
25#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::sorting_sets::sort")]
26pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
27    name: "sort",
28    op_kind: GpuOpKind::Custom("sort"),
29    supported_precisions: &[ScalarType::F32, ScalarType::F64],
30    broadcast: BroadcastSemantics::None,
31    provider_hooks: &[ProviderHook::Custom("sort_dim")],
32    constant_strategy: ConstantStrategy::InlineLiteral,
33    residency: ResidencyPolicy::GatherImmediately,
34    nan_mode: ReductionNaN::Include,
35    two_pass_threshold: None,
36    workgroup_size: None,
37    accepts_nan_mode: true,
38    notes: "Providers may add a dedicated sort kernel in the future; today tensors are gathered to host memory before sorting.",
39};
40
41#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::sorting_sets::sort")]
42pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
43    name: "sort",
44    shape: ShapeRequirements::Any,
45    constant_strategy: ConstantStrategy::InlineLiteral,
46    elementwise: None,
47    reduction: None,
48    emits_nan: true,
49    notes: "Sorting breaks fusion chains and acts as a residency sink; upstream tensors are gathered to host memory.",
50};
51
52const BUILTIN_NAME: &str = "sort";
53
54const SORT_OUTPUT_B: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
55    name: "B",
56    ty: BuiltinParamType::Any,
57    arity: BuiltinParamArity::Required,
58    default: None,
59    description: "Sorted values.",
60}];
61
62const SORT_OUTPUT_BI: [BuiltinParamDescriptor; 2] = [
63    BuiltinParamDescriptor {
64        name: "B",
65        ty: BuiltinParamType::Any,
66        arity: BuiltinParamArity::Required,
67        default: None,
68        description: "Sorted values.",
69    },
70    BuiltinParamDescriptor {
71        name: "I",
72        ty: BuiltinParamType::NumericArray,
73        arity: BuiltinParamArity::Required,
74        default: None,
75        description: "One-based index permutation for each sorted slice.",
76    },
77];
78
79const SORT_INPUTS_A: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
80    name: "A",
81    ty: BuiltinParamType::Any,
82    arity: BuiltinParamArity::Required,
83    default: None,
84    description: "Input array.",
85}];
86
87const SORT_INPUTS_A_ARG1: [BuiltinParamDescriptor; 2] = [
88    BuiltinParamDescriptor {
89        name: "A",
90        ty: BuiltinParamType::Any,
91        arity: BuiltinParamArity::Required,
92        default: None,
93        description: "Input array.",
94    },
95    BuiltinParamDescriptor {
96        name: "arg1",
97        ty: BuiltinParamType::Any,
98        arity: BuiltinParamArity::Required,
99        default: None,
100        description: "Dimension selector or direction token ('ascend'/'descend').",
101    },
102];
103
104const SORT_INPUTS_A_ARG1_ARG2: [BuiltinParamDescriptor; 3] = [
105    BuiltinParamDescriptor {
106        name: "A",
107        ty: BuiltinParamType::Any,
108        arity: BuiltinParamArity::Required,
109        default: None,
110        description: "Input array.",
111    },
112    BuiltinParamDescriptor {
113        name: "arg1",
114        ty: BuiltinParamType::Any,
115        arity: BuiltinParamArity::Required,
116        default: None,
117        description: "Dimension selector, placeholder, or direction token.",
118    },
119    BuiltinParamDescriptor {
120        name: "arg2",
121        ty: BuiltinParamType::Any,
122        arity: BuiltinParamArity::Required,
123        default: None,
124        description: "Dimension selector or direction token.",
125    },
126];
127
128const SORT_INPUTS_COMPARISON_METHOD: [BuiltinParamDescriptor; 4] = [
129    BuiltinParamDescriptor {
130        name: "A",
131        ty: BuiltinParamType::Any,
132        arity: BuiltinParamArity::Required,
133        default: None,
134        description: "Input array.",
135    },
136    BuiltinParamDescriptor {
137        name: "arg",
138        ty: BuiltinParamType::Any,
139        arity: BuiltinParamArity::Variadic,
140        default: None,
141        description: "Optional dimension/direction arguments.",
142    },
143    BuiltinParamDescriptor {
144        name: "name",
145        ty: BuiltinParamType::StringScalar,
146        arity: BuiltinParamArity::Required,
147        default: Some("\"ComparisonMethod\""),
148        description: "Name-value option key.",
149    },
150    BuiltinParamDescriptor {
151        name: "method",
152        ty: BuiltinParamType::StringScalar,
153        arity: BuiltinParamArity::Required,
154        default: Some("\"auto\""),
155        description: "Comparison method: 'auto', 'real', or 'abs'.",
156    },
157];
158
159const SORT_INPUTS_MISSING_PLACEMENT: [BuiltinParamDescriptor; 4] = [
160    BuiltinParamDescriptor {
161        name: "A",
162        ty: BuiltinParamType::Any,
163        arity: BuiltinParamArity::Required,
164        default: None,
165        description: "Input array.",
166    },
167    BuiltinParamDescriptor {
168        name: "arg",
169        ty: BuiltinParamType::Any,
170        arity: BuiltinParamArity::Variadic,
171        default: None,
172        description: "Optional dimension/direction arguments.",
173    },
174    BuiltinParamDescriptor {
175        name: "name",
176        ty: BuiltinParamType::StringScalar,
177        arity: BuiltinParamArity::Required,
178        default: Some("\"MissingPlacement\""),
179        description: "Name-value option key.",
180    },
181    BuiltinParamDescriptor {
182        name: "placement",
183        ty: BuiltinParamType::StringScalar,
184        arity: BuiltinParamArity::Required,
185        default: Some("\"auto\""),
186        description: "Requested NaN placement option (currently unsupported).",
187    },
188];
189
190const SORT_SIGNATURES: [BuiltinSignatureDescriptor; 10] = [
191    BuiltinSignatureDescriptor {
192        label: "B = sort(A)",
193        inputs: &SORT_INPUTS_A,
194        outputs: &SORT_OUTPUT_B,
195    },
196    BuiltinSignatureDescriptor {
197        label: "B = sort(A, arg1)",
198        inputs: &SORT_INPUTS_A_ARG1,
199        outputs: &SORT_OUTPUT_B,
200    },
201    BuiltinSignatureDescriptor {
202        label: "B = sort(A, arg1, arg2)",
203        inputs: &SORT_INPUTS_A_ARG1_ARG2,
204        outputs: &SORT_OUTPUT_B,
205    },
206    BuiltinSignatureDescriptor {
207        label: "B = sort(A, ..., \"ComparisonMethod\", method)",
208        inputs: &SORT_INPUTS_COMPARISON_METHOD,
209        outputs: &SORT_OUTPUT_B,
210    },
211    BuiltinSignatureDescriptor {
212        label: "B = sort(A, ..., \"MissingPlacement\", placement)",
213        inputs: &SORT_INPUTS_MISSING_PLACEMENT,
214        outputs: &SORT_OUTPUT_B,
215    },
216    BuiltinSignatureDescriptor {
217        label: "[B, I] = sort(A)",
218        inputs: &SORT_INPUTS_A,
219        outputs: &SORT_OUTPUT_BI,
220    },
221    BuiltinSignatureDescriptor {
222        label: "[B, I] = sort(A, arg1)",
223        inputs: &SORT_INPUTS_A_ARG1,
224        outputs: &SORT_OUTPUT_BI,
225    },
226    BuiltinSignatureDescriptor {
227        label: "[B, I] = sort(A, arg1, arg2)",
228        inputs: &SORT_INPUTS_A_ARG1_ARG2,
229        outputs: &SORT_OUTPUT_BI,
230    },
231    BuiltinSignatureDescriptor {
232        label: "[B, I] = sort(A, ..., \"ComparisonMethod\", method)",
233        inputs: &SORT_INPUTS_COMPARISON_METHOD,
234        outputs: &SORT_OUTPUT_BI,
235    },
236    BuiltinSignatureDescriptor {
237        label: "[B, I] = sort(A, ..., \"MissingPlacement\", placement)",
238        inputs: &SORT_INPUTS_MISSING_PLACEMENT,
239        outputs: &SORT_OUTPUT_BI,
240    },
241];
242
243const SORT_ERROR_INVALID_DIMENSION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
244    code: "RM.SORT.INVALID_DIMENSION",
245    identifier: Some("RunMat:sort:InvalidDimension"),
246    when: "Dimension argument is non-positive, non-integer, or otherwise invalid.",
247    message: "sort: invalid dimension argument",
248};
249
250const SORT_ERROR_COMPARISON_METHOD_REQUIRES_STRING: BuiltinErrorDescriptor =
251    BuiltinErrorDescriptor {
252        code: "RM.SORT.COMPARISON_METHOD_REQUIRES_STRING",
253        identifier: Some("RunMat:sort:ComparisonMethodRequiresString"),
254        when: "ComparisonMethod option value is not string-like.",
255        message: "sort: 'ComparisonMethod' requires a string value",
256    };
257
258const SORT_ERROR_COMPARISON_METHOD_UNKNOWN: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
259    code: "RM.SORT.COMPARISON_METHOD_UNKNOWN",
260    identifier: Some("RunMat:sort:ComparisonMethodUnknown"),
261    when: "ComparisonMethod option value is not one of 'auto'/'real'/'abs'.",
262    message: "sort: unsupported ComparisonMethod",
263};
264
265const SORT_ERROR_MISSINGPLACEMENT_UNSUPPORTED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
266    code: "RM.SORT.MISSINGPLACEMENT_UNSUPPORTED",
267    identifier: Some("RunMat:sort:MissingPlacementUnsupported"),
268    when: "MissingPlacement option is provided but unsupported.",
269    message: "sort: the 'MissingPlacement' option is not supported yet",
270};
271
272const SORT_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
273    code: "RM.SORT.INVALID_ARGUMENT",
274    identifier: Some("RunMat:sort:InvalidArgument"),
275    when: "Parser encounters invalid or unrecognized option/value arguments.",
276    message: "sort: invalid argument sequence",
277};
278
279const SORT_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
280    code: "RM.SORT.INTERNAL",
281    identifier: Some("RunMat:sort:Internal"),
282    when: "Internal conversion, allocation, or provider result construction fails.",
283    message: "sort: internal operation failed",
284};
285
286const SORT_ERRORS: [BuiltinErrorDescriptor; 6] = [
287    SORT_ERROR_INVALID_DIMENSION,
288    SORT_ERROR_COMPARISON_METHOD_REQUIRES_STRING,
289    SORT_ERROR_COMPARISON_METHOD_UNKNOWN,
290    SORT_ERROR_MISSINGPLACEMENT_UNSUPPORTED,
291    SORT_ERROR_INVALID_ARGUMENT,
292    SORT_ERROR_INTERNAL,
293];
294
295pub const SORT_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
296    signatures: &SORT_SIGNATURES,
297    output_mode: BuiltinOutputMode::ByRequestedOutputCount,
298    completion_policy: BuiltinCompletionPolicy::Public,
299    errors: &SORT_ERRORS,
300};
301
302fn sort_error(
303    error: &'static BuiltinErrorDescriptor,
304    message: impl Into<String>,
305) -> crate::RuntimeError {
306    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
307    if let Some(identifier) = error.identifier {
308        builder = builder.with_identifier(identifier);
309    }
310    builder.build()
311}
312
313fn sort_internal(message: impl Into<String>) -> crate::RuntimeError {
314    sort_error(&SORT_ERROR_INTERNAL, message)
315}
316
317fn sort_invalid_argument(message: impl Into<String>) -> crate::RuntimeError {
318    sort_error(&SORT_ERROR_INVALID_ARGUMENT, message)
319}
320
321#[runtime_builtin(
322    name = "sort",
323    category = "array/sorting_sets",
324    summary = "Sort array elements along a dimension with optional index outputs.",
325    keywords = "sort,ascending,descending,indices,comparisonmethod,gpu",
326    accel = "sink",
327    sink = true,
328    type_resolver(tensor_output_type),
329    descriptor(crate::builtins::array::sorting_sets::sort::SORT_DESCRIPTOR),
330    builtin_path = "crate::builtins::array::sorting_sets::sort"
331)]
332async fn sort_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
333    let eval = evaluate(value, &rest).await?;
334    if let Some(out_count) = crate::output_count::current_output_count() {
335        if out_count == 0 {
336            return Ok(Value::OutputList(Vec::new()));
337        }
338        let (sorted, indices) = eval.into_values();
339        let mut outputs = vec![sorted];
340        if out_count >= 2 {
341            outputs.push(indices);
342        }
343        return Ok(crate::output_count::output_list_with_padding(
344            out_count, outputs,
345        ));
346    }
347    Ok(eval.into_sorted_value())
348}
349
350/// Evaluate the `sort` builtin once and expose both outputs.
351pub async fn evaluate(value: Value, rest: &[Value]) -> crate::BuiltinResult<SortEvaluation> {
352    let args = SortArgs::parse(rest)?;
353    match value {
354        Value::GpuTensor(handle) => sort_gpu(handle, &args).await,
355        other => sort_host(other, &args),
356    }
357}
358
359async fn sort_gpu(
360    handle: GpuTensorHandle,
361    args: &SortArgs,
362) -> crate::BuiltinResult<SortEvaluation> {
363    let shape = handle.shape.clone();
364    let dim = args.dimension.unwrap_or_else(|| default_dimension(&shape));
365    if dim == 0 {
366        return Err(sort_error(
367            &SORT_ERROR_INVALID_DIMENSION,
368            "sort: dimension must be >= 1",
369        ));
370    }
371    let dim_len = dimension_length(&shape, dim);
372    if dim_len > 1 {
373        if let Some(provider) = runmat_accelerate_api::provider() {
374            let order = args.direction.to_provider();
375            let comparison = args.comparison.to_provider();
376            let zero_based = dim - 1;
377            if let Ok(result) = provider
378                .sort_dim(&handle, zero_based, order, comparison)
379                .await
380            {
381                let sorted_tensor = Tensor::new(result.values.data, result.values.shape)
382                    .map_err(|e| sort_internal(format!("sort: {e}")))?;
383                let sorted_value = tensor::tensor_into_value(sorted_tensor);
384                let indices_tensor = Tensor::new(result.indices.data, result.indices.shape)
385                    .map_err(|e| sort_internal(format!("sort: {e}")))?;
386                return Ok(SortEvaluation {
387                    sorted: sorted_value,
388                    indices: indices_tensor,
389                });
390            }
391        }
392    }
393    let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
394    sort_real_tensor(tensor, args)
395}
396
397fn sort_host(value: Value, args: &SortArgs) -> crate::BuiltinResult<SortEvaluation> {
398    match value {
399        Value::ComplexTensor(ct) => sort_complex_tensor(ct, args),
400        Value::Complex(re, im) => {
401            let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
402                .map_err(|e| sort_internal(format!("sort: {e}")))?;
403            sort_complex_tensor(tensor, args)
404        }
405        other => {
406            let tensor =
407                tensor::value_into_tensor_for("sort", other).map_err(sort_invalid_argument)?;
408            sort_real_tensor(tensor, args)
409        }
410    }
411}
412
413fn sort_real_tensor(tensor: Tensor, args: &SortArgs) -> crate::BuiltinResult<SortEvaluation> {
414    let dim = args
415        .dimension
416        .unwrap_or_else(|| default_dimension(&tensor.shape));
417    if dim == 0 {
418        return Err(sort_error(
419            &SORT_ERROR_INVALID_DIMENSION,
420            "sort: dimension must be >= 1",
421        ));
422    }
423
424    let dim_len = dimension_length(&tensor.shape, dim);
425    if tensor.data.is_empty() || dim_len <= 1 {
426        let indices = vec![1.0; tensor.data.len()];
427        let index_tensor = Tensor::new(indices, tensor.shape.clone())
428            .map_err(|e| sort_internal(format!("sort: {e}")))?;
429        let sorted_value = tensor::tensor_into_value(tensor);
430        return Ok(SortEvaluation {
431            sorted: sorted_value,
432            indices: index_tensor,
433        });
434    }
435
436    let stride_before = stride_before(&tensor.shape, dim);
437    let stride_after = stride_after(&tensor.shape, dim);
438    let mut sorted = tensor.data.clone();
439    let mut indices = vec![0.0f64; tensor.data.len()];
440    let mut buffer: Vec<(usize, f64)> = Vec::with_capacity(dim_len);
441
442    for after in 0..stride_after {
443        for before in 0..stride_before {
444            buffer.clear();
445            for k in 0..dim_len {
446                let idx = before + k * stride_before + after * stride_before * dim_len;
447                let value = tensor.data[idx];
448                buffer.push((k, value));
449            }
450            buffer.sort_by(|a, b| compare_real_values(a.1, b.1, args));
451            for (pos, (original_index, value)) in buffer.iter().enumerate() {
452                let target = before + pos * stride_before + after * stride_before * dim_len;
453                sorted[target] = *value;
454                indices[target] = (*original_index + 1) as f64;
455            }
456        }
457    }
458
459    let sorted_tensor = Tensor::new(sorted, tensor.shape.clone())
460        .map_err(|e| sort_internal(format!("sort: {e}")))?;
461    let index_tensor = Tensor::new(indices, tensor.shape.clone())
462        .map_err(|e| sort_internal(format!("sort: {e}")))?;
463
464    Ok(SortEvaluation {
465        sorted: tensor::tensor_into_value(sorted_tensor),
466        indices: index_tensor,
467    })
468}
469
470fn sort_complex_tensor(
471    tensor: ComplexTensor,
472    args: &SortArgs,
473) -> crate::BuiltinResult<SortEvaluation> {
474    let dim = args
475        .dimension
476        .unwrap_or_else(|| default_dimension(&tensor.shape));
477    if dim == 0 {
478        return Err(sort_error(
479            &SORT_ERROR_INVALID_DIMENSION,
480            "sort: dimension must be >= 1",
481        ));
482    }
483
484    let dim_len = dimension_length(&tensor.shape, dim);
485    if tensor.data.is_empty() || dim_len <= 1 {
486        let indices = vec![1.0; tensor.data.len()];
487        let index_tensor = Tensor::new(indices, tensor.shape.clone())
488            .map_err(|e| sort_internal(format!("sort: {e}")))?;
489        return Ok(SortEvaluation {
490            sorted: complex_tensor_into_value(tensor),
491            indices: index_tensor,
492        });
493    }
494
495    let stride_before = stride_before(&tensor.shape, dim);
496    let stride_after = stride_after(&tensor.shape, dim);
497    let mut sorted = tensor.data.clone();
498    let mut indices = vec![0.0f64; tensor.data.len()];
499    let mut buffer: Vec<(usize, (f64, f64))> = Vec::with_capacity(dim_len);
500
501    for after in 0..stride_after {
502        for before in 0..stride_before {
503            buffer.clear();
504            for k in 0..dim_len {
505                let idx = before + k * stride_before + after * stride_before * dim_len;
506                let value = tensor.data[idx];
507                buffer.push((k, value));
508            }
509            buffer.sort_by(|a, b| compare_complex_values(a.1, b.1, args));
510            for (pos, (original_index, value)) in buffer.iter().enumerate() {
511                let target = before + pos * stride_before + after * stride_before * dim_len;
512                sorted[target] = *value;
513                indices[target] = (*original_index + 1) as f64;
514            }
515        }
516    }
517
518    let sorted_tensor = ComplexTensor::new(sorted, tensor.shape.clone())
519        .map_err(|e| sort_internal(format!("sort: {e}")))?;
520    let index_tensor = Tensor::new(indices, tensor.shape.clone())
521        .map_err(|e| sort_internal(format!("sort: {e}")))?;
522
523    Ok(SortEvaluation {
524        sorted: complex_tensor_into_value(sorted_tensor),
525        indices: index_tensor,
526    })
527}
528
529fn complex_tensor_into_value(tensor: ComplexTensor) -> Value {
530    if tensor.data.len() == 1 {
531        let (re, im) = tensor.data[0];
532        Value::Complex(re, im)
533    } else {
534        Value::ComplexTensor(tensor)
535    }
536}
537
538fn compare_real_values(a: f64, b: f64, args: &SortArgs) -> Ordering {
539    match (a.is_nan(), b.is_nan()) {
540        (true, true) => Ordering::Equal,
541        (true, false) => match args.direction {
542            SortDirection::Ascend => Ordering::Greater,
543            SortDirection::Descend => Ordering::Less,
544        },
545        (false, true) => match args.direction {
546            SortDirection::Ascend => Ordering::Less,
547            SortDirection::Descend => Ordering::Greater,
548        },
549        (false, false) => compare_real_finite(a, b, args),
550    }
551}
552
553fn compare_real_finite(a: f64, b: f64, args: &SortArgs) -> Ordering {
554    let primary = match args.comparison {
555        ComparisonMethod::Abs => {
556            let abs_cmp = a.abs().partial_cmp(&b.abs()).unwrap_or(Ordering::Equal);
557            if abs_cmp != Ordering::Equal {
558                return match args.direction {
559                    SortDirection::Ascend => abs_cmp,
560                    SortDirection::Descend => abs_cmp.reverse(),
561                };
562            }
563            Ordering::Equal
564        }
565        ComparisonMethod::Auto | ComparisonMethod::Real => Ordering::Equal,
566    };
567    if primary != Ordering::Equal {
568        return primary;
569    }
570    match args.direction {
571        SortDirection::Ascend => a.partial_cmp(&b).unwrap_or(Ordering::Equal),
572        SortDirection::Descend => b.partial_cmp(&a).unwrap_or(Ordering::Equal),
573    }
574}
575
576fn compare_complex_values(a: (f64, f64), b: (f64, f64), args: &SortArgs) -> Ordering {
577    match (complex_is_nan(a), complex_is_nan(b)) {
578        (true, true) => Ordering::Equal,
579        (true, false) => match args.direction {
580            SortDirection::Ascend => Ordering::Greater,
581            SortDirection::Descend => Ordering::Less,
582        },
583        (false, true) => match args.direction {
584            SortDirection::Ascend => Ordering::Less,
585            SortDirection::Descend => Ordering::Greater,
586        },
587        (false, false) => compare_complex_finite(a, b, args),
588    }
589}
590
591fn compare_complex_finite(a: (f64, f64), b: (f64, f64), args: &SortArgs) -> Ordering {
592    match args.comparison {
593        ComparisonMethod::Real => compare_complex_real_imag(a, b, args.direction),
594        ComparisonMethod::Abs | ComparisonMethod::Auto => {
595            let abs_cmp = complex_abs(a)
596                .partial_cmp(&complex_abs(b))
597                .unwrap_or(Ordering::Equal);
598            if abs_cmp != Ordering::Equal {
599                return match args.direction {
600                    SortDirection::Ascend => abs_cmp,
601                    SortDirection::Descend => abs_cmp.reverse(),
602                };
603            }
604            compare_complex_real_imag(a, b, args.direction)
605        }
606    }
607}
608
609fn compare_complex_real_imag(a: (f64, f64), b: (f64, f64), direction: SortDirection) -> Ordering {
610    let real_cmp = match direction {
611        SortDirection::Ascend => a.0.partial_cmp(&b.0),
612        SortDirection::Descend => b.0.partial_cmp(&a.0),
613    }
614    .unwrap_or(Ordering::Equal);
615    if real_cmp != Ordering::Equal {
616        return real_cmp;
617    }
618    match direction {
619        SortDirection::Ascend => a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal),
620        SortDirection::Descend => b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal),
621    }
622}
623
624fn complex_is_nan(value: (f64, f64)) -> bool {
625    value.0.is_nan() || value.1.is_nan()
626}
627
628fn complex_abs(value: (f64, f64)) -> f64 {
629    value.0.hypot(value.1)
630}
631
632fn stride_before(shape: &[usize], dim: usize) -> usize {
633    if dim <= 1 {
634        return 1;
635    }
636    let mut product = 1usize;
637    for i in 0..(dim - 1) {
638        product = product.saturating_mul(*shape.get(i).unwrap_or(&1));
639    }
640    product
641}
642
643fn stride_after(shape: &[usize], dim: usize) -> usize {
644    if dim >= shape.len() {
645        return 1;
646    }
647    let mut product = 1usize;
648    for extent in shape.iter().skip(dim) {
649        product = product.saturating_mul(*extent);
650    }
651    product
652}
653
654fn dimension_length(shape: &[usize], dim: usize) -> usize {
655    shape.get(dim - 1).copied().unwrap_or(1)
656}
657
658fn default_dimension(shape: &[usize]) -> usize {
659    shape
660        .iter()
661        .position(|&extent| extent > 1)
662        .map(|idx| idx + 1)
663        .unwrap_or(1)
664}
665
666#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
667enum SortDirection {
668    #[default]
669    Ascend,
670    Descend,
671}
672
673impl SortDirection {
674    fn to_provider(self) -> ProviderSortOrder {
675        match self {
676            SortDirection::Ascend => ProviderSortOrder::Ascend,
677            SortDirection::Descend => ProviderSortOrder::Descend,
678        }
679    }
680}
681
682#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
683enum ComparisonMethod {
684    #[default]
685    Auto,
686    Real,
687    Abs,
688}
689
690impl ComparisonMethod {
691    fn to_provider(self) -> ProviderSortComparison {
692        match self {
693            ComparisonMethod::Auto => ProviderSortComparison::Auto,
694            ComparisonMethod::Real => ProviderSortComparison::Real,
695            ComparisonMethod::Abs => ProviderSortComparison::Abs,
696        }
697    }
698}
699
700#[derive(Debug, Clone, Default)]
701struct SortArgs {
702    dimension: Option<usize>,
703    direction: SortDirection,
704    comparison: ComparisonMethod,
705}
706
707impl SortArgs {
708    fn parse(rest: &[Value]) -> crate::BuiltinResult<Self> {
709        let mut args = SortArgs::default();
710        let tokens = tokens_from_values(rest);
711        let mut i = 0usize;
712        while i < rest.len() {
713            if args.dimension.is_none() {
714                if is_dimension_placeholder(&rest[i]) {
715                    i += 1;
716                    continue;
717                }
718                match tensor::parse_dimension(&rest[i], "sort") {
719                    Ok(dim) => {
720                        args.dimension = Some(dim);
721                        i += 1;
722                        continue;
723                    }
724                    Err(err) => {
725                        if matches!(rest[i], Value::Int(_) | Value::Num(_)) {
726                            return Err(sort_error(&SORT_ERROR_INVALID_DIMENSION, err));
727                        }
728                    }
729                }
730            }
731            if let Some(ArgToken::String(text)) = tokens.get(i) {
732                match text.as_str() {
733                    "ascend" | "ascending" => {
734                        args.direction = SortDirection::Ascend;
735                        i += 1;
736                        continue;
737                    }
738                    "descend" | "descending" => {
739                        args.direction = SortDirection::Descend;
740                        i += 1;
741                        continue;
742                    }
743                    "comparisonmethod" => {
744                        i += 1;
745                        if i >= rest.len() {
746                            return Err(sort_invalid_argument(
747                                "sort: expected a value for 'ComparisonMethod'",
748                            ));
749                        }
750                        let value = match tokens.get(i) {
751                            Some(ArgToken::String(value)) => value.as_str(),
752                            _ => {
753                                return Err(sort_error(
754                                    &SORT_ERROR_COMPARISON_METHOD_REQUIRES_STRING,
755                                    "sort: 'ComparisonMethod' requires a string value",
756                                ))
757                            }
758                        };
759                        args.comparison = match value {
760                            "auto" => ComparisonMethod::Auto,
761                            "real" => ComparisonMethod::Real,
762                            "abs" | "magnitude" => ComparisonMethod::Abs,
763                            other => {
764                                return Err(sort_error(
765                                    &SORT_ERROR_COMPARISON_METHOD_UNKNOWN,
766                                    format!("sort: unsupported ComparisonMethod '{other}'"),
767                                )
768                                .into())
769                            }
770                        };
771                        i += 1;
772                        continue;
773                    }
774                    "missingplacement" => {
775                        return Err(sort_error(
776                            &SORT_ERROR_MISSINGPLACEMENT_UNSUPPORTED,
777                            SORT_ERROR_MISSINGPLACEMENT_UNSUPPORTED.message,
778                        )
779                        .into());
780                    }
781                    _ => {}
782                }
783            }
784            if let Some(keyword) = tensor::value_to_string(&rest[i]) {
785                let lowered = keyword.trim().to_ascii_lowercase();
786                match lowered.as_str() {
787                    "ascend" | "ascending" => {
788                        args.direction = SortDirection::Ascend;
789                        i += 1;
790                        continue;
791                    }
792                    "descend" | "descending" => {
793                        args.direction = SortDirection::Descend;
794                        i += 1;
795                        continue;
796                    }
797                    "comparisonmethod" => {
798                        i += 1;
799                        if i >= rest.len() {
800                            return Err(sort_invalid_argument(
801                                "sort: expected a value for 'ComparisonMethod'",
802                            ));
803                        }
804                        let raw = &rest[i];
805                        let value = match raw {
806                            Value::String(s) => s.clone(),
807                            Value::StringArray(sa) if sa.data.len() == 1 => sa.data[0].clone(),
808                            Value::CharArray(ca) if ca.rows == 1 => {
809                                ca.data.iter().copied().collect()
810                            }
811                            _ => {
812                                return Err(sort_error(
813                                    &SORT_ERROR_COMPARISON_METHOD_REQUIRES_STRING,
814                                    "sort: 'ComparisonMethod' requires a string value",
815                                ))
816                            }
817                        };
818                        let lowered_value = value.trim().to_ascii_lowercase();
819                        args.comparison = match lowered_value.as_str() {
820                            "auto" => ComparisonMethod::Auto,
821                            "real" => ComparisonMethod::Real,
822                            "abs" | "magnitude" => ComparisonMethod::Abs,
823                            other => {
824                                return Err(sort_error(
825                                    &SORT_ERROR_COMPARISON_METHOD_UNKNOWN,
826                                    format!("sort: unsupported ComparisonMethod '{other}'"),
827                                )
828                                .into())
829                            }
830                        };
831                        i += 1;
832                        continue;
833                    }
834                    "missingplacement" => {
835                        return Err(sort_error(
836                            &SORT_ERROR_MISSINGPLACEMENT_UNSUPPORTED,
837                            SORT_ERROR_MISSINGPLACEMENT_UNSUPPORTED.message,
838                        )
839                        .into());
840                    }
841                    _ => {}
842                }
843            }
844            return Err(sort_invalid_argument(format!(
845                "sort: unrecognised argument {:?}",
846                rest[i]
847            )));
848        }
849        Ok(args)
850    }
851}
852
853fn is_dimension_placeholder(value: &Value) -> bool {
854    match value {
855        Value::Tensor(t) => t.data.is_empty(),
856        Value::LogicalArray(logical) => logical.data.is_empty(),
857        _ => false,
858    }
859}
860
861pub struct SortEvaluation {
862    sorted: Value,
863    indices: Tensor,
864}
865
866impl SortEvaluation {
867    pub fn into_sorted_value(self) -> Value {
868        self.sorted
869    }
870
871    pub fn into_values(self) -> (Value, Value) {
872        let indices = tensor::tensor_into_value(self.indices);
873        (self.sorted, indices)
874    }
875
876    pub fn indices_value(&self) -> Value {
877        tensor::tensor_into_value(self.indices.clone())
878    }
879}
880
881#[cfg(test)]
882pub(crate) mod tests {
883    use super::*;
884    use crate::builtins::common::test_support;
885    use futures::executor::block_on;
886    use runmat_builtins::{ComplexTensor, IntValue, ResolveContext, Tensor, Type, Value};
887
888    fn sort_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
889        block_on(super::sort_builtin(value, rest))
890    }
891
892    fn evaluate(value: Value, rest: &[Value]) -> crate::BuiltinResult<SortEvaluation> {
893        block_on(super::evaluate(value, rest))
894    }
895
896    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
897    #[test]
898    fn sort_vector_default() {
899        let tensor = Tensor::new(vec![3.0, 1.0, 2.0], vec![3, 1]).unwrap();
900        let result = sort_builtin(Value::Tensor(tensor), Vec::new()).expect("sort");
901        match result {
902            Value::Tensor(t) => {
903                assert_eq!(t.data, vec![1.0, 2.0, 3.0]);
904                assert_eq!(t.shape, vec![3, 1]);
905            }
906            other => panic!("expected tensor result, got {other:?}"),
907        }
908    }
909
910    #[test]
911    fn sort_type_resolver_tensor() {
912        assert_eq!(
913            tensor_output_type(&[Type::tensor()], &ResolveContext::new(Vec::new())),
914            Type::tensor()
915        );
916    }
917
918    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
919    #[test]
920    fn sort_descend_direction() {
921        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
922        let result =
923            sort_builtin(Value::Tensor(tensor), vec![Value::from("descend")]).expect("sort");
924        match result {
925            Value::Tensor(t) => assert_eq!(t.data, vec![4.0, 3.0, 2.0, 1.0]),
926            other => panic!("expected tensor, got {other:?}"),
927        }
928    }
929
930    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
931    #[test]
932    fn sort_matrix_default_dim1() {
933        let tensor = Tensor::new(vec![4.0, 2.0, 1.0, 5.0, 6.0, 3.0], vec![2, 3]).unwrap();
934        let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
935        let (sorted, indices) = eval.into_values();
936        match sorted {
937            Value::Tensor(t) => {
938                assert_eq!(t.data, vec![2.0, 4.0, 1.0, 5.0, 3.0, 6.0]);
939                assert_eq!(t.shape, vec![2, 3]);
940            }
941            other => panic!("expected tensor result, got {other:?}"),
942        }
943        match indices {
944            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0, 1.0, 2.0, 2.0, 1.0]),
945            other => panic!("expected tensor indices, got {other:?}"),
946        }
947    }
948
949    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
950    #[test]
951    fn sort_matrix_along_dimension_two() {
952        let tensor = Tensor::new(vec![1.0, 3.0, 4.0, 2.0, 2.0, 5.0], vec![2, 3]).unwrap();
953        let eval =
954            evaluate(Value::Tensor(tensor), &[Value::Int(IntValue::I32(2))]).expect("evaluate");
955        let (sorted, indices) = eval.into_values();
956        match sorted {
957            Value::Tensor(t) => {
958                assert_eq!(t.data, vec![1.0, 2.0, 2.0, 3.0, 4.0, 5.0]);
959                assert_eq!(t.shape, vec![2, 3]);
960            }
961            other => panic!("expected tensor result, got {other:?}"),
962        }
963        match indices {
964            Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]),
965            other => panic!("expected tensor indices, got {other:?}"),
966        }
967    }
968
969    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
970    #[test]
971    fn sort_dimension_placeholder_then_dim() {
972        let tensor = Tensor::new(vec![1.0, 3.0, 4.0, 2.0], vec![2, 2]).unwrap();
973        let placeholder = Tensor::new(Vec::new(), vec![0, 0]).unwrap();
974        let eval = evaluate(
975            Value::Tensor(tensor),
976            &[
977                Value::Tensor(placeholder),
978                Value::Int(IntValue::I32(2)),
979                Value::from("descend"),
980            ],
981        )
982        .expect("evaluate");
983        let (sorted, _) = eval.into_values();
984        match sorted {
985            Value::Tensor(t) => assert_eq!(t.data, vec![4.0, 3.0, 1.0, 2.0]),
986            other => panic!("expected tensor result, got {other:?}"),
987        }
988    }
989
990    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
991    #[test]
992    fn sort_descend_then_dimension() {
993        let tensor = Tensor::new(vec![1.0, 3.0, 4.0, 2.0, 2.0, 5.0], vec![2, 3]).unwrap();
994        let eval = evaluate(
995            Value::Tensor(tensor),
996            &[Value::from("descend"), Value::Int(IntValue::I32(1))],
997        )
998        .expect("evaluate");
999        let (sorted, _) = eval.into_values();
1000        match sorted {
1001            Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 1.0, 4.0, 2.0, 5.0, 2.0]),
1002            other => panic!("expected tensor result, got {other:?}"),
1003        }
1004    }
1005
1006    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1007    #[test]
1008    fn sort_returns_indices() {
1009        let tensor = Tensor::new(vec![4.0, 1.0, 9.0, 2.0], vec![4, 1]).unwrap();
1010        let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
1011        let (sorted, indices) = eval.into_values();
1012        match sorted {
1013            Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 4.0, 9.0]),
1014            other => panic!("expected tensor, got {other:?}"),
1015        }
1016        match indices {
1017            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 4.0, 1.0, 3.0]),
1018            other => panic!("expected tensor, got {other:?}"),
1019        }
1020    }
1021
1022    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1023    #[test]
1024    fn sort_with_nan_handling() {
1025        let tensor = Tensor::new(vec![f64::NAN, 4.0, 1.0, 2.0], vec![4, 1]).unwrap();
1026        let eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("evaluate");
1027        let (sorted, _) = eval.into_values();
1028        match sorted {
1029            Value::Tensor(t) => {
1030                assert!(t.data[3].is_nan());
1031                assert_eq!(&t.data[0..3], &[1.0, 2.0, 4.0]);
1032            }
1033            other => panic!("expected tensor, got {other:?}"),
1034        }
1035
1036        let eval_desc =
1037            evaluate(Value::Tensor(tensor), &[Value::from("descend")]).expect("evaluate");
1038        let (sorted_desc, _) = eval_desc.into_values();
1039        match sorted_desc {
1040            Value::Tensor(t) => {
1041                assert!(t.data[0].is_nan());
1042                assert_eq!(&t.data[1..], &[4.0, 2.0, 1.0]);
1043            }
1044            other => panic!("expected tensor, got {other:?}"),
1045        }
1046    }
1047
1048    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1049    #[test]
1050    fn sort_by_absolute_value() {
1051        let tensor = Tensor::new(vec![-8.0, -1.0, 3.0, -2.0], vec![4, 1]).unwrap();
1052        let eval = evaluate(
1053            Value::Tensor(tensor),
1054            &[Value::from("ComparisonMethod"), Value::from("abs")],
1055        )
1056        .expect("evaluate");
1057        let (sorted, _) = eval.into_values();
1058        match sorted {
1059            Value::Tensor(t) => assert_eq!(t.data, vec![-1.0, -2.0, 3.0, -8.0]),
1060            other => panic!("expected tensor, got {other:?}"),
1061        }
1062    }
1063
1064    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1065    #[test]
1066    fn sort_by_absolute_value_descend() {
1067        let tensor = Tensor::new(vec![-1.0, 2.0, -3.0, 4.0], vec![4, 1]).unwrap();
1068        let eval = evaluate(
1069            Value::Tensor(tensor),
1070            &[
1071                Value::from("descend"),
1072                Value::from("ComparisonMethod"),
1073                Value::from("abs"),
1074            ],
1075        )
1076        .expect("evaluate");
1077        let (sorted, _) = eval.into_values();
1078        match sorted {
1079            Value::Tensor(t) => assert_eq!(t.data, vec![4.0, -3.0, 2.0, -1.0]),
1080            other => panic!("expected tensor, got {other:?}"),
1081        }
1082    }
1083
1084    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1085    #[test]
1086    fn sort_complex_auto_abs() {
1087        let tensor =
1088            ComplexTensor::new(vec![(1.0, 2.0), (-3.0, 0.5), (0.0, -1.0)], vec![3, 1]).unwrap();
1089        let eval = evaluate(Value::ComplexTensor(tensor), &[]).expect("evaluate");
1090        let (sorted, indices) = eval.into_values();
1091        match sorted {
1092            Value::ComplexTensor(t) => {
1093                assert_eq!(t.data, vec![(0.0, -1.0), (1.0, 2.0), (-3.0, 0.5)])
1094            }
1095            other => panic!("expected complex tensor, got {other:?}"),
1096        }
1097        match indices {
1098            Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 1.0, 2.0]),
1099            other => panic!("expected tensor indices, got {other:?}"),
1100        }
1101    }
1102
1103    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1104    #[test]
1105    fn sort_complex_real_descend() {
1106        let tensor =
1107            ComplexTensor::new(vec![(1.0, 2.0), (-3.0, 0.0), (1.0, -1.0)], vec![3, 1]).unwrap();
1108        let eval = evaluate(
1109            Value::ComplexTensor(tensor),
1110            &[
1111                Value::from("descend"),
1112                Value::from("ComparisonMethod"),
1113                Value::from("real"),
1114            ],
1115        )
1116        .expect("evaluate");
1117        let (sorted, _) = eval.into_values();
1118        match sorted {
1119            Value::ComplexTensor(t) => {
1120                assert_eq!(t.data, vec![(1.0, 2.0), (1.0, -1.0), (-3.0, 0.0)]);
1121            }
1122            other => panic!("expected complex tensor, got {other:?}"),
1123        }
1124    }
1125
1126    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1127    #[test]
1128    fn sort_stable_with_duplicates() {
1129        let tensor = Tensor::new(vec![2.0, 2.0, 1.0, 2.0], vec![4, 1]).unwrap();
1130        let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
1131        let (sorted, indices) = eval.into_values();
1132        match sorted {
1133            Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 2.0, 2.0]),
1134            other => panic!("expected tensor, got {other:?}"),
1135        }
1136        match indices {
1137            Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 1.0, 2.0, 4.0]),
1138            other => panic!("expected tensor indices, got {other:?}"),
1139        }
1140    }
1141
1142    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1143    #[test]
1144    fn sort_empty_tensor() {
1145        let tensor = Tensor::new(Vec::new(), vec![0, 3]).unwrap();
1146        let eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("evaluate");
1147        let (sorted, indices) = eval.into_values();
1148        match sorted {
1149            Value::Tensor(t) => {
1150                assert!(t.data.is_empty());
1151                assert_eq!(t.shape, tensor.shape);
1152            }
1153            other => panic!("expected tensor, got {other:?}"),
1154        }
1155        match indices {
1156            Value::Tensor(t) => assert!(t.data.is_empty()),
1157            other => panic!("expected tensor, got {other:?}"),
1158        }
1159    }
1160
1161    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1162    #[test]
1163    fn sort_dim_greater_than_ndims() {
1164        let tensor = Tensor::new(vec![4.0, 2.0, 3.0, 1.0], vec![2, 2]).unwrap();
1165        let eval = evaluate(
1166            Value::Tensor(tensor.clone()),
1167            &[Value::Int(IntValue::I32(3))],
1168        )
1169        .expect("evaluate");
1170        let (sorted, indices) = eval.into_values();
1171        match sorted {
1172            Value::Tensor(t) => assert_eq!(t.data, tensor.data),
1173            other => panic!("expected tensor, got {other:?}"),
1174        }
1175        match indices {
1176            Value::Tensor(t) => assert!(t.data.iter().all(|v| (*v - 1.0).abs() < f64::EPSILON)),
1177            other => panic!("expected tensor, got {other:?}"),
1178        }
1179    }
1180
1181    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1182    #[test]
1183    fn sort_invalid_argument_errors() {
1184        let err = sort_builtin(
1185            Value::Tensor(Tensor::new(vec![1.0], vec![1, 1]).unwrap()),
1186            vec![Value::from("missingplacement"), Value::from("first")],
1187        )
1188        .unwrap_err();
1189        assert_eq!(
1190            err.identifier(),
1191            SORT_ERROR_MISSINGPLACEMENT_UNSUPPORTED.identifier
1192        );
1193    }
1194
1195    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1196    #[test]
1197    fn sort_invalid_comparison_method_errors() {
1198        let err = sort_builtin(
1199            Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap()),
1200            vec![Value::from("ComparisonMethod"), Value::from("unknown")],
1201        )
1202        .unwrap_err();
1203        assert_eq!(
1204            err.identifier(),
1205            SORT_ERROR_COMPARISON_METHOD_UNKNOWN.identifier
1206        );
1207    }
1208
1209    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1210    #[test]
1211    fn sort_invalid_comparison_method_value_errors() {
1212        let err = sort_builtin(
1213            Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap()),
1214            vec![
1215                Value::from("ComparisonMethod"),
1216                Value::Int(IntValue::I32(1)),
1217            ],
1218        )
1219        .unwrap_err();
1220        assert_eq!(
1221            err.identifier(),
1222            SORT_ERROR_COMPARISON_METHOD_REQUIRES_STRING.identifier
1223        );
1224    }
1225
1226    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1227    #[test]
1228    fn sort_dimension_zero_errors() {
1229        let err = sort_builtin(
1230            Value::Tensor(Tensor::new(vec![1.0], vec![1, 1]).unwrap()),
1231            vec![Value::Num(0.0)],
1232        )
1233        .unwrap_err();
1234        assert_eq!(err.identifier(), SORT_ERROR_INVALID_DIMENSION.identifier);
1235    }
1236
1237    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1238    #[test]
1239    fn sort_gpu_round_trip() {
1240        test_support::with_test_provider(|provider| {
1241            let tensor = Tensor::new(vec![3.0, 1.0, 2.0], vec![3, 1]).unwrap();
1242            let view = runmat_accelerate_api::HostTensorView {
1243                data: &tensor.data,
1244                shape: &tensor.shape,
1245            };
1246            let handle = provider.upload(&view).expect("upload");
1247            let eval = evaluate(Value::GpuTensor(handle), &[]).expect("evaluate");
1248            let (sorted, indices) = eval.into_values();
1249            match sorted {
1250                Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 3.0]),
1251                other => panic!("expected tensor, got {other:?}"),
1252            }
1253            match indices {
1254                Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
1255                other => panic!("expected tensor, got {other:?}"),
1256            }
1257        });
1258    }
1259
1260    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1261    #[test]
1262    #[cfg(feature = "wgpu")]
1263    fn sort_wgpu_matches_cpu() {
1264        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1265            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1266        );
1267        let tensor = Tensor::new(vec![4.0, 1.0, 3.0, 2.0], vec![4, 1]).unwrap();
1268        let cpu_eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("cpu sort");
1269        let (cpu_sorted, cpu_indices) = cpu_eval.into_values();
1270
1271        let gpu_view = runmat_accelerate_api::HostTensorView {
1272            data: &tensor.data,
1273            shape: &tensor.shape,
1274        };
1275        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1276        let handle = provider.upload(&gpu_view).expect("upload");
1277        let gpu_eval = evaluate(Value::GpuTensor(handle), &[]).expect("gpu sort");
1278        let (gpu_sorted, gpu_indices) = gpu_eval.into_values();
1279
1280        let cpu_sorted_tensor = match cpu_sorted {
1281            Value::Tensor(t) => t,
1282            Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
1283            other => panic!("unexpected CPU sorted value {other:?}"),
1284        };
1285        let cpu_indices_tensor = match cpu_indices {
1286            Value::Tensor(t) => t,
1287            Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
1288            other => panic!("unexpected CPU indices value {other:?}"),
1289        };
1290        let gpu_sorted_tensor = match gpu_sorted {
1291            Value::Tensor(t) => t,
1292            Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
1293            other => panic!("unexpected GPU sorted value {other:?}"),
1294        };
1295        let gpu_indices_tensor = match gpu_indices {
1296            Value::Tensor(t) => t,
1297            Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
1298            other => panic!("unexpected GPU indices value {other:?}"),
1299        };
1300
1301        assert_eq!(gpu_sorted_tensor.data, cpu_sorted_tensor.data);
1302        assert_eq!(gpu_indices_tensor.data, cpu_indices_tensor.data);
1303    }
1304}