1use std::cmp::max;
4
5use log::warn;
6use runmat_accelerate_api::{GpuTensorHandle, HostTensorView};
7use runmat_builtins::{ComplexTensor, ResolveContext, Tensor, Type, Value};
8
9use crate::builtins::array::type_resolvers::size_vector_len;
10use runmat_macros::runtime_builtin;
11
12use crate::build_runtime_error;
13use crate::builtins::common::gpu_helpers;
14use crate::builtins::common::random_args::{complex_tensor_into_value, keyword_of};
15use crate::builtins::common::residency::{sequence_gpu_preference, SequenceIntent};
16use crate::builtins::common::spec::{
17 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
18 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
19};
20use crate::builtins::common::tensor;
21
22#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::creation::meshgrid")]
23pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
24 name: "meshgrid",
25 op_kind: GpuOpKind::Custom("array_construct"),
26 supported_precisions: &[ScalarType::F32, ScalarType::F64],
27 broadcast: BroadcastSemantics::Matlab,
28 provider_hooks: &[ProviderHook::Custom("meshgrid")],
29 constant_strategy: ConstantStrategy::InlineLiteral,
30 residency: ResidencyPolicy::NewHandle,
31 nan_mode: ReductionNaN::Include,
32 two_pass_threshold: None,
33 workgroup_size: None,
34 accepts_nan_mode: false,
35 notes: "Providers may supply a dedicated meshgrid hook; until then the runtime builds grids on the host and uploads them when GPU residency is requested.",
36};
37
38fn builtin_error(message: impl Into<String>) -> crate::RuntimeError {
39 build_runtime_error(message)
40 .with_builtin("meshgrid")
41 .build()
42}
43
44#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::creation::meshgrid")]
45pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
46 name: "meshgrid",
47 shape: ShapeRequirements::Any,
48 constant_strategy: ConstantStrategy::InlineLiteral,
49 elementwise: None,
50 reduction: None,
51 emits_nan: false,
52 notes:
53 "Meshgrid explicitly materialises dense coordinate arrays and therefore bypasses fusion.",
54};
55
56fn meshgrid_type(args: &[Type], _context: &ResolveContext) -> Type {
57 if args.is_empty() {
58 return Type::Unknown;
59 }
60 let mut axis_count = args.len();
61 if axis_count >= 2 && matches!(args[axis_count - 2], Type::String) {
62 axis_count = axis_count.saturating_sub(2);
63 }
64 if axis_count == 0 {
65 return Type::Unknown;
66 }
67 let axis_args = &args[..axis_count];
68 let len_x = axis_args.get(0).and_then(size_vector_len);
69 let len_y = axis_args.get(1).and_then(size_vector_len).or(len_x);
70 let len_z = axis_args.get(2).and_then(size_vector_len);
71 let shape = if axis_count >= 3 {
72 vec![len_y, len_x, len_z]
73 } else {
74 vec![len_y, len_x]
75 };
76 Type::Tensor { shape: Some(shape) }
77}
78
79#[runtime_builtin(
80 name = "meshgrid",
81 category = "array/creation",
82 summary = "Generate coordinate matrices for 2-D and 3-D grids.",
83 keywords = "meshgrid,grid,gpu,like,3d",
84 accel = "array_construct",
85 type_resolver(meshgrid_type),
86 builtin_path = "crate::builtins::array::creation::meshgrid"
87)]
88async fn meshgrid_builtin(rest: Vec<Value>) -> crate::BuiltinResult<Value> {
89 let eval = evaluate(&rest).await?;
90 if let Some(out_count) = crate::output_count::current_output_count() {
91 if out_count == 0 {
92 return Ok(Value::OutputList(Vec::new()));
93 }
94 let available = eval.output_count();
95 if out_count > available {
96 let msg = if available == 2 {
97 "meshgrid with two inputs supports at most two outputs"
98 } else {
99 "meshgrid supports at most three outputs"
100 };
101 return Err(builtin_error(msg));
102 }
103 let mut outputs = Vec::with_capacity(out_count);
104 let first = eval.first().await?;
105 outputs.push(first);
106 if out_count >= 2 {
107 outputs.push(eval.second().await?);
108 }
109 if out_count >= 3 {
110 outputs.push(eval.third().await?);
111 }
112 return Ok(Value::OutputList(outputs));
113 }
114 eval.first().await
115}
116
117pub async fn evaluate(args: &[Value]) -> crate::BuiltinResult<MeshgridEval> {
119 let parsed = ParsedMeshgrid::parse(args).await?;
120 let (x_axis, y_axis, z_axis) = normalise_axes(&parsed.axes);
121
122 let require_complex = parsed.axes.iter().any(|axis| axis.is_complex);
123
124 let target_class = match &parsed.template {
125 OutputTemplate::Default => {
126 if require_complex {
127 PrototypeClass::Complex
128 } else {
129 PrototypeClass::Real
130 }
131 }
132 OutputTemplate::Like(spec) => {
133 if require_complex {
134 PrototypeClass::Complex
135 } else {
136 spec.class
137 }
138 }
139 };
140
141 let target_residency = match &parsed.template {
142 OutputTemplate::Default => {
143 if parsed.prefer_gpu {
144 DevicePreference::Gpu
145 } else {
146 DevicePreference::Host
147 }
148 }
149 OutputTemplate::Like(spec) => spec.residency,
150 };
151
152 let axes_all_real = !require_complex;
153 let mut outputs: Vec<MeshgridOutput> = Vec::new();
154
155 if axes_all_real
156 && matches!(target_class, PrototypeClass::Real)
157 && matches!(target_residency, DevicePreference::Gpu)
158 {
159 if let Some(gpu) = try_meshgrid_gpu_from_vector_axes(&x_axis, &y_axis, z_axis.as_ref())? {
160 outputs = gpu;
161 }
162 }
163
164 if outputs.is_empty() {
165 let x_host = axis_to_host_async(&x_axis).await?;
167 let y_host = axis_to_host_async(&y_axis).await?;
168 let z_host = match z_axis.as_ref() {
169 Some(axis) => Some(axis_to_host_async(axis).await?),
170 None => None,
171 };
172 outputs = build_outputs(&x_host, &y_host, z_host.as_ref())
173 .into_iter()
174 .map(MeshgridOutput::Host)
175 .collect();
176 }
177
178 Ok(MeshgridEval {
179 outputs,
180 target_class,
181 target_residency,
182 })
183}
184
185#[derive(Clone)]
186struct ParsedMeshgrid {
187 axes: Vec<AxisData>,
188 template: OutputTemplate,
189 prefer_gpu: bool,
190}
191
192impl ParsedMeshgrid {
193 async fn parse(args: &[Value]) -> crate::BuiltinResult<Self> {
194 if args.is_empty() {
195 return Err(builtin_error(
196 "meshgrid: at least one input vector is required",
197 ));
198 }
199 let mut axis_values: Vec<Value> = Vec::new();
200 let mut like_proto: Option<Value> = None;
201 let mut prefer_gpu = false;
202 let mut idx = 0;
203 while idx < args.len() {
204 let value = args[idx].clone();
205 if let Some(keyword) = keyword_of(&value) {
206 match keyword.as_str() {
207 "like" => {
208 if like_proto.is_some() {
209 return Err(builtin_error(
210 "meshgrid: multiple 'like' specifications are not supported",
211 ));
212 }
213 if axis_values.is_empty() {
214 return Err(builtin_error(
215 "meshgrid: 'like' must follow at least one input vector",
216 ));
217 }
218 let Some(proto) = args.get(idx + 1).cloned() else {
219 return Err(builtin_error("meshgrid: expected prototype after 'like'"));
220 };
221 like_proto = Some(proto);
222 idx += 2;
223 if idx < args.len() {
224 return Err(builtin_error(
225 "meshgrid: 'like' must be the final argument",
226 ));
227 }
228 break;
229 }
230 other => {
231 return Err(builtin_error(format!(
232 "meshgrid: unrecognised option '{other}'"
233 )));
234 }
235 }
236 }
237
238 if let Value::GpuTensor(_) = value {
239 prefer_gpu = true;
240 }
241 axis_values.push(value);
242 idx += 1;
243 }
244
245 if axis_values.is_empty() {
246 return Err(builtin_error(
247 "meshgrid: at least one input vector is required",
248 ));
249 }
250 if axis_values.len() > 3 {
251 return Err(builtin_error(
252 "meshgrid: expected at most three input vectors",
253 ));
254 }
255
256 let mut axes = Vec::with_capacity(max(axis_values.len(), 2));
257 for (i, value) in axis_values.into_iter().enumerate() {
258 let mut consumed_gpu = false;
259 let data = axis_from_value(value, i, &mut consumed_gpu).await?;
260 if consumed_gpu {
261 prefer_gpu = true;
262 }
263 axes.push(data);
264 }
265
266 if !prefer_gpu {
267 if let Some(max_len) = axes.iter().map(|axis| axis.len).max() {
268 if max_len > 0
269 && sequence_gpu_preference(max_len, SequenceIntent::MeshAxis, false).prefer_gpu
270 {
271 prefer_gpu = true;
272 }
273 }
274 }
275
276 let template = if let Some(proto) = like_proto {
277 OutputTemplate::Like(analyse_like_prototype(&proto)?)
278 } else {
279 OutputTemplate::Default
280 };
281
282 Ok(Self {
283 axes,
284 template,
285 prefer_gpu,
286 })
287 }
288}
289
290#[derive(Clone)]
291enum OutputTemplate {
292 Default,
293 Like(PrototypeSpec),
294}
295
296#[derive(Clone)]
297struct PrototypeSpec {
298 residency: DevicePreference,
299 class: PrototypeClass,
300}
301
302#[derive(Clone, Copy, PartialEq, Eq)]
303enum PrototypeClass {
304 Real,
305 Complex,
306}
307
308#[derive(Clone, Copy)]
309enum DevicePreference {
310 Host,
311 Gpu,
312}
313
314fn analyse_like_prototype(proto: &Value) -> crate::BuiltinResult<PrototypeSpec> {
315 match proto {
316 Value::GpuTensor(_) => Ok(PrototypeSpec {
317 residency: DevicePreference::Gpu,
318 class: PrototypeClass::Real,
319 }),
320 Value::ComplexTensor(_) | Value::Complex(_, _) => Ok(PrototypeSpec {
321 residency: DevicePreference::Host,
322 class: PrototypeClass::Complex,
323 }),
324 Value::Tensor(_)
325 | Value::Num(_)
326 | Value::Int(_)
327 | Value::Bool(_)
328 | Value::LogicalArray(_) => Ok(PrototypeSpec {
329 residency: DevicePreference::Host,
330 class: PrototypeClass::Real,
331 }),
332 Value::CharArray(_) | Value::String(_) | Value::StringArray(_) => Err(builtin_error(
333 "meshgrid: prototypes must be numeric or gpuArray values",
334 )),
335 Value::Cell(_)
336 | Value::Struct(_)
337 | Value::Object(_)
338 | Value::HandleObject(_)
339 | Value::Listener(_)
340 | Value::FunctionHandle(_)
341 | Value::Closure(_)
342 | Value::ClassRef(_)
343 | Value::MException(_)
344 | Value::OutputList(_) => Err(builtin_error("meshgrid: prototypes must be numeric arrays")),
345 }
346}
347
348#[derive(Clone)]
349struct AxisData {
350 values: Vec<(f64, f64)>,
351 len: usize,
352 is_complex: bool,
353 gpu_real: Option<GpuTensorHandle>,
354}
355
356async fn axis_from_value(
357 value: Value,
358 index: usize,
359 prefer_gpu: &mut bool,
360) -> crate::BuiltinResult<AxisData> {
361 match value {
362 Value::Tensor(tensor) => axis_from_tensor(tensor, index),
363 Value::LogicalArray(logical) => {
364 let tensor = tensor::logical_to_tensor(&logical)?;
365 axis_from_tensor(tensor, index)
366 }
367 Value::Num(n) => Ok(AxisData {
368 values: vec![(n, 0.0)],
369 len: 1,
370 is_complex: false,
371 gpu_real: None,
372 }),
373 Value::Int(i) => {
374 let val = i.to_f64();
375 Ok(AxisData {
376 values: vec![(val, 0.0)],
377 len: 1,
378 is_complex: false,
379 gpu_real: None,
380 })
381 }
382 Value::Bool(b) => Ok(AxisData {
383 values: vec![(if b { 1.0 } else { 0.0 }, 0.0)],
384 len: 1,
385 is_complex: false,
386 gpu_real: None,
387 }),
388 Value::Complex(re, im) => Ok(AxisData {
389 values: vec![(re, im)],
390 len: 1,
391 is_complex: im != 0.0,
392 gpu_real: None,
393 }),
394 Value::ComplexTensor(tensor) => axis_from_complex_tensor(tensor, index),
395 Value::GpuTensor(handle) => {
396 if is_vector_shape(&handle.shape) {
399 *prefer_gpu = true;
400 return Ok(AxisData {
401 values: Vec::new(),
402 len: vector_len_from_shape(&handle.shape),
403 is_complex: false,
404 gpu_real: Some(handle),
405 });
406 }
407
408 let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
410 if is_vector_shape(&tensor.shape) {
411 *prefer_gpu = true;
412 }
413 axis_from_tensor(tensor, index)
414 }
415 other => Err(builtin_error(format!(
416 "meshgrid: input argument {} must be numeric, got {other:?}",
417 index + 1
418 ))),
419 }
420}
421
422fn axis_from_tensor(tensor: Tensor, index: usize) -> crate::BuiltinResult<AxisData> {
423 if is_vector_shape(&tensor.shape) {
424 let mut values = Vec::with_capacity(tensor.data.len());
425 for &v in &tensor.data {
426 values.push((v, 0.0));
427 }
428 return Ok(AxisData {
429 len: values.len(),
430 values,
431 is_complex: false,
432 gpu_real: None,
433 });
434 }
435
436 if let Some(axis) = axis_from_meshgrid_matrix_real(&tensor, index)? {
442 return Ok(axis);
443 }
444
445 Err(builtin_error(format!(
446 "meshgrid: input argument {} must be a vector (1xN or Nx1), got shape {:?}",
447 index + 1,
448 tensor.shape
449 )))
450}
451
452fn axis_from_complex_tensor(tensor: ComplexTensor, index: usize) -> crate::BuiltinResult<AxisData> {
453 if is_vector_shape(&tensor.shape) {
454 let is_complex = tensor
455 .data
456 .iter()
457 .any(|&(_, imag)| !imag.is_nan() && imag != 0.0);
458 return Ok(AxisData {
459 len: tensor.data.len(),
460 values: tensor.data,
461 is_complex,
462 gpu_real: None,
463 });
464 }
465
466 if let Some(axis) = axis_from_meshgrid_matrix_complex(&tensor, index)? {
467 return Ok(axis);
468 }
469
470 Err(builtin_error(format!(
471 "meshgrid: input argument {} must be a vector (1xN or Nx1), got shape {:?}",
472 index + 1,
473 tensor.shape
474 )))
475}
476
477fn axis_from_meshgrid_matrix_real(
478 tensor: &Tensor,
479 index: usize,
480) -> crate::BuiltinResult<Option<AxisData>> {
481 let (rows, cols) = match tensor.shape.as_slice() {
482 [r, c] => (*r, *c),
483 _ => return Ok(None),
484 };
485 if rows <= 1 || cols <= 1 {
486 return Ok(None);
487 }
488
489 let expect_rows_constant = index == 0;
492
493 if expect_rows_constant {
494 if !matrix_rows_are_identical_real(tensor, rows, cols) {
495 return Ok(None);
496 }
497 let mut values = Vec::with_capacity(cols);
499 for col in 0..cols {
500 let idx = rows * col;
501 values.push((tensor.data[idx], 0.0));
502 }
503 return Ok(Some(AxisData {
504 len: values.len(),
505 values,
506 is_complex: false,
507 gpu_real: None,
508 }));
509 }
510
511 if !matrix_cols_are_identical_real(tensor, rows, cols) {
512 return Ok(None);
513 }
514 let mut values = Vec::with_capacity(rows);
516 for row in 0..rows {
517 values.push((tensor.data[row], 0.0));
518 }
519 Ok(Some(AxisData {
520 len: values.len(),
521 values,
522 is_complex: false,
523 gpu_real: None,
524 }))
525}
526
527fn axis_from_meshgrid_matrix_complex(
528 tensor: &ComplexTensor,
529 index: usize,
530) -> crate::BuiltinResult<Option<AxisData>> {
531 let (rows, cols) = match tensor.shape.as_slice() {
532 [r, c] => (*r, *c),
533 _ => return Ok(None),
534 };
535 if rows <= 1 || cols <= 1 {
536 return Ok(None);
537 }
538
539 let expect_rows_constant = index == 0;
540 if expect_rows_constant {
541 if !matrix_rows_are_identical_complex(tensor, rows, cols) {
542 return Ok(None);
543 }
544 let mut values = Vec::with_capacity(cols);
545 for col in 0..cols {
546 let idx = rows * col;
547 values.push(tensor.data[idx]);
548 }
549 let is_complex = values.iter().any(|&(_, im)| !im.is_nan() && im != 0.0);
550 return Ok(Some(AxisData {
551 len: values.len(),
552 values,
553 is_complex,
554 gpu_real: None,
555 }));
556 }
557
558 if !matrix_cols_are_identical_complex(tensor, rows, cols) {
559 return Ok(None);
560 }
561 let mut values = Vec::with_capacity(rows);
562 for row in 0..rows {
563 values.push(tensor.data[row]);
564 }
565 let is_complex = values.iter().any(|&(_, im)| !im.is_nan() && im != 0.0);
566 Ok(Some(AxisData {
567 len: values.len(),
568 values,
569 is_complex,
570 gpu_real: None,
571 }))
572}
573
574fn matrix_rows_are_identical_real(tensor: &Tensor, rows: usize, cols: usize) -> bool {
575 for row in 1..rows {
576 for col in 0..cols {
577 let idx0 = rows * col;
578 let idx = row + rows * col;
579 if tensor.data[idx] != tensor.data[idx0] {
580 return false;
581 }
582 }
583 }
584 true
585}
586
587fn matrix_cols_are_identical_real(tensor: &Tensor, rows: usize, cols: usize) -> bool {
588 for col in 1..cols {
589 for row in 0..rows {
590 let idx0 = row;
591 let idx = row + rows * col;
592 if tensor.data[idx] != tensor.data[idx0] {
593 return false;
594 }
595 }
596 }
597 true
598}
599
600fn matrix_rows_are_identical_complex(tensor: &ComplexTensor, rows: usize, cols: usize) -> bool {
601 for row in 1..rows {
602 for col in 0..cols {
603 let idx0 = rows * col;
604 let idx = row + rows * col;
605 if tensor.data[idx] != tensor.data[idx0] {
606 return false;
607 }
608 }
609 }
610 true
611}
612
613fn matrix_cols_are_identical_complex(tensor: &ComplexTensor, rows: usize, cols: usize) -> bool {
614 for col in 1..cols {
615 for row in 0..rows {
616 let idx0 = row;
617 let idx = row + rows * col;
618 if tensor.data[idx] != tensor.data[idx0] {
619 return false;
620 }
621 }
622 }
623 true
624}
625
626fn is_vector_shape(shape: &[usize]) -> bool {
627 if shape.is_empty() {
628 return true;
629 }
630 let mut non_singleton = 0usize;
631 for &dim in shape {
632 if dim > 1 {
633 non_singleton += 1;
634 }
635 }
636 non_singleton <= 1
637}
638
639fn vector_len_from_shape(shape: &[usize]) -> usize {
640 if shape.is_empty() {
641 return 1;
642 }
643 shape.iter().copied().max().unwrap_or(0)
644}
645
646async fn axis_to_host_async(axis: &AxisData) -> crate::BuiltinResult<AxisData> {
647 if axis.gpu_real.is_none() {
648 return Ok(axis.clone());
649 }
650 let handle = axis.gpu_real.as_ref().expect("checked gpu_real is_some");
651 let tensor = gpu_helpers::gather_tensor_async(handle).await?;
652 axis_from_tensor(tensor, 0)
654}
655
656fn try_meshgrid_gpu_from_vector_axes(
657 x_axis: &AxisData,
658 y_axis: &AxisData,
659 z_axis: Option<&AxisData>,
660) -> crate::BuiltinResult<Option<Vec<MeshgridOutput>>> {
661 let Some(x_handle) = x_axis.gpu_real.as_ref() else {
662 return Ok(None);
663 };
664 let Some(y_handle) = y_axis.gpu_real.as_ref() else {
665 return Ok(None);
666 };
667
668 let z_handle = match z_axis {
669 Some(axis) => match axis.gpu_real.as_ref() {
670 Some(h) => Some(h),
671 None => return Ok(None),
672 },
673 None => None,
674 };
675
676 let Some(provider) = runmat_accelerate_api::provider_for_handle(x_handle) else {
677 return Ok(None);
678 };
679 if runmat_accelerate_api::provider_for_handle(y_handle).is_none() {
680 return Ok(None);
681 }
682 if let Some(z) = z_handle {
683 if runmat_accelerate_api::provider_for_handle(z).is_none() {
684 return Ok(None);
685 }
686 }
687
688 let nx = x_axis.len;
689 let ny = y_axis.len;
690 let nz = z_axis.map(|axis| axis.len).unwrap_or(1);
691
692 let x_row = provider
694 .reshape(x_handle, &[1, nx])
695 .map_err(|e| builtin_error(format!("meshgrid: reshape X failed: {e}")))?;
696 let y_col = provider
697 .reshape(y_handle, &[ny, 1])
698 .map_err(|e| builtin_error(format!("meshgrid: reshape Y failed: {e}")))?;
699
700 let mut outputs = Vec::with_capacity(if z_handle.is_some() { 3 } else { 2 });
701 if let Some(z) = z_handle {
702 let x_base = provider
703 .reshape(&x_row, &[1, nx, 1])
704 .map_err(|e| builtin_error(format!("meshgrid: reshape X(3d) failed: {e}")))?;
705 let y_base = provider
706 .reshape(&y_col, &[ny, 1, 1])
707 .map_err(|e| builtin_error(format!("meshgrid: reshape Y(3d) failed: {e}")))?;
708
709 let x_grid = provider
710 .repmat(&x_base, &[ny, 1, nz])
711 .map_err(|e| builtin_error(format!("meshgrid: repmat X failed: {e}")))?;
712 let y_grid = provider
713 .repmat(&y_base, &[1, nx, nz])
714 .map_err(|e| builtin_error(format!("meshgrid: repmat Y failed: {e}")))?;
715
716 outputs.push(MeshgridOutput::GpuReal(x_grid));
717 outputs.push(MeshgridOutput::GpuReal(y_grid));
718 let z_axis_row = provider
719 .reshape(z, &[1, nz])
720 .map_err(|e| builtin_error(format!("meshgrid: reshape Z failed: {e}")))?;
721 let z_base = provider
722 .reshape(&z_axis_row, &[1, 1, nz])
723 .map_err(|e| builtin_error(format!("meshgrid: reshape Z(3d) failed: {e}")))?;
724 let z_grid = provider
725 .repmat(&z_base, &[ny, nx, 1])
726 .map_err(|e| builtin_error(format!("meshgrid: repmat Z failed: {e}")))?;
727 outputs.push(MeshgridOutput::GpuReal(z_grid));
728 } else {
729 let x_grid = provider
730 .repmat(&x_row, &[ny, 1])
731 .map_err(|e| builtin_error(format!("meshgrid: repmat X failed: {e}")))?;
732 let y_grid = provider
733 .repmat(&y_col, &[1, nx])
734 .map_err(|e| builtin_error(format!("meshgrid: repmat Y failed: {e}")))?;
735 outputs.push(MeshgridOutput::GpuReal(x_grid));
736 outputs.push(MeshgridOutput::GpuReal(y_grid));
737 }
738
739 Ok(Some(outputs))
740}
741
742fn normalise_axes(axes: &[AxisData]) -> (AxisData, AxisData, Option<AxisData>) {
743 match axes.len() {
744 1 => {
745 let x = axes[0].clone();
746 (x.clone(), x, None)
747 }
748 2 => {
749 let x = axes[0].clone();
750 let y = axes[1].clone();
751 (x, y, None)
752 }
753 3 => {
754 let x = axes[0].clone();
755 let y = axes[1].clone();
756 let z = axes[2].clone();
757 (x, y, Some(z))
758 }
759 _ => unreachable!(),
760 }
761}
762
763fn build_outputs(
764 x_axis: &AxisData,
765 y_axis: &AxisData,
766 z_axis: Option<&AxisData>,
767) -> Vec<GridOutput> {
768 let nx = x_axis.len;
769 let ny = y_axis.len;
770 let nz = z_axis.map(|axis| axis.len).unwrap_or(1);
771 let total = nx * ny * nz;
772 let mut x_data = Vec::with_capacity(total);
773 let mut y_data = Vec::with_capacity(total);
774 let mut z_data = z_axis.map(|_| Vec::with_capacity(total));
775
776 for k in 0..nz {
777 let z_value = z_axis.map(|axis| axis.values[k]);
778 for col in 0..nx {
779 let x_value = x_axis.values[col];
780 for row in 0..ny {
781 x_data.push(x_value);
782 y_data.push(y_axis.values[row]);
783 if let Some(ref mut z_vec) = z_data {
784 z_vec.push(z_value.unwrap());
785 }
786 }
787 }
788 }
789
790 let mut outputs = Vec::new();
791 let base_shape = if nz == 1 {
792 vec![ny, nx]
793 } else {
794 vec![ny, nx, nz]
795 };
796 outputs.push(GridOutput {
797 shape: base_shape.clone(),
798 data: x_data,
799 });
800 outputs.push(GridOutput {
801 shape: base_shape.clone(),
802 data: y_data,
803 });
804 if let Some(z_vec) = z_data {
805 outputs.push(GridOutput {
806 shape: base_shape,
807 data: z_vec,
808 });
809 }
810 outputs
811}
812
813struct GridOutput {
814 shape: Vec<usize>,
815 data: Vec<(f64, f64)>,
816}
817
818impl GridOutput {
819 fn to_value(
820 &self,
821 class: PrototypeClass,
822 residency: DevicePreference,
823 ) -> crate::BuiltinResult<Value> {
824 match class {
825 PrototypeClass::Real => self.to_real_value(residency),
826 PrototypeClass::Complex => self.to_complex_value(residency),
827 }
828 }
829
830 fn to_real_value(&self, residency: DevicePreference) -> crate::BuiltinResult<Value> {
831 let mut real = Vec::with_capacity(self.data.len());
832 for &(re, im) in &self.data {
833 if im != 0.0 {
834 return Err(builtin_error(
835 "meshgrid: cannot represent complex values in a real output",
836 ));
837 }
838 real.push(re);
839 }
840 let tensor = Tensor::new(real, self.shape.clone())
841 .map_err(|e| builtin_error(format!("meshgrid: {e}")))?;
842 match residency {
843 DevicePreference::Host => Ok(tensor::tensor_into_value(tensor)),
844 DevicePreference::Gpu => to_gpu_tensor_value(tensor),
845 }
846 }
847
848 fn to_complex_value(&self, residency: DevicePreference) -> crate::BuiltinResult<Value> {
849 let tensor = ComplexTensor::new(self.data.clone(), self.shape.clone())
850 .map_err(|e| builtin_error(format!("meshgrid: {e}")))?;
851 match residency {
852 DevicePreference::Host => Ok(complex_tensor_into_value(tensor)),
853 DevicePreference::Gpu => {
854 warn!("meshgrid: complex GPU outputs are not implemented; returning host complex array");
855 Ok(complex_tensor_into_value(tensor))
856 }
857 }
858 }
859}
860
861fn to_gpu_tensor_value(tensor: Tensor) -> crate::BuiltinResult<Value> {
862 if let Some(provider) = runmat_accelerate_api::provider() {
863 let view = HostTensorView {
864 data: &tensor.data,
865 shape: &tensor.shape,
866 };
867 match provider.upload(&view) {
868 Ok(handle) => return Ok(Value::GpuTensor(handle)),
869 Err(err) => {
870 warn!("meshgrid: failed to upload tensor to GPU, returning host array: {err}")
871 }
872 }
873 }
874 Ok(tensor::tensor_into_value(tensor))
875}
876
877fn tensor_to_complex_value(tensor: Tensor) -> crate::BuiltinResult<Value> {
878 let data: Vec<(f64, f64)> = tensor.data.iter().map(|&re| (re, 0.0)).collect();
879 let complex = ComplexTensor::new(data, tensor.shape.clone())
880 .map_err(|e| builtin_error(format!("meshgrid: {e}")))?;
881 Ok(complex_tensor_into_value(complex))
882}
883
884enum MeshgridOutput {
885 Host(GridOutput),
886 GpuReal(GpuTensorHandle),
887}
888
889impl MeshgridOutput {
890 async fn to_value(
891 &self,
892 class: PrototypeClass,
893 residency: DevicePreference,
894 ) -> crate::BuiltinResult<Value> {
895 match self {
896 MeshgridOutput::Host(host) => host.to_value(class, residency),
897 MeshgridOutput::GpuReal(handle) => match (class, residency) {
898 (PrototypeClass::Real, DevicePreference::Gpu) => {
899 Ok(Value::GpuTensor(handle.clone()))
900 }
901 (PrototypeClass::Real, DevicePreference::Host) => {
902 let tensor = gpu_helpers::gather_tensor_async(handle).await?;
903 Ok(tensor::tensor_into_value(tensor))
904 }
905 (PrototypeClass::Complex, DevicePreference::Host) => {
906 let tensor = gpu_helpers::gather_tensor_async(handle).await?;
907 tensor_to_complex_value(tensor)
908 }
909 (PrototypeClass::Complex, DevicePreference::Gpu) => {
910 warn!("meshgrid: complex GPU outputs are not implemented; returning host complex array");
911 let tensor = gpu_helpers::gather_tensor_async(handle).await?;
912 tensor_to_complex_value(tensor)
913 }
914 },
915 }
916 }
917}
918
919pub struct MeshgridEval {
922 outputs: Vec<MeshgridOutput>,
923 target_class: PrototypeClass,
924 target_residency: DevicePreference,
925}
926
927impl MeshgridEval {
928 pub fn output_count(&self) -> usize {
929 self.outputs.len()
930 }
931
932 pub async fn first(&self) -> crate::BuiltinResult<Value> {
933 self.outputs[0]
934 .to_value(self.target_class, self.target_residency)
935 .await
936 }
937
938 pub async fn second(&self) -> crate::BuiltinResult<Value> {
939 if self.outputs.len() < 2 {
940 Err(builtin_error("meshgrid: second output unavailable"))
941 } else {
942 self.outputs[1]
943 .to_value(self.target_class, self.target_residency)
944 .await
945 }
946 }
947
948 pub async fn third(&self) -> crate::BuiltinResult<Value> {
949 if self.outputs.len() < 3 {
950 Err(builtin_error(
951 "meshgrid: third output requested but no Z vector was supplied",
952 ))
953 } else {
954 self.outputs[2]
955 .to_value(self.target_class, self.target_residency)
956 .await
957 }
958 }
959}
960
961#[cfg(test)]
962pub(crate) mod tests {
963 use super::*;
964 use crate::builtins::common::test_support;
965 use futures::executor::block_on;
966 #[cfg(feature = "wgpu")]
967 use runmat_accelerate_api::AccelProvider;
968
969 use runmat_accelerate_api::HostTensorView;
970
971 fn evaluate(args: &[Value]) -> crate::BuiltinResult<MeshgridEval> {
972 block_on(super::evaluate(args))
973 }
974
975 fn eval_first(eval: &MeshgridEval) -> crate::BuiltinResult<Value> {
976 block_on(eval.first())
977 }
978
979 fn eval_second(eval: &MeshgridEval) -> crate::BuiltinResult<Value> {
980 block_on(eval.second())
981 }
982
983 fn eval_third(eval: &MeshgridEval) -> crate::BuiltinResult<Value> {
984 block_on(eval.third())
985 }
986
987 fn tensor_from_vec(data: Vec<f64>, rows: usize, cols: usize) -> Tensor {
988 Tensor::new(data, vec![rows, cols]).unwrap()
989 }
990
991 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
992 #[test]
993 fn meshgrid_single_input_duplicates_axis() {
994 let x = tensor_from_vec(vec![-1.0, 0.0, 1.0], 1, 3);
995 let eval = evaluate(&[Value::Tensor(x)]).expect("meshgrid");
996 assert_eq!(eval.output_count(), 2);
997 let x_out = test_support::gather(eval_first(&eval).expect("X")).expect("host");
998 assert_eq!(x_out.shape, vec![3, 3]);
999 assert_eq!(
1000 x_out.data,
1001 vec![-1.0, -1.0, -1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0]
1002 );
1003 let y_out = test_support::gather(eval_second(&eval).expect("Y")).expect("host");
1004 assert_eq!(y_out.shape, vec![3, 3]);
1005 assert_eq!(
1006 y_out.data,
1007 vec![-1.0, 0.0, 1.0, -1.0, 0.0, 1.0, -1.0, 0.0, 1.0]
1008 );
1009 }
1010
1011 #[test]
1012 fn meshgrid_type_infers_rank_from_axis_count() {
1013 let ctx = ResolveContext::new(Vec::new());
1014 assert_eq!(
1015 meshgrid_type(&[Type::Num, Type::Num], &ctx),
1016 Type::Tensor {
1017 shape: Some(vec![Some(1), Some(1)])
1018 }
1019 );
1020 assert_eq!(
1021 meshgrid_type(&[Type::Num, Type::Num, Type::Num], &ctx),
1022 Type::Tensor {
1023 shape: Some(vec![Some(1), Some(1), Some(1)])
1024 }
1025 );
1026 }
1027
1028 #[test]
1029 fn meshgrid_type_uses_vector_lengths() {
1030 let ctx = ResolveContext::new(Vec::new());
1031 assert_eq!(
1032 meshgrid_type(
1033 &[
1034 Type::Tensor {
1035 shape: Some(vec![Some(1), Some(201)]),
1036 },
1037 Type::Tensor {
1038 shape: Some(vec![Some(1), Some(101)]),
1039 },
1040 ],
1041 &ctx,
1042 ),
1043 Type::Tensor {
1044 shape: Some(vec![Some(101), Some(201)])
1045 }
1046 );
1047 }
1048
1049 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1050 #[test]
1051 fn meshgrid_rectangular_inputs() {
1052 let x = tensor_from_vec(vec![0.0, 0.5, 1.0], 1, 3);
1053 let y = tensor_from_vec(vec![10.0, 20.0], 2, 1);
1054 let eval = evaluate(&[Value::Tensor(x), Value::Tensor(y)]).expect("meshgrid");
1055 assert_eq!(eval.output_count(), 2);
1056 let x_out = test_support::gather(eval_first(&eval).expect("X")).expect("host");
1057 assert_eq!(x_out.shape, vec![2, 3]);
1058 assert_eq!(x_out.data, vec![0.0, 0.0, 0.5, 0.5, 1.0, 1.0]);
1059 let y_out = test_support::gather(eval_second(&eval).expect("Y")).expect("host");
1060 assert_eq!(y_out.shape, vec![2, 3]);
1061 assert_eq!(y_out.data, vec![10.0, 20.0, 10.0, 20.0, 10.0, 20.0]);
1062 }
1063
1064 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1065 #[test]
1066 fn meshgrid_three_inputs_volume() {
1067 let x = tensor_from_vec(vec![1.0, 2.0], 1, 2);
1068 let y = tensor_from_vec(vec![5.0, 6.0, 7.0], 3, 1);
1069 let z = tensor_from_vec(vec![0.0, 1.0], 1, 2);
1070 let eval =
1071 evaluate(&[Value::Tensor(x), Value::Tensor(y), Value::Tensor(z)]).expect("meshgrid");
1072 assert_eq!(eval.output_count(), 3);
1073 let x_out = test_support::gather(eval_first(&eval).expect("X")).expect("host");
1074 assert_eq!(x_out.shape, vec![3, 2, 2]);
1075 assert_eq!(
1076 x_out.data,
1077 vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0]
1078 );
1079 let z_out = test_support::gather(eval_third(&eval).expect("Z")).expect("host");
1080 assert_eq!(z_out.shape, vec![3, 2, 2]);
1081 assert_eq!(
1082 z_out.data,
1083 vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
1084 );
1085 }
1086
1087 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1088 #[test]
1089 fn meshgrid_like_keeps_gpu_residency() {
1090 test_support::with_test_provider(|provider| {
1091 let x = tensor_from_vec(vec![-1.0, 0.0, 1.0], 1, 3);
1092 let y = tensor_from_vec(vec![2.0, 4.0], 2, 1);
1093 let proto = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
1094 let proto_view = HostTensorView {
1095 data: &proto.data,
1096 shape: &proto.shape,
1097 };
1098 let proto_handle = provider.upload(&proto_view).expect("upload");
1099 let eval = evaluate(&[
1100 Value::Tensor(x),
1101 Value::Tensor(y),
1102 Value::from("like"),
1103 Value::GpuTensor(proto_handle),
1104 ])
1105 .expect("meshgrid");
1106 let x_value = eval_first(&eval).expect("X");
1107 assert!(matches!(x_value, Value::GpuTensor(_)));
1108 let gathered = test_support::gather(x_value).expect("gather");
1109 assert_eq!(gathered.shape, vec![2, 3]);
1110 });
1111 }
1112
1113 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1114 #[test]
1115 fn meshgrid_gpu_inputs_roundtrip() {
1116 test_support::with_test_provider(|provider| {
1117 let x = tensor_from_vec(vec![0.0, 0.5], 1, 2);
1118 let y = tensor_from_vec(vec![1.0, 2.0], 2, 1);
1119 let x_view = HostTensorView {
1120 data: &x.data,
1121 shape: &x.shape,
1122 };
1123 let y_view = HostTensorView {
1124 data: &y.data,
1125 shape: &y.shape,
1126 };
1127 let x_handle = provider.upload(&x_view).expect("upload");
1128 let y_handle = provider.upload(&y_view).expect("upload");
1129 let eval = evaluate(&[Value::GpuTensor(x_handle), Value::GpuTensor(y_handle)])
1130 .expect("meshgrid");
1131 assert!(matches!(eval_first(&eval).expect("X"), Value::GpuTensor(_)));
1132 assert!(matches!(
1133 eval_second(&eval).expect("Y"),
1134 Value::GpuTensor(_)
1135 ));
1136 });
1137 }
1138
1139 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1140 #[test]
1141 #[cfg(feature = "wgpu")]
1142 fn meshgrid_wgpu_matches_cpu() {
1143 let provider = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1144 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1145 )
1146 .expect("wgpu provider");
1147
1148 let x = tensor_from_vec(vec![-1.0, 0.0, 1.0, 2.0], 1, 4);
1149 let y = tensor_from_vec(vec![5.0, 6.0], 2, 1);
1150
1151 let cpu_eval =
1152 evaluate(&[Value::Tensor(x.clone()), Value::Tensor(y.clone())]).expect("meshgrid cpu");
1153 let cpu_x =
1154 test_support::gather(eval_first(&cpu_eval).expect("X cpu")).expect("gather X cpu");
1155 let cpu_y =
1156 test_support::gather(eval_second(&cpu_eval).expect("Y cpu")).expect("gather Y cpu");
1157
1158 let x_view = HostTensorView {
1159 data: &x.data,
1160 shape: &x.shape,
1161 };
1162 let y_view = HostTensorView {
1163 data: &y.data,
1164 shape: &y.shape,
1165 };
1166 let x_gpu = provider.upload(&x_view).expect("upload x");
1167 let y_gpu = provider.upload(&y_view).expect("upload y");
1168
1169 let gpu_eval =
1170 evaluate(&[Value::GpuTensor(x_gpu), Value::GpuTensor(y_gpu)]).expect("meshgrid gpu");
1171 let gpu_x_value = eval_first(&gpu_eval).expect("X gpu");
1172 let gpu_y_value = eval_second(&gpu_eval).expect("Y gpu");
1173
1174 assert!(matches!(gpu_x_value, Value::GpuTensor(_)));
1175 assert!(matches!(gpu_y_value, Value::GpuTensor(_)));
1176
1177 let gathered_x = test_support::gather(gpu_x_value).expect("gather X gpu");
1178 let gathered_y = test_support::gather(gpu_y_value).expect("gather Y gpu");
1179
1180 assert_eq!(gathered_x.shape, cpu_x.shape);
1181 assert_eq!(gathered_x.data, cpu_x.data);
1182 assert_eq!(gathered_y.shape, cpu_y.shape);
1183 assert_eq!(gathered_y.data, cpu_y.data);
1184 }
1185
1186 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1187 #[test]
1188 fn meshgrid_complex_inputs_produce_complex_outputs() {
1189 let complex = ComplexTensor::new(vec![(1.0, 1.0), (2.0, -1.0)], vec![1, 2]).unwrap();
1190 let eval = evaluate(&[Value::ComplexTensor(complex)]).expect("meshgrid");
1191 let x_value = eval_first(&eval).expect("X");
1192 match x_value {
1193 Value::ComplexTensor(ct) => {
1194 assert_eq!(ct.shape, vec![2, 2]);
1195 }
1196 Value::Complex(_, _) => {}
1197 other => panic!("expected complex output, got {other:?}"),
1198 }
1199 }
1200
1201 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1202 #[test]
1203 fn meshgrid_like_host_prototype() {
1204 let x = tensor_from_vec(vec![1.0, 2.0], 1, 2);
1205 let eval =
1206 evaluate(&[Value::Tensor(x), Value::from("like"), Value::Num(0.0)]).expect("meshgrid");
1207 let x_out = eval_first(&eval).expect("X");
1208 assert!(matches!(x_out, Value::Tensor(_) | Value::Num(_)));
1209 }
1210}