Skip to main content

runmat_runtime/builtins/acceleration/gpu/
pagefun.rs

1//! MATLAB-compatible `pagefun` builtin.
2//!
3//! The `pagefun` builtin applies a MATLAB operator to every 2-D page across the
4//! trailing dimensions of the supplied inputs. This mirrors MathWorks MATLAB
5//! semantics for GPU arrays while retaining host fallbacks when GPU providers
6//! do not expose specialised kernels.
7
8use crate::builtins::acceleration::gpu::type_resolvers::pagefun_type;
9use crate::builtins::common::spec::{
10    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
11    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
12};
13use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
14use runmat_accelerate_api::{GpuTensorHandle, HostTensorView, PagefunOp, PagefunRequest};
15use runmat_builtins::{ComplexTensor, Tensor, Value};
16use runmat_macros::runtime_builtin;
17
18type ComplexMatrixData = (Vec<(f64, f64)>, usize, usize);
19
20#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::acceleration::gpu::pagefun")]
21pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
22    name: "pagefun",
23    op_kind: GpuOpKind::Custom("pagefun"),
24    supported_precisions: &[ScalarType::F32, ScalarType::F64],
25    broadcast: BroadcastSemantics::Matlab,
26    provider_hooks: &[ProviderHook::Custom("pagefun")],
27    constant_strategy: ConstantStrategy::InlineLiteral,
28    residency: ResidencyPolicy::NewHandle,
29    nan_mode: ReductionNaN::Include,
30    two_pass_threshold: None,
31    workgroup_size: None,
32    accepts_nan_mode: false,
33    notes: "WGPU provider accelerates batched @mtimes; runtimes gather to host when no provider hook is available.",
34};
35
36#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::acceleration::gpu::pagefun")]
37pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
38    name: "pagefun",
39    shape: ShapeRequirements::Any,
40    constant_strategy: ConstantStrategy::InlineLiteral,
41    elementwise: None,
42    reduction: None,
43    emits_nan: false,
44    notes: "Acts as a fusion barrier because pagefun can invoke arbitrary MATLAB operators.",
45};
46
47fn pagefun_error(message: impl Into<String>) -> RuntimeError {
48    build_runtime_error(message).with_builtin("pagefun").build()
49}
50
51#[runtime_builtin(
52    name = "pagefun",
53    category = "acceleration/gpu",
54    summary = "Apply MATLAB operators page-by-page across higher-dimensional arrays.",
55    keywords = "pagefun,gpuArray,mtimes,pages,batch",
56    accel = "custom",
57    type_resolver(pagefun_type),
58    builtin_path = "crate::builtins::acceleration::gpu::pagefun"
59)]
60async fn pagefun_builtin(
61    func: Value,
62    first: Value,
63    rest: Vec<Value>,
64) -> crate::BuiltinResult<Value> {
65    let operation = PageOperation::from_callable(func)?;
66    let mut operands = Vec::with_capacity(rest.len() + 1);
67    operands.push(first);
68    operands.extend(rest);
69    if operands.is_empty() {
70        return Err(pagefun_error("pagefun: requires at least one array input"));
71    }
72
73    operation.validate_arity(operands.len())?;
74
75    if let Some(value) = try_pagefun_gpu(&operation, &operands)? {
76        return Ok(value);
77    }
78
79    let all_gpu = operands.iter().all(|v| matches!(v, Value::GpuTensor(_)));
80    let mut host_values = Vec::with_capacity(operands.len());
81    for value in operands {
82        host_values.push(gather_if_needed_async(&value).await?);
83    }
84
85    let mut page_inputs = Vec::with_capacity(host_values.len());
86    for value in host_values {
87        page_inputs.push(PageInput::from_value(value)?);
88    }
89
90    let rank = page_inputs
91        .iter()
92        .map(|input| input.page_dims.len())
93        .max()
94        .unwrap_or(0);
95
96    let mut result_page_dims = if rank == 0 {
97        Vec::new()
98    } else {
99        vec![1usize; rank]
100    };
101
102    for dim in 0..rank {
103        let mut target = 1usize;
104        for input in &page_inputs {
105            let size = input.page_dims.get(dim).copied().unwrap_or(1);
106            if size == 0 {
107                target = 0;
108                break;
109            }
110            if size != 1 {
111                if target == 1 {
112                    target = size;
113                } else if target != size {
114                    return Err(pagefun_error(format!(
115                        "pagefun: page dimension {} mismatch ({} vs {})",
116                        dim + 3,
117                        target,
118                        size
119                    )));
120                }
121            }
122        }
123        if !result_page_dims.is_empty() {
124            result_page_dims[dim] = target;
125        }
126    }
127
128    let page_volume = if rank == 0 {
129        1usize
130    } else {
131        result_page_dims.iter().copied().product()
132    };
133
134    let mut prepared_inputs = Vec::with_capacity(page_inputs.len());
135    for input in page_inputs {
136        prepared_inputs.push(PreparedInput::new(input, rank));
137    }
138
139    operation.validate_shapes(&prepared_inputs)?;
140    let output_kind = operation.output_kind(&prepared_inputs);
141    let (mut result_rows, mut result_cols) =
142        operation.output_matrix_shape(&prepared_inputs, output_kind)?;
143
144    if page_volume == 0 {
145        return finalise_empty_output(
146            &operation,
147            &prepared_inputs,
148            &result_page_dims,
149            output_kind,
150            all_gpu,
151        );
152    }
153
154    let mut real_data: Option<Vec<f64>> = None;
155    let mut complex_data: Option<Vec<(f64, f64)>> = None;
156    let mut multi_index = vec![0usize; rank];
157
158    let mut page_counter = 0usize;
159    loop {
160        let mut page_args = Vec::with_capacity(prepared_inputs.len());
161        for input in &prepared_inputs {
162            page_args.push(input.page_value(&multi_index)?);
163        }
164
165        let mut evaluated = operation.evaluate(&page_args).await?;
166        evaluated = gather_if_needed_async(&evaluated).await?;
167        match output_kind {
168            OutputKind::Real => {
169                let (data, rows, cols) = tensor_matrix_data(evaluated)?;
170                if real_data.is_none() {
171                    result_rows = rows;
172                    result_cols = cols;
173                    real_data = Some(Vec::with_capacity(rows * cols * page_volume));
174                } else if rows != result_rows || cols != result_cols {
175                    return Err(pagefun_error(
176                        "pagefun: result matrices must be the same size",
177                    ));
178                }
179                if let Some(vec) = real_data.as_mut() {
180                    vec.extend(data);
181                }
182            }
183            OutputKind::Complex => {
184                let (data, rows, cols) = complex_matrix_data(evaluated)?;
185                if complex_data.is_none() {
186                    result_rows = rows;
187                    result_cols = cols;
188                    complex_data = Some(Vec::with_capacity(rows * cols * page_volume));
189                } else if rows != result_rows || cols != result_cols {
190                    return Err(pagefun_error(
191                        "pagefun: result matrices must be the same size",
192                    ));
193                }
194                if let Some(vec) = complex_data.as_mut() {
195                    vec.extend(data);
196                }
197            }
198        }
199
200        page_counter += 1;
201        if page_counter == page_volume {
202            break;
203        }
204        increment_multi_index(&mut multi_index, &result_page_dims)?;
205    }
206
207    let final_shape = assemble_shape(result_rows, result_cols, &result_page_dims);
208    let output = match output_kind {
209        OutputKind::Real => {
210            let data = real_data.unwrap_or_default();
211            let tensor = Tensor::new(data, final_shape).map_err(|e| {
212                pagefun_error(format!("pagefun: failed to construct result tensor ({e})"))
213            })?;
214            FinalOutput::Real(tensor)
215        }
216        OutputKind::Complex => {
217            let data = complex_data.unwrap_or_default();
218            let tensor = ComplexTensor::new(data, final_shape).map_err(|e| {
219                pagefun_error(format!(
220                    "pagefun: failed to construct complex result tensor ({e})"
221                ))
222            })?;
223            FinalOutput::Complex(tensor)
224        }
225    };
226
227    output.into_value(all_gpu)
228}
229
230fn try_pagefun_gpu(operation: &PageOperation, operands: &[Value]) -> BuiltinResult<Option<Value>> {
231    if operands.is_empty() {
232        return Ok(None);
233    }
234    if !operands
235        .iter()
236        .all(|value| matches!(value, Value::GpuTensor(_)))
237    {
238        return Ok(None);
239    }
240
241    #[cfg(all(test, feature = "wgpu"))]
242    {
243        // Reassert WGPU provider only when operands are WGPU handles (device_id != 0).
244        if operands
245            .iter()
246            .any(|v| matches!(v, Value::GpuTensor(h) if h.device_id != 0))
247        {
248            let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
249                runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
250            );
251        }
252    }
253    let Some(provider) = runmat_accelerate_api::provider() else {
254        return Ok(None);
255    };
256
257    let handles: Vec<GpuTensorHandle> = operands
258        .iter()
259        .map(|value| match value {
260            Value::GpuTensor(handle) => handle.clone(),
261            _ => unreachable!(),
262        })
263        .collect();
264
265    let request = match build_pagefun_request(operation, &handles)? {
266        Some(request) => request,
267        None => return Ok(None),
268    };
269
270    match provider.pagefun(&request) {
271        Ok(handle) => Ok(Some(Value::GpuTensor(handle))),
272        Err(err) => {
273            log::debug!("pagefun: provider hook unavailable, falling back to host: {err}");
274            Ok(None)
275        }
276    }
277}
278
279fn build_pagefun_request(
280    operation: &PageOperation,
281    handles: &[GpuTensorHandle],
282) -> BuiltinResult<Option<PagefunRequest>> {
283    match operation {
284        PageOperation::Mtimes => {
285            if handles.len() != 2 {
286                return Err(pagefun_error(
287                    "pagefun: @mtimes requires exactly two array inputs",
288                ));
289            }
290
291            let (lhs_rows, lhs_cols, lhs_pages) = handle_matrix_meta(&handles[0])?;
292            let (rhs_rows, rhs_cols, rhs_pages) = handle_matrix_meta(&handles[1])?;
293            if lhs_cols != rhs_rows {
294                return Err(pagefun_error(format!(
295                    "pagefun: inner matrix dimensions must agree ({}x{} * {}x{})",
296                    lhs_rows, lhs_cols, rhs_rows, rhs_cols
297                )));
298            }
299
300            let rank = lhs_pages.len().max(rhs_pages.len());
301            let mut result_page_dims = if rank == 0 {
302                Vec::new()
303            } else {
304                vec![1usize; rank]
305            };
306
307            for dim in 0..rank {
308                let mut target = 1usize;
309                let dims_to_check = [
310                    lhs_pages.get(dim).copied().unwrap_or(1),
311                    rhs_pages.get(dim).copied().unwrap_or(1),
312                ];
313                for size in dims_to_check {
314                    if size == 0 {
315                        target = 0;
316                        break;
317                    }
318                    if size != 1 {
319                        if target == 1 {
320                            target = size;
321                        } else if target != size {
322                            return Err(pagefun_error(format!(
323                                "pagefun: page dimension {} mismatch ({} vs {})",
324                                dim + 3,
325                                target,
326                                size
327                            )));
328                        }
329                    }
330                }
331                if !result_page_dims.is_empty() {
332                    result_page_dims[dim] = target;
333                }
334            }
335
336            let mut input_page_dims = Vec::with_capacity(2);
337            let mut lhs_padded = lhs_pages.clone();
338            lhs_padded.resize(rank, 1);
339            let mut rhs_padded = rhs_pages.clone();
340            rhs_padded.resize(rank, 1);
341            input_page_dims.push(lhs_padded);
342            input_page_dims.push(rhs_padded);
343
344            let mut output_shape = vec![lhs_rows, rhs_cols];
345            output_shape.extend_from_slice(&result_page_dims);
346
347            Ok(Some(PagefunRequest {
348                op: PagefunOp::Mtimes,
349                inputs: handles.to_vec(),
350                output_shape,
351                page_dims: result_page_dims,
352                input_page_dims,
353            }))
354        }
355    }
356}
357
358fn handle_matrix_meta(handle: &GpuTensorHandle) -> BuiltinResult<(usize, usize, Vec<usize>)> {
359    let canonical = canonical_matrix_shape(&handle.shape);
360    if canonical.len() < 2 {
361        return Err(pagefun_error("pagefun: gpu tensor must be at least 2-D"));
362    }
363    let rows = canonical[0];
364    let cols = canonical[1];
365    let pages = if canonical.len() > 2 {
366        canonical[2..].to_vec()
367    } else {
368        Vec::new()
369    };
370    Ok((rows, cols, pages))
371}
372
373fn finalise_empty_output(
374    operation: &PageOperation,
375    inputs: &[PreparedInput],
376    page_dims: &[usize],
377    output_kind: OutputKind,
378    wants_gpu: bool,
379) -> BuiltinResult<Value> {
380    let (rows, cols) = operation.output_matrix_shape(inputs, output_kind)?;
381    let final_shape = assemble_shape(rows, cols, page_dims);
382    let page_factor: usize = if page_dims.is_empty() {
383        1
384    } else {
385        page_dims.iter().copied().product()
386    };
387    let entries = rows
388        .checked_mul(cols)
389        .unwrap_or(0)
390        .checked_mul(page_factor)
391        .unwrap_or(0);
392    match output_kind {
393        OutputKind::Real => {
394            let tensor = Tensor::new(vec![0.0; entries], final_shape).map_err(|e| {
395                pagefun_error(format!("pagefun: failed to build empty tensor ({e})"))
396            })?;
397            FinalOutput::Real(tensor).into_value(wants_gpu)
398        }
399        OutputKind::Complex => {
400            let tensor =
401                ComplexTensor::new(vec![(0.0, 0.0); entries], final_shape).map_err(|e| {
402                    pagefun_error(format!(
403                        "pagefun: failed to build empty complex tensor ({e})"
404                    ))
405                })?;
406            FinalOutput::Complex(tensor).into_value(false)
407        }
408    }
409}
410
411fn assemble_shape(rows: usize, cols: usize, page_dims: &[usize]) -> Vec<usize> {
412    let mut shape = vec![rows, cols];
413    shape.extend_from_slice(page_dims);
414    shape
415}
416
417fn increment_multi_index(indices: &mut [usize], dims: &[usize]) -> BuiltinResult<()> {
418    if dims.contains(&0) {
419        return Ok(());
420    }
421    for (dim, &limit) in dims.iter().enumerate() {
422        if limit == 0 {
423            continue;
424        }
425        indices[dim] += 1;
426        if indices[dim] < limit {
427            return Ok(());
428        }
429        indices[dim] = 0;
430        if dim + 1 == dims.len() {
431            break;
432        }
433    }
434    Ok(())
435}
436
437#[derive(Clone, Copy, Debug, PartialEq, Eq)]
438enum OutputKind {
439    Real,
440    Complex,
441}
442
443enum FinalOutput {
444    Real(Tensor),
445    Complex(ComplexTensor),
446}
447
448impl FinalOutput {
449    fn into_value(self, wants_gpu: bool) -> BuiltinResult<Value> {
450        match self {
451            FinalOutput::Real(tensor) => {
452                if wants_gpu {
453                    #[cfg(all(test, feature = "wgpu"))]
454                    {
455                        if runmat_accelerate_api::provider().is_none() {
456                            let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
457                                runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
458                            );
459                        }
460                    }
461                    if let Some(provider) = runmat_accelerate_api::provider() {
462                        let view = HostTensorView {
463                            data: &tensor.data,
464                            shape: &tensor.shape,
465                        };
466                        if let Ok(handle) = provider.upload(&view) {
467                            return Ok(Value::GpuTensor(handle));
468                        }
469                    }
470                }
471                Ok(Value::Tensor(tensor))
472            }
473            FinalOutput::Complex(tensor) => Ok(Value::ComplexTensor(tensor)),
474        }
475    }
476}
477
478#[derive(Clone)]
479struct PageInput {
480    page_dims: Vec<usize>,
481    rows: usize,
482    cols: usize,
483    data: PageData,
484}
485
486#[derive(Clone)]
487enum PageData {
488    Real(Vec<f64>),
489    Complex(Vec<(f64, f64)>),
490}
491
492impl PageInput {
493    fn from_value(value: Value) -> BuiltinResult<Self> {
494        match value {
495            Value::Tensor(t) => Self::from_tensor(t),
496            Value::Num(n) => Self::from_tensor(
497                Tensor::new(vec![n], vec![1, 1])
498                    .map_err(|e| pagefun_error(format!("pagefun: {e}")))?,
499            ),
500            Value::Int(i) => Self::from_tensor(
501                Tensor::new(vec![i.to_f64()], vec![1, 1])
502                    .map_err(|e| pagefun_error(format!("pagefun: {e}")))?,
503            ),
504            Value::Bool(flag) => Self::from_tensor(
505                Tensor::new(vec![if flag { 1.0 } else { 0.0 }], vec![1, 1])
506                    .map_err(|e| pagefun_error(format!("pagefun: {e}")))?,
507            ),
508            Value::Complex(re, im) => {
509                let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
510                    .map_err(|e| pagefun_error(format!("pagefun: {e}")))?;
511                Self::from_complex_tensor(tensor)
512            }
513            Value::ComplexTensor(t) => Self::from_complex_tensor(t),
514            other => Err(pagefun_error(format!(
515                "pagefun: unsupported input type {}",
516                other.type_name()
517            ))),
518        }
519    }
520
521    fn from_tensor(tensor: Tensor) -> BuiltinResult<Self> {
522        let shape = canonical_matrix_shape(&tensor.shape);
523        if tensor.data.len() != shape.iter().copied().product::<usize>() {
524            return Err(pagefun_error(
525                "pagefun: tensor data does not match its shape",
526            ));
527        }
528        let rows = shape[0];
529        let cols = shape[1];
530        let page_dims = if shape.len() > 2 {
531            shape[2..].to_vec()
532        } else {
533            Vec::new()
534        };
535        Ok(Self {
536            page_dims,
537            rows,
538            cols,
539            data: PageData::Real(tensor.data),
540        })
541    }
542
543    fn from_complex_tensor(tensor: ComplexTensor) -> BuiltinResult<Self> {
544        let shape = canonical_matrix_shape(&tensor.shape);
545        if tensor.data.len() != shape.iter().copied().product::<usize>() {
546            return Err(pagefun_error(
547                "pagefun: tensor data does not match its shape",
548            ));
549        }
550        let rows = shape[0];
551        let cols = shape[1];
552        let page_dims = if shape.len() > 2 {
553            shape[2..].to_vec()
554        } else {
555            Vec::new()
556        };
557        Ok(Self {
558            page_dims,
559            rows,
560            cols,
561            data: PageData::Complex(tensor.data),
562        })
563    }
564
565    fn page_size(&self) -> usize {
566        self.rows * self.cols
567    }
568
569    fn is_complex(&self) -> bool {
570        matches!(self.data, PageData::Complex(_))
571    }
572}
573
574struct PreparedInput {
575    data: PageInput,
576    padded_dims: Vec<usize>,
577    strides: Vec<usize>,
578}
579
580impl PreparedInput {
581    fn new(input: PageInput, rank: usize) -> Self {
582        let mut padded = input.page_dims.clone();
583        padded.resize(rank, 1);
584        let strides = compute_strides(&padded);
585        Self {
586            data: input,
587            padded_dims: padded,
588            strides,
589        }
590    }
591
592    fn rows(&self) -> usize {
593        self.data.rows
594    }
595
596    fn cols(&self) -> usize {
597        self.data.cols
598    }
599
600    fn is_complex(&self) -> bool {
601        self.data.is_complex()
602    }
603
604    fn page_value(&self, multi_index: &[usize]) -> BuiltinResult<Value> {
605        let mut linear_page = 0usize;
606        for (dim, stride) in self.strides.iter().enumerate() {
607            let source_extent = self.padded_dims.get(dim).copied().unwrap_or(1);
608            let requested = multi_index.get(dim).copied().unwrap_or(0);
609            if source_extent == 0 {
610                return Err(pagefun_error("pagefun: source page extent is zero"));
611            }
612            if source_extent != 1 && requested >= source_extent {
613                return Err(pagefun_error("pagefun: page index out of bounds"));
614            }
615            let actual = if source_extent == 1 { 0 } else { requested };
616            linear_page += actual * stride;
617        }
618
619        let offset = linear_page * self.data.page_size();
620        match &self.data.data {
621            PageData::Real(buffer) => {
622                let end = offset + self.data.page_size();
623                let slice = buffer
624                    .get(offset..end)
625                    .ok_or_else(|| pagefun_error("pagefun: page slice out of bounds"))?;
626                let tensor = Tensor::new(slice.to_vec(), vec![self.data.rows, self.data.cols])
627                    .map_err(|e| pagefun_error(format!("pagefun: {e}")))?;
628                Ok(Value::Tensor(tensor))
629            }
630            PageData::Complex(buffer) => {
631                let end = offset + self.data.page_size();
632                let slice = buffer
633                    .get(offset..end)
634                    .ok_or_else(|| pagefun_error("pagefun: page slice out of bounds"))?;
635                let tensor =
636                    ComplexTensor::new(slice.to_vec(), vec![self.data.rows, self.data.cols])
637                        .map_err(|e| pagefun_error(format!("pagefun: {e}")))?;
638                Ok(Value::ComplexTensor(tensor))
639            }
640        }
641    }
642}
643
644fn compute_strides(dims: &[usize]) -> Vec<usize> {
645    let mut strides = Vec::with_capacity(dims.len());
646    let mut stride = 1usize;
647    for &dim in dims {
648        strides.push(stride);
649        stride = stride.saturating_mul(dim.max(1));
650    }
651    strides
652}
653
654fn tensor_matrix_data(value: Value) -> BuiltinResult<(Vec<f64>, usize, usize)> {
655    match value {
656        Value::Tensor(t) => {
657            if t.shape.len() > 2 {
658                return Err(pagefun_error(
659                    "pagefun: operator returned an array with more than two dimensions",
660                ));
661            }
662            let canonical = canonical_matrix_shape(&t.shape);
663            let rows = canonical[0];
664            let cols = canonical[1];
665            if rows * cols != t.data.len() {
666                return Err(pagefun_error("pagefun: result size mismatch"));
667            }
668            Ok((t.data, rows, cols))
669        }
670        Value::Num(n) => Ok((vec![n], 1, 1)),
671        Value::Int(i) => Ok((vec![i.to_f64()], 1, 1)),
672        other => Err(pagefun_error(format!(
673            "pagefun: expected numeric matrix result, received {}",
674            other.type_name()
675        ))),
676    }
677}
678
679fn complex_matrix_data(value: Value) -> BuiltinResult<ComplexMatrixData> {
680    match value {
681        Value::ComplexTensor(t) => {
682            if t.shape.len() > 2 {
683                return Err(pagefun_error(
684                    "pagefun: operator returned an array with more than two dimensions",
685                ));
686            }
687            let canonical = canonical_matrix_shape(&t.shape);
688            let rows = canonical[0];
689            let cols = canonical[1];
690            if rows * cols != t.data.len() {
691                return Err(pagefun_error("pagefun: result size mismatch"));
692            }
693            Ok((t.data, rows, cols))
694        }
695        Value::Complex(re, im) => Ok((vec![(re, im)], 1, 1)),
696        other => Err(pagefun_error(format!(
697            "pagefun: expected complex matrix result, received {}",
698            other.type_name()
699        ))),
700    }
701}
702
703fn canonical_matrix_shape(shape: &[usize]) -> Vec<usize> {
704    match shape.len() {
705        0 => vec![1, 1],
706        1 => vec![1, shape[0]],
707        _ => {
708            let mut out = shape.to_vec();
709            if out.len() == 1 {
710                out.push(1);
711            }
712            out
713        }
714    }
715}
716
717#[derive(Clone, Copy)]
718enum PageOperation {
719    Mtimes,
720}
721
722impl PageOperation {
723    fn from_callable(value: Value) -> BuiltinResult<Self> {
724        let raw = match value {
725            Value::FunctionHandle(func) => func,
726            Value::String(s) => s,
727            Value::StringArray(sa) => {
728                if sa.data.len() != 1 {
729                    return Err(pagefun_error(
730                        "pagefun: function string array must contain exactly one element",
731                    ));
732                }
733                sa.data[0].clone()
734            }
735            Value::CharArray(chars) => {
736                if chars.rows != 1 {
737                    return Err(pagefun_error(
738                        "pagefun: function char array must be a single row character vector",
739                    ));
740                }
741                chars.data.iter().collect()
742            }
743            other => {
744                return Err(pagefun_error(format!(
745                    "pagefun: unsupported function handle type {}",
746                    other.type_name()
747                )))
748            }
749        };
750        let trimmed = raw.trim();
751        let lowered = trimmed.trim_start_matches('@').to_ascii_lowercase();
752        match lowered.as_str() {
753            "mtimes" => Ok(Self::Mtimes),
754            _ => Err(pagefun_error(format!(
755                "pagefun: unsupported function '{}'; currently only @mtimes is implemented",
756                trimmed
757            ))),
758        }
759    }
760
761    fn validate_arity(&self, arg_count: usize) -> BuiltinResult<()> {
762        match self {
763            Self::Mtimes => {
764                if arg_count != 2 {
765                    return Err(pagefun_error(
766                        "pagefun: @mtimes requires exactly two array inputs",
767                    ));
768                }
769                Ok(())
770            }
771        }
772    }
773
774    fn validate_shapes(&self, inputs: &[PreparedInput]) -> BuiltinResult<()> {
775        match self {
776            Self::Mtimes => {
777                let lhs = &inputs[0];
778                let rhs = &inputs[1];
779                if lhs.cols() != rhs.rows() {
780                    return Err(pagefun_error(format!(
781                        "pagefun: inner matrix dimensions must agree ({}x{} * {}x{})",
782                        lhs.rows(),
783                        lhs.cols(),
784                        rhs.rows(),
785                        rhs.cols()
786                    )));
787                }
788                Ok(())
789            }
790        }
791    }
792
793    async fn evaluate(&self, args: &[Value]) -> crate::BuiltinResult<Value> {
794        match self {
795            Self::Mtimes => crate::call_builtin_async("mtimes", args).await,
796        }
797    }
798
799    fn output_kind(&self, inputs: &[PreparedInput]) -> OutputKind {
800        match self {
801            Self::Mtimes => {
802                if inputs.iter().any(|input| input.is_complex()) {
803                    OutputKind::Complex
804                } else {
805                    OutputKind::Real
806                }
807            }
808        }
809    }
810
811    fn output_matrix_shape(
812        &self,
813        inputs: &[PreparedInput],
814        kind: OutputKind,
815    ) -> BuiltinResult<(usize, usize)> {
816        match self {
817            Self::Mtimes => {
818                let lhs = &inputs[0];
819                let rhs = &inputs[1];
820                let rows = lhs.rows();
821                let cols = rhs.cols();
822                match kind {
823                    OutputKind::Real | OutputKind::Complex => Ok((rows, cols)),
824                }
825            }
826        }
827    }
828}
829
830trait TypeName {
831    fn type_name(&self) -> &'static str;
832}
833
834impl TypeName for Value {
835    fn type_name(&self) -> &'static str {
836        match self {
837            Value::Int(_) => "int",
838            Value::Num(_) => "double",
839            Value::Complex(_, _) => "complex double",
840            Value::Bool(_) => "logical",
841            Value::LogicalArray(_) => "logical array",
842            Value::String(_) => "string",
843            Value::StringArray(_) => "string array",
844            Value::CharArray(_) => "char array",
845            Value::Tensor(_) => "double array",
846            Value::ComplexTensor(_) => "complex double array",
847            Value::Cell(_) => "cell array",
848            Value::Struct(_) => "struct",
849            Value::GpuTensor(_) => "gpuArray",
850            Value::Object(_) => "object",
851            Value::HandleObject(_) => "handle object",
852            Value::Listener(_) => "listener",
853            Value::FunctionHandle(_) => "function handle",
854            Value::Closure(_) => "closure",
855            Value::ClassRef(_) => "class reference",
856            Value::MException(_) => "MException",
857            Value::OutputList(_) => "output list",
858        }
859    }
860}
861
862#[cfg(test)]
863pub(crate) mod tests {
864    use super::*;
865    use crate::builtins::common::test_support;
866    use futures::executor::block_on;
867    use runmat_builtins::{CharArray, ResolveContext, StringArray, Type};
868
869    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
870    #[test]
871    fn pagefun_mtimes_single_page() {
872        let lhs = Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
873        let rhs = Tensor::new(vec![5.0, 7.0, 6.0, 8.0], vec![2, 2]).unwrap();
874        let result = pagefun_builtin(
875            Value::FunctionHandle("mtimes".into()),
876            Value::Tensor(lhs),
877            vec![Value::Tensor(rhs)],
878        );
879        let result = block_on(result).expect("pagefun");
880        match result {
881            Value::Tensor(t) => {
882                assert_eq!(t.shape, vec![2, 2]);
883                assert_eq!(t.data, vec![19.0, 43.0, 22.0, 50.0]);
884            }
885            other => panic!("expected tensor result, got {other:?}"),
886        }
887    }
888
889    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
890    #[test]
891    fn pagefun_mtimes_multiple_pages() {
892        let lhs = Tensor::new(vec![1.0, 3.0, 2.0, 4.0, 2.0, 1.0, 0.0, 3.0], vec![2, 2, 2]).unwrap();
893        let rhs = Tensor::new(vec![5.0, 7.0, 6.0, 8.0, 1.0, 0.0, 2.0, 1.0], vec![2, 2, 2]).unwrap();
894        let result = pagefun_builtin(
895            Value::from("@mtimes"),
896            Value::Tensor(lhs),
897            vec![Value::Tensor(rhs)],
898        );
899        let result = block_on(result).expect("pagefun");
900        match result {
901            Value::Tensor(t) => {
902                assert_eq!(t.shape, vec![2, 2, 2]);
903                assert_eq!(t.data, vec![19.0, 43.0, 22.0, 50.0, 2.0, 1.0, 4.0, 5.0]);
904            }
905            other => panic!("expected tensor result, got {other:?}"),
906        }
907    }
908
909    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
910    #[test]
911    fn pagefun_mtimes_broadcast_rhs() {
912        let lhs = Tensor::new(vec![1.0, 3.0, 2.0, 4.0, 5.0, 7.0, 6.0, 8.0], vec![2, 2, 2]).unwrap();
913        let rhs = Tensor::new(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap();
914        let result = pagefun_builtin(
915            Value::FunctionHandle("mtimes".into()),
916            Value::Tensor(lhs),
917            vec![Value::Tensor(rhs)],
918        );
919        let result = block_on(result).expect("pagefun");
920        match result {
921            Value::Tensor(t) => {
922                assert_eq!(t.shape, vec![2, 2, 2]);
923                assert_eq!(
924                    t.data,
925                    vec![1.0, 3.0, 2.0, 4.0, 5.0, 7.0, 6.0, 8.0],
926                    "broadcasted identity should preserve pages"
927                );
928            }
929            other => panic!("expected tensor result, got {other:?}"),
930        }
931    }
932
933    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
934    #[test]
935    fn pagefun_mtimes_empty_pages() {
936        let lhs = Tensor::new(Vec::new(), vec![2, 2, 0]).unwrap();
937        let rhs = Tensor::new(Vec::new(), vec![2, 2, 0]).unwrap();
938        let result = pagefun_builtin(
939            Value::from("@mtimes"),
940            Value::Tensor(lhs),
941            vec![Value::Tensor(rhs)],
942        );
943        let result = block_on(result).expect("pagefun");
944        match result {
945            Value::Tensor(t) => {
946                assert_eq!(t.shape, vec![2, 2, 0]);
947                assert!(t.data.is_empty());
948            }
949            other => panic!("expected tensor result, got {other:?}"),
950        }
951    }
952
953    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
954    #[test]
955    fn pagefun_mtimes_char_array_handle() {
956        let lhs = Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
957        let rhs = Tensor::new(vec![5.0, 7.0, 6.0, 8.0], vec![2, 2]).unwrap();
958        let func = CharArray::new("@mtimes".chars().collect(), 1, 7).unwrap();
959        let result = pagefun_builtin(
960            Value::CharArray(func),
961            Value::Tensor(lhs),
962            vec![Value::Tensor(rhs)],
963        );
964        let result = block_on(result).expect("pagefun char array");
965        match result {
966            Value::Tensor(t) => {
967                assert_eq!(t.shape, vec![2, 2]);
968                assert_eq!(t.data, vec![19.0, 43.0, 22.0, 50.0]);
969            }
970            other => panic!("expected tensor result, got {other:?}"),
971        }
972    }
973
974    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
975    #[test]
976    fn pagefun_mtimes_string_array_handle() {
977        let lhs = Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
978        let rhs = Tensor::new(vec![5.0, 7.0, 6.0, 8.0], vec![2, 2]).unwrap();
979        let strings = StringArray::new(vec!["@mtimes".to_string()], vec![1]).unwrap();
980        let result = pagefun_builtin(
981            Value::StringArray(strings),
982            Value::Tensor(lhs),
983            vec![Value::Tensor(rhs)],
984        );
985        let result = block_on(result).expect("pagefun string array");
986        match result {
987            Value::Tensor(t) => {
988                assert_eq!(t.shape, vec![2, 2]);
989                assert_eq!(t.data, vec![19.0, 43.0, 22.0, 50.0]);
990            }
991            other => panic!("expected tensor result, got {other:?}"),
992        }
993    }
994
995    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
996    #[test]
997    fn pagefun_char_array_multirow_error() {
998        let chars = CharArray::new("@mtimes@".chars().collect(), 2, 4).unwrap();
999        let lhs = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1000        let rhs = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1001        let err = pagefun_builtin(
1002            Value::CharArray(chars),
1003            Value::Tensor(lhs),
1004            vec![Value::Tensor(rhs)],
1005        );
1006        let err = block_on(err).expect_err("expected multi-row char array error");
1007        assert!(
1008            err.contains("char array"),
1009            "unexpected error for multi-row char array: {err}"
1010        );
1011    }
1012
1013    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1014    #[test]
1015    fn pagefun_string_array_multi_value_error() {
1016        let strings =
1017            StringArray::new(vec!["@mtimes".to_string(), "@mtimes".to_string()], vec![2]).unwrap();
1018        let lhs = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1019        let rhs = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1020        let err = pagefun_builtin(
1021            Value::StringArray(strings),
1022            Value::Tensor(lhs),
1023            vec![Value::Tensor(rhs)],
1024        );
1025        let err = block_on(err).expect_err("expected multi-element string array error");
1026        assert!(
1027            err.contains("string array"),
1028            "unexpected error for string array: {err}"
1029        );
1030    }
1031
1032    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1033    #[test]
1034    fn pagefun_page_dimension_mismatch() {
1035        let lhs = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 2, 2]).unwrap();
1036        let rhs = Tensor::new(
1037            vec![
1038                1.0, 5.0, 2.0, 6.0, 3.0, 7.0, 4.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1039            ],
1040            vec![2, 2, 3],
1041        )
1042        .unwrap();
1043        let err = pagefun_builtin(
1044            Value::FunctionHandle("mtimes".into()),
1045            Value::Tensor(lhs),
1046            vec![Value::Tensor(rhs)],
1047        );
1048        let err = block_on(err).expect_err("expected page dimension mismatch");
1049        assert!(
1050            err.contains("page dimension"),
1051            "unexpected mismatch error message: {err}"
1052        );
1053    }
1054
1055    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1056    #[test]
1057    fn pagefun_mtimes_dim_mismatch() {
1058        let lhs = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1059        let rhs = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1060        let err = pagefun_builtin(
1061            Value::FunctionHandle("mtimes".into()),
1062            Value::Tensor(lhs),
1063            vec![Value::Tensor(rhs)],
1064        );
1065        let err = block_on(err).expect_err("expected dimension mismatch");
1066        assert!(
1067            err.contains("inner matrix dimensions"),
1068            "unexpected error message {err}"
1069        );
1070    }
1071
1072    #[test]
1073    fn pagefun_type_is_tensor() {
1074        assert_eq!(
1075            pagefun_type(&[Type::tensor()], &ResolveContext::new(Vec::new())),
1076            Type::tensor()
1077        );
1078    }
1079
1080    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1081    #[test]
1082    fn pagefun_gpu_roundtrip_mtimes() {
1083        test_support::with_test_provider(|provider| {
1084            let tensor =
1085                Tensor::new(vec![1.0, 3.0, 2.0, 4.0, 5.0, 7.0, 6.0, 8.0], vec![2, 2, 2]).unwrap();
1086            let identity = Tensor::new(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap();
1087
1088            let view_lhs = HostTensorView {
1089                data: &tensor.data,
1090                shape: &tensor.shape,
1091            };
1092            let view_rhs = HostTensorView {
1093                data: &identity.data,
1094                shape: &identity.shape,
1095            };
1096            let lhs = provider.upload(&view_lhs).expect("upload lhs");
1097            let rhs = provider.upload(&view_rhs).expect("upload rhs");
1098
1099            let result = pagefun_builtin(
1100                Value::FunctionHandle("mtimes".into()),
1101                Value::GpuTensor(lhs),
1102                vec![Value::GpuTensor(rhs)],
1103            );
1104            let result = block_on(result).expect("pagefun");
1105
1106            let gathered = test_support::gather(result).expect("gather");
1107            assert_eq!(gathered.shape, vec![2, 2, 2]);
1108            assert_eq!(
1109                gathered.data,
1110                vec![1.0, 3.0, 2.0, 4.0, 5.0, 7.0, 6.0, 8.0],
1111                "GPU fallback should match identity broadcast"
1112            );
1113        });
1114    }
1115
1116    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1117    #[test]
1118    #[cfg(feature = "wgpu")]
1119    fn pagefun_wgpu_mtimes_batches() {
1120        use runmat_accelerate::backend::wgpu::provider::{
1121            register_wgpu_provider, WgpuProviderOptions,
1122        };
1123
1124        let _ =
1125            register_wgpu_provider(WgpuProviderOptions::default()).expect("register wgpu provider");
1126        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1127
1128        let lhs = Tensor::new(
1129            vec![
1130                1.0, 4.0, 2.0, 5.0, //
1131                3.0, 6.0, 4.0, 7.0,
1132            ],
1133            vec![2, 2, 2],
1134        )
1135        .unwrap();
1136        let rhs = Tensor::new(
1137            vec![
1138                1.0, 0.0, 0.0, 1.0, //
1139                2.0, 1.0, 3.0, 2.0,
1140            ],
1141            vec![2, 2, 2],
1142        )
1143        .unwrap();
1144
1145        let view_lhs = HostTensorView {
1146            data: &lhs.data,
1147            shape: &lhs.shape,
1148        };
1149        let view_rhs = HostTensorView {
1150            data: &rhs.data,
1151            shape: &rhs.shape,
1152        };
1153
1154        let lhs_handle = provider.upload(&view_lhs).expect("upload lhs");
1155        let rhs_handle = provider.upload(&view_rhs).expect("upload rhs");
1156
1157        let provider_handles = vec![lhs_handle.clone(), rhs_handle.clone()];
1158        let request = build_pagefun_request(&PageOperation::Mtimes, &provider_handles)
1159            .expect("build request")
1160            .expect("request available");
1161
1162        let provider_result = provider.pagefun(&request).expect("wgpu pagefun execution");
1163        let provider_tensor =
1164            test_support::gather(Value::GpuTensor(provider_result)).expect("gather provider");
1165
1166        let builtin_value = pagefun_builtin(
1167            Value::FunctionHandle("mtimes".into()),
1168            Value::GpuTensor(lhs_handle.clone()),
1169            vec![Value::GpuTensor(rhs_handle.clone())],
1170        );
1171        let builtin_value = block_on(builtin_value).expect("pagefun builtin on GPU");
1172        let builtin_tensor = test_support::gather(builtin_value).expect("gather builtin");
1173
1174        let expected_value = pagefun_builtin(
1175            Value::FunctionHandle("mtimes".into()),
1176            Value::Tensor(lhs.clone()),
1177            vec![Value::Tensor(rhs.clone())],
1178        );
1179        let expected_value = block_on(expected_value).expect("pagefun host baseline");
1180        let expected_tensor = match expected_value {
1181            Value::Tensor(t) => t,
1182            other => panic!("expected tensor result, got {other:?}"),
1183        };
1184
1185        assert_eq!(provider_tensor.shape, expected_tensor.shape);
1186        assert_eq!(provider_tensor.data, expected_tensor.data);
1187        assert_eq!(builtin_tensor.shape, expected_tensor.shape);
1188        assert_eq!(builtin_tensor.data, expected_tensor.data);
1189    }
1190}