1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
use crate::{backend::Backend, ops::IntTensor, Int, Shape, Tensor};
use alloc::vec::Vec;

/// Generates a cartesian grid for the given tensor shape on the specified device.
/// The generated tensor is of dimension `D2 = D + 1`, where each element at dimension D contains the cartesian grid coordinates for that element.
///
/// # Arguments
///
/// * `shape` - The shape specifying the dimensions of the tensor.
/// * `device` - The device to create the tensor on.
///
/// # Panics
///
/// Panics if `D2` is not equal to `D+1`.
///
/// # Examples
///
/// ```rust
///    use burn_tensor::Int;
///    use burn_tensor::{backend::Backend, Shape, Tensor};
///    fn example<B: Backend>() {
///        let device = Default::default();
///        let result: Tensor<B, 3, _> = Tensor::<B, 2, Int>::cartesian_grid([2, 3], &device);
///        println!("{}", result);
///    }
/// ```
pub fn cartesian_grid<B: Backend, S: Into<Shape<D>>, const D: usize, const D2: usize>(
    shape: S,
    device: &B::Device,
) -> IntTensor<B, D2> {
    if D2 != D + 1 {
        panic!("D2 must equal D + 1 for Tensor::indices")
    }

    let dims = shape.into().dims;
    let mut indices: Vec<Tensor<B, D, Int>> = Vec::new();

    for dim in 0..D {
        let dim_range: Tensor<B, 1, Int> = Tensor::arange(0..dims[dim] as i64, device);

        let mut shape = [1; D];
        shape[dim] = dims[dim];
        let mut dim_range = dim_range.reshape(shape);

        for (i, &item) in dims.iter().enumerate() {
            if i == dim {
                continue;
            }
            dim_range = dim_range.repeat_dim(i, item);
        }

        indices.push(dim_range);
    }

    Tensor::stack::<D2>(indices, D).into_primitive()
}