1use runmat_accelerate_api::{HostTensorView, ProviderFindResult};
4use runmat_builtins::{ComplexTensor, Tensor, Value};
5use runmat_macros::runtime_builtin;
6
7use crate::builtins::common::random_args::complex_tensor_into_value;
8use crate::builtins::common::spec::{
9 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
10 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
11};
12use crate::builtins::common::{gpu_helpers, tensor};
13#[cfg(feature = "doc_export")]
14use crate::register_builtin_doc_text;
15use crate::{register_builtin_fusion_spec, register_builtin_gpu_spec};
16
17#[cfg(feature = "doc_export")]
18pub const DOC_MD: &str = r#"---
19title: "find"
20category: "array/indexing"
21keywords: ["find", "nonzero", "indices", "row", "column", "gpu"]
22summary: "Locate indices and values of nonzero elements in scalars, vectors, matrices, or N-D tensors."
23references: []
24gpu_support:
25 elementwise: false
26 reduction: false
27 precisions: ["f32", "f64"]
28 broadcasting: "none"
29 notes: "WGPU provider executes find directly on the device; other providers fall back to the host and re-upload results to preserve residency."
30fusion:
31 elementwise: false
32 reduction: false
33 max_inputs: 1
34 constants: "inline"
35requires_feature: null
36tested:
37 unit: "builtins::array::indexing::find::tests"
38 integration: "builtins::array::indexing::find::tests::find_gpu_roundtrip"
39---
40
41# What does the `find` function do in MATLAB / RunMat?
42`find(X)` returns the indices of nonzero elements of `X`. With a single output it produces MATLAB's 1-based linear indices. With multiple outputs it returns row/column (and optionally value) vectors describing each nonzero element.
43
44## How does the `find` function behave in MATLAB / RunMat?
45- `find(X)` scans in column-major order and returns a column vector of linear indices.
46- `find(X, K)` limits the result to the first `K` matches; `K = 0` yields an empty result.
47- `find(X, K, 'first')` (default) scans from the start, while `'last'` scans from the end.
48- `find(X, 'last')` is equivalent to `find(X, 1, 'last')` and returns the final nonzero index.
49- `[row, col] = find(X)` returns per-element row and column subscripts for 2-D or N-D inputs (higher dimensions are flattened into the column index, matching MATLAB semantics).
50- `[row, col, val] = find(X)` also returns the corresponding values; complex inputs preserve their complex values.
51- Logical, char, integer, and double inputs are all supported. Empty inputs return empty outputs with MATLAB-compatible shapes.
52
53## `find` Function GPU Execution Behaviour
54When the input already resides on the GPU (i.e., a `gpuArray`), RunMat gathers it if the active provider does not implement a dedicated `find` kernel, performs the computation on the host, and then uploads the results back to the provider. This preserves residency so downstream fused kernels can continue on the device without an explicit `gather`. Providers may implement a custom hook in the future to run `find` entirely on the GPU; until then, the automatic gather/upload path maintains correctness with a small one-off cost.
55
56## Examples of using the `find` function in MATLAB / RunMat
57
58### Finding linear indices of nonzero elements
59
60```matlab
61A = [0 4 0; 7 0 9];
62k = find(A);
63```
64
65Expected output:
66
67```matlab
68k =
69 2
70 4
71 6
72```
73
74### Limiting the number of matches
75
76```matlab
77A = [0 3 5 0 8];
78first_two = find(A, 2);
79```
80
81Expected output:
82
83```matlab
84first_two =
85 2
86 3
87```
88
89### Locating the last nonzero element
90
91```matlab
92A = [1 0 0 6 0 2];
93last_index = find(A, 'last');
94```
95
96Expected output:
97
98```matlab
99last_index =
100 6
101```
102
103### Retrieving row and column subscripts
104
105```matlab
106A = [0 4 0; 7 0 9];
107[rows, cols] = find(A);
108```
109
110Expected outputs:
111
112```matlab
113rows =
114 2
115 1
116 2
117
118cols =
119 1
120 2
121 3
122```
123
124### Capturing values alongside indices (including complex inputs)
125
126```matlab
127Z = [0 1+2i; 0 0; 3-4i 0];
128[r, c, v] = find(Z);
129```
130
131Expected outputs:
132
133```matlab
134r =
135 1
136 3
137
138c =
139 1
140 1
141
142v =
143 1.0000 + 2.0000i
144 3.0000 - 4.0000i
145```
146
147## GPU residency in RunMat (Do I need `gpuArray`?)
148Usually you do **not** need to move data with `gpuArray` manually. If a provider backs `find` directly, the entire operation stays on the GPU. Otherwise, RunMat gathers once, computes on the host, and then uploads results back to the active provider so subsequent kernels remain device-resident. This means GPU pipelines continue seamlessly without additional `gather`/`gpuArray` calls from user code.
149
150## FAQ
151
152### What elements does `find` consider nonzero?
153Any element whose real or imaginary component is nonzero. For logical inputs, `true` maps to 1 and is considered nonzero; `false` is ignored.
154
155### How are higher-dimensional arrays handled when requesting row/column outputs?
156`find` treats the first dimension as rows and flattens the remaining dimensions into the column index, matching MATLAB's column-major storage.
157
158### What happens when I request more matches than exist?
159`find` returns all available nonzero elements—no error is raised. For example, `find(A, 10)` simply returns every nonzero in `A` if it has fewer than 10.
160
161### Does `find` support char arrays and integers?
162Yes. Characters are converted to their numeric code points during the test for nonzero values; integers are promoted to double precision for the result vectors.
163
164### Can I run `find` entirely on the GPU today?
165Not yet. The runtime gathers GPU inputs, computes on the host, and re-uploads results. Providers can implement the optional `find` hook to make the entire path GPU-native in the future.
166
167### What shapes do empty results take?
168When no element matches, the returned arrays are `0×1` column vectors, just like MATLAB.
169
170### How does `find` interact with fusion or auto-offload?
171`find` is a control-flow style operation, so it does not participate in fusion. Auto-offload still keeps data resident on the GPU where possible by uploading results after the host computation.
172
173### Does `find` preserve complex values in the third output?
174Yes. When you request the value output, complex inputs return a complex column vector that matches MATLAB's behaviour.
175
176### Can I combine `find` with `gpuArray` explicitly?
177Absolutely. If you call `find(gpuArray(X))`, the runtime ensures outputs stay on the GPU so later GPU-aware builtins can consume them without additional transfers.
178
179### Is there a way to obtain subscripts for every dimension?
180Use `find` to get linear indices and then call `ind2sub(size(X), ...)` if you need explicit per-dimension subscripts for N-D arrays.
181
182## See Also
183[ind2sub](./ind2sub), [sub2ind](./sub2ind), [logical](../../comparison/logical), [gpuArray](../../acceleration/gpu/gpuArray)
184
185## Source & Feedback
186- The full source code for the implementation of the `find` function is available at: [`crates/runmat-runtime/src/builtins/array/indexing/find.rs`](https://github.com/runmat-org/runmat/blob/main/crates/runmat-runtime/src/builtins/array/indexing/find.rs)
187- Found a bug or behavioural difference? Please [open an issue](https://github.com/runmat-org/runmat/issues/new/choose) with details and a minimal repro.
188"#;
189
190pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
191 name: "find",
192 op_kind: GpuOpKind::Custom("find"),
193 supported_precisions: &[ScalarType::F32, ScalarType::F64],
194 broadcast: BroadcastSemantics::None,
195 provider_hooks: &[ProviderHook::Custom("find")],
196 constant_strategy: ConstantStrategy::InlineLiteral,
197 residency: ResidencyPolicy::NewHandle,
198 nan_mode: ReductionNaN::Include,
199 two_pass_threshold: None,
200 workgroup_size: None,
201 accepts_nan_mode: false,
202 notes: "WGPU provider executes find directly on device; other providers fall back to host and re-upload results to preserve residency.",
203};
204
205register_builtin_gpu_spec!(GPU_SPEC);
206
207pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
208 name: "find",
209 shape: ShapeRequirements::Any,
210 constant_strategy: ConstantStrategy::InlineLiteral,
211 elementwise: None,
212 reduction: None,
213 emits_nan: false,
214 notes: "Find drives control flow and currently bypasses fusion; metadata is present for completeness only.",
215};
216
217register_builtin_fusion_spec!(FUSION_SPEC);
218
219#[cfg(feature = "doc_export")]
220register_builtin_doc_text!("find", DOC_MD);
221
222#[runtime_builtin(
223 name = "find",
224 category = "array/indexing",
225 summary = "Locate indices and values of nonzero elements.",
226 keywords = "find,nonzero,indices,row,column,gpu",
227 accel = "custom"
228)]
229fn find_builtin(value: Value, rest: Vec<Value>) -> Result<Value, String> {
230 let eval = evaluate(value, &rest)?;
231 eval.linear_value()
232}
233
234pub fn evaluate(value: Value, args: &[Value]) -> Result<FindEval, String> {
236 let options = parse_options(args)?;
237 match value {
238 Value::GpuTensor(handle) => {
239 if let Some(result) = try_provider_find(&handle, &options) {
240 return Ok(FindEval::from_gpu(result));
241 }
242 let (storage, _) = materialize_input(Value::GpuTensor(handle))?;
243 let result = compute_find(&storage, &options);
244 Ok(FindEval::from_host(result, true))
245 }
246 other => {
247 let (storage, input_was_gpu) = materialize_input(other)?;
248 let result = compute_find(&storage, &options);
249 Ok(FindEval::from_host(result, input_was_gpu))
250 }
251 }
252}
253
254fn try_provider_find(
255 handle: &runmat_accelerate_api::GpuTensorHandle,
256 options: &FindOptions,
257) -> Option<ProviderFindResult> {
258 #[cfg(all(test, feature = "wgpu"))]
259 {
260 if handle.device_id != 0 {
261 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
262 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
263 );
264 }
265 }
266 let provider = runmat_accelerate_api::provider()?;
267 let direction = match options.direction {
268 FindDirection::First => runmat_accelerate_api::FindDirection::First,
269 FindDirection::Last => runmat_accelerate_api::FindDirection::Last,
270 };
271 let limit = options.effective_limit();
272 provider.find(handle, limit, direction).ok()
273}
274
275#[derive(Debug, Clone, Copy, PartialEq, Eq)]
276enum FindDirection {
277 First,
278 Last,
279}
280
281#[derive(Debug, Clone)]
282struct FindOptions {
283 limit: Option<usize>,
284 direction: FindDirection,
285}
286
287impl Default for FindOptions {
288 fn default() -> Self {
289 Self {
290 limit: None,
291 direction: FindDirection::First,
292 }
293 }
294}
295
296impl FindOptions {
297 fn effective_limit(&self) -> Option<usize> {
298 match self.direction {
299 FindDirection::Last => self.limit.or(Some(1)),
300 FindDirection::First => self.limit,
301 }
302 }
303}
304
305#[derive(Clone)]
306enum DataStorage {
307 Real(Tensor),
308 Complex(ComplexTensor),
309}
310
311impl DataStorage {
312 fn shape(&self) -> &[usize] {
313 match self {
314 DataStorage::Real(t) => &t.shape,
315 DataStorage::Complex(t) => &t.shape,
316 }
317 }
318}
319
320#[derive(Clone)]
321struct FindResult {
322 shape: Vec<usize>,
323 indices: Vec<usize>,
324 values: FindValues,
325}
326
327#[derive(Clone)]
328enum FindValues {
329 Real(Vec<f64>),
330 Complex(Vec<(f64, f64)>),
331}
332
333pub struct FindEval {
334 inner: FindEvalInner,
335}
336
337enum FindEvalInner {
338 Host {
339 result: FindResult,
340 prefer_gpu: bool,
341 },
342 Gpu {
343 result: ProviderFindResult,
344 },
345}
346
347impl FindEval {
348 fn from_host(result: FindResult, prefer_gpu: bool) -> Self {
349 Self {
350 inner: FindEvalInner::Host { result, prefer_gpu },
351 }
352 }
353
354 fn from_gpu(result: ProviderFindResult) -> Self {
355 Self {
356 inner: FindEvalInner::Gpu { result },
357 }
358 }
359
360 pub fn linear_value(&self) -> Result<Value, String> {
361 match &self.inner {
362 FindEvalInner::Host { result, prefer_gpu } => {
363 let tensor = result.linear_tensor()?;
364 Ok(tensor_to_value(tensor, *prefer_gpu))
365 }
366 FindEvalInner::Gpu { result } => Ok(Value::GpuTensor(result.linear.clone())),
367 }
368 }
369
370 pub fn row_value(&self) -> Result<Value, String> {
371 match &self.inner {
372 FindEvalInner::Host { result, prefer_gpu } => {
373 let tensor = result.row_tensor()?;
374 Ok(tensor_to_value(tensor, *prefer_gpu))
375 }
376 FindEvalInner::Gpu { result } => Ok(Value::GpuTensor(result.rows.clone())),
377 }
378 }
379
380 pub fn column_value(&self) -> Result<Value, String> {
381 match &self.inner {
382 FindEvalInner::Host { result, prefer_gpu } => {
383 let tensor = result.column_tensor()?;
384 Ok(tensor_to_value(tensor, *prefer_gpu))
385 }
386 FindEvalInner::Gpu { result } => Ok(Value::GpuTensor(result.cols.clone())),
387 }
388 }
389
390 pub fn values_value(&self) -> Result<Value, String> {
391 match &self.inner {
392 FindEvalInner::Host { result, prefer_gpu } => result.values_value(*prefer_gpu),
393 FindEvalInner::Gpu { result } => result
394 .values
395 .as_ref()
396 .map(|handle| Value::GpuTensor(handle.clone()))
397 .ok_or_else(|| "find: provider did not return values buffer".to_string()),
398 }
399 }
400}
401
402fn parse_options(args: &[Value]) -> Result<FindOptions, String> {
403 match args.len() {
404 0 => Ok(FindOptions::default()),
405 1 => {
406 if is_direction_like(&args[0]) {
407 let direction_opt = parse_direction(&args[0])?;
408 let limit = if matches!(direction_opt, Some(FindDirection::Last)) {
409 Some(1)
410 } else {
411 None
412 };
413 let direction = direction_opt.unwrap_or(FindDirection::First);
414 Ok(FindOptions { limit, direction })
415 } else {
416 let limit = parse_limit(&args[0])?;
417 Ok(FindOptions {
418 limit: Some(limit),
419 direction: FindDirection::First,
420 })
421 }
422 }
423 2 => {
424 let limit = parse_limit(&args[0])?;
425 let direction = parse_direction(&args[1])?
426 .ok_or_else(|| "find: third argument must be 'first' or 'last'".to_string())?;
427 Ok(FindOptions {
428 limit: Some(limit),
429 direction,
430 })
431 }
432 _ => Err("find: too many input arguments".to_string()),
433 }
434}
435
436fn parse_direction(value: &Value) -> Result<Option<FindDirection>, String> {
437 if let Some(text) = tensor::value_to_string(value) {
438 let lowered = text.trim().to_ascii_lowercase();
439 match lowered.as_str() {
440 "first" => Ok(Some(FindDirection::First)),
441 "last" => Ok(Some(FindDirection::Last)),
442 _ => Err("find: direction must be 'first' or 'last'".to_string()),
443 }
444 } else {
445 Ok(None)
446 }
447}
448
449fn is_direction_like(value: &Value) -> bool {
450 match value {
451 Value::String(_) => true,
452 Value::StringArray(sa) => sa.data.len() == 1,
453 Value::CharArray(ca) => ca.rows == 1,
454 _ => false,
455 }
456}
457
458fn parse_limit(value: &Value) -> Result<usize, String> {
459 match value {
460 Value::GpuTensor(handle) => {
461 let tensor = gpu_helpers::gather_tensor(handle)?;
462 parse_limit_tensor(&tensor)
463 }
464 _ => {
465 let tensor = tensor::value_to_tensor(value)?;
466 parse_limit_tensor(&tensor)
467 }
468 }
469}
470
471fn parse_limit_tensor(tensor: &Tensor) -> Result<usize, String> {
472 if tensor.data.len() != 1 {
473 return Err("find: second argument must be a scalar".to_string());
474 }
475 parse_limit_scalar(tensor.data[0])
476}
477
478fn parse_limit_scalar(value: f64) -> Result<usize, String> {
479 if !value.is_finite() {
480 return Err("find: K must be a finite, non-negative integer".to_string());
481 }
482 let rounded = value.round();
483 if (rounded - value).abs() > f64::EPSILON {
484 return Err("find: K must be a finite, non-negative integer".to_string());
485 }
486 if rounded < 0.0 {
487 return Err("find: K must be >= 0".to_string());
488 }
489 Ok(rounded as usize)
490}
491
492fn materialize_input(value: Value) -> Result<(DataStorage, bool), String> {
493 match value {
494 Value::GpuTensor(handle) => {
495 let tensor = gpu_helpers::gather_tensor(&handle)?;
496 Ok((DataStorage::Real(tensor), true))
497 }
498 Value::Tensor(tensor) => Ok((DataStorage::Real(tensor), false)),
499 Value::LogicalArray(logical) => {
500 let tensor = tensor::logical_to_tensor(&logical)?;
501 Ok((DataStorage::Real(tensor), false))
502 }
503 Value::Num(n) => {
504 let tensor = Tensor::new(vec![n], vec![1, 1]).map_err(|e| format!("find: {e}"))?;
505 Ok((DataStorage::Real(tensor), false))
506 }
507 Value::Int(i) => {
508 let tensor =
509 Tensor::new(vec![i.to_f64()], vec![1, 1]).map_err(|e| format!("find: {e}"))?;
510 Ok((DataStorage::Real(tensor), false))
511 }
512 Value::Bool(b) => {
513 let tensor = Tensor::new(vec![if b { 1.0 } else { 0.0 }], vec![1, 1])
514 .map_err(|e| format!("find: {e}"))?;
515 Ok((DataStorage::Real(tensor), false))
516 }
517 Value::Complex(re, im) => {
518 let tensor =
519 ComplexTensor::new(vec![(re, im)], vec![1, 1]).map_err(|e| format!("find: {e}"))?;
520 Ok((DataStorage::Complex(tensor), false))
521 }
522 Value::ComplexTensor(tensor) => Ok((DataStorage::Complex(tensor), false)),
523 Value::CharArray(chars) => {
524 let mut data = Vec::with_capacity(chars.data.len());
525 for c in 0..chars.cols {
526 for r in 0..chars.rows {
527 let ch = chars.data[r * chars.cols + c] as u32;
528 data.push(ch as f64);
529 }
530 }
531 let tensor = Tensor::new(data, vec![chars.rows, chars.cols])
532 .map_err(|e| format!("find: {e}"))?;
533 Ok((DataStorage::Real(tensor), false))
534 }
535 other => Err(format!(
536 "find: unsupported input type {:?}; expected numeric, logical, or char data",
537 other
538 )),
539 }
540}
541
542fn compute_find(storage: &DataStorage, options: &FindOptions) -> FindResult {
543 let shape = storage.shape().to_vec();
544 let limit = options.effective_limit();
545
546 match storage {
547 DataStorage::Real(tensor) => {
548 let mut indices = Vec::new();
549 let mut values = Vec::new();
550
551 if matches!(limit, Some(0)) {
552 return FindResult::new(shape, indices, FindValues::Real(values));
553 }
554
555 let len = tensor.data.len();
556 match options.direction {
557 FindDirection::First => {
558 for idx in 0..len {
559 let value = tensor.data[idx];
560 if value != 0.0 {
561 indices.push(idx + 1);
562 values.push(value);
563 if limit.is_some_and(|k| indices.len() >= k) {
564 break;
565 }
566 }
567 }
568 }
569 FindDirection::Last => {
570 for idx in (0..len).rev() {
571 let value = tensor.data[idx];
572 if value != 0.0 {
573 indices.push(idx + 1);
574 values.push(value);
575 if limit.is_some_and(|k| indices.len() >= k) {
576 break;
577 }
578 }
579 }
580 }
581 }
582
583 FindResult::new(shape, indices, FindValues::Real(values))
584 }
585 DataStorage::Complex(tensor) => {
586 let mut indices = Vec::new();
587 let mut values = Vec::new();
588
589 if matches!(limit, Some(0)) {
590 return FindResult::new(shape, indices, FindValues::Complex(values));
591 }
592
593 let len = tensor.data.len();
594 match options.direction {
595 FindDirection::First => {
596 for idx in 0..len {
597 let (re, im) = tensor.data[idx];
598 if re != 0.0 || im != 0.0 {
599 indices.push(idx + 1);
600 values.push((re, im));
601 if limit.is_some_and(|k| indices.len() >= k) {
602 break;
603 }
604 }
605 }
606 }
607 FindDirection::Last => {
608 for idx in (0..len).rev() {
609 let (re, im) = tensor.data[idx];
610 if re != 0.0 || im != 0.0 {
611 indices.push(idx + 1);
612 values.push((re, im));
613 if limit.is_some_and(|k| indices.len() >= k) {
614 break;
615 }
616 }
617 }
618 }
619 }
620
621 FindResult::new(shape, indices, FindValues::Complex(values))
622 }
623 }
624}
625
626impl FindResult {
627 fn new(shape: Vec<usize>, indices: Vec<usize>, values: FindValues) -> Self {
628 Self {
629 shape,
630 indices,
631 values,
632 }
633 }
634
635 fn linear_tensor(&self) -> Result<Tensor, String> {
636 let data: Vec<f64> = self.indices.iter().map(|&idx| idx as f64).collect();
637 let rows = data.len();
638 Tensor::new(data, vec![rows, 1]).map_err(|e| format!("find: {e}"))
639 }
640
641 fn row_tensor(&self) -> Result<Tensor, String> {
642 let mut data = Vec::with_capacity(self.indices.len());
643 let rows = self.shape.first().copied().unwrap_or(1).max(1);
644 for &idx in &self.indices {
645 let zero_based = idx - 1;
646 let row = (zero_based % rows) + 1;
647 data.push(row as f64);
648 }
649 Tensor::new(data, vec![self.indices.len(), 1]).map_err(|e| format!("find: {e}"))
650 }
651
652 fn column_tensor(&self) -> Result<Tensor, String> {
653 let mut data = Vec::with_capacity(self.indices.len());
654 let rows = self.shape.first().copied().unwrap_or(1).max(1);
655 for &idx in &self.indices {
656 let zero_based = idx - 1;
657 let col = (zero_based / rows) + 1;
658 data.push(col as f64);
659 }
660 Tensor::new(data, vec![self.indices.len(), 1]).map_err(|e| format!("find: {e}"))
661 }
662
663 fn values_value(&self, prefer_gpu: bool) -> Result<Value, String> {
664 match &self.values {
665 FindValues::Real(values) => {
666 let tensor = Tensor::new(values.clone(), vec![values.len(), 1])
667 .map_err(|e| format!("find: {e}"))?;
668 Ok(tensor_to_value(tensor, prefer_gpu))
669 }
670 FindValues::Complex(values) => {
671 let tensor = ComplexTensor::new(values.clone(), vec![values.len(), 1])
672 .map_err(|e| format!("find: {e}"))?;
673 Ok(complex_tensor_into_value(tensor))
674 }
675 }
676 }
677}
678
679fn tensor_to_value(tensor: Tensor, prefer_gpu: bool) -> Value {
680 if prefer_gpu {
681 if let Some(provider) = runmat_accelerate_api::provider() {
682 let view = HostTensorView {
683 data: &tensor.data,
684 shape: &tensor.shape,
685 };
686 if let Ok(handle) = provider.upload(&view) {
687 return Value::GpuTensor(handle);
688 }
689 }
690 }
691 tensor::tensor_into_value(tensor)
692}
693
694#[cfg(test)]
695mod tests {
696 use super::*;
697 use crate::builtins::common::test_support;
698 use runmat_builtins::{CharArray, IntValue};
699
700 #[test]
701 fn find_linear_indices_basic() {
702 let tensor = Tensor::new(vec![0.0, 4.0, 0.0, 7.0, 0.0, 9.0], vec![2, 3]).unwrap();
703 let value = find_builtin(Value::Tensor(tensor), Vec::new()).expect("find");
704 match value {
705 Value::Tensor(t) => {
706 assert_eq!(t.shape, vec![3, 1]);
707 assert_eq!(t.data, vec![2.0, 4.0, 6.0]);
708 }
709 other => panic!("expected tensor, got {other:?}"),
710 }
711 }
712
713 #[test]
714 fn find_limited_first() {
715 let tensor = Tensor::new(vec![0.0, 3.0, 5.0, 0.0, 8.0], vec![1, 5]).unwrap();
716 let result =
717 find_builtin(Value::Tensor(tensor), vec![Value::Int(IntValue::I32(2))]).expect("find");
718 match result {
719 Value::Tensor(t) => {
720 assert_eq!(t.data, vec![2.0, 3.0]);
721 }
722 other => panic!("expected tensor, got {other:?}"),
723 }
724 }
725
726 #[test]
727 fn find_last_single() {
728 let tensor = Tensor::new(vec![1.0, 0.0, 0.0, 6.0, 0.0, 2.0], vec![1, 6]).unwrap();
729 let result = find_builtin(Value::Tensor(tensor), vec![Value::from("last")]).expect("find");
730 match result {
731 Value::Num(n) => assert_eq!(n, 6.0),
732 Value::Tensor(t) => {
733 assert_eq!(t.data, vec![6.0]);
734 }
735 other => panic!("unexpected result {other:?}"),
736 }
737 }
738
739 #[test]
740 fn find_complex_values() {
741 let tensor =
742 ComplexTensor::new(vec![(0.0, 0.0), (1.0, 2.0), (0.0, 0.0)], vec![3, 1]).unwrap();
743 let eval = evaluate(Value::ComplexTensor(tensor), &[]).expect("find compute");
744 let values = eval.values_value().expect("values");
745 match values {
746 Value::Complex(re, im) => {
747 assert_eq!(re, 1.0);
748 assert_eq!(im, 2.0);
749 }
750 Value::ComplexTensor(ct) => {
751 assert_eq!(ct.shape, vec![1, 1]);
752 assert_eq!(ct.data, vec![(1.0, 2.0)]);
753 }
754 other => panic!("expected complex result, got {other:?}"),
755 }
756 }
757
758 #[test]
759 fn find_gpu_roundtrip() {
760 test_support::with_test_provider(|provider| {
761 let tensor = Tensor::new(vec![0.0, 4.0, 0.0, 7.0], vec![2, 2]).unwrap();
762 let view = HostTensorView {
763 data: &tensor.data,
764 shape: &tensor.shape,
765 };
766 let handle = provider.upload(&view).expect("upload");
767 let result = find_builtin(Value::GpuTensor(handle), Vec::new()).expect("find");
768 let gathered = test_support::gather(result).expect("gather");
769 assert_eq!(gathered.shape, vec![2, 1]);
770 assert_eq!(gathered.data, vec![2.0, 4.0]);
771 });
772 }
773
774 #[test]
775 fn find_direction_error() {
776 let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
777 let err = find_builtin(
778 Value::Tensor(tensor),
779 vec![Value::Int(IntValue::I32(1)), Value::from("invalid")],
780 )
781 .expect_err("expected error");
782 assert!(err.contains("direction"));
783 }
784
785 #[test]
786 fn find_multi_output_rows_cols_values() {
787 let tensor = Tensor::new(vec![0.0, 2.0, 3.0, 0.0, 0.0, 6.0], vec![2, 3]).unwrap();
788 let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
789
790 let rows = test_support::gather(eval.row_value().expect("rows")).expect("gather rows");
791 assert_eq!(rows.shape, vec![3, 1]);
792 assert_eq!(rows.data, vec![2.0, 1.0, 2.0]);
793
794 let cols = test_support::gather(eval.column_value().expect("cols")).expect("gather cols");
795 assert_eq!(cols.shape, vec![3, 1]);
796 assert_eq!(cols.data, vec![1.0, 2.0, 3.0]);
797
798 let vals = test_support::gather(eval.values_value().expect("vals")).expect("gather vals");
799 assert_eq!(vals.shape, vec![3, 1]);
800 assert_eq!(vals.data, vec![2.0, 3.0, 6.0]);
801 }
802
803 #[test]
804 fn find_last_order_descending() {
805 let tensor = Tensor::new(vec![1.0, 0.0, 2.0, 3.0, 0.0], vec![1, 5]).unwrap();
806 let result = find_builtin(
807 Value::Tensor(tensor),
808 vec![Value::Int(IntValue::I32(2)), Value::from("last")],
809 )
810 .expect("find");
811 match result {
812 Value::Tensor(t) => {
813 assert_eq!(t.shape, vec![2, 1]);
814 assert_eq!(t.data, vec![4.0, 3.0]);
815 }
816 Value::Num(_) => panic!("expected column vector"),
817 other => panic!("unexpected result {other:?}"),
818 }
819 }
820
821 #[test]
822 fn find_limit_zero_returns_empty() {
823 let tensor = Tensor::new(vec![1.0, 0.0, 3.0], vec![3, 1]).unwrap();
824 let result = find_builtin(Value::Tensor(tensor), vec![Value::Num(0.0)]).expect("find");
825 match result {
826 Value::Tensor(t) => {
827 assert_eq!(t.shape, vec![0, 1]);
828 assert!(t.data.is_empty());
829 }
830 other => panic!("expected empty tensor, got {other:?}"),
831 }
832 }
833
834 #[test]
835 fn find_char_array_supports_nonzero_codes() {
836 let chars = CharArray::new(vec!['\0', 'A', '\0'], 1, 3).unwrap();
837 let result = find_builtin(Value::CharArray(chars), Vec::new()).expect("find");
838 match result {
839 Value::Num(n) => assert_eq!(n, 2.0),
840 Value::Tensor(t) => assert_eq!(t.data, vec![2.0]),
841 other => panic!("unexpected result {other:?}"),
842 }
843 }
844
845 #[test]
846 fn find_gpu_multi_outputs_return_gpu_handles() {
847 test_support::with_test_provider(|provider| {
848 let tensor = Tensor::new(vec![0.0, 4.0, 5.0, 0.0], vec![2, 2]).unwrap();
849 let view = HostTensorView {
850 data: &tensor.data,
851 shape: &tensor.shape,
852 };
853 let handle = provider.upload(&view).expect("upload");
854 let eval = evaluate(Value::GpuTensor(handle), &[]).expect("evaluate");
855
856 let rows = eval.row_value().expect("rows");
857 assert!(matches!(rows, Value::GpuTensor(_)));
858 let rows_host = test_support::gather(rows).expect("gather rows");
859 assert_eq!(rows_host.data, vec![2.0, 1.0]);
860
861 let cols = eval.column_value().expect("cols");
862 assert!(matches!(cols, Value::GpuTensor(_)));
863 let cols_host = test_support::gather(cols).expect("gather cols");
864 assert_eq!(cols_host.data, vec![1.0, 2.0]);
865
866 let vals = eval.values_value().expect("vals");
867 assert!(matches!(vals, Value::GpuTensor(_)));
868 let vals_host = test_support::gather(vals).expect("gather vals");
869 assert_eq!(vals_host.data, vec![4.0, 5.0]);
870 });
871 }
872
873 #[test]
874 #[cfg(feature = "wgpu")]
875 fn find_wgpu_matches_cpu() {
876 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
877 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
878 );
879 let tensor = Tensor::new(vec![0.0, 2.0, 0.0, 3.0, 4.0, 0.0], vec![3, 2]).unwrap();
880 let cpu_eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("cpu evaluate");
881 let cpu_linear =
882 test_support::gather(cpu_eval.linear_value().expect("cpu linear")).expect("cpu gather");
883 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
884 let view = HostTensorView {
885 data: &tensor.data,
886 shape: &tensor.shape,
887 };
888 let handle = provider.upload(&view).expect("upload");
889 let gpu_eval = evaluate(Value::GpuTensor(handle), &[]).expect("gpu evaluate");
890 let gpu_linear =
891 test_support::gather(gpu_eval.linear_value().expect("gpu linear")).expect("gpu gather");
892 assert_eq!(gpu_linear.data, cpu_linear.data);
893 }
894
895 #[test]
896 #[cfg(feature = "doc_export")]
897 fn doc_examples_present() {
898 let blocks = test_support::doc_examples(DOC_MD);
899 assert!(!blocks.is_empty());
900 }
901}