burn_tensor/tensor/api/
cartesian_grid.rs

1use crate::{Int, Shape, Tensor, backend::Backend};
2use alloc::vec::Vec;
3
4/// Generates a cartesian grid for the given tensor shape on the specified device.
5/// The generated tensor is of dimension `D2 = D + 1`, where each element at dimension D contains the cartesian grid coordinates for that element.
6///
7/// # Arguments
8///
9/// * `shape` - The shape specifying the dimensions of the tensor.
10/// * `device` - The device to create the tensor on.
11///
12/// # Panics
13///
14/// Panics if `D2` is not equal to `D+1`.
15///
16/// # Examples
17///
18/// ```rust
19///    use burn_tensor::Int;
20///    use burn_tensor::{backend::Backend, Shape, Tensor};
21///    fn example<B: Backend>() {
22///        let device = Default::default();
23///        let result: Tensor<B, 3, _> = Tensor::<B, 2, Int>::cartesian_grid([2, 3], &device);
24///        println!("{}", result);
25///    }
26/// ```
27pub fn cartesian_grid<B: Backend, S: Into<Shape>, const D: usize, const D2: usize>(
28    shape: S,
29    device: &B::Device,
30) -> Tensor<B, D2, Int> {
31    if D2 != D + 1 {
32        panic!("D2 must equal D + 1 for Tensor::cartesian_grid")
33    }
34
35    let dims = shape.into().dims;
36    let mut indices: Vec<Tensor<B, D, Int>> = Vec::new();
37
38    for dim in 0..D {
39        let dim_range: Tensor<B, 1, Int> = Tensor::arange(0..dims[dim] as i64, device);
40
41        let mut shape = [1; D];
42        shape[dim] = dims[dim];
43        let mut dim_range = dim_range.reshape(shape);
44
45        for (i, &item) in dims.iter().enumerate() {
46            if i == dim {
47                continue;
48            }
49            dim_range = dim_range.repeat_dim(i, item);
50        }
51
52        indices.push(dim_range);
53    }
54
55    Tensor::stack::<D2>(indices, D)
56}