use crate::backend::Backend;
use crate::tensor::grid::{GridIndexing, GridOptions, GridSparsity, IndexPos};
use crate::tensor::{BasicOps, Tensor};
use alloc::vec::Vec;
pub fn meshgrid<B: Backend, const N: usize, K, O>(
tensors: &[Tensor<B, 1, K>; N],
options: O,
) -> [Tensor<B, N, K>; N]
where
K: BasicOps<B>,
O: Into<GridOptions>,
{
let options = options.into();
let swap_dims = options.indexing == GridIndexing::Cartesian && N > 1;
let dense = options.sparsity == GridSparsity::Dense;
let grid_shape: [usize; N] = tensors
.iter()
.map(|t| t.dims()[0])
.collect::<Vec<_>>()
.try_into()
.unwrap();
tensors
.iter()
.enumerate()
.map(|(i, tensor)| {
let mut coord_tensor_shape = [1; N];
coord_tensor_shape[i] = grid_shape[i];
let mut tensor = tensor.clone().reshape(coord_tensor_shape);
if dense {
tensor = tensor.expand(grid_shape);
}
if swap_dims {
tensor = tensor.swap_dims(0, 1);
}
tensor
})
.collect::<Vec<_>>()
.try_into()
.unwrap()
}
pub fn meshgrid_stack<B: Backend, const D: usize, const D2: usize, K>(
tensors: &[Tensor<B, 1, K>; D],
index_pos: IndexPos,
) -> Tensor<B, D2, K>
where
K: BasicOps<B>,
{
assert_eq!(D2, D + 1, "D2 ({D2}) != D ({D}) + 1");
let xs: Vec<Tensor<B, D, K>> = meshgrid(tensors, GridOptions::default())
.into_iter()
.collect();
let dim = match index_pos {
IndexPos::First => 0,
IndexPos::Last => D,
};
Tensor::stack(xs, dim)
}