use std::{alloc::Layout, collections::BTreeMap, vec::Vec};
#[cfg(not(feature = "no-std"))]
use std::sync::RwLock;
#[cfg(feature = "no-std")]
use spin::RwLock;
#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub(crate) struct AllocationKey {
pub num_bytes: usize,
pub size: usize,
pub alignment: usize,
}
#[derive(Debug)]
pub(crate) struct TensorCache<Ptr> {
pub(crate) allocations: RwLock<BTreeMap<AllocationKey, Vec<Ptr>>>,
pub(crate) enabled: RwLock<bool>,
}
impl<Ptr> Default for TensorCache<Ptr> {
fn default() -> Self {
Self {
allocations: Default::default(),
enabled: RwLock::new(false),
}
}
}
impl<Ptr> TensorCache<Ptr> {
#[allow(unused)]
pub(crate) fn len(&self) -> usize {
#[cfg(not(feature = "no-std"))]
{
self.allocations.read().unwrap().len()
}
#[cfg(feature = "no-std")]
{
self.allocations.read().len()
}
}
pub(crate) fn is_enabled(&self) -> bool {
#[cfg(not(feature = "no-std"))]
{
*self.enabled.read().unwrap()
}
#[cfg(feature = "no-std")]
{
*self.enabled.read()
}
}
pub(crate) fn enable(&self) {
#[cfg(not(feature = "no-std"))]
{
*self.enabled.write().unwrap() = true;
}
#[cfg(feature = "no-std")]
{
*self.enabled.write() = true;
}
}
pub(crate) fn disable(&self) {
#[cfg(not(feature = "no-std"))]
{
*self.enabled.write().unwrap() = false;
}
#[cfg(feature = "no-std")]
{
*self.enabled.write() = false;
}
}
pub(crate) fn try_pop<E>(&self, len: usize) -> Option<Ptr> {
if !self.is_enabled() {
return None;
}
let layout = Layout::new::<E>();
let num_bytes = len * std::mem::size_of::<E>();
let key = AllocationKey {
num_bytes,
size: layout.size(),
alignment: layout.align(),
};
let reuse = {
#[cfg(not(feature = "no-std"))]
let cache = self.allocations.read().unwrap();
#[cfg(feature = "no-std")]
let cache = self.allocations.read();
cache.contains_key(&key)
};
if reuse {
#[cfg(not(feature = "no-std"))]
let mut cache = self.allocations.write().unwrap();
#[cfg(feature = "no-std")]
let mut cache = self.allocations.write();
let items = cache.get_mut(&key).unwrap();
let allocation = items.pop().unwrap();
if items.is_empty() {
cache.remove(&key);
}
Some(allocation)
} else {
None
}
}
pub(crate) fn insert<E>(&self, len: usize, allocation: Ptr) {
if !self.is_enabled() {
panic!("Tried to insert into a disabled cache.");
}
let layout = Layout::new::<E>();
let num_bytes = len * std::mem::size_of::<E>();
let key = AllocationKey {
num_bytes,
size: layout.size(),
alignment: layout.align(),
};
#[cfg(not(feature = "no-std"))]
let mut cache = self.allocations.write().unwrap();
#[cfg(feature = "no-std")]
let mut cache = self.allocations.write();
if let std::collections::btree_map::Entry::Vacant(e) = cache.entry(key) {
#[cfg(not(feature = "no-std"))]
{
e.insert(std::vec![allocation]);
}
#[cfg(feature = "no-std")]
{
let mut allocations = Vec::new();
allocations.push(allocation);
e.insert(allocations);
}
} else {
cache.get_mut(&key).unwrap().push(allocation);
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
#[should_panic(expected = "Tried to insert into a disabled cache.")]
fn test_insert_on_disabled_cache() {
let cache: TensorCache<usize> = Default::default();
cache.insert::<f32>(1, 0);
}
#[test]
fn test_try_pop_on_disabled_cache() {
let cache: TensorCache<usize> = Default::default();
cache.enable();
assert!(cache.is_enabled());
cache.disable();
assert!(!cache.is_enabled());
assert_eq!(cache.try_pop::<f32>(1), None);
assert_eq!(cache.try_pop::<f32>(1), None);
}
#[test]
fn test_try_pop_on_empty_cache() {
let cache: TensorCache<usize> = Default::default();
cache.enable();
assert_eq!(cache.try_pop::<f32>(1), None);
assert_eq!(cache.try_pop::<f32>(1), None);
}
#[test]
fn test_try_pop_on_cache_with_multiple_sizes_and_alignment() {
let cache: TensorCache<usize> = Default::default();
cache.enable();
cache.insert::<f32>(1, 0);
cache.insert::<f32>(1, 1);
cache.insert::<f32>(1, 2);
cache.insert::<f32>(2, 3);
cache.insert::<f32>(2, 4);
cache.insert::<f32>(2, 5);
cache.insert::<f64>(1, 6);
cache.insert::<f64>(1, 7);
cache.insert::<f64>(1, 8);
cache.insert::<f64>(2, 9);
cache.insert::<f64>(2, 10);
cache.insert::<f64>(2, 11);
assert_eq!(cache.try_pop::<f32>(1), Some(2));
assert_eq!(cache.try_pop::<f32>(1), Some(1));
assert_eq!(cache.try_pop::<f32>(1), Some(0));
assert_eq!(cache.try_pop::<f32>(1), None);
assert_eq!(cache.try_pop::<f32>(2), Some(5));
assert_eq!(cache.try_pop::<f32>(2), Some(4));
assert_eq!(cache.try_pop::<f32>(2), Some(3));
assert_eq!(cache.try_pop::<f32>(2), None);
assert_eq!(cache.try_pop::<f64>(1), Some(8));
assert_eq!(cache.try_pop::<f64>(1), Some(7));
assert_eq!(cache.try_pop::<f64>(1), Some(6));
assert_eq!(cache.try_pop::<f64>(1), None);
assert_eq!(cache.try_pop::<f64>(2), Some(11));
assert_eq!(cache.try_pop::<f64>(2), Some(10));
assert_eq!(cache.try_pop::<f64>(2), Some(9));
assert_eq!(cache.try_pop::<f64>(2), None);
}
}