use burn_core::prelude::*;
pub trait SliceAccess<B: Backend>: Clone + Sized {
fn zeros_like(sample: &Self, capacity: usize, device: &B::Device) -> Self;
fn select(self, dim: usize, indices: Tensor<B, 1, Int>) -> Self;
fn slice_assign_inplace(&mut self, index: usize, value: Self);
}
impl<B: Backend> SliceAccess<B> for Tensor<B, 2> {
fn zeros_like(sample: &Self, capacity: usize, device: &B::Device) -> Self {
let feature_dim = sample.dims()[1];
Tensor::zeros([capacity, feature_dim], device)
}
fn select(self, dim: usize, indices: Tensor<B, 1, Int>) -> Self {
Tensor::select(self, dim, indices)
}
fn slice_assign_inplace(&mut self, index: usize, value: Self) {
self.inplace(|t| t.slice_assign(index..index + 1, value));
}
}