Skip to main content

runmat_runtime/builtins/array/creation/
meshgrid.rs

1//! MATLAB-compatible `meshgrid` builtin with GPU-aware semantics.
2
3use std::cmp::max;
4
5use log::warn;
6use runmat_accelerate_api::{GpuTensorHandle, HostTensorView};
7use runmat_builtins::{
8    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
9    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
10    ComplexTensor, ResolveContext, Tensor, Type, Value,
11};
12
13use crate::builtins::array::type_resolvers::size_vector_len;
14use runmat_macros::runtime_builtin;
15
16use crate::build_runtime_error;
17use crate::builtins::common::gpu_helpers;
18use crate::builtins::common::random_args::{complex_tensor_into_value, keyword_of};
19use crate::builtins::common::residency::{sequence_gpu_preference, SequenceIntent};
20use crate::builtins::common::spec::{
21    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
22    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
23};
24use crate::builtins::common::tensor;
25
26#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::creation::meshgrid")]
27pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
28    name: "meshgrid",
29    op_kind: GpuOpKind::Custom("array_construct"),
30    supported_precisions: &[ScalarType::F32, ScalarType::F64],
31    broadcast: BroadcastSemantics::Matlab,
32    provider_hooks: &[ProviderHook::Custom("meshgrid")],
33    constant_strategy: ConstantStrategy::InlineLiteral,
34    residency: ResidencyPolicy::NewHandle,
35    nan_mode: ReductionNaN::Include,
36    two_pass_threshold: None,
37    workgroup_size: None,
38    accepts_nan_mode: false,
39    notes: "Providers may supply a dedicated meshgrid hook; until then the runtime builds grids on the host and uploads them when GPU residency is requested.",
40};
41
42fn builtin_error(message: impl Into<String>) -> crate::RuntimeError {
43    build_runtime_error(message)
44        .with_builtin("meshgrid")
45        .build()
46}
47
48#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::creation::meshgrid")]
49pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
50    name: "meshgrid",
51    shape: ShapeRequirements::Any,
52    constant_strategy: ConstantStrategy::InlineLiteral,
53    elementwise: None,
54    reduction: None,
55    emits_nan: false,
56    notes:
57        "Meshgrid explicitly materialises dense coordinate arrays and therefore bypasses fusion.",
58};
59
60fn meshgrid_type(args: &[Type], _context: &ResolveContext) -> Type {
61    if args.is_empty() {
62        return Type::Unknown;
63    }
64    let mut axis_count = args.len();
65    if axis_count >= 2 && matches!(args[axis_count - 2], Type::String) {
66        axis_count = axis_count.saturating_sub(2);
67    }
68    if axis_count == 0 {
69        return Type::Unknown;
70    }
71    let axis_args = &args[..axis_count];
72    let len_x = axis_args.get(0).and_then(size_vector_len);
73    let len_y = axis_args.get(1).and_then(size_vector_len).or(len_x);
74    let len_z = axis_args.get(2).and_then(size_vector_len);
75    let shape = if axis_count >= 3 {
76        vec![len_y, len_x, len_z]
77    } else {
78        vec![len_y, len_x]
79    };
80    Type::Tensor { shape: Some(shape) }
81}
82
83const MESHGRID_OUTPUT_XY: [BuiltinParamDescriptor; 2] = [
84    BuiltinParamDescriptor {
85        name: "X",
86        ty: BuiltinParamType::NumericArray,
87        arity: BuiltinParamArity::Required,
88        default: None,
89        description: "Grid coordinates along X-axis.",
90    },
91    BuiltinParamDescriptor {
92        name: "Y",
93        ty: BuiltinParamType::NumericArray,
94        arity: BuiltinParamArity::Required,
95        default: None,
96        description: "Grid coordinates along Y-axis.",
97    },
98];
99
100const MESHGRID_OUTPUT_XYZ: [BuiltinParamDescriptor; 3] = [
101    BuiltinParamDescriptor {
102        name: "X",
103        ty: BuiltinParamType::NumericArray,
104        arity: BuiltinParamArity::Required,
105        default: None,
106        description: "Grid coordinates along X-axis.",
107    },
108    BuiltinParamDescriptor {
109        name: "Y",
110        ty: BuiltinParamType::NumericArray,
111        arity: BuiltinParamArity::Required,
112        default: None,
113        description: "Grid coordinates along Y-axis.",
114    },
115    BuiltinParamDescriptor {
116        name: "Z",
117        ty: BuiltinParamType::NumericArray,
118        arity: BuiltinParamArity::Optional,
119        default: None,
120        description: "Grid coordinates along Z-axis.",
121    },
122];
123
124const MESHGRID_SIG_X_INPUTS: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
125    name: "x",
126    ty: BuiltinParamType::NumericArray,
127    arity: BuiltinParamArity::Required,
128    default: None,
129    description: "X-axis vector.",
130}];
131
132const MESHGRID_SIG_XY_INPUTS: [BuiltinParamDescriptor; 2] = [
133    BuiltinParamDescriptor {
134        name: "x",
135        ty: BuiltinParamType::NumericArray,
136        arity: BuiltinParamArity::Required,
137        default: None,
138        description: "X-axis vector.",
139    },
140    BuiltinParamDescriptor {
141        name: "y",
142        ty: BuiltinParamType::NumericArray,
143        arity: BuiltinParamArity::Required,
144        default: None,
145        description: "Y-axis vector.",
146    },
147];
148
149const MESHGRID_SIG_XYZ_INPUTS: [BuiltinParamDescriptor; 3] = [
150    BuiltinParamDescriptor {
151        name: "x",
152        ty: BuiltinParamType::NumericArray,
153        arity: BuiltinParamArity::Required,
154        default: None,
155        description: "X-axis vector.",
156    },
157    BuiltinParamDescriptor {
158        name: "y",
159        ty: BuiltinParamType::NumericArray,
160        arity: BuiltinParamArity::Required,
161        default: None,
162        description: "Y-axis vector.",
163    },
164    BuiltinParamDescriptor {
165        name: "z",
166        ty: BuiltinParamType::NumericArray,
167        arity: BuiltinParamArity::Optional,
168        default: None,
169        description: "Z-axis vector.",
170    },
171];
172
173const MESHGRID_SIG_X_LIKE_INPUTS: [BuiltinParamDescriptor; 3] = [
174    BuiltinParamDescriptor {
175        name: "x",
176        ty: BuiltinParamType::NumericArray,
177        arity: BuiltinParamArity::Required,
178        default: None,
179        description: "X-axis vector.",
180    },
181    BuiltinParamDescriptor {
182        name: "like_kw",
183        ty: BuiltinParamType::StringScalar,
184        arity: BuiltinParamArity::Required,
185        default: Some("\"like\""),
186        description: "Like keyword.",
187    },
188    BuiltinParamDescriptor {
189        name: "prototype",
190        ty: BuiltinParamType::LikePrototype,
191        arity: BuiltinParamArity::Required,
192        default: None,
193        description: "Prototype controlling class/device residency.",
194    },
195];
196
197const MESHGRID_SIG_XY_LIKE_INPUTS: [BuiltinParamDescriptor; 4] = [
198    BuiltinParamDescriptor {
199        name: "x",
200        ty: BuiltinParamType::NumericArray,
201        arity: BuiltinParamArity::Required,
202        default: None,
203        description: "X-axis vector.",
204    },
205    BuiltinParamDescriptor {
206        name: "y",
207        ty: BuiltinParamType::NumericArray,
208        arity: BuiltinParamArity::Required,
209        default: None,
210        description: "Y-axis vector.",
211    },
212    BuiltinParamDescriptor {
213        name: "like_kw",
214        ty: BuiltinParamType::StringScalar,
215        arity: BuiltinParamArity::Required,
216        default: Some("\"like\""),
217        description: "Like keyword.",
218    },
219    BuiltinParamDescriptor {
220        name: "prototype",
221        ty: BuiltinParamType::LikePrototype,
222        arity: BuiltinParamArity::Required,
223        default: None,
224        description: "Prototype controlling class/device residency.",
225    },
226];
227
228const MESHGRID_SIG_XYZ_LIKE_INPUTS: [BuiltinParamDescriptor; 5] = [
229    BuiltinParamDescriptor {
230        name: "x",
231        ty: BuiltinParamType::NumericArray,
232        arity: BuiltinParamArity::Required,
233        default: None,
234        description: "X-axis vector.",
235    },
236    BuiltinParamDescriptor {
237        name: "y",
238        ty: BuiltinParamType::NumericArray,
239        arity: BuiltinParamArity::Required,
240        default: None,
241        description: "Y-axis vector.",
242    },
243    BuiltinParamDescriptor {
244        name: "z",
245        ty: BuiltinParamType::NumericArray,
246        arity: BuiltinParamArity::Optional,
247        default: None,
248        description: "Z-axis vector.",
249    },
250    BuiltinParamDescriptor {
251        name: "like_kw",
252        ty: BuiltinParamType::StringScalar,
253        arity: BuiltinParamArity::Required,
254        default: Some("\"like\""),
255        description: "Like keyword.",
256    },
257    BuiltinParamDescriptor {
258        name: "prototype",
259        ty: BuiltinParamType::LikePrototype,
260        arity: BuiltinParamArity::Required,
261        default: None,
262        description: "Prototype controlling class/device residency.",
263    },
264];
265
266const MESHGRID_SIGNATURES: [BuiltinSignatureDescriptor; 6] = [
267    BuiltinSignatureDescriptor {
268        label: "[X,Y] = meshgrid(x)",
269        inputs: &MESHGRID_SIG_X_INPUTS,
270        outputs: &MESHGRID_OUTPUT_XY,
271    },
272    BuiltinSignatureDescriptor {
273        label: "[X,Y] = meshgrid(x, y)",
274        inputs: &MESHGRID_SIG_XY_INPUTS,
275        outputs: &MESHGRID_OUTPUT_XY,
276    },
277    BuiltinSignatureDescriptor {
278        label: "[X,Y,Z] = meshgrid(x, y, z)",
279        inputs: &MESHGRID_SIG_XYZ_INPUTS,
280        outputs: &MESHGRID_OUTPUT_XYZ,
281    },
282    BuiltinSignatureDescriptor {
283        label: "[X,Y] = meshgrid(x, \"like\", prototype)",
284        inputs: &MESHGRID_SIG_X_LIKE_INPUTS,
285        outputs: &MESHGRID_OUTPUT_XY,
286    },
287    BuiltinSignatureDescriptor {
288        label: "[X,Y] = meshgrid(x, y, \"like\", prototype)",
289        inputs: &MESHGRID_SIG_XY_LIKE_INPUTS,
290        outputs: &MESHGRID_OUTPUT_XY,
291    },
292    BuiltinSignatureDescriptor {
293        label: "[X,Y,Z] = meshgrid(x, y, z, \"like\", prototype)",
294        inputs: &MESHGRID_SIG_XYZ_LIKE_INPUTS,
295        outputs: &MESHGRID_OUTPUT_XYZ,
296    },
297];
298
299const MESHGRID_ERRORS: [BuiltinErrorDescriptor; 11] = [
300    BuiltinErrorDescriptor {
301        code: "RM.MESHGRID.MISSING_AXIS",
302        identifier: None,
303        when: "No axis vectors are provided.",
304        message: "meshgrid: at least one input vector is required",
305    },
306    BuiltinErrorDescriptor {
307        code: "RM.MESHGRID.TOO_MANY_AXES",
308        identifier: None,
309        when: "More than three axis vectors are provided.",
310        message: "meshgrid: expected at most three input vectors",
311    },
312    BuiltinErrorDescriptor {
313        code: "RM.MESHGRID.LIKE_EXPECTED_PROTOTYPE",
314        identifier: None,
315        when: "The 'like' keyword is provided without a prototype argument.",
316        message: "meshgrid: expected prototype after 'like'",
317    },
318    BuiltinErrorDescriptor {
319        code: "RM.MESHGRID.MULTIPLE_LIKE",
320        identifier: None,
321        when: "The 'like' keyword is provided multiple times.",
322        message: "meshgrid: multiple 'like' specifications are not supported",
323    },
324    BuiltinErrorDescriptor {
325        code: "RM.MESHGRID.LIKE_POSITION",
326        identifier: None,
327        when: "The 'like' keyword is in an invalid position or not final.",
328        message: "meshgrid: 'like' must be the final argument",
329    },
330    BuiltinErrorDescriptor {
331        code: "RM.MESHGRID.UNRECOGNIZED_OPTION",
332        identifier: None,
333        when: "A trailing option string is not recognized.",
334        message: "meshgrid: unrecognised option",
335    },
336    BuiltinErrorDescriptor {
337        code: "RM.MESHGRID.INVALID_AXIS_INPUT",
338        identifier: None,
339        when: "Axis inputs are non-numeric or non-vector shapes.",
340        message: "meshgrid: input argument must be numeric vector data",
341    },
342    BuiltinErrorDescriptor {
343        code: "RM.MESHGRID.INVALID_PROTOTYPE",
344        identifier: None,
345        when: "The 'like' prototype is unsupported.",
346        message: "meshgrid: prototypes must be numeric arrays",
347    },
348    BuiltinErrorDescriptor {
349        code: "RM.MESHGRID.OUTPUT_COUNT_EXCEEDED",
350        identifier: None,
351        when: "Requested outputs exceed available outputs for provided axes.",
352        message:
353            "meshgrid: supports at most two outputs for 2-axis inputs and three for 3-axis inputs",
354    },
355    BuiltinErrorDescriptor {
356        code: "RM.MESHGRID.THIRD_OUTPUT_UNAVAILABLE",
357        identifier: None,
358        when: "A third output is requested without supplying a Z-axis vector.",
359        message: "meshgrid: third output requested but no Z vector was supplied",
360    },
361    BuiltinErrorDescriptor {
362        code: "RM.MESHGRID.COMPLEX_REAL_CONVERSION",
363        identifier: None,
364        when: "Complex axis values cannot be represented in requested real output class.",
365        message: "meshgrid: cannot represent complex values in a real output",
366    },
367];
368
369pub const MESHGRID_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
370    signatures: &MESHGRID_SIGNATURES,
371    output_mode: BuiltinOutputMode::ByRequestedOutputCount,
372    completion_policy: BuiltinCompletionPolicy::Public,
373    errors: &MESHGRID_ERRORS,
374};
375
376#[runtime_builtin(
377    name = "meshgrid",
378    category = "array/creation",
379    summary = "Generate coordinate matrices for 2-D and 3-D grids.",
380    keywords = "meshgrid,grid,gpu,like,3d",
381    accel = "array_construct",
382    type_resolver(meshgrid_type),
383    descriptor(crate::builtins::array::creation::meshgrid::MESHGRID_DESCRIPTOR),
384    builtin_path = "crate::builtins::array::creation::meshgrid"
385)]
386async fn meshgrid_builtin(rest: Vec<Value>) -> crate::BuiltinResult<Value> {
387    let eval = evaluate(&rest).await?;
388    if let Some(out_count) = crate::output_count::current_output_count() {
389        if out_count == 0 {
390            return Ok(Value::OutputList(Vec::new()));
391        }
392        let available = eval.output_count();
393        if out_count > available {
394            let msg = if available == 2 {
395                "meshgrid with two inputs supports at most two outputs"
396            } else {
397                "meshgrid supports at most three outputs"
398            };
399            return Err(builtin_error(msg));
400        }
401        let mut outputs = Vec::with_capacity(out_count);
402        let first = eval.first().await?;
403        outputs.push(first);
404        if out_count >= 2 {
405            outputs.push(eval.second().await?);
406        }
407        if out_count >= 3 {
408            outputs.push(eval.third().await?);
409        }
410        return Ok(Value::OutputList(outputs));
411    }
412    eval.first().await
413}
414
415/// Evaluate the `meshgrid` builtin once and reuse the result for multiple outputs.
416pub async fn evaluate(args: &[Value]) -> crate::BuiltinResult<MeshgridEval> {
417    let parsed = ParsedMeshgrid::parse(args).await?;
418    let (x_axis, y_axis, z_axis) = normalise_axes(&parsed.axes);
419
420    let require_complex = parsed.axes.iter().any(|axis| axis.is_complex);
421
422    let target_class = match &parsed.template {
423        OutputTemplate::Default => {
424            if require_complex {
425                PrototypeClass::Complex
426            } else {
427                PrototypeClass::Real
428            }
429        }
430        OutputTemplate::Like(spec) => {
431            if require_complex {
432                PrototypeClass::Complex
433            } else {
434                spec.class
435            }
436        }
437    };
438
439    let target_residency = match &parsed.template {
440        OutputTemplate::Default => {
441            if parsed.prefer_gpu {
442                DevicePreference::Gpu
443            } else {
444                DevicePreference::Host
445            }
446        }
447        OutputTemplate::Like(spec) => spec.residency,
448    };
449
450    let axes_all_real = !require_complex;
451    let mut outputs: Vec<MeshgridOutput> = Vec::new();
452
453    if axes_all_real
454        && matches!(target_class, PrototypeClass::Real)
455        && matches!(target_residency, DevicePreference::Gpu)
456    {
457        if let Some(gpu) = try_meshgrid_gpu_from_vector_axes(&x_axis, &y_axis, z_axis.as_ref())? {
458            outputs = gpu;
459        }
460    }
461
462    if outputs.is_empty() {
463        // Host fallback: ensure we have host axis values materialized.
464        let x_host = axis_to_host_async(&x_axis).await?;
465        let y_host = axis_to_host_async(&y_axis).await?;
466        let z_host = match z_axis.as_ref() {
467            Some(axis) => Some(axis_to_host_async(axis).await?),
468            None => None,
469        };
470        outputs = build_outputs(&x_host, &y_host, z_host.as_ref())
471            .into_iter()
472            .map(MeshgridOutput::Host)
473            .collect();
474    }
475
476    Ok(MeshgridEval {
477        outputs,
478        target_class,
479        target_residency,
480    })
481}
482
483#[derive(Clone)]
484struct ParsedMeshgrid {
485    axes: Vec<AxisData>,
486    template: OutputTemplate,
487    prefer_gpu: bool,
488}
489
490impl ParsedMeshgrid {
491    async fn parse(args: &[Value]) -> crate::BuiltinResult<Self> {
492        if args.is_empty() {
493            return Err(builtin_error(
494                "meshgrid: at least one input vector is required",
495            ));
496        }
497        let mut axis_values: Vec<Value> = Vec::new();
498        let mut like_proto: Option<Value> = None;
499        let mut prefer_gpu = false;
500        let mut idx = 0;
501        while idx < args.len() {
502            let value = args[idx].clone();
503            if let Some(keyword) = keyword_of(&value) {
504                match keyword.as_str() {
505                    "like" => {
506                        if like_proto.is_some() {
507                            return Err(builtin_error(
508                                "meshgrid: multiple 'like' specifications are not supported",
509                            ));
510                        }
511                        if axis_values.is_empty() {
512                            return Err(builtin_error(
513                                "meshgrid: 'like' must follow at least one input vector",
514                            ));
515                        }
516                        let Some(proto) = args.get(idx + 1).cloned() else {
517                            return Err(builtin_error("meshgrid: expected prototype after 'like'"));
518                        };
519                        like_proto = Some(proto);
520                        idx += 2;
521                        if idx < args.len() {
522                            return Err(builtin_error(
523                                "meshgrid: 'like' must be the final argument",
524                            ));
525                        }
526                        break;
527                    }
528                    other => {
529                        return Err(builtin_error(format!(
530                            "meshgrid: unrecognised option '{other}'"
531                        )));
532                    }
533                }
534            }
535
536            if let Value::GpuTensor(_) = value {
537                prefer_gpu = true;
538            }
539            axis_values.push(value);
540            idx += 1;
541        }
542
543        if axis_values.is_empty() {
544            return Err(builtin_error(
545                "meshgrid: at least one input vector is required",
546            ));
547        }
548        if axis_values.len() > 3 {
549            return Err(builtin_error(
550                "meshgrid: expected at most three input vectors",
551            ));
552        }
553
554        let mut axes = Vec::with_capacity(max(axis_values.len(), 2));
555        for (i, value) in axis_values.into_iter().enumerate() {
556            let mut consumed_gpu = false;
557            let data = axis_from_value(value, i, &mut consumed_gpu).await?;
558            if consumed_gpu {
559                prefer_gpu = true;
560            }
561            axes.push(data);
562        }
563
564        if !prefer_gpu {
565            if let Some(max_len) = axes.iter().map(|axis| axis.len).max() {
566                if max_len > 0
567                    && sequence_gpu_preference(max_len, SequenceIntent::MeshAxis, false).prefer_gpu
568                {
569                    prefer_gpu = true;
570                }
571            }
572        }
573
574        let template = if let Some(proto) = like_proto {
575            OutputTemplate::Like(analyse_like_prototype(&proto)?)
576        } else {
577            OutputTemplate::Default
578        };
579
580        Ok(Self {
581            axes,
582            template,
583            prefer_gpu,
584        })
585    }
586}
587
588#[derive(Clone)]
589enum OutputTemplate {
590    Default,
591    Like(PrototypeSpec),
592}
593
594#[derive(Clone)]
595struct PrototypeSpec {
596    residency: DevicePreference,
597    class: PrototypeClass,
598}
599
600#[derive(Clone, Copy, PartialEq, Eq)]
601enum PrototypeClass {
602    Real,
603    Complex,
604}
605
606#[derive(Clone, Copy)]
607enum DevicePreference {
608    Host,
609    Gpu,
610}
611
612fn analyse_like_prototype(proto: &Value) -> crate::BuiltinResult<PrototypeSpec> {
613    match proto {
614        Value::GpuTensor(_) => Ok(PrototypeSpec {
615            residency: DevicePreference::Gpu,
616            class: PrototypeClass::Real,
617        }),
618        Value::ComplexTensor(_) | Value::Complex(_, _) => Ok(PrototypeSpec {
619            residency: DevicePreference::Host,
620            class: PrototypeClass::Complex,
621        }),
622        Value::Tensor(_)
623        | Value::SparseTensor(_)
624        | Value::Num(_)
625        | Value::Int(_)
626        | Value::Bool(_)
627        | Value::LogicalArray(_) => Ok(PrototypeSpec {
628            residency: DevicePreference::Host,
629            class: PrototypeClass::Real,
630        }),
631        Value::CharArray(_) | Value::String(_) | Value::StringArray(_) => Err(builtin_error(
632            "meshgrid: prototypes must be numeric or gpuArray values",
633        )),
634        Value::Cell(_)
635        | Value::Struct(_)
636        | Value::Object(_)
637        | Value::HandleObject(_)
638        | Value::Listener(_)
639        | Value::FunctionHandle(_)
640        | Value::ExternalFunctionHandle(_)
641        | Value::MethodFunctionHandle(_)
642        | Value::BoundFunctionHandle { .. }
643        | Value::Closure(_)
644        | Value::ClassRef(_)
645        | Value::MException(_)
646        | Value::OutputList(_) => Err(builtin_error("meshgrid: prototypes must be numeric arrays")),
647    }
648}
649
650#[derive(Clone)]
651struct AxisData {
652    values: Vec<(f64, f64)>,
653    len: usize,
654    is_complex: bool,
655    gpu_real: Option<GpuTensorHandle>,
656}
657
658async fn axis_from_value(
659    value: Value,
660    index: usize,
661    prefer_gpu: &mut bool,
662) -> crate::BuiltinResult<AxisData> {
663    match value {
664        Value::Tensor(tensor) => axis_from_tensor(tensor, index),
665        Value::LogicalArray(logical) => {
666            let tensor = tensor::logical_to_tensor(&logical)?;
667            axis_from_tensor(tensor, index)
668        }
669        Value::Num(n) => Ok(AxisData {
670            values: vec![(n, 0.0)],
671            len: 1,
672            is_complex: false,
673            gpu_real: None,
674        }),
675        Value::Int(i) => {
676            let val = i.to_f64();
677            Ok(AxisData {
678                values: vec![(val, 0.0)],
679                len: 1,
680                is_complex: false,
681                gpu_real: None,
682            })
683        }
684        Value::Bool(b) => Ok(AxisData {
685            values: vec![(if b { 1.0 } else { 0.0 }, 0.0)],
686            len: 1,
687            is_complex: false,
688            gpu_real: None,
689        }),
690        Value::Complex(re, im) => Ok(AxisData {
691            values: vec![(re, im)],
692            len: 1,
693            is_complex: im != 0.0,
694            gpu_real: None,
695        }),
696        Value::ComplexTensor(tensor) => axis_from_complex_tensor(tensor, index),
697        Value::GpuTensor(handle) => {
698            // Fast path: if the gpuArray is vector-like, keep it on-device and avoid a download.
699            // We'll validate any non-vector shapes by gathering below.
700            if is_vector_shape(&handle.shape) {
701                *prefer_gpu = true;
702                return Ok(AxisData {
703                    values: Vec::new(),
704                    len: vector_len_from_shape(&handle.shape),
705                    is_complex: false,
706                    gpu_real: Some(handle),
707                });
708            }
709
710            // Fallback: gather to validate / recover axes from meshgrid matrices.
711            let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
712            if is_vector_shape(&tensor.shape) {
713                *prefer_gpu = true;
714            }
715            axis_from_tensor(tensor, index)
716        }
717        other => Err(builtin_error(format!(
718            "meshgrid: input argument {} must be numeric, got {other:?}",
719            index + 1
720        ))),
721    }
722}
723
724fn axis_from_tensor(tensor: Tensor, index: usize) -> crate::BuiltinResult<AxisData> {
725    if is_vector_shape(&tensor.shape) {
726        let mut values = Vec::with_capacity(tensor.data.len());
727        for &v in &tensor.data {
728            values.push((v, 0.0));
729        }
730        return Ok(AxisData {
731            len: values.len(),
732            values,
733            is_complex: false,
734            gpu_real: None,
735        });
736    }
737
738    // Be slightly more permissive than MATLAB: if the input is already a meshgrid-style
739    // coordinate matrix, accept it and recover the original axis vector.
740    //
741    // This is a pragmatic compatibility shim for cases where callers already have
742    // coordinate matrices (X/Y) and pass them through `meshgrid` again.
743    if let Some(axis) = axis_from_meshgrid_matrix_real(&tensor, index)? {
744        return Ok(axis);
745    }
746
747    Err(builtin_error(format!(
748        "meshgrid: input argument {} must be a vector (1xN or Nx1), got shape {:?}",
749        index + 1,
750        tensor.shape
751    )))
752}
753
754fn axis_from_complex_tensor(tensor: ComplexTensor, index: usize) -> crate::BuiltinResult<AxisData> {
755    if is_vector_shape(&tensor.shape) {
756        let is_complex = tensor
757            .data
758            .iter()
759            .any(|&(_, imag)| !imag.is_nan() && imag != 0.0);
760        return Ok(AxisData {
761            len: tensor.data.len(),
762            values: tensor.data,
763            is_complex,
764            gpu_real: None,
765        });
766    }
767
768    if let Some(axis) = axis_from_meshgrid_matrix_complex(&tensor, index)? {
769        return Ok(axis);
770    }
771
772    Err(builtin_error(format!(
773        "meshgrid: input argument {} must be a vector (1xN or Nx1), got shape {:?}",
774        index + 1,
775        tensor.shape
776    )))
777}
778
779fn axis_from_meshgrid_matrix_real(
780    tensor: &Tensor,
781    index: usize,
782) -> crate::BuiltinResult<Option<AxisData>> {
783    let (rows, cols) = match tensor.shape.as_slice() {
784        [r, c] => (*r, *c),
785        _ => return Ok(None),
786    };
787    if rows <= 1 || cols <= 1 {
788        return Ok(None);
789    }
790
791    // Index 0 is expected to be the X-axis: a meshgrid X matrix has identical rows.
792    // Index 1 is expected to be the Y-axis: a meshgrid Y matrix has identical columns.
793    let expect_rows_constant = index == 0;
794
795    if expect_rows_constant {
796        if !matrix_rows_are_identical_real(tensor, rows, cols) {
797            return Ok(None);
798        }
799        // Extract the first row as the axis vector (length = cols).
800        let mut values = Vec::with_capacity(cols);
801        for col in 0..cols {
802            let idx = rows * col;
803            values.push((tensor.data[idx], 0.0));
804        }
805        return Ok(Some(AxisData {
806            len: values.len(),
807            values,
808            is_complex: false,
809            gpu_real: None,
810        }));
811    }
812
813    if !matrix_cols_are_identical_real(tensor, rows, cols) {
814        return Ok(None);
815    }
816    // Extract the first column as the axis vector (length = rows).
817    let mut values = Vec::with_capacity(rows);
818    for row in 0..rows {
819        values.push((tensor.data[row], 0.0));
820    }
821    Ok(Some(AxisData {
822        len: values.len(),
823        values,
824        is_complex: false,
825        gpu_real: None,
826    }))
827}
828
829fn axis_from_meshgrid_matrix_complex(
830    tensor: &ComplexTensor,
831    index: usize,
832) -> crate::BuiltinResult<Option<AxisData>> {
833    let (rows, cols) = match tensor.shape.as_slice() {
834        [r, c] => (*r, *c),
835        _ => return Ok(None),
836    };
837    if rows <= 1 || cols <= 1 {
838        return Ok(None);
839    }
840
841    let expect_rows_constant = index == 0;
842    if expect_rows_constant {
843        if !matrix_rows_are_identical_complex(tensor, rows, cols) {
844            return Ok(None);
845        }
846        let mut values = Vec::with_capacity(cols);
847        for col in 0..cols {
848            let idx = rows * col;
849            values.push(tensor.data[idx]);
850        }
851        let is_complex = values.iter().any(|&(_, im)| !im.is_nan() && im != 0.0);
852        return Ok(Some(AxisData {
853            len: values.len(),
854            values,
855            is_complex,
856            gpu_real: None,
857        }));
858    }
859
860    if !matrix_cols_are_identical_complex(tensor, rows, cols) {
861        return Ok(None);
862    }
863    let mut values = Vec::with_capacity(rows);
864    for row in 0..rows {
865        values.push(tensor.data[row]);
866    }
867    let is_complex = values.iter().any(|&(_, im)| !im.is_nan() && im != 0.0);
868    Ok(Some(AxisData {
869        len: values.len(),
870        values,
871        is_complex,
872        gpu_real: None,
873    }))
874}
875
876fn matrix_rows_are_identical_real(tensor: &Tensor, rows: usize, cols: usize) -> bool {
877    for row in 1..rows {
878        for col in 0..cols {
879            let idx0 = rows * col;
880            let idx = row + rows * col;
881            if tensor.data[idx] != tensor.data[idx0] {
882                return false;
883            }
884        }
885    }
886    true
887}
888
889fn matrix_cols_are_identical_real(tensor: &Tensor, rows: usize, cols: usize) -> bool {
890    for col in 1..cols {
891        for row in 0..rows {
892            let idx0 = row;
893            let idx = row + rows * col;
894            if tensor.data[idx] != tensor.data[idx0] {
895                return false;
896            }
897        }
898    }
899    true
900}
901
902fn matrix_rows_are_identical_complex(tensor: &ComplexTensor, rows: usize, cols: usize) -> bool {
903    for row in 1..rows {
904        for col in 0..cols {
905            let idx0 = rows * col;
906            let idx = row + rows * col;
907            if tensor.data[idx] != tensor.data[idx0] {
908                return false;
909            }
910        }
911    }
912    true
913}
914
915fn matrix_cols_are_identical_complex(tensor: &ComplexTensor, rows: usize, cols: usize) -> bool {
916    for col in 1..cols {
917        for row in 0..rows {
918            let idx0 = row;
919            let idx = row + rows * col;
920            if tensor.data[idx] != tensor.data[idx0] {
921                return false;
922            }
923        }
924    }
925    true
926}
927
928fn is_vector_shape(shape: &[usize]) -> bool {
929    if shape.is_empty() {
930        return true;
931    }
932    let mut non_singleton = 0usize;
933    for &dim in shape {
934        if dim > 1 {
935            non_singleton += 1;
936        }
937    }
938    non_singleton <= 1
939}
940
941fn vector_len_from_shape(shape: &[usize]) -> usize {
942    if shape.is_empty() {
943        return 1;
944    }
945    shape.iter().copied().max().unwrap_or(0)
946}
947
948async fn axis_to_host_async(axis: &AxisData) -> crate::BuiltinResult<AxisData> {
949    if axis.gpu_real.is_none() {
950        return Ok(axis.clone());
951    }
952    let handle = axis.gpu_real.as_ref().expect("checked gpu_real is_some");
953    let tensor = gpu_helpers::gather_tensor_async(handle).await?;
954    // Index is only used for error messages; tensor came from a validated vector-like handle.
955    axis_from_tensor(tensor, 0)
956}
957
958fn try_meshgrid_gpu_from_vector_axes(
959    x_axis: &AxisData,
960    y_axis: &AxisData,
961    z_axis: Option<&AxisData>,
962) -> crate::BuiltinResult<Option<Vec<MeshgridOutput>>> {
963    let Some(x_handle) = x_axis.gpu_real.as_ref() else {
964        return Ok(None);
965    };
966    let Some(y_handle) = y_axis.gpu_real.as_ref() else {
967        return Ok(None);
968    };
969
970    let z_handle = match z_axis {
971        Some(axis) => match axis.gpu_real.as_ref() {
972            Some(h) => Some(h),
973            None => return Ok(None),
974        },
975        None => None,
976    };
977
978    let Some(provider) = runmat_accelerate_api::provider_for_handle(x_handle) else {
979        return Ok(None);
980    };
981    if runmat_accelerate_api::provider_for_handle(y_handle).is_none() {
982        return Ok(None);
983    }
984    if let Some(z) = z_handle {
985        if runmat_accelerate_api::provider_for_handle(z).is_none() {
986            return Ok(None);
987        }
988    }
989
990    let nx = x_axis.len;
991    let ny = y_axis.len;
992    let nz = z_axis.map(|axis| axis.len).unwrap_or(1);
993
994    // Reshape axis vectors (metadata-only) so repmat can build full grids on-device.
995    let x_row = provider
996        .reshape(x_handle, &[1, nx])
997        .map_err(|e| builtin_error(format!("meshgrid: reshape X failed: {e}")))?;
998    let y_col = provider
999        .reshape(y_handle, &[ny, 1])
1000        .map_err(|e| builtin_error(format!("meshgrid: reshape Y failed: {e}")))?;
1001
1002    let mut outputs = Vec::with_capacity(if z_handle.is_some() { 3 } else { 2 });
1003    if let Some(z) = z_handle {
1004        let x_base = provider
1005            .reshape(&x_row, &[1, nx, 1])
1006            .map_err(|e| builtin_error(format!("meshgrid: reshape X(3d) failed: {e}")))?;
1007        let y_base = provider
1008            .reshape(&y_col, &[ny, 1, 1])
1009            .map_err(|e| builtin_error(format!("meshgrid: reshape Y(3d) failed: {e}")))?;
1010
1011        let x_grid = provider
1012            .repmat(&x_base, &[ny, 1, nz])
1013            .map_err(|e| builtin_error(format!("meshgrid: repmat X failed: {e}")))?;
1014        let y_grid = provider
1015            .repmat(&y_base, &[1, nx, nz])
1016            .map_err(|e| builtin_error(format!("meshgrid: repmat Y failed: {e}")))?;
1017
1018        outputs.push(MeshgridOutput::GpuReal(x_grid));
1019        outputs.push(MeshgridOutput::GpuReal(y_grid));
1020        let z_axis_row = provider
1021            .reshape(z, &[1, nz])
1022            .map_err(|e| builtin_error(format!("meshgrid: reshape Z failed: {e}")))?;
1023        let z_base = provider
1024            .reshape(&z_axis_row, &[1, 1, nz])
1025            .map_err(|e| builtin_error(format!("meshgrid: reshape Z(3d) failed: {e}")))?;
1026        let z_grid = provider
1027            .repmat(&z_base, &[ny, nx, 1])
1028            .map_err(|e| builtin_error(format!("meshgrid: repmat Z failed: {e}")))?;
1029        outputs.push(MeshgridOutput::GpuReal(z_grid));
1030    } else {
1031        let x_grid = provider
1032            .repmat(&x_row, &[ny, 1])
1033            .map_err(|e| builtin_error(format!("meshgrid: repmat X failed: {e}")))?;
1034        let y_grid = provider
1035            .repmat(&y_col, &[1, nx])
1036            .map_err(|e| builtin_error(format!("meshgrid: repmat Y failed: {e}")))?;
1037        outputs.push(MeshgridOutput::GpuReal(x_grid));
1038        outputs.push(MeshgridOutput::GpuReal(y_grid));
1039    }
1040
1041    Ok(Some(outputs))
1042}
1043
1044fn normalise_axes(axes: &[AxisData]) -> (AxisData, AxisData, Option<AxisData>) {
1045    match axes.len() {
1046        1 => {
1047            let x = axes[0].clone();
1048            (x.clone(), x, None)
1049        }
1050        2 => {
1051            let x = axes[0].clone();
1052            let y = axes[1].clone();
1053            (x, y, None)
1054        }
1055        3 => {
1056            let x = axes[0].clone();
1057            let y = axes[1].clone();
1058            let z = axes[2].clone();
1059            (x, y, Some(z))
1060        }
1061        _ => unreachable!(),
1062    }
1063}
1064
1065fn build_outputs(
1066    x_axis: &AxisData,
1067    y_axis: &AxisData,
1068    z_axis: Option<&AxisData>,
1069) -> Vec<GridOutput> {
1070    let nx = x_axis.len;
1071    let ny = y_axis.len;
1072    let nz = z_axis.map(|axis| axis.len).unwrap_or(1);
1073    let total = nx * ny * nz;
1074    let mut x_data = Vec::with_capacity(total);
1075    let mut y_data = Vec::with_capacity(total);
1076    let mut z_data = z_axis.map(|_| Vec::with_capacity(total));
1077
1078    for k in 0..nz {
1079        let z_value = z_axis.map(|axis| axis.values[k]);
1080        for col in 0..nx {
1081            let x_value = x_axis.values[col];
1082            for row in 0..ny {
1083                x_data.push(x_value);
1084                y_data.push(y_axis.values[row]);
1085                if let Some(ref mut z_vec) = z_data {
1086                    z_vec.push(z_value.unwrap());
1087                }
1088            }
1089        }
1090    }
1091
1092    let mut outputs = Vec::new();
1093    let base_shape = if nz == 1 {
1094        vec![ny, nx]
1095    } else {
1096        vec![ny, nx, nz]
1097    };
1098    outputs.push(GridOutput {
1099        shape: base_shape.clone(),
1100        data: x_data,
1101    });
1102    outputs.push(GridOutput {
1103        shape: base_shape.clone(),
1104        data: y_data,
1105    });
1106    if let Some(z_vec) = z_data {
1107        outputs.push(GridOutput {
1108            shape: base_shape,
1109            data: z_vec,
1110        });
1111    }
1112    outputs
1113}
1114
1115struct GridOutput {
1116    shape: Vec<usize>,
1117    data: Vec<(f64, f64)>,
1118}
1119
1120impl GridOutput {
1121    fn to_value(
1122        &self,
1123        class: PrototypeClass,
1124        residency: DevicePreference,
1125    ) -> crate::BuiltinResult<Value> {
1126        match class {
1127            PrototypeClass::Real => self.to_real_value(residency),
1128            PrototypeClass::Complex => self.to_complex_value(residency),
1129        }
1130    }
1131
1132    fn to_real_value(&self, residency: DevicePreference) -> crate::BuiltinResult<Value> {
1133        let mut real = Vec::with_capacity(self.data.len());
1134        for &(re, im) in &self.data {
1135            if im != 0.0 {
1136                return Err(builtin_error(
1137                    "meshgrid: cannot represent complex values in a real output",
1138                ));
1139            }
1140            real.push(re);
1141        }
1142        let tensor = Tensor::new(real, self.shape.clone())
1143            .map_err(|e| builtin_error(format!("meshgrid: {e}")))?;
1144        match residency {
1145            DevicePreference::Host => Ok(tensor::tensor_into_value(tensor)),
1146            DevicePreference::Gpu => to_gpu_tensor_value(tensor),
1147        }
1148    }
1149
1150    fn to_complex_value(&self, residency: DevicePreference) -> crate::BuiltinResult<Value> {
1151        let tensor = ComplexTensor::new(self.data.clone(), self.shape.clone())
1152            .map_err(|e| builtin_error(format!("meshgrid: {e}")))?;
1153        match residency {
1154            DevicePreference::Host => Ok(complex_tensor_into_value(tensor)),
1155            DevicePreference::Gpu => {
1156                warn!("meshgrid: complex GPU outputs are not implemented; returning host complex array");
1157                Ok(complex_tensor_into_value(tensor))
1158            }
1159        }
1160    }
1161}
1162
1163fn to_gpu_tensor_value(tensor: Tensor) -> crate::BuiltinResult<Value> {
1164    if let Some(provider) = runmat_accelerate_api::provider() {
1165        let view = HostTensorView {
1166            data: &tensor.data,
1167            shape: &tensor.shape,
1168        };
1169        match provider.upload(&view) {
1170            Ok(handle) => return Ok(Value::GpuTensor(handle)),
1171            Err(err) => {
1172                warn!("meshgrid: failed to upload tensor to GPU, returning host array: {err}")
1173            }
1174        }
1175    }
1176    Ok(tensor::tensor_into_value(tensor))
1177}
1178
1179fn tensor_to_complex_value(tensor: Tensor) -> crate::BuiltinResult<Value> {
1180    let data: Vec<(f64, f64)> = tensor.data.iter().map(|&re| (re, 0.0)).collect();
1181    let complex = ComplexTensor::new(data, tensor.shape.clone())
1182        .map_err(|e| builtin_error(format!("meshgrid: {e}")))?;
1183    Ok(complex_tensor_into_value(complex))
1184}
1185
1186enum MeshgridOutput {
1187    Host(GridOutput),
1188    GpuReal(GpuTensorHandle),
1189}
1190
1191impl MeshgridOutput {
1192    async fn to_value(
1193        &self,
1194        class: PrototypeClass,
1195        residency: DevicePreference,
1196    ) -> crate::BuiltinResult<Value> {
1197        match self {
1198            MeshgridOutput::Host(host) => host.to_value(class, residency),
1199            MeshgridOutput::GpuReal(handle) => match (class, residency) {
1200                (PrototypeClass::Real, DevicePreference::Gpu) => {
1201                    Ok(Value::GpuTensor(handle.clone()))
1202                }
1203                (PrototypeClass::Real, DevicePreference::Host) => {
1204                    let tensor = gpu_helpers::gather_tensor_async(handle).await?;
1205                    Ok(tensor::tensor_into_value(tensor))
1206                }
1207                (PrototypeClass::Complex, DevicePreference::Host) => {
1208                    let tensor = gpu_helpers::gather_tensor_async(handle).await?;
1209                    tensor_to_complex_value(tensor)
1210                }
1211                (PrototypeClass::Complex, DevicePreference::Gpu) => {
1212                    warn!("meshgrid: complex GPU outputs are not implemented; returning host complex array");
1213                    let tensor = gpu_helpers::gather_tensor_async(handle).await?;
1214                    tensor_to_complex_value(tensor)
1215                }
1216            },
1217        }
1218    }
1219}
1220
1221/// Holds the results of a `meshgrid` evaluation so multiple outputs can be
1222/// materialised without recomputing the grid.
1223pub struct MeshgridEval {
1224    outputs: Vec<MeshgridOutput>,
1225    target_class: PrototypeClass,
1226    target_residency: DevicePreference,
1227}
1228
1229impl MeshgridEval {
1230    pub fn output_count(&self) -> usize {
1231        self.outputs.len()
1232    }
1233
1234    pub async fn first(&self) -> crate::BuiltinResult<Value> {
1235        self.outputs[0]
1236            .to_value(self.target_class, self.target_residency)
1237            .await
1238    }
1239
1240    pub async fn second(&self) -> crate::BuiltinResult<Value> {
1241        if self.outputs.len() < 2 {
1242            Err(builtin_error("meshgrid: second output unavailable"))
1243        } else {
1244            self.outputs[1]
1245                .to_value(self.target_class, self.target_residency)
1246                .await
1247        }
1248    }
1249
1250    pub async fn third(&self) -> crate::BuiltinResult<Value> {
1251        if self.outputs.len() < 3 {
1252            Err(builtin_error(
1253                "meshgrid: third output requested but no Z vector was supplied",
1254            ))
1255        } else {
1256            self.outputs[2]
1257                .to_value(self.target_class, self.target_residency)
1258                .await
1259        }
1260    }
1261}
1262
1263#[cfg(test)]
1264pub(crate) mod tests {
1265    use super::*;
1266    use crate::builtins::common::test_support;
1267    use futures::executor::block_on;
1268    #[cfg(feature = "wgpu")]
1269    use runmat_accelerate_api::AccelProvider;
1270
1271    use runmat_accelerate_api::HostTensorView;
1272
1273    fn evaluate(args: &[Value]) -> crate::BuiltinResult<MeshgridEval> {
1274        block_on(super::evaluate(args))
1275    }
1276
1277    fn eval_first(eval: &MeshgridEval) -> crate::BuiltinResult<Value> {
1278        block_on(eval.first())
1279    }
1280
1281    fn eval_second(eval: &MeshgridEval) -> crate::BuiltinResult<Value> {
1282        block_on(eval.second())
1283    }
1284
1285    fn eval_third(eval: &MeshgridEval) -> crate::BuiltinResult<Value> {
1286        block_on(eval.third())
1287    }
1288
1289    fn tensor_from_vec(data: Vec<f64>, rows: usize, cols: usize) -> Tensor {
1290        Tensor::new(data, vec![rows, cols]).unwrap()
1291    }
1292
1293    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1294    #[test]
1295    fn meshgrid_single_input_duplicates_axis() {
1296        let x = tensor_from_vec(vec![-1.0, 0.0, 1.0], 1, 3);
1297        let eval = evaluate(&[Value::Tensor(x)]).expect("meshgrid");
1298        assert_eq!(eval.output_count(), 2);
1299        let x_out = test_support::gather(eval_first(&eval).expect("X")).expect("host");
1300        assert_eq!(x_out.shape, vec![3, 3]);
1301        assert_eq!(
1302            x_out.data,
1303            vec![-1.0, -1.0, -1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0]
1304        );
1305        let y_out = test_support::gather(eval_second(&eval).expect("Y")).expect("host");
1306        assert_eq!(y_out.shape, vec![3, 3]);
1307        assert_eq!(
1308            y_out.data,
1309            vec![-1.0, 0.0, 1.0, -1.0, 0.0, 1.0, -1.0, 0.0, 1.0]
1310        );
1311    }
1312
1313    #[test]
1314    fn meshgrid_type_infers_rank_from_axis_count() {
1315        let ctx = ResolveContext::new(Vec::new());
1316        assert_eq!(
1317            meshgrid_type(&[Type::Num, Type::Num], &ctx),
1318            Type::Tensor {
1319                shape: Some(vec![Some(1), Some(1)])
1320            }
1321        );
1322        assert_eq!(
1323            meshgrid_type(&[Type::Num, Type::Num, Type::Num], &ctx),
1324            Type::Tensor {
1325                shape: Some(vec![Some(1), Some(1), Some(1)])
1326            }
1327        );
1328    }
1329
1330    #[test]
1331    fn meshgrid_type_uses_vector_lengths() {
1332        let ctx = ResolveContext::new(Vec::new());
1333        assert_eq!(
1334            meshgrid_type(
1335                &[
1336                    Type::Tensor {
1337                        shape: Some(vec![Some(1), Some(201)]),
1338                    },
1339                    Type::Tensor {
1340                        shape: Some(vec![Some(1), Some(101)]),
1341                    },
1342                ],
1343                &ctx,
1344            ),
1345            Type::Tensor {
1346                shape: Some(vec![Some(101), Some(201)])
1347            }
1348        );
1349    }
1350
1351    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1352    #[test]
1353    fn meshgrid_rectangular_inputs() {
1354        let x = tensor_from_vec(vec![0.0, 0.5, 1.0], 1, 3);
1355        let y = tensor_from_vec(vec![10.0, 20.0], 2, 1);
1356        let eval = evaluate(&[Value::Tensor(x), Value::Tensor(y)]).expect("meshgrid");
1357        assert_eq!(eval.output_count(), 2);
1358        let x_out = test_support::gather(eval_first(&eval).expect("X")).expect("host");
1359        assert_eq!(x_out.shape, vec![2, 3]);
1360        assert_eq!(x_out.data, vec![0.0, 0.0, 0.5, 0.5, 1.0, 1.0]);
1361        let y_out = test_support::gather(eval_second(&eval).expect("Y")).expect("host");
1362        assert_eq!(y_out.shape, vec![2, 3]);
1363        assert_eq!(y_out.data, vec![10.0, 20.0, 10.0, 20.0, 10.0, 20.0]);
1364    }
1365
1366    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1367    #[test]
1368    fn meshgrid_three_inputs_volume() {
1369        let x = tensor_from_vec(vec![1.0, 2.0], 1, 2);
1370        let y = tensor_from_vec(vec![5.0, 6.0, 7.0], 3, 1);
1371        let z = tensor_from_vec(vec![0.0, 1.0], 1, 2);
1372        let eval =
1373            evaluate(&[Value::Tensor(x), Value::Tensor(y), Value::Tensor(z)]).expect("meshgrid");
1374        assert_eq!(eval.output_count(), 3);
1375        let x_out = test_support::gather(eval_first(&eval).expect("X")).expect("host");
1376        assert_eq!(x_out.shape, vec![3, 2, 2]);
1377        assert_eq!(
1378            x_out.data,
1379            vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0]
1380        );
1381        let z_out = test_support::gather(eval_third(&eval).expect("Z")).expect("host");
1382        assert_eq!(z_out.shape, vec![3, 2, 2]);
1383        assert_eq!(
1384            z_out.data,
1385            vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
1386        );
1387    }
1388
1389    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1390    #[test]
1391    fn meshgrid_like_keeps_gpu_residency() {
1392        test_support::with_test_provider(|provider| {
1393            let x = tensor_from_vec(vec![-1.0, 0.0, 1.0], 1, 3);
1394            let y = tensor_from_vec(vec![2.0, 4.0], 2, 1);
1395            let proto = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
1396            let proto_view = HostTensorView {
1397                data: &proto.data,
1398                shape: &proto.shape,
1399            };
1400            let proto_handle = provider.upload(&proto_view).expect("upload");
1401            let eval = evaluate(&[
1402                Value::Tensor(x),
1403                Value::Tensor(y),
1404                Value::from("like"),
1405                Value::GpuTensor(proto_handle),
1406            ])
1407            .expect("meshgrid");
1408            let x_value = eval_first(&eval).expect("X");
1409            assert!(matches!(x_value, Value::GpuTensor(_)));
1410            let gathered = test_support::gather(x_value).expect("gather");
1411            assert_eq!(gathered.shape, vec![2, 3]);
1412        });
1413    }
1414
1415    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1416    #[test]
1417    fn meshgrid_gpu_inputs_roundtrip() {
1418        test_support::with_test_provider(|provider| {
1419            let x = tensor_from_vec(vec![0.0, 0.5], 1, 2);
1420            let y = tensor_from_vec(vec![1.0, 2.0], 2, 1);
1421            let x_view = HostTensorView {
1422                data: &x.data,
1423                shape: &x.shape,
1424            };
1425            let y_view = HostTensorView {
1426                data: &y.data,
1427                shape: &y.shape,
1428            };
1429            let x_handle = provider.upload(&x_view).expect("upload");
1430            let y_handle = provider.upload(&y_view).expect("upload");
1431            let eval = evaluate(&[Value::GpuTensor(x_handle), Value::GpuTensor(y_handle)])
1432                .expect("meshgrid");
1433            assert!(matches!(eval_first(&eval).expect("X"), Value::GpuTensor(_)));
1434            assert!(matches!(
1435                eval_second(&eval).expect("Y"),
1436                Value::GpuTensor(_)
1437            ));
1438        });
1439    }
1440
1441    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1442    #[test]
1443    #[cfg(feature = "wgpu")]
1444    fn meshgrid_wgpu_matches_cpu() {
1445        let provider = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1446            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1447        )
1448        .expect("wgpu provider");
1449
1450        let x = tensor_from_vec(vec![-1.0, 0.0, 1.0, 2.0], 1, 4);
1451        let y = tensor_from_vec(vec![5.0, 6.0], 2, 1);
1452
1453        let cpu_eval =
1454            evaluate(&[Value::Tensor(x.clone()), Value::Tensor(y.clone())]).expect("meshgrid cpu");
1455        let cpu_x =
1456            test_support::gather(eval_first(&cpu_eval).expect("X cpu")).expect("gather X cpu");
1457        let cpu_y =
1458            test_support::gather(eval_second(&cpu_eval).expect("Y cpu")).expect("gather Y cpu");
1459
1460        let x_view = HostTensorView {
1461            data: &x.data,
1462            shape: &x.shape,
1463        };
1464        let y_view = HostTensorView {
1465            data: &y.data,
1466            shape: &y.shape,
1467        };
1468        let x_gpu = provider.upload(&x_view).expect("upload x");
1469        let y_gpu = provider.upload(&y_view).expect("upload y");
1470
1471        let gpu_eval =
1472            evaluate(&[Value::GpuTensor(x_gpu), Value::GpuTensor(y_gpu)]).expect("meshgrid gpu");
1473        let gpu_x_value = eval_first(&gpu_eval).expect("X gpu");
1474        let gpu_y_value = eval_second(&gpu_eval).expect("Y gpu");
1475
1476        assert!(matches!(gpu_x_value, Value::GpuTensor(_)));
1477        assert!(matches!(gpu_y_value, Value::GpuTensor(_)));
1478
1479        let gathered_x = test_support::gather(gpu_x_value).expect("gather X gpu");
1480        let gathered_y = test_support::gather(gpu_y_value).expect("gather Y gpu");
1481
1482        assert_eq!(gathered_x.shape, cpu_x.shape);
1483        assert_eq!(gathered_x.data, cpu_x.data);
1484        assert_eq!(gathered_y.shape, cpu_y.shape);
1485        assert_eq!(gathered_y.data, cpu_y.data);
1486    }
1487
1488    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1489    #[test]
1490    fn meshgrid_complex_inputs_produce_complex_outputs() {
1491        let complex = ComplexTensor::new(vec![(1.0, 1.0), (2.0, -1.0)], vec![1, 2]).unwrap();
1492        let eval = evaluate(&[Value::ComplexTensor(complex)]).expect("meshgrid");
1493        let x_value = eval_first(&eval).expect("X");
1494        match x_value {
1495            Value::ComplexTensor(ct) => {
1496                assert_eq!(ct.shape, vec![2, 2]);
1497            }
1498            Value::Complex(_, _) => {}
1499            other => panic!("expected complex output, got {other:?}"),
1500        }
1501    }
1502
1503    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1504    #[test]
1505    fn meshgrid_like_host_prototype() {
1506        let x = tensor_from_vec(vec![1.0, 2.0], 1, 2);
1507        let eval =
1508            evaluate(&[Value::Tensor(x), Value::from("like"), Value::Num(0.0)]).expect("meshgrid");
1509        let x_out = eval_first(&eval).expect("X");
1510        assert!(matches!(x_out, Value::Tensor(_) | Value::Num(_)));
1511    }
1512}