runmat_runtime/builtins/array/creation/
meshgrid.rs

1//! MATLAB-compatible `meshgrid` builtin with GPU-aware semantics.
2
3use std::cmp::max;
4
5use log::warn;
6use runmat_accelerate_api::{GpuTensorHandle, HostTensorView, MeshgridAxisView};
7use runmat_builtins::{ComplexTensor, Tensor, Value};
8use runmat_macros::runtime_builtin;
9
10use crate::builtins::common::gpu_helpers;
11use crate::builtins::common::random_args::{complex_tensor_into_value, keyword_of};
12use crate::builtins::common::residency::{sequence_gpu_preference, SequenceIntent};
13use crate::builtins::common::spec::{
14    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
15    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
16};
17use crate::builtins::common::tensor;
18#[cfg(feature = "doc_export")]
19use crate::register_builtin_doc_text;
20use crate::{register_builtin_fusion_spec, register_builtin_gpu_spec};
21
22#[cfg(feature = "doc_export")]
23pub const DOC_MD: &str = r#"---
24title: "meshgrid"
25category: "array/creation"
26keywords: ["meshgrid", "grid", "surface", "gpu", "like", "3d"]
27summary: "Generate MATLAB mesh grids for 2-D and 3-D coordinate vectors with optional GPU residency."
28references: []
29gpu_support:
30  elementwise: false
31  reduction: false
32  precisions: ["f32", "f64"]
33  broadcasting: "matlab"
34  notes: "RunMat materialises grids on the host and uploads them when GPU residency is requested. Providers may supply dedicated meshgrid hooks to avoid the round-trip."
35fusion:
36  elementwise: false
37  reduction: false
38  max_inputs: 3
39  constants: "inline"
40requires_feature: null
41tested:
42  unit: "builtins::array::creation::meshgrid::tests"
43  integration: "builtins::array::creation::meshgrid::tests::meshgrid_gpu_inputs_roundtrip"
44---
45
46# What does the `meshgrid` function do in MATLAB / RunMat?
47`meshgrid` turns one-, two-, or three-dimensional vectors into coordinate arrays that
48span a rectangular grid, mirroring MathWorks MATLAB behaviour exactly. The first output
49replicates the `x` vector across rows, the second replicates the `y` vector across columns,
50and the optional third output expands a `z` vector across pages for volumetric grids.
51
52## How does the `meshgrid` function behave in MATLAB / RunMat?
53- `meshgrid(x)` is shorthand for `[X, Y] = meshgrid(x, x)`. It produces square 2-D grids.
54- `meshgrid(x, y)` yields `X` of size `length(y) × length(x)` with rows copied from `x`,
55  and `Y` of the same size with columns copied from `y`.
56- `meshgrid(x, y, z)` returns three outputs sized `length(y) × length(x) × length(z)`,
57  enabling 3-D volume visualisation.
58- Input vectors may be row or column vectors (or even scalars). Empty vectors propagate to
59  empty grids of matching shape.
60- Complex inputs produce complex grids where each output shares the input’s complex values.
61- Supplying GPU vectors (or a `'like', gpuArray(...)` prototype) keeps the outputs on the GPU
62  when an acceleration provider is active. Without provider support, RunMat gathers inputs,
63  materialises the grid on the host, and uploads the result transparently.
64- `'like', prototype` matches both the residency (host or GPU) and numeric class (real or complex)
65  of the prototype. Integer prototypes are promoted to double precision, consistent with MATLAB.
66
67## `meshgrid` Function GPU Execution Behaviour
68- When the active acceleration provider implements the custom `meshgrid` hook, RunMat allocates
69  every coordinate tensor directly on the device so large grids avoid host round-trips entirely.
70- If the hook is missing (or errors), RunMat gathers the 1-D axes, materialises the grids once on
71  the host, and uploads the outputs whenever GPU residency is requested, preserving observable
72  semantics.
73- Complex-valued grids always materialise on the host today; when GPU residency is requested the
74  runtime logs a trace warning and returns host complex tensors so callers still receive correct
75  MATLAB-compatible results.
76
77## Examples of using the `meshgrid` function in MATLAB / RunMat
78
79### Generating a square 2-D grid from one vector
80
81```matlab
82x = -2:2;
83[X, Y] = meshgrid(x);
84```
85
86Expected output (`X` shown; `Y` mirrors the row/column relationship):
87
88```matlab
89X =
90    -2    -1     0     1     2
91    -2    -1     0     1     2
92    -2    -1     0     1     2
93    -2    -1     0     1     2
94    -2    -1     0     1     2
95```
96
97### Building a rectangular grid from two different vectors
98
99```matlab
100x = [0 0.5 1.0];
101y = [10 20];
102[X, Y] = meshgrid(x, y);
103```
104
105Expected output:
106
107```matlab
108X =
109         0    0.5000    1.0000
110         0    0.5000    1.0000
111
112Y =
113    10    10    10
114    20    20    20
115```
116
117### Creating a volumetric grid for 3-D plotting
118
119```matlab
120u = -1:1;
121v = 2:4;
122w = linspace(0, 1, 5);
123[U, V, W] = meshgrid(u, v, w);
124```
125
126Expected output shapes:
127
128```matlab
129size(U) == [3 3 5]
130size(V) == [3 3 5]
131size(W) == [3 3 5]
132```
133
134### Matching an existing GPU prototype
135
136```matlab
137gx = gpuArray(single(linspace(-pi, pi, 4)));
138gy = gpuArray(single([-1 0 1]));
139[Xg, Yg] = meshgrid(gx, gy);
140```
141
142`Xg` and `Yg` remain `gpuArray` values with `single` precision. Gathering them produces the same
143numeric data as the host result.
144
145### Using `'like'` to copy residency from another array
146
147```matlab
148proto = gpuArray.zeros(1, 1, 'double');
149angles = linspace(0, 2*pi, 8);
150radius = [0 1 2];
151[X, Y] = meshgrid(angles, radius, 'like', proto);
152```
153
154Both `X` and `Y` stay on the GPU because the prototype is a `gpuArray`.
155
156### Complex inputs produce complex grids automatically
157
158```matlab
159z = [1+1i, 2+4i];
160[Zx, Zy] = meshgrid(z);
161```
162
163`Zx` and `Zy` are complex arrays whose imaginary parts match the source vector.
164
165## GPU residency in RunMat (Do I need `gpuArray`?)
166
167You usually do **not** need to wrap vectors with `gpuArray` manually. When the active acceleration
168provider supports uploads, RunMat automatically constructs the grid on the host and keeps the
169outputs on the GPU. Supplying a `'like', gpuArray(...)` prototype produces GPU outputs even when
170all inputs are host arrays. Until native provider hooks land, complex-valued grids remain host-side
171and emit a warning when GPU residency is requested.
172
173## FAQ
174
175### How many inputs can `meshgrid` accept?
176
177One, two, or three numeric vectors. Use three inputs when you need volumetric (3-D) grids.
178
179### Can I request three outputs with only one or two inputs?
180
181No. RunMat follows MATLAB and requires three input vectors when three outputs are requested.
182
183### Do row or column vectors behave differently?
184
185No. Any vector shape (row, column, or scalar) is accepted. RunMat treats the linearised data
186identically and replicates it along the appropriate axes.
187
188### What happens with empty vectors?
189
190Empty inputs propagate to empty outputs. For example, `meshgrid([], 1:3)` returns `0×3`
191grids for both outputs.
192
193### Can I use integer vectors?
194
195Yes. Inputs are promoted to double precision internally so the outputs represent the exact same
196values as MATLAB.
197
198### Does `meshgrid` support complex numbers?
199
200Absolutely. Any imaginary components propagate into the outputs. Complex grids currently stay
201on the host even if GPU residency is requested.
202
203### What does `'like'` do?
204
205It matches the numeric class and residency (host or GPU) of the prototype array. Supply a
206`gpuArray` prototype to keep the resulting grids on the GPU.
207
208### How can providers avoid the host fall-back?
209
210Implement the `meshgrid` custom hook in the acceleration provider. RunMat will automatically
211dispatch to it once available.
212
213### Is the output always dense?
214
215Yes. `meshgrid` produces dense arrays. Use `ndgrid` when you need permuted axes or higher-dimensional
216grids beyond three inputs.
217
218### What error do I get if I omit all inputs?
219
220RunMat raises the MATLAB-compatible error `meshgrid: at least one input vector is required`.
221
222## See Also
223[linspace](./linspace), [zeros](./zeros), [ones](./ones), [gpuArray](../../acceleration/gpu/gpuArray), [gather](../../acceleration/gpu/gather)
224
225## Source & Feedback
226- The full source code for the implementation of the `meshgrid` function is available at: [`crates/runmat-runtime/src/builtins/array/creation/meshgrid.rs`](https://github.com/runmat-org/runmat/blob/main/crates/runmat-runtime/src/builtins/array/creation/meshgrid.rs)
227- Found a bug or behavioural difference? Please [open an issue](https://github.com/runmat-org/runmat/issues/new/choose) with details and a minimal repro.
228"#;
229
230pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
231    name: "meshgrid",
232    op_kind: GpuOpKind::Custom("array_construct"),
233    supported_precisions: &[ScalarType::F32, ScalarType::F64],
234    broadcast: BroadcastSemantics::Matlab,
235    provider_hooks: &[ProviderHook::Custom("meshgrid")],
236    constant_strategy: ConstantStrategy::InlineLiteral,
237    residency: ResidencyPolicy::NewHandle,
238    nan_mode: ReductionNaN::Include,
239    two_pass_threshold: None,
240    workgroup_size: None,
241    accepts_nan_mode: false,
242    notes: "Providers may supply a dedicated meshgrid hook; until then the runtime builds grids on the host and uploads them when GPU residency is requested.",
243};
244
245register_builtin_gpu_spec!(GPU_SPEC);
246
247pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
248    name: "meshgrid",
249    shape: ShapeRequirements::Any,
250    constant_strategy: ConstantStrategy::InlineLiteral,
251    elementwise: None,
252    reduction: None,
253    emits_nan: false,
254    notes:
255        "Meshgrid explicitly materialises dense coordinate arrays and therefore bypasses fusion.",
256};
257
258register_builtin_fusion_spec!(FUSION_SPEC);
259
260#[cfg(feature = "doc_export")]
261register_builtin_doc_text!("meshgrid", DOC_MD);
262
263#[runtime_builtin(
264    name = "meshgrid",
265    category = "array/creation",
266    summary = "Generate coordinate matrices for 2-D and 3-D grids.",
267    keywords = "meshgrid,grid,gpu,like,3d",
268    accel = "array_construct"
269)]
270fn meshgrid_builtin(rest: Vec<Value>) -> Result<Value, String> {
271    let eval = evaluate(&rest)?;
272    eval.first()
273}
274
275/// Evaluate the `meshgrid` builtin once and reuse the result for multiple outputs.
276pub fn evaluate(args: &[Value]) -> Result<MeshgridEval, String> {
277    let parsed = ParsedMeshgrid::parse(args)?;
278    let (x_axis, y_axis, z_axis) = normalise_axes(&parsed.axes);
279
280    let require_complex = parsed.axes.iter().any(|axis| axis.is_complex);
281
282    let target_class = match &parsed.template {
283        OutputTemplate::Default => {
284            if require_complex {
285                PrototypeClass::Complex
286            } else {
287                PrototypeClass::Real
288            }
289        }
290        OutputTemplate::Like(spec) => {
291            if require_complex {
292                PrototypeClass::Complex
293            } else {
294                spec.class
295            }
296        }
297    };
298
299    let target_residency = match &parsed.template {
300        OutputTemplate::Default => {
301            if parsed.prefer_gpu {
302                DevicePreference::Gpu
303            } else {
304                DevicePreference::Host
305            }
306        }
307        OutputTemplate::Like(spec) => spec.residency,
308    };
309
310    let mut gpu_outputs: Option<Vec<MeshgridOutput>> = None;
311    let axes_all_real = !require_complex;
312
313    if axes_all_real
314        && matches!(target_class, PrototypeClass::Real)
315        && matches!(target_residency, DevicePreference::Gpu)
316    {
317        if let Some(provider) = runmat_accelerate_api::provider() {
318            let x_real = axis_real_values(&x_axis);
319            let y_real = axis_real_values(&y_axis);
320            let z_real = z_axis.as_ref().map(axis_real_values);
321            let mut axis_views: Vec<MeshgridAxisView<'_>> =
322                Vec::with_capacity(if z_real.is_some() { 3 } else { 2 });
323            axis_views.push(MeshgridAxisView { data: &x_real });
324            axis_views.push(MeshgridAxisView { data: &y_real });
325            if let Some(ref data) = z_real {
326                axis_views.push(MeshgridAxisView { data });
327            }
328            match provider.meshgrid(&axis_views) {
329                Ok(result) => {
330                    let expected = if z_axis.is_some() { 3 } else { 2 };
331                    let outputs: Vec<MeshgridOutput> = result
332                        .outputs
333                        .into_iter()
334                        .map(MeshgridOutput::GpuReal)
335                        .collect();
336                    if outputs.len() == expected {
337                        gpu_outputs = Some(outputs);
338                    } else {
339                        warn!(
340                            "meshgrid: provider returned {}/{} outputs; falling back to host",
341                            outputs.len(),
342                            expected
343                        );
344                    }
345                }
346                Err(err) => {
347                    warn!("meshgrid: provider meshgrid hook failed, falling back to host: {err}")
348                }
349            }
350        }
351    }
352
353    let outputs = gpu_outputs.unwrap_or_else(|| {
354        build_outputs(&x_axis, &y_axis, z_axis.as_ref())
355            .into_iter()
356            .map(MeshgridOutput::Host)
357            .collect()
358    });
359
360    Ok(MeshgridEval {
361        outputs,
362        target_class,
363        target_residency,
364    })
365}
366
367#[derive(Clone)]
368struct ParsedMeshgrid {
369    axes: Vec<AxisData>,
370    template: OutputTemplate,
371    prefer_gpu: bool,
372}
373
374impl ParsedMeshgrid {
375    fn parse(args: &[Value]) -> Result<Self, String> {
376        if args.is_empty() {
377            return Err("meshgrid: at least one input vector is required".to_string());
378        }
379        let mut axis_values: Vec<Value> = Vec::new();
380        let mut like_proto: Option<Value> = None;
381        let mut prefer_gpu = false;
382        let mut idx = 0;
383        while idx < args.len() {
384            let value = args[idx].clone();
385            if let Some(keyword) = keyword_of(&value) {
386                match keyword.as_str() {
387                    "like" => {
388                        if like_proto.is_some() {
389                            return Err(
390                                "meshgrid: multiple 'like' specifications are not supported"
391                                    .to_string(),
392                            );
393                        }
394                        if axis_values.is_empty() {
395                            return Err("meshgrid: 'like' must follow at least one input vector"
396                                .to_string());
397                        }
398                        let Some(proto) = args.get(idx + 1).cloned() else {
399                            return Err("meshgrid: expected prototype after 'like'".to_string());
400                        };
401                        like_proto = Some(proto);
402                        idx += 2;
403                        if idx < args.len() {
404                            return Err("meshgrid: 'like' must be the final argument".to_string());
405                        }
406                        break;
407                    }
408                    other => {
409                        return Err(format!("meshgrid: unrecognised option '{other}'"));
410                    }
411                }
412            }
413
414            if let Value::GpuTensor(_) = value {
415                prefer_gpu = true;
416            }
417            axis_values.push(value);
418            idx += 1;
419        }
420
421        if axis_values.is_empty() {
422            return Err("meshgrid: at least one input vector is required".to_string());
423        }
424        if axis_values.len() > 3 {
425            return Err("meshgrid: expected at most three input vectors".to_string());
426        }
427
428        let mut axes = Vec::with_capacity(max(axis_values.len(), 2));
429        for (i, value) in axis_values.into_iter().enumerate() {
430            let mut consumed_gpu = false;
431            let data = axis_from_value(value, i, &mut consumed_gpu)?;
432            if consumed_gpu {
433                prefer_gpu = true;
434            }
435            axes.push(data);
436        }
437
438        if !prefer_gpu {
439            if let Some(max_len) = axes.iter().map(|axis| axis.len).max() {
440                if max_len > 0
441                    && sequence_gpu_preference(max_len, SequenceIntent::MeshAxis, false).prefer_gpu
442                {
443                    prefer_gpu = true;
444                }
445            }
446        }
447
448        let template = if let Some(proto) = like_proto {
449            OutputTemplate::Like(analyse_like_prototype(&proto)?)
450        } else {
451            OutputTemplate::Default
452        };
453
454        Ok(Self {
455            axes,
456            template,
457            prefer_gpu,
458        })
459    }
460}
461
462#[derive(Clone)]
463enum OutputTemplate {
464    Default,
465    Like(PrototypeSpec),
466}
467
468#[derive(Clone)]
469struct PrototypeSpec {
470    residency: DevicePreference,
471    class: PrototypeClass,
472}
473
474#[derive(Clone, Copy, PartialEq, Eq)]
475enum PrototypeClass {
476    Real,
477    Complex,
478}
479
480#[derive(Clone, Copy)]
481enum DevicePreference {
482    Host,
483    Gpu,
484}
485
486fn analyse_like_prototype(proto: &Value) -> Result<PrototypeSpec, String> {
487    match proto {
488        Value::GpuTensor(_) => Ok(PrototypeSpec {
489            residency: DevicePreference::Gpu,
490            class: PrototypeClass::Real,
491        }),
492        Value::ComplexTensor(_) | Value::Complex(_, _) => Ok(PrototypeSpec {
493            residency: DevicePreference::Host,
494            class: PrototypeClass::Complex,
495        }),
496        Value::Tensor(_)
497        | Value::Num(_)
498        | Value::Int(_)
499        | Value::Bool(_)
500        | Value::LogicalArray(_) => Ok(PrototypeSpec {
501            residency: DevicePreference::Host,
502            class: PrototypeClass::Real,
503        }),
504        Value::CharArray(_) | Value::String(_) | Value::StringArray(_) => {
505            Err("meshgrid: prototypes must be numeric or gpuArray values".to_string())
506        }
507        Value::Cell(_)
508        | Value::Struct(_)
509        | Value::Object(_)
510        | Value::HandleObject(_)
511        | Value::Listener(_)
512        | Value::FunctionHandle(_)
513        | Value::Closure(_)
514        | Value::ClassRef(_)
515        | Value::MException(_) => Err("meshgrid: prototypes must be numeric arrays".to_string()),
516    }
517}
518
519#[derive(Clone)]
520struct AxisData {
521    values: Vec<(f64, f64)>,
522    len: usize,
523    is_complex: bool,
524}
525
526fn axis_from_value(value: Value, index: usize, prefer_gpu: &mut bool) -> Result<AxisData, String> {
527    match value {
528        Value::Tensor(tensor) => axis_from_tensor(tensor),
529        Value::LogicalArray(logical) => {
530            let tensor = tensor::logical_to_tensor(&logical)?;
531            axis_from_tensor(tensor)
532        }
533        Value::Num(n) => Ok(AxisData {
534            values: vec![(n, 0.0)],
535            len: 1,
536            is_complex: false,
537        }),
538        Value::Int(i) => {
539            let val = i.to_f64();
540            Ok(AxisData {
541                values: vec![(val, 0.0)],
542                len: 1,
543                is_complex: false,
544            })
545        }
546        Value::Bool(b) => Ok(AxisData {
547            values: vec![(if b { 1.0 } else { 0.0 }, 0.0)],
548            len: 1,
549            is_complex: false,
550        }),
551        Value::Complex(re, im) => Ok(AxisData {
552            values: vec![(re, im)],
553            len: 1,
554            is_complex: im != 0.0,
555        }),
556        Value::ComplexTensor(tensor) => axis_from_complex_tensor(tensor),
557        Value::GpuTensor(handle) => {
558            *prefer_gpu = true;
559            let tensor = gpu_helpers::gather_tensor(&handle)?;
560            axis_from_tensor(tensor)
561        }
562        other => Err(format!(
563            "meshgrid: input argument {} must be numeric, got {other:?}",
564            index + 1
565        )),
566    }
567}
568
569fn axis_from_tensor(tensor: Tensor) -> Result<AxisData, String> {
570    if !is_vector_shape(&tensor.shape) {
571        return Err("meshgrid: input vectors must be one-dimensional".to_string());
572    }
573    let mut values = Vec::with_capacity(tensor.data.len());
574    for &v in &tensor.data {
575        values.push((v, 0.0));
576    }
577    Ok(AxisData {
578        len: values.len(),
579        values,
580        is_complex: false,
581    })
582}
583
584fn axis_from_complex_tensor(tensor: ComplexTensor) -> Result<AxisData, String> {
585    if !is_vector_shape(&tensor.shape) {
586        return Err("meshgrid: input vectors must be one-dimensional".to_string());
587    }
588    let is_complex = tensor
589        .data
590        .iter()
591        .any(|&(_, imag)| !imag.is_nan() && imag != 0.0);
592    Ok(AxisData {
593        len: tensor.data.len(),
594        values: tensor.data,
595        is_complex,
596    })
597}
598
599fn axis_real_values(axis: &AxisData) -> Vec<f64> {
600    axis.values.iter().map(|(re, _)| *re).collect()
601}
602
603fn is_vector_shape(shape: &[usize]) -> bool {
604    if shape.is_empty() {
605        return true;
606    }
607    let mut non_singleton = 0usize;
608    for &dim in shape {
609        if dim > 1 {
610            non_singleton += 1;
611        }
612    }
613    non_singleton <= 1
614}
615
616fn normalise_axes(axes: &[AxisData]) -> (AxisData, AxisData, Option<AxisData>) {
617    match axes.len() {
618        1 => {
619            let x = axes[0].clone();
620            (x.clone(), x, None)
621        }
622        2 => {
623            let x = axes[0].clone();
624            let y = axes[1].clone();
625            (x, y, None)
626        }
627        3 => {
628            let x = axes[0].clone();
629            let y = axes[1].clone();
630            let z = axes[2].clone();
631            (x, y, Some(z))
632        }
633        _ => unreachable!(),
634    }
635}
636
637fn build_outputs(
638    x_axis: &AxisData,
639    y_axis: &AxisData,
640    z_axis: Option<&AxisData>,
641) -> Vec<GridOutput> {
642    let nx = x_axis.len;
643    let ny = y_axis.len;
644    let nz = z_axis.map(|axis| axis.len).unwrap_or(1);
645    let total = nx * ny * nz;
646    let mut x_data = Vec::with_capacity(total);
647    let mut y_data = Vec::with_capacity(total);
648    let mut z_data = z_axis.map(|_| Vec::with_capacity(total));
649
650    for k in 0..nz {
651        let z_value = z_axis.map(|axis| axis.values[k]);
652        for col in 0..nx {
653            let x_value = x_axis.values[col];
654            for row in 0..ny {
655                x_data.push(x_value);
656                y_data.push(y_axis.values[row]);
657                if let Some(ref mut z_vec) = z_data {
658                    z_vec.push(z_value.unwrap());
659                }
660            }
661        }
662    }
663
664    let mut outputs = Vec::new();
665    let base_shape = if nz == 1 {
666        vec![ny, nx]
667    } else {
668        vec![ny, nx, nz]
669    };
670    outputs.push(GridOutput {
671        shape: base_shape.clone(),
672        data: x_data,
673    });
674    outputs.push(GridOutput {
675        shape: base_shape.clone(),
676        data: y_data,
677    });
678    if let Some(z_vec) = z_data {
679        outputs.push(GridOutput {
680            shape: base_shape,
681            data: z_vec,
682        });
683    }
684    outputs
685}
686
687struct GridOutput {
688    shape: Vec<usize>,
689    data: Vec<(f64, f64)>,
690}
691
692impl GridOutput {
693    fn to_value(
694        &self,
695        class: PrototypeClass,
696        residency: DevicePreference,
697    ) -> Result<Value, String> {
698        match class {
699            PrototypeClass::Real => self.to_real_value(residency),
700            PrototypeClass::Complex => self.to_complex_value(residency),
701        }
702    }
703
704    fn to_real_value(&self, residency: DevicePreference) -> Result<Value, String> {
705        let mut real = Vec::with_capacity(self.data.len());
706        for &(re, im) in &self.data {
707            if im != 0.0 {
708                return Err(
709                    "meshgrid: cannot represent complex values in a real output".to_string()
710                );
711            }
712            real.push(re);
713        }
714        let tensor = Tensor::new(real, self.shape.clone()).map_err(|e| format!("meshgrid: {e}"))?;
715        match residency {
716            DevicePreference::Host => Ok(tensor::tensor_into_value(tensor)),
717            DevicePreference::Gpu => to_gpu_tensor_value(tensor),
718        }
719    }
720
721    fn to_complex_value(&self, residency: DevicePreference) -> Result<Value, String> {
722        let tensor = ComplexTensor::new(self.data.clone(), self.shape.clone())
723            .map_err(|e| format!("meshgrid: {e}"))?;
724        match residency {
725            DevicePreference::Host => Ok(complex_tensor_into_value(tensor)),
726            DevicePreference::Gpu => {
727                warn!("meshgrid: complex GPU outputs are not implemented; returning host complex array");
728                Ok(complex_tensor_into_value(tensor))
729            }
730        }
731    }
732}
733
734fn to_gpu_tensor_value(tensor: Tensor) -> Result<Value, String> {
735    if let Some(provider) = runmat_accelerate_api::provider() {
736        let view = HostTensorView {
737            data: &tensor.data,
738            shape: &tensor.shape,
739        };
740        match provider.upload(&view) {
741            Ok(handle) => return Ok(Value::GpuTensor(handle)),
742            Err(err) => {
743                warn!("meshgrid: failed to upload tensor to GPU, returning host array: {err}")
744            }
745        }
746    }
747    Ok(tensor::tensor_into_value(tensor))
748}
749
750fn tensor_to_complex_value(tensor: Tensor) -> Result<Value, String> {
751    let data: Vec<(f64, f64)> = tensor.data.iter().map(|&re| (re, 0.0)).collect();
752    let complex =
753        ComplexTensor::new(data, tensor.shape.clone()).map_err(|e| format!("meshgrid: {e}"))?;
754    Ok(complex_tensor_into_value(complex))
755}
756
757enum MeshgridOutput {
758    Host(GridOutput),
759    GpuReal(GpuTensorHandle),
760}
761
762impl MeshgridOutput {
763    fn to_value(
764        &self,
765        class: PrototypeClass,
766        residency: DevicePreference,
767    ) -> Result<Value, String> {
768        match self {
769            MeshgridOutput::Host(host) => host.to_value(class, residency),
770            MeshgridOutput::GpuReal(handle) => match (class, residency) {
771                (PrototypeClass::Real, DevicePreference::Gpu) => {
772                    Ok(Value::GpuTensor(handle.clone()))
773                }
774                (PrototypeClass::Real, DevicePreference::Host) => {
775                    let tensor = gpu_helpers::gather_tensor(handle)?;
776                    Ok(tensor::tensor_into_value(tensor))
777                }
778                (PrototypeClass::Complex, DevicePreference::Host) => {
779                    let tensor = gpu_helpers::gather_tensor(handle)?;
780                    tensor_to_complex_value(tensor)
781                }
782                (PrototypeClass::Complex, DevicePreference::Gpu) => {
783                    warn!("meshgrid: complex GPU outputs are not implemented; returning host complex array");
784                    let tensor = gpu_helpers::gather_tensor(handle)?;
785                    tensor_to_complex_value(tensor)
786                }
787            },
788        }
789    }
790}
791
792/// Holds the results of a `meshgrid` evaluation so multiple outputs can be
793/// materialised without recomputing the grid.
794pub struct MeshgridEval {
795    outputs: Vec<MeshgridOutput>,
796    target_class: PrototypeClass,
797    target_residency: DevicePreference,
798}
799
800impl MeshgridEval {
801    pub fn output_count(&self) -> usize {
802        self.outputs.len()
803    }
804
805    pub fn first(&self) -> Result<Value, String> {
806        self.outputs[0].to_value(self.target_class, self.target_residency)
807    }
808
809    pub fn second(&self) -> Result<Value, String> {
810        if self.outputs.len() < 2 {
811            Err("meshgrid: second output unavailable".to_string())
812        } else {
813            self.outputs[1].to_value(self.target_class, self.target_residency)
814        }
815    }
816
817    pub fn third(&self) -> Result<Value, String> {
818        if self.outputs.len() < 3 {
819            Err("meshgrid: third output requested but no Z vector was supplied".to_string())
820        } else {
821            self.outputs[2].to_value(self.target_class, self.target_residency)
822        }
823    }
824}
825
826#[cfg(test)]
827mod tests {
828    use super::*;
829    use crate::builtins::common::test_support;
830    #[cfg(feature = "wgpu")]
831    use runmat_accelerate_api::AccelProvider;
832    use runmat_accelerate_api::HostTensorView;
833
834    fn tensor_from_vec(data: Vec<f64>, rows: usize, cols: usize) -> Tensor {
835        Tensor::new(data, vec![rows, cols]).unwrap()
836    }
837
838    #[test]
839    fn meshgrid_single_input_duplicates_axis() {
840        let x = tensor_from_vec(vec![-1.0, 0.0, 1.0], 1, 3);
841        let eval = evaluate(&[Value::Tensor(x)]).expect("meshgrid");
842        assert_eq!(eval.output_count(), 2);
843        let x_out = test_support::gather(eval.first().expect("X")).expect("host");
844        assert_eq!(x_out.shape, vec![3, 3]);
845        assert_eq!(
846            x_out.data,
847            vec![-1.0, -1.0, -1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0]
848        );
849        let y_out = test_support::gather(eval.second().expect("Y")).expect("host");
850        assert_eq!(y_out.shape, vec![3, 3]);
851        assert_eq!(
852            y_out.data,
853            vec![-1.0, 0.0, 1.0, -1.0, 0.0, 1.0, -1.0, 0.0, 1.0]
854        );
855    }
856
857    #[test]
858    fn meshgrid_rectangular_inputs() {
859        let x = tensor_from_vec(vec![0.0, 0.5, 1.0], 1, 3);
860        let y = tensor_from_vec(vec![10.0, 20.0], 2, 1);
861        let eval = evaluate(&[Value::Tensor(x), Value::Tensor(y)]).expect("meshgrid");
862        assert_eq!(eval.output_count(), 2);
863        let x_out = test_support::gather(eval.first().expect("X")).expect("host");
864        assert_eq!(x_out.shape, vec![2, 3]);
865        assert_eq!(x_out.data, vec![0.0, 0.0, 0.5, 0.5, 1.0, 1.0]);
866        let y_out = test_support::gather(eval.second().expect("Y")).expect("host");
867        assert_eq!(y_out.shape, vec![2, 3]);
868        assert_eq!(y_out.data, vec![10.0, 20.0, 10.0, 20.0, 10.0, 20.0]);
869    }
870
871    #[test]
872    fn meshgrid_three_inputs_volume() {
873        let x = tensor_from_vec(vec![1.0, 2.0], 1, 2);
874        let y = tensor_from_vec(vec![5.0, 6.0, 7.0], 3, 1);
875        let z = tensor_from_vec(vec![0.0, 1.0], 1, 2);
876        let eval =
877            evaluate(&[Value::Tensor(x), Value::Tensor(y), Value::Tensor(z)]).expect("meshgrid");
878        assert_eq!(eval.output_count(), 3);
879        let x_out = test_support::gather(eval.first().expect("X")).expect("host");
880        assert_eq!(x_out.shape, vec![3, 2, 2]);
881        assert_eq!(
882            x_out.data,
883            vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0]
884        );
885        let z_out = test_support::gather(eval.third().expect("Z")).expect("host");
886        assert_eq!(z_out.shape, vec![3, 2, 2]);
887        assert_eq!(
888            z_out.data,
889            vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
890        );
891    }
892
893    #[test]
894    fn meshgrid_like_keeps_gpu_residency() {
895        test_support::with_test_provider(|provider| {
896            let x = tensor_from_vec(vec![-1.0, 0.0, 1.0], 1, 3);
897            let y = tensor_from_vec(vec![2.0, 4.0], 2, 1);
898            let proto = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
899            let proto_view = HostTensorView {
900                data: &proto.data,
901                shape: &proto.shape,
902            };
903            let proto_handle = provider.upload(&proto_view).expect("upload");
904            let eval = evaluate(&[
905                Value::Tensor(x),
906                Value::Tensor(y),
907                Value::from("like"),
908                Value::GpuTensor(proto_handle),
909            ])
910            .expect("meshgrid");
911            let x_value = eval.first().expect("X");
912            assert!(matches!(x_value, Value::GpuTensor(_)));
913            let gathered = test_support::gather(x_value).expect("gather");
914            assert_eq!(gathered.shape, vec![2, 3]);
915        });
916    }
917
918    #[test]
919    fn meshgrid_gpu_inputs_roundtrip() {
920        test_support::with_test_provider(|provider| {
921            let x = tensor_from_vec(vec![0.0, 0.5], 1, 2);
922            let y = tensor_from_vec(vec![1.0, 2.0], 2, 1);
923            let x_view = HostTensorView {
924                data: &x.data,
925                shape: &x.shape,
926            };
927            let y_view = HostTensorView {
928                data: &y.data,
929                shape: &y.shape,
930            };
931            let x_handle = provider.upload(&x_view).expect("upload");
932            let y_handle = provider.upload(&y_view).expect("upload");
933            let eval = evaluate(&[Value::GpuTensor(x_handle), Value::GpuTensor(y_handle)])
934                .expect("meshgrid");
935            assert!(matches!(eval.first().expect("X"), Value::GpuTensor(_)));
936            assert!(matches!(eval.second().expect("Y"), Value::GpuTensor(_)));
937        });
938    }
939
940    #[test]
941    #[cfg(feature = "wgpu")]
942    fn meshgrid_wgpu_matches_cpu() {
943        let provider = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
944            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
945        )
946        .expect("wgpu provider");
947
948        let x = tensor_from_vec(vec![-1.0, 0.0, 1.0, 2.0], 1, 4);
949        let y = tensor_from_vec(vec![5.0, 6.0], 2, 1);
950
951        let cpu_eval =
952            evaluate(&[Value::Tensor(x.clone()), Value::Tensor(y.clone())]).expect("meshgrid cpu");
953        let cpu_x = test_support::gather(cpu_eval.first().expect("X cpu")).expect("gather X cpu");
954        let cpu_y = test_support::gather(cpu_eval.second().expect("Y cpu")).expect("gather Y cpu");
955
956        let x_view = HostTensorView {
957            data: &x.data,
958            shape: &x.shape,
959        };
960        let y_view = HostTensorView {
961            data: &y.data,
962            shape: &y.shape,
963        };
964        let x_gpu = provider.upload(&x_view).expect("upload x");
965        let y_gpu = provider.upload(&y_view).expect("upload y");
966
967        let gpu_eval =
968            evaluate(&[Value::GpuTensor(x_gpu), Value::GpuTensor(y_gpu)]).expect("meshgrid gpu");
969        let gpu_x_value = gpu_eval.first().expect("X gpu");
970        let gpu_y_value = gpu_eval.second().expect("Y gpu");
971
972        assert!(matches!(gpu_x_value, Value::GpuTensor(_)));
973        assert!(matches!(gpu_y_value, Value::GpuTensor(_)));
974
975        let gathered_x = test_support::gather(gpu_x_value).expect("gather X gpu");
976        let gathered_y = test_support::gather(gpu_y_value).expect("gather Y gpu");
977
978        assert_eq!(gathered_x.shape, cpu_x.shape);
979        assert_eq!(gathered_x.data, cpu_x.data);
980        assert_eq!(gathered_y.shape, cpu_y.shape);
981        assert_eq!(gathered_y.data, cpu_y.data);
982    }
983
984    #[test]
985    fn meshgrid_complex_inputs_produce_complex_outputs() {
986        let complex = ComplexTensor::new(vec![(1.0, 1.0), (2.0, -1.0)], vec![1, 2]).unwrap();
987        let eval = evaluate(&[Value::ComplexTensor(complex)]).expect("meshgrid");
988        let x_value = eval.first().expect("X");
989        match x_value {
990            Value::ComplexTensor(ct) => {
991                assert_eq!(ct.shape, vec![2, 2]);
992            }
993            Value::Complex(_, _) => {}
994            other => panic!("expected complex output, got {other:?}"),
995        }
996    }
997
998    #[test]
999    fn meshgrid_like_host_prototype() {
1000        let x = tensor_from_vec(vec![1.0, 2.0], 1, 2);
1001        let eval =
1002            evaluate(&[Value::Tensor(x), Value::from("like"), Value::Num(0.0)]).expect("meshgrid");
1003        let x_out = eval.first().expect("X");
1004        assert!(matches!(x_out, Value::Tensor(_) | Value::Num(_)));
1005    }
1006
1007    #[test]
1008    #[cfg(feature = "doc_export")]
1009    fn doc_examples_present() {
1010        let blocks = test_support::doc_examples(DOC_MD);
1011        assert!(!blocks.is_empty());
1012    }
1013}