1use std::cmp::max;
4
5use log::warn;
6use runmat_accelerate_api::{GpuTensorHandle, HostTensorView};
7use runmat_builtins::{
8 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
9 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
10 ComplexTensor, ResolveContext, Tensor, Type, Value,
11};
12
13use crate::builtins::array::type_resolvers::size_vector_len;
14use runmat_macros::runtime_builtin;
15
16use crate::build_runtime_error;
17use crate::builtins::common::gpu_helpers;
18use crate::builtins::common::random_args::{complex_tensor_into_value, keyword_of};
19use crate::builtins::common::residency::{sequence_gpu_preference, SequenceIntent};
20use crate::builtins::common::spec::{
21 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
22 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
23};
24use crate::builtins::common::tensor;
25
26#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::creation::meshgrid")]
27pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
28 name: "meshgrid",
29 op_kind: GpuOpKind::Custom("array_construct"),
30 supported_precisions: &[ScalarType::F32, ScalarType::F64],
31 broadcast: BroadcastSemantics::Matlab,
32 provider_hooks: &[ProviderHook::Custom("meshgrid")],
33 constant_strategy: ConstantStrategy::InlineLiteral,
34 residency: ResidencyPolicy::NewHandle,
35 nan_mode: ReductionNaN::Include,
36 two_pass_threshold: None,
37 workgroup_size: None,
38 accepts_nan_mode: false,
39 notes: "Providers may supply a dedicated meshgrid hook; until then the runtime builds grids on the host and uploads them when GPU residency is requested.",
40};
41
42fn builtin_error(message: impl Into<String>) -> crate::RuntimeError {
43 build_runtime_error(message)
44 .with_builtin("meshgrid")
45 .build()
46}
47
48#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::creation::meshgrid")]
49pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
50 name: "meshgrid",
51 shape: ShapeRequirements::Any,
52 constant_strategy: ConstantStrategy::InlineLiteral,
53 elementwise: None,
54 reduction: None,
55 emits_nan: false,
56 notes:
57 "Meshgrid explicitly materialises dense coordinate arrays and therefore bypasses fusion.",
58};
59
60fn meshgrid_type(args: &[Type], _context: &ResolveContext) -> Type {
61 if args.is_empty() {
62 return Type::Unknown;
63 }
64 let mut axis_count = args.len();
65 if axis_count >= 2 && matches!(args[axis_count - 2], Type::String) {
66 axis_count = axis_count.saturating_sub(2);
67 }
68 if axis_count == 0 {
69 return Type::Unknown;
70 }
71 let axis_args = &args[..axis_count];
72 let len_x = axis_args.get(0).and_then(size_vector_len);
73 let len_y = axis_args.get(1).and_then(size_vector_len).or(len_x);
74 let len_z = axis_args.get(2).and_then(size_vector_len);
75 let shape = if axis_count >= 3 {
76 vec![len_y, len_x, len_z]
77 } else {
78 vec![len_y, len_x]
79 };
80 Type::Tensor { shape: Some(shape) }
81}
82
83const MESHGRID_OUTPUT_XY: [BuiltinParamDescriptor; 2] = [
84 BuiltinParamDescriptor {
85 name: "X",
86 ty: BuiltinParamType::NumericArray,
87 arity: BuiltinParamArity::Required,
88 default: None,
89 description: "Grid coordinates along X-axis.",
90 },
91 BuiltinParamDescriptor {
92 name: "Y",
93 ty: BuiltinParamType::NumericArray,
94 arity: BuiltinParamArity::Required,
95 default: None,
96 description: "Grid coordinates along Y-axis.",
97 },
98];
99
100const MESHGRID_OUTPUT_XYZ: [BuiltinParamDescriptor; 3] = [
101 BuiltinParamDescriptor {
102 name: "X",
103 ty: BuiltinParamType::NumericArray,
104 arity: BuiltinParamArity::Required,
105 default: None,
106 description: "Grid coordinates along X-axis.",
107 },
108 BuiltinParamDescriptor {
109 name: "Y",
110 ty: BuiltinParamType::NumericArray,
111 arity: BuiltinParamArity::Required,
112 default: None,
113 description: "Grid coordinates along Y-axis.",
114 },
115 BuiltinParamDescriptor {
116 name: "Z",
117 ty: BuiltinParamType::NumericArray,
118 arity: BuiltinParamArity::Optional,
119 default: None,
120 description: "Grid coordinates along Z-axis.",
121 },
122];
123
124const MESHGRID_SIG_X_INPUTS: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
125 name: "x",
126 ty: BuiltinParamType::NumericArray,
127 arity: BuiltinParamArity::Required,
128 default: None,
129 description: "X-axis vector.",
130}];
131
132const MESHGRID_SIG_XY_INPUTS: [BuiltinParamDescriptor; 2] = [
133 BuiltinParamDescriptor {
134 name: "x",
135 ty: BuiltinParamType::NumericArray,
136 arity: BuiltinParamArity::Required,
137 default: None,
138 description: "X-axis vector.",
139 },
140 BuiltinParamDescriptor {
141 name: "y",
142 ty: BuiltinParamType::NumericArray,
143 arity: BuiltinParamArity::Required,
144 default: None,
145 description: "Y-axis vector.",
146 },
147];
148
149const MESHGRID_SIG_XYZ_INPUTS: [BuiltinParamDescriptor; 3] = [
150 BuiltinParamDescriptor {
151 name: "x",
152 ty: BuiltinParamType::NumericArray,
153 arity: BuiltinParamArity::Required,
154 default: None,
155 description: "X-axis vector.",
156 },
157 BuiltinParamDescriptor {
158 name: "y",
159 ty: BuiltinParamType::NumericArray,
160 arity: BuiltinParamArity::Required,
161 default: None,
162 description: "Y-axis vector.",
163 },
164 BuiltinParamDescriptor {
165 name: "z",
166 ty: BuiltinParamType::NumericArray,
167 arity: BuiltinParamArity::Optional,
168 default: None,
169 description: "Z-axis vector.",
170 },
171];
172
173const MESHGRID_SIG_X_LIKE_INPUTS: [BuiltinParamDescriptor; 3] = [
174 BuiltinParamDescriptor {
175 name: "x",
176 ty: BuiltinParamType::NumericArray,
177 arity: BuiltinParamArity::Required,
178 default: None,
179 description: "X-axis vector.",
180 },
181 BuiltinParamDescriptor {
182 name: "like_kw",
183 ty: BuiltinParamType::StringScalar,
184 arity: BuiltinParamArity::Required,
185 default: Some("\"like\""),
186 description: "Like keyword.",
187 },
188 BuiltinParamDescriptor {
189 name: "prototype",
190 ty: BuiltinParamType::LikePrototype,
191 arity: BuiltinParamArity::Required,
192 default: None,
193 description: "Prototype controlling class/device residency.",
194 },
195];
196
197const MESHGRID_SIG_XY_LIKE_INPUTS: [BuiltinParamDescriptor; 4] = [
198 BuiltinParamDescriptor {
199 name: "x",
200 ty: BuiltinParamType::NumericArray,
201 arity: BuiltinParamArity::Required,
202 default: None,
203 description: "X-axis vector.",
204 },
205 BuiltinParamDescriptor {
206 name: "y",
207 ty: BuiltinParamType::NumericArray,
208 arity: BuiltinParamArity::Required,
209 default: None,
210 description: "Y-axis vector.",
211 },
212 BuiltinParamDescriptor {
213 name: "like_kw",
214 ty: BuiltinParamType::StringScalar,
215 arity: BuiltinParamArity::Required,
216 default: Some("\"like\""),
217 description: "Like keyword.",
218 },
219 BuiltinParamDescriptor {
220 name: "prototype",
221 ty: BuiltinParamType::LikePrototype,
222 arity: BuiltinParamArity::Required,
223 default: None,
224 description: "Prototype controlling class/device residency.",
225 },
226];
227
228const MESHGRID_SIG_XYZ_LIKE_INPUTS: [BuiltinParamDescriptor; 5] = [
229 BuiltinParamDescriptor {
230 name: "x",
231 ty: BuiltinParamType::NumericArray,
232 arity: BuiltinParamArity::Required,
233 default: None,
234 description: "X-axis vector.",
235 },
236 BuiltinParamDescriptor {
237 name: "y",
238 ty: BuiltinParamType::NumericArray,
239 arity: BuiltinParamArity::Required,
240 default: None,
241 description: "Y-axis vector.",
242 },
243 BuiltinParamDescriptor {
244 name: "z",
245 ty: BuiltinParamType::NumericArray,
246 arity: BuiltinParamArity::Optional,
247 default: None,
248 description: "Z-axis vector.",
249 },
250 BuiltinParamDescriptor {
251 name: "like_kw",
252 ty: BuiltinParamType::StringScalar,
253 arity: BuiltinParamArity::Required,
254 default: Some("\"like\""),
255 description: "Like keyword.",
256 },
257 BuiltinParamDescriptor {
258 name: "prototype",
259 ty: BuiltinParamType::LikePrototype,
260 arity: BuiltinParamArity::Required,
261 default: None,
262 description: "Prototype controlling class/device residency.",
263 },
264];
265
266const MESHGRID_SIGNATURES: [BuiltinSignatureDescriptor; 6] = [
267 BuiltinSignatureDescriptor {
268 label: "[X,Y] = meshgrid(x)",
269 inputs: &MESHGRID_SIG_X_INPUTS,
270 outputs: &MESHGRID_OUTPUT_XY,
271 },
272 BuiltinSignatureDescriptor {
273 label: "[X,Y] = meshgrid(x, y)",
274 inputs: &MESHGRID_SIG_XY_INPUTS,
275 outputs: &MESHGRID_OUTPUT_XY,
276 },
277 BuiltinSignatureDescriptor {
278 label: "[X,Y,Z] = meshgrid(x, y, z)",
279 inputs: &MESHGRID_SIG_XYZ_INPUTS,
280 outputs: &MESHGRID_OUTPUT_XYZ,
281 },
282 BuiltinSignatureDescriptor {
283 label: "[X,Y] = meshgrid(x, \"like\", prototype)",
284 inputs: &MESHGRID_SIG_X_LIKE_INPUTS,
285 outputs: &MESHGRID_OUTPUT_XY,
286 },
287 BuiltinSignatureDescriptor {
288 label: "[X,Y] = meshgrid(x, y, \"like\", prototype)",
289 inputs: &MESHGRID_SIG_XY_LIKE_INPUTS,
290 outputs: &MESHGRID_OUTPUT_XY,
291 },
292 BuiltinSignatureDescriptor {
293 label: "[X,Y,Z] = meshgrid(x, y, z, \"like\", prototype)",
294 inputs: &MESHGRID_SIG_XYZ_LIKE_INPUTS,
295 outputs: &MESHGRID_OUTPUT_XYZ,
296 },
297];
298
299const MESHGRID_ERRORS: [BuiltinErrorDescriptor; 11] = [
300 BuiltinErrorDescriptor {
301 code: "RM.MESHGRID.MISSING_AXIS",
302 identifier: None,
303 when: "No axis vectors are provided.",
304 message: "meshgrid: at least one input vector is required",
305 },
306 BuiltinErrorDescriptor {
307 code: "RM.MESHGRID.TOO_MANY_AXES",
308 identifier: None,
309 when: "More than three axis vectors are provided.",
310 message: "meshgrid: expected at most three input vectors",
311 },
312 BuiltinErrorDescriptor {
313 code: "RM.MESHGRID.LIKE_EXPECTED_PROTOTYPE",
314 identifier: None,
315 when: "The 'like' keyword is provided without a prototype argument.",
316 message: "meshgrid: expected prototype after 'like'",
317 },
318 BuiltinErrorDescriptor {
319 code: "RM.MESHGRID.MULTIPLE_LIKE",
320 identifier: None,
321 when: "The 'like' keyword is provided multiple times.",
322 message: "meshgrid: multiple 'like' specifications are not supported",
323 },
324 BuiltinErrorDescriptor {
325 code: "RM.MESHGRID.LIKE_POSITION",
326 identifier: None,
327 when: "The 'like' keyword is in an invalid position or not final.",
328 message: "meshgrid: 'like' must be the final argument",
329 },
330 BuiltinErrorDescriptor {
331 code: "RM.MESHGRID.UNRECOGNIZED_OPTION",
332 identifier: None,
333 when: "A trailing option string is not recognized.",
334 message: "meshgrid: unrecognised option",
335 },
336 BuiltinErrorDescriptor {
337 code: "RM.MESHGRID.INVALID_AXIS_INPUT",
338 identifier: None,
339 when: "Axis inputs are non-numeric or non-vector shapes.",
340 message: "meshgrid: input argument must be numeric vector data",
341 },
342 BuiltinErrorDescriptor {
343 code: "RM.MESHGRID.INVALID_PROTOTYPE",
344 identifier: None,
345 when: "The 'like' prototype is unsupported.",
346 message: "meshgrid: prototypes must be numeric arrays",
347 },
348 BuiltinErrorDescriptor {
349 code: "RM.MESHGRID.OUTPUT_COUNT_EXCEEDED",
350 identifier: None,
351 when: "Requested outputs exceed available outputs for provided axes.",
352 message:
353 "meshgrid: supports at most two outputs for 2-axis inputs and three for 3-axis inputs",
354 },
355 BuiltinErrorDescriptor {
356 code: "RM.MESHGRID.THIRD_OUTPUT_UNAVAILABLE",
357 identifier: None,
358 when: "A third output is requested without supplying a Z-axis vector.",
359 message: "meshgrid: third output requested but no Z vector was supplied",
360 },
361 BuiltinErrorDescriptor {
362 code: "RM.MESHGRID.COMPLEX_REAL_CONVERSION",
363 identifier: None,
364 when: "Complex axis values cannot be represented in requested real output class.",
365 message: "meshgrid: cannot represent complex values in a real output",
366 },
367];
368
369pub const MESHGRID_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
370 signatures: &MESHGRID_SIGNATURES,
371 output_mode: BuiltinOutputMode::ByRequestedOutputCount,
372 completion_policy: BuiltinCompletionPolicy::Public,
373 errors: &MESHGRID_ERRORS,
374};
375
376#[runtime_builtin(
377 name = "meshgrid",
378 category = "array/creation",
379 summary = "Generate coordinate matrices for 2-D and 3-D grids.",
380 keywords = "meshgrid,grid,gpu,like,3d",
381 accel = "array_construct",
382 type_resolver(meshgrid_type),
383 descriptor(crate::builtins::array::creation::meshgrid::MESHGRID_DESCRIPTOR),
384 builtin_path = "crate::builtins::array::creation::meshgrid"
385)]
386async fn meshgrid_builtin(rest: Vec<Value>) -> crate::BuiltinResult<Value> {
387 let eval = evaluate(&rest).await?;
388 if let Some(out_count) = crate::output_count::current_output_count() {
389 if out_count == 0 {
390 return Ok(Value::OutputList(Vec::new()));
391 }
392 let available = eval.output_count();
393 if out_count > available {
394 let msg = if available == 2 {
395 "meshgrid with two inputs supports at most two outputs"
396 } else {
397 "meshgrid supports at most three outputs"
398 };
399 return Err(builtin_error(msg));
400 }
401 let mut outputs = Vec::with_capacity(out_count);
402 let first = eval.first().await?;
403 outputs.push(first);
404 if out_count >= 2 {
405 outputs.push(eval.second().await?);
406 }
407 if out_count >= 3 {
408 outputs.push(eval.third().await?);
409 }
410 return Ok(Value::OutputList(outputs));
411 }
412 eval.first().await
413}
414
415pub async fn evaluate(args: &[Value]) -> crate::BuiltinResult<MeshgridEval> {
417 let parsed = ParsedMeshgrid::parse(args).await?;
418 let (x_axis, y_axis, z_axis) = normalise_axes(&parsed.axes);
419
420 let require_complex = parsed.axes.iter().any(|axis| axis.is_complex);
421
422 let target_class = match &parsed.template {
423 OutputTemplate::Default => {
424 if require_complex {
425 PrototypeClass::Complex
426 } else {
427 PrototypeClass::Real
428 }
429 }
430 OutputTemplate::Like(spec) => {
431 if require_complex {
432 PrototypeClass::Complex
433 } else {
434 spec.class
435 }
436 }
437 };
438
439 let target_residency = match &parsed.template {
440 OutputTemplate::Default => {
441 if parsed.prefer_gpu {
442 DevicePreference::Gpu
443 } else {
444 DevicePreference::Host
445 }
446 }
447 OutputTemplate::Like(spec) => spec.residency,
448 };
449
450 let axes_all_real = !require_complex;
451 let mut outputs: Vec<MeshgridOutput> = Vec::new();
452
453 if axes_all_real
454 && matches!(target_class, PrototypeClass::Real)
455 && matches!(target_residency, DevicePreference::Gpu)
456 {
457 if let Some(gpu) = try_meshgrid_gpu_from_vector_axes(&x_axis, &y_axis, z_axis.as_ref())? {
458 outputs = gpu;
459 }
460 }
461
462 if outputs.is_empty() {
463 let x_host = axis_to_host_async(&x_axis).await?;
465 let y_host = axis_to_host_async(&y_axis).await?;
466 let z_host = match z_axis.as_ref() {
467 Some(axis) => Some(axis_to_host_async(axis).await?),
468 None => None,
469 };
470 outputs = build_outputs(&x_host, &y_host, z_host.as_ref())
471 .into_iter()
472 .map(MeshgridOutput::Host)
473 .collect();
474 }
475
476 Ok(MeshgridEval {
477 outputs,
478 target_class,
479 target_residency,
480 })
481}
482
483#[derive(Clone)]
484struct ParsedMeshgrid {
485 axes: Vec<AxisData>,
486 template: OutputTemplate,
487 prefer_gpu: bool,
488}
489
490impl ParsedMeshgrid {
491 async fn parse(args: &[Value]) -> crate::BuiltinResult<Self> {
492 if args.is_empty() {
493 return Err(builtin_error(
494 "meshgrid: at least one input vector is required",
495 ));
496 }
497 let mut axis_values: Vec<Value> = Vec::new();
498 let mut like_proto: Option<Value> = None;
499 let mut prefer_gpu = false;
500 let mut idx = 0;
501 while idx < args.len() {
502 let value = args[idx].clone();
503 if let Some(keyword) = keyword_of(&value) {
504 match keyword.as_str() {
505 "like" => {
506 if like_proto.is_some() {
507 return Err(builtin_error(
508 "meshgrid: multiple 'like' specifications are not supported",
509 ));
510 }
511 if axis_values.is_empty() {
512 return Err(builtin_error(
513 "meshgrid: 'like' must follow at least one input vector",
514 ));
515 }
516 let Some(proto) = args.get(idx + 1).cloned() else {
517 return Err(builtin_error("meshgrid: expected prototype after 'like'"));
518 };
519 like_proto = Some(proto);
520 idx += 2;
521 if idx < args.len() {
522 return Err(builtin_error(
523 "meshgrid: 'like' must be the final argument",
524 ));
525 }
526 break;
527 }
528 other => {
529 return Err(builtin_error(format!(
530 "meshgrid: unrecognised option '{other}'"
531 )));
532 }
533 }
534 }
535
536 if let Value::GpuTensor(_) = value {
537 prefer_gpu = true;
538 }
539 axis_values.push(value);
540 idx += 1;
541 }
542
543 if axis_values.is_empty() {
544 return Err(builtin_error(
545 "meshgrid: at least one input vector is required",
546 ));
547 }
548 if axis_values.len() > 3 {
549 return Err(builtin_error(
550 "meshgrid: expected at most three input vectors",
551 ));
552 }
553
554 let mut axes = Vec::with_capacity(max(axis_values.len(), 2));
555 for (i, value) in axis_values.into_iter().enumerate() {
556 let mut consumed_gpu = false;
557 let data = axis_from_value(value, i, &mut consumed_gpu).await?;
558 if consumed_gpu {
559 prefer_gpu = true;
560 }
561 axes.push(data);
562 }
563
564 if !prefer_gpu {
565 if let Some(max_len) = axes.iter().map(|axis| axis.len).max() {
566 if max_len > 0
567 && sequence_gpu_preference(max_len, SequenceIntent::MeshAxis, false).prefer_gpu
568 {
569 prefer_gpu = true;
570 }
571 }
572 }
573
574 let template = if let Some(proto) = like_proto {
575 OutputTemplate::Like(analyse_like_prototype(&proto)?)
576 } else {
577 OutputTemplate::Default
578 };
579
580 Ok(Self {
581 axes,
582 template,
583 prefer_gpu,
584 })
585 }
586}
587
588#[derive(Clone)]
589enum OutputTemplate {
590 Default,
591 Like(PrototypeSpec),
592}
593
594#[derive(Clone)]
595struct PrototypeSpec {
596 residency: DevicePreference,
597 class: PrototypeClass,
598}
599
600#[derive(Clone, Copy, PartialEq, Eq)]
601enum PrototypeClass {
602 Real,
603 Complex,
604}
605
606#[derive(Clone, Copy)]
607enum DevicePreference {
608 Host,
609 Gpu,
610}
611
612fn analyse_like_prototype(proto: &Value) -> crate::BuiltinResult<PrototypeSpec> {
613 match proto {
614 Value::GpuTensor(_) => Ok(PrototypeSpec {
615 residency: DevicePreference::Gpu,
616 class: PrototypeClass::Real,
617 }),
618 Value::ComplexTensor(_) | Value::Complex(_, _) => Ok(PrototypeSpec {
619 residency: DevicePreference::Host,
620 class: PrototypeClass::Complex,
621 }),
622 Value::Tensor(_)
623 | Value::SparseTensor(_)
624 | Value::Num(_)
625 | Value::Int(_)
626 | Value::Bool(_)
627 | Value::LogicalArray(_) => Ok(PrototypeSpec {
628 residency: DevicePreference::Host,
629 class: PrototypeClass::Real,
630 }),
631 Value::CharArray(_) | Value::String(_) | Value::StringArray(_) => Err(builtin_error(
632 "meshgrid: prototypes must be numeric or gpuArray values",
633 )),
634 Value::Cell(_)
635 | Value::Struct(_)
636 | Value::Object(_)
637 | Value::HandleObject(_)
638 | Value::Listener(_)
639 | Value::FunctionHandle(_)
640 | Value::ExternalFunctionHandle(_)
641 | Value::MethodFunctionHandle(_)
642 | Value::BoundFunctionHandle { .. }
643 | Value::Closure(_)
644 | Value::ClassRef(_)
645 | Value::MException(_)
646 | Value::OutputList(_) => Err(builtin_error("meshgrid: prototypes must be numeric arrays")),
647 }
648}
649
650#[derive(Clone)]
651struct AxisData {
652 values: Vec<(f64, f64)>,
653 len: usize,
654 is_complex: bool,
655 gpu_real: Option<GpuTensorHandle>,
656}
657
658async fn axis_from_value(
659 value: Value,
660 index: usize,
661 prefer_gpu: &mut bool,
662) -> crate::BuiltinResult<AxisData> {
663 match value {
664 Value::Tensor(tensor) => axis_from_tensor(tensor, index),
665 Value::LogicalArray(logical) => {
666 let tensor = tensor::logical_to_tensor(&logical)?;
667 axis_from_tensor(tensor, index)
668 }
669 Value::Num(n) => Ok(AxisData {
670 values: vec![(n, 0.0)],
671 len: 1,
672 is_complex: false,
673 gpu_real: None,
674 }),
675 Value::Int(i) => {
676 let val = i.to_f64();
677 Ok(AxisData {
678 values: vec![(val, 0.0)],
679 len: 1,
680 is_complex: false,
681 gpu_real: None,
682 })
683 }
684 Value::Bool(b) => Ok(AxisData {
685 values: vec![(if b { 1.0 } else { 0.0 }, 0.0)],
686 len: 1,
687 is_complex: false,
688 gpu_real: None,
689 }),
690 Value::Complex(re, im) => Ok(AxisData {
691 values: vec![(re, im)],
692 len: 1,
693 is_complex: im != 0.0,
694 gpu_real: None,
695 }),
696 Value::ComplexTensor(tensor) => axis_from_complex_tensor(tensor, index),
697 Value::GpuTensor(handle) => {
698 if is_vector_shape(&handle.shape) {
701 *prefer_gpu = true;
702 return Ok(AxisData {
703 values: Vec::new(),
704 len: vector_len_from_shape(&handle.shape),
705 is_complex: false,
706 gpu_real: Some(handle),
707 });
708 }
709
710 let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
712 if is_vector_shape(&tensor.shape) {
713 *prefer_gpu = true;
714 }
715 axis_from_tensor(tensor, index)
716 }
717 other => Err(builtin_error(format!(
718 "meshgrid: input argument {} must be numeric, got {other:?}",
719 index + 1
720 ))),
721 }
722}
723
724fn axis_from_tensor(tensor: Tensor, index: usize) -> crate::BuiltinResult<AxisData> {
725 if is_vector_shape(&tensor.shape) {
726 let mut values = Vec::with_capacity(tensor.data.len());
727 for &v in &tensor.data {
728 values.push((v, 0.0));
729 }
730 return Ok(AxisData {
731 len: values.len(),
732 values,
733 is_complex: false,
734 gpu_real: None,
735 });
736 }
737
738 if let Some(axis) = axis_from_meshgrid_matrix_real(&tensor, index)? {
744 return Ok(axis);
745 }
746
747 Err(builtin_error(format!(
748 "meshgrid: input argument {} must be a vector (1xN or Nx1), got shape {:?}",
749 index + 1,
750 tensor.shape
751 )))
752}
753
754fn axis_from_complex_tensor(tensor: ComplexTensor, index: usize) -> crate::BuiltinResult<AxisData> {
755 if is_vector_shape(&tensor.shape) {
756 let is_complex = tensor
757 .data
758 .iter()
759 .any(|&(_, imag)| !imag.is_nan() && imag != 0.0);
760 return Ok(AxisData {
761 len: tensor.data.len(),
762 values: tensor.data,
763 is_complex,
764 gpu_real: None,
765 });
766 }
767
768 if let Some(axis) = axis_from_meshgrid_matrix_complex(&tensor, index)? {
769 return Ok(axis);
770 }
771
772 Err(builtin_error(format!(
773 "meshgrid: input argument {} must be a vector (1xN or Nx1), got shape {:?}",
774 index + 1,
775 tensor.shape
776 )))
777}
778
779fn axis_from_meshgrid_matrix_real(
780 tensor: &Tensor,
781 index: usize,
782) -> crate::BuiltinResult<Option<AxisData>> {
783 let (rows, cols) = match tensor.shape.as_slice() {
784 [r, c] => (*r, *c),
785 _ => return Ok(None),
786 };
787 if rows <= 1 || cols <= 1 {
788 return Ok(None);
789 }
790
791 let expect_rows_constant = index == 0;
794
795 if expect_rows_constant {
796 if !matrix_rows_are_identical_real(tensor, rows, cols) {
797 return Ok(None);
798 }
799 let mut values = Vec::with_capacity(cols);
801 for col in 0..cols {
802 let idx = rows * col;
803 values.push((tensor.data[idx], 0.0));
804 }
805 return Ok(Some(AxisData {
806 len: values.len(),
807 values,
808 is_complex: false,
809 gpu_real: None,
810 }));
811 }
812
813 if !matrix_cols_are_identical_real(tensor, rows, cols) {
814 return Ok(None);
815 }
816 let mut values = Vec::with_capacity(rows);
818 for row in 0..rows {
819 values.push((tensor.data[row], 0.0));
820 }
821 Ok(Some(AxisData {
822 len: values.len(),
823 values,
824 is_complex: false,
825 gpu_real: None,
826 }))
827}
828
829fn axis_from_meshgrid_matrix_complex(
830 tensor: &ComplexTensor,
831 index: usize,
832) -> crate::BuiltinResult<Option<AxisData>> {
833 let (rows, cols) = match tensor.shape.as_slice() {
834 [r, c] => (*r, *c),
835 _ => return Ok(None),
836 };
837 if rows <= 1 || cols <= 1 {
838 return Ok(None);
839 }
840
841 let expect_rows_constant = index == 0;
842 if expect_rows_constant {
843 if !matrix_rows_are_identical_complex(tensor, rows, cols) {
844 return Ok(None);
845 }
846 let mut values = Vec::with_capacity(cols);
847 for col in 0..cols {
848 let idx = rows * col;
849 values.push(tensor.data[idx]);
850 }
851 let is_complex = values.iter().any(|&(_, im)| !im.is_nan() && im != 0.0);
852 return Ok(Some(AxisData {
853 len: values.len(),
854 values,
855 is_complex,
856 gpu_real: None,
857 }));
858 }
859
860 if !matrix_cols_are_identical_complex(tensor, rows, cols) {
861 return Ok(None);
862 }
863 let mut values = Vec::with_capacity(rows);
864 for row in 0..rows {
865 values.push(tensor.data[row]);
866 }
867 let is_complex = values.iter().any(|&(_, im)| !im.is_nan() && im != 0.0);
868 Ok(Some(AxisData {
869 len: values.len(),
870 values,
871 is_complex,
872 gpu_real: None,
873 }))
874}
875
876fn matrix_rows_are_identical_real(tensor: &Tensor, rows: usize, cols: usize) -> bool {
877 for row in 1..rows {
878 for col in 0..cols {
879 let idx0 = rows * col;
880 let idx = row + rows * col;
881 if tensor.data[idx] != tensor.data[idx0] {
882 return false;
883 }
884 }
885 }
886 true
887}
888
889fn matrix_cols_are_identical_real(tensor: &Tensor, rows: usize, cols: usize) -> bool {
890 for col in 1..cols {
891 for row in 0..rows {
892 let idx0 = row;
893 let idx = row + rows * col;
894 if tensor.data[idx] != tensor.data[idx0] {
895 return false;
896 }
897 }
898 }
899 true
900}
901
902fn matrix_rows_are_identical_complex(tensor: &ComplexTensor, rows: usize, cols: usize) -> bool {
903 for row in 1..rows {
904 for col in 0..cols {
905 let idx0 = rows * col;
906 let idx = row + rows * col;
907 if tensor.data[idx] != tensor.data[idx0] {
908 return false;
909 }
910 }
911 }
912 true
913}
914
915fn matrix_cols_are_identical_complex(tensor: &ComplexTensor, rows: usize, cols: usize) -> bool {
916 for col in 1..cols {
917 for row in 0..rows {
918 let idx0 = row;
919 let idx = row + rows * col;
920 if tensor.data[idx] != tensor.data[idx0] {
921 return false;
922 }
923 }
924 }
925 true
926}
927
928fn is_vector_shape(shape: &[usize]) -> bool {
929 if shape.is_empty() {
930 return true;
931 }
932 let mut non_singleton = 0usize;
933 for &dim in shape {
934 if dim > 1 {
935 non_singleton += 1;
936 }
937 }
938 non_singleton <= 1
939}
940
941fn vector_len_from_shape(shape: &[usize]) -> usize {
942 if shape.is_empty() {
943 return 1;
944 }
945 shape.iter().copied().max().unwrap_or(0)
946}
947
948async fn axis_to_host_async(axis: &AxisData) -> crate::BuiltinResult<AxisData> {
949 if axis.gpu_real.is_none() {
950 return Ok(axis.clone());
951 }
952 let handle = axis.gpu_real.as_ref().expect("checked gpu_real is_some");
953 let tensor = gpu_helpers::gather_tensor_async(handle).await?;
954 axis_from_tensor(tensor, 0)
956}
957
958fn try_meshgrid_gpu_from_vector_axes(
959 x_axis: &AxisData,
960 y_axis: &AxisData,
961 z_axis: Option<&AxisData>,
962) -> crate::BuiltinResult<Option<Vec<MeshgridOutput>>> {
963 let Some(x_handle) = x_axis.gpu_real.as_ref() else {
964 return Ok(None);
965 };
966 let Some(y_handle) = y_axis.gpu_real.as_ref() else {
967 return Ok(None);
968 };
969
970 let z_handle = match z_axis {
971 Some(axis) => match axis.gpu_real.as_ref() {
972 Some(h) => Some(h),
973 None => return Ok(None),
974 },
975 None => None,
976 };
977
978 let Some(provider) = runmat_accelerate_api::provider_for_handle(x_handle) else {
979 return Ok(None);
980 };
981 if runmat_accelerate_api::provider_for_handle(y_handle).is_none() {
982 return Ok(None);
983 }
984 if let Some(z) = z_handle {
985 if runmat_accelerate_api::provider_for_handle(z).is_none() {
986 return Ok(None);
987 }
988 }
989
990 let nx = x_axis.len;
991 let ny = y_axis.len;
992 let nz = z_axis.map(|axis| axis.len).unwrap_or(1);
993
994 let x_row = provider
996 .reshape(x_handle, &[1, nx])
997 .map_err(|e| builtin_error(format!("meshgrid: reshape X failed: {e}")))?;
998 let y_col = provider
999 .reshape(y_handle, &[ny, 1])
1000 .map_err(|e| builtin_error(format!("meshgrid: reshape Y failed: {e}")))?;
1001
1002 let mut outputs = Vec::with_capacity(if z_handle.is_some() { 3 } else { 2 });
1003 if let Some(z) = z_handle {
1004 let x_base = provider
1005 .reshape(&x_row, &[1, nx, 1])
1006 .map_err(|e| builtin_error(format!("meshgrid: reshape X(3d) failed: {e}")))?;
1007 let y_base = provider
1008 .reshape(&y_col, &[ny, 1, 1])
1009 .map_err(|e| builtin_error(format!("meshgrid: reshape Y(3d) failed: {e}")))?;
1010
1011 let x_grid = provider
1012 .repmat(&x_base, &[ny, 1, nz])
1013 .map_err(|e| builtin_error(format!("meshgrid: repmat X failed: {e}")))?;
1014 let y_grid = provider
1015 .repmat(&y_base, &[1, nx, nz])
1016 .map_err(|e| builtin_error(format!("meshgrid: repmat Y failed: {e}")))?;
1017
1018 outputs.push(MeshgridOutput::GpuReal(x_grid));
1019 outputs.push(MeshgridOutput::GpuReal(y_grid));
1020 let z_axis_row = provider
1021 .reshape(z, &[1, nz])
1022 .map_err(|e| builtin_error(format!("meshgrid: reshape Z failed: {e}")))?;
1023 let z_base = provider
1024 .reshape(&z_axis_row, &[1, 1, nz])
1025 .map_err(|e| builtin_error(format!("meshgrid: reshape Z(3d) failed: {e}")))?;
1026 let z_grid = provider
1027 .repmat(&z_base, &[ny, nx, 1])
1028 .map_err(|e| builtin_error(format!("meshgrid: repmat Z failed: {e}")))?;
1029 outputs.push(MeshgridOutput::GpuReal(z_grid));
1030 } else {
1031 let x_grid = provider
1032 .repmat(&x_row, &[ny, 1])
1033 .map_err(|e| builtin_error(format!("meshgrid: repmat X failed: {e}")))?;
1034 let y_grid = provider
1035 .repmat(&y_col, &[1, nx])
1036 .map_err(|e| builtin_error(format!("meshgrid: repmat Y failed: {e}")))?;
1037 outputs.push(MeshgridOutput::GpuReal(x_grid));
1038 outputs.push(MeshgridOutput::GpuReal(y_grid));
1039 }
1040
1041 Ok(Some(outputs))
1042}
1043
1044fn normalise_axes(axes: &[AxisData]) -> (AxisData, AxisData, Option<AxisData>) {
1045 match axes.len() {
1046 1 => {
1047 let x = axes[0].clone();
1048 (x.clone(), x, None)
1049 }
1050 2 => {
1051 let x = axes[0].clone();
1052 let y = axes[1].clone();
1053 (x, y, None)
1054 }
1055 3 => {
1056 let x = axes[0].clone();
1057 let y = axes[1].clone();
1058 let z = axes[2].clone();
1059 (x, y, Some(z))
1060 }
1061 _ => unreachable!(),
1062 }
1063}
1064
1065fn build_outputs(
1066 x_axis: &AxisData,
1067 y_axis: &AxisData,
1068 z_axis: Option<&AxisData>,
1069) -> Vec<GridOutput> {
1070 let nx = x_axis.len;
1071 let ny = y_axis.len;
1072 let nz = z_axis.map(|axis| axis.len).unwrap_or(1);
1073 let total = nx * ny * nz;
1074 let mut x_data = Vec::with_capacity(total);
1075 let mut y_data = Vec::with_capacity(total);
1076 let mut z_data = z_axis.map(|_| Vec::with_capacity(total));
1077
1078 for k in 0..nz {
1079 let z_value = z_axis.map(|axis| axis.values[k]);
1080 for col in 0..nx {
1081 let x_value = x_axis.values[col];
1082 for row in 0..ny {
1083 x_data.push(x_value);
1084 y_data.push(y_axis.values[row]);
1085 if let Some(ref mut z_vec) = z_data {
1086 z_vec.push(z_value.unwrap());
1087 }
1088 }
1089 }
1090 }
1091
1092 let mut outputs = Vec::new();
1093 let base_shape = if nz == 1 {
1094 vec![ny, nx]
1095 } else {
1096 vec![ny, nx, nz]
1097 };
1098 outputs.push(GridOutput {
1099 shape: base_shape.clone(),
1100 data: x_data,
1101 });
1102 outputs.push(GridOutput {
1103 shape: base_shape.clone(),
1104 data: y_data,
1105 });
1106 if let Some(z_vec) = z_data {
1107 outputs.push(GridOutput {
1108 shape: base_shape,
1109 data: z_vec,
1110 });
1111 }
1112 outputs
1113}
1114
1115struct GridOutput {
1116 shape: Vec<usize>,
1117 data: Vec<(f64, f64)>,
1118}
1119
1120impl GridOutput {
1121 fn to_value(
1122 &self,
1123 class: PrototypeClass,
1124 residency: DevicePreference,
1125 ) -> crate::BuiltinResult<Value> {
1126 match class {
1127 PrototypeClass::Real => self.to_real_value(residency),
1128 PrototypeClass::Complex => self.to_complex_value(residency),
1129 }
1130 }
1131
1132 fn to_real_value(&self, residency: DevicePreference) -> crate::BuiltinResult<Value> {
1133 let mut real = Vec::with_capacity(self.data.len());
1134 for &(re, im) in &self.data {
1135 if im != 0.0 {
1136 return Err(builtin_error(
1137 "meshgrid: cannot represent complex values in a real output",
1138 ));
1139 }
1140 real.push(re);
1141 }
1142 let tensor = Tensor::new(real, self.shape.clone())
1143 .map_err(|e| builtin_error(format!("meshgrid: {e}")))?;
1144 match residency {
1145 DevicePreference::Host => Ok(tensor::tensor_into_value(tensor)),
1146 DevicePreference::Gpu => to_gpu_tensor_value(tensor),
1147 }
1148 }
1149
1150 fn to_complex_value(&self, residency: DevicePreference) -> crate::BuiltinResult<Value> {
1151 let tensor = ComplexTensor::new(self.data.clone(), self.shape.clone())
1152 .map_err(|e| builtin_error(format!("meshgrid: {e}")))?;
1153 match residency {
1154 DevicePreference::Host => Ok(complex_tensor_into_value(tensor)),
1155 DevicePreference::Gpu => {
1156 warn!("meshgrid: complex GPU outputs are not implemented; returning host complex array");
1157 Ok(complex_tensor_into_value(tensor))
1158 }
1159 }
1160 }
1161}
1162
1163fn to_gpu_tensor_value(tensor: Tensor) -> crate::BuiltinResult<Value> {
1164 if let Some(provider) = runmat_accelerate_api::provider() {
1165 let view = HostTensorView {
1166 data: &tensor.data,
1167 shape: &tensor.shape,
1168 };
1169 match provider.upload(&view) {
1170 Ok(handle) => return Ok(Value::GpuTensor(handle)),
1171 Err(err) => {
1172 warn!("meshgrid: failed to upload tensor to GPU, returning host array: {err}")
1173 }
1174 }
1175 }
1176 Ok(tensor::tensor_into_value(tensor))
1177}
1178
1179fn tensor_to_complex_value(tensor: Tensor) -> crate::BuiltinResult<Value> {
1180 let data: Vec<(f64, f64)> = tensor.data.iter().map(|&re| (re, 0.0)).collect();
1181 let complex = ComplexTensor::new(data, tensor.shape.clone())
1182 .map_err(|e| builtin_error(format!("meshgrid: {e}")))?;
1183 Ok(complex_tensor_into_value(complex))
1184}
1185
1186enum MeshgridOutput {
1187 Host(GridOutput),
1188 GpuReal(GpuTensorHandle),
1189}
1190
1191impl MeshgridOutput {
1192 async fn to_value(
1193 &self,
1194 class: PrototypeClass,
1195 residency: DevicePreference,
1196 ) -> crate::BuiltinResult<Value> {
1197 match self {
1198 MeshgridOutput::Host(host) => host.to_value(class, residency),
1199 MeshgridOutput::GpuReal(handle) => match (class, residency) {
1200 (PrototypeClass::Real, DevicePreference::Gpu) => {
1201 Ok(Value::GpuTensor(handle.clone()))
1202 }
1203 (PrototypeClass::Real, DevicePreference::Host) => {
1204 let tensor = gpu_helpers::gather_tensor_async(handle).await?;
1205 Ok(tensor::tensor_into_value(tensor))
1206 }
1207 (PrototypeClass::Complex, DevicePreference::Host) => {
1208 let tensor = gpu_helpers::gather_tensor_async(handle).await?;
1209 tensor_to_complex_value(tensor)
1210 }
1211 (PrototypeClass::Complex, DevicePreference::Gpu) => {
1212 warn!("meshgrid: complex GPU outputs are not implemented; returning host complex array");
1213 let tensor = gpu_helpers::gather_tensor_async(handle).await?;
1214 tensor_to_complex_value(tensor)
1215 }
1216 },
1217 }
1218 }
1219}
1220
1221pub struct MeshgridEval {
1224 outputs: Vec<MeshgridOutput>,
1225 target_class: PrototypeClass,
1226 target_residency: DevicePreference,
1227}
1228
1229impl MeshgridEval {
1230 pub fn output_count(&self) -> usize {
1231 self.outputs.len()
1232 }
1233
1234 pub async fn first(&self) -> crate::BuiltinResult<Value> {
1235 self.outputs[0]
1236 .to_value(self.target_class, self.target_residency)
1237 .await
1238 }
1239
1240 pub async fn second(&self) -> crate::BuiltinResult<Value> {
1241 if self.outputs.len() < 2 {
1242 Err(builtin_error("meshgrid: second output unavailable"))
1243 } else {
1244 self.outputs[1]
1245 .to_value(self.target_class, self.target_residency)
1246 .await
1247 }
1248 }
1249
1250 pub async fn third(&self) -> crate::BuiltinResult<Value> {
1251 if self.outputs.len() < 3 {
1252 Err(builtin_error(
1253 "meshgrid: third output requested but no Z vector was supplied",
1254 ))
1255 } else {
1256 self.outputs[2]
1257 .to_value(self.target_class, self.target_residency)
1258 .await
1259 }
1260 }
1261}
1262
1263#[cfg(test)]
1264pub(crate) mod tests {
1265 use super::*;
1266 use crate::builtins::common::test_support;
1267 use futures::executor::block_on;
1268 #[cfg(feature = "wgpu")]
1269 use runmat_accelerate_api::AccelProvider;
1270
1271 use runmat_accelerate_api::HostTensorView;
1272
1273 fn evaluate(args: &[Value]) -> crate::BuiltinResult<MeshgridEval> {
1274 block_on(super::evaluate(args))
1275 }
1276
1277 fn eval_first(eval: &MeshgridEval) -> crate::BuiltinResult<Value> {
1278 block_on(eval.first())
1279 }
1280
1281 fn eval_second(eval: &MeshgridEval) -> crate::BuiltinResult<Value> {
1282 block_on(eval.second())
1283 }
1284
1285 fn eval_third(eval: &MeshgridEval) -> crate::BuiltinResult<Value> {
1286 block_on(eval.third())
1287 }
1288
1289 fn tensor_from_vec(data: Vec<f64>, rows: usize, cols: usize) -> Tensor {
1290 Tensor::new(data, vec![rows, cols]).unwrap()
1291 }
1292
1293 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1294 #[test]
1295 fn meshgrid_single_input_duplicates_axis() {
1296 let x = tensor_from_vec(vec![-1.0, 0.0, 1.0], 1, 3);
1297 let eval = evaluate(&[Value::Tensor(x)]).expect("meshgrid");
1298 assert_eq!(eval.output_count(), 2);
1299 let x_out = test_support::gather(eval_first(&eval).expect("X")).expect("host");
1300 assert_eq!(x_out.shape, vec![3, 3]);
1301 assert_eq!(
1302 x_out.data,
1303 vec![-1.0, -1.0, -1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0]
1304 );
1305 let y_out = test_support::gather(eval_second(&eval).expect("Y")).expect("host");
1306 assert_eq!(y_out.shape, vec![3, 3]);
1307 assert_eq!(
1308 y_out.data,
1309 vec![-1.0, 0.0, 1.0, -1.0, 0.0, 1.0, -1.0, 0.0, 1.0]
1310 );
1311 }
1312
1313 #[test]
1314 fn meshgrid_type_infers_rank_from_axis_count() {
1315 let ctx = ResolveContext::new(Vec::new());
1316 assert_eq!(
1317 meshgrid_type(&[Type::Num, Type::Num], &ctx),
1318 Type::Tensor {
1319 shape: Some(vec![Some(1), Some(1)])
1320 }
1321 );
1322 assert_eq!(
1323 meshgrid_type(&[Type::Num, Type::Num, Type::Num], &ctx),
1324 Type::Tensor {
1325 shape: Some(vec![Some(1), Some(1), Some(1)])
1326 }
1327 );
1328 }
1329
1330 #[test]
1331 fn meshgrid_type_uses_vector_lengths() {
1332 let ctx = ResolveContext::new(Vec::new());
1333 assert_eq!(
1334 meshgrid_type(
1335 &[
1336 Type::Tensor {
1337 shape: Some(vec![Some(1), Some(201)]),
1338 },
1339 Type::Tensor {
1340 shape: Some(vec![Some(1), Some(101)]),
1341 },
1342 ],
1343 &ctx,
1344 ),
1345 Type::Tensor {
1346 shape: Some(vec![Some(101), Some(201)])
1347 }
1348 );
1349 }
1350
1351 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1352 #[test]
1353 fn meshgrid_rectangular_inputs() {
1354 let x = tensor_from_vec(vec![0.0, 0.5, 1.0], 1, 3);
1355 let y = tensor_from_vec(vec![10.0, 20.0], 2, 1);
1356 let eval = evaluate(&[Value::Tensor(x), Value::Tensor(y)]).expect("meshgrid");
1357 assert_eq!(eval.output_count(), 2);
1358 let x_out = test_support::gather(eval_first(&eval).expect("X")).expect("host");
1359 assert_eq!(x_out.shape, vec![2, 3]);
1360 assert_eq!(x_out.data, vec![0.0, 0.0, 0.5, 0.5, 1.0, 1.0]);
1361 let y_out = test_support::gather(eval_second(&eval).expect("Y")).expect("host");
1362 assert_eq!(y_out.shape, vec![2, 3]);
1363 assert_eq!(y_out.data, vec![10.0, 20.0, 10.0, 20.0, 10.0, 20.0]);
1364 }
1365
1366 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1367 #[test]
1368 fn meshgrid_three_inputs_volume() {
1369 let x = tensor_from_vec(vec![1.0, 2.0], 1, 2);
1370 let y = tensor_from_vec(vec![5.0, 6.0, 7.0], 3, 1);
1371 let z = tensor_from_vec(vec![0.0, 1.0], 1, 2);
1372 let eval =
1373 evaluate(&[Value::Tensor(x), Value::Tensor(y), Value::Tensor(z)]).expect("meshgrid");
1374 assert_eq!(eval.output_count(), 3);
1375 let x_out = test_support::gather(eval_first(&eval).expect("X")).expect("host");
1376 assert_eq!(x_out.shape, vec![3, 2, 2]);
1377 assert_eq!(
1378 x_out.data,
1379 vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0]
1380 );
1381 let z_out = test_support::gather(eval_third(&eval).expect("Z")).expect("host");
1382 assert_eq!(z_out.shape, vec![3, 2, 2]);
1383 assert_eq!(
1384 z_out.data,
1385 vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
1386 );
1387 }
1388
1389 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1390 #[test]
1391 fn meshgrid_like_keeps_gpu_residency() {
1392 test_support::with_test_provider(|provider| {
1393 let x = tensor_from_vec(vec![-1.0, 0.0, 1.0], 1, 3);
1394 let y = tensor_from_vec(vec![2.0, 4.0], 2, 1);
1395 let proto = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
1396 let proto_view = HostTensorView {
1397 data: &proto.data,
1398 shape: &proto.shape,
1399 };
1400 let proto_handle = provider.upload(&proto_view).expect("upload");
1401 let eval = evaluate(&[
1402 Value::Tensor(x),
1403 Value::Tensor(y),
1404 Value::from("like"),
1405 Value::GpuTensor(proto_handle),
1406 ])
1407 .expect("meshgrid");
1408 let x_value = eval_first(&eval).expect("X");
1409 assert!(matches!(x_value, Value::GpuTensor(_)));
1410 let gathered = test_support::gather(x_value).expect("gather");
1411 assert_eq!(gathered.shape, vec![2, 3]);
1412 });
1413 }
1414
1415 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1416 #[test]
1417 fn meshgrid_gpu_inputs_roundtrip() {
1418 test_support::with_test_provider(|provider| {
1419 let x = tensor_from_vec(vec![0.0, 0.5], 1, 2);
1420 let y = tensor_from_vec(vec![1.0, 2.0], 2, 1);
1421 let x_view = HostTensorView {
1422 data: &x.data,
1423 shape: &x.shape,
1424 };
1425 let y_view = HostTensorView {
1426 data: &y.data,
1427 shape: &y.shape,
1428 };
1429 let x_handle = provider.upload(&x_view).expect("upload");
1430 let y_handle = provider.upload(&y_view).expect("upload");
1431 let eval = evaluate(&[Value::GpuTensor(x_handle), Value::GpuTensor(y_handle)])
1432 .expect("meshgrid");
1433 assert!(matches!(eval_first(&eval).expect("X"), Value::GpuTensor(_)));
1434 assert!(matches!(
1435 eval_second(&eval).expect("Y"),
1436 Value::GpuTensor(_)
1437 ));
1438 });
1439 }
1440
1441 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1442 #[test]
1443 #[cfg(feature = "wgpu")]
1444 fn meshgrid_wgpu_matches_cpu() {
1445 let provider = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1446 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1447 )
1448 .expect("wgpu provider");
1449
1450 let x = tensor_from_vec(vec![-1.0, 0.0, 1.0, 2.0], 1, 4);
1451 let y = tensor_from_vec(vec![5.0, 6.0], 2, 1);
1452
1453 let cpu_eval =
1454 evaluate(&[Value::Tensor(x.clone()), Value::Tensor(y.clone())]).expect("meshgrid cpu");
1455 let cpu_x =
1456 test_support::gather(eval_first(&cpu_eval).expect("X cpu")).expect("gather X cpu");
1457 let cpu_y =
1458 test_support::gather(eval_second(&cpu_eval).expect("Y cpu")).expect("gather Y cpu");
1459
1460 let x_view = HostTensorView {
1461 data: &x.data,
1462 shape: &x.shape,
1463 };
1464 let y_view = HostTensorView {
1465 data: &y.data,
1466 shape: &y.shape,
1467 };
1468 let x_gpu = provider.upload(&x_view).expect("upload x");
1469 let y_gpu = provider.upload(&y_view).expect("upload y");
1470
1471 let gpu_eval =
1472 evaluate(&[Value::GpuTensor(x_gpu), Value::GpuTensor(y_gpu)]).expect("meshgrid gpu");
1473 let gpu_x_value = eval_first(&gpu_eval).expect("X gpu");
1474 let gpu_y_value = eval_second(&gpu_eval).expect("Y gpu");
1475
1476 assert!(matches!(gpu_x_value, Value::GpuTensor(_)));
1477 assert!(matches!(gpu_y_value, Value::GpuTensor(_)));
1478
1479 let gathered_x = test_support::gather(gpu_x_value).expect("gather X gpu");
1480 let gathered_y = test_support::gather(gpu_y_value).expect("gather Y gpu");
1481
1482 assert_eq!(gathered_x.shape, cpu_x.shape);
1483 assert_eq!(gathered_x.data, cpu_x.data);
1484 assert_eq!(gathered_y.shape, cpu_y.shape);
1485 assert_eq!(gathered_y.data, cpu_y.data);
1486 }
1487
1488 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1489 #[test]
1490 fn meshgrid_complex_inputs_produce_complex_outputs() {
1491 let complex = ComplexTensor::new(vec![(1.0, 1.0), (2.0, -1.0)], vec![1, 2]).unwrap();
1492 let eval = evaluate(&[Value::ComplexTensor(complex)]).expect("meshgrid");
1493 let x_value = eval_first(&eval).expect("X");
1494 match x_value {
1495 Value::ComplexTensor(ct) => {
1496 assert_eq!(ct.shape, vec![2, 2]);
1497 }
1498 Value::Complex(_, _) => {}
1499 other => panic!("expected complex output, got {other:?}"),
1500 }
1501 }
1502
1503 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1504 #[test]
1505 fn meshgrid_like_host_prototype() {
1506 let x = tensor_from_vec(vec![1.0, 2.0], 1, 2);
1507 let eval =
1508 evaluate(&[Value::Tensor(x), Value::from("like"), Value::Num(0.0)]).expect("meshgrid");
1509 let x_out = eval_first(&eval).expect("X");
1510 assert!(matches!(x_out, Value::Tensor(_) | Value::Num(_)));
1511 }
1512}