Skip to main content

runmat_runtime/builtins/array/sorting_sets/
setdiff.rs

1//! MATLAB-compatible `setdiff` builtin with GPU-aware semantics for RunMat.
2//!
3//! Provides element-wise and row-wise set difference with optional stable
4//! ordering. GPU tensors are gathered to host memory today, but the builtin is
5//! registered as a residency sink so future providers can implement device-side
6//! kernels without impacting behaviour.
7
8use std::cmp::Ordering;
9use std::collections::{HashMap, HashSet};
10
11use runmat_accelerate_api::{
12    GpuTensorHandle, GpuTensorStorage, HostTensorOwned, SetdiffOptions, SetdiffOrder, SetdiffResult,
13};
14use runmat_builtins::{
15    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
16    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
17    CharArray, ComplexTensor, StringArray, Tensor, Value,
18};
19use runmat_macros::runtime_builtin;
20
21use super::type_resolvers::set_values_output_type;
22use crate::build_runtime_error;
23use crate::builtins::common::arg_tokens::tokens_from_values;
24use crate::builtins::common::gpu_helpers;
25use crate::builtins::common::random_args::complex_tensor_into_value;
26use crate::builtins::common::spec::{
27    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
28    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
29};
30use crate::builtins::common::tensor;
31
32#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::sorting_sets::setdiff")]
33pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
34    name: "setdiff",
35    op_kind: GpuOpKind::Custom("setdiff"),
36    supported_precisions: &[ScalarType::F32, ScalarType::F64],
37    broadcast: BroadcastSemantics::None,
38    provider_hooks: &[ProviderHook::Custom("setdiff")],
39    constant_strategy: ConstantStrategy::InlineLiteral,
40    residency: ResidencyPolicy::GatherImmediately,
41    nan_mode: ReductionNaN::Include,
42    two_pass_threshold: None,
43    workgroup_size: None,
44    accepts_nan_mode: true,
45    notes: "Providers may implement `setdiff`; until then tensors are gathered and processed on the host.",
46};
47
48#[runmat_macros::register_fusion_spec(
49    builtin_path = "crate::builtins::array::sorting_sets::setdiff"
50)]
51pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
52    name: "setdiff",
53    shape: ShapeRequirements::Any,
54    constant_strategy: ConstantStrategy::InlineLiteral,
55    elementwise: None,
56    reduction: None,
57    emits_nan: true,
58    notes: "`setdiff` terminates fusion chains and materialises results on the host; upstream tensors are gathered when necessary.",
59};
60
61const BUILTIN_NAME: &str = "setdiff";
62
63const SETDIFF_OUTPUT_C: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
64    name: "C",
65    ty: BuiltinParamType::Any,
66    arity: BuiltinParamArity::Required,
67    default: None,
68    description: "Values that appear in A but not in B.",
69}];
70
71const SETDIFF_OUTPUT_C_IA: [BuiltinParamDescriptor; 2] = [
72    BuiltinParamDescriptor {
73        name: "C",
74        ty: BuiltinParamType::Any,
75        arity: BuiltinParamArity::Required,
76        default: None,
77        description: "Values that appear in A but not in B.",
78    },
79    BuiltinParamDescriptor {
80        name: "ia",
81        ty: BuiltinParamType::NumericArray,
82        arity: BuiltinParamArity::Required,
83        default: None,
84        description: "Indices selecting retained values/rows from A.",
85    },
86];
87
88const SETDIFF_INPUTS_A_B: [BuiltinParamDescriptor; 2] = [
89    BuiltinParamDescriptor {
90        name: "A",
91        ty: BuiltinParamType::Any,
92        arity: BuiltinParamArity::Required,
93        default: None,
94        description: "First input array.",
95    },
96    BuiltinParamDescriptor {
97        name: "B",
98        ty: BuiltinParamType::Any,
99        arity: BuiltinParamArity::Required,
100        default: None,
101        description: "Second input array.",
102    },
103];
104
105const SETDIFF_INPUTS_A_B_OPTIONS: [BuiltinParamDescriptor; 3] = [
106    BuiltinParamDescriptor {
107        name: "A",
108        ty: BuiltinParamType::Any,
109        arity: BuiltinParamArity::Required,
110        default: None,
111        description: "First input array.",
112    },
113    BuiltinParamDescriptor {
114        name: "B",
115        ty: BuiltinParamType::Any,
116        arity: BuiltinParamArity::Required,
117        default: None,
118        description: "Second input array.",
119    },
120    BuiltinParamDescriptor {
121        name: "option",
122        ty: BuiltinParamType::StringScalar,
123        arity: BuiltinParamArity::Variadic,
124        default: None,
125        description: "Option tokens: 'rows'|'sorted'|'stable'.",
126    },
127];
128
129const SETDIFF_SIGNATURES: [BuiltinSignatureDescriptor; 4] = [
130    BuiltinSignatureDescriptor {
131        label: "C = setdiff(A, B)",
132        inputs: &SETDIFF_INPUTS_A_B,
133        outputs: &SETDIFF_OUTPUT_C,
134    },
135    BuiltinSignatureDescriptor {
136        label: "C = setdiff(A, B, option...)",
137        inputs: &SETDIFF_INPUTS_A_B_OPTIONS,
138        outputs: &SETDIFF_OUTPUT_C,
139    },
140    BuiltinSignatureDescriptor {
141        label: "[C, ia] = setdiff(A, B)",
142        inputs: &SETDIFF_INPUTS_A_B,
143        outputs: &SETDIFF_OUTPUT_C_IA,
144    },
145    BuiltinSignatureDescriptor {
146        label: "[C, ia] = setdiff(A, B, option...)",
147        inputs: &SETDIFF_INPUTS_A_B_OPTIONS,
148        outputs: &SETDIFF_OUTPUT_C_IA,
149    },
150];
151
152const SETDIFF_ERROR_LEGACY_OPTION_UNSUPPORTED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
153    code: "RM.SETDIFF.LEGACY_OPTION_UNSUPPORTED",
154    identifier: Some("RunMat:setdiff:LegacyOptionUnsupported"),
155    when: "Legacy compatibility options are requested.",
156    message: "setdiff: the 'legacy' behaviour is not supported",
157};
158
159const SETDIFF_ERROR_CONFLICTING_ORDER_OPTIONS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
160    code: "RM.SETDIFF.CONFLICTING_ORDER_OPTIONS",
161    identifier: Some("RunMat:setdiff:ConflictingOrderOptions"),
162    when: "Both 'sorted' and 'stable' options are provided.",
163    message: "setdiff: cannot combine 'sorted' with 'stable'",
164};
165
166const SETDIFF_ERROR_UNKNOWN_OPTION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
167    code: "RM.SETDIFF.UNKNOWN_OPTION",
168    identifier: Some("RunMat:setdiff:UnknownOption"),
169    when: "An unsupported option token is provided.",
170    message: "setdiff: unrecognised option",
171};
172
173const SETDIFF_ERROR_ROWS_COLUMN_MISMATCH: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
174    code: "RM.SETDIFF.ROWS_COLUMN_MISMATCH",
175    identifier: Some("RunMat:setdiff:RowsColumnMismatch"),
176    when: "'rows' mode is used and column counts differ.",
177    message: "setdiff: inputs must have the same number of columns when using 'rows'",
178};
179
180const SETDIFF_ERROR_UNSUPPORTED_INPUT_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
181    code: "RM.SETDIFF.UNSUPPORTED_INPUT_TYPE",
182    identifier: Some("RunMat:setdiff:UnsupportedInputType"),
183    when: "Input values cannot be converted into supported setdiff domains.",
184    message: "setdiff: unsupported input type",
185};
186
187const SETDIFF_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
188    code: "RM.SETDIFF.INVALID_ARGUMENT",
189    identifier: Some("RunMat:setdiff:InvalidArgument"),
190    when: "Option arguments are not string-like where required.",
191    message: "setdiff: expected string option arguments",
192};
193
194const SETDIFF_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
195    code: "RM.SETDIFF.INTERNAL",
196    identifier: Some("RunMat:setdiff:Internal"),
197    when: "Internal conversion/allocation/provider decode fails.",
198    message: "setdiff: internal operation failed",
199};
200
201const SETDIFF_ERRORS: [BuiltinErrorDescriptor; 7] = [
202    SETDIFF_ERROR_LEGACY_OPTION_UNSUPPORTED,
203    SETDIFF_ERROR_CONFLICTING_ORDER_OPTIONS,
204    SETDIFF_ERROR_UNKNOWN_OPTION,
205    SETDIFF_ERROR_ROWS_COLUMN_MISMATCH,
206    SETDIFF_ERROR_UNSUPPORTED_INPUT_TYPE,
207    SETDIFF_ERROR_INVALID_ARGUMENT,
208    SETDIFF_ERROR_INTERNAL,
209];
210
211pub const SETDIFF_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
212    signatures: &SETDIFF_SIGNATURES,
213    output_mode: BuiltinOutputMode::ByRequestedOutputCount,
214    completion_policy: BuiltinCompletionPolicy::Public,
215    errors: &SETDIFF_ERRORS,
216};
217
218fn setdiff_error_with(
219    error: &'static BuiltinErrorDescriptor,
220    message: impl Into<String>,
221) -> crate::RuntimeError {
222    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
223    if let Some(identifier) = error.identifier {
224        builder = builder.with_identifier(identifier);
225    }
226    builder.build()
227}
228
229fn setdiff_error(error: &'static BuiltinErrorDescriptor) -> crate::RuntimeError {
230    setdiff_error_with(error, error.message)
231}
232
233fn setdiff_internal_error(message: impl Into<String>) -> crate::RuntimeError {
234    setdiff_error_with(&SETDIFF_ERROR_INTERNAL, message)
235}
236
237#[runtime_builtin(
238    name = "setdiff",
239    category = "array/sorting_sets",
240    summary = "Return values that appear in the first input but not the second.",
241    keywords = "setdiff,difference,stable,rows,indices,gpu",
242    accel = "array_construct",
243    sink = true,
244    type_resolver(set_values_output_type),
245    descriptor(crate::builtins::array::sorting_sets::setdiff::SETDIFF_DESCRIPTOR),
246    builtin_path = "crate::builtins::array::sorting_sets::setdiff"
247)]
248async fn setdiff_builtin(a: Value, b: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
249    let eval = evaluate(a, b, &rest).await?;
250    if let Some(out_count) = crate::output_count::current_output_count() {
251        if out_count == 0 {
252            return Ok(Value::OutputList(Vec::new()));
253        }
254        if out_count == 1 {
255            return Ok(Value::OutputList(vec![eval.into_values_value()]));
256        }
257        let (values, ia) = eval.into_pair();
258        return Ok(crate::output_count::output_list_with_padding(
259            out_count,
260            vec![values, ia],
261        ));
262    }
263    Ok(eval.into_values_value())
264}
265
266/// Evaluate the `setdiff` builtin once and expose all outputs.
267pub async fn evaluate(
268    a: Value,
269    b: Value,
270    rest: &[Value],
271) -> crate::BuiltinResult<SetdiffEvaluation> {
272    let opts = parse_options(rest)?;
273    match (a, b) {
274        (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
275            setdiff_gpu_pair(handle_a, handle_b, &opts).await
276        }
277        (Value::GpuTensor(handle_a), other) => {
278            setdiff_gpu_mixed(handle_a, other, &opts, true).await
279        }
280        (other, Value::GpuTensor(handle_b)) => {
281            setdiff_gpu_mixed(handle_b, other, &opts, false).await
282        }
283        (left, right) => setdiff_host(left, right, &opts),
284    }
285}
286
287fn parse_options(rest: &[Value]) -> crate::BuiltinResult<SetdiffOptions> {
288    let mut opts = SetdiffOptions {
289        rows: false,
290        order: SetdiffOrder::Sorted,
291    };
292    let mut seen_order: Option<SetdiffOrder> = None;
293
294    let tokens = tokens_from_values(rest);
295    for (arg, token) in rest.iter().zip(tokens.iter()) {
296        let text = match token {
297            crate::builtins::common::arg_tokens::ArgToken::String(text) => text.as_str(),
298            _ => {
299                let text = tensor::value_to_string(arg)
300                    .ok_or_else(|| setdiff_error(&SETDIFF_ERROR_INVALID_ARGUMENT))?;
301                let lowered = text.trim().to_ascii_lowercase();
302                parse_setdiff_option(&mut opts, &mut seen_order, &lowered)?;
303                continue;
304            }
305        };
306        parse_setdiff_option(&mut opts, &mut seen_order, text)?;
307    }
308
309    Ok(opts)
310}
311
312fn parse_setdiff_option(
313    opts: &mut SetdiffOptions,
314    seen_order: &mut Option<SetdiffOrder>,
315    lowered: &str,
316) -> crate::BuiltinResult<()> {
317    match lowered {
318        "rows" => opts.rows = true,
319        "sorted" => {
320            if let Some(prev) = seen_order {
321                if *prev != SetdiffOrder::Sorted {
322                    return Err(setdiff_error(&SETDIFF_ERROR_CONFLICTING_ORDER_OPTIONS));
323                }
324            }
325            *seen_order = Some(SetdiffOrder::Sorted);
326            opts.order = SetdiffOrder::Sorted;
327        }
328        "stable" => {
329            if let Some(prev) = seen_order {
330                if *prev != SetdiffOrder::Stable {
331                    return Err(setdiff_error(&SETDIFF_ERROR_CONFLICTING_ORDER_OPTIONS));
332                }
333            }
334            *seen_order = Some(SetdiffOrder::Stable);
335            opts.order = SetdiffOrder::Stable;
336        }
337        "legacy" | "r2012a" => {
338            return Err(setdiff_error(&SETDIFF_ERROR_LEGACY_OPTION_UNSUPPORTED));
339        }
340        other => {
341            return Err(setdiff_error_with(
342                &SETDIFF_ERROR_UNKNOWN_OPTION,
343                format!("setdiff: unrecognised option '{other}'"),
344            ))
345        }
346    }
347    Ok(())
348}
349
350async fn setdiff_gpu_pair(
351    handle_a: GpuTensorHandle,
352    handle_b: GpuTensorHandle,
353    opts: &SetdiffOptions,
354) -> crate::BuiltinResult<SetdiffEvaluation> {
355    if let Some(provider) = runmat_accelerate_api::provider() {
356        match provider.setdiff(&handle_a, &handle_b, opts).await {
357            Ok(result) => return SetdiffEvaluation::from_setdiff_result(result),
358            Err(_) => {
359                // Fall back to host gather when provider does not support setdiff.
360            }
361        }
362    }
363    let a_tensor = gpu_helpers::gather_tensor_async(&handle_a).await?;
364    let b_tensor = gpu_helpers::gather_tensor_async(&handle_b).await?;
365    setdiff_numeric(a_tensor, b_tensor, opts)
366}
367
368async fn setdiff_gpu_mixed(
369    handle_gpu: GpuTensorHandle,
370    other: Value,
371    opts: &SetdiffOptions,
372    gpu_is_a: bool,
373) -> crate::BuiltinResult<SetdiffEvaluation> {
374    let gpu_tensor = gpu_helpers::gather_tensor_async(&handle_gpu).await?;
375    let other_tensor =
376        tensor::value_into_tensor_for("setdiff", other).map_err(setdiff_internal_error)?;
377    if gpu_is_a {
378        setdiff_numeric(gpu_tensor, other_tensor, opts)
379    } else {
380        setdiff_numeric(other_tensor, gpu_tensor, opts)
381    }
382}
383
384fn setdiff_host(
385    a: Value,
386    b: Value,
387    opts: &SetdiffOptions,
388) -> crate::BuiltinResult<SetdiffEvaluation> {
389    match (a, b) {
390        (Value::ComplexTensor(at), Value::ComplexTensor(bt)) => setdiff_complex(at, bt, opts),
391        (Value::ComplexTensor(at), Value::Complex(re, im)) => {
392            let bt = ComplexTensor::new(vec![(re, im)], vec![1, 1])
393                .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
394            setdiff_complex(at, bt, opts)
395        }
396        (Value::Complex(a_re, a_im), Value::ComplexTensor(bt)) => {
397            let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
398                .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
399            setdiff_complex(at, bt, opts)
400        }
401        (Value::Complex(a_re, a_im), Value::Complex(b_re, b_im)) => {
402            let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
403                .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
404            let bt = ComplexTensor::new(vec![(b_re, b_im)], vec![1, 1])
405                .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
406            setdiff_complex(at, bt, opts)
407        }
408
409        (Value::CharArray(ac), Value::CharArray(bc)) => setdiff_char(ac, bc, opts),
410
411        (Value::StringArray(astring), Value::StringArray(bstring)) => {
412            setdiff_string(astring, bstring, opts)
413        }
414        (Value::StringArray(astring), Value::String(b)) => {
415            let bstring = StringArray::new(vec![b], vec![1, 1])
416                .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
417            setdiff_string(astring, bstring, opts)
418        }
419        (Value::String(a), Value::StringArray(bstring)) => {
420            let astring = StringArray::new(vec![a], vec![1, 1])
421                .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
422            setdiff_string(astring, bstring, opts)
423        }
424        (Value::String(a), Value::String(b)) => {
425            let astring = StringArray::new(vec![a], vec![1, 1])
426                .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
427            let bstring = StringArray::new(vec![b], vec![1, 1])
428                .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
429            setdiff_string(astring, bstring, opts)
430        }
431
432        (left, right) => {
433            let tensor_a = tensor::value_into_tensor_for("setdiff", left)
434                .map_err(|e| setdiff_error_with(&SETDIFF_ERROR_UNSUPPORTED_INPUT_TYPE, e))?;
435            let tensor_b = tensor::value_into_tensor_for("setdiff", right)
436                .map_err(|e| setdiff_error_with(&SETDIFF_ERROR_UNSUPPORTED_INPUT_TYPE, e))?;
437            setdiff_numeric(tensor_a, tensor_b, opts)
438        }
439    }
440}
441
442fn setdiff_numeric(
443    a: Tensor,
444    b: Tensor,
445    opts: &SetdiffOptions,
446) -> crate::BuiltinResult<SetdiffEvaluation> {
447    if opts.rows {
448        setdiff_numeric_rows(a, b, opts)
449    } else {
450        setdiff_numeric_elements(a, b, opts)
451    }
452}
453
454/// Helper exposed for acceleration providers handling numeric tensors entirely on the host.
455pub fn setdiff_numeric_from_tensors(
456    a: Tensor,
457    b: Tensor,
458    opts: &SetdiffOptions,
459) -> crate::BuiltinResult<SetdiffEvaluation> {
460    setdiff_numeric(a, b, opts)
461}
462
463fn setdiff_numeric_elements(
464    a: Tensor,
465    b: Tensor,
466    opts: &SetdiffOptions,
467) -> crate::BuiltinResult<SetdiffEvaluation> {
468    let mut b_keys: HashSet<u64> = HashSet::new();
469    for &value in &b.data {
470        b_keys.insert(canonicalize_f64(value));
471    }
472
473    let mut seen: HashMap<u64, usize> = HashMap::new();
474    let mut entries = Vec::<NumericDiffEntry>::new();
475    let mut order_counter = 0usize;
476
477    for (idx, &value) in a.data.iter().enumerate() {
478        let key = canonicalize_f64(value);
479        if b_keys.contains(&key) {
480            continue;
481        }
482        if seen.contains_key(&key) {
483            continue;
484        }
485        let entry_idx = entries.len();
486        entries.push(NumericDiffEntry {
487            value,
488            index: idx,
489            order_rank: order_counter,
490        });
491        seen.insert(key, entry_idx);
492        order_counter += 1;
493    }
494
495    assemble_numeric_setdiff(entries, opts)
496}
497
498fn setdiff_numeric_rows(
499    a: Tensor,
500    b: Tensor,
501    opts: &SetdiffOptions,
502) -> crate::BuiltinResult<SetdiffEvaluation> {
503    if a.shape.len() != 2 || b.shape.len() != 2 {
504        return Err(setdiff_internal_error(
505            "setdiff: 'rows' option requires 2-D numeric matrices",
506        ));
507    }
508    if a.shape[1] != b.shape[1] {
509        return Err(setdiff_error(&SETDIFF_ERROR_ROWS_COLUMN_MISMATCH));
510    }
511
512    let rows_a = a.shape[0];
513    let rows_b = b.shape[0];
514    let cols = a.shape[1];
515
516    let mut b_keys: HashSet<NumericRowKey> = HashSet::new();
517    for r in 0..rows_b {
518        let mut row_values = Vec::with_capacity(cols);
519        for c in 0..cols {
520            let idx = r + c * rows_b;
521            row_values.push(b.data[idx]);
522        }
523        b_keys.insert(NumericRowKey::from_slice(&row_values));
524    }
525
526    let mut seen: HashSet<NumericRowKey> = HashSet::new();
527    let mut entries = Vec::<NumericRowDiffEntry>::new();
528    let mut order_counter = 0usize;
529
530    for r in 0..rows_a {
531        let mut row_values = Vec::with_capacity(cols);
532        for c in 0..cols {
533            let idx = r + c * rows_a;
534            row_values.push(a.data[idx]);
535        }
536        let key = NumericRowKey::from_slice(&row_values);
537        if b_keys.contains(&key) {
538            continue;
539        }
540        if !seen.insert(key) {
541            continue;
542        }
543        entries.push(NumericRowDiffEntry {
544            row_data: row_values,
545            row_index: r,
546            order_rank: order_counter,
547        });
548        order_counter += 1;
549    }
550
551    assemble_numeric_row_setdiff(entries, opts, cols)
552}
553
554fn setdiff_complex(
555    a: ComplexTensor,
556    b: ComplexTensor,
557    opts: &SetdiffOptions,
558) -> crate::BuiltinResult<SetdiffEvaluation> {
559    if opts.rows {
560        setdiff_complex_rows(a, b, opts)
561    } else {
562        setdiff_complex_elements(a, b, opts)
563    }
564}
565
566fn setdiff_complex_elements(
567    a: ComplexTensor,
568    b: ComplexTensor,
569    opts: &SetdiffOptions,
570) -> crate::BuiltinResult<SetdiffEvaluation> {
571    let mut b_keys: HashSet<ComplexKey> = HashSet::new();
572    for &value in &b.data {
573        b_keys.insert(ComplexKey::new(value));
574    }
575
576    let mut seen: HashSet<ComplexKey> = HashSet::new();
577    let mut entries = Vec::<ComplexDiffEntry>::new();
578    let mut order_counter = 0usize;
579
580    for (idx, &value) in a.data.iter().enumerate() {
581        let key = ComplexKey::new(value);
582        if b_keys.contains(&key) {
583            continue;
584        }
585        if !seen.insert(key) {
586            continue;
587        }
588        entries.push(ComplexDiffEntry {
589            value,
590            index: idx,
591            order_rank: order_counter,
592        });
593        order_counter += 1;
594    }
595
596    assemble_complex_setdiff(entries, opts)
597}
598
599fn setdiff_complex_rows(
600    a: ComplexTensor,
601    b: ComplexTensor,
602    opts: &SetdiffOptions,
603) -> crate::BuiltinResult<SetdiffEvaluation> {
604    if a.shape.len() != 2 || b.shape.len() != 2 {
605        return Err(setdiff_internal_error(
606            "setdiff: 'rows' option requires 2-D complex matrices",
607        ));
608    }
609    if a.shape[1] != b.shape[1] {
610        return Err(setdiff_error(&SETDIFF_ERROR_ROWS_COLUMN_MISMATCH));
611    }
612
613    let rows_a = a.shape[0];
614    let rows_b = b.shape[0];
615    let cols = a.shape[1];
616
617    let mut b_keys: HashSet<Vec<ComplexKey>> = HashSet::new();
618    for r in 0..rows_b {
619        let mut key_row = Vec::with_capacity(cols);
620        for c in 0..cols {
621            let idx = r + c * rows_b;
622            key_row.push(ComplexKey::new(b.data[idx]));
623        }
624        b_keys.insert(key_row);
625    }
626
627    let mut seen: HashSet<Vec<ComplexKey>> = HashSet::new();
628    let mut entries = Vec::<ComplexRowDiffEntry>::new();
629    let mut order_counter = 0usize;
630
631    for r in 0..rows_a {
632        let mut row_values = Vec::with_capacity(cols);
633        let mut key_row = Vec::with_capacity(cols);
634        for c in 0..cols {
635            let idx = r + c * rows_a;
636            let value = a.data[idx];
637            row_values.push(value);
638            key_row.push(ComplexKey::new(value));
639        }
640        if b_keys.contains(&key_row) {
641            continue;
642        }
643        if !seen.insert(key_row) {
644            continue;
645        }
646        entries.push(ComplexRowDiffEntry {
647            row_data: row_values,
648            row_index: r,
649            order_rank: order_counter,
650        });
651        order_counter += 1;
652    }
653
654    assemble_complex_row_setdiff(entries, opts, cols)
655}
656
657fn setdiff_char(
658    a: CharArray,
659    b: CharArray,
660    opts: &SetdiffOptions,
661) -> crate::BuiltinResult<SetdiffEvaluation> {
662    if opts.rows {
663        setdiff_char_rows(a, b, opts)
664    } else {
665        setdiff_char_elements(a, b, opts)
666    }
667}
668
669fn setdiff_char_elements(
670    a: CharArray,
671    b: CharArray,
672    opts: &SetdiffOptions,
673) -> crate::BuiltinResult<SetdiffEvaluation> {
674    let mut b_keys: HashSet<u32> = HashSet::new();
675    for ch in &b.data {
676        b_keys.insert(*ch as u32);
677    }
678
679    let mut seen: HashSet<u32> = HashSet::new();
680    let mut entries = Vec::<CharDiffEntry>::new();
681    let mut order_counter = 0usize;
682
683    for col in 0..a.cols {
684        for row in 0..a.rows {
685            let linear_idx = row + col * a.rows;
686            let data_idx = row * a.cols + col;
687            let ch = a.data[data_idx];
688            let key = ch as u32;
689            if b_keys.contains(&key) {
690                continue;
691            }
692            if !seen.insert(key) {
693                continue;
694            }
695            entries.push(CharDiffEntry {
696                ch,
697                index: linear_idx,
698                order_rank: order_counter,
699            });
700            order_counter += 1;
701        }
702    }
703
704    assemble_char_setdiff(entries, opts)
705}
706
707fn setdiff_char_rows(
708    a: CharArray,
709    b: CharArray,
710    opts: &SetdiffOptions,
711) -> crate::BuiltinResult<SetdiffEvaluation> {
712    if a.cols != b.cols {
713        return Err(setdiff_error(&SETDIFF_ERROR_ROWS_COLUMN_MISMATCH));
714    }
715
716    let rows_a = a.rows;
717    let rows_b = b.rows;
718    let cols = a.cols;
719
720    let mut b_keys: HashSet<RowCharKey> = HashSet::new();
721    for r in 0..rows_b {
722        let mut row_values = Vec::with_capacity(cols);
723        for c in 0..cols {
724            let idx = r * cols + c;
725            row_values.push(b.data[idx]);
726        }
727        b_keys.insert(RowCharKey::from_slice(&row_values));
728    }
729
730    let mut seen: HashSet<RowCharKey> = HashSet::new();
731    let mut entries = Vec::<CharRowDiffEntry>::new();
732    let mut order_counter = 0usize;
733
734    for r in 0..rows_a {
735        let mut row_values = Vec::with_capacity(cols);
736        for c in 0..cols {
737            let idx = r * cols + c;
738            row_values.push(a.data[idx]);
739        }
740        let key = RowCharKey::from_slice(&row_values);
741        if b_keys.contains(&key) {
742            continue;
743        }
744        if !seen.insert(key) {
745            continue;
746        }
747        entries.push(CharRowDiffEntry {
748            row_data: row_values,
749            row_index: r,
750            order_rank: order_counter,
751        });
752        order_counter += 1;
753    }
754
755    assemble_char_row_setdiff(entries, opts, cols)
756}
757
758fn setdiff_string(
759    a: StringArray,
760    b: StringArray,
761    opts: &SetdiffOptions,
762) -> crate::BuiltinResult<SetdiffEvaluation> {
763    if opts.rows {
764        setdiff_string_rows(a, b, opts)
765    } else {
766        setdiff_string_elements(a, b, opts)
767    }
768}
769
770fn setdiff_string_elements(
771    a: StringArray,
772    b: StringArray,
773    opts: &SetdiffOptions,
774) -> crate::BuiltinResult<SetdiffEvaluation> {
775    let mut b_keys: HashSet<String> = HashSet::new();
776    for value in &b.data {
777        b_keys.insert(value.clone());
778    }
779
780    let mut seen: HashSet<String> = HashSet::new();
781    let mut entries = Vec::<StringDiffEntry>::new();
782    let mut order_counter = 0usize;
783
784    for (idx, value) in a.data.iter().enumerate() {
785        if b_keys.contains(value) {
786            continue;
787        }
788        if !seen.insert(value.clone()) {
789            continue;
790        }
791        entries.push(StringDiffEntry {
792            value: value.clone(),
793            index: idx,
794            order_rank: order_counter,
795        });
796        order_counter += 1;
797    }
798
799    assemble_string_setdiff(entries, opts)
800}
801
802fn setdiff_string_rows(
803    a: StringArray,
804    b: StringArray,
805    opts: &SetdiffOptions,
806) -> crate::BuiltinResult<SetdiffEvaluation> {
807    if a.shape.len() != 2 || b.shape.len() != 2 {
808        return Err(setdiff_internal_error(
809            "setdiff: 'rows' option requires 2-D string arrays",
810        ));
811    }
812    if a.shape[1] != b.shape[1] {
813        return Err(setdiff_error(&SETDIFF_ERROR_ROWS_COLUMN_MISMATCH));
814    }
815
816    let rows_a = a.shape[0];
817    let rows_b = b.shape[0];
818    let cols = a.shape[1];
819
820    let mut b_keys: HashSet<RowStringKey> = HashSet::new();
821    for r in 0..rows_b {
822        let mut row_values = Vec::with_capacity(cols);
823        for c in 0..cols {
824            let idx = r + c * rows_b;
825            row_values.push(b.data[idx].clone());
826        }
827        b_keys.insert(RowStringKey(row_values.clone()));
828    }
829
830    let mut seen: HashSet<RowStringKey> = HashSet::new();
831    let mut entries = Vec::<StringRowDiffEntry>::new();
832    let mut order_counter = 0usize;
833
834    for r in 0..rows_a {
835        let mut row_values = Vec::with_capacity(cols);
836        for c in 0..cols {
837            let idx = r + c * rows_a;
838            row_values.push(a.data[idx].clone());
839        }
840        let key = RowStringKey(row_values.clone());
841        if b_keys.contains(&key) {
842            continue;
843        }
844        if !seen.insert(key) {
845            continue;
846        }
847        entries.push(StringRowDiffEntry {
848            row_data: row_values,
849            row_index: r,
850            order_rank: order_counter,
851        });
852        order_counter += 1;
853    }
854
855    assemble_string_row_setdiff(entries, opts, cols)
856}
857
858fn assemble_numeric_setdiff(
859    entries: Vec<NumericDiffEntry>,
860    opts: &SetdiffOptions,
861) -> crate::BuiltinResult<SetdiffEvaluation> {
862    let mut order: Vec<usize> = (0..entries.len()).collect();
863    match opts.order {
864        SetdiffOrder::Sorted => {
865            order.sort_by(|&lhs, &rhs| compare_f64(entries[lhs].value, entries[rhs].value));
866        }
867        SetdiffOrder::Stable => {
868            order.sort_by_key(|&idx| entries[idx].order_rank);
869        }
870    }
871
872    let mut values = Vec::with_capacity(order.len());
873    let mut ia = Vec::with_capacity(order.len());
874    for &idx in &order {
875        let entry = &entries[idx];
876        values.push(entry.value);
877        ia.push((entry.index + 1) as f64);
878    }
879
880    let value_tensor = Tensor::new(values, vec![order.len(), 1])
881        .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
882    let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
883        .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
884
885    Ok(SetdiffEvaluation::new(
886        Value::Tensor(value_tensor),
887        ia_tensor,
888    ))
889}
890
891fn assemble_numeric_row_setdiff(
892    entries: Vec<NumericRowDiffEntry>,
893    opts: &SetdiffOptions,
894    cols: usize,
895) -> crate::BuiltinResult<SetdiffEvaluation> {
896    let mut order: Vec<usize> = (0..entries.len()).collect();
897    match opts.order {
898        SetdiffOrder::Sorted => {
899            order.sort_by(|&lhs, &rhs| {
900                compare_numeric_rows(&entries[lhs].row_data, &entries[rhs].row_data)
901            });
902        }
903        SetdiffOrder::Stable => {
904            order.sort_by_key(|&idx| entries[idx].order_rank);
905        }
906    }
907
908    let unique_rows = order.len();
909    let mut values = vec![0.0f64; unique_rows * cols];
910    let mut ia = Vec::with_capacity(unique_rows);
911
912    for (row_pos, &entry_idx) in order.iter().enumerate() {
913        let entry = &entries[entry_idx];
914        for col in 0..cols {
915            let dest = row_pos + col * unique_rows;
916            values[dest] = entry.row_data[col];
917        }
918        ia.push((entry.row_index + 1) as f64);
919    }
920
921    let value_tensor = Tensor::new(values, vec![unique_rows, cols])
922        .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
923    let ia_tensor = Tensor::new(ia, vec![unique_rows, 1])
924        .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
925
926    Ok(SetdiffEvaluation::new(
927        Value::Tensor(value_tensor),
928        ia_tensor,
929    ))
930}
931
932fn assemble_complex_setdiff(
933    entries: Vec<ComplexDiffEntry>,
934    opts: &SetdiffOptions,
935) -> crate::BuiltinResult<SetdiffEvaluation> {
936    let mut order: Vec<usize> = (0..entries.len()).collect();
937    match opts.order {
938        SetdiffOrder::Sorted => {
939            order.sort_by(|&lhs, &rhs| compare_complex(entries[lhs].value, entries[rhs].value));
940        }
941        SetdiffOrder::Stable => {
942            order.sort_by_key(|&idx| entries[idx].order_rank);
943        }
944    }
945
946    let mut values = Vec::with_capacity(order.len());
947    let mut ia = Vec::with_capacity(order.len());
948    for &idx in &order {
949        let entry = &entries[idx];
950        values.push(entry.value);
951        ia.push((entry.index + 1) as f64);
952    }
953
954    let value_tensor = ComplexTensor::new(values, vec![order.len(), 1])
955        .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
956    let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
957        .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
958
959    Ok(SetdiffEvaluation::new(
960        complex_tensor_into_value(value_tensor),
961        ia_tensor,
962    ))
963}
964
965fn assemble_complex_row_setdiff(
966    entries: Vec<ComplexRowDiffEntry>,
967    opts: &SetdiffOptions,
968    cols: usize,
969) -> crate::BuiltinResult<SetdiffEvaluation> {
970    let mut order: Vec<usize> = (0..entries.len()).collect();
971    match opts.order {
972        SetdiffOrder::Sorted => {
973            order.sort_by(|&lhs, &rhs| {
974                compare_complex_rows(&entries[lhs].row_data, &entries[rhs].row_data)
975            });
976        }
977        SetdiffOrder::Stable => {
978            order.sort_by_key(|&idx| entries[idx].order_rank);
979        }
980    }
981
982    let unique_rows = order.len();
983    let mut values = vec![(0.0f64, 0.0f64); unique_rows * cols];
984    let mut ia = Vec::with_capacity(unique_rows);
985
986    for (row_pos, &entry_idx) in order.iter().enumerate() {
987        let entry = &entries[entry_idx];
988        for col in 0..cols {
989            let dest = row_pos + col * unique_rows;
990            values[dest] = entry.row_data[col];
991        }
992        ia.push((entry.row_index + 1) as f64);
993    }
994
995    let value_tensor = ComplexTensor::new(values, vec![unique_rows, cols])
996        .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
997    let ia_tensor = Tensor::new(ia, vec![unique_rows, 1])
998        .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
999
1000    Ok(SetdiffEvaluation::new(
1001        complex_tensor_into_value(value_tensor),
1002        ia_tensor,
1003    ))
1004}
1005
1006fn assemble_char_setdiff(
1007    entries: Vec<CharDiffEntry>,
1008    opts: &SetdiffOptions,
1009) -> crate::BuiltinResult<SetdiffEvaluation> {
1010    let mut order: Vec<usize> = (0..entries.len()).collect();
1011    match opts.order {
1012        SetdiffOrder::Sorted => {
1013            order.sort_by(|&lhs, &rhs| entries[lhs].ch.cmp(&entries[rhs].ch));
1014        }
1015        SetdiffOrder::Stable => {
1016            order.sort_by_key(|&idx| entries[idx].order_rank);
1017        }
1018    }
1019
1020    let mut values = Vec::with_capacity(order.len());
1021    let mut ia = Vec::with_capacity(order.len());
1022    for &idx in &order {
1023        let entry = &entries[idx];
1024        values.push(entry.ch);
1025        ia.push((entry.index + 1) as f64);
1026    }
1027
1028    let value_array = CharArray::new(values, order.len(), 1)
1029        .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
1030    let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
1031        .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
1032
1033    Ok(SetdiffEvaluation::new(
1034        Value::CharArray(value_array),
1035        ia_tensor,
1036    ))
1037}
1038
1039fn assemble_char_row_setdiff(
1040    entries: Vec<CharRowDiffEntry>,
1041    opts: &SetdiffOptions,
1042    cols: usize,
1043) -> crate::BuiltinResult<SetdiffEvaluation> {
1044    let mut order: Vec<usize> = (0..entries.len()).collect();
1045    match opts.order {
1046        SetdiffOrder::Sorted => {
1047            order.sort_by(|&lhs, &rhs| {
1048                compare_char_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1049            });
1050        }
1051        SetdiffOrder::Stable => {
1052            order.sort_by_key(|&idx| entries[idx].order_rank);
1053        }
1054    }
1055
1056    let unique_rows = order.len();
1057    let mut values = vec!['\0'; unique_rows * cols];
1058    let mut ia = Vec::with_capacity(unique_rows);
1059
1060    for (row_pos, &entry_idx) in order.iter().enumerate() {
1061        let entry = &entries[entry_idx];
1062        for col in 0..cols {
1063            let dest = row_pos * cols + col;
1064            values[dest] = entry.row_data[col];
1065        }
1066        ia.push((entry.row_index + 1) as f64);
1067    }
1068
1069    let value_array = CharArray::new(values, unique_rows, cols)
1070        .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
1071    let ia_tensor = Tensor::new(ia, vec![unique_rows, 1])
1072        .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
1073
1074    Ok(SetdiffEvaluation::new(
1075        Value::CharArray(value_array),
1076        ia_tensor,
1077    ))
1078}
1079
1080fn assemble_string_setdiff(
1081    entries: Vec<StringDiffEntry>,
1082    opts: &SetdiffOptions,
1083) -> crate::BuiltinResult<SetdiffEvaluation> {
1084    let mut order: Vec<usize> = (0..entries.len()).collect();
1085    match opts.order {
1086        SetdiffOrder::Sorted => {
1087            order.sort_by(|&lhs, &rhs| entries[lhs].value.cmp(&entries[rhs].value));
1088        }
1089        SetdiffOrder::Stable => {
1090            order.sort_by_key(|&idx| entries[idx].order_rank);
1091        }
1092    }
1093
1094    let mut values = Vec::with_capacity(order.len());
1095    let mut ia = Vec::with_capacity(order.len());
1096    for &idx in &order {
1097        let entry = &entries[idx];
1098        values.push(entry.value.clone());
1099        ia.push((entry.index + 1) as f64);
1100    }
1101
1102    let value_array = StringArray::new(values, vec![order.len(), 1])
1103        .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
1104    let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
1105        .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
1106
1107    Ok(SetdiffEvaluation::new(
1108        Value::StringArray(value_array),
1109        ia_tensor,
1110    ))
1111}
1112
1113fn assemble_string_row_setdiff(
1114    entries: Vec<StringRowDiffEntry>,
1115    opts: &SetdiffOptions,
1116    cols: usize,
1117) -> crate::BuiltinResult<SetdiffEvaluation> {
1118    let mut order: Vec<usize> = (0..entries.len()).collect();
1119    match opts.order {
1120        SetdiffOrder::Sorted => {
1121            order.sort_by(|&lhs, &rhs| {
1122                compare_string_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1123            });
1124        }
1125        SetdiffOrder::Stable => {
1126            order.sort_by_key(|&idx| entries[idx].order_rank);
1127        }
1128    }
1129
1130    let unique_rows = order.len();
1131    let mut values = vec![String::new(); unique_rows * cols];
1132    let mut ia = Vec::with_capacity(unique_rows);
1133
1134    for (row_pos, &entry_idx) in order.iter().enumerate() {
1135        let entry = &entries[entry_idx];
1136        for col in 0..cols {
1137            let dest = row_pos + col * unique_rows;
1138            values[dest] = entry.row_data[col].clone();
1139        }
1140        ia.push((entry.row_index + 1) as f64);
1141    }
1142
1143    let value_array = StringArray::new(values, vec![unique_rows, cols])
1144        .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
1145    let ia_tensor = Tensor::new(ia, vec![unique_rows, 1])
1146        .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
1147
1148    Ok(SetdiffEvaluation::new(
1149        Value::StringArray(value_array),
1150        ia_tensor,
1151    ))
1152}
1153
1154#[derive(Clone, Copy, Debug)]
1155struct NumericDiffEntry {
1156    value: f64,
1157    index: usize,
1158    order_rank: usize,
1159}
1160
1161#[derive(Clone, Debug)]
1162struct NumericRowDiffEntry {
1163    row_data: Vec<f64>,
1164    row_index: usize,
1165    order_rank: usize,
1166}
1167
1168#[derive(Clone, Copy, Debug)]
1169struct ComplexDiffEntry {
1170    value: (f64, f64),
1171    index: usize,
1172    order_rank: usize,
1173}
1174
1175#[derive(Clone, Debug)]
1176struct ComplexRowDiffEntry {
1177    row_data: Vec<(f64, f64)>,
1178    row_index: usize,
1179    order_rank: usize,
1180}
1181
1182#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1183struct CharDiffEntry {
1184    ch: char,
1185    index: usize,
1186    order_rank: usize,
1187}
1188
1189#[derive(Clone, Debug)]
1190struct CharRowDiffEntry {
1191    row_data: Vec<char>,
1192    row_index: usize,
1193    order_rank: usize,
1194}
1195
1196#[derive(Clone, Debug)]
1197struct StringDiffEntry {
1198    value: String,
1199    index: usize,
1200    order_rank: usize,
1201}
1202
1203#[derive(Clone, Debug)]
1204struct StringRowDiffEntry {
1205    row_data: Vec<String>,
1206    row_index: usize,
1207    order_rank: usize,
1208}
1209
1210#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1211struct NumericRowKey(Vec<u64>);
1212
1213impl NumericRowKey {
1214    fn from_slice(values: &[f64]) -> Self {
1215        NumericRowKey(values.iter().map(|&v| canonicalize_f64(v)).collect())
1216    }
1217}
1218
1219#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
1220struct ComplexKey {
1221    re: u64,
1222    im: u64,
1223}
1224
1225impl ComplexKey {
1226    fn new(value: (f64, f64)) -> Self {
1227        Self {
1228            re: canonicalize_f64(value.0),
1229            im: canonicalize_f64(value.1),
1230        }
1231    }
1232}
1233
1234#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1235struct RowCharKey(Vec<u32>);
1236
1237impl RowCharKey {
1238    fn from_slice(values: &[char]) -> Self {
1239        RowCharKey(values.iter().map(|&ch| ch as u32).collect())
1240    }
1241}
1242
1243#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1244struct RowStringKey(Vec<String>);
1245
1246#[derive(Debug)]
1247pub struct SetdiffEvaluation {
1248    values: Value,
1249    ia: Tensor,
1250}
1251
1252impl SetdiffEvaluation {
1253    fn new(values: Value, ia: Tensor) -> Self {
1254        Self { values, ia }
1255    }
1256
1257    pub fn from_setdiff_result(result: SetdiffResult) -> crate::BuiltinResult<Self> {
1258        let SetdiffResult { values, ia } = result;
1259        let values_tensor = Tensor::new(values.data, values.shape)
1260            .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
1261        let ia_tensor = Tensor::new(ia.data, ia.shape)
1262            .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
1263        Ok(SetdiffEvaluation::new(
1264            Value::Tensor(values_tensor),
1265            ia_tensor,
1266        ))
1267    }
1268
1269    pub fn into_numeric_setdiff_result(self) -> crate::BuiltinResult<SetdiffResult> {
1270        let SetdiffEvaluation { values, ia } = self;
1271        let values_tensor = tensor::value_into_tensor_for("setdiff", values)
1272            .map_err(|e| setdiff_internal_error(e))?;
1273        Ok(SetdiffResult {
1274            values: HostTensorOwned {
1275                data: values_tensor.data,
1276                shape: values_tensor.shape,
1277                storage: GpuTensorStorage::Real,
1278            },
1279            ia: HostTensorOwned {
1280                data: ia.data,
1281                shape: ia.shape,
1282                storage: GpuTensorStorage::Real,
1283            },
1284        })
1285    }
1286
1287    pub fn into_values_value(self) -> Value {
1288        self.values
1289    }
1290
1291    pub fn into_pair(self) -> (Value, Value) {
1292        let ia = tensor::tensor_into_value(self.ia);
1293        (self.values, ia)
1294    }
1295
1296    pub fn values_value(&self) -> Value {
1297        self.values.clone()
1298    }
1299
1300    pub fn ia_value(&self) -> Value {
1301        tensor::tensor_into_value(self.ia.clone())
1302    }
1303}
1304
1305fn canonicalize_f64(value: f64) -> u64 {
1306    if value.is_nan() {
1307        0x7ff8_0000_0000_0000u64
1308    } else if value == 0.0 {
1309        0u64
1310    } else {
1311        value.to_bits()
1312    }
1313}
1314
1315fn compare_f64(a: f64, b: f64) -> Ordering {
1316    if a.is_nan() {
1317        if b.is_nan() {
1318            Ordering::Equal
1319        } else {
1320            Ordering::Greater
1321        }
1322    } else if b.is_nan() {
1323        Ordering::Less
1324    } else {
1325        a.partial_cmp(&b).unwrap_or(Ordering::Equal)
1326    }
1327}
1328
1329fn compare_numeric_rows(a: &[f64], b: &[f64]) -> Ordering {
1330    for (lhs, rhs) in a.iter().zip(b.iter()) {
1331        let ord = compare_f64(*lhs, *rhs);
1332        if ord != Ordering::Equal {
1333            return ord;
1334        }
1335    }
1336    Ordering::Equal
1337}
1338
1339fn complex_is_nan(value: (f64, f64)) -> bool {
1340    value.0.is_nan() || value.1.is_nan()
1341}
1342
1343fn compare_complex(a: (f64, f64), b: (f64, f64)) -> Ordering {
1344    match (complex_is_nan(a), complex_is_nan(b)) {
1345        (true, true) => Ordering::Equal,
1346        (true, false) => Ordering::Greater,
1347        (false, true) => Ordering::Less,
1348        (false, false) => {
1349            let mag_a = a.0.hypot(a.1);
1350            let mag_b = b.0.hypot(b.1);
1351            let mag_cmp = compare_f64(mag_a, mag_b);
1352            if mag_cmp != Ordering::Equal {
1353                return mag_cmp;
1354            }
1355            let re_cmp = compare_f64(a.0, b.0);
1356            if re_cmp != Ordering::Equal {
1357                return re_cmp;
1358            }
1359            compare_f64(a.1, b.1)
1360        }
1361    }
1362}
1363
1364fn compare_complex_rows(a: &[(f64, f64)], b: &[(f64, f64)]) -> Ordering {
1365    for (lhs, rhs) in a.iter().zip(b.iter()) {
1366        let ord = compare_complex(*lhs, *rhs);
1367        if ord != Ordering::Equal {
1368            return ord;
1369        }
1370    }
1371    Ordering::Equal
1372}
1373
1374fn compare_char_rows(a: &[char], b: &[char]) -> Ordering {
1375    for (lhs, rhs) in a.iter().zip(b.iter()) {
1376        let ord = lhs.cmp(rhs);
1377        if ord != Ordering::Equal {
1378            return ord;
1379        }
1380    }
1381    Ordering::Equal
1382}
1383
1384fn compare_string_rows(a: &[String], b: &[String]) -> Ordering {
1385    for (lhs, rhs) in a.iter().zip(b.iter()) {
1386        let ord = lhs.cmp(rhs);
1387        if ord != Ordering::Equal {
1388            return ord;
1389        }
1390    }
1391    Ordering::Equal
1392}
1393
1394#[cfg(test)]
1395pub(crate) mod tests {
1396    use super::*;
1397    use crate::builtins::common::test_support;
1398    use runmat_accelerate_api::HostTensorView;
1399    use runmat_builtins::{CharArray, ResolveContext, StringArray, Tensor, Type, Value};
1400
1401    fn evaluate_sync(
1402        a: Value,
1403        b: Value,
1404        rest: &[Value],
1405    ) -> crate::BuiltinResult<SetdiffEvaluation> {
1406        futures::executor::block_on(evaluate(a, b, rest))
1407    }
1408
1409    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1410    #[test]
1411    fn setdiff_numeric_sorted_default() {
1412        let a = Tensor::new(vec![5.0, 7.0, 5.0, 1.0], vec![4, 1]).unwrap();
1413        let b = Tensor::new(vec![7.0, 1.0, 3.0], vec![3, 1]).unwrap();
1414        let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[]).expect("setdiff");
1415        match eval.values_value() {
1416            Value::Tensor(t) => {
1417                assert_eq!(t.shape, vec![1, 1]);
1418                assert_eq!(t.data, vec![5.0]);
1419            }
1420            other => panic!("expected tensor result, got {other:?}"),
1421        }
1422        let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1423        assert_eq!(ia.data, vec![1.0]);
1424    }
1425
1426    #[test]
1427    fn setdiff_type_resolver_numeric() {
1428        assert_eq!(
1429            set_values_output_type(
1430                &[Type::tensor(), Type::tensor()],
1431                &ResolveContext::new(Vec::new()),
1432            ),
1433            Type::tensor()
1434        );
1435    }
1436
1437    #[test]
1438    fn setdiff_type_resolver_string_array() {
1439        assert_eq!(
1440            set_values_output_type(
1441                &[Type::cell_of(Type::String), Type::String],
1442                &ResolveContext::new(Vec::new()),
1443            ),
1444            Type::cell_of(Type::String)
1445        );
1446    }
1447
1448    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1449    #[test]
1450    fn setdiff_numeric_stable() {
1451        let a = Tensor::new(vec![4.0, 2.0, 4.0, 1.0, 3.0], vec![5, 1]).unwrap();
1452        let b = Tensor::new(vec![3.0, 4.0, 5.0, 1.0], vec![4, 1]).unwrap();
1453        let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[Value::from("stable")])
1454            .expect("setdiff");
1455        match eval.values_value() {
1456            Value::Tensor(t) => {
1457                assert_eq!(t.shape, vec![1, 1]);
1458                assert_eq!(t.data, vec![2.0]);
1459            }
1460            other => panic!("expected tensor result, got {other:?}"),
1461        }
1462        let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1463        assert_eq!(ia.data, vec![2.0]);
1464    }
1465
1466    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1467    #[test]
1468    fn setdiff_numeric_rows_sorted() {
1469        let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
1470        let b = Tensor::new(vec![3.0, 5.0, 4.0, 6.0], vec![2, 2]).unwrap();
1471        let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[Value::from("rows")])
1472            .expect("setdiff");
1473        match eval.values_value() {
1474            Value::Tensor(t) => {
1475                assert_eq!(t.shape, vec![1, 2]);
1476                assert_eq!(t.data, vec![1.0, 2.0]);
1477            }
1478            other => panic!("expected tensor result, got {other:?}"),
1479        }
1480        let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1481        assert_eq!(ia.data, vec![1.0]);
1482    }
1483
1484    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1485    #[test]
1486    fn setdiff_numeric_removes_nan() {
1487        let a = Tensor::new(vec![f64::NAN, 2.0, 3.0], vec![3, 1]).unwrap();
1488        let b = Tensor::new(vec![f64::NAN], vec![1, 1]).unwrap();
1489        let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[]).expect("setdiff");
1490        let values = tensor::value_into_tensor_for("setdiff", eval.values_value()).expect("values");
1491        assert_eq!(values.data, vec![2.0, 3.0]);
1492        let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1493        assert_eq!(ia.data, vec![2.0, 3.0]);
1494    }
1495
1496    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1497    #[test]
1498    fn setdiff_char_elements() {
1499        let a = CharArray::new(vec!['m', 'z', 'm', 'a'], 2, 2).unwrap();
1500        let b = CharArray::new(vec!['a', 'x', 'm', 'a'], 2, 2).unwrap();
1501        let eval = evaluate_sync(Value::CharArray(a), Value::CharArray(b), &[]).expect("setdiff");
1502        match eval.values_value() {
1503            Value::CharArray(arr) => {
1504                assert_eq!(arr.rows, 1);
1505                assert_eq!(arr.cols, 1);
1506                assert_eq!(arr.data, vec!['z']);
1507            }
1508            other => panic!("expected char array, got {other:?}"),
1509        }
1510        let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1511        assert_eq!(ia.data, vec![3.0]);
1512    }
1513
1514    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1515    #[test]
1516    fn setdiff_string_rows_stable() {
1517        let a = StringArray::new(
1518            vec![
1519                "alpha".to_string(),
1520                "gamma".to_string(),
1521                "beta".to_string(),
1522                "beta".to_string(),
1523            ],
1524            vec![2, 2],
1525        )
1526        .unwrap();
1527        let b = StringArray::new(
1528            vec![
1529                "gamma".to_string(),
1530                "delta".to_string(),
1531                "beta".to_string(),
1532                "beta".to_string(),
1533            ],
1534            vec![2, 2],
1535        )
1536        .unwrap();
1537        let eval = evaluate_sync(
1538            Value::StringArray(a),
1539            Value::StringArray(b),
1540            &[Value::from("rows"), Value::from("stable")],
1541        )
1542        .expect("setdiff");
1543        match eval.values_value() {
1544            Value::StringArray(arr) => {
1545                assert_eq!(arr.shape, vec![1, 2]);
1546                assert_eq!(arr.data, vec!["alpha".to_string(), "beta".to_string()]);
1547            }
1548            other => panic!("expected string array, got {other:?}"),
1549        }
1550        let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1551        assert_eq!(ia.data, vec![1.0]);
1552    }
1553
1554    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1555    #[test]
1556    fn setdiff_type_mismatch_errors() {
1557        let err = evaluate_sync(Value::from(1.0), Value::String("a".into()), &[]).unwrap_err();
1558        assert_eq!(
1559            err.identifier(),
1560            SETDIFF_ERROR_UNSUPPORTED_INPUT_TYPE.identifier
1561        );
1562    }
1563
1564    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1565    #[test]
1566    fn setdiff_rows_dimension_mismatch_reports_identifier() {
1567        let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).expect("tensor a");
1568        let b = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).expect("tensor b");
1569        let err =
1570            evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[Value::from("rows")]).unwrap_err();
1571        assert_eq!(
1572            err.identifier(),
1573            SETDIFF_ERROR_ROWS_COLUMN_MISMATCH.identifier
1574        );
1575    }
1576
1577    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1578    #[test]
1579    fn setdiff_rejects_legacy_option() {
1580        let err = evaluate_sync(Value::from(1.0), Value::from(2.0), &[Value::from("legacy")])
1581            .unwrap_err();
1582        assert_eq!(
1583            err.identifier(),
1584            SETDIFF_ERROR_LEGACY_OPTION_UNSUPPORTED.identifier
1585        );
1586    }
1587
1588    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1589    #[test]
1590    fn setdiff_rejects_conflicting_order_options() {
1591        let err = evaluate_sync(
1592            Value::from(1.0),
1593            Value::from(2.0),
1594            &[Value::from("stable"), Value::from("sorted")],
1595        )
1596        .unwrap_err();
1597        assert_eq!(
1598            err.identifier(),
1599            SETDIFF_ERROR_CONFLICTING_ORDER_OPTIONS.identifier
1600        );
1601    }
1602
1603    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1604    #[test]
1605    fn setdiff_rejects_unknown_option() {
1606        let err =
1607            evaluate_sync(Value::from(1.0), Value::from(2.0), &[Value::from("bogus")]).unwrap_err();
1608        assert_eq!(err.identifier(), SETDIFF_ERROR_UNKNOWN_OPTION.identifier);
1609    }
1610
1611    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1612    #[test]
1613    fn setdiff_gpu_roundtrip() {
1614        test_support::with_test_provider(|provider| {
1615            let tensor_a = Tensor::new(vec![10.0, 4.0, 6.0, 4.0], vec![4, 1]).unwrap();
1616            let tensor_b = Tensor::new(vec![6.0, 4.0, 2.0], vec![3, 1]).unwrap();
1617            let view_a = HostTensorView {
1618                data: &tensor_a.data,
1619                shape: &tensor_a.shape,
1620            };
1621            let view_b = HostTensorView {
1622                data: &tensor_b.data,
1623                shape: &tensor_b.shape,
1624            };
1625            let handle_a = provider.upload(&view_a).expect("upload a");
1626            let handle_b = provider.upload(&view_b).expect("upload b");
1627            let eval = evaluate_sync(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
1628                .expect("setdiff");
1629            match eval.values_value() {
1630                Value::Tensor(t) => {
1631                    assert_eq!(t.data, vec![10.0]);
1632                }
1633                other => panic!("expected tensor result, got {other:?}"),
1634            }
1635            let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1636            assert_eq!(ia.data, vec![1.0]);
1637        });
1638    }
1639
1640    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1641    #[test]
1642    #[cfg(feature = "wgpu")]
1643    fn setdiff_wgpu_matches_cpu() {
1644        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1645            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1646        );
1647        let a = Tensor::new(vec![8.0, 4.0, 2.0, 4.0], vec![4, 1]).unwrap();
1648        let b = Tensor::new(vec![2.0, 5.0], vec![2, 1]).unwrap();
1649
1650        let cpu_eval = evaluate_sync(Value::Tensor(a.clone()), Value::Tensor(b.clone()), &[])
1651            .expect("setdiff");
1652        let cpu_values = tensor::value_into_tensor_for("setdiff", cpu_eval.values_value()).unwrap();
1653        let cpu_ia = tensor::value_into_tensor_for("setdiff", cpu_eval.ia_value()).unwrap();
1654
1655        let provider = runmat_accelerate_api::provider().expect("provider");
1656        let view_a = HostTensorView {
1657            data: &a.data,
1658            shape: &a.shape,
1659        };
1660        let view_b = HostTensorView {
1661            data: &b.data,
1662            shape: &b.shape,
1663        };
1664        let handle_a = provider.upload(&view_a).expect("upload A");
1665        let handle_b = provider.upload(&view_b).expect("upload B");
1666        let gpu_eval = evaluate_sync(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
1667            .expect("setdiff");
1668        let gpu_values = tensor::value_into_tensor_for("setdiff", gpu_eval.values_value()).unwrap();
1669        let gpu_ia = tensor::value_into_tensor_for("setdiff", gpu_eval.ia_value()).unwrap();
1670
1671        assert_eq!(gpu_values.data, cpu_values.data);
1672        assert_eq!(gpu_values.shape, cpu_values.shape);
1673        assert_eq!(gpu_ia.data, cpu_ia.data);
1674        assert_eq!(gpu_ia.shape, cpu_ia.shape);
1675    }
1676}