Skip to main content

runmat_runtime/builtins/array/sorting_sets/
ismember.rs

1//! MATLAB-compatible `ismember` builtin with GPU-aware semantics for RunMat.
2
3use std::collections::HashMap;
4
5use runmat_accelerate_api::{
6    GpuTensorHandle, GpuTensorStorage, HostLogicalOwned, HostTensorOwned,
7    IsMemberOptions as ProviderIsMemberOptions, IsMemberResult,
8};
9use runmat_builtins::{
10    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
11    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
12    CharArray, ComplexTensor, LogicalArray, StringArray, Tensor, Value,
13};
14use runmat_macros::runtime_builtin;
15
16use super::type_resolvers::logical_output_type;
17use crate::build_runtime_error;
18use crate::builtins::common::gpu_helpers;
19use crate::builtins::common::spec::{
20    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
21    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
22};
23use crate::builtins::common::tensor;
24
25#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::sorting_sets::ismember")]
26pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
27    name: "ismember",
28    op_kind: GpuOpKind::Custom("ismember"),
29    supported_precisions: &[ScalarType::F32, ScalarType::F64],
30    broadcast: BroadcastSemantics::None,
31    provider_hooks: &[ProviderHook::Custom("ismember")],
32    constant_strategy: ConstantStrategy::InlineLiteral,
33    residency: ResidencyPolicy::GatherImmediately,
34    nan_mode: ReductionNaN::Include,
35    two_pass_threshold: None,
36    workgroup_size: None,
37    accepts_nan_mode: false,
38    notes: "Providers may supply dedicated membership kernels; until then RunMat gathers GPU tensors to host memory.",
39};
40
41#[runmat_macros::register_fusion_spec(
42    builtin_path = "crate::builtins::array::sorting_sets::ismember"
43)]
44pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
45    name: "ismember",
46    shape: ShapeRequirements::Any,
47    constant_strategy: ConstantStrategy::InlineLiteral,
48    elementwise: None,
49    reduction: None,
50    emits_nan: false,
51    notes: "`ismember` materialises logical outputs and terminates fusion chains; upstream tensors are gathered when necessary.",
52};
53
54const BUILTIN_NAME: &str = "ismember";
55
56const ISMEMBER_OUTPUT_MASK: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
57    name: "tf",
58    ty: BuiltinParamType::LogicalArray,
59    arity: BuiltinParamArity::Required,
60    default: None,
61    description: "Membership mask over A.",
62}];
63
64const ISMEMBER_OUTPUT_MASK_LOC: [BuiltinParamDescriptor; 2] = [
65    BuiltinParamDescriptor {
66        name: "tf",
67        ty: BuiltinParamType::LogicalArray,
68        arity: BuiltinParamArity::Required,
69        default: None,
70        description: "Membership mask over A.",
71    },
72    BuiltinParamDescriptor {
73        name: "loc",
74        ty: BuiltinParamType::NumericArray,
75        arity: BuiltinParamArity::Required,
76        default: None,
77        description: "First-match indices into B for each element/row in A (0 when absent).",
78    },
79];
80
81const ISMEMBER_INPUTS_A_B: [BuiltinParamDescriptor; 2] = [
82    BuiltinParamDescriptor {
83        name: "A",
84        ty: BuiltinParamType::Any,
85        arity: BuiltinParamArity::Required,
86        default: None,
87        description: "Values or rows to query.",
88    },
89    BuiltinParamDescriptor {
90        name: "B",
91        ty: BuiltinParamType::Any,
92        arity: BuiltinParamArity::Required,
93        default: None,
94        description: "Reference set of values or rows.",
95    },
96];
97
98const ISMEMBER_INPUTS_A_B_OPTIONS: [BuiltinParamDescriptor; 3] = [
99    BuiltinParamDescriptor {
100        name: "A",
101        ty: BuiltinParamType::Any,
102        arity: BuiltinParamArity::Required,
103        default: None,
104        description: "Values or rows to query.",
105    },
106    BuiltinParamDescriptor {
107        name: "B",
108        ty: BuiltinParamType::Any,
109        arity: BuiltinParamArity::Required,
110        default: None,
111        description: "Reference set of values or rows.",
112    },
113    BuiltinParamDescriptor {
114        name: "option",
115        ty: BuiltinParamType::StringScalar,
116        arity: BuiltinParamArity::Variadic,
117        default: None,
118        description: "Option tokens: 'rows'.",
119    },
120];
121
122const ISMEMBER_SIGNATURES: [BuiltinSignatureDescriptor; 4] = [
123    BuiltinSignatureDescriptor {
124        label: "tf = ismember(A, B)",
125        inputs: &ISMEMBER_INPUTS_A_B,
126        outputs: &ISMEMBER_OUTPUT_MASK,
127    },
128    BuiltinSignatureDescriptor {
129        label: "tf = ismember(A, B, option...)",
130        inputs: &ISMEMBER_INPUTS_A_B_OPTIONS,
131        outputs: &ISMEMBER_OUTPUT_MASK,
132    },
133    BuiltinSignatureDescriptor {
134        label: "[tf, loc] = ismember(A, B)",
135        inputs: &ISMEMBER_INPUTS_A_B,
136        outputs: &ISMEMBER_OUTPUT_MASK_LOC,
137    },
138    BuiltinSignatureDescriptor {
139        label: "[tf, loc] = ismember(A, B, option...)",
140        inputs: &ISMEMBER_INPUTS_A_B_OPTIONS,
141        outputs: &ISMEMBER_OUTPUT_MASK_LOC,
142    },
143];
144
145const ISMEMBER_ERROR_LEGACY_OPTION_UNSUPPORTED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
146    code: "RM.ISMEMBER.LEGACY_OPTION_UNSUPPORTED",
147    identifier: Some("RunMat:ismember:LegacyOptionUnsupported"),
148    when: "Legacy compatibility options are requested.",
149    message: "ismember: the 'legacy' behaviour is not supported",
150};
151
152const ISMEMBER_ERROR_UNKNOWN_OPTION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
153    code: "RM.ISMEMBER.UNKNOWN_OPTION",
154    identifier: Some("RunMat:ismember:UnknownOption"),
155    when: "An unsupported option token is provided.",
156    message: "ismember: unrecognised option",
157};
158
159const ISMEMBER_ERROR_ROWS_COLUMN_MISMATCH: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
160    code: "RM.ISMEMBER.ROWS_COLUMN_MISMATCH",
161    identifier: Some("RunMat:ismember:RowsColumnMismatch"),
162    when: "'rows' mode is used and column counts differ.",
163    message: "ismember: inputs must have the same number of columns when using 'rows'",
164};
165
166const ISMEMBER_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
167    code: "RM.ISMEMBER.INVALID_ARGUMENT",
168    identifier: Some("RunMat:ismember:InvalidArgument"),
169    when: "Option arguments are not string-like where required.",
170    message: "ismember: expected string option arguments",
171};
172
173const ISMEMBER_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
174    code: "RM.ISMEMBER.INTERNAL",
175    identifier: Some("RunMat:ismember:Internal"),
176    when: "Internal conversion/allocation/provider decode fails.",
177    message: "ismember: internal operation failed",
178};
179
180const ISMEMBER_ERRORS: [BuiltinErrorDescriptor; 5] = [
181    ISMEMBER_ERROR_LEGACY_OPTION_UNSUPPORTED,
182    ISMEMBER_ERROR_UNKNOWN_OPTION,
183    ISMEMBER_ERROR_ROWS_COLUMN_MISMATCH,
184    ISMEMBER_ERROR_INVALID_ARGUMENT,
185    ISMEMBER_ERROR_INTERNAL,
186];
187
188pub const ISMEMBER_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
189    signatures: &ISMEMBER_SIGNATURES,
190    output_mode: BuiltinOutputMode::ByRequestedOutputCount,
191    completion_policy: BuiltinCompletionPolicy::Public,
192    errors: &ISMEMBER_ERRORS,
193};
194
195fn ismember_error_with(
196    error: &'static BuiltinErrorDescriptor,
197    message: impl Into<String>,
198) -> crate::RuntimeError {
199    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
200    if let Some(identifier) = error.identifier {
201        builder = builder.with_identifier(identifier);
202    }
203    builder.build()
204}
205
206fn ismember_error(error: &'static BuiltinErrorDescriptor) -> crate::RuntimeError {
207    ismember_error_with(error, error.message)
208}
209
210fn ismember_internal_error(message: impl Into<String>) -> crate::RuntimeError {
211    ismember_error_with(&ISMEMBER_ERROR_INTERNAL, message)
212}
213
214#[runtime_builtin(
215    name = "ismember",
216    category = "array/sorting_sets",
217    summary = "Identify array elements or rows that appear in another array while returning first-match indices.",
218    keywords = "ismember,membership,set,rows,indices,gpu",
219    accel = "array_construct",
220    sink = true,
221    type_resolver(logical_output_type),
222    descriptor(crate::builtins::array::sorting_sets::ismember::ISMEMBER_DESCRIPTOR),
223    builtin_path = "crate::builtins::array::sorting_sets::ismember"
224)]
225async fn ismember_builtin(a: Value, b: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
226    let eval = evaluate(a, b, &rest).await?;
227    if let Some(out_count) = crate::output_count::current_output_count() {
228        if out_count == 0 {
229            return Ok(Value::OutputList(Vec::new()));
230        }
231        if out_count == 1 {
232            return Ok(Value::OutputList(vec![eval.into_mask_value()]));
233        }
234        let (mask, loc) = eval.into_pair();
235        return Ok(crate::output_count::output_list_with_padding(
236            out_count,
237            vec![mask, loc],
238        ));
239    }
240    Ok(eval.into_mask_value())
241}
242
243/// Evaluate the `ismember` builtin once and expose all outputs.
244pub async fn evaluate(
245    a: Value,
246    b: Value,
247    rest: &[Value],
248) -> crate::BuiltinResult<IsMemberEvaluation> {
249    let opts = parse_options(rest)?;
250    match (a, b) {
251        (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
252            ismember_gpu_pair(handle_a, handle_b, &opts).await
253        }
254        (Value::GpuTensor(handle_a), other) => {
255            ismember_gpu_mixed(handle_a, other, &opts, true).await
256        }
257        (other, Value::GpuTensor(handle_b)) => {
258            ismember_gpu_mixed(handle_b, other, &opts, false).await
259        }
260        (left, right) => ismember_host(left, right, &opts),
261    }
262}
263
264#[derive(Debug, Clone, Copy)]
265struct IsMemberOptions {
266    rows: bool,
267}
268
269impl IsMemberOptions {
270    fn into_provider_options(self) -> ProviderIsMemberOptions {
271        ProviderIsMemberOptions { rows: self.rows }
272    }
273}
274
275fn parse_options(rest: &[Value]) -> crate::BuiltinResult<IsMemberOptions> {
276    let mut opts = IsMemberOptions { rows: false };
277    for arg in rest {
278        let text = tensor::value_to_string(arg)
279            .ok_or_else(|| ismember_error(&ISMEMBER_ERROR_INVALID_ARGUMENT))?;
280        let lowered = text.trim().to_ascii_lowercase();
281        match lowered.as_str() {
282            "rows" => opts.rows = true,
283            "legacy" | "r2012a" => {
284                return Err(ismember_error(&ISMEMBER_ERROR_LEGACY_OPTION_UNSUPPORTED))
285            }
286            other => {
287                return Err(ismember_error_with(
288                    &ISMEMBER_ERROR_UNKNOWN_OPTION,
289                    format!("ismember: unrecognised option '{other}'"),
290                ))
291            }
292        }
293    }
294    Ok(opts)
295}
296
297async fn ismember_gpu_pair(
298    handle_a: GpuTensorHandle,
299    handle_b: GpuTensorHandle,
300    opts: &IsMemberOptions,
301) -> crate::BuiltinResult<IsMemberEvaluation> {
302    if let Some(provider) = runmat_accelerate_api::provider() {
303        let provider_opts = opts.into_provider_options();
304        match provider
305            .ismember(&handle_a, &handle_b, &provider_opts)
306            .await
307        {
308            Ok(result) => return IsMemberEvaluation::from_provider_result(result),
309            Err(_) => {
310                // Fall back to host gather when the provider lacks an ismember implementation.
311            }
312        }
313    }
314    let tensor_a = gpu_helpers::gather_tensor_async(&handle_a).await?;
315    let tensor_b = gpu_helpers::gather_tensor_async(&handle_b).await?;
316    ismember_numeric_tensors(tensor_a, tensor_b, opts)
317}
318
319async fn ismember_gpu_mixed(
320    handle_gpu: GpuTensorHandle,
321    other: Value,
322    opts: &IsMemberOptions,
323    gpu_is_a: bool,
324) -> crate::BuiltinResult<IsMemberEvaluation> {
325    let tensor_gpu = gpu_helpers::gather_tensor_async(&handle_gpu).await?;
326    if gpu_is_a {
327        ismember_host(Value::Tensor(tensor_gpu), other, opts)
328    } else {
329        ismember_host(other, Value::Tensor(tensor_gpu), opts)
330    }
331}
332
333fn ismember_host(
334    a: Value,
335    b: Value,
336    opts: &IsMemberOptions,
337) -> crate::BuiltinResult<IsMemberEvaluation> {
338    match (a, b) {
339        (Value::ComplexTensor(at), Value::ComplexTensor(bt)) => ismember_complex(at, bt, opts.rows),
340        (Value::ComplexTensor(at), Value::Complex(re, im)) => {
341            let bt = ComplexTensor::new(vec![(re, im)], vec![1, 1])
342                .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
343            ismember_complex(at, bt, opts.rows)
344        }
345        (Value::Complex(a_re, a_im), Value::ComplexTensor(bt)) => {
346            let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
347                .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
348            ismember_complex(at, bt, opts.rows)
349        }
350        (Value::Complex(a_re, a_im), Value::Complex(b_re, b_im)) => {
351            let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
352                .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
353            let bt = ComplexTensor::new(vec![(b_re, b_im)], vec![1, 1])
354                .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
355            ismember_complex(at, bt, opts.rows)
356        }
357
358        (Value::CharArray(ac), Value::CharArray(bc)) => ismember_char(ac, bc, opts.rows),
359
360        (Value::StringArray(astring), Value::StringArray(bstring)) => {
361            ismember_string(astring, bstring, opts.rows)
362        }
363        (Value::StringArray(astring), Value::String(b)) => {
364            let bstring = StringArray::new(vec![b], vec![1, 1])
365                .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
366            ismember_string(astring, bstring, opts.rows)
367        }
368        (Value::String(a), Value::StringArray(bstring)) => {
369            let astring = StringArray::new(vec![a], vec![1, 1])
370                .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
371            ismember_string(astring, bstring, opts.rows)
372        }
373        (Value::String(a), Value::String(b)) => {
374            let astring = StringArray::new(vec![a], vec![1, 1])
375                .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
376            let bstring = StringArray::new(vec![b], vec![1, 1])
377                .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
378            ismember_string(astring, bstring, opts.rows)
379        }
380
381        (left, right) => {
382            let tensor_a = tensor::value_into_tensor_for("ismember", left)
383                .map_err(|e| ismember_internal_error(e))?;
384            let tensor_b = tensor::value_into_tensor_for("ismember", right)
385                .map_err(|e| ismember_internal_error(e))?;
386            ismember_numeric_tensors(tensor_a, tensor_b, opts)
387        }
388    }
389}
390
391fn ismember_numeric_tensors(
392    a: Tensor,
393    b: Tensor,
394    opts: &IsMemberOptions,
395) -> crate::BuiltinResult<IsMemberEvaluation> {
396    if opts.rows {
397        ismember_numeric_rows(a, b)
398    } else {
399        ismember_numeric_elements(a, b)
400    }
401}
402
403/// Helper exposed for acceleration providers handling numeric tensors on the host.
404pub fn ismember_numeric_from_tensors(
405    a: Tensor,
406    b: Tensor,
407    rows: bool,
408) -> crate::BuiltinResult<IsMemberEvaluation> {
409    let opts = IsMemberOptions { rows };
410    ismember_numeric_tensors(a, b, &opts)
411}
412
413fn ismember_numeric_elements(a: Tensor, b: Tensor) -> crate::BuiltinResult<IsMemberEvaluation> {
414    let mut map: HashMap<u64, usize> = HashMap::new();
415    for (idx, &value) in b.data.iter().enumerate() {
416        map.entry(canonicalize_f64(value)).or_insert(idx + 1);
417    }
418
419    let mut mask_data = Vec::<u8>::with_capacity(a.data.len());
420    let mut loc_data = Vec::<f64>::with_capacity(a.data.len());
421
422    for &value in &a.data {
423        let key = canonicalize_f64(value);
424        if let Some(&pos) = map.get(&key) {
425            mask_data.push(1);
426            loc_data.push(pos as f64);
427        } else {
428            mask_data.push(0);
429            loc_data.push(0.0);
430        }
431    }
432
433    let logical = LogicalArray::new(mask_data, a.shape.clone())
434        .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
435    let loc_tensor = Tensor::new(loc_data, a.shape.clone())
436        .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
437    Ok(IsMemberEvaluation::new(logical, loc_tensor))
438}
439
440fn ismember_numeric_rows(a: Tensor, b: Tensor) -> crate::BuiltinResult<IsMemberEvaluation> {
441    let (rows_a, cols_a) = tensor_rows_cols(&a, "ismember")?;
442    let (rows_b, cols_b) = tensor_rows_cols(&b, "ismember")?;
443    if cols_a != cols_b {
444        return Err(ismember_error(&ISMEMBER_ERROR_ROWS_COLUMN_MISMATCH));
445    }
446
447    let mut map: HashMap<NumericRowKey, usize> = HashMap::new();
448    for r in 0..rows_b {
449        let mut row_values = Vec::with_capacity(cols_b);
450        for c in 0..cols_b {
451            let idx = r + c * rows_b;
452            row_values.push(b.data[idx]);
453        }
454        let key = NumericRowKey::from_slice(&row_values);
455        map.entry(key).or_insert(r + 1);
456    }
457
458    let mut mask_data = vec![0u8; rows_a];
459    let mut loc_data = vec![0.0f64; rows_a];
460
461    for r in 0..rows_a {
462        let mut row_values = Vec::with_capacity(cols_a);
463        for c in 0..cols_a {
464            let idx = r + c * rows_a;
465            row_values.push(a.data[idx]);
466        }
467        let key = NumericRowKey::from_slice(&row_values);
468        if let Some(&pos) = map.get(&key) {
469            mask_data[r] = 1;
470            loc_data[r] = pos as f64;
471        }
472    }
473
474    let shape = vec![rows_a, 1];
475    let logical = LogicalArray::new(mask_data, shape.clone())
476        .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
477    let loc_tensor = Tensor::new(loc_data, shape)
478        .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
479    Ok(IsMemberEvaluation::new(logical, loc_tensor))
480}
481
482fn ismember_complex(
483    a: ComplexTensor,
484    b: ComplexTensor,
485    rows: bool,
486) -> crate::BuiltinResult<IsMemberEvaluation> {
487    if rows {
488        ismember_complex_rows(a, b)
489    } else {
490        ismember_complex_elements(a, b)
491    }
492}
493
494fn ismember_complex_elements(
495    a: ComplexTensor,
496    b: ComplexTensor,
497) -> crate::BuiltinResult<IsMemberEvaluation> {
498    let mut map: HashMap<ComplexKey, usize> = HashMap::new();
499    for (idx, &value) in b.data.iter().enumerate() {
500        map.entry(ComplexKey::new(value)).or_insert(idx + 1);
501    }
502
503    let mut mask_data = Vec::<u8>::with_capacity(a.data.len());
504    let mut loc_data = Vec::<f64>::with_capacity(a.data.len());
505
506    for &value in &a.data {
507        let key = ComplexKey::new(value);
508        if let Some(&pos) = map.get(&key) {
509            mask_data.push(1);
510            loc_data.push(pos as f64);
511        } else {
512            mask_data.push(0);
513            loc_data.push(0.0);
514        }
515    }
516
517    let logical = LogicalArray::new(mask_data, a.shape.clone())
518        .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
519    let loc_tensor = Tensor::new(loc_data, a.shape.clone())
520        .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
521    Ok(IsMemberEvaluation::new(logical, loc_tensor))
522}
523
524fn ismember_complex_rows(
525    a: ComplexTensor,
526    b: ComplexTensor,
527) -> crate::BuiltinResult<IsMemberEvaluation> {
528    let (rows_a, cols_a) = complex_rows_cols(&a)?;
529    let (rows_b, cols_b) = complex_rows_cols(&b)?;
530    if cols_a != cols_b {
531        return Err(ismember_error(&ISMEMBER_ERROR_ROWS_COLUMN_MISMATCH).into());
532    }
533
534    let mut map: HashMap<Vec<ComplexKey>, usize> = HashMap::new();
535    for r in 0..rows_b {
536        let mut row_keys = Vec::with_capacity(cols_b);
537        for c in 0..cols_b {
538            let idx = r + c * rows_b;
539            row_keys.push(ComplexKey::new(b.data[idx]));
540        }
541        map.entry(row_keys).or_insert(r + 1);
542    }
543
544    let mut mask_data = vec![0u8; rows_a];
545    let mut loc_data = vec![0.0f64; rows_a];
546
547    for r in 0..rows_a {
548        let mut row_keys = Vec::with_capacity(cols_a);
549        for c in 0..cols_a {
550            let idx = r + c * rows_a;
551            row_keys.push(ComplexKey::new(a.data[idx]));
552        }
553        if let Some(&pos) = map.get(&row_keys) {
554            mask_data[r] = 1;
555            loc_data[r] = pos as f64;
556        }
557    }
558
559    let shape = vec![rows_a, 1];
560    let logical = LogicalArray::new(mask_data, shape.clone())
561        .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
562    let loc_tensor = Tensor::new(loc_data, shape)
563        .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
564    Ok(IsMemberEvaluation::new(logical, loc_tensor))
565}
566
567fn ismember_char(
568    a: CharArray,
569    b: CharArray,
570    rows: bool,
571) -> crate::BuiltinResult<IsMemberEvaluation> {
572    if rows {
573        ismember_char_rows(a, b)
574    } else {
575        ismember_char_elements(a, b)
576    }
577}
578
579fn ismember_char_elements(a: CharArray, b: CharArray) -> crate::BuiltinResult<IsMemberEvaluation> {
580    let rows_b = b.rows;
581    let cols_b = b.cols;
582    let mut map: HashMap<char, usize> = HashMap::new();
583
584    for col in 0..cols_b {
585        for row in 0..rows_b {
586            let data_idx = row * cols_b + col;
587            let ch = b.data[data_idx];
588            let linear_idx = row + col * rows_b;
589            map.entry(ch).or_insert(linear_idx + 1);
590        }
591    }
592
593    let rows_a = a.rows;
594    let cols_a = a.cols;
595    let mut mask_data = vec![0u8; rows_a * cols_a];
596    let mut loc_data = vec![0.0f64; rows_a * cols_a];
597
598    for col in 0..cols_a {
599        for row in 0..rows_a {
600            let data_idx = row * cols_a + col;
601            let ch = a.data[data_idx];
602            let linear_idx = row + col * rows_a;
603            if let Some(&pos) = map.get(&ch) {
604                mask_data[linear_idx] = 1;
605                loc_data[linear_idx] = pos as f64;
606            }
607        }
608    }
609
610    let shape = vec![rows_a, cols_a];
611    let logical = LogicalArray::new(mask_data, shape.clone())
612        .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
613    let loc_tensor = Tensor::new(loc_data, shape)
614        .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
615    Ok(IsMemberEvaluation::new(logical, loc_tensor))
616}
617
618fn ismember_char_rows(a: CharArray, b: CharArray) -> crate::BuiltinResult<IsMemberEvaluation> {
619    if a.cols != b.cols {
620        return Err(ismember_error(&ISMEMBER_ERROR_ROWS_COLUMN_MISMATCH).into());
621    }
622
623    let rows_b = b.rows;
624    let cols = b.cols;
625    let mut map: HashMap<RowCharKey, usize> = HashMap::new();
626
627    for r in 0..rows_b {
628        let mut row_values = Vec::with_capacity(cols);
629        for c in 0..cols {
630            let idx = r * cols + c;
631            row_values.push(b.data[idx]);
632        }
633        let key = RowCharKey::from_slice(&row_values);
634        map.entry(key).or_insert(r + 1);
635    }
636
637    let rows_a = a.rows;
638    let mut mask_data = vec![0u8; rows_a];
639    let mut loc_data = vec![0.0f64; rows_a];
640
641    for r in 0..rows_a {
642        let mut row_values = Vec::with_capacity(cols);
643        for c in 0..cols {
644            let idx = r * cols + c;
645            row_values.push(a.data[idx]);
646        }
647        let key = RowCharKey::from_slice(&row_values);
648        if let Some(&pos) = map.get(&key) {
649            mask_data[r] = 1;
650            loc_data[r] = pos as f64;
651        }
652    }
653
654    let shape = vec![rows_a, 1];
655    let logical = LogicalArray::new(mask_data, shape.clone())
656        .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
657    let loc_tensor = Tensor::new(loc_data, shape)
658        .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
659    Ok(IsMemberEvaluation::new(logical, loc_tensor))
660}
661
662fn ismember_string(
663    a: StringArray,
664    b: StringArray,
665    rows: bool,
666) -> crate::BuiltinResult<IsMemberEvaluation> {
667    if rows {
668        ismember_string_rows(a, b)
669    } else {
670        ismember_string_elements(a, b)
671    }
672}
673
674fn ismember_string_elements(
675    a: StringArray,
676    b: StringArray,
677) -> crate::BuiltinResult<IsMemberEvaluation> {
678    let mut map: HashMap<String, usize> = HashMap::new();
679    for (idx, value) in b.data.iter().enumerate() {
680        map.entry(value.clone()).or_insert(idx + 1);
681    }
682
683    let mut mask_data = Vec::<u8>::with_capacity(a.data.len());
684    let mut loc_data = Vec::<f64>::with_capacity(a.data.len());
685
686    for value in &a.data {
687        if let Some(&pos) = map.get(value) {
688            mask_data.push(1);
689            loc_data.push(pos as f64);
690        } else {
691            mask_data.push(0);
692            loc_data.push(0.0);
693        }
694    }
695
696    let logical = LogicalArray::new(mask_data, a.shape.clone())
697        .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
698    let loc_tensor = Tensor::new(loc_data, a.shape.clone())
699        .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
700    Ok(IsMemberEvaluation::new(logical, loc_tensor))
701}
702
703fn ismember_string_rows(
704    a: StringArray,
705    b: StringArray,
706) -> crate::BuiltinResult<IsMemberEvaluation> {
707    if a.shape.len() != 2 || b.shape.len() != 2 {
708        return Err(ismember_internal_error(
709            "ismember: 'rows' option requires 2-D string arrays",
710        ));
711    }
712    if a.shape[1] != b.shape[1] {
713        return Err(ismember_error(&ISMEMBER_ERROR_ROWS_COLUMN_MISMATCH).into());
714    }
715
716    let rows_a = a.shape[0];
717    let cols = a.shape[1];
718    let rows_b = b.shape[0];
719
720    let mut map: HashMap<RowStringKey, usize> = HashMap::new();
721    for r in 0..rows_b {
722        let mut row_values = Vec::with_capacity(cols);
723        for c in 0..cols {
724            let idx = r + c * rows_b;
725            row_values.push(b.data[idx].clone());
726        }
727        let key = RowStringKey(row_values);
728        map.entry(key).or_insert(r + 1);
729    }
730
731    let mut mask_data = vec![0u8; rows_a];
732    let mut loc_data = vec![0.0f64; rows_a];
733
734    for r in 0..rows_a {
735        let mut row_values = Vec::with_capacity(cols);
736        for c in 0..cols {
737            let idx = r + c * rows_a;
738            row_values.push(a.data[idx].clone());
739        }
740        let key = RowStringKey(row_values);
741        if let Some(&pos) = map.get(&key) {
742            mask_data[r] = 1;
743            loc_data[r] = pos as f64;
744        }
745    }
746
747    let shape = vec![rows_a, 1];
748    let logical = LogicalArray::new(mask_data, shape.clone())
749        .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
750    let loc_tensor = Tensor::new(loc_data, shape)
751        .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
752    Ok(IsMemberEvaluation::new(logical, loc_tensor))
753}
754
755fn tensor_rows_cols(t: &Tensor, name: &str) -> crate::BuiltinResult<(usize, usize)> {
756    match t.shape.len() {
757        0 => Ok((1, 1)),
758        1 => Ok((t.shape[0], 1)),
759        2 => Ok((t.shape[0], t.shape[1])),
760        _ => Err(ismember_internal_error(format!(
761            "{name}: 'rows' option requires 2-D numeric matrices"
762        ))
763        .into()),
764    }
765}
766
767fn complex_rows_cols(t: &ComplexTensor) -> crate::BuiltinResult<(usize, usize)> {
768    match t.shape.len() {
769        0 => Ok((1, 1)),
770        1 => Ok((t.shape[0], 1)),
771        2 => Ok((t.shape[0], t.shape[1])),
772        _ => Err(ismember_internal_error(
773            "ismember: 'rows' option requires 2-D complex matrices",
774        )),
775    }
776}
777
778#[derive(Debug, Clone, PartialEq, Eq, Hash)]
779struct NumericRowKey(Vec<u64>);
780
781impl NumericRowKey {
782    fn from_slice(values: &[f64]) -> Self {
783        NumericRowKey(values.iter().map(|&v| canonicalize_f64(v)).collect())
784    }
785}
786
787#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
788struct ComplexKey {
789    re: u64,
790    im: u64,
791}
792
793impl ComplexKey {
794    fn new(value: (f64, f64)) -> Self {
795        Self {
796            re: canonicalize_f64(value.0),
797            im: canonicalize_f64(value.1),
798        }
799    }
800}
801
802#[derive(Debug, Clone, PartialEq, Eq, Hash)]
803struct RowCharKey(Vec<u32>);
804
805impl RowCharKey {
806    fn from_slice(values: &[char]) -> Self {
807        RowCharKey(values.iter().map(|&ch| ch as u32).collect())
808    }
809}
810
811#[derive(Debug, Clone, PartialEq, Eq, Hash)]
812struct RowStringKey(Vec<String>);
813
814fn canonicalize_f64(value: f64) -> u64 {
815    if value.is_nan() {
816        0x7ff8_0000_0000_0000u64
817    } else if value == 0.0 {
818        0u64
819    } else {
820        value.to_bits()
821    }
822}
823
824#[derive(Debug, Clone)]
825pub struct IsMemberEvaluation {
826    mask: LogicalArray,
827    loc: Tensor,
828}
829
830impl IsMemberEvaluation {
831    fn new(mask: LogicalArray, loc: Tensor) -> Self {
832        Self { mask, loc }
833    }
834
835    pub fn from_provider_result(result: IsMemberResult) -> crate::BuiltinResult<Self> {
836        let mask = LogicalArray::new(result.mask.data, result.mask.shape)
837            .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
838        let loc = Tensor::new(result.loc.data, result.loc.shape)
839            .map_err(|e| ismember_internal_error(format!("ismember: {e}")))?;
840        Ok(IsMemberEvaluation::new(mask, loc))
841    }
842
843    pub fn into_numeric_ismember_result(self) -> crate::BuiltinResult<IsMemberResult> {
844        let IsMemberEvaluation { mask, loc } = self;
845        Ok(IsMemberResult {
846            mask: HostLogicalOwned {
847                data: mask.data,
848                shape: mask.shape,
849            },
850            loc: HostTensorOwned {
851                data: loc.data,
852                shape: loc.shape,
853                storage: GpuTensorStorage::Real,
854            },
855        })
856    }
857
858    pub fn into_mask_value(self) -> Value {
859        logical_array_into_value(self.mask)
860    }
861
862    pub fn mask_value(&self) -> Value {
863        logical_array_into_value(self.mask.clone())
864    }
865
866    pub fn into_pair(self) -> (Value, Value) {
867        let mask = logical_array_into_value(self.mask);
868        let loc = tensor::tensor_into_value(self.loc);
869        (mask, loc)
870    }
871
872    pub fn loc_value(&self) -> Value {
873        tensor::tensor_into_value(self.loc.clone())
874    }
875}
876
877fn logical_array_into_value(logical: LogicalArray) -> Value {
878    if logical.data.len() == 1 {
879        Value::Bool(logical.data[0] != 0)
880    } else {
881        Value::LogicalArray(logical)
882    }
883}
884
885#[cfg(test)]
886pub(crate) mod tests {
887    use super::*;
888    use crate::builtins::common::test_support;
889    use runmat_builtins::{ResolveContext, Tensor, Type};
890
891    #[cfg(feature = "wgpu")]
892    use runmat_accelerate_api::HostTensorView;
893
894    fn evaluate_sync(
895        a: Value,
896        b: Value,
897        rest: &[Value],
898    ) -> crate::BuiltinResult<IsMemberEvaluation> {
899        futures::executor::block_on(evaluate(a, b, rest))
900    }
901
902    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
903    #[test]
904    fn numeric_membership_basic() {
905        let a = Tensor::new(vec![5.0, 7.0, 2.0, 7.0], vec![1, 4]).unwrap();
906        let b = Tensor::new(vec![7.0, 9.0, 5.0], vec![1, 3]).unwrap();
907        let eval = ismember_numeric_elements(a, b).expect("ismember");
908        assert_eq!(eval.mask.data, vec![1, 1, 0, 1]);
909        assert_eq!(eval.loc.data, vec![3.0, 1.0, 0.0, 1.0]);
910    }
911
912    #[test]
913    fn ismember_type_resolver_logical() {
914        assert_eq!(
915            logical_output_type(
916                &[Type::tensor(), Type::tensor()],
917                &ResolveContext::new(Vec::new()),
918            ),
919            Type::logical()
920        );
921    }
922
923    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
924    #[test]
925    fn numeric_nan_membership() {
926        let a = Tensor::new(vec![f64::NAN, 1.0], vec![1, 2]).unwrap();
927        let b = Tensor::new(vec![f64::NAN, 2.0], vec![1, 2]).unwrap();
928        let eval = ismember_numeric_elements(a, b).expect("ismember");
929        assert_eq!(eval.mask.data, vec![1, 0]);
930        assert_eq!(eval.loc.data, vec![1.0, 0.0]);
931    }
932
933    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
934    #[test]
935    fn numeric_rows_membership() {
936        let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
937        let b = Tensor::new(vec![3.0, 5.0, 1.0, 4.0, 6.0, 2.0], vec![3, 2]).unwrap();
938        let eval = ismember_numeric_rows(a, b).expect("ismember");
939        assert_eq!(eval.mask.data, vec![1, 1, 1]);
940        assert_eq!(eval.loc.data, vec![3.0, 1.0, 3.0]);
941        assert_eq!(eval.loc.shape, vec![3, 1]);
942    }
943
944    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
945    #[test]
946    fn complex_membership() {
947        let a = ComplexTensor::new(vec![(1.0, 2.0), (0.0, 0.0)], vec![1, 2]).unwrap();
948        let b = ComplexTensor::new(vec![(0.0, 0.0), (1.0, 2.0)], vec![1, 2]).unwrap();
949        let eval = ismember_complex_elements(a, b).expect("ismember");
950        assert_eq!(eval.mask.data, vec![1, 1]);
951        assert_eq!(eval.loc.data, vec![2.0, 1.0]);
952    }
953
954    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
955    #[test]
956    fn complex_rows_membership() {
957        let a = ComplexTensor::new(
958            vec![(1.0, 1.0), (3.0, 0.0), (2.0, 0.0), (4.0, 4.0)],
959            vec![2, 2],
960        )
961        .unwrap();
962        let b = ComplexTensor::new(
963            vec![
964                (1.0, 1.0),
965                (5.0, 0.0),
966                (3.0, 0.0),
967                (2.0, 0.0),
968                (6.0, 0.0),
969                (4.0, 4.0),
970            ],
971            vec![3, 2],
972        )
973        .unwrap();
974        let eval = ismember_complex_rows(a, b).expect("ismember");
975        assert_eq!(eval.mask.data, vec![1, 1]);
976        assert_eq!(eval.loc.data, vec![1.0, 3.0]);
977    }
978
979    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
980    #[test]
981    fn char_membership() {
982        let a = CharArray::new(vec!['r', 'u', 'n', 'm'], 2, 2).unwrap();
983        let b = CharArray::new(vec!['m', 'a', 'r', 'u'], 2, 2).unwrap();
984        let eval = ismember_char_elements(a, b).expect("ismember");
985        assert_eq!(eval.mask.data, vec![1, 0, 1, 1]);
986        assert_eq!(eval.loc.data, vec![2.0, 0.0, 4.0, 1.0]);
987    }
988
989    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
990    #[test]
991    fn char_rows_membership() {
992        let a = CharArray::new(vec!['m', 'a', 't', 'l'], 2, 2).unwrap();
993        let b = CharArray::new(vec!['m', 'a', 'g', 'e', 't', 'l'], 3, 2).unwrap();
994        let eval = ismember_char_rows(a, b).expect("ismember");
995        assert_eq!(eval.mask.data, vec![1, 1]);
996        assert_eq!(eval.loc.data, vec![1.0, 3.0]);
997    }
998
999    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1000    #[test]
1001    fn string_membership() {
1002        let a = StringArray::new(
1003            vec![
1004                "apple".to_string(),
1005                "pear".to_string(),
1006                "banana".to_string(),
1007            ],
1008            vec![1, 3],
1009        )
1010        .unwrap();
1011        let b = StringArray::new(
1012            vec![
1013                "pear".to_string(),
1014                "orange".to_string(),
1015                "apple".to_string(),
1016            ],
1017            vec![1, 3],
1018        )
1019        .unwrap();
1020        let eval = ismember_string_elements(a, b).expect("ismember");
1021        assert_eq!(eval.mask.data, vec![1, 1, 0]);
1022        assert_eq!(eval.loc.data, vec![3.0, 1.0, 0.0]);
1023    }
1024
1025    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1026    #[test]
1027    fn string_rows_membership() {
1028        let a = StringArray::new(
1029            vec![
1030                "alpha".to_string(),
1031                "gamma".to_string(),
1032                "beta".to_string(),
1033                "delta".to_string(),
1034            ],
1035            vec![2, 2],
1036        )
1037        .unwrap();
1038        let b = StringArray::new(
1039            vec![
1040                "alpha".to_string(),
1041                "theta".to_string(),
1042                "gamma".to_string(),
1043                "beta".to_string(),
1044                "eta".to_string(),
1045                "delta".to_string(),
1046            ],
1047            vec![3, 2],
1048        )
1049        .unwrap();
1050        let eval = ismember_string_rows(a, b).expect("ismember");
1051        assert_eq!(eval.mask.data, vec![1, 1]);
1052        assert_eq!(eval.loc.data, vec![1.0, 3.0]);
1053    }
1054
1055    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1056    #[test]
1057    fn options_reject_legacy() {
1058        let err = parse_options(&[Value::from("legacy")]).unwrap_err();
1059        assert_eq!(
1060            err.identifier(),
1061            ISMEMBER_ERROR_LEGACY_OPTION_UNSUPPORTED.identifier
1062        );
1063    }
1064
1065    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1066    #[test]
1067    fn rejects_unknown_option() {
1068        let err =
1069            evaluate_sync(Value::Num(1.0), Value::Num(1.0), &[Value::from("stable")]).unwrap_err();
1070        assert_eq!(err.identifier(), ISMEMBER_ERROR_UNKNOWN_OPTION.identifier);
1071    }
1072
1073    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1074    #[test]
1075    fn ismember_runtime_numeric() {
1076        let a = Value::Tensor(Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap());
1077        let b = Value::Tensor(Tensor::new(vec![3.0, 1.0], vec![2, 1]).unwrap());
1078        let (mask, loc) = evaluate_sync(a, b, &[]).unwrap().into_pair();
1079        match mask {
1080            Value::LogicalArray(arr) => assert_eq!(arr.data, vec![1, 0, 1]),
1081            other => panic!("expected logical array, got {other:?}"),
1082        }
1083        match loc {
1084            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 0.0, 1.0]),
1085            other => panic!("expected tensor, got {other:?}"),
1086        }
1087    }
1088
1089    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1090    #[test]
1091    fn logical_inputs_promoted() {
1092        let a = Value::Bool(true);
1093        let logical_b =
1094            LogicalArray::new(vec![1, 0], vec![2, 1]).expect("logical array construction");
1095        let eval = evaluate_sync(a, Value::LogicalArray(logical_b), &[]).expect("ismember");
1096        assert_eq!(eval.mask_value(), Value::Bool(true));
1097        assert_eq!(eval.loc_value(), Value::Num(1.0));
1098    }
1099
1100    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1101    #[test]
1102    fn ismember_rows_shape_checks() {
1103        let a = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1104        let b = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1105        assert!(ismember_numeric_rows(a.clone(), b.clone()).is_ok());
1106        let bad = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1107        let err = ismember_numeric_rows(a, bad).unwrap_err();
1108        assert_eq!(
1109            err.identifier(),
1110            ISMEMBER_ERROR_ROWS_COLUMN_MISMATCH.identifier
1111        );
1112    }
1113
1114    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1115    #[test]
1116    fn ismember_gpu_roundtrip() {
1117        test_support::with_test_provider(|provider| {
1118            let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 4.0], vec![4, 1]).unwrap();
1119            let set = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1120            let view_a = runmat_accelerate_api::HostTensorView {
1121                data: &tensor.data,
1122                shape: &tensor.shape,
1123            };
1124            let view_b = runmat_accelerate_api::HostTensorView {
1125                data: &set.data,
1126                shape: &set.shape,
1127            };
1128            let handle_a = provider.upload(&view_a).expect("upload a");
1129            let handle_b = provider.upload(&view_b).expect("upload b");
1130            let eval = evaluate_sync(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
1131                .expect("ismember");
1132            assert_eq!(eval.mask.data, vec![0, 1, 0, 1]);
1133            assert_eq!(eval.loc.data, vec![0.0, 1.0, 0.0, 1.0]);
1134        });
1135    }
1136
1137    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1138    #[test]
1139    fn ismember_gpu_rows_roundtrip() {
1140        test_support::with_test_provider(|provider| {
1141            let rows = Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
1142            let bank = Tensor::new(vec![1.0, 5.0, 3.0, 2.0, 6.0, 4.0], vec![3, 2]).unwrap();
1143            let view_a = runmat_accelerate_api::HostTensorView {
1144                data: &rows.data,
1145                shape: &rows.shape,
1146            };
1147            let view_b = runmat_accelerate_api::HostTensorView {
1148                data: &bank.data,
1149                shape: &bank.shape,
1150            };
1151            let handle_a = provider.upload(&view_a).expect("upload a");
1152            let handle_b = provider.upload(&view_b).expect("upload b");
1153            let eval = evaluate_sync(
1154                Value::GpuTensor(handle_a.clone()),
1155                Value::GpuTensor(handle_b.clone()),
1156                &[Value::from("rows")],
1157            )
1158            .expect("ismember");
1159            assert_eq!(eval.mask.data, vec![1, 1]);
1160            assert_eq!(eval.loc.data, vec![1.0, 3.0]);
1161            let _ = provider.free(&handle_a);
1162            let _ = provider.free(&handle_b);
1163        });
1164    }
1165
1166    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1167    #[test]
1168    #[cfg(feature = "wgpu")]
1169    fn ismember_wgpu_numeric_matches_cpu() {
1170        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1171            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1172        );
1173
1174        let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 4.0], vec![4, 1]).unwrap();
1175        let set = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1176        let cpu_eval =
1177            ismember_numeric_from_tensors(tensor.clone(), set.clone(), false).expect("cpu");
1178
1179        let provider = runmat_accelerate_api::provider().expect("provider");
1180        let view_a = HostTensorView {
1181            data: &tensor.data,
1182            shape: &tensor.shape,
1183        };
1184        let view_b = HostTensorView {
1185            data: &set.data,
1186            shape: &set.shape,
1187        };
1188        let handle_a = provider.upload(&view_a).expect("upload a");
1189        let handle_b = provider.upload(&view_b).expect("upload b");
1190
1191        let eval = evaluate_sync(
1192            Value::GpuTensor(handle_a.clone()),
1193            Value::GpuTensor(handle_b.clone()),
1194            &[],
1195        )
1196        .expect("gpu evaluate");
1197        assert_eq!(eval.mask.data, cpu_eval.mask.data);
1198        assert_eq!(eval.loc.data, cpu_eval.loc.data);
1199
1200        let _ = provider.free(&handle_a);
1201        let _ = provider.free(&handle_b);
1202
1203        let matrix = Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
1204        let bank = Tensor::new(vec![1.0, 7.0, 3.0, 2.0, 9.0, 4.0], vec![3, 2]).unwrap();
1205        let cpu_rows =
1206            ismember_numeric_from_tensors(matrix.clone(), bank.clone(), true).expect("cpu rows");
1207        let view_matrix = HostTensorView {
1208            data: &matrix.data,
1209            shape: &matrix.shape,
1210        };
1211        let view_bank = HostTensorView {
1212            data: &bank.data,
1213            shape: &bank.shape,
1214        };
1215        let handle_matrix = provider.upload(&view_matrix).expect("upload matrix");
1216        let handle_bank = provider.upload(&view_bank).expect("upload bank");
1217        let eval_rows = evaluate_sync(
1218            Value::GpuTensor(handle_matrix.clone()),
1219            Value::GpuTensor(handle_bank.clone()),
1220            &[Value::from("rows")],
1221        )
1222        .expect("gpu rows evaluate");
1223        assert_eq!(eval_rows.mask.data, cpu_rows.mask.data);
1224        assert_eq!(eval_rows.loc.data, cpu_rows.loc.data);
1225        let _ = provider.free(&handle_matrix);
1226        let _ = provider.free(&handle_bank);
1227    }
1228
1229    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1230    #[test]
1231    fn scalar_return_is_bool() {
1232        let a = Value::Tensor(Tensor::new(vec![7.0], vec![1, 1]).unwrap());
1233        let b = Value::Tensor(Tensor::new(vec![7.0], vec![1, 1]).unwrap());
1234        let mask = evaluate_sync(a, b, &[]).unwrap().into_mask_value();
1235        assert_eq!(mask, Value::Bool(true));
1236    }
1237
1238    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1239    #[test]
1240    fn parse_rows_option() {
1241        let opts = parse_options(&[Value::from("rows")]).unwrap();
1242        assert!(opts.rows);
1243    }
1244
1245    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1246    #[test]
1247    fn numeric_rows_with_nan() {
1248        let a = Tensor::new(vec![f64::NAN, 1.0], vec![2, 1]).unwrap();
1249        let b = Tensor::new(vec![f64::NAN, 2.0], vec![2, 1]).unwrap();
1250        let eval = ismember_numeric_rows(a, b).expect("ismember");
1251        assert_eq!(eval.mask.data, vec![1, 0]);
1252        assert_eq!(eval.loc.data, vec![1.0, 0.0]);
1253    }
1254}