1use crate::builtins::acceleration::gpu::type_resolvers::arrayfun_type;
10use crate::builtins::common::spec::{
11 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
12 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
13};
14use crate::{
15 build_runtime_error, gather_if_needed_async, make_cell_with_shape, user_functions,
16 BuiltinResult, RuntimeError,
17};
18use runmat_accelerate_api::{set_handle_logical, GpuTensorHandle, HostTensorView};
19use runmat_builtins::{
20 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
21 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
22 CharArray, Closure, ComplexTensor, LogicalArray, StringArray, Tensor, Value,
23};
24use runmat_macros::runtime_builtin;
25
26#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::acceleration::gpu::arrayfun")]
27pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
28 name: "arrayfun",
29 op_kind: GpuOpKind::Elementwise,
30 supported_precisions: &[ScalarType::F32, ScalarType::F64],
31 broadcast: BroadcastSemantics::Matlab,
32 provider_hooks: &[
33 ProviderHook::Unary { name: "unary_sin" },
34 ProviderHook::Unary { name: "unary_cos" },
35 ProviderHook::Unary { name: "unary_abs" },
36 ProviderHook::Unary { name: "unary_exp" },
37 ProviderHook::Unary { name: "unary_log" },
38 ProviderHook::Unary { name: "unary_sqrt" },
39 ProviderHook::Binary {
40 name: "elem_add",
41 commutative: true,
42 },
43 ProviderHook::Binary {
44 name: "elem_sub",
45 commutative: false,
46 },
47 ProviderHook::Binary {
48 name: "elem_mul",
49 commutative: true,
50 },
51 ProviderHook::Binary {
52 name: "elem_div",
53 commutative: false,
54 },
55 ],
56 constant_strategy: ConstantStrategy::InlineLiteral,
57 residency: ResidencyPolicy::NewHandle,
58 nan_mode: ReductionNaN::Include,
59 two_pass_threshold: None,
60 workgroup_size: None,
61 accepts_nan_mode: false,
62 notes: "Providers that implement the listed kernels can run supported callbacks entirely on the GPU; unsupported callbacks fall back to the host path with re-upload.",
63};
64
65#[runmat_macros::register_fusion_spec(
66 builtin_path = "crate::builtins::acceleration::gpu::arrayfun"
67)]
68pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
69 name: "arrayfun",
70 shape: ShapeRequirements::Any,
71 constant_strategy: ConstantStrategy::InlineLiteral,
72 elementwise: None,
73 reduction: None,
74 emits_nan: false,
75 notes: "Acts as a fusion barrier because the callback can run arbitrary MATLAB code.",
76};
77
78const BUILTIN_NAME: &str = "arrayfun";
79
80const ARRAYFUN_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
81 name: "B",
82 ty: BuiltinParamType::Any,
83 arity: BuiltinParamArity::Required,
84 default: None,
85 description: "Element-wise callback result (uniform array or cell array).",
86}];
87
88const ARRAYFUN_INPUTS_BASE: [BuiltinParamDescriptor; 3] = [
89 BuiltinParamDescriptor {
90 name: "func",
91 ty: BuiltinParamType::Any,
92 arity: BuiltinParamArity::Required,
93 default: None,
94 description: "Function handle or callable name.",
95 },
96 BuiltinParamDescriptor {
97 name: "A1",
98 ty: BuiltinParamType::Any,
99 arity: BuiltinParamArity::Required,
100 default: None,
101 description: "First input array.",
102 },
103 BuiltinParamDescriptor {
104 name: "An",
105 ty: BuiltinParamType::Any,
106 arity: BuiltinParamArity::Variadic,
107 default: None,
108 description: "Additional input arrays.",
109 },
110];
111
112const ARRAYFUN_INPUTS_UNIFORM: [BuiltinParamDescriptor; 5] = [
113 BuiltinParamDescriptor {
114 name: "func",
115 ty: BuiltinParamType::Any,
116 arity: BuiltinParamArity::Required,
117 default: None,
118 description: "Function handle or callable name.",
119 },
120 BuiltinParamDescriptor {
121 name: "A1",
122 ty: BuiltinParamType::Any,
123 arity: BuiltinParamArity::Required,
124 default: None,
125 description: "First input array.",
126 },
127 BuiltinParamDescriptor {
128 name: "An",
129 ty: BuiltinParamType::Any,
130 arity: BuiltinParamArity::Variadic,
131 default: None,
132 description: "Additional input arrays.",
133 },
134 BuiltinParamDescriptor {
135 name: "UniformOutput",
136 ty: BuiltinParamType::PropertyName,
137 arity: BuiltinParamArity::Required,
138 default: Some("\"UniformOutput\""),
139 description: "Name-value key that toggles uniform output collection.",
140 },
141 BuiltinParamDescriptor {
142 name: "tf",
143 ty: BuiltinParamType::Any,
144 arity: BuiltinParamArity::Required,
145 default: Some("true"),
146 description: "Logical true/false value for UniformOutput.",
147 },
148];
149
150const ARRAYFUN_INPUTS_HANDLER: [BuiltinParamDescriptor; 5] = [
151 BuiltinParamDescriptor {
152 name: "func",
153 ty: BuiltinParamType::Any,
154 arity: BuiltinParamArity::Required,
155 default: None,
156 description: "Function handle or callable name.",
157 },
158 BuiltinParamDescriptor {
159 name: "A1",
160 ty: BuiltinParamType::Any,
161 arity: BuiltinParamArity::Required,
162 default: None,
163 description: "First input array.",
164 },
165 BuiltinParamDescriptor {
166 name: "An",
167 ty: BuiltinParamType::Any,
168 arity: BuiltinParamArity::Variadic,
169 default: None,
170 description: "Additional input arrays.",
171 },
172 BuiltinParamDescriptor {
173 name: "ErrorHandler",
174 ty: BuiltinParamType::PropertyName,
175 arity: BuiltinParamArity::Required,
176 default: Some("\"ErrorHandler\""),
177 description: "Name-value key that provides fallback callback on per-element failures.",
178 },
179 BuiltinParamDescriptor {
180 name: "handler",
181 ty: BuiltinParamType::Any,
182 arity: BuiltinParamArity::Required,
183 default: None,
184 description: "Callback invoked with error struct and original scalar arguments.",
185 },
186];
187
188const ARRAYFUN_INPUTS_OPTIONS: [BuiltinParamDescriptor; 4] = [
189 BuiltinParamDescriptor {
190 name: "func",
191 ty: BuiltinParamType::Any,
192 arity: BuiltinParamArity::Required,
193 default: None,
194 description: "Function handle or callable name.",
195 },
196 BuiltinParamDescriptor {
197 name: "A1",
198 ty: BuiltinParamType::Any,
199 arity: BuiltinParamArity::Required,
200 default: None,
201 description: "First input array.",
202 },
203 BuiltinParamDescriptor {
204 name: "An",
205 ty: BuiltinParamType::Any,
206 arity: BuiltinParamArity::Variadic,
207 default: None,
208 description: "Additional input arrays.",
209 },
210 BuiltinParamDescriptor {
211 name: "nameValue",
212 ty: BuiltinParamType::Any,
213 arity: BuiltinParamArity::Variadic,
214 default: None,
215 description: "Name-value option pairs including UniformOutput and ErrorHandler.",
216 },
217];
218
219const ARRAYFUN_SIGNATURES: [BuiltinSignatureDescriptor; 4] = [
220 BuiltinSignatureDescriptor {
221 label: "B = arrayfun(func, A1, An...)",
222 inputs: &ARRAYFUN_INPUTS_BASE,
223 outputs: &ARRAYFUN_OUTPUT,
224 },
225 BuiltinSignatureDescriptor {
226 label: "B = arrayfun(func, A1, An..., \"UniformOutput\", tf)",
227 inputs: &ARRAYFUN_INPUTS_UNIFORM,
228 outputs: &ARRAYFUN_OUTPUT,
229 },
230 BuiltinSignatureDescriptor {
231 label: "B = arrayfun(func, A1, An..., \"ErrorHandler\", handler)",
232 inputs: &ARRAYFUN_INPUTS_HANDLER,
233 outputs: &ARRAYFUN_OUTPUT,
234 },
235 BuiltinSignatureDescriptor {
236 label: "B = arrayfun(func, A1, An..., nameValue...)",
237 inputs: &ARRAYFUN_INPUTS_OPTIONS,
238 outputs: &ARRAYFUN_OUTPUT,
239 },
240];
241
242const ARRAYFUN_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
243 code: "RM.ARRAYFUN.INVALID_INPUT",
244 identifier: Some("RunMat:arrayfun:InvalidInput"),
245 when: "Inputs, callable forms, or option tails violate arrayfun argument requirements.",
246 message: "arrayfun: invalid input arguments",
247};
248
249const ARRAYFUN_ERROR_UNIFORM_OUTPUT_OPTION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
250 code: "RM.ARRAYFUN.UNIFORM_OUTPUT_OPTION",
251 identifier: Some("RunMat:arrayfun:UniformOutputOption"),
252 when: "UniformOutput option value is not interpretable as logical true/false.",
253 message: "arrayfun: UniformOutput must be logical true or false",
254};
255
256const ARRAYFUN_ERROR_CALLBACK_FAILED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
257 code: "RM.ARRAYFUN.CALLBACK_FAILED",
258 identifier: Some("RunMat:arrayfun:CallbackFailed"),
259 when: "Callback invocation fails and no ErrorHandler recovers the element.",
260 message: "arrayfun: callback execution failed",
261};
262
263const ARRAYFUN_ERROR_UNIFORM_OUTPUT_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
264 code: "RM.ARRAYFUN.UNIFORM_OUTPUT_TYPE",
265 identifier: Some("RunMat:arrayfun:UniformOutputType"),
266 when: "UniformOutput=true callback result is not a supported scalar type.",
267 message: "arrayfun: callback must return scalar values for UniformOutput=true",
268};
269
270const ARRAYFUN_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
271 code: "RM.ARRAYFUN.INTERNAL",
272 identifier: Some("RunMat:arrayfun:InternalError"),
273 when: "Internal shape/index/materialization/upload path fails.",
274 message: "arrayfun: internal error",
275};
276
277const ARRAYFUN_ERROR_UNDEFINED_FUNCTION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
278 code: "RM.ARRAYFUN.UNDEFINED_FUNCTION",
279 identifier: Some("RunMat:UndefinedFunction"),
280 when: "External callable identity cannot be resolved in semantic/runtime boundaries.",
281 message: "arrayfun: undefined function",
282};
283
284const ARRAYFUN_ERRORS: [BuiltinErrorDescriptor; 6] = [
285 ARRAYFUN_ERROR_INVALID_INPUT,
286 ARRAYFUN_ERROR_UNIFORM_OUTPUT_OPTION,
287 ARRAYFUN_ERROR_CALLBACK_FAILED,
288 ARRAYFUN_ERROR_UNIFORM_OUTPUT_TYPE,
289 ARRAYFUN_ERROR_INTERNAL,
290 ARRAYFUN_ERROR_UNDEFINED_FUNCTION,
291];
292
293pub const ARRAYFUN_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
294 signatures: &ARRAYFUN_SIGNATURES,
295 output_mode: BuiltinOutputMode::Fixed,
296 completion_policy: BuiltinCompletionPolicy::Public,
297 errors: &ARRAYFUN_ERRORS,
298};
299
300fn arrayfun_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
301 arrayfun_error_with_message(error.message, error)
302}
303
304fn arrayfun_error_with_message(
305 message: impl Into<String>,
306 error: &'static BuiltinErrorDescriptor,
307) -> RuntimeError {
308 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
309 if let Some(identifier) = error.identifier {
310 builder = builder.with_identifier(identifier);
311 }
312 builder.build()
313}
314
315fn arrayfun_error_with_detail(
316 error: &'static BuiltinErrorDescriptor,
317 detail: impl AsRef<str>,
318) -> RuntimeError {
319 arrayfun_error_with_message(format!("{}: {}", error.message, detail.as_ref()), error)
320}
321
322fn arrayfun_error_with_source(
323 message: impl Into<String>,
324 error: &'static BuiltinErrorDescriptor,
325 source: RuntimeError,
326) -> RuntimeError {
327 let identifier = source.identifier().map(str::to_string);
328 let mut builder = build_runtime_error(message.into())
329 .with_builtin(BUILTIN_NAME)
330 .with_source(source);
331 if let Some(identifier) = identifier.as_deref().or(error.identifier) {
332 builder = builder.with_identifier(identifier);
333 }
334 builder.build()
335}
336
337fn arrayfun_flow(message: impl Into<String>) -> RuntimeError {
338 arrayfun_error_with_message(message, &ARRAYFUN_ERROR_INVALID_INPUT)
339}
340
341fn arrayfun_internal(message: impl Into<String>) -> RuntimeError {
342 arrayfun_error_with_message(message, &ARRAYFUN_ERROR_INTERNAL)
343}
344
345fn arrayfun_flow_with_source(message: impl Into<String>, source: RuntimeError) -> RuntimeError {
346 arrayfun_error_with_source(message, &ARRAYFUN_ERROR_CALLBACK_FAILED, source)
347}
348
349fn format_handler_error(err: &RuntimeError) -> String {
350 if let Some(identifier) = err.identifier() {
351 if err.message().is_empty() {
352 return identifier.to_string();
353 }
354 if err.message().starts_with(identifier) {
355 return err.message().to_string();
356 }
357 return format!("{identifier}: {}", err.message());
358 }
359 err.message().to_string()
360}
361
362#[runtime_builtin(
363 name = "arrayfun",
364 category = "acceleration/gpu",
365 summary = "Apply a function element-wise across array inputs.",
366 keywords = "arrayfun,gpu,array,map,functional",
367 accel = "host",
368 type_resolver(arrayfun_type),
369 descriptor(crate::builtins::acceleration::gpu::arrayfun::ARRAYFUN_DESCRIPTOR),
370 builtin_path = "crate::builtins::acceleration::gpu::arrayfun"
371)]
372async fn arrayfun_builtin(func: Value, mut rest: Vec<Value>) -> crate::BuiltinResult<Value> {
373 let callable = Callable::from_function(func)?;
374
375 let mut uniform_output = true;
376 let mut error_handler: Option<Callable> = None;
377
378 while rest.len() >= 2 {
379 let key_candidate = rest[rest.len() - 2].clone();
380 let Some(name) = extract_string(&key_candidate) else {
381 break;
382 };
383 let value = rest.pop().expect("value present");
384 rest.pop();
385 match name.trim().to_ascii_lowercase().as_str() {
386 "uniformoutput" => uniform_output = parse_uniform_output(value)?,
387 "errorhandler" => error_handler = Some(Callable::from_function(value)?),
388 other => {
389 return Err(arrayfun_flow(format!(
390 "arrayfun: unknown name-value argument '{other}'"
391 )))
392 }
393 }
394 }
395
396 if rest.is_empty() {
397 return Err(arrayfun_flow("arrayfun: expected at least one input array"));
398 }
399
400 let inputs_snapshot = rest.clone();
401 let has_gpu_input = inputs_snapshot
402 .iter()
403 .any(|value| matches!(value, Value::GpuTensor(_)));
404 let gpu_device_id = inputs_snapshot.iter().find_map(|v| {
405 if let Value::GpuTensor(h) = v {
406 Some(h.device_id)
407 } else {
408 None
409 }
410 });
411
412 if uniform_output {
413 if let Some(gpu_result) =
414 try_gpu_fast_path(&callable, &inputs_snapshot, error_handler.as_ref()).await?
415 {
416 return Ok(gpu_result);
417 }
418 }
419
420 let mut inputs: Vec<ArrayInput> = Vec::with_capacity(rest.len());
421 let mut base_shape: Vec<usize> = Vec::new();
422 let mut base_len: Option<usize> = None;
423
424 for (idx, raw) in rest.into_iter().enumerate() {
425 if matches!(raw, Value::Cell(_)) {
426 return Err(arrayfun_flow(
427 "arrayfun: cell inputs are not supported (use cellfun instead)",
428 ));
429 }
430 if matches!(raw, Value::Struct(_)) {
431 return Err(arrayfun_flow("arrayfun: struct inputs are not supported"));
432 }
433
434 let host_value = gather_if_needed_async(&raw).await?;
435 let data = ArrayData::from_value(host_value)?;
436 let len = data.len();
437 let is_scalar = len == 1;
438
439 let mut input = ArrayInput { data, is_scalar };
440
441 if let Some(current) = base_len {
442 if current == len {
443 if len > 1 {
444 let shape = input.shape_vec();
445 if shape != base_shape {
446 return Err(arrayfun_flow(format!(
447 "arrayfun: input {} does not match the size of the first array",
448 idx + 1
449 )));
450 }
451 }
452 } else if len == 1 {
453 input.is_scalar = true;
454 } else if current == 1 {
455 base_len = Some(len);
456 base_shape = input.shape_vec();
457 for prior in &mut inputs {
458 let prior_len = prior.len();
459 if prior_len == len {
460 if prior.shape_vec() != base_shape {
461 return Err(arrayfun_flow(format!(
462 "arrayfun: input {} does not match the size of the first array",
463 idx
464 )));
465 }
466 } else if prior_len == 1 {
467 prior.is_scalar = true;
468 } else if prior_len == 0 && len == 0 {
469 continue;
470 } else {
471 return Err(arrayfun_flow(format!(
472 "arrayfun: input {} does not match the size of the first array",
473 idx
474 )));
475 }
476 }
477 } else if len == 0 && current == 0 {
478 let shape = input.shape_vec();
479 if shape != base_shape {
480 return Err(arrayfun_flow(format!(
481 "arrayfun: input {} does not match the size of the first array",
482 idx + 1
483 )));
484 }
485 } else {
486 return Err(arrayfun_flow(format!(
487 "arrayfun: input {} does not match the size of the first array",
488 idx + 1
489 )));
490 }
491 } else {
492 base_len = Some(len);
493 base_shape = input.shape_vec();
494 }
495
496 inputs.push(input);
497 }
498
499 let total_len = base_len.unwrap_or(0);
500
501 if total_len == 0 {
502 if uniform_output {
503 return Ok(empty_uniform(&base_shape));
504 } else {
505 return make_cell_with_shape(Vec::new(), base_shape)
506 .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")));
507 }
508 }
509
510 let mut collector = if uniform_output {
511 Some(UniformCollector::Pending)
512 } else {
513 None
514 };
515
516 let mut cell_outputs: Vec<Value> = Vec::new();
517 let mut args: Vec<Value> = Vec::with_capacity(inputs.len());
518
519 for idx in 0..total_len {
520 args.clear();
521 for input in &inputs {
522 args.push(input.value_at(idx)?);
523 }
524
525 let result = match callable.call(&args).await {
526 Ok(value) => value,
527 Err(err) => {
528 let handler = match error_handler.as_ref() {
529 Some(handler) => handler,
530 None => {
531 return Err(arrayfun_flow_with_source(
532 format!("arrayfun: {}", err.message()),
533 err,
534 ))
535 }
536 };
537 let err_message = format_handler_error(&err);
538 let err_value = make_error_struct(&err_message, idx, &base_shape)?;
539 let mut handler_args = Vec::with_capacity(1 + args.len());
540 handler_args.push(err_value);
541 handler_args.extend(args.clone());
542 handler.call(&handler_args).await?
543 }
544 };
545
546 let host_result = gather_if_needed_async(&result).await?;
547
548 if let Some(collector) = collector.as_mut() {
549 collector.push(&host_result)?;
550 } else {
551 cell_outputs.push(host_result);
552 }
553 }
554
555 if let Some(collector) = collector {
556 let uniform = collector.finish(&base_shape)?;
557 maybe_upload_uniform(uniform, has_gpu_input, gpu_device_id)
558 } else {
559 make_cell_with_shape(cell_outputs, base_shape)
560 .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))
561 }
562}
563
564fn maybe_upload_uniform(
565 value: Value,
566 has_gpu_input: bool,
567 gpu_device_id: Option<u32>,
568) -> BuiltinResult<Value> {
569 if !has_gpu_input {
570 return Ok(value);
571 }
572 #[cfg(all(test, feature = "wgpu"))]
573 {
574 if matches!(gpu_device_id, Some(id) if id != 0) {
575 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
576 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
577 );
578 }
579 }
580 let _ = gpu_device_id; let provider = match runmat_accelerate_api::provider() {
582 Some(p) => p,
583 None => return Ok(value),
584 };
585
586 match value {
587 Value::Tensor(tensor) => {
588 let view = HostTensorView {
589 data: &tensor.data,
590 shape: &tensor.shape,
591 };
592 let handle = provider
593 .upload(&view)
594 .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?;
595 Ok(Value::GpuTensor(handle))
596 }
597 Value::LogicalArray(logical) => {
598 let data: Vec<f64> = logical
599 .data
600 .iter()
601 .map(|&bit| if bit != 0 { 1.0 } else { 0.0 })
602 .collect();
603 let tensor = Tensor::new(data, logical.shape.clone())
604 .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?;
605 let view = HostTensorView {
606 data: &tensor.data,
607 shape: &tensor.shape,
608 };
609 let handle = provider
610 .upload(&view)
611 .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?;
612 set_handle_logical(&handle, true);
613 Ok(Value::GpuTensor(handle))
614 }
615 other => Ok(other),
616 }
617}
618
619fn empty_uniform(shape: &[usize]) -> Value {
620 if shape.is_empty() {
621 return Value::Tensor(Tensor::zeros(vec![0, 0]));
622 }
623 let total: usize = shape.iter().product();
624 let tensor = Tensor::new(vec![0.0; total], shape.to_vec())
625 .unwrap_or_else(|_| Tensor::zeros(shape.to_vec()));
626 Value::Tensor(tensor)
627}
628
629fn parse_uniform_output(value: Value) -> BuiltinResult<bool> {
630 match value {
631 Value::Bool(b) => Ok(b),
632 Value::Num(n) => Ok(n != 0.0),
633 Value::Int(iv) => Ok(iv.to_f64() != 0.0),
634 Value::String(s) => parse_bool_string(&s)
635 .ok_or_else(|| arrayfun_error(&ARRAYFUN_ERROR_UNIFORM_OUTPUT_OPTION)),
636 Value::CharArray(ca) if ca.rows == 1 => {
637 let text: String = ca.data.iter().collect();
638 parse_bool_string(&text)
639 .ok_or_else(|| arrayfun_error(&ARRAYFUN_ERROR_UNIFORM_OUTPUT_OPTION))
640 }
641 other => Err(arrayfun_error_with_detail(
642 &ARRAYFUN_ERROR_UNIFORM_OUTPUT_OPTION,
643 format!("got {other:?}"),
644 )),
645 }
646}
647
648fn parse_bool_string(value: &str) -> Option<bool> {
649 match value.trim().to_ascii_lowercase().as_str() {
650 "true" | "on" => Some(true),
651 "false" | "off" => Some(false),
652 _ => None,
653 }
654}
655
656fn extract_string(value: &Value) -> Option<String> {
657 match value {
658 Value::String(s) => Some(s.clone()),
659 Value::CharArray(ca) if ca.rows == 1 => Some(ca.data.iter().collect()),
660 Value::StringArray(sa) if sa.data.len() == 1 => Some(sa.data[0].clone()),
661 _ => None,
662 }
663}
664
665struct ArrayInput {
666 data: ArrayData,
667 is_scalar: bool,
668}
669
670impl ArrayInput {
671 fn len(&self) -> usize {
672 self.data.len()
673 }
674
675 fn shape_vec(&self) -> Vec<usize> {
676 self.data.shape_vec()
677 }
678
679 fn value_at(&self, idx: usize) -> BuiltinResult<Value> {
680 if self.is_scalar {
681 self.data.value_at(0)
682 } else {
683 self.data.value_at(idx)
684 }
685 }
686}
687
688enum ArrayData {
689 Tensor(Tensor),
690 Logical(LogicalArray),
691 Complex(ComplexTensor),
692 Char(CharArray),
693 String(StringArray),
694 Scalar(Value),
695}
696
697impl ArrayData {
698 fn from_value(value: Value) -> BuiltinResult<Self> {
699 match value {
700 Value::Tensor(t) => Ok(ArrayData::Tensor(t)),
701 Value::LogicalArray(l) => Ok(ArrayData::Logical(l)),
702 Value::ComplexTensor(c) => Ok(ArrayData::Complex(c)),
703 Value::CharArray(ca) => Ok(ArrayData::Char(ca)),
704 Value::StringArray(sa) => Ok(ArrayData::String(sa)),
705 Value::Num(_)
706 | Value::Bool(_)
707 | Value::Int(_)
708 | Value::Complex(_, _)
709 | Value::String(_) => {
710 Ok(ArrayData::Scalar(value))
711 }
712 other => Err(arrayfun_flow(format!(
713 "arrayfun: unsupported input type {other:?} (expected numeric, logical, complex, char, or string arrays)"
714 ))),
715 }
716 }
717
718 fn len(&self) -> usize {
719 match self {
720 ArrayData::Tensor(t) => t.data.len(),
721 ArrayData::Logical(l) => l.data.len(),
722 ArrayData::Complex(c) => c.data.len(),
723 ArrayData::Char(ca) => ca.rows * ca.cols,
724 ArrayData::String(sa) => sa.data.len(),
725 ArrayData::Scalar(_) => 1,
726 }
727 }
728
729 fn shape_vec(&self) -> Vec<usize> {
730 match self {
731 ArrayData::Tensor(t) => {
732 if t.shape.is_empty() {
733 vec![1, 1]
734 } else {
735 t.shape.clone()
736 }
737 }
738 ArrayData::Logical(l) => {
739 if l.shape.is_empty() {
740 vec![1, 1]
741 } else {
742 l.shape.clone()
743 }
744 }
745 ArrayData::Complex(c) => {
746 if c.shape.is_empty() {
747 vec![1, 1]
748 } else {
749 c.shape.clone()
750 }
751 }
752 ArrayData::Char(ca) => vec![ca.rows, ca.cols],
753 ArrayData::String(sa) => {
754 if sa.shape.is_empty() {
755 vec![1, 1]
756 } else {
757 sa.shape.clone()
758 }
759 }
760 ArrayData::Scalar(_) => vec![1, 1],
761 }
762 }
763
764 fn value_at(&self, idx: usize) -> BuiltinResult<Value> {
765 match self {
766 ArrayData::Tensor(t) => {
767 Ok(Value::Num(*t.data.get(idx).ok_or_else(|| {
768 arrayfun_flow("arrayfun: index out of bounds")
769 })?))
770 }
771 ArrayData::Logical(l) => Ok(Value::Bool(
772 *l.data
773 .get(idx)
774 .ok_or_else(|| arrayfun_flow("arrayfun: index out of bounds"))?
775 != 0,
776 )),
777 ArrayData::Complex(c) => {
778 let (re, im) = c
779 .data
780 .get(idx)
781 .ok_or_else(|| arrayfun_flow("arrayfun: index out of bounds"))?;
782 Ok(Value::Complex(*re, *im))
783 }
784 ArrayData::Char(ca) => {
785 if ca.rows == 0 || ca.cols == 0 {
786 return Ok(Value::CharArray(
787 CharArray::new(Vec::new(), 0, 0)
788 .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?,
789 ));
790 }
791 let rows = ca.rows;
792 let cols = ca.cols;
793 let row = idx % rows;
794 let col = idx / rows;
795 let data_idx = row * cols + col;
796 let ch = *ca
797 .data
798 .get(data_idx)
799 .ok_or_else(|| arrayfun_flow("arrayfun: index out of bounds"))?;
800 let char_array = CharArray::new(vec![ch], 1, 1)
801 .map_err(|e| arrayfun_flow(format!("arrayfun: {e}")))?;
802 Ok(Value::CharArray(char_array))
803 }
804 ArrayData::String(sa) => {
805 Ok(Value::String(sa.data.get(idx).cloned().ok_or_else(
806 || arrayfun_flow("arrayfun: index out of bounds"),
807 )?))
808 }
809 ArrayData::Scalar(v) => Ok(v.clone()),
810 }
811 }
812}
813
814#[derive(Clone)]
815enum Callable {
816 Builtin { name: String },
817 ExternalName { name: String },
818 Closure(Closure),
819}
820
821impl Callable {
822 fn resolved_semantic_handle(name: &str) -> Option<Self> {
823 let function = user_functions::resolve_semantic_function_by_name(name)?;
824 Some(Callable::Closure(Closure {
825 function_name: name.to_string(),
826 bound_function: Some(function),
827 captures: Vec::new(),
828 }))
829 }
830
831 fn from_function(value: Value) -> BuiltinResult<Self> {
832 match value {
833 Value::String(text) => Self::from_text(&text),
834 Value::CharArray(ca) => {
835 if ca.rows != 1 {
836 Err(arrayfun_flow(
837 "arrayfun: function name must be a character vector or string scalar",
838 ))
839 } else {
840 let text: String = ca.data.iter().collect();
841 Self::from_text(&text)
842 }
843 }
844 Value::StringArray(sa) if sa.data.len() == 1 => Self::from_text(&sa.data[0]),
845 Value::FunctionHandle(name) => Self::from_text(&name),
846 Value::ExternalFunctionHandle(name) => {
847 if let Some(callable) = Self::resolved_semantic_handle(&name) {
848 Ok(callable)
849 } else if crate::is_well_formed_qualified_name(&name) {
850 Ok(Callable::ExternalName { name })
851 } else {
852 Ok(Callable::Builtin { name })
853 }
854 }
855 Value::BoundFunctionHandle { name, function } => Ok(Callable::Closure(Closure {
856 function_name: name,
857 bound_function: Some(function),
858 captures: Vec::new(),
859 })),
860 Value::Closure(mut closure) => {
861 if closure.bound_function.is_none() {
862 if let Some(function) =
863 user_functions::resolve_semantic_function_by_name(&closure.function_name)
864 {
865 closure.bound_function = Some(function);
866 }
867 }
868 Ok(Callable::Closure(closure))
869 }
870 Value::Num(_) | Value::Int(_) | Value::Bool(_) => Err(arrayfun_flow(
871 "arrayfun: expected function handle or builtin name, not a scalar value",
872 )),
873 other => Err(arrayfun_flow(format!(
874 "arrayfun: expected function handle or builtin name, got {other:?}"
875 ))),
876 }
877 }
878
879 fn from_text(text: &str) -> BuiltinResult<Self> {
880 let trimmed = text.trim();
881 if trimmed.is_empty() {
882 return Err(arrayfun_flow(
883 "arrayfun: expected function handle or builtin name, got empty string",
884 ));
885 }
886 if let Some(rest) = trimmed.strip_prefix('@') {
887 let name = rest.trim();
888 if name.is_empty() {
889 Err(arrayfun_flow("arrayfun: empty function handle"))
890 } else {
891 if let Some(callable) = Self::resolved_semantic_handle(name) {
892 return Ok(callable);
893 }
894 if crate::is_well_formed_qualified_name(name) {
895 return Ok(Callable::ExternalName {
896 name: name.to_string(),
897 });
898 }
899 Ok(Callable::Builtin {
900 name: name.to_string(),
901 })
902 }
903 } else {
904 let name = trimmed.to_ascii_lowercase();
905 if let Some(callable) = Self::resolved_semantic_handle(&name) {
906 return Ok(callable);
907 }
908 if crate::is_well_formed_qualified_name(&name) {
909 return Ok(Callable::ExternalName { name });
910 }
911 Ok(Callable::Builtin { name })
912 }
913 }
914
915 fn builtin_name(&self) -> Option<&str> {
916 match self {
917 Callable::Builtin { name } => Some(name.as_str()),
918 Callable::ExternalName { .. } | Callable::Closure(_) => None,
919 }
920 }
921
922 async fn call(&self, args: &[Value]) -> crate::BuiltinResult<Value> {
923 match self {
924 Callable::Builtin { name } => {
925 let request = user_functions::CallableRequest::resolved(
926 runmat_hir::CallableIdentity::DynamicName(runmat_hir::SymbolName(name.clone())),
927 runmat_hir::CallableFallbackPolicy::RuntimeNameResolution,
928 args.to_vec(),
929 1,
930 );
931 if let Some(result) = user_functions::try_call_semantic_descriptor(request).await {
932 return result;
933 }
934 crate::call_builtin_async(name, args).await
935 }
936 Callable::ExternalName { name } => {
937 let identity = crate::external_callable_identity_for_name(name);
938 let request = user_functions::CallableRequest::resolved(
939 identity.clone(),
940 runmat_hir::CallableFallbackPolicy::ExternalBoundary,
941 args.to_vec(),
942 1,
943 );
944 if let Some(result) = user_functions::try_call_semantic_descriptor(request).await {
945 return result;
946 }
947 Err(arrayfun_error_with_message(
948 format!("Undefined function for callable identity {identity:?}"),
949 &ARRAYFUN_ERROR_UNDEFINED_FUNCTION,
950 ))
951 }
952 Callable::Closure(c) => {
953 let mut merged = c.captures.clone();
954 merged.extend_from_slice(args);
955 if let Some(function) = c.bound_function {
956 let request =
957 user_functions::CallableRequest::semantic(function, merged.clone(), 1);
958 if let Some(result) =
959 user_functions::try_call_semantic_descriptor(request).await
960 {
961 return result;
962 }
963 return Err(arrayfun_error_with_detail(
964 &ARRAYFUN_ERROR_CALLBACK_FAILED,
965 format!(
966 "semantic closure '{}' ({function}) is unavailable",
967 c.function_name
968 ),
969 ));
970 }
971 if let Some(function) =
972 user_functions::resolve_semantic_function_by_name(&c.function_name)
973 {
974 let request =
975 user_functions::CallableRequest::semantic(function, merged.clone(), 1);
976 if let Some(result) =
977 user_functions::try_call_semantic_descriptor(request).await
978 {
979 return result;
980 }
981 }
982 crate::call_builtin_async(&c.function_name, &merged).await
983 }
984 }
985 }
986}
987
988async fn try_gpu_fast_path(
989 callable: &Callable,
990 inputs: &[Value],
991 error_handler: Option<&Callable>,
992) -> BuiltinResult<Option<Value>> {
993 if inputs.is_empty() || error_handler.is_some() {
994 return Ok(None);
995 }
996 if !inputs
997 .iter()
998 .all(|value| matches!(value, Value::GpuTensor(_)))
999 {
1000 return Ok(None);
1001 }
1002
1003 #[cfg(all(test, feature = "wgpu"))]
1004 {
1005 if inputs
1006 .iter()
1007 .any(|v| matches!(v, Value::GpuTensor(h) if h.device_id != 0))
1008 {
1009 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1010 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1011 );
1012 }
1013 }
1014 let provider = match runmat_accelerate_api::provider() {
1015 Some(p) => p,
1016 None => return Ok(None),
1017 };
1018
1019 let Some(name_raw) = callable.builtin_name() else {
1020 return Ok(None);
1021 };
1022 let name = name_raw.to_ascii_lowercase();
1023
1024 let mut handles: Vec<GpuTensorHandle> = Vec::with_capacity(inputs.len());
1025 for value in inputs {
1026 if let Value::GpuTensor(handle) = value {
1027 handles.push(handle.clone());
1028 }
1029 }
1030
1031 if handles.len() >= 2 {
1032 let base_shape = handles[0].shape.clone();
1033 if handles
1034 .iter()
1035 .skip(1)
1036 .any(|handle| handle.shape != base_shape)
1037 {
1038 return Ok(None);
1039 }
1040 }
1041
1042 let result = match name.as_str() {
1043 "sin" if handles.len() == 1 => provider.unary_sin(&handles[0]).await,
1044 "cos" if handles.len() == 1 => provider.unary_cos(&handles[0]).await,
1045 "abs" if handles.len() == 1 => provider.unary_abs(&handles[0]).await,
1046 "exp" if handles.len() == 1 => provider.unary_exp(&handles[0]).await,
1047 "log" if handles.len() == 1 => provider.unary_log(&handles[0]).await,
1048 "sqrt" if handles.len() == 1 => provider.unary_sqrt(&handles[0]).await,
1049 "plus" if handles.len() == 2 => provider.elem_add(&handles[0], &handles[1]).await,
1050 "minus" if handles.len() == 2 => provider.elem_sub(&handles[0], &handles[1]).await,
1051 "times" if handles.len() == 2 => provider.elem_mul(&handles[0], &handles[1]).await,
1052 "rdivide" if handles.len() == 2 => provider.elem_div(&handles[0], &handles[1]).await,
1053 "ldivide" if handles.len() == 2 => provider.elem_div(&handles[1], &handles[0]).await,
1054 _ => return Ok(None),
1055 };
1056
1057 match result {
1058 Ok(handle) => Ok(Some(Value::GpuTensor(handle))),
1059 Err(_) => Ok(None),
1060 }
1061}
1062
1063enum UniformCollector {
1064 Pending,
1065 Double(Vec<f64>),
1066 Logical(Vec<u8>),
1067 Complex(Vec<(f64, f64)>),
1068 Char(Vec<char>),
1069}
1070
1071impl UniformCollector {
1072 fn push(&mut self, value: &Value) -> BuiltinResult<()> {
1073 match self {
1074 UniformCollector::Pending => match classify_value(value)? {
1075 ClassifiedValue::Logical(b) => {
1076 *self = UniformCollector::Logical(vec![b as u8]);
1077 Ok(())
1078 }
1079 ClassifiedValue::Double(d) => {
1080 *self = UniformCollector::Double(vec![d]);
1081 Ok(())
1082 }
1083 ClassifiedValue::Complex(c) => {
1084 *self = UniformCollector::Complex(vec![c]);
1085 Ok(())
1086 }
1087 ClassifiedValue::Char(ch) => {
1088 *self = UniformCollector::Char(vec![ch]);
1089 Ok(())
1090 }
1091 },
1092 UniformCollector::Logical(bits) => match classify_value(value)? {
1093 ClassifiedValue::Logical(b) => {
1094 bits.push(b as u8);
1095 Ok(())
1096 }
1097 ClassifiedValue::Double(d) => {
1098 let mut data: Vec<f64> = bits
1099 .iter()
1100 .map(|&bit| if bit != 0 { 1.0 } else { 0.0 })
1101 .collect();
1102 data.push(d);
1103 *self = UniformCollector::Double(data);
1104 Ok(())
1105 }
1106 ClassifiedValue::Complex(c) => {
1107 let mut data: Vec<(f64, f64)> = bits
1108 .iter()
1109 .map(|&bit| if bit != 0 { (1.0, 0.0) } else { (0.0, 0.0) })
1110 .collect();
1111 data.push(c);
1112 *self = UniformCollector::Complex(data);
1113 Ok(())
1114 }
1115 ClassifiedValue::Char(ch) => {
1116 let mut data: Vec<f64> = bits
1117 .iter()
1118 .map(|&bit| if bit != 0 { 1.0 } else { 0.0 })
1119 .collect();
1120 data.push(ch as u32 as f64);
1121 *self = UniformCollector::Double(data);
1122 Ok(())
1123 }
1124 },
1125 UniformCollector::Double(data) => match classify_value(value)? {
1126 ClassifiedValue::Logical(b) => {
1127 data.push(if b { 1.0 } else { 0.0 });
1128 Ok(())
1129 }
1130 ClassifiedValue::Double(d) => {
1131 data.push(d);
1132 Ok(())
1133 }
1134 ClassifiedValue::Complex(c) => {
1135 let promoted: Vec<(f64, f64)> = data.iter().map(|&v| (v, 0.0)).collect();
1136 let mut complex = promoted;
1137 complex.push(c);
1138 *self = UniformCollector::Complex(complex);
1139 Ok(())
1140 }
1141 ClassifiedValue::Char(ch) => {
1142 data.push(ch as u32 as f64);
1143 Ok(())
1144 }
1145 },
1146 UniformCollector::Complex(data) => match classify_value(value)? {
1147 ClassifiedValue::Logical(b) => {
1148 data.push((if b { 1.0 } else { 0.0 }, 0.0));
1149 Ok(())
1150 }
1151 ClassifiedValue::Double(d) => {
1152 data.push((d, 0.0));
1153 Ok(())
1154 }
1155 ClassifiedValue::Complex(c) => {
1156 data.push(c);
1157 Ok(())
1158 }
1159 ClassifiedValue::Char(ch) => {
1160 data.push((ch as u32 as f64, 0.0));
1161 Ok(())
1162 }
1163 },
1164 UniformCollector::Char(chars) => match classify_value(value)? {
1165 ClassifiedValue::Char(ch) => {
1166 chars.push(ch);
1167 Ok(())
1168 }
1169 ClassifiedValue::Logical(b) => {
1170 let mut data: Vec<f64> = chars.iter().map(|&ch| ch as u32 as f64).collect();
1171 data.push(if b { 1.0 } else { 0.0 });
1172 *self = UniformCollector::Double(data);
1173 Ok(())
1174 }
1175 ClassifiedValue::Double(d) => {
1176 let mut data: Vec<f64> = chars.iter().map(|&ch| ch as u32 as f64).collect();
1177 data.push(d);
1178 *self = UniformCollector::Double(data);
1179 Ok(())
1180 }
1181 ClassifiedValue::Complex(c) => {
1182 let mut promoted: Vec<(f64, f64)> =
1183 chars.iter().map(|&ch| (ch as u32 as f64, 0.0)).collect();
1184 promoted.push(c);
1185 *self = UniformCollector::Complex(promoted);
1186 Ok(())
1187 }
1188 },
1189 }
1190 }
1191
1192 fn finish(self, shape: &[usize]) -> BuiltinResult<Value> {
1193 match self {
1194 UniformCollector::Pending => {
1195 let total = shape.iter().product();
1196 let tensor = Tensor::new(vec![0.0; total], shape.to_vec())
1197 .map_err(|e| arrayfun_internal(format!("arrayfun: {e}")))?;
1198 Ok(Value::Tensor(tensor))
1199 }
1200 UniformCollector::Double(data) => {
1201 let tensor = Tensor::new(data, shape.to_vec())
1202 .map_err(|e| arrayfun_internal(format!("arrayfun: {e}")))?;
1203 Ok(Value::Tensor(tensor))
1204 }
1205 UniformCollector::Logical(bits) => {
1206 let logical = LogicalArray::new(bits, shape.to_vec())
1207 .map_err(|e| arrayfun_internal(format!("arrayfun: {e}")))?;
1208 Ok(Value::LogicalArray(logical))
1209 }
1210 UniformCollector::Complex(entries) => {
1211 let tensor = ComplexTensor::new(entries, shape.to_vec())
1212 .map_err(|e| arrayfun_internal(format!("arrayfun: {e}")))?;
1213 Ok(Value::ComplexTensor(tensor))
1214 }
1215 UniformCollector::Char(chars) => {
1216 let normalized_shape = if shape.is_empty() {
1217 vec![1, 1]
1218 } else {
1219 shape.to_vec()
1220 };
1221
1222 if normalized_shape.len() > 2 {
1223 return Err(arrayfun_error_with_detail(
1224 &ARRAYFUN_ERROR_UNIFORM_OUTPUT_TYPE,
1225 "character outputs with UniformOutput=true must be 2-D",
1226 ));
1227 }
1228
1229 let rows = normalized_shape.first().copied().unwrap_or(1);
1230 let cols = normalized_shape.get(1).copied().unwrap_or(1);
1231 let expected = rows.checked_mul(cols).ok_or_else(|| {
1232 arrayfun_internal("arrayfun: character output size exceeds platform limits")
1233 })?;
1234
1235 if expected != chars.len() {
1236 return Err(arrayfun_error_with_detail(
1237 &ARRAYFUN_ERROR_UNIFORM_OUTPUT_TYPE,
1238 "callback returned the wrong number of characters",
1239 ));
1240 }
1241
1242 let mut row_major = vec!['\0'; expected];
1243 for col in 0..cols {
1244 for row in 0..rows {
1245 let col_major_idx = row + col * rows;
1246 let row_major_idx = row * cols + col;
1247 row_major[row_major_idx] = chars[col_major_idx];
1248 }
1249 }
1250
1251 let array = CharArray::new(row_major, rows, cols)
1252 .map_err(|e| arrayfun_internal(format!("arrayfun: {e}")))?;
1253 Ok(Value::CharArray(array))
1254 }
1255 }
1256 }
1257}
1258
1259enum ClassifiedValue {
1260 Logical(bool),
1261 Double(f64),
1262 Complex((f64, f64)),
1263 Char(char),
1264}
1265
1266fn classify_value(value: &Value) -> BuiltinResult<ClassifiedValue> {
1267 match value {
1268 Value::Bool(b) => Ok(ClassifiedValue::Logical(*b)),
1269 Value::LogicalArray(la) if la.len() == 1 => Ok(ClassifiedValue::Logical(la.data[0] != 0)),
1270 Value::Int(i) => Ok(ClassifiedValue::Double(i.to_f64())),
1271 Value::Num(n) => Ok(ClassifiedValue::Double(*n)),
1272 Value::Tensor(t) if t.data.len() == 1 => Ok(ClassifiedValue::Double(t.data[0])),
1273 Value::Complex(re, im) => Ok(ClassifiedValue::Complex((*re, *im))),
1274 Value::ComplexTensor(t) if t.data.len() == 1 => Ok(ClassifiedValue::Complex(t.data[0])),
1275 Value::CharArray(ca) if ca.rows * ca.cols == 1 => {
1276 let ch = ca.data.first().copied().unwrap_or('\0');
1277 Ok(ClassifiedValue::Char(ch))
1278 }
1279 other => Err(arrayfun_error_with_detail(
1280 &ARRAYFUN_ERROR_UNIFORM_OUTPUT_TYPE,
1281 format!(
1282 "callback must return scalar numeric, logical, character, or complex values for UniformOutput=true (got {other:?})"
1283 ),
1284 )),
1285 }
1286}
1287
1288fn make_error_struct(
1289 raw_error: &str,
1290 linear_index: usize,
1291 shape: &[usize],
1292) -> BuiltinResult<Value> {
1293 let (identifier, message) = split_error_message(raw_error);
1294 let mut st = runmat_builtins::StructValue::new();
1295 st.fields
1296 .insert("identifier".to_string(), Value::String(identifier));
1297 st.fields
1298 .insert("message".to_string(), Value::String(message));
1299 st.fields
1300 .insert("index".to_string(), Value::Num((linear_index + 1) as f64));
1301 let subs = linear_to_indices(linear_index, shape);
1302 let subs_tensor = dims_to_row_tensor(&subs)?;
1303 st.fields
1304 .insert("indices".to_string(), Value::Tensor(subs_tensor));
1305 Ok(Value::Struct(st))
1306}
1307
1308fn split_error_message(raw: &str) -> (String, String) {
1309 let trimmed = raw.trim();
1310 let mut indices = trimmed.match_indices(':');
1311 if let Some((_, _)) = indices.next() {
1312 if let Some((second_idx, _)) = indices.next() {
1313 let identifier = trimmed[..second_idx].trim().to_string();
1314 let message = trimmed[second_idx + 1..].trim().to_string();
1315 if !identifier.is_empty() && identifier.contains(':') {
1316 return (
1317 identifier,
1318 if message.is_empty() {
1319 trimmed.to_string()
1320 } else {
1321 message
1322 },
1323 );
1324 }
1325 } else if trimmed.len() >= 7
1326 && (trimmed[..7].eq_ignore_ascii_case("matlab:")
1327 || trimmed[..7].eq_ignore_ascii_case("runmat:"))
1328 {
1329 return (trimmed.to_string(), String::new());
1330 }
1331 }
1332 (
1333 "RunMat:arrayfun:FunctionError".to_string(),
1334 trimmed.to_string(),
1335 )
1336}
1337
1338fn linear_to_indices(mut index: usize, shape: &[usize]) -> Vec<usize> {
1339 if shape.is_empty() {
1340 return vec![1];
1341 }
1342 let mut subs = Vec::with_capacity(shape.len());
1343 for &dim in shape {
1344 if dim == 0 {
1345 subs.push(1);
1346 continue;
1347 }
1348 let coord = (index % dim) + 1;
1349 subs.push(coord);
1350 index /= dim;
1351 }
1352 subs
1353}
1354
1355fn dims_to_row_tensor(dims: &[usize]) -> BuiltinResult<Tensor> {
1356 let data: Vec<f64> = dims.iter().map(|&d| d as f64).collect();
1357 Tensor::new(data, vec![1, dims.len()]).map_err(|e| arrayfun_internal(format!("arrayfun: {e}")))
1358}
1359
1360#[cfg(test)]
1361pub(crate) mod tests {
1362 use super::*;
1363 use crate::builtins::common::test_support;
1364 use futures::executor::block_on;
1365 use runmat_accelerate_api::HostTensorView;
1366 use runmat_builtins::{ResolveContext, Tensor, Type};
1367 use std::sync::Arc;
1368
1369 fn call(func: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
1370 block_on(arrayfun_builtin(func, rest))
1371 }
1372
1373 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1374 #[test]
1375 fn arrayfun_basic_sin() {
1376 let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0], vec![2, 3]).unwrap();
1377 let expected: Vec<f64> = tensor.data.iter().map(|&x| x.sin()).collect();
1378 let result = call(
1379 Value::FunctionHandle("sin".to_string()),
1380 vec![Value::Tensor(tensor.clone())],
1381 )
1382 .expect("arrayfun");
1383 match result {
1384 Value::Tensor(out) => {
1385 assert_eq!(out.shape, vec![2, 3]);
1386 assert_eq!(out.data, expected);
1387 }
1388 other => panic!("expected tensor, got {other:?}"),
1389 }
1390 }
1391
1392 #[test]
1393 fn arrayfun_semantic_function_handle_uses_semantic_invoker() {
1394 let _guard = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
1395 |function, args, requested_outputs| {
1396 assert_eq!(function, 78);
1397 assert_eq!(requested_outputs, 1);
1398 let [Value::Num(value)] = args else {
1399 panic!("expected scalar numeric argument, got {args:?}");
1400 };
1401 let value = *value;
1402 Box::pin(async move { Ok(Value::Num(value + 10.0)) })
1403 },
1404 )));
1405 let tensor = Tensor::new(vec![1.0, 2.0], vec![1, 2]).expect("tensor");
1406 let handle = Value::BoundFunctionHandle {
1407 name: "arrayfun_target".to_string(),
1408 function: 78,
1409 };
1410
1411 let result = call(handle, vec![Value::Tensor(tensor)]).expect("semantic arrayfun");
1412 match result {
1413 Value::Tensor(out) => {
1414 assert_eq!(out.shape, vec![1, 2]);
1415 assert_eq!(out.data, vec![11.0, 12.0]);
1416 }
1417 other => panic!("expected tensor, got {other:?}"),
1418 }
1419 }
1420
1421 #[test]
1422 fn arrayfun_name_only_callback_uses_semantic_resolver() {
1423 let _resolver_guard =
1424 crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|name| {
1425 (name == "resolved_arrayfun_target").then_some(80)
1426 })));
1427 let _invoker_guard = crate::user_functions::install_semantic_function_invoker(Some(
1428 Arc::new(|function, args, requested_outputs| {
1429 assert_eq!(function, 80);
1430 assert_eq!(requested_outputs, 1);
1431 let [Value::Num(value)] = args else {
1432 panic!("expected scalar numeric argument, got {args:?}");
1433 };
1434 let value = *value;
1435 Box::pin(async move { Ok(Value::Num(value + 20.0)) })
1436 }),
1437 ));
1438 let tensor = Tensor::new(vec![1.0, 2.0], vec![1, 2]).expect("tensor");
1439
1440 let result = call(
1441 Value::String("resolved_arrayfun_target".to_string()),
1442 vec![Value::Tensor(tensor)],
1443 )
1444 .expect("resolved name-only arrayfun");
1445 match result {
1446 Value::Tensor(out) => {
1447 assert_eq!(out.shape, vec![1, 2]);
1448 assert_eq!(out.data, vec![21.0, 22.0]);
1449 }
1450 other => panic!("expected tensor, got {other:?}"),
1451 }
1452 }
1453
1454 #[test]
1455 fn arrayfun_qualified_text_callback_classifies_as_external_name() {
1456 let callable =
1457 Callable::from_text("pkg.callback").expect("qualified arrayfun callback should parse");
1458 assert!(matches!(
1459 callable,
1460 Callable::ExternalName { name } if name == "pkg.callback"
1461 ));
1462 }
1463
1464 #[test]
1465 fn arrayfun_external_handle_uses_semantic_resolver() {
1466 let _resolver_guard =
1467 crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|name| {
1468 (name == "pkg.callback").then_some(87)
1469 })));
1470 let _invoker_guard = crate::user_functions::install_semantic_function_invoker(Some(
1471 Arc::new(|function, args, requested_outputs| {
1472 assert_eq!(function, 87);
1473 assert_eq!(requested_outputs, 1);
1474 let [Value::Num(value)] = args else {
1475 panic!("expected scalar numeric argument, got {args:?}");
1476 };
1477 let value = *value;
1478 Box::pin(async move { Ok(Value::Num(value + 30.0)) })
1479 }),
1480 ));
1481 let tensor = Tensor::new(vec![1.0, 2.0], vec![1, 2]).expect("tensor");
1482
1483 let result = call(
1484 Value::ExternalFunctionHandle("pkg.callback".to_string()),
1485 vec![Value::Tensor(tensor)],
1486 )
1487 .expect("resolved external-handle arrayfun");
1488 match result {
1489 Value::Tensor(out) => {
1490 assert_eq!(out.shape, vec![1, 2]);
1491 assert_eq!(out.data, vec![31.0, 32.0]);
1492 }
1493 other => panic!("expected tensor, got {other:?}"),
1494 }
1495 }
1496
1497 #[test]
1498 fn arrayfun_single_segment_external_handle_uses_runtime_name_resolution() {
1499 let _resolver_guard =
1500 crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|name| {
1501 (name == "callback").then_some(887)
1502 })));
1503 let _invoker_guard = crate::user_functions::install_semantic_function_invoker(Some(
1504 Arc::new(|function, args, requested_outputs| {
1505 assert_eq!(function, 887);
1506 assert_eq!(requested_outputs, 1);
1507 let [Value::Num(value)] = args else {
1508 panic!("expected scalar numeric argument, got {args:?}");
1509 };
1510 let value = *value;
1511 Box::pin(async move { Ok(Value::Num(value + 40.0)) })
1512 }),
1513 ));
1514 let tensor = Tensor::new(vec![1.0, 2.0], vec![1, 2]).expect("tensor");
1515
1516 let result = call(
1517 Value::ExternalFunctionHandle("callback".to_string()),
1518 vec![Value::Tensor(tensor)],
1519 )
1520 .expect("single-segment external-handle arrayfun should resolve via runtime-name policy");
1521 match result {
1522 Value::Tensor(out) => {
1523 assert_eq!(out.shape, vec![1, 2]);
1524 assert_eq!(out.data, vec![41.0, 42.0]);
1525 }
1526 other => panic!("expected tensor, got {other:?}"),
1527 }
1528 }
1529
1530 #[test]
1531 fn arrayfun_external_handle_prefers_semantic_handle_binding_when_resolved() {
1532 let _resolver_guard =
1533 crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|name| {
1534 (name == "pkg.callback").then_some(87)
1535 })));
1536 let callable =
1537 Callable::from_function(Value::ExternalFunctionHandle("pkg.callback".to_string()))
1538 .expect("external handle should parse");
1539 assert!(matches!(
1540 callable,
1541 Callable::Closure(Closure {
1542 function_name,
1543 bound_function: Some(87),
1544 ..
1545 }) if function_name == "pkg.callback"
1546 ));
1547 }
1548
1549 #[test]
1550 fn arrayfun_name_only_closure_prefers_semantic_handle_binding_when_resolved() {
1551 let _resolver_guard =
1552 crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|name| {
1553 (name == "pkg.callback").then_some(187)
1554 })));
1555 let callable = Callable::from_function(Value::Closure(Closure {
1556 function_name: "pkg.callback".to_string(),
1557 bound_function: None,
1558 captures: vec![Value::Num(5.0)],
1559 }))
1560 .expect("closure callback should parse");
1561 assert!(matches!(
1562 callable,
1563 Callable::Closure(Closure {
1564 function_name,
1565 bound_function: Some(187),
1566 captures
1567 }) if function_name == "pkg.callback" && captures == vec![Value::Num(5.0)]
1568 ));
1569 }
1570
1571 #[test]
1572 fn arrayfun_name_only_closure_call_uses_semantic_resolver_when_unbound() {
1573 let _resolver_guard =
1574 crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|name| {
1575 (name == "pkg.callback").then_some(287)
1576 })));
1577 let _invoker_guard = crate::user_functions::install_semantic_function_invoker(Some(
1578 Arc::new(|function, args, requested_outputs| {
1579 assert_eq!(function, 287);
1580 assert_eq!(requested_outputs, 1);
1581 assert_eq!(args, &[Value::Num(5.0), Value::Num(4.0)]);
1582 Box::pin(async { Ok(Value::Num(9.0)) })
1583 }),
1584 ));
1585 let callable = Callable::Closure(Closure {
1586 function_name: "pkg.callback".to_string(),
1587 bound_function: None,
1588 captures: vec![Value::Num(5.0)],
1589 });
1590 let value = block_on(callable.call(&[Value::Num(4.0)])).expect("closure call");
1591 assert_eq!(value, Value::Num(9.0));
1592 }
1593
1594 #[test]
1595 fn arrayfun_external_handle_errors_as_undefined_when_unresolved() {
1596 let _resolver_guard =
1597 crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|_| None)));
1598 let tensor = Tensor::new(vec![1.0], vec![1, 1]).expect("tensor");
1599
1600 let err = call(
1601 Value::ExternalFunctionHandle("pkg.callback".to_string()),
1602 vec![Value::Tensor(tensor)],
1603 )
1604 .expect_err("unresolved external callback should error");
1605 assert_eq!(
1606 err.identifier(),
1607 ARRAYFUN_ERROR_UNDEFINED_FUNCTION.identifier,
1608 "unexpected error: {}",
1609 err.message()
1610 );
1611 assert!(
1612 err.message().contains("ExternalName(QualifiedName"),
1613 "unexpected error: {err:?}"
1614 );
1615 assert!(
1616 !err.message().contains("Undefined function 'pkg.callback'"),
1617 "well-formed external callback should report typed identity: {err:?}"
1618 );
1619 }
1620
1621 #[test]
1622 fn arrayfun_malformed_external_handle_errors_as_undefined_when_unresolved() {
1623 let _resolver_guard =
1624 crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|_| None)));
1625 let tensor = Tensor::new(vec![1.0], vec![1, 1]).expect("tensor");
1626
1627 let err = call(
1628 Value::ExternalFunctionHandle("pkg..callback".to_string()),
1629 vec![Value::Tensor(tensor)],
1630 )
1631 .expect_err("malformed unresolved external callback should error");
1632 assert_eq!(
1633 err.identifier(),
1634 ARRAYFUN_ERROR_UNDEFINED_FUNCTION.identifier,
1635 "unexpected error: {}",
1636 err.message()
1637 );
1638 assert!(
1639 err.message().contains("pkg..callback"),
1640 "unexpected error: {err:?}"
1641 );
1642 }
1643
1644 #[test]
1645 fn arrayfun_type_tracks_function_returns() {
1646 let func = Type::Function {
1647 params: vec![Type::Num],
1648 returns: Box::new(Type::Num),
1649 };
1650 assert_eq!(
1651 arrayfun_type(&[func, Type::tensor()], &ResolveContext::new(Vec::new())),
1652 Type::tensor()
1653 );
1654 }
1655
1656 #[test]
1657 fn arrayfun_type_uses_logical_returns() {
1658 let func = Type::Function {
1659 params: vec![Type::Num],
1660 returns: Box::new(Type::Bool),
1661 };
1662 assert_eq!(
1663 arrayfun_type(&[func, Type::tensor()], &ResolveContext::new(Vec::new())),
1664 Type::logical()
1665 );
1666 }
1667
1668 #[test]
1669 fn arrayfun_type_with_text_args_stays_unknown() {
1670 let func = Type::Function {
1671 params: vec![Type::Num],
1672 returns: Box::new(Type::Num),
1673 };
1674 assert_eq!(
1675 arrayfun_type(
1676 &[func, Type::tensor(), Type::String, Type::Bool],
1677 &ResolveContext::new(Vec::new()),
1678 ),
1679 Type::Unknown
1680 );
1681 }
1682
1683 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1684 #[test]
1685 fn arrayfun_additional_scalar_argument() {
1686 let tensor = Tensor::new(vec![0.5, 1.0, -1.0], vec![3, 1]).unwrap();
1687 let expected: Vec<f64> = tensor.data.iter().map(|&y| y.atan2(1.0)).collect();
1688 let result = call(
1689 Value::FunctionHandle("atan2".to_string()),
1690 vec![Value::Tensor(tensor), Value::Num(1.0)],
1691 )
1692 .expect("arrayfun");
1693 match result {
1694 Value::Tensor(out) => {
1695 assert_eq!(out.data, expected);
1696 }
1697 other => panic!("expected tensor, got {other:?}"),
1698 }
1699 }
1700
1701 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1702 #[test]
1703 fn arrayfun_uniform_false_returns_cell() {
1704 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1705 let expected: Vec<Value> = tensor.data.iter().map(|&x| Value::Num(x.sin())).collect();
1706 let result = call(
1707 Value::FunctionHandle("sin".to_string()),
1708 vec![
1709 Value::Tensor(tensor),
1710 Value::String("UniformOutput".into()),
1711 Value::Bool(false),
1712 ],
1713 )
1714 .expect("arrayfun");
1715 let Value::Cell(cell) = result else {
1716 panic!("expected cell, got something else");
1717 };
1718 assert_eq!(cell.rows, 2);
1719 assert_eq!(cell.cols, 1);
1720 for (row, value) in expected.iter().enumerate() {
1721 assert_eq!(cell.get(row, 0).unwrap(), *value);
1722 }
1723 }
1724
1725 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1726 #[test]
1727 fn arrayfun_uniform_output_option_identifier() {
1728 let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1729 let err = call(
1730 Value::FunctionHandle("sin".to_string()),
1731 vec![
1732 Value::Tensor(tensor),
1733 Value::String("UniformOutput".into()),
1734 Value::String("maybe".into()),
1735 ],
1736 )
1737 .expect_err("expected invalid uniform output option");
1738 assert_eq!(
1739 err.identifier(),
1740 ARRAYFUN_ERROR_UNIFORM_OUTPUT_OPTION.identifier
1741 );
1742 }
1743
1744 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1745 #[test]
1746 fn arrayfun_unknown_name_value_identifier() {
1747 let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1748 let err = call(
1749 Value::FunctionHandle("sin".to_string()),
1750 vec![
1751 Value::Tensor(tensor),
1752 Value::String("MysteryFlag".into()),
1753 Value::Bool(true),
1754 ],
1755 )
1756 .expect_err("expected unknown name-value error");
1757 assert_eq!(err.identifier(), ARRAYFUN_ERROR_INVALID_INPUT.identifier);
1758 }
1759
1760 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1761 #[test]
1762 fn arrayfun_size_mismatch_errors() {
1763 let taller = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1764 let shorter = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1765 let err = call(
1766 Value::FunctionHandle("sin".to_string()),
1767 vec![Value::Tensor(taller), Value::Tensor(shorter)],
1768 )
1769 .expect_err("expected size mismatch error");
1770 let err = err.to_string();
1771 assert!(
1772 err.contains("does not match"),
1773 "expected size mismatch error, got {err}"
1774 );
1775 }
1776
1777 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1778 #[test]
1779 fn arrayfun_error_handler_recovers() {
1780 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1781 let handler = Value::Closure(Closure {
1782 function_name: "__arrayfun_test_handler".into(),
1783 bound_function: None,
1784 captures: vec![Value::Num(42.0)],
1785 });
1786 let result = call(
1787 Value::String("@nonexistent_builtin".into()),
1788 vec![
1789 Value::Tensor(tensor),
1790 Value::String("ErrorHandler".into()),
1791 handler,
1792 ],
1793 )
1794 .expect("arrayfun error handler");
1795 match result {
1796 Value::Tensor(out) => {
1797 assert_eq!(out.shape, vec![3, 1]);
1798 assert_eq!(out.data, vec![42.0, 42.0, 42.0]);
1799 }
1800 other => panic!("expected tensor, got {other:?}"),
1801 }
1802 }
1803
1804 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1805 #[test]
1806 fn arrayfun_error_without_handler_propagates_identifier() {
1807 let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1808 let err = call(
1809 Value::String("@nonexistent_builtin".into()),
1810 vec![Value::Tensor(tensor)],
1811 )
1812 .expect_err("expected unresolved function error");
1813 assert_eq!(
1814 err.identifier(),
1815 ARRAYFUN_ERROR_UNDEFINED_FUNCTION.identifier,
1816 "unexpected error: {}",
1817 err.message()
1818 );
1819 }
1820
1821 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1822 #[test]
1823 fn arrayfun_uniform_logical_result() {
1824 let tensor = Tensor::new(vec![1.0, f64::NAN, 0.0, f64::INFINITY], vec![4, 1]).unwrap();
1825 let result = call(
1826 Value::FunctionHandle("isfinite".to_string()),
1827 vec![Value::Tensor(tensor)],
1828 )
1829 .expect("arrayfun isfinite");
1830 match result {
1831 Value::LogicalArray(la) => {
1832 assert_eq!(la.shape, vec![4, 1]);
1833 assert_eq!(la.data, vec![1, 0, 1, 0]);
1834 }
1835 other => panic!("expected logical array, got {other:?}"),
1836 }
1837 }
1838
1839 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1840 #[test]
1841 fn arrayfun_uniform_character_result() {
1842 let tensor = Tensor::new(vec![65.0, 66.0, 67.0], vec![1, 3]).unwrap();
1843 let result = call(
1844 Value::FunctionHandle("char".to_string()),
1845 vec![Value::Tensor(tensor)],
1846 )
1847 .expect("arrayfun char");
1848 match result {
1849 Value::CharArray(ca) => {
1850 assert_eq!(ca.rows, 1);
1851 assert_eq!(ca.cols, 3);
1852 assert_eq!(ca.data, vec!['A', 'B', 'C']);
1853 }
1854 other => panic!("expected char array, got {other:?}"),
1855 }
1856 }
1857
1858 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1859 #[test]
1860 fn arrayfun_uniform_false_gpu_returns_cell() {
1861 test_support::with_test_provider(|provider| {
1862 let tensor = Tensor::new(vec![0.0, 1.0], vec![2, 1]).unwrap();
1863 let view = HostTensorView {
1864 data: &tensor.data,
1865 shape: &tensor.shape,
1866 };
1867 let handle = provider.upload(&view).expect("upload");
1868 let result = call(
1869 Value::FunctionHandle("sin".to_string()),
1870 vec![
1871 Value::GpuTensor(handle),
1872 Value::String("UniformOutput".into()),
1873 Value::Bool(false),
1874 ],
1875 )
1876 .expect("arrayfun");
1877 match result {
1878 Value::Cell(cell) => {
1879 assert_eq!(cell.rows, 2);
1880 assert_eq!(cell.cols, 1);
1881 let first = cell.get(0, 0).expect("first cell");
1882 let second = cell.get(1, 0).expect("second cell");
1883 match (first, second) {
1884 (Value::Num(a), Value::Num(b)) => {
1885 assert!((a - 0.0f64.sin()).abs() < 1e-12);
1886 assert!((b - 1.0f64.sin()).abs() < 1e-12);
1887 }
1888 other => panic!("expected numeric cells, got {other:?}"),
1889 }
1890 }
1891 other => panic!("expected cell, got {other:?}"),
1892 }
1893 });
1894 }
1895
1896 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1897 #[test]
1898 fn arrayfun_gpu_roundtrip() {
1899 test_support::with_test_provider(|provider| {
1900 let tensor = Tensor::new(vec![0.0, 1.0, 2.0, 3.0], vec![4, 1]).unwrap();
1901 let view = HostTensorView {
1902 data: &tensor.data,
1903 shape: &tensor.shape,
1904 };
1905 let handle = provider.upload(&view).expect("upload");
1906 let result = call(
1907 Value::FunctionHandle("sin".to_string()),
1908 vec![Value::GpuTensor(handle)],
1909 )
1910 .expect("arrayfun");
1911 match result {
1912 Value::GpuTensor(gpu) => {
1913 let gathered = test_support::gather(Value::GpuTensor(gpu.clone())).unwrap();
1914 let expected: Vec<f64> = tensor.data.iter().map(|&x| x.sin()).collect();
1915 assert_eq!(gathered.data, expected);
1916 let _ = provider.free(&gpu);
1917 }
1918 other => panic!("expected gpu tensor, got {other:?}"),
1919 }
1920 });
1921 }
1922
1923 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1924 #[test]
1925 #[cfg(feature = "wgpu")]
1926 fn arrayfun_wgpu_sin_matches_cpu() {
1927 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1928 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1929 );
1930 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1931
1932 let tensor = Tensor::new(vec![0.0, 1.0, 2.0, 3.0], vec![4, 1]).unwrap();
1933 let view = HostTensorView {
1934 data: &tensor.data,
1935 shape: &tensor.shape,
1936 };
1937 let handle = provider.upload(&view).expect("upload");
1938 let result = call(
1939 Value::FunctionHandle("sin".into()),
1940 vec![Value::GpuTensor(handle.clone())],
1941 )
1942 .expect("arrayfun sin gpu");
1943 let Value::GpuTensor(out_handle) = result else {
1944 panic!("expected GPU tensor result");
1945 };
1946 let gathered = test_support::gather(Value::GpuTensor(out_handle.clone())).unwrap();
1947 let expected: Vec<f64> = tensor.data.iter().map(|v| v.sin()).collect();
1948 assert_eq!(gathered.shape, tensor.shape);
1949 let tol = match provider.precision() {
1950 runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
1951 runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
1952 };
1953 for (actual, expect) in gathered.data.iter().zip(expected.iter()) {
1954 assert!(
1955 (actual - expect).abs() < tol,
1956 "expected {expect}, got {actual}"
1957 );
1958 }
1959 let _ = provider.free(&handle);
1960 let _ = provider.free(&out_handle);
1961 }
1962
1963 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1964 #[test]
1965 #[cfg(feature = "wgpu")]
1966 fn arrayfun_wgpu_plus_matches_cpu() {
1967 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1968 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1969 );
1970 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1971
1972 let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1973 let b = Tensor::new(vec![4.0, 3.0, 2.0, 1.0], vec![2, 2]).unwrap();
1974 let view_a = HostTensorView {
1975 data: &a.data,
1976 shape: &a.shape,
1977 };
1978 let view_b = HostTensorView {
1979 data: &b.data,
1980 shape: &b.shape,
1981 };
1982 let handle_a = provider.upload(&view_a).expect("upload a");
1983 let handle_b = provider.upload(&view_b).expect("upload b");
1984 let result = call(
1985 Value::FunctionHandle("plus".into()),
1986 vec![
1987 Value::GpuTensor(handle_a.clone()),
1988 Value::GpuTensor(handle_b.clone()),
1989 ],
1990 )
1991 .expect("arrayfun plus gpu");
1992
1993 let Value::GpuTensor(out_handle) = result else {
1994 panic!("expected GPU tensor result");
1995 };
1996 let gathered = test_support::gather(Value::GpuTensor(out_handle.clone())).unwrap();
1997 let expected: Vec<f64> = a
1998 .data
1999 .iter()
2000 .zip(b.data.iter())
2001 .map(|(x, y)| x + y)
2002 .collect();
2003 assert_eq!(gathered.shape, a.shape);
2004 let tol = match provider.precision() {
2005 runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
2006 runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
2007 };
2008 for (actual, expect) in gathered.data.iter().zip(expected.iter()) {
2009 assert!(
2010 (actual - expect).abs() < tol,
2011 "expected {expect}, got {actual}"
2012 );
2013 }
2014 let _ = provider.free(&handle_a);
2015 let _ = provider.free(&handle_b);
2016 let _ = provider.free(&out_handle);
2017 }
2018
2019 const ARRAYFUN_TEST_HELPER_ERRORS: [BuiltinErrorDescriptor; 0] = [];
2020 const ARRAYFUN_TEST_HELPER_OUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
2021 name: "out",
2022 ty: BuiltinParamType::Any,
2023 arity: BuiltinParamArity::Required,
2024 default: None,
2025 description: "Helper output value.",
2026 }];
2027 const ARRAYFUN_TEST_HANDLER_INPUTS: [BuiltinParamDescriptor; 3] = [
2028 BuiltinParamDescriptor {
2029 name: "seed",
2030 ty: BuiltinParamType::Any,
2031 arity: BuiltinParamArity::Required,
2032 default: None,
2033 description: "Seed value.",
2034 },
2035 BuiltinParamDescriptor {
2036 name: "err",
2037 ty: BuiltinParamType::Any,
2038 arity: BuiltinParamArity::Required,
2039 default: None,
2040 description: "Error context placeholder.",
2041 },
2042 BuiltinParamDescriptor {
2043 name: "rest",
2044 ty: BuiltinParamType::Any,
2045 arity: BuiltinParamArity::Variadic,
2046 default: None,
2047 description: "Additional values.",
2048 },
2049 ];
2050 const ARRAYFUN_TEST_HANDLER_SIGNATURES: [BuiltinSignatureDescriptor; 1] =
2051 [BuiltinSignatureDescriptor {
2052 label: "out = __arrayfun_test_handler(seed, err, ...)",
2053 inputs: &ARRAYFUN_TEST_HANDLER_INPUTS,
2054 outputs: &ARRAYFUN_TEST_HELPER_OUT,
2055 }];
2056 const ARRAYFUN_TEST_HANDLER_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
2057 signatures: &ARRAYFUN_TEST_HANDLER_SIGNATURES,
2058 output_mode: BuiltinOutputMode::Fixed,
2059 completion_policy: BuiltinCompletionPolicy::HiddenInternal,
2060 errors: &ARRAYFUN_TEST_HELPER_ERRORS,
2061 };
2062
2063 #[runmat_macros::runtime_builtin(
2064 name = "__arrayfun_test_handler",
2065 descriptor(
2066 crate::builtins::acceleration::gpu::arrayfun::tests::ARRAYFUN_TEST_HANDLER_DESCRIPTOR
2067 ),
2068 type_resolver(arrayfun_type),
2069 builtin_path = "crate::builtins::acceleration::gpu::arrayfun::tests"
2070 )]
2071 async fn arrayfun_test_handler(
2072 seed: Value,
2073 _err: Value,
2074 rest: Vec<Value>,
2075 ) -> crate::BuiltinResult<Value> {
2076 let _ = rest;
2077 Ok(seed)
2078 }
2079}