1use std::cmp::Ordering;
9use std::collections::{hash_map::Entry, HashMap};
10
11use runmat_accelerate_api::{
12 GpuTensorHandle, HostTensorOwned, UnionOptions, UnionOrder, UnionResult,
13};
14use runmat_builtins::{CharArray, ComplexTensor, StringArray, Tensor, Value};
15use runmat_macros::runtime_builtin;
16
17use crate::builtins::common::gpu_helpers;
18use crate::builtins::common::random_args::complex_tensor_into_value;
19use crate::builtins::common::spec::{
20 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
21 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
22};
23use crate::builtins::common::tensor;
24#[cfg(feature = "doc_export")]
25use crate::register_builtin_doc_text;
26use crate::{register_builtin_fusion_spec, register_builtin_gpu_spec};
27
28#[cfg(feature = "doc_export")]
29pub const DOC_MD: &str = r#"---
30title: "union"
31category: "array/sorting_sets"
32keywords: ["union", "set", "stable", "rows", "indices", "gpu"]
33summary: "Combine two arrays, returning their union with MATLAB-compatible ordering and index outputs."
34references:
35 - https://www.mathworks.com/help/matlab/ref/union.html
36gpu_support:
37 elementwise: false
38 reduction: false
39 precisions: ["f32", "f64"]
40 broadcasting: "none"
41 notes: "When no provider hook is available, tensors are gathered to the host and processed with the CPU implementation."
42fusion:
43 elementwise: false
44 reduction: false
45 max_inputs: 2
46 constants: "inline"
47requires_feature: null
48tested:
49 unit: "builtins::array::sorting_sets::union::tests"
50 integration: "builtins::array::sorting_sets::union::tests::union_gpu_roundtrip"
51---
52
53# What does the `union` function do in MATLAB / RunMat?
54`union(A, B)` creates the set-theoretic union of inputs `A` and `B`, returning a value array `C`
55that contains every distinct element (or row) appearing in either input. Optional second and third
56outputs provide indices back into `A` and `B` that identify which original elements contribute to
57each value in `C`.
58
59## How does the `union` function behave in MATLAB / RunMat?
60- `union(A, B)` flattens the inputs column-major, removes duplicates, and returns the sorted union
61 of their values.
62- `[C, IA, IB] = union(A, B)` also returns index vectors where `IA` points to the elements in `A`
63 that contribute to `C` and `IB` refers to the elements in `B` that only appear in `B`.
64- `union(A, B, 'stable')` preserves the first-seen order: unique values from `A` appear first,
65 followed by new values appearing in `B`.
66- `union(A, B, 'rows')` treats each row as a record. Inputs must share the same number of columns.
67- Character arrays and string arrays are fully supported in both element and row modes.
68- Complex values follow MATLAB's ordering rules (magnitude first, then real/imaginary parts).
69- Legacy compatibility flags (`'legacy'` or `'R2012a'`) are not supported in RunMat.
70
71## GPU execution in RunMat
72- `union` is registered as a residency sink. When tensors live on a GPU and the active provider
73 does not expose a dedicated `union` hook, the runtime gathers them to host memory and executes
74 the CPU implementation to guarantee MATLAB-compatible output.
75- Future providers may implement `ProviderHook::Custom("union")` to keep data on the device.
76
77## Examples of using the `union` function in MATLAB / RunMat
78
79### Combining Two Numeric Vectors
80```matlab
81A = [5 7 1];
82B = [3 1 1];
83[C, IA, IB] = union(A, B);
84```
85Expected output:
86```matlab
87C =
88 1
89 3
90 5
91 7
92IA =
93 3
94 1
95 2
96IB =
97 1
98```
99
100### Preserving Input Order with `'stable'`
101```matlab
102A = [5 7 1];
103B = [3 2 4];
104C = union(A, B, 'stable');
105```
106Expected output:
107```matlab
108C =
109 5
110 7
111 1
112 3
113 2
114 4
115```
116
117### Union of Matrix Rows
118```matlab
119A = [1 2; 3 4; 1 2];
120B = [3 4; 5 6];
121[C, IA, IB] = union(A, B, 'rows');
122```
123Expected output:
124```matlab
125C =
126 1 2
127 3 4
128 5 6
129IA =
130 1
131 2
132IB =
133 2
134```
135
136### Working with Character Arrays
137```matlab
138A = ['m','z'; 'm','a'];
139B = ['a','x'; 'm','a'];
140union(A, B);
141```
142Expected output:
143```matlab
144ans =
145 4x1 char array
146 a
147 m
148 x
149 z
150```
151
152### Union on GPU Data
153```matlab
154G = gpuArray([1 4 2 5]);
155H = gpuArray([2 6]);
156[C, IA, IB] = union(G, H);
157```
158Expected output:
159```matlab
160C =
161 1
162 2
163 4
164 5
165 6
166IA =
167 1
168 3
169 4
170IB =
171 2
172```
173
174## FAQ
175
176### How are the index outputs defined?
177`IA` lists the positions in `A` that contribute to `C`. `IB` lists the positions in `B` for union
178values that do not already appear in `A`. Both use MATLAB's one-based indexing.
179
180### What happens if I request `'stable'` ordering?
181The output keeps the first occurrence order: all unique values from `A` appear first, followed by
182new values from `B` in the order they are encountered.
183
184### Can I use `'rows'` with string or character arrays?
185Yes. `union` accepts string arrays and character arrays with the `'rows'` option, provided the inputs
186share the same number of columns.
187
188### How does `union` treat `NaN` values?
189All `NaN` values are considered equal. Sorted unions place them at the end, while stable unions keep
190the first-seen `NaN`.
191
192### Does GPU execution change the behaviour?
193No. When a provider does not implement a GPU kernel, RunMat automatically gathers inputs to the host,
194performs the union there, and returns host-resident outputs that match MATLAB semantics exactly.
195
196### What if the inputs have different shapes?
197Element-wise unions accept any shapes; both inputs are flattened column-major. Row-wise unions require
198matching column counts and two-dimensional inputs.
199
200### Can I mix data types?
201No. Inputs must belong to the same class (numeric/logical, complex, char, or string). Mixing classes
202produces a descriptive error.
203
204### Are scalar outputs returned as vectors?
205Scalars follow MATLAB's behaviour: a `1×1` union result is represented as a scalar double.
206
207### Do trailing spaces matter for character data?
208Yes. Character arrays treat every character, including trailing spaces, as significant—matching MATLAB's
209behaviour.
210
211### Is the `'legacy'` flag supported?
212No. RunMat only implements the modern MATLAB semantics. Passing `'legacy'` or `'R2012a'` raises an error.
213
214## See Also
215[unique](./unique), [sort](./sort), [sortrows](./sortrows)
216
217## Source & Feedback
218- Source code: [`crates/runmat-runtime/src/builtins/array/sorting_sets/union.rs`](https://github.com/runmat-org/runmat/blob/main/crates/runmat-runtime/src/builtins/array/sorting_sets/union.rs)
219- Found a bug? [Open an issue](https://github.com/runmat-org/runmat/issues/new/choose) with details and a minimal repro.
220"#;
221
222pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
223 name: "union",
224 op_kind: GpuOpKind::Custom("union"),
225 supported_precisions: &[ScalarType::F32, ScalarType::F64],
226 broadcast: BroadcastSemantics::None,
227 provider_hooks: &[ProviderHook::Custom("union")],
228 constant_strategy: ConstantStrategy::InlineLiteral,
229 residency: ResidencyPolicy::GatherImmediately,
230 nan_mode: ReductionNaN::Include,
231 two_pass_threshold: None,
232 workgroup_size: None,
233 accepts_nan_mode: true,
234 notes: "Providers may expose a dedicated union hook; otherwise tensors are gathered and processed on the host.",
235};
236
237register_builtin_gpu_spec!(GPU_SPEC);
238
239pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
240 name: "union",
241 shape: ShapeRequirements::Any,
242 constant_strategy: ConstantStrategy::InlineLiteral,
243 elementwise: None,
244 reduction: None,
245 emits_nan: true,
246 notes: "`union` materialises its inputs and terminates fusion chains; upstream GPU tensors are gathered when necessary.",
247};
248
249register_builtin_fusion_spec!(FUSION_SPEC);
250
251#[cfg(feature = "doc_export")]
252register_builtin_doc_text!("union", DOC_MD);
253
254#[runtime_builtin(
255 name = "union",
256 category = "array/sorting_sets",
257 summary = "Combine two arrays, returning their union with MATLAB-compatible ordering and index outputs.",
258 keywords = "union,set,stable,rows,indices,gpu",
259 accel = "array_construct",
260 sink = true
261)]
262fn union_builtin(a: Value, b: Value, rest: Vec<Value>) -> Result<Value, String> {
263 evaluate(a, b, &rest).map(|eval| eval.into_values_value())
264}
265
266pub fn evaluate(a: Value, b: Value, rest: &[Value]) -> Result<UnionEvaluation, String> {
268 let opts = parse_options(rest)?;
269 match (a, b) {
270 (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
271 union_gpu_pair(handle_a, handle_b, &opts)
272 }
273 (Value::GpuTensor(handle_a), other) => union_gpu_mixed(handle_a, other, &opts, true),
274 (other, Value::GpuTensor(handle_b)) => union_gpu_mixed(handle_b, other, &opts, false),
275 (left, right) => union_host(left, right, &opts),
276 }
277}
278
279fn parse_options(rest: &[Value]) -> Result<UnionOptions, String> {
280 let mut opts = UnionOptions {
281 rows: false,
282 order: UnionOrder::Sorted,
283 };
284 let mut seen_order: Option<UnionOrder> = None;
285
286 for arg in rest {
287 let text = tensor::value_to_string(arg)
288 .ok_or_else(|| "union: expected string option arguments".to_string())?;
289 let lowered = text.trim().to_ascii_lowercase();
290 match lowered.as_str() {
291 "rows" => opts.rows = true,
292 "sorted" => {
293 if let Some(prev) = seen_order {
294 if prev != UnionOrder::Sorted {
295 return Err("union: cannot combine 'sorted' with 'stable'".to_string());
296 }
297 }
298 seen_order = Some(UnionOrder::Sorted);
299 opts.order = UnionOrder::Sorted;
300 }
301 "stable" => {
302 if let Some(prev) = seen_order {
303 if prev != UnionOrder::Stable {
304 return Err("union: cannot combine 'sorted' with 'stable'".to_string());
305 }
306 }
307 seen_order = Some(UnionOrder::Stable);
308 opts.order = UnionOrder::Stable;
309 }
310 "legacy" | "r2012a" => {
311 return Err("union: the 'legacy' behaviour is not supported".to_string());
312 }
313 other => return Err(format!("union: unrecognised option '{other}'")),
314 }
315 }
316
317 Ok(opts)
318}
319
320fn union_gpu_pair(
321 handle_a: GpuTensorHandle,
322 handle_b: GpuTensorHandle,
323 opts: &UnionOptions,
324) -> Result<UnionEvaluation, String> {
325 if let Some(provider) = runmat_accelerate_api::provider() {
326 match provider.union(&handle_a, &handle_b, opts) {
327 Ok(result) => return UnionEvaluation::from_union_result(result),
328 Err(_) => {
329 }
331 }
332 }
333 let tensor_a = gpu_helpers::gather_tensor(&handle_a)?;
334 let tensor_b = gpu_helpers::gather_tensor(&handle_b)?;
335 union_numeric(tensor_a, tensor_b, opts)
336}
337
338fn union_gpu_mixed(
339 handle_gpu: GpuTensorHandle,
340 other: Value,
341 opts: &UnionOptions,
342 gpu_is_a: bool,
343) -> Result<UnionEvaluation, String> {
344 let tensor_gpu = gpu_helpers::gather_tensor(&handle_gpu)?;
345 let tensor_other = tensor::value_into_tensor_for("union", other)?;
346 if gpu_is_a {
347 union_numeric(tensor_gpu, tensor_other, opts)
348 } else {
349 union_numeric(tensor_other, tensor_gpu, opts)
350 }
351}
352
353fn union_host(a: Value, b: Value, opts: &UnionOptions) -> Result<UnionEvaluation, String> {
354 match (a, b) {
355 (Value::ComplexTensor(at), Value::ComplexTensor(bt)) => union_complex(at, bt, opts),
357 (Value::ComplexTensor(at), Value::Complex(re, im)) => {
358 let bt = ComplexTensor::new(vec![(re, im)], vec![1, 1])
359 .map_err(|e| format!("union: {e}"))?;
360 union_complex(at, bt, opts)
361 }
362 (Value::Complex(re, im), Value::ComplexTensor(bt)) => {
363 let at = ComplexTensor::new(vec![(re, im)], vec![1, 1])
364 .map_err(|e| format!("union: {e}"))?;
365 union_complex(at, bt, opts)
366 }
367 (Value::Complex(a_re, a_im), Value::Complex(b_re, b_im)) => {
368 let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
369 .map_err(|e| format!("union: {e}"))?;
370 let bt = ComplexTensor::new(vec![(b_re, b_im)], vec![1, 1])
371 .map_err(|e| format!("union: {e}"))?;
372 union_complex(at, bt, opts)
373 }
374
375 (Value::CharArray(ac), Value::CharArray(bc)) => union_char(ac, bc, opts),
377
378 (Value::StringArray(astring), Value::StringArray(bstring)) => {
380 union_string(astring, bstring, opts)
381 }
382 (Value::StringArray(astring), Value::String(b)) => {
383 let bstring =
384 StringArray::new(vec![b], vec![1, 1]).map_err(|e| format!("union: {e}"))?;
385 union_string(astring, bstring, opts)
386 }
387 (Value::String(a), Value::StringArray(bstring)) => {
388 let astring =
389 StringArray::new(vec![a], vec![1, 1]).map_err(|e| format!("union: {e}"))?;
390 union_string(astring, bstring, opts)
391 }
392 (Value::String(a), Value::String(b)) => {
393 let astring =
394 StringArray::new(vec![a], vec![1, 1]).map_err(|e| format!("union: {e}"))?;
395 let bstring =
396 StringArray::new(vec![b], vec![1, 1]).map_err(|e| format!("union: {e}"))?;
397 union_string(astring, bstring, opts)
398 }
399
400 (left, right) => {
402 let tensor_a = tensor::value_into_tensor_for("union", left)?;
403 let tensor_b = tensor::value_into_tensor_for("union", right)?;
404 union_numeric(tensor_a, tensor_b, opts)
405 }
406 }
407}
408
409fn union_numeric(a: Tensor, b: Tensor, opts: &UnionOptions) -> Result<UnionEvaluation, String> {
410 if opts.rows {
411 union_numeric_rows(a, b, opts)
412 } else {
413 union_numeric_elements(a, b, opts)
414 }
415}
416
417pub fn union_numeric_from_tensors(
419 a: Tensor,
420 b: Tensor,
421 opts: &UnionOptions,
422) -> Result<UnionEvaluation, String> {
423 union_numeric(a, b, opts)
424}
425
426fn union_numeric_elements(
427 a: Tensor,
428 b: Tensor,
429 opts: &UnionOptions,
430) -> Result<UnionEvaluation, String> {
431 let mut entries = Vec::<NumericUnionEntry>::new();
432 let mut map: HashMap<u64, usize> = HashMap::new();
433 let mut order_counter = 0usize;
434
435 for (idx, &value) in a.data.iter().enumerate() {
436 let key = canonicalize_f64(value);
437 match map.entry(key) {
438 Entry::Occupied(_) => {
439 }
441 Entry::Vacant(v) => {
442 let entry_idx = entries.len();
443 entries.push(NumericUnionEntry {
444 value,
445 a_index: Some(idx),
446 b_index: None,
447 order_rank: order_counter,
448 });
449 v.insert(entry_idx);
450 order_counter += 1;
451 }
452 }
453 }
454
455 for (idx, &value) in b.data.iter().enumerate() {
456 let key = canonicalize_f64(value);
457 match map.entry(key) {
458 Entry::Occupied(occ) => {
459 let entry = &mut entries[*occ.get()];
460 if entry.a_index.is_none() && entry.b_index.is_none() {
461 entry.b_index = Some(idx);
462 }
463 }
464 Entry::Vacant(v) => {
465 let entry_idx = entries.len();
466 entries.push(NumericUnionEntry {
467 value,
468 a_index: None,
469 b_index: Some(idx),
470 order_rank: order_counter,
471 });
472 v.insert(entry_idx);
473 order_counter += 1;
474 }
475 }
476 }
477
478 assemble_numeric_union(entries, opts)
479}
480
481fn union_numeric_rows(
482 a: Tensor,
483 b: Tensor,
484 opts: &UnionOptions,
485) -> Result<UnionEvaluation, String> {
486 if a.shape.len() != 2 || b.shape.len() != 2 {
487 return Err("union: 'rows' option requires 2-D numeric matrices".to_string());
488 }
489 if a.shape[1] != b.shape[1] {
490 return Err(
491 "union: inputs must have the same number of columns when using 'rows'".to_string(),
492 );
493 }
494 let rows_a = a.shape[0];
495 let cols = a.shape[1];
496 let rows_b = b.shape[0];
497
498 let mut entries = Vec::<NumericRowUnionEntry>::new();
499 let mut map: HashMap<NumericRowKey, usize> = HashMap::new();
500 let mut order_counter = 0usize;
501
502 for r in 0..rows_a {
503 let mut row_values = Vec::with_capacity(cols);
504 for c in 0..cols {
505 let idx = r + c * rows_a;
506 row_values.push(a.data[idx]);
507 }
508 let key = NumericRowKey::from_slice(&row_values);
509 match map.entry(key) {
510 Entry::Occupied(_) => {}
511 Entry::Vacant(v) => {
512 let entry_idx = entries.len();
513 entries.push(NumericRowUnionEntry {
514 row_data: row_values,
515 a_row: Some(r),
516 b_row: None,
517 order_rank: order_counter,
518 });
519 v.insert(entry_idx);
520 order_counter += 1;
521 }
522 }
523 }
524
525 for r in 0..rows_b {
526 let mut row_values = Vec::with_capacity(cols);
527 for c in 0..cols {
528 let idx = r + c * rows_b;
529 row_values.push(b.data[idx]);
530 }
531 let key = NumericRowKey::from_slice(&row_values);
532 match map.entry(key) {
533 Entry::Occupied(occ) => {
534 let entry = &mut entries[*occ.get()];
535 if entry.a_row.is_none() && entry.b_row.is_none() {
536 entry.b_row = Some(r);
537 }
538 }
539 Entry::Vacant(v) => {
540 let entry_idx = entries.len();
541 entries.push(NumericRowUnionEntry {
542 row_data: row_values,
543 a_row: None,
544 b_row: Some(r),
545 order_rank: order_counter,
546 });
547 v.insert(entry_idx);
548 order_counter += 1;
549 }
550 }
551 }
552
553 assemble_numeric_row_union(entries, opts, cols)
554}
555
556fn union_complex(
557 a: ComplexTensor,
558 b: ComplexTensor,
559 opts: &UnionOptions,
560) -> Result<UnionEvaluation, String> {
561 if opts.rows {
562 union_complex_rows(a, b, opts)
563 } else {
564 union_complex_elements(a, b, opts)
565 }
566}
567
568fn union_complex_elements(
569 a: ComplexTensor,
570 b: ComplexTensor,
571 opts: &UnionOptions,
572) -> Result<UnionEvaluation, String> {
573 let mut entries = Vec::<ComplexUnionEntry>::new();
574 let mut map: HashMap<ComplexKey, usize> = HashMap::new();
575 let mut order_counter = 0usize;
576
577 for (idx, &value) in a.data.iter().enumerate() {
578 let key = ComplexKey::new(value);
579 match map.entry(key) {
580 Entry::Occupied(_) => {}
581 Entry::Vacant(v) => {
582 let entry_idx = entries.len();
583 entries.push(ComplexUnionEntry {
584 value,
585 a_index: Some(idx),
586 b_index: None,
587 order_rank: order_counter,
588 });
589 v.insert(entry_idx);
590 order_counter += 1;
591 }
592 }
593 }
594
595 for (idx, &value) in b.data.iter().enumerate() {
596 let key = ComplexKey::new(value);
597 match map.entry(key) {
598 Entry::Occupied(occ) => {
599 let entry = &mut entries[*occ.get()];
600 if entry.a_index.is_none() && entry.b_index.is_none() {
601 entry.b_index = Some(idx);
602 }
603 }
604 Entry::Vacant(v) => {
605 let entry_idx = entries.len();
606 entries.push(ComplexUnionEntry {
607 value,
608 a_index: None,
609 b_index: Some(idx),
610 order_rank: order_counter,
611 });
612 v.insert(entry_idx);
613 order_counter += 1;
614 }
615 }
616 }
617
618 assemble_complex_union(entries, opts)
619}
620
621fn union_complex_rows(
622 a: ComplexTensor,
623 b: ComplexTensor,
624 opts: &UnionOptions,
625) -> Result<UnionEvaluation, String> {
626 if a.shape.len() != 2 || b.shape.len() != 2 {
627 return Err("union: 'rows' option requires 2-D complex matrices".to_string());
628 }
629 if a.shape[1] != b.shape[1] {
630 return Err(
631 "union: inputs must have the same number of columns when using 'rows'".to_string(),
632 );
633 }
634 let rows_a = a.shape[0];
635 let cols = a.shape[1];
636 let rows_b = b.shape[0];
637
638 let mut entries = Vec::<ComplexRowUnionEntry>::new();
639 let mut map: HashMap<Vec<ComplexKey>, usize> = HashMap::new();
640 let mut order_counter = 0usize;
641
642 for r in 0..rows_a {
643 let mut row_values = Vec::with_capacity(cols);
644 let mut key_row = Vec::with_capacity(cols);
645 for c in 0..cols {
646 let idx = r + c * rows_a;
647 let value = a.data[idx];
648 row_values.push(value);
649 key_row.push(ComplexKey::new(value));
650 }
651 match map.entry(key_row) {
652 Entry::Occupied(_) => {}
653 Entry::Vacant(v) => {
654 let entry_idx = entries.len();
655 entries.push(ComplexRowUnionEntry {
656 row_data: row_values,
657 a_row: Some(r),
658 b_row: None,
659 order_rank: order_counter,
660 });
661 v.insert(entry_idx);
662 order_counter += 1;
663 }
664 }
665 }
666
667 for r in 0..rows_b {
668 let mut row_values = Vec::with_capacity(cols);
669 let mut key_row = Vec::with_capacity(cols);
670 for c in 0..cols {
671 let idx = r + c * rows_b;
672 let value = b.data[idx];
673 row_values.push(value);
674 key_row.push(ComplexKey::new(value));
675 }
676 match map.entry(key_row) {
677 Entry::Occupied(occ) => {
678 let entry = &mut entries[*occ.get()];
679 if entry.a_row.is_none() && entry.b_row.is_none() {
680 entry.b_row = Some(r);
681 }
682 }
683 Entry::Vacant(v) => {
684 let entry_idx = entries.len();
685 entries.push(ComplexRowUnionEntry {
686 row_data: row_values,
687 a_row: None,
688 b_row: Some(r),
689 order_rank: order_counter,
690 });
691 v.insert(entry_idx);
692 order_counter += 1;
693 }
694 }
695 }
696
697 assemble_complex_row_union(entries, opts, cols)
698}
699
700fn union_char(a: CharArray, b: CharArray, opts: &UnionOptions) -> Result<UnionEvaluation, String> {
701 if opts.rows {
702 union_char_rows(a, b, opts)
703 } else {
704 union_char_elements(a, b, opts)
705 }
706}
707
708fn union_char_elements(
709 a: CharArray,
710 b: CharArray,
711 opts: &UnionOptions,
712) -> Result<UnionEvaluation, String> {
713 let mut entries = Vec::<CharUnionEntry>::new();
714 let mut map: HashMap<u32, usize> = HashMap::new();
715 let mut order_counter = 0usize;
716
717 for col in 0..a.cols {
718 for row in 0..a.rows {
719 let linear_idx = row + col * a.rows;
720 let data_idx = row * a.cols + col;
721 let ch = a.data[data_idx];
722 let key = ch as u32;
723 match map.entry(key) {
724 Entry::Occupied(_) => {}
725 Entry::Vacant(v) => {
726 let entry_idx = entries.len();
727 entries.push(CharUnionEntry {
728 ch,
729 a_index: Some(linear_idx),
730 b_index: None,
731 order_rank: order_counter,
732 });
733 v.insert(entry_idx);
734 order_counter += 1;
735 }
736 }
737 }
738 }
739
740 for col in 0..b.cols {
741 for row in 0..b.rows {
742 let linear_idx = row + col * b.rows;
743 let data_idx = row * b.cols + col;
744 let ch = b.data[data_idx];
745 let key = ch as u32;
746 match map.entry(key) {
747 Entry::Occupied(occ) => {
748 let entry = &mut entries[*occ.get()];
749 if entry.a_index.is_none() && entry.b_index.is_none() {
750 entry.b_index = Some(linear_idx);
751 }
752 }
753 Entry::Vacant(v) => {
754 let entry_idx = entries.len();
755 entries.push(CharUnionEntry {
756 ch,
757 a_index: None,
758 b_index: Some(linear_idx),
759 order_rank: order_counter,
760 });
761 v.insert(entry_idx);
762 order_counter += 1;
763 }
764 }
765 }
766 }
767
768 assemble_char_union(entries, opts)
769}
770
771fn union_char_rows(
772 a: CharArray,
773 b: CharArray,
774 opts: &UnionOptions,
775) -> Result<UnionEvaluation, String> {
776 if a.cols != b.cols {
777 return Err(
778 "union: inputs must have the same number of columns when using 'rows'".to_string(),
779 );
780 }
781 let rows_a = a.rows;
782 let rows_b = b.rows;
783 let cols = a.cols;
784
785 let mut entries = Vec::<CharRowUnionEntry>::new();
786 let mut map: HashMap<RowCharKey, usize> = HashMap::new();
787 let mut order_counter = 0usize;
788
789 for r in 0..rows_a {
790 let mut row_values = Vec::with_capacity(cols);
791 for c in 0..cols {
792 let idx = r * cols + c;
793 row_values.push(a.data[idx]);
794 }
795 let key = RowCharKey::from_slice(&row_values);
796 match map.entry(key) {
797 Entry::Occupied(_) => {}
798 Entry::Vacant(v) => {
799 let entry_idx = entries.len();
800 entries.push(CharRowUnionEntry {
801 row_data: row_values,
802 a_row: Some(r),
803 b_row: None,
804 order_rank: order_counter,
805 });
806 v.insert(entry_idx);
807 order_counter += 1;
808 }
809 }
810 }
811
812 for r in 0..rows_b {
813 let mut row_values = Vec::with_capacity(cols);
814 for c in 0..cols {
815 let idx = r * cols + c;
816 row_values.push(b.data[idx]);
817 }
818 let key = RowCharKey::from_slice(&row_values);
819 match map.entry(key) {
820 Entry::Occupied(occ) => {
821 let entry = &mut entries[*occ.get()];
822 if entry.a_row.is_none() && entry.b_row.is_none() {
823 entry.b_row = Some(r);
824 }
825 }
826 Entry::Vacant(v) => {
827 let entry_idx = entries.len();
828 entries.push(CharRowUnionEntry {
829 row_data: row_values,
830 a_row: None,
831 b_row: Some(r),
832 order_rank: order_counter,
833 });
834 v.insert(entry_idx);
835 order_counter += 1;
836 }
837 }
838 }
839
840 assemble_char_row_union(entries, opts, cols)
841}
842
843fn union_string(
844 a: StringArray,
845 b: StringArray,
846 opts: &UnionOptions,
847) -> Result<UnionEvaluation, String> {
848 if opts.rows {
849 union_string_rows(a, b, opts)
850 } else {
851 union_string_elements(a, b, opts)
852 }
853}
854
855fn union_string_elements(
856 a: StringArray,
857 b: StringArray,
858 opts: &UnionOptions,
859) -> Result<UnionEvaluation, String> {
860 let mut entries = Vec::<StringUnionEntry>::new();
861 let mut map: HashMap<String, usize> = HashMap::new();
862 let mut order_counter = 0usize;
863
864 for (idx, value) in a.data.iter().enumerate() {
865 match map.entry(value.clone()) {
866 Entry::Occupied(_) => {}
867 Entry::Vacant(v) => {
868 let entry_idx = entries.len();
869 entries.push(StringUnionEntry {
870 value: value.clone(),
871 a_index: Some(idx),
872 b_index: None,
873 order_rank: order_counter,
874 });
875 v.insert(entry_idx);
876 order_counter += 1;
877 }
878 }
879 }
880
881 for (idx, value) in b.data.iter().enumerate() {
882 match map.entry(value.clone()) {
883 Entry::Occupied(occ) => {
884 let entry = &mut entries[*occ.get()];
885 if entry.a_index.is_none() && entry.b_index.is_none() {
886 entry.b_index = Some(idx);
887 }
888 }
889 Entry::Vacant(v) => {
890 let entry_idx = entries.len();
891 entries.push(StringUnionEntry {
892 value: value.clone(),
893 a_index: None,
894 b_index: Some(idx),
895 order_rank: order_counter,
896 });
897 v.insert(entry_idx);
898 order_counter += 1;
899 }
900 }
901 }
902
903 assemble_string_union(entries, opts)
904}
905
906fn union_string_rows(
907 a: StringArray,
908 b: StringArray,
909 opts: &UnionOptions,
910) -> Result<UnionEvaluation, String> {
911 if a.shape.len() != 2 || b.shape.len() != 2 {
912 return Err("union: 'rows' option requires 2-D string arrays".to_string());
913 }
914 if a.shape[1] != b.shape[1] {
915 return Err(
916 "union: inputs must have the same number of columns when using 'rows'".to_string(),
917 );
918 }
919 let rows_a = a.shape[0];
920 let cols = a.shape[1];
921 let rows_b = b.shape[0];
922
923 let mut entries = Vec::<StringRowUnionEntry>::new();
924 let mut map: HashMap<RowStringKey, usize> = HashMap::new();
925 let mut order_counter = 0usize;
926
927 for r in 0..rows_a {
928 let mut row_values = Vec::with_capacity(cols);
929 for c in 0..cols {
930 let idx = r + c * rows_a;
931 row_values.push(a.data[idx].clone());
932 }
933 let key = RowStringKey(row_values.clone());
934 match map.entry(key) {
935 Entry::Occupied(_) => {}
936 Entry::Vacant(v) => {
937 let entry_idx = entries.len();
938 entries.push(StringRowUnionEntry {
939 row_data: row_values,
940 a_row: Some(r),
941 b_row: None,
942 order_rank: order_counter,
943 });
944 v.insert(entry_idx);
945 order_counter += 1;
946 }
947 }
948 }
949
950 for r in 0..rows_b {
951 let mut row_values = Vec::with_capacity(cols);
952 for c in 0..cols {
953 let idx = r + c * rows_b;
954 row_values.push(b.data[idx].clone());
955 }
956 let key = RowStringKey(row_values.clone());
957 match map.entry(key) {
958 Entry::Occupied(occ) => {
959 let entry = &mut entries[*occ.get()];
960 if entry.a_row.is_none() && entry.b_row.is_none() {
961 entry.b_row = Some(r);
962 }
963 }
964 Entry::Vacant(v) => {
965 let entry_idx = entries.len();
966 entries.push(StringRowUnionEntry {
967 row_data: row_values,
968 a_row: None,
969 b_row: Some(r),
970 order_rank: order_counter,
971 });
972 v.insert(entry_idx);
973 order_counter += 1;
974 }
975 }
976 }
977
978 assemble_string_row_union(entries, opts, cols)
979}
980
981#[derive(Debug, Clone)]
982pub struct UnionEvaluation {
983 values: Value,
984 ia: Tensor,
985 ib: Tensor,
986}
987
988impl UnionEvaluation {
989 fn new(values: Value, ia: Tensor, ib: Tensor) -> Self {
990 Self { values, ia, ib }
991 }
992
993 pub fn from_union_result(result: UnionResult) -> Result<Self, String> {
994 let UnionResult { values, ia, ib } = result;
995 let values_tensor =
996 Tensor::new(values.data, values.shape).map_err(|e| format!("union: {e}"))?;
997 let ia_tensor = Tensor::new(ia.data, ia.shape).map_err(|e| format!("union: {e}"))?;
998 let ib_tensor = Tensor::new(ib.data, ib.shape).map_err(|e| format!("union: {e}"))?;
999 Ok(UnionEvaluation::new(
1000 tensor::tensor_into_value(values_tensor),
1001 ia_tensor,
1002 ib_tensor,
1003 ))
1004 }
1005
1006 pub fn into_numeric_union_result(self) -> Result<UnionResult, String> {
1007 let UnionEvaluation { values, ia, ib } = self;
1008 let values_tensor = tensor::value_into_tensor_for("union", values)?;
1009 Ok(UnionResult {
1010 values: HostTensorOwned {
1011 data: values_tensor.data,
1012 shape: values_tensor.shape,
1013 },
1014 ia: HostTensorOwned {
1015 data: ia.data,
1016 shape: ia.shape,
1017 },
1018 ib: HostTensorOwned {
1019 data: ib.data,
1020 shape: ib.shape,
1021 },
1022 })
1023 }
1024
1025 pub fn into_values_value(self) -> Value {
1026 self.values
1027 }
1028
1029 pub fn into_pair(self) -> (Value, Value) {
1030 let ia = tensor::tensor_into_value(self.ia);
1031 (self.values, ia)
1032 }
1033
1034 pub fn into_triple(self) -> (Value, Value, Value) {
1035 let ia = tensor::tensor_into_value(self.ia);
1036 let ib = tensor::tensor_into_value(self.ib);
1037 (self.values, ia, ib)
1038 }
1039
1040 pub fn values_value(&self) -> Value {
1041 self.values.clone()
1042 }
1043
1044 pub fn ia_value(&self) -> Value {
1045 tensor::tensor_into_value(self.ia.clone())
1046 }
1047
1048 pub fn ib_value(&self) -> Value {
1049 tensor::tensor_into_value(self.ib.clone())
1050 }
1051}
1052
1053#[derive(Debug)]
1054struct NumericUnionEntry {
1055 value: f64,
1056 a_index: Option<usize>,
1057 b_index: Option<usize>,
1058 order_rank: usize,
1059}
1060
1061#[derive(Debug)]
1062struct NumericRowUnionEntry {
1063 row_data: Vec<f64>,
1064 a_row: Option<usize>,
1065 b_row: Option<usize>,
1066 order_rank: usize,
1067}
1068
1069#[derive(Debug)]
1070struct ComplexUnionEntry {
1071 value: (f64, f64),
1072 a_index: Option<usize>,
1073 b_index: Option<usize>,
1074 order_rank: usize,
1075}
1076
1077#[derive(Debug)]
1078struct ComplexRowUnionEntry {
1079 row_data: Vec<(f64, f64)>,
1080 a_row: Option<usize>,
1081 b_row: Option<usize>,
1082 order_rank: usize,
1083}
1084
1085#[derive(Debug)]
1086struct CharUnionEntry {
1087 ch: char,
1088 a_index: Option<usize>,
1089 b_index: Option<usize>,
1090 order_rank: usize,
1091}
1092
1093#[derive(Debug)]
1094struct CharRowUnionEntry {
1095 row_data: Vec<char>,
1096 a_row: Option<usize>,
1097 b_row: Option<usize>,
1098 order_rank: usize,
1099}
1100
1101#[derive(Debug)]
1102struct StringUnionEntry {
1103 value: String,
1104 a_index: Option<usize>,
1105 b_index: Option<usize>,
1106 order_rank: usize,
1107}
1108
1109#[derive(Debug)]
1110struct StringRowUnionEntry {
1111 row_data: Vec<String>,
1112 a_row: Option<usize>,
1113 b_row: Option<usize>,
1114 order_rank: usize,
1115}
1116
1117#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1118struct NumericRowKey(Vec<u64>);
1119
1120impl NumericRowKey {
1121 fn from_slice(values: &[f64]) -> Self {
1122 NumericRowKey(values.iter().map(|&v| canonicalize_f64(v)).collect())
1123 }
1124}
1125
1126#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
1127struct ComplexKey {
1128 re: u64,
1129 im: u64,
1130}
1131
1132impl ComplexKey {
1133 fn new(value: (f64, f64)) -> Self {
1134 Self {
1135 re: canonicalize_f64(value.0),
1136 im: canonicalize_f64(value.1),
1137 }
1138 }
1139}
1140
1141#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1142struct RowCharKey(Vec<u32>);
1143
1144impl RowCharKey {
1145 fn from_slice(values: &[char]) -> Self {
1146 RowCharKey(values.iter().map(|&ch| ch as u32).collect())
1147 }
1148}
1149
1150#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1151struct RowStringKey(Vec<String>);
1152
1153fn assemble_numeric_union(
1154 entries: Vec<NumericUnionEntry>,
1155 opts: &UnionOptions,
1156) -> Result<UnionEvaluation, String> {
1157 let mut order: Vec<usize> = (0..entries.len()).collect();
1158 match opts.order {
1159 UnionOrder::Sorted => {
1160 order.sort_by(|&lhs, &rhs| compare_f64(entries[lhs].value, entries[rhs].value));
1161 }
1162 UnionOrder::Stable => {
1163 order.sort_by_key(|&idx| entries[idx].order_rank);
1164 }
1165 }
1166
1167 let mut values = Vec::with_capacity(order.len());
1168 let mut ia = Vec::new();
1169 let mut ib = Vec::new();
1170 for &idx in &order {
1171 let entry = &entries[idx];
1172 values.push(entry.value);
1173 if let Some(a_idx) = entry.a_index {
1174 ia.push((a_idx + 1) as f64);
1175 } else if let Some(b_idx) = entry.b_index {
1176 ib.push((b_idx + 1) as f64);
1177 }
1178 }
1179
1180 let value_tensor =
1181 Tensor::new(values, vec![order.len(), 1]).map_err(|e| format!("union: {e}"))?;
1182 let ia_len = ia.len();
1183 let ib_len = ib.len();
1184 let ia_tensor = Tensor::new(ia, vec![ia_len, 1]).map_err(|e| format!("union: {e}"))?;
1185 let ib_tensor = Tensor::new(ib, vec![ib_len, 1]).map_err(|e| format!("union: {e}"))?;
1186
1187 Ok(UnionEvaluation::new(
1188 tensor::tensor_into_value(value_tensor),
1189 ia_tensor,
1190 ib_tensor,
1191 ))
1192}
1193
1194fn assemble_numeric_row_union(
1195 entries: Vec<NumericRowUnionEntry>,
1196 opts: &UnionOptions,
1197 cols: usize,
1198) -> Result<UnionEvaluation, String> {
1199 let mut order: Vec<usize> = (0..entries.len()).collect();
1200 match opts.order {
1201 UnionOrder::Sorted => {
1202 order.sort_by(|&lhs, &rhs| {
1203 compare_numeric_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1204 });
1205 }
1206 UnionOrder::Stable => {
1207 order.sort_by_key(|&idx| entries[idx].order_rank);
1208 }
1209 }
1210
1211 let unique_rows = order.len();
1212 let mut values = vec![0.0f64; unique_rows * cols];
1213 let mut ia = Vec::new();
1214 let mut ib = Vec::new();
1215
1216 for (row_pos, &entry_idx) in order.iter().enumerate() {
1217 let entry = &entries[entry_idx];
1218 for col in 0..cols {
1219 let dest = row_pos + col * unique_rows;
1220 values[dest] = entry.row_data[col];
1221 }
1222 if let Some(a_row) = entry.a_row {
1223 ia.push((a_row + 1) as f64);
1224 } else if let Some(b_row) = entry.b_row {
1225 ib.push((b_row + 1) as f64);
1226 }
1227 }
1228
1229 let value_tensor =
1230 Tensor::new(values, vec![unique_rows, cols]).map_err(|e| format!("union: {e}"))?;
1231 let ia_len = ia.len();
1232 let ib_len = ib.len();
1233 let ia_tensor = Tensor::new(ia, vec![ia_len, 1]).map_err(|e| format!("union: {e}"))?;
1234 let ib_tensor = Tensor::new(ib, vec![ib_len, 1]).map_err(|e| format!("union: {e}"))?;
1235
1236 Ok(UnionEvaluation::new(
1237 tensor::tensor_into_value(value_tensor),
1238 ia_tensor,
1239 ib_tensor,
1240 ))
1241}
1242
1243fn assemble_complex_union(
1244 entries: Vec<ComplexUnionEntry>,
1245 opts: &UnionOptions,
1246) -> Result<UnionEvaluation, String> {
1247 let mut order: Vec<usize> = (0..entries.len()).collect();
1248 match opts.order {
1249 UnionOrder::Sorted => {
1250 order.sort_by(|&lhs, &rhs| compare_complex(entries[lhs].value, entries[rhs].value));
1251 }
1252 UnionOrder::Stable => {
1253 order.sort_by_key(|&idx| entries[idx].order_rank);
1254 }
1255 }
1256
1257 let mut values = Vec::with_capacity(order.len());
1258 let mut ia = Vec::new();
1259 let mut ib = Vec::new();
1260 for &idx in &order {
1261 let entry = &entries[idx];
1262 values.push(entry.value);
1263 if let Some(a_idx) = entry.a_index {
1264 ia.push((a_idx + 1) as f64);
1265 } else if let Some(b_idx) = entry.b_index {
1266 ib.push((b_idx + 1) as f64);
1267 }
1268 }
1269
1270 let value_tensor =
1271 ComplexTensor::new(values, vec![order.len(), 1]).map_err(|e| format!("union: {e}"))?;
1272 let ia_len = ia.len();
1273 let ib_len = ib.len();
1274 let ia_tensor = Tensor::new(ia, vec![ia_len, 1]).map_err(|e| format!("union: {e}"))?;
1275 let ib_tensor = Tensor::new(ib, vec![ib_len, 1]).map_err(|e| format!("union: {e}"))?;
1276
1277 Ok(UnionEvaluation::new(
1278 complex_tensor_into_value(value_tensor),
1279 ia_tensor,
1280 ib_tensor,
1281 ))
1282}
1283
1284fn assemble_complex_row_union(
1285 entries: Vec<ComplexRowUnionEntry>,
1286 opts: &UnionOptions,
1287 cols: usize,
1288) -> Result<UnionEvaluation, String> {
1289 let mut order: Vec<usize> = (0..entries.len()).collect();
1290 match opts.order {
1291 UnionOrder::Sorted => {
1292 order.sort_by(|&lhs, &rhs| {
1293 compare_complex_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1294 });
1295 }
1296 UnionOrder::Stable => {
1297 order.sort_by_key(|&idx| entries[idx].order_rank);
1298 }
1299 }
1300
1301 let unique_rows = order.len();
1302 let mut values = vec![(0.0, 0.0); unique_rows * cols];
1303 let mut ia = Vec::new();
1304 let mut ib = Vec::new();
1305
1306 for (row_pos, &entry_idx) in order.iter().enumerate() {
1307 let entry = &entries[entry_idx];
1308 for col in 0..cols {
1309 let dest = row_pos + col * unique_rows;
1310 values[dest] = entry.row_data[col];
1311 }
1312 if let Some(a_row) = entry.a_row {
1313 ia.push((a_row + 1) as f64);
1314 } else if let Some(b_row) = entry.b_row {
1315 ib.push((b_row + 1) as f64);
1316 }
1317 }
1318
1319 let value_tensor =
1320 ComplexTensor::new(values, vec![unique_rows, cols]).map_err(|e| format!("union: {e}"))?;
1321 let ia_len = ia.len();
1322 let ib_len = ib.len();
1323 let ia_tensor = Tensor::new(ia, vec![ia_len, 1]).map_err(|e| format!("union: {e}"))?;
1324 let ib_tensor = Tensor::new(ib, vec![ib_len, 1]).map_err(|e| format!("union: {e}"))?;
1325
1326 Ok(UnionEvaluation::new(
1327 complex_tensor_into_value(value_tensor),
1328 ia_tensor,
1329 ib_tensor,
1330 ))
1331}
1332
1333fn assemble_char_union(
1334 entries: Vec<CharUnionEntry>,
1335 opts: &UnionOptions,
1336) -> Result<UnionEvaluation, String> {
1337 let mut order: Vec<usize> = (0..entries.len()).collect();
1338 match opts.order {
1339 UnionOrder::Sorted => {
1340 order.sort_by(|&lhs, &rhs| entries[lhs].ch.cmp(&entries[rhs].ch));
1341 }
1342 UnionOrder::Stable => {
1343 order.sort_by_key(|&idx| entries[idx].order_rank);
1344 }
1345 }
1346
1347 let mut values = Vec::with_capacity(order.len());
1348 let mut ia = Vec::new();
1349 let mut ib = Vec::new();
1350 for &idx in &order {
1351 let entry = &entries[idx];
1352 values.push(entry.ch);
1353 if let Some(a_idx) = entry.a_index {
1354 ia.push((a_idx + 1) as f64);
1355 } else if let Some(b_idx) = entry.b_index {
1356 ib.push((b_idx + 1) as f64);
1357 }
1358 }
1359
1360 let value_array = CharArray::new(values, order.len(), 1).map_err(|e| format!("union: {e}"))?;
1361 let ia_len = ia.len();
1362 let ib_len = ib.len();
1363 let ia_tensor = Tensor::new(ia, vec![ia_len, 1]).map_err(|e| format!("union: {e}"))?;
1364 let ib_tensor = Tensor::new(ib, vec![ib_len, 1]).map_err(|e| format!("union: {e}"))?;
1365
1366 Ok(UnionEvaluation::new(
1367 Value::CharArray(value_array),
1368 ia_tensor,
1369 ib_tensor,
1370 ))
1371}
1372
1373fn assemble_char_row_union(
1374 entries: Vec<CharRowUnionEntry>,
1375 opts: &UnionOptions,
1376 cols: usize,
1377) -> Result<UnionEvaluation, String> {
1378 let mut order: Vec<usize> = (0..entries.len()).collect();
1379 match opts.order {
1380 UnionOrder::Sorted => {
1381 order.sort_by(|&lhs, &rhs| {
1382 compare_char_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1383 });
1384 }
1385 UnionOrder::Stable => {
1386 order.sort_by_key(|&idx| entries[idx].order_rank);
1387 }
1388 }
1389
1390 let unique_rows = order.len();
1391 let mut values = vec!['\0'; unique_rows * cols];
1392 let mut ia = Vec::new();
1393 let mut ib = Vec::new();
1394
1395 for (row_pos, &entry_idx) in order.iter().enumerate() {
1396 let entry = &entries[entry_idx];
1397 for col in 0..cols {
1398 let dest = row_pos * cols + col;
1399 values[dest] = entry.row_data[col];
1400 }
1401 if let Some(a_row) = entry.a_row {
1402 ia.push((a_row + 1) as f64);
1403 } else if let Some(b_row) = entry.b_row {
1404 ib.push((b_row + 1) as f64);
1405 }
1406 }
1407
1408 let value_array =
1409 CharArray::new(values, unique_rows, cols).map_err(|e| format!("union: {e}"))?;
1410 let ia_len = ia.len();
1411 let ib_len = ib.len();
1412 let ia_tensor = Tensor::new(ia, vec![ia_len, 1]).map_err(|e| format!("union: {e}"))?;
1413 let ib_tensor = Tensor::new(ib, vec![ib_len, 1]).map_err(|e| format!("union: {e}"))?;
1414
1415 Ok(UnionEvaluation::new(
1416 Value::CharArray(value_array),
1417 ia_tensor,
1418 ib_tensor,
1419 ))
1420}
1421
1422fn assemble_string_union(
1423 entries: Vec<StringUnionEntry>,
1424 opts: &UnionOptions,
1425) -> Result<UnionEvaluation, String> {
1426 let mut order: Vec<usize> = (0..entries.len()).collect();
1427 match opts.order {
1428 UnionOrder::Sorted => {
1429 order.sort_by(|&lhs, &rhs| entries[lhs].value.cmp(&entries[rhs].value));
1430 }
1431 UnionOrder::Stable => {
1432 order.sort_by_key(|&idx| entries[idx].order_rank);
1433 }
1434 }
1435
1436 let mut values = Vec::with_capacity(order.len());
1437 let mut ia = Vec::new();
1438 let mut ib = Vec::new();
1439 for &idx in &order {
1440 let entry = &entries[idx];
1441 values.push(entry.value.clone());
1442 if let Some(a_idx) = entry.a_index {
1443 ia.push((a_idx + 1) as f64);
1444 } else if let Some(b_idx) = entry.b_index {
1445 ib.push((b_idx + 1) as f64);
1446 }
1447 }
1448
1449 let value_array =
1450 StringArray::new(values, vec![order.len(), 1]).map_err(|e| format!("union: {e}"))?;
1451 let ia_len = ia.len();
1452 let ib_len = ib.len();
1453 let ia_tensor = Tensor::new(ia, vec![ia_len, 1]).map_err(|e| format!("union: {e}"))?;
1454 let ib_tensor = Tensor::new(ib, vec![ib_len, 1]).map_err(|e| format!("union: {e}"))?;
1455
1456 Ok(UnionEvaluation::new(
1457 Value::StringArray(value_array),
1458 ia_tensor,
1459 ib_tensor,
1460 ))
1461}
1462
1463fn assemble_string_row_union(
1464 entries: Vec<StringRowUnionEntry>,
1465 opts: &UnionOptions,
1466 cols: usize,
1467) -> Result<UnionEvaluation, String> {
1468 let mut order: Vec<usize> = (0..entries.len()).collect();
1469 match opts.order {
1470 UnionOrder::Sorted => {
1471 order.sort_by(|&lhs, &rhs| {
1472 compare_string_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1473 });
1474 }
1475 UnionOrder::Stable => {
1476 order.sort_by_key(|&idx| entries[idx].order_rank);
1477 }
1478 }
1479
1480 let unique_rows = order.len();
1481 let mut values = vec![String::new(); unique_rows * cols];
1482 let mut ia = Vec::new();
1483 let mut ib = Vec::new();
1484
1485 for (row_pos, &entry_idx) in order.iter().enumerate() {
1486 let entry = &entries[entry_idx];
1487 for col in 0..cols {
1488 let dest = row_pos + col * unique_rows;
1489 values[dest] = entry.row_data[col].clone();
1490 }
1491 if let Some(a_row) = entry.a_row {
1492 ia.push((a_row + 1) as f64);
1493 } else if let Some(b_row) = entry.b_row {
1494 ib.push((b_row + 1) as f64);
1495 }
1496 }
1497
1498 let value_array =
1499 StringArray::new(values, vec![unique_rows, cols]).map_err(|e| format!("union: {e}"))?;
1500 let ia_len = ia.len();
1501 let ib_len = ib.len();
1502 let ia_tensor = Tensor::new(ia, vec![ia_len, 1]).map_err(|e| format!("union: {e}"))?;
1503 let ib_tensor = Tensor::new(ib, vec![ib_len, 1]).map_err(|e| format!("union: {e}"))?;
1504
1505 Ok(UnionEvaluation::new(
1506 Value::StringArray(value_array),
1507 ia_tensor,
1508 ib_tensor,
1509 ))
1510}
1511
1512fn canonicalize_f64(value: f64) -> u64 {
1513 if value.is_nan() {
1514 0x7ff8_0000_0000_0000u64
1515 } else if value == 0.0 {
1516 0u64
1517 } else {
1518 value.to_bits()
1519 }
1520}
1521
1522fn compare_f64(a: f64, b: f64) -> Ordering {
1523 if a.is_nan() {
1524 if b.is_nan() {
1525 Ordering::Equal
1526 } else {
1527 Ordering::Greater
1528 }
1529 } else if b.is_nan() {
1530 Ordering::Less
1531 } else {
1532 a.partial_cmp(&b).unwrap_or(Ordering::Equal)
1533 }
1534}
1535
1536fn compare_numeric_rows(a: &[f64], b: &[f64]) -> Ordering {
1537 for (lhs, rhs) in a.iter().zip(b.iter()) {
1538 let ord = compare_f64(*lhs, *rhs);
1539 if ord != Ordering::Equal {
1540 return ord;
1541 }
1542 }
1543 Ordering::Equal
1544}
1545
1546fn complex_is_nan(value: (f64, f64)) -> bool {
1547 value.0.is_nan() || value.1.is_nan()
1548}
1549
1550fn compare_complex(a: (f64, f64), b: (f64, f64)) -> Ordering {
1551 match (complex_is_nan(a), complex_is_nan(b)) {
1552 (true, true) => Ordering::Equal,
1553 (true, false) => Ordering::Greater,
1554 (false, true) => Ordering::Less,
1555 (false, false) => {
1556 let mag_a = a.0.hypot(a.1);
1557 let mag_b = b.0.hypot(b.1);
1558 let mag_cmp = compare_f64(mag_a, mag_b);
1559 if mag_cmp != Ordering::Equal {
1560 return mag_cmp;
1561 }
1562 let re_cmp = compare_f64(a.0, b.0);
1563 if re_cmp != Ordering::Equal {
1564 return re_cmp;
1565 }
1566 compare_f64(a.1, b.1)
1567 }
1568 }
1569}
1570
1571fn compare_complex_rows(a: &[(f64, f64)], b: &[(f64, f64)]) -> Ordering {
1572 for (lhs, rhs) in a.iter().zip(b.iter()) {
1573 let ord = compare_complex(*lhs, *rhs);
1574 if ord != Ordering::Equal {
1575 return ord;
1576 }
1577 }
1578 Ordering::Equal
1579}
1580
1581fn compare_char_rows(a: &[char], b: &[char]) -> Ordering {
1582 for (lhs, rhs) in a.iter().zip(b.iter()) {
1583 let ord = lhs.cmp(rhs);
1584 if ord != Ordering::Equal {
1585 return ord;
1586 }
1587 }
1588 Ordering::Equal
1589}
1590
1591fn compare_string_rows(a: &[String], b: &[String]) -> Ordering {
1592 for (lhs, rhs) in a.iter().zip(b.iter()) {
1593 let ord = lhs.cmp(rhs);
1594 if ord != Ordering::Equal {
1595 return ord;
1596 }
1597 }
1598 Ordering::Equal
1599}
1600
1601#[cfg(test)]
1602mod tests {
1603 use super::*;
1604 use crate::builtins::common::test_support;
1605 use runmat_accelerate_api::HostTensorView;
1606 use runmat_builtins::{IntValue, Tensor, Value};
1607
1608 #[test]
1609 fn union_numeric_sorted_default() {
1610 let a = Tensor::new(vec![5.0, 7.0, 1.0], vec![3, 1]).unwrap();
1611 let b = Tensor::new(vec![3.0, 1.0, 1.0], vec![3, 1]).unwrap();
1612 let eval = evaluate(Value::Tensor(a), Value::Tensor(b), &[]).expect("union");
1613 match eval.values_value() {
1614 Value::Tensor(t) => {
1615 assert_eq!(t.data, vec![1.0, 3.0, 5.0, 7.0]);
1616 assert_eq!(t.shape, vec![4, 1]);
1617 }
1618 other => panic!("expected tensor result, got {other:?}"),
1619 }
1620 let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1621 assert_eq!(ia.data, vec![3.0, 1.0, 2.0]);
1622 assert_eq!(ia.shape, vec![3, 1]);
1623 let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1624 assert_eq!(ib.data, vec![1.0]);
1625 assert_eq!(ib.shape, vec![1, 1]);
1626 }
1627
1628 #[test]
1629 fn union_numeric_stable_order() {
1630 let a = Tensor::new(vec![5.0, 7.0, 1.0], vec![3, 1]).unwrap();
1631 let b = Tensor::new(vec![3.0, 2.0, 4.0], vec![3, 1]).unwrap();
1632 let eval =
1633 evaluate(Value::Tensor(a), Value::Tensor(b), &[Value::from("stable")]).expect("union");
1634 match eval.values_value() {
1635 Value::Tensor(t) => {
1636 assert_eq!(t.data, vec![5.0, 7.0, 1.0, 3.0, 2.0, 4.0]);
1637 assert_eq!(t.shape, vec![6, 1]);
1638 }
1639 other => panic!("expected tensor result, got {other:?}"),
1640 }
1641 let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1642 assert_eq!(ia.data, vec![1.0, 2.0, 3.0]);
1643 let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1644 assert_eq!(ib.data, vec![1.0, 2.0, 3.0]);
1645 }
1646
1647 #[test]
1648 fn union_numeric_sorted_places_nan_last() {
1649 let a = Tensor::new(vec![f64::NAN, 1.0], vec![2, 1]).unwrap();
1650 let b = Tensor::new(vec![2.0, f64::NAN], vec![2, 1]).unwrap();
1651 let eval = evaluate(Value::Tensor(a), Value::Tensor(b), &[]).expect("union");
1652 let values = tensor::value_into_tensor_for("union", eval.values_value()).expect("values");
1653 assert_eq!(values.shape, vec![3, 1]);
1654 assert_eq!(values.data[0], 1.0);
1655 assert_eq!(values.data[1], 2.0);
1656 assert!(values.data[2].is_nan());
1657 let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1658 assert_eq!(ia.data, vec![2.0, 1.0]);
1659 let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1660 assert_eq!(ib.data, vec![1.0]);
1661 }
1662
1663 #[test]
1664 fn union_numeric_rows_sorted() {
1665 let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
1666 let b = Tensor::new(vec![3.0, 5.0, 4.0, 6.0], vec![2, 2]).unwrap();
1667 let eval =
1668 evaluate(Value::Tensor(a), Value::Tensor(b), &[Value::from("rows")]).expect("union");
1669 match eval.values_value() {
1670 Value::Tensor(t) => {
1671 assert_eq!(t.shape, vec![3, 2]);
1672 assert_eq!(t.data, vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
1673 }
1674 other => panic!("expected tensor result, got {other:?}"),
1675 }
1676 let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1677 assert_eq!(ia.data, vec![1.0, 2.0]);
1678 let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1679 assert_eq!(ib.data, vec![2.0]);
1680 }
1681
1682 #[test]
1683 fn union_numeric_rows_stable_preserves_first_occurrence() {
1684 let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
1685 let b = Tensor::new(vec![3.0, 5.0, 1.0, 4.0, 6.0, 2.0], vec![3, 2]).unwrap();
1686 let eval = evaluate(
1687 Value::Tensor(a),
1688 Value::Tensor(b),
1689 &[Value::from("rows"), Value::from("stable")],
1690 )
1691 .expect("union");
1692 let (values, ia, ib) = eval.into_triple();
1693 match values {
1694 Value::Tensor(t) => {
1695 assert_eq!(t.shape, vec![3, 2]);
1696 assert_eq!(t.data, vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
1697 }
1698 other => panic!("expected tensor result, got {other:?}"),
1699 }
1700 let ia_tensor = tensor::value_into_tensor_for("union", ia).expect("ia tensor");
1701 assert_eq!(ia_tensor.data, vec![1.0, 2.0]);
1702 let ib_tensor = tensor::value_into_tensor_for("union", ib).expect("ib tensor");
1703 assert_eq!(ib_tensor.data, vec![2.0]);
1704 }
1705
1706 #[test]
1707 fn union_char_elements() {
1708 let a = CharArray::new(vec!['m', 'z', 'm', 'a'], 2, 2).unwrap();
1709 let b = CharArray::new(vec!['a', 'x', 'm', 'a'], 2, 2).unwrap();
1710 let eval = evaluate(Value::CharArray(a), Value::CharArray(b), &[]).expect("union");
1711 match eval.values_value() {
1712 Value::CharArray(arr) => {
1713 assert_eq!(arr.rows, 4);
1714 assert_eq!(arr.cols, 1);
1715 assert_eq!(arr.data, vec!['a', 'm', 'x', 'z']);
1716 }
1717 other => panic!("expected char array, got {other:?}"),
1718 }
1719 let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1720 assert_eq!(ia.data, vec![4.0, 1.0, 3.0]);
1721 let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1722 assert_eq!(ib.data, vec![3.0]);
1723 }
1724
1725 #[test]
1726 fn union_string_rows_stable() {
1727 let a = StringArray::new(
1728 vec![
1729 "alpha".to_string(),
1730 "gamma".to_string(),
1731 "beta".to_string(),
1732 "beta".to_string(),
1733 ],
1734 vec![2, 2],
1735 )
1736 .unwrap();
1737 let b = StringArray::new(
1738 vec![
1739 "gamma".to_string(),
1740 "delta".to_string(),
1741 "beta".to_string(),
1742 "beta".to_string(),
1743 ],
1744 vec![2, 2],
1745 )
1746 .unwrap();
1747 let eval = evaluate(
1748 Value::StringArray(a),
1749 Value::StringArray(b),
1750 &[Value::from("rows"), Value::from("stable")],
1751 )
1752 .expect("union");
1753 match eval.values_value() {
1754 Value::StringArray(arr) => {
1755 assert_eq!(arr.shape, vec![3, 2]);
1756 assert_eq!(
1757 arr.data,
1758 vec![
1759 "alpha".to_string(),
1760 "gamma".to_string(),
1761 "delta".to_string(),
1762 "beta".to_string(),
1763 "beta".to_string(),
1764 "beta".to_string()
1765 ]
1766 );
1767 }
1768 other => panic!("expected string array, got {other:?}"),
1769 }
1770 let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1771 assert_eq!(ia.data, vec![1.0, 2.0]);
1772 let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1773 assert_eq!(ib.data, vec![2.0]);
1774 }
1775
1776 #[test]
1777 fn union_gpu_roundtrip() {
1778 test_support::with_test_provider(|provider| {
1779 let a = Tensor::new(vec![4.0, 1.0, 2.0], vec![3, 1]).unwrap();
1780 let b = Tensor::new(vec![2.0, 5.0], vec![2, 1]).unwrap();
1781 let view_a = HostTensorView {
1782 data: &a.data,
1783 shape: &a.shape,
1784 };
1785 let view_b = HostTensorView {
1786 data: &b.data,
1787 shape: &b.shape,
1788 };
1789 let handle_a = provider.upload(&view_a).expect("upload A");
1790 let handle_b = provider.upload(&view_b).expect("upload B");
1791 let eval = evaluate(
1792 Value::GpuTensor(handle_a),
1793 Value::GpuTensor(handle_b),
1794 &[Value::from("stable")],
1795 )
1796 .expect("union");
1797 let values = tensor::value_into_tensor_for("union", eval.values_value()).unwrap();
1798 assert_eq!(values.data, vec![4.0, 1.0, 2.0, 5.0]);
1799 let ia = tensor::value_into_tensor_for("union", eval.ia_value()).unwrap();
1800 assert_eq!(ia.data, vec![1.0, 2.0, 3.0]);
1801 let ib = tensor::value_into_tensor_for("union", eval.ib_value()).unwrap();
1802 assert_eq!(ib.data, vec![2.0]);
1803 });
1804 }
1805
1806 #[test]
1807 fn union_rejects_legacy_option() {
1808 let tensor =
1809 Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).expect("tensor construction failed");
1810 let err = evaluate(
1811 Value::Tensor(tensor.clone()),
1812 Value::Tensor(tensor),
1813 &[Value::from("legacy")],
1814 )
1815 .unwrap_err();
1816 assert!(err.contains("legacy"));
1817 }
1818
1819 #[test]
1820 fn union_rows_dimension_mismatch() {
1821 let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1822 let b = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1823 let err = evaluate(Value::Tensor(a), Value::Tensor(b), &[Value::from("rows")]).unwrap_err();
1824 assert!(err.contains("same number of columns"));
1825 }
1826
1827 #[test]
1828 fn union_requires_matching_types() {
1829 let a = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1830 let b = CharArray::new(vec!['a', 'b'], 1, 2).unwrap();
1831 let err = union_host(
1832 Value::Tensor(a),
1833 Value::CharArray(b),
1834 &UnionOptions {
1835 rows: false,
1836 order: UnionOrder::Sorted,
1837 },
1838 )
1839 .unwrap_err();
1840 assert!(err.contains("unsupported input type"));
1841 }
1842
1843 #[test]
1844 fn union_accepts_scalar_inputs() {
1845 let eval = evaluate(Value::Int(IntValue::I32(1)), Value::Num(3.0), &[]).expect("union");
1846 match eval.values_value() {
1847 Value::Tensor(t) => {
1848 assert_eq!(t.data, vec![1.0, 3.0]);
1849 assert_eq!(t.shape, vec![2, 1]);
1850 }
1851 other => panic!("expected numeric tensor, got {other:?}"),
1852 }
1853 let ia = tensor::value_into_tensor_for("union", eval.ia_value()).unwrap();
1854 assert_eq!(ia.data, vec![1.0]);
1855 let ib = tensor::value_into_tensor_for("union", eval.ib_value()).unwrap();
1856 assert_eq!(ib.data, vec![1.0]);
1857 }
1858
1859 #[test]
1860 #[cfg(feature = "wgpu")]
1861 fn union_wgpu_matches_cpu() {
1862 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1863 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1864 );
1865 let a = Tensor::new(vec![4.0, 1.0, 2.0, 3.0], vec![4, 1]).unwrap();
1866 let b = Tensor::new(vec![2.0, 6.0, 3.0], vec![3, 1]).unwrap();
1867
1868 let cpu_eval =
1869 evaluate(Value::Tensor(a.clone()), Value::Tensor(b.clone()), &[]).expect("union");
1870 let cpu_values = tensor::value_into_tensor_for("union", cpu_eval.values_value()).unwrap();
1871 let cpu_ia = tensor::value_into_tensor_for("union", cpu_eval.ia_value()).unwrap();
1872 let cpu_ib = tensor::value_into_tensor_for("union", cpu_eval.ib_value()).unwrap();
1873
1874 let provider = runmat_accelerate_api::provider().expect("provider");
1875 let view_a = HostTensorView {
1876 data: &a.data,
1877 shape: &a.shape,
1878 };
1879 let view_b = HostTensorView {
1880 data: &b.data,
1881 shape: &b.shape,
1882 };
1883 let handle_a = provider.upload(&view_a).expect("upload A");
1884 let handle_b = provider.upload(&view_b).expect("upload B");
1885 let gpu_eval =
1886 evaluate(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[]).expect("union");
1887 let gpu_values = tensor::value_into_tensor_for("union", gpu_eval.values_value()).unwrap();
1888 let gpu_ia = tensor::value_into_tensor_for("union", gpu_eval.ia_value()).unwrap();
1889 let gpu_ib = tensor::value_into_tensor_for("union", gpu_eval.ib_value()).unwrap();
1890
1891 assert_eq!(gpu_values.data, cpu_values.data);
1892 assert_eq!(gpu_values.shape, cpu_values.shape);
1893 assert_eq!(gpu_ia.data, cpu_ia.data);
1894 assert_eq!(gpu_ib.data, cpu_ib.data);
1895 }
1896
1897 #[test]
1898 #[cfg(feature = "doc_export")]
1899 fn doc_examples_present() {
1900 let blocks = test_support::doc_examples(DOC_MD);
1901 assert!(!blocks.is_empty());
1902 }
1903}