use crate::shapes::{GGML_IDS, TensorLayout};
use crate::tensor::Tensor;
use bytemuck::NoUninit;
use khal::backend::{Buffer, DeviceValue, GpuBuffer, GpuBufferSlice};
use std::sync::Arc;
pub struct TensorRef<'a, T: DeviceValue> {
layout: TensorLayout,
buffer: &'a GpuBuffer<T>,
}
impl<'a, T: DeviceValue> Clone for TensorRef<'a, T> {
fn clone(&self) -> Self {
*self
}
}
impl<'a, T: DeviceValue> Copy for TensorRef<'a, T> {}
pub trait AsTensorRef<T: DeviceValue> {
fn as_tensor_ref(&self) -> TensorRef<'_, T>;
}
impl<'a, T: DeviceValue> From<&'a Arc<Tensor<T>>> for TensorRef<'a, T> {
fn from(val: &'a Arc<Tensor<T>>) -> Self {
val.as_view()
}
}
impl<'a, T: DeviceValue> From<&'a Tensor<T>> for TensorRef<'a, T> {
fn from(val: &'a Tensor<T>) -> Self {
val.as_view()
}
}
impl<T: DeviceValue> AsTensorRef<T> for Tensor<T> {
#[inline]
fn as_tensor_ref(&self) -> TensorRef<'_, T> {
self.as_view()
}
}
impl<T: DeviceValue> AsTensorRef<T> for &Tensor<T> {
#[inline]
fn as_tensor_ref(&self) -> TensorRef<'_, T> {
self.as_view()
}
}
impl<'a, T: DeviceValue> AsTensorRef<T> for TensorRef<'a, T> {
#[inline]
fn as_tensor_ref(&self) -> TensorRef<'_, T> {
*self
}
}
impl<'a, 'b, T: DeviceValue> AsTensorRef<T> for &'b TensorRef<'a, T> {
#[inline]
fn as_tensor_ref(&self) -> TensorRef<'_, T> {
**self
}
}
impl<'a, T: DeviceValue> TensorRef<'a, T> {
pub(crate) fn new(layout: TensorLayout, buffer: &'a GpuBuffer<T>) -> Self {
TensorRef { layout, buffer }
}
pub(crate) fn contiguous(dims: &[u32], buffer: &'a GpuBuffer<T>) -> Self {
Self::new(TensorLayout::contiguous(dims), buffer)
}
pub fn is_contiguous(&self) -> bool {
self.layout.is_contiguous()
}
pub fn is_entire_tensor(&self) -> bool
where
T: NoUninit,
{
self.buffer.len() == self.len() as usize && self.layout.offset == 0 && self.is_contiguous()
}
pub fn rank(&self) -> u32 {
self.layout().rank
}
pub fn layout(&self) -> TensorLayout {
self.layout
}
pub fn buffer(&self) -> GpuBufferSlice<'_, T> {
self.buffer.slice(self.layout.offset as usize..)
}
pub fn raw_buffer(&self) -> &GpuBuffer<T> {
self.buffer
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn len(&self) -> u64 {
self.layout.len()
}
pub fn size(&self, i: usize) -> u32 {
self.layout.size[i]
}
pub fn size_ggml(&self, i: usize) -> u32 {
self.layout.size[GGML_IDS[i]]
}
pub fn stride(&self, i: usize) -> u32 {
self.layout.stride[i]
}
pub fn stride_ggml(&self, i: usize) -> u32 {
self.layout.stride[GGML_IDS[i]]
}
fn with_layout(&self, layout: TensorLayout) -> Self {
Self {
layout,
buffer: self.buffer,
}
}
super::tensor_macro::impl_layout_modifiers! {}
}