1use std::cmp::max;
4
5use log::warn;
6use runmat_accelerate_api::{GpuTensorHandle, HostTensorView};
7use runmat_builtins::{
8 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
9 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
10 ComplexTensor, ResolveContext, Tensor, Type, Value,
11};
12
13use crate::builtins::array::type_resolvers::size_vector_len;
14use runmat_macros::runtime_builtin;
15
16use crate::build_runtime_error;
17use crate::builtins::common::gpu_helpers;
18use crate::builtins::common::random_args::{complex_tensor_into_value, keyword_of};
19use crate::builtins::common::residency::{sequence_gpu_preference, SequenceIntent};
20use crate::builtins::common::spec::{
21 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
22 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
23};
24use crate::builtins::common::tensor;
25
26#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::creation::meshgrid")]
27pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
28 name: "meshgrid",
29 op_kind: GpuOpKind::Custom("array_construct"),
30 supported_precisions: &[ScalarType::F32, ScalarType::F64],
31 broadcast: BroadcastSemantics::Matlab,
32 provider_hooks: &[ProviderHook::Custom("meshgrid")],
33 constant_strategy: ConstantStrategy::InlineLiteral,
34 residency: ResidencyPolicy::NewHandle,
35 nan_mode: ReductionNaN::Include,
36 two_pass_threshold: None,
37 workgroup_size: None,
38 accepts_nan_mode: false,
39 notes: "Providers may supply a dedicated meshgrid hook; until then the runtime builds grids on the host and uploads them when GPU residency is requested.",
40};
41
42fn builtin_error(message: impl Into<String>) -> crate::RuntimeError {
43 build_runtime_error(message)
44 .with_builtin("meshgrid")
45 .build()
46}
47
48#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::creation::meshgrid")]
49pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
50 name: "meshgrid",
51 shape: ShapeRequirements::Any,
52 constant_strategy: ConstantStrategy::InlineLiteral,
53 elementwise: None,
54 reduction: None,
55 emits_nan: false,
56 notes:
57 "Meshgrid explicitly materialises dense coordinate arrays and therefore bypasses fusion.",
58};
59
60fn meshgrid_type(args: &[Type], _context: &ResolveContext) -> Type {
61 if args.is_empty() {
62 return Type::Unknown;
63 }
64 let mut axis_count = args.len();
65 if axis_count >= 2 && matches!(args[axis_count - 2], Type::String) {
66 axis_count = axis_count.saturating_sub(2);
67 }
68 if axis_count == 0 {
69 return Type::Unknown;
70 }
71 let axis_args = &args[..axis_count];
72 let len_x = axis_args.get(0).and_then(size_vector_len);
73 let len_y = axis_args.get(1).and_then(size_vector_len).or(len_x);
74 let len_z = axis_args.get(2).and_then(size_vector_len);
75 let shape = if axis_count >= 3 {
76 vec![len_y, len_x, len_z]
77 } else {
78 vec![len_y, len_x]
79 };
80 Type::Tensor { shape: Some(shape) }
81}
82
83const MESHGRID_OUTPUT_XY: [BuiltinParamDescriptor; 2] = [
84 BuiltinParamDescriptor {
85 name: "X",
86 ty: BuiltinParamType::NumericArray,
87 arity: BuiltinParamArity::Required,
88 default: None,
89 description: "Grid coordinates along X-axis.",
90 },
91 BuiltinParamDescriptor {
92 name: "Y",
93 ty: BuiltinParamType::NumericArray,
94 arity: BuiltinParamArity::Required,
95 default: None,
96 description: "Grid coordinates along Y-axis.",
97 },
98];
99
100const MESHGRID_OUTPUT_XYZ: [BuiltinParamDescriptor; 3] = [
101 BuiltinParamDescriptor {
102 name: "X",
103 ty: BuiltinParamType::NumericArray,
104 arity: BuiltinParamArity::Required,
105 default: None,
106 description: "Grid coordinates along X-axis.",
107 },
108 BuiltinParamDescriptor {
109 name: "Y",
110 ty: BuiltinParamType::NumericArray,
111 arity: BuiltinParamArity::Required,
112 default: None,
113 description: "Grid coordinates along Y-axis.",
114 },
115 BuiltinParamDescriptor {
116 name: "Z",
117 ty: BuiltinParamType::NumericArray,
118 arity: BuiltinParamArity::Optional,
119 default: None,
120 description: "Grid coordinates along Z-axis.",
121 },
122];
123
124const MESHGRID_SIG_X_INPUTS: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
125 name: "x",
126 ty: BuiltinParamType::NumericArray,
127 arity: BuiltinParamArity::Required,
128 default: None,
129 description: "X-axis vector.",
130}];
131
132const MESHGRID_SIG_XY_INPUTS: [BuiltinParamDescriptor; 2] = [
133 BuiltinParamDescriptor {
134 name: "x",
135 ty: BuiltinParamType::NumericArray,
136 arity: BuiltinParamArity::Required,
137 default: None,
138 description: "X-axis vector.",
139 },
140 BuiltinParamDescriptor {
141 name: "y",
142 ty: BuiltinParamType::NumericArray,
143 arity: BuiltinParamArity::Required,
144 default: None,
145 description: "Y-axis vector.",
146 },
147];
148
149const MESHGRID_SIG_XYZ_INPUTS: [BuiltinParamDescriptor; 3] = [
150 BuiltinParamDescriptor {
151 name: "x",
152 ty: BuiltinParamType::NumericArray,
153 arity: BuiltinParamArity::Required,
154 default: None,
155 description: "X-axis vector.",
156 },
157 BuiltinParamDescriptor {
158 name: "y",
159 ty: BuiltinParamType::NumericArray,
160 arity: BuiltinParamArity::Required,
161 default: None,
162 description: "Y-axis vector.",
163 },
164 BuiltinParamDescriptor {
165 name: "z",
166 ty: BuiltinParamType::NumericArray,
167 arity: BuiltinParamArity::Optional,
168 default: None,
169 description: "Z-axis vector.",
170 },
171];
172
173const MESHGRID_SIG_X_LIKE_INPUTS: [BuiltinParamDescriptor; 3] = [
174 BuiltinParamDescriptor {
175 name: "x",
176 ty: BuiltinParamType::NumericArray,
177 arity: BuiltinParamArity::Required,
178 default: None,
179 description: "X-axis vector.",
180 },
181 BuiltinParamDescriptor {
182 name: "like_kw",
183 ty: BuiltinParamType::StringScalar,
184 arity: BuiltinParamArity::Required,
185 default: Some("\"like\""),
186 description: "Like keyword.",
187 },
188 BuiltinParamDescriptor {
189 name: "prototype",
190 ty: BuiltinParamType::LikePrototype,
191 arity: BuiltinParamArity::Required,
192 default: None,
193 description: "Prototype controlling class/device residency.",
194 },
195];
196
197const MESHGRID_SIG_XY_LIKE_INPUTS: [BuiltinParamDescriptor; 4] = [
198 BuiltinParamDescriptor {
199 name: "x",
200 ty: BuiltinParamType::NumericArray,
201 arity: BuiltinParamArity::Required,
202 default: None,
203 description: "X-axis vector.",
204 },
205 BuiltinParamDescriptor {
206 name: "y",
207 ty: BuiltinParamType::NumericArray,
208 arity: BuiltinParamArity::Required,
209 default: None,
210 description: "Y-axis vector.",
211 },
212 BuiltinParamDescriptor {
213 name: "like_kw",
214 ty: BuiltinParamType::StringScalar,
215 arity: BuiltinParamArity::Required,
216 default: Some("\"like\""),
217 description: "Like keyword.",
218 },
219 BuiltinParamDescriptor {
220 name: "prototype",
221 ty: BuiltinParamType::LikePrototype,
222 arity: BuiltinParamArity::Required,
223 default: None,
224 description: "Prototype controlling class/device residency.",
225 },
226];
227
228const MESHGRID_SIG_XYZ_LIKE_INPUTS: [BuiltinParamDescriptor; 5] = [
229 BuiltinParamDescriptor {
230 name: "x",
231 ty: BuiltinParamType::NumericArray,
232 arity: BuiltinParamArity::Required,
233 default: None,
234 description: "X-axis vector.",
235 },
236 BuiltinParamDescriptor {
237 name: "y",
238 ty: BuiltinParamType::NumericArray,
239 arity: BuiltinParamArity::Required,
240 default: None,
241 description: "Y-axis vector.",
242 },
243 BuiltinParamDescriptor {
244 name: "z",
245 ty: BuiltinParamType::NumericArray,
246 arity: BuiltinParamArity::Optional,
247 default: None,
248 description: "Z-axis vector.",
249 },
250 BuiltinParamDescriptor {
251 name: "like_kw",
252 ty: BuiltinParamType::StringScalar,
253 arity: BuiltinParamArity::Required,
254 default: Some("\"like\""),
255 description: "Like keyword.",
256 },
257 BuiltinParamDescriptor {
258 name: "prototype",
259 ty: BuiltinParamType::LikePrototype,
260 arity: BuiltinParamArity::Required,
261 default: None,
262 description: "Prototype controlling class/device residency.",
263 },
264];
265
266const MESHGRID_SIGNATURES: [BuiltinSignatureDescriptor; 6] = [
267 BuiltinSignatureDescriptor {
268 label: "[X,Y] = meshgrid(x)",
269 inputs: &MESHGRID_SIG_X_INPUTS,
270 outputs: &MESHGRID_OUTPUT_XY,
271 },
272 BuiltinSignatureDescriptor {
273 label: "[X,Y] = meshgrid(x, y)",
274 inputs: &MESHGRID_SIG_XY_INPUTS,
275 outputs: &MESHGRID_OUTPUT_XY,
276 },
277 BuiltinSignatureDescriptor {
278 label: "[X,Y,Z] = meshgrid(x, y, z)",
279 inputs: &MESHGRID_SIG_XYZ_INPUTS,
280 outputs: &MESHGRID_OUTPUT_XYZ,
281 },
282 BuiltinSignatureDescriptor {
283 label: "[X,Y] = meshgrid(x, \"like\", prototype)",
284 inputs: &MESHGRID_SIG_X_LIKE_INPUTS,
285 outputs: &MESHGRID_OUTPUT_XY,
286 },
287 BuiltinSignatureDescriptor {
288 label: "[X,Y] = meshgrid(x, y, \"like\", prototype)",
289 inputs: &MESHGRID_SIG_XY_LIKE_INPUTS,
290 outputs: &MESHGRID_OUTPUT_XY,
291 },
292 BuiltinSignatureDescriptor {
293 label: "[X,Y,Z] = meshgrid(x, y, z, \"like\", prototype)",
294 inputs: &MESHGRID_SIG_XYZ_LIKE_INPUTS,
295 outputs: &MESHGRID_OUTPUT_XYZ,
296 },
297];
298
299const MESHGRID_ERRORS: [BuiltinErrorDescriptor; 11] = [
300 BuiltinErrorDescriptor {
301 code: "RM.MESHGRID.MISSING_AXIS",
302 identifier: None,
303 when: "No axis vectors are provided.",
304 message: "meshgrid: at least one input vector is required",
305 },
306 BuiltinErrorDescriptor {
307 code: "RM.MESHGRID.TOO_MANY_AXES",
308 identifier: None,
309 when: "More than three axis vectors are provided.",
310 message: "meshgrid: expected at most three input vectors",
311 },
312 BuiltinErrorDescriptor {
313 code: "RM.MESHGRID.LIKE_EXPECTED_PROTOTYPE",
314 identifier: None,
315 when: "The 'like' keyword is provided without a prototype argument.",
316 message: "meshgrid: expected prototype after 'like'",
317 },
318 BuiltinErrorDescriptor {
319 code: "RM.MESHGRID.MULTIPLE_LIKE",
320 identifier: None,
321 when: "The 'like' keyword is provided multiple times.",
322 message: "meshgrid: multiple 'like' specifications are not supported",
323 },
324 BuiltinErrorDescriptor {
325 code: "RM.MESHGRID.LIKE_POSITION",
326 identifier: None,
327 when: "The 'like' keyword is in an invalid position or not final.",
328 message: "meshgrid: 'like' must be the final argument",
329 },
330 BuiltinErrorDescriptor {
331 code: "RM.MESHGRID.UNRECOGNIZED_OPTION",
332 identifier: None,
333 when: "A trailing option string is not recognized.",
334 message: "meshgrid: unrecognised option",
335 },
336 BuiltinErrorDescriptor {
337 code: "RM.MESHGRID.INVALID_AXIS_INPUT",
338 identifier: None,
339 when: "Axis inputs are non-numeric or non-vector shapes.",
340 message: "meshgrid: input argument must be numeric vector data",
341 },
342 BuiltinErrorDescriptor {
343 code: "RM.MESHGRID.INVALID_PROTOTYPE",
344 identifier: None,
345 when: "The 'like' prototype is unsupported.",
346 message: "meshgrid: prototypes must be numeric arrays",
347 },
348 BuiltinErrorDescriptor {
349 code: "RM.MESHGRID.OUTPUT_COUNT_EXCEEDED",
350 identifier: None,
351 when: "Requested outputs exceed available outputs for provided axes.",
352 message:
353 "meshgrid: supports at most two outputs for 2-axis inputs and three for 3-axis inputs",
354 },
355 BuiltinErrorDescriptor {
356 code: "RM.MESHGRID.THIRD_OUTPUT_UNAVAILABLE",
357 identifier: None,
358 when: "A third output is requested without supplying a Z-axis vector.",
359 message: "meshgrid: third output requested but no Z vector was supplied",
360 },
361 BuiltinErrorDescriptor {
362 code: "RM.MESHGRID.COMPLEX_REAL_CONVERSION",
363 identifier: None,
364 when: "Complex axis values cannot be represented in requested real output class.",
365 message: "meshgrid: cannot represent complex values in a real output",
366 },
367];
368
369pub const MESHGRID_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
370 signatures: &MESHGRID_SIGNATURES,
371 output_mode: BuiltinOutputMode::ByRequestedOutputCount,
372 completion_policy: BuiltinCompletionPolicy::Public,
373 errors: &MESHGRID_ERRORS,
374};
375
376#[runtime_builtin(
377 name = "meshgrid",
378 category = "array/creation",
379 summary = "Generate coordinate matrices for 2-D and 3-D grids.",
380 keywords = "meshgrid,grid,gpu,like,3d",
381 accel = "array_construct",
382 type_resolver(meshgrid_type),
383 descriptor(crate::builtins::array::creation::meshgrid::MESHGRID_DESCRIPTOR),
384 builtin_path = "crate::builtins::array::creation::meshgrid"
385)]
386async fn meshgrid_builtin(rest: Vec<Value>) -> crate::BuiltinResult<Value> {
387 let eval = evaluate(&rest).await?;
388 if let Some(out_count) = crate::output_count::current_output_count() {
389 if out_count == 0 {
390 return Ok(Value::OutputList(Vec::new()));
391 }
392 let available = eval.output_count();
393 if out_count > available {
394 let msg = if available == 2 {
395 "meshgrid with two inputs supports at most two outputs"
396 } else {
397 "meshgrid supports at most three outputs"
398 };
399 return Err(builtin_error(msg));
400 }
401 let mut outputs = Vec::with_capacity(out_count);
402 let first = eval.first().await?;
403 outputs.push(first);
404 if out_count >= 2 {
405 outputs.push(eval.second().await?);
406 }
407 if out_count >= 3 {
408 outputs.push(eval.third().await?);
409 }
410 return Ok(Value::OutputList(outputs));
411 }
412 eval.first().await
413}
414
415pub async fn evaluate(args: &[Value]) -> crate::BuiltinResult<MeshgridEval> {
417 let parsed = ParsedMeshgrid::parse(args).await?;
418 let (x_axis, y_axis, z_axis) = normalise_axes(&parsed.axes);
419
420 let require_complex = parsed.axes.iter().any(|axis| axis.is_complex);
421
422 let target_class = match &parsed.template {
423 OutputTemplate::Default => {
424 if require_complex {
425 PrototypeClass::Complex
426 } else {
427 PrototypeClass::Real
428 }
429 }
430 OutputTemplate::Like(spec) => {
431 if require_complex {
432 PrototypeClass::Complex
433 } else {
434 spec.class
435 }
436 }
437 };
438
439 let target_residency = match &parsed.template {
440 OutputTemplate::Default => {
441 if parsed.prefer_gpu {
442 DevicePreference::Gpu
443 } else {
444 DevicePreference::Host
445 }
446 }
447 OutputTemplate::Like(spec) => spec.residency,
448 };
449
450 let axes_all_real = !require_complex;
451 let mut outputs: Vec<MeshgridOutput> = Vec::new();
452
453 if axes_all_real
454 && matches!(target_class, PrototypeClass::Real)
455 && matches!(target_residency, DevicePreference::Gpu)
456 {
457 if let Some(gpu) = try_meshgrid_gpu_from_vector_axes(&x_axis, &y_axis, z_axis.as_ref())? {
458 outputs = gpu;
459 }
460 }
461
462 if outputs.is_empty() {
463 let x_host = axis_to_host_async(&x_axis).await?;
465 let y_host = axis_to_host_async(&y_axis).await?;
466 let z_host = match z_axis.as_ref() {
467 Some(axis) => Some(axis_to_host_async(axis).await?),
468 None => None,
469 };
470 outputs = build_outputs(&x_host, &y_host, z_host.as_ref())
471 .into_iter()
472 .map(MeshgridOutput::Host)
473 .collect();
474 }
475
476 Ok(MeshgridEval {
477 outputs,
478 target_class,
479 target_residency,
480 })
481}
482
483#[derive(Clone)]
484struct ParsedMeshgrid {
485 axes: Vec<AxisData>,
486 template: OutputTemplate,
487 prefer_gpu: bool,
488}
489
490impl ParsedMeshgrid {
491 async fn parse(args: &[Value]) -> crate::BuiltinResult<Self> {
492 if args.is_empty() {
493 return Err(builtin_error(
494 "meshgrid: at least one input vector is required",
495 ));
496 }
497 let mut axis_values: Vec<Value> = Vec::new();
498 let mut like_proto: Option<Value> = None;
499 let mut prefer_gpu = false;
500 let mut idx = 0;
501 while idx < args.len() {
502 let value = args[idx].clone();
503 if let Some(keyword) = keyword_of(&value) {
504 match keyword.as_str() {
505 "like" => {
506 if like_proto.is_some() {
507 return Err(builtin_error(
508 "meshgrid: multiple 'like' specifications are not supported",
509 ));
510 }
511 if axis_values.is_empty() {
512 return Err(builtin_error(
513 "meshgrid: 'like' must follow at least one input vector",
514 ));
515 }
516 let Some(proto) = args.get(idx + 1).cloned() else {
517 return Err(builtin_error("meshgrid: expected prototype after 'like'"));
518 };
519 like_proto = Some(proto);
520 idx += 2;
521 if idx < args.len() {
522 return Err(builtin_error(
523 "meshgrid: 'like' must be the final argument",
524 ));
525 }
526 break;
527 }
528 other => {
529 return Err(builtin_error(format!(
530 "meshgrid: unrecognised option '{other}'"
531 )));
532 }
533 }
534 }
535
536 if let Value::GpuTensor(_) = value {
537 prefer_gpu = true;
538 }
539 axis_values.push(value);
540 idx += 1;
541 }
542
543 if axis_values.is_empty() {
544 return Err(builtin_error(
545 "meshgrid: at least one input vector is required",
546 ));
547 }
548 if axis_values.len() > 3 {
549 return Err(builtin_error(
550 "meshgrid: expected at most three input vectors",
551 ));
552 }
553
554 let mut axes = Vec::with_capacity(max(axis_values.len(), 2));
555 for (i, value) in axis_values.into_iter().enumerate() {
556 let mut consumed_gpu = false;
557 let data = axis_from_value(value, i, &mut consumed_gpu).await?;
558 if consumed_gpu {
559 prefer_gpu = true;
560 }
561 axes.push(data);
562 }
563
564 if !prefer_gpu {
565 if let Some(max_len) = axes.iter().map(|axis| axis.len).max() {
566 if max_len > 0
567 && sequence_gpu_preference(max_len, SequenceIntent::MeshAxis, false).prefer_gpu
568 {
569 prefer_gpu = true;
570 }
571 }
572 }
573
574 let template = if let Some(proto) = like_proto {
575 OutputTemplate::Like(analyse_like_prototype(&proto)?)
576 } else {
577 OutputTemplate::Default
578 };
579
580 Ok(Self {
581 axes,
582 template,
583 prefer_gpu,
584 })
585 }
586}
587
588#[derive(Clone)]
589enum OutputTemplate {
590 Default,
591 Like(PrototypeSpec),
592}
593
594#[derive(Clone)]
595struct PrototypeSpec {
596 residency: DevicePreference,
597 class: PrototypeClass,
598}
599
600#[derive(Clone, Copy, PartialEq, Eq)]
601enum PrototypeClass {
602 Real,
603 Complex,
604}
605
606#[derive(Clone, Copy)]
607enum DevicePreference {
608 Host,
609 Gpu,
610}
611
612fn analyse_like_prototype(proto: &Value) -> crate::BuiltinResult<PrototypeSpec> {
613 match proto {
614 Value::GpuTensor(_) => Ok(PrototypeSpec {
615 residency: DevicePreference::Gpu,
616 class: PrototypeClass::Real,
617 }),
618 Value::ComplexTensor(_) | Value::Complex(_, _) => Ok(PrototypeSpec {
619 residency: DevicePreference::Host,
620 class: PrototypeClass::Complex,
621 }),
622 Value::Tensor(_)
623 | Value::SparseTensor(_)
624 | Value::Num(_)
625 | Value::Int(_)
626 | Value::Bool(_)
627 | Value::LogicalArray(_) => Ok(PrototypeSpec {
628 residency: DevicePreference::Host,
629 class: PrototypeClass::Real,
630 }),
631 Value::CharArray(_) | Value::String(_) | Value::StringArray(_) => Err(builtin_error(
632 "meshgrid: prototypes must be numeric or gpuArray values",
633 )),
634 Value::Symbolic(_) => Err(builtin_error(
635 "meshgrid: prototypes must be numeric or gpuArray values",
636 )),
637 Value::Cell(_)
638 | Value::Struct(_)
639 | Value::Object(_)
640 | Value::HandleObject(_)
641 | Value::Listener(_)
642 | Value::FunctionHandle(_)
643 | Value::ExternalFunctionHandle(_)
644 | Value::MethodFunctionHandle(_)
645 | Value::BoundFunctionHandle { .. }
646 | Value::Closure(_)
647 | Value::ClassRef(_)
648 | Value::MException(_)
649 | Value::OutputList(_) => Err(builtin_error("meshgrid: prototypes must be numeric arrays")),
650 }
651}
652
653#[derive(Clone)]
654struct AxisData {
655 values: Vec<(f64, f64)>,
656 len: usize,
657 is_complex: bool,
658 gpu_real: Option<GpuTensorHandle>,
659}
660
661async fn axis_from_value(
662 value: Value,
663 index: usize,
664 prefer_gpu: &mut bool,
665) -> crate::BuiltinResult<AxisData> {
666 match value {
667 Value::Tensor(tensor) => axis_from_tensor(tensor, index),
668 Value::LogicalArray(logical) => {
669 let tensor = tensor::logical_to_tensor(&logical)?;
670 axis_from_tensor(tensor, index)
671 }
672 Value::Num(n) => Ok(AxisData {
673 values: vec![(n, 0.0)],
674 len: 1,
675 is_complex: false,
676 gpu_real: None,
677 }),
678 Value::Int(i) => {
679 let val = i.to_f64();
680 Ok(AxisData {
681 values: vec![(val, 0.0)],
682 len: 1,
683 is_complex: false,
684 gpu_real: None,
685 })
686 }
687 Value::Bool(b) => Ok(AxisData {
688 values: vec![(if b { 1.0 } else { 0.0 }, 0.0)],
689 len: 1,
690 is_complex: false,
691 gpu_real: None,
692 }),
693 Value::Complex(re, im) => Ok(AxisData {
694 values: vec![(re, im)],
695 len: 1,
696 is_complex: im != 0.0,
697 gpu_real: None,
698 }),
699 Value::ComplexTensor(tensor) => axis_from_complex_tensor(tensor, index),
700 Value::GpuTensor(handle) => {
701 if is_vector_shape(&handle.shape) {
704 *prefer_gpu = true;
705 return Ok(AxisData {
706 values: Vec::new(),
707 len: vector_len_from_shape(&handle.shape),
708 is_complex: false,
709 gpu_real: Some(handle),
710 });
711 }
712
713 let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
715 if is_vector_shape(&tensor.shape) {
716 *prefer_gpu = true;
717 }
718 axis_from_tensor(tensor, index)
719 }
720 other => Err(builtin_error(format!(
721 "meshgrid: input argument {} must be numeric, got {other:?}",
722 index + 1
723 ))),
724 }
725}
726
727fn axis_from_tensor(tensor: Tensor, index: usize) -> crate::BuiltinResult<AxisData> {
728 if is_vector_shape(&tensor.shape) {
729 let mut values = Vec::with_capacity(tensor.data.len());
730 for &v in &tensor.data {
731 values.push((v, 0.0));
732 }
733 return Ok(AxisData {
734 len: values.len(),
735 values,
736 is_complex: false,
737 gpu_real: None,
738 });
739 }
740
741 if let Some(axis) = axis_from_meshgrid_matrix_real(&tensor, index)? {
747 return Ok(axis);
748 }
749
750 Err(builtin_error(format!(
751 "meshgrid: input argument {} must be a vector (1xN or Nx1), got shape {:?}",
752 index + 1,
753 tensor.shape
754 )))
755}
756
757fn axis_from_complex_tensor(tensor: ComplexTensor, index: usize) -> crate::BuiltinResult<AxisData> {
758 if is_vector_shape(&tensor.shape) {
759 let is_complex = tensor
760 .data
761 .iter()
762 .any(|&(_, imag)| !imag.is_nan() && imag != 0.0);
763 return Ok(AxisData {
764 len: tensor.data.len(),
765 values: tensor.data,
766 is_complex,
767 gpu_real: None,
768 });
769 }
770
771 if let Some(axis) = axis_from_meshgrid_matrix_complex(&tensor, index)? {
772 return Ok(axis);
773 }
774
775 Err(builtin_error(format!(
776 "meshgrid: input argument {} must be a vector (1xN or Nx1), got shape {:?}",
777 index + 1,
778 tensor.shape
779 )))
780}
781
782fn axis_from_meshgrid_matrix_real(
783 tensor: &Tensor,
784 index: usize,
785) -> crate::BuiltinResult<Option<AxisData>> {
786 let (rows, cols) = match tensor.shape.as_slice() {
787 [r, c] => (*r, *c),
788 _ => return Ok(None),
789 };
790 if rows <= 1 || cols <= 1 {
791 return Ok(None);
792 }
793
794 let expect_rows_constant = index == 0;
797
798 if expect_rows_constant {
799 if !matrix_rows_are_identical_real(tensor, rows, cols) {
800 return Ok(None);
801 }
802 let mut values = Vec::with_capacity(cols);
804 for col in 0..cols {
805 let idx = rows * col;
806 values.push((tensor.data[idx], 0.0));
807 }
808 return Ok(Some(AxisData {
809 len: values.len(),
810 values,
811 is_complex: false,
812 gpu_real: None,
813 }));
814 }
815
816 if !matrix_cols_are_identical_real(tensor, rows, cols) {
817 return Ok(None);
818 }
819 let mut values = Vec::with_capacity(rows);
821 for row in 0..rows {
822 values.push((tensor.data[row], 0.0));
823 }
824 Ok(Some(AxisData {
825 len: values.len(),
826 values,
827 is_complex: false,
828 gpu_real: None,
829 }))
830}
831
832fn axis_from_meshgrid_matrix_complex(
833 tensor: &ComplexTensor,
834 index: usize,
835) -> crate::BuiltinResult<Option<AxisData>> {
836 let (rows, cols) = match tensor.shape.as_slice() {
837 [r, c] => (*r, *c),
838 _ => return Ok(None),
839 };
840 if rows <= 1 || cols <= 1 {
841 return Ok(None);
842 }
843
844 let expect_rows_constant = index == 0;
845 if expect_rows_constant {
846 if !matrix_rows_are_identical_complex(tensor, rows, cols) {
847 return Ok(None);
848 }
849 let mut values = Vec::with_capacity(cols);
850 for col in 0..cols {
851 let idx = rows * col;
852 values.push(tensor.data[idx]);
853 }
854 let is_complex = values.iter().any(|&(_, im)| !im.is_nan() && im != 0.0);
855 return Ok(Some(AxisData {
856 len: values.len(),
857 values,
858 is_complex,
859 gpu_real: None,
860 }));
861 }
862
863 if !matrix_cols_are_identical_complex(tensor, rows, cols) {
864 return Ok(None);
865 }
866 let mut values = Vec::with_capacity(rows);
867 for row in 0..rows {
868 values.push(tensor.data[row]);
869 }
870 let is_complex = values.iter().any(|&(_, im)| !im.is_nan() && im != 0.0);
871 Ok(Some(AxisData {
872 len: values.len(),
873 values,
874 is_complex,
875 gpu_real: None,
876 }))
877}
878
879fn matrix_rows_are_identical_real(tensor: &Tensor, rows: usize, cols: usize) -> bool {
880 for row in 1..rows {
881 for col in 0..cols {
882 let idx0 = rows * col;
883 let idx = row + rows * col;
884 if tensor.data[idx] != tensor.data[idx0] {
885 return false;
886 }
887 }
888 }
889 true
890}
891
892fn matrix_cols_are_identical_real(tensor: &Tensor, rows: usize, cols: usize) -> bool {
893 for col in 1..cols {
894 for row in 0..rows {
895 let idx0 = row;
896 let idx = row + rows * col;
897 if tensor.data[idx] != tensor.data[idx0] {
898 return false;
899 }
900 }
901 }
902 true
903}
904
905fn matrix_rows_are_identical_complex(tensor: &ComplexTensor, rows: usize, cols: usize) -> bool {
906 for row in 1..rows {
907 for col in 0..cols {
908 let idx0 = rows * col;
909 let idx = row + rows * col;
910 if tensor.data[idx] != tensor.data[idx0] {
911 return false;
912 }
913 }
914 }
915 true
916}
917
918fn matrix_cols_are_identical_complex(tensor: &ComplexTensor, rows: usize, cols: usize) -> bool {
919 for col in 1..cols {
920 for row in 0..rows {
921 let idx0 = row;
922 let idx = row + rows * col;
923 if tensor.data[idx] != tensor.data[idx0] {
924 return false;
925 }
926 }
927 }
928 true
929}
930
931fn is_vector_shape(shape: &[usize]) -> bool {
932 if shape.is_empty() {
933 return true;
934 }
935 let mut non_singleton = 0usize;
936 for &dim in shape {
937 if dim > 1 {
938 non_singleton += 1;
939 }
940 }
941 non_singleton <= 1
942}
943
944fn vector_len_from_shape(shape: &[usize]) -> usize {
945 if shape.is_empty() {
946 return 1;
947 }
948 shape.iter().copied().max().unwrap_or(0)
949}
950
951async fn axis_to_host_async(axis: &AxisData) -> crate::BuiltinResult<AxisData> {
952 if axis.gpu_real.is_none() {
953 return Ok(axis.clone());
954 }
955 let handle = axis.gpu_real.as_ref().expect("checked gpu_real is_some");
956 let tensor = gpu_helpers::gather_tensor_async(handle).await?;
957 axis_from_tensor(tensor, 0)
959}
960
961fn try_meshgrid_gpu_from_vector_axes(
962 x_axis: &AxisData,
963 y_axis: &AxisData,
964 z_axis: Option<&AxisData>,
965) -> crate::BuiltinResult<Option<Vec<MeshgridOutput>>> {
966 let Some(x_handle) = x_axis.gpu_real.as_ref() else {
967 return Ok(None);
968 };
969 let Some(y_handle) = y_axis.gpu_real.as_ref() else {
970 return Ok(None);
971 };
972
973 let z_handle = match z_axis {
974 Some(axis) => match axis.gpu_real.as_ref() {
975 Some(h) => Some(h),
976 None => return Ok(None),
977 },
978 None => None,
979 };
980
981 let Some(provider) = runmat_accelerate_api::provider_for_handle(x_handle) else {
982 return Ok(None);
983 };
984 if runmat_accelerate_api::provider_for_handle(y_handle).is_none() {
985 return Ok(None);
986 }
987 if let Some(z) = z_handle {
988 if runmat_accelerate_api::provider_for_handle(z).is_none() {
989 return Ok(None);
990 }
991 }
992
993 let nx = x_axis.len;
994 let ny = y_axis.len;
995 let nz = z_axis.map(|axis| axis.len).unwrap_or(1);
996
997 let x_row = provider
999 .reshape(x_handle, &[1, nx])
1000 .map_err(|e| builtin_error(format!("meshgrid: reshape X failed: {e}")))?;
1001 let y_col = provider
1002 .reshape(y_handle, &[ny, 1])
1003 .map_err(|e| builtin_error(format!("meshgrid: reshape Y failed: {e}")))?;
1004
1005 let mut outputs = Vec::with_capacity(if z_handle.is_some() { 3 } else { 2 });
1006 if let Some(z) = z_handle {
1007 let x_base = provider
1008 .reshape(&x_row, &[1, nx, 1])
1009 .map_err(|e| builtin_error(format!("meshgrid: reshape X(3d) failed: {e}")))?;
1010 let y_base = provider
1011 .reshape(&y_col, &[ny, 1, 1])
1012 .map_err(|e| builtin_error(format!("meshgrid: reshape Y(3d) failed: {e}")))?;
1013
1014 let x_grid = provider
1015 .repmat(&x_base, &[ny, 1, nz])
1016 .map_err(|e| builtin_error(format!("meshgrid: repmat X failed: {e}")))?;
1017 let y_grid = provider
1018 .repmat(&y_base, &[1, nx, nz])
1019 .map_err(|e| builtin_error(format!("meshgrid: repmat Y failed: {e}")))?;
1020
1021 outputs.push(MeshgridOutput::GpuReal(x_grid));
1022 outputs.push(MeshgridOutput::GpuReal(y_grid));
1023 let z_axis_row = provider
1024 .reshape(z, &[1, nz])
1025 .map_err(|e| builtin_error(format!("meshgrid: reshape Z failed: {e}")))?;
1026 let z_base = provider
1027 .reshape(&z_axis_row, &[1, 1, nz])
1028 .map_err(|e| builtin_error(format!("meshgrid: reshape Z(3d) failed: {e}")))?;
1029 let z_grid = provider
1030 .repmat(&z_base, &[ny, nx, 1])
1031 .map_err(|e| builtin_error(format!("meshgrid: repmat Z failed: {e}")))?;
1032 outputs.push(MeshgridOutput::GpuReal(z_grid));
1033 } else {
1034 let x_grid = provider
1035 .repmat(&x_row, &[ny, 1])
1036 .map_err(|e| builtin_error(format!("meshgrid: repmat X failed: {e}")))?;
1037 let y_grid = provider
1038 .repmat(&y_col, &[1, nx])
1039 .map_err(|e| builtin_error(format!("meshgrid: repmat Y failed: {e}")))?;
1040 outputs.push(MeshgridOutput::GpuReal(x_grid));
1041 outputs.push(MeshgridOutput::GpuReal(y_grid));
1042 }
1043
1044 Ok(Some(outputs))
1045}
1046
1047fn normalise_axes(axes: &[AxisData]) -> (AxisData, AxisData, Option<AxisData>) {
1048 match axes.len() {
1049 1 => {
1050 let x = axes[0].clone();
1051 (x.clone(), x, None)
1052 }
1053 2 => {
1054 let x = axes[0].clone();
1055 let y = axes[1].clone();
1056 (x, y, None)
1057 }
1058 3 => {
1059 let x = axes[0].clone();
1060 let y = axes[1].clone();
1061 let z = axes[2].clone();
1062 (x, y, Some(z))
1063 }
1064 _ => unreachable!(),
1065 }
1066}
1067
1068fn build_outputs(
1069 x_axis: &AxisData,
1070 y_axis: &AxisData,
1071 z_axis: Option<&AxisData>,
1072) -> Vec<GridOutput> {
1073 let nx = x_axis.len;
1074 let ny = y_axis.len;
1075 let nz = z_axis.map(|axis| axis.len).unwrap_or(1);
1076 let total = nx * ny * nz;
1077 let mut x_data = Vec::with_capacity(total);
1078 let mut y_data = Vec::with_capacity(total);
1079 let mut z_data = z_axis.map(|_| Vec::with_capacity(total));
1080
1081 for k in 0..nz {
1082 let z_value = z_axis.map(|axis| axis.values[k]);
1083 for col in 0..nx {
1084 let x_value = x_axis.values[col];
1085 for row in 0..ny {
1086 x_data.push(x_value);
1087 y_data.push(y_axis.values[row]);
1088 if let Some(ref mut z_vec) = z_data {
1089 z_vec.push(z_value.unwrap());
1090 }
1091 }
1092 }
1093 }
1094
1095 let mut outputs = Vec::new();
1096 let base_shape = if nz == 1 {
1097 vec![ny, nx]
1098 } else {
1099 vec![ny, nx, nz]
1100 };
1101 outputs.push(GridOutput {
1102 shape: base_shape.clone(),
1103 data: x_data,
1104 });
1105 outputs.push(GridOutput {
1106 shape: base_shape.clone(),
1107 data: y_data,
1108 });
1109 if let Some(z_vec) = z_data {
1110 outputs.push(GridOutput {
1111 shape: base_shape,
1112 data: z_vec,
1113 });
1114 }
1115 outputs
1116}
1117
1118struct GridOutput {
1119 shape: Vec<usize>,
1120 data: Vec<(f64, f64)>,
1121}
1122
1123impl GridOutput {
1124 fn to_value(
1125 &self,
1126 class: PrototypeClass,
1127 residency: DevicePreference,
1128 ) -> crate::BuiltinResult<Value> {
1129 match class {
1130 PrototypeClass::Real => self.to_real_value(residency),
1131 PrototypeClass::Complex => self.to_complex_value(residency),
1132 }
1133 }
1134
1135 fn to_real_value(&self, residency: DevicePreference) -> crate::BuiltinResult<Value> {
1136 let mut real = Vec::with_capacity(self.data.len());
1137 for &(re, im) in &self.data {
1138 if im != 0.0 {
1139 return Err(builtin_error(
1140 "meshgrid: cannot represent complex values in a real output",
1141 ));
1142 }
1143 real.push(re);
1144 }
1145 let tensor = Tensor::new(real, self.shape.clone())
1146 .map_err(|e| builtin_error(format!("meshgrid: {e}")))?;
1147 match residency {
1148 DevicePreference::Host => Ok(tensor::tensor_into_value(tensor)),
1149 DevicePreference::Gpu => to_gpu_tensor_value(tensor),
1150 }
1151 }
1152
1153 fn to_complex_value(&self, residency: DevicePreference) -> crate::BuiltinResult<Value> {
1154 let tensor = ComplexTensor::new(self.data.clone(), self.shape.clone())
1155 .map_err(|e| builtin_error(format!("meshgrid: {e}")))?;
1156 match residency {
1157 DevicePreference::Host => Ok(complex_tensor_into_value(tensor)),
1158 DevicePreference::Gpu => {
1159 warn!("meshgrid: complex GPU outputs are not implemented; returning host complex array");
1160 Ok(complex_tensor_into_value(tensor))
1161 }
1162 }
1163 }
1164}
1165
1166fn to_gpu_tensor_value(tensor: Tensor) -> crate::BuiltinResult<Value> {
1167 if let Some(provider) = runmat_accelerate_api::provider() {
1168 let view = HostTensorView {
1169 data: &tensor.data,
1170 shape: &tensor.shape,
1171 };
1172 match provider.upload(&view) {
1173 Ok(handle) => return Ok(Value::GpuTensor(handle)),
1174 Err(err) => {
1175 warn!("meshgrid: failed to upload tensor to GPU, returning host array: {err}")
1176 }
1177 }
1178 }
1179 Ok(tensor::tensor_into_value(tensor))
1180}
1181
1182fn tensor_to_complex_value(tensor: Tensor) -> crate::BuiltinResult<Value> {
1183 let data: Vec<(f64, f64)> = tensor.data.iter().map(|&re| (re, 0.0)).collect();
1184 let complex = ComplexTensor::new(data, tensor.shape.clone())
1185 .map_err(|e| builtin_error(format!("meshgrid: {e}")))?;
1186 Ok(complex_tensor_into_value(complex))
1187}
1188
1189enum MeshgridOutput {
1190 Host(GridOutput),
1191 GpuReal(GpuTensorHandle),
1192}
1193
1194impl MeshgridOutput {
1195 async fn to_value(
1196 &self,
1197 class: PrototypeClass,
1198 residency: DevicePreference,
1199 ) -> crate::BuiltinResult<Value> {
1200 match self {
1201 MeshgridOutput::Host(host) => host.to_value(class, residency),
1202 MeshgridOutput::GpuReal(handle) => match (class, residency) {
1203 (PrototypeClass::Real, DevicePreference::Gpu) => {
1204 Ok(Value::GpuTensor(handle.clone()))
1205 }
1206 (PrototypeClass::Real, DevicePreference::Host) => {
1207 let tensor = gpu_helpers::gather_tensor_async(handle).await?;
1208 Ok(tensor::tensor_into_value(tensor))
1209 }
1210 (PrototypeClass::Complex, DevicePreference::Host) => {
1211 let tensor = gpu_helpers::gather_tensor_async(handle).await?;
1212 tensor_to_complex_value(tensor)
1213 }
1214 (PrototypeClass::Complex, DevicePreference::Gpu) => {
1215 warn!("meshgrid: complex GPU outputs are not implemented; returning host complex array");
1216 let tensor = gpu_helpers::gather_tensor_async(handle).await?;
1217 tensor_to_complex_value(tensor)
1218 }
1219 },
1220 }
1221 }
1222}
1223
1224pub struct MeshgridEval {
1227 outputs: Vec<MeshgridOutput>,
1228 target_class: PrototypeClass,
1229 target_residency: DevicePreference,
1230}
1231
1232impl MeshgridEval {
1233 pub fn output_count(&self) -> usize {
1234 self.outputs.len()
1235 }
1236
1237 pub async fn first(&self) -> crate::BuiltinResult<Value> {
1238 self.outputs[0]
1239 .to_value(self.target_class, self.target_residency)
1240 .await
1241 }
1242
1243 pub async fn second(&self) -> crate::BuiltinResult<Value> {
1244 if self.outputs.len() < 2 {
1245 Err(builtin_error("meshgrid: second output unavailable"))
1246 } else {
1247 self.outputs[1]
1248 .to_value(self.target_class, self.target_residency)
1249 .await
1250 }
1251 }
1252
1253 pub async fn third(&self) -> crate::BuiltinResult<Value> {
1254 if self.outputs.len() < 3 {
1255 Err(builtin_error(
1256 "meshgrid: third output requested but no Z vector was supplied",
1257 ))
1258 } else {
1259 self.outputs[2]
1260 .to_value(self.target_class, self.target_residency)
1261 .await
1262 }
1263 }
1264}
1265
1266#[cfg(test)]
1267pub(crate) mod tests {
1268 use super::*;
1269 use crate::builtins::common::test_support;
1270 use futures::executor::block_on;
1271 #[cfg(feature = "wgpu")]
1272 use runmat_accelerate_api::AccelProvider;
1273
1274 use runmat_accelerate_api::HostTensorView;
1275
1276 fn evaluate(args: &[Value]) -> crate::BuiltinResult<MeshgridEval> {
1277 block_on(super::evaluate(args))
1278 }
1279
1280 fn eval_first(eval: &MeshgridEval) -> crate::BuiltinResult<Value> {
1281 block_on(eval.first())
1282 }
1283
1284 fn eval_second(eval: &MeshgridEval) -> crate::BuiltinResult<Value> {
1285 block_on(eval.second())
1286 }
1287
1288 fn eval_third(eval: &MeshgridEval) -> crate::BuiltinResult<Value> {
1289 block_on(eval.third())
1290 }
1291
1292 fn tensor_from_vec(data: Vec<f64>, rows: usize, cols: usize) -> Tensor {
1293 Tensor::new(data, vec![rows, cols]).unwrap()
1294 }
1295
1296 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1297 #[test]
1298 fn meshgrid_single_input_duplicates_axis() {
1299 let x = tensor_from_vec(vec![-1.0, 0.0, 1.0], 1, 3);
1300 let eval = evaluate(&[Value::Tensor(x)]).expect("meshgrid");
1301 assert_eq!(eval.output_count(), 2);
1302 let x_out = test_support::gather(eval_first(&eval).expect("X")).expect("host");
1303 assert_eq!(x_out.shape, vec![3, 3]);
1304 assert_eq!(
1305 x_out.data,
1306 vec![-1.0, -1.0, -1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0]
1307 );
1308 let y_out = test_support::gather(eval_second(&eval).expect("Y")).expect("host");
1309 assert_eq!(y_out.shape, vec![3, 3]);
1310 assert_eq!(
1311 y_out.data,
1312 vec![-1.0, 0.0, 1.0, -1.0, 0.0, 1.0, -1.0, 0.0, 1.0]
1313 );
1314 }
1315
1316 #[test]
1317 fn meshgrid_type_infers_rank_from_axis_count() {
1318 let ctx = ResolveContext::new(Vec::new());
1319 assert_eq!(
1320 meshgrid_type(&[Type::Num, Type::Num], &ctx),
1321 Type::Tensor {
1322 shape: Some(vec![Some(1), Some(1)])
1323 }
1324 );
1325 assert_eq!(
1326 meshgrid_type(&[Type::Num, Type::Num, Type::Num], &ctx),
1327 Type::Tensor {
1328 shape: Some(vec![Some(1), Some(1), Some(1)])
1329 }
1330 );
1331 }
1332
1333 #[test]
1334 fn meshgrid_type_uses_vector_lengths() {
1335 let ctx = ResolveContext::new(Vec::new());
1336 assert_eq!(
1337 meshgrid_type(
1338 &[
1339 Type::Tensor {
1340 shape: Some(vec![Some(1), Some(201)]),
1341 },
1342 Type::Tensor {
1343 shape: Some(vec![Some(1), Some(101)]),
1344 },
1345 ],
1346 &ctx,
1347 ),
1348 Type::Tensor {
1349 shape: Some(vec![Some(101), Some(201)])
1350 }
1351 );
1352 }
1353
1354 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1355 #[test]
1356 fn meshgrid_rectangular_inputs() {
1357 let x = tensor_from_vec(vec![0.0, 0.5, 1.0], 1, 3);
1358 let y = tensor_from_vec(vec![10.0, 20.0], 2, 1);
1359 let eval = evaluate(&[Value::Tensor(x), Value::Tensor(y)]).expect("meshgrid");
1360 assert_eq!(eval.output_count(), 2);
1361 let x_out = test_support::gather(eval_first(&eval).expect("X")).expect("host");
1362 assert_eq!(x_out.shape, vec![2, 3]);
1363 assert_eq!(x_out.data, vec![0.0, 0.0, 0.5, 0.5, 1.0, 1.0]);
1364 let y_out = test_support::gather(eval_second(&eval).expect("Y")).expect("host");
1365 assert_eq!(y_out.shape, vec![2, 3]);
1366 assert_eq!(y_out.data, vec![10.0, 20.0, 10.0, 20.0, 10.0, 20.0]);
1367 }
1368
1369 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1370 #[test]
1371 fn meshgrid_three_inputs_volume() {
1372 let x = tensor_from_vec(vec![1.0, 2.0], 1, 2);
1373 let y = tensor_from_vec(vec![5.0, 6.0, 7.0], 3, 1);
1374 let z = tensor_from_vec(vec![0.0, 1.0], 1, 2);
1375 let eval =
1376 evaluate(&[Value::Tensor(x), Value::Tensor(y), Value::Tensor(z)]).expect("meshgrid");
1377 assert_eq!(eval.output_count(), 3);
1378 let x_out = test_support::gather(eval_first(&eval).expect("X")).expect("host");
1379 assert_eq!(x_out.shape, vec![3, 2, 2]);
1380 assert_eq!(
1381 x_out.data,
1382 vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0]
1383 );
1384 let z_out = test_support::gather(eval_third(&eval).expect("Z")).expect("host");
1385 assert_eq!(z_out.shape, vec![3, 2, 2]);
1386 assert_eq!(
1387 z_out.data,
1388 vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
1389 );
1390 }
1391
1392 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1393 #[test]
1394 fn meshgrid_like_keeps_gpu_residency() {
1395 test_support::with_test_provider(|provider| {
1396 let x = tensor_from_vec(vec![-1.0, 0.0, 1.0], 1, 3);
1397 let y = tensor_from_vec(vec![2.0, 4.0], 2, 1);
1398 let proto = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
1399 let proto_view = HostTensorView {
1400 data: &proto.data,
1401 shape: &proto.shape,
1402 };
1403 let proto_handle = provider.upload(&proto_view).expect("upload");
1404 let eval = evaluate(&[
1405 Value::Tensor(x),
1406 Value::Tensor(y),
1407 Value::from("like"),
1408 Value::GpuTensor(proto_handle),
1409 ])
1410 .expect("meshgrid");
1411 let x_value = eval_first(&eval).expect("X");
1412 assert!(matches!(x_value, Value::GpuTensor(_)));
1413 let gathered = test_support::gather(x_value).expect("gather");
1414 assert_eq!(gathered.shape, vec![2, 3]);
1415 });
1416 }
1417
1418 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1419 #[test]
1420 fn meshgrid_gpu_inputs_roundtrip() {
1421 test_support::with_test_provider(|provider| {
1422 let x = tensor_from_vec(vec![0.0, 0.5], 1, 2);
1423 let y = tensor_from_vec(vec![1.0, 2.0], 2, 1);
1424 let x_view = HostTensorView {
1425 data: &x.data,
1426 shape: &x.shape,
1427 };
1428 let y_view = HostTensorView {
1429 data: &y.data,
1430 shape: &y.shape,
1431 };
1432 let x_handle = provider.upload(&x_view).expect("upload");
1433 let y_handle = provider.upload(&y_view).expect("upload");
1434 let eval = evaluate(&[Value::GpuTensor(x_handle), Value::GpuTensor(y_handle)])
1435 .expect("meshgrid");
1436 assert!(matches!(eval_first(&eval).expect("X"), Value::GpuTensor(_)));
1437 assert!(matches!(
1438 eval_second(&eval).expect("Y"),
1439 Value::GpuTensor(_)
1440 ));
1441 });
1442 }
1443
1444 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1445 #[test]
1446 #[cfg(feature = "wgpu")]
1447 fn meshgrid_wgpu_matches_cpu() {
1448 let provider = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1449 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1450 )
1451 .expect("wgpu provider");
1452
1453 let x = tensor_from_vec(vec![-1.0, 0.0, 1.0, 2.0], 1, 4);
1454 let y = tensor_from_vec(vec![5.0, 6.0], 2, 1);
1455
1456 let cpu_eval =
1457 evaluate(&[Value::Tensor(x.clone()), Value::Tensor(y.clone())]).expect("meshgrid cpu");
1458 let cpu_x =
1459 test_support::gather(eval_first(&cpu_eval).expect("X cpu")).expect("gather X cpu");
1460 let cpu_y =
1461 test_support::gather(eval_second(&cpu_eval).expect("Y cpu")).expect("gather Y cpu");
1462
1463 let x_view = HostTensorView {
1464 data: &x.data,
1465 shape: &x.shape,
1466 };
1467 let y_view = HostTensorView {
1468 data: &y.data,
1469 shape: &y.shape,
1470 };
1471 let x_gpu = provider.upload(&x_view).expect("upload x");
1472 let y_gpu = provider.upload(&y_view).expect("upload y");
1473
1474 let gpu_eval =
1475 evaluate(&[Value::GpuTensor(x_gpu), Value::GpuTensor(y_gpu)]).expect("meshgrid gpu");
1476 let gpu_x_value = eval_first(&gpu_eval).expect("X gpu");
1477 let gpu_y_value = eval_second(&gpu_eval).expect("Y gpu");
1478
1479 assert!(matches!(gpu_x_value, Value::GpuTensor(_)));
1480 assert!(matches!(gpu_y_value, Value::GpuTensor(_)));
1481
1482 let gathered_x = test_support::gather(gpu_x_value).expect("gather X gpu");
1483 let gathered_y = test_support::gather(gpu_y_value).expect("gather Y gpu");
1484
1485 assert_eq!(gathered_x.shape, cpu_x.shape);
1486 assert_eq!(gathered_x.data, cpu_x.data);
1487 assert_eq!(gathered_y.shape, cpu_y.shape);
1488 assert_eq!(gathered_y.data, cpu_y.data);
1489 }
1490
1491 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1492 #[test]
1493 fn meshgrid_complex_inputs_produce_complex_outputs() {
1494 let complex = ComplexTensor::new(vec![(1.0, 1.0), (2.0, -1.0)], vec![1, 2]).unwrap();
1495 let eval = evaluate(&[Value::ComplexTensor(complex)]).expect("meshgrid");
1496 let x_value = eval_first(&eval).expect("X");
1497 match x_value {
1498 Value::ComplexTensor(ct) => {
1499 assert_eq!(ct.shape, vec![2, 2]);
1500 }
1501 Value::Complex(_, _) => {}
1502 other => panic!("expected complex output, got {other:?}"),
1503 }
1504 }
1505
1506 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1507 #[test]
1508 fn meshgrid_like_host_prototype() {
1509 let x = tensor_from_vec(vec![1.0, 2.0], 1, 2);
1510 let eval =
1511 evaluate(&[Value::Tensor(x), Value::from("like"), Value::Num(0.0)]).expect("meshgrid");
1512 let x_out = eval_first(&eval).expect("X");
1513 assert!(matches!(x_out, Value::Tensor(_) | Value::Num(_)));
1514 }
1515}