Skip to main content

runmat_runtime/builtins/cells/core/
cellfun.rs

1//! MATLAB-compatible `cellfun` builtin with host execution semantics for RunMat.
2
3use runmat_builtins::{
4    CellArray, Closure, ComplexTensor, LogicalArray, StructValue, Tensor, Value,
5};
6use runmat_macros::runtime_builtin;
7
8use crate::builtins::cells::type_resolvers::cellfun_type;
9use crate::builtins::common::shape::{dims_to_row_tensor, value_numel};
10use crate::builtins::common::spec::{
11    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
12    ReductionNaN, ResidencyPolicy, ShapeRequirements,
13};
14use crate::{
15    build_runtime_error, call_builtin_async, gather_if_needed_async, make_cell_with_shape,
16    user_functions, BuiltinResult, RuntimeError,
17};
18
19#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::cells::core::cellfun")]
20pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
21    name: "cellfun",
22    op_kind: GpuOpKind::Custom("host-cell-map"),
23    supported_precisions: &[],
24    broadcast: BroadcastSemantics::None,
25    provider_hooks: &[],
26    constant_strategy: ConstantStrategy::InlineLiteral,
27    residency: ResidencyPolicy::GatherImmediately,
28    nan_mode: ReductionNaN::Include,
29    two_pass_threshold: None,
30    workgroup_size: None,
31    accepts_nan_mode: false,
32    notes: "Executes on the host and gathers GPU-resident inputs before evaluating callbacks.",
33};
34
35#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::cells::core::cellfun")]
36pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
37    name: "cellfun",
38    shape: ShapeRequirements::Any,
39    constant_strategy: ConstantStrategy::InlineLiteral,
40    elementwise: None,
41    reduction: None,
42    emits_nan: true,
43    notes: "Callback execution happens on the host; fusion planners should treat cellfun as a fusion barrier.",
44};
45
46const IDENT_INVALID_INPUT: &str = "RunMat:cellfun:InvalidInput";
47const IDENT_UNIFORM_OUTPUT: &str = "RunMat:cellfun:UniformOutput";
48const IDENT_FUNCTION_ERROR: &str = "RunMat:cellfun:FunctionError";
49
50fn cellfun_error(message: impl Into<String>) -> RuntimeError {
51    build_runtime_error(message).with_builtin("cellfun").build()
52}
53
54fn cellfun_error_with_identifier(message: impl Into<String>, identifier: &str) -> RuntimeError {
55    build_runtime_error(message)
56        .with_builtin("cellfun")
57        .with_identifier(identifier)
58        .build()
59}
60
61#[runtime_builtin(
62    name = "cellfun",
63    category = "cells/core",
64    summary = "Apply a function to the contents of each cell array element.",
65    keywords = "cellfun,cell,array,functional",
66    accel = "host",
67    type_resolver(cellfun_type),
68    builtin_path = "crate::builtins::cells::core::cellfun"
69)]
70async fn cellfun_builtin(func: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
71    let callable = Callable::from_function(func)?;
72    let mut args = rest;
73
74    let mut uniform_output = true;
75    let mut error_handler: Option<Callable> = None;
76
77    while args.len() >= 2 {
78        let name_candidate = args[args.len() - 2].clone();
79        let Some(name) = extract_string(&name_candidate) else {
80            break;
81        };
82        let value = args.pop().expect("value present");
83        args.pop();
84        match name.to_ascii_lowercase().as_str() {
85            "uniformoutput" => {
86                uniform_output = parse_uniform_output(value)?;
87            }
88            "errorhandler" => {
89                error_handler = Some(Callable::from_function(value)?);
90            }
91            unknown => {
92                return Err(cellfun_error_with_identifier(
93                    format!("cellfun: unknown name-value argument '{unknown}'"),
94                    IDENT_INVALID_INPUT,
95                ));
96            }
97        }
98    }
99
100    if args.is_empty() {
101        return Err(cellfun_error_with_identifier(
102            "cellfun: expected at least one cell array input",
103            IDENT_INVALID_INPUT,
104        ));
105    }
106
107    let mut cell_inputs: Vec<CellArray> = Vec::new();
108    let mut extra_args: Vec<Value> = Vec::new();
109    let mut seen_non_cell = false;
110
111    for value in args.into_iter() {
112        match value {
113            Value::Cell(ca) if !seen_non_cell => cell_inputs.push(ca),
114            Value::Cell(_) => {
115                return Err(cellfun_error_with_identifier(
116                    "cellfun: cell array inputs must precede extra arguments",
117                    IDENT_INVALID_INPUT,
118                ));
119            }
120            other => {
121                seen_non_cell = true;
122                extra_args.push(other);
123            }
124        }
125    }
126
127    if cell_inputs.is_empty() {
128        return Err(cellfun_error_with_identifier(
129            "cellfun: expected at least one cell array input",
130            IDENT_INVALID_INPUT,
131        ));
132    }
133
134    let reference_shape = cell_inputs[0].shape.clone();
135    for (idx, ca) in cell_inputs.iter().enumerate().skip(1) {
136        if ca.shape != reference_shape {
137            return Err(cellfun_error_with_identifier(
138                format!(
139                    "cellfun: cell array input {} does not match the size of the first input",
140                    idx + 1
141                ),
142                IDENT_INVALID_INPUT,
143            ));
144        }
145    }
146
147    if uniform_output {
148        execute_uniform(
149            &callable,
150            &cell_inputs,
151            &extra_args,
152            error_handler,
153            &reference_shape,
154        )
155        .await
156    } else {
157        execute_cell(
158            &callable,
159            &cell_inputs,
160            &extra_args,
161            error_handler,
162            &reference_shape,
163        )
164        .await
165    }
166}
167
168async fn execute_uniform(
169    callable: &Callable,
170    cell_inputs: &[CellArray],
171    extra_args: &[Value],
172    error_handler: Option<Callable>,
173    shape: &[usize],
174) -> BuiltinResult<Value> {
175    let element_count = total_len(shape).ok_or_else(|| {
176        cellfun_error_with_identifier(
177            "cellfun: cell array size exceeds platform limits",
178            IDENT_INVALID_INPUT,
179        )
180    })?;
181
182    let host_extra_args = prepare_extra_args(extra_args).await?;
183    let mut collector = UniformCollector::Pending;
184    let mut cell_values: Vec<Value> = Vec::with_capacity(cell_inputs.len());
185    let mut call_args: Vec<Value> = Vec::with_capacity(cell_inputs.len() + host_extra_args.len());
186
187    for linear_idx in 0..element_count {
188        cell_values.clear();
189        for cell in cell_inputs {
190            let raw = deref_cell_value(cell, linear_idx);
191            let host_value = gather_if_needed_async(&raw).await?;
192            cell_values.push(host_value);
193        }
194        call_args.clear();
195        call_args.extend(cell_values.iter().cloned());
196        call_args.extend(host_extra_args.iter().cloned());
197
198        let result = match callable.call(&call_args).await {
199            Ok(value) => value,
200            Err(err) => {
201                let Some(handler) = error_handler.as_ref() else {
202                    return Err(err);
203                };
204                let err_value = make_error_struct(&err, linear_idx, shape)?;
205                let mut handler_args =
206                    Vec::with_capacity(1 + cell_values.len() + host_extra_args.len());
207                handler_args.push(err_value);
208                handler_args.extend(cell_values.clone());
209                handler_args.extend(host_extra_args.iter().cloned());
210                handler.call(&handler_args).await?
211            }
212        };
213
214        let host_value = gather_if_needed_async(&result).await?;
215        collector.push(&host_value)?;
216    }
217
218    collector.finish(shape)
219}
220
221async fn execute_cell(
222    callable: &Callable,
223    cell_inputs: &[CellArray],
224    extra_args: &[Value],
225    error_handler: Option<Callable>,
226    shape: &[usize],
227) -> BuiltinResult<Value> {
228    let element_count = total_len(shape).ok_or_else(|| {
229        cellfun_error_with_identifier(
230            "cellfun: cell array size exceeds platform limits",
231            IDENT_INVALID_INPUT,
232        )
233    })?;
234    let host_extra_args = prepare_extra_args(extra_args).await?;
235    let mut outputs: Vec<Value> = Vec::with_capacity(element_count);
236    let mut cell_values: Vec<Value> = Vec::with_capacity(cell_inputs.len());
237    let mut call_args: Vec<Value> = Vec::with_capacity(cell_inputs.len() + host_extra_args.len());
238
239    for linear_idx in 0..element_count {
240        cell_values.clear();
241        for cell in cell_inputs {
242            let raw = deref_cell_value(cell, linear_idx);
243            let host_value = gather_if_needed_async(&raw).await?;
244            cell_values.push(host_value);
245        }
246        call_args.clear();
247        call_args.extend(cell_values.iter().cloned());
248        call_args.extend(host_extra_args.iter().cloned());
249
250        let result = match callable.call(&call_args).await {
251            Ok(value) => value,
252            Err(err) => {
253                let Some(handler) = error_handler.as_ref() else {
254                    return Err(err);
255                };
256                let err_value = make_error_struct(&err, linear_idx, shape)?;
257                let mut handler_args =
258                    Vec::with_capacity(1 + cell_values.len() + host_extra_args.len());
259                handler_args.push(err_value);
260                handler_args.extend(cell_values.clone());
261                handler_args.extend(host_extra_args.iter().cloned());
262                handler.call(&handler_args).await?
263            }
264        };
265
266        let host_value = gather_if_needed_async(&result).await?;
267        outputs.push(host_value);
268    }
269
270    make_cell_with_shape(outputs, shape.to_vec())
271        .map_err(|e| cellfun_error(format!("cellfun: {e}")))
272}
273
274fn deref_cell_value(cell: &CellArray, index: usize) -> Value {
275    cell.data
276        .get(index)
277        .map(|ptr| (**ptr).clone())
278        .unwrap_or(Value::Num(f64::NAN))
279}
280
281fn total_len(shape: &[usize]) -> Option<usize> {
282    if shape.is_empty() {
283        Some(0)
284    } else {
285        shape
286            .iter()
287            .try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
288    }
289}
290
291fn extract_string(value: &Value) -> Option<String> {
292    match value {
293        Value::String(s) => Some(s.clone()),
294        Value::CharArray(ca) if ca.rows == 1 => Some(ca.data.iter().collect()),
295        Value::StringArray(sa) if sa.data.len() == 1 => Some(sa.data[0].clone()),
296        _ => None,
297    }
298}
299
300async fn prepare_extra_args(extra_args: &[Value]) -> BuiltinResult<Vec<Value>> {
301    let mut host_args = Vec::with_capacity(extra_args.len());
302    for arg in extra_args {
303        host_args.push(gather_if_needed_async(arg).await?);
304    }
305    Ok(host_args)
306}
307
308fn parse_uniform_output(value: Value) -> BuiltinResult<bool> {
309    match value {
310        Value::Bool(b) => Ok(b),
311        Value::Num(n) => Ok(n != 0.0),
312        Value::Int(iv) => Ok(iv.to_f64() != 0.0),
313        Value::String(s) => parse_bool_string(&s).ok_or_else(|| {
314            cellfun_error_with_identifier(
315                "cellfun: UniformOutput must be logical true or false",
316                IDENT_UNIFORM_OUTPUT,
317            )
318        }),
319        Value::CharArray(ca) if ca.rows == 1 => {
320            let s: String = ca.data.iter().collect();
321            parse_bool_string(&s).ok_or_else(|| {
322                cellfun_error_with_identifier(
323                    "cellfun: UniformOutput must be logical true or false",
324                    IDENT_UNIFORM_OUTPUT,
325                )
326            })
327        }
328        other => Err(cellfun_error_with_identifier(
329            format!("cellfun: UniformOutput must be logical true or false, got {other:?}"),
330            IDENT_UNIFORM_OUTPUT,
331        )),
332    }
333}
334
335fn parse_bool_string(value: &str) -> Option<bool> {
336    match value.trim().to_ascii_lowercase().as_str() {
337        "true" | "on" => Some(true),
338        "false" | "off" => Some(false),
339        _ => None,
340    }
341}
342
343fn make_error_struct(
344    raw_error: &RuntimeError,
345    linear_index: usize,
346    shape: &[usize],
347) -> BuiltinResult<Value> {
348    let (identifier, message) = error_identifier_and_message(raw_error);
349    let mut st = StructValue::new();
350    st.fields
351        .insert("identifier".to_string(), Value::String(identifier));
352    st.fields
353        .insert("message".to_string(), Value::String(message));
354    st.fields
355        .insert("index".to_string(), Value::Num((linear_index + 1) as f64));
356    let subs = linear_to_indices(linear_index, shape);
357    let subs_tensor =
358        dims_to_row_tensor(&subs).map_err(|e| cellfun_error(format!("cellfun: {e}")))?;
359    st.fields
360        .insert("indices".to_string(), Value::Tensor(subs_tensor));
361    Ok(Value::Struct(st))
362}
363
364fn error_identifier_and_message(error: &RuntimeError) -> (String, String) {
365    if let Some(identifier) = error.identifier() {
366        return (identifier.to_string(), error.message().to_string());
367    }
368    split_error_message(error.message())
369}
370
371fn split_error_message(raw: &str) -> (String, String) {
372    let trimmed = raw.trim();
373    let mut indices = trimmed.match_indices(':');
374    if let Some((_, _)) = indices.next() {
375        if let Some((second_idx, _)) = indices.next() {
376            let identifier = trimmed[..second_idx].trim().to_string();
377            let message = trimmed[second_idx + 1..].trim().to_string();
378            if !identifier.is_empty() && identifier.contains(':') {
379                return (
380                    identifier,
381                    if message.is_empty() {
382                        trimmed.to_string()
383                    } else {
384                        message
385                    },
386                );
387            }
388        } else if trimmed.len() >= 7
389            && (trimmed[..7].eq_ignore_ascii_case("matlab:")
390                || trimmed[..7].eq_ignore_ascii_case("runmat:"))
391        {
392            return (trimmed.to_string(), String::new());
393        }
394    }
395    (IDENT_FUNCTION_ERROR.to_string(), trimmed.to_string())
396}
397
398fn linear_to_indices(mut index: usize, shape: &[usize]) -> Vec<usize> {
399    if shape.is_empty() {
400        return vec![1];
401    }
402    let mut subs = Vec::with_capacity(shape.len());
403    for &dim in shape {
404        if dim == 0 {
405            subs.push(1);
406            continue;
407        }
408        let coord = (index % dim) + 1;
409        subs.push(coord);
410        index /= dim;
411    }
412    subs
413}
414
415#[derive(Clone)]
416enum Callable {
417    Builtin { name: String },
418    Closure(Closure),
419    Special(SpecialCallable),
420}
421
422impl Callable {
423    fn from_function(value: Value) -> BuiltinResult<Self> {
424        match value {
425            Value::String(s) => Self::from_text(&s, true),
426            Value::CharArray(ca) => {
427                if ca.rows != 1 {
428                    Err(cellfun_error_with_identifier(
429                        "cellfun: function name must be a character vector or string scalar",
430                        IDENT_INVALID_INPUT,
431                    ))
432                } else {
433                    let text: String = ca.data.iter().collect();
434                    Self::from_text(&text, true)
435                }
436            }
437            Value::StringArray(sa) => {
438                if sa.data.len() == 1 {
439                    Self::from_text(&sa.data[0], true)
440                } else {
441                    Err(cellfun_error_with_identifier(
442                        "cellfun: function name must be a character vector or string scalar",
443                        IDENT_INVALID_INPUT,
444                    ))
445                }
446            }
447            Value::FunctionHandle(name) => Self::from_text(&name, true),
448            Value::Closure(c) => Ok(Callable::Closure(c)),
449            other => Err(cellfun_error_with_identifier(
450                format!("cellfun: expected function handle or builtin name, got {other:?}"),
451                IDENT_INVALID_INPUT,
452            )),
453        }
454    }
455
456    fn from_text(text: &str, fold_case: bool) -> BuiltinResult<Self> {
457        let trimmed = text.trim();
458        if trimmed.is_empty() {
459            return Err(cellfun_error_with_identifier(
460                "cellfun: expected function handle or builtin name, got empty string",
461                IDENT_INVALID_INPUT,
462            ));
463        }
464        if let Some(rest) = trimmed.strip_prefix('@') {
465            let name = rest.trim();
466            if name.is_empty() {
467                Err(cellfun_error_with_identifier(
468                    "cellfun: empty function handle",
469                    IDENT_INVALID_INPUT,
470                ))
471            } else {
472                Ok(Callable::Builtin {
473                    name: name.to_string(),
474                })
475            }
476        } else {
477            let lowered = trimmed.to_ascii_lowercase();
478            if fold_case && lowered == "isclass" {
479                Ok(Callable::Special(SpecialCallable::IsClass))
480            } else if fold_case && lowered == "prodofsize" {
481                Ok(Callable::Special(SpecialCallable::ProdOfSize))
482            } else {
483                let name = if fold_case {
484                    lowered
485                } else {
486                    trimmed.to_string()
487                };
488                Ok(Callable::Builtin { name })
489            }
490        }
491    }
492
493    async fn call(&self, args: &[Value]) -> BuiltinResult<Value> {
494        fn is_undefined_function(err: &RuntimeError) -> bool {
495            let identifier = err.identifier().unwrap_or("").to_ascii_lowercase();
496            let message = err.message().to_ascii_lowercase();
497            identifier.contains("undefinedfunction") || message.contains("undefined function")
498        }
499        match self {
500            Callable::Builtin { name } => {
501                if let Some(result) = user_functions::try_call_user_function(name, args).await {
502                    match result {
503                        Ok(value) => return Ok(value),
504                        Err(err) => {
505                            if !is_undefined_function(&err) {
506                                return Err(err);
507                            }
508                        }
509                    }
510                }
511                call_builtin_async(name, args).await
512            }
513            Callable::Closure(c) => {
514                let mut captures = c.captures.clone();
515                captures.extend_from_slice(args);
516                if let Some(result) =
517                    user_functions::try_call_user_function(&c.function_name, &captures).await
518                {
519                    match result {
520                        Ok(value) => return Ok(value),
521                        Err(err) => {
522                            if !is_undefined_function(&err) {
523                                return Err(err);
524                            }
525                        }
526                    }
527                }
528                call_builtin_async(&c.function_name, &captures).await
529            }
530            Callable::Special(special) => special.call(args).await,
531        }
532    }
533}
534
535#[derive(Clone)]
536enum SpecialCallable {
537    ProdOfSize,
538    IsClass,
539}
540
541impl SpecialCallable {
542    async fn call(&self, args: &[Value]) -> BuiltinResult<Value> {
543        match self {
544            SpecialCallable::ProdOfSize => {
545                let value = args.first().ok_or_else(|| {
546                    cellfun_error_with_identifier(
547                        "cellfun: prodofsize requires one input",
548                        IDENT_INVALID_INPUT,
549                    )
550                })?;
551                Ok(Value::Num(value_numel(value).await? as f64))
552            }
553            SpecialCallable::IsClass => {
554                if args.len() < 2 {
555                    return Err(cellfun_error_with_identifier(
556                        "cellfun: 'isclass' requires a class name argument",
557                        IDENT_INVALID_INPUT,
558                    ));
559                }
560                let left = args[0].clone();
561                let class_name = extract_string(&args[1]).ok_or_else(|| {
562                    cellfun_error_with_identifier(
563                        "cellfun: class name must be a string scalar",
564                        IDENT_INVALID_INPUT,
565                    )
566                })?;
567                let class_value = call_builtin_async("class", &[left]).await?;
568                let class_str = extract_string(&class_value).ok_or_else(|| {
569                    cellfun_error_with_identifier(
570                        "cellfun: failed to evaluate class name",
571                        IDENT_FUNCTION_ERROR,
572                    )
573                })?;
574                Ok(Value::Bool(
575                    class_str.eq_ignore_ascii_case(class_name.trim()),
576                ))
577            }
578        }
579    }
580}
581
582enum UniformCollector {
583    Pending,
584    Double(Vec<f64>),
585    Logical(Vec<u8>),
586    Complex(Vec<(f64, f64)>),
587}
588
589impl UniformCollector {
590    fn push(&mut self, value: &Value) -> BuiltinResult<()> {
591        match self {
592            UniformCollector::Pending => match classify_value(value)? {
593                ClassifiedValue::Logical(b) => {
594                    *self = UniformCollector::Logical(vec![b as u8]);
595                    Ok(())
596                }
597                ClassifiedValue::Double(d) => {
598                    *self = UniformCollector::Double(vec![d]);
599                    Ok(())
600                }
601                ClassifiedValue::Complex(c) => {
602                    *self = UniformCollector::Complex(vec![c]);
603                    Ok(())
604                }
605            },
606            UniformCollector::Logical(bits) => match classify_value(value)? {
607                ClassifiedValue::Logical(b) => {
608                    bits.push(b as u8);
609                    Ok(())
610                }
611                ClassifiedValue::Double(d) => {
612                    let mut data: Vec<f64> = bits
613                        .iter()
614                        .map(|&bit| if bit != 0 { 1.0 } else { 0.0 })
615                        .collect();
616                    data.push(d);
617                    *self = UniformCollector::Double(data);
618                    Ok(())
619                }
620                ClassifiedValue::Complex(c) => {
621                    let mut data: Vec<(f64, f64)> = bits
622                        .iter()
623                        .map(|&bit| if bit != 0 { (1.0, 0.0) } else { (0.0, 0.0) })
624                        .collect();
625                    data.push(c);
626                    *self = UniformCollector::Complex(data);
627                    Ok(())
628                }
629            },
630            UniformCollector::Double(data) => match classify_value(value)? {
631                ClassifiedValue::Logical(b) => {
632                    data.push(if b { 1.0 } else { 0.0 });
633                    Ok(())
634                }
635                ClassifiedValue::Double(d) => {
636                    data.push(d);
637                    Ok(())
638                }
639                ClassifiedValue::Complex(c) => {
640                    let promoted: Vec<(f64, f64)> = data.iter().map(|&v| (v, 0.0)).collect();
641                    let mut complex = promoted;
642                    complex.push(c);
643                    *self = UniformCollector::Complex(complex);
644                    Ok(())
645                }
646            },
647            UniformCollector::Complex(data) => match classify_value(value)? {
648                ClassifiedValue::Logical(b) => {
649                    data.push((if b { 1.0 } else { 0.0 }, 0.0));
650                    Ok(())
651                }
652                ClassifiedValue::Double(d) => {
653                    data.push((d, 0.0));
654                    Ok(())
655                }
656                ClassifiedValue::Complex(c) => {
657                    data.push(c);
658                    Ok(())
659                }
660            },
661        }
662    }
663
664    fn finish(self, shape: &[usize]) -> BuiltinResult<Value> {
665        match self {
666            UniformCollector::Pending => {
667                let total = total_len(shape).unwrap_or(0);
668                let data = vec![0.0; total];
669                let tensor = Tensor::new(data, shape.to_vec())
670                    .map_err(|e| cellfun_error(format!("cellfun: {e}")))?;
671                Ok(Value::Tensor(tensor))
672            }
673            UniformCollector::Double(data) => {
674                let tensor = Tensor::new(data, shape.to_vec())
675                    .map_err(|e| cellfun_error(format!("cellfun: {e}")))?;
676                Ok(Value::Tensor(tensor))
677            }
678            UniformCollector::Logical(bits) => {
679                let logical = LogicalArray::new(bits, shape.to_vec())
680                    .map_err(|e| cellfun_error(format!("cellfun: {e}")))?;
681                Ok(Value::LogicalArray(logical))
682            }
683            UniformCollector::Complex(data) => {
684                let complex = ComplexTensor::new(data, shape.to_vec())
685                    .map_err(|e| cellfun_error(format!("cellfun: {e}")))?;
686                Ok(Value::ComplexTensor(complex))
687            }
688        }
689    }
690}
691
692enum ClassifiedValue {
693    Logical(bool),
694    Double(f64),
695    Complex((f64, f64)),
696}
697
698fn classify_value(value: &Value) -> BuiltinResult<ClassifiedValue> {
699    match value {
700        Value::Bool(b) => Ok(ClassifiedValue::Logical(*b)),
701        Value::Num(n) => Ok(ClassifiedValue::Double(*n)),
702        Value::Int(iv) => Ok(ClassifiedValue::Double(iv.to_f64())),
703        Value::Complex(re, im) => Ok(ClassifiedValue::Complex((*re, *im))),
704        Value::Tensor(t) if t.data.len() == 1 => Ok(ClassifiedValue::Double(t.data[0])),
705        Value::LogicalArray(la) if la.data.len() == 1 => {
706            Ok(ClassifiedValue::Logical(la.data[0] != 0))
707        }
708        Value::ComplexTensor(ct) if ct.data.len() == 1 => Ok(ClassifiedValue::Complex(ct.data[0])),
709        _ => Err(cellfun_error_with_identifier(
710            "cellfun: callback must return scalar values when 'UniformOutput' is true",
711            IDENT_UNIFORM_OUTPUT,
712        )),
713    }
714}
715
716#[cfg(test)]
717pub(crate) mod tests {
718    use super::*;
719    use crate::builtins::common::test_support;
720    use futures::executor::block_on;
721    use runmat_accelerate_api::HostTensorView;
722    use runmat_builtins::{IntValue, StringArray};
723    use std::convert::TryInto;
724
725    fn cellfun_builtin(func: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
726        block_on(super::cellfun_builtin(func, rest))
727    }
728
729    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
730    #[test]
731    fn cellfun_length_uniform_default() {
732        let cell = crate::make_cell(
733            vec![
734                Value::Tensor(Tensor::new(vec![1.0, 2.0, 3.0], vec![1, 3]).unwrap()),
735                Value::Tensor(Tensor::new(vec![4.0, 5.0, 6.0, 7.0], vec![1, 4]).unwrap()),
736                Value::Tensor(Tensor::new(vec![8.0, 9.0], vec![1, 2]).unwrap()),
737            ],
738            1,
739            3,
740        )
741        .expect("cell");
742        let result =
743            cellfun_builtin(Value::String("@length".into()), vec![cell]).expect("cellfun length");
744        match result {
745            Value::Tensor(t) => {
746                assert_eq!(t.shape, vec![1, 3]);
747                assert_eq!(t.data, vec![3.0, 4.0, 2.0]);
748            }
749            other => panic!("expected tensor, got {other:?}"),
750        }
751    }
752
753    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
754    #[test]
755    fn cellfun_multiple_cells_plus() {
756        let left = crate::make_cell(
757            vec![Value::Num(1.0), Value::Num(2.0), Value::Num(3.0)],
758            1,
759            3,
760        )
761        .expect("cell");
762        let right = crate::make_cell(
763            vec![Value::Num(4.0), Value::Num(5.0), Value::Num(6.0)],
764            1,
765            3,
766        )
767        .expect("cell");
768        let result = cellfun_builtin(Value::String("@__cellfun_add".into()), vec![left, right])
769            .expect("cellfun add");
770        match result {
771            Value::Tensor(t) => {
772                assert_eq!(t.data, vec![5.0, 7.0, 9.0]);
773            }
774            other => panic!("expected tensor, got {other:?}"),
775        }
776    }
777
778    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
779    #[test]
780    fn cellfun_uniform_false_returns_cells() {
781        let cell = crate::make_cell(
782            vec![
783                Value::String("Ada".into()),
784                Value::String("Linus".into()),
785                Value::String("Katherine".into()),
786            ],
787            1,
788            3,
789        )
790        .expect("cell");
791        let result = cellfun_builtin(
792            Value::String("@upper".into()),
793            vec![
794                cell,
795                Value::String("UniformOutput".into()),
796                Value::Bool(false),
797            ],
798        )
799        .expect("cellfun upper");
800        match result {
801            Value::Cell(ca) => {
802                assert_eq!(ca.shape, vec![1, 3]);
803                let upper_a = (*ca.data[0]).clone();
804                let upper_b = (*ca.data[1]).clone();
805                let upper_c = (*ca.data[2]).clone();
806                assert_eq!(extract_string(&upper_a).unwrap(), "ADA");
807                assert_eq!(extract_string(&upper_b).unwrap(), "LINUS");
808                assert_eq!(extract_string(&upper_c).unwrap(), "KATHERINE");
809            }
810            other => panic!("expected cell array, got {other:?}"),
811        }
812    }
813
814    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
815    #[test]
816    fn cellfun_error_handler_recovers() {
817        let cells = crate::make_cell(
818            vec![Value::Num(1.0), Value::Num(2.0), Value::Num(3.0)],
819            1,
820            3,
821        )
822        .expect("cell");
823        let handler = Value::Closure(Closure {
824            function_name: "__cellfun_test_handler".into(),
825            captures: vec![Value::Num(0.0)],
826        });
827        let result = cellfun_builtin(
828            Value::String("@nonexistent_builtin".into()),
829            vec![cells, Value::String("ErrorHandler".into()), handler],
830        )
831        .expect("cellfun error handler");
832        match result {
833            Value::Tensor(t) => {
834                assert_eq!(t.data, vec![0.0, 0.0, 0.0]);
835            }
836            other => panic!("expected tensor, got {other:?}"),
837        }
838    }
839
840    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
841    #[test]
842    fn cellfun_string_identifier() {
843        let cells = crate::make_cell(
844            vec![
845                Value::CharArray(runmat_builtins::CharArray::new_row("")),
846                Value::CharArray(runmat_builtins::CharArray::new_row("abc")),
847                Value::CharArray(runmat_builtins::CharArray::new_row("")),
848            ],
849            1,
850            3,
851        )
852        .expect("cell");
853        let result = cellfun_builtin(
854            Value::CharArray(runmat_builtins::CharArray::new_row("isempty")),
855            vec![cells],
856        )
857        .expect("isempty");
858        match result {
859            Value::LogicalArray(la) => {
860                assert_eq!(la.shape, vec![1, 3]);
861                assert_eq!(la.data, vec![1, 0, 1]);
862            }
863            other => panic!("expected logical array, got {other:?}"),
864        }
865    }
866
867    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
868    #[test]
869    fn cellfun_string_array_identifier() {
870        let cells = crate::make_cell(
871            vec![Value::CharArray(runmat_builtins::CharArray::new_row(""))],
872            1,
873            1,
874        )
875        .expect("cell");
876        let sa = StringArray::new(vec!["isempty".into()], vec![1, 1]).unwrap();
877        let result =
878            cellfun_builtin(Value::StringArray(sa), vec![cells]).expect("cellfun string array");
879        match result {
880            Value::LogicalArray(la) => {
881                assert_eq!(la.shape, vec![1, 1]);
882                assert_eq!(la.data, vec![1]);
883            }
884            other => panic!("expected logical array, got {other:?}"),
885        }
886    }
887
888    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
889    #[test]
890    fn cellfun_uniform_true_non_scalar_errors() {
891        let cells = crate::make_cell(
892            vec![Value::Tensor(
893                Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap(),
894            )],
895            1,
896            1,
897        )
898        .expect("cell");
899        let err = cellfun_builtin(Value::String("@eye".into()), vec![cells])
900            .unwrap_err()
901            .to_string();
902        assert!(
903            err.to_ascii_lowercase().contains("uniformoutput"),
904            "unexpected error: {err}"
905        );
906    }
907
908    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
909    #[test]
910    fn cellfun_uniform_promotes_logical_to_double() {
911        let cells = crate::make_cell(vec![Value::Bool(true), Value::Num(2.5)], 1, 2).unwrap();
912        let result = cellfun_builtin(Value::String("@__cellfun_identity".into()), vec![cells])
913            .expect("cellfun identity");
914        match result {
915            Value::Tensor(t) => {
916                assert_eq!(t.shape, vec![1, 2]);
917                assert_eq!(t.data, vec![1.0, 2.5]);
918            }
919            other => panic!("expected tensor, got {other:?}"),
920        }
921    }
922
923    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
924    #[test]
925    fn cellfun_uniform_promotes_double_to_complex() {
926        let cells =
927            crate::make_cell(vec![Value::Num(2.0), Value::Complex(0.0, 1.0)], 1, 2).unwrap();
928        let result = cellfun_builtin(Value::String("@__cellfun_identity".into()), vec![cells])
929            .expect("cellfun identity");
930        match result {
931            Value::ComplexTensor(ct) => {
932                assert_eq!(ct.shape, vec![1, 2]);
933                assert_eq!(ct.data, vec![(2.0, 0.0), (0.0, 1.0)]);
934            }
935            other => panic!("expected complex tensor, got {other:?}"),
936        }
937    }
938
939    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
940    #[test]
941    fn cellfun_errors_on_mismatched_cell_sizes() {
942        let first = crate::make_cell(vec![Value::Num(1.0), Value::Num(2.0)], 1, 2).unwrap();
943        let second = crate::make_cell(vec![Value::Num(3.0)], 1, 1).unwrap();
944        let err = cellfun_builtin(
945            Value::String("@__cellfun_identity".into()),
946            vec![first, second],
947        )
948        .unwrap_err()
949        .to_string();
950        assert!(
951            err.to_ascii_lowercase().contains("size"),
952            "expected size mismatch error, got: {err}"
953        );
954    }
955
956    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
957    #[test]
958    fn cellfun_uniformoutput_accepts_char_flags() {
959        let strings =
960            crate::make_cell(vec![Value::String("Ada".into())], 1, 1).expect("cell creation");
961        let result = cellfun_builtin(
962            Value::String("@upper".into()),
963            vec![
964                strings,
965                Value::CharArray(runmat_builtins::CharArray::new_row("UniformOutput")),
966                Value::CharArray(runmat_builtins::CharArray::new_row("off")),
967            ],
968        )
969        .expect("cellfun upper char flag");
970        assert!(
971            matches!(result, Value::Cell(_)),
972            "expected cell array result when UniformOutput is 'off'"
973        );
974    }
975
976    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
977    #[test]
978    fn cellfun_isclass_special_case() {
979        let ints = crate::make_cell(
980            vec![
981                Value::Int(IntValue::I32(5)),
982                Value::Num(std::f64::consts::PI),
983                Value::Int(IntValue::I16(2)),
984            ],
985            1,
986            3,
987        )
988        .expect("cell");
989        let result = cellfun_builtin(
990            Value::String("isclass".into()),
991            vec![ints, Value::String("int32".into())],
992        )
993        .expect("cellfun isclass");
994        match result {
995            Value::LogicalArray(la) => {
996                assert_eq!(la.data, vec![1, 0, 0]);
997            }
998            other => panic!("expected logical array, got {other:?}"),
999        }
1000    }
1001
1002    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1003    #[test]
1004    fn cellfun_passes_additional_arguments() {
1005        let matrices = crate::make_cell(
1006            vec![
1007                Value::Tensor(Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap()),
1008                Value::Tensor(Tensor::new(vec![5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap()),
1009            ],
1010            1,
1011            2,
1012        )
1013        .expect("cell");
1014        let dimension = Value::Num(2.0);
1015        let result = cellfun_builtin(Value::String("size".into()), vec![matrices, dimension])
1016            .expect("cellfun size");
1017        match result {
1018            Value::Tensor(t) => {
1019                assert_eq!(t.data, vec![2.0, 2.0]);
1020            }
1021            other => panic!("expected tensor, got {other:?}"),
1022        }
1023    }
1024
1025    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1026    #[test]
1027    fn cellfun_handles_string_array_uniform_false() {
1028        let sa = StringArray::new(vec!["foo".into(), "bar".into()], vec![1, 2]).unwrap();
1029        let cell = crate::make_cell(vec![Value::StringArray(sa)], 1, 1).unwrap();
1030        let result = cellfun_builtin(
1031            Value::String("@strlength".into()),
1032            vec![
1033                cell,
1034                Value::String("UniformOutput".into()),
1035                Value::Bool(false),
1036            ],
1037        )
1038        .unwrap();
1039        match result {
1040            Value::Cell(ca) => {
1041                assert_eq!(ca.shape, vec![1, 1]);
1042                let inner = (*ca.data[0]).clone();
1043                match inner {
1044                    Value::Tensor(t) => assert_eq!(t.data, vec![3.0, 3.0]),
1045                    _ => panic!("expected tensor inside cell"),
1046                }
1047            }
1048            other => panic!("expected cell, got {other:?}"),
1049        }
1050    }
1051
1052    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1053    #[test]
1054    fn cellfun_gathers_gpu_inputs() {
1055        test_support::with_test_provider(|provider| {
1056            let angle = std::f64::consts::PI / 6.0;
1057            let tensor = Tensor::new(vec![angle], vec![1, 1]).unwrap();
1058            let view = HostTensorView {
1059                data: &tensor.data,
1060                shape: &tensor.shape,
1061            };
1062            let handle = provider.upload(&view).expect("upload");
1063            let cell = crate::make_cell(vec![Value::GpuTensor(handle)], 1, 1).expect("cell");
1064            let result =
1065                cellfun_builtin(Value::String("@sin".into()), vec![cell]).expect("cellfun sin");
1066            let gathered = test_support::gather(result).expect("gather");
1067            assert_eq!(gathered.shape, vec![1, 1]);
1068            let expected = angle.sin();
1069            assert!((gathered.data[0] - expected).abs() < 1e-12);
1070        });
1071    }
1072
1073    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1074    #[test]
1075    #[cfg(feature = "wgpu")]
1076    fn cellfun_with_wgpu_provider_handles_gpu_cells() {
1077        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1078            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1079        );
1080        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1081
1082        let value = Tensor::new(vec![0.25], vec![1, 1]).unwrap();
1083        let view = HostTensorView {
1084            data: &value.data,
1085            shape: &value.shape,
1086        };
1087        let handle = provider.upload(&view).expect("upload");
1088        let cell = crate::make_cell(vec![Value::GpuTensor(handle)], 1, 1).expect("cell");
1089
1090        let result =
1091            cellfun_builtin(Value::String("@sin".into()), vec![cell]).expect("cellfun sin");
1092        let gathered = test_support::gather(result).expect("gather");
1093        assert_eq!(gathered.shape, vec![1, 1]);
1094        let expected = value.data[0].sin();
1095        assert!((gathered.data[0] - expected).abs() < 1e-12);
1096    }
1097
1098    #[runmat_macros::runtime_builtin(
1099        name = "__cellfun_test_handler",
1100        type_resolver(cellfun_type),
1101        builtin_path = "crate::builtins::cells::core::cellfun::tests"
1102    )]
1103    fn cellfun_test_handler(
1104        seed: Value,
1105        _err: Value,
1106        rest: Vec<Value>,
1107    ) -> crate::BuiltinResult<Value> {
1108        // Return the captured seed regardless of the inputs; ensure rest is present for coverage.
1109        let _ = rest;
1110        Ok(seed)
1111    }
1112
1113    #[runmat_macros::runtime_builtin(
1114        name = "__cellfun_add",
1115        type_resolver(cellfun_type),
1116        builtin_path = "crate::builtins::cells::core::cellfun::tests"
1117    )]
1118    fn cellfun_add(lhs: Value, rhs: Value) -> crate::BuiltinResult<Value> {
1119        let a: f64 = (&lhs).try_into()?;
1120        let b: f64 = (&rhs).try_into()?;
1121        Ok(Value::Num(a + b))
1122    }
1123
1124    #[runmat_macros::runtime_builtin(
1125        name = "__cellfun_identity",
1126        type_resolver(cellfun_type),
1127        builtin_path = "crate::builtins::cells::core::cellfun::tests"
1128    )]
1129    fn cellfun_identity(value: Value) -> crate::BuiltinResult<Value> {
1130        Ok(value)
1131    }
1132}