runmat_runtime/builtins/array/sorting_sets/
intersect.rs

1//! MATLAB-compatible `intersect` builtin with GPU-aware semantics for RunMat.
2//!
3//! Supports element-wise and row-wise intersections with optional stable ordering,
4//! and index outputs that mirror MathWorks MATLAB semantics. GPU tensors are
5//! gathered to host memory unless a provider supplies a dedicated `intersect`
6//! kernel hook.
7
8use std::cmp::Ordering;
9use std::collections::{HashMap, HashSet};
10
11use runmat_accelerate_api::GpuTensorHandle;
12use runmat_builtins::{CharArray, ComplexTensor, StringArray, Tensor, Value};
13use runmat_macros::runtime_builtin;
14
15use crate::builtins::common::gpu_helpers;
16use crate::builtins::common::random_args::complex_tensor_into_value;
17use crate::builtins::common::spec::{
18    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
19    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
20};
21use crate::builtins::common::tensor;
22#[cfg(feature = "doc_export")]
23use crate::register_builtin_doc_text;
24use crate::{register_builtin_fusion_spec, register_builtin_gpu_spec};
25
26#[cfg(feature = "doc_export")]
27pub const DOC_MD: &str = r#"---
28title: "intersect"
29category: "array/sorting_sets"
30keywords: ["intersect", "set", "stable", "rows", "indices", "gpu"]
31summary: "Return common elements or rows across arrays with MATLAB-compatible ordering and index outputs."
32references:
33  - https://www.mathworks.com/help/matlab/ref/intersect.html
34gpu_support:
35  elementwise: false
36  reduction: false
37  precisions: ["f32", "f64"]
38  broadcasting: "none"
39  notes: "When providers lack a dedicated `intersect` hook, RunMat gathers GPU tensors and executes the CPU implementation."
40fusion:
41  elementwise: false
42  reduction: false
43  max_inputs: 2
44  constants: "inline"
45requires_feature: null
46tested:
47  unit: "builtins::array::sorting_sets::intersect::tests"
48  integration: "builtins::array::sorting_sets::intersect::tests::intersect_gpu_roundtrip"
49---
50
51# What does the `intersect` function do in MATLAB / RunMat?
52`intersect(A, B)` computes the set intersection of `A` and `B`. By default it returns the
53distinct common values sorted ascending, along with optional index outputs that map each
54result back to its originating position in the inputs.
55
56## How does the `intersect` function behave in MATLAB / RunMat?
57- `intersect(A, B)` flattens both inputs column-major, removes duplicates, and returns the
58  sorted intersection.
59- `[C, IA, IB] = intersect(A, B)` also returns index vectors so that `C = A(IA)` and
60  `C = B(IB)`.
61- `intersect(A, B, 'stable')` preserves the first appearance order from `A`.
62- `intersect(A, B, 'rows')` treats each row as an element. Inputs must share the same number
63  of columns.
64- Character arrays, string arrays, logical arrays, complex values, and numerical types are
65  fully supported. Mixed classes raise a descriptive error.
66- Legacy flags such as `'legacy'` or `'R2012a'` are not supported in RunMat.
67
68## `intersect` Function GPU Execution Behavior
69`intersect` registers as a residency sink. When the active acceleration provider offers a
70dedicated `intersect` hook, the set logic can execute directly on the device and return
71GPU-resident results. Until such a hook lands, RunMat automatically gathers GPU tensors to
72host memory, runs the CPU implementation, and materializes host-side outputs so behavior
73matches MathWorks MATLAB exactly.
74
75## Examples of using the `intersect` function in MATLAB / RunMat
76
77### Finding common values in numeric vectors
78```matlab
79A = [5 7 5 1];
80B = [7 1 3];
81[C, IA, IB] = intersect(A, B);
82```
83Expected output:
84```matlab
85C =
86     1
87     7
88IA =
89     4
90     2
91IB =
92     2
93     1
94```
95
96### Preserving input order with `'stable'`
97```matlab
98A = [4 2 4 1 3];
99B = [3 4 5 1];
100C = intersect(A, B, 'stable');
101```
102Expected output:
103```matlab
104C =
105     4
106     1
107     3
108```
109
110### Intersecting matrix rows
111```matlab
112A = [1 2; 3 4; 1 2];
113B = [2 3; 1 2];
114[C, IA, IB] = intersect(A, B, 'rows');
115```
116Expected output:
117```matlab
118C =
119     1     2
120IA =
121     1
122IB =
123     2
124```
125
126### Working with strings and character arrays
127```matlab
128names1 = ["apple" "orange" "pear"];
129names2 = ["pear" "grape" "orange"];
130[common, IA, IB] = intersect(names1, names2, 'stable');
131```
132Expected output:
133```matlab
134common =
135  1×2 string array
136    "orange"    "pear"
137IA =
138     2
139     3
140IB =
141     3
142     1
143```
144
145### Using `intersect` on GPU arrays
146```matlab
147G = gpuArray([10 4 6 4]);
148H = gpuArray([6 4 2]);
149result = intersect(G, H);
150```
151RunMat gathers the data to host memory (until providers implement a device kernel) and returns:
152```matlab
153result =
154     4
155     6
156```
157
158## FAQ
159
160### Does `intersect` keep duplicate values?
161No. MATLAB and RunMat return each common value at most once. Use other logic (such as
162logical indexing) if you need multiset behavior.
163
164### What is the default ordering?
165`intersect` sorts results ascending by default. Specify `'stable'` to preserve the input
166order from `A`.
167
168### Can I intersect rows that contain strings?
169Yes. String arrays support both element and row intersections. When using `'rows'`, the
170inputs must have the same number of columns.
171
172### Are NaN values considered equal?
173Yes. `NaN` values are treated as equal for the purposes of intersection, matching MATLAB.
174
175### Is the `'legacy'` flag supported?
176No. RunMat only implements the modern MATLAB semantics. Passing `'legacy'` or `'R2012a'`
177raises an error.
178
179## GPU residency in RunMat (Do I need `gpuArray`?)
180You usually do **not** need to call `gpuArray` manually. RunMat's planner keeps tensors
181resident on the GPU when profitable. If a provider lacks an `intersect` kernel the runtime
182gathers automatically, so explicit residency management is rarely needed. Explicit calls to
183`gpuArray` remain available for compatibility with MathWorks MATLAB workflows.
184
185## See Also
186[union](./union), [unique](./unique), [sort](./sort), [sortrows](./sortrows)
187
188## Source & Feedback
189- Source code: [`crates/runmat-runtime/src/builtins/array/sorting_sets/intersect.rs`](https://github.com/runmat-org/runmat/blob/main/crates/runmat-runtime/src/builtins/array/sorting_sets/intersect.rs)
190- Found a bug? [Open an issue](https://github.com/runmat-org/runmat/issues/new/choose) with details and a minimal repro.
191"#;
192
193pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
194    name: "intersect",
195    op_kind: GpuOpKind::Custom("intersect"),
196    supported_precisions: &[ScalarType::F32, ScalarType::F64],
197    broadcast: BroadcastSemantics::None,
198    provider_hooks: &[ProviderHook::Custom("intersect")],
199    constant_strategy: ConstantStrategy::InlineLiteral,
200    residency: ResidencyPolicy::GatherImmediately,
201    nan_mode: ReductionNaN::Include,
202    two_pass_threshold: None,
203    workgroup_size: None,
204    accepts_nan_mode: true,
205    notes:
206        "Providers may expose a dedicated intersect hook; otherwise tensors are gathered and processed on the host.",
207};
208
209register_builtin_gpu_spec!(GPU_SPEC);
210
211pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
212    name: "intersect",
213    shape: ShapeRequirements::Any,
214    constant_strategy: ConstantStrategy::InlineLiteral,
215    elementwise: None,
216    reduction: None,
217    emits_nan: true,
218    notes: "`intersect` materialises its inputs and terminates fusion chains; upstream GPU tensors are gathered when necessary.",
219};
220
221register_builtin_fusion_spec!(FUSION_SPEC);
222
223#[cfg(feature = "doc_export")]
224register_builtin_doc_text!("intersect", DOC_MD);
225
226#[runtime_builtin(
227    name = "intersect",
228    category = "array/sorting_sets",
229    summary = "Return common elements or rows across arrays with MATLAB-compatible ordering and index outputs.",
230    keywords = "intersect,set,stable,rows,indices,gpu",
231    accel = "array_construct",
232    sink = true
233)]
234fn intersect_builtin(a: Value, b: Value, rest: Vec<Value>) -> Result<Value, String> {
235    evaluate(a, b, &rest).map(|eval| eval.into_values_value())
236}
237
238/// Evaluate the `intersect` builtin once and expose all outputs.
239pub fn evaluate(a: Value, b: Value, rest: &[Value]) -> Result<IntersectEvaluation, String> {
240    let opts = parse_options(rest)?;
241    match (a, b) {
242        (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
243            intersect_gpu_pair(handle_a, handle_b, &opts)
244        }
245        (Value::GpuTensor(handle_a), other) => intersect_gpu_mixed(handle_a, other, &opts, true),
246        (other, Value::GpuTensor(handle_b)) => intersect_gpu_mixed(handle_b, other, &opts, false),
247        (left, right) => intersect_host(left, right, &opts),
248    }
249}
250
251#[derive(Debug, Clone, Copy, PartialEq, Eq)]
252enum IntersectOrder {
253    Sorted,
254    Stable,
255}
256
257#[derive(Debug, Clone)]
258struct IntersectOptions {
259    rows: bool,
260    order: IntersectOrder,
261}
262
263fn parse_options(rest: &[Value]) -> Result<IntersectOptions, String> {
264    let mut opts = IntersectOptions {
265        rows: false,
266        order: IntersectOrder::Sorted,
267    };
268    let mut seen_order: Option<IntersectOrder> = None;
269
270    for arg in rest {
271        let text = tensor::value_to_string(arg)
272            .ok_or_else(|| "intersect: expected string option arguments".to_string())?;
273        let lowered = text.trim().to_ascii_lowercase();
274        match lowered.as_str() {
275            "rows" => opts.rows = true,
276            "sorted" => {
277                if let Some(prev) = seen_order {
278                    if prev != IntersectOrder::Sorted {
279                        return Err("intersect: cannot combine 'sorted' with 'stable'".to_string());
280                    }
281                }
282                seen_order = Some(IntersectOrder::Sorted);
283                opts.order = IntersectOrder::Sorted;
284            }
285            "stable" => {
286                if let Some(prev) = seen_order {
287                    if prev != IntersectOrder::Stable {
288                        return Err("intersect: cannot combine 'sorted' with 'stable'".to_string());
289                    }
290                }
291                seen_order = Some(IntersectOrder::Stable);
292                opts.order = IntersectOrder::Stable;
293            }
294            "legacy" | "r2012a" => {
295                return Err("intersect: the 'legacy' behaviour is not supported".to_string());
296            }
297            other => return Err(format!("intersect: unrecognised option '{other}'")),
298        }
299    }
300
301    Ok(opts)
302}
303
304fn intersect_gpu_pair(
305    handle_a: GpuTensorHandle,
306    handle_b: GpuTensorHandle,
307    opts: &IntersectOptions,
308) -> Result<IntersectEvaluation, String> {
309    let tensor_a = gpu_helpers::gather_tensor(&handle_a)?;
310    let tensor_b = gpu_helpers::gather_tensor(&handle_b)?;
311    intersect_numeric(tensor_a, tensor_b, opts)
312}
313
314fn intersect_gpu_mixed(
315    handle_gpu: GpuTensorHandle,
316    other: Value,
317    opts: &IntersectOptions,
318    gpu_is_a: bool,
319) -> Result<IntersectEvaluation, String> {
320    let tensor_gpu = gpu_helpers::gather_tensor(&handle_gpu)?;
321    let tensor_other = tensor::value_into_tensor_for("intersect", other)?;
322    if gpu_is_a {
323        intersect_numeric(tensor_gpu, tensor_other, opts)
324    } else {
325        intersect_numeric(tensor_other, tensor_gpu, opts)
326    }
327}
328
329fn intersect_host(
330    a: Value,
331    b: Value,
332    opts: &IntersectOptions,
333) -> Result<IntersectEvaluation, String> {
334    match (a, b) {
335        (Value::ComplexTensor(at), Value::ComplexTensor(bt)) => intersect_complex(at, bt, opts),
336        (Value::ComplexTensor(at), Value::Complex(re, im)) => {
337            let bt = scalar_complex_tensor(re, im)?;
338            intersect_complex(at, bt, opts)
339        }
340        (Value::Complex(re, im), Value::ComplexTensor(bt)) => {
341            let at = scalar_complex_tensor(re, im)?;
342            intersect_complex(at, bt, opts)
343        }
344        (Value::Complex(a_re, a_im), Value::Complex(b_re, b_im)) => {
345            let at = scalar_complex_tensor(a_re, a_im)?;
346            let bt = scalar_complex_tensor(b_re, b_im)?;
347            intersect_complex(at, bt, opts)
348        }
349        (Value::ComplexTensor(at), other) => {
350            let bt = value_into_complex_tensor(other)?;
351            intersect_complex(at, bt, opts)
352        }
353        (other, Value::ComplexTensor(bt)) => {
354            let at = value_into_complex_tensor(other)?;
355            intersect_complex(at, bt, opts)
356        }
357        (Value::Complex(re, im), other) => {
358            let at = scalar_complex_tensor(re, im)?;
359            let bt = value_into_complex_tensor(other)?;
360            intersect_complex(at, bt, opts)
361        }
362        (other, Value::Complex(re, im)) => {
363            let at = value_into_complex_tensor(other)?;
364            let bt = scalar_complex_tensor(re, im)?;
365            intersect_complex(at, bt, opts)
366        }
367
368        (Value::CharArray(ac), Value::CharArray(bc)) => intersect_char(ac, bc, opts),
369
370        (Value::StringArray(astring), Value::StringArray(bstring)) => {
371            intersect_string(astring, bstring, opts)
372        }
373        (Value::StringArray(astring), Value::String(b)) => {
374            let bstring =
375                StringArray::new(vec![b], vec![1, 1]).map_err(|e| format!("intersect: {e}"))?;
376            intersect_string(astring, bstring, opts)
377        }
378        (Value::String(a), Value::StringArray(bstring)) => {
379            let astring =
380                StringArray::new(vec![a], vec![1, 1]).map_err(|e| format!("intersect: {e}"))?;
381            intersect_string(astring, bstring, opts)
382        }
383        (Value::String(a), Value::String(b)) => {
384            let astring =
385                StringArray::new(vec![a], vec![1, 1]).map_err(|e| format!("intersect: {e}"))?;
386            let bstring =
387                StringArray::new(vec![b], vec![1, 1]).map_err(|e| format!("intersect: {e}"))?;
388            intersect_string(astring, bstring, opts)
389        }
390
391        (left, right) => {
392            let tensor_a = tensor::value_into_tensor_for("intersect", left)?;
393            let tensor_b = tensor::value_into_tensor_for("intersect", right)?;
394            intersect_numeric(tensor_a, tensor_b, opts)
395        }
396    }
397}
398
399fn intersect_numeric(
400    a: Tensor,
401    b: Tensor,
402    opts: &IntersectOptions,
403) -> Result<IntersectEvaluation, String> {
404    if opts.rows {
405        intersect_numeric_rows(a, b, opts)
406    } else {
407        intersect_numeric_elements(a, b, opts)
408    }
409}
410
411fn intersect_numeric_elements(
412    a: Tensor,
413    b: Tensor,
414    opts: &IntersectOptions,
415) -> Result<IntersectEvaluation, String> {
416    let mut b_map: HashMap<u64, usize> = HashMap::new();
417    for (idx, &value) in b.data.iter().enumerate() {
418        let key = canonicalize_f64(value);
419        b_map.entry(key).or_insert(idx);
420    }
421
422    let mut seen: HashSet<u64> = HashSet::new();
423    let mut entries = Vec::<NumericIntersectEntry>::new();
424    let mut order_counter = 0usize;
425
426    for (idx, &value) in a.data.iter().enumerate() {
427        let key = canonicalize_f64(value);
428        if seen.contains(&key) {
429            continue;
430        }
431        if let Some(&b_idx) = b_map.get(&key) {
432            entries.push(NumericIntersectEntry {
433                value,
434                a_index: idx,
435                b_index: b_idx,
436                order_rank: order_counter,
437            });
438            seen.insert(key);
439            order_counter += 1;
440        }
441    }
442
443    assemble_numeric_intersect(entries, opts)
444}
445
446fn intersect_numeric_rows(
447    a: Tensor,
448    b: Tensor,
449    opts: &IntersectOptions,
450) -> Result<IntersectEvaluation, String> {
451    if a.shape.len() != 2 || b.shape.len() != 2 {
452        return Err("intersect: 'rows' option requires 2-D numeric matrices".to_string());
453    }
454    if a.shape[1] != b.shape[1] {
455        return Err(
456            "intersect: inputs must have the same number of columns when using 'rows'".to_string(),
457        );
458    }
459    let rows_a = a.shape[0];
460    let cols = a.shape[1];
461    let rows_b = b.shape[0];
462
463    let mut b_map: HashMap<NumericRowKey, usize> = HashMap::new();
464    for r in 0..rows_b {
465        let mut row_values = Vec::with_capacity(cols);
466        for c in 0..cols {
467            let idx = r + c * rows_b;
468            row_values.push(b.data[idx]);
469        }
470        let key = NumericRowKey::from_slice(&row_values);
471        b_map.entry(key).or_insert(r);
472    }
473
474    let mut seen: HashSet<NumericRowKey> = HashSet::new();
475    let mut entries = Vec::<NumericRowIntersectEntry>::new();
476    let mut order_counter = 0usize;
477
478    for r in 0..rows_a {
479        let mut row_values = Vec::with_capacity(cols);
480        for c in 0..cols {
481            let idx = r + c * rows_a;
482            row_values.push(a.data[idx]);
483        }
484        let key = NumericRowKey::from_slice(&row_values);
485        if seen.contains(&key) {
486            continue;
487        }
488        if let Some(&b_row) = b_map.get(&key) {
489            entries.push(NumericRowIntersectEntry {
490                row_data: row_values,
491                a_row: r,
492                b_row,
493                order_rank: order_counter,
494            });
495            seen.insert(key);
496            order_counter += 1;
497        }
498    }
499
500    assemble_numeric_row_intersect(entries, opts, cols)
501}
502
503fn intersect_complex(
504    a: ComplexTensor,
505    b: ComplexTensor,
506    opts: &IntersectOptions,
507) -> Result<IntersectEvaluation, String> {
508    if opts.rows {
509        intersect_complex_rows(a, b, opts)
510    } else {
511        intersect_complex_elements(a, b, opts)
512    }
513}
514
515fn intersect_complex_elements(
516    a: ComplexTensor,
517    b: ComplexTensor,
518    opts: &IntersectOptions,
519) -> Result<IntersectEvaluation, String> {
520    let mut b_map: HashMap<ComplexKey, usize> = HashMap::new();
521    for (idx, &value) in b.data.iter().enumerate() {
522        let key = ComplexKey::new(value);
523        b_map.entry(key).or_insert(idx);
524    }
525
526    let mut seen: HashSet<ComplexKey> = HashSet::new();
527    let mut entries = Vec::<ComplexIntersectEntry>::new();
528    let mut order_counter = 0usize;
529
530    for (idx, &value) in a.data.iter().enumerate() {
531        let key = ComplexKey::new(value);
532        if seen.contains(&key) {
533            continue;
534        }
535        if let Some(&b_idx) = b_map.get(&key) {
536            entries.push(ComplexIntersectEntry {
537                value,
538                a_index: idx,
539                b_index: b_idx,
540                order_rank: order_counter,
541            });
542            seen.insert(key);
543            order_counter += 1;
544        }
545    }
546
547    assemble_complex_intersect(entries, opts)
548}
549
550fn intersect_complex_rows(
551    a: ComplexTensor,
552    b: ComplexTensor,
553    opts: &IntersectOptions,
554) -> Result<IntersectEvaluation, String> {
555    if a.shape.len() != 2 || b.shape.len() != 2 {
556        return Err("intersect: 'rows' option requires 2-D complex matrices".to_string());
557    }
558    if a.shape[1] != b.shape[1] {
559        return Err(
560            "intersect: inputs must have the same number of columns when using 'rows'".to_string(),
561        );
562    }
563    let rows_a = a.shape[0];
564    let cols = a.shape[1];
565    let rows_b = b.shape[0];
566
567    let mut b_map: HashMap<Vec<ComplexKey>, usize> = HashMap::new();
568    for r in 0..rows_b {
569        let mut row_keys = Vec::with_capacity(cols);
570        for c in 0..cols {
571            let idx = r + c * rows_b;
572            row_keys.push(ComplexKey::new(b.data[idx]));
573        }
574        b_map.entry(row_keys).or_insert(r);
575    }
576
577    let mut seen: HashSet<Vec<ComplexKey>> = HashSet::new();
578    let mut entries = Vec::<ComplexRowIntersectEntry>::new();
579    let mut order_counter = 0usize;
580
581    for r in 0..rows_a {
582        let mut row_values = Vec::with_capacity(cols);
583        let mut row_keys = Vec::with_capacity(cols);
584        for c in 0..cols {
585            let idx = r + c * rows_a;
586            let value = a.data[idx];
587            row_values.push(value);
588            row_keys.push(ComplexKey::new(value));
589        }
590        if seen.contains(&row_keys) {
591            continue;
592        }
593        if let Some(&b_row) = b_map.get(&row_keys) {
594            entries.push(ComplexRowIntersectEntry {
595                row_data: row_values,
596                a_row: r,
597                b_row,
598                order_rank: order_counter,
599            });
600            seen.insert(row_keys);
601            order_counter += 1;
602        }
603    }
604
605    assemble_complex_row_intersect(entries, opts, cols)
606}
607
608fn intersect_char(
609    a: CharArray,
610    b: CharArray,
611    opts: &IntersectOptions,
612) -> Result<IntersectEvaluation, String> {
613    if opts.rows {
614        intersect_char_rows(a, b, opts)
615    } else {
616        intersect_char_elements(a, b, opts)
617    }
618}
619
620fn intersect_char_elements(
621    a: CharArray,
622    b: CharArray,
623    opts: &IntersectOptions,
624) -> Result<IntersectEvaluation, String> {
625    let mut seen: HashSet<u32> = HashSet::new();
626    let mut entries = Vec::<CharIntersectEntry>::new();
627    let mut order_counter = 0usize;
628
629    for col in 0..a.cols {
630        for row in 0..a.rows {
631            let linear_idx = row + col * a.rows;
632            let data_idx = row * a.cols + col;
633            let ch = a.data[data_idx];
634            let key = ch as u32;
635            if seen.contains(&key) {
636                continue;
637            }
638            if let Some(b_idx) = find_char_index(&b, ch) {
639                entries.push(CharIntersectEntry {
640                    ch,
641                    a_index: linear_idx,
642                    b_index: b_idx,
643                    order_rank: order_counter,
644                });
645                seen.insert(key);
646                order_counter += 1;
647            }
648        }
649    }
650
651    assemble_char_intersect(entries, opts, &b)
652}
653
654fn intersect_char_rows(
655    a: CharArray,
656    b: CharArray,
657    opts: &IntersectOptions,
658) -> Result<IntersectEvaluation, String> {
659    if a.cols != b.cols {
660        return Err(
661            "intersect: inputs must have the same number of columns when using 'rows'".to_string(),
662        );
663    }
664    let rows_a = a.rows;
665    let rows_b = b.rows;
666    let cols = a.cols;
667
668    let mut b_map: HashMap<RowCharKey, usize> = HashMap::new();
669    for r in 0..rows_b {
670        let mut row_values = Vec::with_capacity(cols);
671        for c in 0..cols {
672            let idx = r * cols + c;
673            row_values.push(b.data[idx]);
674        }
675        let key = RowCharKey::from_slice(&row_values);
676        b_map.entry(key).or_insert(r);
677    }
678
679    let mut seen: HashSet<RowCharKey> = HashSet::new();
680    let mut entries = Vec::<CharRowIntersectEntry>::new();
681    let mut order_counter = 0usize;
682
683    for r in 0..rows_a {
684        let mut row_values = Vec::with_capacity(cols);
685        for c in 0..cols {
686            let idx = r * cols + c;
687            row_values.push(a.data[idx]);
688        }
689        let key = RowCharKey::from_slice(&row_values);
690        if seen.contains(&key) {
691            continue;
692        }
693        if let Some(&b_row) = b_map.get(&key) {
694            entries.push(CharRowIntersectEntry {
695                row_data: row_values,
696                a_row: r,
697                b_row,
698                order_rank: order_counter,
699            });
700            seen.insert(key);
701            order_counter += 1;
702        }
703    }
704
705    assemble_char_row_intersect(entries, opts, cols)
706}
707
708fn find_char_index(array: &CharArray, target: char) -> Option<usize> {
709    for col in 0..array.cols {
710        for row in 0..array.rows {
711            let data_idx = row * array.cols + col;
712            if array.data[data_idx] == target {
713                return Some(row + col * array.rows);
714            }
715        }
716    }
717    None
718}
719
720fn intersect_string(
721    a: StringArray,
722    b: StringArray,
723    opts: &IntersectOptions,
724) -> Result<IntersectEvaluation, String> {
725    if opts.rows {
726        intersect_string_rows(a, b, opts)
727    } else {
728        intersect_string_elements(a, b, opts)
729    }
730}
731
732fn intersect_string_elements(
733    a: StringArray,
734    b: StringArray,
735    opts: &IntersectOptions,
736) -> Result<IntersectEvaluation, String> {
737    let mut b_map: HashMap<String, usize> = HashMap::new();
738    for (idx, value) in b.data.iter().enumerate() {
739        b_map.entry(value.clone()).or_insert(idx);
740    }
741
742    let mut seen: HashSet<String> = HashSet::new();
743    let mut entries = Vec::<StringIntersectEntry>::new();
744    let mut order_counter = 0usize;
745
746    for (idx, value) in a.data.iter().enumerate() {
747        if seen.contains(value) {
748            continue;
749        }
750        if let Some(&b_idx) = b_map.get(value) {
751            entries.push(StringIntersectEntry {
752                value: value.clone(),
753                a_index: idx,
754                b_index: b_idx,
755                order_rank: order_counter,
756            });
757            seen.insert(value.clone());
758            order_counter += 1;
759        }
760    }
761
762    assemble_string_intersect(entries, opts)
763}
764
765fn intersect_string_rows(
766    a: StringArray,
767    b: StringArray,
768    opts: &IntersectOptions,
769) -> Result<IntersectEvaluation, String> {
770    if a.shape.len() != 2 || b.shape.len() != 2 {
771        return Err("intersect: 'rows' option requires 2-D string arrays".to_string());
772    }
773    if a.shape[1] != b.shape[1] {
774        return Err(
775            "intersect: inputs must have the same number of columns when using 'rows'".to_string(),
776        );
777    }
778    let rows_a = a.shape[0];
779    let cols = a.shape[1];
780    let rows_b = b.shape[0];
781
782    let mut b_map: HashMap<RowStringKey, usize> = HashMap::new();
783    for r in 0..rows_b {
784        let mut row_values = Vec::with_capacity(cols);
785        for c in 0..cols {
786            let idx = r + c * rows_b;
787            row_values.push(b.data[idx].clone());
788        }
789        let key = RowStringKey::from_slice(&row_values);
790        b_map.entry(key).or_insert(r);
791    }
792
793    let mut seen: HashSet<RowStringKey> = HashSet::new();
794    let mut entries = Vec::<StringRowIntersectEntry>::new();
795    let mut order_counter = 0usize;
796
797    for r in 0..rows_a {
798        let mut row_values = Vec::with_capacity(cols);
799        for c in 0..cols {
800            let idx = r + c * rows_a;
801            row_values.push(a.data[idx].clone());
802        }
803        let key = RowStringKey::from_slice(&row_values);
804        if seen.contains(&key) {
805            continue;
806        }
807        if let Some(&b_row) = b_map.get(&key) {
808            entries.push(StringRowIntersectEntry {
809                row_data: row_values,
810                a_row: r,
811                b_row,
812                order_rank: order_counter,
813            });
814            seen.insert(key);
815            order_counter += 1;
816        }
817    }
818
819    assemble_string_row_intersect(entries, opts, cols)
820}
821
822#[derive(Debug, Clone)]
823pub struct IntersectEvaluation {
824    values: Value,
825    ia: Tensor,
826    ib: Tensor,
827}
828
829impl IntersectEvaluation {
830    fn new(values: Value, ia: Tensor, ib: Tensor) -> Self {
831        Self { values, ia, ib }
832    }
833
834    pub fn into_values_value(self) -> Value {
835        self.values
836    }
837
838    pub fn into_pair(self) -> (Value, Value) {
839        let ia = tensor::tensor_into_value(self.ia);
840        (self.values, ia)
841    }
842
843    pub fn into_triple(self) -> (Value, Value, Value) {
844        let ia = tensor::tensor_into_value(self.ia);
845        let ib = tensor::tensor_into_value(self.ib);
846        (self.values, ia, ib)
847    }
848
849    pub fn values_value(&self) -> Value {
850        self.values.clone()
851    }
852
853    pub fn ia_value(&self) -> Value {
854        tensor::tensor_into_value(self.ia.clone())
855    }
856
857    pub fn ib_value(&self) -> Value {
858        tensor::tensor_into_value(self.ib.clone())
859    }
860}
861
862#[derive(Debug)]
863struct NumericIntersectEntry {
864    value: f64,
865    a_index: usize,
866    b_index: usize,
867    order_rank: usize,
868}
869
870#[derive(Debug)]
871struct NumericRowIntersectEntry {
872    row_data: Vec<f64>,
873    a_row: usize,
874    b_row: usize,
875    order_rank: usize,
876}
877
878#[derive(Debug)]
879struct ComplexIntersectEntry {
880    value: (f64, f64),
881    a_index: usize,
882    b_index: usize,
883    order_rank: usize,
884}
885
886#[derive(Debug)]
887struct ComplexRowIntersectEntry {
888    row_data: Vec<(f64, f64)>,
889    a_row: usize,
890    b_row: usize,
891    order_rank: usize,
892}
893
894#[derive(Debug)]
895struct CharIntersectEntry {
896    ch: char,
897    a_index: usize,
898    b_index: usize,
899    order_rank: usize,
900}
901
902#[derive(Debug)]
903struct CharRowIntersectEntry {
904    row_data: Vec<char>,
905    a_row: usize,
906    b_row: usize,
907    order_rank: usize,
908}
909
910#[derive(Debug)]
911struct StringIntersectEntry {
912    value: String,
913    a_index: usize,
914    b_index: usize,
915    order_rank: usize,
916}
917
918#[derive(Debug)]
919struct StringRowIntersectEntry {
920    row_data: Vec<String>,
921    a_row: usize,
922    b_row: usize,
923    order_rank: usize,
924}
925
926fn assemble_numeric_intersect(
927    entries: Vec<NumericIntersectEntry>,
928    opts: &IntersectOptions,
929) -> Result<IntersectEvaluation, String> {
930    let mut order: Vec<usize> = (0..entries.len()).collect();
931    match opts.order {
932        IntersectOrder::Sorted => {
933            order.sort_by(|&lhs, &rhs| compare_f64(entries[lhs].value, entries[rhs].value));
934        }
935        IntersectOrder::Stable => {
936            order.sort_by_key(|&idx| entries[idx].order_rank);
937        }
938    }
939
940    let mut values = Vec::with_capacity(order.len());
941    let mut ia = Vec::with_capacity(order.len());
942    let mut ib = Vec::with_capacity(order.len());
943    for &idx in &order {
944        let entry = &entries[idx];
945        values.push(entry.value);
946        ia.push((entry.a_index + 1) as f64);
947        ib.push((entry.b_index + 1) as f64);
948    }
949
950    let value_tensor =
951        Tensor::new(values, vec![order.len(), 1]).map_err(|e| format!("intersect: {e}"))?;
952    let ia_tensor = Tensor::new(ia, vec![order.len(), 1]).map_err(|e| format!("intersect: {e}"))?;
953    let ib_tensor = Tensor::new(ib, vec![order.len(), 1]).map_err(|e| format!("intersect: {e}"))?;
954
955    Ok(IntersectEvaluation::new(
956        tensor::tensor_into_value(value_tensor),
957        ia_tensor,
958        ib_tensor,
959    ))
960}
961
962fn assemble_numeric_row_intersect(
963    entries: Vec<NumericRowIntersectEntry>,
964    opts: &IntersectOptions,
965    cols: usize,
966) -> Result<IntersectEvaluation, String> {
967    let mut order: Vec<usize> = (0..entries.len()).collect();
968    match opts.order {
969        IntersectOrder::Sorted => {
970            order.sort_by(|&lhs, &rhs| {
971                compare_numeric_rows(&entries[lhs].row_data, &entries[rhs].row_data)
972            });
973        }
974        IntersectOrder::Stable => {
975            order.sort_by_key(|&idx| entries[idx].order_rank);
976        }
977    }
978
979    let rows_out = order.len();
980    let mut values = vec![0.0f64; rows_out * cols];
981    let mut ia = Vec::with_capacity(rows_out);
982    let mut ib = Vec::with_capacity(rows_out);
983
984    for (row_pos, &entry_idx) in order.iter().enumerate() {
985        let entry = &entries[entry_idx];
986        for col in 0..cols {
987            let dest = row_pos + col * rows_out;
988            values[dest] = entry.row_data[col];
989        }
990        ia.push((entry.a_row + 1) as f64);
991        ib.push((entry.b_row + 1) as f64);
992    }
993
994    let value_tensor =
995        Tensor::new(values, vec![rows_out, cols]).map_err(|e| format!("intersect: {e}"))?;
996    let ia_tensor = Tensor::new(ia, vec![rows_out, 1]).map_err(|e| format!("intersect: {e}"))?;
997    let ib_tensor = Tensor::new(ib, vec![rows_out, 1]).map_err(|e| format!("intersect: {e}"))?;
998
999    Ok(IntersectEvaluation::new(
1000        tensor::tensor_into_value(value_tensor),
1001        ia_tensor,
1002        ib_tensor,
1003    ))
1004}
1005
1006fn assemble_complex_intersect(
1007    entries: Vec<ComplexIntersectEntry>,
1008    opts: &IntersectOptions,
1009) -> Result<IntersectEvaluation, String> {
1010    let mut order: Vec<usize> = (0..entries.len()).collect();
1011    match opts.order {
1012        IntersectOrder::Sorted => {
1013            order.sort_by(|&lhs, &rhs| compare_complex(entries[lhs].value, entries[rhs].value));
1014        }
1015        IntersectOrder::Stable => {
1016            order.sort_by_key(|&idx| entries[idx].order_rank);
1017        }
1018    }
1019
1020    let mut values = Vec::with_capacity(order.len());
1021    let mut ia = Vec::with_capacity(order.len());
1022    let mut ib = Vec::with_capacity(order.len());
1023    for &idx in &order {
1024        let entry = &entries[idx];
1025        values.push(entry.value);
1026        ia.push((entry.a_index + 1) as f64);
1027        ib.push((entry.b_index + 1) as f64);
1028    }
1029
1030    let value_tensor =
1031        ComplexTensor::new(values, vec![order.len(), 1]).map_err(|e| format!("intersect: {e}"))?;
1032    let ia_tensor = Tensor::new(ia, vec![order.len(), 1]).map_err(|e| format!("intersect: {e}"))?;
1033    let ib_tensor = Tensor::new(ib, vec![order.len(), 1]).map_err(|e| format!("intersect: {e}"))?;
1034
1035    Ok(IntersectEvaluation::new(
1036        complex_tensor_into_value(value_tensor),
1037        ia_tensor,
1038        ib_tensor,
1039    ))
1040}
1041
1042fn assemble_complex_row_intersect(
1043    entries: Vec<ComplexRowIntersectEntry>,
1044    opts: &IntersectOptions,
1045    cols: usize,
1046) -> Result<IntersectEvaluation, String> {
1047    let mut order: Vec<usize> = (0..entries.len()).collect();
1048    match opts.order {
1049        IntersectOrder::Sorted => {
1050            order.sort_by(|&lhs, &rhs| {
1051                compare_complex_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1052            });
1053        }
1054        IntersectOrder::Stable => {
1055            order.sort_by_key(|&idx| entries[idx].order_rank);
1056        }
1057    }
1058
1059    let rows_out = order.len();
1060    let mut values = vec![(0.0f64, 0.0f64); rows_out * cols];
1061    let mut ia = Vec::with_capacity(rows_out);
1062    let mut ib = Vec::with_capacity(rows_out);
1063
1064    for (row_pos, &entry_idx) in order.iter().enumerate() {
1065        let entry = &entries[entry_idx];
1066        for col in 0..cols {
1067            let dest = row_pos + col * rows_out;
1068            values[dest] = entry.row_data[col];
1069        }
1070        ia.push((entry.a_row + 1) as f64);
1071        ib.push((entry.b_row + 1) as f64);
1072    }
1073
1074    let value_tensor =
1075        ComplexTensor::new(values, vec![rows_out, cols]).map_err(|e| format!("intersect: {e}"))?;
1076    let ia_tensor = Tensor::new(ia, vec![rows_out, 1]).map_err(|e| format!("intersect: {e}"))?;
1077    let ib_tensor = Tensor::new(ib, vec![rows_out, 1]).map_err(|e| format!("intersect: {e}"))?;
1078
1079    Ok(IntersectEvaluation::new(
1080        complex_tensor_into_value(value_tensor),
1081        ia_tensor,
1082        ib_tensor,
1083    ))
1084}
1085
1086fn assemble_char_intersect(
1087    entries: Vec<CharIntersectEntry>,
1088    opts: &IntersectOptions,
1089    b: &CharArray,
1090) -> Result<IntersectEvaluation, String> {
1091    let mut order: Vec<usize> = (0..entries.len()).collect();
1092    match opts.order {
1093        IntersectOrder::Sorted => {
1094            order.sort_by(|&lhs, &rhs| entries[lhs].ch.cmp(&entries[rhs].ch));
1095        }
1096        IntersectOrder::Stable => {
1097            order.sort_by_key(|&idx| entries[idx].order_rank);
1098        }
1099    }
1100
1101    let mut values = Vec::with_capacity(order.len());
1102    let mut ia = Vec::with_capacity(order.len());
1103    let mut ib = Vec::with_capacity(order.len());
1104    for &idx in &order {
1105        let entry = &entries[idx];
1106        values.push(entry.ch);
1107        ia.push((entry.a_index + 1) as f64);
1108        let b_idx = find_char_index(b, entry.ch).unwrap_or(entry.b_index);
1109        ib.push((b_idx + 1) as f64);
1110    }
1111
1112    let value_array =
1113        CharArray::new(values, order.len(), 1).map_err(|e| format!("intersect: {e}"))?;
1114    let ia_tensor = Tensor::new(ia, vec![order.len(), 1]).map_err(|e| format!("intersect: {e}"))?;
1115    let ib_tensor = Tensor::new(ib, vec![order.len(), 1]).map_err(|e| format!("intersect: {e}"))?;
1116
1117    Ok(IntersectEvaluation::new(
1118        Value::CharArray(value_array),
1119        ia_tensor,
1120        ib_tensor,
1121    ))
1122}
1123
1124fn assemble_char_row_intersect(
1125    entries: Vec<CharRowIntersectEntry>,
1126    opts: &IntersectOptions,
1127    cols: usize,
1128) -> Result<IntersectEvaluation, String> {
1129    let mut order: Vec<usize> = (0..entries.len()).collect();
1130    match opts.order {
1131        IntersectOrder::Sorted => {
1132            order.sort_by(|&lhs, &rhs| {
1133                compare_char_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1134            });
1135        }
1136        IntersectOrder::Stable => {
1137            order.sort_by_key(|&idx| entries[idx].order_rank);
1138        }
1139    }
1140
1141    let rows_out = order.len();
1142    let mut values = vec!['\0'; rows_out * cols];
1143    let mut ia = Vec::with_capacity(rows_out);
1144    let mut ib = Vec::with_capacity(rows_out);
1145
1146    for (row_pos, &entry_idx) in order.iter().enumerate() {
1147        let entry = &entries[entry_idx];
1148        for col in 0..cols {
1149            let dest = row_pos * cols + col;
1150            values[dest] = entry.row_data[col];
1151        }
1152        ia.push((entry.a_row + 1) as f64);
1153        ib.push((entry.b_row + 1) as f64);
1154    }
1155
1156    let value_array =
1157        CharArray::new(values, rows_out, cols).map_err(|e| format!("intersect: {e}"))?;
1158    let ia_tensor = Tensor::new(ia, vec![rows_out, 1]).map_err(|e| format!("intersect: {e}"))?;
1159    let ib_tensor = Tensor::new(ib, vec![rows_out, 1]).map_err(|e| format!("intersect: {e}"))?;
1160
1161    Ok(IntersectEvaluation::new(
1162        Value::CharArray(value_array),
1163        ia_tensor,
1164        ib_tensor,
1165    ))
1166}
1167
1168fn assemble_string_intersect(
1169    entries: Vec<StringIntersectEntry>,
1170    opts: &IntersectOptions,
1171) -> Result<IntersectEvaluation, String> {
1172    let mut order: Vec<usize> = (0..entries.len()).collect();
1173    match opts.order {
1174        IntersectOrder::Sorted => {
1175            order.sort_by(|&lhs, &rhs| entries[lhs].value.cmp(&entries[rhs].value));
1176        }
1177        IntersectOrder::Stable => {
1178            order.sort_by_key(|&idx| entries[idx].order_rank);
1179        }
1180    }
1181
1182    let mut values = Vec::with_capacity(order.len());
1183    let mut ia = Vec::with_capacity(order.len());
1184    let mut ib = Vec::with_capacity(order.len());
1185    for &idx in &order {
1186        let entry = &entries[idx];
1187        values.push(entry.value.clone());
1188        ia.push((entry.a_index + 1) as f64);
1189        ib.push((entry.b_index + 1) as f64);
1190    }
1191
1192    let value_array =
1193        StringArray::new(values, vec![order.len(), 1]).map_err(|e| format!("intersect: {e}"))?;
1194    let ia_tensor = Tensor::new(ia, vec![order.len(), 1]).map_err(|e| format!("intersect: {e}"))?;
1195    let ib_tensor = Tensor::new(ib, vec![order.len(), 1]).map_err(|e| format!("intersect: {e}"))?;
1196
1197    Ok(IntersectEvaluation::new(
1198        Value::StringArray(value_array),
1199        ia_tensor,
1200        ib_tensor,
1201    ))
1202}
1203
1204fn assemble_string_row_intersect(
1205    entries: Vec<StringRowIntersectEntry>,
1206    opts: &IntersectOptions,
1207    cols: usize,
1208) -> Result<IntersectEvaluation, String> {
1209    let mut order: Vec<usize> = (0..entries.len()).collect();
1210    match opts.order {
1211        IntersectOrder::Sorted => {
1212            order.sort_by(|&lhs, &rhs| {
1213                compare_string_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1214            });
1215        }
1216        IntersectOrder::Stable => {
1217            order.sort_by_key(|&idx| entries[idx].order_rank);
1218        }
1219    }
1220
1221    let rows_out = order.len();
1222    let mut values = vec![String::new(); rows_out * cols];
1223    let mut ia = Vec::with_capacity(rows_out);
1224    let mut ib = Vec::with_capacity(rows_out);
1225
1226    for (row_pos, &entry_idx) in order.iter().enumerate() {
1227        let entry = &entries[entry_idx];
1228        for col in 0..cols {
1229            let dest = row_pos + col * rows_out;
1230            values[dest] = entry.row_data[col].clone();
1231        }
1232        ia.push((entry.a_row + 1) as f64);
1233        ib.push((entry.b_row + 1) as f64);
1234    }
1235
1236    let value_array =
1237        StringArray::new(values, vec![rows_out, cols]).map_err(|e| format!("intersect: {e}"))?;
1238    let ia_tensor = Tensor::new(ia, vec![rows_out, 1]).map_err(|e| format!("intersect: {e}"))?;
1239    let ib_tensor = Tensor::new(ib, vec![rows_out, 1]).map_err(|e| format!("intersect: {e}"))?;
1240
1241    Ok(IntersectEvaluation::new(
1242        Value::StringArray(value_array),
1243        ia_tensor,
1244        ib_tensor,
1245    ))
1246}
1247
1248#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1249struct NumericRowKey(Vec<u64>);
1250
1251impl NumericRowKey {
1252    fn from_slice(values: &[f64]) -> Self {
1253        NumericRowKey(values.iter().map(|&v| canonicalize_f64(v)).collect())
1254    }
1255}
1256
1257#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
1258struct ComplexKey {
1259    re: u64,
1260    im: u64,
1261}
1262
1263impl ComplexKey {
1264    fn new(value: (f64, f64)) -> Self {
1265        Self {
1266            re: canonicalize_f64(value.0),
1267            im: canonicalize_f64(value.1),
1268        }
1269    }
1270}
1271
1272#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1273struct RowCharKey(Vec<u32>);
1274
1275impl RowCharKey {
1276    fn from_slice(values: &[char]) -> Self {
1277        RowCharKey(values.iter().map(|&ch| ch as u32).collect())
1278    }
1279}
1280
1281#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1282struct RowStringKey(Vec<String>);
1283
1284impl RowStringKey {
1285    fn from_slice(values: &[String]) -> Self {
1286        RowStringKey(values.to_vec())
1287    }
1288}
1289
1290fn scalar_complex_tensor(re: f64, im: f64) -> Result<ComplexTensor, String> {
1291    ComplexTensor::new(vec![(re, im)], vec![1, 1]).map_err(|e| format!("intersect: {e}"))
1292}
1293
1294fn tensor_to_complex_owned(name: &str, tensor: Tensor) -> Result<ComplexTensor, String> {
1295    let Tensor { data, shape, .. } = tensor;
1296    let complex: Vec<(f64, f64)> = data.into_iter().map(|re| (re, 0.0)).collect();
1297    ComplexTensor::new(complex, shape).map_err(|e| format!("{name}: {e}"))
1298}
1299
1300fn value_into_complex_tensor(value: Value) -> Result<ComplexTensor, String> {
1301    match value {
1302        Value::ComplexTensor(tensor) => Ok(tensor),
1303        Value::Complex(re, im) => scalar_complex_tensor(re, im),
1304        other => {
1305            let tensor = tensor::value_into_tensor_for("intersect", other)?;
1306            tensor_to_complex_owned("intersect", tensor)
1307        }
1308    }
1309}
1310
1311fn canonicalize_f64(value: f64) -> u64 {
1312    if value.is_nan() {
1313        0x7ff8_0000_0000_0000u64
1314    } else if value == 0.0 {
1315        0u64
1316    } else {
1317        value.to_bits()
1318    }
1319}
1320
1321fn compare_f64(a: f64, b: f64) -> Ordering {
1322    if a.is_nan() {
1323        if b.is_nan() {
1324            Ordering::Equal
1325        } else {
1326            Ordering::Greater
1327        }
1328    } else if b.is_nan() {
1329        Ordering::Less
1330    } else {
1331        a.partial_cmp(&b).unwrap_or(Ordering::Equal)
1332    }
1333}
1334
1335fn compare_numeric_rows(a: &[f64], b: &[f64]) -> Ordering {
1336    for (lhs, rhs) in a.iter().zip(b.iter()) {
1337        let ord = compare_f64(*lhs, *rhs);
1338        if ord != Ordering::Equal {
1339            return ord;
1340        }
1341    }
1342    Ordering::Equal
1343}
1344
1345fn complex_is_nan(value: (f64, f64)) -> bool {
1346    value.0.is_nan() || value.1.is_nan()
1347}
1348
1349fn compare_complex(a: (f64, f64), b: (f64, f64)) -> Ordering {
1350    match (complex_is_nan(a), complex_is_nan(b)) {
1351        (true, true) => Ordering::Equal,
1352        (true, false) => Ordering::Greater,
1353        (false, true) => Ordering::Less,
1354        (false, false) => {
1355            let mag_a = a.0.hypot(a.1);
1356            let mag_b = b.0.hypot(b.1);
1357            let mag_cmp = compare_f64(mag_a, mag_b);
1358            if mag_cmp != Ordering::Equal {
1359                return mag_cmp;
1360            }
1361            let re_cmp = compare_f64(a.0, b.0);
1362            if re_cmp != Ordering::Equal {
1363                return re_cmp;
1364            }
1365            compare_f64(a.1, b.1)
1366        }
1367    }
1368}
1369
1370fn compare_complex_rows(a: &[(f64, f64)], b: &[(f64, f64)]) -> Ordering {
1371    for (lhs, rhs) in a.iter().zip(b.iter()) {
1372        let ord = compare_complex(*lhs, *rhs);
1373        if ord != Ordering::Equal {
1374            return ord;
1375        }
1376    }
1377    Ordering::Equal
1378}
1379
1380fn compare_char_rows(a: &[char], b: &[char]) -> Ordering {
1381    for (lhs, rhs) in a.iter().zip(b.iter()) {
1382        let ord = lhs.cmp(rhs);
1383        if ord != Ordering::Equal {
1384            return ord;
1385        }
1386    }
1387    Ordering::Equal
1388}
1389
1390fn compare_string_rows(a: &[String], b: &[String]) -> Ordering {
1391    for (lhs, rhs) in a.iter().zip(b.iter()) {
1392        let ord = lhs.cmp(rhs);
1393        if ord != Ordering::Equal {
1394            return ord;
1395        }
1396    }
1397    Ordering::Equal
1398}
1399
1400#[cfg(test)]
1401mod tests {
1402    use super::*;
1403    use crate::builtins::common::test_support;
1404    use runmat_accelerate_api::HostTensorView;
1405
1406    #[test]
1407    fn intersect_numeric_sorted() {
1408        let a = Tensor::new(vec![5.0, 7.0, 5.0, 1.0], vec![4, 1]).unwrap();
1409        let b = Tensor::new(vec![7.0, 1.0, 3.0], vec![3, 1]).unwrap();
1410        let eval = intersect_numeric_elements(
1411            a,
1412            b,
1413            &IntersectOptions {
1414                rows: false,
1415                order: IntersectOrder::Sorted,
1416            },
1417        )
1418        .expect("intersect");
1419        let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1420        assert_eq!(values.data, vec![1.0, 7.0]);
1421        let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1422        let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1423        assert_eq!(ia.data, vec![4.0, 2.0]);
1424        assert_eq!(ib.data, vec![2.0, 1.0]);
1425    }
1426
1427    #[test]
1428    fn intersect_numeric_stable() {
1429        let a = Tensor::new(vec![4.0, 2.0, 4.0, 1.0, 3.0], vec![5, 1]).unwrap();
1430        let b = Tensor::new(vec![3.0, 4.0, 5.0, 1.0], vec![4, 1]).unwrap();
1431        let eval = intersect_numeric_elements(
1432            a,
1433            b,
1434            &IntersectOptions {
1435                rows: false,
1436                order: IntersectOrder::Stable,
1437            },
1438        )
1439        .expect("intersect");
1440        let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1441        assert_eq!(values.data, vec![4.0, 1.0, 3.0]);
1442        let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1443        let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1444        assert_eq!(ia.data, vec![1.0, 4.0, 5.0]);
1445        assert_eq!(ib.data, vec![2.0, 4.0, 1.0]);
1446    }
1447
1448    #[test]
1449    fn intersect_numeric_handles_nan() {
1450        let a = Tensor::new(vec![f64::NAN, 1.0, f64::NAN], vec![3, 1]).unwrap();
1451        let b = Tensor::new(vec![2.0, f64::NAN], vec![2, 1]).unwrap();
1452        let eval = intersect_numeric_elements(
1453            a,
1454            b,
1455            &IntersectOptions {
1456                rows: false,
1457                order: IntersectOrder::Sorted,
1458            },
1459        )
1460        .expect("intersect");
1461        let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1462        assert_eq!(values.data.len(), 1);
1463        assert!(values.data[0].is_nan());
1464        let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1465        let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1466        assert_eq!(ia.data, vec![1.0]);
1467        assert_eq!(ib.data, vec![2.0]);
1468    }
1469
1470    #[test]
1471    fn intersect_complex_with_real_inputs() {
1472        let complex =
1473            ComplexTensor::new(vec![(1.0, 0.0), (2.0, 0.0), (3.0, 1.0)], vec![3, 1]).unwrap();
1474        let real = Tensor::new(vec![2.0, 4.0, 1.0], vec![3, 1]).unwrap();
1475        let real_complex = tensor_to_complex_owned("intersect", real).unwrap();
1476        let eval = intersect_complex(
1477            complex,
1478            real_complex,
1479            &IntersectOptions {
1480                rows: false,
1481                order: IntersectOrder::Sorted,
1482            },
1483        )
1484        .expect("intersect complex");
1485        match eval.values_value() {
1486            Value::ComplexTensor(t) => {
1487                assert_eq!(t.data, vec![(1.0, 0.0), (2.0, 0.0)]);
1488            }
1489            other => panic!("expected complex tensor, got {other:?}"),
1490        }
1491        let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1492        let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1493        assert_eq!(ia.data, vec![1.0, 2.0]);
1494        assert_eq!(ib.data, vec![3.0, 1.0]);
1495    }
1496
1497    #[test]
1498    fn intersect_numeric_rows_default() {
1499        let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
1500        let b = Tensor::new(vec![1.0, 5.0, 2.0, 6.0], vec![2, 2]).unwrap();
1501        let eval = intersect_numeric_rows(
1502            a,
1503            b,
1504            &IntersectOptions {
1505                rows: true,
1506                order: IntersectOrder::Sorted,
1507            },
1508        )
1509        .expect("intersect rows");
1510        let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1511        assert_eq!(values.shape, vec![1, 2]);
1512        assert_eq!(values.data, vec![1.0, 2.0]);
1513        let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1514        let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1515        assert_eq!(ia.data, vec![1.0]);
1516        assert_eq!(ib.data, vec![1.0]);
1517    }
1518
1519    #[test]
1520    fn intersect_char_elements_basic() {
1521        let a = CharArray::new("cab".chars().collect(), 1, 3).unwrap();
1522        let b = CharArray::new("bcd".chars().collect(), 1, 3).unwrap();
1523        assert_eq!(find_char_index(&b, 'b'), Some(0));
1524        assert_eq!(find_char_index(&b, 'c'), Some(1));
1525        let b_for_eval = CharArray::new("bcd".chars().collect(), 1, 3).unwrap();
1526        let eval = intersect_char_elements(
1527            a,
1528            b_for_eval,
1529            &IntersectOptions {
1530                rows: false,
1531                order: IntersectOrder::Sorted,
1532            },
1533        )
1534        .expect("intersect char");
1535        match eval.values_value() {
1536            Value::CharArray(arr) => {
1537                assert_eq!(arr.rows, 2);
1538                assert_eq!(arr.cols, 1);
1539                assert_eq!(arr.data, vec!['b', 'c']);
1540            }
1541            other => panic!("expected char array, got {other:?}"),
1542        }
1543        let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1544        let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1545        assert_eq!(ia.data, vec![3.0, 1.0]);
1546        assert_eq!(ib.data, vec![1.0, 2.0]);
1547    }
1548
1549    #[test]
1550    fn intersect_string_elements_stable() {
1551        let a = StringArray::new(
1552            vec!["apple".into(), "orange".into(), "pear".into()],
1553            vec![3, 1],
1554        )
1555        .unwrap();
1556        let b = StringArray::new(
1557            vec!["pear".into(), "grape".into(), "orange".into()],
1558            vec![3, 1],
1559        )
1560        .unwrap();
1561        let eval = intersect_string_elements(
1562            a,
1563            b,
1564            &IntersectOptions {
1565                rows: false,
1566                order: IntersectOrder::Stable,
1567            },
1568        )
1569        .expect("intersect string");
1570        match eval.values_value() {
1571            Value::StringArray(arr) => {
1572                assert_eq!(arr.shape, vec![2, 1]);
1573                assert_eq!(arr.data, vec!["orange".to_string(), "pear".to_string()]);
1574            }
1575            other => panic!("expected string array, got {other:?}"),
1576        }
1577        let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1578        let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1579        assert_eq!(ia.data, vec![2.0, 3.0]);
1580        assert_eq!(ib.data, vec![3.0, 1.0]);
1581    }
1582
1583    #[test]
1584    fn intersect_rejects_legacy_option() {
1585        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1586        let err = evaluate(
1587            Value::Tensor(tensor.clone()),
1588            Value::Tensor(tensor),
1589            &[Value::from("legacy")],
1590        )
1591        .unwrap_err();
1592        assert!(err.contains("legacy"));
1593    }
1594
1595    #[test]
1596    fn intersect_rows_dimension_mismatch() {
1597        let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1598        let b = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1599        let err = intersect_numeric_rows(
1600            a,
1601            b,
1602            &IntersectOptions {
1603                rows: true,
1604                order: IntersectOrder::Sorted,
1605            },
1606        )
1607        .unwrap_err();
1608        assert!(err.contains("same number of columns"));
1609    }
1610
1611    #[test]
1612    fn intersect_mixed_types_error() {
1613        let a = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1614        let b = CharArray::new(vec!['a', 'b'], 1, 2).unwrap();
1615        let err = intersect_host(
1616            Value::Tensor(a),
1617            Value::CharArray(b),
1618            &IntersectOptions {
1619                rows: false,
1620                order: IntersectOrder::Sorted,
1621            },
1622        )
1623        .unwrap_err();
1624        assert!(err.contains("unsupported input type"));
1625    }
1626
1627    #[test]
1628    fn intersect_gpu_roundtrip() {
1629        test_support::with_test_provider(|provider| {
1630            let a = Tensor::new(vec![4.0, 1.0, 2.0, 1.0], vec![4, 1]).unwrap();
1631            let b = Tensor::new(vec![2.0, 5.0, 1.0], vec![3, 1]).unwrap();
1632            let view_a = HostTensorView {
1633                data: &a.data,
1634                shape: &a.shape,
1635            };
1636            let view_b = HostTensorView {
1637                data: &b.data,
1638                shape: &b.shape,
1639            };
1640            let handle_a = provider.upload(&view_a).expect("upload A");
1641            let handle_b = provider.upload(&view_b).expect("upload B");
1642            let eval = evaluate(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
1643                .expect("intersect");
1644            let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1645            assert_eq!(values.data, vec![1.0, 2.0]);
1646            let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1647            let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1648            assert_eq!(ia.data, vec![2.0, 3.0]);
1649            assert_eq!(ib.data, vec![3.0, 1.0]);
1650        });
1651    }
1652
1653    #[test]
1654    fn intersect_two_outputs_from_evaluate() {
1655        let a = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1656        let b = Tensor::new(vec![3.0, 1.0], vec![2, 1]).unwrap();
1657        let eval = intersect_numeric_elements(
1658            a,
1659            b,
1660            &IntersectOptions {
1661                rows: false,
1662                order: IntersectOrder::Sorted,
1663            },
1664        )
1665        .unwrap();
1666        let (_c, ia) = eval.clone().into_pair();
1667        let ia_tensor = tensor::value_into_tensor_for("intersect", ia).unwrap();
1668        assert_eq!(ia_tensor.data, vec![1.0, 3.0]);
1669        let (_c, ia2, ib2) = eval.into_triple();
1670        let ia_tensor2 = tensor::value_into_tensor_for("intersect", ia2).unwrap();
1671        let ib_tensor2 = tensor::value_into_tensor_for("intersect", ib2).unwrap();
1672        assert_eq!(ia_tensor2.data, vec![1.0, 3.0]);
1673        assert_eq!(ib_tensor2.data, vec![2.0, 1.0]);
1674    }
1675
1676    #[test]
1677    #[cfg(feature = "wgpu")]
1678    fn intersect_wgpu_matches_cpu() {
1679        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1680            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1681        );
1682        let a = Tensor::new(vec![4.0, 1.0, 2.0, 3.0], vec![4, 1]).unwrap();
1683        let b = Tensor::new(vec![2.0, 6.0, 3.0], vec![3, 1]).unwrap();
1684
1685        let cpu_eval = intersect_numeric_elements(
1686            a.clone(),
1687            b.clone(),
1688            &IntersectOptions {
1689                rows: false,
1690                order: IntersectOrder::Sorted,
1691            },
1692        )
1693        .unwrap();
1694        let cpu_values =
1695            tensor::value_into_tensor_for("intersect", cpu_eval.values_value()).unwrap();
1696        let cpu_ia = tensor::value_into_tensor_for("intersect", cpu_eval.ia_value()).unwrap();
1697        let cpu_ib = tensor::value_into_tensor_for("intersect", cpu_eval.ib_value()).unwrap();
1698
1699        let provider = runmat_accelerate_api::provider().expect("provider");
1700        let view_a = HostTensorView {
1701            data: &a.data,
1702            shape: &a.shape,
1703        };
1704        let view_b = HostTensorView {
1705            data: &b.data,
1706            shape: &b.shape,
1707        };
1708        let handle_a = provider.upload(&view_a).expect("upload A");
1709        let handle_b = provider.upload(&view_b).expect("upload B");
1710        let gpu_eval = evaluate(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
1711            .expect("intersect");
1712        let gpu_values =
1713            tensor::value_into_tensor_for("intersect", gpu_eval.values_value()).unwrap();
1714        let gpu_ia = tensor::value_into_tensor_for("intersect", gpu_eval.ia_value()).unwrap();
1715        let gpu_ib = tensor::value_into_tensor_for("intersect", gpu_eval.ib_value()).unwrap();
1716
1717        assert_eq!(gpu_values.data, cpu_values.data);
1718        assert_eq!(gpu_ia.data, cpu_ia.data);
1719        assert_eq!(gpu_ib.data, cpu_ib.data);
1720    }
1721
1722    #[test]
1723    #[cfg(feature = "doc_export")]
1724    fn doc_examples_present() {
1725        let blocks = test_support::doc_examples(DOC_MD);
1726        assert!(!blocks.is_empty());
1727    }
1728}