Skip to main content

runmat_runtime/builtins/acceleration/gpu/
arrayfun.rs

1//! MATLAB-compatible `arrayfun` builtin with GPU-aware semantics.
2//!
3//! This implementation supports applying a scalar MATLAB function to every element
4//! of one or more array inputs. When invoked with `gpuArray` inputs the builtin
5//! executes on the host today and uploads the uniform output back to the device so
6//! downstream code continues to see GPU residency. Future provider hooks can swap
7//! in a device kernel without affecting the public API.
8
9use crate::builtins::acceleration::gpu::type_resolvers::arrayfun_type;
10use crate::builtins::common::spec::{
11    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
12    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
13};
14use crate::{
15    build_runtime_error, gather_if_needed_async, make_cell_with_shape, user_functions,
16    BuiltinResult, RuntimeError,
17};
18use runmat_accelerate_api::{set_handle_logical, GpuTensorHandle, HostTensorView};
19use runmat_builtins::{
20    CharArray, Closure, ComplexTensor, LogicalArray, StringArray, Tensor, Value,
21};
22use runmat_macros::runtime_builtin;
23
24#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::acceleration::gpu::arrayfun")]
25pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
26    name: "arrayfun",
27    op_kind: GpuOpKind::Elementwise,
28    supported_precisions: &[ScalarType::F32, ScalarType::F64],
29    broadcast: BroadcastSemantics::Matlab,
30    provider_hooks: &[
31        ProviderHook::Unary { name: "unary_sin" },
32        ProviderHook::Unary { name: "unary_cos" },
33        ProviderHook::Unary { name: "unary_abs" },
34        ProviderHook::Unary { name: "unary_exp" },
35        ProviderHook::Unary { name: "unary_log" },
36        ProviderHook::Unary { name: "unary_sqrt" },
37        ProviderHook::Binary {
38            name: "elem_add",
39            commutative: true,
40        },
41        ProviderHook::Binary {
42            name: "elem_sub",
43            commutative: false,
44        },
45        ProviderHook::Binary {
46            name: "elem_mul",
47            commutative: true,
48        },
49        ProviderHook::Binary {
50            name: "elem_div",
51            commutative: false,
52        },
53    ],
54    constant_strategy: ConstantStrategy::InlineLiteral,
55    residency: ResidencyPolicy::NewHandle,
56    nan_mode: ReductionNaN::Include,
57    two_pass_threshold: None,
58    workgroup_size: None,
59    accepts_nan_mode: false,
60    notes: "Providers that implement the listed kernels can run supported callbacks entirely on the GPU; unsupported callbacks fall back to the host path with re-upload.",
61};
62
63#[runmat_macros::register_fusion_spec(
64    builtin_path = "crate::builtins::acceleration::gpu::arrayfun"
65)]
66pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
67    name: "arrayfun",
68    shape: ShapeRequirements::Any,
69    constant_strategy: ConstantStrategy::InlineLiteral,
70    elementwise: None,
71    reduction: None,
72    emits_nan: false,
73    notes: "Acts as a fusion barrier because the callback can run arbitrary MATLAB code.",
74};
75
76fn arrayfun_error(message: impl Into<String>) -> RuntimeError {
77    build_runtime_error(message)
78        .with_builtin("arrayfun")
79        .build()
80}
81
82fn arrayfun_error_with_source(message: impl Into<String>, source: RuntimeError) -> RuntimeError {
83    let identifier = source.identifier().map(str::to_string);
84    let mut builder = build_runtime_error(message)
85        .with_builtin("arrayfun")
86        .with_source(source);
87    if let Some(identifier) = identifier {
88        builder = builder.with_identifier(identifier);
89    }
90    builder.build()
91}
92
93fn arrayfun_flow(message: impl Into<String>) -> RuntimeError {
94    arrayfun_error(message)
95}
96
97fn arrayfun_flow_with_source(message: impl Into<String>, source: RuntimeError) -> RuntimeError {
98    arrayfun_error_with_source(message, source)
99}
100
101fn format_handler_error(err: &RuntimeError) -> String {
102    if let Some(identifier) = err.identifier() {
103        if err.message().is_empty() {
104            return identifier.to_string();
105        }
106        if err.message().starts_with(identifier) {
107            return err.message().to_string();
108        }
109        return format!("{identifier}: {}", err.message());
110    }
111    err.message().to_string()
112}
113
114#[runtime_builtin(
115    name = "arrayfun",
116    category = "acceleration/gpu",
117    summary = "Apply a function element-wise to array inputs.",
118    keywords = "arrayfun,gpu,array,map,functional",
119    accel = "host",
120    type_resolver(arrayfun_type),
121    builtin_path = "crate::builtins::acceleration::gpu::arrayfun"
122)]
123async fn arrayfun_builtin(func: Value, mut rest: Vec<Value>) -> crate::BuiltinResult<Value> {
124    let callable = Callable::from_function(func)?;
125
126    let mut uniform_output = true;
127    let mut error_handler: Option<Callable> = None;
128
129    while rest.len() >= 2 {
130        let key_candidate = rest[rest.len() - 2].clone();
131        let Some(name) = extract_string(&key_candidate) else {
132            break;
133        };
134        let value = rest.pop().expect("value present");
135        rest.pop();
136        match name.trim().to_ascii_lowercase().as_str() {
137            "uniformoutput" => uniform_output = parse_uniform_output(value)?,
138            "errorhandler" => error_handler = Some(Callable::from_function(value)?),
139            other => {
140                return Err(arrayfun_flow(format!(
141                    "arrayfun: unknown name-value argument '{other}'"
142                )))
143            }
144        }
145    }
146
147    if rest.is_empty() {
148        return Err(arrayfun_flow("arrayfun: expected at least one input array"));
149    }
150
151    let inputs_snapshot = rest.clone();
152    let has_gpu_input = inputs_snapshot
153        .iter()
154        .any(|value| matches!(value, Value::GpuTensor(_)));
155    let gpu_device_id = inputs_snapshot.iter().find_map(|v| {
156        if let Value::GpuTensor(h) = v {
157            Some(h.device_id)
158        } else {
159            None
160        }
161    });
162
163    if uniform_output {
164        if let Some(gpu_result) =
165            try_gpu_fast_path(&callable, &inputs_snapshot, error_handler.as_ref()).await?
166        {
167            return Ok(gpu_result);
168        }
169    }
170
171    let mut inputs: Vec<ArrayInput> = Vec::with_capacity(rest.len());
172    let mut base_shape: Vec<usize> = Vec::new();
173    let mut base_len: Option<usize> = None;
174
175    for (idx, raw) in rest.into_iter().enumerate() {
176        if matches!(raw, Value::Cell(_)) {
177            return Err(arrayfun_flow(
178                "arrayfun: cell inputs are not supported (use cellfun instead)",
179            ));
180        }
181        if matches!(raw, Value::Struct(_)) {
182            return Err(arrayfun_flow("arrayfun: struct inputs are not supported"));
183        }
184
185        let host_value = gather_if_needed_async(&raw).await?;
186        let data = ArrayData::from_value(host_value)?;
187        let len = data.len();
188        let is_scalar = len == 1;
189
190        let mut input = ArrayInput { data, is_scalar };
191
192        if let Some(current) = base_len {
193            if current == len {
194                if len > 1 {
195                    let shape = input.shape_vec();
196                    if shape != base_shape {
197                        return Err(arrayfun_flow(format!(
198                            "arrayfun: input {} does not match the size of the first array",
199                            idx + 1
200                        )));
201                    }
202                }
203            } else if len == 1 {
204                input.is_scalar = true;
205            } else if current == 1 {
206                base_len = Some(len);
207                base_shape = input.shape_vec();
208                for prior in &mut inputs {
209                    let prior_len = prior.len();
210                    if prior_len == len {
211                        if prior.shape_vec() != base_shape {
212                            return Err(arrayfun_flow(format!(
213                                "arrayfun: input {} does not match the size of the first array",
214                                idx
215                            )));
216                        }
217                    } else if prior_len == 1 {
218                        prior.is_scalar = true;
219                    } else if prior_len == 0 && len == 0 {
220                        continue;
221                    } else {
222                        return Err(arrayfun_flow(format!(
223                            "arrayfun: input {} does not match the size of the first array",
224                            idx
225                        )));
226                    }
227                }
228            } else if len == 0 && current == 0 {
229                let shape = input.shape_vec();
230                if shape != base_shape {
231                    return Err(arrayfun_flow(format!(
232                        "arrayfun: input {} does not match the size of the first array",
233                        idx + 1
234                    )));
235                }
236            } else {
237                return Err(arrayfun_flow(format!(
238                    "arrayfun: input {} does not match the size of the first array",
239                    idx + 1
240                )));
241            }
242        } else {
243            base_len = Some(len);
244            base_shape = input.shape_vec();
245        }
246
247        inputs.push(input);
248    }
249
250    let total_len = base_len.unwrap_or(0);
251
252    if total_len == 0 {
253        if uniform_output {
254            return Ok(empty_uniform(&base_shape));
255        } else {
256            return make_cell_with_shape(Vec::new(), base_shape)
257                .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")));
258        }
259    }
260
261    let mut collector = if uniform_output {
262        Some(UniformCollector::Pending)
263    } else {
264        None
265    };
266
267    let mut cell_outputs: Vec<Value> = Vec::new();
268    let mut args: Vec<Value> = Vec::with_capacity(inputs.len());
269
270    for idx in 0..total_len {
271        args.clear();
272        for input in &inputs {
273            args.push(input.value_at(idx)?);
274        }
275
276        let result = match callable.call(&args).await {
277            Ok(value) => value,
278            Err(err) => {
279                let handler = match error_handler.as_ref() {
280                    Some(handler) => handler,
281                    None => {
282                        return Err(arrayfun_flow_with_source(
283                            format!("arrayfun: {}", err.message()),
284                            err,
285                        ))
286                    }
287                };
288                let err_message = format_handler_error(&err);
289                let err_value = make_error_struct(&err_message, idx, &base_shape)?;
290                let mut handler_args = Vec::with_capacity(1 + args.len());
291                handler_args.push(err_value);
292                handler_args.extend(args.clone());
293                handler.call(&handler_args).await?
294            }
295        };
296
297        let host_result = gather_if_needed_async(&result).await?;
298
299        if let Some(collector) = collector.as_mut() {
300            collector.push(&host_result)?;
301        } else {
302            cell_outputs.push(host_result);
303        }
304    }
305
306    if let Some(collector) = collector {
307        let uniform = collector.finish(&base_shape)?;
308        maybe_upload_uniform(uniform, has_gpu_input, gpu_device_id)
309    } else {
310        make_cell_with_shape(cell_outputs, base_shape)
311            .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))
312    }
313}
314
315fn maybe_upload_uniform(
316    value: Value,
317    has_gpu_input: bool,
318    gpu_device_id: Option<u32>,
319) -> BuiltinResult<Value> {
320    if !has_gpu_input {
321        return Ok(value);
322    }
323    #[cfg(all(test, feature = "wgpu"))]
324    {
325        if matches!(gpu_device_id, Some(id) if id != 0) {
326            let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
327                runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
328            );
329        }
330    }
331    let _ = gpu_device_id; // may be used only in cfg(test)
332    let provider = match runmat_accelerate_api::provider() {
333        Some(p) => p,
334        None => return Ok(value),
335    };
336
337    match value {
338        Value::Tensor(tensor) => {
339            let view = HostTensorView {
340                data: &tensor.data,
341                shape: &tensor.shape,
342            };
343            let handle = provider
344                .upload(&view)
345                .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?;
346            Ok(Value::GpuTensor(handle))
347        }
348        Value::LogicalArray(logical) => {
349            let data: Vec<f64> = logical
350                .data
351                .iter()
352                .map(|&bit| if bit != 0 { 1.0 } else { 0.0 })
353                .collect();
354            let tensor = Tensor::new(data, logical.shape.clone())
355                .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?;
356            let view = HostTensorView {
357                data: &tensor.data,
358                shape: &tensor.shape,
359            };
360            let handle = provider
361                .upload(&view)
362                .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?;
363            set_handle_logical(&handle, true);
364            Ok(Value::GpuTensor(handle))
365        }
366        other => Ok(other),
367    }
368}
369
370fn empty_uniform(shape: &[usize]) -> Value {
371    if shape.is_empty() {
372        return Value::Tensor(Tensor::zeros(vec![0, 0]));
373    }
374    let total: usize = shape.iter().product();
375    let tensor = Tensor::new(vec![0.0; total], shape.to_vec())
376        .unwrap_or_else(|_| Tensor::zeros(shape.to_vec()));
377    Value::Tensor(tensor)
378}
379
380fn parse_uniform_output(value: Value) -> BuiltinResult<bool> {
381    match value {
382        Value::Bool(b) => Ok(b),
383        Value::Num(n) => Ok(n != 0.0),
384        Value::Int(iv) => Ok(iv.to_f64() != 0.0),
385        Value::String(s) => parse_bool_string(&s)
386            .ok_or_else(|| arrayfun_flow("arrayfun: UniformOutput must be logical true or false")),
387        Value::CharArray(ca) if ca.rows == 1 => {
388            let text: String = ca.data.iter().collect();
389            parse_bool_string(&text).ok_or_else(|| {
390                arrayfun_flow("arrayfun: UniformOutput must be logical true or false")
391            })
392        }
393        other => Err(arrayfun_flow(format!(
394            "arrayfun: UniformOutput must be logical true or false, got {other:?}"
395        ))),
396    }
397}
398
399fn parse_bool_string(value: &str) -> Option<bool> {
400    match value.trim().to_ascii_lowercase().as_str() {
401        "true" | "on" => Some(true),
402        "false" | "off" => Some(false),
403        _ => None,
404    }
405}
406
407fn extract_string(value: &Value) -> Option<String> {
408    match value {
409        Value::String(s) => Some(s.clone()),
410        Value::CharArray(ca) if ca.rows == 1 => Some(ca.data.iter().collect()),
411        Value::StringArray(sa) if sa.data.len() == 1 => Some(sa.data[0].clone()),
412        _ => None,
413    }
414}
415
416struct ArrayInput {
417    data: ArrayData,
418    is_scalar: bool,
419}
420
421impl ArrayInput {
422    fn len(&self) -> usize {
423        self.data.len()
424    }
425
426    fn shape_vec(&self) -> Vec<usize> {
427        self.data.shape_vec()
428    }
429
430    fn value_at(&self, idx: usize) -> BuiltinResult<Value> {
431        if self.is_scalar {
432            self.data.value_at(0)
433        } else {
434            self.data.value_at(idx)
435        }
436    }
437}
438
439enum ArrayData {
440    Tensor(Tensor),
441    Logical(LogicalArray),
442    Complex(ComplexTensor),
443    Char(CharArray),
444    String(StringArray),
445    Scalar(Value),
446}
447
448impl ArrayData {
449    fn from_value(value: Value) -> BuiltinResult<Self> {
450        match value {
451            Value::Tensor(t) => Ok(ArrayData::Tensor(t)),
452            Value::LogicalArray(l) => Ok(ArrayData::Logical(l)),
453            Value::ComplexTensor(c) => Ok(ArrayData::Complex(c)),
454            Value::CharArray(ca) => Ok(ArrayData::Char(ca)),
455            Value::StringArray(sa) => Ok(ArrayData::String(sa)),
456            Value::Num(_)
457            | Value::Bool(_)
458            | Value::Int(_)
459            | Value::Complex(_, _)
460            | Value::String(_) => {
461                Ok(ArrayData::Scalar(value))
462            }
463            other => Err(arrayfun_flow(format!(
464                "arrayfun: unsupported input type {other:?} (expected numeric, logical, complex, char, or string arrays)"
465            ))),
466        }
467    }
468
469    fn len(&self) -> usize {
470        match self {
471            ArrayData::Tensor(t) => t.data.len(),
472            ArrayData::Logical(l) => l.data.len(),
473            ArrayData::Complex(c) => c.data.len(),
474            ArrayData::Char(ca) => ca.rows * ca.cols,
475            ArrayData::String(sa) => sa.data.len(),
476            ArrayData::Scalar(_) => 1,
477        }
478    }
479
480    fn shape_vec(&self) -> Vec<usize> {
481        match self {
482            ArrayData::Tensor(t) => {
483                if t.shape.is_empty() {
484                    vec![1, 1]
485                } else {
486                    t.shape.clone()
487                }
488            }
489            ArrayData::Logical(l) => {
490                if l.shape.is_empty() {
491                    vec![1, 1]
492                } else {
493                    l.shape.clone()
494                }
495            }
496            ArrayData::Complex(c) => {
497                if c.shape.is_empty() {
498                    vec![1, 1]
499                } else {
500                    c.shape.clone()
501                }
502            }
503            ArrayData::Char(ca) => vec![ca.rows, ca.cols],
504            ArrayData::String(sa) => {
505                if sa.shape.is_empty() {
506                    vec![1, 1]
507                } else {
508                    sa.shape.clone()
509                }
510            }
511            ArrayData::Scalar(_) => vec![1, 1],
512        }
513    }
514
515    fn value_at(&self, idx: usize) -> BuiltinResult<Value> {
516        match self {
517            ArrayData::Tensor(t) => {
518                Ok(Value::Num(*t.data.get(idx).ok_or_else(|| {
519                    arrayfun_flow("arrayfun: index out of bounds")
520                })?))
521            }
522            ArrayData::Logical(l) => Ok(Value::Bool(
523                *l.data
524                    .get(idx)
525                    .ok_or_else(|| arrayfun_flow("arrayfun: index out of bounds"))?
526                    != 0,
527            )),
528            ArrayData::Complex(c) => {
529                let (re, im) = c
530                    .data
531                    .get(idx)
532                    .ok_or_else(|| arrayfun_flow("arrayfun: index out of bounds"))?;
533                Ok(Value::Complex(*re, *im))
534            }
535            ArrayData::Char(ca) => {
536                if ca.rows == 0 || ca.cols == 0 {
537                    return Ok(Value::CharArray(
538                        CharArray::new(Vec::new(), 0, 0)
539                            .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?,
540                    ));
541                }
542                let rows = ca.rows;
543                let cols = ca.cols;
544                let row = idx % rows;
545                let col = idx / rows;
546                let data_idx = row * cols + col;
547                let ch = *ca
548                    .data
549                    .get(data_idx)
550                    .ok_or_else(|| arrayfun_flow("arrayfun: index out of bounds"))?;
551                let char_array = CharArray::new(vec![ch], 1, 1)
552                    .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?;
553                Ok(Value::CharArray(char_array))
554            }
555            ArrayData::String(sa) => {
556                Ok(Value::String(sa.data.get(idx).cloned().ok_or_else(
557                    || arrayfun_flow("arrayfun: index out of bounds"),
558                )?))
559            }
560            ArrayData::Scalar(v) => Ok(v.clone()),
561        }
562    }
563}
564
565#[derive(Clone)]
566enum Callable {
567    Builtin { name: String },
568    Closure(Closure),
569}
570
571impl Callable {
572    fn from_function(value: Value) -> BuiltinResult<Self> {
573        match value {
574            Value::String(text) => Self::from_text(&text),
575            Value::CharArray(ca) => {
576                if ca.rows != 1 {
577                    Err(arrayfun_flow(
578                        "arrayfun: function name must be a character vector or string scalar",
579                    ))
580                } else {
581                    let text: String = ca.data.iter().collect();
582                    Self::from_text(&text)
583                }
584            }
585            Value::StringArray(sa) if sa.data.len() == 1 => Self::from_text(&sa.data[0]),
586            Value::FunctionHandle(name) => Self::from_text(&name),
587            Value::Closure(closure) => Ok(Callable::Closure(closure)),
588            Value::Num(_) | Value::Int(_) | Value::Bool(_) => Err(arrayfun_flow(
589                "arrayfun: expected function handle or builtin name, not a scalar value",
590            )),
591            other => Err(arrayfun_flow(format!(
592                "arrayfun: expected function handle or builtin name, got {other:?}"
593            ))),
594        }
595    }
596
597    fn from_text(text: &str) -> BuiltinResult<Self> {
598        let trimmed = text.trim();
599        if trimmed.is_empty() {
600            return Err(arrayfun_flow(
601                "arrayfun: expected function handle or builtin name, got empty string",
602            ));
603        }
604        if let Some(rest) = trimmed.strip_prefix('@') {
605            let name = rest.trim();
606            if name.is_empty() {
607                Err(arrayfun_flow("arrayfun: empty function handle"))
608            } else {
609                Ok(Callable::Builtin {
610                    name: name.to_string(),
611                })
612            }
613        } else {
614            Ok(Callable::Builtin {
615                name: trimmed.to_ascii_lowercase(),
616            })
617        }
618    }
619
620    fn builtin_name(&self) -> Option<&str> {
621        match self {
622            Callable::Builtin { name } => Some(name.as_str()),
623            Callable::Closure(_) => None,
624        }
625    }
626
627    async fn call(&self, args: &[Value]) -> crate::BuiltinResult<Value> {
628        match self {
629            Callable::Builtin { name } => {
630                if let Some(result) = user_functions::try_call_user_function(name, args).await {
631                    return result;
632                }
633                crate::call_builtin_async(name, args).await
634            }
635            Callable::Closure(c) => {
636                let mut merged = c.captures.clone();
637                merged.extend_from_slice(args);
638                if let Some(result) =
639                    user_functions::try_call_user_function(&c.function_name, &merged).await
640                {
641                    return result;
642                }
643                crate::call_builtin_async(&c.function_name, &merged).await
644            }
645        }
646    }
647}
648
649async fn try_gpu_fast_path(
650    callable: &Callable,
651    inputs: &[Value],
652    error_handler: Option<&Callable>,
653) -> BuiltinResult<Option<Value>> {
654    if inputs.is_empty() || error_handler.is_some() {
655        return Ok(None);
656    }
657    if !inputs
658        .iter()
659        .all(|value| matches!(value, Value::GpuTensor(_)))
660    {
661        return Ok(None);
662    }
663
664    #[cfg(all(test, feature = "wgpu"))]
665    {
666        if inputs
667            .iter()
668            .any(|v| matches!(v, Value::GpuTensor(h) if h.device_id != 0))
669        {
670            let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
671                runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
672            );
673        }
674    }
675    let provider = match runmat_accelerate_api::provider() {
676        Some(p) => p,
677        None => return Ok(None),
678    };
679
680    let Some(name_raw) = callable.builtin_name() else {
681        return Ok(None);
682    };
683    let name = name_raw.to_ascii_lowercase();
684
685    let mut handles: Vec<GpuTensorHandle> = Vec::with_capacity(inputs.len());
686    for value in inputs {
687        if let Value::GpuTensor(handle) = value {
688            handles.push(handle.clone());
689        }
690    }
691
692    if handles.len() >= 2 {
693        let base_shape = handles[0].shape.clone();
694        if handles
695            .iter()
696            .skip(1)
697            .any(|handle| handle.shape != base_shape)
698        {
699            return Ok(None);
700        }
701    }
702
703    let result = match name.as_str() {
704        "sin" if handles.len() == 1 => provider.unary_sin(&handles[0]).await,
705        "cos" if handles.len() == 1 => provider.unary_cos(&handles[0]).await,
706        "abs" if handles.len() == 1 => provider.unary_abs(&handles[0]).await,
707        "exp" if handles.len() == 1 => provider.unary_exp(&handles[0]).await,
708        "log" if handles.len() == 1 => provider.unary_log(&handles[0]).await,
709        "sqrt" if handles.len() == 1 => provider.unary_sqrt(&handles[0]).await,
710        "plus" if handles.len() == 2 => provider.elem_add(&handles[0], &handles[1]).await,
711        "minus" if handles.len() == 2 => provider.elem_sub(&handles[0], &handles[1]).await,
712        "times" if handles.len() == 2 => provider.elem_mul(&handles[0], &handles[1]).await,
713        "rdivide" if handles.len() == 2 => provider.elem_div(&handles[0], &handles[1]).await,
714        "ldivide" if handles.len() == 2 => provider.elem_div(&handles[1], &handles[0]).await,
715        _ => return Ok(None),
716    };
717
718    match result {
719        Ok(handle) => Ok(Some(Value::GpuTensor(handle))),
720        Err(_) => Ok(None),
721    }
722}
723
724enum UniformCollector {
725    Pending,
726    Double(Vec<f64>),
727    Logical(Vec<u8>),
728    Complex(Vec<(f64, f64)>),
729    Char(Vec<char>),
730}
731
732impl UniformCollector {
733    fn push(&mut self, value: &Value) -> BuiltinResult<()> {
734        match self {
735            UniformCollector::Pending => match classify_value(value)? {
736                ClassifiedValue::Logical(b) => {
737                    *self = UniformCollector::Logical(vec![b as u8]);
738                    Ok(())
739                }
740                ClassifiedValue::Double(d) => {
741                    *self = UniformCollector::Double(vec![d]);
742                    Ok(())
743                }
744                ClassifiedValue::Complex(c) => {
745                    *self = UniformCollector::Complex(vec![c]);
746                    Ok(())
747                }
748                ClassifiedValue::Char(ch) => {
749                    *self = UniformCollector::Char(vec![ch]);
750                    Ok(())
751                }
752            },
753            UniformCollector::Logical(bits) => match classify_value(value)? {
754                ClassifiedValue::Logical(b) => {
755                    bits.push(b as u8);
756                    Ok(())
757                }
758                ClassifiedValue::Double(d) => {
759                    let mut data: Vec<f64> = bits
760                        .iter()
761                        .map(|&bit| if bit != 0 { 1.0 } else { 0.0 })
762                        .collect();
763                    data.push(d);
764                    *self = UniformCollector::Double(data);
765                    Ok(())
766                }
767                ClassifiedValue::Complex(c) => {
768                    let mut data: Vec<(f64, f64)> = bits
769                        .iter()
770                        .map(|&bit| if bit != 0 { (1.0, 0.0) } else { (0.0, 0.0) })
771                        .collect();
772                    data.push(c);
773                    *self = UniformCollector::Complex(data);
774                    Ok(())
775                }
776                ClassifiedValue::Char(ch) => {
777                    let mut data: Vec<f64> = bits
778                        .iter()
779                        .map(|&bit| if bit != 0 { 1.0 } else { 0.0 })
780                        .collect();
781                    data.push(ch as u32 as f64);
782                    *self = UniformCollector::Double(data);
783                    Ok(())
784                }
785            },
786            UniformCollector::Double(data) => match classify_value(value)? {
787                ClassifiedValue::Logical(b) => {
788                    data.push(if b { 1.0 } else { 0.0 });
789                    Ok(())
790                }
791                ClassifiedValue::Double(d) => {
792                    data.push(d);
793                    Ok(())
794                }
795                ClassifiedValue::Complex(c) => {
796                    let promoted: Vec<(f64, f64)> = data.iter().map(|&v| (v, 0.0)).collect();
797                    let mut complex = promoted;
798                    complex.push(c);
799                    *self = UniformCollector::Complex(complex);
800                    Ok(())
801                }
802                ClassifiedValue::Char(ch) => {
803                    data.push(ch as u32 as f64);
804                    Ok(())
805                }
806            },
807            UniformCollector::Complex(data) => match classify_value(value)? {
808                ClassifiedValue::Logical(b) => {
809                    data.push((if b { 1.0 } else { 0.0 }, 0.0));
810                    Ok(())
811                }
812                ClassifiedValue::Double(d) => {
813                    data.push((d, 0.0));
814                    Ok(())
815                }
816                ClassifiedValue::Complex(c) => {
817                    data.push(c);
818                    Ok(())
819                }
820                ClassifiedValue::Char(ch) => {
821                    data.push((ch as u32 as f64, 0.0));
822                    Ok(())
823                }
824            },
825            UniformCollector::Char(chars) => match classify_value(value)? {
826                ClassifiedValue::Char(ch) => {
827                    chars.push(ch);
828                    Ok(())
829                }
830                ClassifiedValue::Logical(b) => {
831                    let mut data: Vec<f64> = chars.iter().map(|&ch| ch as u32 as f64).collect();
832                    data.push(if b { 1.0 } else { 0.0 });
833                    *self = UniformCollector::Double(data);
834                    Ok(())
835                }
836                ClassifiedValue::Double(d) => {
837                    let mut data: Vec<f64> = chars.iter().map(|&ch| ch as u32 as f64).collect();
838                    data.push(d);
839                    *self = UniformCollector::Double(data);
840                    Ok(())
841                }
842                ClassifiedValue::Complex(c) => {
843                    let mut promoted: Vec<(f64, f64)> =
844                        chars.iter().map(|&ch| (ch as u32 as f64, 0.0)).collect();
845                    promoted.push(c);
846                    *self = UniformCollector::Complex(promoted);
847                    Ok(())
848                }
849            },
850        }
851    }
852
853    fn finish(self, shape: &[usize]) -> BuiltinResult<Value> {
854        match self {
855            UniformCollector::Pending => {
856                let total = shape.iter().product();
857                let tensor = Tensor::new(vec![0.0; total], shape.to_vec())
858                    .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?;
859                Ok(Value::Tensor(tensor))
860            }
861            UniformCollector::Double(data) => {
862                let tensor = Tensor::new(data, shape.to_vec())
863                    .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?;
864                Ok(Value::Tensor(tensor))
865            }
866            UniformCollector::Logical(bits) => {
867                let logical = LogicalArray::new(bits, shape.to_vec())
868                    .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?;
869                Ok(Value::LogicalArray(logical))
870            }
871            UniformCollector::Complex(entries) => {
872                let tensor = ComplexTensor::new(entries, shape.to_vec())
873                    .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?;
874                Ok(Value::ComplexTensor(tensor))
875            }
876            UniformCollector::Char(chars) => {
877                let normalized_shape = if shape.is_empty() {
878                    vec![1, 1]
879                } else {
880                    shape.to_vec()
881                };
882
883                if normalized_shape.len() > 2 {
884                    return Err(arrayfun_flow(
885                        "arrayfun: character outputs with UniformOutput=true must be 2-D",
886                    ));
887                }
888
889                let rows = normalized_shape.first().copied().unwrap_or(1);
890                let cols = normalized_shape.get(1).copied().unwrap_or(1);
891                let expected = rows.checked_mul(cols).ok_or_else(|| {
892                    arrayfun_flow("arrayfun: character output size exceeds platform limits")
893                })?;
894
895                if expected != chars.len() {
896                    return Err(arrayfun_flow(
897                        "arrayfun: callback returned the wrong number of characters",
898                    ));
899                }
900
901                let mut row_major = vec!['\0'; expected];
902                for col in 0..cols {
903                    for row in 0..rows {
904                        let col_major_idx = row + col * rows;
905                        let row_major_idx = row * cols + col;
906                        row_major[row_major_idx] = chars[col_major_idx];
907                    }
908                }
909
910                let array = CharArray::new(row_major, rows, cols)
911                    .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?;
912                Ok(Value::CharArray(array))
913            }
914        }
915    }
916}
917
918enum ClassifiedValue {
919    Logical(bool),
920    Double(f64),
921    Complex((f64, f64)),
922    Char(char),
923}
924
925fn classify_value(value: &Value) -> BuiltinResult<ClassifiedValue> {
926    match value {
927        Value::Bool(b) => Ok(ClassifiedValue::Logical(*b)),
928        Value::LogicalArray(la) if la.len() == 1 => Ok(ClassifiedValue::Logical(la.data[0] != 0)),
929        Value::Int(i) => Ok(ClassifiedValue::Double(i.to_f64())),
930        Value::Num(n) => Ok(ClassifiedValue::Double(*n)),
931        Value::Tensor(t) if t.data.len() == 1 => Ok(ClassifiedValue::Double(t.data[0])),
932        Value::Complex(re, im) => Ok(ClassifiedValue::Complex((*re, *im))),
933        Value::ComplexTensor(t) if t.data.len() == 1 => Ok(ClassifiedValue::Complex(t.data[0])),
934        Value::CharArray(ca) if ca.rows * ca.cols == 1 => {
935            let ch = ca.data.first().copied().unwrap_or('\0');
936            Ok(ClassifiedValue::Char(ch))
937        }
938        other => Err(arrayfun_flow(format!(
939            "arrayfun: callback must return scalar numeric, logical, character, or complex values for UniformOutput=true (got {other:?})"
940        ))),
941    }
942}
943
944fn make_error_struct(
945    raw_error: &str,
946    linear_index: usize,
947    shape: &[usize],
948) -> BuiltinResult<Value> {
949    let (identifier, message) = split_error_message(raw_error);
950    let mut st = runmat_builtins::StructValue::new();
951    st.fields
952        .insert("identifier".to_string(), Value::String(identifier));
953    st.fields
954        .insert("message".to_string(), Value::String(message));
955    st.fields
956        .insert("index".to_string(), Value::Num((linear_index + 1) as f64));
957    let subs = linear_to_indices(linear_index, shape);
958    let subs_tensor = dims_to_row_tensor(&subs)?;
959    st.fields
960        .insert("indices".to_string(), Value::Tensor(subs_tensor));
961    Ok(Value::Struct(st))
962}
963
964fn split_error_message(raw: &str) -> (String, String) {
965    let trimmed = raw.trim();
966    let mut indices = trimmed.match_indices(':');
967    if let Some((_, _)) = indices.next() {
968        if let Some((second_idx, _)) = indices.next() {
969            let identifier = trimmed[..second_idx].trim().to_string();
970            let message = trimmed[second_idx + 1..].trim().to_string();
971            if !identifier.is_empty() && identifier.contains(':') {
972                return (
973                    identifier,
974                    if message.is_empty() {
975                        trimmed.to_string()
976                    } else {
977                        message
978                    },
979                );
980            }
981        } else if trimmed.len() >= 7
982            && (trimmed[..7].eq_ignore_ascii_case("matlab:")
983                || trimmed[..7].eq_ignore_ascii_case("runmat:"))
984        {
985            return (trimmed.to_string(), String::new());
986        }
987    }
988    (
989        "RunMat:arrayfun:FunctionError".to_string(),
990        trimmed.to_string(),
991    )
992}
993
994fn linear_to_indices(mut index: usize, shape: &[usize]) -> Vec<usize> {
995    if shape.is_empty() {
996        return vec![1];
997    }
998    let mut subs = Vec::with_capacity(shape.len());
999    for &dim in shape {
1000        if dim == 0 {
1001            subs.push(1);
1002            continue;
1003        }
1004        let coord = (index % dim) + 1;
1005        subs.push(coord);
1006        index /= dim;
1007    }
1008    subs
1009}
1010
1011fn dims_to_row_tensor(dims: &[usize]) -> BuiltinResult<Tensor> {
1012    let data: Vec<f64> = dims.iter().map(|&d| d as f64).collect();
1013    Tensor::new(data, vec![1, dims.len()]).map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))
1014}
1015
1016#[cfg(test)]
1017pub(crate) mod tests {
1018    use super::*;
1019    use crate::builtins::common::test_support;
1020    use futures::executor::block_on;
1021    use runmat_accelerate_api::HostTensorView;
1022    use runmat_builtins::{ResolveContext, Tensor, Type};
1023
1024    fn call(func: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
1025        block_on(arrayfun_builtin(func, rest))
1026    }
1027
1028    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1029    #[test]
1030    fn arrayfun_basic_sin() {
1031        let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0], vec![2, 3]).unwrap();
1032        let expected: Vec<f64> = tensor.data.iter().map(|&x| x.sin()).collect();
1033        let result = call(
1034            Value::FunctionHandle("sin".to_string()),
1035            vec![Value::Tensor(tensor.clone())],
1036        )
1037        .expect("arrayfun");
1038        match result {
1039            Value::Tensor(out) => {
1040                assert_eq!(out.shape, vec![2, 3]);
1041                assert_eq!(out.data, expected);
1042            }
1043            other => panic!("expected tensor, got {other:?}"),
1044        }
1045    }
1046
1047    #[test]
1048    fn arrayfun_type_tracks_function_returns() {
1049        let func = Type::Function {
1050            params: vec![Type::Num],
1051            returns: Box::new(Type::Num),
1052        };
1053        assert_eq!(
1054            arrayfun_type(&[func, Type::tensor()], &ResolveContext::new(Vec::new())),
1055            Type::tensor()
1056        );
1057    }
1058
1059    #[test]
1060    fn arrayfun_type_uses_logical_returns() {
1061        let func = Type::Function {
1062            params: vec![Type::Num],
1063            returns: Box::new(Type::Bool),
1064        };
1065        assert_eq!(
1066            arrayfun_type(&[func, Type::tensor()], &ResolveContext::new(Vec::new())),
1067            Type::logical()
1068        );
1069    }
1070
1071    #[test]
1072    fn arrayfun_type_with_text_args_stays_unknown() {
1073        let func = Type::Function {
1074            params: vec![Type::Num],
1075            returns: Box::new(Type::Num),
1076        };
1077        assert_eq!(
1078            arrayfun_type(
1079                &[func, Type::tensor(), Type::String, Type::Bool],
1080                &ResolveContext::new(Vec::new()),
1081            ),
1082            Type::Unknown
1083        );
1084    }
1085
1086    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1087    #[test]
1088    fn arrayfun_additional_scalar_argument() {
1089        let tensor = Tensor::new(vec![0.5, 1.0, -1.0], vec![3, 1]).unwrap();
1090        let expected: Vec<f64> = tensor.data.iter().map(|&y| y.atan2(1.0)).collect();
1091        let result = call(
1092            Value::FunctionHandle("atan2".to_string()),
1093            vec![Value::Tensor(tensor), Value::Num(1.0)],
1094        )
1095        .expect("arrayfun");
1096        match result {
1097            Value::Tensor(out) => {
1098                assert_eq!(out.data, expected);
1099            }
1100            other => panic!("expected tensor, got {other:?}"),
1101        }
1102    }
1103
1104    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1105    #[test]
1106    fn arrayfun_uniform_false_returns_cell() {
1107        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1108        let expected: Vec<Value> = tensor.data.iter().map(|&x| Value::Num(x.sin())).collect();
1109        let result = call(
1110            Value::FunctionHandle("sin".to_string()),
1111            vec![
1112                Value::Tensor(tensor),
1113                Value::String("UniformOutput".into()),
1114                Value::Bool(false),
1115            ],
1116        )
1117        .expect("arrayfun");
1118        let Value::Cell(cell) = result else {
1119            panic!("expected cell, got something else");
1120        };
1121        assert_eq!(cell.rows, 2);
1122        assert_eq!(cell.cols, 1);
1123        for (row, value) in expected.iter().enumerate() {
1124            assert_eq!(cell.get(row, 0).unwrap(), *value);
1125        }
1126    }
1127
1128    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1129    #[test]
1130    fn arrayfun_size_mismatch_errors() {
1131        let taller = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1132        let shorter = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1133        let err = call(
1134            Value::FunctionHandle("sin".to_string()),
1135            vec![Value::Tensor(taller), Value::Tensor(shorter)],
1136        )
1137        .expect_err("expected size mismatch error");
1138        let err = err.to_string();
1139        assert!(
1140            err.contains("does not match"),
1141            "expected size mismatch error, got {err}"
1142        );
1143    }
1144
1145    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1146    #[test]
1147    fn arrayfun_error_handler_recovers() {
1148        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1149        let handler = Value::Closure(Closure {
1150            function_name: "__arrayfun_test_handler".into(),
1151            captures: vec![Value::Num(42.0)],
1152        });
1153        let result = call(
1154            Value::String("@nonexistent_builtin".into()),
1155            vec![
1156                Value::Tensor(tensor),
1157                Value::String("ErrorHandler".into()),
1158                handler,
1159            ],
1160        )
1161        .expect("arrayfun error handler");
1162        match result {
1163            Value::Tensor(out) => {
1164                assert_eq!(out.shape, vec![3, 1]);
1165                assert_eq!(out.data, vec![42.0, 42.0, 42.0]);
1166            }
1167            other => panic!("expected tensor, got {other:?}"),
1168        }
1169    }
1170
1171    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1172    #[test]
1173    fn arrayfun_error_without_handler_propagates_identifier() {
1174        let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1175        let err = call(
1176            Value::String("@nonexistent_builtin".into()),
1177            vec![Value::Tensor(tensor)],
1178        )
1179        .expect_err("expected unresolved function error");
1180        assert_eq!(
1181            err.identifier(),
1182            Some("RunMat:UndefinedFunction"),
1183            "unexpected error: {}",
1184            err.message()
1185        );
1186    }
1187
1188    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1189    #[test]
1190    fn arrayfun_uniform_logical_result() {
1191        let tensor = Tensor::new(vec![1.0, f64::NAN, 0.0, f64::INFINITY], vec![4, 1]).unwrap();
1192        let result = call(
1193            Value::FunctionHandle("isfinite".to_string()),
1194            vec![Value::Tensor(tensor)],
1195        )
1196        .expect("arrayfun isfinite");
1197        match result {
1198            Value::LogicalArray(la) => {
1199                assert_eq!(la.shape, vec![4, 1]);
1200                assert_eq!(la.data, vec![1, 0, 1, 0]);
1201            }
1202            other => panic!("expected logical array, got {other:?}"),
1203        }
1204    }
1205
1206    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1207    #[test]
1208    fn arrayfun_uniform_character_result() {
1209        let tensor = Tensor::new(vec![65.0, 66.0, 67.0], vec![1, 3]).unwrap();
1210        let result = call(
1211            Value::FunctionHandle("char".to_string()),
1212            vec![Value::Tensor(tensor)],
1213        )
1214        .expect("arrayfun char");
1215        match result {
1216            Value::CharArray(ca) => {
1217                assert_eq!(ca.rows, 1);
1218                assert_eq!(ca.cols, 3);
1219                assert_eq!(ca.data, vec!['A', 'B', 'C']);
1220            }
1221            other => panic!("expected char array, got {other:?}"),
1222        }
1223    }
1224
1225    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1226    #[test]
1227    fn arrayfun_uniform_false_gpu_returns_cell() {
1228        test_support::with_test_provider(|provider| {
1229            let tensor = Tensor::new(vec![0.0, 1.0], vec![2, 1]).unwrap();
1230            let view = HostTensorView {
1231                data: &tensor.data,
1232                shape: &tensor.shape,
1233            };
1234            let handle = provider.upload(&view).expect("upload");
1235            let result = call(
1236                Value::FunctionHandle("sin".to_string()),
1237                vec![
1238                    Value::GpuTensor(handle),
1239                    Value::String("UniformOutput".into()),
1240                    Value::Bool(false),
1241                ],
1242            )
1243            .expect("arrayfun");
1244            match result {
1245                Value::Cell(cell) => {
1246                    assert_eq!(cell.rows, 2);
1247                    assert_eq!(cell.cols, 1);
1248                    let first = cell.get(0, 0).expect("first cell");
1249                    let second = cell.get(1, 0).expect("second cell");
1250                    match (first, second) {
1251                        (Value::Num(a), Value::Num(b)) => {
1252                            assert!((a - 0.0f64.sin()).abs() < 1e-12);
1253                            assert!((b - 1.0f64.sin()).abs() < 1e-12);
1254                        }
1255                        other => panic!("expected numeric cells, got {other:?}"),
1256                    }
1257                }
1258                other => panic!("expected cell, got {other:?}"),
1259            }
1260        });
1261    }
1262
1263    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1264    #[test]
1265    fn arrayfun_gpu_roundtrip() {
1266        test_support::with_test_provider(|provider| {
1267            let tensor = Tensor::new(vec![0.0, 1.0, 2.0, 3.0], vec![4, 1]).unwrap();
1268            let view = HostTensorView {
1269                data: &tensor.data,
1270                shape: &tensor.shape,
1271            };
1272            let handle = provider.upload(&view).expect("upload");
1273            let result = call(
1274                Value::FunctionHandle("sin".to_string()),
1275                vec![Value::GpuTensor(handle)],
1276            )
1277            .expect("arrayfun");
1278            match result {
1279                Value::GpuTensor(gpu) => {
1280                    let gathered = test_support::gather(Value::GpuTensor(gpu.clone())).unwrap();
1281                    let expected: Vec<f64> = tensor.data.iter().map(|&x| x.sin()).collect();
1282                    assert_eq!(gathered.data, expected);
1283                    let _ = provider.free(&gpu);
1284                }
1285                other => panic!("expected gpu tensor, got {other:?}"),
1286            }
1287        });
1288    }
1289
1290    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1291    #[test]
1292    #[cfg(feature = "wgpu")]
1293    fn arrayfun_wgpu_sin_matches_cpu() {
1294        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1295            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1296        );
1297        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1298
1299        let tensor = Tensor::new(vec![0.0, 1.0, 2.0, 3.0], vec![4, 1]).unwrap();
1300        let view = HostTensorView {
1301            data: &tensor.data,
1302            shape: &tensor.shape,
1303        };
1304        let handle = provider.upload(&view).expect("upload");
1305        let result = call(
1306            Value::FunctionHandle("sin".into()),
1307            vec![Value::GpuTensor(handle.clone())],
1308        )
1309        .expect("arrayfun sin gpu");
1310        let Value::GpuTensor(out_handle) = result else {
1311            panic!("expected GPU tensor result");
1312        };
1313        let gathered = test_support::gather(Value::GpuTensor(out_handle.clone())).unwrap();
1314        let expected: Vec<f64> = tensor.data.iter().map(|v| v.sin()).collect();
1315        assert_eq!(gathered.shape, tensor.shape);
1316        let tol = match provider.precision() {
1317            runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
1318            runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
1319        };
1320        for (actual, expect) in gathered.data.iter().zip(expected.iter()) {
1321            assert!(
1322                (actual - expect).abs() < tol,
1323                "expected {expect}, got {actual}"
1324            );
1325        }
1326        let _ = provider.free(&handle);
1327        let _ = provider.free(&out_handle);
1328    }
1329
1330    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1331    #[test]
1332    #[cfg(feature = "wgpu")]
1333    fn arrayfun_wgpu_plus_matches_cpu() {
1334        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1335            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1336        );
1337        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1338
1339        let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1340        let b = Tensor::new(vec![4.0, 3.0, 2.0, 1.0], vec![2, 2]).unwrap();
1341        let view_a = HostTensorView {
1342            data: &a.data,
1343            shape: &a.shape,
1344        };
1345        let view_b = HostTensorView {
1346            data: &b.data,
1347            shape: &b.shape,
1348        };
1349        let handle_a = provider.upload(&view_a).expect("upload a");
1350        let handle_b = provider.upload(&view_b).expect("upload b");
1351        let result = call(
1352            Value::FunctionHandle("plus".into()),
1353            vec![
1354                Value::GpuTensor(handle_a.clone()),
1355                Value::GpuTensor(handle_b.clone()),
1356            ],
1357        )
1358        .expect("arrayfun plus gpu");
1359
1360        let Value::GpuTensor(out_handle) = result else {
1361            panic!("expected GPU tensor result");
1362        };
1363        let gathered = test_support::gather(Value::GpuTensor(out_handle.clone())).unwrap();
1364        let expected: Vec<f64> = a
1365            .data
1366            .iter()
1367            .zip(b.data.iter())
1368            .map(|(x, y)| x + y)
1369            .collect();
1370        assert_eq!(gathered.shape, a.shape);
1371        let tol = match provider.precision() {
1372            runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
1373            runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
1374        };
1375        for (actual, expect) in gathered.data.iter().zip(expected.iter()) {
1376            assert!(
1377                (actual - expect).abs() < tol,
1378                "expected {expect}, got {actual}"
1379            );
1380        }
1381        let _ = provider.free(&handle_a);
1382        let _ = provider.free(&handle_b);
1383        let _ = provider.free(&out_handle);
1384    }
1385
1386    #[runmat_macros::runtime_builtin(
1387        name = "__arrayfun_test_handler",
1388        type_resolver(arrayfun_type),
1389        builtin_path = "crate::builtins::acceleration::gpu::arrayfun::tests"
1390    )]
1391    async fn arrayfun_test_handler(
1392        seed: Value,
1393        _err: Value,
1394        rest: Vec<Value>,
1395    ) -> crate::BuiltinResult<Value> {
1396        let _ = rest;
1397        Ok(seed)
1398    }
1399}