1use std::cmp::max;
4
5use log::warn;
6use runmat_accelerate_api::{GpuTensorHandle, HostTensorView};
7use runmat_builtins::{
8 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
9 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
10 ComplexTensor, ResolveContext, Tensor, Type, Value,
11};
12
13use crate::builtins::array::type_resolvers::size_vector_len;
14use runmat_macros::runtime_builtin;
15
16use crate::build_runtime_error;
17use crate::builtins::common::gpu_helpers;
18use crate::builtins::common::random_args::{complex_tensor_into_value, keyword_of};
19use crate::builtins::common::residency::{sequence_gpu_preference, SequenceIntent};
20use crate::builtins::common::spec::{
21 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
22 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
23};
24use crate::builtins::common::tensor;
25
26#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::creation::meshgrid")]
27pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
28 name: "meshgrid",
29 op_kind: GpuOpKind::Custom("array_construct"),
30 supported_precisions: &[ScalarType::F32, ScalarType::F64],
31 broadcast: BroadcastSemantics::Matlab,
32 provider_hooks: &[ProviderHook::Custom("meshgrid")],
33 constant_strategy: ConstantStrategy::InlineLiteral,
34 residency: ResidencyPolicy::NewHandle,
35 nan_mode: ReductionNaN::Include,
36 two_pass_threshold: None,
37 workgroup_size: None,
38 accepts_nan_mode: false,
39 notes: "Providers may supply a dedicated meshgrid hook; until then the runtime builds grids on the host and uploads them when GPU residency is requested.",
40};
41
42fn builtin_error(message: impl Into<String>) -> crate::RuntimeError {
43 build_runtime_error(message)
44 .with_builtin("meshgrid")
45 .build()
46}
47
48#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::creation::meshgrid")]
49pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
50 name: "meshgrid",
51 shape: ShapeRequirements::Any,
52 constant_strategy: ConstantStrategy::InlineLiteral,
53 elementwise: None,
54 reduction: None,
55 emits_nan: false,
56 notes:
57 "Meshgrid explicitly materialises dense coordinate arrays and therefore bypasses fusion.",
58};
59
60fn meshgrid_type(args: &[Type], _context: &ResolveContext) -> Type {
61 if args.is_empty() {
62 return Type::Unknown;
63 }
64 let mut axis_count = args.len();
65 if axis_count >= 2 && matches!(args[axis_count - 2], Type::String) {
66 axis_count = axis_count.saturating_sub(2);
67 }
68 if axis_count == 0 {
69 return Type::Unknown;
70 }
71 let axis_args = &args[..axis_count];
72 let len_x = axis_args.get(0).and_then(size_vector_len);
73 let len_y = axis_args.get(1).and_then(size_vector_len).or(len_x);
74 let len_z = axis_args.get(2).and_then(size_vector_len);
75 let shape = if axis_count >= 3 {
76 vec![len_y, len_x, len_z]
77 } else {
78 vec![len_y, len_x]
79 };
80 Type::Tensor { shape: Some(shape) }
81}
82
83const MESHGRID_OUTPUT_XY: [BuiltinParamDescriptor; 2] = [
84 BuiltinParamDescriptor {
85 name: "X",
86 ty: BuiltinParamType::NumericArray,
87 arity: BuiltinParamArity::Required,
88 default: None,
89 description: "Grid coordinates along X-axis.",
90 },
91 BuiltinParamDescriptor {
92 name: "Y",
93 ty: BuiltinParamType::NumericArray,
94 arity: BuiltinParamArity::Required,
95 default: None,
96 description: "Grid coordinates along Y-axis.",
97 },
98];
99
100const MESHGRID_OUTPUT_XYZ: [BuiltinParamDescriptor; 3] = [
101 BuiltinParamDescriptor {
102 name: "X",
103 ty: BuiltinParamType::NumericArray,
104 arity: BuiltinParamArity::Required,
105 default: None,
106 description: "Grid coordinates along X-axis.",
107 },
108 BuiltinParamDescriptor {
109 name: "Y",
110 ty: BuiltinParamType::NumericArray,
111 arity: BuiltinParamArity::Required,
112 default: None,
113 description: "Grid coordinates along Y-axis.",
114 },
115 BuiltinParamDescriptor {
116 name: "Z",
117 ty: BuiltinParamType::NumericArray,
118 arity: BuiltinParamArity::Optional,
119 default: None,
120 description: "Grid coordinates along Z-axis.",
121 },
122];
123
124const MESHGRID_SIG_X_INPUTS: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
125 name: "x",
126 ty: BuiltinParamType::NumericArray,
127 arity: BuiltinParamArity::Required,
128 default: None,
129 description: "X-axis vector.",
130}];
131
132const MESHGRID_SIG_XY_INPUTS: [BuiltinParamDescriptor; 2] = [
133 BuiltinParamDescriptor {
134 name: "x",
135 ty: BuiltinParamType::NumericArray,
136 arity: BuiltinParamArity::Required,
137 default: None,
138 description: "X-axis vector.",
139 },
140 BuiltinParamDescriptor {
141 name: "y",
142 ty: BuiltinParamType::NumericArray,
143 arity: BuiltinParamArity::Required,
144 default: None,
145 description: "Y-axis vector.",
146 },
147];
148
149const MESHGRID_SIG_XYZ_INPUTS: [BuiltinParamDescriptor; 3] = [
150 BuiltinParamDescriptor {
151 name: "x",
152 ty: BuiltinParamType::NumericArray,
153 arity: BuiltinParamArity::Required,
154 default: None,
155 description: "X-axis vector.",
156 },
157 BuiltinParamDescriptor {
158 name: "y",
159 ty: BuiltinParamType::NumericArray,
160 arity: BuiltinParamArity::Required,
161 default: None,
162 description: "Y-axis vector.",
163 },
164 BuiltinParamDescriptor {
165 name: "z",
166 ty: BuiltinParamType::NumericArray,
167 arity: BuiltinParamArity::Optional,
168 default: None,
169 description: "Z-axis vector.",
170 },
171];
172
173const MESHGRID_SIG_X_LIKE_INPUTS: [BuiltinParamDescriptor; 3] = [
174 BuiltinParamDescriptor {
175 name: "x",
176 ty: BuiltinParamType::NumericArray,
177 arity: BuiltinParamArity::Required,
178 default: None,
179 description: "X-axis vector.",
180 },
181 BuiltinParamDescriptor {
182 name: "like_kw",
183 ty: BuiltinParamType::StringScalar,
184 arity: BuiltinParamArity::Required,
185 default: Some("\"like\""),
186 description: "Like keyword.",
187 },
188 BuiltinParamDescriptor {
189 name: "prototype",
190 ty: BuiltinParamType::LikePrototype,
191 arity: BuiltinParamArity::Required,
192 default: None,
193 description: "Prototype controlling class/device residency.",
194 },
195];
196
197const MESHGRID_SIG_XY_LIKE_INPUTS: [BuiltinParamDescriptor; 4] = [
198 BuiltinParamDescriptor {
199 name: "x",
200 ty: BuiltinParamType::NumericArray,
201 arity: BuiltinParamArity::Required,
202 default: None,
203 description: "X-axis vector.",
204 },
205 BuiltinParamDescriptor {
206 name: "y",
207 ty: BuiltinParamType::NumericArray,
208 arity: BuiltinParamArity::Required,
209 default: None,
210 description: "Y-axis vector.",
211 },
212 BuiltinParamDescriptor {
213 name: "like_kw",
214 ty: BuiltinParamType::StringScalar,
215 arity: BuiltinParamArity::Required,
216 default: Some("\"like\""),
217 description: "Like keyword.",
218 },
219 BuiltinParamDescriptor {
220 name: "prototype",
221 ty: BuiltinParamType::LikePrototype,
222 arity: BuiltinParamArity::Required,
223 default: None,
224 description: "Prototype controlling class/device residency.",
225 },
226];
227
228const MESHGRID_SIG_XYZ_LIKE_INPUTS: [BuiltinParamDescriptor; 5] = [
229 BuiltinParamDescriptor {
230 name: "x",
231 ty: BuiltinParamType::NumericArray,
232 arity: BuiltinParamArity::Required,
233 default: None,
234 description: "X-axis vector.",
235 },
236 BuiltinParamDescriptor {
237 name: "y",
238 ty: BuiltinParamType::NumericArray,
239 arity: BuiltinParamArity::Required,
240 default: None,
241 description: "Y-axis vector.",
242 },
243 BuiltinParamDescriptor {
244 name: "z",
245 ty: BuiltinParamType::NumericArray,
246 arity: BuiltinParamArity::Optional,
247 default: None,
248 description: "Z-axis vector.",
249 },
250 BuiltinParamDescriptor {
251 name: "like_kw",
252 ty: BuiltinParamType::StringScalar,
253 arity: BuiltinParamArity::Required,
254 default: Some("\"like\""),
255 description: "Like keyword.",
256 },
257 BuiltinParamDescriptor {
258 name: "prototype",
259 ty: BuiltinParamType::LikePrototype,
260 arity: BuiltinParamArity::Required,
261 default: None,
262 description: "Prototype controlling class/device residency.",
263 },
264];
265
266const MESHGRID_SIGNATURES: [BuiltinSignatureDescriptor; 6] = [
267 BuiltinSignatureDescriptor {
268 label: "[X,Y] = meshgrid(x)",
269 inputs: &MESHGRID_SIG_X_INPUTS,
270 outputs: &MESHGRID_OUTPUT_XY,
271 },
272 BuiltinSignatureDescriptor {
273 label: "[X,Y] = meshgrid(x, y)",
274 inputs: &MESHGRID_SIG_XY_INPUTS,
275 outputs: &MESHGRID_OUTPUT_XY,
276 },
277 BuiltinSignatureDescriptor {
278 label: "[X,Y,Z] = meshgrid(x, y, z)",
279 inputs: &MESHGRID_SIG_XYZ_INPUTS,
280 outputs: &MESHGRID_OUTPUT_XYZ,
281 },
282 BuiltinSignatureDescriptor {
283 label: "[X,Y] = meshgrid(x, \"like\", prototype)",
284 inputs: &MESHGRID_SIG_X_LIKE_INPUTS,
285 outputs: &MESHGRID_OUTPUT_XY,
286 },
287 BuiltinSignatureDescriptor {
288 label: "[X,Y] = meshgrid(x, y, \"like\", prototype)",
289 inputs: &MESHGRID_SIG_XY_LIKE_INPUTS,
290 outputs: &MESHGRID_OUTPUT_XY,
291 },
292 BuiltinSignatureDescriptor {
293 label: "[X,Y,Z] = meshgrid(x, y, z, \"like\", prototype)",
294 inputs: &MESHGRID_SIG_XYZ_LIKE_INPUTS,
295 outputs: &MESHGRID_OUTPUT_XYZ,
296 },
297];
298
299const MESHGRID_ERRORS: [BuiltinErrorDescriptor; 11] = [
300 BuiltinErrorDescriptor {
301 code: "RM.MESHGRID.MISSING_AXIS",
302 identifier: None,
303 when: "No axis vectors are provided.",
304 message: "meshgrid: at least one input vector is required",
305 },
306 BuiltinErrorDescriptor {
307 code: "RM.MESHGRID.TOO_MANY_AXES",
308 identifier: None,
309 when: "More than three axis vectors are provided.",
310 message: "meshgrid: expected at most three input vectors",
311 },
312 BuiltinErrorDescriptor {
313 code: "RM.MESHGRID.LIKE_EXPECTED_PROTOTYPE",
314 identifier: None,
315 when: "The 'like' keyword is provided without a prototype argument.",
316 message: "meshgrid: expected prototype after 'like'",
317 },
318 BuiltinErrorDescriptor {
319 code: "RM.MESHGRID.MULTIPLE_LIKE",
320 identifier: None,
321 when: "The 'like' keyword is provided multiple times.",
322 message: "meshgrid: multiple 'like' specifications are not supported",
323 },
324 BuiltinErrorDescriptor {
325 code: "RM.MESHGRID.LIKE_POSITION",
326 identifier: None,
327 when: "The 'like' keyword is in an invalid position or not final.",
328 message: "meshgrid: 'like' must be the final argument",
329 },
330 BuiltinErrorDescriptor {
331 code: "RM.MESHGRID.UNRECOGNIZED_OPTION",
332 identifier: None,
333 when: "A trailing option string is not recognized.",
334 message: "meshgrid: unrecognised option",
335 },
336 BuiltinErrorDescriptor {
337 code: "RM.MESHGRID.INVALID_AXIS_INPUT",
338 identifier: None,
339 when: "Axis inputs are non-numeric or non-vector shapes.",
340 message: "meshgrid: input argument must be numeric vector data",
341 },
342 BuiltinErrorDescriptor {
343 code: "RM.MESHGRID.INVALID_PROTOTYPE",
344 identifier: None,
345 when: "The 'like' prototype is unsupported.",
346 message: "meshgrid: prototypes must be numeric arrays",
347 },
348 BuiltinErrorDescriptor {
349 code: "RM.MESHGRID.OUTPUT_COUNT_EXCEEDED",
350 identifier: None,
351 when: "Requested outputs exceed available outputs for provided axes.",
352 message:
353 "meshgrid: supports at most two outputs for 2-axis inputs and three for 3-axis inputs",
354 },
355 BuiltinErrorDescriptor {
356 code: "RM.MESHGRID.THIRD_OUTPUT_UNAVAILABLE",
357 identifier: None,
358 when: "A third output is requested without supplying a Z-axis vector.",
359 message: "meshgrid: third output requested but no Z vector was supplied",
360 },
361 BuiltinErrorDescriptor {
362 code: "RM.MESHGRID.COMPLEX_REAL_CONVERSION",
363 identifier: None,
364 when: "Complex axis values cannot be represented in requested real output class.",
365 message: "meshgrid: cannot represent complex values in a real output",
366 },
367];
368
369pub const MESHGRID_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
370 signatures: &MESHGRID_SIGNATURES,
371 output_mode: BuiltinOutputMode::ByRequestedOutputCount,
372 completion_policy: BuiltinCompletionPolicy::Public,
373 errors: &MESHGRID_ERRORS,
374};
375
376#[runtime_builtin(
377 name = "meshgrid",
378 category = "array/creation",
379 summary = "Generate coordinate matrices for 2-D and 3-D grids.",
380 keywords = "meshgrid,grid,gpu,like,3d",
381 accel = "array_construct",
382 type_resolver(meshgrid_type),
383 descriptor(crate::builtins::array::creation::meshgrid::MESHGRID_DESCRIPTOR),
384 builtin_path = "crate::builtins::array::creation::meshgrid"
385)]
386async fn meshgrid_builtin(rest: Vec<Value>) -> crate::BuiltinResult<Value> {
387 let eval = evaluate(&rest).await?;
388 if let Some(out_count) = crate::output_count::current_output_count() {
389 if out_count == 0 {
390 return Ok(Value::OutputList(Vec::new()));
391 }
392 let available = eval.output_count();
393 if out_count > available {
394 let msg = if available == 2 {
395 "meshgrid with two inputs supports at most two outputs"
396 } else {
397 "meshgrid supports at most three outputs"
398 };
399 return Err(builtin_error(msg));
400 }
401 let mut outputs = Vec::with_capacity(out_count);
402 let first = eval.first().await?;
403 outputs.push(first);
404 if out_count >= 2 {
405 outputs.push(eval.second().await?);
406 }
407 if out_count >= 3 {
408 outputs.push(eval.third().await?);
409 }
410 return Ok(Value::OutputList(outputs));
411 }
412 eval.first().await
413}
414
415pub async fn evaluate(args: &[Value]) -> crate::BuiltinResult<MeshgridEval> {
417 let parsed = ParsedMeshgrid::parse(args).await?;
418 let (x_axis, y_axis, z_axis) = normalise_axes(&parsed.axes);
419
420 let require_complex = parsed.axes.iter().any(|axis| axis.is_complex);
421
422 let target_class = match &parsed.template {
423 OutputTemplate::Default => {
424 if require_complex {
425 PrototypeClass::Complex
426 } else {
427 PrototypeClass::Real
428 }
429 }
430 OutputTemplate::Like(spec) => {
431 if require_complex {
432 PrototypeClass::Complex
433 } else {
434 spec.class
435 }
436 }
437 };
438
439 let target_residency = match &parsed.template {
440 OutputTemplate::Default => {
441 if parsed.prefer_gpu {
442 DevicePreference::Gpu
443 } else {
444 DevicePreference::Host
445 }
446 }
447 OutputTemplate::Like(spec) => spec.residency,
448 };
449
450 let axes_all_real = !require_complex;
451 let mut outputs: Vec<MeshgridOutput> = Vec::new();
452
453 if axes_all_real
454 && matches!(target_class, PrototypeClass::Real)
455 && matches!(target_residency, DevicePreference::Gpu)
456 {
457 if let Some(gpu) = try_meshgrid_gpu_from_vector_axes(&x_axis, &y_axis, z_axis.as_ref())? {
458 outputs = gpu;
459 }
460 }
461
462 if outputs.is_empty() {
463 let x_host = axis_to_host_async(&x_axis).await?;
465 let y_host = axis_to_host_async(&y_axis).await?;
466 let z_host = match z_axis.as_ref() {
467 Some(axis) => Some(axis_to_host_async(axis).await?),
468 None => None,
469 };
470 outputs = build_outputs(&x_host, &y_host, z_host.as_ref())
471 .into_iter()
472 .map(MeshgridOutput::Host)
473 .collect();
474 }
475
476 Ok(MeshgridEval {
477 outputs,
478 target_class,
479 target_residency,
480 })
481}
482
483#[derive(Clone)]
484struct ParsedMeshgrid {
485 axes: Vec<AxisData>,
486 template: OutputTemplate,
487 prefer_gpu: bool,
488}
489
490impl ParsedMeshgrid {
491 async fn parse(args: &[Value]) -> crate::BuiltinResult<Self> {
492 if args.is_empty() {
493 return Err(builtin_error(
494 "meshgrid: at least one input vector is required",
495 ));
496 }
497 let mut axis_values: Vec<Value> = Vec::new();
498 let mut like_proto: Option<Value> = None;
499 let mut prefer_gpu = false;
500 let mut idx = 0;
501 while idx < args.len() {
502 let value = args[idx].clone();
503 if let Some(keyword) = keyword_of(&value) {
504 match keyword.as_str() {
505 "like" => {
506 if like_proto.is_some() {
507 return Err(builtin_error(
508 "meshgrid: multiple 'like' specifications are not supported",
509 ));
510 }
511 if axis_values.is_empty() {
512 return Err(builtin_error(
513 "meshgrid: 'like' must follow at least one input vector",
514 ));
515 }
516 let Some(proto) = args.get(idx + 1).cloned() else {
517 return Err(builtin_error("meshgrid: expected prototype after 'like'"));
518 };
519 like_proto = Some(proto);
520 idx += 2;
521 if idx < args.len() {
522 return Err(builtin_error(
523 "meshgrid: 'like' must be the final argument",
524 ));
525 }
526 break;
527 }
528 other => {
529 return Err(builtin_error(format!(
530 "meshgrid: unrecognised option '{other}'"
531 )));
532 }
533 }
534 }
535
536 if let Value::GpuTensor(_) = value {
537 prefer_gpu = true;
538 }
539 axis_values.push(value);
540 idx += 1;
541 }
542
543 if axis_values.is_empty() {
544 return Err(builtin_error(
545 "meshgrid: at least one input vector is required",
546 ));
547 }
548 if axis_values.len() > 3 {
549 return Err(builtin_error(
550 "meshgrid: expected at most three input vectors",
551 ));
552 }
553
554 let mut axes = Vec::with_capacity(max(axis_values.len(), 2));
555 for (i, value) in axis_values.into_iter().enumerate() {
556 let mut consumed_gpu = false;
557 let data = axis_from_value(value, i, &mut consumed_gpu).await?;
558 if consumed_gpu {
559 prefer_gpu = true;
560 }
561 axes.push(data);
562 }
563
564 if !prefer_gpu {
565 if let Some(max_len) = axes.iter().map(|axis| axis.len).max() {
566 if max_len > 0
567 && sequence_gpu_preference(max_len, SequenceIntent::MeshAxis, false).prefer_gpu
568 {
569 prefer_gpu = true;
570 }
571 }
572 }
573
574 let template = if let Some(proto) = like_proto {
575 OutputTemplate::Like(analyse_like_prototype(&proto)?)
576 } else {
577 OutputTemplate::Default
578 };
579
580 Ok(Self {
581 axes,
582 template,
583 prefer_gpu,
584 })
585 }
586}
587
588#[derive(Clone)]
589enum OutputTemplate {
590 Default,
591 Like(PrototypeSpec),
592}
593
594#[derive(Clone)]
595struct PrototypeSpec {
596 residency: DevicePreference,
597 class: PrototypeClass,
598}
599
600#[derive(Clone, Copy, PartialEq, Eq)]
601enum PrototypeClass {
602 Real,
603 Complex,
604}
605
606#[derive(Clone, Copy)]
607enum DevicePreference {
608 Host,
609 Gpu,
610}
611
612fn analyse_like_prototype(proto: &Value) -> crate::BuiltinResult<PrototypeSpec> {
613 match proto {
614 Value::GpuTensor(_) => Ok(PrototypeSpec {
615 residency: DevicePreference::Gpu,
616 class: PrototypeClass::Real,
617 }),
618 Value::ComplexTensor(_) | Value::Complex(_, _) => Ok(PrototypeSpec {
619 residency: DevicePreference::Host,
620 class: PrototypeClass::Complex,
621 }),
622 Value::Tensor(_)
623 | Value::Num(_)
624 | Value::Int(_)
625 | Value::Bool(_)
626 | Value::LogicalArray(_) => Ok(PrototypeSpec {
627 residency: DevicePreference::Host,
628 class: PrototypeClass::Real,
629 }),
630 Value::CharArray(_) | Value::String(_) | Value::StringArray(_) => Err(builtin_error(
631 "meshgrid: prototypes must be numeric or gpuArray values",
632 )),
633 Value::Cell(_)
634 | Value::Struct(_)
635 | Value::Object(_)
636 | Value::HandleObject(_)
637 | Value::Listener(_)
638 | Value::FunctionHandle(_)
639 | Value::ExternalFunctionHandle(_)
640 | Value::MethodFunctionHandle(_)
641 | Value::BoundFunctionHandle { .. }
642 | Value::Closure(_)
643 | Value::ClassRef(_)
644 | Value::MException(_)
645 | Value::OutputList(_) => Err(builtin_error("meshgrid: prototypes must be numeric arrays")),
646 }
647}
648
649#[derive(Clone)]
650struct AxisData {
651 values: Vec<(f64, f64)>,
652 len: usize,
653 is_complex: bool,
654 gpu_real: Option<GpuTensorHandle>,
655}
656
657async fn axis_from_value(
658 value: Value,
659 index: usize,
660 prefer_gpu: &mut bool,
661) -> crate::BuiltinResult<AxisData> {
662 match value {
663 Value::Tensor(tensor) => axis_from_tensor(tensor, index),
664 Value::LogicalArray(logical) => {
665 let tensor = tensor::logical_to_tensor(&logical)?;
666 axis_from_tensor(tensor, index)
667 }
668 Value::Num(n) => Ok(AxisData {
669 values: vec![(n, 0.0)],
670 len: 1,
671 is_complex: false,
672 gpu_real: None,
673 }),
674 Value::Int(i) => {
675 let val = i.to_f64();
676 Ok(AxisData {
677 values: vec![(val, 0.0)],
678 len: 1,
679 is_complex: false,
680 gpu_real: None,
681 })
682 }
683 Value::Bool(b) => Ok(AxisData {
684 values: vec![(if b { 1.0 } else { 0.0 }, 0.0)],
685 len: 1,
686 is_complex: false,
687 gpu_real: None,
688 }),
689 Value::Complex(re, im) => Ok(AxisData {
690 values: vec![(re, im)],
691 len: 1,
692 is_complex: im != 0.0,
693 gpu_real: None,
694 }),
695 Value::ComplexTensor(tensor) => axis_from_complex_tensor(tensor, index),
696 Value::GpuTensor(handle) => {
697 if is_vector_shape(&handle.shape) {
700 *prefer_gpu = true;
701 return Ok(AxisData {
702 values: Vec::new(),
703 len: vector_len_from_shape(&handle.shape),
704 is_complex: false,
705 gpu_real: Some(handle),
706 });
707 }
708
709 let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
711 if is_vector_shape(&tensor.shape) {
712 *prefer_gpu = true;
713 }
714 axis_from_tensor(tensor, index)
715 }
716 other => Err(builtin_error(format!(
717 "meshgrid: input argument {} must be numeric, got {other:?}",
718 index + 1
719 ))),
720 }
721}
722
723fn axis_from_tensor(tensor: Tensor, index: usize) -> crate::BuiltinResult<AxisData> {
724 if is_vector_shape(&tensor.shape) {
725 let mut values = Vec::with_capacity(tensor.data.len());
726 for &v in &tensor.data {
727 values.push((v, 0.0));
728 }
729 return Ok(AxisData {
730 len: values.len(),
731 values,
732 is_complex: false,
733 gpu_real: None,
734 });
735 }
736
737 if let Some(axis) = axis_from_meshgrid_matrix_real(&tensor, index)? {
743 return Ok(axis);
744 }
745
746 Err(builtin_error(format!(
747 "meshgrid: input argument {} must be a vector (1xN or Nx1), got shape {:?}",
748 index + 1,
749 tensor.shape
750 )))
751}
752
753fn axis_from_complex_tensor(tensor: ComplexTensor, index: usize) -> crate::BuiltinResult<AxisData> {
754 if is_vector_shape(&tensor.shape) {
755 let is_complex = tensor
756 .data
757 .iter()
758 .any(|&(_, imag)| !imag.is_nan() && imag != 0.0);
759 return Ok(AxisData {
760 len: tensor.data.len(),
761 values: tensor.data,
762 is_complex,
763 gpu_real: None,
764 });
765 }
766
767 if let Some(axis) = axis_from_meshgrid_matrix_complex(&tensor, index)? {
768 return Ok(axis);
769 }
770
771 Err(builtin_error(format!(
772 "meshgrid: input argument {} must be a vector (1xN or Nx1), got shape {:?}",
773 index + 1,
774 tensor.shape
775 )))
776}
777
778fn axis_from_meshgrid_matrix_real(
779 tensor: &Tensor,
780 index: usize,
781) -> crate::BuiltinResult<Option<AxisData>> {
782 let (rows, cols) = match tensor.shape.as_slice() {
783 [r, c] => (*r, *c),
784 _ => return Ok(None),
785 };
786 if rows <= 1 || cols <= 1 {
787 return Ok(None);
788 }
789
790 let expect_rows_constant = index == 0;
793
794 if expect_rows_constant {
795 if !matrix_rows_are_identical_real(tensor, rows, cols) {
796 return Ok(None);
797 }
798 let mut values = Vec::with_capacity(cols);
800 for col in 0..cols {
801 let idx = rows * col;
802 values.push((tensor.data[idx], 0.0));
803 }
804 return Ok(Some(AxisData {
805 len: values.len(),
806 values,
807 is_complex: false,
808 gpu_real: None,
809 }));
810 }
811
812 if !matrix_cols_are_identical_real(tensor, rows, cols) {
813 return Ok(None);
814 }
815 let mut values = Vec::with_capacity(rows);
817 for row in 0..rows {
818 values.push((tensor.data[row], 0.0));
819 }
820 Ok(Some(AxisData {
821 len: values.len(),
822 values,
823 is_complex: false,
824 gpu_real: None,
825 }))
826}
827
828fn axis_from_meshgrid_matrix_complex(
829 tensor: &ComplexTensor,
830 index: usize,
831) -> crate::BuiltinResult<Option<AxisData>> {
832 let (rows, cols) = match tensor.shape.as_slice() {
833 [r, c] => (*r, *c),
834 _ => return Ok(None),
835 };
836 if rows <= 1 || cols <= 1 {
837 return Ok(None);
838 }
839
840 let expect_rows_constant = index == 0;
841 if expect_rows_constant {
842 if !matrix_rows_are_identical_complex(tensor, rows, cols) {
843 return Ok(None);
844 }
845 let mut values = Vec::with_capacity(cols);
846 for col in 0..cols {
847 let idx = rows * col;
848 values.push(tensor.data[idx]);
849 }
850 let is_complex = values.iter().any(|&(_, im)| !im.is_nan() && im != 0.0);
851 return Ok(Some(AxisData {
852 len: values.len(),
853 values,
854 is_complex,
855 gpu_real: None,
856 }));
857 }
858
859 if !matrix_cols_are_identical_complex(tensor, rows, cols) {
860 return Ok(None);
861 }
862 let mut values = Vec::with_capacity(rows);
863 for row in 0..rows {
864 values.push(tensor.data[row]);
865 }
866 let is_complex = values.iter().any(|&(_, im)| !im.is_nan() && im != 0.0);
867 Ok(Some(AxisData {
868 len: values.len(),
869 values,
870 is_complex,
871 gpu_real: None,
872 }))
873}
874
875fn matrix_rows_are_identical_real(tensor: &Tensor, rows: usize, cols: usize) -> bool {
876 for row in 1..rows {
877 for col in 0..cols {
878 let idx0 = rows * col;
879 let idx = row + rows * col;
880 if tensor.data[idx] != tensor.data[idx0] {
881 return false;
882 }
883 }
884 }
885 true
886}
887
888fn matrix_cols_are_identical_real(tensor: &Tensor, rows: usize, cols: usize) -> bool {
889 for col in 1..cols {
890 for row in 0..rows {
891 let idx0 = row;
892 let idx = row + rows * col;
893 if tensor.data[idx] != tensor.data[idx0] {
894 return false;
895 }
896 }
897 }
898 true
899}
900
901fn matrix_rows_are_identical_complex(tensor: &ComplexTensor, rows: usize, cols: usize) -> bool {
902 for row in 1..rows {
903 for col in 0..cols {
904 let idx0 = rows * col;
905 let idx = row + rows * col;
906 if tensor.data[idx] != tensor.data[idx0] {
907 return false;
908 }
909 }
910 }
911 true
912}
913
914fn matrix_cols_are_identical_complex(tensor: &ComplexTensor, rows: usize, cols: usize) -> bool {
915 for col in 1..cols {
916 for row in 0..rows {
917 let idx0 = row;
918 let idx = row + rows * col;
919 if tensor.data[idx] != tensor.data[idx0] {
920 return false;
921 }
922 }
923 }
924 true
925}
926
927fn is_vector_shape(shape: &[usize]) -> bool {
928 if shape.is_empty() {
929 return true;
930 }
931 let mut non_singleton = 0usize;
932 for &dim in shape {
933 if dim > 1 {
934 non_singleton += 1;
935 }
936 }
937 non_singleton <= 1
938}
939
940fn vector_len_from_shape(shape: &[usize]) -> usize {
941 if shape.is_empty() {
942 return 1;
943 }
944 shape.iter().copied().max().unwrap_or(0)
945}
946
947async fn axis_to_host_async(axis: &AxisData) -> crate::BuiltinResult<AxisData> {
948 if axis.gpu_real.is_none() {
949 return Ok(axis.clone());
950 }
951 let handle = axis.gpu_real.as_ref().expect("checked gpu_real is_some");
952 let tensor = gpu_helpers::gather_tensor_async(handle).await?;
953 axis_from_tensor(tensor, 0)
955}
956
957fn try_meshgrid_gpu_from_vector_axes(
958 x_axis: &AxisData,
959 y_axis: &AxisData,
960 z_axis: Option<&AxisData>,
961) -> crate::BuiltinResult<Option<Vec<MeshgridOutput>>> {
962 let Some(x_handle) = x_axis.gpu_real.as_ref() else {
963 return Ok(None);
964 };
965 let Some(y_handle) = y_axis.gpu_real.as_ref() else {
966 return Ok(None);
967 };
968
969 let z_handle = match z_axis {
970 Some(axis) => match axis.gpu_real.as_ref() {
971 Some(h) => Some(h),
972 None => return Ok(None),
973 },
974 None => None,
975 };
976
977 let Some(provider) = runmat_accelerate_api::provider_for_handle(x_handle) else {
978 return Ok(None);
979 };
980 if runmat_accelerate_api::provider_for_handle(y_handle).is_none() {
981 return Ok(None);
982 }
983 if let Some(z) = z_handle {
984 if runmat_accelerate_api::provider_for_handle(z).is_none() {
985 return Ok(None);
986 }
987 }
988
989 let nx = x_axis.len;
990 let ny = y_axis.len;
991 let nz = z_axis.map(|axis| axis.len).unwrap_or(1);
992
993 let x_row = provider
995 .reshape(x_handle, &[1, nx])
996 .map_err(|e| builtin_error(format!("meshgrid: reshape X failed: {e}")))?;
997 let y_col = provider
998 .reshape(y_handle, &[ny, 1])
999 .map_err(|e| builtin_error(format!("meshgrid: reshape Y failed: {e}")))?;
1000
1001 let mut outputs = Vec::with_capacity(if z_handle.is_some() { 3 } else { 2 });
1002 if let Some(z) = z_handle {
1003 let x_base = provider
1004 .reshape(&x_row, &[1, nx, 1])
1005 .map_err(|e| builtin_error(format!("meshgrid: reshape X(3d) failed: {e}")))?;
1006 let y_base = provider
1007 .reshape(&y_col, &[ny, 1, 1])
1008 .map_err(|e| builtin_error(format!("meshgrid: reshape Y(3d) failed: {e}")))?;
1009
1010 let x_grid = provider
1011 .repmat(&x_base, &[ny, 1, nz])
1012 .map_err(|e| builtin_error(format!("meshgrid: repmat X failed: {e}")))?;
1013 let y_grid = provider
1014 .repmat(&y_base, &[1, nx, nz])
1015 .map_err(|e| builtin_error(format!("meshgrid: repmat Y failed: {e}")))?;
1016
1017 outputs.push(MeshgridOutput::GpuReal(x_grid));
1018 outputs.push(MeshgridOutput::GpuReal(y_grid));
1019 let z_axis_row = provider
1020 .reshape(z, &[1, nz])
1021 .map_err(|e| builtin_error(format!("meshgrid: reshape Z failed: {e}")))?;
1022 let z_base = provider
1023 .reshape(&z_axis_row, &[1, 1, nz])
1024 .map_err(|e| builtin_error(format!("meshgrid: reshape Z(3d) failed: {e}")))?;
1025 let z_grid = provider
1026 .repmat(&z_base, &[ny, nx, 1])
1027 .map_err(|e| builtin_error(format!("meshgrid: repmat Z failed: {e}")))?;
1028 outputs.push(MeshgridOutput::GpuReal(z_grid));
1029 } else {
1030 let x_grid = provider
1031 .repmat(&x_row, &[ny, 1])
1032 .map_err(|e| builtin_error(format!("meshgrid: repmat X failed: {e}")))?;
1033 let y_grid = provider
1034 .repmat(&y_col, &[1, nx])
1035 .map_err(|e| builtin_error(format!("meshgrid: repmat Y failed: {e}")))?;
1036 outputs.push(MeshgridOutput::GpuReal(x_grid));
1037 outputs.push(MeshgridOutput::GpuReal(y_grid));
1038 }
1039
1040 Ok(Some(outputs))
1041}
1042
1043fn normalise_axes(axes: &[AxisData]) -> (AxisData, AxisData, Option<AxisData>) {
1044 match axes.len() {
1045 1 => {
1046 let x = axes[0].clone();
1047 (x.clone(), x, None)
1048 }
1049 2 => {
1050 let x = axes[0].clone();
1051 let y = axes[1].clone();
1052 (x, y, None)
1053 }
1054 3 => {
1055 let x = axes[0].clone();
1056 let y = axes[1].clone();
1057 let z = axes[2].clone();
1058 (x, y, Some(z))
1059 }
1060 _ => unreachable!(),
1061 }
1062}
1063
1064fn build_outputs(
1065 x_axis: &AxisData,
1066 y_axis: &AxisData,
1067 z_axis: Option<&AxisData>,
1068) -> Vec<GridOutput> {
1069 let nx = x_axis.len;
1070 let ny = y_axis.len;
1071 let nz = z_axis.map(|axis| axis.len).unwrap_or(1);
1072 let total = nx * ny * nz;
1073 let mut x_data = Vec::with_capacity(total);
1074 let mut y_data = Vec::with_capacity(total);
1075 let mut z_data = z_axis.map(|_| Vec::with_capacity(total));
1076
1077 for k in 0..nz {
1078 let z_value = z_axis.map(|axis| axis.values[k]);
1079 for col in 0..nx {
1080 let x_value = x_axis.values[col];
1081 for row in 0..ny {
1082 x_data.push(x_value);
1083 y_data.push(y_axis.values[row]);
1084 if let Some(ref mut z_vec) = z_data {
1085 z_vec.push(z_value.unwrap());
1086 }
1087 }
1088 }
1089 }
1090
1091 let mut outputs = Vec::new();
1092 let base_shape = if nz == 1 {
1093 vec![ny, nx]
1094 } else {
1095 vec![ny, nx, nz]
1096 };
1097 outputs.push(GridOutput {
1098 shape: base_shape.clone(),
1099 data: x_data,
1100 });
1101 outputs.push(GridOutput {
1102 shape: base_shape.clone(),
1103 data: y_data,
1104 });
1105 if let Some(z_vec) = z_data {
1106 outputs.push(GridOutput {
1107 shape: base_shape,
1108 data: z_vec,
1109 });
1110 }
1111 outputs
1112}
1113
1114struct GridOutput {
1115 shape: Vec<usize>,
1116 data: Vec<(f64, f64)>,
1117}
1118
1119impl GridOutput {
1120 fn to_value(
1121 &self,
1122 class: PrototypeClass,
1123 residency: DevicePreference,
1124 ) -> crate::BuiltinResult<Value> {
1125 match class {
1126 PrototypeClass::Real => self.to_real_value(residency),
1127 PrototypeClass::Complex => self.to_complex_value(residency),
1128 }
1129 }
1130
1131 fn to_real_value(&self, residency: DevicePreference) -> crate::BuiltinResult<Value> {
1132 let mut real = Vec::with_capacity(self.data.len());
1133 for &(re, im) in &self.data {
1134 if im != 0.0 {
1135 return Err(builtin_error(
1136 "meshgrid: cannot represent complex values in a real output",
1137 ));
1138 }
1139 real.push(re);
1140 }
1141 let tensor = Tensor::new(real, self.shape.clone())
1142 .map_err(|e| builtin_error(format!("meshgrid: {e}")))?;
1143 match residency {
1144 DevicePreference::Host => Ok(tensor::tensor_into_value(tensor)),
1145 DevicePreference::Gpu => to_gpu_tensor_value(tensor),
1146 }
1147 }
1148
1149 fn to_complex_value(&self, residency: DevicePreference) -> crate::BuiltinResult<Value> {
1150 let tensor = ComplexTensor::new(self.data.clone(), self.shape.clone())
1151 .map_err(|e| builtin_error(format!("meshgrid: {e}")))?;
1152 match residency {
1153 DevicePreference::Host => Ok(complex_tensor_into_value(tensor)),
1154 DevicePreference::Gpu => {
1155 warn!("meshgrid: complex GPU outputs are not implemented; returning host complex array");
1156 Ok(complex_tensor_into_value(tensor))
1157 }
1158 }
1159 }
1160}
1161
1162fn to_gpu_tensor_value(tensor: Tensor) -> crate::BuiltinResult<Value> {
1163 if let Some(provider) = runmat_accelerate_api::provider() {
1164 let view = HostTensorView {
1165 data: &tensor.data,
1166 shape: &tensor.shape,
1167 };
1168 match provider.upload(&view) {
1169 Ok(handle) => return Ok(Value::GpuTensor(handle)),
1170 Err(err) => {
1171 warn!("meshgrid: failed to upload tensor to GPU, returning host array: {err}")
1172 }
1173 }
1174 }
1175 Ok(tensor::tensor_into_value(tensor))
1176}
1177
1178fn tensor_to_complex_value(tensor: Tensor) -> crate::BuiltinResult<Value> {
1179 let data: Vec<(f64, f64)> = tensor.data.iter().map(|&re| (re, 0.0)).collect();
1180 let complex = ComplexTensor::new(data, tensor.shape.clone())
1181 .map_err(|e| builtin_error(format!("meshgrid: {e}")))?;
1182 Ok(complex_tensor_into_value(complex))
1183}
1184
1185enum MeshgridOutput {
1186 Host(GridOutput),
1187 GpuReal(GpuTensorHandle),
1188}
1189
1190impl MeshgridOutput {
1191 async fn to_value(
1192 &self,
1193 class: PrototypeClass,
1194 residency: DevicePreference,
1195 ) -> crate::BuiltinResult<Value> {
1196 match self {
1197 MeshgridOutput::Host(host) => host.to_value(class, residency),
1198 MeshgridOutput::GpuReal(handle) => match (class, residency) {
1199 (PrototypeClass::Real, DevicePreference::Gpu) => {
1200 Ok(Value::GpuTensor(handle.clone()))
1201 }
1202 (PrototypeClass::Real, DevicePreference::Host) => {
1203 let tensor = gpu_helpers::gather_tensor_async(handle).await?;
1204 Ok(tensor::tensor_into_value(tensor))
1205 }
1206 (PrototypeClass::Complex, DevicePreference::Host) => {
1207 let tensor = gpu_helpers::gather_tensor_async(handle).await?;
1208 tensor_to_complex_value(tensor)
1209 }
1210 (PrototypeClass::Complex, DevicePreference::Gpu) => {
1211 warn!("meshgrid: complex GPU outputs are not implemented; returning host complex array");
1212 let tensor = gpu_helpers::gather_tensor_async(handle).await?;
1213 tensor_to_complex_value(tensor)
1214 }
1215 },
1216 }
1217 }
1218}
1219
1220pub struct MeshgridEval {
1223 outputs: Vec<MeshgridOutput>,
1224 target_class: PrototypeClass,
1225 target_residency: DevicePreference,
1226}
1227
1228impl MeshgridEval {
1229 pub fn output_count(&self) -> usize {
1230 self.outputs.len()
1231 }
1232
1233 pub async fn first(&self) -> crate::BuiltinResult<Value> {
1234 self.outputs[0]
1235 .to_value(self.target_class, self.target_residency)
1236 .await
1237 }
1238
1239 pub async fn second(&self) -> crate::BuiltinResult<Value> {
1240 if self.outputs.len() < 2 {
1241 Err(builtin_error("meshgrid: second output unavailable"))
1242 } else {
1243 self.outputs[1]
1244 .to_value(self.target_class, self.target_residency)
1245 .await
1246 }
1247 }
1248
1249 pub async fn third(&self) -> crate::BuiltinResult<Value> {
1250 if self.outputs.len() < 3 {
1251 Err(builtin_error(
1252 "meshgrid: third output requested but no Z vector was supplied",
1253 ))
1254 } else {
1255 self.outputs[2]
1256 .to_value(self.target_class, self.target_residency)
1257 .await
1258 }
1259 }
1260}
1261
1262#[cfg(test)]
1263pub(crate) mod tests {
1264 use super::*;
1265 use crate::builtins::common::test_support;
1266 use futures::executor::block_on;
1267 #[cfg(feature = "wgpu")]
1268 use runmat_accelerate_api::AccelProvider;
1269
1270 use runmat_accelerate_api::HostTensorView;
1271
1272 fn evaluate(args: &[Value]) -> crate::BuiltinResult<MeshgridEval> {
1273 block_on(super::evaluate(args))
1274 }
1275
1276 fn eval_first(eval: &MeshgridEval) -> crate::BuiltinResult<Value> {
1277 block_on(eval.first())
1278 }
1279
1280 fn eval_second(eval: &MeshgridEval) -> crate::BuiltinResult<Value> {
1281 block_on(eval.second())
1282 }
1283
1284 fn eval_third(eval: &MeshgridEval) -> crate::BuiltinResult<Value> {
1285 block_on(eval.third())
1286 }
1287
1288 fn tensor_from_vec(data: Vec<f64>, rows: usize, cols: usize) -> Tensor {
1289 Tensor::new(data, vec![rows, cols]).unwrap()
1290 }
1291
1292 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1293 #[test]
1294 fn meshgrid_single_input_duplicates_axis() {
1295 let x = tensor_from_vec(vec![-1.0, 0.0, 1.0], 1, 3);
1296 let eval = evaluate(&[Value::Tensor(x)]).expect("meshgrid");
1297 assert_eq!(eval.output_count(), 2);
1298 let x_out = test_support::gather(eval_first(&eval).expect("X")).expect("host");
1299 assert_eq!(x_out.shape, vec![3, 3]);
1300 assert_eq!(
1301 x_out.data,
1302 vec![-1.0, -1.0, -1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0]
1303 );
1304 let y_out = test_support::gather(eval_second(&eval).expect("Y")).expect("host");
1305 assert_eq!(y_out.shape, vec![3, 3]);
1306 assert_eq!(
1307 y_out.data,
1308 vec![-1.0, 0.0, 1.0, -1.0, 0.0, 1.0, -1.0, 0.0, 1.0]
1309 );
1310 }
1311
1312 #[test]
1313 fn meshgrid_type_infers_rank_from_axis_count() {
1314 let ctx = ResolveContext::new(Vec::new());
1315 assert_eq!(
1316 meshgrid_type(&[Type::Num, Type::Num], &ctx),
1317 Type::Tensor {
1318 shape: Some(vec![Some(1), Some(1)])
1319 }
1320 );
1321 assert_eq!(
1322 meshgrid_type(&[Type::Num, Type::Num, Type::Num], &ctx),
1323 Type::Tensor {
1324 shape: Some(vec![Some(1), Some(1), Some(1)])
1325 }
1326 );
1327 }
1328
1329 #[test]
1330 fn meshgrid_type_uses_vector_lengths() {
1331 let ctx = ResolveContext::new(Vec::new());
1332 assert_eq!(
1333 meshgrid_type(
1334 &[
1335 Type::Tensor {
1336 shape: Some(vec![Some(1), Some(201)]),
1337 },
1338 Type::Tensor {
1339 shape: Some(vec![Some(1), Some(101)]),
1340 },
1341 ],
1342 &ctx,
1343 ),
1344 Type::Tensor {
1345 shape: Some(vec![Some(101), Some(201)])
1346 }
1347 );
1348 }
1349
1350 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1351 #[test]
1352 fn meshgrid_rectangular_inputs() {
1353 let x = tensor_from_vec(vec![0.0, 0.5, 1.0], 1, 3);
1354 let y = tensor_from_vec(vec![10.0, 20.0], 2, 1);
1355 let eval = evaluate(&[Value::Tensor(x), Value::Tensor(y)]).expect("meshgrid");
1356 assert_eq!(eval.output_count(), 2);
1357 let x_out = test_support::gather(eval_first(&eval).expect("X")).expect("host");
1358 assert_eq!(x_out.shape, vec![2, 3]);
1359 assert_eq!(x_out.data, vec![0.0, 0.0, 0.5, 0.5, 1.0, 1.0]);
1360 let y_out = test_support::gather(eval_second(&eval).expect("Y")).expect("host");
1361 assert_eq!(y_out.shape, vec![2, 3]);
1362 assert_eq!(y_out.data, vec![10.0, 20.0, 10.0, 20.0, 10.0, 20.0]);
1363 }
1364
1365 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1366 #[test]
1367 fn meshgrid_three_inputs_volume() {
1368 let x = tensor_from_vec(vec![1.0, 2.0], 1, 2);
1369 let y = tensor_from_vec(vec![5.0, 6.0, 7.0], 3, 1);
1370 let z = tensor_from_vec(vec![0.0, 1.0], 1, 2);
1371 let eval =
1372 evaluate(&[Value::Tensor(x), Value::Tensor(y), Value::Tensor(z)]).expect("meshgrid");
1373 assert_eq!(eval.output_count(), 3);
1374 let x_out = test_support::gather(eval_first(&eval).expect("X")).expect("host");
1375 assert_eq!(x_out.shape, vec![3, 2, 2]);
1376 assert_eq!(
1377 x_out.data,
1378 vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0]
1379 );
1380 let z_out = test_support::gather(eval_third(&eval).expect("Z")).expect("host");
1381 assert_eq!(z_out.shape, vec![3, 2, 2]);
1382 assert_eq!(
1383 z_out.data,
1384 vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
1385 );
1386 }
1387
1388 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1389 #[test]
1390 fn meshgrid_like_keeps_gpu_residency() {
1391 test_support::with_test_provider(|provider| {
1392 let x = tensor_from_vec(vec![-1.0, 0.0, 1.0], 1, 3);
1393 let y = tensor_from_vec(vec![2.0, 4.0], 2, 1);
1394 let proto = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
1395 let proto_view = HostTensorView {
1396 data: &proto.data,
1397 shape: &proto.shape,
1398 };
1399 let proto_handle = provider.upload(&proto_view).expect("upload");
1400 let eval = evaluate(&[
1401 Value::Tensor(x),
1402 Value::Tensor(y),
1403 Value::from("like"),
1404 Value::GpuTensor(proto_handle),
1405 ])
1406 .expect("meshgrid");
1407 let x_value = eval_first(&eval).expect("X");
1408 assert!(matches!(x_value, Value::GpuTensor(_)));
1409 let gathered = test_support::gather(x_value).expect("gather");
1410 assert_eq!(gathered.shape, vec![2, 3]);
1411 });
1412 }
1413
1414 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1415 #[test]
1416 fn meshgrid_gpu_inputs_roundtrip() {
1417 test_support::with_test_provider(|provider| {
1418 let x = tensor_from_vec(vec![0.0, 0.5], 1, 2);
1419 let y = tensor_from_vec(vec![1.0, 2.0], 2, 1);
1420 let x_view = HostTensorView {
1421 data: &x.data,
1422 shape: &x.shape,
1423 };
1424 let y_view = HostTensorView {
1425 data: &y.data,
1426 shape: &y.shape,
1427 };
1428 let x_handle = provider.upload(&x_view).expect("upload");
1429 let y_handle = provider.upload(&y_view).expect("upload");
1430 let eval = evaluate(&[Value::GpuTensor(x_handle), Value::GpuTensor(y_handle)])
1431 .expect("meshgrid");
1432 assert!(matches!(eval_first(&eval).expect("X"), Value::GpuTensor(_)));
1433 assert!(matches!(
1434 eval_second(&eval).expect("Y"),
1435 Value::GpuTensor(_)
1436 ));
1437 });
1438 }
1439
1440 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1441 #[test]
1442 #[cfg(feature = "wgpu")]
1443 fn meshgrid_wgpu_matches_cpu() {
1444 let provider = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1445 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1446 )
1447 .expect("wgpu provider");
1448
1449 let x = tensor_from_vec(vec![-1.0, 0.0, 1.0, 2.0], 1, 4);
1450 let y = tensor_from_vec(vec![5.0, 6.0], 2, 1);
1451
1452 let cpu_eval =
1453 evaluate(&[Value::Tensor(x.clone()), Value::Tensor(y.clone())]).expect("meshgrid cpu");
1454 let cpu_x =
1455 test_support::gather(eval_first(&cpu_eval).expect("X cpu")).expect("gather X cpu");
1456 let cpu_y =
1457 test_support::gather(eval_second(&cpu_eval).expect("Y cpu")).expect("gather Y cpu");
1458
1459 let x_view = HostTensorView {
1460 data: &x.data,
1461 shape: &x.shape,
1462 };
1463 let y_view = HostTensorView {
1464 data: &y.data,
1465 shape: &y.shape,
1466 };
1467 let x_gpu = provider.upload(&x_view).expect("upload x");
1468 let y_gpu = provider.upload(&y_view).expect("upload y");
1469
1470 let gpu_eval =
1471 evaluate(&[Value::GpuTensor(x_gpu), Value::GpuTensor(y_gpu)]).expect("meshgrid gpu");
1472 let gpu_x_value = eval_first(&gpu_eval).expect("X gpu");
1473 let gpu_y_value = eval_second(&gpu_eval).expect("Y gpu");
1474
1475 assert!(matches!(gpu_x_value, Value::GpuTensor(_)));
1476 assert!(matches!(gpu_y_value, Value::GpuTensor(_)));
1477
1478 let gathered_x = test_support::gather(gpu_x_value).expect("gather X gpu");
1479 let gathered_y = test_support::gather(gpu_y_value).expect("gather Y gpu");
1480
1481 assert_eq!(gathered_x.shape, cpu_x.shape);
1482 assert_eq!(gathered_x.data, cpu_x.data);
1483 assert_eq!(gathered_y.shape, cpu_y.shape);
1484 assert_eq!(gathered_y.data, cpu_y.data);
1485 }
1486
1487 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1488 #[test]
1489 fn meshgrid_complex_inputs_produce_complex_outputs() {
1490 let complex = ComplexTensor::new(vec![(1.0, 1.0), (2.0, -1.0)], vec![1, 2]).unwrap();
1491 let eval = evaluate(&[Value::ComplexTensor(complex)]).expect("meshgrid");
1492 let x_value = eval_first(&eval).expect("X");
1493 match x_value {
1494 Value::ComplexTensor(ct) => {
1495 assert_eq!(ct.shape, vec![2, 2]);
1496 }
1497 Value::Complex(_, _) => {}
1498 other => panic!("expected complex output, got {other:?}"),
1499 }
1500 }
1501
1502 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1503 #[test]
1504 fn meshgrid_like_host_prototype() {
1505 let x = tensor_from_vec(vec![1.0, 2.0], 1, 2);
1506 let eval =
1507 evaluate(&[Value::Tensor(x), Value::from("like"), Value::Num(0.0)]).expect("meshgrid");
1508 let x_out = eval_first(&eval).expect("X");
1509 assert!(matches!(x_out, Value::Tensor(_) | Value::Num(_)));
1510 }
1511}