runmat_runtime/builtins/array/sorting_sets/
setdiff.rs

1//! MATLAB-compatible `setdiff` builtin with GPU-aware semantics for RunMat.
2//!
3//! Provides element-wise and row-wise set difference with optional stable
4//! ordering. GPU tensors are gathered to host memory today, but the builtin is
5//! registered as a residency sink so future providers can implement device-side
6//! kernels without impacting behaviour.
7
8use std::cmp::Ordering;
9use std::collections::{HashMap, HashSet};
10
11use runmat_accelerate_api::{
12    GpuTensorHandle, HostTensorOwned, SetdiffOptions, SetdiffOrder, SetdiffResult,
13};
14use runmat_builtins::{CharArray, ComplexTensor, StringArray, Tensor, Value};
15use runmat_macros::runtime_builtin;
16
17use crate::builtins::common::gpu_helpers;
18use crate::builtins::common::random_args::complex_tensor_into_value;
19use crate::builtins::common::spec::{
20    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
21    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
22};
23use crate::builtins::common::tensor;
24#[cfg(feature = "doc_export")]
25use crate::register_builtin_doc_text;
26use crate::{register_builtin_fusion_spec, register_builtin_gpu_spec};
27
28#[cfg(feature = "doc_export")]
29pub const DOC_MD: &str = r#"---
30title: "setdiff"
31category: "array/sorting_sets"
32keywords: ["setdiff", "difference", "stable", "rows", "indices", "gpu"]
33summary: "Return values that appear in the first input but not the second, matching MATLAB ordering rules."
34references:
35  - https://www.mathworks.com/help/matlab/ref/setdiff.html
36gpu_support:
37  elementwise: false
38  reduction: false
39  precisions: ["f32", "f64"]
40  broadcasting: "none"
41  notes: "When providers lack a dedicated `setdiff` hook, RunMat gathers GPU tensors to host memory and reuses the CPU path."
42fusion:
43  elementwise: false
44  reduction: false
45  max_inputs: 2
46  constants: "inline"
47requires_feature: null
48tested:
49  unit: "builtins::array::sorting_sets::setdiff::tests"
50  integration: "builtins::array::sorting_sets::setdiff::tests::setdiff_gpu_roundtrip"
51---
52
53# What does the `setdiff` function do in MATLAB / RunMat?
54`setdiff(A, B)` returns the set of values (or rows) that appear in `A` but not in `B`. Results are
55unique, and the function can operate in sorted or stable order as well as row mode.
56
57## How does the `setdiff` function behave in MATLAB / RunMat?
58- `setdiff(A, B)` flattens inputs column-major, removes duplicates, subtracts the values of `B` from `A`,
59  and returns the remaining elements **sorted** ascending by default.
60- `[C, IA] = setdiff(A, B)` also returns indices so that `C = A(IA)`.
61- `setdiff(A, B, 'stable')` preserves the first appearance order from `A` instead of sorting.
62- `setdiff(A, B, 'rows')` treats each row as an element. Inputs must share the same number of columns.
63- Character arrays, string arrays, logical arrays, numeric types, and complex values are all supported.
64- Legacy flags (`'legacy'`, `'R2012a'`) are not supported; RunMat always follows modern MATLAB semantics.
65
66## `setdiff` Function GPU Execution Behaviour
67`setdiff` is registered as a residency sink. When tensors reside on the GPU and the active provider
68does not yet implement a `setdiff` hook, RunMat gathers them to host memory, performs the CPU
69implementation, and materialises host-resident results. Future providers can wire a custom hook to
70perform the set difference directly on-device without affecting existing callers.
71
72## Examples of using the `setdiff` function in MATLAB / RunMat
73
74### Finding values exclusive to the first numeric vector
75```matlab
76A = [5 7 5 1];
77B = [7 1 3];
78[C, IA] = setdiff(A, B);
79```
80Expected output:
81```matlab
82C =
83     5
84IA =
85     1
86```
87
88### Preserving input order with `'stable'`
89```matlab
90A = [4 2 4 1 3];
91B = [3 4 5 1];
92[C, IA] = setdiff(A, B, 'stable');
93```
94Expected output:
95```matlab
96C =
97     2
98IA =
99     2
100```
101
102### Working with rows of numeric matrices
103```matlab
104A = [1 2; 3 4; 1 2];
105B = [3 4; 5 6];
106[C, IA] = setdiff(A, B, 'rows');
107```
108Expected output:
109```matlab
110C =
111     1     2
112IA =
113     1
114```
115
116### Computing set difference for character data
117```matlab
118A = ['m','z'; 'm','a'];
119B = ['a','x'; 'm','a'];
120[C, IA] = setdiff(A, B);
121```
122Expected output:
123```matlab
124C =
125    m
126IA =
127     1
128```
129
130### Subtracting string arrays by row
131```matlab
132A = ["alpha" "beta"; "gamma" "beta"];
133B = ["gamma" "beta"; "delta" "beta"];
134[C, IA] = setdiff(A, B, 'rows', 'stable');
135```
136Expected output:
137```matlab
138C =
139  1x2 string array
140    "alpha"    "beta"
141IA =
142     1
143```
144
145### Using `setdiff` with GPU arrays
146```matlab
147G = gpuArray([10 4 6 4]);
148H = gpuArray([6 4 2]);
149C = setdiff(G, H);
150```
151RunMat gathers `G` and `H` to the host (until providers implement a GPU hook) and returns:
152```matlab
153C =
154    10
155```
156
157## FAQ
158
159### What ordering does `setdiff` use by default?
160Results are sorted ascending. Specify `'stable'` to preserve the first appearance order from the first input.
161
162### How are the index outputs defined?
163`IA` points to the positions in `A` that correspond to each element (or row) returned in `C`, using MATLAB's one-based indexing.
164
165### Can I combine `'rows'` with `'stable'`?
166Yes. `'rows'` can be paired with either `'sorted'` (default) or `'stable'`. Other option combinations that conflict (e.g. `'sorted'` with `'stable'`) are rejected.
167
168### Does `setdiff` remove `NaN` values from `A` when they exist in `B`?
169Yes. `NaN` values are considered equal. If `B` contains `NaN`, all `NaN` entries from `A` are removed.
170
171### Are complex numbers supported?
172Absolutely. Complex values use MATLAB's ordering rules (magnitude, then real part, then imaginary part) for the sorted output.
173
174### Does GPU execution change the results?
175No. Until providers supply a device implementation, RunMat gathers GPU inputs and executes the CPU path to guarantee MATLAB-compatible behaviour.
176
177### What happens if the inputs have different classes?
178RunMat follows MATLAB's rules: both inputs must share the same class (numeric/logical, complex, char, or string). Mixed-class inputs raise descriptive errors.
179
180### Can I request `'legacy'` behaviour?
181No. RunMat implements the modern semantics only. Passing `'legacy'` or `'R2012a'` results in an error.
182
183## See Also
184[unique](./unique), [union](./union), [intersect](./intersect), [ismember](./ismember), [gpuArray](../../acceleration/gpu/gpuArray), [gather](../../acceleration/gpu/gather)
185
186## Source & Feedback
187- Implementation: `crates/runmat-runtime/src/builtins/array/sorting_sets/setdiff.rs`
188- Issues / feedback: https://github.com/runmat-org/runmat/issues/new/choose
189"#;
190
191pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
192    name: "setdiff",
193    op_kind: GpuOpKind::Custom("setdiff"),
194    supported_precisions: &[ScalarType::F32, ScalarType::F64],
195    broadcast: BroadcastSemantics::None,
196    provider_hooks: &[ProviderHook::Custom("setdiff")],
197    constant_strategy: ConstantStrategy::InlineLiteral,
198    residency: ResidencyPolicy::GatherImmediately,
199    nan_mode: ReductionNaN::Include,
200    two_pass_threshold: None,
201    workgroup_size: None,
202    accepts_nan_mode: true,
203    notes: "Providers may implement `setdiff`; until then tensors are gathered and processed on the host.",
204};
205
206register_builtin_gpu_spec!(GPU_SPEC);
207
208pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
209    name: "setdiff",
210    shape: ShapeRequirements::Any,
211    constant_strategy: ConstantStrategy::InlineLiteral,
212    elementwise: None,
213    reduction: None,
214    emits_nan: true,
215    notes: "`setdiff` materialises its inputs and terminates fusion chains; upstream GPU tensors are gathered if needed.",
216};
217
218register_builtin_fusion_spec!(FUSION_SPEC);
219
220#[cfg(feature = "doc_export")]
221register_builtin_doc_text!("setdiff", DOC_MD);
222
223#[runtime_builtin(
224    name = "setdiff",
225    category = "array/sorting_sets",
226    summary = "Return the values that appear in the first input but not the second.",
227    keywords = "setdiff,difference,stable,rows,indices,gpu",
228    accel = "array_construct",
229    sink = true
230)]
231fn setdiff_builtin(a: Value, b: Value, rest: Vec<Value>) -> Result<Value, String> {
232    evaluate(a, b, &rest).map(|eval| eval.into_values_value())
233}
234
235/// Evaluate `setdiff` once and expose all outputs to the caller.
236pub fn evaluate(a: Value, b: Value, rest: &[Value]) -> Result<SetdiffEvaluation, String> {
237    let opts = parse_options(rest)?;
238    match (a, b) {
239        (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
240            setdiff_gpu_pair(handle_a, handle_b, &opts)
241        }
242        (Value::GpuTensor(handle_a), other) => setdiff_gpu_mixed(handle_a, other, &opts, true),
243        (other, Value::GpuTensor(handle_b)) => setdiff_gpu_mixed(handle_b, other, &opts, false),
244        (left, right) => setdiff_host(left, right, &opts),
245    }
246}
247
248fn parse_options(rest: &[Value]) -> Result<SetdiffOptions, String> {
249    let mut opts = SetdiffOptions {
250        rows: false,
251        order: SetdiffOrder::Sorted,
252    };
253    let mut seen_order: Option<SetdiffOrder> = None;
254
255    for arg in rest {
256        let text = tensor::value_to_string(arg)
257            .ok_or_else(|| "setdiff: expected string option arguments".to_string())?;
258        let lowered = text.trim().to_ascii_lowercase();
259        match lowered.as_str() {
260            "rows" => opts.rows = true,
261            "sorted" => {
262                if let Some(prev) = seen_order {
263                    if prev != SetdiffOrder::Sorted {
264                        return Err("setdiff: cannot combine 'sorted' with 'stable'".to_string());
265                    }
266                }
267                seen_order = Some(SetdiffOrder::Sorted);
268                opts.order = SetdiffOrder::Sorted;
269            }
270            "stable" => {
271                if let Some(prev) = seen_order {
272                    if prev != SetdiffOrder::Stable {
273                        return Err("setdiff: cannot combine 'sorted' with 'stable'".to_string());
274                    }
275                }
276                seen_order = Some(SetdiffOrder::Stable);
277                opts.order = SetdiffOrder::Stable;
278            }
279            "legacy" | "r2012a" => {
280                return Err("setdiff: the 'legacy' behaviour is not supported".to_string());
281            }
282            other => return Err(format!("setdiff: unrecognised option '{other}'")),
283        }
284    }
285
286    Ok(opts)
287}
288
289fn setdiff_gpu_pair(
290    handle_a: GpuTensorHandle,
291    handle_b: GpuTensorHandle,
292    opts: &SetdiffOptions,
293) -> Result<SetdiffEvaluation, String> {
294    if let Some(provider) = runmat_accelerate_api::provider() {
295        match provider.setdiff(&handle_a, &handle_b, opts) {
296            Ok(result) => return SetdiffEvaluation::from_setdiff_result(result),
297            Err(_) => {
298                // Fall back to host gather when provider does not support setdiff.
299            }
300        }
301    }
302    let a_tensor = gpu_helpers::gather_tensor(&handle_a)?;
303    let b_tensor = gpu_helpers::gather_tensor(&handle_b)?;
304    setdiff_numeric(a_tensor, b_tensor, opts)
305}
306
307fn setdiff_gpu_mixed(
308    handle_gpu: GpuTensorHandle,
309    other: Value,
310    opts: &SetdiffOptions,
311    gpu_is_a: bool,
312) -> Result<SetdiffEvaluation, String> {
313    let gpu_tensor = gpu_helpers::gather_tensor(&handle_gpu)?;
314    let other_tensor = tensor::value_into_tensor_for("setdiff", other)?;
315    if gpu_is_a {
316        setdiff_numeric(gpu_tensor, other_tensor, opts)
317    } else {
318        setdiff_numeric(other_tensor, gpu_tensor, opts)
319    }
320}
321
322fn setdiff_host(a: Value, b: Value, opts: &SetdiffOptions) -> Result<SetdiffEvaluation, String> {
323    match (a, b) {
324        (Value::ComplexTensor(at), Value::ComplexTensor(bt)) => setdiff_complex(at, bt, opts),
325        (Value::ComplexTensor(at), Value::Complex(re, im)) => {
326            let bt = ComplexTensor::new(vec![(re, im)], vec![1, 1])
327                .map_err(|e| format!("setdiff: {e}"))?;
328            setdiff_complex(at, bt, opts)
329        }
330        (Value::Complex(a_re, a_im), Value::ComplexTensor(bt)) => {
331            let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
332                .map_err(|e| format!("setdiff: {e}"))?;
333            setdiff_complex(at, bt, opts)
334        }
335        (Value::Complex(a_re, a_im), Value::Complex(b_re, b_im)) => {
336            let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
337                .map_err(|e| format!("setdiff: {e}"))?;
338            let bt = ComplexTensor::new(vec![(b_re, b_im)], vec![1, 1])
339                .map_err(|e| format!("setdiff: {e}"))?;
340            setdiff_complex(at, bt, opts)
341        }
342
343        (Value::CharArray(ac), Value::CharArray(bc)) => setdiff_char(ac, bc, opts),
344
345        (Value::StringArray(astring), Value::StringArray(bstring)) => {
346            setdiff_string(astring, bstring, opts)
347        }
348        (Value::StringArray(astring), Value::String(b)) => {
349            let bstring =
350                StringArray::new(vec![b], vec![1, 1]).map_err(|e| format!("setdiff: {e}"))?;
351            setdiff_string(astring, bstring, opts)
352        }
353        (Value::String(a), Value::StringArray(bstring)) => {
354            let astring =
355                StringArray::new(vec![a], vec![1, 1]).map_err(|e| format!("setdiff: {e}"))?;
356            setdiff_string(astring, bstring, opts)
357        }
358        (Value::String(a), Value::String(b)) => {
359            let astring =
360                StringArray::new(vec![a], vec![1, 1]).map_err(|e| format!("setdiff: {e}"))?;
361            let bstring =
362                StringArray::new(vec![b], vec![1, 1]).map_err(|e| format!("setdiff: {e}"))?;
363            setdiff_string(astring, bstring, opts)
364        }
365
366        (left, right) => {
367            let tensor_a = tensor::value_into_tensor_for("setdiff", left)?;
368            let tensor_b = tensor::value_into_tensor_for("setdiff", right)?;
369            setdiff_numeric(tensor_a, tensor_b, opts)
370        }
371    }
372}
373
374fn setdiff_numeric(
375    a: Tensor,
376    b: Tensor,
377    opts: &SetdiffOptions,
378) -> Result<SetdiffEvaluation, String> {
379    if opts.rows {
380        setdiff_numeric_rows(a, b, opts)
381    } else {
382        setdiff_numeric_elements(a, b, opts)
383    }
384}
385
386/// Helper exposed for acceleration providers handling numeric tensors entirely on the host.
387pub fn setdiff_numeric_from_tensors(
388    a: Tensor,
389    b: Tensor,
390    opts: &SetdiffOptions,
391) -> Result<SetdiffEvaluation, String> {
392    setdiff_numeric(a, b, opts)
393}
394
395fn setdiff_numeric_elements(
396    a: Tensor,
397    b: Tensor,
398    opts: &SetdiffOptions,
399) -> Result<SetdiffEvaluation, String> {
400    let mut b_keys: HashSet<u64> = HashSet::new();
401    for &value in &b.data {
402        b_keys.insert(canonicalize_f64(value));
403    }
404
405    let mut seen: HashMap<u64, usize> = HashMap::new();
406    let mut entries = Vec::<NumericDiffEntry>::new();
407    let mut order_counter = 0usize;
408
409    for (idx, &value) in a.data.iter().enumerate() {
410        let key = canonicalize_f64(value);
411        if b_keys.contains(&key) {
412            continue;
413        }
414        if seen.contains_key(&key) {
415            continue;
416        }
417        let entry_idx = entries.len();
418        entries.push(NumericDiffEntry {
419            value,
420            index: idx,
421            order_rank: order_counter,
422        });
423        seen.insert(key, entry_idx);
424        order_counter += 1;
425    }
426
427    assemble_numeric_setdiff(entries, opts)
428}
429
430fn setdiff_numeric_rows(
431    a: Tensor,
432    b: Tensor,
433    opts: &SetdiffOptions,
434) -> Result<SetdiffEvaluation, String> {
435    if a.shape.len() != 2 || b.shape.len() != 2 {
436        return Err("setdiff: 'rows' option requires 2-D numeric matrices".to_string());
437    }
438    if a.shape[1] != b.shape[1] {
439        return Err(
440            "setdiff: inputs must have the same number of columns when using 'rows'".to_string(),
441        );
442    }
443
444    let rows_a = a.shape[0];
445    let rows_b = b.shape[0];
446    let cols = a.shape[1];
447
448    let mut b_keys: HashSet<NumericRowKey> = HashSet::new();
449    for r in 0..rows_b {
450        let mut row_values = Vec::with_capacity(cols);
451        for c in 0..cols {
452            let idx = r + c * rows_b;
453            row_values.push(b.data[idx]);
454        }
455        b_keys.insert(NumericRowKey::from_slice(&row_values));
456    }
457
458    let mut seen: HashSet<NumericRowKey> = HashSet::new();
459    let mut entries = Vec::<NumericRowDiffEntry>::new();
460    let mut order_counter = 0usize;
461
462    for r in 0..rows_a {
463        let mut row_values = Vec::with_capacity(cols);
464        for c in 0..cols {
465            let idx = r + c * rows_a;
466            row_values.push(a.data[idx]);
467        }
468        let key = NumericRowKey::from_slice(&row_values);
469        if b_keys.contains(&key) {
470            continue;
471        }
472        if !seen.insert(key) {
473            continue;
474        }
475        entries.push(NumericRowDiffEntry {
476            row_data: row_values,
477            row_index: r,
478            order_rank: order_counter,
479        });
480        order_counter += 1;
481    }
482
483    assemble_numeric_row_setdiff(entries, opts, cols)
484}
485
486fn setdiff_complex(
487    a: ComplexTensor,
488    b: ComplexTensor,
489    opts: &SetdiffOptions,
490) -> Result<SetdiffEvaluation, String> {
491    if opts.rows {
492        setdiff_complex_rows(a, b, opts)
493    } else {
494        setdiff_complex_elements(a, b, opts)
495    }
496}
497
498fn setdiff_complex_elements(
499    a: ComplexTensor,
500    b: ComplexTensor,
501    opts: &SetdiffOptions,
502) -> Result<SetdiffEvaluation, String> {
503    let mut b_keys: HashSet<ComplexKey> = HashSet::new();
504    for &value in &b.data {
505        b_keys.insert(ComplexKey::new(value));
506    }
507
508    let mut seen: HashSet<ComplexKey> = HashSet::new();
509    let mut entries = Vec::<ComplexDiffEntry>::new();
510    let mut order_counter = 0usize;
511
512    for (idx, &value) in a.data.iter().enumerate() {
513        let key = ComplexKey::new(value);
514        if b_keys.contains(&key) {
515            continue;
516        }
517        if !seen.insert(key) {
518            continue;
519        }
520        entries.push(ComplexDiffEntry {
521            value,
522            index: idx,
523            order_rank: order_counter,
524        });
525        order_counter += 1;
526    }
527
528    assemble_complex_setdiff(entries, opts)
529}
530
531fn setdiff_complex_rows(
532    a: ComplexTensor,
533    b: ComplexTensor,
534    opts: &SetdiffOptions,
535) -> Result<SetdiffEvaluation, String> {
536    if a.shape.len() != 2 || b.shape.len() != 2 {
537        return Err("setdiff: 'rows' option requires 2-D complex matrices".to_string());
538    }
539    if a.shape[1] != b.shape[1] {
540        return Err(
541            "setdiff: inputs must have the same number of columns when using 'rows'".to_string(),
542        );
543    }
544
545    let rows_a = a.shape[0];
546    let rows_b = b.shape[0];
547    let cols = a.shape[1];
548
549    let mut b_keys: HashSet<Vec<ComplexKey>> = HashSet::new();
550    for r in 0..rows_b {
551        let mut key_row = Vec::with_capacity(cols);
552        for c in 0..cols {
553            let idx = r + c * rows_b;
554            key_row.push(ComplexKey::new(b.data[idx]));
555        }
556        b_keys.insert(key_row);
557    }
558
559    let mut seen: HashSet<Vec<ComplexKey>> = HashSet::new();
560    let mut entries = Vec::<ComplexRowDiffEntry>::new();
561    let mut order_counter = 0usize;
562
563    for r in 0..rows_a {
564        let mut row_values = Vec::with_capacity(cols);
565        let mut key_row = Vec::with_capacity(cols);
566        for c in 0..cols {
567            let idx = r + c * rows_a;
568            let value = a.data[idx];
569            row_values.push(value);
570            key_row.push(ComplexKey::new(value));
571        }
572        if b_keys.contains(&key_row) {
573            continue;
574        }
575        if !seen.insert(key_row) {
576            continue;
577        }
578        entries.push(ComplexRowDiffEntry {
579            row_data: row_values,
580            row_index: r,
581            order_rank: order_counter,
582        });
583        order_counter += 1;
584    }
585
586    assemble_complex_row_setdiff(entries, opts, cols)
587}
588
589fn setdiff_char(
590    a: CharArray,
591    b: CharArray,
592    opts: &SetdiffOptions,
593) -> Result<SetdiffEvaluation, String> {
594    if opts.rows {
595        setdiff_char_rows(a, b, opts)
596    } else {
597        setdiff_char_elements(a, b, opts)
598    }
599}
600
601fn setdiff_char_elements(
602    a: CharArray,
603    b: CharArray,
604    opts: &SetdiffOptions,
605) -> Result<SetdiffEvaluation, String> {
606    let mut b_keys: HashSet<u32> = HashSet::new();
607    for ch in &b.data {
608        b_keys.insert(*ch as u32);
609    }
610
611    let mut seen: HashSet<u32> = HashSet::new();
612    let mut entries = Vec::<CharDiffEntry>::new();
613    let mut order_counter = 0usize;
614
615    for col in 0..a.cols {
616        for row in 0..a.rows {
617            let linear_idx = row + col * a.rows;
618            let data_idx = row * a.cols + col;
619            let ch = a.data[data_idx];
620            let key = ch as u32;
621            if b_keys.contains(&key) {
622                continue;
623            }
624            if !seen.insert(key) {
625                continue;
626            }
627            entries.push(CharDiffEntry {
628                ch,
629                index: linear_idx,
630                order_rank: order_counter,
631            });
632            order_counter += 1;
633        }
634    }
635
636    assemble_char_setdiff(entries, opts)
637}
638
639fn setdiff_char_rows(
640    a: CharArray,
641    b: CharArray,
642    opts: &SetdiffOptions,
643) -> Result<SetdiffEvaluation, String> {
644    if a.cols != b.cols {
645        return Err(
646            "setdiff: inputs must have the same number of columns when using 'rows'".to_string(),
647        );
648    }
649
650    let rows_a = a.rows;
651    let rows_b = b.rows;
652    let cols = a.cols;
653
654    let mut b_keys: HashSet<RowCharKey> = HashSet::new();
655    for r in 0..rows_b {
656        let mut row_values = Vec::with_capacity(cols);
657        for c in 0..cols {
658            let idx = r * cols + c;
659            row_values.push(b.data[idx]);
660        }
661        b_keys.insert(RowCharKey::from_slice(&row_values));
662    }
663
664    let mut seen: HashSet<RowCharKey> = HashSet::new();
665    let mut entries = Vec::<CharRowDiffEntry>::new();
666    let mut order_counter = 0usize;
667
668    for r in 0..rows_a {
669        let mut row_values = Vec::with_capacity(cols);
670        for c in 0..cols {
671            let idx = r * cols + c;
672            row_values.push(a.data[idx]);
673        }
674        let key = RowCharKey::from_slice(&row_values);
675        if b_keys.contains(&key) {
676            continue;
677        }
678        if !seen.insert(key) {
679            continue;
680        }
681        entries.push(CharRowDiffEntry {
682            row_data: row_values,
683            row_index: r,
684            order_rank: order_counter,
685        });
686        order_counter += 1;
687    }
688
689    assemble_char_row_setdiff(entries, opts, cols)
690}
691
692fn setdiff_string(
693    a: StringArray,
694    b: StringArray,
695    opts: &SetdiffOptions,
696) -> Result<SetdiffEvaluation, String> {
697    if opts.rows {
698        setdiff_string_rows(a, b, opts)
699    } else {
700        setdiff_string_elements(a, b, opts)
701    }
702}
703
704fn setdiff_string_elements(
705    a: StringArray,
706    b: StringArray,
707    opts: &SetdiffOptions,
708) -> Result<SetdiffEvaluation, String> {
709    let mut b_keys: HashSet<String> = HashSet::new();
710    for value in &b.data {
711        b_keys.insert(value.clone());
712    }
713
714    let mut seen: HashSet<String> = HashSet::new();
715    let mut entries = Vec::<StringDiffEntry>::new();
716    let mut order_counter = 0usize;
717
718    for (idx, value) in a.data.iter().enumerate() {
719        if b_keys.contains(value) {
720            continue;
721        }
722        if !seen.insert(value.clone()) {
723            continue;
724        }
725        entries.push(StringDiffEntry {
726            value: value.clone(),
727            index: idx,
728            order_rank: order_counter,
729        });
730        order_counter += 1;
731    }
732
733    assemble_string_setdiff(entries, opts)
734}
735
736fn setdiff_string_rows(
737    a: StringArray,
738    b: StringArray,
739    opts: &SetdiffOptions,
740) -> Result<SetdiffEvaluation, String> {
741    if a.shape.len() != 2 || b.shape.len() != 2 {
742        return Err("setdiff: 'rows' option requires 2-D string arrays".to_string());
743    }
744    if a.shape[1] != b.shape[1] {
745        return Err(
746            "setdiff: inputs must have the same number of columns when using 'rows'".to_string(),
747        );
748    }
749
750    let rows_a = a.shape[0];
751    let rows_b = b.shape[0];
752    let cols = a.shape[1];
753
754    let mut b_keys: HashSet<RowStringKey> = HashSet::new();
755    for r in 0..rows_b {
756        let mut row_values = Vec::with_capacity(cols);
757        for c in 0..cols {
758            let idx = r + c * rows_b;
759            row_values.push(b.data[idx].clone());
760        }
761        b_keys.insert(RowStringKey(row_values.clone()));
762    }
763
764    let mut seen: HashSet<RowStringKey> = HashSet::new();
765    let mut entries = Vec::<StringRowDiffEntry>::new();
766    let mut order_counter = 0usize;
767
768    for r in 0..rows_a {
769        let mut row_values = Vec::with_capacity(cols);
770        for c in 0..cols {
771            let idx = r + c * rows_a;
772            row_values.push(a.data[idx].clone());
773        }
774        let key = RowStringKey(row_values.clone());
775        if b_keys.contains(&key) {
776            continue;
777        }
778        if !seen.insert(key) {
779            continue;
780        }
781        entries.push(StringRowDiffEntry {
782            row_data: row_values,
783            row_index: r,
784            order_rank: order_counter,
785        });
786        order_counter += 1;
787    }
788
789    assemble_string_row_setdiff(entries, opts, cols)
790}
791
792fn assemble_numeric_setdiff(
793    entries: Vec<NumericDiffEntry>,
794    opts: &SetdiffOptions,
795) -> Result<SetdiffEvaluation, String> {
796    let mut order: Vec<usize> = (0..entries.len()).collect();
797    match opts.order {
798        SetdiffOrder::Sorted => {
799            order.sort_by(|&lhs, &rhs| compare_f64(entries[lhs].value, entries[rhs].value));
800        }
801        SetdiffOrder::Stable => {
802            order.sort_by_key(|&idx| entries[idx].order_rank);
803        }
804    }
805
806    let mut values = Vec::with_capacity(order.len());
807    let mut ia = Vec::with_capacity(order.len());
808    for &idx in &order {
809        let entry = &entries[idx];
810        values.push(entry.value);
811        ia.push((entry.index + 1) as f64);
812    }
813
814    let value_tensor =
815        Tensor::new(values, vec![order.len(), 1]).map_err(|e| format!("setdiff: {e}"))?;
816    let ia_tensor = Tensor::new(ia, vec![order.len(), 1]).map_err(|e| format!("setdiff: {e}"))?;
817
818    Ok(SetdiffEvaluation::new(
819        Value::Tensor(value_tensor),
820        ia_tensor,
821    ))
822}
823
824fn assemble_numeric_row_setdiff(
825    entries: Vec<NumericRowDiffEntry>,
826    opts: &SetdiffOptions,
827    cols: usize,
828) -> Result<SetdiffEvaluation, String> {
829    let mut order: Vec<usize> = (0..entries.len()).collect();
830    match opts.order {
831        SetdiffOrder::Sorted => {
832            order.sort_by(|&lhs, &rhs| {
833                compare_numeric_rows(&entries[lhs].row_data, &entries[rhs].row_data)
834            });
835        }
836        SetdiffOrder::Stable => {
837            order.sort_by_key(|&idx| entries[idx].order_rank);
838        }
839    }
840
841    let unique_rows = order.len();
842    let mut values = vec![0.0f64; unique_rows * cols];
843    let mut ia = Vec::with_capacity(unique_rows);
844
845    for (row_pos, &entry_idx) in order.iter().enumerate() {
846        let entry = &entries[entry_idx];
847        for col in 0..cols {
848            let dest = row_pos + col * unique_rows;
849            values[dest] = entry.row_data[col];
850        }
851        ia.push((entry.row_index + 1) as f64);
852    }
853
854    let value_tensor =
855        Tensor::new(values, vec![unique_rows, cols]).map_err(|e| format!("setdiff: {e}"))?;
856    let ia_tensor = Tensor::new(ia, vec![unique_rows, 1]).map_err(|e| format!("setdiff: {e}"))?;
857
858    Ok(SetdiffEvaluation::new(
859        Value::Tensor(value_tensor),
860        ia_tensor,
861    ))
862}
863
864fn assemble_complex_setdiff(
865    entries: Vec<ComplexDiffEntry>,
866    opts: &SetdiffOptions,
867) -> Result<SetdiffEvaluation, String> {
868    let mut order: Vec<usize> = (0..entries.len()).collect();
869    match opts.order {
870        SetdiffOrder::Sorted => {
871            order.sort_by(|&lhs, &rhs| compare_complex(entries[lhs].value, entries[rhs].value));
872        }
873        SetdiffOrder::Stable => {
874            order.sort_by_key(|&idx| entries[idx].order_rank);
875        }
876    }
877
878    let mut values = Vec::with_capacity(order.len());
879    let mut ia = Vec::with_capacity(order.len());
880    for &idx in &order {
881        let entry = &entries[idx];
882        values.push(entry.value);
883        ia.push((entry.index + 1) as f64);
884    }
885
886    let value_tensor =
887        ComplexTensor::new(values, vec![order.len(), 1]).map_err(|e| format!("setdiff: {e}"))?;
888    let ia_tensor = Tensor::new(ia, vec![order.len(), 1]).map_err(|e| format!("setdiff: {e}"))?;
889
890    Ok(SetdiffEvaluation::new(
891        complex_tensor_into_value(value_tensor),
892        ia_tensor,
893    ))
894}
895
896fn assemble_complex_row_setdiff(
897    entries: Vec<ComplexRowDiffEntry>,
898    opts: &SetdiffOptions,
899    cols: usize,
900) -> Result<SetdiffEvaluation, String> {
901    let mut order: Vec<usize> = (0..entries.len()).collect();
902    match opts.order {
903        SetdiffOrder::Sorted => {
904            order.sort_by(|&lhs, &rhs| {
905                compare_complex_rows(&entries[lhs].row_data, &entries[rhs].row_data)
906            });
907        }
908        SetdiffOrder::Stable => {
909            order.sort_by_key(|&idx| entries[idx].order_rank);
910        }
911    }
912
913    let unique_rows = order.len();
914    let mut values = vec![(0.0f64, 0.0f64); unique_rows * cols];
915    let mut ia = Vec::with_capacity(unique_rows);
916
917    for (row_pos, &entry_idx) in order.iter().enumerate() {
918        let entry = &entries[entry_idx];
919        for col in 0..cols {
920            let dest = row_pos + col * unique_rows;
921            values[dest] = entry.row_data[col];
922        }
923        ia.push((entry.row_index + 1) as f64);
924    }
925
926    let value_tensor =
927        ComplexTensor::new(values, vec![unique_rows, cols]).map_err(|e| format!("setdiff: {e}"))?;
928    let ia_tensor = Tensor::new(ia, vec![unique_rows, 1]).map_err(|e| format!("setdiff: {e}"))?;
929
930    Ok(SetdiffEvaluation::new(
931        complex_tensor_into_value(value_tensor),
932        ia_tensor,
933    ))
934}
935
936fn assemble_char_setdiff(
937    entries: Vec<CharDiffEntry>,
938    opts: &SetdiffOptions,
939) -> Result<SetdiffEvaluation, String> {
940    let mut order: Vec<usize> = (0..entries.len()).collect();
941    match opts.order {
942        SetdiffOrder::Sorted => {
943            order.sort_by(|&lhs, &rhs| entries[lhs].ch.cmp(&entries[rhs].ch));
944        }
945        SetdiffOrder::Stable => {
946            order.sort_by_key(|&idx| entries[idx].order_rank);
947        }
948    }
949
950    let mut values = Vec::with_capacity(order.len());
951    let mut ia = Vec::with_capacity(order.len());
952    for &idx in &order {
953        let entry = &entries[idx];
954        values.push(entry.ch);
955        ia.push((entry.index + 1) as f64);
956    }
957
958    let value_array =
959        CharArray::new(values, order.len(), 1).map_err(|e| format!("setdiff: {e}"))?;
960    let ia_tensor = Tensor::new(ia, vec![order.len(), 1]).map_err(|e| format!("setdiff: {e}"))?;
961
962    Ok(SetdiffEvaluation::new(
963        Value::CharArray(value_array),
964        ia_tensor,
965    ))
966}
967
968fn assemble_char_row_setdiff(
969    entries: Vec<CharRowDiffEntry>,
970    opts: &SetdiffOptions,
971    cols: usize,
972) -> Result<SetdiffEvaluation, String> {
973    let mut order: Vec<usize> = (0..entries.len()).collect();
974    match opts.order {
975        SetdiffOrder::Sorted => {
976            order.sort_by(|&lhs, &rhs| {
977                compare_char_rows(&entries[lhs].row_data, &entries[rhs].row_data)
978            });
979        }
980        SetdiffOrder::Stable => {
981            order.sort_by_key(|&idx| entries[idx].order_rank);
982        }
983    }
984
985    let unique_rows = order.len();
986    let mut values = vec!['\0'; unique_rows * cols];
987    let mut ia = Vec::with_capacity(unique_rows);
988
989    for (row_pos, &entry_idx) in order.iter().enumerate() {
990        let entry = &entries[entry_idx];
991        for col in 0..cols {
992            let dest = row_pos * cols + col;
993            values[dest] = entry.row_data[col];
994        }
995        ia.push((entry.row_index + 1) as f64);
996    }
997
998    let value_array =
999        CharArray::new(values, unique_rows, cols).map_err(|e| format!("setdiff: {e}"))?;
1000    let ia_tensor = Tensor::new(ia, vec![unique_rows, 1]).map_err(|e| format!("setdiff: {e}"))?;
1001
1002    Ok(SetdiffEvaluation::new(
1003        Value::CharArray(value_array),
1004        ia_tensor,
1005    ))
1006}
1007
1008fn assemble_string_setdiff(
1009    entries: Vec<StringDiffEntry>,
1010    opts: &SetdiffOptions,
1011) -> Result<SetdiffEvaluation, String> {
1012    let mut order: Vec<usize> = (0..entries.len()).collect();
1013    match opts.order {
1014        SetdiffOrder::Sorted => {
1015            order.sort_by(|&lhs, &rhs| entries[lhs].value.cmp(&entries[rhs].value));
1016        }
1017        SetdiffOrder::Stable => {
1018            order.sort_by_key(|&idx| entries[idx].order_rank);
1019        }
1020    }
1021
1022    let mut values = Vec::with_capacity(order.len());
1023    let mut ia = Vec::with_capacity(order.len());
1024    for &idx in &order {
1025        let entry = &entries[idx];
1026        values.push(entry.value.clone());
1027        ia.push((entry.index + 1) as f64);
1028    }
1029
1030    let value_array =
1031        StringArray::new(values, vec![order.len(), 1]).map_err(|e| format!("setdiff: {e}"))?;
1032    let ia_tensor = Tensor::new(ia, vec![order.len(), 1]).map_err(|e| format!("setdiff: {e}"))?;
1033
1034    Ok(SetdiffEvaluation::new(
1035        Value::StringArray(value_array),
1036        ia_tensor,
1037    ))
1038}
1039
1040fn assemble_string_row_setdiff(
1041    entries: Vec<StringRowDiffEntry>,
1042    opts: &SetdiffOptions,
1043    cols: usize,
1044) -> Result<SetdiffEvaluation, String> {
1045    let mut order: Vec<usize> = (0..entries.len()).collect();
1046    match opts.order {
1047        SetdiffOrder::Sorted => {
1048            order.sort_by(|&lhs, &rhs| {
1049                compare_string_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1050            });
1051        }
1052        SetdiffOrder::Stable => {
1053            order.sort_by_key(|&idx| entries[idx].order_rank);
1054        }
1055    }
1056
1057    let unique_rows = order.len();
1058    let mut values = vec![String::new(); unique_rows * cols];
1059    let mut ia = Vec::with_capacity(unique_rows);
1060
1061    for (row_pos, &entry_idx) in order.iter().enumerate() {
1062        let entry = &entries[entry_idx];
1063        for col in 0..cols {
1064            let dest = row_pos + col * unique_rows;
1065            values[dest] = entry.row_data[col].clone();
1066        }
1067        ia.push((entry.row_index + 1) as f64);
1068    }
1069
1070    let value_array =
1071        StringArray::new(values, vec![unique_rows, cols]).map_err(|e| format!("setdiff: {e}"))?;
1072    let ia_tensor = Tensor::new(ia, vec![unique_rows, 1]).map_err(|e| format!("setdiff: {e}"))?;
1073
1074    Ok(SetdiffEvaluation::new(
1075        Value::StringArray(value_array),
1076        ia_tensor,
1077    ))
1078}
1079
1080#[derive(Clone, Copy, Debug)]
1081struct NumericDiffEntry {
1082    value: f64,
1083    index: usize,
1084    order_rank: usize,
1085}
1086
1087#[derive(Clone, Debug)]
1088struct NumericRowDiffEntry {
1089    row_data: Vec<f64>,
1090    row_index: usize,
1091    order_rank: usize,
1092}
1093
1094#[derive(Clone, Copy, Debug)]
1095struct ComplexDiffEntry {
1096    value: (f64, f64),
1097    index: usize,
1098    order_rank: usize,
1099}
1100
1101#[derive(Clone, Debug)]
1102struct ComplexRowDiffEntry {
1103    row_data: Vec<(f64, f64)>,
1104    row_index: usize,
1105    order_rank: usize,
1106}
1107
1108#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1109struct CharDiffEntry {
1110    ch: char,
1111    index: usize,
1112    order_rank: usize,
1113}
1114
1115#[derive(Clone, Debug)]
1116struct CharRowDiffEntry {
1117    row_data: Vec<char>,
1118    row_index: usize,
1119    order_rank: usize,
1120}
1121
1122#[derive(Clone, Debug)]
1123struct StringDiffEntry {
1124    value: String,
1125    index: usize,
1126    order_rank: usize,
1127}
1128
1129#[derive(Clone, Debug)]
1130struct StringRowDiffEntry {
1131    row_data: Vec<String>,
1132    row_index: usize,
1133    order_rank: usize,
1134}
1135
1136#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1137struct NumericRowKey(Vec<u64>);
1138
1139impl NumericRowKey {
1140    fn from_slice(values: &[f64]) -> Self {
1141        NumericRowKey(values.iter().map(|&v| canonicalize_f64(v)).collect())
1142    }
1143}
1144
1145#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
1146struct ComplexKey {
1147    re: u64,
1148    im: u64,
1149}
1150
1151impl ComplexKey {
1152    fn new(value: (f64, f64)) -> Self {
1153        Self {
1154            re: canonicalize_f64(value.0),
1155            im: canonicalize_f64(value.1),
1156        }
1157    }
1158}
1159
1160#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1161struct RowCharKey(Vec<u32>);
1162
1163impl RowCharKey {
1164    fn from_slice(values: &[char]) -> Self {
1165        RowCharKey(values.iter().map(|&ch| ch as u32).collect())
1166    }
1167}
1168
1169#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1170struct RowStringKey(Vec<String>);
1171
1172pub struct SetdiffEvaluation {
1173    values: Value,
1174    ia: Tensor,
1175}
1176
1177impl SetdiffEvaluation {
1178    fn new(values: Value, ia: Tensor) -> Self {
1179        Self { values, ia }
1180    }
1181
1182    pub fn from_setdiff_result(result: SetdiffResult) -> Result<Self, String> {
1183        let SetdiffResult { values, ia } = result;
1184        let values_tensor =
1185            Tensor::new(values.data, values.shape).map_err(|e| format!("setdiff: {e}"))?;
1186        let ia_tensor = Tensor::new(ia.data, ia.shape).map_err(|e| format!("setdiff: {e}"))?;
1187        Ok(SetdiffEvaluation::new(
1188            Value::Tensor(values_tensor),
1189            ia_tensor,
1190        ))
1191    }
1192
1193    pub fn into_numeric_setdiff_result(self) -> Result<SetdiffResult, String> {
1194        let SetdiffEvaluation { values, ia } = self;
1195        let values_tensor = tensor::value_into_tensor_for("setdiff", values)?;
1196        Ok(SetdiffResult {
1197            values: HostTensorOwned {
1198                data: values_tensor.data,
1199                shape: values_tensor.shape,
1200            },
1201            ia: HostTensorOwned {
1202                data: ia.data,
1203                shape: ia.shape,
1204            },
1205        })
1206    }
1207
1208    pub fn into_values_value(self) -> Value {
1209        self.values
1210    }
1211
1212    pub fn into_pair(self) -> (Value, Value) {
1213        let ia = tensor::tensor_into_value(self.ia);
1214        (self.values, ia)
1215    }
1216
1217    pub fn values_value(&self) -> Value {
1218        self.values.clone()
1219    }
1220
1221    pub fn ia_value(&self) -> Value {
1222        tensor::tensor_into_value(self.ia.clone())
1223    }
1224}
1225
1226fn canonicalize_f64(value: f64) -> u64 {
1227    if value.is_nan() {
1228        0x7ff8_0000_0000_0000u64
1229    } else if value == 0.0 {
1230        0u64
1231    } else {
1232        value.to_bits()
1233    }
1234}
1235
1236fn compare_f64(a: f64, b: f64) -> Ordering {
1237    if a.is_nan() {
1238        if b.is_nan() {
1239            Ordering::Equal
1240        } else {
1241            Ordering::Greater
1242        }
1243    } else if b.is_nan() {
1244        Ordering::Less
1245    } else {
1246        a.partial_cmp(&b).unwrap_or(Ordering::Equal)
1247    }
1248}
1249
1250fn compare_numeric_rows(a: &[f64], b: &[f64]) -> Ordering {
1251    for (lhs, rhs) in a.iter().zip(b.iter()) {
1252        let ord = compare_f64(*lhs, *rhs);
1253        if ord != Ordering::Equal {
1254            return ord;
1255        }
1256    }
1257    Ordering::Equal
1258}
1259
1260fn complex_is_nan(value: (f64, f64)) -> bool {
1261    value.0.is_nan() || value.1.is_nan()
1262}
1263
1264fn compare_complex(a: (f64, f64), b: (f64, f64)) -> Ordering {
1265    match (complex_is_nan(a), complex_is_nan(b)) {
1266        (true, true) => Ordering::Equal,
1267        (true, false) => Ordering::Greater,
1268        (false, true) => Ordering::Less,
1269        (false, false) => {
1270            let mag_a = a.0.hypot(a.1);
1271            let mag_b = b.0.hypot(b.1);
1272            let mag_cmp = compare_f64(mag_a, mag_b);
1273            if mag_cmp != Ordering::Equal {
1274                return mag_cmp;
1275            }
1276            let re_cmp = compare_f64(a.0, b.0);
1277            if re_cmp != Ordering::Equal {
1278                return re_cmp;
1279            }
1280            compare_f64(a.1, b.1)
1281        }
1282    }
1283}
1284
1285fn compare_complex_rows(a: &[(f64, f64)], b: &[(f64, f64)]) -> Ordering {
1286    for (lhs, rhs) in a.iter().zip(b.iter()) {
1287        let ord = compare_complex(*lhs, *rhs);
1288        if ord != Ordering::Equal {
1289            return ord;
1290        }
1291    }
1292    Ordering::Equal
1293}
1294
1295fn compare_char_rows(a: &[char], b: &[char]) -> Ordering {
1296    for (lhs, rhs) in a.iter().zip(b.iter()) {
1297        let ord = lhs.cmp(rhs);
1298        if ord != Ordering::Equal {
1299            return ord;
1300        }
1301    }
1302    Ordering::Equal
1303}
1304
1305fn compare_string_rows(a: &[String], b: &[String]) -> Ordering {
1306    for (lhs, rhs) in a.iter().zip(b.iter()) {
1307        let ord = lhs.cmp(rhs);
1308        if ord != Ordering::Equal {
1309            return ord;
1310        }
1311    }
1312    Ordering::Equal
1313}
1314
1315#[cfg(test)]
1316mod tests {
1317    use super::*;
1318    use crate::builtins::common::test_support;
1319    use runmat_accelerate_api::HostTensorView;
1320    use runmat_builtins::{CharArray, StringArray, Tensor, Value};
1321
1322    #[test]
1323    fn setdiff_numeric_sorted_default() {
1324        let a = Tensor::new(vec![5.0, 7.0, 5.0, 1.0], vec![4, 1]).unwrap();
1325        let b = Tensor::new(vec![7.0, 1.0, 3.0], vec![3, 1]).unwrap();
1326        let eval = evaluate(Value::Tensor(a), Value::Tensor(b), &[]).expect("setdiff");
1327        match eval.values_value() {
1328            Value::Tensor(t) => {
1329                assert_eq!(t.shape, vec![1, 1]);
1330                assert_eq!(t.data, vec![5.0]);
1331            }
1332            other => panic!("expected tensor result, got {other:?}"),
1333        }
1334        let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1335        assert_eq!(ia.data, vec![1.0]);
1336    }
1337
1338    #[test]
1339    fn setdiff_numeric_stable() {
1340        let a = Tensor::new(vec![4.0, 2.0, 4.0, 1.0, 3.0], vec![5, 1]).unwrap();
1341        let b = Tensor::new(vec![3.0, 4.0, 5.0, 1.0], vec![4, 1]).unwrap();
1342        let eval = evaluate(Value::Tensor(a), Value::Tensor(b), &[Value::from("stable")])
1343            .expect("setdiff");
1344        match eval.values_value() {
1345            Value::Tensor(t) => {
1346                assert_eq!(t.shape, vec![1, 1]);
1347                assert_eq!(t.data, vec![2.0]);
1348            }
1349            other => panic!("expected tensor result, got {other:?}"),
1350        }
1351        let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1352        assert_eq!(ia.data, vec![2.0]);
1353    }
1354
1355    #[test]
1356    fn setdiff_numeric_rows_sorted() {
1357        let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
1358        let b = Tensor::new(vec![3.0, 5.0, 4.0, 6.0], vec![2, 2]).unwrap();
1359        let eval =
1360            evaluate(Value::Tensor(a), Value::Tensor(b), &[Value::from("rows")]).expect("setdiff");
1361        match eval.values_value() {
1362            Value::Tensor(t) => {
1363                assert_eq!(t.shape, vec![1, 2]);
1364                assert_eq!(t.data, vec![1.0, 2.0]);
1365            }
1366            other => panic!("expected tensor result, got {other:?}"),
1367        }
1368        let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1369        assert_eq!(ia.data, vec![1.0]);
1370    }
1371
1372    #[test]
1373    fn setdiff_numeric_removes_nan() {
1374        let a = Tensor::new(vec![f64::NAN, 2.0, 3.0], vec![3, 1]).unwrap();
1375        let b = Tensor::new(vec![f64::NAN], vec![1, 1]).unwrap();
1376        let eval = evaluate(Value::Tensor(a), Value::Tensor(b), &[]).expect("setdiff");
1377        let values = tensor::value_into_tensor_for("setdiff", eval.values_value()).expect("values");
1378        assert_eq!(values.data, vec![2.0, 3.0]);
1379        let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1380        assert_eq!(ia.data, vec![2.0, 3.0]);
1381    }
1382
1383    #[test]
1384    fn setdiff_char_elements() {
1385        let a = CharArray::new(vec!['m', 'z', 'm', 'a'], 2, 2).unwrap();
1386        let b = CharArray::new(vec!['a', 'x', 'm', 'a'], 2, 2).unwrap();
1387        let eval = evaluate(Value::CharArray(a), Value::CharArray(b), &[]).expect("setdiff");
1388        match eval.values_value() {
1389            Value::CharArray(arr) => {
1390                assert_eq!(arr.rows, 1);
1391                assert_eq!(arr.cols, 1);
1392                assert_eq!(arr.data, vec!['z']);
1393            }
1394            other => panic!("expected char array, got {other:?}"),
1395        }
1396        let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1397        assert_eq!(ia.data, vec![3.0]);
1398    }
1399
1400    #[test]
1401    fn setdiff_string_rows_stable() {
1402        let a = StringArray::new(
1403            vec![
1404                "alpha".to_string(),
1405                "gamma".to_string(),
1406                "beta".to_string(),
1407                "beta".to_string(),
1408            ],
1409            vec![2, 2],
1410        )
1411        .unwrap();
1412        let b = StringArray::new(
1413            vec![
1414                "gamma".to_string(),
1415                "delta".to_string(),
1416                "beta".to_string(),
1417                "beta".to_string(),
1418            ],
1419            vec![2, 2],
1420        )
1421        .unwrap();
1422        let eval = evaluate(
1423            Value::StringArray(a),
1424            Value::StringArray(b),
1425            &[Value::from("rows"), Value::from("stable")],
1426        )
1427        .expect("setdiff");
1428        match eval.values_value() {
1429            Value::StringArray(arr) => {
1430                assert_eq!(arr.shape, vec![1, 2]);
1431                assert_eq!(arr.data, vec!["alpha".to_string(), "beta".to_string()]);
1432            }
1433            other => panic!("expected string array, got {other:?}"),
1434        }
1435        let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1436        assert_eq!(ia.data, vec![1.0]);
1437    }
1438
1439    #[test]
1440    fn setdiff_type_mismatch_errors() {
1441        let result = evaluate(Value::from(1.0), Value::String("a".into()), &[]);
1442        assert!(result.is_err());
1443    }
1444
1445    #[test]
1446    fn setdiff_rejects_legacy_option() {
1447        let result = evaluate(Value::from(1.0), Value::from(2.0), &[Value::from("legacy")]);
1448        assert!(result
1449            .err()
1450            .unwrap()
1451            .contains("setdiff: the 'legacy' behaviour is not supported"));
1452    }
1453
1454    #[test]
1455    fn setdiff_gpu_roundtrip() {
1456        test_support::with_test_provider(|provider| {
1457            let tensor_a = Tensor::new(vec![10.0, 4.0, 6.0, 4.0], vec![4, 1]).unwrap();
1458            let tensor_b = Tensor::new(vec![6.0, 4.0, 2.0], vec![3, 1]).unwrap();
1459            let view_a = HostTensorView {
1460                data: &tensor_a.data,
1461                shape: &tensor_a.shape,
1462            };
1463            let view_b = HostTensorView {
1464                data: &tensor_b.data,
1465                shape: &tensor_b.shape,
1466            };
1467            let handle_a = provider.upload(&view_a).expect("upload a");
1468            let handle_b = provider.upload(&view_b).expect("upload b");
1469            let eval = evaluate(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
1470                .expect("setdiff");
1471            match eval.values_value() {
1472                Value::Tensor(t) => {
1473                    assert_eq!(t.data, vec![10.0]);
1474                }
1475                other => panic!("expected tensor result, got {other:?}"),
1476            }
1477            let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1478            assert_eq!(ia.data, vec![1.0]);
1479        });
1480    }
1481
1482    #[test]
1483    #[cfg(feature = "wgpu")]
1484    fn setdiff_wgpu_matches_cpu() {
1485        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1486            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1487        );
1488        let a = Tensor::new(vec![8.0, 4.0, 2.0, 4.0], vec![4, 1]).unwrap();
1489        let b = Tensor::new(vec![2.0, 5.0], vec![2, 1]).unwrap();
1490
1491        let cpu_eval =
1492            evaluate(Value::Tensor(a.clone()), Value::Tensor(b.clone()), &[]).expect("setdiff");
1493        let cpu_values = tensor::value_into_tensor_for("setdiff", cpu_eval.values_value()).unwrap();
1494        let cpu_ia = tensor::value_into_tensor_for("setdiff", cpu_eval.ia_value()).unwrap();
1495
1496        let provider = runmat_accelerate_api::provider().expect("provider");
1497        let view_a = HostTensorView {
1498            data: &a.data,
1499            shape: &a.shape,
1500        };
1501        let view_b = HostTensorView {
1502            data: &b.data,
1503            shape: &b.shape,
1504        };
1505        let handle_a = provider.upload(&view_a).expect("upload A");
1506        let handle_b = provider.upload(&view_b).expect("upload B");
1507        let gpu_eval =
1508            evaluate(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[]).expect("setdiff");
1509        let gpu_values = tensor::value_into_tensor_for("setdiff", gpu_eval.values_value()).unwrap();
1510        let gpu_ia = tensor::value_into_tensor_for("setdiff", gpu_eval.ia_value()).unwrap();
1511
1512        assert_eq!(gpu_values.data, cpu_values.data);
1513        assert_eq!(gpu_values.shape, cpu_values.shape);
1514        assert_eq!(gpu_ia.data, cpu_ia.data);
1515        assert_eq!(gpu_ia.shape, cpu_ia.shape);
1516    }
1517
1518    #[test]
1519    #[cfg(feature = "doc_export")]
1520    fn doc_examples_present() {
1521        let blocks = test_support::doc_examples(DOC_MD);
1522        assert!(!blocks.is_empty());
1523    }
1524}