runmat_runtime/builtins/array/sorting_sets/
union.rs

1//! MATLAB-compatible `union` builtin with GPU-aware semantics for RunMat.
2//!
3//! Handles element-wise and row-wise unions with optional stable ordering and
4//! index outputs that mirror MathWorks MATLAB semantics. GPU tensors are
5//! gathered to host memory unless a provider supplies a dedicated `union`
6//! kernel hook.
7
8use std::cmp::Ordering;
9use std::collections::{hash_map::Entry, HashMap};
10
11use runmat_accelerate_api::{
12    GpuTensorHandle, HostTensorOwned, UnionOptions, UnionOrder, UnionResult,
13};
14use runmat_builtins::{CharArray, ComplexTensor, StringArray, Tensor, Value};
15use runmat_macros::runtime_builtin;
16
17use crate::builtins::common::gpu_helpers;
18use crate::builtins::common::random_args::complex_tensor_into_value;
19use crate::builtins::common::spec::{
20    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
21    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
22};
23use crate::builtins::common::tensor;
24#[cfg(feature = "doc_export")]
25use crate::register_builtin_doc_text;
26use crate::{register_builtin_fusion_spec, register_builtin_gpu_spec};
27
28#[cfg(feature = "doc_export")]
29pub const DOC_MD: &str = r#"---
30title: "union"
31category: "array/sorting_sets"
32keywords: ["union", "set", "stable", "rows", "indices", "gpu"]
33summary: "Combine two arrays, returning their union with MATLAB-compatible ordering and index outputs."
34references:
35  - https://www.mathworks.com/help/matlab/ref/union.html
36gpu_support:
37  elementwise: false
38  reduction: false
39  precisions: ["f32", "f64"]
40  broadcasting: "none"
41  notes: "When no provider hook is available, tensors are gathered to the host and processed with the CPU implementation."
42fusion:
43  elementwise: false
44  reduction: false
45  max_inputs: 2
46  constants: "inline"
47requires_feature: null
48tested:
49  unit: "builtins::array::sorting_sets::union::tests"
50  integration: "builtins::array::sorting_sets::union::tests::union_gpu_roundtrip"
51---
52
53# What does the `union` function do in MATLAB / RunMat?
54`union(A, B)` creates the set-theoretic union of inputs `A` and `B`, returning a value array `C`
55that contains every distinct element (or row) appearing in either input. Optional second and third
56outputs provide indices back into `A` and `B` that identify which original elements contribute to
57each value in `C`.
58
59## How does the `union` function behave in MATLAB / RunMat?
60- `union(A, B)` flattens the inputs column-major, removes duplicates, and returns the sorted union
61  of their values.
62- `[C, IA, IB] = union(A, B)` also returns index vectors where `IA` points to the elements in `A`
63  that contribute to `C` and `IB` refers to the elements in `B` that only appear in `B`.
64- `union(A, B, 'stable')` preserves the first-seen order: unique values from `A` appear first,
65  followed by new values appearing in `B`.
66- `union(A, B, 'rows')` treats each row as a record. Inputs must share the same number of columns.
67- Character arrays and string arrays are fully supported in both element and row modes.
68- Complex values follow MATLAB's ordering rules (magnitude first, then real/imaginary parts).
69- Legacy compatibility flags (`'legacy'` or `'R2012a'`) are not supported in RunMat.
70
71## GPU execution in RunMat
72- `union` is registered as a residency sink. When tensors live on a GPU and the active provider
73  does not expose a dedicated `union` hook, the runtime gathers them to host memory and executes
74  the CPU implementation to guarantee MATLAB-compatible output.
75- Future providers may implement `ProviderHook::Custom("union")` to keep data on the device.
76
77## Examples of using the `union` function in MATLAB / RunMat
78
79### Combining Two Numeric Vectors
80```matlab
81A = [5 7 1];
82B = [3 1 1];
83[C, IA, IB] = union(A, B);
84```
85Expected output:
86```matlab
87C =
88     1
89     3
90     5
91     7
92IA =
93     3
94     1
95     2
96IB =
97     1
98```
99
100### Preserving Input Order with `'stable'`
101```matlab
102A = [5 7 1];
103B = [3 2 4];
104C = union(A, B, 'stable');
105```
106Expected output:
107```matlab
108C =
109     5
110     7
111     1
112     3
113     2
114     4
115```
116
117### Union of Matrix Rows
118```matlab
119A = [1 2; 3 4; 1 2];
120B = [3 4; 5 6];
121[C, IA, IB] = union(A, B, 'rows');
122```
123Expected output:
124```matlab
125C =
126     1     2
127     3     4
128     5     6
129IA =
130     1
131     2
132IB =
133     2
134```
135
136### Working with Character Arrays
137```matlab
138A = ['m','z'; 'm','a'];
139B = ['a','x'; 'm','a'];
140union(A, B);
141```
142Expected output:
143```matlab
144ans =
145  4x1 char array
146    a
147    m
148    x
149    z
150```
151
152### Union on GPU Data
153```matlab
154G = gpuArray([1 4 2 5]);
155H = gpuArray([2 6]);
156[C, IA, IB] = union(G, H);
157```
158Expected output:
159```matlab
160C =
161     1
162     2
163     4
164     5
165     6
166IA =
167     1
168     3
169     4
170IB =
171     2
172```
173
174## FAQ
175
176### How are the index outputs defined?
177`IA` lists the positions in `A` that contribute to `C`. `IB` lists the positions in `B` for union
178values that do not already appear in `A`. Both use MATLAB's one-based indexing.
179
180### What happens if I request `'stable'` ordering?
181The output keeps the first occurrence order: all unique values from `A` appear first, followed by
182new values from `B` in the order they are encountered.
183
184### Can I use `'rows'` with string or character arrays?
185Yes. `union` accepts string arrays and character arrays with the `'rows'` option, provided the inputs
186share the same number of columns.
187
188### How does `union` treat `NaN` values?
189All `NaN` values are considered equal. Sorted unions place them at the end, while stable unions keep
190the first-seen `NaN`.
191
192### Does GPU execution change the behaviour?
193No. When a provider does not implement a GPU kernel, RunMat automatically gathers inputs to the host,
194performs the union there, and returns host-resident outputs that match MATLAB semantics exactly.
195
196### What if the inputs have different shapes?
197Element-wise unions accept any shapes; both inputs are flattened column-major. Row-wise unions require
198matching column counts and two-dimensional inputs.
199
200### Can I mix data types?
201No. Inputs must belong to the same class (numeric/logical, complex, char, or string). Mixing classes
202produces a descriptive error.
203
204### Are scalar outputs returned as vectors?
205Scalars follow MATLAB's behaviour: a `1×1` union result is represented as a scalar double.
206
207### Do trailing spaces matter for character data?
208Yes. Character arrays treat every character, including trailing spaces, as significant—matching MATLAB's
209behaviour.
210
211### Is the `'legacy'` flag supported?
212No. RunMat only implements the modern MATLAB semantics. Passing `'legacy'` or `'R2012a'` raises an error.
213
214## See Also
215[unique](./unique), [sort](./sort), [sortrows](./sortrows)
216
217## Source & Feedback
218- Source code: [`crates/runmat-runtime/src/builtins/array/sorting_sets/union.rs`](https://github.com/runmat-org/runmat/blob/main/crates/runmat-runtime/src/builtins/array/sorting_sets/union.rs)
219- Found a bug? [Open an issue](https://github.com/runmat-org/runmat/issues/new/choose) with details and a minimal repro.
220"#;
221
222pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
223    name: "union",
224    op_kind: GpuOpKind::Custom("union"),
225    supported_precisions: &[ScalarType::F32, ScalarType::F64],
226    broadcast: BroadcastSemantics::None,
227    provider_hooks: &[ProviderHook::Custom("union")],
228    constant_strategy: ConstantStrategy::InlineLiteral,
229    residency: ResidencyPolicy::GatherImmediately,
230    nan_mode: ReductionNaN::Include,
231    two_pass_threshold: None,
232    workgroup_size: None,
233    accepts_nan_mode: true,
234    notes: "Providers may expose a dedicated union hook; otherwise tensors are gathered and processed on the host.",
235};
236
237register_builtin_gpu_spec!(GPU_SPEC);
238
239pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
240    name: "union",
241    shape: ShapeRequirements::Any,
242    constant_strategy: ConstantStrategy::InlineLiteral,
243    elementwise: None,
244    reduction: None,
245    emits_nan: true,
246    notes: "`union` materialises its inputs and terminates fusion chains; upstream GPU tensors are gathered when necessary.",
247};
248
249register_builtin_fusion_spec!(FUSION_SPEC);
250
251#[cfg(feature = "doc_export")]
252register_builtin_doc_text!("union", DOC_MD);
253
254#[runtime_builtin(
255    name = "union",
256    category = "array/sorting_sets",
257    summary = "Combine two arrays, returning their union with MATLAB-compatible ordering and index outputs.",
258    keywords = "union,set,stable,rows,indices,gpu",
259    accel = "array_construct",
260    sink = true
261)]
262fn union_builtin(a: Value, b: Value, rest: Vec<Value>) -> Result<Value, String> {
263    evaluate(a, b, &rest).map(|eval| eval.into_values_value())
264}
265
266/// Evaluate the `union` builtin once and expose all outputs.
267pub fn evaluate(a: Value, b: Value, rest: &[Value]) -> Result<UnionEvaluation, String> {
268    let opts = parse_options(rest)?;
269    match (a, b) {
270        (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
271            union_gpu_pair(handle_a, handle_b, &opts)
272        }
273        (Value::GpuTensor(handle_a), other) => union_gpu_mixed(handle_a, other, &opts, true),
274        (other, Value::GpuTensor(handle_b)) => union_gpu_mixed(handle_b, other, &opts, false),
275        (left, right) => union_host(left, right, &opts),
276    }
277}
278
279fn parse_options(rest: &[Value]) -> Result<UnionOptions, String> {
280    let mut opts = UnionOptions {
281        rows: false,
282        order: UnionOrder::Sorted,
283    };
284    let mut seen_order: Option<UnionOrder> = None;
285
286    for arg in rest {
287        let text = tensor::value_to_string(arg)
288            .ok_or_else(|| "union: expected string option arguments".to_string())?;
289        let lowered = text.trim().to_ascii_lowercase();
290        match lowered.as_str() {
291            "rows" => opts.rows = true,
292            "sorted" => {
293                if let Some(prev) = seen_order {
294                    if prev != UnionOrder::Sorted {
295                        return Err("union: cannot combine 'sorted' with 'stable'".to_string());
296                    }
297                }
298                seen_order = Some(UnionOrder::Sorted);
299                opts.order = UnionOrder::Sorted;
300            }
301            "stable" => {
302                if let Some(prev) = seen_order {
303                    if prev != UnionOrder::Stable {
304                        return Err("union: cannot combine 'sorted' with 'stable'".to_string());
305                    }
306                }
307                seen_order = Some(UnionOrder::Stable);
308                opts.order = UnionOrder::Stable;
309            }
310            "legacy" | "r2012a" => {
311                return Err("union: the 'legacy' behaviour is not supported".to_string());
312            }
313            other => return Err(format!("union: unrecognised option '{other}'")),
314        }
315    }
316
317    Ok(opts)
318}
319
320fn union_gpu_pair(
321    handle_a: GpuTensorHandle,
322    handle_b: GpuTensorHandle,
323    opts: &UnionOptions,
324) -> Result<UnionEvaluation, String> {
325    if let Some(provider) = runmat_accelerate_api::provider() {
326        match provider.union(&handle_a, &handle_b, opts) {
327            Ok(result) => return UnionEvaluation::from_union_result(result),
328            Err(_) => {
329                // Fall back to host gather when provider union is unavailable.
330            }
331        }
332    }
333    let tensor_a = gpu_helpers::gather_tensor(&handle_a)?;
334    let tensor_b = gpu_helpers::gather_tensor(&handle_b)?;
335    union_numeric(tensor_a, tensor_b, opts)
336}
337
338fn union_gpu_mixed(
339    handle_gpu: GpuTensorHandle,
340    other: Value,
341    opts: &UnionOptions,
342    gpu_is_a: bool,
343) -> Result<UnionEvaluation, String> {
344    let tensor_gpu = gpu_helpers::gather_tensor(&handle_gpu)?;
345    let tensor_other = tensor::value_into_tensor_for("union", other)?;
346    if gpu_is_a {
347        union_numeric(tensor_gpu, tensor_other, opts)
348    } else {
349        union_numeric(tensor_other, tensor_gpu, opts)
350    }
351}
352
353fn union_host(a: Value, b: Value, opts: &UnionOptions) -> Result<UnionEvaluation, String> {
354    match (a, b) {
355        // Complex cases
356        (Value::ComplexTensor(at), Value::ComplexTensor(bt)) => union_complex(at, bt, opts),
357        (Value::ComplexTensor(at), Value::Complex(re, im)) => {
358            let bt = ComplexTensor::new(vec![(re, im)], vec![1, 1])
359                .map_err(|e| format!("union: {e}"))?;
360            union_complex(at, bt, opts)
361        }
362        (Value::Complex(re, im), Value::ComplexTensor(bt)) => {
363            let at = ComplexTensor::new(vec![(re, im)], vec![1, 1])
364                .map_err(|e| format!("union: {e}"))?;
365            union_complex(at, bt, opts)
366        }
367        (Value::Complex(a_re, a_im), Value::Complex(b_re, b_im)) => {
368            let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
369                .map_err(|e| format!("union: {e}"))?;
370            let bt = ComplexTensor::new(vec![(b_re, b_im)], vec![1, 1])
371                .map_err(|e| format!("union: {e}"))?;
372            union_complex(at, bt, opts)
373        }
374
375        // Character arrays
376        (Value::CharArray(ac), Value::CharArray(bc)) => union_char(ac, bc, opts),
377
378        // String arrays / scalars
379        (Value::StringArray(astring), Value::StringArray(bstring)) => {
380            union_string(astring, bstring, opts)
381        }
382        (Value::StringArray(astring), Value::String(b)) => {
383            let bstring =
384                StringArray::new(vec![b], vec![1, 1]).map_err(|e| format!("union: {e}"))?;
385            union_string(astring, bstring, opts)
386        }
387        (Value::String(a), Value::StringArray(bstring)) => {
388            let astring =
389                StringArray::new(vec![a], vec![1, 1]).map_err(|e| format!("union: {e}"))?;
390            union_string(astring, bstring, opts)
391        }
392        (Value::String(a), Value::String(b)) => {
393            let astring =
394                StringArray::new(vec![a], vec![1, 1]).map_err(|e| format!("union: {e}"))?;
395            let bstring =
396                StringArray::new(vec![b], vec![1, 1]).map_err(|e| format!("union: {e}"))?;
397            union_string(astring, bstring, opts)
398        }
399
400        // Fallback to numeric (includes tensors, logical arrays, ints, bools, doubles)
401        (left, right) => {
402            let tensor_a = tensor::value_into_tensor_for("union", left)?;
403            let tensor_b = tensor::value_into_tensor_for("union", right)?;
404            union_numeric(tensor_a, tensor_b, opts)
405        }
406    }
407}
408
409fn union_numeric(a: Tensor, b: Tensor, opts: &UnionOptions) -> Result<UnionEvaluation, String> {
410    if opts.rows {
411        union_numeric_rows(a, b, opts)
412    } else {
413        union_numeric_elements(a, b, opts)
414    }
415}
416
417/// Helper exposed for acceleration providers handling numeric tensors entirely on the host.
418pub fn union_numeric_from_tensors(
419    a: Tensor,
420    b: Tensor,
421    opts: &UnionOptions,
422) -> Result<UnionEvaluation, String> {
423    union_numeric(a, b, opts)
424}
425
426fn union_numeric_elements(
427    a: Tensor,
428    b: Tensor,
429    opts: &UnionOptions,
430) -> Result<UnionEvaluation, String> {
431    let mut entries = Vec::<NumericUnionEntry>::new();
432    let mut map: HashMap<u64, usize> = HashMap::new();
433    let mut order_counter = 0usize;
434
435    for (idx, &value) in a.data.iter().enumerate() {
436        let key = canonicalize_f64(value);
437        match map.entry(key) {
438            Entry::Occupied(_) => {
439                // Already recorded from A; keep first occurrence only.
440            }
441            Entry::Vacant(v) => {
442                let entry_idx = entries.len();
443                entries.push(NumericUnionEntry {
444                    value,
445                    a_index: Some(idx),
446                    b_index: None,
447                    order_rank: order_counter,
448                });
449                v.insert(entry_idx);
450                order_counter += 1;
451            }
452        }
453    }
454
455    for (idx, &value) in b.data.iter().enumerate() {
456        let key = canonicalize_f64(value);
457        match map.entry(key) {
458            Entry::Occupied(occ) => {
459                let entry = &mut entries[*occ.get()];
460                if entry.a_index.is_none() && entry.b_index.is_none() {
461                    entry.b_index = Some(idx);
462                }
463            }
464            Entry::Vacant(v) => {
465                let entry_idx = entries.len();
466                entries.push(NumericUnionEntry {
467                    value,
468                    a_index: None,
469                    b_index: Some(idx),
470                    order_rank: order_counter,
471                });
472                v.insert(entry_idx);
473                order_counter += 1;
474            }
475        }
476    }
477
478    assemble_numeric_union(entries, opts)
479}
480
481fn union_numeric_rows(
482    a: Tensor,
483    b: Tensor,
484    opts: &UnionOptions,
485) -> Result<UnionEvaluation, String> {
486    if a.shape.len() != 2 || b.shape.len() != 2 {
487        return Err("union: 'rows' option requires 2-D numeric matrices".to_string());
488    }
489    if a.shape[1] != b.shape[1] {
490        return Err(
491            "union: inputs must have the same number of columns when using 'rows'".to_string(),
492        );
493    }
494    let rows_a = a.shape[0];
495    let cols = a.shape[1];
496    let rows_b = b.shape[0];
497
498    let mut entries = Vec::<NumericRowUnionEntry>::new();
499    let mut map: HashMap<NumericRowKey, usize> = HashMap::new();
500    let mut order_counter = 0usize;
501
502    for r in 0..rows_a {
503        let mut row_values = Vec::with_capacity(cols);
504        for c in 0..cols {
505            let idx = r + c * rows_a;
506            row_values.push(a.data[idx]);
507        }
508        let key = NumericRowKey::from_slice(&row_values);
509        match map.entry(key) {
510            Entry::Occupied(_) => {}
511            Entry::Vacant(v) => {
512                let entry_idx = entries.len();
513                entries.push(NumericRowUnionEntry {
514                    row_data: row_values,
515                    a_row: Some(r),
516                    b_row: None,
517                    order_rank: order_counter,
518                });
519                v.insert(entry_idx);
520                order_counter += 1;
521            }
522        }
523    }
524
525    for r in 0..rows_b {
526        let mut row_values = Vec::with_capacity(cols);
527        for c in 0..cols {
528            let idx = r + c * rows_b;
529            row_values.push(b.data[idx]);
530        }
531        let key = NumericRowKey::from_slice(&row_values);
532        match map.entry(key) {
533            Entry::Occupied(occ) => {
534                let entry = &mut entries[*occ.get()];
535                if entry.a_row.is_none() && entry.b_row.is_none() {
536                    entry.b_row = Some(r);
537                }
538            }
539            Entry::Vacant(v) => {
540                let entry_idx = entries.len();
541                entries.push(NumericRowUnionEntry {
542                    row_data: row_values,
543                    a_row: None,
544                    b_row: Some(r),
545                    order_rank: order_counter,
546                });
547                v.insert(entry_idx);
548                order_counter += 1;
549            }
550        }
551    }
552
553    assemble_numeric_row_union(entries, opts, cols)
554}
555
556fn union_complex(
557    a: ComplexTensor,
558    b: ComplexTensor,
559    opts: &UnionOptions,
560) -> Result<UnionEvaluation, String> {
561    if opts.rows {
562        union_complex_rows(a, b, opts)
563    } else {
564        union_complex_elements(a, b, opts)
565    }
566}
567
568fn union_complex_elements(
569    a: ComplexTensor,
570    b: ComplexTensor,
571    opts: &UnionOptions,
572) -> Result<UnionEvaluation, String> {
573    let mut entries = Vec::<ComplexUnionEntry>::new();
574    let mut map: HashMap<ComplexKey, usize> = HashMap::new();
575    let mut order_counter = 0usize;
576
577    for (idx, &value) in a.data.iter().enumerate() {
578        let key = ComplexKey::new(value);
579        match map.entry(key) {
580            Entry::Occupied(_) => {}
581            Entry::Vacant(v) => {
582                let entry_idx = entries.len();
583                entries.push(ComplexUnionEntry {
584                    value,
585                    a_index: Some(idx),
586                    b_index: None,
587                    order_rank: order_counter,
588                });
589                v.insert(entry_idx);
590                order_counter += 1;
591            }
592        }
593    }
594
595    for (idx, &value) in b.data.iter().enumerate() {
596        let key = ComplexKey::new(value);
597        match map.entry(key) {
598            Entry::Occupied(occ) => {
599                let entry = &mut entries[*occ.get()];
600                if entry.a_index.is_none() && entry.b_index.is_none() {
601                    entry.b_index = Some(idx);
602                }
603            }
604            Entry::Vacant(v) => {
605                let entry_idx = entries.len();
606                entries.push(ComplexUnionEntry {
607                    value,
608                    a_index: None,
609                    b_index: Some(idx),
610                    order_rank: order_counter,
611                });
612                v.insert(entry_idx);
613                order_counter += 1;
614            }
615        }
616    }
617
618    assemble_complex_union(entries, opts)
619}
620
621fn union_complex_rows(
622    a: ComplexTensor,
623    b: ComplexTensor,
624    opts: &UnionOptions,
625) -> Result<UnionEvaluation, String> {
626    if a.shape.len() != 2 || b.shape.len() != 2 {
627        return Err("union: 'rows' option requires 2-D complex matrices".to_string());
628    }
629    if a.shape[1] != b.shape[1] {
630        return Err(
631            "union: inputs must have the same number of columns when using 'rows'".to_string(),
632        );
633    }
634    let rows_a = a.shape[0];
635    let cols = a.shape[1];
636    let rows_b = b.shape[0];
637
638    let mut entries = Vec::<ComplexRowUnionEntry>::new();
639    let mut map: HashMap<Vec<ComplexKey>, usize> = HashMap::new();
640    let mut order_counter = 0usize;
641
642    for r in 0..rows_a {
643        let mut row_values = Vec::with_capacity(cols);
644        let mut key_row = Vec::with_capacity(cols);
645        for c in 0..cols {
646            let idx = r + c * rows_a;
647            let value = a.data[idx];
648            row_values.push(value);
649            key_row.push(ComplexKey::new(value));
650        }
651        match map.entry(key_row) {
652            Entry::Occupied(_) => {}
653            Entry::Vacant(v) => {
654                let entry_idx = entries.len();
655                entries.push(ComplexRowUnionEntry {
656                    row_data: row_values,
657                    a_row: Some(r),
658                    b_row: None,
659                    order_rank: order_counter,
660                });
661                v.insert(entry_idx);
662                order_counter += 1;
663            }
664        }
665    }
666
667    for r in 0..rows_b {
668        let mut row_values = Vec::with_capacity(cols);
669        let mut key_row = Vec::with_capacity(cols);
670        for c in 0..cols {
671            let idx = r + c * rows_b;
672            let value = b.data[idx];
673            row_values.push(value);
674            key_row.push(ComplexKey::new(value));
675        }
676        match map.entry(key_row) {
677            Entry::Occupied(occ) => {
678                let entry = &mut entries[*occ.get()];
679                if entry.a_row.is_none() && entry.b_row.is_none() {
680                    entry.b_row = Some(r);
681                }
682            }
683            Entry::Vacant(v) => {
684                let entry_idx = entries.len();
685                entries.push(ComplexRowUnionEntry {
686                    row_data: row_values,
687                    a_row: None,
688                    b_row: Some(r),
689                    order_rank: order_counter,
690                });
691                v.insert(entry_idx);
692                order_counter += 1;
693            }
694        }
695    }
696
697    assemble_complex_row_union(entries, opts, cols)
698}
699
700fn union_char(a: CharArray, b: CharArray, opts: &UnionOptions) -> Result<UnionEvaluation, String> {
701    if opts.rows {
702        union_char_rows(a, b, opts)
703    } else {
704        union_char_elements(a, b, opts)
705    }
706}
707
708fn union_char_elements(
709    a: CharArray,
710    b: CharArray,
711    opts: &UnionOptions,
712) -> Result<UnionEvaluation, String> {
713    let mut entries = Vec::<CharUnionEntry>::new();
714    let mut map: HashMap<u32, usize> = HashMap::new();
715    let mut order_counter = 0usize;
716
717    for col in 0..a.cols {
718        for row in 0..a.rows {
719            let linear_idx = row + col * a.rows;
720            let data_idx = row * a.cols + col;
721            let ch = a.data[data_idx];
722            let key = ch as u32;
723            match map.entry(key) {
724                Entry::Occupied(_) => {}
725                Entry::Vacant(v) => {
726                    let entry_idx = entries.len();
727                    entries.push(CharUnionEntry {
728                        ch,
729                        a_index: Some(linear_idx),
730                        b_index: None,
731                        order_rank: order_counter,
732                    });
733                    v.insert(entry_idx);
734                    order_counter += 1;
735                }
736            }
737        }
738    }
739
740    for col in 0..b.cols {
741        for row in 0..b.rows {
742            let linear_idx = row + col * b.rows;
743            let data_idx = row * b.cols + col;
744            let ch = b.data[data_idx];
745            let key = ch as u32;
746            match map.entry(key) {
747                Entry::Occupied(occ) => {
748                    let entry = &mut entries[*occ.get()];
749                    if entry.a_index.is_none() && entry.b_index.is_none() {
750                        entry.b_index = Some(linear_idx);
751                    }
752                }
753                Entry::Vacant(v) => {
754                    let entry_idx = entries.len();
755                    entries.push(CharUnionEntry {
756                        ch,
757                        a_index: None,
758                        b_index: Some(linear_idx),
759                        order_rank: order_counter,
760                    });
761                    v.insert(entry_idx);
762                    order_counter += 1;
763                }
764            }
765        }
766    }
767
768    assemble_char_union(entries, opts)
769}
770
771fn union_char_rows(
772    a: CharArray,
773    b: CharArray,
774    opts: &UnionOptions,
775) -> Result<UnionEvaluation, String> {
776    if a.cols != b.cols {
777        return Err(
778            "union: inputs must have the same number of columns when using 'rows'".to_string(),
779        );
780    }
781    let rows_a = a.rows;
782    let rows_b = b.rows;
783    let cols = a.cols;
784
785    let mut entries = Vec::<CharRowUnionEntry>::new();
786    let mut map: HashMap<RowCharKey, usize> = HashMap::new();
787    let mut order_counter = 0usize;
788
789    for r in 0..rows_a {
790        let mut row_values = Vec::with_capacity(cols);
791        for c in 0..cols {
792            let idx = r * cols + c;
793            row_values.push(a.data[idx]);
794        }
795        let key = RowCharKey::from_slice(&row_values);
796        match map.entry(key) {
797            Entry::Occupied(_) => {}
798            Entry::Vacant(v) => {
799                let entry_idx = entries.len();
800                entries.push(CharRowUnionEntry {
801                    row_data: row_values,
802                    a_row: Some(r),
803                    b_row: None,
804                    order_rank: order_counter,
805                });
806                v.insert(entry_idx);
807                order_counter += 1;
808            }
809        }
810    }
811
812    for r in 0..rows_b {
813        let mut row_values = Vec::with_capacity(cols);
814        for c in 0..cols {
815            let idx = r * cols + c;
816            row_values.push(b.data[idx]);
817        }
818        let key = RowCharKey::from_slice(&row_values);
819        match map.entry(key) {
820            Entry::Occupied(occ) => {
821                let entry = &mut entries[*occ.get()];
822                if entry.a_row.is_none() && entry.b_row.is_none() {
823                    entry.b_row = Some(r);
824                }
825            }
826            Entry::Vacant(v) => {
827                let entry_idx = entries.len();
828                entries.push(CharRowUnionEntry {
829                    row_data: row_values,
830                    a_row: None,
831                    b_row: Some(r),
832                    order_rank: order_counter,
833                });
834                v.insert(entry_idx);
835                order_counter += 1;
836            }
837        }
838    }
839
840    assemble_char_row_union(entries, opts, cols)
841}
842
843fn union_string(
844    a: StringArray,
845    b: StringArray,
846    opts: &UnionOptions,
847) -> Result<UnionEvaluation, String> {
848    if opts.rows {
849        union_string_rows(a, b, opts)
850    } else {
851        union_string_elements(a, b, opts)
852    }
853}
854
855fn union_string_elements(
856    a: StringArray,
857    b: StringArray,
858    opts: &UnionOptions,
859) -> Result<UnionEvaluation, String> {
860    let mut entries = Vec::<StringUnionEntry>::new();
861    let mut map: HashMap<String, usize> = HashMap::new();
862    let mut order_counter = 0usize;
863
864    for (idx, value) in a.data.iter().enumerate() {
865        match map.entry(value.clone()) {
866            Entry::Occupied(_) => {}
867            Entry::Vacant(v) => {
868                let entry_idx = entries.len();
869                entries.push(StringUnionEntry {
870                    value: value.clone(),
871                    a_index: Some(idx),
872                    b_index: None,
873                    order_rank: order_counter,
874                });
875                v.insert(entry_idx);
876                order_counter += 1;
877            }
878        }
879    }
880
881    for (idx, value) in b.data.iter().enumerate() {
882        match map.entry(value.clone()) {
883            Entry::Occupied(occ) => {
884                let entry = &mut entries[*occ.get()];
885                if entry.a_index.is_none() && entry.b_index.is_none() {
886                    entry.b_index = Some(idx);
887                }
888            }
889            Entry::Vacant(v) => {
890                let entry_idx = entries.len();
891                entries.push(StringUnionEntry {
892                    value: value.clone(),
893                    a_index: None,
894                    b_index: Some(idx),
895                    order_rank: order_counter,
896                });
897                v.insert(entry_idx);
898                order_counter += 1;
899            }
900        }
901    }
902
903    assemble_string_union(entries, opts)
904}
905
906fn union_string_rows(
907    a: StringArray,
908    b: StringArray,
909    opts: &UnionOptions,
910) -> Result<UnionEvaluation, String> {
911    if a.shape.len() != 2 || b.shape.len() != 2 {
912        return Err("union: 'rows' option requires 2-D string arrays".to_string());
913    }
914    if a.shape[1] != b.shape[1] {
915        return Err(
916            "union: inputs must have the same number of columns when using 'rows'".to_string(),
917        );
918    }
919    let rows_a = a.shape[0];
920    let cols = a.shape[1];
921    let rows_b = b.shape[0];
922
923    let mut entries = Vec::<StringRowUnionEntry>::new();
924    let mut map: HashMap<RowStringKey, usize> = HashMap::new();
925    let mut order_counter = 0usize;
926
927    for r in 0..rows_a {
928        let mut row_values = Vec::with_capacity(cols);
929        for c in 0..cols {
930            let idx = r + c * rows_a;
931            row_values.push(a.data[idx].clone());
932        }
933        let key = RowStringKey(row_values.clone());
934        match map.entry(key) {
935            Entry::Occupied(_) => {}
936            Entry::Vacant(v) => {
937                let entry_idx = entries.len();
938                entries.push(StringRowUnionEntry {
939                    row_data: row_values,
940                    a_row: Some(r),
941                    b_row: None,
942                    order_rank: order_counter,
943                });
944                v.insert(entry_idx);
945                order_counter += 1;
946            }
947        }
948    }
949
950    for r in 0..rows_b {
951        let mut row_values = Vec::with_capacity(cols);
952        for c in 0..cols {
953            let idx = r + c * rows_b;
954            row_values.push(b.data[idx].clone());
955        }
956        let key = RowStringKey(row_values.clone());
957        match map.entry(key) {
958            Entry::Occupied(occ) => {
959                let entry = &mut entries[*occ.get()];
960                if entry.a_row.is_none() && entry.b_row.is_none() {
961                    entry.b_row = Some(r);
962                }
963            }
964            Entry::Vacant(v) => {
965                let entry_idx = entries.len();
966                entries.push(StringRowUnionEntry {
967                    row_data: row_values,
968                    a_row: None,
969                    b_row: Some(r),
970                    order_rank: order_counter,
971                });
972                v.insert(entry_idx);
973                order_counter += 1;
974            }
975        }
976    }
977
978    assemble_string_row_union(entries, opts, cols)
979}
980
981#[derive(Debug, Clone)]
982pub struct UnionEvaluation {
983    values: Value,
984    ia: Tensor,
985    ib: Tensor,
986}
987
988impl UnionEvaluation {
989    fn new(values: Value, ia: Tensor, ib: Tensor) -> Self {
990        Self { values, ia, ib }
991    }
992
993    pub fn from_union_result(result: UnionResult) -> Result<Self, String> {
994        let UnionResult { values, ia, ib } = result;
995        let values_tensor =
996            Tensor::new(values.data, values.shape).map_err(|e| format!("union: {e}"))?;
997        let ia_tensor = Tensor::new(ia.data, ia.shape).map_err(|e| format!("union: {e}"))?;
998        let ib_tensor = Tensor::new(ib.data, ib.shape).map_err(|e| format!("union: {e}"))?;
999        Ok(UnionEvaluation::new(
1000            tensor::tensor_into_value(values_tensor),
1001            ia_tensor,
1002            ib_tensor,
1003        ))
1004    }
1005
1006    pub fn into_numeric_union_result(self) -> Result<UnionResult, String> {
1007        let UnionEvaluation { values, ia, ib } = self;
1008        let values_tensor = tensor::value_into_tensor_for("union", values)?;
1009        Ok(UnionResult {
1010            values: HostTensorOwned {
1011                data: values_tensor.data,
1012                shape: values_tensor.shape,
1013            },
1014            ia: HostTensorOwned {
1015                data: ia.data,
1016                shape: ia.shape,
1017            },
1018            ib: HostTensorOwned {
1019                data: ib.data,
1020                shape: ib.shape,
1021            },
1022        })
1023    }
1024
1025    pub fn into_values_value(self) -> Value {
1026        self.values
1027    }
1028
1029    pub fn into_pair(self) -> (Value, Value) {
1030        let ia = tensor::tensor_into_value(self.ia);
1031        (self.values, ia)
1032    }
1033
1034    pub fn into_triple(self) -> (Value, Value, Value) {
1035        let ia = tensor::tensor_into_value(self.ia);
1036        let ib = tensor::tensor_into_value(self.ib);
1037        (self.values, ia, ib)
1038    }
1039
1040    pub fn values_value(&self) -> Value {
1041        self.values.clone()
1042    }
1043
1044    pub fn ia_value(&self) -> Value {
1045        tensor::tensor_into_value(self.ia.clone())
1046    }
1047
1048    pub fn ib_value(&self) -> Value {
1049        tensor::tensor_into_value(self.ib.clone())
1050    }
1051}
1052
1053#[derive(Debug)]
1054struct NumericUnionEntry {
1055    value: f64,
1056    a_index: Option<usize>,
1057    b_index: Option<usize>,
1058    order_rank: usize,
1059}
1060
1061#[derive(Debug)]
1062struct NumericRowUnionEntry {
1063    row_data: Vec<f64>,
1064    a_row: Option<usize>,
1065    b_row: Option<usize>,
1066    order_rank: usize,
1067}
1068
1069#[derive(Debug)]
1070struct ComplexUnionEntry {
1071    value: (f64, f64),
1072    a_index: Option<usize>,
1073    b_index: Option<usize>,
1074    order_rank: usize,
1075}
1076
1077#[derive(Debug)]
1078struct ComplexRowUnionEntry {
1079    row_data: Vec<(f64, f64)>,
1080    a_row: Option<usize>,
1081    b_row: Option<usize>,
1082    order_rank: usize,
1083}
1084
1085#[derive(Debug)]
1086struct CharUnionEntry {
1087    ch: char,
1088    a_index: Option<usize>,
1089    b_index: Option<usize>,
1090    order_rank: usize,
1091}
1092
1093#[derive(Debug)]
1094struct CharRowUnionEntry {
1095    row_data: Vec<char>,
1096    a_row: Option<usize>,
1097    b_row: Option<usize>,
1098    order_rank: usize,
1099}
1100
1101#[derive(Debug)]
1102struct StringUnionEntry {
1103    value: String,
1104    a_index: Option<usize>,
1105    b_index: Option<usize>,
1106    order_rank: usize,
1107}
1108
1109#[derive(Debug)]
1110struct StringRowUnionEntry {
1111    row_data: Vec<String>,
1112    a_row: Option<usize>,
1113    b_row: Option<usize>,
1114    order_rank: usize,
1115}
1116
1117#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1118struct NumericRowKey(Vec<u64>);
1119
1120impl NumericRowKey {
1121    fn from_slice(values: &[f64]) -> Self {
1122        NumericRowKey(values.iter().map(|&v| canonicalize_f64(v)).collect())
1123    }
1124}
1125
1126#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
1127struct ComplexKey {
1128    re: u64,
1129    im: u64,
1130}
1131
1132impl ComplexKey {
1133    fn new(value: (f64, f64)) -> Self {
1134        Self {
1135            re: canonicalize_f64(value.0),
1136            im: canonicalize_f64(value.1),
1137        }
1138    }
1139}
1140
1141#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1142struct RowCharKey(Vec<u32>);
1143
1144impl RowCharKey {
1145    fn from_slice(values: &[char]) -> Self {
1146        RowCharKey(values.iter().map(|&ch| ch as u32).collect())
1147    }
1148}
1149
1150#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1151struct RowStringKey(Vec<String>);
1152
1153fn assemble_numeric_union(
1154    entries: Vec<NumericUnionEntry>,
1155    opts: &UnionOptions,
1156) -> Result<UnionEvaluation, String> {
1157    let mut order: Vec<usize> = (0..entries.len()).collect();
1158    match opts.order {
1159        UnionOrder::Sorted => {
1160            order.sort_by(|&lhs, &rhs| compare_f64(entries[lhs].value, entries[rhs].value));
1161        }
1162        UnionOrder::Stable => {
1163            order.sort_by_key(|&idx| entries[idx].order_rank);
1164        }
1165    }
1166
1167    let mut values = Vec::with_capacity(order.len());
1168    let mut ia = Vec::new();
1169    let mut ib = Vec::new();
1170    for &idx in &order {
1171        let entry = &entries[idx];
1172        values.push(entry.value);
1173        if let Some(a_idx) = entry.a_index {
1174            ia.push((a_idx + 1) as f64);
1175        } else if let Some(b_idx) = entry.b_index {
1176            ib.push((b_idx + 1) as f64);
1177        }
1178    }
1179
1180    let value_tensor =
1181        Tensor::new(values, vec![order.len(), 1]).map_err(|e| format!("union: {e}"))?;
1182    let ia_len = ia.len();
1183    let ib_len = ib.len();
1184    let ia_tensor = Tensor::new(ia, vec![ia_len, 1]).map_err(|e| format!("union: {e}"))?;
1185    let ib_tensor = Tensor::new(ib, vec![ib_len, 1]).map_err(|e| format!("union: {e}"))?;
1186
1187    Ok(UnionEvaluation::new(
1188        tensor::tensor_into_value(value_tensor),
1189        ia_tensor,
1190        ib_tensor,
1191    ))
1192}
1193
1194fn assemble_numeric_row_union(
1195    entries: Vec<NumericRowUnionEntry>,
1196    opts: &UnionOptions,
1197    cols: usize,
1198) -> Result<UnionEvaluation, String> {
1199    let mut order: Vec<usize> = (0..entries.len()).collect();
1200    match opts.order {
1201        UnionOrder::Sorted => {
1202            order.sort_by(|&lhs, &rhs| {
1203                compare_numeric_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1204            });
1205        }
1206        UnionOrder::Stable => {
1207            order.sort_by_key(|&idx| entries[idx].order_rank);
1208        }
1209    }
1210
1211    let unique_rows = order.len();
1212    let mut values = vec![0.0f64; unique_rows * cols];
1213    let mut ia = Vec::new();
1214    let mut ib = Vec::new();
1215
1216    for (row_pos, &entry_idx) in order.iter().enumerate() {
1217        let entry = &entries[entry_idx];
1218        for col in 0..cols {
1219            let dest = row_pos + col * unique_rows;
1220            values[dest] = entry.row_data[col];
1221        }
1222        if let Some(a_row) = entry.a_row {
1223            ia.push((a_row + 1) as f64);
1224        } else if let Some(b_row) = entry.b_row {
1225            ib.push((b_row + 1) as f64);
1226        }
1227    }
1228
1229    let value_tensor =
1230        Tensor::new(values, vec![unique_rows, cols]).map_err(|e| format!("union: {e}"))?;
1231    let ia_len = ia.len();
1232    let ib_len = ib.len();
1233    let ia_tensor = Tensor::new(ia, vec![ia_len, 1]).map_err(|e| format!("union: {e}"))?;
1234    let ib_tensor = Tensor::new(ib, vec![ib_len, 1]).map_err(|e| format!("union: {e}"))?;
1235
1236    Ok(UnionEvaluation::new(
1237        tensor::tensor_into_value(value_tensor),
1238        ia_tensor,
1239        ib_tensor,
1240    ))
1241}
1242
1243fn assemble_complex_union(
1244    entries: Vec<ComplexUnionEntry>,
1245    opts: &UnionOptions,
1246) -> Result<UnionEvaluation, String> {
1247    let mut order: Vec<usize> = (0..entries.len()).collect();
1248    match opts.order {
1249        UnionOrder::Sorted => {
1250            order.sort_by(|&lhs, &rhs| compare_complex(entries[lhs].value, entries[rhs].value));
1251        }
1252        UnionOrder::Stable => {
1253            order.sort_by_key(|&idx| entries[idx].order_rank);
1254        }
1255    }
1256
1257    let mut values = Vec::with_capacity(order.len());
1258    let mut ia = Vec::new();
1259    let mut ib = Vec::new();
1260    for &idx in &order {
1261        let entry = &entries[idx];
1262        values.push(entry.value);
1263        if let Some(a_idx) = entry.a_index {
1264            ia.push((a_idx + 1) as f64);
1265        } else if let Some(b_idx) = entry.b_index {
1266            ib.push((b_idx + 1) as f64);
1267        }
1268    }
1269
1270    let value_tensor =
1271        ComplexTensor::new(values, vec![order.len(), 1]).map_err(|e| format!("union: {e}"))?;
1272    let ia_len = ia.len();
1273    let ib_len = ib.len();
1274    let ia_tensor = Tensor::new(ia, vec![ia_len, 1]).map_err(|e| format!("union: {e}"))?;
1275    let ib_tensor = Tensor::new(ib, vec![ib_len, 1]).map_err(|e| format!("union: {e}"))?;
1276
1277    Ok(UnionEvaluation::new(
1278        complex_tensor_into_value(value_tensor),
1279        ia_tensor,
1280        ib_tensor,
1281    ))
1282}
1283
1284fn assemble_complex_row_union(
1285    entries: Vec<ComplexRowUnionEntry>,
1286    opts: &UnionOptions,
1287    cols: usize,
1288) -> Result<UnionEvaluation, String> {
1289    let mut order: Vec<usize> = (0..entries.len()).collect();
1290    match opts.order {
1291        UnionOrder::Sorted => {
1292            order.sort_by(|&lhs, &rhs| {
1293                compare_complex_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1294            });
1295        }
1296        UnionOrder::Stable => {
1297            order.sort_by_key(|&idx| entries[idx].order_rank);
1298        }
1299    }
1300
1301    let unique_rows = order.len();
1302    let mut values = vec![(0.0, 0.0); unique_rows * cols];
1303    let mut ia = Vec::new();
1304    let mut ib = Vec::new();
1305
1306    for (row_pos, &entry_idx) in order.iter().enumerate() {
1307        let entry = &entries[entry_idx];
1308        for col in 0..cols {
1309            let dest = row_pos + col * unique_rows;
1310            values[dest] = entry.row_data[col];
1311        }
1312        if let Some(a_row) = entry.a_row {
1313            ia.push((a_row + 1) as f64);
1314        } else if let Some(b_row) = entry.b_row {
1315            ib.push((b_row + 1) as f64);
1316        }
1317    }
1318
1319    let value_tensor =
1320        ComplexTensor::new(values, vec![unique_rows, cols]).map_err(|e| format!("union: {e}"))?;
1321    let ia_len = ia.len();
1322    let ib_len = ib.len();
1323    let ia_tensor = Tensor::new(ia, vec![ia_len, 1]).map_err(|e| format!("union: {e}"))?;
1324    let ib_tensor = Tensor::new(ib, vec![ib_len, 1]).map_err(|e| format!("union: {e}"))?;
1325
1326    Ok(UnionEvaluation::new(
1327        complex_tensor_into_value(value_tensor),
1328        ia_tensor,
1329        ib_tensor,
1330    ))
1331}
1332
1333fn assemble_char_union(
1334    entries: Vec<CharUnionEntry>,
1335    opts: &UnionOptions,
1336) -> Result<UnionEvaluation, String> {
1337    let mut order: Vec<usize> = (0..entries.len()).collect();
1338    match opts.order {
1339        UnionOrder::Sorted => {
1340            order.sort_by(|&lhs, &rhs| entries[lhs].ch.cmp(&entries[rhs].ch));
1341        }
1342        UnionOrder::Stable => {
1343            order.sort_by_key(|&idx| entries[idx].order_rank);
1344        }
1345    }
1346
1347    let mut values = Vec::with_capacity(order.len());
1348    let mut ia = Vec::new();
1349    let mut ib = Vec::new();
1350    for &idx in &order {
1351        let entry = &entries[idx];
1352        values.push(entry.ch);
1353        if let Some(a_idx) = entry.a_index {
1354            ia.push((a_idx + 1) as f64);
1355        } else if let Some(b_idx) = entry.b_index {
1356            ib.push((b_idx + 1) as f64);
1357        }
1358    }
1359
1360    let value_array = CharArray::new(values, order.len(), 1).map_err(|e| format!("union: {e}"))?;
1361    let ia_len = ia.len();
1362    let ib_len = ib.len();
1363    let ia_tensor = Tensor::new(ia, vec![ia_len, 1]).map_err(|e| format!("union: {e}"))?;
1364    let ib_tensor = Tensor::new(ib, vec![ib_len, 1]).map_err(|e| format!("union: {e}"))?;
1365
1366    Ok(UnionEvaluation::new(
1367        Value::CharArray(value_array),
1368        ia_tensor,
1369        ib_tensor,
1370    ))
1371}
1372
1373fn assemble_char_row_union(
1374    entries: Vec<CharRowUnionEntry>,
1375    opts: &UnionOptions,
1376    cols: usize,
1377) -> Result<UnionEvaluation, String> {
1378    let mut order: Vec<usize> = (0..entries.len()).collect();
1379    match opts.order {
1380        UnionOrder::Sorted => {
1381            order.sort_by(|&lhs, &rhs| {
1382                compare_char_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1383            });
1384        }
1385        UnionOrder::Stable => {
1386            order.sort_by_key(|&idx| entries[idx].order_rank);
1387        }
1388    }
1389
1390    let unique_rows = order.len();
1391    let mut values = vec!['\0'; unique_rows * cols];
1392    let mut ia = Vec::new();
1393    let mut ib = Vec::new();
1394
1395    for (row_pos, &entry_idx) in order.iter().enumerate() {
1396        let entry = &entries[entry_idx];
1397        for col in 0..cols {
1398            let dest = row_pos * cols + col;
1399            values[dest] = entry.row_data[col];
1400        }
1401        if let Some(a_row) = entry.a_row {
1402            ia.push((a_row + 1) as f64);
1403        } else if let Some(b_row) = entry.b_row {
1404            ib.push((b_row + 1) as f64);
1405        }
1406    }
1407
1408    let value_array =
1409        CharArray::new(values, unique_rows, cols).map_err(|e| format!("union: {e}"))?;
1410    let ia_len = ia.len();
1411    let ib_len = ib.len();
1412    let ia_tensor = Tensor::new(ia, vec![ia_len, 1]).map_err(|e| format!("union: {e}"))?;
1413    let ib_tensor = Tensor::new(ib, vec![ib_len, 1]).map_err(|e| format!("union: {e}"))?;
1414
1415    Ok(UnionEvaluation::new(
1416        Value::CharArray(value_array),
1417        ia_tensor,
1418        ib_tensor,
1419    ))
1420}
1421
1422fn assemble_string_union(
1423    entries: Vec<StringUnionEntry>,
1424    opts: &UnionOptions,
1425) -> Result<UnionEvaluation, String> {
1426    let mut order: Vec<usize> = (0..entries.len()).collect();
1427    match opts.order {
1428        UnionOrder::Sorted => {
1429            order.sort_by(|&lhs, &rhs| entries[lhs].value.cmp(&entries[rhs].value));
1430        }
1431        UnionOrder::Stable => {
1432            order.sort_by_key(|&idx| entries[idx].order_rank);
1433        }
1434    }
1435
1436    let mut values = Vec::with_capacity(order.len());
1437    let mut ia = Vec::new();
1438    let mut ib = Vec::new();
1439    for &idx in &order {
1440        let entry = &entries[idx];
1441        values.push(entry.value.clone());
1442        if let Some(a_idx) = entry.a_index {
1443            ia.push((a_idx + 1) as f64);
1444        } else if let Some(b_idx) = entry.b_index {
1445            ib.push((b_idx + 1) as f64);
1446        }
1447    }
1448
1449    let value_array =
1450        StringArray::new(values, vec![order.len(), 1]).map_err(|e| format!("union: {e}"))?;
1451    let ia_len = ia.len();
1452    let ib_len = ib.len();
1453    let ia_tensor = Tensor::new(ia, vec![ia_len, 1]).map_err(|e| format!("union: {e}"))?;
1454    let ib_tensor = Tensor::new(ib, vec![ib_len, 1]).map_err(|e| format!("union: {e}"))?;
1455
1456    Ok(UnionEvaluation::new(
1457        Value::StringArray(value_array),
1458        ia_tensor,
1459        ib_tensor,
1460    ))
1461}
1462
1463fn assemble_string_row_union(
1464    entries: Vec<StringRowUnionEntry>,
1465    opts: &UnionOptions,
1466    cols: usize,
1467) -> Result<UnionEvaluation, String> {
1468    let mut order: Vec<usize> = (0..entries.len()).collect();
1469    match opts.order {
1470        UnionOrder::Sorted => {
1471            order.sort_by(|&lhs, &rhs| {
1472                compare_string_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1473            });
1474        }
1475        UnionOrder::Stable => {
1476            order.sort_by_key(|&idx| entries[idx].order_rank);
1477        }
1478    }
1479
1480    let unique_rows = order.len();
1481    let mut values = vec![String::new(); unique_rows * cols];
1482    let mut ia = Vec::new();
1483    let mut ib = Vec::new();
1484
1485    for (row_pos, &entry_idx) in order.iter().enumerate() {
1486        let entry = &entries[entry_idx];
1487        for col in 0..cols {
1488            let dest = row_pos + col * unique_rows;
1489            values[dest] = entry.row_data[col].clone();
1490        }
1491        if let Some(a_row) = entry.a_row {
1492            ia.push((a_row + 1) as f64);
1493        } else if let Some(b_row) = entry.b_row {
1494            ib.push((b_row + 1) as f64);
1495        }
1496    }
1497
1498    let value_array =
1499        StringArray::new(values, vec![unique_rows, cols]).map_err(|e| format!("union: {e}"))?;
1500    let ia_len = ia.len();
1501    let ib_len = ib.len();
1502    let ia_tensor = Tensor::new(ia, vec![ia_len, 1]).map_err(|e| format!("union: {e}"))?;
1503    let ib_tensor = Tensor::new(ib, vec![ib_len, 1]).map_err(|e| format!("union: {e}"))?;
1504
1505    Ok(UnionEvaluation::new(
1506        Value::StringArray(value_array),
1507        ia_tensor,
1508        ib_tensor,
1509    ))
1510}
1511
1512fn canonicalize_f64(value: f64) -> u64 {
1513    if value.is_nan() {
1514        0x7ff8_0000_0000_0000u64
1515    } else if value == 0.0 {
1516        0u64
1517    } else {
1518        value.to_bits()
1519    }
1520}
1521
1522fn compare_f64(a: f64, b: f64) -> Ordering {
1523    if a.is_nan() {
1524        if b.is_nan() {
1525            Ordering::Equal
1526        } else {
1527            Ordering::Greater
1528        }
1529    } else if b.is_nan() {
1530        Ordering::Less
1531    } else {
1532        a.partial_cmp(&b).unwrap_or(Ordering::Equal)
1533    }
1534}
1535
1536fn compare_numeric_rows(a: &[f64], b: &[f64]) -> Ordering {
1537    for (lhs, rhs) in a.iter().zip(b.iter()) {
1538        let ord = compare_f64(*lhs, *rhs);
1539        if ord != Ordering::Equal {
1540            return ord;
1541        }
1542    }
1543    Ordering::Equal
1544}
1545
1546fn complex_is_nan(value: (f64, f64)) -> bool {
1547    value.0.is_nan() || value.1.is_nan()
1548}
1549
1550fn compare_complex(a: (f64, f64), b: (f64, f64)) -> Ordering {
1551    match (complex_is_nan(a), complex_is_nan(b)) {
1552        (true, true) => Ordering::Equal,
1553        (true, false) => Ordering::Greater,
1554        (false, true) => Ordering::Less,
1555        (false, false) => {
1556            let mag_a = a.0.hypot(a.1);
1557            let mag_b = b.0.hypot(b.1);
1558            let mag_cmp = compare_f64(mag_a, mag_b);
1559            if mag_cmp != Ordering::Equal {
1560                return mag_cmp;
1561            }
1562            let re_cmp = compare_f64(a.0, b.0);
1563            if re_cmp != Ordering::Equal {
1564                return re_cmp;
1565            }
1566            compare_f64(a.1, b.1)
1567        }
1568    }
1569}
1570
1571fn compare_complex_rows(a: &[(f64, f64)], b: &[(f64, f64)]) -> Ordering {
1572    for (lhs, rhs) in a.iter().zip(b.iter()) {
1573        let ord = compare_complex(*lhs, *rhs);
1574        if ord != Ordering::Equal {
1575            return ord;
1576        }
1577    }
1578    Ordering::Equal
1579}
1580
1581fn compare_char_rows(a: &[char], b: &[char]) -> Ordering {
1582    for (lhs, rhs) in a.iter().zip(b.iter()) {
1583        let ord = lhs.cmp(rhs);
1584        if ord != Ordering::Equal {
1585            return ord;
1586        }
1587    }
1588    Ordering::Equal
1589}
1590
1591fn compare_string_rows(a: &[String], b: &[String]) -> Ordering {
1592    for (lhs, rhs) in a.iter().zip(b.iter()) {
1593        let ord = lhs.cmp(rhs);
1594        if ord != Ordering::Equal {
1595            return ord;
1596        }
1597    }
1598    Ordering::Equal
1599}
1600
1601#[cfg(test)]
1602mod tests {
1603    use super::*;
1604    use crate::builtins::common::test_support;
1605    use runmat_accelerate_api::HostTensorView;
1606    use runmat_builtins::{IntValue, Tensor, Value};
1607
1608    #[test]
1609    fn union_numeric_sorted_default() {
1610        let a = Tensor::new(vec![5.0, 7.0, 1.0], vec![3, 1]).unwrap();
1611        let b = Tensor::new(vec![3.0, 1.0, 1.0], vec![3, 1]).unwrap();
1612        let eval = evaluate(Value::Tensor(a), Value::Tensor(b), &[]).expect("union");
1613        match eval.values_value() {
1614            Value::Tensor(t) => {
1615                assert_eq!(t.data, vec![1.0, 3.0, 5.0, 7.0]);
1616                assert_eq!(t.shape, vec![4, 1]);
1617            }
1618            other => panic!("expected tensor result, got {other:?}"),
1619        }
1620        let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1621        assert_eq!(ia.data, vec![3.0, 1.0, 2.0]);
1622        assert_eq!(ia.shape, vec![3, 1]);
1623        let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1624        assert_eq!(ib.data, vec![1.0]);
1625        assert_eq!(ib.shape, vec![1, 1]);
1626    }
1627
1628    #[test]
1629    fn union_numeric_stable_order() {
1630        let a = Tensor::new(vec![5.0, 7.0, 1.0], vec![3, 1]).unwrap();
1631        let b = Tensor::new(vec![3.0, 2.0, 4.0], vec![3, 1]).unwrap();
1632        let eval =
1633            evaluate(Value::Tensor(a), Value::Tensor(b), &[Value::from("stable")]).expect("union");
1634        match eval.values_value() {
1635            Value::Tensor(t) => {
1636                assert_eq!(t.data, vec![5.0, 7.0, 1.0, 3.0, 2.0, 4.0]);
1637                assert_eq!(t.shape, vec![6, 1]);
1638            }
1639            other => panic!("expected tensor result, got {other:?}"),
1640        }
1641        let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1642        assert_eq!(ia.data, vec![1.0, 2.0, 3.0]);
1643        let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1644        assert_eq!(ib.data, vec![1.0, 2.0, 3.0]);
1645    }
1646
1647    #[test]
1648    fn union_numeric_sorted_places_nan_last() {
1649        let a = Tensor::new(vec![f64::NAN, 1.0], vec![2, 1]).unwrap();
1650        let b = Tensor::new(vec![2.0, f64::NAN], vec![2, 1]).unwrap();
1651        let eval = evaluate(Value::Tensor(a), Value::Tensor(b), &[]).expect("union");
1652        let values = tensor::value_into_tensor_for("union", eval.values_value()).expect("values");
1653        assert_eq!(values.shape, vec![3, 1]);
1654        assert_eq!(values.data[0], 1.0);
1655        assert_eq!(values.data[1], 2.0);
1656        assert!(values.data[2].is_nan());
1657        let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1658        assert_eq!(ia.data, vec![2.0, 1.0]);
1659        let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1660        assert_eq!(ib.data, vec![1.0]);
1661    }
1662
1663    #[test]
1664    fn union_numeric_rows_sorted() {
1665        let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
1666        let b = Tensor::new(vec![3.0, 5.0, 4.0, 6.0], vec![2, 2]).unwrap();
1667        let eval =
1668            evaluate(Value::Tensor(a), Value::Tensor(b), &[Value::from("rows")]).expect("union");
1669        match eval.values_value() {
1670            Value::Tensor(t) => {
1671                assert_eq!(t.shape, vec![3, 2]);
1672                assert_eq!(t.data, vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
1673            }
1674            other => panic!("expected tensor result, got {other:?}"),
1675        }
1676        let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1677        assert_eq!(ia.data, vec![1.0, 2.0]);
1678        let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1679        assert_eq!(ib.data, vec![2.0]);
1680    }
1681
1682    #[test]
1683    fn union_numeric_rows_stable_preserves_first_occurrence() {
1684        let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
1685        let b = Tensor::new(vec![3.0, 5.0, 1.0, 4.0, 6.0, 2.0], vec![3, 2]).unwrap();
1686        let eval = evaluate(
1687            Value::Tensor(a),
1688            Value::Tensor(b),
1689            &[Value::from("rows"), Value::from("stable")],
1690        )
1691        .expect("union");
1692        let (values, ia, ib) = eval.into_triple();
1693        match values {
1694            Value::Tensor(t) => {
1695                assert_eq!(t.shape, vec![3, 2]);
1696                assert_eq!(t.data, vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
1697            }
1698            other => panic!("expected tensor result, got {other:?}"),
1699        }
1700        let ia_tensor = tensor::value_into_tensor_for("union", ia).expect("ia tensor");
1701        assert_eq!(ia_tensor.data, vec![1.0, 2.0]);
1702        let ib_tensor = tensor::value_into_tensor_for("union", ib).expect("ib tensor");
1703        assert_eq!(ib_tensor.data, vec![2.0]);
1704    }
1705
1706    #[test]
1707    fn union_char_elements() {
1708        let a = CharArray::new(vec!['m', 'z', 'm', 'a'], 2, 2).unwrap();
1709        let b = CharArray::new(vec!['a', 'x', 'm', 'a'], 2, 2).unwrap();
1710        let eval = evaluate(Value::CharArray(a), Value::CharArray(b), &[]).expect("union");
1711        match eval.values_value() {
1712            Value::CharArray(arr) => {
1713                assert_eq!(arr.rows, 4);
1714                assert_eq!(arr.cols, 1);
1715                assert_eq!(arr.data, vec!['a', 'm', 'x', 'z']);
1716            }
1717            other => panic!("expected char array, got {other:?}"),
1718        }
1719        let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1720        assert_eq!(ia.data, vec![4.0, 1.0, 3.0]);
1721        let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1722        assert_eq!(ib.data, vec![3.0]);
1723    }
1724
1725    #[test]
1726    fn union_string_rows_stable() {
1727        let a = StringArray::new(
1728            vec![
1729                "alpha".to_string(),
1730                "gamma".to_string(),
1731                "beta".to_string(),
1732                "beta".to_string(),
1733            ],
1734            vec![2, 2],
1735        )
1736        .unwrap();
1737        let b = StringArray::new(
1738            vec![
1739                "gamma".to_string(),
1740                "delta".to_string(),
1741                "beta".to_string(),
1742                "beta".to_string(),
1743            ],
1744            vec![2, 2],
1745        )
1746        .unwrap();
1747        let eval = evaluate(
1748            Value::StringArray(a),
1749            Value::StringArray(b),
1750            &[Value::from("rows"), Value::from("stable")],
1751        )
1752        .expect("union");
1753        match eval.values_value() {
1754            Value::StringArray(arr) => {
1755                assert_eq!(arr.shape, vec![3, 2]);
1756                assert_eq!(
1757                    arr.data,
1758                    vec![
1759                        "alpha".to_string(),
1760                        "gamma".to_string(),
1761                        "delta".to_string(),
1762                        "beta".to_string(),
1763                        "beta".to_string(),
1764                        "beta".to_string()
1765                    ]
1766                );
1767            }
1768            other => panic!("expected string array, got {other:?}"),
1769        }
1770        let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1771        assert_eq!(ia.data, vec![1.0, 2.0]);
1772        let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1773        assert_eq!(ib.data, vec![2.0]);
1774    }
1775
1776    #[test]
1777    fn union_gpu_roundtrip() {
1778        test_support::with_test_provider(|provider| {
1779            let a = Tensor::new(vec![4.0, 1.0, 2.0], vec![3, 1]).unwrap();
1780            let b = Tensor::new(vec![2.0, 5.0], vec![2, 1]).unwrap();
1781            let view_a = HostTensorView {
1782                data: &a.data,
1783                shape: &a.shape,
1784            };
1785            let view_b = HostTensorView {
1786                data: &b.data,
1787                shape: &b.shape,
1788            };
1789            let handle_a = provider.upload(&view_a).expect("upload A");
1790            let handle_b = provider.upload(&view_b).expect("upload B");
1791            let eval = evaluate(
1792                Value::GpuTensor(handle_a),
1793                Value::GpuTensor(handle_b),
1794                &[Value::from("stable")],
1795            )
1796            .expect("union");
1797            let values = tensor::value_into_tensor_for("union", eval.values_value()).unwrap();
1798            assert_eq!(values.data, vec![4.0, 1.0, 2.0, 5.0]);
1799            let ia = tensor::value_into_tensor_for("union", eval.ia_value()).unwrap();
1800            assert_eq!(ia.data, vec![1.0, 2.0, 3.0]);
1801            let ib = tensor::value_into_tensor_for("union", eval.ib_value()).unwrap();
1802            assert_eq!(ib.data, vec![2.0]);
1803        });
1804    }
1805
1806    #[test]
1807    fn union_rejects_legacy_option() {
1808        let tensor =
1809            Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).expect("tensor construction failed");
1810        let err = evaluate(
1811            Value::Tensor(tensor.clone()),
1812            Value::Tensor(tensor),
1813            &[Value::from("legacy")],
1814        )
1815        .unwrap_err();
1816        assert!(err.contains("legacy"));
1817    }
1818
1819    #[test]
1820    fn union_rows_dimension_mismatch() {
1821        let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1822        let b = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1823        let err = evaluate(Value::Tensor(a), Value::Tensor(b), &[Value::from("rows")]).unwrap_err();
1824        assert!(err.contains("same number of columns"));
1825    }
1826
1827    #[test]
1828    fn union_requires_matching_types() {
1829        let a = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1830        let b = CharArray::new(vec!['a', 'b'], 1, 2).unwrap();
1831        let err = union_host(
1832            Value::Tensor(a),
1833            Value::CharArray(b),
1834            &UnionOptions {
1835                rows: false,
1836                order: UnionOrder::Sorted,
1837            },
1838        )
1839        .unwrap_err();
1840        assert!(err.contains("unsupported input type"));
1841    }
1842
1843    #[test]
1844    fn union_accepts_scalar_inputs() {
1845        let eval = evaluate(Value::Int(IntValue::I32(1)), Value::Num(3.0), &[]).expect("union");
1846        match eval.values_value() {
1847            Value::Tensor(t) => {
1848                assert_eq!(t.data, vec![1.0, 3.0]);
1849                assert_eq!(t.shape, vec![2, 1]);
1850            }
1851            other => panic!("expected numeric tensor, got {other:?}"),
1852        }
1853        let ia = tensor::value_into_tensor_for("union", eval.ia_value()).unwrap();
1854        assert_eq!(ia.data, vec![1.0]);
1855        let ib = tensor::value_into_tensor_for("union", eval.ib_value()).unwrap();
1856        assert_eq!(ib.data, vec![1.0]);
1857    }
1858
1859    #[test]
1860    #[cfg(feature = "wgpu")]
1861    fn union_wgpu_matches_cpu() {
1862        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1863            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1864        );
1865        let a = Tensor::new(vec![4.0, 1.0, 2.0, 3.0], vec![4, 1]).unwrap();
1866        let b = Tensor::new(vec![2.0, 6.0, 3.0], vec![3, 1]).unwrap();
1867
1868        let cpu_eval =
1869            evaluate(Value::Tensor(a.clone()), Value::Tensor(b.clone()), &[]).expect("union");
1870        let cpu_values = tensor::value_into_tensor_for("union", cpu_eval.values_value()).unwrap();
1871        let cpu_ia = tensor::value_into_tensor_for("union", cpu_eval.ia_value()).unwrap();
1872        let cpu_ib = tensor::value_into_tensor_for("union", cpu_eval.ib_value()).unwrap();
1873
1874        let provider = runmat_accelerate_api::provider().expect("provider");
1875        let view_a = HostTensorView {
1876            data: &a.data,
1877            shape: &a.shape,
1878        };
1879        let view_b = HostTensorView {
1880            data: &b.data,
1881            shape: &b.shape,
1882        };
1883        let handle_a = provider.upload(&view_a).expect("upload A");
1884        let handle_b = provider.upload(&view_b).expect("upload B");
1885        let gpu_eval =
1886            evaluate(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[]).expect("union");
1887        let gpu_values = tensor::value_into_tensor_for("union", gpu_eval.values_value()).unwrap();
1888        let gpu_ia = tensor::value_into_tensor_for("union", gpu_eval.ia_value()).unwrap();
1889        let gpu_ib = tensor::value_into_tensor_for("union", gpu_eval.ib_value()).unwrap();
1890
1891        assert_eq!(gpu_values.data, cpu_values.data);
1892        assert_eq!(gpu_values.shape, cpu_values.shape);
1893        assert_eq!(gpu_ia.data, cpu_ia.data);
1894        assert_eq!(gpu_ib.data, cpu_ib.data);
1895    }
1896
1897    #[test]
1898    #[cfg(feature = "doc_export")]
1899    fn doc_examples_present() {
1900        let blocks = test_support::doc_examples(DOC_MD);
1901        assert!(!blocks.is_empty());
1902    }
1903}