use crate::{Int, Shape, Tensor, backend::Backend};
use alloc::vec::Vec;
pub fn cartesian_grid<B: Backend, S: Into<Shape>, const D: usize, const D2: usize>(
shape: S,
device: &B::Device,
) -> Tensor<B, D2, Int> {
if D2 != D + 1 {
panic!("D2 must equal D + 1 for Tensor::cartesian_grid")
}
let dims = shape.into();
let mut indices: Vec<Tensor<B, D, Int>> = Vec::new();
for dim in 0..D {
let dim_range: Tensor<B, 1, Int> = Tensor::arange(0..dims[dim] as i64, device);
let mut shape = [1; D];
shape[dim] = dims[dim];
let mut dim_range = dim_range.reshape(shape);
for (i, &item) in dims.iter().enumerate() {
if i == dim {
continue;
}
dim_range = dim_range.repeat_dim(i, item);
}
indices.push(dim_range);
}
Tensor::stack::<D2>(indices, D)
}