Skip to main content

runmat_runtime/builtins/array/sorting_sets/
intersect.rs

1//! MATLAB-compatible `intersect` builtin with GPU-aware semantics for RunMat.
2//!
3//! Supports element-wise and row-wise intersections with optional stable ordering,
4//! and index outputs that mirror MathWorks MATLAB semantics. GPU tensors are
5//! gathered to host memory unless a provider supplies a dedicated `intersect`
6//! kernel hook.
7
8use std::cmp::Ordering;
9use std::collections::{HashMap, HashSet};
10
11use runmat_accelerate_api::GpuTensorHandle;
12use runmat_builtins::{CharArray, ComplexTensor, StringArray, Tensor, Value};
13use runmat_macros::runtime_builtin;
14
15use super::type_resolvers::set_values_output_type;
16use crate::build_runtime_error;
17use crate::builtins::common::arg_tokens::tokens_from_values;
18use crate::builtins::common::gpu_helpers;
19use crate::builtins::common::random_args::complex_tensor_into_value;
20use crate::builtins::common::spec::{
21    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
22    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
23};
24use crate::builtins::common::tensor;
25
26#[runmat_macros::register_gpu_spec(
27    builtin_path = "crate::builtins::array::sorting_sets::intersect"
28)]
29pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
30    name: "intersect",
31    op_kind: GpuOpKind::Custom("intersect"),
32    supported_precisions: &[ScalarType::F32, ScalarType::F64],
33    broadcast: BroadcastSemantics::None,
34    provider_hooks: &[ProviderHook::Custom("intersect")],
35    constant_strategy: ConstantStrategy::InlineLiteral,
36    residency: ResidencyPolicy::GatherImmediately,
37    nan_mode: ReductionNaN::Include,
38    two_pass_threshold: None,
39    workgroup_size: None,
40    accepts_nan_mode: true,
41    notes:
42        "Providers may expose a dedicated intersect hook; otherwise tensors are gathered and processed on the host.",
43};
44
45#[runmat_macros::register_fusion_spec(
46    builtin_path = "crate::builtins::array::sorting_sets::intersect"
47)]
48pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
49    name: "intersect",
50    shape: ShapeRequirements::Any,
51    constant_strategy: ConstantStrategy::InlineLiteral,
52    elementwise: None,
53    reduction: None,
54    emits_nan: true,
55    notes: "`intersect` materialises its inputs and terminates fusion chains; upstream GPU tensors are gathered when necessary.",
56};
57
58fn intersect_error(message: impl Into<String>) -> crate::RuntimeError {
59    build_runtime_error(message)
60        .with_builtin("intersect")
61        .build()
62}
63
64#[runtime_builtin(
65    name = "intersect",
66    category = "array/sorting_sets",
67    summary = "Return common elements or rows across arrays with MATLAB-compatible ordering and index outputs.",
68    keywords = "intersect,set,stable,rows,indices,gpu",
69    accel = "array_construct",
70    sink = true,
71    type_resolver(set_values_output_type),
72    builtin_path = "crate::builtins::array::sorting_sets::intersect"
73)]
74async fn intersect_builtin(a: Value, b: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
75    let eval = evaluate(a, b, &rest).await?;
76    if let Some(out_count) = crate::output_count::current_output_count() {
77        if out_count == 0 {
78            return Ok(Value::OutputList(Vec::new()));
79        }
80        if out_count == 1 {
81            return Ok(Value::OutputList(vec![eval.into_values_value()]));
82        }
83        if out_count == 2 {
84            let (values, ia) = eval.into_pair();
85            return Ok(Value::OutputList(vec![values, ia]));
86        }
87        let (values, ia, ib) = eval.into_triple();
88        return Ok(crate::output_count::output_list_with_padding(
89            out_count,
90            vec![values, ia, ib],
91        ));
92    }
93    Ok(eval.into_values_value())
94}
95
96/// Evaluate the `intersect` builtin once and expose all outputs.
97pub async fn evaluate(
98    a: Value,
99    b: Value,
100    rest: &[Value],
101) -> crate::BuiltinResult<IntersectEvaluation> {
102    let opts = parse_options(rest)?;
103    match (a, b) {
104        (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
105            intersect_gpu_pair(handle_a, handle_b, &opts).await
106        }
107        (Value::GpuTensor(handle_a), other) => {
108            intersect_gpu_mixed(handle_a, other, &opts, true).await
109        }
110        (other, Value::GpuTensor(handle_b)) => {
111            intersect_gpu_mixed(handle_b, other, &opts, false).await
112        }
113        (left, right) => intersect_host(left, right, &opts),
114    }
115}
116
117#[derive(Debug, Clone, Copy, PartialEq, Eq)]
118enum IntersectOrder {
119    Sorted,
120    Stable,
121}
122
123#[derive(Debug, Clone)]
124struct IntersectOptions {
125    rows: bool,
126    order: IntersectOrder,
127}
128
129fn parse_options(rest: &[Value]) -> crate::BuiltinResult<IntersectOptions> {
130    let mut opts = IntersectOptions {
131        rows: false,
132        order: IntersectOrder::Sorted,
133    };
134    let mut seen_order: Option<IntersectOrder> = None;
135
136    let tokens = tokens_from_values(rest);
137    for (arg, token) in rest.iter().zip(tokens.iter()) {
138        let text = match token {
139            crate::builtins::common::arg_tokens::ArgToken::String(text) => text.as_str(),
140            _ => {
141                let text = tensor::value_to_string(arg).ok_or_else(|| {
142                    intersect_error("intersect: expected string option arguments")
143                })?;
144                let lowered = text.trim().to_ascii_lowercase();
145                parse_intersect_option(&mut opts, &mut seen_order, &lowered)?;
146                continue;
147            }
148        };
149        parse_intersect_option(&mut opts, &mut seen_order, text)?;
150    }
151
152    Ok(opts)
153}
154
155fn parse_intersect_option(
156    opts: &mut IntersectOptions,
157    seen_order: &mut Option<IntersectOrder>,
158    lowered: &str,
159) -> crate::BuiltinResult<()> {
160    match lowered {
161        "rows" => opts.rows = true,
162        "sorted" => {
163            if let Some(prev) = seen_order {
164                if *prev != IntersectOrder::Sorted {
165                    return Err(intersect_error(
166                        "intersect: cannot combine 'sorted' with 'stable'",
167                    ));
168                }
169            }
170            *seen_order = Some(IntersectOrder::Sorted);
171            opts.order = IntersectOrder::Sorted;
172        }
173        "stable" => {
174            if let Some(prev) = seen_order {
175                if *prev != IntersectOrder::Stable {
176                    return Err(intersect_error(
177                        "intersect: cannot combine 'sorted' with 'stable'",
178                    ));
179                }
180            }
181            *seen_order = Some(IntersectOrder::Stable);
182            opts.order = IntersectOrder::Stable;
183        }
184        "legacy" | "r2012a" => {
185            return Err(intersect_error(
186                "intersect: the 'legacy' behaviour is not supported",
187            ));
188        }
189        other => {
190            return Err(intersect_error(format!(
191                "intersect: unrecognised option '{other}'"
192            )))
193        }
194    }
195    Ok(())
196}
197
198async fn intersect_gpu_pair(
199    handle_a: GpuTensorHandle,
200    handle_b: GpuTensorHandle,
201    opts: &IntersectOptions,
202) -> crate::BuiltinResult<IntersectEvaluation> {
203    let tensor_a = gpu_helpers::gather_tensor_async(&handle_a).await?;
204    let tensor_b = gpu_helpers::gather_tensor_async(&handle_b).await?;
205    intersect_numeric(tensor_a, tensor_b, opts)
206}
207
208async fn intersect_gpu_mixed(
209    handle_gpu: GpuTensorHandle,
210    other: Value,
211    opts: &IntersectOptions,
212    gpu_is_a: bool,
213) -> crate::BuiltinResult<IntersectEvaluation> {
214    let tensor_gpu = gpu_helpers::gather_tensor_async(&handle_gpu).await?;
215    let tensor_other =
216        tensor::value_into_tensor_for("intersect", other).map_err(|e| intersect_error(e))?;
217    if gpu_is_a {
218        intersect_numeric(tensor_gpu, tensor_other, opts)
219    } else {
220        intersect_numeric(tensor_other, tensor_gpu, opts)
221    }
222}
223
224fn intersect_host(
225    a: Value,
226    b: Value,
227    opts: &IntersectOptions,
228) -> crate::BuiltinResult<IntersectEvaluation> {
229    match (a, b) {
230        (Value::ComplexTensor(at), Value::ComplexTensor(bt)) => intersect_complex(at, bt, opts),
231        (Value::ComplexTensor(at), Value::Complex(re, im)) => {
232            let bt = scalar_complex_tensor(re, im)?;
233            intersect_complex(at, bt, opts)
234        }
235        (Value::Complex(re, im), Value::ComplexTensor(bt)) => {
236            let at = scalar_complex_tensor(re, im)?;
237            intersect_complex(at, bt, opts)
238        }
239        (Value::Complex(a_re, a_im), Value::Complex(b_re, b_im)) => {
240            let at = scalar_complex_tensor(a_re, a_im)?;
241            let bt = scalar_complex_tensor(b_re, b_im)?;
242            intersect_complex(at, bt, opts)
243        }
244        (Value::ComplexTensor(at), other) => {
245            let bt = value_into_complex_tensor(other)?;
246            intersect_complex(at, bt, opts)
247        }
248        (other, Value::ComplexTensor(bt)) => {
249            let at = value_into_complex_tensor(other)?;
250            intersect_complex(at, bt, opts)
251        }
252        (Value::Complex(re, im), other) => {
253            let at = scalar_complex_tensor(re, im)?;
254            let bt = value_into_complex_tensor(other)?;
255            intersect_complex(at, bt, opts)
256        }
257        (other, Value::Complex(re, im)) => {
258            let at = value_into_complex_tensor(other)?;
259            let bt = scalar_complex_tensor(re, im)?;
260            intersect_complex(at, bt, opts)
261        }
262
263        (Value::CharArray(ac), Value::CharArray(bc)) => intersect_char(ac, bc, opts),
264
265        (Value::StringArray(astring), Value::StringArray(bstring)) => {
266            intersect_string(astring, bstring, opts)
267        }
268        (Value::StringArray(astring), Value::String(b)) => {
269            let bstring = StringArray::new(vec![b], vec![1, 1])
270                .map_err(|e| intersect_error(format!("intersect: {e}")))?;
271            intersect_string(astring, bstring, opts)
272        }
273        (Value::String(a), Value::StringArray(bstring)) => {
274            let astring = StringArray::new(vec![a], vec![1, 1])
275                .map_err(|e| intersect_error(format!("intersect: {e}")))?;
276            intersect_string(astring, bstring, opts)
277        }
278        (Value::String(a), Value::String(b)) => {
279            let astring = StringArray::new(vec![a], vec![1, 1])
280                .map_err(|e| intersect_error(format!("intersect: {e}")))?;
281            let bstring = StringArray::new(vec![b], vec![1, 1])
282                .map_err(|e| intersect_error(format!("intersect: {e}")))?;
283            intersect_string(astring, bstring, opts)
284        }
285
286        (left, right) => {
287            let tensor_a =
288                tensor::value_into_tensor_for("intersect", left).map_err(|e| intersect_error(e))?;
289            let tensor_b = tensor::value_into_tensor_for("intersect", right)
290                .map_err(|e| intersect_error(e))?;
291            intersect_numeric(tensor_a, tensor_b, opts)
292        }
293    }
294}
295
296fn intersect_numeric(
297    a: Tensor,
298    b: Tensor,
299    opts: &IntersectOptions,
300) -> crate::BuiltinResult<IntersectEvaluation> {
301    if opts.rows {
302        intersect_numeric_rows(a, b, opts)
303    } else {
304        intersect_numeric_elements(a, b, opts)
305    }
306}
307
308fn intersect_numeric_elements(
309    a: Tensor,
310    b: Tensor,
311    opts: &IntersectOptions,
312) -> crate::BuiltinResult<IntersectEvaluation> {
313    let mut b_map: HashMap<u64, usize> = HashMap::new();
314    for (idx, &value) in b.data.iter().enumerate() {
315        let key = canonicalize_f64(value);
316        b_map.entry(key).or_insert(idx);
317    }
318
319    let mut seen: HashSet<u64> = HashSet::new();
320    let mut entries = Vec::<NumericIntersectEntry>::new();
321    let mut order_counter = 0usize;
322
323    for (idx, &value) in a.data.iter().enumerate() {
324        let key = canonicalize_f64(value);
325        if seen.contains(&key) {
326            continue;
327        }
328        if let Some(&b_idx) = b_map.get(&key) {
329            entries.push(NumericIntersectEntry {
330                value,
331                a_index: idx,
332                b_index: b_idx,
333                order_rank: order_counter,
334            });
335            seen.insert(key);
336            order_counter += 1;
337        }
338    }
339
340    assemble_numeric_intersect(entries, opts)
341}
342
343fn intersect_numeric_rows(
344    a: Tensor,
345    b: Tensor,
346    opts: &IntersectOptions,
347) -> crate::BuiltinResult<IntersectEvaluation> {
348    if a.shape.len() != 2 || b.shape.len() != 2 {
349        return Err(intersect_error(
350            "intersect: 'rows' option requires 2-D numeric matrices",
351        ));
352    }
353    if a.shape[1] != b.shape[1] {
354        return Err(intersect_error(
355            "intersect: inputs must have the same number of columns when using 'rows'",
356        ));
357    }
358    let rows_a = a.shape[0];
359    let cols = a.shape[1];
360    let rows_b = b.shape[0];
361
362    let mut b_map: HashMap<NumericRowKey, usize> = HashMap::new();
363    for r in 0..rows_b {
364        let mut row_values = Vec::with_capacity(cols);
365        for c in 0..cols {
366            let idx = r + c * rows_b;
367            row_values.push(b.data[idx]);
368        }
369        let key = NumericRowKey::from_slice(&row_values);
370        b_map.entry(key).or_insert(r);
371    }
372
373    let mut seen: HashSet<NumericRowKey> = HashSet::new();
374    let mut entries = Vec::<NumericRowIntersectEntry>::new();
375    let mut order_counter = 0usize;
376
377    for r in 0..rows_a {
378        let mut row_values = Vec::with_capacity(cols);
379        for c in 0..cols {
380            let idx = r + c * rows_a;
381            row_values.push(a.data[idx]);
382        }
383        let key = NumericRowKey::from_slice(&row_values);
384        if seen.contains(&key) {
385            continue;
386        }
387        if let Some(&b_row) = b_map.get(&key) {
388            entries.push(NumericRowIntersectEntry {
389                row_data: row_values,
390                a_row: r,
391                b_row,
392                order_rank: order_counter,
393            });
394            seen.insert(key);
395            order_counter += 1;
396        }
397    }
398
399    assemble_numeric_row_intersect(entries, opts, cols)
400}
401
402fn intersect_complex(
403    a: ComplexTensor,
404    b: ComplexTensor,
405    opts: &IntersectOptions,
406) -> crate::BuiltinResult<IntersectEvaluation> {
407    if opts.rows {
408        intersect_complex_rows(a, b, opts)
409    } else {
410        intersect_complex_elements(a, b, opts)
411    }
412}
413
414fn intersect_complex_elements(
415    a: ComplexTensor,
416    b: ComplexTensor,
417    opts: &IntersectOptions,
418) -> crate::BuiltinResult<IntersectEvaluation> {
419    let mut b_map: HashMap<ComplexKey, usize> = HashMap::new();
420    for (idx, &value) in b.data.iter().enumerate() {
421        let key = ComplexKey::new(value);
422        b_map.entry(key).or_insert(idx);
423    }
424
425    let mut seen: HashSet<ComplexKey> = HashSet::new();
426    let mut entries = Vec::<ComplexIntersectEntry>::new();
427    let mut order_counter = 0usize;
428
429    for (idx, &value) in a.data.iter().enumerate() {
430        let key = ComplexKey::new(value);
431        if seen.contains(&key) {
432            continue;
433        }
434        if let Some(&b_idx) = b_map.get(&key) {
435            entries.push(ComplexIntersectEntry {
436                value,
437                a_index: idx,
438                b_index: b_idx,
439                order_rank: order_counter,
440            });
441            seen.insert(key);
442            order_counter += 1;
443        }
444    }
445
446    assemble_complex_intersect(entries, opts)
447}
448
449fn intersect_complex_rows(
450    a: ComplexTensor,
451    b: ComplexTensor,
452    opts: &IntersectOptions,
453) -> crate::BuiltinResult<IntersectEvaluation> {
454    if a.shape.len() != 2 || b.shape.len() != 2 {
455        return Err(intersect_error(
456            "intersect: 'rows' option requires 2-D complex matrices",
457        ));
458    }
459    if a.shape[1] != b.shape[1] {
460        return Err(intersect_error(
461            "intersect: inputs must have the same number of columns when using 'rows'",
462        ));
463    }
464    let rows_a = a.shape[0];
465    let cols = a.shape[1];
466    let rows_b = b.shape[0];
467
468    let mut b_map: HashMap<Vec<ComplexKey>, usize> = HashMap::new();
469    for r in 0..rows_b {
470        let mut row_keys = Vec::with_capacity(cols);
471        for c in 0..cols {
472            let idx = r + c * rows_b;
473            row_keys.push(ComplexKey::new(b.data[idx]));
474        }
475        b_map.entry(row_keys).or_insert(r);
476    }
477
478    let mut seen: HashSet<Vec<ComplexKey>> = HashSet::new();
479    let mut entries = Vec::<ComplexRowIntersectEntry>::new();
480    let mut order_counter = 0usize;
481
482    for r in 0..rows_a {
483        let mut row_values = Vec::with_capacity(cols);
484        let mut row_keys = Vec::with_capacity(cols);
485        for c in 0..cols {
486            let idx = r + c * rows_a;
487            let value = a.data[idx];
488            row_values.push(value);
489            row_keys.push(ComplexKey::new(value));
490        }
491        if seen.contains(&row_keys) {
492            continue;
493        }
494        if let Some(&b_row) = b_map.get(&row_keys) {
495            entries.push(ComplexRowIntersectEntry {
496                row_data: row_values,
497                a_row: r,
498                b_row,
499                order_rank: order_counter,
500            });
501            seen.insert(row_keys);
502            order_counter += 1;
503        }
504    }
505
506    assemble_complex_row_intersect(entries, opts, cols)
507}
508
509fn intersect_char(
510    a: CharArray,
511    b: CharArray,
512    opts: &IntersectOptions,
513) -> crate::BuiltinResult<IntersectEvaluation> {
514    if opts.rows {
515        intersect_char_rows(a, b, opts)
516    } else {
517        intersect_char_elements(a, b, opts)
518    }
519}
520
521fn intersect_char_elements(
522    a: CharArray,
523    b: CharArray,
524    opts: &IntersectOptions,
525) -> crate::BuiltinResult<IntersectEvaluation> {
526    let mut seen: HashSet<u32> = HashSet::new();
527    let mut entries = Vec::<CharIntersectEntry>::new();
528    let mut order_counter = 0usize;
529
530    for col in 0..a.cols {
531        for row in 0..a.rows {
532            let linear_idx = row + col * a.rows;
533            let data_idx = row * a.cols + col;
534            let ch = a.data[data_idx];
535            let key = ch as u32;
536            if seen.contains(&key) {
537                continue;
538            }
539            if let Some(b_idx) = find_char_index(&b, ch) {
540                entries.push(CharIntersectEntry {
541                    ch,
542                    a_index: linear_idx,
543                    b_index: b_idx,
544                    order_rank: order_counter,
545                });
546                seen.insert(key);
547                order_counter += 1;
548            }
549        }
550    }
551
552    assemble_char_intersect(entries, opts, &b)
553}
554
555fn intersect_char_rows(
556    a: CharArray,
557    b: CharArray,
558    opts: &IntersectOptions,
559) -> crate::BuiltinResult<IntersectEvaluation> {
560    if a.cols != b.cols {
561        return Err(intersect_error(
562            "intersect: inputs must have the same number of columns when using 'rows'",
563        ));
564    }
565    let rows_a = a.rows;
566    let rows_b = b.rows;
567    let cols = a.cols;
568
569    let mut b_map: HashMap<RowCharKey, usize> = HashMap::new();
570    for r in 0..rows_b {
571        let mut row_values = Vec::with_capacity(cols);
572        for c in 0..cols {
573            let idx = r * cols + c;
574            row_values.push(b.data[idx]);
575        }
576        let key = RowCharKey::from_slice(&row_values);
577        b_map.entry(key).or_insert(r);
578    }
579
580    let mut seen: HashSet<RowCharKey> = HashSet::new();
581    let mut entries = Vec::<CharRowIntersectEntry>::new();
582    let mut order_counter = 0usize;
583
584    for r in 0..rows_a {
585        let mut row_values = Vec::with_capacity(cols);
586        for c in 0..cols {
587            let idx = r * cols + c;
588            row_values.push(a.data[idx]);
589        }
590        let key = RowCharKey::from_slice(&row_values);
591        if seen.contains(&key) {
592            continue;
593        }
594        if let Some(&b_row) = b_map.get(&key) {
595            entries.push(CharRowIntersectEntry {
596                row_data: row_values,
597                a_row: r,
598                b_row,
599                order_rank: order_counter,
600            });
601            seen.insert(key);
602            order_counter += 1;
603        }
604    }
605
606    assemble_char_row_intersect(entries, opts, cols)
607}
608
609fn find_char_index(array: &CharArray, target: char) -> Option<usize> {
610    for col in 0..array.cols {
611        for row in 0..array.rows {
612            let data_idx = row * array.cols + col;
613            if array.data[data_idx] == target {
614                return Some(row + col * array.rows);
615            }
616        }
617    }
618    None
619}
620
621fn intersect_string(
622    a: StringArray,
623    b: StringArray,
624    opts: &IntersectOptions,
625) -> crate::BuiltinResult<IntersectEvaluation> {
626    if opts.rows {
627        intersect_string_rows(a, b, opts)
628    } else {
629        intersect_string_elements(a, b, opts)
630    }
631}
632
633fn intersect_string_elements(
634    a: StringArray,
635    b: StringArray,
636    opts: &IntersectOptions,
637) -> crate::BuiltinResult<IntersectEvaluation> {
638    let mut b_map: HashMap<String, usize> = HashMap::new();
639    for (idx, value) in b.data.iter().enumerate() {
640        b_map.entry(value.clone()).or_insert(idx);
641    }
642
643    let mut seen: HashSet<String> = HashSet::new();
644    let mut entries = Vec::<StringIntersectEntry>::new();
645    let mut order_counter = 0usize;
646
647    for (idx, value) in a.data.iter().enumerate() {
648        if seen.contains(value) {
649            continue;
650        }
651        if let Some(&b_idx) = b_map.get(value) {
652            entries.push(StringIntersectEntry {
653                value: value.clone(),
654                a_index: idx,
655                b_index: b_idx,
656                order_rank: order_counter,
657            });
658            seen.insert(value.clone());
659            order_counter += 1;
660        }
661    }
662
663    assemble_string_intersect(entries, opts)
664}
665
666fn intersect_string_rows(
667    a: StringArray,
668    b: StringArray,
669    opts: &IntersectOptions,
670) -> crate::BuiltinResult<IntersectEvaluation> {
671    if a.shape.len() != 2 || b.shape.len() != 2 {
672        return Err(intersect_error(
673            "intersect: 'rows' option requires 2-D string arrays",
674        ));
675    }
676    if a.shape[1] != b.shape[1] {
677        return Err(intersect_error(
678            "intersect: inputs must have the same number of columns when using 'rows'",
679        ));
680    }
681    let rows_a = a.shape[0];
682    let cols = a.shape[1];
683    let rows_b = b.shape[0];
684
685    let mut b_map: HashMap<RowStringKey, usize> = HashMap::new();
686    for r in 0..rows_b {
687        let mut row_values = Vec::with_capacity(cols);
688        for c in 0..cols {
689            let idx = r + c * rows_b;
690            row_values.push(b.data[idx].clone());
691        }
692        let key = RowStringKey::from_slice(&row_values);
693        b_map.entry(key).or_insert(r);
694    }
695
696    let mut seen: HashSet<RowStringKey> = HashSet::new();
697    let mut entries = Vec::<StringRowIntersectEntry>::new();
698    let mut order_counter = 0usize;
699
700    for r in 0..rows_a {
701        let mut row_values = Vec::with_capacity(cols);
702        for c in 0..cols {
703            let idx = r + c * rows_a;
704            row_values.push(a.data[idx].clone());
705        }
706        let key = RowStringKey::from_slice(&row_values);
707        if seen.contains(&key) {
708            continue;
709        }
710        if let Some(&b_row) = b_map.get(&key) {
711            entries.push(StringRowIntersectEntry {
712                row_data: row_values,
713                a_row: r,
714                b_row,
715                order_rank: order_counter,
716            });
717            seen.insert(key);
718            order_counter += 1;
719        }
720    }
721
722    assemble_string_row_intersect(entries, opts, cols)
723}
724
725#[derive(Debug, Clone)]
726pub struct IntersectEvaluation {
727    values: Value,
728    ia: Tensor,
729    ib: Tensor,
730}
731
732impl IntersectEvaluation {
733    fn new(values: Value, ia: Tensor, ib: Tensor) -> Self {
734        Self { values, ia, ib }
735    }
736
737    pub fn into_values_value(self) -> Value {
738        self.values
739    }
740
741    pub fn into_pair(self) -> (Value, Value) {
742        let ia = tensor::tensor_into_value(self.ia);
743        (self.values, ia)
744    }
745
746    pub fn into_triple(self) -> (Value, Value, Value) {
747        let ia = tensor::tensor_into_value(self.ia);
748        let ib = tensor::tensor_into_value(self.ib);
749        (self.values, ia, ib)
750    }
751
752    pub fn values_value(&self) -> Value {
753        self.values.clone()
754    }
755
756    pub fn ia_value(&self) -> Value {
757        tensor::tensor_into_value(self.ia.clone())
758    }
759
760    pub fn ib_value(&self) -> Value {
761        tensor::tensor_into_value(self.ib.clone())
762    }
763}
764
765#[derive(Debug)]
766struct NumericIntersectEntry {
767    value: f64,
768    a_index: usize,
769    b_index: usize,
770    order_rank: usize,
771}
772
773#[derive(Debug)]
774struct NumericRowIntersectEntry {
775    row_data: Vec<f64>,
776    a_row: usize,
777    b_row: usize,
778    order_rank: usize,
779}
780
781#[derive(Debug)]
782struct ComplexIntersectEntry {
783    value: (f64, f64),
784    a_index: usize,
785    b_index: usize,
786    order_rank: usize,
787}
788
789#[derive(Debug)]
790struct ComplexRowIntersectEntry {
791    row_data: Vec<(f64, f64)>,
792    a_row: usize,
793    b_row: usize,
794    order_rank: usize,
795}
796
797#[derive(Debug)]
798struct CharIntersectEntry {
799    ch: char,
800    a_index: usize,
801    b_index: usize,
802    order_rank: usize,
803}
804
805#[derive(Debug)]
806struct CharRowIntersectEntry {
807    row_data: Vec<char>,
808    a_row: usize,
809    b_row: usize,
810    order_rank: usize,
811}
812
813#[derive(Debug)]
814struct StringIntersectEntry {
815    value: String,
816    a_index: usize,
817    b_index: usize,
818    order_rank: usize,
819}
820
821#[derive(Debug)]
822struct StringRowIntersectEntry {
823    row_data: Vec<String>,
824    a_row: usize,
825    b_row: usize,
826    order_rank: usize,
827}
828
829fn assemble_numeric_intersect(
830    entries: Vec<NumericIntersectEntry>,
831    opts: &IntersectOptions,
832) -> crate::BuiltinResult<IntersectEvaluation> {
833    let mut order: Vec<usize> = (0..entries.len()).collect();
834    match opts.order {
835        IntersectOrder::Sorted => {
836            order.sort_by(|&lhs, &rhs| compare_f64(entries[lhs].value, entries[rhs].value));
837        }
838        IntersectOrder::Stable => {
839            order.sort_by_key(|&idx| entries[idx].order_rank);
840        }
841    }
842
843    let mut values = Vec::with_capacity(order.len());
844    let mut ia = Vec::with_capacity(order.len());
845    let mut ib = Vec::with_capacity(order.len());
846    for &idx in &order {
847        let entry = &entries[idx];
848        values.push(entry.value);
849        ia.push((entry.a_index + 1) as f64);
850        ib.push((entry.b_index + 1) as f64);
851    }
852
853    let value_tensor = Tensor::new(values, vec![order.len(), 1])
854        .map_err(|e| intersect_error(format!("intersect: {e}")))?;
855    let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
856        .map_err(|e| intersect_error(format!("intersect: {e}")))?;
857    let ib_tensor = Tensor::new(ib, vec![order.len(), 1])
858        .map_err(|e| intersect_error(format!("intersect: {e}")))?;
859
860    Ok(IntersectEvaluation::new(
861        tensor::tensor_into_value(value_tensor),
862        ia_tensor,
863        ib_tensor,
864    ))
865}
866
867fn assemble_numeric_row_intersect(
868    entries: Vec<NumericRowIntersectEntry>,
869    opts: &IntersectOptions,
870    cols: usize,
871) -> crate::BuiltinResult<IntersectEvaluation> {
872    let mut order: Vec<usize> = (0..entries.len()).collect();
873    match opts.order {
874        IntersectOrder::Sorted => {
875            order.sort_by(|&lhs, &rhs| {
876                compare_numeric_rows(&entries[lhs].row_data, &entries[rhs].row_data)
877            });
878        }
879        IntersectOrder::Stable => {
880            order.sort_by_key(|&idx| entries[idx].order_rank);
881        }
882    }
883
884    let rows_out = order.len();
885    let mut values = vec![0.0f64; rows_out * cols];
886    let mut ia = Vec::with_capacity(rows_out);
887    let mut ib = Vec::with_capacity(rows_out);
888
889    for (row_pos, &entry_idx) in order.iter().enumerate() {
890        let entry = &entries[entry_idx];
891        for col in 0..cols {
892            let dest = row_pos + col * rows_out;
893            values[dest] = entry.row_data[col];
894        }
895        ia.push((entry.a_row + 1) as f64);
896        ib.push((entry.b_row + 1) as f64);
897    }
898
899    let value_tensor = Tensor::new(values, vec![rows_out, cols])
900        .map_err(|e| intersect_error(format!("intersect: {e}")))?;
901    let ia_tensor = Tensor::new(ia, vec![rows_out, 1])
902        .map_err(|e| intersect_error(format!("intersect: {e}")))?;
903    let ib_tensor = Tensor::new(ib, vec![rows_out, 1])
904        .map_err(|e| intersect_error(format!("intersect: {e}")))?;
905
906    Ok(IntersectEvaluation::new(
907        tensor::tensor_into_value(value_tensor),
908        ia_tensor,
909        ib_tensor,
910    ))
911}
912
913fn assemble_complex_intersect(
914    entries: Vec<ComplexIntersectEntry>,
915    opts: &IntersectOptions,
916) -> crate::BuiltinResult<IntersectEvaluation> {
917    let mut order: Vec<usize> = (0..entries.len()).collect();
918    match opts.order {
919        IntersectOrder::Sorted => {
920            order.sort_by(|&lhs, &rhs| compare_complex(entries[lhs].value, entries[rhs].value));
921        }
922        IntersectOrder::Stable => {
923            order.sort_by_key(|&idx| entries[idx].order_rank);
924        }
925    }
926
927    let mut values = Vec::with_capacity(order.len());
928    let mut ia = Vec::with_capacity(order.len());
929    let mut ib = Vec::with_capacity(order.len());
930    for &idx in &order {
931        let entry = &entries[idx];
932        values.push(entry.value);
933        ia.push((entry.a_index + 1) as f64);
934        ib.push((entry.b_index + 1) as f64);
935    }
936
937    let value_tensor = ComplexTensor::new(values, vec![order.len(), 1])
938        .map_err(|e| intersect_error(format!("intersect: {e}")))?;
939    let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
940        .map_err(|e| intersect_error(format!("intersect: {e}")))?;
941    let ib_tensor = Tensor::new(ib, vec![order.len(), 1])
942        .map_err(|e| intersect_error(format!("intersect: {e}")))?;
943
944    Ok(IntersectEvaluation::new(
945        complex_tensor_into_value(value_tensor),
946        ia_tensor,
947        ib_tensor,
948    ))
949}
950
951fn assemble_complex_row_intersect(
952    entries: Vec<ComplexRowIntersectEntry>,
953    opts: &IntersectOptions,
954    cols: usize,
955) -> crate::BuiltinResult<IntersectEvaluation> {
956    let mut order: Vec<usize> = (0..entries.len()).collect();
957    match opts.order {
958        IntersectOrder::Sorted => {
959            order.sort_by(|&lhs, &rhs| {
960                compare_complex_rows(&entries[lhs].row_data, &entries[rhs].row_data)
961            });
962        }
963        IntersectOrder::Stable => {
964            order.sort_by_key(|&idx| entries[idx].order_rank);
965        }
966    }
967
968    let rows_out = order.len();
969    let mut values = vec![(0.0f64, 0.0f64); rows_out * cols];
970    let mut ia = Vec::with_capacity(rows_out);
971    let mut ib = Vec::with_capacity(rows_out);
972
973    for (row_pos, &entry_idx) in order.iter().enumerate() {
974        let entry = &entries[entry_idx];
975        for col in 0..cols {
976            let dest = row_pos + col * rows_out;
977            values[dest] = entry.row_data[col];
978        }
979        ia.push((entry.a_row + 1) as f64);
980        ib.push((entry.b_row + 1) as f64);
981    }
982
983    let value_tensor = ComplexTensor::new(values, vec![rows_out, cols])
984        .map_err(|e| intersect_error(format!("intersect: {e}")))?;
985    let ia_tensor = Tensor::new(ia, vec![rows_out, 1])
986        .map_err(|e| intersect_error(format!("intersect: {e}")))?;
987    let ib_tensor = Tensor::new(ib, vec![rows_out, 1])
988        .map_err(|e| intersect_error(format!("intersect: {e}")))?;
989
990    Ok(IntersectEvaluation::new(
991        complex_tensor_into_value(value_tensor),
992        ia_tensor,
993        ib_tensor,
994    ))
995}
996
997fn assemble_char_intersect(
998    entries: Vec<CharIntersectEntry>,
999    opts: &IntersectOptions,
1000    b: &CharArray,
1001) -> crate::BuiltinResult<IntersectEvaluation> {
1002    let mut order: Vec<usize> = (0..entries.len()).collect();
1003    match opts.order {
1004        IntersectOrder::Sorted => {
1005            order.sort_by(|&lhs, &rhs| entries[lhs].ch.cmp(&entries[rhs].ch));
1006        }
1007        IntersectOrder::Stable => {
1008            order.sort_by_key(|&idx| entries[idx].order_rank);
1009        }
1010    }
1011
1012    let mut values = Vec::with_capacity(order.len());
1013    let mut ia = Vec::with_capacity(order.len());
1014    let mut ib = Vec::with_capacity(order.len());
1015    for &idx in &order {
1016        let entry = &entries[idx];
1017        values.push(entry.ch);
1018        ia.push((entry.a_index + 1) as f64);
1019        let b_idx = find_char_index(b, entry.ch).unwrap_or(entry.b_index);
1020        ib.push((b_idx + 1) as f64);
1021    }
1022
1023    let value_array = CharArray::new(values, order.len(), 1)
1024        .map_err(|e| intersect_error(format!("intersect: {e}")))?;
1025    let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
1026        .map_err(|e| intersect_error(format!("intersect: {e}")))?;
1027    let ib_tensor = Tensor::new(ib, vec![order.len(), 1])
1028        .map_err(|e| intersect_error(format!("intersect: {e}")))?;
1029
1030    Ok(IntersectEvaluation::new(
1031        Value::CharArray(value_array),
1032        ia_tensor,
1033        ib_tensor,
1034    ))
1035}
1036
1037fn assemble_char_row_intersect(
1038    entries: Vec<CharRowIntersectEntry>,
1039    opts: &IntersectOptions,
1040    cols: usize,
1041) -> crate::BuiltinResult<IntersectEvaluation> {
1042    let mut order: Vec<usize> = (0..entries.len()).collect();
1043    match opts.order {
1044        IntersectOrder::Sorted => {
1045            order.sort_by(|&lhs, &rhs| {
1046                compare_char_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1047            });
1048        }
1049        IntersectOrder::Stable => {
1050            order.sort_by_key(|&idx| entries[idx].order_rank);
1051        }
1052    }
1053
1054    let rows_out = order.len();
1055    let mut values = vec!['\0'; rows_out * cols];
1056    let mut ia = Vec::with_capacity(rows_out);
1057    let mut ib = Vec::with_capacity(rows_out);
1058
1059    for (row_pos, &entry_idx) in order.iter().enumerate() {
1060        let entry = &entries[entry_idx];
1061        for col in 0..cols {
1062            let dest = row_pos * cols + col;
1063            values[dest] = entry.row_data[col];
1064        }
1065        ia.push((entry.a_row + 1) as f64);
1066        ib.push((entry.b_row + 1) as f64);
1067    }
1068
1069    let value_array = CharArray::new(values, rows_out, cols)
1070        .map_err(|e| intersect_error(format!("intersect: {e}")))?;
1071    let ia_tensor = Tensor::new(ia, vec![rows_out, 1])
1072        .map_err(|e| intersect_error(format!("intersect: {e}")))?;
1073    let ib_tensor = Tensor::new(ib, vec![rows_out, 1])
1074        .map_err(|e| intersect_error(format!("intersect: {e}")))?;
1075
1076    Ok(IntersectEvaluation::new(
1077        Value::CharArray(value_array),
1078        ia_tensor,
1079        ib_tensor,
1080    ))
1081}
1082
1083fn assemble_string_intersect(
1084    entries: Vec<StringIntersectEntry>,
1085    opts: &IntersectOptions,
1086) -> crate::BuiltinResult<IntersectEvaluation> {
1087    let mut order: Vec<usize> = (0..entries.len()).collect();
1088    match opts.order {
1089        IntersectOrder::Sorted => {
1090            order.sort_by(|&lhs, &rhs| entries[lhs].value.cmp(&entries[rhs].value));
1091        }
1092        IntersectOrder::Stable => {
1093            order.sort_by_key(|&idx| entries[idx].order_rank);
1094        }
1095    }
1096
1097    let mut values = Vec::with_capacity(order.len());
1098    let mut ia = Vec::with_capacity(order.len());
1099    let mut ib = Vec::with_capacity(order.len());
1100    for &idx in &order {
1101        let entry = &entries[idx];
1102        values.push(entry.value.clone());
1103        ia.push((entry.a_index + 1) as f64);
1104        ib.push((entry.b_index + 1) as f64);
1105    }
1106
1107    let value_array = StringArray::new(values, vec![order.len(), 1])
1108        .map_err(|e| intersect_error(format!("intersect: {e}")))?;
1109    let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
1110        .map_err(|e| intersect_error(format!("intersect: {e}")))?;
1111    let ib_tensor = Tensor::new(ib, vec![order.len(), 1])
1112        .map_err(|e| intersect_error(format!("intersect: {e}")))?;
1113
1114    Ok(IntersectEvaluation::new(
1115        Value::StringArray(value_array),
1116        ia_tensor,
1117        ib_tensor,
1118    ))
1119}
1120
1121fn assemble_string_row_intersect(
1122    entries: Vec<StringRowIntersectEntry>,
1123    opts: &IntersectOptions,
1124    cols: usize,
1125) -> crate::BuiltinResult<IntersectEvaluation> {
1126    let mut order: Vec<usize> = (0..entries.len()).collect();
1127    match opts.order {
1128        IntersectOrder::Sorted => {
1129            order.sort_by(|&lhs, &rhs| {
1130                compare_string_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1131            });
1132        }
1133        IntersectOrder::Stable => {
1134            order.sort_by_key(|&idx| entries[idx].order_rank);
1135        }
1136    }
1137
1138    let rows_out = order.len();
1139    let mut values = vec![String::new(); rows_out * cols];
1140    let mut ia = Vec::with_capacity(rows_out);
1141    let mut ib = Vec::with_capacity(rows_out);
1142
1143    for (row_pos, &entry_idx) in order.iter().enumerate() {
1144        let entry = &entries[entry_idx];
1145        for col in 0..cols {
1146            let dest = row_pos + col * rows_out;
1147            values[dest] = entry.row_data[col].clone();
1148        }
1149        ia.push((entry.a_row + 1) as f64);
1150        ib.push((entry.b_row + 1) as f64);
1151    }
1152
1153    let value_array = StringArray::new(values, vec![rows_out, cols])
1154        .map_err(|e| intersect_error(format!("intersect: {e}")))?;
1155    let ia_tensor = Tensor::new(ia, vec![rows_out, 1])
1156        .map_err(|e| intersect_error(format!("intersect: {e}")))?;
1157    let ib_tensor = Tensor::new(ib, vec![rows_out, 1])
1158        .map_err(|e| intersect_error(format!("intersect: {e}")))?;
1159
1160    Ok(IntersectEvaluation::new(
1161        Value::StringArray(value_array),
1162        ia_tensor,
1163        ib_tensor,
1164    ))
1165}
1166
1167#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1168struct NumericRowKey(Vec<u64>);
1169
1170impl NumericRowKey {
1171    fn from_slice(values: &[f64]) -> Self {
1172        NumericRowKey(values.iter().map(|&v| canonicalize_f64(v)).collect())
1173    }
1174}
1175
1176#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
1177struct ComplexKey {
1178    re: u64,
1179    im: u64,
1180}
1181
1182impl ComplexKey {
1183    fn new(value: (f64, f64)) -> Self {
1184        Self {
1185            re: canonicalize_f64(value.0),
1186            im: canonicalize_f64(value.1),
1187        }
1188    }
1189}
1190
1191#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1192struct RowCharKey(Vec<u32>);
1193
1194impl RowCharKey {
1195    fn from_slice(values: &[char]) -> Self {
1196        RowCharKey(values.iter().map(|&ch| ch as u32).collect())
1197    }
1198}
1199
1200#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1201struct RowStringKey(Vec<String>);
1202
1203impl RowStringKey {
1204    fn from_slice(values: &[String]) -> Self {
1205        RowStringKey(values.to_vec())
1206    }
1207}
1208
1209fn scalar_complex_tensor(re: f64, im: f64) -> crate::BuiltinResult<ComplexTensor> {
1210    ComplexTensor::new(vec![(re, im)], vec![1, 1])
1211        .map_err(|e| intersect_error(format!("intersect: {e}")))
1212}
1213
1214fn tensor_to_complex_owned(name: &str, tensor: Tensor) -> crate::BuiltinResult<ComplexTensor> {
1215    let Tensor { data, shape, .. } = tensor;
1216    let complex: Vec<(f64, f64)> = data.into_iter().map(|re| (re, 0.0)).collect();
1217    ComplexTensor::new(complex, shape).map_err(|e| intersect_error(format!("{name}: {e}")))
1218}
1219
1220fn value_into_complex_tensor(value: Value) -> crate::BuiltinResult<ComplexTensor> {
1221    match value {
1222        Value::ComplexTensor(tensor) => Ok(tensor),
1223        Value::Complex(re, im) => scalar_complex_tensor(re, im),
1224        other => {
1225            let tensor = tensor::value_into_tensor_for("intersect", other)
1226                .map_err(|e| intersect_error(e))?;
1227            tensor_to_complex_owned("intersect", tensor)
1228        }
1229    }
1230}
1231
1232fn canonicalize_f64(value: f64) -> u64 {
1233    if value.is_nan() {
1234        0x7ff8_0000_0000_0000u64
1235    } else if value == 0.0 {
1236        0u64
1237    } else {
1238        value.to_bits()
1239    }
1240}
1241
1242fn compare_f64(a: f64, b: f64) -> Ordering {
1243    if a.is_nan() {
1244        if b.is_nan() {
1245            Ordering::Equal
1246        } else {
1247            Ordering::Greater
1248        }
1249    } else if b.is_nan() {
1250        Ordering::Less
1251    } else {
1252        a.partial_cmp(&b).unwrap_or(Ordering::Equal)
1253    }
1254}
1255
1256fn compare_numeric_rows(a: &[f64], b: &[f64]) -> Ordering {
1257    for (lhs, rhs) in a.iter().zip(b.iter()) {
1258        let ord = compare_f64(*lhs, *rhs);
1259        if ord != Ordering::Equal {
1260            return ord;
1261        }
1262    }
1263    Ordering::Equal
1264}
1265
1266fn complex_is_nan(value: (f64, f64)) -> bool {
1267    value.0.is_nan() || value.1.is_nan()
1268}
1269
1270fn compare_complex(a: (f64, f64), b: (f64, f64)) -> Ordering {
1271    match (complex_is_nan(a), complex_is_nan(b)) {
1272        (true, true) => Ordering::Equal,
1273        (true, false) => Ordering::Greater,
1274        (false, true) => Ordering::Less,
1275        (false, false) => {
1276            let mag_a = a.0.hypot(a.1);
1277            let mag_b = b.0.hypot(b.1);
1278            let mag_cmp = compare_f64(mag_a, mag_b);
1279            if mag_cmp != Ordering::Equal {
1280                return mag_cmp;
1281            }
1282            let re_cmp = compare_f64(a.0, b.0);
1283            if re_cmp != Ordering::Equal {
1284                return re_cmp;
1285            }
1286            compare_f64(a.1, b.1)
1287        }
1288    }
1289}
1290
1291fn compare_complex_rows(a: &[(f64, f64)], b: &[(f64, f64)]) -> Ordering {
1292    for (lhs, rhs) in a.iter().zip(b.iter()) {
1293        let ord = compare_complex(*lhs, *rhs);
1294        if ord != Ordering::Equal {
1295            return ord;
1296        }
1297    }
1298    Ordering::Equal
1299}
1300
1301fn compare_char_rows(a: &[char], b: &[char]) -> Ordering {
1302    for (lhs, rhs) in a.iter().zip(b.iter()) {
1303        let ord = lhs.cmp(rhs);
1304        if ord != Ordering::Equal {
1305            return ord;
1306        }
1307    }
1308    Ordering::Equal
1309}
1310
1311fn compare_string_rows(a: &[String], b: &[String]) -> Ordering {
1312    for (lhs, rhs) in a.iter().zip(b.iter()) {
1313        let ord = lhs.cmp(rhs);
1314        if ord != Ordering::Equal {
1315            return ord;
1316        }
1317    }
1318    Ordering::Equal
1319}
1320
1321#[cfg(test)]
1322pub(crate) mod tests {
1323    use super::*;
1324    use crate::builtins::common::test_support;
1325    use runmat_accelerate_api::HostTensorView;
1326    use runmat_builtins::{ResolveContext, Type};
1327
1328    fn error_message(err: crate::RuntimeError) -> String {
1329        err.message().to_string()
1330    }
1331
1332    fn evaluate_sync(
1333        a: Value,
1334        b: Value,
1335        rest: &[Value],
1336    ) -> crate::BuiltinResult<IntersectEvaluation> {
1337        futures::executor::block_on(evaluate(a, b, rest))
1338    }
1339
1340    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1341    #[test]
1342    fn intersect_numeric_sorted() {
1343        let a = Tensor::new(vec![5.0, 7.0, 5.0, 1.0], vec![4, 1]).unwrap();
1344        let b = Tensor::new(vec![7.0, 1.0, 3.0], vec![3, 1]).unwrap();
1345        let eval = intersect_numeric_elements(
1346            a,
1347            b,
1348            &IntersectOptions {
1349                rows: false,
1350                order: IntersectOrder::Sorted,
1351            },
1352        )
1353        .expect("intersect");
1354        let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1355        assert_eq!(values.data, vec![1.0, 7.0]);
1356        let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1357        let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1358        assert_eq!(ia.data, vec![4.0, 2.0]);
1359        assert_eq!(ib.data, vec![2.0, 1.0]);
1360    }
1361
1362    #[test]
1363    fn intersect_type_resolver_numeric() {
1364        assert_eq!(
1365            set_values_output_type(&[Type::tensor()], &ResolveContext::new(Vec::new())),
1366            Type::tensor()
1367        );
1368    }
1369
1370    #[test]
1371    fn intersect_type_resolver_string_array() {
1372        assert_eq!(
1373            set_values_output_type(
1374                &[Type::cell_of(Type::String)],
1375                &ResolveContext::new(Vec::new()),
1376            ),
1377            Type::cell_of(Type::String)
1378        );
1379    }
1380
1381    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1382    #[test]
1383    fn intersect_numeric_stable() {
1384        let a = Tensor::new(vec![4.0, 2.0, 4.0, 1.0, 3.0], vec![5, 1]).unwrap();
1385        let b = Tensor::new(vec![3.0, 4.0, 5.0, 1.0], vec![4, 1]).unwrap();
1386        let eval = intersect_numeric_elements(
1387            a,
1388            b,
1389            &IntersectOptions {
1390                rows: false,
1391                order: IntersectOrder::Stable,
1392            },
1393        )
1394        .expect("intersect");
1395        let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1396        assert_eq!(values.data, vec![4.0, 1.0, 3.0]);
1397        let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1398        let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1399        assert_eq!(ia.data, vec![1.0, 4.0, 5.0]);
1400        assert_eq!(ib.data, vec![2.0, 4.0, 1.0]);
1401    }
1402
1403    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1404    #[test]
1405    fn intersect_numeric_handles_nan() {
1406        let a = Tensor::new(vec![f64::NAN, 1.0, f64::NAN], vec![3, 1]).unwrap();
1407        let b = Tensor::new(vec![2.0, f64::NAN], vec![2, 1]).unwrap();
1408        let eval = intersect_numeric_elements(
1409            a,
1410            b,
1411            &IntersectOptions {
1412                rows: false,
1413                order: IntersectOrder::Sorted,
1414            },
1415        )
1416        .expect("intersect");
1417        let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1418        assert_eq!(values.data.len(), 1);
1419        assert!(values.data[0].is_nan());
1420        let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1421        let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1422        assert_eq!(ia.data, vec![1.0]);
1423        assert_eq!(ib.data, vec![2.0]);
1424    }
1425
1426    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1427    #[test]
1428    fn intersect_complex_with_real_inputs() {
1429        let complex =
1430            ComplexTensor::new(vec![(1.0, 0.0), (2.0, 0.0), (3.0, 1.0)], vec![3, 1]).unwrap();
1431        let real = Tensor::new(vec![2.0, 4.0, 1.0], vec![3, 1]).unwrap();
1432        let real_complex = tensor_to_complex_owned("intersect", real).unwrap();
1433        let eval = intersect_complex(
1434            complex,
1435            real_complex,
1436            &IntersectOptions {
1437                rows: false,
1438                order: IntersectOrder::Sorted,
1439            },
1440        )
1441        .expect("intersect complex");
1442        match eval.values_value() {
1443            Value::ComplexTensor(t) => {
1444                assert_eq!(t.data, vec![(1.0, 0.0), (2.0, 0.0)]);
1445            }
1446            other => panic!("expected complex tensor, got {other:?}"),
1447        }
1448        let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1449        let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1450        assert_eq!(ia.data, vec![1.0, 2.0]);
1451        assert_eq!(ib.data, vec![3.0, 1.0]);
1452    }
1453
1454    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1455    #[test]
1456    fn intersect_numeric_rows_default() {
1457        let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
1458        let b = Tensor::new(vec![1.0, 5.0, 2.0, 6.0], vec![2, 2]).unwrap();
1459        let eval = intersect_numeric_rows(
1460            a,
1461            b,
1462            &IntersectOptions {
1463                rows: true,
1464                order: IntersectOrder::Sorted,
1465            },
1466        )
1467        .expect("intersect rows");
1468        let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1469        assert_eq!(values.shape, vec![1, 2]);
1470        assert_eq!(values.data, vec![1.0, 2.0]);
1471        let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1472        let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1473        assert_eq!(ia.data, vec![1.0]);
1474        assert_eq!(ib.data, vec![1.0]);
1475    }
1476
1477    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1478    #[test]
1479    fn intersect_char_elements_basic() {
1480        let a = CharArray::new("cab".chars().collect(), 1, 3).unwrap();
1481        let b = CharArray::new("bcd".chars().collect(), 1, 3).unwrap();
1482        assert_eq!(find_char_index(&b, 'b'), Some(0));
1483        assert_eq!(find_char_index(&b, 'c'), Some(1));
1484        let b_for_eval = CharArray::new("bcd".chars().collect(), 1, 3).unwrap();
1485        let eval = intersect_char_elements(
1486            a,
1487            b_for_eval,
1488            &IntersectOptions {
1489                rows: false,
1490                order: IntersectOrder::Sorted,
1491            },
1492        )
1493        .expect("intersect char");
1494        match eval.values_value() {
1495            Value::CharArray(arr) => {
1496                assert_eq!(arr.rows, 2);
1497                assert_eq!(arr.cols, 1);
1498                assert_eq!(arr.data, vec!['b', 'c']);
1499            }
1500            other => panic!("expected char array, got {other:?}"),
1501        }
1502        let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1503        let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1504        assert_eq!(ia.data, vec![3.0, 1.0]);
1505        assert_eq!(ib.data, vec![1.0, 2.0]);
1506    }
1507
1508    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1509    #[test]
1510    fn intersect_string_elements_stable() {
1511        let a = StringArray::new(
1512            vec!["apple".into(), "orange".into(), "pear".into()],
1513            vec![3, 1],
1514        )
1515        .unwrap();
1516        let b = StringArray::new(
1517            vec!["pear".into(), "grape".into(), "orange".into()],
1518            vec![3, 1],
1519        )
1520        .unwrap();
1521        let eval = intersect_string_elements(
1522            a,
1523            b,
1524            &IntersectOptions {
1525                rows: false,
1526                order: IntersectOrder::Stable,
1527            },
1528        )
1529        .expect("intersect string");
1530        match eval.values_value() {
1531            Value::StringArray(arr) => {
1532                assert_eq!(arr.shape, vec![2, 1]);
1533                assert_eq!(arr.data, vec!["orange".to_string(), "pear".to_string()]);
1534            }
1535            other => panic!("expected string array, got {other:?}"),
1536        }
1537        let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1538        let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1539        assert_eq!(ia.data, vec![2.0, 3.0]);
1540        assert_eq!(ib.data, vec![3.0, 1.0]);
1541    }
1542
1543    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1544    #[test]
1545    fn intersect_rejects_legacy_option() {
1546        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1547        let err = error_message(
1548            evaluate_sync(
1549                Value::Tensor(tensor.clone()),
1550                Value::Tensor(tensor),
1551                &[Value::from("legacy")],
1552            )
1553            .unwrap_err(),
1554        );
1555        assert!(err.contains("legacy"));
1556    }
1557
1558    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1559    #[test]
1560    fn intersect_rows_dimension_mismatch() {
1561        let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1562        let b = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1563        let err = error_message(
1564            intersect_numeric_rows(
1565                a,
1566                b,
1567                &IntersectOptions {
1568                    rows: true,
1569                    order: IntersectOrder::Sorted,
1570                },
1571            )
1572            .unwrap_err(),
1573        );
1574        assert!(err.contains("same number of columns"));
1575    }
1576
1577    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1578    #[test]
1579    fn intersect_mixed_types_error() {
1580        let a = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1581        let b = CharArray::new(vec!['a', 'b'], 1, 2).unwrap();
1582        let err = error_message(
1583            intersect_host(
1584                Value::Tensor(a),
1585                Value::CharArray(b),
1586                &IntersectOptions {
1587                    rows: false,
1588                    order: IntersectOrder::Sorted,
1589                },
1590            )
1591            .unwrap_err(),
1592        );
1593        assert!(err.contains("unsupported input type"));
1594    }
1595
1596    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1597    #[test]
1598    fn intersect_gpu_roundtrip() {
1599        test_support::with_test_provider(|provider| {
1600            let a = Tensor::new(vec![4.0, 1.0, 2.0, 1.0], vec![4, 1]).unwrap();
1601            let b = Tensor::new(vec![2.0, 5.0, 1.0], vec![3, 1]).unwrap();
1602            let view_a = HostTensorView {
1603                data: &a.data,
1604                shape: &a.shape,
1605            };
1606            let view_b = HostTensorView {
1607                data: &b.data,
1608                shape: &b.shape,
1609            };
1610            let handle_a = provider.upload(&view_a).expect("upload A");
1611            let handle_b = provider.upload(&view_b).expect("upload B");
1612            let eval = evaluate_sync(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
1613                .expect("intersect");
1614            let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1615            assert_eq!(values.data, vec![1.0, 2.0]);
1616            let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1617            let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1618            assert_eq!(ia.data, vec![2.0, 3.0]);
1619            assert_eq!(ib.data, vec![3.0, 1.0]);
1620        });
1621    }
1622
1623    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1624    #[test]
1625    fn intersect_two_outputs_from_evaluate() {
1626        let a = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1627        let b = Tensor::new(vec![3.0, 1.0], vec![2, 1]).unwrap();
1628        let eval = intersect_numeric_elements(
1629            a,
1630            b,
1631            &IntersectOptions {
1632                rows: false,
1633                order: IntersectOrder::Sorted,
1634            },
1635        )
1636        .unwrap();
1637        let (_c, ia) = eval.clone().into_pair();
1638        let ia_tensor = tensor::value_into_tensor_for("intersect", ia).unwrap();
1639        assert_eq!(ia_tensor.data, vec![1.0, 3.0]);
1640        let (_c, ia2, ib2) = eval.into_triple();
1641        let ia_tensor2 = tensor::value_into_tensor_for("intersect", ia2).unwrap();
1642        let ib_tensor2 = tensor::value_into_tensor_for("intersect", ib2).unwrap();
1643        assert_eq!(ia_tensor2.data, vec![1.0, 3.0]);
1644        assert_eq!(ib_tensor2.data, vec![2.0, 1.0]);
1645    }
1646
1647    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1648    #[test]
1649    #[cfg(feature = "wgpu")]
1650    fn intersect_wgpu_matches_cpu() {
1651        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1652            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1653        );
1654        let a = Tensor::new(vec![4.0, 1.0, 2.0, 3.0], vec![4, 1]).unwrap();
1655        let b = Tensor::new(vec![2.0, 6.0, 3.0], vec![3, 1]).unwrap();
1656
1657        let cpu_eval = intersect_numeric_elements(
1658            a.clone(),
1659            b.clone(),
1660            &IntersectOptions {
1661                rows: false,
1662                order: IntersectOrder::Sorted,
1663            },
1664        )
1665        .unwrap();
1666        let cpu_values =
1667            tensor::value_into_tensor_for("intersect", cpu_eval.values_value()).unwrap();
1668        let cpu_ia = tensor::value_into_tensor_for("intersect", cpu_eval.ia_value()).unwrap();
1669        let cpu_ib = tensor::value_into_tensor_for("intersect", cpu_eval.ib_value()).unwrap();
1670
1671        let provider = runmat_accelerate_api::provider().expect("provider");
1672        let view_a = HostTensorView {
1673            data: &a.data,
1674            shape: &a.shape,
1675        };
1676        let view_b = HostTensorView {
1677            data: &b.data,
1678            shape: &b.shape,
1679        };
1680        let handle_a = provider.upload(&view_a).expect("upload A");
1681        let handle_b = provider.upload(&view_b).expect("upload B");
1682        let gpu_eval = evaluate_sync(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
1683            .expect("intersect");
1684        let gpu_values =
1685            tensor::value_into_tensor_for("intersect", gpu_eval.values_value()).unwrap();
1686        let gpu_ia = tensor::value_into_tensor_for("intersect", gpu_eval.ia_value()).unwrap();
1687        let gpu_ib = tensor::value_into_tensor_for("intersect", gpu_eval.ib_value()).unwrap();
1688
1689        assert_eq!(gpu_values.data, cpu_values.data);
1690        assert_eq!(gpu_ia.data, cpu_ia.data);
1691        assert_eq!(gpu_ib.data, cpu_ib.data);
1692    }
1693}