Skip to main content

runmat_runtime/builtins/array/sorting_sets/
ismember.rs

1//! MATLAB-compatible `ismember` builtin with GPU-aware semantics for RunMat.
2
3use std::collections::HashMap;
4
5use runmat_accelerate_api::{
6    GpuTensorHandle, GpuTensorStorage, HostLogicalOwned, HostTensorOwned,
7    IsMemberOptions as ProviderIsMemberOptions, IsMemberResult,
8};
9use runmat_builtins::{CharArray, ComplexTensor, LogicalArray, StringArray, Tensor, Value};
10use runmat_macros::runtime_builtin;
11
12use super::type_resolvers::logical_output_type;
13use crate::build_runtime_error;
14use crate::builtins::common::gpu_helpers;
15use crate::builtins::common::spec::{
16    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
17    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
18};
19use crate::builtins::common::tensor;
20
21#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::sorting_sets::ismember")]
22pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
23    name: "ismember",
24    op_kind: GpuOpKind::Custom("ismember"),
25    supported_precisions: &[ScalarType::F32, ScalarType::F64],
26    broadcast: BroadcastSemantics::None,
27    provider_hooks: &[ProviderHook::Custom("ismember")],
28    constant_strategy: ConstantStrategy::InlineLiteral,
29    residency: ResidencyPolicy::GatherImmediately,
30    nan_mode: ReductionNaN::Include,
31    two_pass_threshold: None,
32    workgroup_size: None,
33    accepts_nan_mode: false,
34    notes: "Providers may supply dedicated membership kernels; until then RunMat gathers GPU tensors to host memory.",
35};
36
37#[runmat_macros::register_fusion_spec(
38    builtin_path = "crate::builtins::array::sorting_sets::ismember"
39)]
40pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
41    name: "ismember",
42    shape: ShapeRequirements::Any,
43    constant_strategy: ConstantStrategy::InlineLiteral,
44    elementwise: None,
45    reduction: None,
46    emits_nan: false,
47    notes: "`ismember` materialises logical outputs and terminates fusion chains; upstream tensors are gathered when necessary.",
48};
49
50fn ismember_error(message: impl Into<String>) -> crate::RuntimeError {
51    build_runtime_error(message)
52        .with_builtin("ismember")
53        .build()
54}
55
56#[runtime_builtin(
57    name = "ismember",
58    category = "array/sorting_sets",
59    summary = "Identify array elements or rows that appear in another array while returning first-match indices.",
60    keywords = "ismember,membership,set,rows,indices,gpu",
61    accel = "array_construct",
62    sink = true,
63    type_resolver(logical_output_type),
64    builtin_path = "crate::builtins::array::sorting_sets::ismember"
65)]
66async fn ismember_builtin(a: Value, b: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
67    let eval = evaluate(a, b, &rest).await?;
68    if let Some(out_count) = crate::output_count::current_output_count() {
69        if out_count == 0 {
70            return Ok(Value::OutputList(Vec::new()));
71        }
72        if out_count == 1 {
73            return Ok(Value::OutputList(vec![eval.into_mask_value()]));
74        }
75        let (mask, loc) = eval.into_pair();
76        return Ok(crate::output_count::output_list_with_padding(
77            out_count,
78            vec![mask, loc],
79        ));
80    }
81    Ok(eval.into_mask_value())
82}
83
84/// Evaluate the `ismember` builtin once and expose all outputs.
85pub async fn evaluate(
86    a: Value,
87    b: Value,
88    rest: &[Value],
89) -> crate::BuiltinResult<IsMemberEvaluation> {
90    let opts = parse_options(rest)?;
91    match (a, b) {
92        (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
93            ismember_gpu_pair(handle_a, handle_b, &opts).await
94        }
95        (Value::GpuTensor(handle_a), other) => {
96            ismember_gpu_mixed(handle_a, other, &opts, true).await
97        }
98        (other, Value::GpuTensor(handle_b)) => {
99            ismember_gpu_mixed(handle_b, other, &opts, false).await
100        }
101        (left, right) => ismember_host(left, right, &opts),
102    }
103}
104
105#[derive(Debug, Clone, Copy)]
106struct IsMemberOptions {
107    rows: bool,
108}
109
110impl IsMemberOptions {
111    fn into_provider_options(self) -> ProviderIsMemberOptions {
112        ProviderIsMemberOptions { rows: self.rows }
113    }
114}
115
116fn parse_options(rest: &[Value]) -> crate::BuiltinResult<IsMemberOptions> {
117    let mut opts = IsMemberOptions { rows: false };
118    for arg in rest {
119        let text = tensor::value_to_string(arg)
120            .ok_or_else(|| ismember_error("ismember: expected string option arguments"))?;
121        let lowered = text.trim().to_ascii_lowercase();
122        match lowered.as_str() {
123            "rows" => opts.rows = true,
124            "legacy" | "r2012a" => {
125                return Err(ismember_error(
126                    "ismember: the 'legacy' behaviour is not supported",
127                ))
128            }
129            other => {
130                return Err(ismember_error(format!(
131                    "ismember: unrecognised option '{other}'"
132                )))
133            }
134        }
135    }
136    Ok(opts)
137}
138
139async fn ismember_gpu_pair(
140    handle_a: GpuTensorHandle,
141    handle_b: GpuTensorHandle,
142    opts: &IsMemberOptions,
143) -> crate::BuiltinResult<IsMemberEvaluation> {
144    if let Some(provider) = runmat_accelerate_api::provider() {
145        let provider_opts = opts.into_provider_options();
146        match provider
147            .ismember(&handle_a, &handle_b, &provider_opts)
148            .await
149        {
150            Ok(result) => return IsMemberEvaluation::from_provider_result(result),
151            Err(_) => {
152                // Fall back to host gather when the provider lacks an ismember implementation.
153            }
154        }
155    }
156    let tensor_a = gpu_helpers::gather_tensor_async(&handle_a).await?;
157    let tensor_b = gpu_helpers::gather_tensor_async(&handle_b).await?;
158    ismember_numeric_tensors(tensor_a, tensor_b, opts)
159}
160
161async fn ismember_gpu_mixed(
162    handle_gpu: GpuTensorHandle,
163    other: Value,
164    opts: &IsMemberOptions,
165    gpu_is_a: bool,
166) -> crate::BuiltinResult<IsMemberEvaluation> {
167    let tensor_gpu = gpu_helpers::gather_tensor_async(&handle_gpu).await?;
168    if gpu_is_a {
169        ismember_host(Value::Tensor(tensor_gpu), other, opts)
170    } else {
171        ismember_host(other, Value::Tensor(tensor_gpu), opts)
172    }
173}
174
175fn ismember_host(
176    a: Value,
177    b: Value,
178    opts: &IsMemberOptions,
179) -> crate::BuiltinResult<IsMemberEvaluation> {
180    match (a, b) {
181        (Value::ComplexTensor(at), Value::ComplexTensor(bt)) => ismember_complex(at, bt, opts.rows),
182        (Value::ComplexTensor(at), Value::Complex(re, im)) => {
183            let bt = ComplexTensor::new(vec![(re, im)], vec![1, 1])
184                .map_err(|e| ismember_error(format!("ismember: {e}")))?;
185            ismember_complex(at, bt, opts.rows)
186        }
187        (Value::Complex(a_re, a_im), Value::ComplexTensor(bt)) => {
188            let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
189                .map_err(|e| ismember_error(format!("ismember: {e}")))?;
190            ismember_complex(at, bt, opts.rows)
191        }
192        (Value::Complex(a_re, a_im), Value::Complex(b_re, b_im)) => {
193            let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
194                .map_err(|e| ismember_error(format!("ismember: {e}")))?;
195            let bt = ComplexTensor::new(vec![(b_re, b_im)], vec![1, 1])
196                .map_err(|e| ismember_error(format!("ismember: {e}")))?;
197            ismember_complex(at, bt, opts.rows)
198        }
199
200        (Value::CharArray(ac), Value::CharArray(bc)) => ismember_char(ac, bc, opts.rows),
201
202        (Value::StringArray(astring), Value::StringArray(bstring)) => {
203            ismember_string(astring, bstring, opts.rows)
204        }
205        (Value::StringArray(astring), Value::String(b)) => {
206            let bstring = StringArray::new(vec![b], vec![1, 1])
207                .map_err(|e| ismember_error(format!("ismember: {e}")))?;
208            ismember_string(astring, bstring, opts.rows)
209        }
210        (Value::String(a), Value::StringArray(bstring)) => {
211            let astring = StringArray::new(vec![a], vec![1, 1])
212                .map_err(|e| ismember_error(format!("ismember: {e}")))?;
213            ismember_string(astring, bstring, opts.rows)
214        }
215        (Value::String(a), Value::String(b)) => {
216            let astring = StringArray::new(vec![a], vec![1, 1])
217                .map_err(|e| ismember_error(format!("ismember: {e}")))?;
218            let bstring = StringArray::new(vec![b], vec![1, 1])
219                .map_err(|e| ismember_error(format!("ismember: {e}")))?;
220            ismember_string(astring, bstring, opts.rows)
221        }
222
223        (left, right) => {
224            let tensor_a =
225                tensor::value_into_tensor_for("ismember", left).map_err(|e| ismember_error(e))?;
226            let tensor_b =
227                tensor::value_into_tensor_for("ismember", right).map_err(|e| ismember_error(e))?;
228            ismember_numeric_tensors(tensor_a, tensor_b, opts)
229        }
230    }
231}
232
233fn ismember_numeric_tensors(
234    a: Tensor,
235    b: Tensor,
236    opts: &IsMemberOptions,
237) -> crate::BuiltinResult<IsMemberEvaluation> {
238    if opts.rows {
239        ismember_numeric_rows(a, b)
240    } else {
241        ismember_numeric_elements(a, b)
242    }
243}
244
245/// Helper exposed for acceleration providers handling numeric tensors on the host.
246pub fn ismember_numeric_from_tensors(
247    a: Tensor,
248    b: Tensor,
249    rows: bool,
250) -> crate::BuiltinResult<IsMemberEvaluation> {
251    let opts = IsMemberOptions { rows };
252    ismember_numeric_tensors(a, b, &opts)
253}
254
255fn ismember_numeric_elements(a: Tensor, b: Tensor) -> crate::BuiltinResult<IsMemberEvaluation> {
256    let mut map: HashMap<u64, usize> = HashMap::new();
257    for (idx, &value) in b.data.iter().enumerate() {
258        map.entry(canonicalize_f64(value)).or_insert(idx + 1);
259    }
260
261    let mut mask_data = Vec::<u8>::with_capacity(a.data.len());
262    let mut loc_data = Vec::<f64>::with_capacity(a.data.len());
263
264    for &value in &a.data {
265        let key = canonicalize_f64(value);
266        if let Some(&pos) = map.get(&key) {
267            mask_data.push(1);
268            loc_data.push(pos as f64);
269        } else {
270            mask_data.push(0);
271            loc_data.push(0.0);
272        }
273    }
274
275    let logical = LogicalArray::new(mask_data, a.shape.clone())
276        .map_err(|e| ismember_error(format!("ismember: {e}")))?;
277    let loc_tensor = Tensor::new(loc_data, a.shape.clone())
278        .map_err(|e| ismember_error(format!("ismember: {e}")))?;
279    Ok(IsMemberEvaluation::new(logical, loc_tensor))
280}
281
282fn ismember_numeric_rows(a: Tensor, b: Tensor) -> crate::BuiltinResult<IsMemberEvaluation> {
283    let (rows_a, cols_a) = tensor_rows_cols(&a, "ismember")?;
284    let (rows_b, cols_b) = tensor_rows_cols(&b, "ismember")?;
285    if cols_a != cols_b {
286        return Err(ismember_error(
287            "ismember: inputs must have the same number of columns when using 'rows'",
288        ));
289    }
290
291    let mut map: HashMap<NumericRowKey, usize> = HashMap::new();
292    for r in 0..rows_b {
293        let mut row_values = Vec::with_capacity(cols_b);
294        for c in 0..cols_b {
295            let idx = r + c * rows_b;
296            row_values.push(b.data[idx]);
297        }
298        let key = NumericRowKey::from_slice(&row_values);
299        map.entry(key).or_insert(r + 1);
300    }
301
302    let mut mask_data = vec![0u8; rows_a];
303    let mut loc_data = vec![0.0f64; rows_a];
304
305    for r in 0..rows_a {
306        let mut row_values = Vec::with_capacity(cols_a);
307        for c in 0..cols_a {
308            let idx = r + c * rows_a;
309            row_values.push(a.data[idx]);
310        }
311        let key = NumericRowKey::from_slice(&row_values);
312        if let Some(&pos) = map.get(&key) {
313            mask_data[r] = 1;
314            loc_data[r] = pos as f64;
315        }
316    }
317
318    let shape = vec![rows_a, 1];
319    let logical = LogicalArray::new(mask_data, shape.clone())
320        .map_err(|e| ismember_error(format!("ismember: {e}")))?;
321    let loc_tensor =
322        Tensor::new(loc_data, shape).map_err(|e| ismember_error(format!("ismember: {e}")))?;
323    Ok(IsMemberEvaluation::new(logical, loc_tensor))
324}
325
326fn ismember_complex(
327    a: ComplexTensor,
328    b: ComplexTensor,
329    rows: bool,
330) -> crate::BuiltinResult<IsMemberEvaluation> {
331    if rows {
332        ismember_complex_rows(a, b)
333    } else {
334        ismember_complex_elements(a, b)
335    }
336}
337
338fn ismember_complex_elements(
339    a: ComplexTensor,
340    b: ComplexTensor,
341) -> crate::BuiltinResult<IsMemberEvaluation> {
342    let mut map: HashMap<ComplexKey, usize> = HashMap::new();
343    for (idx, &value) in b.data.iter().enumerate() {
344        map.entry(ComplexKey::new(value)).or_insert(idx + 1);
345    }
346
347    let mut mask_data = Vec::<u8>::with_capacity(a.data.len());
348    let mut loc_data = Vec::<f64>::with_capacity(a.data.len());
349
350    for &value in &a.data {
351        let key = ComplexKey::new(value);
352        if let Some(&pos) = map.get(&key) {
353            mask_data.push(1);
354            loc_data.push(pos as f64);
355        } else {
356            mask_data.push(0);
357            loc_data.push(0.0);
358        }
359    }
360
361    let logical = LogicalArray::new(mask_data, a.shape.clone())
362        .map_err(|e| ismember_error(format!("ismember: {e}")))?;
363    let loc_tensor = Tensor::new(loc_data, a.shape.clone())
364        .map_err(|e| ismember_error(format!("ismember: {e}")))?;
365    Ok(IsMemberEvaluation::new(logical, loc_tensor))
366}
367
368fn ismember_complex_rows(
369    a: ComplexTensor,
370    b: ComplexTensor,
371) -> crate::BuiltinResult<IsMemberEvaluation> {
372    let (rows_a, cols_a) = complex_rows_cols(&a)?;
373    let (rows_b, cols_b) = complex_rows_cols(&b)?;
374    if cols_a != cols_b {
375        return Err(ismember_error(
376            "ismember: complex inputs must have the same number of columns when using 'rows'",
377        )
378        .into());
379    }
380
381    let mut map: HashMap<Vec<ComplexKey>, usize> = HashMap::new();
382    for r in 0..rows_b {
383        let mut row_keys = Vec::with_capacity(cols_b);
384        for c in 0..cols_b {
385            let idx = r + c * rows_b;
386            row_keys.push(ComplexKey::new(b.data[idx]));
387        }
388        map.entry(row_keys).or_insert(r + 1);
389    }
390
391    let mut mask_data = vec![0u8; rows_a];
392    let mut loc_data = vec![0.0f64; rows_a];
393
394    for r in 0..rows_a {
395        let mut row_keys = Vec::with_capacity(cols_a);
396        for c in 0..cols_a {
397            let idx = r + c * rows_a;
398            row_keys.push(ComplexKey::new(a.data[idx]));
399        }
400        if let Some(&pos) = map.get(&row_keys) {
401            mask_data[r] = 1;
402            loc_data[r] = pos as f64;
403        }
404    }
405
406    let shape = vec![rows_a, 1];
407    let logical = LogicalArray::new(mask_data, shape.clone())
408        .map_err(|e| ismember_error(format!("ismember: {e}")))?;
409    let loc_tensor =
410        Tensor::new(loc_data, shape).map_err(|e| ismember_error(format!("ismember: {e}")))?;
411    Ok(IsMemberEvaluation::new(logical, loc_tensor))
412}
413
414fn ismember_char(
415    a: CharArray,
416    b: CharArray,
417    rows: bool,
418) -> crate::BuiltinResult<IsMemberEvaluation> {
419    if rows {
420        ismember_char_rows(a, b)
421    } else {
422        ismember_char_elements(a, b)
423    }
424}
425
426fn ismember_char_elements(a: CharArray, b: CharArray) -> crate::BuiltinResult<IsMemberEvaluation> {
427    let rows_b = b.rows;
428    let cols_b = b.cols;
429    let mut map: HashMap<char, usize> = HashMap::new();
430
431    for col in 0..cols_b {
432        for row in 0..rows_b {
433            let data_idx = row * cols_b + col;
434            let ch = b.data[data_idx];
435            let linear_idx = row + col * rows_b;
436            map.entry(ch).or_insert(linear_idx + 1);
437        }
438    }
439
440    let rows_a = a.rows;
441    let cols_a = a.cols;
442    let mut mask_data = vec![0u8; rows_a * cols_a];
443    let mut loc_data = vec![0.0f64; rows_a * cols_a];
444
445    for col in 0..cols_a {
446        for row in 0..rows_a {
447            let data_idx = row * cols_a + col;
448            let ch = a.data[data_idx];
449            let linear_idx = row + col * rows_a;
450            if let Some(&pos) = map.get(&ch) {
451                mask_data[linear_idx] = 1;
452                loc_data[linear_idx] = pos as f64;
453            }
454        }
455    }
456
457    let shape = vec![rows_a, cols_a];
458    let logical = LogicalArray::new(mask_data, shape.clone())
459        .map_err(|e| ismember_error(format!("ismember: {e}")))?;
460    let loc_tensor =
461        Tensor::new(loc_data, shape).map_err(|e| ismember_error(format!("ismember: {e}")))?;
462    Ok(IsMemberEvaluation::new(logical, loc_tensor))
463}
464
465fn ismember_char_rows(a: CharArray, b: CharArray) -> crate::BuiltinResult<IsMemberEvaluation> {
466    if a.cols != b.cols {
467        return Err(ismember_error(
468            "ismember: character inputs must have the same number of columns when using 'rows'",
469        )
470        .into());
471    }
472
473    let rows_b = b.rows;
474    let cols = b.cols;
475    let mut map: HashMap<RowCharKey, usize> = HashMap::new();
476
477    for r in 0..rows_b {
478        let mut row_values = Vec::with_capacity(cols);
479        for c in 0..cols {
480            let idx = r * cols + c;
481            row_values.push(b.data[idx]);
482        }
483        let key = RowCharKey::from_slice(&row_values);
484        map.entry(key).or_insert(r + 1);
485    }
486
487    let rows_a = a.rows;
488    let mut mask_data = vec![0u8; rows_a];
489    let mut loc_data = vec![0.0f64; rows_a];
490
491    for r in 0..rows_a {
492        let mut row_values = Vec::with_capacity(cols);
493        for c in 0..cols {
494            let idx = r * cols + c;
495            row_values.push(a.data[idx]);
496        }
497        let key = RowCharKey::from_slice(&row_values);
498        if let Some(&pos) = map.get(&key) {
499            mask_data[r] = 1;
500            loc_data[r] = pos as f64;
501        }
502    }
503
504    let shape = vec![rows_a, 1];
505    let logical = LogicalArray::new(mask_data, shape.clone())
506        .map_err(|e| ismember_error(format!("ismember: {e}")))?;
507    let loc_tensor =
508        Tensor::new(loc_data, shape).map_err(|e| ismember_error(format!("ismember: {e}")))?;
509    Ok(IsMemberEvaluation::new(logical, loc_tensor))
510}
511
512fn ismember_string(
513    a: StringArray,
514    b: StringArray,
515    rows: bool,
516) -> crate::BuiltinResult<IsMemberEvaluation> {
517    if rows {
518        ismember_string_rows(a, b)
519    } else {
520        ismember_string_elements(a, b)
521    }
522}
523
524fn ismember_string_elements(
525    a: StringArray,
526    b: StringArray,
527) -> crate::BuiltinResult<IsMemberEvaluation> {
528    let mut map: HashMap<String, usize> = HashMap::new();
529    for (idx, value) in b.data.iter().enumerate() {
530        map.entry(value.clone()).or_insert(idx + 1);
531    }
532
533    let mut mask_data = Vec::<u8>::with_capacity(a.data.len());
534    let mut loc_data = Vec::<f64>::with_capacity(a.data.len());
535
536    for value in &a.data {
537        if let Some(&pos) = map.get(value) {
538            mask_data.push(1);
539            loc_data.push(pos as f64);
540        } else {
541            mask_data.push(0);
542            loc_data.push(0.0);
543        }
544    }
545
546    let logical = LogicalArray::new(mask_data, a.shape.clone())
547        .map_err(|e| ismember_error(format!("ismember: {e}")))?;
548    let loc_tensor = Tensor::new(loc_data, a.shape.clone())
549        .map_err(|e| ismember_error(format!("ismember: {e}")))?;
550    Ok(IsMemberEvaluation::new(logical, loc_tensor))
551}
552
553fn ismember_string_rows(
554    a: StringArray,
555    b: StringArray,
556) -> crate::BuiltinResult<IsMemberEvaluation> {
557    if a.shape.len() != 2 || b.shape.len() != 2 {
558        return Err(ismember_error(
559            "ismember: 'rows' option requires 2-D string arrays",
560        ));
561    }
562    if a.shape[1] != b.shape[1] {
563        return Err(ismember_error(
564            "ismember: string inputs must have the same number of columns when using 'rows'",
565        )
566        .into());
567    }
568
569    let rows_a = a.shape[0];
570    let cols = a.shape[1];
571    let rows_b = b.shape[0];
572
573    let mut map: HashMap<RowStringKey, usize> = HashMap::new();
574    for r in 0..rows_b {
575        let mut row_values = Vec::with_capacity(cols);
576        for c in 0..cols {
577            let idx = r + c * rows_b;
578            row_values.push(b.data[idx].clone());
579        }
580        let key = RowStringKey(row_values);
581        map.entry(key).or_insert(r + 1);
582    }
583
584    let mut mask_data = vec![0u8; rows_a];
585    let mut loc_data = vec![0.0f64; rows_a];
586
587    for r in 0..rows_a {
588        let mut row_values = Vec::with_capacity(cols);
589        for c in 0..cols {
590            let idx = r + c * rows_a;
591            row_values.push(a.data[idx].clone());
592        }
593        let key = RowStringKey(row_values);
594        if let Some(&pos) = map.get(&key) {
595            mask_data[r] = 1;
596            loc_data[r] = pos as f64;
597        }
598    }
599
600    let shape = vec![rows_a, 1];
601    let logical = LogicalArray::new(mask_data, shape.clone())
602        .map_err(|e| ismember_error(format!("ismember: {e}")))?;
603    let loc_tensor =
604        Tensor::new(loc_data, shape).map_err(|e| ismember_error(format!("ismember: {e}")))?;
605    Ok(IsMemberEvaluation::new(logical, loc_tensor))
606}
607
608fn tensor_rows_cols(t: &Tensor, name: &str) -> crate::BuiltinResult<(usize, usize)> {
609    match t.shape.len() {
610        0 => Ok((1, 1)),
611        1 => Ok((t.shape[0], 1)),
612        2 => Ok((t.shape[0], t.shape[1])),
613        _ => Err(ismember_error(format!(
614            "{name}: 'rows' option requires 2-D numeric matrices"
615        ))
616        .into()),
617    }
618}
619
620fn complex_rows_cols(t: &ComplexTensor) -> crate::BuiltinResult<(usize, usize)> {
621    match t.shape.len() {
622        0 => Ok((1, 1)),
623        1 => Ok((t.shape[0], 1)),
624        2 => Ok((t.shape[0], t.shape[1])),
625        _ => Err(ismember_error(
626            "ismember: 'rows' option requires 2-D complex matrices",
627        )),
628    }
629}
630
631#[derive(Debug, Clone, PartialEq, Eq, Hash)]
632struct NumericRowKey(Vec<u64>);
633
634impl NumericRowKey {
635    fn from_slice(values: &[f64]) -> Self {
636        NumericRowKey(values.iter().map(|&v| canonicalize_f64(v)).collect())
637    }
638}
639
640#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
641struct ComplexKey {
642    re: u64,
643    im: u64,
644}
645
646impl ComplexKey {
647    fn new(value: (f64, f64)) -> Self {
648        Self {
649            re: canonicalize_f64(value.0),
650            im: canonicalize_f64(value.1),
651        }
652    }
653}
654
655#[derive(Debug, Clone, PartialEq, Eq, Hash)]
656struct RowCharKey(Vec<u32>);
657
658impl RowCharKey {
659    fn from_slice(values: &[char]) -> Self {
660        RowCharKey(values.iter().map(|&ch| ch as u32).collect())
661    }
662}
663
664#[derive(Debug, Clone, PartialEq, Eq, Hash)]
665struct RowStringKey(Vec<String>);
666
667fn canonicalize_f64(value: f64) -> u64 {
668    if value.is_nan() {
669        0x7ff8_0000_0000_0000u64
670    } else if value == 0.0 {
671        0u64
672    } else {
673        value.to_bits()
674    }
675}
676
677#[derive(Debug, Clone)]
678pub struct IsMemberEvaluation {
679    mask: LogicalArray,
680    loc: Tensor,
681}
682
683impl IsMemberEvaluation {
684    fn new(mask: LogicalArray, loc: Tensor) -> Self {
685        Self { mask, loc }
686    }
687
688    pub fn from_provider_result(result: IsMemberResult) -> crate::BuiltinResult<Self> {
689        let mask = LogicalArray::new(result.mask.data, result.mask.shape)
690            .map_err(|e| ismember_error(format!("ismember: {e}")))?;
691        let loc = Tensor::new(result.loc.data, result.loc.shape)
692            .map_err(|e| ismember_error(format!("ismember: {e}")))?;
693        Ok(IsMemberEvaluation::new(mask, loc))
694    }
695
696    pub fn into_numeric_ismember_result(self) -> crate::BuiltinResult<IsMemberResult> {
697        let IsMemberEvaluation { mask, loc } = self;
698        Ok(IsMemberResult {
699            mask: HostLogicalOwned {
700                data: mask.data,
701                shape: mask.shape,
702            },
703            loc: HostTensorOwned {
704                data: loc.data,
705                shape: loc.shape,
706                storage: GpuTensorStorage::Real,
707            },
708        })
709    }
710
711    pub fn into_mask_value(self) -> Value {
712        logical_array_into_value(self.mask)
713    }
714
715    pub fn mask_value(&self) -> Value {
716        logical_array_into_value(self.mask.clone())
717    }
718
719    pub fn into_pair(self) -> (Value, Value) {
720        let mask = logical_array_into_value(self.mask);
721        let loc = tensor::tensor_into_value(self.loc);
722        (mask, loc)
723    }
724
725    pub fn loc_value(&self) -> Value {
726        tensor::tensor_into_value(self.loc.clone())
727    }
728}
729
730fn logical_array_into_value(logical: LogicalArray) -> Value {
731    if logical.data.len() == 1 {
732        Value::Bool(logical.data[0] != 0)
733    } else {
734        Value::LogicalArray(logical)
735    }
736}
737
738#[cfg(test)]
739pub(crate) mod tests {
740    use super::*;
741    use crate::builtins::common::test_support;
742    use runmat_builtins::{ResolveContext, Tensor, Type};
743
744    #[cfg(feature = "wgpu")]
745    use runmat_accelerate_api::HostTensorView;
746
747    fn error_message(err: crate::RuntimeError) -> String {
748        err.message().to_string()
749    }
750
751    fn evaluate_sync(
752        a: Value,
753        b: Value,
754        rest: &[Value],
755    ) -> crate::BuiltinResult<IsMemberEvaluation> {
756        futures::executor::block_on(evaluate(a, b, rest))
757    }
758
759    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
760    #[test]
761    fn numeric_membership_basic() {
762        let a = Tensor::new(vec![5.0, 7.0, 2.0, 7.0], vec![1, 4]).unwrap();
763        let b = Tensor::new(vec![7.0, 9.0, 5.0], vec![1, 3]).unwrap();
764        let eval = ismember_numeric_elements(a, b).expect("ismember");
765        assert_eq!(eval.mask.data, vec![1, 1, 0, 1]);
766        assert_eq!(eval.loc.data, vec![3.0, 1.0, 0.0, 1.0]);
767    }
768
769    #[test]
770    fn ismember_type_resolver_logical() {
771        assert_eq!(
772            logical_output_type(
773                &[Type::tensor(), Type::tensor()],
774                &ResolveContext::new(Vec::new()),
775            ),
776            Type::logical()
777        );
778    }
779
780    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
781    #[test]
782    fn numeric_nan_membership() {
783        let a = Tensor::new(vec![f64::NAN, 1.0], vec![1, 2]).unwrap();
784        let b = Tensor::new(vec![f64::NAN, 2.0], vec![1, 2]).unwrap();
785        let eval = ismember_numeric_elements(a, b).expect("ismember");
786        assert_eq!(eval.mask.data, vec![1, 0]);
787        assert_eq!(eval.loc.data, vec![1.0, 0.0]);
788    }
789
790    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
791    #[test]
792    fn numeric_rows_membership() {
793        let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
794        let b = Tensor::new(vec![3.0, 5.0, 1.0, 4.0, 6.0, 2.0], vec![3, 2]).unwrap();
795        let eval = ismember_numeric_rows(a, b).expect("ismember");
796        assert_eq!(eval.mask.data, vec![1, 1, 1]);
797        assert_eq!(eval.loc.data, vec![3.0, 1.0, 3.0]);
798        assert_eq!(eval.loc.shape, vec![3, 1]);
799    }
800
801    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
802    #[test]
803    fn complex_membership() {
804        let a = ComplexTensor::new(vec![(1.0, 2.0), (0.0, 0.0)], vec![1, 2]).unwrap();
805        let b = ComplexTensor::new(vec![(0.0, 0.0), (1.0, 2.0)], vec![1, 2]).unwrap();
806        let eval = ismember_complex_elements(a, b).expect("ismember");
807        assert_eq!(eval.mask.data, vec![1, 1]);
808        assert_eq!(eval.loc.data, vec![2.0, 1.0]);
809    }
810
811    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
812    #[test]
813    fn complex_rows_membership() {
814        let a = ComplexTensor::new(
815            vec![(1.0, 1.0), (3.0, 0.0), (2.0, 0.0), (4.0, 4.0)],
816            vec![2, 2],
817        )
818        .unwrap();
819        let b = ComplexTensor::new(
820            vec![
821                (1.0, 1.0),
822                (5.0, 0.0),
823                (3.0, 0.0),
824                (2.0, 0.0),
825                (6.0, 0.0),
826                (4.0, 4.0),
827            ],
828            vec![3, 2],
829        )
830        .unwrap();
831        let eval = ismember_complex_rows(a, b).expect("ismember");
832        assert_eq!(eval.mask.data, vec![1, 1]);
833        assert_eq!(eval.loc.data, vec![1.0, 3.0]);
834    }
835
836    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
837    #[test]
838    fn char_membership() {
839        let a = CharArray::new(vec!['r', 'u', 'n', 'm'], 2, 2).unwrap();
840        let b = CharArray::new(vec!['m', 'a', 'r', 'u'], 2, 2).unwrap();
841        let eval = ismember_char_elements(a, b).expect("ismember");
842        assert_eq!(eval.mask.data, vec![1, 0, 1, 1]);
843        assert_eq!(eval.loc.data, vec![2.0, 0.0, 4.0, 1.0]);
844    }
845
846    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
847    #[test]
848    fn char_rows_membership() {
849        let a = CharArray::new(vec!['m', 'a', 't', 'l'], 2, 2).unwrap();
850        let b = CharArray::new(vec!['m', 'a', 'g', 'e', 't', 'l'], 3, 2).unwrap();
851        let eval = ismember_char_rows(a, b).expect("ismember");
852        assert_eq!(eval.mask.data, vec![1, 1]);
853        assert_eq!(eval.loc.data, vec![1.0, 3.0]);
854    }
855
856    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
857    #[test]
858    fn string_membership() {
859        let a = StringArray::new(
860            vec![
861                "apple".to_string(),
862                "pear".to_string(),
863                "banana".to_string(),
864            ],
865            vec![1, 3],
866        )
867        .unwrap();
868        let b = StringArray::new(
869            vec![
870                "pear".to_string(),
871                "orange".to_string(),
872                "apple".to_string(),
873            ],
874            vec![1, 3],
875        )
876        .unwrap();
877        let eval = ismember_string_elements(a, b).expect("ismember");
878        assert_eq!(eval.mask.data, vec![1, 1, 0]);
879        assert_eq!(eval.loc.data, vec![3.0, 1.0, 0.0]);
880    }
881
882    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
883    #[test]
884    fn string_rows_membership() {
885        let a = StringArray::new(
886            vec![
887                "alpha".to_string(),
888                "gamma".to_string(),
889                "beta".to_string(),
890                "delta".to_string(),
891            ],
892            vec![2, 2],
893        )
894        .unwrap();
895        let b = StringArray::new(
896            vec![
897                "alpha".to_string(),
898                "theta".to_string(),
899                "gamma".to_string(),
900                "beta".to_string(),
901                "eta".to_string(),
902                "delta".to_string(),
903            ],
904            vec![3, 2],
905        )
906        .unwrap();
907        let eval = ismember_string_rows(a, b).expect("ismember");
908        assert_eq!(eval.mask.data, vec![1, 1]);
909        assert_eq!(eval.loc.data, vec![1.0, 3.0]);
910    }
911
912    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
913    #[test]
914    fn options_reject_legacy() {
915        let err = error_message(parse_options(&[Value::from("legacy")]).unwrap_err());
916        assert!(err.contains("legacy"));
917    }
918
919    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
920    #[test]
921    fn rejects_unknown_option() {
922        let err = error_message(
923            evaluate_sync(Value::Num(1.0), Value::Num(1.0), &[Value::from("stable")]).unwrap_err(),
924        );
925        assert!(err.contains("unrecognised option"));
926    }
927
928    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
929    #[test]
930    fn ismember_runtime_numeric() {
931        let a = Value::Tensor(Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap());
932        let b = Value::Tensor(Tensor::new(vec![3.0, 1.0], vec![2, 1]).unwrap());
933        let (mask, loc) = evaluate_sync(a, b, &[]).unwrap().into_pair();
934        match mask {
935            Value::LogicalArray(arr) => assert_eq!(arr.data, vec![1, 0, 1]),
936            other => panic!("expected logical array, got {other:?}"),
937        }
938        match loc {
939            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 0.0, 1.0]),
940            other => panic!("expected tensor, got {other:?}"),
941        }
942    }
943
944    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
945    #[test]
946    fn logical_inputs_promoted() {
947        let a = Value::Bool(true);
948        let logical_b =
949            LogicalArray::new(vec![1, 0], vec![2, 1]).expect("logical array construction");
950        let eval = evaluate_sync(a, Value::LogicalArray(logical_b), &[]).expect("ismember");
951        assert_eq!(eval.mask_value(), Value::Bool(true));
952        assert_eq!(eval.loc_value(), Value::Num(1.0));
953    }
954
955    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
956    #[test]
957    fn ismember_rows_shape_checks() {
958        let a = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
959        let b = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
960        assert!(ismember_numeric_rows(a.clone(), b.clone()).is_ok());
961        let bad = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
962        let err = error_message(ismember_numeric_rows(a, bad).unwrap_err());
963        assert!(err.contains("same number of columns"));
964    }
965
966    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
967    #[test]
968    fn ismember_gpu_roundtrip() {
969        test_support::with_test_provider(|provider| {
970            let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 4.0], vec![4, 1]).unwrap();
971            let set = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
972            let view_a = runmat_accelerate_api::HostTensorView {
973                data: &tensor.data,
974                shape: &tensor.shape,
975            };
976            let view_b = runmat_accelerate_api::HostTensorView {
977                data: &set.data,
978                shape: &set.shape,
979            };
980            let handle_a = provider.upload(&view_a).expect("upload a");
981            let handle_b = provider.upload(&view_b).expect("upload b");
982            let eval = evaluate_sync(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
983                .expect("ismember");
984            assert_eq!(eval.mask.data, vec![0, 1, 0, 1]);
985            assert_eq!(eval.loc.data, vec![0.0, 1.0, 0.0, 1.0]);
986        });
987    }
988
989    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
990    #[test]
991    fn ismember_gpu_rows_roundtrip() {
992        test_support::with_test_provider(|provider| {
993            let rows = Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
994            let bank = Tensor::new(vec![1.0, 5.0, 3.0, 2.0, 6.0, 4.0], vec![3, 2]).unwrap();
995            let view_a = runmat_accelerate_api::HostTensorView {
996                data: &rows.data,
997                shape: &rows.shape,
998            };
999            let view_b = runmat_accelerate_api::HostTensorView {
1000                data: &bank.data,
1001                shape: &bank.shape,
1002            };
1003            let handle_a = provider.upload(&view_a).expect("upload a");
1004            let handle_b = provider.upload(&view_b).expect("upload b");
1005            let eval = evaluate_sync(
1006                Value::GpuTensor(handle_a.clone()),
1007                Value::GpuTensor(handle_b.clone()),
1008                &[Value::from("rows")],
1009            )
1010            .expect("ismember");
1011            assert_eq!(eval.mask.data, vec![1, 1]);
1012            assert_eq!(eval.loc.data, vec![1.0, 3.0]);
1013            let _ = provider.free(&handle_a);
1014            let _ = provider.free(&handle_b);
1015        });
1016    }
1017
1018    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1019    #[test]
1020    #[cfg(feature = "wgpu")]
1021    fn ismember_wgpu_numeric_matches_cpu() {
1022        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1023            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1024        );
1025
1026        let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 4.0], vec![4, 1]).unwrap();
1027        let set = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1028        let cpu_eval =
1029            ismember_numeric_from_tensors(tensor.clone(), set.clone(), false).expect("cpu");
1030
1031        let provider = runmat_accelerate_api::provider().expect("provider");
1032        let view_a = HostTensorView {
1033            data: &tensor.data,
1034            shape: &tensor.shape,
1035        };
1036        let view_b = HostTensorView {
1037            data: &set.data,
1038            shape: &set.shape,
1039        };
1040        let handle_a = provider.upload(&view_a).expect("upload a");
1041        let handle_b = provider.upload(&view_b).expect("upload b");
1042
1043        let eval = evaluate_sync(
1044            Value::GpuTensor(handle_a.clone()),
1045            Value::GpuTensor(handle_b.clone()),
1046            &[],
1047        )
1048        .expect("gpu evaluate");
1049        assert_eq!(eval.mask.data, cpu_eval.mask.data);
1050        assert_eq!(eval.loc.data, cpu_eval.loc.data);
1051
1052        let _ = provider.free(&handle_a);
1053        let _ = provider.free(&handle_b);
1054
1055        let matrix = Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
1056        let bank = Tensor::new(vec![1.0, 7.0, 3.0, 2.0, 9.0, 4.0], vec![3, 2]).unwrap();
1057        let cpu_rows =
1058            ismember_numeric_from_tensors(matrix.clone(), bank.clone(), true).expect("cpu rows");
1059        let view_matrix = HostTensorView {
1060            data: &matrix.data,
1061            shape: &matrix.shape,
1062        };
1063        let view_bank = HostTensorView {
1064            data: &bank.data,
1065            shape: &bank.shape,
1066        };
1067        let handle_matrix = provider.upload(&view_matrix).expect("upload matrix");
1068        let handle_bank = provider.upload(&view_bank).expect("upload bank");
1069        let eval_rows = evaluate_sync(
1070            Value::GpuTensor(handle_matrix.clone()),
1071            Value::GpuTensor(handle_bank.clone()),
1072            &[Value::from("rows")],
1073        )
1074        .expect("gpu rows evaluate");
1075        assert_eq!(eval_rows.mask.data, cpu_rows.mask.data);
1076        assert_eq!(eval_rows.loc.data, cpu_rows.loc.data);
1077        let _ = provider.free(&handle_matrix);
1078        let _ = provider.free(&handle_bank);
1079    }
1080
1081    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1082    #[test]
1083    fn scalar_return_is_bool() {
1084        let a = Value::Tensor(Tensor::new(vec![7.0], vec![1, 1]).unwrap());
1085        let b = Value::Tensor(Tensor::new(vec![7.0], vec![1, 1]).unwrap());
1086        let mask = evaluate_sync(a, b, &[]).unwrap().into_mask_value();
1087        assert_eq!(mask, Value::Bool(true));
1088    }
1089
1090    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1091    #[test]
1092    fn parse_rows_option() {
1093        let opts = parse_options(&[Value::from("rows")]).unwrap();
1094        assert!(opts.rows);
1095    }
1096
1097    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1098    #[test]
1099    fn numeric_rows_with_nan() {
1100        let a = Tensor::new(vec![f64::NAN, 1.0], vec![2, 1]).unwrap();
1101        let b = Tensor::new(vec![f64::NAN, 2.0], vec![2, 1]).unwrap();
1102        let eval = ismember_numeric_rows(a, b).expect("ismember");
1103        assert_eq!(eval.mask.data, vec![1, 0]);
1104        assert_eq!(eval.loc.data, vec![1.0, 0.0]);
1105    }
1106}