Skip to main content

runmat_runtime/builtins/array/sorting_sets/
union.rs

1//! MATLAB-compatible `union` builtin with GPU-aware semantics for RunMat.
2//!
3//! Handles element-wise and row-wise unions with optional stable ordering and
4//! index outputs that mirror MathWorks MATLAB semantics. GPU tensors are
5//! gathered to host memory unless a provider supplies a dedicated `union`
6//! kernel hook.
7
8use std::cmp::Ordering;
9use std::collections::{hash_map::Entry, HashMap};
10
11use runmat_accelerate_api::{
12    GpuTensorHandle, GpuTensorStorage, HostTensorOwned, UnionOptions, UnionOrder, UnionResult,
13};
14use runmat_builtins::{
15    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
16    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
17    CharArray, ComplexTensor, StringArray, Tensor, Value,
18};
19use runmat_macros::runtime_builtin;
20
21use super::type_resolvers::set_values_output_type;
22use crate::build_runtime_error;
23use crate::builtins::common::arg_tokens::tokens_from_values;
24use crate::builtins::common::gpu_helpers;
25use crate::builtins::common::random_args::complex_tensor_into_value;
26use crate::builtins::common::spec::{
27    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
28    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
29};
30use crate::builtins::common::tensor;
31
32#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::sorting_sets::union")]
33pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
34    name: "union",
35    op_kind: GpuOpKind::Custom("union"),
36    supported_precisions: &[ScalarType::F32, ScalarType::F64],
37    broadcast: BroadcastSemantics::None,
38    provider_hooks: &[ProviderHook::Custom("union")],
39    constant_strategy: ConstantStrategy::InlineLiteral,
40    residency: ResidencyPolicy::GatherImmediately,
41    nan_mode: ReductionNaN::Include,
42    two_pass_threshold: None,
43    workgroup_size: None,
44    accepts_nan_mode: true,
45    notes: "Providers may expose a dedicated union hook; otherwise tensors are gathered and processed on the host.",
46};
47
48#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::sorting_sets::union")]
49pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
50    name: "union",
51    shape: ShapeRequirements::Any,
52    constant_strategy: ConstantStrategy::InlineLiteral,
53    elementwise: None,
54    reduction: None,
55    emits_nan: true,
56    notes: "`union` terminates fusion chains and materialises results on the host; upstream tensors are gathered when necessary.",
57};
58
59const BUILTIN_NAME: &str = "union";
60
61const UNION_OUTPUT_C: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
62    name: "C",
63    ty: BuiltinParamType::Any,
64    arity: BuiltinParamArity::Required,
65    default: None,
66    description: "Union values or rows.",
67}];
68
69const UNION_OUTPUT_C_IA: [BuiltinParamDescriptor; 2] = [
70    BuiltinParamDescriptor {
71        name: "C",
72        ty: BuiltinParamType::Any,
73        arity: BuiltinParamArity::Required,
74        default: None,
75        description: "Union values or rows.",
76    },
77    BuiltinParamDescriptor {
78        name: "ia",
79        ty: BuiltinParamType::NumericArray,
80        arity: BuiltinParamArity::Required,
81        default: None,
82        description: "Indices selecting contributions from A.",
83    },
84];
85
86const UNION_OUTPUT_C_IA_IB: [BuiltinParamDescriptor; 3] = [
87    BuiltinParamDescriptor {
88        name: "C",
89        ty: BuiltinParamType::Any,
90        arity: BuiltinParamArity::Required,
91        default: None,
92        description: "Union values or rows.",
93    },
94    BuiltinParamDescriptor {
95        name: "ia",
96        ty: BuiltinParamType::NumericArray,
97        arity: BuiltinParamArity::Required,
98        default: None,
99        description: "Indices selecting contributions from A.",
100    },
101    BuiltinParamDescriptor {
102        name: "ib",
103        ty: BuiltinParamType::NumericArray,
104        arity: BuiltinParamArity::Required,
105        default: None,
106        description: "Indices selecting contributions from B.",
107    },
108];
109
110const UNION_INPUTS_A_B: [BuiltinParamDescriptor; 2] = [
111    BuiltinParamDescriptor {
112        name: "A",
113        ty: BuiltinParamType::Any,
114        arity: BuiltinParamArity::Required,
115        default: None,
116        description: "First input array.",
117    },
118    BuiltinParamDescriptor {
119        name: "B",
120        ty: BuiltinParamType::Any,
121        arity: BuiltinParamArity::Required,
122        default: None,
123        description: "Second input array.",
124    },
125];
126
127const UNION_INPUTS_A_B_OPTIONS: [BuiltinParamDescriptor; 3] = [
128    BuiltinParamDescriptor {
129        name: "A",
130        ty: BuiltinParamType::Any,
131        arity: BuiltinParamArity::Required,
132        default: None,
133        description: "First input array.",
134    },
135    BuiltinParamDescriptor {
136        name: "B",
137        ty: BuiltinParamType::Any,
138        arity: BuiltinParamArity::Required,
139        default: None,
140        description: "Second input array.",
141    },
142    BuiltinParamDescriptor {
143        name: "option",
144        ty: BuiltinParamType::StringScalar,
145        arity: BuiltinParamArity::Variadic,
146        default: None,
147        description: "Option tokens: 'rows'|'sorted'|'stable'.",
148    },
149];
150
151const UNION_SIGNATURES: [BuiltinSignatureDescriptor; 6] = [
152    BuiltinSignatureDescriptor {
153        label: "C = union(A, B)",
154        inputs: &UNION_INPUTS_A_B,
155        outputs: &UNION_OUTPUT_C,
156    },
157    BuiltinSignatureDescriptor {
158        label: "C = union(A, B, option...)",
159        inputs: &UNION_INPUTS_A_B_OPTIONS,
160        outputs: &UNION_OUTPUT_C,
161    },
162    BuiltinSignatureDescriptor {
163        label: "[C, ia] = union(A, B)",
164        inputs: &UNION_INPUTS_A_B,
165        outputs: &UNION_OUTPUT_C_IA,
166    },
167    BuiltinSignatureDescriptor {
168        label: "[C, ia] = union(A, B, option...)",
169        inputs: &UNION_INPUTS_A_B_OPTIONS,
170        outputs: &UNION_OUTPUT_C_IA,
171    },
172    BuiltinSignatureDescriptor {
173        label: "[C, ia, ib] = union(A, B)",
174        inputs: &UNION_INPUTS_A_B,
175        outputs: &UNION_OUTPUT_C_IA_IB,
176    },
177    BuiltinSignatureDescriptor {
178        label: "[C, ia, ib] = union(A, B, option...)",
179        inputs: &UNION_INPUTS_A_B_OPTIONS,
180        outputs: &UNION_OUTPUT_C_IA_IB,
181    },
182];
183
184const UNION_ERROR_LEGACY_OPTION_UNSUPPORTED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
185    code: "RM.UNION.LEGACY_OPTION_UNSUPPORTED",
186    identifier: Some("RunMat:union:LegacyOptionUnsupported"),
187    when: "Legacy compatibility options are requested.",
188    message: "union: the 'legacy' behaviour is not supported",
189};
190
191const UNION_ERROR_CONFLICTING_ORDER_OPTIONS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
192    code: "RM.UNION.CONFLICTING_ORDER_OPTIONS",
193    identifier: Some("RunMat:union:ConflictingOrderOptions"),
194    when: "Both 'sorted' and 'stable' options are provided.",
195    message: "union: cannot combine 'sorted' with 'stable'",
196};
197
198const UNION_ERROR_UNKNOWN_OPTION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
199    code: "RM.UNION.UNKNOWN_OPTION",
200    identifier: Some("RunMat:union:UnknownOption"),
201    when: "An unsupported option token is provided.",
202    message: "union: unrecognised option",
203};
204
205const UNION_ERROR_ROWS_COLUMN_MISMATCH: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
206    code: "RM.UNION.ROWS_COLUMN_MISMATCH",
207    identifier: Some("RunMat:union:RowsColumnMismatch"),
208    when: "'rows' mode is used and column counts differ.",
209    message: "union: inputs must have the same number of columns when using 'rows'",
210};
211
212const UNION_ERROR_UNSUPPORTED_INPUT_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
213    code: "RM.UNION.UNSUPPORTED_INPUT_TYPE",
214    identifier: Some("RunMat:union:UnsupportedInputType"),
215    when: "Input values cannot be converted into supported union domains.",
216    message: "union: unsupported input type",
217};
218
219const UNION_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
220    code: "RM.UNION.INVALID_ARGUMENT",
221    identifier: Some("RunMat:union:InvalidArgument"),
222    when: "Option arguments are not string-like where required.",
223    message: "union: expected string option arguments",
224};
225
226const UNION_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
227    code: "RM.UNION.INTERNAL",
228    identifier: Some("RunMat:union:Internal"),
229    when: "Internal conversion/allocation/provider decode fails.",
230    message: "union: internal operation failed",
231};
232
233const UNION_ERRORS: [BuiltinErrorDescriptor; 7] = [
234    UNION_ERROR_LEGACY_OPTION_UNSUPPORTED,
235    UNION_ERROR_CONFLICTING_ORDER_OPTIONS,
236    UNION_ERROR_UNKNOWN_OPTION,
237    UNION_ERROR_ROWS_COLUMN_MISMATCH,
238    UNION_ERROR_UNSUPPORTED_INPUT_TYPE,
239    UNION_ERROR_INVALID_ARGUMENT,
240    UNION_ERROR_INTERNAL,
241];
242
243pub const UNION_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
244    signatures: &UNION_SIGNATURES,
245    output_mode: BuiltinOutputMode::ByRequestedOutputCount,
246    completion_policy: BuiltinCompletionPolicy::Public,
247    errors: &UNION_ERRORS,
248};
249
250fn union_error_with(
251    error: &'static BuiltinErrorDescriptor,
252    message: impl Into<String>,
253) -> crate::RuntimeError {
254    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
255    if let Some(identifier) = error.identifier {
256        builder = builder.with_identifier(identifier);
257    }
258    builder.build()
259}
260
261fn union_error(error: &'static BuiltinErrorDescriptor) -> crate::RuntimeError {
262    union_error_with(error, error.message)
263}
264
265fn union_internal_error(message: impl Into<String>) -> crate::RuntimeError {
266    union_error_with(&UNION_ERROR_INTERNAL, message)
267}
268
269#[runtime_builtin(
270    name = "union",
271    category = "array/sorting_sets",
272    summary = "Return unions of input arrays with ordering and index-output controls.",
273    keywords = "union,set,stable,rows,indices,gpu",
274    accel = "array_construct",
275    sink = true,
276    type_resolver(set_values_output_type),
277    descriptor(crate::builtins::array::sorting_sets::union::UNION_DESCRIPTOR),
278    builtin_path = "crate::builtins::array::sorting_sets::union"
279)]
280async fn union_builtin(a: Value, b: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
281    let eval = evaluate(a, b, &rest).await?;
282    if let Some(out_count) = crate::output_count::current_output_count() {
283        if out_count == 0 {
284            return Ok(Value::OutputList(Vec::new()));
285        }
286        if out_count == 1 {
287            return Ok(Value::OutputList(vec![eval.into_values_value()]));
288        }
289        if out_count == 2 {
290            let (values, ia) = eval.into_pair();
291            return Ok(Value::OutputList(vec![values, ia]));
292        }
293        let (values, ia, ib) = eval.into_triple();
294        return Ok(crate::output_count::output_list_with_padding(
295            out_count,
296            vec![values, ia, ib],
297        ));
298    }
299    Ok(eval.into_values_value())
300}
301
302/// Evaluate the `union` builtin once and expose all outputs.
303pub async fn evaluate(a: Value, b: Value, rest: &[Value]) -> crate::BuiltinResult<UnionEvaluation> {
304    let opts = parse_options(rest)?;
305    match (a, b) {
306        (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
307            union_gpu_pair(handle_a, handle_b, &opts).await
308        }
309        (Value::GpuTensor(handle_a), other) => union_gpu_mixed(handle_a, other, &opts, true).await,
310        (other, Value::GpuTensor(handle_b)) => union_gpu_mixed(handle_b, other, &opts, false).await,
311        (left, right) => union_host(left, right, &opts),
312    }
313}
314
315fn parse_options(rest: &[Value]) -> crate::BuiltinResult<UnionOptions> {
316    let mut opts = UnionOptions {
317        rows: false,
318        order: UnionOrder::Sorted,
319    };
320    let mut seen_order: Option<UnionOrder> = None;
321
322    let tokens = tokens_from_values(rest);
323    for (arg, token) in rest.iter().zip(tokens.iter()) {
324        let text = match token {
325            crate::builtins::common::arg_tokens::ArgToken::String(text) => text.as_str(),
326            _ => {
327                let text = tensor::value_to_string(arg)
328                    .ok_or_else(|| union_error(&UNION_ERROR_INVALID_ARGUMENT))?;
329                let lowered = text.trim().to_ascii_lowercase();
330                parse_union_option(&mut opts, &mut seen_order, &lowered)?;
331                continue;
332            }
333        };
334        parse_union_option(&mut opts, &mut seen_order, text)?;
335    }
336
337    Ok(opts)
338}
339
340fn parse_union_option(
341    opts: &mut UnionOptions,
342    seen_order: &mut Option<UnionOrder>,
343    lowered: &str,
344) -> crate::BuiltinResult<()> {
345    match lowered {
346        "rows" => opts.rows = true,
347        "sorted" => {
348            if let Some(prev) = seen_order {
349                if *prev != UnionOrder::Sorted {
350                    return Err(union_error(&UNION_ERROR_CONFLICTING_ORDER_OPTIONS));
351                }
352            }
353            *seen_order = Some(UnionOrder::Sorted);
354            opts.order = UnionOrder::Sorted;
355        }
356        "stable" => {
357            if let Some(prev) = seen_order {
358                if *prev != UnionOrder::Stable {
359                    return Err(union_error(&UNION_ERROR_CONFLICTING_ORDER_OPTIONS));
360                }
361            }
362            *seen_order = Some(UnionOrder::Stable);
363            opts.order = UnionOrder::Stable;
364        }
365        "legacy" | "r2012a" => {
366            return Err(union_error(&UNION_ERROR_LEGACY_OPTION_UNSUPPORTED));
367        }
368        other => {
369            return Err(union_error_with(
370                &UNION_ERROR_UNKNOWN_OPTION,
371                format!("union: unrecognised option '{other}'"),
372            ))
373        }
374    }
375    Ok(())
376}
377
378async fn union_gpu_pair(
379    handle_a: GpuTensorHandle,
380    handle_b: GpuTensorHandle,
381    opts: &UnionOptions,
382) -> crate::BuiltinResult<UnionEvaluation> {
383    if let Some(provider) = runmat_accelerate_api::provider() {
384        match provider.union(&handle_a, &handle_b, opts).await {
385            Ok(result) => return UnionEvaluation::from_union_result(result),
386            Err(_) => {
387                // Fall back to host gather when provider union is unavailable.
388            }
389        }
390    }
391    let tensor_a = gpu_helpers::gather_tensor_async(&handle_a).await?;
392    let tensor_b = gpu_helpers::gather_tensor_async(&handle_b).await?;
393    union_numeric(tensor_a, tensor_b, opts)
394}
395
396async fn union_gpu_mixed(
397    handle_gpu: GpuTensorHandle,
398    other: Value,
399    opts: &UnionOptions,
400    gpu_is_a: bool,
401) -> crate::BuiltinResult<UnionEvaluation> {
402    let tensor_gpu = gpu_helpers::gather_tensor_async(&handle_gpu).await?;
403    let tensor_other =
404        tensor::value_into_tensor_for("union", other).map_err(|e| union_internal_error(e))?;
405    if gpu_is_a {
406        union_numeric(tensor_gpu, tensor_other, opts)
407    } else {
408        union_numeric(tensor_other, tensor_gpu, opts)
409    }
410}
411
412fn union_host(a: Value, b: Value, opts: &UnionOptions) -> crate::BuiltinResult<UnionEvaluation> {
413    match (a, b) {
414        // Complex cases
415        (Value::ComplexTensor(at), Value::ComplexTensor(bt)) => union_complex(at, bt, opts),
416        (Value::ComplexTensor(at), Value::Complex(re, im)) => {
417            let bt = ComplexTensor::new(vec![(re, im)], vec![1, 1])
418                .map_err(|e| union_internal_error(format!("union: {e}")))?;
419            union_complex(at, bt, opts)
420        }
421        (Value::Complex(re, im), Value::ComplexTensor(bt)) => {
422            let at = ComplexTensor::new(vec![(re, im)], vec![1, 1])
423                .map_err(|e| union_internal_error(format!("union: {e}")))?;
424            union_complex(at, bt, opts)
425        }
426        (Value::Complex(a_re, a_im), Value::Complex(b_re, b_im)) => {
427            let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
428                .map_err(|e| union_internal_error(format!("union: {e}")))?;
429            let bt = ComplexTensor::new(vec![(b_re, b_im)], vec![1, 1])
430                .map_err(|e| union_internal_error(format!("union: {e}")))?;
431            union_complex(at, bt, opts)
432        }
433
434        // Character arrays
435        (Value::CharArray(ac), Value::CharArray(bc)) => union_char(ac, bc, opts),
436
437        // String arrays / scalars
438        (Value::StringArray(astring), Value::StringArray(bstring)) => {
439            union_string(astring, bstring, opts)
440        }
441        (Value::StringArray(astring), Value::String(b)) => {
442            let bstring = StringArray::new(vec![b], vec![1, 1])
443                .map_err(|e| union_internal_error(format!("union: {e}")))?;
444            union_string(astring, bstring, opts)
445        }
446        (Value::String(a), Value::StringArray(bstring)) => {
447            let astring = StringArray::new(vec![a], vec![1, 1])
448                .map_err(|e| union_internal_error(format!("union: {e}")))?;
449            union_string(astring, bstring, opts)
450        }
451        (Value::String(a), Value::String(b)) => {
452            let astring = StringArray::new(vec![a], vec![1, 1])
453                .map_err(|e| union_internal_error(format!("union: {e}")))?;
454            let bstring = StringArray::new(vec![b], vec![1, 1])
455                .map_err(|e| union_internal_error(format!("union: {e}")))?;
456            union_string(astring, bstring, opts)
457        }
458
459        // Fallback to numeric (includes tensors, logical arrays, ints, bools, doubles)
460        (left, right) => {
461            let tensor_a = tensor::value_into_tensor_for("union", left)
462                .map_err(|e| union_error_with(&UNION_ERROR_UNSUPPORTED_INPUT_TYPE, e))?;
463            let tensor_b = tensor::value_into_tensor_for("union", right)
464                .map_err(|e| union_error_with(&UNION_ERROR_UNSUPPORTED_INPUT_TYPE, e))?;
465            union_numeric(tensor_a, tensor_b, opts)
466        }
467    }
468}
469
470fn union_numeric(
471    a: Tensor,
472    b: Tensor,
473    opts: &UnionOptions,
474) -> crate::BuiltinResult<UnionEvaluation> {
475    if opts.rows {
476        union_numeric_rows(a, b, opts)
477    } else {
478        union_numeric_elements(a, b, opts)
479    }
480}
481
482/// Helper exposed for acceleration providers handling numeric tensors entirely on the host.
483pub fn union_numeric_from_tensors(
484    a: Tensor,
485    b: Tensor,
486    opts: &UnionOptions,
487) -> crate::BuiltinResult<UnionEvaluation> {
488    union_numeric(a, b, opts)
489}
490
491fn union_numeric_elements(
492    a: Tensor,
493    b: Tensor,
494    opts: &UnionOptions,
495) -> crate::BuiltinResult<UnionEvaluation> {
496    let mut entries = Vec::<NumericUnionEntry>::new();
497    let mut map: HashMap<u64, usize> = HashMap::new();
498    let mut order_counter = 0usize;
499
500    for (idx, &value) in a.data.iter().enumerate() {
501        let key = canonicalize_f64(value);
502        match map.entry(key) {
503            Entry::Occupied(_) => {
504                // Already recorded from A; keep first occurrence only.
505            }
506            Entry::Vacant(v) => {
507                let entry_idx = entries.len();
508                entries.push(NumericUnionEntry {
509                    value,
510                    a_index: Some(idx),
511                    b_index: None,
512                    order_rank: order_counter,
513                });
514                v.insert(entry_idx);
515                order_counter += 1;
516            }
517        }
518    }
519
520    for (idx, &value) in b.data.iter().enumerate() {
521        let key = canonicalize_f64(value);
522        match map.entry(key) {
523            Entry::Occupied(occ) => {
524                let entry = &mut entries[*occ.get()];
525                if entry.a_index.is_none() && entry.b_index.is_none() {
526                    entry.b_index = Some(idx);
527                }
528            }
529            Entry::Vacant(v) => {
530                let entry_idx = entries.len();
531                entries.push(NumericUnionEntry {
532                    value,
533                    a_index: None,
534                    b_index: Some(idx),
535                    order_rank: order_counter,
536                });
537                v.insert(entry_idx);
538                order_counter += 1;
539            }
540        }
541    }
542
543    assemble_numeric_union(entries, opts)
544}
545
546fn union_numeric_rows(
547    a: Tensor,
548    b: Tensor,
549    opts: &UnionOptions,
550) -> crate::BuiltinResult<UnionEvaluation> {
551    if a.shape.len() != 2 || b.shape.len() != 2 {
552        return Err(union_internal_error(
553            "union: 'rows' option requires 2-D numeric matrices",
554        ));
555    }
556    if a.shape[1] != b.shape[1] {
557        return Err(union_error_with(
558            &UNION_ERROR_ROWS_COLUMN_MISMATCH,
559            UNION_ERROR_ROWS_COLUMN_MISMATCH.message,
560        ));
561    }
562    let rows_a = a.shape[0];
563    let cols = a.shape[1];
564    let rows_b = b.shape[0];
565
566    let mut entries = Vec::<NumericRowUnionEntry>::new();
567    let mut map: HashMap<NumericRowKey, usize> = HashMap::new();
568    let mut order_counter = 0usize;
569
570    for r in 0..rows_a {
571        let mut row_values = Vec::with_capacity(cols);
572        for c in 0..cols {
573            let idx = r + c * rows_a;
574            row_values.push(a.data[idx]);
575        }
576        let key = NumericRowKey::from_slice(&row_values);
577        match map.entry(key) {
578            Entry::Occupied(_) => {}
579            Entry::Vacant(v) => {
580                let entry_idx = entries.len();
581                entries.push(NumericRowUnionEntry {
582                    row_data: row_values,
583                    a_row: Some(r),
584                    b_row: None,
585                    order_rank: order_counter,
586                });
587                v.insert(entry_idx);
588                order_counter += 1;
589            }
590        }
591    }
592
593    for r in 0..rows_b {
594        let mut row_values = Vec::with_capacity(cols);
595        for c in 0..cols {
596            let idx = r + c * rows_b;
597            row_values.push(b.data[idx]);
598        }
599        let key = NumericRowKey::from_slice(&row_values);
600        match map.entry(key) {
601            Entry::Occupied(occ) => {
602                let entry = &mut entries[*occ.get()];
603                if entry.a_row.is_none() && entry.b_row.is_none() {
604                    entry.b_row = Some(r);
605                }
606            }
607            Entry::Vacant(v) => {
608                let entry_idx = entries.len();
609                entries.push(NumericRowUnionEntry {
610                    row_data: row_values,
611                    a_row: None,
612                    b_row: Some(r),
613                    order_rank: order_counter,
614                });
615                v.insert(entry_idx);
616                order_counter += 1;
617            }
618        }
619    }
620
621    assemble_numeric_row_union(entries, opts, cols)
622}
623
624fn union_complex(
625    a: ComplexTensor,
626    b: ComplexTensor,
627    opts: &UnionOptions,
628) -> crate::BuiltinResult<UnionEvaluation> {
629    if opts.rows {
630        union_complex_rows(a, b, opts)
631    } else {
632        union_complex_elements(a, b, opts)
633    }
634}
635
636fn union_complex_elements(
637    a: ComplexTensor,
638    b: ComplexTensor,
639    opts: &UnionOptions,
640) -> crate::BuiltinResult<UnionEvaluation> {
641    let mut entries = Vec::<ComplexUnionEntry>::new();
642    let mut map: HashMap<ComplexKey, usize> = HashMap::new();
643    let mut order_counter = 0usize;
644
645    for (idx, &value) in a.data.iter().enumerate() {
646        let key = ComplexKey::new(value);
647        match map.entry(key) {
648            Entry::Occupied(_) => {}
649            Entry::Vacant(v) => {
650                let entry_idx = entries.len();
651                entries.push(ComplexUnionEntry {
652                    value,
653                    a_index: Some(idx),
654                    b_index: None,
655                    order_rank: order_counter,
656                });
657                v.insert(entry_idx);
658                order_counter += 1;
659            }
660        }
661    }
662
663    for (idx, &value) in b.data.iter().enumerate() {
664        let key = ComplexKey::new(value);
665        match map.entry(key) {
666            Entry::Occupied(occ) => {
667                let entry = &mut entries[*occ.get()];
668                if entry.a_index.is_none() && entry.b_index.is_none() {
669                    entry.b_index = Some(idx);
670                }
671            }
672            Entry::Vacant(v) => {
673                let entry_idx = entries.len();
674                entries.push(ComplexUnionEntry {
675                    value,
676                    a_index: None,
677                    b_index: Some(idx),
678                    order_rank: order_counter,
679                });
680                v.insert(entry_idx);
681                order_counter += 1;
682            }
683        }
684    }
685
686    assemble_complex_union(entries, opts)
687}
688
689fn union_complex_rows(
690    a: ComplexTensor,
691    b: ComplexTensor,
692    opts: &UnionOptions,
693) -> crate::BuiltinResult<UnionEvaluation> {
694    if a.shape.len() != 2 || b.shape.len() != 2 {
695        return Err(union_internal_error(
696            "union: 'rows' option requires 2-D complex matrices",
697        ));
698    }
699    if a.shape[1] != b.shape[1] {
700        return Err(union_error_with(
701            &UNION_ERROR_ROWS_COLUMN_MISMATCH,
702            UNION_ERROR_ROWS_COLUMN_MISMATCH.message,
703        ));
704    }
705    let rows_a = a.shape[0];
706    let cols = a.shape[1];
707    let rows_b = b.shape[0];
708
709    let mut entries = Vec::<ComplexRowUnionEntry>::new();
710    let mut map: HashMap<Vec<ComplexKey>, usize> = HashMap::new();
711    let mut order_counter = 0usize;
712
713    for r in 0..rows_a {
714        let mut row_values = Vec::with_capacity(cols);
715        let mut key_row = Vec::with_capacity(cols);
716        for c in 0..cols {
717            let idx = r + c * rows_a;
718            let value = a.data[idx];
719            row_values.push(value);
720            key_row.push(ComplexKey::new(value));
721        }
722        match map.entry(key_row) {
723            Entry::Occupied(_) => {}
724            Entry::Vacant(v) => {
725                let entry_idx = entries.len();
726                entries.push(ComplexRowUnionEntry {
727                    row_data: row_values,
728                    a_row: Some(r),
729                    b_row: None,
730                    order_rank: order_counter,
731                });
732                v.insert(entry_idx);
733                order_counter += 1;
734            }
735        }
736    }
737
738    for r in 0..rows_b {
739        let mut row_values = Vec::with_capacity(cols);
740        let mut key_row = Vec::with_capacity(cols);
741        for c in 0..cols {
742            let idx = r + c * rows_b;
743            let value = b.data[idx];
744            row_values.push(value);
745            key_row.push(ComplexKey::new(value));
746        }
747        match map.entry(key_row) {
748            Entry::Occupied(occ) => {
749                let entry = &mut entries[*occ.get()];
750                if entry.a_row.is_none() && entry.b_row.is_none() {
751                    entry.b_row = Some(r);
752                }
753            }
754            Entry::Vacant(v) => {
755                let entry_idx = entries.len();
756                entries.push(ComplexRowUnionEntry {
757                    row_data: row_values,
758                    a_row: None,
759                    b_row: Some(r),
760                    order_rank: order_counter,
761                });
762                v.insert(entry_idx);
763                order_counter += 1;
764            }
765        }
766    }
767
768    assemble_complex_row_union(entries, opts, cols)
769}
770
771fn union_char(
772    a: CharArray,
773    b: CharArray,
774    opts: &UnionOptions,
775) -> crate::BuiltinResult<UnionEvaluation> {
776    if opts.rows {
777        union_char_rows(a, b, opts)
778    } else {
779        union_char_elements(a, b, opts)
780    }
781}
782
783fn union_char_elements(
784    a: CharArray,
785    b: CharArray,
786    opts: &UnionOptions,
787) -> crate::BuiltinResult<UnionEvaluation> {
788    let mut entries = Vec::<CharUnionEntry>::new();
789    let mut map: HashMap<u32, usize> = HashMap::new();
790    let mut order_counter = 0usize;
791
792    for col in 0..a.cols {
793        for row in 0..a.rows {
794            let linear_idx = row + col * a.rows;
795            let data_idx = row * a.cols + col;
796            let ch = a.data[data_idx];
797            let key = ch as u32;
798            match map.entry(key) {
799                Entry::Occupied(_) => {}
800                Entry::Vacant(v) => {
801                    let entry_idx = entries.len();
802                    entries.push(CharUnionEntry {
803                        ch,
804                        a_index: Some(linear_idx),
805                        b_index: None,
806                        order_rank: order_counter,
807                    });
808                    v.insert(entry_idx);
809                    order_counter += 1;
810                }
811            }
812        }
813    }
814
815    for col in 0..b.cols {
816        for row in 0..b.rows {
817            let linear_idx = row + col * b.rows;
818            let data_idx = row * b.cols + col;
819            let ch = b.data[data_idx];
820            let key = ch as u32;
821            match map.entry(key) {
822                Entry::Occupied(occ) => {
823                    let entry = &mut entries[*occ.get()];
824                    if entry.a_index.is_none() && entry.b_index.is_none() {
825                        entry.b_index = Some(linear_idx);
826                    }
827                }
828                Entry::Vacant(v) => {
829                    let entry_idx = entries.len();
830                    entries.push(CharUnionEntry {
831                        ch,
832                        a_index: None,
833                        b_index: Some(linear_idx),
834                        order_rank: order_counter,
835                    });
836                    v.insert(entry_idx);
837                    order_counter += 1;
838                }
839            }
840        }
841    }
842
843    assemble_char_union(entries, opts)
844}
845
846fn union_char_rows(
847    a: CharArray,
848    b: CharArray,
849    opts: &UnionOptions,
850) -> crate::BuiltinResult<UnionEvaluation> {
851    if a.cols != b.cols {
852        return Err(union_error_with(
853            &UNION_ERROR_ROWS_COLUMN_MISMATCH,
854            UNION_ERROR_ROWS_COLUMN_MISMATCH.message,
855        ));
856    }
857    let rows_a = a.rows;
858    let rows_b = b.rows;
859    let cols = a.cols;
860
861    let mut entries = Vec::<CharRowUnionEntry>::new();
862    let mut map: HashMap<RowCharKey, usize> = HashMap::new();
863    let mut order_counter = 0usize;
864
865    for r in 0..rows_a {
866        let mut row_values = Vec::with_capacity(cols);
867        for c in 0..cols {
868            let idx = r * cols + c;
869            row_values.push(a.data[idx]);
870        }
871        let key = RowCharKey::from_slice(&row_values);
872        match map.entry(key) {
873            Entry::Occupied(_) => {}
874            Entry::Vacant(v) => {
875                let entry_idx = entries.len();
876                entries.push(CharRowUnionEntry {
877                    row_data: row_values,
878                    a_row: Some(r),
879                    b_row: None,
880                    order_rank: order_counter,
881                });
882                v.insert(entry_idx);
883                order_counter += 1;
884            }
885        }
886    }
887
888    for r in 0..rows_b {
889        let mut row_values = Vec::with_capacity(cols);
890        for c in 0..cols {
891            let idx = r * cols + c;
892            row_values.push(b.data[idx]);
893        }
894        let key = RowCharKey::from_slice(&row_values);
895        match map.entry(key) {
896            Entry::Occupied(occ) => {
897                let entry = &mut entries[*occ.get()];
898                if entry.a_row.is_none() && entry.b_row.is_none() {
899                    entry.b_row = Some(r);
900                }
901            }
902            Entry::Vacant(v) => {
903                let entry_idx = entries.len();
904                entries.push(CharRowUnionEntry {
905                    row_data: row_values,
906                    a_row: None,
907                    b_row: Some(r),
908                    order_rank: order_counter,
909                });
910                v.insert(entry_idx);
911                order_counter += 1;
912            }
913        }
914    }
915
916    assemble_char_row_union(entries, opts, cols)
917}
918
919fn union_string(
920    a: StringArray,
921    b: StringArray,
922    opts: &UnionOptions,
923) -> crate::BuiltinResult<UnionEvaluation> {
924    if opts.rows {
925        union_string_rows(a, b, opts)
926    } else {
927        union_string_elements(a, b, opts)
928    }
929}
930
931fn union_string_elements(
932    a: StringArray,
933    b: StringArray,
934    opts: &UnionOptions,
935) -> crate::BuiltinResult<UnionEvaluation> {
936    let mut entries = Vec::<StringUnionEntry>::new();
937    let mut map: HashMap<String, usize> = HashMap::new();
938    let mut order_counter = 0usize;
939
940    for (idx, value) in a.data.iter().enumerate() {
941        match map.entry(value.clone()) {
942            Entry::Occupied(_) => {}
943            Entry::Vacant(v) => {
944                let entry_idx = entries.len();
945                entries.push(StringUnionEntry {
946                    value: value.clone(),
947                    a_index: Some(idx),
948                    b_index: None,
949                    order_rank: order_counter,
950                });
951                v.insert(entry_idx);
952                order_counter += 1;
953            }
954        }
955    }
956
957    for (idx, value) in b.data.iter().enumerate() {
958        match map.entry(value.clone()) {
959            Entry::Occupied(occ) => {
960                let entry = &mut entries[*occ.get()];
961                if entry.a_index.is_none() && entry.b_index.is_none() {
962                    entry.b_index = Some(idx);
963                }
964            }
965            Entry::Vacant(v) => {
966                let entry_idx = entries.len();
967                entries.push(StringUnionEntry {
968                    value: value.clone(),
969                    a_index: None,
970                    b_index: Some(idx),
971                    order_rank: order_counter,
972                });
973                v.insert(entry_idx);
974                order_counter += 1;
975            }
976        }
977    }
978
979    assemble_string_union(entries, opts)
980}
981
982fn union_string_rows(
983    a: StringArray,
984    b: StringArray,
985    opts: &UnionOptions,
986) -> crate::BuiltinResult<UnionEvaluation> {
987    if a.shape.len() != 2 || b.shape.len() != 2 {
988        return Err(union_internal_error(
989            "union: 'rows' option requires 2-D string arrays",
990        ));
991    }
992    if a.shape[1] != b.shape[1] {
993        return Err(union_error_with(
994            &UNION_ERROR_ROWS_COLUMN_MISMATCH,
995            UNION_ERROR_ROWS_COLUMN_MISMATCH.message,
996        ));
997    }
998    let rows_a = a.shape[0];
999    let cols = a.shape[1];
1000    let rows_b = b.shape[0];
1001
1002    let mut entries = Vec::<StringRowUnionEntry>::new();
1003    let mut map: HashMap<RowStringKey, usize> = HashMap::new();
1004    let mut order_counter = 0usize;
1005
1006    for r in 0..rows_a {
1007        let mut row_values = Vec::with_capacity(cols);
1008        for c in 0..cols {
1009            let idx = r + c * rows_a;
1010            row_values.push(a.data[idx].clone());
1011        }
1012        let key = RowStringKey(row_values.clone());
1013        match map.entry(key) {
1014            Entry::Occupied(_) => {}
1015            Entry::Vacant(v) => {
1016                let entry_idx = entries.len();
1017                entries.push(StringRowUnionEntry {
1018                    row_data: row_values,
1019                    a_row: Some(r),
1020                    b_row: None,
1021                    order_rank: order_counter,
1022                });
1023                v.insert(entry_idx);
1024                order_counter += 1;
1025            }
1026        }
1027    }
1028
1029    for r in 0..rows_b {
1030        let mut row_values = Vec::with_capacity(cols);
1031        for c in 0..cols {
1032            let idx = r + c * rows_b;
1033            row_values.push(b.data[idx].clone());
1034        }
1035        let key = RowStringKey(row_values.clone());
1036        match map.entry(key) {
1037            Entry::Occupied(occ) => {
1038                let entry = &mut entries[*occ.get()];
1039                if entry.a_row.is_none() && entry.b_row.is_none() {
1040                    entry.b_row = Some(r);
1041                }
1042            }
1043            Entry::Vacant(v) => {
1044                let entry_idx = entries.len();
1045                entries.push(StringRowUnionEntry {
1046                    row_data: row_values,
1047                    a_row: None,
1048                    b_row: Some(r),
1049                    order_rank: order_counter,
1050                });
1051                v.insert(entry_idx);
1052                order_counter += 1;
1053            }
1054        }
1055    }
1056
1057    assemble_string_row_union(entries, opts, cols)
1058}
1059
1060#[derive(Debug, Clone)]
1061pub struct UnionEvaluation {
1062    values: Value,
1063    ia: Tensor,
1064    ib: Tensor,
1065}
1066
1067impl UnionEvaluation {
1068    fn new(values: Value, ia: Tensor, ib: Tensor) -> Self {
1069        Self { values, ia, ib }
1070    }
1071
1072    pub fn from_union_result(result: UnionResult) -> crate::BuiltinResult<Self> {
1073        let UnionResult { values, ia, ib } = result;
1074        let values_tensor = Tensor::new(values.data, values.shape)
1075            .map_err(|e| union_internal_error(format!("union: {e}")))?;
1076        let ia_tensor = Tensor::new(ia.data, ia.shape)
1077            .map_err(|e| union_internal_error(format!("union: {e}")))?;
1078        let ib_tensor = Tensor::new(ib.data, ib.shape)
1079            .map_err(|e| union_internal_error(format!("union: {e}")))?;
1080        Ok(UnionEvaluation::new(
1081            tensor::tensor_into_value(values_tensor),
1082            ia_tensor,
1083            ib_tensor,
1084        ))
1085    }
1086
1087    pub fn into_numeric_union_result(self) -> crate::BuiltinResult<UnionResult> {
1088        let UnionEvaluation { values, ia, ib } = self;
1089        let values_tensor =
1090            tensor::value_into_tensor_for("union", values).map_err(|e| union_internal_error(e))?;
1091        Ok(UnionResult {
1092            values: HostTensorOwned {
1093                data: values_tensor.data,
1094                shape: values_tensor.shape,
1095                storage: GpuTensorStorage::Real,
1096            },
1097            ia: HostTensorOwned {
1098                data: ia.data,
1099                shape: ia.shape,
1100                storage: GpuTensorStorage::Real,
1101            },
1102            ib: HostTensorOwned {
1103                data: ib.data,
1104                shape: ib.shape,
1105                storage: GpuTensorStorage::Real,
1106            },
1107        })
1108    }
1109
1110    pub fn into_values_value(self) -> Value {
1111        self.values
1112    }
1113
1114    pub fn into_pair(self) -> (Value, Value) {
1115        let ia = tensor::tensor_into_value(self.ia);
1116        (self.values, ia)
1117    }
1118
1119    pub fn into_triple(self) -> (Value, Value, Value) {
1120        let ia = tensor::tensor_into_value(self.ia);
1121        let ib = tensor::tensor_into_value(self.ib);
1122        (self.values, ia, ib)
1123    }
1124
1125    pub fn values_value(&self) -> Value {
1126        self.values.clone()
1127    }
1128
1129    pub fn ia_value(&self) -> Value {
1130        tensor::tensor_into_value(self.ia.clone())
1131    }
1132
1133    pub fn ib_value(&self) -> Value {
1134        tensor::tensor_into_value(self.ib.clone())
1135    }
1136}
1137
1138#[derive(Debug)]
1139struct NumericUnionEntry {
1140    value: f64,
1141    a_index: Option<usize>,
1142    b_index: Option<usize>,
1143    order_rank: usize,
1144}
1145
1146#[derive(Debug)]
1147struct NumericRowUnionEntry {
1148    row_data: Vec<f64>,
1149    a_row: Option<usize>,
1150    b_row: Option<usize>,
1151    order_rank: usize,
1152}
1153
1154#[derive(Debug)]
1155struct ComplexUnionEntry {
1156    value: (f64, f64),
1157    a_index: Option<usize>,
1158    b_index: Option<usize>,
1159    order_rank: usize,
1160}
1161
1162#[derive(Debug)]
1163struct ComplexRowUnionEntry {
1164    row_data: Vec<(f64, f64)>,
1165    a_row: Option<usize>,
1166    b_row: Option<usize>,
1167    order_rank: usize,
1168}
1169
1170#[derive(Debug)]
1171struct CharUnionEntry {
1172    ch: char,
1173    a_index: Option<usize>,
1174    b_index: Option<usize>,
1175    order_rank: usize,
1176}
1177
1178#[derive(Debug)]
1179struct CharRowUnionEntry {
1180    row_data: Vec<char>,
1181    a_row: Option<usize>,
1182    b_row: Option<usize>,
1183    order_rank: usize,
1184}
1185
1186#[derive(Debug)]
1187struct StringUnionEntry {
1188    value: String,
1189    a_index: Option<usize>,
1190    b_index: Option<usize>,
1191    order_rank: usize,
1192}
1193
1194#[derive(Debug)]
1195struct StringRowUnionEntry {
1196    row_data: Vec<String>,
1197    a_row: Option<usize>,
1198    b_row: Option<usize>,
1199    order_rank: usize,
1200}
1201
1202#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1203struct NumericRowKey(Vec<u64>);
1204
1205impl NumericRowKey {
1206    fn from_slice(values: &[f64]) -> Self {
1207        NumericRowKey(values.iter().map(|&v| canonicalize_f64(v)).collect())
1208    }
1209}
1210
1211#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
1212struct ComplexKey {
1213    re: u64,
1214    im: u64,
1215}
1216
1217impl ComplexKey {
1218    fn new(value: (f64, f64)) -> Self {
1219        Self {
1220            re: canonicalize_f64(value.0),
1221            im: canonicalize_f64(value.1),
1222        }
1223    }
1224}
1225
1226#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1227struct RowCharKey(Vec<u32>);
1228
1229impl RowCharKey {
1230    fn from_slice(values: &[char]) -> Self {
1231        RowCharKey(values.iter().map(|&ch| ch as u32).collect())
1232    }
1233}
1234
1235#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1236struct RowStringKey(Vec<String>);
1237
1238fn assemble_numeric_union(
1239    entries: Vec<NumericUnionEntry>,
1240    opts: &UnionOptions,
1241) -> crate::BuiltinResult<UnionEvaluation> {
1242    let mut order: Vec<usize> = (0..entries.len()).collect();
1243    match opts.order {
1244        UnionOrder::Sorted => {
1245            order.sort_by(|&lhs, &rhs| compare_f64(entries[lhs].value, entries[rhs].value));
1246        }
1247        UnionOrder::Stable => {
1248            order.sort_by_key(|&idx| entries[idx].order_rank);
1249        }
1250    }
1251
1252    let mut values = Vec::with_capacity(order.len());
1253    let mut ia = Vec::new();
1254    let mut ib = Vec::new();
1255    for &idx in &order {
1256        let entry = &entries[idx];
1257        values.push(entry.value);
1258        if let Some(a_idx) = entry.a_index {
1259            ia.push((a_idx + 1) as f64);
1260        } else if let Some(b_idx) = entry.b_index {
1261            ib.push((b_idx + 1) as f64);
1262        }
1263    }
1264
1265    let value_tensor = Tensor::new(values, vec![order.len(), 1])
1266        .map_err(|e| union_internal_error(format!("union: {e}")))?;
1267    let ia_len = ia.len();
1268    let ib_len = ib.len();
1269    let ia_tensor = Tensor::new(ia, vec![ia_len, 1])
1270        .map_err(|e| union_internal_error(format!("union: {e}")))?;
1271    let ib_tensor = Tensor::new(ib, vec![ib_len, 1])
1272        .map_err(|e| union_internal_error(format!("union: {e}")))?;
1273
1274    Ok(UnionEvaluation::new(
1275        tensor::tensor_into_value(value_tensor),
1276        ia_tensor,
1277        ib_tensor,
1278    ))
1279}
1280
1281fn assemble_numeric_row_union(
1282    entries: Vec<NumericRowUnionEntry>,
1283    opts: &UnionOptions,
1284    cols: usize,
1285) -> crate::BuiltinResult<UnionEvaluation> {
1286    let mut order: Vec<usize> = (0..entries.len()).collect();
1287    match opts.order {
1288        UnionOrder::Sorted => {
1289            order.sort_by(|&lhs, &rhs| {
1290                compare_numeric_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1291            });
1292        }
1293        UnionOrder::Stable => {
1294            order.sort_by_key(|&idx| entries[idx].order_rank);
1295        }
1296    }
1297
1298    let unique_rows = order.len();
1299    let mut values = vec![0.0f64; unique_rows * cols];
1300    let mut ia = Vec::new();
1301    let mut ib = Vec::new();
1302
1303    for (row_pos, &entry_idx) in order.iter().enumerate() {
1304        let entry = &entries[entry_idx];
1305        for col in 0..cols {
1306            let dest = row_pos + col * unique_rows;
1307            values[dest] = entry.row_data[col];
1308        }
1309        if let Some(a_row) = entry.a_row {
1310            ia.push((a_row + 1) as f64);
1311        } else if let Some(b_row) = entry.b_row {
1312            ib.push((b_row + 1) as f64);
1313        }
1314    }
1315
1316    let value_tensor = Tensor::new(values, vec![unique_rows, cols])
1317        .map_err(|e| union_internal_error(format!("union: {e}")))?;
1318    let ia_len = ia.len();
1319    let ib_len = ib.len();
1320    let ia_tensor = Tensor::new(ia, vec![ia_len, 1])
1321        .map_err(|e| union_internal_error(format!("union: {e}")))?;
1322    let ib_tensor = Tensor::new(ib, vec![ib_len, 1])
1323        .map_err(|e| union_internal_error(format!("union: {e}")))?;
1324
1325    Ok(UnionEvaluation::new(
1326        tensor::tensor_into_value(value_tensor),
1327        ia_tensor,
1328        ib_tensor,
1329    ))
1330}
1331
1332fn assemble_complex_union(
1333    entries: Vec<ComplexUnionEntry>,
1334    opts: &UnionOptions,
1335) -> crate::BuiltinResult<UnionEvaluation> {
1336    let mut order: Vec<usize> = (0..entries.len()).collect();
1337    match opts.order {
1338        UnionOrder::Sorted => {
1339            order.sort_by(|&lhs, &rhs| compare_complex(entries[lhs].value, entries[rhs].value));
1340        }
1341        UnionOrder::Stable => {
1342            order.sort_by_key(|&idx| entries[idx].order_rank);
1343        }
1344    }
1345
1346    let mut values = Vec::with_capacity(order.len());
1347    let mut ia = Vec::new();
1348    let mut ib = Vec::new();
1349    for &idx in &order {
1350        let entry = &entries[idx];
1351        values.push(entry.value);
1352        if let Some(a_idx) = entry.a_index {
1353            ia.push((a_idx + 1) as f64);
1354        } else if let Some(b_idx) = entry.b_index {
1355            ib.push((b_idx + 1) as f64);
1356        }
1357    }
1358
1359    let value_tensor = ComplexTensor::new(values, vec![order.len(), 1])
1360        .map_err(|e| union_internal_error(format!("union: {e}")))?;
1361    let ia_len = ia.len();
1362    let ib_len = ib.len();
1363    let ia_tensor = Tensor::new(ia, vec![ia_len, 1])
1364        .map_err(|e| union_internal_error(format!("union: {e}")))?;
1365    let ib_tensor = Tensor::new(ib, vec![ib_len, 1])
1366        .map_err(|e| union_internal_error(format!("union: {e}")))?;
1367
1368    Ok(UnionEvaluation::new(
1369        complex_tensor_into_value(value_tensor),
1370        ia_tensor,
1371        ib_tensor,
1372    ))
1373}
1374
1375fn assemble_complex_row_union(
1376    entries: Vec<ComplexRowUnionEntry>,
1377    opts: &UnionOptions,
1378    cols: usize,
1379) -> crate::BuiltinResult<UnionEvaluation> {
1380    let mut order: Vec<usize> = (0..entries.len()).collect();
1381    match opts.order {
1382        UnionOrder::Sorted => {
1383            order.sort_by(|&lhs, &rhs| {
1384                compare_complex_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1385            });
1386        }
1387        UnionOrder::Stable => {
1388            order.sort_by_key(|&idx| entries[idx].order_rank);
1389        }
1390    }
1391
1392    let unique_rows = order.len();
1393    let mut values = vec![(0.0, 0.0); unique_rows * cols];
1394    let mut ia = Vec::new();
1395    let mut ib = Vec::new();
1396
1397    for (row_pos, &entry_idx) in order.iter().enumerate() {
1398        let entry = &entries[entry_idx];
1399        for col in 0..cols {
1400            let dest = row_pos + col * unique_rows;
1401            values[dest] = entry.row_data[col];
1402        }
1403        if let Some(a_row) = entry.a_row {
1404            ia.push((a_row + 1) as f64);
1405        } else if let Some(b_row) = entry.b_row {
1406            ib.push((b_row + 1) as f64);
1407        }
1408    }
1409
1410    let value_tensor = ComplexTensor::new(values, vec![unique_rows, cols])
1411        .map_err(|e| union_internal_error(format!("union: {e}")))?;
1412    let ia_len = ia.len();
1413    let ib_len = ib.len();
1414    let ia_tensor = Tensor::new(ia, vec![ia_len, 1])
1415        .map_err(|e| union_internal_error(format!("union: {e}")))?;
1416    let ib_tensor = Tensor::new(ib, vec![ib_len, 1])
1417        .map_err(|e| union_internal_error(format!("union: {e}")))?;
1418
1419    Ok(UnionEvaluation::new(
1420        complex_tensor_into_value(value_tensor),
1421        ia_tensor,
1422        ib_tensor,
1423    ))
1424}
1425
1426fn assemble_char_union(
1427    entries: Vec<CharUnionEntry>,
1428    opts: &UnionOptions,
1429) -> crate::BuiltinResult<UnionEvaluation> {
1430    let mut order: Vec<usize> = (0..entries.len()).collect();
1431    match opts.order {
1432        UnionOrder::Sorted => {
1433            order.sort_by(|&lhs, &rhs| entries[lhs].ch.cmp(&entries[rhs].ch));
1434        }
1435        UnionOrder::Stable => {
1436            order.sort_by_key(|&idx| entries[idx].order_rank);
1437        }
1438    }
1439
1440    let mut values = Vec::with_capacity(order.len());
1441    let mut ia = Vec::new();
1442    let mut ib = Vec::new();
1443    for &idx in &order {
1444        let entry = &entries[idx];
1445        values.push(entry.ch);
1446        if let Some(a_idx) = entry.a_index {
1447            ia.push((a_idx + 1) as f64);
1448        } else if let Some(b_idx) = entry.b_index {
1449            ib.push((b_idx + 1) as f64);
1450        }
1451    }
1452
1453    let value_array = CharArray::new(values, order.len(), 1)
1454        .map_err(|e| union_internal_error(format!("union: {e}")))?;
1455    let ia_len = ia.len();
1456    let ib_len = ib.len();
1457    let ia_tensor = Tensor::new(ia, vec![ia_len, 1])
1458        .map_err(|e| union_internal_error(format!("union: {e}")))?;
1459    let ib_tensor = Tensor::new(ib, vec![ib_len, 1])
1460        .map_err(|e| union_internal_error(format!("union: {e}")))?;
1461
1462    Ok(UnionEvaluation::new(
1463        Value::CharArray(value_array),
1464        ia_tensor,
1465        ib_tensor,
1466    ))
1467}
1468
1469fn assemble_char_row_union(
1470    entries: Vec<CharRowUnionEntry>,
1471    opts: &UnionOptions,
1472    cols: usize,
1473) -> crate::BuiltinResult<UnionEvaluation> {
1474    let mut order: Vec<usize> = (0..entries.len()).collect();
1475    match opts.order {
1476        UnionOrder::Sorted => {
1477            order.sort_by(|&lhs, &rhs| {
1478                compare_char_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1479            });
1480        }
1481        UnionOrder::Stable => {
1482            order.sort_by_key(|&idx| entries[idx].order_rank);
1483        }
1484    }
1485
1486    let unique_rows = order.len();
1487    let mut values = vec!['\0'; unique_rows * cols];
1488    let mut ia = Vec::new();
1489    let mut ib = Vec::new();
1490
1491    for (row_pos, &entry_idx) in order.iter().enumerate() {
1492        let entry = &entries[entry_idx];
1493        for col in 0..cols {
1494            let dest = row_pos * cols + col;
1495            values[dest] = entry.row_data[col];
1496        }
1497        if let Some(a_row) = entry.a_row {
1498            ia.push((a_row + 1) as f64);
1499        } else if let Some(b_row) = entry.b_row {
1500            ib.push((b_row + 1) as f64);
1501        }
1502    }
1503
1504    let value_array = CharArray::new(values, unique_rows, cols)
1505        .map_err(|e| union_internal_error(format!("union: {e}")))?;
1506    let ia_len = ia.len();
1507    let ib_len = ib.len();
1508    let ia_tensor = Tensor::new(ia, vec![ia_len, 1])
1509        .map_err(|e| union_internal_error(format!("union: {e}")))?;
1510    let ib_tensor = Tensor::new(ib, vec![ib_len, 1])
1511        .map_err(|e| union_internal_error(format!("union: {e}")))?;
1512
1513    Ok(UnionEvaluation::new(
1514        Value::CharArray(value_array),
1515        ia_tensor,
1516        ib_tensor,
1517    ))
1518}
1519
1520fn assemble_string_union(
1521    entries: Vec<StringUnionEntry>,
1522    opts: &UnionOptions,
1523) -> crate::BuiltinResult<UnionEvaluation> {
1524    let mut order: Vec<usize> = (0..entries.len()).collect();
1525    match opts.order {
1526        UnionOrder::Sorted => {
1527            order.sort_by(|&lhs, &rhs| entries[lhs].value.cmp(&entries[rhs].value));
1528        }
1529        UnionOrder::Stable => {
1530            order.sort_by_key(|&idx| entries[idx].order_rank);
1531        }
1532    }
1533
1534    let mut values = Vec::with_capacity(order.len());
1535    let mut ia = Vec::new();
1536    let mut ib = Vec::new();
1537    for &idx in &order {
1538        let entry = &entries[idx];
1539        values.push(entry.value.clone());
1540        if let Some(a_idx) = entry.a_index {
1541            ia.push((a_idx + 1) as f64);
1542        } else if let Some(b_idx) = entry.b_index {
1543            ib.push((b_idx + 1) as f64);
1544        }
1545    }
1546
1547    let value_array = StringArray::new(values, vec![order.len(), 1])
1548        .map_err(|e| union_internal_error(format!("union: {e}")))?;
1549    let ia_len = ia.len();
1550    let ib_len = ib.len();
1551    let ia_tensor = Tensor::new(ia, vec![ia_len, 1])
1552        .map_err(|e| union_internal_error(format!("union: {e}")))?;
1553    let ib_tensor = Tensor::new(ib, vec![ib_len, 1])
1554        .map_err(|e| union_internal_error(format!("union: {e}")))?;
1555
1556    Ok(UnionEvaluation::new(
1557        Value::StringArray(value_array),
1558        ia_tensor,
1559        ib_tensor,
1560    ))
1561}
1562
1563fn assemble_string_row_union(
1564    entries: Vec<StringRowUnionEntry>,
1565    opts: &UnionOptions,
1566    cols: usize,
1567) -> crate::BuiltinResult<UnionEvaluation> {
1568    let mut order: Vec<usize> = (0..entries.len()).collect();
1569    match opts.order {
1570        UnionOrder::Sorted => {
1571            order.sort_by(|&lhs, &rhs| {
1572                compare_string_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1573            });
1574        }
1575        UnionOrder::Stable => {
1576            order.sort_by_key(|&idx| entries[idx].order_rank);
1577        }
1578    }
1579
1580    let unique_rows = order.len();
1581    let mut values = vec![String::new(); unique_rows * cols];
1582    let mut ia = Vec::new();
1583    let mut ib = Vec::new();
1584
1585    for (row_pos, &entry_idx) in order.iter().enumerate() {
1586        let entry = &entries[entry_idx];
1587        for col in 0..cols {
1588            let dest = row_pos + col * unique_rows;
1589            values[dest] = entry.row_data[col].clone();
1590        }
1591        if let Some(a_row) = entry.a_row {
1592            ia.push((a_row + 1) as f64);
1593        } else if let Some(b_row) = entry.b_row {
1594            ib.push((b_row + 1) as f64);
1595        }
1596    }
1597
1598    let value_array = StringArray::new(values, vec![unique_rows, cols])
1599        .map_err(|e| union_internal_error(format!("union: {e}")))?;
1600    let ia_len = ia.len();
1601    let ib_len = ib.len();
1602    let ia_tensor = Tensor::new(ia, vec![ia_len, 1])
1603        .map_err(|e| union_internal_error(format!("union: {e}")))?;
1604    let ib_tensor = Tensor::new(ib, vec![ib_len, 1])
1605        .map_err(|e| union_internal_error(format!("union: {e}")))?;
1606
1607    Ok(UnionEvaluation::new(
1608        Value::StringArray(value_array),
1609        ia_tensor,
1610        ib_tensor,
1611    ))
1612}
1613
1614fn canonicalize_f64(value: f64) -> u64 {
1615    if value.is_nan() {
1616        0x7ff8_0000_0000_0000u64
1617    } else if value == 0.0 {
1618        0u64
1619    } else {
1620        value.to_bits()
1621    }
1622}
1623
1624fn compare_f64(a: f64, b: f64) -> Ordering {
1625    if a.is_nan() {
1626        if b.is_nan() {
1627            Ordering::Equal
1628        } else {
1629            Ordering::Greater
1630        }
1631    } else if b.is_nan() {
1632        Ordering::Less
1633    } else {
1634        a.partial_cmp(&b).unwrap_or(Ordering::Equal)
1635    }
1636}
1637
1638fn compare_numeric_rows(a: &[f64], b: &[f64]) -> Ordering {
1639    for (lhs, rhs) in a.iter().zip(b.iter()) {
1640        let ord = compare_f64(*lhs, *rhs);
1641        if ord != Ordering::Equal {
1642            return ord;
1643        }
1644    }
1645    Ordering::Equal
1646}
1647
1648fn complex_is_nan(value: (f64, f64)) -> bool {
1649    value.0.is_nan() || value.1.is_nan()
1650}
1651
1652fn compare_complex(a: (f64, f64), b: (f64, f64)) -> Ordering {
1653    match (complex_is_nan(a), complex_is_nan(b)) {
1654        (true, true) => Ordering::Equal,
1655        (true, false) => Ordering::Greater,
1656        (false, true) => Ordering::Less,
1657        (false, false) => {
1658            let mag_a = a.0.hypot(a.1);
1659            let mag_b = b.0.hypot(b.1);
1660            let mag_cmp = compare_f64(mag_a, mag_b);
1661            if mag_cmp != Ordering::Equal {
1662                return mag_cmp;
1663            }
1664            let re_cmp = compare_f64(a.0, b.0);
1665            if re_cmp != Ordering::Equal {
1666                return re_cmp;
1667            }
1668            compare_f64(a.1, b.1)
1669        }
1670    }
1671}
1672
1673fn compare_complex_rows(a: &[(f64, f64)], b: &[(f64, f64)]) -> Ordering {
1674    for (lhs, rhs) in a.iter().zip(b.iter()) {
1675        let ord = compare_complex(*lhs, *rhs);
1676        if ord != Ordering::Equal {
1677            return ord;
1678        }
1679    }
1680    Ordering::Equal
1681}
1682
1683fn compare_char_rows(a: &[char], b: &[char]) -> Ordering {
1684    for (lhs, rhs) in a.iter().zip(b.iter()) {
1685        let ord = lhs.cmp(rhs);
1686        if ord != Ordering::Equal {
1687            return ord;
1688        }
1689    }
1690    Ordering::Equal
1691}
1692
1693fn compare_string_rows(a: &[String], b: &[String]) -> Ordering {
1694    for (lhs, rhs) in a.iter().zip(b.iter()) {
1695        let ord = lhs.cmp(rhs);
1696        if ord != Ordering::Equal {
1697            return ord;
1698        }
1699    }
1700    Ordering::Equal
1701}
1702
1703#[cfg(test)]
1704pub(crate) mod tests {
1705    use super::*;
1706    use crate::builtins::common::test_support;
1707    use runmat_accelerate_api::HostTensorView;
1708    use runmat_builtins::{IntValue, ResolveContext, Tensor, Type, Value};
1709
1710    fn evaluate_sync(a: Value, b: Value, rest: &[Value]) -> crate::BuiltinResult<UnionEvaluation> {
1711        futures::executor::block_on(evaluate(a, b, rest))
1712    }
1713
1714    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1715    #[test]
1716    fn union_numeric_sorted_default() {
1717        let a = Tensor::new(vec![5.0, 7.0, 1.0], vec![3, 1]).unwrap();
1718        let b = Tensor::new(vec![3.0, 1.0, 1.0], vec![3, 1]).unwrap();
1719        let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[]).expect("union");
1720        match eval.values_value() {
1721            Value::Tensor(t) => {
1722                assert_eq!(t.data, vec![1.0, 3.0, 5.0, 7.0]);
1723                assert_eq!(t.shape, vec![4, 1]);
1724            }
1725            other => panic!("expected tensor result, got {other:?}"),
1726        }
1727        let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1728        assert_eq!(ia.data, vec![3.0, 1.0, 2.0]);
1729        assert_eq!(ia.shape, vec![3, 1]);
1730        let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1731        assert_eq!(ib.data, vec![1.0]);
1732        assert_eq!(ib.shape, vec![1, 1]);
1733    }
1734
1735    #[test]
1736    fn union_type_resolver_numeric() {
1737        assert_eq!(
1738            set_values_output_type(
1739                &[Type::tensor(), Type::tensor()],
1740                &ResolveContext::new(Vec::new()),
1741            ),
1742            Type::tensor()
1743        );
1744    }
1745
1746    #[test]
1747    fn union_type_resolver_string_array() {
1748        assert_eq!(
1749            set_values_output_type(
1750                &[Type::cell_of(Type::String), Type::String],
1751                &ResolveContext::new(Vec::new()),
1752            ),
1753            Type::cell_of(Type::String)
1754        );
1755    }
1756
1757    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1758    #[test]
1759    fn union_numeric_stable_order() {
1760        let a = Tensor::new(vec![5.0, 7.0, 1.0], vec![3, 1]).unwrap();
1761        let b = Tensor::new(vec![3.0, 2.0, 4.0], vec![3, 1]).unwrap();
1762        let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[Value::from("stable")])
1763            .expect("union");
1764        match eval.values_value() {
1765            Value::Tensor(t) => {
1766                assert_eq!(t.data, vec![5.0, 7.0, 1.0, 3.0, 2.0, 4.0]);
1767                assert_eq!(t.shape, vec![6, 1]);
1768            }
1769            other => panic!("expected tensor result, got {other:?}"),
1770        }
1771        let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1772        assert_eq!(ia.data, vec![1.0, 2.0, 3.0]);
1773        let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1774        assert_eq!(ib.data, vec![1.0, 2.0, 3.0]);
1775    }
1776
1777    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1778    #[test]
1779    fn union_numeric_sorted_places_nan_last() {
1780        let a = Tensor::new(vec![f64::NAN, 1.0], vec![2, 1]).unwrap();
1781        let b = Tensor::new(vec![2.0, f64::NAN], vec![2, 1]).unwrap();
1782        let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[]).expect("union");
1783        let values = tensor::value_into_tensor_for("union", eval.values_value()).expect("values");
1784        assert_eq!(values.shape, vec![3, 1]);
1785        assert_eq!(values.data[0], 1.0);
1786        assert_eq!(values.data[1], 2.0);
1787        assert!(values.data[2].is_nan());
1788        let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1789        assert_eq!(ia.data, vec![2.0, 1.0]);
1790        let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1791        assert_eq!(ib.data, vec![1.0]);
1792    }
1793
1794    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1795    #[test]
1796    fn union_numeric_rows_sorted() {
1797        let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
1798        let b = Tensor::new(vec![3.0, 5.0, 4.0, 6.0], vec![2, 2]).unwrap();
1799        let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[Value::from("rows")])
1800            .expect("union");
1801        match eval.values_value() {
1802            Value::Tensor(t) => {
1803                assert_eq!(t.shape, vec![3, 2]);
1804                assert_eq!(t.data, vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
1805            }
1806            other => panic!("expected tensor result, got {other:?}"),
1807        }
1808        let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1809        assert_eq!(ia.data, vec![1.0, 2.0]);
1810        let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1811        assert_eq!(ib.data, vec![2.0]);
1812    }
1813
1814    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1815    #[test]
1816    fn union_numeric_rows_stable_preserves_first_occurrence() {
1817        let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
1818        let b = Tensor::new(vec![3.0, 5.0, 1.0, 4.0, 6.0, 2.0], vec![3, 2]).unwrap();
1819        let eval = evaluate_sync(
1820            Value::Tensor(a),
1821            Value::Tensor(b),
1822            &[Value::from("rows"), Value::from("stable")],
1823        )
1824        .expect("union");
1825        let (values, ia, ib) = eval.into_triple();
1826        match values {
1827            Value::Tensor(t) => {
1828                assert_eq!(t.shape, vec![3, 2]);
1829                assert_eq!(t.data, vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
1830            }
1831            other => panic!("expected tensor result, got {other:?}"),
1832        }
1833        let ia_tensor = tensor::value_into_tensor_for("union", ia).expect("ia tensor");
1834        assert_eq!(ia_tensor.data, vec![1.0, 2.0]);
1835        let ib_tensor = tensor::value_into_tensor_for("union", ib).expect("ib tensor");
1836        assert_eq!(ib_tensor.data, vec![2.0]);
1837    }
1838
1839    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1840    #[test]
1841    fn union_char_elements() {
1842        let a = CharArray::new(vec!['m', 'z', 'm', 'a'], 2, 2).unwrap();
1843        let b = CharArray::new(vec!['a', 'x', 'm', 'a'], 2, 2).unwrap();
1844        let eval = evaluate_sync(Value::CharArray(a), Value::CharArray(b), &[]).expect("union");
1845        match eval.values_value() {
1846            Value::CharArray(arr) => {
1847                assert_eq!(arr.rows, 4);
1848                assert_eq!(arr.cols, 1);
1849                assert_eq!(arr.data, vec!['a', 'm', 'x', 'z']);
1850            }
1851            other => panic!("expected char array, got {other:?}"),
1852        }
1853        let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1854        assert_eq!(ia.data, vec![4.0, 1.0, 3.0]);
1855        let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1856        assert_eq!(ib.data, vec![3.0]);
1857    }
1858
1859    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1860    #[test]
1861    fn union_string_rows_stable() {
1862        let a = StringArray::new(
1863            vec![
1864                "alpha".to_string(),
1865                "gamma".to_string(),
1866                "beta".to_string(),
1867                "beta".to_string(),
1868            ],
1869            vec![2, 2],
1870        )
1871        .unwrap();
1872        let b = StringArray::new(
1873            vec![
1874                "gamma".to_string(),
1875                "delta".to_string(),
1876                "beta".to_string(),
1877                "beta".to_string(),
1878            ],
1879            vec![2, 2],
1880        )
1881        .unwrap();
1882        let eval = evaluate_sync(
1883            Value::StringArray(a),
1884            Value::StringArray(b),
1885            &[Value::from("rows"), Value::from("stable")],
1886        )
1887        .expect("union");
1888        match eval.values_value() {
1889            Value::StringArray(arr) => {
1890                assert_eq!(arr.shape, vec![3, 2]);
1891                assert_eq!(
1892                    arr.data,
1893                    vec![
1894                        "alpha".to_string(),
1895                        "gamma".to_string(),
1896                        "delta".to_string(),
1897                        "beta".to_string(),
1898                        "beta".to_string(),
1899                        "beta".to_string()
1900                    ]
1901                );
1902            }
1903            other => panic!("expected string array, got {other:?}"),
1904        }
1905        let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1906        assert_eq!(ia.data, vec![1.0, 2.0]);
1907        let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1908        assert_eq!(ib.data, vec![2.0]);
1909    }
1910
1911    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1912    #[test]
1913    fn union_gpu_roundtrip() {
1914        test_support::with_test_provider(|provider| {
1915            let a = Tensor::new(vec![4.0, 1.0, 2.0], vec![3, 1]).unwrap();
1916            let b = Tensor::new(vec![2.0, 5.0], vec![2, 1]).unwrap();
1917            let view_a = HostTensorView {
1918                data: &a.data,
1919                shape: &a.shape,
1920            };
1921            let view_b = HostTensorView {
1922                data: &b.data,
1923                shape: &b.shape,
1924            };
1925            let handle_a = provider.upload(&view_a).expect("upload A");
1926            let handle_b = provider.upload(&view_b).expect("upload B");
1927            let eval = evaluate_sync(
1928                Value::GpuTensor(handle_a),
1929                Value::GpuTensor(handle_b),
1930                &[Value::from("stable")],
1931            )
1932            .expect("union");
1933            let values = tensor::value_into_tensor_for("union", eval.values_value()).unwrap();
1934            assert_eq!(values.data, vec![4.0, 1.0, 2.0, 5.0]);
1935            let ia = tensor::value_into_tensor_for("union", eval.ia_value()).unwrap();
1936            assert_eq!(ia.data, vec![1.0, 2.0, 3.0]);
1937            let ib = tensor::value_into_tensor_for("union", eval.ib_value()).unwrap();
1938            assert_eq!(ib.data, vec![2.0]);
1939        });
1940    }
1941
1942    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1943    #[test]
1944    fn union_rejects_legacy_option() {
1945        let tensor =
1946            Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).expect("tensor construction failed");
1947        let err = evaluate_sync(
1948            Value::Tensor(tensor.clone()),
1949            Value::Tensor(tensor),
1950            &[Value::from("legacy")],
1951        )
1952        .unwrap_err();
1953        assert_eq!(
1954            err.identifier(),
1955            UNION_ERROR_LEGACY_OPTION_UNSUPPORTED.identifier
1956        );
1957    }
1958
1959    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1960    #[test]
1961    fn union_rejects_conflicting_order_options() {
1962        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).expect("tensor construction failed");
1963        let err = evaluate_sync(
1964            Value::Tensor(tensor.clone()),
1965            Value::Tensor(tensor),
1966            &[Value::from("stable"), Value::from("sorted")],
1967        )
1968        .unwrap_err();
1969        assert_eq!(
1970            err.identifier(),
1971            UNION_ERROR_CONFLICTING_ORDER_OPTIONS.identifier
1972        );
1973    }
1974
1975    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1976    #[test]
1977    fn union_rejects_unknown_option() {
1978        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).expect("tensor construction failed");
1979        let err = evaluate_sync(
1980            Value::Tensor(tensor.clone()),
1981            Value::Tensor(tensor),
1982            &[Value::from("bogus")],
1983        )
1984        .unwrap_err();
1985        assert_eq!(err.identifier(), UNION_ERROR_UNKNOWN_OPTION.identifier);
1986    }
1987
1988    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1989    #[test]
1990    fn union_rows_dimension_mismatch() {
1991        let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1992        let b = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1993        let err =
1994            evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[Value::from("rows")]).unwrap_err();
1995        assert_eq!(
1996            err.identifier(),
1997            UNION_ERROR_ROWS_COLUMN_MISMATCH.identifier
1998        );
1999    }
2000
2001    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2002    #[test]
2003    fn union_requires_matching_types() {
2004        let a = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
2005        let b = CharArray::new(vec!['a', 'b'], 1, 2).unwrap();
2006        let err = union_host(
2007            Value::Tensor(a),
2008            Value::CharArray(b),
2009            &UnionOptions {
2010                rows: false,
2011                order: UnionOrder::Sorted,
2012            },
2013        )
2014        .unwrap_err();
2015        assert_eq!(
2016            err.identifier(),
2017            UNION_ERROR_UNSUPPORTED_INPUT_TYPE.identifier
2018        );
2019    }
2020
2021    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2022    #[test]
2023    fn union_accepts_scalar_inputs() {
2024        let eval =
2025            evaluate_sync(Value::Int(IntValue::I32(1)), Value::Num(3.0), &[]).expect("union");
2026        match eval.values_value() {
2027            Value::Tensor(t) => {
2028                assert_eq!(t.data, vec![1.0, 3.0]);
2029                assert_eq!(t.shape, vec![2, 1]);
2030            }
2031            other => panic!("expected numeric tensor, got {other:?}"),
2032        }
2033        let ia = tensor::value_into_tensor_for("union", eval.ia_value()).unwrap();
2034        assert_eq!(ia.data, vec![1.0]);
2035        let ib = tensor::value_into_tensor_for("union", eval.ib_value()).unwrap();
2036        assert_eq!(ib.data, vec![1.0]);
2037    }
2038
2039    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2040    #[test]
2041    #[cfg(feature = "wgpu")]
2042    fn union_wgpu_matches_cpu() {
2043        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
2044            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
2045        );
2046        let a = Tensor::new(vec![4.0, 1.0, 2.0, 3.0], vec![4, 1]).unwrap();
2047        let b = Tensor::new(vec![2.0, 6.0, 3.0], vec![3, 1]).unwrap();
2048
2049        let cpu_eval =
2050            evaluate_sync(Value::Tensor(a.clone()), Value::Tensor(b.clone()), &[]).expect("union");
2051        let cpu_values = tensor::value_into_tensor_for("union", cpu_eval.values_value()).unwrap();
2052        let cpu_ia = tensor::value_into_tensor_for("union", cpu_eval.ia_value()).unwrap();
2053        let cpu_ib = tensor::value_into_tensor_for("union", cpu_eval.ib_value()).unwrap();
2054
2055        let provider = runmat_accelerate_api::provider().expect("provider");
2056        let view_a = HostTensorView {
2057            data: &a.data,
2058            shape: &a.shape,
2059        };
2060        let view_b = HostTensorView {
2061            data: &b.data,
2062            shape: &b.shape,
2063        };
2064        let handle_a = provider.upload(&view_a).expect("upload A");
2065        let handle_b = provider.upload(&view_b).expect("upload B");
2066        let gpu_eval = evaluate_sync(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
2067            .expect("union");
2068        let gpu_values = tensor::value_into_tensor_for("union", gpu_eval.values_value()).unwrap();
2069        let gpu_ia = tensor::value_into_tensor_for("union", gpu_eval.ia_value()).unwrap();
2070        let gpu_ib = tensor::value_into_tensor_for("union", gpu_eval.ib_value()).unwrap();
2071
2072        assert_eq!(gpu_values.data, cpu_values.data);
2073        assert_eq!(gpu_values.shape, cpu_values.shape);
2074        assert_eq!(gpu_ia.data, cpu_ia.data);
2075        assert_eq!(gpu_ib.data, cpu_ib.data);
2076    }
2077}