1use std::cmp::Ordering;
4
5use runmat_accelerate_api::{
6 GpuTensorHandle, SortComparison as ProviderSortComparison, SortOrder as ProviderSortOrder,
7 SortResult as ProviderSortResult, SortRowsColumnSpec as ProviderSortRowsColumnSpec,
8};
9use runmat_builtins::{CharArray, ComplexTensor, Tensor, Value};
10use runmat_macros::runtime_builtin;
11
12use super::type_resolvers::tensor_output_type;
13use crate::build_runtime_error;
14use crate::builtins::common::gpu_helpers;
15use crate::builtins::common::spec::{
16 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
17 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
18};
19use crate::builtins::common::tensor;
20
21#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::sorting_sets::sortrows")]
22pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
23 name: "sortrows",
24 op_kind: GpuOpKind::Custom("sortrows"),
25 supported_precisions: &[ScalarType::F32, ScalarType::F64],
26 broadcast: BroadcastSemantics::None,
27 provider_hooks: &[ProviderHook::Custom("sortrows")],
28 constant_strategy: ConstantStrategy::InlineLiteral,
29 residency: ResidencyPolicy::GatherImmediately,
30 nan_mode: ReductionNaN::Include,
31 two_pass_threshold: None,
32 workgroup_size: None,
33 accepts_nan_mode: true,
34 notes:
35 "Providers may implement a row-sort kernel; explicit MissingPlacement overrides fall back to host memory until native support exists.",
36};
37
38#[runmat_macros::register_fusion_spec(
39 builtin_path = "crate::builtins::array::sorting_sets::sortrows"
40)]
41pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
42 name: "sortrows",
43 shape: ShapeRequirements::Any,
44 constant_strategy: ConstantStrategy::InlineLiteral,
45 elementwise: None,
46 reduction: None,
47 emits_nan: true,
48 notes: "`sortrows` terminates fusion chains and materialises results on the host; upstream tensors are gathered when necessary.",
49};
50
51fn sortrows_error(message: impl Into<String>) -> crate::RuntimeError {
52 build_runtime_error(message)
53 .with_builtin("sortrows")
54 .build()
55}
56
57#[runtime_builtin(
58 name = "sortrows",
59 category = "array/sorting_sets",
60 summary = "Sort matrix rows lexicographically with optional column and direction control.",
61 keywords = "sortrows,row sort,lexicographic,gpu",
62 accel = "sink",
63 sink = true,
64 type_resolver(tensor_output_type),
65 builtin_path = "crate::builtins::array::sorting_sets::sortrows"
66)]
67async fn sortrows_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
68 let eval = evaluate(value, &rest).await?;
69 if let Some(out_count) = crate::output_count::current_output_count() {
70 if out_count == 0 {
71 return Ok(Value::OutputList(Vec::new()));
72 }
73 let (sorted, indices) = eval.into_values();
74 let mut outputs = vec![sorted];
75 if out_count >= 2 {
76 outputs.push(indices);
77 }
78 return Ok(crate::output_count::output_list_with_padding(
79 out_count, outputs,
80 ));
81 }
82 Ok(eval.into_sorted_value())
83}
84
85pub async fn evaluate(value: Value, rest: &[Value]) -> crate::BuiltinResult<SortRowsEvaluation> {
87 match value {
88 Value::GpuTensor(handle) => sortrows_gpu(handle, rest).await,
89 other => sortrows_host(other, rest),
90 }
91}
92
93async fn sortrows_gpu(
94 handle: GpuTensorHandle,
95 rest: &[Value],
96) -> crate::BuiltinResult<SortRowsEvaluation> {
97 ensure_matrix_shape(&handle.shape)?;
98 let (_, cols) = rows_cols_from_shape(&handle.shape);
99 let args = SortRowsArgs::parse(rest, cols)?;
100
101 if args.missing_is_auto() {
102 if let Some(provider) = runmat_accelerate_api::provider() {
103 let provider_columns = args.to_provider_columns();
104 let provider_comparison = args.provider_comparison();
105 match provider
106 .sort_rows(&handle, &provider_columns, provider_comparison)
107 .await
108 {
109 Ok(result) => return sortrows_from_provider_result(result),
110 Err(_err) => {
111 }
113 }
114 }
115 }
116
117 let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
118 sortrows_real_tensor_with_args(tensor, &args)
119}
120
121fn sortrows_from_provider_result(
122 result: ProviderSortResult,
123) -> crate::BuiltinResult<SortRowsEvaluation> {
124 let sorted_tensor = Tensor::new(result.values.data, result.values.shape)
125 .map_err(|e| sortrows_error(format!("sortrows: {e}")))?;
126 let indices_tensor = Tensor::new(result.indices.data, result.indices.shape)
127 .map_err(|e| sortrows_error(format!("sortrows: {e}")))?;
128 Ok(SortRowsEvaluation {
129 sorted: tensor::tensor_into_value(sorted_tensor),
130 indices: indices_tensor,
131 })
132}
133
134fn sortrows_host(value: Value, rest: &[Value]) -> crate::BuiltinResult<SortRowsEvaluation> {
135 match value {
136 Value::Tensor(tensor) => sortrows_real_tensor(tensor, rest),
137 Value::LogicalArray(logical) => {
138 let tensor = tensor::logical_to_tensor(&logical)
139 .map_err(|e| sortrows_error(e))?;
140 sortrows_real_tensor(tensor, rest)
141 }
142 Value::Num(_) | Value::Int(_) | Value::Bool(_) => {
143 let tensor = tensor::value_into_tensor_for("sortrows", value)
144 .map_err(|e| sortrows_error(e))?;
145 sortrows_real_tensor(tensor, rest)
146 }
147 Value::ComplexTensor(ct) => sortrows_complex_tensor(ct, rest),
148 Value::Complex(re, im) => {
149 let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
150 .map_err(|e| sortrows_error(format!("sortrows: {e}")))?;
151 sortrows_complex_tensor(tensor, rest)
152 }
153 Value::CharArray(ca) => sortrows_char_array(ca, rest),
154 other => Err(sortrows_error(format!(
155 "sortrows: unsupported input type {:?}; expected numeric, logical, complex, or char arrays",
156 other
157 ))
158 .into()),
159 }
160}
161
162fn sortrows_real_tensor(
163 tensor: Tensor,
164 rest: &[Value],
165) -> crate::BuiltinResult<SortRowsEvaluation> {
166 ensure_matrix_shape(&tensor.shape)?;
167 let cols = tensor.cols();
168 let args = SortRowsArgs::parse(rest, cols)?;
169 sortrows_real_tensor_with_args(tensor, &args)
170}
171
172fn sortrows_real_tensor_with_args(
173 tensor: Tensor,
174 args: &SortRowsArgs,
175) -> crate::BuiltinResult<SortRowsEvaluation> {
176 let rows = tensor.rows();
177 let cols = tensor.cols();
178
179 if rows <= 1 || cols == 0 || tensor.data.is_empty() || args.columns.is_empty() {
180 let indices = identity_indices(rows)?;
181 return Ok(SortRowsEvaluation {
182 sorted: tensor::tensor_into_value(tensor),
183 indices,
184 });
185 }
186
187 let mut order: Vec<usize> = (0..rows).collect();
188 order.sort_by(|&a, &b| compare_real_rows(&tensor, rows, args, a, b));
189
190 let sorted_tensor = reorder_real_rows(&tensor, rows, cols, &order)?;
191 let indices = permutation_indices(&order)?;
192 Ok(SortRowsEvaluation {
193 sorted: tensor::tensor_into_value(sorted_tensor),
194 indices,
195 })
196}
197
198fn sortrows_complex_tensor(
199 tensor: ComplexTensor,
200 rest: &[Value],
201) -> crate::BuiltinResult<SortRowsEvaluation> {
202 ensure_matrix_shape(&tensor.shape)?;
203 let cols = tensor.cols;
204 let args = SortRowsArgs::parse(rest, cols)?;
205 sortrows_complex_tensor_with_args(tensor, &args)
206}
207
208fn sortrows_complex_tensor_with_args(
209 tensor: ComplexTensor,
210 args: &SortRowsArgs,
211) -> crate::BuiltinResult<SortRowsEvaluation> {
212 let rows = tensor.rows;
213 let cols = tensor.cols;
214
215 if rows <= 1 || cols == 0 || tensor.data.is_empty() || args.columns.is_empty() {
216 let indices = identity_indices(rows)?;
217 return Ok(SortRowsEvaluation {
218 sorted: complex_tensor_into_value(tensor),
219 indices,
220 });
221 }
222
223 let mut order: Vec<usize> = (0..rows).collect();
224 order.sort_by(|&a, &b| compare_complex_rows(&tensor, rows, args, a, b));
225
226 let sorted_tensor = reorder_complex_rows(&tensor, rows, cols, &order)?;
227 let indices = permutation_indices(&order)?;
228 Ok(SortRowsEvaluation {
229 sorted: complex_tensor_into_value(sorted_tensor),
230 indices,
231 })
232}
233
234fn sortrows_char_array(ca: CharArray, rest: &[Value]) -> crate::BuiltinResult<SortRowsEvaluation> {
235 let cols = ca.cols;
236 let args = SortRowsArgs::parse(rest, cols)?;
237 sortrows_char_array_with_args(ca, &args)
238}
239
240fn sortrows_char_array_with_args(
241 ca: CharArray,
242 args: &SortRowsArgs,
243) -> crate::BuiltinResult<SortRowsEvaluation> {
244 let rows = ca.rows;
245 let cols = ca.cols;
246
247 if rows <= 1 || cols == 0 || ca.data.is_empty() || args.columns.is_empty() {
248 let indices = identity_indices(rows)?;
249 return Ok(SortRowsEvaluation {
250 sorted: Value::CharArray(ca),
251 indices,
252 });
253 }
254
255 let mut order: Vec<usize> = (0..rows).collect();
256 order.sort_by(|&a, &b| compare_char_rows(&ca, args, a, b));
257
258 let sorted = reorder_char_rows(&ca, rows, cols, &order)?;
259 let indices = permutation_indices(&order)?;
260 Ok(SortRowsEvaluation {
261 sorted: Value::CharArray(sorted),
262 indices,
263 })
264}
265
266fn ensure_matrix_shape(shape: &[usize]) -> crate::BuiltinResult<()> {
267 if shape.len() <= 2 {
268 Ok(())
269 } else {
270 Err(sortrows_error("sortrows: input must be a 2-D matrix"))
271 }
272}
273
274fn rows_cols_from_shape(shape: &[usize]) -> (usize, usize) {
275 match shape.len() {
276 0 => (1, 1),
277 1 => (1, shape[0]),
278 _ => (shape[0], shape[1]),
279 }
280}
281
282fn compare_real_rows(
283 tensor: &Tensor,
284 rows: usize,
285 args: &SortRowsArgs,
286 a: usize,
287 b: usize,
288) -> Ordering {
289 for spec in &args.columns {
290 if spec.index >= tensor.cols() {
291 continue;
292 }
293 let idx_a = a + spec.index * rows;
294 let idx_b = b + spec.index * rows;
295 let va = tensor.data[idx_a];
296 let vb = tensor.data[idx_b];
297 let missing = args.missing_for_direction(spec.direction);
298 let ord = compare_real_scalars(va, vb, spec.direction, args.comparison, missing);
299 if ord != Ordering::Equal {
300 return ord;
301 }
302 }
303 Ordering::Equal
304}
305
306fn compare_complex_rows(
307 tensor: &ComplexTensor,
308 rows: usize,
309 args: &SortRowsArgs,
310 a: usize,
311 b: usize,
312) -> Ordering {
313 for spec in &args.columns {
314 if spec.index >= tensor.cols {
315 continue;
316 }
317 let idx_a = a + spec.index * rows;
318 let idx_b = b + spec.index * rows;
319 let va = tensor.data[idx_a];
320 let vb = tensor.data[idx_b];
321 let missing = args.missing_for_direction(spec.direction);
322 let ord = compare_complex_scalars(va, vb, spec.direction, args.comparison, missing);
323 if ord != Ordering::Equal {
324 return ord;
325 }
326 }
327 Ordering::Equal
328}
329
330fn compare_char_rows(ca: &CharArray, args: &SortRowsArgs, a: usize, b: usize) -> Ordering {
331 for spec in &args.columns {
332 if spec.index >= ca.cols {
333 continue;
334 }
335 let idx_a = a * ca.cols + spec.index;
336 let idx_b = b * ca.cols + spec.index;
337 let va = ca.data[idx_a];
338 let vb = ca.data[idx_b];
339 let ord = match spec.direction {
340 SortDirection::Ascend => va.cmp(&vb),
341 SortDirection::Descend => vb.cmp(&va),
342 };
343 if ord != Ordering::Equal {
344 return ord;
345 }
346 }
347 Ordering::Equal
348}
349
350fn reorder_real_rows(
351 tensor: &Tensor,
352 rows: usize,
353 cols: usize,
354 order: &[usize],
355) -> crate::BuiltinResult<Tensor> {
356 let mut data = vec![0.0; tensor.data.len()];
357 for col in 0..cols {
358 for (dest_row, &src_row) in order.iter().enumerate() {
359 let src_idx = src_row + col * rows;
360 let dst_idx = dest_row + col * rows;
361 data[dst_idx] = tensor.data[src_idx];
362 }
363 }
364 Tensor::new(data, tensor.shape.clone()).map_err(|e| sortrows_error(format!("sortrows: {e}")))
365}
366
367fn reorder_complex_rows(
368 tensor: &ComplexTensor,
369 rows: usize,
370 cols: usize,
371 order: &[usize],
372) -> crate::BuiltinResult<ComplexTensor> {
373 let mut data = vec![(0.0, 0.0); tensor.data.len()];
374 for col in 0..cols {
375 for (dest_row, &src_row) in order.iter().enumerate() {
376 let src_idx = src_row + col * rows;
377 let dst_idx = dest_row + col * rows;
378 data[dst_idx] = tensor.data[src_idx];
379 }
380 }
381 ComplexTensor::new(data, tensor.shape.clone())
382 .map_err(|e| sortrows_error(format!("sortrows: {e}")))
383}
384
385fn reorder_char_rows(
386 ca: &CharArray,
387 rows: usize,
388 cols: usize,
389 order: &[usize],
390) -> crate::BuiltinResult<CharArray> {
391 let mut data = vec!['\0'; ca.data.len()];
392 for (dest_row, &src_row) in order.iter().enumerate() {
393 for col in 0..cols {
394 let src_idx = src_row * cols + col;
395 let dst_idx = dest_row * cols + col;
396 data[dst_idx] = ca.data[src_idx];
397 }
398 }
399 CharArray::new(data, rows, cols).map_err(|e| sortrows_error(format!("sortrows: {e}")))
400}
401
402fn compare_real_scalars(
403 a: f64,
404 b: f64,
405 direction: SortDirection,
406 comparison: ComparisonMethod,
407 missing: MissingPlacementResolved,
408) -> Ordering {
409 match (a.is_nan(), b.is_nan()) {
410 (true, true) => Ordering::Equal,
411 (true, false) => match missing {
412 MissingPlacementResolved::First => Ordering::Less,
413 MissingPlacementResolved::Last => Ordering::Greater,
414 },
415 (false, true) => match missing {
416 MissingPlacementResolved::First => Ordering::Greater,
417 MissingPlacementResolved::Last => Ordering::Less,
418 },
419 (false, false) => compare_real_finite_scalars(a, b, direction, comparison),
420 }
421}
422
423fn compare_real_finite_scalars(
424 a: f64,
425 b: f64,
426 direction: SortDirection,
427 comparison: ComparisonMethod,
428) -> Ordering {
429 if matches!(comparison, ComparisonMethod::Abs) {
430 let abs_cmp = a.abs().partial_cmp(&b.abs()).unwrap_or(Ordering::Equal);
431 if abs_cmp != Ordering::Equal {
432 return match direction {
433 SortDirection::Ascend => abs_cmp,
434 SortDirection::Descend => abs_cmp.reverse(),
435 };
436 }
437 }
438 match direction {
439 SortDirection::Ascend => a.partial_cmp(&b).unwrap_or(Ordering::Equal),
440 SortDirection::Descend => b.partial_cmp(&a).unwrap_or(Ordering::Equal),
441 }
442}
443
444fn compare_complex_scalars(
445 a: (f64, f64),
446 b: (f64, f64),
447 direction: SortDirection,
448 comparison: ComparisonMethod,
449 missing: MissingPlacementResolved,
450) -> Ordering {
451 match (complex_is_nan(a), complex_is_nan(b)) {
452 (true, true) => Ordering::Equal,
453 (true, false) => match missing {
454 MissingPlacementResolved::First => Ordering::Less,
455 MissingPlacementResolved::Last => Ordering::Greater,
456 },
457 (false, true) => match missing {
458 MissingPlacementResolved::First => Ordering::Greater,
459 MissingPlacementResolved::Last => Ordering::Less,
460 },
461 (false, false) => compare_complex_finite_scalars(a, b, direction, comparison),
462 }
463}
464
465fn compare_complex_finite_scalars(
466 a: (f64, f64),
467 b: (f64, f64),
468 direction: SortDirection,
469 comparison: ComparisonMethod,
470) -> Ordering {
471 match comparison {
472 ComparisonMethod::Real => compare_complex_real_first(a, b, direction),
473 ComparisonMethod::Auto | ComparisonMethod::Abs => {
474 let abs_cmp = complex_abs(a)
475 .partial_cmp(&complex_abs(b))
476 .unwrap_or(Ordering::Equal);
477 if abs_cmp != Ordering::Equal {
478 return match direction {
479 SortDirection::Ascend => abs_cmp,
480 SortDirection::Descend => abs_cmp.reverse(),
481 };
482 }
483 compare_complex_real_first(a, b, direction)
484 }
485 }
486}
487
488fn compare_complex_real_first(a: (f64, f64), b: (f64, f64), direction: SortDirection) -> Ordering {
489 let real_cmp = match direction {
490 SortDirection::Ascend => a.0.partial_cmp(&b.0),
491 SortDirection::Descend => b.0.partial_cmp(&a.0),
492 }
493 .unwrap_or(Ordering::Equal);
494 if real_cmp != Ordering::Equal {
495 return real_cmp;
496 }
497 match direction {
498 SortDirection::Ascend => a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal),
499 SortDirection::Descend => b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal),
500 }
501}
502
503fn complex_is_nan(value: (f64, f64)) -> bool {
504 value.0.is_nan() || value.1.is_nan()
505}
506
507fn complex_abs(value: (f64, f64)) -> f64 {
508 value.0.hypot(value.1)
509}
510
511fn permutation_indices(order: &[usize]) -> crate::BuiltinResult<Tensor> {
512 let rows = order.len();
513 let mut data = Vec::with_capacity(rows);
514 for &idx in order {
515 data.push((idx + 1) as f64);
516 }
517 Tensor::new(data, vec![rows, 1]).map_err(|e| sortrows_error(format!("sortrows: {e}")))
518}
519
520fn identity_indices(rows: usize) -> crate::BuiltinResult<Tensor> {
521 let mut data = Vec::with_capacity(rows);
522 for i in 0..rows {
523 data.push((i + 1) as f64);
524 }
525 Tensor::new(data, vec![rows, 1]).map_err(|e| sortrows_error(format!("sortrows: {e}")))
526}
527
528fn complex_tensor_into_value(tensor: ComplexTensor) -> Value {
529 if tensor.data.len() == 1 {
530 Value::Complex(tensor.data[0].0, tensor.data[0].1)
531 } else {
532 Value::ComplexTensor(tensor)
533 }
534}
535
536#[derive(Debug, Clone, Copy, PartialEq, Eq)]
537enum SortDirection {
538 Ascend,
539 Descend,
540}
541
542impl SortDirection {
543 fn from_str(value: &str) -> Option<Self> {
544 match value.trim().to_ascii_lowercase().as_str() {
545 "ascend" | "ascending" => Some(SortDirection::Ascend),
546 "descend" | "descending" => Some(SortDirection::Descend),
547 _ => None,
548 }
549 }
550}
551
552#[derive(Debug, Clone, Copy, PartialEq, Eq)]
553enum ComparisonMethod {
554 Auto,
555 Real,
556 Abs,
557}
558
559#[derive(Debug, Clone, Copy, PartialEq, Eq)]
560enum MissingPlacement {
561 Auto,
562 First,
563 Last,
564}
565
566#[derive(Debug, Clone, Copy, PartialEq, Eq)]
567enum MissingPlacementResolved {
568 First,
569 Last,
570}
571
572impl MissingPlacement {
573 fn resolve(self, direction: SortDirection) -> MissingPlacementResolved {
574 match self {
575 MissingPlacement::First => MissingPlacementResolved::First,
576 MissingPlacement::Last => MissingPlacementResolved::Last,
577 MissingPlacement::Auto => match direction {
578 SortDirection::Ascend => MissingPlacementResolved::Last,
579 SortDirection::Descend => MissingPlacementResolved::First,
580 },
581 }
582 }
583
584 fn is_auto(self) -> bool {
585 matches!(self, MissingPlacement::Auto)
586 }
587}
588
589#[derive(Debug, Clone)]
590struct ColumnSpec {
591 index: usize,
592 direction: SortDirection,
593}
594
595#[derive(Debug, Clone)]
596struct SortRowsArgs {
597 columns: Vec<ColumnSpec>,
598 comparison: ComparisonMethod,
599 missing: MissingPlacement,
600}
601
602impl SortRowsArgs {
603 fn parse(rest: &[Value], num_cols: usize) -> crate::BuiltinResult<Self> {
604 let mut columns: Option<Vec<ColumnSpec>> = None;
605 let mut override_direction: Option<SortDirection> = None;
606 let mut comparison = ComparisonMethod::Auto;
607 let mut missing = MissingPlacement::Auto;
608 let mut i = 0usize;
609
610 while i < rest.len() {
611 if columns.is_none() {
612 if let Some(parsed) = parse_column_vector(&rest[i], num_cols)? {
613 columns = Some(parsed);
614 i += 1;
615 continue;
616 }
617 }
618 if let Some(direction) = parse_direction(&rest[i]) {
619 override_direction = Some(direction);
620 i += 1;
621 continue;
622 }
623 let Some(keyword) = tensor::value_to_string(&rest[i]) else {
624 return Err(sortrows_error(format!(
625 "sortrows: invalid argument {:?}",
626 rest[i]
627 )));
628 };
629 let lowered = keyword.trim().to_ascii_lowercase();
630 match lowered.as_str() {
631 "comparisonmethod" => {
632 i += 1;
633 if i >= rest.len() {
634 return Err(sortrows_error(
635 "sortrows: expected a value for 'ComparisonMethod'",
636 ));
637 }
638 let Some(value_str) = tensor::value_to_string(&rest[i]) else {
639 return Err(sortrows_error(
640 "sortrows: 'ComparisonMethod' expects a string value",
641 )
642 .into());
643 };
644 comparison = match value_str.trim().to_ascii_lowercase().as_str() {
645 "auto" => ComparisonMethod::Auto,
646 "real" => ComparisonMethod::Real,
647 "abs" | "magnitude" => ComparisonMethod::Abs,
648 other => {
649 return Err(sortrows_error(format!(
650 "sortrows: unsupported ComparisonMethod '{other}'"
651 ))
652 .into())
653 }
654 };
655 i += 1;
656 }
657 "missingplacement" => {
658 i += 1;
659 if i >= rest.len() {
660 return Err(sortrows_error(
661 "sortrows: expected a value for 'MissingPlacement'",
662 )
663 .into());
664 }
665 let Some(value_str) = tensor::value_to_string(&rest[i]) else {
666 return Err(sortrows_error(
667 "sortrows: 'MissingPlacement' expects a string value",
668 )
669 .into());
670 };
671 missing = match value_str.trim().to_ascii_lowercase().as_str() {
672 "auto" => MissingPlacement::Auto,
673 "first" => MissingPlacement::First,
674 "last" => MissingPlacement::Last,
675 other => {
676 return Err(sortrows_error(format!(
677 "sortrows: unsupported MissingPlacement '{other}'"
678 ))
679 .into())
680 }
681 };
682 i += 1;
683 }
684 other => {
685 return Err(sortrows_error(format!(
686 "sortrows: unexpected argument '{other}'"
687 )));
688 }
689 }
690 }
691
692 let mut columns = columns.unwrap_or_else(|| default_columns(num_cols));
693 if let Some(dir) = override_direction {
694 for spec in &mut columns {
695 spec.direction = dir;
696 }
697 }
698 validate_columns(&columns, num_cols)?;
699
700 Ok(SortRowsArgs {
701 columns,
702 comparison,
703 missing,
704 })
705 }
706
707 fn to_provider_columns(&self) -> Vec<ProviderSortRowsColumnSpec> {
708 self.columns
709 .iter()
710 .map(|spec| ProviderSortRowsColumnSpec {
711 index: spec.index,
712 order: match spec.direction {
713 SortDirection::Ascend => ProviderSortOrder::Ascend,
714 SortDirection::Descend => ProviderSortOrder::Descend,
715 },
716 })
717 .collect()
718 }
719
720 fn provider_comparison(&self) -> ProviderSortComparison {
721 match self.comparison {
722 ComparisonMethod::Auto => ProviderSortComparison::Auto,
723 ComparisonMethod::Real => ProviderSortComparison::Real,
724 ComparisonMethod::Abs => ProviderSortComparison::Abs,
725 }
726 }
727
728 fn missing_for_direction(&self, direction: SortDirection) -> MissingPlacementResolved {
729 self.missing.resolve(direction)
730 }
731
732 fn missing_is_auto(&self) -> bool {
733 self.missing.is_auto()
734 }
735}
736
737fn parse_column_vector(
738 value: &Value,
739 num_cols: usize,
740) -> crate::BuiltinResult<Option<Vec<ColumnSpec>>> {
741 match value {
742 Value::Int(i) => parse_single_column(i.to_i64(), num_cols).map(Some),
743 Value::Num(n) => {
744 if !n.is_finite() {
745 return Err(sortrows_error("sortrows: column indices must be finite"));
746 }
747 let rounded = n.round();
748 if (rounded - n).abs() > f64::EPSILON {
749 return Err(sortrows_error("sortrows: column indices must be integers"));
750 }
751 parse_single_column(rounded as i64, num_cols).map(Some)
752 }
753 Value::Tensor(tensor) => {
754 if !is_vector(&tensor.shape) {
755 return Err(sortrows_error(
756 "sortrows: column specification must be a vector",
757 ));
758 }
759 let mut specs = Vec::with_capacity(tensor.data.len());
760 for &entry in &tensor.data {
761 if !entry.is_finite() {
762 return Err(sortrows_error("sortrows: column indices must be finite"));
763 }
764 let rounded = entry.round();
765 if (rounded - entry).abs() > f64::EPSILON {
766 return Err(sortrows_error("sortrows: column indices must be integers"));
767 }
768 let column = parse_single_column_i64(rounded as i64, num_cols)?;
769 specs.push(column);
770 }
771 Ok(Some(specs))
772 }
773 _ => Ok(None),
774 }
775}
776
777fn parse_single_column(value: i64, num_cols: usize) -> crate::BuiltinResult<Vec<ColumnSpec>> {
778 parse_single_column_i64(value, num_cols).map(|spec| vec![spec])
779}
780
781fn parse_single_column_i64(value: i64, num_cols: usize) -> crate::BuiltinResult<ColumnSpec> {
782 if value == 0 {
783 return Err(sortrows_error("sortrows: column indices must be non-zero"));
784 }
785 let abs = value.unsigned_abs() as usize;
786 if abs == 0 {
787 return Err(sortrows_error("sortrows: column indices must be >= 1"));
788 }
789 if num_cols == 0 {
790 return Err(sortrows_error(
791 "sortrows: column index exceeds matrix with 0 columns",
792 ));
793 }
794 if abs > num_cols {
795 return Err(sortrows_error(format!(
796 "sortrows: column index {} exceeds matrix with {} columns",
797 abs, num_cols
798 ))
799 .into());
800 }
801 let direction = if value > 0 {
802 SortDirection::Ascend
803 } else {
804 SortDirection::Descend
805 };
806 Ok(ColumnSpec {
807 index: abs - 1,
808 direction,
809 })
810}
811
812fn parse_direction(value: &Value) -> Option<SortDirection> {
813 tensor::value_to_string(value).and_then(|s| SortDirection::from_str(&s))
814}
815
816fn default_columns(num_cols: usize) -> Vec<ColumnSpec> {
817 let mut columns = Vec::with_capacity(num_cols);
818 for col in 0..num_cols {
819 columns.push(ColumnSpec {
820 index: col,
821 direction: SortDirection::Ascend,
822 });
823 }
824 columns
825}
826
827fn validate_columns(columns: &[ColumnSpec], num_cols: usize) -> crate::BuiltinResult<()> {
828 if num_cols == 0 && columns.iter().any(|spec| spec.index > 0) {
829 return Err(sortrows_error(
830 "sortrows: column index exceeds matrix with 0 columns",
831 ));
832 }
833 for spec in columns {
834 if num_cols > 0 && spec.index >= num_cols {
835 return Err(sortrows_error(format!(
836 "sortrows: column index {} exceeds matrix with {} columns",
837 spec.index + 1,
838 num_cols
839 ))
840 .into());
841 }
842 }
843 Ok(())
844}
845
846fn is_vector(shape: &[usize]) -> bool {
847 match shape.len() {
848 0 => true,
849 1 => true,
850 2 => shape[0] == 1 || shape[1] == 1,
851 _ => false,
852 }
853}
854
855#[derive(Debug)]
856pub struct SortRowsEvaluation {
857 sorted: Value,
858 indices: Tensor,
859}
860
861impl SortRowsEvaluation {
862 pub fn into_sorted_value(self) -> Value {
863 self.sorted
864 }
865
866 pub fn into_values(self) -> (Value, Value) {
867 let indices = tensor::tensor_into_value(self.indices);
868 (self.sorted, indices)
869 }
870
871 pub fn indices_value(&self) -> Value {
872 tensor::tensor_into_value(self.indices.clone())
873 }
874}
875
876#[cfg(test)]
877pub(crate) mod tests {
878 use super::*;
879 use crate::builtins::common::test_support;
880 use runmat_builtins::{IntValue, ResolveContext, Type, Value};
881
882 fn error_message(err: crate::RuntimeError) -> String {
883 err.message().to_string()
884 }
885
886 fn evaluate(value: Value, rest: &[Value]) -> crate::BuiltinResult<SortRowsEvaluation> {
887 futures::executor::block_on(super::evaluate(value, rest))
888 }
889
890 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
891 #[test]
892 fn sortrows_default_matrix() {
893 let tensor = Tensor::new(vec![3.0, 1.0, 2.0, 4.0, 1.0, 5.0], vec![3, 2]).unwrap();
894 let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
895 let (sorted, indices) = eval.into_values();
896 match sorted {
897 Value::Tensor(t) => {
898 assert_eq!(t.shape, vec![3, 2]);
899 assert_eq!(t.data, vec![1.0, 2.0, 3.0, 1.0, 5.0, 4.0]);
900 }
901 other => panic!("expected tensor, got {other:?}"),
902 }
903 match indices {
904 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
905 Value::Num(_) => panic!("expected tensor indices"),
906 other => panic!("unexpected indices {other:?}"),
907 }
908 }
909
910 #[test]
911 fn sortrows_type_resolver_tensor() {
912 assert_eq!(
913 tensor_output_type(&[Type::tensor()], &ResolveContext::new(Vec::new())),
914 Type::tensor()
915 );
916 }
917
918 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
919 #[test]
920 fn sortrows_with_column_vector() {
921 let tensor = Tensor::new(
922 vec![1.0, 3.0, 3.0, 4.0, 2.0, 2.0, 2.0, 5.0, 1.0],
923 vec![3, 3],
924 )
925 .unwrap();
926 let cols = Tensor::new(vec![2.0, 3.0, 1.0], vec![3, 1]).unwrap();
927 let eval = evaluate(Value::Tensor(tensor), &[Value::Tensor(cols)]).expect("evaluate");
928 let (sorted, _) = eval.into_values();
929 match sorted {
930 Value::Tensor(t) => {
931 assert_eq!(t.data, vec![3.0, 3.0, 1.0, 2.0, 2.0, 4.0, 1.0, 5.0, 2.0]);
932 }
933 other => panic!("expected tensor, got {other:?}"),
934 }
935 }
936
937 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
938 #[test]
939 fn sortrows_direction_descend() {
940 let tensor = Tensor::new(vec![1.0, 2.0, 4.0, 3.0], vec![2, 2]).unwrap();
941 let eval = evaluate(Value::Tensor(tensor), &[Value::from("descend")]).expect("evaluate");
942 let (sorted, _) = eval.into_values();
943 match sorted {
944 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0, 3.0, 4.0]),
945 other => panic!("expected tensor, got {other:?}"),
946 }
947 }
948
949 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
950 #[test]
951 fn sortrows_mixed_directions() {
952 let tensor = Tensor::new(vec![1.0, 1.0, 1.0, 1.0, 7.0, 2.0], vec![3, 2]).unwrap();
953 let cols = Tensor::new(vec![1.0, -2.0], vec![2, 1]).unwrap();
954 let eval = evaluate(Value::Tensor(tensor), &[Value::Tensor(cols)]).expect("evaluate");
955 let (sorted, _) = eval.into_values();
956 match sorted {
957 Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 1.0, 1.0, 7.0, 2.0, 1.0]),
958 other => panic!("expected tensor, got {other:?}"),
959 }
960 }
961
962 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
963 #[test]
964 fn sortrows_returns_indices() {
965 let tensor = Tensor::new(vec![2.0, 1.0, 3.0, 4.0], vec![2, 2]).unwrap();
966 let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
967 let (_, indices) = eval.into_values();
968 match indices {
969 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0]),
970 Value::Num(_) => panic!("expected tensor indices"),
971 other => panic!("unexpected indices {other:?}"),
972 }
973 }
974
975 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
976 #[test]
977 fn sortrows_char_array() {
978 let chars = CharArray::new(
979 "bob "
980 .chars()
981 .chain("al ".chars())
982 .chain("ally".chars())
983 .collect(),
984 3,
985 4,
986 )
987 .unwrap();
988 let eval = evaluate(Value::CharArray(chars), &[]).expect("evaluate");
989 let (sorted, _) = eval.into_values();
990 match sorted {
991 Value::CharArray(ca) => {
992 assert_eq!(ca.rows, 3);
993 assert_eq!(ca.cols, 4);
994 let strings: Vec<String> = (0..ca.rows)
995 .map(|r| {
996 ca.data[r * ca.cols..(r + 1) * ca.cols]
997 .iter()
998 .collect::<String>()
999 })
1000 .collect();
1001 assert_eq!(
1002 strings,
1003 vec!["al ".to_string(), "ally".to_string(), "bob ".to_string()]
1004 );
1005 }
1006 other => panic!("expected char array, got {other:?}"),
1007 }
1008 }
1009
1010 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1011 #[test]
1012 fn sortrows_complex_abs() {
1013 let tensor = ComplexTensor::new(vec![(1.0, 2.0), (-2.0, 1.0)], vec![2, 1]).unwrap();
1014 let eval = evaluate(
1015 Value::ComplexTensor(tensor),
1016 &[Value::from("ComparisonMethod"), Value::from("abs")],
1017 )
1018 .expect("evaluate");
1019 let (sorted, _) = eval.into_values();
1020 match sorted {
1021 Value::ComplexTensor(ct) => {
1022 assert_eq!(ct.data, vec![(-2.0, 1.0), (1.0, 2.0)]);
1023 }
1024 other => panic!("expected complex tensor, got {other:?}"),
1025 }
1026 }
1027
1028 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1029 #[test]
1030 fn sortrows_invalid_column_index_errors() {
1031 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1032 let err = error_message(
1033 evaluate(Value::Tensor(tensor), &[Value::Int(IntValue::I32(3))]).unwrap_err(),
1034 );
1035 assert!(
1036 err.contains("column index"),
1037 "unexpected error message: {err}"
1038 );
1039 }
1040
1041 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1042 #[test]
1043 fn sortrows_missingplacement_first_moves_nan_first() {
1044 let tensor = Tensor::new(vec![1.0, f64::NAN, 2.0, 3.0], vec![2, 2]).unwrap();
1045 let eval = evaluate(
1046 Value::Tensor(tensor),
1047 &[Value::from("MissingPlacement"), Value::from("first")],
1048 )
1049 .expect("evaluate");
1050 let (sorted, indices) = eval.into_values();
1051 match sorted {
1052 Value::Tensor(t) => {
1053 assert!(t.data[0].is_nan());
1054 assert_eq!(t.data[1], 1.0);
1055 assert_eq!(t.data[2], 3.0);
1056 assert_eq!(t.data[3], 2.0);
1057 }
1058 other => panic!("expected tensor, got {other:?}"),
1059 }
1060 match indices {
1061 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0]),
1062 Value::Num(_) => panic!("expected tensor indices"),
1063 other => panic!("unexpected indices {other:?}"),
1064 }
1065 }
1066
1067 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1068 #[test]
1069 fn sortrows_missingplacement_last_descend_moves_nan_last() {
1070 let tensor = Tensor::new(vec![f64::NAN, 5.0, 1.0, 2.0], vec![2, 2]).unwrap();
1071 let eval = evaluate(
1072 Value::Tensor(tensor),
1073 &[
1074 Value::from("descend"),
1075 Value::from("MissingPlacement"),
1076 Value::from("last"),
1077 ],
1078 )
1079 .expect("evaluate");
1080 let (sorted, indices) = eval.into_values();
1081 match sorted {
1082 Value::Tensor(t) => {
1083 assert_eq!(t.data[0], 5.0);
1084 assert!(t.data[1].is_nan());
1085 assert_eq!(t.data[2], 2.0);
1086 assert_eq!(t.data[3], 1.0);
1087 }
1088 other => panic!("expected tensor, got {other:?}"),
1089 }
1090 match indices {
1091 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0]),
1092 Value::Num(_) => panic!("expected tensor indices"),
1093 other => panic!("unexpected indices {other:?}"),
1094 }
1095 }
1096
1097 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1098 #[test]
1099 fn sortrows_missingplacement_invalid_value_errors() {
1100 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1101 let err = error_message(
1102 evaluate(
1103 Value::Tensor(tensor),
1104 &[Value::from("MissingPlacement"), Value::from("middle")],
1105 )
1106 .unwrap_err(),
1107 );
1108 assert!(
1109 err.contains("MissingPlacement"),
1110 "unexpected error message: {err}"
1111 );
1112 }
1113
1114 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1115 #[test]
1116 fn sortrows_gpu_roundtrip() {
1117 test_support::with_test_provider(|provider| {
1118 let tensor = Tensor::new(vec![3.0, 1.0, 2.0, 4.0, 1.0, 5.0], vec![3, 2]).unwrap();
1119 let view = runmat_accelerate_api::HostTensorView {
1120 data: &tensor.data,
1121 shape: &tensor.shape,
1122 };
1123 let handle = provider.upload(&view).expect("upload");
1124 let eval = evaluate(Value::GpuTensor(handle), &[]).expect("evaluate");
1125 let (sorted, indices) = eval.into_values();
1126 match sorted {
1127 Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 3.0, 1.0, 5.0, 4.0]),
1128 other => panic!("expected tensor, got {other:?}"),
1129 }
1130 match indices {
1131 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
1132 other => panic!("unexpected indices {other:?}"),
1133 }
1134 });
1135 }
1136
1137 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1138 #[test]
1139 #[cfg(feature = "wgpu")]
1140 fn sortrows_wgpu_matches_cpu() {
1141 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1142 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1143 );
1144
1145 let tensor = Tensor::new(vec![4.0, 2.0, 3.0, 1.0, 2.0, 5.0], vec![3, 2]).unwrap();
1146 let cpu_eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("cpu evaluate");
1147 let (cpu_sorted_val, cpu_indices_val) = cpu_eval.into_values();
1148 let cpu_sorted = match cpu_sorted_val {
1149 Value::Tensor(t) => t,
1150 other => panic!("expected tensor, got {other:?}"),
1151 };
1152 let cpu_indices = match cpu_indices_val {
1153 Value::Tensor(t) => t,
1154 other => panic!("expected tensor indices, got {other:?}"),
1155 };
1156
1157 let view = runmat_accelerate_api::HostTensorView {
1158 data: &tensor.data,
1159 shape: &tensor.shape,
1160 };
1161 let provider = runmat_accelerate_api::provider().expect("provider");
1162 let handle = provider.upload(&view).expect("upload");
1163 let gpu_eval = evaluate(Value::GpuTensor(handle.clone()), &[]).expect("gpu evaluate");
1164 let (gpu_sorted_val, gpu_indices_val) = gpu_eval.into_values();
1165 let gpu_sorted = match gpu_sorted_val {
1166 Value::Tensor(t) => t,
1167 other => panic!("expected tensor, got {other:?}"),
1168 };
1169 let gpu_indices = match gpu_indices_val {
1170 Value::Tensor(t) => t,
1171 other => panic!("expected tensor indices, got {other:?}"),
1172 };
1173
1174 assert_eq!(gpu_sorted.shape, cpu_sorted.shape);
1175 assert_eq!(gpu_sorted.data, cpu_sorted.data);
1176 assert_eq!(gpu_indices.shape, cpu_indices.shape);
1177 assert_eq!(gpu_indices.data, cpu_indices.data);
1178
1179 let _ = provider.free(&handle);
1180 }
1181}