1use std::cmp::Ordering;
9use std::collections::{HashMap, HashSet};
10
11use runmat_accelerate_api::{
12 GpuTensorHandle, GpuTensorStorage, HostTensorOwned, SetdiffOptions, SetdiffOrder, SetdiffResult,
13};
14use runmat_builtins::{CharArray, ComplexTensor, StringArray, Tensor, Value};
15use runmat_macros::runtime_builtin;
16
17use super::type_resolvers::set_values_output_type;
18use crate::build_runtime_error;
19use crate::builtins::common::arg_tokens::tokens_from_values;
20use crate::builtins::common::gpu_helpers;
21use crate::builtins::common::random_args::complex_tensor_into_value;
22use crate::builtins::common::spec::{
23 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
24 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
25};
26use crate::builtins::common::tensor;
27
28#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::sorting_sets::setdiff")]
29pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
30 name: "setdiff",
31 op_kind: GpuOpKind::Custom("setdiff"),
32 supported_precisions: &[ScalarType::F32, ScalarType::F64],
33 broadcast: BroadcastSemantics::None,
34 provider_hooks: &[ProviderHook::Custom("setdiff")],
35 constant_strategy: ConstantStrategy::InlineLiteral,
36 residency: ResidencyPolicy::GatherImmediately,
37 nan_mode: ReductionNaN::Include,
38 two_pass_threshold: None,
39 workgroup_size: None,
40 accepts_nan_mode: true,
41 notes: "Providers may implement `setdiff`; until then tensors are gathered and processed on the host.",
42};
43
44#[runmat_macros::register_fusion_spec(
45 builtin_path = "crate::builtins::array::sorting_sets::setdiff"
46)]
47pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
48 name: "setdiff",
49 shape: ShapeRequirements::Any,
50 constant_strategy: ConstantStrategy::InlineLiteral,
51 elementwise: None,
52 reduction: None,
53 emits_nan: true,
54 notes: "`setdiff` terminates fusion chains and materialises results on the host; upstream tensors are gathered when necessary.",
55};
56
57fn setdiff_error(message: impl Into<String>) -> crate::RuntimeError {
58 build_runtime_error(message).with_builtin("setdiff").build()
59}
60
61#[runtime_builtin(
62 name = "setdiff",
63 category = "array/sorting_sets",
64 summary = "Return the values that appear in the first input but not the second.",
65 keywords = "setdiff,difference,stable,rows,indices,gpu",
66 accel = "array_construct",
67 sink = true,
68 type_resolver(set_values_output_type),
69 builtin_path = "crate::builtins::array::sorting_sets::setdiff"
70)]
71async fn setdiff_builtin(a: Value, b: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
72 Ok(evaluate(a, b, &rest).await?.into_values_value())
73}
74
75pub async fn evaluate(
77 a: Value,
78 b: Value,
79 rest: &[Value],
80) -> crate::BuiltinResult<SetdiffEvaluation> {
81 let opts = parse_options(rest)?;
82 match (a, b) {
83 (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
84 setdiff_gpu_pair(handle_a, handle_b, &opts).await
85 }
86 (Value::GpuTensor(handle_a), other) => {
87 setdiff_gpu_mixed(handle_a, other, &opts, true).await
88 }
89 (other, Value::GpuTensor(handle_b)) => {
90 setdiff_gpu_mixed(handle_b, other, &opts, false).await
91 }
92 (left, right) => setdiff_host(left, right, &opts),
93 }
94}
95
96fn parse_options(rest: &[Value]) -> crate::BuiltinResult<SetdiffOptions> {
97 let mut opts = SetdiffOptions {
98 rows: false,
99 order: SetdiffOrder::Sorted,
100 };
101 let mut seen_order: Option<SetdiffOrder> = None;
102
103 let tokens = tokens_from_values(rest);
104 for (arg, token) in rest.iter().zip(tokens.iter()) {
105 let text = match token {
106 crate::builtins::common::arg_tokens::ArgToken::String(text) => text.as_str(),
107 _ => {
108 let text = tensor::value_to_string(arg)
109 .ok_or_else(|| setdiff_error("setdiff: expected string option arguments"))?;
110 let lowered = text.trim().to_ascii_lowercase();
111 parse_setdiff_option(&mut opts, &mut seen_order, &lowered)?;
112 continue;
113 }
114 };
115 parse_setdiff_option(&mut opts, &mut seen_order, text)?;
116 }
117
118 Ok(opts)
119}
120
121fn parse_setdiff_option(
122 opts: &mut SetdiffOptions,
123 seen_order: &mut Option<SetdiffOrder>,
124 lowered: &str,
125) -> crate::BuiltinResult<()> {
126 match lowered {
127 "rows" => opts.rows = true,
128 "sorted" => {
129 if let Some(prev) = seen_order {
130 if *prev != SetdiffOrder::Sorted {
131 return Err(setdiff_error(
132 "setdiff: cannot combine 'sorted' with 'stable'",
133 ));
134 }
135 }
136 *seen_order = Some(SetdiffOrder::Sorted);
137 opts.order = SetdiffOrder::Sorted;
138 }
139 "stable" => {
140 if let Some(prev) = seen_order {
141 if *prev != SetdiffOrder::Stable {
142 return Err(setdiff_error(
143 "setdiff: cannot combine 'sorted' with 'stable'",
144 ));
145 }
146 }
147 *seen_order = Some(SetdiffOrder::Stable);
148 opts.order = SetdiffOrder::Stable;
149 }
150 "legacy" | "r2012a" => {
151 return Err(setdiff_error(
152 "setdiff: the 'legacy' behaviour is not supported",
153 ));
154 }
155 other => {
156 return Err(setdiff_error(format!(
157 "setdiff: unrecognised option '{other}'"
158 )))
159 }
160 }
161 Ok(())
162}
163
164async fn setdiff_gpu_pair(
165 handle_a: GpuTensorHandle,
166 handle_b: GpuTensorHandle,
167 opts: &SetdiffOptions,
168) -> crate::BuiltinResult<SetdiffEvaluation> {
169 if let Some(provider) = runmat_accelerate_api::provider() {
170 match provider.setdiff(&handle_a, &handle_b, opts).await {
171 Ok(result) => return SetdiffEvaluation::from_setdiff_result(result),
172 Err(_) => {
173 }
175 }
176 }
177 let a_tensor = gpu_helpers::gather_tensor_async(&handle_a).await?;
178 let b_tensor = gpu_helpers::gather_tensor_async(&handle_b).await?;
179 setdiff_numeric(a_tensor, b_tensor, opts)
180}
181
182async fn setdiff_gpu_mixed(
183 handle_gpu: GpuTensorHandle,
184 other: Value,
185 opts: &SetdiffOptions,
186 gpu_is_a: bool,
187) -> crate::BuiltinResult<SetdiffEvaluation> {
188 let gpu_tensor = gpu_helpers::gather_tensor_async(&handle_gpu).await?;
189 let other_tensor =
190 tensor::value_into_tensor_for("setdiff", other).map_err(|e| setdiff_error(e))?;
191 if gpu_is_a {
192 setdiff_numeric(gpu_tensor, other_tensor, opts)
193 } else {
194 setdiff_numeric(other_tensor, gpu_tensor, opts)
195 }
196}
197
198fn setdiff_host(
199 a: Value,
200 b: Value,
201 opts: &SetdiffOptions,
202) -> crate::BuiltinResult<SetdiffEvaluation> {
203 match (a, b) {
204 (Value::ComplexTensor(at), Value::ComplexTensor(bt)) => setdiff_complex(at, bt, opts),
205 (Value::ComplexTensor(at), Value::Complex(re, im)) => {
206 let bt = ComplexTensor::new(vec![(re, im)], vec![1, 1])
207 .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
208 setdiff_complex(at, bt, opts)
209 }
210 (Value::Complex(a_re, a_im), Value::ComplexTensor(bt)) => {
211 let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
212 .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
213 setdiff_complex(at, bt, opts)
214 }
215 (Value::Complex(a_re, a_im), Value::Complex(b_re, b_im)) => {
216 let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
217 .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
218 let bt = ComplexTensor::new(vec![(b_re, b_im)], vec![1, 1])
219 .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
220 setdiff_complex(at, bt, opts)
221 }
222
223 (Value::CharArray(ac), Value::CharArray(bc)) => setdiff_char(ac, bc, opts),
224
225 (Value::StringArray(astring), Value::StringArray(bstring)) => {
226 setdiff_string(astring, bstring, opts)
227 }
228 (Value::StringArray(astring), Value::String(b)) => {
229 let bstring = StringArray::new(vec![b], vec![1, 1])
230 .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
231 setdiff_string(astring, bstring, opts)
232 }
233 (Value::String(a), Value::StringArray(bstring)) => {
234 let astring = StringArray::new(vec![a], vec![1, 1])
235 .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
236 setdiff_string(astring, bstring, opts)
237 }
238 (Value::String(a), Value::String(b)) => {
239 let astring = StringArray::new(vec![a], vec![1, 1])
240 .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
241 let bstring = StringArray::new(vec![b], vec![1, 1])
242 .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
243 setdiff_string(astring, bstring, opts)
244 }
245
246 (left, right) => {
247 let tensor_a =
248 tensor::value_into_tensor_for("setdiff", left).map_err(|e| setdiff_error(e))?;
249 let tensor_b =
250 tensor::value_into_tensor_for("setdiff", right).map_err(|e| setdiff_error(e))?;
251 setdiff_numeric(tensor_a, tensor_b, opts)
252 }
253 }
254}
255
256fn setdiff_numeric(
257 a: Tensor,
258 b: Tensor,
259 opts: &SetdiffOptions,
260) -> crate::BuiltinResult<SetdiffEvaluation> {
261 if opts.rows {
262 setdiff_numeric_rows(a, b, opts)
263 } else {
264 setdiff_numeric_elements(a, b, opts)
265 }
266}
267
268pub fn setdiff_numeric_from_tensors(
270 a: Tensor,
271 b: Tensor,
272 opts: &SetdiffOptions,
273) -> crate::BuiltinResult<SetdiffEvaluation> {
274 setdiff_numeric(a, b, opts)
275}
276
277fn setdiff_numeric_elements(
278 a: Tensor,
279 b: Tensor,
280 opts: &SetdiffOptions,
281) -> crate::BuiltinResult<SetdiffEvaluation> {
282 let mut b_keys: HashSet<u64> = HashSet::new();
283 for &value in &b.data {
284 b_keys.insert(canonicalize_f64(value));
285 }
286
287 let mut seen: HashMap<u64, usize> = HashMap::new();
288 let mut entries = Vec::<NumericDiffEntry>::new();
289 let mut order_counter = 0usize;
290
291 for (idx, &value) in a.data.iter().enumerate() {
292 let key = canonicalize_f64(value);
293 if b_keys.contains(&key) {
294 continue;
295 }
296 if seen.contains_key(&key) {
297 continue;
298 }
299 let entry_idx = entries.len();
300 entries.push(NumericDiffEntry {
301 value,
302 index: idx,
303 order_rank: order_counter,
304 });
305 seen.insert(key, entry_idx);
306 order_counter += 1;
307 }
308
309 assemble_numeric_setdiff(entries, opts)
310}
311
312fn setdiff_numeric_rows(
313 a: Tensor,
314 b: Tensor,
315 opts: &SetdiffOptions,
316) -> crate::BuiltinResult<SetdiffEvaluation> {
317 if a.shape.len() != 2 || b.shape.len() != 2 {
318 return Err(setdiff_error(
319 "setdiff: 'rows' option requires 2-D numeric matrices",
320 ));
321 }
322 if a.shape[1] != b.shape[1] {
323 return Err(setdiff_error(
324 "setdiff: inputs must have the same number of columns when using 'rows'",
325 ));
326 }
327
328 let rows_a = a.shape[0];
329 let rows_b = b.shape[0];
330 let cols = a.shape[1];
331
332 let mut b_keys: HashSet<NumericRowKey> = HashSet::new();
333 for r in 0..rows_b {
334 let mut row_values = Vec::with_capacity(cols);
335 for c in 0..cols {
336 let idx = r + c * rows_b;
337 row_values.push(b.data[idx]);
338 }
339 b_keys.insert(NumericRowKey::from_slice(&row_values));
340 }
341
342 let mut seen: HashSet<NumericRowKey> = HashSet::new();
343 let mut entries = Vec::<NumericRowDiffEntry>::new();
344 let mut order_counter = 0usize;
345
346 for r in 0..rows_a {
347 let mut row_values = Vec::with_capacity(cols);
348 for c in 0..cols {
349 let idx = r + c * rows_a;
350 row_values.push(a.data[idx]);
351 }
352 let key = NumericRowKey::from_slice(&row_values);
353 if b_keys.contains(&key) {
354 continue;
355 }
356 if !seen.insert(key) {
357 continue;
358 }
359 entries.push(NumericRowDiffEntry {
360 row_data: row_values,
361 row_index: r,
362 order_rank: order_counter,
363 });
364 order_counter += 1;
365 }
366
367 assemble_numeric_row_setdiff(entries, opts, cols)
368}
369
370fn setdiff_complex(
371 a: ComplexTensor,
372 b: ComplexTensor,
373 opts: &SetdiffOptions,
374) -> crate::BuiltinResult<SetdiffEvaluation> {
375 if opts.rows {
376 setdiff_complex_rows(a, b, opts)
377 } else {
378 setdiff_complex_elements(a, b, opts)
379 }
380}
381
382fn setdiff_complex_elements(
383 a: ComplexTensor,
384 b: ComplexTensor,
385 opts: &SetdiffOptions,
386) -> crate::BuiltinResult<SetdiffEvaluation> {
387 let mut b_keys: HashSet<ComplexKey> = HashSet::new();
388 for &value in &b.data {
389 b_keys.insert(ComplexKey::new(value));
390 }
391
392 let mut seen: HashSet<ComplexKey> = HashSet::new();
393 let mut entries = Vec::<ComplexDiffEntry>::new();
394 let mut order_counter = 0usize;
395
396 for (idx, &value) in a.data.iter().enumerate() {
397 let key = ComplexKey::new(value);
398 if b_keys.contains(&key) {
399 continue;
400 }
401 if !seen.insert(key) {
402 continue;
403 }
404 entries.push(ComplexDiffEntry {
405 value,
406 index: idx,
407 order_rank: order_counter,
408 });
409 order_counter += 1;
410 }
411
412 assemble_complex_setdiff(entries, opts)
413}
414
415fn setdiff_complex_rows(
416 a: ComplexTensor,
417 b: ComplexTensor,
418 opts: &SetdiffOptions,
419) -> crate::BuiltinResult<SetdiffEvaluation> {
420 if a.shape.len() != 2 || b.shape.len() != 2 {
421 return Err(setdiff_error(
422 "setdiff: 'rows' option requires 2-D complex matrices",
423 ));
424 }
425 if a.shape[1] != b.shape[1] {
426 return Err(setdiff_error(
427 "setdiff: inputs must have the same number of columns when using 'rows'",
428 ));
429 }
430
431 let rows_a = a.shape[0];
432 let rows_b = b.shape[0];
433 let cols = a.shape[1];
434
435 let mut b_keys: HashSet<Vec<ComplexKey>> = HashSet::new();
436 for r in 0..rows_b {
437 let mut key_row = Vec::with_capacity(cols);
438 for c in 0..cols {
439 let idx = r + c * rows_b;
440 key_row.push(ComplexKey::new(b.data[idx]));
441 }
442 b_keys.insert(key_row);
443 }
444
445 let mut seen: HashSet<Vec<ComplexKey>> = HashSet::new();
446 let mut entries = Vec::<ComplexRowDiffEntry>::new();
447 let mut order_counter = 0usize;
448
449 for r in 0..rows_a {
450 let mut row_values = Vec::with_capacity(cols);
451 let mut key_row = Vec::with_capacity(cols);
452 for c in 0..cols {
453 let idx = r + c * rows_a;
454 let value = a.data[idx];
455 row_values.push(value);
456 key_row.push(ComplexKey::new(value));
457 }
458 if b_keys.contains(&key_row) {
459 continue;
460 }
461 if !seen.insert(key_row) {
462 continue;
463 }
464 entries.push(ComplexRowDiffEntry {
465 row_data: row_values,
466 row_index: r,
467 order_rank: order_counter,
468 });
469 order_counter += 1;
470 }
471
472 assemble_complex_row_setdiff(entries, opts, cols)
473}
474
475fn setdiff_char(
476 a: CharArray,
477 b: CharArray,
478 opts: &SetdiffOptions,
479) -> crate::BuiltinResult<SetdiffEvaluation> {
480 if opts.rows {
481 setdiff_char_rows(a, b, opts)
482 } else {
483 setdiff_char_elements(a, b, opts)
484 }
485}
486
487fn setdiff_char_elements(
488 a: CharArray,
489 b: CharArray,
490 opts: &SetdiffOptions,
491) -> crate::BuiltinResult<SetdiffEvaluation> {
492 let mut b_keys: HashSet<u32> = HashSet::new();
493 for ch in &b.data {
494 b_keys.insert(*ch as u32);
495 }
496
497 let mut seen: HashSet<u32> = HashSet::new();
498 let mut entries = Vec::<CharDiffEntry>::new();
499 let mut order_counter = 0usize;
500
501 for col in 0..a.cols {
502 for row in 0..a.rows {
503 let linear_idx = row + col * a.rows;
504 let data_idx = row * a.cols + col;
505 let ch = a.data[data_idx];
506 let key = ch as u32;
507 if b_keys.contains(&key) {
508 continue;
509 }
510 if !seen.insert(key) {
511 continue;
512 }
513 entries.push(CharDiffEntry {
514 ch,
515 index: linear_idx,
516 order_rank: order_counter,
517 });
518 order_counter += 1;
519 }
520 }
521
522 assemble_char_setdiff(entries, opts)
523}
524
525fn setdiff_char_rows(
526 a: CharArray,
527 b: CharArray,
528 opts: &SetdiffOptions,
529) -> crate::BuiltinResult<SetdiffEvaluation> {
530 if a.cols != b.cols {
531 return Err(setdiff_error(
532 "setdiff: inputs must have the same number of columns when using 'rows'",
533 ));
534 }
535
536 let rows_a = a.rows;
537 let rows_b = b.rows;
538 let cols = a.cols;
539
540 let mut b_keys: HashSet<RowCharKey> = HashSet::new();
541 for r in 0..rows_b {
542 let mut row_values = Vec::with_capacity(cols);
543 for c in 0..cols {
544 let idx = r * cols + c;
545 row_values.push(b.data[idx]);
546 }
547 b_keys.insert(RowCharKey::from_slice(&row_values));
548 }
549
550 let mut seen: HashSet<RowCharKey> = HashSet::new();
551 let mut entries = Vec::<CharRowDiffEntry>::new();
552 let mut order_counter = 0usize;
553
554 for r in 0..rows_a {
555 let mut row_values = Vec::with_capacity(cols);
556 for c in 0..cols {
557 let idx = r * cols + c;
558 row_values.push(a.data[idx]);
559 }
560 let key = RowCharKey::from_slice(&row_values);
561 if b_keys.contains(&key) {
562 continue;
563 }
564 if !seen.insert(key) {
565 continue;
566 }
567 entries.push(CharRowDiffEntry {
568 row_data: row_values,
569 row_index: r,
570 order_rank: order_counter,
571 });
572 order_counter += 1;
573 }
574
575 assemble_char_row_setdiff(entries, opts, cols)
576}
577
578fn setdiff_string(
579 a: StringArray,
580 b: StringArray,
581 opts: &SetdiffOptions,
582) -> crate::BuiltinResult<SetdiffEvaluation> {
583 if opts.rows {
584 setdiff_string_rows(a, b, opts)
585 } else {
586 setdiff_string_elements(a, b, opts)
587 }
588}
589
590fn setdiff_string_elements(
591 a: StringArray,
592 b: StringArray,
593 opts: &SetdiffOptions,
594) -> crate::BuiltinResult<SetdiffEvaluation> {
595 let mut b_keys: HashSet<String> = HashSet::new();
596 for value in &b.data {
597 b_keys.insert(value.clone());
598 }
599
600 let mut seen: HashSet<String> = HashSet::new();
601 let mut entries = Vec::<StringDiffEntry>::new();
602 let mut order_counter = 0usize;
603
604 for (idx, value) in a.data.iter().enumerate() {
605 if b_keys.contains(value) {
606 continue;
607 }
608 if !seen.insert(value.clone()) {
609 continue;
610 }
611 entries.push(StringDiffEntry {
612 value: value.clone(),
613 index: idx,
614 order_rank: order_counter,
615 });
616 order_counter += 1;
617 }
618
619 assemble_string_setdiff(entries, opts)
620}
621
622fn setdiff_string_rows(
623 a: StringArray,
624 b: StringArray,
625 opts: &SetdiffOptions,
626) -> crate::BuiltinResult<SetdiffEvaluation> {
627 if a.shape.len() != 2 || b.shape.len() != 2 {
628 return Err(setdiff_error(
629 "setdiff: 'rows' option requires 2-D string arrays",
630 ));
631 }
632 if a.shape[1] != b.shape[1] {
633 return Err(setdiff_error(
634 "setdiff: inputs must have the same number of columns when using 'rows'",
635 ));
636 }
637
638 let rows_a = a.shape[0];
639 let rows_b = b.shape[0];
640 let cols = a.shape[1];
641
642 let mut b_keys: HashSet<RowStringKey> = HashSet::new();
643 for r in 0..rows_b {
644 let mut row_values = Vec::with_capacity(cols);
645 for c in 0..cols {
646 let idx = r + c * rows_b;
647 row_values.push(b.data[idx].clone());
648 }
649 b_keys.insert(RowStringKey(row_values.clone()));
650 }
651
652 let mut seen: HashSet<RowStringKey> = HashSet::new();
653 let mut entries = Vec::<StringRowDiffEntry>::new();
654 let mut order_counter = 0usize;
655
656 for r in 0..rows_a {
657 let mut row_values = Vec::with_capacity(cols);
658 for c in 0..cols {
659 let idx = r + c * rows_a;
660 row_values.push(a.data[idx].clone());
661 }
662 let key = RowStringKey(row_values.clone());
663 if b_keys.contains(&key) {
664 continue;
665 }
666 if !seen.insert(key) {
667 continue;
668 }
669 entries.push(StringRowDiffEntry {
670 row_data: row_values,
671 row_index: r,
672 order_rank: order_counter,
673 });
674 order_counter += 1;
675 }
676
677 assemble_string_row_setdiff(entries, opts, cols)
678}
679
680fn assemble_numeric_setdiff(
681 entries: Vec<NumericDiffEntry>,
682 opts: &SetdiffOptions,
683) -> crate::BuiltinResult<SetdiffEvaluation> {
684 let mut order: Vec<usize> = (0..entries.len()).collect();
685 match opts.order {
686 SetdiffOrder::Sorted => {
687 order.sort_by(|&lhs, &rhs| compare_f64(entries[lhs].value, entries[rhs].value));
688 }
689 SetdiffOrder::Stable => {
690 order.sort_by_key(|&idx| entries[idx].order_rank);
691 }
692 }
693
694 let mut values = Vec::with_capacity(order.len());
695 let mut ia = Vec::with_capacity(order.len());
696 for &idx in &order {
697 let entry = &entries[idx];
698 values.push(entry.value);
699 ia.push((entry.index + 1) as f64);
700 }
701
702 let value_tensor = Tensor::new(values, vec![order.len(), 1])
703 .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
704 let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
705 .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
706
707 Ok(SetdiffEvaluation::new(
708 Value::Tensor(value_tensor),
709 ia_tensor,
710 ))
711}
712
713fn assemble_numeric_row_setdiff(
714 entries: Vec<NumericRowDiffEntry>,
715 opts: &SetdiffOptions,
716 cols: usize,
717) -> crate::BuiltinResult<SetdiffEvaluation> {
718 let mut order: Vec<usize> = (0..entries.len()).collect();
719 match opts.order {
720 SetdiffOrder::Sorted => {
721 order.sort_by(|&lhs, &rhs| {
722 compare_numeric_rows(&entries[lhs].row_data, &entries[rhs].row_data)
723 });
724 }
725 SetdiffOrder::Stable => {
726 order.sort_by_key(|&idx| entries[idx].order_rank);
727 }
728 }
729
730 let unique_rows = order.len();
731 let mut values = vec![0.0f64; unique_rows * cols];
732 let mut ia = Vec::with_capacity(unique_rows);
733
734 for (row_pos, &entry_idx) in order.iter().enumerate() {
735 let entry = &entries[entry_idx];
736 for col in 0..cols {
737 let dest = row_pos + col * unique_rows;
738 values[dest] = entry.row_data[col];
739 }
740 ia.push((entry.row_index + 1) as f64);
741 }
742
743 let value_tensor = Tensor::new(values, vec![unique_rows, cols])
744 .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
745 let ia_tensor = Tensor::new(ia, vec![unique_rows, 1])
746 .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
747
748 Ok(SetdiffEvaluation::new(
749 Value::Tensor(value_tensor),
750 ia_tensor,
751 ))
752}
753
754fn assemble_complex_setdiff(
755 entries: Vec<ComplexDiffEntry>,
756 opts: &SetdiffOptions,
757) -> crate::BuiltinResult<SetdiffEvaluation> {
758 let mut order: Vec<usize> = (0..entries.len()).collect();
759 match opts.order {
760 SetdiffOrder::Sorted => {
761 order.sort_by(|&lhs, &rhs| compare_complex(entries[lhs].value, entries[rhs].value));
762 }
763 SetdiffOrder::Stable => {
764 order.sort_by_key(|&idx| entries[idx].order_rank);
765 }
766 }
767
768 let mut values = Vec::with_capacity(order.len());
769 let mut ia = Vec::with_capacity(order.len());
770 for &idx in &order {
771 let entry = &entries[idx];
772 values.push(entry.value);
773 ia.push((entry.index + 1) as f64);
774 }
775
776 let value_tensor = ComplexTensor::new(values, vec![order.len(), 1])
777 .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
778 let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
779 .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
780
781 Ok(SetdiffEvaluation::new(
782 complex_tensor_into_value(value_tensor),
783 ia_tensor,
784 ))
785}
786
787fn assemble_complex_row_setdiff(
788 entries: Vec<ComplexRowDiffEntry>,
789 opts: &SetdiffOptions,
790 cols: usize,
791) -> crate::BuiltinResult<SetdiffEvaluation> {
792 let mut order: Vec<usize> = (0..entries.len()).collect();
793 match opts.order {
794 SetdiffOrder::Sorted => {
795 order.sort_by(|&lhs, &rhs| {
796 compare_complex_rows(&entries[lhs].row_data, &entries[rhs].row_data)
797 });
798 }
799 SetdiffOrder::Stable => {
800 order.sort_by_key(|&idx| entries[idx].order_rank);
801 }
802 }
803
804 let unique_rows = order.len();
805 let mut values = vec![(0.0f64, 0.0f64); unique_rows * cols];
806 let mut ia = Vec::with_capacity(unique_rows);
807
808 for (row_pos, &entry_idx) in order.iter().enumerate() {
809 let entry = &entries[entry_idx];
810 for col in 0..cols {
811 let dest = row_pos + col * unique_rows;
812 values[dest] = entry.row_data[col];
813 }
814 ia.push((entry.row_index + 1) as f64);
815 }
816
817 let value_tensor = ComplexTensor::new(values, vec![unique_rows, cols])
818 .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
819 let ia_tensor = Tensor::new(ia, vec![unique_rows, 1])
820 .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
821
822 Ok(SetdiffEvaluation::new(
823 complex_tensor_into_value(value_tensor),
824 ia_tensor,
825 ))
826}
827
828fn assemble_char_setdiff(
829 entries: Vec<CharDiffEntry>,
830 opts: &SetdiffOptions,
831) -> crate::BuiltinResult<SetdiffEvaluation> {
832 let mut order: Vec<usize> = (0..entries.len()).collect();
833 match opts.order {
834 SetdiffOrder::Sorted => {
835 order.sort_by(|&lhs, &rhs| entries[lhs].ch.cmp(&entries[rhs].ch));
836 }
837 SetdiffOrder::Stable => {
838 order.sort_by_key(|&idx| entries[idx].order_rank);
839 }
840 }
841
842 let mut values = Vec::with_capacity(order.len());
843 let mut ia = Vec::with_capacity(order.len());
844 for &idx in &order {
845 let entry = &entries[idx];
846 values.push(entry.ch);
847 ia.push((entry.index + 1) as f64);
848 }
849
850 let value_array = CharArray::new(values, order.len(), 1)
851 .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
852 let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
853 .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
854
855 Ok(SetdiffEvaluation::new(
856 Value::CharArray(value_array),
857 ia_tensor,
858 ))
859}
860
861fn assemble_char_row_setdiff(
862 entries: Vec<CharRowDiffEntry>,
863 opts: &SetdiffOptions,
864 cols: usize,
865) -> crate::BuiltinResult<SetdiffEvaluation> {
866 let mut order: Vec<usize> = (0..entries.len()).collect();
867 match opts.order {
868 SetdiffOrder::Sorted => {
869 order.sort_by(|&lhs, &rhs| {
870 compare_char_rows(&entries[lhs].row_data, &entries[rhs].row_data)
871 });
872 }
873 SetdiffOrder::Stable => {
874 order.sort_by_key(|&idx| entries[idx].order_rank);
875 }
876 }
877
878 let unique_rows = order.len();
879 let mut values = vec!['\0'; unique_rows * cols];
880 let mut ia = Vec::with_capacity(unique_rows);
881
882 for (row_pos, &entry_idx) in order.iter().enumerate() {
883 let entry = &entries[entry_idx];
884 for col in 0..cols {
885 let dest = row_pos * cols + col;
886 values[dest] = entry.row_data[col];
887 }
888 ia.push((entry.row_index + 1) as f64);
889 }
890
891 let value_array = CharArray::new(values, unique_rows, cols)
892 .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
893 let ia_tensor = Tensor::new(ia, vec![unique_rows, 1])
894 .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
895
896 Ok(SetdiffEvaluation::new(
897 Value::CharArray(value_array),
898 ia_tensor,
899 ))
900}
901
902fn assemble_string_setdiff(
903 entries: Vec<StringDiffEntry>,
904 opts: &SetdiffOptions,
905) -> crate::BuiltinResult<SetdiffEvaluation> {
906 let mut order: Vec<usize> = (0..entries.len()).collect();
907 match opts.order {
908 SetdiffOrder::Sorted => {
909 order.sort_by(|&lhs, &rhs| entries[lhs].value.cmp(&entries[rhs].value));
910 }
911 SetdiffOrder::Stable => {
912 order.sort_by_key(|&idx| entries[idx].order_rank);
913 }
914 }
915
916 let mut values = Vec::with_capacity(order.len());
917 let mut ia = Vec::with_capacity(order.len());
918 for &idx in &order {
919 let entry = &entries[idx];
920 values.push(entry.value.clone());
921 ia.push((entry.index + 1) as f64);
922 }
923
924 let value_array = StringArray::new(values, vec![order.len(), 1])
925 .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
926 let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
927 .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
928
929 Ok(SetdiffEvaluation::new(
930 Value::StringArray(value_array),
931 ia_tensor,
932 ))
933}
934
935fn assemble_string_row_setdiff(
936 entries: Vec<StringRowDiffEntry>,
937 opts: &SetdiffOptions,
938 cols: usize,
939) -> crate::BuiltinResult<SetdiffEvaluation> {
940 let mut order: Vec<usize> = (0..entries.len()).collect();
941 match opts.order {
942 SetdiffOrder::Sorted => {
943 order.sort_by(|&lhs, &rhs| {
944 compare_string_rows(&entries[lhs].row_data, &entries[rhs].row_data)
945 });
946 }
947 SetdiffOrder::Stable => {
948 order.sort_by_key(|&idx| entries[idx].order_rank);
949 }
950 }
951
952 let unique_rows = order.len();
953 let mut values = vec![String::new(); unique_rows * cols];
954 let mut ia = Vec::with_capacity(unique_rows);
955
956 for (row_pos, &entry_idx) in order.iter().enumerate() {
957 let entry = &entries[entry_idx];
958 for col in 0..cols {
959 let dest = row_pos + col * unique_rows;
960 values[dest] = entry.row_data[col].clone();
961 }
962 ia.push((entry.row_index + 1) as f64);
963 }
964
965 let value_array = StringArray::new(values, vec![unique_rows, cols])
966 .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
967 let ia_tensor = Tensor::new(ia, vec![unique_rows, 1])
968 .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
969
970 Ok(SetdiffEvaluation::new(
971 Value::StringArray(value_array),
972 ia_tensor,
973 ))
974}
975
976#[derive(Clone, Copy, Debug)]
977struct NumericDiffEntry {
978 value: f64,
979 index: usize,
980 order_rank: usize,
981}
982
983#[derive(Clone, Debug)]
984struct NumericRowDiffEntry {
985 row_data: Vec<f64>,
986 row_index: usize,
987 order_rank: usize,
988}
989
990#[derive(Clone, Copy, Debug)]
991struct ComplexDiffEntry {
992 value: (f64, f64),
993 index: usize,
994 order_rank: usize,
995}
996
997#[derive(Clone, Debug)]
998struct ComplexRowDiffEntry {
999 row_data: Vec<(f64, f64)>,
1000 row_index: usize,
1001 order_rank: usize,
1002}
1003
1004#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1005struct CharDiffEntry {
1006 ch: char,
1007 index: usize,
1008 order_rank: usize,
1009}
1010
1011#[derive(Clone, Debug)]
1012struct CharRowDiffEntry {
1013 row_data: Vec<char>,
1014 row_index: usize,
1015 order_rank: usize,
1016}
1017
1018#[derive(Clone, Debug)]
1019struct StringDiffEntry {
1020 value: String,
1021 index: usize,
1022 order_rank: usize,
1023}
1024
1025#[derive(Clone, Debug)]
1026struct StringRowDiffEntry {
1027 row_data: Vec<String>,
1028 row_index: usize,
1029 order_rank: usize,
1030}
1031
1032#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1033struct NumericRowKey(Vec<u64>);
1034
1035impl NumericRowKey {
1036 fn from_slice(values: &[f64]) -> Self {
1037 NumericRowKey(values.iter().map(|&v| canonicalize_f64(v)).collect())
1038 }
1039}
1040
1041#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
1042struct ComplexKey {
1043 re: u64,
1044 im: u64,
1045}
1046
1047impl ComplexKey {
1048 fn new(value: (f64, f64)) -> Self {
1049 Self {
1050 re: canonicalize_f64(value.0),
1051 im: canonicalize_f64(value.1),
1052 }
1053 }
1054}
1055
1056#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1057struct RowCharKey(Vec<u32>);
1058
1059impl RowCharKey {
1060 fn from_slice(values: &[char]) -> Self {
1061 RowCharKey(values.iter().map(|&ch| ch as u32).collect())
1062 }
1063}
1064
1065#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1066struct RowStringKey(Vec<String>);
1067
1068#[derive(Debug)]
1069pub struct SetdiffEvaluation {
1070 values: Value,
1071 ia: Tensor,
1072}
1073
1074impl SetdiffEvaluation {
1075 fn new(values: Value, ia: Tensor) -> Self {
1076 Self { values, ia }
1077 }
1078
1079 pub fn from_setdiff_result(result: SetdiffResult) -> crate::BuiltinResult<Self> {
1080 let SetdiffResult { values, ia } = result;
1081 let values_tensor = Tensor::new(values.data, values.shape)
1082 .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
1083 let ia_tensor =
1084 Tensor::new(ia.data, ia.shape).map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
1085 Ok(SetdiffEvaluation::new(
1086 Value::Tensor(values_tensor),
1087 ia_tensor,
1088 ))
1089 }
1090
1091 pub fn into_numeric_setdiff_result(self) -> crate::BuiltinResult<SetdiffResult> {
1092 let SetdiffEvaluation { values, ia } = self;
1093 let values_tensor =
1094 tensor::value_into_tensor_for("setdiff", values).map_err(|e| setdiff_error(e))?;
1095 Ok(SetdiffResult {
1096 values: HostTensorOwned {
1097 data: values_tensor.data,
1098 shape: values_tensor.shape,
1099 storage: GpuTensorStorage::Real,
1100 },
1101 ia: HostTensorOwned {
1102 data: ia.data,
1103 shape: ia.shape,
1104 storage: GpuTensorStorage::Real,
1105 },
1106 })
1107 }
1108
1109 pub fn into_values_value(self) -> Value {
1110 self.values
1111 }
1112
1113 pub fn into_pair(self) -> (Value, Value) {
1114 let ia = tensor::tensor_into_value(self.ia);
1115 (self.values, ia)
1116 }
1117
1118 pub fn values_value(&self) -> Value {
1119 self.values.clone()
1120 }
1121
1122 pub fn ia_value(&self) -> Value {
1123 tensor::tensor_into_value(self.ia.clone())
1124 }
1125}
1126
1127fn canonicalize_f64(value: f64) -> u64 {
1128 if value.is_nan() {
1129 0x7ff8_0000_0000_0000u64
1130 } else if value == 0.0 {
1131 0u64
1132 } else {
1133 value.to_bits()
1134 }
1135}
1136
1137fn compare_f64(a: f64, b: f64) -> Ordering {
1138 if a.is_nan() {
1139 if b.is_nan() {
1140 Ordering::Equal
1141 } else {
1142 Ordering::Greater
1143 }
1144 } else if b.is_nan() {
1145 Ordering::Less
1146 } else {
1147 a.partial_cmp(&b).unwrap_or(Ordering::Equal)
1148 }
1149}
1150
1151fn compare_numeric_rows(a: &[f64], b: &[f64]) -> Ordering {
1152 for (lhs, rhs) in a.iter().zip(b.iter()) {
1153 let ord = compare_f64(*lhs, *rhs);
1154 if ord != Ordering::Equal {
1155 return ord;
1156 }
1157 }
1158 Ordering::Equal
1159}
1160
1161fn complex_is_nan(value: (f64, f64)) -> bool {
1162 value.0.is_nan() || value.1.is_nan()
1163}
1164
1165fn compare_complex(a: (f64, f64), b: (f64, f64)) -> Ordering {
1166 match (complex_is_nan(a), complex_is_nan(b)) {
1167 (true, true) => Ordering::Equal,
1168 (true, false) => Ordering::Greater,
1169 (false, true) => Ordering::Less,
1170 (false, false) => {
1171 let mag_a = a.0.hypot(a.1);
1172 let mag_b = b.0.hypot(b.1);
1173 let mag_cmp = compare_f64(mag_a, mag_b);
1174 if mag_cmp != Ordering::Equal {
1175 return mag_cmp;
1176 }
1177 let re_cmp = compare_f64(a.0, b.0);
1178 if re_cmp != Ordering::Equal {
1179 return re_cmp;
1180 }
1181 compare_f64(a.1, b.1)
1182 }
1183 }
1184}
1185
1186fn compare_complex_rows(a: &[(f64, f64)], b: &[(f64, f64)]) -> Ordering {
1187 for (lhs, rhs) in a.iter().zip(b.iter()) {
1188 let ord = compare_complex(*lhs, *rhs);
1189 if ord != Ordering::Equal {
1190 return ord;
1191 }
1192 }
1193 Ordering::Equal
1194}
1195
1196fn compare_char_rows(a: &[char], b: &[char]) -> Ordering {
1197 for (lhs, rhs) in a.iter().zip(b.iter()) {
1198 let ord = lhs.cmp(rhs);
1199 if ord != Ordering::Equal {
1200 return ord;
1201 }
1202 }
1203 Ordering::Equal
1204}
1205
1206fn compare_string_rows(a: &[String], b: &[String]) -> Ordering {
1207 for (lhs, rhs) in a.iter().zip(b.iter()) {
1208 let ord = lhs.cmp(rhs);
1209 if ord != Ordering::Equal {
1210 return ord;
1211 }
1212 }
1213 Ordering::Equal
1214}
1215
1216#[cfg(test)]
1217pub(crate) mod tests {
1218 use super::*;
1219 use crate::builtins::common::test_support;
1220 use runmat_accelerate_api::HostTensorView;
1221 use runmat_builtins::{CharArray, ResolveContext, StringArray, Tensor, Type, Value};
1222
1223 fn error_message(err: crate::RuntimeError) -> String {
1224 err.message().to_string()
1225 }
1226
1227 fn evaluate_sync(
1228 a: Value,
1229 b: Value,
1230 rest: &[Value],
1231 ) -> crate::BuiltinResult<SetdiffEvaluation> {
1232 futures::executor::block_on(evaluate(a, b, rest))
1233 }
1234
1235 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1236 #[test]
1237 fn setdiff_numeric_sorted_default() {
1238 let a = Tensor::new(vec![5.0, 7.0, 5.0, 1.0], vec![4, 1]).unwrap();
1239 let b = Tensor::new(vec![7.0, 1.0, 3.0], vec![3, 1]).unwrap();
1240 let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[]).expect("setdiff");
1241 match eval.values_value() {
1242 Value::Tensor(t) => {
1243 assert_eq!(t.shape, vec![1, 1]);
1244 assert_eq!(t.data, vec![5.0]);
1245 }
1246 other => panic!("expected tensor result, got {other:?}"),
1247 }
1248 let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1249 assert_eq!(ia.data, vec![1.0]);
1250 }
1251
1252 #[test]
1253 fn setdiff_type_resolver_numeric() {
1254 assert_eq!(
1255 set_values_output_type(
1256 &[Type::tensor(), Type::tensor()],
1257 &ResolveContext::new(Vec::new()),
1258 ),
1259 Type::tensor()
1260 );
1261 }
1262
1263 #[test]
1264 fn setdiff_type_resolver_string_array() {
1265 assert_eq!(
1266 set_values_output_type(
1267 &[Type::cell_of(Type::String), Type::String],
1268 &ResolveContext::new(Vec::new()),
1269 ),
1270 Type::cell_of(Type::String)
1271 );
1272 }
1273
1274 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1275 #[test]
1276 fn setdiff_numeric_stable() {
1277 let a = Tensor::new(vec![4.0, 2.0, 4.0, 1.0, 3.0], vec![5, 1]).unwrap();
1278 let b = Tensor::new(vec![3.0, 4.0, 5.0, 1.0], vec![4, 1]).unwrap();
1279 let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[Value::from("stable")])
1280 .expect("setdiff");
1281 match eval.values_value() {
1282 Value::Tensor(t) => {
1283 assert_eq!(t.shape, vec![1, 1]);
1284 assert_eq!(t.data, vec![2.0]);
1285 }
1286 other => panic!("expected tensor result, got {other:?}"),
1287 }
1288 let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1289 assert_eq!(ia.data, vec![2.0]);
1290 }
1291
1292 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1293 #[test]
1294 fn setdiff_numeric_rows_sorted() {
1295 let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
1296 let b = Tensor::new(vec![3.0, 5.0, 4.0, 6.0], vec![2, 2]).unwrap();
1297 let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[Value::from("rows")])
1298 .expect("setdiff");
1299 match eval.values_value() {
1300 Value::Tensor(t) => {
1301 assert_eq!(t.shape, vec![1, 2]);
1302 assert_eq!(t.data, vec![1.0, 2.0]);
1303 }
1304 other => panic!("expected tensor result, got {other:?}"),
1305 }
1306 let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1307 assert_eq!(ia.data, vec![1.0]);
1308 }
1309
1310 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1311 #[test]
1312 fn setdiff_numeric_removes_nan() {
1313 let a = Tensor::new(vec![f64::NAN, 2.0, 3.0], vec![3, 1]).unwrap();
1314 let b = Tensor::new(vec![f64::NAN], vec![1, 1]).unwrap();
1315 let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[]).expect("setdiff");
1316 let values = tensor::value_into_tensor_for("setdiff", eval.values_value()).expect("values");
1317 assert_eq!(values.data, vec![2.0, 3.0]);
1318 let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1319 assert_eq!(ia.data, vec![2.0, 3.0]);
1320 }
1321
1322 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1323 #[test]
1324 fn setdiff_char_elements() {
1325 let a = CharArray::new(vec!['m', 'z', 'm', 'a'], 2, 2).unwrap();
1326 let b = CharArray::new(vec!['a', 'x', 'm', 'a'], 2, 2).unwrap();
1327 let eval = evaluate_sync(Value::CharArray(a), Value::CharArray(b), &[]).expect("setdiff");
1328 match eval.values_value() {
1329 Value::CharArray(arr) => {
1330 assert_eq!(arr.rows, 1);
1331 assert_eq!(arr.cols, 1);
1332 assert_eq!(arr.data, vec!['z']);
1333 }
1334 other => panic!("expected char array, got {other:?}"),
1335 }
1336 let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1337 assert_eq!(ia.data, vec![3.0]);
1338 }
1339
1340 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1341 #[test]
1342 fn setdiff_string_rows_stable() {
1343 let a = StringArray::new(
1344 vec![
1345 "alpha".to_string(),
1346 "gamma".to_string(),
1347 "beta".to_string(),
1348 "beta".to_string(),
1349 ],
1350 vec![2, 2],
1351 )
1352 .unwrap();
1353 let b = StringArray::new(
1354 vec![
1355 "gamma".to_string(),
1356 "delta".to_string(),
1357 "beta".to_string(),
1358 "beta".to_string(),
1359 ],
1360 vec![2, 2],
1361 )
1362 .unwrap();
1363 let eval = evaluate_sync(
1364 Value::StringArray(a),
1365 Value::StringArray(b),
1366 &[Value::from("rows"), Value::from("stable")],
1367 )
1368 .expect("setdiff");
1369 match eval.values_value() {
1370 Value::StringArray(arr) => {
1371 assert_eq!(arr.shape, vec![1, 2]);
1372 assert_eq!(arr.data, vec!["alpha".to_string(), "beta".to_string()]);
1373 }
1374 other => panic!("expected string array, got {other:?}"),
1375 }
1376 let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1377 assert_eq!(ia.data, vec![1.0]);
1378 }
1379
1380 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1381 #[test]
1382 fn setdiff_type_mismatch_errors() {
1383 let result = evaluate_sync(Value::from(1.0), Value::String("a".into()), &[]);
1384 assert!(result.is_err());
1385 }
1386
1387 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1388 #[test]
1389 fn setdiff_rejects_legacy_option() {
1390 let err = error_message(
1391 evaluate_sync(Value::from(1.0), Value::from(2.0), &[Value::from("legacy")])
1392 .unwrap_err(),
1393 );
1394 assert!(err.contains("setdiff: the 'legacy' behaviour is not supported"));
1395 }
1396
1397 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1398 #[test]
1399 fn setdiff_gpu_roundtrip() {
1400 test_support::with_test_provider(|provider| {
1401 let tensor_a = Tensor::new(vec![10.0, 4.0, 6.0, 4.0], vec![4, 1]).unwrap();
1402 let tensor_b = Tensor::new(vec![6.0, 4.0, 2.0], vec![3, 1]).unwrap();
1403 let view_a = HostTensorView {
1404 data: &tensor_a.data,
1405 shape: &tensor_a.shape,
1406 };
1407 let view_b = HostTensorView {
1408 data: &tensor_b.data,
1409 shape: &tensor_b.shape,
1410 };
1411 let handle_a = provider.upload(&view_a).expect("upload a");
1412 let handle_b = provider.upload(&view_b).expect("upload b");
1413 let eval = evaluate_sync(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
1414 .expect("setdiff");
1415 match eval.values_value() {
1416 Value::Tensor(t) => {
1417 assert_eq!(t.data, vec![10.0]);
1418 }
1419 other => panic!("expected tensor result, got {other:?}"),
1420 }
1421 let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1422 assert_eq!(ia.data, vec![1.0]);
1423 });
1424 }
1425
1426 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1427 #[test]
1428 #[cfg(feature = "wgpu")]
1429 fn setdiff_wgpu_matches_cpu() {
1430 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1431 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1432 );
1433 let a = Tensor::new(vec![8.0, 4.0, 2.0, 4.0], vec![4, 1]).unwrap();
1434 let b = Tensor::new(vec![2.0, 5.0], vec![2, 1]).unwrap();
1435
1436 let cpu_eval = evaluate_sync(Value::Tensor(a.clone()), Value::Tensor(b.clone()), &[])
1437 .expect("setdiff");
1438 let cpu_values = tensor::value_into_tensor_for("setdiff", cpu_eval.values_value()).unwrap();
1439 let cpu_ia = tensor::value_into_tensor_for("setdiff", cpu_eval.ia_value()).unwrap();
1440
1441 let provider = runmat_accelerate_api::provider().expect("provider");
1442 let view_a = HostTensorView {
1443 data: &a.data,
1444 shape: &a.shape,
1445 };
1446 let view_b = HostTensorView {
1447 data: &b.data,
1448 shape: &b.shape,
1449 };
1450 let handle_a = provider.upload(&view_a).expect("upload A");
1451 let handle_b = provider.upload(&view_b).expect("upload B");
1452 let gpu_eval = evaluate_sync(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
1453 .expect("setdiff");
1454 let gpu_values = tensor::value_into_tensor_for("setdiff", gpu_eval.values_value()).unwrap();
1455 let gpu_ia = tensor::value_into_tensor_for("setdiff", gpu_eval.ia_value()).unwrap();
1456
1457 assert_eq!(gpu_values.data, cpu_values.data);
1458 assert_eq!(gpu_values.shape, cpu_values.shape);
1459 assert_eq!(gpu_ia.data, cpu_ia.data);
1460 assert_eq!(gpu_ia.shape, cpu_ia.shape);
1461 }
1462}