use crate::dtype::{DType, DataType, Element};
use crate::error::Result;
use crate::runtime::Runtime;
use std::sync::Arc;
pub struct Storage<R: Runtime> {
inner: Arc<StorageInner<R>>,
}
struct StorageInner<R: Runtime> {
ptr: u64,
len: usize,
dtype: R::DType,
device: R::Device,
owned: bool,
}
impl<R: Runtime> Storage<R> {
pub fn new(len: usize, dtype: R::DType, device: &R::Device) -> Result<Self> {
let size_bytes = dtype.storage_bytes(len);
let ptr = R::allocate(size_bytes, device)?;
Ok(Self {
inner: Arc::new(StorageInner {
ptr,
len,
dtype,
device: device.clone(),
owned: true,
}),
})
}
pub fn from_bytes(data: &[u8], dtype: R::DType, device: &R::Device) -> Result<Self> {
let len = data.len() / dtype.size_in_bytes();
let ptr = R::allocate(data.len(), device)?;
R::copy_to_device(data, ptr, device)?;
Ok(Self {
inner: Arc::new(StorageInner {
ptr,
len,
dtype,
device: device.clone(),
owned: true,
}),
})
}
pub unsafe fn from_ptr(ptr: u64, len: usize, dtype: R::DType, device: &R::Device) -> Self {
Self {
inner: Arc::new(StorageInner {
ptr,
len,
dtype,
device: device.clone(),
owned: false,
}),
}
}
pub unsafe fn from_ptr_owned(
ptr: u64,
len: usize,
dtype: R::DType,
device: &R::Device,
) -> Self {
Self {
inner: Arc::new(StorageInner {
ptr,
len,
dtype,
device: device.clone(),
owned: true,
}),
}
}
pub unsafe fn from_raw(
ptr: u64,
len: usize,
dtype: R::DType,
device: &R::Device,
owned: bool,
) -> Self {
Self {
inner: Arc::new(StorageInner {
ptr,
len,
dtype,
device: device.clone(),
owned,
}),
}
}
pub fn from_slice<T: Element>(data: &[T], device: &R::Device) -> Result<Self>
where
R: Runtime<DType = DType>,
{
let dtype = T::DTYPE;
let len = data.len();
let bytes = bytemuck::cast_slice(data);
let size_bytes = bytes.len();
let ptr = R::allocate(size_bytes, device)?;
R::copy_to_device(bytes, ptr, device)?;
Ok(Self {
inner: Arc::new(StorageInner {
ptr,
len,
dtype,
device: device.clone(),
owned: true,
}),
})
}
#[inline]
pub fn ptr(&self) -> u64 {
self.inner.ptr
}
#[inline]
pub fn len(&self) -> usize {
self.inner.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.inner.len == 0
}
#[inline]
pub fn dtype(&self) -> R::DType {
self.inner.dtype
}
#[inline]
pub fn device(&self) -> &R::Device {
&self.inner.device
}
#[inline]
pub fn size_in_bytes(&self) -> usize {
self.inner.dtype.storage_bytes(self.inner.len)
}
#[inline]
pub fn ref_count(&self) -> usize {
Arc::strong_count(&self.inner)
}
#[inline]
pub fn is_unique(&self) -> bool {
Arc::strong_count(&self.inner) == 1
}
#[inline]
pub fn is_owned(&self) -> bool {
self.inner.owned
}
#[inline]
pub fn as_raw(&self) -> RawBuffer
where
R: Runtime<DType = DType>,
{
RawBuffer {
ptr: self.inner.ptr,
len: self.inner.len,
dtype: self.inner.dtype,
}
}
#[inline]
pub unsafe fn as_host_slice<T: bytemuck::Pod>(&self) -> &[T] {
if self.inner.len == 0 {
return &[];
}
let ptr = self.inner.ptr as *const T;
unsafe { std::slice::from_raw_parts(ptr, self.inner.len) }
}
#[inline]
pub unsafe fn as_host_slice_mut<T: bytemuck::Pod>(&mut self) -> &mut [T] {
if self.inner.len == 0 {
return &mut [];
}
let ptr = self.inner.ptr as *mut T;
unsafe { std::slice::from_raw_parts_mut(ptr, self.inner.len) }
}
pub fn to_vec<T: bytemuck::Pod>(&self) -> Vec<T> {
let mut result = vec![T::zeroed(); self.inner.len];
let bytes: &mut [u8] = bytemuck::cast_slice_mut(&mut result);
R::copy_from_device(self.inner.ptr, bytes, &self.inner.device)
.expect("copy_from_device failed in to_vec()");
result
}
}
impl<R: Runtime> Clone for Storage<R> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl<R: Runtime> Drop for StorageInner<R> {
fn drop(&mut self) {
if self.owned && self.ptr != 0 {
R::deallocate(self.ptr, self.dtype.storage_bytes(self.len), &self.device);
}
}
}
impl<R: Runtime> std::fmt::Debug for Storage<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Storage")
.field("ptr", &format!("0x{:x}", self.inner.ptr))
.field("len", &self.inner.len)
.field("dtype", &self.inner.dtype)
.field("owned", &self.inner.owned)
.field("refs", &Arc::strong_count(&self.inner))
.finish()
}
}
#[derive(Copy, Clone, Debug)]
pub struct RawBuffer {
pub ptr: u64,
pub len: usize,
pub dtype: DType,
}
impl RawBuffer {
#[inline]
pub const fn new(ptr: u64, len: usize, dtype: DType) -> Self {
Self { ptr, len, dtype }
}
#[inline]
pub const fn empty() -> Self {
Self {
ptr: 0,
len: 0,
dtype: DType::F32,
}
}
#[inline]
pub const fn size_in_bytes(&self) -> usize {
self.len * self.dtype.size_in_bytes()
}
}