#[burn_tensor_testgen::testgen(meshgrid)]
mod tests {
use super::*;
use burn_tensor::BasicOps;
use burn_tensor::backend::Backend;
use burn_tensor::grid::{
GridIndexing, GridOptions, GridSparsity, IndexPos, meshgrid, meshgrid_stack,
};
use burn_tensor::{Int, Shape, Tensor, TensorData};
fn assert_tensors_equal<const N: usize, B: Backend, K>(
actual: &[Tensor<B, N, K>; N],
expected: &[Tensor<B, N, K>; N],
) where
K: BasicOps<B>,
{
for (a, e) in actual.iter().zip(expected.iter()) {
a.clone()
.into_data()
.assert_eq(&e.clone().into_data(), true);
}
}
#[test]
fn test_meshgrid() {
let x = TestTensor::<1>::from([1, 2, 3, 4]);
let y = TestTensor::<1>::from([5, 6]);
let z = TestTensor::<1>::from([7, 8]);
let grid_shape = [x.dims()[0], y.dims()[0], z.dims()[0]];
assert_tensors_equal(
&meshgrid(&[x.clone(), y.clone(), z.clone()], GridOptions::default()),
&[
x.clone().reshape([4, 1, 1]).expand(grid_shape),
y.clone().reshape([1, 2, 1]).expand(grid_shape),
z.clone().reshape([1, 1, 2]).expand(grid_shape),
],
);
assert_tensors_equal(
&meshgrid(&[x.clone(), y.clone(), z.clone()], GridSparsity::Dense),
&[
x.clone().reshape([4, 1, 1]).expand(grid_shape),
y.clone().reshape([1, 2, 1]).expand(grid_shape),
z.clone().reshape([1, 1, 2]).expand(grid_shape),
],
);
assert_tensors_equal(
&meshgrid(&[x.clone(), y.clone(), z.clone()], GridIndexing::Matrix),
&[
x.clone().reshape([4, 1, 1]).expand(grid_shape),
y.clone().reshape([1, 2, 1]).expand(grid_shape),
z.clone().reshape([1, 1, 2]).expand(grid_shape),
],
);
assert_tensors_equal(
&meshgrid(
&[x.clone(), y.clone(), z.clone()],
GridOptions {
indexing: GridIndexing::Matrix,
sparsity: GridSparsity::Sparse,
},
),
&[
x.clone().reshape([4, 1, 1]),
y.clone().reshape([1, 2, 1]),
z.clone().reshape([1, 1, 2]),
],
);
assert_tensors_equal(
&meshgrid(&[x.clone(), y.clone(), z.clone()], GridSparsity::Sparse),
&[
x.clone().reshape([4, 1, 1]),
y.clone().reshape([1, 2, 1]),
z.clone().reshape([1, 1, 2]),
],
);
assert_tensors_equal(
&meshgrid(&[x.clone(), y.clone(), z.clone()], GridIndexing::Cartesian),
&[
x.clone()
.reshape([4, 1, 1])
.expand(grid_shape)
.swap_dims(0, 1),
y.clone()
.reshape([1, 2, 1])
.expand(grid_shape)
.swap_dims(0, 1),
z.clone()
.reshape([1, 1, 2])
.expand(grid_shape)
.swap_dims(0, 1),
],
);
assert_tensors_equal(
&meshgrid(
&[x.clone(), y.clone(), z.clone()],
GridOptions::new(GridIndexing::Cartesian, GridSparsity::Sparse),
),
&[
x.clone().reshape([4, 1, 1]).swap_dims(0, 1),
y.clone().reshape([1, 2, 1]).swap_dims(0, 1),
z.clone().reshape([1, 1, 2]).swap_dims(0, 1),
],
);
assert_tensors_equal(
&meshgrid(
&[x.clone(), y.clone(), z.clone()],
GridOptions {
indexing: GridIndexing::Cartesian,
sparsity: GridSparsity::Sparse,
},
),
&[
x.clone().reshape([4, 1, 1]).swap_dims(0, 1),
y.clone().reshape([1, 2, 1]).swap_dims(0, 1),
z.clone().reshape([1, 1, 2]).swap_dims(0, 1),
],
);
}
#[test]
fn test_meshgrid_stack() {
let tensors = [
TestTensor::from([0.5, 1.0, 2.5]),
TestTensor::from([0.5, 1.0]),
];
let result: Tensor<_, 3> = meshgrid_stack(&tensors, IndexPos::First);
result.to_data().assert_eq(
&TensorData::from([
[[0.5, 0.5], [1.0, 1.0], [2.5, 2.5]],
[[0.5, 1.0], [0.5, 1.0], [0.5, 1.0]],
]),
false,
);
let result: Tensor<_, 3> = meshgrid_stack(&tensors, IndexPos::Last);
result.to_data().assert_eq(
&TensorData::from([
[[0.5, 0.5], [0.5, 1.0]],
[[1.0, 0.5], [1.0, 1.0]],
[[2.5, 0.5], [2.5, 1.0]],
]),
false,
);
}
}