burn_tensor/tensor/api/cartesian_grid.rs
1use crate::{Int, Shape, Tensor, backend::Backend};
2use alloc::vec::Vec;
3
4/// Generates a cartesian grid for the given tensor shape on the specified device.
5/// The generated tensor is of dimension `D2 = D + 1`, where each element at dimension D contains the cartesian grid coordinates for that element.
6///
7/// # Arguments
8///
9/// * `shape` - The shape specifying the dimensions of the tensor.
10/// * `device` - The device to create the tensor on.
11///
12/// # Panics
13///
14/// Panics if `D2` is not equal to `D+1`.
15///
16/// # Examples
17///
18/// ```rust
19/// use burn_tensor::Int;
20/// use burn_tensor::{backend::Backend, Shape, Tensor};
21/// fn example<B: Backend>() {
22/// let device = Default::default();
23/// let result: Tensor<B, 3, _> = Tensor::<B, 2, Int>::cartesian_grid([2, 3], &device);
24/// println!("{}", result);
25/// }
26/// ```
27pub fn cartesian_grid<B: Backend, S: Into<Shape>, const D: usize, const D2: usize>(
28 shape: S,
29 device: &B::Device,
30) -> Tensor<B, D2, Int> {
31 if D2 != D + 1 {
32 panic!("D2 must equal D + 1 for Tensor::cartesian_grid")
33 }
34
35 let dims = shape.into().dims;
36 let mut indices: Vec<Tensor<B, D, Int>> = Vec::new();
37
38 for dim in 0..D {
39 let dim_range: Tensor<B, 1, Int> = Tensor::arange(0..dims[dim] as i64, device);
40
41 let mut shape = [1; D];
42 shape[dim] = dims[dim];
43 let mut dim_range = dim_range.reshape(shape);
44
45 for (i, &item) in dims.iter().enumerate() {
46 if i == dim {
47 continue;
48 }
49 dim_range = dim_range.repeat_dim(i, item);
50 }
51
52 indices.push(dim_range);
53 }
54
55 Tensor::stack::<D2>(indices, D)
56}