1use std::collections::HashMap;
4
5use runmat_accelerate_api::{
6 GpuTensorHandle, GpuTensorStorage, HostLogicalOwned, HostTensorOwned,
7 IsMemberOptions as ProviderIsMemberOptions, IsMemberResult,
8};
9use runmat_builtins::{CharArray, ComplexTensor, LogicalArray, StringArray, Tensor, Value};
10use runmat_macros::runtime_builtin;
11
12use super::type_resolvers::logical_output_type;
13use crate::build_runtime_error;
14use crate::builtins::common::gpu_helpers;
15use crate::builtins::common::spec::{
16 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
17 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
18};
19use crate::builtins::common::tensor;
20
21#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::sorting_sets::ismember")]
22pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
23 name: "ismember",
24 op_kind: GpuOpKind::Custom("ismember"),
25 supported_precisions: &[ScalarType::F32, ScalarType::F64],
26 broadcast: BroadcastSemantics::None,
27 provider_hooks: &[ProviderHook::Custom("ismember")],
28 constant_strategy: ConstantStrategy::InlineLiteral,
29 residency: ResidencyPolicy::GatherImmediately,
30 nan_mode: ReductionNaN::Include,
31 two_pass_threshold: None,
32 workgroup_size: None,
33 accepts_nan_mode: false,
34 notes: "Providers may supply dedicated membership kernels; until then RunMat gathers GPU tensors to host memory.",
35};
36
37#[runmat_macros::register_fusion_spec(
38 builtin_path = "crate::builtins::array::sorting_sets::ismember"
39)]
40pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
41 name: "ismember",
42 shape: ShapeRequirements::Any,
43 constant_strategy: ConstantStrategy::InlineLiteral,
44 elementwise: None,
45 reduction: None,
46 emits_nan: false,
47 notes: "`ismember` materialises logical outputs and terminates fusion chains; upstream tensors are gathered when necessary.",
48};
49
50fn ismember_error(message: impl Into<String>) -> crate::RuntimeError {
51 build_runtime_error(message)
52 .with_builtin("ismember")
53 .build()
54}
55
56#[runtime_builtin(
57 name = "ismember",
58 category = "array/sorting_sets",
59 summary = "Identify array elements or rows that appear in another array while returning first-match indices.",
60 keywords = "ismember,membership,set,rows,indices,gpu",
61 accel = "array_construct",
62 sink = true,
63 type_resolver(logical_output_type),
64 builtin_path = "crate::builtins::array::sorting_sets::ismember"
65)]
66async fn ismember_builtin(a: Value, b: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
67 let eval = evaluate(a, b, &rest).await?;
68 if let Some(out_count) = crate::output_count::current_output_count() {
69 if out_count == 0 {
70 return Ok(Value::OutputList(Vec::new()));
71 }
72 if out_count == 1 {
73 return Ok(Value::OutputList(vec![eval.into_mask_value()]));
74 }
75 let (mask, loc) = eval.into_pair();
76 return Ok(crate::output_count::output_list_with_padding(
77 out_count,
78 vec![mask, loc],
79 ));
80 }
81 Ok(eval.into_mask_value())
82}
83
84pub async fn evaluate(
86 a: Value,
87 b: Value,
88 rest: &[Value],
89) -> crate::BuiltinResult<IsMemberEvaluation> {
90 let opts = parse_options(rest)?;
91 match (a, b) {
92 (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
93 ismember_gpu_pair(handle_a, handle_b, &opts).await
94 }
95 (Value::GpuTensor(handle_a), other) => {
96 ismember_gpu_mixed(handle_a, other, &opts, true).await
97 }
98 (other, Value::GpuTensor(handle_b)) => {
99 ismember_gpu_mixed(handle_b, other, &opts, false).await
100 }
101 (left, right) => ismember_host(left, right, &opts),
102 }
103}
104
105#[derive(Debug, Clone, Copy)]
106struct IsMemberOptions {
107 rows: bool,
108}
109
110impl IsMemberOptions {
111 fn into_provider_options(self) -> ProviderIsMemberOptions {
112 ProviderIsMemberOptions { rows: self.rows }
113 }
114}
115
116fn parse_options(rest: &[Value]) -> crate::BuiltinResult<IsMemberOptions> {
117 let mut opts = IsMemberOptions { rows: false };
118 for arg in rest {
119 let text = tensor::value_to_string(arg)
120 .ok_or_else(|| ismember_error("ismember: expected string option arguments"))?;
121 let lowered = text.trim().to_ascii_lowercase();
122 match lowered.as_str() {
123 "rows" => opts.rows = true,
124 "legacy" | "r2012a" => {
125 return Err(ismember_error(
126 "ismember: the 'legacy' behaviour is not supported",
127 ))
128 }
129 other => {
130 return Err(ismember_error(format!(
131 "ismember: unrecognised option '{other}'"
132 )))
133 }
134 }
135 }
136 Ok(opts)
137}
138
139async fn ismember_gpu_pair(
140 handle_a: GpuTensorHandle,
141 handle_b: GpuTensorHandle,
142 opts: &IsMemberOptions,
143) -> crate::BuiltinResult<IsMemberEvaluation> {
144 if let Some(provider) = runmat_accelerate_api::provider() {
145 let provider_opts = opts.into_provider_options();
146 match provider
147 .ismember(&handle_a, &handle_b, &provider_opts)
148 .await
149 {
150 Ok(result) => return IsMemberEvaluation::from_provider_result(result),
151 Err(_) => {
152 }
154 }
155 }
156 let tensor_a = gpu_helpers::gather_tensor_async(&handle_a).await?;
157 let tensor_b = gpu_helpers::gather_tensor_async(&handle_b).await?;
158 ismember_numeric_tensors(tensor_a, tensor_b, opts)
159}
160
161async fn ismember_gpu_mixed(
162 handle_gpu: GpuTensorHandle,
163 other: Value,
164 opts: &IsMemberOptions,
165 gpu_is_a: bool,
166) -> crate::BuiltinResult<IsMemberEvaluation> {
167 let tensor_gpu = gpu_helpers::gather_tensor_async(&handle_gpu).await?;
168 if gpu_is_a {
169 ismember_host(Value::Tensor(tensor_gpu), other, opts)
170 } else {
171 ismember_host(other, Value::Tensor(tensor_gpu), opts)
172 }
173}
174
175fn ismember_host(
176 a: Value,
177 b: Value,
178 opts: &IsMemberOptions,
179) -> crate::BuiltinResult<IsMemberEvaluation> {
180 match (a, b) {
181 (Value::ComplexTensor(at), Value::ComplexTensor(bt)) => ismember_complex(at, bt, opts.rows),
182 (Value::ComplexTensor(at), Value::Complex(re, im)) => {
183 let bt = ComplexTensor::new(vec![(re, im)], vec![1, 1])
184 .map_err(|e| ismember_error(format!("ismember: {e}")))?;
185 ismember_complex(at, bt, opts.rows)
186 }
187 (Value::Complex(a_re, a_im), Value::ComplexTensor(bt)) => {
188 let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
189 .map_err(|e| ismember_error(format!("ismember: {e}")))?;
190 ismember_complex(at, bt, opts.rows)
191 }
192 (Value::Complex(a_re, a_im), Value::Complex(b_re, b_im)) => {
193 let at = ComplexTensor::new(vec![(a_re, a_im)], vec![1, 1])
194 .map_err(|e| ismember_error(format!("ismember: {e}")))?;
195 let bt = ComplexTensor::new(vec![(b_re, b_im)], vec![1, 1])
196 .map_err(|e| ismember_error(format!("ismember: {e}")))?;
197 ismember_complex(at, bt, opts.rows)
198 }
199
200 (Value::CharArray(ac), Value::CharArray(bc)) => ismember_char(ac, bc, opts.rows),
201
202 (Value::StringArray(astring), Value::StringArray(bstring)) => {
203 ismember_string(astring, bstring, opts.rows)
204 }
205 (Value::StringArray(astring), Value::String(b)) => {
206 let bstring = StringArray::new(vec![b], vec![1, 1])
207 .map_err(|e| ismember_error(format!("ismember: {e}")))?;
208 ismember_string(astring, bstring, opts.rows)
209 }
210 (Value::String(a), Value::StringArray(bstring)) => {
211 let astring = StringArray::new(vec![a], vec![1, 1])
212 .map_err(|e| ismember_error(format!("ismember: {e}")))?;
213 ismember_string(astring, bstring, opts.rows)
214 }
215 (Value::String(a), Value::String(b)) => {
216 let astring = StringArray::new(vec![a], vec![1, 1])
217 .map_err(|e| ismember_error(format!("ismember: {e}")))?;
218 let bstring = StringArray::new(vec![b], vec![1, 1])
219 .map_err(|e| ismember_error(format!("ismember: {e}")))?;
220 ismember_string(astring, bstring, opts.rows)
221 }
222
223 (left, right) => {
224 let tensor_a =
225 tensor::value_into_tensor_for("ismember", left).map_err(|e| ismember_error(e))?;
226 let tensor_b =
227 tensor::value_into_tensor_for("ismember", right).map_err(|e| ismember_error(e))?;
228 ismember_numeric_tensors(tensor_a, tensor_b, opts)
229 }
230 }
231}
232
233fn ismember_numeric_tensors(
234 a: Tensor,
235 b: Tensor,
236 opts: &IsMemberOptions,
237) -> crate::BuiltinResult<IsMemberEvaluation> {
238 if opts.rows {
239 ismember_numeric_rows(a, b)
240 } else {
241 ismember_numeric_elements(a, b)
242 }
243}
244
245pub fn ismember_numeric_from_tensors(
247 a: Tensor,
248 b: Tensor,
249 rows: bool,
250) -> crate::BuiltinResult<IsMemberEvaluation> {
251 let opts = IsMemberOptions { rows };
252 ismember_numeric_tensors(a, b, &opts)
253}
254
255fn ismember_numeric_elements(a: Tensor, b: Tensor) -> crate::BuiltinResult<IsMemberEvaluation> {
256 let mut map: HashMap<u64, usize> = HashMap::new();
257 for (idx, &value) in b.data.iter().enumerate() {
258 map.entry(canonicalize_f64(value)).or_insert(idx + 1);
259 }
260
261 let mut mask_data = Vec::<u8>::with_capacity(a.data.len());
262 let mut loc_data = Vec::<f64>::with_capacity(a.data.len());
263
264 for &value in &a.data {
265 let key = canonicalize_f64(value);
266 if let Some(&pos) = map.get(&key) {
267 mask_data.push(1);
268 loc_data.push(pos as f64);
269 } else {
270 mask_data.push(0);
271 loc_data.push(0.0);
272 }
273 }
274
275 let logical = LogicalArray::new(mask_data, a.shape.clone())
276 .map_err(|e| ismember_error(format!("ismember: {e}")))?;
277 let loc_tensor = Tensor::new(loc_data, a.shape.clone())
278 .map_err(|e| ismember_error(format!("ismember: {e}")))?;
279 Ok(IsMemberEvaluation::new(logical, loc_tensor))
280}
281
282fn ismember_numeric_rows(a: Tensor, b: Tensor) -> crate::BuiltinResult<IsMemberEvaluation> {
283 let (rows_a, cols_a) = tensor_rows_cols(&a, "ismember")?;
284 let (rows_b, cols_b) = tensor_rows_cols(&b, "ismember")?;
285 if cols_a != cols_b {
286 return Err(ismember_error(
287 "ismember: inputs must have the same number of columns when using 'rows'",
288 ));
289 }
290
291 let mut map: HashMap<NumericRowKey, usize> = HashMap::new();
292 for r in 0..rows_b {
293 let mut row_values = Vec::with_capacity(cols_b);
294 for c in 0..cols_b {
295 let idx = r + c * rows_b;
296 row_values.push(b.data[idx]);
297 }
298 let key = NumericRowKey::from_slice(&row_values);
299 map.entry(key).or_insert(r + 1);
300 }
301
302 let mut mask_data = vec![0u8; rows_a];
303 let mut loc_data = vec![0.0f64; rows_a];
304
305 for r in 0..rows_a {
306 let mut row_values = Vec::with_capacity(cols_a);
307 for c in 0..cols_a {
308 let idx = r + c * rows_a;
309 row_values.push(a.data[idx]);
310 }
311 let key = NumericRowKey::from_slice(&row_values);
312 if let Some(&pos) = map.get(&key) {
313 mask_data[r] = 1;
314 loc_data[r] = pos as f64;
315 }
316 }
317
318 let shape = vec![rows_a, 1];
319 let logical = LogicalArray::new(mask_data, shape.clone())
320 .map_err(|e| ismember_error(format!("ismember: {e}")))?;
321 let loc_tensor =
322 Tensor::new(loc_data, shape).map_err(|e| ismember_error(format!("ismember: {e}")))?;
323 Ok(IsMemberEvaluation::new(logical, loc_tensor))
324}
325
326fn ismember_complex(
327 a: ComplexTensor,
328 b: ComplexTensor,
329 rows: bool,
330) -> crate::BuiltinResult<IsMemberEvaluation> {
331 if rows {
332 ismember_complex_rows(a, b)
333 } else {
334 ismember_complex_elements(a, b)
335 }
336}
337
338fn ismember_complex_elements(
339 a: ComplexTensor,
340 b: ComplexTensor,
341) -> crate::BuiltinResult<IsMemberEvaluation> {
342 let mut map: HashMap<ComplexKey, usize> = HashMap::new();
343 for (idx, &value) in b.data.iter().enumerate() {
344 map.entry(ComplexKey::new(value)).or_insert(idx + 1);
345 }
346
347 let mut mask_data = Vec::<u8>::with_capacity(a.data.len());
348 let mut loc_data = Vec::<f64>::with_capacity(a.data.len());
349
350 for &value in &a.data {
351 let key = ComplexKey::new(value);
352 if let Some(&pos) = map.get(&key) {
353 mask_data.push(1);
354 loc_data.push(pos as f64);
355 } else {
356 mask_data.push(0);
357 loc_data.push(0.0);
358 }
359 }
360
361 let logical = LogicalArray::new(mask_data, a.shape.clone())
362 .map_err(|e| ismember_error(format!("ismember: {e}")))?;
363 let loc_tensor = Tensor::new(loc_data, a.shape.clone())
364 .map_err(|e| ismember_error(format!("ismember: {e}")))?;
365 Ok(IsMemberEvaluation::new(logical, loc_tensor))
366}
367
368fn ismember_complex_rows(
369 a: ComplexTensor,
370 b: ComplexTensor,
371) -> crate::BuiltinResult<IsMemberEvaluation> {
372 let (rows_a, cols_a) = complex_rows_cols(&a)?;
373 let (rows_b, cols_b) = complex_rows_cols(&b)?;
374 if cols_a != cols_b {
375 return Err(ismember_error(
376 "ismember: complex inputs must have the same number of columns when using 'rows'",
377 )
378 .into());
379 }
380
381 let mut map: HashMap<Vec<ComplexKey>, usize> = HashMap::new();
382 for r in 0..rows_b {
383 let mut row_keys = Vec::with_capacity(cols_b);
384 for c in 0..cols_b {
385 let idx = r + c * rows_b;
386 row_keys.push(ComplexKey::new(b.data[idx]));
387 }
388 map.entry(row_keys).or_insert(r + 1);
389 }
390
391 let mut mask_data = vec![0u8; rows_a];
392 let mut loc_data = vec![0.0f64; rows_a];
393
394 for r in 0..rows_a {
395 let mut row_keys = Vec::with_capacity(cols_a);
396 for c in 0..cols_a {
397 let idx = r + c * rows_a;
398 row_keys.push(ComplexKey::new(a.data[idx]));
399 }
400 if let Some(&pos) = map.get(&row_keys) {
401 mask_data[r] = 1;
402 loc_data[r] = pos as f64;
403 }
404 }
405
406 let shape = vec![rows_a, 1];
407 let logical = LogicalArray::new(mask_data, shape.clone())
408 .map_err(|e| ismember_error(format!("ismember: {e}")))?;
409 let loc_tensor =
410 Tensor::new(loc_data, shape).map_err(|e| ismember_error(format!("ismember: {e}")))?;
411 Ok(IsMemberEvaluation::new(logical, loc_tensor))
412}
413
414fn ismember_char(
415 a: CharArray,
416 b: CharArray,
417 rows: bool,
418) -> crate::BuiltinResult<IsMemberEvaluation> {
419 if rows {
420 ismember_char_rows(a, b)
421 } else {
422 ismember_char_elements(a, b)
423 }
424}
425
426fn ismember_char_elements(a: CharArray, b: CharArray) -> crate::BuiltinResult<IsMemberEvaluation> {
427 let rows_b = b.rows;
428 let cols_b = b.cols;
429 let mut map: HashMap<char, usize> = HashMap::new();
430
431 for col in 0..cols_b {
432 for row in 0..rows_b {
433 let data_idx = row * cols_b + col;
434 let ch = b.data[data_idx];
435 let linear_idx = row + col * rows_b;
436 map.entry(ch).or_insert(linear_idx + 1);
437 }
438 }
439
440 let rows_a = a.rows;
441 let cols_a = a.cols;
442 let mut mask_data = vec![0u8; rows_a * cols_a];
443 let mut loc_data = vec![0.0f64; rows_a * cols_a];
444
445 for col in 0..cols_a {
446 for row in 0..rows_a {
447 let data_idx = row * cols_a + col;
448 let ch = a.data[data_idx];
449 let linear_idx = row + col * rows_a;
450 if let Some(&pos) = map.get(&ch) {
451 mask_data[linear_idx] = 1;
452 loc_data[linear_idx] = pos as f64;
453 }
454 }
455 }
456
457 let shape = vec![rows_a, cols_a];
458 let logical = LogicalArray::new(mask_data, shape.clone())
459 .map_err(|e| ismember_error(format!("ismember: {e}")))?;
460 let loc_tensor =
461 Tensor::new(loc_data, shape).map_err(|e| ismember_error(format!("ismember: {e}")))?;
462 Ok(IsMemberEvaluation::new(logical, loc_tensor))
463}
464
465fn ismember_char_rows(a: CharArray, b: CharArray) -> crate::BuiltinResult<IsMemberEvaluation> {
466 if a.cols != b.cols {
467 return Err(ismember_error(
468 "ismember: character inputs must have the same number of columns when using 'rows'",
469 )
470 .into());
471 }
472
473 let rows_b = b.rows;
474 let cols = b.cols;
475 let mut map: HashMap<RowCharKey, usize> = HashMap::new();
476
477 for r in 0..rows_b {
478 let mut row_values = Vec::with_capacity(cols);
479 for c in 0..cols {
480 let idx = r * cols + c;
481 row_values.push(b.data[idx]);
482 }
483 let key = RowCharKey::from_slice(&row_values);
484 map.entry(key).or_insert(r + 1);
485 }
486
487 let rows_a = a.rows;
488 let mut mask_data = vec![0u8; rows_a];
489 let mut loc_data = vec![0.0f64; rows_a];
490
491 for r in 0..rows_a {
492 let mut row_values = Vec::with_capacity(cols);
493 for c in 0..cols {
494 let idx = r * cols + c;
495 row_values.push(a.data[idx]);
496 }
497 let key = RowCharKey::from_slice(&row_values);
498 if let Some(&pos) = map.get(&key) {
499 mask_data[r] = 1;
500 loc_data[r] = pos as f64;
501 }
502 }
503
504 let shape = vec![rows_a, 1];
505 let logical = LogicalArray::new(mask_data, shape.clone())
506 .map_err(|e| ismember_error(format!("ismember: {e}")))?;
507 let loc_tensor =
508 Tensor::new(loc_data, shape).map_err(|e| ismember_error(format!("ismember: {e}")))?;
509 Ok(IsMemberEvaluation::new(logical, loc_tensor))
510}
511
512fn ismember_string(
513 a: StringArray,
514 b: StringArray,
515 rows: bool,
516) -> crate::BuiltinResult<IsMemberEvaluation> {
517 if rows {
518 ismember_string_rows(a, b)
519 } else {
520 ismember_string_elements(a, b)
521 }
522}
523
524fn ismember_string_elements(
525 a: StringArray,
526 b: StringArray,
527) -> crate::BuiltinResult<IsMemberEvaluation> {
528 let mut map: HashMap<String, usize> = HashMap::new();
529 for (idx, value) in b.data.iter().enumerate() {
530 map.entry(value.clone()).or_insert(idx + 1);
531 }
532
533 let mut mask_data = Vec::<u8>::with_capacity(a.data.len());
534 let mut loc_data = Vec::<f64>::with_capacity(a.data.len());
535
536 for value in &a.data {
537 if let Some(&pos) = map.get(value) {
538 mask_data.push(1);
539 loc_data.push(pos as f64);
540 } else {
541 mask_data.push(0);
542 loc_data.push(0.0);
543 }
544 }
545
546 let logical = LogicalArray::new(mask_data, a.shape.clone())
547 .map_err(|e| ismember_error(format!("ismember: {e}")))?;
548 let loc_tensor = Tensor::new(loc_data, a.shape.clone())
549 .map_err(|e| ismember_error(format!("ismember: {e}")))?;
550 Ok(IsMemberEvaluation::new(logical, loc_tensor))
551}
552
553fn ismember_string_rows(
554 a: StringArray,
555 b: StringArray,
556) -> crate::BuiltinResult<IsMemberEvaluation> {
557 if a.shape.len() != 2 || b.shape.len() != 2 {
558 return Err(ismember_error(
559 "ismember: 'rows' option requires 2-D string arrays",
560 ));
561 }
562 if a.shape[1] != b.shape[1] {
563 return Err(ismember_error(
564 "ismember: string inputs must have the same number of columns when using 'rows'",
565 )
566 .into());
567 }
568
569 let rows_a = a.shape[0];
570 let cols = a.shape[1];
571 let rows_b = b.shape[0];
572
573 let mut map: HashMap<RowStringKey, usize> = HashMap::new();
574 for r in 0..rows_b {
575 let mut row_values = Vec::with_capacity(cols);
576 for c in 0..cols {
577 let idx = r + c * rows_b;
578 row_values.push(b.data[idx].clone());
579 }
580 let key = RowStringKey(row_values);
581 map.entry(key).or_insert(r + 1);
582 }
583
584 let mut mask_data = vec![0u8; rows_a];
585 let mut loc_data = vec![0.0f64; rows_a];
586
587 for r in 0..rows_a {
588 let mut row_values = Vec::with_capacity(cols);
589 for c in 0..cols {
590 let idx = r + c * rows_a;
591 row_values.push(a.data[idx].clone());
592 }
593 let key = RowStringKey(row_values);
594 if let Some(&pos) = map.get(&key) {
595 mask_data[r] = 1;
596 loc_data[r] = pos as f64;
597 }
598 }
599
600 let shape = vec![rows_a, 1];
601 let logical = LogicalArray::new(mask_data, shape.clone())
602 .map_err(|e| ismember_error(format!("ismember: {e}")))?;
603 let loc_tensor =
604 Tensor::new(loc_data, shape).map_err(|e| ismember_error(format!("ismember: {e}")))?;
605 Ok(IsMemberEvaluation::new(logical, loc_tensor))
606}
607
608fn tensor_rows_cols(t: &Tensor, name: &str) -> crate::BuiltinResult<(usize, usize)> {
609 match t.shape.len() {
610 0 => Ok((1, 1)),
611 1 => Ok((t.shape[0], 1)),
612 2 => Ok((t.shape[0], t.shape[1])),
613 _ => Err(ismember_error(format!(
614 "{name}: 'rows' option requires 2-D numeric matrices"
615 ))
616 .into()),
617 }
618}
619
620fn complex_rows_cols(t: &ComplexTensor) -> crate::BuiltinResult<(usize, usize)> {
621 match t.shape.len() {
622 0 => Ok((1, 1)),
623 1 => Ok((t.shape[0], 1)),
624 2 => Ok((t.shape[0], t.shape[1])),
625 _ => Err(ismember_error(
626 "ismember: 'rows' option requires 2-D complex matrices",
627 )),
628 }
629}
630
631#[derive(Debug, Clone, PartialEq, Eq, Hash)]
632struct NumericRowKey(Vec<u64>);
633
634impl NumericRowKey {
635 fn from_slice(values: &[f64]) -> Self {
636 NumericRowKey(values.iter().map(|&v| canonicalize_f64(v)).collect())
637 }
638}
639
640#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
641struct ComplexKey {
642 re: u64,
643 im: u64,
644}
645
646impl ComplexKey {
647 fn new(value: (f64, f64)) -> Self {
648 Self {
649 re: canonicalize_f64(value.0),
650 im: canonicalize_f64(value.1),
651 }
652 }
653}
654
655#[derive(Debug, Clone, PartialEq, Eq, Hash)]
656struct RowCharKey(Vec<u32>);
657
658impl RowCharKey {
659 fn from_slice(values: &[char]) -> Self {
660 RowCharKey(values.iter().map(|&ch| ch as u32).collect())
661 }
662}
663
664#[derive(Debug, Clone, PartialEq, Eq, Hash)]
665struct RowStringKey(Vec<String>);
666
667fn canonicalize_f64(value: f64) -> u64 {
668 if value.is_nan() {
669 0x7ff8_0000_0000_0000u64
670 } else if value == 0.0 {
671 0u64
672 } else {
673 value.to_bits()
674 }
675}
676
677#[derive(Debug, Clone)]
678pub struct IsMemberEvaluation {
679 mask: LogicalArray,
680 loc: Tensor,
681}
682
683impl IsMemberEvaluation {
684 fn new(mask: LogicalArray, loc: Tensor) -> Self {
685 Self { mask, loc }
686 }
687
688 pub fn from_provider_result(result: IsMemberResult) -> crate::BuiltinResult<Self> {
689 let mask = LogicalArray::new(result.mask.data, result.mask.shape)
690 .map_err(|e| ismember_error(format!("ismember: {e}")))?;
691 let loc = Tensor::new(result.loc.data, result.loc.shape)
692 .map_err(|e| ismember_error(format!("ismember: {e}")))?;
693 Ok(IsMemberEvaluation::new(mask, loc))
694 }
695
696 pub fn into_numeric_ismember_result(self) -> crate::BuiltinResult<IsMemberResult> {
697 let IsMemberEvaluation { mask, loc } = self;
698 Ok(IsMemberResult {
699 mask: HostLogicalOwned {
700 data: mask.data,
701 shape: mask.shape,
702 },
703 loc: HostTensorOwned {
704 data: loc.data,
705 shape: loc.shape,
706 storage: GpuTensorStorage::Real,
707 },
708 })
709 }
710
711 pub fn into_mask_value(self) -> Value {
712 logical_array_into_value(self.mask)
713 }
714
715 pub fn mask_value(&self) -> Value {
716 logical_array_into_value(self.mask.clone())
717 }
718
719 pub fn into_pair(self) -> (Value, Value) {
720 let mask = logical_array_into_value(self.mask);
721 let loc = tensor::tensor_into_value(self.loc);
722 (mask, loc)
723 }
724
725 pub fn loc_value(&self) -> Value {
726 tensor::tensor_into_value(self.loc.clone())
727 }
728}
729
730fn logical_array_into_value(logical: LogicalArray) -> Value {
731 if logical.data.len() == 1 {
732 Value::Bool(logical.data[0] != 0)
733 } else {
734 Value::LogicalArray(logical)
735 }
736}
737
738#[cfg(test)]
739pub(crate) mod tests {
740 use super::*;
741 use crate::builtins::common::test_support;
742 use runmat_builtins::{ResolveContext, Tensor, Type};
743
744 #[cfg(feature = "wgpu")]
745 use runmat_accelerate_api::HostTensorView;
746
747 fn error_message(err: crate::RuntimeError) -> String {
748 err.message().to_string()
749 }
750
751 fn evaluate_sync(
752 a: Value,
753 b: Value,
754 rest: &[Value],
755 ) -> crate::BuiltinResult<IsMemberEvaluation> {
756 futures::executor::block_on(evaluate(a, b, rest))
757 }
758
759 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
760 #[test]
761 fn numeric_membership_basic() {
762 let a = Tensor::new(vec![5.0, 7.0, 2.0, 7.0], vec![1, 4]).unwrap();
763 let b = Tensor::new(vec![7.0, 9.0, 5.0], vec![1, 3]).unwrap();
764 let eval = ismember_numeric_elements(a, b).expect("ismember");
765 assert_eq!(eval.mask.data, vec![1, 1, 0, 1]);
766 assert_eq!(eval.loc.data, vec![3.0, 1.0, 0.0, 1.0]);
767 }
768
769 #[test]
770 fn ismember_type_resolver_logical() {
771 assert_eq!(
772 logical_output_type(
773 &[Type::tensor(), Type::tensor()],
774 &ResolveContext::new(Vec::new()),
775 ),
776 Type::logical()
777 );
778 }
779
780 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
781 #[test]
782 fn numeric_nan_membership() {
783 let a = Tensor::new(vec![f64::NAN, 1.0], vec![1, 2]).unwrap();
784 let b = Tensor::new(vec![f64::NAN, 2.0], vec![1, 2]).unwrap();
785 let eval = ismember_numeric_elements(a, b).expect("ismember");
786 assert_eq!(eval.mask.data, vec![1, 0]);
787 assert_eq!(eval.loc.data, vec![1.0, 0.0]);
788 }
789
790 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
791 #[test]
792 fn numeric_rows_membership() {
793 let a = Tensor::new(vec![1.0, 3.0, 1.0, 2.0, 4.0, 2.0], vec![3, 2]).unwrap();
794 let b = Tensor::new(vec![3.0, 5.0, 1.0, 4.0, 6.0, 2.0], vec![3, 2]).unwrap();
795 let eval = ismember_numeric_rows(a, b).expect("ismember");
796 assert_eq!(eval.mask.data, vec![1, 1, 1]);
797 assert_eq!(eval.loc.data, vec![3.0, 1.0, 3.0]);
798 assert_eq!(eval.loc.shape, vec![3, 1]);
799 }
800
801 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
802 #[test]
803 fn complex_membership() {
804 let a = ComplexTensor::new(vec![(1.0, 2.0), (0.0, 0.0)], vec![1, 2]).unwrap();
805 let b = ComplexTensor::new(vec![(0.0, 0.0), (1.0, 2.0)], vec![1, 2]).unwrap();
806 let eval = ismember_complex_elements(a, b).expect("ismember");
807 assert_eq!(eval.mask.data, vec![1, 1]);
808 assert_eq!(eval.loc.data, vec![2.0, 1.0]);
809 }
810
811 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
812 #[test]
813 fn complex_rows_membership() {
814 let a = ComplexTensor::new(
815 vec![(1.0, 1.0), (3.0, 0.0), (2.0, 0.0), (4.0, 4.0)],
816 vec![2, 2],
817 )
818 .unwrap();
819 let b = ComplexTensor::new(
820 vec![
821 (1.0, 1.0),
822 (5.0, 0.0),
823 (3.0, 0.0),
824 (2.0, 0.0),
825 (6.0, 0.0),
826 (4.0, 4.0),
827 ],
828 vec![3, 2],
829 )
830 .unwrap();
831 let eval = ismember_complex_rows(a, b).expect("ismember");
832 assert_eq!(eval.mask.data, vec![1, 1]);
833 assert_eq!(eval.loc.data, vec![1.0, 3.0]);
834 }
835
836 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
837 #[test]
838 fn char_membership() {
839 let a = CharArray::new(vec!['r', 'u', 'n', 'm'], 2, 2).unwrap();
840 let b = CharArray::new(vec!['m', 'a', 'r', 'u'], 2, 2).unwrap();
841 let eval = ismember_char_elements(a, b).expect("ismember");
842 assert_eq!(eval.mask.data, vec![1, 0, 1, 1]);
843 assert_eq!(eval.loc.data, vec![2.0, 0.0, 4.0, 1.0]);
844 }
845
846 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
847 #[test]
848 fn char_rows_membership() {
849 let a = CharArray::new(vec!['m', 'a', 't', 'l'], 2, 2).unwrap();
850 let b = CharArray::new(vec!['m', 'a', 'g', 'e', 't', 'l'], 3, 2).unwrap();
851 let eval = ismember_char_rows(a, b).expect("ismember");
852 assert_eq!(eval.mask.data, vec![1, 1]);
853 assert_eq!(eval.loc.data, vec![1.0, 3.0]);
854 }
855
856 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
857 #[test]
858 fn string_membership() {
859 let a = StringArray::new(
860 vec![
861 "apple".to_string(),
862 "pear".to_string(),
863 "banana".to_string(),
864 ],
865 vec![1, 3],
866 )
867 .unwrap();
868 let b = StringArray::new(
869 vec![
870 "pear".to_string(),
871 "orange".to_string(),
872 "apple".to_string(),
873 ],
874 vec![1, 3],
875 )
876 .unwrap();
877 let eval = ismember_string_elements(a, b).expect("ismember");
878 assert_eq!(eval.mask.data, vec![1, 1, 0]);
879 assert_eq!(eval.loc.data, vec![3.0, 1.0, 0.0]);
880 }
881
882 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
883 #[test]
884 fn string_rows_membership() {
885 let a = StringArray::new(
886 vec![
887 "alpha".to_string(),
888 "gamma".to_string(),
889 "beta".to_string(),
890 "delta".to_string(),
891 ],
892 vec![2, 2],
893 )
894 .unwrap();
895 let b = StringArray::new(
896 vec![
897 "alpha".to_string(),
898 "theta".to_string(),
899 "gamma".to_string(),
900 "beta".to_string(),
901 "eta".to_string(),
902 "delta".to_string(),
903 ],
904 vec![3, 2],
905 )
906 .unwrap();
907 let eval = ismember_string_rows(a, b).expect("ismember");
908 assert_eq!(eval.mask.data, vec![1, 1]);
909 assert_eq!(eval.loc.data, vec![1.0, 3.0]);
910 }
911
912 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
913 #[test]
914 fn options_reject_legacy() {
915 let err = error_message(parse_options(&[Value::from("legacy")]).unwrap_err());
916 assert!(err.contains("legacy"));
917 }
918
919 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
920 #[test]
921 fn rejects_unknown_option() {
922 let err = error_message(
923 evaluate_sync(Value::Num(1.0), Value::Num(1.0), &[Value::from("stable")]).unwrap_err(),
924 );
925 assert!(err.contains("unrecognised option"));
926 }
927
928 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
929 #[test]
930 fn ismember_runtime_numeric() {
931 let a = Value::Tensor(Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap());
932 let b = Value::Tensor(Tensor::new(vec![3.0, 1.0], vec![2, 1]).unwrap());
933 let (mask, loc) = evaluate_sync(a, b, &[]).unwrap().into_pair();
934 match mask {
935 Value::LogicalArray(arr) => assert_eq!(arr.data, vec![1, 0, 1]),
936 other => panic!("expected logical array, got {other:?}"),
937 }
938 match loc {
939 Value::Tensor(t) => assert_eq!(t.data, vec![2.0, 0.0, 1.0]),
940 other => panic!("expected tensor, got {other:?}"),
941 }
942 }
943
944 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
945 #[test]
946 fn logical_inputs_promoted() {
947 let a = Value::Bool(true);
948 let logical_b =
949 LogicalArray::new(vec![1, 0], vec![2, 1]).expect("logical array construction");
950 let eval = evaluate_sync(a, Value::LogicalArray(logical_b), &[]).expect("ismember");
951 assert_eq!(eval.mask_value(), Value::Bool(true));
952 assert_eq!(eval.loc_value(), Value::Num(1.0));
953 }
954
955 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
956 #[test]
957 fn ismember_rows_shape_checks() {
958 let a = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
959 let b = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
960 assert!(ismember_numeric_rows(a.clone(), b.clone()).is_ok());
961 let bad = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
962 let err = error_message(ismember_numeric_rows(a, bad).unwrap_err());
963 assert!(err.contains("same number of columns"));
964 }
965
966 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
967 #[test]
968 fn ismember_gpu_roundtrip() {
969 test_support::with_test_provider(|provider| {
970 let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 4.0], vec![4, 1]).unwrap();
971 let set = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
972 let view_a = runmat_accelerate_api::HostTensorView {
973 data: &tensor.data,
974 shape: &tensor.shape,
975 };
976 let view_b = runmat_accelerate_api::HostTensorView {
977 data: &set.data,
978 shape: &set.shape,
979 };
980 let handle_a = provider.upload(&view_a).expect("upload a");
981 let handle_b = provider.upload(&view_b).expect("upload b");
982 let eval = evaluate_sync(Value::GpuTensor(handle_a), Value::GpuTensor(handle_b), &[])
983 .expect("ismember");
984 assert_eq!(eval.mask.data, vec![0, 1, 0, 1]);
985 assert_eq!(eval.loc.data, vec![0.0, 1.0, 0.0, 1.0]);
986 });
987 }
988
989 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
990 #[test]
991 fn ismember_gpu_rows_roundtrip() {
992 test_support::with_test_provider(|provider| {
993 let rows = Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
994 let bank = Tensor::new(vec![1.0, 5.0, 3.0, 2.0, 6.0, 4.0], vec![3, 2]).unwrap();
995 let view_a = runmat_accelerate_api::HostTensorView {
996 data: &rows.data,
997 shape: &rows.shape,
998 };
999 let view_b = runmat_accelerate_api::HostTensorView {
1000 data: &bank.data,
1001 shape: &bank.shape,
1002 };
1003 let handle_a = provider.upload(&view_a).expect("upload a");
1004 let handle_b = provider.upload(&view_b).expect("upload b");
1005 let eval = evaluate_sync(
1006 Value::GpuTensor(handle_a.clone()),
1007 Value::GpuTensor(handle_b.clone()),
1008 &[Value::from("rows")],
1009 )
1010 .expect("ismember");
1011 assert_eq!(eval.mask.data, vec![1, 1]);
1012 assert_eq!(eval.loc.data, vec![1.0, 3.0]);
1013 let _ = provider.free(&handle_a);
1014 let _ = provider.free(&handle_b);
1015 });
1016 }
1017
1018 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1019 #[test]
1020 #[cfg(feature = "wgpu")]
1021 fn ismember_wgpu_numeric_matches_cpu() {
1022 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1023 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1024 );
1025
1026 let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 4.0], vec![4, 1]).unwrap();
1027 let set = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1028 let cpu_eval =
1029 ismember_numeric_from_tensors(tensor.clone(), set.clone(), false).expect("cpu");
1030
1031 let provider = runmat_accelerate_api::provider().expect("provider");
1032 let view_a = HostTensorView {
1033 data: &tensor.data,
1034 shape: &tensor.shape,
1035 };
1036 let view_b = HostTensorView {
1037 data: &set.data,
1038 shape: &set.shape,
1039 };
1040 let handle_a = provider.upload(&view_a).expect("upload a");
1041 let handle_b = provider.upload(&view_b).expect("upload b");
1042
1043 let eval = evaluate_sync(
1044 Value::GpuTensor(handle_a.clone()),
1045 Value::GpuTensor(handle_b.clone()),
1046 &[],
1047 )
1048 .expect("gpu evaluate");
1049 assert_eq!(eval.mask.data, cpu_eval.mask.data);
1050 assert_eq!(eval.loc.data, cpu_eval.loc.data);
1051
1052 let _ = provider.free(&handle_a);
1053 let _ = provider.free(&handle_b);
1054
1055 let matrix = Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
1056 let bank = Tensor::new(vec![1.0, 7.0, 3.0, 2.0, 9.0, 4.0], vec![3, 2]).unwrap();
1057 let cpu_rows =
1058 ismember_numeric_from_tensors(matrix.clone(), bank.clone(), true).expect("cpu rows");
1059 let view_matrix = HostTensorView {
1060 data: &matrix.data,
1061 shape: &matrix.shape,
1062 };
1063 let view_bank = HostTensorView {
1064 data: &bank.data,
1065 shape: &bank.shape,
1066 };
1067 let handle_matrix = provider.upload(&view_matrix).expect("upload matrix");
1068 let handle_bank = provider.upload(&view_bank).expect("upload bank");
1069 let eval_rows = evaluate_sync(
1070 Value::GpuTensor(handle_matrix.clone()),
1071 Value::GpuTensor(handle_bank.clone()),
1072 &[Value::from("rows")],
1073 )
1074 .expect("gpu rows evaluate");
1075 assert_eq!(eval_rows.mask.data, cpu_rows.mask.data);
1076 assert_eq!(eval_rows.loc.data, cpu_rows.loc.data);
1077 let _ = provider.free(&handle_matrix);
1078 let _ = provider.free(&handle_bank);
1079 }
1080
1081 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1082 #[test]
1083 fn scalar_return_is_bool() {
1084 let a = Value::Tensor(Tensor::new(vec![7.0], vec![1, 1]).unwrap());
1085 let b = Value::Tensor(Tensor::new(vec![7.0], vec![1, 1]).unwrap());
1086 let mask = evaluate_sync(a, b, &[]).unwrap().into_mask_value();
1087 assert_eq!(mask, Value::Bool(true));
1088 }
1089
1090 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1091 #[test]
1092 fn parse_rows_option() {
1093 let opts = parse_options(&[Value::from("rows")]).unwrap();
1094 assert!(opts.rows);
1095 }
1096
1097 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1098 #[test]
1099 fn numeric_rows_with_nan() {
1100 let a = Tensor::new(vec![f64::NAN, 1.0], vec![2, 1]).unwrap();
1101 let b = Tensor::new(vec![f64::NAN, 2.0], vec![2, 1]).unwrap();
1102 let eval = ismember_numeric_rows(a, b).expect("ismember");
1103 assert_eq!(eval.mask.data, vec![1, 0]);
1104 assert_eq!(eval.loc.data, vec![1.0, 0.0]);
1105 }
1106}