1use std::cmp::max;
4
5use log::warn;
6use runmat_accelerate_api::{GpuTensorHandle, GpuTensorStorage, HostTensorView};
7use runmat_builtins::{
8 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
9 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
10 ComplexTensor, ResolveContext, Tensor, Type, Value,
11};
12
13use crate::builtins::array::type_resolvers::size_vector_len;
14use runmat_macros::runtime_builtin;
15
16use crate::build_runtime_error;
17use crate::builtins::common::gpu_helpers;
18use crate::builtins::common::random_args::{complex_tensor_into_value, keyword_of};
19use crate::builtins::common::residency::{sequence_gpu_preference, SequenceIntent};
20use crate::builtins::common::spec::{
21 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
22 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
23};
24use crate::builtins::common::tensor;
25
26#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::creation::meshgrid")]
27pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
28 name: "meshgrid",
29 op_kind: GpuOpKind::Custom("array_construct"),
30 supported_precisions: &[ScalarType::F32, ScalarType::F64],
31 broadcast: BroadcastSemantics::Matlab,
32 provider_hooks: &[ProviderHook::Custom("meshgrid")],
33 constant_strategy: ConstantStrategy::InlineLiteral,
34 residency: ResidencyPolicy::NewHandle,
35 nan_mode: ReductionNaN::Include,
36 two_pass_threshold: None,
37 workgroup_size: None,
38 accepts_nan_mode: false,
39 notes: "Providers may supply a dedicated meshgrid hook; until then the runtime builds grids on the host and uploads them when GPU residency is requested.",
40};
41
42fn builtin_error(message: impl Into<String>) -> crate::RuntimeError {
43 build_runtime_error(message)
44 .with_builtin("meshgrid")
45 .build()
46}
47
48#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::creation::meshgrid")]
49pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
50 name: "meshgrid",
51 shape: ShapeRequirements::Any,
52 constant_strategy: ConstantStrategy::InlineLiteral,
53 elementwise: None,
54 reduction: None,
55 emits_nan: false,
56 notes:
57 "Meshgrid explicitly materialises dense coordinate arrays and therefore bypasses fusion.",
58};
59
60fn meshgrid_type(args: &[Type], _context: &ResolveContext) -> Type {
61 if args.is_empty() {
62 return Type::Unknown;
63 }
64 let mut axis_count = args.len();
65 if axis_count >= 2 && matches!(args[axis_count - 2], Type::String) {
66 axis_count = axis_count.saturating_sub(2);
67 }
68 if axis_count == 0 {
69 return Type::Unknown;
70 }
71 let axis_args = &args[..axis_count];
72 let len_x = axis_args.get(0).and_then(size_vector_len);
73 let len_y = axis_args.get(1).and_then(size_vector_len).or(len_x);
74 let len_z = axis_args.get(2).and_then(size_vector_len);
75 let shape = if axis_count >= 3 {
76 vec![len_y, len_x, len_z]
77 } else {
78 vec![len_y, len_x]
79 };
80 Type::Tensor { shape: Some(shape) }
81}
82
83const MESHGRID_OUTPUT_XY: [BuiltinParamDescriptor; 2] = [
84 BuiltinParamDescriptor {
85 name: "X",
86 ty: BuiltinParamType::NumericArray,
87 arity: BuiltinParamArity::Required,
88 default: None,
89 description: "Grid coordinates along X-axis.",
90 },
91 BuiltinParamDescriptor {
92 name: "Y",
93 ty: BuiltinParamType::NumericArray,
94 arity: BuiltinParamArity::Required,
95 default: None,
96 description: "Grid coordinates along Y-axis.",
97 },
98];
99
100const MESHGRID_OUTPUT_XYZ: [BuiltinParamDescriptor; 3] = [
101 BuiltinParamDescriptor {
102 name: "X",
103 ty: BuiltinParamType::NumericArray,
104 arity: BuiltinParamArity::Required,
105 default: None,
106 description: "Grid coordinates along X-axis.",
107 },
108 BuiltinParamDescriptor {
109 name: "Y",
110 ty: BuiltinParamType::NumericArray,
111 arity: BuiltinParamArity::Required,
112 default: None,
113 description: "Grid coordinates along Y-axis.",
114 },
115 BuiltinParamDescriptor {
116 name: "Z",
117 ty: BuiltinParamType::NumericArray,
118 arity: BuiltinParamArity::Optional,
119 default: None,
120 description: "Grid coordinates along Z-axis.",
121 },
122];
123
124const MESHGRID_SIG_X_INPUTS: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
125 name: "x",
126 ty: BuiltinParamType::NumericArray,
127 arity: BuiltinParamArity::Required,
128 default: None,
129 description: "X-axis vector.",
130}];
131
132const MESHGRID_SIG_XY_INPUTS: [BuiltinParamDescriptor; 2] = [
133 BuiltinParamDescriptor {
134 name: "x",
135 ty: BuiltinParamType::NumericArray,
136 arity: BuiltinParamArity::Required,
137 default: None,
138 description: "X-axis vector.",
139 },
140 BuiltinParamDescriptor {
141 name: "y",
142 ty: BuiltinParamType::NumericArray,
143 arity: BuiltinParamArity::Required,
144 default: None,
145 description: "Y-axis vector.",
146 },
147];
148
149const MESHGRID_SIG_XYZ_INPUTS: [BuiltinParamDescriptor; 3] = [
150 BuiltinParamDescriptor {
151 name: "x",
152 ty: BuiltinParamType::NumericArray,
153 arity: BuiltinParamArity::Required,
154 default: None,
155 description: "X-axis vector.",
156 },
157 BuiltinParamDescriptor {
158 name: "y",
159 ty: BuiltinParamType::NumericArray,
160 arity: BuiltinParamArity::Required,
161 default: None,
162 description: "Y-axis vector.",
163 },
164 BuiltinParamDescriptor {
165 name: "z",
166 ty: BuiltinParamType::NumericArray,
167 arity: BuiltinParamArity::Optional,
168 default: None,
169 description: "Z-axis vector.",
170 },
171];
172
173const MESHGRID_SIG_X_LIKE_INPUTS: [BuiltinParamDescriptor; 3] = [
174 BuiltinParamDescriptor {
175 name: "x",
176 ty: BuiltinParamType::NumericArray,
177 arity: BuiltinParamArity::Required,
178 default: None,
179 description: "X-axis vector.",
180 },
181 BuiltinParamDescriptor {
182 name: "like_kw",
183 ty: BuiltinParamType::StringScalar,
184 arity: BuiltinParamArity::Required,
185 default: Some("\"like\""),
186 description: "Like keyword.",
187 },
188 BuiltinParamDescriptor {
189 name: "prototype",
190 ty: BuiltinParamType::LikePrototype,
191 arity: BuiltinParamArity::Required,
192 default: None,
193 description: "Prototype controlling class/device residency.",
194 },
195];
196
197const MESHGRID_SIG_XY_LIKE_INPUTS: [BuiltinParamDescriptor; 4] = [
198 BuiltinParamDescriptor {
199 name: "x",
200 ty: BuiltinParamType::NumericArray,
201 arity: BuiltinParamArity::Required,
202 default: None,
203 description: "X-axis vector.",
204 },
205 BuiltinParamDescriptor {
206 name: "y",
207 ty: BuiltinParamType::NumericArray,
208 arity: BuiltinParamArity::Required,
209 default: None,
210 description: "Y-axis vector.",
211 },
212 BuiltinParamDescriptor {
213 name: "like_kw",
214 ty: BuiltinParamType::StringScalar,
215 arity: BuiltinParamArity::Required,
216 default: Some("\"like\""),
217 description: "Like keyword.",
218 },
219 BuiltinParamDescriptor {
220 name: "prototype",
221 ty: BuiltinParamType::LikePrototype,
222 arity: BuiltinParamArity::Required,
223 default: None,
224 description: "Prototype controlling class/device residency.",
225 },
226];
227
228const MESHGRID_SIG_XYZ_LIKE_INPUTS: [BuiltinParamDescriptor; 5] = [
229 BuiltinParamDescriptor {
230 name: "x",
231 ty: BuiltinParamType::NumericArray,
232 arity: BuiltinParamArity::Required,
233 default: None,
234 description: "X-axis vector.",
235 },
236 BuiltinParamDescriptor {
237 name: "y",
238 ty: BuiltinParamType::NumericArray,
239 arity: BuiltinParamArity::Required,
240 default: None,
241 description: "Y-axis vector.",
242 },
243 BuiltinParamDescriptor {
244 name: "z",
245 ty: BuiltinParamType::NumericArray,
246 arity: BuiltinParamArity::Optional,
247 default: None,
248 description: "Z-axis vector.",
249 },
250 BuiltinParamDescriptor {
251 name: "like_kw",
252 ty: BuiltinParamType::StringScalar,
253 arity: BuiltinParamArity::Required,
254 default: Some("\"like\""),
255 description: "Like keyword.",
256 },
257 BuiltinParamDescriptor {
258 name: "prototype",
259 ty: BuiltinParamType::LikePrototype,
260 arity: BuiltinParamArity::Required,
261 default: None,
262 description: "Prototype controlling class/device residency.",
263 },
264];
265
266const MESHGRID_SIGNATURES: [BuiltinSignatureDescriptor; 6] = [
267 BuiltinSignatureDescriptor {
268 label: "[X,Y] = meshgrid(x)",
269 inputs: &MESHGRID_SIG_X_INPUTS,
270 outputs: &MESHGRID_OUTPUT_XY,
271 },
272 BuiltinSignatureDescriptor {
273 label: "[X,Y] = meshgrid(x, y)",
274 inputs: &MESHGRID_SIG_XY_INPUTS,
275 outputs: &MESHGRID_OUTPUT_XY,
276 },
277 BuiltinSignatureDescriptor {
278 label: "[X,Y,Z] = meshgrid(x, y, z)",
279 inputs: &MESHGRID_SIG_XYZ_INPUTS,
280 outputs: &MESHGRID_OUTPUT_XYZ,
281 },
282 BuiltinSignatureDescriptor {
283 label: "[X,Y] = meshgrid(x, \"like\", prototype)",
284 inputs: &MESHGRID_SIG_X_LIKE_INPUTS,
285 outputs: &MESHGRID_OUTPUT_XY,
286 },
287 BuiltinSignatureDescriptor {
288 label: "[X,Y] = meshgrid(x, y, \"like\", prototype)",
289 inputs: &MESHGRID_SIG_XY_LIKE_INPUTS,
290 outputs: &MESHGRID_OUTPUT_XY,
291 },
292 BuiltinSignatureDescriptor {
293 label: "[X,Y,Z] = meshgrid(x, y, z, \"like\", prototype)",
294 inputs: &MESHGRID_SIG_XYZ_LIKE_INPUTS,
295 outputs: &MESHGRID_OUTPUT_XYZ,
296 },
297];
298
299const MESHGRID_ERRORS: [BuiltinErrorDescriptor; 11] = [
300 BuiltinErrorDescriptor {
301 code: "RM.MESHGRID.MISSING_AXIS",
302 identifier: None,
303 when: "No axis vectors are provided.",
304 message: "meshgrid: at least one input vector is required",
305 },
306 BuiltinErrorDescriptor {
307 code: "RM.MESHGRID.TOO_MANY_AXES",
308 identifier: None,
309 when: "More than three axis vectors are provided.",
310 message: "meshgrid: expected at most three input vectors",
311 },
312 BuiltinErrorDescriptor {
313 code: "RM.MESHGRID.LIKE_EXPECTED_PROTOTYPE",
314 identifier: None,
315 when: "The 'like' keyword is provided without a prototype argument.",
316 message: "meshgrid: expected prototype after 'like'",
317 },
318 BuiltinErrorDescriptor {
319 code: "RM.MESHGRID.MULTIPLE_LIKE",
320 identifier: None,
321 when: "The 'like' keyword is provided multiple times.",
322 message: "meshgrid: multiple 'like' specifications are not supported",
323 },
324 BuiltinErrorDescriptor {
325 code: "RM.MESHGRID.LIKE_POSITION",
326 identifier: None,
327 when: "The 'like' keyword is in an invalid position or not final.",
328 message: "meshgrid: 'like' must be the final argument",
329 },
330 BuiltinErrorDescriptor {
331 code: "RM.MESHGRID.UNRECOGNIZED_OPTION",
332 identifier: None,
333 when: "A trailing option string is not recognized.",
334 message: "meshgrid: unrecognised option",
335 },
336 BuiltinErrorDescriptor {
337 code: "RM.MESHGRID.INVALID_AXIS_INPUT",
338 identifier: None,
339 when: "Axis inputs are non-numeric or non-vector shapes.",
340 message: "meshgrid: input argument must be numeric vector data",
341 },
342 BuiltinErrorDescriptor {
343 code: "RM.MESHGRID.INVALID_PROTOTYPE",
344 identifier: None,
345 when: "The 'like' prototype is unsupported.",
346 message: "meshgrid: prototypes must be numeric arrays",
347 },
348 BuiltinErrorDescriptor {
349 code: "RM.MESHGRID.OUTPUT_COUNT_EXCEEDED",
350 identifier: None,
351 when: "Requested outputs exceed available outputs for provided axes.",
352 message:
353 "meshgrid: supports at most two outputs for 2-axis inputs and three for 3-axis inputs",
354 },
355 BuiltinErrorDescriptor {
356 code: "RM.MESHGRID.THIRD_OUTPUT_UNAVAILABLE",
357 identifier: None,
358 when: "A third output is requested without supplying a Z-axis vector.",
359 message: "meshgrid: third output requested but no Z vector was supplied",
360 },
361 BuiltinErrorDescriptor {
362 code: "RM.MESHGRID.COMPLEX_REAL_CONVERSION",
363 identifier: None,
364 when: "Complex axis values cannot be represented in requested real output class.",
365 message: "meshgrid: cannot represent complex values in a real output",
366 },
367];
368
369pub const MESHGRID_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
370 signatures: &MESHGRID_SIGNATURES,
371 output_mode: BuiltinOutputMode::ByRequestedOutputCount,
372 completion_policy: BuiltinCompletionPolicy::Public,
373 errors: &MESHGRID_ERRORS,
374};
375
376#[runtime_builtin(
377 name = "meshgrid",
378 category = "array/creation",
379 summary = "Generate coordinate matrices for 2-D and 3-D grids.",
380 keywords = "meshgrid,grid,gpu,like,3d",
381 accel = "array_construct",
382 type_resolver(meshgrid_type),
383 descriptor(crate::builtins::array::creation::meshgrid::MESHGRID_DESCRIPTOR),
384 builtin_path = "crate::builtins::array::creation::meshgrid"
385)]
386async fn meshgrid_builtin(rest: Vec<Value>) -> crate::BuiltinResult<Value> {
387 let eval = evaluate(&rest).await?;
388 if let Some(out_count) = crate::output_count::current_output_count() {
389 if out_count == 0 {
390 return Ok(Value::OutputList(Vec::new()));
391 }
392 let available = eval.output_count();
393 if out_count > available {
394 let msg = if available == 2 {
395 "meshgrid with two inputs supports at most two outputs"
396 } else {
397 "meshgrid supports at most three outputs"
398 };
399 return Err(builtin_error(msg));
400 }
401 let mut outputs = Vec::with_capacity(out_count);
402 let first = eval.first().await?;
403 outputs.push(first);
404 if out_count >= 2 {
405 outputs.push(eval.second().await?);
406 }
407 if out_count >= 3 {
408 outputs.push(eval.third().await?);
409 }
410 return Ok(Value::OutputList(outputs));
411 }
412 eval.first().await
413}
414
415pub async fn evaluate(args: &[Value]) -> crate::BuiltinResult<MeshgridEval> {
417 let parsed = ParsedMeshgrid::parse(args).await?;
418 let (x_axis, y_axis, z_axis) = normalise_axes(&parsed.axes);
419
420 let require_complex = parsed.axes.iter().any(|axis| axis.is_complex);
421
422 let target_class = match &parsed.template {
423 OutputTemplate::Default => {
424 if require_complex {
425 PrototypeClass::Complex
426 } else {
427 PrototypeClass::Real
428 }
429 }
430 OutputTemplate::Like(spec) => {
431 if require_complex {
432 PrototypeClass::Complex
433 } else {
434 spec.class
435 }
436 }
437 };
438
439 let target_residency = match &parsed.template {
440 OutputTemplate::Default => {
441 if parsed.prefer_gpu {
442 DevicePreference::Gpu
443 } else {
444 DevicePreference::Host
445 }
446 }
447 OutputTemplate::Like(spec) => spec.residency,
448 };
449
450 let mut outputs: Vec<MeshgridOutput> = Vec::new();
451
452 if matches!(target_residency, DevicePreference::Gpu) {
453 if let Some(gpu) = try_meshgrid_gpu_from_vector_axes(&x_axis, &y_axis, z_axis.as_ref())? {
454 outputs = gpu;
455 }
456 }
457
458 if outputs.is_empty() {
459 let x_host = axis_to_host_async(&x_axis).await?;
461 let y_host = axis_to_host_async(&y_axis).await?;
462 let z_host = match z_axis.as_ref() {
463 Some(axis) => Some(axis_to_host_async(axis).await?),
464 None => None,
465 };
466 outputs = build_outputs(&x_host, &y_host, z_host.as_ref())
467 .into_iter()
468 .map(MeshgridOutput::Host)
469 .collect();
470 }
471
472 Ok(MeshgridEval {
473 outputs,
474 target_class,
475 target_residency,
476 })
477}
478
479#[derive(Clone)]
480struct ParsedMeshgrid {
481 axes: Vec<AxisData>,
482 template: OutputTemplate,
483 prefer_gpu: bool,
484}
485
486impl ParsedMeshgrid {
487 async fn parse(args: &[Value]) -> crate::BuiltinResult<Self> {
488 if args.is_empty() {
489 return Err(builtin_error(
490 "meshgrid: at least one input vector is required",
491 ));
492 }
493 let mut axis_values: Vec<Value> = Vec::new();
494 let mut like_proto: Option<Value> = None;
495 let mut prefer_gpu = false;
496 let mut idx = 0;
497 while idx < args.len() {
498 let value = args[idx].clone();
499 if let Some(keyword) = keyword_of(&value) {
500 match keyword.as_str() {
501 "like" => {
502 if like_proto.is_some() {
503 return Err(builtin_error(
504 "meshgrid: multiple 'like' specifications are not supported",
505 ));
506 }
507 if axis_values.is_empty() {
508 return Err(builtin_error(
509 "meshgrid: 'like' must follow at least one input vector",
510 ));
511 }
512 let Some(proto) = args.get(idx + 1).cloned() else {
513 return Err(builtin_error("meshgrid: expected prototype after 'like'"));
514 };
515 like_proto = Some(proto);
516 idx += 2;
517 if idx < args.len() {
518 return Err(builtin_error(
519 "meshgrid: 'like' must be the final argument",
520 ));
521 }
522 break;
523 }
524 other => {
525 return Err(builtin_error(format!(
526 "meshgrid: unrecognised option '{other}'"
527 )));
528 }
529 }
530 }
531
532 if let Value::GpuTensor(_) = value {
533 prefer_gpu = true;
534 }
535 axis_values.push(value);
536 idx += 1;
537 }
538
539 if axis_values.is_empty() {
540 return Err(builtin_error(
541 "meshgrid: at least one input vector is required",
542 ));
543 }
544 if axis_values.len() > 3 {
545 return Err(builtin_error(
546 "meshgrid: expected at most three input vectors",
547 ));
548 }
549
550 let mut axes = Vec::with_capacity(max(axis_values.len(), 2));
551 for (i, value) in axis_values.into_iter().enumerate() {
552 let mut consumed_gpu = false;
553 let data = axis_from_value(value, i, &mut consumed_gpu).await?;
554 if consumed_gpu {
555 prefer_gpu = true;
556 }
557 axes.push(data);
558 }
559
560 if !prefer_gpu {
561 if let Some(max_len) = axes.iter().map(|axis| axis.len).max() {
562 if max_len > 0
563 && sequence_gpu_preference(max_len, SequenceIntent::MeshAxis, false).prefer_gpu
564 {
565 prefer_gpu = true;
566 }
567 }
568 }
569
570 let template = if let Some(proto) = like_proto {
571 OutputTemplate::Like(analyse_like_prototype(&proto)?)
572 } else {
573 OutputTemplate::Default
574 };
575
576 Ok(Self {
577 axes,
578 template,
579 prefer_gpu,
580 })
581 }
582}
583
584#[derive(Clone)]
585enum OutputTemplate {
586 Default,
587 Like(PrototypeSpec),
588}
589
590#[derive(Clone)]
591struct PrototypeSpec {
592 residency: DevicePreference,
593 class: PrototypeClass,
594}
595
596#[derive(Clone, Copy, PartialEq, Eq)]
597enum PrototypeClass {
598 Real,
599 Complex,
600}
601
602#[derive(Clone, Copy)]
603enum DevicePreference {
604 Host,
605 Gpu,
606}
607
608fn analyse_like_prototype(proto: &Value) -> crate::BuiltinResult<PrototypeSpec> {
609 match proto {
610 Value::GpuTensor(handle) => {
611 let class = if runmat_accelerate_api::handle_storage(handle)
612 == GpuTensorStorage::ComplexInterleaved
613 {
614 PrototypeClass::Complex
615 } else {
616 PrototypeClass::Real
617 };
618 Ok(PrototypeSpec {
619 residency: DevicePreference::Gpu,
620 class,
621 })
622 }
623 Value::ComplexTensor(_) | Value::Complex(_, _) => Ok(PrototypeSpec {
624 residency: DevicePreference::Host,
625 class: PrototypeClass::Complex,
626 }),
627 Value::Tensor(_)
628 | Value::SparseTensor(_)
629 | Value::Num(_)
630 | Value::Int(_)
631 | Value::Bool(_)
632 | Value::LogicalArray(_) => Ok(PrototypeSpec {
633 residency: DevicePreference::Host,
634 class: PrototypeClass::Real,
635 }),
636 Value::CharArray(_) | Value::String(_) | Value::StringArray(_) => Err(builtin_error(
637 "meshgrid: prototypes must be numeric or gpuArray values",
638 )),
639 Value::Symbolic(_) => Err(builtin_error(
640 "meshgrid: prototypes must be numeric or gpuArray values",
641 )),
642 Value::Cell(_)
643 | Value::Struct(_)
644 | Value::Object(_)
645 | Value::HandleObject(_)
646 | Value::Listener(_)
647 | Value::FunctionHandle(_)
648 | Value::ExternalFunctionHandle(_)
649 | Value::MethodFunctionHandle(_)
650 | Value::BoundFunctionHandle { .. }
651 | Value::Closure(_)
652 | Value::ClassRef(_)
653 | Value::MException(_)
654 | Value::OutputList(_) => Err(builtin_error("meshgrid: prototypes must be numeric arrays")),
655 }
656}
657
658#[derive(Clone)]
659struct AxisData {
660 values: Vec<(f64, f64)>,
661 len: usize,
662 is_complex: bool,
663 gpu_real: Option<GpuTensorHandle>,
664}
665
666async fn axis_from_value(
667 value: Value,
668 index: usize,
669 prefer_gpu: &mut bool,
670) -> crate::BuiltinResult<AxisData> {
671 match value {
672 Value::Tensor(tensor) => axis_from_tensor(tensor, index),
673 Value::LogicalArray(logical) => {
674 let tensor = tensor::logical_to_tensor(&logical)?;
675 axis_from_tensor(tensor, index)
676 }
677 Value::Num(n) => Ok(AxisData {
678 values: vec![(n, 0.0)],
679 len: 1,
680 is_complex: false,
681 gpu_real: None,
682 }),
683 Value::Int(i) => {
684 let val = i.to_f64();
685 Ok(AxisData {
686 values: vec![(val, 0.0)],
687 len: 1,
688 is_complex: false,
689 gpu_real: None,
690 })
691 }
692 Value::Bool(b) => Ok(AxisData {
693 values: vec![(if b { 1.0 } else { 0.0 }, 0.0)],
694 len: 1,
695 is_complex: false,
696 gpu_real: None,
697 }),
698 Value::Complex(re, im) => Ok(AxisData {
699 values: vec![(re, im)],
700 len: 1,
701 is_complex: im != 0.0,
702 gpu_real: None,
703 }),
704 Value::ComplexTensor(tensor) => axis_from_complex_tensor(tensor, index),
705 Value::GpuTensor(handle) => {
706 let is_complex = runmat_accelerate_api::handle_storage(&handle)
707 == GpuTensorStorage::ComplexInterleaved;
708 if is_vector_shape(&handle.shape) && !is_complex {
711 *prefer_gpu = true;
712 return Ok(AxisData {
713 values: Vec::new(),
714 len: vector_len_from_shape(&handle.shape),
715 is_complex,
716 gpu_real: Some(handle),
717 });
718 }
719
720 *prefer_gpu = true;
722 let gathered = gpu_helpers::gather_value_async(&Value::GpuTensor(handle)).await?;
723 match gathered {
724 Value::Tensor(tensor) => {
725 if is_vector_shape(&tensor.shape) {
726 *prefer_gpu = true;
727 }
728 axis_from_tensor(tensor, index)
729 }
730 Value::ComplexTensor(tensor) => {
731 if is_vector_shape(&tensor.shape) {
732 *prefer_gpu = true;
733 }
734 axis_from_complex_tensor(tensor, index)
735 }
736 other => Err(builtin_error(format!(
737 "meshgrid: input argument {} must be numeric, got {other:?}",
738 index + 1
739 ))),
740 }
741 }
742 other => Err(builtin_error(format!(
743 "meshgrid: input argument {} must be numeric, got {other:?}",
744 index + 1
745 ))),
746 }
747}
748
749fn axis_from_tensor(tensor: Tensor, index: usize) -> crate::BuiltinResult<AxisData> {
750 if is_vector_shape(&tensor.shape) {
751 let mut values = Vec::with_capacity(tensor.data.len());
752 for &v in &tensor.data {
753 values.push((v, 0.0));
754 }
755 return Ok(AxisData {
756 len: values.len(),
757 values,
758 is_complex: false,
759 gpu_real: None,
760 });
761 }
762
763 if let Some(axis) = axis_from_meshgrid_matrix_real(&tensor, index)? {
769 return Ok(axis);
770 }
771
772 Err(builtin_error(format!(
773 "meshgrid: input argument {} must be a vector (1xN or Nx1), got shape {:?}",
774 index + 1,
775 tensor.shape
776 )))
777}
778
779fn axis_from_complex_tensor(tensor: ComplexTensor, index: usize) -> crate::BuiltinResult<AxisData> {
780 if is_vector_shape(&tensor.shape) {
781 let is_complex = tensor
782 .data
783 .iter()
784 .any(|&(_, imag)| !imag.is_nan() && imag != 0.0);
785 return Ok(AxisData {
786 len: tensor.data.len(),
787 values: tensor.data,
788 is_complex,
789 gpu_real: None,
790 });
791 }
792
793 if let Some(axis) = axis_from_meshgrid_matrix_complex(&tensor, index)? {
794 return Ok(axis);
795 }
796
797 Err(builtin_error(format!(
798 "meshgrid: input argument {} must be a vector (1xN or Nx1), got shape {:?}",
799 index + 1,
800 tensor.shape
801 )))
802}
803
804fn axis_from_meshgrid_matrix_real(
805 tensor: &Tensor,
806 index: usize,
807) -> crate::BuiltinResult<Option<AxisData>> {
808 let (rows, cols) = match tensor.shape.as_slice() {
809 [r, c] => (*r, *c),
810 _ => return Ok(None),
811 };
812 if rows <= 1 || cols <= 1 {
813 return Ok(None);
814 }
815
816 let expect_rows_constant = index == 0;
819
820 if expect_rows_constant {
821 if !matrix_rows_are_identical_real(tensor, rows, cols) {
822 return Ok(None);
823 }
824 let mut values = Vec::with_capacity(cols);
826 for col in 0..cols {
827 let idx = rows * col;
828 values.push((tensor.data[idx], 0.0));
829 }
830 return Ok(Some(AxisData {
831 len: values.len(),
832 values,
833 is_complex: false,
834 gpu_real: None,
835 }));
836 }
837
838 if !matrix_cols_are_identical_real(tensor, rows, cols) {
839 return Ok(None);
840 }
841 let mut values = Vec::with_capacity(rows);
843 for row in 0..rows {
844 values.push((tensor.data[row], 0.0));
845 }
846 Ok(Some(AxisData {
847 len: values.len(),
848 values,
849 is_complex: false,
850 gpu_real: None,
851 }))
852}
853
854fn axis_from_meshgrid_matrix_complex(
855 tensor: &ComplexTensor,
856 index: usize,
857) -> crate::BuiltinResult<Option<AxisData>> {
858 let (rows, cols) = match tensor.shape.as_slice() {
859 [r, c] => (*r, *c),
860 _ => return Ok(None),
861 };
862 if rows <= 1 || cols <= 1 {
863 return Ok(None);
864 }
865
866 let expect_rows_constant = index == 0;
867 if expect_rows_constant {
868 if !matrix_rows_are_identical_complex(tensor, rows, cols) {
869 return Ok(None);
870 }
871 let mut values = Vec::with_capacity(cols);
872 for col in 0..cols {
873 let idx = rows * col;
874 values.push(tensor.data[idx]);
875 }
876 let is_complex = values.iter().any(|&(_, im)| !im.is_nan() && im != 0.0);
877 return Ok(Some(AxisData {
878 len: values.len(),
879 values,
880 is_complex,
881 gpu_real: None,
882 }));
883 }
884
885 if !matrix_cols_are_identical_complex(tensor, rows, cols) {
886 return Ok(None);
887 }
888 let mut values = Vec::with_capacity(rows);
889 for row in 0..rows {
890 values.push(tensor.data[row]);
891 }
892 let is_complex = values.iter().any(|&(_, im)| !im.is_nan() && im != 0.0);
893 Ok(Some(AxisData {
894 len: values.len(),
895 values,
896 is_complex,
897 gpu_real: None,
898 }))
899}
900
901fn matrix_rows_are_identical_real(tensor: &Tensor, rows: usize, cols: usize) -> bool {
902 for row in 1..rows {
903 for col in 0..cols {
904 let idx0 = rows * col;
905 let idx = row + rows * col;
906 if tensor.data[idx] != tensor.data[idx0] {
907 return false;
908 }
909 }
910 }
911 true
912}
913
914fn matrix_cols_are_identical_real(tensor: &Tensor, rows: usize, cols: usize) -> bool {
915 for col in 1..cols {
916 for row in 0..rows {
917 let idx0 = row;
918 let idx = row + rows * col;
919 if tensor.data[idx] != tensor.data[idx0] {
920 return false;
921 }
922 }
923 }
924 true
925}
926
927fn matrix_rows_are_identical_complex(tensor: &ComplexTensor, rows: usize, cols: usize) -> bool {
928 for row in 1..rows {
929 for col in 0..cols {
930 let idx0 = rows * col;
931 let idx = row + rows * col;
932 if tensor.data[idx] != tensor.data[idx0] {
933 return false;
934 }
935 }
936 }
937 true
938}
939
940fn matrix_cols_are_identical_complex(tensor: &ComplexTensor, rows: usize, cols: usize) -> bool {
941 for col in 1..cols {
942 for row in 0..rows {
943 let idx0 = row;
944 let idx = row + rows * col;
945 if tensor.data[idx] != tensor.data[idx0] {
946 return false;
947 }
948 }
949 }
950 true
951}
952
953fn is_vector_shape(shape: &[usize]) -> bool {
954 if shape.is_empty() {
955 return true;
956 }
957 let mut non_singleton = 0usize;
958 for &dim in shape {
959 if dim > 1 {
960 non_singleton += 1;
961 }
962 }
963 non_singleton <= 1
964}
965
966fn vector_len_from_shape(shape: &[usize]) -> usize {
967 if shape.is_empty() {
968 return 1;
969 }
970 shape.iter().copied().max().unwrap_or(0)
971}
972
973async fn axis_to_host_async(axis: &AxisData) -> crate::BuiltinResult<AxisData> {
974 if axis.gpu_real.is_none() {
975 return Ok(axis.clone());
976 }
977 let handle = axis.gpu_real.as_ref().expect("checked gpu_real is_some");
978 let gathered = gpu_helpers::gather_value_async(&Value::GpuTensor(handle.clone())).await?;
979 match gathered {
981 Value::Tensor(tensor) => axis_from_tensor(tensor, 0),
982 Value::ComplexTensor(tensor) => axis_from_complex_tensor(tensor, 0),
983 Value::Num(n) => Ok(AxisData {
984 values: vec![(n, 0.0)],
985 len: 1,
986 is_complex: false,
987 gpu_real: None,
988 }),
989 Value::Complex(re, im) => Ok(AxisData {
990 values: vec![(re, im)],
991 len: 1,
992 is_complex: im != 0.0,
993 gpu_real: None,
994 }),
995 other => Err(builtin_error(format!(
996 "meshgrid: expected numeric GPU axis, got {other:?}"
997 ))),
998 }
999}
1000
1001fn try_meshgrid_gpu_from_vector_axes(
1002 x_axis: &AxisData,
1003 y_axis: &AxisData,
1004 z_axis: Option<&AxisData>,
1005) -> crate::BuiltinResult<Option<Vec<MeshgridOutput>>> {
1006 let Some(x_handle) = x_axis.gpu_real.as_ref() else {
1007 return Ok(None);
1008 };
1009 let Some(y_handle) = y_axis.gpu_real.as_ref() else {
1010 return Ok(None);
1011 };
1012
1013 let z_handle = match z_axis {
1014 Some(axis) => match axis.gpu_real.as_ref() {
1015 Some(h) => Some(h),
1016 None => return Ok(None),
1017 },
1018 None => None,
1019 };
1020
1021 let Some(provider) = runmat_accelerate_api::provider_for_handle(x_handle) else {
1022 return Ok(None);
1023 };
1024 let Some(y_provider) = runmat_accelerate_api::provider_for_handle(y_handle) else {
1025 return Ok(None);
1026 };
1027 if y_provider.device_id() != provider.device_id() {
1028 return Ok(None);
1029 }
1030 if let Some(z) = z_handle {
1031 let Some(z_provider) = runmat_accelerate_api::provider_for_handle(z) else {
1032 return Ok(None);
1033 };
1034 if z_provider.device_id() != provider.device_id() {
1035 return Ok(None);
1036 }
1037 }
1038
1039 let nx = x_axis.len;
1040 let ny = y_axis.len;
1041 let nz = z_axis.map(|axis| axis.len).unwrap_or(1);
1042
1043 let x_row = provider
1045 .reshape(x_handle, &[1, nx])
1046 .map_err(|e| builtin_error(format!("meshgrid: reshape X failed: {e}")))?;
1047 let y_col = provider
1048 .reshape(y_handle, &[ny, 1])
1049 .map_err(|e| builtin_error(format!("meshgrid: reshape Y failed: {e}")))?;
1050
1051 let mut outputs = Vec::with_capacity(if z_handle.is_some() { 3 } else { 2 });
1052 if let Some(z) = z_handle {
1053 let x_base = provider
1054 .reshape(&x_row, &[1, nx, 1])
1055 .map_err(|e| builtin_error(format!("meshgrid: reshape X(3d) failed: {e}")))?;
1056 let y_base = provider
1057 .reshape(&y_col, &[ny, 1, 1])
1058 .map_err(|e| builtin_error(format!("meshgrid: reshape Y(3d) failed: {e}")))?;
1059
1060 let x_grid = provider
1061 .repmat(&x_base, &[ny, 1, nz])
1062 .map_err(|e| builtin_error(format!("meshgrid: repmat X failed: {e}")))?;
1063 let y_grid = provider
1064 .repmat(&y_base, &[1, nx, nz])
1065 .map_err(|e| builtin_error(format!("meshgrid: repmat Y failed: {e}")))?;
1066
1067 outputs.push(MeshgridOutput::Gpu(x_grid));
1068 outputs.push(MeshgridOutput::Gpu(y_grid));
1069 let z_axis_row = provider
1070 .reshape(z, &[1, nz])
1071 .map_err(|e| builtin_error(format!("meshgrid: reshape Z failed: {e}")))?;
1072 let z_base = provider
1073 .reshape(&z_axis_row, &[1, 1, nz])
1074 .map_err(|e| builtin_error(format!("meshgrid: reshape Z(3d) failed: {e}")))?;
1075 let z_grid = provider
1076 .repmat(&z_base, &[ny, nx, 1])
1077 .map_err(|e| builtin_error(format!("meshgrid: repmat Z failed: {e}")))?;
1078 outputs.push(MeshgridOutput::Gpu(z_grid));
1079 } else {
1080 let x_grid = provider
1081 .repmat(&x_row, &[ny, 1])
1082 .map_err(|e| builtin_error(format!("meshgrid: repmat X failed: {e}")))?;
1083 let y_grid = provider
1084 .repmat(&y_col, &[1, nx])
1085 .map_err(|e| builtin_error(format!("meshgrid: repmat Y failed: {e}")))?;
1086 outputs.push(MeshgridOutput::Gpu(x_grid));
1087 outputs.push(MeshgridOutput::Gpu(y_grid));
1088 }
1089
1090 Ok(Some(outputs))
1091}
1092
1093fn normalise_axes(axes: &[AxisData]) -> (AxisData, AxisData, Option<AxisData>) {
1094 match axes.len() {
1095 1 => {
1096 let x = axes[0].clone();
1097 (x.clone(), x, None)
1098 }
1099 2 => {
1100 let x = axes[0].clone();
1101 let y = axes[1].clone();
1102 (x, y, None)
1103 }
1104 3 => {
1105 let x = axes[0].clone();
1106 let y = axes[1].clone();
1107 let z = axes[2].clone();
1108 (x, y, Some(z))
1109 }
1110 _ => unreachable!(),
1111 }
1112}
1113
1114fn build_outputs(
1115 x_axis: &AxisData,
1116 y_axis: &AxisData,
1117 z_axis: Option<&AxisData>,
1118) -> Vec<GridOutput> {
1119 let nx = x_axis.len;
1120 let ny = y_axis.len;
1121 let nz = z_axis.map(|axis| axis.len).unwrap_or(1);
1122 let total = nx * ny * nz;
1123 let mut x_data = Vec::with_capacity(total);
1124 let mut y_data = Vec::with_capacity(total);
1125 let mut z_data = z_axis.map(|_| Vec::with_capacity(total));
1126
1127 for k in 0..nz {
1128 let z_value = z_axis.map(|axis| axis.values[k]);
1129 for col in 0..nx {
1130 let x_value = x_axis.values[col];
1131 for row in 0..ny {
1132 x_data.push(x_value);
1133 y_data.push(y_axis.values[row]);
1134 if let Some(ref mut z_vec) = z_data {
1135 z_vec.push(z_value.unwrap());
1136 }
1137 }
1138 }
1139 }
1140
1141 let mut outputs = Vec::new();
1142 let base_shape = if nz == 1 {
1143 vec![ny, nx]
1144 } else {
1145 vec![ny, nx, nz]
1146 };
1147 outputs.push(GridOutput {
1148 shape: base_shape.clone(),
1149 data: x_data,
1150 });
1151 outputs.push(GridOutput {
1152 shape: base_shape.clone(),
1153 data: y_data,
1154 });
1155 if let Some(z_vec) = z_data {
1156 outputs.push(GridOutput {
1157 shape: base_shape,
1158 data: z_vec,
1159 });
1160 }
1161 outputs
1162}
1163
1164struct GridOutput {
1165 shape: Vec<usize>,
1166 data: Vec<(f64, f64)>,
1167}
1168
1169impl GridOutput {
1170 fn to_value(
1171 &self,
1172 class: PrototypeClass,
1173 residency: DevicePreference,
1174 ) -> crate::BuiltinResult<Value> {
1175 match class {
1176 PrototypeClass::Real => self.to_real_value(residency),
1177 PrototypeClass::Complex => self.to_complex_value(residency),
1178 }
1179 }
1180
1181 fn to_real_value(&self, residency: DevicePreference) -> crate::BuiltinResult<Value> {
1182 let mut real = Vec::with_capacity(self.data.len());
1183 for &(re, im) in &self.data {
1184 if im != 0.0 {
1185 return Err(builtin_error(
1186 "meshgrid: cannot represent complex values in a real output",
1187 ));
1188 }
1189 real.push(re);
1190 }
1191 let tensor = Tensor::new(real, self.shape.clone())
1192 .map_err(|e| builtin_error(format!("meshgrid: {e}")))?;
1193 match residency {
1194 DevicePreference::Host => Ok(tensor::tensor_into_value(tensor)),
1195 DevicePreference::Gpu => to_gpu_tensor_value(tensor),
1196 }
1197 }
1198
1199 fn to_complex_value(&self, residency: DevicePreference) -> crate::BuiltinResult<Value> {
1200 let tensor = ComplexTensor::new(self.data.clone(), self.shape.clone())
1201 .map_err(|e| builtin_error(format!("meshgrid: {e}")))?;
1202 match residency {
1203 DevicePreference::Host => Ok(complex_tensor_into_value(tensor)),
1204 DevicePreference::Gpu => to_complex_gpu_tensor_value(tensor),
1205 }
1206 }
1207}
1208
1209fn to_gpu_tensor_value(tensor: Tensor) -> crate::BuiltinResult<Value> {
1210 if let Some(provider) = runmat_accelerate_api::provider() {
1211 let view = HostTensorView {
1212 data: &tensor.data,
1213 shape: &tensor.shape,
1214 };
1215 match provider.upload(&view) {
1216 Ok(handle) => return Ok(Value::GpuTensor(handle)),
1217 Err(err) => {
1218 warn!("meshgrid: failed to upload tensor to GPU, returning host array: {err}")
1219 }
1220 }
1221 }
1222 Ok(tensor::tensor_into_value(tensor))
1223}
1224
1225fn to_complex_gpu_tensor_value(tensor: ComplexTensor) -> crate::BuiltinResult<Value> {
1226 if let Some(provider) = runmat_accelerate_api::provider() {
1227 match gpu_helpers::upload_complex_tensor(provider, &tensor) {
1228 Ok(handle) => return Ok(gpu_helpers::complex_gpu_value(handle)),
1229 Err(err) => {
1230 warn!(
1231 "meshgrid: failed to upload complex tensor to GPU, returning host array: {err}"
1232 )
1233 }
1234 }
1235 }
1236 Ok(complex_tensor_into_value(tensor))
1237}
1238
1239fn tensor_to_complex_tensor(tensor: Tensor) -> crate::BuiltinResult<ComplexTensor> {
1240 let data: Vec<(f64, f64)> = tensor.data.iter().map(|&re| (re, 0.0)).collect();
1241 ComplexTensor::new(data, tensor.shape.clone())
1242 .map_err(|e| builtin_error(format!("meshgrid: {e}")))
1243}
1244
1245fn tensor_to_complex_value(tensor: Tensor) -> crate::BuiltinResult<Value> {
1246 let complex = tensor_to_complex_tensor(tensor)?;
1247 Ok(complex_tensor_into_value(complex))
1248}
1249
1250enum MeshgridOutput {
1251 Host(GridOutput),
1252 Gpu(GpuTensorHandle),
1253}
1254
1255impl MeshgridOutput {
1256 async fn to_value(
1257 &self,
1258 class: PrototypeClass,
1259 residency: DevicePreference,
1260 ) -> crate::BuiltinResult<Value> {
1261 match self {
1262 MeshgridOutput::Host(host) => host.to_value(class, residency),
1263 MeshgridOutput::Gpu(handle) => match (class, residency) {
1264 (PrototypeClass::Real, DevicePreference::Gpu) => {
1265 Ok(Value::GpuTensor(handle.clone()))
1266 }
1267 (PrototypeClass::Real, DevicePreference::Host) => {
1268 let tensor = gpu_helpers::gather_tensor_async(handle).await?;
1269 Ok(tensor::tensor_into_value(tensor))
1270 }
1271 (PrototypeClass::Complex, DevicePreference::Host) => {
1272 match gpu_helpers::gather_value_async(&Value::GpuTensor(handle.clone())).await?
1273 {
1274 Value::ComplexTensor(tensor) => Ok(complex_tensor_into_value(tensor)),
1275 Value::Complex(re, im) => Ok(Value::Complex(re, im)),
1276 Value::Tensor(tensor) => tensor_to_complex_value(tensor),
1277 Value::Num(n) => Ok(Value::Complex(n, 0.0)),
1278 other => Err(builtin_error(format!(
1279 "meshgrid: expected numeric GPU output, got {other:?}"
1280 ))),
1281 }
1282 }
1283 (PrototypeClass::Complex, DevicePreference::Gpu) => {
1284 if runmat_accelerate_api::handle_storage(handle)
1285 == GpuTensorStorage::ComplexInterleaved
1286 {
1287 Ok(gpu_helpers::complex_gpu_value(handle.clone()))
1288 } else {
1289 let tensor = gpu_helpers::gather_tensor_async(handle).await?;
1290 to_complex_gpu_tensor_value(tensor_to_complex_tensor(tensor)?)
1291 }
1292 }
1293 },
1294 }
1295 }
1296}
1297
1298pub struct MeshgridEval {
1301 outputs: Vec<MeshgridOutput>,
1302 target_class: PrototypeClass,
1303 target_residency: DevicePreference,
1304}
1305
1306impl MeshgridEval {
1307 pub fn output_count(&self) -> usize {
1308 self.outputs.len()
1309 }
1310
1311 pub async fn first(&self) -> crate::BuiltinResult<Value> {
1312 self.outputs[0]
1313 .to_value(self.target_class, self.target_residency)
1314 .await
1315 }
1316
1317 pub async fn second(&self) -> crate::BuiltinResult<Value> {
1318 if self.outputs.len() < 2 {
1319 Err(builtin_error("meshgrid: second output unavailable"))
1320 } else {
1321 self.outputs[1]
1322 .to_value(self.target_class, self.target_residency)
1323 .await
1324 }
1325 }
1326
1327 pub async fn third(&self) -> crate::BuiltinResult<Value> {
1328 if self.outputs.len() < 3 {
1329 Err(builtin_error(
1330 "meshgrid: third output requested but no Z vector was supplied",
1331 ))
1332 } else {
1333 self.outputs[2]
1334 .to_value(self.target_class, self.target_residency)
1335 .await
1336 }
1337 }
1338}
1339
1340#[cfg(test)]
1341pub(crate) mod tests {
1342 use super::*;
1343 use crate::builtins::common::test_support;
1344 use futures::executor::block_on;
1345 #[cfg(feature = "wgpu")]
1346 use runmat_accelerate_api::AccelProvider;
1347
1348 use runmat_accelerate_api::HostTensorView;
1349
1350 fn evaluate(args: &[Value]) -> crate::BuiltinResult<MeshgridEval> {
1351 block_on(super::evaluate(args))
1352 }
1353
1354 fn eval_first(eval: &MeshgridEval) -> crate::BuiltinResult<Value> {
1355 block_on(eval.first())
1356 }
1357
1358 fn eval_second(eval: &MeshgridEval) -> crate::BuiltinResult<Value> {
1359 block_on(eval.second())
1360 }
1361
1362 fn eval_third(eval: &MeshgridEval) -> crate::BuiltinResult<Value> {
1363 block_on(eval.third())
1364 }
1365
1366 fn tensor_from_vec(data: Vec<f64>, rows: usize, cols: usize) -> Tensor {
1367 Tensor::new(data, vec![rows, cols]).unwrap()
1368 }
1369
1370 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1371 #[test]
1372 fn meshgrid_single_input_duplicates_axis() {
1373 let x = tensor_from_vec(vec![-1.0, 0.0, 1.0], 1, 3);
1374 let eval = evaluate(&[Value::Tensor(x)]).expect("meshgrid");
1375 assert_eq!(eval.output_count(), 2);
1376 let x_out = test_support::gather(eval_first(&eval).expect("X")).expect("host");
1377 assert_eq!(x_out.shape, vec![3, 3]);
1378 assert_eq!(
1379 x_out.data,
1380 vec![-1.0, -1.0, -1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0]
1381 );
1382 let y_out = test_support::gather(eval_second(&eval).expect("Y")).expect("host");
1383 assert_eq!(y_out.shape, vec![3, 3]);
1384 assert_eq!(
1385 y_out.data,
1386 vec![-1.0, 0.0, 1.0, -1.0, 0.0, 1.0, -1.0, 0.0, 1.0]
1387 );
1388 }
1389
1390 #[test]
1391 fn meshgrid_type_infers_rank_from_axis_count() {
1392 let ctx = ResolveContext::new(Vec::new());
1393 assert_eq!(
1394 meshgrid_type(&[Type::Num, Type::Num], &ctx),
1395 Type::Tensor {
1396 shape: Some(vec![Some(1), Some(1)])
1397 }
1398 );
1399 assert_eq!(
1400 meshgrid_type(&[Type::Num, Type::Num, Type::Num], &ctx),
1401 Type::Tensor {
1402 shape: Some(vec![Some(1), Some(1), Some(1)])
1403 }
1404 );
1405 }
1406
1407 #[test]
1408 fn meshgrid_type_uses_vector_lengths() {
1409 let ctx = ResolveContext::new(Vec::new());
1410 assert_eq!(
1411 meshgrid_type(
1412 &[
1413 Type::Tensor {
1414 shape: Some(vec![Some(1), Some(201)]),
1415 },
1416 Type::Tensor {
1417 shape: Some(vec![Some(1), Some(101)]),
1418 },
1419 ],
1420 &ctx,
1421 ),
1422 Type::Tensor {
1423 shape: Some(vec![Some(101), Some(201)])
1424 }
1425 );
1426 }
1427
1428 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1429 #[test]
1430 fn meshgrid_rectangular_inputs() {
1431 let x = tensor_from_vec(vec![0.0, 0.5, 1.0], 1, 3);
1432 let y = tensor_from_vec(vec![10.0, 20.0], 2, 1);
1433 let eval = evaluate(&[Value::Tensor(x), Value::Tensor(y)]).expect("meshgrid");
1434 assert_eq!(eval.output_count(), 2);
1435 let x_out = test_support::gather(eval_first(&eval).expect("X")).expect("host");
1436 assert_eq!(x_out.shape, vec![2, 3]);
1437 assert_eq!(x_out.data, vec![0.0, 0.0, 0.5, 0.5, 1.0, 1.0]);
1438 let y_out = test_support::gather(eval_second(&eval).expect("Y")).expect("host");
1439 assert_eq!(y_out.shape, vec![2, 3]);
1440 assert_eq!(y_out.data, vec![10.0, 20.0, 10.0, 20.0, 10.0, 20.0]);
1441 }
1442
1443 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1444 #[test]
1445 fn meshgrid_three_inputs_volume() {
1446 let x = tensor_from_vec(vec![1.0, 2.0], 1, 2);
1447 let y = tensor_from_vec(vec![5.0, 6.0, 7.0], 3, 1);
1448 let z = tensor_from_vec(vec![0.0, 1.0], 1, 2);
1449 let eval =
1450 evaluate(&[Value::Tensor(x), Value::Tensor(y), Value::Tensor(z)]).expect("meshgrid");
1451 assert_eq!(eval.output_count(), 3);
1452 let x_out = test_support::gather(eval_first(&eval).expect("X")).expect("host");
1453 assert_eq!(x_out.shape, vec![3, 2, 2]);
1454 assert_eq!(
1455 x_out.data,
1456 vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0]
1457 );
1458 let z_out = test_support::gather(eval_third(&eval).expect("Z")).expect("host");
1459 assert_eq!(z_out.shape, vec![3, 2, 2]);
1460 assert_eq!(
1461 z_out.data,
1462 vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
1463 );
1464 }
1465
1466 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1467 #[test]
1468 fn meshgrid_like_keeps_gpu_residency() {
1469 test_support::with_test_provider(|provider| {
1470 let x = tensor_from_vec(vec![-1.0, 0.0, 1.0], 1, 3);
1471 let y = tensor_from_vec(vec![2.0, 4.0], 2, 1);
1472 let proto = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
1473 let proto_view = HostTensorView {
1474 data: &proto.data,
1475 shape: &proto.shape,
1476 };
1477 let proto_handle = provider.upload(&proto_view).expect("upload");
1478 let eval = evaluate(&[
1479 Value::Tensor(x),
1480 Value::Tensor(y),
1481 Value::from("like"),
1482 Value::GpuTensor(proto_handle),
1483 ])
1484 .expect("meshgrid");
1485 let x_value = eval_first(&eval).expect("X");
1486 assert!(matches!(x_value, Value::GpuTensor(_)));
1487 let gathered = test_support::gather(x_value).expect("gather");
1488 assert_eq!(gathered.shape, vec![2, 3]);
1489 });
1490 }
1491
1492 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1493 #[test]
1494 fn meshgrid_gpu_inputs_roundtrip() {
1495 test_support::with_test_provider(|provider| {
1496 let x = tensor_from_vec(vec![0.0, 0.5], 1, 2);
1497 let y = tensor_from_vec(vec![1.0, 2.0], 2, 1);
1498 let x_view = HostTensorView {
1499 data: &x.data,
1500 shape: &x.shape,
1501 };
1502 let y_view = HostTensorView {
1503 data: &y.data,
1504 shape: &y.shape,
1505 };
1506 let x_handle = provider.upload(&x_view).expect("upload");
1507 let y_handle = provider.upload(&y_view).expect("upload");
1508 let eval = evaluate(&[Value::GpuTensor(x_handle), Value::GpuTensor(y_handle)])
1509 .expect("meshgrid");
1510 assert!(matches!(eval_first(&eval).expect("X"), Value::GpuTensor(_)));
1511 assert!(matches!(
1512 eval_second(&eval).expect("Y"),
1513 Value::GpuTensor(_)
1514 ));
1515 });
1516 }
1517
1518 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1519 #[test]
1520 #[cfg(feature = "wgpu")]
1521 fn meshgrid_wgpu_matches_cpu() {
1522 let _guard = test_support::accel_test_lock();
1523 let Ok(provider) = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1524 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1525 ) else {
1526 return;
1527 };
1528
1529 let x = tensor_from_vec(vec![-1.0, 0.0, 1.0, 2.0], 1, 4);
1530 let y = tensor_from_vec(vec![5.0, 6.0], 2, 1);
1531
1532 let cpu_eval =
1533 evaluate(&[Value::Tensor(x.clone()), Value::Tensor(y.clone())]).expect("meshgrid cpu");
1534 let cpu_x =
1535 test_support::gather(eval_first(&cpu_eval).expect("X cpu")).expect("gather X cpu");
1536 let cpu_y =
1537 test_support::gather(eval_second(&cpu_eval).expect("Y cpu")).expect("gather Y cpu");
1538
1539 let x_view = HostTensorView {
1540 data: &x.data,
1541 shape: &x.shape,
1542 };
1543 let y_view = HostTensorView {
1544 data: &y.data,
1545 shape: &y.shape,
1546 };
1547 let x_gpu = provider.upload(&x_view).expect("upload x");
1548 let y_gpu = provider.upload(&y_view).expect("upload y");
1549
1550 let gpu_eval =
1551 evaluate(&[Value::GpuTensor(x_gpu), Value::GpuTensor(y_gpu)]).expect("meshgrid gpu");
1552 let gpu_x_value = eval_first(&gpu_eval).expect("X gpu");
1553 let gpu_y_value = eval_second(&gpu_eval).expect("Y gpu");
1554
1555 assert!(matches!(gpu_x_value, Value::GpuTensor(_)));
1556 assert!(matches!(gpu_y_value, Value::GpuTensor(_)));
1557
1558 let gathered_x = test_support::gather(gpu_x_value).expect("gather X gpu");
1559 let gathered_y = test_support::gather(gpu_y_value).expect("gather Y gpu");
1560
1561 assert_eq!(gathered_x.shape, cpu_x.shape);
1562 assert_eq!(gathered_x.data, cpu_x.data);
1563 assert_eq!(gathered_y.shape, cpu_y.shape);
1564 assert_eq!(gathered_y.data, cpu_y.data);
1565 }
1566
1567 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1568 #[test]
1569 fn meshgrid_complex_inputs_produce_complex_outputs() {
1570 let complex = ComplexTensor::new(vec![(1.0, 1.0), (2.0, -1.0)], vec![1, 2]).unwrap();
1571 let eval = evaluate(&[Value::ComplexTensor(complex)]).expect("meshgrid");
1572 let x_value = eval_first(&eval).expect("X");
1573 match x_value {
1574 Value::ComplexTensor(ct) => {
1575 assert_eq!(ct.shape, vec![2, 2]);
1576 }
1577 Value::Complex(_, _) => {}
1578 other => panic!("expected complex output, got {other:?}"),
1579 }
1580 }
1581
1582 #[test]
1583 fn meshgrid_like_complex_gpu_prototype_keeps_complex_residency() {
1584 test_support::with_test_provider(|provider| {
1585 let x = tensor_from_vec(vec![1.0, 2.0], 1, 2);
1586 let proto = ComplexTensor::new(vec![(0.0, 1.0)], vec![1, 1]).unwrap();
1587 let proto_handle =
1588 gpu_helpers::upload_complex_tensor(provider, &proto).expect("upload");
1589
1590 let eval = evaluate(&[
1591 Value::Tensor(x),
1592 Value::from("like"),
1593 Value::GpuTensor(proto_handle),
1594 ])
1595 .expect("meshgrid");
1596 let x_value = eval_first(&eval).expect("X");
1597 let Value::GpuTensor(handle) = x_value else {
1598 panic!("expected complex gpu tensor");
1599 };
1600 assert_eq!(
1601 runmat_accelerate_api::handle_storage(&handle),
1602 GpuTensorStorage::ComplexInterleaved
1603 );
1604 let gathered = block_on(gpu_helpers::gather_value_async(&Value::GpuTensor(handle)))
1605 .expect("gather");
1606 let Value::ComplexTensor(tensor) = gathered else {
1607 panic!("expected complex tensor");
1608 };
1609 assert_eq!(tensor.shape, vec![2, 2]);
1610 assert_eq!(
1611 tensor.data,
1612 vec![(1.0, 0.0), (1.0, 0.0), (2.0, 0.0), (2.0, 0.0)]
1613 );
1614 });
1615 }
1616
1617 #[test]
1618 fn meshgrid_complex_gpu_axis_stays_resident() {
1619 test_support::with_test_provider(|provider| {
1620 let axis = ComplexTensor::new(vec![(1.0, 1.0), (2.0, -1.0)], vec![1, 2]).unwrap();
1621 let axis_handle = gpu_helpers::upload_complex_tensor(provider, &axis).expect("upload");
1622
1623 let eval = evaluate(&[Value::GpuTensor(axis_handle)]).expect("meshgrid");
1624 let x_value = eval_first(&eval).expect("X");
1625 let Value::GpuTensor(handle) = x_value else {
1626 panic!("expected complex gpu tensor");
1627 };
1628 assert_eq!(
1629 runmat_accelerate_api::handle_storage(&handle),
1630 GpuTensorStorage::ComplexInterleaved
1631 );
1632 let gathered = block_on(gpu_helpers::gather_value_async(&Value::GpuTensor(handle)))
1633 .expect("gather");
1634 let Value::ComplexTensor(tensor) = gathered else {
1635 panic!("expected complex tensor");
1636 };
1637 assert_eq!(tensor.shape, vec![2, 2]);
1638 assert_eq!(
1639 tensor.data,
1640 vec![(1.0, 1.0), (1.0, 1.0), (2.0, -1.0), (2.0, -1.0)]
1641 );
1642 });
1643 }
1644
1645 #[test]
1646 #[cfg(feature = "wgpu")]
1647 fn meshgrid_wgpu_complex_axis_matches_cpu_and_stays_resident() {
1648 let _guard = test_support::accel_test_lock();
1649 let Ok(provider) = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1650 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1651 ) else {
1652 return;
1653 };
1654
1655 let axis = ComplexTensor::new(vec![(1.0, 1.0), (2.0, -1.0)], vec![1, 2]).unwrap();
1656 let cpu_eval = evaluate(&[Value::ComplexTensor(axis.clone())]).expect("meshgrid cpu");
1657 let cpu_x = match eval_first(&cpu_eval).expect("X cpu") {
1658 Value::ComplexTensor(tensor) => tensor,
1659 other => panic!("expected cpu complex tensor, got {other:?}"),
1660 };
1661
1662 let axis_handle = gpu_helpers::upload_complex_tensor(provider, &axis).expect("upload");
1663 let gpu_eval = evaluate(&[Value::GpuTensor(axis_handle)]).expect("meshgrid gpu");
1664 let gpu_x = eval_first(&gpu_eval).expect("X gpu");
1665 let Value::GpuTensor(handle) = gpu_x else {
1666 panic!("expected complex gpu tensor");
1667 };
1668 assert_eq!(
1669 runmat_accelerate_api::handle_storage(&handle),
1670 GpuTensorStorage::ComplexInterleaved
1671 );
1672 let gathered =
1673 block_on(gpu_helpers::gather_value_async(&Value::GpuTensor(handle))).expect("gather");
1674 let Value::ComplexTensor(gpu_tensor) = gathered else {
1675 panic!("expected complex tensor");
1676 };
1677 assert_eq!(gpu_tensor.shape, cpu_x.shape);
1678 assert_eq!(gpu_tensor.data, cpu_x.data);
1679 }
1680
1681 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1682 #[test]
1683 fn meshgrid_like_host_prototype() {
1684 let x = tensor_from_vec(vec![1.0, 2.0], 1, 2);
1685 let eval =
1686 evaluate(&[Value::Tensor(x), Value::from("like"), Value::Num(0.0)]).expect("meshgrid");
1687 let x_out = eval_first(&eval).expect("X");
1688 assert!(matches!(x_out, Value::Tensor(_) | Value::Num(_)));
1689 }
1690}