1use std::cmp::Ordering;
4
5use runmat_accelerate_api::{
6 GpuTensorHandle, SortComparison as ProviderSortComparison, SortOrder as ProviderSortOrder,
7 SortResult as ProviderSortResult, SortRowsColumnSpec as ProviderSortRowsColumnSpec,
8};
9use runmat_builtins::{
10 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
11 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
12 CharArray, ComplexTensor, Tensor, Value,
13};
14use runmat_macros::runtime_builtin;
15
16use super::type_resolvers::tensor_output_type;
17use crate::build_runtime_error;
18use crate::builtins::common::gpu_helpers;
19use crate::builtins::common::spec::{
20 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
21 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
22};
23use crate::builtins::common::tensor;
24
25#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::sorting_sets::sortrows")]
26pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
27 name: "sortrows",
28 op_kind: GpuOpKind::Custom("sortrows"),
29 supported_precisions: &[ScalarType::F32, ScalarType::F64],
30 broadcast: BroadcastSemantics::None,
31 provider_hooks: &[ProviderHook::Custom("sortrows")],
32 constant_strategy: ConstantStrategy::InlineLiteral,
33 residency: ResidencyPolicy::GatherImmediately,
34 nan_mode: ReductionNaN::Include,
35 two_pass_threshold: None,
36 workgroup_size: None,
37 accepts_nan_mode: true,
38 notes:
39 "Providers may implement a row-sort kernel; explicit MissingPlacement overrides fall back to host memory until native support exists.",
40};
41
42#[runmat_macros::register_fusion_spec(
43 builtin_path = "crate::builtins::array::sorting_sets::sortrows"
44)]
45pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
46 name: "sortrows",
47 shape: ShapeRequirements::Any,
48 constant_strategy: ConstantStrategy::InlineLiteral,
49 elementwise: None,
50 reduction: None,
51 emits_nan: true,
52 notes: "`sortrows` terminates fusion chains and materialises results on the host; upstream tensors are gathered when necessary.",
53};
54
55const BUILTIN_NAME: &str = "sortrows";
56
57const SORTROWS_OUTPUT_B: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
58 name: "B",
59 ty: BuiltinParamType::Any,
60 arity: BuiltinParamArity::Required,
61 default: None,
62 description: "Sorted input rows.",
63}];
64
65const SORTROWS_OUTPUT_BI: [BuiltinParamDescriptor; 2] = [
66 BuiltinParamDescriptor {
67 name: "B",
68 ty: BuiltinParamType::Any,
69 arity: BuiltinParamArity::Required,
70 default: None,
71 description: "Sorted input rows.",
72 },
73 BuiltinParamDescriptor {
74 name: "I",
75 ty: BuiltinParamType::NumericArray,
76 arity: BuiltinParamArity::Required,
77 default: None,
78 description: "Permutation indices mapping sorted rows to original rows.",
79 },
80];
81
82const SORTROWS_INPUTS_A: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
83 name: "A",
84 ty: BuiltinParamType::Any,
85 arity: BuiltinParamArity::Required,
86 default: None,
87 description: "Input matrix to sort by rows.",
88}];
89
90const SORTROWS_INPUTS_A_COLUMNS: [BuiltinParamDescriptor; 2] = [
91 BuiltinParamDescriptor {
92 name: "A",
93 ty: BuiltinParamType::Any,
94 arity: BuiltinParamArity::Required,
95 default: None,
96 description: "Input matrix to sort by rows.",
97 },
98 BuiltinParamDescriptor {
99 name: "column",
100 ty: BuiltinParamType::NumericArray,
101 arity: BuiltinParamArity::Required,
102 default: None,
103 description: "Column specification vector (negative entries request descending order).",
104 },
105];
106
107const SORTROWS_INPUTS_A_DIRECTION: [BuiltinParamDescriptor; 2] = [
108 BuiltinParamDescriptor {
109 name: "A",
110 ty: BuiltinParamType::Any,
111 arity: BuiltinParamArity::Required,
112 default: None,
113 description: "Input matrix to sort by rows.",
114 },
115 BuiltinParamDescriptor {
116 name: "direction",
117 ty: BuiltinParamType::StringScalar,
118 arity: BuiltinParamArity::Required,
119 default: Some("\"ascend\""),
120 description: "Global row direction override: 'ascend' or 'descend'.",
121 },
122];
123
124const SORTROWS_INPUTS_A_COLUMNS_DIRECTION: [BuiltinParamDescriptor; 3] = [
125 BuiltinParamDescriptor {
126 name: "A",
127 ty: BuiltinParamType::Any,
128 arity: BuiltinParamArity::Required,
129 default: None,
130 description: "Input matrix to sort by rows.",
131 },
132 BuiltinParamDescriptor {
133 name: "column",
134 ty: BuiltinParamType::NumericArray,
135 arity: BuiltinParamArity::Required,
136 default: None,
137 description: "Column specification vector (negative entries request descending order).",
138 },
139 BuiltinParamDescriptor {
140 name: "direction",
141 ty: BuiltinParamType::StringScalar,
142 arity: BuiltinParamArity::Required,
143 default: Some("\"ascend\""),
144 description: "Global row direction override: 'ascend' or 'descend'.",
145 },
146];
147
148const SORTROWS_INPUTS_COMPARISON_METHOD: [BuiltinParamDescriptor; 4] = [
149 BuiltinParamDescriptor {
150 name: "A",
151 ty: BuiltinParamType::Any,
152 arity: BuiltinParamArity::Required,
153 default: None,
154 description: "Input matrix to sort by rows.",
155 },
156 BuiltinParamDescriptor {
157 name: "arg",
158 ty: BuiltinParamType::Any,
159 arity: BuiltinParamArity::Variadic,
160 default: None,
161 description: "Optional column and direction arguments.",
162 },
163 BuiltinParamDescriptor {
164 name: "name",
165 ty: BuiltinParamType::StringScalar,
166 arity: BuiltinParamArity::Required,
167 default: Some("\"ComparisonMethod\""),
168 description: "Name-value option key.",
169 },
170 BuiltinParamDescriptor {
171 name: "method",
172 ty: BuiltinParamType::StringScalar,
173 arity: BuiltinParamArity::Required,
174 default: Some("\"auto\""),
175 description: "Comparison method: 'auto', 'real', or 'abs'.",
176 },
177];
178
179const SORTROWS_INPUTS_MISSING_PLACEMENT: [BuiltinParamDescriptor; 4] = [
180 BuiltinParamDescriptor {
181 name: "A",
182 ty: BuiltinParamType::Any,
183 arity: BuiltinParamArity::Required,
184 default: None,
185 description: "Input matrix to sort by rows.",
186 },
187 BuiltinParamDescriptor {
188 name: "arg",
189 ty: BuiltinParamType::Any,
190 arity: BuiltinParamArity::Variadic,
191 default: None,
192 description: "Optional column and direction arguments.",
193 },
194 BuiltinParamDescriptor {
195 name: "name",
196 ty: BuiltinParamType::StringScalar,
197 arity: BuiltinParamArity::Required,
198 default: Some("\"MissingPlacement\""),
199 description: "Name-value option key.",
200 },
201 BuiltinParamDescriptor {
202 name: "placement",
203 ty: BuiltinParamType::StringScalar,
204 arity: BuiltinParamArity::Required,
205 default: Some("\"auto\""),
206 description: "NaN placement policy: 'auto', 'first', or 'last'.",
207 },
208];
209
210const SORTROWS_SIGNATURES: [BuiltinSignatureDescriptor; 12] = [
211 BuiltinSignatureDescriptor {
212 label: "B = sortrows(A)",
213 inputs: &SORTROWS_INPUTS_A,
214 outputs: &SORTROWS_OUTPUT_B,
215 },
216 BuiltinSignatureDescriptor {
217 label: "B = sortrows(A, column)",
218 inputs: &SORTROWS_INPUTS_A_COLUMNS,
219 outputs: &SORTROWS_OUTPUT_B,
220 },
221 BuiltinSignatureDescriptor {
222 label: "B = sortrows(A, direction)",
223 inputs: &SORTROWS_INPUTS_A_DIRECTION,
224 outputs: &SORTROWS_OUTPUT_B,
225 },
226 BuiltinSignatureDescriptor {
227 label: "B = sortrows(A, column, direction)",
228 inputs: &SORTROWS_INPUTS_A_COLUMNS_DIRECTION,
229 outputs: &SORTROWS_OUTPUT_B,
230 },
231 BuiltinSignatureDescriptor {
232 label: "B = sortrows(A, ..., \"ComparisonMethod\", method)",
233 inputs: &SORTROWS_INPUTS_COMPARISON_METHOD,
234 outputs: &SORTROWS_OUTPUT_B,
235 },
236 BuiltinSignatureDescriptor {
237 label: "B = sortrows(A, ..., \"MissingPlacement\", placement)",
238 inputs: &SORTROWS_INPUTS_MISSING_PLACEMENT,
239 outputs: &SORTROWS_OUTPUT_B,
240 },
241 BuiltinSignatureDescriptor {
242 label: "[B, I] = sortrows(A)",
243 inputs: &SORTROWS_INPUTS_A,
244 outputs: &SORTROWS_OUTPUT_BI,
245 },
246 BuiltinSignatureDescriptor {
247 label: "[B, I] = sortrows(A, column)",
248 inputs: &SORTROWS_INPUTS_A_COLUMNS,
249 outputs: &SORTROWS_OUTPUT_BI,
250 },
251 BuiltinSignatureDescriptor {
252 label: "[B, I] = sortrows(A, direction)",
253 inputs: &SORTROWS_INPUTS_A_DIRECTION,
254 outputs: &SORTROWS_OUTPUT_BI,
255 },
256 BuiltinSignatureDescriptor {
257 label: "[B, I] = sortrows(A, column, direction)",
258 inputs: &SORTROWS_INPUTS_A_COLUMNS_DIRECTION,
259 outputs: &SORTROWS_OUTPUT_BI,
260 },
261 BuiltinSignatureDescriptor {
262 label: "[B, I] = sortrows(A, ..., \"ComparisonMethod\", method)",
263 inputs: &SORTROWS_INPUTS_COMPARISON_METHOD,
264 outputs: &SORTROWS_OUTPUT_BI,
265 },
266 BuiltinSignatureDescriptor {
267 label: "[B, I] = sortrows(A, ..., \"MissingPlacement\", placement)",
268 inputs: &SORTROWS_INPUTS_MISSING_PLACEMENT,
269 outputs: &SORTROWS_OUTPUT_BI,
270 },
271];
272
273const SORTROWS_ERROR_INVALID_COLUMN_INDEX: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
274 code: "RM.SORTROWS.INVALID_COLUMN_INDEX",
275 identifier: Some("RunMat:sortrows:InvalidColumnIndex"),
276 when: "Column specification indices are out of range, zero, or otherwise invalid.",
277 message: "sortrows: invalid column index",
278};
279
280const SORTROWS_ERROR_MISSING_PLACEMENT_UNKNOWN: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
281 code: "RM.SORTROWS.MISSING_PLACEMENT_UNKNOWN",
282 identifier: Some("RunMat:sortrows:MissingPlacementUnknown"),
283 when: "MissingPlacement option value is unsupported.",
284 message: "sortrows: unsupported MissingPlacement value",
285};
286
287const SORTROWS_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
288 code: "RM.SORTROWS.INVALID_ARGUMENT",
289 identifier: Some("RunMat:sortrows:InvalidArgument"),
290 when: "Option parsing receives invalid argument kinds or malformed name-value pairs.",
291 message: "sortrows: invalid argument",
292};
293
294const SORTROWS_ERROR_COMPARISON_METHOD_UNKNOWN: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
295 code: "RM.SORTROWS.COMPARISON_METHOD_UNKNOWN",
296 identifier: Some("RunMat:sortrows:ComparisonMethodUnknown"),
297 when: "ComparisonMethod option value is unsupported.",
298 message: "sortrows: unsupported ComparisonMethod value",
299};
300
301const SORTROWS_ERROR_UNSUPPORTED_INPUT_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
302 code: "RM.SORTROWS.UNSUPPORTED_INPUT_TYPE",
303 identifier: Some("RunMat:sortrows:UnsupportedInputType"),
304 when: "Input cannot be converted to numeric, logical, complex, or char matrix domain.",
305 message: "sortrows: unsupported input type",
306};
307
308const SORTROWS_ERROR_MATRIX_REQUIRED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
309 code: "RM.SORTROWS.MATRIX_REQUIRED",
310 identifier: Some("RunMat:sortrows:MatrixRequired"),
311 when: "Input has rank greater than 2 where matrix input is required.",
312 message: "sortrows: input must be a 2-D matrix",
313};
314
315const SORTROWS_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
316 code: "RM.SORTROWS.INTERNAL",
317 identifier: Some("RunMat:sortrows:Internal"),
318 when: "Internal conversion/allocation/provider decode fails.",
319 message: "sortrows: internal operation failed",
320};
321
322const SORTROWS_ERRORS: [BuiltinErrorDescriptor; 7] = [
323 SORTROWS_ERROR_INVALID_COLUMN_INDEX,
324 SORTROWS_ERROR_MISSING_PLACEMENT_UNKNOWN,
325 SORTROWS_ERROR_INVALID_ARGUMENT,
326 SORTROWS_ERROR_COMPARISON_METHOD_UNKNOWN,
327 SORTROWS_ERROR_UNSUPPORTED_INPUT_TYPE,
328 SORTROWS_ERROR_MATRIX_REQUIRED,
329 SORTROWS_ERROR_INTERNAL,
330];
331
332pub const SORTROWS_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
333 signatures: &SORTROWS_SIGNATURES,
334 output_mode: BuiltinOutputMode::ByRequestedOutputCount,
335 completion_policy: BuiltinCompletionPolicy::Public,
336 errors: &SORTROWS_ERRORS,
337};
338
339fn sortrows_error_with(
340 error: &'static BuiltinErrorDescriptor,
341 message: impl Into<String>,
342) -> crate::RuntimeError {
343 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
344 if let Some(identifier) = error.identifier {
345 builder = builder.with_identifier(identifier);
346 }
347 builder.build()
348}
349
350fn sortrows_error(error: &'static BuiltinErrorDescriptor) -> crate::RuntimeError {
351 sortrows_error_with(error, error.message)
352}
353
354fn sortrows_internal_error(message: impl Into<String>) -> crate::RuntimeError {
355 sortrows_error_with(&SORTROWS_ERROR_INTERNAL, message)
356}
357
358#[runtime_builtin(
359 name = "sortrows",
360 category = "array/sorting_sets",
361 summary = "Sort matrix rows lexicographically with column and direction controls.",
362 keywords = "sortrows,row sort,lexicographic,gpu",
363 accel = "sink",
364 sink = true,
365 type_resolver(tensor_output_type),
366 descriptor(crate::builtins::array::sorting_sets::sortrows::SORTROWS_DESCRIPTOR),
367 builtin_path = "crate::builtins::array::sorting_sets::sortrows"
368)]
369async fn sortrows_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
370 let eval = evaluate(value, &rest).await?;
371 if let Some(out_count) = crate::output_count::current_output_count() {
372 if out_count == 0 {
373 return Ok(Value::OutputList(Vec::new()));
374 }
375 let (sorted, indices) = eval.into_values();
376 let mut outputs = vec![sorted];
377 if out_count >= 2 {
378 outputs.push(indices);
379 }
380 return Ok(crate::output_count::output_list_with_padding(
381 out_count, outputs,
382 ));
383 }
384 Ok(eval.into_sorted_value())
385}
386
387pub async fn evaluate(value: Value, rest: &[Value]) -> crate::BuiltinResult<SortRowsEvaluation> {
389 match value {
390 Value::GpuTensor(handle) => sortrows_gpu(handle, rest).await,
391 other => sortrows_host(other, rest),
392 }
393}
394
395async fn sortrows_gpu(
396 handle: GpuTensorHandle,
397 rest: &[Value],
398) -> crate::BuiltinResult<SortRowsEvaluation> {
399 ensure_matrix_shape(&handle.shape)?;
400 let (_, cols) = rows_cols_from_shape(&handle.shape);
401 let args = SortRowsArgs::parse(rest, cols)?;
402
403 if args.missing_is_auto() {
404 if let Some(provider) = runmat_accelerate_api::provider() {
405 let provider_columns = args.to_provider_columns();
406 let provider_comparison = args.provider_comparison();
407 match provider
408 .sort_rows(&handle, &provider_columns, provider_comparison)
409 .await
410 {
411 Ok(result) => return sortrows_from_provider_result(result),
412 Err(_err) => {
413 }
415 }
416 }
417 }
418
419 let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
420 sortrows_real_tensor_with_args(tensor, &args)
421}
422
423fn sortrows_from_provider_result(
424 result: ProviderSortResult,
425) -> crate::BuiltinResult<SortRowsEvaluation> {
426 let sorted_tensor = Tensor::new(result.values.data, result.values.shape)
427 .map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))?;
428 let indices_tensor = Tensor::new(result.indices.data, result.indices.shape)
429 .map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))?;
430 Ok(SortRowsEvaluation {
431 sorted: tensor::tensor_into_value(sorted_tensor),
432 indices: indices_tensor,
433 })
434}
435
436fn sortrows_host(value: Value, rest: &[Value]) -> crate::BuiltinResult<SortRowsEvaluation> {
437 match value {
438 Value::Tensor(tensor) => sortrows_real_tensor(tensor, rest),
439 Value::LogicalArray(logical) => {
440 let tensor = tensor::logical_to_tensor(&logical)
441 .map_err(|e| sortrows_internal_error(e))?;
442 sortrows_real_tensor(tensor, rest)
443 }
444 Value::Num(_) | Value::Int(_) | Value::Bool(_) => {
445 let tensor = tensor::value_into_tensor_for("sortrows", value)
446 .map_err(|e| sortrows_internal_error(e))?;
447 sortrows_real_tensor(tensor, rest)
448 }
449 Value::ComplexTensor(ct) => sortrows_complex_tensor(ct, rest),
450 Value::Complex(re, im) => {
451 let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
452 .map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))?;
453 sortrows_complex_tensor(tensor, rest)
454 }
455 Value::CharArray(ca) => sortrows_char_array(ca, rest),
456 Value::Object(obj) if obj.is_class(crate::builtins::table::TABLE_CLASS) => {
457 let (sorted, indices) =
458 crate::builtins::table::sortrows_table(Value::Object(obj), rest)?;
459 Ok(SortRowsEvaluation::from_parts(sorted, indices))
460 }
461 other => Err(sortrows_error_with(
462 &SORTROWS_ERROR_UNSUPPORTED_INPUT_TYPE,
463 format!(
464 "sortrows: unsupported input type {:?}; expected numeric, logical, complex, or char arrays",
465 other
466 ),
467 )
468 .into()),
469 }
470}
471
472fn sortrows_real_tensor(
473 tensor: Tensor,
474 rest: &[Value],
475) -> crate::BuiltinResult<SortRowsEvaluation> {
476 ensure_matrix_shape(&tensor.shape)?;
477 let cols = tensor.cols();
478 let args = SortRowsArgs::parse(rest, cols)?;
479 sortrows_real_tensor_with_args(tensor, &args)
480}
481
482fn sortrows_real_tensor_with_args(
483 tensor: Tensor,
484 args: &SortRowsArgs,
485) -> crate::BuiltinResult<SortRowsEvaluation> {
486 let rows = tensor.rows();
487 let cols = tensor.cols();
488
489 if rows <= 1 || cols == 0 || tensor.data.is_empty() || args.columns.is_empty() {
490 let indices = identity_indices(rows)?;
491 return Ok(SortRowsEvaluation {
492 sorted: tensor::tensor_into_value(tensor),
493 indices,
494 });
495 }
496
497 let mut order: Vec<usize> = (0..rows).collect();
498 order.sort_by(|&a, &b| compare_real_rows(&tensor, rows, args, a, b));
499
500 let sorted_tensor = reorder_real_rows(&tensor, rows, cols, &order)?;
501 let indices = permutation_indices(&order)?;
502 Ok(SortRowsEvaluation {
503 sorted: tensor::tensor_into_value(sorted_tensor),
504 indices,
505 })
506}
507
508fn sortrows_complex_tensor(
509 tensor: ComplexTensor,
510 rest: &[Value],
511) -> crate::BuiltinResult<SortRowsEvaluation> {
512 ensure_matrix_shape(&tensor.shape)?;
513 let cols = tensor.cols;
514 let args = SortRowsArgs::parse(rest, cols)?;
515 sortrows_complex_tensor_with_args(tensor, &args)
516}
517
518fn sortrows_complex_tensor_with_args(
519 tensor: ComplexTensor,
520 args: &SortRowsArgs,
521) -> crate::BuiltinResult<SortRowsEvaluation> {
522 let rows = tensor.rows;
523 let cols = tensor.cols;
524
525 if rows <= 1 || cols == 0 || tensor.data.is_empty() || args.columns.is_empty() {
526 let indices = identity_indices(rows)?;
527 return Ok(SortRowsEvaluation {
528 sorted: complex_tensor_into_value(tensor),
529 indices,
530 });
531 }
532
533 let mut order: Vec<usize> = (0..rows).collect();
534 order.sort_by(|&a, &b| compare_complex_rows(&tensor, rows, args, a, b));
535
536 let sorted_tensor = reorder_complex_rows(&tensor, rows, cols, &order)?;
537 let indices = permutation_indices(&order)?;
538 Ok(SortRowsEvaluation {
539 sorted: complex_tensor_into_value(sorted_tensor),
540 indices,
541 })
542}
543
544fn sortrows_char_array(ca: CharArray, rest: &[Value]) -> crate::BuiltinResult<SortRowsEvaluation> {
545 let cols = ca.cols;
546 let args = SortRowsArgs::parse(rest, cols)?;
547 sortrows_char_array_with_args(ca, &args)
548}
549
550fn sortrows_char_array_with_args(
551 ca: CharArray,
552 args: &SortRowsArgs,
553) -> crate::BuiltinResult<SortRowsEvaluation> {
554 let rows = ca.rows;
555 let cols = ca.cols;
556
557 if rows <= 1 || cols == 0 || ca.data.is_empty() || args.columns.is_empty() {
558 let indices = identity_indices(rows)?;
559 return Ok(SortRowsEvaluation {
560 sorted: Value::CharArray(ca),
561 indices,
562 });
563 }
564
565 let mut order: Vec<usize> = (0..rows).collect();
566 order.sort_by(|&a, &b| compare_char_rows(&ca, args, a, b));
567
568 let sorted = reorder_char_rows(&ca, rows, cols, &order)?;
569 let indices = permutation_indices(&order)?;
570 Ok(SortRowsEvaluation {
571 sorted: Value::CharArray(sorted),
572 indices,
573 })
574}
575
576fn ensure_matrix_shape(shape: &[usize]) -> crate::BuiltinResult<()> {
577 if shape.len() <= 2 {
578 Ok(())
579 } else {
580 Err(sortrows_error(&SORTROWS_ERROR_MATRIX_REQUIRED))
581 }
582}
583
584fn rows_cols_from_shape(shape: &[usize]) -> (usize, usize) {
585 match shape.len() {
586 0 => (1, 1),
587 1 => (1, shape[0]),
588 _ => (shape[0], shape[1]),
589 }
590}
591
592fn compare_real_rows(
593 tensor: &Tensor,
594 rows: usize,
595 args: &SortRowsArgs,
596 a: usize,
597 b: usize,
598) -> Ordering {
599 for spec in &args.columns {
600 if spec.index >= tensor.cols() {
601 continue;
602 }
603 let idx_a = a + spec.index * rows;
604 let idx_b = b + spec.index * rows;
605 let va = tensor.data[idx_a];
606 let vb = tensor.data[idx_b];
607 let missing = args.missing_for_direction(spec.direction);
608 let ord = compare_real_scalars(va, vb, spec.direction, args.comparison, missing);
609 if ord != Ordering::Equal {
610 return ord;
611 }
612 }
613 Ordering::Equal
614}
615
616fn compare_complex_rows(
617 tensor: &ComplexTensor,
618 rows: usize,
619 args: &SortRowsArgs,
620 a: usize,
621 b: usize,
622) -> Ordering {
623 for spec in &args.columns {
624 if spec.index >= tensor.cols {
625 continue;
626 }
627 let idx_a = a + spec.index * rows;
628 let idx_b = b + spec.index * rows;
629 let va = tensor.data[idx_a];
630 let vb = tensor.data[idx_b];
631 let missing = args.missing_for_direction(spec.direction);
632 let ord = compare_complex_scalars(va, vb, spec.direction, args.comparison, missing);
633 if ord != Ordering::Equal {
634 return ord;
635 }
636 }
637 Ordering::Equal
638}
639
640fn compare_char_rows(ca: &CharArray, args: &SortRowsArgs, a: usize, b: usize) -> Ordering {
641 for spec in &args.columns {
642 if spec.index >= ca.cols {
643 continue;
644 }
645 let idx_a = a * ca.cols + spec.index;
646 let idx_b = b * ca.cols + spec.index;
647 let va = ca.data[idx_a];
648 let vb = ca.data[idx_b];
649 let ord = match spec.direction {
650 SortDirection::Ascend => va.cmp(&vb),
651 SortDirection::Descend => vb.cmp(&va),
652 };
653 if ord != Ordering::Equal {
654 return ord;
655 }
656 }
657 Ordering::Equal
658}
659
660fn reorder_real_rows(
661 tensor: &Tensor,
662 rows: usize,
663 cols: usize,
664 order: &[usize],
665) -> crate::BuiltinResult<Tensor> {
666 let mut data = vec![0.0; tensor.data.len()];
667 for col in 0..cols {
668 for (dest_row, &src_row) in order.iter().enumerate() {
669 let src_idx = src_row + col * rows;
670 let dst_idx = dest_row + col * rows;
671 data[dst_idx] = tensor.data[src_idx];
672 }
673 }
674 Tensor::new(data, tensor.shape.clone())
675 .map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))
676}
677
678fn reorder_complex_rows(
679 tensor: &ComplexTensor,
680 rows: usize,
681 cols: usize,
682 order: &[usize],
683) -> crate::BuiltinResult<ComplexTensor> {
684 let mut data = vec![(0.0, 0.0); tensor.data.len()];
685 for col in 0..cols {
686 for (dest_row, &src_row) in order.iter().enumerate() {
687 let src_idx = src_row + col * rows;
688 let dst_idx = dest_row + col * rows;
689 data[dst_idx] = tensor.data[src_idx];
690 }
691 }
692 ComplexTensor::new(data, tensor.shape.clone())
693 .map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))
694}
695
696fn reorder_char_rows(
697 ca: &CharArray,
698 rows: usize,
699 cols: usize,
700 order: &[usize],
701) -> crate::BuiltinResult<CharArray> {
702 let mut data = vec!['\0'; ca.data.len()];
703 for (dest_row, &src_row) in order.iter().enumerate() {
704 for col in 0..cols {
705 let src_idx = src_row * cols + col;
706 let dst_idx = dest_row * cols + col;
707 data[dst_idx] = ca.data[src_idx];
708 }
709 }
710 CharArray::new(data, rows, cols).map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))
711}
712
713fn compare_real_scalars(
714 a: f64,
715 b: f64,
716 direction: SortDirection,
717 comparison: ComparisonMethod,
718 missing: MissingPlacementResolved,
719) -> Ordering {
720 match (a.is_nan(), b.is_nan()) {
721 (true, true) => Ordering::Equal,
722 (true, false) => match missing {
723 MissingPlacementResolved::First => Ordering::Less,
724 MissingPlacementResolved::Last => Ordering::Greater,
725 },
726 (false, true) => match missing {
727 MissingPlacementResolved::First => Ordering::Greater,
728 MissingPlacementResolved::Last => Ordering::Less,
729 },
730 (false, false) => compare_real_finite_scalars(a, b, direction, comparison),
731 }
732}
733
734fn compare_real_finite_scalars(
735 a: f64,
736 b: f64,
737 direction: SortDirection,
738 comparison: ComparisonMethod,
739) -> Ordering {
740 if matches!(comparison, ComparisonMethod::Abs) {
741 let abs_cmp = a.abs().partial_cmp(&b.abs()).unwrap_or(Ordering::Equal);
742 if abs_cmp != Ordering::Equal {
743 return match direction {
744 SortDirection::Ascend => abs_cmp,
745 SortDirection::Descend => abs_cmp.reverse(),
746 };
747 }
748 }
749 match direction {
750 SortDirection::Ascend => a.partial_cmp(&b).unwrap_or(Ordering::Equal),
751 SortDirection::Descend => b.partial_cmp(&a).unwrap_or(Ordering::Equal),
752 }
753}
754
755fn compare_complex_scalars(
756 a: (f64, f64),
757 b: (f64, f64),
758 direction: SortDirection,
759 comparison: ComparisonMethod,
760 missing: MissingPlacementResolved,
761) -> Ordering {
762 match (complex_is_nan(a), complex_is_nan(b)) {
763 (true, true) => Ordering::Equal,
764 (true, false) => match missing {
765 MissingPlacementResolved::First => Ordering::Less,
766 MissingPlacementResolved::Last => Ordering::Greater,
767 },
768 (false, true) => match missing {
769 MissingPlacementResolved::First => Ordering::Greater,
770 MissingPlacementResolved::Last => Ordering::Less,
771 },
772 (false, false) => compare_complex_finite_scalars(a, b, direction, comparison),
773 }
774}
775
776fn compare_complex_finite_scalars(
777 a: (f64, f64),
778 b: (f64, f64),
779 direction: SortDirection,
780 comparison: ComparisonMethod,
781) -> Ordering {
782 match comparison {
783 ComparisonMethod::Real => compare_complex_real_first(a, b, direction),
784 ComparisonMethod::Auto | ComparisonMethod::Abs => {
785 let abs_cmp = complex_abs(a)
786 .partial_cmp(&complex_abs(b))
787 .unwrap_or(Ordering::Equal);
788 if abs_cmp != Ordering::Equal {
789 return match direction {
790 SortDirection::Ascend => abs_cmp,
791 SortDirection::Descend => abs_cmp.reverse(),
792 };
793 }
794 compare_complex_real_first(a, b, direction)
795 }
796 }
797}
798
799fn compare_complex_real_first(a: (f64, f64), b: (f64, f64), direction: SortDirection) -> Ordering {
800 let real_cmp = match direction {
801 SortDirection::Ascend => a.0.partial_cmp(&b.0),
802 SortDirection::Descend => b.0.partial_cmp(&a.0),
803 }
804 .unwrap_or(Ordering::Equal);
805 if real_cmp != Ordering::Equal {
806 return real_cmp;
807 }
808 match direction {
809 SortDirection::Ascend => a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal),
810 SortDirection::Descend => b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal),
811 }
812}
813
814fn complex_is_nan(value: (f64, f64)) -> bool {
815 value.0.is_nan() || value.1.is_nan()
816}
817
818fn complex_abs(value: (f64, f64)) -> f64 {
819 value.0.hypot(value.1)
820}
821
822fn permutation_indices(order: &[usize]) -> crate::BuiltinResult<Tensor> {
823 let rows = order.len();
824 let mut data = Vec::with_capacity(rows);
825 for &idx in order {
826 data.push((idx + 1) as f64);
827 }
828 Tensor::new(data, vec![rows, 1]).map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))
829}
830
831fn identity_indices(rows: usize) -> crate::BuiltinResult<Tensor> {
832 let mut data = Vec::with_capacity(rows);
833 for i in 0..rows {
834 data.push((i + 1) as f64);
835 }
836 Tensor::new(data, vec![rows, 1]).map_err(|e| sortrows_internal_error(format!("sortrows: {e}")))
837}
838
839fn complex_tensor_into_value(tensor: ComplexTensor) -> Value {
840 if tensor.data.len() == 1 {
841 Value::Complex(tensor.data[0].0, tensor.data[0].1)
842 } else {
843 Value::ComplexTensor(tensor)
844 }
845}
846
847#[derive(Debug, Clone, Copy, PartialEq, Eq)]
848enum SortDirection {
849 Ascend,
850 Descend,
851}
852
853impl SortDirection {
854 fn from_str(value: &str) -> Option<Self> {
855 match value.trim().to_ascii_lowercase().as_str() {
856 "ascend" | "ascending" => Some(SortDirection::Ascend),
857 "descend" | "descending" => Some(SortDirection::Descend),
858 _ => None,
859 }
860 }
861}
862
863#[derive(Debug, Clone, Copy, PartialEq, Eq)]
864enum ComparisonMethod {
865 Auto,
866 Real,
867 Abs,
868}
869
870#[derive(Debug, Clone, Copy, PartialEq, Eq)]
871enum MissingPlacement {
872 Auto,
873 First,
874 Last,
875}
876
877#[derive(Debug, Clone, Copy, PartialEq, Eq)]
878enum MissingPlacementResolved {
879 First,
880 Last,
881}
882
883impl MissingPlacement {
884 fn resolve(self, direction: SortDirection) -> MissingPlacementResolved {
885 match self {
886 MissingPlacement::First => MissingPlacementResolved::First,
887 MissingPlacement::Last => MissingPlacementResolved::Last,
888 MissingPlacement::Auto => match direction {
889 SortDirection::Ascend => MissingPlacementResolved::Last,
890 SortDirection::Descend => MissingPlacementResolved::First,
891 },
892 }
893 }
894
895 fn is_auto(self) -> bool {
896 matches!(self, MissingPlacement::Auto)
897 }
898}
899
900#[derive(Debug, Clone)]
901struct ColumnSpec {
902 index: usize,
903 direction: SortDirection,
904}
905
906#[derive(Debug, Clone)]
907struct SortRowsArgs {
908 columns: Vec<ColumnSpec>,
909 comparison: ComparisonMethod,
910 missing: MissingPlacement,
911}
912
913impl SortRowsArgs {
914 fn parse(rest: &[Value], num_cols: usize) -> crate::BuiltinResult<Self> {
915 let mut columns: Option<Vec<ColumnSpec>> = None;
916 let mut override_direction: Option<SortDirection> = None;
917 let mut comparison = ComparisonMethod::Auto;
918 let mut missing = MissingPlacement::Auto;
919 let mut i = 0usize;
920
921 while i < rest.len() {
922 if columns.is_none() {
923 if let Some(parsed) = parse_column_vector(&rest[i], num_cols)? {
924 columns = Some(parsed);
925 i += 1;
926 continue;
927 }
928 }
929 if let Some(direction) = parse_direction(&rest[i]) {
930 override_direction = Some(direction);
931 i += 1;
932 continue;
933 }
934 let Some(keyword) = tensor::value_to_string(&rest[i]) else {
935 return Err(sortrows_error_with(
936 &SORTROWS_ERROR_INVALID_ARGUMENT,
937 format!("sortrows: invalid argument {:?}", rest[i]),
938 ));
939 };
940 let lowered = keyword.trim().to_ascii_lowercase();
941 match lowered.as_str() {
942 "comparisonmethod" => {
943 i += 1;
944 if i >= rest.len() {
945 return Err(sortrows_error_with(
946 &SORTROWS_ERROR_INVALID_ARGUMENT,
947 "sortrows: expected a value for 'ComparisonMethod'",
948 ));
949 }
950 let Some(value_str) = tensor::value_to_string(&rest[i]) else {
951 return Err(sortrows_error_with(
952 &SORTROWS_ERROR_INVALID_ARGUMENT,
953 "sortrows: 'ComparisonMethod' expects a string value",
954 )
955 .into());
956 };
957 comparison = match value_str.trim().to_ascii_lowercase().as_str() {
958 "auto" => ComparisonMethod::Auto,
959 "real" => ComparisonMethod::Real,
960 "abs" | "magnitude" => ComparisonMethod::Abs,
961 other => {
962 return Err(sortrows_error_with(
963 &SORTROWS_ERROR_COMPARISON_METHOD_UNKNOWN,
964 format!("sortrows: unsupported ComparisonMethod '{other}'"),
965 )
966 .into())
967 }
968 };
969 i += 1;
970 }
971 "missingplacement" => {
972 i += 1;
973 if i >= rest.len() {
974 return Err(sortrows_error_with(
975 &SORTROWS_ERROR_INVALID_ARGUMENT,
976 "sortrows: expected a value for 'MissingPlacement'",
977 )
978 .into());
979 }
980 let Some(value_str) = tensor::value_to_string(&rest[i]) else {
981 return Err(sortrows_error_with(
982 &SORTROWS_ERROR_INVALID_ARGUMENT,
983 "sortrows: 'MissingPlacement' expects a string value",
984 )
985 .into());
986 };
987 missing = match value_str.trim().to_ascii_lowercase().as_str() {
988 "auto" => MissingPlacement::Auto,
989 "first" => MissingPlacement::First,
990 "last" => MissingPlacement::Last,
991 other => {
992 return Err(sortrows_error_with(
993 &SORTROWS_ERROR_MISSING_PLACEMENT_UNKNOWN,
994 format!("sortrows: unsupported MissingPlacement '{other}'"),
995 )
996 .into())
997 }
998 };
999 i += 1;
1000 }
1001 other => {
1002 return Err(sortrows_error_with(
1003 &SORTROWS_ERROR_INVALID_ARGUMENT,
1004 format!("sortrows: unexpected argument '{other}'"),
1005 ));
1006 }
1007 }
1008 }
1009
1010 let mut columns = columns.unwrap_or_else(|| default_columns(num_cols));
1011 if let Some(dir) = override_direction {
1012 for spec in &mut columns {
1013 spec.direction = dir;
1014 }
1015 }
1016 validate_columns(&columns, num_cols)?;
1017
1018 Ok(SortRowsArgs {
1019 columns,
1020 comparison,
1021 missing,
1022 })
1023 }
1024
1025 fn to_provider_columns(&self) -> Vec<ProviderSortRowsColumnSpec> {
1026 self.columns
1027 .iter()
1028 .map(|spec| ProviderSortRowsColumnSpec {
1029 index: spec.index,
1030 order: match spec.direction {
1031 SortDirection::Ascend => ProviderSortOrder::Ascend,
1032 SortDirection::Descend => ProviderSortOrder::Descend,
1033 },
1034 })
1035 .collect()
1036 }
1037
1038 fn provider_comparison(&self) -> ProviderSortComparison {
1039 match self.comparison {
1040 ComparisonMethod::Auto => ProviderSortComparison::Auto,
1041 ComparisonMethod::Real => ProviderSortComparison::Real,
1042 ComparisonMethod::Abs => ProviderSortComparison::Abs,
1043 }
1044 }
1045
1046 fn missing_for_direction(&self, direction: SortDirection) -> MissingPlacementResolved {
1047 self.missing.resolve(direction)
1048 }
1049
1050 fn missing_is_auto(&self) -> bool {
1051 self.missing.is_auto()
1052 }
1053}
1054
1055fn parse_column_vector(
1056 value: &Value,
1057 num_cols: usize,
1058) -> crate::BuiltinResult<Option<Vec<ColumnSpec>>> {
1059 match value {
1060 Value::Int(i) => parse_single_column(i.to_i64(), num_cols).map(Some),
1061 Value::Num(n) => {
1062 if !n.is_finite() {
1063 return Err(sortrows_error_with(
1064 &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1065 "sortrows: column indices must be finite",
1066 ));
1067 }
1068 let rounded = n.round();
1069 if (rounded - n).abs() > f64::EPSILON {
1070 return Err(sortrows_error_with(
1071 &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1072 "sortrows: column indices must be integers",
1073 ));
1074 }
1075 parse_single_column(rounded as i64, num_cols).map(Some)
1076 }
1077 Value::Tensor(tensor) => {
1078 if !is_vector(&tensor.shape) {
1079 return Err(sortrows_error_with(
1080 &SORTROWS_ERROR_INVALID_ARGUMENT,
1081 "sortrows: column specification must be a vector",
1082 ));
1083 }
1084 let mut specs = Vec::with_capacity(tensor.data.len());
1085 for &entry in &tensor.data {
1086 if !entry.is_finite() {
1087 return Err(sortrows_error_with(
1088 &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1089 "sortrows: column indices must be finite",
1090 ));
1091 }
1092 let rounded = entry.round();
1093 if (rounded - entry).abs() > f64::EPSILON {
1094 return Err(sortrows_error_with(
1095 &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1096 "sortrows: column indices must be integers",
1097 ));
1098 }
1099 let column = parse_single_column_i64(rounded as i64, num_cols)?;
1100 specs.push(column);
1101 }
1102 Ok(Some(specs))
1103 }
1104 _ => Ok(None),
1105 }
1106}
1107
1108fn parse_single_column(value: i64, num_cols: usize) -> crate::BuiltinResult<Vec<ColumnSpec>> {
1109 parse_single_column_i64(value, num_cols).map(|spec| vec![spec])
1110}
1111
1112fn parse_single_column_i64(value: i64, num_cols: usize) -> crate::BuiltinResult<ColumnSpec> {
1113 if value == 0 {
1114 return Err(sortrows_error_with(
1115 &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1116 "sortrows: column indices must be non-zero",
1117 ));
1118 }
1119 let abs = value.unsigned_abs() as usize;
1120 if abs == 0 {
1121 return Err(sortrows_error_with(
1122 &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1123 "sortrows: column indices must be >= 1",
1124 ));
1125 }
1126 if num_cols == 0 {
1127 return Err(sortrows_error_with(
1128 &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1129 "sortrows: column index exceeds matrix with 0 columns",
1130 ));
1131 }
1132 if abs > num_cols {
1133 return Err(sortrows_error_with(
1134 &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1135 format!(
1136 "sortrows: column index {} exceeds matrix with {} columns",
1137 abs, num_cols
1138 ),
1139 )
1140 .into());
1141 }
1142 let direction = if value > 0 {
1143 SortDirection::Ascend
1144 } else {
1145 SortDirection::Descend
1146 };
1147 Ok(ColumnSpec {
1148 index: abs - 1,
1149 direction,
1150 })
1151}
1152
1153fn parse_direction(value: &Value) -> Option<SortDirection> {
1154 tensor::value_to_string(value).and_then(|s| SortDirection::from_str(&s))
1155}
1156
1157fn default_columns(num_cols: usize) -> Vec<ColumnSpec> {
1158 let mut columns = Vec::with_capacity(num_cols);
1159 for col in 0..num_cols {
1160 columns.push(ColumnSpec {
1161 index: col,
1162 direction: SortDirection::Ascend,
1163 });
1164 }
1165 columns
1166}
1167
1168fn validate_columns(columns: &[ColumnSpec], num_cols: usize) -> crate::BuiltinResult<()> {
1169 if num_cols == 0 && columns.iter().any(|spec| spec.index > 0) {
1170 return Err(sortrows_error_with(
1171 &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1172 "sortrows: column index exceeds matrix with 0 columns",
1173 ));
1174 }
1175 for spec in columns {
1176 if num_cols > 0 && spec.index >= num_cols {
1177 return Err(sortrows_error_with(
1178 &SORTROWS_ERROR_INVALID_COLUMN_INDEX,
1179 format!(
1180 "sortrows: column index {} exceeds matrix with {} columns",
1181 spec.index + 1,
1182 num_cols
1183 ),
1184 )
1185 .into());
1186 }
1187 }
1188 Ok(())
1189}
1190
1191fn is_vector(shape: &[usize]) -> bool {
1192 match shape.len() {
1193 0 => true,
1194 1 => true,
1195 2 => shape[0] == 1 || shape[1] == 1,
1196 _ => false,
1197 }
1198}
1199
1200#[derive(Debug)]
1201pub struct SortRowsEvaluation {
1202 sorted: Value,
1203 indices: Tensor,
1204}
1205
1206impl SortRowsEvaluation {
1207 pub(crate) fn from_parts(sorted: Value, indices: Tensor) -> Self {
1208 Self { sorted, indices }
1209 }
1210
1211 pub fn into_sorted_value(self) -> Value {
1212 self.sorted
1213 }
1214
1215 pub fn into_values(self) -> (Value, Value) {
1216 let indices = tensor::tensor_into_value(self.indices);
1217 (self.sorted, indices)
1218 }
1219
1220 pub fn indices_value(&self) -> Value {
1221 tensor::tensor_into_value(self.indices.clone())
1222 }
1223}
1224
1225#[cfg(test)]
1226pub(crate) mod tests {
1227 use super::*;
1228 use crate::builtins::common::test_support;
1229 use runmat_builtins::{IntValue, ResolveContext, Type, Value};
1230
1231 fn evaluate(value: Value, rest: &[Value]) -> crate::BuiltinResult<SortRowsEvaluation> {
1232 futures::executor::block_on(super::evaluate(value, rest))
1233 }
1234
1235 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1236 #[test]
1237 fn sortrows_default_matrix() {
1238 let tensor = Tensor::new(vec![3.0, 1.0, 2.0, 4.0, 1.0, 5.0], vec![3, 2]).unwrap();
1239 let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
1240 let (sorted, indices) = eval.into_values();
1241 match sorted {
1242 Value::Tensor(t) => {
1243 assert_eq!(t.shape, vec![3, 2]);
1244 assert_eq!(t.data, vec![1.0, 2.0, 3.0, 1.0, 5.0, 4.0]);
1245 }
1246 other => panic!("expected tensor, got {other:?}"),
1247 }
1248 match indices {
1249 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
1250 Value::Num(_) => panic!("expected tensor indices"),
1251 other => panic!("unexpected indices {other:?}"),
1252 }
1253 }
1254
1255 #[test]
1256 fn sortrows_type_resolver_tensor() {
1257 assert_eq!(
1258 tensor_output_type(&[Type::tensor()], &ResolveContext::new(Vec::new())),
1259 Type::tensor()
1260 );
1261 }
1262
1263 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1264 #[test]
1265 fn sortrows_with_column_vector() {
1266 let tensor = Tensor::new(
1267 vec![1.0, 3.0, 3.0, 4.0, 2.0, 2.0, 2.0, 5.0, 1.0],
1268 vec![3, 3],
1269 )
1270 .unwrap();
1271 let cols = Tensor::new(vec![2.0, 3.0, 1.0], vec![3, 1]).unwrap();
1272 let eval = evaluate(Value::Tensor(tensor), &[Value::Tensor(cols)]).expect("evaluate");
1273 let (sorted, _) = eval.into_values();
1274 match sorted {
1275 Value::Tensor(t) => {
1276 assert_eq!(t.data, vec![3.0, 3.0, 1.0, 2.0, 2.0, 4.0, 1.0, 5.0, 2.0]);
1277 }
1278 other => panic!("expected tensor, got {other:?}"),
1279 }
1280 }
1281
1282 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1283 #[test]
1284 fn sortrows_direction_descend() {
1285 let tensor = Tensor::new(vec![1.0, 2.0, 4.0, 3.0], vec![2, 2]).unwrap();
1286 let eval = evaluate(Value::Tensor(tensor), &[Value::from("descend")]).expect("evaluate");
1287 let (sorted, _) = eval.into_values();
1288 match sorted {
1289 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0, 3.0, 4.0]),
1290 other => panic!("expected tensor, got {other:?}"),
1291 }
1292 }
1293
1294 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1295 #[test]
1296 fn sortrows_mixed_directions() {
1297 let tensor = Tensor::new(vec![1.0, 1.0, 1.0, 1.0, 7.0, 2.0], vec![3, 2]).unwrap();
1298 let cols = Tensor::new(vec![1.0, -2.0], vec![2, 1]).unwrap();
1299 let eval = evaluate(Value::Tensor(tensor), &[Value::Tensor(cols)]).expect("evaluate");
1300 let (sorted, _) = eval.into_values();
1301 match sorted {
1302 Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 1.0, 1.0, 7.0, 2.0, 1.0]),
1303 other => panic!("expected tensor, got {other:?}"),
1304 }
1305 }
1306
1307 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1308 #[test]
1309 fn sortrows_returns_indices() {
1310 let tensor = Tensor::new(vec![2.0, 1.0, 3.0, 4.0], vec![2, 2]).unwrap();
1311 let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
1312 let (_, indices) = eval.into_values();
1313 match indices {
1314 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0]),
1315 Value::Num(_) => panic!("expected tensor indices"),
1316 other => panic!("unexpected indices {other:?}"),
1317 }
1318 }
1319
1320 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1321 #[test]
1322 fn sortrows_char_array() {
1323 let chars = CharArray::new(
1324 "bob "
1325 .chars()
1326 .chain("al ".chars())
1327 .chain("ally".chars())
1328 .collect(),
1329 3,
1330 4,
1331 )
1332 .unwrap();
1333 let eval = evaluate(Value::CharArray(chars), &[]).expect("evaluate");
1334 let (sorted, _) = eval.into_values();
1335 match sorted {
1336 Value::CharArray(ca) => {
1337 assert_eq!(ca.rows, 3);
1338 assert_eq!(ca.cols, 4);
1339 let strings: Vec<String> = (0..ca.rows)
1340 .map(|r| {
1341 ca.data[r * ca.cols..(r + 1) * ca.cols]
1342 .iter()
1343 .collect::<String>()
1344 })
1345 .collect();
1346 assert_eq!(
1347 strings,
1348 vec!["al ".to_string(), "ally".to_string(), "bob ".to_string()]
1349 );
1350 }
1351 other => panic!("expected char array, got {other:?}"),
1352 }
1353 }
1354
1355 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1356 #[test]
1357 fn sortrows_complex_abs() {
1358 let tensor = ComplexTensor::new(vec![(1.0, 2.0), (-2.0, 1.0)], vec![2, 1]).unwrap();
1359 let eval = evaluate(
1360 Value::ComplexTensor(tensor),
1361 &[Value::from("ComparisonMethod"), Value::from("abs")],
1362 )
1363 .expect("evaluate");
1364 let (sorted, _) = eval.into_values();
1365 match sorted {
1366 Value::ComplexTensor(ct) => {
1367 assert_eq!(ct.data, vec![(-2.0, 1.0), (1.0, 2.0)]);
1368 }
1369 other => panic!("expected complex tensor, got {other:?}"),
1370 }
1371 }
1372
1373 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1374 #[test]
1375 fn sortrows_invalid_column_index_errors() {
1376 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1377 let err = evaluate(Value::Tensor(tensor), &[Value::Int(IntValue::I32(3))]).unwrap_err();
1378 assert_eq!(
1379 err.identifier(),
1380 SORTROWS_ERROR_INVALID_COLUMN_INDEX.identifier
1381 );
1382 }
1383
1384 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1385 #[test]
1386 fn sortrows_missingplacement_first_moves_nan_first() {
1387 let tensor = Tensor::new(vec![1.0, f64::NAN, 2.0, 3.0], vec![2, 2]).unwrap();
1388 let eval = evaluate(
1389 Value::Tensor(tensor),
1390 &[Value::from("MissingPlacement"), Value::from("first")],
1391 )
1392 .expect("evaluate");
1393 let (sorted, indices) = eval.into_values();
1394 match sorted {
1395 Value::Tensor(t) => {
1396 assert!(t.data[0].is_nan());
1397 assert_eq!(t.data[1], 1.0);
1398 assert_eq!(t.data[2], 3.0);
1399 assert_eq!(t.data[3], 2.0);
1400 }
1401 other => panic!("expected tensor, got {other:?}"),
1402 }
1403 match indices {
1404 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0]),
1405 Value::Num(_) => panic!("expected tensor indices"),
1406 other => panic!("unexpected indices {other:?}"),
1407 }
1408 }
1409
1410 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1411 #[test]
1412 fn sortrows_missingplacement_last_descend_moves_nan_last() {
1413 let tensor = Tensor::new(vec![f64::NAN, 5.0, 1.0, 2.0], vec![2, 2]).unwrap();
1414 let eval = evaluate(
1415 Value::Tensor(tensor),
1416 &[
1417 Value::from("descend"),
1418 Value::from("MissingPlacement"),
1419 Value::from("last"),
1420 ],
1421 )
1422 .expect("evaluate");
1423 let (sorted, indices) = eval.into_values();
1424 match sorted {
1425 Value::Tensor(t) => {
1426 assert_eq!(t.data[0], 5.0);
1427 assert!(t.data[1].is_nan());
1428 assert_eq!(t.data[2], 2.0);
1429 assert_eq!(t.data[3], 1.0);
1430 }
1431 other => panic!("expected tensor, got {other:?}"),
1432 }
1433 match indices {
1434 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 1.0]),
1435 Value::Num(_) => panic!("expected tensor indices"),
1436 other => panic!("unexpected indices {other:?}"),
1437 }
1438 }
1439
1440 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1441 #[test]
1442 fn sortrows_missingplacement_invalid_value_errors() {
1443 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1444 let err = evaluate(
1445 Value::Tensor(tensor),
1446 &[Value::from("MissingPlacement"), Value::from("middle")],
1447 )
1448 .unwrap_err();
1449 assert_eq!(
1450 err.identifier(),
1451 SORTROWS_ERROR_MISSING_PLACEMENT_UNKNOWN.identifier
1452 );
1453 }
1454
1455 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1456 #[test]
1457 fn sortrows_gpu_roundtrip() {
1458 test_support::with_test_provider(|provider| {
1459 let tensor = Tensor::new(vec![3.0, 1.0, 2.0, 4.0, 1.0, 5.0], vec![3, 2]).unwrap();
1460 let view = runmat_accelerate_api::HostTensorView {
1461 data: &tensor.data,
1462 shape: &tensor.shape,
1463 };
1464 let handle = provider.upload(&view).expect("upload");
1465 let eval = evaluate(Value::GpuTensor(handle), &[]).expect("evaluate");
1466 let (sorted, indices) = eval.into_values();
1467 match sorted {
1468 Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0, 3.0, 1.0, 5.0, 4.0]),
1469 other => panic!("expected tensor, got {other:?}"),
1470 }
1471 match indices {
1472 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 3.0, 1.0]),
1473 other => panic!("unexpected indices {other:?}"),
1474 }
1475 });
1476 }
1477
1478 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1479 #[test]
1480 #[cfg(feature = "wgpu")]
1481 fn sortrows_wgpu_matches_cpu() {
1482 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1483 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1484 );
1485
1486 let tensor = Tensor::new(vec![4.0, 2.0, 3.0, 1.0, 2.0, 5.0], vec![3, 2]).unwrap();
1487 let cpu_eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("cpu evaluate");
1488 let (cpu_sorted_val, cpu_indices_val) = cpu_eval.into_values();
1489 let cpu_sorted = match cpu_sorted_val {
1490 Value::Tensor(t) => t,
1491 other => panic!("expected tensor, got {other:?}"),
1492 };
1493 let cpu_indices = match cpu_indices_val {
1494 Value::Tensor(t) => t,
1495 other => panic!("expected tensor indices, got {other:?}"),
1496 };
1497
1498 let view = runmat_accelerate_api::HostTensorView {
1499 data: &tensor.data,
1500 shape: &tensor.shape,
1501 };
1502 let provider = runmat_accelerate_api::provider().expect("provider");
1503 let handle = provider.upload(&view).expect("upload");
1504 let gpu_eval = evaluate(Value::GpuTensor(handle.clone()), &[]).expect("gpu evaluate");
1505 let (gpu_sorted_val, gpu_indices_val) = gpu_eval.into_values();
1506 let gpu_sorted = match gpu_sorted_val {
1507 Value::Tensor(t) => t,
1508 other => panic!("expected tensor, got {other:?}"),
1509 };
1510 let gpu_indices = match gpu_indices_val {
1511 Value::Tensor(t) => t,
1512 other => panic!("expected tensor indices, got {other:?}"),
1513 };
1514
1515 assert_eq!(gpu_sorted.shape, cpu_sorted.shape);
1516 assert_eq!(gpu_sorted.data, cpu_sorted.data);
1517 assert_eq!(gpu_indices.shape, cpu_indices.shape);
1518 assert_eq!(gpu_indices.data, cpu_indices.data);
1519
1520 let _ = provider.free(&handle);
1521 }
1522}