runmat_runtime/builtins/acceleration/gpu/
arrayfun.rs

1//! MATLAB-compatible `arrayfun` builtin with GPU-aware semantics.
2//!
3//! This implementation supports applying a scalar MATLAB function to every element
4//! of one or more array inputs. When invoked with `gpuArray` inputs the builtin
5//! executes on the host today and uploads the uniform output back to the device so
6//! downstream code continues to see GPU residency. Future provider hooks can swap
7//! in a device kernel without affecting the public API.
8
9use crate::builtins::common::spec::{
10    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
11    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
12};
13use crate::{
14    gather_if_needed, make_cell_with_shape, register_builtin_fusion_spec, register_builtin_gpu_spec,
15};
16use runmat_accelerate_api::{set_handle_logical, GpuTensorHandle, HostTensorView};
17use runmat_builtins::{CharArray, Closure, ComplexTensor, LogicalArray, Tensor, Value};
18use runmat_macros::runtime_builtin;
19
20#[cfg(feature = "doc_export")]
21use crate::register_builtin_doc_text;
22
23#[cfg(feature = "doc_export")]
24pub const DOC_MD: &str = r#"---
25title: "arrayfun"
26category: "acceleration/gpu"
27keywords: ["arrayfun", "gpuArray", "elementwise map", "anonymous function", "uniformoutput"]
28summary: "Apply a function to each element of array inputs, returning either a numeric array or a cell array."
29references:
30  - https://www.mathworks.com/help/parallel-computing/arrayfun.html
31gpu_support:
32  elementwise: true
33  reduction: false
34  precisions: ["f32", "f64"]
35  broadcasting: "matlab"
36  notes: "Executes directly on the GPU for supported builtin callbacks (sin, cos, abs, exp, log, sqrt, plus, minus, times, rdivide, ldivide) when all inputs are gpuArray values; falls back to host execution for closures, heterogeneous inputs, or unsupported callbacks. Uniform numeric/logical outputs are re-uploaded to the GPU otherwise; complex/character outputs stay on the host."
37fusion:
38  elementwise: false
39  reduction: false
40  max_inputs: 1
41  constants: "inline"
42requires_feature: null
43tested:
44  unit: "builtins::acceleration::gpu::arrayfun::tests"
45  integration: "builtins::acceleration::gpu::arrayfun::tests::arrayfun_gpu_roundtrip"
46  doc: "builtins::acceleration::gpu::arrayfun::tests::arrayfun_doc_examples_present"
47---
48
49# What does the `arrayfun` function do in MATLAB / RunMat?
50`arrayfun(func, A1, A2, …)` evaluates `func` for every element (or element-wise combination)
51of the supplied arrays. The builtin mirrors MATLAB's behaviour:
52
53- Inputs must have the same size. Scalars participate by broadcasting their single value.
54- The optional `'UniformOutput'` name-value flag controls whether results are collected into a
55  numeric/complex/logical/character array (`true`, the default) or returned as a cell array (`false`).
56- When `'ErrorHandler', handler` is supplied the handler receives the error struct and the
57  arguments that triggered the failure, letting you supply a fallback result.
58
59## How does the `arrayfun` function behave in MATLAB / RunMat?
60- Accepts function handles, builtin names (character vectors or string scalars), and closures.
61- Supports additional scalar parameters: `arrayfun(@(x,c) x + c, A, 5)` passes `5` to every call.
62- Honors the `'UniformOutput'` and `'ErrorHandler'` name-value pairs for MATLAB-compatible control flow.
63- Handles numeric, logical, character, and complex arrays. Unsupported types raise a descriptive
64  error instructing you to use `cellfun` when appropriate.
65- Empty inputs return empty outputs whose shape matches the first array argument.
66- When any input is a `gpuArray`, numeric or logical uniform outputs are uploaded back to the GPU
67  so downstream code retains GPU residency. Complex or character uniform outputs remain on the host
68  until providers add the appropriate buffer support. The current implementation computes on the
69  host and therefore inherits the host's floating-point behaviour.
70
71## `arrayfun` Function GPU Execution Behaviour
72When every input is a `gpuArray`, `'UniformOutput'` is `true`, and the callback resolves to one of
73the supported builtins (`sin`, `cos`, `abs`, `exp`, `log`, `sqrt`, `plus`, `minus`, `times`,
74`rdivide`, or `ldivide`), RunMat bypasses the host path and dispatches directly to the active
75provider through the corresponding hooks (`unary_*` or `elem_*`). The builtin acts as a fusion
76barrier—the fusion planner lowers upstream producers before invoking `arrayfun` because the callback
77can evaluate arbitrary MATLAB code.
78
79All other combinations—including closures, callbacks with extra scalar parameters, mixed residency,
80or `'UniformOutput', false`—gather inputs to the host, execute the callback element-wise, and then
81upload numeric or logical uniform results back to the GPU so later code continues with device
82residency. Complex and character uniform outputs remain host-resident until device representations
83are available. Cell outputs are always host-resident.
84
85## Examples of using the `arrayfun` function in MATLAB / RunMat
86
87### Squaring every element of a matrix
88```matlab
89A = [1 2 3; 4 5 6];
90B = arrayfun(@(x) x.^2, A);
91```
92Expected output:
93```matlab
94B =
95     1     4     9
96    16    25    36
97```
98
99### Passing additional scalar parameters
100```matlab
101A = [1 2 3];
102offset = 10;
103result = arrayfun(@(x, c) x + c, A, offset);
104```
105Expected output:
106```matlab
107result =
108    11    12    13
109```
110
111### Returning cells with non-uniform outputs
112```matlab
113strings = ["Run" "Matlab" "GPU"];
114chars = arrayfun(@(s) sprintf("%d", strlength(s)), strings, 'UniformOutput', false);
115```
116Expected output:
117```matlab
118chars =
119  1×3 cell array
120    {'3'}    {'6'}    {'3'}
121```
122
123### Handling errors with a custom error handler
124```matlab
125vals = [-1 0 1];
126handler = @(err, x) err.identifier;
127safe = arrayfun(@(x) sqrt(x), vals, 'ErrorHandler', handler, 'UniformOutput', false);
128```
129Expected output:
130```matlab
131safe =
132  1×3 cell array
133    {'MATLAB:arrayfun:FunctionError'}    {[0]}    {[1]}
134```
135
136### Working with `gpuArray` inputs
137```matlab
138G = gpuArray(linspace(0, pi, 5));
139S = arrayfun(@sin, G);
140H = gather(S);
141```
142Expected output:
143```matlab
144S =
145  1×5 gpuArray
146         0    0.7071    1.0000    0.7071         0
147H =
148         0    0.7071    1.0000    0.7071         0
149```
150
151## GPU residency in RunMat (Do I need `gpuArray`?)
152No. RunMat's auto-offload logic moves tensors to the GPU when profitable. If you do call
153`gpuArray`, `arrayfun` keeps the result on the GPU for uniform numeric or logical outputs so later
154operations can continue without gathering. Non-uniform or complex/character results stay on the
155host until GPU representations are available.
156
157## FAQ
158
159### Do I have to call `gpuArray` before using `arrayfun`?
160No. `arrayfun` participates in the same planner as other builtins, so the runtime migrates data to
161the GPU when it determines a benefit. Manual `gpuArray` calls remain useful for MATLAB
162compatibility or to force residency for custom workflows.
163
164### What happens when the callback returns mixed types?
165Set `'UniformOutput', false` so the builtin returns a cell array. When `'UniformOutput'` is `true`
166every callback invocation must return a numeric, logical, or complex scalar.
167
168### Can `arrayfun` handle character inputs?
169Yes. Each character element is passed to the callback as a single-character char array and the
170output follows the normal uniform/non-uniform rules.
171
172### Does `arrayfun` short-circuit on errors?
173No. The builtin invokes the optional error handler when a callback fails. If no handler is
174provided the first error aborts the entire call with a MATLAB-compatible identifier/message pair.
175
176### How are logical outputs represented on the GPU?
177Logical results use 0.0/1.0 buffers on the device. When you gather them RunMat converts the data
178back into a logical array automatically.
179
180## See Also
181[cellfun](../../cells/core/cellfun), [gpuArray](./gpuarray), [gather](./gather)
182
183## Source & Feedback
184- Source code: [`crates/runmat-runtime/src/builtins/acceleration/gpu/arrayfun.rs`](https://github.com/runmat-org/runmat/blob/main/crates/runmat-runtime/src/builtins/acceleration/gpu/arrayfun.rs)
185- Found an issue? Please [open a GitHub issue](https://github.com/runmat-org/runmat/issues/new/choose) with a repro.
186"#;
187
188pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
189    name: "arrayfun",
190    op_kind: GpuOpKind::Elementwise,
191    supported_precisions: &[ScalarType::F32, ScalarType::F64],
192    broadcast: BroadcastSemantics::Matlab,
193    provider_hooks: &[
194        ProviderHook::Unary { name: "unary_sin" },
195        ProviderHook::Unary { name: "unary_cos" },
196        ProviderHook::Unary { name: "unary_abs" },
197        ProviderHook::Unary { name: "unary_exp" },
198        ProviderHook::Unary { name: "unary_log" },
199        ProviderHook::Unary { name: "unary_sqrt" },
200        ProviderHook::Binary {
201            name: "elem_add",
202            commutative: true,
203        },
204        ProviderHook::Binary {
205            name: "elem_sub",
206            commutative: false,
207        },
208        ProviderHook::Binary {
209            name: "elem_mul",
210            commutative: true,
211        },
212        ProviderHook::Binary {
213            name: "elem_div",
214            commutative: false,
215        },
216    ],
217    constant_strategy: ConstantStrategy::InlineLiteral,
218    residency: ResidencyPolicy::NewHandle,
219    nan_mode: ReductionNaN::Include,
220    two_pass_threshold: None,
221    workgroup_size: None,
222    accepts_nan_mode: false,
223    notes: "Providers that implement the listed kernels can run supported callbacks entirely on the GPU; unsupported callbacks fall back to the host path with re-upload.",
224};
225
226register_builtin_gpu_spec!(GPU_SPEC);
227
228pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
229    name: "arrayfun",
230    shape: ShapeRequirements::Any,
231    constant_strategy: ConstantStrategy::InlineLiteral,
232    elementwise: None,
233    reduction: None,
234    emits_nan: false,
235    notes: "Acts as a fusion barrier because the callback can run arbitrary MATLAB code.",
236};
237
238register_builtin_fusion_spec!(FUSION_SPEC);
239
240#[cfg(feature = "doc_export")]
241register_builtin_doc_text!("arrayfun", DOC_MD);
242
243#[runtime_builtin(
244    name = "arrayfun",
245    category = "acceleration/gpu",
246    summary = "Apply a function element-wise to array inputs.",
247    keywords = "arrayfun,gpu,array,map,functional",
248    accel = "host"
249)]
250fn arrayfun_builtin(func: Value, mut rest: Vec<Value>) -> Result<Value, String> {
251    let callable = Callable::from_function(func)?;
252
253    let mut uniform_output = true;
254    let mut error_handler: Option<Callable> = None;
255
256    while rest.len() >= 2 {
257        let key_candidate = rest[rest.len() - 2].clone();
258        let Some(name) = extract_string(&key_candidate) else {
259            break;
260        };
261        let value = rest.pop().expect("value present");
262        rest.pop();
263        match name.trim().to_ascii_lowercase().as_str() {
264            "uniformoutput" => uniform_output = parse_uniform_output(value)?,
265            "errorhandler" => error_handler = Some(Callable::from_function(value)?),
266            other => return Err(format!("arrayfun: unknown name-value argument '{other}'")),
267        }
268    }
269
270    if rest.is_empty() {
271        return Err("arrayfun: expected at least one input array".to_string());
272    }
273
274    let inputs_snapshot = rest.clone();
275    let has_gpu_input = inputs_snapshot
276        .iter()
277        .any(|value| matches!(value, Value::GpuTensor(_)));
278    let gpu_device_id = inputs_snapshot.iter().find_map(|v| {
279        if let Value::GpuTensor(h) = v {
280            Some(h.device_id)
281        } else {
282            None
283        }
284    });
285
286    if uniform_output {
287        if let Some(gpu_result) =
288            try_gpu_fast_path(&callable, &inputs_snapshot, error_handler.as_ref())?
289        {
290            return Ok(gpu_result);
291        }
292    }
293
294    let mut inputs: Vec<ArrayInput> = Vec::with_capacity(rest.len());
295    let mut base_shape: Vec<usize> = Vec::new();
296    let mut base_len: Option<usize> = None;
297
298    for (idx, raw) in rest.into_iter().enumerate() {
299        if matches!(raw, Value::Cell(_)) {
300            return Err(
301                "arrayfun: cell inputs are not supported (use cellfun instead)".to_string(),
302            );
303        }
304        if matches!(raw, Value::Struct(_)) {
305            return Err("arrayfun: struct inputs are not supported".to_string());
306        }
307
308        let host_value = gather_if_needed(&raw)?;
309        let data = ArrayData::from_value(host_value)?;
310        let len = data.len();
311        let is_scalar = len == 1;
312
313        let mut input = ArrayInput { data, is_scalar };
314
315        if let Some(current) = base_len {
316            if current == len {
317                if len > 1 {
318                    let shape = input.shape_vec();
319                    if shape != base_shape {
320                        return Err(format!(
321                            "arrayfun: input {} does not match the size of the first array",
322                            idx + 1
323                        ));
324                    }
325                }
326            } else if len == 1 {
327                input.is_scalar = true;
328            } else if current == 1 {
329                base_len = Some(len);
330                base_shape = input.shape_vec();
331                for prior in &mut inputs {
332                    let prior_len = prior.len();
333                    if prior_len == len {
334                        if prior.shape_vec() != base_shape {
335                            return Err(format!(
336                                "arrayfun: input {} does not match the size of the first array",
337                                idx
338                            ));
339                        }
340                    } else if prior_len == 1 {
341                        prior.is_scalar = true;
342                    } else if prior_len == 0 && len == 0 {
343                        continue;
344                    } else {
345                        return Err(format!(
346                            "arrayfun: input {} does not match the size of the first array",
347                            idx
348                        ));
349                    }
350                }
351            } else if len == 0 && current == 0 {
352                let shape = input.shape_vec();
353                if shape != base_shape {
354                    return Err(format!(
355                        "arrayfun: input {} does not match the size of the first array",
356                        idx + 1
357                    ));
358                }
359            } else {
360                return Err(format!(
361                    "arrayfun: input {} does not match the size of the first array",
362                    idx + 1
363                ));
364            }
365        } else {
366            base_len = Some(len);
367            base_shape = input.shape_vec();
368        }
369
370        inputs.push(input);
371    }
372
373    let total_len = base_len.unwrap_or(0);
374
375    if total_len == 0 {
376        if uniform_output {
377            return Ok(empty_uniform(&base_shape));
378        } else {
379            return make_cell_with_shape(Vec::new(), base_shape)
380                .map_err(|e| format!("arrayfun: {e}"));
381        }
382    }
383
384    let mut collector = if uniform_output {
385        Some(UniformCollector::Pending)
386    } else {
387        None
388    };
389
390    let mut cell_outputs: Vec<Value> = Vec::new();
391    let mut args: Vec<Value> = Vec::with_capacity(inputs.len());
392
393    for idx in 0..total_len {
394        args.clear();
395        for input in &inputs {
396            args.push(input.value_at(idx)?);
397        }
398
399        let result = match callable.call(&args) {
400            Ok(value) => value,
401            Err(err) => {
402                let handler = error_handler
403                    .as_ref()
404                    .ok_or_else(|| format!("arrayfun: {err}"))?;
405                let err_value = make_error_struct(&err, idx, &base_shape)?;
406                let mut handler_args = Vec::with_capacity(1 + args.len());
407                handler_args.push(err_value);
408                handler_args.extend(args.clone());
409                handler.call(&handler_args)?
410            }
411        };
412
413        let host_result = gather_if_needed(&result)?;
414
415        if let Some(collector) = collector.as_mut() {
416            collector.push(&host_result)?;
417        } else {
418            cell_outputs.push(host_result);
419        }
420    }
421
422    if let Some(collector) = collector {
423        let uniform = collector.finish(&base_shape)?;
424        maybe_upload_uniform(uniform, has_gpu_input, gpu_device_id)
425    } else {
426        make_cell_with_shape(cell_outputs, base_shape).map_err(|e| format!("arrayfun: {e}"))
427    }
428}
429
430fn maybe_upload_uniform(
431    value: Value,
432    has_gpu_input: bool,
433    gpu_device_id: Option<u32>,
434) -> Result<Value, String> {
435    if !has_gpu_input {
436        return Ok(value);
437    }
438    #[cfg(all(test, feature = "wgpu"))]
439    {
440        if matches!(gpu_device_id, Some(id) if id != 0) {
441            let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
442                runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
443            );
444        }
445    }
446    let _ = gpu_device_id; // may be used only in cfg(test)
447    let provider = match runmat_accelerate_api::provider() {
448        Some(p) => p,
449        None => return Ok(value),
450    };
451
452    match value {
453        Value::Tensor(tensor) => {
454            let view = HostTensorView {
455                data: &tensor.data,
456                shape: &tensor.shape,
457            };
458            let handle = provider.upload(&view).map_err(|e| e.to_string())?;
459            Ok(Value::GpuTensor(handle))
460        }
461        Value::LogicalArray(logical) => {
462            let data: Vec<f64> = logical
463                .data
464                .iter()
465                .map(|&bit| if bit != 0 { 1.0 } else { 0.0 })
466                .collect();
467            let tensor =
468                Tensor::new(data, logical.shape.clone()).map_err(|e| format!("arrayfun: {e}"))?;
469            let view = HostTensorView {
470                data: &tensor.data,
471                shape: &tensor.shape,
472            };
473            let handle = provider.upload(&view).map_err(|e| e.to_string())?;
474            set_handle_logical(&handle, true);
475            Ok(Value::GpuTensor(handle))
476        }
477        other => Ok(other),
478    }
479}
480
481fn empty_uniform(shape: &[usize]) -> Value {
482    if shape.is_empty() {
483        return Value::Tensor(Tensor::zeros(vec![0, 0]));
484    }
485    let total: usize = shape.iter().product();
486    let tensor = Tensor::new(vec![0.0; total], shape.to_vec())
487        .unwrap_or_else(|_| Tensor::zeros(shape.to_vec()));
488    Value::Tensor(tensor)
489}
490
491fn parse_uniform_output(value: Value) -> Result<bool, String> {
492    match value {
493        Value::Bool(b) => Ok(b),
494        Value::Num(n) => Ok(n != 0.0),
495        Value::Int(iv) => Ok(iv.to_f64() != 0.0),
496        Value::String(s) => parse_bool_string(&s)
497            .ok_or_else(|| "arrayfun: UniformOutput must be logical true or false".to_string()),
498        Value::CharArray(ca) if ca.rows == 1 => {
499            let text: String = ca.data.iter().collect();
500            parse_bool_string(&text)
501                .ok_or_else(|| "arrayfun: UniformOutput must be logical true or false".to_string())
502        }
503        other => Err(format!(
504            "arrayfun: UniformOutput must be logical true or false, got {other:?}"
505        )),
506    }
507}
508
509fn parse_bool_string(value: &str) -> Option<bool> {
510    match value.trim().to_ascii_lowercase().as_str() {
511        "true" | "on" => Some(true),
512        "false" | "off" => Some(false),
513        _ => None,
514    }
515}
516
517fn extract_string(value: &Value) -> Option<String> {
518    match value {
519        Value::String(s) => Some(s.clone()),
520        Value::CharArray(ca) if ca.rows == 1 => Some(ca.data.iter().collect()),
521        Value::StringArray(sa) if sa.data.len() == 1 => Some(sa.data[0].clone()),
522        _ => None,
523    }
524}
525
526struct ArrayInput {
527    data: ArrayData,
528    is_scalar: bool,
529}
530
531impl ArrayInput {
532    fn len(&self) -> usize {
533        self.data.len()
534    }
535
536    fn shape_vec(&self) -> Vec<usize> {
537        self.data.shape_vec()
538    }
539
540    fn value_at(&self, idx: usize) -> Result<Value, String> {
541        if self.is_scalar {
542            self.data.value_at(0)
543        } else {
544            self.data.value_at(idx)
545        }
546    }
547}
548
549enum ArrayData {
550    Tensor(Tensor),
551    Logical(LogicalArray),
552    Complex(ComplexTensor),
553    Char(CharArray),
554    Scalar(Value),
555}
556
557impl ArrayData {
558    fn from_value(value: Value) -> Result<Self, String> {
559        match value {
560            Value::Tensor(t) => Ok(ArrayData::Tensor(t)),
561            Value::LogicalArray(l) => Ok(ArrayData::Logical(l)),
562            Value::ComplexTensor(c) => Ok(ArrayData::Complex(c)),
563            Value::CharArray(ca) => Ok(ArrayData::Char(ca)),
564            Value::Num(_) | Value::Bool(_) | Value::Int(_) | Value::Complex(_, _) => {
565                Ok(ArrayData::Scalar(value))
566            }
567            other => Err(format!(
568                "arrayfun: unsupported input type {other:?} (expected numeric, logical, complex, or char arrays)"
569            )),
570        }
571    }
572
573    fn len(&self) -> usize {
574        match self {
575            ArrayData::Tensor(t) => t.data.len(),
576            ArrayData::Logical(l) => l.data.len(),
577            ArrayData::Complex(c) => c.data.len(),
578            ArrayData::Char(ca) => ca.rows * ca.cols,
579            ArrayData::Scalar(_) => 1,
580        }
581    }
582
583    fn shape_vec(&self) -> Vec<usize> {
584        match self {
585            ArrayData::Tensor(t) => {
586                if t.shape.is_empty() {
587                    vec![1, 1]
588                } else {
589                    t.shape.clone()
590                }
591            }
592            ArrayData::Logical(l) => {
593                if l.shape.is_empty() {
594                    vec![1, 1]
595                } else {
596                    l.shape.clone()
597                }
598            }
599            ArrayData::Complex(c) => {
600                if c.shape.is_empty() {
601                    vec![1, 1]
602                } else {
603                    c.shape.clone()
604                }
605            }
606            ArrayData::Char(ca) => vec![ca.rows, ca.cols],
607            ArrayData::Scalar(_) => vec![1, 1],
608        }
609    }
610
611    fn value_at(&self, idx: usize) -> Result<Value, String> {
612        match self {
613            ArrayData::Tensor(t) => {
614                Ok(Value::Num(*t.data.get(idx).ok_or_else(|| {
615                    "arrayfun: index out of bounds".to_string()
616                })?))
617            }
618            ArrayData::Logical(l) => Ok(Value::Bool(
619                *l.data
620                    .get(idx)
621                    .ok_or_else(|| "arrayfun: index out of bounds".to_string())?
622                    != 0,
623            )),
624            ArrayData::Complex(c) => {
625                let (re, im) = c
626                    .data
627                    .get(idx)
628                    .ok_or_else(|| "arrayfun: index out of bounds".to_string())?;
629                Ok(Value::Complex(*re, *im))
630            }
631            ArrayData::Char(ca) => {
632                if ca.rows == 0 || ca.cols == 0 {
633                    return Ok(Value::CharArray(
634                        CharArray::new(Vec::new(), 0, 0).map_err(|e| format!("arrayfun: {e}"))?,
635                    ));
636                }
637                let rows = ca.rows;
638                let cols = ca.cols;
639                let row = idx % rows;
640                let col = idx / rows;
641                let data_idx = row * cols + col;
642                let ch = *ca
643                    .data
644                    .get(data_idx)
645                    .ok_or_else(|| "arrayfun: index out of bounds".to_string())?;
646                let char_array =
647                    CharArray::new(vec![ch], 1, 1).map_err(|e| format!("arrayfun: {e}"))?;
648                Ok(Value::CharArray(char_array))
649            }
650            ArrayData::Scalar(v) => Ok(v.clone()),
651        }
652    }
653}
654
655#[derive(Clone)]
656enum Callable {
657    Builtin { name: String },
658    Closure(Closure),
659}
660
661impl Callable {
662    fn from_function(value: Value) -> Result<Self, String> {
663        match value {
664            Value::String(text) => Self::from_text(&text),
665            Value::CharArray(ca) => {
666                if ca.rows != 1 {
667                    Err(
668                        "arrayfun: function name must be a character vector or string scalar"
669                            .to_string(),
670                    )
671                } else {
672                    let text: String = ca.data.iter().collect();
673                    Self::from_text(&text)
674                }
675            }
676            Value::StringArray(sa) if sa.data.len() == 1 => Self::from_text(&sa.data[0]),
677            Value::FunctionHandle(name) => Ok(Callable::Builtin { name }),
678            Value::Closure(closure) => Ok(Callable::Closure(closure)),
679            Value::Num(_) | Value::Int(_) | Value::Bool(_) => Err(
680                "arrayfun: expected function handle or builtin name, not a scalar value"
681                    .to_string(),
682            ),
683            other => Err(format!(
684                "arrayfun: expected function handle or builtin name, got {other:?}"
685            )),
686        }
687    }
688
689    fn from_text(text: &str) -> Result<Self, String> {
690        let trimmed = text.trim();
691        if trimmed.is_empty() {
692            return Err(
693                "arrayfun: expected function handle or builtin name, got empty string".to_string(),
694            );
695        }
696        if let Some(rest) = trimmed.strip_prefix('@') {
697            let name = rest.trim();
698            if name.is_empty() {
699                Err("arrayfun: empty function handle".to_string())
700            } else {
701                Ok(Callable::Builtin {
702                    name: name.to_string(),
703                })
704            }
705        } else {
706            Ok(Callable::Builtin {
707                name: trimmed.to_ascii_lowercase(),
708            })
709        }
710    }
711
712    fn builtin_name(&self) -> Option<&str> {
713        match self {
714            Callable::Builtin { name } => Some(name.as_str()),
715            Callable::Closure(_) => None,
716        }
717    }
718
719    fn call(&self, args: &[Value]) -> Result<Value, String> {
720        match self {
721            Callable::Builtin { name } => crate::call_builtin(name, args),
722            Callable::Closure(c) => {
723                let mut merged = c.captures.clone();
724                merged.extend_from_slice(args);
725                crate::call_builtin(&c.function_name, &merged)
726            }
727        }
728    }
729}
730
731fn try_gpu_fast_path(
732    callable: &Callable,
733    inputs: &[Value],
734    error_handler: Option<&Callable>,
735) -> Result<Option<Value>, String> {
736    if inputs.is_empty() || error_handler.is_some() {
737        return Ok(None);
738    }
739    if !inputs
740        .iter()
741        .all(|value| matches!(value, Value::GpuTensor(_)))
742    {
743        return Ok(None);
744    }
745
746    #[cfg(all(test, feature = "wgpu"))]
747    {
748        if inputs
749            .iter()
750            .any(|v| matches!(v, Value::GpuTensor(h) if h.device_id != 0))
751        {
752            let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
753                runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
754            );
755        }
756    }
757    let provider = match runmat_accelerate_api::provider() {
758        Some(p) => p,
759        None => return Ok(None),
760    };
761
762    let Some(name_raw) = callable.builtin_name() else {
763        return Ok(None);
764    };
765    let name = name_raw.to_ascii_lowercase();
766
767    let mut handles: Vec<GpuTensorHandle> = Vec::with_capacity(inputs.len());
768    for value in inputs {
769        if let Value::GpuTensor(handle) = value {
770            handles.push(handle.clone());
771        }
772    }
773
774    if handles.len() >= 2 {
775        let base_shape = handles[0].shape.clone();
776        if handles
777            .iter()
778            .skip(1)
779            .any(|handle| handle.shape != base_shape)
780        {
781            return Ok(None);
782        }
783    }
784
785    let result = match name.as_str() {
786        "sin" if handles.len() == 1 => provider.unary_sin(&handles[0]),
787        "cos" if handles.len() == 1 => provider.unary_cos(&handles[0]),
788        "abs" if handles.len() == 1 => provider.unary_abs(&handles[0]),
789        "exp" if handles.len() == 1 => provider.unary_exp(&handles[0]),
790        "log" if handles.len() == 1 => provider.unary_log(&handles[0]),
791        "sqrt" if handles.len() == 1 => provider.unary_sqrt(&handles[0]),
792        "plus" if handles.len() == 2 => provider.elem_add(&handles[0], &handles[1]),
793        "minus" if handles.len() == 2 => provider.elem_sub(&handles[0], &handles[1]),
794        "times" if handles.len() == 2 => provider.elem_mul(&handles[0], &handles[1]),
795        "rdivide" if handles.len() == 2 => provider.elem_div(&handles[0], &handles[1]),
796        "ldivide" if handles.len() == 2 => provider.elem_div(&handles[1], &handles[0]),
797        _ => return Ok(None),
798    };
799
800    match result {
801        Ok(handle) => Ok(Some(Value::GpuTensor(handle))),
802        Err(_) => Ok(None),
803    }
804}
805
806enum UniformCollector {
807    Pending,
808    Double(Vec<f64>),
809    Logical(Vec<u8>),
810    Complex(Vec<(f64, f64)>),
811    Char(Vec<char>),
812}
813
814impl UniformCollector {
815    fn push(&mut self, value: &Value) -> Result<(), String> {
816        match self {
817            UniformCollector::Pending => match classify_value(value)? {
818                ClassifiedValue::Logical(b) => {
819                    *self = UniformCollector::Logical(vec![b as u8]);
820                    Ok(())
821                }
822                ClassifiedValue::Double(d) => {
823                    *self = UniformCollector::Double(vec![d]);
824                    Ok(())
825                }
826                ClassifiedValue::Complex(c) => {
827                    *self = UniformCollector::Complex(vec![c]);
828                    Ok(())
829                }
830                ClassifiedValue::Char(ch) => {
831                    *self = UniformCollector::Char(vec![ch]);
832                    Ok(())
833                }
834            },
835            UniformCollector::Logical(bits) => match classify_value(value)? {
836                ClassifiedValue::Logical(b) => {
837                    bits.push(b as u8);
838                    Ok(())
839                }
840                ClassifiedValue::Double(d) => {
841                    let mut data: Vec<f64> = bits
842                        .iter()
843                        .map(|&bit| if bit != 0 { 1.0 } else { 0.0 })
844                        .collect();
845                    data.push(d);
846                    *self = UniformCollector::Double(data);
847                    Ok(())
848                }
849                ClassifiedValue::Complex(c) => {
850                    let mut data: Vec<(f64, f64)> = bits
851                        .iter()
852                        .map(|&bit| if bit != 0 { (1.0, 0.0) } else { (0.0, 0.0) })
853                        .collect();
854                    data.push(c);
855                    *self = UniformCollector::Complex(data);
856                    Ok(())
857                }
858                ClassifiedValue::Char(ch) => {
859                    let mut data: Vec<f64> = bits
860                        .iter()
861                        .map(|&bit| if bit != 0 { 1.0 } else { 0.0 })
862                        .collect();
863                    data.push(ch as u32 as f64);
864                    *self = UniformCollector::Double(data);
865                    Ok(())
866                }
867            },
868            UniformCollector::Double(data) => match classify_value(value)? {
869                ClassifiedValue::Logical(b) => {
870                    data.push(if b { 1.0 } else { 0.0 });
871                    Ok(())
872                }
873                ClassifiedValue::Double(d) => {
874                    data.push(d);
875                    Ok(())
876                }
877                ClassifiedValue::Complex(c) => {
878                    let promoted: Vec<(f64, f64)> = data.iter().map(|&v| (v, 0.0)).collect();
879                    let mut complex = promoted;
880                    complex.push(c);
881                    *self = UniformCollector::Complex(complex);
882                    Ok(())
883                }
884                ClassifiedValue::Char(ch) => {
885                    data.push(ch as u32 as f64);
886                    Ok(())
887                }
888            },
889            UniformCollector::Complex(data) => match classify_value(value)? {
890                ClassifiedValue::Logical(b) => {
891                    data.push((if b { 1.0 } else { 0.0 }, 0.0));
892                    Ok(())
893                }
894                ClassifiedValue::Double(d) => {
895                    data.push((d, 0.0));
896                    Ok(())
897                }
898                ClassifiedValue::Complex(c) => {
899                    data.push(c);
900                    Ok(())
901                }
902                ClassifiedValue::Char(ch) => {
903                    data.push((ch as u32 as f64, 0.0));
904                    Ok(())
905                }
906            },
907            UniformCollector::Char(chars) => match classify_value(value)? {
908                ClassifiedValue::Char(ch) => {
909                    chars.push(ch);
910                    Ok(())
911                }
912                ClassifiedValue::Logical(b) => {
913                    let mut data: Vec<f64> = chars.iter().map(|&ch| ch as u32 as f64).collect();
914                    data.push(if b { 1.0 } else { 0.0 });
915                    *self = UniformCollector::Double(data);
916                    Ok(())
917                }
918                ClassifiedValue::Double(d) => {
919                    let mut data: Vec<f64> = chars.iter().map(|&ch| ch as u32 as f64).collect();
920                    data.push(d);
921                    *self = UniformCollector::Double(data);
922                    Ok(())
923                }
924                ClassifiedValue::Complex(c) => {
925                    let mut promoted: Vec<(f64, f64)> =
926                        chars.iter().map(|&ch| (ch as u32 as f64, 0.0)).collect();
927                    promoted.push(c);
928                    *self = UniformCollector::Complex(promoted);
929                    Ok(())
930                }
931            },
932        }
933    }
934
935    fn finish(self, shape: &[usize]) -> Result<Value, String> {
936        match self {
937            UniformCollector::Pending => {
938                let total = shape.iter().product();
939                let tensor = Tensor::new(vec![0.0; total], shape.to_vec())
940                    .map_err(|e| format!("arrayfun: {e}"))?;
941                Ok(Value::Tensor(tensor))
942            }
943            UniformCollector::Double(data) => {
944                let tensor =
945                    Tensor::new(data, shape.to_vec()).map_err(|e| format!("arrayfun: {e}"))?;
946                Ok(Value::Tensor(tensor))
947            }
948            UniformCollector::Logical(bits) => {
949                let logical = LogicalArray::new(bits, shape.to_vec())
950                    .map_err(|e| format!("arrayfun: {e}"))?;
951                Ok(Value::LogicalArray(logical))
952            }
953            UniformCollector::Complex(entries) => {
954                let tensor = ComplexTensor::new(entries, shape.to_vec())
955                    .map_err(|e| format!("arrayfun: {e}"))?;
956                Ok(Value::ComplexTensor(tensor))
957            }
958            UniformCollector::Char(chars) => {
959                let normalized_shape = if shape.is_empty() {
960                    vec![1, 1]
961                } else {
962                    shape.to_vec()
963                };
964
965                if normalized_shape.len() > 2 {
966                    return Err(
967                        "arrayfun: character outputs with UniformOutput=true must be 2-D"
968                            .to_string(),
969                    );
970                }
971
972                let rows = normalized_shape.first().copied().unwrap_or(1);
973                let cols = normalized_shape.get(1).copied().unwrap_or(1);
974                let expected = rows.checked_mul(cols).ok_or_else(|| {
975                    "arrayfun: character output size exceeds platform limits".to_string()
976                })?;
977
978                if expected != chars.len() {
979                    return Err(
980                        "arrayfun: callback returned the wrong number of characters".to_string()
981                    );
982                }
983
984                let mut row_major = vec!['\0'; expected];
985                for col in 0..cols {
986                    for row in 0..rows {
987                        let col_major_idx = row + col * rows;
988                        let row_major_idx = row * cols + col;
989                        row_major[row_major_idx] = chars[col_major_idx];
990                    }
991                }
992
993                let array =
994                    CharArray::new(row_major, rows, cols).map_err(|e| format!("arrayfun: {e}"))?;
995                Ok(Value::CharArray(array))
996            }
997        }
998    }
999}
1000
1001enum ClassifiedValue {
1002    Logical(bool),
1003    Double(f64),
1004    Complex((f64, f64)),
1005    Char(char),
1006}
1007
1008fn classify_value(value: &Value) -> Result<ClassifiedValue, String> {
1009    match value {
1010        Value::Bool(b) => Ok(ClassifiedValue::Logical(*b)),
1011        Value::LogicalArray(la) if la.len() == 1 => Ok(ClassifiedValue::Logical(la.data[0] != 0)),
1012        Value::Int(i) => Ok(ClassifiedValue::Double(i.to_f64())),
1013        Value::Num(n) => Ok(ClassifiedValue::Double(*n)),
1014        Value::Tensor(t) if t.data.len() == 1 => Ok(ClassifiedValue::Double(t.data[0])),
1015        Value::Complex(re, im) => Ok(ClassifiedValue::Complex((*re, *im))),
1016        Value::ComplexTensor(t) if t.data.len() == 1 => Ok(ClassifiedValue::Complex(t.data[0])),
1017        Value::CharArray(ca) if ca.rows * ca.cols == 1 => {
1018            let ch = ca.data.first().copied().unwrap_or('\0');
1019            Ok(ClassifiedValue::Char(ch))
1020        }
1021        other => Err(format!(
1022            "arrayfun: callback must return scalar numeric, logical, character, or complex values for UniformOutput=true (got {other:?})"
1023        )),
1024    }
1025}
1026
1027fn make_error_struct(
1028    raw_error: &str,
1029    linear_index: usize,
1030    shape: &[usize],
1031) -> Result<Value, String> {
1032    let (identifier, message) = split_error_message(raw_error);
1033    let mut st = runmat_builtins::StructValue::new();
1034    st.fields
1035        .insert("identifier".to_string(), Value::String(identifier));
1036    st.fields
1037        .insert("message".to_string(), Value::String(message));
1038    st.fields
1039        .insert("index".to_string(), Value::Num((linear_index + 1) as f64));
1040    let subs = linear_to_indices(linear_index, shape);
1041    let subs_tensor = dims_to_row_tensor(&subs)?;
1042    st.fields
1043        .insert("indices".to_string(), Value::Tensor(subs_tensor));
1044    Ok(Value::Struct(st))
1045}
1046
1047fn split_error_message(raw: &str) -> (String, String) {
1048    let trimmed = raw.trim();
1049    let mut indices = trimmed.match_indices(':');
1050    if let Some((_, _)) = indices.next() {
1051        if let Some((second_idx, _)) = indices.next() {
1052            let identifier = trimmed[..second_idx].trim().to_string();
1053            let message = trimmed[second_idx + 1..].trim().to_string();
1054            if !identifier.is_empty() && identifier.contains(':') {
1055                return (
1056                    identifier,
1057                    if message.is_empty() {
1058                        trimmed.to_string()
1059                    } else {
1060                        message
1061                    },
1062                );
1063            }
1064        } else if trimmed.len() >= 7
1065            && (trimmed[..7].eq_ignore_ascii_case("matlab:")
1066                || trimmed[..7].eq_ignore_ascii_case("runmat:"))
1067        {
1068            return (trimmed.to_string(), String::new());
1069        }
1070    }
1071    (
1072        "MATLAB:arrayfun:FunctionError".to_string(),
1073        trimmed.to_string(),
1074    )
1075}
1076
1077fn linear_to_indices(mut index: usize, shape: &[usize]) -> Vec<usize> {
1078    if shape.is_empty() {
1079        return vec![1];
1080    }
1081    let mut subs = Vec::with_capacity(shape.len());
1082    for &dim in shape {
1083        if dim == 0 {
1084            subs.push(1);
1085            continue;
1086        }
1087        let coord = (index % dim) + 1;
1088        subs.push(coord);
1089        index /= dim;
1090    }
1091    subs
1092}
1093
1094fn dims_to_row_tensor(dims: &[usize]) -> Result<Tensor, String> {
1095    let data: Vec<f64> = dims.iter().map(|&d| d as f64).collect();
1096    Tensor::new(data, vec![1, dims.len()]).map_err(|e| format!("arrayfun: {e}"))
1097}
1098
1099#[cfg(test)]
1100mod tests {
1101    use super::*;
1102    use crate::builtins::common::test_support;
1103    use runmat_accelerate_api::HostTensorView;
1104    use runmat_builtins::Tensor;
1105
1106    #[test]
1107    fn arrayfun_basic_sin() {
1108        let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0], vec![2, 3]).unwrap();
1109        let expected: Vec<f64> = tensor.data.iter().map(|&x| x.sin()).collect();
1110        let result = arrayfun_builtin(
1111            Value::FunctionHandle("sin".to_string()),
1112            vec![Value::Tensor(tensor.clone())],
1113        )
1114        .expect("arrayfun");
1115        match result {
1116            Value::Tensor(out) => {
1117                assert_eq!(out.shape, vec![2, 3]);
1118                assert_eq!(out.data, expected);
1119            }
1120            other => panic!("expected tensor, got {other:?}"),
1121        }
1122    }
1123
1124    #[test]
1125    fn arrayfun_additional_scalar_argument() {
1126        let tensor = Tensor::new(vec![0.5, 1.0, -1.0], vec![3, 1]).unwrap();
1127        let expected: Vec<f64> = tensor.data.iter().map(|&y| y.atan2(1.0)).collect();
1128        let result = arrayfun_builtin(
1129            Value::FunctionHandle("atan2".to_string()),
1130            vec![Value::Tensor(tensor), Value::Num(1.0)],
1131        )
1132        .expect("arrayfun");
1133        match result {
1134            Value::Tensor(out) => {
1135                assert_eq!(out.data, expected);
1136            }
1137            other => panic!("expected tensor, got {other:?}"),
1138        }
1139    }
1140
1141    #[test]
1142    fn arrayfun_uniform_false_returns_cell() {
1143        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1144        let expected: Vec<Value> = tensor.data.iter().map(|&x| Value::Num(x.sin())).collect();
1145        let result = arrayfun_builtin(
1146            Value::FunctionHandle("sin".to_string()),
1147            vec![
1148                Value::Tensor(tensor),
1149                Value::String("UniformOutput".into()),
1150                Value::Bool(false),
1151            ],
1152        )
1153        .expect("arrayfun");
1154        let Value::Cell(cell) = result else {
1155            panic!("expected cell, got something else");
1156        };
1157        assert_eq!(cell.rows, 2);
1158        assert_eq!(cell.cols, 1);
1159        for (row, value) in expected.iter().enumerate() {
1160            assert_eq!(cell.get(row, 0).unwrap(), *value);
1161        }
1162    }
1163
1164    #[test]
1165    fn arrayfun_size_mismatch_errors() {
1166        let taller = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1167        let shorter = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1168        let err = arrayfun_builtin(
1169            Value::FunctionHandle("sin".to_string()),
1170            vec![Value::Tensor(taller), Value::Tensor(shorter)],
1171        )
1172        .expect_err("expected size mismatch error");
1173        assert!(
1174            err.contains("does not match"),
1175            "expected size mismatch error, got {err}"
1176        );
1177    }
1178
1179    #[test]
1180    fn arrayfun_error_handler_recovers() {
1181        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1182        let handler = Value::Closure(Closure {
1183            function_name: "__arrayfun_test_handler".into(),
1184            captures: vec![Value::Num(42.0)],
1185        });
1186        let result = arrayfun_builtin(
1187            Value::String("@nonexistent_builtin".into()),
1188            vec![
1189                Value::Tensor(tensor),
1190                Value::String("ErrorHandler".into()),
1191                handler,
1192            ],
1193        )
1194        .expect("arrayfun error handler");
1195        match result {
1196            Value::Tensor(out) => {
1197                assert_eq!(out.shape, vec![3, 1]);
1198                assert_eq!(out.data, vec![42.0, 42.0, 42.0]);
1199            }
1200            other => panic!("expected tensor, got {other:?}"),
1201        }
1202    }
1203
1204    #[test]
1205    fn arrayfun_error_without_handler_propagates_identifier() {
1206        let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1207        let err = arrayfun_builtin(
1208            Value::String("@nonexistent_builtin".into()),
1209            vec![Value::Tensor(tensor)],
1210        )
1211        .expect_err("expected unresolved function error");
1212        assert!(
1213            err.contains("MATLAB:UndefinedFunction"),
1214            "unexpected error: {err}"
1215        );
1216    }
1217
1218    #[test]
1219    fn arrayfun_uniform_logical_result() {
1220        let tensor = Tensor::new(vec![1.0, f64::NAN, 0.0, f64::INFINITY], vec![4, 1]).unwrap();
1221        let result = arrayfun_builtin(
1222            Value::FunctionHandle("isfinite".to_string()),
1223            vec![Value::Tensor(tensor)],
1224        )
1225        .expect("arrayfun isfinite");
1226        match result {
1227            Value::LogicalArray(la) => {
1228                assert_eq!(la.shape, vec![4, 1]);
1229                assert_eq!(la.data, vec![1, 0, 1, 0]);
1230            }
1231            other => panic!("expected logical array, got {other:?}"),
1232        }
1233    }
1234
1235    #[test]
1236    fn arrayfun_uniform_character_result() {
1237        let tensor = Tensor::new(vec![65.0, 66.0, 67.0], vec![1, 3]).unwrap();
1238        let result = arrayfun_builtin(
1239            Value::FunctionHandle("char".to_string()),
1240            vec![Value::Tensor(tensor)],
1241        )
1242        .expect("arrayfun char");
1243        match result {
1244            Value::CharArray(ca) => {
1245                assert_eq!(ca.rows, 1);
1246                assert_eq!(ca.cols, 3);
1247                assert_eq!(ca.data, vec!['A', 'B', 'C']);
1248            }
1249            other => panic!("expected char array, got {other:?}"),
1250        }
1251    }
1252
1253    #[test]
1254    fn arrayfun_uniform_false_gpu_returns_cell() {
1255        test_support::with_test_provider(|provider| {
1256            let tensor = Tensor::new(vec![0.0, 1.0], vec![2, 1]).unwrap();
1257            let view = HostTensorView {
1258                data: &tensor.data,
1259                shape: &tensor.shape,
1260            };
1261            let handle = provider.upload(&view).expect("upload");
1262            let result = arrayfun_builtin(
1263                Value::FunctionHandle("sin".to_string()),
1264                vec![
1265                    Value::GpuTensor(handle),
1266                    Value::String("UniformOutput".into()),
1267                    Value::Bool(false),
1268                ],
1269            )
1270            .expect("arrayfun");
1271            match result {
1272                Value::Cell(cell) => {
1273                    assert_eq!(cell.rows, 2);
1274                    assert_eq!(cell.cols, 1);
1275                    let first = cell.get(0, 0).expect("first cell");
1276                    let second = cell.get(1, 0).expect("second cell");
1277                    match (first, second) {
1278                        (Value::Num(a), Value::Num(b)) => {
1279                            assert!((a - 0.0f64.sin()).abs() < 1e-12);
1280                            assert!((b - 1.0f64.sin()).abs() < 1e-12);
1281                        }
1282                        other => panic!("expected numeric cells, got {other:?}"),
1283                    }
1284                }
1285                other => panic!("expected cell, got {other:?}"),
1286            }
1287        });
1288    }
1289
1290    #[test]
1291    fn arrayfun_gpu_roundtrip() {
1292        test_support::with_test_provider(|provider| {
1293            let tensor = Tensor::new(vec![0.0, 1.0, 2.0, 3.0], vec![4, 1]).unwrap();
1294            let view = HostTensorView {
1295                data: &tensor.data,
1296                shape: &tensor.shape,
1297            };
1298            let handle = provider.upload(&view).expect("upload");
1299            let result = arrayfun_builtin(
1300                Value::FunctionHandle("sin".to_string()),
1301                vec![Value::GpuTensor(handle)],
1302            )
1303            .expect("arrayfun");
1304            match result {
1305                Value::GpuTensor(gpu) => {
1306                    let gathered = test_support::gather(Value::GpuTensor(gpu.clone())).unwrap();
1307                    let expected: Vec<f64> = tensor.data.iter().map(|&x| x.sin()).collect();
1308                    assert_eq!(gathered.data, expected);
1309                    let _ = provider.free(&gpu);
1310                }
1311                other => panic!("expected gpu tensor, got {other:?}"),
1312            }
1313        });
1314    }
1315
1316    #[test]
1317    #[cfg(feature = "wgpu")]
1318    fn arrayfun_wgpu_sin_matches_cpu() {
1319        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1320            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1321        );
1322        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1323
1324        let tensor = Tensor::new(vec![0.0, 1.0, 2.0, 3.0], vec![4, 1]).unwrap();
1325        let view = HostTensorView {
1326            data: &tensor.data,
1327            shape: &tensor.shape,
1328        };
1329        let handle = provider.upload(&view).expect("upload");
1330        let result = arrayfun_builtin(
1331            Value::FunctionHandle("sin".into()),
1332            vec![Value::GpuTensor(handle.clone())],
1333        )
1334        .expect("arrayfun sin gpu");
1335        let Value::GpuTensor(out_handle) = result else {
1336            panic!("expected GPU tensor result");
1337        };
1338        let gathered = test_support::gather(Value::GpuTensor(out_handle.clone())).unwrap();
1339        let expected: Vec<f64> = tensor.data.iter().map(|v| v.sin()).collect();
1340        assert_eq!(gathered.shape, tensor.shape);
1341        let tol = match provider.precision() {
1342            runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
1343            runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
1344        };
1345        for (actual, expect) in gathered.data.iter().zip(expected.iter()) {
1346            assert!(
1347                (actual - expect).abs() < tol,
1348                "expected {expect}, got {actual}"
1349            );
1350        }
1351        let _ = provider.free(&handle);
1352        let _ = provider.free(&out_handle);
1353    }
1354
1355    #[test]
1356    #[cfg(feature = "wgpu")]
1357    fn arrayfun_wgpu_plus_matches_cpu() {
1358        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1359            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1360        );
1361        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1362
1363        let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1364        let b = Tensor::new(vec![4.0, 3.0, 2.0, 1.0], vec![2, 2]).unwrap();
1365        let view_a = HostTensorView {
1366            data: &a.data,
1367            shape: &a.shape,
1368        };
1369        let view_b = HostTensorView {
1370            data: &b.data,
1371            shape: &b.shape,
1372        };
1373        let handle_a = provider.upload(&view_a).expect("upload a");
1374        let handle_b = provider.upload(&view_b).expect("upload b");
1375        let result = arrayfun_builtin(
1376            Value::FunctionHandle("plus".into()),
1377            vec![
1378                Value::GpuTensor(handle_a.clone()),
1379                Value::GpuTensor(handle_b.clone()),
1380            ],
1381        )
1382        .expect("arrayfun plus gpu");
1383
1384        let Value::GpuTensor(out_handle) = result else {
1385            panic!("expected GPU tensor result");
1386        };
1387        let gathered = test_support::gather(Value::GpuTensor(out_handle.clone())).unwrap();
1388        let expected: Vec<f64> = a
1389            .data
1390            .iter()
1391            .zip(b.data.iter())
1392            .map(|(x, y)| x + y)
1393            .collect();
1394        assert_eq!(gathered.shape, a.shape);
1395        let tol = match provider.precision() {
1396            runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
1397            runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
1398        };
1399        for (actual, expect) in gathered.data.iter().zip(expected.iter()) {
1400            assert!(
1401                (actual - expect).abs() < tol,
1402                "expected {expect}, got {actual}"
1403            );
1404        }
1405        let _ = provider.free(&handle_a);
1406        let _ = provider.free(&handle_b);
1407        let _ = provider.free(&out_handle);
1408    }
1409
1410    #[runmat_macros::runtime_builtin(name = "__arrayfun_test_handler")]
1411    fn arrayfun_test_handler(seed: Value, _err: Value, rest: Vec<Value>) -> Result<Value, String> {
1412        let _ = rest;
1413        Ok(seed)
1414    }
1415
1416    #[cfg(feature = "doc_export")]
1417    #[test]
1418    fn arrayfun_doc_examples_present() {
1419        let blocks = test_support::doc_examples(DOC_MD);
1420        assert!(blocks.len() >= 5, "expected at least five doc examples");
1421    }
1422}