use ::burn::tensor::{Tensor, backend::Backend};
pub type Targets<B> = Tensor<B, 3>;
pub type Value<B> = Tensor<B, 1>;
pub(crate) fn column<B: Backend>(values: Tensor<B, 2>, index: usize) -> Tensor<B, 1> {
let row_count = values.dims()[0];
values
.slice([0..row_count, index..index + 1])
.squeeze_dim::<1>(1)
}
pub(crate) fn expand<B: Backend>(values: Tensor<B, 1>) -> Tensor<B, 2> {
values.unsqueeze_dim::<2>(1).repeat_dim(1, 2)
}