1use std::cmp::Ordering;
9use std::collections::{HashMap, HashSet};
10
11use runmat_accelerate_api::GpuTensorHandle;
12use runmat_builtins::{CharArray, ComplexTensor, StringArray, Tensor, Value};
13use runmat_macros::runtime_builtin;
14
15use crate::builtins::common::gpu_helpers;
16use crate::builtins::common::random_args::complex_tensor_into_value;
17use crate::builtins::common::spec::{
18 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
19 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
20};
21use crate::builtins::common::tensor;
22#[cfg(feature = "doc_export")]
23use crate::register_builtin_doc_text;
24use crate::{register_builtin_fusion_spec, register_builtin_gpu_spec};
25
26#[cfg(feature = "doc_export")]
27pub const DOC_MD: &str = r#"---
28title: "intersect"
29category: "array/sorting_sets"
30keywords: ["intersect", "set", "stable", "rows", "indices", "gpu"]
31summary: "Return common elements or rows across arrays with MATLAB-compatible ordering and index outputs."
32references:
33 - https://www.mathworks.com/help/matlab/ref/intersect.html
34gpu_support:
35 elementwise: false
36 reduction: false
37 precisions: ["f32", "f64"]
38 broadcasting: "none"
39 notes: "When providers lack a dedicated `intersect` hook, RunMat gathers GPU tensors and executes the CPU implementation."
40fusion:
41 elementwise: false
42 reduction: false
43 max_inputs: 2
44 constants: "inline"
45requires_feature: null
46tested:
47 unit: "builtins::array::sorting_sets::intersect::tests"
48 integration: "builtins::array::sorting_sets::intersect::tests::intersect_gpu_roundtrip"
49---
50
51# What does the `intersect` function do in MATLAB / RunMat?
52`intersect(A, B)` computes the set intersection of `A` and `B`. By default it returns the
53distinct common values sorted ascending, along with optional index outputs that map each
54result back to its originating position in the inputs.
55
56## How does the `intersect` function behave in MATLAB / RunMat?
57- `intersect(A, B)` flattens both inputs column-major, removes duplicates, and returns the
58 sorted intersection.
59- `[C, IA, IB] = intersect(A, B)` also returns index vectors so that `C = A(IA)` and
60 `C = B(IB)`.
61- `intersect(A, B, 'stable')` preserves the first appearance order from `A`.
62- `intersect(A, B, 'rows')` treats each row as an element. Inputs must share the same number
63 of columns.
64- Character arrays, string arrays, logical arrays, complex values, and numerical types are
65 fully supported. Mixed classes raise a descriptive error.
66- Legacy flags such as `'legacy'` or `'R2012a'` are not supported in RunMat.
67
68## `intersect` Function GPU Execution Behavior
69`intersect` registers as a residency sink. When the active acceleration provider offers a
70dedicated `intersect` hook, the set logic can execute directly on the device and return
71GPU-resident results. Until such a hook lands, RunMat automatically gathers GPU tensors to
72host memory, runs the CPU implementation, and materializes host-side outputs so behavior
73matches MathWorks MATLAB exactly.
74
75## Examples of using the `intersect` function in MATLAB / RunMat
76
77### Finding common values in numeric vectors
78```matlab
79A = [5 7 5 1];
80B = [7 1 3];
81[C, IA, IB] = intersect(A, B);
82```
83Expected output:
84```matlab
85C =
86 1
87 7
88IA =
89 4
90 2
91IB =
92 2
93 1
94```
95
96### Preserving input order with `'stable'`
97```matlab
98A = [4 2 4 1 3];
99B = [3 4 5 1];
100C = intersect(A, B, 'stable');
101```
102Expected output:
103```matlab
104C =
105 4
106 1
107 3
108```
109
110### Intersecting matrix rows
111```matlab
112A = [1 2; 3 4; 1 2];
113B = [2 3; 1 2];
114[C, IA, IB] = intersect(A, B, 'rows');
115```
116Expected output:
117```matlab
118C =
119 1 2
120IA =
121 1
122IB =
123 2
124```
125
126### Working with strings and character arrays
127```matlab
128names1 = ["apple" "orange" "pear"];
129names2 = ["pear" "grape" "orange"];
130[common, IA, IB] = intersect(names1, names2, 'stable');
131```
132Expected output:
133```matlab
134common =
135 1×2 string array
136 "orange" "pear"
137IA =
138 2
139 3
140IB =
141 3
142 1
143```
144
145### Using `intersect` on GPU arrays
146```matlab
147G = gpuArray([10 4 6 4]);
148H = gpuArray([6 4 2]);
149result = intersect(G, H);
150```
151RunMat gathers the data to host memory (until providers implement a device kernel) and returns:
152```matlab
153result =
154 4
155 6
156```
157
158## FAQ
159
160### Does `intersect` keep duplicate values?
161No. MATLAB and RunMat return each common value at most once. Use other logic (such as
162logical indexing) if you need multiset behavior.
163
164### What is the default ordering?
165`intersect` sorts results ascending by default. Specify `'stable'` to preserve the input
166order from `A`.
167
168### Can I intersect rows that contain strings?
169Yes. String arrays support both element and row intersections. When using `'rows'`, the
170inputs must have the same number of columns.
171
172### Are NaN values considered equal?
173Yes. `NaN` values are treated as equal for the purposes of intersection, matching MATLAB.
174
175### Is the `'legacy'` flag supported?
176No. RunMat only implements the modern MATLAB semantics. Passing `'legacy'` or `'R2012a'`
177raises an error.
178
179## GPU residency in RunMat (Do I need `gpuArray`?)
180You usually do **not** need to call `gpuArray` manually. RunMat's planner keeps tensors
181resident on the GPU when profitable. If a provider lacks an `intersect` kernel the runtime
182gathers automatically, so explicit residency management is rarely needed. Explicit calls to
183`gpuArray` remain available for compatibility with MathWorks MATLAB workflows.
184
185## See Also
186[union](./union), [unique](./unique), [sort](./sort), [sortrows](./sortrows)
187
188## Source & Feedback
189- Source code: [`crates/runmat-runtime/src/builtins/array/sorting_sets/intersect.rs`](https://github.com/runmat-org/runmat/blob/main/crates/runmat-runtime/src/builtins/array/sorting_sets/intersect.rs)
190- Found a bug? [Open an issue](https://github.com/runmat-org/runmat/issues/new/choose) with details and a minimal repro.
191"#;
192
193pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
194 name: "intersect",
195 op_kind: GpuOpKind::Custom("intersect"),
196 supported_precisions: &[ScalarType::F32, ScalarType::F64],
197 broadcast: BroadcastSemantics::None,
198 provider_hooks: &[ProviderHook::Custom("intersect")],
199 constant_strategy: ConstantStrategy::InlineLiteral,
200 residency: ResidencyPolicy::GatherImmediately,
201 nan_mode: ReductionNaN::Include,
202 two_pass_threshold: None,
203 workgroup_size: None,
204 accepts_nan_mode: true,
205 notes:
206 "Providers may expose a dedicated intersect hook; otherwise tensors are gathered and processed on the host.",
207};
208
209register_builtin_gpu_spec!(GPU_SPEC);
210
211pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
212 name: "intersect",
213 shape: ShapeRequirements::Any,
214 constant_strategy: ConstantStrategy::InlineLiteral,
215 elementwise: None,
216 reduction: None,
217 emits_nan: true,
218 notes: "`intersect` materialises its inputs and terminates fusion chains; upstream GPU tensors are gathered when necessary.",
219};
220
221register_builtin_fusion_spec!(FUSION_SPEC);
222
223#[cfg(feature = "doc_export")]
224register_builtin_doc_text!("intersect", DOC_MD);
225
226#[runtime_builtin(
227 name = "intersect",
228 category = "array/sorting_sets",
229 summary = "Return common elements or rows across arrays with MATLAB-compatible ordering and index outputs.",
230 keywords = "intersect,set,stable,rows,indices,gpu",
231 accel = "array_construct",
232 sink = true
233)]
234fn intersect_builtin(a: Value, b: Value, rest: Vec<Value>) -> Result<Value, String> {
235 evaluate(a, b, &rest).map(|eval| eval.into_values_value())
236}
237
238pub fn evaluate(a: Value, b: Value, rest: &[Value]) -> Result<IntersectEvaluation, String> {
240 let opts = parse_options(rest)?;
241 match (a, b) {
242 (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
243 intersect_gpu_pair(handle_a, handle_b, &opts)
244 }
245 (Value::GpuTensor(handle_a), other) => intersect_gpu_mixed(handle_a, other, &opts, true),
246 (other, Value::GpuTensor(handle_b)) => intersect_gpu_mixed(handle_b, other, &opts, false),
247 (left, right) => intersect_host(left, right, &opts),
248 }
249}
250
251#[derive(Debug, Clone, Copy, PartialEq, Eq)]
252enum IntersectOrder {
253 Sorted,
254 Stable,
255}
256
257#[derive(Debug, Clone)]
258struct IntersectOptions {
259 rows: bool,
260 order: IntersectOrder,
261}
262
263fn parse_options(rest: &[Value]) -> Result<IntersectOptions, String> {
264 let mut opts = IntersectOptions {
265 rows: false,
266 order: IntersectOrder::Sorted,
267 };
268 let mut seen_order: Option<IntersectOrder> = None;
269
270 for arg in rest {
271 let text = tensor::value_to_string(arg)
272 .ok_or_else(|| "intersect: expected string option arguments".to_string())?;
273 let lowered = text.trim().to_ascii_lowercase();
274 match lowered.as_str() {
275 "rows" => opts.rows = true,
276 "sorted" => {
277 if let Some(prev) = seen_order {
278 if prev != IntersectOrder::Sorted {
279 return Err("intersect: cannot combine 'sorted' with 'stable'".to_string());
280 }
281 }
282 seen_order = Some(IntersectOrder::Sorted);
283 opts.order = IntersectOrder::Sorted;
284 }
285 "stable" => {
286 if let Some(prev) = seen_order {
287 if prev != IntersectOrder::Stable {
288 return Err("intersect: cannot combine 'sorted' with 'stable'".to_string());
289 }
290 }
291 seen_order = Some(IntersectOrder::Stable);
292 opts.order = IntersectOrder::Stable;
293 }
294 "legacy" | "r2012a" => {
295 return Err("intersect: the 'legacy' behaviour is not supported".to_string());
296 }
297 other => return Err(format!("intersect: unrecognised option '{other}'")),
298 }
299 }
300
301 Ok(opts)
302}
303
304fn intersect_gpu_pair(
305 handle_a: GpuTensorHandle,
306 handle_b: GpuTensorHandle,
307 opts: &IntersectOptions,
308) -> Result<IntersectEvaluation, String> {
309 let tensor_a = gpu_helpers::gather_tensor(&handle_a)?;
310 let tensor_b = gpu_helpers::gather_tensor(&handle_b)?;
311 intersect_numeric(tensor_a, tensor_b, opts)
312}
313
314fn intersect_gpu_mixed(
315 handle_gpu: GpuTensorHandle,
316 other: Value,
317 opts: &IntersectOptions,
318 gpu_is_a: bool,
319) -> Result<IntersectEvaluation, String> {
320 let tensor_gpu = gpu_helpers::gather_tensor(&handle_gpu)?;
321 let tensor_other = tensor::value_into_tensor_for("intersect", other)?;
322 if gpu_is_a {
323 intersect_numeric(tensor_gpu, tensor_other, opts)
324 } else {
325 intersect_numeric(tensor_other, tensor_gpu, opts)
326 }
327}
328
329fn intersect_host(
330 a: Value,
331 b: Value,
332 opts: &IntersectOptions,
333) -> Result<IntersectEvaluation, String> {
334 match (a, b) {
335 (Value::ComplexTensor(at), Value::ComplexTensor(bt)) => intersect_complex(at, bt, opts),
336 (Value::ComplexTensor(at), Value::Complex(re, im)) => {
337 let bt = scalar_complex_tensor(re, im)?;
338 intersect_complex(at, bt, opts)
339 }
340 (Value::Complex(re, im), Value::ComplexTensor(bt)) => {
341 let at = scalar_complex_tensor(re, im)?;
342 intersect_complex(at, bt, opts)
343 }
344 (Value::Complex(a_re, a_im), Value::Complex(b_re, b_im)) => {
345 let at = scalar_complex_tensor(a_re, a_im)?;
346 let bt = scalar_complex_tensor(b_re, b_im)?;
347 intersect_complex(at, bt, opts)
348 }
349 (Value::ComplexTensor(at), other) => {
350 let bt = value_into_complex_tensor(other)?;
351 intersect_complex(at, bt, opts)
352 }
353 (other, Value::ComplexTensor(bt)) => {
354 let at = value_into_complex_tensor(other)?;
355 intersect_complex(at, bt, opts)
356 }
357 (Value::Complex(re, im), other) => {
358 let at = scalar_complex_tensor(re, im)?;
359 let bt = value_into_complex_tensor(other)?;
360 intersect_complex(at, bt, opts)
361 }
362 (other, Value::Complex(re, im)) => {
363 let at = value_into_complex_tensor(other)?;
364 let bt = scalar_complex_tensor(re, im)?;
365 intersect_complex(at, bt, opts)
366 }
367
368 (Value::CharArray(ac), Value::CharArray(bc)) => intersect_char(ac, bc, opts),
369
370 (Value::StringArray(astring), Value::StringArray(bstring)) => {
371 intersect_string(astring, bstring, opts)
372 }
373 (Value::StringArray(astring), Value::String(b)) => {
374 let bstring =
375 StringArray::new(vec![b], vec![1, 1]).map_err(|e| format!("intersect: {e}"))?;
376 intersect_string(astring, bstring, opts)
377 }
378 (Value::String(a), Value::StringArray(bstring)) => {
379 let astring =
380 StringArray::new(vec![a], vec![1, 1]).map_err(|e| format!("intersect: {e}"))?;
381 intersect_string(astring, bstring, opts)
382 }
383 (Value::String(a), Value::String(b)) => {
384 let astring =
385 StringArray::new(vec![a], vec![1, 1]).map_err(|e| format!("intersect: {e}"))?;
386 let bstring =
387 StringArray::new(vec![b], vec![1, 1]).map_err(|e| format!("intersect: {e}"))?;
388 intersect_string(astring, bstring, opts)
389 }
390
391 (left, right) => {
392 let tensor_a = tensor::value_into_tensor_for("intersect", left)?;
393 let tensor_b = tensor::value_into_tensor_for("intersect", right)?;
394 intersect_numeric(tensor_a, tensor_b, opts)
395 }
396 }
397}
398
399fn intersect_numeric(
400 a: Tensor,
401 b: Tensor,
402 opts: &IntersectOptions,
403) -> Result<IntersectEvaluation, String> {
404 if opts.rows {
405 intersect_numeric_rows(a, b, opts)
406 } else {
407 intersect_numeric_elements(a, b, opts)
408 }
409}
410
411fn intersect_numeric_elements(
412 a: Tensor,
413 b: Tensor,
414 opts: &IntersectOptions,
415) -> Result<IntersectEvaluation, String> {
416 let mut b_map: HashMap<u64, usize> = HashMap::new();
417 for (idx, &value) in b.data.iter().enumerate() {
418 let key = canonicalize_f64(value);
419 b_map.entry(key).or_insert(idx);
420 }
421
422 let mut seen: HashSet<u64> = HashSet::new();
423 let mut entries = Vec::<NumericIntersectEntry>::new();
424 let mut order_counter = 0usize;
425
426 for (idx, &value) in a.data.iter().enumerate() {
427 let key = canonicalize_f64(value);
428 if seen.contains(&key) {
429 continue;
430 }
431 if let Some(&b_idx) = b_map.get(&key) {
432 entries.push(NumericIntersectEntry {
433 value,
434 a_index: idx,
435 b_index: b_idx,
436 order_rank: order_counter,
437 });
438 seen.insert(key);
439 order_counter += 1;
440 }
441 }
442
443 assemble_numeric_intersect(entries, opts)
444}
445
446fn intersect_numeric_rows(
447 a: Tensor,
448 b: Tensor,
449 opts: &IntersectOptions,
450) -> Result<IntersectEvaluation, String> {
451 if a.shape.len() != 2 || b.shape.len() != 2 {
452 return Err("intersect: 'rows' option requires 2-D numeric matrices".to_string());
453 }
454 if a.shape[1] != b.shape[1] {
455 return Err(
456 "intersect: inputs must have the same number of columns when using 'rows'".to_string(),
457 );
458 }
459 let rows_a = a.shape[0];
460 let cols = a.shape[1];
461 let rows_b = b.shape[0];
462
463 let mut b_map: HashMap<NumericRowKey, usize> = HashMap::new();
464 for r in 0..rows_b {
465 let mut row_values = Vec::with_capacity(cols);
466 for c in 0..cols {
467 let idx = r + c * rows_b;
468 row_values.push(b.data[idx]);
469 }
470 let key = NumericRowKey::from_slice(&row_values);
471 b_map.entry(key).or_insert(r);
472 }
473
474 let mut seen: HashSet<NumericRowKey> = HashSet::new();
475 let mut entries = Vec::<NumericRowIntersectEntry>::new();
476 let mut order_counter = 0usize;
477
478 for r in 0..rows_a {
479 let mut row_values = Vec::with_capacity(cols);
480 for c in 0..cols {
481 let idx = r + c * rows_a;
482 row_values.push(a.data[idx]);
483 }
484 let key = NumericRowKey::from_slice(&row_values);
485 if seen.contains(&key) {
486 continue;
487 }
488 if let Some(&b_row) = b_map.get(&key) {
489 entries.push(NumericRowIntersectEntry {
490 row_data: row_values,
491 a_row: r,
492 b_row,
493 order_rank: order_counter,
494 });
495 seen.insert(key);
496 order_counter += 1;
497 }
498 }
499
500 assemble_numeric_row_intersect(entries, opts, cols)
501}
502
503fn intersect_complex(
504 a: ComplexTensor,
505 b: ComplexTensor,
506 opts: &IntersectOptions,
507) -> Result<IntersectEvaluation, String> {
508 if opts.rows {
509 intersect_complex_rows(a, b, opts)
510 } else {
511 intersect_complex_elements(a, b, opts)
512 }
513}
514
515fn intersect_complex_elements(
516 a: ComplexTensor,
517 b: ComplexTensor,
518 opts: &IntersectOptions,
519) -> Result<IntersectEvaluation, String> {
520 let mut b_map: HashMap<ComplexKey, usize> = HashMap::new();
521 for (idx, &value) in b.data.iter().enumerate() {
522 let key = ComplexKey::new(value);
523 b_map.entry(key).or_insert(idx);
524 }
525
526 let mut seen: HashSet<ComplexKey> = HashSet::new();
527 let mut entries = Vec::<ComplexIntersectEntry>::new();
528 let mut order_counter = 0usize;
529
530 for (idx, &value) in a.data.iter().enumerate() {
531 let key = ComplexKey::new(value);
532 if seen.contains(&key) {
533 continue;
534 }
535 if let Some(&b_idx) = b_map.get(&key) {
536 entries.push(ComplexIntersectEntry {
537 value,
538 a_index: idx,
539 b_index: b_idx,
540 order_rank: order_counter,
541 });
542 seen.insert(key);
543 order_counter += 1;
544 }
545 }
546
547 assemble_complex_intersect(entries, opts)
548}
549
550fn intersect_complex_rows(
551 a: ComplexTensor,
552 b: ComplexTensor,
553 opts: &IntersectOptions,
554) -> Result<IntersectEvaluation, String> {
555 if a.shape.len() != 2 || b.shape.len() != 2 {
556 return Err("intersect: 'rows' option requires 2-D complex matrices".to_string());
557 }
558 if a.shape[1] != b.shape[1] {
559 return Err(
560 "intersect: inputs must have the same number of columns when using 'rows'".to_string(),
561 );
562 }
563 let rows_a = a.shape[0];
564 let cols = a.shape[1];
565 let rows_b = b.shape[0];
566
567 let mut b_map: HashMap<Vec<ComplexKey>, usize> = HashMap::new();
568 for r in 0..rows_b {
569 let mut row_keys = Vec::with_capacity(cols);
570 for c in 0..cols {
571 let idx = r + c * rows_b;
572 row_keys.push(ComplexKey::new(b.data[idx]));
573 }
574 b_map.entry(row_keys).or_insert(r);
575 }
576
577 let mut seen: HashSet<Vec<ComplexKey>> = HashSet::new();
578 let mut entries = Vec::<ComplexRowIntersectEntry>::new();
579 let mut order_counter = 0usize;
580
581 for r in 0..rows_a {
582 let mut row_values = Vec::with_capacity(cols);
583 let mut row_keys = Vec::with_capacity(cols);
584 for c in 0..cols {
585 let idx = r + c * rows_a;
586 let value = a.data[idx];
587 row_values.push(value);
588 row_keys.push(ComplexKey::new(value));
589 }
590 if seen.contains(&row_keys) {
591 continue;
592 }
593 if let Some(&b_row) = b_map.get(&row_keys) {
594 entries.push(ComplexRowIntersectEntry {
595 row_data: row_values,
596 a_row: r,
597 b_row,
598 order_rank: order_counter,
599 });
600 seen.insert(row_keys);
601 order_counter += 1;
602 }
603 }
604
605 assemble_complex_row_intersect(entries, opts, cols)
606}
607
608fn intersect_char(
609 a: CharArray,
610 b: CharArray,
611 opts: &IntersectOptions,
612) -> Result<IntersectEvaluation, String> {
613 if opts.rows {
614 intersect_char_rows(a, b, opts)
615 } else {
616 intersect_char_elements(a, b, opts)
617 }
618}
619
620fn intersect_char_elements(
621 a: CharArray,
622 b: CharArray,
623 opts: &IntersectOptions,
624) -> Result<IntersectEvaluation, String> {
625 let mut seen: HashSet<u32> = HashSet::new();
626 let mut entries = Vec::<CharIntersectEntry>::new();
627 let mut order_counter = 0usize;
628
629 for col in 0..a.cols {
630 for row in 0..a.rows {
631 let linear_idx = row + col * a.rows;
632 let data_idx = row * a.cols + col;
633 let ch = a.data[data_idx];
634 let key = ch as u32;
635 if seen.contains(&key) {
636 continue;
637 }
638 if let Some(b_idx) = find_char_index(&b, ch) {
639 entries.push(CharIntersectEntry {
640 ch,
641 a_index: linear_idx,
642 b_index: b_idx,
643 order_rank: order_counter,
644 });
645 seen.insert(key);
646 order_counter += 1;
647 }
648 }
649 }
650
651 assemble_char_intersect(entries, opts, &b)
652}
653
654fn intersect_char_rows(
655 a: CharArray,
656 b: CharArray,
657 opts: &IntersectOptions,
658) -> Result<IntersectEvaluation, String> {
659 if a.cols != b.cols {
660 return Err(
661 "intersect: inputs must have the same number of columns when using 'rows'".to_string(),
662 );
663 }
664 let rows_a = a.rows;
665 let rows_b = b.rows;
666 let cols = a.cols;
667
668 let mut b_map: HashMap<RowCharKey, usize> = HashMap::new();
669 for r in 0..rows_b {
670 let mut row_values = Vec::with_capacity(cols);
671 for c in 0..cols {
672 let idx = r * cols + c;
673 row_values.push(b.data[idx]);
674 }
675 let key = RowCharKey::from_slice(&row_values);
676 b_map.entry(key).or_insert(r);
677 }
678
679 let mut seen: HashSet<RowCharKey> = HashSet::new();
680 let mut entries = Vec::<CharRowIntersectEntry>::new();
681 let mut order_counter = 0usize;
682
683 for r in 0..rows_a {
684 let mut row_values = Vec::with_capacity(cols);
685 for c in 0..cols {
686 let idx = r * cols + c;
687 row_values.push(a.data[idx]);
688 }
689 let key = RowCharKey::from_slice(&row_values);
690 if seen.contains(&key) {
691 continue;
692 }
693 if let Some(&b_row) = b_map.get(&key) {
694 entries.push(CharRowIntersectEntry {
695 row_data: row_values,
696 a_row: r,
697 b_row,
698 order_rank: order_counter,
699 });
700 seen.insert(key);
701 order_counter += 1;
702 }
703 }
704
705 assemble_char_row_intersect(entries, opts, cols)
706}
707
708fn find_char_index(array: &CharArray, target: char) -> Option<usize> {
709 for col in 0..array.cols {
710 for row in 0..array.rows {
711 let data_idx = row * array.cols + col;
712 if array.data[data_idx] == target {
713 return Some(row + col * array.rows);
714 }
715 }
716 }
717 None
718}
719
720fn intersect_string(
721 a: StringArray,
722 b: StringArray,
723 opts: &IntersectOptions,
724) -> Result<IntersectEvaluation, String> {
725 if opts.rows {
726 intersect_string_rows(a, b, opts)
727 } else {
728 intersect_string_elements(a, b, opts)
729 }
730}
731
732fn intersect_string_elements(
733 a: StringArray,
734 b: StringArray,
735 opts: &IntersectOptions,
736) -> Result<IntersectEvaluation, String> {
737 let mut b_map: HashMap<String, usize> = HashMap::new();
738 for (idx, value) in b.data.iter().enumerate() {
739 b_map.entry(value.clone()).or_insert(idx);
740 }
741
742 let mut seen: HashSet<String> = HashSet::new();
743 let mut entries = Vec::<StringIntersectEntry>::new();
744 let mut order_counter = 0usize;
745
746 for (idx, value) in a.data.iter().enumerate() {
747 if seen.contains(value) {
748 continue;
749 }
750 if let Some(&b_idx) = b_map.get(value) {
751 entries.push(StringIntersectEntry {
752 value: value.clone(),
753 a_index: idx,
754 b_index: b_idx,
755 order_rank: order_counter,
756 });
757 seen.insert(value.clone());
758 order_counter += 1;
759 }
760 }
761
762 assemble_string_intersect(entries, opts)
763}
764
765fn intersect_string_rows(
766 a: StringArray,
767 b: StringArray,
768 opts: &IntersectOptions,
769) -> Result<IntersectEvaluation, String> {
770 if a.shape.len() != 2 || b.shape.len() != 2 {
771 return Err("intersect: 'rows' option requires 2-D string arrays".to_string());
772 }
773 if a.shape[1] != b.shape[1] {
774 return Err(
775 "intersect: inputs must have the same number of columns when using 'rows'".to_string(),
776 );
777 }
778 let rows_a = a.shape[0];
779 let cols = a.shape[1];
780 let rows_b = b.shape[0];
781
782 let mut b_map: HashMap<RowStringKey, usize> = HashMap::new();
783 for r in 0..rows_b {
784 let mut row_values = Vec::with_capacity(cols);
785 for c in 0..cols {
786 let idx = r + c * rows_b;
787 row_values.push(b.data[idx].clone());
788 }
789 let key = RowStringKey::from_slice(&row_values);
790 b_map.entry(key).or_insert(r);
791 }
792
793 let mut seen: HashSet<RowStringKey> = HashSet::new();
794 let mut entries = Vec::<StringRowIntersectEntry>::new();
795 let mut order_counter = 0usize;
796
797 for r in 0..rows_a {
798 let mut row_values = Vec::with_capacity(cols);
799 for c in 0..cols {
800 let idx = r + c * rows_a;
801 row_values.push(a.data[idx].clone());
802 }
803 let key = RowStringKey::from_slice(&row_values);
804 if seen.contains(&key) {
805 continue;
806 }
807 if let Some(&b_row) = b_map.get(&key) {
808 entries.push(StringRowIntersectEntry {
809 row_data: row_values,
810 a_row: r,
811 b_row,
812 order_rank: order_counter,
813 });
814 seen.insert(key);
815 order_counter += 1;
816 }
817 }
818
819 assemble_string_row_intersect(entries, opts, cols)
820}
821
822#[derive(Debug, Clone)]
823pub struct IntersectEvaluation {
824 values: Value,
825 ia: Tensor,
826 ib: Tensor,
827}
828
829impl IntersectEvaluation {
830 fn new(values: Value, ia: Tensor, ib: Tensor) -> Self {
831 Self { values, ia, ib }
832 }
833
834 pub fn into_values_value(self) -> Value {
835 self.values
836 }
837
838 pub fn into_pair(self) -> (Value, Value) {
839 let ia = tensor::tensor_into_value(self.ia);
840 (self.values, ia)
841 }
842
843 pub fn into_triple(self) -> (Value, Value, Value) {
844 let ia = tensor::tensor_into_value(self.ia);
845 let ib = tensor::tensor_into_value(self.ib);
846 (self.values, ia, ib)
847 }
848
849 pub fn values_value(&self) -> Value {
850 self.values.clone()
851 }
852
853 pub fn ia_value(&self) -> Value {
854 tensor::tensor_into_value(self.ia.clone())
855 }
856
857 pub fn ib_value(&self) -> Value {
858 tensor::tensor_into_value(self.ib.clone())
859 }
860}
861
862#[derive(Debug)]
863struct NumericIntersectEntry {
864 value: f64,
865 a_index: usize,
866 b_index: usize,
867 order_rank: usize,
868}
869
870#[derive(Debug)]
871struct NumericRowIntersectEntry {
872 row_data: Vec<f64>,
873 a_row: usize,
874 b_row: usize,
875 order_rank: usize,
876}
877
878#[derive(Debug)]
879struct ComplexIntersectEntry {
880 value: (f64, f64),
881 a_index: usize,
882 b_index: usize,
883 order_rank: usize,
884}
885
886#[derive(Debug)]
887struct ComplexRowIntersectEntry {
888 row_data: Vec<(f64, f64)>,
889 a_row: usize,
890 b_row: usize,
891 order_rank: usize,
892}
893
894#[derive(Debug)]
895struct CharIntersectEntry {
896 ch: char,
897 a_index: usize,
898 b_index: usize,
899 order_rank: usize,
900}
901
902#[derive(Debug)]
903struct CharRowIntersectEntry {
904 row_data: Vec<char>,
905 a_row: usize,
906 b_row: usize,
907 order_rank: usize,
908}
909
910#[derive(Debug)]
911struct StringIntersectEntry {
912 value: String,
913 a_index: usize,
914 b_index: usize,
915 order_rank: usize,
916}
917
918#[derive(Debug)]
919struct StringRowIntersectEntry {
920 row_data: Vec<String>,
921 a_row: usize,
922 b_row: usize,
923 order_rank: usize,
924}
925
926fn assemble_numeric_intersect(
927 entries: Vec<NumericIntersectEntry>,
928 opts: &IntersectOptions,
929) -> Result<IntersectEvaluation, String> {
930 let mut order: Vec<usize> = (0..entries.len()).collect();
931 match opts.order {
932 IntersectOrder::Sorted => {
933 order.sort_by(|&lhs, &rhs| compare_f64(entries[lhs].value, entries[rhs].value));
934 }
935 IntersectOrder::Stable => {
936 order.sort_by_key(|&idx| entries[idx].order_rank);
937 }
938 }
939
940 let mut values = Vec::with_capacity(order.len());
941 let mut ia = Vec::with_capacity(order.len());
942 let mut ib = Vec::with_capacity(order.len());
943 for &idx in &order {
944 let entry = &entries[idx];
945 values.push(entry.value);
946 ia.push((entry.a_index + 1) as f64);
947 ib.push((entry.b_index + 1) as f64);
948 }
949
950 let value_tensor =
951 Tensor::new(values, vec![order.len(), 1]).map_err(|e| format!("intersect: {e}"))?;
952 let ia_tensor = Tensor::new(ia, vec![order.len(), 1]).map_err(|e| format!("intersect: {e}"))?;
953 let ib_tensor = Tensor::new(ib, vec![order.len(), 1]).map_err(|e| format!("intersect: {e}"))?;
954
955 Ok(IntersectEvaluation::new(
956 tensor::tensor_into_value(value_tensor),
957 ia_tensor,
958 ib_tensor,
959 ))
960}
961
962fn assemble_numeric_row_intersect(
963 entries: Vec<NumericRowIntersectEntry>,
964 opts: &IntersectOptions,
965 cols: usize,
966) -> Result<IntersectEvaluation, String> {
967 let mut order: Vec<usize> = (0..entries.len()).collect();
968 match opts.order {
969 IntersectOrder::Sorted => {
970 order.sort_by(|&lhs, &rhs| {
971 compare_numeric_rows(&entries[lhs].row_data, &entries[rhs].row_data)
972 });
973 }
974 IntersectOrder::Stable => {
975 order.sort_by_key(|&idx| entries[idx].order_rank);
976 }
977 }
978
979 let rows_out = order.len();
980 let mut values = vec![0.0f64; rows_out * cols];
981 let mut ia = Vec::with_capacity(rows_out);
982 let mut ib = Vec::with_capacity(rows_out);
983
984 for (row_pos, &entry_idx) in order.iter().enumerate() {
985 let entry = &entries[entry_idx];
986 for col in 0..cols {
987 let dest = row_pos + col * rows_out;
988 values[dest] = entry.row_data[col];
989 }
990 ia.push((entry.a_row + 1) as f64);
991 ib.push((entry.b_row + 1) as f64);
992 }
993
994 let value_tensor =
995 Tensor::new(values, vec![rows_out, cols]).map_err(|e| format!("intersect: {e}"))?;
996 let ia_tensor = Tensor::new(ia, vec![rows_out, 1]).map_err(|e| format!("intersect: {e}"))?;
997 let ib_tensor = Tensor::new(ib, vec![rows_out, 1]).map_err(|e| format!("intersect: {e}"))?;
998
999 Ok(IntersectEvaluation::new(
1000 tensor::tensor_into_value(value_tensor),
1001 ia_tensor,
1002 ib_tensor,
1003 ))
1004}
1005
1006fn assemble_complex_intersect(
1007 entries: Vec<ComplexIntersectEntry>,
1008 opts: &IntersectOptions,
1009) -> Result<IntersectEvaluation, String> {
1010 let mut order: Vec<usize> = (0..entries.len()).collect();
1011 match opts.order {
1012 IntersectOrder::Sorted => {
1013 order.sort_by(|&lhs, &rhs| compare_complex(entries[lhs].value, entries[rhs].value));
1014 }
1015 IntersectOrder::Stable => {
1016 order.sort_by_key(|&idx| entries[idx].order_rank);
1017 }
1018 }
1019
1020 let mut values = Vec::with_capacity(order.len());
1021 let mut ia = Vec::with_capacity(order.len());
1022 let mut ib = Vec::with_capacity(order.len());
1023 for &idx in &order {
1024 let entry = &entries[idx];
1025 values.push(entry.value);
1026 ia.push((entry.a_index + 1) as f64);
1027 ib.push((entry.b_index + 1) as f64);
1028 }
1029
1030 let value_tensor =
1031 ComplexTensor::new(values, vec![order.len(), 1]).map_err(|e| format!("intersect: {e}"))?;
1032 let ia_tensor = Tensor::new(ia, vec![order.len(), 1]).map_err(|e| format!("intersect: {e}"))?;
1033 let ib_tensor = Tensor::new(ib, vec![order.len(), 1]).map_err(|e| format!("intersect: {e}"))?;
1034
1035 Ok(IntersectEvaluation::new(
1036 complex_tensor_into_value(value_tensor),
1037 ia_tensor,
1038 ib_tensor,
1039 ))
1040}
1041
1042fn assemble_complex_row_intersect(
1043 entries: Vec<ComplexRowIntersectEntry>,
1044 opts: &IntersectOptions,
1045 cols: usize,
1046) -> Result<IntersectEvaluation, String> {
1047 let mut order: Vec<usize> = (0..entries.len()).collect();
1048 match opts.order {
1049 IntersectOrder::Sorted => {
1050 order.sort_by(|&lhs, &rhs| {
1051 compare_complex_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1052 });
1053 }
1054 IntersectOrder::Stable => {
1055 order.sort_by_key(|&idx| entries[idx].order_rank);
1056 }
1057 }
1058
1059 let rows_out = order.len();
1060 let mut values = vec![(0.0f64, 0.0f64); rows_out * cols];
1061 let mut ia = Vec::with_capacity(rows_out);
1062 let mut ib = Vec::with_capacity(rows_out);
1063
1064 for (row_pos, &entry_idx) in order.iter().enumerate() {
1065 let entry = &entries[entry_idx];
1066 for col in 0..cols {
1067 let dest = row_pos + col * rows_out;
1068 values[dest] = entry.row_data[col];
1069 }
1070 ia.push((entry.a_row + 1) as f64);
1071 ib.push((entry.b_row + 1) as f64);
1072 }
1073
1074 let value_tensor =
1075 ComplexTensor::new(values, vec![rows_out, cols]).map_err(|e| format!("intersect: {e}"))?;
1076 let ia_tensor = Tensor::new(ia, vec![rows_out, 1]).map_err(|e| format!("intersect: {e}"))?;
1077 let ib_tensor = Tensor::new(ib, vec![rows_out, 1]).map_err(|e| format!("intersect: {e}"))?;
1078
1079 Ok(IntersectEvaluation::new(
1080 complex_tensor_into_value(value_tensor),
1081 ia_tensor,
1082 ib_tensor,
1083 ))
1084}
1085
1086fn assemble_char_intersect(
1087 entries: Vec<CharIntersectEntry>,
1088 opts: &IntersectOptions,
1089 b: &CharArray,
1090) -> Result<IntersectEvaluation, String> {
1091 let mut order: Vec<usize> = (0..entries.len()).collect();
1092 match opts.order {
1093 IntersectOrder::Sorted => {
1094 order.sort_by(|&lhs, &rhs| entries[lhs].ch.cmp(&entries[rhs].ch));
1095 }
1096 IntersectOrder::Stable => {
1097 order.sort_by_key(|&idx| entries[idx].order_rank);
1098 }
1099 }
1100
1101 let mut values = Vec::with_capacity(order.len());
1102 let mut ia = Vec::with_capacity(order.len());
1103 let mut ib = Vec::with_capacity(order.len());
1104 for &idx in &order {
1105 let entry = &entries[idx];
1106 values.push(entry.ch);
1107 ia.push((entry.a_index + 1) as f64);
1108 let b_idx = find_char_index(b, entry.ch).unwrap_or(entry.b_index);
1109 ib.push((b_idx + 1) as f64);
1110 }
1111
1112 let value_array =
1113 CharArray::new(values, order.len(), 1).map_err(|e| format!("intersect: {e}"))?;
1114 let ia_tensor = Tensor::new(ia, vec![order.len(), 1]).map_err(|e| format!("intersect: {e}"))?;
1115 let ib_tensor = Tensor::new(ib, vec![order.len(), 1]).map_err(|e| format!("intersect: {e}"))?;
1116
1117 Ok(IntersectEvaluation::new(
1118 Value::CharArray(value_array),
1119 ia_tensor,
1120 ib_tensor,
1121 ))
1122}
1123
1124fn assemble_char_row_intersect(
1125 entries: Vec<CharRowIntersectEntry>,
1126 opts: &IntersectOptions,
1127 cols: usize,
1128) -> Result<IntersectEvaluation, String> {
1129 let mut order: Vec<usize> = (0..entries.len()).collect();
1130 match opts.order {
1131 IntersectOrder::Sorted => {
1132 order.sort_by(|&lhs, &rhs| {
1133 compare_char_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1134 });
1135 }
1136 IntersectOrder::Stable => {
1137 order.sort_by_key(|&idx| entries[idx].order_rank);
1138 }
1139 }
1140
1141 let rows_out = order.len();
1142 let mut values = vec!['\0'; rows_out * cols];
1143 let mut ia = Vec::with_capacity(rows_out);
1144 let mut ib = Vec::with_capacity(rows_out);
1145
1146 for (row_pos, &entry_idx) in order.iter().enumerate() {
1147 let entry = &entries[entry_idx];
1148 for col in 0..cols {
1149 let dest = row_pos * cols + col;
1150 values[dest] = entry.row_data[col];
1151 }
1152 ia.push((entry.a_row + 1) as f64);
1153 ib.push((entry.b_row + 1) as f64);
1154 }
1155
1156 let value_array =
1157 CharArray::new(values, rows_out, cols).map_err(|e| format!("intersect: {e}"))?;
1158 let ia_tensor = Tensor::new(ia, vec![rows_out, 1]).map_err(|e| format!("intersect: {e}"))?;
1159 let ib_tensor = Tensor::new(ib, vec![rows_out, 1]).map_err(|e| format!("intersect: {e}"))?;
1160
1161 Ok(IntersectEvaluation::new(
1162 Value::CharArray(value_array),
1163 ia_tensor,
1164 ib_tensor,
1165 ))
1166}
1167
1168fn assemble_string_intersect(
1169 entries: Vec<StringIntersectEntry>,
1170 opts: &IntersectOptions,
1171) -> Result<IntersectEvaluation, String> {
1172 let mut order: Vec<usize> = (0..entries.len()).collect();
1173 match opts.order {
1174 IntersectOrder::Sorted => {
1175 order.sort_by(|&lhs, &rhs| entries[lhs].value.cmp(&entries[rhs].value));
1176 }
1177 IntersectOrder::Stable => {
1178 order.sort_by_key(|&idx| entries[idx].order_rank);
1179 }
1180 }
1181
1182 let mut values = Vec::with_capacity(order.len());
1183 let mut ia = Vec::with_capacity(order.len());
1184 let mut ib = Vec::with_capacity(order.len());
1185 for &idx in &order {
1186 let entry = &entries[idx];
1187 values.push(entry.value.clone());
1188 ia.push((entry.a_index + 1) as f64);
1189 ib.push((entry.b_index + 1) as f64);
1190 }
1191
1192 let value_array =
1193 StringArray::new(values, vec![order.len(), 1]).map_err(|e| format!("intersect: {e}"))?;
1194 let ia_tensor = Tensor::new(ia, vec![order.len(), 1]).map_err(|e| format!("intersect: {e}"))?;
1195 let ib_tensor = Tensor::new(ib, vec![order.len(), 1]).map_err(|e| format!("intersect: {e}"))?;
1196
1197 Ok(IntersectEvaluation::new(
1198 Value::StringArray(value_array),
1199 ia_tensor,
1200 ib_tensor,
1201 ))
1202}
1203
1204fn assemble_string_row_intersect(
1205 entries: Vec<StringRowIntersectEntry>,
1206 opts: &IntersectOptions,
1207 cols: usize,
1208) -> Result<IntersectEvaluation, String> {
1209 let mut order: Vec<usize> = (0..entries.len()).collect();
1210 match opts.order {
1211 IntersectOrder::Sorted => {
1212 order.sort_by(|&lhs, &rhs| {
1213 compare_string_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1214 });
1215 }
1216 IntersectOrder::Stable => {
1217 order.sort_by_key(|&idx| entries[idx].order_rank);
1218 }
1219 }
1220
1221 let rows_out = order.len();
1222 let mut values = vec![String::new(); rows_out * cols];
1223 let mut ia = Vec::with_capacity(rows_out);
1224 let mut ib = Vec::with_capacity(rows_out);
1225
1226 for (row_pos, &entry_idx) in order.iter().enumerate() {
1227 let entry = &entries[entry_idx];
1228 for col in 0..cols {
1229 let dest = row_pos + col * rows_out;
1230 values[dest] = entry.row_data[col].clone();
1231 }
1232 ia.push((entry.a_row + 1) as f64);
1233 ib.push((entry.b_row + 1) as f64);
1234 }
1235
1236 let value_array =
1237 StringArray::new(values, vec![rows_out, cols]).map_err(|e| format!("intersect: {e}"))?;
1238 let ia_tensor = Tensor::new(ia, vec![rows_out, 1]).map_err(|e| format!("intersect: {e}"))?;
1239 let ib_tensor = Tensor::new(ib, vec![rows_out, 1]).map_err(|e| format!("intersect: {e}"))?;
1240
1241 Ok(IntersectEvaluation::new(
1242 Value::StringArray(value_array),
1243 ia_tensor,
1244 ib_tensor,
1245 ))
1246}
1247
1248#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1249struct NumericRowKey(Vec<u64>);
1250
1251impl NumericRowKey {
1252 fn from_slice(values: &[f64]) -> Self {
1253 NumericRowKey(values.iter().map(|&v| canonicalize_f64(v)).collect())
1254 }
1255}
1256
1257#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
1258struct ComplexKey {
1259 re: u64,
1260 im: u64,
1261}
1262
1263impl ComplexKey {
1264 fn new(value: (f64, f64)) -> Self {
1265 Self {
1266 re: canonicalize_f64(value.0),
1267 im: canonicalize_f64(value.1),
1268 }
1269 }
1270}
1271
1272#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1273struct RowCharKey(Vec<u32>);
1274
1275impl RowCharKey {
1276 fn from_slice(values: &[char]) -> Self {
1277 RowCharKey(values.iter().map(|&ch| ch as u32).collect())
1278 }
1279}
1280
1281#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1282struct RowStringKey(Vec<String>);
1283
1284impl RowStringKey {
1285 fn from_slice(values: &[String]) -> Self {
1286 RowStringKey(values.to_vec())
1287 }
1288}
1289
1290fn scalar_complex_tensor(re: f64, im: f64) -> Result<ComplexTensor, String> {
1291 ComplexTensor::new(vec![(re, im)], vec![1, 1]).map_err(|e| format!("intersect: {e}"))
1292}
1293
1294fn tensor_to_complex_owned(name: &str, tensor: Tensor) -> Result<ComplexTensor, String> {
1295 let Tensor { data, shape, .. } = tensor;
1296 let complex: Vec<(f64, f64)> = data.into_iter().map(|re| (re, 0.0)).collect();
1297 ComplexTensor::new(complex, shape).map_err(|e| format!("{name}: {e}"))
1298}
1299
1300fn value_into_complex_tensor(value: Value) -> Result<ComplexTensor, String> {
1301 match value {
1302 Value::ComplexTensor(tensor) => Ok(tensor),
1303 Value::Complex(re, im) => scalar_complex_tensor(re, im),
1304 other => {
1305 let tensor = tensor::value_into_tensor_for("intersect", other)?;
1306 tensor_to_complex_owned("intersect", tensor)
1307 }
1308 }
1309}
1310
1311fn canonicalize_f64(value: f64) -> u64 {
1312 if value.is_nan() {
1313 0x7ff8_0000_0000_0000u64
1314 } else if value == 0.0 {
1315 0u64
1316 } else {
1317 value.to_bits()
1318 }
1319}
1320
1321fn compare_f64(a: f64, b: f64) -> Ordering {
1322 if a.is_nan() {
1323 if b.is_nan() {
1324 Ordering::Equal
1325 } else {
1326 Ordering::Greater
1327 }
1328 } else if b.is_nan() {
1329 Ordering::Less
1330 } else {
1331 a.partial_cmp(&b).unwrap_or(Ordering::Equal)
1332 }
1333}
1334
1335fn compare_numeric_rows(a: &[f64], b: &[f64]) -> Ordering {
1336 for (lhs, rhs) in a.iter().zip(b.iter()) {
1337 let ord = compare_f64(*lhs, *rhs);
1338 if ord != Ordering::Equal {
1339 return ord;
1340 }
1341 }
1342 Ordering::Equal
1343}
1344
1345fn complex_is_nan(value: (f64, f64)) -> bool {
1346 value.0.is_nan() || value.1.is_nan()
1347}
1348
1349fn compare_complex(a: (f64, f64), b: (f64, f64)) -> Ordering {
1350 match (complex_is_nan(a), complex_is_nan(b)) {
1351 (true, true) => Ordering::Equal,
1352 (true, false) => Ordering::Greater,
1353 (false, true) => Ordering::Less,
1354 (false, false) => {
1355 let mag_a = a.0.hypot(a.1);
1356 let mag_b = b.0.hypot(b.1);
1357 let mag_cmp = compare_f64(mag_a, mag_b);
1358 if mag_cmp != Ordering::Equal {
1359 return mag_cmp;
1360 }
1361 let re_cmp = compare_f64(a.0, b.0);
1362 if re_cmp != Ordering::Equal {
1363 return re_cmp;
1364 }
1365 compare_f64(a.1, b.1)
1366 }
1367 }
1368}
1369
1370fn compare_complex_rows(a: &[(f64, f64)], b: &[(f64, f64)]) -> Ordering {
1371 for (lhs, rhs) in a.iter().zip(b.iter()) {
1372 let ord = compare_complex(*lhs, *rhs);
1373 if ord != Ordering::Equal {
1374 return ord;
1375 }
1376 }
1377 Ordering::Equal
1378}
1379
1380fn compare_char_rows(a: &[char], b: &[char]) -> Ordering {
1381 for (lhs, rhs) in a.iter().zip(b.iter()) {
1382 let ord = lhs.cmp(rhs);
1383 if ord != Ordering::Equal {
1384 return ord;
1385 }
1386 }
1387 Ordering::Equal
1388}
1389
1390fn compare_string_rows(a: &[String], b: &[String]) -> Ordering {
1391 for (lhs, rhs) in a.iter().zip(b.iter()) {
1392 let ord = lhs.cmp(rhs);
1393 if ord != Ordering::Equal {
1394 return ord;
1395 }
1396 }
1397 Ordering::Equal
1398}
1399
1400#[cfg(test)]
1401mod tests {
1402 use super::*;
1403 use crate::builtins::common::test_support;
1404 use runmat_accelerate_api::HostTensorView;
1405
1406 #[test]
1407 fn intersect_numeric_sorted() {
1408 let a = Tensor::new(vec![5.0, 7.0, 5.0, 1.0], vec![4, 1]).unwrap();
1409 let b = Tensor::new(vec![7.0, 1.0, 3.0], vec![3, 1]).unwrap();
1410 let eval = intersect_numeric_elements(
1411 a,
1412 b,
1413 &IntersectOptions {
1414 rows: false,
1415 order: IntersectOrder::Sorted,
1416 },
1417 )
1418 .expect("intersect");
1419 let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1420 assert_eq!(values.data, vec![1.0, 7.0]);
1421 let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1422 let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1423 assert_eq!(ia.data, vec![4.0, 2.0]);
1424 assert_eq!(ib.data, vec![2.0, 1.0]);
1425 }
1426
1427 #[test]
1428 fn intersect_numeric_stable() {
1429 let a = Tensor::new(vec![4.0, 2.0, 4.0, 1.0, 3.0], vec![5, 1]).unwrap();
1430 let b = Tensor::new(vec![3.0, 4.0, 5.0, 1.0], vec![4, 1]).unwrap();
1431 let eval = intersect_numeric_elements(
1432 a,
1433 b,
1434 &IntersectOptions {
1435 rows: false,
1436 order: IntersectOrder::Stable,
1437 },
1438 )
1439 .expect("intersect");
1440 let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1441 assert_eq!(values.data, vec![4.0, 1.0, 3.0]);
1442 let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1443 let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1444 assert_eq!(ia.data, vec![1.0, 4.0, 5.0]);
1445 assert_eq!(ib.data, vec![2.0, 4.0, 1.0]);
1446 }
1447
1448 #[test]
1449 fn intersect_numeric_handles_nan() {
1450 let a = Tensor::new(vec![f64::NAN, 1.0, f64::NAN], vec![3, 1]).unwrap();
1451 let b = Tensor::new(vec![2.0, f64::NAN], vec![2, 1]).unwrap();
1452 let eval = intersect_numeric_elements(
1453 a,
1454 b,
1455 &IntersectOptions {
1456 rows: false,
1457 order: IntersectOrder::Sorted,
1458 },
1459 )
1460 .expect("intersect");
1461 let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1462 assert_eq!(values.data.len(), 1);
1463 assert!(values.data[0].is_nan());
1464 let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1465 let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1466 assert_eq!(ia.data, vec![1.0]);
1467 assert_eq!(ib.data, vec![2.0]);
1468 }
1469
1470 #[test]
1471 fn intersect_complex_with_real_inputs() {
1472 let complex =
1473 ComplexTensor::new(vec![(1.0, 0.0), (2.0, 0.0), (3.0, 1.0)], vec![3, 1]).unwrap();
1474 let real = Tensor::new(vec![2.0, 4.0, 1.0], vec![3, 1]).unwrap();
1475 let real_complex = tensor_to_complex_owned("intersect", real).unwrap();
1476 let eval = intersect_complex(
1477 complex,
1478 real_complex,
1479 &IntersectOptions {
1480 rows: false,
1481 order: IntersectOrder::Sorted,
1482 },
1483 )
1484 .expect("intersect complex");
1485 match eval.values_value() {
1486 Value::ComplexTensor(t) => {
1487 assert_eq!(t.data, vec![(1.0, 0.0), (2.0, 0.0)]);
1488 }
1489 other => panic!("expected complex tensor, got {other:?}"),
1490 }
1491 let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1492 let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1493 assert_eq!(ia.data, vec![1.0, 2.0]);
1494 assert_eq!(ib.data, vec![3.0, 1.0]);
1495 }
1496
1497 #[test]
1498 fn intersect_numeric_rows_default() {
1499 let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
1500 let b = Tensor::new(vec![1.0, 5.0, 2.0, 6.0], vec![2, 2]).unwrap();
1501 let eval = intersect_numeric_rows(
1502 a,
1503 b,
1504 &IntersectOptions {
1505 rows: true,
1506 order: IntersectOrder::Sorted,
1507 },
1508 )
1509 .expect("intersect rows");
1510 let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1511 assert_eq!(values.shape, vec![1, 2]);
1512 assert_eq!(values.data, vec![1.0, 2.0]);
1513 let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1514 let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1515 assert_eq!(ia.data, vec![1.0]);
1516 assert_eq!(ib.data, vec![1.0]);
1517 }
1518
1519 #[test]
1520 fn intersect_char_elements_basic() {
1521 let a = CharArray::new("cab".chars().collect(), 1, 3).unwrap();
1522 let b = CharArray::new("bcd".chars().collect(), 1, 3).unwrap();
1523 assert_eq!(find_char_index(&b, 'b'), Some(0));
1524 assert_eq!(find_char_index(&b, 'c'), Some(1));
1525 let b_for_eval = CharArray::new("bcd".chars().collect(), 1, 3).unwrap();
1526 let eval = intersect_char_elements(
1527 a,
1528 b_for_eval,
1529 &IntersectOptions {
1530 rows: false,
1531 order: IntersectOrder::Sorted,
1532 },
1533 )
1534 .expect("intersect char");
1535 match eval.values_value() {
1536 Value::CharArray(arr) => {
1537 assert_eq!(arr.rows, 2);
1538 assert_eq!(arr.cols, 1);
1539 assert_eq!(arr.data, vec!['b', 'c']);
1540 }
1541 other => panic!("expected char array, got {other:?}"),
1542 }
1543 let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1544 let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1545 assert_eq!(ia.data, vec![3.0, 1.0]);
1546 assert_eq!(ib.data, vec![1.0, 2.0]);
1547 }
1548
1549 #[test]
1550 fn intersect_string_elements_stable() {
1551 let a = StringArray::new(
1552 vec!["apple".into(), "orange".into(), "pear".into()],
1553 vec![3, 1],
1554 )
1555 .unwrap();
1556 let b = StringArray::new(
1557 vec!["pear".into(), "grape".into(), "orange".into()],
1558 vec![3, 1],
1559 )
1560 .unwrap();
1561 let eval = intersect_string_elements(
1562 a,
1563 b,
1564 &IntersectOptions {
1565 rows: false,
1566 order: IntersectOrder::Stable,
1567 },
1568 )
1569 .expect("intersect string");
1570 match eval.values_value() {
1571 Value::StringArray(arr) => {
1572 assert_eq!(arr.shape, vec![2, 1]);
1573 assert_eq!(arr.data, vec!["orange".to_string(), "pear".to_string()]);
1574 }
1575 other => panic!("expected string array, got {other:?}"),
1576 }
1577 let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1578 let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1579 assert_eq!(ia.data, vec![2.0, 3.0]);
1580 assert_eq!(ib.data, vec![3.0, 1.0]);
1581 }
1582
1583 #[test]
1584 fn intersect_rejects_legacy_option() {
1585 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1586 let err = evaluate(
1587 Value::Tensor(tensor.clone()),
1588 Value::Tensor(tensor),
1589 &[Value::from("legacy")],
1590 )
1591 .unwrap_err();
1592 assert!(err.contains("legacy"));
1593 }
1594
1595 #[test]
1596 fn intersect_rows_dimension_mismatch() {
1597 let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1598 let b = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1599 let err = intersect_numeric_rows(
1600 a,
1601 b,
1602 &IntersectOptions {
1603 rows: true,
1604 order: IntersectOrder::Sorted,
1605 },
1606 )
1607 .unwrap_err();
1608 assert!(err.contains("same number of columns"));
1609 }
1610
1611 #[test]
1612 fn intersect_mixed_types_error() {
1613 let a = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1614 let b = CharArray::new(vec!['a', 'b'], 1, 2).unwrap();
1615 let err = intersect_host(
1616 Value::Tensor(a),
1617 Value::CharArray(b),
1618 &IntersectOptions {
1619 rows: false,
1620 order: IntersectOrder::Sorted,
1621 },
1622 )
1623 .unwrap_err();
1624 assert!(err.contains("unsupported input type"));
1625 }
1626
1627 #[test]
1628 fn intersect_gpu_roundtrip() {
1629 test_support::with_test_provider(|provider| {
1630 let a = Tensor::new(vec![4.0, 1.0, 2.0, 1.0], vec![4, 1]).unwrap();
1631 let b = Tensor::new(vec![2.0, 5.0, 1.0], vec![3, 1]).unwrap();
1632 let view_a = HostTensorView {
1633 data: &a.data,
1634 shape: &a.shape,
1635 };
1636 let view_b = HostTensorView {
1637 data: &b.data,
1638 shape: &b.shape,
1639 };
1640 let handle_a = provider.upload(&view_a).expect("upload A");
1641 let handle_b = provider.upload(&view_b).expect("upload B");
1642 let eval = evaluate(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
1643 .expect("intersect");
1644 let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1645 assert_eq!(values.data, vec![1.0, 2.0]);
1646 let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1647 let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1648 assert_eq!(ia.data, vec![2.0, 3.0]);
1649 assert_eq!(ib.data, vec![3.0, 1.0]);
1650 });
1651 }
1652
1653 #[test]
1654 fn intersect_two_outputs_from_evaluate() {
1655 let a = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1656 let b = Tensor::new(vec![3.0, 1.0], vec![2, 1]).unwrap();
1657 let eval = intersect_numeric_elements(
1658 a,
1659 b,
1660 &IntersectOptions {
1661 rows: false,
1662 order: IntersectOrder::Sorted,
1663 },
1664 )
1665 .unwrap();
1666 let (_c, ia) = eval.clone().into_pair();
1667 let ia_tensor = tensor::value_into_tensor_for("intersect", ia).unwrap();
1668 assert_eq!(ia_tensor.data, vec![1.0, 3.0]);
1669 let (_c, ia2, ib2) = eval.into_triple();
1670 let ia_tensor2 = tensor::value_into_tensor_for("intersect", ia2).unwrap();
1671 let ib_tensor2 = tensor::value_into_tensor_for("intersect", ib2).unwrap();
1672 assert_eq!(ia_tensor2.data, vec![1.0, 3.0]);
1673 assert_eq!(ib_tensor2.data, vec![2.0, 1.0]);
1674 }
1675
1676 #[test]
1677 #[cfg(feature = "wgpu")]
1678 fn intersect_wgpu_matches_cpu() {
1679 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1680 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1681 );
1682 let a = Tensor::new(vec![4.0, 1.0, 2.0, 3.0], vec![4, 1]).unwrap();
1683 let b = Tensor::new(vec![2.0, 6.0, 3.0], vec![3, 1]).unwrap();
1684
1685 let cpu_eval = intersect_numeric_elements(
1686 a.clone(),
1687 b.clone(),
1688 &IntersectOptions {
1689 rows: false,
1690 order: IntersectOrder::Sorted,
1691 },
1692 )
1693 .unwrap();
1694 let cpu_values =
1695 tensor::value_into_tensor_for("intersect", cpu_eval.values_value()).unwrap();
1696 let cpu_ia = tensor::value_into_tensor_for("intersect", cpu_eval.ia_value()).unwrap();
1697 let cpu_ib = tensor::value_into_tensor_for("intersect", cpu_eval.ib_value()).unwrap();
1698
1699 let provider = runmat_accelerate_api::provider().expect("provider");
1700 let view_a = HostTensorView {
1701 data: &a.data,
1702 shape: &a.shape,
1703 };
1704 let view_b = HostTensorView {
1705 data: &b.data,
1706 shape: &b.shape,
1707 };
1708 let handle_a = provider.upload(&view_a).expect("upload A");
1709 let handle_b = provider.upload(&view_b).expect("upload B");
1710 let gpu_eval = evaluate(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
1711 .expect("intersect");
1712 let gpu_values =
1713 tensor::value_into_tensor_for("intersect", gpu_eval.values_value()).unwrap();
1714 let gpu_ia = tensor::value_into_tensor_for("intersect", gpu_eval.ia_value()).unwrap();
1715 let gpu_ib = tensor::value_into_tensor_for("intersect", gpu_eval.ib_value()).unwrap();
1716
1717 assert_eq!(gpu_values.data, cpu_values.data);
1718 assert_eq!(gpu_ia.data, cpu_ia.data);
1719 assert_eq!(gpu_ib.data, cpu_ib.data);
1720 }
1721
1722 #[test]
1723 #[cfg(feature = "doc_export")]
1724 fn doc_examples_present() {
1725 let blocks = test_support::doc_examples(DOC_MD);
1726 assert!(!blocks.is_empty());
1727 }
1728}