Skip to main content

runmat_runtime/builtins/array/creation/
meshgrid.rs

1//! MATLAB-compatible `meshgrid` builtin with GPU-aware semantics.
2
3use std::cmp::max;
4
5use log::warn;
6use runmat_accelerate_api::{GpuTensorHandle, HostTensorView};
7use runmat_builtins::{ComplexTensor, ResolveContext, Tensor, Type, Value};
8
9use crate::builtins::array::type_resolvers::size_vector_len;
10use runmat_macros::runtime_builtin;
11
12use crate::build_runtime_error;
13use crate::builtins::common::gpu_helpers;
14use crate::builtins::common::random_args::{complex_tensor_into_value, keyword_of};
15use crate::builtins::common::residency::{sequence_gpu_preference, SequenceIntent};
16use crate::builtins::common::spec::{
17    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
18    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
19};
20use crate::builtins::common::tensor;
21
22#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::creation::meshgrid")]
23pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
24    name: "meshgrid",
25    op_kind: GpuOpKind::Custom("array_construct"),
26    supported_precisions: &[ScalarType::F32, ScalarType::F64],
27    broadcast: BroadcastSemantics::Matlab,
28    provider_hooks: &[ProviderHook::Custom("meshgrid")],
29    constant_strategy: ConstantStrategy::InlineLiteral,
30    residency: ResidencyPolicy::NewHandle,
31    nan_mode: ReductionNaN::Include,
32    two_pass_threshold: None,
33    workgroup_size: None,
34    accepts_nan_mode: false,
35    notes: "Providers may supply a dedicated meshgrid hook; until then the runtime builds grids on the host and uploads them when GPU residency is requested.",
36};
37
38fn builtin_error(message: impl Into<String>) -> crate::RuntimeError {
39    build_runtime_error(message)
40        .with_builtin("meshgrid")
41        .build()
42}
43
44#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::creation::meshgrid")]
45pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
46    name: "meshgrid",
47    shape: ShapeRequirements::Any,
48    constant_strategy: ConstantStrategy::InlineLiteral,
49    elementwise: None,
50    reduction: None,
51    emits_nan: false,
52    notes:
53        "Meshgrid explicitly materialises dense coordinate arrays and therefore bypasses fusion.",
54};
55
56fn meshgrid_type(args: &[Type], _context: &ResolveContext) -> Type {
57    if args.is_empty() {
58        return Type::Unknown;
59    }
60    let mut axis_count = args.len();
61    if axis_count >= 2 && matches!(args[axis_count - 2], Type::String) {
62        axis_count = axis_count.saturating_sub(2);
63    }
64    if axis_count == 0 {
65        return Type::Unknown;
66    }
67    let axis_args = &args[..axis_count];
68    let len_x = axis_args.get(0).and_then(size_vector_len);
69    let len_y = axis_args.get(1).and_then(size_vector_len).or(len_x);
70    let len_z = axis_args.get(2).and_then(size_vector_len);
71    let shape = if axis_count >= 3 {
72        vec![len_y, len_x, len_z]
73    } else {
74        vec![len_y, len_x]
75    };
76    Type::Tensor { shape: Some(shape) }
77}
78
79#[runtime_builtin(
80    name = "meshgrid",
81    category = "array/creation",
82    summary = "Generate coordinate matrices for 2-D and 3-D grids.",
83    keywords = "meshgrid,grid,gpu,like,3d",
84    accel = "array_construct",
85    type_resolver(meshgrid_type),
86    builtin_path = "crate::builtins::array::creation::meshgrid"
87)]
88async fn meshgrid_builtin(rest: Vec<Value>) -> crate::BuiltinResult<Value> {
89    let eval = evaluate(&rest).await?;
90    if let Some(out_count) = crate::output_count::current_output_count() {
91        if out_count == 0 {
92            return Ok(Value::OutputList(Vec::new()));
93        }
94        let available = eval.output_count();
95        if out_count > available {
96            let msg = if available == 2 {
97                "meshgrid with two inputs supports at most two outputs"
98            } else {
99                "meshgrid supports at most three outputs"
100            };
101            return Err(builtin_error(msg));
102        }
103        let mut outputs = Vec::with_capacity(out_count);
104        let first = eval.first().await?;
105        outputs.push(first);
106        if out_count >= 2 {
107            outputs.push(eval.second().await?);
108        }
109        if out_count >= 3 {
110            outputs.push(eval.third().await?);
111        }
112        return Ok(Value::OutputList(outputs));
113    }
114    eval.first().await
115}
116
117/// Evaluate the `meshgrid` builtin once and reuse the result for multiple outputs.
118pub async fn evaluate(args: &[Value]) -> crate::BuiltinResult<MeshgridEval> {
119    let parsed = ParsedMeshgrid::parse(args).await?;
120    let (x_axis, y_axis, z_axis) = normalise_axes(&parsed.axes);
121
122    let require_complex = parsed.axes.iter().any(|axis| axis.is_complex);
123
124    let target_class = match &parsed.template {
125        OutputTemplate::Default => {
126            if require_complex {
127                PrototypeClass::Complex
128            } else {
129                PrototypeClass::Real
130            }
131        }
132        OutputTemplate::Like(spec) => {
133            if require_complex {
134                PrototypeClass::Complex
135            } else {
136                spec.class
137            }
138        }
139    };
140
141    let target_residency = match &parsed.template {
142        OutputTemplate::Default => {
143            if parsed.prefer_gpu {
144                DevicePreference::Gpu
145            } else {
146                DevicePreference::Host
147            }
148        }
149        OutputTemplate::Like(spec) => spec.residency,
150    };
151
152    let axes_all_real = !require_complex;
153    let mut outputs: Vec<MeshgridOutput> = Vec::new();
154
155    if axes_all_real
156        && matches!(target_class, PrototypeClass::Real)
157        && matches!(target_residency, DevicePreference::Gpu)
158    {
159        if let Some(gpu) = try_meshgrid_gpu_from_vector_axes(&x_axis, &y_axis, z_axis.as_ref())? {
160            outputs = gpu;
161        }
162    }
163
164    if outputs.is_empty() {
165        // Host fallback: ensure we have host axis values materialized.
166        let x_host = axis_to_host_async(&x_axis).await?;
167        let y_host = axis_to_host_async(&y_axis).await?;
168        let z_host = match z_axis.as_ref() {
169            Some(axis) => Some(axis_to_host_async(axis).await?),
170            None => None,
171        };
172        outputs = build_outputs(&x_host, &y_host, z_host.as_ref())
173            .into_iter()
174            .map(MeshgridOutput::Host)
175            .collect();
176    }
177
178    Ok(MeshgridEval {
179        outputs,
180        target_class,
181        target_residency,
182    })
183}
184
185#[derive(Clone)]
186struct ParsedMeshgrid {
187    axes: Vec<AxisData>,
188    template: OutputTemplate,
189    prefer_gpu: bool,
190}
191
192impl ParsedMeshgrid {
193    async fn parse(args: &[Value]) -> crate::BuiltinResult<Self> {
194        if args.is_empty() {
195            return Err(builtin_error(
196                "meshgrid: at least one input vector is required",
197            ));
198        }
199        let mut axis_values: Vec<Value> = Vec::new();
200        let mut like_proto: Option<Value> = None;
201        let mut prefer_gpu = false;
202        let mut idx = 0;
203        while idx < args.len() {
204            let value = args[idx].clone();
205            if let Some(keyword) = keyword_of(&value) {
206                match keyword.as_str() {
207                    "like" => {
208                        if like_proto.is_some() {
209                            return Err(builtin_error(
210                                "meshgrid: multiple 'like' specifications are not supported",
211                            ));
212                        }
213                        if axis_values.is_empty() {
214                            return Err(builtin_error(
215                                "meshgrid: 'like' must follow at least one input vector",
216                            ));
217                        }
218                        let Some(proto) = args.get(idx + 1).cloned() else {
219                            return Err(builtin_error("meshgrid: expected prototype after 'like'"));
220                        };
221                        like_proto = Some(proto);
222                        idx += 2;
223                        if idx < args.len() {
224                            return Err(builtin_error(
225                                "meshgrid: 'like' must be the final argument",
226                            ));
227                        }
228                        break;
229                    }
230                    other => {
231                        return Err(builtin_error(format!(
232                            "meshgrid: unrecognised option '{other}'"
233                        )));
234                    }
235                }
236            }
237
238            if let Value::GpuTensor(_) = value {
239                prefer_gpu = true;
240            }
241            axis_values.push(value);
242            idx += 1;
243        }
244
245        if axis_values.is_empty() {
246            return Err(builtin_error(
247                "meshgrid: at least one input vector is required",
248            ));
249        }
250        if axis_values.len() > 3 {
251            return Err(builtin_error(
252                "meshgrid: expected at most three input vectors",
253            ));
254        }
255
256        let mut axes = Vec::with_capacity(max(axis_values.len(), 2));
257        for (i, value) in axis_values.into_iter().enumerate() {
258            let mut consumed_gpu = false;
259            let data = axis_from_value(value, i, &mut consumed_gpu).await?;
260            if consumed_gpu {
261                prefer_gpu = true;
262            }
263            axes.push(data);
264        }
265
266        if !prefer_gpu {
267            if let Some(max_len) = axes.iter().map(|axis| axis.len).max() {
268                if max_len > 0
269                    && sequence_gpu_preference(max_len, SequenceIntent::MeshAxis, false).prefer_gpu
270                {
271                    prefer_gpu = true;
272                }
273            }
274        }
275
276        let template = if let Some(proto) = like_proto {
277            OutputTemplate::Like(analyse_like_prototype(&proto)?)
278        } else {
279            OutputTemplate::Default
280        };
281
282        Ok(Self {
283            axes,
284            template,
285            prefer_gpu,
286        })
287    }
288}
289
290#[derive(Clone)]
291enum OutputTemplate {
292    Default,
293    Like(PrototypeSpec),
294}
295
296#[derive(Clone)]
297struct PrototypeSpec {
298    residency: DevicePreference,
299    class: PrototypeClass,
300}
301
302#[derive(Clone, Copy, PartialEq, Eq)]
303enum PrototypeClass {
304    Real,
305    Complex,
306}
307
308#[derive(Clone, Copy)]
309enum DevicePreference {
310    Host,
311    Gpu,
312}
313
314fn analyse_like_prototype(proto: &Value) -> crate::BuiltinResult<PrototypeSpec> {
315    match proto {
316        Value::GpuTensor(_) => Ok(PrototypeSpec {
317            residency: DevicePreference::Gpu,
318            class: PrototypeClass::Real,
319        }),
320        Value::ComplexTensor(_) | Value::Complex(_, _) => Ok(PrototypeSpec {
321            residency: DevicePreference::Host,
322            class: PrototypeClass::Complex,
323        }),
324        Value::Tensor(_)
325        | Value::Num(_)
326        | Value::Int(_)
327        | Value::Bool(_)
328        | Value::LogicalArray(_) => Ok(PrototypeSpec {
329            residency: DevicePreference::Host,
330            class: PrototypeClass::Real,
331        }),
332        Value::CharArray(_) | Value::String(_) | Value::StringArray(_) => Err(builtin_error(
333            "meshgrid: prototypes must be numeric or gpuArray values",
334        )),
335        Value::Cell(_)
336        | Value::Struct(_)
337        | Value::Object(_)
338        | Value::HandleObject(_)
339        | Value::Listener(_)
340        | Value::FunctionHandle(_)
341        | Value::Closure(_)
342        | Value::ClassRef(_)
343        | Value::MException(_)
344        | Value::OutputList(_) => Err(builtin_error("meshgrid: prototypes must be numeric arrays")),
345    }
346}
347
348#[derive(Clone)]
349struct AxisData {
350    values: Vec<(f64, f64)>,
351    len: usize,
352    is_complex: bool,
353    gpu_real: Option<GpuTensorHandle>,
354}
355
356async fn axis_from_value(
357    value: Value,
358    index: usize,
359    prefer_gpu: &mut bool,
360) -> crate::BuiltinResult<AxisData> {
361    match value {
362        Value::Tensor(tensor) => axis_from_tensor(tensor, index),
363        Value::LogicalArray(logical) => {
364            let tensor = tensor::logical_to_tensor(&logical)?;
365            axis_from_tensor(tensor, index)
366        }
367        Value::Num(n) => Ok(AxisData {
368            values: vec![(n, 0.0)],
369            len: 1,
370            is_complex: false,
371            gpu_real: None,
372        }),
373        Value::Int(i) => {
374            let val = i.to_f64();
375            Ok(AxisData {
376                values: vec![(val, 0.0)],
377                len: 1,
378                is_complex: false,
379                gpu_real: None,
380            })
381        }
382        Value::Bool(b) => Ok(AxisData {
383            values: vec![(if b { 1.0 } else { 0.0 }, 0.0)],
384            len: 1,
385            is_complex: false,
386            gpu_real: None,
387        }),
388        Value::Complex(re, im) => Ok(AxisData {
389            values: vec![(re, im)],
390            len: 1,
391            is_complex: im != 0.0,
392            gpu_real: None,
393        }),
394        Value::ComplexTensor(tensor) => axis_from_complex_tensor(tensor, index),
395        Value::GpuTensor(handle) => {
396            // Fast path: if the gpuArray is vector-like, keep it on-device and avoid a download.
397            // We'll validate any non-vector shapes by gathering below.
398            if is_vector_shape(&handle.shape) {
399                *prefer_gpu = true;
400                return Ok(AxisData {
401                    values: Vec::new(),
402                    len: vector_len_from_shape(&handle.shape),
403                    is_complex: false,
404                    gpu_real: Some(handle),
405                });
406            }
407
408            // Fallback: gather to validate / recover axes from meshgrid matrices.
409            let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
410            if is_vector_shape(&tensor.shape) {
411                *prefer_gpu = true;
412            }
413            axis_from_tensor(tensor, index)
414        }
415        other => Err(builtin_error(format!(
416            "meshgrid: input argument {} must be numeric, got {other:?}",
417            index + 1
418        ))),
419    }
420}
421
422fn axis_from_tensor(tensor: Tensor, index: usize) -> crate::BuiltinResult<AxisData> {
423    if is_vector_shape(&tensor.shape) {
424        let mut values = Vec::with_capacity(tensor.data.len());
425        for &v in &tensor.data {
426            values.push((v, 0.0));
427        }
428        return Ok(AxisData {
429            len: values.len(),
430            values,
431            is_complex: false,
432            gpu_real: None,
433        });
434    }
435
436    // Be slightly more permissive than MATLAB: if the input is already a meshgrid-style
437    // coordinate matrix, accept it and recover the original axis vector.
438    //
439    // This is a pragmatic compatibility shim for cases where callers already have
440    // coordinate matrices (X/Y) and pass them through `meshgrid` again.
441    if let Some(axis) = axis_from_meshgrid_matrix_real(&tensor, index)? {
442        return Ok(axis);
443    }
444
445    Err(builtin_error(format!(
446        "meshgrid: input argument {} must be a vector (1xN or Nx1), got shape {:?}",
447        index + 1,
448        tensor.shape
449    )))
450}
451
452fn axis_from_complex_tensor(tensor: ComplexTensor, index: usize) -> crate::BuiltinResult<AxisData> {
453    if is_vector_shape(&tensor.shape) {
454        let is_complex = tensor
455            .data
456            .iter()
457            .any(|&(_, imag)| !imag.is_nan() && imag != 0.0);
458        return Ok(AxisData {
459            len: tensor.data.len(),
460            values: tensor.data,
461            is_complex,
462            gpu_real: None,
463        });
464    }
465
466    if let Some(axis) = axis_from_meshgrid_matrix_complex(&tensor, index)? {
467        return Ok(axis);
468    }
469
470    Err(builtin_error(format!(
471        "meshgrid: input argument {} must be a vector (1xN or Nx1), got shape {:?}",
472        index + 1,
473        tensor.shape
474    )))
475}
476
477fn axis_from_meshgrid_matrix_real(
478    tensor: &Tensor,
479    index: usize,
480) -> crate::BuiltinResult<Option<AxisData>> {
481    let (rows, cols) = match tensor.shape.as_slice() {
482        [r, c] => (*r, *c),
483        _ => return Ok(None),
484    };
485    if rows <= 1 || cols <= 1 {
486        return Ok(None);
487    }
488
489    // Index 0 is expected to be the X-axis: a meshgrid X matrix has identical rows.
490    // Index 1 is expected to be the Y-axis: a meshgrid Y matrix has identical columns.
491    let expect_rows_constant = index == 0;
492
493    if expect_rows_constant {
494        if !matrix_rows_are_identical_real(tensor, rows, cols) {
495            return Ok(None);
496        }
497        // Extract the first row as the axis vector (length = cols).
498        let mut values = Vec::with_capacity(cols);
499        for col in 0..cols {
500            let idx = rows * col;
501            values.push((tensor.data[idx], 0.0));
502        }
503        return Ok(Some(AxisData {
504            len: values.len(),
505            values,
506            is_complex: false,
507            gpu_real: None,
508        }));
509    }
510
511    if !matrix_cols_are_identical_real(tensor, rows, cols) {
512        return Ok(None);
513    }
514    // Extract the first column as the axis vector (length = rows).
515    let mut values = Vec::with_capacity(rows);
516    for row in 0..rows {
517        values.push((tensor.data[row], 0.0));
518    }
519    Ok(Some(AxisData {
520        len: values.len(),
521        values,
522        is_complex: false,
523        gpu_real: None,
524    }))
525}
526
527fn axis_from_meshgrid_matrix_complex(
528    tensor: &ComplexTensor,
529    index: usize,
530) -> crate::BuiltinResult<Option<AxisData>> {
531    let (rows, cols) = match tensor.shape.as_slice() {
532        [r, c] => (*r, *c),
533        _ => return Ok(None),
534    };
535    if rows <= 1 || cols <= 1 {
536        return Ok(None);
537    }
538
539    let expect_rows_constant = index == 0;
540    if expect_rows_constant {
541        if !matrix_rows_are_identical_complex(tensor, rows, cols) {
542            return Ok(None);
543        }
544        let mut values = Vec::with_capacity(cols);
545        for col in 0..cols {
546            let idx = rows * col;
547            values.push(tensor.data[idx]);
548        }
549        let is_complex = values.iter().any(|&(_, im)| !im.is_nan() && im != 0.0);
550        return Ok(Some(AxisData {
551            len: values.len(),
552            values,
553            is_complex,
554            gpu_real: None,
555        }));
556    }
557
558    if !matrix_cols_are_identical_complex(tensor, rows, cols) {
559        return Ok(None);
560    }
561    let mut values = Vec::with_capacity(rows);
562    for row in 0..rows {
563        values.push(tensor.data[row]);
564    }
565    let is_complex = values.iter().any(|&(_, im)| !im.is_nan() && im != 0.0);
566    Ok(Some(AxisData {
567        len: values.len(),
568        values,
569        is_complex,
570        gpu_real: None,
571    }))
572}
573
574fn matrix_rows_are_identical_real(tensor: &Tensor, rows: usize, cols: usize) -> bool {
575    for row in 1..rows {
576        for col in 0..cols {
577            let idx0 = rows * col;
578            let idx = row + rows * col;
579            if tensor.data[idx] != tensor.data[idx0] {
580                return false;
581            }
582        }
583    }
584    true
585}
586
587fn matrix_cols_are_identical_real(tensor: &Tensor, rows: usize, cols: usize) -> bool {
588    for col in 1..cols {
589        for row in 0..rows {
590            let idx0 = row;
591            let idx = row + rows * col;
592            if tensor.data[idx] != tensor.data[idx0] {
593                return false;
594            }
595        }
596    }
597    true
598}
599
600fn matrix_rows_are_identical_complex(tensor: &ComplexTensor, rows: usize, cols: usize) -> bool {
601    for row in 1..rows {
602        for col in 0..cols {
603            let idx0 = rows * col;
604            let idx = row + rows * col;
605            if tensor.data[idx] != tensor.data[idx0] {
606                return false;
607            }
608        }
609    }
610    true
611}
612
613fn matrix_cols_are_identical_complex(tensor: &ComplexTensor, rows: usize, cols: usize) -> bool {
614    for col in 1..cols {
615        for row in 0..rows {
616            let idx0 = row;
617            let idx = row + rows * col;
618            if tensor.data[idx] != tensor.data[idx0] {
619                return false;
620            }
621        }
622    }
623    true
624}
625
626fn is_vector_shape(shape: &[usize]) -> bool {
627    if shape.is_empty() {
628        return true;
629    }
630    let mut non_singleton = 0usize;
631    for &dim in shape {
632        if dim > 1 {
633            non_singleton += 1;
634        }
635    }
636    non_singleton <= 1
637}
638
639fn vector_len_from_shape(shape: &[usize]) -> usize {
640    if shape.is_empty() {
641        return 1;
642    }
643    shape.iter().copied().max().unwrap_or(0)
644}
645
646async fn axis_to_host_async(axis: &AxisData) -> crate::BuiltinResult<AxisData> {
647    if axis.gpu_real.is_none() {
648        return Ok(axis.clone());
649    }
650    let handle = axis.gpu_real.as_ref().expect("checked gpu_real is_some");
651    let tensor = gpu_helpers::gather_tensor_async(handle).await?;
652    // Index is only used for error messages; tensor came from a validated vector-like handle.
653    axis_from_tensor(tensor, 0)
654}
655
656fn try_meshgrid_gpu_from_vector_axes(
657    x_axis: &AxisData,
658    y_axis: &AxisData,
659    z_axis: Option<&AxisData>,
660) -> crate::BuiltinResult<Option<Vec<MeshgridOutput>>> {
661    let Some(x_handle) = x_axis.gpu_real.as_ref() else {
662        return Ok(None);
663    };
664    let Some(y_handle) = y_axis.gpu_real.as_ref() else {
665        return Ok(None);
666    };
667
668    let z_handle = match z_axis {
669        Some(axis) => match axis.gpu_real.as_ref() {
670            Some(h) => Some(h),
671            None => return Ok(None),
672        },
673        None => None,
674    };
675
676    let Some(provider) = runmat_accelerate_api::provider_for_handle(x_handle) else {
677        return Ok(None);
678    };
679    if runmat_accelerate_api::provider_for_handle(y_handle).is_none() {
680        return Ok(None);
681    }
682    if let Some(z) = z_handle {
683        if runmat_accelerate_api::provider_for_handle(z).is_none() {
684            return Ok(None);
685        }
686    }
687
688    let nx = x_axis.len;
689    let ny = y_axis.len;
690    let nz = z_axis.map(|axis| axis.len).unwrap_or(1);
691
692    // Reshape axis vectors (metadata-only) so repmat can build full grids on-device.
693    let x_row = provider
694        .reshape(x_handle, &[1, nx])
695        .map_err(|e| builtin_error(format!("meshgrid: reshape X failed: {e}")))?;
696    let y_col = provider
697        .reshape(y_handle, &[ny, 1])
698        .map_err(|e| builtin_error(format!("meshgrid: reshape Y failed: {e}")))?;
699
700    let mut outputs = Vec::with_capacity(if z_handle.is_some() { 3 } else { 2 });
701    if let Some(z) = z_handle {
702        let x_base = provider
703            .reshape(&x_row, &[1, nx, 1])
704            .map_err(|e| builtin_error(format!("meshgrid: reshape X(3d) failed: {e}")))?;
705        let y_base = provider
706            .reshape(&y_col, &[ny, 1, 1])
707            .map_err(|e| builtin_error(format!("meshgrid: reshape Y(3d) failed: {e}")))?;
708
709        let x_grid = provider
710            .repmat(&x_base, &[ny, 1, nz])
711            .map_err(|e| builtin_error(format!("meshgrid: repmat X failed: {e}")))?;
712        let y_grid = provider
713            .repmat(&y_base, &[1, nx, nz])
714            .map_err(|e| builtin_error(format!("meshgrid: repmat Y failed: {e}")))?;
715
716        outputs.push(MeshgridOutput::GpuReal(x_grid));
717        outputs.push(MeshgridOutput::GpuReal(y_grid));
718        let z_axis_row = provider
719            .reshape(z, &[1, nz])
720            .map_err(|e| builtin_error(format!("meshgrid: reshape Z failed: {e}")))?;
721        let z_base = provider
722            .reshape(&z_axis_row, &[1, 1, nz])
723            .map_err(|e| builtin_error(format!("meshgrid: reshape Z(3d) failed: {e}")))?;
724        let z_grid = provider
725            .repmat(&z_base, &[ny, nx, 1])
726            .map_err(|e| builtin_error(format!("meshgrid: repmat Z failed: {e}")))?;
727        outputs.push(MeshgridOutput::GpuReal(z_grid));
728    } else {
729        let x_grid = provider
730            .repmat(&x_row, &[ny, 1])
731            .map_err(|e| builtin_error(format!("meshgrid: repmat X failed: {e}")))?;
732        let y_grid = provider
733            .repmat(&y_col, &[1, nx])
734            .map_err(|e| builtin_error(format!("meshgrid: repmat Y failed: {e}")))?;
735        outputs.push(MeshgridOutput::GpuReal(x_grid));
736        outputs.push(MeshgridOutput::GpuReal(y_grid));
737    }
738
739    Ok(Some(outputs))
740}
741
742fn normalise_axes(axes: &[AxisData]) -> (AxisData, AxisData, Option<AxisData>) {
743    match axes.len() {
744        1 => {
745            let x = axes[0].clone();
746            (x.clone(), x, None)
747        }
748        2 => {
749            let x = axes[0].clone();
750            let y = axes[1].clone();
751            (x, y, None)
752        }
753        3 => {
754            let x = axes[0].clone();
755            let y = axes[1].clone();
756            let z = axes[2].clone();
757            (x, y, Some(z))
758        }
759        _ => unreachable!(),
760    }
761}
762
763fn build_outputs(
764    x_axis: &AxisData,
765    y_axis: &AxisData,
766    z_axis: Option<&AxisData>,
767) -> Vec<GridOutput> {
768    let nx = x_axis.len;
769    let ny = y_axis.len;
770    let nz = z_axis.map(|axis| axis.len).unwrap_or(1);
771    let total = nx * ny * nz;
772    let mut x_data = Vec::with_capacity(total);
773    let mut y_data = Vec::with_capacity(total);
774    let mut z_data = z_axis.map(|_| Vec::with_capacity(total));
775
776    for k in 0..nz {
777        let z_value = z_axis.map(|axis| axis.values[k]);
778        for col in 0..nx {
779            let x_value = x_axis.values[col];
780            for row in 0..ny {
781                x_data.push(x_value);
782                y_data.push(y_axis.values[row]);
783                if let Some(ref mut z_vec) = z_data {
784                    z_vec.push(z_value.unwrap());
785                }
786            }
787        }
788    }
789
790    let mut outputs = Vec::new();
791    let base_shape = if nz == 1 {
792        vec![ny, nx]
793    } else {
794        vec![ny, nx, nz]
795    };
796    outputs.push(GridOutput {
797        shape: base_shape.clone(),
798        data: x_data,
799    });
800    outputs.push(GridOutput {
801        shape: base_shape.clone(),
802        data: y_data,
803    });
804    if let Some(z_vec) = z_data {
805        outputs.push(GridOutput {
806            shape: base_shape,
807            data: z_vec,
808        });
809    }
810    outputs
811}
812
813struct GridOutput {
814    shape: Vec<usize>,
815    data: Vec<(f64, f64)>,
816}
817
818impl GridOutput {
819    fn to_value(
820        &self,
821        class: PrototypeClass,
822        residency: DevicePreference,
823    ) -> crate::BuiltinResult<Value> {
824        match class {
825            PrototypeClass::Real => self.to_real_value(residency),
826            PrototypeClass::Complex => self.to_complex_value(residency),
827        }
828    }
829
830    fn to_real_value(&self, residency: DevicePreference) -> crate::BuiltinResult<Value> {
831        let mut real = Vec::with_capacity(self.data.len());
832        for &(re, im) in &self.data {
833            if im != 0.0 {
834                return Err(builtin_error(
835                    "meshgrid: cannot represent complex values in a real output",
836                ));
837            }
838            real.push(re);
839        }
840        let tensor = Tensor::new(real, self.shape.clone())
841            .map_err(|e| builtin_error(format!("meshgrid: {e}")))?;
842        match residency {
843            DevicePreference::Host => Ok(tensor::tensor_into_value(tensor)),
844            DevicePreference::Gpu => to_gpu_tensor_value(tensor),
845        }
846    }
847
848    fn to_complex_value(&self, residency: DevicePreference) -> crate::BuiltinResult<Value> {
849        let tensor = ComplexTensor::new(self.data.clone(), self.shape.clone())
850            .map_err(|e| builtin_error(format!("meshgrid: {e}")))?;
851        match residency {
852            DevicePreference::Host => Ok(complex_tensor_into_value(tensor)),
853            DevicePreference::Gpu => {
854                warn!("meshgrid: complex GPU outputs are not implemented; returning host complex array");
855                Ok(complex_tensor_into_value(tensor))
856            }
857        }
858    }
859}
860
861fn to_gpu_tensor_value(tensor: Tensor) -> crate::BuiltinResult<Value> {
862    if let Some(provider) = runmat_accelerate_api::provider() {
863        let view = HostTensorView {
864            data: &tensor.data,
865            shape: &tensor.shape,
866        };
867        match provider.upload(&view) {
868            Ok(handle) => return Ok(Value::GpuTensor(handle)),
869            Err(err) => {
870                warn!("meshgrid: failed to upload tensor to GPU, returning host array: {err}")
871            }
872        }
873    }
874    Ok(tensor::tensor_into_value(tensor))
875}
876
877fn tensor_to_complex_value(tensor: Tensor) -> crate::BuiltinResult<Value> {
878    let data: Vec<(f64, f64)> = tensor.data.iter().map(|&re| (re, 0.0)).collect();
879    let complex = ComplexTensor::new(data, tensor.shape.clone())
880        .map_err(|e| builtin_error(format!("meshgrid: {e}")))?;
881    Ok(complex_tensor_into_value(complex))
882}
883
884enum MeshgridOutput {
885    Host(GridOutput),
886    GpuReal(GpuTensorHandle),
887}
888
889impl MeshgridOutput {
890    async fn to_value(
891        &self,
892        class: PrototypeClass,
893        residency: DevicePreference,
894    ) -> crate::BuiltinResult<Value> {
895        match self {
896            MeshgridOutput::Host(host) => host.to_value(class, residency),
897            MeshgridOutput::GpuReal(handle) => match (class, residency) {
898                (PrototypeClass::Real, DevicePreference::Gpu) => {
899                    Ok(Value::GpuTensor(handle.clone()))
900                }
901                (PrototypeClass::Real, DevicePreference::Host) => {
902                    let tensor = gpu_helpers::gather_tensor_async(handle).await?;
903                    Ok(tensor::tensor_into_value(tensor))
904                }
905                (PrototypeClass::Complex, DevicePreference::Host) => {
906                    let tensor = gpu_helpers::gather_tensor_async(handle).await?;
907                    tensor_to_complex_value(tensor)
908                }
909                (PrototypeClass::Complex, DevicePreference::Gpu) => {
910                    warn!("meshgrid: complex GPU outputs are not implemented; returning host complex array");
911                    let tensor = gpu_helpers::gather_tensor_async(handle).await?;
912                    tensor_to_complex_value(tensor)
913                }
914            },
915        }
916    }
917}
918
919/// Holds the results of a `meshgrid` evaluation so multiple outputs can be
920/// materialised without recomputing the grid.
921pub struct MeshgridEval {
922    outputs: Vec<MeshgridOutput>,
923    target_class: PrototypeClass,
924    target_residency: DevicePreference,
925}
926
927impl MeshgridEval {
928    pub fn output_count(&self) -> usize {
929        self.outputs.len()
930    }
931
932    pub async fn first(&self) -> crate::BuiltinResult<Value> {
933        self.outputs[0]
934            .to_value(self.target_class, self.target_residency)
935            .await
936    }
937
938    pub async fn second(&self) -> crate::BuiltinResult<Value> {
939        if self.outputs.len() < 2 {
940            Err(builtin_error("meshgrid: second output unavailable"))
941        } else {
942            self.outputs[1]
943                .to_value(self.target_class, self.target_residency)
944                .await
945        }
946    }
947
948    pub async fn third(&self) -> crate::BuiltinResult<Value> {
949        if self.outputs.len() < 3 {
950            Err(builtin_error(
951                "meshgrid: third output requested but no Z vector was supplied",
952            ))
953        } else {
954            self.outputs[2]
955                .to_value(self.target_class, self.target_residency)
956                .await
957        }
958    }
959}
960
961#[cfg(test)]
962pub(crate) mod tests {
963    use super::*;
964    use crate::builtins::common::test_support;
965    use futures::executor::block_on;
966    #[cfg(feature = "wgpu")]
967    use runmat_accelerate_api::AccelProvider;
968
969    use runmat_accelerate_api::HostTensorView;
970
971    fn evaluate(args: &[Value]) -> crate::BuiltinResult<MeshgridEval> {
972        block_on(super::evaluate(args))
973    }
974
975    fn eval_first(eval: &MeshgridEval) -> crate::BuiltinResult<Value> {
976        block_on(eval.first())
977    }
978
979    fn eval_second(eval: &MeshgridEval) -> crate::BuiltinResult<Value> {
980        block_on(eval.second())
981    }
982
983    fn eval_third(eval: &MeshgridEval) -> crate::BuiltinResult<Value> {
984        block_on(eval.third())
985    }
986
987    fn tensor_from_vec(data: Vec<f64>, rows: usize, cols: usize) -> Tensor {
988        Tensor::new(data, vec![rows, cols]).unwrap()
989    }
990
991    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
992    #[test]
993    fn meshgrid_single_input_duplicates_axis() {
994        let x = tensor_from_vec(vec![-1.0, 0.0, 1.0], 1, 3);
995        let eval = evaluate(&[Value::Tensor(x)]).expect("meshgrid");
996        assert_eq!(eval.output_count(), 2);
997        let x_out = test_support::gather(eval_first(&eval).expect("X")).expect("host");
998        assert_eq!(x_out.shape, vec![3, 3]);
999        assert_eq!(
1000            x_out.data,
1001            vec![-1.0, -1.0, -1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0]
1002        );
1003        let y_out = test_support::gather(eval_second(&eval).expect("Y")).expect("host");
1004        assert_eq!(y_out.shape, vec![3, 3]);
1005        assert_eq!(
1006            y_out.data,
1007            vec![-1.0, 0.0, 1.0, -1.0, 0.0, 1.0, -1.0, 0.0, 1.0]
1008        );
1009    }
1010
1011    #[test]
1012    fn meshgrid_type_infers_rank_from_axis_count() {
1013        let ctx = ResolveContext::new(Vec::new());
1014        assert_eq!(
1015            meshgrid_type(&[Type::Num, Type::Num], &ctx),
1016            Type::Tensor {
1017                shape: Some(vec![Some(1), Some(1)])
1018            }
1019        );
1020        assert_eq!(
1021            meshgrid_type(&[Type::Num, Type::Num, Type::Num], &ctx),
1022            Type::Tensor {
1023                shape: Some(vec![Some(1), Some(1), Some(1)])
1024            }
1025        );
1026    }
1027
1028    #[test]
1029    fn meshgrid_type_uses_vector_lengths() {
1030        let ctx = ResolveContext::new(Vec::new());
1031        assert_eq!(
1032            meshgrid_type(
1033                &[
1034                    Type::Tensor {
1035                        shape: Some(vec![Some(1), Some(201)]),
1036                    },
1037                    Type::Tensor {
1038                        shape: Some(vec![Some(1), Some(101)]),
1039                    },
1040                ],
1041                &ctx,
1042            ),
1043            Type::Tensor {
1044                shape: Some(vec![Some(101), Some(201)])
1045            }
1046        );
1047    }
1048
1049    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1050    #[test]
1051    fn meshgrid_rectangular_inputs() {
1052        let x = tensor_from_vec(vec![0.0, 0.5, 1.0], 1, 3);
1053        let y = tensor_from_vec(vec![10.0, 20.0], 2, 1);
1054        let eval = evaluate(&[Value::Tensor(x), Value::Tensor(y)]).expect("meshgrid");
1055        assert_eq!(eval.output_count(), 2);
1056        let x_out = test_support::gather(eval_first(&eval).expect("X")).expect("host");
1057        assert_eq!(x_out.shape, vec![2, 3]);
1058        assert_eq!(x_out.data, vec![0.0, 0.0, 0.5, 0.5, 1.0, 1.0]);
1059        let y_out = test_support::gather(eval_second(&eval).expect("Y")).expect("host");
1060        assert_eq!(y_out.shape, vec![2, 3]);
1061        assert_eq!(y_out.data, vec![10.0, 20.0, 10.0, 20.0, 10.0, 20.0]);
1062    }
1063
1064    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1065    #[test]
1066    fn meshgrid_three_inputs_volume() {
1067        let x = tensor_from_vec(vec![1.0, 2.0], 1, 2);
1068        let y = tensor_from_vec(vec![5.0, 6.0, 7.0], 3, 1);
1069        let z = tensor_from_vec(vec![0.0, 1.0], 1, 2);
1070        let eval =
1071            evaluate(&[Value::Tensor(x), Value::Tensor(y), Value::Tensor(z)]).expect("meshgrid");
1072        assert_eq!(eval.output_count(), 3);
1073        let x_out = test_support::gather(eval_first(&eval).expect("X")).expect("host");
1074        assert_eq!(x_out.shape, vec![3, 2, 2]);
1075        assert_eq!(
1076            x_out.data,
1077            vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0]
1078        );
1079        let z_out = test_support::gather(eval_third(&eval).expect("Z")).expect("host");
1080        assert_eq!(z_out.shape, vec![3, 2, 2]);
1081        assert_eq!(
1082            z_out.data,
1083            vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
1084        );
1085    }
1086
1087    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1088    #[test]
1089    fn meshgrid_like_keeps_gpu_residency() {
1090        test_support::with_test_provider(|provider| {
1091            let x = tensor_from_vec(vec![-1.0, 0.0, 1.0], 1, 3);
1092            let y = tensor_from_vec(vec![2.0, 4.0], 2, 1);
1093            let proto = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
1094            let proto_view = HostTensorView {
1095                data: &proto.data,
1096                shape: &proto.shape,
1097            };
1098            let proto_handle = provider.upload(&proto_view).expect("upload");
1099            let eval = evaluate(&[
1100                Value::Tensor(x),
1101                Value::Tensor(y),
1102                Value::from("like"),
1103                Value::GpuTensor(proto_handle),
1104            ])
1105            .expect("meshgrid");
1106            let x_value = eval_first(&eval).expect("X");
1107            assert!(matches!(x_value, Value::GpuTensor(_)));
1108            let gathered = test_support::gather(x_value).expect("gather");
1109            assert_eq!(gathered.shape, vec![2, 3]);
1110        });
1111    }
1112
1113    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1114    #[test]
1115    fn meshgrid_gpu_inputs_roundtrip() {
1116        test_support::with_test_provider(|provider| {
1117            let x = tensor_from_vec(vec![0.0, 0.5], 1, 2);
1118            let y = tensor_from_vec(vec![1.0, 2.0], 2, 1);
1119            let x_view = HostTensorView {
1120                data: &x.data,
1121                shape: &x.shape,
1122            };
1123            let y_view = HostTensorView {
1124                data: &y.data,
1125                shape: &y.shape,
1126            };
1127            let x_handle = provider.upload(&x_view).expect("upload");
1128            let y_handle = provider.upload(&y_view).expect("upload");
1129            let eval = evaluate(&[Value::GpuTensor(x_handle), Value::GpuTensor(y_handle)])
1130                .expect("meshgrid");
1131            assert!(matches!(eval_first(&eval).expect("X"), Value::GpuTensor(_)));
1132            assert!(matches!(
1133                eval_second(&eval).expect("Y"),
1134                Value::GpuTensor(_)
1135            ));
1136        });
1137    }
1138
1139    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1140    #[test]
1141    #[cfg(feature = "wgpu")]
1142    fn meshgrid_wgpu_matches_cpu() {
1143        let provider = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1144            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1145        )
1146        .expect("wgpu provider");
1147
1148        let x = tensor_from_vec(vec![-1.0, 0.0, 1.0, 2.0], 1, 4);
1149        let y = tensor_from_vec(vec![5.0, 6.0], 2, 1);
1150
1151        let cpu_eval =
1152            evaluate(&[Value::Tensor(x.clone()), Value::Tensor(y.clone())]).expect("meshgrid cpu");
1153        let cpu_x =
1154            test_support::gather(eval_first(&cpu_eval).expect("X cpu")).expect("gather X cpu");
1155        let cpu_y =
1156            test_support::gather(eval_second(&cpu_eval).expect("Y cpu")).expect("gather Y cpu");
1157
1158        let x_view = HostTensorView {
1159            data: &x.data,
1160            shape: &x.shape,
1161        };
1162        let y_view = HostTensorView {
1163            data: &y.data,
1164            shape: &y.shape,
1165        };
1166        let x_gpu = provider.upload(&x_view).expect("upload x");
1167        let y_gpu = provider.upload(&y_view).expect("upload y");
1168
1169        let gpu_eval =
1170            evaluate(&[Value::GpuTensor(x_gpu), Value::GpuTensor(y_gpu)]).expect("meshgrid gpu");
1171        let gpu_x_value = eval_first(&gpu_eval).expect("X gpu");
1172        let gpu_y_value = eval_second(&gpu_eval).expect("Y gpu");
1173
1174        assert!(matches!(gpu_x_value, Value::GpuTensor(_)));
1175        assert!(matches!(gpu_y_value, Value::GpuTensor(_)));
1176
1177        let gathered_x = test_support::gather(gpu_x_value).expect("gather X gpu");
1178        let gathered_y = test_support::gather(gpu_y_value).expect("gather Y gpu");
1179
1180        assert_eq!(gathered_x.shape, cpu_x.shape);
1181        assert_eq!(gathered_x.data, cpu_x.data);
1182        assert_eq!(gathered_y.shape, cpu_y.shape);
1183        assert_eq!(gathered_y.data, cpu_y.data);
1184    }
1185
1186    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1187    #[test]
1188    fn meshgrid_complex_inputs_produce_complex_outputs() {
1189        let complex = ComplexTensor::new(vec![(1.0, 1.0), (2.0, -1.0)], vec![1, 2]).unwrap();
1190        let eval = evaluate(&[Value::ComplexTensor(complex)]).expect("meshgrid");
1191        let x_value = eval_first(&eval).expect("X");
1192        match x_value {
1193            Value::ComplexTensor(ct) => {
1194                assert_eq!(ct.shape, vec![2, 2]);
1195            }
1196            Value::Complex(_, _) => {}
1197            other => panic!("expected complex output, got {other:?}"),
1198        }
1199    }
1200
1201    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1202    #[test]
1203    fn meshgrid_like_host_prototype() {
1204        let x = tensor_from_vec(vec![1.0, 2.0], 1, 2);
1205        let eval =
1206            evaluate(&[Value::Tensor(x), Value::from("like"), Value::Num(0.0)]).expect("meshgrid");
1207        let x_out = eval_first(&eval).expect("X");
1208        assert!(matches!(x_out, Value::Tensor(_) | Value::Num(_)));
1209    }
1210}