1use std::cmp::Ordering;
4
5use runmat_accelerate_api::{
6 GpuTensorHandle, SortComparison as ProviderSortComparison, SortOrder as ProviderSortOrder,
7 SortResult as ProviderSortResult, SortRowsColumnSpec as ProviderSortRowsColumnSpec,
8};
9use runmat_builtins::{CharArray, ComplexTensor, Tensor, Value};
10use runmat_macros::runtime_builtin;
11
12use crate::builtins::common::gpu_helpers;
13use crate::builtins::common::spec::{
14 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
15 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
16};
17use crate::builtins::common::tensor;
18#[cfg(feature = "doc_export")]
19use crate::register_builtin_doc_text;
20use crate::{register_builtin_fusion_spec, register_builtin_gpu_spec};
21
22#[cfg(feature = "doc_export")]
23pub const DOC_MD: &str = r#"---
24title: "sortrows"
25category: "array/sorting_sets"
26keywords: ["sortrows", "row sort", "lexicographic", "gpu"]
27summary: "Sort matrix rows lexicographically with optional column and direction control."
28references:
29 - https://www.mathworks.com/help/matlab/ref/sortrows.html
30gpu_support:
31 elementwise: false
32 reduction: false
33 precisions: ["f32", "f64"]
34 broadcasting: "none"
35 notes: "Falls back to host memory when providers do not expose a dedicated row sort kernel."
36fusion:
37 elementwise: false
38 reduction: false
39 max_inputs: 1
40 constants: "inline"
41requires_feature: null
42tested:
43 unit: "builtins::array::sorting_sets::sortrows::tests"
44 integration: "builtins::array::sorting_sets::sortrows::tests::sortrows_gpu_roundtrip"
45---
46
47# What does the `sortrows` function do in MATLAB / RunMat?
48`sortrows` reorders the rows of a matrix (or character array) so they appear in lexicographic order.
49You can control which columns participate in the comparison and whether each column uses ascending or descending order.
50
51## How does the `sortrows` function behave in MATLAB / RunMat?
52- `sortrows(A)` sorts by column `1`, then column `2`, and so on, all in ascending order.
53- `sortrows(A, C)` treats the vector `C` as the column order. Positive entries sort ascending; negative entries sort descending.
54- `sortrows(A, 'descend')` sorts all columns in descending order. Combine this with a column vector to mix directions.
55- `[B, I] = sortrows(A, ...)` also returns `I`, the 1-based row permutation indices.
56- `sortrows` is stable: rows that compare equal keep their original order.
57- For complex inputs, `'ComparisonMethod'` accepts `'auto'`, `'real'`, or `'abs'`, matching MATLAB semantics.
58- NaN handling mirrors MATLAB: in ascending sorts rows containing NaN values move to the end; in descending sorts they move to the beginning.
59- `'MissingPlacement'` lets you choose whether NaN (and other missing) rows appear `'first'`, `'last'`, or follow MATLAB's `'auto'` default.
60- Character arrays are sorted lexicographically using their character codes.
61
62## `sortrows` Function GPU Execution Behaviour
63- `sortrows` is registered as a sink builtin. When the input tensor already lives on the GPU and the active provider exposes a `sortrows` hook, the runtime delegates to that hook; the current provider contract returns host buffers, so the sorted rows and permutation indices are materialised on the CPU before being returned.
64- When the provider lacks the hook—or cannot honour a specific combination of options such as `'MissingPlacement','first'` or `'MissingPlacement','last'`—RunMat gathers the tensor and performs the sort on the host while preserving MATLAB semantics.
65- Name-value options that the provider does not advertize fall back automatically; callers do not need to special-case GPU vs CPU execution.
66- The permutation indices are emitted as double-precision column vectors so they can be reused directly for MATLAB-style indexing.
67
68## Examples of using `sortrows` in MATLAB / RunMat
69
70### Sorting rows of a matrix in ascending order
71```matlab
72A = [3 2; 1 4; 2 1];
73B = sortrows(A);
74```
75Expected output:
76```matlab
77B =
78 1 4
79 2 1
80 3 2
81```
82
83### Sorting by a custom column order
84```matlab
85A = [1 4 2; 3 2 5; 3 2 1];
86B = sortrows(A, [2 3 1]);
87```
88Expected output:
89```matlab
90B =
91 3 2 1
92 3 2 5
93 1 4 2
94```
95
96### Sorting rows in descending order
97```matlab
98A = [2 8; 4 1; 3 5];
99B = sortrows(A, 'descend');
100```
101Expected output:
102```matlab
103B =
104 4 1
105 3 5
106 2 8
107```
108
109### Mixing ascending and descending directions
110```matlab
111A = [1 7 3; 1 2 9; 1 2 3];
112B = sortrows(A, [1 -2 3]);
113```
114Expected output:
115```matlab
116B =
117 1 7 3
118 1 2 3
119 1 2 9
120```
121
122### Sorting rows of a character array
123```matlab
124names = ['bob '; 'al '; 'ally'];
125sorted = sortrows(names);
126```
127Expected output:
128```matlab
129sorted =
130al
131ally
132bob
133```
134
135### Sorting rows of complex data by magnitude
136```matlab
137Z = [3+4i, 3; 1+2i, 4];
138B = sortrows(Z, 'ComparisonMethod', 'abs');
139```
140Expected output:
141```matlab
142B =
143 1.0000 + 2.0000i 4.0000
144 3.0000 + 4.0000i 3.0000
145```
146
147### Forcing NaN rows to the top
148```matlab
149A = [1 NaN; NaN 2];
150B = sortrows(A, 'MissingPlacement', 'first');
151```
152Expected output:
153```matlab
154B =
155 NaN 2
156 1 NaN
157```
158
159### Sorting GPU-resident data with automatic host fallback
160```matlab
161G = gpuArray([3 1; 2 4; 1 2]);
162[B, I] = sortrows(G);
163```
164The runtime gathers `G`, performs the sort on the host, and returns host tensors. The permutation indices `I`
165match MATLAB's 1-based output.
166
167## FAQ
168
169### Can I request the permutation indices?
170Yes. Call `[B, I] = sortrows(A, ...)` to receive the 1-based row permutation indices in `I`.
171
172### How do I sort specific columns?
173Provide a column vector, e.g. `sortrows(A, [2 -3])` sorts by column `2` ascending and column `3` descending.
174
175### What happens when rows contain NaN values?
176Rows containing NaNs move to the bottom for ascending sorts and to the top for descending sorts when `'MissingPlacement'` is left at its `'auto'` default, matching MATLAB.
177
178### How can I force NaNs or missing values to the top or bottom?
179Use the name-value pair `'MissingPlacement','first'` to place missing rows before finite ones, or `'MissingPlacement','last'` to move them to the end regardless of direction.
180
181### Does `sortrows` work with complex numbers?
182Yes. Use `'ComparisonMethod','real'` to sort by the real component or `'abs'` to sort by magnitude (the default behaviour matches MATLAB's `'auto'` rules).
183
184### Can I combine a direction string with a column vector?
185Yes. `sortrows(A, [1 3], 'descend')` applies descending order to both columns after applying the specified column order.
186
187### Is the operation stable?
188Yes. Rows that compare equal remain in their original order.
189
190### Does `sortrows` mutate its input?
191No. It returns a sorted copy of the input. GPU inputs are gathered to host memory when required.
192
193### Are string arrays supported?
194String arrays are not yet supported. Convert them to character matrices or use tables before sorting.
195
196## See Also
197[sort](./sort), [unique](./unique), [max](../../math/reduction/max), [min](../../math/reduction/min), [permute](../../array/shape/permute)
198
199## Source & Feedback
200- Source code: [`crates/runmat-runtime/src/builtins/array/sorting_sets/sortrows.rs`](https://github.com/runmat-org/runmat/blob/main/crates/runmat-runtime/src/builtins/array/sorting_sets/sortrows.rs)
201- Found a bug? [Open an issue](https://github.com/runmat-org/runmat/issues/new/choose) with a minimal reproduction.
202"#;
203
204pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
205 name: "sortrows",
206 op_kind: GpuOpKind::Custom("sortrows"),
207 supported_precisions: &[ScalarType::F32, ScalarType::F64],
208 broadcast: BroadcastSemantics::None,
209 provider_hooks: &[ProviderHook::Custom("sortrows")],
210 constant_strategy: ConstantStrategy::InlineLiteral,
211 residency: ResidencyPolicy::GatherImmediately,
212 nan_mode: ReductionNaN::Include,
213 two_pass_threshold: None,
214 workgroup_size: None,
215 accepts_nan_mode: true,
216 notes:
217 "Providers may implement a row-sort kernel; explicit MissingPlacement overrides fall back to host memory until native support exists.",
218};
219
220register_builtin_gpu_spec!(GPU_SPEC);
221
222pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
223 name: "sortrows",
224 shape: ShapeRequirements::Any,
225 constant_strategy: ConstantStrategy::InlineLiteral,
226 elementwise: None,
227 reduction: None,
228 emits_nan: true,
229 notes: "Acts as a sink operation; upstream fusion chains terminate before sorting rows.",
230};
231
232register_builtin_fusion_spec!(FUSION_SPEC);
233
234#[cfg(feature = "doc_export")]
235register_builtin_doc_text!("sortrows", DOC_MD);
236
237#[runtime_builtin(
238 name = "sortrows",
239 category = "array/sorting_sets",
240 summary = "Sort matrix rows lexicographically with optional column and direction control.",
241 keywords = "sortrows,row sort,lexicographic,gpu",
242 accel = "sink",
243 sink = true
244)]
245fn sortrows_builtin(value: Value, rest: Vec<Value>) -> Result<Value, String> {
246 evaluate(value, &rest).map(|eval| eval.into_sorted_value())
247}
248
249pub fn evaluate(value: Value, rest: &[Value]) -> Result<SortRowsEvaluation, String> {
251 match value {
252 Value::GpuTensor(handle) => sortrows_gpu(handle, rest),
253 other => sortrows_host(other, rest),
254 }
255}
256
257fn sortrows_gpu(handle: GpuTensorHandle, rest: &[Value]) -> Result<SortRowsEvaluation, String> {
258 ensure_matrix_shape(&handle.shape)?;
259 let (_, cols) = rows_cols_from_shape(&handle.shape);
260 let args = SortRowsArgs::parse(rest, cols)?;
261
262 if args.missing_is_auto() {
263 if let Some(provider) = runmat_accelerate_api::provider() {
264 let provider_columns = args.to_provider_columns();
265 let provider_comparison = args.provider_comparison();
266 match provider.sort_rows(&handle, &provider_columns, provider_comparison) {
267 Ok(result) => return sortrows_from_provider_result(result),
268 Err(_err) => {
269 }
271 }
272 }
273 }
274
275 let tensor = gpu_helpers::gather_tensor(&handle)?;
276 sortrows_real_tensor_with_args(tensor, &args)
277}
278
279fn sortrows_from_provider_result(result: ProviderSortResult) -> Result<SortRowsEvaluation, String> {
280 let sorted_tensor = Tensor::new(result.values.data, result.values.shape)
281 .map_err(|e| format!("sortrows: {e}"))?;
282 let indices_tensor = Tensor::new(result.indices.data, result.indices.shape)
283 .map_err(|e| format!("sortrows: {e}"))?;
284 Ok(SortRowsEvaluation {
285 sorted: tensor::tensor_into_value(sorted_tensor),
286 indices: indices_tensor,
287 })
288}
289
290fn sortrows_host(value: Value, rest: &[Value]) -> Result<SortRowsEvaluation, String> {
291 match value {
292 Value::Tensor(tensor) => sortrows_real_tensor(tensor, rest),
293 Value::LogicalArray(logical) => {
294 let tensor = tensor::logical_to_tensor(&logical)?;
295 sortrows_real_tensor(tensor, rest)
296 }
297 Value::Num(_) | Value::Int(_) | Value::Bool(_) => {
298 let tensor = tensor::value_into_tensor_for("sortrows", value)?;
299 sortrows_real_tensor(tensor, rest)
300 }
301 Value::ComplexTensor(ct) => sortrows_complex_tensor(ct, rest),
302 Value::Complex(re, im) => {
303 let tensor =
304 ComplexTensor::new(vec![(re, im)], vec![1, 1]).map_err(|e| format!("sortrows: {e}"))?;
305 sortrows_complex_tensor(tensor, rest)
306 }
307 Value::CharArray(ca) => sortrows_char_array(ca, rest),
308 other => Err(format!(
309 "sortrows: unsupported input type {:?}; expected numeric, logical, complex, or char arrays",
310 other
311 )),
312 }
313}
314
315fn sortrows_real_tensor(tensor: Tensor, rest: &[Value]) -> Result<SortRowsEvaluation, String> {
316 ensure_matrix_shape(&tensor.shape)?;
317 let cols = tensor.cols();
318 let args = SortRowsArgs::parse(rest, cols)?;
319 sortrows_real_tensor_with_args(tensor, &args)
320}
321
322fn sortrows_real_tensor_with_args(
323 tensor: Tensor,
324 args: &SortRowsArgs,
325) -> Result<SortRowsEvaluation, String> {
326 let rows = tensor.rows();
327 let cols = tensor.cols();
328
329 if rows <= 1 || cols == 0 || tensor.data.is_empty() || args.columns.is_empty() {
330 let indices = identity_indices(rows)?;
331 return Ok(SortRowsEvaluation {
332 sorted: tensor::tensor_into_value(tensor),
333 indices,
334 });
335 }
336
337 let mut order: Vec<usize> = (0..rows).collect();
338 order.sort_by(|&a, &b| compare_real_rows(&tensor, rows, args, a, b));
339
340 let sorted_tensor = reorder_real_rows(&tensor, rows, cols, &order)?;
341 let indices = permutation_indices(&order)?;
342 Ok(SortRowsEvaluation {
343 sorted: tensor::tensor_into_value(sorted_tensor),
344 indices,
345 })
346}
347
348fn sortrows_complex_tensor(
349 tensor: ComplexTensor,
350 rest: &[Value],
351) -> Result<SortRowsEvaluation, String> {
352 ensure_matrix_shape(&tensor.shape)?;
353 let cols = tensor.cols;
354 let args = SortRowsArgs::parse(rest, cols)?;
355 sortrows_complex_tensor_with_args(tensor, &args)
356}
357
358fn sortrows_complex_tensor_with_args(
359 tensor: ComplexTensor,
360 args: &SortRowsArgs,
361) -> Result<SortRowsEvaluation, String> {
362 let rows = tensor.rows;
363 let cols = tensor.cols;
364
365 if rows <= 1 || cols == 0 || tensor.data.is_empty() || args.columns.is_empty() {
366 let indices = identity_indices(rows)?;
367 return Ok(SortRowsEvaluation {
368 sorted: complex_tensor_into_value(tensor),
369 indices,
370 });
371 }
372
373 let mut order: Vec<usize> = (0..rows).collect();
374 order.sort_by(|&a, &b| compare_complex_rows(&tensor, rows, args, a, b));
375
376 let sorted_tensor = reorder_complex_rows(&tensor, rows, cols, &order)?;
377 let indices = permutation_indices(&order)?;
378 Ok(SortRowsEvaluation {
379 sorted: complex_tensor_into_value(sorted_tensor),
380 indices,
381 })
382}
383
384fn sortrows_char_array(ca: CharArray, rest: &[Value]) -> Result<SortRowsEvaluation, String> {
385 let cols = ca.cols;
386 let args = SortRowsArgs::parse(rest, cols)?;
387 sortrows_char_array_with_args(ca, &args)
388}
389
390fn sortrows_char_array_with_args(
391 ca: CharArray,
392 args: &SortRowsArgs,
393) -> Result<SortRowsEvaluation, String> {
394 let rows = ca.rows;
395 let cols = ca.cols;
396
397 if rows <= 1 || cols == 0 || ca.data.is_empty() || args.columns.is_empty() {
398 let indices = identity_indices(rows)?;
399 return Ok(SortRowsEvaluation {
400 sorted: Value::CharArray(ca),
401 indices,
402 });
403 }
404
405 let mut order: Vec<usize> = (0..rows).collect();
406 order.sort_by(|&a, &b| compare_char_rows(&ca, args, a, b));
407
408 let sorted = reorder_char_rows(&ca, rows, cols, &order)?;
409 let indices = permutation_indices(&order)?;
410 Ok(SortRowsEvaluation {
411 sorted: Value::CharArray(sorted),
412 indices,
413 })
414}
415
416fn ensure_matrix_shape(shape: &[usize]) -> Result<(), String> {
417 if shape.len() <= 2 {
418 Ok(())
419 } else {
420 Err("sortrows: input must be a 2-D matrix".to_string())
421 }
422}
423
424fn rows_cols_from_shape(shape: &[usize]) -> (usize, usize) {
425 match shape.len() {
426 0 => (1, 1),
427 1 => (1, shape[0]),
428 _ => (shape[0], shape[1]),
429 }
430}
431
432fn compare_real_rows(
433 tensor: &Tensor,
434 rows: usize,
435 args: &SortRowsArgs,
436 a: usize,
437 b: usize,
438) -> Ordering {
439 for spec in &args.columns {
440 if spec.index >= tensor.cols() {
441 continue;
442 }
443 let idx_a = a + spec.index * rows;
444 let idx_b = b + spec.index * rows;
445 let va = tensor.data[idx_a];
446 let vb = tensor.data[idx_b];
447 let missing = args.missing_for_direction(spec.direction);
448 let ord = compare_real_scalars(va, vb, spec.direction, args.comparison, missing);
449 if ord != Ordering::Equal {
450 return ord;
451 }
452 }
453 Ordering::Equal
454}
455
456fn compare_complex_rows(
457 tensor: &ComplexTensor,
458 rows: usize,
459 args: &SortRowsArgs,
460 a: usize,
461 b: usize,
462) -> Ordering {
463 for spec in &args.columns {
464 if spec.index >= tensor.cols {
465 continue;
466 }
467 let idx_a = a + spec.index * rows;
468 let idx_b = b + spec.index * rows;
469 let va = tensor.data[idx_a];
470 let vb = tensor.data[idx_b];
471 let missing = args.missing_for_direction(spec.direction);
472 let ord = compare_complex_scalars(va, vb, spec.direction, args.comparison, missing);
473 if ord != Ordering::Equal {
474 return ord;
475 }
476 }
477 Ordering::Equal
478}
479
480fn compare_char_rows(ca: &CharArray, args: &SortRowsArgs, a: usize, b: usize) -> Ordering {
481 for spec in &args.columns {
482 if spec.index >= ca.cols {
483 continue;
484 }
485 let idx_a = a * ca.cols + spec.index;
486 let idx_b = b * ca.cols + spec.index;
487 let va = ca.data[idx_a];
488 let vb = ca.data[idx_b];
489 let ord = match spec.direction {
490 SortDirection::Ascend => va.cmp(&vb),
491 SortDirection::Descend => vb.cmp(&va),
492 };
493 if ord != Ordering::Equal {
494 return ord;
495 }
496 }
497 Ordering::Equal
498}
499
500fn reorder_real_rows(
501 tensor: &Tensor,
502 rows: usize,
503 cols: usize,
504 order: &[usize],
505) -> Result<Tensor, String> {
506 let mut data = vec![0.0; tensor.data.len()];
507 for col in 0..cols {
508 for (dest_row, &src_row) in order.iter().enumerate() {
509 let src_idx = src_row + col * rows;
510 let dst_idx = dest_row + col * rows;
511 data[dst_idx] = tensor.data[src_idx];
512 }
513 }
514 Tensor::new(data, tensor.shape.clone()).map_err(|e| format!("sortrows: {e}"))
515}
516
517fn reorder_complex_rows(
518 tensor: &ComplexTensor,
519 rows: usize,
520 cols: usize,
521 order: &[usize],
522) -> Result<ComplexTensor, String> {
523 let mut data = vec![(0.0, 0.0); tensor.data.len()];
524 for col in 0..cols {
525 for (dest_row, &src_row) in order.iter().enumerate() {
526 let src_idx = src_row + col * rows;
527 let dst_idx = dest_row + col * rows;
528 data[dst_idx] = tensor.data[src_idx];
529 }
530 }
531 ComplexTensor::new(data, tensor.shape.clone()).map_err(|e| format!("sortrows: {e}"))
532}
533
534fn reorder_char_rows(
535 ca: &CharArray,
536 rows: usize,
537 cols: usize,
538 order: &[usize],
539) -> Result<CharArray, String> {
540 let mut data = vec!['\0'; ca.data.len()];
541 for (dest_row, &src_row) in order.iter().enumerate() {
542 for col in 0..cols {
543 let src_idx = src_row * cols + col;
544 let dst_idx = dest_row * cols + col;
545 data[dst_idx] = ca.data[src_idx];
546 }
547 }
548 CharArray::new(data, rows, cols).map_err(|e| format!("sortrows: {e}"))
549}
550
551fn compare_real_scalars(
552 a: f64,
553 b: f64,
554 direction: SortDirection,
555 comparison: ComparisonMethod,
556 missing: MissingPlacementResolved,
557) -> Ordering {
558 match (a.is_nan(), b.is_nan()) {
559 (true, true) => Ordering::Equal,
560 (true, false) => match missing {
561 MissingPlacementResolved::First => Ordering::Less,
562 MissingPlacementResolved::Last => Ordering::Greater,
563 },
564 (false, true) => match missing {
565 MissingPlacementResolved::First => Ordering::Greater,
566 MissingPlacementResolved::Last => Ordering::Less,
567 },
568 (false, false) => compare_real_finite_scalars(a, b, direction, comparison),
569 }
570}
571
572fn compare_real_finite_scalars(
573 a: f64,
574 b: f64,
575 direction: SortDirection,
576 comparison: ComparisonMethod,
577) -> Ordering {
578 if matches!(comparison, ComparisonMethod::Abs) {
579 let abs_cmp = a.abs().partial_cmp(&b.abs()).unwrap_or(Ordering::Equal);
580 if abs_cmp != Ordering::Equal {
581 return match direction {
582 SortDirection::Ascend => abs_cmp,
583 SortDirection::Descend => abs_cmp.reverse(),
584 };
585 }
586 }
587 match direction {
588 SortDirection::Ascend => a.partial_cmp(&b).unwrap_or(Ordering::Equal),
589 SortDirection::Descend => b.partial_cmp(&a).unwrap_or(Ordering::Equal),
590 }
591}
592
593fn compare_complex_scalars(
594 a: (f64, f64),
595 b: (f64, f64),
596 direction: SortDirection,
597 comparison: ComparisonMethod,
598 missing: MissingPlacementResolved,
599) -> Ordering {
600 match (complex_is_nan(a), complex_is_nan(b)) {
601 (true, true) => Ordering::Equal,
602 (true, false) => match missing {
603 MissingPlacementResolved::First => Ordering::Less,
604 MissingPlacementResolved::Last => Ordering::Greater,
605 },
606 (false, true) => match missing {
607 MissingPlacementResolved::First => Ordering::Greater,
608 MissingPlacementResolved::Last => Ordering::Less,
609 },
610 (false, false) => compare_complex_finite_scalars(a, b, direction, comparison),
611 }
612}
613
614fn compare_complex_finite_scalars(
615 a: (f64, f64),
616 b: (f64, f64),
617 direction: SortDirection,
618 comparison: ComparisonMethod,
619) -> Ordering {
620 match comparison {
621 ComparisonMethod::Real => compare_complex_real_first(a, b, direction),
622 ComparisonMethod::Auto | ComparisonMethod::Abs => {
623 let abs_cmp = complex_abs(a)
624 .partial_cmp(&complex_abs(b))
625 .unwrap_or(Ordering::Equal);
626 if abs_cmp != Ordering::Equal {
627 return match direction {
628 SortDirection::Ascend => abs_cmp,
629 SortDirection::Descend => abs_cmp.reverse(),
630 };
631 }
632 compare_complex_real_first(a, b, direction)
633 }
634 }
635}
636
637fn compare_complex_real_first(a: (f64, f64), b: (f64, f64), direction: SortDirection) -> Ordering {
638 let real_cmp = match direction {
639 SortDirection::Ascend => a.0.partial_cmp(&b.0),
640 SortDirection::Descend => b.0.partial_cmp(&a.0),
641 }
642 .unwrap_or(Ordering::Equal);
643 if real_cmp != Ordering::Equal {
644 return real_cmp;
645 }
646 match direction {
647 SortDirection::Ascend => a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal),
648 SortDirection::Descend => b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal),
649 }
650}
651
652fn complex_is_nan(value: (f64, f64)) -> bool {
653 value.0.is_nan() || value.1.is_nan()
654}
655
656fn complex_abs(value: (f64, f64)) -> f64 {
657 value.0.hypot(value.1)
658}
659
660fn permutation_indices(order: &[usize]) -> Result<Tensor, String> {
661 let rows = order.len();
662 let mut data = Vec::with_capacity(rows);
663 for &idx in order {
664 data.push((idx + 1) as f64);
665 }
666 Tensor::new(data, vec![rows, 1]).map_err(|e| format!("sortrows: {e}"))
667}
668
669fn identity_indices(rows: usize) -> Result<Tensor, String> {
670 let mut data = Vec::with_capacity(rows);
671 for i in 0..rows {
672 data.push((i + 1) as f64);
673 }
674 Tensor::new(data, vec![rows, 1]).map_err(|e| format!("sortrows: {e}"))
675}
676
677fn complex_tensor_into_value(tensor: ComplexTensor) -> Value {
678 if tensor.data.len() == 1 {
679 Value::Complex(tensor.data[0].0, tensor.data[0].1)
680 } else {
681 Value::ComplexTensor(tensor)
682 }
683}
684
685#[derive(Debug, Clone, Copy, PartialEq, Eq)]
686enum SortDirection {
687 Ascend,
688 Descend,
689}
690
691impl SortDirection {
692 fn from_str(value: &str) -> Option<Self> {
693 match value.trim().to_ascii_lowercase().as_str() {
694 "ascend" | "ascending" => Some(SortDirection::Ascend),
695 "descend" | "descending" => Some(SortDirection::Descend),
696 _ => None,
697 }
698 }
699}
700
701#[derive(Debug, Clone, Copy, PartialEq, Eq)]
702enum ComparisonMethod {
703 Auto,
704 Real,
705 Abs,
706}
707
708#[derive(Debug, Clone, Copy, PartialEq, Eq)]
709enum MissingPlacement {
710 Auto,
711 First,
712 Last,
713}
714
715#[derive(Debug, Clone, Copy, PartialEq, Eq)]
716enum MissingPlacementResolved {
717 First,
718 Last,
719}
720
721impl MissingPlacement {
722 fn resolve(self, direction: SortDirection) -> MissingPlacementResolved {
723 match self {
724 MissingPlacement::First => MissingPlacementResolved::First,
725 MissingPlacement::Last => MissingPlacementResolved::Last,
726 MissingPlacement::Auto => match direction {
727 SortDirection::Ascend => MissingPlacementResolved::Last,
728 SortDirection::Descend => MissingPlacementResolved::First,
729 },
730 }
731 }
732
733 fn is_auto(self) -> bool {
734 matches!(self, MissingPlacement::Auto)
735 }
736}
737
738#[derive(Debug, Clone)]
739struct ColumnSpec {
740 index: usize,
741 direction: SortDirection,
742}
743
744#[derive(Debug, Clone)]
745struct SortRowsArgs {
746 columns: Vec<ColumnSpec>,
747 comparison: ComparisonMethod,
748 missing: MissingPlacement,
749}
750
751impl SortRowsArgs {
752 fn parse(rest: &[Value], num_cols: usize) -> Result<Self, String> {
753 let mut columns: Option<Vec<ColumnSpec>> = None;
754 let mut override_direction: Option<SortDirection> = None;
755 let mut comparison = ComparisonMethod::Auto;
756 let mut missing = MissingPlacement::Auto;
757 let mut i = 0usize;
758
759 while i < rest.len() {
760 if columns.is_none() {
761 if let Some(parsed) = parse_column_vector(&rest[i], num_cols)? {
762 columns = Some(parsed);
763 i += 1;
764 continue;
765 }
766 }
767 if let Some(direction) = parse_direction(&rest[i]) {
768 override_direction = Some(direction);
769 i += 1;
770 continue;
771 }
772 let Some(keyword) = tensor::value_to_string(&rest[i]) else {
773 return Err(format!("sortrows: invalid argument {:?}", rest[i]));
774 };
775 let lowered = keyword.trim().to_ascii_lowercase();
776 match lowered.as_str() {
777 "comparisonmethod" => {
778 i += 1;
779 if i >= rest.len() {
780 return Err("sortrows: expected a value for 'ComparisonMethod'".to_string());
781 }
782 let Some(value_str) = tensor::value_to_string(&rest[i]) else {
783 return Err(
784 "sortrows: 'ComparisonMethod' expects a string value".to_string()
785 );
786 };
787 comparison = match value_str.trim().to_ascii_lowercase().as_str() {
788 "auto" => ComparisonMethod::Auto,
789 "real" => ComparisonMethod::Real,
790 "abs" | "magnitude" => ComparisonMethod::Abs,
791 other => {
792 return Err(format!("sortrows: unsupported ComparisonMethod '{other}'"))
793 }
794 };
795 i += 1;
796 }
797 "missingplacement" => {
798 i += 1;
799 if i >= rest.len() {
800 return Err("sortrows: expected a value for 'MissingPlacement'".to_string());
801 }
802 let Some(value_str) = tensor::value_to_string(&rest[i]) else {
803 return Err(
804 "sortrows: 'MissingPlacement' expects a string value".to_string()
805 );
806 };
807 missing = match value_str.trim().to_ascii_lowercase().as_str() {
808 "auto" => MissingPlacement::Auto,
809 "first" => MissingPlacement::First,
810 "last" => MissingPlacement::Last,
811 other => {
812 return Err(format!("sortrows: unsupported MissingPlacement '{other}'"))
813 }
814 };
815 i += 1;
816 }
817 other => {
818 return Err(format!("sortrows: unexpected argument '{other}'"));
819 }
820 }
821 }
822
823 let mut columns = columns.unwrap_or_else(|| default_columns(num_cols));
824 if let Some(dir) = override_direction {
825 for spec in &mut columns {
826 spec.direction = dir;
827 }
828 }
829 validate_columns(&columns, num_cols)?;
830
831 Ok(SortRowsArgs {
832 columns,
833 comparison,
834 missing,
835 })
836 }
837
838 fn to_provider_columns(&self) -> Vec<ProviderSortRowsColumnSpec> {
839 self.columns
840 .iter()
841 .map(|spec| ProviderSortRowsColumnSpec {
842 index: spec.index,
843 order: match spec.direction {
844 SortDirection::Ascend => ProviderSortOrder::Ascend,
845 SortDirection::Descend => ProviderSortOrder::Descend,
846 },
847 })
848 .collect()
849 }
850
851 fn provider_comparison(&self) -> ProviderSortComparison {
852 match self.comparison {
853 ComparisonMethod::Auto => ProviderSortComparison::Auto,
854 ComparisonMethod::Real => ProviderSortComparison::Real,
855 ComparisonMethod::Abs => ProviderSortComparison::Abs,
856 }
857 }
858
859 fn missing_for_direction(&self, direction: SortDirection) -> MissingPlacementResolved {
860 self.missing.resolve(direction)
861 }
862
863 fn missing_is_auto(&self) -> bool {
864 self.missing.is_auto()
865 }
866}
867
868fn parse_column_vector(value: &Value, num_cols: usize) -> Result<Option<Vec<ColumnSpec>>, String> {
869 match value {
870 Value::Int(i) => parse_single_column(i.to_i64(), num_cols).map(Some),
871 Value::Num(n) => {
872 if !n.is_finite() {
873 return Err("sortrows: column indices must be finite".to_string());
874 }
875 let rounded = n.round();
876 if (rounded - n).abs() > f64::EPSILON {
877 return Err("sortrows: column indices must be integers".to_string());
878 }
879 parse_single_column(rounded as i64, num_cols).map(Some)
880 }
881 Value::Tensor(tensor) => {
882 if !is_vector(&tensor.shape) {
883 return Err("sortrows: column specification must be a vector".to_string());
884 }
885 let mut specs = Vec::with_capacity(tensor.data.len());
886 for &entry in &tensor.data {
887 if !entry.is_finite() {
888 return Err("sortrows: column indices must be finite".to_string());
889 }
890 let rounded = entry.round();
891 if (rounded - entry).abs() > f64::EPSILON {
892 return Err("sortrows: column indices must be integers".to_string());
893 }
894 let column = parse_single_column_i64(rounded as i64, num_cols)?;
895 specs.push(column);
896 }
897 Ok(Some(specs))
898 }
899 _ => Ok(None),
900 }
901}
902
903fn parse_single_column(value: i64, num_cols: usize) -> Result<Vec<ColumnSpec>, String> {
904 parse_single_column_i64(value, num_cols).map(|spec| vec![spec])
905}
906
907fn parse_single_column_i64(value: i64, num_cols: usize) -> Result<ColumnSpec, String> {
908 if value == 0 {
909 return Err("sortrows: column indices must be non-zero".to_string());
910 }
911 let abs = value.unsigned_abs() as usize;
912 if abs == 0 {
913 return Err("sortrows: column indices must be >= 1".to_string());
914 }
915 if num_cols == 0 {
916 return Err("sortrows: column index exceeds matrix with 0 columns".to_string());
917 }
918 if abs > num_cols {
919 return Err(format!(
920 "sortrows: column index {} exceeds matrix with {} columns",
921 abs, num_cols
922 ));
923 }
924 let direction = if value > 0 {
925 SortDirection::Ascend
926 } else {
927 SortDirection::Descend
928 };
929 Ok(ColumnSpec {
930 index: abs - 1,
931 direction,
932 })
933}
934
935fn parse_direction(value: &Value) -> Option<SortDirection> {
936 tensor::value_to_string(value).and_then(|s| SortDirection::from_str(&s))
937}
938
939fn default_columns(num_cols: usize) -> Vec<ColumnSpec> {
940 let mut columns = Vec::with_capacity(num_cols);
941 for col in 0..num_cols {
942 columns.push(ColumnSpec {
943 index: col,
944 direction: SortDirection::Ascend,
945 });
946 }
947 columns
948}
949
950fn validate_columns(columns: &[ColumnSpec], num_cols: usize) -> Result<(), String> {
951 if num_cols == 0 && columns.iter().any(|spec| spec.index > 0) {
952 return Err("sortrows: column index exceeds matrix with 0 columns".to_string());
953 }
954 for spec in columns {
955 if num_cols > 0 && spec.index >= num_cols {
956 return Err(format!(
957 "sortrows: column index {} exceeds matrix with {} columns",
958 spec.index + 1,
959 num_cols
960 ));
961 }
962 }
963 Ok(())
964}
965
966fn is_vector(shape: &[usize]) -> bool {
967 match shape.len() {
968 0 => true,
969 1 => true,
970 2 => shape[0] == 1 || shape[1] == 1,
971 _ => false,
972 }
973}
974
975#[derive(Debug)]
976pub struct SortRowsEvaluation {
977 sorted: Value,
978 indices: Tensor,
979}
980
981impl SortRowsEvaluation {
982 pub fn into_sorted_value(self) -> Value {
983 self.sorted
984 }
985
986 pub fn into_values(self) -> (Value, Value) {
987 let indices = tensor::tensor_into_value(self.indices);
988 (self.sorted, indices)
989 }
990
991 pub fn indices_value(&self) -> Value {
992 tensor::tensor_into_value(self.indices.clone())
993 }
994}
995
996#[cfg(test)]
997mod tests {
998 use super::*;
999 use crate::builtins::common::test_support;
1000 use runmat_builtins::{IntValue, Value};
1001
1002 #[test]
1003 fn sortrows_default_matrix() {
1004 let tensor = Tensor::new(vec![3.0, 1.0, 2.0, 4.0, 1.0, 5.0], vec![3, 2]).unwrap();
1005 let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
1006 let (sorted, indices) = eval.into_values();
1007 match sorted {
1008 Value::Tensor(t) => {
1009 assert_eq!(t.shape, vec![3, 2]);
1010 assert_eq!(t.data, vec![1.0, 2.0, 3.0, 1.0, 5.0, 4.0]);
1011 }
1012 other => panic!("expected tensor, got {other:?}"),
1013 }
1014 match indices {
1015 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
1016 Value::Num(_) => panic!("expected tensor indices"),
1017 other => panic!("unexpected indices {other:?}"),
1018 }
1019 }
1020
1021 #[test]
1022 fn sortrows_with_column_vector() {
1023 let tensor = Tensor::new(
1024 vec![1.0, 3.0, 3.0, 4.0, 2.0, 2.0, 2.0, 5.0, 1.0],
1025 vec![3, 3],
1026 )
1027 .unwrap();
1028 let cols = Tensor::new(vec![2.0, 3.0, 1.0], vec![3, 1]).unwrap();
1029 let eval = evaluate(Value::Tensor(tensor), &[Value::Tensor(cols)]).expect("evaluate");
1030 let (sorted, _) = eval.into_values();
1031 match sorted {
1032 Value::Tensor(t) => {
1033 assert_eq!(t.data, vec![3.0, 3.0, 1.0, 2.0, 2.0, 4.0, 1.0, 5.0, 2.0]);
1034 }
1035 other => panic!("expected tensor, got {other:?}"),
1036 }
1037 }
1038
1039 #[test]
1040 fn sortrows_direction_descend() {
1041 let tensor = Tensor::new(vec![1.0, 2.0, 4.0, 3.0], vec![2, 2]).unwrap();
1042 let eval = evaluate(Value::Tensor(tensor), &[Value::from("descend")]).expect("evaluate");
1043 let (sorted, _) = eval.into_values();
1044 match sorted {
1045 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0, 3.0, 4.0]),
1046 other => panic!("expected tensor, got {other:?}"),
1047 }
1048 }
1049
1050 #[test]
1051 fn sortrows_mixed_directions() {
1052 let tensor = Tensor::new(vec![1.0, 1.0, 1.0, 1.0, 7.0, 2.0], vec![3, 2]).unwrap();
1053 let cols = Tensor::new(vec![1.0, -2.0], vec![2, 1]).unwrap();
1054 let eval = evaluate(Value::Tensor(tensor), &[Value::Tensor(cols)]).expect("evaluate");
1055 let (sorted, _) = eval.into_values();
1056 match sorted {
1057 Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 1.0, 1.0, 7.0, 2.0, 1.0]),
1058 other => panic!("expected tensor, got {other:?}"),
1059 }
1060 }
1061
1062 #[test]
1063 fn sortrows_returns_indices() {
1064 let tensor = Tensor::new(vec![2.0, 1.0, 3.0, 4.0], vec![2, 2]).unwrap();
1065 let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
1066 let (_, indices) = eval.into_values();
1067 match indices {
1068 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0]),
1069 Value::Num(_) => panic!("expected tensor indices"),
1070 other => panic!("unexpected indices {other:?}"),
1071 }
1072 }
1073
1074 #[test]
1075 fn sortrows_char_array() {
1076 let chars = CharArray::new(
1077 "bob "
1078 .chars()
1079 .chain("al ".chars())
1080 .chain("ally".chars())
1081 .collect(),
1082 3,
1083 4,
1084 )
1085 .unwrap();
1086 let eval = evaluate(Value::CharArray(chars), &[]).expect("evaluate");
1087 let (sorted, _) = eval.into_values();
1088 match sorted {
1089 Value::CharArray(ca) => {
1090 assert_eq!(ca.rows, 3);
1091 assert_eq!(ca.cols, 4);
1092 let strings: Vec<String> = (0..ca.rows)
1093 .map(|r| {
1094 ca.data[r * ca.cols..(r + 1) * ca.cols]
1095 .iter()
1096 .collect::<String>()
1097 })
1098 .collect();
1099 assert_eq!(
1100 strings,
1101 vec!["al ".to_string(), "ally".to_string(), "bob ".to_string()]
1102 );
1103 }
1104 other => panic!("expected char array, got {other:?}"),
1105 }
1106 }
1107
1108 #[test]
1109 fn sortrows_complex_abs() {
1110 let tensor = ComplexTensor::new(vec![(1.0, 2.0), (-2.0, 1.0)], vec![2, 1]).unwrap();
1111 let eval = evaluate(
1112 Value::ComplexTensor(tensor),
1113 &[Value::from("ComparisonMethod"), Value::from("abs")],
1114 )
1115 .expect("evaluate");
1116 let (sorted, _) = eval.into_values();
1117 match sorted {
1118 Value::ComplexTensor(ct) => {
1119 assert_eq!(ct.data, vec![(-2.0, 1.0), (1.0, 2.0)]);
1120 }
1121 other => panic!("expected complex tensor, got {other:?}"),
1122 }
1123 }
1124
1125 #[test]
1126 fn sortrows_invalid_column_index_errors() {
1127 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1128 let err = evaluate(Value::Tensor(tensor), &[Value::Int(IntValue::I32(3))]).unwrap_err();
1129 assert!(
1130 err.contains("column index"),
1131 "unexpected error message: {err}"
1132 );
1133 }
1134
1135 #[test]
1136 fn sortrows_missingplacement_first_moves_nan_first() {
1137 let tensor = Tensor::new(vec![1.0, f64::NAN, 2.0, 3.0], vec![2, 2]).unwrap();
1138 let eval = evaluate(
1139 Value::Tensor(tensor),
1140 &[Value::from("MissingPlacement"), Value::from("first")],
1141 )
1142 .expect("evaluate");
1143 let (sorted, indices) = eval.into_values();
1144 match sorted {
1145 Value::Tensor(t) => {
1146 assert!(t.data[0].is_nan());
1147 assert_eq!(t.data[1], 1.0);
1148 assert_eq!(t.data[2], 3.0);
1149 assert_eq!(t.data[3], 2.0);
1150 }
1151 other => panic!("expected tensor, got {other:?}"),
1152 }
1153 match indices {
1154 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0]),
1155 Value::Num(_) => panic!("expected tensor indices"),
1156 other => panic!("unexpected indices {other:?}"),
1157 }
1158 }
1159
1160 #[test]
1161 fn sortrows_missingplacement_last_descend_moves_nan_last() {
1162 let tensor = Tensor::new(vec![f64::NAN, 5.0, 1.0, 2.0], vec![2, 2]).unwrap();
1163 let eval = evaluate(
1164 Value::Tensor(tensor),
1165 &[
1166 Value::from("descend"),
1167 Value::from("MissingPlacement"),
1168 Value::from("last"),
1169 ],
1170 )
1171 .expect("evaluate");
1172 let (sorted, indices) = eval.into_values();
1173 match sorted {
1174 Value::Tensor(t) => {
1175 assert_eq!(t.data[0], 5.0);
1176 assert!(t.data[1].is_nan());
1177 assert_eq!(t.data[2], 2.0);
1178 assert_eq!(t.data[3], 1.0);
1179 }
1180 other => panic!("expected tensor, got {other:?}"),
1181 }
1182 match indices {
1183 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0]),
1184 Value::Num(_) => panic!("expected tensor indices"),
1185 other => panic!("unexpected indices {other:?}"),
1186 }
1187 }
1188
1189 #[test]
1190 fn sortrows_missingplacement_invalid_value_errors() {
1191 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1192 let err = evaluate(
1193 Value::Tensor(tensor),
1194 &[Value::from("MissingPlacement"), Value::from("middle")],
1195 )
1196 .unwrap_err();
1197 assert!(
1198 err.contains("MissingPlacement"),
1199 "unexpected error message: {err}"
1200 );
1201 }
1202
1203 #[test]
1204 fn sortrows_gpu_roundtrip() {
1205 test_support::with_test_provider(|provider| {
1206 let tensor = Tensor::new(vec![3.0, 1.0, 2.0, 4.0, 1.0, 5.0], vec![3, 2]).unwrap();
1207 let view = runmat_accelerate_api::HostTensorView {
1208 data: &tensor.data,
1209 shape: &tensor.shape,
1210 };
1211 let handle = provider.upload(&view).expect("upload");
1212 let eval = evaluate(Value::GpuTensor(handle), &[]).expect("evaluate");
1213 let (sorted, indices) = eval.into_values();
1214 match sorted {
1215 Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 3.0, 1.0, 5.0, 4.0]),
1216 other => panic!("expected tensor, got {other:?}"),
1217 }
1218 match indices {
1219 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
1220 other => panic!("unexpected indices {other:?}"),
1221 }
1222 });
1223 }
1224
1225 #[test]
1226 #[cfg(feature = "wgpu")]
1227 fn sortrows_wgpu_matches_cpu() {
1228 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1229 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1230 );
1231
1232 let tensor = Tensor::new(vec![4.0, 2.0, 3.0, 1.0, 2.0, 5.0], vec![3, 2]).unwrap();
1233 let cpu_eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("cpu evaluate");
1234 let (cpu_sorted_val, cpu_indices_val) = cpu_eval.into_values();
1235 let cpu_sorted = match cpu_sorted_val {
1236 Value::Tensor(t) => t,
1237 other => panic!("expected tensor, got {other:?}"),
1238 };
1239 let cpu_indices = match cpu_indices_val {
1240 Value::Tensor(t) => t,
1241 other => panic!("expected tensor indices, got {other:?}"),
1242 };
1243
1244 let view = runmat_accelerate_api::HostTensorView {
1245 data: &tensor.data,
1246 shape: &tensor.shape,
1247 };
1248 let provider = runmat_accelerate_api::provider().expect("provider");
1249 let handle = provider.upload(&view).expect("upload");
1250 let gpu_eval = evaluate(Value::GpuTensor(handle.clone()), &[]).expect("gpu evaluate");
1251 let (gpu_sorted_val, gpu_indices_val) = gpu_eval.into_values();
1252 let gpu_sorted = match gpu_sorted_val {
1253 Value::Tensor(t) => t,
1254 other => panic!("expected tensor, got {other:?}"),
1255 };
1256 let gpu_indices = match gpu_indices_val {
1257 Value::Tensor(t) => t,
1258 other => panic!("expected tensor indices, got {other:?}"),
1259 };
1260
1261 assert_eq!(gpu_sorted.shape, cpu_sorted.shape);
1262 assert_eq!(gpu_sorted.data, cpu_sorted.data);
1263 assert_eq!(gpu_indices.shape, cpu_indices.shape);
1264 assert_eq!(gpu_indices.data, cpu_indices.data);
1265
1266 let _ = provider.free(&handle);
1267 }
1268
1269 #[test]
1270 #[cfg(feature = "doc_export")]
1271 fn doc_examples_present() {
1272 let blocks = test_support::doc_examples(DOC_MD);
1273 assert!(!blocks.is_empty());
1274 }
1275}