1use crate::builtins::common::spec::{
10 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
11 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
12};
13use crate::{
14 gather_if_needed, make_cell_with_shape, register_builtin_fusion_spec, register_builtin_gpu_spec,
15};
16use runmat_accelerate_api::{set_handle_logical, GpuTensorHandle, HostTensorView};
17use runmat_builtins::{CharArray, Closure, ComplexTensor, LogicalArray, Tensor, Value};
18use runmat_macros::runtime_builtin;
19
20#[cfg(feature = "doc_export")]
21use crate::register_builtin_doc_text;
22
23#[cfg(feature = "doc_export")]
24pub const DOC_MD: &str = r#"---
25title: "arrayfun"
26category: "acceleration/gpu"
27keywords: ["arrayfun", "gpuArray", "elementwise map", "anonymous function", "uniformoutput"]
28summary: "Apply a function to each element of array inputs, returning either a numeric array or a cell array."
29references:
30 - https://www.mathworks.com/help/parallel-computing/arrayfun.html
31gpu_support:
32 elementwise: true
33 reduction: false
34 precisions: ["f32", "f64"]
35 broadcasting: "matlab"
36 notes: "Executes directly on the GPU for supported builtin callbacks (sin, cos, abs, exp, log, sqrt, plus, minus, times, rdivide, ldivide) when all inputs are gpuArray values; falls back to host execution for closures, heterogeneous inputs, or unsupported callbacks. Uniform numeric/logical outputs are re-uploaded to the GPU otherwise; complex/character outputs stay on the host."
37fusion:
38 elementwise: false
39 reduction: false
40 max_inputs: 1
41 constants: "inline"
42requires_feature: null
43tested:
44 unit: "builtins::acceleration::gpu::arrayfun::tests"
45 integration: "builtins::acceleration::gpu::arrayfun::tests::arrayfun_gpu_roundtrip"
46 doc: "builtins::acceleration::gpu::arrayfun::tests::arrayfun_doc_examples_present"
47---
48
49# What does the `arrayfun` function do in MATLAB / RunMat?
50`arrayfun(func, A1, A2, …)` evaluates `func` for every element (or element-wise combination)
51of the supplied arrays. The builtin mirrors MATLAB's behaviour:
52
53- Inputs must have the same size. Scalars participate by broadcasting their single value.
54- The optional `'UniformOutput'` name-value flag controls whether results are collected into a
55 numeric/complex/logical/character array (`true`, the default) or returned as a cell array (`false`).
56- When `'ErrorHandler', handler` is supplied the handler receives the error struct and the
57 arguments that triggered the failure, letting you supply a fallback result.
58
59## How does the `arrayfun` function behave in MATLAB / RunMat?
60- Accepts function handles, builtin names (character vectors or string scalars), and closures.
61- Supports additional scalar parameters: `arrayfun(@(x,c) x + c, A, 5)` passes `5` to every call.
62- Honors the `'UniformOutput'` and `'ErrorHandler'` name-value pairs for MATLAB-compatible control flow.
63- Handles numeric, logical, character, and complex arrays. Unsupported types raise a descriptive
64 error instructing you to use `cellfun` when appropriate.
65- Empty inputs return empty outputs whose shape matches the first array argument.
66- When any input is a `gpuArray`, numeric or logical uniform outputs are uploaded back to the GPU
67 so downstream code retains GPU residency. Complex or character uniform outputs remain on the host
68 until providers add the appropriate buffer support. The current implementation computes on the
69 host and therefore inherits the host's floating-point behaviour.
70
71## `arrayfun` Function GPU Execution Behaviour
72When every input is a `gpuArray`, `'UniformOutput'` is `true`, and the callback resolves to one of
73the supported builtins (`sin`, `cos`, `abs`, `exp`, `log`, `sqrt`, `plus`, `minus`, `times`,
74`rdivide`, or `ldivide`), RunMat bypasses the host path and dispatches directly to the active
75provider through the corresponding hooks (`unary_*` or `elem_*`). The builtin acts as a fusion
76barrier—the fusion planner lowers upstream producers before invoking `arrayfun` because the callback
77can evaluate arbitrary MATLAB code.
78
79All other combinations—including closures, callbacks with extra scalar parameters, mixed residency,
80or `'UniformOutput', false`—gather inputs to the host, execute the callback element-wise, and then
81upload numeric or logical uniform results back to the GPU so later code continues with device
82residency. Complex and character uniform outputs remain host-resident until device representations
83are available. Cell outputs are always host-resident.
84
85## Examples of using the `arrayfun` function in MATLAB / RunMat
86
87### Squaring every element of a matrix
88```matlab
89A = [1 2 3; 4 5 6];
90B = arrayfun(@(x) x.^2, A);
91```
92Expected output:
93```matlab
94B =
95 1 4 9
96 16 25 36
97```
98
99### Passing additional scalar parameters
100```matlab
101A = [1 2 3];
102offset = 10;
103result = arrayfun(@(x, c) x + c, A, offset);
104```
105Expected output:
106```matlab
107result =
108 11 12 13
109```
110
111### Returning cells with non-uniform outputs
112```matlab
113strings = ["Run" "Matlab" "GPU"];
114chars = arrayfun(@(s) sprintf("%d", strlength(s)), strings, 'UniformOutput', false);
115```
116Expected output:
117```matlab
118chars =
119 1×3 cell array
120 {'3'} {'6'} {'3'}
121```
122
123### Handling errors with a custom error handler
124```matlab
125vals = [-1 0 1];
126handler = @(err, x) err.identifier;
127safe = arrayfun(@(x) sqrt(x), vals, 'ErrorHandler', handler, 'UniformOutput', false);
128```
129Expected output:
130```matlab
131safe =
132 1×3 cell array
133 {'MATLAB:arrayfun:FunctionError'} {[0]} {[1]}
134```
135
136### Working with `gpuArray` inputs
137```matlab
138G = gpuArray(linspace(0, pi, 5));
139S = arrayfun(@sin, G);
140H = gather(S);
141```
142Expected output:
143```matlab
144S =
145 1×5 gpuArray
146 0 0.7071 1.0000 0.7071 0
147H =
148 0 0.7071 1.0000 0.7071 0
149```
150
151## GPU residency in RunMat (Do I need `gpuArray`?)
152No. RunMat's auto-offload logic moves tensors to the GPU when profitable. If you do call
153`gpuArray`, `arrayfun` keeps the result on the GPU for uniform numeric or logical outputs so later
154operations can continue without gathering. Non-uniform or complex/character results stay on the
155host until GPU representations are available.
156
157## FAQ
158
159### Do I have to call `gpuArray` before using `arrayfun`?
160No. `arrayfun` participates in the same planner as other builtins, so the runtime migrates data to
161the GPU when it determines a benefit. Manual `gpuArray` calls remain useful for MATLAB
162compatibility or to force residency for custom workflows.
163
164### What happens when the callback returns mixed types?
165Set `'UniformOutput', false` so the builtin returns a cell array. When `'UniformOutput'` is `true`
166every callback invocation must return a numeric, logical, or complex scalar.
167
168### Can `arrayfun` handle character inputs?
169Yes. Each character element is passed to the callback as a single-character char array and the
170output follows the normal uniform/non-uniform rules.
171
172### Does `arrayfun` short-circuit on errors?
173No. The builtin invokes the optional error handler when a callback fails. If no handler is
174provided the first error aborts the entire call with a MATLAB-compatible identifier/message pair.
175
176### How are logical outputs represented on the GPU?
177Logical results use 0.0/1.0 buffers on the device. When you gather them RunMat converts the data
178back into a logical array automatically.
179
180## See Also
181[cellfun](../../cells/core/cellfun), [gpuArray](./gpuarray), [gather](./gather)
182
183## Source & Feedback
184- Source code: [`crates/runmat-runtime/src/builtins/acceleration/gpu/arrayfun.rs`](https://github.com/runmat-org/runmat/blob/main/crates/runmat-runtime/src/builtins/acceleration/gpu/arrayfun.rs)
185- Found an issue? Please [open a GitHub issue](https://github.com/runmat-org/runmat/issues/new/choose) with a repro.
186"#;
187
188pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
189 name: "arrayfun",
190 op_kind: GpuOpKind::Elementwise,
191 supported_precisions: &[ScalarType::F32, ScalarType::F64],
192 broadcast: BroadcastSemantics::Matlab,
193 provider_hooks: &[
194 ProviderHook::Unary { name: "unary_sin" },
195 ProviderHook::Unary { name: "unary_cos" },
196 ProviderHook::Unary { name: "unary_abs" },
197 ProviderHook::Unary { name: "unary_exp" },
198 ProviderHook::Unary { name: "unary_log" },
199 ProviderHook::Unary { name: "unary_sqrt" },
200 ProviderHook::Binary {
201 name: "elem_add",
202 commutative: true,
203 },
204 ProviderHook::Binary {
205 name: "elem_sub",
206 commutative: false,
207 },
208 ProviderHook::Binary {
209 name: "elem_mul",
210 commutative: true,
211 },
212 ProviderHook::Binary {
213 name: "elem_div",
214 commutative: false,
215 },
216 ],
217 constant_strategy: ConstantStrategy::InlineLiteral,
218 residency: ResidencyPolicy::NewHandle,
219 nan_mode: ReductionNaN::Include,
220 two_pass_threshold: None,
221 workgroup_size: None,
222 accepts_nan_mode: false,
223 notes: "Providers that implement the listed kernels can run supported callbacks entirely on the GPU; unsupported callbacks fall back to the host path with re-upload.",
224};
225
226register_builtin_gpu_spec!(GPU_SPEC);
227
228pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
229 name: "arrayfun",
230 shape: ShapeRequirements::Any,
231 constant_strategy: ConstantStrategy::InlineLiteral,
232 elementwise: None,
233 reduction: None,
234 emits_nan: false,
235 notes: "Acts as a fusion barrier because the callback can run arbitrary MATLAB code.",
236};
237
238register_builtin_fusion_spec!(FUSION_SPEC);
239
240#[cfg(feature = "doc_export")]
241register_builtin_doc_text!("arrayfun", DOC_MD);
242
243#[runtime_builtin(
244 name = "arrayfun",
245 category = "acceleration/gpu",
246 summary = "Apply a function element-wise to array inputs.",
247 keywords = "arrayfun,gpu,array,map,functional",
248 accel = "host"
249)]
250fn arrayfun_builtin(func: Value, mut rest: Vec<Value>) -> Result<Value, String> {
251 let callable = Callable::from_function(func)?;
252
253 let mut uniform_output = true;
254 let mut error_handler: Option<Callable> = None;
255
256 while rest.len() >= 2 {
257 let key_candidate = rest[rest.len() - 2].clone();
258 let Some(name) = extract_string(&key_candidate) else {
259 break;
260 };
261 let value = rest.pop().expect("value present");
262 rest.pop();
263 match name.trim().to_ascii_lowercase().as_str() {
264 "uniformoutput" => uniform_output = parse_uniform_output(value)?,
265 "errorhandler" => error_handler = Some(Callable::from_function(value)?),
266 other => return Err(format!("arrayfun: unknown name-value argument '{other}'")),
267 }
268 }
269
270 if rest.is_empty() {
271 return Err("arrayfun: expected at least one input array".to_string());
272 }
273
274 let inputs_snapshot = rest.clone();
275 let has_gpu_input = inputs_snapshot
276 .iter()
277 .any(|value| matches!(value, Value::GpuTensor(_)));
278 let gpu_device_id = inputs_snapshot.iter().find_map(|v| {
279 if let Value::GpuTensor(h) = v {
280 Some(h.device_id)
281 } else {
282 None
283 }
284 });
285
286 if uniform_output {
287 if let Some(gpu_result) =
288 try_gpu_fast_path(&callable, &inputs_snapshot, error_handler.as_ref())?
289 {
290 return Ok(gpu_result);
291 }
292 }
293
294 let mut inputs: Vec<ArrayInput> = Vec::with_capacity(rest.len());
295 let mut base_shape: Vec<usize> = Vec::new();
296 let mut base_len: Option<usize> = None;
297
298 for (idx, raw) in rest.into_iter().enumerate() {
299 if matches!(raw, Value::Cell(_)) {
300 return Err(
301 "arrayfun: cell inputs are not supported (use cellfun instead)".to_string(),
302 );
303 }
304 if matches!(raw, Value::Struct(_)) {
305 return Err("arrayfun: struct inputs are not supported".to_string());
306 }
307
308 let host_value = gather_if_needed(&raw)?;
309 let data = ArrayData::from_value(host_value)?;
310 let len = data.len();
311 let is_scalar = len == 1;
312
313 let mut input = ArrayInput { data, is_scalar };
314
315 if let Some(current) = base_len {
316 if current == len {
317 if len > 1 {
318 let shape = input.shape_vec();
319 if shape != base_shape {
320 return Err(format!(
321 "arrayfun: input {} does not match the size of the first array",
322 idx + 1
323 ));
324 }
325 }
326 } else if len == 1 {
327 input.is_scalar = true;
328 } else if current == 1 {
329 base_len = Some(len);
330 base_shape = input.shape_vec();
331 for prior in &mut inputs {
332 let prior_len = prior.len();
333 if prior_len == len {
334 if prior.shape_vec() != base_shape {
335 return Err(format!(
336 "arrayfun: input {} does not match the size of the first array",
337 idx
338 ));
339 }
340 } else if prior_len == 1 {
341 prior.is_scalar = true;
342 } else if prior_len == 0 && len == 0 {
343 continue;
344 } else {
345 return Err(format!(
346 "arrayfun: input {} does not match the size of the first array",
347 idx
348 ));
349 }
350 }
351 } else if len == 0 && current == 0 {
352 let shape = input.shape_vec();
353 if shape != base_shape {
354 return Err(format!(
355 "arrayfun: input {} does not match the size of the first array",
356 idx + 1
357 ));
358 }
359 } else {
360 return Err(format!(
361 "arrayfun: input {} does not match the size of the first array",
362 idx + 1
363 ));
364 }
365 } else {
366 base_len = Some(len);
367 base_shape = input.shape_vec();
368 }
369
370 inputs.push(input);
371 }
372
373 let total_len = base_len.unwrap_or(0);
374
375 if total_len == 0 {
376 if uniform_output {
377 return Ok(empty_uniform(&base_shape));
378 } else {
379 return make_cell_with_shape(Vec::new(), base_shape)
380 .map_err(|e| format!("arrayfun: {e}"));
381 }
382 }
383
384 let mut collector = if uniform_output {
385 Some(UniformCollector::Pending)
386 } else {
387 None
388 };
389
390 let mut cell_outputs: Vec<Value> = Vec::new();
391 let mut args: Vec<Value> = Vec::with_capacity(inputs.len());
392
393 for idx in 0..total_len {
394 args.clear();
395 for input in &inputs {
396 args.push(input.value_at(idx)?);
397 }
398
399 let result = match callable.call(&args) {
400 Ok(value) => value,
401 Err(err) => {
402 let handler = error_handler
403 .as_ref()
404 .ok_or_else(|| format!("arrayfun: {err}"))?;
405 let err_value = make_error_struct(&err, idx, &base_shape)?;
406 let mut handler_args = Vec::with_capacity(1 + args.len());
407 handler_args.push(err_value);
408 handler_args.extend(args.clone());
409 handler.call(&handler_args)?
410 }
411 };
412
413 let host_result = gather_if_needed(&result)?;
414
415 if let Some(collector) = collector.as_mut() {
416 collector.push(&host_result)?;
417 } else {
418 cell_outputs.push(host_result);
419 }
420 }
421
422 if let Some(collector) = collector {
423 let uniform = collector.finish(&base_shape)?;
424 maybe_upload_uniform(uniform, has_gpu_input, gpu_device_id)
425 } else {
426 make_cell_with_shape(cell_outputs, base_shape).map_err(|e| format!("arrayfun: {e}"))
427 }
428}
429
430fn maybe_upload_uniform(
431 value: Value,
432 has_gpu_input: bool,
433 gpu_device_id: Option<u32>,
434) -> Result<Value, String> {
435 if !has_gpu_input {
436 return Ok(value);
437 }
438 #[cfg(all(test, feature = "wgpu"))]
439 {
440 if matches!(gpu_device_id, Some(id) if id != 0) {
441 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
442 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
443 );
444 }
445 }
446 let _ = gpu_device_id; let provider = match runmat_accelerate_api::provider() {
448 Some(p) => p,
449 None => return Ok(value),
450 };
451
452 match value {
453 Value::Tensor(tensor) => {
454 let view = HostTensorView {
455 data: &tensor.data,
456 shape: &tensor.shape,
457 };
458 let handle = provider.upload(&view).map_err(|e| e.to_string())?;
459 Ok(Value::GpuTensor(handle))
460 }
461 Value::LogicalArray(logical) => {
462 let data: Vec<f64> = logical
463 .data
464 .iter()
465 .map(|&bit| if bit != 0 { 1.0 } else { 0.0 })
466 .collect();
467 let tensor =
468 Tensor::new(data, logical.shape.clone()).map_err(|e| format!("arrayfun: {e}"))?;
469 let view = HostTensorView {
470 data: &tensor.data,
471 shape: &tensor.shape,
472 };
473 let handle = provider.upload(&view).map_err(|e| e.to_string())?;
474 set_handle_logical(&handle, true);
475 Ok(Value::GpuTensor(handle))
476 }
477 other => Ok(other),
478 }
479}
480
481fn empty_uniform(shape: &[usize]) -> Value {
482 if shape.is_empty() {
483 return Value::Tensor(Tensor::zeros(vec![0, 0]));
484 }
485 let total: usize = shape.iter().product();
486 let tensor = Tensor::new(vec![0.0; total], shape.to_vec())
487 .unwrap_or_else(|_| Tensor::zeros(shape.to_vec()));
488 Value::Tensor(tensor)
489}
490
491fn parse_uniform_output(value: Value) -> Result<bool, String> {
492 match value {
493 Value::Bool(b) => Ok(b),
494 Value::Num(n) => Ok(n != 0.0),
495 Value::Int(iv) => Ok(iv.to_f64() != 0.0),
496 Value::String(s) => parse_bool_string(&s)
497 .ok_or_else(|| "arrayfun: UniformOutput must be logical true or false".to_string()),
498 Value::CharArray(ca) if ca.rows == 1 => {
499 let text: String = ca.data.iter().collect();
500 parse_bool_string(&text)
501 .ok_or_else(|| "arrayfun: UniformOutput must be logical true or false".to_string())
502 }
503 other => Err(format!(
504 "arrayfun: UniformOutput must be logical true or false, got {other:?}"
505 )),
506 }
507}
508
509fn parse_bool_string(value: &str) -> Option<bool> {
510 match value.trim().to_ascii_lowercase().as_str() {
511 "true" | "on" => Some(true),
512 "false" | "off" => Some(false),
513 _ => None,
514 }
515}
516
517fn extract_string(value: &Value) -> Option<String> {
518 match value {
519 Value::String(s) => Some(s.clone()),
520 Value::CharArray(ca) if ca.rows == 1 => Some(ca.data.iter().collect()),
521 Value::StringArray(sa) if sa.data.len() == 1 => Some(sa.data[0].clone()),
522 _ => None,
523 }
524}
525
526struct ArrayInput {
527 data: ArrayData,
528 is_scalar: bool,
529}
530
531impl ArrayInput {
532 fn len(&self) -> usize {
533 self.data.len()
534 }
535
536 fn shape_vec(&self) -> Vec<usize> {
537 self.data.shape_vec()
538 }
539
540 fn value_at(&self, idx: usize) -> Result<Value, String> {
541 if self.is_scalar {
542 self.data.value_at(0)
543 } else {
544 self.data.value_at(idx)
545 }
546 }
547}
548
549enum ArrayData {
550 Tensor(Tensor),
551 Logical(LogicalArray),
552 Complex(ComplexTensor),
553 Char(CharArray),
554 Scalar(Value),
555}
556
557impl ArrayData {
558 fn from_value(value: Value) -> Result<Self, String> {
559 match value {
560 Value::Tensor(t) => Ok(ArrayData::Tensor(t)),
561 Value::LogicalArray(l) => Ok(ArrayData::Logical(l)),
562 Value::ComplexTensor(c) => Ok(ArrayData::Complex(c)),
563 Value::CharArray(ca) => Ok(ArrayData::Char(ca)),
564 Value::Num(_) | Value::Bool(_) | Value::Int(_) | Value::Complex(_, _) => {
565 Ok(ArrayData::Scalar(value))
566 }
567 other => Err(format!(
568 "arrayfun: unsupported input type {other:?} (expected numeric, logical, complex, or char arrays)"
569 )),
570 }
571 }
572
573 fn len(&self) -> usize {
574 match self {
575 ArrayData::Tensor(t) => t.data.len(),
576 ArrayData::Logical(l) => l.data.len(),
577 ArrayData::Complex(c) => c.data.len(),
578 ArrayData::Char(ca) => ca.rows * ca.cols,
579 ArrayData::Scalar(_) => 1,
580 }
581 }
582
583 fn shape_vec(&self) -> Vec<usize> {
584 match self {
585 ArrayData::Tensor(t) => {
586 if t.shape.is_empty() {
587 vec![1, 1]
588 } else {
589 t.shape.clone()
590 }
591 }
592 ArrayData::Logical(l) => {
593 if l.shape.is_empty() {
594 vec![1, 1]
595 } else {
596 l.shape.clone()
597 }
598 }
599 ArrayData::Complex(c) => {
600 if c.shape.is_empty() {
601 vec![1, 1]
602 } else {
603 c.shape.clone()
604 }
605 }
606 ArrayData::Char(ca) => vec![ca.rows, ca.cols],
607 ArrayData::Scalar(_) => vec![1, 1],
608 }
609 }
610
611 fn value_at(&self, idx: usize) -> Result<Value, String> {
612 match self {
613 ArrayData::Tensor(t) => {
614 Ok(Value::Num(*t.data.get(idx).ok_or_else(|| {
615 "arrayfun: index out of bounds".to_string()
616 })?))
617 }
618 ArrayData::Logical(l) => Ok(Value::Bool(
619 *l.data
620 .get(idx)
621 .ok_or_else(|| "arrayfun: index out of bounds".to_string())?
622 != 0,
623 )),
624 ArrayData::Complex(c) => {
625 let (re, im) = c
626 .data
627 .get(idx)
628 .ok_or_else(|| "arrayfun: index out of bounds".to_string())?;
629 Ok(Value::Complex(*re, *im))
630 }
631 ArrayData::Char(ca) => {
632 if ca.rows == 0 || ca.cols == 0 {
633 return Ok(Value::CharArray(
634 CharArray::new(Vec::new(), 0, 0).map_err(|e| format!("arrayfun: {e}"))?,
635 ));
636 }
637 let rows = ca.rows;
638 let cols = ca.cols;
639 let row = idx % rows;
640 let col = idx / rows;
641 let data_idx = row * cols + col;
642 let ch = *ca
643 .data
644 .get(data_idx)
645 .ok_or_else(|| "arrayfun: index out of bounds".to_string())?;
646 let char_array =
647 CharArray::new(vec![ch], 1, 1).map_err(|e| format!("arrayfun: {e}"))?;
648 Ok(Value::CharArray(char_array))
649 }
650 ArrayData::Scalar(v) => Ok(v.clone()),
651 }
652 }
653}
654
655#[derive(Clone)]
656enum Callable {
657 Builtin { name: String },
658 Closure(Closure),
659}
660
661impl Callable {
662 fn from_function(value: Value) -> Result<Self, String> {
663 match value {
664 Value::String(text) => Self::from_text(&text),
665 Value::CharArray(ca) => {
666 if ca.rows != 1 {
667 Err(
668 "arrayfun: function name must be a character vector or string scalar"
669 .to_string(),
670 )
671 } else {
672 let text: String = ca.data.iter().collect();
673 Self::from_text(&text)
674 }
675 }
676 Value::StringArray(sa) if sa.data.len() == 1 => Self::from_text(&sa.data[0]),
677 Value::FunctionHandle(name) => Ok(Callable::Builtin { name }),
678 Value::Closure(closure) => Ok(Callable::Closure(closure)),
679 Value::Num(_) | Value::Int(_) | Value::Bool(_) => Err(
680 "arrayfun: expected function handle or builtin name, not a scalar value"
681 .to_string(),
682 ),
683 other => Err(format!(
684 "arrayfun: expected function handle or builtin name, got {other:?}"
685 )),
686 }
687 }
688
689 fn from_text(text: &str) -> Result<Self, String> {
690 let trimmed = text.trim();
691 if trimmed.is_empty() {
692 return Err(
693 "arrayfun: expected function handle or builtin name, got empty string".to_string(),
694 );
695 }
696 if let Some(rest) = trimmed.strip_prefix('@') {
697 let name = rest.trim();
698 if name.is_empty() {
699 Err("arrayfun: empty function handle".to_string())
700 } else {
701 Ok(Callable::Builtin {
702 name: name.to_string(),
703 })
704 }
705 } else {
706 Ok(Callable::Builtin {
707 name: trimmed.to_ascii_lowercase(),
708 })
709 }
710 }
711
712 fn builtin_name(&self) -> Option<&str> {
713 match self {
714 Callable::Builtin { name } => Some(name.as_str()),
715 Callable::Closure(_) => None,
716 }
717 }
718
719 fn call(&self, args: &[Value]) -> Result<Value, String> {
720 match self {
721 Callable::Builtin { name } => crate::call_builtin(name, args),
722 Callable::Closure(c) => {
723 let mut merged = c.captures.clone();
724 merged.extend_from_slice(args);
725 crate::call_builtin(&c.function_name, &merged)
726 }
727 }
728 }
729}
730
731fn try_gpu_fast_path(
732 callable: &Callable,
733 inputs: &[Value],
734 error_handler: Option<&Callable>,
735) -> Result<Option<Value>, String> {
736 if inputs.is_empty() || error_handler.is_some() {
737 return Ok(None);
738 }
739 if !inputs
740 .iter()
741 .all(|value| matches!(value, Value::GpuTensor(_)))
742 {
743 return Ok(None);
744 }
745
746 #[cfg(all(test, feature = "wgpu"))]
747 {
748 if inputs
749 .iter()
750 .any(|v| matches!(v, Value::GpuTensor(h) if h.device_id != 0))
751 {
752 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
753 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
754 );
755 }
756 }
757 let provider = match runmat_accelerate_api::provider() {
758 Some(p) => p,
759 None => return Ok(None),
760 };
761
762 let Some(name_raw) = callable.builtin_name() else {
763 return Ok(None);
764 };
765 let name = name_raw.to_ascii_lowercase();
766
767 let mut handles: Vec<GpuTensorHandle> = Vec::with_capacity(inputs.len());
768 for value in inputs {
769 if let Value::GpuTensor(handle) = value {
770 handles.push(handle.clone());
771 }
772 }
773
774 if handles.len() >= 2 {
775 let base_shape = handles[0].shape.clone();
776 if handles
777 .iter()
778 .skip(1)
779 .any(|handle| handle.shape != base_shape)
780 {
781 return Ok(None);
782 }
783 }
784
785 let result = match name.as_str() {
786 "sin" if handles.len() == 1 => provider.unary_sin(&handles[0]),
787 "cos" if handles.len() == 1 => provider.unary_cos(&handles[0]),
788 "abs" if handles.len() == 1 => provider.unary_abs(&handles[0]),
789 "exp" if handles.len() == 1 => provider.unary_exp(&handles[0]),
790 "log" if handles.len() == 1 => provider.unary_log(&handles[0]),
791 "sqrt" if handles.len() == 1 => provider.unary_sqrt(&handles[0]),
792 "plus" if handles.len() == 2 => provider.elem_add(&handles[0], &handles[1]),
793 "minus" if handles.len() == 2 => provider.elem_sub(&handles[0], &handles[1]),
794 "times" if handles.len() == 2 => provider.elem_mul(&handles[0], &handles[1]),
795 "rdivide" if handles.len() == 2 => provider.elem_div(&handles[0], &handles[1]),
796 "ldivide" if handles.len() == 2 => provider.elem_div(&handles[1], &handles[0]),
797 _ => return Ok(None),
798 };
799
800 match result {
801 Ok(handle) => Ok(Some(Value::GpuTensor(handle))),
802 Err(_) => Ok(None),
803 }
804}
805
806enum UniformCollector {
807 Pending,
808 Double(Vec<f64>),
809 Logical(Vec<u8>),
810 Complex(Vec<(f64, f64)>),
811 Char(Vec<char>),
812}
813
814impl UniformCollector {
815 fn push(&mut self, value: &Value) -> Result<(), String> {
816 match self {
817 UniformCollector::Pending => match classify_value(value)? {
818 ClassifiedValue::Logical(b) => {
819 *self = UniformCollector::Logical(vec![b as u8]);
820 Ok(())
821 }
822 ClassifiedValue::Double(d) => {
823 *self = UniformCollector::Double(vec![d]);
824 Ok(())
825 }
826 ClassifiedValue::Complex(c) => {
827 *self = UniformCollector::Complex(vec![c]);
828 Ok(())
829 }
830 ClassifiedValue::Char(ch) => {
831 *self = UniformCollector::Char(vec![ch]);
832 Ok(())
833 }
834 },
835 UniformCollector::Logical(bits) => match classify_value(value)? {
836 ClassifiedValue::Logical(b) => {
837 bits.push(b as u8);
838 Ok(())
839 }
840 ClassifiedValue::Double(d) => {
841 let mut data: Vec<f64> = bits
842 .iter()
843 .map(|&bit| if bit != 0 { 1.0 } else { 0.0 })
844 .collect();
845 data.push(d);
846 *self = UniformCollector::Double(data);
847 Ok(())
848 }
849 ClassifiedValue::Complex(c) => {
850 let mut data: Vec<(f64, f64)> = bits
851 .iter()
852 .map(|&bit| if bit != 0 { (1.0, 0.0) } else { (0.0, 0.0) })
853 .collect();
854 data.push(c);
855 *self = UniformCollector::Complex(data);
856 Ok(())
857 }
858 ClassifiedValue::Char(ch) => {
859 let mut data: Vec<f64> = bits
860 .iter()
861 .map(|&bit| if bit != 0 { 1.0 } else { 0.0 })
862 .collect();
863 data.push(ch as u32 as f64);
864 *self = UniformCollector::Double(data);
865 Ok(())
866 }
867 },
868 UniformCollector::Double(data) => match classify_value(value)? {
869 ClassifiedValue::Logical(b) => {
870 data.push(if b { 1.0 } else { 0.0 });
871 Ok(())
872 }
873 ClassifiedValue::Double(d) => {
874 data.push(d);
875 Ok(())
876 }
877 ClassifiedValue::Complex(c) => {
878 let promoted: Vec<(f64, f64)> = data.iter().map(|&v| (v, 0.0)).collect();
879 let mut complex = promoted;
880 complex.push(c);
881 *self = UniformCollector::Complex(complex);
882 Ok(())
883 }
884 ClassifiedValue::Char(ch) => {
885 data.push(ch as u32 as f64);
886 Ok(())
887 }
888 },
889 UniformCollector::Complex(data) => match classify_value(value)? {
890 ClassifiedValue::Logical(b) => {
891 data.push((if b { 1.0 } else { 0.0 }, 0.0));
892 Ok(())
893 }
894 ClassifiedValue::Double(d) => {
895 data.push((d, 0.0));
896 Ok(())
897 }
898 ClassifiedValue::Complex(c) => {
899 data.push(c);
900 Ok(())
901 }
902 ClassifiedValue::Char(ch) => {
903 data.push((ch as u32 as f64, 0.0));
904 Ok(())
905 }
906 },
907 UniformCollector::Char(chars) => match classify_value(value)? {
908 ClassifiedValue::Char(ch) => {
909 chars.push(ch);
910 Ok(())
911 }
912 ClassifiedValue::Logical(b) => {
913 let mut data: Vec<f64> = chars.iter().map(|&ch| ch as u32 as f64).collect();
914 data.push(if b { 1.0 } else { 0.0 });
915 *self = UniformCollector::Double(data);
916 Ok(())
917 }
918 ClassifiedValue::Double(d) => {
919 let mut data: Vec<f64> = chars.iter().map(|&ch| ch as u32 as f64).collect();
920 data.push(d);
921 *self = UniformCollector::Double(data);
922 Ok(())
923 }
924 ClassifiedValue::Complex(c) => {
925 let mut promoted: Vec<(f64, f64)> =
926 chars.iter().map(|&ch| (ch as u32 as f64, 0.0)).collect();
927 promoted.push(c);
928 *self = UniformCollector::Complex(promoted);
929 Ok(())
930 }
931 },
932 }
933 }
934
935 fn finish(self, shape: &[usize]) -> Result<Value, String> {
936 match self {
937 UniformCollector::Pending => {
938 let total = shape.iter().product();
939 let tensor = Tensor::new(vec![0.0; total], shape.to_vec())
940 .map_err(|e| format!("arrayfun: {e}"))?;
941 Ok(Value::Tensor(tensor))
942 }
943 UniformCollector::Double(data) => {
944 let tensor =
945 Tensor::new(data, shape.to_vec()).map_err(|e| format!("arrayfun: {e}"))?;
946 Ok(Value::Tensor(tensor))
947 }
948 UniformCollector::Logical(bits) => {
949 let logical = LogicalArray::new(bits, shape.to_vec())
950 .map_err(|e| format!("arrayfun: {e}"))?;
951 Ok(Value::LogicalArray(logical))
952 }
953 UniformCollector::Complex(entries) => {
954 let tensor = ComplexTensor::new(entries, shape.to_vec())
955 .map_err(|e| format!("arrayfun: {e}"))?;
956 Ok(Value::ComplexTensor(tensor))
957 }
958 UniformCollector::Char(chars) => {
959 let normalized_shape = if shape.is_empty() {
960 vec![1, 1]
961 } else {
962 shape.to_vec()
963 };
964
965 if normalized_shape.len() > 2 {
966 return Err(
967 "arrayfun: character outputs with UniformOutput=true must be 2-D"
968 .to_string(),
969 );
970 }
971
972 let rows = normalized_shape.first().copied().unwrap_or(1);
973 let cols = normalized_shape.get(1).copied().unwrap_or(1);
974 let expected = rows.checked_mul(cols).ok_or_else(|| {
975 "arrayfun: character output size exceeds platform limits".to_string()
976 })?;
977
978 if expected != chars.len() {
979 return Err(
980 "arrayfun: callback returned the wrong number of characters".to_string()
981 );
982 }
983
984 let mut row_major = vec!['\0'; expected];
985 for col in 0..cols {
986 for row in 0..rows {
987 let col_major_idx = row + col * rows;
988 let row_major_idx = row * cols + col;
989 row_major[row_major_idx] = chars[col_major_idx];
990 }
991 }
992
993 let array =
994 CharArray::new(row_major, rows, cols).map_err(|e| format!("arrayfun: {e}"))?;
995 Ok(Value::CharArray(array))
996 }
997 }
998 }
999}
1000
1001enum ClassifiedValue {
1002 Logical(bool),
1003 Double(f64),
1004 Complex((f64, f64)),
1005 Char(char),
1006}
1007
1008fn classify_value(value: &Value) -> Result<ClassifiedValue, String> {
1009 match value {
1010 Value::Bool(b) => Ok(ClassifiedValue::Logical(*b)),
1011 Value::LogicalArray(la) if la.len() == 1 => Ok(ClassifiedValue::Logical(la.data[0] != 0)),
1012 Value::Int(i) => Ok(ClassifiedValue::Double(i.to_f64())),
1013 Value::Num(n) => Ok(ClassifiedValue::Double(*n)),
1014 Value::Tensor(t) if t.data.len() == 1 => Ok(ClassifiedValue::Double(t.data[0])),
1015 Value::Complex(re, im) => Ok(ClassifiedValue::Complex((*re, *im))),
1016 Value::ComplexTensor(t) if t.data.len() == 1 => Ok(ClassifiedValue::Complex(t.data[0])),
1017 Value::CharArray(ca) if ca.rows * ca.cols == 1 => {
1018 let ch = ca.data.first().copied().unwrap_or('\0');
1019 Ok(ClassifiedValue::Char(ch))
1020 }
1021 other => Err(format!(
1022 "arrayfun: callback must return scalar numeric, logical, character, or complex values for UniformOutput=true (got {other:?})"
1023 )),
1024 }
1025}
1026
1027fn make_error_struct(
1028 raw_error: &str,
1029 linear_index: usize,
1030 shape: &[usize],
1031) -> Result<Value, String> {
1032 let (identifier, message) = split_error_message(raw_error);
1033 let mut st = runmat_builtins::StructValue::new();
1034 st.fields
1035 .insert("identifier".to_string(), Value::String(identifier));
1036 st.fields
1037 .insert("message".to_string(), Value::String(message));
1038 st.fields
1039 .insert("index".to_string(), Value::Num((linear_index + 1) as f64));
1040 let subs = linear_to_indices(linear_index, shape);
1041 let subs_tensor = dims_to_row_tensor(&subs)?;
1042 st.fields
1043 .insert("indices".to_string(), Value::Tensor(subs_tensor));
1044 Ok(Value::Struct(st))
1045}
1046
1047fn split_error_message(raw: &str) -> (String, String) {
1048 let trimmed = raw.trim();
1049 let mut indices = trimmed.match_indices(':');
1050 if let Some((_, _)) = indices.next() {
1051 if let Some((second_idx, _)) = indices.next() {
1052 let identifier = trimmed[..second_idx].trim().to_string();
1053 let message = trimmed[second_idx + 1..].trim().to_string();
1054 if !identifier.is_empty() && identifier.contains(':') {
1055 return (
1056 identifier,
1057 if message.is_empty() {
1058 trimmed.to_string()
1059 } else {
1060 message
1061 },
1062 );
1063 }
1064 } else if trimmed.len() >= 7
1065 && (trimmed[..7].eq_ignore_ascii_case("matlab:")
1066 || trimmed[..7].eq_ignore_ascii_case("runmat:"))
1067 {
1068 return (trimmed.to_string(), String::new());
1069 }
1070 }
1071 (
1072 "MATLAB:arrayfun:FunctionError".to_string(),
1073 trimmed.to_string(),
1074 )
1075}
1076
1077fn linear_to_indices(mut index: usize, shape: &[usize]) -> Vec<usize> {
1078 if shape.is_empty() {
1079 return vec![1];
1080 }
1081 let mut subs = Vec::with_capacity(shape.len());
1082 for &dim in shape {
1083 if dim == 0 {
1084 subs.push(1);
1085 continue;
1086 }
1087 let coord = (index % dim) + 1;
1088 subs.push(coord);
1089 index /= dim;
1090 }
1091 subs
1092}
1093
1094fn dims_to_row_tensor(dims: &[usize]) -> Result<Tensor, String> {
1095 let data: Vec<f64> = dims.iter().map(|&d| d as f64).collect();
1096 Tensor::new(data, vec![1, dims.len()]).map_err(|e| format!("arrayfun: {e}"))
1097}
1098
1099#[cfg(test)]
1100mod tests {
1101 use super::*;
1102 use crate::builtins::common::test_support;
1103 use runmat_accelerate_api::HostTensorView;
1104 use runmat_builtins::Tensor;
1105
1106 #[test]
1107 fn arrayfun_basic_sin() {
1108 let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0], vec![2, 3]).unwrap();
1109 let expected: Vec<f64> = tensor.data.iter().map(|&x| x.sin()).collect();
1110 let result = arrayfun_builtin(
1111 Value::FunctionHandle("sin".to_string()),
1112 vec![Value::Tensor(tensor.clone())],
1113 )
1114 .expect("arrayfun");
1115 match result {
1116 Value::Tensor(out) => {
1117 assert_eq!(out.shape, vec![2, 3]);
1118 assert_eq!(out.data, expected);
1119 }
1120 other => panic!("expected tensor, got {other:?}"),
1121 }
1122 }
1123
1124 #[test]
1125 fn arrayfun_additional_scalar_argument() {
1126 let tensor = Tensor::new(vec![0.5, 1.0, -1.0], vec![3, 1]).unwrap();
1127 let expected: Vec<f64> = tensor.data.iter().map(|&y| y.atan2(1.0)).collect();
1128 let result = arrayfun_builtin(
1129 Value::FunctionHandle("atan2".to_string()),
1130 vec![Value::Tensor(tensor), Value::Num(1.0)],
1131 )
1132 .expect("arrayfun");
1133 match result {
1134 Value::Tensor(out) => {
1135 assert_eq!(out.data, expected);
1136 }
1137 other => panic!("expected tensor, got {other:?}"),
1138 }
1139 }
1140
1141 #[test]
1142 fn arrayfun_uniform_false_returns_cell() {
1143 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1144 let expected: Vec<Value> = tensor.data.iter().map(|&x| Value::Num(x.sin())).collect();
1145 let result = arrayfun_builtin(
1146 Value::FunctionHandle("sin".to_string()),
1147 vec![
1148 Value::Tensor(tensor),
1149 Value::String("UniformOutput".into()),
1150 Value::Bool(false),
1151 ],
1152 )
1153 .expect("arrayfun");
1154 let Value::Cell(cell) = result else {
1155 panic!("expected cell, got something else");
1156 };
1157 assert_eq!(cell.rows, 2);
1158 assert_eq!(cell.cols, 1);
1159 for (row, value) in expected.iter().enumerate() {
1160 assert_eq!(cell.get(row, 0).unwrap(), *value);
1161 }
1162 }
1163
1164 #[test]
1165 fn arrayfun_size_mismatch_errors() {
1166 let taller = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1167 let shorter = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1168 let err = arrayfun_builtin(
1169 Value::FunctionHandle("sin".to_string()),
1170 vec![Value::Tensor(taller), Value::Tensor(shorter)],
1171 )
1172 .expect_err("expected size mismatch error");
1173 assert!(
1174 err.contains("does not match"),
1175 "expected size mismatch error, got {err}"
1176 );
1177 }
1178
1179 #[test]
1180 fn arrayfun_error_handler_recovers() {
1181 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1182 let handler = Value::Closure(Closure {
1183 function_name: "__arrayfun_test_handler".into(),
1184 captures: vec![Value::Num(42.0)],
1185 });
1186 let result = arrayfun_builtin(
1187 Value::String("@nonexistent_builtin".into()),
1188 vec![
1189 Value::Tensor(tensor),
1190 Value::String("ErrorHandler".into()),
1191 handler,
1192 ],
1193 )
1194 .expect("arrayfun error handler");
1195 match result {
1196 Value::Tensor(out) => {
1197 assert_eq!(out.shape, vec![3, 1]);
1198 assert_eq!(out.data, vec![42.0, 42.0, 42.0]);
1199 }
1200 other => panic!("expected tensor, got {other:?}"),
1201 }
1202 }
1203
1204 #[test]
1205 fn arrayfun_error_without_handler_propagates_identifier() {
1206 let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1207 let err = arrayfun_builtin(
1208 Value::String("@nonexistent_builtin".into()),
1209 vec![Value::Tensor(tensor)],
1210 )
1211 .expect_err("expected unresolved function error");
1212 assert!(
1213 err.contains("MATLAB:UndefinedFunction"),
1214 "unexpected error: {err}"
1215 );
1216 }
1217
1218 #[test]
1219 fn arrayfun_uniform_logical_result() {
1220 let tensor = Tensor::new(vec![1.0, f64::NAN, 0.0, f64::INFINITY], vec![4, 1]).unwrap();
1221 let result = arrayfun_builtin(
1222 Value::FunctionHandle("isfinite".to_string()),
1223 vec![Value::Tensor(tensor)],
1224 )
1225 .expect("arrayfun isfinite");
1226 match result {
1227 Value::LogicalArray(la) => {
1228 assert_eq!(la.shape, vec![4, 1]);
1229 assert_eq!(la.data, vec![1, 0, 1, 0]);
1230 }
1231 other => panic!("expected logical array, got {other:?}"),
1232 }
1233 }
1234
1235 #[test]
1236 fn arrayfun_uniform_character_result() {
1237 let tensor = Tensor::new(vec![65.0, 66.0, 67.0], vec![1, 3]).unwrap();
1238 let result = arrayfun_builtin(
1239 Value::FunctionHandle("char".to_string()),
1240 vec![Value::Tensor(tensor)],
1241 )
1242 .expect("arrayfun char");
1243 match result {
1244 Value::CharArray(ca) => {
1245 assert_eq!(ca.rows, 1);
1246 assert_eq!(ca.cols, 3);
1247 assert_eq!(ca.data, vec!['A', 'B', 'C']);
1248 }
1249 other => panic!("expected char array, got {other:?}"),
1250 }
1251 }
1252
1253 #[test]
1254 fn arrayfun_uniform_false_gpu_returns_cell() {
1255 test_support::with_test_provider(|provider| {
1256 let tensor = Tensor::new(vec![0.0, 1.0], vec![2, 1]).unwrap();
1257 let view = HostTensorView {
1258 data: &tensor.data,
1259 shape: &tensor.shape,
1260 };
1261 let handle = provider.upload(&view).expect("upload");
1262 let result = arrayfun_builtin(
1263 Value::FunctionHandle("sin".to_string()),
1264 vec![
1265 Value::GpuTensor(handle),
1266 Value::String("UniformOutput".into()),
1267 Value::Bool(false),
1268 ],
1269 )
1270 .expect("arrayfun");
1271 match result {
1272 Value::Cell(cell) => {
1273 assert_eq!(cell.rows, 2);
1274 assert_eq!(cell.cols, 1);
1275 let first = cell.get(0, 0).expect("first cell");
1276 let second = cell.get(1, 0).expect("second cell");
1277 match (first, second) {
1278 (Value::Num(a), Value::Num(b)) => {
1279 assert!((a - 0.0f64.sin()).abs() < 1e-12);
1280 assert!((b - 1.0f64.sin()).abs() < 1e-12);
1281 }
1282 other => panic!("expected numeric cells, got {other:?}"),
1283 }
1284 }
1285 other => panic!("expected cell, got {other:?}"),
1286 }
1287 });
1288 }
1289
1290 #[test]
1291 fn arrayfun_gpu_roundtrip() {
1292 test_support::with_test_provider(|provider| {
1293 let tensor = Tensor::new(vec![0.0, 1.0, 2.0, 3.0], vec![4, 1]).unwrap();
1294 let view = HostTensorView {
1295 data: &tensor.data,
1296 shape: &tensor.shape,
1297 };
1298 let handle = provider.upload(&view).expect("upload");
1299 let result = arrayfun_builtin(
1300 Value::FunctionHandle("sin".to_string()),
1301 vec![Value::GpuTensor(handle)],
1302 )
1303 .expect("arrayfun");
1304 match result {
1305 Value::GpuTensor(gpu) => {
1306 let gathered = test_support::gather(Value::GpuTensor(gpu.clone())).unwrap();
1307 let expected: Vec<f64> = tensor.data.iter().map(|&x| x.sin()).collect();
1308 assert_eq!(gathered.data, expected);
1309 let _ = provider.free(&gpu);
1310 }
1311 other => panic!("expected gpu tensor, got {other:?}"),
1312 }
1313 });
1314 }
1315
1316 #[test]
1317 #[cfg(feature = "wgpu")]
1318 fn arrayfun_wgpu_sin_matches_cpu() {
1319 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1320 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1321 );
1322 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1323
1324 let tensor = Tensor::new(vec![0.0, 1.0, 2.0, 3.0], vec![4, 1]).unwrap();
1325 let view = HostTensorView {
1326 data: &tensor.data,
1327 shape: &tensor.shape,
1328 };
1329 let handle = provider.upload(&view).expect("upload");
1330 let result = arrayfun_builtin(
1331 Value::FunctionHandle("sin".into()),
1332 vec![Value::GpuTensor(handle.clone())],
1333 )
1334 .expect("arrayfun sin gpu");
1335 let Value::GpuTensor(out_handle) = result else {
1336 panic!("expected GPU tensor result");
1337 };
1338 let gathered = test_support::gather(Value::GpuTensor(out_handle.clone())).unwrap();
1339 let expected: Vec<f64> = tensor.data.iter().map(|v| v.sin()).collect();
1340 assert_eq!(gathered.shape, tensor.shape);
1341 let tol = match provider.precision() {
1342 runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
1343 runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
1344 };
1345 for (actual, expect) in gathered.data.iter().zip(expected.iter()) {
1346 assert!(
1347 (actual - expect).abs() < tol,
1348 "expected {expect}, got {actual}"
1349 );
1350 }
1351 let _ = provider.free(&handle);
1352 let _ = provider.free(&out_handle);
1353 }
1354
1355 #[test]
1356 #[cfg(feature = "wgpu")]
1357 fn arrayfun_wgpu_plus_matches_cpu() {
1358 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1359 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1360 );
1361 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1362
1363 let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1364 let b = Tensor::new(vec![4.0, 3.0, 2.0, 1.0], vec![2, 2]).unwrap();
1365 let view_a = HostTensorView {
1366 data: &a.data,
1367 shape: &a.shape,
1368 };
1369 let view_b = HostTensorView {
1370 data: &b.data,
1371 shape: &b.shape,
1372 };
1373 let handle_a = provider.upload(&view_a).expect("upload a");
1374 let handle_b = provider.upload(&view_b).expect("upload b");
1375 let result = arrayfun_builtin(
1376 Value::FunctionHandle("plus".into()),
1377 vec![
1378 Value::GpuTensor(handle_a.clone()),
1379 Value::GpuTensor(handle_b.clone()),
1380 ],
1381 )
1382 .expect("arrayfun plus gpu");
1383
1384 let Value::GpuTensor(out_handle) = result else {
1385 panic!("expected GPU tensor result");
1386 };
1387 let gathered = test_support::gather(Value::GpuTensor(out_handle.clone())).unwrap();
1388 let expected: Vec<f64> = a
1389 .data
1390 .iter()
1391 .zip(b.data.iter())
1392 .map(|(x, y)| x + y)
1393 .collect();
1394 assert_eq!(gathered.shape, a.shape);
1395 let tol = match provider.precision() {
1396 runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
1397 runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
1398 };
1399 for (actual, expect) in gathered.data.iter().zip(expected.iter()) {
1400 assert!(
1401 (actual - expect).abs() < tol,
1402 "expected {expect}, got {actual}"
1403 );
1404 }
1405 let _ = provider.free(&handle_a);
1406 let _ = provider.free(&handle_b);
1407 let _ = provider.free(&out_handle);
1408 }
1409
1410 #[runmat_macros::runtime_builtin(name = "__arrayfun_test_handler")]
1411 fn arrayfun_test_handler(seed: Value, _err: Value, rest: Vec<Value>) -> Result<Value, String> {
1412 let _ = rest;
1413 Ok(seed)
1414 }
1415
1416 #[cfg(feature = "doc_export")]
1417 #[test]
1418 fn arrayfun_doc_examples_present() {
1419 let blocks = test_support::doc_examples(DOC_MD);
1420 assert!(blocks.len() >= 5, "expected at least five doc examples");
1421 }
1422}