1use crate::builtins::acceleration::gpu::type_resolvers::pagefun_type;
9use crate::builtins::common::spec::{
10 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
11 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
12};
13use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
14use runmat_accelerate_api::{GpuTensorHandle, HostTensorView, PagefunOp, PagefunRequest};
15use runmat_builtins::{ComplexTensor, Tensor, Value};
16use runmat_macros::runtime_builtin;
17
18type ComplexMatrixData = (Vec<(f64, f64)>, usize, usize);
19
20#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::acceleration::gpu::pagefun")]
21pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
22 name: "pagefun",
23 op_kind: GpuOpKind::Custom("pagefun"),
24 supported_precisions: &[ScalarType::F32, ScalarType::F64],
25 broadcast: BroadcastSemantics::Matlab,
26 provider_hooks: &[ProviderHook::Custom("pagefun")],
27 constant_strategy: ConstantStrategy::InlineLiteral,
28 residency: ResidencyPolicy::NewHandle,
29 nan_mode: ReductionNaN::Include,
30 two_pass_threshold: None,
31 workgroup_size: None,
32 accepts_nan_mode: false,
33 notes: "WGPU provider accelerates batched @mtimes; runtimes gather to host when no provider hook is available.",
34};
35
36#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::acceleration::gpu::pagefun")]
37pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
38 name: "pagefun",
39 shape: ShapeRequirements::Any,
40 constant_strategy: ConstantStrategy::InlineLiteral,
41 elementwise: None,
42 reduction: None,
43 emits_nan: false,
44 notes: "Acts as a fusion barrier because pagefun can invoke arbitrary MATLAB operators.",
45};
46
47fn pagefun_error(message: impl Into<String>) -> RuntimeError {
48 build_runtime_error(message).with_builtin("pagefun").build()
49}
50
51#[runtime_builtin(
52 name = "pagefun",
53 category = "acceleration/gpu",
54 summary = "Apply MATLAB operators page-by-page across higher-dimensional arrays.",
55 keywords = "pagefun,gpuArray,mtimes,pages,batch",
56 accel = "custom",
57 type_resolver(pagefun_type),
58 builtin_path = "crate::builtins::acceleration::gpu::pagefun"
59)]
60async fn pagefun_builtin(
61 func: Value,
62 first: Value,
63 rest: Vec<Value>,
64) -> crate::BuiltinResult<Value> {
65 let operation = PageOperation::from_callable(func)?;
66 let mut operands = Vec::with_capacity(rest.len() + 1);
67 operands.push(first);
68 operands.extend(rest);
69 if operands.is_empty() {
70 return Err(pagefun_error("pagefun: requires at least one array input"));
71 }
72
73 operation.validate_arity(operands.len())?;
74
75 if let Some(value) = try_pagefun_gpu(&operation, &operands)? {
76 return Ok(value);
77 }
78
79 let all_gpu = operands.iter().all(|v| matches!(v, Value::GpuTensor(_)));
80 let mut host_values = Vec::with_capacity(operands.len());
81 for value in operands {
82 host_values.push(gather_if_needed_async(&value).await?);
83 }
84
85 let mut page_inputs = Vec::with_capacity(host_values.len());
86 for value in host_values {
87 page_inputs.push(PageInput::from_value(value)?);
88 }
89
90 let rank = page_inputs
91 .iter()
92 .map(|input| input.page_dims.len())
93 .max()
94 .unwrap_or(0);
95
96 let mut result_page_dims = if rank == 0 {
97 Vec::new()
98 } else {
99 vec![1usize; rank]
100 };
101
102 for dim in 0..rank {
103 let mut target = 1usize;
104 for input in &page_inputs {
105 let size = input.page_dims.get(dim).copied().unwrap_or(1);
106 if size == 0 {
107 target = 0;
108 break;
109 }
110 if size != 1 {
111 if target == 1 {
112 target = size;
113 } else if target != size {
114 return Err(pagefun_error(format!(
115 "pagefun: page dimension {} mismatch ({} vs {})",
116 dim + 3,
117 target,
118 size
119 )));
120 }
121 }
122 }
123 if !result_page_dims.is_empty() {
124 result_page_dims[dim] = target;
125 }
126 }
127
128 let page_volume = if rank == 0 {
129 1usize
130 } else {
131 result_page_dims.iter().copied().product()
132 };
133
134 let mut prepared_inputs = Vec::with_capacity(page_inputs.len());
135 for input in page_inputs {
136 prepared_inputs.push(PreparedInput::new(input, rank));
137 }
138
139 operation.validate_shapes(&prepared_inputs)?;
140 let output_kind = operation.output_kind(&prepared_inputs);
141 let (mut result_rows, mut result_cols) =
142 operation.output_matrix_shape(&prepared_inputs, output_kind)?;
143
144 if page_volume == 0 {
145 return finalise_empty_output(
146 &operation,
147 &prepared_inputs,
148 &result_page_dims,
149 output_kind,
150 all_gpu,
151 );
152 }
153
154 let mut real_data: Option<Vec<f64>> = None;
155 let mut complex_data: Option<Vec<(f64, f64)>> = None;
156 let mut multi_index = vec![0usize; rank];
157
158 let mut page_counter = 0usize;
159 loop {
160 let mut page_args = Vec::with_capacity(prepared_inputs.len());
161 for input in &prepared_inputs {
162 page_args.push(input.page_value(&multi_index)?);
163 }
164
165 let mut evaluated = operation.evaluate(&page_args).await?;
166 evaluated = gather_if_needed_async(&evaluated).await?;
167 match output_kind {
168 OutputKind::Real => {
169 let (data, rows, cols) = tensor_matrix_data(evaluated)?;
170 if real_data.is_none() {
171 result_rows = rows;
172 result_cols = cols;
173 real_data = Some(Vec::with_capacity(rows * cols * page_volume));
174 } else if rows != result_rows || cols != result_cols {
175 return Err(pagefun_error(
176 "pagefun: result matrices must be the same size",
177 ));
178 }
179 if let Some(vec) = real_data.as_mut() {
180 vec.extend(data);
181 }
182 }
183 OutputKind::Complex => {
184 let (data, rows, cols) = complex_matrix_data(evaluated)?;
185 if complex_data.is_none() {
186 result_rows = rows;
187 result_cols = cols;
188 complex_data = Some(Vec::with_capacity(rows * cols * page_volume));
189 } else if rows != result_rows || cols != result_cols {
190 return Err(pagefun_error(
191 "pagefun: result matrices must be the same size",
192 ));
193 }
194 if let Some(vec) = complex_data.as_mut() {
195 vec.extend(data);
196 }
197 }
198 }
199
200 page_counter += 1;
201 if page_counter == page_volume {
202 break;
203 }
204 increment_multi_index(&mut multi_index, &result_page_dims)?;
205 }
206
207 let final_shape = assemble_shape(result_rows, result_cols, &result_page_dims);
208 let output = match output_kind {
209 OutputKind::Real => {
210 let data = real_data.unwrap_or_default();
211 let tensor = Tensor::new(data, final_shape).map_err(|e| {
212 pagefun_error(format!("pagefun: failed to construct result tensor ({e})"))
213 })?;
214 FinalOutput::Real(tensor)
215 }
216 OutputKind::Complex => {
217 let data = complex_data.unwrap_or_default();
218 let tensor = ComplexTensor::new(data, final_shape).map_err(|e| {
219 pagefun_error(format!(
220 "pagefun: failed to construct complex result tensor ({e})"
221 ))
222 })?;
223 FinalOutput::Complex(tensor)
224 }
225 };
226
227 output.into_value(all_gpu)
228}
229
230fn try_pagefun_gpu(operation: &PageOperation, operands: &[Value]) -> BuiltinResult<Option<Value>> {
231 if operands.is_empty() {
232 return Ok(None);
233 }
234 if !operands
235 .iter()
236 .all(|value| matches!(value, Value::GpuTensor(_)))
237 {
238 return Ok(None);
239 }
240
241 #[cfg(all(test, feature = "wgpu"))]
242 {
243 if operands
245 .iter()
246 .any(|v| matches!(v, Value::GpuTensor(h) if h.device_id != 0))
247 {
248 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
249 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
250 );
251 }
252 }
253 let Some(provider) = runmat_accelerate_api::provider() else {
254 return Ok(None);
255 };
256
257 let handles: Vec<GpuTensorHandle> = operands
258 .iter()
259 .map(|value| match value {
260 Value::GpuTensor(handle) => handle.clone(),
261 _ => unreachable!(),
262 })
263 .collect();
264
265 let request = match build_pagefun_request(operation, &handles)? {
266 Some(request) => request,
267 None => return Ok(None),
268 };
269
270 match provider.pagefun(&request) {
271 Ok(handle) => Ok(Some(Value::GpuTensor(handle))),
272 Err(err) => {
273 log::debug!("pagefun: provider hook unavailable, falling back to host: {err}");
274 Ok(None)
275 }
276 }
277}
278
279fn build_pagefun_request(
280 operation: &PageOperation,
281 handles: &[GpuTensorHandle],
282) -> BuiltinResult<Option<PagefunRequest>> {
283 match operation {
284 PageOperation::Mtimes => {
285 if handles.len() != 2 {
286 return Err(pagefun_error(
287 "pagefun: @mtimes requires exactly two array inputs",
288 ));
289 }
290
291 let (lhs_rows, lhs_cols, lhs_pages) = handle_matrix_meta(&handles[0])?;
292 let (rhs_rows, rhs_cols, rhs_pages) = handle_matrix_meta(&handles[1])?;
293 if lhs_cols != rhs_rows {
294 return Err(pagefun_error(format!(
295 "pagefun: inner matrix dimensions must agree ({}x{} * {}x{})",
296 lhs_rows, lhs_cols, rhs_rows, rhs_cols
297 )));
298 }
299
300 let rank = lhs_pages.len().max(rhs_pages.len());
301 let mut result_page_dims = if rank == 0 {
302 Vec::new()
303 } else {
304 vec![1usize; rank]
305 };
306
307 for dim in 0..rank {
308 let mut target = 1usize;
309 let dims_to_check = [
310 lhs_pages.get(dim).copied().unwrap_or(1),
311 rhs_pages.get(dim).copied().unwrap_or(1),
312 ];
313 for size in dims_to_check {
314 if size == 0 {
315 target = 0;
316 break;
317 }
318 if size != 1 {
319 if target == 1 {
320 target = size;
321 } else if target != size {
322 return Err(pagefun_error(format!(
323 "pagefun: page dimension {} mismatch ({} vs {})",
324 dim + 3,
325 target,
326 size
327 )));
328 }
329 }
330 }
331 if !result_page_dims.is_empty() {
332 result_page_dims[dim] = target;
333 }
334 }
335
336 let mut input_page_dims = Vec::with_capacity(2);
337 let mut lhs_padded = lhs_pages.clone();
338 lhs_padded.resize(rank, 1);
339 let mut rhs_padded = rhs_pages.clone();
340 rhs_padded.resize(rank, 1);
341 input_page_dims.push(lhs_padded);
342 input_page_dims.push(rhs_padded);
343
344 let mut output_shape = vec![lhs_rows, rhs_cols];
345 output_shape.extend_from_slice(&result_page_dims);
346
347 Ok(Some(PagefunRequest {
348 op: PagefunOp::Mtimes,
349 inputs: handles.to_vec(),
350 output_shape,
351 page_dims: result_page_dims,
352 input_page_dims,
353 }))
354 }
355 }
356}
357
358fn handle_matrix_meta(handle: &GpuTensorHandle) -> BuiltinResult<(usize, usize, Vec<usize>)> {
359 let canonical = canonical_matrix_shape(&handle.shape);
360 if canonical.len() < 2 {
361 return Err(pagefun_error("pagefun: gpu tensor must be at least 2-D"));
362 }
363 let rows = canonical[0];
364 let cols = canonical[1];
365 let pages = if canonical.len() > 2 {
366 canonical[2..].to_vec()
367 } else {
368 Vec::new()
369 };
370 Ok((rows, cols, pages))
371}
372
373fn finalise_empty_output(
374 operation: &PageOperation,
375 inputs: &[PreparedInput],
376 page_dims: &[usize],
377 output_kind: OutputKind,
378 wants_gpu: bool,
379) -> BuiltinResult<Value> {
380 let (rows, cols) = operation.output_matrix_shape(inputs, output_kind)?;
381 let final_shape = assemble_shape(rows, cols, page_dims);
382 let page_factor: usize = if page_dims.is_empty() {
383 1
384 } else {
385 page_dims.iter().copied().product()
386 };
387 let entries = rows
388 .checked_mul(cols)
389 .unwrap_or(0)
390 .checked_mul(page_factor)
391 .unwrap_or(0);
392 match output_kind {
393 OutputKind::Real => {
394 let tensor = Tensor::new(vec![0.0; entries], final_shape).map_err(|e| {
395 pagefun_error(format!("pagefun: failed to build empty tensor ({e})"))
396 })?;
397 FinalOutput::Real(tensor).into_value(wants_gpu)
398 }
399 OutputKind::Complex => {
400 let tensor =
401 ComplexTensor::new(vec![(0.0, 0.0); entries], final_shape).map_err(|e| {
402 pagefun_error(format!(
403 "pagefun: failed to build empty complex tensor ({e})"
404 ))
405 })?;
406 FinalOutput::Complex(tensor).into_value(false)
407 }
408 }
409}
410
411fn assemble_shape(rows: usize, cols: usize, page_dims: &[usize]) -> Vec<usize> {
412 let mut shape = vec![rows, cols];
413 shape.extend_from_slice(page_dims);
414 shape
415}
416
417fn increment_multi_index(indices: &mut [usize], dims: &[usize]) -> BuiltinResult<()> {
418 if dims.contains(&0) {
419 return Ok(());
420 }
421 for (dim, &limit) in dims.iter().enumerate() {
422 if limit == 0 {
423 continue;
424 }
425 indices[dim] += 1;
426 if indices[dim] < limit {
427 return Ok(());
428 }
429 indices[dim] = 0;
430 if dim + 1 == dims.len() {
431 break;
432 }
433 }
434 Ok(())
435}
436
437#[derive(Clone, Copy, Debug, PartialEq, Eq)]
438enum OutputKind {
439 Real,
440 Complex,
441}
442
443enum FinalOutput {
444 Real(Tensor),
445 Complex(ComplexTensor),
446}
447
448impl FinalOutput {
449 fn into_value(self, wants_gpu: bool) -> BuiltinResult<Value> {
450 match self {
451 FinalOutput::Real(tensor) => {
452 if wants_gpu {
453 #[cfg(all(test, feature = "wgpu"))]
454 {
455 if runmat_accelerate_api::provider().is_none() {
456 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
457 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
458 );
459 }
460 }
461 if let Some(provider) = runmat_accelerate_api::provider() {
462 let view = HostTensorView {
463 data: &tensor.data,
464 shape: &tensor.shape,
465 };
466 if let Ok(handle) = provider.upload(&view) {
467 return Ok(Value::GpuTensor(handle));
468 }
469 }
470 }
471 Ok(Value::Tensor(tensor))
472 }
473 FinalOutput::Complex(tensor) => Ok(Value::ComplexTensor(tensor)),
474 }
475 }
476}
477
478#[derive(Clone)]
479struct PageInput {
480 page_dims: Vec<usize>,
481 rows: usize,
482 cols: usize,
483 data: PageData,
484}
485
486#[derive(Clone)]
487enum PageData {
488 Real(Vec<f64>),
489 Complex(Vec<(f64, f64)>),
490}
491
492impl PageInput {
493 fn from_value(value: Value) -> BuiltinResult<Self> {
494 match value {
495 Value::Tensor(t) => Self::from_tensor(t),
496 Value::Num(n) => Self::from_tensor(
497 Tensor::new(vec![n], vec![1, 1])
498 .map_err(|e| pagefun_error(format!("pagefun: {e}")))?,
499 ),
500 Value::Int(i) => Self::from_tensor(
501 Tensor::new(vec![i.to_f64()], vec![1, 1])
502 .map_err(|e| pagefun_error(format!("pagefun: {e}")))?,
503 ),
504 Value::Bool(flag) => Self::from_tensor(
505 Tensor::new(vec![if flag { 1.0 } else { 0.0 }], vec![1, 1])
506 .map_err(|e| pagefun_error(format!("pagefun: {e}")))?,
507 ),
508 Value::Complex(re, im) => {
509 let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
510 .map_err(|e| pagefun_error(format!("pagefun: {e}")))?;
511 Self::from_complex_tensor(tensor)
512 }
513 Value::ComplexTensor(t) => Self::from_complex_tensor(t),
514 other => Err(pagefun_error(format!(
515 "pagefun: unsupported input type {}",
516 other.type_name()
517 ))),
518 }
519 }
520
521 fn from_tensor(tensor: Tensor) -> BuiltinResult<Self> {
522 let shape = canonical_matrix_shape(&tensor.shape);
523 if tensor.data.len() != shape.iter().copied().product::<usize>() {
524 return Err(pagefun_error(
525 "pagefun: tensor data does not match its shape",
526 ));
527 }
528 let rows = shape[0];
529 let cols = shape[1];
530 let page_dims = if shape.len() > 2 {
531 shape[2..].to_vec()
532 } else {
533 Vec::new()
534 };
535 Ok(Self {
536 page_dims,
537 rows,
538 cols,
539 data: PageData::Real(tensor.data),
540 })
541 }
542
543 fn from_complex_tensor(tensor: ComplexTensor) -> BuiltinResult<Self> {
544 let shape = canonical_matrix_shape(&tensor.shape);
545 if tensor.data.len() != shape.iter().copied().product::<usize>() {
546 return Err(pagefun_error(
547 "pagefun: tensor data does not match its shape",
548 ));
549 }
550 let rows = shape[0];
551 let cols = shape[1];
552 let page_dims = if shape.len() > 2 {
553 shape[2..].to_vec()
554 } else {
555 Vec::new()
556 };
557 Ok(Self {
558 page_dims,
559 rows,
560 cols,
561 data: PageData::Complex(tensor.data),
562 })
563 }
564
565 fn page_size(&self) -> usize {
566 self.rows * self.cols
567 }
568
569 fn is_complex(&self) -> bool {
570 matches!(self.data, PageData::Complex(_))
571 }
572}
573
574struct PreparedInput {
575 data: PageInput,
576 padded_dims: Vec<usize>,
577 strides: Vec<usize>,
578}
579
580impl PreparedInput {
581 fn new(input: PageInput, rank: usize) -> Self {
582 let mut padded = input.page_dims.clone();
583 padded.resize(rank, 1);
584 let strides = compute_strides(&padded);
585 Self {
586 data: input,
587 padded_dims: padded,
588 strides,
589 }
590 }
591
592 fn rows(&self) -> usize {
593 self.data.rows
594 }
595
596 fn cols(&self) -> usize {
597 self.data.cols
598 }
599
600 fn is_complex(&self) -> bool {
601 self.data.is_complex()
602 }
603
604 fn page_value(&self, multi_index: &[usize]) -> BuiltinResult<Value> {
605 let mut linear_page = 0usize;
606 for (dim, stride) in self.strides.iter().enumerate() {
607 let source_extent = self.padded_dims.get(dim).copied().unwrap_or(1);
608 let requested = multi_index.get(dim).copied().unwrap_or(0);
609 if source_extent == 0 {
610 return Err(pagefun_error("pagefun: source page extent is zero"));
611 }
612 if source_extent != 1 && requested >= source_extent {
613 return Err(pagefun_error("pagefun: page index out of bounds"));
614 }
615 let actual = if source_extent == 1 { 0 } else { requested };
616 linear_page += actual * stride;
617 }
618
619 let offset = linear_page * self.data.page_size();
620 match &self.data.data {
621 PageData::Real(buffer) => {
622 let end = offset + self.data.page_size();
623 let slice = buffer
624 .get(offset..end)
625 .ok_or_else(|| pagefun_error("pagefun: page slice out of bounds"))?;
626 let tensor = Tensor::new(slice.to_vec(), vec![self.data.rows, self.data.cols])
627 .map_err(|e| pagefun_error(format!("pagefun: {e}")))?;
628 Ok(Value::Tensor(tensor))
629 }
630 PageData::Complex(buffer) => {
631 let end = offset + self.data.page_size();
632 let slice = buffer
633 .get(offset..end)
634 .ok_or_else(|| pagefun_error("pagefun: page slice out of bounds"))?;
635 let tensor =
636 ComplexTensor::new(slice.to_vec(), vec![self.data.rows, self.data.cols])
637 .map_err(|e| pagefun_error(format!("pagefun: {e}")))?;
638 Ok(Value::ComplexTensor(tensor))
639 }
640 }
641 }
642}
643
644fn compute_strides(dims: &[usize]) -> Vec<usize> {
645 let mut strides = Vec::with_capacity(dims.len());
646 let mut stride = 1usize;
647 for &dim in dims {
648 strides.push(stride);
649 stride = stride.saturating_mul(dim.max(1));
650 }
651 strides
652}
653
654fn tensor_matrix_data(value: Value) -> BuiltinResult<(Vec<f64>, usize, usize)> {
655 match value {
656 Value::Tensor(t) => {
657 if t.shape.len() > 2 {
658 return Err(pagefun_error(
659 "pagefun: operator returned an array with more than two dimensions",
660 ));
661 }
662 let canonical = canonical_matrix_shape(&t.shape);
663 let rows = canonical[0];
664 let cols = canonical[1];
665 if rows * cols != t.data.len() {
666 return Err(pagefun_error("pagefun: result size mismatch"));
667 }
668 Ok((t.data, rows, cols))
669 }
670 Value::Num(n) => Ok((vec![n], 1, 1)),
671 Value::Int(i) => Ok((vec![i.to_f64()], 1, 1)),
672 other => Err(pagefun_error(format!(
673 "pagefun: expected numeric matrix result, received {}",
674 other.type_name()
675 ))),
676 }
677}
678
679fn complex_matrix_data(value: Value) -> BuiltinResult<ComplexMatrixData> {
680 match value {
681 Value::ComplexTensor(t) => {
682 if t.shape.len() > 2 {
683 return Err(pagefun_error(
684 "pagefun: operator returned an array with more than two dimensions",
685 ));
686 }
687 let canonical = canonical_matrix_shape(&t.shape);
688 let rows = canonical[0];
689 let cols = canonical[1];
690 if rows * cols != t.data.len() {
691 return Err(pagefun_error("pagefun: result size mismatch"));
692 }
693 Ok((t.data, rows, cols))
694 }
695 Value::Complex(re, im) => Ok((vec![(re, im)], 1, 1)),
696 other => Err(pagefun_error(format!(
697 "pagefun: expected complex matrix result, received {}",
698 other.type_name()
699 ))),
700 }
701}
702
703fn canonical_matrix_shape(shape: &[usize]) -> Vec<usize> {
704 match shape.len() {
705 0 => vec![1, 1],
706 1 => vec![1, shape[0]],
707 _ => {
708 let mut out = shape.to_vec();
709 if out.len() == 1 {
710 out.push(1);
711 }
712 out
713 }
714 }
715}
716
717#[derive(Clone, Copy)]
718enum PageOperation {
719 Mtimes,
720}
721
722impl PageOperation {
723 fn from_callable(value: Value) -> BuiltinResult<Self> {
724 let raw = match value {
725 Value::FunctionHandle(func) => func,
726 Value::String(s) => s,
727 Value::StringArray(sa) => {
728 if sa.data.len() != 1 {
729 return Err(pagefun_error(
730 "pagefun: function string array must contain exactly one element",
731 ));
732 }
733 sa.data[0].clone()
734 }
735 Value::CharArray(chars) => {
736 if chars.rows != 1 {
737 return Err(pagefun_error(
738 "pagefun: function char array must be a single row character vector",
739 ));
740 }
741 chars.data.iter().collect()
742 }
743 other => {
744 return Err(pagefun_error(format!(
745 "pagefun: unsupported function handle type {}",
746 other.type_name()
747 )))
748 }
749 };
750 let trimmed = raw.trim();
751 let lowered = trimmed.trim_start_matches('@').to_ascii_lowercase();
752 match lowered.as_str() {
753 "mtimes" => Ok(Self::Mtimes),
754 _ => Err(pagefun_error(format!(
755 "pagefun: unsupported function '{}'; currently only @mtimes is implemented",
756 trimmed
757 ))),
758 }
759 }
760
761 fn validate_arity(&self, arg_count: usize) -> BuiltinResult<()> {
762 match self {
763 Self::Mtimes => {
764 if arg_count != 2 {
765 return Err(pagefun_error(
766 "pagefun: @mtimes requires exactly two array inputs",
767 ));
768 }
769 Ok(())
770 }
771 }
772 }
773
774 fn validate_shapes(&self, inputs: &[PreparedInput]) -> BuiltinResult<()> {
775 match self {
776 Self::Mtimes => {
777 let lhs = &inputs[0];
778 let rhs = &inputs[1];
779 if lhs.cols() != rhs.rows() {
780 return Err(pagefun_error(format!(
781 "pagefun: inner matrix dimensions must agree ({}x{} * {}x{})",
782 lhs.rows(),
783 lhs.cols(),
784 rhs.rows(),
785 rhs.cols()
786 )));
787 }
788 Ok(())
789 }
790 }
791 }
792
793 async fn evaluate(&self, args: &[Value]) -> crate::BuiltinResult<Value> {
794 match self {
795 Self::Mtimes => crate::call_builtin_async("mtimes", args).await,
796 }
797 }
798
799 fn output_kind(&self, inputs: &[PreparedInput]) -> OutputKind {
800 match self {
801 Self::Mtimes => {
802 if inputs.iter().any(|input| input.is_complex()) {
803 OutputKind::Complex
804 } else {
805 OutputKind::Real
806 }
807 }
808 }
809 }
810
811 fn output_matrix_shape(
812 &self,
813 inputs: &[PreparedInput],
814 kind: OutputKind,
815 ) -> BuiltinResult<(usize, usize)> {
816 match self {
817 Self::Mtimes => {
818 let lhs = &inputs[0];
819 let rhs = &inputs[1];
820 let rows = lhs.rows();
821 let cols = rhs.cols();
822 match kind {
823 OutputKind::Real | OutputKind::Complex => Ok((rows, cols)),
824 }
825 }
826 }
827 }
828}
829
830trait TypeName {
831 fn type_name(&self) -> &'static str;
832}
833
834impl TypeName for Value {
835 fn type_name(&self) -> &'static str {
836 match self {
837 Value::Int(_) => "int",
838 Value::Num(_) => "double",
839 Value::Complex(_, _) => "complex double",
840 Value::Bool(_) => "logical",
841 Value::LogicalArray(_) => "logical array",
842 Value::String(_) => "string",
843 Value::StringArray(_) => "string array",
844 Value::CharArray(_) => "char array",
845 Value::Tensor(_) => "double array",
846 Value::ComplexTensor(_) => "complex double array",
847 Value::Cell(_) => "cell array",
848 Value::Struct(_) => "struct",
849 Value::GpuTensor(_) => "gpuArray",
850 Value::Object(_) => "object",
851 Value::HandleObject(_) => "handle object",
852 Value::Listener(_) => "listener",
853 Value::FunctionHandle(_) => "function handle",
854 Value::Closure(_) => "closure",
855 Value::ClassRef(_) => "class reference",
856 Value::MException(_) => "MException",
857 Value::OutputList(_) => "output list",
858 }
859 }
860}
861
862#[cfg(test)]
863pub(crate) mod tests {
864 use super::*;
865 use crate::builtins::common::test_support;
866 use futures::executor::block_on;
867 use runmat_builtins::{CharArray, ResolveContext, StringArray, Type};
868
869 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
870 #[test]
871 fn pagefun_mtimes_single_page() {
872 let lhs = Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
873 let rhs = Tensor::new(vec![5.0, 7.0, 6.0, 8.0], vec![2, 2]).unwrap();
874 let result = pagefun_builtin(
875 Value::FunctionHandle("mtimes".into()),
876 Value::Tensor(lhs),
877 vec![Value::Tensor(rhs)],
878 );
879 let result = block_on(result).expect("pagefun");
880 match result {
881 Value::Tensor(t) => {
882 assert_eq!(t.shape, vec![2, 2]);
883 assert_eq!(t.data, vec![19.0, 43.0, 22.0, 50.0]);
884 }
885 other => panic!("expected tensor result, got {other:?}"),
886 }
887 }
888
889 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
890 #[test]
891 fn pagefun_mtimes_multiple_pages() {
892 let lhs = Tensor::new(vec![1.0, 3.0, 2.0, 4.0, 2.0, 1.0, 0.0, 3.0], vec![2, 2, 2]).unwrap();
893 let rhs = Tensor::new(vec![5.0, 7.0, 6.0, 8.0, 1.0, 0.0, 2.0, 1.0], vec![2, 2, 2]).unwrap();
894 let result = pagefun_builtin(
895 Value::from("@mtimes"),
896 Value::Tensor(lhs),
897 vec![Value::Tensor(rhs)],
898 );
899 let result = block_on(result).expect("pagefun");
900 match result {
901 Value::Tensor(t) => {
902 assert_eq!(t.shape, vec![2, 2, 2]);
903 assert_eq!(t.data, vec![19.0, 43.0, 22.0, 50.0, 2.0, 1.0, 4.0, 5.0]);
904 }
905 other => panic!("expected tensor result, got {other:?}"),
906 }
907 }
908
909 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
910 #[test]
911 fn pagefun_mtimes_broadcast_rhs() {
912 let lhs = Tensor::new(vec![1.0, 3.0, 2.0, 4.0, 5.0, 7.0, 6.0, 8.0], vec![2, 2, 2]).unwrap();
913 let rhs = Tensor::new(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap();
914 let result = pagefun_builtin(
915 Value::FunctionHandle("mtimes".into()),
916 Value::Tensor(lhs),
917 vec![Value::Tensor(rhs)],
918 );
919 let result = block_on(result).expect("pagefun");
920 match result {
921 Value::Tensor(t) => {
922 assert_eq!(t.shape, vec![2, 2, 2]);
923 assert_eq!(
924 t.data,
925 vec![1.0, 3.0, 2.0, 4.0, 5.0, 7.0, 6.0, 8.0],
926 "broadcasted identity should preserve pages"
927 );
928 }
929 other => panic!("expected tensor result, got {other:?}"),
930 }
931 }
932
933 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
934 #[test]
935 fn pagefun_mtimes_empty_pages() {
936 let lhs = Tensor::new(Vec::new(), vec![2, 2, 0]).unwrap();
937 let rhs = Tensor::new(Vec::new(), vec![2, 2, 0]).unwrap();
938 let result = pagefun_builtin(
939 Value::from("@mtimes"),
940 Value::Tensor(lhs),
941 vec![Value::Tensor(rhs)],
942 );
943 let result = block_on(result).expect("pagefun");
944 match result {
945 Value::Tensor(t) => {
946 assert_eq!(t.shape, vec![2, 2, 0]);
947 assert!(t.data.is_empty());
948 }
949 other => panic!("expected tensor result, got {other:?}"),
950 }
951 }
952
953 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
954 #[test]
955 fn pagefun_mtimes_char_array_handle() {
956 let lhs = Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
957 let rhs = Tensor::new(vec![5.0, 7.0, 6.0, 8.0], vec![2, 2]).unwrap();
958 let func = CharArray::new("@mtimes".chars().collect(), 1, 7).unwrap();
959 let result = pagefun_builtin(
960 Value::CharArray(func),
961 Value::Tensor(lhs),
962 vec![Value::Tensor(rhs)],
963 );
964 let result = block_on(result).expect("pagefun char array");
965 match result {
966 Value::Tensor(t) => {
967 assert_eq!(t.shape, vec![2, 2]);
968 assert_eq!(t.data, vec![19.0, 43.0, 22.0, 50.0]);
969 }
970 other => panic!("expected tensor result, got {other:?}"),
971 }
972 }
973
974 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
975 #[test]
976 fn pagefun_mtimes_string_array_handle() {
977 let lhs = Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
978 let rhs = Tensor::new(vec![5.0, 7.0, 6.0, 8.0], vec![2, 2]).unwrap();
979 let strings = StringArray::new(vec!["@mtimes".to_string()], vec![1]).unwrap();
980 let result = pagefun_builtin(
981 Value::StringArray(strings),
982 Value::Tensor(lhs),
983 vec![Value::Tensor(rhs)],
984 );
985 let result = block_on(result).expect("pagefun string array");
986 match result {
987 Value::Tensor(t) => {
988 assert_eq!(t.shape, vec![2, 2]);
989 assert_eq!(t.data, vec![19.0, 43.0, 22.0, 50.0]);
990 }
991 other => panic!("expected tensor result, got {other:?}"),
992 }
993 }
994
995 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
996 #[test]
997 fn pagefun_char_array_multirow_error() {
998 let chars = CharArray::new("@mtimes@".chars().collect(), 2, 4).unwrap();
999 let lhs = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1000 let rhs = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1001 let err = pagefun_builtin(
1002 Value::CharArray(chars),
1003 Value::Tensor(lhs),
1004 vec![Value::Tensor(rhs)],
1005 );
1006 let err = block_on(err).expect_err("expected multi-row char array error");
1007 assert!(
1008 err.contains("char array"),
1009 "unexpected error for multi-row char array: {err}"
1010 );
1011 }
1012
1013 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1014 #[test]
1015 fn pagefun_string_array_multi_value_error() {
1016 let strings =
1017 StringArray::new(vec!["@mtimes".to_string(), "@mtimes".to_string()], vec![2]).unwrap();
1018 let lhs = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1019 let rhs = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1020 let err = pagefun_builtin(
1021 Value::StringArray(strings),
1022 Value::Tensor(lhs),
1023 vec![Value::Tensor(rhs)],
1024 );
1025 let err = block_on(err).expect_err("expected multi-element string array error");
1026 assert!(
1027 err.contains("string array"),
1028 "unexpected error for string array: {err}"
1029 );
1030 }
1031
1032 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1033 #[test]
1034 fn pagefun_page_dimension_mismatch() {
1035 let lhs = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 2, 2]).unwrap();
1036 let rhs = Tensor::new(
1037 vec![
1038 1.0, 5.0, 2.0, 6.0, 3.0, 7.0, 4.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1039 ],
1040 vec![2, 2, 3],
1041 )
1042 .unwrap();
1043 let err = pagefun_builtin(
1044 Value::FunctionHandle("mtimes".into()),
1045 Value::Tensor(lhs),
1046 vec![Value::Tensor(rhs)],
1047 );
1048 let err = block_on(err).expect_err("expected page dimension mismatch");
1049 assert!(
1050 err.contains("page dimension"),
1051 "unexpected mismatch error message: {err}"
1052 );
1053 }
1054
1055 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1056 #[test]
1057 fn pagefun_mtimes_dim_mismatch() {
1058 let lhs = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1059 let rhs = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1060 let err = pagefun_builtin(
1061 Value::FunctionHandle("mtimes".into()),
1062 Value::Tensor(lhs),
1063 vec![Value::Tensor(rhs)],
1064 );
1065 let err = block_on(err).expect_err("expected dimension mismatch");
1066 assert!(
1067 err.contains("inner matrix dimensions"),
1068 "unexpected error message {err}"
1069 );
1070 }
1071
1072 #[test]
1073 fn pagefun_type_is_tensor() {
1074 assert_eq!(
1075 pagefun_type(&[Type::tensor()], &ResolveContext::new(Vec::new())),
1076 Type::tensor()
1077 );
1078 }
1079
1080 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1081 #[test]
1082 fn pagefun_gpu_roundtrip_mtimes() {
1083 test_support::with_test_provider(|provider| {
1084 let tensor =
1085 Tensor::new(vec![1.0, 3.0, 2.0, 4.0, 5.0, 7.0, 6.0, 8.0], vec![2, 2, 2]).unwrap();
1086 let identity = Tensor::new(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap();
1087
1088 let view_lhs = HostTensorView {
1089 data: &tensor.data,
1090 shape: &tensor.shape,
1091 };
1092 let view_rhs = HostTensorView {
1093 data: &identity.data,
1094 shape: &identity.shape,
1095 };
1096 let lhs = provider.upload(&view_lhs).expect("upload lhs");
1097 let rhs = provider.upload(&view_rhs).expect("upload rhs");
1098
1099 let result = pagefun_builtin(
1100 Value::FunctionHandle("mtimes".into()),
1101 Value::GpuTensor(lhs),
1102 vec![Value::GpuTensor(rhs)],
1103 );
1104 let result = block_on(result).expect("pagefun");
1105
1106 let gathered = test_support::gather(result).expect("gather");
1107 assert_eq!(gathered.shape, vec![2, 2, 2]);
1108 assert_eq!(
1109 gathered.data,
1110 vec![1.0, 3.0, 2.0, 4.0, 5.0, 7.0, 6.0, 8.0],
1111 "GPU fallback should match identity broadcast"
1112 );
1113 });
1114 }
1115
1116 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1117 #[test]
1118 #[cfg(feature = "wgpu")]
1119 fn pagefun_wgpu_mtimes_batches() {
1120 use runmat_accelerate::backend::wgpu::provider::{
1121 register_wgpu_provider, WgpuProviderOptions,
1122 };
1123
1124 let _ =
1125 register_wgpu_provider(WgpuProviderOptions::default()).expect("register wgpu provider");
1126 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1127
1128 let lhs = Tensor::new(
1129 vec![
1130 1.0, 4.0, 2.0, 5.0, 3.0, 6.0, 4.0, 7.0,
1132 ],
1133 vec![2, 2, 2],
1134 )
1135 .unwrap();
1136 let rhs = Tensor::new(
1137 vec![
1138 1.0, 0.0, 0.0, 1.0, 2.0, 1.0, 3.0, 2.0,
1140 ],
1141 vec![2, 2, 2],
1142 )
1143 .unwrap();
1144
1145 let view_lhs = HostTensorView {
1146 data: &lhs.data,
1147 shape: &lhs.shape,
1148 };
1149 let view_rhs = HostTensorView {
1150 data: &rhs.data,
1151 shape: &rhs.shape,
1152 };
1153
1154 let lhs_handle = provider.upload(&view_lhs).expect("upload lhs");
1155 let rhs_handle = provider.upload(&view_rhs).expect("upload rhs");
1156
1157 let provider_handles = vec![lhs_handle.clone(), rhs_handle.clone()];
1158 let request = build_pagefun_request(&PageOperation::Mtimes, &provider_handles)
1159 .expect("build request")
1160 .expect("request available");
1161
1162 let provider_result = provider.pagefun(&request).expect("wgpu pagefun execution");
1163 let provider_tensor =
1164 test_support::gather(Value::GpuTensor(provider_result)).expect("gather provider");
1165
1166 let builtin_value = pagefun_builtin(
1167 Value::FunctionHandle("mtimes".into()),
1168 Value::GpuTensor(lhs_handle.clone()),
1169 vec![Value::GpuTensor(rhs_handle.clone())],
1170 );
1171 let builtin_value = block_on(builtin_value).expect("pagefun builtin on GPU");
1172 let builtin_tensor = test_support::gather(builtin_value).expect("gather builtin");
1173
1174 let expected_value = pagefun_builtin(
1175 Value::FunctionHandle("mtimes".into()),
1176 Value::Tensor(lhs.clone()),
1177 vec![Value::Tensor(rhs.clone())],
1178 );
1179 let expected_value = block_on(expected_value).expect("pagefun host baseline");
1180 let expected_tensor = match expected_value {
1181 Value::Tensor(t) => t,
1182 other => panic!("expected tensor result, got {other:?}"),
1183 };
1184
1185 assert_eq!(provider_tensor.shape, expected_tensor.shape);
1186 assert_eq!(provider_tensor.data, expected_tensor.data);
1187 assert_eq!(builtin_tensor.shape, expected_tensor.shape);
1188 assert_eq!(builtin_tensor.data, expected_tensor.data);
1189 }
1190}