1use std::collections::HashMap;
4
5use runmat_accelerate_api::{
6 GpuTensorHandle, GpuTensorStorage, HostLogicalOwned, HostTensorOwned,
7 IsMemberOptions as ProviderIsMemberOptions, IsMemberResult,
8};
9use runmat_builtins::{
10 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
11 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
12 CharArray, ComplexTensor, LogicalArray, StringArray, Tensor, Value,
13};
14use runmat_macros::runtime_builtin;
15
16use super::type_resolvers::logical_output_type;
17use crate::build_runtime_error;
18use crate::builtins::common::gpu_helpers;
19use crate::builtins::common::spec::{
20 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
21 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
22};
23use crate::builtins::common::tensor;
24
25#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::sorting_sets::ismember")]
26pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
27 name: "ismember",
28 op_kind: GpuOpKind::Custom("ismember"),
29 supported_precisions: &[ScalarType::F32, ScalarType::F64],
30 broadcast: BroadcastSemantics::None,
31 provider_hooks: &[ProviderHook::Custom("ismember")],
32 constant_strategy: ConstantStrategy::InlineLiteral,
33 residency: ResidencyPolicy::GatherImmediately,
34 nan_mode: ReductionNaN::Include,
35 two_pass_threshold: None,
36 workgroup_size: None,
37 accepts_nan_mode: false,
38 notes: "Providers may supply dedicated membership kernels; until then RunMat gathers GPU tensors to host memory.",
39};
40
41#[runmat_macros::register_fusion_spec(
42 builtin_path = "crate::builtins::array::sorting_sets::ismember"
43)]
44pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
45 name: "ismember",
46 shape: ShapeRequirements::Any,
47 constant_strategy: ConstantStrategy::InlineLiteral,
48 elementwise: None,
49 reduction: None,
50 emits_nan: false,
51 notes: "`ismember` materialises logical outputs and terminates fusion chains; upstream tensors are gathered when necessary.",
52};
53
54const BUILTIN_NAME: &str = "ismember";
55
56const ISMEMBER_OUTPUT_MASK: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
57 name: "tf",
58 ty: BuiltinParamType::LogicalArray,
59 arity: BuiltinParamArity::Required,
60 default: None,
61 description: "Membership mask over A.",
62}];
63
64const ISMEMBER_OUTPUT_MASK_LOC: [BuiltinParamDescriptor; 2] = [
65 BuiltinParamDescriptor {
66 name: "tf",
67 ty: BuiltinParamType::LogicalArray,
68 arity: BuiltinParamArity::Required,
69 default: None,
70 description: "Membership mask over A.",
71 },
72 BuiltinParamDescriptor {
73 name: "loc",
74 ty: BuiltinParamType::NumericArray,
75 arity: BuiltinParamArity::Required,
76 default: None,
77 description: "First-match indices into B for each element/row in A (0 when absent).",
78 },
79];
80
81const ISMEMBER_INPUTS_A_B: [BuiltinParamDescriptor; 2] = [
82 BuiltinParamDescriptor {
83 name: "A",
84 ty: BuiltinParamType::Any,
85 arity: BuiltinParamArity::Required,
86 default: None,
87 description: "Values or rows to query.",
88 },
89 BuiltinParamDescriptor {
90 name: "B",
91 ty: BuiltinParamType::Any,
92 arity: BuiltinParamArity::Required,
93 default: None,
94 description: "Reference set of values or rows.",
95 },
96];
97
98const ISMEMBER_INPUTS_A_B_OPTIONS: [BuiltinParamDescriptor; 3] = [
99 BuiltinParamDescriptor {
100 name: "A",
101 ty: BuiltinParamType::Any,
102 arity: BuiltinParamArity::Required,
103 default: None,
104 description: "Values or rows to query.",
105 },
106 BuiltinParamDescriptor {
107 name: "B",
108 ty: BuiltinParamType::Any,
109 arity: BuiltinParamArity::Required,
110 default: None,
111 description: "Reference set of values or rows.",
112 },
113 BuiltinParamDescriptor {
114 name: "option",
115 ty: BuiltinParamType::StringScalar,
116 arity: BuiltinParamArity::Variadic,
117 default: None,
118 description: "Option tokens: 'rows'.",
119 },
120];
121
122const ISMEMBER_SIGNATURES: [BuiltinSignatureDescriptor; 4] = [
123 BuiltinSignatureDescriptor {
124 label: "tf = ismember(A, B)",
125 inputs: &ISMEMBER_INPUTS_A_B,
126 outputs: &ISMEMBER_OUTPUT_MASK,
127 },
128 BuiltinSignatureDescriptor {
129 label: "tf = ismember(A, B, option...)",
130 inputs: &ISMEMBER_INPUTS_A_B_OPTIONS,
131 outputs: &ISMEMBER_OUTPUT_MASK,
132 },
133 BuiltinSignatureDescriptor {
134 label: "[tf, loc] = ismember(A, B)",
135 inputs: &ISMEMBER_INPUTS_A_B,
136 outputs: &ISMEMBER_OUTPUT_MASK_LOC,
137 },
138 BuiltinSignatureDescriptor {
139 label: "[tf, loc] = ismember(A, B, option...)",
140 inputs: &ISMEMBER_INPUTS_A_B_OPTIONS,
141 outputs: &ISMEMBER_OUTPUT_MASK_LOC,
142 },
143];
144
145const ISMEMBER_ERROR_LEGACY_OPTION_UNSUPPORTED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
146 code: "RM.ISMEMBER.LEGACY_OPTION_UNSUPPORTED",
147 identifier: Some("RunMat:ismember:LegacyOptionUnsupported"),
148 when: "Legacy compatibility options are requested.",
149 message: "ismember: the 'legacy' behaviour is not supported",
150};
151
152const ISMEMBER_ERROR_UNKNOWN_OPTION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
153 code: "RM.ISMEMBER.UNKNOWN_OPTION",
154 identifier: Some("RunMat:ismember:UnknownOption"),
155 when: "An unsupported option token is provided.",
156 message: "ismember: unrecognised option",
157};
158
159const ISMEMBER_ERROR_ROWS_COLUMN_MISMATCH: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
160 code: "RM.ISMEMBER.ROWS_COLUMN_MISMATCH",
161 identifier: Some("RunMat:ismember:RowsColumnMismatch"),
162 when: "'rows' mode is used and column counts differ.",
163 message: "ismember: inputs must have the same number of columns when using 'rows'",
164};
165
166const ISMEMBER_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
167 code: "RM.ISMEMBER.INVALID_ARGUMENT",
168 identifier: Some("RunMat:ismember:InvalidArgument"),
169 when: "Option arguments are not string-like where required.",
170 message: "ismember: expected string option arguments",
171};
172
173const ISMEMBER_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
174 code: "RM.ISMEMBER.INTERNAL",
175 identifier: Some("RunMat:ismember:Internal"),
176 when: "Internal conversion/allocation/provider decode fails.",
177 message: "ismember: internal operation failed",
178};
179
180const ISMEMBER_ERRORS: [BuiltinErrorDescriptor; 5] = [
181 ISMEMBER_ERROR_LEGACY_OPTION_UNSUPPORTED,
182 ISMEMBER_ERROR_UNKNOWN_OPTION,
183 ISMEMBER_ERROR_ROWS_COLUMN_MISMATCH,
184 ISMEMBER_ERROR_INVALID_ARGUMENT,
185 ISMEMBER_ERROR_INTERNAL,
186];
187
188pub const ISMEMBER_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
189 signatures: &ISMEMBER_SIGNATURES,
190 output_mode: BuiltinOutputMode::ByRequestedOutputCount,
191 completion_policy: BuiltinCompletionPolicy::Public,
192 errors: &ISMEMBER_ERRORS,
193};
194
195fn ismember_error_with(
196 error: &'static BuiltinErrorDescriptor,
197 message: impl Into<String>,
198) -> crate::RuntimeError {
199 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
200 if let Some(identifier) = error.identifier {
201 builder = builder.with_identifier(identifier);
202 }
203 builder.build()
204}
205
206fn ismember_error(error: &'static BuiltinErrorDescriptor) -> crate::RuntimeError {
207 ismember_error_with(error, error.message)
208}
209
210fn ismember_internal_error(message: impl Into<String>) -> crate::RuntimeError {
211 ismember_error_with(&ISMEMBER_ERROR_INTERNAL, message)
212}
213
214#[runtime_builtin(
215 name = "ismember",
216 category = "array/sorting_sets",
217 summary = "Identify array elements or rows that appear in another array while returning first-match indices.",
218 keywords = "ismember,membership,set,rows,indices,gpu",
219 accel = "array_construct",
220 sink = true,
221 type_resolver(logical_output_type),
222 descriptor(crate::builtins::array::sorting_sets::ismember::ISMEMBER_DESCRIPTOR),
223 builtin_path = "crate::builtins::array::sorting_sets::ismember"
224)]
225async fn ismember_builtin(a: Value, b: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
226 let eval = evaluate(a, b, &rest).await?;
227 if let Some(out_count) = crate::output_count::current_output_count() {
228 if out_count == 0 {
229 return Ok(Value::OutputList(Vec::new()));
230 }
231 if out_count == 1 {
232 return Ok(Value::OutputList(vec![eval.into_mask_value()]));
233 }
234 let (mask, loc) = eval.into_pair();
235 return Ok(crate::output_count::output_list_with_padding(
236 out_count,
237 vec![mask, loc],
238 ));
239 }
240 Ok(eval.into_mask_value())
241}
242
243pub async fn evaluate(
245 a: Value,
246 b: Value,
247 rest: &[Value],
248) -> crate::BuiltinResult<IsMemberEvaluation> {
249 let opts = parse_options(rest)?;
250 match (a, b) {
251 (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
252 ismember_gpu_pair(handle_a, handle_b, &opts).await
253 }
254 (Value::GpuTensor(handle_a), other) => {
255 ismember_gpu_mixed(handle_a, other, &opts, true).await
256 }
257 (other, Value::GpuTensor(handle_b)) => {
258 ismember_gpu_mixed(handle_b, other, &opts, false).await
259 }
260 (left, right) => ismember_host(left, right, &opts),
261 }
262}
263
264#[derive(Debug, Clone, Copy)]
265struct IsMemberOptions {
266 rows: bool,
267}
268
269impl IsMemberOptions {
270 fn into_provider_options(self) -> ProviderIsMemberOptions {
271 ProviderIsMemberOptions { rows: self.rows }
272 }
273}
274
275fn parse_options(rest: &[Value]) -> crate::BuiltinResult<IsMemberOptions> {
276 let mut opts = IsMemberOptions { rows: false };
277 for arg in rest {
278 let text = tensor::value_to_string(arg)
279 .ok_or_else(|| ismember_error(&ISMEMBER_ERROR_INVALID_ARGUMENT))?;
280 let lowered = text.trim().to_ascii_lowercase();
281 match lowered.as_str() {
282 "rows" => opts.rows = true,
283 "legacy" | "r2012a" => {
284 return Err(ismember_error(&ISMEMBER_ERROR_LEGACY_OPTION_UNSUPPORTED))
285 }
286 other => {
287 return Err(ismember_error_with(
288 &ISMEMBER_ERROR_UNKNOWN_OPTION,
289 format!("ismember: unrecognised option '{other}'"),
290 ))
291 }
292 }
293 }
294 Ok(opts)
295}
296
297async fn ismember_gpu_pair(
298 handle_a: GpuTensorHandle,
299 handle_b: GpuTensorHandle,
300 opts: &IsMemberOptions,
301) -> crate::BuiltinResult<IsMemberEvaluation> {
302 if let Some(provider) = runmat_accelerate_api::provider() {
303 let provider_opts = opts.into_provider_options();
304 match provider
305 .ismember(&handle_a, &handle_b, &provider_opts)
306 .await
307 {
308 Ok(result) => return IsMemberEvaluation::from_provider_result(result),
309 Err(_) => {
310 }
312 }
313 }
314 let tensor_a = gpu_helpers::gather_tensor_async(&handle_a).await?;
315 let tensor_b = gpu_helpers::gather_tensor_async(&handle_b).await?;
316 ismember_numeric_tensors(tensor_a, tensor_b, opts)
317}
318
319async fn ismember_gpu_mixed(
320 handle_gpu: GpuTensorHandle,
321 other: Value,
322 opts: &IsMemberOptions,
323 gpu_is_a: bool,
324) -> crate::BuiltinResult<IsMemberEvaluation> {
325 let tensor_gpu = gpu_helpers::gather_tensor_async(&handle_gpu).await?;
326 if gpu_is_a {
327 ismember_host(Value::Tensor(tensor_gpu), other, opts)
328 } else {
329 ismember_host(other, Value::Tensor(tensor_gpu), opts)
330 }
331}
332
333fn ismember_host(
334 a: Value,
335 b: Value,
336 opts: &IsMemberOptions,
337) -> crate::BuiltinResult<IsMemberEvaluation> {
338 match (a, b) {
339 (Value::ComplexTensor(at), Value::ComplexTensor(bt)) => ismember_complex(at, bt, opts.rows),
340 (Value::ComplexTensor(at), Value::Complex(re, im)) => {
341 let bt = ComplexTensor::new(vec![(re, im)], vec![1, 1])
342 .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
343 ismember_complex(at, bt, opts.rows)
344 }
345 (Value::Complex(a_re, a_im), Value::ComplexTensor(bt)) => {
346 let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
347 .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
348 ismember_complex(at, bt, opts.rows)
349 }
350 (Value::Complex(a_re, a_im), Value::Complex(b_re, b_im)) => {
351 let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
352 .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
353 let bt = ComplexTensor::new(vec![(b_re, b_im)], vec![1, 1])
354 .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
355 ismember_complex(at, bt, opts.rows)
356 }
357
358 (Value::CharArray(ac), Value::CharArray(bc)) => ismember_char(ac, bc, opts.rows),
359
360 (Value::StringArray(astring), Value::StringArray(bstring)) => {
361 ismember_string(astring, bstring, opts.rows)
362 }
363 (Value::StringArray(astring), Value::String(b)) => {
364 let bstring = StringArray::new(vec![b], vec![1, 1])
365 .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
366 ismember_string(astring, bstring, opts.rows)
367 }
368 (Value::String(a), Value::StringArray(bstring)) => {
369 let astring = StringArray::new(vec![a], vec![1, 1])
370 .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
371 ismember_string(astring, bstring, opts.rows)
372 }
373 (Value::String(a), Value::String(b)) => {
374 let astring = StringArray::new(vec![a], vec![1, 1])
375 .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
376 let bstring = StringArray::new(vec![b], vec![1, 1])
377 .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
378 ismember_string(astring, bstring, opts.rows)
379 }
380
381 (left, right) => {
382 let tensor_a = tensor::value_into_tensor_for("ismember", left)
383 .map_err(|e| ismember_internal_error(e))?;
384 let tensor_b = tensor::value_into_tensor_for("ismember", right)
385 .map_err(|e| ismember_internal_error(e))?;
386 ismember_numeric_tensors(tensor_a, tensor_b, opts)
387 }
388 }
389}
390
391fn ismember_numeric_tensors(
392 a: Tensor,
393 b: Tensor,
394 opts: &IsMemberOptions,
395) -> crate::BuiltinResult<IsMemberEvaluation> {
396 if opts.rows {
397 ismember_numeric_rows(a, b)
398 } else {
399 ismember_numeric_elements(a, b)
400 }
401}
402
403pub fn ismember_numeric_from_tensors(
405 a: Tensor,
406 b: Tensor,
407 rows: bool,
408) -> crate::BuiltinResult<IsMemberEvaluation> {
409 let opts = IsMemberOptions { rows };
410 ismember_numeric_tensors(a, b, &opts)
411}
412
413fn ismember_numeric_elements(a: Tensor, b: Tensor) -> crate::BuiltinResult<IsMemberEvaluation> {
414 let mut map: HashMap<u64, usize> = HashMap::new();
415 for (idx, &value) in b.data.iter().enumerate() {
416 map.entry(canonicalize_f64(value)).or_insert(idx + 1);
417 }
418
419 let mut mask_data = Vec::<u8>::with_capacity(a.data.len());
420 let mut loc_data = Vec::<f64>::with_capacity(a.data.len());
421
422 for &value in &a.data {
423 let key = canonicalize_f64(value);
424 if let Some(&pos) = map.get(&key) {
425 mask_data.push(1);
426 loc_data.push(pos as f64);
427 } else {
428 mask_data.push(0);
429 loc_data.push(0.0);
430 }
431 }
432
433 let logical = LogicalArray::new(mask_data, a.shape.clone())
434 .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
435 let loc_tensor = Tensor::new(loc_data, a.shape.clone())
436 .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
437 Ok(IsMemberEvaluation::new(logical, loc_tensor))
438}
439
440fn ismember_numeric_rows(a: Tensor, b: Tensor) -> crate::BuiltinResult<IsMemberEvaluation> {
441 let (rows_a, cols_a) = tensor_rows_cols(&a, "ismember")?;
442 let (rows_b, cols_b) = tensor_rows_cols(&b, "ismember")?;
443 if cols_a != cols_b {
444 return Err(ismember_error(&ISMEMBER_ERROR_ROWS_COLUMN_MISMATCH));
445 }
446
447 let mut map: HashMap<NumericRowKey, usize> = HashMap::new();
448 for r in 0..rows_b {
449 let mut row_values = Vec::with_capacity(cols_b);
450 for c in 0..cols_b {
451 let idx = r + c * rows_b;
452 row_values.push(b.data[idx]);
453 }
454 let key = NumericRowKey::from_slice(&row_values);
455 map.entry(key).or_insert(r + 1);
456 }
457
458 let mut mask_data = vec![0u8; rows_a];
459 let mut loc_data = vec![0.0f64; rows_a];
460
461 for r in 0..rows_a {
462 let mut row_values = Vec::with_capacity(cols_a);
463 for c in 0..cols_a {
464 let idx = r + c * rows_a;
465 row_values.push(a.data[idx]);
466 }
467 let key = NumericRowKey::from_slice(&row_values);
468 if let Some(&pos) = map.get(&key) {
469 mask_data[r] = 1;
470 loc_data[r] = pos as f64;
471 }
472 }
473
474 let shape = vec![rows_a, 1];
475 let logical = LogicalArray::new(mask_data, shape.clone())
476 .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
477 let loc_tensor = Tensor::new(loc_data, shape)
478 .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
479 Ok(IsMemberEvaluation::new(logical, loc_tensor))
480}
481
482fn ismember_complex(
483 a: ComplexTensor,
484 b: ComplexTensor,
485 rows: bool,
486) -> crate::BuiltinResult<IsMemberEvaluation> {
487 if rows {
488 ismember_complex_rows(a, b)
489 } else {
490 ismember_complex_elements(a, b)
491 }
492}
493
494fn ismember_complex_elements(
495 a: ComplexTensor,
496 b: ComplexTensor,
497) -> crate::BuiltinResult<IsMemberEvaluation> {
498 let mut map: HashMap<ComplexKey, usize> = HashMap::new();
499 for (idx, &value) in b.data.iter().enumerate() {
500 map.entry(ComplexKey::new(value)).or_insert(idx + 1);
501 }
502
503 let mut mask_data = Vec::<u8>::with_capacity(a.data.len());
504 let mut loc_data = Vec::<f64>::with_capacity(a.data.len());
505
506 for &value in &a.data {
507 let key = ComplexKey::new(value);
508 if let Some(&pos) = map.get(&key) {
509 mask_data.push(1);
510 loc_data.push(pos as f64);
511 } else {
512 mask_data.push(0);
513 loc_data.push(0.0);
514 }
515 }
516
517 let logical = LogicalArray::new(mask_data, a.shape.clone())
518 .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
519 let loc_tensor = Tensor::new(loc_data, a.shape.clone())
520 .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
521 Ok(IsMemberEvaluation::new(logical, loc_tensor))
522}
523
524fn ismember_complex_rows(
525 a: ComplexTensor,
526 b: ComplexTensor,
527) -> crate::BuiltinResult<IsMemberEvaluation> {
528 let (rows_a, cols_a) = complex_rows_cols(&a)?;
529 let (rows_b, cols_b) = complex_rows_cols(&b)?;
530 if cols_a != cols_b {
531 return Err(ismember_error(&ISMEMBER_ERROR_ROWS_COLUMN_MISMATCH).into());
532 }
533
534 let mut map: HashMap<Vec<ComplexKey>, usize> = HashMap::new();
535 for r in 0..rows_b {
536 let mut row_keys = Vec::with_capacity(cols_b);
537 for c in 0..cols_b {
538 let idx = r + c * rows_b;
539 row_keys.push(ComplexKey::new(b.data[idx]));
540 }
541 map.entry(row_keys).or_insert(r + 1);
542 }
543
544 let mut mask_data = vec![0u8; rows_a];
545 let mut loc_data = vec![0.0f64; rows_a];
546
547 for r in 0..rows_a {
548 let mut row_keys = Vec::with_capacity(cols_a);
549 for c in 0..cols_a {
550 let idx = r + c * rows_a;
551 row_keys.push(ComplexKey::new(a.data[idx]));
552 }
553 if let Some(&pos) = map.get(&row_keys) {
554 mask_data[r] = 1;
555 loc_data[r] = pos as f64;
556 }
557 }
558
559 let shape = vec![rows_a, 1];
560 let logical = LogicalArray::new(mask_data, shape.clone())
561 .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
562 let loc_tensor = Tensor::new(loc_data, shape)
563 .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
564 Ok(IsMemberEvaluation::new(logical, loc_tensor))
565}
566
567fn ismember_char(
568 a: CharArray,
569 b: CharArray,
570 rows: bool,
571) -> crate::BuiltinResult<IsMemberEvaluation> {
572 if rows {
573 ismember_char_rows(a, b)
574 } else {
575 ismember_char_elements(a, b)
576 }
577}
578
579fn ismember_char_elements(a: CharArray, b: CharArray) -> crate::BuiltinResult<IsMemberEvaluation> {
580 let rows_b = b.rows;
581 let cols_b = b.cols;
582 let mut map: HashMap<char, usize> = HashMap::new();
583
584 for col in 0..cols_b {
585 for row in 0..rows_b {
586 let data_idx = row * cols_b + col;
587 let ch = b.data[data_idx];
588 let linear_idx = row + col * rows_b;
589 map.entry(ch).or_insert(linear_idx + 1);
590 }
591 }
592
593 let rows_a = a.rows;
594 let cols_a = a.cols;
595 let mut mask_data = vec![0u8; rows_a * cols_a];
596 let mut loc_data = vec![0.0f64; rows_a * cols_a];
597
598 for col in 0..cols_a {
599 for row in 0..rows_a {
600 let data_idx = row * cols_a + col;
601 let ch = a.data[data_idx];
602 let linear_idx = row + col * rows_a;
603 if let Some(&pos) = map.get(&ch) {
604 mask_data[linear_idx] = 1;
605 loc_data[linear_idx] = pos as f64;
606 }
607 }
608 }
609
610 let shape = vec![rows_a, cols_a];
611 let logical = LogicalArray::new(mask_data, shape.clone())
612 .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
613 let loc_tensor = Tensor::new(loc_data, shape)
614 .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
615 Ok(IsMemberEvaluation::new(logical, loc_tensor))
616}
617
618fn ismember_char_rows(a: CharArray, b: CharArray) -> crate::BuiltinResult<IsMemberEvaluation> {
619 if a.cols != b.cols {
620 return Err(ismember_error(&ISMEMBER_ERROR_ROWS_COLUMN_MISMATCH).into());
621 }
622
623 let rows_b = b.rows;
624 let cols = b.cols;
625 let mut map: HashMap<RowCharKey, usize> = HashMap::new();
626
627 for r in 0..rows_b {
628 let mut row_values = Vec::with_capacity(cols);
629 for c in 0..cols {
630 let idx = r * cols + c;
631 row_values.push(b.data[idx]);
632 }
633 let key = RowCharKey::from_slice(&row_values);
634 map.entry(key).or_insert(r + 1);
635 }
636
637 let rows_a = a.rows;
638 let mut mask_data = vec![0u8; rows_a];
639 let mut loc_data = vec![0.0f64; rows_a];
640
641 for r in 0..rows_a {
642 let mut row_values = Vec::with_capacity(cols);
643 for c in 0..cols {
644 let idx = r * cols + c;
645 row_values.push(a.data[idx]);
646 }
647 let key = RowCharKey::from_slice(&row_values);
648 if let Some(&pos) = map.get(&key) {
649 mask_data[r] = 1;
650 loc_data[r] = pos as f64;
651 }
652 }
653
654 let shape = vec![rows_a, 1];
655 let logical = LogicalArray::new(mask_data, shape.clone())
656 .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
657 let loc_tensor = Tensor::new(loc_data, shape)
658 .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
659 Ok(IsMemberEvaluation::new(logical, loc_tensor))
660}
661
662fn ismember_string(
663 a: StringArray,
664 b: StringArray,
665 rows: bool,
666) -> crate::BuiltinResult<IsMemberEvaluation> {
667 if rows {
668 ismember_string_rows(a, b)
669 } else {
670 ismember_string_elements(a, b)
671 }
672}
673
674fn ismember_string_elements(
675 a: StringArray,
676 b: StringArray,
677) -> crate::BuiltinResult<IsMemberEvaluation> {
678 let mut map: HashMap<String, usize> = HashMap::new();
679 for (idx, value) in b.data.iter().enumerate() {
680 map.entry(value.clone()).or_insert(idx + 1);
681 }
682
683 let mut mask_data = Vec::<u8>::with_capacity(a.data.len());
684 let mut loc_data = Vec::<f64>::with_capacity(a.data.len());
685
686 for value in &a.data {
687 if let Some(&pos) = map.get(value) {
688 mask_data.push(1);
689 loc_data.push(pos as f64);
690 } else {
691 mask_data.push(0);
692 loc_data.push(0.0);
693 }
694 }
695
696 let logical = LogicalArray::new(mask_data, a.shape.clone())
697 .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
698 let loc_tensor = Tensor::new(loc_data, a.shape.clone())
699 .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
700 Ok(IsMemberEvaluation::new(logical, loc_tensor))
701}
702
703fn ismember_string_rows(
704 a: StringArray,
705 b: StringArray,
706) -> crate::BuiltinResult<IsMemberEvaluation> {
707 if a.shape.len() != 2 || b.shape.len() != 2 {
708 return Err(ismember_internal_error(
709 "ismember: 'rows' option requires 2-D string arrays",
710 ));
711 }
712 if a.shape[1] != b.shape[1] {
713 return Err(ismember_error(&ISMEMBER_ERROR_ROWS_COLUMN_MISMATCH).into());
714 }
715
716 let rows_a = a.shape[0];
717 let cols = a.shape[1];
718 let rows_b = b.shape[0];
719
720 let mut map: HashMap<RowStringKey, usize> = HashMap::new();
721 for r in 0..rows_b {
722 let mut row_values = Vec::with_capacity(cols);
723 for c in 0..cols {
724 let idx = r + c * rows_b;
725 row_values.push(b.data[idx].clone());
726 }
727 let key = RowStringKey(row_values);
728 map.entry(key).or_insert(r + 1);
729 }
730
731 let mut mask_data = vec![0u8; rows_a];
732 let mut loc_data = vec![0.0f64; rows_a];
733
734 for r in 0..rows_a {
735 let mut row_values = Vec::with_capacity(cols);
736 for c in 0..cols {
737 let idx = r + c * rows_a;
738 row_values.push(a.data[idx].clone());
739 }
740 let key = RowStringKey(row_values);
741 if let Some(&pos) = map.get(&key) {
742 mask_data[r] = 1;
743 loc_data[r] = pos as f64;
744 }
745 }
746
747 let shape = vec![rows_a, 1];
748 let logical = LogicalArray::new(mask_data, shape.clone())
749 .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
750 let loc_tensor = Tensor::new(loc_data, shape)
751 .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
752 Ok(IsMemberEvaluation::new(logical, loc_tensor))
753}
754
755fn tensor_rows_cols(t: &Tensor, name: &str) -> crate::BuiltinResult<(usize, usize)> {
756 match t.shape.len() {
757 0 => Ok((1, 1)),
758 1 => Ok((t.shape[0], 1)),
759 2 => Ok((t.shape[0], t.shape[1])),
760 _ => Err(ismember_internal_error(format!(
761 "{name}: 'rows' option requires 2-D numeric matrices"
762 ))
763 .into()),
764 }
765}
766
767fn complex_rows_cols(t: &ComplexTensor) -> crate::BuiltinResult<(usize, usize)> {
768 match t.shape.len() {
769 0 => Ok((1, 1)),
770 1 => Ok((t.shape[0], 1)),
771 2 => Ok((t.shape[0], t.shape[1])),
772 _ => Err(ismember_internal_error(
773 "ismember: 'rows' option requires 2-D complex matrices",
774 )),
775 }
776}
777
778#[derive(Debug, Clone, PartialEq, Eq, Hash)]
779struct NumericRowKey(Vec<u64>);
780
781impl NumericRowKey {
782 fn from_slice(values: &[f64]) -> Self {
783 NumericRowKey(values.iter().map(|&v| canonicalize_f64(v)).collect())
784 }
785}
786
787#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
788struct ComplexKey {
789 re: u64,
790 im: u64,
791}
792
793impl ComplexKey {
794 fn new(value: (f64, f64)) -> Self {
795 Self {
796 re: canonicalize_f64(value.0),
797 im: canonicalize_f64(value.1),
798 }
799 }
800}
801
802#[derive(Debug, Clone, PartialEq, Eq, Hash)]
803struct RowCharKey(Vec<u32>);
804
805impl RowCharKey {
806 fn from_slice(values: &[char]) -> Self {
807 RowCharKey(values.iter().map(|&ch| ch as u32).collect())
808 }
809}
810
811#[derive(Debug, Clone, PartialEq, Eq, Hash)]
812struct RowStringKey(Vec<String>);
813
814fn canonicalize_f64(value: f64) -> u64 {
815 if value.is_nan() {
816 0x7ff8_0000_0000_0000u64
817 } else if value == 0.0 {
818 0u64
819 } else {
820 value.to_bits()
821 }
822}
823
824#[derive(Debug, Clone)]
825pub struct IsMemberEvaluation {
826 mask: LogicalArray,
827 loc: Tensor,
828}
829
830impl IsMemberEvaluation {
831 fn new(mask: LogicalArray, loc: Tensor) -> Self {
832 Self { mask, loc }
833 }
834
835 pub fn from_provider_result(result: IsMemberResult) -> crate::BuiltinResult<Self> {
836 let mask = LogicalArray::new(result.mask.data, result.mask.shape)
837 .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
838 let loc = Tensor::new(result.loc.data, result.loc.shape)
839 .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
840 Ok(IsMemberEvaluation::new(mask, loc))
841 }
842
843 pub fn into_numeric_ismember_result(self) -> crate::BuiltinResult<IsMemberResult> {
844 let IsMemberEvaluation { mask, loc } = self;
845 Ok(IsMemberResult {
846 mask: HostLogicalOwned {
847 data: mask.data,
848 shape: mask.shape,
849 },
850 loc: HostTensorOwned {
851 data: loc.data,
852 shape: loc.shape,
853 storage: GpuTensorStorage::Real,
854 },
855 })
856 }
857
858 pub fn into_mask_value(self) -> Value {
859 logical_array_into_value(self.mask)
860 }
861
862 pub fn mask_value(&self) -> Value {
863 logical_array_into_value(self.mask.clone())
864 }
865
866 pub fn into_pair(self) -> (Value, Value) {
867 let mask = logical_array_into_value(self.mask);
868 let loc = tensor::tensor_into_value(self.loc);
869 (mask, loc)
870 }
871
872 pub fn loc_value(&self) -> Value {
873 tensor::tensor_into_value(self.loc.clone())
874 }
875}
876
877fn logical_array_into_value(logical: LogicalArray) -> Value {
878 if logical.data.len() == 1 {
879 Value::Bool(logical.data[0] != 0)
880 } else {
881 Value::LogicalArray(logical)
882 }
883}
884
885#[cfg(test)]
886pub(crate) mod tests {
887 use super::*;
888 use crate::builtins::common::test_support;
889 use runmat_builtins::{ResolveContext, Tensor, Type};
890
891 #[cfg(feature = "wgpu")]
892 use runmat_accelerate_api::HostTensorView;
893
894 fn evaluate_sync(
895 a: Value,
896 b: Value,
897 rest: &[Value],
898 ) -> crate::BuiltinResult<IsMemberEvaluation> {
899 futures::executor::block_on(evaluate(a, b, rest))
900 }
901
902 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
903 #[test]
904 fn numeric_membership_basic() {
905 let a = Tensor::new(vec![5.0, 7.0, 2.0, 7.0], vec![1, 4]).unwrap();
906 let b = Tensor::new(vec![7.0, 9.0, 5.0], vec![1, 3]).unwrap();
907 let eval = ismember_numeric_elements(a, b).expect("ismember");
908 assert_eq!(eval.mask.data, vec![1, 1, 0, 1]);
909 assert_eq!(eval.loc.data, vec![3.0, 1.0, 0.0, 1.0]);
910 }
911
912 #[test]
913 fn ismember_type_resolver_logical() {
914 assert_eq!(
915 logical_output_type(
916 &[Type::tensor(), Type::tensor()],
917 &ResolveContext::new(Vec::new()),
918 ),
919 Type::logical()
920 );
921 }
922
923 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
924 #[test]
925 fn numeric_nan_membership() {
926 let a = Tensor::new(vec![f64::NAN, 1.0], vec![1, 2]).unwrap();
927 let b = Tensor::new(vec![f64::NAN, 2.0], vec![1, 2]).unwrap();
928 let eval = ismember_numeric_elements(a, b).expect("ismember");
929 assert_eq!(eval.mask.data, vec![1, 0]);
930 assert_eq!(eval.loc.data, vec![1.0, 0.0]);
931 }
932
933 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
934 #[test]
935 fn numeric_rows_membership() {
936 let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
937 let b = Tensor::new(vec![3.0, 5.0, 1.0, 4.0, 6.0, 2.0], vec![3, 2]).unwrap();
938 let eval = ismember_numeric_rows(a, b).expect("ismember");
939 assert_eq!(eval.mask.data, vec![1, 1, 1]);
940 assert_eq!(eval.loc.data, vec![3.0, 1.0, 3.0]);
941 assert_eq!(eval.loc.shape, vec![3, 1]);
942 }
943
944 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
945 #[test]
946 fn complex_membership() {
947 let a = ComplexTensor::new(vec![(1.0, 2.0), (0.0, 0.0)], vec![1, 2]).unwrap();
948 let b = ComplexTensor::new(vec![(0.0, 0.0), (1.0, 2.0)], vec![1, 2]).unwrap();
949 let eval = ismember_complex_elements(a, b).expect("ismember");
950 assert_eq!(eval.mask.data, vec![1, 1]);
951 assert_eq!(eval.loc.data, vec![2.0, 1.0]);
952 }
953
954 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
955 #[test]
956 fn complex_rows_membership() {
957 let a = ComplexTensor::new(
958 vec![(1.0, 1.0), (3.0, 0.0), (2.0, 0.0), (4.0, 4.0)],
959 vec![2, 2],
960 )
961 .unwrap();
962 let b = ComplexTensor::new(
963 vec![
964 (1.0, 1.0),
965 (5.0, 0.0),
966 (3.0, 0.0),
967 (2.0, 0.0),
968 (6.0, 0.0),
969 (4.0, 4.0),
970 ],
971 vec![3, 2],
972 )
973 .unwrap();
974 let eval = ismember_complex_rows(a, b).expect("ismember");
975 assert_eq!(eval.mask.data, vec![1, 1]);
976 assert_eq!(eval.loc.data, vec![1.0, 3.0]);
977 }
978
979 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
980 #[test]
981 fn char_membership() {
982 let a = CharArray::new(vec!['r', 'u', 'n', 'm'], 2, 2).unwrap();
983 let b = CharArray::new(vec!['m', 'a', 'r', 'u'], 2, 2).unwrap();
984 let eval = ismember_char_elements(a, b).expect("ismember");
985 assert_eq!(eval.mask.data, vec![1, 0, 1, 1]);
986 assert_eq!(eval.loc.data, vec![2.0, 0.0, 4.0, 1.0]);
987 }
988
989 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
990 #[test]
991 fn char_rows_membership() {
992 let a = CharArray::new(vec!['m', 'a', 't', 'l'], 2, 2).unwrap();
993 let b = CharArray::new(vec!['m', 'a', 'g', 'e', 't', 'l'], 3, 2).unwrap();
994 let eval = ismember_char_rows(a, b).expect("ismember");
995 assert_eq!(eval.mask.data, vec![1, 1]);
996 assert_eq!(eval.loc.data, vec![1.0, 3.0]);
997 }
998
999 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1000 #[test]
1001 fn string_membership() {
1002 let a = StringArray::new(
1003 vec![
1004 "apple".to_string(),
1005 "pear".to_string(),
1006 "banana".to_string(),
1007 ],
1008 vec![1, 3],
1009 )
1010 .unwrap();
1011 let b = StringArray::new(
1012 vec![
1013 "pear".to_string(),
1014 "orange".to_string(),
1015 "apple".to_string(),
1016 ],
1017 vec![1, 3],
1018 )
1019 .unwrap();
1020 let eval = ismember_string_elements(a, b).expect("ismember");
1021 assert_eq!(eval.mask.data, vec![1, 1, 0]);
1022 assert_eq!(eval.loc.data, vec![3.0, 1.0, 0.0]);
1023 }
1024
1025 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1026 #[test]
1027 fn string_rows_membership() {
1028 let a = StringArray::new(
1029 vec![
1030 "alpha".to_string(),
1031 "gamma".to_string(),
1032 "beta".to_string(),
1033 "delta".to_string(),
1034 ],
1035 vec![2, 2],
1036 )
1037 .unwrap();
1038 let b = StringArray::new(
1039 vec![
1040 "alpha".to_string(),
1041 "theta".to_string(),
1042 "gamma".to_string(),
1043 "beta".to_string(),
1044 "eta".to_string(),
1045 "delta".to_string(),
1046 ],
1047 vec![3, 2],
1048 )
1049 .unwrap();
1050 let eval = ismember_string_rows(a, b).expect("ismember");
1051 assert_eq!(eval.mask.data, vec![1, 1]);
1052 assert_eq!(eval.loc.data, vec![1.0, 3.0]);
1053 }
1054
1055 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1056 #[test]
1057 fn options_reject_legacy() {
1058 let err = parse_options(&[Value::from("legacy")]).unwrap_err();
1059 assert_eq!(
1060 err.identifier(),
1061 ISMEMBER_ERROR_LEGACY_OPTION_UNSUPPORTED.identifier
1062 );
1063 }
1064
1065 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1066 #[test]
1067 fn rejects_unknown_option() {
1068 let err =
1069 evaluate_sync(Value::Num(1.0), Value::Num(1.0), &[Value::from("stable")]).unwrap_err();
1070 assert_eq!(err.identifier(), ISMEMBER_ERROR_UNKNOWN_OPTION.identifier);
1071 }
1072
1073 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1074 #[test]
1075 fn ismember_runtime_numeric() {
1076 let a = Value::Tensor(Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap());
1077 let b = Value::Tensor(Tensor::new(vec![3.0, 1.0], vec![2, 1]).unwrap());
1078 let (mask, loc) = evaluate_sync(a, b, &[]).unwrap().into_pair();
1079 match mask {
1080 Value::LogicalArray(arr) => assert_eq!(arr.data, vec![1, 0, 1]),
1081 other => panic!("expected logical array, got {other:?}"),
1082 }
1083 match loc {
1084 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 0.0, 1.0]),
1085 other => panic!("expected tensor, got {other:?}"),
1086 }
1087 }
1088
1089 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1090 #[test]
1091 fn logical_inputs_promoted() {
1092 let a = Value::Bool(true);
1093 let logical_b =
1094 LogicalArray::new(vec![1, 0], vec![2, 1]).expect("logical array construction");
1095 let eval = evaluate_sync(a, Value::LogicalArray(logical_b), &[]).expect("ismember");
1096 assert_eq!(eval.mask_value(), Value::Bool(true));
1097 assert_eq!(eval.loc_value(), Value::Num(1.0));
1098 }
1099
1100 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1101 #[test]
1102 fn ismember_rows_shape_checks() {
1103 let a = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1104 let b = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1105 assert!(ismember_numeric_rows(a.clone(), b.clone()).is_ok());
1106 let bad = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1107 let err = ismember_numeric_rows(a, bad).unwrap_err();
1108 assert_eq!(
1109 err.identifier(),
1110 ISMEMBER_ERROR_ROWS_COLUMN_MISMATCH.identifier
1111 );
1112 }
1113
1114 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1115 #[test]
1116 fn ismember_gpu_roundtrip() {
1117 test_support::with_test_provider(|provider| {
1118 let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 4.0], vec![4, 1]).unwrap();
1119 let set = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1120 let view_a = runmat_accelerate_api::HostTensorView {
1121 data: &tensor.data,
1122 shape: &tensor.shape,
1123 };
1124 let view_b = runmat_accelerate_api::HostTensorView {
1125 data: &set.data,
1126 shape: &set.shape,
1127 };
1128 let handle_a = provider.upload(&view_a).expect("upload a");
1129 let handle_b = provider.upload(&view_b).expect("upload b");
1130 let eval = evaluate_sync(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
1131 .expect("ismember");
1132 assert_eq!(eval.mask.data, vec![0, 1, 0, 1]);
1133 assert_eq!(eval.loc.data, vec![0.0, 1.0, 0.0, 1.0]);
1134 });
1135 }
1136
1137 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1138 #[test]
1139 fn ismember_gpu_rows_roundtrip() {
1140 test_support::with_test_provider(|provider| {
1141 let rows = Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
1142 let bank = Tensor::new(vec![1.0, 5.0, 3.0, 2.0, 6.0, 4.0], vec![3, 2]).unwrap();
1143 let view_a = runmat_accelerate_api::HostTensorView {
1144 data: &rows.data,
1145 shape: &rows.shape,
1146 };
1147 let view_b = runmat_accelerate_api::HostTensorView {
1148 data: &bank.data,
1149 shape: &bank.shape,
1150 };
1151 let handle_a = provider.upload(&view_a).expect("upload a");
1152 let handle_b = provider.upload(&view_b).expect("upload b");
1153 let eval = evaluate_sync(
1154 Value::GpuTensor(handle_a.clone()),
1155 Value::GpuTensor(handle_b.clone()),
1156 &[Value::from("rows")],
1157 )
1158 .expect("ismember");
1159 assert_eq!(eval.mask.data, vec![1, 1]);
1160 assert_eq!(eval.loc.data, vec![1.0, 3.0]);
1161 let _ = provider.free(&handle_a);
1162 let _ = provider.free(&handle_b);
1163 });
1164 }
1165
1166 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1167 #[test]
1168 #[cfg(feature = "wgpu")]
1169 fn ismember_wgpu_numeric_matches_cpu() {
1170 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1171 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1172 );
1173
1174 let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 4.0], vec![4, 1]).unwrap();
1175 let set = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1176 let cpu_eval =
1177 ismember_numeric_from_tensors(tensor.clone(), set.clone(), false).expect("cpu");
1178
1179 let provider = runmat_accelerate_api::provider().expect("provider");
1180 let view_a = HostTensorView {
1181 data: &tensor.data,
1182 shape: &tensor.shape,
1183 };
1184 let view_b = HostTensorView {
1185 data: &set.data,
1186 shape: &set.shape,
1187 };
1188 let handle_a = provider.upload(&view_a).expect("upload a");
1189 let handle_b = provider.upload(&view_b).expect("upload b");
1190
1191 let eval = evaluate_sync(
1192 Value::GpuTensor(handle_a.clone()),
1193 Value::GpuTensor(handle_b.clone()),
1194 &[],
1195 )
1196 .expect("gpu evaluate");
1197 assert_eq!(eval.mask.data, cpu_eval.mask.data);
1198 assert_eq!(eval.loc.data, cpu_eval.loc.data);
1199
1200 let _ = provider.free(&handle_a);
1201 let _ = provider.free(&handle_b);
1202
1203 let matrix = Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
1204 let bank = Tensor::new(vec![1.0, 7.0, 3.0, 2.0, 9.0, 4.0], vec![3, 2]).unwrap();
1205 let cpu_rows =
1206 ismember_numeric_from_tensors(matrix.clone(), bank.clone(), true).expect("cpu rows");
1207 let view_matrix = HostTensorView {
1208 data: &matrix.data,
1209 shape: &matrix.shape,
1210 };
1211 let view_bank = HostTensorView {
1212 data: &bank.data,
1213 shape: &bank.shape,
1214 };
1215 let handle_matrix = provider.upload(&view_matrix).expect("upload matrix");
1216 let handle_bank = provider.upload(&view_bank).expect("upload bank");
1217 let eval_rows = evaluate_sync(
1218 Value::GpuTensor(handle_matrix.clone()),
1219 Value::GpuTensor(handle_bank.clone()),
1220 &[Value::from("rows")],
1221 )
1222 .expect("gpu rows evaluate");
1223 assert_eq!(eval_rows.mask.data, cpu_rows.mask.data);
1224 assert_eq!(eval_rows.loc.data, cpu_rows.loc.data);
1225 let _ = provider.free(&handle_matrix);
1226 let _ = provider.free(&handle_bank);
1227 }
1228
1229 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1230 #[test]
1231 fn scalar_return_is_bool() {
1232 let a = Value::Tensor(Tensor::new(vec![7.0], vec![1, 1]).unwrap());
1233 let b = Value::Tensor(Tensor::new(vec![7.0], vec![1, 1]).unwrap());
1234 let mask = evaluate_sync(a, b, &[]).unwrap().into_mask_value();
1235 assert_eq!(mask, Value::Bool(true));
1236 }
1237
1238 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1239 #[test]
1240 fn parse_rows_option() {
1241 let opts = parse_options(&[Value::from("rows")]).unwrap();
1242 assert!(opts.rows);
1243 }
1244
1245 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1246 #[test]
1247 fn numeric_rows_with_nan() {
1248 let a = Tensor::new(vec![f64::NAN, 1.0], vec![2, 1]).unwrap();
1249 let b = Tensor::new(vec![f64::NAN, 2.0], vec![2, 1]).unwrap();
1250 let eval = ismember_numeric_rows(a, b).expect("ismember");
1251 assert_eq!(eval.mask.data, vec![1, 0]);
1252 assert_eq!(eval.loc.data, vec![1.0, 0.0]);
1253 }
1254}