1use runmat_accelerate_api::{HostTensorView, ProviderFindResult};
4use runmat_builtins::{ComplexTensor, ResolveContext, Tensor, Type, Value};
5use runmat_macros::runtime_builtin;
6
7use crate::builtins::array::type_resolvers::column_vector_type;
8use crate::builtins::common::arg_tokens::ArgToken;
9use crate::builtins::common::random_args::complex_tensor_into_value;
10use crate::builtins::common::spec::{
11 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
12 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
13};
14use crate::builtins::common::{gpu_helpers, tensor};
15use crate::{build_runtime_error, RuntimeError};
16
17#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::indexing::find")]
18pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
19 name: "find",
20 op_kind: GpuOpKind::Custom("find"),
21 supported_precisions: &[ScalarType::F32, ScalarType::F64],
22 broadcast: BroadcastSemantics::None,
23 provider_hooks: &[ProviderHook::Custom("find")],
24 constant_strategy: ConstantStrategy::InlineLiteral,
25 residency: ResidencyPolicy::NewHandle,
26 nan_mode: ReductionNaN::Include,
27 two_pass_threshold: None,
28 workgroup_size: None,
29 accepts_nan_mode: false,
30 notes: "WGPU provider executes find directly on device; other providers fall back to host and re-upload results to preserve residency.",
31};
32
33#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::indexing::find")]
34pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
35 name: "find",
36 shape: ShapeRequirements::Any,
37 constant_strategy: ConstantStrategy::InlineLiteral,
38 elementwise: None,
39 reduction: None,
40 emits_nan: false,
41 notes: "Find drives control flow and currently bypasses fusion; metadata is present for completeness only.",
42};
43
44fn find_type(_args: &[Type], _ctx: &ResolveContext) -> Type {
45 column_vector_type()
46}
47
48fn parse_find_tokens(tokens: &[ArgToken]) -> crate::BuiltinResult<FindOptions> {
49 match tokens.len() {
50 0 => Ok(FindOptions::default()),
51 1 => {
52 if let Some(direction) = token_to_direction(&tokens[0])? {
53 let limit = if matches!(direction, FindDirection::Last) {
54 Some(1)
55 } else {
56 None
57 };
58 Ok(FindOptions { limit, direction })
59 } else {
60 let limit = token_to_limit(&tokens[0])?;
61 Ok(FindOptions {
62 limit: Some(limit),
63 direction: FindDirection::First,
64 })
65 }
66 }
67 2 => {
68 let limit = token_to_limit(&tokens[0])?;
69 let direction = token_to_direction(&tokens[1])?
70 .ok_or_else(|| find_error("find: third argument must be 'first' or 'last'"))?;
71 Ok(FindOptions {
72 limit: Some(limit),
73 direction,
74 })
75 }
76 _ => Err(find_error("find: too many input arguments")),
77 }
78}
79
80fn token_to_direction(token: &ArgToken) -> crate::BuiltinResult<Option<FindDirection>> {
81 match token {
82 ArgToken::String(text) => match text.as_str() {
83 "first" => Ok(Some(FindDirection::First)),
84 "last" => Ok(Some(FindDirection::Last)),
85 _ => Err(find_error("find: direction must be 'first' or 'last'")),
86 },
87 _ => Ok(None),
88 }
89}
90
91fn token_to_limit(token: &ArgToken) -> crate::BuiltinResult<usize> {
92 match token {
93 ArgToken::Number(value) => parse_limit_scalar(*value),
94 _ => Err(find_error("find: second argument must be a scalar")),
95 }
96}
97
98#[runtime_builtin(
99 name = "find",
100 category = "array/indexing",
101 summary = "Locate indices and values of nonzero elements.",
102 keywords = "find,nonzero,indices,row,column,gpu",
103 accel = "custom",
104 type_resolver(find_type),
105 builtin_path = "crate::builtins::array::indexing::find"
106)]
107async fn find_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
108 let eval = evaluate(value, &rest).await?;
109 if let Some(out_count) = crate::output_count::current_output_count() {
110 if out_count == 0 {
111 return Ok(Value::OutputList(Vec::new()));
112 }
113 if out_count <= 1 {
114 let linear = eval.linear_value()?;
115 return Ok(crate::output_count::output_list_with_padding(
116 out_count,
117 vec![linear],
118 ));
119 }
120 let rows = eval.row_value()?;
121 let cols = eval.column_value()?;
122 let mut outputs = vec![rows, cols];
123 if out_count >= 3 {
124 outputs.push(eval.values_value()?);
125 }
126 return Ok(crate::output_count::output_list_with_padding(
127 out_count, outputs,
128 ));
129 }
130 eval.linear_value()
131}
132
133pub async fn evaluate(value: Value, args: &[Value]) -> crate::BuiltinResult<FindEval> {
135 let options = parse_options(args).await?;
136 match value {
137 Value::GpuTensor(handle) => {
138 if let Some(result) = try_provider_find(&handle, &options) {
139 return Ok(FindEval::from_gpu(result));
140 }
141 let (storage, _) = materialize_input(Value::GpuTensor(handle)).await?;
142 let result = compute_find(&storage, &options);
143 Ok(FindEval::from_host(result, true))
144 }
145 other => {
146 let (storage, input_was_gpu) = materialize_input(other).await?;
147 let result = compute_find(&storage, &options);
148 Ok(FindEval::from_host(result, input_was_gpu))
149 }
150 }
151}
152
153fn try_provider_find(
154 handle: &runmat_accelerate_api::GpuTensorHandle,
155 options: &FindOptions,
156) -> Option<ProviderFindResult> {
157 #[cfg(all(test, feature = "wgpu"))]
158 {
159 if handle.device_id != 0 {
160 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
161 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
162 );
163 }
164 }
165 let provider = runmat_accelerate_api::provider()?;
166 let direction = match options.direction {
167 FindDirection::First => runmat_accelerate_api::FindDirection::First,
168 FindDirection::Last => runmat_accelerate_api::FindDirection::Last,
169 };
170 let limit = options.effective_limit();
171 provider.find(handle, limit, direction).ok()
172}
173
174#[derive(Debug, Clone, Copy, PartialEq, Eq)]
175enum FindDirection {
176 First,
177 Last,
178}
179
180#[derive(Debug, Clone)]
181struct FindOptions {
182 limit: Option<usize>,
183 direction: FindDirection,
184}
185
186impl Default for FindOptions {
187 fn default() -> Self {
188 Self {
189 limit: None,
190 direction: FindDirection::First,
191 }
192 }
193}
194
195impl FindOptions {
196 fn effective_limit(&self) -> Option<usize> {
197 match self.direction {
198 FindDirection::Last => self.limit.or(Some(1)),
199 FindDirection::First => self.limit,
200 }
201 }
202}
203
204#[derive(Clone)]
205enum DataStorage {
206 Real(Tensor),
207 Complex(ComplexTensor),
208}
209
210impl DataStorage {
211 fn shape(&self) -> &[usize] {
212 match self {
213 DataStorage::Real(t) => &t.shape,
214 DataStorage::Complex(t) => &t.shape,
215 }
216 }
217}
218
219#[derive(Clone)]
220struct FindResult {
221 shape: Vec<usize>,
222 indices: Vec<usize>,
223 values: FindValues,
224}
225
226#[derive(Clone)]
227enum FindValues {
228 Real(Vec<f64>),
229 Complex(Vec<(f64, f64)>),
230}
231
232pub struct FindEval {
233 inner: FindEvalInner,
234}
235
236enum FindEvalInner {
237 Host {
238 result: FindResult,
239 prefer_gpu: bool,
240 },
241 Gpu {
242 result: ProviderFindResult,
243 },
244}
245
246impl FindEval {
247 fn from_host(result: FindResult, prefer_gpu: bool) -> Self {
248 Self {
249 inner: FindEvalInner::Host { result, prefer_gpu },
250 }
251 }
252
253 fn from_gpu(result: ProviderFindResult) -> Self {
254 Self {
255 inner: FindEvalInner::Gpu { result },
256 }
257 }
258
259 pub fn linear_value(&self) -> crate::BuiltinResult<Value> {
260 match &self.inner {
261 FindEvalInner::Host { result, prefer_gpu } => {
262 let tensor = result.linear_tensor()?;
263 Ok(tensor_to_value(tensor, *prefer_gpu))
264 }
265 FindEvalInner::Gpu { result } => Ok(Value::GpuTensor(result.linear.clone())),
266 }
267 }
268
269 pub fn row_value(&self) -> crate::BuiltinResult<Value> {
270 match &self.inner {
271 FindEvalInner::Host { result, prefer_gpu } => {
272 let tensor = result.row_tensor()?;
273 Ok(tensor_to_value(tensor, *prefer_gpu))
274 }
275 FindEvalInner::Gpu { result } => Ok(Value::GpuTensor(result.rows.clone())),
276 }
277 }
278
279 pub fn column_value(&self) -> crate::BuiltinResult<Value> {
280 match &self.inner {
281 FindEvalInner::Host { result, prefer_gpu } => {
282 let tensor = result.column_tensor()?;
283 Ok(tensor_to_value(tensor, *prefer_gpu))
284 }
285 FindEvalInner::Gpu { result } => Ok(Value::GpuTensor(result.cols.clone())),
286 }
287 }
288
289 pub fn values_value(&self) -> crate::BuiltinResult<Value> {
290 match &self.inner {
291 FindEvalInner::Host { result, prefer_gpu } => result.values_value(*prefer_gpu),
292 FindEvalInner::Gpu { result } => result
293 .values
294 .as_ref()
295 .map(|handle| Value::GpuTensor(handle.clone()))
296 .ok_or_else(|| find_error("find: provider did not return values buffer")),
297 }
298 }
299}
300
301async fn parse_options(args: &[Value]) -> crate::BuiltinResult<FindOptions> {
302 parse_find_tokens(&crate::builtins::common::arg_tokens::tokens_from_values(
303 args,
304 ))
305}
306
307fn parse_limit_scalar(value: f64) -> crate::BuiltinResult<usize> {
308 if !value.is_finite() {
309 return Err(find_error("find: K must be a finite, non-negative integer"));
310 }
311 let rounded = value.round();
312 if (rounded - value).abs() > f64::EPSILON {
313 return Err(find_error("find: K must be a finite, non-negative integer"));
314 }
315 if rounded < 0.0 {
316 return Err(find_error("find: K must be >= 0"));
317 }
318 Ok(rounded as usize)
319}
320
321async fn materialize_input(value: Value) -> crate::BuiltinResult<(DataStorage, bool)> {
322 match value {
323 Value::GpuTensor(handle) => {
324 let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
325 Ok((DataStorage::Real(tensor), true))
326 }
327 Value::Tensor(tensor) => Ok((DataStorage::Real(tensor), false)),
328 Value::LogicalArray(logical) => {
329 let tensor =
330 tensor::logical_to_tensor(&logical).map_err(|message| find_error(message))?;
331 Ok((DataStorage::Real(tensor), false))
332 }
333 Value::Num(n) => {
334 let tensor =
335 Tensor::new(vec![n], vec![1, 1]).map_err(|e| find_error(format!("find: {e}")))?;
336 Ok((DataStorage::Real(tensor), false))
337 }
338 Value::Int(i) => {
339 let tensor = Tensor::new(vec![i.to_f64()], vec![1, 1])
340 .map_err(|e| find_error(format!("find: {e}")))?;
341 Ok((DataStorage::Real(tensor), false))
342 }
343 Value::Bool(b) => {
344 let tensor = Tensor::new(vec![if b { 1.0 } else { 0.0 }], vec![1, 1])
345 .map_err(|e| find_error(format!("find: {e}")))?;
346 Ok((DataStorage::Real(tensor), false))
347 }
348 Value::Complex(re, im) => {
349 let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
350 .map_err(|e| find_error(format!("find: {e}")))?;
351 Ok((DataStorage::Complex(tensor), false))
352 }
353 Value::ComplexTensor(tensor) => Ok((DataStorage::Complex(tensor), false)),
354 Value::CharArray(chars) => {
355 let mut data = Vec::with_capacity(chars.data.len());
356 for c in 0..chars.cols {
357 for r in 0..chars.rows {
358 let ch = chars.data[r * chars.cols + c] as u32;
359 data.push(ch as f64);
360 }
361 }
362 let tensor = Tensor::new(data, vec![chars.rows, chars.cols])
363 .map_err(|e| find_error(format!("find: {e}")))?;
364 Ok((DataStorage::Real(tensor), false))
365 }
366 other => Err(find_error(format!(
367 "find: unsupported input type {:?}; expected numeric, logical, or char data",
368 other
369 ))),
370 }
371}
372
373fn compute_find(storage: &DataStorage, options: &FindOptions) -> FindResult {
374 let shape = storage.shape().to_vec();
375 let limit = options.effective_limit();
376
377 match storage {
378 DataStorage::Real(tensor) => {
379 let mut indices = Vec::new();
380 let mut values = Vec::new();
381
382 if matches!(limit, Some(0)) {
383 return FindResult::new(shape, indices, FindValues::Real(values));
384 }
385
386 let len = tensor.data.len();
387 match options.direction {
388 FindDirection::First => {
389 for idx in 0..len {
390 let value = tensor.data[idx];
391 if value != 0.0 {
392 indices.push(idx + 1);
393 values.push(value);
394 if limit.is_some_and(|k| indices.len() >= k) {
395 break;
396 }
397 }
398 }
399 }
400 FindDirection::Last => {
401 for idx in (0..len).rev() {
402 let value = tensor.data[idx];
403 if value != 0.0 {
404 indices.push(idx + 1);
405 values.push(value);
406 if limit.is_some_and(|k| indices.len() >= k) {
407 break;
408 }
409 }
410 }
411 }
412 }
413
414 FindResult::new(shape, indices, FindValues::Real(values))
415 }
416 DataStorage::Complex(tensor) => {
417 let mut indices = Vec::new();
418 let mut values = Vec::new();
419
420 if matches!(limit, Some(0)) {
421 return FindResult::new(shape, indices, FindValues::Complex(values));
422 }
423
424 let len = tensor.data.len();
425 match options.direction {
426 FindDirection::First => {
427 for idx in 0..len {
428 let (re, im) = tensor.data[idx];
429 if re != 0.0 || im != 0.0 {
430 indices.push(idx + 1);
431 values.push((re, im));
432 if limit.is_some_and(|k| indices.len() >= k) {
433 break;
434 }
435 }
436 }
437 }
438 FindDirection::Last => {
439 for idx in (0..len).rev() {
440 let (re, im) = tensor.data[idx];
441 if re != 0.0 || im != 0.0 {
442 indices.push(idx + 1);
443 values.push((re, im));
444 if limit.is_some_and(|k| indices.len() >= k) {
445 break;
446 }
447 }
448 }
449 }
450 }
451
452 FindResult::new(shape, indices, FindValues::Complex(values))
453 }
454 }
455}
456
457impl FindResult {
458 fn new(shape: Vec<usize>, indices: Vec<usize>, values: FindValues) -> Self {
459 Self {
460 shape,
461 indices,
462 values,
463 }
464 }
465
466 fn linear_tensor(&self) -> crate::BuiltinResult<Tensor> {
467 let data: Vec<f64> = self.indices.iter().map(|&idx| idx as f64).collect();
468 let rows = data.len();
469 Tensor::new(data, vec![rows, 1]).map_err(|e| find_error(format!("find: {e}")))
470 }
471
472 fn row_tensor(&self) -> crate::BuiltinResult<Tensor> {
473 let mut data = Vec::with_capacity(self.indices.len());
474 let rows = self.shape.first().copied().unwrap_or(1).max(1);
475 for &idx in &self.indices {
476 let zero_based = idx - 1;
477 let row = (zero_based % rows) + 1;
478 data.push(row as f64);
479 }
480 Tensor::new(data, vec![self.indices.len(), 1]).map_err(|e| find_error(format!("find: {e}")))
481 }
482
483 fn column_tensor(&self) -> crate::BuiltinResult<Tensor> {
484 let mut data = Vec::with_capacity(self.indices.len());
485 let rows = self.shape.first().copied().unwrap_or(1).max(1);
486 for &idx in &self.indices {
487 let zero_based = idx - 1;
488 let col = (zero_based / rows) + 1;
489 data.push(col as f64);
490 }
491 Tensor::new(data, vec![self.indices.len(), 1]).map_err(|e| find_error(format!("find: {e}")))
492 }
493
494 fn values_value(&self, prefer_gpu: bool) -> crate::BuiltinResult<Value> {
495 match &self.values {
496 FindValues::Real(values) => {
497 let tensor = Tensor::new(values.clone(), vec![values.len(), 1])
498 .map_err(|e| find_error(format!("find: {e}")))?;
499 Ok(tensor_to_value(tensor, prefer_gpu))
500 }
501 FindValues::Complex(values) => {
502 let tensor = ComplexTensor::new(values.clone(), vec![values.len(), 1])
503 .map_err(|e| find_error(format!("find: {e}")))?;
504 Ok(complex_tensor_into_value(tensor))
505 }
506 }
507 }
508}
509
510fn tensor_to_value(tensor: Tensor, prefer_gpu: bool) -> Value {
511 if prefer_gpu {
512 if let Some(provider) = runmat_accelerate_api::provider() {
513 let view = HostTensorView {
514 data: &tensor.data,
515 shape: &tensor.shape,
516 };
517 if let Ok(handle) = provider.upload(&view) {
518 return Value::GpuTensor(handle);
519 }
520 }
521 }
522 tensor::tensor_into_value(tensor)
523}
524
525fn find_error(message: impl Into<String>) -> RuntimeError {
526 build_runtime_error(message).with_builtin("find").build()
527}
528
529#[cfg(test)]
530pub(crate) mod tests {
531 use super::*;
532 use crate::builtins::common::test_support;
533 use futures::executor::block_on;
534 use runmat_builtins::{CharArray, IntValue, Type};
535
536 fn find_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
537 block_on(super::find_builtin(value, rest))
538 }
539
540 fn evaluate(value: Value, rest: &[Value]) -> crate::BuiltinResult<FindEval> {
541 block_on(super::evaluate(value, rest))
542 }
543
544 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
545 #[test]
546 fn find_linear_indices_basic() {
547 let tensor = Tensor::new(vec![0.0, 4.0, 0.0, 7.0, 0.0, 9.0], vec![2, 3]).unwrap();
548 let value = find_builtin(Value::Tensor(tensor), Vec::new()).expect("find");
549 match value {
550 Value::Tensor(t) => {
551 assert_eq!(t.shape, vec![3, 1]);
552 assert_eq!(t.data, vec![2.0, 4.0, 6.0]);
553 }
554 other => panic!("expected tensor, got {other:?}"),
555 }
556 }
557
558 #[test]
559 fn find_type_is_column_vector() {
560 assert_eq!(
561 find_type(
562 &[Type::Tensor { shape: None }],
563 &ResolveContext::new(Vec::new()),
564 ),
565 Type::Tensor {
566 shape: Some(vec![None, Some(1)])
567 }
568 );
569 }
570
571 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
572 #[test]
573 fn find_limited_first() {
574 let tensor = Tensor::new(vec![0.0, 3.0, 5.0, 0.0, 8.0], vec![1, 5]).unwrap();
575 let result =
576 find_builtin(Value::Tensor(tensor), vec![Value::Int(IntValue::I32(2))]).expect("find");
577 match result {
578 Value::Tensor(t) => {
579 assert_eq!(t.data, vec![2.0, 3.0]);
580 }
581 other => panic!("expected tensor, got {other:?}"),
582 }
583 }
584
585 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
586 #[test]
587 fn find_last_single() {
588 let tensor = Tensor::new(vec![1.0, 0.0, 0.0, 6.0, 0.0, 2.0], vec![1, 6]).unwrap();
589 let result = find_builtin(Value::Tensor(tensor), vec![Value::from("last")]).expect("find");
590 match result {
591 Value::Num(n) => assert_eq!(n, 6.0),
592 Value::Tensor(t) => {
593 assert_eq!(t.data, vec![6.0]);
594 }
595 other => panic!("unexpected result {other:?}"),
596 }
597 }
598
599 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
600 #[test]
601 fn find_complex_values() {
602 let tensor =
603 ComplexTensor::new(vec![(0.0, 0.0), (1.0, 2.0), (0.0, 0.0)], vec![3, 1]).unwrap();
604 let eval = evaluate(Value::ComplexTensor(tensor), &[]).expect("find compute");
605 let values = eval.values_value().expect("values");
606 match values {
607 Value::Complex(re, im) => {
608 assert_eq!(re, 1.0);
609 assert_eq!(im, 2.0);
610 }
611 Value::ComplexTensor(ct) => {
612 assert_eq!(ct.shape, vec![1, 1]);
613 assert_eq!(ct.data, vec![(1.0, 2.0)]);
614 }
615 other => panic!("expected complex result, got {other:?}"),
616 }
617 }
618
619 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
620 #[test]
621 fn find_gpu_roundtrip() {
622 test_support::with_test_provider(|provider| {
623 let tensor = Tensor::new(vec![0.0, 4.0, 0.0, 7.0], vec![2, 2]).unwrap();
624 let view = HostTensorView {
625 data: &tensor.data,
626 shape: &tensor.shape,
627 };
628 let handle = provider.upload(&view).expect("upload");
629 let result = find_builtin(Value::GpuTensor(handle), Vec::new()).expect("find");
630 let gathered = test_support::gather(result).expect("gather");
631 assert_eq!(gathered.shape, vec![2, 1]);
632 assert_eq!(gathered.data, vec![2.0, 4.0]);
633 });
634 }
635
636 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
637 #[test]
638 fn find_direction_error() {
639 let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
640 let err = find_builtin(
641 Value::Tensor(tensor),
642 vec![Value::Int(IntValue::I32(1)), Value::from("invalid")],
643 )
644 .expect_err("expected error");
645 assert!(err.to_string().contains("direction"));
646 }
647
648 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
649 #[test]
650 fn find_multi_output_rows_cols_values() {
651 let tensor = Tensor::new(vec![0.0, 2.0, 3.0, 0.0, 0.0, 6.0], vec![2, 3]).unwrap();
652 let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
653
654 let rows = test_support::gather(eval.row_value().expect("rows")).expect("gather rows");
655 assert_eq!(rows.shape, vec![3, 1]);
656 assert_eq!(rows.data, vec![2.0, 1.0, 2.0]);
657
658 let cols = test_support::gather(eval.column_value().expect("cols")).expect("gather cols");
659 assert_eq!(cols.shape, vec![3, 1]);
660 assert_eq!(cols.data, vec![1.0, 2.0, 3.0]);
661
662 let vals = test_support::gather(eval.values_value().expect("vals")).expect("gather vals");
663 assert_eq!(vals.shape, vec![3, 1]);
664 assert_eq!(vals.data, vec![2.0, 3.0, 6.0]);
665 }
666
667 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
668 #[test]
669 fn find_last_order_descending() {
670 let tensor = Tensor::new(vec![1.0, 0.0, 2.0, 3.0, 0.0], vec![1, 5]).unwrap();
671 let result = find_builtin(
672 Value::Tensor(tensor),
673 vec![Value::Int(IntValue::I32(2)), Value::from("last")],
674 )
675 .expect("find");
676 match result {
677 Value::Tensor(t) => {
678 assert_eq!(t.shape, vec![2, 1]);
679 assert_eq!(t.data, vec![4.0, 3.0]);
680 }
681 Value::Num(_) => panic!("expected column vector"),
682 other => panic!("unexpected result {other:?}"),
683 }
684 }
685
686 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
687 #[test]
688 fn find_limit_zero_returns_empty() {
689 let tensor = Tensor::new(vec![1.0, 0.0, 3.0], vec![3, 1]).unwrap();
690 let result = find_builtin(Value::Tensor(tensor), vec![Value::Num(0.0)]).expect("find");
691 match result {
692 Value::Tensor(t) => {
693 assert_eq!(t.shape, vec![0, 1]);
694 assert!(t.data.is_empty());
695 }
696 other => panic!("expected empty tensor, got {other:?}"),
697 }
698 }
699
700 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
701 #[test]
702 fn find_char_array_supports_nonzero_codes() {
703 let chars = CharArray::new(vec!['\0', 'A', '\0'], 1, 3).unwrap();
704 let result = find_builtin(Value::CharArray(chars), Vec::new()).expect("find");
705 match result {
706 Value::Num(n) => assert_eq!(n, 2.0),
707 Value::Tensor(t) => assert_eq!(t.data, vec![2.0]),
708 other => panic!("unexpected result {other:?}"),
709 }
710 }
711
712 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
713 #[test]
714 fn find_gpu_multi_outputs_return_gpu_handles() {
715 test_support::with_test_provider(|provider| {
716 let tensor = Tensor::new(vec![0.0, 4.0, 5.0, 0.0], vec![2, 2]).unwrap();
717 let view = HostTensorView {
718 data: &tensor.data,
719 shape: &tensor.shape,
720 };
721 let handle = provider.upload(&view).expect("upload");
722 let eval = evaluate(Value::GpuTensor(handle), &[]).expect("evaluate");
723
724 let rows = eval.row_value().expect("rows");
725 assert!(matches!(rows, Value::GpuTensor(_)));
726 let rows_host = test_support::gather(rows).expect("gather rows");
727 assert_eq!(rows_host.data, vec![2.0, 1.0]);
728
729 let cols = eval.column_value().expect("cols");
730 assert!(matches!(cols, Value::GpuTensor(_)));
731 let cols_host = test_support::gather(cols).expect("gather cols");
732 assert_eq!(cols_host.data, vec![1.0, 2.0]);
733
734 let vals = eval.values_value().expect("vals");
735 assert!(matches!(vals, Value::GpuTensor(_)));
736 let vals_host = test_support::gather(vals).expect("gather vals");
737 assert_eq!(vals_host.data, vec![4.0, 5.0]);
738 });
739 }
740
741 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
742 #[test]
743 #[cfg(feature = "wgpu")]
744 fn find_wgpu_matches_cpu() {
745 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
746 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
747 );
748 let tensor = Tensor::new(vec![0.0, 2.0, 0.0, 3.0, 4.0, 0.0], vec![3, 2]).unwrap();
749 let cpu_eval = evaluate(Value::Tensor(tensor.clone()), &[]).expect("cpu evaluate");
750 let cpu_linear =
751 test_support::gather(cpu_eval.linear_value().expect("cpu linear")).expect("cpu gather");
752 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
753 let view = HostTensorView {
754 data: &tensor.data,
755 shape: &tensor.shape,
756 };
757 let handle = provider.upload(&view).expect("upload");
758 let gpu_eval = evaluate(Value::GpuTensor(handle), &[]).expect("gpu evaluate");
759 let gpu_linear =
760 test_support::gather(gpu_eval.linear_value().expect("gpu linear")).expect("gpu gather");
761 assert_eq!(gpu_linear.data, cpu_linear.data);
762 }
763}