1use std::cmp::Ordering;
4
5use runmat_accelerate_api::{
6 GpuTensorHandle, SortComparison as ProviderSortComparison, SortOrder as ProviderSortOrder,
7};
8use runmat_builtins::{ComplexTensor, Tensor, Value};
9use runmat_macros::runtime_builtin;
10
11use super::type_resolvers::tensor_output_type;
12use crate::build_runtime_error;
13use crate::builtins::common::arg_tokens::{tokens_from_values, ArgToken};
14use crate::builtins::common::gpu_helpers;
15use crate::builtins::common::spec::{
16 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
17 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
18};
19use crate::builtins::common::tensor;
20
21#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::sorting_sets::sort")]
22pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
23 name: "sort",
24 op_kind: GpuOpKind::Custom("sort"),
25 supported_precisions: &[ScalarType::F32, ScalarType::F64],
26 broadcast: BroadcastSemantics::None,
27 provider_hooks: &[ProviderHook::Custom("sort_dim")],
28 constant_strategy: ConstantStrategy::InlineLiteral,
29 residency: ResidencyPolicy::GatherImmediately,
30 nan_mode: ReductionNaN::Include,
31 two_pass_threshold: None,
32 workgroup_size: None,
33 accepts_nan_mode: true,
34 notes: "Providers may add a dedicated sort kernel in the future; today tensors are gathered to host memory before sorting.",
35};
36
37#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::sorting_sets::sort")]
38pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
39 name: "sort",
40 shape: ShapeRequirements::Any,
41 constant_strategy: ConstantStrategy::InlineLiteral,
42 elementwise: None,
43 reduction: None,
44 emits_nan: true,
45 notes: "Sorting breaks fusion chains and acts as a residency sink; upstream tensors are gathered to host memory.",
46};
47
48fn sort_error(message: impl Into<String>) -> crate::RuntimeError {
49 build_runtime_error(message).with_builtin("sort").build()
50}
51
52#[runtime_builtin(
53 name = "sort",
54 category = "array/sorting_sets",
55 summary = "Sort scalars, vectors, matrices, or N-D tensors along a dimension, with optional index outputs.",
56 keywords = "sort,ascending,descending,indices,comparisonmethod,gpu",
57 accel = "sink",
58 sink = true,
59 type_resolver(tensor_output_type),
60 builtin_path = "crate::builtins::array::sorting_sets::sort"
61)]
62async fn sort_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
63 let eval = evaluate(value, &rest).await?;
64 if let Some(out_count) = crate::output_count::current_output_count() {
65 if out_count == 0 {
66 return Ok(Value::OutputList(Vec::new()));
67 }
68 let (sorted, indices) = eval.into_values();
69 let mut outputs = vec![sorted];
70 if out_count >= 2 {
71 outputs.push(indices);
72 }
73 return Ok(crate::output_count::output_list_with_padding(
74 out_count, outputs,
75 ));
76 }
77 Ok(eval.into_sorted_value())
78}
79
80pub async fn evaluate(value: Value, rest: &[Value]) -> crate::BuiltinResult<SortEvaluation> {
82 let args = SortArgs::parse(rest)?;
83 match value {
84 Value::GpuTensor(handle) => sort_gpu(handle, &args).await,
85 other => sort_host(other, &args),
86 }
87}
88
89async fn sort_gpu(
90 handle: GpuTensorHandle,
91 args: &SortArgs,
92) -> crate::BuiltinResult<SortEvaluation> {
93 let shape = handle.shape.clone();
94 let dim = args.dimension.unwrap_or_else(|| default_dimension(&shape));
95 if dim == 0 {
96 return Err(sort_error("sort: dimension must be >= 1"));
97 }
98 let dim_len = dimension_length(&shape, dim);
99 if dim_len > 1 {
100 if let Some(provider) = runmat_accelerate_api::provider() {
101 let order = args.direction.to_provider();
102 let comparison = args.comparison.to_provider();
103 let zero_based = dim - 1;
104 if let Ok(result) = provider
105 .sort_dim(&handle, zero_based, order, comparison)
106 .await
107 {
108 let sorted_tensor = Tensor::new(result.values.data, result.values.shape)
109 .map_err(|e| sort_error(format!("sort: {e}")))?;
110 let sorted_value = tensor::tensor_into_value(sorted_tensor);
111 let indices_tensor = Tensor::new(result.indices.data, result.indices.shape)
112 .map_err(|e| sort_error(format!("sort: {e}")))?;
113 return Ok(SortEvaluation {
114 sorted: sorted_value,
115 indices: indices_tensor,
116 });
117 }
118 }
119 }
120 let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
121 sort_real_tensor(tensor, args)
122}
123
124fn sort_host(value: Value, args: &SortArgs) -> crate::BuiltinResult<SortEvaluation> {
125 match value {
126 Value::ComplexTensor(ct) => sort_complex_tensor(ct, args),
127 Value::Complex(re, im) => {
128 let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
129 .map_err(|e| sort_error(format!("sort: {e}")))?;
130 sort_complex_tensor(tensor, args)
131 }
132 other => {
133 let tensor = tensor::value_into_tensor_for("sort", other).map_err(|e| sort_error(e))?;
134 sort_real_tensor(tensor, args)
135 }
136 }
137}
138
139fn sort_real_tensor(tensor: Tensor, args: &SortArgs) -> crate::BuiltinResult<SortEvaluation> {
140 let dim = args
141 .dimension
142 .unwrap_or_else(|| default_dimension(&tensor.shape));
143 if dim == 0 {
144 return Err(sort_error("sort: dimension must be >= 1"));
145 }
146
147 let dim_len = dimension_length(&tensor.shape, dim);
148 if tensor.data.is_empty() || dim_len <= 1 {
149 let indices = vec![1.0; tensor.data.len()];
150 let index_tensor = Tensor::new(indices, tensor.shape.clone())
151 .map_err(|e| sort_error(format!("sort: {e}")))?;
152 let sorted_value = tensor::tensor_into_value(tensor);
153 return Ok(SortEvaluation {
154 sorted: sorted_value,
155 indices: index_tensor,
156 });
157 }
158
159 let stride_before = stride_before(&tensor.shape, dim);
160 let stride_after = stride_after(&tensor.shape, dim);
161 let mut sorted = tensor.data.clone();
162 let mut indices = vec![0.0f64; tensor.data.len()];
163 let mut buffer: Vec<(usize, f64)> = Vec::with_capacity(dim_len);
164
165 for after in 0..stride_after {
166 for before in 0..stride_before {
167 buffer.clear();
168 for k in 0..dim_len {
169 let idx = before + k * stride_before + after * stride_before * dim_len;
170 let value = tensor.data[idx];
171 buffer.push((k, value));
172 }
173 buffer.sort_by(|a, b| compare_real_values(a.1, b.1, args));
174 for (pos, (original_index, value)) in buffer.iter().enumerate() {
175 let target = before + pos * stride_before + after * stride_before * dim_len;
176 sorted[target] = *value;
177 indices[target] = (*original_index + 1) as f64;
178 }
179 }
180 }
181
182 let sorted_tensor =
183 Tensor::new(sorted, tensor.shape.clone()).map_err(|e| sort_error(format!("sort: {e}")))?;
184 let index_tensor =
185 Tensor::new(indices, tensor.shape.clone()).map_err(|e| sort_error(format!("sort: {e}")))?;
186
187 Ok(SortEvaluation {
188 sorted: tensor::tensor_into_value(sorted_tensor),
189 indices: index_tensor,
190 })
191}
192
193fn sort_complex_tensor(
194 tensor: ComplexTensor,
195 args: &SortArgs,
196) -> crate::BuiltinResult<SortEvaluation> {
197 let dim = args
198 .dimension
199 .unwrap_or_else(|| default_dimension(&tensor.shape));
200 if dim == 0 {
201 return Err(sort_error("sort: dimension must be >= 1"));
202 }
203
204 let dim_len = dimension_length(&tensor.shape, dim);
205 if tensor.data.is_empty() || dim_len <= 1 {
206 let indices = vec![1.0; tensor.data.len()];
207 let index_tensor = Tensor::new(indices, tensor.shape.clone())
208 .map_err(|e| sort_error(format!("sort: {e}")))?;
209 return Ok(SortEvaluation {
210 sorted: complex_tensor_into_value(tensor),
211 indices: index_tensor,
212 });
213 }
214
215 let stride_before = stride_before(&tensor.shape, dim);
216 let stride_after = stride_after(&tensor.shape, dim);
217 let mut sorted = tensor.data.clone();
218 let mut indices = vec![0.0f64; tensor.data.len()];
219 let mut buffer: Vec<(usize, (f64, f64))> = Vec::with_capacity(dim_len);
220
221 for after in 0..stride_after {
222 for before in 0..stride_before {
223 buffer.clear();
224 for k in 0..dim_len {
225 let idx = before + k * stride_before + after * stride_before * dim_len;
226 let value = tensor.data[idx];
227 buffer.push((k, value));
228 }
229 buffer.sort_by(|a, b| compare_complex_values(a.1, b.1, args));
230 for (pos, (original_index, value)) in buffer.iter().enumerate() {
231 let target = before + pos * stride_before + after * stride_before * dim_len;
232 sorted[target] = *value;
233 indices[target] = (*original_index + 1) as f64;
234 }
235 }
236 }
237
238 let sorted_tensor = ComplexTensor::new(sorted, tensor.shape.clone())
239 .map_err(|e| sort_error(format!("sort: {e}")))?;
240 let index_tensor =
241 Tensor::new(indices, tensor.shape.clone()).map_err(|e| sort_error(format!("sort: {e}")))?;
242
243 Ok(SortEvaluation {
244 sorted: complex_tensor_into_value(sorted_tensor),
245 indices: index_tensor,
246 })
247}
248
249fn complex_tensor_into_value(tensor: ComplexTensor) -> Value {
250 if tensor.data.len() == 1 {
251 let (re, im) = tensor.data[0];
252 Value::Complex(re, im)
253 } else {
254 Value::ComplexTensor(tensor)
255 }
256}
257
258fn compare_real_values(a: f64, b: f64, args: &SortArgs) -> Ordering {
259 match (a.is_nan(), b.is_nan()) {
260 (true, true) => Ordering::Equal,
261 (true, false) => match args.direction {
262 SortDirection::Ascend => Ordering::Greater,
263 SortDirection::Descend => Ordering::Less,
264 },
265 (false, true) => match args.direction {
266 SortDirection::Ascend => Ordering::Less,
267 SortDirection::Descend => Ordering::Greater,
268 },
269 (false, false) => compare_real_finite(a, b, args),
270 }
271}
272
273fn compare_real_finite(a: f64, b: f64, args: &SortArgs) -> Ordering {
274 let primary = match args.comparison {
275 ComparisonMethod::Abs => {
276 let abs_cmp = a.abs().partial_cmp(&b.abs()).unwrap_or(Ordering::Equal);
277 if abs_cmp != Ordering::Equal {
278 return match args.direction {
279 SortDirection::Ascend => abs_cmp,
280 SortDirection::Descend => abs_cmp.reverse(),
281 };
282 }
283 Ordering::Equal
284 }
285 ComparisonMethod::Auto | ComparisonMethod::Real => Ordering::Equal,
286 };
287 if primary != Ordering::Equal {
288 return primary;
289 }
290 match args.direction {
291 SortDirection::Ascend => a.partial_cmp(&b).unwrap_or(Ordering::Equal),
292 SortDirection::Descend => b.partial_cmp(&a).unwrap_or(Ordering::Equal),
293 }
294}
295
296fn compare_complex_values(a: (f64, f64), b: (f64, f64), args: &SortArgs) -> Ordering {
297 match (complex_is_nan(a), complex_is_nan(b)) {
298 (true, true) => Ordering::Equal,
299 (true, false) => match args.direction {
300 SortDirection::Ascend => Ordering::Greater,
301 SortDirection::Descend => Ordering::Less,
302 },
303 (false, true) => match args.direction {
304 SortDirection::Ascend => Ordering::Less,
305 SortDirection::Descend => Ordering::Greater,
306 },
307 (false, false) => compare_complex_finite(a, b, args),
308 }
309}
310
311fn compare_complex_finite(a: (f64, f64), b: (f64, f64), args: &SortArgs) -> Ordering {
312 match args.comparison {
313 ComparisonMethod::Real => compare_complex_real_imag(a, b, args.direction),
314 ComparisonMethod::Abs | ComparisonMethod::Auto => {
315 let abs_cmp = complex_abs(a)
316 .partial_cmp(&complex_abs(b))
317 .unwrap_or(Ordering::Equal);
318 if abs_cmp != Ordering::Equal {
319 return match args.direction {
320 SortDirection::Ascend => abs_cmp,
321 SortDirection::Descend => abs_cmp.reverse(),
322 };
323 }
324 compare_complex_real_imag(a, b, args.direction)
325 }
326 }
327}
328
329fn compare_complex_real_imag(a: (f64, f64), b: (f64, f64), direction: SortDirection) -> Ordering {
330 let real_cmp = match direction {
331 SortDirection::Ascend => a.0.partial_cmp(&b.0),
332 SortDirection::Descend => b.0.partial_cmp(&a.0),
333 }
334 .unwrap_or(Ordering::Equal);
335 if real_cmp != Ordering::Equal {
336 return real_cmp;
337 }
338 match direction {
339 SortDirection::Ascend => a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal),
340 SortDirection::Descend => b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal),
341 }
342}
343
344fn complex_is_nan(value: (f64, f64)) -> bool {
345 value.0.is_nan() || value.1.is_nan()
346}
347
348fn complex_abs(value: (f64, f64)) -> f64 {
349 value.0.hypot(value.1)
350}
351
352fn stride_before(shape: &[usize], dim: usize) -> usize {
353 if dim <= 1 {
354 return 1;
355 }
356 let mut product = 1usize;
357 for i in 0..(dim - 1) {
358 product = product.saturating_mul(*shape.get(i).unwrap_or(&1));
359 }
360 product
361}
362
363fn stride_after(shape: &[usize], dim: usize) -> usize {
364 if dim >= shape.len() {
365 return 1;
366 }
367 let mut product = 1usize;
368 for extent in shape.iter().skip(dim) {
369 product = product.saturating_mul(*extent);
370 }
371 product
372}
373
374fn dimension_length(shape: &[usize], dim: usize) -> usize {
375 shape.get(dim - 1).copied().unwrap_or(1)
376}
377
378fn default_dimension(shape: &[usize]) -> usize {
379 shape
380 .iter()
381 .position(|&extent| extent > 1)
382 .map(|idx| idx + 1)
383 .unwrap_or(1)
384}
385
386#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
387enum SortDirection {
388 #[default]
389 Ascend,
390 Descend,
391}
392
393impl SortDirection {
394 fn to_provider(self) -> ProviderSortOrder {
395 match self {
396 SortDirection::Ascend => ProviderSortOrder::Ascend,
397 SortDirection::Descend => ProviderSortOrder::Descend,
398 }
399 }
400}
401
402#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
403enum ComparisonMethod {
404 #[default]
405 Auto,
406 Real,
407 Abs,
408}
409
410impl ComparisonMethod {
411 fn to_provider(self) -> ProviderSortComparison {
412 match self {
413 ComparisonMethod::Auto => ProviderSortComparison::Auto,
414 ComparisonMethod::Real => ProviderSortComparison::Real,
415 ComparisonMethod::Abs => ProviderSortComparison::Abs,
416 }
417 }
418}
419
420#[derive(Debug, Clone, Default)]
421struct SortArgs {
422 dimension: Option<usize>,
423 direction: SortDirection,
424 comparison: ComparisonMethod,
425}
426
427impl SortArgs {
428 fn parse(rest: &[Value]) -> crate::BuiltinResult<Self> {
429 let mut args = SortArgs::default();
430 let tokens = tokens_from_values(rest);
431 let mut i = 0usize;
432 while i < rest.len() {
433 if args.dimension.is_none() {
434 if is_dimension_placeholder(&rest[i]) {
435 i += 1;
436 continue;
437 }
438 match tensor::parse_dimension(&rest[i], "sort") {
439 Ok(dim) => {
440 args.dimension = Some(dim);
441 i += 1;
442 continue;
443 }
444 Err(err) => {
445 if matches!(rest[i], Value::Int(_) | Value::Num(_)) {
446 return Err(sort_error(err));
447 }
448 }
449 }
450 }
451 if let Some(ArgToken::String(text)) = tokens.get(i) {
452 match text.as_str() {
453 "ascend" | "ascending" => {
454 args.direction = SortDirection::Ascend;
455 i += 1;
456 continue;
457 }
458 "descend" | "descending" => {
459 args.direction = SortDirection::Descend;
460 i += 1;
461 continue;
462 }
463 "comparisonmethod" => {
464 i += 1;
465 if i >= rest.len() {
466 return Err(sort_error(
467 "sort: expected a value for 'ComparisonMethod'",
468 ));
469 }
470 let value = match tokens.get(i) {
471 Some(ArgToken::String(value)) => value.as_str(),
472 _ => {
473 return Err(sort_error(
474 "sort: 'ComparisonMethod' requires a string value",
475 ))
476 }
477 };
478 args.comparison = match value {
479 "auto" => ComparisonMethod::Auto,
480 "real" => ComparisonMethod::Real,
481 "abs" | "magnitude" => ComparisonMethod::Abs,
482 other => {
483 return Err(sort_error(format!(
484 "sort: unsupported ComparisonMethod '{other}'"
485 ))
486 .into())
487 }
488 };
489 i += 1;
490 continue;
491 }
492 "missingplacement" => {
493 return Err(sort_error(
494 "sort: the 'MissingPlacement' option is not supported yet",
495 )
496 .into());
497 }
498 _ => {}
499 }
500 }
501 if let Some(keyword) = tensor::value_to_string(&rest[i]) {
502 let lowered = keyword.trim().to_ascii_lowercase();
503 match lowered.as_str() {
504 "ascend" | "ascending" => {
505 args.direction = SortDirection::Ascend;
506 i += 1;
507 continue;
508 }
509 "descend" | "descending" => {
510 args.direction = SortDirection::Descend;
511 i += 1;
512 continue;
513 }
514 "comparisonmethod" => {
515 i += 1;
516 if i >= rest.len() {
517 return Err(sort_error(
518 "sort: expected a value for 'ComparisonMethod'",
519 ));
520 }
521 let raw = &rest[i];
522 let value = match raw {
523 Value::String(s) => s.clone(),
524 Value::StringArray(sa) if sa.data.len() == 1 => sa.data[0].clone(),
525 Value::CharArray(ca) if ca.rows == 1 => {
526 ca.data.iter().copied().collect()
527 }
528 _ => {
529 return Err(sort_error(
530 "sort: 'ComparisonMethod' requires a string value",
531 ))
532 }
533 };
534 let lowered_value = value.trim().to_ascii_lowercase();
535 args.comparison = match lowered_value.as_str() {
536 "auto" => ComparisonMethod::Auto,
537 "real" => ComparisonMethod::Real,
538 "abs" | "magnitude" => ComparisonMethod::Abs,
539 other => {
540 return Err(sort_error(format!(
541 "sort: unsupported ComparisonMethod '{other}'"
542 ))
543 .into())
544 }
545 };
546 i += 1;
547 continue;
548 }
549 "missingplacement" => {
550 return Err(sort_error(
551 "sort: the 'MissingPlacement' option is not supported yet",
552 )
553 .into());
554 }
555 _ => {}
556 }
557 }
558 return Err(sort_error(format!(
559 "sort: unrecognised argument {:?}",
560 rest[i]
561 )));
562 }
563 Ok(args)
564 }
565}
566
567fn is_dimension_placeholder(value: &Value) -> bool {
568 match value {
569 Value::Tensor(t) => t.data.is_empty(),
570 Value::LogicalArray(logical) => logical.data.is_empty(),
571 _ => false,
572 }
573}
574
575pub struct SortEvaluation {
576 sorted: Value,
577 indices: Tensor,
578}
579
580impl SortEvaluation {
581 pub fn into_sorted_value(self) -> Value {
582 self.sorted
583 }
584
585 pub fn into_values(self) -> (Value, Value) {
586 let indices = tensor::tensor_into_value(self.indices);
587 (self.sorted, indices)
588 }
589
590 pub fn indices_value(&self) -> Value {
591 tensor::tensor_into_value(self.indices.clone())
592 }
593}
594
595#[cfg(test)]
596pub(crate) mod tests {
597 use super::*;
598 use crate::builtins::common::test_support;
599 use futures::executor::block_on;
600 use runmat_builtins::{ComplexTensor, IntValue, ResolveContext, Tensor, Type, Value};
601
602 fn sort_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
603 block_on(super::sort_builtin(value, rest))
604 }
605
606 fn error_message(err: crate::RuntimeError) -> String {
607 err.message().to_string()
608 }
609
610 fn evaluate(value: Value, rest: &[Value]) -> crate::BuiltinResult<SortEvaluation> {
611 block_on(super::evaluate(value, rest))
612 }
613
614 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
615 #[test]
616 fn sort_vector_default() {
617 let tensor = Tensor::new(vec![3.0, 1.0, 2.0], vec![3, 1]).unwrap();
618 let result = sort_builtin(Value::Tensor(tensor), Vec::new()).expect("sort");
619 match result {
620 Value::Tensor(t) => {
621 assert_eq!(t.data, vec![1.0, 2.0, 3.0]);
622 assert_eq!(t.shape, vec![3, 1]);
623 }
624 other => panic!("expected tensor result, got {other:?}"),
625 }
626 }
627
628 #[test]
629 fn sort_type_resolver_tensor() {
630 assert_eq!(
631 tensor_output_type(&[Type::tensor()], &ResolveContext::new(Vec::new())),
632 Type::tensor()
633 );
634 }
635
636 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
637 #[test]
638 fn sort_descend_direction() {
639 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
640 let result =
641 sort_builtin(Value::Tensor(tensor), vec![Value::from("descend")]).expect("sort");
642 match result {
643 Value::Tensor(t) => assert_eq!(t.data, vec![4.0, 3.0, 2.0, 1.0]),
644 other => panic!("expected tensor, got {other:?}"),
645 }
646 }
647
648 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
649 #[test]
650 fn sort_matrix_default_dim1() {
651 let tensor = Tensor::new(vec![4.0, 2.0, 1.0, 5.0, 6.0, 3.0], vec![2, 3]).unwrap();
652 let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
653 let (sorted, indices) = eval.into_values();
654 match sorted {
655 Value::Tensor(t) => {
656 assert_eq!(t.data, vec![2.0, 4.0, 1.0, 5.0, 3.0, 6.0]);
657 assert_eq!(t.shape, vec![2, 3]);
658 }
659 other => panic!("expected tensor result, got {other:?}"),
660 }
661 match indices {
662 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0, 1.0, 2.0, 2.0, 1.0]),
663 other => panic!("expected tensor indices, got {other:?}"),
664 }
665 }
666
667 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
668 #[test]
669 fn sort_matrix_along_dimension_two() {
670 let tensor = Tensor::new(vec![1.0, 3.0, 4.0, 2.0, 2.0, 5.0], vec![2, 3]).unwrap();
671 let eval =
672 evaluate(Value::Tensor(tensor), &[Value::Int(IntValue::I32(2))]).expect("evaluate");
673 let (sorted, indices) = eval.into_values();
674 match sorted {
675 Value::Tensor(t) => {
676 assert_eq!(t.data, vec![1.0, 2.0, 2.0, 3.0, 4.0, 5.0]);
677 assert_eq!(t.shape, vec![2, 3]);
678 }
679 other => panic!("expected tensor result, got {other:?}"),
680 }
681 match indices {
682 Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]),
683 other => panic!("expected tensor indices, got {other:?}"),
684 }
685 }
686
687 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
688 #[test]
689 fn sort_dimension_placeholder_then_dim() {
690 let tensor = Tensor::new(vec![1.0, 3.0, 4.0, 2.0], vec![2, 2]).unwrap();
691 let placeholder = Tensor::new(Vec::new(), vec![0, 0]).unwrap();
692 let eval = evaluate(
693 Value::Tensor(tensor),
694 &[
695 Value::Tensor(placeholder),
696 Value::Int(IntValue::I32(2)),
697 Value::from("descend"),
698 ],
699 )
700 .expect("evaluate");
701 let (sorted, _) = eval.into_values();
702 match sorted {
703 Value::Tensor(t) => assert_eq!(t.data, vec![4.0, 3.0, 1.0, 2.0]),
704 other => panic!("expected tensor result, got {other:?}"),
705 }
706 }
707
708 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
709 #[test]
710 fn sort_descend_then_dimension() {
711 let tensor = Tensor::new(vec![1.0, 3.0, 4.0, 2.0, 2.0, 5.0], vec![2, 3]).unwrap();
712 let eval = evaluate(
713 Value::Tensor(tensor),
714 &[Value::from("descend"), Value::Int(IntValue::I32(1))],
715 )
716 .expect("evaluate");
717 let (sorted, _) = eval.into_values();
718 match sorted {
719 Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 1.0, 4.0, 2.0, 5.0, 2.0]),
720 other => panic!("expected tensor result, got {other:?}"),
721 }
722 }
723
724 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
725 #[test]
726 fn sort_returns_indices() {
727 let tensor = Tensor::new(vec![4.0, 1.0, 9.0, 2.0], vec![4, 1]).unwrap();
728 let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
729 let (sorted, indices) = eval.into_values();
730 match sorted {
731 Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 4.0, 9.0]),
732 other => panic!("expected tensor, got {other:?}"),
733 }
734 match indices {
735 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 4.0, 1.0, 3.0]),
736 other => panic!("expected tensor, got {other:?}"),
737 }
738 }
739
740 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
741 #[test]
742 fn sort_with_nan_handling() {
743 let tensor = Tensor::new(vec![f64::NAN, 4.0, 1.0, 2.0], vec![4, 1]).unwrap();
744 let eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("evaluate");
745 let (sorted, _) = eval.into_values();
746 match sorted {
747 Value::Tensor(t) => {
748 assert!(t.data[3].is_nan());
749 assert_eq!(&t.data[0..3], &[1.0, 2.0, 4.0]);
750 }
751 other => panic!("expected tensor, got {other:?}"),
752 }
753
754 let eval_desc =
755 evaluate(Value::Tensor(tensor), &[Value::from("descend")]).expect("evaluate");
756 let (sorted_desc, _) = eval_desc.into_values();
757 match sorted_desc {
758 Value::Tensor(t) => {
759 assert!(t.data[0].is_nan());
760 assert_eq!(&t.data[1..], &[4.0, 2.0, 1.0]);
761 }
762 other => panic!("expected tensor, got {other:?}"),
763 }
764 }
765
766 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
767 #[test]
768 fn sort_by_absolute_value() {
769 let tensor = Tensor::new(vec![-8.0, -1.0, 3.0, -2.0], vec![4, 1]).unwrap();
770 let eval = evaluate(
771 Value::Tensor(tensor),
772 &[Value::from("ComparisonMethod"), Value::from("abs")],
773 )
774 .expect("evaluate");
775 let (sorted, _) = eval.into_values();
776 match sorted {
777 Value::Tensor(t) => assert_eq!(t.data, vec![-1.0, -2.0, 3.0, -8.0]),
778 other => panic!("expected tensor, got {other:?}"),
779 }
780 }
781
782 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
783 #[test]
784 fn sort_by_absolute_value_descend() {
785 let tensor = Tensor::new(vec![-1.0, 2.0, -3.0, 4.0], vec![4, 1]).unwrap();
786 let eval = evaluate(
787 Value::Tensor(tensor),
788 &[
789 Value::from("descend"),
790 Value::from("ComparisonMethod"),
791 Value::from("abs"),
792 ],
793 )
794 .expect("evaluate");
795 let (sorted, _) = eval.into_values();
796 match sorted {
797 Value::Tensor(t) => assert_eq!(t.data, vec![4.0, -3.0, 2.0, -1.0]),
798 other => panic!("expected tensor, got {other:?}"),
799 }
800 }
801
802 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
803 #[test]
804 fn sort_complex_auto_abs() {
805 let tensor =
806 ComplexTensor::new(vec![(1.0, 2.0), (-3.0, 0.5), (0.0, -1.0)], vec![3, 1]).unwrap();
807 let eval = evaluate(Value::ComplexTensor(tensor), &[]).expect("evaluate");
808 let (sorted, indices) = eval.into_values();
809 match sorted {
810 Value::ComplexTensor(t) => {
811 assert_eq!(t.data, vec![(0.0, -1.0), (1.0, 2.0), (-3.0, 0.5)])
812 }
813 other => panic!("expected complex tensor, got {other:?}"),
814 }
815 match indices {
816 Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 1.0, 2.0]),
817 other => panic!("expected tensor indices, got {other:?}"),
818 }
819 }
820
821 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
822 #[test]
823 fn sort_complex_real_descend() {
824 let tensor =
825 ComplexTensor::new(vec![(1.0, 2.0), (-3.0, 0.0), (1.0, -1.0)], vec![3, 1]).unwrap();
826 let eval = evaluate(
827 Value::ComplexTensor(tensor),
828 &[
829 Value::from("descend"),
830 Value::from("ComparisonMethod"),
831 Value::from("real"),
832 ],
833 )
834 .expect("evaluate");
835 let (sorted, _) = eval.into_values();
836 match sorted {
837 Value::ComplexTensor(t) => {
838 assert_eq!(t.data, vec![(1.0, 2.0), (1.0, -1.0), (-3.0, 0.0)]);
839 }
840 other => panic!("expected complex tensor, got {other:?}"),
841 }
842 }
843
844 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
845 #[test]
846 fn sort_stable_with_duplicates() {
847 let tensor = Tensor::new(vec![2.0, 2.0, 1.0, 2.0], vec![4, 1]).unwrap();
848 let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
849 let (sorted, indices) = eval.into_values();
850 match sorted {
851 Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 2.0, 2.0]),
852 other => panic!("expected tensor, got {other:?}"),
853 }
854 match indices {
855 Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 1.0, 2.0, 4.0]),
856 other => panic!("expected tensor indices, got {other:?}"),
857 }
858 }
859
860 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
861 #[test]
862 fn sort_empty_tensor() {
863 let tensor = Tensor::new(Vec::new(), vec![0, 3]).unwrap();
864 let eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("evaluate");
865 let (sorted, indices) = eval.into_values();
866 match sorted {
867 Value::Tensor(t) => {
868 assert!(t.data.is_empty());
869 assert_eq!(t.shape, tensor.shape);
870 }
871 other => panic!("expected tensor, got {other:?}"),
872 }
873 match indices {
874 Value::Tensor(t) => assert!(t.data.is_empty()),
875 other => panic!("expected tensor, got {other:?}"),
876 }
877 }
878
879 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
880 #[test]
881 fn sort_dim_greater_than_ndims() {
882 let tensor = Tensor::new(vec![4.0, 2.0, 3.0, 1.0], vec![2, 2]).unwrap();
883 let eval = evaluate(
884 Value::Tensor(tensor.clone()),
885 &[Value::Int(IntValue::I32(3))],
886 )
887 .expect("evaluate");
888 let (sorted, indices) = eval.into_values();
889 match sorted {
890 Value::Tensor(t) => assert_eq!(t.data, tensor.data),
891 other => panic!("expected tensor, got {other:?}"),
892 }
893 match indices {
894 Value::Tensor(t) => assert!(t.data.iter().all(|v| (*v - 1.0).abs() < f64::EPSILON)),
895 other => panic!("expected tensor, got {other:?}"),
896 }
897 }
898
899 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
900 #[test]
901 fn sort_invalid_argument_errors() {
902 let err = error_message(
903 sort_builtin(
904 Value::Tensor(Tensor::new(vec![1.0], vec![1, 1]).unwrap()),
905 vec![Value::from("missingplacement"), Value::from("first")],
906 )
907 .unwrap_err(),
908 );
909 assert!(err.contains("MissingPlacement"));
910 }
911
912 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
913 #[test]
914 fn sort_invalid_comparison_method_errors() {
915 let err = error_message(
916 sort_builtin(
917 Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap()),
918 vec![Value::from("ComparisonMethod"), Value::from("unknown")],
919 )
920 .unwrap_err(),
921 );
922 assert!(err.contains("ComparisonMethod"), "unexpected error: {err}");
923 }
924
925 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
926 #[test]
927 fn sort_invalid_comparison_method_value_errors() {
928 let err = error_message(
929 sort_builtin(
930 Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap()),
931 vec![
932 Value::from("ComparisonMethod"),
933 Value::Int(IntValue::I32(1)),
934 ],
935 )
936 .unwrap_err(),
937 );
938 assert!(
939 err.contains("requires a string value"),
940 "unexpected error: {err}"
941 );
942 }
943
944 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
945 #[test]
946 fn sort_dimension_zero_errors() {
947 let err = error_message(
948 sort_builtin(
949 Value::Tensor(Tensor::new(vec![1.0], vec![1, 1]).unwrap()),
950 vec![Value::Num(0.0)],
951 )
952 .unwrap_err(),
953 );
954 assert!(
955 err.contains("dimension must be >= 1"),
956 "unexpected error: {err}"
957 );
958 }
959
960 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
961 #[test]
962 fn sort_gpu_round_trip() {
963 test_support::with_test_provider(|provider| {
964 let tensor = Tensor::new(vec![3.0, 1.0, 2.0], vec![3, 1]).unwrap();
965 let view = runmat_accelerate_api::HostTensorView {
966 data: &tensor.data,
967 shape: &tensor.shape,
968 };
969 let handle = provider.upload(&view).expect("upload");
970 let eval = evaluate(Value::GpuTensor(handle), &[]).expect("evaluate");
971 let (sorted, indices) = eval.into_values();
972 match sorted {
973 Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 3.0]),
974 other => panic!("expected tensor, got {other:?}"),
975 }
976 match indices {
977 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
978 other => panic!("expected tensor, got {other:?}"),
979 }
980 });
981 }
982
983 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
984 #[test]
985 #[cfg(feature = "wgpu")]
986 fn sort_wgpu_matches_cpu() {
987 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
988 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
989 );
990 let tensor = Tensor::new(vec![4.0, 1.0, 3.0, 2.0], vec![4, 1]).unwrap();
991 let cpu_eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("cpu sort");
992 let (cpu_sorted, cpu_indices) = cpu_eval.into_values();
993
994 let gpu_view = runmat_accelerate_api::HostTensorView {
995 data: &tensor.data,
996 shape: &tensor.shape,
997 };
998 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
999 let handle = provider.upload(&gpu_view).expect("upload");
1000 let gpu_eval = evaluate(Value::GpuTensor(handle), &[]).expect("gpu sort");
1001 let (gpu_sorted, gpu_indices) = gpu_eval.into_values();
1002
1003 let cpu_sorted_tensor = match cpu_sorted {
1004 Value::Tensor(t) => t,
1005 Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
1006 other => panic!("unexpected CPU sorted value {other:?}"),
1007 };
1008 let cpu_indices_tensor = match cpu_indices {
1009 Value::Tensor(t) => t,
1010 Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
1011 other => panic!("unexpected CPU indices value {other:?}"),
1012 };
1013 let gpu_sorted_tensor = match gpu_sorted {
1014 Value::Tensor(t) => t,
1015 Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
1016 other => panic!("unexpected GPU sorted value {other:?}"),
1017 };
1018 let gpu_indices_tensor = match gpu_indices {
1019 Value::Tensor(t) => t,
1020 Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).unwrap(),
1021 other => panic!("unexpected GPU indices value {other:?}"),
1022 };
1023
1024 assert_eq!(gpu_sorted_tensor.data, cpu_sorted_tensor.data);
1025 assert_eq!(gpu_indices_tensor.data, cpu_indices_tensor.data);
1026 }
1027}