1use runmat_accelerate_api::{HostTensorView, ProviderFindResult};
4use runmat_builtins::{
5 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
6 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
7 ComplexTensor, ResolveContext, Tensor, Type, Value,
8};
9use runmat_macros::runtime_builtin;
10
11use crate::builtins::array::type_resolvers::column_vector_type;
12use crate::builtins::common::arg_tokens::ArgToken;
13use crate::builtins::common::random_args::complex_tensor_into_value;
14use crate::builtins::common::spec::{
15 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
16 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
17};
18use crate::builtins::common::{gpu_helpers, tensor};
19use crate::{build_runtime_error, RuntimeError};
20
21#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::indexing::find")]
22pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
23 name: "find",
24 op_kind: GpuOpKind::Custom("find"),
25 supported_precisions: &[ScalarType::F32, ScalarType::F64],
26 broadcast: BroadcastSemantics::None,
27 provider_hooks: &[ProviderHook::Custom("find")],
28 constant_strategy: ConstantStrategy::InlineLiteral,
29 residency: ResidencyPolicy::NewHandle,
30 nan_mode: ReductionNaN::Include,
31 two_pass_threshold: None,
32 workgroup_size: None,
33 accepts_nan_mode: false,
34 notes: "WGPU provider executes find directly on device; other providers fall back to host and re-upload results to preserve residency.",
35};
36
37#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::indexing::find")]
38pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
39 name: "find",
40 shape: ShapeRequirements::Any,
41 constant_strategy: ConstantStrategy::InlineLiteral,
42 elementwise: None,
43 reduction: None,
44 emits_nan: false,
45 notes: "Find drives control flow and currently bypasses fusion; metadata is present for completeness only.",
46};
47
48fn find_type(_args: &[Type], _ctx: &ResolveContext) -> Type {
49 column_vector_type()
50}
51
52const BUILTIN_NAME: &str = "find";
53
54const FIND_OUTPUT_LINEAR: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
55 name: "idx",
56 ty: BuiltinParamType::NumericArray,
57 arity: BuiltinParamArity::Required,
58 default: None,
59 description: "Linear indices of non-zero elements.",
60}];
61
62const FIND_OUTPUT_ROW_COL: [BuiltinParamDescriptor; 2] = [
63 BuiltinParamDescriptor {
64 name: "row",
65 ty: BuiltinParamType::NumericArray,
66 arity: BuiltinParamArity::Required,
67 default: None,
68 description: "Row subscripts of non-zero elements.",
69 },
70 BuiltinParamDescriptor {
71 name: "col",
72 ty: BuiltinParamType::NumericArray,
73 arity: BuiltinParamArity::Required,
74 default: None,
75 description: "Column subscripts of non-zero elements.",
76 },
77];
78
79const FIND_OUTPUT_ROW_COL_VAL: [BuiltinParamDescriptor; 3] = [
80 BuiltinParamDescriptor {
81 name: "row",
82 ty: BuiltinParamType::NumericArray,
83 arity: BuiltinParamArity::Required,
84 default: None,
85 description: "Row subscripts of non-zero elements.",
86 },
87 BuiltinParamDescriptor {
88 name: "col",
89 ty: BuiltinParamType::NumericArray,
90 arity: BuiltinParamArity::Required,
91 default: None,
92 description: "Column subscripts of non-zero elements.",
93 },
94 BuiltinParamDescriptor {
95 name: "v",
96 ty: BuiltinParamType::Any,
97 arity: BuiltinParamArity::Required,
98 default: None,
99 description: "Values at the reported row/column locations.",
100 },
101];
102
103const FIND_INPUTS_BASE: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
104 name: "X",
105 ty: BuiltinParamType::Any,
106 arity: BuiltinParamArity::Required,
107 default: None,
108 description: "Input array to search.",
109}];
110
111const FIND_INPUTS_LIMIT: [BuiltinParamDescriptor; 2] = [
112 BuiltinParamDescriptor {
113 name: "X",
114 ty: BuiltinParamType::Any,
115 arity: BuiltinParamArity::Required,
116 default: None,
117 description: "Input array to search.",
118 },
119 BuiltinParamDescriptor {
120 name: "K",
121 ty: BuiltinParamType::NumericScalar,
122 arity: BuiltinParamArity::Required,
123 default: None,
124 description: "Maximum number of indices to return.",
125 },
126];
127
128const FIND_INPUTS_LIMIT_DIR: [BuiltinParamDescriptor; 3] = [
129 BuiltinParamDescriptor {
130 name: "X",
131 ty: BuiltinParamType::Any,
132 arity: BuiltinParamArity::Required,
133 default: None,
134 description: "Input array to search.",
135 },
136 BuiltinParamDescriptor {
137 name: "K",
138 ty: BuiltinParamType::NumericScalar,
139 arity: BuiltinParamArity::Required,
140 default: None,
141 description: "Maximum number of indices to return.",
142 },
143 BuiltinParamDescriptor {
144 name: "direction",
145 ty: BuiltinParamType::StringScalar,
146 arity: BuiltinParamArity::Required,
147 default: Some("\"first\""),
148 description: "Direction selector: `\"first\"` or `\"last\"`.",
149 },
150];
151
152const FIND_SIGNATURES: [BuiltinSignatureDescriptor; 7] = [
153 BuiltinSignatureDescriptor {
154 label: "idx = find(X)",
155 inputs: &FIND_INPUTS_BASE,
156 outputs: &FIND_OUTPUT_LINEAR,
157 },
158 BuiltinSignatureDescriptor {
159 label: "idx = find(X, K)",
160 inputs: &FIND_INPUTS_LIMIT,
161 outputs: &FIND_OUTPUT_LINEAR,
162 },
163 BuiltinSignatureDescriptor {
164 label: "idx = find(X, K, direction)",
165 inputs: &FIND_INPUTS_LIMIT_DIR,
166 outputs: &FIND_OUTPUT_LINEAR,
167 },
168 BuiltinSignatureDescriptor {
169 label: "[row, col] = find(X)",
170 inputs: &FIND_INPUTS_BASE,
171 outputs: &FIND_OUTPUT_ROW_COL,
172 },
173 BuiltinSignatureDescriptor {
174 label: "[row, col] = find(X, K, direction)",
175 inputs: &FIND_INPUTS_LIMIT_DIR,
176 outputs: &FIND_OUTPUT_ROW_COL,
177 },
178 BuiltinSignatureDescriptor {
179 label: "[row, col, v] = find(X)",
180 inputs: &FIND_INPUTS_BASE,
181 outputs: &FIND_OUTPUT_ROW_COL_VAL,
182 },
183 BuiltinSignatureDescriptor {
184 label: "[row, col, v] = find(X, K, direction)",
185 inputs: &FIND_INPUTS_LIMIT_DIR,
186 outputs: &FIND_OUTPUT_ROW_COL_VAL,
187 },
188];
189
190const FIND_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
191 code: "RM.FIND.INVALID_INPUT",
192 identifier: Some("RunMat:find:InvalidInput"),
193 when: "Input type or option arguments are not valid for find.",
194 message: "find: invalid input arguments",
195};
196
197const FIND_ERROR_PROVIDER_OUTPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
198 code: "RM.FIND.PROVIDER_OUTPUT",
199 identifier: Some("RunMat:find:ProviderOutput"),
200 when: "GPU provider does not return expected output buffers for requested nargout.",
201 message: "find: provider output buffer mismatch",
202};
203
204const FIND_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
205 code: "RM.FIND.INTERNAL",
206 identifier: Some("RunMat:find:InternalError"),
207 when: "Internal tensor conversion/materialization fails while building outputs.",
208 message: "find: internal error",
209};
210
211const FIND_ERRORS: [BuiltinErrorDescriptor; 3] = [
212 FIND_ERROR_INVALID_INPUT,
213 FIND_ERROR_PROVIDER_OUTPUT,
214 FIND_ERROR_INTERNAL,
215];
216
217pub const FIND_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
218 signatures: &FIND_SIGNATURES,
219 output_mode: BuiltinOutputMode::ByRequestedOutputCount,
220 completion_policy: BuiltinCompletionPolicy::Public,
221 errors: &FIND_ERRORS,
222};
223
224fn find_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
225 find_error_with_message(error.message, error)
226}
227
228fn find_error_with_message(
229 message: impl Into<String>,
230 error: &'static BuiltinErrorDescriptor,
231) -> RuntimeError {
232 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
233 if let Some(identifier) = error.identifier {
234 builder = builder.with_identifier(identifier);
235 }
236 builder.build()
237}
238
239fn parse_find_tokens(tokens: &[ArgToken]) -> crate::BuiltinResult<FindOptions> {
240 match tokens.len() {
241 0 => Ok(FindOptions::default()),
242 1 => {
243 if let Some(direction) = token_to_direction(&tokens[0])? {
244 let limit = if matches!(direction, FindDirection::Last) {
245 Some(1)
246 } else {
247 None
248 };
249 Ok(FindOptions { limit, direction })
250 } else {
251 let limit = token_to_limit(&tokens[0])?;
252 Ok(FindOptions {
253 limit: Some(limit),
254 direction: FindDirection::First,
255 })
256 }
257 }
258 2 => {
259 let limit = token_to_limit(&tokens[0])?;
260 let direction = token_to_direction(&tokens[1])?.ok_or_else(|| {
261 find_error_with_message(
262 "find: third argument must be 'first' or 'last'",
263 &FIND_ERROR_INVALID_INPUT,
264 )
265 })?;
266 Ok(FindOptions {
267 limit: Some(limit),
268 direction,
269 })
270 }
271 _ => Err(find_error_with_message(
272 "find: too many input arguments",
273 &FIND_ERROR_INVALID_INPUT,
274 )),
275 }
276}
277
278fn token_to_direction(token: &ArgToken) -> crate::BuiltinResult<Option<FindDirection>> {
279 match token {
280 ArgToken::String(text) => match text.as_str() {
281 "first" => Ok(Some(FindDirection::First)),
282 "last" => Ok(Some(FindDirection::Last)),
283 _ => Err(find_error_with_message(
284 "find: direction must be 'first' or 'last'",
285 &FIND_ERROR_INVALID_INPUT,
286 )),
287 },
288 _ => Ok(None),
289 }
290}
291
292fn token_to_limit(token: &ArgToken) -> crate::BuiltinResult<usize> {
293 match token {
294 ArgToken::Number(value) => parse_limit_scalar(*value),
295 _ => Err(find_error_with_message(
296 "find: second argument must be a scalar",
297 &FIND_ERROR_INVALID_INPUT,
298 )),
299 }
300}
301
302#[runtime_builtin(
303 name = "find",
304 category = "array/indexing",
305 summary = "Locate nonzero indices and values.",
306 keywords = "find,nonzero,indices,row,column,gpu",
307 accel = "custom",
308 type_resolver(find_type),
309 descriptor(crate::builtins::array::indexing::find::FIND_DESCRIPTOR),
310 builtin_path = "crate::builtins::array::indexing::find"
311)]
312async fn find_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
313 let eval = evaluate(value, &rest).await?;
314 if let Some(out_count) = crate::output_count::current_output_count() {
315 if out_count == 0 {
316 return Ok(Value::OutputList(Vec::new()));
317 }
318 if out_count <= 1 {
319 let linear = eval.linear_value()?;
320 return Ok(crate::output_count::output_list_with_padding(
321 out_count,
322 vec![linear],
323 ));
324 }
325 let rows = eval.row_value()?;
326 let cols = eval.column_value()?;
327 let mut outputs = vec![rows, cols];
328 if out_count >= 3 {
329 outputs.push(eval.values_value()?);
330 }
331 return Ok(crate::output_count::output_list_with_padding(
332 out_count, outputs,
333 ));
334 }
335 eval.linear_value()
336}
337
338pub async fn evaluate(value: Value, args: &[Value]) -> crate::BuiltinResult<FindEval> {
340 let options = parse_options(args).await?;
341 match value {
342 Value::GpuTensor(handle) => {
343 if let Some(result) = try_provider_find(&handle, &options) {
344 return Ok(FindEval::from_gpu(result));
345 }
346 let (storage, _) = materialize_input(Value::GpuTensor(handle)).await?;
347 let result = compute_find(&storage, &options);
348 Ok(FindEval::from_host(result, true))
349 }
350 other => {
351 let (storage, input_was_gpu) = materialize_input(other).await?;
352 let result = compute_find(&storage, &options);
353 Ok(FindEval::from_host(result, input_was_gpu))
354 }
355 }
356}
357
358fn try_provider_find(
359 handle: &runmat_accelerate_api::GpuTensorHandle,
360 options: &FindOptions,
361) -> Option<ProviderFindResult> {
362 #[cfg(all(test, feature = "wgpu"))]
363 {
364 if handle.device_id != 0 {
365 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
366 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
367 );
368 }
369 }
370 let provider = runmat_accelerate_api::provider()?;
371 let direction = match options.direction {
372 FindDirection::First => runmat_accelerate_api::FindDirection::First,
373 FindDirection::Last => runmat_accelerate_api::FindDirection::Last,
374 };
375 let limit = options.effective_limit();
376 provider.find(handle, limit, direction).ok()
377}
378
379#[derive(Debug, Clone, Copy, PartialEq, Eq)]
380enum FindDirection {
381 First,
382 Last,
383}
384
385#[derive(Debug, Clone)]
386struct FindOptions {
387 limit: Option<usize>,
388 direction: FindDirection,
389}
390
391impl Default for FindOptions {
392 fn default() -> Self {
393 Self {
394 limit: None,
395 direction: FindDirection::First,
396 }
397 }
398}
399
400impl FindOptions {
401 fn effective_limit(&self) -> Option<usize> {
402 match self.direction {
403 FindDirection::Last => self.limit.or(Some(1)),
404 FindDirection::First => self.limit,
405 }
406 }
407}
408
409#[derive(Clone)]
410enum DataStorage {
411 Real(Tensor),
412 Complex(ComplexTensor),
413}
414
415impl DataStorage {
416 fn shape(&self) -> &[usize] {
417 match self {
418 DataStorage::Real(t) => &t.shape,
419 DataStorage::Complex(t) => &t.shape,
420 }
421 }
422}
423
424#[derive(Clone)]
425struct FindResult {
426 shape: Vec<usize>,
427 indices: Vec<usize>,
428 values: FindValues,
429}
430
431#[derive(Clone)]
432enum FindValues {
433 Real(Vec<f64>),
434 Complex(Vec<(f64, f64)>),
435}
436
437pub struct FindEval {
438 inner: FindEvalInner,
439}
440
441enum FindEvalInner {
442 Host {
443 result: FindResult,
444 prefer_gpu: bool,
445 },
446 Gpu {
447 result: ProviderFindResult,
448 },
449}
450
451impl FindEval {
452 fn from_host(result: FindResult, prefer_gpu: bool) -> Self {
453 Self {
454 inner: FindEvalInner::Host { result, prefer_gpu },
455 }
456 }
457
458 fn from_gpu(result: ProviderFindResult) -> Self {
459 Self {
460 inner: FindEvalInner::Gpu { result },
461 }
462 }
463
464 pub fn linear_value(&self) -> crate::BuiltinResult<Value> {
465 match &self.inner {
466 FindEvalInner::Host { result, prefer_gpu } => {
467 let tensor = result.linear_tensor()?;
468 Ok(tensor_to_value(tensor, *prefer_gpu))
469 }
470 FindEvalInner::Gpu { result } => Ok(Value::GpuTensor(result.linear.clone())),
471 }
472 }
473
474 pub fn row_value(&self) -> crate::BuiltinResult<Value> {
475 match &self.inner {
476 FindEvalInner::Host { result, prefer_gpu } => {
477 let tensor = result.row_tensor()?;
478 Ok(tensor_to_value(tensor, *prefer_gpu))
479 }
480 FindEvalInner::Gpu { result } => Ok(Value::GpuTensor(result.rows.clone())),
481 }
482 }
483
484 pub fn column_value(&self) -> crate::BuiltinResult<Value> {
485 match &self.inner {
486 FindEvalInner::Host { result, prefer_gpu } => {
487 let tensor = result.column_tensor()?;
488 Ok(tensor_to_value(tensor, *prefer_gpu))
489 }
490 FindEvalInner::Gpu { result } => Ok(Value::GpuTensor(result.cols.clone())),
491 }
492 }
493
494 pub fn values_value(&self) -> crate::BuiltinResult<Value> {
495 match &self.inner {
496 FindEvalInner::Host { result, prefer_gpu } => result.values_value(*prefer_gpu),
497 FindEvalInner::Gpu { result } => result
498 .values
499 .as_ref()
500 .map(|handle| Value::GpuTensor(handle.clone()))
501 .ok_or_else(|| find_error(&FIND_ERROR_PROVIDER_OUTPUT)),
502 }
503 }
504}
505
506async fn parse_options(args: &[Value]) -> crate::BuiltinResult<FindOptions> {
507 parse_find_tokens(&crate::builtins::common::arg_tokens::tokens_from_values(
508 args,
509 ))
510}
511
512fn parse_limit_scalar(value: f64) -> crate::BuiltinResult<usize> {
513 if !value.is_finite() {
514 return Err(find_error_with_message(
515 "find: K must be a finite, non-negative integer",
516 &FIND_ERROR_INVALID_INPUT,
517 ));
518 }
519 let rounded = value.round();
520 if (rounded - value).abs() > f64::EPSILON {
521 return Err(find_error_with_message(
522 "find: K must be a finite, non-negative integer",
523 &FIND_ERROR_INVALID_INPUT,
524 ));
525 }
526 if rounded < 0.0 {
527 return Err(find_error_with_message(
528 "find: K must be >= 0",
529 &FIND_ERROR_INVALID_INPUT,
530 ));
531 }
532 Ok(rounded as usize)
533}
534
535async fn materialize_input(value: Value) -> crate::BuiltinResult<(DataStorage, bool)> {
536 match value {
537 Value::GpuTensor(handle) => {
538 let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
539 Ok((DataStorage::Real(tensor), true))
540 }
541 Value::Tensor(tensor) => Ok((DataStorage::Real(tensor), false)),
542 Value::LogicalArray(logical) => {
543 let tensor = tensor::logical_to_tensor(&logical)
544 .map_err(|message| find_error_with_message(message, &FIND_ERROR_INTERNAL))?;
545 Ok((DataStorage::Real(tensor), false))
546 }
547 Value::Num(n) => {
548 let tensor = Tensor::new(vec![n], vec![1, 1])
549 .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))?;
550 Ok((DataStorage::Real(tensor), false))
551 }
552 Value::Int(i) => {
553 let tensor = Tensor::new(vec![i.to_f64()], vec![1, 1])
554 .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))?;
555 Ok((DataStorage::Real(tensor), false))
556 }
557 Value::Bool(b) => {
558 let tensor = Tensor::new(vec![if b { 1.0 } else { 0.0 }], vec![1, 1])
559 .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))?;
560 Ok((DataStorage::Real(tensor), false))
561 }
562 Value::Complex(re, im) => {
563 let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
564 .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))?;
565 Ok((DataStorage::Complex(tensor), false))
566 }
567 Value::ComplexTensor(tensor) => Ok((DataStorage::Complex(tensor), false)),
568 Value::CharArray(chars) => {
569 let mut data = Vec::with_capacity(chars.data.len());
570 for c in 0..chars.cols {
571 for r in 0..chars.rows {
572 let ch = chars.data[r * chars.cols + c] as u32;
573 data.push(ch as f64);
574 }
575 }
576 let tensor = Tensor::new(data, vec![chars.rows, chars.cols])
577 .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))?;
578 Ok((DataStorage::Real(tensor), false))
579 }
580 other => Err(find_error_with_message(
581 format!(
582 "find: unsupported input type {:?}; expected numeric, logical, or char data",
583 other
584 ),
585 &FIND_ERROR_INVALID_INPUT,
586 )),
587 }
588}
589
590fn compute_find(storage: &DataStorage, options: &FindOptions) -> FindResult {
591 let shape = storage.shape().to_vec();
592 let limit = options.effective_limit();
593
594 match storage {
595 DataStorage::Real(tensor) => {
596 let mut indices = Vec::new();
597 let mut values = Vec::new();
598
599 if matches!(limit, Some(0)) {
600 return FindResult::new(shape, indices, FindValues::Real(values));
601 }
602
603 let len = tensor.data.len();
604 match options.direction {
605 FindDirection::First => {
606 for idx in 0..len {
607 let value = tensor.data[idx];
608 if value != 0.0 {
609 indices.push(idx + 1);
610 values.push(value);
611 if limit.is_some_and(|k| indices.len() >= k) {
612 break;
613 }
614 }
615 }
616 }
617 FindDirection::Last => {
618 for idx in (0..len).rev() {
619 let value = tensor.data[idx];
620 if value != 0.0 {
621 indices.push(idx + 1);
622 values.push(value);
623 if limit.is_some_and(|k| indices.len() >= k) {
624 break;
625 }
626 }
627 }
628 }
629 }
630
631 FindResult::new(shape, indices, FindValues::Real(values))
632 }
633 DataStorage::Complex(tensor) => {
634 let mut indices = Vec::new();
635 let mut values = Vec::new();
636
637 if matches!(limit, Some(0)) {
638 return FindResult::new(shape, indices, FindValues::Complex(values));
639 }
640
641 let len = tensor.data.len();
642 match options.direction {
643 FindDirection::First => {
644 for idx in 0..len {
645 let (re, im) = tensor.data[idx];
646 if re != 0.0 || im != 0.0 {
647 indices.push(idx + 1);
648 values.push((re, im));
649 if limit.is_some_and(|k| indices.len() >= k) {
650 break;
651 }
652 }
653 }
654 }
655 FindDirection::Last => {
656 for idx in (0..len).rev() {
657 let (re, im) = tensor.data[idx];
658 if re != 0.0 || im != 0.0 {
659 indices.push(idx + 1);
660 values.push((re, im));
661 if limit.is_some_and(|k| indices.len() >= k) {
662 break;
663 }
664 }
665 }
666 }
667 }
668
669 FindResult::new(shape, indices, FindValues::Complex(values))
670 }
671 }
672}
673
674impl FindResult {
675 fn new(shape: Vec<usize>, indices: Vec<usize>, values: FindValues) -> Self {
676 Self {
677 shape,
678 indices,
679 values,
680 }
681 }
682
683 fn linear_tensor(&self) -> crate::BuiltinResult<Tensor> {
684 let data: Vec<f64> = self.indices.iter().map(|&idx| idx as f64).collect();
685 let rows = data.len();
686 Tensor::new(data, vec![rows, 1])
687 .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))
688 }
689
690 fn row_tensor(&self) -> crate::BuiltinResult<Tensor> {
691 let mut data = Vec::with_capacity(self.indices.len());
692 let rows = self.shape.first().copied().unwrap_or(1).max(1);
693 for &idx in &self.indices {
694 let zero_based = idx - 1;
695 let row = (zero_based % rows) + 1;
696 data.push(row as f64);
697 }
698 Tensor::new(data, vec![self.indices.len(), 1])
699 .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))
700 }
701
702 fn column_tensor(&self) -> crate::BuiltinResult<Tensor> {
703 let mut data = Vec::with_capacity(self.indices.len());
704 let rows = self.shape.first().copied().unwrap_or(1).max(1);
705 for &idx in &self.indices {
706 let zero_based = idx - 1;
707 let col = (zero_based / rows) + 1;
708 data.push(col as f64);
709 }
710 Tensor::new(data, vec![self.indices.len(), 1])
711 .map_err(|e| find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL))
712 }
713
714 fn values_value(&self, prefer_gpu: bool) -> crate::BuiltinResult<Value> {
715 match &self.values {
716 FindValues::Real(values) => {
717 let tensor = Tensor::new(values.clone(), vec![values.len(), 1]).map_err(|e| {
718 find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL)
719 })?;
720 Ok(tensor_to_value(tensor, prefer_gpu))
721 }
722 FindValues::Complex(values) => {
723 let tensor =
724 ComplexTensor::new(values.clone(), vec![values.len(), 1]).map_err(|e| {
725 find_error_with_message(format!("find: {e}"), &FIND_ERROR_INTERNAL)
726 })?;
727 Ok(complex_tensor_into_value(tensor))
728 }
729 }
730 }
731}
732
733fn tensor_to_value(tensor: Tensor, prefer_gpu: bool) -> Value {
734 if prefer_gpu {
735 if let Some(provider) = runmat_accelerate_api::provider() {
736 let view = HostTensorView {
737 data: &tensor.data,
738 shape: &tensor.shape,
739 };
740 if let Ok(handle) = provider.upload(&view) {
741 return Value::GpuTensor(handle);
742 }
743 }
744 }
745 tensor::tensor_into_value(tensor)
746}
747
748#[cfg(test)]
749pub(crate) mod tests {
750 use super::*;
751 use crate::builtins::common::test_support;
752 use futures::executor::block_on;
753 use runmat_builtins::{CharArray, IntValue, Type};
754
755 fn find_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
756 block_on(super::find_builtin(value, rest))
757 }
758
759 fn evaluate(value: Value, rest: &[Value]) -> crate::BuiltinResult<FindEval> {
760 block_on(super::evaluate(value, rest))
761 }
762
763 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
764 #[test]
765 fn find_linear_indices_basic() {
766 let tensor = Tensor::new(vec![0.0, 4.0, 0.0, 7.0, 0.0, 9.0], vec![2, 3]).unwrap();
767 let value = find_builtin(Value::Tensor(tensor), Vec::new()).expect("find");
768 match value {
769 Value::Tensor(t) => {
770 assert_eq!(t.shape, vec![3, 1]);
771 assert_eq!(t.data, vec![2.0, 4.0, 6.0]);
772 }
773 other => panic!("expected tensor, got {other:?}"),
774 }
775 }
776
777 #[test]
778 fn find_type_is_column_vector() {
779 assert_eq!(
780 find_type(
781 &[Type::Tensor { shape: None }],
782 &ResolveContext::new(Vec::new()),
783 ),
784 Type::Tensor {
785 shape: Some(vec![None, Some(1)])
786 }
787 );
788 }
789
790 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
791 #[test]
792 fn find_limited_first() {
793 let tensor = Tensor::new(vec![0.0, 3.0, 5.0, 0.0, 8.0], vec![1, 5]).unwrap();
794 let result =
795 find_builtin(Value::Tensor(tensor), vec![Value::Int(IntValue::I32(2))]).expect("find");
796 match result {
797 Value::Tensor(t) => {
798 assert_eq!(t.data, vec![2.0, 3.0]);
799 }
800 other => panic!("expected tensor, got {other:?}"),
801 }
802 }
803
804 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
805 #[test]
806 fn find_last_single() {
807 let tensor = Tensor::new(vec![1.0, 0.0, 0.0, 6.0, 0.0, 2.0], vec![1, 6]).unwrap();
808 let result = find_builtin(Value::Tensor(tensor), vec![Value::from("last")]).expect("find");
809 match result {
810 Value::Num(n) => assert_eq!(n, 6.0),
811 Value::Tensor(t) => {
812 assert_eq!(t.data, vec![6.0]);
813 }
814 other => panic!("unexpected result {other:?}"),
815 }
816 }
817
818 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
819 #[test]
820 fn find_complex_values() {
821 let tensor =
822 ComplexTensor::new(vec![(0.0, 0.0), (1.0, 2.0), (0.0, 0.0)], vec![3, 1]).unwrap();
823 let eval = evaluate(Value::ComplexTensor(tensor), &[]).expect("find compute");
824 let values = eval.values_value().expect("values");
825 match values {
826 Value::Complex(re, im) => {
827 assert_eq!(re, 1.0);
828 assert_eq!(im, 2.0);
829 }
830 Value::ComplexTensor(ct) => {
831 assert_eq!(ct.shape, vec![1, 1]);
832 assert_eq!(ct.data, vec![(1.0, 2.0)]);
833 }
834 other => panic!("expected complex result, got {other:?}"),
835 }
836 }
837
838 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
839 #[test]
840 fn find_gpu_roundtrip() {
841 test_support::with_test_provider(|provider| {
842 let tensor = Tensor::new(vec![0.0, 4.0, 0.0, 7.0], vec![2, 2]).unwrap();
843 let view = HostTensorView {
844 data: &tensor.data,
845 shape: &tensor.shape,
846 };
847 let handle = provider.upload(&view).expect("upload");
848 let result = find_builtin(Value::GpuTensor(handle), Vec::new()).expect("find");
849 let gathered = test_support::gather(result).expect("gather");
850 assert_eq!(gathered.shape, vec![2, 1]);
851 assert_eq!(gathered.data, vec![2.0, 4.0]);
852 });
853 }
854
855 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
856 #[test]
857 fn find_direction_error() {
858 let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
859 let err = find_builtin(
860 Value::Tensor(tensor),
861 vec![Value::Int(IntValue::I32(1)), Value::from("invalid")],
862 )
863 .expect_err("expected error");
864 assert!(err.to_string().contains("direction"));
865 assert_eq!(err.identifier(), super::FIND_ERROR_INVALID_INPUT.identifier);
866 }
867
868 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
869 #[test]
870 fn find_multi_output_rows_cols_values() {
871 let tensor = Tensor::new(vec![0.0, 2.0, 3.0, 0.0, 0.0, 6.0], vec![2, 3]).unwrap();
872 let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
873
874 let rows = test_support::gather(eval.row_value().expect("rows")).expect("gather rows");
875 assert_eq!(rows.shape, vec![3, 1]);
876 assert_eq!(rows.data, vec![2.0, 1.0, 2.0]);
877
878 let cols = test_support::gather(eval.column_value().expect("cols")).expect("gather cols");
879 assert_eq!(cols.shape, vec![3, 1]);
880 assert_eq!(cols.data, vec![1.0, 2.0, 3.0]);
881
882 let vals = test_support::gather(eval.values_value().expect("vals")).expect("gather vals");
883 assert_eq!(vals.shape, vec![3, 1]);
884 assert_eq!(vals.data, vec![2.0, 3.0, 6.0]);
885 }
886
887 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
888 #[test]
889 fn find_last_order_descending() {
890 let tensor = Tensor::new(vec![1.0, 0.0, 2.0, 3.0, 0.0], vec![1, 5]).unwrap();
891 let result = find_builtin(
892 Value::Tensor(tensor),
893 vec![Value::Int(IntValue::I32(2)), Value::from("last")],
894 )
895 .expect("find");
896 match result {
897 Value::Tensor(t) => {
898 assert_eq!(t.shape, vec![2, 1]);
899 assert_eq!(t.data, vec![4.0, 3.0]);
900 }
901 Value::Num(_) => panic!("expected column vector"),
902 other => panic!("unexpected result {other:?}"),
903 }
904 }
905
906 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
907 #[test]
908 fn find_limit_zero_returns_empty() {
909 let tensor = Tensor::new(vec![1.0, 0.0, 3.0], vec![3, 1]).unwrap();
910 let result = find_builtin(Value::Tensor(tensor), vec![Value::Num(0.0)]).expect("find");
911 match result {
912 Value::Tensor(t) => {
913 assert_eq!(t.shape, vec![0, 1]);
914 assert!(t.data.is_empty());
915 }
916 other => panic!("expected empty tensor, got {other:?}"),
917 }
918 }
919
920 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
921 #[test]
922 fn find_char_array_supports_nonzero_codes() {
923 let chars = CharArray::new(vec!['\0', 'A', '\0'], 1, 3).unwrap();
924 let result = find_builtin(Value::CharArray(chars), Vec::new()).expect("find");
925 match result {
926 Value::Num(n) => assert_eq!(n, 2.0),
927 Value::Tensor(t) => assert_eq!(t.data, vec![2.0]),
928 other => panic!("unexpected result {other:?}"),
929 }
930 }
931
932 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
933 #[test]
934 fn find_gpu_multi_outputs_return_gpu_handles() {
935 test_support::with_test_provider(|provider| {
936 let tensor = Tensor::new(vec![0.0, 4.0, 5.0, 0.0], vec![2, 2]).unwrap();
937 let view = HostTensorView {
938 data: &tensor.data,
939 shape: &tensor.shape,
940 };
941 let handle = provider.upload(&view).expect("upload");
942 let eval = evaluate(Value::GpuTensor(handle), &[]).expect("evaluate");
943
944 let rows = eval.row_value().expect("rows");
945 assert!(matches!(rows, Value::GpuTensor(_)));
946 let rows_host = test_support::gather(rows).expect("gather rows");
947 assert_eq!(rows_host.data, vec![2.0, 1.0]);
948
949 let cols = eval.column_value().expect("cols");
950 assert!(matches!(cols, Value::GpuTensor(_)));
951 let cols_host = test_support::gather(cols).expect("gather cols");
952 assert_eq!(cols_host.data, vec![1.0, 2.0]);
953
954 let vals = eval.values_value().expect("vals");
955 assert!(matches!(vals, Value::GpuTensor(_)));
956 let vals_host = test_support::gather(vals).expect("gather vals");
957 assert_eq!(vals_host.data, vec![4.0, 5.0]);
958 });
959 }
960
961 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
962 #[test]
963 #[cfg(feature = "wgpu")]
964 fn find_wgpu_matches_cpu() {
965 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
966 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
967 );
968 let tensor = Tensor::new(vec![0.0, 2.0, 0.0, 3.0, 4.0, 0.0], vec![3, 2]).unwrap();
969 let cpu_eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("cpu evaluate");
970 let cpu_linear =
971 test_support::gather(cpu_eval.linear_value().expect("cpu linear")).expect("cpu gather");
972 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
973 let view = HostTensorView {
974 data: &tensor.data,
975 shape: &tensor.shape,
976 };
977 let handle = provider.upload(&view).expect("upload");
978 let gpu_eval = evaluate(Value::GpuTensor(handle), &[]).expect("gpu evaluate");
979 let gpu_linear =
980 test_support::gather(gpu_eval.linear_value().expect("gpu linear")).expect("gpu gather");
981 assert_eq!(gpu_linear.data, cpu_linear.data);
982 }
983}