use crate::shapes::{GGML_IDS, TensorLayout};
use crate::tensor::{AsTensorRef, Tensor, TensorRef};
use bytemuck::NoUninit;
use khal::backend::{Buffer, DeviceValue, GpuBuffer, GpuBufferSlice, GpuBufferSliceMut};
pub struct TensorMut<'a, T: DeviceValue> {
layout: TensorLayout,
buffer: &'a mut GpuBuffer<T>,
}
pub trait AsTensorMut<T: DeviceValue> {
fn as_tensor_mut(&mut self) -> TensorMut<'_, T>;
}
impl<'a, T: DeviceValue> From<&'a mut Tensor<T>> for TensorMut<'a, T> {
fn from(val: &'a mut Tensor<T>) -> TensorMut<'a, T> {
val.as_view_mut()
}
}
impl<T: DeviceValue> AsTensorMut<T> for Tensor<T> {
#[inline]
fn as_tensor_mut(&mut self) -> TensorMut<'_, T> {
self.as_view_mut()
}
}
impl<T: DeviceValue> AsTensorMut<T> for &mut Tensor<T> {
#[inline]
fn as_tensor_mut(&mut self) -> TensorMut<'_, T> {
self.as_view_mut()
}
}
impl<'a, T: DeviceValue> AsTensorMut<T> for TensorMut<'a, T> {
#[inline]
fn as_tensor_mut(&mut self) -> TensorMut<'_, T> {
TensorMut {
layout: self.layout,
buffer: &mut *self.buffer,
}
}
}
impl<'a, 'b, T: DeviceValue> AsTensorMut<T> for &'b mut TensorMut<'a, T> {
#[inline]
fn as_tensor_mut(&mut self) -> TensorMut<'_, T> {
TensorMut {
layout: self.layout,
buffer: &mut *self.buffer,
}
}
}
impl<'a, T: DeviceValue> AsTensorRef<T> for TensorMut<'a, T> {
#[inline]
fn as_tensor_ref(&self) -> TensorRef<'_, T> {
self.as_ref()
}
}
impl<'a, 'b, T: DeviceValue> AsTensorRef<T> for &'b TensorMut<'a, T> {
#[inline]
fn as_tensor_ref(&self) -> TensorRef<'_, T> {
self.as_ref()
}
}
impl<'a, T: DeviceValue> TensorMut<'a, T> {
pub(crate) fn new(layout: TensorLayout, buffer: &'a mut GpuBuffer<T>) -> TensorMut<'a, T> {
TensorMut { layout, buffer }
}
pub(crate) fn contiguous(dims: &[u32], buffer: &'a mut GpuBuffer<T>) -> TensorMut<'a, T> {
Self::new(TensorLayout::contiguous(dims), buffer)
}
pub fn as_ref(&self) -> TensorRef<'_, T> {
TensorRef::new(self.layout, &*self.buffer)
}
pub fn is_contiguous(&self) -> bool {
self.layout.is_contiguous()
}
pub fn is_entire_tensor(&self) -> bool
where
T: NoUninit,
{
self.as_ref().is_entire_tensor()
}
pub fn rank(&self) -> u32 {
self.layout().rank
}
pub fn layout(&self) -> TensorLayout {
self.layout
}
pub fn buffer(&self) -> GpuBufferSlice<'_, T> {
self.buffer.slice(self.layout.offset as usize..)
}
pub fn buffer_mut(&mut self) -> GpuBufferSliceMut<'_, T> {
self.buffer.slice_mut(self.layout.offset as usize..)
}
pub fn raw_buffer(&mut self) -> &mut GpuBuffer<T> {
self.buffer
}
pub fn is_empty(&self) -> bool {
self.layout.is_empty()
}
pub fn len(&self) -> u64 {
self.layout.len()
}
pub fn size(&self, i: usize) -> u32 {
self.layout.size[i]
}
pub fn size_ggml(&self, i: usize) -> u32 {
self.layout.size[GGML_IDS[i]]
}
pub fn stride(&self, i: usize) -> u32 {
self.layout.stride[i]
}
pub fn stride_ggml(&self, i: usize) -> u32 {
self.layout.stride[GGML_IDS[i]]
}
fn with_layout(self, layout: TensorLayout) -> Self {
Self {
layout,
buffer: self.buffer,
}
}
super::tensor_macro::impl_layout_modifiers! {}
}