1use crate::builtins::acceleration::gpu::type_resolvers::arrayfun_type;
10use crate::builtins::common::spec::{
11 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
12 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
13};
14use crate::{
15 build_runtime_error, gather_if_needed_async, make_cell_with_shape, user_functions,
16 BuiltinResult, RuntimeError,
17};
18use runmat_accelerate_api::{set_handle_logical, GpuTensorHandle, HostTensorView};
19use runmat_builtins::{
20 CharArray, Closure, ComplexTensor, LogicalArray, StringArray, Tensor, Value,
21};
22use runmat_macros::runtime_builtin;
23
24#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::acceleration::gpu::arrayfun")]
25pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
26 name: "arrayfun",
27 op_kind: GpuOpKind::Elementwise,
28 supported_precisions: &[ScalarType::F32, ScalarType::F64],
29 broadcast: BroadcastSemantics::Matlab,
30 provider_hooks: &[
31 ProviderHook::Unary { name: "unary_sin" },
32 ProviderHook::Unary { name: "unary_cos" },
33 ProviderHook::Unary { name: "unary_abs" },
34 ProviderHook::Unary { name: "unary_exp" },
35 ProviderHook::Unary { name: "unary_log" },
36 ProviderHook::Unary { name: "unary_sqrt" },
37 ProviderHook::Binary {
38 name: "elem_add",
39 commutative: true,
40 },
41 ProviderHook::Binary {
42 name: "elem_sub",
43 commutative: false,
44 },
45 ProviderHook::Binary {
46 name: "elem_mul",
47 commutative: true,
48 },
49 ProviderHook::Binary {
50 name: "elem_div",
51 commutative: false,
52 },
53 ],
54 constant_strategy: ConstantStrategy::InlineLiteral,
55 residency: ResidencyPolicy::NewHandle,
56 nan_mode: ReductionNaN::Include,
57 two_pass_threshold: None,
58 workgroup_size: None,
59 accepts_nan_mode: false,
60 notes: "Providers that implement the listed kernels can run supported callbacks entirely on the GPU; unsupported callbacks fall back to the host path with re-upload.",
61};
62
63#[runmat_macros::register_fusion_spec(
64 builtin_path = "crate::builtins::acceleration::gpu::arrayfun"
65)]
66pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
67 name: "arrayfun",
68 shape: ShapeRequirements::Any,
69 constant_strategy: ConstantStrategy::InlineLiteral,
70 elementwise: None,
71 reduction: None,
72 emits_nan: false,
73 notes: "Acts as a fusion barrier because the callback can run arbitrary MATLAB code.",
74};
75
76fn arrayfun_error(message: impl Into<String>) -> RuntimeError {
77 build_runtime_error(message)
78 .with_builtin("arrayfun")
79 .build()
80}
81
82fn arrayfun_error_with_source(message: impl Into<String>, source: RuntimeError) -> RuntimeError {
83 let identifier = source.identifier().map(str::to_string);
84 let mut builder = build_runtime_error(message)
85 .with_builtin("arrayfun")
86 .with_source(source);
87 if let Some(identifier) = identifier {
88 builder = builder.with_identifier(identifier);
89 }
90 builder.build()
91}
92
93fn arrayfun_flow(message: impl Into<String>) -> RuntimeError {
94 arrayfun_error(message)
95}
96
97fn arrayfun_flow_with_source(message: impl Into<String>, source: RuntimeError) -> RuntimeError {
98 arrayfun_error_with_source(message, source)
99}
100
101fn format_handler_error(err: &RuntimeError) -> String {
102 if let Some(identifier) = err.identifier() {
103 if err.message().is_empty() {
104 return identifier.to_string();
105 }
106 if err.message().starts_with(identifier) {
107 return err.message().to_string();
108 }
109 return format!("{identifier}: {}", err.message());
110 }
111 err.message().to_string()
112}
113
114#[runtime_builtin(
115 name = "arrayfun",
116 category = "acceleration/gpu",
117 summary = "Apply a function element-wise to array inputs.",
118 keywords = "arrayfun,gpu,array,map,functional",
119 accel = "host",
120 type_resolver(arrayfun_type),
121 builtin_path = "crate::builtins::acceleration::gpu::arrayfun"
122)]
123async fn arrayfun_builtin(func: Value, mut rest: Vec<Value>) -> crate::BuiltinResult<Value> {
124 let callable = Callable::from_function(func)?;
125
126 let mut uniform_output = true;
127 let mut error_handler: Option<Callable> = None;
128
129 while rest.len() >= 2 {
130 let key_candidate = rest[rest.len() - 2].clone();
131 let Some(name) = extract_string(&key_candidate) else {
132 break;
133 };
134 let value = rest.pop().expect("value present");
135 rest.pop();
136 match name.trim().to_ascii_lowercase().as_str() {
137 "uniformoutput" => uniform_output = parse_uniform_output(value)?,
138 "errorhandler" => error_handler = Some(Callable::from_function(value)?),
139 other => {
140 return Err(arrayfun_flow(format!(
141 "arrayfun: unknown name-value argument '{other}'"
142 )))
143 }
144 }
145 }
146
147 if rest.is_empty() {
148 return Err(arrayfun_flow("arrayfun: expected at least one input array"));
149 }
150
151 let inputs_snapshot = rest.clone();
152 let has_gpu_input = inputs_snapshot
153 .iter()
154 .any(|value| matches!(value, Value::GpuTensor(_)));
155 let gpu_device_id = inputs_snapshot.iter().find_map(|v| {
156 if let Value::GpuTensor(h) = v {
157 Some(h.device_id)
158 } else {
159 None
160 }
161 });
162
163 if uniform_output {
164 if let Some(gpu_result) =
165 try_gpu_fast_path(&callable, &inputs_snapshot, error_handler.as_ref()).await?
166 {
167 return Ok(gpu_result);
168 }
169 }
170
171 let mut inputs: Vec<ArrayInput> = Vec::with_capacity(rest.len());
172 let mut base_shape: Vec<usize> = Vec::new();
173 let mut base_len: Option<usize> = None;
174
175 for (idx, raw) in rest.into_iter().enumerate() {
176 if matches!(raw, Value::Cell(_)) {
177 return Err(arrayfun_flow(
178 "arrayfun: cell inputs are not supported (use cellfun instead)",
179 ));
180 }
181 if matches!(raw, Value::Struct(_)) {
182 return Err(arrayfun_flow("arrayfun: struct inputs are not supported"));
183 }
184
185 let host_value = gather_if_needed_async(&raw).await?;
186 let data = ArrayData::from_value(host_value)?;
187 let len = data.len();
188 let is_scalar = len == 1;
189
190 let mut input = ArrayInput { data, is_scalar };
191
192 if let Some(current) = base_len {
193 if current == len {
194 if len > 1 {
195 let shape = input.shape_vec();
196 if shape != base_shape {
197 return Err(arrayfun_flow(format!(
198 "arrayfun: input {} does not match the size of the first array",
199 idx + 1
200 )));
201 }
202 }
203 } else if len == 1 {
204 input.is_scalar = true;
205 } else if current == 1 {
206 base_len = Some(len);
207 base_shape = input.shape_vec();
208 for prior in &mut inputs {
209 let prior_len = prior.len();
210 if prior_len == len {
211 if prior.shape_vec() != base_shape {
212 return Err(arrayfun_flow(format!(
213 "arrayfun: input {} does not match the size of the first array",
214 idx
215 )));
216 }
217 } else if prior_len == 1 {
218 prior.is_scalar = true;
219 } else if prior_len == 0 && len == 0 {
220 continue;
221 } else {
222 return Err(arrayfun_flow(format!(
223 "arrayfun: input {} does not match the size of the first array",
224 idx
225 )));
226 }
227 }
228 } else if len == 0 && current == 0 {
229 let shape = input.shape_vec();
230 if shape != base_shape {
231 return Err(arrayfun_flow(format!(
232 "arrayfun: input {} does not match the size of the first array",
233 idx + 1
234 )));
235 }
236 } else {
237 return Err(arrayfun_flow(format!(
238 "arrayfun: input {} does not match the size of the first array",
239 idx + 1
240 )));
241 }
242 } else {
243 base_len = Some(len);
244 base_shape = input.shape_vec();
245 }
246
247 inputs.push(input);
248 }
249
250 let total_len = base_len.unwrap_or(0);
251
252 if total_len == 0 {
253 if uniform_output {
254 return Ok(empty_uniform(&base_shape));
255 } else {
256 return make_cell_with_shape(Vec::new(), base_shape)
257 .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")));
258 }
259 }
260
261 let mut collector = if uniform_output {
262 Some(UniformCollector::Pending)
263 } else {
264 None
265 };
266
267 let mut cell_outputs: Vec<Value> = Vec::new();
268 let mut args: Vec<Value> = Vec::with_capacity(inputs.len());
269
270 for idx in 0..total_len {
271 args.clear();
272 for input in &inputs {
273 args.push(input.value_at(idx)?);
274 }
275
276 let result = match callable.call(&args).await {
277 Ok(value) => value,
278 Err(err) => {
279 let handler = match error_handler.as_ref() {
280 Some(handler) => handler,
281 None => {
282 return Err(arrayfun_flow_with_source(
283 format!("arrayfun: {}", err.message()),
284 err,
285 ))
286 }
287 };
288 let err_message = format_handler_error(&err);
289 let err_value = make_error_struct(&err_message, idx, &base_shape)?;
290 let mut handler_args = Vec::with_capacity(1 + args.len());
291 handler_args.push(err_value);
292 handler_args.extend(args.clone());
293 handler.call(&handler_args).await?
294 }
295 };
296
297 let host_result = gather_if_needed_async(&result).await?;
298
299 if let Some(collector) = collector.as_mut() {
300 collector.push(&host_result)?;
301 } else {
302 cell_outputs.push(host_result);
303 }
304 }
305
306 if let Some(collector) = collector {
307 let uniform = collector.finish(&base_shape)?;
308 maybe_upload_uniform(uniform, has_gpu_input, gpu_device_id)
309 } else {
310 make_cell_with_shape(cell_outputs, base_shape)
311 .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))
312 }
313}
314
315fn maybe_upload_uniform(
316 value: Value,
317 has_gpu_input: bool,
318 gpu_device_id: Option<u32>,
319) -> BuiltinResult<Value> {
320 if !has_gpu_input {
321 return Ok(value);
322 }
323 #[cfg(all(test, feature = "wgpu"))]
324 {
325 if matches!(gpu_device_id, Some(id) if id != 0) {
326 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
327 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
328 );
329 }
330 }
331 let _ = gpu_device_id; let provider = match runmat_accelerate_api::provider() {
333 Some(p) => p,
334 None => return Ok(value),
335 };
336
337 match value {
338 Value::Tensor(tensor) => {
339 let view = HostTensorView {
340 data: &tensor.data,
341 shape: &tensor.shape,
342 };
343 let handle = provider
344 .upload(&view)
345 .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?;
346 Ok(Value::GpuTensor(handle))
347 }
348 Value::LogicalArray(logical) => {
349 let data: Vec<f64> = logical
350 .data
351 .iter()
352 .map(|&bit| if bit != 0 { 1.0 } else { 0.0 })
353 .collect();
354 let tensor = Tensor::new(data, logical.shape.clone())
355 .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?;
356 let view = HostTensorView {
357 data: &tensor.data,
358 shape: &tensor.shape,
359 };
360 let handle = provider
361 .upload(&view)
362 .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?;
363 set_handle_logical(&handle, true);
364 Ok(Value::GpuTensor(handle))
365 }
366 other => Ok(other),
367 }
368}
369
370fn empty_uniform(shape: &[usize]) -> Value {
371 if shape.is_empty() {
372 return Value::Tensor(Tensor::zeros(vec![0, 0]));
373 }
374 let total: usize = shape.iter().product();
375 let tensor = Tensor::new(vec![0.0; total], shape.to_vec())
376 .unwrap_or_else(|_| Tensor::zeros(shape.to_vec()));
377 Value::Tensor(tensor)
378}
379
380fn parse_uniform_output(value: Value) -> BuiltinResult<bool> {
381 match value {
382 Value::Bool(b) => Ok(b),
383 Value::Num(n) => Ok(n != 0.0),
384 Value::Int(iv) => Ok(iv.to_f64() != 0.0),
385 Value::String(s) => parse_bool_string(&s)
386 .ok_or_else(|| arrayfun_flow("arrayfun: UniformOutput must be logical true or false")),
387 Value::CharArray(ca) if ca.rows == 1 => {
388 let text: String = ca.data.iter().collect();
389 parse_bool_string(&text).ok_or_else(|| {
390 arrayfun_flow("arrayfun: UniformOutput must be logical true or false")
391 })
392 }
393 other => Err(arrayfun_flow(format!(
394 "arrayfun: UniformOutput must be logical true or false, got {other:?}"
395 ))),
396 }
397}
398
399fn parse_bool_string(value: &str) -> Option<bool> {
400 match value.trim().to_ascii_lowercase().as_str() {
401 "true" | "on" => Some(true),
402 "false" | "off" => Some(false),
403 _ => None,
404 }
405}
406
407fn extract_string(value: &Value) -> Option<String> {
408 match value {
409 Value::String(s) => Some(s.clone()),
410 Value::CharArray(ca) if ca.rows == 1 => Some(ca.data.iter().collect()),
411 Value::StringArray(sa) if sa.data.len() == 1 => Some(sa.data[0].clone()),
412 _ => None,
413 }
414}
415
416struct ArrayInput {
417 data: ArrayData,
418 is_scalar: bool,
419}
420
421impl ArrayInput {
422 fn len(&self) -> usize {
423 self.data.len()
424 }
425
426 fn shape_vec(&self) -> Vec<usize> {
427 self.data.shape_vec()
428 }
429
430 fn value_at(&self, idx: usize) -> BuiltinResult<Value> {
431 if self.is_scalar {
432 self.data.value_at(0)
433 } else {
434 self.data.value_at(idx)
435 }
436 }
437}
438
439enum ArrayData {
440 Tensor(Tensor),
441 Logical(LogicalArray),
442 Complex(ComplexTensor),
443 Char(CharArray),
444 String(StringArray),
445 Scalar(Value),
446}
447
448impl ArrayData {
449 fn from_value(value: Value) -> BuiltinResult<Self> {
450 match value {
451 Value::Tensor(t) => Ok(ArrayData::Tensor(t)),
452 Value::LogicalArray(l) => Ok(ArrayData::Logical(l)),
453 Value::ComplexTensor(c) => Ok(ArrayData::Complex(c)),
454 Value::CharArray(ca) => Ok(ArrayData::Char(ca)),
455 Value::StringArray(sa) => Ok(ArrayData::String(sa)),
456 Value::Num(_)
457 | Value::Bool(_)
458 | Value::Int(_)
459 | Value::Complex(_, _)
460 | Value::String(_) => {
461 Ok(ArrayData::Scalar(value))
462 }
463 other => Err(arrayfun_flow(format!(
464 "arrayfun: unsupported input type {other:?} (expected numeric, logical, complex, char, or string arrays)"
465 ))),
466 }
467 }
468
469 fn len(&self) -> usize {
470 match self {
471 ArrayData::Tensor(t) => t.data.len(),
472 ArrayData::Logical(l) => l.data.len(),
473 ArrayData::Complex(c) => c.data.len(),
474 ArrayData::Char(ca) => ca.rows * ca.cols,
475 ArrayData::String(sa) => sa.data.len(),
476 ArrayData::Scalar(_) => 1,
477 }
478 }
479
480 fn shape_vec(&self) -> Vec<usize> {
481 match self {
482 ArrayData::Tensor(t) => {
483 if t.shape.is_empty() {
484 vec![1, 1]
485 } else {
486 t.shape.clone()
487 }
488 }
489 ArrayData::Logical(l) => {
490 if l.shape.is_empty() {
491 vec![1, 1]
492 } else {
493 l.shape.clone()
494 }
495 }
496 ArrayData::Complex(c) => {
497 if c.shape.is_empty() {
498 vec![1, 1]
499 } else {
500 c.shape.clone()
501 }
502 }
503 ArrayData::Char(ca) => vec![ca.rows, ca.cols],
504 ArrayData::String(sa) => {
505 if sa.shape.is_empty() {
506 vec![1, 1]
507 } else {
508 sa.shape.clone()
509 }
510 }
511 ArrayData::Scalar(_) => vec![1, 1],
512 }
513 }
514
515 fn value_at(&self, idx: usize) -> BuiltinResult<Value> {
516 match self {
517 ArrayData::Tensor(t) => {
518 Ok(Value::Num(*t.data.get(idx).ok_or_else(|| {
519 arrayfun_flow("arrayfun: index out of bounds")
520 })?))
521 }
522 ArrayData::Logical(l) => Ok(Value::Bool(
523 *l.data
524 .get(idx)
525 .ok_or_else(|| arrayfun_flow("arrayfun: index out of bounds"))?
526 != 0,
527 )),
528 ArrayData::Complex(c) => {
529 let (re, im) = c
530 .data
531 .get(idx)
532 .ok_or_else(|| arrayfun_flow("arrayfun: index out of bounds"))?;
533 Ok(Value::Complex(*re, *im))
534 }
535 ArrayData::Char(ca) => {
536 if ca.rows == 0 || ca.cols == 0 {
537 return Ok(Value::CharArray(
538 CharArray::new(Vec::new(), 0, 0)
539 .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?,
540 ));
541 }
542 let rows = ca.rows;
543 let cols = ca.cols;
544 let row = idx % rows;
545 let col = idx / rows;
546 let data_idx = row * cols + col;
547 let ch = *ca
548 .data
549 .get(data_idx)
550 .ok_or_else(|| arrayfun_flow("arrayfun: index out of bounds"))?;
551 let char_array = CharArray::new(vec![ch], 1, 1)
552 .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?;
553 Ok(Value::CharArray(char_array))
554 }
555 ArrayData::String(sa) => {
556 Ok(Value::String(sa.data.get(idx).cloned().ok_or_else(
557 || arrayfun_flow("arrayfun: index out of bounds"),
558 )?))
559 }
560 ArrayData::Scalar(v) => Ok(v.clone()),
561 }
562 }
563}
564
565#[derive(Clone)]
566enum Callable {
567 Builtin { name: String },
568 Closure(Closure),
569}
570
571impl Callable {
572 fn from_function(value: Value) -> BuiltinResult<Self> {
573 match value {
574 Value::String(text) => Self::from_text(&text),
575 Value::CharArray(ca) => {
576 if ca.rows != 1 {
577 Err(arrayfun_flow(
578 "arrayfun: function name must be a character vector or string scalar",
579 ))
580 } else {
581 let text: String = ca.data.iter().collect();
582 Self::from_text(&text)
583 }
584 }
585 Value::StringArray(sa) if sa.data.len() == 1 => Self::from_text(&sa.data[0]),
586 Value::FunctionHandle(name) => Self::from_text(&name),
587 Value::Closure(closure) => Ok(Callable::Closure(closure)),
588 Value::Num(_) | Value::Int(_) | Value::Bool(_) => Err(arrayfun_flow(
589 "arrayfun: expected function handle or builtin name, not a scalar value",
590 )),
591 other => Err(arrayfun_flow(format!(
592 "arrayfun: expected function handle or builtin name, got {other:?}"
593 ))),
594 }
595 }
596
597 fn from_text(text: &str) -> BuiltinResult<Self> {
598 let trimmed = text.trim();
599 if trimmed.is_empty() {
600 return Err(arrayfun_flow(
601 "arrayfun: expected function handle or builtin name, got empty string",
602 ));
603 }
604 if let Some(rest) = trimmed.strip_prefix('@') {
605 let name = rest.trim();
606 if name.is_empty() {
607 Err(arrayfun_flow("arrayfun: empty function handle"))
608 } else {
609 Ok(Callable::Builtin {
610 name: name.to_string(),
611 })
612 }
613 } else {
614 Ok(Callable::Builtin {
615 name: trimmed.to_ascii_lowercase(),
616 })
617 }
618 }
619
620 fn builtin_name(&self) -> Option<&str> {
621 match self {
622 Callable::Builtin { name } => Some(name.as_str()),
623 Callable::Closure(_) => None,
624 }
625 }
626
627 async fn call(&self, args: &[Value]) -> crate::BuiltinResult<Value> {
628 match self {
629 Callable::Builtin { name } => {
630 if let Some(result) = user_functions::try_call_user_function(name, args).await {
631 return result;
632 }
633 crate::call_builtin_async(name, args).await
634 }
635 Callable::Closure(c) => {
636 let mut merged = c.captures.clone();
637 merged.extend_from_slice(args);
638 if let Some(result) =
639 user_functions::try_call_user_function(&c.function_name, &merged).await
640 {
641 return result;
642 }
643 crate::call_builtin_async(&c.function_name, &merged).await
644 }
645 }
646 }
647}
648
649async fn try_gpu_fast_path(
650 callable: &Callable,
651 inputs: &[Value],
652 error_handler: Option<&Callable>,
653) -> BuiltinResult<Option<Value>> {
654 if inputs.is_empty() || error_handler.is_some() {
655 return Ok(None);
656 }
657 if !inputs
658 .iter()
659 .all(|value| matches!(value, Value::GpuTensor(_)))
660 {
661 return Ok(None);
662 }
663
664 #[cfg(all(test, feature = "wgpu"))]
665 {
666 if inputs
667 .iter()
668 .any(|v| matches!(v, Value::GpuTensor(h) if h.device_id != 0))
669 {
670 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
671 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
672 );
673 }
674 }
675 let provider = match runmat_accelerate_api::provider() {
676 Some(p) => p,
677 None => return Ok(None),
678 };
679
680 let Some(name_raw) = callable.builtin_name() else {
681 return Ok(None);
682 };
683 let name = name_raw.to_ascii_lowercase();
684
685 let mut handles: Vec<GpuTensorHandle> = Vec::with_capacity(inputs.len());
686 for value in inputs {
687 if let Value::GpuTensor(handle) = value {
688 handles.push(handle.clone());
689 }
690 }
691
692 if handles.len() >= 2 {
693 let base_shape = handles[0].shape.clone();
694 if handles
695 .iter()
696 .skip(1)
697 .any(|handle| handle.shape != base_shape)
698 {
699 return Ok(None);
700 }
701 }
702
703 let result = match name.as_str() {
704 "sin" if handles.len() == 1 => provider.unary_sin(&handles[0]).await,
705 "cos" if handles.len() == 1 => provider.unary_cos(&handles[0]).await,
706 "abs" if handles.len() == 1 => provider.unary_abs(&handles[0]).await,
707 "exp" if handles.len() == 1 => provider.unary_exp(&handles[0]).await,
708 "log" if handles.len() == 1 => provider.unary_log(&handles[0]).await,
709 "sqrt" if handles.len() == 1 => provider.unary_sqrt(&handles[0]).await,
710 "plus" if handles.len() == 2 => provider.elem_add(&handles[0], &handles[1]).await,
711 "minus" if handles.len() == 2 => provider.elem_sub(&handles[0], &handles[1]).await,
712 "times" if handles.len() == 2 => provider.elem_mul(&handles[0], &handles[1]).await,
713 "rdivide" if handles.len() == 2 => provider.elem_div(&handles[0], &handles[1]).await,
714 "ldivide" if handles.len() == 2 => provider.elem_div(&handles[1], &handles[0]).await,
715 _ => return Ok(None),
716 };
717
718 match result {
719 Ok(handle) => Ok(Some(Value::GpuTensor(handle))),
720 Err(_) => Ok(None),
721 }
722}
723
724enum UniformCollector {
725 Pending,
726 Double(Vec<f64>),
727 Logical(Vec<u8>),
728 Complex(Vec<(f64, f64)>),
729 Char(Vec<char>),
730}
731
732impl UniformCollector {
733 fn push(&mut self, value: &Value) -> BuiltinResult<()> {
734 match self {
735 UniformCollector::Pending => match classify_value(value)? {
736 ClassifiedValue::Logical(b) => {
737 *self = UniformCollector::Logical(vec![b as u8]);
738 Ok(())
739 }
740 ClassifiedValue::Double(d) => {
741 *self = UniformCollector::Double(vec![d]);
742 Ok(())
743 }
744 ClassifiedValue::Complex(c) => {
745 *self = UniformCollector::Complex(vec![c]);
746 Ok(())
747 }
748 ClassifiedValue::Char(ch) => {
749 *self = UniformCollector::Char(vec![ch]);
750 Ok(())
751 }
752 },
753 UniformCollector::Logical(bits) => match classify_value(value)? {
754 ClassifiedValue::Logical(b) => {
755 bits.push(b as u8);
756 Ok(())
757 }
758 ClassifiedValue::Double(d) => {
759 let mut data: Vec<f64> = bits
760 .iter()
761 .map(|&bit| if bit != 0 { 1.0 } else { 0.0 })
762 .collect();
763 data.push(d);
764 *self = UniformCollector::Double(data);
765 Ok(())
766 }
767 ClassifiedValue::Complex(c) => {
768 let mut data: Vec<(f64, f64)> = bits
769 .iter()
770 .map(|&bit| if bit != 0 { (1.0, 0.0) } else { (0.0, 0.0) })
771 .collect();
772 data.push(c);
773 *self = UniformCollector::Complex(data);
774 Ok(())
775 }
776 ClassifiedValue::Char(ch) => {
777 let mut data: Vec<f64> = bits
778 .iter()
779 .map(|&bit| if bit != 0 { 1.0 } else { 0.0 })
780 .collect();
781 data.push(ch as u32 as f64);
782 *self = UniformCollector::Double(data);
783 Ok(())
784 }
785 },
786 UniformCollector::Double(data) => match classify_value(value)? {
787 ClassifiedValue::Logical(b) => {
788 data.push(if b { 1.0 } else { 0.0 });
789 Ok(())
790 }
791 ClassifiedValue::Double(d) => {
792 data.push(d);
793 Ok(())
794 }
795 ClassifiedValue::Complex(c) => {
796 let promoted: Vec<(f64, f64)> = data.iter().map(|&v| (v, 0.0)).collect();
797 let mut complex = promoted;
798 complex.push(c);
799 *self = UniformCollector::Complex(complex);
800 Ok(())
801 }
802 ClassifiedValue::Char(ch) => {
803 data.push(ch as u32 as f64);
804 Ok(())
805 }
806 },
807 UniformCollector::Complex(data) => match classify_value(value)? {
808 ClassifiedValue::Logical(b) => {
809 data.push((if b { 1.0 } else { 0.0 }, 0.0));
810 Ok(())
811 }
812 ClassifiedValue::Double(d) => {
813 data.push((d, 0.0));
814 Ok(())
815 }
816 ClassifiedValue::Complex(c) => {
817 data.push(c);
818 Ok(())
819 }
820 ClassifiedValue::Char(ch) => {
821 data.push((ch as u32 as f64, 0.0));
822 Ok(())
823 }
824 },
825 UniformCollector::Char(chars) => match classify_value(value)? {
826 ClassifiedValue::Char(ch) => {
827 chars.push(ch);
828 Ok(())
829 }
830 ClassifiedValue::Logical(b) => {
831 let mut data: Vec<f64> = chars.iter().map(|&ch| ch as u32 as f64).collect();
832 data.push(if b { 1.0 } else { 0.0 });
833 *self = UniformCollector::Double(data);
834 Ok(())
835 }
836 ClassifiedValue::Double(d) => {
837 let mut data: Vec<f64> = chars.iter().map(|&ch| ch as u32 as f64).collect();
838 data.push(d);
839 *self = UniformCollector::Double(data);
840 Ok(())
841 }
842 ClassifiedValue::Complex(c) => {
843 let mut promoted: Vec<(f64, f64)> =
844 chars.iter().map(|&ch| (ch as u32 as f64, 0.0)).collect();
845 promoted.push(c);
846 *self = UniformCollector::Complex(promoted);
847 Ok(())
848 }
849 },
850 }
851 }
852
853 fn finish(self, shape: &[usize]) -> BuiltinResult<Value> {
854 match self {
855 UniformCollector::Pending => {
856 let total = shape.iter().product();
857 let tensor = Tensor::new(vec![0.0; total], shape.to_vec())
858 .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?;
859 Ok(Value::Tensor(tensor))
860 }
861 UniformCollector::Double(data) => {
862 let tensor = Tensor::new(data, shape.to_vec())
863 .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?;
864 Ok(Value::Tensor(tensor))
865 }
866 UniformCollector::Logical(bits) => {
867 let logical = LogicalArray::new(bits, shape.to_vec())
868 .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?;
869 Ok(Value::LogicalArray(logical))
870 }
871 UniformCollector::Complex(entries) => {
872 let tensor = ComplexTensor::new(entries, shape.to_vec())
873 .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?;
874 Ok(Value::ComplexTensor(tensor))
875 }
876 UniformCollector::Char(chars) => {
877 let normalized_shape = if shape.is_empty() {
878 vec![1, 1]
879 } else {
880 shape.to_vec()
881 };
882
883 if normalized_shape.len() > 2 {
884 return Err(arrayfun_flow(
885 "arrayfun: character outputs with UniformOutput=true must be 2-D",
886 ));
887 }
888
889 let rows = normalized_shape.first().copied().unwrap_or(1);
890 let cols = normalized_shape.get(1).copied().unwrap_or(1);
891 let expected = rows.checked_mul(cols).ok_or_else(|| {
892 arrayfun_flow("arrayfun: character output size exceeds platform limits")
893 })?;
894
895 if expected != chars.len() {
896 return Err(arrayfun_flow(
897 "arrayfun: callback returned the wrong number of characters",
898 ));
899 }
900
901 let mut row_major = vec!['\0'; expected];
902 for col in 0..cols {
903 for row in 0..rows {
904 let col_major_idx = row + col * rows;
905 let row_major_idx = row * cols + col;
906 row_major[row_major_idx] = chars[col_major_idx];
907 }
908 }
909
910 let array = CharArray::new(row_major, rows, cols)
911 .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?;
912 Ok(Value::CharArray(array))
913 }
914 }
915 }
916}
917
918enum ClassifiedValue {
919 Logical(bool),
920 Double(f64),
921 Complex((f64, f64)),
922 Char(char),
923}
924
925fn classify_value(value: &Value) -> BuiltinResult<ClassifiedValue> {
926 match value {
927 Value::Bool(b) => Ok(ClassifiedValue::Logical(*b)),
928 Value::LogicalArray(la) if la.len() == 1 => Ok(ClassifiedValue::Logical(la.data[0] != 0)),
929 Value::Int(i) => Ok(ClassifiedValue::Double(i.to_f64())),
930 Value::Num(n) => Ok(ClassifiedValue::Double(*n)),
931 Value::Tensor(t) if t.data.len() == 1 => Ok(ClassifiedValue::Double(t.data[0])),
932 Value::Complex(re, im) => Ok(ClassifiedValue::Complex((*re, *im))),
933 Value::ComplexTensor(t) if t.data.len() == 1 => Ok(ClassifiedValue::Complex(t.data[0])),
934 Value::CharArray(ca) if ca.rows * ca.cols == 1 => {
935 let ch = ca.data.first().copied().unwrap_or('\0');
936 Ok(ClassifiedValue::Char(ch))
937 }
938 other => Err(arrayfun_flow(format!(
939 "arrayfun: callback must return scalar numeric, logical, character, or complex values for UniformOutput=true (got {other:?})"
940 ))),
941 }
942}
943
944fn make_error_struct(
945 raw_error: &str,
946 linear_index: usize,
947 shape: &[usize],
948) -> BuiltinResult<Value> {
949 let (identifier, message) = split_error_message(raw_error);
950 let mut st = runmat_builtins::StructValue::new();
951 st.fields
952 .insert("identifier".to_string(), Value::String(identifier));
953 st.fields
954 .insert("message".to_string(), Value::String(message));
955 st.fields
956 .insert("index".to_string(), Value::Num((linear_index + 1) as f64));
957 let subs = linear_to_indices(linear_index, shape);
958 let subs_tensor = dims_to_row_tensor(&subs)?;
959 st.fields
960 .insert("indices".to_string(), Value::Tensor(subs_tensor));
961 Ok(Value::Struct(st))
962}
963
964fn split_error_message(raw: &str) -> (String, String) {
965 let trimmed = raw.trim();
966 let mut indices = trimmed.match_indices(':');
967 if let Some((_, _)) = indices.next() {
968 if let Some((second_idx, _)) = indices.next() {
969 let identifier = trimmed[..second_idx].trim().to_string();
970 let message = trimmed[second_idx + 1..].trim().to_string();
971 if !identifier.is_empty() && identifier.contains(':') {
972 return (
973 identifier,
974 if message.is_empty() {
975 trimmed.to_string()
976 } else {
977 message
978 },
979 );
980 }
981 } else if trimmed.len() >= 7
982 && (trimmed[..7].eq_ignore_ascii_case("matlab:")
983 || trimmed[..7].eq_ignore_ascii_case("runmat:"))
984 {
985 return (trimmed.to_string(), String::new());
986 }
987 }
988 (
989 "RunMat:arrayfun:FunctionError".to_string(),
990 trimmed.to_string(),
991 )
992}
993
994fn linear_to_indices(mut index: usize, shape: &[usize]) -> Vec<usize> {
995 if shape.is_empty() {
996 return vec![1];
997 }
998 let mut subs = Vec::with_capacity(shape.len());
999 for &dim in shape {
1000 if dim == 0 {
1001 subs.push(1);
1002 continue;
1003 }
1004 let coord = (index % dim) + 1;
1005 subs.push(coord);
1006 index /= dim;
1007 }
1008 subs
1009}
1010
1011fn dims_to_row_tensor(dims: &[usize]) -> BuiltinResult<Tensor> {
1012 let data: Vec<f64> = dims.iter().map(|&d| d as f64).collect();
1013 Tensor::new(data, vec![1, dims.len()]).map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))
1014}
1015
1016#[cfg(test)]
1017pub(crate) mod tests {
1018 use super::*;
1019 use crate::builtins::common::test_support;
1020 use futures::executor::block_on;
1021 use runmat_accelerate_api::HostTensorView;
1022 use runmat_builtins::{ResolveContext, Tensor, Type};
1023
1024 fn call(func: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
1025 block_on(arrayfun_builtin(func, rest))
1026 }
1027
1028 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1029 #[test]
1030 fn arrayfun_basic_sin() {
1031 let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0], vec![2, 3]).unwrap();
1032 let expected: Vec<f64> = tensor.data.iter().map(|&x| x.sin()).collect();
1033 let result = call(
1034 Value::FunctionHandle("sin".to_string()),
1035 vec![Value::Tensor(tensor.clone())],
1036 )
1037 .expect("arrayfun");
1038 match result {
1039 Value::Tensor(out) => {
1040 assert_eq!(out.shape, vec![2, 3]);
1041 assert_eq!(out.data, expected);
1042 }
1043 other => panic!("expected tensor, got {other:?}"),
1044 }
1045 }
1046
1047 #[test]
1048 fn arrayfun_type_tracks_function_returns() {
1049 let func = Type::Function {
1050 params: vec![Type::Num],
1051 returns: Box::new(Type::Num),
1052 };
1053 assert_eq!(
1054 arrayfun_type(&[func, Type::tensor()], &ResolveContext::new(Vec::new())),
1055 Type::tensor()
1056 );
1057 }
1058
1059 #[test]
1060 fn arrayfun_type_uses_logical_returns() {
1061 let func = Type::Function {
1062 params: vec![Type::Num],
1063 returns: Box::new(Type::Bool),
1064 };
1065 assert_eq!(
1066 arrayfun_type(&[func, Type::tensor()], &ResolveContext::new(Vec::new())),
1067 Type::logical()
1068 );
1069 }
1070
1071 #[test]
1072 fn arrayfun_type_with_text_args_stays_unknown() {
1073 let func = Type::Function {
1074 params: vec![Type::Num],
1075 returns: Box::new(Type::Num),
1076 };
1077 assert_eq!(
1078 arrayfun_type(
1079 &[func, Type::tensor(), Type::String, Type::Bool],
1080 &ResolveContext::new(Vec::new()),
1081 ),
1082 Type::Unknown
1083 );
1084 }
1085
1086 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1087 #[test]
1088 fn arrayfun_additional_scalar_argument() {
1089 let tensor = Tensor::new(vec![0.5, 1.0, -1.0], vec![3, 1]).unwrap();
1090 let expected: Vec<f64> = tensor.data.iter().map(|&y| y.atan2(1.0)).collect();
1091 let result = call(
1092 Value::FunctionHandle("atan2".to_string()),
1093 vec![Value::Tensor(tensor), Value::Num(1.0)],
1094 )
1095 .expect("arrayfun");
1096 match result {
1097 Value::Tensor(out) => {
1098 assert_eq!(out.data, expected);
1099 }
1100 other => panic!("expected tensor, got {other:?}"),
1101 }
1102 }
1103
1104 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1105 #[test]
1106 fn arrayfun_uniform_false_returns_cell() {
1107 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1108 let expected: Vec<Value> = tensor.data.iter().map(|&x| Value::Num(x.sin())).collect();
1109 let result = call(
1110 Value::FunctionHandle("sin".to_string()),
1111 vec![
1112 Value::Tensor(tensor),
1113 Value::String("UniformOutput".into()),
1114 Value::Bool(false),
1115 ],
1116 )
1117 .expect("arrayfun");
1118 let Value::Cell(cell) = result else {
1119 panic!("expected cell, got something else");
1120 };
1121 assert_eq!(cell.rows, 2);
1122 assert_eq!(cell.cols, 1);
1123 for (row, value) in expected.iter().enumerate() {
1124 assert_eq!(cell.get(row, 0).unwrap(), *value);
1125 }
1126 }
1127
1128 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1129 #[test]
1130 fn arrayfun_size_mismatch_errors() {
1131 let taller = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1132 let shorter = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1133 let err = call(
1134 Value::FunctionHandle("sin".to_string()),
1135 vec![Value::Tensor(taller), Value::Tensor(shorter)],
1136 )
1137 .expect_err("expected size mismatch error");
1138 let err = err.to_string();
1139 assert!(
1140 err.contains("does not match"),
1141 "expected size mismatch error, got {err}"
1142 );
1143 }
1144
1145 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1146 #[test]
1147 fn arrayfun_error_handler_recovers() {
1148 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1149 let handler = Value::Closure(Closure {
1150 function_name: "__arrayfun_test_handler".into(),
1151 captures: vec![Value::Num(42.0)],
1152 });
1153 let result = call(
1154 Value::String("@nonexistent_builtin".into()),
1155 vec![
1156 Value::Tensor(tensor),
1157 Value::String("ErrorHandler".into()),
1158 handler,
1159 ],
1160 )
1161 .expect("arrayfun error handler");
1162 match result {
1163 Value::Tensor(out) => {
1164 assert_eq!(out.shape, vec![3, 1]);
1165 assert_eq!(out.data, vec![42.0, 42.0, 42.0]);
1166 }
1167 other => panic!("expected tensor, got {other:?}"),
1168 }
1169 }
1170
1171 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1172 #[test]
1173 fn arrayfun_error_without_handler_propagates_identifier() {
1174 let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1175 let err = call(
1176 Value::String("@nonexistent_builtin".into()),
1177 vec![Value::Tensor(tensor)],
1178 )
1179 .expect_err("expected unresolved function error");
1180 assert_eq!(
1181 err.identifier(),
1182 Some("RunMat:UndefinedFunction"),
1183 "unexpected error: {}",
1184 err.message()
1185 );
1186 }
1187
1188 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1189 #[test]
1190 fn arrayfun_uniform_logical_result() {
1191 let tensor = Tensor::new(vec![1.0, f64::NAN, 0.0, f64::INFINITY], vec![4, 1]).unwrap();
1192 let result = call(
1193 Value::FunctionHandle("isfinite".to_string()),
1194 vec![Value::Tensor(tensor)],
1195 )
1196 .expect("arrayfun isfinite");
1197 match result {
1198 Value::LogicalArray(la) => {
1199 assert_eq!(la.shape, vec![4, 1]);
1200 assert_eq!(la.data, vec![1, 0, 1, 0]);
1201 }
1202 other => panic!("expected logical array, got {other:?}"),
1203 }
1204 }
1205
1206 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1207 #[test]
1208 fn arrayfun_uniform_character_result() {
1209 let tensor = Tensor::new(vec![65.0, 66.0, 67.0], vec![1, 3]).unwrap();
1210 let result = call(
1211 Value::FunctionHandle("char".to_string()),
1212 vec![Value::Tensor(tensor)],
1213 )
1214 .expect("arrayfun char");
1215 match result {
1216 Value::CharArray(ca) => {
1217 assert_eq!(ca.rows, 1);
1218 assert_eq!(ca.cols, 3);
1219 assert_eq!(ca.data, vec!['A', 'B', 'C']);
1220 }
1221 other => panic!("expected char array, got {other:?}"),
1222 }
1223 }
1224
1225 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1226 #[test]
1227 fn arrayfun_uniform_false_gpu_returns_cell() {
1228 test_support::with_test_provider(|provider| {
1229 let tensor = Tensor::new(vec![0.0, 1.0], vec![2, 1]).unwrap();
1230 let view = HostTensorView {
1231 data: &tensor.data,
1232 shape: &tensor.shape,
1233 };
1234 let handle = provider.upload(&view).expect("upload");
1235 let result = call(
1236 Value::FunctionHandle("sin".to_string()),
1237 vec![
1238 Value::GpuTensor(handle),
1239 Value::String("UniformOutput".into()),
1240 Value::Bool(false),
1241 ],
1242 )
1243 .expect("arrayfun");
1244 match result {
1245 Value::Cell(cell) => {
1246 assert_eq!(cell.rows, 2);
1247 assert_eq!(cell.cols, 1);
1248 let first = cell.get(0, 0).expect("first cell");
1249 let second = cell.get(1, 0).expect("second cell");
1250 match (first, second) {
1251 (Value::Num(a), Value::Num(b)) => {
1252 assert!((a - 0.0f64.sin()).abs() < 1e-12);
1253 assert!((b - 1.0f64.sin()).abs() < 1e-12);
1254 }
1255 other => panic!("expected numeric cells, got {other:?}"),
1256 }
1257 }
1258 other => panic!("expected cell, got {other:?}"),
1259 }
1260 });
1261 }
1262
1263 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1264 #[test]
1265 fn arrayfun_gpu_roundtrip() {
1266 test_support::with_test_provider(|provider| {
1267 let tensor = Tensor::new(vec![0.0, 1.0, 2.0, 3.0], vec![4, 1]).unwrap();
1268 let view = HostTensorView {
1269 data: &tensor.data,
1270 shape: &tensor.shape,
1271 };
1272 let handle = provider.upload(&view).expect("upload");
1273 let result = call(
1274 Value::FunctionHandle("sin".to_string()),
1275 vec![Value::GpuTensor(handle)],
1276 )
1277 .expect("arrayfun");
1278 match result {
1279 Value::GpuTensor(gpu) => {
1280 let gathered = test_support::gather(Value::GpuTensor(gpu.clone())).unwrap();
1281 let expected: Vec<f64> = tensor.data.iter().map(|&x| x.sin()).collect();
1282 assert_eq!(gathered.data, expected);
1283 let _ = provider.free(&gpu);
1284 }
1285 other => panic!("expected gpu tensor, got {other:?}"),
1286 }
1287 });
1288 }
1289
1290 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1291 #[test]
1292 #[cfg(feature = "wgpu")]
1293 fn arrayfun_wgpu_sin_matches_cpu() {
1294 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1295 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1296 );
1297 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1298
1299 let tensor = Tensor::new(vec![0.0, 1.0, 2.0, 3.0], vec![4, 1]).unwrap();
1300 let view = HostTensorView {
1301 data: &tensor.data,
1302 shape: &tensor.shape,
1303 };
1304 let handle = provider.upload(&view).expect("upload");
1305 let result = call(
1306 Value::FunctionHandle("sin".into()),
1307 vec![Value::GpuTensor(handle.clone())],
1308 )
1309 .expect("arrayfun sin gpu");
1310 let Value::GpuTensor(out_handle) = result else {
1311 panic!("expected GPU tensor result");
1312 };
1313 let gathered = test_support::gather(Value::GpuTensor(out_handle.clone())).unwrap();
1314 let expected: Vec<f64> = tensor.data.iter().map(|v| v.sin()).collect();
1315 assert_eq!(gathered.shape, tensor.shape);
1316 let tol = match provider.precision() {
1317 runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
1318 runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
1319 };
1320 for (actual, expect) in gathered.data.iter().zip(expected.iter()) {
1321 assert!(
1322 (actual - expect).abs() < tol,
1323 "expected {expect}, got {actual}"
1324 );
1325 }
1326 let _ = provider.free(&handle);
1327 let _ = provider.free(&out_handle);
1328 }
1329
1330 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1331 #[test]
1332 #[cfg(feature = "wgpu")]
1333 fn arrayfun_wgpu_plus_matches_cpu() {
1334 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1335 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1336 );
1337 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1338
1339 let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1340 let b = Tensor::new(vec![4.0, 3.0, 2.0, 1.0], vec![2, 2]).unwrap();
1341 let view_a = HostTensorView {
1342 data: &a.data,
1343 shape: &a.shape,
1344 };
1345 let view_b = HostTensorView {
1346 data: &b.data,
1347 shape: &b.shape,
1348 };
1349 let handle_a = provider.upload(&view_a).expect("upload a");
1350 let handle_b = provider.upload(&view_b).expect("upload b");
1351 let result = call(
1352 Value::FunctionHandle("plus".into()),
1353 vec![
1354 Value::GpuTensor(handle_a.clone()),
1355 Value::GpuTensor(handle_b.clone()),
1356 ],
1357 )
1358 .expect("arrayfun plus gpu");
1359
1360 let Value::GpuTensor(out_handle) = result else {
1361 panic!("expected GPU tensor result");
1362 };
1363 let gathered = test_support::gather(Value::GpuTensor(out_handle.clone())).unwrap();
1364 let expected: Vec<f64> = a
1365 .data
1366 .iter()
1367 .zip(b.data.iter())
1368 .map(|(x, y)| x + y)
1369 .collect();
1370 assert_eq!(gathered.shape, a.shape);
1371 let tol = match provider.precision() {
1372 runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
1373 runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
1374 };
1375 for (actual, expect) in gathered.data.iter().zip(expected.iter()) {
1376 assert!(
1377 (actual - expect).abs() < tol,
1378 "expected {expect}, got {actual}"
1379 );
1380 }
1381 let _ = provider.free(&handle_a);
1382 let _ = provider.free(&handle_b);
1383 let _ = provider.free(&out_handle);
1384 }
1385
1386 #[runmat_macros::runtime_builtin(
1387 name = "__arrayfun_test_handler",
1388 type_resolver(arrayfun_type),
1389 builtin_path = "crate::builtins::acceleration::gpu::arrayfun::tests"
1390 )]
1391 async fn arrayfun_test_handler(
1392 seed: Value,
1393 _err: Value,
1394 rest: Vec<Value>,
1395 ) -> crate::BuiltinResult<Value> {
1396 let _ = rest;
1397 Ok(seed)
1398 }
1399}