1use std::cmp::Ordering;
9use std::collections::{hash_map::Entry, HashMap};
10
11use runmat_accelerate_api::{
12 GpuTensorHandle, GpuTensorStorage, HostTensorOwned, UnionOptions, UnionOrder, UnionResult,
13};
14use runmat_builtins::{
15 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
16 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
17 CharArray, ComplexTensor, StringArray, Tensor, Value,
18};
19use runmat_macros::runtime_builtin;
20
21use super::type_resolvers::set_values_output_type;
22use crate::build_runtime_error;
23use crate::builtins::common::arg_tokens::tokens_from_values;
24use crate::builtins::common::gpu_helpers;
25use crate::builtins::common::random_args::complex_tensor_into_value;
26use crate::builtins::common::spec::{
27 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
28 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
29};
30use crate::builtins::common::tensor;
31
32#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::sorting_sets::union")]
33pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
34 name: "union",
35 op_kind: GpuOpKind::Custom("union"),
36 supported_precisions: &[ScalarType::F32, ScalarType::F64],
37 broadcast: BroadcastSemantics::None,
38 provider_hooks: &[ProviderHook::Custom("union")],
39 constant_strategy: ConstantStrategy::InlineLiteral,
40 residency: ResidencyPolicy::GatherImmediately,
41 nan_mode: ReductionNaN::Include,
42 two_pass_threshold: None,
43 workgroup_size: None,
44 accepts_nan_mode: true,
45 notes: "Providers may expose a dedicated union hook; otherwise tensors are gathered and processed on the host.",
46};
47
48#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::sorting_sets::union")]
49pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
50 name: "union",
51 shape: ShapeRequirements::Any,
52 constant_strategy: ConstantStrategy::InlineLiteral,
53 elementwise: None,
54 reduction: None,
55 emits_nan: true,
56 notes: "`union` terminates fusion chains and materialises results on the host; upstream tensors are gathered when necessary.",
57};
58
59const BUILTIN_NAME: &str = "union";
60
61const UNION_OUTPUT_C: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
62 name: "C",
63 ty: BuiltinParamType::Any,
64 arity: BuiltinParamArity::Required,
65 default: None,
66 description: "Union values or rows.",
67}];
68
69const UNION_OUTPUT_C_IA: [BuiltinParamDescriptor; 2] = [
70 BuiltinParamDescriptor {
71 name: "C",
72 ty: BuiltinParamType::Any,
73 arity: BuiltinParamArity::Required,
74 default: None,
75 description: "Union values or rows.",
76 },
77 BuiltinParamDescriptor {
78 name: "ia",
79 ty: BuiltinParamType::NumericArray,
80 arity: BuiltinParamArity::Required,
81 default: None,
82 description: "Indices selecting contributions from A.",
83 },
84];
85
86const UNION_OUTPUT_C_IA_IB: [BuiltinParamDescriptor; 3] = [
87 BuiltinParamDescriptor {
88 name: "C",
89 ty: BuiltinParamType::Any,
90 arity: BuiltinParamArity::Required,
91 default: None,
92 description: "Union values or rows.",
93 },
94 BuiltinParamDescriptor {
95 name: "ia",
96 ty: BuiltinParamType::NumericArray,
97 arity: BuiltinParamArity::Required,
98 default: None,
99 description: "Indices selecting contributions from A.",
100 },
101 BuiltinParamDescriptor {
102 name: "ib",
103 ty: BuiltinParamType::NumericArray,
104 arity: BuiltinParamArity::Required,
105 default: None,
106 description: "Indices selecting contributions from B.",
107 },
108];
109
110const UNION_INPUTS_A_B: [BuiltinParamDescriptor; 2] = [
111 BuiltinParamDescriptor {
112 name: "A",
113 ty: BuiltinParamType::Any,
114 arity: BuiltinParamArity::Required,
115 default: None,
116 description: "First input array.",
117 },
118 BuiltinParamDescriptor {
119 name: "B",
120 ty: BuiltinParamType::Any,
121 arity: BuiltinParamArity::Required,
122 default: None,
123 description: "Second input array.",
124 },
125];
126
127const UNION_INPUTS_A_B_OPTIONS: [BuiltinParamDescriptor; 3] = [
128 BuiltinParamDescriptor {
129 name: "A",
130 ty: BuiltinParamType::Any,
131 arity: BuiltinParamArity::Required,
132 default: None,
133 description: "First input array.",
134 },
135 BuiltinParamDescriptor {
136 name: "B",
137 ty: BuiltinParamType::Any,
138 arity: BuiltinParamArity::Required,
139 default: None,
140 description: "Second input array.",
141 },
142 BuiltinParamDescriptor {
143 name: "option",
144 ty: BuiltinParamType::StringScalar,
145 arity: BuiltinParamArity::Variadic,
146 default: None,
147 description: "Option tokens: 'rows'|'sorted'|'stable'.",
148 },
149];
150
151const UNION_SIGNATURES: [BuiltinSignatureDescriptor; 6] = [
152 BuiltinSignatureDescriptor {
153 label: "C = union(A, B)",
154 inputs: &UNION_INPUTS_A_B,
155 outputs: &UNION_OUTPUT_C,
156 },
157 BuiltinSignatureDescriptor {
158 label: "C = union(A, B, option...)",
159 inputs: &UNION_INPUTS_A_B_OPTIONS,
160 outputs: &UNION_OUTPUT_C,
161 },
162 BuiltinSignatureDescriptor {
163 label: "[C, ia] = union(A, B)",
164 inputs: &UNION_INPUTS_A_B,
165 outputs: &UNION_OUTPUT_C_IA,
166 },
167 BuiltinSignatureDescriptor {
168 label: "[C, ia] = union(A, B, option...)",
169 inputs: &UNION_INPUTS_A_B_OPTIONS,
170 outputs: &UNION_OUTPUT_C_IA,
171 },
172 BuiltinSignatureDescriptor {
173 label: "[C, ia, ib] = union(A, B)",
174 inputs: &UNION_INPUTS_A_B,
175 outputs: &UNION_OUTPUT_C_IA_IB,
176 },
177 BuiltinSignatureDescriptor {
178 label: "[C, ia, ib] = union(A, B, option...)",
179 inputs: &UNION_INPUTS_A_B_OPTIONS,
180 outputs: &UNION_OUTPUT_C_IA_IB,
181 },
182];
183
184const UNION_ERROR_LEGACY_OPTION_UNSUPPORTED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
185 code: "RM.UNION.LEGACY_OPTION_UNSUPPORTED",
186 identifier: Some("RunMat:union:LegacyOptionUnsupported"),
187 when: "Legacy compatibility options are requested.",
188 message: "union: the 'legacy' behaviour is not supported",
189};
190
191const UNION_ERROR_CONFLICTING_ORDER_OPTIONS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
192 code: "RM.UNION.CONFLICTING_ORDER_OPTIONS",
193 identifier: Some("RunMat:union:ConflictingOrderOptions"),
194 when: "Both 'sorted' and 'stable' options are provided.",
195 message: "union: cannot combine 'sorted' with 'stable'",
196};
197
198const UNION_ERROR_UNKNOWN_OPTION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
199 code: "RM.UNION.UNKNOWN_OPTION",
200 identifier: Some("RunMat:union:UnknownOption"),
201 when: "An unsupported option token is provided.",
202 message: "union: unrecognised option",
203};
204
205const UNION_ERROR_ROWS_COLUMN_MISMATCH: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
206 code: "RM.UNION.ROWS_COLUMN_MISMATCH",
207 identifier: Some("RunMat:union:RowsColumnMismatch"),
208 when: "'rows' mode is used and column counts differ.",
209 message: "union: inputs must have the same number of columns when using 'rows'",
210};
211
212const UNION_ERROR_UNSUPPORTED_INPUT_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
213 code: "RM.UNION.UNSUPPORTED_INPUT_TYPE",
214 identifier: Some("RunMat:union:UnsupportedInputType"),
215 when: "Input values cannot be converted into supported union domains.",
216 message: "union: unsupported input type",
217};
218
219const UNION_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
220 code: "RM.UNION.INVALID_ARGUMENT",
221 identifier: Some("RunMat:union:InvalidArgument"),
222 when: "Option arguments are not string-like where required.",
223 message: "union: expected string option arguments",
224};
225
226const UNION_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
227 code: "RM.UNION.INTERNAL",
228 identifier: Some("RunMat:union:Internal"),
229 when: "Internal conversion/allocation/provider decode fails.",
230 message: "union: internal operation failed",
231};
232
233const UNION_ERRORS: [BuiltinErrorDescriptor; 7] = [
234 UNION_ERROR_LEGACY_OPTION_UNSUPPORTED,
235 UNION_ERROR_CONFLICTING_ORDER_OPTIONS,
236 UNION_ERROR_UNKNOWN_OPTION,
237 UNION_ERROR_ROWS_COLUMN_MISMATCH,
238 UNION_ERROR_UNSUPPORTED_INPUT_TYPE,
239 UNION_ERROR_INVALID_ARGUMENT,
240 UNION_ERROR_INTERNAL,
241];
242
243pub const UNION_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
244 signatures: &UNION_SIGNATURES,
245 output_mode: BuiltinOutputMode::ByRequestedOutputCount,
246 completion_policy: BuiltinCompletionPolicy::Public,
247 errors: &UNION_ERRORS,
248};
249
250fn union_error_with(
251 error: &'static BuiltinErrorDescriptor,
252 message: impl Into<String>,
253) -> crate::RuntimeError {
254 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
255 if let Some(identifier) = error.identifier {
256 builder = builder.with_identifier(identifier);
257 }
258 builder.build()
259}
260
261fn union_error(error: &'static BuiltinErrorDescriptor) -> crate::RuntimeError {
262 union_error_with(error, error.message)
263}
264
265fn union_internal_error(message: impl Into<String>) -> crate::RuntimeError {
266 union_error_with(&UNION_ERROR_INTERNAL, message)
267}
268
269#[runtime_builtin(
270 name = "union",
271 category = "array/sorting_sets",
272 summary = "Return unions of input arrays with ordering and index-output controls.",
273 keywords = "union,set,stable,rows,indices,gpu",
274 accel = "array_construct",
275 sink = true,
276 type_resolver(set_values_output_type),
277 descriptor(crate::builtins::array::sorting_sets::union::UNION_DESCRIPTOR),
278 builtin_path = "crate::builtins::array::sorting_sets::union"
279)]
280async fn union_builtin(a: Value, b: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
281 let eval = evaluate(a, b, &rest).await?;
282 if let Some(out_count) = crate::output_count::current_output_count() {
283 if out_count == 0 {
284 return Ok(Value::OutputList(Vec::new()));
285 }
286 if out_count == 1 {
287 return Ok(Value::OutputList(vec![eval.into_values_value()]));
288 }
289 if out_count == 2 {
290 let (values, ia) = eval.into_pair();
291 return Ok(Value::OutputList(vec![values, ia]));
292 }
293 let (values, ia, ib) = eval.into_triple();
294 return Ok(crate::output_count::output_list_with_padding(
295 out_count,
296 vec![values, ia, ib],
297 ));
298 }
299 Ok(eval.into_values_value())
300}
301
302pub async fn evaluate(a: Value, b: Value, rest: &[Value]) -> crate::BuiltinResult<UnionEvaluation> {
304 let opts = parse_options(rest)?;
305 match (a, b) {
306 (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
307 union_gpu_pair(handle_a, handle_b, &opts).await
308 }
309 (Value::GpuTensor(handle_a), other) => union_gpu_mixed(handle_a, other, &opts, true).await,
310 (other, Value::GpuTensor(handle_b)) => union_gpu_mixed(handle_b, other, &opts, false).await,
311 (left, right) => union_host(left, right, &opts),
312 }
313}
314
315fn parse_options(rest: &[Value]) -> crate::BuiltinResult<UnionOptions> {
316 let mut opts = UnionOptions {
317 rows: false,
318 order: UnionOrder::Sorted,
319 };
320 let mut seen_order: Option<UnionOrder> = None;
321
322 let tokens = tokens_from_values(rest);
323 for (arg, token) in rest.iter().zip(tokens.iter()) {
324 let text = match token {
325 crate::builtins::common::arg_tokens::ArgToken::String(text) => text.as_str(),
326 _ => {
327 let text = tensor::value_to_string(arg)
328 .ok_or_else(|| union_error(&UNION_ERROR_INVALID_ARGUMENT))?;
329 let lowered = text.trim().to_ascii_lowercase();
330 parse_union_option(&mut opts, &mut seen_order, &lowered)?;
331 continue;
332 }
333 };
334 parse_union_option(&mut opts, &mut seen_order, text)?;
335 }
336
337 Ok(opts)
338}
339
340fn parse_union_option(
341 opts: &mut UnionOptions,
342 seen_order: &mut Option<UnionOrder>,
343 lowered: &str,
344) -> crate::BuiltinResult<()> {
345 match lowered {
346 "rows" => opts.rows = true,
347 "sorted" => {
348 if let Some(prev) = seen_order {
349 if *prev != UnionOrder::Sorted {
350 return Err(union_error(&UNION_ERROR_CONFLICTING_ORDER_OPTIONS));
351 }
352 }
353 *seen_order = Some(UnionOrder::Sorted);
354 opts.order = UnionOrder::Sorted;
355 }
356 "stable" => {
357 if let Some(prev) = seen_order {
358 if *prev != UnionOrder::Stable {
359 return Err(union_error(&UNION_ERROR_CONFLICTING_ORDER_OPTIONS));
360 }
361 }
362 *seen_order = Some(UnionOrder::Stable);
363 opts.order = UnionOrder::Stable;
364 }
365 "legacy" | "r2012a" => {
366 return Err(union_error(&UNION_ERROR_LEGACY_OPTION_UNSUPPORTED));
367 }
368 other => {
369 return Err(union_error_with(
370 &UNION_ERROR_UNKNOWN_OPTION,
371 format!("union: unrecognised option '{other}'"),
372 ))
373 }
374 }
375 Ok(())
376}
377
378async fn union_gpu_pair(
379 handle_a: GpuTensorHandle,
380 handle_b: GpuTensorHandle,
381 opts: &UnionOptions,
382) -> crate::BuiltinResult<UnionEvaluation> {
383 if let Some(provider) = runmat_accelerate_api::provider() {
384 match provider.union(&handle_a, &handle_b, opts).await {
385 Ok(result) => return UnionEvaluation::from_union_result(result),
386 Err(_) => {
387 }
389 }
390 }
391 let tensor_a = gpu_helpers::gather_tensor_async(&handle_a).await?;
392 let tensor_b = gpu_helpers::gather_tensor_async(&handle_b).await?;
393 union_numeric(tensor_a, tensor_b, opts)
394}
395
396async fn union_gpu_mixed(
397 handle_gpu: GpuTensorHandle,
398 other: Value,
399 opts: &UnionOptions,
400 gpu_is_a: bool,
401) -> crate::BuiltinResult<UnionEvaluation> {
402 let tensor_gpu = gpu_helpers::gather_tensor_async(&handle_gpu).await?;
403 let tensor_other =
404 tensor::value_into_tensor_for("union", other).map_err(|e| union_internal_error(e))?;
405 if gpu_is_a {
406 union_numeric(tensor_gpu, tensor_other, opts)
407 } else {
408 union_numeric(tensor_other, tensor_gpu, opts)
409 }
410}
411
412fn union_host(a: Value, b: Value, opts: &UnionOptions) -> crate::BuiltinResult<UnionEvaluation> {
413 match (a, b) {
414 (Value::ComplexTensor(at), Value::ComplexTensor(bt)) => union_complex(at, bt, opts),
416 (Value::ComplexTensor(at), Value::Complex(re, im)) => {
417 let bt = ComplexTensor::new(vec![(re, im)], vec![1, 1])
418 .map_err(|e| union_internal_error(format!("union: {e}")))?;
419 union_complex(at, bt, opts)
420 }
421 (Value::Complex(re, im), Value::ComplexTensor(bt)) => {
422 let at = ComplexTensor::new(vec![(re, im)], vec![1, 1])
423 .map_err(|e| union_internal_error(format!("union: {e}")))?;
424 union_complex(at, bt, opts)
425 }
426 (Value::Complex(a_re, a_im), Value::Complex(b_re, b_im)) => {
427 let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
428 .map_err(|e| union_internal_error(format!("union: {e}")))?;
429 let bt = ComplexTensor::new(vec![(b_re, b_im)], vec![1, 1])
430 .map_err(|e| union_internal_error(format!("union: {e}")))?;
431 union_complex(at, bt, opts)
432 }
433
434 (Value::CharArray(ac), Value::CharArray(bc)) => union_char(ac, bc, opts),
436
437 (Value::StringArray(astring), Value::StringArray(bstring)) => {
439 union_string(astring, bstring, opts)
440 }
441 (Value::StringArray(astring), Value::String(b)) => {
442 let bstring = StringArray::new(vec![b], vec![1, 1])
443 .map_err(|e| union_internal_error(format!("union: {e}")))?;
444 union_string(astring, bstring, opts)
445 }
446 (Value::String(a), Value::StringArray(bstring)) => {
447 let astring = StringArray::new(vec![a], vec![1, 1])
448 .map_err(|e| union_internal_error(format!("union: {e}")))?;
449 union_string(astring, bstring, opts)
450 }
451 (Value::String(a), Value::String(b)) => {
452 let astring = StringArray::new(vec![a], vec![1, 1])
453 .map_err(|e| union_internal_error(format!("union: {e}")))?;
454 let bstring = StringArray::new(vec![b], vec![1, 1])
455 .map_err(|e| union_internal_error(format!("union: {e}")))?;
456 union_string(astring, bstring, opts)
457 }
458
459 (left, right) => {
461 let tensor_a = tensor::value_into_tensor_for("union", left)
462 .map_err(|e| union_error_with(&UNION_ERROR_UNSUPPORTED_INPUT_TYPE, e))?;
463 let tensor_b = tensor::value_into_tensor_for("union", right)
464 .map_err(|e| union_error_with(&UNION_ERROR_UNSUPPORTED_INPUT_TYPE, e))?;
465 union_numeric(tensor_a, tensor_b, opts)
466 }
467 }
468}
469
470fn union_numeric(
471 a: Tensor,
472 b: Tensor,
473 opts: &UnionOptions,
474) -> crate::BuiltinResult<UnionEvaluation> {
475 if opts.rows {
476 union_numeric_rows(a, b, opts)
477 } else {
478 union_numeric_elements(a, b, opts)
479 }
480}
481
482pub fn union_numeric_from_tensors(
484 a: Tensor,
485 b: Tensor,
486 opts: &UnionOptions,
487) -> crate::BuiltinResult<UnionEvaluation> {
488 union_numeric(a, b, opts)
489}
490
491fn union_numeric_elements(
492 a: Tensor,
493 b: Tensor,
494 opts: &UnionOptions,
495) -> crate::BuiltinResult<UnionEvaluation> {
496 let mut entries = Vec::<NumericUnionEntry>::new();
497 let mut map: HashMap<u64, usize> = HashMap::new();
498 let mut order_counter = 0usize;
499
500 for (idx, &value) in a.data.iter().enumerate() {
501 let key = canonicalize_f64(value);
502 match map.entry(key) {
503 Entry::Occupied(_) => {
504 }
506 Entry::Vacant(v) => {
507 let entry_idx = entries.len();
508 entries.push(NumericUnionEntry {
509 value,
510 a_index: Some(idx),
511 b_index: None,
512 order_rank: order_counter,
513 });
514 v.insert(entry_idx);
515 order_counter += 1;
516 }
517 }
518 }
519
520 for (idx, &value) in b.data.iter().enumerate() {
521 let key = canonicalize_f64(value);
522 match map.entry(key) {
523 Entry::Occupied(occ) => {
524 let entry = &mut entries[*occ.get()];
525 if entry.a_index.is_none() && entry.b_index.is_none() {
526 entry.b_index = Some(idx);
527 }
528 }
529 Entry::Vacant(v) => {
530 let entry_idx = entries.len();
531 entries.push(NumericUnionEntry {
532 value,
533 a_index: None,
534 b_index: Some(idx),
535 order_rank: order_counter,
536 });
537 v.insert(entry_idx);
538 order_counter += 1;
539 }
540 }
541 }
542
543 assemble_numeric_union(entries, opts)
544}
545
546fn union_numeric_rows(
547 a: Tensor,
548 b: Tensor,
549 opts: &UnionOptions,
550) -> crate::BuiltinResult<UnionEvaluation> {
551 if a.shape.len() != 2 || b.shape.len() != 2 {
552 return Err(union_internal_error(
553 "union: 'rows' option requires 2-D numeric matrices",
554 ));
555 }
556 if a.shape[1] != b.shape[1] {
557 return Err(union_error_with(
558 &UNION_ERROR_ROWS_COLUMN_MISMATCH,
559 UNION_ERROR_ROWS_COLUMN_MISMATCH.message,
560 ));
561 }
562 let rows_a = a.shape[0];
563 let cols = a.shape[1];
564 let rows_b = b.shape[0];
565
566 let mut entries = Vec::<NumericRowUnionEntry>::new();
567 let mut map: HashMap<NumericRowKey, usize> = HashMap::new();
568 let mut order_counter = 0usize;
569
570 for r in 0..rows_a {
571 let mut row_values = Vec::with_capacity(cols);
572 for c in 0..cols {
573 let idx = r + c * rows_a;
574 row_values.push(a.data[idx]);
575 }
576 let key = NumericRowKey::from_slice(&row_values);
577 match map.entry(key) {
578 Entry::Occupied(_) => {}
579 Entry::Vacant(v) => {
580 let entry_idx = entries.len();
581 entries.push(NumericRowUnionEntry {
582 row_data: row_values,
583 a_row: Some(r),
584 b_row: None,
585 order_rank: order_counter,
586 });
587 v.insert(entry_idx);
588 order_counter += 1;
589 }
590 }
591 }
592
593 for r in 0..rows_b {
594 let mut row_values = Vec::with_capacity(cols);
595 for c in 0..cols {
596 let idx = r + c * rows_b;
597 row_values.push(b.data[idx]);
598 }
599 let key = NumericRowKey::from_slice(&row_values);
600 match map.entry(key) {
601 Entry::Occupied(occ) => {
602 let entry = &mut entries[*occ.get()];
603 if entry.a_row.is_none() && entry.b_row.is_none() {
604 entry.b_row = Some(r);
605 }
606 }
607 Entry::Vacant(v) => {
608 let entry_idx = entries.len();
609 entries.push(NumericRowUnionEntry {
610 row_data: row_values,
611 a_row: None,
612 b_row: Some(r),
613 order_rank: order_counter,
614 });
615 v.insert(entry_idx);
616 order_counter += 1;
617 }
618 }
619 }
620
621 assemble_numeric_row_union(entries, opts, cols)
622}
623
624fn union_complex(
625 a: ComplexTensor,
626 b: ComplexTensor,
627 opts: &UnionOptions,
628) -> crate::BuiltinResult<UnionEvaluation> {
629 if opts.rows {
630 union_complex_rows(a, b, opts)
631 } else {
632 union_complex_elements(a, b, opts)
633 }
634}
635
636fn union_complex_elements(
637 a: ComplexTensor,
638 b: ComplexTensor,
639 opts: &UnionOptions,
640) -> crate::BuiltinResult<UnionEvaluation> {
641 let mut entries = Vec::<ComplexUnionEntry>::new();
642 let mut map: HashMap<ComplexKey, usize> = HashMap::new();
643 let mut order_counter = 0usize;
644
645 for (idx, &value) in a.data.iter().enumerate() {
646 let key = ComplexKey::new(value);
647 match map.entry(key) {
648 Entry::Occupied(_) => {}
649 Entry::Vacant(v) => {
650 let entry_idx = entries.len();
651 entries.push(ComplexUnionEntry {
652 value,
653 a_index: Some(idx),
654 b_index: None,
655 order_rank: order_counter,
656 });
657 v.insert(entry_idx);
658 order_counter += 1;
659 }
660 }
661 }
662
663 for (idx, &value) in b.data.iter().enumerate() {
664 let key = ComplexKey::new(value);
665 match map.entry(key) {
666 Entry::Occupied(occ) => {
667 let entry = &mut entries[*occ.get()];
668 if entry.a_index.is_none() && entry.b_index.is_none() {
669 entry.b_index = Some(idx);
670 }
671 }
672 Entry::Vacant(v) => {
673 let entry_idx = entries.len();
674 entries.push(ComplexUnionEntry {
675 value,
676 a_index: None,
677 b_index: Some(idx),
678 order_rank: order_counter,
679 });
680 v.insert(entry_idx);
681 order_counter += 1;
682 }
683 }
684 }
685
686 assemble_complex_union(entries, opts)
687}
688
689fn union_complex_rows(
690 a: ComplexTensor,
691 b: ComplexTensor,
692 opts: &UnionOptions,
693) -> crate::BuiltinResult<UnionEvaluation> {
694 if a.shape.len() != 2 || b.shape.len() != 2 {
695 return Err(union_internal_error(
696 "union: 'rows' option requires 2-D complex matrices",
697 ));
698 }
699 if a.shape[1] != b.shape[1] {
700 return Err(union_error_with(
701 &UNION_ERROR_ROWS_COLUMN_MISMATCH,
702 UNION_ERROR_ROWS_COLUMN_MISMATCH.message,
703 ));
704 }
705 let rows_a = a.shape[0];
706 let cols = a.shape[1];
707 let rows_b = b.shape[0];
708
709 let mut entries = Vec::<ComplexRowUnionEntry>::new();
710 let mut map: HashMap<Vec<ComplexKey>, usize> = HashMap::new();
711 let mut order_counter = 0usize;
712
713 for r in 0..rows_a {
714 let mut row_values = Vec::with_capacity(cols);
715 let mut key_row = Vec::with_capacity(cols);
716 for c in 0..cols {
717 let idx = r + c * rows_a;
718 let value = a.data[idx];
719 row_values.push(value);
720 key_row.push(ComplexKey::new(value));
721 }
722 match map.entry(key_row) {
723 Entry::Occupied(_) => {}
724 Entry::Vacant(v) => {
725 let entry_idx = entries.len();
726 entries.push(ComplexRowUnionEntry {
727 row_data: row_values,
728 a_row: Some(r),
729 b_row: None,
730 order_rank: order_counter,
731 });
732 v.insert(entry_idx);
733 order_counter += 1;
734 }
735 }
736 }
737
738 for r in 0..rows_b {
739 let mut row_values = Vec::with_capacity(cols);
740 let mut key_row = Vec::with_capacity(cols);
741 for c in 0..cols {
742 let idx = r + c * rows_b;
743 let value = b.data[idx];
744 row_values.push(value);
745 key_row.push(ComplexKey::new(value));
746 }
747 match map.entry(key_row) {
748 Entry::Occupied(occ) => {
749 let entry = &mut entries[*occ.get()];
750 if entry.a_row.is_none() && entry.b_row.is_none() {
751 entry.b_row = Some(r);
752 }
753 }
754 Entry::Vacant(v) => {
755 let entry_idx = entries.len();
756 entries.push(ComplexRowUnionEntry {
757 row_data: row_values,
758 a_row: None,
759 b_row: Some(r),
760 order_rank: order_counter,
761 });
762 v.insert(entry_idx);
763 order_counter += 1;
764 }
765 }
766 }
767
768 assemble_complex_row_union(entries, opts, cols)
769}
770
771fn union_char(
772 a: CharArray,
773 b: CharArray,
774 opts: &UnionOptions,
775) -> crate::BuiltinResult<UnionEvaluation> {
776 if opts.rows {
777 union_char_rows(a, b, opts)
778 } else {
779 union_char_elements(a, b, opts)
780 }
781}
782
783fn union_char_elements(
784 a: CharArray,
785 b: CharArray,
786 opts: &UnionOptions,
787) -> crate::BuiltinResult<UnionEvaluation> {
788 let mut entries = Vec::<CharUnionEntry>::new();
789 let mut map: HashMap<u32, usize> = HashMap::new();
790 let mut order_counter = 0usize;
791
792 for col in 0..a.cols {
793 for row in 0..a.rows {
794 let linear_idx = row + col * a.rows;
795 let data_idx = row * a.cols + col;
796 let ch = a.data[data_idx];
797 let key = ch as u32;
798 match map.entry(key) {
799 Entry::Occupied(_) => {}
800 Entry::Vacant(v) => {
801 let entry_idx = entries.len();
802 entries.push(CharUnionEntry {
803 ch,
804 a_index: Some(linear_idx),
805 b_index: None,
806 order_rank: order_counter,
807 });
808 v.insert(entry_idx);
809 order_counter += 1;
810 }
811 }
812 }
813 }
814
815 for col in 0..b.cols {
816 for row in 0..b.rows {
817 let linear_idx = row + col * b.rows;
818 let data_idx = row * b.cols + col;
819 let ch = b.data[data_idx];
820 let key = ch as u32;
821 match map.entry(key) {
822 Entry::Occupied(occ) => {
823 let entry = &mut entries[*occ.get()];
824 if entry.a_index.is_none() && entry.b_index.is_none() {
825 entry.b_index = Some(linear_idx);
826 }
827 }
828 Entry::Vacant(v) => {
829 let entry_idx = entries.len();
830 entries.push(CharUnionEntry {
831 ch,
832 a_index: None,
833 b_index: Some(linear_idx),
834 order_rank: order_counter,
835 });
836 v.insert(entry_idx);
837 order_counter += 1;
838 }
839 }
840 }
841 }
842
843 assemble_char_union(entries, opts)
844}
845
846fn union_char_rows(
847 a: CharArray,
848 b: CharArray,
849 opts: &UnionOptions,
850) -> crate::BuiltinResult<UnionEvaluation> {
851 if a.cols != b.cols {
852 return Err(union_error_with(
853 &UNION_ERROR_ROWS_COLUMN_MISMATCH,
854 UNION_ERROR_ROWS_COLUMN_MISMATCH.message,
855 ));
856 }
857 let rows_a = a.rows;
858 let rows_b = b.rows;
859 let cols = a.cols;
860
861 let mut entries = Vec::<CharRowUnionEntry>::new();
862 let mut map: HashMap<RowCharKey, usize> = HashMap::new();
863 let mut order_counter = 0usize;
864
865 for r in 0..rows_a {
866 let mut row_values = Vec::with_capacity(cols);
867 for c in 0..cols {
868 let idx = r * cols + c;
869 row_values.push(a.data[idx]);
870 }
871 let key = RowCharKey::from_slice(&row_values);
872 match map.entry(key) {
873 Entry::Occupied(_) => {}
874 Entry::Vacant(v) => {
875 let entry_idx = entries.len();
876 entries.push(CharRowUnionEntry {
877 row_data: row_values,
878 a_row: Some(r),
879 b_row: None,
880 order_rank: order_counter,
881 });
882 v.insert(entry_idx);
883 order_counter += 1;
884 }
885 }
886 }
887
888 for r in 0..rows_b {
889 let mut row_values = Vec::with_capacity(cols);
890 for c in 0..cols {
891 let idx = r * cols + c;
892 row_values.push(b.data[idx]);
893 }
894 let key = RowCharKey::from_slice(&row_values);
895 match map.entry(key) {
896 Entry::Occupied(occ) => {
897 let entry = &mut entries[*occ.get()];
898 if entry.a_row.is_none() && entry.b_row.is_none() {
899 entry.b_row = Some(r);
900 }
901 }
902 Entry::Vacant(v) => {
903 let entry_idx = entries.len();
904 entries.push(CharRowUnionEntry {
905 row_data: row_values,
906 a_row: None,
907 b_row: Some(r),
908 order_rank: order_counter,
909 });
910 v.insert(entry_idx);
911 order_counter += 1;
912 }
913 }
914 }
915
916 assemble_char_row_union(entries, opts, cols)
917}
918
919fn union_string(
920 a: StringArray,
921 b: StringArray,
922 opts: &UnionOptions,
923) -> crate::BuiltinResult<UnionEvaluation> {
924 if opts.rows {
925 union_string_rows(a, b, opts)
926 } else {
927 union_string_elements(a, b, opts)
928 }
929}
930
931fn union_string_elements(
932 a: StringArray,
933 b: StringArray,
934 opts: &UnionOptions,
935) -> crate::BuiltinResult<UnionEvaluation> {
936 let mut entries = Vec::<StringUnionEntry>::new();
937 let mut map: HashMap<String, usize> = HashMap::new();
938 let mut order_counter = 0usize;
939
940 for (idx, value) in a.data.iter().enumerate() {
941 match map.entry(value.clone()) {
942 Entry::Occupied(_) => {}
943 Entry::Vacant(v) => {
944 let entry_idx = entries.len();
945 entries.push(StringUnionEntry {
946 value: value.clone(),
947 a_index: Some(idx),
948 b_index: None,
949 order_rank: order_counter,
950 });
951 v.insert(entry_idx);
952 order_counter += 1;
953 }
954 }
955 }
956
957 for (idx, value) in b.data.iter().enumerate() {
958 match map.entry(value.clone()) {
959 Entry::Occupied(occ) => {
960 let entry = &mut entries[*occ.get()];
961 if entry.a_index.is_none() && entry.b_index.is_none() {
962 entry.b_index = Some(idx);
963 }
964 }
965 Entry::Vacant(v) => {
966 let entry_idx = entries.len();
967 entries.push(StringUnionEntry {
968 value: value.clone(),
969 a_index: None,
970 b_index: Some(idx),
971 order_rank: order_counter,
972 });
973 v.insert(entry_idx);
974 order_counter += 1;
975 }
976 }
977 }
978
979 assemble_string_union(entries, opts)
980}
981
982fn union_string_rows(
983 a: StringArray,
984 b: StringArray,
985 opts: &UnionOptions,
986) -> crate::BuiltinResult<UnionEvaluation> {
987 if a.shape.len() != 2 || b.shape.len() != 2 {
988 return Err(union_internal_error(
989 "union: 'rows' option requires 2-D string arrays",
990 ));
991 }
992 if a.shape[1] != b.shape[1] {
993 return Err(union_error_with(
994 &UNION_ERROR_ROWS_COLUMN_MISMATCH,
995 UNION_ERROR_ROWS_COLUMN_MISMATCH.message,
996 ));
997 }
998 let rows_a = a.shape[0];
999 let cols = a.shape[1];
1000 let rows_b = b.shape[0];
1001
1002 let mut entries = Vec::<StringRowUnionEntry>::new();
1003 let mut map: HashMap<RowStringKey, usize> = HashMap::new();
1004 let mut order_counter = 0usize;
1005
1006 for r in 0..rows_a {
1007 let mut row_values = Vec::with_capacity(cols);
1008 for c in 0..cols {
1009 let idx = r + c * rows_a;
1010 row_values.push(a.data[idx].clone());
1011 }
1012 let key = RowStringKey(row_values.clone());
1013 match map.entry(key) {
1014 Entry::Occupied(_) => {}
1015 Entry::Vacant(v) => {
1016 let entry_idx = entries.len();
1017 entries.push(StringRowUnionEntry {
1018 row_data: row_values,
1019 a_row: Some(r),
1020 b_row: None,
1021 order_rank: order_counter,
1022 });
1023 v.insert(entry_idx);
1024 order_counter += 1;
1025 }
1026 }
1027 }
1028
1029 for r in 0..rows_b {
1030 let mut row_values = Vec::with_capacity(cols);
1031 for c in 0..cols {
1032 let idx = r + c * rows_b;
1033 row_values.push(b.data[idx].clone());
1034 }
1035 let key = RowStringKey(row_values.clone());
1036 match map.entry(key) {
1037 Entry::Occupied(occ) => {
1038 let entry = &mut entries[*occ.get()];
1039 if entry.a_row.is_none() && entry.b_row.is_none() {
1040 entry.b_row = Some(r);
1041 }
1042 }
1043 Entry::Vacant(v) => {
1044 let entry_idx = entries.len();
1045 entries.push(StringRowUnionEntry {
1046 row_data: row_values,
1047 a_row: None,
1048 b_row: Some(r),
1049 order_rank: order_counter,
1050 });
1051 v.insert(entry_idx);
1052 order_counter += 1;
1053 }
1054 }
1055 }
1056
1057 assemble_string_row_union(entries, opts, cols)
1058}
1059
1060#[derive(Debug, Clone)]
1061pub struct UnionEvaluation {
1062 values: Value,
1063 ia: Tensor,
1064 ib: Tensor,
1065}
1066
1067impl UnionEvaluation {
1068 fn new(values: Value, ia: Tensor, ib: Tensor) -> Self {
1069 Self { values, ia, ib }
1070 }
1071
1072 pub fn from_union_result(result: UnionResult) -> crate::BuiltinResult<Self> {
1073 let UnionResult { values, ia, ib } = result;
1074 let values_tensor = Tensor::new(values.data, values.shape)
1075 .map_err(|e| union_internal_error(format!("union: {e}")))?;
1076 let ia_tensor = Tensor::new(ia.data, ia.shape)
1077 .map_err(|e| union_internal_error(format!("union: {e}")))?;
1078 let ib_tensor = Tensor::new(ib.data, ib.shape)
1079 .map_err(|e| union_internal_error(format!("union: {e}")))?;
1080 Ok(UnionEvaluation::new(
1081 tensor::tensor_into_value(values_tensor),
1082 ia_tensor,
1083 ib_tensor,
1084 ))
1085 }
1086
1087 pub fn into_numeric_union_result(self) -> crate::BuiltinResult<UnionResult> {
1088 let UnionEvaluation { values, ia, ib } = self;
1089 let values_tensor =
1090 tensor::value_into_tensor_for("union", values).map_err(|e| union_internal_error(e))?;
1091 Ok(UnionResult {
1092 values: HostTensorOwned {
1093 data: values_tensor.data,
1094 shape: values_tensor.shape,
1095 storage: GpuTensorStorage::Real,
1096 },
1097 ia: HostTensorOwned {
1098 data: ia.data,
1099 shape: ia.shape,
1100 storage: GpuTensorStorage::Real,
1101 },
1102 ib: HostTensorOwned {
1103 data: ib.data,
1104 shape: ib.shape,
1105 storage: GpuTensorStorage::Real,
1106 },
1107 })
1108 }
1109
1110 pub fn into_values_value(self) -> Value {
1111 self.values
1112 }
1113
1114 pub fn into_pair(self) -> (Value, Value) {
1115 let ia = tensor::tensor_into_value(self.ia);
1116 (self.values, ia)
1117 }
1118
1119 pub fn into_triple(self) -> (Value, Value, Value) {
1120 let ia = tensor::tensor_into_value(self.ia);
1121 let ib = tensor::tensor_into_value(self.ib);
1122 (self.values, ia, ib)
1123 }
1124
1125 pub fn values_value(&self) -> Value {
1126 self.values.clone()
1127 }
1128
1129 pub fn ia_value(&self) -> Value {
1130 tensor::tensor_into_value(self.ia.clone())
1131 }
1132
1133 pub fn ib_value(&self) -> Value {
1134 tensor::tensor_into_value(self.ib.clone())
1135 }
1136}
1137
1138#[derive(Debug)]
1139struct NumericUnionEntry {
1140 value: f64,
1141 a_index: Option<usize>,
1142 b_index: Option<usize>,
1143 order_rank: usize,
1144}
1145
1146#[derive(Debug)]
1147struct NumericRowUnionEntry {
1148 row_data: Vec<f64>,
1149 a_row: Option<usize>,
1150 b_row: Option<usize>,
1151 order_rank: usize,
1152}
1153
1154#[derive(Debug)]
1155struct ComplexUnionEntry {
1156 value: (f64, f64),
1157 a_index: Option<usize>,
1158 b_index: Option<usize>,
1159 order_rank: usize,
1160}
1161
1162#[derive(Debug)]
1163struct ComplexRowUnionEntry {
1164 row_data: Vec<(f64, f64)>,
1165 a_row: Option<usize>,
1166 b_row: Option<usize>,
1167 order_rank: usize,
1168}
1169
1170#[derive(Debug)]
1171struct CharUnionEntry {
1172 ch: char,
1173 a_index: Option<usize>,
1174 b_index: Option<usize>,
1175 order_rank: usize,
1176}
1177
1178#[derive(Debug)]
1179struct CharRowUnionEntry {
1180 row_data: Vec<char>,
1181 a_row: Option<usize>,
1182 b_row: Option<usize>,
1183 order_rank: usize,
1184}
1185
1186#[derive(Debug)]
1187struct StringUnionEntry {
1188 value: String,
1189 a_index: Option<usize>,
1190 b_index: Option<usize>,
1191 order_rank: usize,
1192}
1193
1194#[derive(Debug)]
1195struct StringRowUnionEntry {
1196 row_data: Vec<String>,
1197 a_row: Option<usize>,
1198 b_row: Option<usize>,
1199 order_rank: usize,
1200}
1201
1202#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1203struct NumericRowKey(Vec<u64>);
1204
1205impl NumericRowKey {
1206 fn from_slice(values: &[f64]) -> Self {
1207 NumericRowKey(values.iter().map(|&v| canonicalize_f64(v)).collect())
1208 }
1209}
1210
1211#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
1212struct ComplexKey {
1213 re: u64,
1214 im: u64,
1215}
1216
1217impl ComplexKey {
1218 fn new(value: (f64, f64)) -> Self {
1219 Self {
1220 re: canonicalize_f64(value.0),
1221 im: canonicalize_f64(value.1),
1222 }
1223 }
1224}
1225
1226#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1227struct RowCharKey(Vec<u32>);
1228
1229impl RowCharKey {
1230 fn from_slice(values: &[char]) -> Self {
1231 RowCharKey(values.iter().map(|&ch| ch as u32).collect())
1232 }
1233}
1234
1235#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1236struct RowStringKey(Vec<String>);
1237
1238fn assemble_numeric_union(
1239 entries: Vec<NumericUnionEntry>,
1240 opts: &UnionOptions,
1241) -> crate::BuiltinResult<UnionEvaluation> {
1242 let mut order: Vec<usize> = (0..entries.len()).collect();
1243 match opts.order {
1244 UnionOrder::Sorted => {
1245 order.sort_by(|&lhs, &rhs| compare_f64(entries[lhs].value, entries[rhs].value));
1246 }
1247 UnionOrder::Stable => {
1248 order.sort_by_key(|&idx| entries[idx].order_rank);
1249 }
1250 }
1251
1252 let mut values = Vec::with_capacity(order.len());
1253 let mut ia = Vec::new();
1254 let mut ib = Vec::new();
1255 for &idx in &order {
1256 let entry = &entries[idx];
1257 values.push(entry.value);
1258 if let Some(a_idx) = entry.a_index {
1259 ia.push((a_idx + 1) as f64);
1260 } else if let Some(b_idx) = entry.b_index {
1261 ib.push((b_idx + 1) as f64);
1262 }
1263 }
1264
1265 let value_tensor = Tensor::new(values, vec![order.len(), 1])
1266 .map_err(|e| union_internal_error(format!("union: {e}")))?;
1267 let ia_len = ia.len();
1268 let ib_len = ib.len();
1269 let ia_tensor = Tensor::new(ia, vec![ia_len, 1])
1270 .map_err(|e| union_internal_error(format!("union: {e}")))?;
1271 let ib_tensor = Tensor::new(ib, vec![ib_len, 1])
1272 .map_err(|e| union_internal_error(format!("union: {e}")))?;
1273
1274 Ok(UnionEvaluation::new(
1275 tensor::tensor_into_value(value_tensor),
1276 ia_tensor,
1277 ib_tensor,
1278 ))
1279}
1280
1281fn assemble_numeric_row_union(
1282 entries: Vec<NumericRowUnionEntry>,
1283 opts: &UnionOptions,
1284 cols: usize,
1285) -> crate::BuiltinResult<UnionEvaluation> {
1286 let mut order: Vec<usize> = (0..entries.len()).collect();
1287 match opts.order {
1288 UnionOrder::Sorted => {
1289 order.sort_by(|&lhs, &rhs| {
1290 compare_numeric_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1291 });
1292 }
1293 UnionOrder::Stable => {
1294 order.sort_by_key(|&idx| entries[idx].order_rank);
1295 }
1296 }
1297
1298 let unique_rows = order.len();
1299 let mut values = vec![0.0f64; unique_rows * cols];
1300 let mut ia = Vec::new();
1301 let mut ib = Vec::new();
1302
1303 for (row_pos, &entry_idx) in order.iter().enumerate() {
1304 let entry = &entries[entry_idx];
1305 for col in 0..cols {
1306 let dest = row_pos + col * unique_rows;
1307 values[dest] = entry.row_data[col];
1308 }
1309 if let Some(a_row) = entry.a_row {
1310 ia.push((a_row + 1) as f64);
1311 } else if let Some(b_row) = entry.b_row {
1312 ib.push((b_row + 1) as f64);
1313 }
1314 }
1315
1316 let value_tensor = Tensor::new(values, vec![unique_rows, cols])
1317 .map_err(|e| union_internal_error(format!("union: {e}")))?;
1318 let ia_len = ia.len();
1319 let ib_len = ib.len();
1320 let ia_tensor = Tensor::new(ia, vec![ia_len, 1])
1321 .map_err(|e| union_internal_error(format!("union: {e}")))?;
1322 let ib_tensor = Tensor::new(ib, vec![ib_len, 1])
1323 .map_err(|e| union_internal_error(format!("union: {e}")))?;
1324
1325 Ok(UnionEvaluation::new(
1326 tensor::tensor_into_value(value_tensor),
1327 ia_tensor,
1328 ib_tensor,
1329 ))
1330}
1331
1332fn assemble_complex_union(
1333 entries: Vec<ComplexUnionEntry>,
1334 opts: &UnionOptions,
1335) -> crate::BuiltinResult<UnionEvaluation> {
1336 let mut order: Vec<usize> = (0..entries.len()).collect();
1337 match opts.order {
1338 UnionOrder::Sorted => {
1339 order.sort_by(|&lhs, &rhs| compare_complex(entries[lhs].value, entries[rhs].value));
1340 }
1341 UnionOrder::Stable => {
1342 order.sort_by_key(|&idx| entries[idx].order_rank);
1343 }
1344 }
1345
1346 let mut values = Vec::with_capacity(order.len());
1347 let mut ia = Vec::new();
1348 let mut ib = Vec::new();
1349 for &idx in &order {
1350 let entry = &entries[idx];
1351 values.push(entry.value);
1352 if let Some(a_idx) = entry.a_index {
1353 ia.push((a_idx + 1) as f64);
1354 } else if let Some(b_idx) = entry.b_index {
1355 ib.push((b_idx + 1) as f64);
1356 }
1357 }
1358
1359 let value_tensor = ComplexTensor::new(values, vec![order.len(), 1])
1360 .map_err(|e| union_internal_error(format!("union: {e}")))?;
1361 let ia_len = ia.len();
1362 let ib_len = ib.len();
1363 let ia_tensor = Tensor::new(ia, vec![ia_len, 1])
1364 .map_err(|e| union_internal_error(format!("union: {e}")))?;
1365 let ib_tensor = Tensor::new(ib, vec![ib_len, 1])
1366 .map_err(|e| union_internal_error(format!("union: {e}")))?;
1367
1368 Ok(UnionEvaluation::new(
1369 complex_tensor_into_value(value_tensor),
1370 ia_tensor,
1371 ib_tensor,
1372 ))
1373}
1374
1375fn assemble_complex_row_union(
1376 entries: Vec<ComplexRowUnionEntry>,
1377 opts: &UnionOptions,
1378 cols: usize,
1379) -> crate::BuiltinResult<UnionEvaluation> {
1380 let mut order: Vec<usize> = (0..entries.len()).collect();
1381 match opts.order {
1382 UnionOrder::Sorted => {
1383 order.sort_by(|&lhs, &rhs| {
1384 compare_complex_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1385 });
1386 }
1387 UnionOrder::Stable => {
1388 order.sort_by_key(|&idx| entries[idx].order_rank);
1389 }
1390 }
1391
1392 let unique_rows = order.len();
1393 let mut values = vec![(0.0, 0.0); unique_rows * cols];
1394 let mut ia = Vec::new();
1395 let mut ib = Vec::new();
1396
1397 for (row_pos, &entry_idx) in order.iter().enumerate() {
1398 let entry = &entries[entry_idx];
1399 for col in 0..cols {
1400 let dest = row_pos + col * unique_rows;
1401 values[dest] = entry.row_data[col];
1402 }
1403 if let Some(a_row) = entry.a_row {
1404 ia.push((a_row + 1) as f64);
1405 } else if let Some(b_row) = entry.b_row {
1406 ib.push((b_row + 1) as f64);
1407 }
1408 }
1409
1410 let value_tensor = ComplexTensor::new(values, vec![unique_rows, cols])
1411 .map_err(|e| union_internal_error(format!("union: {e}")))?;
1412 let ia_len = ia.len();
1413 let ib_len = ib.len();
1414 let ia_tensor = Tensor::new(ia, vec![ia_len, 1])
1415 .map_err(|e| union_internal_error(format!("union: {e}")))?;
1416 let ib_tensor = Tensor::new(ib, vec![ib_len, 1])
1417 .map_err(|e| union_internal_error(format!("union: {e}")))?;
1418
1419 Ok(UnionEvaluation::new(
1420 complex_tensor_into_value(value_tensor),
1421 ia_tensor,
1422 ib_tensor,
1423 ))
1424}
1425
1426fn assemble_char_union(
1427 entries: Vec<CharUnionEntry>,
1428 opts: &UnionOptions,
1429) -> crate::BuiltinResult<UnionEvaluation> {
1430 let mut order: Vec<usize> = (0..entries.len()).collect();
1431 match opts.order {
1432 UnionOrder::Sorted => {
1433 order.sort_by(|&lhs, &rhs| entries[lhs].ch.cmp(&entries[rhs].ch));
1434 }
1435 UnionOrder::Stable => {
1436 order.sort_by_key(|&idx| entries[idx].order_rank);
1437 }
1438 }
1439
1440 let mut values = Vec::with_capacity(order.len());
1441 let mut ia = Vec::new();
1442 let mut ib = Vec::new();
1443 for &idx in &order {
1444 let entry = &entries[idx];
1445 values.push(entry.ch);
1446 if let Some(a_idx) = entry.a_index {
1447 ia.push((a_idx + 1) as f64);
1448 } else if let Some(b_idx) = entry.b_index {
1449 ib.push((b_idx + 1) as f64);
1450 }
1451 }
1452
1453 let value_array = CharArray::new(values, order.len(), 1)
1454 .map_err(|e| union_internal_error(format!("union: {e}")))?;
1455 let ia_len = ia.len();
1456 let ib_len = ib.len();
1457 let ia_tensor = Tensor::new(ia, vec![ia_len, 1])
1458 .map_err(|e| union_internal_error(format!("union: {e}")))?;
1459 let ib_tensor = Tensor::new(ib, vec![ib_len, 1])
1460 .map_err(|e| union_internal_error(format!("union: {e}")))?;
1461
1462 Ok(UnionEvaluation::new(
1463 Value::CharArray(value_array),
1464 ia_tensor,
1465 ib_tensor,
1466 ))
1467}
1468
1469fn assemble_char_row_union(
1470 entries: Vec<CharRowUnionEntry>,
1471 opts: &UnionOptions,
1472 cols: usize,
1473) -> crate::BuiltinResult<UnionEvaluation> {
1474 let mut order: Vec<usize> = (0..entries.len()).collect();
1475 match opts.order {
1476 UnionOrder::Sorted => {
1477 order.sort_by(|&lhs, &rhs| {
1478 compare_char_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1479 });
1480 }
1481 UnionOrder::Stable => {
1482 order.sort_by_key(|&idx| entries[idx].order_rank);
1483 }
1484 }
1485
1486 let unique_rows = order.len();
1487 let mut values = vec!['\0'; unique_rows * cols];
1488 let mut ia = Vec::new();
1489 let mut ib = Vec::new();
1490
1491 for (row_pos, &entry_idx) in order.iter().enumerate() {
1492 let entry = &entries[entry_idx];
1493 for col in 0..cols {
1494 let dest = row_pos * cols + col;
1495 values[dest] = entry.row_data[col];
1496 }
1497 if let Some(a_row) = entry.a_row {
1498 ia.push((a_row + 1) as f64);
1499 } else if let Some(b_row) = entry.b_row {
1500 ib.push((b_row + 1) as f64);
1501 }
1502 }
1503
1504 let value_array = CharArray::new(values, unique_rows, cols)
1505 .map_err(|e| union_internal_error(format!("union: {e}")))?;
1506 let ia_len = ia.len();
1507 let ib_len = ib.len();
1508 let ia_tensor = Tensor::new(ia, vec![ia_len, 1])
1509 .map_err(|e| union_internal_error(format!("union: {e}")))?;
1510 let ib_tensor = Tensor::new(ib, vec![ib_len, 1])
1511 .map_err(|e| union_internal_error(format!("union: {e}")))?;
1512
1513 Ok(UnionEvaluation::new(
1514 Value::CharArray(value_array),
1515 ia_tensor,
1516 ib_tensor,
1517 ))
1518}
1519
1520fn assemble_string_union(
1521 entries: Vec<StringUnionEntry>,
1522 opts: &UnionOptions,
1523) -> crate::BuiltinResult<UnionEvaluation> {
1524 let mut order: Vec<usize> = (0..entries.len()).collect();
1525 match opts.order {
1526 UnionOrder::Sorted => {
1527 order.sort_by(|&lhs, &rhs| entries[lhs].value.cmp(&entries[rhs].value));
1528 }
1529 UnionOrder::Stable => {
1530 order.sort_by_key(|&idx| entries[idx].order_rank);
1531 }
1532 }
1533
1534 let mut values = Vec::with_capacity(order.len());
1535 let mut ia = Vec::new();
1536 let mut ib = Vec::new();
1537 for &idx in &order {
1538 let entry = &entries[idx];
1539 values.push(entry.value.clone());
1540 if let Some(a_idx) = entry.a_index {
1541 ia.push((a_idx + 1) as f64);
1542 } else if let Some(b_idx) = entry.b_index {
1543 ib.push((b_idx + 1) as f64);
1544 }
1545 }
1546
1547 let value_array = StringArray::new(values, vec![order.len(), 1])
1548 .map_err(|e| union_internal_error(format!("union: {e}")))?;
1549 let ia_len = ia.len();
1550 let ib_len = ib.len();
1551 let ia_tensor = Tensor::new(ia, vec![ia_len, 1])
1552 .map_err(|e| union_internal_error(format!("union: {e}")))?;
1553 let ib_tensor = Tensor::new(ib, vec![ib_len, 1])
1554 .map_err(|e| union_internal_error(format!("union: {e}")))?;
1555
1556 Ok(UnionEvaluation::new(
1557 Value::StringArray(value_array),
1558 ia_tensor,
1559 ib_tensor,
1560 ))
1561}
1562
1563fn assemble_string_row_union(
1564 entries: Vec<StringRowUnionEntry>,
1565 opts: &UnionOptions,
1566 cols: usize,
1567) -> crate::BuiltinResult<UnionEvaluation> {
1568 let mut order: Vec<usize> = (0..entries.len()).collect();
1569 match opts.order {
1570 UnionOrder::Sorted => {
1571 order.sort_by(|&lhs, &rhs| {
1572 compare_string_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1573 });
1574 }
1575 UnionOrder::Stable => {
1576 order.sort_by_key(|&idx| entries[idx].order_rank);
1577 }
1578 }
1579
1580 let unique_rows = order.len();
1581 let mut values = vec![String::new(); unique_rows * cols];
1582 let mut ia = Vec::new();
1583 let mut ib = Vec::new();
1584
1585 for (row_pos, &entry_idx) in order.iter().enumerate() {
1586 let entry = &entries[entry_idx];
1587 for col in 0..cols {
1588 let dest = row_pos + col * unique_rows;
1589 values[dest] = entry.row_data[col].clone();
1590 }
1591 if let Some(a_row) = entry.a_row {
1592 ia.push((a_row + 1) as f64);
1593 } else if let Some(b_row) = entry.b_row {
1594 ib.push((b_row + 1) as f64);
1595 }
1596 }
1597
1598 let value_array = StringArray::new(values, vec![unique_rows, cols])
1599 .map_err(|e| union_internal_error(format!("union: {e}")))?;
1600 let ia_len = ia.len();
1601 let ib_len = ib.len();
1602 let ia_tensor = Tensor::new(ia, vec![ia_len, 1])
1603 .map_err(|e| union_internal_error(format!("union: {e}")))?;
1604 let ib_tensor = Tensor::new(ib, vec![ib_len, 1])
1605 .map_err(|e| union_internal_error(format!("union: {e}")))?;
1606
1607 Ok(UnionEvaluation::new(
1608 Value::StringArray(value_array),
1609 ia_tensor,
1610 ib_tensor,
1611 ))
1612}
1613
1614fn canonicalize_f64(value: f64) -> u64 {
1615 if value.is_nan() {
1616 0x7ff8_0000_0000_0000u64
1617 } else if value == 0.0 {
1618 0u64
1619 } else {
1620 value.to_bits()
1621 }
1622}
1623
1624fn compare_f64(a: f64, b: f64) -> Ordering {
1625 if a.is_nan() {
1626 if b.is_nan() {
1627 Ordering::Equal
1628 } else {
1629 Ordering::Greater
1630 }
1631 } else if b.is_nan() {
1632 Ordering::Less
1633 } else {
1634 a.partial_cmp(&b).unwrap_or(Ordering::Equal)
1635 }
1636}
1637
1638fn compare_numeric_rows(a: &[f64], b: &[f64]) -> Ordering {
1639 for (lhs, rhs) in a.iter().zip(b.iter()) {
1640 let ord = compare_f64(*lhs, *rhs);
1641 if ord != Ordering::Equal {
1642 return ord;
1643 }
1644 }
1645 Ordering::Equal
1646}
1647
1648fn complex_is_nan(value: (f64, f64)) -> bool {
1649 value.0.is_nan() || value.1.is_nan()
1650}
1651
1652fn compare_complex(a: (f64, f64), b: (f64, f64)) -> Ordering {
1653 match (complex_is_nan(a), complex_is_nan(b)) {
1654 (true, true) => Ordering::Equal,
1655 (true, false) => Ordering::Greater,
1656 (false, true) => Ordering::Less,
1657 (false, false) => {
1658 let mag_a = a.0.hypot(a.1);
1659 let mag_b = b.0.hypot(b.1);
1660 let mag_cmp = compare_f64(mag_a, mag_b);
1661 if mag_cmp != Ordering::Equal {
1662 return mag_cmp;
1663 }
1664 let re_cmp = compare_f64(a.0, b.0);
1665 if re_cmp != Ordering::Equal {
1666 return re_cmp;
1667 }
1668 compare_f64(a.1, b.1)
1669 }
1670 }
1671}
1672
1673fn compare_complex_rows(a: &[(f64, f64)], b: &[(f64, f64)]) -> Ordering {
1674 for (lhs, rhs) in a.iter().zip(b.iter()) {
1675 let ord = compare_complex(*lhs, *rhs);
1676 if ord != Ordering::Equal {
1677 return ord;
1678 }
1679 }
1680 Ordering::Equal
1681}
1682
1683fn compare_char_rows(a: &[char], b: &[char]) -> Ordering {
1684 for (lhs, rhs) in a.iter().zip(b.iter()) {
1685 let ord = lhs.cmp(rhs);
1686 if ord != Ordering::Equal {
1687 return ord;
1688 }
1689 }
1690 Ordering::Equal
1691}
1692
1693fn compare_string_rows(a: &[String], b: &[String]) -> Ordering {
1694 for (lhs, rhs) in a.iter().zip(b.iter()) {
1695 let ord = lhs.cmp(rhs);
1696 if ord != Ordering::Equal {
1697 return ord;
1698 }
1699 }
1700 Ordering::Equal
1701}
1702
1703#[cfg(test)]
1704pub(crate) mod tests {
1705 use super::*;
1706 use crate::builtins::common::test_support;
1707 use runmat_accelerate_api::HostTensorView;
1708 use runmat_builtins::{IntValue, ResolveContext, Tensor, Type, Value};
1709
1710 fn evaluate_sync(a: Value, b: Value, rest: &[Value]) -> crate::BuiltinResult<UnionEvaluation> {
1711 futures::executor::block_on(evaluate(a, b, rest))
1712 }
1713
1714 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1715 #[test]
1716 fn union_numeric_sorted_default() {
1717 let a = Tensor::new(vec![5.0, 7.0, 1.0], vec![3, 1]).unwrap();
1718 let b = Tensor::new(vec![3.0, 1.0, 1.0], vec![3, 1]).unwrap();
1719 let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[]).expect("union");
1720 match eval.values_value() {
1721 Value::Tensor(t) => {
1722 assert_eq!(t.data, vec![1.0, 3.0, 5.0, 7.0]);
1723 assert_eq!(t.shape, vec![4, 1]);
1724 }
1725 other => panic!("expected tensor result, got {other:?}"),
1726 }
1727 let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1728 assert_eq!(ia.data, vec![3.0, 1.0, 2.0]);
1729 assert_eq!(ia.shape, vec![3, 1]);
1730 let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1731 assert_eq!(ib.data, vec![1.0]);
1732 assert_eq!(ib.shape, vec![1, 1]);
1733 }
1734
1735 #[test]
1736 fn union_type_resolver_numeric() {
1737 assert_eq!(
1738 set_values_output_type(
1739 &[Type::tensor(), Type::tensor()],
1740 &ResolveContext::new(Vec::new()),
1741 ),
1742 Type::tensor()
1743 );
1744 }
1745
1746 #[test]
1747 fn union_type_resolver_string_array() {
1748 assert_eq!(
1749 set_values_output_type(
1750 &[Type::cell_of(Type::String), Type::String],
1751 &ResolveContext::new(Vec::new()),
1752 ),
1753 Type::cell_of(Type::String)
1754 );
1755 }
1756
1757 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1758 #[test]
1759 fn union_numeric_stable_order() {
1760 let a = Tensor::new(vec![5.0, 7.0, 1.0], vec![3, 1]).unwrap();
1761 let b = Tensor::new(vec![3.0, 2.0, 4.0], vec![3, 1]).unwrap();
1762 let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[Value::from("stable")])
1763 .expect("union");
1764 match eval.values_value() {
1765 Value::Tensor(t) => {
1766 assert_eq!(t.data, vec![5.0, 7.0, 1.0, 3.0, 2.0, 4.0]);
1767 assert_eq!(t.shape, vec![6, 1]);
1768 }
1769 other => panic!("expected tensor result, got {other:?}"),
1770 }
1771 let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1772 assert_eq!(ia.data, vec![1.0, 2.0, 3.0]);
1773 let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1774 assert_eq!(ib.data, vec![1.0, 2.0, 3.0]);
1775 }
1776
1777 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1778 #[test]
1779 fn union_numeric_sorted_places_nan_last() {
1780 let a = Tensor::new(vec![f64::NAN, 1.0], vec![2, 1]).unwrap();
1781 let b = Tensor::new(vec![2.0, f64::NAN], vec![2, 1]).unwrap();
1782 let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[]).expect("union");
1783 let values = tensor::value_into_tensor_for("union", eval.values_value()).expect("values");
1784 assert_eq!(values.shape, vec![3, 1]);
1785 assert_eq!(values.data[0], 1.0);
1786 assert_eq!(values.data[1], 2.0);
1787 assert!(values.data[2].is_nan());
1788 let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1789 assert_eq!(ia.data, vec![2.0, 1.0]);
1790 let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1791 assert_eq!(ib.data, vec![1.0]);
1792 }
1793
1794 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1795 #[test]
1796 fn union_numeric_rows_sorted() {
1797 let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
1798 let b = Tensor::new(vec![3.0, 5.0, 4.0, 6.0], vec![2, 2]).unwrap();
1799 let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[Value::from("rows")])
1800 .expect("union");
1801 match eval.values_value() {
1802 Value::Tensor(t) => {
1803 assert_eq!(t.shape, vec![3, 2]);
1804 assert_eq!(t.data, vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
1805 }
1806 other => panic!("expected tensor result, got {other:?}"),
1807 }
1808 let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1809 assert_eq!(ia.data, vec![1.0, 2.0]);
1810 let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1811 assert_eq!(ib.data, vec![2.0]);
1812 }
1813
1814 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1815 #[test]
1816 fn union_numeric_rows_stable_preserves_first_occurrence() {
1817 let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
1818 let b = Tensor::new(vec![3.0, 5.0, 1.0, 4.0, 6.0, 2.0], vec![3, 2]).unwrap();
1819 let eval = evaluate_sync(
1820 Value::Tensor(a),
1821 Value::Tensor(b),
1822 &[Value::from("rows"), Value::from("stable")],
1823 )
1824 .expect("union");
1825 let (values, ia, ib) = eval.into_triple();
1826 match values {
1827 Value::Tensor(t) => {
1828 assert_eq!(t.shape, vec![3, 2]);
1829 assert_eq!(t.data, vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
1830 }
1831 other => panic!("expected tensor result, got {other:?}"),
1832 }
1833 let ia_tensor = tensor::value_into_tensor_for("union", ia).expect("ia tensor");
1834 assert_eq!(ia_tensor.data, vec![1.0, 2.0]);
1835 let ib_tensor = tensor::value_into_tensor_for("union", ib).expect("ib tensor");
1836 assert_eq!(ib_tensor.data, vec![2.0]);
1837 }
1838
1839 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1840 #[test]
1841 fn union_char_elements() {
1842 let a = CharArray::new(vec!['m', 'z', 'm', 'a'], 2, 2).unwrap();
1843 let b = CharArray::new(vec!['a', 'x', 'm', 'a'], 2, 2).unwrap();
1844 let eval = evaluate_sync(Value::CharArray(a), Value::CharArray(b), &[]).expect("union");
1845 match eval.values_value() {
1846 Value::CharArray(arr) => {
1847 assert_eq!(arr.rows, 4);
1848 assert_eq!(arr.cols, 1);
1849 assert_eq!(arr.data, vec!['a', 'm', 'x', 'z']);
1850 }
1851 other => panic!("expected char array, got {other:?}"),
1852 }
1853 let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1854 assert_eq!(ia.data, vec![4.0, 1.0, 3.0]);
1855 let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1856 assert_eq!(ib.data, vec![3.0]);
1857 }
1858
1859 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1860 #[test]
1861 fn union_string_rows_stable() {
1862 let a = StringArray::new(
1863 vec![
1864 "alpha".to_string(),
1865 "gamma".to_string(),
1866 "beta".to_string(),
1867 "beta".to_string(),
1868 ],
1869 vec![2, 2],
1870 )
1871 .unwrap();
1872 let b = StringArray::new(
1873 vec![
1874 "gamma".to_string(),
1875 "delta".to_string(),
1876 "beta".to_string(),
1877 "beta".to_string(),
1878 ],
1879 vec![2, 2],
1880 )
1881 .unwrap();
1882 let eval = evaluate_sync(
1883 Value::StringArray(a),
1884 Value::StringArray(b),
1885 &[Value::from("rows"), Value::from("stable")],
1886 )
1887 .expect("union");
1888 match eval.values_value() {
1889 Value::StringArray(arr) => {
1890 assert_eq!(arr.shape, vec![3, 2]);
1891 assert_eq!(
1892 arr.data,
1893 vec![
1894 "alpha".to_string(),
1895 "gamma".to_string(),
1896 "delta".to_string(),
1897 "beta".to_string(),
1898 "beta".to_string(),
1899 "beta".to_string()
1900 ]
1901 );
1902 }
1903 other => panic!("expected string array, got {other:?}"),
1904 }
1905 let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1906 assert_eq!(ia.data, vec![1.0, 2.0]);
1907 let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1908 assert_eq!(ib.data, vec![2.0]);
1909 }
1910
1911 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1912 #[test]
1913 fn union_gpu_roundtrip() {
1914 test_support::with_test_provider(|provider| {
1915 let a = Tensor::new(vec![4.0, 1.0, 2.0], vec![3, 1]).unwrap();
1916 let b = Tensor::new(vec![2.0, 5.0], vec![2, 1]).unwrap();
1917 let view_a = HostTensorView {
1918 data: &a.data,
1919 shape: &a.shape,
1920 };
1921 let view_b = HostTensorView {
1922 data: &b.data,
1923 shape: &b.shape,
1924 };
1925 let handle_a = provider.upload(&view_a).expect("upload A");
1926 let handle_b = provider.upload(&view_b).expect("upload B");
1927 let eval = evaluate_sync(
1928 Value::GpuTensor(handle_a),
1929 Value::GpuTensor(handle_b),
1930 &[Value::from("stable")],
1931 )
1932 .expect("union");
1933 let values = tensor::value_into_tensor_for("union", eval.values_value()).unwrap();
1934 assert_eq!(values.data, vec![4.0, 1.0, 2.0, 5.0]);
1935 let ia = tensor::value_into_tensor_for("union", eval.ia_value()).unwrap();
1936 assert_eq!(ia.data, vec![1.0, 2.0, 3.0]);
1937 let ib = tensor::value_into_tensor_for("union", eval.ib_value()).unwrap();
1938 assert_eq!(ib.data, vec![2.0]);
1939 });
1940 }
1941
1942 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1943 #[test]
1944 fn union_rejects_legacy_option() {
1945 let tensor =
1946 Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).expect("tensor construction failed");
1947 let err = evaluate_sync(
1948 Value::Tensor(tensor.clone()),
1949 Value::Tensor(tensor),
1950 &[Value::from("legacy")],
1951 )
1952 .unwrap_err();
1953 assert_eq!(
1954 err.identifier(),
1955 UNION_ERROR_LEGACY_OPTION_UNSUPPORTED.identifier
1956 );
1957 }
1958
1959 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1960 #[test]
1961 fn union_rejects_conflicting_order_options() {
1962 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).expect("tensor construction failed");
1963 let err = evaluate_sync(
1964 Value::Tensor(tensor.clone()),
1965 Value::Tensor(tensor),
1966 &[Value::from("stable"), Value::from("sorted")],
1967 )
1968 .unwrap_err();
1969 assert_eq!(
1970 err.identifier(),
1971 UNION_ERROR_CONFLICTING_ORDER_OPTIONS.identifier
1972 );
1973 }
1974
1975 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1976 #[test]
1977 fn union_rejects_unknown_option() {
1978 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).expect("tensor construction failed");
1979 let err = evaluate_sync(
1980 Value::Tensor(tensor.clone()),
1981 Value::Tensor(tensor),
1982 &[Value::from("bogus")],
1983 )
1984 .unwrap_err();
1985 assert_eq!(err.identifier(), UNION_ERROR_UNKNOWN_OPTION.identifier);
1986 }
1987
1988 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1989 #[test]
1990 fn union_rows_dimension_mismatch() {
1991 let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1992 let b = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1993 let err =
1994 evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[Value::from("rows")]).unwrap_err();
1995 assert_eq!(
1996 err.identifier(),
1997 UNION_ERROR_ROWS_COLUMN_MISMATCH.identifier
1998 );
1999 }
2000
2001 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2002 #[test]
2003 fn union_requires_matching_types() {
2004 let a = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
2005 let b = CharArray::new(vec!['a', 'b'], 1, 2).unwrap();
2006 let err = union_host(
2007 Value::Tensor(a),
2008 Value::CharArray(b),
2009 &UnionOptions {
2010 rows: false,
2011 order: UnionOrder::Sorted,
2012 },
2013 )
2014 .unwrap_err();
2015 assert_eq!(
2016 err.identifier(),
2017 UNION_ERROR_UNSUPPORTED_INPUT_TYPE.identifier
2018 );
2019 }
2020
2021 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2022 #[test]
2023 fn union_accepts_scalar_inputs() {
2024 let eval =
2025 evaluate_sync(Value::Int(IntValue::I32(1)), Value::Num(3.0), &[]).expect("union");
2026 match eval.values_value() {
2027 Value::Tensor(t) => {
2028 assert_eq!(t.data, vec![1.0, 3.0]);
2029 assert_eq!(t.shape, vec![2, 1]);
2030 }
2031 other => panic!("expected numeric tensor, got {other:?}"),
2032 }
2033 let ia = tensor::value_into_tensor_for("union", eval.ia_value()).unwrap();
2034 assert_eq!(ia.data, vec![1.0]);
2035 let ib = tensor::value_into_tensor_for("union", eval.ib_value()).unwrap();
2036 assert_eq!(ib.data, vec![1.0]);
2037 }
2038
2039 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2040 #[test]
2041 #[cfg(feature = "wgpu")]
2042 fn union_wgpu_matches_cpu() {
2043 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
2044 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
2045 );
2046 let a = Tensor::new(vec![4.0, 1.0, 2.0, 3.0], vec![4, 1]).unwrap();
2047 let b = Tensor::new(vec![2.0, 6.0, 3.0], vec![3, 1]).unwrap();
2048
2049 let cpu_eval =
2050 evaluate_sync(Value::Tensor(a.clone()), Value::Tensor(b.clone()), &[]).expect("union");
2051 let cpu_values = tensor::value_into_tensor_for("union", cpu_eval.values_value()).unwrap();
2052 let cpu_ia = tensor::value_into_tensor_for("union", cpu_eval.ia_value()).unwrap();
2053 let cpu_ib = tensor::value_into_tensor_for("union", cpu_eval.ib_value()).unwrap();
2054
2055 let provider = runmat_accelerate_api::provider().expect("provider");
2056 let view_a = HostTensorView {
2057 data: &a.data,
2058 shape: &a.shape,
2059 };
2060 let view_b = HostTensorView {
2061 data: &b.data,
2062 shape: &b.shape,
2063 };
2064 let handle_a = provider.upload(&view_a).expect("upload A");
2065 let handle_b = provider.upload(&view_b).expect("upload B");
2066 let gpu_eval = evaluate_sync(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
2067 .expect("union");
2068 let gpu_values = tensor::value_into_tensor_for("union", gpu_eval.values_value()).unwrap();
2069 let gpu_ia = tensor::value_into_tensor_for("union", gpu_eval.ia_value()).unwrap();
2070 let gpu_ib = tensor::value_into_tensor_for("union", gpu_eval.ib_value()).unwrap();
2071
2072 assert_eq!(gpu_values.data, cpu_values.data);
2073 assert_eq!(gpu_values.shape, cpu_values.shape);
2074 assert_eq!(gpu_ia.data, cpu_ia.data);
2075 assert_eq!(gpu_ib.data, cpu_ib.data);
2076 }
2077}