1use std::cmp::Ordering;
9use std::collections::{HashMap, HashSet};
10
11use runmat_accelerate_api::GpuTensorHandle;
12use runmat_builtins::{
13 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
14 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
15 CharArray, ComplexTensor, StringArray, Tensor, Value,
16};
17use runmat_macros::runtime_builtin;
18
19use super::type_resolvers::set_values_output_type;
20use crate::build_runtime_error;
21use crate::builtins::common::arg_tokens::tokens_from_values;
22use crate::builtins::common::gpu_helpers;
23use crate::builtins::common::random_args::complex_tensor_into_value;
24use crate::builtins::common::spec::{
25 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
26 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
27};
28use crate::builtins::common::tensor;
29
30#[runmat_macros::register_gpu_spec(
31 builtin_path = "crate::builtins::array::sorting_sets::intersect"
32)]
33pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
34 name: "intersect",
35 op_kind: GpuOpKind::Custom("intersect"),
36 supported_precisions: &[ScalarType::F32, ScalarType::F64],
37 broadcast: BroadcastSemantics::None,
38 provider_hooks: &[ProviderHook::Custom("intersect")],
39 constant_strategy: ConstantStrategy::InlineLiteral,
40 residency: ResidencyPolicy::GatherImmediately,
41 nan_mode: ReductionNaN::Include,
42 two_pass_threshold: None,
43 workgroup_size: None,
44 accepts_nan_mode: true,
45 notes:
46 "Providers may expose a dedicated intersect hook; otherwise tensors are gathered and processed on the host.",
47};
48
49#[runmat_macros::register_fusion_spec(
50 builtin_path = "crate::builtins::array::sorting_sets::intersect"
51)]
52pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
53 name: "intersect",
54 shape: ShapeRequirements::Any,
55 constant_strategy: ConstantStrategy::InlineLiteral,
56 elementwise: None,
57 reduction: None,
58 emits_nan: true,
59 notes: "`intersect` materialises its inputs and terminates fusion chains; upstream GPU tensors are gathered when necessary.",
60};
61
62const BUILTIN_NAME: &str = "intersect";
63
64const INTERSECT_OUTPUT_C: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
65 name: "C",
66 ty: BuiltinParamType::Any,
67 arity: BuiltinParamArity::Required,
68 default: None,
69 description: "Intersection values or rows.",
70}];
71
72const INTERSECT_OUTPUT_C_IA: [BuiltinParamDescriptor; 2] = [
73 BuiltinParamDescriptor {
74 name: "C",
75 ty: BuiltinParamType::Any,
76 arity: BuiltinParamArity::Required,
77 default: None,
78 description: "Intersection values or rows.",
79 },
80 BuiltinParamDescriptor {
81 name: "ia",
82 ty: BuiltinParamType::NumericArray,
83 arity: BuiltinParamArity::Required,
84 default: None,
85 description: "Indices selecting matching elements/rows in A.",
86 },
87];
88
89const INTERSECT_OUTPUT_C_IA_IB: [BuiltinParamDescriptor; 3] = [
90 BuiltinParamDescriptor {
91 name: "C",
92 ty: BuiltinParamType::Any,
93 arity: BuiltinParamArity::Required,
94 default: None,
95 description: "Intersection values or rows.",
96 },
97 BuiltinParamDescriptor {
98 name: "ia",
99 ty: BuiltinParamType::NumericArray,
100 arity: BuiltinParamArity::Required,
101 default: None,
102 description: "Indices selecting matching elements/rows in A.",
103 },
104 BuiltinParamDescriptor {
105 name: "ib",
106 ty: BuiltinParamType::NumericArray,
107 arity: BuiltinParamArity::Required,
108 default: None,
109 description: "Indices selecting matching elements/rows in B.",
110 },
111];
112
113const INTERSECT_INPUTS_A_B: [BuiltinParamDescriptor; 2] = [
114 BuiltinParamDescriptor {
115 name: "A",
116 ty: BuiltinParamType::Any,
117 arity: BuiltinParamArity::Required,
118 default: None,
119 description: "First input array.",
120 },
121 BuiltinParamDescriptor {
122 name: "B",
123 ty: BuiltinParamType::Any,
124 arity: BuiltinParamArity::Required,
125 default: None,
126 description: "Second input array.",
127 },
128];
129
130const INTERSECT_INPUTS_A_B_OPTIONS: [BuiltinParamDescriptor; 3] = [
131 BuiltinParamDescriptor {
132 name: "A",
133 ty: BuiltinParamType::Any,
134 arity: BuiltinParamArity::Required,
135 default: None,
136 description: "First input array.",
137 },
138 BuiltinParamDescriptor {
139 name: "B",
140 ty: BuiltinParamType::Any,
141 arity: BuiltinParamArity::Required,
142 default: None,
143 description: "Second input array.",
144 },
145 BuiltinParamDescriptor {
146 name: "option",
147 ty: BuiltinParamType::StringScalar,
148 arity: BuiltinParamArity::Variadic,
149 default: None,
150 description: "Option tokens: 'rows'|'sorted'|'stable'.",
151 },
152];
153
154const INTERSECT_SIGNATURES: [BuiltinSignatureDescriptor; 6] = [
155 BuiltinSignatureDescriptor {
156 label: "C = intersect(A, B)",
157 inputs: &INTERSECT_INPUTS_A_B,
158 outputs: &INTERSECT_OUTPUT_C,
159 },
160 BuiltinSignatureDescriptor {
161 label: "C = intersect(A, B, option...)",
162 inputs: &INTERSECT_INPUTS_A_B_OPTIONS,
163 outputs: &INTERSECT_OUTPUT_C,
164 },
165 BuiltinSignatureDescriptor {
166 label: "[C, ia] = intersect(A, B)",
167 inputs: &INTERSECT_INPUTS_A_B,
168 outputs: &INTERSECT_OUTPUT_C_IA,
169 },
170 BuiltinSignatureDescriptor {
171 label: "[C, ia] = intersect(A, B, option...)",
172 inputs: &INTERSECT_INPUTS_A_B_OPTIONS,
173 outputs: &INTERSECT_OUTPUT_C_IA,
174 },
175 BuiltinSignatureDescriptor {
176 label: "[C, ia, ib] = intersect(A, B)",
177 inputs: &INTERSECT_INPUTS_A_B,
178 outputs: &INTERSECT_OUTPUT_C_IA_IB,
179 },
180 BuiltinSignatureDescriptor {
181 label: "[C, ia, ib] = intersect(A, B, option...)",
182 inputs: &INTERSECT_INPUTS_A_B_OPTIONS,
183 outputs: &INTERSECT_OUTPUT_C_IA_IB,
184 },
185];
186
187const INTERSECT_ERROR_LEGACY_OPTION_UNSUPPORTED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
188 code: "RM.INTERSECT.LEGACY_OPTION_UNSUPPORTED",
189 identifier: Some("RunMat:intersect:LegacyOptionUnsupported"),
190 when: "Legacy compatibility options are requested.",
191 message: "intersect: the 'legacy' behaviour is not supported",
192};
193
194const INTERSECT_ERROR_CONFLICTING_ORDER_OPTIONS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
195 code: "RM.INTERSECT.CONFLICTING_ORDER_OPTIONS",
196 identifier: Some("RunMat:intersect:ConflictingOrderOptions"),
197 when: "Both 'sorted' and 'stable' options are provided.",
198 message: "intersect: cannot combine 'sorted' with 'stable'",
199};
200
201const INTERSECT_ERROR_UNKNOWN_OPTION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
202 code: "RM.INTERSECT.UNKNOWN_OPTION",
203 identifier: Some("RunMat:intersect:UnknownOption"),
204 when: "An unsupported option token is provided.",
205 message: "intersect: unrecognised option",
206};
207
208const INTERSECT_ERROR_ROWS_COLUMN_MISMATCH: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
209 code: "RM.INTERSECT.ROWS_COLUMN_MISMATCH",
210 identifier: Some("RunMat:intersect:RowsColumnMismatch"),
211 when: "'rows' mode is used and column counts differ.",
212 message: "intersect: inputs must have the same number of columns when using 'rows'",
213};
214
215const INTERSECT_ERROR_UNSUPPORTED_INPUT_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
216 code: "RM.INTERSECT.UNSUPPORTED_INPUT_TYPE",
217 identifier: Some("RunMat:intersect:UnsupportedInputType"),
218 when: "Input values cannot be converted into supported intersect domains.",
219 message: "intersect: unsupported input type",
220};
221
222const INTERSECT_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
223 code: "RM.INTERSECT.INVALID_ARGUMENT",
224 identifier: Some("RunMat:intersect:InvalidArgument"),
225 when: "Option arguments are not string-like where required.",
226 message: "intersect: expected string option arguments",
227};
228
229const INTERSECT_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
230 code: "RM.INTERSECT.INTERNAL",
231 identifier: Some("RunMat:intersect:Internal"),
232 when: "Internal conversion/allocation/provider decode fails.",
233 message: "intersect: internal operation failed",
234};
235
236const INTERSECT_ERRORS: [BuiltinErrorDescriptor; 7] = [
237 INTERSECT_ERROR_LEGACY_OPTION_UNSUPPORTED,
238 INTERSECT_ERROR_CONFLICTING_ORDER_OPTIONS,
239 INTERSECT_ERROR_UNKNOWN_OPTION,
240 INTERSECT_ERROR_ROWS_COLUMN_MISMATCH,
241 INTERSECT_ERROR_UNSUPPORTED_INPUT_TYPE,
242 INTERSECT_ERROR_INVALID_ARGUMENT,
243 INTERSECT_ERROR_INTERNAL,
244];
245
246pub const INTERSECT_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
247 signatures: &INTERSECT_SIGNATURES,
248 output_mode: BuiltinOutputMode::ByRequestedOutputCount,
249 completion_policy: BuiltinCompletionPolicy::Public,
250 errors: &INTERSECT_ERRORS,
251};
252
253fn intersect_error_with(
254 error: &'static BuiltinErrorDescriptor,
255 message: impl Into<String>,
256) -> crate::RuntimeError {
257 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
258 if let Some(identifier) = error.identifier {
259 builder = builder.with_identifier(identifier);
260 }
261 builder.build()
262}
263
264fn intersect_error(error: &'static BuiltinErrorDescriptor) -> crate::RuntimeError {
265 intersect_error_with(error, error.message)
266}
267
268fn intersect_internal_error(message: impl Into<String>) -> crate::RuntimeError {
269 intersect_error_with(&INTERSECT_ERROR_INTERNAL, message)
270}
271
272#[runtime_builtin(
273 name = "intersect",
274 category = "array/sorting_sets",
275 summary = "Return common elements or rows across arrays with index outputs.",
276 keywords = "intersect,set,stable,rows,indices,gpu",
277 accel = "array_construct",
278 sink = true,
279 type_resolver(set_values_output_type),
280 descriptor(crate::builtins::array::sorting_sets::intersect::INTERSECT_DESCRIPTOR),
281 builtin_path = "crate::builtins::array::sorting_sets::intersect"
282)]
283async fn intersect_builtin(a: Value, b: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
284 let eval = evaluate(a, b, &rest).await?;
285 if let Some(out_count) = crate::output_count::current_output_count() {
286 if out_count == 0 {
287 return Ok(Value::OutputList(Vec::new()));
288 }
289 if out_count == 1 {
290 return Ok(Value::OutputList(vec![eval.into_values_value()]));
291 }
292 if out_count == 2 {
293 let (values, ia) = eval.into_pair();
294 return Ok(Value::OutputList(vec![values, ia]));
295 }
296 let (values, ia, ib) = eval.into_triple();
297 return Ok(crate::output_count::output_list_with_padding(
298 out_count,
299 vec![values, ia, ib],
300 ));
301 }
302 Ok(eval.into_values_value())
303}
304
305pub async fn evaluate(
307 a: Value,
308 b: Value,
309 rest: &[Value],
310) -> crate::BuiltinResult<IntersectEvaluation> {
311 let opts = parse_options(rest)?;
312 match (a, b) {
313 (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
314 intersect_gpu_pair(handle_a, handle_b, &opts).await
315 }
316 (Value::GpuTensor(handle_a), other) => {
317 intersect_gpu_mixed(handle_a, other, &opts, true).await
318 }
319 (other, Value::GpuTensor(handle_b)) => {
320 intersect_gpu_mixed(handle_b, other, &opts, false).await
321 }
322 (left, right) => intersect_host(left, right, &opts),
323 }
324}
325
326#[derive(Debug, Clone, Copy, PartialEq, Eq)]
327enum IntersectOrder {
328 Sorted,
329 Stable,
330}
331
332#[derive(Debug, Clone)]
333struct IntersectOptions {
334 rows: bool,
335 order: IntersectOrder,
336}
337
338fn parse_options(rest: &[Value]) -> crate::BuiltinResult<IntersectOptions> {
339 let mut opts = IntersectOptions {
340 rows: false,
341 order: IntersectOrder::Sorted,
342 };
343 let mut seen_order: Option<IntersectOrder> = None;
344
345 let tokens = tokens_from_values(rest);
346 for (arg, token) in rest.iter().zip(tokens.iter()) {
347 let text = match token {
348 crate::builtins::common::arg_tokens::ArgToken::String(text) => text.as_str(),
349 _ => {
350 let text = tensor::value_to_string(arg)
351 .ok_or_else(|| intersect_error(&INTERSECT_ERROR_INVALID_ARGUMENT))?;
352 let lowered = text.trim().to_ascii_lowercase();
353 parse_intersect_option(&mut opts, &mut seen_order, &lowered)?;
354 continue;
355 }
356 };
357 parse_intersect_option(&mut opts, &mut seen_order, text)?;
358 }
359
360 Ok(opts)
361}
362
363fn parse_intersect_option(
364 opts: &mut IntersectOptions,
365 seen_order: &mut Option<IntersectOrder>,
366 lowered: &str,
367) -> crate::BuiltinResult<()> {
368 match lowered {
369 "rows" => opts.rows = true,
370 "sorted" => {
371 if let Some(prev) = seen_order {
372 if *prev != IntersectOrder::Sorted {
373 return Err(intersect_error(&INTERSECT_ERROR_CONFLICTING_ORDER_OPTIONS));
374 }
375 }
376 *seen_order = Some(IntersectOrder::Sorted);
377 opts.order = IntersectOrder::Sorted;
378 }
379 "stable" => {
380 if let Some(prev) = seen_order {
381 if *prev != IntersectOrder::Stable {
382 return Err(intersect_error(&INTERSECT_ERROR_CONFLICTING_ORDER_OPTIONS));
383 }
384 }
385 *seen_order = Some(IntersectOrder::Stable);
386 opts.order = IntersectOrder::Stable;
387 }
388 "legacy" | "r2012a" => {
389 return Err(intersect_error(&INTERSECT_ERROR_LEGACY_OPTION_UNSUPPORTED));
390 }
391 other => {
392 return Err(intersect_error_with(
393 &INTERSECT_ERROR_UNKNOWN_OPTION,
394 format!("intersect: unrecognised option '{other}'"),
395 ))
396 }
397 }
398 Ok(())
399}
400
401async fn intersect_gpu_pair(
402 handle_a: GpuTensorHandle,
403 handle_b: GpuTensorHandle,
404 opts: &IntersectOptions,
405) -> crate::BuiltinResult<IntersectEvaluation> {
406 let tensor_a = gpu_helpers::gather_tensor_async(&handle_a).await?;
407 let tensor_b = gpu_helpers::gather_tensor_async(&handle_b).await?;
408 intersect_numeric(tensor_a, tensor_b, opts)
409}
410
411async fn intersect_gpu_mixed(
412 handle_gpu: GpuTensorHandle,
413 other: Value,
414 opts: &IntersectOptions,
415 gpu_is_a: bool,
416) -> crate::BuiltinResult<IntersectEvaluation> {
417 let tensor_gpu = gpu_helpers::gather_tensor_async(&handle_gpu).await?;
418 let tensor_other = tensor::value_into_tensor_for("intersect", other)
419 .map_err(|e| intersect_internal_error(e))?;
420 if gpu_is_a {
421 intersect_numeric(tensor_gpu, tensor_other, opts)
422 } else {
423 intersect_numeric(tensor_other, tensor_gpu, opts)
424 }
425}
426
427fn intersect_host(
428 a: Value,
429 b: Value,
430 opts: &IntersectOptions,
431) -> crate::BuiltinResult<IntersectEvaluation> {
432 match (a, b) {
433 (Value::ComplexTensor(at), Value::ComplexTensor(bt)) => intersect_complex(at, bt, opts),
434 (Value::ComplexTensor(at), Value::Complex(re, im)) => {
435 let bt = scalar_complex_tensor(re, im)?;
436 intersect_complex(at, bt, opts)
437 }
438 (Value::Complex(re, im), Value::ComplexTensor(bt)) => {
439 let at = scalar_complex_tensor(re, im)?;
440 intersect_complex(at, bt, opts)
441 }
442 (Value::Complex(a_re, a_im), Value::Complex(b_re, b_im)) => {
443 let at = scalar_complex_tensor(a_re, a_im)?;
444 let bt = scalar_complex_tensor(b_re, b_im)?;
445 intersect_complex(at, bt, opts)
446 }
447 (Value::ComplexTensor(at), other) => {
448 let bt = value_into_complex_tensor(other)?;
449 intersect_complex(at, bt, opts)
450 }
451 (other, Value::ComplexTensor(bt)) => {
452 let at = value_into_complex_tensor(other)?;
453 intersect_complex(at, bt, opts)
454 }
455 (Value::Complex(re, im), other) => {
456 let at = scalar_complex_tensor(re, im)?;
457 let bt = value_into_complex_tensor(other)?;
458 intersect_complex(at, bt, opts)
459 }
460 (other, Value::Complex(re, im)) => {
461 let at = value_into_complex_tensor(other)?;
462 let bt = scalar_complex_tensor(re, im)?;
463 intersect_complex(at, bt, opts)
464 }
465
466 (Value::CharArray(ac), Value::CharArray(bc)) => intersect_char(ac, bc, opts),
467
468 (Value::StringArray(astring), Value::StringArray(bstring)) => {
469 intersect_string(astring, bstring, opts)
470 }
471 (Value::StringArray(astring), Value::String(b)) => {
472 let bstring = StringArray::new(vec![b], vec![1, 1])
473 .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
474 intersect_string(astring, bstring, opts)
475 }
476 (Value::String(a), Value::StringArray(bstring)) => {
477 let astring = StringArray::new(vec![a], vec![1, 1])
478 .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
479 intersect_string(astring, bstring, opts)
480 }
481 (Value::String(a), Value::String(b)) => {
482 let astring = StringArray::new(vec![a], vec![1, 1])
483 .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
484 let bstring = StringArray::new(vec![b], vec![1, 1])
485 .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
486 intersect_string(astring, bstring, opts)
487 }
488
489 (left, right) => {
490 let tensor_a = tensor::value_into_tensor_for("intersect", left)
491 .map_err(|e| intersect_error_with(&INTERSECT_ERROR_UNSUPPORTED_INPUT_TYPE, e))?;
492 let tensor_b = tensor::value_into_tensor_for("intersect", right)
493 .map_err(|e| intersect_error_with(&INTERSECT_ERROR_UNSUPPORTED_INPUT_TYPE, e))?;
494 intersect_numeric(tensor_a, tensor_b, opts)
495 }
496 }
497}
498
499fn intersect_numeric(
500 a: Tensor,
501 b: Tensor,
502 opts: &IntersectOptions,
503) -> crate::BuiltinResult<IntersectEvaluation> {
504 if opts.rows {
505 intersect_numeric_rows(a, b, opts)
506 } else {
507 intersect_numeric_elements(a, b, opts)
508 }
509}
510
511fn intersect_numeric_elements(
512 a: Tensor,
513 b: Tensor,
514 opts: &IntersectOptions,
515) -> crate::BuiltinResult<IntersectEvaluation> {
516 let mut b_map: HashMap<u64, usize> = HashMap::new();
517 for (idx, &value) in b.data.iter().enumerate() {
518 let key = canonicalize_f64(value);
519 b_map.entry(key).or_insert(idx);
520 }
521
522 let mut seen: HashSet<u64> = HashSet::new();
523 let mut entries = Vec::<NumericIntersectEntry>::new();
524 let mut order_counter = 0usize;
525
526 for (idx, &value) in a.data.iter().enumerate() {
527 let key = canonicalize_f64(value);
528 if seen.contains(&key) {
529 continue;
530 }
531 if let Some(&b_idx) = b_map.get(&key) {
532 entries.push(NumericIntersectEntry {
533 value,
534 a_index: idx,
535 b_index: b_idx,
536 order_rank: order_counter,
537 });
538 seen.insert(key);
539 order_counter += 1;
540 }
541 }
542
543 assemble_numeric_intersect(entries, opts)
544}
545
546fn intersect_numeric_rows(
547 a: Tensor,
548 b: Tensor,
549 opts: &IntersectOptions,
550) -> crate::BuiltinResult<IntersectEvaluation> {
551 if a.shape.len() != 2 || b.shape.len() != 2 {
552 return Err(intersect_internal_error(
553 "intersect: 'rows' option requires 2-D numeric matrices",
554 ));
555 }
556 if a.shape[1] != b.shape[1] {
557 return Err(intersect_error(&INTERSECT_ERROR_ROWS_COLUMN_MISMATCH));
558 }
559 let rows_a = a.shape[0];
560 let cols = a.shape[1];
561 let rows_b = b.shape[0];
562
563 let mut b_map: HashMap<NumericRowKey, usize> = HashMap::new();
564 for r in 0..rows_b {
565 let mut row_values = Vec::with_capacity(cols);
566 for c in 0..cols {
567 let idx = r + c * rows_b;
568 row_values.push(b.data[idx]);
569 }
570 let key = NumericRowKey::from_slice(&row_values);
571 b_map.entry(key).or_insert(r);
572 }
573
574 let mut seen: HashSet<NumericRowKey> = HashSet::new();
575 let mut entries = Vec::<NumericRowIntersectEntry>::new();
576 let mut order_counter = 0usize;
577
578 for r in 0..rows_a {
579 let mut row_values = Vec::with_capacity(cols);
580 for c in 0..cols {
581 let idx = r + c * rows_a;
582 row_values.push(a.data[idx]);
583 }
584 let key = NumericRowKey::from_slice(&row_values);
585 if seen.contains(&key) {
586 continue;
587 }
588 if let Some(&b_row) = b_map.get(&key) {
589 entries.push(NumericRowIntersectEntry {
590 row_data: row_values,
591 a_row: r,
592 b_row,
593 order_rank: order_counter,
594 });
595 seen.insert(key);
596 order_counter += 1;
597 }
598 }
599
600 assemble_numeric_row_intersect(entries, opts, cols)
601}
602
603fn intersect_complex(
604 a: ComplexTensor,
605 b: ComplexTensor,
606 opts: &IntersectOptions,
607) -> crate::BuiltinResult<IntersectEvaluation> {
608 if opts.rows {
609 intersect_complex_rows(a, b, opts)
610 } else {
611 intersect_complex_elements(a, b, opts)
612 }
613}
614
615fn intersect_complex_elements(
616 a: ComplexTensor,
617 b: ComplexTensor,
618 opts: &IntersectOptions,
619) -> crate::BuiltinResult<IntersectEvaluation> {
620 let mut b_map: HashMap<ComplexKey, usize> = HashMap::new();
621 for (idx, &value) in b.data.iter().enumerate() {
622 let key = ComplexKey::new(value);
623 b_map.entry(key).or_insert(idx);
624 }
625
626 let mut seen: HashSet<ComplexKey> = HashSet::new();
627 let mut entries = Vec::<ComplexIntersectEntry>::new();
628 let mut order_counter = 0usize;
629
630 for (idx, &value) in a.data.iter().enumerate() {
631 let key = ComplexKey::new(value);
632 if seen.contains(&key) {
633 continue;
634 }
635 if let Some(&b_idx) = b_map.get(&key) {
636 entries.push(ComplexIntersectEntry {
637 value,
638 a_index: idx,
639 b_index: b_idx,
640 order_rank: order_counter,
641 });
642 seen.insert(key);
643 order_counter += 1;
644 }
645 }
646
647 assemble_complex_intersect(entries, opts)
648}
649
650fn intersect_complex_rows(
651 a: ComplexTensor,
652 b: ComplexTensor,
653 opts: &IntersectOptions,
654) -> crate::BuiltinResult<IntersectEvaluation> {
655 if a.shape.len() != 2 || b.shape.len() != 2 {
656 return Err(intersect_internal_error(
657 "intersect: 'rows' option requires 2-D complex matrices",
658 ));
659 }
660 if a.shape[1] != b.shape[1] {
661 return Err(intersect_error(&INTERSECT_ERROR_ROWS_COLUMN_MISMATCH));
662 }
663 let rows_a = a.shape[0];
664 let cols = a.shape[1];
665 let rows_b = b.shape[0];
666
667 let mut b_map: HashMap<Vec<ComplexKey>, usize> = HashMap::new();
668 for r in 0..rows_b {
669 let mut row_keys = Vec::with_capacity(cols);
670 for c in 0..cols {
671 let idx = r + c * rows_b;
672 row_keys.push(ComplexKey::new(b.data[idx]));
673 }
674 b_map.entry(row_keys).or_insert(r);
675 }
676
677 let mut seen: HashSet<Vec<ComplexKey>> = HashSet::new();
678 let mut entries = Vec::<ComplexRowIntersectEntry>::new();
679 let mut order_counter = 0usize;
680
681 for r in 0..rows_a {
682 let mut row_values = Vec::with_capacity(cols);
683 let mut row_keys = Vec::with_capacity(cols);
684 for c in 0..cols {
685 let idx = r + c * rows_a;
686 let value = a.data[idx];
687 row_values.push(value);
688 row_keys.push(ComplexKey::new(value));
689 }
690 if seen.contains(&row_keys) {
691 continue;
692 }
693 if let Some(&b_row) = b_map.get(&row_keys) {
694 entries.push(ComplexRowIntersectEntry {
695 row_data: row_values,
696 a_row: r,
697 b_row,
698 order_rank: order_counter,
699 });
700 seen.insert(row_keys);
701 order_counter += 1;
702 }
703 }
704
705 assemble_complex_row_intersect(entries, opts, cols)
706}
707
708fn intersect_char(
709 a: CharArray,
710 b: CharArray,
711 opts: &IntersectOptions,
712) -> crate::BuiltinResult<IntersectEvaluation> {
713 if opts.rows {
714 intersect_char_rows(a, b, opts)
715 } else {
716 intersect_char_elements(a, b, opts)
717 }
718}
719
720fn intersect_char_elements(
721 a: CharArray,
722 b: CharArray,
723 opts: &IntersectOptions,
724) -> crate::BuiltinResult<IntersectEvaluation> {
725 let mut seen: HashSet<u32> = HashSet::new();
726 let mut entries = Vec::<CharIntersectEntry>::new();
727 let mut order_counter = 0usize;
728
729 for col in 0..a.cols {
730 for row in 0..a.rows {
731 let linear_idx = row + col * a.rows;
732 let data_idx = row * a.cols + col;
733 let ch = a.data[data_idx];
734 let key = ch as u32;
735 if seen.contains(&key) {
736 continue;
737 }
738 if let Some(b_idx) = find_char_index(&b, ch) {
739 entries.push(CharIntersectEntry {
740 ch,
741 a_index: linear_idx,
742 b_index: b_idx,
743 order_rank: order_counter,
744 });
745 seen.insert(key);
746 order_counter += 1;
747 }
748 }
749 }
750
751 assemble_char_intersect(entries, opts, &b)
752}
753
754fn intersect_char_rows(
755 a: CharArray,
756 b: CharArray,
757 opts: &IntersectOptions,
758) -> crate::BuiltinResult<IntersectEvaluation> {
759 if a.cols != b.cols {
760 return Err(intersect_error(&INTERSECT_ERROR_ROWS_COLUMN_MISMATCH));
761 }
762 let rows_a = a.rows;
763 let rows_b = b.rows;
764 let cols = a.cols;
765
766 let mut b_map: HashMap<RowCharKey, usize> = HashMap::new();
767 for r in 0..rows_b {
768 let mut row_values = Vec::with_capacity(cols);
769 for c in 0..cols {
770 let idx = r * cols + c;
771 row_values.push(b.data[idx]);
772 }
773 let key = RowCharKey::from_slice(&row_values);
774 b_map.entry(key).or_insert(r);
775 }
776
777 let mut seen: HashSet<RowCharKey> = HashSet::new();
778 let mut entries = Vec::<CharRowIntersectEntry>::new();
779 let mut order_counter = 0usize;
780
781 for r in 0..rows_a {
782 let mut row_values = Vec::with_capacity(cols);
783 for c in 0..cols {
784 let idx = r * cols + c;
785 row_values.push(a.data[idx]);
786 }
787 let key = RowCharKey::from_slice(&row_values);
788 if seen.contains(&key) {
789 continue;
790 }
791 if let Some(&b_row) = b_map.get(&key) {
792 entries.push(CharRowIntersectEntry {
793 row_data: row_values,
794 a_row: r,
795 b_row,
796 order_rank: order_counter,
797 });
798 seen.insert(key);
799 order_counter += 1;
800 }
801 }
802
803 assemble_char_row_intersect(entries, opts, cols)
804}
805
806fn find_char_index(array: &CharArray, target: char) -> Option<usize> {
807 for col in 0..array.cols {
808 for row in 0..array.rows {
809 let data_idx = row * array.cols + col;
810 if array.data[data_idx] == target {
811 return Some(row + col * array.rows);
812 }
813 }
814 }
815 None
816}
817
818fn intersect_string(
819 a: StringArray,
820 b: StringArray,
821 opts: &IntersectOptions,
822) -> crate::BuiltinResult<IntersectEvaluation> {
823 if opts.rows {
824 intersect_string_rows(a, b, opts)
825 } else {
826 intersect_string_elements(a, b, opts)
827 }
828}
829
830fn intersect_string_elements(
831 a: StringArray,
832 b: StringArray,
833 opts: &IntersectOptions,
834) -> crate::BuiltinResult<IntersectEvaluation> {
835 let mut b_map: HashMap<String, usize> = HashMap::new();
836 for (idx, value) in b.data.iter().enumerate() {
837 b_map.entry(value.clone()).or_insert(idx);
838 }
839
840 let mut seen: HashSet<String> = HashSet::new();
841 let mut entries = Vec::<StringIntersectEntry>::new();
842 let mut order_counter = 0usize;
843
844 for (idx, value) in a.data.iter().enumerate() {
845 if seen.contains(value) {
846 continue;
847 }
848 if let Some(&b_idx) = b_map.get(value) {
849 entries.push(StringIntersectEntry {
850 value: value.clone(),
851 a_index: idx,
852 b_index: b_idx,
853 order_rank: order_counter,
854 });
855 seen.insert(value.clone());
856 order_counter += 1;
857 }
858 }
859
860 assemble_string_intersect(entries, opts)
861}
862
863fn intersect_string_rows(
864 a: StringArray,
865 b: StringArray,
866 opts: &IntersectOptions,
867) -> crate::BuiltinResult<IntersectEvaluation> {
868 if a.shape.len() != 2 || b.shape.len() != 2 {
869 return Err(intersect_internal_error(
870 "intersect: 'rows' option requires 2-D string arrays",
871 ));
872 }
873 if a.shape[1] != b.shape[1] {
874 return Err(intersect_error(&INTERSECT_ERROR_ROWS_COLUMN_MISMATCH));
875 }
876 let rows_a = a.shape[0];
877 let cols = a.shape[1];
878 let rows_b = b.shape[0];
879
880 let mut b_map: HashMap<RowStringKey, usize> = HashMap::new();
881 for r in 0..rows_b {
882 let mut row_values = Vec::with_capacity(cols);
883 for c in 0..cols {
884 let idx = r + c * rows_b;
885 row_values.push(b.data[idx].clone());
886 }
887 let key = RowStringKey::from_slice(&row_values);
888 b_map.entry(key).or_insert(r);
889 }
890
891 let mut seen: HashSet<RowStringKey> = HashSet::new();
892 let mut entries = Vec::<StringRowIntersectEntry>::new();
893 let mut order_counter = 0usize;
894
895 for r in 0..rows_a {
896 let mut row_values = Vec::with_capacity(cols);
897 for c in 0..cols {
898 let idx = r + c * rows_a;
899 row_values.push(a.data[idx].clone());
900 }
901 let key = RowStringKey::from_slice(&row_values);
902 if seen.contains(&key) {
903 continue;
904 }
905 if let Some(&b_row) = b_map.get(&key) {
906 entries.push(StringRowIntersectEntry {
907 row_data: row_values,
908 a_row: r,
909 b_row,
910 order_rank: order_counter,
911 });
912 seen.insert(key);
913 order_counter += 1;
914 }
915 }
916
917 assemble_string_row_intersect(entries, opts, cols)
918}
919
920#[derive(Debug, Clone)]
921pub struct IntersectEvaluation {
922 values: Value,
923 ia: Tensor,
924 ib: Tensor,
925}
926
927impl IntersectEvaluation {
928 fn new(values: Value, ia: Tensor, ib: Tensor) -> Self {
929 Self { values, ia, ib }
930 }
931
932 pub fn into_values_value(self) -> Value {
933 self.values
934 }
935
936 pub fn into_pair(self) -> (Value, Value) {
937 let ia = tensor::tensor_into_value(self.ia);
938 (self.values, ia)
939 }
940
941 pub fn into_triple(self) -> (Value, Value, Value) {
942 let ia = tensor::tensor_into_value(self.ia);
943 let ib = tensor::tensor_into_value(self.ib);
944 (self.values, ia, ib)
945 }
946
947 pub fn values_value(&self) -> Value {
948 self.values.clone()
949 }
950
951 pub fn ia_value(&self) -> Value {
952 tensor::tensor_into_value(self.ia.clone())
953 }
954
955 pub fn ib_value(&self) -> Value {
956 tensor::tensor_into_value(self.ib.clone())
957 }
958}
959
960#[derive(Debug)]
961struct NumericIntersectEntry {
962 value: f64,
963 a_index: usize,
964 b_index: usize,
965 order_rank: usize,
966}
967
968#[derive(Debug)]
969struct NumericRowIntersectEntry {
970 row_data: Vec<f64>,
971 a_row: usize,
972 b_row: usize,
973 order_rank: usize,
974}
975
976#[derive(Debug)]
977struct ComplexIntersectEntry {
978 value: (f64, f64),
979 a_index: usize,
980 b_index: usize,
981 order_rank: usize,
982}
983
984#[derive(Debug)]
985struct ComplexRowIntersectEntry {
986 row_data: Vec<(f64, f64)>,
987 a_row: usize,
988 b_row: usize,
989 order_rank: usize,
990}
991
992#[derive(Debug)]
993struct CharIntersectEntry {
994 ch: char,
995 a_index: usize,
996 b_index: usize,
997 order_rank: usize,
998}
999
1000#[derive(Debug)]
1001struct CharRowIntersectEntry {
1002 row_data: Vec<char>,
1003 a_row: usize,
1004 b_row: usize,
1005 order_rank: usize,
1006}
1007
1008#[derive(Debug)]
1009struct StringIntersectEntry {
1010 value: String,
1011 a_index: usize,
1012 b_index: usize,
1013 order_rank: usize,
1014}
1015
1016#[derive(Debug)]
1017struct StringRowIntersectEntry {
1018 row_data: Vec<String>,
1019 a_row: usize,
1020 b_row: usize,
1021 order_rank: usize,
1022}
1023
1024fn assemble_numeric_intersect(
1025 entries: Vec<NumericIntersectEntry>,
1026 opts: &IntersectOptions,
1027) -> crate::BuiltinResult<IntersectEvaluation> {
1028 let mut order: Vec<usize> = (0..entries.len()).collect();
1029 match opts.order {
1030 IntersectOrder::Sorted => {
1031 order.sort_by(|&lhs, &rhs| compare_f64(entries[lhs].value, entries[rhs].value));
1032 }
1033 IntersectOrder::Stable => {
1034 order.sort_by_key(|&idx| entries[idx].order_rank);
1035 }
1036 }
1037
1038 let mut values = Vec::with_capacity(order.len());
1039 let mut ia = Vec::with_capacity(order.len());
1040 let mut ib = Vec::with_capacity(order.len());
1041 for &idx in &order {
1042 let entry = &entries[idx];
1043 values.push(entry.value);
1044 ia.push((entry.a_index + 1) as f64);
1045 ib.push((entry.b_index + 1) as f64);
1046 }
1047
1048 let value_tensor = Tensor::new(values, vec![order.len(), 1])
1049 .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1050 let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
1051 .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1052 let ib_tensor = Tensor::new(ib, vec![order.len(), 1])
1053 .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1054
1055 Ok(IntersectEvaluation::new(
1056 tensor::tensor_into_value(value_tensor),
1057 ia_tensor,
1058 ib_tensor,
1059 ))
1060}
1061
1062fn assemble_numeric_row_intersect(
1063 entries: Vec<NumericRowIntersectEntry>,
1064 opts: &IntersectOptions,
1065 cols: usize,
1066) -> crate::BuiltinResult<IntersectEvaluation> {
1067 let mut order: Vec<usize> = (0..entries.len()).collect();
1068 match opts.order {
1069 IntersectOrder::Sorted => {
1070 order.sort_by(|&lhs, &rhs| {
1071 compare_numeric_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1072 });
1073 }
1074 IntersectOrder::Stable => {
1075 order.sort_by_key(|&idx| entries[idx].order_rank);
1076 }
1077 }
1078
1079 let rows_out = order.len();
1080 let mut values = vec![0.0f64; rows_out * cols];
1081 let mut ia = Vec::with_capacity(rows_out);
1082 let mut ib = Vec::with_capacity(rows_out);
1083
1084 for (row_pos, &entry_idx) in order.iter().enumerate() {
1085 let entry = &entries[entry_idx];
1086 for col in 0..cols {
1087 let dest = row_pos + col * rows_out;
1088 values[dest] = entry.row_data[col];
1089 }
1090 ia.push((entry.a_row + 1) as f64);
1091 ib.push((entry.b_row + 1) as f64);
1092 }
1093
1094 let value_tensor = Tensor::new(values, vec![rows_out, cols])
1095 .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1096 let ia_tensor = Tensor::new(ia, vec![rows_out, 1])
1097 .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1098 let ib_tensor = Tensor::new(ib, vec![rows_out, 1])
1099 .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1100
1101 Ok(IntersectEvaluation::new(
1102 tensor::tensor_into_value(value_tensor),
1103 ia_tensor,
1104 ib_tensor,
1105 ))
1106}
1107
1108fn assemble_complex_intersect(
1109 entries: Vec<ComplexIntersectEntry>,
1110 opts: &IntersectOptions,
1111) -> crate::BuiltinResult<IntersectEvaluation> {
1112 let mut order: Vec<usize> = (0..entries.len()).collect();
1113 match opts.order {
1114 IntersectOrder::Sorted => {
1115 order.sort_by(|&lhs, &rhs| compare_complex(entries[lhs].value, entries[rhs].value));
1116 }
1117 IntersectOrder::Stable => {
1118 order.sort_by_key(|&idx| entries[idx].order_rank);
1119 }
1120 }
1121
1122 let mut values = Vec::with_capacity(order.len());
1123 let mut ia = Vec::with_capacity(order.len());
1124 let mut ib = Vec::with_capacity(order.len());
1125 for &idx in &order {
1126 let entry = &entries[idx];
1127 values.push(entry.value);
1128 ia.push((entry.a_index + 1) as f64);
1129 ib.push((entry.b_index + 1) as f64);
1130 }
1131
1132 let value_tensor = ComplexTensor::new(values, vec![order.len(), 1])
1133 .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1134 let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
1135 .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1136 let ib_tensor = Tensor::new(ib, vec![order.len(), 1])
1137 .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1138
1139 Ok(IntersectEvaluation::new(
1140 complex_tensor_into_value(value_tensor),
1141 ia_tensor,
1142 ib_tensor,
1143 ))
1144}
1145
1146fn assemble_complex_row_intersect(
1147 entries: Vec<ComplexRowIntersectEntry>,
1148 opts: &IntersectOptions,
1149 cols: usize,
1150) -> crate::BuiltinResult<IntersectEvaluation> {
1151 let mut order: Vec<usize> = (0..entries.len()).collect();
1152 match opts.order {
1153 IntersectOrder::Sorted => {
1154 order.sort_by(|&lhs, &rhs| {
1155 compare_complex_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1156 });
1157 }
1158 IntersectOrder::Stable => {
1159 order.sort_by_key(|&idx| entries[idx].order_rank);
1160 }
1161 }
1162
1163 let rows_out = order.len();
1164 let mut values = vec![(0.0f64, 0.0f64); rows_out * cols];
1165 let mut ia = Vec::with_capacity(rows_out);
1166 let mut ib = Vec::with_capacity(rows_out);
1167
1168 for (row_pos, &entry_idx) in order.iter().enumerate() {
1169 let entry = &entries[entry_idx];
1170 for col in 0..cols {
1171 let dest = row_pos + col * rows_out;
1172 values[dest] = entry.row_data[col];
1173 }
1174 ia.push((entry.a_row + 1) as f64);
1175 ib.push((entry.b_row + 1) as f64);
1176 }
1177
1178 let value_tensor = ComplexTensor::new(values, vec![rows_out, cols])
1179 .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1180 let ia_tensor = Tensor::new(ia, vec![rows_out, 1])
1181 .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1182 let ib_tensor = Tensor::new(ib, vec![rows_out, 1])
1183 .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1184
1185 Ok(IntersectEvaluation::new(
1186 complex_tensor_into_value(value_tensor),
1187 ia_tensor,
1188 ib_tensor,
1189 ))
1190}
1191
1192fn assemble_char_intersect(
1193 entries: Vec<CharIntersectEntry>,
1194 opts: &IntersectOptions,
1195 b: &CharArray,
1196) -> crate::BuiltinResult<IntersectEvaluation> {
1197 let mut order: Vec<usize> = (0..entries.len()).collect();
1198 match opts.order {
1199 IntersectOrder::Sorted => {
1200 order.sort_by(|&lhs, &rhs| entries[lhs].ch.cmp(&entries[rhs].ch));
1201 }
1202 IntersectOrder::Stable => {
1203 order.sort_by_key(|&idx| entries[idx].order_rank);
1204 }
1205 }
1206
1207 let mut values = Vec::with_capacity(order.len());
1208 let mut ia = Vec::with_capacity(order.len());
1209 let mut ib = Vec::with_capacity(order.len());
1210 for &idx in &order {
1211 let entry = &entries[idx];
1212 values.push(entry.ch);
1213 ia.push((entry.a_index + 1) as f64);
1214 let b_idx = find_char_index(b, entry.ch).unwrap_or(entry.b_index);
1215 ib.push((b_idx + 1) as f64);
1216 }
1217
1218 let value_array = CharArray::new(values, order.len(), 1)
1219 .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1220 let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
1221 .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1222 let ib_tensor = Tensor::new(ib, vec![order.len(), 1])
1223 .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1224
1225 Ok(IntersectEvaluation::new(
1226 Value::CharArray(value_array),
1227 ia_tensor,
1228 ib_tensor,
1229 ))
1230}
1231
1232fn assemble_char_row_intersect(
1233 entries: Vec<CharRowIntersectEntry>,
1234 opts: &IntersectOptions,
1235 cols: usize,
1236) -> crate::BuiltinResult<IntersectEvaluation> {
1237 let mut order: Vec<usize> = (0..entries.len()).collect();
1238 match opts.order {
1239 IntersectOrder::Sorted => {
1240 order.sort_by(|&lhs, &rhs| {
1241 compare_char_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1242 });
1243 }
1244 IntersectOrder::Stable => {
1245 order.sort_by_key(|&idx| entries[idx].order_rank);
1246 }
1247 }
1248
1249 let rows_out = order.len();
1250 let mut values = vec!['\0'; rows_out * cols];
1251 let mut ia = Vec::with_capacity(rows_out);
1252 let mut ib = Vec::with_capacity(rows_out);
1253
1254 for (row_pos, &entry_idx) in order.iter().enumerate() {
1255 let entry = &entries[entry_idx];
1256 for col in 0..cols {
1257 let dest = row_pos * cols + col;
1258 values[dest] = entry.row_data[col];
1259 }
1260 ia.push((entry.a_row + 1) as f64);
1261 ib.push((entry.b_row + 1) as f64);
1262 }
1263
1264 let value_array = CharArray::new(values, rows_out, cols)
1265 .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1266 let ia_tensor = Tensor::new(ia, vec![rows_out, 1])
1267 .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1268 let ib_tensor = Tensor::new(ib, vec![rows_out, 1])
1269 .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1270
1271 Ok(IntersectEvaluation::new(
1272 Value::CharArray(value_array),
1273 ia_tensor,
1274 ib_tensor,
1275 ))
1276}
1277
1278fn assemble_string_intersect(
1279 entries: Vec<StringIntersectEntry>,
1280 opts: &IntersectOptions,
1281) -> crate::BuiltinResult<IntersectEvaluation> {
1282 let mut order: Vec<usize> = (0..entries.len()).collect();
1283 match opts.order {
1284 IntersectOrder::Sorted => {
1285 order.sort_by(|&lhs, &rhs| entries[lhs].value.cmp(&entries[rhs].value));
1286 }
1287 IntersectOrder::Stable => {
1288 order.sort_by_key(|&idx| entries[idx].order_rank);
1289 }
1290 }
1291
1292 let mut values = Vec::with_capacity(order.len());
1293 let mut ia = Vec::with_capacity(order.len());
1294 let mut ib = Vec::with_capacity(order.len());
1295 for &idx in &order {
1296 let entry = &entries[idx];
1297 values.push(entry.value.clone());
1298 ia.push((entry.a_index + 1) as f64);
1299 ib.push((entry.b_index + 1) as f64);
1300 }
1301
1302 let value_array = StringArray::new(values, vec![order.len(), 1])
1303 .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1304 let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
1305 .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1306 let ib_tensor = Tensor::new(ib, vec![order.len(), 1])
1307 .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1308
1309 Ok(IntersectEvaluation::new(
1310 Value::StringArray(value_array),
1311 ia_tensor,
1312 ib_tensor,
1313 ))
1314}
1315
1316fn assemble_string_row_intersect(
1317 entries: Vec<StringRowIntersectEntry>,
1318 opts: &IntersectOptions,
1319 cols: usize,
1320) -> crate::BuiltinResult<IntersectEvaluation> {
1321 let mut order: Vec<usize> = (0..entries.len()).collect();
1322 match opts.order {
1323 IntersectOrder::Sorted => {
1324 order.sort_by(|&lhs, &rhs| {
1325 compare_string_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1326 });
1327 }
1328 IntersectOrder::Stable => {
1329 order.sort_by_key(|&idx| entries[idx].order_rank);
1330 }
1331 }
1332
1333 let rows_out = order.len();
1334 let mut values = vec![String::new(); rows_out * cols];
1335 let mut ia = Vec::with_capacity(rows_out);
1336 let mut ib = Vec::with_capacity(rows_out);
1337
1338 for (row_pos, &entry_idx) in order.iter().enumerate() {
1339 let entry = &entries[entry_idx];
1340 for col in 0..cols {
1341 let dest = row_pos + col * rows_out;
1342 values[dest] = entry.row_data[col].clone();
1343 }
1344 ia.push((entry.a_row + 1) as f64);
1345 ib.push((entry.b_row + 1) as f64);
1346 }
1347
1348 let value_array = StringArray::new(values, vec![rows_out, cols])
1349 .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1350 let ia_tensor = Tensor::new(ia, vec![rows_out, 1])
1351 .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1352 let ib_tensor = Tensor::new(ib, vec![rows_out, 1])
1353 .map_err(|e| intersect_internal_error(format!("intersect: {e}")))?;
1354
1355 Ok(IntersectEvaluation::new(
1356 Value::StringArray(value_array),
1357 ia_tensor,
1358 ib_tensor,
1359 ))
1360}
1361
1362#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1363struct NumericRowKey(Vec<u64>);
1364
1365impl NumericRowKey {
1366 fn from_slice(values: &[f64]) -> Self {
1367 NumericRowKey(values.iter().map(|&v| canonicalize_f64(v)).collect())
1368 }
1369}
1370
1371#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
1372struct ComplexKey {
1373 re: u64,
1374 im: u64,
1375}
1376
1377impl ComplexKey {
1378 fn new(value: (f64, f64)) -> Self {
1379 Self {
1380 re: canonicalize_f64(value.0),
1381 im: canonicalize_f64(value.1),
1382 }
1383 }
1384}
1385
1386#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1387struct RowCharKey(Vec<u32>);
1388
1389impl RowCharKey {
1390 fn from_slice(values: &[char]) -> Self {
1391 RowCharKey(values.iter().map(|&ch| ch as u32).collect())
1392 }
1393}
1394
1395#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1396struct RowStringKey(Vec<String>);
1397
1398impl RowStringKey {
1399 fn from_slice(values: &[String]) -> Self {
1400 RowStringKey(values.to_vec())
1401 }
1402}
1403
1404fn scalar_complex_tensor(re: f64, im: f64) -> crate::BuiltinResult<ComplexTensor> {
1405 ComplexTensor::new(vec![(re, im)], vec![1, 1])
1406 .map_err(|e| intersect_internal_error(format!("intersect: {e}")))
1407}
1408
1409fn tensor_to_complex_owned(name: &str, tensor: Tensor) -> crate::BuiltinResult<ComplexTensor> {
1410 let Tensor { data, shape, .. } = tensor;
1411 let complex: Vec<(f64, f64)> = data.into_iter().map(|re| (re, 0.0)).collect();
1412 ComplexTensor::new(complex, shape).map_err(|e| intersect_internal_error(format!("{name}: {e}")))
1413}
1414
1415fn value_into_complex_tensor(value: Value) -> crate::BuiltinResult<ComplexTensor> {
1416 match value {
1417 Value::ComplexTensor(tensor) => Ok(tensor),
1418 Value::Complex(re, im) => scalar_complex_tensor(re, im),
1419 other => {
1420 let tensor = tensor::value_into_tensor_for("intersect", other)
1421 .map_err(|e| intersect_internal_error(e))?;
1422 tensor_to_complex_owned("intersect", tensor)
1423 }
1424 }
1425}
1426
1427fn canonicalize_f64(value: f64) -> u64 {
1428 if value.is_nan() {
1429 0x7ff8_0000_0000_0000u64
1430 } else if value == 0.0 {
1431 0u64
1432 } else {
1433 value.to_bits()
1434 }
1435}
1436
1437fn compare_f64(a: f64, b: f64) -> Ordering {
1438 if a.is_nan() {
1439 if b.is_nan() {
1440 Ordering::Equal
1441 } else {
1442 Ordering::Greater
1443 }
1444 } else if b.is_nan() {
1445 Ordering::Less
1446 } else {
1447 a.partial_cmp(&b).unwrap_or(Ordering::Equal)
1448 }
1449}
1450
1451fn compare_numeric_rows(a: &[f64], b: &[f64]) -> Ordering {
1452 for (lhs, rhs) in a.iter().zip(b.iter()) {
1453 let ord = compare_f64(*lhs, *rhs);
1454 if ord != Ordering::Equal {
1455 return ord;
1456 }
1457 }
1458 Ordering::Equal
1459}
1460
1461fn complex_is_nan(value: (f64, f64)) -> bool {
1462 value.0.is_nan() || value.1.is_nan()
1463}
1464
1465fn compare_complex(a: (f64, f64), b: (f64, f64)) -> Ordering {
1466 match (complex_is_nan(a), complex_is_nan(b)) {
1467 (true, true) => Ordering::Equal,
1468 (true, false) => Ordering::Greater,
1469 (false, true) => Ordering::Less,
1470 (false, false) => {
1471 let mag_a = a.0.hypot(a.1);
1472 let mag_b = b.0.hypot(b.1);
1473 let mag_cmp = compare_f64(mag_a, mag_b);
1474 if mag_cmp != Ordering::Equal {
1475 return mag_cmp;
1476 }
1477 let re_cmp = compare_f64(a.0, b.0);
1478 if re_cmp != Ordering::Equal {
1479 return re_cmp;
1480 }
1481 compare_f64(a.1, b.1)
1482 }
1483 }
1484}
1485
1486fn compare_complex_rows(a: &[(f64, f64)], b: &[(f64, f64)]) -> Ordering {
1487 for (lhs, rhs) in a.iter().zip(b.iter()) {
1488 let ord = compare_complex(*lhs, *rhs);
1489 if ord != Ordering::Equal {
1490 return ord;
1491 }
1492 }
1493 Ordering::Equal
1494}
1495
1496fn compare_char_rows(a: &[char], b: &[char]) -> Ordering {
1497 for (lhs, rhs) in a.iter().zip(b.iter()) {
1498 let ord = lhs.cmp(rhs);
1499 if ord != Ordering::Equal {
1500 return ord;
1501 }
1502 }
1503 Ordering::Equal
1504}
1505
1506fn compare_string_rows(a: &[String], b: &[String]) -> Ordering {
1507 for (lhs, rhs) in a.iter().zip(b.iter()) {
1508 let ord = lhs.cmp(rhs);
1509 if ord != Ordering::Equal {
1510 return ord;
1511 }
1512 }
1513 Ordering::Equal
1514}
1515
1516#[cfg(test)]
1517pub(crate) mod tests {
1518 use super::*;
1519 use crate::builtins::common::test_support;
1520 use runmat_accelerate_api::HostTensorView;
1521 use runmat_builtins::{ResolveContext, Type};
1522
1523 fn evaluate_sync(
1524 a: Value,
1525 b: Value,
1526 rest: &[Value],
1527 ) -> crate::BuiltinResult<IntersectEvaluation> {
1528 futures::executor::block_on(evaluate(a, b, rest))
1529 }
1530
1531 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1532 #[test]
1533 fn intersect_numeric_sorted() {
1534 let a = Tensor::new(vec![5.0, 7.0, 5.0, 1.0], vec![4, 1]).unwrap();
1535 let b = Tensor::new(vec![7.0, 1.0, 3.0], vec![3, 1]).unwrap();
1536 let eval = intersect_numeric_elements(
1537 a,
1538 b,
1539 &IntersectOptions {
1540 rows: false,
1541 order: IntersectOrder::Sorted,
1542 },
1543 )
1544 .expect("intersect");
1545 let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1546 assert_eq!(values.data, vec![1.0, 7.0]);
1547 let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1548 let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1549 assert_eq!(ia.data, vec![4.0, 2.0]);
1550 assert_eq!(ib.data, vec![2.0, 1.0]);
1551 }
1552
1553 #[test]
1554 fn intersect_type_resolver_numeric() {
1555 assert_eq!(
1556 set_values_output_type(&[Type::tensor()], &ResolveContext::new(Vec::new())),
1557 Type::tensor()
1558 );
1559 }
1560
1561 #[test]
1562 fn intersect_type_resolver_string_array() {
1563 assert_eq!(
1564 set_values_output_type(
1565 &[Type::cell_of(Type::String)],
1566 &ResolveContext::new(Vec::new()),
1567 ),
1568 Type::cell_of(Type::String)
1569 );
1570 }
1571
1572 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1573 #[test]
1574 fn intersect_numeric_stable() {
1575 let a = Tensor::new(vec![4.0, 2.0, 4.0, 1.0, 3.0], vec![5, 1]).unwrap();
1576 let b = Tensor::new(vec![3.0, 4.0, 5.0, 1.0], vec![4, 1]).unwrap();
1577 let eval = intersect_numeric_elements(
1578 a,
1579 b,
1580 &IntersectOptions {
1581 rows: false,
1582 order: IntersectOrder::Stable,
1583 },
1584 )
1585 .expect("intersect");
1586 let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1587 assert_eq!(values.data, vec![4.0, 1.0, 3.0]);
1588 let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1589 let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1590 assert_eq!(ia.data, vec![1.0, 4.0, 5.0]);
1591 assert_eq!(ib.data, vec![2.0, 4.0, 1.0]);
1592 }
1593
1594 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1595 #[test]
1596 fn intersect_numeric_handles_nan() {
1597 let a = Tensor::new(vec![f64::NAN, 1.0, f64::NAN], vec![3, 1]).unwrap();
1598 let b = Tensor::new(vec![2.0, f64::NAN], vec![2, 1]).unwrap();
1599 let eval = intersect_numeric_elements(
1600 a,
1601 b,
1602 &IntersectOptions {
1603 rows: false,
1604 order: IntersectOrder::Sorted,
1605 },
1606 )
1607 .expect("intersect");
1608 let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1609 assert_eq!(values.data.len(), 1);
1610 assert!(values.data[0].is_nan());
1611 let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1612 let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1613 assert_eq!(ia.data, vec![1.0]);
1614 assert_eq!(ib.data, vec![2.0]);
1615 }
1616
1617 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1618 #[test]
1619 fn intersect_complex_with_real_inputs() {
1620 let complex =
1621 ComplexTensor::new(vec![(1.0, 0.0), (2.0, 0.0), (3.0, 1.0)], vec![3, 1]).unwrap();
1622 let real = Tensor::new(vec![2.0, 4.0, 1.0], vec![3, 1]).unwrap();
1623 let real_complex = tensor_to_complex_owned("intersect", real).unwrap();
1624 let eval = intersect_complex(
1625 complex,
1626 real_complex,
1627 &IntersectOptions {
1628 rows: false,
1629 order: IntersectOrder::Sorted,
1630 },
1631 )
1632 .expect("intersect complex");
1633 match eval.values_value() {
1634 Value::ComplexTensor(t) => {
1635 assert_eq!(t.data, vec![(1.0, 0.0), (2.0, 0.0)]);
1636 }
1637 other => panic!("expected complex tensor, got {other:?}"),
1638 }
1639 let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1640 let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1641 assert_eq!(ia.data, vec![1.0, 2.0]);
1642 assert_eq!(ib.data, vec![3.0, 1.0]);
1643 }
1644
1645 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1646 #[test]
1647 fn intersect_numeric_rows_default() {
1648 let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
1649 let b = Tensor::new(vec![1.0, 5.0, 2.0, 6.0], vec![2, 2]).unwrap();
1650 let eval = intersect_numeric_rows(
1651 a,
1652 b,
1653 &IntersectOptions {
1654 rows: true,
1655 order: IntersectOrder::Sorted,
1656 },
1657 )
1658 .expect("intersect rows");
1659 let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1660 assert_eq!(values.shape, vec![1, 2]);
1661 assert_eq!(values.data, vec![1.0, 2.0]);
1662 let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1663 let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1664 assert_eq!(ia.data, vec![1.0]);
1665 assert_eq!(ib.data, vec![1.0]);
1666 }
1667
1668 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1669 #[test]
1670 fn intersect_char_elements_basic() {
1671 let a = CharArray::new("cab".chars().collect(), 1, 3).unwrap();
1672 let b = CharArray::new("bcd".chars().collect(), 1, 3).unwrap();
1673 assert_eq!(find_char_index(&b, 'b'), Some(0));
1674 assert_eq!(find_char_index(&b, 'c'), Some(1));
1675 let b_for_eval = CharArray::new("bcd".chars().collect(), 1, 3).unwrap();
1676 let eval = intersect_char_elements(
1677 a,
1678 b_for_eval,
1679 &IntersectOptions {
1680 rows: false,
1681 order: IntersectOrder::Sorted,
1682 },
1683 )
1684 .expect("intersect char");
1685 match eval.values_value() {
1686 Value::CharArray(arr) => {
1687 assert_eq!(arr.rows, 2);
1688 assert_eq!(arr.cols, 1);
1689 assert_eq!(arr.data, vec!['b', 'c']);
1690 }
1691 other => panic!("expected char array, got {other:?}"),
1692 }
1693 let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1694 let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1695 assert_eq!(ia.data, vec![3.0, 1.0]);
1696 assert_eq!(ib.data, vec![1.0, 2.0]);
1697 }
1698
1699 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1700 #[test]
1701 fn intersect_string_elements_stable() {
1702 let a = StringArray::new(
1703 vec!["apple".into(), "orange".into(), "pear".into()],
1704 vec![3, 1],
1705 )
1706 .unwrap();
1707 let b = StringArray::new(
1708 vec!["pear".into(), "grape".into(), "orange".into()],
1709 vec![3, 1],
1710 )
1711 .unwrap();
1712 let eval = intersect_string_elements(
1713 a,
1714 b,
1715 &IntersectOptions {
1716 rows: false,
1717 order: IntersectOrder::Stable,
1718 },
1719 )
1720 .expect("intersect string");
1721 match eval.values_value() {
1722 Value::StringArray(arr) => {
1723 assert_eq!(arr.shape, vec![2, 1]);
1724 assert_eq!(arr.data, vec!["orange".to_string(), "pear".to_string()]);
1725 }
1726 other => panic!("expected string array, got {other:?}"),
1727 }
1728 let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1729 let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1730 assert_eq!(ia.data, vec![2.0, 3.0]);
1731 assert_eq!(ib.data, vec![3.0, 1.0]);
1732 }
1733
1734 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1735 #[test]
1736 fn intersect_rejects_legacy_option() {
1737 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1738 let err = evaluate_sync(
1739 Value::Tensor(tensor.clone()),
1740 Value::Tensor(tensor),
1741 &[Value::from("legacy")],
1742 )
1743 .unwrap_err();
1744 assert_eq!(
1745 err.identifier(),
1746 INTERSECT_ERROR_LEGACY_OPTION_UNSUPPORTED.identifier
1747 );
1748 }
1749
1750 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1751 #[test]
1752 fn intersect_rejects_conflicting_order_options() {
1753 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1754 let err = evaluate_sync(
1755 Value::Tensor(tensor.clone()),
1756 Value::Tensor(tensor),
1757 &[Value::from("stable"), Value::from("sorted")],
1758 )
1759 .unwrap_err();
1760 assert_eq!(
1761 err.identifier(),
1762 INTERSECT_ERROR_CONFLICTING_ORDER_OPTIONS.identifier
1763 );
1764 }
1765
1766 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1767 #[test]
1768 fn intersect_rejects_unknown_option() {
1769 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1770 let err = evaluate_sync(
1771 Value::Tensor(tensor.clone()),
1772 Value::Tensor(tensor),
1773 &[Value::from("bogus")],
1774 )
1775 .unwrap_err();
1776 assert_eq!(err.identifier(), INTERSECT_ERROR_UNKNOWN_OPTION.identifier);
1777 }
1778
1779 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1780 #[test]
1781 fn intersect_rows_dimension_mismatch() {
1782 let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1783 let b = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1784 let err = intersect_numeric_rows(
1785 a,
1786 b,
1787 &IntersectOptions {
1788 rows: true,
1789 order: IntersectOrder::Sorted,
1790 },
1791 )
1792 .unwrap_err();
1793 assert_eq!(
1794 err.identifier(),
1795 INTERSECT_ERROR_ROWS_COLUMN_MISMATCH.identifier
1796 );
1797 }
1798
1799 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1800 #[test]
1801 fn intersect_mixed_types_error() {
1802 let a = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1803 let b = CharArray::new(vec!['a', 'b'], 1, 2).unwrap();
1804 let err = intersect_host(
1805 Value::Tensor(a),
1806 Value::CharArray(b),
1807 &IntersectOptions {
1808 rows: false,
1809 order: IntersectOrder::Sorted,
1810 },
1811 )
1812 .unwrap_err();
1813 assert_eq!(
1814 err.identifier(),
1815 INTERSECT_ERROR_UNSUPPORTED_INPUT_TYPE.identifier
1816 );
1817 }
1818
1819 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1820 #[test]
1821 fn intersect_gpu_roundtrip() {
1822 test_support::with_test_provider(|provider| {
1823 let a = Tensor::new(vec![4.0, 1.0, 2.0, 1.0], vec![4, 1]).unwrap();
1824 let b = Tensor::new(vec![2.0, 5.0, 1.0], vec![3, 1]).unwrap();
1825 let view_a = HostTensorView {
1826 data: &a.data,
1827 shape: &a.shape,
1828 };
1829 let view_b = HostTensorView {
1830 data: &b.data,
1831 shape: &b.shape,
1832 };
1833 let handle_a = provider.upload(&view_a).expect("upload A");
1834 let handle_b = provider.upload(&view_b).expect("upload B");
1835 let eval = evaluate_sync(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
1836 .expect("intersect");
1837 let values = tensor::value_into_tensor_for("intersect", eval.values_value()).unwrap();
1838 assert_eq!(values.data, vec![1.0, 2.0]);
1839 let ia = tensor::value_into_tensor_for("intersect", eval.ia_value()).unwrap();
1840 let ib = tensor::value_into_tensor_for("intersect", eval.ib_value()).unwrap();
1841 assert_eq!(ia.data, vec![2.0, 3.0]);
1842 assert_eq!(ib.data, vec![3.0, 1.0]);
1843 });
1844 }
1845
1846 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1847 #[test]
1848 fn intersect_two_outputs_from_evaluate() {
1849 let a = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1850 let b = Tensor::new(vec![3.0, 1.0], vec![2, 1]).unwrap();
1851 let eval = intersect_numeric_elements(
1852 a,
1853 b,
1854 &IntersectOptions {
1855 rows: false,
1856 order: IntersectOrder::Sorted,
1857 },
1858 )
1859 .unwrap();
1860 let (_c, ia) = eval.clone().into_pair();
1861 let ia_tensor = tensor::value_into_tensor_for("intersect", ia).unwrap();
1862 assert_eq!(ia_tensor.data, vec![1.0, 3.0]);
1863 let (_c, ia2, ib2) = eval.into_triple();
1864 let ia_tensor2 = tensor::value_into_tensor_for("intersect", ia2).unwrap();
1865 let ib_tensor2 = tensor::value_into_tensor_for("intersect", ib2).unwrap();
1866 assert_eq!(ia_tensor2.data, vec![1.0, 3.0]);
1867 assert_eq!(ib_tensor2.data, vec![2.0, 1.0]);
1868 }
1869
1870 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1871 #[test]
1872 #[cfg(feature = "wgpu")]
1873 fn intersect_wgpu_matches_cpu() {
1874 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1875 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1876 );
1877 let a = Tensor::new(vec![4.0, 1.0, 2.0, 3.0], vec![4, 1]).unwrap();
1878 let b = Tensor::new(vec![2.0, 6.0, 3.0], vec![3, 1]).unwrap();
1879
1880 let cpu_eval = intersect_numeric_elements(
1881 a.clone(),
1882 b.clone(),
1883 &IntersectOptions {
1884 rows: false,
1885 order: IntersectOrder::Sorted,
1886 },
1887 )
1888 .unwrap();
1889 let cpu_values =
1890 tensor::value_into_tensor_for("intersect", cpu_eval.values_value()).unwrap();
1891 let cpu_ia = tensor::value_into_tensor_for("intersect", cpu_eval.ia_value()).unwrap();
1892 let cpu_ib = tensor::value_into_tensor_for("intersect", cpu_eval.ib_value()).unwrap();
1893
1894 let provider = runmat_accelerate_api::provider().expect("provider");
1895 let view_a = HostTensorView {
1896 data: &a.data,
1897 shape: &a.shape,
1898 };
1899 let view_b = HostTensorView {
1900 data: &b.data,
1901 shape: &b.shape,
1902 };
1903 let handle_a = provider.upload(&view_a).expect("upload A");
1904 let handle_b = provider.upload(&view_b).expect("upload B");
1905 let gpu_eval = evaluate_sync(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
1906 .expect("intersect");
1907 let gpu_values =
1908 tensor::value_into_tensor_for("intersect", gpu_eval.values_value()).unwrap();
1909 let gpu_ia = tensor::value_into_tensor_for("intersect", gpu_eval.ia_value()).unwrap();
1910 let gpu_ib = tensor::value_into_tensor_for("intersect", gpu_eval.ib_value()).unwrap();
1911
1912 assert_eq!(gpu_values.data, cpu_values.data);
1913 assert_eq!(gpu_ia.data, cpu_ia.data);
1914 assert_eq!(gpu_ib.data, cpu_ib.data);
1915 }
1916}