1use std::cmp::Ordering;
9use std::collections::{hash_map::Entry, HashMap};
10
11use runmat_accelerate_api::{
12 GpuTensorHandle, GpuTensorStorage, HostTensorOwned, UnionOptions, UnionOrder, UnionResult,
13};
14use runmat_builtins::{CharArray, ComplexTensor, StringArray, Tensor, Value};
15use runmat_macros::runtime_builtin;
16
17use super::type_resolvers::set_values_output_type;
18use crate::build_runtime_error;
19use crate::builtins::common::arg_tokens::tokens_from_values;
20use crate::builtins::common::gpu_helpers;
21use crate::builtins::common::random_args::complex_tensor_into_value;
22use crate::builtins::common::spec::{
23 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
24 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
25};
26use crate::builtins::common::tensor;
27
28#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::sorting_sets::union")]
29pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
30 name: "union",
31 op_kind: GpuOpKind::Custom("union"),
32 supported_precisions: &[ScalarType::F32, ScalarType::F64],
33 broadcast: BroadcastSemantics::None,
34 provider_hooks: &[ProviderHook::Custom("union")],
35 constant_strategy: ConstantStrategy::InlineLiteral,
36 residency: ResidencyPolicy::GatherImmediately,
37 nan_mode: ReductionNaN::Include,
38 two_pass_threshold: None,
39 workgroup_size: None,
40 accepts_nan_mode: true,
41 notes: "Providers may expose a dedicated union hook; otherwise tensors are gathered and processed on the host.",
42};
43
44#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::sorting_sets::union")]
45pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
46 name: "union",
47 shape: ShapeRequirements::Any,
48 constant_strategy: ConstantStrategy::InlineLiteral,
49 elementwise: None,
50 reduction: None,
51 emits_nan: true,
52 notes: "`union` terminates fusion chains and materialises results on the host; upstream tensors are gathered when necessary.",
53};
54
55fn union_error(message: impl Into<String>) -> crate::RuntimeError {
56 build_runtime_error(message).with_builtin("union").build()
57}
58
59#[runtime_builtin(
60 name = "union",
61 category = "array/sorting_sets",
62 summary = "Combine two arrays, returning their union with MATLAB-compatible ordering and index outputs.",
63 keywords = "union,set,stable,rows,indices,gpu",
64 accel = "array_construct",
65 sink = true,
66 type_resolver(set_values_output_type),
67 builtin_path = "crate::builtins::array::sorting_sets::union"
68)]
69async fn union_builtin(a: Value, b: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
70 let eval = evaluate(a, b, &rest).await?;
71 if let Some(out_count) = crate::output_count::current_output_count() {
72 if out_count == 0 {
73 return Ok(Value::OutputList(Vec::new()));
74 }
75 if out_count == 1 {
76 return Ok(Value::OutputList(vec![eval.into_values_value()]));
77 }
78 if out_count == 2 {
79 let (values, ia) = eval.into_pair();
80 return Ok(Value::OutputList(vec![values, ia]));
81 }
82 let (values, ia, ib) = eval.into_triple();
83 return Ok(crate::output_count::output_list_with_padding(
84 out_count,
85 vec![values, ia, ib],
86 ));
87 }
88 Ok(eval.into_values_value())
89}
90
91pub async fn evaluate(a: Value, b: Value, rest: &[Value]) -> crate::BuiltinResult<UnionEvaluation> {
93 let opts = parse_options(rest)?;
94 match (a, b) {
95 (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
96 union_gpu_pair(handle_a, handle_b, &opts).await
97 }
98 (Value::GpuTensor(handle_a), other) => union_gpu_mixed(handle_a, other, &opts, true).await,
99 (other, Value::GpuTensor(handle_b)) => union_gpu_mixed(handle_b, other, &opts, false).await,
100 (left, right) => union_host(left, right, &opts),
101 }
102}
103
104fn parse_options(rest: &[Value]) -> crate::BuiltinResult<UnionOptions> {
105 let mut opts = UnionOptions {
106 rows: false,
107 order: UnionOrder::Sorted,
108 };
109 let mut seen_order: Option<UnionOrder> = None;
110
111 let tokens = tokens_from_values(rest);
112 for (arg, token) in rest.iter().zip(tokens.iter()) {
113 let text = match token {
114 crate::builtins::common::arg_tokens::ArgToken::String(text) => text.as_str(),
115 _ => {
116 let text = tensor::value_to_string(arg)
117 .ok_or_else(|| union_error("union: expected string option arguments"))?;
118 let lowered = text.trim().to_ascii_lowercase();
119 parse_union_option(&mut opts, &mut seen_order, &lowered)?;
120 continue;
121 }
122 };
123 parse_union_option(&mut opts, &mut seen_order, text)?;
124 }
125
126 Ok(opts)
127}
128
129fn parse_union_option(
130 opts: &mut UnionOptions,
131 seen_order: &mut Option<UnionOrder>,
132 lowered: &str,
133) -> crate::BuiltinResult<()> {
134 match lowered {
135 "rows" => opts.rows = true,
136 "sorted" => {
137 if let Some(prev) = seen_order {
138 if *prev != UnionOrder::Sorted {
139 return Err(union_error("union: cannot combine 'sorted' with 'stable'"));
140 }
141 }
142 *seen_order = Some(UnionOrder::Sorted);
143 opts.order = UnionOrder::Sorted;
144 }
145 "stable" => {
146 if let Some(prev) = seen_order {
147 if *prev != UnionOrder::Stable {
148 return Err(union_error("union: cannot combine 'sorted' with 'stable'"));
149 }
150 }
151 *seen_order = Some(UnionOrder::Stable);
152 opts.order = UnionOrder::Stable;
153 }
154 "legacy" | "r2012a" => {
155 return Err(union_error(
156 "union: the 'legacy' behaviour is not supported",
157 ));
158 }
159 other => return Err(union_error(format!("union: unrecognised option '{other}'"))),
160 }
161 Ok(())
162}
163
164async fn union_gpu_pair(
165 handle_a: GpuTensorHandle,
166 handle_b: GpuTensorHandle,
167 opts: &UnionOptions,
168) -> crate::BuiltinResult<UnionEvaluation> {
169 if let Some(provider) = runmat_accelerate_api::provider() {
170 match provider.union(&handle_a, &handle_b, opts).await {
171 Ok(result) => return UnionEvaluation::from_union_result(result),
172 Err(_) => {
173 }
175 }
176 }
177 let tensor_a = gpu_helpers::gather_tensor_async(&handle_a).await?;
178 let tensor_b = gpu_helpers::gather_tensor_async(&handle_b).await?;
179 union_numeric(tensor_a, tensor_b, opts)
180}
181
182async fn union_gpu_mixed(
183 handle_gpu: GpuTensorHandle,
184 other: Value,
185 opts: &UnionOptions,
186 gpu_is_a: bool,
187) -> crate::BuiltinResult<UnionEvaluation> {
188 let tensor_gpu = gpu_helpers::gather_tensor_async(&handle_gpu).await?;
189 let tensor_other = tensor::value_into_tensor_for("union", other).map_err(|e| union_error(e))?;
190 if gpu_is_a {
191 union_numeric(tensor_gpu, tensor_other, opts)
192 } else {
193 union_numeric(tensor_other, tensor_gpu, opts)
194 }
195}
196
197fn union_host(a: Value, b: Value, opts: &UnionOptions) -> crate::BuiltinResult<UnionEvaluation> {
198 match (a, b) {
199 (Value::ComplexTensor(at), Value::ComplexTensor(bt)) => union_complex(at, bt, opts),
201 (Value::ComplexTensor(at), Value::Complex(re, im)) => {
202 let bt = ComplexTensor::new(vec![(re, im)], vec![1, 1])
203 .map_err(|e| union_error(format!("union: {e}")))?;
204 union_complex(at, bt, opts)
205 }
206 (Value::Complex(re, im), Value::ComplexTensor(bt)) => {
207 let at = ComplexTensor::new(vec![(re, im)], vec![1, 1])
208 .map_err(|e| union_error(format!("union: {e}")))?;
209 union_complex(at, bt, opts)
210 }
211 (Value::Complex(a_re, a_im), Value::Complex(b_re, b_im)) => {
212 let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
213 .map_err(|e| union_error(format!("union: {e}")))?;
214 let bt = ComplexTensor::new(vec![(b_re, b_im)], vec![1, 1])
215 .map_err(|e| union_error(format!("union: {e}")))?;
216 union_complex(at, bt, opts)
217 }
218
219 (Value::CharArray(ac), Value::CharArray(bc)) => union_char(ac, bc, opts),
221
222 (Value::StringArray(astring), Value::StringArray(bstring)) => {
224 union_string(astring, bstring, opts)
225 }
226 (Value::StringArray(astring), Value::String(b)) => {
227 let bstring = StringArray::new(vec![b], vec![1, 1])
228 .map_err(|e| union_error(format!("union: {e}")))?;
229 union_string(astring, bstring, opts)
230 }
231 (Value::String(a), Value::StringArray(bstring)) => {
232 let astring = StringArray::new(vec![a], vec![1, 1])
233 .map_err(|e| union_error(format!("union: {e}")))?;
234 union_string(astring, bstring, opts)
235 }
236 (Value::String(a), Value::String(b)) => {
237 let astring = StringArray::new(vec![a], vec![1, 1])
238 .map_err(|e| union_error(format!("union: {e}")))?;
239 let bstring = StringArray::new(vec![b], vec![1, 1])
240 .map_err(|e| union_error(format!("union: {e}")))?;
241 union_string(astring, bstring, opts)
242 }
243
244 (left, right) => {
246 let tensor_a =
247 tensor::value_into_tensor_for("union", left).map_err(|e| union_error(e))?;
248 let tensor_b =
249 tensor::value_into_tensor_for("union", right).map_err(|e| union_error(e))?;
250 union_numeric(tensor_a, tensor_b, opts)
251 }
252 }
253}
254
255fn union_numeric(
256 a: Tensor,
257 b: Tensor,
258 opts: &UnionOptions,
259) -> crate::BuiltinResult<UnionEvaluation> {
260 if opts.rows {
261 union_numeric_rows(a, b, opts)
262 } else {
263 union_numeric_elements(a, b, opts)
264 }
265}
266
267pub fn union_numeric_from_tensors(
269 a: Tensor,
270 b: Tensor,
271 opts: &UnionOptions,
272) -> crate::BuiltinResult<UnionEvaluation> {
273 union_numeric(a, b, opts)
274}
275
276fn union_numeric_elements(
277 a: Tensor,
278 b: Tensor,
279 opts: &UnionOptions,
280) -> crate::BuiltinResult<UnionEvaluation> {
281 let mut entries = Vec::<NumericUnionEntry>::new();
282 let mut map: HashMap<u64, usize> = HashMap::new();
283 let mut order_counter = 0usize;
284
285 for (idx, &value) in a.data.iter().enumerate() {
286 let key = canonicalize_f64(value);
287 match map.entry(key) {
288 Entry::Occupied(_) => {
289 }
291 Entry::Vacant(v) => {
292 let entry_idx = entries.len();
293 entries.push(NumericUnionEntry {
294 value,
295 a_index: Some(idx),
296 b_index: None,
297 order_rank: order_counter,
298 });
299 v.insert(entry_idx);
300 order_counter += 1;
301 }
302 }
303 }
304
305 for (idx, &value) in b.data.iter().enumerate() {
306 let key = canonicalize_f64(value);
307 match map.entry(key) {
308 Entry::Occupied(occ) => {
309 let entry = &mut entries[*occ.get()];
310 if entry.a_index.is_none() && entry.b_index.is_none() {
311 entry.b_index = Some(idx);
312 }
313 }
314 Entry::Vacant(v) => {
315 let entry_idx = entries.len();
316 entries.push(NumericUnionEntry {
317 value,
318 a_index: None,
319 b_index: Some(idx),
320 order_rank: order_counter,
321 });
322 v.insert(entry_idx);
323 order_counter += 1;
324 }
325 }
326 }
327
328 assemble_numeric_union(entries, opts)
329}
330
331fn union_numeric_rows(
332 a: Tensor,
333 b: Tensor,
334 opts: &UnionOptions,
335) -> crate::BuiltinResult<UnionEvaluation> {
336 if a.shape.len() != 2 || b.shape.len() != 2 {
337 return Err(union_error(
338 "union: 'rows' option requires 2-D numeric matrices",
339 ));
340 }
341 if a.shape[1] != b.shape[1] {
342 return Err(union_error(
343 "union: inputs must have the same number of columns when using 'rows'",
344 ));
345 }
346 let rows_a = a.shape[0];
347 let cols = a.shape[1];
348 let rows_b = b.shape[0];
349
350 let mut entries = Vec::<NumericRowUnionEntry>::new();
351 let mut map: HashMap<NumericRowKey, usize> = HashMap::new();
352 let mut order_counter = 0usize;
353
354 for r in 0..rows_a {
355 let mut row_values = Vec::with_capacity(cols);
356 for c in 0..cols {
357 let idx = r + c * rows_a;
358 row_values.push(a.data[idx]);
359 }
360 let key = NumericRowKey::from_slice(&row_values);
361 match map.entry(key) {
362 Entry::Occupied(_) => {}
363 Entry::Vacant(v) => {
364 let entry_idx = entries.len();
365 entries.push(NumericRowUnionEntry {
366 row_data: row_values,
367 a_row: Some(r),
368 b_row: None,
369 order_rank: order_counter,
370 });
371 v.insert(entry_idx);
372 order_counter += 1;
373 }
374 }
375 }
376
377 for r in 0..rows_b {
378 let mut row_values = Vec::with_capacity(cols);
379 for c in 0..cols {
380 let idx = r + c * rows_b;
381 row_values.push(b.data[idx]);
382 }
383 let key = NumericRowKey::from_slice(&row_values);
384 match map.entry(key) {
385 Entry::Occupied(occ) => {
386 let entry = &mut entries[*occ.get()];
387 if entry.a_row.is_none() && entry.b_row.is_none() {
388 entry.b_row = Some(r);
389 }
390 }
391 Entry::Vacant(v) => {
392 let entry_idx = entries.len();
393 entries.push(NumericRowUnionEntry {
394 row_data: row_values,
395 a_row: None,
396 b_row: Some(r),
397 order_rank: order_counter,
398 });
399 v.insert(entry_idx);
400 order_counter += 1;
401 }
402 }
403 }
404
405 assemble_numeric_row_union(entries, opts, cols)
406}
407
408fn union_complex(
409 a: ComplexTensor,
410 b: ComplexTensor,
411 opts: &UnionOptions,
412) -> crate::BuiltinResult<UnionEvaluation> {
413 if opts.rows {
414 union_complex_rows(a, b, opts)
415 } else {
416 union_complex_elements(a, b, opts)
417 }
418}
419
420fn union_complex_elements(
421 a: ComplexTensor,
422 b: ComplexTensor,
423 opts: &UnionOptions,
424) -> crate::BuiltinResult<UnionEvaluation> {
425 let mut entries = Vec::<ComplexUnionEntry>::new();
426 let mut map: HashMap<ComplexKey, usize> = HashMap::new();
427 let mut order_counter = 0usize;
428
429 for (idx, &value) in a.data.iter().enumerate() {
430 let key = ComplexKey::new(value);
431 match map.entry(key) {
432 Entry::Occupied(_) => {}
433 Entry::Vacant(v) => {
434 let entry_idx = entries.len();
435 entries.push(ComplexUnionEntry {
436 value,
437 a_index: Some(idx),
438 b_index: None,
439 order_rank: order_counter,
440 });
441 v.insert(entry_idx);
442 order_counter += 1;
443 }
444 }
445 }
446
447 for (idx, &value) in b.data.iter().enumerate() {
448 let key = ComplexKey::new(value);
449 match map.entry(key) {
450 Entry::Occupied(occ) => {
451 let entry = &mut entries[*occ.get()];
452 if entry.a_index.is_none() && entry.b_index.is_none() {
453 entry.b_index = Some(idx);
454 }
455 }
456 Entry::Vacant(v) => {
457 let entry_idx = entries.len();
458 entries.push(ComplexUnionEntry {
459 value,
460 a_index: None,
461 b_index: Some(idx),
462 order_rank: order_counter,
463 });
464 v.insert(entry_idx);
465 order_counter += 1;
466 }
467 }
468 }
469
470 assemble_complex_union(entries, opts)
471}
472
473fn union_complex_rows(
474 a: ComplexTensor,
475 b: ComplexTensor,
476 opts: &UnionOptions,
477) -> crate::BuiltinResult<UnionEvaluation> {
478 if a.shape.len() != 2 || b.shape.len() != 2 {
479 return Err(union_error(
480 "union: 'rows' option requires 2-D complex matrices",
481 ));
482 }
483 if a.shape[1] != b.shape[1] {
484 return Err(union_error(
485 "union: inputs must have the same number of columns when using 'rows'",
486 ));
487 }
488 let rows_a = a.shape[0];
489 let cols = a.shape[1];
490 let rows_b = b.shape[0];
491
492 let mut entries = Vec::<ComplexRowUnionEntry>::new();
493 let mut map: HashMap<Vec<ComplexKey>, usize> = HashMap::new();
494 let mut order_counter = 0usize;
495
496 for r in 0..rows_a {
497 let mut row_values = Vec::with_capacity(cols);
498 let mut key_row = Vec::with_capacity(cols);
499 for c in 0..cols {
500 let idx = r + c * rows_a;
501 let value = a.data[idx];
502 row_values.push(value);
503 key_row.push(ComplexKey::new(value));
504 }
505 match map.entry(key_row) {
506 Entry::Occupied(_) => {}
507 Entry::Vacant(v) => {
508 let entry_idx = entries.len();
509 entries.push(ComplexRowUnionEntry {
510 row_data: row_values,
511 a_row: Some(r),
512 b_row: None,
513 order_rank: order_counter,
514 });
515 v.insert(entry_idx);
516 order_counter += 1;
517 }
518 }
519 }
520
521 for r in 0..rows_b {
522 let mut row_values = Vec::with_capacity(cols);
523 let mut key_row = Vec::with_capacity(cols);
524 for c in 0..cols {
525 let idx = r + c * rows_b;
526 let value = b.data[idx];
527 row_values.push(value);
528 key_row.push(ComplexKey::new(value));
529 }
530 match map.entry(key_row) {
531 Entry::Occupied(occ) => {
532 let entry = &mut entries[*occ.get()];
533 if entry.a_row.is_none() && entry.b_row.is_none() {
534 entry.b_row = Some(r);
535 }
536 }
537 Entry::Vacant(v) => {
538 let entry_idx = entries.len();
539 entries.push(ComplexRowUnionEntry {
540 row_data: row_values,
541 a_row: None,
542 b_row: Some(r),
543 order_rank: order_counter,
544 });
545 v.insert(entry_idx);
546 order_counter += 1;
547 }
548 }
549 }
550
551 assemble_complex_row_union(entries, opts, cols)
552}
553
554fn union_char(
555 a: CharArray,
556 b: CharArray,
557 opts: &UnionOptions,
558) -> crate::BuiltinResult<UnionEvaluation> {
559 if opts.rows {
560 union_char_rows(a, b, opts)
561 } else {
562 union_char_elements(a, b, opts)
563 }
564}
565
566fn union_char_elements(
567 a: CharArray,
568 b: CharArray,
569 opts: &UnionOptions,
570) -> crate::BuiltinResult<UnionEvaluation> {
571 let mut entries = Vec::<CharUnionEntry>::new();
572 let mut map: HashMap<u32, usize> = HashMap::new();
573 let mut order_counter = 0usize;
574
575 for col in 0..a.cols {
576 for row in 0..a.rows {
577 let linear_idx = row + col * a.rows;
578 let data_idx = row * a.cols + col;
579 let ch = a.data[data_idx];
580 let key = ch as u32;
581 match map.entry(key) {
582 Entry::Occupied(_) => {}
583 Entry::Vacant(v) => {
584 let entry_idx = entries.len();
585 entries.push(CharUnionEntry {
586 ch,
587 a_index: Some(linear_idx),
588 b_index: None,
589 order_rank: order_counter,
590 });
591 v.insert(entry_idx);
592 order_counter += 1;
593 }
594 }
595 }
596 }
597
598 for col in 0..b.cols {
599 for row in 0..b.rows {
600 let linear_idx = row + col * b.rows;
601 let data_idx = row * b.cols + col;
602 let ch = b.data[data_idx];
603 let key = ch as u32;
604 match map.entry(key) {
605 Entry::Occupied(occ) => {
606 let entry = &mut entries[*occ.get()];
607 if entry.a_index.is_none() && entry.b_index.is_none() {
608 entry.b_index = Some(linear_idx);
609 }
610 }
611 Entry::Vacant(v) => {
612 let entry_idx = entries.len();
613 entries.push(CharUnionEntry {
614 ch,
615 a_index: None,
616 b_index: Some(linear_idx),
617 order_rank: order_counter,
618 });
619 v.insert(entry_idx);
620 order_counter += 1;
621 }
622 }
623 }
624 }
625
626 assemble_char_union(entries, opts)
627}
628
629fn union_char_rows(
630 a: CharArray,
631 b: CharArray,
632 opts: &UnionOptions,
633) -> crate::BuiltinResult<UnionEvaluation> {
634 if a.cols != b.cols {
635 return Err(union_error(
636 "union: inputs must have the same number of columns when using 'rows'",
637 ));
638 }
639 let rows_a = a.rows;
640 let rows_b = b.rows;
641 let cols = a.cols;
642
643 let mut entries = Vec::<CharRowUnionEntry>::new();
644 let mut map: HashMap<RowCharKey, usize> = HashMap::new();
645 let mut order_counter = 0usize;
646
647 for r in 0..rows_a {
648 let mut row_values = Vec::with_capacity(cols);
649 for c in 0..cols {
650 let idx = r * cols + c;
651 row_values.push(a.data[idx]);
652 }
653 let key = RowCharKey::from_slice(&row_values);
654 match map.entry(key) {
655 Entry::Occupied(_) => {}
656 Entry::Vacant(v) => {
657 let entry_idx = entries.len();
658 entries.push(CharRowUnionEntry {
659 row_data: row_values,
660 a_row: Some(r),
661 b_row: None,
662 order_rank: order_counter,
663 });
664 v.insert(entry_idx);
665 order_counter += 1;
666 }
667 }
668 }
669
670 for r in 0..rows_b {
671 let mut row_values = Vec::with_capacity(cols);
672 for c in 0..cols {
673 let idx = r * cols + c;
674 row_values.push(b.data[idx]);
675 }
676 let key = RowCharKey::from_slice(&row_values);
677 match map.entry(key) {
678 Entry::Occupied(occ) => {
679 let entry = &mut entries[*occ.get()];
680 if entry.a_row.is_none() && entry.b_row.is_none() {
681 entry.b_row = Some(r);
682 }
683 }
684 Entry::Vacant(v) => {
685 let entry_idx = entries.len();
686 entries.push(CharRowUnionEntry {
687 row_data: row_values,
688 a_row: None,
689 b_row: Some(r),
690 order_rank: order_counter,
691 });
692 v.insert(entry_idx);
693 order_counter += 1;
694 }
695 }
696 }
697
698 assemble_char_row_union(entries, opts, cols)
699}
700
701fn union_string(
702 a: StringArray,
703 b: StringArray,
704 opts: &UnionOptions,
705) -> crate::BuiltinResult<UnionEvaluation> {
706 if opts.rows {
707 union_string_rows(a, b, opts)
708 } else {
709 union_string_elements(a, b, opts)
710 }
711}
712
713fn union_string_elements(
714 a: StringArray,
715 b: StringArray,
716 opts: &UnionOptions,
717) -> crate::BuiltinResult<UnionEvaluation> {
718 let mut entries = Vec::<StringUnionEntry>::new();
719 let mut map: HashMap<String, usize> = HashMap::new();
720 let mut order_counter = 0usize;
721
722 for (idx, value) in a.data.iter().enumerate() {
723 match map.entry(value.clone()) {
724 Entry::Occupied(_) => {}
725 Entry::Vacant(v) => {
726 let entry_idx = entries.len();
727 entries.push(StringUnionEntry {
728 value: value.clone(),
729 a_index: Some(idx),
730 b_index: None,
731 order_rank: order_counter,
732 });
733 v.insert(entry_idx);
734 order_counter += 1;
735 }
736 }
737 }
738
739 for (idx, value) in b.data.iter().enumerate() {
740 match map.entry(value.clone()) {
741 Entry::Occupied(occ) => {
742 let entry = &mut entries[*occ.get()];
743 if entry.a_index.is_none() && entry.b_index.is_none() {
744 entry.b_index = Some(idx);
745 }
746 }
747 Entry::Vacant(v) => {
748 let entry_idx = entries.len();
749 entries.push(StringUnionEntry {
750 value: value.clone(),
751 a_index: None,
752 b_index: Some(idx),
753 order_rank: order_counter,
754 });
755 v.insert(entry_idx);
756 order_counter += 1;
757 }
758 }
759 }
760
761 assemble_string_union(entries, opts)
762}
763
764fn union_string_rows(
765 a: StringArray,
766 b: StringArray,
767 opts: &UnionOptions,
768) -> crate::BuiltinResult<UnionEvaluation> {
769 if a.shape.len() != 2 || b.shape.len() != 2 {
770 return Err(union_error(
771 "union: 'rows' option requires 2-D string arrays",
772 ));
773 }
774 if a.shape[1] != b.shape[1] {
775 return Err(union_error(
776 "union: inputs must have the same number of columns when using 'rows'",
777 ));
778 }
779 let rows_a = a.shape[0];
780 let cols = a.shape[1];
781 let rows_b = b.shape[0];
782
783 let mut entries = Vec::<StringRowUnionEntry>::new();
784 let mut map: HashMap<RowStringKey, usize> = HashMap::new();
785 let mut order_counter = 0usize;
786
787 for r in 0..rows_a {
788 let mut row_values = Vec::with_capacity(cols);
789 for c in 0..cols {
790 let idx = r + c * rows_a;
791 row_values.push(a.data[idx].clone());
792 }
793 let key = RowStringKey(row_values.clone());
794 match map.entry(key) {
795 Entry::Occupied(_) => {}
796 Entry::Vacant(v) => {
797 let entry_idx = entries.len();
798 entries.push(StringRowUnionEntry {
799 row_data: row_values,
800 a_row: Some(r),
801 b_row: None,
802 order_rank: order_counter,
803 });
804 v.insert(entry_idx);
805 order_counter += 1;
806 }
807 }
808 }
809
810 for r in 0..rows_b {
811 let mut row_values = Vec::with_capacity(cols);
812 for c in 0..cols {
813 let idx = r + c * rows_b;
814 row_values.push(b.data[idx].clone());
815 }
816 let key = RowStringKey(row_values.clone());
817 match map.entry(key) {
818 Entry::Occupied(occ) => {
819 let entry = &mut entries[*occ.get()];
820 if entry.a_row.is_none() && entry.b_row.is_none() {
821 entry.b_row = Some(r);
822 }
823 }
824 Entry::Vacant(v) => {
825 let entry_idx = entries.len();
826 entries.push(StringRowUnionEntry {
827 row_data: row_values,
828 a_row: None,
829 b_row: Some(r),
830 order_rank: order_counter,
831 });
832 v.insert(entry_idx);
833 order_counter += 1;
834 }
835 }
836 }
837
838 assemble_string_row_union(entries, opts, cols)
839}
840
841#[derive(Debug, Clone)]
842pub struct UnionEvaluation {
843 values: Value,
844 ia: Tensor,
845 ib: Tensor,
846}
847
848impl UnionEvaluation {
849 fn new(values: Value, ia: Tensor, ib: Tensor) -> Self {
850 Self { values, ia, ib }
851 }
852
853 pub fn from_union_result(result: UnionResult) -> crate::BuiltinResult<Self> {
854 let UnionResult { values, ia, ib } = result;
855 let values_tensor = Tensor::new(values.data, values.shape)
856 .map_err(|e| union_error(format!("union: {e}")))?;
857 let ia_tensor =
858 Tensor::new(ia.data, ia.shape).map_err(|e| union_error(format!("union: {e}")))?;
859 let ib_tensor =
860 Tensor::new(ib.data, ib.shape).map_err(|e| union_error(format!("union: {e}")))?;
861 Ok(UnionEvaluation::new(
862 tensor::tensor_into_value(values_tensor),
863 ia_tensor,
864 ib_tensor,
865 ))
866 }
867
868 pub fn into_numeric_union_result(self) -> crate::BuiltinResult<UnionResult> {
869 let UnionEvaluation { values, ia, ib } = self;
870 let values_tensor =
871 tensor::value_into_tensor_for("union", values).map_err(|e| union_error(e))?;
872 Ok(UnionResult {
873 values: HostTensorOwned {
874 data: values_tensor.data,
875 shape: values_tensor.shape,
876 storage: GpuTensorStorage::Real,
877 },
878 ia: HostTensorOwned {
879 data: ia.data,
880 shape: ia.shape,
881 storage: GpuTensorStorage::Real,
882 },
883 ib: HostTensorOwned {
884 data: ib.data,
885 shape: ib.shape,
886 storage: GpuTensorStorage::Real,
887 },
888 })
889 }
890
891 pub fn into_values_value(self) -> Value {
892 self.values
893 }
894
895 pub fn into_pair(self) -> (Value, Value) {
896 let ia = tensor::tensor_into_value(self.ia);
897 (self.values, ia)
898 }
899
900 pub fn into_triple(self) -> (Value, Value, Value) {
901 let ia = tensor::tensor_into_value(self.ia);
902 let ib = tensor::tensor_into_value(self.ib);
903 (self.values, ia, ib)
904 }
905
906 pub fn values_value(&self) -> Value {
907 self.values.clone()
908 }
909
910 pub fn ia_value(&self) -> Value {
911 tensor::tensor_into_value(self.ia.clone())
912 }
913
914 pub fn ib_value(&self) -> Value {
915 tensor::tensor_into_value(self.ib.clone())
916 }
917}
918
919#[derive(Debug)]
920struct NumericUnionEntry {
921 value: f64,
922 a_index: Option<usize>,
923 b_index: Option<usize>,
924 order_rank: usize,
925}
926
927#[derive(Debug)]
928struct NumericRowUnionEntry {
929 row_data: Vec<f64>,
930 a_row: Option<usize>,
931 b_row: Option<usize>,
932 order_rank: usize,
933}
934
935#[derive(Debug)]
936struct ComplexUnionEntry {
937 value: (f64, f64),
938 a_index: Option<usize>,
939 b_index: Option<usize>,
940 order_rank: usize,
941}
942
943#[derive(Debug)]
944struct ComplexRowUnionEntry {
945 row_data: Vec<(f64, f64)>,
946 a_row: Option<usize>,
947 b_row: Option<usize>,
948 order_rank: usize,
949}
950
951#[derive(Debug)]
952struct CharUnionEntry {
953 ch: char,
954 a_index: Option<usize>,
955 b_index: Option<usize>,
956 order_rank: usize,
957}
958
959#[derive(Debug)]
960struct CharRowUnionEntry {
961 row_data: Vec<char>,
962 a_row: Option<usize>,
963 b_row: Option<usize>,
964 order_rank: usize,
965}
966
967#[derive(Debug)]
968struct StringUnionEntry {
969 value: String,
970 a_index: Option<usize>,
971 b_index: Option<usize>,
972 order_rank: usize,
973}
974
975#[derive(Debug)]
976struct StringRowUnionEntry {
977 row_data: Vec<String>,
978 a_row: Option<usize>,
979 b_row: Option<usize>,
980 order_rank: usize,
981}
982
983#[derive(Debug, Clone, PartialEq, Eq, Hash)]
984struct NumericRowKey(Vec<u64>);
985
986impl NumericRowKey {
987 fn from_slice(values: &[f64]) -> Self {
988 NumericRowKey(values.iter().map(|&v| canonicalize_f64(v)).collect())
989 }
990}
991
992#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
993struct ComplexKey {
994 re: u64,
995 im: u64,
996}
997
998impl ComplexKey {
999 fn new(value: (f64, f64)) -> Self {
1000 Self {
1001 re: canonicalize_f64(value.0),
1002 im: canonicalize_f64(value.1),
1003 }
1004 }
1005}
1006
1007#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1008struct RowCharKey(Vec<u32>);
1009
1010impl RowCharKey {
1011 fn from_slice(values: &[char]) -> Self {
1012 RowCharKey(values.iter().map(|&ch| ch as u32).collect())
1013 }
1014}
1015
1016#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1017struct RowStringKey(Vec<String>);
1018
1019fn assemble_numeric_union(
1020 entries: Vec<NumericUnionEntry>,
1021 opts: &UnionOptions,
1022) -> crate::BuiltinResult<UnionEvaluation> {
1023 let mut order: Vec<usize> = (0..entries.len()).collect();
1024 match opts.order {
1025 UnionOrder::Sorted => {
1026 order.sort_by(|&lhs, &rhs| compare_f64(entries[lhs].value, entries[rhs].value));
1027 }
1028 UnionOrder::Stable => {
1029 order.sort_by_key(|&idx| entries[idx].order_rank);
1030 }
1031 }
1032
1033 let mut values = Vec::with_capacity(order.len());
1034 let mut ia = Vec::new();
1035 let mut ib = Vec::new();
1036 for &idx in &order {
1037 let entry = &entries[idx];
1038 values.push(entry.value);
1039 if let Some(a_idx) = entry.a_index {
1040 ia.push((a_idx + 1) as f64);
1041 } else if let Some(b_idx) = entry.b_index {
1042 ib.push((b_idx + 1) as f64);
1043 }
1044 }
1045
1046 let value_tensor = Tensor::new(values, vec![order.len(), 1])
1047 .map_err(|e| union_error(format!("union: {e}")))?;
1048 let ia_len = ia.len();
1049 let ib_len = ib.len();
1050 let ia_tensor =
1051 Tensor::new(ia, vec![ia_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1052 let ib_tensor =
1053 Tensor::new(ib, vec![ib_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1054
1055 Ok(UnionEvaluation::new(
1056 tensor::tensor_into_value(value_tensor),
1057 ia_tensor,
1058 ib_tensor,
1059 ))
1060}
1061
1062fn assemble_numeric_row_union(
1063 entries: Vec<NumericRowUnionEntry>,
1064 opts: &UnionOptions,
1065 cols: usize,
1066) -> crate::BuiltinResult<UnionEvaluation> {
1067 let mut order: Vec<usize> = (0..entries.len()).collect();
1068 match opts.order {
1069 UnionOrder::Sorted => {
1070 order.sort_by(|&lhs, &rhs| {
1071 compare_numeric_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1072 });
1073 }
1074 UnionOrder::Stable => {
1075 order.sort_by_key(|&idx| entries[idx].order_rank);
1076 }
1077 }
1078
1079 let unique_rows = order.len();
1080 let mut values = vec![0.0f64; unique_rows * cols];
1081 let mut ia = Vec::new();
1082 let mut ib = Vec::new();
1083
1084 for (row_pos, &entry_idx) in order.iter().enumerate() {
1085 let entry = &entries[entry_idx];
1086 for col in 0..cols {
1087 let dest = row_pos + col * unique_rows;
1088 values[dest] = entry.row_data[col];
1089 }
1090 if let Some(a_row) = entry.a_row {
1091 ia.push((a_row + 1) as f64);
1092 } else if let Some(b_row) = entry.b_row {
1093 ib.push((b_row + 1) as f64);
1094 }
1095 }
1096
1097 let value_tensor = Tensor::new(values, vec![unique_rows, cols])
1098 .map_err(|e| union_error(format!("union: {e}")))?;
1099 let ia_len = ia.len();
1100 let ib_len = ib.len();
1101 let ia_tensor =
1102 Tensor::new(ia, vec![ia_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1103 let ib_tensor =
1104 Tensor::new(ib, vec![ib_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1105
1106 Ok(UnionEvaluation::new(
1107 tensor::tensor_into_value(value_tensor),
1108 ia_tensor,
1109 ib_tensor,
1110 ))
1111}
1112
1113fn assemble_complex_union(
1114 entries: Vec<ComplexUnionEntry>,
1115 opts: &UnionOptions,
1116) -> crate::BuiltinResult<UnionEvaluation> {
1117 let mut order: Vec<usize> = (0..entries.len()).collect();
1118 match opts.order {
1119 UnionOrder::Sorted => {
1120 order.sort_by(|&lhs, &rhs| compare_complex(entries[lhs].value, entries[rhs].value));
1121 }
1122 UnionOrder::Stable => {
1123 order.sort_by_key(|&idx| entries[idx].order_rank);
1124 }
1125 }
1126
1127 let mut values = Vec::with_capacity(order.len());
1128 let mut ia = Vec::new();
1129 let mut ib = Vec::new();
1130 for &idx in &order {
1131 let entry = &entries[idx];
1132 values.push(entry.value);
1133 if let Some(a_idx) = entry.a_index {
1134 ia.push((a_idx + 1) as f64);
1135 } else if let Some(b_idx) = entry.b_index {
1136 ib.push((b_idx + 1) as f64);
1137 }
1138 }
1139
1140 let value_tensor = ComplexTensor::new(values, vec![order.len(), 1])
1141 .map_err(|e| union_error(format!("union: {e}")))?;
1142 let ia_len = ia.len();
1143 let ib_len = ib.len();
1144 let ia_tensor =
1145 Tensor::new(ia, vec![ia_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1146 let ib_tensor =
1147 Tensor::new(ib, vec![ib_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1148
1149 Ok(UnionEvaluation::new(
1150 complex_tensor_into_value(value_tensor),
1151 ia_tensor,
1152 ib_tensor,
1153 ))
1154}
1155
1156fn assemble_complex_row_union(
1157 entries: Vec<ComplexRowUnionEntry>,
1158 opts: &UnionOptions,
1159 cols: usize,
1160) -> crate::BuiltinResult<UnionEvaluation> {
1161 let mut order: Vec<usize> = (0..entries.len()).collect();
1162 match opts.order {
1163 UnionOrder::Sorted => {
1164 order.sort_by(|&lhs, &rhs| {
1165 compare_complex_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1166 });
1167 }
1168 UnionOrder::Stable => {
1169 order.sort_by_key(|&idx| entries[idx].order_rank);
1170 }
1171 }
1172
1173 let unique_rows = order.len();
1174 let mut values = vec![(0.0, 0.0); unique_rows * cols];
1175 let mut ia = Vec::new();
1176 let mut ib = Vec::new();
1177
1178 for (row_pos, &entry_idx) in order.iter().enumerate() {
1179 let entry = &entries[entry_idx];
1180 for col in 0..cols {
1181 let dest = row_pos + col * unique_rows;
1182 values[dest] = entry.row_data[col];
1183 }
1184 if let Some(a_row) = entry.a_row {
1185 ia.push((a_row + 1) as f64);
1186 } else if let Some(b_row) = entry.b_row {
1187 ib.push((b_row + 1) as f64);
1188 }
1189 }
1190
1191 let value_tensor = ComplexTensor::new(values, vec![unique_rows, cols])
1192 .map_err(|e| union_error(format!("union: {e}")))?;
1193 let ia_len = ia.len();
1194 let ib_len = ib.len();
1195 let ia_tensor =
1196 Tensor::new(ia, vec![ia_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1197 let ib_tensor =
1198 Tensor::new(ib, vec![ib_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1199
1200 Ok(UnionEvaluation::new(
1201 complex_tensor_into_value(value_tensor),
1202 ia_tensor,
1203 ib_tensor,
1204 ))
1205}
1206
1207fn assemble_char_union(
1208 entries: Vec<CharUnionEntry>,
1209 opts: &UnionOptions,
1210) -> crate::BuiltinResult<UnionEvaluation> {
1211 let mut order: Vec<usize> = (0..entries.len()).collect();
1212 match opts.order {
1213 UnionOrder::Sorted => {
1214 order.sort_by(|&lhs, &rhs| entries[lhs].ch.cmp(&entries[rhs].ch));
1215 }
1216 UnionOrder::Stable => {
1217 order.sort_by_key(|&idx| entries[idx].order_rank);
1218 }
1219 }
1220
1221 let mut values = Vec::with_capacity(order.len());
1222 let mut ia = Vec::new();
1223 let mut ib = Vec::new();
1224 for &idx in &order {
1225 let entry = &entries[idx];
1226 values.push(entry.ch);
1227 if let Some(a_idx) = entry.a_index {
1228 ia.push((a_idx + 1) as f64);
1229 } else if let Some(b_idx) = entry.b_index {
1230 ib.push((b_idx + 1) as f64);
1231 }
1232 }
1233
1234 let value_array =
1235 CharArray::new(values, order.len(), 1).map_err(|e| union_error(format!("union: {e}")))?;
1236 let ia_len = ia.len();
1237 let ib_len = ib.len();
1238 let ia_tensor =
1239 Tensor::new(ia, vec![ia_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1240 let ib_tensor =
1241 Tensor::new(ib, vec![ib_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1242
1243 Ok(UnionEvaluation::new(
1244 Value::CharArray(value_array),
1245 ia_tensor,
1246 ib_tensor,
1247 ))
1248}
1249
1250fn assemble_char_row_union(
1251 entries: Vec<CharRowUnionEntry>,
1252 opts: &UnionOptions,
1253 cols: usize,
1254) -> crate::BuiltinResult<UnionEvaluation> {
1255 let mut order: Vec<usize> = (0..entries.len()).collect();
1256 match opts.order {
1257 UnionOrder::Sorted => {
1258 order.sort_by(|&lhs, &rhs| {
1259 compare_char_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1260 });
1261 }
1262 UnionOrder::Stable => {
1263 order.sort_by_key(|&idx| entries[idx].order_rank);
1264 }
1265 }
1266
1267 let unique_rows = order.len();
1268 let mut values = vec!['\0'; unique_rows * cols];
1269 let mut ia = Vec::new();
1270 let mut ib = Vec::new();
1271
1272 for (row_pos, &entry_idx) in order.iter().enumerate() {
1273 let entry = &entries[entry_idx];
1274 for col in 0..cols {
1275 let dest = row_pos * cols + col;
1276 values[dest] = entry.row_data[col];
1277 }
1278 if let Some(a_row) = entry.a_row {
1279 ia.push((a_row + 1) as f64);
1280 } else if let Some(b_row) = entry.b_row {
1281 ib.push((b_row + 1) as f64);
1282 }
1283 }
1284
1285 let value_array = CharArray::new(values, unique_rows, cols)
1286 .map_err(|e| union_error(format!("union: {e}")))?;
1287 let ia_len = ia.len();
1288 let ib_len = ib.len();
1289 let ia_tensor =
1290 Tensor::new(ia, vec![ia_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1291 let ib_tensor =
1292 Tensor::new(ib, vec![ib_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1293
1294 Ok(UnionEvaluation::new(
1295 Value::CharArray(value_array),
1296 ia_tensor,
1297 ib_tensor,
1298 ))
1299}
1300
1301fn assemble_string_union(
1302 entries: Vec<StringUnionEntry>,
1303 opts: &UnionOptions,
1304) -> crate::BuiltinResult<UnionEvaluation> {
1305 let mut order: Vec<usize> = (0..entries.len()).collect();
1306 match opts.order {
1307 UnionOrder::Sorted => {
1308 order.sort_by(|&lhs, &rhs| entries[lhs].value.cmp(&entries[rhs].value));
1309 }
1310 UnionOrder::Stable => {
1311 order.sort_by_key(|&idx| entries[idx].order_rank);
1312 }
1313 }
1314
1315 let mut values = Vec::with_capacity(order.len());
1316 let mut ia = Vec::new();
1317 let mut ib = Vec::new();
1318 for &idx in &order {
1319 let entry = &entries[idx];
1320 values.push(entry.value.clone());
1321 if let Some(a_idx) = entry.a_index {
1322 ia.push((a_idx + 1) as f64);
1323 } else if let Some(b_idx) = entry.b_index {
1324 ib.push((b_idx + 1) as f64);
1325 }
1326 }
1327
1328 let value_array = StringArray::new(values, vec![order.len(), 1])
1329 .map_err(|e| union_error(format!("union: {e}")))?;
1330 let ia_len = ia.len();
1331 let ib_len = ib.len();
1332 let ia_tensor =
1333 Tensor::new(ia, vec![ia_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1334 let ib_tensor =
1335 Tensor::new(ib, vec![ib_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1336
1337 Ok(UnionEvaluation::new(
1338 Value::StringArray(value_array),
1339 ia_tensor,
1340 ib_tensor,
1341 ))
1342}
1343
1344fn assemble_string_row_union(
1345 entries: Vec<StringRowUnionEntry>,
1346 opts: &UnionOptions,
1347 cols: usize,
1348) -> crate::BuiltinResult<UnionEvaluation> {
1349 let mut order: Vec<usize> = (0..entries.len()).collect();
1350 match opts.order {
1351 UnionOrder::Sorted => {
1352 order.sort_by(|&lhs, &rhs| {
1353 compare_string_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1354 });
1355 }
1356 UnionOrder::Stable => {
1357 order.sort_by_key(|&idx| entries[idx].order_rank);
1358 }
1359 }
1360
1361 let unique_rows = order.len();
1362 let mut values = vec![String::new(); unique_rows * cols];
1363 let mut ia = Vec::new();
1364 let mut ib = Vec::new();
1365
1366 for (row_pos, &entry_idx) in order.iter().enumerate() {
1367 let entry = &entries[entry_idx];
1368 for col in 0..cols {
1369 let dest = row_pos + col * unique_rows;
1370 values[dest] = entry.row_data[col].clone();
1371 }
1372 if let Some(a_row) = entry.a_row {
1373 ia.push((a_row + 1) as f64);
1374 } else if let Some(b_row) = entry.b_row {
1375 ib.push((b_row + 1) as f64);
1376 }
1377 }
1378
1379 let value_array = StringArray::new(values, vec![unique_rows, cols])
1380 .map_err(|e| union_error(format!("union: {e}")))?;
1381 let ia_len = ia.len();
1382 let ib_len = ib.len();
1383 let ia_tensor =
1384 Tensor::new(ia, vec![ia_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1385 let ib_tensor =
1386 Tensor::new(ib, vec![ib_len, 1]).map_err(|e| union_error(format!("union: {e}")))?;
1387
1388 Ok(UnionEvaluation::new(
1389 Value::StringArray(value_array),
1390 ia_tensor,
1391 ib_tensor,
1392 ))
1393}
1394
1395fn canonicalize_f64(value: f64) -> u64 {
1396 if value.is_nan() {
1397 0x7ff8_0000_0000_0000u64
1398 } else if value == 0.0 {
1399 0u64
1400 } else {
1401 value.to_bits()
1402 }
1403}
1404
1405fn compare_f64(a: f64, b: f64) -> Ordering {
1406 if a.is_nan() {
1407 if b.is_nan() {
1408 Ordering::Equal
1409 } else {
1410 Ordering::Greater
1411 }
1412 } else if b.is_nan() {
1413 Ordering::Less
1414 } else {
1415 a.partial_cmp(&b).unwrap_or(Ordering::Equal)
1416 }
1417}
1418
1419fn compare_numeric_rows(a: &[f64], b: &[f64]) -> Ordering {
1420 for (lhs, rhs) in a.iter().zip(b.iter()) {
1421 let ord = compare_f64(*lhs, *rhs);
1422 if ord != Ordering::Equal {
1423 return ord;
1424 }
1425 }
1426 Ordering::Equal
1427}
1428
1429fn complex_is_nan(value: (f64, f64)) -> bool {
1430 value.0.is_nan() || value.1.is_nan()
1431}
1432
1433fn compare_complex(a: (f64, f64), b: (f64, f64)) -> Ordering {
1434 match (complex_is_nan(a), complex_is_nan(b)) {
1435 (true, true) => Ordering::Equal,
1436 (true, false) => Ordering::Greater,
1437 (false, true) => Ordering::Less,
1438 (false, false) => {
1439 let mag_a = a.0.hypot(a.1);
1440 let mag_b = b.0.hypot(b.1);
1441 let mag_cmp = compare_f64(mag_a, mag_b);
1442 if mag_cmp != Ordering::Equal {
1443 return mag_cmp;
1444 }
1445 let re_cmp = compare_f64(a.0, b.0);
1446 if re_cmp != Ordering::Equal {
1447 return re_cmp;
1448 }
1449 compare_f64(a.1, b.1)
1450 }
1451 }
1452}
1453
1454fn compare_complex_rows(a: &[(f64, f64)], b: &[(f64, f64)]) -> Ordering {
1455 for (lhs, rhs) in a.iter().zip(b.iter()) {
1456 let ord = compare_complex(*lhs, *rhs);
1457 if ord != Ordering::Equal {
1458 return ord;
1459 }
1460 }
1461 Ordering::Equal
1462}
1463
1464fn compare_char_rows(a: &[char], b: &[char]) -> Ordering {
1465 for (lhs, rhs) in a.iter().zip(b.iter()) {
1466 let ord = lhs.cmp(rhs);
1467 if ord != Ordering::Equal {
1468 return ord;
1469 }
1470 }
1471 Ordering::Equal
1472}
1473
1474fn compare_string_rows(a: &[String], b: &[String]) -> Ordering {
1475 for (lhs, rhs) in a.iter().zip(b.iter()) {
1476 let ord = lhs.cmp(rhs);
1477 if ord != Ordering::Equal {
1478 return ord;
1479 }
1480 }
1481 Ordering::Equal
1482}
1483
1484#[cfg(test)]
1485pub(crate) mod tests {
1486 use super::*;
1487 use crate::builtins::common::test_support;
1488 use runmat_accelerate_api::HostTensorView;
1489 use runmat_builtins::{IntValue, ResolveContext, Tensor, Type, Value};
1490
1491 fn error_message(err: crate::RuntimeError) -> String {
1492 err.message().to_string()
1493 }
1494
1495 fn evaluate_sync(a: Value, b: Value, rest: &[Value]) -> crate::BuiltinResult<UnionEvaluation> {
1496 futures::executor::block_on(evaluate(a, b, rest))
1497 }
1498
1499 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1500 #[test]
1501 fn union_numeric_sorted_default() {
1502 let a = Tensor::new(vec![5.0, 7.0, 1.0], vec![3, 1]).unwrap();
1503 let b = Tensor::new(vec![3.0, 1.0, 1.0], vec![3, 1]).unwrap();
1504 let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[]).expect("union");
1505 match eval.values_value() {
1506 Value::Tensor(t) => {
1507 assert_eq!(t.data, vec![1.0, 3.0, 5.0, 7.0]);
1508 assert_eq!(t.shape, vec![4, 1]);
1509 }
1510 other => panic!("expected tensor result, got {other:?}"),
1511 }
1512 let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1513 assert_eq!(ia.data, vec![3.0, 1.0, 2.0]);
1514 assert_eq!(ia.shape, vec![3, 1]);
1515 let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1516 assert_eq!(ib.data, vec![1.0]);
1517 assert_eq!(ib.shape, vec![1, 1]);
1518 }
1519
1520 #[test]
1521 fn union_type_resolver_numeric() {
1522 assert_eq!(
1523 set_values_output_type(
1524 &[Type::tensor(), Type::tensor()],
1525 &ResolveContext::new(Vec::new()),
1526 ),
1527 Type::tensor()
1528 );
1529 }
1530
1531 #[test]
1532 fn union_type_resolver_string_array() {
1533 assert_eq!(
1534 set_values_output_type(
1535 &[Type::cell_of(Type::String), Type::String],
1536 &ResolveContext::new(Vec::new()),
1537 ),
1538 Type::cell_of(Type::String)
1539 );
1540 }
1541
1542 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1543 #[test]
1544 fn union_numeric_stable_order() {
1545 let a = Tensor::new(vec![5.0, 7.0, 1.0], vec![3, 1]).unwrap();
1546 let b = Tensor::new(vec![3.0, 2.0, 4.0], vec![3, 1]).unwrap();
1547 let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[Value::from("stable")])
1548 .expect("union");
1549 match eval.values_value() {
1550 Value::Tensor(t) => {
1551 assert_eq!(t.data, vec![5.0, 7.0, 1.0, 3.0, 2.0, 4.0]);
1552 assert_eq!(t.shape, vec![6, 1]);
1553 }
1554 other => panic!("expected tensor result, got {other:?}"),
1555 }
1556 let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1557 assert_eq!(ia.data, vec![1.0, 2.0, 3.0]);
1558 let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1559 assert_eq!(ib.data, vec![1.0, 2.0, 3.0]);
1560 }
1561
1562 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1563 #[test]
1564 fn union_numeric_sorted_places_nan_last() {
1565 let a = Tensor::new(vec![f64::NAN, 1.0], vec![2, 1]).unwrap();
1566 let b = Tensor::new(vec![2.0, f64::NAN], vec![2, 1]).unwrap();
1567 let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[]).expect("union");
1568 let values = tensor::value_into_tensor_for("union", eval.values_value()).expect("values");
1569 assert_eq!(values.shape, vec![3, 1]);
1570 assert_eq!(values.data[0], 1.0);
1571 assert_eq!(values.data[1], 2.0);
1572 assert!(values.data[2].is_nan());
1573 let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1574 assert_eq!(ia.data, vec![2.0, 1.0]);
1575 let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1576 assert_eq!(ib.data, vec![1.0]);
1577 }
1578
1579 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1580 #[test]
1581 fn union_numeric_rows_sorted() {
1582 let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
1583 let b = Tensor::new(vec![3.0, 5.0, 4.0, 6.0], vec![2, 2]).unwrap();
1584 let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[Value::from("rows")])
1585 .expect("union");
1586 match eval.values_value() {
1587 Value::Tensor(t) => {
1588 assert_eq!(t.shape, vec![3, 2]);
1589 assert_eq!(t.data, vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
1590 }
1591 other => panic!("expected tensor result, got {other:?}"),
1592 }
1593 let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1594 assert_eq!(ia.data, vec![1.0, 2.0]);
1595 let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1596 assert_eq!(ib.data, vec![2.0]);
1597 }
1598
1599 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1600 #[test]
1601 fn union_numeric_rows_stable_preserves_first_occurrence() {
1602 let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
1603 let b = Tensor::new(vec![3.0, 5.0, 1.0, 4.0, 6.0, 2.0], vec![3, 2]).unwrap();
1604 let eval = evaluate_sync(
1605 Value::Tensor(a),
1606 Value::Tensor(b),
1607 &[Value::from("rows"), Value::from("stable")],
1608 )
1609 .expect("union");
1610 let (values, ia, ib) = eval.into_triple();
1611 match values {
1612 Value::Tensor(t) => {
1613 assert_eq!(t.shape, vec![3, 2]);
1614 assert_eq!(t.data, vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
1615 }
1616 other => panic!("expected tensor result, got {other:?}"),
1617 }
1618 let ia_tensor = tensor::value_into_tensor_for("union", ia).expect("ia tensor");
1619 assert_eq!(ia_tensor.data, vec![1.0, 2.0]);
1620 let ib_tensor = tensor::value_into_tensor_for("union", ib).expect("ib tensor");
1621 assert_eq!(ib_tensor.data, vec![2.0]);
1622 }
1623
1624 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1625 #[test]
1626 fn union_char_elements() {
1627 let a = CharArray::new(vec!['m', 'z', 'm', 'a'], 2, 2).unwrap();
1628 let b = CharArray::new(vec!['a', 'x', 'm', 'a'], 2, 2).unwrap();
1629 let eval = evaluate_sync(Value::CharArray(a), Value::CharArray(b), &[]).expect("union");
1630 match eval.values_value() {
1631 Value::CharArray(arr) => {
1632 assert_eq!(arr.rows, 4);
1633 assert_eq!(arr.cols, 1);
1634 assert_eq!(arr.data, vec!['a', 'm', 'x', 'z']);
1635 }
1636 other => panic!("expected char array, got {other:?}"),
1637 }
1638 let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1639 assert_eq!(ia.data, vec![4.0, 1.0, 3.0]);
1640 let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1641 assert_eq!(ib.data, vec![3.0]);
1642 }
1643
1644 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1645 #[test]
1646 fn union_string_rows_stable() {
1647 let a = StringArray::new(
1648 vec![
1649 "alpha".to_string(),
1650 "gamma".to_string(),
1651 "beta".to_string(),
1652 "beta".to_string(),
1653 ],
1654 vec![2, 2],
1655 )
1656 .unwrap();
1657 let b = StringArray::new(
1658 vec![
1659 "gamma".to_string(),
1660 "delta".to_string(),
1661 "beta".to_string(),
1662 "beta".to_string(),
1663 ],
1664 vec![2, 2],
1665 )
1666 .unwrap();
1667 let eval = evaluate_sync(
1668 Value::StringArray(a),
1669 Value::StringArray(b),
1670 &[Value::from("rows"), Value::from("stable")],
1671 )
1672 .expect("union");
1673 match eval.values_value() {
1674 Value::StringArray(arr) => {
1675 assert_eq!(arr.shape, vec![3, 2]);
1676 assert_eq!(
1677 arr.data,
1678 vec![
1679 "alpha".to_string(),
1680 "gamma".to_string(),
1681 "delta".to_string(),
1682 "beta".to_string(),
1683 "beta".to_string(),
1684 "beta".to_string()
1685 ]
1686 );
1687 }
1688 other => panic!("expected string array, got {other:?}"),
1689 }
1690 let ia = tensor::value_into_tensor_for("union", eval.ia_value()).expect("ia tensor");
1691 assert_eq!(ia.data, vec![1.0, 2.0]);
1692 let ib = tensor::value_into_tensor_for("union", eval.ib_value()).expect("ib tensor");
1693 assert_eq!(ib.data, vec![2.0]);
1694 }
1695
1696 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1697 #[test]
1698 fn union_gpu_roundtrip() {
1699 test_support::with_test_provider(|provider| {
1700 let a = Tensor::new(vec![4.0, 1.0, 2.0], vec![3, 1]).unwrap();
1701 let b = Tensor::new(vec![2.0, 5.0], vec![2, 1]).unwrap();
1702 let view_a = HostTensorView {
1703 data: &a.data,
1704 shape: &a.shape,
1705 };
1706 let view_b = HostTensorView {
1707 data: &b.data,
1708 shape: &b.shape,
1709 };
1710 let handle_a = provider.upload(&view_a).expect("upload A");
1711 let handle_b = provider.upload(&view_b).expect("upload B");
1712 let eval = evaluate_sync(
1713 Value::GpuTensor(handle_a),
1714 Value::GpuTensor(handle_b),
1715 &[Value::from("stable")],
1716 )
1717 .expect("union");
1718 let values = tensor::value_into_tensor_for("union", eval.values_value()).unwrap();
1719 assert_eq!(values.data, vec![4.0, 1.0, 2.0, 5.0]);
1720 let ia = tensor::value_into_tensor_for("union", eval.ia_value()).unwrap();
1721 assert_eq!(ia.data, vec![1.0, 2.0, 3.0]);
1722 let ib = tensor::value_into_tensor_for("union", eval.ib_value()).unwrap();
1723 assert_eq!(ib.data, vec![2.0]);
1724 });
1725 }
1726
1727 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1728 #[test]
1729 fn union_rejects_legacy_option() {
1730 let tensor =
1731 Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).expect("tensor construction failed");
1732 let err = error_message(
1733 evaluate_sync(
1734 Value::Tensor(tensor.clone()),
1735 Value::Tensor(tensor),
1736 &[Value::from("legacy")],
1737 )
1738 .unwrap_err(),
1739 );
1740 assert!(err.contains("legacy"));
1741 }
1742
1743 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1744 #[test]
1745 fn union_rows_dimension_mismatch() {
1746 let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1747 let b = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1748 let err = error_message(
1749 evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[Value::from("rows")]).unwrap_err(),
1750 );
1751 assert!(err.contains("same number of columns"));
1752 }
1753
1754 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1755 #[test]
1756 fn union_requires_matching_types() {
1757 let a = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1758 let b = CharArray::new(vec!['a', 'b'], 1, 2).unwrap();
1759 let err = error_message(
1760 union_host(
1761 Value::Tensor(a),
1762 Value::CharArray(b),
1763 &UnionOptions {
1764 rows: false,
1765 order: UnionOrder::Sorted,
1766 },
1767 )
1768 .unwrap_err(),
1769 );
1770 assert!(err.contains("unsupported input type"));
1771 }
1772
1773 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1774 #[test]
1775 fn union_accepts_scalar_inputs() {
1776 let eval =
1777 evaluate_sync(Value::Int(IntValue::I32(1)), Value::Num(3.0), &[]).expect("union");
1778 match eval.values_value() {
1779 Value::Tensor(t) => {
1780 assert_eq!(t.data, vec![1.0, 3.0]);
1781 assert_eq!(t.shape, vec![2, 1]);
1782 }
1783 other => panic!("expected numeric tensor, got {other:?}"),
1784 }
1785 let ia = tensor::value_into_tensor_for("union", eval.ia_value()).unwrap();
1786 assert_eq!(ia.data, vec![1.0]);
1787 let ib = tensor::value_into_tensor_for("union", eval.ib_value()).unwrap();
1788 assert_eq!(ib.data, vec![1.0]);
1789 }
1790
1791 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1792 #[test]
1793 #[cfg(feature = "wgpu")]
1794 fn union_wgpu_matches_cpu() {
1795 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1796 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1797 );
1798 let a = Tensor::new(vec![4.0, 1.0, 2.0, 3.0], vec![4, 1]).unwrap();
1799 let b = Tensor::new(vec![2.0, 6.0, 3.0], vec![3, 1]).unwrap();
1800
1801 let cpu_eval =
1802 evaluate_sync(Value::Tensor(a.clone()), Value::Tensor(b.clone()), &[]).expect("union");
1803 let cpu_values = tensor::value_into_tensor_for("union", cpu_eval.values_value()).unwrap();
1804 let cpu_ia = tensor::value_into_tensor_for("union", cpu_eval.ia_value()).unwrap();
1805 let cpu_ib = tensor::value_into_tensor_for("union", cpu_eval.ib_value()).unwrap();
1806
1807 let provider = runmat_accelerate_api::provider().expect("provider");
1808 let view_a = HostTensorView {
1809 data: &a.data,
1810 shape: &a.shape,
1811 };
1812 let view_b = HostTensorView {
1813 data: &b.data,
1814 shape: &b.shape,
1815 };
1816 let handle_a = provider.upload(&view_a).expect("upload A");
1817 let handle_b = provider.upload(&view_b).expect("upload B");
1818 let gpu_eval = evaluate_sync(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
1819 .expect("union");
1820 let gpu_values = tensor::value_into_tensor_for("union", gpu_eval.values_value()).unwrap();
1821 let gpu_ia = tensor::value_into_tensor_for("union", gpu_eval.ia_value()).unwrap();
1822 let gpu_ib = tensor::value_into_tensor_for("union", gpu_eval.ib_value()).unwrap();
1823
1824 assert_eq!(gpu_values.data, cpu_values.data);
1825 assert_eq!(gpu_values.shape, cpu_values.shape);
1826 assert_eq!(gpu_ia.data, cpu_ia.data);
1827 assert_eq!(gpu_ib.data, cpu_ib.data);
1828 }
1829}