burn_tensor/tensor/grid/
meshgrid.rs

1use crate::backend::Backend;
2use crate::tensor::grid::{GridIndexing, GridOptions, GridSparsity, IndexPos};
3use crate::tensor::{BasicOps, Tensor};
4use alloc::vec::Vec;
5
6/// Return a collection of coordinate matrices for coordinate vectors.
7///
8/// Takes N 1D tensors and returns N tensors where each tensor represents the coordinates
9/// in one dimension across an N-dimensional grid.
10///
11/// Based upon `options.sparse`, the generated coordinate tensors can either be `Sparse` or `Dense`:
12/// * In `Sparse` mode, output tensors will have shape 1 everywhere except their cardinal dimension.
13/// * In `Dense` mode, output tensors will be expanded to the full grid shape.
14///
15/// Based upon `options.indexing`, the generated coordinate tensors will use either:
16/// * `Matrix` indexing, where dimensions are in the same order as their cardinality.
17/// * `Cartesian` indexing; where the first two dimensions are swapped.
18///
19/// See:
20///  - [numpy.meshgrid](https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html)
21///  - [torch.meshgrid](https://pytorch.org/docs/stable/generated/torch.meshgrid.html)
22///
23/// # Arguments
24///
25/// * `tensors` - A slice of 1D tensors
26/// * `options` - the options.
27///
28/// # Returns
29///
30/// A vector of N N-dimensional tensors representing the grid coordinates.
31pub fn meshgrid<B: Backend, const N: usize, K, O>(
32    tensors: &[Tensor<B, 1, K>; N],
33    options: O,
34) -> [Tensor<B, N, K>; N]
35where
36    K: BasicOps<B>,
37    O: Into<GridOptions>,
38{
39    let options = options.into();
40    let swap_dims = options.indexing == GridIndexing::Cartesian && N > 1;
41    let dense = options.sparsity == GridSparsity::Dense;
42
43    let grid_shape: [usize; N] = tensors
44        .iter()
45        .map(|t| t.dims()[0])
46        .collect::<Vec<_>>()
47        .try_into()
48        .unwrap();
49
50    tensors
51        .iter()
52        .enumerate()
53        .map(|(i, tensor)| {
54            let mut coord_tensor_shape = [1; N];
55            coord_tensor_shape[i] = grid_shape[i];
56
57            // Reshape the tensor to have singleton dimensions in all but the i-th dimension
58            let mut tensor = tensor.clone().reshape(coord_tensor_shape);
59
60            if dense {
61                tensor = tensor.expand(grid_shape);
62            }
63            if swap_dims {
64                tensor = tensor.swap_dims(0, 1);
65            }
66
67            tensor
68        })
69        .collect::<Vec<_>>()
70        .try_into()
71        .unwrap()
72}
73
74/// Return a coordinate matrix for a given set of 1D coordinate tensors.
75///
76/// Equivalent to stacking a dense matrix `meshgrid`,
77/// where the stack is along the first or last dimension.
78///
79/// # Arguments
80///
81/// * `tensors`: A slice of 1D tensors.
82/// * `index_pos`: The position of the index in the output tensor.
83///
84/// # Returns
85///
86/// A tensor of either ``(N, ..., |T[i]|, ...)`` or ``(..., |T[i]|, ..., N)``,
87/// of coordinates, indexed on the first or last dimension.
88pub fn meshgrid_stack<B: Backend, const D: usize, const D2: usize, K>(
89    tensors: &[Tensor<B, 1, K>; D],
90    index_pos: IndexPos,
91) -> Tensor<B, D2, K>
92where
93    K: BasicOps<B>,
94{
95    assert_eq!(D2, D + 1, "D2 ({D2}) != D ({D}) + 1");
96
97    let xs: Vec<Tensor<B, D, K>> = meshgrid(tensors, GridOptions::default())
98        .into_iter()
99        .collect();
100
101    let dim = match index_pos {
102        IndexPos::First => 0,
103        IndexPos::Last => D,
104    };
105
106    Tensor::stack(xs, dim)
107}