Skip to main content

runmat_runtime/builtins/array/sorting_sets/
sort.rs

1//! MATLAB-compatible `sort` builtin with multi-output and GPU-aware semantics.
2
3use std::cmp::Ordering;
4
5use runmat_accelerate_api::{
6    GpuTensorHandle, SortComparison as ProviderSortComparison, SortOrder as ProviderSortOrder,
7};
8use runmat_builtins::{ComplexTensor, Tensor, Value};
9use runmat_macros::runtime_builtin;
10
11use super::type_resolvers::tensor_output_type;
12use crate::build_runtime_error;
13use crate::builtins::common::arg_tokens::{tokens_from_values, ArgToken};
14use crate::builtins::common::gpu_helpers;
15use crate::builtins::common::spec::{
16    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
17    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
18};
19use crate::builtins::common::tensor;
20
21#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::sorting_sets::sort")]
22pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
23    name: "sort",
24    op_kind: GpuOpKind::Custom("sort"),
25    supported_precisions: &[ScalarType::F32, ScalarType::F64],
26    broadcast: BroadcastSemantics::None,
27    provider_hooks: &[ProviderHook::Custom("sort_dim")],
28    constant_strategy: ConstantStrategy::InlineLiteral,
29    residency: ResidencyPolicy::GatherImmediately,
30    nan_mode: ReductionNaN::Include,
31    two_pass_threshold: None,
32    workgroup_size: None,
33    accepts_nan_mode: true,
34    notes: "Providers may add a dedicated sort kernel in the future; today tensors are gathered to host memory before sorting.",
35};
36
37#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::sorting_sets::sort")]
38pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
39    name: "sort",
40    shape: ShapeRequirements::Any,
41    constant_strategy: ConstantStrategy::InlineLiteral,
42    elementwise: None,
43    reduction: None,
44    emits_nan: true,
45    notes: "Sorting breaks fusion chains and acts as a residency sink; upstream tensors are gathered to host memory.",
46};
47
48fn sort_error(message: impl Into<String>) -> crate::RuntimeError {
49    build_runtime_error(message).with_builtin("sort").build()
50}
51
52#[runtime_builtin(
53    name = "sort",
54    category = "array/sorting_sets",
55    summary = "Sort scalars, vectors, matrices, or N-D tensors along a dimension, with optional index outputs.",
56    keywords = "sort,ascending,descending,indices,comparisonmethod,gpu",
57    accel = "sink",
58    sink = true,
59    type_resolver(tensor_output_type),
60    builtin_path = "crate::builtins::array::sorting_sets::sort"
61)]
62async fn sort_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
63    let eval = evaluate(value, &rest).await?;
64    if let Some(out_count) = crate::output_count::current_output_count() {
65        if out_count == 0 {
66            return Ok(Value::OutputList(Vec::new()));
67        }
68        let (sorted, indices) = eval.into_values();
69        let mut outputs = vec![sorted];
70        if out_count >= 2 {
71            outputs.push(indices);
72        }
73        return Ok(crate::output_count::output_list_with_padding(
74            out_count, outputs,
75        ));
76    }
77    Ok(eval.into_sorted_value())
78}
79
80/// Evaluate the `sort` builtin once and expose both outputs.
81pub async fn evaluate(value: Value, rest: &[Value]) -> crate::BuiltinResult<SortEvaluation> {
82    let args = SortArgs::parse(rest)?;
83    match value {
84        Value::GpuTensor(handle) => sort_gpu(handle, &args).await,
85        other => sort_host(other, &args),
86    }
87}
88
89async fn sort_gpu(
90    handle: GpuTensorHandle,
91    args: &SortArgs,
92) -> crate::BuiltinResult<SortEvaluation> {
93    let shape = handle.shape.clone();
94    let dim = args.dimension.unwrap_or_else(|| default_dimension(&shape));
95    if dim == 0 {
96        return Err(sort_error("sort: dimension must be >= 1"));
97    }
98    let dim_len = dimension_length(&shape, dim);
99    if dim_len > 1 {
100        if let Some(provider) = runmat_accelerate_api::provider() {
101            let order = args.direction.to_provider();
102            let comparison = args.comparison.to_provider();
103            let zero_based = dim - 1;
104            if let Ok(result) = provider
105                .sort_dim(&handle, zero_based, order, comparison)
106                .await
107            {
108                let sorted_tensor = Tensor::new(result.values.data, result.values.shape)
109                    .map_err(|e| sort_error(format!("sort: {e}")))?;
110                let sorted_value = tensor::tensor_into_value(sorted_tensor);
111                let indices_tensor = Tensor::new(result.indices.data, result.indices.shape)
112                    .map_err(|e| sort_error(format!("sort: {e}")))?;
113                return Ok(SortEvaluation {
114                    sorted: sorted_value,
115                    indices: indices_tensor,
116                });
117            }
118        }
119    }
120    let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
121    sort_real_tensor(tensor, args)
122}
123
124fn sort_host(value: Value, args: &SortArgs) -> crate::BuiltinResult<SortEvaluation> {
125    match value {
126        Value::ComplexTensor(ct) => sort_complex_tensor(ct, args),
127        Value::Complex(re, im) => {
128            let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
129                .map_err(|e| sort_error(format!("sort: {e}")))?;
130            sort_complex_tensor(tensor, args)
131        }
132        other => {
133            let tensor = tensor::value_into_tensor_for("sort", other).map_err(|e| sort_error(e))?;
134            sort_real_tensor(tensor, args)
135        }
136    }
137}
138
139fn sort_real_tensor(tensor: Tensor, args: &SortArgs) -> crate::BuiltinResult<SortEvaluation> {
140    let dim = args
141        .dimension
142        .unwrap_or_else(|| default_dimension(&tensor.shape));
143    if dim == 0 {
144        return Err(sort_error("sort: dimension must be >= 1"));
145    }
146
147    let dim_len = dimension_length(&tensor.shape, dim);
148    if tensor.data.is_empty() || dim_len <= 1 {
149        let indices = vec![1.0; tensor.data.len()];
150        let index_tensor = Tensor::new(indices, tensor.shape.clone())
151            .map_err(|e| sort_error(format!("sort: {e}")))?;
152        let sorted_value = tensor::tensor_into_value(tensor);
153        return Ok(SortEvaluation {
154            sorted: sorted_value,
155            indices: index_tensor,
156        });
157    }
158
159    let stride_before = stride_before(&tensor.shape, dim);
160    let stride_after = stride_after(&tensor.shape, dim);
161    let mut sorted = tensor.data.clone();
162    let mut indices = vec![0.0f64; tensor.data.len()];
163    let mut buffer: Vec<(usize, f64)> = Vec::with_capacity(dim_len);
164
165    for after in 0..stride_after {
166        for before in 0..stride_before {
167            buffer.clear();
168            for k in 0..dim_len {
169                let idx = before + k * stride_before + after * stride_before * dim_len;
170                let value = tensor.data[idx];
171                buffer.push((k, value));
172            }
173            buffer.sort_by(|a, b| compare_real_values(a.1, b.1, args));
174            for (pos, (original_index, value)) in buffer.iter().enumerate() {
175                let target = before + pos * stride_before + after * stride_before * dim_len;
176                sorted[target] = *value;
177                indices[target] = (*original_index + 1) as f64;
178            }
179        }
180    }
181
182    let sorted_tensor =
183        Tensor::new(sorted, tensor.shape.clone()).map_err(|e| sort_error(format!("sort: {e}")))?;
184    let index_tensor =
185        Tensor::new(indices, tensor.shape.clone()).map_err(|e| sort_error(format!("sort: {e}")))?;
186
187    Ok(SortEvaluation {
188        sorted: tensor::tensor_into_value(sorted_tensor),
189        indices: index_tensor,
190    })
191}
192
193fn sort_complex_tensor(
194    tensor: ComplexTensor,
195    args: &SortArgs,
196) -> crate::BuiltinResult<SortEvaluation> {
197    let dim = args
198        .dimension
199        .unwrap_or_else(|| default_dimension(&tensor.shape));
200    if dim == 0 {
201        return Err(sort_error("sort: dimension must be >= 1"));
202    }
203
204    let dim_len = dimension_length(&tensor.shape, dim);
205    if tensor.data.is_empty() || dim_len <= 1 {
206        let indices = vec![1.0; tensor.data.len()];
207        let index_tensor = Tensor::new(indices, tensor.shape.clone())
208            .map_err(|e| sort_error(format!("sort: {e}")))?;
209        return Ok(SortEvaluation {
210            sorted: complex_tensor_into_value(tensor),
211            indices: index_tensor,
212        });
213    }
214
215    let stride_before = stride_before(&tensor.shape, dim);
216    let stride_after = stride_after(&tensor.shape, dim);
217    let mut sorted = tensor.data.clone();
218    let mut indices = vec![0.0f64; tensor.data.len()];
219    let mut buffer: Vec<(usize, (f64, f64))> = Vec::with_capacity(dim_len);
220
221    for after in 0..stride_after {
222        for before in 0..stride_before {
223            buffer.clear();
224            for k in 0..dim_len {
225                let idx = before + k * stride_before + after * stride_before * dim_len;
226                let value = tensor.data[idx];
227                buffer.push((k, value));
228            }
229            buffer.sort_by(|a, b| compare_complex_values(a.1, b.1, args));
230            for (pos, (original_index, value)) in buffer.iter().enumerate() {
231                let target = before + pos * stride_before + after * stride_before * dim_len;
232                sorted[target] = *value;
233                indices[target] = (*original_index + 1) as f64;
234            }
235        }
236    }
237
238    let sorted_tensor = ComplexTensor::new(sorted, tensor.shape.clone())
239        .map_err(|e| sort_error(format!("sort: {e}")))?;
240    let index_tensor =
241        Tensor::new(indices, tensor.shape.clone()).map_err(|e| sort_error(format!("sort: {e}")))?;
242
243    Ok(SortEvaluation {
244        sorted: complex_tensor_into_value(sorted_tensor),
245        indices: index_tensor,
246    })
247}
248
249fn complex_tensor_into_value(tensor: ComplexTensor) -> Value {
250    if tensor.data.len() == 1 {
251        let (re, im) = tensor.data[0];
252        Value::Complex(re, im)
253    } else {
254        Value::ComplexTensor(tensor)
255    }
256}
257
258fn compare_real_values(a: f64, b: f64, args: &SortArgs) -> Ordering {
259    match (a.is_nan(), b.is_nan()) {
260        (true, true) => Ordering::Equal,
261        (true, false) => match args.direction {
262            SortDirection::Ascend => Ordering::Greater,
263            SortDirection::Descend => Ordering::Less,
264        },
265        (false, true) => match args.direction {
266            SortDirection::Ascend => Ordering::Less,
267            SortDirection::Descend => Ordering::Greater,
268        },
269        (false, false) => compare_real_finite(a, b, args),
270    }
271}
272
273fn compare_real_finite(a: f64, b: f64, args: &SortArgs) -> Ordering {
274    let primary = match args.comparison {
275        ComparisonMethod::Abs => {
276            let abs_cmp = a.abs().partial_cmp(&b.abs()).unwrap_or(Ordering::Equal);
277            if abs_cmp != Ordering::Equal {
278                return match args.direction {
279                    SortDirection::Ascend => abs_cmp,
280                    SortDirection::Descend => abs_cmp.reverse(),
281                };
282            }
283            Ordering::Equal
284        }
285        ComparisonMethod::Auto | ComparisonMethod::Real => Ordering::Equal,
286    };
287    if primary != Ordering::Equal {
288        return primary;
289    }
290    match args.direction {
291        SortDirection::Ascend => a.partial_cmp(&b).unwrap_or(Ordering::Equal),
292        SortDirection::Descend => b.partial_cmp(&a).unwrap_or(Ordering::Equal),
293    }
294}
295
296fn compare_complex_values(a: (f64, f64), b: (f64, f64), args: &SortArgs) -> Ordering {
297    match (complex_is_nan(a), complex_is_nan(b)) {
298        (true, true) => Ordering::Equal,
299        (true, false) => match args.direction {
300            SortDirection::Ascend => Ordering::Greater,
301            SortDirection::Descend => Ordering::Less,
302        },
303        (false, true) => match args.direction {
304            SortDirection::Ascend => Ordering::Less,
305            SortDirection::Descend => Ordering::Greater,
306        },
307        (false, false) => compare_complex_finite(a, b, args),
308    }
309}
310
311fn compare_complex_finite(a: (f64, f64), b: (f64, f64), args: &SortArgs) -> Ordering {
312    match args.comparison {
313        ComparisonMethod::Real => compare_complex_real_imag(a, b, args.direction),
314        ComparisonMethod::Abs | ComparisonMethod::Auto => {
315            let abs_cmp = complex_abs(a)
316                .partial_cmp(&complex_abs(b))
317                .unwrap_or(Ordering::Equal);
318            if abs_cmp != Ordering::Equal {
319                return match args.direction {
320                    SortDirection::Ascend => abs_cmp,
321                    SortDirection::Descend => abs_cmp.reverse(),
322                };
323            }
324            compare_complex_real_imag(a, b, args.direction)
325        }
326    }
327}
328
329fn compare_complex_real_imag(a: (f64, f64), b: (f64, f64), direction: SortDirection) -> Ordering {
330    let real_cmp = match direction {
331        SortDirection::Ascend => a.0.partial_cmp(&b.0),
332        SortDirection::Descend => b.0.partial_cmp(&a.0),
333    }
334    .unwrap_or(Ordering::Equal);
335    if real_cmp != Ordering::Equal {
336        return real_cmp;
337    }
338    match direction {
339        SortDirection::Ascend => a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal),
340        SortDirection::Descend => b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal),
341    }
342}
343
344fn complex_is_nan(value: (f64, f64)) -> bool {
345    value.0.is_nan() || value.1.is_nan()
346}
347
348fn complex_abs(value: (f64, f64)) -> f64 {
349    value.0.hypot(value.1)
350}
351
352fn stride_before(shape: &[usize], dim: usize) -> usize {
353    if dim <= 1 {
354        return 1;
355    }
356    let mut product = 1usize;
357    for i in 0..(dim - 1) {
358        product = product.saturating_mul(*shape.get(i).unwrap_or(&1));
359    }
360    product
361}
362
363fn stride_after(shape: &[usize], dim: usize) -> usize {
364    if dim >= shape.len() {
365        return 1;
366    }
367    let mut product = 1usize;
368    for extent in shape.iter().skip(dim) {
369        product = product.saturating_mul(*extent);
370    }
371    product
372}
373
374fn dimension_length(shape: &[usize], dim: usize) -> usize {
375    shape.get(dim - 1).copied().unwrap_or(1)
376}
377
378fn default_dimension(shape: &[usize]) -> usize {
379    shape
380        .iter()
381        .position(|&extent| extent > 1)
382        .map(|idx| idx + 1)
383        .unwrap_or(1)
384}
385
386#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
387enum SortDirection {
388    #[default]
389    Ascend,
390    Descend,
391}
392
393impl SortDirection {
394    fn to_provider(self) -> ProviderSortOrder {
395        match self {
396            SortDirection::Ascend => ProviderSortOrder::Ascend,
397            SortDirection::Descend => ProviderSortOrder::Descend,
398        }
399    }
400}
401
402#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
403enum ComparisonMethod {
404    #[default]
405    Auto,
406    Real,
407    Abs,
408}
409
410impl ComparisonMethod {
411    fn to_provider(self) -> ProviderSortComparison {
412        match self {
413            ComparisonMethod::Auto => ProviderSortComparison::Auto,
414            ComparisonMethod::Real => ProviderSortComparison::Real,
415            ComparisonMethod::Abs => ProviderSortComparison::Abs,
416        }
417    }
418}
419
420#[derive(Debug, Clone, Default)]
421struct SortArgs {
422    dimension: Option<usize>,
423    direction: SortDirection,
424    comparison: ComparisonMethod,
425}
426
427impl SortArgs {
428    fn parse(rest: &[Value]) -> crate::BuiltinResult<Self> {
429        let mut args = SortArgs::default();
430        let tokens = tokens_from_values(rest);
431        let mut i = 0usize;
432        while i < rest.len() {
433            if args.dimension.is_none() {
434                if is_dimension_placeholder(&rest[i]) {
435                    i += 1;
436                    continue;
437                }
438                match tensor::parse_dimension(&rest[i], "sort") {
439                    Ok(dim) => {
440                        args.dimension = Some(dim);
441                        i += 1;
442                        continue;
443                    }
444                    Err(err) => {
445                        if matches!(rest[i], Value::Int(_) | Value::Num(_)) {
446                            return Err(sort_error(err));
447                        }
448                    }
449                }
450            }
451            if let Some(ArgToken::String(text)) = tokens.get(i) {
452                match text.as_str() {
453                    "ascend" | "ascending" => {
454                        args.direction = SortDirection::Ascend;
455                        i += 1;
456                        continue;
457                    }
458                    "descend" | "descending" => {
459                        args.direction = SortDirection::Descend;
460                        i += 1;
461                        continue;
462                    }
463                    "comparisonmethod" => {
464                        i += 1;
465                        if i >= rest.len() {
466                            return Err(sort_error(
467                                "sort: expected a value for 'ComparisonMethod'",
468                            ));
469                        }
470                        let value = match tokens.get(i) {
471                            Some(ArgToken::String(value)) => value.as_str(),
472                            _ => {
473                                return Err(sort_error(
474                                    "sort: 'ComparisonMethod' requires a string value",
475                                ))
476                            }
477                        };
478                        args.comparison = match value {
479                            "auto" => ComparisonMethod::Auto,
480                            "real" => ComparisonMethod::Real,
481                            "abs" | "magnitude" => ComparisonMethod::Abs,
482                            other => {
483                                return Err(sort_error(format!(
484                                    "sort: unsupported ComparisonMethod '{other}'"
485                                ))
486                                .into())
487                            }
488                        };
489                        i += 1;
490                        continue;
491                    }
492                    "missingplacement" => {
493                        return Err(sort_error(
494                            "sort: the 'MissingPlacement' option is not supported yet",
495                        )
496                        .into());
497                    }
498                    _ => {}
499                }
500            }
501            if let Some(keyword) = tensor::value_to_string(&rest[i]) {
502                let lowered = keyword.trim().to_ascii_lowercase();
503                match lowered.as_str() {
504                    "ascend" | "ascending" => {
505                        args.direction = SortDirection::Ascend;
506                        i += 1;
507                        continue;
508                    }
509                    "descend" | "descending" => {
510                        args.direction = SortDirection::Descend;
511                        i += 1;
512                        continue;
513                    }
514                    "comparisonmethod" => {
515                        i += 1;
516                        if i >= rest.len() {
517                            return Err(sort_error(
518                                "sort: expected a value for 'ComparisonMethod'",
519                            ));
520                        }
521                        let raw = &rest[i];
522                        let value = match raw {
523                            Value::String(s) => s.clone(),
524                            Value::StringArray(sa) if sa.data.len() == 1 => sa.data[0].clone(),
525                            Value::CharArray(ca) if ca.rows == 1 => {
526                                ca.data.iter().copied().collect()
527                            }
528                            _ => {
529                                return Err(sort_error(
530                                    "sort: 'ComparisonMethod' requires a string value",
531                                ))
532                            }
533                        };
534                        let lowered_value = value.trim().to_ascii_lowercase();
535                        args.comparison = match lowered_value.as_str() {
536                            "auto" => ComparisonMethod::Auto,
537                            "real" => ComparisonMethod::Real,
538                            "abs" | "magnitude" => ComparisonMethod::Abs,
539                            other => {
540                                return Err(sort_error(format!(
541                                    "sort: unsupported ComparisonMethod '{other}'"
542                                ))
543                                .into())
544                            }
545                        };
546                        i += 1;
547                        continue;
548                    }
549                    "missingplacement" => {
550                        return Err(sort_error(
551                            "sort: the 'MissingPlacement' option is not supported yet",
552                        )
553                        .into());
554                    }
555                    _ => {}
556                }
557            }
558            return Err(sort_error(format!(
559                "sort: unrecognised argument {:?}",
560                rest[i]
561            )));
562        }
563        Ok(args)
564    }
565}
566
567fn is_dimension_placeholder(value: &Value) -> bool {
568    match value {
569        Value::Tensor(t) => t.data.is_empty(),
570        Value::LogicalArray(logical) => logical.data.is_empty(),
571        _ => false,
572    }
573}
574
575pub struct SortEvaluation {
576    sorted: Value,
577    indices: Tensor,
578}
579
580impl SortEvaluation {
581    pub fn into_sorted_value(self) -> Value {
582        self.sorted
583    }
584
585    pub fn into_values(self) -> (Value, Value) {
586        let indices = tensor::tensor_into_value(self.indices);
587        (self.sorted, indices)
588    }
589
590    pub fn indices_value(&self) -> Value {
591        tensor::tensor_into_value(self.indices.clone())
592    }
593}
594
595#[cfg(test)]
596pub(crate) mod tests {
597    use super::*;
598    use crate::builtins::common::test_support;
599    use futures::executor::block_on;
600    use runmat_builtins::{ComplexTensor, IntValue, ResolveContext, Tensor, Type, Value};
601
602    fn sort_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
603        block_on(super::sort_builtin(value, rest))
604    }
605
606    fn error_message(err: crate::RuntimeError) -> String {
607        err.message().to_string()
608    }
609
610    fn evaluate(value: Value, rest: &[Value]) -> crate::BuiltinResult<SortEvaluation> {
611        block_on(super::evaluate(value, rest))
612    }
613
614    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
615    #[test]
616    fn sort_vector_default() {
617        let tensor = Tensor::new(vec![3.0, 1.0, 2.0], vec![3, 1]).unwrap();
618        let result = sort_builtin(Value::Tensor(tensor), Vec::new()).expect("sort");
619        match result {
620            Value::Tensor(t) => {
621                assert_eq!(t.data, vec![1.0, 2.0, 3.0]);
622                assert_eq!(t.shape, vec![3, 1]);
623            }
624            other => panic!("expected tensor result, got {other:?}"),
625        }
626    }
627
628    #[test]
629    fn sort_type_resolver_tensor() {
630        assert_eq!(
631            tensor_output_type(&[Type::tensor()], &ResolveContext::new(Vec::new())),
632            Type::tensor()
633        );
634    }
635
636    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
637    #[test]
638    fn sort_descend_direction() {
639        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
640        let result =
641            sort_builtin(Value::Tensor(tensor), vec![Value::from("descend")]).expect("sort");
642        match result {
643            Value::Tensor(t) => assert_eq!(t.data, vec![4.0, 3.0, 2.0, 1.0]),
644            other => panic!("expected tensor, got {other:?}"),
645        }
646    }
647
648    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
649    #[test]
650    fn sort_matrix_default_dim1() {
651        let tensor = Tensor::new(vec![4.0, 2.0, 1.0, 5.0, 6.0, 3.0], vec![2, 3]).unwrap();
652        let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
653        let (sorted, indices) = eval.into_values();
654        match sorted {
655            Value::Tensor(t) => {
656                assert_eq!(t.data, vec![2.0, 4.0, 1.0, 5.0, 3.0, 6.0]);
657                assert_eq!(t.shape, vec![2, 3]);
658            }
659            other => panic!("expected tensor result, got {other:?}"),
660        }
661        match indices {
662            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0, 1.0, 2.0, 2.0, 1.0]),
663            other => panic!("expected tensor indices, got {other:?}"),
664        }
665    }
666
667    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
668    #[test]
669    fn sort_matrix_along_dimension_two() {
670        let tensor = Tensor::new(vec![1.0, 3.0, 4.0, 2.0, 2.0, 5.0], vec![2, 3]).unwrap();
671        let eval =
672            evaluate(Value::Tensor(tensor), &[Value::Int(IntValue::I32(2))]).expect("evaluate");
673        let (sorted, indices) = eval.into_values();
674        match sorted {
675            Value::Tensor(t) => {
676                assert_eq!(t.data, vec![1.0, 2.0, 2.0, 3.0, 4.0, 5.0]);
677                assert_eq!(t.shape, vec![2, 3]);
678            }
679            other => panic!("expected tensor result, got {other:?}"),
680        }
681        match indices {
682            Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]),
683            other => panic!("expected tensor indices, got {other:?}"),
684        }
685    }
686
687    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
688    #[test]
689    fn sort_dimension_placeholder_then_dim() {
690        let tensor = Tensor::new(vec![1.0, 3.0, 4.0, 2.0], vec![2, 2]).unwrap();
691        let placeholder = Tensor::new(Vec::new(), vec![0, 0]).unwrap();
692        let eval = evaluate(
693            Value::Tensor(tensor),
694            &[
695                Value::Tensor(placeholder),
696                Value::Int(IntValue::I32(2)),
697                Value::from("descend"),
698            ],
699        )
700        .expect("evaluate");
701        let (sorted, _) = eval.into_values();
702        match sorted {
703            Value::Tensor(t) => assert_eq!(t.data, vec![4.0, 3.0, 1.0, 2.0]),
704            other => panic!("expected tensor result, got {other:?}"),
705        }
706    }
707
708    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
709    #[test]
710    fn sort_descend_then_dimension() {
711        let tensor = Tensor::new(vec![1.0, 3.0, 4.0, 2.0, 2.0, 5.0], vec![2, 3]).unwrap();
712        let eval = evaluate(
713            Value::Tensor(tensor),
714            &[Value::from("descend"), Value::Int(IntValue::I32(1))],
715        )
716        .expect("evaluate");
717        let (sorted, _) = eval.into_values();
718        match sorted {
719            Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 1.0, 4.0, 2.0, 5.0, 2.0]),
720            other => panic!("expected tensor result, got {other:?}"),
721        }
722    }
723
724    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
725    #[test]
726    fn sort_returns_indices() {
727        let tensor = Tensor::new(vec![4.0, 1.0, 9.0, 2.0], vec![4, 1]).unwrap();
728        let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
729        let (sorted, indices) = eval.into_values();
730        match sorted {
731            Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 4.0, 9.0]),
732            other => panic!("expected tensor, got {other:?}"),
733        }
734        match indices {
735            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 4.0, 1.0, 3.0]),
736            other => panic!("expected tensor, got {other:?}"),
737        }
738    }
739
740    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
741    #[test]
742    fn sort_with_nan_handling() {
743        let tensor = Tensor::new(vec![f64::NAN, 4.0, 1.0, 2.0], vec![4, 1]).unwrap();
744        let eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("evaluate");
745        let (sorted, _) = eval.into_values();
746        match sorted {
747            Value::Tensor(t) => {
748                assert!(t.data[3].is_nan());
749                assert_eq!(&t.data[0..3], &[1.0, 2.0, 4.0]);
750            }
751            other => panic!("expected tensor, got {other:?}"),
752        }
753
754        let eval_desc =
755            evaluate(Value::Tensor(tensor), &[Value::from("descend")]).expect("evaluate");
756        let (sorted_desc, _) = eval_desc.into_values();
757        match sorted_desc {
758            Value::Tensor(t) => {
759                assert!(t.data[0].is_nan());
760                assert_eq!(&t.data[1..], &[4.0, 2.0, 1.0]);
761            }
762            other => panic!("expected tensor, got {other:?}"),
763        }
764    }
765
766    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
767    #[test]
768    fn sort_by_absolute_value() {
769        let tensor = Tensor::new(vec![-8.0, -1.0, 3.0, -2.0], vec![4, 1]).unwrap();
770        let eval = evaluate(
771            Value::Tensor(tensor),
772            &[Value::from("ComparisonMethod"), Value::from("abs")],
773        )
774        .expect("evaluate");
775        let (sorted, _) = eval.into_values();
776        match sorted {
777            Value::Tensor(t) => assert_eq!(t.data, vec![-1.0, -2.0, 3.0, -8.0]),
778            other => panic!("expected tensor, got {other:?}"),
779        }
780    }
781
782    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
783    #[test]
784    fn sort_by_absolute_value_descend() {
785        let tensor = Tensor::new(vec![-1.0, 2.0, -3.0, 4.0], vec![4, 1]).unwrap();
786        let eval = evaluate(
787            Value::Tensor(tensor),
788            &[
789                Value::from("descend"),
790                Value::from("ComparisonMethod"),
791                Value::from("abs"),
792            ],
793        )
794        .expect("evaluate");
795        let (sorted, _) = eval.into_values();
796        match sorted {
797            Value::Tensor(t) => assert_eq!(t.data, vec![4.0, -3.0, 2.0, -1.0]),
798            other => panic!("expected tensor, got {other:?}"),
799        }
800    }
801
802    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
803    #[test]
804    fn sort_complex_auto_abs() {
805        let tensor =
806            ComplexTensor::new(vec![(1.0, 2.0), (-3.0, 0.5), (0.0, -1.0)], vec![3, 1]).unwrap();
807        let eval = evaluate(Value::ComplexTensor(tensor), &[]).expect("evaluate");
808        let (sorted, indices) = eval.into_values();
809        match sorted {
810            Value::ComplexTensor(t) => {
811                assert_eq!(t.data, vec![(0.0, -1.0), (1.0, 2.0), (-3.0, 0.5)])
812            }
813            other => panic!("expected complex tensor, got {other:?}"),
814        }
815        match indices {
816            Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 1.0, 2.0]),
817            other => panic!("expected tensor indices, got {other:?}"),
818        }
819    }
820
821    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
822    #[test]
823    fn sort_complex_real_descend() {
824        let tensor =
825            ComplexTensor::new(vec![(1.0, 2.0), (-3.0, 0.0), (1.0, -1.0)], vec![3, 1]).unwrap();
826        let eval = evaluate(
827            Value::ComplexTensor(tensor),
828            &[
829                Value::from("descend"),
830                Value::from("ComparisonMethod"),
831                Value::from("real"),
832            ],
833        )
834        .expect("evaluate");
835        let (sorted, _) = eval.into_values();
836        match sorted {
837            Value::ComplexTensor(t) => {
838                assert_eq!(t.data, vec![(1.0, 2.0), (1.0, -1.0), (-3.0, 0.0)]);
839            }
840            other => panic!("expected complex tensor, got {other:?}"),
841        }
842    }
843
844    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
845    #[test]
846    fn sort_stable_with_duplicates() {
847        let tensor = Tensor::new(vec![2.0, 2.0, 1.0, 2.0], vec![4, 1]).unwrap();
848        let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
849        let (sorted, indices) = eval.into_values();
850        match sorted {
851            Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 2.0, 2.0]),
852            other => panic!("expected tensor, got {other:?}"),
853        }
854        match indices {
855            Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 1.0, 2.0, 4.0]),
856            other => panic!("expected tensor indices, got {other:?}"),
857        }
858    }
859
860    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
861    #[test]
862    fn sort_empty_tensor() {
863        let tensor = Tensor::new(Vec::new(), vec![0, 3]).unwrap();
864        let eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("evaluate");
865        let (sorted, indices) = eval.into_values();
866        match sorted {
867            Value::Tensor(t) => {
868                assert!(t.data.is_empty());
869                assert_eq!(t.shape, tensor.shape);
870            }
871            other => panic!("expected tensor, got {other:?}"),
872        }
873        match indices {
874            Value::Tensor(t) => assert!(t.data.is_empty()),
875            other => panic!("expected tensor, got {other:?}"),
876        }
877    }
878
879    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
880    #[test]
881    fn sort_dim_greater_than_ndims() {
882        let tensor = Tensor::new(vec![4.0, 2.0, 3.0, 1.0], vec![2, 2]).unwrap();
883        let eval = evaluate(
884            Value::Tensor(tensor.clone()),
885            &[Value::Int(IntValue::I32(3))],
886        )
887        .expect("evaluate");
888        let (sorted, indices) = eval.into_values();
889        match sorted {
890            Value::Tensor(t) => assert_eq!(t.data, tensor.data),
891            other => panic!("expected tensor, got {other:?}"),
892        }
893        match indices {
894            Value::Tensor(t) => assert!(t.data.iter().all(|v| (*v - 1.0).abs() < f64::EPSILON)),
895            other => panic!("expected tensor, got {other:?}"),
896        }
897    }
898
899    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
900    #[test]
901    fn sort_invalid_argument_errors() {
902        let err = error_message(
903            sort_builtin(
904                Value::Tensor(Tensor::new(vec![1.0], vec![1, 1]).unwrap()),
905                vec![Value::from("missingplacement"), Value::from("first")],
906            )
907            .unwrap_err(),
908        );
909        assert!(err.contains("MissingPlacement"));
910    }
911
912    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
913    #[test]
914    fn sort_invalid_comparison_method_errors() {
915        let err = error_message(
916            sort_builtin(
917                Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap()),
918                vec![Value::from("ComparisonMethod"), Value::from("unknown")],
919            )
920            .unwrap_err(),
921        );
922        assert!(err.contains("ComparisonMethod"), "unexpected error: {err}");
923    }
924
925    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
926    #[test]
927    fn sort_invalid_comparison_method_value_errors() {
928        let err = error_message(
929            sort_builtin(
930                Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap()),
931                vec![
932                    Value::from("ComparisonMethod"),
933                    Value::Int(IntValue::I32(1)),
934                ],
935            )
936            .unwrap_err(),
937        );
938        assert!(
939            err.contains("requires a string value"),
940            "unexpected error: {err}"
941        );
942    }
943
944    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
945    #[test]
946    fn sort_dimension_zero_errors() {
947        let err = error_message(
948            sort_builtin(
949                Value::Tensor(Tensor::new(vec![1.0], vec![1, 1]).unwrap()),
950                vec![Value::Num(0.0)],
951            )
952            .unwrap_err(),
953        );
954        assert!(
955            err.contains("dimension must be >= 1"),
956            "unexpected error: {err}"
957        );
958    }
959
960    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
961    #[test]
962    fn sort_gpu_round_trip() {
963        test_support::with_test_provider(|provider| {
964            let tensor = Tensor::new(vec![3.0, 1.0, 2.0], vec![3, 1]).unwrap();
965            let view = runmat_accelerate_api::HostTensorView {
966                data: &tensor.data,
967                shape: &tensor.shape,
968            };
969            let handle = provider.upload(&view).expect("upload");
970            let eval = evaluate(Value::GpuTensor(handle), &[]).expect("evaluate");
971            let (sorted, indices) = eval.into_values();
972            match sorted {
973                Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 3.0]),
974                other => panic!("expected tensor, got {other:?}"),
975            }
976            match indices {
977                Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
978                other => panic!("expected tensor, got {other:?}"),
979            }
980        });
981    }
982
983    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
984    #[test]
985    #[cfg(feature = "wgpu")]
986    fn sort_wgpu_matches_cpu() {
987        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
988            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
989        );
990        let tensor = Tensor::new(vec![4.0, 1.0, 3.0, 2.0], vec![4, 1]).unwrap();
991        let cpu_eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("cpu sort");
992        let (cpu_sorted, cpu_indices) = cpu_eval.into_values();
993
994        let gpu_view = runmat_accelerate_api::HostTensorView {
995            data: &tensor.data,
996            shape: &tensor.shape,
997        };
998        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
999        let handle = provider.upload(&gpu_view).expect("upload");
1000        let gpu_eval = evaluate(Value::GpuTensor(handle), &[]).expect("gpu sort");
1001        let (gpu_sorted, gpu_indices) = gpu_eval.into_values();
1002
1003        let cpu_sorted_tensor = match cpu_sorted {
1004            Value::Tensor(t) => t,
1005            Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
1006            other => panic!("unexpected CPU sorted value {other:?}"),
1007        };
1008        let cpu_indices_tensor = match cpu_indices {
1009            Value::Tensor(t) => t,
1010            Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
1011            other => panic!("unexpected CPU indices value {other:?}"),
1012        };
1013        let gpu_sorted_tensor = match gpu_sorted {
1014            Value::Tensor(t) => t,
1015            Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
1016            other => panic!("unexpected GPU sorted value {other:?}"),
1017        };
1018        let gpu_indices_tensor = match gpu_indices {
1019            Value::Tensor(t) => t,
1020            Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
1021            other => panic!("unexpected GPU indices value {other:?}"),
1022        };
1023
1024        assert_eq!(gpu_sorted_tensor.data, cpu_sorted_tensor.data);
1025        assert_eq!(gpu_indices_tensor.data, cpu_indices_tensor.data);
1026    }
1027}