1use runmat_accelerate_api::{HostTensorView, ProviderFindResult};
4use runmat_builtins::{
5 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
6 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
7 ComplexTensor, ResolveContext, Tensor, Type, Value,
8};
9use runmat_macros::runtime_builtin;
10
11use crate::builtins::array::type_resolvers::column_vector_type;
12use crate::builtins::common::arg_tokens::ArgToken;
13use crate::builtins::common::random_args::complex_tensor_into_value;
14use crate::builtins::common::spec::{
15 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
16 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
17};
18use crate::builtins::common::{gpu_helpers, tensor};
19use crate::{build_runtime_error, RuntimeError};
20
21#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::indexing::find")]
22pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
23 name: "find",
24 op_kind: GpuOpKind::Custom("find"),
25 supported_precisions: &[ScalarType::F32, ScalarType::F64],
26 broadcast: BroadcastSemantics::None,
27 provider_hooks: &[ProviderHook::Custom("find")],
28 constant_strategy: ConstantStrategy::InlineLiteral,
29 residency: ResidencyPolicy::NewHandle,
30 nan_mode: ReductionNaN::Include,
31 two_pass_threshold: None,
32 workgroup_size: None,
33 accepts_nan_mode: false,
34 notes: "WGPU provider executes find directly on device; other providers fall back to host and re-upload results to preserve residency.",
35};
36
37#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::indexing::find")]
38pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
39 name: "find",
40 shape: ShapeRequirements::Any,
41 constant_strategy: ConstantStrategy::InlineLiteral,
42 elementwise: None,
43 reduction: None,
44 emits_nan: false,
45 notes: "Find drives control flow and currently bypasses fusion; metadata is present for completeness only.",
46};
47
48fn find_type(_args: &[Type], _ctx: &ResolveContext) -> Type {
49 column_vector_type()
50}
51
52const BUILTIN_NAME: &str = "find";
53
54const FIND_OUTPUT_LINEAR: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
55 name: "idx",
56 ty: BuiltinParamType::NumericArray,
57 arity: BuiltinParamArity::Required,
58 default: None,
59 description: "Linear indices of non-zero elements.",
60}];
61
62const FIND_OUTPUT_ROW_COL: [BuiltinParamDescriptor; 2] = [
63 BuiltinParamDescriptor {
64 name: "row",
65 ty: BuiltinParamType::NumericArray,
66 arity: BuiltinParamArity::Required,
67 default: None,
68 description: "Row subscripts of non-zero elements.",
69 },
70 BuiltinParamDescriptor {
71 name: "col",
72 ty: BuiltinParamType::NumericArray,
73 arity: BuiltinParamArity::Required,
74 default: None,
75 description: "Column subscripts of non-zero elements.",
76 },
77];
78
79const FIND_OUTPUT_ROW_COL_VAL: [BuiltinParamDescriptor; 3] = [
80 BuiltinParamDescriptor {
81 name: "row",
82 ty: BuiltinParamType::NumericArray,
83 arity: BuiltinParamArity::Required,
84 default: None,
85 description: "Row subscripts of non-zero elements.",
86 },
87 BuiltinParamDescriptor {
88 name: "col",
89 ty: BuiltinParamType::NumericArray,
90 arity: BuiltinParamArity::Required,
91 default: None,
92 description: "Column subscripts of non-zero elements.",
93 },
94 BuiltinParamDescriptor {
95 name: "v",
96 ty: BuiltinParamType::Any,
97 arity: BuiltinParamArity::Required,
98 default: None,
99 description: "Values at the reported row/column locations.",
100 },
101];
102
103const FIND_INPUTS_BASE: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
104 name: "X",
105 ty: BuiltinParamType::Any,
106 arity: BuiltinParamArity::Required,
107 default: None,
108 description: "Input array to search.",
109}];
110
111const FIND_INPUTS_LIMIT: [BuiltinParamDescriptor; 2] = [
112 BuiltinParamDescriptor {
113 name: "X",
114 ty: BuiltinParamType::Any,
115 arity: BuiltinParamArity::Required,
116 default: None,
117 description: "Input array to search.",
118 },
119 BuiltinParamDescriptor {
120 name: "K",
121 ty: BuiltinParamType::NumericScalar,
122 arity: BuiltinParamArity::Required,
123 default: None,
124 description: "Maximum number of indices to return.",
125 },
126];
127
128const FIND_INPUTS_LIMIT_DIR: [BuiltinParamDescriptor; 3] = [
129 BuiltinParamDescriptor {
130 name: "X",
131 ty: BuiltinParamType::Any,
132 arity: BuiltinParamArity::Required,
133 default: None,
134 description: "Input array to search.",
135 },
136 BuiltinParamDescriptor {
137 name: "K",
138 ty: BuiltinParamType::NumericScalar,
139 arity: BuiltinParamArity::Required,
140 default: None,
141 description: "Maximum number of indices to return.",
142 },
143 BuiltinParamDescriptor {
144 name: "direction",
145 ty: BuiltinParamType::StringScalar,
146 arity: BuiltinParamArity::Required,
147 default: Some("\"first\""),
148 description: "Direction selector: `\"first\"` or `\"last\"`.",
149 },
150];
151
152const FIND_SIGNATURES: [BuiltinSignatureDescriptor; 7] = [
153 BuiltinSignatureDescriptor {
154 label: "idx = find(X)",
155 inputs: &FIND_INPUTS_BASE,
156 outputs: &FIND_OUTPUT_LINEAR,
157 },
158 BuiltinSignatureDescriptor {
159 label: "idx = find(X, K)",
160 inputs: &FIND_INPUTS_LIMIT,
161 outputs: &FIND_OUTPUT_LINEAR,
162 },
163 BuiltinSignatureDescriptor {
164 label: "idx = find(X, K, direction)",
165 inputs: &FIND_INPUTS_LIMIT_DIR,
166 outputs: &FIND_OUTPUT_LINEAR,
167 },
168 BuiltinSignatureDescriptor {
169 label: "[row, col] = find(X)",
170 inputs: &FIND_INPUTS_BASE,
171 outputs: &FIND_OUTPUT_ROW_COL,
172 },
173 BuiltinSignatureDescriptor {
174 label: "[row, col] = find(X, K, direction)",
175 inputs: &FIND_INPUTS_LIMIT_DIR,
176 outputs: &FIND_OUTPUT_ROW_COL,
177 },
178 BuiltinSignatureDescriptor {
179 label: "[row, col, v] = find(X)",
180 inputs: &FIND_INPUTS_BASE,
181 outputs: &FIND_OUTPUT_ROW_COL_VAL,
182 },
183 BuiltinSignatureDescriptor {
184 label: "[row, col, v] = find(X, K, direction)",
185 inputs: &FIND_INPUTS_LIMIT_DIR,
186 outputs: &FIND_OUTPUT_ROW_COL_VAL,
187 },
188];
189
190const FIND_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
191 code: "RM.FIND.INVALID_INPUT",
192 identifier: Some("RunMat:find:InvalidInput"),
193 when: "Input type or option arguments are not valid for find.",
194 message: "find: invalid input arguments",
195};
196
197const FIND_ERROR_PROVIDER_OUTPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
198 code: "RM.FIND.PROVIDER_OUTPUT",
199 identifier: Some("RunMat:find:ProviderOutput"),
200 when: "GPU provider does not return expected output buffers for requested nargout.",
201 message: "find: provider output buffer mismatch",
202};
203
204const FIND_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
205 code: "RM.FIND.INTERNAL",
206 identifier: Some("RunMat:find:InternalError"),
207 when: "Internal tensor conversion/materialization fails while building outputs.",
208 message: "find: internal error",
209};
210
211const FIND_ERRORS: [BuiltinErrorDescriptor; 3] = [
212 FIND_ERROR_INVALID_INPUT,
213 FIND_ERROR_PROVIDER_OUTPUT,
214 FIND_ERROR_INTERNAL,
215];
216
217pub const FIND_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
218 signatures: &FIND_SIGNATURES,
219 output_mode: BuiltinOutputMode::ByRequestedOutputCount,
220 completion_policy: BuiltinCompletionPolicy::Public,
221 errors: &FIND_ERRORS,
222};
223
224fn find_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
225 find_error_with_message(error.message, error)
226}
227
228fn find_error_with_message(
229 message: impl Into<String>,
230 error: &'static BuiltinErrorDescriptor,
231) -> RuntimeError {
232 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
233 if let Some(identifier) = error.identifier {
234 builder = builder.with_identifier(identifier);
235 }
236 builder.build()
237}
238
239fn parse_find_tokens(tokens: &[ArgToken]) -> crate::BuiltinResult<FindOptions> {
240 match tokens.len() {
241 0 => Ok(FindOptions::default()),
242 1 => {
243 if let Some(direction) = token_to_direction(&tokens[0])? {
244 let limit = if matches!(direction, FindDirection::Last) {
245 Some(1)
246 } else {
247 None
248 };
249 Ok(FindOptions { limit, direction })
250 } else {
251 let limit = token_to_limit(&tokens[0])?;
252 Ok(FindOptions {
253 limit: Some(limit),
254 direction: FindDirection::First,
255 })
256 }
257 }
258 2 => {
259 let limit = token_to_limit(&tokens[0])?;
260 let direction = token_to_direction(&tokens[1])?.ok_or_else(|| {
261 find_error_with_message(
262 "find: third argument must be 'first' or 'last'",
263 &FIND_ERROR_INVALID_INPUT,
264 )
265 })?;
266 Ok(FindOptions {
267 limit: Some(limit),
268 direction,
269 })
270 }
271 _ => Err(find_error_with_message(
272 "find: too many input arguments",
273 &FIND_ERROR_INVALID_INPUT,
274 )),
275 }
276}
277
278fn token_to_direction(token: &ArgToken) -> crate::BuiltinResult<Option<FindDirection>> {
279 match token {
280 ArgToken::String(text) => match text.as_str() {
281 "first" => Ok(Some(FindDirection::First)),
282 "last" => Ok(Some(FindDirection::Last)),
283 _ => Err(find_error_with_message(
284 "find: direction must be 'first' or 'last'",
285 &FIND_ERROR_INVALID_INPUT,
286 )),
287 },
288 _ => Ok(None),
289 }
290}
291
292fn token_to_limit(token: &ArgToken) -> crate::BuiltinResult<usize> {
293 match token {
294 ArgToken::Number(value) => parse_limit_scalar(*value),
295 _ => Err(find_error_with_message(
296 "find: second argument must be a scalar",
297 &FIND_ERROR_INVALID_INPUT,
298 )),
299 }
300}
301
302#[runtime_builtin(
303 name = "find",
304 category = "array/indexing",
305 summary = "Locate nonzero indices and values.",
306 keywords = "find,nonzero,indices,row,column,gpu",
307 accel = "custom",
308 type_resolver(find_type),
309 descriptor(crate::builtins::array::indexing::find::FIND_DESCRIPTOR),
310 builtin_path = "crate::builtins::array::indexing::find"
311)]
312async fn find_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
313 let eval = evaluate(value, &rest).await?;
314 if let Some(out_count) = crate::output_count::current_output_count() {
315 if out_count == 0 {
316 return Ok(Value::OutputList(Vec::new()));
317 }
318 if out_count <= 1 {
319 let linear = eval.linear_value()?;
320 return Ok(crate::output_count::output_list_with_padding(
321 out_count,
322 vec![linear],
323 ));
324 }
325 let rows = eval.row_value()?;
326 let cols = eval.column_value()?;
327 let mut outputs = vec![rows, cols];
328 if out_count >= 3 {
329 outputs.push(eval.values_value()?);
330 }
331 return Ok(crate::output_count::output_list_with_padding(
332 out_count, outputs,
333 ));
334 }
335 eval.linear_value()
336}
337
338pub async fn evaluate(value: Value, args: &[Value]) -> crate::BuiltinResult<FindEval> {
340 let options = parse_options(args).await?;
341 match value {
342 Value::GpuTensor(handle) => {
343 if let Some(result) = try_provider_find(&handle, &options) {
344 return Ok(FindEval::from_gpu(result));
345 }
346 let (storage, _) = materialize_input(Value::GpuTensor(handle)).await?;
347 let result = compute_find(&storage, &options);
348 Ok(FindEval::from_host(result, true))
349 }
350 Value::SparseTensor(sparse) => {
351 let result = compute_find_sparse(&sparse, &options);
352 Ok(FindEval::from_host(result, false))
353 }
354 other => {
355 let (storage, input_was_gpu) = materialize_input(other).await?;
356 let result = compute_find(&storage, &options);
357 Ok(FindEval::from_host(result, input_was_gpu))
358 }
359 }
360}
361
362fn try_provider_find(
363 handle: &runmat_accelerate_api::GpuTensorHandle,
364 options: &FindOptions,
365) -> Option<ProviderFindResult> {
366 #[cfg(all(test, feature = "wgpu"))]
367 {
368 if handle.device_id != 0 {
369 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
370 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
371 );
372 }
373 }
374 let provider = runmat_accelerate_api::provider()?;
375 let direction = match options.direction {
376 FindDirection::First => runmat_accelerate_api::FindDirection::First,
377 FindDirection::Last => runmat_accelerate_api::FindDirection::Last,
378 };
379 let limit = options.effective_limit();
380 provider.find(handle, limit, direction).ok()
381}
382
383#[derive(Debug, Clone, Copy, PartialEq, Eq)]
384enum FindDirection {
385 First,
386 Last,
387}
388
389#[derive(Debug, Clone)]
390struct FindOptions {
391 limit: Option<usize>,
392 direction: FindDirection,
393}
394
395impl Default for FindOptions {
396 fn default() -> Self {
397 Self {
398 limit: None,
399 direction: FindDirection::First,
400 }
401 }
402}
403
404impl FindOptions {
405 fn effective_limit(&self) -> Option<usize> {
406 match self.direction {
407 FindDirection::Last => self.limit.or(Some(1)),
408 FindDirection::First => self.limit,
409 }
410 }
411}
412
413#[derive(Clone)]
414enum DataStorage {
415 Real(Tensor),
416 Complex(ComplexTensor),
417}
418
419impl DataStorage {
420 fn shape(&self) -> &[usize] {
421 match self {
422 DataStorage::Real(t) => &t.shape,
423 DataStorage::Complex(t) => &t.shape,
424 }
425 }
426}
427
428#[derive(Clone)]
429struct FindResult {
430 shape: Vec<usize>,
431 indices: Vec<usize>,
432 values: FindValues,
433}
434
435#[derive(Clone)]
436enum FindValues {
437 Real(Vec<f64>),
438 Complex(Vec<(f64, f64)>),
439}
440
441pub struct FindEval {
442 inner: FindEvalInner,
443}
444
445enum FindEvalInner {
446 Host {
447 result: FindResult,
448 prefer_gpu: bool,
449 },
450 Gpu {
451 result: ProviderFindResult,
452 },
453}
454
455impl FindEval {
456 fn from_host(result: FindResult, prefer_gpu: bool) -> Self {
457 Self {
458 inner: FindEvalInner::Host { result, prefer_gpu },
459 }
460 }
461
462 fn from_gpu(result: ProviderFindResult) -> Self {
463 Self {
464 inner: FindEvalInner::Gpu { result },
465 }
466 }
467
468 pub fn linear_value(&self) -> crate::BuiltinResult<Value> {
469 match &self.inner {
470 FindEvalInner::Host { result, prefer_gpu } => {
471 let tensor = result.linear_tensor()?;
472 Ok(tensor_to_value(tensor, *prefer_gpu))
473 }
474 FindEvalInner::Gpu { result } => Ok(Value::GpuTensor(result.linear.clone())),
475 }
476 }
477
478 pub fn row_value(&self) -> crate::BuiltinResult<Value> {
479 match &self.inner {
480 FindEvalInner::Host { result, prefer_gpu } => {
481 let tensor = result.row_tensor()?;
482 Ok(tensor_to_value(tensor, *prefer_gpu))
483 }
484 FindEvalInner::Gpu { result } => Ok(Value::GpuTensor(result.rows.clone())),
485 }
486 }
487
488 pub fn column_value(&self) -> crate::BuiltinResult<Value> {
489 match &self.inner {
490 FindEvalInner::Host { result, prefer_gpu } => {
491 let tensor = result.column_tensor()?;
492 Ok(tensor_to_value(tensor, *prefer_gpu))
493 }
494 FindEvalInner::Gpu { result } => Ok(Value::GpuTensor(result.cols.clone())),
495 }
496 }
497
498 pub fn values_value(&self) -> crate::BuiltinResult<Value> {
499 match &self.inner {
500 FindEvalInner::Host { result, prefer_gpu } => result.values_value(*prefer_gpu),
501 FindEvalInner::Gpu { result } => result
502 .values
503 .as_ref()
504 .map(|handle| Value::GpuTensor(handle.clone()))
505 .ok_or_else(|| find_error(&FIND_ERROR_PROVIDER_OUTPUT)),
506 }
507 }
508}
509
510async fn parse_options(args: &[Value]) -> crate::BuiltinResult<FindOptions> {
511 parse_find_tokens(&crate::builtins::common::arg_tokens::tokens_from_values(
512 args,
513 ))
514}
515
516fn parse_limit_scalar(value: f64) -> crate::BuiltinResult<usize> {
517 if !value.is_finite() {
518 return Err(find_error_with_message(
519 "find: K must be a finite, non-negative integer",
520 &FIND_ERROR_INVALID_INPUT,
521 ));
522 }
523 let rounded = value.round();
524 if (rounded - value).abs() > f64::EPSILON {
525 return Err(find_error_with_message(
526 "find: K must be a finite, non-negative integer",
527 &FIND_ERROR_INVALID_INPUT,
528 ));
529 }
530 if rounded < 0.0 {
531 return Err(find_error_with_message(
532 "find: K must be >= 0",
533 &FIND_ERROR_INVALID_INPUT,
534 ));
535 }
536 Ok(rounded as usize)
537}
538
539async fn materialize_input(value: Value) -> crate::BuiltinResult<(DataStorage, bool)> {
540 match value {
541 Value::GpuTensor(handle) => {
542 let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
543 Ok((DataStorage::Real(tensor), true))
544 }
545 Value::Tensor(tensor) => Ok((DataStorage::Real(tensor), false)),
546 Value::SparseTensor(sparse) => Ok((
547 DataStorage::Real(sparse.to_dense().map_err(|e| {
548 find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL)
549 })?),
550 false,
551 )),
552 Value::LogicalArray(logical) => {
553 let tensor = tensor::logical_to_tensor(&logical)
554 .map_err(|message| find_error_with_message(message, &FIND_ERROR_INTERNAL))?;
555 Ok((DataStorage::Real(tensor), false))
556 }
557 Value::Num(n) => {
558 let tensor = Tensor::new(vec![n], vec![1, 1])
559 .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))?;
560 Ok((DataStorage::Real(tensor), false))
561 }
562 Value::Int(i) => {
563 let tensor = Tensor::new(vec![i.to_f64()], vec![1, 1])
564 .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))?;
565 Ok((DataStorage::Real(tensor), false))
566 }
567 Value::Bool(b) => {
568 let tensor = Tensor::new(vec![if b { 1.0 } else { 0.0 }], vec![1, 1])
569 .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))?;
570 Ok((DataStorage::Real(tensor), false))
571 }
572 Value::Complex(re, im) => {
573 let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
574 .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))?;
575 Ok((DataStorage::Complex(tensor), false))
576 }
577 Value::ComplexTensor(tensor) => Ok((DataStorage::Complex(tensor), false)),
578 Value::CharArray(chars) => {
579 let mut data = Vec::with_capacity(chars.data.len());
580 for c in 0..chars.cols {
581 for r in 0..chars.rows {
582 let ch = chars.data[r * chars.cols + c] as u32;
583 data.push(ch as f64);
584 }
585 }
586 let tensor = Tensor::new(data, vec![chars.rows, chars.cols])
587 .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))?;
588 Ok((DataStorage::Real(tensor), false))
589 }
590 other => Err(find_error_with_message(
591 format!(
592 "find: unsupported input type {:?}; expected numeric, logical, or char data",
593 other
594 ),
595 &FIND_ERROR_INVALID_INPUT,
596 )),
597 }
598}
599
600fn compute_find(storage: &DataStorage, options: &FindOptions) -> FindResult {
601 let shape = storage.shape().to_vec();
602 let limit = options.effective_limit();
603
604 match storage {
605 DataStorage::Real(tensor) => {
606 let mut indices = Vec::new();
607 let mut values = Vec::new();
608
609 if matches!(limit, Some(0)) {
610 return FindResult::new(shape, indices, FindValues::Real(values));
611 }
612
613 let len = tensor.data.len();
614 match options.direction {
615 FindDirection::First => {
616 for idx in 0..len {
617 let value = tensor.data[idx];
618 if value != 0.0 {
619 indices.push(idx + 1);
620 values.push(value);
621 if limit.is_some_and(|k| indices.len() >= k) {
622 break;
623 }
624 }
625 }
626 }
627 FindDirection::Last => {
628 for idx in (0..len).rev() {
629 let value = tensor.data[idx];
630 if value != 0.0 {
631 indices.push(idx + 1);
632 values.push(value);
633 if limit.is_some_and(|k| indices.len() >= k) {
634 break;
635 }
636 }
637 }
638 }
639 }
640
641 FindResult::new(shape, indices, FindValues::Real(values))
642 }
643 DataStorage::Complex(tensor) => {
644 let mut indices = Vec::new();
645 let mut values = Vec::new();
646
647 if matches!(limit, Some(0)) {
648 return FindResult::new(shape, indices, FindValues::Complex(values));
649 }
650
651 let len = tensor.data.len();
652 match options.direction {
653 FindDirection::First => {
654 for idx in 0..len {
655 let (re, im) = tensor.data[idx];
656 if re != 0.0 || im != 0.0 {
657 indices.push(idx + 1);
658 values.push((re, im));
659 if limit.is_some_and(|k| indices.len() >= k) {
660 break;
661 }
662 }
663 }
664 }
665 FindDirection::Last => {
666 for idx in (0..len).rev() {
667 let (re, im) = tensor.data[idx];
668 if re != 0.0 || im != 0.0 {
669 indices.push(idx + 1);
670 values.push((re, im));
671 if limit.is_some_and(|k| indices.len() >= k) {
672 break;
673 }
674 }
675 }
676 }
677 }
678
679 FindResult::new(shape, indices, FindValues::Complex(values))
680 }
681 }
682}
683
684fn compute_find_sparse(
685 sparse: &runmat_builtins::SparseTensor,
686 options: &FindOptions,
687) -> FindResult {
688 let shape = vec![sparse.rows, sparse.cols];
689 let limit = options.effective_limit();
690
691 let mut indices = Vec::new();
692 let mut values = Vec::new();
693
694 if matches!(limit, Some(0)) {
695 return FindResult::new(shape, indices, FindValues::Real(values));
696 }
697
698 match options.direction {
699 FindDirection::First => {
700 for col in 0..sparse.cols {
701 let col_start = sparse.col_ptrs[col];
702 let col_end = sparse.col_ptrs[col + 1];
703 for idx in col_start..col_end {
704 let row = sparse.row_indices[idx];
705 let value = sparse.values[idx];
706 if value != 0.0 {
707 let linear_idx = row + col * sparse.rows;
708 indices.push(linear_idx + 1);
709 values.push(value);
710 if limit.is_some_and(|k| indices.len() >= k) {
711 return FindResult::new(shape, indices, FindValues::Real(values));
712 }
713 }
714 }
715 }
716 }
717 FindDirection::Last => {
718 for col in (0..sparse.cols).rev() {
719 let col_start = sparse.col_ptrs[col];
720 let col_end = sparse.col_ptrs[col + 1];
721 for idx in (col_start..col_end).rev() {
722 let row = sparse.row_indices[idx];
723 let value = sparse.values[idx];
724 if value != 0.0 {
725 let linear_idx = row + col * sparse.rows;
726 indices.push(linear_idx + 1);
727 values.push(value);
728 if limit.is_some_and(|k| indices.len() >= k) {
729 return FindResult::new(shape, indices, FindValues::Real(values));
730 }
731 }
732 }
733 }
734 }
735 }
736
737 FindResult::new(shape, indices, FindValues::Real(values))
738}
739
740impl FindResult {
741 fn new(shape: Vec<usize>, indices: Vec<usize>, values: FindValues) -> Self {
742 Self {
743 shape,
744 indices,
745 values,
746 }
747 }
748
749 fn linear_tensor(&self) -> crate::BuiltinResult<Tensor> {
750 let data: Vec<f64> = self.indices.iter().map(|&idx| idx as f64).collect();
751 let rows = data.len();
752 Tensor::new(data, vec![rows, 1])
753 .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))
754 }
755
756 fn row_tensor(&self) -> crate::BuiltinResult<Tensor> {
757 let mut data = Vec::with_capacity(self.indices.len());
758 let rows = self.shape.first().copied().unwrap_or(1).max(1);
759 for &idx in &self.indices {
760 let zero_based = idx - 1;
761 let row = (zero_based % rows) + 1;
762 data.push(row as f64);
763 }
764 Tensor::new(data, vec![self.indices.len(), 1])
765 .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))
766 }
767
768 fn column_tensor(&self) -> crate::BuiltinResult<Tensor> {
769 let mut data = Vec::with_capacity(self.indices.len());
770 let rows = self.shape.first().copied().unwrap_or(1).max(1);
771 for &idx in &self.indices {
772 let zero_based = idx - 1;
773 let col = (zero_based / rows) + 1;
774 data.push(col as f64);
775 }
776 Tensor::new(data, vec![self.indices.len(), 1])
777 .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))
778 }
779
780 fn values_value(&self, prefer_gpu: bool) -> crate::BuiltinResult<Value> {
781 match &self.values {
782 FindValues::Real(values) => {
783 let tensor = Tensor::new(values.clone(), vec![values.len(), 1]).map_err(|e| {
784 find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL)
785 })?;
786 Ok(tensor_to_value(tensor, prefer_gpu))
787 }
788 FindValues::Complex(values) => {
789 let tensor =
790 ComplexTensor::new(values.clone(), vec![values.len(), 1]).map_err(|e| {
791 find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL)
792 })?;
793 Ok(complex_tensor_into_value(tensor))
794 }
795 }
796 }
797}
798
799fn tensor_to_value(tensor: Tensor, prefer_gpu: bool) -> Value {
800 if prefer_gpu {
801 if let Some(provider) = runmat_accelerate_api::provider() {
802 let view = HostTensorView {
803 data: &tensor.data,
804 shape: &tensor.shape,
805 };
806 if let Ok(handle) = provider.upload(&view) {
807 return Value::GpuTensor(handle);
808 }
809 }
810 }
811 tensor::tensor_into_value(tensor)
812}
813
814#[cfg(test)]
815pub(crate) mod tests {
816 use super::*;
817 use crate::builtins::common::test_support;
818 use futures::executor::block_on;
819 use runmat_builtins::{CharArray, IntValue, Type};
820
821 fn find_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
822 block_on(super::find_builtin(value, rest))
823 }
824
825 fn evaluate(value: Value, rest: &[Value]) -> crate::BuiltinResult<FindEval> {
826 block_on(super::evaluate(value, rest))
827 }
828
829 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
830 #[test]
831 fn find_linear_indices_basic() {
832 let tensor = Tensor::new(vec![0.0, 4.0, 0.0, 7.0, 0.0, 9.0], vec![2, 3]).unwrap();
833 let value = find_builtin(Value::Tensor(tensor), Vec::new()).expect("find");
834 match value {
835 Value::Tensor(t) => {
836 assert_eq!(t.shape, vec![3, 1]);
837 assert_eq!(t.data, vec![2.0, 4.0, 6.0]);
838 }
839 other => panic!("expected tensor, got {other:?}"),
840 }
841 }
842
843 #[test]
844 fn find_type_is_column_vector() {
845 assert_eq!(
846 find_type(
847 &[Type::Tensor { shape: None }],
848 &ResolveContext::new(Vec::new()),
849 ),
850 Type::Tensor {
851 shape: Some(vec![None, Some(1)])
852 }
853 );
854 }
855
856 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
857 #[test]
858 fn find_limited_first() {
859 let tensor = Tensor::new(vec![0.0, 3.0, 5.0, 0.0, 8.0], vec![1, 5]).unwrap();
860 let result =
861 find_builtin(Value::Tensor(tensor), vec![Value::Int(IntValue::I32(2))]).expect("find");
862 match result {
863 Value::Tensor(t) => {
864 assert_eq!(t.data, vec![2.0, 3.0]);
865 }
866 other => panic!("expected tensor, got {other:?}"),
867 }
868 }
869
870 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
871 #[test]
872 fn find_last_single() {
873 let tensor = Tensor::new(vec![1.0, 0.0, 0.0, 6.0, 0.0, 2.0], vec![1, 6]).unwrap();
874 let result = find_builtin(Value::Tensor(tensor), vec![Value::from("last")]).expect("find");
875 match result {
876 Value::Num(n) => assert_eq!(n, 6.0),
877 Value::Tensor(t) => {
878 assert_eq!(t.data, vec![6.0]);
879 }
880 other => panic!("unexpected result {other:?}"),
881 }
882 }
883
884 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
885 #[test]
886 fn find_complex_values() {
887 let tensor =
888 ComplexTensor::new(vec![(0.0, 0.0), (1.0, 2.0), (0.0, 0.0)], vec![3, 1]).unwrap();
889 let eval = evaluate(Value::ComplexTensor(tensor), &[]).expect("find compute");
890 let values = eval.values_value().expect("values");
891 match values {
892 Value::Complex(re, im) => {
893 assert_eq!(re, 1.0);
894 assert_eq!(im, 2.0);
895 }
896 Value::ComplexTensor(ct) => {
897 assert_eq!(ct.shape, vec![1, 1]);
898 assert_eq!(ct.data, vec![(1.0, 2.0)]);
899 }
900 other => panic!("expected complex result, got {other:?}"),
901 }
902 }
903
904 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
905 #[test]
906 fn find_gpu_roundtrip() {
907 test_support::with_test_provider(|provider| {
908 let tensor = Tensor::new(vec![0.0, 4.0, 0.0, 7.0], vec![2, 2]).unwrap();
909 let view = HostTensorView {
910 data: &tensor.data,
911 shape: &tensor.shape,
912 };
913 let handle = provider.upload(&view).expect("upload");
914 let result = find_builtin(Value::GpuTensor(handle), Vec::new()).expect("find");
915 let gathered = test_support::gather(result).expect("gather");
916 assert_eq!(gathered.shape, vec![2, 1]);
917 assert_eq!(gathered.data, vec![2.0, 4.0]);
918 });
919 }
920
921 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
922 #[test]
923 fn find_direction_error() {
924 let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
925 let err = find_builtin(
926 Value::Tensor(tensor),
927 vec![Value::Int(IntValue::I32(1)), Value::from("invalid")],
928 )
929 .expect_err("expected error");
930 assert!(err.to_string().contains("direction"));
931 assert_eq!(err.identifier(), super::FIND_ERROR_INVALID_INPUT.identifier);
932 }
933
934 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
935 #[test]
936 fn find_multi_output_rows_cols_values() {
937 let tensor = Tensor::new(vec![0.0, 2.0, 3.0, 0.0, 0.0, 6.0], vec![2, 3]).unwrap();
938 let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
939
940 let rows = test_support::gather(eval.row_value().expect("rows")).expect("gather rows");
941 assert_eq!(rows.shape, vec![3, 1]);
942 assert_eq!(rows.data, vec![2.0, 1.0, 2.0]);
943
944 let cols = test_support::gather(eval.column_value().expect("cols")).expect("gather cols");
945 assert_eq!(cols.shape, vec![3, 1]);
946 assert_eq!(cols.data, vec![1.0, 2.0, 3.0]);
947
948 let vals = test_support::gather(eval.values_value().expect("vals")).expect("gather vals");
949 assert_eq!(vals.shape, vec![3, 1]);
950 assert_eq!(vals.data, vec![2.0, 3.0, 6.0]);
951 }
952
953 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
954 #[test]
955 fn find_last_order_descending() {
956 let tensor = Tensor::new(vec![1.0, 0.0, 2.0, 3.0, 0.0], vec![1, 5]).unwrap();
957 let result = find_builtin(
958 Value::Tensor(tensor),
959 vec![Value::Int(IntValue::I32(2)), Value::from("last")],
960 )
961 .expect("find");
962 match result {
963 Value::Tensor(t) => {
964 assert_eq!(t.shape, vec![2, 1]);
965 assert_eq!(t.data, vec![4.0, 3.0]);
966 }
967 Value::Num(_) => panic!("expected column vector"),
968 other => panic!("unexpected result {other:?}"),
969 }
970 }
971
972 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
973 #[test]
974 fn find_limit_zero_returns_empty() {
975 let tensor = Tensor::new(vec![1.0, 0.0, 3.0], vec![3, 1]).unwrap();
976 let result = find_builtin(Value::Tensor(tensor), vec![Value::Num(0.0)]).expect("find");
977 match result {
978 Value::Tensor(t) => {
979 assert_eq!(t.shape, vec![0, 1]);
980 assert!(t.data.is_empty());
981 }
982 other => panic!("expected empty tensor, got {other:?}"),
983 }
984 }
985
986 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
987 #[test]
988 fn find_char_array_supports_nonzero_codes() {
989 let chars = CharArray::new(vec!['\0', 'A', '\0'], 1, 3).unwrap();
990 let result = find_builtin(Value::CharArray(chars), Vec::new()).expect("find");
991 match result {
992 Value::Num(n) => assert_eq!(n, 2.0),
993 Value::Tensor(t) => assert_eq!(t.data, vec![2.0]),
994 other => panic!("unexpected result {other:?}"),
995 }
996 }
997
998 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
999 #[test]
1000 fn find_gpu_multi_outputs_return_gpu_handles() {
1001 test_support::with_test_provider(|provider| {
1002 let tensor = Tensor::new(vec![0.0, 4.0, 5.0, 0.0], vec![2, 2]).unwrap();
1003 let view = HostTensorView {
1004 data: &tensor.data,
1005 shape: &tensor.shape,
1006 };
1007 let handle = provider.upload(&view).expect("upload");
1008 let eval = evaluate(Value::GpuTensor(handle), &[]).expect("evaluate");
1009
1010 let rows = eval.row_value().expect("rows");
1011 assert!(matches!(rows, Value::GpuTensor(_)));
1012 let rows_host = test_support::gather(rows).expect("gather rows");
1013 assert_eq!(rows_host.data, vec![2.0, 1.0]);
1014
1015 let cols = eval.column_value().expect("cols");
1016 assert!(matches!(cols, Value::GpuTensor(_)));
1017 let cols_host = test_support::gather(cols).expect("gather cols");
1018 assert_eq!(cols_host.data, vec![1.0, 2.0]);
1019
1020 let vals = eval.values_value().expect("vals");
1021 assert!(matches!(vals, Value::GpuTensor(_)));
1022 let vals_host = test_support::gather(vals).expect("gather vals");
1023 assert_eq!(vals_host.data, vec![4.0, 5.0]);
1024 });
1025 }
1026
1027 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1028 #[test]
1029 #[cfg(feature = "wgpu")]
1030 fn find_wgpu_matches_cpu() {
1031 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1032 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1033 );
1034 let tensor = Tensor::new(vec![0.0, 2.0, 0.0, 3.0, 4.0, 0.0], vec![3, 2]).unwrap();
1035 let cpu_eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("cpu evaluate");
1036 let cpu_linear =
1037 test_support::gather(cpu_eval.linear_value().expect("cpu linear")).expect("cpu gather");
1038 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1039 let view = HostTensorView {
1040 data: &tensor.data,
1041 shape: &tensor.shape,
1042 };
1043 let handle = provider.upload(&view).expect("upload");
1044 let gpu_eval = evaluate(Value::GpuTensor(handle), &[]).expect("gpu evaluate");
1045 let gpu_linear =
1046 test_support::gather(gpu_eval.linear_value().expect("gpu linear")).expect("gpu gather");
1047 assert_eq!(gpu_linear.data, cpu_linear.data);
1048 }
1049}