Skip to main content

runmat_runtime/builtins/array/indexing/
find.rs

1//! MATLAB-compatible `find` builtin with GPU-aware semantics for RunMat.
2
3use runmat_accelerate_api::{HostTensorView, ProviderFindResult};
4use runmat_builtins::{
5    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
6    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
7    ComplexTensor, ResolveContext, Tensor, Type, Value,
8};
9use runmat_macros::runtime_builtin;
10
11use crate::builtins::array::type_resolvers::column_vector_type;
12use crate::builtins::common::arg_tokens::ArgToken;
13use crate::builtins::common::random_args::complex_tensor_into_value;
14use crate::builtins::common::spec::{
15    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
16    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
17};
18use crate::builtins::common::{gpu_helpers, tensor};
19use crate::{build_runtime_error, RuntimeError};
20
21#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::indexing::find")]
22pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
23    name: "find",
24    op_kind: GpuOpKind::Custom("find"),
25    supported_precisions: &[ScalarType::F32, ScalarType::F64],
26    broadcast: BroadcastSemantics::None,
27    provider_hooks: &[ProviderHook::Custom("find")],
28    constant_strategy: ConstantStrategy::InlineLiteral,
29    residency: ResidencyPolicy::NewHandle,
30    nan_mode: ReductionNaN::Include,
31    two_pass_threshold: None,
32    workgroup_size: None,
33    accepts_nan_mode: false,
34    notes: "WGPU provider executes find directly on device; other providers fall back to host and re-upload results to preserve residency.",
35};
36
37#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::indexing::find")]
38pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
39    name: "find",
40    shape: ShapeRequirements::Any,
41    constant_strategy: ConstantStrategy::InlineLiteral,
42    elementwise: None,
43    reduction: None,
44    emits_nan: false,
45    notes: "Find drives control flow and currently bypasses fusion; metadata is present for completeness only.",
46};
47
48fn find_type(_args: &[Type], _ctx: &ResolveContext) -> Type {
49    column_vector_type()
50}
51
52const BUILTIN_NAME: &str = "find";
53
54const FIND_OUTPUT_LINEAR: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
55    name: "idx",
56    ty: BuiltinParamType::NumericArray,
57    arity: BuiltinParamArity::Required,
58    default: None,
59    description: "Linear indices of non-zero elements.",
60}];
61
62const FIND_OUTPUT_ROW_COL: [BuiltinParamDescriptor; 2] = [
63    BuiltinParamDescriptor {
64        name: "row",
65        ty: BuiltinParamType::NumericArray,
66        arity: BuiltinParamArity::Required,
67        default: None,
68        description: "Row subscripts of non-zero elements.",
69    },
70    BuiltinParamDescriptor {
71        name: "col",
72        ty: BuiltinParamType::NumericArray,
73        arity: BuiltinParamArity::Required,
74        default: None,
75        description: "Column subscripts of non-zero elements.",
76    },
77];
78
79const FIND_OUTPUT_ROW_COL_VAL: [BuiltinParamDescriptor; 3] = [
80    BuiltinParamDescriptor {
81        name: "row",
82        ty: BuiltinParamType::NumericArray,
83        arity: BuiltinParamArity::Required,
84        default: None,
85        description: "Row subscripts of non-zero elements.",
86    },
87    BuiltinParamDescriptor {
88        name: "col",
89        ty: BuiltinParamType::NumericArray,
90        arity: BuiltinParamArity::Required,
91        default: None,
92        description: "Column subscripts of non-zero elements.",
93    },
94    BuiltinParamDescriptor {
95        name: "v",
96        ty: BuiltinParamType::Any,
97        arity: BuiltinParamArity::Required,
98        default: None,
99        description: "Values at the reported row/column locations.",
100    },
101];
102
103const FIND_INPUTS_BASE: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
104    name: "X",
105    ty: BuiltinParamType::Any,
106    arity: BuiltinParamArity::Required,
107    default: None,
108    description: "Input array to search.",
109}];
110
111const FIND_INPUTS_LIMIT: [BuiltinParamDescriptor; 2] = [
112    BuiltinParamDescriptor {
113        name: "X",
114        ty: BuiltinParamType::Any,
115        arity: BuiltinParamArity::Required,
116        default: None,
117        description: "Input array to search.",
118    },
119    BuiltinParamDescriptor {
120        name: "K",
121        ty: BuiltinParamType::NumericScalar,
122        arity: BuiltinParamArity::Required,
123        default: None,
124        description: "Maximum number of indices to return.",
125    },
126];
127
128const FIND_INPUTS_LIMIT_DIR: [BuiltinParamDescriptor; 3] = [
129    BuiltinParamDescriptor {
130        name: "X",
131        ty: BuiltinParamType::Any,
132        arity: BuiltinParamArity::Required,
133        default: None,
134        description: "Input array to search.",
135    },
136    BuiltinParamDescriptor {
137        name: "K",
138        ty: BuiltinParamType::NumericScalar,
139        arity: BuiltinParamArity::Required,
140        default: None,
141        description: "Maximum number of indices to return.",
142    },
143    BuiltinParamDescriptor {
144        name: "direction",
145        ty: BuiltinParamType::StringScalar,
146        arity: BuiltinParamArity::Required,
147        default: Some("\"first\""),
148        description: "Direction selector: `\"first\"` or `\"last\"`.",
149    },
150];
151
152const FIND_SIGNATURES: [BuiltinSignatureDescriptor; 7] = [
153    BuiltinSignatureDescriptor {
154        label: "idx = find(X)",
155        inputs: &FIND_INPUTS_BASE,
156        outputs: &FIND_OUTPUT_LINEAR,
157    },
158    BuiltinSignatureDescriptor {
159        label: "idx = find(X, K)",
160        inputs: &FIND_INPUTS_LIMIT,
161        outputs: &FIND_OUTPUT_LINEAR,
162    },
163    BuiltinSignatureDescriptor {
164        label: "idx = find(X, K, direction)",
165        inputs: &FIND_INPUTS_LIMIT_DIR,
166        outputs: &FIND_OUTPUT_LINEAR,
167    },
168    BuiltinSignatureDescriptor {
169        label: "[row, col] = find(X)",
170        inputs: &FIND_INPUTS_BASE,
171        outputs: &FIND_OUTPUT_ROW_COL,
172    },
173    BuiltinSignatureDescriptor {
174        label: "[row, col] = find(X, K, direction)",
175        inputs: &FIND_INPUTS_LIMIT_DIR,
176        outputs: &FIND_OUTPUT_ROW_COL,
177    },
178    BuiltinSignatureDescriptor {
179        label: "[row, col, v] = find(X)",
180        inputs: &FIND_INPUTS_BASE,
181        outputs: &FIND_OUTPUT_ROW_COL_VAL,
182    },
183    BuiltinSignatureDescriptor {
184        label: "[row, col, v] = find(X, K, direction)",
185        inputs: &FIND_INPUTS_LIMIT_DIR,
186        outputs: &FIND_OUTPUT_ROW_COL_VAL,
187    },
188];
189
190const FIND_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
191    code: "RM.FIND.INVALID_INPUT",
192    identifier: Some("RunMat:find:InvalidInput"),
193    when: "Input type or option arguments are not valid for find.",
194    message: "find: invalid input arguments",
195};
196
197const FIND_ERROR_PROVIDER_OUTPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
198    code: "RM.FIND.PROVIDER_OUTPUT",
199    identifier: Some("RunMat:find:ProviderOutput"),
200    when: "GPU provider does not return expected output buffers for requested nargout.",
201    message: "find: provider output buffer mismatch",
202};
203
204const FIND_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
205    code: "RM.FIND.INTERNAL",
206    identifier: Some("RunMat:find:InternalError"),
207    when: "Internal tensor conversion/materialization fails while building outputs.",
208    message: "find: internal error",
209};
210
211const FIND_ERRORS: [BuiltinErrorDescriptor; 3] = [
212    FIND_ERROR_INVALID_INPUT,
213    FIND_ERROR_PROVIDER_OUTPUT,
214    FIND_ERROR_INTERNAL,
215];
216
217pub const FIND_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
218    signatures: &FIND_SIGNATURES,
219    output_mode: BuiltinOutputMode::ByRequestedOutputCount,
220    completion_policy: BuiltinCompletionPolicy::Public,
221    errors: &FIND_ERRORS,
222};
223
224fn find_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
225    find_error_with_message(error.message, error)
226}
227
228fn find_error_with_message(
229    message: impl Into<String>,
230    error: &'static BuiltinErrorDescriptor,
231) -> RuntimeError {
232    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
233    if let Some(identifier) = error.identifier {
234        builder = builder.with_identifier(identifier);
235    }
236    builder.build()
237}
238
239fn parse_find_tokens(tokens: &[ArgToken]) -> crate::BuiltinResult<FindOptions> {
240    match tokens.len() {
241        0 => Ok(FindOptions::default()),
242        1 => {
243            if let Some(direction) = token_to_direction(&tokens[0])? {
244                let limit = if matches!(direction, FindDirection::Last) {
245                    Some(1)
246                } else {
247                    None
248                };
249                Ok(FindOptions { limit, direction })
250            } else {
251                let limit = token_to_limit(&tokens[0])?;
252                Ok(FindOptions {
253                    limit: Some(limit),
254                    direction: FindDirection::First,
255                })
256            }
257        }
258        2 => {
259            let limit = token_to_limit(&tokens[0])?;
260            let direction = token_to_direction(&tokens[1])?.ok_or_else(|| {
261                find_error_with_message(
262                    "find: third argument must be 'first' or 'last'",
263                    &FIND_ERROR_INVALID_INPUT,
264                )
265            })?;
266            Ok(FindOptions {
267                limit: Some(limit),
268                direction,
269            })
270        }
271        _ => Err(find_error_with_message(
272            "find: too many input arguments",
273            &FIND_ERROR_INVALID_INPUT,
274        )),
275    }
276}
277
278fn token_to_direction(token: &ArgToken) -> crate::BuiltinResult<Option<FindDirection>> {
279    match token {
280        ArgToken::String(text) => match text.as_str() {
281            "first" => Ok(Some(FindDirection::First)),
282            "last" => Ok(Some(FindDirection::Last)),
283            _ => Err(find_error_with_message(
284                "find: direction must be 'first' or 'last'",
285                &FIND_ERROR_INVALID_INPUT,
286            )),
287        },
288        _ => Ok(None),
289    }
290}
291
292fn token_to_limit(token: &ArgToken) -> crate::BuiltinResult<usize> {
293    match token {
294        ArgToken::Number(value) => parse_limit_scalar(*value),
295        _ => Err(find_error_with_message(
296            "find: second argument must be a scalar",
297            &FIND_ERROR_INVALID_INPUT,
298        )),
299    }
300}
301
302#[runtime_builtin(
303    name = "find",
304    category = "array/indexing",
305    summary = "Locate nonzero indices and values.",
306    keywords = "find,nonzero,indices,row,column,gpu",
307    accel = "custom",
308    type_resolver(find_type),
309    descriptor(crate::builtins::array::indexing::find::FIND_DESCRIPTOR),
310    builtin_path = "crate::builtins::array::indexing::find"
311)]
312async fn find_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
313    let eval = evaluate(value, &rest).await?;
314    if let Some(out_count) = crate::output_count::current_output_count() {
315        if out_count == 0 {
316            return Ok(Value::OutputList(Vec::new()));
317        }
318        if out_count <= 1 {
319            let linear = eval.linear_value()?;
320            return Ok(crate::output_count::output_list_with_padding(
321                out_count,
322                vec![linear],
323            ));
324        }
325        let rows = eval.row_value()?;
326        let cols = eval.column_value()?;
327        let mut outputs = vec![rows, cols];
328        if out_count >= 3 {
329            outputs.push(eval.values_value()?);
330        }
331        return Ok(crate::output_count::output_list_with_padding(
332            out_count, outputs,
333        ));
334    }
335    eval.linear_value()
336}
337
338/// Evaluate `find` and return an object that can materialise the various outputs.
339pub async fn evaluate(value: Value, args: &[Value]) -> crate::BuiltinResult<FindEval> {
340    let options = parse_options(args).await?;
341    match value {
342        Value::GpuTensor(handle) => {
343            if let Some(result) = try_provider_find(&handle, &options) {
344                return Ok(FindEval::from_gpu(result));
345            }
346            let (storage, _) = materialize_input(Value::GpuTensor(handle)).await?;
347            let result = compute_find(&storage, &options);
348            Ok(FindEval::from_host(result, true))
349        }
350        Value::SparseTensor(sparse) => {
351            let result = compute_find_sparse(&sparse, &options);
352            Ok(FindEval::from_host(result, false))
353        }
354        other => {
355            let (storage, input_was_gpu) = materialize_input(other).await?;
356            let result = compute_find(&storage, &options);
357            Ok(FindEval::from_host(result, input_was_gpu))
358        }
359    }
360}
361
362fn try_provider_find(
363    handle: &runmat_accelerate_api::GpuTensorHandle,
364    options: &FindOptions,
365) -> Option<ProviderFindResult> {
366    #[cfg(all(test, feature = "wgpu"))]
367    {
368        if handle.device_id != 0 {
369            let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
370                runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
371            );
372        }
373    }
374    let provider = runmat_accelerate_api::provider()?;
375    let direction = match options.direction {
376        FindDirection::First => runmat_accelerate_api::FindDirection::First,
377        FindDirection::Last => runmat_accelerate_api::FindDirection::Last,
378    };
379    let limit = options.effective_limit();
380    provider.find(handle, limit, direction).ok()
381}
382
383#[derive(Debug, Clone, Copy, PartialEq, Eq)]
384enum FindDirection {
385    First,
386    Last,
387}
388
389#[derive(Debug, Clone)]
390struct FindOptions {
391    limit: Option<usize>,
392    direction: FindDirection,
393}
394
395impl Default for FindOptions {
396    fn default() -> Self {
397        Self {
398            limit: None,
399            direction: FindDirection::First,
400        }
401    }
402}
403
404impl FindOptions {
405    fn effective_limit(&self) -> Option<usize> {
406        match self.direction {
407            FindDirection::Last => self.limit.or(Some(1)),
408            FindDirection::First => self.limit,
409        }
410    }
411}
412
413#[derive(Clone)]
414enum DataStorage {
415    Real(Tensor),
416    Complex(ComplexTensor),
417}
418
419impl DataStorage {
420    fn shape(&self) -> &[usize] {
421        match self {
422            DataStorage::Real(t) => &t.shape,
423            DataStorage::Complex(t) => &t.shape,
424        }
425    }
426}
427
428#[derive(Clone)]
429struct FindResult {
430    shape: Vec<usize>,
431    indices: Vec<usize>,
432    values: FindValues,
433}
434
435#[derive(Clone)]
436enum FindValues {
437    Real(Vec<f64>),
438    Complex(Vec<(f64, f64)>),
439}
440
441pub struct FindEval {
442    inner: FindEvalInner,
443}
444
445enum FindEvalInner {
446    Host {
447        result: FindResult,
448        prefer_gpu: bool,
449    },
450    Gpu {
451        result: ProviderFindResult,
452    },
453}
454
455impl FindEval {
456    fn from_host(result: FindResult, prefer_gpu: bool) -> Self {
457        Self {
458            inner: FindEvalInner::Host { result, prefer_gpu },
459        }
460    }
461
462    fn from_gpu(result: ProviderFindResult) -> Self {
463        Self {
464            inner: FindEvalInner::Gpu { result },
465        }
466    }
467
468    pub fn linear_value(&self) -> crate::BuiltinResult<Value> {
469        match &self.inner {
470            FindEvalInner::Host { result, prefer_gpu } => {
471                let tensor = result.linear_tensor()?;
472                Ok(tensor_to_value(tensor, *prefer_gpu))
473            }
474            FindEvalInner::Gpu { result } => Ok(Value::GpuTensor(result.linear.clone())),
475        }
476    }
477
478    pub fn row_value(&self) -> crate::BuiltinResult<Value> {
479        match &self.inner {
480            FindEvalInner::Host { result, prefer_gpu } => {
481                let tensor = result.row_tensor()?;
482                Ok(tensor_to_value(tensor, *prefer_gpu))
483            }
484            FindEvalInner::Gpu { result } => Ok(Value::GpuTensor(result.rows.clone())),
485        }
486    }
487
488    pub fn column_value(&self) -> crate::BuiltinResult<Value> {
489        match &self.inner {
490            FindEvalInner::Host { result, prefer_gpu } => {
491                let tensor = result.column_tensor()?;
492                Ok(tensor_to_value(tensor, *prefer_gpu))
493            }
494            FindEvalInner::Gpu { result } => Ok(Value::GpuTensor(result.cols.clone())),
495        }
496    }
497
498    pub fn values_value(&self) -> crate::BuiltinResult<Value> {
499        match &self.inner {
500            FindEvalInner::Host { result, prefer_gpu } => result.values_value(*prefer_gpu),
501            FindEvalInner::Gpu { result } => result
502                .values
503                .as_ref()
504                .map(|handle| Value::GpuTensor(handle.clone()))
505                .ok_or_else(|| find_error(&FIND_ERROR_PROVIDER_OUTPUT)),
506        }
507    }
508}
509
510async fn parse_options(args: &[Value]) -> crate::BuiltinResult<FindOptions> {
511    parse_find_tokens(&crate::builtins::common::arg_tokens::tokens_from_values(
512        args,
513    ))
514}
515
516fn parse_limit_scalar(value: f64) -> crate::BuiltinResult<usize> {
517    if !value.is_finite() {
518        return Err(find_error_with_message(
519            "find: K must be a finite, non-negative integer",
520            &FIND_ERROR_INVALID_INPUT,
521        ));
522    }
523    let rounded = value.round();
524    if (rounded - value).abs() > f64::EPSILON {
525        return Err(find_error_with_message(
526            "find: K must be a finite, non-negative integer",
527            &FIND_ERROR_INVALID_INPUT,
528        ));
529    }
530    if rounded < 0.0 {
531        return Err(find_error_with_message(
532            "find: K must be >= 0",
533            &FIND_ERROR_INVALID_INPUT,
534        ));
535    }
536    Ok(rounded as usize)
537}
538
539async fn materialize_input(value: Value) -> crate::BuiltinResult<(DataStorage, bool)> {
540    match value {
541        Value::GpuTensor(handle) => {
542            let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
543            Ok((DataStorage::Real(tensor), true))
544        }
545        Value::Tensor(tensor) => Ok((DataStorage::Real(tensor), false)),
546        Value::SparseTensor(sparse) => Ok((
547            DataStorage::Real(sparse.to_dense().map_err(|e| {
548                find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL)
549            })?),
550            false,
551        )),
552        Value::LogicalArray(logical) => {
553            let tensor = tensor::logical_to_tensor(&logical)
554                .map_err(|message| find_error_with_message(message, &FIND_ERROR_INTERNAL))?;
555            Ok((DataStorage::Real(tensor), false))
556        }
557        Value::Num(n) => {
558            let tensor = Tensor::new(vec![n], vec![1, 1])
559                .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))?;
560            Ok((DataStorage::Real(tensor), false))
561        }
562        Value::Int(i) => {
563            let tensor = Tensor::new(vec![i.to_f64()], vec![1, 1])
564                .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))?;
565            Ok((DataStorage::Real(tensor), false))
566        }
567        Value::Bool(b) => {
568            let tensor = Tensor::new(vec![if b { 1.0 } else { 0.0 }], vec![1, 1])
569                .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))?;
570            Ok((DataStorage::Real(tensor), false))
571        }
572        Value::Complex(re, im) => {
573            let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
574                .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))?;
575            Ok((DataStorage::Complex(tensor), false))
576        }
577        Value::ComplexTensor(tensor) => Ok((DataStorage::Complex(tensor), false)),
578        Value::CharArray(chars) => {
579            let mut data = Vec::with_capacity(chars.data.len());
580            for c in 0..chars.cols {
581                for r in 0..chars.rows {
582                    let ch = chars.data[r * chars.cols + c] as u32;
583                    data.push(ch as f64);
584                }
585            }
586            let tensor = Tensor::new(data, vec![chars.rows, chars.cols])
587                .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))?;
588            Ok((DataStorage::Real(tensor), false))
589        }
590        other => Err(find_error_with_message(
591            format!(
592                "find: unsupported input type {:?}; expected numeric, logical, or char data",
593                other
594            ),
595            &FIND_ERROR_INVALID_INPUT,
596        )),
597    }
598}
599
600fn compute_find(storage: &DataStorage, options: &FindOptions) -> FindResult {
601    let shape = storage.shape().to_vec();
602    let limit = options.effective_limit();
603
604    match storage {
605        DataStorage::Real(tensor) => {
606            let mut indices = Vec::new();
607            let mut values = Vec::new();
608
609            if matches!(limit, Some(0)) {
610                return FindResult::new(shape, indices, FindValues::Real(values));
611            }
612
613            let len = tensor.data.len();
614            match options.direction {
615                FindDirection::First => {
616                    for idx in 0..len {
617                        let value = tensor.data[idx];
618                        if value != 0.0 {
619                            indices.push(idx + 1);
620                            values.push(value);
621                            if limit.is_some_and(|k| indices.len() >= k) {
622                                break;
623                            }
624                        }
625                    }
626                }
627                FindDirection::Last => {
628                    for idx in (0..len).rev() {
629                        let value = tensor.data[idx];
630                        if value != 0.0 {
631                            indices.push(idx + 1);
632                            values.push(value);
633                            if limit.is_some_and(|k| indices.len() >= k) {
634                                break;
635                            }
636                        }
637                    }
638                }
639            }
640
641            FindResult::new(shape, indices, FindValues::Real(values))
642        }
643        DataStorage::Complex(tensor) => {
644            let mut indices = Vec::new();
645            let mut values = Vec::new();
646
647            if matches!(limit, Some(0)) {
648                return FindResult::new(shape, indices, FindValues::Complex(values));
649            }
650
651            let len = tensor.data.len();
652            match options.direction {
653                FindDirection::First => {
654                    for idx in 0..len {
655                        let (re, im) = tensor.data[idx];
656                        if re != 0.0 || im != 0.0 {
657                            indices.push(idx + 1);
658                            values.push((re, im));
659                            if limit.is_some_and(|k| indices.len() >= k) {
660                                break;
661                            }
662                        }
663                    }
664                }
665                FindDirection::Last => {
666                    for idx in (0..len).rev() {
667                        let (re, im) = tensor.data[idx];
668                        if re != 0.0 || im != 0.0 {
669                            indices.push(idx + 1);
670                            values.push((re, im));
671                            if limit.is_some_and(|k| indices.len() >= k) {
672                                break;
673                            }
674                        }
675                    }
676                }
677            }
678
679            FindResult::new(shape, indices, FindValues::Complex(values))
680        }
681    }
682}
683
684fn compute_find_sparse(
685    sparse: &runmat_builtins::SparseTensor,
686    options: &FindOptions,
687) -> FindResult {
688    let shape = vec![sparse.rows, sparse.cols];
689    let limit = options.effective_limit();
690
691    let mut indices = Vec::new();
692    let mut values = Vec::new();
693
694    if matches!(limit, Some(0)) {
695        return FindResult::new(shape, indices, FindValues::Real(values));
696    }
697
698    match options.direction {
699        FindDirection::First => {
700            for col in 0..sparse.cols {
701                let col_start = sparse.col_ptrs[col];
702                let col_end = sparse.col_ptrs[col + 1];
703                for idx in col_start..col_end {
704                    let row = sparse.row_indices[idx];
705                    let value = sparse.values[idx];
706                    if value != 0.0 {
707                        let linear_idx = row + col * sparse.rows;
708                        indices.push(linear_idx + 1);
709                        values.push(value);
710                        if limit.is_some_and(|k| indices.len() >= k) {
711                            return FindResult::new(shape, indices, FindValues::Real(values));
712                        }
713                    }
714                }
715            }
716        }
717        FindDirection::Last => {
718            for col in (0..sparse.cols).rev() {
719                let col_start = sparse.col_ptrs[col];
720                let col_end = sparse.col_ptrs[col + 1];
721                for idx in (col_start..col_end).rev() {
722                    let row = sparse.row_indices[idx];
723                    let value = sparse.values[idx];
724                    if value != 0.0 {
725                        let linear_idx = row + col * sparse.rows;
726                        indices.push(linear_idx + 1);
727                        values.push(value);
728                        if limit.is_some_and(|k| indices.len() >= k) {
729                            return FindResult::new(shape, indices, FindValues::Real(values));
730                        }
731                    }
732                }
733            }
734        }
735    }
736
737    FindResult::new(shape, indices, FindValues::Real(values))
738}
739
740impl FindResult {
741    fn new(shape: Vec<usize>, indices: Vec<usize>, values: FindValues) -> Self {
742        Self {
743            shape,
744            indices,
745            values,
746        }
747    }
748
749    fn linear_tensor(&self) -> crate::BuiltinResult<Tensor> {
750        let data: Vec<f64> = self.indices.iter().map(|&idx| idx as f64).collect();
751        let rows = data.len();
752        Tensor::new(data, vec![rows, 1])
753            .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))
754    }
755
756    fn row_tensor(&self) -> crate::BuiltinResult<Tensor> {
757        let mut data = Vec::with_capacity(self.indices.len());
758        let rows = self.shape.first().copied().unwrap_or(1).max(1);
759        for &idx in &self.indices {
760            let zero_based = idx - 1;
761            let row = (zero_based % rows) + 1;
762            data.push(row as f64);
763        }
764        Tensor::new(data, vec![self.indices.len(), 1])
765            .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))
766    }
767
768    fn column_tensor(&self) -> crate::BuiltinResult<Tensor> {
769        let mut data = Vec::with_capacity(self.indices.len());
770        let rows = self.shape.first().copied().unwrap_or(1).max(1);
771        for &idx in &self.indices {
772            let zero_based = idx - 1;
773            let col = (zero_based / rows) + 1;
774            data.push(col as f64);
775        }
776        Tensor::new(data, vec![self.indices.len(), 1])
777            .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))
778    }
779
780    fn values_value(&self, prefer_gpu: bool) -> crate::BuiltinResult<Value> {
781        match &self.values {
782            FindValues::Real(values) => {
783                let tensor = Tensor::new(values.clone(), vec![values.len(), 1]).map_err(|e| {
784                    find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL)
785                })?;
786                Ok(tensor_to_value(tensor, prefer_gpu))
787            }
788            FindValues::Complex(values) => {
789                let tensor =
790                    ComplexTensor::new(values.clone(), vec![values.len(), 1]).map_err(|e| {
791                        find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL)
792                    })?;
793                Ok(complex_tensor_into_value(tensor))
794            }
795        }
796    }
797}
798
799fn tensor_to_value(tensor: Tensor, prefer_gpu: bool) -> Value {
800    if prefer_gpu {
801        if let Some(provider) = runmat_accelerate_api::provider() {
802            let view = HostTensorView {
803                data: &tensor.data,
804                shape: &tensor.shape,
805            };
806            if let Ok(handle) = provider.upload(&view) {
807                return Value::GpuTensor(handle);
808            }
809        }
810    }
811    tensor::tensor_into_value(tensor)
812}
813
814#[cfg(test)]
815pub(crate) mod tests {
816    use super::*;
817    use crate::builtins::common::test_support;
818    use futures::executor::block_on;
819    use runmat_builtins::{CharArray, IntValue, Type};
820
821    fn find_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
822        block_on(super::find_builtin(value, rest))
823    }
824
825    fn evaluate(value: Value, rest: &[Value]) -> crate::BuiltinResult<FindEval> {
826        block_on(super::evaluate(value, rest))
827    }
828
829    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
830    #[test]
831    fn find_linear_indices_basic() {
832        let tensor = Tensor::new(vec![0.0, 4.0, 0.0, 7.0, 0.0, 9.0], vec![2, 3]).unwrap();
833        let value = find_builtin(Value::Tensor(tensor), Vec::new()).expect("find");
834        match value {
835            Value::Tensor(t) => {
836                assert_eq!(t.shape, vec![3, 1]);
837                assert_eq!(t.data, vec![2.0, 4.0, 6.0]);
838            }
839            other => panic!("expected tensor, got {other:?}"),
840        }
841    }
842
843    #[test]
844    fn find_type_is_column_vector() {
845        assert_eq!(
846            find_type(
847                &[Type::Tensor { shape: None }],
848                &ResolveContext::new(Vec::new()),
849            ),
850            Type::Tensor {
851                shape: Some(vec![None, Some(1)])
852            }
853        );
854    }
855
856    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
857    #[test]
858    fn find_limited_first() {
859        let tensor = Tensor::new(vec![0.0, 3.0, 5.0, 0.0, 8.0], vec![1, 5]).unwrap();
860        let result =
861            find_builtin(Value::Tensor(tensor), vec![Value::Int(IntValue::I32(2))]).expect("find");
862        match result {
863            Value::Tensor(t) => {
864                assert_eq!(t.data, vec![2.0, 3.0]);
865            }
866            other => panic!("expected tensor, got {other:?}"),
867        }
868    }
869
870    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
871    #[test]
872    fn find_last_single() {
873        let tensor = Tensor::new(vec![1.0, 0.0, 0.0, 6.0, 0.0, 2.0], vec![1, 6]).unwrap();
874        let result = find_builtin(Value::Tensor(tensor), vec![Value::from("last")]).expect("find");
875        match result {
876            Value::Num(n) => assert_eq!(n, 6.0),
877            Value::Tensor(t) => {
878                assert_eq!(t.data, vec![6.0]);
879            }
880            other => panic!("unexpected result {other:?}"),
881        }
882    }
883
884    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
885    #[test]
886    fn find_complex_values() {
887        let tensor =
888            ComplexTensor::new(vec![(0.0, 0.0), (1.0, 2.0), (0.0, 0.0)], vec![3, 1]).unwrap();
889        let eval = evaluate(Value::ComplexTensor(tensor), &[]).expect("find compute");
890        let values = eval.values_value().expect("values");
891        match values {
892            Value::Complex(re, im) => {
893                assert_eq!(re, 1.0);
894                assert_eq!(im, 2.0);
895            }
896            Value::ComplexTensor(ct) => {
897                assert_eq!(ct.shape, vec![1, 1]);
898                assert_eq!(ct.data, vec![(1.0, 2.0)]);
899            }
900            other => panic!("expected complex result, got {other:?}"),
901        }
902    }
903
904    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
905    #[test]
906    fn find_gpu_roundtrip() {
907        test_support::with_test_provider(|provider| {
908            let tensor = Tensor::new(vec![0.0, 4.0, 0.0, 7.0], vec![2, 2]).unwrap();
909            let view = HostTensorView {
910                data: &tensor.data,
911                shape: &tensor.shape,
912            };
913            let handle = provider.upload(&view).expect("upload");
914            let result = find_builtin(Value::GpuTensor(handle), Vec::new()).expect("find");
915            let gathered = test_support::gather(result).expect("gather");
916            assert_eq!(gathered.shape, vec![2, 1]);
917            assert_eq!(gathered.data, vec![2.0, 4.0]);
918        });
919    }
920
921    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
922    #[test]
923    fn find_direction_error() {
924        let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
925        let err = find_builtin(
926            Value::Tensor(tensor),
927            vec![Value::Int(IntValue::I32(1)), Value::from("invalid")],
928        )
929        .expect_err("expected error");
930        assert!(err.to_string().contains("direction"));
931        assert_eq!(err.identifier(), super::FIND_ERROR_INVALID_INPUT.identifier);
932    }
933
934    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
935    #[test]
936    fn find_multi_output_rows_cols_values() {
937        let tensor = Tensor::new(vec![0.0, 2.0, 3.0, 0.0, 0.0, 6.0], vec![2, 3]).unwrap();
938        let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
939
940        let rows = test_support::gather(eval.row_value().expect("rows")).expect("gather rows");
941        assert_eq!(rows.shape, vec![3, 1]);
942        assert_eq!(rows.data, vec![2.0, 1.0, 2.0]);
943
944        let cols = test_support::gather(eval.column_value().expect("cols")).expect("gather cols");
945        assert_eq!(cols.shape, vec![3, 1]);
946        assert_eq!(cols.data, vec![1.0, 2.0, 3.0]);
947
948        let vals = test_support::gather(eval.values_value().expect("vals")).expect("gather vals");
949        assert_eq!(vals.shape, vec![3, 1]);
950        assert_eq!(vals.data, vec![2.0, 3.0, 6.0]);
951    }
952
953    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
954    #[test]
955    fn find_last_order_descending() {
956        let tensor = Tensor::new(vec![1.0, 0.0, 2.0, 3.0, 0.0], vec![1, 5]).unwrap();
957        let result = find_builtin(
958            Value::Tensor(tensor),
959            vec![Value::Int(IntValue::I32(2)), Value::from("last")],
960        )
961        .expect("find");
962        match result {
963            Value::Tensor(t) => {
964                assert_eq!(t.shape, vec![2, 1]);
965                assert_eq!(t.data, vec![4.0, 3.0]);
966            }
967            Value::Num(_) => panic!("expected column vector"),
968            other => panic!("unexpected result {other:?}"),
969        }
970    }
971
972    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
973    #[test]
974    fn find_limit_zero_returns_empty() {
975        let tensor = Tensor::new(vec![1.0, 0.0, 3.0], vec![3, 1]).unwrap();
976        let result = find_builtin(Value::Tensor(tensor), vec![Value::Num(0.0)]).expect("find");
977        match result {
978            Value::Tensor(t) => {
979                assert_eq!(t.shape, vec![0, 1]);
980                assert!(t.data.is_empty());
981            }
982            other => panic!("expected empty tensor, got {other:?}"),
983        }
984    }
985
986    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
987    #[test]
988    fn find_char_array_supports_nonzero_codes() {
989        let chars = CharArray::new(vec!['\0', 'A', '\0'], 1, 3).unwrap();
990        let result = find_builtin(Value::CharArray(chars), Vec::new()).expect("find");
991        match result {
992            Value::Num(n) => assert_eq!(n, 2.0),
993            Value::Tensor(t) => assert_eq!(t.data, vec![2.0]),
994            other => panic!("unexpected result {other:?}"),
995        }
996    }
997
998    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
999    #[test]
1000    fn find_gpu_multi_outputs_return_gpu_handles() {
1001        test_support::with_test_provider(|provider| {
1002            let tensor = Tensor::new(vec![0.0, 4.0, 5.0, 0.0], vec![2, 2]).unwrap();
1003            let view = HostTensorView {
1004                data: &tensor.data,
1005                shape: &tensor.shape,
1006            };
1007            let handle = provider.upload(&view).expect("upload");
1008            let eval = evaluate(Value::GpuTensor(handle), &[]).expect("evaluate");
1009
1010            let rows = eval.row_value().expect("rows");
1011            assert!(matches!(rows, Value::GpuTensor(_)));
1012            let rows_host = test_support::gather(rows).expect("gather rows");
1013            assert_eq!(rows_host.data, vec![2.0, 1.0]);
1014
1015            let cols = eval.column_value().expect("cols");
1016            assert!(matches!(cols, Value::GpuTensor(_)));
1017            let cols_host = test_support::gather(cols).expect("gather cols");
1018            assert_eq!(cols_host.data, vec![1.0, 2.0]);
1019
1020            let vals = eval.values_value().expect("vals");
1021            assert!(matches!(vals, Value::GpuTensor(_)));
1022            let vals_host = test_support::gather(vals).expect("gather vals");
1023            assert_eq!(vals_host.data, vec![4.0, 5.0]);
1024        });
1025    }
1026
1027    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1028    #[test]
1029    #[cfg(feature = "wgpu")]
1030    fn find_wgpu_matches_cpu() {
1031        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1032            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1033        );
1034        let tensor = Tensor::new(vec![0.0, 2.0, 0.0, 3.0, 4.0, 0.0], vec![3, 2]).unwrap();
1035        let cpu_eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("cpu evaluate");
1036        let cpu_linear =
1037            test_support::gather(cpu_eval.linear_value().expect("cpu linear")).expect("cpu gather");
1038        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1039        let view = HostTensorView {
1040            data: &tensor.data,
1041            shape: &tensor.shape,
1042        };
1043        let handle = provider.upload(&view).expect("upload");
1044        let gpu_eval = evaluate(Value::GpuTensor(handle), &[]).expect("gpu evaluate");
1045        let gpu_linear =
1046            test_support::gather(gpu_eval.linear_value().expect("gpu linear")).expect("gpu gather");
1047        assert_eq!(gpu_linear.data, cpu_linear.data);
1048    }
1049}