Skip to main content

runmat_runtime/builtins/array/indexing/
find.rs

1//! MATLAB-compatible `find` builtin with GPU-aware semantics for RunMat.
2
3use runmat_accelerate_api::{HostTensorView, ProviderFindResult};
4use runmat_builtins::{ComplexTensor, ResolveContext, Tensor, Type, Value};
5use runmat_macros::runtime_builtin;
6
7use crate::builtins::array::type_resolvers::column_vector_type;
8use crate::builtins::common::arg_tokens::ArgToken;
9use crate::builtins::common::random_args::complex_tensor_into_value;
10use crate::builtins::common::spec::{
11    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
12    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
13};
14use crate::builtins::common::{gpu_helpers, tensor};
15use crate::{build_runtime_error, RuntimeError};
16
17#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::indexing::find")]
18pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
19    name: "find",
20    op_kind: GpuOpKind::Custom("find"),
21    supported_precisions: &[ScalarType::F32, ScalarType::F64],
22    broadcast: BroadcastSemantics::None,
23    provider_hooks: &[ProviderHook::Custom("find")],
24    constant_strategy: ConstantStrategy::InlineLiteral,
25    residency: ResidencyPolicy::NewHandle,
26    nan_mode: ReductionNaN::Include,
27    two_pass_threshold: None,
28    workgroup_size: None,
29    accepts_nan_mode: false,
30    notes: "WGPU provider executes find directly on device; other providers fall back to host and re-upload results to preserve residency.",
31};
32
33#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::indexing::find")]
34pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
35    name: "find",
36    shape: ShapeRequirements::Any,
37    constant_strategy: ConstantStrategy::InlineLiteral,
38    elementwise: None,
39    reduction: None,
40    emits_nan: false,
41    notes: "Find drives control flow and currently bypasses fusion; metadata is present for completeness only.",
42};
43
44fn find_type(_args: &[Type], _ctx: &ResolveContext) -> Type {
45    column_vector_type()
46}
47
48fn parse_find_tokens(tokens: &[ArgToken]) -> crate::BuiltinResult<FindOptions> {
49    match tokens.len() {
50        0 => Ok(FindOptions::default()),
51        1 => {
52            if let Some(direction) = token_to_direction(&tokens[0])? {
53                let limit = if matches!(direction, FindDirection::Last) {
54                    Some(1)
55                } else {
56                    None
57                };
58                Ok(FindOptions { limit, direction })
59            } else {
60                let limit = token_to_limit(&tokens[0])?;
61                Ok(FindOptions {
62                    limit: Some(limit),
63                    direction: FindDirection::First,
64                })
65            }
66        }
67        2 => {
68            let limit = token_to_limit(&tokens[0])?;
69            let direction = token_to_direction(&tokens[1])?
70                .ok_or_else(|| find_error("find: third argument must be 'first' or 'last'"))?;
71            Ok(FindOptions {
72                limit: Some(limit),
73                direction,
74            })
75        }
76        _ => Err(find_error("find: too many input arguments")),
77    }
78}
79
80fn token_to_direction(token: &ArgToken) -> crate::BuiltinResult<Option<FindDirection>> {
81    match token {
82        ArgToken::String(text) => match text.as_str() {
83            "first" => Ok(Some(FindDirection::First)),
84            "last" => Ok(Some(FindDirection::Last)),
85            _ => Err(find_error("find: direction must be 'first' or 'last'")),
86        },
87        _ => Ok(None),
88    }
89}
90
91fn token_to_limit(token: &ArgToken) -> crate::BuiltinResult<usize> {
92    match token {
93        ArgToken::Number(value) => parse_limit_scalar(*value),
94        _ => Err(find_error("find: second argument must be a scalar")),
95    }
96}
97
98#[runtime_builtin(
99    name = "find",
100    category = "array/indexing",
101    summary = "Locate indices and values of nonzero elements.",
102    keywords = "find,nonzero,indices,row,column,gpu",
103    accel = "custom",
104    type_resolver(find_type),
105    builtin_path = "crate::builtins::array::indexing::find"
106)]
107async fn find_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
108    let eval = evaluate(value, &rest).await?;
109    if let Some(out_count) = crate::output_count::current_output_count() {
110        if out_count == 0 {
111            return Ok(Value::OutputList(Vec::new()));
112        }
113        if out_count <= 1 {
114            let linear = eval.linear_value()?;
115            return Ok(crate::output_count::output_list_with_padding(
116                out_count,
117                vec![linear],
118            ));
119        }
120        let rows = eval.row_value()?;
121        let cols = eval.column_value()?;
122        let mut outputs = vec![rows, cols];
123        if out_count >= 3 {
124            outputs.push(eval.values_value()?);
125        }
126        return Ok(crate::output_count::output_list_with_padding(
127            out_count, outputs,
128        ));
129    }
130    eval.linear_value()
131}
132
133/// Evaluate `find` and return an object that can materialise the various outputs.
134pub async fn evaluate(value: Value, args: &[Value]) -> crate::BuiltinResult<FindEval> {
135    let options = parse_options(args).await?;
136    match value {
137        Value::GpuTensor(handle) => {
138            if let Some(result) = try_provider_find(&handle, &options) {
139                return Ok(FindEval::from_gpu(result));
140            }
141            let (storage, _) = materialize_input(Value::GpuTensor(handle)).await?;
142            let result = compute_find(&storage, &options);
143            Ok(FindEval::from_host(result, true))
144        }
145        other => {
146            let (storage, input_was_gpu) = materialize_input(other).await?;
147            let result = compute_find(&storage, &options);
148            Ok(FindEval::from_host(result, input_was_gpu))
149        }
150    }
151}
152
153fn try_provider_find(
154    handle: &runmat_accelerate_api::GpuTensorHandle,
155    options: &FindOptions,
156) -> Option<ProviderFindResult> {
157    #[cfg(all(test, feature = "wgpu"))]
158    {
159        if handle.device_id != 0 {
160            let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
161                runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
162            );
163        }
164    }
165    let provider = runmat_accelerate_api::provider()?;
166    let direction = match options.direction {
167        FindDirection::First => runmat_accelerate_api::FindDirection::First,
168        FindDirection::Last => runmat_accelerate_api::FindDirection::Last,
169    };
170    let limit = options.effective_limit();
171    provider.find(handle, limit, direction).ok()
172}
173
174#[derive(Debug, Clone, Copy, PartialEq, Eq)]
175enum FindDirection {
176    First,
177    Last,
178}
179
180#[derive(Debug, Clone)]
181struct FindOptions {
182    limit: Option<usize>,
183    direction: FindDirection,
184}
185
186impl Default for FindOptions {
187    fn default() -> Self {
188        Self {
189            limit: None,
190            direction: FindDirection::First,
191        }
192    }
193}
194
195impl FindOptions {
196    fn effective_limit(&self) -> Option<usize> {
197        match self.direction {
198            FindDirection::Last => self.limit.or(Some(1)),
199            FindDirection::First => self.limit,
200        }
201    }
202}
203
204#[derive(Clone)]
205enum DataStorage {
206    Real(Tensor),
207    Complex(ComplexTensor),
208}
209
210impl DataStorage {
211    fn shape(&self) -> &[usize] {
212        match self {
213            DataStorage::Real(t) => &t.shape,
214            DataStorage::Complex(t) => &t.shape,
215        }
216    }
217}
218
219#[derive(Clone)]
220struct FindResult {
221    shape: Vec<usize>,
222    indices: Vec<usize>,
223    values: FindValues,
224}
225
226#[derive(Clone)]
227enum FindValues {
228    Real(Vec<f64>),
229    Complex(Vec<(f64, f64)>),
230}
231
232pub struct FindEval {
233    inner: FindEvalInner,
234}
235
236enum FindEvalInner {
237    Host {
238        result: FindResult,
239        prefer_gpu: bool,
240    },
241    Gpu {
242        result: ProviderFindResult,
243    },
244}
245
246impl FindEval {
247    fn from_host(result: FindResult, prefer_gpu: bool) -> Self {
248        Self {
249            inner: FindEvalInner::Host { result, prefer_gpu },
250        }
251    }
252
253    fn from_gpu(result: ProviderFindResult) -> Self {
254        Self {
255            inner: FindEvalInner::Gpu { result },
256        }
257    }
258
259    pub fn linear_value(&self) -> crate::BuiltinResult<Value> {
260        match &self.inner {
261            FindEvalInner::Host { result, prefer_gpu } => {
262                let tensor = result.linear_tensor()?;
263                Ok(tensor_to_value(tensor, *prefer_gpu))
264            }
265            FindEvalInner::Gpu { result } => Ok(Value::GpuTensor(result.linear.clone())),
266        }
267    }
268
269    pub fn row_value(&self) -> crate::BuiltinResult<Value> {
270        match &self.inner {
271            FindEvalInner::Host { result, prefer_gpu } => {
272                let tensor = result.row_tensor()?;
273                Ok(tensor_to_value(tensor, *prefer_gpu))
274            }
275            FindEvalInner::Gpu { result } => Ok(Value::GpuTensor(result.rows.clone())),
276        }
277    }
278
279    pub fn column_value(&self) -> crate::BuiltinResult<Value> {
280        match &self.inner {
281            FindEvalInner::Host { result, prefer_gpu } => {
282                let tensor = result.column_tensor()?;
283                Ok(tensor_to_value(tensor, *prefer_gpu))
284            }
285            FindEvalInner::Gpu { result } => Ok(Value::GpuTensor(result.cols.clone())),
286        }
287    }
288
289    pub fn values_value(&self) -> crate::BuiltinResult<Value> {
290        match &self.inner {
291            FindEvalInner::Host { result, prefer_gpu } => result.values_value(*prefer_gpu),
292            FindEvalInner::Gpu { result } => result
293                .values
294                .as_ref()
295                .map(|handle| Value::GpuTensor(handle.clone()))
296                .ok_or_else(|| find_error("find: provider did not return values buffer")),
297        }
298    }
299}
300
301async fn parse_options(args: &[Value]) -> crate::BuiltinResult<FindOptions> {
302    parse_find_tokens(&crate::builtins::common::arg_tokens::tokens_from_values(
303        args,
304    ))
305}
306
307fn parse_limit_scalar(value: f64) -> crate::BuiltinResult<usize> {
308    if !value.is_finite() {
309        return Err(find_error("find: K must be a finite, non-negative integer"));
310    }
311    let rounded = value.round();
312    if (rounded - value).abs() > f64::EPSILON {
313        return Err(find_error("find: K must be a finite, non-negative integer"));
314    }
315    if rounded < 0.0 {
316        return Err(find_error("find: K must be >= 0"));
317    }
318    Ok(rounded as usize)
319}
320
321async fn materialize_input(value: Value) -> crate::BuiltinResult<(DataStorage, bool)> {
322    match value {
323        Value::GpuTensor(handle) => {
324            let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
325            Ok((DataStorage::Real(tensor), true))
326        }
327        Value::Tensor(tensor) => Ok((DataStorage::Real(tensor), false)),
328        Value::LogicalArray(logical) => {
329            let tensor =
330                tensor::logical_to_tensor(&logical).map_err(|message| find_error(message))?;
331            Ok((DataStorage::Real(tensor), false))
332        }
333        Value::Num(n) => {
334            let tensor =
335                Tensor::new(vec![n], vec![1, 1]).map_err(|e| find_error(format!("find: {e}")))?;
336            Ok((DataStorage::Real(tensor), false))
337        }
338        Value::Int(i) => {
339            let tensor = Tensor::new(vec![i.to_f64()], vec![1, 1])
340                .map_err(|e| find_error(format!("find: {e}")))?;
341            Ok((DataStorage::Real(tensor), false))
342        }
343        Value::Bool(b) => {
344            let tensor = Tensor::new(vec![if b { 1.0 } else { 0.0 }], vec![1, 1])
345                .map_err(|e| find_error(format!("find: {e}")))?;
346            Ok((DataStorage::Real(tensor), false))
347        }
348        Value::Complex(re, im) => {
349            let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
350                .map_err(|e| find_error(format!("find: {e}")))?;
351            Ok((DataStorage::Complex(tensor), false))
352        }
353        Value::ComplexTensor(tensor) => Ok((DataStorage::Complex(tensor), false)),
354        Value::CharArray(chars) => {
355            let mut data = Vec::with_capacity(chars.data.len());
356            for c in 0..chars.cols {
357                for r in 0..chars.rows {
358                    let ch = chars.data[r * chars.cols + c] as u32;
359                    data.push(ch as f64);
360                }
361            }
362            let tensor = Tensor::new(data, vec![chars.rows, chars.cols])
363                .map_err(|e| find_error(format!("find: {e}")))?;
364            Ok((DataStorage::Real(tensor), false))
365        }
366        other => Err(find_error(format!(
367            "find: unsupported input type {:?}; expected numeric, logical, or char data",
368            other
369        ))),
370    }
371}
372
373fn compute_find(storage: &DataStorage, options: &FindOptions) -> FindResult {
374    let shape = storage.shape().to_vec();
375    let limit = options.effective_limit();
376
377    match storage {
378        DataStorage::Real(tensor) => {
379            let mut indices = Vec::new();
380            let mut values = Vec::new();
381
382            if matches!(limit, Some(0)) {
383                return FindResult::new(shape, indices, FindValues::Real(values));
384            }
385
386            let len = tensor.data.len();
387            match options.direction {
388                FindDirection::First => {
389                    for idx in 0..len {
390                        let value = tensor.data[idx];
391                        if value != 0.0 {
392                            indices.push(idx + 1);
393                            values.push(value);
394                            if limit.is_some_and(|k| indices.len() >= k) {
395                                break;
396                            }
397                        }
398                    }
399                }
400                FindDirection::Last => {
401                    for idx in (0..len).rev() {
402                        let value = tensor.data[idx];
403                        if value != 0.0 {
404                            indices.push(idx + 1);
405                            values.push(value);
406                            if limit.is_some_and(|k| indices.len() >= k) {
407                                break;
408                            }
409                        }
410                    }
411                }
412            }
413
414            FindResult::new(shape, indices, FindValues::Real(values))
415        }
416        DataStorage::Complex(tensor) => {
417            let mut indices = Vec::new();
418            let mut values = Vec::new();
419
420            if matches!(limit, Some(0)) {
421                return FindResult::new(shape, indices, FindValues::Complex(values));
422            }
423
424            let len = tensor.data.len();
425            match options.direction {
426                FindDirection::First => {
427                    for idx in 0..len {
428                        let (re, im) = tensor.data[idx];
429                        if re != 0.0 || im != 0.0 {
430                            indices.push(idx + 1);
431                            values.push((re, im));
432                            if limit.is_some_and(|k| indices.len() >= k) {
433                                break;
434                            }
435                        }
436                    }
437                }
438                FindDirection::Last => {
439                    for idx in (0..len).rev() {
440                        let (re, im) = tensor.data[idx];
441                        if re != 0.0 || im != 0.0 {
442                            indices.push(idx + 1);
443                            values.push((re, im));
444                            if limit.is_some_and(|k| indices.len() >= k) {
445                                break;
446                            }
447                        }
448                    }
449                }
450            }
451
452            FindResult::new(shape, indices, FindValues::Complex(values))
453        }
454    }
455}
456
457impl FindResult {
458    fn new(shape: Vec<usize>, indices: Vec<usize>, values: FindValues) -> Self {
459        Self {
460            shape,
461            indices,
462            values,
463        }
464    }
465
466    fn linear_tensor(&self) -> crate::BuiltinResult<Tensor> {
467        let data: Vec<f64> = self.indices.iter().map(|&idx| idx as f64).collect();
468        let rows = data.len();
469        Tensor::new(data, vec![rows, 1]).map_err(|e| find_error(format!("find: {e}")))
470    }
471
472    fn row_tensor(&self) -> crate::BuiltinResult<Tensor> {
473        let mut data = Vec::with_capacity(self.indices.len());
474        let rows = self.shape.first().copied().unwrap_or(1).max(1);
475        for &idx in &self.indices {
476            let zero_based = idx - 1;
477            let row = (zero_based % rows) + 1;
478            data.push(row as f64);
479        }
480        Tensor::new(data, vec![self.indices.len(), 1]).map_err(|e| find_error(format!("find: {e}")))
481    }
482
483    fn column_tensor(&self) -> crate::BuiltinResult<Tensor> {
484        let mut data = Vec::with_capacity(self.indices.len());
485        let rows = self.shape.first().copied().unwrap_or(1).max(1);
486        for &idx in &self.indices {
487            let zero_based = idx - 1;
488            let col = (zero_based / rows) + 1;
489            data.push(col as f64);
490        }
491        Tensor::new(data, vec![self.indices.len(), 1]).map_err(|e| find_error(format!("find: {e}")))
492    }
493
494    fn values_value(&self, prefer_gpu: bool) -> crate::BuiltinResult<Value> {
495        match &self.values {
496            FindValues::Real(values) => {
497                let tensor = Tensor::new(values.clone(), vec![values.len(), 1])
498                    .map_err(|e| find_error(format!("find: {e}")))?;
499                Ok(tensor_to_value(tensor, prefer_gpu))
500            }
501            FindValues::Complex(values) => {
502                let tensor = ComplexTensor::new(values.clone(), vec![values.len(), 1])
503                    .map_err(|e| find_error(format!("find: {e}")))?;
504                Ok(complex_tensor_into_value(tensor))
505            }
506        }
507    }
508}
509
510fn tensor_to_value(tensor: Tensor, prefer_gpu: bool) -> Value {
511    if prefer_gpu {
512        if let Some(provider) = runmat_accelerate_api::provider() {
513            let view = HostTensorView {
514                data: &tensor.data,
515                shape: &tensor.shape,
516            };
517            if let Ok(handle) = provider.upload(&view) {
518                return Value::GpuTensor(handle);
519            }
520        }
521    }
522    tensor::tensor_into_value(tensor)
523}
524
525fn find_error(message: impl Into<String>) -> RuntimeError {
526    build_runtime_error(message).with_builtin("find").build()
527}
528
529#[cfg(test)]
530pub(crate) mod tests {
531    use super::*;
532    use crate::builtins::common::test_support;
533    use futures::executor::block_on;
534    use runmat_builtins::{CharArray, IntValue, Type};
535
536    fn find_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
537        block_on(super::find_builtin(value, rest))
538    }
539
540    fn evaluate(value: Value, rest: &[Value]) -> crate::BuiltinResult<FindEval> {
541        block_on(super::evaluate(value, rest))
542    }
543
544    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
545    #[test]
546    fn find_linear_indices_basic() {
547        let tensor = Tensor::new(vec![0.0, 4.0, 0.0, 7.0, 0.0, 9.0], vec![2, 3]).unwrap();
548        let value = find_builtin(Value::Tensor(tensor), Vec::new()).expect("find");
549        match value {
550            Value::Tensor(t) => {
551                assert_eq!(t.shape, vec![3, 1]);
552                assert_eq!(t.data, vec![2.0, 4.0, 6.0]);
553            }
554            other => panic!("expected tensor, got {other:?}"),
555        }
556    }
557
558    #[test]
559    fn find_type_is_column_vector() {
560        assert_eq!(
561            find_type(
562                &[Type::Tensor { shape: None }],
563                &ResolveContext::new(Vec::new()),
564            ),
565            Type::Tensor {
566                shape: Some(vec![None, Some(1)])
567            }
568        );
569    }
570
571    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
572    #[test]
573    fn find_limited_first() {
574        let tensor = Tensor::new(vec![0.0, 3.0, 5.0, 0.0, 8.0], vec![1, 5]).unwrap();
575        let result =
576            find_builtin(Value::Tensor(tensor), vec![Value::Int(IntValue::I32(2))]).expect("find");
577        match result {
578            Value::Tensor(t) => {
579                assert_eq!(t.data, vec![2.0, 3.0]);
580            }
581            other => panic!("expected tensor, got {other:?}"),
582        }
583    }
584
585    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
586    #[test]
587    fn find_last_single() {
588        let tensor = Tensor::new(vec![1.0, 0.0, 0.0, 6.0, 0.0, 2.0], vec![1, 6]).unwrap();
589        let result = find_builtin(Value::Tensor(tensor), vec![Value::from("last")]).expect("find");
590        match result {
591            Value::Num(n) => assert_eq!(n, 6.0),
592            Value::Tensor(t) => {
593                assert_eq!(t.data, vec![6.0]);
594            }
595            other => panic!("unexpected result {other:?}"),
596        }
597    }
598
599    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
600    #[test]
601    fn find_complex_values() {
602        let tensor =
603            ComplexTensor::new(vec![(0.0, 0.0), (1.0, 2.0), (0.0, 0.0)], vec![3, 1]).unwrap();
604        let eval = evaluate(Value::ComplexTensor(tensor), &[]).expect("find compute");
605        let values = eval.values_value().expect("values");
606        match values {
607            Value::Complex(re, im) => {
608                assert_eq!(re, 1.0);
609                assert_eq!(im, 2.0);
610            }
611            Value::ComplexTensor(ct) => {
612                assert_eq!(ct.shape, vec![1, 1]);
613                assert_eq!(ct.data, vec![(1.0, 2.0)]);
614            }
615            other => panic!("expected complex result, got {other:?}"),
616        }
617    }
618
619    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
620    #[test]
621    fn find_gpu_roundtrip() {
622        test_support::with_test_provider(|provider| {
623            let tensor = Tensor::new(vec![0.0, 4.0, 0.0, 7.0], vec![2, 2]).unwrap();
624            let view = HostTensorView {
625                data: &tensor.data,
626                shape: &tensor.shape,
627            };
628            let handle = provider.upload(&view).expect("upload");
629            let result = find_builtin(Value::GpuTensor(handle), Vec::new()).expect("find");
630            let gathered = test_support::gather(result).expect("gather");
631            assert_eq!(gathered.shape, vec![2, 1]);
632            assert_eq!(gathered.data, vec![2.0, 4.0]);
633        });
634    }
635
636    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
637    #[test]
638    fn find_direction_error() {
639        let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
640        let err = find_builtin(
641            Value::Tensor(tensor),
642            vec![Value::Int(IntValue::I32(1)), Value::from("invalid")],
643        )
644        .expect_err("expected error");
645        assert!(err.to_string().contains("direction"));
646    }
647
648    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
649    #[test]
650    fn find_multi_output_rows_cols_values() {
651        let tensor = Tensor::new(vec![0.0, 2.0, 3.0, 0.0, 0.0, 6.0], vec![2, 3]).unwrap();
652        let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
653
654        let rows = test_support::gather(eval.row_value().expect("rows")).expect("gather rows");
655        assert_eq!(rows.shape, vec![3, 1]);
656        assert_eq!(rows.data, vec![2.0, 1.0, 2.0]);
657
658        let cols = test_support::gather(eval.column_value().expect("cols")).expect("gather cols");
659        assert_eq!(cols.shape, vec![3, 1]);
660        assert_eq!(cols.data, vec![1.0, 2.0, 3.0]);
661
662        let vals = test_support::gather(eval.values_value().expect("vals")).expect("gather vals");
663        assert_eq!(vals.shape, vec![3, 1]);
664        assert_eq!(vals.data, vec![2.0, 3.0, 6.0]);
665    }
666
667    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
668    #[test]
669    fn find_last_order_descending() {
670        let tensor = Tensor::new(vec![1.0, 0.0, 2.0, 3.0, 0.0], vec![1, 5]).unwrap();
671        let result = find_builtin(
672            Value::Tensor(tensor),
673            vec![Value::Int(IntValue::I32(2)), Value::from("last")],
674        )
675        .expect("find");
676        match result {
677            Value::Tensor(t) => {
678                assert_eq!(t.shape, vec![2, 1]);
679                assert_eq!(t.data, vec![4.0, 3.0]);
680            }
681            Value::Num(_) => panic!("expected column vector"),
682            other => panic!("unexpected result {other:?}"),
683        }
684    }
685
686    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
687    #[test]
688    fn find_limit_zero_returns_empty() {
689        let tensor = Tensor::new(vec![1.0, 0.0, 3.0], vec![3, 1]).unwrap();
690        let result = find_builtin(Value::Tensor(tensor), vec![Value::Num(0.0)]).expect("find");
691        match result {
692            Value::Tensor(t) => {
693                assert_eq!(t.shape, vec![0, 1]);
694                assert!(t.data.is_empty());
695            }
696            other => panic!("expected empty tensor, got {other:?}"),
697        }
698    }
699
700    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
701    #[test]
702    fn find_char_array_supports_nonzero_codes() {
703        let chars = CharArray::new(vec!['\0', 'A', '\0'], 1, 3).unwrap();
704        let result = find_builtin(Value::CharArray(chars), Vec::new()).expect("find");
705        match result {
706            Value::Num(n) => assert_eq!(n, 2.0),
707            Value::Tensor(t) => assert_eq!(t.data, vec![2.0]),
708            other => panic!("unexpected result {other:?}"),
709        }
710    }
711
712    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
713    #[test]
714    fn find_gpu_multi_outputs_return_gpu_handles() {
715        test_support::with_test_provider(|provider| {
716            let tensor = Tensor::new(vec![0.0, 4.0, 5.0, 0.0], vec![2, 2]).unwrap();
717            let view = HostTensorView {
718                data: &tensor.data,
719                shape: &tensor.shape,
720            };
721            let handle = provider.upload(&view).expect("upload");
722            let eval = evaluate(Value::GpuTensor(handle), &[]).expect("evaluate");
723
724            let rows = eval.row_value().expect("rows");
725            assert!(matches!(rows, Value::GpuTensor(_)));
726            let rows_host = test_support::gather(rows).expect("gather rows");
727            assert_eq!(rows_host.data, vec![2.0, 1.0]);
728
729            let cols = eval.column_value().expect("cols");
730            assert!(matches!(cols, Value::GpuTensor(_)));
731            let cols_host = test_support::gather(cols).expect("gather cols");
732            assert_eq!(cols_host.data, vec![1.0, 2.0]);
733
734            let vals = eval.values_value().expect("vals");
735            assert!(matches!(vals, Value::GpuTensor(_)));
736            let vals_host = test_support::gather(vals).expect("gather vals");
737            assert_eq!(vals_host.data, vec![4.0, 5.0]);
738        });
739    }
740
741    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
742    #[test]
743    #[cfg(feature = "wgpu")]
744    fn find_wgpu_matches_cpu() {
745        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
746            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
747        );
748        let tensor = Tensor::new(vec![0.0, 2.0, 0.0, 3.0, 4.0, 0.0], vec![3, 2]).unwrap();
749        let cpu_eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("cpu evaluate");
750        let cpu_linear =
751            test_support::gather(cpu_eval.linear_value().expect("cpu linear")).expect("cpu gather");
752        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
753        let view = HostTensorView {
754            data: &tensor.data,
755            shape: &tensor.shape,
756        };
757        let handle = provider.upload(&view).expect("upload");
758        let gpu_eval = evaluate(Value::GpuTensor(handle), &[]).expect("gpu evaluate");
759        let gpu_linear =
760            test_support::gather(gpu_eval.linear_value().expect("gpu linear")).expect("gpu gather");
761        assert_eq!(gpu_linear.data, cpu_linear.data);
762    }
763}