runmat_runtime/builtins/array/sorting_sets/
ismember.rs

1//! MATLAB-compatible `ismember` builtin with GPU-aware semantics for RunMat.
2
3use std::collections::HashMap;
4
5use runmat_accelerate_api::{
6    GpuTensorHandle, HostLogicalOwned, HostTensorOwned, IsMemberOptions as ProviderIsMemberOptions,
7    IsMemberResult,
8};
9use runmat_builtins::{CharArray, ComplexTensor, LogicalArray, StringArray, Tensor, Value};
10use runmat_macros::runtime_builtin;
11
12use crate::builtins::common::gpu_helpers;
13use crate::builtins::common::spec::{
14    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
15    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
16};
17use crate::builtins::common::tensor;
18#[cfg(feature = "doc_export")]
19use crate::register_builtin_doc_text;
20use crate::{register_builtin_fusion_spec, register_builtin_gpu_spec};
21
22#[cfg(feature = "doc_export")]
23pub const DOC_MD: &str = r#"---
24title: "ismember"
25category: "array/sorting_sets"
26keywords: ["ismember", "membership", "set", "rows", "indices", "gpu"]
27summary: "Identify array elements or rows that appear in another array while returning first-match indices."
28references:
29  - https://www.mathworks.com/help/matlab/ref/ismember.html
30gpu_support:
31  elementwise: false
32  reduction: false
33  precisions: ["f32", "f64"]
34  broadcasting: "none"
35  notes: "When providers lack a dedicated membership hook RunMat gathers GPU tensors and executes the host implementation."
36fusion:
37  elementwise: false
38  reduction: false
39  max_inputs: 2
40  constants: "inline"
41requires_feature: null
42tested:
43  unit: "builtins::array::sorting_sets::ismember::tests"
44  integration: "builtins::array::sorting_sets::ismember::tests::ismember_gpu_roundtrip"
45---
46
47# What does the `ismember` function do in MATLAB / RunMat?
48`ismember(A, B)` compares the elements (or rows) of `A` against `B` and returns a logical array
49marking which members of `A` are present in `B`. The optional second output reports the index in `B`
50of the first matched element. RunMat follows MATLAB semantics for numeric, logical, complex, string,
51and character arrays.
52
53## How does the `ismember` function behave in MATLAB / RunMat?
54- The first output `tf` has the same shape as `A` (or `size(A,1) × 1` when using `'rows'`).
55- The optional second output `loc` contains one-based indices into `B`, with `0` for values that are
56  not found.
57- Duplicate values in `A` return the index of the first occurrence in `B` every time they match.
58- `NaN` values are treated as identical so they match other `NaN` entries in `B`.
59- Character arrays follow column-major linear indexing, mirroring MATLAB.
60- The `'rows'` option compares complete rows; inputs must agree on the number of columns.
61- Legacy flags (`'legacy'`, `'R2012a'`) are deliberately unsupported in RunMat.
62
63## `ismember` Function GPU Execution Behaviour
64When either input is a GPU tensor, RunMat first checks whether the active acceleration provider
65exposes a custom `ismember` hook. Until providers implement that hook, the runtime transparently
66gathers GPU operands to host memory, performs the membership lookup using the CPU implementation,
67and returns host-resident outputs so results exactly match MATLAB.
68
69## GPU residency in RunMat (Do I need `gpuArray`?)
70
71Most code does not need to call `gpuArray` explicitly. The native auto-offload planner keeps track
72of residency and recognises that `ismember` is a sink: the operation produces logical outputs and
73one-based indices that currently live on the host. If an acceleration provider exposes a full
74`ismember` hook in the future, the planner can keep data on the device automatically. Until then,
75manual `gpuArray` / `gather` calls only serve to mirror MATLAB workflows; RunMat already performs
76the necessary transfers when it detects that tensors reside on the GPU.
77
78## Examples of using the `ismember` function in MATLAB / RunMat
79
80### Checking membership of numeric vectors
81```matlab
82A = [5 7 2 7];
83B = [7 9 5];
84[tf, loc] = ismember(A, B);
85```
86Expected output:
87```matlab
88tf =
89     1     1     0     1
90loc =
91     3     1     0     1
92```
93
94### Finding row membership in a matrix
95```matlab
96A = [1 2; 3 4; 1 2];
97B = [3 4; 5 6; 1 2];
98[tf, loc] = ismember(A, B, 'rows');
99```
100Expected output:
101```matlab
102tf =
103     1
104     1
105     1
106loc =
107     3
108     1
109     3
110```
111
112### Locating values and retrieving the index
113```matlab
114values = [10 20 30];
115set = [30 10 40];
116[tf, loc] = ismember(values, set);
117```
118Expected output:
119```matlab
120tf =
121     1     0     1
122loc =
123     2     0     1
124```
125
126### Testing characters against a set
127```matlab
128chars = ['r','u'; 'n','m'];
129set = ['m','a'; 'r','u'];
130[tf, loc] = ismember(chars, set);
131```
132Expected output:
133```matlab
134tf =
135     1     1
136     0     0
137loc =
138     3     1
139     0     0
140```
141
142### Working with string arrays
143```matlab
144A = ["apple" "pear" "banana"];
145B = ["pear" "orange" "apple"];
146[tf, loc] = ismember(A, B);
147```
148Expected output:
149```matlab
150tf =
151  1×3 logical array
152   1   1   0
153loc =
154  1×3 double
155   3   1   0
156```
157
158### Using `ismember` with `gpuArray` inputs
159```matlab
160G = gpuArray([1 4 2 4]);
161H = gpuArray([4 5]);
162[tf, loc] = ismember(G, H);
163```
164Expected output (RunMat gathers to host unless a provider implements `ismember`):
165```matlab
166tf =
167     0     1     0     1
168loc =
169     0     1     0     1
170```
171
172## FAQ
173
174### Does `ismember` treat `NaN` values as equal?
175Yes. `NaN` values compare equal for membership tests so every `NaN` in `A` matches any `NaN` in `B`.
176
177### What happens when an element of `A` is not found in `B`?
178The corresponding logical entry is `false` and the index output stores `0`, matching MATLAB.
179
180### Can I use `ismember` with string arrays and character arrays?
181Yes. String arrays, scalar strings, and character arrays are supported. Mixed string/char inputs
182should be normalised (for example, convert scalars with `string`).
183
184### How does the `'rows'` option change the output shape?
185`'rows'` compares entire rows and returns outputs of size `size(A,1) × 1`, regardless of how many
186columns the input matrices contain.
187
188### Are the legacy flags supported?
189No. RunMat only implements modern MATLAB semantics. Passing `'legacy'` or `'R2012a'` raises an
190error, just like other set builtins in RunMat.
191
192### Will `ismember` run on the GPU automatically?
193If the active provider advertises an `ismember` hook, the runtime can keep tensors on the device.
194Otherwise the data is gathered to the host with no behavioural differences.
195
196## See Also
197[unique](./unique), [intersect](./intersect), [setdiff](./setdiff), [union](./union), [gpuArray](../../acceleration/gpu/gpuArray), [gather](../../acceleration/gpu/gather)
198
199## Source & Feedback
200- Source code: [`crates/runmat-runtime/src/builtins/array/sorting_sets/ismember.rs`](https://github.com/runmat-org/runmat/blob/main/crates/runmat-runtime/src/builtins/array/sorting_sets/ismember.rs)
201- Found a bug? [Open an issue](https://github.com/runmat-org/runmat/issues/new/choose) with details and a minimal repro.
202"#;
203
204pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
205    name: "ismember",
206    op_kind: GpuOpKind::Custom("ismember"),
207    supported_precisions: &[ScalarType::F32, ScalarType::F64],
208    broadcast: BroadcastSemantics::None,
209    provider_hooks: &[ProviderHook::Custom("ismember")],
210    constant_strategy: ConstantStrategy::InlineLiteral,
211    residency: ResidencyPolicy::GatherImmediately,
212    nan_mode: ReductionNaN::Include,
213    two_pass_threshold: None,
214    workgroup_size: None,
215    accepts_nan_mode: false,
216    notes: "Providers may supply dedicated membership kernels; until then RunMat gathers GPU tensors to host memory.",
217};
218
219register_builtin_gpu_spec!(GPU_SPEC);
220
221pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
222    name: "ismember",
223    shape: ShapeRequirements::Any,
224    constant_strategy: ConstantStrategy::InlineLiteral,
225    elementwise: None,
226    reduction: None,
227    emits_nan: false,
228    notes: "Membership queries execute via host set lookups; the fusion planner treats ismember as a residency sink.",
229};
230
231register_builtin_fusion_spec!(FUSION_SPEC);
232
233#[cfg(feature = "doc_export")]
234register_builtin_doc_text!("ismember", DOC_MD);
235
236#[runtime_builtin(
237    name = "ismember",
238    category = "array/sorting_sets",
239    summary = "Identify array elements or rows that appear in another array while returning first-match indices.",
240    keywords = "ismember,membership,set,rows,indices,gpu",
241    accel = "array_construct",
242    sink = true
243)]
244fn ismember_builtin(a: Value, b: Value, rest: Vec<Value>) -> Result<Value, String> {
245    evaluate(a, b, &rest).map(|eval| eval.into_mask_value())
246}
247
248/// Evaluate `ismember` once and expose both outputs.
249pub fn evaluate(a: Value, b: Value, rest: &[Value]) -> Result<IsMemberEvaluation, String> {
250    let opts = parse_options(rest)?;
251    match (a, b) {
252        (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
253            ismember_gpu_pair(handle_a, handle_b, &opts)
254        }
255        (Value::GpuTensor(handle_a), other) => ismember_gpu_mixed(handle_a, other, &opts, true),
256        (other, Value::GpuTensor(handle_b)) => ismember_gpu_mixed(handle_b, other, &opts, false),
257        (left, right) => ismember_host(left, right, &opts),
258    }
259}
260
261#[derive(Debug, Clone, Copy)]
262struct IsMemberOptions {
263    rows: bool,
264}
265
266impl IsMemberOptions {
267    fn into_provider_options(self) -> ProviderIsMemberOptions {
268        ProviderIsMemberOptions { rows: self.rows }
269    }
270}
271
272fn parse_options(rest: &[Value]) -> Result<IsMemberOptions, String> {
273    let mut opts = IsMemberOptions { rows: false };
274    for arg in rest {
275        let text = tensor::value_to_string(arg)
276            .ok_or_else(|| "ismember: expected string option arguments".to_string())?;
277        let lowered = text.trim().to_ascii_lowercase();
278        match lowered.as_str() {
279            "rows" => opts.rows = true,
280            "legacy" | "r2012a" => {
281                return Err("ismember: the 'legacy' behaviour is not supported".to_string())
282            }
283            other => return Err(format!("ismember: unrecognised option '{other}'")),
284        }
285    }
286    Ok(opts)
287}
288
289fn ismember_gpu_pair(
290    handle_a: GpuTensorHandle,
291    handle_b: GpuTensorHandle,
292    opts: &IsMemberOptions,
293) -> Result<IsMemberEvaluation, String> {
294    if let Some(provider) = runmat_accelerate_api::provider() {
295        let provider_opts = opts.into_provider_options();
296        match provider.ismember(&handle_a, &handle_b, &provider_opts) {
297            Ok(result) => return IsMemberEvaluation::from_provider_result(result),
298            Err(_) => {
299                // Fall back to host gather when the provider lacks an ismember implementation.
300            }
301        }
302    }
303    let tensor_a = gpu_helpers::gather_tensor(&handle_a)?;
304    let tensor_b = gpu_helpers::gather_tensor(&handle_b)?;
305    ismember_numeric_tensors(tensor_a, tensor_b, opts)
306}
307
308fn ismember_gpu_mixed(
309    handle_gpu: GpuTensorHandle,
310    other: Value,
311    opts: &IsMemberOptions,
312    gpu_is_a: bool,
313) -> Result<IsMemberEvaluation, String> {
314    let tensor_gpu = gpu_helpers::gather_tensor(&handle_gpu)?;
315    if gpu_is_a {
316        ismember_host(Value::Tensor(tensor_gpu), other, opts)
317    } else {
318        ismember_host(other, Value::Tensor(tensor_gpu), opts)
319    }
320}
321
322fn ismember_host(a: Value, b: Value, opts: &IsMemberOptions) -> Result<IsMemberEvaluation, String> {
323    match (a, b) {
324        (Value::ComplexTensor(at), Value::ComplexTensor(bt)) => ismember_complex(at, bt, opts.rows),
325        (Value::ComplexTensor(at), Value::Complex(re, im)) => {
326            let bt = ComplexTensor::new(vec![(re, im)], vec![1, 1])
327                .map_err(|e| format!("ismember: {e}"))?;
328            ismember_complex(at, bt, opts.rows)
329        }
330        (Value::Complex(a_re, a_im), Value::ComplexTensor(bt)) => {
331            let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
332                .map_err(|e| format!("ismember: {e}"))?;
333            ismember_complex(at, bt, opts.rows)
334        }
335        (Value::Complex(a_re, a_im), Value::Complex(b_re, b_im)) => {
336            let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
337                .map_err(|e| format!("ismember: {e}"))?;
338            let bt = ComplexTensor::new(vec![(b_re, b_im)], vec![1, 1])
339                .map_err(|e| format!("ismember: {e}"))?;
340            ismember_complex(at, bt, opts.rows)
341        }
342
343        (Value::CharArray(ac), Value::CharArray(bc)) => ismember_char(ac, bc, opts.rows),
344
345        (Value::StringArray(astring), Value::StringArray(bstring)) => {
346            ismember_string(astring, bstring, opts.rows)
347        }
348        (Value::StringArray(astring), Value::String(b)) => {
349            let bstring =
350                StringArray::new(vec![b], vec![1, 1]).map_err(|e| format!("ismember: {e}"))?;
351            ismember_string(astring, bstring, opts.rows)
352        }
353        (Value::String(a), Value::StringArray(bstring)) => {
354            let astring =
355                StringArray::new(vec![a], vec![1, 1]).map_err(|e| format!("ismember: {e}"))?;
356            ismember_string(astring, bstring, opts.rows)
357        }
358        (Value::String(a), Value::String(b)) => {
359            let astring =
360                StringArray::new(vec![a], vec![1, 1]).map_err(|e| format!("ismember: {e}"))?;
361            let bstring =
362                StringArray::new(vec![b], vec![1, 1]).map_err(|e| format!("ismember: {e}"))?;
363            ismember_string(astring, bstring, opts.rows)
364        }
365
366        (left, right) => {
367            let tensor_a = tensor::value_into_tensor_for("ismember", left)?;
368            let tensor_b = tensor::value_into_tensor_for("ismember", right)?;
369            ismember_numeric_tensors(tensor_a, tensor_b, opts)
370        }
371    }
372}
373
374fn ismember_numeric_tensors(
375    a: Tensor,
376    b: Tensor,
377    opts: &IsMemberOptions,
378) -> Result<IsMemberEvaluation, String> {
379    if opts.rows {
380        ismember_numeric_rows(a, b)
381    } else {
382        ismember_numeric_elements(a, b)
383    }
384}
385
386/// Helper exposed for acceleration providers handling numeric tensors on the host.
387pub fn ismember_numeric_from_tensors(
388    a: Tensor,
389    b: Tensor,
390    rows: bool,
391) -> Result<IsMemberEvaluation, String> {
392    let opts = IsMemberOptions { rows };
393    ismember_numeric_tensors(a, b, &opts)
394}
395
396fn ismember_numeric_elements(a: Tensor, b: Tensor) -> Result<IsMemberEvaluation, String> {
397    let mut map: HashMap<u64, usize> = HashMap::new();
398    for (idx, &value) in b.data.iter().enumerate() {
399        map.entry(canonicalize_f64(value)).or_insert(idx + 1);
400    }
401
402    let mut mask_data = Vec::<u8>::with_capacity(a.data.len());
403    let mut loc_data = Vec::<f64>::with_capacity(a.data.len());
404
405    for &value in &a.data {
406        let key = canonicalize_f64(value);
407        if let Some(&pos) = map.get(&key) {
408            mask_data.push(1);
409            loc_data.push(pos as f64);
410        } else {
411            mask_data.push(0);
412            loc_data.push(0.0);
413        }
414    }
415
416    let logical = LogicalArray::new(mask_data, a.shape.clone())?;
417    let loc_tensor =
418        Tensor::new(loc_data, a.shape.clone()).map_err(|e| format!("ismember: {e}"))?;
419    Ok(IsMemberEvaluation::new(logical, loc_tensor))
420}
421
422fn ismember_numeric_rows(a: Tensor, b: Tensor) -> Result<IsMemberEvaluation, String> {
423    let (rows_a, cols_a) = tensor_rows_cols(&a, "ismember")?;
424    let (rows_b, cols_b) = tensor_rows_cols(&b, "ismember")?;
425    if cols_a != cols_b {
426        return Err(
427            "ismember: inputs must have the same number of columns when using 'rows'".to_string(),
428        );
429    }
430
431    let mut map: HashMap<NumericRowKey, usize> = HashMap::new();
432    for r in 0..rows_b {
433        let mut row_values = Vec::with_capacity(cols_b);
434        for c in 0..cols_b {
435            let idx = r + c * rows_b;
436            row_values.push(b.data[idx]);
437        }
438        let key = NumericRowKey::from_slice(&row_values);
439        map.entry(key).or_insert(r + 1);
440    }
441
442    let mut mask_data = vec![0u8; rows_a];
443    let mut loc_data = vec![0.0f64; rows_a];
444
445    for r in 0..rows_a {
446        let mut row_values = Vec::with_capacity(cols_a);
447        for c in 0..cols_a {
448            let idx = r + c * rows_a;
449            row_values.push(a.data[idx]);
450        }
451        let key = NumericRowKey::from_slice(&row_values);
452        if let Some(&pos) = map.get(&key) {
453            mask_data[r] = 1;
454            loc_data[r] = pos as f64;
455        }
456    }
457
458    let shape = vec![rows_a, 1];
459    let logical = LogicalArray::new(mask_data, shape.clone())?;
460    let loc_tensor = Tensor::new(loc_data, shape).map_err(|e| format!("ismember: {e}"))?;
461    Ok(IsMemberEvaluation::new(logical, loc_tensor))
462}
463
464fn ismember_complex(
465    a: ComplexTensor,
466    b: ComplexTensor,
467    rows: bool,
468) -> Result<IsMemberEvaluation, String> {
469    if rows {
470        ismember_complex_rows(a, b)
471    } else {
472        ismember_complex_elements(a, b)
473    }
474}
475
476fn ismember_complex_elements(
477    a: ComplexTensor,
478    b: ComplexTensor,
479) -> Result<IsMemberEvaluation, String> {
480    let mut map: HashMap<ComplexKey, usize> = HashMap::new();
481    for (idx, &value) in b.data.iter().enumerate() {
482        map.entry(ComplexKey::new(value)).or_insert(idx + 1);
483    }
484
485    let mut mask_data = Vec::<u8>::with_capacity(a.data.len());
486    let mut loc_data = Vec::<f64>::with_capacity(a.data.len());
487
488    for &value in &a.data {
489        let key = ComplexKey::new(value);
490        if let Some(&pos) = map.get(&key) {
491            mask_data.push(1);
492            loc_data.push(pos as f64);
493        } else {
494            mask_data.push(0);
495            loc_data.push(0.0);
496        }
497    }
498
499    let logical = LogicalArray::new(mask_data, a.shape.clone())?;
500    let loc_tensor =
501        Tensor::new(loc_data, a.shape.clone()).map_err(|e| format!("ismember: {e}"))?;
502    Ok(IsMemberEvaluation::new(logical, loc_tensor))
503}
504
505fn ismember_complex_rows(a: ComplexTensor, b: ComplexTensor) -> Result<IsMemberEvaluation, String> {
506    let (rows_a, cols_a) = complex_rows_cols(&a)?;
507    let (rows_b, cols_b) = complex_rows_cols(&b)?;
508    if cols_a != cols_b {
509        return Err(
510            "ismember: complex inputs must have the same number of columns when using 'rows'"
511                .to_string(),
512        );
513    }
514
515    let mut map: HashMap<Vec<ComplexKey>, usize> = HashMap::new();
516    for r in 0..rows_b {
517        let mut row_keys = Vec::with_capacity(cols_b);
518        for c in 0..cols_b {
519            let idx = r + c * rows_b;
520            row_keys.push(ComplexKey::new(b.data[idx]));
521        }
522        map.entry(row_keys).or_insert(r + 1);
523    }
524
525    let mut mask_data = vec![0u8; rows_a];
526    let mut loc_data = vec![0.0f64; rows_a];
527
528    for r in 0..rows_a {
529        let mut row_keys = Vec::with_capacity(cols_a);
530        for c in 0..cols_a {
531            let idx = r + c * rows_a;
532            row_keys.push(ComplexKey::new(a.data[idx]));
533        }
534        if let Some(&pos) = map.get(&row_keys) {
535            mask_data[r] = 1;
536            loc_data[r] = pos as f64;
537        }
538    }
539
540    let shape = vec![rows_a, 1];
541    let logical = LogicalArray::new(mask_data, shape.clone())?;
542    let loc_tensor = Tensor::new(loc_data, shape).map_err(|e| format!("ismember: {e}"))?;
543    Ok(IsMemberEvaluation::new(logical, loc_tensor))
544}
545
546fn ismember_char(a: CharArray, b: CharArray, rows: bool) -> Result<IsMemberEvaluation, String> {
547    if rows {
548        ismember_char_rows(a, b)
549    } else {
550        ismember_char_elements(a, b)
551    }
552}
553
554fn ismember_char_elements(a: CharArray, b: CharArray) -> Result<IsMemberEvaluation, String> {
555    let rows_b = b.rows;
556    let cols_b = b.cols;
557    let mut map: HashMap<char, usize> = HashMap::new();
558
559    for col in 0..cols_b {
560        for row in 0..rows_b {
561            let data_idx = row * cols_b + col;
562            let ch = b.data[data_idx];
563            let linear_idx = row + col * rows_b;
564            map.entry(ch).or_insert(linear_idx + 1);
565        }
566    }
567
568    let rows_a = a.rows;
569    let cols_a = a.cols;
570    let mut mask_data = vec![0u8; rows_a * cols_a];
571    let mut loc_data = vec![0.0f64; rows_a * cols_a];
572
573    for col in 0..cols_a {
574        for row in 0..rows_a {
575            let data_idx = row * cols_a + col;
576            let ch = a.data[data_idx];
577            let linear_idx = row + col * rows_a;
578            if let Some(&pos) = map.get(&ch) {
579                mask_data[linear_idx] = 1;
580                loc_data[linear_idx] = pos as f64;
581            }
582        }
583    }
584
585    let shape = vec![rows_a, cols_a];
586    let logical = LogicalArray::new(mask_data, shape.clone())?;
587    let loc_tensor = Tensor::new(loc_data, shape).map_err(|e| format!("ismember: {e}"))?;
588    Ok(IsMemberEvaluation::new(logical, loc_tensor))
589}
590
591fn ismember_char_rows(a: CharArray, b: CharArray) -> Result<IsMemberEvaluation, String> {
592    if a.cols != b.cols {
593        return Err(
594            "ismember: character inputs must have the same number of columns when using 'rows'"
595                .to_string(),
596        );
597    }
598
599    let rows_b = b.rows;
600    let cols = b.cols;
601    let mut map: HashMap<RowCharKey, usize> = HashMap::new();
602
603    for r in 0..rows_b {
604        let mut row_values = Vec::with_capacity(cols);
605        for c in 0..cols {
606            let idx = r * cols + c;
607            row_values.push(b.data[idx]);
608        }
609        let key = RowCharKey::from_slice(&row_values);
610        map.entry(key).or_insert(r + 1);
611    }
612
613    let rows_a = a.rows;
614    let mut mask_data = vec![0u8; rows_a];
615    let mut loc_data = vec![0.0f64; rows_a];
616
617    for r in 0..rows_a {
618        let mut row_values = Vec::with_capacity(cols);
619        for c in 0..cols {
620            let idx = r * cols + c;
621            row_values.push(a.data[idx]);
622        }
623        let key = RowCharKey::from_slice(&row_values);
624        if let Some(&pos) = map.get(&key) {
625            mask_data[r] = 1;
626            loc_data[r] = pos as f64;
627        }
628    }
629
630    let shape = vec![rows_a, 1];
631    let logical = LogicalArray::new(mask_data, shape.clone())?;
632    let loc_tensor = Tensor::new(loc_data, shape).map_err(|e| format!("ismember: {e}"))?;
633    Ok(IsMemberEvaluation::new(logical, loc_tensor))
634}
635
636fn ismember_string(
637    a: StringArray,
638    b: StringArray,
639    rows: bool,
640) -> Result<IsMemberEvaluation, String> {
641    if rows {
642        ismember_string_rows(a, b)
643    } else {
644        ismember_string_elements(a, b)
645    }
646}
647
648fn ismember_string_elements(a: StringArray, b: StringArray) -> Result<IsMemberEvaluation, String> {
649    let mut map: HashMap<String, usize> = HashMap::new();
650    for (idx, value) in b.data.iter().enumerate() {
651        map.entry(value.clone()).or_insert(idx + 1);
652    }
653
654    let mut mask_data = Vec::<u8>::with_capacity(a.data.len());
655    let mut loc_data = Vec::<f64>::with_capacity(a.data.len());
656
657    for value in &a.data {
658        if let Some(&pos) = map.get(value) {
659            mask_data.push(1);
660            loc_data.push(pos as f64);
661        } else {
662            mask_data.push(0);
663            loc_data.push(0.0);
664        }
665    }
666
667    let logical = LogicalArray::new(mask_data, a.shape.clone())?;
668    let loc_tensor =
669        Tensor::new(loc_data, a.shape.clone()).map_err(|e| format!("ismember: {e}"))?;
670    Ok(IsMemberEvaluation::new(logical, loc_tensor))
671}
672
673fn ismember_string_rows(a: StringArray, b: StringArray) -> Result<IsMemberEvaluation, String> {
674    if a.shape.len() != 2 || b.shape.len() != 2 {
675        return Err("ismember: 'rows' option requires 2-D string arrays".to_string());
676    }
677    if a.shape[1] != b.shape[1] {
678        return Err(
679            "ismember: string inputs must have the same number of columns when using 'rows'"
680                .to_string(),
681        );
682    }
683
684    let rows_a = a.shape[0];
685    let cols = a.shape[1];
686    let rows_b = b.shape[0];
687
688    let mut map: HashMap<RowStringKey, usize> = HashMap::new();
689    for r in 0..rows_b {
690        let mut row_values = Vec::with_capacity(cols);
691        for c in 0..cols {
692            let idx = r + c * rows_b;
693            row_values.push(b.data[idx].clone());
694        }
695        let key = RowStringKey(row_values);
696        map.entry(key).or_insert(r + 1);
697    }
698
699    let mut mask_data = vec![0u8; rows_a];
700    let mut loc_data = vec![0.0f64; rows_a];
701
702    for r in 0..rows_a {
703        let mut row_values = Vec::with_capacity(cols);
704        for c in 0..cols {
705            let idx = r + c * rows_a;
706            row_values.push(a.data[idx].clone());
707        }
708        let key = RowStringKey(row_values);
709        if let Some(&pos) = map.get(&key) {
710            mask_data[r] = 1;
711            loc_data[r] = pos as f64;
712        }
713    }
714
715    let shape = vec![rows_a, 1];
716    let logical = LogicalArray::new(mask_data, shape.clone())?;
717    let loc_tensor = Tensor::new(loc_data, shape).map_err(|e| format!("ismember: {e}"))?;
718    Ok(IsMemberEvaluation::new(logical, loc_tensor))
719}
720
721fn tensor_rows_cols(t: &Tensor, name: &str) -> Result<(usize, usize), String> {
722    match t.shape.len() {
723        0 => Ok((1, 1)),
724        1 => Ok((t.shape[0], 1)),
725        2 => Ok((t.shape[0], t.shape[1])),
726        _ => Err(format!(
727            "{name}: 'rows' option requires 2-D numeric matrices"
728        )),
729    }
730}
731
732fn complex_rows_cols(t: &ComplexTensor) -> Result<(usize, usize), String> {
733    match t.shape.len() {
734        0 => Ok((1, 1)),
735        1 => Ok((t.shape[0], 1)),
736        2 => Ok((t.shape[0], t.shape[1])),
737        _ => Err("ismember: 'rows' option requires 2-D complex matrices".to_string()),
738    }
739}
740
741#[derive(Debug, Clone, PartialEq, Eq, Hash)]
742struct NumericRowKey(Vec<u64>);
743
744impl NumericRowKey {
745    fn from_slice(values: &[f64]) -> Self {
746        NumericRowKey(values.iter().map(|&v| canonicalize_f64(v)).collect())
747    }
748}
749
750#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
751struct ComplexKey {
752    re: u64,
753    im: u64,
754}
755
756impl ComplexKey {
757    fn new(value: (f64, f64)) -> Self {
758        Self {
759            re: canonicalize_f64(value.0),
760            im: canonicalize_f64(value.1),
761        }
762    }
763}
764
765#[derive(Debug, Clone, PartialEq, Eq, Hash)]
766struct RowCharKey(Vec<u32>);
767
768impl RowCharKey {
769    fn from_slice(values: &[char]) -> Self {
770        RowCharKey(values.iter().map(|&ch| ch as u32).collect())
771    }
772}
773
774#[derive(Debug, Clone, PartialEq, Eq, Hash)]
775struct RowStringKey(Vec<String>);
776
777fn canonicalize_f64(value: f64) -> u64 {
778    if value.is_nan() {
779        0x7ff8_0000_0000_0000u64
780    } else if value == 0.0 {
781        0u64
782    } else {
783        value.to_bits()
784    }
785}
786
787#[derive(Debug, Clone)]
788pub struct IsMemberEvaluation {
789    mask: LogicalArray,
790    loc: Tensor,
791}
792
793impl IsMemberEvaluation {
794    fn new(mask: LogicalArray, loc: Tensor) -> Self {
795        Self { mask, loc }
796    }
797
798    pub fn from_provider_result(result: IsMemberResult) -> Result<Self, String> {
799        let mask = LogicalArray::new(result.mask.data, result.mask.shape)
800            .map_err(|e| format!("ismember: {e}"))?;
801        let loc =
802            Tensor::new(result.loc.data, result.loc.shape).map_err(|e| format!("ismember: {e}"))?;
803        Ok(IsMemberEvaluation::new(mask, loc))
804    }
805
806    pub fn into_numeric_ismember_result(self) -> Result<IsMemberResult, String> {
807        let IsMemberEvaluation { mask, loc } = self;
808        Ok(IsMemberResult {
809            mask: HostLogicalOwned {
810                data: mask.data,
811                shape: mask.shape,
812            },
813            loc: HostTensorOwned {
814                data: loc.data,
815                shape: loc.shape,
816            },
817        })
818    }
819
820    pub fn into_mask_value(self) -> Value {
821        logical_array_into_value(self.mask)
822    }
823
824    pub fn mask_value(&self) -> Value {
825        logical_array_into_value(self.mask.clone())
826    }
827
828    pub fn into_pair(self) -> (Value, Value) {
829        let mask = logical_array_into_value(self.mask);
830        let loc = tensor::tensor_into_value(self.loc);
831        (mask, loc)
832    }
833
834    pub fn loc_value(&self) -> Value {
835        tensor::tensor_into_value(self.loc.clone())
836    }
837}
838
839fn logical_array_into_value(logical: LogicalArray) -> Value {
840    if logical.data.len() == 1 {
841        Value::Bool(logical.data[0] != 0)
842    } else {
843        Value::LogicalArray(logical)
844    }
845}
846
847#[cfg(test)]
848mod tests {
849    use super::*;
850    use crate::builtins::common::test_support;
851    use runmat_builtins::Tensor;
852
853    #[cfg(feature = "wgpu")]
854    use runmat_accelerate_api::HostTensorView;
855
856    #[test]
857    fn numeric_membership_basic() {
858        let a = Tensor::new(vec![5.0, 7.0, 2.0, 7.0], vec![1, 4]).unwrap();
859        let b = Tensor::new(vec![7.0, 9.0, 5.0], vec![1, 3]).unwrap();
860        let eval = ismember_numeric_elements(a, b).expect("ismember");
861        assert_eq!(eval.mask.data, vec![1, 1, 0, 1]);
862        assert_eq!(eval.loc.data, vec![3.0, 1.0, 0.0, 1.0]);
863    }
864
865    #[test]
866    fn numeric_nan_membership() {
867        let a = Tensor::new(vec![f64::NAN, 1.0], vec![1, 2]).unwrap();
868        let b = Tensor::new(vec![f64::NAN, 2.0], vec![1, 2]).unwrap();
869        let eval = ismember_numeric_elements(a, b).expect("ismember");
870        assert_eq!(eval.mask.data, vec![1, 0]);
871        assert_eq!(eval.loc.data, vec![1.0, 0.0]);
872    }
873
874    #[test]
875    fn numeric_rows_membership() {
876        let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
877        let b = Tensor::new(vec![3.0, 5.0, 1.0, 4.0, 6.0, 2.0], vec![3, 2]).unwrap();
878        let eval = ismember_numeric_rows(a, b).expect("ismember");
879        assert_eq!(eval.mask.data, vec![1, 1, 1]);
880        assert_eq!(eval.loc.data, vec![3.0, 1.0, 3.0]);
881        assert_eq!(eval.loc.shape, vec![3, 1]);
882    }
883
884    #[test]
885    fn complex_membership() {
886        let a = ComplexTensor::new(vec![(1.0, 2.0), (0.0, 0.0)], vec![1, 2]).unwrap();
887        let b = ComplexTensor::new(vec![(0.0, 0.0), (1.0, 2.0)], vec![1, 2]).unwrap();
888        let eval = ismember_complex_elements(a, b).expect("ismember");
889        assert_eq!(eval.mask.data, vec![1, 1]);
890        assert_eq!(eval.loc.data, vec![2.0, 1.0]);
891    }
892
893    #[test]
894    fn complex_rows_membership() {
895        let a = ComplexTensor::new(
896            vec![(1.0, 1.0), (3.0, 0.0), (2.0, 0.0), (4.0, 4.0)],
897            vec![2, 2],
898        )
899        .unwrap();
900        let b = ComplexTensor::new(
901            vec![
902                (1.0, 1.0),
903                (5.0, 0.0),
904                (3.0, 0.0),
905                (2.0, 0.0),
906                (6.0, 0.0),
907                (4.0, 4.0),
908            ],
909            vec![3, 2],
910        )
911        .unwrap();
912        let eval = ismember_complex_rows(a, b).expect("ismember");
913        assert_eq!(eval.mask.data, vec![1, 1]);
914        assert_eq!(eval.loc.data, vec![1.0, 3.0]);
915    }
916
917    #[test]
918    fn char_membership() {
919        let a = CharArray::new(vec!['r', 'u', 'n', 'm'], 2, 2).unwrap();
920        let b = CharArray::new(vec!['m', 'a', 'r', 'u'], 2, 2).unwrap();
921        let eval = ismember_char_elements(a, b).expect("ismember");
922        assert_eq!(eval.mask.data, vec![1, 0, 1, 1]);
923        assert_eq!(eval.loc.data, vec![2.0, 0.0, 4.0, 1.0]);
924    }
925
926    #[test]
927    fn char_rows_membership() {
928        let a = CharArray::new(vec!['m', 'a', 't', 'l'], 2, 2).unwrap();
929        let b = CharArray::new(vec!['m', 'a', 'g', 'e', 't', 'l'], 3, 2).unwrap();
930        let eval = ismember_char_rows(a, b).expect("ismember");
931        assert_eq!(eval.mask.data, vec![1, 1]);
932        assert_eq!(eval.loc.data, vec![1.0, 3.0]);
933    }
934
935    #[test]
936    fn string_membership() {
937        let a = StringArray::new(
938            vec![
939                "apple".to_string(),
940                "pear".to_string(),
941                "banana".to_string(),
942            ],
943            vec![1, 3],
944        )
945        .unwrap();
946        let b = StringArray::new(
947            vec![
948                "pear".to_string(),
949                "orange".to_string(),
950                "apple".to_string(),
951            ],
952            vec![1, 3],
953        )
954        .unwrap();
955        let eval = ismember_string_elements(a, b).expect("ismember");
956        assert_eq!(eval.mask.data, vec![1, 1, 0]);
957        assert_eq!(eval.loc.data, vec![3.0, 1.0, 0.0]);
958    }
959
960    #[test]
961    fn string_rows_membership() {
962        let a = StringArray::new(
963            vec![
964                "alpha".to_string(),
965                "gamma".to_string(),
966                "beta".to_string(),
967                "delta".to_string(),
968            ],
969            vec![2, 2],
970        )
971        .unwrap();
972        let b = StringArray::new(
973            vec![
974                "alpha".to_string(),
975                "theta".to_string(),
976                "gamma".to_string(),
977                "beta".to_string(),
978                "eta".to_string(),
979                "delta".to_string(),
980            ],
981            vec![3, 2],
982        )
983        .unwrap();
984        let eval = ismember_string_rows(a, b).expect("ismember");
985        assert_eq!(eval.mask.data, vec![1, 1]);
986        assert_eq!(eval.loc.data, vec![1.0, 3.0]);
987    }
988
989    #[test]
990    fn options_reject_legacy() {
991        let err = parse_options(&[Value::from("legacy")]).unwrap_err();
992        assert!(err.contains("legacy"));
993    }
994
995    #[test]
996    fn rejects_unknown_option() {
997        let err = evaluate(Value::Num(1.0), Value::Num(1.0), &[Value::from("stable")]).unwrap_err();
998        assert!(err.contains("unrecognised option"));
999    }
1000
1001    #[test]
1002    fn ismember_runtime_numeric() {
1003        let a = Value::Tensor(Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap());
1004        let b = Value::Tensor(Tensor::new(vec![3.0, 1.0], vec![2, 1]).unwrap());
1005        let (mask, loc) = evaluate(a, b, &[]).unwrap().into_pair();
1006        match mask {
1007            Value::LogicalArray(arr) => assert_eq!(arr.data, vec![1, 0, 1]),
1008            other => panic!("expected logical array, got {other:?}"),
1009        }
1010        match loc {
1011            Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 0.0, 1.0]),
1012            other => panic!("expected tensor, got {other:?}"),
1013        }
1014    }
1015
1016    #[test]
1017    fn logical_inputs_promoted() {
1018        let a = Value::Bool(true);
1019        let logical_b =
1020            LogicalArray::new(vec![1, 0], vec![2, 1]).expect("logical array construction");
1021        let eval = evaluate(a, Value::LogicalArray(logical_b), &[]).expect("ismember");
1022        assert_eq!(eval.mask_value(), Value::Bool(true));
1023        assert_eq!(eval.loc_value(), Value::Num(1.0));
1024    }
1025
1026    #[test]
1027    fn ismember_rows_shape_checks() {
1028        let a = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1029        let b = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1030        assert!(ismember_numeric_rows(a.clone(), b.clone()).is_ok());
1031        let bad = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1032        let err = ismember_numeric_rows(a, bad).unwrap_err();
1033        assert!(err.contains("same number of columns"));
1034    }
1035
1036    #[test]
1037    fn ismember_gpu_roundtrip() {
1038        test_support::with_test_provider(|provider| {
1039            let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 4.0], vec![4, 1]).unwrap();
1040            let set = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1041            let view_a = runmat_accelerate_api::HostTensorView {
1042                data: &tensor.data,
1043                shape: &tensor.shape,
1044            };
1045            let view_b = runmat_accelerate_api::HostTensorView {
1046                data: &set.data,
1047                shape: &set.shape,
1048            };
1049            let handle_a = provider.upload(&view_a).expect("upload a");
1050            let handle_b = provider.upload(&view_b).expect("upload b");
1051            let eval = evaluate(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
1052                .expect("ismember");
1053            assert_eq!(eval.mask.data, vec![0, 1, 0, 1]);
1054            assert_eq!(eval.loc.data, vec![0.0, 1.0, 0.0, 1.0]);
1055        });
1056    }
1057
1058    #[test]
1059    fn ismember_gpu_rows_roundtrip() {
1060        test_support::with_test_provider(|provider| {
1061            let rows = Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
1062            let bank = Tensor::new(vec![1.0, 5.0, 3.0, 2.0, 6.0, 4.0], vec![3, 2]).unwrap();
1063            let view_a = runmat_accelerate_api::HostTensorView {
1064                data: &rows.data,
1065                shape: &rows.shape,
1066            };
1067            let view_b = runmat_accelerate_api::HostTensorView {
1068                data: &bank.data,
1069                shape: &bank.shape,
1070            };
1071            let handle_a = provider.upload(&view_a).expect("upload a");
1072            let handle_b = provider.upload(&view_b).expect("upload b");
1073            let eval = evaluate(
1074                Value::GpuTensor(handle_a.clone()),
1075                Value::GpuTensor(handle_b.clone()),
1076                &[Value::from("rows")],
1077            )
1078            .expect("ismember");
1079            assert_eq!(eval.mask.data, vec![1, 1]);
1080            assert_eq!(eval.loc.data, vec![1.0, 3.0]);
1081            let _ = provider.free(&handle_a);
1082            let _ = provider.free(&handle_b);
1083        });
1084    }
1085
1086    #[test]
1087    #[cfg(feature = "wgpu")]
1088    fn ismember_wgpu_numeric_matches_cpu() {
1089        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1090            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1091        );
1092
1093        let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 4.0], vec![4, 1]).unwrap();
1094        let set = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1095        let cpu_eval =
1096            ismember_numeric_from_tensors(tensor.clone(), set.clone(), false).expect("cpu");
1097
1098        let provider = runmat_accelerate_api::provider().expect("provider");
1099        let view_a = HostTensorView {
1100            data: &tensor.data,
1101            shape: &tensor.shape,
1102        };
1103        let view_b = HostTensorView {
1104            data: &set.data,
1105            shape: &set.shape,
1106        };
1107        let handle_a = provider.upload(&view_a).expect("upload a");
1108        let handle_b = provider.upload(&view_b).expect("upload b");
1109
1110        let eval = evaluate(
1111            Value::GpuTensor(handle_a.clone()),
1112            Value::GpuTensor(handle_b.clone()),
1113            &[],
1114        )
1115        .expect("gpu evaluate");
1116        assert_eq!(eval.mask.data, cpu_eval.mask.data);
1117        assert_eq!(eval.loc.data, cpu_eval.loc.data);
1118
1119        let _ = provider.free(&handle_a);
1120        let _ = provider.free(&handle_b);
1121
1122        let matrix = Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
1123        let bank = Tensor::new(vec![1.0, 7.0, 3.0, 2.0, 9.0, 4.0], vec![3, 2]).unwrap();
1124        let cpu_rows =
1125            ismember_numeric_from_tensors(matrix.clone(), bank.clone(), true).expect("cpu rows");
1126        let view_matrix = HostTensorView {
1127            data: &matrix.data,
1128            shape: &matrix.shape,
1129        };
1130        let view_bank = HostTensorView {
1131            data: &bank.data,
1132            shape: &bank.shape,
1133        };
1134        let handle_matrix = provider.upload(&view_matrix).expect("upload matrix");
1135        let handle_bank = provider.upload(&view_bank).expect("upload bank");
1136        let eval_rows = evaluate(
1137            Value::GpuTensor(handle_matrix.clone()),
1138            Value::GpuTensor(handle_bank.clone()),
1139            &[Value::from("rows")],
1140        )
1141        .expect("gpu rows evaluate");
1142        assert_eq!(eval_rows.mask.data, cpu_rows.mask.data);
1143        assert_eq!(eval_rows.loc.data, cpu_rows.loc.data);
1144        let _ = provider.free(&handle_matrix);
1145        let _ = provider.free(&handle_bank);
1146    }
1147
1148    #[test]
1149    fn scalar_return_is_bool() {
1150        let a = Value::Tensor(Tensor::new(vec![7.0], vec![1, 1]).unwrap());
1151        let b = Value::Tensor(Tensor::new(vec![7.0], vec![1, 1]).unwrap());
1152        let mask = evaluate(a, b, &[]).unwrap().into_mask_value();
1153        assert_eq!(mask, Value::Bool(true));
1154    }
1155
1156    #[test]
1157    fn parse_rows_option() {
1158        let opts = parse_options(&[Value::from("rows")]).unwrap();
1159        assert!(opts.rows);
1160    }
1161
1162    #[test]
1163    fn numeric_rows_with_nan() {
1164        let a = Tensor::new(vec![f64::NAN, 1.0], vec![2, 1]).unwrap();
1165        let b = Tensor::new(vec![f64::NAN, 2.0], vec![2, 1]).unwrap();
1166        let eval = ismember_numeric_rows(a, b).expect("ismember");
1167        assert_eq!(eval.mask.data, vec![1, 0]);
1168        assert_eq!(eval.loc.data, vec![1.0, 0.0]);
1169    }
1170
1171    #[cfg(feature = "doc_export")]
1172    #[test]
1173    fn doc_examples_present() {
1174        let blocks = test_support::doc_examples(DOC_MD);
1175        assert!(!blocks.is_empty());
1176    }
1177}