use crate::device::Device;
use crate::dtype::TensorElement;
use crate::error::Result;
use std::sync::Arc;
pub trait Storage: Send + Sync + std::fmt::Debug + 'static {
type Elem: TensorElement;
type Device: Device;
fn allocate(device: &Self::Device, size: usize) -> Result<Self>
where
Self: Sized;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn device(&self) -> &Self::Device;
fn clone_storage(&self) -> Result<Self>
where
Self: Sized;
}
#[derive(Debug)]
pub struct SharedStorage<S: Storage> {
inner: Arc<S>,
}
impl<S: Storage> Clone for SharedStorage<S> {
fn clone(&self) -> Self {
SharedStorage {
inner: self.inner.clone(),
}
}
}
impl<S: Storage> SharedStorage<S> {
pub fn new(storage: S) -> Self {
SharedStorage {
inner: Arc::new(storage),
}
}
pub fn get(&self) -> &S {
&self.inner
}
pub fn inner_arc(&self) -> &Arc<S> {
&self.inner
}
pub fn strong_count(&self) -> usize {
Arc::strong_count(&self.inner)
}
pub fn get_mut(&mut self) -> Option<&mut S> {
Arc::get_mut(&mut self.inner)
}
pub fn make_mut(&mut self) -> Result<&mut S>
where
S: Sized,
{
if Arc::strong_count(&self.inner) > 1 {
let cloned = self.inner.clone_storage()?;
self.inner = Arc::new(cloned);
}
Ok(Arc::get_mut(&mut self.inner)
.expect("Arc::get_mut should succeed after ensuring unique reference"))
}
pub fn is_unique(&self) -> bool {
Arc::strong_count(&self.inner) == 1
}
pub fn try_unwrap(self) -> std::result::Result<S, Self> {
match Arc::try_unwrap(self.inner) {
Ok(storage) => Ok(storage),
Err(arc) => Err(SharedStorage { inner: arc }),
}
}
pub fn downgrade(&self) -> std::sync::Weak<S> {
Arc::downgrade(&self.inner)
}
pub fn upgrade_from_weak(weak: &std::sync::Weak<S>) -> Option<Self> {
weak.upgrade().map(|inner| SharedStorage { inner })
}
}
pub trait StorageExt: Storage {
fn allocate_with_value(device: &Self::Device, size: usize, value: Self::Elem) -> Result<Self>
where
Self: Sized,
{
let _ = value;
Self::allocate(device, size)
}
fn allocate_zeros(device: &Self::Device, size: usize) -> Result<Self>
where
Self: Sized,
{
Self::allocate_with_value(device, size, Self::Elem::zero())
}
fn allocate_ones(device: &Self::Device, size: usize) -> Result<Self>
where
Self: Sized,
{
Self::allocate_with_value(device, size, Self::Elem::one())
}
fn memory_usage(&self) -> usize {
self.len() * std::mem::size_of::<Self::Elem>()
}
fn is_compatible_with<Other: Storage>(&self, other: &Other) -> bool
where
Self::Device: PartialEq<Other::Device>,
{
self.device() == other.device()
}
}
impl<T: Storage> StorageExt for T {}
pub trait StorageFactory<S: Storage> {
fn create_default(device: &S::Device, size: usize) -> Result<S>;
fn create_with_config(device: &S::Device, size: usize, config: &StorageConfig) -> Result<S>;
}
#[derive(Debug, Clone)]
pub struct StorageConfig {
pub initial_value: Option<f32>,
pub alignment: Option<usize>,
pub use_pooling: bool,
pub clear_on_dealloc: bool,
pub prefer_mmap: bool,
}
impl Default for StorageConfig {
fn default() -> Self {
Self {
initial_value: None,
alignment: None,
use_pooling: true,
clear_on_dealloc: false,
prefer_mmap: false,
}
}
}
impl StorageConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_initial_value(mut self, value: f32) -> Self {
self.initial_value = Some(value);
self
}
pub fn with_alignment(mut self, alignment: usize) -> Self {
self.alignment = Some(alignment);
self
}
pub fn with_pooling(mut self, use_pooling: bool) -> Self {
self.use_pooling = use_pooling;
self
}
pub fn with_clear_on_dealloc(mut self, clear: bool) -> Self {
self.clear_on_dealloc = clear;
self
}
pub fn with_mmap_preference(mut self, prefer: bool) -> Self {
self.prefer_mmap = prefer;
self
}
}