Skip to main content

runmat_runtime/builtins/array/indexing/
find.rs

1//! MATLAB-compatible `find` builtin with GPU-aware semantics for RunMat.
2
3use runmat_accelerate_api::{HostTensorView, ProviderFindResult};
4use runmat_builtins::{
5    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
6    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
7    ComplexTensor, ResolveContext, Tensor, Type, Value,
8};
9use runmat_macros::runtime_builtin;
10
11use crate::builtins::array::type_resolvers::column_vector_type;
12use crate::builtins::common::arg_tokens::ArgToken;
13use crate::builtins::common::random_args::complex_tensor_into_value;
14use crate::builtins::common::spec::{
15    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
16    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
17};
18use crate::builtins::common::{gpu_helpers, tensor};
19use crate::{build_runtime_error, RuntimeError};
20
21#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::indexing::find")]
22pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
23    name: "find",
24    op_kind: GpuOpKind::Custom("find"),
25    supported_precisions: &[ScalarType::F32, ScalarType::F64],
26    broadcast: BroadcastSemantics::None,
27    provider_hooks: &[ProviderHook::Custom("find")],
28    constant_strategy: ConstantStrategy::InlineLiteral,
29    residency: ResidencyPolicy::NewHandle,
30    nan_mode: ReductionNaN::Include,
31    two_pass_threshold: None,
32    workgroup_size: None,
33    accepts_nan_mode: false,
34    notes: "WGPU provider executes find directly on device; other providers fall back to host and re-upload results to preserve residency.",
35};
36
37#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::indexing::find")]
38pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
39    name: "find",
40    shape: ShapeRequirements::Any,
41    constant_strategy: ConstantStrategy::InlineLiteral,
42    elementwise: None,
43    reduction: None,
44    emits_nan: false,
45    notes: "Find drives control flow and currently bypasses fusion; metadata is present for completeness only.",
46};
47
48fn find_type(_args: &[Type], _ctx: &ResolveContext) -> Type {
49    column_vector_type()
50}
51
52const BUILTIN_NAME: &str = "find";
53
54const FIND_OUTPUT_LINEAR: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
55    name: "idx",
56    ty: BuiltinParamType::NumericArray,
57    arity: BuiltinParamArity::Required,
58    default: None,
59    description: "Linear indices of non-zero elements.",
60}];
61
62const FIND_OUTPUT_ROW_COL: [BuiltinParamDescriptor; 2] = [
63    BuiltinParamDescriptor {
64        name: "row",
65        ty: BuiltinParamType::NumericArray,
66        arity: BuiltinParamArity::Required,
67        default: None,
68        description: "Row subscripts of non-zero elements.",
69    },
70    BuiltinParamDescriptor {
71        name: "col",
72        ty: BuiltinParamType::NumericArray,
73        arity: BuiltinParamArity::Required,
74        default: None,
75        description: "Column subscripts of non-zero elements.",
76    },
77];
78
79const FIND_OUTPUT_ROW_COL_VAL: [BuiltinParamDescriptor; 3] = [
80    BuiltinParamDescriptor {
81        name: "row",
82        ty: BuiltinParamType::NumericArray,
83        arity: BuiltinParamArity::Required,
84        default: None,
85        description: "Row subscripts of non-zero elements.",
86    },
87    BuiltinParamDescriptor {
88        name: "col",
89        ty: BuiltinParamType::NumericArray,
90        arity: BuiltinParamArity::Required,
91        default: None,
92        description: "Column subscripts of non-zero elements.",
93    },
94    BuiltinParamDescriptor {
95        name: "v",
96        ty: BuiltinParamType::Any,
97        arity: BuiltinParamArity::Required,
98        default: None,
99        description: "Values at the reported row/column locations.",
100    },
101];
102
103const FIND_INPUTS_BASE: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
104    name: "X",
105    ty: BuiltinParamType::Any,
106    arity: BuiltinParamArity::Required,
107    default: None,
108    description: "Input array to search.",
109}];
110
111const FIND_INPUTS_LIMIT: [BuiltinParamDescriptor; 2] = [
112    BuiltinParamDescriptor {
113        name: "X",
114        ty: BuiltinParamType::Any,
115        arity: BuiltinParamArity::Required,
116        default: None,
117        description: "Input array to search.",
118    },
119    BuiltinParamDescriptor {
120        name: "K",
121        ty: BuiltinParamType::NumericScalar,
122        arity: BuiltinParamArity::Required,
123        default: None,
124        description: "Maximum number of indices to return.",
125    },
126];
127
128const FIND_INPUTS_LIMIT_DIR: [BuiltinParamDescriptor; 3] = [
129    BuiltinParamDescriptor {
130        name: "X",
131        ty: BuiltinParamType::Any,
132        arity: BuiltinParamArity::Required,
133        default: None,
134        description: "Input array to search.",
135    },
136    BuiltinParamDescriptor {
137        name: "K",
138        ty: BuiltinParamType::NumericScalar,
139        arity: BuiltinParamArity::Required,
140        default: None,
141        description: "Maximum number of indices to return.",
142    },
143    BuiltinParamDescriptor {
144        name: "direction",
145        ty: BuiltinParamType::StringScalar,
146        arity: BuiltinParamArity::Required,
147        default: Some("\"first\""),
148        description: "Direction selector: `\"first\"` or `\"last\"`.",
149    },
150];
151
152const FIND_SIGNATURES: [BuiltinSignatureDescriptor; 7] = [
153    BuiltinSignatureDescriptor {
154        label: "idx = find(X)",
155        inputs: &FIND_INPUTS_BASE,
156        outputs: &FIND_OUTPUT_LINEAR,
157    },
158    BuiltinSignatureDescriptor {
159        label: "idx = find(X, K)",
160        inputs: &FIND_INPUTS_LIMIT,
161        outputs: &FIND_OUTPUT_LINEAR,
162    },
163    BuiltinSignatureDescriptor {
164        label: "idx = find(X, K, direction)",
165        inputs: &FIND_INPUTS_LIMIT_DIR,
166        outputs: &FIND_OUTPUT_LINEAR,
167    },
168    BuiltinSignatureDescriptor {
169        label: "[row, col] = find(X)",
170        inputs: &FIND_INPUTS_BASE,
171        outputs: &FIND_OUTPUT_ROW_COL,
172    },
173    BuiltinSignatureDescriptor {
174        label: "[row, col] = find(X, K, direction)",
175        inputs: &FIND_INPUTS_LIMIT_DIR,
176        outputs: &FIND_OUTPUT_ROW_COL,
177    },
178    BuiltinSignatureDescriptor {
179        label: "[row, col, v] = find(X)",
180        inputs: &FIND_INPUTS_BASE,
181        outputs: &FIND_OUTPUT_ROW_COL_VAL,
182    },
183    BuiltinSignatureDescriptor {
184        label: "[row, col, v] = find(X, K, direction)",
185        inputs: &FIND_INPUTS_LIMIT_DIR,
186        outputs: &FIND_OUTPUT_ROW_COL_VAL,
187    },
188];
189
190const FIND_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
191    code: "RM.FIND.INVALID_INPUT",
192    identifier: Some("RunMat:find:InvalidInput"),
193    when: "Input type or option arguments are not valid for find.",
194    message: "find: invalid input arguments",
195};
196
197const FIND_ERROR_PROVIDER_OUTPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
198    code: "RM.FIND.PROVIDER_OUTPUT",
199    identifier: Some("RunMat:find:ProviderOutput"),
200    when: "GPU provider does not return expected output buffers for requested nargout.",
201    message: "find: provider output buffer mismatch",
202};
203
204const FIND_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
205    code: "RM.FIND.INTERNAL",
206    identifier: Some("RunMat:find:InternalError"),
207    when: "Internal tensor conversion/materialization fails while building outputs.",
208    message: "find: internal error",
209};
210
211const FIND_ERRORS: [BuiltinErrorDescriptor; 3] = [
212    FIND_ERROR_INVALID_INPUT,
213    FIND_ERROR_PROVIDER_OUTPUT,
214    FIND_ERROR_INTERNAL,
215];
216
217pub const FIND_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
218    signatures: &FIND_SIGNATURES,
219    output_mode: BuiltinOutputMode::ByRequestedOutputCount,
220    completion_policy: BuiltinCompletionPolicy::Public,
221    errors: &FIND_ERRORS,
222};
223
224fn find_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
225    find_error_with_message(error.message, error)
226}
227
228fn find_error_with_message(
229    message: impl Into<String>,
230    error: &'static BuiltinErrorDescriptor,
231) -> RuntimeError {
232    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
233    if let Some(identifier) = error.identifier {
234        builder = builder.with_identifier(identifier);
235    }
236    builder.build()
237}
238
239fn parse_find_tokens(tokens: &[ArgToken]) -> crate::BuiltinResult<FindOptions> {
240    match tokens.len() {
241        0 => Ok(FindOptions::default()),
242        1 => {
243            if let Some(direction) = token_to_direction(&tokens[0])? {
244                let limit = if matches!(direction, FindDirection::Last) {
245                    Some(1)
246                } else {
247                    None
248                };
249                Ok(FindOptions { limit, direction })
250            } else {
251                let limit = token_to_limit(&tokens[0])?;
252                Ok(FindOptions {
253                    limit: Some(limit),
254                    direction: FindDirection::First,
255                })
256            }
257        }
258        2 => {
259            let limit = token_to_limit(&tokens[0])?;
260            let direction = token_to_direction(&tokens[1])?.ok_or_else(|| {
261                find_error_with_message(
262                    "find: third argument must be 'first' or 'last'",
263                    &FIND_ERROR_INVALID_INPUT,
264                )
265            })?;
266            Ok(FindOptions {
267                limit: Some(limit),
268                direction,
269            })
270        }
271        _ => Err(find_error_with_message(
272            "find: too many input arguments",
273            &FIND_ERROR_INVALID_INPUT,
274        )),
275    }
276}
277
278fn token_to_direction(token: &ArgToken) -> crate::BuiltinResult<Option<FindDirection>> {
279    match token {
280        ArgToken::String(text) => match text.as_str() {
281            "first" => Ok(Some(FindDirection::First)),
282            "last" => Ok(Some(FindDirection::Last)),
283            _ => Err(find_error_with_message(
284                "find: direction must be 'first' or 'last'",
285                &FIND_ERROR_INVALID_INPUT,
286            )),
287        },
288        _ => Ok(None),
289    }
290}
291
292fn token_to_limit(token: &ArgToken) -> crate::BuiltinResult<usize> {
293    match token {
294        ArgToken::Number(value) => parse_limit_scalar(*value),
295        _ => Err(find_error_with_message(
296            "find: second argument must be a scalar",
297            &FIND_ERROR_INVALID_INPUT,
298        )),
299    }
300}
301
302#[runtime_builtin(
303    name = "find",
304    category = "array/indexing",
305    summary = "Locate nonzero indices and values.",
306    keywords = "find,nonzero,indices,row,column,gpu",
307    accel = "custom",
308    type_resolver(find_type),
309    descriptor(crate::builtins::array::indexing::find::FIND_DESCRIPTOR),
310    builtin_path = "crate::builtins::array::indexing::find"
311)]
312async fn find_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
313    let eval = evaluate(value, &rest).await?;
314    if let Some(out_count) = crate::output_count::current_output_count() {
315        if out_count == 0 {
316            return Ok(Value::OutputList(Vec::new()));
317        }
318        if out_count <= 1 {
319            let linear = eval.linear_value()?;
320            return Ok(crate::output_count::output_list_with_padding(
321                out_count,
322                vec![linear],
323            ));
324        }
325        let rows = eval.row_value()?;
326        let cols = eval.column_value()?;
327        let mut outputs = vec![rows, cols];
328        if out_count >= 3 {
329            outputs.push(eval.values_value()?);
330        }
331        return Ok(crate::output_count::output_list_with_padding(
332            out_count, outputs,
333        ));
334    }
335    eval.linear_value()
336}
337
338/// Evaluate `find` and return an object that can materialise the various outputs.
339pub async fn evaluate(value: Value, args: &[Value]) -> crate::BuiltinResult<FindEval> {
340    let options = parse_options(args).await?;
341    match value {
342        Value::GpuTensor(handle) => {
343            if let Some(result) = try_provider_find(&handle, &options) {
344                return Ok(FindEval::from_gpu(result));
345            }
346            let (storage, _) = materialize_input(Value::GpuTensor(handle)).await?;
347            let result = compute_find(&storage, &options);
348            Ok(FindEval::from_host(result, true))
349        }
350        other => {
351            let (storage, input_was_gpu) = materialize_input(other).await?;
352            let result = compute_find(&storage, &options);
353            Ok(FindEval::from_host(result, input_was_gpu))
354        }
355    }
356}
357
358fn try_provider_find(
359    handle: &runmat_accelerate_api::GpuTensorHandle,
360    options: &FindOptions,
361) -> Option<ProviderFindResult> {
362    #[cfg(all(test, feature = "wgpu"))]
363    {
364        if handle.device_id != 0 {
365            let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
366                runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
367            );
368        }
369    }
370    let provider = runmat_accelerate_api::provider()?;
371    let direction = match options.direction {
372        FindDirection::First => runmat_accelerate_api::FindDirection::First,
373        FindDirection::Last => runmat_accelerate_api::FindDirection::Last,
374    };
375    let limit = options.effective_limit();
376    provider.find(handle, limit, direction).ok()
377}
378
379#[derive(Debug, Clone, Copy, PartialEq, Eq)]
380enum FindDirection {
381    First,
382    Last,
383}
384
385#[derive(Debug, Clone)]
386struct FindOptions {
387    limit: Option<usize>,
388    direction: FindDirection,
389}
390
391impl Default for FindOptions {
392    fn default() -> Self {
393        Self {
394            limit: None,
395            direction: FindDirection::First,
396        }
397    }
398}
399
400impl FindOptions {
401    fn effective_limit(&self) -> Option<usize> {
402        match self.direction {
403            FindDirection::Last => self.limit.or(Some(1)),
404            FindDirection::First => self.limit,
405        }
406    }
407}
408
409#[derive(Clone)]
410enum DataStorage {
411    Real(Tensor),
412    Complex(ComplexTensor),
413}
414
415impl DataStorage {
416    fn shape(&self) -> &[usize] {
417        match self {
418            DataStorage::Real(t) => &t.shape,
419            DataStorage::Complex(t) => &t.shape,
420        }
421    }
422}
423
424#[derive(Clone)]
425struct FindResult {
426    shape: Vec<usize>,
427    indices: Vec<usize>,
428    values: FindValues,
429}
430
431#[derive(Clone)]
432enum FindValues {
433    Real(Vec<f64>),
434    Complex(Vec<(f64, f64)>),
435}
436
437pub struct FindEval {
438    inner: FindEvalInner,
439}
440
441enum FindEvalInner {
442    Host {
443        result: FindResult,
444        prefer_gpu: bool,
445    },
446    Gpu {
447        result: ProviderFindResult,
448    },
449}
450
451impl FindEval {
452    fn from_host(result: FindResult, prefer_gpu: bool) -> Self {
453        Self {
454            inner: FindEvalInner::Host { result, prefer_gpu },
455        }
456    }
457
458    fn from_gpu(result: ProviderFindResult) -> Self {
459        Self {
460            inner: FindEvalInner::Gpu { result },
461        }
462    }
463
464    pub fn linear_value(&self) -> crate::BuiltinResult<Value> {
465        match &self.inner {
466            FindEvalInner::Host { result, prefer_gpu } => {
467                let tensor = result.linear_tensor()?;
468                Ok(tensor_to_value(tensor, *prefer_gpu))
469            }
470            FindEvalInner::Gpu { result } => Ok(Value::GpuTensor(result.linear.clone())),
471        }
472    }
473
474    pub fn row_value(&self) -> crate::BuiltinResult<Value> {
475        match &self.inner {
476            FindEvalInner::Host { result, prefer_gpu } => {
477                let tensor = result.row_tensor()?;
478                Ok(tensor_to_value(tensor, *prefer_gpu))
479            }
480            FindEvalInner::Gpu { result } => Ok(Value::GpuTensor(result.rows.clone())),
481        }
482    }
483
484    pub fn column_value(&self) -> crate::BuiltinResult<Value> {
485        match &self.inner {
486            FindEvalInner::Host { result, prefer_gpu } => {
487                let tensor = result.column_tensor()?;
488                Ok(tensor_to_value(tensor, *prefer_gpu))
489            }
490            FindEvalInner::Gpu { result } => Ok(Value::GpuTensor(result.cols.clone())),
491        }
492    }
493
494    pub fn values_value(&self) -> crate::BuiltinResult<Value> {
495        match &self.inner {
496            FindEvalInner::Host { result, prefer_gpu } => result.values_value(*prefer_gpu),
497            FindEvalInner::Gpu { result } => result
498                .values
499                .as_ref()
500                .map(|handle| Value::GpuTensor(handle.clone()))
501                .ok_or_else(|| find_error(&FIND_ERROR_PROVIDER_OUTPUT)),
502        }
503    }
504}
505
506async fn parse_options(args: &[Value]) -> crate::BuiltinResult<FindOptions> {
507    parse_find_tokens(&crate::builtins::common::arg_tokens::tokens_from_values(
508        args,
509    ))
510}
511
512fn parse_limit_scalar(value: f64) -> crate::BuiltinResult<usize> {
513    if !value.is_finite() {
514        return Err(find_error_with_message(
515            "find: K must be a finite, non-negative integer",
516            &FIND_ERROR_INVALID_INPUT,
517        ));
518    }
519    let rounded = value.round();
520    if (rounded - value).abs() > f64::EPSILON {
521        return Err(find_error_with_message(
522            "find: K must be a finite, non-negative integer",
523            &FIND_ERROR_INVALID_INPUT,
524        ));
525    }
526    if rounded < 0.0 {
527        return Err(find_error_with_message(
528            "find: K must be >= 0",
529            &FIND_ERROR_INVALID_INPUT,
530        ));
531    }
532    Ok(rounded as usize)
533}
534
535async fn materialize_input(value: Value) -> crate::BuiltinResult<(DataStorage, bool)> {
536    match value {
537        Value::GpuTensor(handle) => {
538            let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
539            Ok((DataStorage::Real(tensor), true))
540        }
541        Value::Tensor(tensor) => Ok((DataStorage::Real(tensor), false)),
542        Value::LogicalArray(logical) => {
543            let tensor = tensor::logical_to_tensor(&logical)
544                .map_err(|message| find_error_with_message(message, &FIND_ERROR_INTERNAL))?;
545            Ok((DataStorage::Real(tensor), false))
546        }
547        Value::Num(n) => {
548            let tensor = Tensor::new(vec![n], vec![1, 1])
549                .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))?;
550            Ok((DataStorage::Real(tensor), false))
551        }
552        Value::Int(i) => {
553            let tensor = Tensor::new(vec![i.to_f64()], vec![1, 1])
554                .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))?;
555            Ok((DataStorage::Real(tensor), false))
556        }
557        Value::Bool(b) => {
558            let tensor = Tensor::new(vec![if b { 1.0 } else { 0.0 }], vec![1, 1])
559                .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))?;
560            Ok((DataStorage::Real(tensor), false))
561        }
562        Value::Complex(re, im) => {
563            let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
564                .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))?;
565            Ok((DataStorage::Complex(tensor), false))
566        }
567        Value::ComplexTensor(tensor) => Ok((DataStorage::Complex(tensor), false)),
568        Value::CharArray(chars) => {
569            let mut data = Vec::with_capacity(chars.data.len());
570            for c in 0..chars.cols {
571                for r in 0..chars.rows {
572                    let ch = chars.data[r * chars.cols + c] as u32;
573                    data.push(ch as f64);
574                }
575            }
576            let tensor = Tensor::new(data, vec![chars.rows, chars.cols])
577                .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))?;
578            Ok((DataStorage::Real(tensor), false))
579        }
580        other => Err(find_error_with_message(
581            format!(
582                "find: unsupported input type {:?}; expected numeric, logical, or char data",
583                other
584            ),
585            &FIND_ERROR_INVALID_INPUT,
586        )),
587    }
588}
589
590fn compute_find(storage: &DataStorage, options: &FindOptions) -> FindResult {
591    let shape = storage.shape().to_vec();
592    let limit = options.effective_limit();
593
594    match storage {
595        DataStorage::Real(tensor) => {
596            let mut indices = Vec::new();
597            let mut values = Vec::new();
598
599            if matches!(limit, Some(0)) {
600                return FindResult::new(shape, indices, FindValues::Real(values));
601            }
602
603            let len = tensor.data.len();
604            match options.direction {
605                FindDirection::First => {
606                    for idx in 0..len {
607                        let value = tensor.data[idx];
608                        if value != 0.0 {
609                            indices.push(idx + 1);
610                            values.push(value);
611                            if limit.is_some_and(|k| indices.len() >= k) {
612                                break;
613                            }
614                        }
615                    }
616                }
617                FindDirection::Last => {
618                    for idx in (0..len).rev() {
619                        let value = tensor.data[idx];
620                        if value != 0.0 {
621                            indices.push(idx + 1);
622                            values.push(value);
623                            if limit.is_some_and(|k| indices.len() >= k) {
624                                break;
625                            }
626                        }
627                    }
628                }
629            }
630
631            FindResult::new(shape, indices, FindValues::Real(values))
632        }
633        DataStorage::Complex(tensor) => {
634            let mut indices = Vec::new();
635            let mut values = Vec::new();
636
637            if matches!(limit, Some(0)) {
638                return FindResult::new(shape, indices, FindValues::Complex(values));
639            }
640
641            let len = tensor.data.len();
642            match options.direction {
643                FindDirection::First => {
644                    for idx in 0..len {
645                        let (re, im) = tensor.data[idx];
646                        if re != 0.0 || im != 0.0 {
647                            indices.push(idx + 1);
648                            values.push((re, im));
649                            if limit.is_some_and(|k| indices.len() >= k) {
650                                break;
651                            }
652                        }
653                    }
654                }
655                FindDirection::Last => {
656                    for idx in (0..len).rev() {
657                        let (re, im) = tensor.data[idx];
658                        if re != 0.0 || im != 0.0 {
659                            indices.push(idx + 1);
660                            values.push((re, im));
661                            if limit.is_some_and(|k| indices.len() >= k) {
662                                break;
663                            }
664                        }
665                    }
666                }
667            }
668
669            FindResult::new(shape, indices, FindValues::Complex(values))
670        }
671    }
672}
673
674impl FindResult {
675    fn new(shape: Vec<usize>, indices: Vec<usize>, values: FindValues) -> Self {
676        Self {
677            shape,
678            indices,
679            values,
680        }
681    }
682
683    fn linear_tensor(&self) -> crate::BuiltinResult<Tensor> {
684        let data: Vec<f64> = self.indices.iter().map(|&idx| idx as f64).collect();
685        let rows = data.len();
686        Tensor::new(data, vec![rows, 1])
687            .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))
688    }
689
690    fn row_tensor(&self) -> crate::BuiltinResult<Tensor> {
691        let mut data = Vec::with_capacity(self.indices.len());
692        let rows = self.shape.first().copied().unwrap_or(1).max(1);
693        for &idx in &self.indices {
694            let zero_based = idx - 1;
695            let row = (zero_based % rows) + 1;
696            data.push(row as f64);
697        }
698        Tensor::new(data, vec![self.indices.len(), 1])
699            .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))
700    }
701
702    fn column_tensor(&self) -> crate::BuiltinResult<Tensor> {
703        let mut data = Vec::with_capacity(self.indices.len());
704        let rows = self.shape.first().copied().unwrap_or(1).max(1);
705        for &idx in &self.indices {
706            let zero_based = idx - 1;
707            let col = (zero_based / rows) + 1;
708            data.push(col as f64);
709        }
710        Tensor::new(data, vec![self.indices.len(), 1])
711            .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))
712    }
713
714    fn values_value(&self, prefer_gpu: bool) -> crate::BuiltinResult<Value> {
715        match &self.values {
716            FindValues::Real(values) => {
717                let tensor = Tensor::new(values.clone(), vec![values.len(), 1]).map_err(|e| {
718                    find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL)
719                })?;
720                Ok(tensor_to_value(tensor, prefer_gpu))
721            }
722            FindValues::Complex(values) => {
723                let tensor =
724                    ComplexTensor::new(values.clone(), vec![values.len(), 1]).map_err(|e| {
725                        find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL)
726                    })?;
727                Ok(complex_tensor_into_value(tensor))
728            }
729        }
730    }
731}
732
733fn tensor_to_value(tensor: Tensor, prefer_gpu: bool) -> Value {
734    if prefer_gpu {
735        if let Some(provider) = runmat_accelerate_api::provider() {
736            let view = HostTensorView {
737                data: &tensor.data,
738                shape: &tensor.shape,
739            };
740            if let Ok(handle) = provider.upload(&view) {
741                return Value::GpuTensor(handle);
742            }
743        }
744    }
745    tensor::tensor_into_value(tensor)
746}
747
748#[cfg(test)]
749pub(crate) mod tests {
750    use super::*;
751    use crate::builtins::common::test_support;
752    use futures::executor::block_on;
753    use runmat_builtins::{CharArray, IntValue, Type};
754
755    fn find_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
756        block_on(super::find_builtin(value, rest))
757    }
758
759    fn evaluate(value: Value, rest: &[Value]) -> crate::BuiltinResult<FindEval> {
760        block_on(super::evaluate(value, rest))
761    }
762
763    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
764    #[test]
765    fn find_linear_indices_basic() {
766        let tensor = Tensor::new(vec![0.0, 4.0, 0.0, 7.0, 0.0, 9.0], vec![2, 3]).unwrap();
767        let value = find_builtin(Value::Tensor(tensor), Vec::new()).expect("find");
768        match value {
769            Value::Tensor(t) => {
770                assert_eq!(t.shape, vec![3, 1]);
771                assert_eq!(t.data, vec![2.0, 4.0, 6.0]);
772            }
773            other => panic!("expected tensor, got {other:?}"),
774        }
775    }
776
777    #[test]
778    fn find_type_is_column_vector() {
779        assert_eq!(
780            find_type(
781                &[Type::Tensor { shape: None }],
782                &ResolveContext::new(Vec::new()),
783            ),
784            Type::Tensor {
785                shape: Some(vec![None, Some(1)])
786            }
787        );
788    }
789
790    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
791    #[test]
792    fn find_limited_first() {
793        let tensor = Tensor::new(vec![0.0, 3.0, 5.0, 0.0, 8.0], vec![1, 5]).unwrap();
794        let result =
795            find_builtin(Value::Tensor(tensor), vec![Value::Int(IntValue::I32(2))]).expect("find");
796        match result {
797            Value::Tensor(t) => {
798                assert_eq!(t.data, vec![2.0, 3.0]);
799            }
800            other => panic!("expected tensor, got {other:?}"),
801        }
802    }
803
804    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
805    #[test]
806    fn find_last_single() {
807        let tensor = Tensor::new(vec![1.0, 0.0, 0.0, 6.0, 0.0, 2.0], vec![1, 6]).unwrap();
808        let result = find_builtin(Value::Tensor(tensor), vec![Value::from("last")]).expect("find");
809        match result {
810            Value::Num(n) => assert_eq!(n, 6.0),
811            Value::Tensor(t) => {
812                assert_eq!(t.data, vec![6.0]);
813            }
814            other => panic!("unexpected result {other:?}"),
815        }
816    }
817
818    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
819    #[test]
820    fn find_complex_values() {
821        let tensor =
822            ComplexTensor::new(vec![(0.0, 0.0), (1.0, 2.0), (0.0, 0.0)], vec![3, 1]).unwrap();
823        let eval = evaluate(Value::ComplexTensor(tensor), &[]).expect("find compute");
824        let values = eval.values_value().expect("values");
825        match values {
826            Value::Complex(re, im) => {
827                assert_eq!(re, 1.0);
828                assert_eq!(im, 2.0);
829            }
830            Value::ComplexTensor(ct) => {
831                assert_eq!(ct.shape, vec![1, 1]);
832                assert_eq!(ct.data, vec![(1.0, 2.0)]);
833            }
834            other => panic!("expected complex result, got {other:?}"),
835        }
836    }
837
838    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
839    #[test]
840    fn find_gpu_roundtrip() {
841        test_support::with_test_provider(|provider| {
842            let tensor = Tensor::new(vec![0.0, 4.0, 0.0, 7.0], vec![2, 2]).unwrap();
843            let view = HostTensorView {
844                data: &tensor.data,
845                shape: &tensor.shape,
846            };
847            let handle = provider.upload(&view).expect("upload");
848            let result = find_builtin(Value::GpuTensor(handle), Vec::new()).expect("find");
849            let gathered = test_support::gather(result).expect("gather");
850            assert_eq!(gathered.shape, vec![2, 1]);
851            assert_eq!(gathered.data, vec![2.0, 4.0]);
852        });
853    }
854
855    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
856    #[test]
857    fn find_direction_error() {
858        let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
859        let err = find_builtin(
860            Value::Tensor(tensor),
861            vec![Value::Int(IntValue::I32(1)), Value::from("invalid")],
862        )
863        .expect_err("expected error");
864        assert!(err.to_string().contains("direction"));
865        assert_eq!(err.identifier(), super::FIND_ERROR_INVALID_INPUT.identifier);
866    }
867
868    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
869    #[test]
870    fn find_multi_output_rows_cols_values() {
871        let tensor = Tensor::new(vec![0.0, 2.0, 3.0, 0.0, 0.0, 6.0], vec![2, 3]).unwrap();
872        let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
873
874        let rows = test_support::gather(eval.row_value().expect("rows")).expect("gather rows");
875        assert_eq!(rows.shape, vec![3, 1]);
876        assert_eq!(rows.data, vec![2.0, 1.0, 2.0]);
877
878        let cols = test_support::gather(eval.column_value().expect("cols")).expect("gather cols");
879        assert_eq!(cols.shape, vec![3, 1]);
880        assert_eq!(cols.data, vec![1.0, 2.0, 3.0]);
881
882        let vals = test_support::gather(eval.values_value().expect("vals")).expect("gather vals");
883        assert_eq!(vals.shape, vec![3, 1]);
884        assert_eq!(vals.data, vec![2.0, 3.0, 6.0]);
885    }
886
887    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
888    #[test]
889    fn find_last_order_descending() {
890        let tensor = Tensor::new(vec![1.0, 0.0, 2.0, 3.0, 0.0], vec![1, 5]).unwrap();
891        let result = find_builtin(
892            Value::Tensor(tensor),
893            vec![Value::Int(IntValue::I32(2)), Value::from("last")],
894        )
895        .expect("find");
896        match result {
897            Value::Tensor(t) => {
898                assert_eq!(t.shape, vec![2, 1]);
899                assert_eq!(t.data, vec![4.0, 3.0]);
900            }
901            Value::Num(_) => panic!("expected column vector"),
902            other => panic!("unexpected result {other:?}"),
903        }
904    }
905
906    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
907    #[test]
908    fn find_limit_zero_returns_empty() {
909        let tensor = Tensor::new(vec![1.0, 0.0, 3.0], vec![3, 1]).unwrap();
910        let result = find_builtin(Value::Tensor(tensor), vec![Value::Num(0.0)]).expect("find");
911        match result {
912            Value::Tensor(t) => {
913                assert_eq!(t.shape, vec![0, 1]);
914                assert!(t.data.is_empty());
915            }
916            other => panic!("expected empty tensor, got {other:?}"),
917        }
918    }
919
920    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
921    #[test]
922    fn find_char_array_supports_nonzero_codes() {
923        let chars = CharArray::new(vec!['\0', 'A', '\0'], 1, 3).unwrap();
924        let result = find_builtin(Value::CharArray(chars), Vec::new()).expect("find");
925        match result {
926            Value::Num(n) => assert_eq!(n, 2.0),
927            Value::Tensor(t) => assert_eq!(t.data, vec![2.0]),
928            other => panic!("unexpected result {other:?}"),
929        }
930    }
931
932    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
933    #[test]
934    fn find_gpu_multi_outputs_return_gpu_handles() {
935        test_support::with_test_provider(|provider| {
936            let tensor = Tensor::new(vec![0.0, 4.0, 5.0, 0.0], vec![2, 2]).unwrap();
937            let view = HostTensorView {
938                data: &tensor.data,
939                shape: &tensor.shape,
940            };
941            let handle = provider.upload(&view).expect("upload");
942            let eval = evaluate(Value::GpuTensor(handle), &[]).expect("evaluate");
943
944            let rows = eval.row_value().expect("rows");
945            assert!(matches!(rows, Value::GpuTensor(_)));
946            let rows_host = test_support::gather(rows).expect("gather rows");
947            assert_eq!(rows_host.data, vec![2.0, 1.0]);
948
949            let cols = eval.column_value().expect("cols");
950            assert!(matches!(cols, Value::GpuTensor(_)));
951            let cols_host = test_support::gather(cols).expect("gather cols");
952            assert_eq!(cols_host.data, vec![1.0, 2.0]);
953
954            let vals = eval.values_value().expect("vals");
955            assert!(matches!(vals, Value::GpuTensor(_)));
956            let vals_host = test_support::gather(vals).expect("gather vals");
957            assert_eq!(vals_host.data, vec![4.0, 5.0]);
958        });
959    }
960
961    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
962    #[test]
963    #[cfg(feature = "wgpu")]
964    fn find_wgpu_matches_cpu() {
965        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
966            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
967        );
968        let tensor = Tensor::new(vec![0.0, 2.0, 0.0, 3.0, 4.0, 0.0], vec![3, 2]).unwrap();
969        let cpu_eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("cpu evaluate");
970        let cpu_linear =
971            test_support::gather(cpu_eval.linear_value().expect("cpu linear")).expect("cpu gather");
972        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
973        let view = HostTensorView {
974            data: &tensor.data,
975            shape: &tensor.shape,
976        };
977        let handle = provider.upload(&view).expect("upload");
978        let gpu_eval = evaluate(Value::GpuTensor(handle), &[]).expect("gpu evaluate");
979        let gpu_linear =
980            test_support::gather(gpu_eval.linear_value().expect("gpu linear")).expect("gpu gather");
981        assert_eq!(gpu_linear.data, cpu_linear.data);
982    }
983}