1use runmat_builtins::{
4 CellArray, Closure, ComplexTensor, LogicalArray, StructValue, Tensor, Value,
5};
6use runmat_macros::runtime_builtin;
7
8use crate::builtins::cells::type_resolvers::cellfun_type;
9use crate::builtins::common::shape::{dims_to_row_tensor, value_numel};
10use crate::builtins::common::spec::{
11 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
12 ReductionNaN, ResidencyPolicy, ShapeRequirements,
13};
14use crate::{
15 build_runtime_error, call_builtin_async, gather_if_needed_async, make_cell_with_shape,
16 user_functions, BuiltinResult, RuntimeError,
17};
18
19#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::cells::core::cellfun")]
20pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
21 name: "cellfun",
22 op_kind: GpuOpKind::Custom("host-cell-map"),
23 supported_precisions: &[],
24 broadcast: BroadcastSemantics::None,
25 provider_hooks: &[],
26 constant_strategy: ConstantStrategy::InlineLiteral,
27 residency: ResidencyPolicy::GatherImmediately,
28 nan_mode: ReductionNaN::Include,
29 two_pass_threshold: None,
30 workgroup_size: None,
31 accepts_nan_mode: false,
32 notes: "Executes on the host and gathers GPU-resident inputs before evaluating callbacks.",
33};
34
35#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::cells::core::cellfun")]
36pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
37 name: "cellfun",
38 shape: ShapeRequirements::Any,
39 constant_strategy: ConstantStrategy::InlineLiteral,
40 elementwise: None,
41 reduction: None,
42 emits_nan: true,
43 notes: "Callback execution happens on the host; fusion planners should treat cellfun as a fusion barrier.",
44};
45
46const IDENT_INVALID_INPUT: &str = "RunMat:cellfun:InvalidInput";
47const IDENT_UNIFORM_OUTPUT: &str = "RunMat:cellfun:UniformOutput";
48const IDENT_FUNCTION_ERROR: &str = "RunMat:cellfun:FunctionError";
49
50fn cellfun_error(message: impl Into<String>) -> RuntimeError {
51 build_runtime_error(message).with_builtin("cellfun").build()
52}
53
54fn cellfun_error_with_identifier(message: impl Into<String>, identifier: &str) -> RuntimeError {
55 build_runtime_error(message)
56 .with_builtin("cellfun")
57 .with_identifier(identifier)
58 .build()
59}
60
61#[runtime_builtin(
62 name = "cellfun",
63 category = "cells/core",
64 summary = "Apply a function to the contents of each cell array element.",
65 keywords = "cellfun,cell,array,functional",
66 accel = "host",
67 type_resolver(cellfun_type),
68 builtin_path = "crate::builtins::cells::core::cellfun"
69)]
70async fn cellfun_builtin(func: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
71 let callable = Callable::from_function(func)?;
72 let mut args = rest;
73
74 let mut uniform_output = true;
75 let mut error_handler: Option<Callable> = None;
76
77 while args.len() >= 2 {
78 let name_candidate = args[args.len() - 2].clone();
79 let Some(name) = extract_string(&name_candidate) else {
80 break;
81 };
82 let value = args.pop().expect("value present");
83 args.pop();
84 match name.to_ascii_lowercase().as_str() {
85 "uniformoutput" => {
86 uniform_output = parse_uniform_output(value)?;
87 }
88 "errorhandler" => {
89 error_handler = Some(Callable::from_function(value)?);
90 }
91 unknown => {
92 return Err(cellfun_error_with_identifier(
93 format!("cellfun: unknown name-value argument '{unknown}'"),
94 IDENT_INVALID_INPUT,
95 ));
96 }
97 }
98 }
99
100 if args.is_empty() {
101 return Err(cellfun_error_with_identifier(
102 "cellfun: expected at least one cell array input",
103 IDENT_INVALID_INPUT,
104 ));
105 }
106
107 let mut cell_inputs: Vec<CellArray> = Vec::new();
108 let mut extra_args: Vec<Value> = Vec::new();
109 let mut seen_non_cell = false;
110
111 for value in args.into_iter() {
112 match value {
113 Value::Cell(ca) if !seen_non_cell => cell_inputs.push(ca),
114 Value::Cell(_) => {
115 return Err(cellfun_error_with_identifier(
116 "cellfun: cell array inputs must precede extra arguments",
117 IDENT_INVALID_INPUT,
118 ));
119 }
120 other => {
121 seen_non_cell = true;
122 extra_args.push(other);
123 }
124 }
125 }
126
127 if cell_inputs.is_empty() {
128 return Err(cellfun_error_with_identifier(
129 "cellfun: expected at least one cell array input",
130 IDENT_INVALID_INPUT,
131 ));
132 }
133
134 let reference_shape = cell_inputs[0].shape.clone();
135 for (idx, ca) in cell_inputs.iter().enumerate().skip(1) {
136 if ca.shape != reference_shape {
137 return Err(cellfun_error_with_identifier(
138 format!(
139 "cellfun: cell array input {} does not match the size of the first input",
140 idx + 1
141 ),
142 IDENT_INVALID_INPUT,
143 ));
144 }
145 }
146
147 if uniform_output {
148 execute_uniform(
149 &callable,
150 &cell_inputs,
151 &extra_args,
152 error_handler,
153 &reference_shape,
154 )
155 .await
156 } else {
157 execute_cell(
158 &callable,
159 &cell_inputs,
160 &extra_args,
161 error_handler,
162 &reference_shape,
163 )
164 .await
165 }
166}
167
168async fn execute_uniform(
169 callable: &Callable,
170 cell_inputs: &[CellArray],
171 extra_args: &[Value],
172 error_handler: Option<Callable>,
173 shape: &[usize],
174) -> BuiltinResult<Value> {
175 let element_count = total_len(shape).ok_or_else(|| {
176 cellfun_error_with_identifier(
177 "cellfun: cell array size exceeds platform limits",
178 IDENT_INVALID_INPUT,
179 )
180 })?;
181
182 let host_extra_args = prepare_extra_args(extra_args).await?;
183 let mut collector = UniformCollector::Pending;
184 let mut cell_values: Vec<Value> = Vec::with_capacity(cell_inputs.len());
185 let mut call_args: Vec<Value> = Vec::with_capacity(cell_inputs.len() + host_extra_args.len());
186
187 for linear_idx in 0..element_count {
188 cell_values.clear();
189 for cell in cell_inputs {
190 let raw = deref_cell_value(cell, linear_idx);
191 let host_value = gather_if_needed_async(&raw).await?;
192 cell_values.push(host_value);
193 }
194 call_args.clear();
195 call_args.extend(cell_values.iter().cloned());
196 call_args.extend(host_extra_args.iter().cloned());
197
198 let result = match callable.call(&call_args).await {
199 Ok(value) => value,
200 Err(err) => {
201 let Some(handler) = error_handler.as_ref() else {
202 return Err(err);
203 };
204 let err_value = make_error_struct(&err, linear_idx, shape)?;
205 let mut handler_args =
206 Vec::with_capacity(1 + cell_values.len() + host_extra_args.len());
207 handler_args.push(err_value);
208 handler_args.extend(cell_values.clone());
209 handler_args.extend(host_extra_args.iter().cloned());
210 handler.call(&handler_args).await?
211 }
212 };
213
214 let host_value = gather_if_needed_async(&result).await?;
215 collector.push(&host_value)?;
216 }
217
218 collector.finish(shape)
219}
220
221async fn execute_cell(
222 callable: &Callable,
223 cell_inputs: &[CellArray],
224 extra_args: &[Value],
225 error_handler: Option<Callable>,
226 shape: &[usize],
227) -> BuiltinResult<Value> {
228 let element_count = total_len(shape).ok_or_else(|| {
229 cellfun_error_with_identifier(
230 "cellfun: cell array size exceeds platform limits",
231 IDENT_INVALID_INPUT,
232 )
233 })?;
234 let host_extra_args = prepare_extra_args(extra_args).await?;
235 let mut outputs: Vec<Value> = Vec::with_capacity(element_count);
236 let mut cell_values: Vec<Value> = Vec::with_capacity(cell_inputs.len());
237 let mut call_args: Vec<Value> = Vec::with_capacity(cell_inputs.len() + host_extra_args.len());
238
239 for linear_idx in 0..element_count {
240 cell_values.clear();
241 for cell in cell_inputs {
242 let raw = deref_cell_value(cell, linear_idx);
243 let host_value = gather_if_needed_async(&raw).await?;
244 cell_values.push(host_value);
245 }
246 call_args.clear();
247 call_args.extend(cell_values.iter().cloned());
248 call_args.extend(host_extra_args.iter().cloned());
249
250 let result = match callable.call(&call_args).await {
251 Ok(value) => value,
252 Err(err) => {
253 let Some(handler) = error_handler.as_ref() else {
254 return Err(err);
255 };
256 let err_value = make_error_struct(&err, linear_idx, shape)?;
257 let mut handler_args =
258 Vec::with_capacity(1 + cell_values.len() + host_extra_args.len());
259 handler_args.push(err_value);
260 handler_args.extend(cell_values.clone());
261 handler_args.extend(host_extra_args.iter().cloned());
262 handler.call(&handler_args).await?
263 }
264 };
265
266 let host_value = gather_if_needed_async(&result).await?;
267 outputs.push(host_value);
268 }
269
270 make_cell_with_shape(outputs, shape.to_vec())
271 .map_err(|e| cellfun_error(format!("cellfun: {e}")))
272}
273
274fn deref_cell_value(cell: &CellArray, index: usize) -> Value {
275 cell.data
276 .get(index)
277 .map(|ptr| (**ptr).clone())
278 .unwrap_or(Value::Num(f64::NAN))
279}
280
281fn total_len(shape: &[usize]) -> Option<usize> {
282 if shape.is_empty() {
283 Some(0)
284 } else {
285 shape
286 .iter()
287 .try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
288 }
289}
290
291fn extract_string(value: &Value) -> Option<String> {
292 match value {
293 Value::String(s) => Some(s.clone()),
294 Value::CharArray(ca) if ca.rows == 1 => Some(ca.data.iter().collect()),
295 Value::StringArray(sa) if sa.data.len() == 1 => Some(sa.data[0].clone()),
296 _ => None,
297 }
298}
299
300async fn prepare_extra_args(extra_args: &[Value]) -> BuiltinResult<Vec<Value>> {
301 let mut host_args = Vec::with_capacity(extra_args.len());
302 for arg in extra_args {
303 host_args.push(gather_if_needed_async(arg).await?);
304 }
305 Ok(host_args)
306}
307
308fn parse_uniform_output(value: Value) -> BuiltinResult<bool> {
309 match value {
310 Value::Bool(b) => Ok(b),
311 Value::Num(n) => Ok(n != 0.0),
312 Value::Int(iv) => Ok(iv.to_f64() != 0.0),
313 Value::String(s) => parse_bool_string(&s).ok_or_else(|| {
314 cellfun_error_with_identifier(
315 "cellfun: UniformOutput must be logical true or false",
316 IDENT_UNIFORM_OUTPUT,
317 )
318 }),
319 Value::CharArray(ca) if ca.rows == 1 => {
320 let s: String = ca.data.iter().collect();
321 parse_bool_string(&s).ok_or_else(|| {
322 cellfun_error_with_identifier(
323 "cellfun: UniformOutput must be logical true or false",
324 IDENT_UNIFORM_OUTPUT,
325 )
326 })
327 }
328 other => Err(cellfun_error_with_identifier(
329 format!("cellfun: UniformOutput must be logical true or false, got {other:?}"),
330 IDENT_UNIFORM_OUTPUT,
331 )),
332 }
333}
334
335fn parse_bool_string(value: &str) -> Option<bool> {
336 match value.trim().to_ascii_lowercase().as_str() {
337 "true" | "on" => Some(true),
338 "false" | "off" => Some(false),
339 _ => None,
340 }
341}
342
343fn make_error_struct(
344 raw_error: &RuntimeError,
345 linear_index: usize,
346 shape: &[usize],
347) -> BuiltinResult<Value> {
348 let (identifier, message) = error_identifier_and_message(raw_error);
349 let mut st = StructValue::new();
350 st.fields
351 .insert("identifier".to_string(), Value::String(identifier));
352 st.fields
353 .insert("message".to_string(), Value::String(message));
354 st.fields
355 .insert("index".to_string(), Value::Num((linear_index + 1) as f64));
356 let subs = linear_to_indices(linear_index, shape);
357 let subs_tensor =
358 dims_to_row_tensor(&subs).map_err(|e| cellfun_error(format!("cellfun: {e}")))?;
359 st.fields
360 .insert("indices".to_string(), Value::Tensor(subs_tensor));
361 Ok(Value::Struct(st))
362}
363
364fn error_identifier_and_message(error: &RuntimeError) -> (String, String) {
365 if let Some(identifier) = error.identifier() {
366 return (identifier.to_string(), error.message().to_string());
367 }
368 split_error_message(error.message())
369}
370
371fn split_error_message(raw: &str) -> (String, String) {
372 let trimmed = raw.trim();
373 let mut indices = trimmed.match_indices(':');
374 if let Some((_, _)) = indices.next() {
375 if let Some((second_idx, _)) = indices.next() {
376 let identifier = trimmed[..second_idx].trim().to_string();
377 let message = trimmed[second_idx + 1..].trim().to_string();
378 if !identifier.is_empty() && identifier.contains(':') {
379 return (
380 identifier,
381 if message.is_empty() {
382 trimmed.to_string()
383 } else {
384 message
385 },
386 );
387 }
388 } else if trimmed.len() >= 7
389 && (trimmed[..7].eq_ignore_ascii_case("matlab:")
390 || trimmed[..7].eq_ignore_ascii_case("runmat:"))
391 {
392 return (trimmed.to_string(), String::new());
393 }
394 }
395 (IDENT_FUNCTION_ERROR.to_string(), trimmed.to_string())
396}
397
398fn linear_to_indices(mut index: usize, shape: &[usize]) -> Vec<usize> {
399 if shape.is_empty() {
400 return vec![1];
401 }
402 let mut subs = Vec::with_capacity(shape.len());
403 for &dim in shape {
404 if dim == 0 {
405 subs.push(1);
406 continue;
407 }
408 let coord = (index % dim) + 1;
409 subs.push(coord);
410 index /= dim;
411 }
412 subs
413}
414
415#[derive(Clone)]
416enum Callable {
417 Builtin { name: String },
418 Closure(Closure),
419 Special(SpecialCallable),
420}
421
422impl Callable {
423 fn from_function(value: Value) -> BuiltinResult<Self> {
424 match value {
425 Value::String(s) => Self::from_text(&s, true),
426 Value::CharArray(ca) => {
427 if ca.rows != 1 {
428 Err(cellfun_error_with_identifier(
429 "cellfun: function name must be a character vector or string scalar",
430 IDENT_INVALID_INPUT,
431 ))
432 } else {
433 let text: String = ca.data.iter().collect();
434 Self::from_text(&text, true)
435 }
436 }
437 Value::StringArray(sa) => {
438 if sa.data.len() == 1 {
439 Self::from_text(&sa.data[0], true)
440 } else {
441 Err(cellfun_error_with_identifier(
442 "cellfun: function name must be a character vector or string scalar",
443 IDENT_INVALID_INPUT,
444 ))
445 }
446 }
447 Value::FunctionHandle(name) => Self::from_text(&name, true),
448 Value::Closure(c) => Ok(Callable::Closure(c)),
449 other => Err(cellfun_error_with_identifier(
450 format!("cellfun: expected function handle or builtin name, got {other:?}"),
451 IDENT_INVALID_INPUT,
452 )),
453 }
454 }
455
456 fn from_text(text: &str, fold_case: bool) -> BuiltinResult<Self> {
457 let trimmed = text.trim();
458 if trimmed.is_empty() {
459 return Err(cellfun_error_with_identifier(
460 "cellfun: expected function handle or builtin name, got empty string",
461 IDENT_INVALID_INPUT,
462 ));
463 }
464 if let Some(rest) = trimmed.strip_prefix('@') {
465 let name = rest.trim();
466 if name.is_empty() {
467 Err(cellfun_error_with_identifier(
468 "cellfun: empty function handle",
469 IDENT_INVALID_INPUT,
470 ))
471 } else {
472 Ok(Callable::Builtin {
473 name: name.to_string(),
474 })
475 }
476 } else {
477 let lowered = trimmed.to_ascii_lowercase();
478 if fold_case && lowered == "isclass" {
479 Ok(Callable::Special(SpecialCallable::IsClass))
480 } else if fold_case && lowered == "prodofsize" {
481 Ok(Callable::Special(SpecialCallable::ProdOfSize))
482 } else {
483 let name = if fold_case {
484 lowered
485 } else {
486 trimmed.to_string()
487 };
488 Ok(Callable::Builtin { name })
489 }
490 }
491 }
492
493 async fn call(&self, args: &[Value]) -> BuiltinResult<Value> {
494 fn is_undefined_function(err: &RuntimeError) -> bool {
495 let identifier = err.identifier().unwrap_or("").to_ascii_lowercase();
496 let message = err.message().to_ascii_lowercase();
497 identifier.contains("undefinedfunction") || message.contains("undefined function")
498 }
499 match self {
500 Callable::Builtin { name } => {
501 if let Some(result) = user_functions::try_call_user_function(name, args).await {
502 match result {
503 Ok(value) => return Ok(value),
504 Err(err) => {
505 if !is_undefined_function(&err) {
506 return Err(err);
507 }
508 }
509 }
510 }
511 call_builtin_async(name, args).await
512 }
513 Callable::Closure(c) => {
514 let mut captures = c.captures.clone();
515 captures.extend_from_slice(args);
516 if let Some(result) =
517 user_functions::try_call_user_function(&c.function_name, &captures).await
518 {
519 match result {
520 Ok(value) => return Ok(value),
521 Err(err) => {
522 if !is_undefined_function(&err) {
523 return Err(err);
524 }
525 }
526 }
527 }
528 call_builtin_async(&c.function_name, &captures).await
529 }
530 Callable::Special(special) => special.call(args).await,
531 }
532 }
533}
534
535#[derive(Clone)]
536enum SpecialCallable {
537 ProdOfSize,
538 IsClass,
539}
540
541impl SpecialCallable {
542 async fn call(&self, args: &[Value]) -> BuiltinResult<Value> {
543 match self {
544 SpecialCallable::ProdOfSize => {
545 let value = args.first().ok_or_else(|| {
546 cellfun_error_with_identifier(
547 "cellfun: prodofsize requires one input",
548 IDENT_INVALID_INPUT,
549 )
550 })?;
551 Ok(Value::Num(value_numel(value).await? as f64))
552 }
553 SpecialCallable::IsClass => {
554 if args.len() < 2 {
555 return Err(cellfun_error_with_identifier(
556 "cellfun: 'isclass' requires a class name argument",
557 IDENT_INVALID_INPUT,
558 ));
559 }
560 let left = args[0].clone();
561 let class_name = extract_string(&args[1]).ok_or_else(|| {
562 cellfun_error_with_identifier(
563 "cellfun: class name must be a string scalar",
564 IDENT_INVALID_INPUT,
565 )
566 })?;
567 let class_value = call_builtin_async("class", &[left]).await?;
568 let class_str = extract_string(&class_value).ok_or_else(|| {
569 cellfun_error_with_identifier(
570 "cellfun: failed to evaluate class name",
571 IDENT_FUNCTION_ERROR,
572 )
573 })?;
574 Ok(Value::Bool(
575 class_str.eq_ignore_ascii_case(class_name.trim()),
576 ))
577 }
578 }
579 }
580}
581
582enum UniformCollector {
583 Pending,
584 Double(Vec<f64>),
585 Logical(Vec<u8>),
586 Complex(Vec<(f64, f64)>),
587}
588
589impl UniformCollector {
590 fn push(&mut self, value: &Value) -> BuiltinResult<()> {
591 match self {
592 UniformCollector::Pending => match classify_value(value)? {
593 ClassifiedValue::Logical(b) => {
594 *self = UniformCollector::Logical(vec![b as u8]);
595 Ok(())
596 }
597 ClassifiedValue::Double(d) => {
598 *self = UniformCollector::Double(vec![d]);
599 Ok(())
600 }
601 ClassifiedValue::Complex(c) => {
602 *self = UniformCollector::Complex(vec![c]);
603 Ok(())
604 }
605 },
606 UniformCollector::Logical(bits) => match classify_value(value)? {
607 ClassifiedValue::Logical(b) => {
608 bits.push(b as u8);
609 Ok(())
610 }
611 ClassifiedValue::Double(d) => {
612 let mut data: Vec<f64> = bits
613 .iter()
614 .map(|&bit| if bit != 0 { 1.0 } else { 0.0 })
615 .collect();
616 data.push(d);
617 *self = UniformCollector::Double(data);
618 Ok(())
619 }
620 ClassifiedValue::Complex(c) => {
621 let mut data: Vec<(f64, f64)> = bits
622 .iter()
623 .map(|&bit| if bit != 0 { (1.0, 0.0) } else { (0.0, 0.0) })
624 .collect();
625 data.push(c);
626 *self = UniformCollector::Complex(data);
627 Ok(())
628 }
629 },
630 UniformCollector::Double(data) => match classify_value(value)? {
631 ClassifiedValue::Logical(b) => {
632 data.push(if b { 1.0 } else { 0.0 });
633 Ok(())
634 }
635 ClassifiedValue::Double(d) => {
636 data.push(d);
637 Ok(())
638 }
639 ClassifiedValue::Complex(c) => {
640 let promoted: Vec<(f64, f64)> = data.iter().map(|&v| (v, 0.0)).collect();
641 let mut complex = promoted;
642 complex.push(c);
643 *self = UniformCollector::Complex(complex);
644 Ok(())
645 }
646 },
647 UniformCollector::Complex(data) => match classify_value(value)? {
648 ClassifiedValue::Logical(b) => {
649 data.push((if b { 1.0 } else { 0.0 }, 0.0));
650 Ok(())
651 }
652 ClassifiedValue::Double(d) => {
653 data.push((d, 0.0));
654 Ok(())
655 }
656 ClassifiedValue::Complex(c) => {
657 data.push(c);
658 Ok(())
659 }
660 },
661 }
662 }
663
664 fn finish(self, shape: &[usize]) -> BuiltinResult<Value> {
665 match self {
666 UniformCollector::Pending => {
667 let total = total_len(shape).unwrap_or(0);
668 let data = vec![0.0; total];
669 let tensor = Tensor::new(data, shape.to_vec())
670 .map_err(|e| cellfun_error(format!("cellfun: {e}")))?;
671 Ok(Value::Tensor(tensor))
672 }
673 UniformCollector::Double(data) => {
674 let tensor = Tensor::new(data, shape.to_vec())
675 .map_err(|e| cellfun_error(format!("cellfun: {e}")))?;
676 Ok(Value::Tensor(tensor))
677 }
678 UniformCollector::Logical(bits) => {
679 let logical = LogicalArray::new(bits, shape.to_vec())
680 .map_err(|e| cellfun_error(format!("cellfun: {e}")))?;
681 Ok(Value::LogicalArray(logical))
682 }
683 UniformCollector::Complex(data) => {
684 let complex = ComplexTensor::new(data, shape.to_vec())
685 .map_err(|e| cellfun_error(format!("cellfun: {e}")))?;
686 Ok(Value::ComplexTensor(complex))
687 }
688 }
689 }
690}
691
692enum ClassifiedValue {
693 Logical(bool),
694 Double(f64),
695 Complex((f64, f64)),
696}
697
698fn classify_value(value: &Value) -> BuiltinResult<ClassifiedValue> {
699 match value {
700 Value::Bool(b) => Ok(ClassifiedValue::Logical(*b)),
701 Value::Num(n) => Ok(ClassifiedValue::Double(*n)),
702 Value::Int(iv) => Ok(ClassifiedValue::Double(iv.to_f64())),
703 Value::Complex(re, im) => Ok(ClassifiedValue::Complex((*re, *im))),
704 Value::Tensor(t) if t.data.len() == 1 => Ok(ClassifiedValue::Double(t.data[0])),
705 Value::LogicalArray(la) if la.data.len() == 1 => {
706 Ok(ClassifiedValue::Logical(la.data[0] != 0))
707 }
708 Value::ComplexTensor(ct) if ct.data.len() == 1 => Ok(ClassifiedValue::Complex(ct.data[0])),
709 _ => Err(cellfun_error_with_identifier(
710 "cellfun: callback must return scalar values when 'UniformOutput' is true",
711 IDENT_UNIFORM_OUTPUT,
712 )),
713 }
714}
715
716#[cfg(test)]
717pub(crate) mod tests {
718 use super::*;
719 use crate::builtins::common::test_support;
720 use futures::executor::block_on;
721 use runmat_accelerate_api::HostTensorView;
722 use runmat_builtins::{IntValue, StringArray};
723 use std::convert::TryInto;
724
725 fn cellfun_builtin(func: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
726 block_on(super::cellfun_builtin(func, rest))
727 }
728
729 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
730 #[test]
731 fn cellfun_length_uniform_default() {
732 let cell = crate::make_cell(
733 vec![
734 Value::Tensor(Tensor::new(vec![1.0, 2.0, 3.0], vec![1, 3]).unwrap()),
735 Value::Tensor(Tensor::new(vec![4.0, 5.0, 6.0, 7.0], vec![1, 4]).unwrap()),
736 Value::Tensor(Tensor::new(vec![8.0, 9.0], vec![1, 2]).unwrap()),
737 ],
738 1,
739 3,
740 )
741 .expect("cell");
742 let result =
743 cellfun_builtin(Value::String("@length".into()), vec![cell]).expect("cellfun length");
744 match result {
745 Value::Tensor(t) => {
746 assert_eq!(t.shape, vec![1, 3]);
747 assert_eq!(t.data, vec![3.0, 4.0, 2.0]);
748 }
749 other => panic!("expected tensor, got {other:?}"),
750 }
751 }
752
753 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
754 #[test]
755 fn cellfun_multiple_cells_plus() {
756 let left = crate::make_cell(
757 vec![Value::Num(1.0), Value::Num(2.0), Value::Num(3.0)],
758 1,
759 3,
760 )
761 .expect("cell");
762 let right = crate::make_cell(
763 vec![Value::Num(4.0), Value::Num(5.0), Value::Num(6.0)],
764 1,
765 3,
766 )
767 .expect("cell");
768 let result = cellfun_builtin(Value::String("@__cellfun_add".into()), vec![left, right])
769 .expect("cellfun add");
770 match result {
771 Value::Tensor(t) => {
772 assert_eq!(t.data, vec![5.0, 7.0, 9.0]);
773 }
774 other => panic!("expected tensor, got {other:?}"),
775 }
776 }
777
778 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
779 #[test]
780 fn cellfun_uniform_false_returns_cells() {
781 let cell = crate::make_cell(
782 vec![
783 Value::String("Ada".into()),
784 Value::String("Linus".into()),
785 Value::String("Katherine".into()),
786 ],
787 1,
788 3,
789 )
790 .expect("cell");
791 let result = cellfun_builtin(
792 Value::String("@upper".into()),
793 vec![
794 cell,
795 Value::String("UniformOutput".into()),
796 Value::Bool(false),
797 ],
798 )
799 .expect("cellfun upper");
800 match result {
801 Value::Cell(ca) => {
802 assert_eq!(ca.shape, vec![1, 3]);
803 let upper_a = (*ca.data[0]).clone();
804 let upper_b = (*ca.data[1]).clone();
805 let upper_c = (*ca.data[2]).clone();
806 assert_eq!(extract_string(&upper_a).unwrap(), "ADA");
807 assert_eq!(extract_string(&upper_b).unwrap(), "LINUS");
808 assert_eq!(extract_string(&upper_c).unwrap(), "KATHERINE");
809 }
810 other => panic!("expected cell array, got {other:?}"),
811 }
812 }
813
814 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
815 #[test]
816 fn cellfun_error_handler_recovers() {
817 let cells = crate::make_cell(
818 vec![Value::Num(1.0), Value::Num(2.0), Value::Num(3.0)],
819 1,
820 3,
821 )
822 .expect("cell");
823 let handler = Value::Closure(Closure {
824 function_name: "__cellfun_test_handler".into(),
825 captures: vec![Value::Num(0.0)],
826 });
827 let result = cellfun_builtin(
828 Value::String("@nonexistent_builtin".into()),
829 vec![cells, Value::String("ErrorHandler".into()), handler],
830 )
831 .expect("cellfun error handler");
832 match result {
833 Value::Tensor(t) => {
834 assert_eq!(t.data, vec![0.0, 0.0, 0.0]);
835 }
836 other => panic!("expected tensor, got {other:?}"),
837 }
838 }
839
840 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
841 #[test]
842 fn cellfun_string_identifier() {
843 let cells = crate::make_cell(
844 vec![
845 Value::CharArray(runmat_builtins::CharArray::new_row("")),
846 Value::CharArray(runmat_builtins::CharArray::new_row("abc")),
847 Value::CharArray(runmat_builtins::CharArray::new_row("")),
848 ],
849 1,
850 3,
851 )
852 .expect("cell");
853 let result = cellfun_builtin(
854 Value::CharArray(runmat_builtins::CharArray::new_row("isempty")),
855 vec![cells],
856 )
857 .expect("isempty");
858 match result {
859 Value::LogicalArray(la) => {
860 assert_eq!(la.shape, vec![1, 3]);
861 assert_eq!(la.data, vec![1, 0, 1]);
862 }
863 other => panic!("expected logical array, got {other:?}"),
864 }
865 }
866
867 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
868 #[test]
869 fn cellfun_string_array_identifier() {
870 let cells = crate::make_cell(
871 vec![Value::CharArray(runmat_builtins::CharArray::new_row(""))],
872 1,
873 1,
874 )
875 .expect("cell");
876 let sa = StringArray::new(vec!["isempty".into()], vec![1, 1]).unwrap();
877 let result =
878 cellfun_builtin(Value::StringArray(sa), vec![cells]).expect("cellfun string array");
879 match result {
880 Value::LogicalArray(la) => {
881 assert_eq!(la.shape, vec![1, 1]);
882 assert_eq!(la.data, vec![1]);
883 }
884 other => panic!("expected logical array, got {other:?}"),
885 }
886 }
887
888 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
889 #[test]
890 fn cellfun_uniform_true_non_scalar_errors() {
891 let cells = crate::make_cell(
892 vec![Value::Tensor(
893 Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap(),
894 )],
895 1,
896 1,
897 )
898 .expect("cell");
899 let err = cellfun_builtin(Value::String("@eye".into()), vec![cells])
900 .unwrap_err()
901 .to_string();
902 assert!(
903 err.to_ascii_lowercase().contains("uniformoutput"),
904 "unexpected error: {err}"
905 );
906 }
907
908 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
909 #[test]
910 fn cellfun_uniform_promotes_logical_to_double() {
911 let cells = crate::make_cell(vec![Value::Bool(true), Value::Num(2.5)], 1, 2).unwrap();
912 let result = cellfun_builtin(Value::String("@__cellfun_identity".into()), vec![cells])
913 .expect("cellfun identity");
914 match result {
915 Value::Tensor(t) => {
916 assert_eq!(t.shape, vec![1, 2]);
917 assert_eq!(t.data, vec![1.0, 2.5]);
918 }
919 other => panic!("expected tensor, got {other:?}"),
920 }
921 }
922
923 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
924 #[test]
925 fn cellfun_uniform_promotes_double_to_complex() {
926 let cells =
927 crate::make_cell(vec![Value::Num(2.0), Value::Complex(0.0, 1.0)], 1, 2).unwrap();
928 let result = cellfun_builtin(Value::String("@__cellfun_identity".into()), vec![cells])
929 .expect("cellfun identity");
930 match result {
931 Value::ComplexTensor(ct) => {
932 assert_eq!(ct.shape, vec![1, 2]);
933 assert_eq!(ct.data, vec![(2.0, 0.0), (0.0, 1.0)]);
934 }
935 other => panic!("expected complex tensor, got {other:?}"),
936 }
937 }
938
939 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
940 #[test]
941 fn cellfun_errors_on_mismatched_cell_sizes() {
942 let first = crate::make_cell(vec![Value::Num(1.0), Value::Num(2.0)], 1, 2).unwrap();
943 let second = crate::make_cell(vec![Value::Num(3.0)], 1, 1).unwrap();
944 let err = cellfun_builtin(
945 Value::String("@__cellfun_identity".into()),
946 vec![first, second],
947 )
948 .unwrap_err()
949 .to_string();
950 assert!(
951 err.to_ascii_lowercase().contains("size"),
952 "expected size mismatch error, got: {err}"
953 );
954 }
955
956 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
957 #[test]
958 fn cellfun_uniformoutput_accepts_char_flags() {
959 let strings =
960 crate::make_cell(vec![Value::String("Ada".into())], 1, 1).expect("cell creation");
961 let result = cellfun_builtin(
962 Value::String("@upper".into()),
963 vec![
964 strings,
965 Value::CharArray(runmat_builtins::CharArray::new_row("UniformOutput")),
966 Value::CharArray(runmat_builtins::CharArray::new_row("off")),
967 ],
968 )
969 .expect("cellfun upper char flag");
970 assert!(
971 matches!(result, Value::Cell(_)),
972 "expected cell array result when UniformOutput is 'off'"
973 );
974 }
975
976 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
977 #[test]
978 fn cellfun_isclass_special_case() {
979 let ints = crate::make_cell(
980 vec![
981 Value::Int(IntValue::I32(5)),
982 Value::Num(std::f64::consts::PI),
983 Value::Int(IntValue::I16(2)),
984 ],
985 1,
986 3,
987 )
988 .expect("cell");
989 let result = cellfun_builtin(
990 Value::String("isclass".into()),
991 vec![ints, Value::String("int32".into())],
992 )
993 .expect("cellfun isclass");
994 match result {
995 Value::LogicalArray(la) => {
996 assert_eq!(la.data, vec![1, 0, 0]);
997 }
998 other => panic!("expected logical array, got {other:?}"),
999 }
1000 }
1001
1002 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1003 #[test]
1004 fn cellfun_passes_additional_arguments() {
1005 let matrices = crate::make_cell(
1006 vec![
1007 Value::Tensor(Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap()),
1008 Value::Tensor(Tensor::new(vec![5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap()),
1009 ],
1010 1,
1011 2,
1012 )
1013 .expect("cell");
1014 let dimension = Value::Num(2.0);
1015 let result = cellfun_builtin(Value::String("size".into()), vec![matrices, dimension])
1016 .expect("cellfun size");
1017 match result {
1018 Value::Tensor(t) => {
1019 assert_eq!(t.data, vec![2.0, 2.0]);
1020 }
1021 other => panic!("expected tensor, got {other:?}"),
1022 }
1023 }
1024
1025 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1026 #[test]
1027 fn cellfun_handles_string_array_uniform_false() {
1028 let sa = StringArray::new(vec!["foo".into(), "bar".into()], vec![1, 2]).unwrap();
1029 let cell = crate::make_cell(vec![Value::StringArray(sa)], 1, 1).unwrap();
1030 let result = cellfun_builtin(
1031 Value::String("@strlength".into()),
1032 vec![
1033 cell,
1034 Value::String("UniformOutput".into()),
1035 Value::Bool(false),
1036 ],
1037 )
1038 .unwrap();
1039 match result {
1040 Value::Cell(ca) => {
1041 assert_eq!(ca.shape, vec![1, 1]);
1042 let inner = (*ca.data[0]).clone();
1043 match inner {
1044 Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 3.0]),
1045 _ => panic!("expected tensor inside cell"),
1046 }
1047 }
1048 other => panic!("expected cell, got {other:?}"),
1049 }
1050 }
1051
1052 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1053 #[test]
1054 fn cellfun_gathers_gpu_inputs() {
1055 test_support::with_test_provider(|provider| {
1056 let angle = std::f64::consts::PI / 6.0;
1057 let tensor = Tensor::new(vec![angle], vec![1, 1]).unwrap();
1058 let view = HostTensorView {
1059 data: &tensor.data,
1060 shape: &tensor.shape,
1061 };
1062 let handle = provider.upload(&view).expect("upload");
1063 let cell = crate::make_cell(vec![Value::GpuTensor(handle)], 1, 1).expect("cell");
1064 let result =
1065 cellfun_builtin(Value::String("@sin".into()), vec![cell]).expect("cellfun sin");
1066 let gathered = test_support::gather(result).expect("gather");
1067 assert_eq!(gathered.shape, vec![1, 1]);
1068 let expected = angle.sin();
1069 assert!((gathered.data[0] - expected).abs() < 1e-12);
1070 });
1071 }
1072
1073 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1074 #[test]
1075 #[cfg(feature = "wgpu")]
1076 fn cellfun_with_wgpu_provider_handles_gpu_cells() {
1077 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1078 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1079 );
1080 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1081
1082 let value = Tensor::new(vec![0.25], vec![1, 1]).unwrap();
1083 let view = HostTensorView {
1084 data: &value.data,
1085 shape: &value.shape,
1086 };
1087 let handle = provider.upload(&view).expect("upload");
1088 let cell = crate::make_cell(vec![Value::GpuTensor(handle)], 1, 1).expect("cell");
1089
1090 let result =
1091 cellfun_builtin(Value::String("@sin".into()), vec![cell]).expect("cellfun sin");
1092 let gathered = test_support::gather(result).expect("gather");
1093 assert_eq!(gathered.shape, vec![1, 1]);
1094 let expected = value.data[0].sin();
1095 assert!((gathered.data[0] - expected).abs() < 1e-12);
1096 }
1097
1098 #[runmat_macros::runtime_builtin(
1099 name = "__cellfun_test_handler",
1100 type_resolver(cellfun_type),
1101 builtin_path = "crate::builtins::cells::core::cellfun::tests"
1102 )]
1103 fn cellfun_test_handler(
1104 seed: Value,
1105 _err: Value,
1106 rest: Vec<Value>,
1107 ) -> crate::BuiltinResult<Value> {
1108 let _ = rest;
1110 Ok(seed)
1111 }
1112
1113 #[runmat_macros::runtime_builtin(
1114 name = "__cellfun_add",
1115 type_resolver(cellfun_type),
1116 builtin_path = "crate::builtins::cells::core::cellfun::tests"
1117 )]
1118 fn cellfun_add(lhs: Value, rhs: Value) -> crate::BuiltinResult<Value> {
1119 let a: f64 = (&lhs).try_into()?;
1120 let b: f64 = (&rhs).try_into()?;
1121 Ok(Value::Num(a + b))
1122 }
1123
1124 #[runmat_macros::runtime_builtin(
1125 name = "__cellfun_identity",
1126 type_resolver(cellfun_type),
1127 builtin_path = "crate::builtins::cells::core::cellfun::tests"
1128 )]
1129 fn cellfun_identity(value: Value) -> crate::BuiltinResult<Value> {
1130 Ok(value)
1131 }
1132}