1use std::cmp::Ordering;
9use std::collections::{HashMap, HashSet};
10
11use runmat_accelerate_api::GpuTensorHandle;
12use runmat_builtins::{CharArray, ComplexTensor, StringArray, Tensor, Value};
13use runmat_macros::runtime_builtin;
14
15use super::type_resolvers::set_values_output_type;
16use crate::build_runtime_error;
17use crate::builtins::common::arg_tokens::tokens_from_values;
18use crate::builtins::common::gpu_helpers;
19use crate::builtins::common::random_args::complex_tensor_into_value;
20use crate::builtins::common::spec::{
21 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
22 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
23};
24use crate::builtins::common::tensor;
25
26#[runmat_macros::register_gpu_spec(
27 builtin_path = "crate::builtins::array::sorting_sets::intersect"
28)]
29pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
30 name: "intersect",
31 op_kind: GpuOpKind::Custom("intersect"),
32 supported_precisions: &[ScalarType::F32, ScalarType::F64],
33 broadcast: BroadcastSemantics::None,
34 provider_hooks: &[ProviderHook::Custom("intersect")],
35 constant_strategy: ConstantStrategy::InlineLiteral,
36 residency: ResidencyPolicy::GatherImmediately,
37 nan_mode: ReductionNaN::Include,
38 two_pass_threshold: None,
39 workgroup_size: None,
40 accepts_nan_mode: true,
41 notes:
42 "Providers may expose a dedicated intersect hook; otherwise tensors are gathered and processed on the host.",
43};
44
45#[runmat_macros::register_fusion_spec(
46 builtin_path = "crate::builtins::array::sorting_sets::intersect"
47)]
48pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
49 name: "intersect",
50 shape: ShapeRequirements::Any,
51 constant_strategy: ConstantStrategy::InlineLiteral,
52 elementwise: None,
53 reduction: None,
54 emits_nan: true,
55 notes: "`intersect` materialises its inputs and terminates fusion chains; upstream GPU tensors are gathered when necessary.",
56};
57
58fn intersect_error(message: impl Into<String>) -> crate::RuntimeError {
59 build_runtime_error(message)
60 .with_builtin("intersect")
61 .build()
62}
63
64#[runtime_builtin(
65 name = "intersect",
66 category = "array/sorting_sets",
67 summary = "Return common elements or rows across arrays with MATLAB-compatible ordering and index outputs.",
68 keywords = "intersect,set,stable,rows,indices,gpu",
69 accel = "array_construct",
70 sink = true,
71 type_resolver(set_values_output_type),
72 builtin_path = "crate::builtins::array::sorting_sets::intersect"
73)]
74async fn intersect_builtin(a: Value, b: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
75 let eval = evaluate(a, b, &rest).await?;
76 if let Some(out_count) = crate::output_count::current_output_count() {
77 if out_count == 0 {
78 return Ok(Value::OutputList(Vec::new()));
79 }
80 if out_count == 1 {
81 return Ok(Value::OutputList(vec![eval.into_values_value()]));
82 }
83 if out_count == 2 {
84 let (values, ia) = eval.into_pair();
85 return Ok(Value::OutputList(vec![values, ia]));
86 }
87 let (values, ia, ib) = eval.into_triple();
88 return Ok(crate::output_count::output_list_with_padding(
89 out_count,
90 vec![values, ia, ib],
91 ));
92 }
93 Ok(eval.into_values_value())
94}
95
96pub async fn evaluate(
98 a: Value,
99 b: Value,
100 rest: &[Value],
101) -> crate::BuiltinResult<IntersectEvaluation> {
102 let opts = parse_options(rest)?;
103 match (a, b) {
104 (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
105 intersect_gpu_pair(handle_a, handle_b, &opts).await
106 }
107 (Value::GpuTensor(handle_a), other) => {
108 intersect_gpu_mixed(handle_a, other, &opts, true).await
109 }
110 (other, Value::GpuTensor(handle_b)) => {
111 intersect_gpu_mixed(handle_b, other, &opts, false).await
112 }
113 (left, right) => intersect_host(left, right, &opts),
114 }
115}
116
117#[derive(Debug, Clone, Copy, PartialEq, Eq)]
118enum IntersectOrder {
119 Sorted,
120 Stable,
121}
122
123#[derive(Debug, Clone)]
124struct IntersectOptions {
125 rows: bool,
126 order: IntersectOrder,
127}
128
129fn parse_options(rest: &[Value]) -> crate::BuiltinResult<IntersectOptions> {
130 let mut opts = IntersectOptions {
131 rows: false,
132 order: IntersectOrder::Sorted,
133 };
134 let mut seen_order: Option<IntersectOrder> = None;
135
136 let tokens = tokens_from_values(rest);
137 for (arg, token) in rest.iter().zip(tokens.iter()) {
138 let text = match token {
139 crate::builtins::common::arg_tokens::ArgToken::String(text) => text.as_str(),
140 _ => {
141 let text = tensor::value_to_string(arg).ok_or_else(|| {
142 intersect_error("intersect: expected string option arguments")
143 })?;
144 let lowered = text.trim().to_ascii_lowercase();
145 parse_intersect_option(&mut opts, &mut seen_order, &lowered)?;
146 continue;
147 }
148 };
149 parse_intersect_option(&mut opts, &mut seen_order, text)?;
150 }
151
152 Ok(opts)
153}
154
155fn parse_intersect_option(
156 opts: &mut IntersectOptions,
157 seen_order: &mut Option<IntersectOrder>,
158 lowered: &str,
159) -> crate::BuiltinResult<()> {
160 match lowered {
161 "rows" => opts.rows = true,
162 "sorted" => {
163 if let Some(prev) = seen_order {
164 if *prev != IntersectOrder::Sorted {
165 return Err(intersect_error(
166 "intersect: cannot combine 'sorted' with 'stable'",
167 ));
168 }
169 }
170 *seen_order = Some(IntersectOrder::Sorted);
171 opts.order = IntersectOrder::Sorted;
172 }
173 "stable" => {
174 if let Some(prev) = seen_order {
175 if *prev != IntersectOrder::Stable {
176 return Err(intersect_error(
177 "intersect: cannot combine 'sorted' with 'stable'",
178 ));
179 }
180 }
181 *seen_order = Some(IntersectOrder::Stable);
182 opts.order = IntersectOrder::Stable;
183 }
184 "legacy" | "r2012a" => {
185 return Err(intersect_error(
186 "intersect: the 'legacy' behaviour is not supported",
187 ));
188 }
189 other => {
190 return Err(intersect_error(format!(
191 "intersect: unrecognised option '{other}'"
192 )))
193 }
194 }
195 Ok(())
196}
197
198async fn intersect_gpu_pair(
199 handle_a: GpuTensorHandle,
200 handle_b: GpuTensorHandle,
201 opts: &IntersectOptions,
202) -> crate::BuiltinResult<IntersectEvaluation> {
203 let tensor_a = gpu_helpers::gather_tensor_async(&handle_a).await?;
204 let tensor_b = gpu_helpers::gather_tensor_async(&handle_b).await?;
205 intersect_numeric(tensor_a, tensor_b, opts)
206}
207
208async fn intersect_gpu_mixed(
209 handle_gpu: GpuTensorHandle,
210 other: Value,
211 opts: &IntersectOptions,
212 gpu_is_a: bool,
213) -> crate::BuiltinResult<IntersectEvaluation> {
214 let tensor_gpu = gpu_helpers::gather_tensor_async(&handle_gpu).await?;
215 let tensor_other =
216 tensor::value_into_tensor_for("intersect", other).map_err(|e| intersect_error(e))?;
217 if gpu_is_a {
218 intersect_numeric(tensor_gpu, tensor_other, opts)
219 } else {
220 intersect_numeric(tensor_other, tensor_gpu, opts)
221 }
222}
223
224fn intersect_host(
225 a: Value,
226 b: Value,
227 opts: &IntersectOptions,
228) -> crate::BuiltinResult<IntersectEvaluation> {
229 match (a, b) {
230 (Value::ComplexTensor(at), Value::ComplexTensor(bt)) => intersect_complex(at, bt, opts),
231 (Value::ComplexTensor(at), Value::Complex(re, im)) => {
232 let bt = scalar_complex_tensor(re, im)?;
233 intersect_complex(at, bt, opts)
234 }
235 (Value::Complex(re, im), Value::ComplexTensor(bt)) => {
236 let at = scalar_complex_tensor(re, im)?;
237 intersect_complex(at, bt, opts)
238 }
239 (Value::Complex(a_re, a_im), Value::Complex(b_re, b_im)) => {
240 let at = scalar_complex_tensor(a_re, a_im)?;
241 let bt = scalar_complex_tensor(b_re, b_im)?;
242 intersect_complex(at, bt, opts)
243 }
244 (Value::ComplexTensor(at), other) => {
245 let bt = value_into_complex_tensor(other)?;
246 intersect_complex(at, bt, opts)
247 }
248 (other, Value::ComplexTensor(bt)) => {
249 let at = value_into_complex_tensor(other)?;
250 intersect_complex(at, bt, opts)
251 }
252 (Value::Complex(re, im), other) => {
253 let at = scalar_complex_tensor(re, im)?;
254 let bt = value_into_complex_tensor(other)?;
255 intersect_complex(at, bt, opts)
256 }
257 (other, Value::Complex(re, im)) => {
258 let at = value_into_complex_tensor(other)?;
259 let bt = scalar_complex_tensor(re, im)?;
260 intersect_complex(at, bt, opts)
261 }
262
263 (Value::CharArray(ac), Value::CharArray(bc)) => intersect_char(ac, bc, opts),
264
265 (Value::StringArray(astring), Value::StringArray(bstring)) => {
266 intersect_string(astring, bstring, opts)
267 }
268 (Value::StringArray(astring), Value::String(b)) => {
269 let bstring = StringArray::new(vec![b], vec![1, 1])
270 .map_err(|e| intersect_error(format!("intersect: {e}")))?;
271 intersect_string(astring, bstring, opts)
272 }
273 (Value::String(a), Value::StringArray(bstring)) => {
274 let astring = StringArray::new(vec![a], vec![1, 1])
275 .map_err(|e| intersect_error(format!("intersect: {e}")))?;
276 intersect_string(astring, bstring, opts)
277 }
278 (Value::String(a), Value::String(b)) => {
279 let astring = StringArray::new(vec![a], vec![1, 1])
280 .map_err(|e| intersect_error(format!("intersect: {e}")))?;
281 let bstring = StringArray::new(vec![b], vec![1, 1])
282 .map_err(|e| intersect_error(format!("intersect: {e}")))?;
283 intersect_string(astring, bstring, opts)
284 }
285
286 (left, right) => {
287 let tensor_a =
288 tensor::value_into_tensor_for("intersect", left).map_err(|e| intersect_error(e))?;
289 let tensor_b = tensor::value_into_tensor_for("intersect", right)
290 .map_err(|e| intersect_error(e))?;
291 intersect_numeric(tensor_a, tensor_b, opts)
292 }
293 }
294}
295
296fn intersect_numeric(
297 a: Tensor,
298 b: Tensor,
299 opts: &IntersectOptions,
300) -> crate::BuiltinResult<IntersectEvaluation> {
301 if opts.rows {
302 intersect_numeric_rows(a, b, opts)
303 } else {
304 intersect_numeric_elements(a, b, opts)
305 }
306}
307
308fn intersect_numeric_elements(
309 a: Tensor,
310 b: Tensor,
311 opts: &IntersectOptions,
312) -> crate::BuiltinResult<IntersectEvaluation> {
313 let mut b_map: HashMap<u64, usize> = HashMap::new();
314 for (idx, &value) in b.data.iter().enumerate() {
315 let key = canonicalize_f64(value);
316 b_map.entry(key).or_insert(idx);
317 }
318
319 let mut seen: HashSet<u64> = HashSet::new();
320 let mut entries = Vec::<NumericIntersectEntry>::new();
321 let mut order_counter = 0usize;
322
323 for (idx, &value) in a.data.iter().enumerate() {
324 let key = canonicalize_f64(value);
325 if seen.contains(&key) {
326 continue;
327 }
328 if let Some(&b_idx) = b_map.get(&key) {
329 entries.push(NumericIntersectEntry {
330 value,
331 a_index: idx,
332 b_index: b_idx,
333 order_rank: order_counter,
334 });
335 seen.insert(key);
336 order_counter += 1;
337 }
338 }
339
340 assemble_numeric_intersect(entries, opts)
341}
342
343fn intersect_numeric_rows(
344 a: Tensor,
345 b: Tensor,
346 opts: &IntersectOptions,
347) -> crate::BuiltinResult<IntersectEvaluation> {
348 if a.shape.len() != 2 || b.shape.len() != 2 {
349 return Err(intersect_error(
350 "intersect: 'rows' option requires 2-D numeric matrices",
351 ));
352 }
353 if a.shape[1] != b.shape[1] {
354 return Err(intersect_error(
355 "intersect: inputs must have the same number of columns when using 'rows'",
356 ));
357 }
358 let rows_a = a.shape[0];
359 let cols = a.shape[1];
360 let rows_b = b.shape[0];
361
362 let mut b_map: HashMap<NumericRowKey, usize> = HashMap::new();
363 for r in 0..rows_b {
364 let mut row_values = Vec::with_capacity(cols);
365 for c in 0..cols {
366 let idx = r + c * rows_b;
367 row_values.push(b.data[idx]);
368 }
369 let key = NumericRowKey::from_slice(&row_values);
370 b_map.entry(key).or_insert(r);
371 }
372
373 let mut seen: HashSet<NumericRowKey> = HashSet::new();
374 let mut entries = Vec::<NumericRowIntersectEntry>::new();
375 let mut order_counter = 0usize;
376
377 for r in 0..rows_a {
378 let mut row_values = Vec::with_capacity(cols);
379 for c in 0..cols {
380 let idx = r + c * rows_a;
381 row_values.push(a.data[idx]);
382 }
383 let key = NumericRowKey::from_slice(&row_values);
384 if seen.contains(&key) {
385 continue;
386 }
387 if let Some(&b_row) = b_map.get(&key) {
388 entries.push(NumericRowIntersectEntry {
389 row_data: row_values,
390 a_row: r,
391 b_row,
392 order_rank: order_counter,
393 });
394 seen.insert(key);
395 order_counter += 1;
396 }
397 }
398
399 assemble_numeric_row_intersect(entries, opts, cols)
400}
401
402fn intersect_complex(
403 a: ComplexTensor,
404 b: ComplexTensor,
405 opts: &IntersectOptions,
406) -> crate::BuiltinResult<IntersectEvaluation> {
407 if opts.rows {
408 intersect_complex_rows(a, b, opts)
409 } else {
410 intersect_complex_elements(a, b, opts)
411 }
412}
413
414fn intersect_complex_elements(
415 a: ComplexTensor,
416 b: ComplexTensor,
417 opts: &IntersectOptions,
418) -> crate::BuiltinResult<IntersectEvaluation> {
419 let mut b_map: HashMap<ComplexKey, usize> = HashMap::new();
420 for (idx, &value) in b.data.iter().enumerate() {
421 let key = ComplexKey::new(value);
422 b_map.entry(key).or_insert(idx);
423 }
424
425 let mut seen: HashSet<ComplexKey> = HashSet::new();
426 let mut entries = Vec::<ComplexIntersectEntry>::new();
427 let mut order_counter = 0usize;
428
429 for (idx, &value) in a.data.iter().enumerate() {
430 let key = ComplexKey::new(value);
431 if seen.contains(&key) {
432 continue;
433 }
434 if let Some(&b_idx) = b_map.get(&key) {
435 entries.push(ComplexIntersectEntry {
436 value,
437 a_index: idx,
438 b_index: b_idx,
439 order_rank: order_counter,
440 });
441 seen.insert(key);
442 order_counter += 1;
443 }
444 }
445
446 assemble_complex_intersect(entries, opts)
447}
448
449fn intersect_complex_rows(
450 a: ComplexTensor,
451 b: ComplexTensor,
452 opts: &IntersectOptions,
453) -> crate::BuiltinResult<IntersectEvaluation> {
454 if a.shape.len() != 2 || b.shape.len() != 2 {
455 return Err(intersect_error(
456 "intersect: 'rows' option requires 2-D complex matrices",
457 ));
458 }
459 if a.shape[1] != b.shape[1] {
460 return Err(intersect_error(
461 "intersect: inputs must have the same number of columns when using 'rows'",
462 ));
463 }
464 let rows_a = a.shape[0];
465 let cols = a.shape[1];
466 let rows_b = b.shape[0];
467
468 let mut b_map: HashMap<Vec<ComplexKey>, usize> = HashMap::new();
469 for r in 0..rows_b {
470 let mut row_keys = Vec::with_capacity(cols);
471 for c in 0..cols {
472 let idx = r + c * rows_b;
473 row_keys.push(ComplexKey::new(b.data[idx]));
474 }
475 b_map.entry(row_keys).or_insert(r);
476 }
477
478 let mut seen: HashSet<Vec<ComplexKey>> = HashSet::new();
479 let mut entries = Vec::<ComplexRowIntersectEntry>::new();
480 let mut order_counter = 0usize;
481
482 for r in 0..rows_a {
483 let mut row_values = Vec::with_capacity(cols);
484 let mut row_keys = Vec::with_capacity(cols);
485 for c in 0..cols {
486 let idx = r + c * rows_a;
487 let value = a.data[idx];
488 row_values.push(value);
489 row_keys.push(ComplexKey::new(value));
490 }
491 if seen.contains(&row_keys) {
492 continue;
493 }
494 if let Some(&b_row) = b_map.get(&row_keys) {
495 entries.push(ComplexRowIntersectEntry {
496 row_data: row_values,
497 a_row: r,
498 b_row,
499 order_rank: order_counter,
500 });
501 seen.insert(row_keys);
502 order_counter += 1;
503 }
504 }
505
506 assemble_complex_row_intersect(entries, opts, cols)
507}
508
509fn intersect_char(
510 a: CharArray,
511 b: CharArray,
512 opts: &IntersectOptions,
513) -> crate::BuiltinResult<IntersectEvaluation> {
514 if opts.rows {
515 intersect_char_rows(a, b, opts)
516 } else {
517 intersect_char_elements(a, b, opts)
518 }
519}
520
521fn intersect_char_elements(
522 a: CharArray,
523 b: CharArray,
524 opts: &IntersectOptions,
525) -> crate::BuiltinResult<IntersectEvaluation> {
526 let mut seen: HashSet<u32> = HashSet::new();
527 let mut entries = Vec::<CharIntersectEntry>::new();
528 let mut order_counter = 0usize;
529
530 for col in 0..a.cols {
531 for row in 0..a.rows {
532 let linear_idx = row + col * a.rows;
533 let data_idx = row * a.cols + col;
534 let ch = a.data[data_idx];
535 let key = ch as u32;
536 if seen.contains(&key) {
537 continue;
538 }
539 if let Some(b_idx) = find_char_index(&b, ch) {
540 entries.push(CharIntersectEntry {
541 ch,
542 a_index: linear_idx,
543 b_index: b_idx,
544 order_rank: order_counter,
545 });
546 seen.insert(key);
547 order_counter += 1;
548 }
549 }
550 }
551
552 assemble_char_intersect(entries, opts, &b)
553}
554
555fn intersect_char_rows(
556 a: CharArray,
557 b: CharArray,
558 opts: &IntersectOptions,
559) -> crate::BuiltinResult<IntersectEvaluation> {
560 if a.cols != b.cols {
561 return Err(intersect_error(
562 "intersect: inputs must have the same number of columns when using 'rows'",
563 ));
564 }
565 let rows_a = a.rows;
566 let rows_b = b.rows;
567 let cols = a.cols;
568
569 let mut b_map: HashMap<RowCharKey, usize> = HashMap::new();
570 for r in 0..rows_b {
571 let mut row_values = Vec::with_capacity(cols);
572 for c in 0..cols {
573 let idx = r * cols + c;
574 row_values.push(b.data[idx]);
575 }
576 let key = RowCharKey::from_slice(&row_values);
577 b_map.entry(key).or_insert(r);
578 }
579
580 let mut seen: HashSet<RowCharKey> = HashSet::new();
581 let mut entries = Vec::<CharRowIntersectEntry>::new();
582 let mut order_counter = 0usize;
583
584 for r in 0..rows_a {
585 let mut row_values = Vec::with_capacity(cols);
586 for c in 0..cols {
587 let idx = r * cols + c;
588 row_values.push(a.data[idx]);
589 }
590 let key = RowCharKey::from_slice(&row_values);
591 if seen.contains(&key) {
592 continue;
593 }
594 if let Some(&b_row) = b_map.get(&key) {
595 entries.push(CharRowIntersectEntry {
596 row_data: row_values,
597 a_row: r,
598 b_row,
599 order_rank: order_counter,
600 });
601 seen.insert(key);
602 order_counter += 1;
603 }
604 }
605
606 assemble_char_row_intersect(entries, opts, cols)
607}
608
609fn find_char_index(array: &CharArray, target: char) -> Option<usize> {
610 for col in 0..array.cols {
611 for row in 0..array.rows {
612 let data_idx = row * array.cols + col;
613 if array.data[data_idx] == target {
614 return Some(row + col * array.rows);
615 }
616 }
617 }
618 None
619}
620
621fn intersect_string(
622 a: StringArray,
623 b: StringArray,
624 opts: &IntersectOptions,
625) -> crate::BuiltinResult<IntersectEvaluation> {
626 if opts.rows {
627 intersect_string_rows(a, b, opts)
628 } else {
629 intersect_string_elements(a, b, opts)
630 }
631}
632
633fn intersect_string_elements(
634 a: StringArray,
635 b: StringArray,
636 opts: &IntersectOptions,
637) -> crate::BuiltinResult<IntersectEvaluation> {
638 let mut b_map: HashMap<String, usize> = HashMap::new();
639 for (idx, value) in b.data.iter().enumerate() {
640 b_map.entry(value.clone()).or_insert(idx);
641 }
642
643 let mut seen: HashSet<String> = HashSet::new();
644 let mut entries = Vec::<StringIntersectEntry>::new();
645 let mut order_counter = 0usize;
646
647 for (idx, value) in a.data.iter().enumerate() {
648 if seen.contains(value) {
649 continue;
650 }
651 if let Some(&b_idx) = b_map.get(value) {
652 entries.push(StringIntersectEntry {
653 value: value.clone(),
654 a_index: idx,
655 b_index: b_idx,
656 order_rank: order_counter,
657 });
658 seen.insert(value.clone());
659 order_counter += 1;
660 }
661 }
662
663 assemble_string_intersect(entries, opts)
664}
665
666fn intersect_string_rows(
667 a: StringArray,
668 b: StringArray,
669 opts: &IntersectOptions,
670) -> crate::BuiltinResult<IntersectEvaluation> {
671 if a.shape.len() != 2 || b.shape.len() != 2 {
672 return Err(intersect_error(
673 "intersect: 'rows' option requires 2-D string arrays",
674 ));
675 }
676 if a.shape[1] != b.shape[1] {
677 return Err(intersect_error(
678 "intersect: inputs must have the same number of columns when using 'rows'",
679 ));
680 }
681 let rows_a = a.shape[0];
682 let cols = a.shape[1];
683 let rows_b = b.shape[0];
684
685 let mut b_map: HashMap<RowStringKey, usize> = HashMap::new();
686 for r in 0..rows_b {
687 let mut row_values = Vec::with_capacity(cols);
688 for c in 0..cols {
689 let idx = r + c * rows_b;
690 row_values.push(b.data[idx].clone());
691 }
692 let key = RowStringKey::from_slice(&row_values);
693 b_map.entry(key).or_insert(r);
694 }
695
696 let mut seen: HashSet<RowStringKey> = HashSet::new();
697 let mut entries = Vec::<StringRowIntersectEntry>::new();
698 let mut order_counter = 0usize;
699
700 for r in 0..rows_a {
701 let mut row_values = Vec::with_capacity(cols);
702 for c in 0..cols {
703 let idx = r + c * rows_a;
704 row_values.push(a.data[idx].clone());
705 }
706 let key = RowStringKey::from_slice(&row_values);
707 if seen.contains(&key) {
708 continue;
709 }
710 if let Some(&b_row) = b_map.get(&key) {
711 entries.push(StringRowIntersectEntry {
712 row_data: row_values,
713 a_row: r,
714 b_row,
715 order_rank: order_counter,
716 });
717 seen.insert(key);
718 order_counter += 1;
719 }
720 }
721
722 assemble_string_row_intersect(entries, opts, cols)
723}
724
725#[derive(Debug, Clone)]
726pub struct IntersectEvaluation {
727 values: Value,
728 ia: Tensor,
729 ib: Tensor,
730}
731
732impl IntersectEvaluation {
733 fn new(values: Value, ia: Tensor, ib: Tensor) -> Self {
734 Self { values, ia, ib }
735 }
736
737 pub fn into_values_value(self) -> Value {
738 self.values
739 }
740
741 pub fn into_pair(self) -> (Value, Value) {
742 let ia = tensor::tensor_into_value(self.ia);
743 (self.values, ia)
744 }
745
746 pub fn into_triple(self) -> (Value, Value, Value) {
747 let ia = tensor::tensor_into_value(self.ia);
748 let ib = tensor::tensor_into_value(self.ib);
749 (self.values, ia, ib)
750 }
751
752 pub fn values_value(&self) -> Value {
753 self.values.clone()
754 }
755
756 pub fn ia_value(&self) -> Value {
757 tensor::tensor_into_value(self.ia.clone())
758 }
759
760 pub fn ib_value(&self) -> Value {
761 tensor::tensor_into_value(self.ib.clone())
762 }
763}
764
765#[derive(Debug)]
766struct NumericIntersectEntry {
767 value: f64,
768 a_index: usize,
769 b_index: usize,
770 order_rank: usize,
771}
772
773#[derive(Debug)]
774struct NumericRowIntersectEntry {
775 row_data: Vec<f64>,
776 a_row: usize,
777 b_row: usize,
778 order_rank: usize,
779}
780
781#[derive(Debug)]
782struct ComplexIntersectEntry {
783 value: (f64, f64),
784 a_index: usize,
785 b_index: usize,
786 order_rank: usize,
787}
788
789#[derive(Debug)]
790struct ComplexRowIntersectEntry {
791 row_data: Vec<(f64, f64)>,
792 a_row: usize,
793 b_row: usize,
794 order_rank: usize,
795}
796
797#[derive(Debug)]
798struct CharIntersectEntry {
799 ch: char,
800 a_index: usize,
801 b_index: usize,
802 order_rank: usize,
803}
804
805#[derive(Debug)]
806struct CharRowIntersectEntry {
807 row_data: Vec<char>,
808 a_row: usize,
809 b_row: usize,
810 order_rank: usize,
811}
812
813#[derive(Debug)]
814struct StringIntersectEntry {
815 value: String,
816 a_index: usize,
817 b_index: usize,
818 order_rank: usize,
819}
820
821#[derive(Debug)]
822struct StringRowIntersectEntry {
823 row_data: Vec<String>,
824 a_row: usize,
825 b_row: usize,
826 order_rank: usize,
827}
828
829fn assemble_numeric_intersect(
830 entries: Vec<NumericIntersectEntry>,
831 opts: &IntersectOptions,
832) -> crate::BuiltinResult<IntersectEvaluation> {
833 let mut order: Vec<usize> = (0..entries.len()).collect();
834 match opts.order {
835 IntersectOrder::Sorted => {
836 order.sort_by(|&lhs, &rhs| compare_f64(entries[lhs].value, entries[rhs].value));
837 }
838 IntersectOrder::Stable => {
839 order.sort_by_key(|&idx| entries[idx].order_rank);
840 }
841 }
842
843 let mut values = Vec::with_capacity(order.len());
844 let mut ia = Vec::with_capacity(order.len());
845 let mut ib = Vec::with_capacity(order.len());
846 for &idx in &order {
847 let entry = &entries[idx];
848 values.push(entry.value);
849 ia.push((entry.a_index + 1) as f64);
850 ib.push((entry.b_index + 1) as f64);
851 }
852
853 let value_tensor = Tensor::new(values, vec![order.len(), 1])
854 .map_err(|e| intersect_error(format!("intersect: {e}")))?;
855 let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
856 .map_err(|e| intersect_error(format!("intersect: {e}")))?;
857 let ib_tensor = Tensor::new(ib, vec![order.len(), 1])
858 .map_err(|e| intersect_error(format!("intersect: {e}")))?;
859
860 Ok(IntersectEvaluation::new(
861 tensor::tensor_into_value(value_tensor),
862 ia_tensor,
863 ib_tensor,
864 ))
865}
866
867fn assemble_numeric_row_intersect(
868 entries: Vec<NumericRowIntersectEntry>,
869 opts: &IntersectOptions,
870 cols: usize,
871) -> crate::BuiltinResult<IntersectEvaluation> {
872 let mut order: Vec<usize> = (0..entries.len()).collect();
873 match opts.order {
874 IntersectOrder::Sorted => {
875 order.sort_by(|&lhs, &rhs| {
876 compare_numeric_rows(&entries[lhs].row_data, &entries[rhs].row_data)
877 });
878 }
879 IntersectOrder::Stable => {
880 order.sort_by_key(|&idx| entries[idx].order_rank);
881 }
882 }
883
884 let rows_out = order.len();
885 let mut values = vec![0.0f64; rows_out * cols];
886 let mut ia = Vec::with_capacity(rows_out);
887 let mut ib = Vec::with_capacity(rows_out);
888
889 for (row_pos, &entry_idx) in order.iter().enumerate() {
890 let entry = &entries[entry_idx];
891 for col in 0..cols {
892 let dest = row_pos + col * rows_out;
893 values[dest] = entry.row_data[col];
894 }
895 ia.push((entry.a_row + 1) as f64);
896 ib.push((entry.b_row + 1) as f64);
897 }
898
899 let value_tensor = Tensor::new(values, vec![rows_out, cols])
900 .map_err(|e| intersect_error(format!("intersect: {e}")))?;
901 let ia_tensor = Tensor::new(ia, vec![rows_out, 1])
902 .map_err(|e| intersect_error(format!("intersect: {e}")))?;
903 let ib_tensor = Tensor::new(ib, vec![rows_out, 1])
904 .map_err(|e| intersect_error(format!("intersect: {e}")))?;
905
906 Ok(IntersectEvaluation::new(
907 tensor::tensor_into_value(value_tensor),
908 ia_tensor,
909 ib_tensor,
910 ))
911}
912
913fn assemble_complex_intersect(
914 entries: Vec<ComplexIntersectEntry>,
915 opts: &IntersectOptions,
916) -> crate::BuiltinResult<IntersectEvaluation> {
917 let mut order: Vec<usize> = (0..entries.len()).collect();
918 match opts.order {
919 IntersectOrder::Sorted => {
920 order.sort_by(|&lhs, &rhs| compare_complex(entries[lhs].value, entries[rhs].value));
921 }
922 IntersectOrder::Stable => {
923 order.sort_by_key(|&idx| entries[idx].order_rank);
924 }
925 }
926
927 let mut values = Vec::with_capacity(order.len());
928 let mut ia = Vec::with_capacity(order.len());
929 let mut ib = Vec::with_capacity(order.len());
930 for &idx in &order {
931 let entry = &entries[idx];
932 values.push(entry.value);
933 ia.push((entry.a_index + 1) as f64);
934 ib.push((entry.b_index + 1) as f64);
935 }
936
937 let value_tensor = ComplexTensor::new(values, vec![order.len(), 1])
938 .map_err(|e| intersect_error(format!("intersect: {e}")))?;
939 let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
940 .map_err(|e| intersect_error(format!("intersect: {e}")))?;
941 let ib_tensor = Tensor::new(ib, vec![order.len(), 1])
942 .map_err(|e| intersect_error(format!("intersect: {e}")))?;
943
944 Ok(IntersectEvaluation::new(
945 complex_tensor_into_value(value_tensor),
946 ia_tensor,
947 ib_tensor,
948 ))
949}
950
951fn assemble_complex_row_intersect(
952 entries: Vec<ComplexRowIntersectEntry>,
953 opts: &IntersectOptions,
954 cols: usize,
955) -> crate::BuiltinResult<IntersectEvaluation> {
956 let mut order: Vec<usize> = (0..entries.len()).collect();
957 match opts.order {
958 IntersectOrder::Sorted => {
959 order.sort_by(|&lhs, &rhs| {
960 compare_complex_rows(&entries[lhs].row_data, &entries[rhs].row_data)
961 });
962 }
963 IntersectOrder::Stable => {
964 order.sort_by_key(|&idx| entries[idx].order_rank);
965 }
966 }
967
968 let rows_out = order.len();
969 let mut values = vec![(0.0f64, 0.0f64); rows_out * cols];
970 let mut ia = Vec::with_capacity(rows_out);
971 let mut ib = Vec::with_capacity(rows_out);
972
973 for (row_pos, &entry_idx) in order.iter().enumerate() {
974 let entry = &entries[entry_idx];
975 for col in 0..cols {
976 let dest = row_pos + col * rows_out;
977 values[dest] = entry.row_data[col];
978 }
979 ia.push((entry.a_row + 1) as f64);
980 ib.push((entry.b_row + 1) as f64);
981 }
982
983 let value_tensor = ComplexTensor::new(values, vec![rows_out, cols])
984 .map_err(|e| intersect_error(format!("intersect: {e}")))?;
985 let ia_tensor = Tensor::new(ia, vec![rows_out, 1])
986 .map_err(|e| intersect_error(format!("intersect: {e}")))?;
987 let ib_tensor = Tensor::new(ib, vec![rows_out, 1])
988 .map_err(|e| intersect_error(format!("intersect: {e}")))?;
989
990 Ok(IntersectEvaluation::new(
991 complex_tensor_into_value(value_tensor),
992 ia_tensor,
993 ib_tensor,
994 ))
995}
996
997fn assemble_char_intersect(
998 entries: Vec<CharIntersectEntry>,
999 opts: &IntersectOptions,
1000 b: &CharArray,
1001) -> crate::BuiltinResult<IntersectEvaluation> {
1002 let mut order: Vec<usize> = (0..entries.len()).collect();
1003 match opts.order {
1004 IntersectOrder::Sorted => {
1005 order.sort_by(|&lhs, &rhs| entries[lhs].ch.cmp(&entries[rhs].ch));
1006 }
1007 IntersectOrder::Stable => {
1008 order.sort_by_key(|&idx| entries[idx].order_rank);
1009 }
1010 }
1011
1012 let mut values = Vec::with_capacity(order.len());
1013 let mut ia = Vec::with_capacity(order.len());
1014 let mut ib = Vec::with_capacity(order.len());
1015 for &idx in &order {
1016 let entry = &entries[idx];
1017 values.push(entry.ch);
1018 ia.push((entry.a_index + 1) as f64);
1019 let b_idx = find_char_index(b, entry.ch).unwrap_or(entry.b_index);
1020 ib.push((b_idx + 1) as f64);
1021 }
1022
1023 let value_array = CharArray::new(values, order.len(), 1)
1024 .map_err(|e| intersect_error(format!("intersect: {e}")))?;
1025 let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
1026 .map_err(|e| intersect_error(format!("intersect: {e}")))?;
1027 let ib_tensor = Tensor::new(ib, vec![order.len(), 1])
1028 .map_err(|e| intersect_error(format!("intersect: {e}")))?;
1029
1030 Ok(IntersectEvaluation::new(
1031 Value::CharArray(value_array),
1032 ia_tensor,
1033 ib_tensor,
1034 ))
1035}
1036
1037fn assemble_char_row_intersect(
1038 entries: Vec<CharRowIntersectEntry>,
1039 opts: &IntersectOptions,
1040 cols: usize,
1041) -> crate::BuiltinResult<IntersectEvaluation> {
1042 let mut order: Vec<usize> = (0..entries.len()).collect();
1043 match opts.order {
1044 IntersectOrder::Sorted => {
1045 order.sort_by(|&lhs, &rhs| {
1046 compare_char_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1047 });
1048 }
1049 IntersectOrder::Stable => {
1050 order.sort_by_key(|&idx| entries[idx].order_rank);
1051 }
1052 }
1053
1054 let rows_out = order.len();
1055 let mut values = vec!['\0'; rows_out * cols];
1056 let mut ia = Vec::with_capacity(rows_out);
1057 let mut ib = Vec::with_capacity(rows_out);
1058
1059 for (row_pos, &entry_idx) in order.iter().enumerate() {
1060 let entry = &entries[entry_idx];
1061 for col in 0..cols {
1062 let dest = row_pos * cols + col;
1063 values[dest] = entry.row_data[col];
1064 }
1065 ia.push((entry.a_row + 1) as f64);
1066 ib.push((entry.b_row + 1) as f64);
1067 }
1068
1069 let value_array = CharArray::new(values, rows_out, cols)
1070 .map_err(|e| intersect_error(format!("intersect: {e}")))?;
1071 let ia_tensor = Tensor::new(ia, vec![rows_out, 1])
1072 .map_err(|e| intersect_error(format!("intersect: {e}")))?;
1073 let ib_tensor = Tensor::new(ib, vec![rows_out, 1])
1074 .map_err(|e| intersect_error(format!("intersect: {e}")))?;
1075
1076 Ok(IntersectEvaluation::new(
1077 Value::CharArray(value_array),
1078 ia_tensor,
1079 ib_tensor,
1080 ))
1081}
1082
1083fn assemble_string_intersect(
1084 entries: Vec<StringIntersectEntry>,
1085 opts: &IntersectOptions,
1086) -> crate::BuiltinResult<IntersectEvaluation> {
1087 let mut order: Vec<usize> = (0..entries.len()).collect();
1088 match opts.order {
1089 IntersectOrder::Sorted => {
1090 order.sort_by(|&lhs, &rhs| entries[lhs].value.cmp(&entries[rhs].value));
1091 }
1092 IntersectOrder::Stable => {
1093 order.sort_by_key(|&idx| entries[idx].order_rank);
1094 }
1095 }
1096
1097 let mut values = Vec::with_capacity(order.len());
1098 let mut ia = Vec::with_capacity(order.len());
1099 let mut ib = Vec::with_capacity(order.len());
1100 for &idx in &order {
1101 let entry = &entries[idx];
1102 values.push(entry.value.clone());
1103 ia.push((entry.a_index + 1) as f64);
1104 ib.push((entry.b_index + 1) as f64);
1105 }
1106
1107 let value_array = StringArray::new(values, vec![order.len(), 1])
1108 .map_err(|e| intersect_error(format!("intersect: {e}")))?;
1109 let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
1110 .map_err(|e| intersect_error(format!("intersect: {e}")))?;
1111 let ib_tensor = Tensor::new(ib, vec![order.len(), 1])
1112 .map_err(|e| intersect_error(format!("intersect: {e}")))?;
1113
1114 Ok(IntersectEvaluation::new(
1115 Value::StringArray(value_array),
1116 ia_tensor,
1117 ib_tensor,
1118 ))
1119}
1120
1121fn assemble_string_row_intersect(
1122 entries: Vec<StringRowIntersectEntry>,
1123 opts: &IntersectOptions,
1124 cols: usize,
1125) -> crate::BuiltinResult<IntersectEvaluation> {
1126 let mut order: Vec<usize> = (0..entries.len()).collect();
1127 match opts.order {
1128 IntersectOrder::Sorted => {
1129 order.sort_by(|&lhs, &rhs| {
1130 compare_string_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1131 });
1132 }
1133 IntersectOrder::Stable => {
1134 order.sort_by_key(|&idx| entries[idx].order_rank);
1135 }
1136 }
1137
1138 let rows_out = order.len();
1139 let mut values = vec![String::new(); rows_out * cols];
1140 let mut ia = Vec::with_capacity(rows_out);
1141 let mut ib = Vec::with_capacity(rows_out);
1142
1143 for (row_pos, &entry_idx) in order.iter().enumerate() {
1144 let entry = &entries[entry_idx];
1145 for col in 0..cols {
1146 let dest = row_pos + col * rows_out;
1147 values[dest] = entry.row_data[col].clone();
1148 }
1149 ia.push((entry.a_row + 1) as f64);
1150 ib.push((entry.b_row + 1) as f64);
1151 }
1152
1153 let value_array = StringArray::new(values, vec![rows_out, cols])
1154 .map_err(|e| intersect_error(format!("intersect: {e}")))?;
1155 let ia_tensor = Tensor::new(ia, vec![rows_out, 1])
1156 .map_err(|e| intersect_error(format!("intersect: {e}")))?;
1157 let ib_tensor = Tensor::new(ib, vec![rows_out, 1])
1158 .map_err(|e| intersect_error(format!("intersect: {e}")))?;
1159
1160 Ok(IntersectEvaluation::new(
1161 Value::StringArray(value_array),
1162 ia_tensor,
1163 ib_tensor,
1164 ))
1165}
1166
1167#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1168struct NumericRowKey(Vec<u64>);
1169
1170impl NumericRowKey {
1171 fn from_slice(values: &[f64]) -> Self {
1172 NumericRowKey(values.iter().map(|&v| canonicalize_f64(v)).collect())
1173 }
1174}
1175
1176#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
1177struct ComplexKey {
1178 re: u64,
1179 im: u64,
1180}
1181
1182impl ComplexKey {
1183 fn new(value: (f64, f64)) -> Self {
1184 Self {
1185 re: canonicalize_f64(value.0),
1186 im: canonicalize_f64(value.1),
1187 }
1188 }
1189}
1190
1191#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1192struct RowCharKey(Vec<u32>);
1193
1194impl RowCharKey {
1195 fn from_slice(values: &[char]) -> Self {
1196 RowCharKey(values.iter().map(|&ch| ch as u32).collect())
1197 }
1198}
1199
1200#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1201struct RowStringKey(Vec<String>);
1202
1203impl RowStringKey {
1204 fn from_slice(values: &[String]) -> Self {
1205 RowStringKey(values.to_vec())
1206 }
1207}
1208
1209fn scalar_complex_tensor(re: f64, im: f64) -> crate::BuiltinResult<ComplexTensor> {
1210 ComplexTensor::new(vec![(re, im)], vec![1, 1])
1211 .map_err(|e| intersect_error(format!("intersect: {e}")))
1212}
1213
1214fn tensor_to_complex_owned(name: &str, tensor: Tensor) -> crate::BuiltinResult<ComplexTensor> {
1215 let Tensor { data, shape, .. } = tensor;
1216 let complex: Vec<(f64, f64)> = data.into_iter().map(|re| (re, 0.0)).collect();
1217 ComplexTensor::new(complex, shape).map_err(|e| intersect_error(format!("{name}: {e}")))
1218}
1219
1220fn value_into_complex_tensor(value: Value) -> crate::BuiltinResult<ComplexTensor> {
1221 match value {
1222 Value::ComplexTensor(tensor) => Ok(tensor),
1223 Value::Complex(re, im) => scalar_complex_tensor(re, im),
1224 other => {
1225 let tensor = tensor::value_into_tensor_for("intersect", other)
1226 .map_err(|e| intersect_error(e))?;
1227 tensor_to_complex_owned("intersect", tensor)
1228 }
1229 }
1230}
1231
1232fn canonicalize_f64(value: f64) -> u64 {
1233 if value.is_nan() {
1234 0x7ff8_0000_0000_0000u64
1235 } else if value == 0.0 {
1236 0u64
1237 } else {
1238 value.to_bits()
1239 }
1240}
1241
1242fn compare_f64(a: f64, b: f64) -> Ordering {
1243 if a.is_nan() {
1244 if b.is_nan() {
1245 Ordering::Equal
1246 } else {
1247 Ordering::Greater
1248 }
1249 } else if b.is_nan() {
1250 Ordering::Less
1251 } else {
1252 a.partial_cmp(&b).unwrap_or(Ordering::Equal)
1253 }
1254}
1255
1256fn compare_numeric_rows(a: &[f64], b: &[f64]) -> Ordering {
1257 for (lhs, rhs) in a.iter().zip(b.iter()) {
1258 let ord = compare_f64(*lhs, *rhs);
1259 if ord != Ordering::Equal {
1260 return ord;
1261 }
1262 }
1263 Ordering::Equal
1264}
1265
1266fn complex_is_nan(value: (f64, f64)) -> bool {
1267 value.0.is_nan() || value.1.is_nan()
1268}
1269
1270fn compare_complex(a: (f64, f64), b: (f64, f64)) -> Ordering {
1271 match (complex_is_nan(a), complex_is_nan(b)) {
1272 (true, true) => Ordering::Equal,
1273 (true, false) => Ordering::Greater,
1274 (false, true) => Ordering::Less,
1275 (false, false) => {
1276 let mag_a = a.0.hypot(a.1);
1277 let mag_b = b.0.hypot(b.1);
1278 let mag_cmp = compare_f64(mag_a, mag_b);
1279 if mag_cmp != Ordering::Equal {
1280 return mag_cmp;
1281 }
1282 let re_cmp = compare_f64(a.0, b.0);
1283 if re_cmp != Ordering::Equal {
1284 return re_cmp;
1285 }
1286 compare_f64(a.1, b.1)
1287 }
1288 }
1289}
1290
1291fn compare_complex_rows(a: &[(f64, f64)], b: &[(f64, f64)]) -> Ordering {
1292 for (lhs, rhs) in a.iter().zip(b.iter()) {
1293 let ord = compare_complex(*lhs, *rhs);
1294 if ord != Ordering::Equal {
1295 return ord;
1296 }
1297 }
1298 Ordering::Equal
1299}
1300
1301fn compare_char_rows(a: &[char], b: &[char]) -> Ordering {
1302 for (lhs, rhs) in a.iter().zip(b.iter()) {
1303 let ord = lhs.cmp(rhs);
1304 if ord != Ordering::Equal {
1305 return ord;
1306 }
1307 }
1308 Ordering::Equal
1309}
1310
1311fn compare_string_rows(a: &[String], b: &[String]) -> Ordering {
1312 for (lhs, rhs) in a.iter().zip(b.iter()) {
1313 let ord = lhs.cmp(rhs);
1314 if ord != Ordering::Equal {
1315 return ord;
1316 }
1317 }
1318 Ordering::Equal
1319}
1320
1321#[cfg(test)]
1322pub(crate) mod tests {
1323 use super::*;
1324 use crate::builtins::common::test_support;
1325 use runmat_accelerate_api::HostTensorView;
1326 use runmat_builtins::{ResolveContext, Type};
1327
1328 fn error_message(err: crate::RuntimeError) -> String {
1329 err.message().to_string()
1330 }
1331
1332 fn evaluate_sync(
1333 a: Value,
1334 b: Value,
1335 rest: &[Value],
1336 ) -> crate::BuiltinResult<IntersectEvaluation> {
1337 futures::executor::block_on(evaluate(a, b, rest))
1338 }
1339
1340 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1341 #[test]
1342 fn intersect_numeric_sorted() {
1343 let a = Tensor::new(vec![5.0, 7.0, 5.0, 1.0], vec![4, 1]).unwrap();
1344 let b = Tensor::new(vec![7.0, 1.0, 3.0], vec![3, 1]).unwrap();
1345 let eval = intersect_numeric_elements(
1346 a,
1347 b,
1348 &IntersectOptions {
1349 rows: false,
1350 order: IntersectOrder::Sorted,
1351 },
1352 )
1353 .expect("intersect");
1354 let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1355 assert_eq!(values.data, vec![1.0, 7.0]);
1356 let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1357 let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1358 assert_eq!(ia.data, vec![4.0, 2.0]);
1359 assert_eq!(ib.data, vec![2.0, 1.0]);
1360 }
1361
1362 #[test]
1363 fn intersect_type_resolver_numeric() {
1364 assert_eq!(
1365 set_values_output_type(&[Type::tensor()], &ResolveContext::new(Vec::new())),
1366 Type::tensor()
1367 );
1368 }
1369
1370 #[test]
1371 fn intersect_type_resolver_string_array() {
1372 assert_eq!(
1373 set_values_output_type(
1374 &[Type::cell_of(Type::String)],
1375 &ResolveContext::new(Vec::new()),
1376 ),
1377 Type::cell_of(Type::String)
1378 );
1379 }
1380
1381 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1382 #[test]
1383 fn intersect_numeric_stable() {
1384 let a = Tensor::new(vec![4.0, 2.0, 4.0, 1.0, 3.0], vec![5, 1]).unwrap();
1385 let b = Tensor::new(vec![3.0, 4.0, 5.0, 1.0], vec![4, 1]).unwrap();
1386 let eval = intersect_numeric_elements(
1387 a,
1388 b,
1389 &IntersectOptions {
1390 rows: false,
1391 order: IntersectOrder::Stable,
1392 },
1393 )
1394 .expect("intersect");
1395 let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1396 assert_eq!(values.data, vec![4.0, 1.0, 3.0]);
1397 let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1398 let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1399 assert_eq!(ia.data, vec![1.0, 4.0, 5.0]);
1400 assert_eq!(ib.data, vec![2.0, 4.0, 1.0]);
1401 }
1402
1403 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1404 #[test]
1405 fn intersect_numeric_handles_nan() {
1406 let a = Tensor::new(vec![f64::NAN, 1.0, f64::NAN], vec![3, 1]).unwrap();
1407 let b = Tensor::new(vec![2.0, f64::NAN], vec![2, 1]).unwrap();
1408 let eval = intersect_numeric_elements(
1409 a,
1410 b,
1411 &IntersectOptions {
1412 rows: false,
1413 order: IntersectOrder::Sorted,
1414 },
1415 )
1416 .expect("intersect");
1417 let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1418 assert_eq!(values.data.len(), 1);
1419 assert!(values.data[0].is_nan());
1420 let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1421 let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1422 assert_eq!(ia.data, vec![1.0]);
1423 assert_eq!(ib.data, vec![2.0]);
1424 }
1425
1426 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1427 #[test]
1428 fn intersect_complex_with_real_inputs() {
1429 let complex =
1430 ComplexTensor::new(vec![(1.0, 0.0), (2.0, 0.0), (3.0, 1.0)], vec![3, 1]).unwrap();
1431 let real = Tensor::new(vec![2.0, 4.0, 1.0], vec![3, 1]).unwrap();
1432 let real_complex = tensor_to_complex_owned("intersect", real).unwrap();
1433 let eval = intersect_complex(
1434 complex,
1435 real_complex,
1436 &IntersectOptions {
1437 rows: false,
1438 order: IntersectOrder::Sorted,
1439 },
1440 )
1441 .expect("intersect complex");
1442 match eval.values_value() {
1443 Value::ComplexTensor(t) => {
1444 assert_eq!(t.data, vec![(1.0, 0.0), (2.0, 0.0)]);
1445 }
1446 other => panic!("expected complex tensor, got {other:?}"),
1447 }
1448 let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1449 let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1450 assert_eq!(ia.data, vec![1.0, 2.0]);
1451 assert_eq!(ib.data, vec![3.0, 1.0]);
1452 }
1453
1454 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1455 #[test]
1456 fn intersect_numeric_rows_default() {
1457 let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
1458 let b = Tensor::new(vec![1.0, 5.0, 2.0, 6.0], vec![2, 2]).unwrap();
1459 let eval = intersect_numeric_rows(
1460 a,
1461 b,
1462 &IntersectOptions {
1463 rows: true,
1464 order: IntersectOrder::Sorted,
1465 },
1466 )
1467 .expect("intersect rows");
1468 let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1469 assert_eq!(values.shape, vec![1, 2]);
1470 assert_eq!(values.data, vec![1.0, 2.0]);
1471 let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1472 let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1473 assert_eq!(ia.data, vec![1.0]);
1474 assert_eq!(ib.data, vec![1.0]);
1475 }
1476
1477 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1478 #[test]
1479 fn intersect_char_elements_basic() {
1480 let a = CharArray::new("cab".chars().collect(), 1, 3).unwrap();
1481 let b = CharArray::new("bcd".chars().collect(), 1, 3).unwrap();
1482 assert_eq!(find_char_index(&b, 'b'), Some(0));
1483 assert_eq!(find_char_index(&b, 'c'), Some(1));
1484 let b_for_eval = CharArray::new("bcd".chars().collect(), 1, 3).unwrap();
1485 let eval = intersect_char_elements(
1486 a,
1487 b_for_eval,
1488 &IntersectOptions {
1489 rows: false,
1490 order: IntersectOrder::Sorted,
1491 },
1492 )
1493 .expect("intersect char");
1494 match eval.values_value() {
1495 Value::CharArray(arr) => {
1496 assert_eq!(arr.rows, 2);
1497 assert_eq!(arr.cols, 1);
1498 assert_eq!(arr.data, vec!['b', 'c']);
1499 }
1500 other => panic!("expected char array, got {other:?}"),
1501 }
1502 let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1503 let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1504 assert_eq!(ia.data, vec![3.0, 1.0]);
1505 assert_eq!(ib.data, vec![1.0, 2.0]);
1506 }
1507
1508 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1509 #[test]
1510 fn intersect_string_elements_stable() {
1511 let a = StringArray::new(
1512 vec!["apple".into(), "orange".into(), "pear".into()],
1513 vec![3, 1],
1514 )
1515 .unwrap();
1516 let b = StringArray::new(
1517 vec!["pear".into(), "grape".into(), "orange".into()],
1518 vec![3, 1],
1519 )
1520 .unwrap();
1521 let eval = intersect_string_elements(
1522 a,
1523 b,
1524 &IntersectOptions {
1525 rows: false,
1526 order: IntersectOrder::Stable,
1527 },
1528 )
1529 .expect("intersect string");
1530 match eval.values_value() {
1531 Value::StringArray(arr) => {
1532 assert_eq!(arr.shape, vec![2, 1]);
1533 assert_eq!(arr.data, vec!["orange".to_string(), "pear".to_string()]);
1534 }
1535 other => panic!("expected string array, got {other:?}"),
1536 }
1537 let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1538 let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1539 assert_eq!(ia.data, vec![2.0, 3.0]);
1540 assert_eq!(ib.data, vec![3.0, 1.0]);
1541 }
1542
1543 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1544 #[test]
1545 fn intersect_rejects_legacy_option() {
1546 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1547 let err = error_message(
1548 evaluate_sync(
1549 Value::Tensor(tensor.clone()),
1550 Value::Tensor(tensor),
1551 &[Value::from("legacy")],
1552 )
1553 .unwrap_err(),
1554 );
1555 assert!(err.contains("legacy"));
1556 }
1557
1558 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1559 #[test]
1560 fn intersect_rows_dimension_mismatch() {
1561 let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1562 let b = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1563 let err = error_message(
1564 intersect_numeric_rows(
1565 a,
1566 b,
1567 &IntersectOptions {
1568 rows: true,
1569 order: IntersectOrder::Sorted,
1570 },
1571 )
1572 .unwrap_err(),
1573 );
1574 assert!(err.contains("same number of columns"));
1575 }
1576
1577 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1578 #[test]
1579 fn intersect_mixed_types_error() {
1580 let a = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1581 let b = CharArray::new(vec!['a', 'b'], 1, 2).unwrap();
1582 let err = error_message(
1583 intersect_host(
1584 Value::Tensor(a),
1585 Value::CharArray(b),
1586 &IntersectOptions {
1587 rows: false,
1588 order: IntersectOrder::Sorted,
1589 },
1590 )
1591 .unwrap_err(),
1592 );
1593 assert!(err.contains("unsupported input type"));
1594 }
1595
1596 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1597 #[test]
1598 fn intersect_gpu_roundtrip() {
1599 test_support::with_test_provider(|provider| {
1600 let a = Tensor::new(vec![4.0, 1.0, 2.0, 1.0], vec![4, 1]).unwrap();
1601 let b = Tensor::new(vec![2.0, 5.0, 1.0], vec![3, 1]).unwrap();
1602 let view_a = HostTensorView {
1603 data: &a.data,
1604 shape: &a.shape,
1605 };
1606 let view_b = HostTensorView {
1607 data: &b.data,
1608 shape: &b.shape,
1609 };
1610 let handle_a = provider.upload(&view_a).expect("upload A");
1611 let handle_b = provider.upload(&view_b).expect("upload B");
1612 let eval = evaluate_sync(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
1613 .expect("intersect");
1614 let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1615 assert_eq!(values.data, vec![1.0, 2.0]);
1616 let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1617 let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1618 assert_eq!(ia.data, vec![2.0, 3.0]);
1619 assert_eq!(ib.data, vec![3.0, 1.0]);
1620 });
1621 }
1622
1623 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1624 #[test]
1625 fn intersect_two_outputs_from_evaluate() {
1626 let a = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1627 let b = Tensor::new(vec![3.0, 1.0], vec![2, 1]).unwrap();
1628 let eval = intersect_numeric_elements(
1629 a,
1630 b,
1631 &IntersectOptions {
1632 rows: false,
1633 order: IntersectOrder::Sorted,
1634 },
1635 )
1636 .unwrap();
1637 let (_c, ia) = eval.clone().into_pair();
1638 let ia_tensor = tensor::value_into_tensor_for("intersect", ia).unwrap();
1639 assert_eq!(ia_tensor.data, vec![1.0, 3.0]);
1640 let (_c, ia2, ib2) = eval.into_triple();
1641 let ia_tensor2 = tensor::value_into_tensor_for("intersect", ia2).unwrap();
1642 let ib_tensor2 = tensor::value_into_tensor_for("intersect", ib2).unwrap();
1643 assert_eq!(ia_tensor2.data, vec![1.0, 3.0]);
1644 assert_eq!(ib_tensor2.data, vec![2.0, 1.0]);
1645 }
1646
1647 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1648 #[test]
1649 #[cfg(feature = "wgpu")]
1650 fn intersect_wgpu_matches_cpu() {
1651 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1652 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1653 );
1654 let a = Tensor::new(vec![4.0, 1.0, 2.0, 3.0], vec![4, 1]).unwrap();
1655 let b = Tensor::new(vec![2.0, 6.0, 3.0], vec![3, 1]).unwrap();
1656
1657 let cpu_eval = intersect_numeric_elements(
1658 a.clone(),
1659 b.clone(),
1660 &IntersectOptions {
1661 rows: false,
1662 order: IntersectOrder::Sorted,
1663 },
1664 )
1665 .unwrap();
1666 let cpu_values =
1667 tensor::value_into_tensor_for("intersect", cpu_eval.values_value()).unwrap();
1668 let cpu_ia = tensor::value_into_tensor_for("intersect", cpu_eval.ia_value()).unwrap();
1669 let cpu_ib = tensor::value_into_tensor_for("intersect", cpu_eval.ib_value()).unwrap();
1670
1671 let provider = runmat_accelerate_api::provider().expect("provider");
1672 let view_a = HostTensorView {
1673 data: &a.data,
1674 shape: &a.shape,
1675 };
1676 let view_b = HostTensorView {
1677 data: &b.data,
1678 shape: &b.shape,
1679 };
1680 let handle_a = provider.upload(&view_a).expect("upload A");
1681 let handle_b = provider.upload(&view_b).expect("upload B");
1682 let gpu_eval = evaluate_sync(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
1683 .expect("intersect");
1684 let gpu_values =
1685 tensor::value_into_tensor_for("intersect", gpu_eval.values_value()).unwrap();
1686 let gpu_ia = tensor::value_into_tensor_for("intersect", gpu_eval.ia_value()).unwrap();
1687 let gpu_ib = tensor::value_into_tensor_for("intersect", gpu_eval.ib_value()).unwrap();
1688
1689 assert_eq!(gpu_values.data, cpu_values.data);
1690 assert_eq!(gpu_ia.data, cpu_ia.data);
1691 assert_eq!(gpu_ib.data, cpu_ib.data);
1692 }
1693}