1use std::cmp::Ordering;
4
5use runmat_accelerate_api::{
6 GpuTensorHandle, SortComparison as ProviderSortComparison, SortOrder as ProviderSortOrder,
7 SortResult as ProviderSortResult, SortRowsColumnSpec as ProviderSortRowsColumnSpec,
8};
9use runmat_builtins::{
10 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
11 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
12 CharArray, ComplexTensor, Tensor, Value,
13};
14use runmat_macros::runtime_builtin;
15
16use super::type_resolvers::tensor_output_type;
17use crate::build_runtime_error;
18use crate::builtins::common::gpu_helpers;
19use crate::builtins::common::spec::{
20 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
21 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
22};
23use crate::builtins::common::tensor;
24
25#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::sorting_sets::sortrows")]
26pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
27 name: "sortrows",
28 op_kind: GpuOpKind::Custom("sortrows"),
29 supported_precisions: &[ScalarType::F32, ScalarType::F64],
30 broadcast: BroadcastSemantics::None,
31 provider_hooks: &[ProviderHook::Custom("sortrows")],
32 constant_strategy: ConstantStrategy::InlineLiteral,
33 residency: ResidencyPolicy::GatherImmediately,
34 nan_mode: ReductionNaN::Include,
35 two_pass_threshold: None,
36 workgroup_size: None,
37 accepts_nan_mode: true,
38 notes:
39 "Providers may implement a row-sort kernel; explicit MissingPlacement overrides fall back to host memory until native support exists.",
40};
41
42#[runmat_macros::register_fusion_spec(
43 builtin_path = "crate::builtins::array::sorting_sets::sortrows"
44)]
45pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
46 name: "sortrows",
47 shape: ShapeRequirements::Any,
48 constant_strategy: ConstantStrategy::InlineLiteral,
49 elementwise: None,
50 reduction: None,
51 emits_nan: true,
52 notes: "`sortrows` terminates fusion chains and materialises results on the host; upstream tensors are gathered when necessary.",
53};
54
55const BUILTIN_NAME: &str = "sortrows";
56
57const SORTROWS_OUTPUT_B: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
58 name: "B",
59 ty: BuiltinParamType::Any,
60 arity: BuiltinParamArity::Required,
61 default: None,
62 description: "Sorted input rows.",
63}];
64
65const SORTROWS_OUTPUT_BI: [BuiltinParamDescriptor; 2] = [
66 BuiltinParamDescriptor {
67 name: "B",
68 ty: BuiltinParamType::Any,
69 arity: BuiltinParamArity::Required,
70 default: None,
71 description: "Sorted input rows.",
72 },
73 BuiltinParamDescriptor {
74 name: "I",
75 ty: BuiltinParamType::NumericArray,
76 arity: BuiltinParamArity::Required,
77 default: None,
78 description: "Permutation indices mapping sorted rows to original rows.",
79 },
80];
81
82const SORTROWS_INPUTS_A: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
83 name: "A",
84 ty: BuiltinParamType::Any,
85 arity: BuiltinParamArity::Required,
86 default: None,
87 description: "Input matrix to sort by rows.",
88}];
89
90const SORTROWS_INPUTS_A_COLUMNS: [BuiltinParamDescriptor; 2] = [
91 BuiltinParamDescriptor {
92 name: "A",
93 ty: BuiltinParamType::Any,
94 arity: BuiltinParamArity::Required,
95 default: None,
96 description: "Input matrix to sort by rows.",
97 },
98 BuiltinParamDescriptor {
99 name: "column",
100 ty: BuiltinParamType::NumericArray,
101 arity: BuiltinParamArity::Required,
102 default: None,
103 description: "Column specification vector (negative entries request descending order).",
104 },
105];
106
107const SORTROWS_INPUTS_A_DIRECTION: [BuiltinParamDescriptor; 2] = [
108 BuiltinParamDescriptor {
109 name: "A",
110 ty: BuiltinParamType::Any,
111 arity: BuiltinParamArity::Required,
112 default: None,
113 description: "Input matrix to sort by rows.",
114 },
115 BuiltinParamDescriptor {
116 name: "direction",
117 ty: BuiltinParamType::StringScalar,
118 arity: BuiltinParamArity::Required,
119 default: Some("\"ascend\""),
120 description: "Global row direction override: 'ascend' or 'descend'.",
121 },
122];
123
124const SORTROWS_INPUTS_A_COLUMNS_DIRECTION: [BuiltinParamDescriptor; 3] = [
125 BuiltinParamDescriptor {
126 name: "A",
127 ty: BuiltinParamType::Any,
128 arity: BuiltinParamArity::Required,
129 default: None,
130 description: "Input matrix to sort by rows.",
131 },
132 BuiltinParamDescriptor {
133 name: "column",
134 ty: BuiltinParamType::NumericArray,
135 arity: BuiltinParamArity::Required,
136 default: None,
137 description: "Column specification vector (negative entries request descending order).",
138 },
139 BuiltinParamDescriptor {
140 name: "direction",
141 ty: BuiltinParamType::StringScalar,
142 arity: BuiltinParamArity::Required,
143 default: Some("\"ascend\""),
144 description: "Global row direction override: 'ascend' or 'descend'.",
145 },
146];
147
148const SORTROWS_INPUTS_COMPARISON_METHOD: [BuiltinParamDescriptor; 4] = [
149 BuiltinParamDescriptor {
150 name: "A",
151 ty: BuiltinParamType::Any,
152 arity: BuiltinParamArity::Required,
153 default: None,
154 description: "Input matrix to sort by rows.",
155 },
156 BuiltinParamDescriptor {
157 name: "arg",
158 ty: BuiltinParamType::Any,
159 arity: BuiltinParamArity::Variadic,
160 default: None,
161 description: "Optional column and direction arguments.",
162 },
163 BuiltinParamDescriptor {
164 name: "name",
165 ty: BuiltinParamType::StringScalar,
166 arity: BuiltinParamArity::Required,
167 default: Some("\"ComparisonMethod\""),
168 description: "Name-value option key.",
169 },
170 BuiltinParamDescriptor {
171 name: "method",
172 ty: BuiltinParamType::StringScalar,
173 arity: BuiltinParamArity::Required,
174 default: Some("\"auto\""),
175 description: "Comparison method: 'auto', 'real', or 'abs'.",
176 },
177];
178
179const SORTROWS_INPUTS_MISSING_PLACEMENT: [BuiltinParamDescriptor; 4] = [
180 BuiltinParamDescriptor {
181 name: "A",
182 ty: BuiltinParamType::Any,
183 arity: BuiltinParamArity::Required,
184 default: None,
185 description: "Input matrix to sort by rows.",
186 },
187 BuiltinParamDescriptor {
188 name: "arg",
189 ty: BuiltinParamType::Any,
190 arity: BuiltinParamArity::Variadic,
191 default: None,
192 description: "Optional column and direction arguments.",
193 },
194 BuiltinParamDescriptor {
195 name: "name",
196 ty: BuiltinParamType::StringScalar,
197 arity: BuiltinParamArity::Required,
198 default: Some("\"MissingPlacement\""),
199 description: "Name-value option key.",
200 },
201 BuiltinParamDescriptor {
202 name: "placement",
203 ty: BuiltinParamType::StringScalar,
204 arity: BuiltinParamArity::Required,
205 default: Some("\"auto\""),
206 description: "NaN placement policy: 'auto', 'first', or 'last'.",
207 },
208];
209
210const SORTROWS_SIGNATURES: [BuiltinSignatureDescriptor; 12] = [
211 BuiltinSignatureDescriptor {
212 label: "B = sortrows(A)",
213 inputs: &SORTROWS_INPUTS_A,
214 outputs: &SORTROWS_OUTPUT_B,
215 },
216 BuiltinSignatureDescriptor {
217 label: "B = sortrows(A, column)",
218 inputs: &SORTROWS_INPUTS_A_COLUMNS,
219 outputs: &SORTROWS_OUTPUT_B,
220 },
221 BuiltinSignatureDescriptor {
222 label: "B = sortrows(A, direction)",
223 inputs: &SORTROWS_INPUTS_A_DIRECTION,
224 outputs: &SORTROWS_OUTPUT_B,
225 },
226 BuiltinSignatureDescriptor {
227 label: "B = sortrows(A, column, direction)",
228 inputs: &SORTROWS_INPUTS_A_COLUMNS_DIRECTION,
229 outputs: &SORTROWS_OUTPUT_B,
230 },
231 BuiltinSignatureDescriptor {
232 label: "B = sortrows(A, ..., \"ComparisonMethod\", method)",
233 inputs: &SORTROWS_INPUTS_COMPARISON_METHOD,
234 outputs: &SORTROWS_OUTPUT_B,
235 },
236 BuiltinSignatureDescriptor {
237 label: "B = sortrows(A, ..., \"MissingPlacement\", placement)",
238 inputs: &SORTROWS_INPUTS_MISSING_PLACEMENT,
239 outputs: &SORTROWS_OUTPUT_B,
240 },
241 BuiltinSignatureDescriptor {
242 label: "[B, I] = sortrows(A)",
243 inputs: &SORTROWS_INPUTS_A,
244 outputs: &SORTROWS_OUTPUT_BI,
245 },
246 BuiltinSignatureDescriptor {
247 label: "[B, I] = sortrows(A, column)",
248 inputs: &SORTROWS_INPUTS_A_COLUMNS,
249 outputs: &SORTROWS_OUTPUT_BI,
250 },
251 BuiltinSignatureDescriptor {
252 label: "[B, I] = sortrows(A, direction)",
253 inputs: &SORTROWS_INPUTS_A_DIRECTION,
254 outputs: &SORTROWS_OUTPUT_BI,
255 },
256 BuiltinSignatureDescriptor {
257 label: "[B, I] = sortrows(A, column, direction)",
258 inputs: &SORTROWS_INPUTS_A_COLUMNS_DIRECTION,
259 outputs: &SORTROWS_OUTPUT_BI,
260 },
261 BuiltinSignatureDescriptor {
262 label: "[B, I] = sortrows(A, ..., \"ComparisonMethod\", method)",
263 inputs: &SORTROWS_INPUTS_COMPARISON_METHOD,
264 outputs: &SORTROWS_OUTPUT_BI,
265 },
266 BuiltinSignatureDescriptor {
267 label: "[B, I] = sortrows(A, ..., \"MissingPlacement\", placement)",
268 inputs: &SORTROWS_INPUTS_MISSING_PLACEMENT,
269 outputs: &SORTROWS_OUTPUT_BI,
270 },
271];
272
273const SORTROWS_ERROR_INVALID_COLUMN_INDEX: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
274 code: "RM.SORTROWS.INVALID_COLUMN_INDEX",
275 identifier: Some("RunMat:sortrows:InvalidColumnIndex"),
276 when: "Column specification indices are out of range, zero, or otherwise invalid.",
277 message: "sortrows: invalid column index",
278};
279
280const SORTROWS_ERROR_MISSING_PLACEMENT_UNKNOWN: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
281 code: "RM.SORTROWS.MISSING_PLACEMENT_UNKNOWN",
282 identifier: Some("RunMat:sortrows:MissingPlacementUnknown"),
283 when: "MissingPlacement option value is unsupported.",
284 message: "sortrows: unsupported MissingPlacement value",
285};
286
287const SORTROWS_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
288 code: "RM.SORTROWS.INVALID_ARGUMENT",
289 identifier: Some("RunMat:sortrows:InvalidArgument"),
290 when: "Option parsing receives invalid argument kinds or malformed name-value pairs.",
291 message: "sortrows: invalid argument",
292};
293
294const SORTROWS_ERROR_COMPARISON_METHOD_UNKNOWN: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
295 code: "RM.SORTROWS.COMPARISON_METHOD_UNKNOWN",
296 identifier: Some("RunMat:sortrows:ComparisonMethodUnknown"),
297 when: "ComparisonMethod option value is unsupported.",
298 message: "sortrows: unsupported ComparisonMethod value",
299};
300
301const SORTROWS_ERROR_UNSUPPORTED_INPUT_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
302 code: "RM.SORTROWS.UNSUPPORTED_INPUT_TYPE",
303 identifier: Some("RunMat:sortrows:UnsupportedInputType"),
304 when: "Input cannot be converted to numeric, logical, complex, or char matrix domain.",
305 message: "sortrows: unsupported input type",
306};
307
308const SORTROWS_ERROR_MATRIX_REQUIRED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
309 code: "RM.SORTROWS.MATRIX_REQUIRED",
310 identifier: Some("RunMat:sortrows:MatrixRequired"),
311 when: "Input has rank greater than 2 where matrix input is required.",
312 message: "sortrows: input must be a 2-D matrix",
313};
314
315const SORTROWS_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
316 code: "RM.SORTROWS.INTERNAL",
317 identifier: Some("RunMat:sortrows:Internal"),
318 when: "Internal conversion/allocation/provider decode fails.",
319 message: "sortrows: internal operation failed",
320};
321
322const SORTROWS_ERRORS: [BuiltinErrorDescriptor; 7] = [
323 SORTROWS_ERROR_INVALID_COLUMN_INDEX,
324 SORTROWS_ERROR_MISSING_PLACEMENT_UNKNOWN,
325 SORTROWS_ERROR_INVALID_ARGUMENT,
326 SORTROWS_ERROR_COMPARISON_METHOD_UNKNOWN,
327 SORTROWS_ERROR_UNSUPPORTED_INPUT_TYPE,
328 SORTROWS_ERROR_MATRIX_REQUIRED,
329 SORTROWS_ERROR_INTERNAL,
330];
331
332pub const SORTROWS_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
333 signatures: &SORTROWS_SIGNATURES,
334 output_mode: BuiltinOutputMode::ByRequestedOutputCount,
335 completion_policy: BuiltinCompletionPolicy::Public,
336 errors: &SORTROWS_ERRORS,
337};
338
339fn sortrows_error_with(
340 error: &'static BuiltinErrorDescriptor,
341 message: impl Into<String>,
342) -> crate::RuntimeError {
343 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
344 if let Some(identifier) = error.identifier {
345 builder = builder.with_identifier(identifier);
346 }
347 builder.build()
348}
349
350fn sortrows_error(error: &'static BuiltinErrorDescriptor) -> crate::RuntimeError {
351 sortrows_error_with(error, error.message)
352}
353
354fn sortrows_internal_error(message: impl Into<String>) -> crate::RuntimeError {
355 sortrows_error_with(&SORTROWS_ERROR_INTERNAL, message)
356}
357
358#[runtime_builtin(
359 name = "sortrows",
360 category = "array/sorting_sets",
361 summary = "Sort matrix rows lexicographically with column and direction controls.",
362 keywords = "sortrows,row sort,lexicographic,gpu",
363 accel = "sink",
364 sink = true,
365 type_resolver(tensor_output_type),
366 descriptor(crate::builtins::array::sorting_sets::sortrows::SORTROWS_DESCRIPTOR),
367 builtin_path = "crate::builtins::array::sorting_sets::sortrows"
368)]
369async fn sortrows_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
370 let eval = evaluate(value, &rest).await?;
371 if let Some(out_count) = crate::output_count::current_output_count() {
372 if out_count == 0 {
373 return Ok(Value::OutputList(Vec::new()));
374 }
375 let (sorted, indices) = eval.into_values();
376 let mut outputs = vec![sorted];
377 if out_count >= 2 {
378 outputs.push(indices);
379 }
380 return Ok(crate::output_count::output_list_with_padding(
381 out_count, outputs,
382 ));
383 }
384 Ok(eval.into_sorted_value())
385}
386
387pub async fn evaluate(value: Value, rest: &[Value]) -> crate::BuiltinResult<SortRowsEvaluation> {
389 match value {
390 Value::GpuTensor(handle) => sortrows_gpu(handle, rest).await,
391 other => sortrows_host(other, rest),
392 }
393}
394
395async fn sortrows_gpu(
396 handle: GpuTensorHandle,
397 rest: &[Value],
398) -> crate::BuiltinResult<SortRowsEvaluation> {
399 ensure_matrix_shape(&handle.shape)?;
400 let (_, cols) = rows_cols_from_shape(&handle.shape);
401 let args = SortRowsArgs::parse(rest, cols)?;
402
403 if args.missing_is_auto() {
404 if let Some(provider) = runmat_accelerate_api::provider() {
405 let provider_columns = args.to_provider_columns();
406 let provider_comparison = args.provider_comparison();
407 match provider
408 .sort_rows(&handle, &provider_columns, provider_comparison)
409 .await
410 {
411 Ok(result) => return sortrows_from_provider_result(result),
412 Err(_err) => {
413 }
415 }
416 }
417 }
418
419 let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
420 sortrows_real_tensor_with_args(tensor, &args)
421}
422
423fn sortrows_from_provider_result(
424 result: ProviderSortResult,
425) -> crate::BuiltinResult<SortRowsEvaluation> {
426 let sorted_tensor = Tensor::new(result.values.data, result.values.shape)
427 .map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))?;
428 let indices_tensor = Tensor::new(result.indices.data, result.indices.shape)
429 .map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))?;
430 Ok(SortRowsEvaluation {
431 sorted: tensor::tensor_into_value(sorted_tensor),
432 indices: indices_tensor,
433 })
434}
435
436fn sortrows_host(value: Value, rest: &[Value]) -> crate::BuiltinResult<SortRowsEvaluation> {
437 match value {
438 Value::Tensor(tensor) => sortrows_real_tensor(tensor, rest),
439 Value::LogicalArray(logical) => {
440 let tensor = tensor::logical_to_tensor(&logical)
441 .map_err(|e| sortrows_internal_error(e))?;
442 sortrows_real_tensor(tensor, rest)
443 }
444 Value::Num(_) | Value::Int(_) | Value::Bool(_) => {
445 let tensor = tensor::value_into_tensor_for("sortrows", value)
446 .map_err(|e| sortrows_internal_error(e))?;
447 sortrows_real_tensor(tensor, rest)
448 }
449 Value::ComplexTensor(ct) => sortrows_complex_tensor(ct, rest),
450 Value::Complex(re, im) => {
451 let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
452 .map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))?;
453 sortrows_complex_tensor(tensor, rest)
454 }
455 Value::CharArray(ca) => sortrows_char_array(ca, rest),
456 other => Err(sortrows_error_with(
457 &SORTROWS_ERROR_UNSUPPORTED_INPUT_TYPE,
458 format!(
459 "sortrows: unsupported input type {:?}; expected numeric, logical, complex, or char arrays",
460 other
461 ),
462 )
463 .into()),
464 }
465}
466
467fn sortrows_real_tensor(
468 tensor: Tensor,
469 rest: &[Value],
470) -> crate::BuiltinResult<SortRowsEvaluation> {
471 ensure_matrix_shape(&tensor.shape)?;
472 let cols = tensor.cols();
473 let args = SortRowsArgs::parse(rest, cols)?;
474 sortrows_real_tensor_with_args(tensor, &args)
475}
476
477fn sortrows_real_tensor_with_args(
478 tensor: Tensor,
479 args: &SortRowsArgs,
480) -> crate::BuiltinResult<SortRowsEvaluation> {
481 let rows = tensor.rows();
482 let cols = tensor.cols();
483
484 if rows <= 1 || cols == 0 || tensor.data.is_empty() || args.columns.is_empty() {
485 let indices = identity_indices(rows)?;
486 return Ok(SortRowsEvaluation {
487 sorted: tensor::tensor_into_value(tensor),
488 indices,
489 });
490 }
491
492 let mut order: Vec<usize> = (0..rows).collect();
493 order.sort_by(|&a, &b| compare_real_rows(&tensor, rows, args, a, b));
494
495 let sorted_tensor = reorder_real_rows(&tensor, rows, cols, &order)?;
496 let indices = permutation_indices(&order)?;
497 Ok(SortRowsEvaluation {
498 sorted: tensor::tensor_into_value(sorted_tensor),
499 indices,
500 })
501}
502
503fn sortrows_complex_tensor(
504 tensor: ComplexTensor,
505 rest: &[Value],
506) -> crate::BuiltinResult<SortRowsEvaluation> {
507 ensure_matrix_shape(&tensor.shape)?;
508 let cols = tensor.cols;
509 let args = SortRowsArgs::parse(rest, cols)?;
510 sortrows_complex_tensor_with_args(tensor, &args)
511}
512
513fn sortrows_complex_tensor_with_args(
514 tensor: ComplexTensor,
515 args: &SortRowsArgs,
516) -> crate::BuiltinResult<SortRowsEvaluation> {
517 let rows = tensor.rows;
518 let cols = tensor.cols;
519
520 if rows <= 1 || cols == 0 || tensor.data.is_empty() || args.columns.is_empty() {
521 let indices = identity_indices(rows)?;
522 return Ok(SortRowsEvaluation {
523 sorted: complex_tensor_into_value(tensor),
524 indices,
525 });
526 }
527
528 let mut order: Vec<usize> = (0..rows).collect();
529 order.sort_by(|&a, &b| compare_complex_rows(&tensor, rows, args, a, b));
530
531 let sorted_tensor = reorder_complex_rows(&tensor, rows, cols, &order)?;
532 let indices = permutation_indices(&order)?;
533 Ok(SortRowsEvaluation {
534 sorted: complex_tensor_into_value(sorted_tensor),
535 indices,
536 })
537}
538
539fn sortrows_char_array(ca: CharArray, rest: &[Value]) -> crate::BuiltinResult<SortRowsEvaluation> {
540 let cols = ca.cols;
541 let args = SortRowsArgs::parse(rest, cols)?;
542 sortrows_char_array_with_args(ca, &args)
543}
544
545fn sortrows_char_array_with_args(
546 ca: CharArray,
547 args: &SortRowsArgs,
548) -> crate::BuiltinResult<SortRowsEvaluation> {
549 let rows = ca.rows;
550 let cols = ca.cols;
551
552 if rows <= 1 || cols == 0 || ca.data.is_empty() || args.columns.is_empty() {
553 let indices = identity_indices(rows)?;
554 return Ok(SortRowsEvaluation {
555 sorted: Value::CharArray(ca),
556 indices,
557 });
558 }
559
560 let mut order: Vec<usize> = (0..rows).collect();
561 order.sort_by(|&a, &b| compare_char_rows(&ca, args, a, b));
562
563 let sorted = reorder_char_rows(&ca, rows, cols, &order)?;
564 let indices = permutation_indices(&order)?;
565 Ok(SortRowsEvaluation {
566 sorted: Value::CharArray(sorted),
567 indices,
568 })
569}
570
571fn ensure_matrix_shape(shape: &[usize]) -> crate::BuiltinResult<()> {
572 if shape.len() <= 2 {
573 Ok(())
574 } else {
575 Err(sortrows_error(&SORTROWS_ERROR_MATRIX_REQUIRED))
576 }
577}
578
579fn rows_cols_from_shape(shape: &[usize]) -> (usize, usize) {
580 match shape.len() {
581 0 => (1, 1),
582 1 => (1, shape[0]),
583 _ => (shape[0], shape[1]),
584 }
585}
586
587fn compare_real_rows(
588 tensor: &Tensor,
589 rows: usize,
590 args: &SortRowsArgs,
591 a: usize,
592 b: usize,
593) -> Ordering {
594 for spec in &args.columns {
595 if spec.index >= tensor.cols() {
596 continue;
597 }
598 let idx_a = a + spec.index * rows;
599 let idx_b = b + spec.index * rows;
600 let va = tensor.data[idx_a];
601 let vb = tensor.data[idx_b];
602 let missing = args.missing_for_direction(spec.direction);
603 let ord = compare_real_scalars(va, vb, spec.direction, args.comparison, missing);
604 if ord != Ordering::Equal {
605 return ord;
606 }
607 }
608 Ordering::Equal
609}
610
611fn compare_complex_rows(
612 tensor: &ComplexTensor,
613 rows: usize,
614 args: &SortRowsArgs,
615 a: usize,
616 b: usize,
617) -> Ordering {
618 for spec in &args.columns {
619 if spec.index >= tensor.cols {
620 continue;
621 }
622 let idx_a = a + spec.index * rows;
623 let idx_b = b + spec.index * rows;
624 let va = tensor.data[idx_a];
625 let vb = tensor.data[idx_b];
626 let missing = args.missing_for_direction(spec.direction);
627 let ord = compare_complex_scalars(va, vb, spec.direction, args.comparison, missing);
628 if ord != Ordering::Equal {
629 return ord;
630 }
631 }
632 Ordering::Equal
633}
634
635fn compare_char_rows(ca: &CharArray, args: &SortRowsArgs, a: usize, b: usize) -> Ordering {
636 for spec in &args.columns {
637 if spec.index >= ca.cols {
638 continue;
639 }
640 let idx_a = a * ca.cols + spec.index;
641 let idx_b = b * ca.cols + spec.index;
642 let va = ca.data[idx_a];
643 let vb = ca.data[idx_b];
644 let ord = match spec.direction {
645 SortDirection::Ascend => va.cmp(&vb),
646 SortDirection::Descend => vb.cmp(&va),
647 };
648 if ord != Ordering::Equal {
649 return ord;
650 }
651 }
652 Ordering::Equal
653}
654
655fn reorder_real_rows(
656 tensor: &Tensor,
657 rows: usize,
658 cols: usize,
659 order: &[usize],
660) -> crate::BuiltinResult<Tensor> {
661 let mut data = vec![0.0; tensor.data.len()];
662 for col in 0..cols {
663 for (dest_row, &src_row) in order.iter().enumerate() {
664 let src_idx = src_row + col * rows;
665 let dst_idx = dest_row + col * rows;
666 data[dst_idx] = tensor.data[src_idx];
667 }
668 }
669 Tensor::new(data, tensor.shape.clone())
670 .map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))
671}
672
673fn reorder_complex_rows(
674 tensor: &ComplexTensor,
675 rows: usize,
676 cols: usize,
677 order: &[usize],
678) -> crate::BuiltinResult<ComplexTensor> {
679 let mut data = vec![(0.0, 0.0); tensor.data.len()];
680 for col in 0..cols {
681 for (dest_row, &src_row) in order.iter().enumerate() {
682 let src_idx = src_row + col * rows;
683 let dst_idx = dest_row + col * rows;
684 data[dst_idx] = tensor.data[src_idx];
685 }
686 }
687 ComplexTensor::new(data, tensor.shape.clone())
688 .map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))
689}
690
691fn reorder_char_rows(
692 ca: &CharArray,
693 rows: usize,
694 cols: usize,
695 order: &[usize],
696) -> crate::BuiltinResult<CharArray> {
697 let mut data = vec!['\0'; ca.data.len()];
698 for (dest_row, &src_row) in order.iter().enumerate() {
699 for col in 0..cols {
700 let src_idx = src_row * cols + col;
701 let dst_idx = dest_row * cols + col;
702 data[dst_idx] = ca.data[src_idx];
703 }
704 }
705 CharArray::new(data, rows, cols).map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))
706}
707
708fn compare_real_scalars(
709 a: f64,
710 b: f64,
711 direction: SortDirection,
712 comparison: ComparisonMethod,
713 missing: MissingPlacementResolved,
714) -> Ordering {
715 match (a.is_nan(), b.is_nan()) {
716 (true, true) => Ordering::Equal,
717 (true, false) => match missing {
718 MissingPlacementResolved::First => Ordering::Less,
719 MissingPlacementResolved::Last => Ordering::Greater,
720 },
721 (false, true) => match missing {
722 MissingPlacementResolved::First => Ordering::Greater,
723 MissingPlacementResolved::Last => Ordering::Less,
724 },
725 (false, false) => compare_real_finite_scalars(a, b, direction, comparison),
726 }
727}
728
729fn compare_real_finite_scalars(
730 a: f64,
731 b: f64,
732 direction: SortDirection,
733 comparison: ComparisonMethod,
734) -> Ordering {
735 if matches!(comparison, ComparisonMethod::Abs) {
736 let abs_cmp = a.abs().partial_cmp(&b.abs()).unwrap_or(Ordering::Equal);
737 if abs_cmp != Ordering::Equal {
738 return match direction {
739 SortDirection::Ascend => abs_cmp,
740 SortDirection::Descend => abs_cmp.reverse(),
741 };
742 }
743 }
744 match direction {
745 SortDirection::Ascend => a.partial_cmp(&b).unwrap_or(Ordering::Equal),
746 SortDirection::Descend => b.partial_cmp(&a).unwrap_or(Ordering::Equal),
747 }
748}
749
750fn compare_complex_scalars(
751 a: (f64, f64),
752 b: (f64, f64),
753 direction: SortDirection,
754 comparison: ComparisonMethod,
755 missing: MissingPlacementResolved,
756) -> Ordering {
757 match (complex_is_nan(a), complex_is_nan(b)) {
758 (true, true) => Ordering::Equal,
759 (true, false) => match missing {
760 MissingPlacementResolved::First => Ordering::Less,
761 MissingPlacementResolved::Last => Ordering::Greater,
762 },
763 (false, true) => match missing {
764 MissingPlacementResolved::First => Ordering::Greater,
765 MissingPlacementResolved::Last => Ordering::Less,
766 },
767 (false, false) => compare_complex_finite_scalars(a, b, direction, comparison),
768 }
769}
770
771fn compare_complex_finite_scalars(
772 a: (f64, f64),
773 b: (f64, f64),
774 direction: SortDirection,
775 comparison: ComparisonMethod,
776) -> Ordering {
777 match comparison {
778 ComparisonMethod::Real => compare_complex_real_first(a, b, direction),
779 ComparisonMethod::Auto | ComparisonMethod::Abs => {
780 let abs_cmp = complex_abs(a)
781 .partial_cmp(&complex_abs(b))
782 .unwrap_or(Ordering::Equal);
783 if abs_cmp != Ordering::Equal {
784 return match direction {
785 SortDirection::Ascend => abs_cmp,
786 SortDirection::Descend => abs_cmp.reverse(),
787 };
788 }
789 compare_complex_real_first(a, b, direction)
790 }
791 }
792}
793
794fn compare_complex_real_first(a: (f64, f64), b: (f64, f64), direction: SortDirection) -> Ordering {
795 let real_cmp = match direction {
796 SortDirection::Ascend => a.0.partial_cmp(&b.0),
797 SortDirection::Descend => b.0.partial_cmp(&a.0),
798 }
799 .unwrap_or(Ordering::Equal);
800 if real_cmp != Ordering::Equal {
801 return real_cmp;
802 }
803 match direction {
804 SortDirection::Ascend => a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal),
805 SortDirection::Descend => b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal),
806 }
807}
808
809fn complex_is_nan(value: (f64, f64)) -> bool {
810 value.0.is_nan() || value.1.is_nan()
811}
812
813fn complex_abs(value: (f64, f64)) -> f64 {
814 value.0.hypot(value.1)
815}
816
817fn permutation_indices(order: &[usize]) -> crate::BuiltinResult<Tensor> {
818 let rows = order.len();
819 let mut data = Vec::with_capacity(rows);
820 for &idx in order {
821 data.push((idx + 1) as f64);
822 }
823 Tensor::new(data, vec![rows, 1]).map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))
824}
825
826fn identity_indices(rows: usize) -> crate::BuiltinResult<Tensor> {
827 let mut data = Vec::with_capacity(rows);
828 for i in 0..rows {
829 data.push((i + 1) as f64);
830 }
831 Tensor::new(data, vec![rows, 1]).map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))
832}
833
834fn complex_tensor_into_value(tensor: ComplexTensor) -> Value {
835 if tensor.data.len() == 1 {
836 Value::Complex(tensor.data[0].0, tensor.data[0].1)
837 } else {
838 Value::ComplexTensor(tensor)
839 }
840}
841
842#[derive(Debug, Clone, Copy, PartialEq, Eq)]
843enum SortDirection {
844 Ascend,
845 Descend,
846}
847
848impl SortDirection {
849 fn from_str(value: &str) -> Option<Self> {
850 match value.trim().to_ascii_lowercase().as_str() {
851 "ascend" | "ascending" => Some(SortDirection::Ascend),
852 "descend" | "descending" => Some(SortDirection::Descend),
853 _ => None,
854 }
855 }
856}
857
858#[derive(Debug, Clone, Copy, PartialEq, Eq)]
859enum ComparisonMethod {
860 Auto,
861 Real,
862 Abs,
863}
864
865#[derive(Debug, Clone, Copy, PartialEq, Eq)]
866enum MissingPlacement {
867 Auto,
868 First,
869 Last,
870}
871
872#[derive(Debug, Clone, Copy, PartialEq, Eq)]
873enum MissingPlacementResolved {
874 First,
875 Last,
876}
877
878impl MissingPlacement {
879 fn resolve(self, direction: SortDirection) -> MissingPlacementResolved {
880 match self {
881 MissingPlacement::First => MissingPlacementResolved::First,
882 MissingPlacement::Last => MissingPlacementResolved::Last,
883 MissingPlacement::Auto => match direction {
884 SortDirection::Ascend => MissingPlacementResolved::Last,
885 SortDirection::Descend => MissingPlacementResolved::First,
886 },
887 }
888 }
889
890 fn is_auto(self) -> bool {
891 matches!(self, MissingPlacement::Auto)
892 }
893}
894
895#[derive(Debug, Clone)]
896struct ColumnSpec {
897 index: usize,
898 direction: SortDirection,
899}
900
901#[derive(Debug, Clone)]
902struct SortRowsArgs {
903 columns: Vec<ColumnSpec>,
904 comparison: ComparisonMethod,
905 missing: MissingPlacement,
906}
907
908impl SortRowsArgs {
909 fn parse(rest: &[Value], num_cols: usize) -> crate::BuiltinResult<Self> {
910 let mut columns: Option<Vec<ColumnSpec>> = None;
911 let mut override_direction: Option<SortDirection> = None;
912 let mut comparison = ComparisonMethod::Auto;
913 let mut missing = MissingPlacement::Auto;
914 let mut i = 0usize;
915
916 while i < rest.len() {
917 if columns.is_none() {
918 if let Some(parsed) = parse_column_vector(&rest[i], num_cols)? {
919 columns = Some(parsed);
920 i += 1;
921 continue;
922 }
923 }
924 if let Some(direction) = parse_direction(&rest[i]) {
925 override_direction = Some(direction);
926 i += 1;
927 continue;
928 }
929 let Some(keyword) = tensor::value_to_string(&rest[i]) else {
930 return Err(sortrows_error_with(
931 &SORTROWS_ERROR_INVALID_ARGUMENT,
932 format!("sortrows: invalid argument {:?}", rest[i]),
933 ));
934 };
935 let lowered = keyword.trim().to_ascii_lowercase();
936 match lowered.as_str() {
937 "comparisonmethod" => {
938 i += 1;
939 if i >= rest.len() {
940 return Err(sortrows_error_with(
941 &SORTROWS_ERROR_INVALID_ARGUMENT,
942 "sortrows: expected a value for 'ComparisonMethod'",
943 ));
944 }
945 let Some(value_str) = tensor::value_to_string(&rest[i]) else {
946 return Err(sortrows_error_with(
947 &SORTROWS_ERROR_INVALID_ARGUMENT,
948 "sortrows: 'ComparisonMethod' expects a string value",
949 )
950 .into());
951 };
952 comparison = match value_str.trim().to_ascii_lowercase().as_str() {
953 "auto" => ComparisonMethod::Auto,
954 "real" => ComparisonMethod::Real,
955 "abs" | "magnitude" => ComparisonMethod::Abs,
956 other => {
957 return Err(sortrows_error_with(
958 &SORTROWS_ERROR_COMPARISON_METHOD_UNKNOWN,
959 format!("sortrows: unsupported ComparisonMethod '{other}'"),
960 )
961 .into())
962 }
963 };
964 i += 1;
965 }
966 "missingplacement" => {
967 i += 1;
968 if i >= rest.len() {
969 return Err(sortrows_error_with(
970 &SORTROWS_ERROR_INVALID_ARGUMENT,
971 "sortrows: expected a value for 'MissingPlacement'",
972 )
973 .into());
974 }
975 let Some(value_str) = tensor::value_to_string(&rest[i]) else {
976 return Err(sortrows_error_with(
977 &SORTROWS_ERROR_INVALID_ARGUMENT,
978 "sortrows: 'MissingPlacement' expects a string value",
979 )
980 .into());
981 };
982 missing = match value_str.trim().to_ascii_lowercase().as_str() {
983 "auto" => MissingPlacement::Auto,
984 "first" => MissingPlacement::First,
985 "last" => MissingPlacement::Last,
986 other => {
987 return Err(sortrows_error_with(
988 &SORTROWS_ERROR_MISSING_PLACEMENT_UNKNOWN,
989 format!("sortrows: unsupported MissingPlacement '{other}'"),
990 )
991 .into())
992 }
993 };
994 i += 1;
995 }
996 other => {
997 return Err(sortrows_error_with(
998 &SORTROWS_ERROR_INVALID_ARGUMENT,
999 format!("sortrows: unexpected argument '{other}'"),
1000 ));
1001 }
1002 }
1003 }
1004
1005 let mut columns = columns.unwrap_or_else(|| default_columns(num_cols));
1006 if let Some(dir) = override_direction {
1007 for spec in &mut columns {
1008 spec.direction = dir;
1009 }
1010 }
1011 validate_columns(&columns, num_cols)?;
1012
1013 Ok(SortRowsArgs {
1014 columns,
1015 comparison,
1016 missing,
1017 })
1018 }
1019
1020 fn to_provider_columns(&self) -> Vec<ProviderSortRowsColumnSpec> {
1021 self.columns
1022 .iter()
1023 .map(|spec| ProviderSortRowsColumnSpec {
1024 index: spec.index,
1025 order: match spec.direction {
1026 SortDirection::Ascend => ProviderSortOrder::Ascend,
1027 SortDirection::Descend => ProviderSortOrder::Descend,
1028 },
1029 })
1030 .collect()
1031 }
1032
1033 fn provider_comparison(&self) -> ProviderSortComparison {
1034 match self.comparison {
1035 ComparisonMethod::Auto => ProviderSortComparison::Auto,
1036 ComparisonMethod::Real => ProviderSortComparison::Real,
1037 ComparisonMethod::Abs => ProviderSortComparison::Abs,
1038 }
1039 }
1040
1041 fn missing_for_direction(&self, direction: SortDirection) -> MissingPlacementResolved {
1042 self.missing.resolve(direction)
1043 }
1044
1045 fn missing_is_auto(&self) -> bool {
1046 self.missing.is_auto()
1047 }
1048}
1049
1050fn parse_column_vector(
1051 value: &Value,
1052 num_cols: usize,
1053) -> crate::BuiltinResult<Option<Vec<ColumnSpec>>> {
1054 match value {
1055 Value::Int(i) => parse_single_column(i.to_i64(), num_cols).map(Some),
1056 Value::Num(n) => {
1057 if !n.is_finite() {
1058 return Err(sortrows_error_with(
1059 &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1060 "sortrows: column indices must be finite",
1061 ));
1062 }
1063 let rounded = n.round();
1064 if (rounded - n).abs() > f64::EPSILON {
1065 return Err(sortrows_error_with(
1066 &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1067 "sortrows: column indices must be integers",
1068 ));
1069 }
1070 parse_single_column(rounded as i64, num_cols).map(Some)
1071 }
1072 Value::Tensor(tensor) => {
1073 if !is_vector(&tensor.shape) {
1074 return Err(sortrows_error_with(
1075 &SORTROWS_ERROR_INVALID_ARGUMENT,
1076 "sortrows: column specification must be a vector",
1077 ));
1078 }
1079 let mut specs = Vec::with_capacity(tensor.data.len());
1080 for &entry in &tensor.data {
1081 if !entry.is_finite() {
1082 return Err(sortrows_error_with(
1083 &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1084 "sortrows: column indices must be finite",
1085 ));
1086 }
1087 let rounded = entry.round();
1088 if (rounded - entry).abs() > f64::EPSILON {
1089 return Err(sortrows_error_with(
1090 &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1091 "sortrows: column indices must be integers",
1092 ));
1093 }
1094 let column = parse_single_column_i64(rounded as i64, num_cols)?;
1095 specs.push(column);
1096 }
1097 Ok(Some(specs))
1098 }
1099 _ => Ok(None),
1100 }
1101}
1102
1103fn parse_single_column(value: i64, num_cols: usize) -> crate::BuiltinResult<Vec<ColumnSpec>> {
1104 parse_single_column_i64(value, num_cols).map(|spec| vec![spec])
1105}
1106
1107fn parse_single_column_i64(value: i64, num_cols: usize) -> crate::BuiltinResult<ColumnSpec> {
1108 if value == 0 {
1109 return Err(sortrows_error_with(
1110 &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1111 "sortrows: column indices must be non-zero",
1112 ));
1113 }
1114 let abs = value.unsigned_abs() as usize;
1115 if abs == 0 {
1116 return Err(sortrows_error_with(
1117 &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1118 "sortrows: column indices must be >= 1",
1119 ));
1120 }
1121 if num_cols == 0 {
1122 return Err(sortrows_error_with(
1123 &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1124 "sortrows: column index exceeds matrix with 0 columns",
1125 ));
1126 }
1127 if abs > num_cols {
1128 return Err(sortrows_error_with(
1129 &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1130 format!(
1131 "sortrows: column index {} exceeds matrix with {} columns",
1132 abs, num_cols
1133 ),
1134 )
1135 .into());
1136 }
1137 let direction = if value > 0 {
1138 SortDirection::Ascend
1139 } else {
1140 SortDirection::Descend
1141 };
1142 Ok(ColumnSpec {
1143 index: abs - 1,
1144 direction,
1145 })
1146}
1147
1148fn parse_direction(value: &Value) -> Option<SortDirection> {
1149 tensor::value_to_string(value).and_then(|s| SortDirection::from_str(&s))
1150}
1151
1152fn default_columns(num_cols: usize) -> Vec<ColumnSpec> {
1153 let mut columns = Vec::with_capacity(num_cols);
1154 for col in 0..num_cols {
1155 columns.push(ColumnSpec {
1156 index: col,
1157 direction: SortDirection::Ascend,
1158 });
1159 }
1160 columns
1161}
1162
1163fn validate_columns(columns: &[ColumnSpec], num_cols: usize) -> crate::BuiltinResult<()> {
1164 if num_cols == 0 && columns.iter().any(|spec| spec.index > 0) {
1165 return Err(sortrows_error_with(
1166 &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1167 "sortrows: column index exceeds matrix with 0 columns",
1168 ));
1169 }
1170 for spec in columns {
1171 if num_cols > 0 && spec.index >= num_cols {
1172 return Err(sortrows_error_with(
1173 &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1174 format!(
1175 "sortrows: column index {} exceeds matrix with {} columns",
1176 spec.index + 1,
1177 num_cols
1178 ),
1179 )
1180 .into());
1181 }
1182 }
1183 Ok(())
1184}
1185
1186fn is_vector(shape: &[usize]) -> bool {
1187 match shape.len() {
1188 0 => true,
1189 1 => true,
1190 2 => shape[0] == 1 || shape[1] == 1,
1191 _ => false,
1192 }
1193}
1194
1195#[derive(Debug)]
1196pub struct SortRowsEvaluation {
1197 sorted: Value,
1198 indices: Tensor,
1199}
1200
1201impl SortRowsEvaluation {
1202 pub fn into_sorted_value(self) -> Value {
1203 self.sorted
1204 }
1205
1206 pub fn into_values(self) -> (Value, Value) {
1207 let indices = tensor::tensor_into_value(self.indices);
1208 (self.sorted, indices)
1209 }
1210
1211 pub fn indices_value(&self) -> Value {
1212 tensor::tensor_into_value(self.indices.clone())
1213 }
1214}
1215
1216#[cfg(test)]
1217pub(crate) mod tests {
1218 use super::*;
1219 use crate::builtins::common::test_support;
1220 use runmat_builtins::{IntValue, ResolveContext, Type, Value};
1221
1222 fn evaluate(value: Value, rest: &[Value]) -> crate::BuiltinResult<SortRowsEvaluation> {
1223 futures::executor::block_on(super::evaluate(value, rest))
1224 }
1225
1226 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1227 #[test]
1228 fn sortrows_default_matrix() {
1229 let tensor = Tensor::new(vec![3.0, 1.0, 2.0, 4.0, 1.0, 5.0], vec![3, 2]).unwrap();
1230 let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
1231 let (sorted, indices) = eval.into_values();
1232 match sorted {
1233 Value::Tensor(t) => {
1234 assert_eq!(t.shape, vec![3, 2]);
1235 assert_eq!(t.data, vec![1.0, 2.0, 3.0, 1.0, 5.0, 4.0]);
1236 }
1237 other => panic!("expected tensor, got {other:?}"),
1238 }
1239 match indices {
1240 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
1241 Value::Num(_) => panic!("expected tensor indices"),
1242 other => panic!("unexpected indices {other:?}"),
1243 }
1244 }
1245
1246 #[test]
1247 fn sortrows_type_resolver_tensor() {
1248 assert_eq!(
1249 tensor_output_type(&[Type::tensor()], &ResolveContext::new(Vec::new())),
1250 Type::tensor()
1251 );
1252 }
1253
1254 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1255 #[test]
1256 fn sortrows_with_column_vector() {
1257 let tensor = Tensor::new(
1258 vec![1.0, 3.0, 3.0, 4.0, 2.0, 2.0, 2.0, 5.0, 1.0],
1259 vec![3, 3],
1260 )
1261 .unwrap();
1262 let cols = Tensor::new(vec![2.0, 3.0, 1.0], vec![3, 1]).unwrap();
1263 let eval = evaluate(Value::Tensor(tensor), &[Value::Tensor(cols)]).expect("evaluate");
1264 let (sorted, _) = eval.into_values();
1265 match sorted {
1266 Value::Tensor(t) => {
1267 assert_eq!(t.data, vec![3.0, 3.0, 1.0, 2.0, 2.0, 4.0, 1.0, 5.0, 2.0]);
1268 }
1269 other => panic!("expected tensor, got {other:?}"),
1270 }
1271 }
1272
1273 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1274 #[test]
1275 fn sortrows_direction_descend() {
1276 let tensor = Tensor::new(vec![1.0, 2.0, 4.0, 3.0], vec![2, 2]).unwrap();
1277 let eval = evaluate(Value::Tensor(tensor), &[Value::from("descend")]).expect("evaluate");
1278 let (sorted, _) = eval.into_values();
1279 match sorted {
1280 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0, 3.0, 4.0]),
1281 other => panic!("expected tensor, got {other:?}"),
1282 }
1283 }
1284
1285 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1286 #[test]
1287 fn sortrows_mixed_directions() {
1288 let tensor = Tensor::new(vec![1.0, 1.0, 1.0, 1.0, 7.0, 2.0], vec![3, 2]).unwrap();
1289 let cols = Tensor::new(vec![1.0, -2.0], vec![2, 1]).unwrap();
1290 let eval = evaluate(Value::Tensor(tensor), &[Value::Tensor(cols)]).expect("evaluate");
1291 let (sorted, _) = eval.into_values();
1292 match sorted {
1293 Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 1.0, 1.0, 7.0, 2.0, 1.0]),
1294 other => panic!("expected tensor, got {other:?}"),
1295 }
1296 }
1297
1298 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1299 #[test]
1300 fn sortrows_returns_indices() {
1301 let tensor = Tensor::new(vec![2.0, 1.0, 3.0, 4.0], vec![2, 2]).unwrap();
1302 let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
1303 let (_, indices) = eval.into_values();
1304 match indices {
1305 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0]),
1306 Value::Num(_) => panic!("expected tensor indices"),
1307 other => panic!("unexpected indices {other:?}"),
1308 }
1309 }
1310
1311 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1312 #[test]
1313 fn sortrows_char_array() {
1314 let chars = CharArray::new(
1315 "bob "
1316 .chars()
1317 .chain("al ".chars())
1318 .chain("ally".chars())
1319 .collect(),
1320 3,
1321 4,
1322 )
1323 .unwrap();
1324 let eval = evaluate(Value::CharArray(chars), &[]).expect("evaluate");
1325 let (sorted, _) = eval.into_values();
1326 match sorted {
1327 Value::CharArray(ca) => {
1328 assert_eq!(ca.rows, 3);
1329 assert_eq!(ca.cols, 4);
1330 let strings: Vec<String> = (0..ca.rows)
1331 .map(|r| {
1332 ca.data[r * ca.cols..(r + 1) * ca.cols]
1333 .iter()
1334 .collect::<String>()
1335 })
1336 .collect();
1337 assert_eq!(
1338 strings,
1339 vec!["al ".to_string(), "ally".to_string(), "bob ".to_string()]
1340 );
1341 }
1342 other => panic!("expected char array, got {other:?}"),
1343 }
1344 }
1345
1346 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1347 #[test]
1348 fn sortrows_complex_abs() {
1349 let tensor = ComplexTensor::new(vec![(1.0, 2.0), (-2.0, 1.0)], vec![2, 1]).unwrap();
1350 let eval = evaluate(
1351 Value::ComplexTensor(tensor),
1352 &[Value::from("ComparisonMethod"), Value::from("abs")],
1353 )
1354 .expect("evaluate");
1355 let (sorted, _) = eval.into_values();
1356 match sorted {
1357 Value::ComplexTensor(ct) => {
1358 assert_eq!(ct.data, vec![(-2.0, 1.0), (1.0, 2.0)]);
1359 }
1360 other => panic!("expected complex tensor, got {other:?}"),
1361 }
1362 }
1363
1364 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1365 #[test]
1366 fn sortrows_invalid_column_index_errors() {
1367 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1368 let err = evaluate(Value::Tensor(tensor), &[Value::Int(IntValue::I32(3))]).unwrap_err();
1369 assert_eq!(
1370 err.identifier(),
1371 SORTROWS_ERROR_INVALID_COLUMN_INDEX.identifier
1372 );
1373 }
1374
1375 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1376 #[test]
1377 fn sortrows_missingplacement_first_moves_nan_first() {
1378 let tensor = Tensor::new(vec![1.0, f64::NAN, 2.0, 3.0], vec![2, 2]).unwrap();
1379 let eval = evaluate(
1380 Value::Tensor(tensor),
1381 &[Value::from("MissingPlacement"), Value::from("first")],
1382 )
1383 .expect("evaluate");
1384 let (sorted, indices) = eval.into_values();
1385 match sorted {
1386 Value::Tensor(t) => {
1387 assert!(t.data[0].is_nan());
1388 assert_eq!(t.data[1], 1.0);
1389 assert_eq!(t.data[2], 3.0);
1390 assert_eq!(t.data[3], 2.0);
1391 }
1392 other => panic!("expected tensor, got {other:?}"),
1393 }
1394 match indices {
1395 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0]),
1396 Value::Num(_) => panic!("expected tensor indices"),
1397 other => panic!("unexpected indices {other:?}"),
1398 }
1399 }
1400
1401 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1402 #[test]
1403 fn sortrows_missingplacement_last_descend_moves_nan_last() {
1404 let tensor = Tensor::new(vec![f64::NAN, 5.0, 1.0, 2.0], vec![2, 2]).unwrap();
1405 let eval = evaluate(
1406 Value::Tensor(tensor),
1407 &[
1408 Value::from("descend"),
1409 Value::from("MissingPlacement"),
1410 Value::from("last"),
1411 ],
1412 )
1413 .expect("evaluate");
1414 let (sorted, indices) = eval.into_values();
1415 match sorted {
1416 Value::Tensor(t) => {
1417 assert_eq!(t.data[0], 5.0);
1418 assert!(t.data[1].is_nan());
1419 assert_eq!(t.data[2], 2.0);
1420 assert_eq!(t.data[3], 1.0);
1421 }
1422 other => panic!("expected tensor, got {other:?}"),
1423 }
1424 match indices {
1425 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0]),
1426 Value::Num(_) => panic!("expected tensor indices"),
1427 other => panic!("unexpected indices {other:?}"),
1428 }
1429 }
1430
1431 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1432 #[test]
1433 fn sortrows_missingplacement_invalid_value_errors() {
1434 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1435 let err = evaluate(
1436 Value::Tensor(tensor),
1437 &[Value::from("MissingPlacement"), Value::from("middle")],
1438 )
1439 .unwrap_err();
1440 assert_eq!(
1441 err.identifier(),
1442 SORTROWS_ERROR_MISSING_PLACEMENT_UNKNOWN.identifier
1443 );
1444 }
1445
1446 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1447 #[test]
1448 fn sortrows_gpu_roundtrip() {
1449 test_support::with_test_provider(|provider| {
1450 let tensor = Tensor::new(vec![3.0, 1.0, 2.0, 4.0, 1.0, 5.0], vec![3, 2]).unwrap();
1451 let view = runmat_accelerate_api::HostTensorView {
1452 data: &tensor.data,
1453 shape: &tensor.shape,
1454 };
1455 let handle = provider.upload(&view).expect("upload");
1456 let eval = evaluate(Value::GpuTensor(handle), &[]).expect("evaluate");
1457 let (sorted, indices) = eval.into_values();
1458 match sorted {
1459 Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 3.0, 1.0, 5.0, 4.0]),
1460 other => panic!("expected tensor, got {other:?}"),
1461 }
1462 match indices {
1463 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
1464 other => panic!("unexpected indices {other:?}"),
1465 }
1466 });
1467 }
1468
1469 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1470 #[test]
1471 #[cfg(feature = "wgpu")]
1472 fn sortrows_wgpu_matches_cpu() {
1473 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1474 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1475 );
1476
1477 let tensor = Tensor::new(vec![4.0, 2.0, 3.0, 1.0, 2.0, 5.0], vec![3, 2]).unwrap();
1478 let cpu_eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("cpu evaluate");
1479 let (cpu_sorted_val, cpu_indices_val) = cpu_eval.into_values();
1480 let cpu_sorted = match cpu_sorted_val {
1481 Value::Tensor(t) => t,
1482 other => panic!("expected tensor, got {other:?}"),
1483 };
1484 let cpu_indices = match cpu_indices_val {
1485 Value::Tensor(t) => t,
1486 other => panic!("expected tensor indices, got {other:?}"),
1487 };
1488
1489 let view = runmat_accelerate_api::HostTensorView {
1490 data: &tensor.data,
1491 shape: &tensor.shape,
1492 };
1493 let provider = runmat_accelerate_api::provider().expect("provider");
1494 let handle = provider.upload(&view).expect("upload");
1495 let gpu_eval = evaluate(Value::GpuTensor(handle.clone()), &[]).expect("gpu evaluate");
1496 let (gpu_sorted_val, gpu_indices_val) = gpu_eval.into_values();
1497 let gpu_sorted = match gpu_sorted_val {
1498 Value::Tensor(t) => t,
1499 other => panic!("expected tensor, got {other:?}"),
1500 };
1501 let gpu_indices = match gpu_indices_val {
1502 Value::Tensor(t) => t,
1503 other => panic!("expected tensor indices, got {other:?}"),
1504 };
1505
1506 assert_eq!(gpu_sorted.shape, cpu_sorted.shape);
1507 assert_eq!(gpu_sorted.data, cpu_sorted.data);
1508 assert_eq!(gpu_indices.shape, cpu_indices.shape);
1509 assert_eq!(gpu_indices.data, cpu_indices.data);
1510
1511 let _ = provider.free(&handle);
1512 }
1513}