Skip to main content

runmat_runtime/builtins/array/creation/
meshgrid.rs

1//! MATLAB-compatible `meshgrid` builtin with GPU-aware semantics.
2
3use std::cmp::max;
4
5use log::warn;
6use runmat_accelerate_api::{GpuTensorHandle, HostTensorView};
7use runmat_builtins::{
8    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
9    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
10    ComplexTensor, ResolveContext, Tensor, Type, Value,
11};
12
13use crate::builtins::array::type_resolvers::size_vector_len;
14use runmat_macros::runtime_builtin;
15
16use crate::build_runtime_error;
17use crate::builtins::common::gpu_helpers;
18use crate::builtins::common::random_args::{complex_tensor_into_value, keyword_of};
19use crate::builtins::common::residency::{sequence_gpu_preference, SequenceIntent};
20use crate::builtins::common::spec::{
21    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
22    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
23};
24use crate::builtins::common::tensor;
25
26#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::creation::meshgrid")]
27pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
28    name: "meshgrid",
29    op_kind: GpuOpKind::Custom("array_construct"),
30    supported_precisions: &[ScalarType::F32, ScalarType::F64],
31    broadcast: BroadcastSemantics::Matlab,
32    provider_hooks: &[ProviderHook::Custom("meshgrid")],
33    constant_strategy: ConstantStrategy::InlineLiteral,
34    residency: ResidencyPolicy::NewHandle,
35    nan_mode: ReductionNaN::Include,
36    two_pass_threshold: None,
37    workgroup_size: None,
38    accepts_nan_mode: false,
39    notes: "Providers may supply a dedicated meshgrid hook; until then the runtime builds grids on the host and uploads them when GPU residency is requested.",
40};
41
42fn builtin_error(message: impl Into<String>) -> crate::RuntimeError {
43    build_runtime_error(message)
44        .with_builtin("meshgrid")
45        .build()
46}
47
48#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::creation::meshgrid")]
49pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
50    name: "meshgrid",
51    shape: ShapeRequirements::Any,
52    constant_strategy: ConstantStrategy::InlineLiteral,
53    elementwise: None,
54    reduction: None,
55    emits_nan: false,
56    notes:
57        "Meshgrid explicitly materialises dense coordinate arrays and therefore bypasses fusion.",
58};
59
60fn meshgrid_type(args: &[Type], _context: &ResolveContext) -> Type {
61    if args.is_empty() {
62        return Type::Unknown;
63    }
64    let mut axis_count = args.len();
65    if axis_count >= 2 && matches!(args[axis_count - 2], Type::String) {
66        axis_count = axis_count.saturating_sub(2);
67    }
68    if axis_count == 0 {
69        return Type::Unknown;
70    }
71    let axis_args = &args[..axis_count];
72    let len_x = axis_args.get(0).and_then(size_vector_len);
73    let len_y = axis_args.get(1).and_then(size_vector_len).or(len_x);
74    let len_z = axis_args.get(2).and_then(size_vector_len);
75    let shape = if axis_count >= 3 {
76        vec![len_y, len_x, len_z]
77    } else {
78        vec![len_y, len_x]
79    };
80    Type::Tensor { shape: Some(shape) }
81}
82
83const MESHGRID_OUTPUT_XY: [BuiltinParamDescriptor; 2] = [
84    BuiltinParamDescriptor {
85        name: "X",
86        ty: BuiltinParamType::NumericArray,
87        arity: BuiltinParamArity::Required,
88        default: None,
89        description: "Grid coordinates along X-axis.",
90    },
91    BuiltinParamDescriptor {
92        name: "Y",
93        ty: BuiltinParamType::NumericArray,
94        arity: BuiltinParamArity::Required,
95        default: None,
96        description: "Grid coordinates along Y-axis.",
97    },
98];
99
100const MESHGRID_OUTPUT_XYZ: [BuiltinParamDescriptor; 3] = [
101    BuiltinParamDescriptor {
102        name: "X",
103        ty: BuiltinParamType::NumericArray,
104        arity: BuiltinParamArity::Required,
105        default: None,
106        description: "Grid coordinates along X-axis.",
107    },
108    BuiltinParamDescriptor {
109        name: "Y",
110        ty: BuiltinParamType::NumericArray,
111        arity: BuiltinParamArity::Required,
112        default: None,
113        description: "Grid coordinates along Y-axis.",
114    },
115    BuiltinParamDescriptor {
116        name: "Z",
117        ty: BuiltinParamType::NumericArray,
118        arity: BuiltinParamArity::Optional,
119        default: None,
120        description: "Grid coordinates along Z-axis.",
121    },
122];
123
124const MESHGRID_SIG_X_INPUTS: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
125    name: "x",
126    ty: BuiltinParamType::NumericArray,
127    arity: BuiltinParamArity::Required,
128    default: None,
129    description: "X-axis vector.",
130}];
131
132const MESHGRID_SIG_XY_INPUTS: [BuiltinParamDescriptor; 2] = [
133    BuiltinParamDescriptor {
134        name: "x",
135        ty: BuiltinParamType::NumericArray,
136        arity: BuiltinParamArity::Required,
137        default: None,
138        description: "X-axis vector.",
139    },
140    BuiltinParamDescriptor {
141        name: "y",
142        ty: BuiltinParamType::NumericArray,
143        arity: BuiltinParamArity::Required,
144        default: None,
145        description: "Y-axis vector.",
146    },
147];
148
149const MESHGRID_SIG_XYZ_INPUTS: [BuiltinParamDescriptor; 3] = [
150    BuiltinParamDescriptor {
151        name: "x",
152        ty: BuiltinParamType::NumericArray,
153        arity: BuiltinParamArity::Required,
154        default: None,
155        description: "X-axis vector.",
156    },
157    BuiltinParamDescriptor {
158        name: "y",
159        ty: BuiltinParamType::NumericArray,
160        arity: BuiltinParamArity::Required,
161        default: None,
162        description: "Y-axis vector.",
163    },
164    BuiltinParamDescriptor {
165        name: "z",
166        ty: BuiltinParamType::NumericArray,
167        arity: BuiltinParamArity::Optional,
168        default: None,
169        description: "Z-axis vector.",
170    },
171];
172
173const MESHGRID_SIG_X_LIKE_INPUTS: [BuiltinParamDescriptor; 3] = [
174    BuiltinParamDescriptor {
175        name: "x",
176        ty: BuiltinParamType::NumericArray,
177        arity: BuiltinParamArity::Required,
178        default: None,
179        description: "X-axis vector.",
180    },
181    BuiltinParamDescriptor {
182        name: "like_kw",
183        ty: BuiltinParamType::StringScalar,
184        arity: BuiltinParamArity::Required,
185        default: Some("\"like\""),
186        description: "Like keyword.",
187    },
188    BuiltinParamDescriptor {
189        name: "prototype",
190        ty: BuiltinParamType::LikePrototype,
191        arity: BuiltinParamArity::Required,
192        default: None,
193        description: "Prototype controlling class/device residency.",
194    },
195];
196
197const MESHGRID_SIG_XY_LIKE_INPUTS: [BuiltinParamDescriptor; 4] = [
198    BuiltinParamDescriptor {
199        name: "x",
200        ty: BuiltinParamType::NumericArray,
201        arity: BuiltinParamArity::Required,
202        default: None,
203        description: "X-axis vector.",
204    },
205    BuiltinParamDescriptor {
206        name: "y",
207        ty: BuiltinParamType::NumericArray,
208        arity: BuiltinParamArity::Required,
209        default: None,
210        description: "Y-axis vector.",
211    },
212    BuiltinParamDescriptor {
213        name: "like_kw",
214        ty: BuiltinParamType::StringScalar,
215        arity: BuiltinParamArity::Required,
216        default: Some("\"like\""),
217        description: "Like keyword.",
218    },
219    BuiltinParamDescriptor {
220        name: "prototype",
221        ty: BuiltinParamType::LikePrototype,
222        arity: BuiltinParamArity::Required,
223        default: None,
224        description: "Prototype controlling class/device residency.",
225    },
226];
227
228const MESHGRID_SIG_XYZ_LIKE_INPUTS: [BuiltinParamDescriptor; 5] = [
229    BuiltinParamDescriptor {
230        name: "x",
231        ty: BuiltinParamType::NumericArray,
232        arity: BuiltinParamArity::Required,
233        default: None,
234        description: "X-axis vector.",
235    },
236    BuiltinParamDescriptor {
237        name: "y",
238        ty: BuiltinParamType::NumericArray,
239        arity: BuiltinParamArity::Required,
240        default: None,
241        description: "Y-axis vector.",
242    },
243    BuiltinParamDescriptor {
244        name: "z",
245        ty: BuiltinParamType::NumericArray,
246        arity: BuiltinParamArity::Optional,
247        default: None,
248        description: "Z-axis vector.",
249    },
250    BuiltinParamDescriptor {
251        name: "like_kw",
252        ty: BuiltinParamType::StringScalar,
253        arity: BuiltinParamArity::Required,
254        default: Some("\"like\""),
255        description: "Like keyword.",
256    },
257    BuiltinParamDescriptor {
258        name: "prototype",
259        ty: BuiltinParamType::LikePrototype,
260        arity: BuiltinParamArity::Required,
261        default: None,
262        description: "Prototype controlling class/device residency.",
263    },
264];
265
266const MESHGRID_SIGNATURES: [BuiltinSignatureDescriptor; 6] = [
267    BuiltinSignatureDescriptor {
268        label: "[X,Y] = meshgrid(x)",
269        inputs: &MESHGRID_SIG_X_INPUTS,
270        outputs: &MESHGRID_OUTPUT_XY,
271    },
272    BuiltinSignatureDescriptor {
273        label: "[X,Y] = meshgrid(x, y)",
274        inputs: &MESHGRID_SIG_XY_INPUTS,
275        outputs: &MESHGRID_OUTPUT_XY,
276    },
277    BuiltinSignatureDescriptor {
278        label: "[X,Y,Z] = meshgrid(x, y, z)",
279        inputs: &MESHGRID_SIG_XYZ_INPUTS,
280        outputs: &MESHGRID_OUTPUT_XYZ,
281    },
282    BuiltinSignatureDescriptor {
283        label: "[X,Y] = meshgrid(x, \"like\", prototype)",
284        inputs: &MESHGRID_SIG_X_LIKE_INPUTS,
285        outputs: &MESHGRID_OUTPUT_XY,
286    },
287    BuiltinSignatureDescriptor {
288        label: "[X,Y] = meshgrid(x, y, \"like\", prototype)",
289        inputs: &MESHGRID_SIG_XY_LIKE_INPUTS,
290        outputs: &MESHGRID_OUTPUT_XY,
291    },
292    BuiltinSignatureDescriptor {
293        label: "[X,Y,Z] = meshgrid(x, y, z, \"like\", prototype)",
294        inputs: &MESHGRID_SIG_XYZ_LIKE_INPUTS,
295        outputs: &MESHGRID_OUTPUT_XYZ,
296    },
297];
298
299const MESHGRID_ERRORS: [BuiltinErrorDescriptor; 11] = [
300    BuiltinErrorDescriptor {
301        code: "RM.MESHGRID.MISSING_AXIS",
302        identifier: None,
303        when: "No axis vectors are provided.",
304        message: "meshgrid: at least one input vector is required",
305    },
306    BuiltinErrorDescriptor {
307        code: "RM.MESHGRID.TOO_MANY_AXES",
308        identifier: None,
309        when: "More than three axis vectors are provided.",
310        message: "meshgrid: expected at most three input vectors",
311    },
312    BuiltinErrorDescriptor {
313        code: "RM.MESHGRID.LIKE_EXPECTED_PROTOTYPE",
314        identifier: None,
315        when: "The 'like' keyword is provided without a prototype argument.",
316        message: "meshgrid: expected prototype after 'like'",
317    },
318    BuiltinErrorDescriptor {
319        code: "RM.MESHGRID.MULTIPLE_LIKE",
320        identifier: None,
321        when: "The 'like' keyword is provided multiple times.",
322        message: "meshgrid: multiple 'like' specifications are not supported",
323    },
324    BuiltinErrorDescriptor {
325        code: "RM.MESHGRID.LIKE_POSITION",
326        identifier: None,
327        when: "The 'like' keyword is in an invalid position or not final.",
328        message: "meshgrid: 'like' must be the final argument",
329    },
330    BuiltinErrorDescriptor {
331        code: "RM.MESHGRID.UNRECOGNIZED_OPTION",
332        identifier: None,
333        when: "A trailing option string is not recognized.",
334        message: "meshgrid: unrecognised option",
335    },
336    BuiltinErrorDescriptor {
337        code: "RM.MESHGRID.INVALID_AXIS_INPUT",
338        identifier: None,
339        when: "Axis inputs are non-numeric or non-vector shapes.",
340        message: "meshgrid: input argument must be numeric vector data",
341    },
342    BuiltinErrorDescriptor {
343        code: "RM.MESHGRID.INVALID_PROTOTYPE",
344        identifier: None,
345        when: "The 'like' prototype is unsupported.",
346        message: "meshgrid: prototypes must be numeric arrays",
347    },
348    BuiltinErrorDescriptor {
349        code: "RM.MESHGRID.OUTPUT_COUNT_EXCEEDED",
350        identifier: None,
351        when: "Requested outputs exceed available outputs for provided axes.",
352        message:
353            "meshgrid: supports at most two outputs for 2-axis inputs and three for 3-axis inputs",
354    },
355    BuiltinErrorDescriptor {
356        code: "RM.MESHGRID.THIRD_OUTPUT_UNAVAILABLE",
357        identifier: None,
358        when: "A third output is requested without supplying a Z-axis vector.",
359        message: "meshgrid: third output requested but no Z vector was supplied",
360    },
361    BuiltinErrorDescriptor {
362        code: "RM.MESHGRID.COMPLEX_REAL_CONVERSION",
363        identifier: None,
364        when: "Complex axis values cannot be represented in requested real output class.",
365        message: "meshgrid: cannot represent complex values in a real output",
366    },
367];
368
369pub const MESHGRID_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
370    signatures: &MESHGRID_SIGNATURES,
371    output_mode: BuiltinOutputMode::ByRequestedOutputCount,
372    completion_policy: BuiltinCompletionPolicy::Public,
373    errors: &MESHGRID_ERRORS,
374};
375
376#[runtime_builtin(
377    name = "meshgrid",
378    category = "array/creation",
379    summary = "Generate coordinate matrices for 2-D and 3-D grids.",
380    keywords = "meshgrid,grid,gpu,like,3d",
381    accel = "array_construct",
382    type_resolver(meshgrid_type),
383    descriptor(crate::builtins::array::creation::meshgrid::MESHGRID_DESCRIPTOR),
384    builtin_path = "crate::builtins::array::creation::meshgrid"
385)]
386async fn meshgrid_builtin(rest: Vec<Value>) -> crate::BuiltinResult<Value> {
387    let eval = evaluate(&rest).await?;
388    if let Some(out_count) = crate::output_count::current_output_count() {
389        if out_count == 0 {
390            return Ok(Value::OutputList(Vec::new()));
391        }
392        let available = eval.output_count();
393        if out_count > available {
394            let msg = if available == 2 {
395                "meshgrid with two inputs supports at most two outputs"
396            } else {
397                "meshgrid supports at most three outputs"
398            };
399            return Err(builtin_error(msg));
400        }
401        let mut outputs = Vec::with_capacity(out_count);
402        let first = eval.first().await?;
403        outputs.push(first);
404        if out_count >= 2 {
405            outputs.push(eval.second().await?);
406        }
407        if out_count >= 3 {
408            outputs.push(eval.third().await?);
409        }
410        return Ok(Value::OutputList(outputs));
411    }
412    eval.first().await
413}
414
415/// Evaluate the `meshgrid` builtin once and reuse the result for multiple outputs.
416pub async fn evaluate(args: &[Value]) -> crate::BuiltinResult<MeshgridEval> {
417    let parsed = ParsedMeshgrid::parse(args).await?;
418    let (x_axis, y_axis, z_axis) = normalise_axes(&parsed.axes);
419
420    let require_complex = parsed.axes.iter().any(|axis| axis.is_complex);
421
422    let target_class = match &parsed.template {
423        OutputTemplate::Default => {
424            if require_complex {
425                PrototypeClass::Complex
426            } else {
427                PrototypeClass::Real
428            }
429        }
430        OutputTemplate::Like(spec) => {
431            if require_complex {
432                PrototypeClass::Complex
433            } else {
434                spec.class
435            }
436        }
437    };
438
439    let target_residency = match &parsed.template {
440        OutputTemplate::Default => {
441            if parsed.prefer_gpu {
442                DevicePreference::Gpu
443            } else {
444                DevicePreference::Host
445            }
446        }
447        OutputTemplate::Like(spec) => spec.residency,
448    };
449
450    let axes_all_real = !require_complex;
451    let mut outputs: Vec<MeshgridOutput> = Vec::new();
452
453    if axes_all_real
454        && matches!(target_class, PrototypeClass::Real)
455        && matches!(target_residency, DevicePreference::Gpu)
456    {
457        if let Some(gpu) = try_meshgrid_gpu_from_vector_axes(&x_axis, &y_axis, z_axis.as_ref())? {
458            outputs = gpu;
459        }
460    }
461
462    if outputs.is_empty() {
463        // Host fallback: ensure we have host axis values materialized.
464        let x_host = axis_to_host_async(&x_axis).await?;
465        let y_host = axis_to_host_async(&y_axis).await?;
466        let z_host = match z_axis.as_ref() {
467            Some(axis) => Some(axis_to_host_async(axis).await?),
468            None => None,
469        };
470        outputs = build_outputs(&x_host, &y_host, z_host.as_ref())
471            .into_iter()
472            .map(MeshgridOutput::Host)
473            .collect();
474    }
475
476    Ok(MeshgridEval {
477        outputs,
478        target_class,
479        target_residency,
480    })
481}
482
483#[derive(Clone)]
484struct ParsedMeshgrid {
485    axes: Vec<AxisData>,
486    template: OutputTemplate,
487    prefer_gpu: bool,
488}
489
490impl ParsedMeshgrid {
491    async fn parse(args: &[Value]) -> crate::BuiltinResult<Self> {
492        if args.is_empty() {
493            return Err(builtin_error(
494                "meshgrid: at least one input vector is required",
495            ));
496        }
497        let mut axis_values: Vec<Value> = Vec::new();
498        let mut like_proto: Option<Value> = None;
499        let mut prefer_gpu = false;
500        let mut idx = 0;
501        while idx < args.len() {
502            let value = args[idx].clone();
503            if let Some(keyword) = keyword_of(&value) {
504                match keyword.as_str() {
505                    "like" => {
506                        if like_proto.is_some() {
507                            return Err(builtin_error(
508                                "meshgrid: multiple 'like' specifications are not supported",
509                            ));
510                        }
511                        if axis_values.is_empty() {
512                            return Err(builtin_error(
513                                "meshgrid: 'like' must follow at least one input vector",
514                            ));
515                        }
516                        let Some(proto) = args.get(idx + 1).cloned() else {
517                            return Err(builtin_error("meshgrid: expected prototype after 'like'"));
518                        };
519                        like_proto = Some(proto);
520                        idx += 2;
521                        if idx < args.len() {
522                            return Err(builtin_error(
523                                "meshgrid: 'like' must be the final argument",
524                            ));
525                        }
526                        break;
527                    }
528                    other => {
529                        return Err(builtin_error(format!(
530                            "meshgrid: unrecognised option '{other}'"
531                        )));
532                    }
533                }
534            }
535
536            if let Value::GpuTensor(_) = value {
537                prefer_gpu = true;
538            }
539            axis_values.push(value);
540            idx += 1;
541        }
542
543        if axis_values.is_empty() {
544            return Err(builtin_error(
545                "meshgrid: at least one input vector is required",
546            ));
547        }
548        if axis_values.len() > 3 {
549            return Err(builtin_error(
550                "meshgrid: expected at most three input vectors",
551            ));
552        }
553
554        let mut axes = Vec::with_capacity(max(axis_values.len(), 2));
555        for (i, value) in axis_values.into_iter().enumerate() {
556            let mut consumed_gpu = false;
557            let data = axis_from_value(value, i, &mut consumed_gpu).await?;
558            if consumed_gpu {
559                prefer_gpu = true;
560            }
561            axes.push(data);
562        }
563
564        if !prefer_gpu {
565            if let Some(max_len) = axes.iter().map(|axis| axis.len).max() {
566                if max_len > 0
567                    && sequence_gpu_preference(max_len, SequenceIntent::MeshAxis, false).prefer_gpu
568                {
569                    prefer_gpu = true;
570                }
571            }
572        }
573
574        let template = if let Some(proto) = like_proto {
575            OutputTemplate::Like(analyse_like_prototype(&proto)?)
576        } else {
577            OutputTemplate::Default
578        };
579
580        Ok(Self {
581            axes,
582            template,
583            prefer_gpu,
584        })
585    }
586}
587
588#[derive(Clone)]
589enum OutputTemplate {
590    Default,
591    Like(PrototypeSpec),
592}
593
594#[derive(Clone)]
595struct PrototypeSpec {
596    residency: DevicePreference,
597    class: PrototypeClass,
598}
599
600#[derive(Clone, Copy, PartialEq, Eq)]
601enum PrototypeClass {
602    Real,
603    Complex,
604}
605
606#[derive(Clone, Copy)]
607enum DevicePreference {
608    Host,
609    Gpu,
610}
611
612fn analyse_like_prototype(proto: &Value) -> crate::BuiltinResult<PrototypeSpec> {
613    match proto {
614        Value::GpuTensor(_) => Ok(PrototypeSpec {
615            residency: DevicePreference::Gpu,
616            class: PrototypeClass::Real,
617        }),
618        Value::ComplexTensor(_) | Value::Complex(_, _) => Ok(PrototypeSpec {
619            residency: DevicePreference::Host,
620            class: PrototypeClass::Complex,
621        }),
622        Value::Tensor(_)
623        | Value::Num(_)
624        | Value::Int(_)
625        | Value::Bool(_)
626        | Value::LogicalArray(_) => Ok(PrototypeSpec {
627            residency: DevicePreference::Host,
628            class: PrototypeClass::Real,
629        }),
630        Value::CharArray(_) | Value::String(_) | Value::StringArray(_) => Err(builtin_error(
631            "meshgrid: prototypes must be numeric or gpuArray values",
632        )),
633        Value::Cell(_)
634        | Value::Struct(_)
635        | Value::Object(_)
636        | Value::HandleObject(_)
637        | Value::Listener(_)
638        | Value::FunctionHandle(_)
639        | Value::ExternalFunctionHandle(_)
640        | Value::MethodFunctionHandle(_)
641        | Value::BoundFunctionHandle { .. }
642        | Value::Closure(_)
643        | Value::ClassRef(_)
644        | Value::MException(_)
645        | Value::OutputList(_) => Err(builtin_error("meshgrid: prototypes must be numeric arrays")),
646    }
647}
648
649#[derive(Clone)]
650struct AxisData {
651    values: Vec<(f64, f64)>,
652    len: usize,
653    is_complex: bool,
654    gpu_real: Option<GpuTensorHandle>,
655}
656
657async fn axis_from_value(
658    value: Value,
659    index: usize,
660    prefer_gpu: &mut bool,
661) -> crate::BuiltinResult<AxisData> {
662    match value {
663        Value::Tensor(tensor) => axis_from_tensor(tensor, index),
664        Value::LogicalArray(logical) => {
665            let tensor = tensor::logical_to_tensor(&logical)?;
666            axis_from_tensor(tensor, index)
667        }
668        Value::Num(n) => Ok(AxisData {
669            values: vec![(n, 0.0)],
670            len: 1,
671            is_complex: false,
672            gpu_real: None,
673        }),
674        Value::Int(i) => {
675            let val = i.to_f64();
676            Ok(AxisData {
677                values: vec![(val, 0.0)],
678                len: 1,
679                is_complex: false,
680                gpu_real: None,
681            })
682        }
683        Value::Bool(b) => Ok(AxisData {
684            values: vec![(if b { 1.0 } else { 0.0 }, 0.0)],
685            len: 1,
686            is_complex: false,
687            gpu_real: None,
688        }),
689        Value::Complex(re, im) => Ok(AxisData {
690            values: vec![(re, im)],
691            len: 1,
692            is_complex: im != 0.0,
693            gpu_real: None,
694        }),
695        Value::ComplexTensor(tensor) => axis_from_complex_tensor(tensor, index),
696        Value::GpuTensor(handle) => {
697            // Fast path: if the gpuArray is vector-like, keep it on-device and avoid a download.
698            // We'll validate any non-vector shapes by gathering below.
699            if is_vector_shape(&handle.shape) {
700                *prefer_gpu = true;
701                return Ok(AxisData {
702                    values: Vec::new(),
703                    len: vector_len_from_shape(&handle.shape),
704                    is_complex: false,
705                    gpu_real: Some(handle),
706                });
707            }
708
709            // Fallback: gather to validate / recover axes from meshgrid matrices.
710            let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
711            if is_vector_shape(&tensor.shape) {
712                *prefer_gpu = true;
713            }
714            axis_from_tensor(tensor, index)
715        }
716        other => Err(builtin_error(format!(
717            "meshgrid: input argument {} must be numeric, got {other:?}",
718            index + 1
719        ))),
720    }
721}
722
723fn axis_from_tensor(tensor: Tensor, index: usize) -> crate::BuiltinResult<AxisData> {
724    if is_vector_shape(&tensor.shape) {
725        let mut values = Vec::with_capacity(tensor.data.len());
726        for &v in &tensor.data {
727            values.push((v, 0.0));
728        }
729        return Ok(AxisData {
730            len: values.len(),
731            values,
732            is_complex: false,
733            gpu_real: None,
734        });
735    }
736
737    // Be slightly more permissive than MATLAB: if the input is already a meshgrid-style
738    // coordinate matrix, accept it and recover the original axis vector.
739    //
740    // This is a pragmatic compatibility shim for cases where callers already have
741    // coordinate matrices (X/Y) and pass them through `meshgrid` again.
742    if let Some(axis) = axis_from_meshgrid_matrix_real(&tensor, index)? {
743        return Ok(axis);
744    }
745
746    Err(builtin_error(format!(
747        "meshgrid: input argument {} must be a vector (1xN or Nx1), got shape {:?}",
748        index + 1,
749        tensor.shape
750    )))
751}
752
753fn axis_from_complex_tensor(tensor: ComplexTensor, index: usize) -> crate::BuiltinResult<AxisData> {
754    if is_vector_shape(&tensor.shape) {
755        let is_complex = tensor
756            .data
757            .iter()
758            .any(|&(_, imag)| !imag.is_nan() && imag != 0.0);
759        return Ok(AxisData {
760            len: tensor.data.len(),
761            values: tensor.data,
762            is_complex,
763            gpu_real: None,
764        });
765    }
766
767    if let Some(axis) = axis_from_meshgrid_matrix_complex(&tensor, index)? {
768        return Ok(axis);
769    }
770
771    Err(builtin_error(format!(
772        "meshgrid: input argument {} must be a vector (1xN or Nx1), got shape {:?}",
773        index + 1,
774        tensor.shape
775    )))
776}
777
778fn axis_from_meshgrid_matrix_real(
779    tensor: &Tensor,
780    index: usize,
781) -> crate::BuiltinResult<Option<AxisData>> {
782    let (rows, cols) = match tensor.shape.as_slice() {
783        [r, c] => (*r, *c),
784        _ => return Ok(None),
785    };
786    if rows <= 1 || cols <= 1 {
787        return Ok(None);
788    }
789
790    // Index 0 is expected to be the X-axis: a meshgrid X matrix has identical rows.
791    // Index 1 is expected to be the Y-axis: a meshgrid Y matrix has identical columns.
792    let expect_rows_constant = index == 0;
793
794    if expect_rows_constant {
795        if !matrix_rows_are_identical_real(tensor, rows, cols) {
796            return Ok(None);
797        }
798        // Extract the first row as the axis vector (length = cols).
799        let mut values = Vec::with_capacity(cols);
800        for col in 0..cols {
801            let idx = rows * col;
802            values.push((tensor.data[idx], 0.0));
803        }
804        return Ok(Some(AxisData {
805            len: values.len(),
806            values,
807            is_complex: false,
808            gpu_real: None,
809        }));
810    }
811
812    if !matrix_cols_are_identical_real(tensor, rows, cols) {
813        return Ok(None);
814    }
815    // Extract the first column as the axis vector (length = rows).
816    let mut values = Vec::with_capacity(rows);
817    for row in 0..rows {
818        values.push((tensor.data[row], 0.0));
819    }
820    Ok(Some(AxisData {
821        len: values.len(),
822        values,
823        is_complex: false,
824        gpu_real: None,
825    }))
826}
827
828fn axis_from_meshgrid_matrix_complex(
829    tensor: &ComplexTensor,
830    index: usize,
831) -> crate::BuiltinResult<Option<AxisData>> {
832    let (rows, cols) = match tensor.shape.as_slice() {
833        [r, c] => (*r, *c),
834        _ => return Ok(None),
835    };
836    if rows <= 1 || cols <= 1 {
837        return Ok(None);
838    }
839
840    let expect_rows_constant = index == 0;
841    if expect_rows_constant {
842        if !matrix_rows_are_identical_complex(tensor, rows, cols) {
843            return Ok(None);
844        }
845        let mut values = Vec::with_capacity(cols);
846        for col in 0..cols {
847            let idx = rows * col;
848            values.push(tensor.data[idx]);
849        }
850        let is_complex = values.iter().any(|&(_, im)| !im.is_nan() && im != 0.0);
851        return Ok(Some(AxisData {
852            len: values.len(),
853            values,
854            is_complex,
855            gpu_real: None,
856        }));
857    }
858
859    if !matrix_cols_are_identical_complex(tensor, rows, cols) {
860        return Ok(None);
861    }
862    let mut values = Vec::with_capacity(rows);
863    for row in 0..rows {
864        values.push(tensor.data[row]);
865    }
866    let is_complex = values.iter().any(|&(_, im)| !im.is_nan() && im != 0.0);
867    Ok(Some(AxisData {
868        len: values.len(),
869        values,
870        is_complex,
871        gpu_real: None,
872    }))
873}
874
875fn matrix_rows_are_identical_real(tensor: &Tensor, rows: usize, cols: usize) -> bool {
876    for row in 1..rows {
877        for col in 0..cols {
878            let idx0 = rows * col;
879            let idx = row + rows * col;
880            if tensor.data[idx] != tensor.data[idx0] {
881                return false;
882            }
883        }
884    }
885    true
886}
887
888fn matrix_cols_are_identical_real(tensor: &Tensor, rows: usize, cols: usize) -> bool {
889    for col in 1..cols {
890        for row in 0..rows {
891            let idx0 = row;
892            let idx = row + rows * col;
893            if tensor.data[idx] != tensor.data[idx0] {
894                return false;
895            }
896        }
897    }
898    true
899}
900
901fn matrix_rows_are_identical_complex(tensor: &ComplexTensor, rows: usize, cols: usize) -> bool {
902    for row in 1..rows {
903        for col in 0..cols {
904            let idx0 = rows * col;
905            let idx = row + rows * col;
906            if tensor.data[idx] != tensor.data[idx0] {
907                return false;
908            }
909        }
910    }
911    true
912}
913
914fn matrix_cols_are_identical_complex(tensor: &ComplexTensor, rows: usize, cols: usize) -> bool {
915    for col in 1..cols {
916        for row in 0..rows {
917            let idx0 = row;
918            let idx = row + rows * col;
919            if tensor.data[idx] != tensor.data[idx0] {
920                return false;
921            }
922        }
923    }
924    true
925}
926
927fn is_vector_shape(shape: &[usize]) -> bool {
928    if shape.is_empty() {
929        return true;
930    }
931    let mut non_singleton = 0usize;
932    for &dim in shape {
933        if dim > 1 {
934            non_singleton += 1;
935        }
936    }
937    non_singleton <= 1
938}
939
940fn vector_len_from_shape(shape: &[usize]) -> usize {
941    if shape.is_empty() {
942        return 1;
943    }
944    shape.iter().copied().max().unwrap_or(0)
945}
946
947async fn axis_to_host_async(axis: &AxisData) -> crate::BuiltinResult<AxisData> {
948    if axis.gpu_real.is_none() {
949        return Ok(axis.clone());
950    }
951    let handle = axis.gpu_real.as_ref().expect("checked gpu_real is_some");
952    let tensor = gpu_helpers::gather_tensor_async(handle).await?;
953    // Index is only used for error messages; tensor came from a validated vector-like handle.
954    axis_from_tensor(tensor, 0)
955}
956
957fn try_meshgrid_gpu_from_vector_axes(
958    x_axis: &AxisData,
959    y_axis: &AxisData,
960    z_axis: Option<&AxisData>,
961) -> crate::BuiltinResult<Option<Vec<MeshgridOutput>>> {
962    let Some(x_handle) = x_axis.gpu_real.as_ref() else {
963        return Ok(None);
964    };
965    let Some(y_handle) = y_axis.gpu_real.as_ref() else {
966        return Ok(None);
967    };
968
969    let z_handle = match z_axis {
970        Some(axis) => match axis.gpu_real.as_ref() {
971            Some(h) => Some(h),
972            None => return Ok(None),
973        },
974        None => None,
975    };
976
977    let Some(provider) = runmat_accelerate_api::provider_for_handle(x_handle) else {
978        return Ok(None);
979    };
980    if runmat_accelerate_api::provider_for_handle(y_handle).is_none() {
981        return Ok(None);
982    }
983    if let Some(z) = z_handle {
984        if runmat_accelerate_api::provider_for_handle(z).is_none() {
985            return Ok(None);
986        }
987    }
988
989    let nx = x_axis.len;
990    let ny = y_axis.len;
991    let nz = z_axis.map(|axis| axis.len).unwrap_or(1);
992
993    // Reshape axis vectors (metadata-only) so repmat can build full grids on-device.
994    let x_row = provider
995        .reshape(x_handle, &[1, nx])
996        .map_err(|e| builtin_error(format!("meshgrid: reshape X failed: {e}")))?;
997    let y_col = provider
998        .reshape(y_handle, &[ny, 1])
999        .map_err(|e| builtin_error(format!("meshgrid: reshape Y failed: {e}")))?;
1000
1001    let mut outputs = Vec::with_capacity(if z_handle.is_some() { 3 } else { 2 });
1002    if let Some(z) = z_handle {
1003        let x_base = provider
1004            .reshape(&x_row, &[1, nx, 1])
1005            .map_err(|e| builtin_error(format!("meshgrid: reshape X(3d) failed: {e}")))?;
1006        let y_base = provider
1007            .reshape(&y_col, &[ny, 1, 1])
1008            .map_err(|e| builtin_error(format!("meshgrid: reshape Y(3d) failed: {e}")))?;
1009
1010        let x_grid = provider
1011            .repmat(&x_base, &[ny, 1, nz])
1012            .map_err(|e| builtin_error(format!("meshgrid: repmat X failed: {e}")))?;
1013        let y_grid = provider
1014            .repmat(&y_base, &[1, nx, nz])
1015            .map_err(|e| builtin_error(format!("meshgrid: repmat Y failed: {e}")))?;
1016
1017        outputs.push(MeshgridOutput::GpuReal(x_grid));
1018        outputs.push(MeshgridOutput::GpuReal(y_grid));
1019        let z_axis_row = provider
1020            .reshape(z, &[1, nz])
1021            .map_err(|e| builtin_error(format!("meshgrid: reshape Z failed: {e}")))?;
1022        let z_base = provider
1023            .reshape(&z_axis_row, &[1, 1, nz])
1024            .map_err(|e| builtin_error(format!("meshgrid: reshape Z(3d) failed: {e}")))?;
1025        let z_grid = provider
1026            .repmat(&z_base, &[ny, nx, 1])
1027            .map_err(|e| builtin_error(format!("meshgrid: repmat Z failed: {e}")))?;
1028        outputs.push(MeshgridOutput::GpuReal(z_grid));
1029    } else {
1030        let x_grid = provider
1031            .repmat(&x_row, &[ny, 1])
1032            .map_err(|e| builtin_error(format!("meshgrid: repmat X failed: {e}")))?;
1033        let y_grid = provider
1034            .repmat(&y_col, &[1, nx])
1035            .map_err(|e| builtin_error(format!("meshgrid: repmat Y failed: {e}")))?;
1036        outputs.push(MeshgridOutput::GpuReal(x_grid));
1037        outputs.push(MeshgridOutput::GpuReal(y_grid));
1038    }
1039
1040    Ok(Some(outputs))
1041}
1042
1043fn normalise_axes(axes: &[AxisData]) -> (AxisData, AxisData, Option<AxisData>) {
1044    match axes.len() {
1045        1 => {
1046            let x = axes[0].clone();
1047            (x.clone(), x, None)
1048        }
1049        2 => {
1050            let x = axes[0].clone();
1051            let y = axes[1].clone();
1052            (x, y, None)
1053        }
1054        3 => {
1055            let x = axes[0].clone();
1056            let y = axes[1].clone();
1057            let z = axes[2].clone();
1058            (x, y, Some(z))
1059        }
1060        _ => unreachable!(),
1061    }
1062}
1063
1064fn build_outputs(
1065    x_axis: &AxisData,
1066    y_axis: &AxisData,
1067    z_axis: Option<&AxisData>,
1068) -> Vec<GridOutput> {
1069    let nx = x_axis.len;
1070    let ny = y_axis.len;
1071    let nz = z_axis.map(|axis| axis.len).unwrap_or(1);
1072    let total = nx * ny * nz;
1073    let mut x_data = Vec::with_capacity(total);
1074    let mut y_data = Vec::with_capacity(total);
1075    let mut z_data = z_axis.map(|_| Vec::with_capacity(total));
1076
1077    for k in 0..nz {
1078        let z_value = z_axis.map(|axis| axis.values[k]);
1079        for col in 0..nx {
1080            let x_value = x_axis.values[col];
1081            for row in 0..ny {
1082                x_data.push(x_value);
1083                y_data.push(y_axis.values[row]);
1084                if let Some(ref mut z_vec) = z_data {
1085                    z_vec.push(z_value.unwrap());
1086                }
1087            }
1088        }
1089    }
1090
1091    let mut outputs = Vec::new();
1092    let base_shape = if nz == 1 {
1093        vec![ny, nx]
1094    } else {
1095        vec![ny, nx, nz]
1096    };
1097    outputs.push(GridOutput {
1098        shape: base_shape.clone(),
1099        data: x_data,
1100    });
1101    outputs.push(GridOutput {
1102        shape: base_shape.clone(),
1103        data: y_data,
1104    });
1105    if let Some(z_vec) = z_data {
1106        outputs.push(GridOutput {
1107            shape: base_shape,
1108            data: z_vec,
1109        });
1110    }
1111    outputs
1112}
1113
1114struct GridOutput {
1115    shape: Vec<usize>,
1116    data: Vec<(f64, f64)>,
1117}
1118
1119impl GridOutput {
1120    fn to_value(
1121        &self,
1122        class: PrototypeClass,
1123        residency: DevicePreference,
1124    ) -> crate::BuiltinResult<Value> {
1125        match class {
1126            PrototypeClass::Real => self.to_real_value(residency),
1127            PrototypeClass::Complex => self.to_complex_value(residency),
1128        }
1129    }
1130
1131    fn to_real_value(&self, residency: DevicePreference) -> crate::BuiltinResult<Value> {
1132        let mut real = Vec::with_capacity(self.data.len());
1133        for &(re, im) in &self.data {
1134            if im != 0.0 {
1135                return Err(builtin_error(
1136                    "meshgrid: cannot represent complex values in a real output",
1137                ));
1138            }
1139            real.push(re);
1140        }
1141        let tensor = Tensor::new(real, self.shape.clone())
1142            .map_err(|e| builtin_error(format!("meshgrid: {e}")))?;
1143        match residency {
1144            DevicePreference::Host => Ok(tensor::tensor_into_value(tensor)),
1145            DevicePreference::Gpu => to_gpu_tensor_value(tensor),
1146        }
1147    }
1148
1149    fn to_complex_value(&self, residency: DevicePreference) -> crate::BuiltinResult<Value> {
1150        let tensor = ComplexTensor::new(self.data.clone(), self.shape.clone())
1151            .map_err(|e| builtin_error(format!("meshgrid: {e}")))?;
1152        match residency {
1153            DevicePreference::Host => Ok(complex_tensor_into_value(tensor)),
1154            DevicePreference::Gpu => {
1155                warn!("meshgrid: complex GPU outputs are not implemented; returning host complex array");
1156                Ok(complex_tensor_into_value(tensor))
1157            }
1158        }
1159    }
1160}
1161
1162fn to_gpu_tensor_value(tensor: Tensor) -> crate::BuiltinResult<Value> {
1163    if let Some(provider) = runmat_accelerate_api::provider() {
1164        let view = HostTensorView {
1165            data: &tensor.data,
1166            shape: &tensor.shape,
1167        };
1168        match provider.upload(&view) {
1169            Ok(handle) => return Ok(Value::GpuTensor(handle)),
1170            Err(err) => {
1171                warn!("meshgrid: failed to upload tensor to GPU, returning host array: {err}")
1172            }
1173        }
1174    }
1175    Ok(tensor::tensor_into_value(tensor))
1176}
1177
1178fn tensor_to_complex_value(tensor: Tensor) -> crate::BuiltinResult<Value> {
1179    let data: Vec<(f64, f64)> = tensor.data.iter().map(|&re| (re, 0.0)).collect();
1180    let complex = ComplexTensor::new(data, tensor.shape.clone())
1181        .map_err(|e| builtin_error(format!("meshgrid: {e}")))?;
1182    Ok(complex_tensor_into_value(complex))
1183}
1184
1185enum MeshgridOutput {
1186    Host(GridOutput),
1187    GpuReal(GpuTensorHandle),
1188}
1189
1190impl MeshgridOutput {
1191    async fn to_value(
1192        &self,
1193        class: PrototypeClass,
1194        residency: DevicePreference,
1195    ) -> crate::BuiltinResult<Value> {
1196        match self {
1197            MeshgridOutput::Host(host) => host.to_value(class, residency),
1198            MeshgridOutput::GpuReal(handle) => match (class, residency) {
1199                (PrototypeClass::Real, DevicePreference::Gpu) => {
1200                    Ok(Value::GpuTensor(handle.clone()))
1201                }
1202                (PrototypeClass::Real, DevicePreference::Host) => {
1203                    let tensor = gpu_helpers::gather_tensor_async(handle).await?;
1204                    Ok(tensor::tensor_into_value(tensor))
1205                }
1206                (PrototypeClass::Complex, DevicePreference::Host) => {
1207                    let tensor = gpu_helpers::gather_tensor_async(handle).await?;
1208                    tensor_to_complex_value(tensor)
1209                }
1210                (PrototypeClass::Complex, DevicePreference::Gpu) => {
1211                    warn!("meshgrid: complex GPU outputs are not implemented; returning host complex array");
1212                    let tensor = gpu_helpers::gather_tensor_async(handle).await?;
1213                    tensor_to_complex_value(tensor)
1214                }
1215            },
1216        }
1217    }
1218}
1219
1220/// Holds the results of a `meshgrid` evaluation so multiple outputs can be
1221/// materialised without recomputing the grid.
1222pub struct MeshgridEval {
1223    outputs: Vec<MeshgridOutput>,
1224    target_class: PrototypeClass,
1225    target_residency: DevicePreference,
1226}
1227
1228impl MeshgridEval {
1229    pub fn output_count(&self) -> usize {
1230        self.outputs.len()
1231    }
1232
1233    pub async fn first(&self) -> crate::BuiltinResult<Value> {
1234        self.outputs[0]
1235            .to_value(self.target_class, self.target_residency)
1236            .await
1237    }
1238
1239    pub async fn second(&self) -> crate::BuiltinResult<Value> {
1240        if self.outputs.len() < 2 {
1241            Err(builtin_error("meshgrid: second output unavailable"))
1242        } else {
1243            self.outputs[1]
1244                .to_value(self.target_class, self.target_residency)
1245                .await
1246        }
1247    }
1248
1249    pub async fn third(&self) -> crate::BuiltinResult<Value> {
1250        if self.outputs.len() < 3 {
1251            Err(builtin_error(
1252                "meshgrid: third output requested but no Z vector was supplied",
1253            ))
1254        } else {
1255            self.outputs[2]
1256                .to_value(self.target_class, self.target_residency)
1257                .await
1258        }
1259    }
1260}
1261
1262#[cfg(test)]
1263pub(crate) mod tests {
1264    use super::*;
1265    use crate::builtins::common::test_support;
1266    use futures::executor::block_on;
1267    #[cfg(feature = "wgpu")]
1268    use runmat_accelerate_api::AccelProvider;
1269
1270    use runmat_accelerate_api::HostTensorView;
1271
1272    fn evaluate(args: &[Value]) -> crate::BuiltinResult<MeshgridEval> {
1273        block_on(super::evaluate(args))
1274    }
1275
1276    fn eval_first(eval: &MeshgridEval) -> crate::BuiltinResult<Value> {
1277        block_on(eval.first())
1278    }
1279
1280    fn eval_second(eval: &MeshgridEval) -> crate::BuiltinResult<Value> {
1281        block_on(eval.second())
1282    }
1283
1284    fn eval_third(eval: &MeshgridEval) -> crate::BuiltinResult<Value> {
1285        block_on(eval.third())
1286    }
1287
1288    fn tensor_from_vec(data: Vec<f64>, rows: usize, cols: usize) -> Tensor {
1289        Tensor::new(data, vec![rows, cols]).unwrap()
1290    }
1291
1292    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1293    #[test]
1294    fn meshgrid_single_input_duplicates_axis() {
1295        let x = tensor_from_vec(vec![-1.0, 0.0, 1.0], 1, 3);
1296        let eval = evaluate(&[Value::Tensor(x)]).expect("meshgrid");
1297        assert_eq!(eval.output_count(), 2);
1298        let x_out = test_support::gather(eval_first(&eval).expect("X")).expect("host");
1299        assert_eq!(x_out.shape, vec![3, 3]);
1300        assert_eq!(
1301            x_out.data,
1302            vec![-1.0, -1.0, -1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0]
1303        );
1304        let y_out = test_support::gather(eval_second(&eval).expect("Y")).expect("host");
1305        assert_eq!(y_out.shape, vec![3, 3]);
1306        assert_eq!(
1307            y_out.data,
1308            vec![-1.0, 0.0, 1.0, -1.0, 0.0, 1.0, -1.0, 0.0, 1.0]
1309        );
1310    }
1311
1312    #[test]
1313    fn meshgrid_type_infers_rank_from_axis_count() {
1314        let ctx = ResolveContext::new(Vec::new());
1315        assert_eq!(
1316            meshgrid_type(&[Type::Num, Type::Num], &ctx),
1317            Type::Tensor {
1318                shape: Some(vec![Some(1), Some(1)])
1319            }
1320        );
1321        assert_eq!(
1322            meshgrid_type(&[Type::Num, Type::Num, Type::Num], &ctx),
1323            Type::Tensor {
1324                shape: Some(vec![Some(1), Some(1), Some(1)])
1325            }
1326        );
1327    }
1328
1329    #[test]
1330    fn meshgrid_type_uses_vector_lengths() {
1331        let ctx = ResolveContext::new(Vec::new());
1332        assert_eq!(
1333            meshgrid_type(
1334                &[
1335                    Type::Tensor {
1336                        shape: Some(vec![Some(1), Some(201)]),
1337                    },
1338                    Type::Tensor {
1339                        shape: Some(vec![Some(1), Some(101)]),
1340                    },
1341                ],
1342                &ctx,
1343            ),
1344            Type::Tensor {
1345                shape: Some(vec![Some(101), Some(201)])
1346            }
1347        );
1348    }
1349
1350    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1351    #[test]
1352    fn meshgrid_rectangular_inputs() {
1353        let x = tensor_from_vec(vec![0.0, 0.5, 1.0], 1, 3);
1354        let y = tensor_from_vec(vec![10.0, 20.0], 2, 1);
1355        let eval = evaluate(&[Value::Tensor(x), Value::Tensor(y)]).expect("meshgrid");
1356        assert_eq!(eval.output_count(), 2);
1357        let x_out = test_support::gather(eval_first(&eval).expect("X")).expect("host");
1358        assert_eq!(x_out.shape, vec![2, 3]);
1359        assert_eq!(x_out.data, vec![0.0, 0.0, 0.5, 0.5, 1.0, 1.0]);
1360        let y_out = test_support::gather(eval_second(&eval).expect("Y")).expect("host");
1361        assert_eq!(y_out.shape, vec![2, 3]);
1362        assert_eq!(y_out.data, vec![10.0, 20.0, 10.0, 20.0, 10.0, 20.0]);
1363    }
1364
1365    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1366    #[test]
1367    fn meshgrid_three_inputs_volume() {
1368        let x = tensor_from_vec(vec![1.0, 2.0], 1, 2);
1369        let y = tensor_from_vec(vec![5.0, 6.0, 7.0], 3, 1);
1370        let z = tensor_from_vec(vec![0.0, 1.0], 1, 2);
1371        let eval =
1372            evaluate(&[Value::Tensor(x), Value::Tensor(y), Value::Tensor(z)]).expect("meshgrid");
1373        assert_eq!(eval.output_count(), 3);
1374        let x_out = test_support::gather(eval_first(&eval).expect("X")).expect("host");
1375        assert_eq!(x_out.shape, vec![3, 2, 2]);
1376        assert_eq!(
1377            x_out.data,
1378            vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0]
1379        );
1380        let z_out = test_support::gather(eval_third(&eval).expect("Z")).expect("host");
1381        assert_eq!(z_out.shape, vec![3, 2, 2]);
1382        assert_eq!(
1383            z_out.data,
1384            vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
1385        );
1386    }
1387
1388    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1389    #[test]
1390    fn meshgrid_like_keeps_gpu_residency() {
1391        test_support::with_test_provider(|provider| {
1392            let x = tensor_from_vec(vec![-1.0, 0.0, 1.0], 1, 3);
1393            let y = tensor_from_vec(vec![2.0, 4.0], 2, 1);
1394            let proto = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
1395            let proto_view = HostTensorView {
1396                data: &proto.data,
1397                shape: &proto.shape,
1398            };
1399            let proto_handle = provider.upload(&proto_view).expect("upload");
1400            let eval = evaluate(&[
1401                Value::Tensor(x),
1402                Value::Tensor(y),
1403                Value::from("like"),
1404                Value::GpuTensor(proto_handle),
1405            ])
1406            .expect("meshgrid");
1407            let x_value = eval_first(&eval).expect("X");
1408            assert!(matches!(x_value, Value::GpuTensor(_)));
1409            let gathered = test_support::gather(x_value).expect("gather");
1410            assert_eq!(gathered.shape, vec![2, 3]);
1411        });
1412    }
1413
1414    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1415    #[test]
1416    fn meshgrid_gpu_inputs_roundtrip() {
1417        test_support::with_test_provider(|provider| {
1418            let x = tensor_from_vec(vec![0.0, 0.5], 1, 2);
1419            let y = tensor_from_vec(vec![1.0, 2.0], 2, 1);
1420            let x_view = HostTensorView {
1421                data: &x.data,
1422                shape: &x.shape,
1423            };
1424            let y_view = HostTensorView {
1425                data: &y.data,
1426                shape: &y.shape,
1427            };
1428            let x_handle = provider.upload(&x_view).expect("upload");
1429            let y_handle = provider.upload(&y_view).expect("upload");
1430            let eval = evaluate(&[Value::GpuTensor(x_handle), Value::GpuTensor(y_handle)])
1431                .expect("meshgrid");
1432            assert!(matches!(eval_first(&eval).expect("X"), Value::GpuTensor(_)));
1433            assert!(matches!(
1434                eval_second(&eval).expect("Y"),
1435                Value::GpuTensor(_)
1436            ));
1437        });
1438    }
1439
1440    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1441    #[test]
1442    #[cfg(feature = "wgpu")]
1443    fn meshgrid_wgpu_matches_cpu() {
1444        let provider = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1445            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1446        )
1447        .expect("wgpu provider");
1448
1449        let x = tensor_from_vec(vec![-1.0, 0.0, 1.0, 2.0], 1, 4);
1450        let y = tensor_from_vec(vec![5.0, 6.0], 2, 1);
1451
1452        let cpu_eval =
1453            evaluate(&[Value::Tensor(x.clone()), Value::Tensor(y.clone())]).expect("meshgrid cpu");
1454        let cpu_x =
1455            test_support::gather(eval_first(&cpu_eval).expect("X cpu")).expect("gather X cpu");
1456        let cpu_y =
1457            test_support::gather(eval_second(&cpu_eval).expect("Y cpu")).expect("gather Y cpu");
1458
1459        let x_view = HostTensorView {
1460            data: &x.data,
1461            shape: &x.shape,
1462        };
1463        let y_view = HostTensorView {
1464            data: &y.data,
1465            shape: &y.shape,
1466        };
1467        let x_gpu = provider.upload(&x_view).expect("upload x");
1468        let y_gpu = provider.upload(&y_view).expect("upload y");
1469
1470        let gpu_eval =
1471            evaluate(&[Value::GpuTensor(x_gpu), Value::GpuTensor(y_gpu)]).expect("meshgrid gpu");
1472        let gpu_x_value = eval_first(&gpu_eval).expect("X gpu");
1473        let gpu_y_value = eval_second(&gpu_eval).expect("Y gpu");
1474
1475        assert!(matches!(gpu_x_value, Value::GpuTensor(_)));
1476        assert!(matches!(gpu_y_value, Value::GpuTensor(_)));
1477
1478        let gathered_x = test_support::gather(gpu_x_value).expect("gather X gpu");
1479        let gathered_y = test_support::gather(gpu_y_value).expect("gather Y gpu");
1480
1481        assert_eq!(gathered_x.shape, cpu_x.shape);
1482        assert_eq!(gathered_x.data, cpu_x.data);
1483        assert_eq!(gathered_y.shape, cpu_y.shape);
1484        assert_eq!(gathered_y.data, cpu_y.data);
1485    }
1486
1487    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1488    #[test]
1489    fn meshgrid_complex_inputs_produce_complex_outputs() {
1490        let complex = ComplexTensor::new(vec![(1.0, 1.0), (2.0, -1.0)], vec![1, 2]).unwrap();
1491        let eval = evaluate(&[Value::ComplexTensor(complex)]).expect("meshgrid");
1492        let x_value = eval_first(&eval).expect("X");
1493        match x_value {
1494            Value::ComplexTensor(ct) => {
1495                assert_eq!(ct.shape, vec![2, 2]);
1496            }
1497            Value::Complex(_, _) => {}
1498            other => panic!("expected complex output, got {other:?}"),
1499        }
1500    }
1501
1502    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1503    #[test]
1504    fn meshgrid_like_host_prototype() {
1505        let x = tensor_from_vec(vec![1.0, 2.0], 1, 2);
1506        let eval =
1507            evaluate(&[Value::Tensor(x), Value::from("like"), Value::Num(0.0)]).expect("meshgrid");
1508        let x_out = eval_first(&eval).expect("X");
1509        assert!(matches!(x_out, Value::Tensor(_) | Value::Num(_)));
1510    }
1511}