use crate::DType;
use numr::error::Result;
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn slice_last_dim_impl<R, C>(
_client: &C,
tensor: &Tensor<R>,
start: usize,
len: usize,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>,
{
tensor.narrow(-1, start, len)?.contiguous()
}
pub fn slice_last_2d_impl<R, C>(
_client: &C,
tensor: &Tensor<R>,
start_h: usize,
len_h: usize,
start_w: usize,
len_w: usize,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>,
{
let sliced_h = tensor.narrow(-2, start_h, len_h)?;
sliced_h.narrow(-1, start_w, len_w)?.contiguous()
}