Skip to main content

runmat_runtime/builtins/acceleration/gpu/
arrayfun.rs

1//! MATLAB-compatible `arrayfun` builtin with GPU-aware semantics.
2//!
3//! This implementation supports applying a scalar MATLAB function to every element
4//! of one or more array inputs. When invoked with `gpuArray` inputs the builtin
5//! executes on the host today and uploads the uniform output back to the device so
6//! downstream code continues to see GPU residency. Future provider hooks can swap
7//! in a device kernel without affecting the public API.
8
9use crate::builtins::acceleration::gpu::type_resolvers::arrayfun_type;
10use crate::builtins::common::spec::{
11    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
12    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
13};
14use crate::{
15    build_runtime_error, gather_if_needed_async, make_cell_with_shape, user_functions,
16    BuiltinResult, RuntimeError,
17};
18use runmat_accelerate_api::{set_handle_logical, GpuTensorHandle, HostTensorView};
19use runmat_builtins::{
20    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
21    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
22    CharArray, Closure, ComplexTensor, LogicalArray, StringArray, Tensor, Value,
23};
24use runmat_macros::runtime_builtin;
25
26#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::acceleration::gpu::arrayfun")]
27pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
28    name: "arrayfun",
29    op_kind: GpuOpKind::Elementwise,
30    supported_precisions: &[ScalarType::F32, ScalarType::F64],
31    broadcast: BroadcastSemantics::Matlab,
32    provider_hooks: &[
33        ProviderHook::Unary { name: "unary_sin" },
34        ProviderHook::Unary { name: "unary_cos" },
35        ProviderHook::Unary { name: "unary_abs" },
36        ProviderHook::Unary { name: "unary_exp" },
37        ProviderHook::Unary { name: "unary_log" },
38        ProviderHook::Unary { name: "unary_sqrt" },
39        ProviderHook::Binary {
40            name: "elem_add",
41            commutative: true,
42        },
43        ProviderHook::Binary {
44            name: "elem_sub",
45            commutative: false,
46        },
47        ProviderHook::Binary {
48            name: "elem_mul",
49            commutative: true,
50        },
51        ProviderHook::Binary {
52            name: "elem_div",
53            commutative: false,
54        },
55    ],
56    constant_strategy: ConstantStrategy::InlineLiteral,
57    residency: ResidencyPolicy::NewHandle,
58    nan_mode: ReductionNaN::Include,
59    two_pass_threshold: None,
60    workgroup_size: None,
61    accepts_nan_mode: false,
62    notes: "Providers that implement the listed kernels can run supported callbacks entirely on the GPU; unsupported callbacks fall back to the host path with re-upload.",
63};
64
65#[runmat_macros::register_fusion_spec(
66    builtin_path = "crate::builtins::acceleration::gpu::arrayfun"
67)]
68pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
69    name: "arrayfun",
70    shape: ShapeRequirements::Any,
71    constant_strategy: ConstantStrategy::InlineLiteral,
72    elementwise: None,
73    reduction: None,
74    emits_nan: false,
75    notes: "Acts as a fusion barrier because the callback can run arbitrary MATLAB code.",
76};
77
78const BUILTIN_NAME: &str = "arrayfun";
79
80const ARRAYFUN_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
81    name: "B",
82    ty: BuiltinParamType::Any,
83    arity: BuiltinParamArity::Required,
84    default: None,
85    description: "Element-wise callback result (uniform array or cell array).",
86}];
87
88const ARRAYFUN_INPUTS_BASE: [BuiltinParamDescriptor; 3] = [
89    BuiltinParamDescriptor {
90        name: "func",
91        ty: BuiltinParamType::Any,
92        arity: BuiltinParamArity::Required,
93        default: None,
94        description: "Function handle or callable name.",
95    },
96    BuiltinParamDescriptor {
97        name: "A1",
98        ty: BuiltinParamType::Any,
99        arity: BuiltinParamArity::Required,
100        default: None,
101        description: "First input array.",
102    },
103    BuiltinParamDescriptor {
104        name: "An",
105        ty: BuiltinParamType::Any,
106        arity: BuiltinParamArity::Variadic,
107        default: None,
108        description: "Additional input arrays.",
109    },
110];
111
112const ARRAYFUN_INPUTS_UNIFORM: [BuiltinParamDescriptor; 5] = [
113    BuiltinParamDescriptor {
114        name: "func",
115        ty: BuiltinParamType::Any,
116        arity: BuiltinParamArity::Required,
117        default: None,
118        description: "Function handle or callable name.",
119    },
120    BuiltinParamDescriptor {
121        name: "A1",
122        ty: BuiltinParamType::Any,
123        arity: BuiltinParamArity::Required,
124        default: None,
125        description: "First input array.",
126    },
127    BuiltinParamDescriptor {
128        name: "An",
129        ty: BuiltinParamType::Any,
130        arity: BuiltinParamArity::Variadic,
131        default: None,
132        description: "Additional input arrays.",
133    },
134    BuiltinParamDescriptor {
135        name: "UniformOutput",
136        ty: BuiltinParamType::PropertyName,
137        arity: BuiltinParamArity::Required,
138        default: Some("\"UniformOutput\""),
139        description: "Name-value key that toggles uniform output collection.",
140    },
141    BuiltinParamDescriptor {
142        name: "tf",
143        ty: BuiltinParamType::Any,
144        arity: BuiltinParamArity::Required,
145        default: Some("true"),
146        description: "Logical true/false value for UniformOutput.",
147    },
148];
149
150const ARRAYFUN_INPUTS_HANDLER: [BuiltinParamDescriptor; 5] = [
151    BuiltinParamDescriptor {
152        name: "func",
153        ty: BuiltinParamType::Any,
154        arity: BuiltinParamArity::Required,
155        default: None,
156        description: "Function handle or callable name.",
157    },
158    BuiltinParamDescriptor {
159        name: "A1",
160        ty: BuiltinParamType::Any,
161        arity: BuiltinParamArity::Required,
162        default: None,
163        description: "First input array.",
164    },
165    BuiltinParamDescriptor {
166        name: "An",
167        ty: BuiltinParamType::Any,
168        arity: BuiltinParamArity::Variadic,
169        default: None,
170        description: "Additional input arrays.",
171    },
172    BuiltinParamDescriptor {
173        name: "ErrorHandler",
174        ty: BuiltinParamType::PropertyName,
175        arity: BuiltinParamArity::Required,
176        default: Some("\"ErrorHandler\""),
177        description: "Name-value key that provides fallback callback on per-element failures.",
178    },
179    BuiltinParamDescriptor {
180        name: "handler",
181        ty: BuiltinParamType::Any,
182        arity: BuiltinParamArity::Required,
183        default: None,
184        description: "Callback invoked with error struct and original scalar arguments.",
185    },
186];
187
188const ARRAYFUN_INPUTS_OPTIONS: [BuiltinParamDescriptor; 4] = [
189    BuiltinParamDescriptor {
190        name: "func",
191        ty: BuiltinParamType::Any,
192        arity: BuiltinParamArity::Required,
193        default: None,
194        description: "Function handle or callable name.",
195    },
196    BuiltinParamDescriptor {
197        name: "A1",
198        ty: BuiltinParamType::Any,
199        arity: BuiltinParamArity::Required,
200        default: None,
201        description: "First input array.",
202    },
203    BuiltinParamDescriptor {
204        name: "An",
205        ty: BuiltinParamType::Any,
206        arity: BuiltinParamArity::Variadic,
207        default: None,
208        description: "Additional input arrays.",
209    },
210    BuiltinParamDescriptor {
211        name: "nameValue",
212        ty: BuiltinParamType::Any,
213        arity: BuiltinParamArity::Variadic,
214        default: None,
215        description: "Name-value option pairs including UniformOutput and ErrorHandler.",
216    },
217];
218
219const ARRAYFUN_SIGNATURES: [BuiltinSignatureDescriptor; 4] = [
220    BuiltinSignatureDescriptor {
221        label: "B = arrayfun(func, A1, An...)",
222        inputs: &ARRAYFUN_INPUTS_BASE,
223        outputs: &ARRAYFUN_OUTPUT,
224    },
225    BuiltinSignatureDescriptor {
226        label: "B = arrayfun(func, A1, An..., \"UniformOutput\", tf)",
227        inputs: &ARRAYFUN_INPUTS_UNIFORM,
228        outputs: &ARRAYFUN_OUTPUT,
229    },
230    BuiltinSignatureDescriptor {
231        label: "B = arrayfun(func, A1, An..., \"ErrorHandler\", handler)",
232        inputs: &ARRAYFUN_INPUTS_HANDLER,
233        outputs: &ARRAYFUN_OUTPUT,
234    },
235    BuiltinSignatureDescriptor {
236        label: "B = arrayfun(func, A1, An..., nameValue...)",
237        inputs: &ARRAYFUN_INPUTS_OPTIONS,
238        outputs: &ARRAYFUN_OUTPUT,
239    },
240];
241
242const ARRAYFUN_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
243    code: "RM.ARRAYFUN.INVALID_INPUT",
244    identifier: Some("RunMat:arrayfun:InvalidInput"),
245    when: "Inputs, callable forms, or option tails violate arrayfun argument requirements.",
246    message: "arrayfun: invalid input arguments",
247};
248
249const ARRAYFUN_ERROR_UNIFORM_OUTPUT_OPTION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
250    code: "RM.ARRAYFUN.UNIFORM_OUTPUT_OPTION",
251    identifier: Some("RunMat:arrayfun:UniformOutputOption"),
252    when: "UniformOutput option value is not interpretable as logical true/false.",
253    message: "arrayfun: UniformOutput must be logical true or false",
254};
255
256const ARRAYFUN_ERROR_CALLBACK_FAILED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
257    code: "RM.ARRAYFUN.CALLBACK_FAILED",
258    identifier: Some("RunMat:arrayfun:CallbackFailed"),
259    when: "Callback invocation fails and no ErrorHandler recovers the element.",
260    message: "arrayfun: callback execution failed",
261};
262
263const ARRAYFUN_ERROR_UNIFORM_OUTPUT_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
264    code: "RM.ARRAYFUN.UNIFORM_OUTPUT_TYPE",
265    identifier: Some("RunMat:arrayfun:UniformOutputType"),
266    when: "UniformOutput=true callback result is not a supported scalar type.",
267    message: "arrayfun: callback must return scalar values for UniformOutput=true",
268};
269
270const ARRAYFUN_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
271    code: "RM.ARRAYFUN.INTERNAL",
272    identifier: Some("RunMat:arrayfun:InternalError"),
273    when: "Internal shape/index/materialization/upload path fails.",
274    message: "arrayfun: internal error",
275};
276
277const ARRAYFUN_ERROR_UNDEFINED_FUNCTION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
278    code: "RM.ARRAYFUN.UNDEFINED_FUNCTION",
279    identifier: Some("RunMat:UndefinedFunction"),
280    when: "External callable identity cannot be resolved in semantic/runtime boundaries.",
281    message: "arrayfun: undefined function",
282};
283
284const ARRAYFUN_ERRORS: [BuiltinErrorDescriptor; 6] = [
285    ARRAYFUN_ERROR_INVALID_INPUT,
286    ARRAYFUN_ERROR_UNIFORM_OUTPUT_OPTION,
287    ARRAYFUN_ERROR_CALLBACK_FAILED,
288    ARRAYFUN_ERROR_UNIFORM_OUTPUT_TYPE,
289    ARRAYFUN_ERROR_INTERNAL,
290    ARRAYFUN_ERROR_UNDEFINED_FUNCTION,
291];
292
293pub const ARRAYFUN_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
294    signatures: &ARRAYFUN_SIGNATURES,
295    output_mode: BuiltinOutputMode::Fixed,
296    completion_policy: BuiltinCompletionPolicy::Public,
297    errors: &ARRAYFUN_ERRORS,
298};
299
300fn arrayfun_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
301    arrayfun_error_with_message(error.message, error)
302}
303
304fn arrayfun_error_with_message(
305    message: impl Into<String>,
306    error: &'static BuiltinErrorDescriptor,
307) -> RuntimeError {
308    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
309    if let Some(identifier) = error.identifier {
310        builder = builder.with_identifier(identifier);
311    }
312    builder.build()
313}
314
315fn arrayfun_error_with_detail(
316    error: &'static BuiltinErrorDescriptor,
317    detail: impl AsRef<str>,
318) -> RuntimeError {
319    arrayfun_error_with_message(format!("{}: {}", error.message, detail.as_ref()), error)
320}
321
322fn arrayfun_error_with_source(
323    message: impl Into<String>,
324    error: &'static BuiltinErrorDescriptor,
325    source: RuntimeError,
326) -> RuntimeError {
327    let identifier = source.identifier().map(str::to_string);
328    let mut builder = build_runtime_error(message.into())
329        .with_builtin(BUILTIN_NAME)
330        .with_source(source);
331    if let Some(identifier) = identifier.as_deref().or(error.identifier) {
332        builder = builder.with_identifier(identifier);
333    }
334    builder.build()
335}
336
337fn arrayfun_flow(message: impl Into<String>) -> RuntimeError {
338    arrayfun_error_with_message(message, &ARRAYFUN_ERROR_INVALID_INPUT)
339}
340
341fn arrayfun_internal(message: impl Into<String>) -> RuntimeError {
342    arrayfun_error_with_message(message, &ARRAYFUN_ERROR_INTERNAL)
343}
344
345fn arrayfun_flow_with_source(message: impl Into<String>, source: RuntimeError) -> RuntimeError {
346    arrayfun_error_with_source(message, &ARRAYFUN_ERROR_CALLBACK_FAILED, source)
347}
348
349fn format_handler_error(err: &RuntimeError) -> String {
350    if let Some(identifier) = err.identifier() {
351        if err.message().is_empty() {
352            return identifier.to_string();
353        }
354        if err.message().starts_with(identifier) {
355            return err.message().to_string();
356        }
357        return format!("{identifier}: {}", err.message());
358    }
359    err.message().to_string()
360}
361
362#[runtime_builtin(
363    name = "arrayfun",
364    category = "acceleration/gpu",
365    summary = "Apply a function element-wise across array inputs.",
366    keywords = "arrayfun,gpu,array,map,functional",
367    accel = "host",
368    type_resolver(arrayfun_type),
369    descriptor(crate::builtins::acceleration::gpu::arrayfun::ARRAYFUN_DESCRIPTOR),
370    builtin_path = "crate::builtins::acceleration::gpu::arrayfun"
371)]
372async fn arrayfun_builtin(func: Value, mut rest: Vec<Value>) -> crate::BuiltinResult<Value> {
373    let callable = Callable::from_function(func)?;
374
375    let mut uniform_output = true;
376    let mut error_handler: Option<Callable> = None;
377
378    while rest.len() >= 2 {
379        let key_candidate = rest[rest.len() - 2].clone();
380        let Some(name) = extract_string(&key_candidate) else {
381            break;
382        };
383        let value = rest.pop().expect("value present");
384        rest.pop();
385        match name.trim().to_ascii_lowercase().as_str() {
386            "uniformoutput" => uniform_output = parse_uniform_output(value)?,
387            "errorhandler" => error_handler = Some(Callable::from_function(value)?),
388            other => {
389                return Err(arrayfun_flow(format!(
390                    "arrayfun: unknown name-value argument '{other}'"
391                )))
392            }
393        }
394    }
395
396    if rest.is_empty() {
397        return Err(arrayfun_flow("arrayfun: expected at least one input array"));
398    }
399
400    let inputs_snapshot = rest.clone();
401    let has_gpu_input = inputs_snapshot
402        .iter()
403        .any(|value| matches!(value, Value::GpuTensor(_)));
404    let gpu_device_id = inputs_snapshot.iter().find_map(|v| {
405        if let Value::GpuTensor(h) = v {
406            Some(h.device_id)
407        } else {
408            None
409        }
410    });
411
412    if uniform_output {
413        if let Some(gpu_result) =
414            try_gpu_fast_path(&callable, &inputs_snapshot, error_handler.as_ref()).await?
415        {
416            return Ok(gpu_result);
417        }
418    }
419
420    let mut inputs: Vec<ArrayInput> = Vec::with_capacity(rest.len());
421    let mut base_shape: Vec<usize> = Vec::new();
422    let mut base_len: Option<usize> = None;
423
424    for (idx, raw) in rest.into_iter().enumerate() {
425        if matches!(raw, Value::Cell(_)) {
426            return Err(arrayfun_flow(
427                "arrayfun: cell inputs are not supported (use cellfun instead)",
428            ));
429        }
430        if matches!(raw, Value::Struct(_)) {
431            return Err(arrayfun_flow("arrayfun: struct inputs are not supported"));
432        }
433
434        let host_value = gather_if_needed_async(&raw).await?;
435        let data = ArrayData::from_value(host_value)?;
436        let len = data.len();
437        let is_scalar = len == 1;
438
439        let mut input = ArrayInput { data, is_scalar };
440
441        if let Some(current) = base_len {
442            if current == len {
443                if len > 1 {
444                    let shape = input.shape_vec();
445                    if shape != base_shape {
446                        return Err(arrayfun_flow(format!(
447                            "arrayfun: input {} does not match the size of the first array",
448                            idx + 1
449                        )));
450                    }
451                }
452            } else if len == 1 {
453                input.is_scalar = true;
454            } else if current == 1 {
455                base_len = Some(len);
456                base_shape = input.shape_vec();
457                for prior in &mut inputs {
458                    let prior_len = prior.len();
459                    if prior_len == len {
460                        if prior.shape_vec() != base_shape {
461                            return Err(arrayfun_flow(format!(
462                                "arrayfun: input {} does not match the size of the first array",
463                                idx
464                            )));
465                        }
466                    } else if prior_len == 1 {
467                        prior.is_scalar = true;
468                    } else if prior_len == 0 && len == 0 {
469                        continue;
470                    } else {
471                        return Err(arrayfun_flow(format!(
472                            "arrayfun: input {} does not match the size of the first array",
473                            idx
474                        )));
475                    }
476                }
477            } else if len == 0 && current == 0 {
478                let shape = input.shape_vec();
479                if shape != base_shape {
480                    return Err(arrayfun_flow(format!(
481                        "arrayfun: input {} does not match the size of the first array",
482                        idx + 1
483                    )));
484                }
485            } else {
486                return Err(arrayfun_flow(format!(
487                    "arrayfun: input {} does not match the size of the first array",
488                    idx + 1
489                )));
490            }
491        } else {
492            base_len = Some(len);
493            base_shape = input.shape_vec();
494        }
495
496        inputs.push(input);
497    }
498
499    let total_len = base_len.unwrap_or(0);
500
501    if total_len == 0 {
502        if uniform_output {
503            return Ok(empty_uniform(&base_shape));
504        } else {
505            return make_cell_with_shape(Vec::new(), base_shape)
506                .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")));
507        }
508    }
509
510    let mut collector = if uniform_output {
511        Some(UniformCollector::Pending)
512    } else {
513        None
514    };
515
516    let mut cell_outputs: Vec<Value> = Vec::new();
517    let mut args: Vec<Value> = Vec::with_capacity(inputs.len());
518
519    for idx in 0..total_len {
520        args.clear();
521        for input in &inputs {
522            args.push(input.value_at(idx)?);
523        }
524
525        let result = match callable.call(&args).await {
526            Ok(value) => value,
527            Err(err) => {
528                let handler = match error_handler.as_ref() {
529                    Some(handler) => handler,
530                    None => {
531                        return Err(arrayfun_flow_with_source(
532                            format!("arrayfun: {}", err.message()),
533                            err,
534                        ))
535                    }
536                };
537                let err_message = format_handler_error(&err);
538                let err_value = make_error_struct(&err_message, idx, &base_shape)?;
539                let mut handler_args = Vec::with_capacity(1 + args.len());
540                handler_args.push(err_value);
541                handler_args.extend(args.clone());
542                handler.call(&handler_args).await?
543            }
544        };
545
546        let host_result = gather_if_needed_async(&result).await?;
547
548        if let Some(collector) = collector.as_mut() {
549            collector.push(&host_result)?;
550        } else {
551            cell_outputs.push(host_result);
552        }
553    }
554
555    if let Some(collector) = collector {
556        let uniform = collector.finish(&base_shape)?;
557        maybe_upload_uniform(uniform, has_gpu_input, gpu_device_id)
558    } else {
559        make_cell_with_shape(cell_outputs, base_shape)
560            .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))
561    }
562}
563
564fn maybe_upload_uniform(
565    value: Value,
566    has_gpu_input: bool,
567    gpu_device_id: Option<u32>,
568) -> BuiltinResult<Value> {
569    if !has_gpu_input {
570        return Ok(value);
571    }
572    #[cfg(all(test, feature = "wgpu"))]
573    {
574        if matches!(gpu_device_id, Some(id) if id != 0) {
575            let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
576                runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
577            );
578        }
579    }
580    let _ = gpu_device_id; // may be used only in cfg(test)
581    let provider = match runmat_accelerate_api::provider() {
582        Some(p) => p,
583        None => return Ok(value),
584    };
585
586    match value {
587        Value::Tensor(tensor) => {
588            let view = HostTensorView {
589                data: &tensor.data,
590                shape: &tensor.shape,
591            };
592            let handle = provider
593                .upload(&view)
594                .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?;
595            Ok(Value::GpuTensor(handle))
596        }
597        Value::LogicalArray(logical) => {
598            let data: Vec<f64> = logical
599                .data
600                .iter()
601                .map(|&bit| if bit != 0 { 1.0 } else { 0.0 })
602                .collect();
603            let tensor = Tensor::new(data, logical.shape.clone())
604                .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?;
605            let view = HostTensorView {
606                data: &tensor.data,
607                shape: &tensor.shape,
608            };
609            let handle = provider
610                .upload(&view)
611                .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?;
612            set_handle_logical(&handle, true);
613            Ok(Value::GpuTensor(handle))
614        }
615        other => Ok(other),
616    }
617}
618
619fn empty_uniform(shape: &[usize]) -> Value {
620    if shape.is_empty() {
621        return Value::Tensor(Tensor::zeros(vec![0, 0]));
622    }
623    let total: usize = shape.iter().product();
624    let tensor = Tensor::new(vec![0.0; total], shape.to_vec())
625        .unwrap_or_else(|_| Tensor::zeros(shape.to_vec()));
626    Value::Tensor(tensor)
627}
628
629fn parse_uniform_output(value: Value) -> BuiltinResult<bool> {
630    match value {
631        Value::Bool(b) => Ok(b),
632        Value::Num(n) => Ok(n != 0.0),
633        Value::Int(iv) => Ok(iv.to_f64() != 0.0),
634        Value::String(s) => parse_bool_string(&s)
635            .ok_or_else(|| arrayfun_error(&ARRAYFUN_ERROR_UNIFORM_OUTPUT_OPTION)),
636        Value::CharArray(ca) if ca.rows == 1 => {
637            let text: String = ca.data.iter().collect();
638            parse_bool_string(&text)
639                .ok_or_else(|| arrayfun_error(&ARRAYFUN_ERROR_UNIFORM_OUTPUT_OPTION))
640        }
641        other => Err(arrayfun_error_with_detail(
642            &ARRAYFUN_ERROR_UNIFORM_OUTPUT_OPTION,
643            format!("got {other:?}"),
644        )),
645    }
646}
647
648fn parse_bool_string(value: &str) -> Option<bool> {
649    match value.trim().to_ascii_lowercase().as_str() {
650        "true" | "on" => Some(true),
651        "false" | "off" => Some(false),
652        _ => None,
653    }
654}
655
656fn extract_string(value: &Value) -> Option<String> {
657    match value {
658        Value::String(s) => Some(s.clone()),
659        Value::CharArray(ca) if ca.rows == 1 => Some(ca.data.iter().collect()),
660        Value::StringArray(sa) if sa.data.len() == 1 => Some(sa.data[0].clone()),
661        _ => None,
662    }
663}
664
665struct ArrayInput {
666    data: ArrayData,
667    is_scalar: bool,
668}
669
670impl ArrayInput {
671    fn len(&self) -> usize {
672        self.data.len()
673    }
674
675    fn shape_vec(&self) -> Vec<usize> {
676        self.data.shape_vec()
677    }
678
679    fn value_at(&self, idx: usize) -> BuiltinResult<Value> {
680        if self.is_scalar {
681            self.data.value_at(0)
682        } else {
683            self.data.value_at(idx)
684        }
685    }
686}
687
688enum ArrayData {
689    Tensor(Tensor),
690    Logical(LogicalArray),
691    Complex(ComplexTensor),
692    Char(CharArray),
693    String(StringArray),
694    Scalar(Value),
695}
696
697impl ArrayData {
698    fn from_value(value: Value) -> BuiltinResult<Self> {
699        match value {
700            Value::Tensor(t) => Ok(ArrayData::Tensor(t)),
701            Value::LogicalArray(l) => Ok(ArrayData::Logical(l)),
702            Value::ComplexTensor(c) => Ok(ArrayData::Complex(c)),
703            Value::CharArray(ca) => Ok(ArrayData::Char(ca)),
704            Value::StringArray(sa) => Ok(ArrayData::String(sa)),
705            Value::Num(_)
706            | Value::Bool(_)
707            | Value::Int(_)
708            | Value::Complex(_, _)
709            | Value::String(_) => {
710                Ok(ArrayData::Scalar(value))
711            }
712            other => Err(arrayfun_flow(format!(
713                "arrayfun: unsupported input type {other:?} (expected numeric, logical, complex, char, or string arrays)"
714            ))),
715        }
716    }
717
718    fn len(&self) -> usize {
719        match self {
720            ArrayData::Tensor(t) => t.data.len(),
721            ArrayData::Logical(l) => l.data.len(),
722            ArrayData::Complex(c) => c.data.len(),
723            ArrayData::Char(ca) => ca.rows * ca.cols,
724            ArrayData::String(sa) => sa.data.len(),
725            ArrayData::Scalar(_) => 1,
726        }
727    }
728
729    fn shape_vec(&self) -> Vec<usize> {
730        match self {
731            ArrayData::Tensor(t) => {
732                if t.shape.is_empty() {
733                    vec![1, 1]
734                } else {
735                    t.shape.clone()
736                }
737            }
738            ArrayData::Logical(l) => {
739                if l.shape.is_empty() {
740                    vec![1, 1]
741                } else {
742                    l.shape.clone()
743                }
744            }
745            ArrayData::Complex(c) => {
746                if c.shape.is_empty() {
747                    vec![1, 1]
748                } else {
749                    c.shape.clone()
750                }
751            }
752            ArrayData::Char(ca) => vec![ca.rows, ca.cols],
753            ArrayData::String(sa) => {
754                if sa.shape.is_empty() {
755                    vec![1, 1]
756                } else {
757                    sa.shape.clone()
758                }
759            }
760            ArrayData::Scalar(_) => vec![1, 1],
761        }
762    }
763
764    fn value_at(&self, idx: usize) -> BuiltinResult<Value> {
765        match self {
766            ArrayData::Tensor(t) => {
767                Ok(Value::Num(*t.data.get(idx).ok_or_else(|| {
768                    arrayfun_flow("arrayfun: index out of bounds")
769                })?))
770            }
771            ArrayData::Logical(l) => Ok(Value::Bool(
772                *l.data
773                    .get(idx)
774                    .ok_or_else(|| arrayfun_flow("arrayfun: index out of bounds"))?
775                    != 0,
776            )),
777            ArrayData::Complex(c) => {
778                let (re, im) = c
779                    .data
780                    .get(idx)
781                    .ok_or_else(|| arrayfun_flow("arrayfun: index out of bounds"))?;
782                Ok(Value::Complex(*re, *im))
783            }
784            ArrayData::Char(ca) => {
785                if ca.rows == 0 || ca.cols == 0 {
786                    return Ok(Value::CharArray(
787                        CharArray::new(Vec::new(), 0, 0)
788                            .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?,
789                    ));
790                }
791                let rows = ca.rows;
792                let cols = ca.cols;
793                let row = idx % rows;
794                let col = idx / rows;
795                let data_idx = row * cols + col;
796                let ch = *ca
797                    .data
798                    .get(data_idx)
799                    .ok_or_else(|| arrayfun_flow("arrayfun: index out of bounds"))?;
800                let char_array = CharArray::new(vec![ch], 1, 1)
801                    .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?;
802                Ok(Value::CharArray(char_array))
803            }
804            ArrayData::String(sa) => {
805                Ok(Value::String(sa.data.get(idx).cloned().ok_or_else(
806                    || arrayfun_flow("arrayfun: index out of bounds"),
807                )?))
808            }
809            ArrayData::Scalar(v) => Ok(v.clone()),
810        }
811    }
812}
813
814#[derive(Clone)]
815enum Callable {
816    Builtin { name: String },
817    ExternalName { name: String },
818    Closure(Closure),
819}
820
821impl Callable {
822    fn resolved_semantic_handle(name: &str) -> Option<Self> {
823        let function = user_functions::resolve_semantic_function_by_name(name)?;
824        Some(Callable::Closure(Closure {
825            function_name: name.to_string(),
826            bound_function: Some(function),
827            captures: Vec::new(),
828        }))
829    }
830
831    fn from_function(value: Value) -> BuiltinResult<Self> {
832        match value {
833            Value::String(text) => Self::from_text(&text),
834            Value::CharArray(ca) => {
835                if ca.rows != 1 {
836                    Err(arrayfun_flow(
837                        "arrayfun: function name must be a character vector or string scalar",
838                    ))
839                } else {
840                    let text: String = ca.data.iter().collect();
841                    Self::from_text(&text)
842                }
843            }
844            Value::StringArray(sa) if sa.data.len() == 1 => Self::from_text(&sa.data[0]),
845            Value::FunctionHandle(name) => Self::from_text(&name),
846            Value::ExternalFunctionHandle(name) => {
847                if let Some(callable) = Self::resolved_semantic_handle(&name) {
848                    Ok(callable)
849                } else if crate::is_well_formed_qualified_name(&name) {
850                    Ok(Callable::ExternalName { name })
851                } else {
852                    Ok(Callable::Builtin { name })
853                }
854            }
855            Value::BoundFunctionHandle { name, function } => Ok(Callable::Closure(Closure {
856                function_name: name,
857                bound_function: Some(function),
858                captures: Vec::new(),
859            })),
860            Value::Closure(mut closure) => {
861                if closure.bound_function.is_none() {
862                    if let Some(function) =
863                        user_functions::resolve_semantic_function_by_name(&closure.function_name)
864                    {
865                        closure.bound_function = Some(function);
866                    }
867                }
868                Ok(Callable::Closure(closure))
869            }
870            Value::Num(_) | Value::Int(_) | Value::Bool(_) => Err(arrayfun_flow(
871                "arrayfun: expected function handle or builtin name, not a scalar value",
872            )),
873            other => Err(arrayfun_flow(format!(
874                "arrayfun: expected function handle or builtin name, got {other:?}"
875            ))),
876        }
877    }
878
879    fn from_text(text: &str) -> BuiltinResult<Self> {
880        let trimmed = text.trim();
881        if trimmed.is_empty() {
882            return Err(arrayfun_flow(
883                "arrayfun: expected function handle or builtin name, got empty string",
884            ));
885        }
886        if let Some(rest) = trimmed.strip_prefix('@') {
887            let name = rest.trim();
888            if name.is_empty() {
889                Err(arrayfun_flow("arrayfun: empty function handle"))
890            } else {
891                if let Some(callable) = Self::resolved_semantic_handle(name) {
892                    return Ok(callable);
893                }
894                if crate::is_well_formed_qualified_name(name) {
895                    return Ok(Callable::ExternalName {
896                        name: name.to_string(),
897                    });
898                }
899                Ok(Callable::Builtin {
900                    name: name.to_string(),
901                })
902            }
903        } else {
904            let name = trimmed.to_ascii_lowercase();
905            if let Some(callable) = Self::resolved_semantic_handle(&name) {
906                return Ok(callable);
907            }
908            if crate::is_well_formed_qualified_name(&name) {
909                return Ok(Callable::ExternalName { name });
910            }
911            Ok(Callable::Builtin { name })
912        }
913    }
914
915    fn builtin_name(&self) -> Option<&str> {
916        match self {
917            Callable::Builtin { name } => Some(name.as_str()),
918            Callable::ExternalName { .. } | Callable::Closure(_) => None,
919        }
920    }
921
922    async fn call(&self, args: &[Value]) -> crate::BuiltinResult<Value> {
923        match self {
924            Callable::Builtin { name } => {
925                let request = user_functions::CallableRequest::resolved(
926                    runmat_hir::CallableIdentity::DynamicName(runmat_hir::SymbolName(name.clone())),
927                    runmat_hir::CallableFallbackPolicy::RuntimeNameResolution,
928                    args.to_vec(),
929                    1,
930                );
931                if let Some(result) = user_functions::try_call_semantic_descriptor(request).await {
932                    return result;
933                }
934                crate::call_builtin_async(name, args).await
935            }
936            Callable::ExternalName { name } => {
937                let identity = crate::external_callable_identity_for_name(name);
938                let request = user_functions::CallableRequest::resolved(
939                    identity.clone(),
940                    runmat_hir::CallableFallbackPolicy::ExternalBoundary,
941                    args.to_vec(),
942                    1,
943                );
944                if let Some(result) = user_functions::try_call_semantic_descriptor(request).await {
945                    return result;
946                }
947                Err(arrayfun_error_with_message(
948                    format!("Undefined function for callable identity {identity:?}"),
949                    &ARRAYFUN_ERROR_UNDEFINED_FUNCTION,
950                ))
951            }
952            Callable::Closure(c) => {
953                let mut merged = c.captures.clone();
954                merged.extend_from_slice(args);
955                if let Some(function) = c.bound_function {
956                    let request =
957                        user_functions::CallableRequest::semantic(function, merged.clone(), 1);
958                    if let Some(result) =
959                        user_functions::try_call_semantic_descriptor(request).await
960                    {
961                        return result;
962                    }
963                    return Err(arrayfun_error_with_detail(
964                        &ARRAYFUN_ERROR_CALLBACK_FAILED,
965                        format!(
966                            "semantic closure '{}' ({function}) is unavailable",
967                            c.function_name
968                        ),
969                    ));
970                }
971                if let Some(function) =
972                    user_functions::resolve_semantic_function_by_name(&c.function_name)
973                {
974                    let request =
975                        user_functions::CallableRequest::semantic(function, merged.clone(), 1);
976                    if let Some(result) =
977                        user_functions::try_call_semantic_descriptor(request).await
978                    {
979                        return result;
980                    }
981                }
982                crate::call_builtin_async(&c.function_name, &merged).await
983            }
984        }
985    }
986}
987
988async fn try_gpu_fast_path(
989    callable: &Callable,
990    inputs: &[Value],
991    error_handler: Option<&Callable>,
992) -> BuiltinResult<Option<Value>> {
993    if inputs.is_empty() || error_handler.is_some() {
994        return Ok(None);
995    }
996    if !inputs
997        .iter()
998        .all(|value| matches!(value, Value::GpuTensor(_)))
999    {
1000        return Ok(None);
1001    }
1002
1003    #[cfg(all(test, feature = "wgpu"))]
1004    {
1005        if inputs
1006            .iter()
1007            .any(|v| matches!(v, Value::GpuTensor(h) if h.device_id != 0))
1008        {
1009            let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1010                runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1011            );
1012        }
1013    }
1014    let provider = match runmat_accelerate_api::provider() {
1015        Some(p) => p,
1016        None => return Ok(None),
1017    };
1018
1019    let Some(name_raw) = callable.builtin_name() else {
1020        return Ok(None);
1021    };
1022    let name = name_raw.to_ascii_lowercase();
1023
1024    let mut handles: Vec<GpuTensorHandle> = Vec::with_capacity(inputs.len());
1025    for value in inputs {
1026        if let Value::GpuTensor(handle) = value {
1027            handles.push(handle.clone());
1028        }
1029    }
1030
1031    if handles.len() >= 2 {
1032        let base_shape = handles[0].shape.clone();
1033        if handles
1034            .iter()
1035            .skip(1)
1036            .any(|handle| handle.shape != base_shape)
1037        {
1038            return Ok(None);
1039        }
1040    }
1041
1042    let result = match name.as_str() {
1043        "sin" if handles.len() == 1 => provider.unary_sin(&handles[0]).await,
1044        "cos" if handles.len() == 1 => provider.unary_cos(&handles[0]).await,
1045        "abs" if handles.len() == 1 => provider.unary_abs(&handles[0]).await,
1046        "exp" if handles.len() == 1 => provider.unary_exp(&handles[0]).await,
1047        "log" if handles.len() == 1 => provider.unary_log(&handles[0]).await,
1048        "sqrt" if handles.len() == 1 => provider.unary_sqrt(&handles[0]).await,
1049        "plus" if handles.len() == 2 => provider.elem_add(&handles[0], &handles[1]).await,
1050        "minus" if handles.len() == 2 => provider.elem_sub(&handles[0], &handles[1]).await,
1051        "times" if handles.len() == 2 => provider.elem_mul(&handles[0], &handles[1]).await,
1052        "rdivide" if handles.len() == 2 => provider.elem_div(&handles[0], &handles[1]).await,
1053        "ldivide" if handles.len() == 2 => provider.elem_div(&handles[1], &handles[0]).await,
1054        _ => return Ok(None),
1055    };
1056
1057    match result {
1058        Ok(handle) => Ok(Some(Value::GpuTensor(handle))),
1059        Err(_) => Ok(None),
1060    }
1061}
1062
1063enum UniformCollector {
1064    Pending,
1065    Double(Vec<f64>),
1066    Logical(Vec<u8>),
1067    Complex(Vec<(f64, f64)>),
1068    Char(Vec<char>),
1069}
1070
1071impl UniformCollector {
1072    fn push(&mut self, value: &Value) -> BuiltinResult<()> {
1073        match self {
1074            UniformCollector::Pending => match classify_value(value)? {
1075                ClassifiedValue::Logical(b) => {
1076                    *self = UniformCollector::Logical(vec![b as u8]);
1077                    Ok(())
1078                }
1079                ClassifiedValue::Double(d) => {
1080                    *self = UniformCollector::Double(vec![d]);
1081                    Ok(())
1082                }
1083                ClassifiedValue::Complex(c) => {
1084                    *self = UniformCollector::Complex(vec![c]);
1085                    Ok(())
1086                }
1087                ClassifiedValue::Char(ch) => {
1088                    *self = UniformCollector::Char(vec![ch]);
1089                    Ok(())
1090                }
1091            },
1092            UniformCollector::Logical(bits) => match classify_value(value)? {
1093                ClassifiedValue::Logical(b) => {
1094                    bits.push(b as u8);
1095                    Ok(())
1096                }
1097                ClassifiedValue::Double(d) => {
1098                    let mut data: Vec<f64> = bits
1099                        .iter()
1100                        .map(|&bit| if bit != 0 { 1.0 } else { 0.0 })
1101                        .collect();
1102                    data.push(d);
1103                    *self = UniformCollector::Double(data);
1104                    Ok(())
1105                }
1106                ClassifiedValue::Complex(c) => {
1107                    let mut data: Vec<(f64, f64)> = bits
1108                        .iter()
1109                        .map(|&bit| if bit != 0 { (1.0, 0.0) } else { (0.0, 0.0) })
1110                        .collect();
1111                    data.push(c);
1112                    *self = UniformCollector::Complex(data);
1113                    Ok(())
1114                }
1115                ClassifiedValue::Char(ch) => {
1116                    let mut data: Vec<f64> = bits
1117                        .iter()
1118                        .map(|&bit| if bit != 0 { 1.0 } else { 0.0 })
1119                        .collect();
1120                    data.push(ch as u32 as f64);
1121                    *self = UniformCollector::Double(data);
1122                    Ok(())
1123                }
1124            },
1125            UniformCollector::Double(data) => match classify_value(value)? {
1126                ClassifiedValue::Logical(b) => {
1127                    data.push(if b { 1.0 } else { 0.0 });
1128                    Ok(())
1129                }
1130                ClassifiedValue::Double(d) => {
1131                    data.push(d);
1132                    Ok(())
1133                }
1134                ClassifiedValue::Complex(c) => {
1135                    let promoted: Vec<(f64, f64)> = data.iter().map(|&v| (v, 0.0)).collect();
1136                    let mut complex = promoted;
1137                    complex.push(c);
1138                    *self = UniformCollector::Complex(complex);
1139                    Ok(())
1140                }
1141                ClassifiedValue::Char(ch) => {
1142                    data.push(ch as u32 as f64);
1143                    Ok(())
1144                }
1145            },
1146            UniformCollector::Complex(data) => match classify_value(value)? {
1147                ClassifiedValue::Logical(b) => {
1148                    data.push((if b { 1.0 } else { 0.0 }, 0.0));
1149                    Ok(())
1150                }
1151                ClassifiedValue::Double(d) => {
1152                    data.push((d, 0.0));
1153                    Ok(())
1154                }
1155                ClassifiedValue::Complex(c) => {
1156                    data.push(c);
1157                    Ok(())
1158                }
1159                ClassifiedValue::Char(ch) => {
1160                    data.push((ch as u32 as f64, 0.0));
1161                    Ok(())
1162                }
1163            },
1164            UniformCollector::Char(chars) => match classify_value(value)? {
1165                ClassifiedValue::Char(ch) => {
1166                    chars.push(ch);
1167                    Ok(())
1168                }
1169                ClassifiedValue::Logical(b) => {
1170                    let mut data: Vec<f64> = chars.iter().map(|&ch| ch as u32 as f64).collect();
1171                    data.push(if b { 1.0 } else { 0.0 });
1172                    *self = UniformCollector::Double(data);
1173                    Ok(())
1174                }
1175                ClassifiedValue::Double(d) => {
1176                    let mut data: Vec<f64> = chars.iter().map(|&ch| ch as u32 as f64).collect();
1177                    data.push(d);
1178                    *self = UniformCollector::Double(data);
1179                    Ok(())
1180                }
1181                ClassifiedValue::Complex(c) => {
1182                    let mut promoted: Vec<(f64, f64)> =
1183                        chars.iter().map(|&ch| (ch as u32 as f64, 0.0)).collect();
1184                    promoted.push(c);
1185                    *self = UniformCollector::Complex(promoted);
1186                    Ok(())
1187                }
1188            },
1189        }
1190    }
1191
1192    fn finish(self, shape: &[usize]) -> BuiltinResult<Value> {
1193        match self {
1194            UniformCollector::Pending => {
1195                let total = shape.iter().product();
1196                let tensor = Tensor::new(vec![0.0; total], shape.to_vec())
1197                    .map_err(|e| arrayfun_internal(format!("arrayfun: {e}")))?;
1198                Ok(Value::Tensor(tensor))
1199            }
1200            UniformCollector::Double(data) => {
1201                let tensor = Tensor::new(data, shape.to_vec())
1202                    .map_err(|e| arrayfun_internal(format!("arrayfun: {e}")))?;
1203                Ok(Value::Tensor(tensor))
1204            }
1205            UniformCollector::Logical(bits) => {
1206                let logical = LogicalArray::new(bits, shape.to_vec())
1207                    .map_err(|e| arrayfun_internal(format!("arrayfun: {e}")))?;
1208                Ok(Value::LogicalArray(logical))
1209            }
1210            UniformCollector::Complex(entries) => {
1211                let tensor = ComplexTensor::new(entries, shape.to_vec())
1212                    .map_err(|e| arrayfun_internal(format!("arrayfun: {e}")))?;
1213                Ok(Value::ComplexTensor(tensor))
1214            }
1215            UniformCollector::Char(chars) => {
1216                let normalized_shape = if shape.is_empty() {
1217                    vec![1, 1]
1218                } else {
1219                    shape.to_vec()
1220                };
1221
1222                if normalized_shape.len() > 2 {
1223                    return Err(arrayfun_error_with_detail(
1224                        &ARRAYFUN_ERROR_UNIFORM_OUTPUT_TYPE,
1225                        "character outputs with UniformOutput=true must be 2-D",
1226                    ));
1227                }
1228
1229                let rows = normalized_shape.first().copied().unwrap_or(1);
1230                let cols = normalized_shape.get(1).copied().unwrap_or(1);
1231                let expected = rows.checked_mul(cols).ok_or_else(|| {
1232                    arrayfun_internal("arrayfun: character output size exceeds platform limits")
1233                })?;
1234
1235                if expected != chars.len() {
1236                    return Err(arrayfun_error_with_detail(
1237                        &ARRAYFUN_ERROR_UNIFORM_OUTPUT_TYPE,
1238                        "callback returned the wrong number of characters",
1239                    ));
1240                }
1241
1242                let mut row_major = vec!['\0'; expected];
1243                for col in 0..cols {
1244                    for row in 0..rows {
1245                        let col_major_idx = row + col * rows;
1246                        let row_major_idx = row * cols + col;
1247                        row_major[row_major_idx] = chars[col_major_idx];
1248                    }
1249                }
1250
1251                let array = CharArray::new(row_major, rows, cols)
1252                    .map_err(|e| arrayfun_internal(format!("arrayfun: {e}")))?;
1253                Ok(Value::CharArray(array))
1254            }
1255        }
1256    }
1257}
1258
1259enum ClassifiedValue {
1260    Logical(bool),
1261    Double(f64),
1262    Complex((f64, f64)),
1263    Char(char),
1264}
1265
1266fn classify_value(value: &Value) -> BuiltinResult<ClassifiedValue> {
1267    match value {
1268        Value::Bool(b) => Ok(ClassifiedValue::Logical(*b)),
1269        Value::LogicalArray(la) if la.len() == 1 => Ok(ClassifiedValue::Logical(la.data[0] != 0)),
1270        Value::Int(i) => Ok(ClassifiedValue::Double(i.to_f64())),
1271        Value::Num(n) => Ok(ClassifiedValue::Double(*n)),
1272        Value::Tensor(t) if t.data.len() == 1 => Ok(ClassifiedValue::Double(t.data[0])),
1273        Value::Complex(re, im) => Ok(ClassifiedValue::Complex((*re, *im))),
1274        Value::ComplexTensor(t) if t.data.len() == 1 => Ok(ClassifiedValue::Complex(t.data[0])),
1275        Value::CharArray(ca) if ca.rows * ca.cols == 1 => {
1276            let ch = ca.data.first().copied().unwrap_or('\0');
1277            Ok(ClassifiedValue::Char(ch))
1278        }
1279        other => Err(arrayfun_error_with_detail(
1280            &ARRAYFUN_ERROR_UNIFORM_OUTPUT_TYPE,
1281            format!(
1282                "callback must return scalar numeric, logical, character, or complex values for UniformOutput=true (got {other:?})"
1283            ),
1284        )),
1285    }
1286}
1287
1288fn make_error_struct(
1289    raw_error: &str,
1290    linear_index: usize,
1291    shape: &[usize],
1292) -> BuiltinResult<Value> {
1293    let (identifier, message) = split_error_message(raw_error);
1294    let mut st = runmat_builtins::StructValue::new();
1295    st.fields
1296        .insert("identifier".to_string(), Value::String(identifier));
1297    st.fields
1298        .insert("message".to_string(), Value::String(message));
1299    st.fields
1300        .insert("index".to_string(), Value::Num((linear_index + 1) as f64));
1301    let subs = linear_to_indices(linear_index, shape);
1302    let subs_tensor = dims_to_row_tensor(&subs)?;
1303    st.fields
1304        .insert("indices".to_string(), Value::Tensor(subs_tensor));
1305    Ok(Value::Struct(st))
1306}
1307
1308fn split_error_message(raw: &str) -> (String, String) {
1309    let trimmed = raw.trim();
1310    let mut indices = trimmed.match_indices(':');
1311    if let Some((_, _)) = indices.next() {
1312        if let Some((second_idx, _)) = indices.next() {
1313            let identifier = trimmed[..second_idx].trim().to_string();
1314            let message = trimmed[second_idx + 1..].trim().to_string();
1315            if !identifier.is_empty() && identifier.contains(':') {
1316                return (
1317                    identifier,
1318                    if message.is_empty() {
1319                        trimmed.to_string()
1320                    } else {
1321                        message
1322                    },
1323                );
1324            }
1325        } else if trimmed.len() >= 7
1326            && (trimmed[..7].eq_ignore_ascii_case("matlab:")
1327                || trimmed[..7].eq_ignore_ascii_case("runmat:"))
1328        {
1329            return (trimmed.to_string(), String::new());
1330        }
1331    }
1332    (
1333        "RunMat:arrayfun:FunctionError".to_string(),
1334        trimmed.to_string(),
1335    )
1336}
1337
1338fn linear_to_indices(mut index: usize, shape: &[usize]) -> Vec<usize> {
1339    if shape.is_empty() {
1340        return vec![1];
1341    }
1342    let mut subs = Vec::with_capacity(shape.len());
1343    for &dim in shape {
1344        if dim == 0 {
1345            subs.push(1);
1346            continue;
1347        }
1348        let coord = (index % dim) + 1;
1349        subs.push(coord);
1350        index /= dim;
1351    }
1352    subs
1353}
1354
1355fn dims_to_row_tensor(dims: &[usize]) -> BuiltinResult<Tensor> {
1356    let data: Vec<f64> = dims.iter().map(|&d| d as f64).collect();
1357    Tensor::new(data, vec![1, dims.len()]).map_err(|e| arrayfun_internal(format!("arrayfun: {e}")))
1358}
1359
1360#[cfg(test)]
1361pub(crate) mod tests {
1362    use super::*;
1363    use crate::builtins::common::test_support;
1364    use futures::executor::block_on;
1365    use runmat_accelerate_api::HostTensorView;
1366    use runmat_builtins::{ResolveContext, Tensor, Type};
1367    use std::sync::Arc;
1368
1369    fn call(func: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
1370        block_on(arrayfun_builtin(func, rest))
1371    }
1372
1373    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1374    #[test]
1375    fn arrayfun_basic_sin() {
1376        let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0], vec![2, 3]).unwrap();
1377        let expected: Vec<f64> = tensor.data.iter().map(|&x| x.sin()).collect();
1378        let result = call(
1379            Value::FunctionHandle("sin".to_string()),
1380            vec![Value::Tensor(tensor.clone())],
1381        )
1382        .expect("arrayfun");
1383        match result {
1384            Value::Tensor(out) => {
1385                assert_eq!(out.shape, vec![2, 3]);
1386                assert_eq!(out.data, expected);
1387            }
1388            other => panic!("expected tensor, got {other:?}"),
1389        }
1390    }
1391
1392    #[test]
1393    fn arrayfun_semantic_function_handle_uses_semantic_invoker() {
1394        let _guard = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
1395            |function, args, requested_outputs| {
1396                assert_eq!(function, 78);
1397                assert_eq!(requested_outputs, 1);
1398                let [Value::Num(value)] = args else {
1399                    panic!("expected scalar numeric argument, got {args:?}");
1400                };
1401                let value = *value;
1402                Box::pin(async move { Ok(Value::Num(value + 10.0)) })
1403            },
1404        )));
1405        let tensor = Tensor::new(vec![1.0, 2.0], vec![1, 2]).expect("tensor");
1406        let handle = Value::BoundFunctionHandle {
1407            name: "arrayfun_target".to_string(),
1408            function: 78,
1409        };
1410
1411        let result = call(handle, vec![Value::Tensor(tensor)]).expect("semantic arrayfun");
1412        match result {
1413            Value::Tensor(out) => {
1414                assert_eq!(out.shape, vec![1, 2]);
1415                assert_eq!(out.data, vec![11.0, 12.0]);
1416            }
1417            other => panic!("expected tensor, got {other:?}"),
1418        }
1419    }
1420
1421    #[test]
1422    fn arrayfun_name_only_callback_uses_semantic_resolver() {
1423        let _resolver_guard =
1424            crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|name| {
1425                (name == "resolved_arrayfun_target").then_some(80)
1426            })));
1427        let _invoker_guard = crate::user_functions::install_semantic_function_invoker(Some(
1428            Arc::new(|function, args, requested_outputs| {
1429                assert_eq!(function, 80);
1430                assert_eq!(requested_outputs, 1);
1431                let [Value::Num(value)] = args else {
1432                    panic!("expected scalar numeric argument, got {args:?}");
1433                };
1434                let value = *value;
1435                Box::pin(async move { Ok(Value::Num(value + 20.0)) })
1436            }),
1437        ));
1438        let tensor = Tensor::new(vec![1.0, 2.0], vec![1, 2]).expect("tensor");
1439
1440        let result = call(
1441            Value::String("resolved_arrayfun_target".to_string()),
1442            vec![Value::Tensor(tensor)],
1443        )
1444        .expect("resolved name-only arrayfun");
1445        match result {
1446            Value::Tensor(out) => {
1447                assert_eq!(out.shape, vec![1, 2]);
1448                assert_eq!(out.data, vec![21.0, 22.0]);
1449            }
1450            other => panic!("expected tensor, got {other:?}"),
1451        }
1452    }
1453
1454    #[test]
1455    fn arrayfun_qualified_text_callback_classifies_as_external_name() {
1456        let callable =
1457            Callable::from_text("pkg.callback").expect("qualified arrayfun callback should parse");
1458        assert!(matches!(
1459            callable,
1460            Callable::ExternalName { name } if name == "pkg.callback"
1461        ));
1462    }
1463
1464    #[test]
1465    fn arrayfun_external_handle_uses_semantic_resolver() {
1466        let _resolver_guard =
1467            crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|name| {
1468                (name == "pkg.callback").then_some(87)
1469            })));
1470        let _invoker_guard = crate::user_functions::install_semantic_function_invoker(Some(
1471            Arc::new(|function, args, requested_outputs| {
1472                assert_eq!(function, 87);
1473                assert_eq!(requested_outputs, 1);
1474                let [Value::Num(value)] = args else {
1475                    panic!("expected scalar numeric argument, got {args:?}");
1476                };
1477                let value = *value;
1478                Box::pin(async move { Ok(Value::Num(value + 30.0)) })
1479            }),
1480        ));
1481        let tensor = Tensor::new(vec![1.0, 2.0], vec![1, 2]).expect("tensor");
1482
1483        let result = call(
1484            Value::ExternalFunctionHandle("pkg.callback".to_string()),
1485            vec![Value::Tensor(tensor)],
1486        )
1487        .expect("resolved external-handle arrayfun");
1488        match result {
1489            Value::Tensor(out) => {
1490                assert_eq!(out.shape, vec![1, 2]);
1491                assert_eq!(out.data, vec![31.0, 32.0]);
1492            }
1493            other => panic!("expected tensor, got {other:?}"),
1494        }
1495    }
1496
1497    #[test]
1498    fn arrayfun_single_segment_external_handle_uses_runtime_name_resolution() {
1499        let _resolver_guard =
1500            crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|name| {
1501                (name == "callback").then_some(887)
1502            })));
1503        let _invoker_guard = crate::user_functions::install_semantic_function_invoker(Some(
1504            Arc::new(|function, args, requested_outputs| {
1505                assert_eq!(function, 887);
1506                assert_eq!(requested_outputs, 1);
1507                let [Value::Num(value)] = args else {
1508                    panic!("expected scalar numeric argument, got {args:?}");
1509                };
1510                let value = *value;
1511                Box::pin(async move { Ok(Value::Num(value + 40.0)) })
1512            }),
1513        ));
1514        let tensor = Tensor::new(vec![1.0, 2.0], vec![1, 2]).expect("tensor");
1515
1516        let result = call(
1517            Value::ExternalFunctionHandle("callback".to_string()),
1518            vec![Value::Tensor(tensor)],
1519        )
1520        .expect("single-segment external-handle arrayfun should resolve via runtime-name policy");
1521        match result {
1522            Value::Tensor(out) => {
1523                assert_eq!(out.shape, vec![1, 2]);
1524                assert_eq!(out.data, vec![41.0, 42.0]);
1525            }
1526            other => panic!("expected tensor, got {other:?}"),
1527        }
1528    }
1529
1530    #[test]
1531    fn arrayfun_external_handle_prefers_semantic_handle_binding_when_resolved() {
1532        let _resolver_guard =
1533            crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|name| {
1534                (name == "pkg.callback").then_some(87)
1535            })));
1536        let callable =
1537            Callable::from_function(Value::ExternalFunctionHandle("pkg.callback".to_string()))
1538                .expect("external handle should parse");
1539        assert!(matches!(
1540            callable,
1541            Callable::Closure(Closure {
1542                function_name,
1543                bound_function: Some(87),
1544                ..
1545            }) if function_name == "pkg.callback"
1546        ));
1547    }
1548
1549    #[test]
1550    fn arrayfun_name_only_closure_prefers_semantic_handle_binding_when_resolved() {
1551        let _resolver_guard =
1552            crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|name| {
1553                (name == "pkg.callback").then_some(187)
1554            })));
1555        let callable = Callable::from_function(Value::Closure(Closure {
1556            function_name: "pkg.callback".to_string(),
1557            bound_function: None,
1558            captures: vec![Value::Num(5.0)],
1559        }))
1560        .expect("closure callback should parse");
1561        assert!(matches!(
1562            callable,
1563            Callable::Closure(Closure {
1564                function_name,
1565                bound_function: Some(187),
1566                captures
1567            }) if function_name == "pkg.callback" && captures == vec![Value::Num(5.0)]
1568        ));
1569    }
1570
1571    #[test]
1572    fn arrayfun_name_only_closure_call_uses_semantic_resolver_when_unbound() {
1573        let _resolver_guard =
1574            crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|name| {
1575                (name == "pkg.callback").then_some(287)
1576            })));
1577        let _invoker_guard = crate::user_functions::install_semantic_function_invoker(Some(
1578            Arc::new(|function, args, requested_outputs| {
1579                assert_eq!(function, 287);
1580                assert_eq!(requested_outputs, 1);
1581                assert_eq!(args, &[Value::Num(5.0), Value::Num(4.0)]);
1582                Box::pin(async { Ok(Value::Num(9.0)) })
1583            }),
1584        ));
1585        let callable = Callable::Closure(Closure {
1586            function_name: "pkg.callback".to_string(),
1587            bound_function: None,
1588            captures: vec![Value::Num(5.0)],
1589        });
1590        let value = block_on(callable.call(&[Value::Num(4.0)])).expect("closure call");
1591        assert_eq!(value, Value::Num(9.0));
1592    }
1593
1594    #[test]
1595    fn arrayfun_external_handle_errors_as_undefined_when_unresolved() {
1596        let _resolver_guard =
1597            crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|_| None)));
1598        let tensor = Tensor::new(vec![1.0], vec![1, 1]).expect("tensor");
1599
1600        let err = call(
1601            Value::ExternalFunctionHandle("pkg.callback".to_string()),
1602            vec![Value::Tensor(tensor)],
1603        )
1604        .expect_err("unresolved external callback should error");
1605        assert_eq!(
1606            err.identifier(),
1607            ARRAYFUN_ERROR_UNDEFINED_FUNCTION.identifier,
1608            "unexpected error: {}",
1609            err.message()
1610        );
1611        assert!(
1612            err.message().contains("ExternalName(QualifiedName"),
1613            "unexpected error: {err:?}"
1614        );
1615        assert!(
1616            !err.message().contains("Undefined function 'pkg.callback'"),
1617            "well-formed external callback should report typed identity: {err:?}"
1618        );
1619    }
1620
1621    #[test]
1622    fn arrayfun_malformed_external_handle_errors_as_undefined_when_unresolved() {
1623        let _resolver_guard =
1624            crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|_| None)));
1625        let tensor = Tensor::new(vec![1.0], vec![1, 1]).expect("tensor");
1626
1627        let err = call(
1628            Value::ExternalFunctionHandle("pkg..callback".to_string()),
1629            vec![Value::Tensor(tensor)],
1630        )
1631        .expect_err("malformed unresolved external callback should error");
1632        assert_eq!(
1633            err.identifier(),
1634            ARRAYFUN_ERROR_UNDEFINED_FUNCTION.identifier,
1635            "unexpected error: {}",
1636            err.message()
1637        );
1638        assert!(
1639            err.message().contains("pkg..callback"),
1640            "unexpected error: {err:?}"
1641        );
1642    }
1643
1644    #[test]
1645    fn arrayfun_type_tracks_function_returns() {
1646        let func = Type::Function {
1647            params: vec![Type::Num],
1648            returns: Box::new(Type::Num),
1649        };
1650        assert_eq!(
1651            arrayfun_type(&[func, Type::tensor()], &ResolveContext::new(Vec::new())),
1652            Type::tensor()
1653        );
1654    }
1655
1656    #[test]
1657    fn arrayfun_type_uses_logical_returns() {
1658        let func = Type::Function {
1659            params: vec![Type::Num],
1660            returns: Box::new(Type::Bool),
1661        };
1662        assert_eq!(
1663            arrayfun_type(&[func, Type::tensor()], &ResolveContext::new(Vec::new())),
1664            Type::logical()
1665        );
1666    }
1667
1668    #[test]
1669    fn arrayfun_type_with_text_args_stays_unknown() {
1670        let func = Type::Function {
1671            params: vec![Type::Num],
1672            returns: Box::new(Type::Num),
1673        };
1674        assert_eq!(
1675            arrayfun_type(
1676                &[func, Type::tensor(), Type::String, Type::Bool],
1677                &ResolveContext::new(Vec::new()),
1678            ),
1679            Type::Unknown
1680        );
1681    }
1682
1683    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1684    #[test]
1685    fn arrayfun_additional_scalar_argument() {
1686        let tensor = Tensor::new(vec![0.5, 1.0, -1.0], vec![3, 1]).unwrap();
1687        let expected: Vec<f64> = tensor.data.iter().map(|&y| y.atan2(1.0)).collect();
1688        let result = call(
1689            Value::FunctionHandle("atan2".to_string()),
1690            vec![Value::Tensor(tensor), Value::Num(1.0)],
1691        )
1692        .expect("arrayfun");
1693        match result {
1694            Value::Tensor(out) => {
1695                assert_eq!(out.data, expected);
1696            }
1697            other => panic!("expected tensor, got {other:?}"),
1698        }
1699    }
1700
1701    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1702    #[test]
1703    fn arrayfun_uniform_false_returns_cell() {
1704        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1705        let expected: Vec<Value> = tensor.data.iter().map(|&x| Value::Num(x.sin())).collect();
1706        let result = call(
1707            Value::FunctionHandle("sin".to_string()),
1708            vec![
1709                Value::Tensor(tensor),
1710                Value::String("UniformOutput".into()),
1711                Value::Bool(false),
1712            ],
1713        )
1714        .expect("arrayfun");
1715        let Value::Cell(cell) = result else {
1716            panic!("expected cell, got something else");
1717        };
1718        assert_eq!(cell.rows, 2);
1719        assert_eq!(cell.cols, 1);
1720        for (row, value) in expected.iter().enumerate() {
1721            assert_eq!(cell.get(row, 0).unwrap(), *value);
1722        }
1723    }
1724
1725    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1726    #[test]
1727    fn arrayfun_uniform_output_option_identifier() {
1728        let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1729        let err = call(
1730            Value::FunctionHandle("sin".to_string()),
1731            vec![
1732                Value::Tensor(tensor),
1733                Value::String("UniformOutput".into()),
1734                Value::String("maybe".into()),
1735            ],
1736        )
1737        .expect_err("expected invalid uniform output option");
1738        assert_eq!(
1739            err.identifier(),
1740            ARRAYFUN_ERROR_UNIFORM_OUTPUT_OPTION.identifier
1741        );
1742    }
1743
1744    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1745    #[test]
1746    fn arrayfun_unknown_name_value_identifier() {
1747        let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1748        let err = call(
1749            Value::FunctionHandle("sin".to_string()),
1750            vec![
1751                Value::Tensor(tensor),
1752                Value::String("MysteryFlag".into()),
1753                Value::Bool(true),
1754            ],
1755        )
1756        .expect_err("expected unknown name-value error");
1757        assert_eq!(err.identifier(), ARRAYFUN_ERROR_INVALID_INPUT.identifier);
1758    }
1759
1760    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1761    #[test]
1762    fn arrayfun_size_mismatch_errors() {
1763        let taller = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1764        let shorter = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1765        let err = call(
1766            Value::FunctionHandle("sin".to_string()),
1767            vec![Value::Tensor(taller), Value::Tensor(shorter)],
1768        )
1769        .expect_err("expected size mismatch error");
1770        let err = err.to_string();
1771        assert!(
1772            err.contains("does not match"),
1773            "expected size mismatch error, got {err}"
1774        );
1775    }
1776
1777    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1778    #[test]
1779    fn arrayfun_error_handler_recovers() {
1780        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1781        let handler = Value::Closure(Closure {
1782            function_name: "__arrayfun_test_handler".into(),
1783            bound_function: None,
1784            captures: vec![Value::Num(42.0)],
1785        });
1786        let result = call(
1787            Value::String("@nonexistent_builtin".into()),
1788            vec![
1789                Value::Tensor(tensor),
1790                Value::String("ErrorHandler".into()),
1791                handler,
1792            ],
1793        )
1794        .expect("arrayfun error handler");
1795        match result {
1796            Value::Tensor(out) => {
1797                assert_eq!(out.shape, vec![3, 1]);
1798                assert_eq!(out.data, vec![42.0, 42.0, 42.0]);
1799            }
1800            other => panic!("expected tensor, got {other:?}"),
1801        }
1802    }
1803
1804    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1805    #[test]
1806    fn arrayfun_error_without_handler_propagates_identifier() {
1807        let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1808        let err = call(
1809            Value::String("@nonexistent_builtin".into()),
1810            vec![Value::Tensor(tensor)],
1811        )
1812        .expect_err("expected unresolved function error");
1813        assert_eq!(
1814            err.identifier(),
1815            ARRAYFUN_ERROR_UNDEFINED_FUNCTION.identifier,
1816            "unexpected error: {}",
1817            err.message()
1818        );
1819    }
1820
1821    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1822    #[test]
1823    fn arrayfun_uniform_logical_result() {
1824        let tensor = Tensor::new(vec![1.0, f64::NAN, 0.0, f64::INFINITY], vec![4, 1]).unwrap();
1825        let result = call(
1826            Value::FunctionHandle("isfinite".to_string()),
1827            vec![Value::Tensor(tensor)],
1828        )
1829        .expect("arrayfun isfinite");
1830        match result {
1831            Value::LogicalArray(la) => {
1832                assert_eq!(la.shape, vec![4, 1]);
1833                assert_eq!(la.data, vec![1, 0, 1, 0]);
1834            }
1835            other => panic!("expected logical array, got {other:?}"),
1836        }
1837    }
1838
1839    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1840    #[test]
1841    fn arrayfun_uniform_character_result() {
1842        let tensor = Tensor::new(vec![65.0, 66.0, 67.0], vec![1, 3]).unwrap();
1843        let result = call(
1844            Value::FunctionHandle("char".to_string()),
1845            vec![Value::Tensor(tensor)],
1846        )
1847        .expect("arrayfun char");
1848        match result {
1849            Value::CharArray(ca) => {
1850                assert_eq!(ca.rows, 1);
1851                assert_eq!(ca.cols, 3);
1852                assert_eq!(ca.data, vec!['A', 'B', 'C']);
1853            }
1854            other => panic!("expected char array, got {other:?}"),
1855        }
1856    }
1857
1858    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1859    #[test]
1860    fn arrayfun_uniform_false_gpu_returns_cell() {
1861        test_support::with_test_provider(|provider| {
1862            let tensor = Tensor::new(vec![0.0, 1.0], vec![2, 1]).unwrap();
1863            let view = HostTensorView {
1864                data: &tensor.data,
1865                shape: &tensor.shape,
1866            };
1867            let handle = provider.upload(&view).expect("upload");
1868            let result = call(
1869                Value::FunctionHandle("sin".to_string()),
1870                vec![
1871                    Value::GpuTensor(handle),
1872                    Value::String("UniformOutput".into()),
1873                    Value::Bool(false),
1874                ],
1875            )
1876            .expect("arrayfun");
1877            match result {
1878                Value::Cell(cell) => {
1879                    assert_eq!(cell.rows, 2);
1880                    assert_eq!(cell.cols, 1);
1881                    let first = cell.get(0, 0).expect("first cell");
1882                    let second = cell.get(1, 0).expect("second cell");
1883                    match (first, second) {
1884                        (Value::Num(a), Value::Num(b)) => {
1885                            assert!((a - 0.0f64.sin()).abs() < 1e-12);
1886                            assert!((b - 1.0f64.sin()).abs() < 1e-12);
1887                        }
1888                        other => panic!("expected numeric cells, got {other:?}"),
1889                    }
1890                }
1891                other => panic!("expected cell, got {other:?}"),
1892            }
1893        });
1894    }
1895
1896    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1897    #[test]
1898    fn arrayfun_gpu_roundtrip() {
1899        test_support::with_test_provider(|provider| {
1900            let tensor = Tensor::new(vec![0.0, 1.0, 2.0, 3.0], vec![4, 1]).unwrap();
1901            let view = HostTensorView {
1902                data: &tensor.data,
1903                shape: &tensor.shape,
1904            };
1905            let handle = provider.upload(&view).expect("upload");
1906            let result = call(
1907                Value::FunctionHandle("sin".to_string()),
1908                vec![Value::GpuTensor(handle)],
1909            )
1910            .expect("arrayfun");
1911            match result {
1912                Value::GpuTensor(gpu) => {
1913                    let gathered = test_support::gather(Value::GpuTensor(gpu.clone())).unwrap();
1914                    let expected: Vec<f64> = tensor.data.iter().map(|&x| x.sin()).collect();
1915                    assert_eq!(gathered.data, expected);
1916                    let _ = provider.free(&gpu);
1917                }
1918                other => panic!("expected gpu tensor, got {other:?}"),
1919            }
1920        });
1921    }
1922
1923    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1924    #[test]
1925    #[cfg(feature = "wgpu")]
1926    fn arrayfun_wgpu_sin_matches_cpu() {
1927        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1928            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1929        );
1930        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1931
1932        let tensor = Tensor::new(vec![0.0, 1.0, 2.0, 3.0], vec![4, 1]).unwrap();
1933        let view = HostTensorView {
1934            data: &tensor.data,
1935            shape: &tensor.shape,
1936        };
1937        let handle = provider.upload(&view).expect("upload");
1938        let result = call(
1939            Value::FunctionHandle("sin".into()),
1940            vec![Value::GpuTensor(handle.clone())],
1941        )
1942        .expect("arrayfun sin gpu");
1943        let Value::GpuTensor(out_handle) = result else {
1944            panic!("expected GPU tensor result");
1945        };
1946        let gathered = test_support::gather(Value::GpuTensor(out_handle.clone())).unwrap();
1947        let expected: Vec<f64> = tensor.data.iter().map(|v| v.sin()).collect();
1948        assert_eq!(gathered.shape, tensor.shape);
1949        let tol = match provider.precision() {
1950            runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
1951            runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
1952        };
1953        for (actual, expect) in gathered.data.iter().zip(expected.iter()) {
1954            assert!(
1955                (actual - expect).abs() < tol,
1956                "expected {expect}, got {actual}"
1957            );
1958        }
1959        let _ = provider.free(&handle);
1960        let _ = provider.free(&out_handle);
1961    }
1962
1963    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1964    #[test]
1965    #[cfg(feature = "wgpu")]
1966    fn arrayfun_wgpu_plus_matches_cpu() {
1967        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1968            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1969        );
1970        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1971
1972        let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1973        let b = Tensor::new(vec![4.0, 3.0, 2.0, 1.0], vec![2, 2]).unwrap();
1974        let view_a = HostTensorView {
1975            data: &a.data,
1976            shape: &a.shape,
1977        };
1978        let view_b = HostTensorView {
1979            data: &b.data,
1980            shape: &b.shape,
1981        };
1982        let handle_a = provider.upload(&view_a).expect("upload a");
1983        let handle_b = provider.upload(&view_b).expect("upload b");
1984        let result = call(
1985            Value::FunctionHandle("plus".into()),
1986            vec![
1987                Value::GpuTensor(handle_a.clone()),
1988                Value::GpuTensor(handle_b.clone()),
1989            ],
1990        )
1991        .expect("arrayfun plus gpu");
1992
1993        let Value::GpuTensor(out_handle) = result else {
1994            panic!("expected GPU tensor result");
1995        };
1996        let gathered = test_support::gather(Value::GpuTensor(out_handle.clone())).unwrap();
1997        let expected: Vec<f64> = a
1998            .data
1999            .iter()
2000            .zip(b.data.iter())
2001            .map(|(x, y)| x + y)
2002            .collect();
2003        assert_eq!(gathered.shape, a.shape);
2004        let tol = match provider.precision() {
2005            runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
2006            runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
2007        };
2008        for (actual, expect) in gathered.data.iter().zip(expected.iter()) {
2009            assert!(
2010                (actual - expect).abs() < tol,
2011                "expected {expect}, got {actual}"
2012            );
2013        }
2014        let _ = provider.free(&handle_a);
2015        let _ = provider.free(&handle_b);
2016        let _ = provider.free(&out_handle);
2017    }
2018
2019    const ARRAYFUN_TEST_HELPER_ERRORS: [BuiltinErrorDescriptor; 0] = [];
2020    const ARRAYFUN_TEST_HELPER_OUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
2021        name: "out",
2022        ty: BuiltinParamType::Any,
2023        arity: BuiltinParamArity::Required,
2024        default: None,
2025        description: "Helper output value.",
2026    }];
2027    const ARRAYFUN_TEST_HANDLER_INPUTS: [BuiltinParamDescriptor; 3] = [
2028        BuiltinParamDescriptor {
2029            name: "seed",
2030            ty: BuiltinParamType::Any,
2031            arity: BuiltinParamArity::Required,
2032            default: None,
2033            description: "Seed value.",
2034        },
2035        BuiltinParamDescriptor {
2036            name: "err",
2037            ty: BuiltinParamType::Any,
2038            arity: BuiltinParamArity::Required,
2039            default: None,
2040            description: "Error context placeholder.",
2041        },
2042        BuiltinParamDescriptor {
2043            name: "rest",
2044            ty: BuiltinParamType::Any,
2045            arity: BuiltinParamArity::Variadic,
2046            default: None,
2047            description: "Additional values.",
2048        },
2049    ];
2050    const ARRAYFUN_TEST_HANDLER_SIGNATURES: [BuiltinSignatureDescriptor; 1] =
2051        [BuiltinSignatureDescriptor {
2052            label: "out = __arrayfun_test_handler(seed, err, ...)",
2053            inputs: &ARRAYFUN_TEST_HANDLER_INPUTS,
2054            outputs: &ARRAYFUN_TEST_HELPER_OUT,
2055        }];
2056    const ARRAYFUN_TEST_HANDLER_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
2057        signatures: &ARRAYFUN_TEST_HANDLER_SIGNATURES,
2058        output_mode: BuiltinOutputMode::Fixed,
2059        completion_policy: BuiltinCompletionPolicy::HiddenInternal,
2060        errors: &ARRAYFUN_TEST_HELPER_ERRORS,
2061    };
2062
2063    #[runmat_macros::runtime_builtin(
2064        name = "__arrayfun_test_handler",
2065        descriptor(
2066            crate::builtins::acceleration::gpu::arrayfun::tests::ARRAYFUN_TEST_HANDLER_DESCRIPTOR
2067        ),
2068        type_resolver(arrayfun_type),
2069        builtin_path = "crate::builtins::acceleration::gpu::arrayfun::tests"
2070    )]
2071    async fn arrayfun_test_handler(
2072        seed: Value,
2073        _err: Value,
2074        rest: Vec<Value>,
2075    ) -> crate::BuiltinResult<Value> {
2076        let _ = rest;
2077        Ok(seed)
2078    }
2079}