1use std::cmp::Ordering;
9use std::collections::{HashMap, HashSet};
10
11use runmat_accelerate_api::{
12 GpuTensorHandle, GpuTensorStorage, HostTensorOwned, SetdiffOptions, SetdiffOrder, SetdiffResult,
13};
14use runmat_builtins::{
15 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
16 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
17 CharArray, ComplexTensor, StringArray, Tensor, Value,
18};
19use runmat_macros::runtime_builtin;
20
21use super::type_resolvers::set_values_output_type;
22use crate::build_runtime_error;
23use crate::builtins::common::arg_tokens::tokens_from_values;
24use crate::builtins::common::gpu_helpers;
25use crate::builtins::common::random_args::complex_tensor_into_value;
26use crate::builtins::common::spec::{
27 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
28 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
29};
30use crate::builtins::common::tensor;
31
32#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::sorting_sets::setdiff")]
33pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
34 name: "setdiff",
35 op_kind: GpuOpKind::Custom("setdiff"),
36 supported_precisions: &[ScalarType::F32, ScalarType::F64],
37 broadcast: BroadcastSemantics::None,
38 provider_hooks: &[ProviderHook::Custom("setdiff")],
39 constant_strategy: ConstantStrategy::InlineLiteral,
40 residency: ResidencyPolicy::GatherImmediately,
41 nan_mode: ReductionNaN::Include,
42 two_pass_threshold: None,
43 workgroup_size: None,
44 accepts_nan_mode: true,
45 notes: "Providers may implement `setdiff`; until then tensors are gathered and processed on the host.",
46};
47
48#[runmat_macros::register_fusion_spec(
49 builtin_path = "crate::builtins::array::sorting_sets::setdiff"
50)]
51pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
52 name: "setdiff",
53 shape: ShapeRequirements::Any,
54 constant_strategy: ConstantStrategy::InlineLiteral,
55 elementwise: None,
56 reduction: None,
57 emits_nan: true,
58 notes: "`setdiff` terminates fusion chains and materialises results on the host; upstream tensors are gathered when necessary.",
59};
60
61const BUILTIN_NAME: &str = "setdiff";
62
63const SETDIFF_OUTPUT_C: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
64 name: "C",
65 ty: BuiltinParamType::Any,
66 arity: BuiltinParamArity::Required,
67 default: None,
68 description: "Values that appear in A but not in B.",
69}];
70
71const SETDIFF_OUTPUT_C_IA: [BuiltinParamDescriptor; 2] = [
72 BuiltinParamDescriptor {
73 name: "C",
74 ty: BuiltinParamType::Any,
75 arity: BuiltinParamArity::Required,
76 default: None,
77 description: "Values that appear in A but not in B.",
78 },
79 BuiltinParamDescriptor {
80 name: "ia",
81 ty: BuiltinParamType::NumericArray,
82 arity: BuiltinParamArity::Required,
83 default: None,
84 description: "Indices selecting retained values/rows from A.",
85 },
86];
87
88const SETDIFF_INPUTS_A_B: [BuiltinParamDescriptor; 2] = [
89 BuiltinParamDescriptor {
90 name: "A",
91 ty: BuiltinParamType::Any,
92 arity: BuiltinParamArity::Required,
93 default: None,
94 description: "First input array.",
95 },
96 BuiltinParamDescriptor {
97 name: "B",
98 ty: BuiltinParamType::Any,
99 arity: BuiltinParamArity::Required,
100 default: None,
101 description: "Second input array.",
102 },
103];
104
105const SETDIFF_INPUTS_A_B_OPTIONS: [BuiltinParamDescriptor; 3] = [
106 BuiltinParamDescriptor {
107 name: "A",
108 ty: BuiltinParamType::Any,
109 arity: BuiltinParamArity::Required,
110 default: None,
111 description: "First input array.",
112 },
113 BuiltinParamDescriptor {
114 name: "B",
115 ty: BuiltinParamType::Any,
116 arity: BuiltinParamArity::Required,
117 default: None,
118 description: "Second input array.",
119 },
120 BuiltinParamDescriptor {
121 name: "option",
122 ty: BuiltinParamType::StringScalar,
123 arity: BuiltinParamArity::Variadic,
124 default: None,
125 description: "Option tokens: 'rows'|'sorted'|'stable'.",
126 },
127];
128
129const SETDIFF_SIGNATURES: [BuiltinSignatureDescriptor; 4] = [
130 BuiltinSignatureDescriptor {
131 label: "C = setdiff(A, B)",
132 inputs: &SETDIFF_INPUTS_A_B,
133 outputs: &SETDIFF_OUTPUT_C,
134 },
135 BuiltinSignatureDescriptor {
136 label: "C = setdiff(A, B, option...)",
137 inputs: &SETDIFF_INPUTS_A_B_OPTIONS,
138 outputs: &SETDIFF_OUTPUT_C,
139 },
140 BuiltinSignatureDescriptor {
141 label: "[C, ia] = setdiff(A, B)",
142 inputs: &SETDIFF_INPUTS_A_B,
143 outputs: &SETDIFF_OUTPUT_C_IA,
144 },
145 BuiltinSignatureDescriptor {
146 label: "[C, ia] = setdiff(A, B, option...)",
147 inputs: &SETDIFF_INPUTS_A_B_OPTIONS,
148 outputs: &SETDIFF_OUTPUT_C_IA,
149 },
150];
151
152const SETDIFF_ERROR_LEGACY_OPTION_UNSUPPORTED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
153 code: "RM.SETDIFF.LEGACY_OPTION_UNSUPPORTED",
154 identifier: Some("RunMat:setdiff:LegacyOptionUnsupported"),
155 when: "Legacy compatibility options are requested.",
156 message: "setdiff: the 'legacy' behaviour is not supported",
157};
158
159const SETDIFF_ERROR_CONFLICTING_ORDER_OPTIONS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
160 code: "RM.SETDIFF.CONFLICTING_ORDER_OPTIONS",
161 identifier: Some("RunMat:setdiff:ConflictingOrderOptions"),
162 when: "Both 'sorted' and 'stable' options are provided.",
163 message: "setdiff: cannot combine 'sorted' with 'stable'",
164};
165
166const SETDIFF_ERROR_UNKNOWN_OPTION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
167 code: "RM.SETDIFF.UNKNOWN_OPTION",
168 identifier: Some("RunMat:setdiff:UnknownOption"),
169 when: "An unsupported option token is provided.",
170 message: "setdiff: unrecognised option",
171};
172
173const SETDIFF_ERROR_ROWS_COLUMN_MISMATCH: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
174 code: "RM.SETDIFF.ROWS_COLUMN_MISMATCH",
175 identifier: Some("RunMat:setdiff:RowsColumnMismatch"),
176 when: "'rows' mode is used and column counts differ.",
177 message: "setdiff: inputs must have the same number of columns when using 'rows'",
178};
179
180const SETDIFF_ERROR_UNSUPPORTED_INPUT_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
181 code: "RM.SETDIFF.UNSUPPORTED_INPUT_TYPE",
182 identifier: Some("RunMat:setdiff:UnsupportedInputType"),
183 when: "Input values cannot be converted into supported setdiff domains.",
184 message: "setdiff: unsupported input type",
185};
186
187const SETDIFF_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
188 code: "RM.SETDIFF.INVALID_ARGUMENT",
189 identifier: Some("RunMat:setdiff:InvalidArgument"),
190 when: "Option arguments are not string-like where required.",
191 message: "setdiff: expected string option arguments",
192};
193
194const SETDIFF_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
195 code: "RM.SETDIFF.INTERNAL",
196 identifier: Some("RunMat:setdiff:Internal"),
197 when: "Internal conversion/allocation/provider decode fails.",
198 message: "setdiff: internal operation failed",
199};
200
201const SETDIFF_ERRORS: [BuiltinErrorDescriptor; 7] = [
202 SETDIFF_ERROR_LEGACY_OPTION_UNSUPPORTED,
203 SETDIFF_ERROR_CONFLICTING_ORDER_OPTIONS,
204 SETDIFF_ERROR_UNKNOWN_OPTION,
205 SETDIFF_ERROR_ROWS_COLUMN_MISMATCH,
206 SETDIFF_ERROR_UNSUPPORTED_INPUT_TYPE,
207 SETDIFF_ERROR_INVALID_ARGUMENT,
208 SETDIFF_ERROR_INTERNAL,
209];
210
211pub const SETDIFF_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
212 signatures: &SETDIFF_SIGNATURES,
213 output_mode: BuiltinOutputMode::ByRequestedOutputCount,
214 completion_policy: BuiltinCompletionPolicy::Public,
215 errors: &SETDIFF_ERRORS,
216};
217
218fn setdiff_error_with(
219 error: &'static BuiltinErrorDescriptor,
220 message: impl Into<String>,
221) -> crate::RuntimeError {
222 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
223 if let Some(identifier) = error.identifier {
224 builder = builder.with_identifier(identifier);
225 }
226 builder.build()
227}
228
229fn setdiff_error(error: &'static BuiltinErrorDescriptor) -> crate::RuntimeError {
230 setdiff_error_with(error, error.message)
231}
232
233fn setdiff_internal_error(message: impl Into<String>) -> crate::RuntimeError {
234 setdiff_error_with(&SETDIFF_ERROR_INTERNAL, message)
235}
236
237#[runtime_builtin(
238 name = "setdiff",
239 category = "array/sorting_sets",
240 summary = "Return values that appear in the first input but not the second.",
241 keywords = "setdiff,difference,stable,rows,indices,gpu",
242 accel = "array_construct",
243 sink = true,
244 type_resolver(set_values_output_type),
245 descriptor(crate::builtins::array::sorting_sets::setdiff::SETDIFF_DESCRIPTOR),
246 builtin_path = "crate::builtins::array::sorting_sets::setdiff"
247)]
248async fn setdiff_builtin(a: Value, b: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
249 let eval = evaluate(a, b, &rest).await?;
250 if let Some(out_count) = crate::output_count::current_output_count() {
251 if out_count == 0 {
252 return Ok(Value::OutputList(Vec::new()));
253 }
254 if out_count == 1 {
255 return Ok(Value::OutputList(vec![eval.into_values_value()]));
256 }
257 let (values, ia) = eval.into_pair();
258 return Ok(crate::output_count::output_list_with_padding(
259 out_count,
260 vec![values, ia],
261 ));
262 }
263 Ok(eval.into_values_value())
264}
265
266pub async fn evaluate(
268 a: Value,
269 b: Value,
270 rest: &[Value],
271) -> crate::BuiltinResult<SetdiffEvaluation> {
272 let opts = parse_options(rest)?;
273 match (a, b) {
274 (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
275 setdiff_gpu_pair(handle_a, handle_b, &opts).await
276 }
277 (Value::GpuTensor(handle_a), other) => {
278 setdiff_gpu_mixed(handle_a, other, &opts, true).await
279 }
280 (other, Value::GpuTensor(handle_b)) => {
281 setdiff_gpu_mixed(handle_b, other, &opts, false).await
282 }
283 (left, right) => setdiff_host(left, right, &opts),
284 }
285}
286
287fn parse_options(rest: &[Value]) -> crate::BuiltinResult<SetdiffOptions> {
288 let mut opts = SetdiffOptions {
289 rows: false,
290 order: SetdiffOrder::Sorted,
291 };
292 let mut seen_order: Option<SetdiffOrder> = None;
293
294 let tokens = tokens_from_values(rest);
295 for (arg, token) in rest.iter().zip(tokens.iter()) {
296 let text = match token {
297 crate::builtins::common::arg_tokens::ArgToken::String(text) => text.as_str(),
298 _ => {
299 let text = tensor::value_to_string(arg)
300 .ok_or_else(|| setdiff_error(&SETDIFF_ERROR_INVALID_ARGUMENT))?;
301 let lowered = text.trim().to_ascii_lowercase();
302 parse_setdiff_option(&mut opts, &mut seen_order, &lowered)?;
303 continue;
304 }
305 };
306 parse_setdiff_option(&mut opts, &mut seen_order, text)?;
307 }
308
309 Ok(opts)
310}
311
312fn parse_setdiff_option(
313 opts: &mut SetdiffOptions,
314 seen_order: &mut Option<SetdiffOrder>,
315 lowered: &str,
316) -> crate::BuiltinResult<()> {
317 match lowered {
318 "rows" => opts.rows = true,
319 "sorted" => {
320 if let Some(prev) = seen_order {
321 if *prev != SetdiffOrder::Sorted {
322 return Err(setdiff_error(&SETDIFF_ERROR_CONFLICTING_ORDER_OPTIONS));
323 }
324 }
325 *seen_order = Some(SetdiffOrder::Sorted);
326 opts.order = SetdiffOrder::Sorted;
327 }
328 "stable" => {
329 if let Some(prev) = seen_order {
330 if *prev != SetdiffOrder::Stable {
331 return Err(setdiff_error(&SETDIFF_ERROR_CONFLICTING_ORDER_OPTIONS));
332 }
333 }
334 *seen_order = Some(SetdiffOrder::Stable);
335 opts.order = SetdiffOrder::Stable;
336 }
337 "legacy" | "r2012a" => {
338 return Err(setdiff_error(&SETDIFF_ERROR_LEGACY_OPTION_UNSUPPORTED));
339 }
340 other => {
341 return Err(setdiff_error_with(
342 &SETDIFF_ERROR_UNKNOWN_OPTION,
343 format!("setdiff: unrecognised option '{other}'"),
344 ))
345 }
346 }
347 Ok(())
348}
349
350async fn setdiff_gpu_pair(
351 handle_a: GpuTensorHandle,
352 handle_b: GpuTensorHandle,
353 opts: &SetdiffOptions,
354) -> crate::BuiltinResult<SetdiffEvaluation> {
355 if let Some(provider) = runmat_accelerate_api::provider() {
356 match provider.setdiff(&handle_a, &handle_b, opts).await {
357 Ok(result) => return SetdiffEvaluation::from_setdiff_result(result),
358 Err(_) => {
359 }
361 }
362 }
363 let a_tensor = gpu_helpers::gather_tensor_async(&handle_a).await?;
364 let b_tensor = gpu_helpers::gather_tensor_async(&handle_b).await?;
365 setdiff_numeric(a_tensor, b_tensor, opts)
366}
367
368async fn setdiff_gpu_mixed(
369 handle_gpu: GpuTensorHandle,
370 other: Value,
371 opts: &SetdiffOptions,
372 gpu_is_a: bool,
373) -> crate::BuiltinResult<SetdiffEvaluation> {
374 let gpu_tensor = gpu_helpers::gather_tensor_async(&handle_gpu).await?;
375 let other_tensor =
376 tensor::value_into_tensor_for("setdiff", other).map_err(setdiff_internal_error)?;
377 if gpu_is_a {
378 setdiff_numeric(gpu_tensor, other_tensor, opts)
379 } else {
380 setdiff_numeric(other_tensor, gpu_tensor, opts)
381 }
382}
383
384fn setdiff_host(
385 a: Value,
386 b: Value,
387 opts: &SetdiffOptions,
388) -> crate::BuiltinResult<SetdiffEvaluation> {
389 match (a, b) {
390 (Value::ComplexTensor(at), Value::ComplexTensor(bt)) => setdiff_complex(at, bt, opts),
391 (Value::ComplexTensor(at), Value::Complex(re, im)) => {
392 let bt = ComplexTensor::new(vec![(re, im)], vec![1, 1])
393 .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
394 setdiff_complex(at, bt, opts)
395 }
396 (Value::Complex(a_re, a_im), Value::ComplexTensor(bt)) => {
397 let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
398 .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
399 setdiff_complex(at, bt, opts)
400 }
401 (Value::Complex(a_re, a_im), Value::Complex(b_re, b_im)) => {
402 let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
403 .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
404 let bt = ComplexTensor::new(vec![(b_re, b_im)], vec![1, 1])
405 .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
406 setdiff_complex(at, bt, opts)
407 }
408
409 (Value::CharArray(ac), Value::CharArray(bc)) => setdiff_char(ac, bc, opts),
410
411 (Value::StringArray(astring), Value::StringArray(bstring)) => {
412 setdiff_string(astring, bstring, opts)
413 }
414 (Value::StringArray(astring), Value::String(b)) => {
415 let bstring = StringArray::new(vec![b], vec![1, 1])
416 .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
417 setdiff_string(astring, bstring, opts)
418 }
419 (Value::String(a), Value::StringArray(bstring)) => {
420 let astring = StringArray::new(vec![a], vec![1, 1])
421 .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
422 setdiff_string(astring, bstring, opts)
423 }
424 (Value::String(a), Value::String(b)) => {
425 let astring = StringArray::new(vec![a], vec![1, 1])
426 .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
427 let bstring = StringArray::new(vec![b], vec![1, 1])
428 .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
429 setdiff_string(astring, bstring, opts)
430 }
431
432 (left, right) => {
433 let tensor_a = tensor::value_into_tensor_for("setdiff", left)
434 .map_err(|e| setdiff_error_with(&SETDIFF_ERROR_UNSUPPORTED_INPUT_TYPE, e))?;
435 let tensor_b = tensor::value_into_tensor_for("setdiff", right)
436 .map_err(|e| setdiff_error_with(&SETDIFF_ERROR_UNSUPPORTED_INPUT_TYPE, e))?;
437 setdiff_numeric(tensor_a, tensor_b, opts)
438 }
439 }
440}
441
442fn setdiff_numeric(
443 a: Tensor,
444 b: Tensor,
445 opts: &SetdiffOptions,
446) -> crate::BuiltinResult<SetdiffEvaluation> {
447 if opts.rows {
448 setdiff_numeric_rows(a, b, opts)
449 } else {
450 setdiff_numeric_elements(a, b, opts)
451 }
452}
453
454pub fn setdiff_numeric_from_tensors(
456 a: Tensor,
457 b: Tensor,
458 opts: &SetdiffOptions,
459) -> crate::BuiltinResult<SetdiffEvaluation> {
460 setdiff_numeric(a, b, opts)
461}
462
463fn setdiff_numeric_elements(
464 a: Tensor,
465 b: Tensor,
466 opts: &SetdiffOptions,
467) -> crate::BuiltinResult<SetdiffEvaluation> {
468 let mut b_keys: HashSet<u64> = HashSet::new();
469 for &value in &b.data {
470 b_keys.insert(canonicalize_f64(value));
471 }
472
473 let mut seen: HashMap<u64, usize> = HashMap::new();
474 let mut entries = Vec::<NumericDiffEntry>::new();
475 let mut order_counter = 0usize;
476
477 for (idx, &value) in a.data.iter().enumerate() {
478 let key = canonicalize_f64(value);
479 if b_keys.contains(&key) {
480 continue;
481 }
482 if seen.contains_key(&key) {
483 continue;
484 }
485 let entry_idx = entries.len();
486 entries.push(NumericDiffEntry {
487 value,
488 index: idx,
489 order_rank: order_counter,
490 });
491 seen.insert(key, entry_idx);
492 order_counter += 1;
493 }
494
495 assemble_numeric_setdiff(entries, opts)
496}
497
498fn setdiff_numeric_rows(
499 a: Tensor,
500 b: Tensor,
501 opts: &SetdiffOptions,
502) -> crate::BuiltinResult<SetdiffEvaluation> {
503 if a.shape.len() != 2 || b.shape.len() != 2 {
504 return Err(setdiff_internal_error(
505 "setdiff: 'rows' option requires 2-D numeric matrices",
506 ));
507 }
508 if a.shape[1] != b.shape[1] {
509 return Err(setdiff_error(&SETDIFF_ERROR_ROWS_COLUMN_MISMATCH));
510 }
511
512 let rows_a = a.shape[0];
513 let rows_b = b.shape[0];
514 let cols = a.shape[1];
515
516 let mut b_keys: HashSet<NumericRowKey> = HashSet::new();
517 for r in 0..rows_b {
518 let mut row_values = Vec::with_capacity(cols);
519 for c in 0..cols {
520 let idx = r + c * rows_b;
521 row_values.push(b.data[idx]);
522 }
523 b_keys.insert(NumericRowKey::from_slice(&row_values));
524 }
525
526 let mut seen: HashSet<NumericRowKey> = HashSet::new();
527 let mut entries = Vec::<NumericRowDiffEntry>::new();
528 let mut order_counter = 0usize;
529
530 for r in 0..rows_a {
531 let mut row_values = Vec::with_capacity(cols);
532 for c in 0..cols {
533 let idx = r + c * rows_a;
534 row_values.push(a.data[idx]);
535 }
536 let key = NumericRowKey::from_slice(&row_values);
537 if b_keys.contains(&key) {
538 continue;
539 }
540 if !seen.insert(key) {
541 continue;
542 }
543 entries.push(NumericRowDiffEntry {
544 row_data: row_values,
545 row_index: r,
546 order_rank: order_counter,
547 });
548 order_counter += 1;
549 }
550
551 assemble_numeric_row_setdiff(entries, opts, cols)
552}
553
554fn setdiff_complex(
555 a: ComplexTensor,
556 b: ComplexTensor,
557 opts: &SetdiffOptions,
558) -> crate::BuiltinResult<SetdiffEvaluation> {
559 if opts.rows {
560 setdiff_complex_rows(a, b, opts)
561 } else {
562 setdiff_complex_elements(a, b, opts)
563 }
564}
565
566fn setdiff_complex_elements(
567 a: ComplexTensor,
568 b: ComplexTensor,
569 opts: &SetdiffOptions,
570) -> crate::BuiltinResult<SetdiffEvaluation> {
571 let mut b_keys: HashSet<ComplexKey> = HashSet::new();
572 for &value in &b.data {
573 b_keys.insert(ComplexKey::new(value));
574 }
575
576 let mut seen: HashSet<ComplexKey> = HashSet::new();
577 let mut entries = Vec::<ComplexDiffEntry>::new();
578 let mut order_counter = 0usize;
579
580 for (idx, &value) in a.data.iter().enumerate() {
581 let key = ComplexKey::new(value);
582 if b_keys.contains(&key) {
583 continue;
584 }
585 if !seen.insert(key) {
586 continue;
587 }
588 entries.push(ComplexDiffEntry {
589 value,
590 index: idx,
591 order_rank: order_counter,
592 });
593 order_counter += 1;
594 }
595
596 assemble_complex_setdiff(entries, opts)
597}
598
599fn setdiff_complex_rows(
600 a: ComplexTensor,
601 b: ComplexTensor,
602 opts: &SetdiffOptions,
603) -> crate::BuiltinResult<SetdiffEvaluation> {
604 if a.shape.len() != 2 || b.shape.len() != 2 {
605 return Err(setdiff_internal_error(
606 "setdiff: 'rows' option requires 2-D complex matrices",
607 ));
608 }
609 if a.shape[1] != b.shape[1] {
610 return Err(setdiff_error(&SETDIFF_ERROR_ROWS_COLUMN_MISMATCH));
611 }
612
613 let rows_a = a.shape[0];
614 let rows_b = b.shape[0];
615 let cols = a.shape[1];
616
617 let mut b_keys: HashSet<Vec<ComplexKey>> = HashSet::new();
618 for r in 0..rows_b {
619 let mut key_row = Vec::with_capacity(cols);
620 for c in 0..cols {
621 let idx = r + c * rows_b;
622 key_row.push(ComplexKey::new(b.data[idx]));
623 }
624 b_keys.insert(key_row);
625 }
626
627 let mut seen: HashSet<Vec<ComplexKey>> = HashSet::new();
628 let mut entries = Vec::<ComplexRowDiffEntry>::new();
629 let mut order_counter = 0usize;
630
631 for r in 0..rows_a {
632 let mut row_values = Vec::with_capacity(cols);
633 let mut key_row = Vec::with_capacity(cols);
634 for c in 0..cols {
635 let idx = r + c * rows_a;
636 let value = a.data[idx];
637 row_values.push(value);
638 key_row.push(ComplexKey::new(value));
639 }
640 if b_keys.contains(&key_row) {
641 continue;
642 }
643 if !seen.insert(key_row) {
644 continue;
645 }
646 entries.push(ComplexRowDiffEntry {
647 row_data: row_values,
648 row_index: r,
649 order_rank: order_counter,
650 });
651 order_counter += 1;
652 }
653
654 assemble_complex_row_setdiff(entries, opts, cols)
655}
656
657fn setdiff_char(
658 a: CharArray,
659 b: CharArray,
660 opts: &SetdiffOptions,
661) -> crate::BuiltinResult<SetdiffEvaluation> {
662 if opts.rows {
663 setdiff_char_rows(a, b, opts)
664 } else {
665 setdiff_char_elements(a, b, opts)
666 }
667}
668
669fn setdiff_char_elements(
670 a: CharArray,
671 b: CharArray,
672 opts: &SetdiffOptions,
673) -> crate::BuiltinResult<SetdiffEvaluation> {
674 let mut b_keys: HashSet<u32> = HashSet::new();
675 for ch in &b.data {
676 b_keys.insert(*ch as u32);
677 }
678
679 let mut seen: HashSet<u32> = HashSet::new();
680 let mut entries = Vec::<CharDiffEntry>::new();
681 let mut order_counter = 0usize;
682
683 for col in 0..a.cols {
684 for row in 0..a.rows {
685 let linear_idx = row + col * a.rows;
686 let data_idx = row * a.cols + col;
687 let ch = a.data[data_idx];
688 let key = ch as u32;
689 if b_keys.contains(&key) {
690 continue;
691 }
692 if !seen.insert(key) {
693 continue;
694 }
695 entries.push(CharDiffEntry {
696 ch,
697 index: linear_idx,
698 order_rank: order_counter,
699 });
700 order_counter += 1;
701 }
702 }
703
704 assemble_char_setdiff(entries, opts)
705}
706
707fn setdiff_char_rows(
708 a: CharArray,
709 b: CharArray,
710 opts: &SetdiffOptions,
711) -> crate::BuiltinResult<SetdiffEvaluation> {
712 if a.cols != b.cols {
713 return Err(setdiff_error(&SETDIFF_ERROR_ROWS_COLUMN_MISMATCH));
714 }
715
716 let rows_a = a.rows;
717 let rows_b = b.rows;
718 let cols = a.cols;
719
720 let mut b_keys: HashSet<RowCharKey> = HashSet::new();
721 for r in 0..rows_b {
722 let mut row_values = Vec::with_capacity(cols);
723 for c in 0..cols {
724 let idx = r * cols + c;
725 row_values.push(b.data[idx]);
726 }
727 b_keys.insert(RowCharKey::from_slice(&row_values));
728 }
729
730 let mut seen: HashSet<RowCharKey> = HashSet::new();
731 let mut entries = Vec::<CharRowDiffEntry>::new();
732 let mut order_counter = 0usize;
733
734 for r in 0..rows_a {
735 let mut row_values = Vec::with_capacity(cols);
736 for c in 0..cols {
737 let idx = r * cols + c;
738 row_values.push(a.data[idx]);
739 }
740 let key = RowCharKey::from_slice(&row_values);
741 if b_keys.contains(&key) {
742 continue;
743 }
744 if !seen.insert(key) {
745 continue;
746 }
747 entries.push(CharRowDiffEntry {
748 row_data: row_values,
749 row_index: r,
750 order_rank: order_counter,
751 });
752 order_counter += 1;
753 }
754
755 assemble_char_row_setdiff(entries, opts, cols)
756}
757
758fn setdiff_string(
759 a: StringArray,
760 b: StringArray,
761 opts: &SetdiffOptions,
762) -> crate::BuiltinResult<SetdiffEvaluation> {
763 if opts.rows {
764 setdiff_string_rows(a, b, opts)
765 } else {
766 setdiff_string_elements(a, b, opts)
767 }
768}
769
770fn setdiff_string_elements(
771 a: StringArray,
772 b: StringArray,
773 opts: &SetdiffOptions,
774) -> crate::BuiltinResult<SetdiffEvaluation> {
775 let mut b_keys: HashSet<String> = HashSet::new();
776 for value in &b.data {
777 b_keys.insert(value.clone());
778 }
779
780 let mut seen: HashSet<String> = HashSet::new();
781 let mut entries = Vec::<StringDiffEntry>::new();
782 let mut order_counter = 0usize;
783
784 for (idx, value) in a.data.iter().enumerate() {
785 if b_keys.contains(value) {
786 continue;
787 }
788 if !seen.insert(value.clone()) {
789 continue;
790 }
791 entries.push(StringDiffEntry {
792 value: value.clone(),
793 index: idx,
794 order_rank: order_counter,
795 });
796 order_counter += 1;
797 }
798
799 assemble_string_setdiff(entries, opts)
800}
801
802fn setdiff_string_rows(
803 a: StringArray,
804 b: StringArray,
805 opts: &SetdiffOptions,
806) -> crate::BuiltinResult<SetdiffEvaluation> {
807 if a.shape.len() != 2 || b.shape.len() != 2 {
808 return Err(setdiff_internal_error(
809 "setdiff: 'rows' option requires 2-D string arrays",
810 ));
811 }
812 if a.shape[1] != b.shape[1] {
813 return Err(setdiff_error(&SETDIFF_ERROR_ROWS_COLUMN_MISMATCH));
814 }
815
816 let rows_a = a.shape[0];
817 let rows_b = b.shape[0];
818 let cols = a.shape[1];
819
820 let mut b_keys: HashSet<RowStringKey> = HashSet::new();
821 for r in 0..rows_b {
822 let mut row_values = Vec::with_capacity(cols);
823 for c in 0..cols {
824 let idx = r + c * rows_b;
825 row_values.push(b.data[idx].clone());
826 }
827 b_keys.insert(RowStringKey(row_values.clone()));
828 }
829
830 let mut seen: HashSet<RowStringKey> = HashSet::new();
831 let mut entries = Vec::<StringRowDiffEntry>::new();
832 let mut order_counter = 0usize;
833
834 for r in 0..rows_a {
835 let mut row_values = Vec::with_capacity(cols);
836 for c in 0..cols {
837 let idx = r + c * rows_a;
838 row_values.push(a.data[idx].clone());
839 }
840 let key = RowStringKey(row_values.clone());
841 if b_keys.contains(&key) {
842 continue;
843 }
844 if !seen.insert(key) {
845 continue;
846 }
847 entries.push(StringRowDiffEntry {
848 row_data: row_values,
849 row_index: r,
850 order_rank: order_counter,
851 });
852 order_counter += 1;
853 }
854
855 assemble_string_row_setdiff(entries, opts, cols)
856}
857
858fn assemble_numeric_setdiff(
859 entries: Vec<NumericDiffEntry>,
860 opts: &SetdiffOptions,
861) -> crate::BuiltinResult<SetdiffEvaluation> {
862 let mut order: Vec<usize> = (0..entries.len()).collect();
863 match opts.order {
864 SetdiffOrder::Sorted => {
865 order.sort_by(|&lhs, &rhs| compare_f64(entries[lhs].value, entries[rhs].value));
866 }
867 SetdiffOrder::Stable => {
868 order.sort_by_key(|&idx| entries[idx].order_rank);
869 }
870 }
871
872 let mut values = Vec::with_capacity(order.len());
873 let mut ia = Vec::with_capacity(order.len());
874 for &idx in &order {
875 let entry = &entries[idx];
876 values.push(entry.value);
877 ia.push((entry.index + 1) as f64);
878 }
879
880 let value_tensor = Tensor::new(values, vec![order.len(), 1])
881 .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
882 let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
883 .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
884
885 Ok(SetdiffEvaluation::new(
886 Value::Tensor(value_tensor),
887 ia_tensor,
888 ))
889}
890
891fn assemble_numeric_row_setdiff(
892 entries: Vec<NumericRowDiffEntry>,
893 opts: &SetdiffOptions,
894 cols: usize,
895) -> crate::BuiltinResult<SetdiffEvaluation> {
896 let mut order: Vec<usize> = (0..entries.len()).collect();
897 match opts.order {
898 SetdiffOrder::Sorted => {
899 order.sort_by(|&lhs, &rhs| {
900 compare_numeric_rows(&entries[lhs].row_data, &entries[rhs].row_data)
901 });
902 }
903 SetdiffOrder::Stable => {
904 order.sort_by_key(|&idx| entries[idx].order_rank);
905 }
906 }
907
908 let unique_rows = order.len();
909 let mut values = vec![0.0f64; unique_rows * cols];
910 let mut ia = Vec::with_capacity(unique_rows);
911
912 for (row_pos, &entry_idx) in order.iter().enumerate() {
913 let entry = &entries[entry_idx];
914 for col in 0..cols {
915 let dest = row_pos + col * unique_rows;
916 values[dest] = entry.row_data[col];
917 }
918 ia.push((entry.row_index + 1) as f64);
919 }
920
921 let value_tensor = Tensor::new(values, vec![unique_rows, cols])
922 .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
923 let ia_tensor = Tensor::new(ia, vec![unique_rows, 1])
924 .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
925
926 Ok(SetdiffEvaluation::new(
927 Value::Tensor(value_tensor),
928 ia_tensor,
929 ))
930}
931
932fn assemble_complex_setdiff(
933 entries: Vec<ComplexDiffEntry>,
934 opts: &SetdiffOptions,
935) -> crate::BuiltinResult<SetdiffEvaluation> {
936 let mut order: Vec<usize> = (0..entries.len()).collect();
937 match opts.order {
938 SetdiffOrder::Sorted => {
939 order.sort_by(|&lhs, &rhs| compare_complex(entries[lhs].value, entries[rhs].value));
940 }
941 SetdiffOrder::Stable => {
942 order.sort_by_key(|&idx| entries[idx].order_rank);
943 }
944 }
945
946 let mut values = Vec::with_capacity(order.len());
947 let mut ia = Vec::with_capacity(order.len());
948 for &idx in &order {
949 let entry = &entries[idx];
950 values.push(entry.value);
951 ia.push((entry.index + 1) as f64);
952 }
953
954 let value_tensor = ComplexTensor::new(values, vec![order.len(), 1])
955 .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
956 let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
957 .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
958
959 Ok(SetdiffEvaluation::new(
960 complex_tensor_into_value(value_tensor),
961 ia_tensor,
962 ))
963}
964
965fn assemble_complex_row_setdiff(
966 entries: Vec<ComplexRowDiffEntry>,
967 opts: &SetdiffOptions,
968 cols: usize,
969) -> crate::BuiltinResult<SetdiffEvaluation> {
970 let mut order: Vec<usize> = (0..entries.len()).collect();
971 match opts.order {
972 SetdiffOrder::Sorted => {
973 order.sort_by(|&lhs, &rhs| {
974 compare_complex_rows(&entries[lhs].row_data, &entries[rhs].row_data)
975 });
976 }
977 SetdiffOrder::Stable => {
978 order.sort_by_key(|&idx| entries[idx].order_rank);
979 }
980 }
981
982 let unique_rows = order.len();
983 let mut values = vec![(0.0f64, 0.0f64); unique_rows * cols];
984 let mut ia = Vec::with_capacity(unique_rows);
985
986 for (row_pos, &entry_idx) in order.iter().enumerate() {
987 let entry = &entries[entry_idx];
988 for col in 0..cols {
989 let dest = row_pos + col * unique_rows;
990 values[dest] = entry.row_data[col];
991 }
992 ia.push((entry.row_index + 1) as f64);
993 }
994
995 let value_tensor = ComplexTensor::new(values, vec![unique_rows, cols])
996 .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
997 let ia_tensor = Tensor::new(ia, vec![unique_rows, 1])
998 .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
999
1000 Ok(SetdiffEvaluation::new(
1001 complex_tensor_into_value(value_tensor),
1002 ia_tensor,
1003 ))
1004}
1005
1006fn assemble_char_setdiff(
1007 entries: Vec<CharDiffEntry>,
1008 opts: &SetdiffOptions,
1009) -> crate::BuiltinResult<SetdiffEvaluation> {
1010 let mut order: Vec<usize> = (0..entries.len()).collect();
1011 match opts.order {
1012 SetdiffOrder::Sorted => {
1013 order.sort_by(|&lhs, &rhs| entries[lhs].ch.cmp(&entries[rhs].ch));
1014 }
1015 SetdiffOrder::Stable => {
1016 order.sort_by_key(|&idx| entries[idx].order_rank);
1017 }
1018 }
1019
1020 let mut values = Vec::with_capacity(order.len());
1021 let mut ia = Vec::with_capacity(order.len());
1022 for &idx in &order {
1023 let entry = &entries[idx];
1024 values.push(entry.ch);
1025 ia.push((entry.index + 1) as f64);
1026 }
1027
1028 let value_array = CharArray::new(values, order.len(), 1)
1029 .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
1030 let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
1031 .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
1032
1033 Ok(SetdiffEvaluation::new(
1034 Value::CharArray(value_array),
1035 ia_tensor,
1036 ))
1037}
1038
1039fn assemble_char_row_setdiff(
1040 entries: Vec<CharRowDiffEntry>,
1041 opts: &SetdiffOptions,
1042 cols: usize,
1043) -> crate::BuiltinResult<SetdiffEvaluation> {
1044 let mut order: Vec<usize> = (0..entries.len()).collect();
1045 match opts.order {
1046 SetdiffOrder::Sorted => {
1047 order.sort_by(|&lhs, &rhs| {
1048 compare_char_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1049 });
1050 }
1051 SetdiffOrder::Stable => {
1052 order.sort_by_key(|&idx| entries[idx].order_rank);
1053 }
1054 }
1055
1056 let unique_rows = order.len();
1057 let mut values = vec!['\0'; unique_rows * cols];
1058 let mut ia = Vec::with_capacity(unique_rows);
1059
1060 for (row_pos, &entry_idx) in order.iter().enumerate() {
1061 let entry = &entries[entry_idx];
1062 for col in 0..cols {
1063 let dest = row_pos * cols + col;
1064 values[dest] = entry.row_data[col];
1065 }
1066 ia.push((entry.row_index + 1) as f64);
1067 }
1068
1069 let value_array = CharArray::new(values, unique_rows, cols)
1070 .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
1071 let ia_tensor = Tensor::new(ia, vec![unique_rows, 1])
1072 .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
1073
1074 Ok(SetdiffEvaluation::new(
1075 Value::CharArray(value_array),
1076 ia_tensor,
1077 ))
1078}
1079
1080fn assemble_string_setdiff(
1081 entries: Vec<StringDiffEntry>,
1082 opts: &SetdiffOptions,
1083) -> crate::BuiltinResult<SetdiffEvaluation> {
1084 let mut order: Vec<usize> = (0..entries.len()).collect();
1085 match opts.order {
1086 SetdiffOrder::Sorted => {
1087 order.sort_by(|&lhs, &rhs| entries[lhs].value.cmp(&entries[rhs].value));
1088 }
1089 SetdiffOrder::Stable => {
1090 order.sort_by_key(|&idx| entries[idx].order_rank);
1091 }
1092 }
1093
1094 let mut values = Vec::with_capacity(order.len());
1095 let mut ia = Vec::with_capacity(order.len());
1096 for &idx in &order {
1097 let entry = &entries[idx];
1098 values.push(entry.value.clone());
1099 ia.push((entry.index + 1) as f64);
1100 }
1101
1102 let value_array = StringArray::new(values, vec![order.len(), 1])
1103 .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
1104 let ia_tensor = Tensor::new(ia, vec![order.len(), 1])
1105 .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
1106
1107 Ok(SetdiffEvaluation::new(
1108 Value::StringArray(value_array),
1109 ia_tensor,
1110 ))
1111}
1112
1113fn assemble_string_row_setdiff(
1114 entries: Vec<StringRowDiffEntry>,
1115 opts: &SetdiffOptions,
1116 cols: usize,
1117) -> crate::BuiltinResult<SetdiffEvaluation> {
1118 let mut order: Vec<usize> = (0..entries.len()).collect();
1119 match opts.order {
1120 SetdiffOrder::Sorted => {
1121 order.sort_by(|&lhs, &rhs| {
1122 compare_string_rows(&entries[lhs].row_data, &entries[rhs].row_data)
1123 });
1124 }
1125 SetdiffOrder::Stable => {
1126 order.sort_by_key(|&idx| entries[idx].order_rank);
1127 }
1128 }
1129
1130 let unique_rows = order.len();
1131 let mut values = vec![String::new(); unique_rows * cols];
1132 let mut ia = Vec::with_capacity(unique_rows);
1133
1134 for (row_pos, &entry_idx) in order.iter().enumerate() {
1135 let entry = &entries[entry_idx];
1136 for col in 0..cols {
1137 let dest = row_pos + col * unique_rows;
1138 values[dest] = entry.row_data[col].clone();
1139 }
1140 ia.push((entry.row_index + 1) as f64);
1141 }
1142
1143 let value_array = StringArray::new(values, vec![unique_rows, cols])
1144 .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
1145 let ia_tensor = Tensor::new(ia, vec![unique_rows, 1])
1146 .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
1147
1148 Ok(SetdiffEvaluation::new(
1149 Value::StringArray(value_array),
1150 ia_tensor,
1151 ))
1152}
1153
1154#[derive(Clone, Copy, Debug)]
1155struct NumericDiffEntry {
1156 value: f64,
1157 index: usize,
1158 order_rank: usize,
1159}
1160
1161#[derive(Clone, Debug)]
1162struct NumericRowDiffEntry {
1163 row_data: Vec<f64>,
1164 row_index: usize,
1165 order_rank: usize,
1166}
1167
1168#[derive(Clone, Copy, Debug)]
1169struct ComplexDiffEntry {
1170 value: (f64, f64),
1171 index: usize,
1172 order_rank: usize,
1173}
1174
1175#[derive(Clone, Debug)]
1176struct ComplexRowDiffEntry {
1177 row_data: Vec<(f64, f64)>,
1178 row_index: usize,
1179 order_rank: usize,
1180}
1181
1182#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1183struct CharDiffEntry {
1184 ch: char,
1185 index: usize,
1186 order_rank: usize,
1187}
1188
1189#[derive(Clone, Debug)]
1190struct CharRowDiffEntry {
1191 row_data: Vec<char>,
1192 row_index: usize,
1193 order_rank: usize,
1194}
1195
1196#[derive(Clone, Debug)]
1197struct StringDiffEntry {
1198 value: String,
1199 index: usize,
1200 order_rank: usize,
1201}
1202
1203#[derive(Clone, Debug)]
1204struct StringRowDiffEntry {
1205 row_data: Vec<String>,
1206 row_index: usize,
1207 order_rank: usize,
1208}
1209
1210#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1211struct NumericRowKey(Vec<u64>);
1212
1213impl NumericRowKey {
1214 fn from_slice(values: &[f64]) -> Self {
1215 NumericRowKey(values.iter().map(|&v| canonicalize_f64(v)).collect())
1216 }
1217}
1218
1219#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
1220struct ComplexKey {
1221 re: u64,
1222 im: u64,
1223}
1224
1225impl ComplexKey {
1226 fn new(value: (f64, f64)) -> Self {
1227 Self {
1228 re: canonicalize_f64(value.0),
1229 im: canonicalize_f64(value.1),
1230 }
1231 }
1232}
1233
1234#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1235struct RowCharKey(Vec<u32>);
1236
1237impl RowCharKey {
1238 fn from_slice(values: &[char]) -> Self {
1239 RowCharKey(values.iter().map(|&ch| ch as u32).collect())
1240 }
1241}
1242
1243#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1244struct RowStringKey(Vec<String>);
1245
1246#[derive(Debug)]
1247pub struct SetdiffEvaluation {
1248 values: Value,
1249 ia: Tensor,
1250}
1251
1252impl SetdiffEvaluation {
1253 fn new(values: Value, ia: Tensor) -> Self {
1254 Self { values, ia }
1255 }
1256
1257 pub fn from_setdiff_result(result: SetdiffResult) -> crate::BuiltinResult<Self> {
1258 let SetdiffResult { values, ia } = result;
1259 let values_tensor = Tensor::new(values.data, values.shape)
1260 .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
1261 let ia_tensor = Tensor::new(ia.data, ia.shape)
1262 .map_err(|e| setdiff_internal_error(format!("setdiff: {e}")))?;
1263 Ok(SetdiffEvaluation::new(
1264 Value::Tensor(values_tensor),
1265 ia_tensor,
1266 ))
1267 }
1268
1269 pub fn into_numeric_setdiff_result(self) -> crate::BuiltinResult<SetdiffResult> {
1270 let SetdiffEvaluation { values, ia } = self;
1271 let values_tensor = tensor::value_into_tensor_for("setdiff", values)
1272 .map_err(|e| setdiff_internal_error(e))?;
1273 Ok(SetdiffResult {
1274 values: HostTensorOwned {
1275 data: values_tensor.data,
1276 shape: values_tensor.shape,
1277 storage: GpuTensorStorage::Real,
1278 },
1279 ia: HostTensorOwned {
1280 data: ia.data,
1281 shape: ia.shape,
1282 storage: GpuTensorStorage::Real,
1283 },
1284 })
1285 }
1286
1287 pub fn into_values_value(self) -> Value {
1288 self.values
1289 }
1290
1291 pub fn into_pair(self) -> (Value, Value) {
1292 let ia = tensor::tensor_into_value(self.ia);
1293 (self.values, ia)
1294 }
1295
1296 pub fn values_value(&self) -> Value {
1297 self.values.clone()
1298 }
1299
1300 pub fn ia_value(&self) -> Value {
1301 tensor::tensor_into_value(self.ia.clone())
1302 }
1303}
1304
1305fn canonicalize_f64(value: f64) -> u64 {
1306 if value.is_nan() {
1307 0x7ff8_0000_0000_0000u64
1308 } else if value == 0.0 {
1309 0u64
1310 } else {
1311 value.to_bits()
1312 }
1313}
1314
1315fn compare_f64(a: f64, b: f64) -> Ordering {
1316 if a.is_nan() {
1317 if b.is_nan() {
1318 Ordering::Equal
1319 } else {
1320 Ordering::Greater
1321 }
1322 } else if b.is_nan() {
1323 Ordering::Less
1324 } else {
1325 a.partial_cmp(&b).unwrap_or(Ordering::Equal)
1326 }
1327}
1328
1329fn compare_numeric_rows(a: &[f64], b: &[f64]) -> Ordering {
1330 for (lhs, rhs) in a.iter().zip(b.iter()) {
1331 let ord = compare_f64(*lhs, *rhs);
1332 if ord != Ordering::Equal {
1333 return ord;
1334 }
1335 }
1336 Ordering::Equal
1337}
1338
1339fn complex_is_nan(value: (f64, f64)) -> bool {
1340 value.0.is_nan() || value.1.is_nan()
1341}
1342
1343fn compare_complex(a: (f64, f64), b: (f64, f64)) -> Ordering {
1344 match (complex_is_nan(a), complex_is_nan(b)) {
1345 (true, true) => Ordering::Equal,
1346 (true, false) => Ordering::Greater,
1347 (false, true) => Ordering::Less,
1348 (false, false) => {
1349 let mag_a = a.0.hypot(a.1);
1350 let mag_b = b.0.hypot(b.1);
1351 let mag_cmp = compare_f64(mag_a, mag_b);
1352 if mag_cmp != Ordering::Equal {
1353 return mag_cmp;
1354 }
1355 let re_cmp = compare_f64(a.0, b.0);
1356 if re_cmp != Ordering::Equal {
1357 return re_cmp;
1358 }
1359 compare_f64(a.1, b.1)
1360 }
1361 }
1362}
1363
1364fn compare_complex_rows(a: &[(f64, f64)], b: &[(f64, f64)]) -> Ordering {
1365 for (lhs, rhs) in a.iter().zip(b.iter()) {
1366 let ord = compare_complex(*lhs, *rhs);
1367 if ord != Ordering::Equal {
1368 return ord;
1369 }
1370 }
1371 Ordering::Equal
1372}
1373
1374fn compare_char_rows(a: &[char], b: &[char]) -> Ordering {
1375 for (lhs, rhs) in a.iter().zip(b.iter()) {
1376 let ord = lhs.cmp(rhs);
1377 if ord != Ordering::Equal {
1378 return ord;
1379 }
1380 }
1381 Ordering::Equal
1382}
1383
1384fn compare_string_rows(a: &[String], b: &[String]) -> Ordering {
1385 for (lhs, rhs) in a.iter().zip(b.iter()) {
1386 let ord = lhs.cmp(rhs);
1387 if ord != Ordering::Equal {
1388 return ord;
1389 }
1390 }
1391 Ordering::Equal
1392}
1393
1394#[cfg(test)]
1395pub(crate) mod tests {
1396 use super::*;
1397 use crate::builtins::common::test_support;
1398 use runmat_accelerate_api::HostTensorView;
1399 use runmat_builtins::{CharArray, ResolveContext, StringArray, Tensor, Type, Value};
1400
1401 fn evaluate_sync(
1402 a: Value,
1403 b: Value,
1404 rest: &[Value],
1405 ) -> crate::BuiltinResult<SetdiffEvaluation> {
1406 futures::executor::block_on(evaluate(a, b, rest))
1407 }
1408
1409 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1410 #[test]
1411 fn setdiff_numeric_sorted_default() {
1412 let a = Tensor::new(vec![5.0, 7.0, 5.0, 1.0], vec![4, 1]).unwrap();
1413 let b = Tensor::new(vec![7.0, 1.0, 3.0], vec![3, 1]).unwrap();
1414 let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[]).expect("setdiff");
1415 match eval.values_value() {
1416 Value::Tensor(t) => {
1417 assert_eq!(t.shape, vec![1, 1]);
1418 assert_eq!(t.data, vec![5.0]);
1419 }
1420 other => panic!("expected tensor result, got {other:?}"),
1421 }
1422 let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1423 assert_eq!(ia.data, vec![1.0]);
1424 }
1425
1426 #[test]
1427 fn setdiff_type_resolver_numeric() {
1428 assert_eq!(
1429 set_values_output_type(
1430 &[Type::tensor(), Type::tensor()],
1431 &ResolveContext::new(Vec::new()),
1432 ),
1433 Type::tensor()
1434 );
1435 }
1436
1437 #[test]
1438 fn setdiff_type_resolver_string_array() {
1439 assert_eq!(
1440 set_values_output_type(
1441 &[Type::cell_of(Type::String), Type::String],
1442 &ResolveContext::new(Vec::new()),
1443 ),
1444 Type::cell_of(Type::String)
1445 );
1446 }
1447
1448 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1449 #[test]
1450 fn setdiff_numeric_stable() {
1451 let a = Tensor::new(vec![4.0, 2.0, 4.0, 1.0, 3.0], vec![5, 1]).unwrap();
1452 let b = Tensor::new(vec![3.0, 4.0, 5.0, 1.0], vec![4, 1]).unwrap();
1453 let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[Value::from("stable")])
1454 .expect("setdiff");
1455 match eval.values_value() {
1456 Value::Tensor(t) => {
1457 assert_eq!(t.shape, vec![1, 1]);
1458 assert_eq!(t.data, vec![2.0]);
1459 }
1460 other => panic!("expected tensor result, got {other:?}"),
1461 }
1462 let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1463 assert_eq!(ia.data, vec![2.0]);
1464 }
1465
1466 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1467 #[test]
1468 fn setdiff_numeric_rows_sorted() {
1469 let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
1470 let b = Tensor::new(vec![3.0, 5.0, 4.0, 6.0], vec![2, 2]).unwrap();
1471 let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[Value::from("rows")])
1472 .expect("setdiff");
1473 match eval.values_value() {
1474 Value::Tensor(t) => {
1475 assert_eq!(t.shape, vec![1, 2]);
1476 assert_eq!(t.data, vec![1.0, 2.0]);
1477 }
1478 other => panic!("expected tensor result, got {other:?}"),
1479 }
1480 let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1481 assert_eq!(ia.data, vec![1.0]);
1482 }
1483
1484 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1485 #[test]
1486 fn setdiff_numeric_removes_nan() {
1487 let a = Tensor::new(vec![f64::NAN, 2.0, 3.0], vec![3, 1]).unwrap();
1488 let b = Tensor::new(vec![f64::NAN], vec![1, 1]).unwrap();
1489 let eval = evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[]).expect("setdiff");
1490 let values = tensor::value_into_tensor_for("setdiff", eval.values_value()).expect("values");
1491 assert_eq!(values.data, vec![2.0, 3.0]);
1492 let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1493 assert_eq!(ia.data, vec![2.0, 3.0]);
1494 }
1495
1496 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1497 #[test]
1498 fn setdiff_char_elements() {
1499 let a = CharArray::new(vec!['m', 'z', 'm', 'a'], 2, 2).unwrap();
1500 let b = CharArray::new(vec!['a', 'x', 'm', 'a'], 2, 2).unwrap();
1501 let eval = evaluate_sync(Value::CharArray(a), Value::CharArray(b), &[]).expect("setdiff");
1502 match eval.values_value() {
1503 Value::CharArray(arr) => {
1504 assert_eq!(arr.rows, 1);
1505 assert_eq!(arr.cols, 1);
1506 assert_eq!(arr.data, vec!['z']);
1507 }
1508 other => panic!("expected char array, got {other:?}"),
1509 }
1510 let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1511 assert_eq!(ia.data, vec![3.0]);
1512 }
1513
1514 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1515 #[test]
1516 fn setdiff_string_rows_stable() {
1517 let a = StringArray::new(
1518 vec![
1519 "alpha".to_string(),
1520 "gamma".to_string(),
1521 "beta".to_string(),
1522 "beta".to_string(),
1523 ],
1524 vec![2, 2],
1525 )
1526 .unwrap();
1527 let b = StringArray::new(
1528 vec![
1529 "gamma".to_string(),
1530 "delta".to_string(),
1531 "beta".to_string(),
1532 "beta".to_string(),
1533 ],
1534 vec![2, 2],
1535 )
1536 .unwrap();
1537 let eval = evaluate_sync(
1538 Value::StringArray(a),
1539 Value::StringArray(b),
1540 &[Value::from("rows"), Value::from("stable")],
1541 )
1542 .expect("setdiff");
1543 match eval.values_value() {
1544 Value::StringArray(arr) => {
1545 assert_eq!(arr.shape, vec![1, 2]);
1546 assert_eq!(arr.data, vec!["alpha".to_string(), "beta".to_string()]);
1547 }
1548 other => panic!("expected string array, got {other:?}"),
1549 }
1550 let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1551 assert_eq!(ia.data, vec![1.0]);
1552 }
1553
1554 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1555 #[test]
1556 fn setdiff_type_mismatch_errors() {
1557 let err = evaluate_sync(Value::from(1.0), Value::String("a".into()), &[]).unwrap_err();
1558 assert_eq!(
1559 err.identifier(),
1560 SETDIFF_ERROR_UNSUPPORTED_INPUT_TYPE.identifier
1561 );
1562 }
1563
1564 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1565 #[test]
1566 fn setdiff_rows_dimension_mismatch_reports_identifier() {
1567 let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).expect("tensor a");
1568 let b = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).expect("tensor b");
1569 let err =
1570 evaluate_sync(Value::Tensor(a), Value::Tensor(b), &[Value::from("rows")]).unwrap_err();
1571 assert_eq!(
1572 err.identifier(),
1573 SETDIFF_ERROR_ROWS_COLUMN_MISMATCH.identifier
1574 );
1575 }
1576
1577 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1578 #[test]
1579 fn setdiff_rejects_legacy_option() {
1580 let err = evaluate_sync(Value::from(1.0), Value::from(2.0), &[Value::from("legacy")])
1581 .unwrap_err();
1582 assert_eq!(
1583 err.identifier(),
1584 SETDIFF_ERROR_LEGACY_OPTION_UNSUPPORTED.identifier
1585 );
1586 }
1587
1588 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1589 #[test]
1590 fn setdiff_rejects_conflicting_order_options() {
1591 let err = evaluate_sync(
1592 Value::from(1.0),
1593 Value::from(2.0),
1594 &[Value::from("stable"), Value::from("sorted")],
1595 )
1596 .unwrap_err();
1597 assert_eq!(
1598 err.identifier(),
1599 SETDIFF_ERROR_CONFLICTING_ORDER_OPTIONS.identifier
1600 );
1601 }
1602
1603 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1604 #[test]
1605 fn setdiff_rejects_unknown_option() {
1606 let err =
1607 evaluate_sync(Value::from(1.0), Value::from(2.0), &[Value::from("bogus")]).unwrap_err();
1608 assert_eq!(err.identifier(), SETDIFF_ERROR_UNKNOWN_OPTION.identifier);
1609 }
1610
1611 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1612 #[test]
1613 fn setdiff_gpu_roundtrip() {
1614 test_support::with_test_provider(|provider| {
1615 let tensor_a = Tensor::new(vec![10.0, 4.0, 6.0, 4.0], vec![4, 1]).unwrap();
1616 let tensor_b = Tensor::new(vec![6.0, 4.0, 2.0], vec![3, 1]).unwrap();
1617 let view_a = HostTensorView {
1618 data: &tensor_a.data,
1619 shape: &tensor_a.shape,
1620 };
1621 let view_b = HostTensorView {
1622 data: &tensor_b.data,
1623 shape: &tensor_b.shape,
1624 };
1625 let handle_a = provider.upload(&view_a).expect("upload a");
1626 let handle_b = provider.upload(&view_b).expect("upload b");
1627 let eval = evaluate_sync(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
1628 .expect("setdiff");
1629 match eval.values_value() {
1630 Value::Tensor(t) => {
1631 assert_eq!(t.data, vec![10.0]);
1632 }
1633 other => panic!("expected tensor result, got {other:?}"),
1634 }
1635 let ia = tensor::value_into_tensor_for("setdiff", eval.ia_value()).expect("ia tensor");
1636 assert_eq!(ia.data, vec![1.0]);
1637 });
1638 }
1639
1640 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1641 #[test]
1642 #[cfg(feature = "wgpu")]
1643 fn setdiff_wgpu_matches_cpu() {
1644 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1645 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1646 );
1647 let a = Tensor::new(vec![8.0, 4.0, 2.0, 4.0], vec![4, 1]).unwrap();
1648 let b = Tensor::new(vec![2.0, 5.0], vec![2, 1]).unwrap();
1649
1650 let cpu_eval = evaluate_sync(Value::Tensor(a.clone()), Value::Tensor(b.clone()), &[])
1651 .expect("setdiff");
1652 let cpu_values = tensor::value_into_tensor_for("setdiff", cpu_eval.values_value()).unwrap();
1653 let cpu_ia = tensor::value_into_tensor_for("setdiff", cpu_eval.ia_value()).unwrap();
1654
1655 let provider = runmat_accelerate_api::provider().expect("provider");
1656 let view_a = HostTensorView {
1657 data: &a.data,
1658 shape: &a.shape,
1659 };
1660 let view_b = HostTensorView {
1661 data: &b.data,
1662 shape: &b.shape,
1663 };
1664 let handle_a = provider.upload(&view_a).expect("upload A");
1665 let handle_b = provider.upload(&view_b).expect("upload B");
1666 let gpu_eval = evaluate_sync(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
1667 .expect("setdiff");
1668 let gpu_values = tensor::value_into_tensor_for("setdiff", gpu_eval.values_value()).unwrap();
1669 let gpu_ia = tensor::value_into_tensor_for("setdiff", gpu_eval.ia_value()).unwrap();
1670
1671 assert_eq!(gpu_values.data, cpu_values.data);
1672 assert_eq!(gpu_values.shape, cpu_values.shape);
1673 assert_eq!(gpu_ia.data, cpu_ia.data);
1674 assert_eq!(gpu_ia.shape, cpu_ia.shape);
1675 }
1676}