1use std::cmp::Ordering;
9use std::collections::{HashMap, HashSet};
10
11use runmat_accelerate_api::{
12 GpuTensorHandle, HostTensorOwned, SetdiffOptions, SetdiffOrder, SetdiffResult,
13};
14use runmat_builtins::{CharArray, ComplexTensor, StringArray, Tensor, Value};
15use runmat_macros::runtime_builtin;
16
17use crate::builtins::common::gpu_helpers;
18use crate::builtins::common::random_args::complex_tensor_into_value;
19use crate::builtins::common::spec::{
20 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
21 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
22};
23use crate::builtins::common::tensor;
24#[cfg(feature = "doc_export")]
25use crate::register_builtin_doc_text;
26use crate::{register_builtin_fusion_spec, register_builtin_gpu_spec};
27
28#[cfg(feature = "doc_export")]
29pub const DOC_MD: &str = r#"---
30title: "setdiff"
31category: "array/sorting_sets"
32keywords: ["setdiff", "difference", "stable", "rows", "indices", "gpu"]
33summary: "Return values that appear in the first input but not the second, matching MATLAB ordering rules."
34references:
35 - https://www.mathworks.com/help/matlab/ref/setdiff.html
36gpu_support:
37 elementwise: false
38 reduction: false
39 precisions: ["f32", "f64"]
40 broadcasting: "none"
41 notes: "When providers lack a dedicated `setdiff` hook, RunMat gathers GPU tensors to host memory and reuses the CPU path."
42fusion:
43 elementwise: false
44 reduction: false
45 max_inputs: 2
46 constants: "inline"
47requires_feature: null
48tested:
49 unit: "builtins::array::sorting_sets::setdiff::tests"
50 integration: "builtins::array::sorting_sets::setdiff::tests::setdiff_gpu_roundtrip"
51---
52
53# What does the `setdiff` function do in MATLAB / RunMat?
54`setdiff(A, B)` returns the set of values (or rows) that appear in `A` but not in `B`. Results are
55unique, and the function can operate in sorted or stable order as well as row mode.
56
57## How does the `setdiff` function behave in MATLAB / RunMat?
58- `setdiff(A, B)` flattens inputs column-major, removes duplicates, subtracts the values of `B` from `A`,
59 and returns the remaining elements **sorted** ascending by default.
60- `[C, IA] = setdiff(A, B)` also returns indices so that `C = A(IA)`.
61- `setdiff(A, B, 'stable')` preserves the first appearance order from `A` instead of sorting.
62- `setdiff(A, B, 'rows')` treats each row as an element. Inputs must share the same number of columns.
63- Character arrays, string arrays, logical arrays, numeric types, and complex values are all supported.
64- Legacy flags (`'legacy'`, `'R2012a'`) are not supported; RunMat always follows modern MATLAB semantics.
65
66## `setdiff` Function GPU Execution Behaviour
67`setdiff` is registered as a residency sink. When tensors reside on the GPU and the active provider
68does not yet implement a `setdiff` hook, RunMat gathers them to host memory, performs the CPU
69implementation, and materialises host-resident results. Future providers can wire a custom hook to
70perform the set difference directly on-device without affecting existing callers.
71
72## Examples of using the `setdiff` function in MATLAB / RunMat
73
74### Finding values exclusive to the first numeric vector
75```matlab
76A = [5 7 5 1];
77B = [7 1 3];
78[C, IA] = setdiff(A, B);
79```
80Expected output:
81```matlab
82C =
83 5
84IA =
85 1
86```
87
88### Preserving input order with `'stable'`
89```matlab
90A = [4 2 4 1 3];
91B = [3 4 5 1];
92[C, IA] = setdiff(A, B, 'stable');
93```
94Expected output:
95```matlab
96C =
97 2
98IA =
99 2
100```
101
102### Working with rows of numeric matrices
103```matlab
104A = [1 2; 3 4; 1 2];
105B = [3 4; 5 6];
106[C, IA] = setdiff(A, B, 'rows');
107```
108Expected output:
109```matlab
110C =
111 1 2
112IA =
113 1
114```
115
116### Computing set difference for character data
117```matlab
118A = ['m','z'; 'm','a'];
119B = ['a','x'; 'm','a'];
120[C, IA] = setdiff(A, B);
121```
122Expected output:
123```matlab
124C =
125 m
126IA =
127 1
128```
129
130### Subtracting string arrays by row
131```matlab
132A = ["alpha" "beta"; "gamma" "beta"];
133B = ["gamma" "beta"; "delta" "beta"];
134[C, IA] = setdiff(A, B, 'rows', 'stable');
135```
136Expected output:
137```matlab
138C =
139 1x2 string array
140 "alpha" "beta"
141IA =
142 1
143```
144
145### Using `setdiff` with GPU arrays
146```matlab
147G = gpuArray([10 4 6 4]);
148H = gpuArray([6 4 2]);
149C = setdiff(G, H);
150```
151RunMat gathers `G` and `H` to the host (until providers implement a GPU hook) and returns:
152```matlab
153C =
154 10
155```
156
157## FAQ
158
159### What ordering does `setdiff` use by default?
160Results are sorted ascending. Specify `'stable'` to preserve the first appearance order from the first input.
161
162### How are the index outputs defined?
163`IA` points to the positions in `A` that correspond to each element (or row) returned in `C`, using MATLAB's one-based indexing.
164
165### Can I combine `'rows'` with `'stable'`?
166Yes. `'rows'` can be paired with either `'sorted'` (default) or `'stable'`. Other option combinations that conflict (e.g. `'sorted'` with `'stable'`) are rejected.
167
168### Does `setdiff` remove `NaN` values from `A` when they exist in `B`?
169Yes. `NaN` values are considered equal. If `B` contains `NaN`, all `NaN` entries from `A` are removed.
170
171### Are complex numbers supported?
172Absolutely. Complex values use MATLAB's ordering rules (magnitude, then real part, then imaginary part) for the sorted output.
173
174### Does GPU execution change the results?
175No. Until providers supply a device implementation, RunMat gathers GPU inputs and executes the CPU path to guarantee MATLAB-compatible behaviour.
176
177### What happens if the inputs have different classes?
178RunMat follows MATLAB's rules: both inputs must share the same class (numeric/logical, complex, char, or string). Mixed-class inputs raise descriptive errors.
179
180### Can I request `'legacy'` behaviour?
181No. RunMat implements the modern semantics only. Passing `'legacy'` or `'R2012a'` results in an error.
182
183## See Also
184[unique](./unique), [union](./union), [intersect](./intersect), [ismember](./ismember), [gpuArray](../../acceleration/gpu/gpuArray), [gather](../../acceleration/gpu/gather)
185
186## Source & Feedback
187- Implementation: `crates/runmat-runtime/src/builtins/array/sorting_sets/setdiff.rs`
188- Issues / feedback: https://github.com/runmat-org/runmat/issues/new/choose
189"#;
190
191pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
192 name: "setdiff",
193 op_kind: GpuOpKind::Custom("setdiff"),
194 supported_precisions: &[ScalarType::F32, ScalarType::F64],
195 broadcast: BroadcastSemantics::None,
196 provider_hooks: &[ProviderHook::Custom("setdiff")],
197 constant_strategy: ConstantStrategy::InlineLiteral,
198 residency: ResidencyPolicy::GatherImmediately,
199 nan_mode: ReductionNaN::Include,
200 two_pass_threshold: None,
201 workgroup_size: None,
202 accepts_nan_mode: true,
203 notes: "Providers may implement `setdiff`; until then tensors are gathered and processed on the host.",
204};
205
206register_builtin_gpu_spec!(GPU_SPEC);
207
208pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
209 name: "setdiff",
210 shape: ShapeRequirements::Any,
211 constant_strategy: ConstantStrategy::InlineLiteral,
212 elementwise: None,
213 reduction: None,
214 emits_nan: true,
215 notes: "`setdiff` materialises its inputs and terminates fusion chains; upstream GPU tensors are gathered if needed.",
216};
217
218register_builtin_fusion_spec!(FUSION_SPEC);
219
220#[cfg(feature = "doc_export")]
221register_builtin_doc_text!("setdiff", DOC_MD);
222
223#[runtime_builtin(
224 name = "setdiff",
225 category = "array/sorting_sets",
226 summary = "Return the values that appear in the first input but not the second.",
227 keywords = "setdiff,difference,stable,rows,indices,gpu",
228 accel = "array_construct",
229 sink = true
230)]
231fn setdiff_builtin(a: Value, b: Value, rest: Vec<Value>) -> Result<Value, String> {
232 evaluate(a, b, &rest).map(|eval| eval.into_values_value())
233}
234
235pub fn evaluate(a: Value, b: Value, rest: &[Value]) -> Result<SetdiffEvaluation, String> {
237 let opts = parse_options(rest)?;
238 match (a, b) {
239 (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
240 setdiff_gpu_pair(handle_a, handle_b, &opts)
241 }
242 (Value::GpuTensor(handle_a), other) => setdiff_gpu_mixed(handle_a, other, &opts, true),
243 (other, Value::GpuTensor(handle_b)) => setdiff_gpu_mixed(handle_b, other, &opts, false),
244 (left, right) => setdiff_host(left, right, &opts),
245 }
246}
247
248fn parse_options(rest: &[Value]) -> Result<SetdiffOptions, String> {
249 let mut opts = SetdiffOptions {
250 rows: false,
251 order: SetdiffOrder::Sorted,
252 };
253 let mut seen_order: Option<SetdiffOrder> = None;
254
255 for arg in rest {
256 let text = tensor::value_to_string(arg)
257 .ok_or_else(|| "setdiff: expected string option arguments".to_string())?;
258 let lowered = text.trim().to_ascii_lowercase();
259 match lowered.as_str() {
260 "rows" => opts.rows = true,
261 "sorted" => {
262 if let Some(prev) = seen_order {
263 if prev != SetdiffOrder::Sorted {
264 return Err("setdiff: cannot combine 'sorted' with 'stable'".to_string());
265 }
266 }
267 seen_order = Some(SetdiffOrder::Sorted);
268 opts.order = SetdiffOrder::Sorted;
269 }
270 "stable" => {
271 if let Some(prev) = seen_order {
272 if prev != SetdiffOrder::Stable {
273 return Err("setdiff: cannot combine 'sorted' with 'stable'".to_string());
274 }
275 }
276 seen_order = Some(SetdiffOrder::Stable);
277 opts.order = SetdiffOrder::Stable;
278 }
279 "legacy" | "r2012a" => {
280 return Err("setdiff: the 'legacy' behaviour is not supported".to_string());
281 }
282 other => return Err(format!("setdiff: unrecognised option '{other}'")),
283 }
284 }
285
286 Ok(opts)
287}
288
289fn setdiff_gpu_pair(
290 handle_a: GpuTensorHandle,
291 handle_b: GpuTensorHandle,
292 opts: &SetdiffOptions,
293) -> Result<SetdiffEvaluation, String> {
294 if let Some(provider) = runmat_accelerate_api::provider() {
295 match provider.setdiff(&handle_a, &handle_b, opts) {
296 Ok(result) => return SetdiffEvaluation::from_setdiff_result(result),
297 Err(_) => {
298 }
300 }
301 }
302 let a_tensor = gpu_helpers::gather_tensor(&handle_a)?;
303 let b_tensor = gpu_helpers::gather_tensor(&handle_b)?;
304 setdiff_numeric(a_tensor, b_tensor, opts)
305}
306
307fn setdiff_gpu_mixed(
308 handle_gpu: GpuTensorHandle,
309 other: Value,
310 opts: &SetdiffOptions,
311 gpu_is_a: bool,
312) -> Result<SetdiffEvaluation, String> {
313 let gpu_tensor = gpu_helpers::gather_tensor(&handle_gpu)?;
314 let other_tensor = tensor::value_into_tensor_for("setdiff", other)?;
315 if gpu_is_a {
316 setdiff_numeric(gpu_tensor, other_tensor, opts)
317 } else {
318 setdiff_numeric(other_tensor, gpu_tensor, opts)
319 }
320}
321
322fn setdiff_host(a: Value, b: Value, opts: &SetdiffOptions) -> Result<SetdiffEvaluation, String> {
323 match (a, b) {
324 (Value::ComplexTensor(at), Value::ComplexTensor(bt)) => setdiff_complex(at, bt, opts),
325 (Value::ComplexTensor(at), Value::Complex(re, im)) => {
326 let bt = ComplexTensor::new(vec![(re, im)], vec![1, 1])
327 .map_err(|e| format!("setdiff: {e}"))?;
328 setdiff_complex(at, bt, opts)
329 }
330 (Value::Complex(a_re, a_im), Value::ComplexTensor(bt)) => {
331 let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
332 .map_err(|e| format!("setdiff: {e}"))?;
333 setdiff_complex(at, bt, opts)
334 }
335 (Value::Complex(a_re, a_im), Value::Complex(b_re, b_im)) => {
336 let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
337 .map_err(|e| format!("setdiff: {e}"))?;
338 let bt = ComplexTensor::new(vec![(b_re, b_im)], vec![1, 1])
339 .map_err(|e| format!("setdiff: {e}"))?;
340 setdiff_complex(at, bt, opts)
341 }
342
343 (Value::CharArray(ac), Value::CharArray(bc)) => setdiff_char(ac, bc, opts),
344
345 (Value::StringArray(astring), Value::StringArray(bstring)) => {
346 setdiff_string(astring, bstring, opts)
347 }
348 (Value::StringArray(astring), Value::String(b)) => {
349 let bstring =
350 StringArray::new(vec![b], vec![1, 1]).map_err(|e| format!("setdiff: {e}"))?;
351 setdiff_string(astring, bstring, opts)
352 }
353 (Value::String(a), Value::StringArray(bstring)) => {
354 let astring =
355 StringArray::new(vec![a], vec![1, 1]).map_err(|e| format!("setdiff: {e}"))?;
356 setdiff_string(astring, bstring, opts)
357 }
358 (Value::String(a), Value::String(b)) => {
359 let astring =
360 StringArray::new(vec![a], vec![1, 1]).map_err(|e| format!("setdiff: {e}"))?;
361 let bstring =
362 StringArray::new(vec![b], vec![1, 1]).map_err(|e| format!("setdiff: {e}"))?;
363 setdiff_string(astring, bstring, opts)
364 }
365
366 (left, right) => {
367 let tensor_a = tensor::value_into_tensor_for("setdiff", left)?;
368 let tensor_b = tensor::value_into_tensor_for("setdiff", right)?;
369 setdiff_numeric(tensor_a, tensor_b, opts)
370 }
371 }
372}
373
374fn setdiff_numeric(
375 a: Tensor,
376 b: Tensor,
377 opts: &SetdiffOptions,
378) -> Result<SetdiffEvaluation, String> {
379 if opts.rows {
380 setdiff_numeric_rows(a, b, opts)
381 } else {
382 setdiff_numeric_elements(a, b, opts)
383 }
384}
385
386pub fn setdiff_numeric_from_tensors(
388 a: Tensor,
389 b: Tensor,
390 opts: &SetdiffOptions,
391) -> Result<SetdiffEvaluation, String> {
392 setdiff_numeric(a, b, opts)
393}
394
395fn setdiff_numeric_elements(
396 a: Tensor,
397 b: Tensor,
398 opts: &SetdiffOptions,
399) -> Result<SetdiffEvaluation, String> {
400 let mut b_keys: HashSet<u64> = HashSet::new();
401 for &value in &b.data {
402 b_keys.insert(canonicalize_f64(value));
403 }
404
405 let mut seen: HashMap<u64, usize> = HashMap::new();
406 let mut entries = Vec::<NumericDiffEntry>::new();
407 let mut order_counter = 0usize;
408
409 for (idx, &value) in a.data.iter().enumerate() {
410 let key = canonicalize_f64(value);
411 if b_keys.contains(&key) {
412 continue;
413 }
414 if seen.contains_key(&key) {
415 continue;
416 }
417 let entry_idx = entries.len();
418 entries.push(NumericDiffEntry {
419 value,
420 index: idx,
421 order_rank: order_counter,
422 });
423 seen.insert(key, entry_idx);
424 order_counter += 1;
425 }
426
427 assemble_numeric_setdiff(entries, opts)
428}
429
430fn setdiff_numeric_rows(
431 a: Tensor,
432 b: Tensor,
433 opts: &SetdiffOptions,
434) -> Result<SetdiffEvaluation, String> {
435 if a.shape.len() != 2 || b.shape.len() != 2 {
436 return Err("setdiff: 'rows' option requires 2-D numeric matrices".to_string());
437 }
438 if a.shape[1] != b.shape[1] {
439 return Err(
440 "setdiff: inputs must have the same number of columns when using 'rows'".to_string(),
441 );
442 }
443
444 let rows_a = a.shape[0];
445 let rows_b = b.shape[0];
446 let cols = a.shape[1];
447
448 let mut b_keys: HashSet<NumericRowKey> = HashSet::new();
449 for r in 0..rows_b {
450 let mut row_values = Vec::with_capacity(cols);
451 for c in 0..cols {
452 let idx = r + c * rows_b;
453 row_values.push(b.data[idx]);
454 }
455 b_keys.insert(NumericRowKey::from_slice(&row_values));
456 }
457
458 let mut seen: HashSet<NumericRowKey> = HashSet::new();
459 let mut entries = Vec::<NumericRowDiffEntry>::new();
460 let mut order_counter = 0usize;
461
462 for r in 0..rows_a {
463 let mut row_values = Vec::with_capacity(cols);
464 for c in 0..cols {
465 let idx = r + c * rows_a;
466 row_values.push(a.data[idx]);
467 }
468 let key = NumericRowKey::from_slice(&row_values);
469 if b_keys.contains(&key) {
470 continue;
471 }
472 if !seen.insert(key) {
473 continue;
474 }
475 entries.push(NumericRowDiffEntry {
476 row_data: row_values,
477 row_index: r,
478 order_rank: order_counter,
479 });
480 order_counter += 1;
481 }
482
483 assemble_numeric_row_setdiff(entries, opts, cols)
484}
485
486fn setdiff_complex(
487 a: ComplexTensor,
488 b: ComplexTensor,
489 opts: &SetdiffOptions,
490) -> Result<SetdiffEvaluation, String> {
491 if opts.rows {
492 setdiff_complex_rows(a, b, opts)
493 } else {
494 setdiff_complex_elements(a, b, opts)
495 }
496}
497
498fn setdiff_complex_elements(
499 a: ComplexTensor,
500 b: ComplexTensor,
501 opts: &SetdiffOptions,
502) -> Result<SetdiffEvaluation, String> {
503 let mut b_keys: HashSet<ComplexKey> = HashSet::new();
504 for &value in &b.data {
505 b_keys.insert(ComplexKey::new(value));
506 }
507
508 let mut seen: HashSet<ComplexKey> = HashSet::new();
509 let mut entries = Vec::<ComplexDiffEntry>::new();
510 let mut order_counter = 0usize;
511
512 for (idx, &value) in a.data.iter().enumerate() {
513 let key = ComplexKey::new(value);
514 if b_keys.contains(&key) {
515 continue;
516 }
517 if !seen.insert(key) {
518 continue;
519 }
520 entries.push(ComplexDiffEntry {
521 value,
522 index: idx,
523 order_rank: order_counter,
524 });
525 order_counter += 1;
526 }
527
528 assemble_complex_setdiff(entries, opts)
529}
530
531fn setdiff_complex_rows(
532 a: ComplexTensor,
533 b: ComplexTensor,
534 opts: &SetdiffOptions,
535) -> Result<SetdiffEvaluation, String> {
536 if a.shape.len() != 2 || b.shape.len() != 2 {
537 return Err("setdiff: 'rows' option requires 2-D complex matrices".to_string());
538 }
539 if a.shape[1] != b.shape[1] {
540 return Err(
541 "setdiff: inputs must have the same number of columns when using 'rows'".to_string(),
542 );
543 }
544
545 let rows_a = a.shape[0];
546 let rows_b = b.shape[0];
547 let cols = a.shape[1];
548
549 let mut b_keys: HashSet<Vec<ComplexKey>> = HashSet::new();
550 for r in 0..rows_b {
551 let mut key_row = Vec::with_capacity(cols);
552 for c in 0..cols {
553 let idx = r + c * rows_b;
554 key_row.push(ComplexKey::new(b.data[idx]));
555 }
556 b_keys.insert(key_row);
557 }
558
559 let mut seen: HashSet<Vec<ComplexKey>> = HashSet::new();
560 let mut entries = Vec::<ComplexRowDiffEntry>::new();
561 let mut order_counter = 0usize;
562
563 for r in 0..rows_a {
564 let mut row_values = Vec::with_capacity(cols);
565 let mut key_row = Vec::with_capacity(cols);
566 for c in 0..cols {
567 let idx = r + c * rows_a;
568 let value = a.data[idx];
569 row_values.push(value);
570 key_row.push(ComplexKey::new(value));
571 }
572 if b_keys.contains(&key_row) {
573 continue;
574 }
575 if !seen.insert(key_row) {
576 continue;
577 }
578 entries.push(ComplexRowDiffEntry {
579 row_data: row_values,
580 row_index: r,
581 order_rank: order_counter,
582 });
583 order_counter += 1;
584 }
585
586 assemble_complex_row_setdiff(entries, opts, cols)
587}
588
589fn setdiff_char(
590 a: CharArray,
591 b: CharArray,
592 opts: &SetdiffOptions,
593) -> Result<SetdiffEvaluation, String> {
594 if opts.rows {
595 setdiff_char_rows(a, b, opts)
596 } else {
597 setdiff_char_elements(a, b, opts)
598 }
599}
600
601fn setdiff_char_elements(
602 a: CharArray,
603 b: CharArray,
604 opts: &SetdiffOptions,
605) -> Result<SetdiffEvaluation, String> {
606 let mut b_keys: HashSet<u32> = HashSet::new();
607 for ch in &b.data {
608 b_keys.insert(*ch as u32);
609 }
610
611 let mut seen: HashSet<u32> = HashSet::new();
612 let mut entries = Vec::<CharDiffEntry>::new();
613 let mut order_counter = 0usize;
614
615 for col in 0..a.cols {
616 for row in 0..a.rows {
617 let linear_idx = row + col * a.rows;
618 let data_idx = row * a.cols + col;
619 let ch = a.data[data_idx];
620 let key = ch as u32;
621 if b_keys.contains(&key) {
622 continue;
623 }
624 if !seen.insert(key) {
625 continue;
626 }
627 entries.push(CharDiffEntry {
628 ch,
629 index: linear_idx,
630 order_rank: order_counter,
631 });
632 order_counter += 1;
633 }
634 }
635
636 assemble_char_setdiff(entries, opts)
637}
638
639fn setdiff_char_rows(
640 a: CharArray,
641 b: CharArray,
642 opts: &SetdiffOptions,
643) -> Result<SetdiffEvaluation, String> {
644 if a.cols != b.cols {
645 return Err(
646 "setdiff: inputs must have the same number of columns when using 'rows'".to_string(),
647 );
648 }
649
650 let rows_a = a.rows;
651 let rows_b = b.rows;
652 let cols = a.cols;
653
654 let mut b_keys: HashSet<RowCharKey> = HashSet::new();
655 for r in 0..rows_b {
656 let mut row_values = Vec::with_capacity(cols);
657 for c in 0..cols {
658 let idx = r * cols + c;
659 row_values.push(b.data[idx]);
660 }
661 b_keys.insert(RowCharKey::from_slice(&row_values));
662 }
663
664 let mut seen: HashSet<RowCharKey> = HashSet::new();
665 let mut entries = Vec::<CharRowDiffEntry>::new();
666 let mut order_counter = 0usize;
667
668 for r in 0..rows_a {
669 let mut row_values = Vec::with_capacity(cols);
670 for c in 0..cols {
671 let idx = r * cols + c;
672 row_values.push(a.data[idx]);
673 }
674 let key = RowCharKey::from_slice(&row_values);
675 if b_keys.contains(&key) {
676 continue;
677 }
678 if !seen.insert(key) {
679 continue;
680 }
681 entries.push(CharRowDiffEntry {
682 row_data: row_values,
683 row_index: r,
684 order_rank: order_counter,
685 });
686 order_counter += 1;
687 }
688
689 assemble_char_row_setdiff(entries, opts, cols)
690}
691
692fn setdiff_string(
693 a: StringArray,
694 b: StringArray,
695 opts: &SetdiffOptions,
696) -> Result<SetdiffEvaluation, String> {
697 if opts.rows {
698 setdiff_string_rows(a, b, opts)
699 } else {
700 setdiff_string_elements(a, b, opts)
701 }
702}
703
704fn setdiff_string_elements(
705 a: StringArray,
706 b: StringArray,
707 opts: &SetdiffOptions,
708) -> Result<SetdiffEvaluation, String> {
709 let mut b_keys: HashSet<String> = HashSet::new();
710 for value in &b.data {
711 b_keys.insert(value.clone());
712 }
713
714 let mut seen: HashSet<String> = HashSet::new();
715 let mut entries = Vec::<StringDiffEntry>::new();
716 let mut order_counter = 0usize;
717
718 for (idx, value) in a.data.iter().enumerate() {
719 if b_keys.contains(value) {
720 continue;
721 }
722 if !seen.insert(value.clone()) {
723 continue;
724 }
725 entries.push(StringDiffEntry {
726 value: value.clone(),
727 index: idx,
728 order_rank: order_counter,
729 });
730 order_counter += 1;
731 }
732
733 assemble_string_setdiff(entries, opts)
734}
735
736fn setdiff_string_rows(
737 a: StringArray,
738 b: StringArray,
739 opts: &SetdiffOptions,
740) -> Result<SetdiffEvaluation, String> {
741 if a.shape.len() != 2 || b.shape.len() != 2 {
742 return Err("setdiff: 'rows' option requires 2-D string arrays".to_string());
743 }
744 if a.shape[1] != b.shape[1] {
745 return Err(
746 "setdiff: inputs must have the same number of columns when using 'rows'".to_string(),
747 );
748 }
749
750 let rows_a = a.shape[0];
751 let rows_b = b.shape[0];
752 let cols = a.shape[1];
753
754 let mut b_keys: HashSet<RowStringKey> = HashSet::new();
755 for r in 0..rows_b {
756 let mut row_values = Vec::with_capacity(cols);
757 for c in 0..cols {
758 let idx = r + c * rows_b;
759 row_values.push(b.data[idx].clone());
760 }
761 b_keys.insert(RowStringKey(row_values.clone()));
762 }
763
764 let mut seen: HashSet<RowStringKey> = HashSet::new();
765 let mut entries = Vec::<StringRowDiffEntry>::new();
766 let mut order_counter = 0usize;
767
768 for r in 0..rows_a {
769 let mut row_values = Vec::with_capacity(cols);
770 for c in 0..cols {
771 let idx = r + c * rows_a;
772 row_values.push(a.data[idx].clone());
773 }
774 let key = RowStringKey(row_values.clone());
775 if b_keys.contains(&key) {
776 continue;
777 }
778 if !seen.insert(key) {
779 continue;
780 }
781 entries.push(StringRowDiffEntry {
782 row_data: row_values,
783 row_index: r,
784 order_rank: order_counter,
785 });
786 order_counter += 1;
787 }
788
789 assemble_string_row_setdiff(entries, opts, cols)
790}
791
792fn assemble_numeric_setdiff(
793 entries: Vec<NumericDiffEntry>,
794 opts: &SetdiffOptions,
795) -> Result<SetdiffEvaluation, String> {
796 let mut order: Vec<usize> = (0..entries.len()).collect();
797 match opts.order {
798 SetdiffOrder::Sorted => {
799 order.sort_by(|&lhs, &rhs| compare_f64(entries[lhs].value, entries[rhs].value));
800 }
801 SetdiffOrder::Stable => {
802 order.sort_by_key(|&idx| entries[idx].order_rank);
803 }
804 }
805
806 let mut values = Vec::with_capacity(order.len());
807 let mut ia = Vec::with_capacity(order.len());
808 for &idx in &order {
809 let entry = &entries[idx];
810 values.push(entry.value);
811 ia.push((entry.index + 1) as f64);
812 }
813
814 let value_tensor =
815 Tensor::new(values, vec![order.len(), 1]).map_err(|e| format!("setdiff: {e}"))?;
816 let ia_tensor = Tensor::new(ia, vec![order.len(), 1]).map_err(|e| format!("setdiff: {e}"))?;
817
818 Ok(SetdiffEvaluation::new(
819 Value::Tensor(value_tensor),
820 ia_tensor,
821 ))
822}
823
824fn assemble_numeric_row_setdiff(
825 entries: Vec<NumericRowDiffEntry>,
826 opts: &SetdiffOptions,
827 cols: usize,
828) -> Result<SetdiffEvaluation, String> {
829 let mut order: Vec<usize> = (0..entries.len()).collect();
830 match opts.order {
831 SetdiffOrder::Sorted => {
832 order.sort_by(|&lhs, &rhs| {
833 compare_numeric_rows(&entries[lhs].row_data, &entries[rhs].row_data)
834 });
835 }
836 SetdiffOrder::Stable => {
837 order.sort_by_key(|&idx| entries[idx].order_rank);
838 }
839 }
840
841 let unique_rows = order.len();
842 let mut values = vec![0.0f64; unique_rows * cols];
843 let mut ia = Vec::with_capacity(unique_rows);
844
845 for (row_pos, &entry_idx) in order.iter().enumerate() {
846 let entry = &entries[entry_idx];
847 for col in 0..cols {
848 let dest = row_pos + col * unique_rows;
849 values[dest] = entry.row_data[col];
850 }
851 ia.push((entry.row_index + 1) as f64);
852 }
853
854 let value_tensor =
855 Tensor::new(values, vec![unique_rows, cols]).map_err(|e| format!("setdiff: {e}"))?;
856 let ia_tensor = Tensor::new(ia, vec![unique_rows, 1]).map_err(|e| format!("setdiff: {e}"))?;
857
858 Ok(SetdiffEvaluation::new(
859 Value::Tensor(value_tensor),
860 ia_tensor,
861 ))
862}
863
864fn assemble_complex_setdiff(
865 entries: Vec<ComplexDiffEntry>,
866 opts: &SetdiffOptions,
867) -> Result<SetdiffEvaluation, String> {
868 let mut order: Vec<usize> = (0..entries.len()).collect();
869 match opts.order {
870 SetdiffOrder::Sorted => {
871 order.sort_by(|&lhs, &rhs| compare_complex(entries[lhs].value, entries[rhs].value));
872 }
873 SetdiffOrder::Stable => {
874 order.sort_by_key(|&idx| entries[idx].order_rank);
875 }
876 }
877
878 let mut values = Vec::with_capacity(order.len());
879 let mut ia = Vec::with_capacity(order.len());
880 for &idx in &order {
881 let entry = &entries[idx];
882 values.push(entry.value);
883 ia.push((entry.index + 1) as f64);
884 }
885
886 let value_tensor =
887 ComplexTensor::new(values, vec![order.len(), 1]).map_err(|e| format!("setdiff: {e}"))?;
888 let ia_tensor = Tensor::new(ia, vec![order.len(), 1]).map_err(|e| format!("setdiff: {e}"))?;
889
890 Ok(SetdiffEvaluation::new(
891 complex_tensor_into_value(value_tensor),
892 ia_tensor,
893 ))
894}
895
896fn assemble_complex_row_setdiff(
897 entries: Vec<ComplexRowDiffEntry>,
898 opts: &SetdiffOptions,
899 cols: usize,
900) -> Result<SetdiffEvaluation, String> {
901 let mut order: Vec<usize> = (0..entries.len()).collect();
902 match opts.order {
903 SetdiffOrder::Sorted => {
904 order.sort_by(|&lhs, &rhs| {
905 compare_complex_rows(&entries[lhs].row_data, &entries[rhs].row_data)
906 });
907 }
908 SetdiffOrder::Stable => {
909 order.sort_by_key(|&idx| entries[idx].order_rank);
910 }
911 }
912
913 let unique_rows = order.len();
914 let mut values = vec![(0.0f64, 0.0f64); unique_rows * cols];
915 let mut ia = Vec::with_capacity(unique_rows);
916
917 for (row_pos, &entry_idx) in order.iter().enumerate() {
918 let entry = &entries[entry_idx];
919 for col in 0..cols {
920 let dest = row_pos + col * unique_rows;
921 values[dest] = entry.row_data[col];
922 }
923 ia.push((entry.row_index + 1) as f64);
924 }
925
926 let value_tensor =
927 ComplexTensor::new(values, vec![unique_rows, cols]).map_err(|e| format!("setdiff: {e}"))?;
928 let ia_tensor = Tensor::new(ia, vec![unique_rows, 1]).map_err(|e| format!("setdiff: {e}"))?;
929
930 Ok(SetdiffEvaluation::new(
931 complex_tensor_into_value(value_tensor),
932 ia_tensor,
933 ))
934}
935
936fn assemble_char_setdiff(
937 entries: Vec<CharDiffEntry>,
938 opts: &SetdiffOptions,
939) -> Result<SetdiffEvaluation, String> {
940 let mut order: Vec<usize> = (0..entries.len()).collect();
941 match opts.order {
942 SetdiffOrder::Sorted => {
943 order.sort_by(|&lhs, &rhs| entries[lhs].ch.cmp(&entries[rhs].ch));
944 }
945 SetdiffOrder::Stable => {
946 order.sort_by_key(|&idx| entries[idx].order_rank);
947 }
948 }
949
950 let mut values = Vec::with_capacity(order.len());
951 let mut ia = Vec::with_capacity(order.len());
952 for &idx in &order {
953 let entry = &entries[idx];
954 values.push(entry.ch);
955 ia.push((entry.index + 1) as f64);
956 }
957
958 let value_array =
959 CharArray::new(values, order.len(), 1).map_err(|e| format!("setdiff: {e}"))?;
960 let ia_tensor = Tensor::new(ia, vec![order.len(), 1]).map_err(|e| format!("setdiff: {e}"))?;
961
962 Ok(SetdiffEvaluation::new(
963 Value::CharArray(value_array),
964 ia_tensor,
965 ))
966}
967
968fn assemble_char_row_setdiff(
969 entries: Vec<CharRowDiffEntry>,
970 opts: &SetdiffOptions,
971 cols: usize,
972) -> Result<SetdiffEvaluation, String> {
973 let mut order: Vec<usize> = (0..entries.len()).collect();
974 match opts.order {
975 SetdiffOrder::Sorted => {
976 order.sort_by(|&lhs, &rhs| {
977 compare_char_rows(&entries[lhs].row_data, &entries[rhs].row_data)
978 });
979 }
980 SetdiffOrder::Stable => {
981 order.sort_by_key(|&idx| entries[idx].order_rank);
982 }
983 }
984
985 let unique_rows = order.len();
986 let mut values = vec!['\0'; unique_rows * cols];
987 let mut ia = Vec::with_capacity(unique_rows);
988
989 for (row_pos, &entry_idx) in order.iter().enumerate() {
990 let entry = &entries[entry_idx];
991 for col in 0..cols {
992 let dest = row_pos * cols + col;
993 values[dest] = entry.row_data[col];
994 }
995 ia.push((entry.row_index + 1) as f64);
996 }
997
998 let value_array =
999 CharArray::new(values, unique_rows, cols).map_err(|e| format!("setdiff: {e}"))?;
1000 let ia_tensor = Tensor::new(ia, vec![unique_rows, 1]).map_err(|e| format!("setdiff: {e}"))?;
1001
1002 Ok(SetdiffEvaluation::new(
1003 Value::CharArray(value_array),
1004 ia_tensor,
1005 ))
1006}
1007
1008fn assemble_string_setdiff(
1009 entries: Vec<StringDiffEntry>,
1010 opts: &SetdiffOptions,
1011) -> Result<SetdiffEvaluation, String> {
1012 let mut order: Vec<usize> = (0..entries.len()).collect();
1013 match opts.order {
1014 SetdiffOrder::Sorted => {
1015 order.sort_by(|&lhs, &rhs| entries[lhs].value.cmp(&entries[rhs].value));
1016 }
1017 SetdiffOrder::Stable => {
1018 order.sort_by_key(|&idx| entries[idx].order_rank);
1019 }
1020 }
1021
1022 let mut values = Vec::with_capacity(order.len());
1023 let mut ia = Vec::with_capacity(order.len());
1024 for &idx in &order {
1025 let entry = &entries[idx];
1026 values.push(entry.value.clone());
1027 ia.push((entry.index + 1) as f64);
1028 }
1029
1030 let value_array =
1031 StringArray::new(values, vec![order.len(), 1]).map_err(|e| format!("setdiff: {e}"))?;
1032 let ia_tensor = Tensor::new(ia, vec![order.len(), 1]).map_err(|e| format!("setdiff: {e}"))?;
1033
1034 Ok(SetdiffEvaluation::new(
1035 Value::StringArray(value_array),
1036 ia_tensor,
1037 ))
1038}
1039
1040fn assemble_string_row_setdiff(
1041 entries: Vec<StringRowDiffEntry>,
1042 opts: &SetdiffOptions,
1043 cols: usize,
1044) -> Result<SetdiffEvaluation, String> {
1045 let mut order: Vec<usize> = (0..entries.len()).collect();
1046 match opts.order {
1047 SetdiffOrder::Sorted => {
1048 order.sort_by(|&lhs, &rhs| {
1049 compare_string_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1050 });
1051 }
1052 SetdiffOrder::Stable => {
1053 order.sort_by_key(|&idx| entries[idx].order_rank);
1054 }
1055 }
1056
1057 let unique_rows = order.len();
1058 let mut values = vec![String::new(); unique_rows * cols];
1059 let mut ia = Vec::with_capacity(unique_rows);
1060
1061 for (row_pos, &entry_idx) in order.iter().enumerate() {
1062 let entry = &entries[entry_idx];
1063 for col in 0..cols {
1064 let dest = row_pos + col * unique_rows;
1065 values[dest] = entry.row_data[col].clone();
1066 }
1067 ia.push((entry.row_index + 1) as f64);
1068 }
1069
1070 let value_array =
1071 StringArray::new(values, vec![unique_rows, cols]).map_err(|e| format!("setdiff: {e}"))?;
1072 let ia_tensor = Tensor::new(ia, vec![unique_rows, 1]).map_err(|e| format!("setdiff: {e}"))?;
1073
1074 Ok(SetdiffEvaluation::new(
1075 Value::StringArray(value_array),
1076 ia_tensor,
1077 ))
1078}
1079
1080#[derive(Clone, Copy, Debug)]
1081struct NumericDiffEntry {
1082 value: f64,
1083 index: usize,
1084 order_rank: usize,
1085}
1086
1087#[derive(Clone, Debug)]
1088struct NumericRowDiffEntry {
1089 row_data: Vec<f64>,
1090 row_index: usize,
1091 order_rank: usize,
1092}
1093
1094#[derive(Clone, Copy, Debug)]
1095struct ComplexDiffEntry {
1096 value: (f64, f64),
1097 index: usize,
1098 order_rank: usize,
1099}
1100
1101#[derive(Clone, Debug)]
1102struct ComplexRowDiffEntry {
1103 row_data: Vec<(f64, f64)>,
1104 row_index: usize,
1105 order_rank: usize,
1106}
1107
1108#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1109struct CharDiffEntry {
1110 ch: char,
1111 index: usize,
1112 order_rank: usize,
1113}
1114
1115#[derive(Clone, Debug)]
1116struct CharRowDiffEntry {
1117 row_data: Vec<char>,
1118 row_index: usize,
1119 order_rank: usize,
1120}
1121
1122#[derive(Clone, Debug)]
1123struct StringDiffEntry {
1124 value: String,
1125 index: usize,
1126 order_rank: usize,
1127}
1128
1129#[derive(Clone, Debug)]
1130struct StringRowDiffEntry {
1131 row_data: Vec<String>,
1132 row_index: usize,
1133 order_rank: usize,
1134}
1135
1136#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1137struct NumericRowKey(Vec<u64>);
1138
1139impl NumericRowKey {
1140 fn from_slice(values: &[f64]) -> Self {
1141 NumericRowKey(values.iter().map(|&v| canonicalize_f64(v)).collect())
1142 }
1143}
1144
1145#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
1146struct ComplexKey {
1147 re: u64,
1148 im: u64,
1149}
1150
1151impl ComplexKey {
1152 fn new(value: (f64, f64)) -> Self {
1153 Self {
1154 re: canonicalize_f64(value.0),
1155 im: canonicalize_f64(value.1),
1156 }
1157 }
1158}
1159
1160#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1161struct RowCharKey(Vec<u32>);
1162
1163impl RowCharKey {
1164 fn from_slice(values: &[char]) -> Self {
1165 RowCharKey(values.iter().map(|&ch| ch as u32).collect())
1166 }
1167}
1168
1169#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1170struct RowStringKey(Vec<String>);
1171
1172pub struct SetdiffEvaluation {
1173 values: Value,
1174 ia: Tensor,
1175}
1176
1177impl SetdiffEvaluation {
1178 fn new(values: Value, ia: Tensor) -> Self {
1179 Self { values, ia }
1180 }
1181
1182 pub fn from_setdiff_result(result: SetdiffResult) -> Result<Self, String> {
1183 let SetdiffResult { values, ia } = result;
1184 let values_tensor =
1185 Tensor::new(values.data, values.shape).map_err(|e| format!("setdiff: {e}"))?;
1186 let ia_tensor = Tensor::new(ia.data, ia.shape).map_err(|e| format!("setdiff: {e}"))?;
1187 Ok(SetdiffEvaluation::new(
1188 Value::Tensor(values_tensor),
1189 ia_tensor,
1190 ))
1191 }
1192
1193 pub fn into_numeric_setdiff_result(self) -> Result<SetdiffResult, String> {
1194 let SetdiffEvaluation { values, ia } = self;
1195 let values_tensor = tensor::value_into_tensor_for("setdiff", values)?;
1196 Ok(SetdiffResult {
1197 values: HostTensorOwned {
1198 data: values_tensor.data,
1199 shape: values_tensor.shape,
1200 },
1201 ia: HostTensorOwned {
1202 data: ia.data,
1203 shape: ia.shape,
1204 },
1205 })
1206 }
1207
1208 pub fn into_values_value(self) -> Value {
1209 self.values
1210 }
1211
1212 pub fn into_pair(self) -> (Value, Value) {
1213 let ia = tensor::tensor_into_value(self.ia);
1214 (self.values, ia)
1215 }
1216
1217 pub fn values_value(&self) -> Value {
1218 self.values.clone()
1219 }
1220
1221 pub fn ia_value(&self) -> Value {
1222 tensor::tensor_into_value(self.ia.clone())
1223 }
1224}
1225
1226fn canonicalize_f64(value: f64) -> u64 {
1227 if value.is_nan() {
1228 0x7ff8_0000_0000_0000u64
1229 } else if value == 0.0 {
1230 0u64
1231 } else {
1232 value.to_bits()
1233 }
1234}
1235
1236fn compare_f64(a: f64, b: f64) -> Ordering {
1237 if a.is_nan() {
1238 if b.is_nan() {
1239 Ordering::Equal
1240 } else {
1241 Ordering::Greater
1242 }
1243 } else if b.is_nan() {
1244 Ordering::Less
1245 } else {
1246 a.partial_cmp(&b).unwrap_or(Ordering::Equal)
1247 }
1248}
1249
1250fn compare_numeric_rows(a: &[f64], b: &[f64]) -> Ordering {
1251 for (lhs, rhs) in a.iter().zip(b.iter()) {
1252 let ord = compare_f64(*lhs, *rhs);
1253 if ord != Ordering::Equal {
1254 return ord;
1255 }
1256 }
1257 Ordering::Equal
1258}
1259
1260fn complex_is_nan(value: (f64, f64)) -> bool {
1261 value.0.is_nan() || value.1.is_nan()
1262}
1263
1264fn compare_complex(a: (f64, f64), b: (f64, f64)) -> Ordering {
1265 match (complex_is_nan(a), complex_is_nan(b)) {
1266 (true, true) => Ordering::Equal,
1267 (true, false) => Ordering::Greater,
1268 (false, true) => Ordering::Less,
1269 (false, false) => {
1270 let mag_a = a.0.hypot(a.1);
1271 let mag_b = b.0.hypot(b.1);
1272 let mag_cmp = compare_f64(mag_a, mag_b);
1273 if mag_cmp != Ordering::Equal {
1274 return mag_cmp;
1275 }
1276 let re_cmp = compare_f64(a.0, b.0);
1277 if re_cmp != Ordering::Equal {
1278 return re_cmp;
1279 }
1280 compare_f64(a.1, b.1)
1281 }
1282 }
1283}
1284
1285fn compare_complex_rows(a: &[(f64, f64)], b: &[(f64, f64)]) -> Ordering {
1286 for (lhs, rhs) in a.iter().zip(b.iter()) {
1287 let ord = compare_complex(*lhs, *rhs);
1288 if ord != Ordering::Equal {
1289 return ord;
1290 }
1291 }
1292 Ordering::Equal
1293}
1294
1295fn compare_char_rows(a: &[char], b: &[char]) -> Ordering {
1296 for (lhs, rhs) in a.iter().zip(b.iter()) {
1297 let ord = lhs.cmp(rhs);
1298 if ord != Ordering::Equal {
1299 return ord;
1300 }
1301 }
1302 Ordering::Equal
1303}
1304
1305fn compare_string_rows(a: &[String], b: &[String]) -> Ordering {
1306 for (lhs, rhs) in a.iter().zip(b.iter()) {
1307 let ord = lhs.cmp(rhs);
1308 if ord != Ordering::Equal {
1309 return ord;
1310 }
1311 }
1312 Ordering::Equal
1313}
1314
1315#[cfg(test)]
1316mod tests {
1317 use super::*;
1318 use crate::builtins::common::test_support;
1319 use runmat_accelerate_api::HostTensorView;
1320 use runmat_builtins::{CharArray, StringArray, Tensor, Value};
1321
1322 #[test]
1323 fn setdiff_numeric_sorted_default() {
1324 let a = Tensor::new(vec![5.0, 7.0, 5.0, 1.0], vec![4, 1]).unwrap();
1325 let b = Tensor::new(vec![7.0, 1.0, 3.0], vec![3, 1]).unwrap();
1326 let eval = evaluate(Value::Tensor(a), Value::Tensor(b), &[]).expect("setdiff");
1327 match eval.values_value() {
1328 Value::Tensor(t) => {
1329 assert_eq!(t.shape, vec![1, 1]);
1330 assert_eq!(t.data, vec![5.0]);
1331 }
1332 other => panic!("expected tensor result, got {other:?}"),
1333 }
1334 let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1335 assert_eq!(ia.data, vec![1.0]);
1336 }
1337
1338 #[test]
1339 fn setdiff_numeric_stable() {
1340 let a = Tensor::new(vec![4.0, 2.0, 4.0, 1.0, 3.0], vec![5, 1]).unwrap();
1341 let b = Tensor::new(vec![3.0, 4.0, 5.0, 1.0], vec![4, 1]).unwrap();
1342 let eval = evaluate(Value::Tensor(a), Value::Tensor(b), &[Value::from("stable")])
1343 .expect("setdiff");
1344 match eval.values_value() {
1345 Value::Tensor(t) => {
1346 assert_eq!(t.shape, vec![1, 1]);
1347 assert_eq!(t.data, vec![2.0]);
1348 }
1349 other => panic!("expected tensor result, got {other:?}"),
1350 }
1351 let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1352 assert_eq!(ia.data, vec![2.0]);
1353 }
1354
1355 #[test]
1356 fn setdiff_numeric_rows_sorted() {
1357 let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
1358 let b = Tensor::new(vec![3.0, 5.0, 4.0, 6.0], vec![2, 2]).unwrap();
1359 let eval =
1360 evaluate(Value::Tensor(a), Value::Tensor(b), &[Value::from("rows")]).expect("setdiff");
1361 match eval.values_value() {
1362 Value::Tensor(t) => {
1363 assert_eq!(t.shape, vec![1, 2]);
1364 assert_eq!(t.data, vec![1.0, 2.0]);
1365 }
1366 other => panic!("expected tensor result, got {other:?}"),
1367 }
1368 let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1369 assert_eq!(ia.data, vec![1.0]);
1370 }
1371
1372 #[test]
1373 fn setdiff_numeric_removes_nan() {
1374 let a = Tensor::new(vec![f64::NAN, 2.0, 3.0], vec![3, 1]).unwrap();
1375 let b = Tensor::new(vec![f64::NAN], vec![1, 1]).unwrap();
1376 let eval = evaluate(Value::Tensor(a), Value::Tensor(b), &[]).expect("setdiff");
1377 let values = tensor::value_into_tensor_for("setdiff", eval.values_value()).expect("values");
1378 assert_eq!(values.data, vec![2.0, 3.0]);
1379 let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1380 assert_eq!(ia.data, vec![2.0, 3.0]);
1381 }
1382
1383 #[test]
1384 fn setdiff_char_elements() {
1385 let a = CharArray::new(vec!['m', 'z', 'm', 'a'], 2, 2).unwrap();
1386 let b = CharArray::new(vec!['a', 'x', 'm', 'a'], 2, 2).unwrap();
1387 let eval = evaluate(Value::CharArray(a), Value::CharArray(b), &[]).expect("setdiff");
1388 match eval.values_value() {
1389 Value::CharArray(arr) => {
1390 assert_eq!(arr.rows, 1);
1391 assert_eq!(arr.cols, 1);
1392 assert_eq!(arr.data, vec!['z']);
1393 }
1394 other => panic!("expected char array, got {other:?}"),
1395 }
1396 let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1397 assert_eq!(ia.data, vec![3.0]);
1398 }
1399
1400 #[test]
1401 fn setdiff_string_rows_stable() {
1402 let a = StringArray::new(
1403 vec![
1404 "alpha".to_string(),
1405 "gamma".to_string(),
1406 "beta".to_string(),
1407 "beta".to_string(),
1408 ],
1409 vec![2, 2],
1410 )
1411 .unwrap();
1412 let b = StringArray::new(
1413 vec![
1414 "gamma".to_string(),
1415 "delta".to_string(),
1416 "beta".to_string(),
1417 "beta".to_string(),
1418 ],
1419 vec![2, 2],
1420 )
1421 .unwrap();
1422 let eval = evaluate(
1423 Value::StringArray(a),
1424 Value::StringArray(b),
1425 &[Value::from("rows"), Value::from("stable")],
1426 )
1427 .expect("setdiff");
1428 match eval.values_value() {
1429 Value::StringArray(arr) => {
1430 assert_eq!(arr.shape, vec![1, 2]);
1431 assert_eq!(arr.data, vec!["alpha".to_string(), "beta".to_string()]);
1432 }
1433 other => panic!("expected string array, got {other:?}"),
1434 }
1435 let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1436 assert_eq!(ia.data, vec![1.0]);
1437 }
1438
1439 #[test]
1440 fn setdiff_type_mismatch_errors() {
1441 let result = evaluate(Value::from(1.0), Value::String("a".into()), &[]);
1442 assert!(result.is_err());
1443 }
1444
1445 #[test]
1446 fn setdiff_rejects_legacy_option() {
1447 let result = evaluate(Value::from(1.0), Value::from(2.0), &[Value::from("legacy")]);
1448 assert!(result
1449 .err()
1450 .unwrap()
1451 .contains("setdiff: the 'legacy' behaviour is not supported"));
1452 }
1453
1454 #[test]
1455 fn setdiff_gpu_roundtrip() {
1456 test_support::with_test_provider(|provider| {
1457 let tensor_a = Tensor::new(vec![10.0, 4.0, 6.0, 4.0], vec![4, 1]).unwrap();
1458 let tensor_b = Tensor::new(vec![6.0, 4.0, 2.0], vec![3, 1]).unwrap();
1459 let view_a = HostTensorView {
1460 data: &tensor_a.data,
1461 shape: &tensor_a.shape,
1462 };
1463 let view_b = HostTensorView {
1464 data: &tensor_b.data,
1465 shape: &tensor_b.shape,
1466 };
1467 let handle_a = provider.upload(&view_a).expect("upload a");
1468 let handle_b = provider.upload(&view_b).expect("upload b");
1469 let eval = evaluate(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
1470 .expect("setdiff");
1471 match eval.values_value() {
1472 Value::Tensor(t) => {
1473 assert_eq!(t.data, vec![10.0]);
1474 }
1475 other => panic!("expected tensor result, got {other:?}"),
1476 }
1477 let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1478 assert_eq!(ia.data, vec![1.0]);
1479 });
1480 }
1481
1482 #[test]
1483 #[cfg(feature = "wgpu")]
1484 fn setdiff_wgpu_matches_cpu() {
1485 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1486 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1487 );
1488 let a = Tensor::new(vec![8.0, 4.0, 2.0, 4.0], vec![4, 1]).unwrap();
1489 let b = Tensor::new(vec![2.0, 5.0], vec![2, 1]).unwrap();
1490
1491 let cpu_eval =
1492 evaluate(Value::Tensor(a.clone()), Value::Tensor(b.clone()), &[]).expect("setdiff");
1493 let cpu_values = tensor::value_into_tensor_for("setdiff", cpu_eval.values_value()).unwrap();
1494 let cpu_ia = tensor::value_into_tensor_for("setdiff", cpu_eval.ia_value()).unwrap();
1495
1496 let provider = runmat_accelerate_api::provider().expect("provider");
1497 let view_a = HostTensorView {
1498 data: &a.data,
1499 shape: &a.shape,
1500 };
1501 let view_b = HostTensorView {
1502 data: &b.data,
1503 shape: &b.shape,
1504 };
1505 let handle_a = provider.upload(&view_a).expect("upload A");
1506 let handle_b = provider.upload(&view_b).expect("upload B");
1507 let gpu_eval =
1508 evaluate(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[]).expect("setdiff");
1509 let gpu_values = tensor::value_into_tensor_for("setdiff", gpu_eval.values_value()).unwrap();
1510 let gpu_ia = tensor::value_into_tensor_for("setdiff", gpu_eval.ia_value()).unwrap();
1511
1512 assert_eq!(gpu_values.data, cpu_values.data);
1513 assert_eq!(gpu_values.shape, cpu_values.shape);
1514 assert_eq!(gpu_ia.data, cpu_ia.data);
1515 assert_eq!(gpu_ia.shape, cpu_ia.shape);
1516 }
1517
1518 #[test]
1519 #[cfg(feature = "doc_export")]
1520 fn doc_examples_present() {
1521 let blocks = test_support::doc_examples(DOC_MD);
1522 assert!(!blocks.is_empty());
1523 }
1524}