Skip to main content

runmat_runtime/builtins/array/sorting_sets/
intersect.rs

1//! MATLAB-compatible `intersect` builtin with GPU-aware semantics for RunMat.
2//!
3//! Supports element-wise and row-wise intersections with optional stable ordering,
4//! and index outputs that mirror MathWorks MATLAB semantics. GPU tensors are
5//! gathered to host memory unless a provider supplies a dedicated `intersect`
6//! kernel hook.
7
8use std::cmp::Ordering;
9use std::collections::{HashMap, HashSet};
10
11use runmat_accelerate_api::GpuTensorHandle;
12use runmat_builtins::{
13    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
14    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
15    CharArray, ComplexTensor, StringArray, Tensor, Value,
16};
17use runmat_macros::runtime_builtin;
18
19use super::type_resolvers::set_values_output_type;
20use crate::build_runtime_error;
21use crate::builtins::common::arg_tokens::tokens_from_values;
22use crate::builtins::common::gpu_helpers;
23use crate::builtins::common::random_args::complex_tensor_into_value;
24use crate::builtins::common::spec::{
25    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
26    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
27};
28use crate::builtins::common::tensor;
29
30#[runmat_macros::register_gpu_spec(
31    builtin_path = "crate::builtins::array::sorting_sets::intersect"
32)]
33pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
34    name: "intersect",
35    op_kind: GpuOpKind::Custom("intersect"),
36    supported_precisions: &[ScalarType::F32, ScalarType::F64],
37    broadcast: BroadcastSemantics::None,
38    provider_hooks: &[ProviderHook::Custom("intersect")],
39    constant_strategy: ConstantStrategy::InlineLiteral,
40    residency: ResidencyPolicy::GatherImmediately,
41    nan_mode: ReductionNaN::Include,
42    two_pass_threshold: None,
43    workgroup_size: None,
44    accepts_nan_mode: true,
45    notes:
46        "Providers may expose a dedicated intersect hook; otherwise tensors are gathered and processed on the host.",
47};
48
49#[runmat_macros::register_fusion_spec(
50    builtin_path = "crate::builtins::array::sorting_sets::intersect"
51)]
52pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
53    name: "intersect",
54    shape: ShapeRequirements::Any,
55    constant_strategy: ConstantStrategy::InlineLiteral,
56    elementwise: None,
57    reduction: None,
58    emits_nan: true,
59    notes: "`intersect` materialises its inputs and terminates fusion chains; upstream GPU tensors are gathered when necessary.",
60};
61
62const BUILTIN_NAME: &str = "intersect";
63
64const INTERSECT_OUTPUT_C: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
65    name: "C",
66    ty: BuiltinParamType::Any,
67    arity: BuiltinParamArity::Required,
68    default: None,
69    description: "Intersection values or rows.",
70}];
71
72const INTERSECT_OUTPUT_C_IA: [BuiltinParamDescriptor; 2] = [
73    BuiltinParamDescriptor {
74        name: "C",
75        ty: BuiltinParamType::Any,
76        arity: BuiltinParamArity::Required,
77        default: None,
78        description: "Intersection values or rows.",
79    },
80    BuiltinParamDescriptor {
81        name: "ia",
82        ty: BuiltinParamType::NumericArray,
83        arity: BuiltinParamArity::Required,
84        default: None,
85        description: "Indices selecting matching elements/rows in A.",
86    },
87];
88
89const INTERSECT_OUTPUT_C_IA_IB: [BuiltinParamDescriptor; 3] = [
90    BuiltinParamDescriptor {
91        name: "C",
92        ty: BuiltinParamType::Any,
93        arity: BuiltinParamArity::Required,
94        default: None,
95        description: "Intersection values or rows.",
96    },
97    BuiltinParamDescriptor {
98        name: "ia",
99        ty: BuiltinParamType::NumericArray,
100        arity: BuiltinParamArity::Required,
101        default: None,
102        description: "Indices selecting matching elements/rows in A.",
103    },
104    BuiltinParamDescriptor {
105        name: "ib",
106        ty: BuiltinParamType::NumericArray,
107        arity: BuiltinParamArity::Required,
108        default: None,
109        description: "Indices selecting matching elements/rows in B.",
110    },
111];
112
113const INTERSECT_INPUTS_A_B: [BuiltinParamDescriptor; 2] = [
114    BuiltinParamDescriptor {
115        name: "A",
116        ty: BuiltinParamType::Any,
117        arity: BuiltinParamArity::Required,
118        default: None,
119        description: "First input array.",
120    },
121    BuiltinParamDescriptor {
122        name: "B",
123        ty: BuiltinParamType::Any,
124        arity: BuiltinParamArity::Required,
125        default: None,
126        description: "Second input array.",
127    },
128];
129
130const INTERSECT_INPUTS_A_B_OPTIONS: [BuiltinParamDescriptor; 3] = [
131    BuiltinParamDescriptor {
132        name: "A",
133        ty: BuiltinParamType::Any,
134        arity: BuiltinParamArity::Required,
135        default: None,
136        description: "First input array.",
137    },
138    BuiltinParamDescriptor {
139        name: "B",
140        ty: BuiltinParamType::Any,
141        arity: BuiltinParamArity::Required,
142        default: None,
143        description: "Second input array.",
144    },
145    BuiltinParamDescriptor {
146        name: "option",
147        ty: BuiltinParamType::StringScalar,
148        arity: BuiltinParamArity::Variadic,
149        default: None,
150        description: "Option tokens: 'rows'|'sorted'|'stable'.",
151    },
152];
153
154const INTERSECT_SIGNATURES: [BuiltinSignatureDescriptor; 6] = [
155    BuiltinSignatureDescriptor {
156        label: "C = intersect(A, B)",
157        inputs: &INTERSECT_INPUTS_A_B,
158        outputs: &INTERSECT_OUTPUT_C,
159    },
160    BuiltinSignatureDescriptor {
161        label: "C = intersect(A, B, option...)",
162        inputs: &INTERSECT_INPUTS_A_B_OPTIONS,
163        outputs: &INTERSECT_OUTPUT_C,
164    },
165    BuiltinSignatureDescriptor {
166        label: "[C, ia] = intersect(A, B)",
167        inputs: &INTERSECT_INPUTS_A_B,
168        outputs: &INTERSECT_OUTPUT_C_IA,
169    },
170    BuiltinSignatureDescriptor {
171        label: "[C, ia] = intersect(A, B, option...)",
172        inputs: &INTERSECT_INPUTS_A_B_OPTIONS,
173        outputs: &INTERSECT_OUTPUT_C_IA,
174    },
175    BuiltinSignatureDescriptor {
176        label: "[C, ia, ib] = intersect(A, B)",
177        inputs: &INTERSECT_INPUTS_A_B,
178        outputs: &INTERSECT_OUTPUT_C_IA_IB,
179    },
180    BuiltinSignatureDescriptor {
181        label: "[C, ia, ib] = intersect(A, B, option...)",
182        inputs: &INTERSECT_INPUTS_A_B_OPTIONS,
183        outputs: &INTERSECT_OUTPUT_C_IA_IB,
184    },
185];
186
187const INTERSECT_ERROR_LEGACY_OPTION_UNSUPPORTED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
188    code: "RM.INTERSECT.LEGACY_OPTION_UNSUPPORTED",
189    identifier: Some("RunMat:intersect:LegacyOptionUnsupported"),
190    when: "Legacy compatibility options are requested.",
191    message: "intersect: the 'legacy' behaviour is not supported",
192};
193
194const INTERSECT_ERROR_CONFLICTING_ORDER_OPTIONS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
195    code: "RM.INTERSECT.CONFLICTING_ORDER_OPTIONS",
196    identifier: Some("RunMat:intersect:ConflictingOrderOptions"),
197    when: "Both 'sorted' and 'stable' options are provided.",
198    message: "intersect: cannot combine 'sorted' with 'stable'",
199};
200
201const INTERSECT_ERROR_UNKNOWN_OPTION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
202    code: "RM.INTERSECT.UNKNOWN_OPTION",
203    identifier: Some("RunMat:intersect:UnknownOption"),
204    when: "An unsupported option token is provided.",
205    message: "intersect: unrecognised option",
206};
207
208const INTERSECT_ERROR_ROWS_COLUMN_MISMATCH: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
209    code: "RM.INTERSECT.ROWS_COLUMN_MISMATCH",
210    identifier: Some("RunMat:intersect:RowsColumnMismatch"),
211    when: "'rows' mode is used and column counts differ.",
212    message: "intersect: inputs must have the same number of columns when using 'rows'",
213};
214
215const INTERSECT_ERROR_UNSUPPORTED_INPUT_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
216    code: "RM.INTERSECT.UNSUPPORTED_INPUT_TYPE",
217    identifier: Some("RunMat:intersect:UnsupportedInputType"),
218    when: "Input values cannot be converted into supported intersect domains.",
219    message: "intersect: unsupported input type",
220};
221
222const INTERSECT_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
223    code: "RM.INTERSECT.INVALID_ARGUMENT",
224    identifier: Some("RunMat:intersect:InvalidArgument"),
225    when: "Option arguments are not string-like where required.",
226    message: "intersect: expected string option arguments",
227};
228
229const INTERSECT_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
230    code: "RM.INTERSECT.INTERNAL",
231    identifier: Some("RunMat:intersect:Internal"),
232    when: "Internal conversion/allocation/provider decode fails.",
233    message: "intersect: internal operation failed",
234};
235
236const INTERSECT_ERRORS: [BuiltinErrorDescriptor; 7] = [
237    INTERSECT_ERROR_LEGACY_OPTION_UNSUPPORTED,
238    INTERSECT_ERROR_CONFLICTING_ORDER_OPTIONS,
239    INTERSECT_ERROR_UNKNOWN_OPTION,
240    INTERSECT_ERROR_ROWS_COLUMN_MISMATCH,
241    INTERSECT_ERROR_UNSUPPORTED_INPUT_TYPE,
242    INTERSECT_ERROR_INVALID_ARGUMENT,
243    INTERSECT_ERROR_INTERNAL,
244];
245
246pub const INTERSECT_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
247    signatures: &INTERSECT_SIGNATURES,
248    output_mode: BuiltinOutputMode::ByRequestedOutputCount,
249    completion_policy: BuiltinCompletionPolicy::Public,
250    errors: &INTERSECT_ERRORS,
251};
252
253fn intersect_error_with(
254    error: &'static BuiltinErrorDescriptor,
255    message: impl Into<String>,
256) -> crate::RuntimeError {
257    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
258    if let Some(identifier) = error.identifier {
259        builder = builder.with_identifier(identifier);
260    }
261    builder.build()
262}
263
264fn intersect_error(error: &'static BuiltinErrorDescriptor) -> crate::RuntimeError {
265    intersect_error_with(error, error.message)
266}
267
268fn intersect_internal_error(message: impl Into<String>) -> crate::RuntimeError {
269    intersect_error_with(&INTERSECT_ERROR_INTERNAL, message)
270}
271
272#[runtime_builtin(
273    name = "intersect",
274    category = "array/sorting_sets",
275    summary = "Return common elements or rows across arrays with index outputs.",
276    keywords = "intersect,set,stable,rows,indices,gpu",
277    accel = "array_construct",
278    sink = true,
279    type_resolver(set_values_output_type),
280    descriptor(crate::builtins::array::sorting_sets::intersect::INTERSECT_DESCRIPTOR),
281    builtin_path = "crate::builtins::array::sorting_sets::intersect"
282)]
283async fn intersect_builtin(a: Value, b: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
284    let eval = evaluate(a, b, &rest).await?;
285    if let Some(out_count) = crate::output_count::current_output_count() {
286        if out_count == 0 {
287            return Ok(Value::OutputList(Vec::new()));
288        }
289        if out_count == 1 {
290            return Ok(Value::OutputList(vec![eval.into_values_value()]));
291        }
292        if out_count == 2 {
293            let (values, ia) = eval.into_pair();
294            return Ok(Value::OutputList(vec![values, ia]));
295        }
296        let (values, ia, ib) = eval.into_triple();
297        return Ok(crate::output_count::output_list_with_padding(
298            out_count,
299            vec![values, ia, ib],
300        ));
301    }
302    Ok(eval.into_values_value())
303}
304
305/// Evaluate the `intersect` builtin once and expose all outputs.
306pub async fn evaluate(
307    a: Value,
308    b: Value,
309    rest: &[Value],
310) -> crate::BuiltinResult<IntersectEvaluation> {
311    let opts = parse_options(rest)?;
312    match (a, b) {
313        (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
314            intersect_gpu_pair(handle_a, handle_b, &opts).await
315        }
316        (Value::GpuTensor(handle_a), other) => {
317            intersect_gpu_mixed(handle_a, other, &opts, true).await
318        }
319        (other, Value::GpuTensor(handle_b)) => {
320            intersect_gpu_mixed(handle_b, other, &opts, false).await
321        }
322        (left, right) => intersect_host(left, right, &opts),
323    }
324}
325
326#[derive(Debug, Clone, Copy, PartialEq, Eq)]
327enum IntersectOrder {
328    Sorted,
329    Stable,
330}
331
332#[derive(Debug, Clone)]
333struct IntersectOptions {
334    rows: bool,
335    order: IntersectOrder,
336}
337
338fn parse_options(rest: &[Value]) -> crate::BuiltinResult<IntersectOptions> {
339    let mut opts = IntersectOptions {
340        rows: false,
341        order: IntersectOrder::Sorted,
342    };
343    let mut seen_order: Option<IntersectOrder> = None;
344
345    let tokens = tokens_from_values(rest);
346    for (arg, token) in rest.iter().zip(tokens.iter()) {
347        let text = match token {
348            crate::builtins::common::arg_tokens::ArgToken::String(text) => text.as_str(),
349            _ => {
350                let text = tensor::value_to_string(arg)
351                    .ok_or_else(|| intersect_error(&INTERSECT_ERROR_INVALID_ARGUMENT))?;
352                let lowered = text.trim().to_ascii_lowercase();
353                parse_intersect_option(&mut opts, &mut seen_order, &lowered)?;
354                continue;
355            }
356        };
357        parse_intersect_option(&mut opts, &mut seen_order, text)?;
358    }
359
360    Ok(opts)
361}
362
363fn parse_intersect_option(
364    opts: &mut IntersectOptions,
365    seen_order: &mut Option<IntersectOrder>,
366    lowered: &str,
367) -> crate::BuiltinResult<()> {
368    match lowered {
369        "rows" => opts.rows = true,
370        "sorted" => {
371            if let Some(prev) = seen_order {
372                if *prev != IntersectOrder::Sorted {
373                    return Err(intersect_error(&INTERSECT_ERROR_CONFLICTING_ORDER_OPTIONS));
374                }
375            }
376            *seen_order = Some(IntersectOrder::Sorted);
377            opts.order = IntersectOrder::Sorted;
378        }
379        "stable" => {
380            if let Some(prev) = seen_order {
381                if *prev != IntersectOrder::Stable {
382                    return Err(intersect_error(&INTERSECT_ERROR_CONFLICTING_ORDER_OPTIONS));
383                }
384            }
385            *seen_order = Some(IntersectOrder::Stable);
386            opts.order = IntersectOrder::Stable;
387        }
388        "legacy" | "r2012a" => {
389            return Err(intersect_error(&INTERSECT_ERROR_LEGACY_OPTION_UNSUPPORTED));
390        }
391        other => {
392            return Err(intersect_error_with(
393                &INTERSECT_ERROR_UNKNOWN_OPTION,
394                format!("intersect: unrecognised option '{other}'"),
395            ))
396        }
397    }
398    Ok(())
399}
400
401async fn intersect_gpu_pair(
402    handle_a: GpuTensorHandle,
403    handle_b: GpuTensorHandle,
404    opts: &IntersectOptions,
405) -> crate::BuiltinResult<IntersectEvaluation> {
406    let tensor_a = gpu_helpers::gather_tensor_async(&handle_a).await?;
407    let tensor_b = gpu_helpers::gather_tensor_async(&handle_b).await?;
408    intersect_numeric(tensor_a, tensor_b, opts)
409}
410
411async fn intersect_gpu_mixed(
412    handle_gpu: GpuTensorHandle,
413    other: Value,
414    opts: &IntersectOptions,
415    gpu_is_a: bool,
416) -> crate::BuiltinResult<IntersectEvaluation> {
417    let tensor_gpu = gpu_helpers::gather_tensor_async(&handle_gpu).await?;
418    let tensor_other = tensor::value_into_tensor_for("intersect", other)
419        .map_err(|e| intersect_internal_error(e))?;
420    if gpu_is_a {
421        intersect_numeric(tensor_gpu, tensor_other, opts)
422    } else {
423        intersect_numeric(tensor_other, tensor_gpu, opts)
424    }
425}
426
427fn intersect_host(
428    a: Value,
429    b: Value,
430    opts: &IntersectOptions,
431) -> crate::BuiltinResult<IntersectEvaluation> {
432    match (a, b) {
433        (Value::ComplexTensor(at), Value::ComplexTensor(bt)) => intersect_complex(at, bt, opts),
434        (Value::ComplexTensor(at), Value::Complex(re, im)) => {
435            let bt = scalar_complex_tensor(re, im)?;
436            intersect_complex(at, bt, opts)
437        }
438        (Value::Complex(re, im), Value::ComplexTensor(bt)) => {
439            let at = scalar_complex_tensor(re, im)?;
440            intersect_complex(at, bt, opts)
441        }
442        (Value::Complex(a_re, a_im), Value::Complex(b_re, b_im)) => {
443            let at = scalar_complex_tensor(a_re, a_im)?;
444            let bt = scalar_complex_tensor(b_re, b_im)?;
445            intersect_complex(at, bt, opts)
446        }
447        (Value::ComplexTensor(at), other) => {
448            let bt = value_into_complex_tensor(other)?;
449            intersect_complex(at, bt, opts)
450        }
451        (other, Value::ComplexTensor(bt)) => {
452            let at = value_into_complex_tensor(other)?;
453            intersect_complex(at, bt, opts)
454        }
455        (Value::Complex(re, im), other) => {
456            let at = scalar_complex_tensor(re, im)?;
457            let bt = value_into_complex_tensor(other)?;
458            intersect_complex(at, bt, opts)
459        }
460        (other, Value::Complex(re, im)) => {
461            let at = value_into_complex_tensor(other)?;
462            let bt = scalar_complex_tensor(re, im)?;
463            intersect_complex(at, bt, opts)
464        }
465
466        (Value::CharArray(ac), Value::CharArray(bc)) => intersect_char(ac, bc, opts),
467
468        (Value::StringArray(astring), Value::StringArray(bstring)) => {
469            intersect_string(astring, bstring, opts)
470        }
471        (Value::StringArray(astring), Value::String(b)) => {
472            let bstring = StringArray::new(vec![b], vec![1, 1])
473                .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
474            intersect_string(astring, bstring, opts)
475        }
476        (Value::String(a), Value::StringArray(bstring)) => {
477            let astring = StringArray::new(vec![a], vec![1, 1])
478                .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
479            intersect_string(astring, bstring, opts)
480        }
481        (Value::String(a), Value::String(b)) => {
482            let astring = StringArray::new(vec![a], vec![1, 1])
483                .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
484            let bstring = StringArray::new(vec![b], vec![1, 1])
485                .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
486            intersect_string(astring, bstring, opts)
487        }
488
489        (left, right) => {
490            let tensor_a = tensor::value_into_tensor_for("intersect", left)
491                .map_err(|e| intersect_error_with(&INTERSECT_ERROR_UNSUPPORTED_INPUT_TYPE, e))?;
492            let tensor_b = tensor::value_into_tensor_for("intersect", right)
493                .map_err(|e| intersect_error_with(&INTERSECT_ERROR_UNSUPPORTED_INPUT_TYPE, e))?;
494            intersect_numeric(tensor_a, tensor_b, opts)
495        }
496    }
497}
498
499fn intersect_numeric(
500    a: Tensor,
501    b: Tensor,
502    opts: &IntersectOptions,
503) -> crate::BuiltinResult<IntersectEvaluation> {
504    if opts.rows {
505        intersect_numeric_rows(a, b, opts)
506    } else {
507        intersect_numeric_elements(a, b, opts)
508    }
509}
510
511fn intersect_numeric_elements(
512    a: Tensor,
513    b: Tensor,
514    opts: &IntersectOptions,
515) -> crate::BuiltinResult<IntersectEvaluation> {
516    let mut b_map: HashMap<u64, usize> = HashMap::new();
517    for (idx, &value) in b.data.iter().enumerate() {
518        let key = canonicalize_f64(value);
519        b_map.entry(key).or_insert(idx);
520    }
521
522    let mut seen: HashSet<u64> = HashSet::new();
523    let mut entries = Vec::<NumericIntersectEntry>::new();
524    let mut order_counter = 0usize;
525
526    for (idx, &value) in a.data.iter().enumerate() {
527        let key = canonicalize_f64(value);
528        if seen.contains(&key) {
529            continue;
530        }
531        if let Some(&b_idx) = b_map.get(&key) {
532            entries.push(NumericIntersectEntry {
533                value,
534                a_index: idx,
535                b_index: b_idx,
536                order_rank: order_counter,
537            });
538            seen.insert(key);
539            order_counter += 1;
540        }
541    }
542
543    assemble_numeric_intersect(entries, opts)
544}
545
546fn intersect_numeric_rows(
547    a: Tensor,
548    b: Tensor,
549    opts: &IntersectOptions,
550) -> crate::BuiltinResult<IntersectEvaluation> {
551    if a.shape.len() != 2 || b.shape.len() != 2 {
552        return Err(intersect_internal_error(
553            "intersect: 'rows' option requires 2-D numeric matrices",
554        ));
555    }
556    if a.shape[1] != b.shape[1] {
557        return Err(intersect_error(&INTERSECT_ERROR_ROWS_COLUMN_MISMATCH));
558    }
559    let rows_a = a.shape[0];
560    let cols = a.shape[1];
561    let rows_b = b.shape[0];
562
563    let mut b_map: HashMap<NumericRowKey, usize> = HashMap::new();
564    for r in 0..rows_b {
565        let mut row_values = Vec::with_capacity(cols);
566        for c in 0..cols {
567            let idx = r + c * rows_b;
568            row_values.push(b.data[idx]);
569        }
570        let key = NumericRowKey::from_slice(&row_values);
571        b_map.entry(key).or_insert(r);
572    }
573
574    let mut seen: HashSet<NumericRowKey> = HashSet::new();
575    let mut entries = Vec::<NumericRowIntersectEntry>::new();
576    let mut order_counter = 0usize;
577
578    for r in 0..rows_a {
579        let mut row_values = Vec::with_capacity(cols);
580        for c in 0..cols {
581            let idx = r + c * rows_a;
582            row_values.push(a.data[idx]);
583        }
584        let key = NumericRowKey::from_slice(&row_values);
585        if seen.contains(&key) {
586            continue;
587        }
588        if let Some(&b_row) = b_map.get(&key) {
589            entries.push(NumericRowIntersectEntry {
590                row_data: row_values,
591                a_row: r,
592                b_row,
593                order_rank: order_counter,
594            });
595            seen.insert(key);
596            order_counter += 1;
597        }
598    }
599
600    assemble_numeric_row_intersect(entries, opts, cols)
601}
602
603fn intersect_complex(
604    a: ComplexTensor,
605    b: ComplexTensor,
606    opts: &IntersectOptions,
607) -> crate::BuiltinResult<IntersectEvaluation> {
608    if opts.rows {
609        intersect_complex_rows(a, b, opts)
610    } else {
611        intersect_complex_elements(a, b, opts)
612    }
613}
614
615fn intersect_complex_elements(
616    a: ComplexTensor,
617    b: ComplexTensor,
618    opts: &IntersectOptions,
619) -> crate::BuiltinResult<IntersectEvaluation> {
620    let mut b_map: HashMap<ComplexKey, usize> = HashMap::new();
621    for (idx, &value) in b.data.iter().enumerate() {
622        let key = ComplexKey::new(value);
623        b_map.entry(key).or_insert(idx);
624    }
625
626    let mut seen: HashSet<ComplexKey> = HashSet::new();
627    let mut entries = Vec::<ComplexIntersectEntry>::new();
628    let mut order_counter = 0usize;
629
630    for (idx, &value) in a.data.iter().enumerate() {
631        let key = ComplexKey::new(value);
632        if seen.contains(&key) {
633            continue;
634        }
635        if let Some(&b_idx) = b_map.get(&key) {
636            entries.push(ComplexIntersectEntry {
637                value,
638                a_index: idx,
639                b_index: b_idx,
640                order_rank: order_counter,
641            });
642            seen.insert(key);
643            order_counter += 1;
644        }
645    }
646
647    assemble_complex_intersect(entries, opts)
648}
649
650fn intersect_complex_rows(
651    a: ComplexTensor,
652    b: ComplexTensor,
653    opts: &IntersectOptions,
654) -> crate::BuiltinResult<IntersectEvaluation> {
655    if a.shape.len() != 2 || b.shape.len() != 2 {
656        return Err(intersect_internal_error(
657            "intersect: 'rows' option requires 2-D complex matrices",
658        ));
659    }
660    if a.shape[1] != b.shape[1] {
661        return Err(intersect_error(&INTERSECT_ERROR_ROWS_COLUMN_MISMATCH));
662    }
663    let rows_a = a.shape[0];
664    let cols = a.shape[1];
665    let rows_b = b.shape[0];
666
667    let mut b_map: HashMap<Vec<ComplexKey>, usize> = HashMap::new();
668    for r in 0..rows_b {
669        let mut row_keys = Vec::with_capacity(cols);
670        for c in 0..cols {
671            let idx = r + c * rows_b;
672            row_keys.push(ComplexKey::new(b.data[idx]));
673        }
674        b_map.entry(row_keys).or_insert(r);
675    }
676
677    let mut seen: HashSet<Vec<ComplexKey>> = HashSet::new();
678    let mut entries = Vec::<ComplexRowIntersectEntry>::new();
679    let mut order_counter = 0usize;
680
681    for r in 0..rows_a {
682        let mut row_values = Vec::with_capacity(cols);
683        let mut row_keys = Vec::with_capacity(cols);
684        for c in 0..cols {
685            let idx = r + c * rows_a;
686            let value = a.data[idx];
687            row_values.push(value);
688            row_keys.push(ComplexKey::new(value));
689        }
690        if seen.contains(&row_keys) {
691            continue;
692        }
693        if let Some(&b_row) = b_map.get(&row_keys) {
694            entries.push(ComplexRowIntersectEntry {
695                row_data: row_values,
696                a_row: r,
697                b_row,
698                order_rank: order_counter,
699            });
700            seen.insert(row_keys);
701            order_counter += 1;
702        }
703    }
704
705    assemble_complex_row_intersect(entries, opts, cols)
706}
707
708fn intersect_char(
709    a: CharArray,
710    b: CharArray,
711    opts: &IntersectOptions,
712) -> crate::BuiltinResult<IntersectEvaluation> {
713    if opts.rows {
714        intersect_char_rows(a, b, opts)
715    } else {
716        intersect_char_elements(a, b, opts)
717    }
718}
719
720fn intersect_char_elements(
721    a: CharArray,
722    b: CharArray,
723    opts: &IntersectOptions,
724) -> crate::BuiltinResult<IntersectEvaluation> {
725    let mut seen: HashSet<u32> = HashSet::new();
726    let mut entries = Vec::<CharIntersectEntry>::new();
727    let mut order_counter = 0usize;
728
729    for col in 0..a.cols {
730        for row in 0..a.rows {
731            let linear_idx = row + col * a.rows;
732            let data_idx = row * a.cols + col;
733            let ch = a.data[data_idx];
734            let key = ch as u32;
735            if seen.contains(&key) {
736                continue;
737            }
738            if let Some(b_idx) = find_char_index(&b, ch) {
739                entries.push(CharIntersectEntry {
740                    ch,
741                    a_index: linear_idx,
742                    b_index: b_idx,
743                    order_rank: order_counter,
744                });
745                seen.insert(key);
746                order_counter += 1;
747            }
748        }
749    }
750
751    assemble_char_intersect(entries, opts, &b)
752}
753
754fn intersect_char_rows(
755    a: CharArray,
756    b: CharArray,
757    opts: &IntersectOptions,
758) -> crate::BuiltinResult<IntersectEvaluation> {
759    if a.cols != b.cols {
760        return Err(intersect_error(&INTERSECT_ERROR_ROWS_COLUMN_MISMATCH));
761    }
762    let rows_a = a.rows;
763    let rows_b = b.rows;
764    let cols = a.cols;
765
766    let mut b_map: HashMap<RowCharKey, usize> = HashMap::new();
767    for r in 0..rows_b {
768        let mut row_values = Vec::with_capacity(cols);
769        for c in 0..cols {
770            let idx = r * cols + c;
771            row_values.push(b.data[idx]);
772        }
773        let key = RowCharKey::from_slice(&row_values);
774        b_map.entry(key).or_insert(r);
775    }
776
777    let mut seen: HashSet<RowCharKey> = HashSet::new();
778    let mut entries = Vec::<CharRowIntersectEntry>::new();
779    let mut order_counter = 0usize;
780
781    for r in 0..rows_a {
782        let mut row_values = Vec::with_capacity(cols);
783        for c in 0..cols {
784            let idx = r * cols + c;
785            row_values.push(a.data[idx]);
786        }
787        let key = RowCharKey::from_slice(&row_values);
788        if seen.contains(&key) {
789            continue;
790        }
791        if let Some(&b_row) = b_map.get(&key) {
792            entries.push(CharRowIntersectEntry {
793                row_data: row_values,
794                a_row: r,
795                b_row,
796                order_rank: order_counter,
797            });
798            seen.insert(key);
799            order_counter += 1;
800        }
801    }
802
803    assemble_char_row_intersect(entries, opts, cols)
804}
805
806fn find_char_index(array: &CharArray, target: char) -> Option<usize> {
807    for col in 0..array.cols {
808        for row in 0..array.rows {
809            let data_idx = row * array.cols + col;
810            if array.data[data_idx] == target {
811                return Some(row + col * array.rows);
812            }
813        }
814    }
815    None
816}
817
818fn intersect_string(
819    a: StringArray,
820    b: StringArray,
821    opts: &IntersectOptions,
822) -> crate::BuiltinResult<IntersectEvaluation> {
823    if opts.rows {
824        intersect_string_rows(a, b, opts)
825    } else {
826        intersect_string_elements(a, b, opts)
827    }
828}
829
830fn intersect_string_elements(
831    a: StringArray,
832    b: StringArray,
833    opts: &IntersectOptions,
834) -> crate::BuiltinResult<IntersectEvaluation> {
835    let mut b_map: HashMap<String, usize> = HashMap::new();
836    for (idx, value) in b.data.iter().enumerate() {
837        b_map.entry(value.clone()).or_insert(idx);
838    }
839
840    let mut seen: HashSet<String> = HashSet::new();
841    let mut entries = Vec::<StringIntersectEntry>::new();
842    let mut order_counter = 0usize;
843
844    for (idx, value) in a.data.iter().enumerate() {
845        if seen.contains(value) {
846            continue;
847        }
848        if let Some(&b_idx) = b_map.get(value) {
849            entries.push(StringIntersectEntry {
850                value: value.clone(),
851                a_index: idx,
852                b_index: b_idx,
853                order_rank: order_counter,
854            });
855            seen.insert(value.clone());
856            order_counter += 1;
857        }
858    }
859
860    assemble_string_intersect(entries, opts)
861}
862
863fn intersect_string_rows(
864    a: StringArray,
865    b: StringArray,
866    opts: &IntersectOptions,
867) -> crate::BuiltinResult<IntersectEvaluation> {
868    if a.shape.len() != 2 || b.shape.len() != 2 {
869        return Err(intersect_internal_error(
870            "intersect: 'rows' option requires 2-D string arrays",
871        ));
872    }
873    if a.shape[1] != b.shape[1] {
874        return Err(intersect_error(&INTERSECT_ERROR_ROWS_COLUMN_MISMATCH));
875    }
876    let rows_a = a.shape[0];
877    let cols = a.shape[1];
878    let rows_b = b.shape[0];
879
880    let mut b_map: HashMap<RowStringKey, usize> = HashMap::new();
881    for r in 0..rows_b {
882        let mut row_values = Vec::with_capacity(cols);
883        for c in 0..cols {
884            let idx = r + c * rows_b;
885            row_values.push(b.data[idx].clone());
886        }
887        let key = RowStringKey::from_slice(&row_values);
888        b_map.entry(key).or_insert(r);
889    }
890
891    let mut seen: HashSet<RowStringKey> = HashSet::new();
892    let mut entries = Vec::<StringRowIntersectEntry>::new();
893    let mut order_counter = 0usize;
894
895    for r in 0..rows_a {
896        let mut row_values = Vec::with_capacity(cols);
897        for c in 0..cols {
898            let idx = r + c * rows_a;
899            row_values.push(a.data[idx].clone());
900        }
901        let key = RowStringKey::from_slice(&row_values);
902        if seen.contains(&key) {
903            continue;
904        }
905        if let Some(&b_row) = b_map.get(&key) {
906            entries.push(StringRowIntersectEntry {
907                row_data: row_values,
908                a_row: r,
909                b_row,
910                order_rank: order_counter,
911            });
912            seen.insert(key);
913            order_counter += 1;
914        }
915    }
916
917    assemble_string_row_intersect(entries, opts, cols)
918}
919
920#[derive(Debug, Clone)]
921pub struct IntersectEvaluation {
922    values: Value,
923    ia: Tensor,
924    ib: Tensor,
925}
926
927impl IntersectEvaluation {
928    fn new(values: Value, ia: Tensor, ib: Tensor) -> Self {
929        Self { values, ia, ib }
930    }
931
932    pub fn into_values_value(self) -> Value {
933        self.values
934    }
935
936    pub fn into_pair(self) -> (Value, Value) {
937        let ia = tensor::tensor_into_value(self.ia);
938        (self.values, ia)
939    }
940
941    pub fn into_triple(self) -> (Value, Value, Value) {
942        let ia = tensor::tensor_into_value(self.ia);
943        let ib = tensor::tensor_into_value(self.ib);
944        (self.values, ia, ib)
945    }
946
947    pub fn values_value(&self) -> Value {
948        self.values.clone()
949    }
950
951    pub fn ia_value(&self) -> Value {
952        tensor::tensor_into_value(self.ia.clone())
953    }
954
955    pub fn ib_value(&self) -> Value {
956        tensor::tensor_into_value(self.ib.clone())
957    }
958}
959
960#[derive(Debug)]
961struct NumericIntersectEntry {
962    value: f64,
963    a_index: usize,
964    b_index: usize,
965    order_rank: usize,
966}
967
968#[derive(Debug)]
969struct NumericRowIntersectEntry {
970    row_data: Vec<f64>,
971    a_row: usize,
972    b_row: usize,
973    order_rank: usize,
974}
975
976#[derive(Debug)]
977struct ComplexIntersectEntry {
978    value: (f64, f64),
979    a_index: usize,
980    b_index: usize,
981    order_rank: usize,
982}
983
984#[derive(Debug)]
985struct ComplexRowIntersectEntry {
986    row_data: Vec<(f64, f64)>,
987    a_row: usize,
988    b_row: usize,
989    order_rank: usize,
990}
991
992#[derive(Debug)]
993struct CharIntersectEntry {
994    ch: char,
995    a_index: usize,
996    b_index: usize,
997    order_rank: usize,
998}
999
1000#[derive(Debug)]
1001struct CharRowIntersectEntry {
1002    row_data: Vec<char>,
1003    a_row: usize,
1004    b_row: usize,
1005    order_rank: usize,
1006}
1007
1008#[derive(Debug)]
1009struct StringIntersectEntry {
1010    value: String,
1011    a_index: usize,
1012    b_index: usize,
1013    order_rank: usize,
1014}
1015
1016#[derive(Debug)]
1017struct StringRowIntersectEntry {
1018    row_data: Vec<String>,
1019    a_row: usize,
1020    b_row: usize,
1021    order_rank: usize,
1022}
1023
1024fn assemble_numeric_intersect(
1025    entries: Vec<NumericIntersectEntry>,
1026    opts: &IntersectOptions,
1027) -> crate::BuiltinResult<IntersectEvaluation> {
1028    let mut order: Vec<usize> = (0..entries.len()).collect();
1029    match opts.order {
1030        IntersectOrder::Sorted => {
1031            order.sort_by(|&lhs, &rhs| compare_f64(entries[lhs].value, entries[rhs].value));
1032        }
1033        IntersectOrder::Stable => {
1034            order.sort_by_key(|&idx| entries[idx].order_rank);
1035        }
1036    }
1037
1038    let mut values = Vec::with_capacity(order.len());
1039    let mut ia = Vec::with_capacity(order.len());
1040    let mut ib = Vec::with_capacity(order.len());
1041    for &idx in &order {
1042        let entry = &entries[idx];
1043        values.push(entry.value);
1044        ia.push((entry.a_index + 1) as f64);
1045        ib.push((entry.b_index + 1) as f64);
1046    }
1047
1048    let value_tensor = Tensor::new(values, vec![order.len(), 1])
1049        .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1050    let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
1051        .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1052    let ib_tensor = Tensor::new(ib, vec![order.len(), 1])
1053        .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1054
1055    Ok(IntersectEvaluation::new(
1056        tensor::tensor_into_value(value_tensor),
1057        ia_tensor,
1058        ib_tensor,
1059    ))
1060}
1061
1062fn assemble_numeric_row_intersect(
1063    entries: Vec<NumericRowIntersectEntry>,
1064    opts: &IntersectOptions,
1065    cols: usize,
1066) -> crate::BuiltinResult<IntersectEvaluation> {
1067    let mut order: Vec<usize> = (0..entries.len()).collect();
1068    match opts.order {
1069        IntersectOrder::Sorted => {
1070            order.sort_by(|&lhs, &rhs| {
1071                compare_numeric_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1072            });
1073        }
1074        IntersectOrder::Stable => {
1075            order.sort_by_key(|&idx| entries[idx].order_rank);
1076        }
1077    }
1078
1079    let rows_out = order.len();
1080    let mut values = vec![0.0f64; rows_out * cols];
1081    let mut ia = Vec::with_capacity(rows_out);
1082    let mut ib = Vec::with_capacity(rows_out);
1083
1084    for (row_pos, &entry_idx) in order.iter().enumerate() {
1085        let entry = &entries[entry_idx];
1086        for col in 0..cols {
1087            let dest = row_pos + col * rows_out;
1088            values[dest] = entry.row_data[col];
1089        }
1090        ia.push((entry.a_row + 1) as f64);
1091        ib.push((entry.b_row + 1) as f64);
1092    }
1093
1094    let value_tensor = Tensor::new(values, vec![rows_out, cols])
1095        .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1096    let ia_tensor = Tensor::new(ia, vec![rows_out, 1])
1097        .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1098    let ib_tensor = Tensor::new(ib, vec![rows_out, 1])
1099        .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1100
1101    Ok(IntersectEvaluation::new(
1102        tensor::tensor_into_value(value_tensor),
1103        ia_tensor,
1104        ib_tensor,
1105    ))
1106}
1107
1108fn assemble_complex_intersect(
1109    entries: Vec<ComplexIntersectEntry>,
1110    opts: &IntersectOptions,
1111) -> crate::BuiltinResult<IntersectEvaluation> {
1112    let mut order: Vec<usize> = (0..entries.len()).collect();
1113    match opts.order {
1114        IntersectOrder::Sorted => {
1115            order.sort_by(|&lhs, &rhs| compare_complex(entries[lhs].value, entries[rhs].value));
1116        }
1117        IntersectOrder::Stable => {
1118            order.sort_by_key(|&idx| entries[idx].order_rank);
1119        }
1120    }
1121
1122    let mut values = Vec::with_capacity(order.len());
1123    let mut ia = Vec::with_capacity(order.len());
1124    let mut ib = Vec::with_capacity(order.len());
1125    for &idx in &order {
1126        let entry = &entries[idx];
1127        values.push(entry.value);
1128        ia.push((entry.a_index + 1) as f64);
1129        ib.push((entry.b_index + 1) as f64);
1130    }
1131
1132    let value_tensor = ComplexTensor::new(values, vec![order.len(), 1])
1133        .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1134    let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
1135        .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1136    let ib_tensor = Tensor::new(ib, vec![order.len(), 1])
1137        .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1138
1139    Ok(IntersectEvaluation::new(
1140        complex_tensor_into_value(value_tensor),
1141        ia_tensor,
1142        ib_tensor,
1143    ))
1144}
1145
1146fn assemble_complex_row_intersect(
1147    entries: Vec<ComplexRowIntersectEntry>,
1148    opts: &IntersectOptions,
1149    cols: usize,
1150) -> crate::BuiltinResult<IntersectEvaluation> {
1151    let mut order: Vec<usize> = (0..entries.len()).collect();
1152    match opts.order {
1153        IntersectOrder::Sorted => {
1154            order.sort_by(|&lhs, &rhs| {
1155                compare_complex_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1156            });
1157        }
1158        IntersectOrder::Stable => {
1159            order.sort_by_key(|&idx| entries[idx].order_rank);
1160        }
1161    }
1162
1163    let rows_out = order.len();
1164    let mut values = vec![(0.0f64, 0.0f64); rows_out * cols];
1165    let mut ia = Vec::with_capacity(rows_out);
1166    let mut ib = Vec::with_capacity(rows_out);
1167
1168    for (row_pos, &entry_idx) in order.iter().enumerate() {
1169        let entry = &entries[entry_idx];
1170        for col in 0..cols {
1171            let dest = row_pos + col * rows_out;
1172            values[dest] = entry.row_data[col];
1173        }
1174        ia.push((entry.a_row + 1) as f64);
1175        ib.push((entry.b_row + 1) as f64);
1176    }
1177
1178    let value_tensor = ComplexTensor::new(values, vec![rows_out, cols])
1179        .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1180    let ia_tensor = Tensor::new(ia, vec![rows_out, 1])
1181        .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1182    let ib_tensor = Tensor::new(ib, vec![rows_out, 1])
1183        .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1184
1185    Ok(IntersectEvaluation::new(
1186        complex_tensor_into_value(value_tensor),
1187        ia_tensor,
1188        ib_tensor,
1189    ))
1190}
1191
1192fn assemble_char_intersect(
1193    entries: Vec<CharIntersectEntry>,
1194    opts: &IntersectOptions,
1195    b: &CharArray,
1196) -> crate::BuiltinResult<IntersectEvaluation> {
1197    let mut order: Vec<usize> = (0..entries.len()).collect();
1198    match opts.order {
1199        IntersectOrder::Sorted => {
1200            order.sort_by(|&lhs, &rhs| entries[lhs].ch.cmp(&entries[rhs].ch));
1201        }
1202        IntersectOrder::Stable => {
1203            order.sort_by_key(|&idx| entries[idx].order_rank);
1204        }
1205    }
1206
1207    let mut values = Vec::with_capacity(order.len());
1208    let mut ia = Vec::with_capacity(order.len());
1209    let mut ib = Vec::with_capacity(order.len());
1210    for &idx in &order {
1211        let entry = &entries[idx];
1212        values.push(entry.ch);
1213        ia.push((entry.a_index + 1) as f64);
1214        let b_idx = find_char_index(b, entry.ch).unwrap_or(entry.b_index);
1215        ib.push((b_idx + 1) as f64);
1216    }
1217
1218    let value_array = CharArray::new(values, order.len(), 1)
1219        .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1220    let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
1221        .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1222    let ib_tensor = Tensor::new(ib, vec![order.len(), 1])
1223        .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1224
1225    Ok(IntersectEvaluation::new(
1226        Value::CharArray(value_array),
1227        ia_tensor,
1228        ib_tensor,
1229    ))
1230}
1231
1232fn assemble_char_row_intersect(
1233    entries: Vec<CharRowIntersectEntry>,
1234    opts: &IntersectOptions,
1235    cols: usize,
1236) -> crate::BuiltinResult<IntersectEvaluation> {
1237    let mut order: Vec<usize> = (0..entries.len()).collect();
1238    match opts.order {
1239        IntersectOrder::Sorted => {
1240            order.sort_by(|&lhs, &rhs| {
1241                compare_char_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1242            });
1243        }
1244        IntersectOrder::Stable => {
1245            order.sort_by_key(|&idx| entries[idx].order_rank);
1246        }
1247    }
1248
1249    let rows_out = order.len();
1250    let mut values = vec!['\0'; rows_out * cols];
1251    let mut ia = Vec::with_capacity(rows_out);
1252    let mut ib = Vec::with_capacity(rows_out);
1253
1254    for (row_pos, &entry_idx) in order.iter().enumerate() {
1255        let entry = &entries[entry_idx];
1256        for col in 0..cols {
1257            let dest = row_pos * cols + col;
1258            values[dest] = entry.row_data[col];
1259        }
1260        ia.push((entry.a_row + 1) as f64);
1261        ib.push((entry.b_row + 1) as f64);
1262    }
1263
1264    let value_array = CharArray::new(values, rows_out, cols)
1265        .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1266    let ia_tensor = Tensor::new(ia, vec![rows_out, 1])
1267        .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1268    let ib_tensor = Tensor::new(ib, vec![rows_out, 1])
1269        .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1270
1271    Ok(IntersectEvaluation::new(
1272        Value::CharArray(value_array),
1273        ia_tensor,
1274        ib_tensor,
1275    ))
1276}
1277
1278fn assemble_string_intersect(
1279    entries: Vec<StringIntersectEntry>,
1280    opts: &IntersectOptions,
1281) -> crate::BuiltinResult<IntersectEvaluation> {
1282    let mut order: Vec<usize> = (0..entries.len()).collect();
1283    match opts.order {
1284        IntersectOrder::Sorted => {
1285            order.sort_by(|&lhs, &rhs| entries[lhs].value.cmp(&entries[rhs].value));
1286        }
1287        IntersectOrder::Stable => {
1288            order.sort_by_key(|&idx| entries[idx].order_rank);
1289        }
1290    }
1291
1292    let mut values = Vec::with_capacity(order.len());
1293    let mut ia = Vec::with_capacity(order.len());
1294    let mut ib = Vec::with_capacity(order.len());
1295    for &idx in &order {
1296        let entry = &entries[idx];
1297        values.push(entry.value.clone());
1298        ia.push((entry.a_index + 1) as f64);
1299        ib.push((entry.b_index + 1) as f64);
1300    }
1301
1302    let value_array = StringArray::new(values, vec![order.len(), 1])
1303        .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1304    let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
1305        .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1306    let ib_tensor = Tensor::new(ib, vec![order.len(), 1])
1307        .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1308
1309    Ok(IntersectEvaluation::new(
1310        Value::StringArray(value_array),
1311        ia_tensor,
1312        ib_tensor,
1313    ))
1314}
1315
1316fn assemble_string_row_intersect(
1317    entries: Vec<StringRowIntersectEntry>,
1318    opts: &IntersectOptions,
1319    cols: usize,
1320) -> crate::BuiltinResult<IntersectEvaluation> {
1321    let mut order: Vec<usize> = (0..entries.len()).collect();
1322    match opts.order {
1323        IntersectOrder::Sorted => {
1324            order.sort_by(|&lhs, &rhs| {
1325                compare_string_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1326            });
1327        }
1328        IntersectOrder::Stable => {
1329            order.sort_by_key(|&idx| entries[idx].order_rank);
1330        }
1331    }
1332
1333    let rows_out = order.len();
1334    let mut values = vec![String::new(); rows_out * cols];
1335    let mut ia = Vec::with_capacity(rows_out);
1336    let mut ib = Vec::with_capacity(rows_out);
1337
1338    for (row_pos, &entry_idx) in order.iter().enumerate() {
1339        let entry = &entries[entry_idx];
1340        for col in 0..cols {
1341            let dest = row_pos + col * rows_out;
1342            values[dest] = entry.row_data[col].clone();
1343        }
1344        ia.push((entry.a_row + 1) as f64);
1345        ib.push((entry.b_row + 1) as f64);
1346    }
1347
1348    let value_array = StringArray::new(values, vec![rows_out, cols])
1349        .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1350    let ia_tensor = Tensor::new(ia, vec![rows_out, 1])
1351        .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1352    let ib_tensor = Tensor::new(ib, vec![rows_out, 1])
1353        .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1354
1355    Ok(IntersectEvaluation::new(
1356        Value::StringArray(value_array),
1357        ia_tensor,
1358        ib_tensor,
1359    ))
1360}
1361
1362#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1363struct NumericRowKey(Vec<u64>);
1364
1365impl NumericRowKey {
1366    fn from_slice(values: &[f64]) -> Self {
1367        NumericRowKey(values.iter().map(|&v| canonicalize_f64(v)).collect())
1368    }
1369}
1370
1371#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
1372struct ComplexKey {
1373    re: u64,
1374    im: u64,
1375}
1376
1377impl ComplexKey {
1378    fn new(value: (f64, f64)) -> Self {
1379        Self {
1380            re: canonicalize_f64(value.0),
1381            im: canonicalize_f64(value.1),
1382        }
1383    }
1384}
1385
1386#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1387struct RowCharKey(Vec<u32>);
1388
1389impl RowCharKey {
1390    fn from_slice(values: &[char]) -> Self {
1391        RowCharKey(values.iter().map(|&ch| ch as u32).collect())
1392    }
1393}
1394
1395#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1396struct RowStringKey(Vec<String>);
1397
1398impl RowStringKey {
1399    fn from_slice(values: &[String]) -> Self {
1400        RowStringKey(values.to_vec())
1401    }
1402}
1403
1404fn scalar_complex_tensor(re: f64, im: f64) -> crate::BuiltinResult<ComplexTensor> {
1405    ComplexTensor::new(vec![(re, im)], vec![1, 1])
1406        .map_err(|e| intersect_internal_error(format!("intersect: {e}")))
1407}
1408
1409fn tensor_to_complex_owned(name: &str, tensor: Tensor) -> crate::BuiltinResult<ComplexTensor> {
1410    let Tensor { data, shape, .. } = tensor;
1411    let complex: Vec<(f64, f64)> = data.into_iter().map(|re| (re, 0.0)).collect();
1412    ComplexTensor::new(complex, shape).map_err(|e| intersect_internal_error(format!("{name}: {e}")))
1413}
1414
1415fn value_into_complex_tensor(value: Value) -> crate::BuiltinResult<ComplexTensor> {
1416    match value {
1417        Value::ComplexTensor(tensor) => Ok(tensor),
1418        Value::Complex(re, im) => scalar_complex_tensor(re, im),
1419        other => {
1420            let tensor = tensor::value_into_tensor_for("intersect", other)
1421                .map_err(|e| intersect_internal_error(e))?;
1422            tensor_to_complex_owned("intersect", tensor)
1423        }
1424    }
1425}
1426
1427fn canonicalize_f64(value: f64) -> u64 {
1428    if value.is_nan() {
1429        0x7ff8_0000_0000_0000u64
1430    } else if value == 0.0 {
1431        0u64
1432    } else {
1433        value.to_bits()
1434    }
1435}
1436
1437fn compare_f64(a: f64, b: f64) -> Ordering {
1438    if a.is_nan() {
1439        if b.is_nan() {
1440            Ordering::Equal
1441        } else {
1442            Ordering::Greater
1443        }
1444    } else if b.is_nan() {
1445        Ordering::Less
1446    } else {
1447        a.partial_cmp(&b).unwrap_or(Ordering::Equal)
1448    }
1449}
1450
1451fn compare_numeric_rows(a: &[f64], b: &[f64]) -> Ordering {
1452    for (lhs, rhs) in a.iter().zip(b.iter()) {
1453        let ord = compare_f64(*lhs, *rhs);
1454        if ord != Ordering::Equal {
1455            return ord;
1456        }
1457    }
1458    Ordering::Equal
1459}
1460
1461fn complex_is_nan(value: (f64, f64)) -> bool {
1462    value.0.is_nan() || value.1.is_nan()
1463}
1464
1465fn compare_complex(a: (f64, f64), b: (f64, f64)) -> Ordering {
1466    match (complex_is_nan(a), complex_is_nan(b)) {
1467        (true, true) => Ordering::Equal,
1468        (true, false) => Ordering::Greater,
1469        (false, true) => Ordering::Less,
1470        (false, false) => {
1471            let mag_a = a.0.hypot(a.1);
1472            let mag_b = b.0.hypot(b.1);
1473            let mag_cmp = compare_f64(mag_a, mag_b);
1474            if mag_cmp != Ordering::Equal {
1475                return mag_cmp;
1476            }
1477            let re_cmp = compare_f64(a.0, b.0);
1478            if re_cmp != Ordering::Equal {
1479                return re_cmp;
1480            }
1481            compare_f64(a.1, b.1)
1482        }
1483    }
1484}
1485
1486fn compare_complex_rows(a: &[(f64, f64)], b: &[(f64, f64)]) -> Ordering {
1487    for (lhs, rhs) in a.iter().zip(b.iter()) {
1488        let ord = compare_complex(*lhs, *rhs);
1489        if ord != Ordering::Equal {
1490            return ord;
1491        }
1492    }
1493    Ordering::Equal
1494}
1495
1496fn compare_char_rows(a: &[char], b: &[char]) -> Ordering {
1497    for (lhs, rhs) in a.iter().zip(b.iter()) {
1498        let ord = lhs.cmp(rhs);
1499        if ord != Ordering::Equal {
1500            return ord;
1501        }
1502    }
1503    Ordering::Equal
1504}
1505
1506fn compare_string_rows(a: &[String], b: &[String]) -> Ordering {
1507    for (lhs, rhs) in a.iter().zip(b.iter()) {
1508        let ord = lhs.cmp(rhs);
1509        if ord != Ordering::Equal {
1510            return ord;
1511        }
1512    }
1513    Ordering::Equal
1514}
1515
1516#[cfg(test)]
1517pub(crate) mod tests {
1518    use super::*;
1519    use crate::builtins::common::test_support;
1520    use runmat_accelerate_api::HostTensorView;
1521    use runmat_builtins::{ResolveContext, Type};
1522
1523    fn evaluate_sync(
1524        a: Value,
1525        b: Value,
1526        rest: &[Value],
1527    ) -> crate::BuiltinResult<IntersectEvaluation> {
1528        futures::executor::block_on(evaluate(a, b, rest))
1529    }
1530
1531    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1532    #[test]
1533    fn intersect_numeric_sorted() {
1534        let a = Tensor::new(vec![5.0, 7.0, 5.0, 1.0], vec![4, 1]).unwrap();
1535        let b = Tensor::new(vec![7.0, 1.0, 3.0], vec![3, 1]).unwrap();
1536        let eval = intersect_numeric_elements(
1537            a,
1538            b,
1539            &IntersectOptions {
1540                rows: false,
1541                order: IntersectOrder::Sorted,
1542            },
1543        )
1544        .expect("intersect");
1545        let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1546        assert_eq!(values.data, vec![1.0, 7.0]);
1547        let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1548        let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1549        assert_eq!(ia.data, vec![4.0, 2.0]);
1550        assert_eq!(ib.data, vec![2.0, 1.0]);
1551    }
1552
1553    #[test]
1554    fn intersect_type_resolver_numeric() {
1555        assert_eq!(
1556            set_values_output_type(&[Type::tensor()], &ResolveContext::new(Vec::new())),
1557            Type::tensor()
1558        );
1559    }
1560
1561    #[test]
1562    fn intersect_type_resolver_string_array() {
1563        assert_eq!(
1564            set_values_output_type(
1565                &[Type::cell_of(Type::String)],
1566                &ResolveContext::new(Vec::new()),
1567            ),
1568            Type::cell_of(Type::String)
1569        );
1570    }
1571
1572    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1573    #[test]
1574    fn intersect_numeric_stable() {
1575        let a = Tensor::new(vec![4.0, 2.0, 4.0, 1.0, 3.0], vec![5, 1]).unwrap();
1576        let b = Tensor::new(vec![3.0, 4.0, 5.0, 1.0], vec![4, 1]).unwrap();
1577        let eval = intersect_numeric_elements(
1578            a,
1579            b,
1580            &IntersectOptions {
1581                rows: false,
1582                order: IntersectOrder::Stable,
1583            },
1584        )
1585        .expect("intersect");
1586        let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1587        assert_eq!(values.data, vec![4.0, 1.0, 3.0]);
1588        let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1589        let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1590        assert_eq!(ia.data, vec![1.0, 4.0, 5.0]);
1591        assert_eq!(ib.data, vec![2.0, 4.0, 1.0]);
1592    }
1593
1594    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1595    #[test]
1596    fn intersect_numeric_handles_nan() {
1597        let a = Tensor::new(vec![f64::NAN, 1.0, f64::NAN], vec![3, 1]).unwrap();
1598        let b = Tensor::new(vec![2.0, f64::NAN], vec![2, 1]).unwrap();
1599        let eval = intersect_numeric_elements(
1600            a,
1601            b,
1602            &IntersectOptions {
1603                rows: false,
1604                order: IntersectOrder::Sorted,
1605            },
1606        )
1607        .expect("intersect");
1608        let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1609        assert_eq!(values.data.len(), 1);
1610        assert!(values.data[0].is_nan());
1611        let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1612        let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1613        assert_eq!(ia.data, vec![1.0]);
1614        assert_eq!(ib.data, vec![2.0]);
1615    }
1616
1617    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1618    #[test]
1619    fn intersect_complex_with_real_inputs() {
1620        let complex =
1621            ComplexTensor::new(vec![(1.0, 0.0), (2.0, 0.0), (3.0, 1.0)], vec![3, 1]).unwrap();
1622        let real = Tensor::new(vec![2.0, 4.0, 1.0], vec![3, 1]).unwrap();
1623        let real_complex = tensor_to_complex_owned("intersect", real).unwrap();
1624        let eval = intersect_complex(
1625            complex,
1626            real_complex,
1627            &IntersectOptions {
1628                rows: false,
1629                order: IntersectOrder::Sorted,
1630            },
1631        )
1632        .expect("intersect complex");
1633        match eval.values_value() {
1634            Value::ComplexTensor(t) => {
1635                assert_eq!(t.data, vec![(1.0, 0.0), (2.0, 0.0)]);
1636            }
1637            other => panic!("expected complex tensor, got {other:?}"),
1638        }
1639        let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1640        let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1641        assert_eq!(ia.data, vec![1.0, 2.0]);
1642        assert_eq!(ib.data, vec![3.0, 1.0]);
1643    }
1644
1645    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1646    #[test]
1647    fn intersect_numeric_rows_default() {
1648        let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
1649        let b = Tensor::new(vec![1.0, 5.0, 2.0, 6.0], vec![2, 2]).unwrap();
1650        let eval = intersect_numeric_rows(
1651            a,
1652            b,
1653            &IntersectOptions {
1654                rows: true,
1655                order: IntersectOrder::Sorted,
1656            },
1657        )
1658        .expect("intersect rows");
1659        let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1660        assert_eq!(values.shape, vec![1, 2]);
1661        assert_eq!(values.data, vec![1.0, 2.0]);
1662        let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1663        let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1664        assert_eq!(ia.data, vec![1.0]);
1665        assert_eq!(ib.data, vec![1.0]);
1666    }
1667
1668    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1669    #[test]
1670    fn intersect_char_elements_basic() {
1671        let a = CharArray::new("cab".chars().collect(), 1, 3).unwrap();
1672        let b = CharArray::new("bcd".chars().collect(), 1, 3).unwrap();
1673        assert_eq!(find_char_index(&b, 'b'), Some(0));
1674        assert_eq!(find_char_index(&b, 'c'), Some(1));
1675        let b_for_eval = CharArray::new("bcd".chars().collect(), 1, 3).unwrap();
1676        let eval = intersect_char_elements(
1677            a,
1678            b_for_eval,
1679            &IntersectOptions {
1680                rows: false,
1681                order: IntersectOrder::Sorted,
1682            },
1683        )
1684        .expect("intersect char");
1685        match eval.values_value() {
1686            Value::CharArray(arr) => {
1687                assert_eq!(arr.rows, 2);
1688                assert_eq!(arr.cols, 1);
1689                assert_eq!(arr.data, vec!['b', 'c']);
1690            }
1691            other => panic!("expected char array, got {other:?}"),
1692        }
1693        let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1694        let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1695        assert_eq!(ia.data, vec![3.0, 1.0]);
1696        assert_eq!(ib.data, vec![1.0, 2.0]);
1697    }
1698
1699    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1700    #[test]
1701    fn intersect_string_elements_stable() {
1702        let a = StringArray::new(
1703            vec!["apple".into(), "orange".into(), "pear".into()],
1704            vec![3, 1],
1705        )
1706        .unwrap();
1707        let b = StringArray::new(
1708            vec!["pear".into(), "grape".into(), "orange".into()],
1709            vec![3, 1],
1710        )
1711        .unwrap();
1712        let eval = intersect_string_elements(
1713            a,
1714            b,
1715            &IntersectOptions {
1716                rows: false,
1717                order: IntersectOrder::Stable,
1718            },
1719        )
1720        .expect("intersect string");
1721        match eval.values_value() {
1722            Value::StringArray(arr) => {
1723                assert_eq!(arr.shape, vec![2, 1]);
1724                assert_eq!(arr.data, vec!["orange".to_string(), "pear".to_string()]);
1725            }
1726            other => panic!("expected string array, got {other:?}"),
1727        }
1728        let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1729        let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1730        assert_eq!(ia.data, vec![2.0, 3.0]);
1731        assert_eq!(ib.data, vec![3.0, 1.0]);
1732    }
1733
1734    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1735    #[test]
1736    fn intersect_rejects_legacy_option() {
1737        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1738        let err = evaluate_sync(
1739            Value::Tensor(tensor.clone()),
1740            Value::Tensor(tensor),
1741            &[Value::from("legacy")],
1742        )
1743        .unwrap_err();
1744        assert_eq!(
1745            err.identifier(),
1746            INTERSECT_ERROR_LEGACY_OPTION_UNSUPPORTED.identifier
1747        );
1748    }
1749
1750    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1751    #[test]
1752    fn intersect_rejects_conflicting_order_options() {
1753        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1754        let err = evaluate_sync(
1755            Value::Tensor(tensor.clone()),
1756            Value::Tensor(tensor),
1757            &[Value::from("stable"), Value::from("sorted")],
1758        )
1759        .unwrap_err();
1760        assert_eq!(
1761            err.identifier(),
1762            INTERSECT_ERROR_CONFLICTING_ORDER_OPTIONS.identifier
1763        );
1764    }
1765
1766    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1767    #[test]
1768    fn intersect_rejects_unknown_option() {
1769        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1770        let err = evaluate_sync(
1771            Value::Tensor(tensor.clone()),
1772            Value::Tensor(tensor),
1773            &[Value::from("bogus")],
1774        )
1775        .unwrap_err();
1776        assert_eq!(err.identifier(), INTERSECT_ERROR_UNKNOWN_OPTION.identifier);
1777    }
1778
1779    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1780    #[test]
1781    fn intersect_rows_dimension_mismatch() {
1782        let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1783        let b = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1784        let err = intersect_numeric_rows(
1785            a,
1786            b,
1787            &IntersectOptions {
1788                rows: true,
1789                order: IntersectOrder::Sorted,
1790            },
1791        )
1792        .unwrap_err();
1793        assert_eq!(
1794            err.identifier(),
1795            INTERSECT_ERROR_ROWS_COLUMN_MISMATCH.identifier
1796        );
1797    }
1798
1799    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1800    #[test]
1801    fn intersect_mixed_types_error() {
1802        let a = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1803        let b = CharArray::new(vec!['a', 'b'], 1, 2).unwrap();
1804        let err = intersect_host(
1805            Value::Tensor(a),
1806            Value::CharArray(b),
1807            &IntersectOptions {
1808                rows: false,
1809                order: IntersectOrder::Sorted,
1810            },
1811        )
1812        .unwrap_err();
1813        assert_eq!(
1814            err.identifier(),
1815            INTERSECT_ERROR_UNSUPPORTED_INPUT_TYPE.identifier
1816        );
1817    }
1818
1819    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1820    #[test]
1821    fn intersect_gpu_roundtrip() {
1822        test_support::with_test_provider(|provider| {
1823            let a = Tensor::new(vec![4.0, 1.0, 2.0, 1.0], vec![4, 1]).unwrap();
1824            let b = Tensor::new(vec![2.0, 5.0, 1.0], vec![3, 1]).unwrap();
1825            let view_a = HostTensorView {
1826                data: &a.data,
1827                shape: &a.shape,
1828            };
1829            let view_b = HostTensorView {
1830                data: &b.data,
1831                shape: &b.shape,
1832            };
1833            let handle_a = provider.upload(&view_a).expect("upload A");
1834            let handle_b = provider.upload(&view_b).expect("upload B");
1835            let eval = evaluate_sync(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
1836                .expect("intersect");
1837            let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1838            assert_eq!(values.data, vec![1.0, 2.0]);
1839            let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1840            let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1841            assert_eq!(ia.data, vec![2.0, 3.0]);
1842            assert_eq!(ib.data, vec![3.0, 1.0]);
1843        });
1844    }
1845
1846    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1847    #[test]
1848    fn intersect_two_outputs_from_evaluate() {
1849        let a = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1850        let b = Tensor::new(vec![3.0, 1.0], vec![2, 1]).unwrap();
1851        let eval = intersect_numeric_elements(
1852            a,
1853            b,
1854            &IntersectOptions {
1855                rows: false,
1856                order: IntersectOrder::Sorted,
1857            },
1858        )
1859        .unwrap();
1860        let (_c, ia) = eval.clone().into_pair();
1861        let ia_tensor = tensor::value_into_tensor_for("intersect", ia).unwrap();
1862        assert_eq!(ia_tensor.data, vec![1.0, 3.0]);
1863        let (_c, ia2, ib2) = eval.into_triple();
1864        let ia_tensor2 = tensor::value_into_tensor_for("intersect", ia2).unwrap();
1865        let ib_tensor2 = tensor::value_into_tensor_for("intersect", ib2).unwrap();
1866        assert_eq!(ia_tensor2.data, vec![1.0, 3.0]);
1867        assert_eq!(ib_tensor2.data, vec![2.0, 1.0]);
1868    }
1869
1870    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1871    #[test]
1872    #[cfg(feature = "wgpu")]
1873    fn intersect_wgpu_matches_cpu() {
1874        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1875            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1876        );
1877        let a = Tensor::new(vec![4.0, 1.0, 2.0, 3.0], vec![4, 1]).unwrap();
1878        let b = Tensor::new(vec![2.0, 6.0, 3.0], vec![3, 1]).unwrap();
1879
1880        let cpu_eval = intersect_numeric_elements(
1881            a.clone(),
1882            b.clone(),
1883            &IntersectOptions {
1884                rows: false,
1885                order: IntersectOrder::Sorted,
1886            },
1887        )
1888        .unwrap();
1889        let cpu_values =
1890            tensor::value_into_tensor_for("intersect", cpu_eval.values_value()).unwrap();
1891        let cpu_ia = tensor::value_into_tensor_for("intersect", cpu_eval.ia_value()).unwrap();
1892        let cpu_ib = tensor::value_into_tensor_for("intersect", cpu_eval.ib_value()).unwrap();
1893
1894        let provider = runmat_accelerate_api::provider().expect("provider");
1895        let view_a = HostTensorView {
1896            data: &a.data,
1897            shape: &a.shape,
1898        };
1899        let view_b = HostTensorView {
1900            data: &b.data,
1901            shape: &b.shape,
1902        };
1903        let handle_a = provider.upload(&view_a).expect("upload A");
1904        let handle_b = provider.upload(&view_b).expect("upload B");
1905        let gpu_eval = evaluate_sync(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
1906            .expect("intersect");
1907        let gpu_values =
1908            tensor::value_into_tensor_for("intersect", gpu_eval.values_value()).unwrap();
1909        let gpu_ia = tensor::value_into_tensor_for("intersect", gpu_eval.ia_value()).unwrap();
1910        let gpu_ib = tensor::value_into_tensor_for("intersect", gpu_eval.ib_value()).unwrap();
1911
1912        assert_eq!(gpu_values.data, cpu_values.data);
1913        assert_eq!(gpu_ia.data, cpu_ia.data);
1914        assert_eq!(gpu_ib.data, cpu_ib.data);
1915    }
1916}