Skip to main content

runmat_runtime/builtins/array/creation/
meshgrid.rs

1//! MATLAB-compatible `meshgrid` builtin with GPU-aware semantics.
2
3use std::cmp::max;
4
5use log::warn;
6use runmat_accelerate_api::{GpuTensorHandle, GpuTensorStorage, HostTensorView};
7use runmat_builtins::{
8    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
9    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
10    ComplexTensor, ResolveContext, Tensor, Type, Value,
11};
12
13use crate::builtins::array::type_resolvers::size_vector_len;
14use runmat_macros::runtime_builtin;
15
16use crate::build_runtime_error;
17use crate::builtins::common::gpu_helpers;
18use crate::builtins::common::random_args::{complex_tensor_into_value, keyword_of};
19use crate::builtins::common::residency::{sequence_gpu_preference, SequenceIntent};
20use crate::builtins::common::spec::{
21    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
22    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
23};
24use crate::builtins::common::tensor;
25
26#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::creation::meshgrid")]
27pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
28    name: "meshgrid",
29    op_kind: GpuOpKind::Custom("array_construct"),
30    supported_precisions: &[ScalarType::F32, ScalarType::F64],
31    broadcast: BroadcastSemantics::Matlab,
32    provider_hooks: &[ProviderHook::Custom("meshgrid")],
33    constant_strategy: ConstantStrategy::InlineLiteral,
34    residency: ResidencyPolicy::NewHandle,
35    nan_mode: ReductionNaN::Include,
36    two_pass_threshold: None,
37    workgroup_size: None,
38    accepts_nan_mode: false,
39    notes: "Providers may supply a dedicated meshgrid hook; until then the runtime builds grids on the host and uploads them when GPU residency is requested.",
40};
41
42fn builtin_error(message: impl Into<String>) -> crate::RuntimeError {
43    build_runtime_error(message)
44        .with_builtin("meshgrid")
45        .build()
46}
47
48#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::creation::meshgrid")]
49pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
50    name: "meshgrid",
51    shape: ShapeRequirements::Any,
52    constant_strategy: ConstantStrategy::InlineLiteral,
53    elementwise: None,
54    reduction: None,
55    emits_nan: false,
56    notes:
57        "Meshgrid explicitly materialises dense coordinate arrays and therefore bypasses fusion.",
58};
59
60fn meshgrid_type(args: &[Type], _context: &ResolveContext) -> Type {
61    if args.is_empty() {
62        return Type::Unknown;
63    }
64    let mut axis_count = args.len();
65    if axis_count >= 2 && matches!(args[axis_count - 2], Type::String) {
66        axis_count = axis_count.saturating_sub(2);
67    }
68    if axis_count == 0 {
69        return Type::Unknown;
70    }
71    let axis_args = &args[..axis_count];
72    let len_x = axis_args.get(0).and_then(size_vector_len);
73    let len_y = axis_args.get(1).and_then(size_vector_len).or(len_x);
74    let len_z = axis_args.get(2).and_then(size_vector_len);
75    let shape = if axis_count >= 3 {
76        vec![len_y, len_x, len_z]
77    } else {
78        vec![len_y, len_x]
79    };
80    Type::Tensor { shape: Some(shape) }
81}
82
83const MESHGRID_OUTPUT_XY: [BuiltinParamDescriptor; 2] = [
84    BuiltinParamDescriptor {
85        name: "X",
86        ty: BuiltinParamType::NumericArray,
87        arity: BuiltinParamArity::Required,
88        default: None,
89        description: "Grid coordinates along X-axis.",
90    },
91    BuiltinParamDescriptor {
92        name: "Y",
93        ty: BuiltinParamType::NumericArray,
94        arity: BuiltinParamArity::Required,
95        default: None,
96        description: "Grid coordinates along Y-axis.",
97    },
98];
99
100const MESHGRID_OUTPUT_XYZ: [BuiltinParamDescriptor; 3] = [
101    BuiltinParamDescriptor {
102        name: "X",
103        ty: BuiltinParamType::NumericArray,
104        arity: BuiltinParamArity::Required,
105        default: None,
106        description: "Grid coordinates along X-axis.",
107    },
108    BuiltinParamDescriptor {
109        name: "Y",
110        ty: BuiltinParamType::NumericArray,
111        arity: BuiltinParamArity::Required,
112        default: None,
113        description: "Grid coordinates along Y-axis.",
114    },
115    BuiltinParamDescriptor {
116        name: "Z",
117        ty: BuiltinParamType::NumericArray,
118        arity: BuiltinParamArity::Optional,
119        default: None,
120        description: "Grid coordinates along Z-axis.",
121    },
122];
123
124const MESHGRID_SIG_X_INPUTS: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
125    name: "x",
126    ty: BuiltinParamType::NumericArray,
127    arity: BuiltinParamArity::Required,
128    default: None,
129    description: "X-axis vector.",
130}];
131
132const MESHGRID_SIG_XY_INPUTS: [BuiltinParamDescriptor; 2] = [
133    BuiltinParamDescriptor {
134        name: "x",
135        ty: BuiltinParamType::NumericArray,
136        arity: BuiltinParamArity::Required,
137        default: None,
138        description: "X-axis vector.",
139    },
140    BuiltinParamDescriptor {
141        name: "y",
142        ty: BuiltinParamType::NumericArray,
143        arity: BuiltinParamArity::Required,
144        default: None,
145        description: "Y-axis vector.",
146    },
147];
148
149const MESHGRID_SIG_XYZ_INPUTS: [BuiltinParamDescriptor; 3] = [
150    BuiltinParamDescriptor {
151        name: "x",
152        ty: BuiltinParamType::NumericArray,
153        arity: BuiltinParamArity::Required,
154        default: None,
155        description: "X-axis vector.",
156    },
157    BuiltinParamDescriptor {
158        name: "y",
159        ty: BuiltinParamType::NumericArray,
160        arity: BuiltinParamArity::Required,
161        default: None,
162        description: "Y-axis vector.",
163    },
164    BuiltinParamDescriptor {
165        name: "z",
166        ty: BuiltinParamType::NumericArray,
167        arity: BuiltinParamArity::Optional,
168        default: None,
169        description: "Z-axis vector.",
170    },
171];
172
173const MESHGRID_SIG_X_LIKE_INPUTS: [BuiltinParamDescriptor; 3] = [
174    BuiltinParamDescriptor {
175        name: "x",
176        ty: BuiltinParamType::NumericArray,
177        arity: BuiltinParamArity::Required,
178        default: None,
179        description: "X-axis vector.",
180    },
181    BuiltinParamDescriptor {
182        name: "like_kw",
183        ty: BuiltinParamType::StringScalar,
184        arity: BuiltinParamArity::Required,
185        default: Some("\"like\""),
186        description: "Like keyword.",
187    },
188    BuiltinParamDescriptor {
189        name: "prototype",
190        ty: BuiltinParamType::LikePrototype,
191        arity: BuiltinParamArity::Required,
192        default: None,
193        description: "Prototype controlling class/device residency.",
194    },
195];
196
197const MESHGRID_SIG_XY_LIKE_INPUTS: [BuiltinParamDescriptor; 4] = [
198    BuiltinParamDescriptor {
199        name: "x",
200        ty: BuiltinParamType::NumericArray,
201        arity: BuiltinParamArity::Required,
202        default: None,
203        description: "X-axis vector.",
204    },
205    BuiltinParamDescriptor {
206        name: "y",
207        ty: BuiltinParamType::NumericArray,
208        arity: BuiltinParamArity::Required,
209        default: None,
210        description: "Y-axis vector.",
211    },
212    BuiltinParamDescriptor {
213        name: "like_kw",
214        ty: BuiltinParamType::StringScalar,
215        arity: BuiltinParamArity::Required,
216        default: Some("\"like\""),
217        description: "Like keyword.",
218    },
219    BuiltinParamDescriptor {
220        name: "prototype",
221        ty: BuiltinParamType::LikePrototype,
222        arity: BuiltinParamArity::Required,
223        default: None,
224        description: "Prototype controlling class/device residency.",
225    },
226];
227
228const MESHGRID_SIG_XYZ_LIKE_INPUTS: [BuiltinParamDescriptor; 5] = [
229    BuiltinParamDescriptor {
230        name: "x",
231        ty: BuiltinParamType::NumericArray,
232        arity: BuiltinParamArity::Required,
233        default: None,
234        description: "X-axis vector.",
235    },
236    BuiltinParamDescriptor {
237        name: "y",
238        ty: BuiltinParamType::NumericArray,
239        arity: BuiltinParamArity::Required,
240        default: None,
241        description: "Y-axis vector.",
242    },
243    BuiltinParamDescriptor {
244        name: "z",
245        ty: BuiltinParamType::NumericArray,
246        arity: BuiltinParamArity::Optional,
247        default: None,
248        description: "Z-axis vector.",
249    },
250    BuiltinParamDescriptor {
251        name: "like_kw",
252        ty: BuiltinParamType::StringScalar,
253        arity: BuiltinParamArity::Required,
254        default: Some("\"like\""),
255        description: "Like keyword.",
256    },
257    BuiltinParamDescriptor {
258        name: "prototype",
259        ty: BuiltinParamType::LikePrototype,
260        arity: BuiltinParamArity::Required,
261        default: None,
262        description: "Prototype controlling class/device residency.",
263    },
264];
265
266const MESHGRID_SIGNATURES: [BuiltinSignatureDescriptor; 6] = [
267    BuiltinSignatureDescriptor {
268        label: "[X,Y] = meshgrid(x)",
269        inputs: &MESHGRID_SIG_X_INPUTS,
270        outputs: &MESHGRID_OUTPUT_XY,
271    },
272    BuiltinSignatureDescriptor {
273        label: "[X,Y] = meshgrid(x, y)",
274        inputs: &MESHGRID_SIG_XY_INPUTS,
275        outputs: &MESHGRID_OUTPUT_XY,
276    },
277    BuiltinSignatureDescriptor {
278        label: "[X,Y,Z] = meshgrid(x, y, z)",
279        inputs: &MESHGRID_SIG_XYZ_INPUTS,
280        outputs: &MESHGRID_OUTPUT_XYZ,
281    },
282    BuiltinSignatureDescriptor {
283        label: "[X,Y] = meshgrid(x, \"like\", prototype)",
284        inputs: &MESHGRID_SIG_X_LIKE_INPUTS,
285        outputs: &MESHGRID_OUTPUT_XY,
286    },
287    BuiltinSignatureDescriptor {
288        label: "[X,Y] = meshgrid(x, y, \"like\", prototype)",
289        inputs: &MESHGRID_SIG_XY_LIKE_INPUTS,
290        outputs: &MESHGRID_OUTPUT_XY,
291    },
292    BuiltinSignatureDescriptor {
293        label: "[X,Y,Z] = meshgrid(x, y, z, \"like\", prototype)",
294        inputs: &MESHGRID_SIG_XYZ_LIKE_INPUTS,
295        outputs: &MESHGRID_OUTPUT_XYZ,
296    },
297];
298
299const MESHGRID_ERRORS: [BuiltinErrorDescriptor; 11] = [
300    BuiltinErrorDescriptor {
301        code: "RM.MESHGRID.MISSING_AXIS",
302        identifier: None,
303        when: "No axis vectors are provided.",
304        message: "meshgrid: at least one input vector is required",
305    },
306    BuiltinErrorDescriptor {
307        code: "RM.MESHGRID.TOO_MANY_AXES",
308        identifier: None,
309        when: "More than three axis vectors are provided.",
310        message: "meshgrid: expected at most three input vectors",
311    },
312    BuiltinErrorDescriptor {
313        code: "RM.MESHGRID.LIKE_EXPECTED_PROTOTYPE",
314        identifier: None,
315        when: "The 'like' keyword is provided without a prototype argument.",
316        message: "meshgrid: expected prototype after 'like'",
317    },
318    BuiltinErrorDescriptor {
319        code: "RM.MESHGRID.MULTIPLE_LIKE",
320        identifier: None,
321        when: "The 'like' keyword is provided multiple times.",
322        message: "meshgrid: multiple 'like' specifications are not supported",
323    },
324    BuiltinErrorDescriptor {
325        code: "RM.MESHGRID.LIKE_POSITION",
326        identifier: None,
327        when: "The 'like' keyword is in an invalid position or not final.",
328        message: "meshgrid: 'like' must be the final argument",
329    },
330    BuiltinErrorDescriptor {
331        code: "RM.MESHGRID.UNRECOGNIZED_OPTION",
332        identifier: None,
333        when: "A trailing option string is not recognized.",
334        message: "meshgrid: unrecognised option",
335    },
336    BuiltinErrorDescriptor {
337        code: "RM.MESHGRID.INVALID_AXIS_INPUT",
338        identifier: None,
339        when: "Axis inputs are non-numeric or non-vector shapes.",
340        message: "meshgrid: input argument must be numeric vector data",
341    },
342    BuiltinErrorDescriptor {
343        code: "RM.MESHGRID.INVALID_PROTOTYPE",
344        identifier: None,
345        when: "The 'like' prototype is unsupported.",
346        message: "meshgrid: prototypes must be numeric arrays",
347    },
348    BuiltinErrorDescriptor {
349        code: "RM.MESHGRID.OUTPUT_COUNT_EXCEEDED",
350        identifier: None,
351        when: "Requested outputs exceed available outputs for provided axes.",
352        message:
353            "meshgrid: supports at most two outputs for 2-axis inputs and three for 3-axis inputs",
354    },
355    BuiltinErrorDescriptor {
356        code: "RM.MESHGRID.THIRD_OUTPUT_UNAVAILABLE",
357        identifier: None,
358        when: "A third output is requested without supplying a Z-axis vector.",
359        message: "meshgrid: third output requested but no Z vector was supplied",
360    },
361    BuiltinErrorDescriptor {
362        code: "RM.MESHGRID.COMPLEX_REAL_CONVERSION",
363        identifier: None,
364        when: "Complex axis values cannot be represented in requested real output class.",
365        message: "meshgrid: cannot represent complex values in a real output",
366    },
367];
368
369pub const MESHGRID_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
370    signatures: &MESHGRID_SIGNATURES,
371    output_mode: BuiltinOutputMode::ByRequestedOutputCount,
372    completion_policy: BuiltinCompletionPolicy::Public,
373    errors: &MESHGRID_ERRORS,
374};
375
376#[runtime_builtin(
377    name = "meshgrid",
378    category = "array/creation",
379    summary = "Generate coordinate matrices for 2-D and 3-D grids.",
380    keywords = "meshgrid,grid,gpu,like,3d",
381    accel = "array_construct",
382    type_resolver(meshgrid_type),
383    descriptor(crate::builtins::array::creation::meshgrid::MESHGRID_DESCRIPTOR),
384    builtin_path = "crate::builtins::array::creation::meshgrid"
385)]
386async fn meshgrid_builtin(rest: Vec<Value>) -> crate::BuiltinResult<Value> {
387    let eval = evaluate(&rest).await?;
388    if let Some(out_count) = crate::output_count::current_output_count() {
389        if out_count == 0 {
390            return Ok(Value::OutputList(Vec::new()));
391        }
392        let available = eval.output_count();
393        if out_count > available {
394            let msg = if available == 2 {
395                "meshgrid with two inputs supports at most two outputs"
396            } else {
397                "meshgrid supports at most three outputs"
398            };
399            return Err(builtin_error(msg));
400        }
401        let mut outputs = Vec::with_capacity(out_count);
402        let first = eval.first().await?;
403        outputs.push(first);
404        if out_count >= 2 {
405            outputs.push(eval.second().await?);
406        }
407        if out_count >= 3 {
408            outputs.push(eval.third().await?);
409        }
410        return Ok(Value::OutputList(outputs));
411    }
412    eval.first().await
413}
414
415/// Evaluate the `meshgrid` builtin once and reuse the result for multiple outputs.
416pub async fn evaluate(args: &[Value]) -> crate::BuiltinResult<MeshgridEval> {
417    let parsed = ParsedMeshgrid::parse(args).await?;
418    let (x_axis, y_axis, z_axis) = normalise_axes(&parsed.axes);
419
420    let require_complex = parsed.axes.iter().any(|axis| axis.is_complex);
421
422    let target_class = match &parsed.template {
423        OutputTemplate::Default => {
424            if require_complex {
425                PrototypeClass::Complex
426            } else {
427                PrototypeClass::Real
428            }
429        }
430        OutputTemplate::Like(spec) => {
431            if require_complex {
432                PrototypeClass::Complex
433            } else {
434                spec.class
435            }
436        }
437    };
438
439    let target_residency = match &parsed.template {
440        OutputTemplate::Default => {
441            if parsed.prefer_gpu {
442                DevicePreference::Gpu
443            } else {
444                DevicePreference::Host
445            }
446        }
447        OutputTemplate::Like(spec) => spec.residency,
448    };
449
450    let mut outputs: Vec<MeshgridOutput> = Vec::new();
451
452    if matches!(target_residency, DevicePreference::Gpu) {
453        if let Some(gpu) = try_meshgrid_gpu_from_vector_axes(&x_axis, &y_axis, z_axis.as_ref())? {
454            outputs = gpu;
455        }
456    }
457
458    if outputs.is_empty() {
459        // Host fallback: ensure we have host axis values materialized.
460        let x_host = axis_to_host_async(&x_axis).await?;
461        let y_host = axis_to_host_async(&y_axis).await?;
462        let z_host = match z_axis.as_ref() {
463            Some(axis) => Some(axis_to_host_async(axis).await?),
464            None => None,
465        };
466        outputs = build_outputs(&x_host, &y_host, z_host.as_ref())
467            .into_iter()
468            .map(MeshgridOutput::Host)
469            .collect();
470    }
471
472    Ok(MeshgridEval {
473        outputs,
474        target_class,
475        target_residency,
476    })
477}
478
479#[derive(Clone)]
480struct ParsedMeshgrid {
481    axes: Vec<AxisData>,
482    template: OutputTemplate,
483    prefer_gpu: bool,
484}
485
486impl ParsedMeshgrid {
487    async fn parse(args: &[Value]) -> crate::BuiltinResult<Self> {
488        if args.is_empty() {
489            return Err(builtin_error(
490                "meshgrid: at least one input vector is required",
491            ));
492        }
493        let mut axis_values: Vec<Value> = Vec::new();
494        let mut like_proto: Option<Value> = None;
495        let mut prefer_gpu = false;
496        let mut idx = 0;
497        while idx < args.len() {
498            let value = args[idx].clone();
499            if let Some(keyword) = keyword_of(&value) {
500                match keyword.as_str() {
501                    "like" => {
502                        if like_proto.is_some() {
503                            return Err(builtin_error(
504                                "meshgrid: multiple 'like' specifications are not supported",
505                            ));
506                        }
507                        if axis_values.is_empty() {
508                            return Err(builtin_error(
509                                "meshgrid: 'like' must follow at least one input vector",
510                            ));
511                        }
512                        let Some(proto) = args.get(idx + 1).cloned() else {
513                            return Err(builtin_error("meshgrid: expected prototype after 'like'"));
514                        };
515                        like_proto = Some(proto);
516                        idx += 2;
517                        if idx < args.len() {
518                            return Err(builtin_error(
519                                "meshgrid: 'like' must be the final argument",
520                            ));
521                        }
522                        break;
523                    }
524                    other => {
525                        return Err(builtin_error(format!(
526                            "meshgrid: unrecognised option '{other}'"
527                        )));
528                    }
529                }
530            }
531
532            if let Value::GpuTensor(_) = value {
533                prefer_gpu = true;
534            }
535            axis_values.push(value);
536            idx += 1;
537        }
538
539        if axis_values.is_empty() {
540            return Err(builtin_error(
541                "meshgrid: at least one input vector is required",
542            ));
543        }
544        if axis_values.len() > 3 {
545            return Err(builtin_error(
546                "meshgrid: expected at most three input vectors",
547            ));
548        }
549
550        let mut axes = Vec::with_capacity(max(axis_values.len(), 2));
551        for (i, value) in axis_values.into_iter().enumerate() {
552            let mut consumed_gpu = false;
553            let data = axis_from_value(value, i, &mut consumed_gpu).await?;
554            if consumed_gpu {
555                prefer_gpu = true;
556            }
557            axes.push(data);
558        }
559
560        if !prefer_gpu {
561            if let Some(max_len) = axes.iter().map(|axis| axis.len).max() {
562                if max_len > 0
563                    && sequence_gpu_preference(max_len, SequenceIntent::MeshAxis, false).prefer_gpu
564                {
565                    prefer_gpu = true;
566                }
567            }
568        }
569
570        let template = if let Some(proto) = like_proto {
571            OutputTemplate::Like(analyse_like_prototype(&proto)?)
572        } else {
573            OutputTemplate::Default
574        };
575
576        Ok(Self {
577            axes,
578            template,
579            prefer_gpu,
580        })
581    }
582}
583
584#[derive(Clone)]
585enum OutputTemplate {
586    Default,
587    Like(PrototypeSpec),
588}
589
590#[derive(Clone)]
591struct PrototypeSpec {
592    residency: DevicePreference,
593    class: PrototypeClass,
594}
595
596#[derive(Clone, Copy, PartialEq, Eq)]
597enum PrototypeClass {
598    Real,
599    Complex,
600}
601
602#[derive(Clone, Copy)]
603enum DevicePreference {
604    Host,
605    Gpu,
606}
607
608fn analyse_like_prototype(proto: &Value) -> crate::BuiltinResult<PrototypeSpec> {
609    match proto {
610        Value::GpuTensor(handle) => {
611            let class = if runmat_accelerate_api::handle_storage(handle)
612                == GpuTensorStorage::ComplexInterleaved
613            {
614                PrototypeClass::Complex
615            } else {
616                PrototypeClass::Real
617            };
618            Ok(PrototypeSpec {
619                residency: DevicePreference::Gpu,
620                class,
621            })
622        }
623        Value::ComplexTensor(_) | Value::Complex(_, _) => Ok(PrototypeSpec {
624            residency: DevicePreference::Host,
625            class: PrototypeClass::Complex,
626        }),
627        Value::Tensor(_)
628        | Value::SparseTensor(_)
629        | Value::Num(_)
630        | Value::Int(_)
631        | Value::Bool(_)
632        | Value::LogicalArray(_) => Ok(PrototypeSpec {
633            residency: DevicePreference::Host,
634            class: PrototypeClass::Real,
635        }),
636        Value::CharArray(_) | Value::String(_) | Value::StringArray(_) => Err(builtin_error(
637            "meshgrid: prototypes must be numeric or gpuArray values",
638        )),
639        Value::Symbolic(_) => Err(builtin_error(
640            "meshgrid: prototypes must be numeric or gpuArray values",
641        )),
642        Value::Cell(_)
643        | Value::Struct(_)
644        | Value::Object(_)
645        | Value::HandleObject(_)
646        | Value::Listener(_)
647        | Value::FunctionHandle(_)
648        | Value::ExternalFunctionHandle(_)
649        | Value::MethodFunctionHandle(_)
650        | Value::BoundFunctionHandle { .. }
651        | Value::Closure(_)
652        | Value::ClassRef(_)
653        | Value::MException(_)
654        | Value::OutputList(_) => Err(builtin_error("meshgrid: prototypes must be numeric arrays")),
655    }
656}
657
658#[derive(Clone)]
659struct AxisData {
660    values: Vec<(f64, f64)>,
661    len: usize,
662    is_complex: bool,
663    gpu_real: Option<GpuTensorHandle>,
664}
665
666async fn axis_from_value(
667    value: Value,
668    index: usize,
669    prefer_gpu: &mut bool,
670) -> crate::BuiltinResult<AxisData> {
671    match value {
672        Value::Tensor(tensor) => axis_from_tensor(tensor, index),
673        Value::LogicalArray(logical) => {
674            let tensor = tensor::logical_to_tensor(&logical)?;
675            axis_from_tensor(tensor, index)
676        }
677        Value::Num(n) => Ok(AxisData {
678            values: vec![(n, 0.0)],
679            len: 1,
680            is_complex: false,
681            gpu_real: None,
682        }),
683        Value::Int(i) => {
684            let val = i.to_f64();
685            Ok(AxisData {
686                values: vec![(val, 0.0)],
687                len: 1,
688                is_complex: false,
689                gpu_real: None,
690            })
691        }
692        Value::Bool(b) => Ok(AxisData {
693            values: vec![(if b { 1.0 } else { 0.0 }, 0.0)],
694            len: 1,
695            is_complex: false,
696            gpu_real: None,
697        }),
698        Value::Complex(re, im) => Ok(AxisData {
699            values: vec![(re, im)],
700            len: 1,
701            is_complex: im != 0.0,
702            gpu_real: None,
703        }),
704        Value::ComplexTensor(tensor) => axis_from_complex_tensor(tensor, index),
705        Value::GpuTensor(handle) => {
706            let is_complex = runmat_accelerate_api::handle_storage(&handle)
707                == GpuTensorStorage::ComplexInterleaved;
708            // Fast path: if the gpuArray is vector-like, keep it on-device and avoid a download.
709            // We'll validate any non-vector shapes by gathering below.
710            if is_vector_shape(&handle.shape) && !is_complex {
711                *prefer_gpu = true;
712                return Ok(AxisData {
713                    values: Vec::new(),
714                    len: vector_len_from_shape(&handle.shape),
715                    is_complex,
716                    gpu_real: Some(handle),
717                });
718            }
719
720            // Fallback: gather to validate / recover axes from meshgrid matrices.
721            *prefer_gpu = true;
722            let gathered = gpu_helpers::gather_value_async(&Value::GpuTensor(handle)).await?;
723            match gathered {
724                Value::Tensor(tensor) => {
725                    if is_vector_shape(&tensor.shape) {
726                        *prefer_gpu = true;
727                    }
728                    axis_from_tensor(tensor, index)
729                }
730                Value::ComplexTensor(tensor) => {
731                    if is_vector_shape(&tensor.shape) {
732                        *prefer_gpu = true;
733                    }
734                    axis_from_complex_tensor(tensor, index)
735                }
736                other => Err(builtin_error(format!(
737                    "meshgrid: input argument {} must be numeric, got {other:?}",
738                    index + 1
739                ))),
740            }
741        }
742        other => Err(builtin_error(format!(
743            "meshgrid: input argument {} must be numeric, got {other:?}",
744            index + 1
745        ))),
746    }
747}
748
749fn axis_from_tensor(tensor: Tensor, index: usize) -> crate::BuiltinResult<AxisData> {
750    if is_vector_shape(&tensor.shape) {
751        let mut values = Vec::with_capacity(tensor.data.len());
752        for &v in &tensor.data {
753            values.push((v, 0.0));
754        }
755        return Ok(AxisData {
756            len: values.len(),
757            values,
758            is_complex: false,
759            gpu_real: None,
760        });
761    }
762
763    // Be slightly more permissive than MATLAB: if the input is already a meshgrid-style
764    // coordinate matrix, accept it and recover the original axis vector.
765    //
766    // This is a pragmatic compatibility shim for cases where callers already have
767    // coordinate matrices (X/Y) and pass them through `meshgrid` again.
768    if let Some(axis) = axis_from_meshgrid_matrix_real(&tensor, index)? {
769        return Ok(axis);
770    }
771
772    Err(builtin_error(format!(
773        "meshgrid: input argument {} must be a vector (1xN or Nx1), got shape {:?}",
774        index + 1,
775        tensor.shape
776    )))
777}
778
779fn axis_from_complex_tensor(tensor: ComplexTensor, index: usize) -> crate::BuiltinResult<AxisData> {
780    if is_vector_shape(&tensor.shape) {
781        let is_complex = tensor
782            .data
783            .iter()
784            .any(|&(_, imag)| !imag.is_nan() && imag != 0.0);
785        return Ok(AxisData {
786            len: tensor.data.len(),
787            values: tensor.data,
788            is_complex,
789            gpu_real: None,
790        });
791    }
792
793    if let Some(axis) = axis_from_meshgrid_matrix_complex(&tensor, index)? {
794        return Ok(axis);
795    }
796
797    Err(builtin_error(format!(
798        "meshgrid: input argument {} must be a vector (1xN or Nx1), got shape {:?}",
799        index + 1,
800        tensor.shape
801    )))
802}
803
804fn axis_from_meshgrid_matrix_real(
805    tensor: &Tensor,
806    index: usize,
807) -> crate::BuiltinResult<Option<AxisData>> {
808    let (rows, cols) = match tensor.shape.as_slice() {
809        [r, c] => (*r, *c),
810        _ => return Ok(None),
811    };
812    if rows <= 1 || cols <= 1 {
813        return Ok(None);
814    }
815
816    // Index 0 is expected to be the X-axis: a meshgrid X matrix has identical rows.
817    // Index 1 is expected to be the Y-axis: a meshgrid Y matrix has identical columns.
818    let expect_rows_constant = index == 0;
819
820    if expect_rows_constant {
821        if !matrix_rows_are_identical_real(tensor, rows, cols) {
822            return Ok(None);
823        }
824        // Extract the first row as the axis vector (length = cols).
825        let mut values = Vec::with_capacity(cols);
826        for col in 0..cols {
827            let idx = rows * col;
828            values.push((tensor.data[idx], 0.0));
829        }
830        return Ok(Some(AxisData {
831            len: values.len(),
832            values,
833            is_complex: false,
834            gpu_real: None,
835        }));
836    }
837
838    if !matrix_cols_are_identical_real(tensor, rows, cols) {
839        return Ok(None);
840    }
841    // Extract the first column as the axis vector (length = rows).
842    let mut values = Vec::with_capacity(rows);
843    for row in 0..rows {
844        values.push((tensor.data[row], 0.0));
845    }
846    Ok(Some(AxisData {
847        len: values.len(),
848        values,
849        is_complex: false,
850        gpu_real: None,
851    }))
852}
853
854fn axis_from_meshgrid_matrix_complex(
855    tensor: &ComplexTensor,
856    index: usize,
857) -> crate::BuiltinResult<Option<AxisData>> {
858    let (rows, cols) = match tensor.shape.as_slice() {
859        [r, c] => (*r, *c),
860        _ => return Ok(None),
861    };
862    if rows <= 1 || cols <= 1 {
863        return Ok(None);
864    }
865
866    let expect_rows_constant = index == 0;
867    if expect_rows_constant {
868        if !matrix_rows_are_identical_complex(tensor, rows, cols) {
869            return Ok(None);
870        }
871        let mut values = Vec::with_capacity(cols);
872        for col in 0..cols {
873            let idx = rows * col;
874            values.push(tensor.data[idx]);
875        }
876        let is_complex = values.iter().any(|&(_, im)| !im.is_nan() && im != 0.0);
877        return Ok(Some(AxisData {
878            len: values.len(),
879            values,
880            is_complex,
881            gpu_real: None,
882        }));
883    }
884
885    if !matrix_cols_are_identical_complex(tensor, rows, cols) {
886        return Ok(None);
887    }
888    let mut values = Vec::with_capacity(rows);
889    for row in 0..rows {
890        values.push(tensor.data[row]);
891    }
892    let is_complex = values.iter().any(|&(_, im)| !im.is_nan() && im != 0.0);
893    Ok(Some(AxisData {
894        len: values.len(),
895        values,
896        is_complex,
897        gpu_real: None,
898    }))
899}
900
901fn matrix_rows_are_identical_real(tensor: &Tensor, rows: usize, cols: usize) -> bool {
902    for row in 1..rows {
903        for col in 0..cols {
904            let idx0 = rows * col;
905            let idx = row + rows * col;
906            if tensor.data[idx] != tensor.data[idx0] {
907                return false;
908            }
909        }
910    }
911    true
912}
913
914fn matrix_cols_are_identical_real(tensor: &Tensor, rows: usize, cols: usize) -> bool {
915    for col in 1..cols {
916        for row in 0..rows {
917            let idx0 = row;
918            let idx = row + rows * col;
919            if tensor.data[idx] != tensor.data[idx0] {
920                return false;
921            }
922        }
923    }
924    true
925}
926
927fn matrix_rows_are_identical_complex(tensor: &ComplexTensor, rows: usize, cols: usize) -> bool {
928    for row in 1..rows {
929        for col in 0..cols {
930            let idx0 = rows * col;
931            let idx = row + rows * col;
932            if tensor.data[idx] != tensor.data[idx0] {
933                return false;
934            }
935        }
936    }
937    true
938}
939
940fn matrix_cols_are_identical_complex(tensor: &ComplexTensor, rows: usize, cols: usize) -> bool {
941    for col in 1..cols {
942        for row in 0..rows {
943            let idx0 = row;
944            let idx = row + rows * col;
945            if tensor.data[idx] != tensor.data[idx0] {
946                return false;
947            }
948        }
949    }
950    true
951}
952
953fn is_vector_shape(shape: &[usize]) -> bool {
954    if shape.is_empty() {
955        return true;
956    }
957    let mut non_singleton = 0usize;
958    for &dim in shape {
959        if dim > 1 {
960            non_singleton += 1;
961        }
962    }
963    non_singleton <= 1
964}
965
966fn vector_len_from_shape(shape: &[usize]) -> usize {
967    if shape.is_empty() {
968        return 1;
969    }
970    shape.iter().copied().max().unwrap_or(0)
971}
972
973async fn axis_to_host_async(axis: &AxisData) -> crate::BuiltinResult<AxisData> {
974    if axis.gpu_real.is_none() {
975        return Ok(axis.clone());
976    }
977    let handle = axis.gpu_real.as_ref().expect("checked gpu_real is_some");
978    let gathered = gpu_helpers::gather_value_async(&Value::GpuTensor(handle.clone())).await?;
979    // Index is only used for error messages; value came from a validated vector-like handle.
980    match gathered {
981        Value::Tensor(tensor) => axis_from_tensor(tensor, 0),
982        Value::ComplexTensor(tensor) => axis_from_complex_tensor(tensor, 0),
983        Value::Num(n) => Ok(AxisData {
984            values: vec![(n, 0.0)],
985            len: 1,
986            is_complex: false,
987            gpu_real: None,
988        }),
989        Value::Complex(re, im) => Ok(AxisData {
990            values: vec![(re, im)],
991            len: 1,
992            is_complex: im != 0.0,
993            gpu_real: None,
994        }),
995        other => Err(builtin_error(format!(
996            "meshgrid: expected numeric GPU axis, got {other:?}"
997        ))),
998    }
999}
1000
1001fn try_meshgrid_gpu_from_vector_axes(
1002    x_axis: &AxisData,
1003    y_axis: &AxisData,
1004    z_axis: Option<&AxisData>,
1005) -> crate::BuiltinResult<Option<Vec<MeshgridOutput>>> {
1006    let Some(x_handle) = x_axis.gpu_real.as_ref() else {
1007        return Ok(None);
1008    };
1009    let Some(y_handle) = y_axis.gpu_real.as_ref() else {
1010        return Ok(None);
1011    };
1012
1013    let z_handle = match z_axis {
1014        Some(axis) => match axis.gpu_real.as_ref() {
1015            Some(h) => Some(h),
1016            None => return Ok(None),
1017        },
1018        None => None,
1019    };
1020
1021    let Some(provider) = runmat_accelerate_api::provider_for_handle(x_handle) else {
1022        return Ok(None);
1023    };
1024    let Some(y_provider) = runmat_accelerate_api::provider_for_handle(y_handle) else {
1025        return Ok(None);
1026    };
1027    if y_provider.device_id() != provider.device_id() {
1028        return Ok(None);
1029    }
1030    if let Some(z) = z_handle {
1031        let Some(z_provider) = runmat_accelerate_api::provider_for_handle(z) else {
1032            return Ok(None);
1033        };
1034        if z_provider.device_id() != provider.device_id() {
1035            return Ok(None);
1036        }
1037    }
1038
1039    let nx = x_axis.len;
1040    let ny = y_axis.len;
1041    let nz = z_axis.map(|axis| axis.len).unwrap_or(1);
1042
1043    // Reshape axis vectors (metadata-only) so repmat can build full grids on-device.
1044    let x_row = provider
1045        .reshape(x_handle, &[1, nx])
1046        .map_err(|e| builtin_error(format!("meshgrid: reshape X failed: {e}")))?;
1047    let y_col = provider
1048        .reshape(y_handle, &[ny, 1])
1049        .map_err(|e| builtin_error(format!("meshgrid: reshape Y failed: {e}")))?;
1050
1051    let mut outputs = Vec::with_capacity(if z_handle.is_some() { 3 } else { 2 });
1052    if let Some(z) = z_handle {
1053        let x_base = provider
1054            .reshape(&x_row, &[1, nx, 1])
1055            .map_err(|e| builtin_error(format!("meshgrid: reshape X(3d) failed: {e}")))?;
1056        let y_base = provider
1057            .reshape(&y_col, &[ny, 1, 1])
1058            .map_err(|e| builtin_error(format!("meshgrid: reshape Y(3d) failed: {e}")))?;
1059
1060        let x_grid = provider
1061            .repmat(&x_base, &[ny, 1, nz])
1062            .map_err(|e| builtin_error(format!("meshgrid: repmat X failed: {e}")))?;
1063        let y_grid = provider
1064            .repmat(&y_base, &[1, nx, nz])
1065            .map_err(|e| builtin_error(format!("meshgrid: repmat Y failed: {e}")))?;
1066
1067        outputs.push(MeshgridOutput::Gpu(x_grid));
1068        outputs.push(MeshgridOutput::Gpu(y_grid));
1069        let z_axis_row = provider
1070            .reshape(z, &[1, nz])
1071            .map_err(|e| builtin_error(format!("meshgrid: reshape Z failed: {e}")))?;
1072        let z_base = provider
1073            .reshape(&z_axis_row, &[1, 1, nz])
1074            .map_err(|e| builtin_error(format!("meshgrid: reshape Z(3d) failed: {e}")))?;
1075        let z_grid = provider
1076            .repmat(&z_base, &[ny, nx, 1])
1077            .map_err(|e| builtin_error(format!("meshgrid: repmat Z failed: {e}")))?;
1078        outputs.push(MeshgridOutput::Gpu(z_grid));
1079    } else {
1080        let x_grid = provider
1081            .repmat(&x_row, &[ny, 1])
1082            .map_err(|e| builtin_error(format!("meshgrid: repmat X failed: {e}")))?;
1083        let y_grid = provider
1084            .repmat(&y_col, &[1, nx])
1085            .map_err(|e| builtin_error(format!("meshgrid: repmat Y failed: {e}")))?;
1086        outputs.push(MeshgridOutput::Gpu(x_grid));
1087        outputs.push(MeshgridOutput::Gpu(y_grid));
1088    }
1089
1090    Ok(Some(outputs))
1091}
1092
1093fn normalise_axes(axes: &[AxisData]) -> (AxisData, AxisData, Option<AxisData>) {
1094    match axes.len() {
1095        1 => {
1096            let x = axes[0].clone();
1097            (x.clone(), x, None)
1098        }
1099        2 => {
1100            let x = axes[0].clone();
1101            let y = axes[1].clone();
1102            (x, y, None)
1103        }
1104        3 => {
1105            let x = axes[0].clone();
1106            let y = axes[1].clone();
1107            let z = axes[2].clone();
1108            (x, y, Some(z))
1109        }
1110        _ => unreachable!(),
1111    }
1112}
1113
1114fn build_outputs(
1115    x_axis: &AxisData,
1116    y_axis: &AxisData,
1117    z_axis: Option<&AxisData>,
1118) -> Vec<GridOutput> {
1119    let nx = x_axis.len;
1120    let ny = y_axis.len;
1121    let nz = z_axis.map(|axis| axis.len).unwrap_or(1);
1122    let total = nx * ny * nz;
1123    let mut x_data = Vec::with_capacity(total);
1124    let mut y_data = Vec::with_capacity(total);
1125    let mut z_data = z_axis.map(|_| Vec::with_capacity(total));
1126
1127    for k in 0..nz {
1128        let z_value = z_axis.map(|axis| axis.values[k]);
1129        for col in 0..nx {
1130            let x_value = x_axis.values[col];
1131            for row in 0..ny {
1132                x_data.push(x_value);
1133                y_data.push(y_axis.values[row]);
1134                if let Some(ref mut z_vec) = z_data {
1135                    z_vec.push(z_value.unwrap());
1136                }
1137            }
1138        }
1139    }
1140
1141    let mut outputs = Vec::new();
1142    let base_shape = if nz == 1 {
1143        vec![ny, nx]
1144    } else {
1145        vec![ny, nx, nz]
1146    };
1147    outputs.push(GridOutput {
1148        shape: base_shape.clone(),
1149        data: x_data,
1150    });
1151    outputs.push(GridOutput {
1152        shape: base_shape.clone(),
1153        data: y_data,
1154    });
1155    if let Some(z_vec) = z_data {
1156        outputs.push(GridOutput {
1157            shape: base_shape,
1158            data: z_vec,
1159        });
1160    }
1161    outputs
1162}
1163
1164struct GridOutput {
1165    shape: Vec<usize>,
1166    data: Vec<(f64, f64)>,
1167}
1168
1169impl GridOutput {
1170    fn to_value(
1171        &self,
1172        class: PrototypeClass,
1173        residency: DevicePreference,
1174    ) -> crate::BuiltinResult<Value> {
1175        match class {
1176            PrototypeClass::Real => self.to_real_value(residency),
1177            PrototypeClass::Complex => self.to_complex_value(residency),
1178        }
1179    }
1180
1181    fn to_real_value(&self, residency: DevicePreference) -> crate::BuiltinResult<Value> {
1182        let mut real = Vec::with_capacity(self.data.len());
1183        for &(re, im) in &self.data {
1184            if im != 0.0 {
1185                return Err(builtin_error(
1186                    "meshgrid: cannot represent complex values in a real output",
1187                ));
1188            }
1189            real.push(re);
1190        }
1191        let tensor = Tensor::new(real, self.shape.clone())
1192            .map_err(|e| builtin_error(format!("meshgrid: {e}")))?;
1193        match residency {
1194            DevicePreference::Host => Ok(tensor::tensor_into_value(tensor)),
1195            DevicePreference::Gpu => to_gpu_tensor_value(tensor),
1196        }
1197    }
1198
1199    fn to_complex_value(&self, residency: DevicePreference) -> crate::BuiltinResult<Value> {
1200        let tensor = ComplexTensor::new(self.data.clone(), self.shape.clone())
1201            .map_err(|e| builtin_error(format!("meshgrid: {e}")))?;
1202        match residency {
1203            DevicePreference::Host => Ok(complex_tensor_into_value(tensor)),
1204            DevicePreference::Gpu => to_complex_gpu_tensor_value(tensor),
1205        }
1206    }
1207}
1208
1209fn to_gpu_tensor_value(tensor: Tensor) -> crate::BuiltinResult<Value> {
1210    if let Some(provider) = runmat_accelerate_api::provider() {
1211        let view = HostTensorView {
1212            data: &tensor.data,
1213            shape: &tensor.shape,
1214        };
1215        match provider.upload(&view) {
1216            Ok(handle) => return Ok(Value::GpuTensor(handle)),
1217            Err(err) => {
1218                warn!("meshgrid: failed to upload tensor to GPU, returning host array: {err}")
1219            }
1220        }
1221    }
1222    Ok(tensor::tensor_into_value(tensor))
1223}
1224
1225fn to_complex_gpu_tensor_value(tensor: ComplexTensor) -> crate::BuiltinResult<Value> {
1226    if let Some(provider) = runmat_accelerate_api::provider() {
1227        match gpu_helpers::upload_complex_tensor(provider, &tensor) {
1228            Ok(handle) => return Ok(gpu_helpers::complex_gpu_value(handle)),
1229            Err(err) => {
1230                warn!(
1231                    "meshgrid: failed to upload complex tensor to GPU, returning host array: {err}"
1232                )
1233            }
1234        }
1235    }
1236    Ok(complex_tensor_into_value(tensor))
1237}
1238
1239fn tensor_to_complex_tensor(tensor: Tensor) -> crate::BuiltinResult<ComplexTensor> {
1240    let data: Vec<(f64, f64)> = tensor.data.iter().map(|&re| (re, 0.0)).collect();
1241    ComplexTensor::new(data, tensor.shape.clone())
1242        .map_err(|e| builtin_error(format!("meshgrid: {e}")))
1243}
1244
1245fn tensor_to_complex_value(tensor: Tensor) -> crate::BuiltinResult<Value> {
1246    let complex = tensor_to_complex_tensor(tensor)?;
1247    Ok(complex_tensor_into_value(complex))
1248}
1249
1250enum MeshgridOutput {
1251    Host(GridOutput),
1252    Gpu(GpuTensorHandle),
1253}
1254
1255impl MeshgridOutput {
1256    async fn to_value(
1257        &self,
1258        class: PrototypeClass,
1259        residency: DevicePreference,
1260    ) -> crate::BuiltinResult<Value> {
1261        match self {
1262            MeshgridOutput::Host(host) => host.to_value(class, residency),
1263            MeshgridOutput::Gpu(handle) => match (class, residency) {
1264                (PrototypeClass::Real, DevicePreference::Gpu) => {
1265                    Ok(Value::GpuTensor(handle.clone()))
1266                }
1267                (PrototypeClass::Real, DevicePreference::Host) => {
1268                    let tensor = gpu_helpers::gather_tensor_async(handle).await?;
1269                    Ok(tensor::tensor_into_value(tensor))
1270                }
1271                (PrototypeClass::Complex, DevicePreference::Host) => {
1272                    match gpu_helpers::gather_value_async(&Value::GpuTensor(handle.clone())).await?
1273                    {
1274                        Value::ComplexTensor(tensor) => Ok(complex_tensor_into_value(tensor)),
1275                        Value::Complex(re, im) => Ok(Value::Complex(re, im)),
1276                        Value::Tensor(tensor) => tensor_to_complex_value(tensor),
1277                        Value::Num(n) => Ok(Value::Complex(n, 0.0)),
1278                        other => Err(builtin_error(format!(
1279                            "meshgrid: expected numeric GPU output, got {other:?}"
1280                        ))),
1281                    }
1282                }
1283                (PrototypeClass::Complex, DevicePreference::Gpu) => {
1284                    if runmat_accelerate_api::handle_storage(handle)
1285                        == GpuTensorStorage::ComplexInterleaved
1286                    {
1287                        Ok(gpu_helpers::complex_gpu_value(handle.clone()))
1288                    } else {
1289                        let tensor = gpu_helpers::gather_tensor_async(handle).await?;
1290                        to_complex_gpu_tensor_value(tensor_to_complex_tensor(tensor)?)
1291                    }
1292                }
1293            },
1294        }
1295    }
1296}
1297
1298/// Holds the results of a `meshgrid` evaluation so multiple outputs can be
1299/// materialised without recomputing the grid.
1300pub struct MeshgridEval {
1301    outputs: Vec<MeshgridOutput>,
1302    target_class: PrototypeClass,
1303    target_residency: DevicePreference,
1304}
1305
1306impl MeshgridEval {
1307    pub fn output_count(&self) -> usize {
1308        self.outputs.len()
1309    }
1310
1311    pub async fn first(&self) -> crate::BuiltinResult<Value> {
1312        self.outputs[0]
1313            .to_value(self.target_class, self.target_residency)
1314            .await
1315    }
1316
1317    pub async fn second(&self) -> crate::BuiltinResult<Value> {
1318        if self.outputs.len() < 2 {
1319            Err(builtin_error("meshgrid: second output unavailable"))
1320        } else {
1321            self.outputs[1]
1322                .to_value(self.target_class, self.target_residency)
1323                .await
1324        }
1325    }
1326
1327    pub async fn third(&self) -> crate::BuiltinResult<Value> {
1328        if self.outputs.len() < 3 {
1329            Err(builtin_error(
1330                "meshgrid: third output requested but no Z vector was supplied",
1331            ))
1332        } else {
1333            self.outputs[2]
1334                .to_value(self.target_class, self.target_residency)
1335                .await
1336        }
1337    }
1338}
1339
1340#[cfg(test)]
1341pub(crate) mod tests {
1342    use super::*;
1343    use crate::builtins::common::test_support;
1344    use futures::executor::block_on;
1345    #[cfg(feature = "wgpu")]
1346    use runmat_accelerate_api::AccelProvider;
1347
1348    use runmat_accelerate_api::HostTensorView;
1349
1350    fn evaluate(args: &[Value]) -> crate::BuiltinResult<MeshgridEval> {
1351        block_on(super::evaluate(args))
1352    }
1353
1354    fn eval_first(eval: &MeshgridEval) -> crate::BuiltinResult<Value> {
1355        block_on(eval.first())
1356    }
1357
1358    fn eval_second(eval: &MeshgridEval) -> crate::BuiltinResult<Value> {
1359        block_on(eval.second())
1360    }
1361
1362    fn eval_third(eval: &MeshgridEval) -> crate::BuiltinResult<Value> {
1363        block_on(eval.third())
1364    }
1365
1366    fn tensor_from_vec(data: Vec<f64>, rows: usize, cols: usize) -> Tensor {
1367        Tensor::new(data, vec![rows, cols]).unwrap()
1368    }
1369
1370    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1371    #[test]
1372    fn meshgrid_single_input_duplicates_axis() {
1373        let x = tensor_from_vec(vec![-1.0, 0.0, 1.0], 1, 3);
1374        let eval = evaluate(&[Value::Tensor(x)]).expect("meshgrid");
1375        assert_eq!(eval.output_count(), 2);
1376        let x_out = test_support::gather(eval_first(&eval).expect("X")).expect("host");
1377        assert_eq!(x_out.shape, vec![3, 3]);
1378        assert_eq!(
1379            x_out.data,
1380            vec![-1.0, -1.0, -1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0]
1381        );
1382        let y_out = test_support::gather(eval_second(&eval).expect("Y")).expect("host");
1383        assert_eq!(y_out.shape, vec![3, 3]);
1384        assert_eq!(
1385            y_out.data,
1386            vec![-1.0, 0.0, 1.0, -1.0, 0.0, 1.0, -1.0, 0.0, 1.0]
1387        );
1388    }
1389
1390    #[test]
1391    fn meshgrid_type_infers_rank_from_axis_count() {
1392        let ctx = ResolveContext::new(Vec::new());
1393        assert_eq!(
1394            meshgrid_type(&[Type::Num, Type::Num], &ctx),
1395            Type::Tensor {
1396                shape: Some(vec![Some(1), Some(1)])
1397            }
1398        );
1399        assert_eq!(
1400            meshgrid_type(&[Type::Num, Type::Num, Type::Num], &ctx),
1401            Type::Tensor {
1402                shape: Some(vec![Some(1), Some(1), Some(1)])
1403            }
1404        );
1405    }
1406
1407    #[test]
1408    fn meshgrid_type_uses_vector_lengths() {
1409        let ctx = ResolveContext::new(Vec::new());
1410        assert_eq!(
1411            meshgrid_type(
1412                &[
1413                    Type::Tensor {
1414                        shape: Some(vec![Some(1), Some(201)]),
1415                    },
1416                    Type::Tensor {
1417                        shape: Some(vec![Some(1), Some(101)]),
1418                    },
1419                ],
1420                &ctx,
1421            ),
1422            Type::Tensor {
1423                shape: Some(vec![Some(101), Some(201)])
1424            }
1425        );
1426    }
1427
1428    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1429    #[test]
1430    fn meshgrid_rectangular_inputs() {
1431        let x = tensor_from_vec(vec![0.0, 0.5, 1.0], 1, 3);
1432        let y = tensor_from_vec(vec![10.0, 20.0], 2, 1);
1433        let eval = evaluate(&[Value::Tensor(x), Value::Tensor(y)]).expect("meshgrid");
1434        assert_eq!(eval.output_count(), 2);
1435        let x_out = test_support::gather(eval_first(&eval).expect("X")).expect("host");
1436        assert_eq!(x_out.shape, vec![2, 3]);
1437        assert_eq!(x_out.data, vec![0.0, 0.0, 0.5, 0.5, 1.0, 1.0]);
1438        let y_out = test_support::gather(eval_second(&eval).expect("Y")).expect("host");
1439        assert_eq!(y_out.shape, vec![2, 3]);
1440        assert_eq!(y_out.data, vec![10.0, 20.0, 10.0, 20.0, 10.0, 20.0]);
1441    }
1442
1443    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1444    #[test]
1445    fn meshgrid_three_inputs_volume() {
1446        let x = tensor_from_vec(vec![1.0, 2.0], 1, 2);
1447        let y = tensor_from_vec(vec![5.0, 6.0, 7.0], 3, 1);
1448        let z = tensor_from_vec(vec![0.0, 1.0], 1, 2);
1449        let eval =
1450            evaluate(&[Value::Tensor(x), Value::Tensor(y), Value::Tensor(z)]).expect("meshgrid");
1451        assert_eq!(eval.output_count(), 3);
1452        let x_out = test_support::gather(eval_first(&eval).expect("X")).expect("host");
1453        assert_eq!(x_out.shape, vec![3, 2, 2]);
1454        assert_eq!(
1455            x_out.data,
1456            vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0]
1457        );
1458        let z_out = test_support::gather(eval_third(&eval).expect("Z")).expect("host");
1459        assert_eq!(z_out.shape, vec![3, 2, 2]);
1460        assert_eq!(
1461            z_out.data,
1462            vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
1463        );
1464    }
1465
1466    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1467    #[test]
1468    fn meshgrid_like_keeps_gpu_residency() {
1469        test_support::with_test_provider(|provider| {
1470            let x = tensor_from_vec(vec![-1.0, 0.0, 1.0], 1, 3);
1471            let y = tensor_from_vec(vec![2.0, 4.0], 2, 1);
1472            let proto = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
1473            let proto_view = HostTensorView {
1474                data: &proto.data,
1475                shape: &proto.shape,
1476            };
1477            let proto_handle = provider.upload(&proto_view).expect("upload");
1478            let eval = evaluate(&[
1479                Value::Tensor(x),
1480                Value::Tensor(y),
1481                Value::from("like"),
1482                Value::GpuTensor(proto_handle),
1483            ])
1484            .expect("meshgrid");
1485            let x_value = eval_first(&eval).expect("X");
1486            assert!(matches!(x_value, Value::GpuTensor(_)));
1487            let gathered = test_support::gather(x_value).expect("gather");
1488            assert_eq!(gathered.shape, vec![2, 3]);
1489        });
1490    }
1491
1492    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1493    #[test]
1494    fn meshgrid_gpu_inputs_roundtrip() {
1495        test_support::with_test_provider(|provider| {
1496            let x = tensor_from_vec(vec![0.0, 0.5], 1, 2);
1497            let y = tensor_from_vec(vec![1.0, 2.0], 2, 1);
1498            let x_view = HostTensorView {
1499                data: &x.data,
1500                shape: &x.shape,
1501            };
1502            let y_view = HostTensorView {
1503                data: &y.data,
1504                shape: &y.shape,
1505            };
1506            let x_handle = provider.upload(&x_view).expect("upload");
1507            let y_handle = provider.upload(&y_view).expect("upload");
1508            let eval = evaluate(&[Value::GpuTensor(x_handle), Value::GpuTensor(y_handle)])
1509                .expect("meshgrid");
1510            assert!(matches!(eval_first(&eval).expect("X"), Value::GpuTensor(_)));
1511            assert!(matches!(
1512                eval_second(&eval).expect("Y"),
1513                Value::GpuTensor(_)
1514            ));
1515        });
1516    }
1517
1518    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1519    #[test]
1520    #[cfg(feature = "wgpu")]
1521    fn meshgrid_wgpu_matches_cpu() {
1522        let _guard = test_support::accel_test_lock();
1523        let Ok(provider) = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1524            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1525        ) else {
1526            return;
1527        };
1528
1529        let x = tensor_from_vec(vec![-1.0, 0.0, 1.0, 2.0], 1, 4);
1530        let y = tensor_from_vec(vec![5.0, 6.0], 2, 1);
1531
1532        let cpu_eval =
1533            evaluate(&[Value::Tensor(x.clone()), Value::Tensor(y.clone())]).expect("meshgrid cpu");
1534        let cpu_x =
1535            test_support::gather(eval_first(&cpu_eval).expect("X cpu")).expect("gather X cpu");
1536        let cpu_y =
1537            test_support::gather(eval_second(&cpu_eval).expect("Y cpu")).expect("gather Y cpu");
1538
1539        let x_view = HostTensorView {
1540            data: &x.data,
1541            shape: &x.shape,
1542        };
1543        let y_view = HostTensorView {
1544            data: &y.data,
1545            shape: &y.shape,
1546        };
1547        let x_gpu = provider.upload(&x_view).expect("upload x");
1548        let y_gpu = provider.upload(&y_view).expect("upload y");
1549
1550        let gpu_eval =
1551            evaluate(&[Value::GpuTensor(x_gpu), Value::GpuTensor(y_gpu)]).expect("meshgrid gpu");
1552        let gpu_x_value = eval_first(&gpu_eval).expect("X gpu");
1553        let gpu_y_value = eval_second(&gpu_eval).expect("Y gpu");
1554
1555        assert!(matches!(gpu_x_value, Value::GpuTensor(_)));
1556        assert!(matches!(gpu_y_value, Value::GpuTensor(_)));
1557
1558        let gathered_x = test_support::gather(gpu_x_value).expect("gather X gpu");
1559        let gathered_y = test_support::gather(gpu_y_value).expect("gather Y gpu");
1560
1561        assert_eq!(gathered_x.shape, cpu_x.shape);
1562        assert_eq!(gathered_x.data, cpu_x.data);
1563        assert_eq!(gathered_y.shape, cpu_y.shape);
1564        assert_eq!(gathered_y.data, cpu_y.data);
1565    }
1566
1567    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1568    #[test]
1569    fn meshgrid_complex_inputs_produce_complex_outputs() {
1570        let complex = ComplexTensor::new(vec![(1.0, 1.0), (2.0, -1.0)], vec![1, 2]).unwrap();
1571        let eval = evaluate(&[Value::ComplexTensor(complex)]).expect("meshgrid");
1572        let x_value = eval_first(&eval).expect("X");
1573        match x_value {
1574            Value::ComplexTensor(ct) => {
1575                assert_eq!(ct.shape, vec![2, 2]);
1576            }
1577            Value::Complex(_, _) => {}
1578            other => panic!("expected complex output, got {other:?}"),
1579        }
1580    }
1581
1582    #[test]
1583    fn meshgrid_like_complex_gpu_prototype_keeps_complex_residency() {
1584        test_support::with_test_provider(|provider| {
1585            let x = tensor_from_vec(vec![1.0, 2.0], 1, 2);
1586            let proto = ComplexTensor::new(vec![(0.0, 1.0)], vec![1, 1]).unwrap();
1587            let proto_handle =
1588                gpu_helpers::upload_complex_tensor(provider, &proto).expect("upload");
1589
1590            let eval = evaluate(&[
1591                Value::Tensor(x),
1592                Value::from("like"),
1593                Value::GpuTensor(proto_handle),
1594            ])
1595            .expect("meshgrid");
1596            let x_value = eval_first(&eval).expect("X");
1597            let Value::GpuTensor(handle) = x_value else {
1598                panic!("expected complex gpu tensor");
1599            };
1600            assert_eq!(
1601                runmat_accelerate_api::handle_storage(&handle),
1602                GpuTensorStorage::ComplexInterleaved
1603            );
1604            let gathered = block_on(gpu_helpers::gather_value_async(&Value::GpuTensor(handle)))
1605                .expect("gather");
1606            let Value::ComplexTensor(tensor) = gathered else {
1607                panic!("expected complex tensor");
1608            };
1609            assert_eq!(tensor.shape, vec![2, 2]);
1610            assert_eq!(
1611                tensor.data,
1612                vec![(1.0, 0.0), (1.0, 0.0), (2.0, 0.0), (2.0, 0.0)]
1613            );
1614        });
1615    }
1616
1617    #[test]
1618    fn meshgrid_complex_gpu_axis_stays_resident() {
1619        test_support::with_test_provider(|provider| {
1620            let axis = ComplexTensor::new(vec![(1.0, 1.0), (2.0, -1.0)], vec![1, 2]).unwrap();
1621            let axis_handle = gpu_helpers::upload_complex_tensor(provider, &axis).expect("upload");
1622
1623            let eval = evaluate(&[Value::GpuTensor(axis_handle)]).expect("meshgrid");
1624            let x_value = eval_first(&eval).expect("X");
1625            let Value::GpuTensor(handle) = x_value else {
1626                panic!("expected complex gpu tensor");
1627            };
1628            assert_eq!(
1629                runmat_accelerate_api::handle_storage(&handle),
1630                GpuTensorStorage::ComplexInterleaved
1631            );
1632            let gathered = block_on(gpu_helpers::gather_value_async(&Value::GpuTensor(handle)))
1633                .expect("gather");
1634            let Value::ComplexTensor(tensor) = gathered else {
1635                panic!("expected complex tensor");
1636            };
1637            assert_eq!(tensor.shape, vec![2, 2]);
1638            assert_eq!(
1639                tensor.data,
1640                vec![(1.0, 1.0), (1.0, 1.0), (2.0, -1.0), (2.0, -1.0)]
1641            );
1642        });
1643    }
1644
1645    #[test]
1646    #[cfg(feature = "wgpu")]
1647    fn meshgrid_wgpu_complex_axis_matches_cpu_and_stays_resident() {
1648        let _guard = test_support::accel_test_lock();
1649        let Ok(provider) = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1650            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1651        ) else {
1652            return;
1653        };
1654
1655        let axis = ComplexTensor::new(vec![(1.0, 1.0), (2.0, -1.0)], vec![1, 2]).unwrap();
1656        let cpu_eval = evaluate(&[Value::ComplexTensor(axis.clone())]).expect("meshgrid cpu");
1657        let cpu_x = match eval_first(&cpu_eval).expect("X cpu") {
1658            Value::ComplexTensor(tensor) => tensor,
1659            other => panic!("expected cpu complex tensor, got {other:?}"),
1660        };
1661
1662        let axis_handle = gpu_helpers::upload_complex_tensor(provider, &axis).expect("upload");
1663        let gpu_eval = evaluate(&[Value::GpuTensor(axis_handle)]).expect("meshgrid gpu");
1664        let gpu_x = eval_first(&gpu_eval).expect("X gpu");
1665        let Value::GpuTensor(handle) = gpu_x else {
1666            panic!("expected complex gpu tensor");
1667        };
1668        assert_eq!(
1669            runmat_accelerate_api::handle_storage(&handle),
1670            GpuTensorStorage::ComplexInterleaved
1671        );
1672        let gathered =
1673            block_on(gpu_helpers::gather_value_async(&Value::GpuTensor(handle))).expect("gather");
1674        let Value::ComplexTensor(gpu_tensor) = gathered else {
1675            panic!("expected complex tensor");
1676        };
1677        assert_eq!(gpu_tensor.shape, cpu_x.shape);
1678        assert_eq!(gpu_tensor.data, cpu_x.data);
1679    }
1680
1681    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1682    #[test]
1683    fn meshgrid_like_host_prototype() {
1684        let x = tensor_from_vec(vec![1.0, 2.0], 1, 2);
1685        let eval =
1686            evaluate(&[Value::Tensor(x), Value::from("like"), Value::Num(0.0)]).expect("meshgrid");
1687        let x_out = eval_first(&eval).expect("X");
1688        assert!(matches!(x_out, Value::Tensor(_) | Value::Num(_)));
1689    }
1690}