use slang_hal::backend::{Backend, Buffer, DeviceValue, EncaseType, Encoder, ShaderBinding};
use crate::shapes::{GGML_IDS, MatrixOrdering, ViewShape};
use bytemuck::Pod;
use encase::ShaderType;
use nalgebra::{Dim, IsContiguous, Matrix, Storage};
use std::sync::Arc;
use slang_hal::backend::WebGpu;
use wgpu::BufferUsages;
use slang_hal::ShaderArgs;
#[cfg(feature = "cuda")]
use crate::cuda::Cuda;
use slang_hal::shader::ShaderArgsError;
pub struct TensorBuilder {
shape: [u32; 4],
usage: BufferUsages,
ordering: MatrixOrdering,
label: Option<String>,
}
impl TensorBuilder {
pub fn scalar(usage: BufferUsages) -> Self {
Self::tensor([1, 1, 1, 1], usage)
}
pub fn vector(dim: u32, usage: BufferUsages) -> Self {
Self::tensor([dim, 1, 1, 1], usage)
}
pub fn matrix(nrows: u32, ncols: u32, usage: BufferUsages) -> Self {
Self::tensor([nrows, ncols, 1, 1], usage)
}
pub fn tensor(shape: [u32; 4], usage: BufferUsages) -> Self {
Self {
shape,
usage,
ordering: MatrixOrdering::ColumnMajor,
label: None,
}
}
fn len(&self) -> u64 {
self.shape.into_iter().map(|s| s as u64).product()
}
pub fn ordering(mut self, ordering: MatrixOrdering) -> Self {
self.ordering = ordering;
self
}
pub fn label(mut self, label: String) -> Self {
self.label = Some(label);
self
}
pub unsafe fn build_uninit<T: DeviceValue + Pod, B: Backend>(
self,
backend: &B,
) -> Result<GpuTensor<T, B>, B::Error> {
let buffer = unsafe { backend.uninit_buffer(self.len() as usize, self.usage)? };
Ok(GpuTensor {
shape: self.shape,
buffer,
ordering: self.ordering,
})
}
pub unsafe fn build_uninit_encased<T: DeviceValue + EncaseType, B: Backend>(
self,
backend: &B,
) -> Result<GpuTensor<T, B>, B::Error> {
let buffer = unsafe { backend.uninit_buffer_encased(self.len() as usize, self.usage)? };
Ok(GpuTensor {
shape: self.shape,
buffer,
ordering: self.ordering,
})
}
pub fn build_init<T: DeviceValue + Pod, B: Backend>(
self,
backend: &B,
data: &[T],
) -> Result<GpuTensor<T, B>, B::Error> {
assert!(
data.len() as u64 >= self.len(),
"Incorrect number of elements provided for initializing Tensor.\
Expected at least {}, found {}",
self.len(),
data.len()
);
let buffer = backend.init_buffer(data, self.usage)?;
Ok(GpuTensor {
shape: self.shape,
buffer,
ordering: self.ordering,
})
}
pub fn build_encased<T: DeviceValue + EncaseType, B: Backend>(
self,
backend: &B,
data: &[T],
) -> Result<GpuTensor<T, B>, B::Error> {
assert!(
data.len() as u64 >= self.len(),
"Incorrect number of elements provided for initializing Tensor.\
Expected at least {}, found {}",
self.len(),
data.len()
);
let buffer = backend.init_buffer_encased(data, self.usage)?;
Ok(GpuTensor {
shape: self.shape,
buffer,
ordering: self.ordering,
})
}
}
pub type GpuVector<T, B> = GpuTensor<T, B>;
pub type GpuMatrix<T, B> = GpuTensor<T, B>;
pub type GpuScalar<T, B> = GpuTensor<T, B>;
pub struct GpuTensor<T: DeviceValue, B: Backend> {
shape: [u32; 4],
buffer: B::Buffer<T>,
ordering: MatrixOrdering,
}
pub type WgpuTensor<T> = GpuTensor<T, WebGpu>;
#[cfg(feature = "cuda")]
pub type CudaTensor<T> = GpuTensor<T, Cuda>;
impl<T: DeviceValue, B: Backend> GpuTensor<T, B> {
pub fn ordering(&self) -> MatrixOrdering {
self.ordering
}
pub fn transposed(mut self) -> Self {
self.transpose();
self
}
pub fn transpose(&mut self) {
self.shape.swap(0, 1);
self.ordering = self.ordering.transpose();
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn len(&self) -> u64 {
self.shape.into_iter().map(|s| s as u64).product()
}
pub fn size(&self, i: usize) -> u32 {
self.shape[i]
}
pub fn size_ggml(&self, i: usize) -> u32 {
self.size(GGML_IDS[i])
}
pub fn stride(&self, i: usize) -> u32 {
self.as_view().view_shape.stride[i]
}
pub fn stride_ggml(&self, i: usize) -> u32 {
self.stride(GGML_IDS[i])
}
pub fn bytes_len(&self) -> u64
where
T: DeviceValue,
{
std::mem::size_of::<T>() as u64 * self.len()
}
pub fn copy_from_view<'a>(
&mut self,
encoder: &mut B::Encoder,
source: impl Into<GpuTensorView<'a, T, B>>,
) -> Result<(), B::Error>
where
T: DeviceValue + Pod,
{
let source = source.into();
let copy_len = self.len();
encoder.copy_buffer_to_buffer(
source.buffer,
source.offset as usize,
&mut self.buffer,
0,
copy_len as usize,
)
}
pub fn copy_from_view_encased<'a>(
&mut self,
encoder: &mut B::Encoder,
source: impl Into<GpuTensorView<'a, T, B>>,
) -> Result<(), B::Error>
where
T: DeviceValue + ShaderType,
{
let source = source.into();
let copy_len = self.len();
encoder.copy_buffer_to_buffer_encased(
source.buffer,
source.offset as usize,
&mut self.buffer,
0,
copy_len as usize,
)
}
pub fn shape(&self) -> [u32; 4] {
self.shape
}
pub fn buffer(&self) -> &B::Buffer<T> {
&self.buffer
}
pub fn buffer_mut(&mut self) -> &mut B::Buffer<T> {
&mut self.buffer
}
pub fn into_inner(self) -> B::Buffer<T> {
self.buffer
}
pub fn as_view(&self) -> GpuTensorView<'_, T, B> {
GpuTensorView {
view_shape: ViewShape::contiguous(self.shape, self.ordering),
offset: 0,
buffer: &self.buffer,
}
}
pub fn as_view_mut(&mut self) -> GpuTensorViewMut<'_, T, B> {
GpuTensorViewMut {
view_shape: ViewShape::contiguous(self.shape, self.ordering),
offset: 0,
buffer: &mut self.buffer,
}
}
}
impl<'a, T: DeviceValue, B: Backend> From<&'a Arc<GpuTensor<T, B>>> for GpuTensorView<'a, T, B> {
fn from(val: &'a Arc<GpuTensor<T, B>>) -> Self {
val.as_view()
}
}
impl<'a, T: DeviceValue, B: Backend> From<&'a GpuTensor<T, B>> for GpuTensorView<'a, T, B> {
fn from(val: &'a GpuTensor<T, B>) -> Self {
val.as_view()
}
}
impl<'a, T: DeviceValue, B: Backend> From<&'a mut GpuTensor<T, B>> for GpuTensorViewMut<'a, T, B> {
fn from(val: &'a mut GpuTensor<T, B>) -> Self {
val.as_view_mut()
}
}
pub struct GpuTensorViewMut<'a, T: DeviceValue, B: Backend> {
view_shape: ViewShape,
buffer: &'a mut B::Buffer<T>,
offset: u32,
}
pub struct GpuTensorView<'a, T: DeviceValue, B: Backend> {
view_shape: ViewShape,
buffer: &'a B::Buffer<T>,
offset: u32,
}
impl<'a, T: DeviceValue, B: Backend> Clone for GpuTensorView<'a, T, B> {
fn clone(&self) -> Self {
*self
}
}
impl<'a, T: DeviceValue, B: Backend> Copy for GpuTensorView<'a, T, B> {}
impl<'a, T: DeviceValue, B: Backend> GpuTensorView<'a, T, B> {
pub fn ordering(&self) -> Option<MatrixOrdering> {
self.view_shape.ordering()
}
pub fn is_contiguous(&self) -> Option<MatrixOrdering> {
self.view_shape.is_contiguous()
}
pub fn is_entire_tensor(&self) -> Option<MatrixOrdering> {
if self.buffer.len() == self.len() as usize && self.offset == 0 {
self.is_contiguous()
} else {
None
}
}
pub fn shape(&self) -> ViewShape {
self.view_shape
}
pub fn buffer(&self) -> B::BufferSlice<'_, T> {
self.buffer.slice(self.offset as usize..)
}
pub fn raw_buffer(&self) -> &B::Buffer<T> {
self.buffer
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn len(&self) -> u64 {
self.view_shape.len()
}
pub fn size(&self, i: usize) -> u32 {
self.view_shape.size[i]
}
pub fn size_ggml(&self, i: usize) -> u32 {
self.view_shape.size[GGML_IDS[i]]
}
pub fn stride(&self, i: usize) -> u32 {
self.view_shape.stride[i]
}
pub fn stride_ggml(&self, i: usize) -> u32 {
self.view_shape.stride[GGML_IDS[i]]
}
pub fn transposed(&self) -> Self {
self.permute([1, 0, 2, 3])
}
pub fn permute(&self, permutations: [usize; 4]) -> Self {
Self {
view_shape: self.view_shape.permute(permutations),
offset: self.offset,
buffer: self.buffer,
}
}
pub fn permute_ggml(&self, permutations: [usize; 4]) -> Self {
Self {
view_shape: self.view_shape.permute_ggml(permutations),
offset: self.offset,
buffer: self.buffer,
}
}
pub fn reshape_with_ordering<const DIM2: usize>(
&self,
shape: [u32; DIM2],
ordering: MatrixOrdering,
) -> Self {
assert!(DIM2 <= 4);
let mut shape4 = [1; 4];
shape4[..DIM2].copy_from_slice(&shape);
let view_shape = ViewShape::contiguous(shape4, ordering);
self.view(0, shape4, view_shape.stride.map(Some))
}
pub fn reshape<const DIM2: usize>(&self, shape: [u32; DIM2]) -> Self {
self.view(0, shape, [None; DIM2])
}
pub fn reshape_ggml<const DIM2: usize>(&self, mut shape: [u32; DIM2]) -> Self {
shape.swap(0, 1);
if self.view_shape.size[0] == 1 && self.view_shape.size[1] == 1 {
self.reshape_with_ordering(shape, MatrixOrdering::RowMajor)
} else {
self.reshape(shape)
}
}
pub fn reshape_ggml_with_ordering<const DIM2: usize>(
&self,
mut shape: [u32; DIM2],
ordering: MatrixOrdering,
) -> Self {
shape.swap(0, 1);
self.reshape_with_ordering(shape, ordering)
}
pub fn view<const DIM2: usize>(
&self,
mut offset: u32,
shape: [u32; DIM2],
stride: [Option<u32>; DIM2],
) -> Self {
let available_elts = self.view_shape.size.iter().product::<u32>();
let needed_elts = shape.iter().product::<u32>() + offset;
assert!(
needed_elts <= available_elts,
"Source tensor is too small for reshaping. Expected at least {needed_elts} elements (shape: {shape:?}, offset: {offset}), found {available_elts} (shape: {:?})",
self.view_shape.size
);
offset += self.offset;
GpuTensorView {
view_shape: self.view_shape.view(shape, stride),
offset,
buffer: self.buffer,
}
}
pub fn view_ggml<const DIM2: usize>(
&self,
offset: u32,
mut shape: [u32; DIM2],
mut stride: [Option<u32>; DIM2],
) -> Self {
shape.swap(0, 1);
stride.swap(0, 1);
self.view(offset, shape, stride)
}
pub fn matrix(&self, matrix_id: u32) -> Self {
let [nrows, ncols, nmats, ncubes] = self.view_shape.size;
assert!(matrix_id < nmats);
GpuTensorView {
view_shape: ViewShape {
size: [nrows, ncols, 1, ncubes],
stride: self.view_shape.stride,
},
offset: self.offset + self.view_shape.stride[2] * matrix_id,
buffer: self.buffer,
}
}
pub fn columns(&self, first_col: u32, new_ncols: u32) -> Self {
let [nrows, ncols, nmats, ncubes] = self.view_shape.size;
assert!(first_col + new_ncols < ncols);
GpuTensorView {
view_shape: ViewShape {
size: [nrows, new_ncols, nmats, ncubes],
stride: self.view_shape.stride,
},
offset: self.offset + self.view_shape.stride[1] * first_col,
buffer: self.buffer,
}
}
pub fn column(&self, col: u32) -> Self {
self.columns(col, 1)
}
pub fn rows(&self, first_row: u32, new_nrows: u32) -> Self {
let [nrows, ncols, nmats, ncubes] = self.view_shape.size;
assert!(first_row + new_nrows < nrows);
GpuTensorView {
view_shape: ViewShape {
size: [new_nrows, ncols, nmats, ncubes],
stride: self.view_shape.stride,
},
offset: self.offset + self.view_shape.stride[0] * first_row,
buffer: self.buffer,
}
}
pub fn row(&self, row: u32) -> Self {
self.rows(row, 1)
}
}
impl<'a, T: DeviceValue, B: Backend> GpuTensorViewMut<'a, T, B> {
pub fn as_ref(&self) -> GpuTensorView<'_, T, B> {
GpuTensorView {
view_shape: self.view_shape,
buffer: &*self.buffer,
offset: self.offset,
}
}
pub fn is_contiguous(&self) -> Option<MatrixOrdering> {
self.as_ref().is_contiguous()
}
pub fn is_entire_tensor(&self) -> Option<MatrixOrdering> {
self.as_ref().is_entire_tensor()
}
pub fn shape(&self) -> ViewShape {
self.view_shape
}
pub fn raw_buffer(&mut self) -> &mut B::Buffer<T> {
self.buffer
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn len(&self) -> u64 {
self.view_shape.len()
}
pub fn transposed(&mut self) -> GpuTensorViewMut<'_, T, B> {
self.permute([1, 0, 2, 3])
}
pub fn permute(&mut self, permutations: [usize; 4]) -> GpuTensorViewMut<'_, T, B> {
GpuTensorViewMut {
view_shape: self.view_shape.permute(permutations),
offset: self.offset,
buffer: self.buffer,
}
}
pub fn reshape<const DIM2: usize>(&mut self, shape: [u32; DIM2]) -> GpuTensorViewMut<'_, T, B> {
self.view(0, shape, [None; DIM2])
}
pub fn view<const DIM2: usize>(
&mut self,
mut offset: u32,
shape: [u32; DIM2],
stride: [Option<u32>; DIM2],
) -> GpuTensorViewMut<'_, T, B> {
offset += self.offset;
let available_elts = self.view_shape.size.iter().product::<u32>();
let needed_elts = shape.iter().product::<u32>() + offset;
assert!(
needed_elts <= available_elts,
"Source tensor is too small for reshaping. Expected at least {needed_elts} elements (shape: {shape:?}), found {available_elts} (shape: {:?})",
self.view_shape.size
);
GpuTensorViewMut {
view_shape: self.view_shape.view(shape, stride),
offset,
buffer: self.buffer,
}
}
pub fn matrix(&mut self, matrix_id: u32) -> GpuTensorViewMut<'_, T, B> {
let [nrows, ncols, nmats, ncubes] = self.view_shape.size;
assert!(matrix_id < nmats);
GpuTensorViewMut {
view_shape: ViewShape {
size: [nrows, ncols, 1, ncubes],
stride: self.view_shape.stride,
},
offset: self.offset + self.view_shape.stride[2] * matrix_id,
buffer: self.buffer,
}
}
pub fn columns(&mut self, first_col: u32, new_ncols: u32) -> GpuTensorViewMut<'_, T, B> {
let [nrows, ncols, nmats, ncubes] = self.view_shape.size;
assert!(first_col + new_ncols < ncols);
GpuTensorViewMut {
view_shape: ViewShape {
size: [nrows, new_ncols, nmats, ncubes],
stride: self.view_shape.stride,
},
offset: self.offset + self.view_shape.stride[1] * first_col,
buffer: self.buffer,
}
}
pub fn column(&mut self, col: u32) -> GpuTensorViewMut<'_, T, B> {
self.columns(col, 1)
}
pub fn rows(&mut self, first_row: u32, new_nrows: u32) -> GpuTensorViewMut<'_, T, B> {
let [nrows, ncols, nmats, ncubes] = self.view_shape.size;
assert!(first_row + new_nrows < nrows);
GpuTensorViewMut {
view_shape: ViewShape {
size: [new_nrows, ncols, nmats, ncubes],
stride: self.view_shape.stride,
},
offset: self.offset + self.view_shape.stride[0] * first_row,
buffer: self.buffer,
}
}
pub fn row(&mut self, row: u32) -> GpuTensorViewMut<'_, T, B> {
self.rows(row, 1)
}
}
impl<T: DeviceValue, B: Backend> GpuTensor<T, B> {
pub fn reshape<const DIM2: usize>(&self, shape: [u32; DIM2]) -> GpuTensorView<'_, T, B> {
self.as_view().reshape_with_ordering(shape, self.ordering)
}
pub fn reshape_ggml<const DIM2: usize>(&self, shape: [u32; DIM2]) -> GpuTensorView<'_, T, B> {
self.as_view()
.reshape_ggml_with_ordering(shape, self.ordering)
}
pub fn permute(&self, permutations: [usize; 4]) -> GpuTensorView<'_, T, B> {
self.as_view().permute(permutations)
}
pub fn permute_ggml(&self, permutations: [usize; 4]) -> GpuTensorView<'_, T, B> {
self.as_view().permute_ggml(permutations)
}
pub fn view<const DIM2: usize>(
&self,
offset: u32,
shape: [u32; DIM2],
stride: [Option<u32>; DIM2],
) -> GpuTensorView<'_, T, B> {
self.as_view().view(offset, shape, stride)
}
pub fn view_ggml<const DIM2: usize>(
&self,
offset: u32,
shape: [u32; DIM2],
stride: [Option<u32>; DIM2],
) -> GpuTensorView<'_, T, B> {
self.as_view().view_ggml(offset, shape, stride)
}
pub fn column(&self, i: u32) -> GpuTensorView<'_, T, B> {
self.as_view().column(i)
}
pub fn columns(&self, first_col: u32, ncols: u32) -> GpuTensorView<'_, T, B> {
self.as_view().columns(first_col, ncols)
}
pub fn row(&self, i: u32) -> GpuTensorView<'_, T, B> {
self.as_view().row(i)
}
pub fn rows(&self, first_row: u32, nrows: u32) -> GpuTensorView<'_, T, B> {
self.as_view().rows(first_row, nrows)
}
}
impl<T: DeviceValue + Pod, B: Backend> GpuTensor<T, B> {
pub unsafe fn matrix_uninit(
backend: &B,
nrows: u32,
ncols: u32,
usage: BufferUsages,
) -> Result<Self, B::Error>
where
T: DeviceValue,
{
unsafe { TensorBuilder::matrix(nrows, ncols, usage).build_uninit(backend) }
}
pub fn matrix<R: Dim, C: Dim, S: Storage<T, R, C> + IsContiguous>(
backend: &B,
matrix: &Matrix<T, R, C, S>,
usage: BufferUsages,
) -> Result<Self, B::Error>
where
T: DeviceValue + nalgebra::Scalar,
{
Self::matrix_with_ordering(backend, matrix, MatrixOrdering::default(), usage)
}
pub fn matrix_with_ordering<R: Dim, C: Dim, S: Storage<T, R, C> + IsContiguous>(
backend: &B,
matrix: &Matrix<T, R, C, S>,
ordering: MatrixOrdering,
usage: BufferUsages,
) -> Result<Self, B::Error>
where
T: DeviceValue + nalgebra::Scalar,
{
TensorBuilder::matrix(matrix.nrows() as u32, matrix.ncols() as u32, usage)
.ordering(ordering)
.build_init(backend, matrix.as_slice())
}
}
impl<T: DeviceValue, B: Backend> GpuTensor<T, B> {
pub unsafe fn vector_uninit(
backend: &B,
len: u32,
usage: BufferUsages,
) -> Result<Self, B::Error>
where
T: DeviceValue + Pod,
{
unsafe { TensorBuilder::vector(len, usage).build_uninit(backend) }
}
pub fn vector(
backend: &B,
vector: impl AsRef<[T]>,
usage: BufferUsages,
) -> Result<Self, B::Error>
where
T: DeviceValue + Pod,
{
let v = vector.as_ref();
TensorBuilder::vector(v.len() as u32, usage).build_init(backend, v.as_ref())
}
pub unsafe fn vector_uninit_encased(
backend: &B,
len: u32,
usage: BufferUsages,
) -> Result<Self, B::Error>
where
T: DeviceValue + EncaseType,
{
unsafe { TensorBuilder::vector(len, usage).build_uninit_encased(backend) }
}
pub fn vector_encased(
backend: &B,
vector: impl AsRef<[T]>,
usage: BufferUsages,
) -> Result<Self, B::Error>
where
T: DeviceValue + EncaseType,
{
let v = vector.as_ref();
TensorBuilder::vector(v.len() as u32, usage).build_encased(backend, v.as_ref())
}
}
impl<T: DeviceValue, B: Backend> GpuTensor<T, B> {
pub unsafe fn scalar_uninit(backend: &B, usage: BufferUsages) -> Result<Self, B::Error>
where
T: DeviceValue + Pod,
{
unsafe { TensorBuilder::scalar(usage).build_uninit(backend) }
}
pub unsafe fn scalar_uninit_encased(backend: &B, usage: BufferUsages) -> Result<Self, B::Error>
where
T: DeviceValue + EncaseType,
{
unsafe { TensorBuilder::scalar(usage).build_uninit_encased(backend) }
}
pub fn scalar(backend: &B, value: T, usage: BufferUsages) -> Result<Self, B::Error>
where
T: DeviceValue + Pod,
{
TensorBuilder::scalar(usage).build_init(backend, &[value])
}
pub fn scalar_encased(backend: &B, value: T, usage: BufferUsages) -> Result<Self, B::Error>
where
T: DeviceValue + EncaseType,
{
TensorBuilder::scalar(usage).build_encased(backend, &[value])
}
}
impl<'b, B: Backend, T: DeviceValue> ShaderArgs<'b, B> for GpuTensor<T, B> {
fn write_arg<'a>(
&'b self,
binding: ShaderBinding,
name: &str,
dispatch: &mut B::Dispatch<'a>,
) -> Result<(), ShaderArgsError>
where
'b: 'a,
{
self.buffer.write_arg(binding, name, dispatch)
}
}