1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
use crate::{backend::Backend, ops::IntTensor, Int, Shape, Tensor};
use alloc::vec::Vec;
/// Generates a cartesian grid for the given tensor shape on the specified device.
/// The generated tensor is of dimension `D2 = D + 1`, where each element at dimension D contains the cartesian grid coordinates for that element.
///
/// # Arguments
///
/// * `shape` - The shape specifying the dimensions of the tensor.
/// * `device` - The device to create the tensor on.
///
/// # Panics
///
/// Panics if `D2` is not equal to `D+1`.
///
/// # Examples
///
/// ```rust
/// use burn_tensor::Int;
/// use burn_tensor::{backend::Backend, Shape, Tensor};
/// fn example<B: Backend>() {
/// let device = Default::default();
/// let result: Tensor<B, 3, _> = Tensor::<B, 2, Int>::cartesian_grid([2, 3], &device);
/// println!("{}", result);
/// }
/// ```
pub fn cartesian_grid<B: Backend, S: Into<Shape<D>>, const D: usize, const D2: usize>(
shape: S,
device: &B::Device,
) -> IntTensor<B, D2> {
if D2 != D + 1 {
panic!("D2 must equal D + 1 for Tensor::indices")
}
let dims = shape.into().dims;
let mut indices: Vec<Tensor<B, D, Int>> = Vec::new();
for dim in 0..D {
let dim_range: Tensor<B, 1, Int> = Tensor::arange(0..dims[dim] as i64, device);
let mut shape = [1; D];
shape[dim] = dims[dim];
let mut dim_range = dim_range.reshape(shape);
for (i, &item) in dims.iter().enumerate() {
if i == dim {
continue;
}
dim_range = dim_range.repeat_dim(i, item);
}
indices.push(dim_range);
}
Tensor::stack::<D2>(indices, D).into_primitive()
}