use crate::shapes::{GGML_IDS, MatrixOrdering, ViewShape};
use bytemuck::NoUninit;
use encase::ShaderType;
use nalgebra::{Dim, IsContiguous, Matrix, Storage};
use slang_hal::backend::{Backend, Buffer, DeviceValue, EncaseType, Encoder, ShaderBinding};
use std::ops::{Bound, RangeBounds};
use std::sync::Arc;
use slang_hal::backend::WebGpu;
#[cfg(feature = "cuda")]
use crate::cuda::Cuda;
use slang_hal::shader::ShaderArgsError;
use slang_hal::{BufferUsages, ShaderArgs};
pub struct TensorBuilder {
shape: [u32; 4],
usage: BufferUsages,
ordering: MatrixOrdering,
label: Option<String>,
}
impl TensorBuilder {
pub fn scalar(usage: BufferUsages) -> Self {
Self::tensor([1, 1, 1, 1], usage)
}
pub fn vector(dim: u32, usage: BufferUsages) -> Self {
Self::tensor([dim, 1, 1, 1], usage)
}
pub fn matrix(nrows: u32, ncols: u32, usage: BufferUsages) -> Self {
Self::tensor([nrows, ncols, 1, 1], usage)
}
pub fn tensor(shape: [u32; 4], usage: BufferUsages) -> Self {
Self {
shape,
usage,
ordering: MatrixOrdering::ColumnMajor,
label: None,
}
}
fn len(&self) -> u64 {
self.shape.into_iter().map(|s| s as u64).product()
}
pub fn ordering(mut self, ordering: MatrixOrdering) -> Self {
self.ordering = ordering;
self
}
pub fn label(mut self, label: String) -> Self {
self.label = Some(label);
self
}
pub fn build_uninit<T: DeviceValue + NoUninit, B: Backend>(
self,
backend: &B,
) -> Result<GpuTensor<T, B>, B::Error> {
let buffer = backend.uninit_buffer(self.len() as usize, self.usage)?;
Ok(GpuTensor {
shape: self.shape,
buffer,
ordering: self.ordering,
})
}
pub fn build_uninit_encased<T: DeviceValue + EncaseType, B: Backend>(
self,
backend: &B,
) -> Result<GpuTensor<T, B>, B::Error> {
let buffer = backend.uninit_buffer_encased(self.len() as usize, self.usage)?;
Ok(GpuTensor {
shape: self.shape,
buffer,
ordering: self.ordering,
})
}
pub fn build_init<T: DeviceValue + NoUninit, B: Backend>(
self,
backend: &B,
data: &[T],
) -> Result<GpuTensor<T, B>, B::Error> {
assert!(
data.len() as u64 >= self.len(),
"Incorrect number of elements provided for initializing Tensor.\
Expected at least {}, found {}",
self.len(),
data.len()
);
let buffer = backend.init_buffer(data, self.usage)?;
Ok(GpuTensor {
shape: self.shape,
buffer,
ordering: self.ordering,
})
}
pub fn build_encased<T: DeviceValue + EncaseType, B: Backend>(
self,
backend: &B,
data: &[T],
) -> Result<GpuTensor<T, B>, B::Error> {
assert!(
data.len() as u64 >= self.len(),
"Incorrect number of elements provided for initializing Tensor.\
Expected at least {}, found {}",
self.len(),
data.len()
);
let buffer = backend.init_buffer_encased(data, self.usage)?;
Ok(GpuTensor {
shape: self.shape,
buffer,
ordering: self.ordering,
})
}
}
pub type GpuVector<T, B> = GpuTensor<T, B>;
pub type GpuMatrix<T, B> = GpuTensor<T, B>;
pub type GpuScalar<T, B> = GpuTensor<T, B>;
pub struct GpuTensor<T: DeviceValue, B: Backend> {
shape: [u32; 4],
buffer: B::Buffer<T>,
ordering: MatrixOrdering,
}
pub type WgpuTensor<T> = GpuTensor<T, WebGpu>;
#[cfg(feature = "cuda")]
pub type CudaTensor<T> = GpuTensor<T, Cuda>;
impl<T: DeviceValue, B: Backend> GpuTensor<T, B> {
pub fn ordering(&self) -> MatrixOrdering {
self.ordering
}
pub fn transposed(mut self) -> Self {
self.transpose();
self
}
pub fn transpose(&mut self) {
self.shape.swap(0, 1);
self.ordering = self.ordering.transpose();
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn len(&self) -> u64 {
self.shape.into_iter().map(|s| s as u64).product()
}
pub fn capacity(&self) -> u64
where
T: NoUninit,
{
self.buffer.len() as u64
}
pub fn capacity_encased(&self) -> u64
where
T: EncaseType,
{
self.buffer.len_encased() as u64
}
pub fn order(&self) -> u8 {
self.shape.iter().map(|s| (*s > 1) as u8).sum()
}
pub fn size(&self, i: usize) -> u32 {
self.shape[i]
}
pub fn size_ggml(&self, i: usize) -> u32 {
self.size(GGML_IDS[i])
}
pub fn stride(&self, i: usize) -> u32 {
self.as_view().view_shape.stride[i]
}
pub fn stride_ggml(&self, i: usize) -> u32 {
self.stride(GGML_IDS[i])
}
pub fn bytes_len(&self) -> u64
where
T: DeviceValue,
{
std::mem::size_of::<T>() as u64 * self.len()
}
pub fn copy_from_view<'a>(
&mut self,
encoder: &mut B::Encoder,
source: impl Into<GpuTensorView<'a, T, B>>,
) -> Result<(), B::Error>
where
T: DeviceValue + NoUninit,
{
let source = source.into();
let copy_len = self.len();
encoder.copy_buffer_to_buffer(
source.buffer,
source.offset as usize,
&mut self.buffer,
0,
copy_len as usize,
)
}
pub fn copy_from_view_encased<'a>(
&mut self,
encoder: &mut B::Encoder,
source: impl Into<GpuTensorView<'a, T, B>>,
) -> Result<(), B::Error>
where
T: DeviceValue + ShaderType,
{
let source = source.into();
let copy_len = self.len();
encoder.copy_buffer_to_buffer_encased(
source.buffer,
source.offset as usize,
&mut self.buffer,
0,
copy_len as usize,
)
}
pub fn shape(&self) -> [u32; 4] {
self.shape
}
pub fn buffer(&self) -> &B::Buffer<T> {
&self.buffer
}
pub fn buffer_mut(&mut self) -> &mut B::Buffer<T> {
&mut self.buffer
}
pub fn into_inner(self) -> B::Buffer<T> {
self.buffer
}
pub fn as_view(&self) -> GpuTensorView<'_, T, B> {
GpuTensorView {
view_shape: ViewShape::contiguous(self.shape, self.ordering),
offset: 0,
buffer: &self.buffer,
}
}
pub fn as_view_mut(&mut self) -> GpuTensorViewMut<'_, T, B> {
GpuTensorViewMut {
view_shape: ViewShape::contiguous(self.shape, self.ordering),
offset: 0,
buffer: &mut self.buffer,
}
}
fn vector_dim(&self) -> usize {
let dim = match self.ordering {
MatrixOrdering::RowMajor => 1,
MatrixOrdering::ColumnMajor => 0,
};
let mut required_shape = [1; 4];
required_shape[dim] = self.shape[dim];
assert_eq!(
required_shape, self.shape,
"Operation only supported on vector tensors."
);
dim
}
}
impl<'a, T: DeviceValue, B: Backend> From<&'a Arc<GpuTensor<T, B>>> for GpuTensorView<'a, T, B> {
fn from(val: &'a Arc<GpuTensor<T, B>>) -> Self {
val.as_view()
}
}
impl<'a, T: DeviceValue, B: Backend> From<&'a GpuTensor<T, B>> for GpuTensorView<'a, T, B> {
fn from(val: &'a GpuTensor<T, B>) -> Self {
val.as_view()
}
}
impl<'a, T: DeviceValue, B: Backend> From<&'a mut GpuTensor<T, B>> for GpuTensorViewMut<'a, T, B> {
fn from(val: &'a mut GpuTensor<T, B>) -> Self {
val.as_view_mut()
}
}
pub struct GpuTensorViewMut<'a, T: DeviceValue, B: Backend> {
view_shape: ViewShape,
buffer: &'a mut B::Buffer<T>,
offset: u32,
}
pub struct GpuTensorView<'a, T: DeviceValue, B: Backend> {
view_shape: ViewShape,
buffer: &'a B::Buffer<T>,
offset: u32,
}
impl<'a, T: DeviceValue, B: Backend> Clone for GpuTensorView<'a, T, B> {
fn clone(&self) -> Self {
*self
}
}
impl<'a, T: DeviceValue, B: Backend> Copy for GpuTensorView<'a, T, B> {}
impl<'a, T: DeviceValue, B: Backend> GpuTensorView<'a, T, B> {
pub fn ordering(&self) -> Option<MatrixOrdering> {
self.view_shape.ordering()
}
pub fn is_contiguous(&self) -> Option<MatrixOrdering> {
self.view_shape.is_contiguous()
}
pub fn is_entire_tensor(&self) -> Option<MatrixOrdering>
where
T: NoUninit,
{
if self.buffer.len() == self.len() as usize && self.offset == 0 {
self.is_contiguous()
} else {
None
}
}
pub fn shape(&self) -> ViewShape {
self.view_shape
}
pub fn buffer(&self) -> B::BufferSlice<'_, T> {
self.buffer.slice(self.offset as usize..)
}
pub fn raw_buffer(&self) -> &B::Buffer<T> {
self.buffer
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn len(&self) -> u64 {
self.view_shape.len()
}
pub fn size(&self, i: usize) -> u32 {
self.view_shape.size[i]
}
pub fn size_ggml(&self, i: usize) -> u32 {
self.view_shape.size[GGML_IDS[i]]
}
pub fn stride(&self, i: usize) -> u32 {
self.view_shape.stride[i]
}
pub fn stride_ggml(&self, i: usize) -> u32 {
self.view_shape.stride[GGML_IDS[i]]
}
pub fn transposed(&self) -> Self {
self.permute([1, 0, 2, 3])
}
pub fn permute(&self, permutations: [usize; 4]) -> Self {
Self {
view_shape: self.view_shape.permute(permutations),
offset: self.offset,
buffer: self.buffer,
}
}
pub fn permute_ggml(&self, permutations: [usize; 4]) -> Self {
Self {
view_shape: self.view_shape.permute_ggml(permutations),
offset: self.offset,
buffer: self.buffer,
}
}
pub fn reshape_with_ordering<const DIM2: usize>(
&self,
shape: [u32; DIM2],
ordering: MatrixOrdering,
) -> Self {
assert!(DIM2 <= 4);
let mut shape4 = [1; 4];
shape4[..DIM2].copy_from_slice(&shape);
let view_shape = ViewShape::contiguous(shape4, ordering);
self.view(0, shape4, view_shape.stride.map(Some))
}
pub fn reshape<const DIM2: usize>(&self, shape: [u32; DIM2]) -> Self {
self.view(0, shape, [None; DIM2])
}
pub fn reshape_ggml<const DIM2: usize>(&self, mut shape: [u32; DIM2]) -> Self {
shape.swap(0, 1);
if self.view_shape.size[0] == 1 && self.view_shape.size[1] == 1 {
self.reshape_with_ordering(shape, MatrixOrdering::RowMajor)
} else {
self.reshape(shape)
}
}
pub fn reshape_ggml_with_ordering<const DIM2: usize>(
&self,
mut shape: [u32; DIM2],
ordering: MatrixOrdering,
) -> Self {
shape.swap(0, 1);
self.reshape_with_ordering(shape, ordering)
}
pub fn view<const DIM2: usize>(
&self,
mut offset: u32,
shape: [u32; DIM2],
stride: [Option<u32>; DIM2],
) -> Self {
let available_elts = self.view_shape.size.iter().product::<u32>();
let needed_elts = shape.iter().product::<u32>() + offset;
assert!(
needed_elts <= available_elts,
"Source tensor is too small for reshaping. Expected at least {needed_elts} elements (shape: {shape:?}, offset: {offset}), found {available_elts} (shape: {:?})",
self.view_shape.size
);
offset += self.offset;
GpuTensorView {
view_shape: self.view_shape.view(shape, stride),
offset,
buffer: self.buffer,
}
}
pub fn view_ggml<const DIM2: usize>(
&self,
offset: u32,
mut shape: [u32; DIM2],
mut stride: [Option<u32>; DIM2],
) -> Self {
shape.swap(0, 1);
stride.swap(0, 1);
self.view(offset, shape, stride)
}
pub fn matrix(&self, matrix_id: u32) -> Self {
let [nrows, ncols, nmats, ncubes] = self.view_shape.size;
assert!(matrix_id < nmats);
GpuTensorView {
view_shape: ViewShape {
size: [nrows, ncols, 1, ncubes],
stride: self.view_shape.stride,
},
offset: self.offset + self.view_shape.stride[2] * matrix_id,
buffer: self.buffer,
}
}
pub fn columns(&self, first_col: u32, new_ncols: u32) -> Self {
let [nrows, ncols, nmats, ncubes] = self.view_shape.size;
assert!(first_col + new_ncols < ncols);
GpuTensorView {
view_shape: ViewShape {
size: [nrows, new_ncols, nmats, ncubes],
stride: self.view_shape.stride,
},
offset: self.offset + self.view_shape.stride[1] * first_col,
buffer: self.buffer,
}
}
pub fn column(&self, col: u32) -> Self {
self.columns(col, 1)
}
pub fn rows(&self, first_row: u32, new_nrows: u32) -> Self {
let [nrows, ncols, nmats, ncubes] = self.view_shape.size;
assert!(first_row + new_nrows < nrows);
GpuTensorView {
view_shape: ViewShape {
size: [new_nrows, ncols, nmats, ncubes],
stride: self.view_shape.stride,
},
offset: self.offset + self.view_shape.stride[0] * first_row,
buffer: self.buffer,
}
}
pub fn row(&self, row: u32) -> Self {
self.rows(row, 1)
}
}
impl<'a, T: DeviceValue, B: Backend> GpuTensorViewMut<'a, T, B> {
pub fn as_ref(&self) -> GpuTensorView<'_, T, B> {
GpuTensorView {
view_shape: self.view_shape,
buffer: &*self.buffer,
offset: self.offset,
}
}
pub fn is_contiguous(&self) -> Option<MatrixOrdering> {
self.as_ref().is_contiguous()
}
pub fn is_entire_tensor(&self) -> Option<MatrixOrdering>
where
T: NoUninit,
{
self.as_ref().is_entire_tensor()
}
pub fn shape(&self) -> ViewShape {
self.view_shape
}
pub fn raw_buffer(&mut self) -> &mut B::Buffer<T> {
self.buffer
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn len(&self) -> u64 {
self.view_shape.len()
}
pub fn transposed(&mut self) -> GpuTensorViewMut<'_, T, B> {
self.permute([1, 0, 2, 3])
}
pub fn permute(&mut self, permutations: [usize; 4]) -> GpuTensorViewMut<'_, T, B> {
GpuTensorViewMut {
view_shape: self.view_shape.permute(permutations),
offset: self.offset,
buffer: self.buffer,
}
}
pub fn reshape<const DIM2: usize>(&mut self, shape: [u32; DIM2]) -> GpuTensorViewMut<'_, T, B> {
self.view(0, shape, [None; DIM2])
}
pub fn view<const DIM2: usize>(
&mut self,
mut offset: u32,
shape: [u32; DIM2],
stride: [Option<u32>; DIM2],
) -> GpuTensorViewMut<'_, T, B> {
offset += self.offset;
let available_elts = self.view_shape.size.iter().product::<u32>();
let needed_elts = shape.iter().product::<u32>() + offset;
assert!(
needed_elts <= available_elts,
"Source tensor is too small for reshaping. Expected at least {needed_elts} elements (shape: {shape:?}), found {available_elts} (shape: {:?})",
self.view_shape.size
);
GpuTensorViewMut {
view_shape: self.view_shape.view(shape, stride),
offset,
buffer: self.buffer,
}
}
pub fn matrix(&mut self, matrix_id: u32) -> GpuTensorViewMut<'_, T, B> {
let [nrows, ncols, nmats, ncubes] = self.view_shape.size;
assert!(matrix_id < nmats);
GpuTensorViewMut {
view_shape: ViewShape {
size: [nrows, ncols, 1, ncubes],
stride: self.view_shape.stride,
},
offset: self.offset + self.view_shape.stride[2] * matrix_id,
buffer: self.buffer,
}
}
pub fn columns(&mut self, first_col: u32, new_ncols: u32) -> GpuTensorViewMut<'_, T, B> {
let [nrows, ncols, nmats, ncubes] = self.view_shape.size;
assert!(first_col + new_ncols < ncols);
GpuTensorViewMut {
view_shape: ViewShape {
size: [nrows, new_ncols, nmats, ncubes],
stride: self.view_shape.stride,
},
offset: self.offset + self.view_shape.stride[1] * first_col,
buffer: self.buffer,
}
}
pub fn column(&mut self, col: u32) -> GpuTensorViewMut<'_, T, B> {
self.columns(col, 1)
}
pub fn rows(&mut self, first_row: u32, new_nrows: u32) -> GpuTensorViewMut<'_, T, B> {
let [nrows, ncols, nmats, ncubes] = self.view_shape.size;
assert!(first_row + new_nrows < nrows);
GpuTensorViewMut {
view_shape: ViewShape {
size: [new_nrows, ncols, nmats, ncubes],
stride: self.view_shape.stride,
},
offset: self.offset + self.view_shape.stride[0] * first_row,
buffer: self.buffer,
}
}
pub fn row(&mut self, row: u32) -> GpuTensorViewMut<'_, T, B> {
self.rows(row, 1)
}
}
impl<T: DeviceValue, B: Backend> GpuTensor<T, B> {
pub fn reshape<const DIM2: usize>(&self, shape: [u32; DIM2]) -> GpuTensorView<'_, T, B> {
self.as_view().reshape_with_ordering(shape, self.ordering)
}
pub fn reshape_ggml<const DIM2: usize>(&self, shape: [u32; DIM2]) -> GpuTensorView<'_, T, B> {
self.as_view()
.reshape_ggml_with_ordering(shape, self.ordering)
}
pub fn permute(&self, permutations: [usize; 4]) -> GpuTensorView<'_, T, B> {
self.as_view().permute(permutations)
}
pub fn permute_ggml(&self, permutations: [usize; 4]) -> GpuTensorView<'_, T, B> {
self.as_view().permute_ggml(permutations)
}
pub fn view<const DIM2: usize>(
&self,
offset: u32,
shape: [u32; DIM2],
stride: [Option<u32>; DIM2],
) -> GpuTensorView<'_, T, B> {
self.as_view().view(offset, shape, stride)
}
pub fn view_ggml<const DIM2: usize>(
&self,
offset: u32,
shape: [u32; DIM2],
stride: [Option<u32>; DIM2],
) -> GpuTensorView<'_, T, B> {
self.as_view().view_ggml(offset, shape, stride)
}
pub fn column(&self, i: u32) -> GpuTensorView<'_, T, B> {
self.as_view().column(i)
}
pub fn columns(&self, first_col: u32, ncols: u32) -> GpuTensorView<'_, T, B> {
self.as_view().columns(first_col, ncols)
}
pub fn row(&self, i: u32) -> GpuTensorView<'_, T, B> {
self.as_view().row(i)
}
pub fn rows(&self, first_row: u32, nrows: u32) -> GpuTensorView<'_, T, B> {
self.as_view().rows(first_row, nrows)
}
}
impl<T: DeviceValue + NoUninit, B: Backend> GpuTensor<T, B> {
pub fn matrix_uninit(
backend: &B,
nrows: u32,
ncols: u32,
usage: BufferUsages,
) -> Result<Self, B::Error>
where
T: DeviceValue,
{
TensorBuilder::matrix(nrows, ncols, usage).build_uninit(backend)
}
pub fn matrix<R: Dim, C: Dim, S: Storage<T, R, C> + IsContiguous>(
backend: &B,
matrix: &Matrix<T, R, C, S>,
usage: BufferUsages,
) -> Result<Self, B::Error>
where
T: DeviceValue + nalgebra::Scalar,
{
Self::matrix_with_ordering(backend, matrix, MatrixOrdering::default(), usage)
}
pub fn matrix_with_ordering<R: Dim, C: Dim, S: Storage<T, R, C> + IsContiguous>(
backend: &B,
matrix: &Matrix<T, R, C, S>,
ordering: MatrixOrdering,
usage: BufferUsages,
) -> Result<Self, B::Error>
where
T: DeviceValue + nalgebra::Scalar,
{
TensorBuilder::matrix(matrix.nrows() as u32, matrix.ncols() as u32, usage)
.ordering(ordering)
.build_init(backend, matrix.as_slice())
}
}
impl<T: DeviceValue, B: Backend> GpuTensor<T, B> {
pub fn vector_uninit(backend: &B, len: u32, usage: BufferUsages) -> Result<Self, B::Error>
where
T: DeviceValue + NoUninit,
{
TensorBuilder::vector(len, usage).build_uninit(backend)
}
pub fn vector(
backend: &B,
vector: impl AsRef<[T]>,
usage: BufferUsages,
) -> Result<Self, B::Error>
where
T: DeviceValue + NoUninit,
{
let v = vector.as_ref();
TensorBuilder::vector(v.len() as u32, usage).build_init(backend, v.as_ref())
}
pub fn vector_uninit_encased(
backend: &B,
len: u32,
usage: BufferUsages,
) -> Result<Self, B::Error>
where
T: DeviceValue + EncaseType,
{
TensorBuilder::vector(len, usage).build_uninit_encased(backend)
}
pub fn vector_encased(
backend: &B,
vector: impl AsRef<[T]>,
usage: BufferUsages,
) -> Result<Self, B::Error>
where
T: DeviceValue + EncaseType,
{
let v = vector.as_ref();
TensorBuilder::vector(v.len() as u32, usage).build_encased(backend, v.as_ref())
}
}
impl<T: DeviceValue, B: Backend> GpuTensor<T, B> {
pub fn scalar_uninit(backend: &B, usage: BufferUsages) -> Result<Self, B::Error>
where
T: DeviceValue + NoUninit,
{
TensorBuilder::scalar(usage).build_uninit(backend)
}
pub fn scalar_uninit_encased(backend: &B, usage: BufferUsages) -> Result<Self, B::Error>
where
T: DeviceValue + EncaseType,
{
TensorBuilder::scalar(usage).build_uninit_encased(backend)
}
pub fn scalar(backend: &B, value: T, usage: BufferUsages) -> Result<Self, B::Error>
where
T: DeviceValue + NoUninit,
{
TensorBuilder::scalar(usage).build_init(backend, &[value])
}
pub fn scalar_encased(backend: &B, value: T, usage: BufferUsages) -> Result<Self, B::Error>
where
T: DeviceValue + EncaseType,
{
TensorBuilder::scalar(usage).build_encased(backend, &[value])
}
}
impl<'b, B: Backend, T: DeviceValue> ShaderArgs<'b, B> for GpuTensor<T, B> {
fn write_arg<'a>(
&'b self,
binding: ShaderBinding,
name: &str,
dispatch: &mut B::Dispatch<'a>,
) -> Result<(), ShaderArgsError>
where
'b: 'a,
{
self.buffer.write_arg(binding, name, dispatch)
}
}
macro_rules! append_and_remove(
($append: ident, $shift_remove: ident, $TraitBound: ident, $capacity: ident, $copy_buffer_to_buffer: ident, $uninit_buffer: ident, $write_buffer: ident) => {
pub fn $append(&mut self, backend: &B, data: &[T]) -> Result<(), B::Error>
where
T: $TraitBound,
{
let dim_to_grow = self.vector_dim();
let num_added = data.len();
let curr_len = self.shape[dim_to_grow];
let new_len = curr_len + num_added as u32;
let mut encoder = backend.begin_encoding();
if new_len as u64 >= self.$capacity() {
let new_capacity = new_len.next_power_of_two();
let mut new_buffer = backend.$uninit_buffer(
new_capacity as usize,
self.buffer().usage() | BufferUsages::COPY_DST
)?;
encoder.$copy_buffer_to_buffer(
&self.buffer,
0,
&mut new_buffer,
0,
curr_len as usize,
)?;
self.buffer = new_buffer;
}
backend.$write_buffer(&mut self.buffer, curr_len as u64, data)?;
backend.submit(encoder)?;
self.shape[dim_to_grow] = new_len;
Ok(())
}
pub fn $shift_remove(
&mut self,
backend: &B,
range: impl RangeBounds<usize>,
) -> Result<usize, B::Error>
where T: $TraitBound {
let dim_to_shrink = self.vector_dim();
let curr_len = self.shape[dim_to_shrink] as usize;
let range_start = match range.start_bound() {
Bound::Included(i) => *i,
Bound::Excluded(i) => *i + 1,
Bound::Unbounded => 0,
};
let range_end = match range.end_bound() {
Bound::Included(i) => *i + 1,
Bound::Excluded(i) => *i,
Bound::Unbounded => curr_len,
};
if range_end <= range_start {
return Ok(0);
}
assert!(range_end <= curr_len, "Range index out of bounds.");
let num_elements_to_move = curr_len - range_end;
if num_elements_to_move > 0 {
let mut staging = backend.$uninit_buffer(
num_elements_to_move,
BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC,
)?;
let mut encoder = backend.begin_encoding();
encoder.$copy_buffer_to_buffer(
&self.buffer,
range_end,
&mut staging,
0,
num_elements_to_move,
)?;
encoder.$copy_buffer_to_buffer(
&staging,
0,
&mut self.buffer,
range_start,
num_elements_to_move,
)?;
backend.submit(encoder)?;
}
let num_removed = range_end - range_start;
self.shape[dim_to_shrink] -= num_removed as u32;
Ok(num_removed)
}
}
);
impl<T: DeviceValue, B: Backend> GpuTensor<T, B> {
append_and_remove!(
append,
shift_remove,
NoUninit,
capacity,
copy_buffer_to_buffer,
uninit_buffer,
write_buffer
);
append_and_remove!(
append_encased,
shift_remove_encased,
EncaseType,
capacity_encased,
copy_buffer_to_buffer_encased,
uninit_buffer_encased,
write_buffer_encased
);
}