use crate::shapes::{Shape, Unit};
use crate::tensor::{cache::TensorCache, cpu::LendingIterator, storage_traits::*, Tensor};
use rand::{rngs::StdRng, Rng, SeedableRng};
use std::{sync::Arc, vec::Vec};
#[cfg(feature = "no-std")]
use spin::Mutex;
#[cfg(not(feature = "no-std"))]
use std::sync::Mutex;
#[derive(Copy, Clone, Debug)]
pub(crate) struct BytesPtr(pub(crate) *mut u8);
unsafe impl Send for BytesPtr {}
unsafe impl Sync for BytesPtr {}
#[derive(Clone, Debug)]
pub struct Cpu {
pub(crate) rng: Arc<Mutex<StdRng>>,
pub(crate) cache: Arc<TensorCache<BytesPtr>>,
}
impl Default for Cpu {
fn default() -> Self {
Self {
rng: Arc::new(Mutex::new(StdRng::seed_from_u64(0))),
cache: Arc::new(Default::default()),
}
}
}
impl Cpu {
pub fn seed_from_u64(seed: u64) -> Self {
Self {
rng: Arc::new(Mutex::new(StdRng::seed_from_u64(seed))),
cache: Arc::new(Default::default()),
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum CpuError {
OutOfMemory,
WrongNumElements,
}
impl std::fmt::Display for CpuError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::OutOfMemory => f.write_str("CpuError::OutOfMemory"),
Self::WrongNumElements => f.write_str("CpuError::WrongNumElements"),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for CpuError {}
impl HasErr for Cpu {
type Err = CpuError;
}
#[derive(Debug)]
pub struct CachableVec<E> {
pub(crate) data: Vec<E>,
pub(crate) cache: Arc<TensorCache<BytesPtr>>,
}
impl<E: Clone> Clone for CachableVec<E> {
fn clone(&self) -> Self {
let numel = self.data.len();
self.cache.try_pop::<E>(numel).map_or_else(
|| Self {
data: self.data.clone(),
cache: self.cache.clone(),
},
|allocation| {
assert!(numel < isize::MAX as usize);
let mut data = unsafe { Vec::from_raw_parts(allocation.0 as *mut E, numel, numel) };
data.clone_from(&self.data);
Self {
data,
cache: self.cache.clone(),
}
},
)
}
}
impl<E> Drop for CachableVec<E> {
fn drop(&mut self) {
if self.cache.is_enabled() {
let mut data = std::mem::take(&mut self.data);
data.shrink_to_fit();
let numel = data.len();
let ptr = data.as_mut_ptr() as *mut u8;
std::mem::forget(data);
self.cache.insert::<E>(numel, BytesPtr(ptr));
}
}
}
impl<E> std::ops::Deref for CachableVec<E> {
type Target = Vec<E>;
fn deref(&self) -> &Self::Target {
&self.data
}
}
impl<E> std::ops::DerefMut for CachableVec<E> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.data
}
}
impl RandomU64 for Cpu {
fn random_u64(&self) -> u64 {
#[cfg(not(feature = "no-std"))]
{
self.rng.lock().unwrap().gen()
}
#[cfg(feature = "no-std")]
{
self.rng.lock().gen()
}
}
}
impl<E: Unit> Storage<E> for Cpu {
type Vec = CachableVec<E>;
fn try_alloc_len(&self, len: usize) -> Result<Self::Vec, Self::Err> {
self.try_alloc_zeros(len)
}
fn len(&self, v: &Self::Vec) -> usize {
v.len()
}
fn tensor_to_vec<S: Shape, T>(&self, tensor: &Tensor<S, E, Self, T>) -> Vec<E> {
let mut buf = Vec::with_capacity(tensor.shape.num_elements());
let mut iter = tensor.iter();
while let Some(v) = iter.next() {
buf.push(*v);
}
buf
}
}
impl Synchronize for Cpu {
fn try_synchronize(&self) -> Result<(), Self::Err> {
Ok(())
}
}
impl Cache for Cpu {
fn try_enable_cache(&self) -> Result<(), Self::Err> {
self.cache.enable();
Ok(())
}
fn try_disable_cache(&self) -> Result<(), Self::Err> {
self.cache.disable();
self.try_empty_cache()
}
fn try_empty_cache(&self) -> Result<(), Self::Err> {
#[cfg(not(feature = "no-std"))]
let mut cache = self.cache.allocations.write().unwrap();
#[cfg(feature = "no-std")]
let mut cache = self.cache.allocations.write();
for (&key, allocations) in cache.iter_mut() {
assert!(key.num_bytes % key.size == 0);
assert!(key.num_bytes < isize::MAX as usize);
let len = key.num_bytes / key.size;
let cap = len;
for alloc in allocations.drain(..) {
debug_assert_eq!(std::alloc::Layout::new::<u8>().align(), 1);
debug_assert_eq!(std::alloc::Layout::new::<u16>().align(), 2);
debug_assert_eq!(std::alloc::Layout::new::<u32>().align(), 4);
debug_assert_eq!(std::alloc::Layout::new::<u64>().align(), 8);
match key.alignment {
1 => unsafe { drop(Vec::from_raw_parts(alloc.0, len, cap)) },
2 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u16, len, cap)) },
4 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u32, len, cap)) },
8 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u64, len, cap)) },
_ => unreachable!(),
};
}
}
cache.clear();
Ok(())
}
}