burn_tensor/tensor/grid/meshgrid.rs
1use crate::backend::Backend;
2use crate::tensor::grid::{GridIndexing, GridOptions, GridSparsity, IndexPos};
3use crate::tensor::{BasicOps, Tensor};
4use alloc::vec::Vec;
5
6/// Return a collection of coordinate matrices for coordinate vectors.
7///
8/// Takes N 1D tensors and returns N tensors where each tensor represents the coordinates
9/// in one dimension across an N-dimensional grid.
10///
11/// Based upon `options.sparse`, the generated coordinate tensors can either be `Sparse` or `Dense`:
12/// * In `Sparse` mode, output tensors will have shape 1 everywhere except their cardinal dimension.
13/// * In `Dense` mode, output tensors will be expanded to the full grid shape.
14///
15/// Based upon `options.indexing`, the generated coordinate tensors will use either:
16/// * `Matrix` indexing, where dimensions are in the same order as their cardinality.
17/// * `Cartesian` indexing; where the first two dimensions are swapped.
18///
19/// See:
20/// - [numpy.meshgrid](https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html)
21/// - [torch.meshgrid](https://pytorch.org/docs/stable/generated/torch.meshgrid.html)
22///
23/// # Arguments
24///
25/// * `tensors` - A slice of 1D tensors
26/// * `options` - the options.
27///
28/// # Returns
29///
30/// A vector of N N-dimensional tensors representing the grid coordinates.
31pub fn meshgrid<B: Backend, const N: usize, K, O>(
32 tensors: &[Tensor<B, 1, K>; N],
33 options: O,
34) -> [Tensor<B, N, K>; N]
35where
36 K: BasicOps<B>,
37 O: Into<GridOptions>,
38{
39 let options = options.into();
40 let swap_dims = options.indexing == GridIndexing::Cartesian && N > 1;
41 let dense = options.sparsity == GridSparsity::Dense;
42
43 let grid_shape: [usize; N] = tensors
44 .iter()
45 .map(|t| t.dims()[0])
46 .collect::<Vec<_>>()
47 .try_into()
48 .unwrap();
49
50 tensors
51 .iter()
52 .enumerate()
53 .map(|(i, tensor)| {
54 let mut coord_tensor_shape = [1; N];
55 coord_tensor_shape[i] = grid_shape[i];
56
57 // Reshape the tensor to have singleton dimensions in all but the i-th dimension
58 let mut tensor = tensor.clone().reshape(coord_tensor_shape);
59
60 if dense {
61 tensor = tensor.expand(grid_shape);
62 }
63 if swap_dims {
64 tensor = tensor.swap_dims(0, 1);
65 }
66
67 tensor
68 })
69 .collect::<Vec<_>>()
70 .try_into()
71 .unwrap()
72}
73
74/// Return a coordinate matrix for a given set of 1D coordinate tensors.
75///
76/// Equivalent to stacking a dense matrix `meshgrid`,
77/// where the stack is along the first or last dimension.
78///
79/// # Arguments
80///
81/// * `tensors`: A slice of 1D tensors.
82/// * `index_pos`: The position of the index in the output tensor.
83///
84/// # Returns
85///
86/// A tensor of either ``(N, ..., |T[i]|, ...)`` or ``(..., |T[i]|, ..., N)``,
87/// of coordinates, indexed on the first or last dimension.
88pub fn meshgrid_stack<B: Backend, const D: usize, const D2: usize, K>(
89 tensors: &[Tensor<B, 1, K>; D],
90 index_pos: IndexPos,
91) -> Tensor<B, D2, K>
92where
93 K: BasicOps<B>,
94{
95 assert_eq!(D2, D + 1, "D2 ({D2}) != D ({D}) + 1");
96
97 let xs: Vec<Tensor<B, D, K>> = meshgrid(tensors, GridOptions::default())
98 .into_iter()
99 .collect();
100
101 let dim = match index_pos {
102 IndexPos::First => 0,
103 IndexPos::Last => D,
104 };
105
106 Tensor::stack(xs, dim)
107}