1use std::collections::HashMap;
4
5use runmat_accelerate_api::{
6 GpuTensorHandle, HostLogicalOwned, HostTensorOwned, IsMemberOptions as ProviderIsMemberOptions,
7 IsMemberResult,
8};
9use runmat_builtins::{CharArray, ComplexTensor, LogicalArray, StringArray, Tensor, Value};
10use runmat_macros::runtime_builtin;
11
12use crate::builtins::common::gpu_helpers;
13use crate::builtins::common::spec::{
14 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
15 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
16};
17use crate::builtins::common::tensor;
18#[cfg(feature = "doc_export")]
19use crate::register_builtin_doc_text;
20use crate::{register_builtin_fusion_spec, register_builtin_gpu_spec};
21
22#[cfg(feature = "doc_export")]
23pub const DOC_MD: &str = r#"---
24title: "ismember"
25category: "array/sorting_sets"
26keywords: ["ismember", "membership", "set", "rows", "indices", "gpu"]
27summary: "Identify array elements or rows that appear in another array while returning first-match indices."
28references:
29 - https://www.mathworks.com/help/matlab/ref/ismember.html
30gpu_support:
31 elementwise: false
32 reduction: false
33 precisions: ["f32", "f64"]
34 broadcasting: "none"
35 notes: "When providers lack a dedicated membership hook RunMat gathers GPU tensors and executes the host implementation."
36fusion:
37 elementwise: false
38 reduction: false
39 max_inputs: 2
40 constants: "inline"
41requires_feature: null
42tested:
43 unit: "builtins::array::sorting_sets::ismember::tests"
44 integration: "builtins::array::sorting_sets::ismember::tests::ismember_gpu_roundtrip"
45---
46
47# What does the `ismember` function do in MATLAB / RunMat?
48`ismember(A, B)` compares the elements (or rows) of `A` against `B` and returns a logical array
49marking which members of `A` are present in `B`. The optional second output reports the index in `B`
50of the first matched element. RunMat follows MATLAB semantics for numeric, logical, complex, string,
51and character arrays.
52
53## How does the `ismember` function behave in MATLAB / RunMat?
54- The first output `tf` has the same shape as `A` (or `size(A,1) × 1` when using `'rows'`).
55- The optional second output `loc` contains one-based indices into `B`, with `0` for values that are
56 not found.
57- Duplicate values in `A` return the index of the first occurrence in `B` every time they match.
58- `NaN` values are treated as identical so they match other `NaN` entries in `B`.
59- Character arrays follow column-major linear indexing, mirroring MATLAB.
60- The `'rows'` option compares complete rows; inputs must agree on the number of columns.
61- Legacy flags (`'legacy'`, `'R2012a'`) are deliberately unsupported in RunMat.
62
63## `ismember` Function GPU Execution Behaviour
64When either input is a GPU tensor, RunMat first checks whether the active acceleration provider
65exposes a custom `ismember` hook. Until providers implement that hook, the runtime transparently
66gathers GPU operands to host memory, performs the membership lookup using the CPU implementation,
67and returns host-resident outputs so results exactly match MATLAB.
68
69## GPU residency in RunMat (Do I need `gpuArray`?)
70
71Most code does not need to call `gpuArray` explicitly. The native auto-offload planner keeps track
72of residency and recognises that `ismember` is a sink: the operation produces logical outputs and
73one-based indices that currently live on the host. If an acceleration provider exposes a full
74`ismember` hook in the future, the planner can keep data on the device automatically. Until then,
75manual `gpuArray` / `gather` calls only serve to mirror MATLAB workflows; RunMat already performs
76the necessary transfers when it detects that tensors reside on the GPU.
77
78## Examples of using the `ismember` function in MATLAB / RunMat
79
80### Checking membership of numeric vectors
81```matlab
82A = [5 7 2 7];
83B = [7 9 5];
84[tf, loc] = ismember(A, B);
85```
86Expected output:
87```matlab
88tf =
89 1 1 0 1
90loc =
91 3 1 0 1
92```
93
94### Finding row membership in a matrix
95```matlab
96A = [1 2; 3 4; 1 2];
97B = [3 4; 5 6; 1 2];
98[tf, loc] = ismember(A, B, 'rows');
99```
100Expected output:
101```matlab
102tf =
103 1
104 1
105 1
106loc =
107 3
108 1
109 3
110```
111
112### Locating values and retrieving the index
113```matlab
114values = [10 20 30];
115set = [30 10 40];
116[tf, loc] = ismember(values, set);
117```
118Expected output:
119```matlab
120tf =
121 1 0 1
122loc =
123 2 0 1
124```
125
126### Testing characters against a set
127```matlab
128chars = ['r','u'; 'n','m'];
129set = ['m','a'; 'r','u'];
130[tf, loc] = ismember(chars, set);
131```
132Expected output:
133```matlab
134tf =
135 1 1
136 0 0
137loc =
138 3 1
139 0 0
140```
141
142### Working with string arrays
143```matlab
144A = ["apple" "pear" "banana"];
145B = ["pear" "orange" "apple"];
146[tf, loc] = ismember(A, B);
147```
148Expected output:
149```matlab
150tf =
151 1×3 logical array
152 1 1 0
153loc =
154 1×3 double
155 3 1 0
156```
157
158### Using `ismember` with `gpuArray` inputs
159```matlab
160G = gpuArray([1 4 2 4]);
161H = gpuArray([4 5]);
162[tf, loc] = ismember(G, H);
163```
164Expected output (RunMat gathers to host unless a provider implements `ismember`):
165```matlab
166tf =
167 0 1 0 1
168loc =
169 0 1 0 1
170```
171
172## FAQ
173
174### Does `ismember` treat `NaN` values as equal?
175Yes. `NaN` values compare equal for membership tests so every `NaN` in `A` matches any `NaN` in `B`.
176
177### What happens when an element of `A` is not found in `B`?
178The corresponding logical entry is `false` and the index output stores `0`, matching MATLAB.
179
180### Can I use `ismember` with string arrays and character arrays?
181Yes. String arrays, scalar strings, and character arrays are supported. Mixed string/char inputs
182should be normalised (for example, convert scalars with `string`).
183
184### How does the `'rows'` option change the output shape?
185`'rows'` compares entire rows and returns outputs of size `size(A,1) × 1`, regardless of how many
186columns the input matrices contain.
187
188### Are the legacy flags supported?
189No. RunMat only implements modern MATLAB semantics. Passing `'legacy'` or `'R2012a'` raises an
190error, just like other set builtins in RunMat.
191
192### Will `ismember` run on the GPU automatically?
193If the active provider advertises an `ismember` hook, the runtime can keep tensors on the device.
194Otherwise the data is gathered to the host with no behavioural differences.
195
196## See Also
197[unique](./unique), [intersect](./intersect), [setdiff](./setdiff), [union](./union), [gpuArray](../../acceleration/gpu/gpuArray), [gather](../../acceleration/gpu/gather)
198
199## Source & Feedback
200- Source code: [`crates/runmat-runtime/src/builtins/array/sorting_sets/ismember.rs`](https://github.com/runmat-org/runmat/blob/main/crates/runmat-runtime/src/builtins/array/sorting_sets/ismember.rs)
201- Found a bug? [Open an issue](https://github.com/runmat-org/runmat/issues/new/choose) with details and a minimal repro.
202"#;
203
204pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
205 name: "ismember",
206 op_kind: GpuOpKind::Custom("ismember"),
207 supported_precisions: &[ScalarType::F32, ScalarType::F64],
208 broadcast: BroadcastSemantics::None,
209 provider_hooks: &[ProviderHook::Custom("ismember")],
210 constant_strategy: ConstantStrategy::InlineLiteral,
211 residency: ResidencyPolicy::GatherImmediately,
212 nan_mode: ReductionNaN::Include,
213 two_pass_threshold: None,
214 workgroup_size: None,
215 accepts_nan_mode: false,
216 notes: "Providers may supply dedicated membership kernels; until then RunMat gathers GPU tensors to host memory.",
217};
218
219register_builtin_gpu_spec!(GPU_SPEC);
220
221pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
222 name: "ismember",
223 shape: ShapeRequirements::Any,
224 constant_strategy: ConstantStrategy::InlineLiteral,
225 elementwise: None,
226 reduction: None,
227 emits_nan: false,
228 notes: "Membership queries execute via host set lookups; the fusion planner treats ismember as a residency sink.",
229};
230
231register_builtin_fusion_spec!(FUSION_SPEC);
232
233#[cfg(feature = "doc_export")]
234register_builtin_doc_text!("ismember", DOC_MD);
235
236#[runtime_builtin(
237 name = "ismember",
238 category = "array/sorting_sets",
239 summary = "Identify array elements or rows that appear in another array while returning first-match indices.",
240 keywords = "ismember,membership,set,rows,indices,gpu",
241 accel = "array_construct",
242 sink = true
243)]
244fn ismember_builtin(a: Value, b: Value, rest: Vec<Value>) -> Result<Value, String> {
245 evaluate(a, b, &rest).map(|eval| eval.into_mask_value())
246}
247
248pub fn evaluate(a: Value, b: Value, rest: &[Value]) -> Result<IsMemberEvaluation, String> {
250 let opts = parse_options(rest)?;
251 match (a, b) {
252 (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
253 ismember_gpu_pair(handle_a, handle_b, &opts)
254 }
255 (Value::GpuTensor(handle_a), other) => ismember_gpu_mixed(handle_a, other, &opts, true),
256 (other, Value::GpuTensor(handle_b)) => ismember_gpu_mixed(handle_b, other, &opts, false),
257 (left, right) => ismember_host(left, right, &opts),
258 }
259}
260
261#[derive(Debug, Clone, Copy)]
262struct IsMemberOptions {
263 rows: bool,
264}
265
266impl IsMemberOptions {
267 fn into_provider_options(self) -> ProviderIsMemberOptions {
268 ProviderIsMemberOptions { rows: self.rows }
269 }
270}
271
272fn parse_options(rest: &[Value]) -> Result<IsMemberOptions, String> {
273 let mut opts = IsMemberOptions { rows: false };
274 for arg in rest {
275 let text = tensor::value_to_string(arg)
276 .ok_or_else(|| "ismember: expected string option arguments".to_string())?;
277 let lowered = text.trim().to_ascii_lowercase();
278 match lowered.as_str() {
279 "rows" => opts.rows = true,
280 "legacy" | "r2012a" => {
281 return Err("ismember: the 'legacy' behaviour is not supported".to_string())
282 }
283 other => return Err(format!("ismember: unrecognised option '{other}'")),
284 }
285 }
286 Ok(opts)
287}
288
289fn ismember_gpu_pair(
290 handle_a: GpuTensorHandle,
291 handle_b: GpuTensorHandle,
292 opts: &IsMemberOptions,
293) -> Result<IsMemberEvaluation, String> {
294 if let Some(provider) = runmat_accelerate_api::provider() {
295 let provider_opts = opts.into_provider_options();
296 match provider.ismember(&handle_a, &handle_b, &provider_opts) {
297 Ok(result) => return IsMemberEvaluation::from_provider_result(result),
298 Err(_) => {
299 }
301 }
302 }
303 let tensor_a = gpu_helpers::gather_tensor(&handle_a)?;
304 let tensor_b = gpu_helpers::gather_tensor(&handle_b)?;
305 ismember_numeric_tensors(tensor_a, tensor_b, opts)
306}
307
308fn ismember_gpu_mixed(
309 handle_gpu: GpuTensorHandle,
310 other: Value,
311 opts: &IsMemberOptions,
312 gpu_is_a: bool,
313) -> Result<IsMemberEvaluation, String> {
314 let tensor_gpu = gpu_helpers::gather_tensor(&handle_gpu)?;
315 if gpu_is_a {
316 ismember_host(Value::Tensor(tensor_gpu), other, opts)
317 } else {
318 ismember_host(other, Value::Tensor(tensor_gpu), opts)
319 }
320}
321
322fn ismember_host(a: Value, b: Value, opts: &IsMemberOptions) -> Result<IsMemberEvaluation, String> {
323 match (a, b) {
324 (Value::ComplexTensor(at), Value::ComplexTensor(bt)) => ismember_complex(at, bt, opts.rows),
325 (Value::ComplexTensor(at), Value::Complex(re, im)) => {
326 let bt = ComplexTensor::new(vec![(re, im)], vec![1, 1])
327 .map_err(|e| format!("ismember: {e}"))?;
328 ismember_complex(at, bt, opts.rows)
329 }
330 (Value::Complex(a_re, a_im), Value::ComplexTensor(bt)) => {
331 let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
332 .map_err(|e| format!("ismember: {e}"))?;
333 ismember_complex(at, bt, opts.rows)
334 }
335 (Value::Complex(a_re, a_im), Value::Complex(b_re, b_im)) => {
336 let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
337 .map_err(|e| format!("ismember: {e}"))?;
338 let bt = ComplexTensor::new(vec![(b_re, b_im)], vec![1, 1])
339 .map_err(|e| format!("ismember: {e}"))?;
340 ismember_complex(at, bt, opts.rows)
341 }
342
343 (Value::CharArray(ac), Value::CharArray(bc)) => ismember_char(ac, bc, opts.rows),
344
345 (Value::StringArray(astring), Value::StringArray(bstring)) => {
346 ismember_string(astring, bstring, opts.rows)
347 }
348 (Value::StringArray(astring), Value::String(b)) => {
349 let bstring =
350 StringArray::new(vec![b], vec![1, 1]).map_err(|e| format!("ismember: {e}"))?;
351 ismember_string(astring, bstring, opts.rows)
352 }
353 (Value::String(a), Value::StringArray(bstring)) => {
354 let astring =
355 StringArray::new(vec![a], vec![1, 1]).map_err(|e| format!("ismember: {e}"))?;
356 ismember_string(astring, bstring, opts.rows)
357 }
358 (Value::String(a), Value::String(b)) => {
359 let astring =
360 StringArray::new(vec![a], vec![1, 1]).map_err(|e| format!("ismember: {e}"))?;
361 let bstring =
362 StringArray::new(vec![b], vec![1, 1]).map_err(|e| format!("ismember: {e}"))?;
363 ismember_string(astring, bstring, opts.rows)
364 }
365
366 (left, right) => {
367 let tensor_a = tensor::value_into_tensor_for("ismember", left)?;
368 let tensor_b = tensor::value_into_tensor_for("ismember", right)?;
369 ismember_numeric_tensors(tensor_a, tensor_b, opts)
370 }
371 }
372}
373
374fn ismember_numeric_tensors(
375 a: Tensor,
376 b: Tensor,
377 opts: &IsMemberOptions,
378) -> Result<IsMemberEvaluation, String> {
379 if opts.rows {
380 ismember_numeric_rows(a, b)
381 } else {
382 ismember_numeric_elements(a, b)
383 }
384}
385
386pub fn ismember_numeric_from_tensors(
388 a: Tensor,
389 b: Tensor,
390 rows: bool,
391) -> Result<IsMemberEvaluation, String> {
392 let opts = IsMemberOptions { rows };
393 ismember_numeric_tensors(a, b, &opts)
394}
395
396fn ismember_numeric_elements(a: Tensor, b: Tensor) -> Result<IsMemberEvaluation, String> {
397 let mut map: HashMap<u64, usize> = HashMap::new();
398 for (idx, &value) in b.data.iter().enumerate() {
399 map.entry(canonicalize_f64(value)).or_insert(idx + 1);
400 }
401
402 let mut mask_data = Vec::<u8>::with_capacity(a.data.len());
403 let mut loc_data = Vec::<f64>::with_capacity(a.data.len());
404
405 for &value in &a.data {
406 let key = canonicalize_f64(value);
407 if let Some(&pos) = map.get(&key) {
408 mask_data.push(1);
409 loc_data.push(pos as f64);
410 } else {
411 mask_data.push(0);
412 loc_data.push(0.0);
413 }
414 }
415
416 let logical = LogicalArray::new(mask_data, a.shape.clone())?;
417 let loc_tensor =
418 Tensor::new(loc_data, a.shape.clone()).map_err(|e| format!("ismember: {e}"))?;
419 Ok(IsMemberEvaluation::new(logical, loc_tensor))
420}
421
422fn ismember_numeric_rows(a: Tensor, b: Tensor) -> Result<IsMemberEvaluation, String> {
423 let (rows_a, cols_a) = tensor_rows_cols(&a, "ismember")?;
424 let (rows_b, cols_b) = tensor_rows_cols(&b, "ismember")?;
425 if cols_a != cols_b {
426 return Err(
427 "ismember: inputs must have the same number of columns when using 'rows'".to_string(),
428 );
429 }
430
431 let mut map: HashMap<NumericRowKey, usize> = HashMap::new();
432 for r in 0..rows_b {
433 let mut row_values = Vec::with_capacity(cols_b);
434 for c in 0..cols_b {
435 let idx = r + c * rows_b;
436 row_values.push(b.data[idx]);
437 }
438 let key = NumericRowKey::from_slice(&row_values);
439 map.entry(key).or_insert(r + 1);
440 }
441
442 let mut mask_data = vec![0u8; rows_a];
443 let mut loc_data = vec![0.0f64; rows_a];
444
445 for r in 0..rows_a {
446 let mut row_values = Vec::with_capacity(cols_a);
447 for c in 0..cols_a {
448 let idx = r + c * rows_a;
449 row_values.push(a.data[idx]);
450 }
451 let key = NumericRowKey::from_slice(&row_values);
452 if let Some(&pos) = map.get(&key) {
453 mask_data[r] = 1;
454 loc_data[r] = pos as f64;
455 }
456 }
457
458 let shape = vec![rows_a, 1];
459 let logical = LogicalArray::new(mask_data, shape.clone())?;
460 let loc_tensor = Tensor::new(loc_data, shape).map_err(|e| format!("ismember: {e}"))?;
461 Ok(IsMemberEvaluation::new(logical, loc_tensor))
462}
463
464fn ismember_complex(
465 a: ComplexTensor,
466 b: ComplexTensor,
467 rows: bool,
468) -> Result<IsMemberEvaluation, String> {
469 if rows {
470 ismember_complex_rows(a, b)
471 } else {
472 ismember_complex_elements(a, b)
473 }
474}
475
476fn ismember_complex_elements(
477 a: ComplexTensor,
478 b: ComplexTensor,
479) -> Result<IsMemberEvaluation, String> {
480 let mut map: HashMap<ComplexKey, usize> = HashMap::new();
481 for (idx, &value) in b.data.iter().enumerate() {
482 map.entry(ComplexKey::new(value)).or_insert(idx + 1);
483 }
484
485 let mut mask_data = Vec::<u8>::with_capacity(a.data.len());
486 let mut loc_data = Vec::<f64>::with_capacity(a.data.len());
487
488 for &value in &a.data {
489 let key = ComplexKey::new(value);
490 if let Some(&pos) = map.get(&key) {
491 mask_data.push(1);
492 loc_data.push(pos as f64);
493 } else {
494 mask_data.push(0);
495 loc_data.push(0.0);
496 }
497 }
498
499 let logical = LogicalArray::new(mask_data, a.shape.clone())?;
500 let loc_tensor =
501 Tensor::new(loc_data, a.shape.clone()).map_err(|e| format!("ismember: {e}"))?;
502 Ok(IsMemberEvaluation::new(logical, loc_tensor))
503}
504
505fn ismember_complex_rows(a: ComplexTensor, b: ComplexTensor) -> Result<IsMemberEvaluation, String> {
506 let (rows_a, cols_a) = complex_rows_cols(&a)?;
507 let (rows_b, cols_b) = complex_rows_cols(&b)?;
508 if cols_a != cols_b {
509 return Err(
510 "ismember: complex inputs must have the same number of columns when using 'rows'"
511 .to_string(),
512 );
513 }
514
515 let mut map: HashMap<Vec<ComplexKey>, usize> = HashMap::new();
516 for r in 0..rows_b {
517 let mut row_keys = Vec::with_capacity(cols_b);
518 for c in 0..cols_b {
519 let idx = r + c * rows_b;
520 row_keys.push(ComplexKey::new(b.data[idx]));
521 }
522 map.entry(row_keys).or_insert(r + 1);
523 }
524
525 let mut mask_data = vec![0u8; rows_a];
526 let mut loc_data = vec![0.0f64; rows_a];
527
528 for r in 0..rows_a {
529 let mut row_keys = Vec::with_capacity(cols_a);
530 for c in 0..cols_a {
531 let idx = r + c * rows_a;
532 row_keys.push(ComplexKey::new(a.data[idx]));
533 }
534 if let Some(&pos) = map.get(&row_keys) {
535 mask_data[r] = 1;
536 loc_data[r] = pos as f64;
537 }
538 }
539
540 let shape = vec![rows_a, 1];
541 let logical = LogicalArray::new(mask_data, shape.clone())?;
542 let loc_tensor = Tensor::new(loc_data, shape).map_err(|e| format!("ismember: {e}"))?;
543 Ok(IsMemberEvaluation::new(logical, loc_tensor))
544}
545
546fn ismember_char(a: CharArray, b: CharArray, rows: bool) -> Result<IsMemberEvaluation, String> {
547 if rows {
548 ismember_char_rows(a, b)
549 } else {
550 ismember_char_elements(a, b)
551 }
552}
553
554fn ismember_char_elements(a: CharArray, b: CharArray) -> Result<IsMemberEvaluation, String> {
555 let rows_b = b.rows;
556 let cols_b = b.cols;
557 let mut map: HashMap<char, usize> = HashMap::new();
558
559 for col in 0..cols_b {
560 for row in 0..rows_b {
561 let data_idx = row * cols_b + col;
562 let ch = b.data[data_idx];
563 let linear_idx = row + col * rows_b;
564 map.entry(ch).or_insert(linear_idx + 1);
565 }
566 }
567
568 let rows_a = a.rows;
569 let cols_a = a.cols;
570 let mut mask_data = vec![0u8; rows_a * cols_a];
571 let mut loc_data = vec![0.0f64; rows_a * cols_a];
572
573 for col in 0..cols_a {
574 for row in 0..rows_a {
575 let data_idx = row * cols_a + col;
576 let ch = a.data[data_idx];
577 let linear_idx = row + col * rows_a;
578 if let Some(&pos) = map.get(&ch) {
579 mask_data[linear_idx] = 1;
580 loc_data[linear_idx] = pos as f64;
581 }
582 }
583 }
584
585 let shape = vec![rows_a, cols_a];
586 let logical = LogicalArray::new(mask_data, shape.clone())?;
587 let loc_tensor = Tensor::new(loc_data, shape).map_err(|e| format!("ismember: {e}"))?;
588 Ok(IsMemberEvaluation::new(logical, loc_tensor))
589}
590
591fn ismember_char_rows(a: CharArray, b: CharArray) -> Result<IsMemberEvaluation, String> {
592 if a.cols != b.cols {
593 return Err(
594 "ismember: character inputs must have the same number of columns when using 'rows'"
595 .to_string(),
596 );
597 }
598
599 let rows_b = b.rows;
600 let cols = b.cols;
601 let mut map: HashMap<RowCharKey, usize> = HashMap::new();
602
603 for r in 0..rows_b {
604 let mut row_values = Vec::with_capacity(cols);
605 for c in 0..cols {
606 let idx = r * cols + c;
607 row_values.push(b.data[idx]);
608 }
609 let key = RowCharKey::from_slice(&row_values);
610 map.entry(key).or_insert(r + 1);
611 }
612
613 let rows_a = a.rows;
614 let mut mask_data = vec![0u8; rows_a];
615 let mut loc_data = vec![0.0f64; rows_a];
616
617 for r in 0..rows_a {
618 let mut row_values = Vec::with_capacity(cols);
619 for c in 0..cols {
620 let idx = r * cols + c;
621 row_values.push(a.data[idx]);
622 }
623 let key = RowCharKey::from_slice(&row_values);
624 if let Some(&pos) = map.get(&key) {
625 mask_data[r] = 1;
626 loc_data[r] = pos as f64;
627 }
628 }
629
630 let shape = vec![rows_a, 1];
631 let logical = LogicalArray::new(mask_data, shape.clone())?;
632 let loc_tensor = Tensor::new(loc_data, shape).map_err(|e| format!("ismember: {e}"))?;
633 Ok(IsMemberEvaluation::new(logical, loc_tensor))
634}
635
636fn ismember_string(
637 a: StringArray,
638 b: StringArray,
639 rows: bool,
640) -> Result<IsMemberEvaluation, String> {
641 if rows {
642 ismember_string_rows(a, b)
643 } else {
644 ismember_string_elements(a, b)
645 }
646}
647
648fn ismember_string_elements(a: StringArray, b: StringArray) -> Result<IsMemberEvaluation, String> {
649 let mut map: HashMap<String, usize> = HashMap::new();
650 for (idx, value) in b.data.iter().enumerate() {
651 map.entry(value.clone()).or_insert(idx + 1);
652 }
653
654 let mut mask_data = Vec::<u8>::with_capacity(a.data.len());
655 let mut loc_data = Vec::<f64>::with_capacity(a.data.len());
656
657 for value in &a.data {
658 if let Some(&pos) = map.get(value) {
659 mask_data.push(1);
660 loc_data.push(pos as f64);
661 } else {
662 mask_data.push(0);
663 loc_data.push(0.0);
664 }
665 }
666
667 let logical = LogicalArray::new(mask_data, a.shape.clone())?;
668 let loc_tensor =
669 Tensor::new(loc_data, a.shape.clone()).map_err(|e| format!("ismember: {e}"))?;
670 Ok(IsMemberEvaluation::new(logical, loc_tensor))
671}
672
673fn ismember_string_rows(a: StringArray, b: StringArray) -> Result<IsMemberEvaluation, String> {
674 if a.shape.len() != 2 || b.shape.len() != 2 {
675 return Err("ismember: 'rows' option requires 2-D string arrays".to_string());
676 }
677 if a.shape[1] != b.shape[1] {
678 return Err(
679 "ismember: string inputs must have the same number of columns when using 'rows'"
680 .to_string(),
681 );
682 }
683
684 let rows_a = a.shape[0];
685 let cols = a.shape[1];
686 let rows_b = b.shape[0];
687
688 let mut map: HashMap<RowStringKey, usize> = HashMap::new();
689 for r in 0..rows_b {
690 let mut row_values = Vec::with_capacity(cols);
691 for c in 0..cols {
692 let idx = r + c * rows_b;
693 row_values.push(b.data[idx].clone());
694 }
695 let key = RowStringKey(row_values);
696 map.entry(key).or_insert(r + 1);
697 }
698
699 let mut mask_data = vec![0u8; rows_a];
700 let mut loc_data = vec![0.0f64; rows_a];
701
702 for r in 0..rows_a {
703 let mut row_values = Vec::with_capacity(cols);
704 for c in 0..cols {
705 let idx = r + c * rows_a;
706 row_values.push(a.data[idx].clone());
707 }
708 let key = RowStringKey(row_values);
709 if let Some(&pos) = map.get(&key) {
710 mask_data[r] = 1;
711 loc_data[r] = pos as f64;
712 }
713 }
714
715 let shape = vec![rows_a, 1];
716 let logical = LogicalArray::new(mask_data, shape.clone())?;
717 let loc_tensor = Tensor::new(loc_data, shape).map_err(|e| format!("ismember: {e}"))?;
718 Ok(IsMemberEvaluation::new(logical, loc_tensor))
719}
720
721fn tensor_rows_cols(t: &Tensor, name: &str) -> Result<(usize, usize), String> {
722 match t.shape.len() {
723 0 => Ok((1, 1)),
724 1 => Ok((t.shape[0], 1)),
725 2 => Ok((t.shape[0], t.shape[1])),
726 _ => Err(format!(
727 "{name}: 'rows' option requires 2-D numeric matrices"
728 )),
729 }
730}
731
732fn complex_rows_cols(t: &ComplexTensor) -> Result<(usize, usize), String> {
733 match t.shape.len() {
734 0 => Ok((1, 1)),
735 1 => Ok((t.shape[0], 1)),
736 2 => Ok((t.shape[0], t.shape[1])),
737 _ => Err("ismember: 'rows' option requires 2-D complex matrices".to_string()),
738 }
739}
740
741#[derive(Debug, Clone, PartialEq, Eq, Hash)]
742struct NumericRowKey(Vec<u64>);
743
744impl NumericRowKey {
745 fn from_slice(values: &[f64]) -> Self {
746 NumericRowKey(values.iter().map(|&v| canonicalize_f64(v)).collect())
747 }
748}
749
750#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
751struct ComplexKey {
752 re: u64,
753 im: u64,
754}
755
756impl ComplexKey {
757 fn new(value: (f64, f64)) -> Self {
758 Self {
759 re: canonicalize_f64(value.0),
760 im: canonicalize_f64(value.1),
761 }
762 }
763}
764
765#[derive(Debug, Clone, PartialEq, Eq, Hash)]
766struct RowCharKey(Vec<u32>);
767
768impl RowCharKey {
769 fn from_slice(values: &[char]) -> Self {
770 RowCharKey(values.iter().map(|&ch| ch as u32).collect())
771 }
772}
773
774#[derive(Debug, Clone, PartialEq, Eq, Hash)]
775struct RowStringKey(Vec<String>);
776
777fn canonicalize_f64(value: f64) -> u64 {
778 if value.is_nan() {
779 0x7ff8_0000_0000_0000u64
780 } else if value == 0.0 {
781 0u64
782 } else {
783 value.to_bits()
784 }
785}
786
787#[derive(Debug, Clone)]
788pub struct IsMemberEvaluation {
789 mask: LogicalArray,
790 loc: Tensor,
791}
792
793impl IsMemberEvaluation {
794 fn new(mask: LogicalArray, loc: Tensor) -> Self {
795 Self { mask, loc }
796 }
797
798 pub fn from_provider_result(result: IsMemberResult) -> Result<Self, String> {
799 let mask = LogicalArray::new(result.mask.data, result.mask.shape)
800 .map_err(|e| format!("ismember: {e}"))?;
801 let loc =
802 Tensor::new(result.loc.data, result.loc.shape).map_err(|e| format!("ismember: {e}"))?;
803 Ok(IsMemberEvaluation::new(mask, loc))
804 }
805
806 pub fn into_numeric_ismember_result(self) -> Result<IsMemberResult, String> {
807 let IsMemberEvaluation { mask, loc } = self;
808 Ok(IsMemberResult {
809 mask: HostLogicalOwned {
810 data: mask.data,
811 shape: mask.shape,
812 },
813 loc: HostTensorOwned {
814 data: loc.data,
815 shape: loc.shape,
816 },
817 })
818 }
819
820 pub fn into_mask_value(self) -> Value {
821 logical_array_into_value(self.mask)
822 }
823
824 pub fn mask_value(&self) -> Value {
825 logical_array_into_value(self.mask.clone())
826 }
827
828 pub fn into_pair(self) -> (Value, Value) {
829 let mask = logical_array_into_value(self.mask);
830 let loc = tensor::tensor_into_value(self.loc);
831 (mask, loc)
832 }
833
834 pub fn loc_value(&self) -> Value {
835 tensor::tensor_into_value(self.loc.clone())
836 }
837}
838
839fn logical_array_into_value(logical: LogicalArray) -> Value {
840 if logical.data.len() == 1 {
841 Value::Bool(logical.data[0] != 0)
842 } else {
843 Value::LogicalArray(logical)
844 }
845}
846
847#[cfg(test)]
848mod tests {
849 use super::*;
850 use crate::builtins::common::test_support;
851 use runmat_builtins::Tensor;
852
853 #[cfg(feature = "wgpu")]
854 use runmat_accelerate_api::HostTensorView;
855
856 #[test]
857 fn numeric_membership_basic() {
858 let a = Tensor::new(vec![5.0, 7.0, 2.0, 7.0], vec![1, 4]).unwrap();
859 let b = Tensor::new(vec![7.0, 9.0, 5.0], vec![1, 3]).unwrap();
860 let eval = ismember_numeric_elements(a, b).expect("ismember");
861 assert_eq!(eval.mask.data, vec![1, 1, 0, 1]);
862 assert_eq!(eval.loc.data, vec![3.0, 1.0, 0.0, 1.0]);
863 }
864
865 #[test]
866 fn numeric_nan_membership() {
867 let a = Tensor::new(vec![f64::NAN, 1.0], vec![1, 2]).unwrap();
868 let b = Tensor::new(vec![f64::NAN, 2.0], vec![1, 2]).unwrap();
869 let eval = ismember_numeric_elements(a, b).expect("ismember");
870 assert_eq!(eval.mask.data, vec![1, 0]);
871 assert_eq!(eval.loc.data, vec![1.0, 0.0]);
872 }
873
874 #[test]
875 fn numeric_rows_membership() {
876 let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
877 let b = Tensor::new(vec![3.0, 5.0, 1.0, 4.0, 6.0, 2.0], vec![3, 2]).unwrap();
878 let eval = ismember_numeric_rows(a, b).expect("ismember");
879 assert_eq!(eval.mask.data, vec![1, 1, 1]);
880 assert_eq!(eval.loc.data, vec![3.0, 1.0, 3.0]);
881 assert_eq!(eval.loc.shape, vec![3, 1]);
882 }
883
884 #[test]
885 fn complex_membership() {
886 let a = ComplexTensor::new(vec![(1.0, 2.0), (0.0, 0.0)], vec![1, 2]).unwrap();
887 let b = ComplexTensor::new(vec![(0.0, 0.0), (1.0, 2.0)], vec![1, 2]).unwrap();
888 let eval = ismember_complex_elements(a, b).expect("ismember");
889 assert_eq!(eval.mask.data, vec![1, 1]);
890 assert_eq!(eval.loc.data, vec![2.0, 1.0]);
891 }
892
893 #[test]
894 fn complex_rows_membership() {
895 let a = ComplexTensor::new(
896 vec![(1.0, 1.0), (3.0, 0.0), (2.0, 0.0), (4.0, 4.0)],
897 vec![2, 2],
898 )
899 .unwrap();
900 let b = ComplexTensor::new(
901 vec![
902 (1.0, 1.0),
903 (5.0, 0.0),
904 (3.0, 0.0),
905 (2.0, 0.0),
906 (6.0, 0.0),
907 (4.0, 4.0),
908 ],
909 vec![3, 2],
910 )
911 .unwrap();
912 let eval = ismember_complex_rows(a, b).expect("ismember");
913 assert_eq!(eval.mask.data, vec![1, 1]);
914 assert_eq!(eval.loc.data, vec![1.0, 3.0]);
915 }
916
917 #[test]
918 fn char_membership() {
919 let a = CharArray::new(vec!['r', 'u', 'n', 'm'], 2, 2).unwrap();
920 let b = CharArray::new(vec!['m', 'a', 'r', 'u'], 2, 2).unwrap();
921 let eval = ismember_char_elements(a, b).expect("ismember");
922 assert_eq!(eval.mask.data, vec![1, 0, 1, 1]);
923 assert_eq!(eval.loc.data, vec![2.0, 0.0, 4.0, 1.0]);
924 }
925
926 #[test]
927 fn char_rows_membership() {
928 let a = CharArray::new(vec!['m', 'a', 't', 'l'], 2, 2).unwrap();
929 let b = CharArray::new(vec!['m', 'a', 'g', 'e', 't', 'l'], 3, 2).unwrap();
930 let eval = ismember_char_rows(a, b).expect("ismember");
931 assert_eq!(eval.mask.data, vec![1, 1]);
932 assert_eq!(eval.loc.data, vec![1.0, 3.0]);
933 }
934
935 #[test]
936 fn string_membership() {
937 let a = StringArray::new(
938 vec![
939 "apple".to_string(),
940 "pear".to_string(),
941 "banana".to_string(),
942 ],
943 vec![1, 3],
944 )
945 .unwrap();
946 let b = StringArray::new(
947 vec![
948 "pear".to_string(),
949 "orange".to_string(),
950 "apple".to_string(),
951 ],
952 vec![1, 3],
953 )
954 .unwrap();
955 let eval = ismember_string_elements(a, b).expect("ismember");
956 assert_eq!(eval.mask.data, vec![1, 1, 0]);
957 assert_eq!(eval.loc.data, vec![3.0, 1.0, 0.0]);
958 }
959
960 #[test]
961 fn string_rows_membership() {
962 let a = StringArray::new(
963 vec![
964 "alpha".to_string(),
965 "gamma".to_string(),
966 "beta".to_string(),
967 "delta".to_string(),
968 ],
969 vec![2, 2],
970 )
971 .unwrap();
972 let b = StringArray::new(
973 vec![
974 "alpha".to_string(),
975 "theta".to_string(),
976 "gamma".to_string(),
977 "beta".to_string(),
978 "eta".to_string(),
979 "delta".to_string(),
980 ],
981 vec![3, 2],
982 )
983 .unwrap();
984 let eval = ismember_string_rows(a, b).expect("ismember");
985 assert_eq!(eval.mask.data, vec![1, 1]);
986 assert_eq!(eval.loc.data, vec![1.0, 3.0]);
987 }
988
989 #[test]
990 fn options_reject_legacy() {
991 let err = parse_options(&[Value::from("legacy")]).unwrap_err();
992 assert!(err.contains("legacy"));
993 }
994
995 #[test]
996 fn rejects_unknown_option() {
997 let err = evaluate(Value::Num(1.0), Value::Num(1.0), &[Value::from("stable")]).unwrap_err();
998 assert!(err.contains("unrecognised option"));
999 }
1000
1001 #[test]
1002 fn ismember_runtime_numeric() {
1003 let a = Value::Tensor(Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap());
1004 let b = Value::Tensor(Tensor::new(vec![3.0, 1.0], vec![2, 1]).unwrap());
1005 let (mask, loc) = evaluate(a, b, &[]).unwrap().into_pair();
1006 match mask {
1007 Value::LogicalArray(arr) => assert_eq!(arr.data, vec![1, 0, 1]),
1008 other => panic!("expected logical array, got {other:?}"),
1009 }
1010 match loc {
1011 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 0.0, 1.0]),
1012 other => panic!("expected tensor, got {other:?}"),
1013 }
1014 }
1015
1016 #[test]
1017 fn logical_inputs_promoted() {
1018 let a = Value::Bool(true);
1019 let logical_b =
1020 LogicalArray::new(vec![1, 0], vec![2, 1]).expect("logical array construction");
1021 let eval = evaluate(a, Value::LogicalArray(logical_b), &[]).expect("ismember");
1022 assert_eq!(eval.mask_value(), Value::Bool(true));
1023 assert_eq!(eval.loc_value(), Value::Num(1.0));
1024 }
1025
1026 #[test]
1027 fn ismember_rows_shape_checks() {
1028 let a = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1029 let b = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1030 assert!(ismember_numeric_rows(a.clone(), b.clone()).is_ok());
1031 let bad = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1032 let err = ismember_numeric_rows(a, bad).unwrap_err();
1033 assert!(err.contains("same number of columns"));
1034 }
1035
1036 #[test]
1037 fn ismember_gpu_roundtrip() {
1038 test_support::with_test_provider(|provider| {
1039 let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 4.0], vec![4, 1]).unwrap();
1040 let set = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1041 let view_a = runmat_accelerate_api::HostTensorView {
1042 data: &tensor.data,
1043 shape: &tensor.shape,
1044 };
1045 let view_b = runmat_accelerate_api::HostTensorView {
1046 data: &set.data,
1047 shape: &set.shape,
1048 };
1049 let handle_a = provider.upload(&view_a).expect("upload a");
1050 let handle_b = provider.upload(&view_b).expect("upload b");
1051 let eval = evaluate(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
1052 .expect("ismember");
1053 assert_eq!(eval.mask.data, vec![0, 1, 0, 1]);
1054 assert_eq!(eval.loc.data, vec![0.0, 1.0, 0.0, 1.0]);
1055 });
1056 }
1057
1058 #[test]
1059 fn ismember_gpu_rows_roundtrip() {
1060 test_support::with_test_provider(|provider| {
1061 let rows = Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
1062 let bank = Tensor::new(vec![1.0, 5.0, 3.0, 2.0, 6.0, 4.0], vec![3, 2]).unwrap();
1063 let view_a = runmat_accelerate_api::HostTensorView {
1064 data: &rows.data,
1065 shape: &rows.shape,
1066 };
1067 let view_b = runmat_accelerate_api::HostTensorView {
1068 data: &bank.data,
1069 shape: &bank.shape,
1070 };
1071 let handle_a = provider.upload(&view_a).expect("upload a");
1072 let handle_b = provider.upload(&view_b).expect("upload b");
1073 let eval = evaluate(
1074 Value::GpuTensor(handle_a.clone()),
1075 Value::GpuTensor(handle_b.clone()),
1076 &[Value::from("rows")],
1077 )
1078 .expect("ismember");
1079 assert_eq!(eval.mask.data, vec![1, 1]);
1080 assert_eq!(eval.loc.data, vec![1.0, 3.0]);
1081 let _ = provider.free(&handle_a);
1082 let _ = provider.free(&handle_b);
1083 });
1084 }
1085
1086 #[test]
1087 #[cfg(feature = "wgpu")]
1088 fn ismember_wgpu_numeric_matches_cpu() {
1089 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1090 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1091 );
1092
1093 let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 4.0], vec![4, 1]).unwrap();
1094 let set = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1095 let cpu_eval =
1096 ismember_numeric_from_tensors(tensor.clone(), set.clone(), false).expect("cpu");
1097
1098 let provider = runmat_accelerate_api::provider().expect("provider");
1099 let view_a = HostTensorView {
1100 data: &tensor.data,
1101 shape: &tensor.shape,
1102 };
1103 let view_b = HostTensorView {
1104 data: &set.data,
1105 shape: &set.shape,
1106 };
1107 let handle_a = provider.upload(&view_a).expect("upload a");
1108 let handle_b = provider.upload(&view_b).expect("upload b");
1109
1110 let eval = evaluate(
1111 Value::GpuTensor(handle_a.clone()),
1112 Value::GpuTensor(handle_b.clone()),
1113 &[],
1114 )
1115 .expect("gpu evaluate");
1116 assert_eq!(eval.mask.data, cpu_eval.mask.data);
1117 assert_eq!(eval.loc.data, cpu_eval.loc.data);
1118
1119 let _ = provider.free(&handle_a);
1120 let _ = provider.free(&handle_b);
1121
1122 let matrix = Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
1123 let bank = Tensor::new(vec![1.0, 7.0, 3.0, 2.0, 9.0, 4.0], vec![3, 2]).unwrap();
1124 let cpu_rows =
1125 ismember_numeric_from_tensors(matrix.clone(), bank.clone(), true).expect("cpu rows");
1126 let view_matrix = HostTensorView {
1127 data: &matrix.data,
1128 shape: &matrix.shape,
1129 };
1130 let view_bank = HostTensorView {
1131 data: &bank.data,
1132 shape: &bank.shape,
1133 };
1134 let handle_matrix = provider.upload(&view_matrix).expect("upload matrix");
1135 let handle_bank = provider.upload(&view_bank).expect("upload bank");
1136 let eval_rows = evaluate(
1137 Value::GpuTensor(handle_matrix.clone()),
1138 Value::GpuTensor(handle_bank.clone()),
1139 &[Value::from("rows")],
1140 )
1141 .expect("gpu rows evaluate");
1142 assert_eq!(eval_rows.mask.data, cpu_rows.mask.data);
1143 assert_eq!(eval_rows.loc.data, cpu_rows.loc.data);
1144 let _ = provider.free(&handle_matrix);
1145 let _ = provider.free(&handle_bank);
1146 }
1147
1148 #[test]
1149 fn scalar_return_is_bool() {
1150 let a = Value::Tensor(Tensor::new(vec![7.0], vec![1, 1]).unwrap());
1151 let b = Value::Tensor(Tensor::new(vec![7.0], vec![1, 1]).unwrap());
1152 let mask = evaluate(a, b, &[]).unwrap().into_mask_value();
1153 assert_eq!(mask, Value::Bool(true));
1154 }
1155
1156 #[test]
1157 fn parse_rows_option() {
1158 let opts = parse_options(&[Value::from("rows")]).unwrap();
1159 assert!(opts.rows);
1160 }
1161
1162 #[test]
1163 fn numeric_rows_with_nan() {
1164 let a = Tensor::new(vec![f64::NAN, 1.0], vec![2, 1]).unwrap();
1165 let b = Tensor::new(vec![f64::NAN, 2.0], vec![2, 1]).unwrap();
1166 let eval = ismember_numeric_rows(a, b).expect("ismember");
1167 assert_eq!(eval.mask.data, vec![1, 0]);
1168 assert_eq!(eval.loc.data, vec![1.0, 0.0]);
1169 }
1170
1171 #[cfg(feature = "doc_export")]
1172 #[test]
1173 fn doc_examples_present() {
1174 let blocks = test_support::doc_examples(DOC_MD);
1175 assert!(!blocks.is_empty());
1176 }
1177}