Skip to main content

runmat_runtime/builtins/array/sorting_sets/
union.rs

1//! MATLAB-compatible `union` builtin with GPU-aware semantics for RunMat.
2//!
3//! Handles element-wise and row-wise unions with optional stable ordering and
4//! index outputs that mirror MathWorks MATLAB semantics. GPU tensors are
5//! gathered to host memory unless a provider supplies a dedicated `union`
6//! kernel hook.
7
8use std::cmp::Ordering;
9use std::collections::{hash_map::Entry, HashMap};
10
11use runmat_accelerate_api::{
12    GpuTensorHandle, GpuTensorStorage, HostTensorOwned, UnionOptions, UnionOrder, UnionResult,
13};
14use runmat_builtins::{CharArray, ComplexTensor, StringArray, Tensor, Value};
15use runmat_macros::runtime_builtin;
16
17use super::type_resolvers::set_values_output_type;
18use crate::build_runtime_error;
19use crate::builtins::common::arg_tokens::tokens_from_values;
20use crate::builtins::common::gpu_helpers;
21use crate::builtins::common::random_args::complex_tensor_into_value;
22use crate::builtins::common::spec::{
23    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
24    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
25};
26use crate::builtins::common::tensor;
27
28#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::sorting_sets::union")]
29pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
30    name: "union",
31    op_kind: GpuOpKind::Custom("union"),
32    supported_precisions: &[ScalarType::F32, ScalarType::F64],
33    broadcast: BroadcastSemantics::None,
34    provider_hooks: &[ProviderHook::Custom("union")],
35    constant_strategy: ConstantStrategy::InlineLiteral,
36    residency: ResidencyPolicy::GatherImmediately,
37    nan_mode: ReductionNaN::Include,
38    two_pass_threshold: None,
39    workgroup_size: None,
40    accepts_nan_mode: true,
41    notes: "Providers may expose a dedicated union hook; otherwise tensors are gathered and processed on the host.",
42};
43
44#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::sorting_sets::union")]
45pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
46    name: "union",
47    shape: ShapeRequirements::Any,
48    constant_strategy: ConstantStrategy::InlineLiteral,
49    elementwise: None,
50    reduction: None,
51    emits_nan: true,
52    notes: "`union` terminates fusion chains and materialises results on the host; upstream tensors are gathered when necessary.",
53};
54
55fn union_error(message: impl Into<String>) -> crate::RuntimeError {
56    build_runtime_error(message).with_builtin("union").build()
57}
58
59#[runtime_builtin(
60    name = "union",
61    category = "array/sorting_sets",
62    summary = "Combine two arrays, returning their union with MATLAB-compatible ordering and index outputs.",
63    keywords = "union,set,stable,rows,indices,gpu",
64    accel = "array_construct",
65    sink = true,
66    type_resolver(set_values_output_type),
67    builtin_path = "crate::builtins::array::sorting_sets::union"
68)]
69async fn union_builtin(a: Value, b: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
70    let eval = evaluate(a, b, &rest).await?;
71    if let Some(out_count) = crate::output_count::current_output_count() {
72        if out_count == 0 {
73            return Ok(Value::OutputList(Vec::new()));
74        }
75        if out_count == 1 {
76            return Ok(Value::OutputList(vec![eval.into_values_value()]));
77        }
78        if out_count == 2 {
79            let (values, ia) = eval.into_pair();
80            return Ok(Value::OutputList(vec![values, ia]));
81        }
82        let (values, ia, ib) = eval.into_triple();
83        return Ok(crate::output_count::output_list_with_padding(
84            out_count,
85            vec![values, ia, ib],
86        ));
87    }
88    Ok(eval.into_values_value())
89}
90
91/// Evaluate the `union` builtin once and expose all outputs.
92pub async fn evaluate(a: Value, b: Value, rest: &[Value]) -> crate::BuiltinResult<UnionEvaluation> {
93    let opts = parse_options(rest)?;
94    match (a, b) {
95        (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
96            union_gpu_pair(handle_a, handle_b, &opts).await
97        }
98        (Value::GpuTensor(handle_a), other) => union_gpu_mixed(handle_a, other, &opts, true).await,
99        (other, Value::GpuTensor(handle_b)) => union_gpu_mixed(handle_b, other, &opts, false).await,
100        (left, right) => union_host(left, right, &opts),
101    }
102}
103
104fn parse_options(rest: &[Value]) -> crate::BuiltinResult<UnionOptions> {
105    let mut opts = UnionOptions {
106        rows: false,
107        order: UnionOrder::Sorted,
108    };
109    let mut seen_order: Option<UnionOrder> = None;
110
111    let tokens = tokens_from_values(rest);
112    for (arg, token) in rest.iter().zip(tokens.iter()) {
113        let text = match token {
114            crate::builtins::common::arg_tokens::ArgToken::String(text) => text.as_str(),
115            _ => {
116                let text = tensor::value_to_string(arg)
117                    .ok_or_else(|| union_error("union: expected string option arguments"))?;
118                let lowered = text.trim().to_ascii_lowercase();
119                parse_union_option(&mut opts, &mut seen_order, &lowered)?;
120                continue;
121            }
122        };
123        parse_union_option(&mut opts, &mut seen_order, text)?;
124    }
125
126    Ok(opts)
127}
128
129fn parse_union_option(
130    opts: &mut UnionOptions,
131    seen_order: &mut Option<UnionOrder>,
132    lowered: &str,
133) -> crate::BuiltinResult<()> {
134    match lowered {
135        "rows" => opts.rows = true,
136        "sorted" => {
137            if let Some(prev) = seen_order {
138                if *prev != UnionOrder::Sorted {
139                    return Err(union_error("union: cannot combine 'sorted' with 'stable'"));
140                }
141            }
142            *seen_order = Some(UnionOrder::Sorted);
143            opts.order = UnionOrder::Sorted;
144        }
145        "stable" => {
146            if let Some(prev) = seen_order {
147                if *prev != UnionOrder::Stable {
148                    return Err(union_error("union: cannot combine 'sorted' with 'stable'"));
149                }
150            }
151            *seen_order = Some(UnionOrder::Stable);
152            opts.order = UnionOrder::Stable;
153        }
154        "legacy" | "r2012a" => {
155            return Err(union_error(
156                "union: the 'legacy' behaviour is not supported",
157            ));
158        }
159        other => return Err(union_error(format!("union: unrecognised option '{other}'"))),
160    }
161    Ok(())
162}
163
164async fn union_gpu_pair(
165    handle_a: GpuTensorHandle,
166    handle_b: GpuTensorHandle,
167    opts: &UnionOptions,
168) -> crate::BuiltinResult<UnionEvaluation> {
169    if let Some(provider) = runmat_accelerate_api::provider() {
170        match provider.union(&handle_a, &handle_b, opts).await {
171            Ok(result) => return UnionEvaluation::from_union_result(result),
172            Err(_) => {
173                // Fall back to host gather when provider union is unavailable.
174            }
175        }
176    }
177    let tensor_a = gpu_helpers::gather_tensor_async(&handle_a).await?;
178    let tensor_b = gpu_helpers::gather_tensor_async(&handle_b).await?;
179    union_numeric(tensor_a, tensor_b, opts)
180}
181
182async fn union_gpu_mixed(
183    handle_gpu: GpuTensorHandle,
184    other: Value,
185    opts: &UnionOptions,
186    gpu_is_a: bool,
187) -> crate::BuiltinResult<UnionEvaluation> {
188    let tensor_gpu = gpu_helpers::gather_tensor_async(&handle_gpu).await?;
189    let tensor_other = tensor::value_into_tensor_for("union", other).map_err(|e| union_error(e))?;
190    if gpu_is_a {
191        union_numeric(tensor_gpu, tensor_other, opts)
192    } else {
193        union_numeric(tensor_other, tensor_gpu, opts)
194    }
195}
196
197fn union_host(a: Value, b: Value, opts: &UnionOptions) -> crate::BuiltinResult<UnionEvaluation> {
198    match (a, b) {
199        // Complex cases
200        (Value::ComplexTensor(at), Value::ComplexTensor(bt)) => union_complex(at, bt, opts),
201        (Value::ComplexTensor(at), Value::Complex(re, im)) => {
202            let bt = ComplexTensor::new(vec![(re, im)], vec![1, 1])
203                .map_err(|e| union_error(format!("union: {e}")))?;
204            union_complex(at, bt, opts)
205        }
206        (Value::Complex(re, im), Value::ComplexTensor(bt)) => {
207            let at = ComplexTensor::new(vec![(re, im)], vec![1, 1])
208                .map_err(|e| union_error(format!("union: {e}")))?;
209            union_complex(at, bt, opts)
210        }
211        (Value::Complex(a_re, a_im), Value::Complex(b_re, b_im)) => {
212            let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
213                .map_err(|e| union_error(format!("union: {e}")))?;
214            let bt = ComplexTensor::new(vec![(b_re, b_im)], vec![1, 1])
215                .map_err(|e| union_error(format!("union: {e}")))?;
216            union_complex(at, bt, opts)
217        }
218
219        // Character arrays
220        (Value::CharArray(ac), Value::CharArray(bc)) => union_char(ac, bc, opts),
221
222        // String arrays / scalars
223        (Value::StringArray(astring), Value::StringArray(bstring)) => {
224            union_string(astring, bstring, opts)
225        }
226        (Value::StringArray(astring), Value::String(b)) => {
227            let bstring = StringArray::new(vec![b], vec![1, 1])
228                .map_err(|e| union_error(format!("union: {e}")))?;
229            union_string(astring, bstring, opts)
230        }
231        (Value::String(a), Value::StringArray(bstring)) => {
232            let astring = StringArray::new(vec![a], vec![1, 1])
233                .map_err(|e| union_error(format!("union: {e}")))?;
234            union_string(astring, bstring, opts)
235        }
236        (Value::String(a), Value::String(b)) => {
237            let astring = StringArray::new(vec![a], vec![1, 1])
238                .map_err(|e| union_error(format!("union: {e}")))?;
239            let bstring = StringArray::new(vec![b], vec![1, 1])
240                .map_err(|e| union_error(format!("union: {e}")))?;
241            union_string(astring, bstring, opts)
242        }
243
244        // Fallback to numeric (includes tensors, logical arrays, ints, bools, doubles)
245        (left, right) => {
246            let tensor_a =
247                tensor::value_into_tensor_for("union", left).map_err(|e| union_error(e))?;
248            let tensor_b =
249                tensor::value_into_tensor_for("union", right).map_err(|e| union_error(e))?;
250            union_numeric(tensor_a, tensor_b, opts)
251        }
252    }
253}
254
255fn union_numeric(
256    a: Tensor,
257    b: Tensor,
258    opts: &UnionOptions,
259) -> crate::BuiltinResult<UnionEvaluation> {
260    if opts.rows {
261        union_numeric_rows(a, b, opts)
262    } else {
263        union_numeric_elements(a, b, opts)
264    }
265}
266
267/// Helper exposed for acceleration providers handling numeric tensors entirely on the host.
268pub fn union_numeric_from_tensors(
269    a: Tensor,
270    b: Tensor,
271    opts: &UnionOptions,
272) -> crate::BuiltinResult<UnionEvaluation> {
273    union_numeric(a, b, opts)
274}
275
276fn union_numeric_elements(
277    a: Tensor,
278    b: Tensor,
279    opts: &UnionOptions,
280) -> crate::BuiltinResult<UnionEvaluation> {
281    let mut entries = Vec::<NumericUnionEntry>::new();
282    let mut map: HashMap<u64, usize> = HashMap::new();
283    let mut order_counter = 0usize;
284
285    for (idx, &value) in a.data.iter().enumerate() {
286        let key = canonicalize_f64(value);
287        match map.entry(key) {
288            Entry::Occupied(_) => {
289                // Already recorded from A; keep first occurrence only.
290            }
291            Entry::Vacant(v) => {
292                let entry_idx = entries.len();
293                entries.push(NumericUnionEntry {
294                    value,
295                    a_index: Some(idx),
296                    b_index: None,
297                    order_rank: order_counter,
298                });
299                v.insert(entry_idx);
300                order_counter += 1;
301            }
302        }
303    }
304
305    for (idx, &value) in b.data.iter().enumerate() {
306        let key = canonicalize_f64(value);
307        match map.entry(key) {
308            Entry::Occupied(occ) => {
309                let entry = &mut entries[*occ.get()];
310                if entry.a_index.is_none() && entry.b_index.is_none() {
311                    entry.b_index = Some(idx);
312                }
313            }
314            Entry::Vacant(v) => {
315                let entry_idx = entries.len();
316                entries.push(NumericUnionEntry {
317                    value,
318                    a_index: None,
319                    b_index: Some(idx),
320                    order_rank: order_counter,
321                });
322                v.insert(entry_idx);
323                order_counter += 1;
324            }
325        }
326    }
327
328    assemble_numeric_union(entries, opts)
329}
330
331fn union_numeric_rows(
332    a: Tensor,
333    b: Tensor,
334    opts: &UnionOptions,
335) -> crate::BuiltinResult<UnionEvaluation> {
336    if a.shape.len() != 2 || b.shape.len() != 2 {
337        return Err(union_error(
338            "union: 'rows' option requires 2-D numeric matrices",
339        ));
340    }
341    if a.shape[1] != b.shape[1] {
342        return Err(union_error(
343            "union: inputs must have the same number of columns when using 'rows'",
344        ));
345    }
346    let rows_a = a.shape[0];
347    let cols = a.shape[1];
348    let rows_b = b.shape[0];
349
350    let mut entries = Vec::<NumericRowUnionEntry>::new();
351    let mut map: HashMap<NumericRowKey, usize> = HashMap::new();
352    let mut order_counter = 0usize;
353
354    for r in 0..rows_a {
355        let mut row_values = Vec::with_capacity(cols);
356        for c in 0..cols {
357            let idx = r + c * rows_a;
358            row_values.push(a.data[idx]);
359        }
360        let key = NumericRowKey::from_slice(&row_values);
361        match map.entry(key) {
362            Entry::Occupied(_) => {}
363            Entry::Vacant(v) => {
364                let entry_idx = entries.len();
365                entries.push(NumericRowUnionEntry {
366                    row_data: row_values,
367                    a_row: Some(r),
368                    b_row: None,
369                    order_rank: order_counter,
370                });
371                v.insert(entry_idx);
372                order_counter += 1;
373            }
374        }
375    }
376
377    for r in 0..rows_b {
378        let mut row_values = Vec::with_capacity(cols);
379        for c in 0..cols {
380            let idx = r + c * rows_b;
381            row_values.push(b.data[idx]);
382        }
383        let key = NumericRowKey::from_slice(&row_values);
384        match map.entry(key) {
385            Entry::Occupied(occ) => {
386                let entry = &mut entries[*occ.get()];
387                if entry.a_row.is_none() && entry.b_row.is_none() {
388                    entry.b_row = Some(r);
389                }
390            }
391            Entry::Vacant(v) => {
392                let entry_idx = entries.len();
393                entries.push(NumericRowUnionEntry {
394                    row_data: row_values,
395                    a_row: None,
396                    b_row: Some(r),
397                    order_rank: order_counter,
398                });
399                v.insert(entry_idx);
400                order_counter += 1;
401            }
402        }
403    }
404
405    assemble_numeric_row_union(entries, opts, cols)
406}
407
408fn union_complex(
409    a: ComplexTensor,
410    b: ComplexTensor,
411    opts: &UnionOptions,
412) -> crate::BuiltinResult<UnionEvaluation> {
413    if opts.rows {
414        union_complex_rows(a, b, opts)
415    } else {
416        union_complex_elements(a, b, opts)
417    }
418}
419
420fn union_complex_elements(
421    a: ComplexTensor,
422    b: ComplexTensor,
423    opts: &UnionOptions,
424) -> crate::BuiltinResult<UnionEvaluation> {
425    let mut entries = Vec::<ComplexUnionEntry>::new();
426    let mut map: HashMap<ComplexKey, usize> = HashMap::new();
427    let mut order_counter = 0usize;
428
429    for (idx, &value) in a.data.iter().enumerate() {
430        let key = ComplexKey::new(value);
431        match map.entry(key) {
432            Entry::Occupied(_) => {}
433            Entry::Vacant(v) => {
434                let entry_idx = entries.len();
435                entries.push(ComplexUnionEntry {
436                    value,
437                    a_index: Some(idx),
438                    b_index: None,
439                    order_rank: order_counter,
440                });
441                v.insert(entry_idx);
442                order_counter += 1;
443            }
444        }
445    }
446
447    for (idx, &value) in b.data.iter().enumerate() {
448        let key = ComplexKey::new(value);
449        match map.entry(key) {
450            Entry::Occupied(occ) => {
451                let entry = &mut entries[*occ.get()];
452                if entry.a_index.is_none() && entry.b_index.is_none() {
453                    entry.b_index = Some(idx);
454                }
455            }
456            Entry::Vacant(v) => {
457                let entry_idx = entries.len();
458                entries.push(ComplexUnionEntry {
459                    value,
460                    a_index: None,
461                    b_index: Some(idx),
462                    order_rank: order_counter,
463                });
464                v.insert(entry_idx);
465                order_counter += 1;
466            }
467        }
468    }
469
470    assemble_complex_union(entries, opts)
471}
472
473fn union_complex_rows(
474    a: ComplexTensor,
475    b: ComplexTensor,
476    opts: &UnionOptions,
477) -> crate::BuiltinResult<UnionEvaluation> {
478    if a.shape.len() != 2 || b.shape.len() != 2 {
479        return Err(union_error(
480            "union: 'rows' option requires 2-D complex matrices",
481        ));
482    }
483    if a.shape[1] != b.shape[1] {
484        return Err(union_error(
485            "union: inputs must have the same number of columns when using 'rows'",
486        ));
487    }
488    let rows_a = a.shape[0];
489    let cols = a.shape[1];
490    let rows_b = b.shape[0];
491
492    let mut entries = Vec::<ComplexRowUnionEntry>::new();
493    let mut map: HashMap<Vec<ComplexKey>, usize> = HashMap::new();
494    let mut order_counter = 0usize;
495
496    for r in 0..rows_a {
497        let mut row_values = Vec::with_capacity(cols);
498        let mut key_row = Vec::with_capacity(cols);
499        for c in 0..cols {
500            let idx = r + c * rows_a;
501            let value = a.data[idx];
502            row_values.push(value);
503            key_row.push(ComplexKey::new(value));
504        }
505        match map.entry(key_row) {
506            Entry::Occupied(_) => {}
507            Entry::Vacant(v) => {
508                let entry_idx = entries.len();
509                entries.push(ComplexRowUnionEntry {
510                    row_data: row_values,
511                    a_row: Some(r),
512                    b_row: None,
513                    order_rank: order_counter,
514                });
515                v.insert(entry_idx);
516                order_counter += 1;
517            }
518        }
519    }
520
521    for r in 0..rows_b {
522        let mut row_values = Vec::with_capacity(cols);
523        let mut key_row = Vec::with_capacity(cols);
524        for c in 0..cols {
525            let idx = r + c * rows_b;
526            let value = b.data[idx];
527            row_values.push(value);
528            key_row.push(ComplexKey::new(value));
529        }
530        match map.entry(key_row) {
531            Entry::Occupied(occ) => {
532                let entry = &mut entries[*occ.get()];
533                if entry.a_row.is_none() && entry.b_row.is_none() {
534                    entry.b_row = Some(r);
535                }
536            }
537            Entry::Vacant(v) => {
538                let entry_idx = entries.len();
539                entries.push(ComplexRowUnionEntry {
540                    row_data: row_values,
541                    a_row: None,
542                    b_row: Some(r),
543                    order_rank: order_counter,
544                });
545                v.insert(entry_idx);
546                order_counter += 1;
547            }
548        }
549    }
550
551    assemble_complex_row_union(entries, opts, cols)
552}
553
554fn union_char(
555    a: CharArray,
556    b: CharArray,
557    opts: &UnionOptions,
558) -> crate::BuiltinResult<UnionEvaluation> {
559    if opts.rows {
560        union_char_rows(a, b, opts)
561    } else {
562        union_char_elements(a, b, opts)
563    }
564}
565
566fn union_char_elements(
567    a: CharArray,
568    b: CharArray,
569    opts: &UnionOptions,
570) -> crate::BuiltinResult<UnionEvaluation> {
571    let mut entries = Vec::<CharUnionEntry>::new();
572    let mut map: HashMap<u32, usize> = HashMap::new();
573    let mut order_counter = 0usize;
574
575    for col in 0..a.cols {
576        for row in 0..a.rows {
577            let linear_idx = row + col * a.rows;
578            let data_idx = row * a.cols + col;
579            let ch = a.data[data_idx];
580            let key = ch as u32;
581            match map.entry(key) {
582                Entry::Occupied(_) => {}
583                Entry::Vacant(v) => {
584                    let entry_idx = entries.len();
585                    entries.push(CharUnionEntry {
586                        ch,
587                        a_index: Some(linear_idx),
588                        b_index: None,
589                        order_rank: order_counter,
590                    });
591                    v.insert(entry_idx);
592                    order_counter += 1;
593                }
594            }
595        }
596    }
597
598    for col in 0..b.cols {
599        for row in 0..b.rows {
600            let linear_idx = row + col * b.rows;
601            let data_idx = row * b.cols + col;
602            let ch = b.data[data_idx];
603            let key = ch as u32;
604            match map.entry(key) {
605                Entry::Occupied(occ) => {
606                    let entry = &mut entries[*occ.get()];
607                    if entry.a_index.is_none() && entry.b_index.is_none() {
608                        entry.b_index = Some(linear_idx);
609                    }
610                }
611                Entry::Vacant(v) => {
612                    let entry_idx = entries.len();
613                    entries.push(CharUnionEntry {
614                        ch,
615                        a_index: None,
616                        b_index: Some(linear_idx),
617                        order_rank: order_counter,
618                    });
619                    v.insert(entry_idx);
620                    order_counter += 1;
621                }
622            }
623        }
624    }
625
626    assemble_char_union(entries, opts)
627}
628
629fn union_char_rows(
630    a: CharArray,
631    b: CharArray,
632    opts: &UnionOptions,
633) -> crate::BuiltinResult<UnionEvaluation> {
634    if a.cols != b.cols {
635        return Err(union_error(
636            "union: inputs must have the same number of columns when using 'rows'",
637        ));
638    }
639    let rows_a = a.rows;
640    let rows_b = b.rows;
641    let cols = a.cols;
642
643    let mut entries = Vec::<CharRowUnionEntry>::new();
644    let mut map: HashMap<RowCharKey, usize> = HashMap::new();
645    let mut order_counter = 0usize;
646
647    for r in 0..rows_a {
648        let mut row_values = Vec::with_capacity(cols);
649        for c in 0..cols {
650            let idx = r * cols + c;
651            row_values.push(a.data[idx]);
652        }
653        let key = RowCharKey::from_slice(&row_values);
654        match map.entry(key) {
655            Entry::Occupied(_) => {}
656            Entry::Vacant(v) => {
657                let entry_idx = entries.len();
658                entries.push(CharRowUnionEntry {
659                    row_data: row_values,
660                    a_row: Some(r),
661                    b_row: None,
662                    order_rank: order_counter,
663                });
664                v.insert(entry_idx);
665                order_counter += 1;
666            }
667        }
668    }
669
670    for r in 0..rows_b {
671        let mut row_values = Vec::with_capacity(cols);
672        for c in 0..cols {
673            let idx = r * cols + c;
674            row_values.push(b.data[idx]);
675        }
676        let key = RowCharKey::from_slice(&row_values);
677        match map.entry(key) {
678            Entry::Occupied(occ) => {
679                let entry = &mut entries[*occ.get()];
680                if entry.a_row.is_none() && entry.b_row.is_none() {
681                    entry.b_row = Some(r);
682                }
683            }
684            Entry::Vacant(v) => {
685                let entry_idx = entries.len();
686                entries.push(CharRowUnionEntry {
687                    row_data: row_values,
688                    a_row: None,
689                    b_row: Some(r),
690                    order_rank: order_counter,
691                });
692                v.insert(entry_idx);
693                order_counter += 1;
694            }
695        }
696    }
697
698    assemble_char_row_union(entries, opts, cols)
699}
700
701fn union_string(
702    a: StringArray,
703    b: StringArray,
704    opts: &UnionOptions,
705) -> crate::BuiltinResult<UnionEvaluation> {
706    if opts.rows {
707        union_string_rows(a, b, opts)
708    } else {
709        union_string_elements(a, b, opts)
710    }
711}
712
713fn union_string_elements(
714    a: StringArray,
715    b: StringArray,
716    opts: &UnionOptions,
717) -> crate::BuiltinResult<UnionEvaluation> {
718    let mut entries = Vec::<StringUnionEntry>::new();
719    let mut map: HashMap<String, usize> = HashMap::new();
720    let mut order_counter = 0usize;
721
722    for (idx, value) in a.data.iter().enumerate() {
723        match map.entry(value.clone()) {
724            Entry::Occupied(_) => {}
725            Entry::Vacant(v) => {
726                let entry_idx = entries.len();
727                entries.push(StringUnionEntry {
728                    value: value.clone(),
729                    a_index: Some(idx),
730                    b_index: None,
731                    order_rank: order_counter,
732                });
733                v.insert(entry_idx);
734                order_counter += 1;
735            }
736        }
737    }
738
739    for (idx, value) in b.data.iter().enumerate() {
740        match map.entry(value.clone()) {
741            Entry::Occupied(occ) => {
742                let entry = &mut entries[*occ.get()];
743                if entry.a_index.is_none() && entry.b_index.is_none() {
744                    entry.b_index = Some(idx);
745                }
746            }
747            Entry::Vacant(v) => {
748                let entry_idx = entries.len();
749                entries.push(StringUnionEntry {
750                    value: value.clone(),
751                    a_index: None,
752                    b_index: Some(idx),
753                    order_rank: order_counter,
754                });
755                v.insert(entry_idx);
756                order_counter += 1;
757            }
758        }
759    }
760
761    assemble_string_union(entries, opts)
762}
763
764fn union_string_rows(
765    a: StringArray,
766    b: StringArray,
767    opts: &UnionOptions,
768) -> crate::BuiltinResult<UnionEvaluation> {
769    if a.shape.len() != 2 || b.shape.len() != 2 {
770        return Err(union_error(
771            "union: 'rows' option requires 2-D string arrays",
772        ));
773    }
774    if a.shape[1] != b.shape[1] {
775        return Err(union_error(
776            "union: inputs must have the same number of columns when using 'rows'",
777        ));
778    }
779    let rows_a = a.shape[0];
780    let cols = a.shape[1];
781    let rows_b = b.shape[0];
782
783    let mut entries = Vec::<StringRowUnionEntry>::new();
784    let mut map: HashMap<RowStringKey, usize> = HashMap::new();
785    let mut order_counter = 0usize;
786
787    for r in 0..rows_a {
788        let mut row_values = Vec::with_capacity(cols);
789        for c in 0..cols {
790            let idx = r + c * rows_a;
791            row_values.push(a.data[idx].clone());
792        }
793        let key = RowStringKey(row_values.clone());
794        match map.entry(key) {
795            Entry::Occupied(_) => {}
796            Entry::Vacant(v) => {
797                let entry_idx = entries.len();
798                entries.push(StringRowUnionEntry {
799                    row_data: row_values,
800                    a_row: Some(r),
801                    b_row: None,
802                    order_rank: order_counter,
803                });
804                v.insert(entry_idx);
805                order_counter += 1;
806            }
807        }
808    }
809
810    for r in 0..rows_b {
811        let mut row_values = Vec::with_capacity(cols);
812        for c in 0..cols {
813            let idx = r + c * rows_b;
814            row_values.push(b.data[idx].clone());
815        }
816        let key = RowStringKey(row_values.clone());
817        match map.entry(key) {
818            Entry::Occupied(occ) => {
819                let entry = &mut entries[*occ.get()];
820                if entry.a_row.is_none() && entry.b_row.is_none() {
821                    entry.b_row = Some(r);
822                }
823            }
824            Entry::Vacant(v) => {
825                let entry_idx = entries.len();
826                entries.push(StringRowUnionEntry {
827                    row_data: row_values,
828                    a_row: None,
829                    b_row: Some(r),
830                    order_rank: order_counter,
831                });
832                v.insert(entry_idx);
833                order_counter += 1;
834            }
835        }
836    }
837
838    assemble_string_row_union(entries, opts, cols)
839}
840
841#[derive(Debug, Clone)]
842pub struct UnionEvaluation {
843    values: Value,
844    ia: Tensor,
845    ib: Tensor,
846}
847
848impl UnionEvaluation {
849    fn new(values: Value, ia: Tensor, ib: Tensor) -> Self {
850        Self { values, ia, ib }
851    }
852
853    pub fn from_union_result(result: UnionResult) -> crate::BuiltinResult<Self> {
854        let UnionResult { values, ia, ib } = result;
855        let values_tensor = Tensor::new(values.data, values.shape)
856            .map_err(|e| union_error(format!("union: {e}")))?;
857        let ia_tensor =
858            Tensor::new(ia.data, ia.shape).map_err(|e| union_error(format!("union: {e}")))?;
859        let ib_tensor =
860            Tensor::new(ib.data, ib.shape).map_err(|e| union_error(format!("union: {e}")))?;
861        Ok(UnionEvaluation::new(
862            tensor::tensor_into_value(values_tensor),
863            ia_tensor,
864            ib_tensor,
865        ))
866    }
867
868    pub fn into_numeric_union_result(self) -> crate::BuiltinResult<UnionResult> {
869        let UnionEvaluation { values, ia, ib } = self;
870        let values_tensor =
871            tensor::value_into_tensor_for("union", values).map_err(|e| union_error(e))?;
872        Ok(UnionResult {
873            values: HostTensorOwned {
874                data: values_tensor.data,
875                shape: values_tensor.shape,
876                storage: GpuTensorStorage::Real,
877            },
878            ia: HostTensorOwned {
879                data: ia.data,
880                shape: ia.shape,
881                storage: GpuTensorStorage::Real,
882            },
883            ib: HostTensorOwned {
884                data: ib.data,
885                shape: ib.shape,
886                storage: GpuTensorStorage::Real,
887            },
888        })
889    }
890
891    pub fn into_values_value(self) -> Value {
892        self.values
893    }
894
895    pub fn into_pair(self) -> (Value, Value) {
896        let ia = tensor::tensor_into_value(self.ia);
897        (self.values, ia)
898    }
899
900    pub fn into_triple(self) -> (Value, Value, Value) {
901        let ia = tensor::tensor_into_value(self.ia);
902        let ib = tensor::tensor_into_value(self.ib);
903        (self.values, ia, ib)
904    }
905
906    pub fn values_value(&self) -> Value {
907        self.values.clone()
908    }
909
910    pub fn ia_value(&self) -> Value {
911        tensor::tensor_into_value(self.ia.clone())
912    }
913
914    pub fn ib_value(&self) -> Value {
915        tensor::tensor_into_value(self.ib.clone())
916    }
917}
918
919#[derive(Debug)]
920struct NumericUnionEntry {
921    value: f64,
922    a_index: Option<usize>,
923    b_index: Option<usize>,
924    order_rank: usize,
925}
926
927#[derive(Debug)]
928struct NumericRowUnionEntry {
929    row_data: Vec<f64>,
930    a_row: Option<usize>,
931    b_row: Option<usize>,
932    order_rank: usize,
933}
934
935#[derive(Debug)]
936struct ComplexUnionEntry {
937    value: (f64, f64),
938    a_index: Option<usize>,
939    b_index: Option<usize>,
940    order_rank: usize,
941}
942
943#[derive(Debug)]
944struct ComplexRowUnionEntry {
945    row_data: Vec<(f64, f64)>,
946    a_row: Option<usize>,
947    b_row: Option<usize>,
948    order_rank: usize,
949}
950
951#[derive(Debug)]
952struct CharUnionEntry {
953    ch: char,
954    a_index: Option<usize>,
955    b_index: Option<usize>,
956    order_rank: usize,
957}
958
959#[derive(Debug)]
960struct CharRowUnionEntry {
961    row_data: Vec<char>,
962    a_row: Option<usize>,
963    b_row: Option<usize>,
964    order_rank: usize,
965}
966
967#[derive(Debug)]
968struct StringUnionEntry {
969    value: String,
970    a_index: Option<usize>,
971    b_index: Option<usize>,
972    order_rank: usize,
973}
974
975#[derive(Debug)]
976struct StringRowUnionEntry {
977    row_data: Vec<String>,
978    a_row: Option<usize>,
979    b_row: Option<usize>,
980    order_rank: usize,
981}
982
983#[derive(Debug, Clone, PartialEq, Eq, Hash)]
984struct NumericRowKey(Vec<u64>);
985
986impl NumericRowKey {
987    fn from_slice(values: &[f64]) -> Self {
988        NumericRowKey(values.iter().map(|&v| canonicalize_f64(v)).collect())
989    }
990}
991
992#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
993struct ComplexKey {
994    re: u64,
995    im: u64,
996}
997
998impl ComplexKey {
999    fn new(value: (f64, f64)) -> Self {
1000        Self {
1001            re: canonicalize_f64(value.0),
1002            im: canonicalize_f64(value.1),
1003        }
1004    }
1005}
1006
1007#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1008struct RowCharKey(Vec<u32>);
1009
1010impl RowCharKey {
1011    fn from_slice(values: &[char]) -> Self {
1012        RowCharKey(values.iter().map(|&ch| ch as u32).collect())
1013    }
1014}
1015
1016#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1017struct RowStringKey(Vec<String>);
1018
1019fn assemble_numeric_union(
1020    entries: Vec<NumericUnionEntry>,
1021    opts: &UnionOptions,
1022) -> crate::BuiltinResult<UnionEvaluation> {
1023    let mut order: Vec<usize> = (0..entries.len()).collect();
1024    match opts.order {
1025        UnionOrder::Sorted => {
1026            order.sort_by(|&lhs, &rhs| compare_f64(entries[lhs].value, entries[rhs].value));
1027        }
1028        UnionOrder::Stable => {
1029            order.sort_by_key(|&idx| entries[idx].order_rank);
1030        }
1031    }
1032
1033    let mut values = Vec::with_capacity(order.len());
1034    let mut ia = Vec::new();
1035    let mut ib = Vec::new();
1036    for &idx in &order {
1037        let entry = &entries[idx];
1038        values.push(entry.value);
1039        if let Some(a_idx) = entry.a_index {
1040            ia.push((a_idx + 1) as f64);
1041        } else if let Some(b_idx) = entry.b_index {
1042            ib.push((b_idx + 1) as f64);
1043        }
1044    }
1045
1046    let value_tensor = Tensor::new(values, vec![order.len(), 1])
1047        .map_err(|e| union_error(format!("union: {e}")))?;
1048    let ia_len = ia.len();
1049    let ib_len = ib.len();
1050    let ia_tensor =
1051        Tensor::new(ia, vec![ia_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1052    let ib_tensor =
1053        Tensor::new(ib, vec![ib_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1054
1055    Ok(UnionEvaluation::new(
1056        tensor::tensor_into_value(value_tensor),
1057        ia_tensor,
1058        ib_tensor,
1059    ))
1060}
1061
1062fn assemble_numeric_row_union(
1063    entries: Vec<NumericRowUnionEntry>,
1064    opts: &UnionOptions,
1065    cols: usize,
1066) -> crate::BuiltinResult<UnionEvaluation> {
1067    let mut order: Vec<usize> = (0..entries.len()).collect();
1068    match opts.order {
1069        UnionOrder::Sorted => {
1070            order.sort_by(|&lhs, &rhs| {
1071                compare_numeric_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1072            });
1073        }
1074        UnionOrder::Stable => {
1075            order.sort_by_key(|&idx| entries[idx].order_rank);
1076        }
1077    }
1078
1079    let unique_rows = order.len();
1080    let mut values = vec![0.0f64; unique_rows * cols];
1081    let mut ia = Vec::new();
1082    let mut ib = Vec::new();
1083
1084    for (row_pos, &entry_idx) in order.iter().enumerate() {
1085        let entry = &entries[entry_idx];
1086        for col in 0..cols {
1087            let dest = row_pos + col * unique_rows;
1088            values[dest] = entry.row_data[col];
1089        }
1090        if let Some(a_row) = entry.a_row {
1091            ia.push((a_row + 1) as f64);
1092        } else if let Some(b_row) = entry.b_row {
1093            ib.push((b_row + 1) as f64);
1094        }
1095    }
1096
1097    let value_tensor = Tensor::new(values, vec![unique_rows, cols])
1098        .map_err(|e| union_error(format!("union: {e}")))?;
1099    let ia_len = ia.len();
1100    let ib_len = ib.len();
1101    let ia_tensor =
1102        Tensor::new(ia, vec![ia_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1103    let ib_tensor =
1104        Tensor::new(ib, vec![ib_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1105
1106    Ok(UnionEvaluation::new(
1107        tensor::tensor_into_value(value_tensor),
1108        ia_tensor,
1109        ib_tensor,
1110    ))
1111}
1112
1113fn assemble_complex_union(
1114    entries: Vec<ComplexUnionEntry>,
1115    opts: &UnionOptions,
1116) -> crate::BuiltinResult<UnionEvaluation> {
1117    let mut order: Vec<usize> = (0..entries.len()).collect();
1118    match opts.order {
1119        UnionOrder::Sorted => {
1120            order.sort_by(|&lhs, &rhs| compare_complex(entries[lhs].value, entries[rhs].value));
1121        }
1122        UnionOrder::Stable => {
1123            order.sort_by_key(|&idx| entries[idx].order_rank);
1124        }
1125    }
1126
1127    let mut values = Vec::with_capacity(order.len());
1128    let mut ia = Vec::new();
1129    let mut ib = Vec::new();
1130    for &idx in &order {
1131        let entry = &entries[idx];
1132        values.push(entry.value);
1133        if let Some(a_idx) = entry.a_index {
1134            ia.push((a_idx + 1) as f64);
1135        } else if let Some(b_idx) = entry.b_index {
1136            ib.push((b_idx + 1) as f64);
1137        }
1138    }
1139
1140    let value_tensor = ComplexTensor::new(values, vec![order.len(), 1])
1141        .map_err(|e| union_error(format!("union: {e}")))?;
1142    let ia_len = ia.len();
1143    let ib_len = ib.len();
1144    let ia_tensor =
1145        Tensor::new(ia, vec![ia_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1146    let ib_tensor =
1147        Tensor::new(ib, vec![ib_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1148
1149    Ok(UnionEvaluation::new(
1150        complex_tensor_into_value(value_tensor),
1151        ia_tensor,
1152        ib_tensor,
1153    ))
1154}
1155
1156fn assemble_complex_row_union(
1157    entries: Vec<ComplexRowUnionEntry>,
1158    opts: &UnionOptions,
1159    cols: usize,
1160) -> crate::BuiltinResult<UnionEvaluation> {
1161    let mut order: Vec<usize> = (0..entries.len()).collect();
1162    match opts.order {
1163        UnionOrder::Sorted => {
1164            order.sort_by(|&lhs, &rhs| {
1165                compare_complex_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1166            });
1167        }
1168        UnionOrder::Stable => {
1169            order.sort_by_key(|&idx| entries[idx].order_rank);
1170        }
1171    }
1172
1173    let unique_rows = order.len();
1174    let mut values = vec![(0.0, 0.0); unique_rows * cols];
1175    let mut ia = Vec::new();
1176    let mut ib = Vec::new();
1177
1178    for (row_pos, &entry_idx) in order.iter().enumerate() {
1179        let entry = &entries[entry_idx];
1180        for col in 0..cols {
1181            let dest = row_pos + col * unique_rows;
1182            values[dest] = entry.row_data[col];
1183        }
1184        if let Some(a_row) = entry.a_row {
1185            ia.push((a_row + 1) as f64);
1186        } else if let Some(b_row) = entry.b_row {
1187            ib.push((b_row + 1) as f64);
1188        }
1189    }
1190
1191    let value_tensor = ComplexTensor::new(values, vec![unique_rows, cols])
1192        .map_err(|e| union_error(format!("union: {e}")))?;
1193    let ia_len = ia.len();
1194    let ib_len = ib.len();
1195    let ia_tensor =
1196        Tensor::new(ia, vec![ia_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1197    let ib_tensor =
1198        Tensor::new(ib, vec![ib_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1199
1200    Ok(UnionEvaluation::new(
1201        complex_tensor_into_value(value_tensor),
1202        ia_tensor,
1203        ib_tensor,
1204    ))
1205}
1206
1207fn assemble_char_union(
1208    entries: Vec<CharUnionEntry>,
1209    opts: &UnionOptions,
1210) -> crate::BuiltinResult<UnionEvaluation> {
1211    let mut order: Vec<usize> = (0..entries.len()).collect();
1212    match opts.order {
1213        UnionOrder::Sorted => {
1214            order.sort_by(|&lhs, &rhs| entries[lhs].ch.cmp(&entries[rhs].ch));
1215        }
1216        UnionOrder::Stable => {
1217            order.sort_by_key(|&idx| entries[idx].order_rank);
1218        }
1219    }
1220
1221    let mut values = Vec::with_capacity(order.len());
1222    let mut ia = Vec::new();
1223    let mut ib = Vec::new();
1224    for &idx in &order {
1225        let entry = &entries[idx];
1226        values.push(entry.ch);
1227        if let Some(a_idx) = entry.a_index {
1228            ia.push((a_idx + 1) as f64);
1229        } else if let Some(b_idx) = entry.b_index {
1230            ib.push((b_idx + 1) as f64);
1231        }
1232    }
1233
1234    let value_array =
1235        CharArray::new(values, order.len(), 1).map_err(|e| union_error(format!("union: {e}")))?;
1236    let ia_len = ia.len();
1237    let ib_len = ib.len();
1238    let ia_tensor =
1239        Tensor::new(ia, vec![ia_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1240    let ib_tensor =
1241        Tensor::new(ib, vec![ib_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1242
1243    Ok(UnionEvaluation::new(
1244        Value::CharArray(value_array),
1245        ia_tensor,
1246        ib_tensor,
1247    ))
1248}
1249
1250fn assemble_char_row_union(
1251    entries: Vec<CharRowUnionEntry>,
1252    opts: &UnionOptions,
1253    cols: usize,
1254) -> crate::BuiltinResult<UnionEvaluation> {
1255    let mut order: Vec<usize> = (0..entries.len()).collect();
1256    match opts.order {
1257        UnionOrder::Sorted => {
1258            order.sort_by(|&lhs, &rhs| {
1259                compare_char_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1260            });
1261        }
1262        UnionOrder::Stable => {
1263            order.sort_by_key(|&idx| entries[idx].order_rank);
1264        }
1265    }
1266
1267    let unique_rows = order.len();
1268    let mut values = vec!['\0'; unique_rows * cols];
1269    let mut ia = Vec::new();
1270    let mut ib = Vec::new();
1271
1272    for (row_pos, &entry_idx) in order.iter().enumerate() {
1273        let entry = &entries[entry_idx];
1274        for col in 0..cols {
1275            let dest = row_pos * cols + col;
1276            values[dest] = entry.row_data[col];
1277        }
1278        if let Some(a_row) = entry.a_row {
1279            ia.push((a_row + 1) as f64);
1280        } else if let Some(b_row) = entry.b_row {
1281            ib.push((b_row + 1) as f64);
1282        }
1283    }
1284
1285    let value_array = CharArray::new(values, unique_rows, cols)
1286        .map_err(|e| union_error(format!("union: {e}")))?;
1287    let ia_len = ia.len();
1288    let ib_len = ib.len();
1289    let ia_tensor =
1290        Tensor::new(ia, vec![ia_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1291    let ib_tensor =
1292        Tensor::new(ib, vec![ib_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1293
1294    Ok(UnionEvaluation::new(
1295        Value::CharArray(value_array),
1296        ia_tensor,
1297        ib_tensor,
1298    ))
1299}
1300
1301fn assemble_string_union(
1302    entries: Vec<StringUnionEntry>,
1303    opts: &UnionOptions,
1304) -> crate::BuiltinResult<UnionEvaluation> {
1305    let mut order: Vec<usize> = (0..entries.len()).collect();
1306    match opts.order {
1307        UnionOrder::Sorted => {
1308            order.sort_by(|&lhs, &rhs| entries[lhs].value.cmp(&entries[rhs].value));
1309        }
1310        UnionOrder::Stable => {
1311            order.sort_by_key(|&idx| entries[idx].order_rank);
1312        }
1313    }
1314
1315    let mut values = Vec::with_capacity(order.len());
1316    let mut ia = Vec::new();
1317    let mut ib = Vec::new();
1318    for &idx in &order {
1319        let entry = &entries[idx];
1320        values.push(entry.value.clone());
1321        if let Some(a_idx) = entry.a_index {
1322            ia.push((a_idx + 1) as f64);
1323        } else if let Some(b_idx) = entry.b_index {
1324            ib.push((b_idx + 1) as f64);
1325        }
1326    }
1327
1328    let value_array = StringArray::new(values, vec![order.len(), 1])
1329        .map_err(|e| union_error(format!("union: {e}")))?;
1330    let ia_len = ia.len();
1331    let ib_len = ib.len();
1332    let ia_tensor =
1333        Tensor::new(ia, vec![ia_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1334    let ib_tensor =
1335        Tensor::new(ib, vec![ib_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1336
1337    Ok(UnionEvaluation::new(
1338        Value::StringArray(value_array),
1339        ia_tensor,
1340        ib_tensor,
1341    ))
1342}
1343
1344fn assemble_string_row_union(
1345    entries: Vec<StringRowUnionEntry>,
1346    opts: &UnionOptions,
1347    cols: usize,
1348) -> crate::BuiltinResult<UnionEvaluation> {
1349    let mut order: Vec<usize> = (0..entries.len()).collect();
1350    match opts.order {
1351        UnionOrder::Sorted => {
1352            order.sort_by(|&lhs, &rhs| {
1353                compare_string_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1354            });
1355        }
1356        UnionOrder::Stable => {
1357            order.sort_by_key(|&idx| entries[idx].order_rank);
1358        }
1359    }
1360
1361    let unique_rows = order.len();
1362    let mut values = vec![String::new(); unique_rows * cols];
1363    let mut ia = Vec::new();
1364    let mut ib = Vec::new();
1365
1366    for (row_pos, &entry_idx) in order.iter().enumerate() {
1367        let entry = &entries[entry_idx];
1368        for col in 0..cols {
1369            let dest = row_pos + col * unique_rows;
1370            values[dest] = entry.row_data[col].clone();
1371        }
1372        if let Some(a_row) = entry.a_row {
1373            ia.push((a_row + 1) as f64);
1374        } else if let Some(b_row) = entry.b_row {
1375            ib.push((b_row + 1) as f64);
1376        }
1377    }
1378
1379    let value_array = StringArray::new(values, vec![unique_rows, cols])
1380        .map_err(|e| union_error(format!("union: {e}")))?;
1381    let ia_len = ia.len();
1382    let ib_len = ib.len();
1383    let ia_tensor =
1384        Tensor::new(ia, vec![ia_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1385    let ib_tensor =
1386        Tensor::new(ib, vec![ib_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1387
1388    Ok(UnionEvaluation::new(
1389        Value::StringArray(value_array),
1390        ia_tensor,
1391        ib_tensor,
1392    ))
1393}
1394
1395fn canonicalize_f64(value: f64) -> u64 {
1396    if value.is_nan() {
1397        0x7ff8_0000_0000_0000u64
1398    } else if value == 0.0 {
1399        0u64
1400    } else {
1401        value.to_bits()
1402    }
1403}
1404
1405fn compare_f64(a: f64, b: f64) -> Ordering {
1406    if a.is_nan() {
1407        if b.is_nan() {
1408            Ordering::Equal
1409        } else {
1410            Ordering::Greater
1411        }
1412    } else if b.is_nan() {
1413        Ordering::Less
1414    } else {
1415        a.partial_cmp(&b).unwrap_or(Ordering::Equal)
1416    }
1417}
1418
1419fn compare_numeric_rows(a: &[f64], b: &[f64]) -> Ordering {
1420    for (lhs, rhs) in a.iter().zip(b.iter()) {
1421        let ord = compare_f64(*lhs, *rhs);
1422        if ord != Ordering::Equal {
1423            return ord;
1424        }
1425    }
1426    Ordering::Equal
1427}
1428
1429fn complex_is_nan(value: (f64, f64)) -> bool {
1430    value.0.is_nan() || value.1.is_nan()
1431}
1432
1433fn compare_complex(a: (f64, f64), b: (f64, f64)) -> Ordering {
1434    match (complex_is_nan(a), complex_is_nan(b)) {
1435        (true, true) => Ordering::Equal,
1436        (true, false) => Ordering::Greater,
1437        (false, true) => Ordering::Less,
1438        (false, false) => {
1439            let mag_a = a.0.hypot(a.1);
1440            let mag_b = b.0.hypot(b.1);
1441            let mag_cmp = compare_f64(mag_a, mag_b);
1442            if mag_cmp != Ordering::Equal {
1443                return mag_cmp;
1444            }
1445            let re_cmp = compare_f64(a.0, b.0);
1446            if re_cmp != Ordering::Equal {
1447                return re_cmp;
1448            }
1449            compare_f64(a.1, b.1)
1450        }
1451    }
1452}
1453
1454fn compare_complex_rows(a: &[(f64, f64)], b: &[(f64, f64)]) -> Ordering {
1455    for (lhs, rhs) in a.iter().zip(b.iter()) {
1456        let ord = compare_complex(*lhs, *rhs);
1457        if ord != Ordering::Equal {
1458            return ord;
1459        }
1460    }
1461    Ordering::Equal
1462}
1463
1464fn compare_char_rows(a: &[char], b: &[char]) -> Ordering {
1465    for (lhs, rhs) in a.iter().zip(b.iter()) {
1466        let ord = lhs.cmp(rhs);
1467        if ord != Ordering::Equal {
1468            return ord;
1469        }
1470    }
1471    Ordering::Equal
1472}
1473
1474fn compare_string_rows(a: &[String], b: &[String]) -> Ordering {
1475    for (lhs, rhs) in a.iter().zip(b.iter()) {
1476        let ord = lhs.cmp(rhs);
1477        if ord != Ordering::Equal {
1478            return ord;
1479        }
1480    }
1481    Ordering::Equal
1482}
1483
1484#[cfg(test)]
1485pub(crate) mod tests {
1486    use super::*;
1487    use crate::builtins::common::test_support;
1488    use runmat_accelerate_api::HostTensorView;
1489    use runmat_builtins::{IntValue, ResolveContext, Tensor, Type, Value};
1490
1491    fn error_message(err: crate::RuntimeError) -> String {
1492        err.message().to_string()
1493    }
1494
1495    fn evaluate_sync(a: Value, b: Value, rest: &[Value]) -> crate::BuiltinResult<UnionEvaluation> {
1496        futures::executor::block_on(evaluate(a, b, rest))
1497    }
1498
1499    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1500    #[test]
1501    fn union_numeric_sorted_default() {
1502        let a = Tensor::new(vec![5.0, 7.0, 1.0], vec![3, 1]).unwrap();
1503        let b = Tensor::new(vec![3.0, 1.0, 1.0], vec![3, 1]).unwrap();
1504        let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[]).expect("union");
1505        match eval.values_value() {
1506            Value::Tensor(t) => {
1507                assert_eq!(t.data, vec![1.0, 3.0, 5.0, 7.0]);
1508                assert_eq!(t.shape, vec![4, 1]);
1509            }
1510            other => panic!("expected tensor result, got {other:?}"),
1511        }
1512        let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1513        assert_eq!(ia.data, vec![3.0, 1.0, 2.0]);
1514        assert_eq!(ia.shape, vec![3, 1]);
1515        let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1516        assert_eq!(ib.data, vec![1.0]);
1517        assert_eq!(ib.shape, vec![1, 1]);
1518    }
1519
1520    #[test]
1521    fn union_type_resolver_numeric() {
1522        assert_eq!(
1523            set_values_output_type(
1524                &[Type::tensor(), Type::tensor()],
1525                &ResolveContext::new(Vec::new()),
1526            ),
1527            Type::tensor()
1528        );
1529    }
1530
1531    #[test]
1532    fn union_type_resolver_string_array() {
1533        assert_eq!(
1534            set_values_output_type(
1535                &[Type::cell_of(Type::String), Type::String],
1536                &ResolveContext::new(Vec::new()),
1537            ),
1538            Type::cell_of(Type::String)
1539        );
1540    }
1541
1542    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1543    #[test]
1544    fn union_numeric_stable_order() {
1545        let a = Tensor::new(vec![5.0, 7.0, 1.0], vec![3, 1]).unwrap();
1546        let b = Tensor::new(vec![3.0, 2.0, 4.0], vec![3, 1]).unwrap();
1547        let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[Value::from("stable")])
1548            .expect("union");
1549        match eval.values_value() {
1550            Value::Tensor(t) => {
1551                assert_eq!(t.data, vec![5.0, 7.0, 1.0, 3.0, 2.0, 4.0]);
1552                assert_eq!(t.shape, vec![6, 1]);
1553            }
1554            other => panic!("expected tensor result, got {other:?}"),
1555        }
1556        let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1557        assert_eq!(ia.data, vec![1.0, 2.0, 3.0]);
1558        let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1559        assert_eq!(ib.data, vec![1.0, 2.0, 3.0]);
1560    }
1561
1562    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1563    #[test]
1564    fn union_numeric_sorted_places_nan_last() {
1565        let a = Tensor::new(vec![f64::NAN, 1.0], vec![2, 1]).unwrap();
1566        let b = Tensor::new(vec![2.0, f64::NAN], vec![2, 1]).unwrap();
1567        let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[]).expect("union");
1568        let values = tensor::value_into_tensor_for("union", eval.values_value()).expect("values");
1569        assert_eq!(values.shape, vec![3, 1]);
1570        assert_eq!(values.data[0], 1.0);
1571        assert_eq!(values.data[1], 2.0);
1572        assert!(values.data[2].is_nan());
1573        let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1574        assert_eq!(ia.data, vec![2.0, 1.0]);
1575        let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1576        assert_eq!(ib.data, vec![1.0]);
1577    }
1578
1579    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1580    #[test]
1581    fn union_numeric_rows_sorted() {
1582        let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
1583        let b = Tensor::new(vec![3.0, 5.0, 4.0, 6.0], vec![2, 2]).unwrap();
1584        let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[Value::from("rows")])
1585            .expect("union");
1586        match eval.values_value() {
1587            Value::Tensor(t) => {
1588                assert_eq!(t.shape, vec![3, 2]);
1589                assert_eq!(t.data, vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
1590            }
1591            other => panic!("expected tensor result, got {other:?}"),
1592        }
1593        let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1594        assert_eq!(ia.data, vec![1.0, 2.0]);
1595        let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1596        assert_eq!(ib.data, vec![2.0]);
1597    }
1598
1599    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1600    #[test]
1601    fn union_numeric_rows_stable_preserves_first_occurrence() {
1602        let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
1603        let b = Tensor::new(vec![3.0, 5.0, 1.0, 4.0, 6.0, 2.0], vec![3, 2]).unwrap();
1604        let eval = evaluate_sync(
1605            Value::Tensor(a),
1606            Value::Tensor(b),
1607            &[Value::from("rows"), Value::from("stable")],
1608        )
1609        .expect("union");
1610        let (values, ia, ib) = eval.into_triple();
1611        match values {
1612            Value::Tensor(t) => {
1613                assert_eq!(t.shape, vec![3, 2]);
1614                assert_eq!(t.data, vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
1615            }
1616            other => panic!("expected tensor result, got {other:?}"),
1617        }
1618        let ia_tensor = tensor::value_into_tensor_for("union", ia).expect("ia tensor");
1619        assert_eq!(ia_tensor.data, vec![1.0, 2.0]);
1620        let ib_tensor = tensor::value_into_tensor_for("union", ib).expect("ib tensor");
1621        assert_eq!(ib_tensor.data, vec![2.0]);
1622    }
1623
1624    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1625    #[test]
1626    fn union_char_elements() {
1627        let a = CharArray::new(vec!['m', 'z', 'm', 'a'], 2, 2).unwrap();
1628        let b = CharArray::new(vec!['a', 'x', 'm', 'a'], 2, 2).unwrap();
1629        let eval = evaluate_sync(Value::CharArray(a), Value::CharArray(b), &[]).expect("union");
1630        match eval.values_value() {
1631            Value::CharArray(arr) => {
1632                assert_eq!(arr.rows, 4);
1633                assert_eq!(arr.cols, 1);
1634                assert_eq!(arr.data, vec!['a', 'm', 'x', 'z']);
1635            }
1636            other => panic!("expected char array, got {other:?}"),
1637        }
1638        let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1639        assert_eq!(ia.data, vec![4.0, 1.0, 3.0]);
1640        let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1641        assert_eq!(ib.data, vec![3.0]);
1642    }
1643
1644    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1645    #[test]
1646    fn union_string_rows_stable() {
1647        let a = StringArray::new(
1648            vec![
1649                "alpha".to_string(),
1650                "gamma".to_string(),
1651                "beta".to_string(),
1652                "beta".to_string(),
1653            ],
1654            vec![2, 2],
1655        )
1656        .unwrap();
1657        let b = StringArray::new(
1658            vec![
1659                "gamma".to_string(),
1660                "delta".to_string(),
1661                "beta".to_string(),
1662                "beta".to_string(),
1663            ],
1664            vec![2, 2],
1665        )
1666        .unwrap();
1667        let eval = evaluate_sync(
1668            Value::StringArray(a),
1669            Value::StringArray(b),
1670            &[Value::from("rows"), Value::from("stable")],
1671        )
1672        .expect("union");
1673        match eval.values_value() {
1674            Value::StringArray(arr) => {
1675                assert_eq!(arr.shape, vec![3, 2]);
1676                assert_eq!(
1677                    arr.data,
1678                    vec![
1679                        "alpha".to_string(),
1680                        "gamma".to_string(),
1681                        "delta".to_string(),
1682                        "beta".to_string(),
1683                        "beta".to_string(),
1684                        "beta".to_string()
1685                    ]
1686                );
1687            }
1688            other => panic!("expected string array, got {other:?}"),
1689        }
1690        let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1691        assert_eq!(ia.data, vec![1.0, 2.0]);
1692        let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1693        assert_eq!(ib.data, vec![2.0]);
1694    }
1695
1696    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1697    #[test]
1698    fn union_gpu_roundtrip() {
1699        test_support::with_test_provider(|provider| {
1700            let a = Tensor::new(vec![4.0, 1.0, 2.0], vec![3, 1]).unwrap();
1701            let b = Tensor::new(vec![2.0, 5.0], vec![2, 1]).unwrap();
1702            let view_a = HostTensorView {
1703                data: &a.data,
1704                shape: &a.shape,
1705            };
1706            let view_b = HostTensorView {
1707                data: &b.data,
1708                shape: &b.shape,
1709            };
1710            let handle_a = provider.upload(&view_a).expect("upload A");
1711            let handle_b = provider.upload(&view_b).expect("upload B");
1712            let eval = evaluate_sync(
1713                Value::GpuTensor(handle_a),
1714                Value::GpuTensor(handle_b),
1715                &[Value::from("stable")],
1716            )
1717            .expect("union");
1718            let values = tensor::value_into_tensor_for("union", eval.values_value()).unwrap();
1719            assert_eq!(values.data, vec![4.0, 1.0, 2.0, 5.0]);
1720            let ia = tensor::value_into_tensor_for("union", eval.ia_value()).unwrap();
1721            assert_eq!(ia.data, vec![1.0, 2.0, 3.0]);
1722            let ib = tensor::value_into_tensor_for("union", eval.ib_value()).unwrap();
1723            assert_eq!(ib.data, vec![2.0]);
1724        });
1725    }
1726
1727    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1728    #[test]
1729    fn union_rejects_legacy_option() {
1730        let tensor =
1731            Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).expect("tensor construction failed");
1732        let err = error_message(
1733            evaluate_sync(
1734                Value::Tensor(tensor.clone()),
1735                Value::Tensor(tensor),
1736                &[Value::from("legacy")],
1737            )
1738            .unwrap_err(),
1739        );
1740        assert!(err.contains("legacy"));
1741    }
1742
1743    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1744    #[test]
1745    fn union_rows_dimension_mismatch() {
1746        let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1747        let b = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1748        let err = error_message(
1749            evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[Value::from("rows")]).unwrap_err(),
1750        );
1751        assert!(err.contains("same number of columns"));
1752    }
1753
1754    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1755    #[test]
1756    fn union_requires_matching_types() {
1757        let a = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1758        let b = CharArray::new(vec!['a', 'b'], 1, 2).unwrap();
1759        let err = error_message(
1760            union_host(
1761                Value::Tensor(a),
1762                Value::CharArray(b),
1763                &UnionOptions {
1764                    rows: false,
1765                    order: UnionOrder::Sorted,
1766                },
1767            )
1768            .unwrap_err(),
1769        );
1770        assert!(err.contains("unsupported input type"));
1771    }
1772
1773    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1774    #[test]
1775    fn union_accepts_scalar_inputs() {
1776        let eval =
1777            evaluate_sync(Value::Int(IntValue::I32(1)), Value::Num(3.0), &[]).expect("union");
1778        match eval.values_value() {
1779            Value::Tensor(t) => {
1780                assert_eq!(t.data, vec![1.0, 3.0]);
1781                assert_eq!(t.shape, vec![2, 1]);
1782            }
1783            other => panic!("expected numeric tensor, got {other:?}"),
1784        }
1785        let ia = tensor::value_into_tensor_for("union", eval.ia_value()).unwrap();
1786        assert_eq!(ia.data, vec![1.0]);
1787        let ib = tensor::value_into_tensor_for("union", eval.ib_value()).unwrap();
1788        assert_eq!(ib.data, vec![1.0]);
1789    }
1790
1791    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1792    #[test]
1793    #[cfg(feature = "wgpu")]
1794    fn union_wgpu_matches_cpu() {
1795        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1796            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1797        );
1798        let a = Tensor::new(vec![4.0, 1.0, 2.0, 3.0], vec![4, 1]).unwrap();
1799        let b = Tensor::new(vec![2.0, 6.0, 3.0], vec![3, 1]).unwrap();
1800
1801        let cpu_eval =
1802            evaluate_sync(Value::Tensor(a.clone()), Value::Tensor(b.clone()), &[]).expect("union");
1803        let cpu_values = tensor::value_into_tensor_for("union", cpu_eval.values_value()).unwrap();
1804        let cpu_ia = tensor::value_into_tensor_for("union", cpu_eval.ia_value()).unwrap();
1805        let cpu_ib = tensor::value_into_tensor_for("union", cpu_eval.ib_value()).unwrap();
1806
1807        let provider = runmat_accelerate_api::provider().expect("provider");
1808        let view_a = HostTensorView {
1809            data: &a.data,
1810            shape: &a.shape,
1811        };
1812        let view_b = HostTensorView {
1813            data: &b.data,
1814            shape: &b.shape,
1815        };
1816        let handle_a = provider.upload(&view_a).expect("upload A");
1817        let handle_b = provider.upload(&view_b).expect("upload B");
1818        let gpu_eval = evaluate_sync(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
1819            .expect("union");
1820        let gpu_values = tensor::value_into_tensor_for("union", gpu_eval.values_value()).unwrap();
1821        let gpu_ia = tensor::value_into_tensor_for("union", gpu_eval.ia_value()).unwrap();
1822        let gpu_ib = tensor::value_into_tensor_for("union", gpu_eval.ib_value()).unwrap();
1823
1824        assert_eq!(gpu_values.data, cpu_values.data);
1825        assert_eq!(gpu_values.shape, cpu_values.shape);
1826        assert_eq!(gpu_ia.data, cpu_ia.data);
1827        assert_eq!(gpu_ib.data, cpu_ib.data);
1828    }
1829}