Skip to main content

runmat_runtime/builtins/array/sorting_sets/
setdiff.rs

1//! MATLAB-compatible `setdiff` builtin with GPU-aware semantics for RunMat.
2//!
3//! Provides element-wise and row-wise set difference with optional stable
4//! ordering. GPU tensors are gathered to host memory today, but the builtin is
5//! registered as a residency sink so future providers can implement device-side
6//! kernels without impacting behaviour.
7
8use std::cmp::Ordering;
9use std::collections::{HashMap, HashSet};
10
11use runmat_accelerate_api::{
12    GpuTensorHandle, GpuTensorStorage, HostTensorOwned, SetdiffOptions, SetdiffOrder, SetdiffResult,
13};
14use runmat_builtins::{CharArray, ComplexTensor, StringArray, Tensor, Value};
15use runmat_macros::runtime_builtin;
16
17use super::type_resolvers::set_values_output_type;
18use crate::build_runtime_error;
19use crate::builtins::common::arg_tokens::tokens_from_values;
20use crate::builtins::common::gpu_helpers;
21use crate::builtins::common::random_args::complex_tensor_into_value;
22use crate::builtins::common::spec::{
23    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
24    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
25};
26use crate::builtins::common::tensor;
27
28#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::sorting_sets::setdiff")]
29pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
30    name: "setdiff",
31    op_kind: GpuOpKind::Custom("setdiff"),
32    supported_precisions: &[ScalarType::F32, ScalarType::F64],
33    broadcast: BroadcastSemantics::None,
34    provider_hooks: &[ProviderHook::Custom("setdiff")],
35    constant_strategy: ConstantStrategy::InlineLiteral,
36    residency: ResidencyPolicy::GatherImmediately,
37    nan_mode: ReductionNaN::Include,
38    two_pass_threshold: None,
39    workgroup_size: None,
40    accepts_nan_mode: true,
41    notes: "Providers may implement `setdiff`; until then tensors are gathered and processed on the host.",
42};
43
44#[runmat_macros::register_fusion_spec(
45    builtin_path = "crate::builtins::array::sorting_sets::setdiff"
46)]
47pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
48    name: "setdiff",
49    shape: ShapeRequirements::Any,
50    constant_strategy: ConstantStrategy::InlineLiteral,
51    elementwise: None,
52    reduction: None,
53    emits_nan: true,
54    notes: "`setdiff` terminates fusion chains and materialises results on the host; upstream tensors are gathered when necessary.",
55};
56
57fn setdiff_error(message: impl Into<String>) -> crate::RuntimeError {
58    build_runtime_error(message).with_builtin("setdiff").build()
59}
60
61#[runtime_builtin(
62    name = "setdiff",
63    category = "array/sorting_sets",
64    summary = "Return the values that appear in the first input but not the second.",
65    keywords = "setdiff,difference,stable,rows,indices,gpu",
66    accel = "array_construct",
67    sink = true,
68    type_resolver(set_values_output_type),
69    builtin_path = "crate::builtins::array::sorting_sets::setdiff"
70)]
71async fn setdiff_builtin(a: Value, b: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
72    Ok(evaluate(a, b, &rest).await?.into_values_value())
73}
74
75/// Evaluate the `setdiff` builtin once and expose all outputs.
76pub async fn evaluate(
77    a: Value,
78    b: Value,
79    rest: &[Value],
80) -> crate::BuiltinResult<SetdiffEvaluation> {
81    let opts = parse_options(rest)?;
82    match (a, b) {
83        (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
84            setdiff_gpu_pair(handle_a, handle_b, &opts).await
85        }
86        (Value::GpuTensor(handle_a), other) => {
87            setdiff_gpu_mixed(handle_a, other, &opts, true).await
88        }
89        (other, Value::GpuTensor(handle_b)) => {
90            setdiff_gpu_mixed(handle_b, other, &opts, false).await
91        }
92        (left, right) => setdiff_host(left, right, &opts),
93    }
94}
95
96fn parse_options(rest: &[Value]) -> crate::BuiltinResult<SetdiffOptions> {
97    let mut opts = SetdiffOptions {
98        rows: false,
99        order: SetdiffOrder::Sorted,
100    };
101    let mut seen_order: Option<SetdiffOrder> = None;
102
103    let tokens = tokens_from_values(rest);
104    for (arg, token) in rest.iter().zip(tokens.iter()) {
105        let text = match token {
106            crate::builtins::common::arg_tokens::ArgToken::String(text) => text.as_str(),
107            _ => {
108                let text = tensor::value_to_string(arg)
109                    .ok_or_else(|| setdiff_error("setdiff: expected string option arguments"))?;
110                let lowered = text.trim().to_ascii_lowercase();
111                parse_setdiff_option(&mut opts, &mut seen_order, &lowered)?;
112                continue;
113            }
114        };
115        parse_setdiff_option(&mut opts, &mut seen_order, text)?;
116    }
117
118    Ok(opts)
119}
120
121fn parse_setdiff_option(
122    opts: &mut SetdiffOptions,
123    seen_order: &mut Option<SetdiffOrder>,
124    lowered: &str,
125) -> crate::BuiltinResult<()> {
126    match lowered {
127        "rows" => opts.rows = true,
128        "sorted" => {
129            if let Some(prev) = seen_order {
130                if *prev != SetdiffOrder::Sorted {
131                    return Err(setdiff_error(
132                        "setdiff: cannot combine 'sorted' with 'stable'",
133                    ));
134                }
135            }
136            *seen_order = Some(SetdiffOrder::Sorted);
137            opts.order = SetdiffOrder::Sorted;
138        }
139        "stable" => {
140            if let Some(prev) = seen_order {
141                if *prev != SetdiffOrder::Stable {
142                    return Err(setdiff_error(
143                        "setdiff: cannot combine 'sorted' with 'stable'",
144                    ));
145                }
146            }
147            *seen_order = Some(SetdiffOrder::Stable);
148            opts.order = SetdiffOrder::Stable;
149        }
150        "legacy" | "r2012a" => {
151            return Err(setdiff_error(
152                "setdiff: the 'legacy' behaviour is not supported",
153            ));
154        }
155        other => {
156            return Err(setdiff_error(format!(
157                "setdiff: unrecognised option '{other}'"
158            )))
159        }
160    }
161    Ok(())
162}
163
164async fn setdiff_gpu_pair(
165    handle_a: GpuTensorHandle,
166    handle_b: GpuTensorHandle,
167    opts: &SetdiffOptions,
168) -> crate::BuiltinResult<SetdiffEvaluation> {
169    if let Some(provider) = runmat_accelerate_api::provider() {
170        match provider.setdiff(&handle_a, &handle_b, opts).await {
171            Ok(result) => return SetdiffEvaluation::from_setdiff_result(result),
172            Err(_) => {
173                // Fall back to host gather when provider does not support setdiff.
174            }
175        }
176    }
177    let a_tensor = gpu_helpers::gather_tensor_async(&handle_a).await?;
178    let b_tensor = gpu_helpers::gather_tensor_async(&handle_b).await?;
179    setdiff_numeric(a_tensor, b_tensor, opts)
180}
181
182async fn setdiff_gpu_mixed(
183    handle_gpu: GpuTensorHandle,
184    other: Value,
185    opts: &SetdiffOptions,
186    gpu_is_a: bool,
187) -> crate::BuiltinResult<SetdiffEvaluation> {
188    let gpu_tensor = gpu_helpers::gather_tensor_async(&handle_gpu).await?;
189    let other_tensor =
190        tensor::value_into_tensor_for("setdiff", other).map_err(|e| setdiff_error(e))?;
191    if gpu_is_a {
192        setdiff_numeric(gpu_tensor, other_tensor, opts)
193    } else {
194        setdiff_numeric(other_tensor, gpu_tensor, opts)
195    }
196}
197
198fn setdiff_host(
199    a: Value,
200    b: Value,
201    opts: &SetdiffOptions,
202) -> crate::BuiltinResult<SetdiffEvaluation> {
203    match (a, b) {
204        (Value::ComplexTensor(at), Value::ComplexTensor(bt)) => setdiff_complex(at, bt, opts),
205        (Value::ComplexTensor(at), Value::Complex(re, im)) => {
206            let bt = ComplexTensor::new(vec![(re, im)], vec![1, 1])
207                .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
208            setdiff_complex(at, bt, opts)
209        }
210        (Value::Complex(a_re, a_im), Value::ComplexTensor(bt)) => {
211            let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
212                .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
213            setdiff_complex(at, bt, opts)
214        }
215        (Value::Complex(a_re, a_im), Value::Complex(b_re, b_im)) => {
216            let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
217                .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
218            let bt = ComplexTensor::new(vec![(b_re, b_im)], vec![1, 1])
219                .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
220            setdiff_complex(at, bt, opts)
221        }
222
223        (Value::CharArray(ac), Value::CharArray(bc)) => setdiff_char(ac, bc, opts),
224
225        (Value::StringArray(astring), Value::StringArray(bstring)) => {
226            setdiff_string(astring, bstring, opts)
227        }
228        (Value::StringArray(astring), Value::String(b)) => {
229            let bstring = StringArray::new(vec![b], vec![1, 1])
230                .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
231            setdiff_string(astring, bstring, opts)
232        }
233        (Value::String(a), Value::StringArray(bstring)) => {
234            let astring = StringArray::new(vec![a], vec![1, 1])
235                .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
236            setdiff_string(astring, bstring, opts)
237        }
238        (Value::String(a), Value::String(b)) => {
239            let astring = StringArray::new(vec![a], vec![1, 1])
240                .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
241            let bstring = StringArray::new(vec![b], vec![1, 1])
242                .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
243            setdiff_string(astring, bstring, opts)
244        }
245
246        (left, right) => {
247            let tensor_a =
248                tensor::value_into_tensor_for("setdiff", left).map_err(|e| setdiff_error(e))?;
249            let tensor_b =
250                tensor::value_into_tensor_for("setdiff", right).map_err(|e| setdiff_error(e))?;
251            setdiff_numeric(tensor_a, tensor_b, opts)
252        }
253    }
254}
255
256fn setdiff_numeric(
257    a: Tensor,
258    b: Tensor,
259    opts: &SetdiffOptions,
260) -> crate::BuiltinResult<SetdiffEvaluation> {
261    if opts.rows {
262        setdiff_numeric_rows(a, b, opts)
263    } else {
264        setdiff_numeric_elements(a, b, opts)
265    }
266}
267
268/// Helper exposed for acceleration providers handling numeric tensors entirely on the host.
269pub fn setdiff_numeric_from_tensors(
270    a: Tensor,
271    b: Tensor,
272    opts: &SetdiffOptions,
273) -> crate::BuiltinResult<SetdiffEvaluation> {
274    setdiff_numeric(a, b, opts)
275}
276
277fn setdiff_numeric_elements(
278    a: Tensor,
279    b: Tensor,
280    opts: &SetdiffOptions,
281) -> crate::BuiltinResult<SetdiffEvaluation> {
282    let mut b_keys: HashSet<u64> = HashSet::new();
283    for &value in &b.data {
284        b_keys.insert(canonicalize_f64(value));
285    }
286
287    let mut seen: HashMap<u64, usize> = HashMap::new();
288    let mut entries = Vec::<NumericDiffEntry>::new();
289    let mut order_counter = 0usize;
290
291    for (idx, &value) in a.data.iter().enumerate() {
292        let key = canonicalize_f64(value);
293        if b_keys.contains(&key) {
294            continue;
295        }
296        if seen.contains_key(&key) {
297            continue;
298        }
299        let entry_idx = entries.len();
300        entries.push(NumericDiffEntry {
301            value,
302            index: idx,
303            order_rank: order_counter,
304        });
305        seen.insert(key, entry_idx);
306        order_counter += 1;
307    }
308
309    assemble_numeric_setdiff(entries, opts)
310}
311
312fn setdiff_numeric_rows(
313    a: Tensor,
314    b: Tensor,
315    opts: &SetdiffOptions,
316) -> crate::BuiltinResult<SetdiffEvaluation> {
317    if a.shape.len() != 2 || b.shape.len() != 2 {
318        return Err(setdiff_error(
319            "setdiff: 'rows' option requires 2-D numeric matrices",
320        ));
321    }
322    if a.shape[1] != b.shape[1] {
323        return Err(setdiff_error(
324            "setdiff: inputs must have the same number of columns when using 'rows'",
325        ));
326    }
327
328    let rows_a = a.shape[0];
329    let rows_b = b.shape[0];
330    let cols = a.shape[1];
331
332    let mut b_keys: HashSet<NumericRowKey> = HashSet::new();
333    for r in 0..rows_b {
334        let mut row_values = Vec::with_capacity(cols);
335        for c in 0..cols {
336            let idx = r + c * rows_b;
337            row_values.push(b.data[idx]);
338        }
339        b_keys.insert(NumericRowKey::from_slice(&row_values));
340    }
341
342    let mut seen: HashSet<NumericRowKey> = HashSet::new();
343    let mut entries = Vec::<NumericRowDiffEntry>::new();
344    let mut order_counter = 0usize;
345
346    for r in 0..rows_a {
347        let mut row_values = Vec::with_capacity(cols);
348        for c in 0..cols {
349            let idx = r + c * rows_a;
350            row_values.push(a.data[idx]);
351        }
352        let key = NumericRowKey::from_slice(&row_values);
353        if b_keys.contains(&key) {
354            continue;
355        }
356        if !seen.insert(key) {
357            continue;
358        }
359        entries.push(NumericRowDiffEntry {
360            row_data: row_values,
361            row_index: r,
362            order_rank: order_counter,
363        });
364        order_counter += 1;
365    }
366
367    assemble_numeric_row_setdiff(entries, opts, cols)
368}
369
370fn setdiff_complex(
371    a: ComplexTensor,
372    b: ComplexTensor,
373    opts: &SetdiffOptions,
374) -> crate::BuiltinResult<SetdiffEvaluation> {
375    if opts.rows {
376        setdiff_complex_rows(a, b, opts)
377    } else {
378        setdiff_complex_elements(a, b, opts)
379    }
380}
381
382fn setdiff_complex_elements(
383    a: ComplexTensor,
384    b: ComplexTensor,
385    opts: &SetdiffOptions,
386) -> crate::BuiltinResult<SetdiffEvaluation> {
387    let mut b_keys: HashSet<ComplexKey> = HashSet::new();
388    for &value in &b.data {
389        b_keys.insert(ComplexKey::new(value));
390    }
391
392    let mut seen: HashSet<ComplexKey> = HashSet::new();
393    let mut entries = Vec::<ComplexDiffEntry>::new();
394    let mut order_counter = 0usize;
395
396    for (idx, &value) in a.data.iter().enumerate() {
397        let key = ComplexKey::new(value);
398        if b_keys.contains(&key) {
399            continue;
400        }
401        if !seen.insert(key) {
402            continue;
403        }
404        entries.push(ComplexDiffEntry {
405            value,
406            index: idx,
407            order_rank: order_counter,
408        });
409        order_counter += 1;
410    }
411
412    assemble_complex_setdiff(entries, opts)
413}
414
415fn setdiff_complex_rows(
416    a: ComplexTensor,
417    b: ComplexTensor,
418    opts: &SetdiffOptions,
419) -> crate::BuiltinResult<SetdiffEvaluation> {
420    if a.shape.len() != 2 || b.shape.len() != 2 {
421        return Err(setdiff_error(
422            "setdiff: 'rows' option requires 2-D complex matrices",
423        ));
424    }
425    if a.shape[1] != b.shape[1] {
426        return Err(setdiff_error(
427            "setdiff: inputs must have the same number of columns when using 'rows'",
428        ));
429    }
430
431    let rows_a = a.shape[0];
432    let rows_b = b.shape[0];
433    let cols = a.shape[1];
434
435    let mut b_keys: HashSet<Vec<ComplexKey>> = HashSet::new();
436    for r in 0..rows_b {
437        let mut key_row = Vec::with_capacity(cols);
438        for c in 0..cols {
439            let idx = r + c * rows_b;
440            key_row.push(ComplexKey::new(b.data[idx]));
441        }
442        b_keys.insert(key_row);
443    }
444
445    let mut seen: HashSet<Vec<ComplexKey>> = HashSet::new();
446    let mut entries = Vec::<ComplexRowDiffEntry>::new();
447    let mut order_counter = 0usize;
448
449    for r in 0..rows_a {
450        let mut row_values = Vec::with_capacity(cols);
451        let mut key_row = Vec::with_capacity(cols);
452        for c in 0..cols {
453            let idx = r + c * rows_a;
454            let value = a.data[idx];
455            row_values.push(value);
456            key_row.push(ComplexKey::new(value));
457        }
458        if b_keys.contains(&key_row) {
459            continue;
460        }
461        if !seen.insert(key_row) {
462            continue;
463        }
464        entries.push(ComplexRowDiffEntry {
465            row_data: row_values,
466            row_index: r,
467            order_rank: order_counter,
468        });
469        order_counter += 1;
470    }
471
472    assemble_complex_row_setdiff(entries, opts, cols)
473}
474
475fn setdiff_char(
476    a: CharArray,
477    b: CharArray,
478    opts: &SetdiffOptions,
479) -> crate::BuiltinResult<SetdiffEvaluation> {
480    if opts.rows {
481        setdiff_char_rows(a, b, opts)
482    } else {
483        setdiff_char_elements(a, b, opts)
484    }
485}
486
487fn setdiff_char_elements(
488    a: CharArray,
489    b: CharArray,
490    opts: &SetdiffOptions,
491) -> crate::BuiltinResult<SetdiffEvaluation> {
492    let mut b_keys: HashSet<u32> = HashSet::new();
493    for ch in &b.data {
494        b_keys.insert(*ch as u32);
495    }
496
497    let mut seen: HashSet<u32> = HashSet::new();
498    let mut entries = Vec::<CharDiffEntry>::new();
499    let mut order_counter = 0usize;
500
501    for col in 0..a.cols {
502        for row in 0..a.rows {
503            let linear_idx = row + col * a.rows;
504            let data_idx = row * a.cols + col;
505            let ch = a.data[data_idx];
506            let key = ch as u32;
507            if b_keys.contains(&key) {
508                continue;
509            }
510            if !seen.insert(key) {
511                continue;
512            }
513            entries.push(CharDiffEntry {
514                ch,
515                index: linear_idx,
516                order_rank: order_counter,
517            });
518            order_counter += 1;
519        }
520    }
521
522    assemble_char_setdiff(entries, opts)
523}
524
525fn setdiff_char_rows(
526    a: CharArray,
527    b: CharArray,
528    opts: &SetdiffOptions,
529) -> crate::BuiltinResult<SetdiffEvaluation> {
530    if a.cols != b.cols {
531        return Err(setdiff_error(
532            "setdiff: inputs must have the same number of columns when using 'rows'",
533        ));
534    }
535
536    let rows_a = a.rows;
537    let rows_b = b.rows;
538    let cols = a.cols;
539
540    let mut b_keys: HashSet<RowCharKey> = HashSet::new();
541    for r in 0..rows_b {
542        let mut row_values = Vec::with_capacity(cols);
543        for c in 0..cols {
544            let idx = r * cols + c;
545            row_values.push(b.data[idx]);
546        }
547        b_keys.insert(RowCharKey::from_slice(&row_values));
548    }
549
550    let mut seen: HashSet<RowCharKey> = HashSet::new();
551    let mut entries = Vec::<CharRowDiffEntry>::new();
552    let mut order_counter = 0usize;
553
554    for r in 0..rows_a {
555        let mut row_values = Vec::with_capacity(cols);
556        for c in 0..cols {
557            let idx = r * cols + c;
558            row_values.push(a.data[idx]);
559        }
560        let key = RowCharKey::from_slice(&row_values);
561        if b_keys.contains(&key) {
562            continue;
563        }
564        if !seen.insert(key) {
565            continue;
566        }
567        entries.push(CharRowDiffEntry {
568            row_data: row_values,
569            row_index: r,
570            order_rank: order_counter,
571        });
572        order_counter += 1;
573    }
574
575    assemble_char_row_setdiff(entries, opts, cols)
576}
577
578fn setdiff_string(
579    a: StringArray,
580    b: StringArray,
581    opts: &SetdiffOptions,
582) -> crate::BuiltinResult<SetdiffEvaluation> {
583    if opts.rows {
584        setdiff_string_rows(a, b, opts)
585    } else {
586        setdiff_string_elements(a, b, opts)
587    }
588}
589
590fn setdiff_string_elements(
591    a: StringArray,
592    b: StringArray,
593    opts: &SetdiffOptions,
594) -> crate::BuiltinResult<SetdiffEvaluation> {
595    let mut b_keys: HashSet<String> = HashSet::new();
596    for value in &b.data {
597        b_keys.insert(value.clone());
598    }
599
600    let mut seen: HashSet<String> = HashSet::new();
601    let mut entries = Vec::<StringDiffEntry>::new();
602    let mut order_counter = 0usize;
603
604    for (idx, value) in a.data.iter().enumerate() {
605        if b_keys.contains(value) {
606            continue;
607        }
608        if !seen.insert(value.clone()) {
609            continue;
610        }
611        entries.push(StringDiffEntry {
612            value: value.clone(),
613            index: idx,
614            order_rank: order_counter,
615        });
616        order_counter += 1;
617    }
618
619    assemble_string_setdiff(entries, opts)
620}
621
622fn setdiff_string_rows(
623    a: StringArray,
624    b: StringArray,
625    opts: &SetdiffOptions,
626) -> crate::BuiltinResult<SetdiffEvaluation> {
627    if a.shape.len() != 2 || b.shape.len() != 2 {
628        return Err(setdiff_error(
629            "setdiff: 'rows' option requires 2-D string arrays",
630        ));
631    }
632    if a.shape[1] != b.shape[1] {
633        return Err(setdiff_error(
634            "setdiff: inputs must have the same number of columns when using 'rows'",
635        ));
636    }
637
638    let rows_a = a.shape[0];
639    let rows_b = b.shape[0];
640    let cols = a.shape[1];
641
642    let mut b_keys: HashSet<RowStringKey> = HashSet::new();
643    for r in 0..rows_b {
644        let mut row_values = Vec::with_capacity(cols);
645        for c in 0..cols {
646            let idx = r + c * rows_b;
647            row_values.push(b.data[idx].clone());
648        }
649        b_keys.insert(RowStringKey(row_values.clone()));
650    }
651
652    let mut seen: HashSet<RowStringKey> = HashSet::new();
653    let mut entries = Vec::<StringRowDiffEntry>::new();
654    let mut order_counter = 0usize;
655
656    for r in 0..rows_a {
657        let mut row_values = Vec::with_capacity(cols);
658        for c in 0..cols {
659            let idx = r + c * rows_a;
660            row_values.push(a.data[idx].clone());
661        }
662        let key = RowStringKey(row_values.clone());
663        if b_keys.contains(&key) {
664            continue;
665        }
666        if !seen.insert(key) {
667            continue;
668        }
669        entries.push(StringRowDiffEntry {
670            row_data: row_values,
671            row_index: r,
672            order_rank: order_counter,
673        });
674        order_counter += 1;
675    }
676
677    assemble_string_row_setdiff(entries, opts, cols)
678}
679
680fn assemble_numeric_setdiff(
681    entries: Vec<NumericDiffEntry>,
682    opts: &SetdiffOptions,
683) -> crate::BuiltinResult<SetdiffEvaluation> {
684    let mut order: Vec<usize> = (0..entries.len()).collect();
685    match opts.order {
686        SetdiffOrder::Sorted => {
687            order.sort_by(|&lhs, &rhs| compare_f64(entries[lhs].value, entries[rhs].value));
688        }
689        SetdiffOrder::Stable => {
690            order.sort_by_key(|&idx| entries[idx].order_rank);
691        }
692    }
693
694    let mut values = Vec::with_capacity(order.len());
695    let mut ia = Vec::with_capacity(order.len());
696    for &idx in &order {
697        let entry = &entries[idx];
698        values.push(entry.value);
699        ia.push((entry.index + 1) as f64);
700    }
701
702    let value_tensor = Tensor::new(values, vec![order.len(), 1])
703        .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
704    let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
705        .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
706
707    Ok(SetdiffEvaluation::new(
708        Value::Tensor(value_tensor),
709        ia_tensor,
710    ))
711}
712
713fn assemble_numeric_row_setdiff(
714    entries: Vec<NumericRowDiffEntry>,
715    opts: &SetdiffOptions,
716    cols: usize,
717) -> crate::BuiltinResult<SetdiffEvaluation> {
718    let mut order: Vec<usize> = (0..entries.len()).collect();
719    match opts.order {
720        SetdiffOrder::Sorted => {
721            order.sort_by(|&lhs, &rhs| {
722                compare_numeric_rows(&entries[lhs].row_data, &entries[rhs].row_data)
723            });
724        }
725        SetdiffOrder::Stable => {
726            order.sort_by_key(|&idx| entries[idx].order_rank);
727        }
728    }
729
730    let unique_rows = order.len();
731    let mut values = vec![0.0f64; unique_rows * cols];
732    let mut ia = Vec::with_capacity(unique_rows);
733
734    for (row_pos, &entry_idx) in order.iter().enumerate() {
735        let entry = &entries[entry_idx];
736        for col in 0..cols {
737            let dest = row_pos + col * unique_rows;
738            values[dest] = entry.row_data[col];
739        }
740        ia.push((entry.row_index + 1) as f64);
741    }
742
743    let value_tensor = Tensor::new(values, vec![unique_rows, cols])
744        .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
745    let ia_tensor = Tensor::new(ia, vec![unique_rows, 1])
746        .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
747
748    Ok(SetdiffEvaluation::new(
749        Value::Tensor(value_tensor),
750        ia_tensor,
751    ))
752}
753
754fn assemble_complex_setdiff(
755    entries: Vec<ComplexDiffEntry>,
756    opts: &SetdiffOptions,
757) -> crate::BuiltinResult<SetdiffEvaluation> {
758    let mut order: Vec<usize> = (0..entries.len()).collect();
759    match opts.order {
760        SetdiffOrder::Sorted => {
761            order.sort_by(|&lhs, &rhs| compare_complex(entries[lhs].value, entries[rhs].value));
762        }
763        SetdiffOrder::Stable => {
764            order.sort_by_key(|&idx| entries[idx].order_rank);
765        }
766    }
767
768    let mut values = Vec::with_capacity(order.len());
769    let mut ia = Vec::with_capacity(order.len());
770    for &idx in &order {
771        let entry = &entries[idx];
772        values.push(entry.value);
773        ia.push((entry.index + 1) as f64);
774    }
775
776    let value_tensor = ComplexTensor::new(values, vec![order.len(), 1])
777        .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
778    let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
779        .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
780
781    Ok(SetdiffEvaluation::new(
782        complex_tensor_into_value(value_tensor),
783        ia_tensor,
784    ))
785}
786
787fn assemble_complex_row_setdiff(
788    entries: Vec<ComplexRowDiffEntry>,
789    opts: &SetdiffOptions,
790    cols: usize,
791) -> crate::BuiltinResult<SetdiffEvaluation> {
792    let mut order: Vec<usize> = (0..entries.len()).collect();
793    match opts.order {
794        SetdiffOrder::Sorted => {
795            order.sort_by(|&lhs, &rhs| {
796                compare_complex_rows(&entries[lhs].row_data, &entries[rhs].row_data)
797            });
798        }
799        SetdiffOrder::Stable => {
800            order.sort_by_key(|&idx| entries[idx].order_rank);
801        }
802    }
803
804    let unique_rows = order.len();
805    let mut values = vec![(0.0f64, 0.0f64); unique_rows * cols];
806    let mut ia = Vec::with_capacity(unique_rows);
807
808    for (row_pos, &entry_idx) in order.iter().enumerate() {
809        let entry = &entries[entry_idx];
810        for col in 0..cols {
811            let dest = row_pos + col * unique_rows;
812            values[dest] = entry.row_data[col];
813        }
814        ia.push((entry.row_index + 1) as f64);
815    }
816
817    let value_tensor = ComplexTensor::new(values, vec![unique_rows, cols])
818        .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
819    let ia_tensor = Tensor::new(ia, vec![unique_rows, 1])
820        .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
821
822    Ok(SetdiffEvaluation::new(
823        complex_tensor_into_value(value_tensor),
824        ia_tensor,
825    ))
826}
827
828fn assemble_char_setdiff(
829    entries: Vec<CharDiffEntry>,
830    opts: &SetdiffOptions,
831) -> crate::BuiltinResult<SetdiffEvaluation> {
832    let mut order: Vec<usize> = (0..entries.len()).collect();
833    match opts.order {
834        SetdiffOrder::Sorted => {
835            order.sort_by(|&lhs, &rhs| entries[lhs].ch.cmp(&entries[rhs].ch));
836        }
837        SetdiffOrder::Stable => {
838            order.sort_by_key(|&idx| entries[idx].order_rank);
839        }
840    }
841
842    let mut values = Vec::with_capacity(order.len());
843    let mut ia = Vec::with_capacity(order.len());
844    for &idx in &order {
845        let entry = &entries[idx];
846        values.push(entry.ch);
847        ia.push((entry.index + 1) as f64);
848    }
849
850    let value_array = CharArray::new(values, order.len(), 1)
851        .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
852    let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
853        .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
854
855    Ok(SetdiffEvaluation::new(
856        Value::CharArray(value_array),
857        ia_tensor,
858    ))
859}
860
861fn assemble_char_row_setdiff(
862    entries: Vec<CharRowDiffEntry>,
863    opts: &SetdiffOptions,
864    cols: usize,
865) -> crate::BuiltinResult<SetdiffEvaluation> {
866    let mut order: Vec<usize> = (0..entries.len()).collect();
867    match opts.order {
868        SetdiffOrder::Sorted => {
869            order.sort_by(|&lhs, &rhs| {
870                compare_char_rows(&entries[lhs].row_data, &entries[rhs].row_data)
871            });
872        }
873        SetdiffOrder::Stable => {
874            order.sort_by_key(|&idx| entries[idx].order_rank);
875        }
876    }
877
878    let unique_rows = order.len();
879    let mut values = vec!['\0'; unique_rows * cols];
880    let mut ia = Vec::with_capacity(unique_rows);
881
882    for (row_pos, &entry_idx) in order.iter().enumerate() {
883        let entry = &entries[entry_idx];
884        for col in 0..cols {
885            let dest = row_pos * cols + col;
886            values[dest] = entry.row_data[col];
887        }
888        ia.push((entry.row_index + 1) as f64);
889    }
890
891    let value_array = CharArray::new(values, unique_rows, cols)
892        .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
893    let ia_tensor = Tensor::new(ia, vec![unique_rows, 1])
894        .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
895
896    Ok(SetdiffEvaluation::new(
897        Value::CharArray(value_array),
898        ia_tensor,
899    ))
900}
901
902fn assemble_string_setdiff(
903    entries: Vec<StringDiffEntry>,
904    opts: &SetdiffOptions,
905) -> crate::BuiltinResult<SetdiffEvaluation> {
906    let mut order: Vec<usize> = (0..entries.len()).collect();
907    match opts.order {
908        SetdiffOrder::Sorted => {
909            order.sort_by(|&lhs, &rhs| entries[lhs].value.cmp(&entries[rhs].value));
910        }
911        SetdiffOrder::Stable => {
912            order.sort_by_key(|&idx| entries[idx].order_rank);
913        }
914    }
915
916    let mut values = Vec::with_capacity(order.len());
917    let mut ia = Vec::with_capacity(order.len());
918    for &idx in &order {
919        let entry = &entries[idx];
920        values.push(entry.value.clone());
921        ia.push((entry.index + 1) as f64);
922    }
923
924    let value_array = StringArray::new(values, vec![order.len(), 1])
925        .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
926    let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
927        .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
928
929    Ok(SetdiffEvaluation::new(
930        Value::StringArray(value_array),
931        ia_tensor,
932    ))
933}
934
935fn assemble_string_row_setdiff(
936    entries: Vec<StringRowDiffEntry>,
937    opts: &SetdiffOptions,
938    cols: usize,
939) -> crate::BuiltinResult<SetdiffEvaluation> {
940    let mut order: Vec<usize> = (0..entries.len()).collect();
941    match opts.order {
942        SetdiffOrder::Sorted => {
943            order.sort_by(|&lhs, &rhs| {
944                compare_string_rows(&entries[lhs].row_data, &entries[rhs].row_data)
945            });
946        }
947        SetdiffOrder::Stable => {
948            order.sort_by_key(|&idx| entries[idx].order_rank);
949        }
950    }
951
952    let unique_rows = order.len();
953    let mut values = vec![String::new(); unique_rows * cols];
954    let mut ia = Vec::with_capacity(unique_rows);
955
956    for (row_pos, &entry_idx) in order.iter().enumerate() {
957        let entry = &entries[entry_idx];
958        for col in 0..cols {
959            let dest = row_pos + col * unique_rows;
960            values[dest] = entry.row_data[col].clone();
961        }
962        ia.push((entry.row_index + 1) as f64);
963    }
964
965    let value_array = StringArray::new(values, vec![unique_rows, cols])
966        .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
967    let ia_tensor = Tensor::new(ia, vec![unique_rows, 1])
968        .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
969
970    Ok(SetdiffEvaluation::new(
971        Value::StringArray(value_array),
972        ia_tensor,
973    ))
974}
975
976#[derive(Clone, Copy, Debug)]
977struct NumericDiffEntry {
978    value: f64,
979    index: usize,
980    order_rank: usize,
981}
982
983#[derive(Clone, Debug)]
984struct NumericRowDiffEntry {
985    row_data: Vec<f64>,
986    row_index: usize,
987    order_rank: usize,
988}
989
990#[derive(Clone, Copy, Debug)]
991struct ComplexDiffEntry {
992    value: (f64, f64),
993    index: usize,
994    order_rank: usize,
995}
996
997#[derive(Clone, Debug)]
998struct ComplexRowDiffEntry {
999    row_data: Vec<(f64, f64)>,
1000    row_index: usize,
1001    order_rank: usize,
1002}
1003
1004#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1005struct CharDiffEntry {
1006    ch: char,
1007    index: usize,
1008    order_rank: usize,
1009}
1010
1011#[derive(Clone, Debug)]
1012struct CharRowDiffEntry {
1013    row_data: Vec<char>,
1014    row_index: usize,
1015    order_rank: usize,
1016}
1017
1018#[derive(Clone, Debug)]
1019struct StringDiffEntry {
1020    value: String,
1021    index: usize,
1022    order_rank: usize,
1023}
1024
1025#[derive(Clone, Debug)]
1026struct StringRowDiffEntry {
1027    row_data: Vec<String>,
1028    row_index: usize,
1029    order_rank: usize,
1030}
1031
1032#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1033struct NumericRowKey(Vec<u64>);
1034
1035impl NumericRowKey {
1036    fn from_slice(values: &[f64]) -> Self {
1037        NumericRowKey(values.iter().map(|&v| canonicalize_f64(v)).collect())
1038    }
1039}
1040
1041#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
1042struct ComplexKey {
1043    re: u64,
1044    im: u64,
1045}
1046
1047impl ComplexKey {
1048    fn new(value: (f64, f64)) -> Self {
1049        Self {
1050            re: canonicalize_f64(value.0),
1051            im: canonicalize_f64(value.1),
1052        }
1053    }
1054}
1055
1056#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1057struct RowCharKey(Vec<u32>);
1058
1059impl RowCharKey {
1060    fn from_slice(values: &[char]) -> Self {
1061        RowCharKey(values.iter().map(|&ch| ch as u32).collect())
1062    }
1063}
1064
1065#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1066struct RowStringKey(Vec<String>);
1067
1068#[derive(Debug)]
1069pub struct SetdiffEvaluation {
1070    values: Value,
1071    ia: Tensor,
1072}
1073
1074impl SetdiffEvaluation {
1075    fn new(values: Value, ia: Tensor) -> Self {
1076        Self { values, ia }
1077    }
1078
1079    pub fn from_setdiff_result(result: SetdiffResult) -> crate::BuiltinResult<Self> {
1080        let SetdiffResult { values, ia } = result;
1081        let values_tensor = Tensor::new(values.data, values.shape)
1082            .map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
1083        let ia_tensor =
1084            Tensor::new(ia.data, ia.shape).map_err(|e| setdiff_error(format!("setdiff: {e}")))?;
1085        Ok(SetdiffEvaluation::new(
1086            Value::Tensor(values_tensor),
1087            ia_tensor,
1088        ))
1089    }
1090
1091    pub fn into_numeric_setdiff_result(self) -> crate::BuiltinResult<SetdiffResult> {
1092        let SetdiffEvaluation { values, ia } = self;
1093        let values_tensor =
1094            tensor::value_into_tensor_for("setdiff", values).map_err(|e| setdiff_error(e))?;
1095        Ok(SetdiffResult {
1096            values: HostTensorOwned {
1097                data: values_tensor.data,
1098                shape: values_tensor.shape,
1099                storage: GpuTensorStorage::Real,
1100            },
1101            ia: HostTensorOwned {
1102                data: ia.data,
1103                shape: ia.shape,
1104                storage: GpuTensorStorage::Real,
1105            },
1106        })
1107    }
1108
1109    pub fn into_values_value(self) -> Value {
1110        self.values
1111    }
1112
1113    pub fn into_pair(self) -> (Value, Value) {
1114        let ia = tensor::tensor_into_value(self.ia);
1115        (self.values, ia)
1116    }
1117
1118    pub fn values_value(&self) -> Value {
1119        self.values.clone()
1120    }
1121
1122    pub fn ia_value(&self) -> Value {
1123        tensor::tensor_into_value(self.ia.clone())
1124    }
1125}
1126
1127fn canonicalize_f64(value: f64) -> u64 {
1128    if value.is_nan() {
1129        0x7ff8_0000_0000_0000u64
1130    } else if value == 0.0 {
1131        0u64
1132    } else {
1133        value.to_bits()
1134    }
1135}
1136
1137fn compare_f64(a: f64, b: f64) -> Ordering {
1138    if a.is_nan() {
1139        if b.is_nan() {
1140            Ordering::Equal
1141        } else {
1142            Ordering::Greater
1143        }
1144    } else if b.is_nan() {
1145        Ordering::Less
1146    } else {
1147        a.partial_cmp(&b).unwrap_or(Ordering::Equal)
1148    }
1149}
1150
1151fn compare_numeric_rows(a: &[f64], b: &[f64]) -> Ordering {
1152    for (lhs, rhs) in a.iter().zip(b.iter()) {
1153        let ord = compare_f64(*lhs, *rhs);
1154        if ord != Ordering::Equal {
1155            return ord;
1156        }
1157    }
1158    Ordering::Equal
1159}
1160
1161fn complex_is_nan(value: (f64, f64)) -> bool {
1162    value.0.is_nan() || value.1.is_nan()
1163}
1164
1165fn compare_complex(a: (f64, f64), b: (f64, f64)) -> Ordering {
1166    match (complex_is_nan(a), complex_is_nan(b)) {
1167        (true, true) => Ordering::Equal,
1168        (true, false) => Ordering::Greater,
1169        (false, true) => Ordering::Less,
1170        (false, false) => {
1171            let mag_a = a.0.hypot(a.1);
1172            let mag_b = b.0.hypot(b.1);
1173            let mag_cmp = compare_f64(mag_a, mag_b);
1174            if mag_cmp != Ordering::Equal {
1175                return mag_cmp;
1176            }
1177            let re_cmp = compare_f64(a.0, b.0);
1178            if re_cmp != Ordering::Equal {
1179                return re_cmp;
1180            }
1181            compare_f64(a.1, b.1)
1182        }
1183    }
1184}
1185
1186fn compare_complex_rows(a: &[(f64, f64)], b: &[(f64, f64)]) -> Ordering {
1187    for (lhs, rhs) in a.iter().zip(b.iter()) {
1188        let ord = compare_complex(*lhs, *rhs);
1189        if ord != Ordering::Equal {
1190            return ord;
1191        }
1192    }
1193    Ordering::Equal
1194}
1195
1196fn compare_char_rows(a: &[char], b: &[char]) -> Ordering {
1197    for (lhs, rhs) in a.iter().zip(b.iter()) {
1198        let ord = lhs.cmp(rhs);
1199        if ord != Ordering::Equal {
1200            return ord;
1201        }
1202    }
1203    Ordering::Equal
1204}
1205
1206fn compare_string_rows(a: &[String], b: &[String]) -> Ordering {
1207    for (lhs, rhs) in a.iter().zip(b.iter()) {
1208        let ord = lhs.cmp(rhs);
1209        if ord != Ordering::Equal {
1210            return ord;
1211        }
1212    }
1213    Ordering::Equal
1214}
1215
1216#[cfg(test)]
1217pub(crate) mod tests {
1218    use super::*;
1219    use crate::builtins::common::test_support;
1220    use runmat_accelerate_api::HostTensorView;
1221    use runmat_builtins::{CharArray, ResolveContext, StringArray, Tensor, Type, Value};
1222
1223    fn error_message(err: crate::RuntimeError) -> String {
1224        err.message().to_string()
1225    }
1226
1227    fn evaluate_sync(
1228        a: Value,
1229        b: Value,
1230        rest: &[Value],
1231    ) -> crate::BuiltinResult<SetdiffEvaluation> {
1232        futures::executor::block_on(evaluate(a, b, rest))
1233    }
1234
1235    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1236    #[test]
1237    fn setdiff_numeric_sorted_default() {
1238        let a = Tensor::new(vec![5.0, 7.0, 5.0, 1.0], vec![4, 1]).unwrap();
1239        let b = Tensor::new(vec![7.0, 1.0, 3.0], vec![3, 1]).unwrap();
1240        let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[]).expect("setdiff");
1241        match eval.values_value() {
1242            Value::Tensor(t) => {
1243                assert_eq!(t.shape, vec![1, 1]);
1244                assert_eq!(t.data, vec![5.0]);
1245            }
1246            other => panic!("expected tensor result, got {other:?}"),
1247        }
1248        let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1249        assert_eq!(ia.data, vec![1.0]);
1250    }
1251
1252    #[test]
1253    fn setdiff_type_resolver_numeric() {
1254        assert_eq!(
1255            set_values_output_type(
1256                &[Type::tensor(), Type::tensor()],
1257                &ResolveContext::new(Vec::new()),
1258            ),
1259            Type::tensor()
1260        );
1261    }
1262
1263    #[test]
1264    fn setdiff_type_resolver_string_array() {
1265        assert_eq!(
1266            set_values_output_type(
1267                &[Type::cell_of(Type::String), Type::String],
1268                &ResolveContext::new(Vec::new()),
1269            ),
1270            Type::cell_of(Type::String)
1271        );
1272    }
1273
1274    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1275    #[test]
1276    fn setdiff_numeric_stable() {
1277        let a = Tensor::new(vec![4.0, 2.0, 4.0, 1.0, 3.0], vec![5, 1]).unwrap();
1278        let b = Tensor::new(vec![3.0, 4.0, 5.0, 1.0], vec![4, 1]).unwrap();
1279        let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[Value::from("stable")])
1280            .expect("setdiff");
1281        match eval.values_value() {
1282            Value::Tensor(t) => {
1283                assert_eq!(t.shape, vec![1, 1]);
1284                assert_eq!(t.data, vec![2.0]);
1285            }
1286            other => panic!("expected tensor result, got {other:?}"),
1287        }
1288        let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1289        assert_eq!(ia.data, vec![2.0]);
1290    }
1291
1292    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1293    #[test]
1294    fn setdiff_numeric_rows_sorted() {
1295        let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
1296        let b = Tensor::new(vec![3.0, 5.0, 4.0, 6.0], vec![2, 2]).unwrap();
1297        let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[Value::from("rows")])
1298            .expect("setdiff");
1299        match eval.values_value() {
1300            Value::Tensor(t) => {
1301                assert_eq!(t.shape, vec![1, 2]);
1302                assert_eq!(t.data, vec![1.0, 2.0]);
1303            }
1304            other => panic!("expected tensor result, got {other:?}"),
1305        }
1306        let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1307        assert_eq!(ia.data, vec![1.0]);
1308    }
1309
1310    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1311    #[test]
1312    fn setdiff_numeric_removes_nan() {
1313        let a = Tensor::new(vec![f64::NAN, 2.0, 3.0], vec![3, 1]).unwrap();
1314        let b = Tensor::new(vec![f64::NAN], vec![1, 1]).unwrap();
1315        let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[]).expect("setdiff");
1316        let values = tensor::value_into_tensor_for("setdiff", eval.values_value()).expect("values");
1317        assert_eq!(values.data, vec![2.0, 3.0]);
1318        let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1319        assert_eq!(ia.data, vec![2.0, 3.0]);
1320    }
1321
1322    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1323    #[test]
1324    fn setdiff_char_elements() {
1325        let a = CharArray::new(vec!['m', 'z', 'm', 'a'], 2, 2).unwrap();
1326        let b = CharArray::new(vec!['a', 'x', 'm', 'a'], 2, 2).unwrap();
1327        let eval = evaluate_sync(Value::CharArray(a), Value::CharArray(b), &[]).expect("setdiff");
1328        match eval.values_value() {
1329            Value::CharArray(arr) => {
1330                assert_eq!(arr.rows, 1);
1331                assert_eq!(arr.cols, 1);
1332                assert_eq!(arr.data, vec!['z']);
1333            }
1334            other => panic!("expected char array, got {other:?}"),
1335        }
1336        let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1337        assert_eq!(ia.data, vec![3.0]);
1338    }
1339
1340    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1341    #[test]
1342    fn setdiff_string_rows_stable() {
1343        let a = StringArray::new(
1344            vec![
1345                "alpha".to_string(),
1346                "gamma".to_string(),
1347                "beta".to_string(),
1348                "beta".to_string(),
1349            ],
1350            vec![2, 2],
1351        )
1352        .unwrap();
1353        let b = StringArray::new(
1354            vec![
1355                "gamma".to_string(),
1356                "delta".to_string(),
1357                "beta".to_string(),
1358                "beta".to_string(),
1359            ],
1360            vec![2, 2],
1361        )
1362        .unwrap();
1363        let eval = evaluate_sync(
1364            Value::StringArray(a),
1365            Value::StringArray(b),
1366            &[Value::from("rows"), Value::from("stable")],
1367        )
1368        .expect("setdiff");
1369        match eval.values_value() {
1370            Value::StringArray(arr) => {
1371                assert_eq!(arr.shape, vec![1, 2]);
1372                assert_eq!(arr.data, vec!["alpha".to_string(), "beta".to_string()]);
1373            }
1374            other => panic!("expected string array, got {other:?}"),
1375        }
1376        let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1377        assert_eq!(ia.data, vec![1.0]);
1378    }
1379
1380    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1381    #[test]
1382    fn setdiff_type_mismatch_errors() {
1383        let result = evaluate_sync(Value::from(1.0), Value::String("a".into()), &[]);
1384        assert!(result.is_err());
1385    }
1386
1387    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1388    #[test]
1389    fn setdiff_rejects_legacy_option() {
1390        let err = error_message(
1391            evaluate_sync(Value::from(1.0), Value::from(2.0), &[Value::from("legacy")])
1392                .unwrap_err(),
1393        );
1394        assert!(err.contains("setdiff: the 'legacy' behaviour is not supported"));
1395    }
1396
1397    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1398    #[test]
1399    fn setdiff_gpu_roundtrip() {
1400        test_support::with_test_provider(|provider| {
1401            let tensor_a = Tensor::new(vec![10.0, 4.0, 6.0, 4.0], vec![4, 1]).unwrap();
1402            let tensor_b = Tensor::new(vec![6.0, 4.0, 2.0], vec![3, 1]).unwrap();
1403            let view_a = HostTensorView {
1404                data: &tensor_a.data,
1405                shape: &tensor_a.shape,
1406            };
1407            let view_b = HostTensorView {
1408                data: &tensor_b.data,
1409                shape: &tensor_b.shape,
1410            };
1411            let handle_a = provider.upload(&view_a).expect("upload a");
1412            let handle_b = provider.upload(&view_b).expect("upload b");
1413            let eval = evaluate_sync(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
1414                .expect("setdiff");
1415            match eval.values_value() {
1416                Value::Tensor(t) => {
1417                    assert_eq!(t.data, vec![10.0]);
1418                }
1419                other => panic!("expected tensor result, got {other:?}"),
1420            }
1421            let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1422            assert_eq!(ia.data, vec![1.0]);
1423        });
1424    }
1425
1426    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1427    #[test]
1428    #[cfg(feature = "wgpu")]
1429    fn setdiff_wgpu_matches_cpu() {
1430        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1431            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1432        );
1433        let a = Tensor::new(vec![8.0, 4.0, 2.0, 4.0], vec![4, 1]).unwrap();
1434        let b = Tensor::new(vec![2.0, 5.0], vec![2, 1]).unwrap();
1435
1436        let cpu_eval = evaluate_sync(Value::Tensor(a.clone()), Value::Tensor(b.clone()), &[])
1437            .expect("setdiff");
1438        let cpu_values = tensor::value_into_tensor_for("setdiff", cpu_eval.values_value()).unwrap();
1439        let cpu_ia = tensor::value_into_tensor_for("setdiff", cpu_eval.ia_value()).unwrap();
1440
1441        let provider = runmat_accelerate_api::provider().expect("provider");
1442        let view_a = HostTensorView {
1443            data: &a.data,
1444            shape: &a.shape,
1445        };
1446        let view_b = HostTensorView {
1447            data: &b.data,
1448            shape: &b.shape,
1449        };
1450        let handle_a = provider.upload(&view_a).expect("upload A");
1451        let handle_b = provider.upload(&view_b).expect("upload B");
1452        let gpu_eval = evaluate_sync(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
1453            .expect("setdiff");
1454        let gpu_values = tensor::value_into_tensor_for("setdiff", gpu_eval.values_value()).unwrap();
1455        let gpu_ia = tensor::value_into_tensor_for("setdiff", gpu_eval.ia_value()).unwrap();
1456
1457        assert_eq!(gpu_values.data, cpu_values.data);
1458        assert_eq!(gpu_values.shape, cpu_values.shape);
1459        assert_eq!(gpu_ia.data, cpu_ia.data);
1460        assert_eq!(gpu_ia.shape, cpu_ia.shape);
1461    }
1462}