use crate::rb_tree::{Color, Direction, RBTree, TreeBucket};
use std::cmp::Ordering;
use std::mem::size_of;
use std::ptr::NonNull;
type Link<T> = Option<NonNull<T>>;
use super::ALIGN;
pub const MIN_CACHE_SIZE: usize = size_of::<Bucket>();
#[cfg(debug_assertions)]
const MSB: usize = 1 << (size_of::<usize>() * 8 - 1);
struct Bucket {
left_order_: usize,
right_order_: usize,
left_size_: usize,
right_size_: usize,
}
impl Bucket {
fn left_order(&self) -> Link<Self> {
let ptr = (self.left_order_ >> 1) & !0x07;
NonNull::new(ptr as *mut Self)
}
fn set_left_order(&mut self, ptr: Link<Self>) {
let ptr = ptr.map_or(0, |ptr| ptr.as_ptr() as usize);
debug_assert!(ptr & 0x07 == 0);
debug_assert!(ptr & MSB == 0);
self.left_order_ &= 0x0f;
self.left_order_ |= ptr << 1;
}
fn right_order(&self) -> Link<Self> {
let ptr = (self.right_order_ >> 1) & !0x07;
NonNull::new(ptr as *mut Self)
}
fn set_right_order(&mut self, ptr: Link<Self>) {
let ptr = ptr.map_or(0, |ptr| ptr.as_ptr() as usize);
debug_assert!(ptr & 0x07 == 0);
debug_assert!(ptr & MSB == 0);
self.right_order_ &= 0x0f;
self.right_order_ |= ptr << 1;
}
fn left_size(&self) -> Link<Self> {
let ptr = (self.left_size_ >> 1) & !0x07;
NonNull::new(ptr as *mut Self)
}
fn set_left_size(&mut self, ptr: Link<Self>) {
let ptr = ptr.map_or(0, |ptr| ptr.as_ptr() as usize);
debug_assert!(ptr & 0x07 == 0);
debug_assert!(ptr & MSB == 0);
self.left_size_ &= 0x0f;
self.left_size_ |= ptr << 1;
}
fn right_size(&self) -> Link<Self> {
let ptr = (self.right_size_ >> 1) & !0x07;
NonNull::new(ptr as *mut Self)
}
fn set_right_size(&mut self, ptr: Link<Self>) {
let ptr = ptr.map_or(0, |ptr| ptr.as_ptr() as usize);
debug_assert!(ptr & 0x07 == 0);
debug_assert!(ptr & MSB == 0);
self.right_size_ &= 0x0f;
self.right_size_ |= ptr << 1;
}
fn order_color(&self) -> Color {
if self.left_order_ & 0x04 == 0 {
Color::Black
} else {
Color::Red
}
}
fn set_order_color(&mut self, color: Color) {
match color {
Color::Black => self.left_order_ &= !0x04,
Color::Red => self.left_order_ |= 0x04,
}
}
fn size_color(&self) -> Color {
if self.left_order_ & 0x02 == 0 {
Color::Black
} else {
Color::Red
}
}
fn set_size_color(&mut self, color: Color) {
match color {
Color::Black => self.left_order_ &= !0x02,
Color::Red => self.left_order_ |= 0x02,
}
}
fn size(&self) -> usize {
let a = (self.left_order_ & 0x01) << 15;
let b = (self.right_order_ & 0x0f) << 11;
let c = (self.left_size_ & 0x0f) << 7;
let d = (self.right_size_ & 0x0f) << 3;
a + b + c + d
}
fn set_size(&mut self, size: usize) {
debug_assert!(size <= u16::MAX as usize);
debug_assert!(size & 0x07 == 0);
self.left_order_ &= !0x01;
self.left_order_ |= size >> 15;
self.right_order_ &= !0x0f;
self.right_order_ |= (size >> 11) & 0x0f;
self.left_size_ &= !0x0f;
self.left_size_ |= (size >> 7) & 0x0f;
self.right_size_ &= !0x0f;
self.right_size_ |= (size >> 3) & 0x0f;
}
}
struct SizeBucket(Bucket);
impl SizeBucket {
pub fn init(ptr: NonNull<u8>, size: usize) {
debug_assert!(size_of::<Self>() <= size);
debug_assert!(size <= u16::MAX as usize);
debug_assert!(size % ALIGN == 0);
let this: &mut Self = unsafe { ptr.cast().as_mut() };
this.0.set_size(size);
}
pub fn size(&self) -> usize {
self.0.size()
}
}
impl PartialEq<Self> for SizeBucket {
fn eq(&self, other: &Self) -> bool {
let this: *const SizeBucket = self;
this == other
}
}
impl Eq for SizeBucket {}
impl PartialEq<usize> for SizeBucket {
fn eq(&self, other: &usize) -> bool {
self.size() == *other
}
}
#[cfg(test)]
impl PartialEq<Bucket> for SizeBucket {
fn eq(&self, other: &Bucket) -> bool {
unsafe { self == std::mem::transmute::<&Bucket, &Self>(other) }
}
}
impl PartialOrd<Self> for SizeBucket {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
if self.0.size() == other.0.size() {
let this: *const SizeBucket = self;
let other: *const SizeBucket = other;
this.partial_cmp(&other)
} else {
self.0.size().partial_cmp(&other.0.size())
}
}
}
impl Ord for SizeBucket {
fn cmp(&self, other: &Self) -> Ordering {
if self.0.size() == other.0.size() {
let this: *const SizeBucket = self;
let other: *const SizeBucket = other;
this.cmp(&other)
} else {
self.0.size().cmp(&other.0.size())
}
}
}
impl PartialOrd<usize> for SizeBucket {
fn partial_cmp(&self, other: &usize) -> Option<Ordering> {
self.size().partial_cmp(other)
}
}
#[cfg(test)]
impl PartialOrd<Bucket> for SizeBucket {
fn partial_cmp(&self, other: &Bucket) -> Option<Ordering> {
unsafe { self.partial_cmp(std::mem::transmute::<&Bucket, &SizeBucket>(other)) }
}
}
impl TreeBucket for SizeBucket {
fn child(&self, direction: Direction) -> Link<Self> {
match direction {
Direction::Left => self.0.left_size().map(NonNull::cast),
Direction::Right => self.0.right_size().map(NonNull::cast),
}
}
fn set_child(&mut self, child: Link<Self>, direction: Direction) {
match direction {
Direction::Left => self.0.set_left_size(child.map(NonNull::cast)),
Direction::Right => self.0.set_right_size(child.map(NonNull::cast)),
}
}
fn color(&self) -> Color {
self.0.size_color()
}
fn set_color(&mut self, color: Color) {
self.0.set_size_color(color)
}
}
struct OrderBucket(Bucket);
impl OrderBucket {
pub fn init(ptr: NonNull<u8>, _size: usize) {
debug_assert!(size_of::<Self>() <= _size);
debug_assert!(_size <= u16::MAX as usize);
debug_assert!(_size % ALIGN == 0);
let _this: &Self = unsafe { ptr.cast().as_ref() };
debug_assert!(_this.size() == _size);
}
pub fn size(&self) -> usize {
self.0.size()
}
}
impl PartialEq<Self> for OrderBucket {
fn eq(&self, other: &Self) -> bool {
self as *const Self == other
}
}
impl PartialEq<NonNull<u8>> for OrderBucket {
fn eq(&self, other: &NonNull<u8>) -> bool {
let begin: *const u8 = (self as *const Self).cast();
let end: *const u8 = unsafe { begin.add(self.size()) };
let other: *const u8 = other.as_ptr();
begin <= other && other < end
}
}
#[cfg(test)]
impl PartialEq<Bucket> for OrderBucket {
fn eq(&self, other: &Bucket) -> bool {
unsafe { self == std::mem::transmute::<&Bucket, &Self>(other) }
}
}
impl Eq for OrderBucket {}
impl PartialOrd<Self> for OrderBucket {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
let this: *const Self = self;
let other: *const Self = other;
this.partial_cmp(&other)
}
}
impl PartialOrd<NonNull<u8>> for OrderBucket {
fn partial_cmp(&self, other: &NonNull<u8>) -> Option<Ordering> {
if self == other {
Some(Ordering::Equal)
} else {
let this: *const u8 = (self as *const Self).cast();
let other: *const u8 = other.as_ptr();
this.partial_cmp(&other)
}
}
}
#[cfg(test)]
impl PartialOrd<Bucket> for OrderBucket {
fn partial_cmp(&self, other: &Bucket) -> Option<Ordering> {
unsafe { self.partial_cmp(std::mem::transmute::<&Bucket, &OrderBucket>(other)) }
}
}
impl Ord for OrderBucket {
fn cmp(&self, other: &Self) -> Ordering {
let this: *const Self = self;
let other: *const Self = other;
this.cmp(&other)
}
}
impl TreeBucket for OrderBucket {
fn child(&self, direction: Direction) -> Link<Self> {
match direction {
Direction::Left => self.0.left_order().map(NonNull::cast),
Direction::Right => self.0.right_order().map(NonNull::cast),
}
}
fn set_child(&mut self, child: Link<Self>, direction: Direction) {
match direction {
Direction::Left => self.0.set_left_order(child.map(NonNull::cast)),
Direction::Right => self.0.set_right_order(child.map(NonNull::cast)),
}
}
fn color(&self) -> Color {
self.0.order_color()
}
fn set_color(&mut self, color: Color) {
self.0.set_order_color(color)
}
}
pub struct LargeCache {
size_tree: RBTree<SizeBucket>,
order_tree: RBTree<OrderBucket>,
}
impl LargeCache {
pub const fn new() -> Self {
Self {
size_tree: RBTree::new(),
order_tree: RBTree::new(),
}
}
#[cfg(test)]
pub fn is_empty(&self) -> bool {
if self.size_tree.is_empty() {
assert!(self.order_tree.is_empty());
true
} else {
assert_eq!(self.order_tree.is_empty(), false);
false
}
}
pub fn alloc(&mut self, size: usize) -> Option<(NonNull<u8>, usize)> {
debug_assert!(size % ALIGN == 0);
debug_assert!(0 < size);
unsafe {
let mut ptr = self.size_tree.remove_lower_bound(&size)?;
let alloc_size = ptr.as_ref().size();
let rest_size = alloc_size - size;
if size_of::<Bucket>() <= rest_size {
let size_bucket = ptr.as_mut();
SizeBucket::init(ptr.cast(), rest_size);
self.size_tree.insert(size_bucket);
let ret = ptr.as_ptr().cast::<u8>().add(rest_size);
Some((NonNull::new_unchecked(ret), size))
} else {
let order_bucket: &mut OrderBucket = ptr.cast().as_mut();
self.order_tree.remove(order_bucket);
Some((ptr.cast(), alloc_size))
}
}
}
pub fn dealloc_without_merge(&mut self, ptr: NonNull<u8>, size: usize) -> bool {
debug_assert!(ptr.as_ptr() as usize % ALIGN == 0);
debug_assert!(size % ALIGN == 0);
if size < size_of::<Bucket>() {
false
} else {
unsafe {
SizeBucket::init(ptr, size);
self.size_tree.insert(ptr.cast().as_mut());
OrderBucket::init(ptr, size);
self.order_tree.insert(ptr.cast().as_mut());
}
true
}
}
pub fn dealloc(&mut self, ptr: NonNull<u8>, size: usize) -> bool {
debug_assert!(ptr.as_ptr() as usize % ALIGN == 0);
debug_assert!(size % ALIGN == 0);
let size = unsafe {
let next_ptr = NonNull::new_unchecked(ptr.as_ptr().add(size));
match self.order_tree.remove(&next_ptr) {
None => size,
Some(ptr) => {
let size_bucket: &SizeBucket = ptr.cast().as_ref();
self.size_tree.remove(size_bucket);
size + size_bucket.size()
}
}
};
unsafe {
let prev_ptr = NonNull::new_unchecked(ptr.as_ptr().offset(-1));
match self.order_tree.find(&prev_ptr) {
None => {
if size < size_of::<Bucket>() {
return false;
}
SizeBucket::init(ptr, size);
self.size_tree.insert(ptr.cast().as_mut());
OrderBucket::init(ptr, size);
self.order_tree.insert(ptr.cast().as_mut());
}
Some(prev_ptr) => {
let order_bucket = prev_ptr.as_ref();
let size = size + order_bucket.size();
let size_bucket: &mut SizeBucket = prev_ptr.cast().as_mut();
self.size_tree.remove(size_bucket);
SizeBucket::init(prev_ptr.cast(), size);
self.size_tree.insert(size_bucket);
}
}
}
true
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_alloc() {
let mut cache = LargeCache::new();
type Block = [usize; 16];
let mut blocks: Vec<Block> = Vec::with_capacity(1024);
unsafe { blocks.set_len(1024) };
for i in 0..blocks.len() {
if i % 2 == 1 {
continue;
} else {
let size = size_of::<Block>();
let ptr = NonNull::from(&mut blocks[i]);
cache.dealloc(ptr.cast(), size);
}
}
for _ in 0..8 {
let mut pointers = Vec::new();
for (_, size) in (0..512).zip((ALIGN..=size_of::<Block>()).cycle()) {
let size = size - (size % ALIGN);
let allocated = cache.alloc(size);
assert!(allocated.is_some());
let (ptr, s) = allocated.unwrap();
assert!(ptr.as_ptr() as usize % ALIGN == 0);
assert!(size <= s);
assert!(s < size + size_of::<Bucket>());
unsafe { ptr.as_ptr().write_bytes(0xff, s) };
unsafe { ptr.as_ptr().write_bytes(0xff, s) };
pointers.push((ptr, s));
}
for (ptr, size) in pointers {
cache.dealloc(ptr, size);
}
}
}
#[test]
fn test_alloc_fraction() {
let mut buckets: Vec<Bucket> = Vec::with_capacity(3);
unsafe { buckets.set_len(3) };
{
let mut cache = LargeCache::new();
{
let ptr = NonNull::from(&mut buckets[0]);
let size = size_of::<Bucket>();
cache.dealloc(ptr.cast(), size);
}
{
let size = ALIGN;
let allocated = cache.alloc(size);
assert!(allocated.is_some());
let (ptr, s) = allocated.unwrap();
assert!(size <= s);
unsafe { ptr.as_ptr().write_bytes(0xff, s) };
}
assert!(cache.is_empty());
}
{
let mut cache = LargeCache::new();
{
let size = 2 * size_of::<Bucket>();
let ptr = NonNull::from(&mut buckets[0]);
cache.dealloc(ptr.cast(), size);
}
{
let size = ALIGN;
let allocated = cache.alloc(size);
assert!(allocated.is_some());
let (ptr, s) = allocated.unwrap();
assert!(size <= s);
unsafe { ptr.as_ptr().write_bytes(0xff, s) };
}
{
let size = size_of::<Bucket>();
let allocated = cache.alloc(size);
assert!(allocated.is_some());
let (ptr, s) = allocated.unwrap();
assert!(size <= s);
unsafe { ptr.as_ptr().write_bytes(0xff, s) };
}
assert!(cache.is_empty());
}
{
let mut cache = LargeCache::new();
{
let size = 2 * size_of::<Bucket>();
let ptr = NonNull::from(&mut buckets[0]);
cache.dealloc(ptr.cast(), size);
}
{
let size = size_of::<Bucket>();
let allocated = cache.alloc(size);
assert!(allocated.is_some());
let (ptr, s) = allocated.unwrap();
assert!(size <= s);
unsafe { ptr.as_ptr().write_bytes(0xff, s) };
}
{
let size = ALIGN;
let allocated = cache.alloc(size);
assert!(allocated.is_some());
let (ptr, s) = allocated.unwrap();
assert!(size <= s);
unsafe { ptr.as_ptr().write_bytes(0xff, s) };
}
assert!(cache.is_empty());
}
{
let mut cache = LargeCache::new();
{
let size = size_of::<Bucket>();
let ptr = NonNull::from(&mut buckets[0]);
cache.dealloc(ptr.cast(), size);
let ptr = NonNull::from(&mut buckets[2]);
cache.dealloc(ptr.cast(), size);
}
for _ in 0..2 {
let size = ALIGN;
let allocated = cache.alloc(size);
assert!(allocated.is_some());
let (ptr, s) = allocated.unwrap();
assert!(size <= s);
unsafe { ptr.as_ptr().write_bytes(0xff, s) };
}
assert!(cache.is_empty());
}
}
#[test]
fn test_dealloc_merge() {
unsafe {
let mut buckets: Vec<Bucket> = Vec::with_capacity(5);
buckets.set_len(5);
let mut cache = LargeCache::new();
let size = size_of::<Bucket>();
{
cache.dealloc(NonNull::from(&mut buckets[0]).cast(), size);
let size_ptr = cache.size_tree.find(&buckets[0]);
assert!(size_ptr.is_some());
assert!(size_ptr.unwrap().as_ref().size() == size_of::<Bucket>());
let order_ptr = cache.order_tree.find(&buckets[0]);
assert!(order_ptr.is_some());
assert!(order_ptr.unwrap().as_ref().size() == size_of::<Bucket>());
for i in 1..5 {
assert!(cache.size_tree.find(&buckets[i]).is_none());
assert!(cache.order_tree.find(&buckets[i]).is_none());
}
}
{
cache.dealloc(NonNull::from(&mut buckets[4]).cast(), size);
let size_ptr = cache.size_tree.find(&buckets[0]);
assert!(size_ptr.is_some());
assert!(size_ptr.unwrap().as_ref().size() == size_of::<Bucket>());
let order_ptr = cache.order_tree.find(&buckets[0]);
assert!(order_ptr.is_some());
assert!(order_ptr.unwrap().as_ref().size() == size_of::<Bucket>());
let size_ptr = cache.size_tree.find(&buckets[4]);
assert!(size_ptr.is_some());
assert!(size_ptr.unwrap().as_ref().size() == size_of::<Bucket>());
let order_ptr = cache.order_tree.find(&buckets[4]);
assert!(order_ptr.is_some());
assert!(order_ptr.unwrap().as_ref().size() == size_of::<Bucket>());
for i in 1..4 {
assert!(cache.size_tree.find(&buckets[i]).is_none());
assert!(cache.order_tree.find(&buckets[i]).is_none());
}
}
{
cache.dealloc(NonNull::from(&mut buckets[1]).cast(), size);
let size_ptr = cache.size_tree.find(&buckets[0]);
assert!(size_ptr.is_some());
assert!(size_ptr.unwrap().as_ref().size() == 2 * size_of::<Bucket>());
let order_ptr = cache.order_tree.find(&buckets[0]);
assert!(order_ptr.is_some());
assert!(order_ptr.unwrap().as_ref().size() == 2 * size_of::<Bucket>());
let size_ptr = cache.size_tree.find(&buckets[4]);
assert!(size_ptr.is_some());
assert!(size_ptr.unwrap().as_ref().size() == size_of::<Bucket>());
let order_ptr = cache.order_tree.find(&buckets[4]);
assert!(order_ptr.is_some());
assert!(order_ptr.unwrap().as_ref().size() == size_of::<Bucket>());
for i in 2..4 {
assert!(cache.size_tree.find(&buckets[i]).is_none());
assert!(cache.order_tree.find(&buckets[i]).is_none());
}
}
{
cache.dealloc(NonNull::from(&mut buckets[3]).cast(), size);
let size_ptr = cache.size_tree.find(&buckets[0]);
assert!(size_ptr.is_some());
assert!(size_ptr.unwrap().as_ref().size() == 2 * size_of::<Bucket>());
let order_ptr = cache.order_tree.find(&buckets[0]);
assert!(order_ptr.is_some());
assert!(order_ptr.unwrap().as_ref().size() == 2 * size_of::<Bucket>());
let size_ptr = cache.size_tree.find(&buckets[3]);
assert!(size_ptr.is_some());
assert!(size_ptr.unwrap().as_ref().size() == 2 * size_of::<Bucket>());
let order_ptr = cache.order_tree.find(&buckets[3]);
assert!(order_ptr.is_some());
assert!(order_ptr.unwrap().as_ref().size() == 2 * size_of::<Bucket>());
for i in 2..3 {
assert!(cache.size_tree.find(&buckets[i]).is_none());
assert!(cache.order_tree.find(&buckets[i]).is_none());
}
}
{
cache.dealloc(NonNull::from(&mut buckets[2]).cast(), size);
let size_ptr = cache.size_tree.find(&buckets[0]);
assert!(size_ptr.is_some());
assert!(size_ptr.unwrap().as_ref().size() == 5 * size_of::<Bucket>());
let order_ptr = cache.order_tree.find(&buckets[0]);
assert!(order_ptr.is_some());
assert!(order_ptr.unwrap().as_ref().size() == 5 * size_of::<Bucket>());
for i in 1..5 {
assert!(cache.size_tree.find(&buckets[i]).is_none());
assert!(cache.order_tree.find(&buckets[i]).is_none());
}
}
}
}
#[test]
fn test_dealloc_small_merge() {
unsafe {
let mut buckets: Vec<Bucket> = Vec::with_capacity(3);
buckets.set_len(3);
let mut cache = LargeCache::new();
let size = size_of::<Bucket>();
{
assert!(cache.dealloc(NonNull::from(&mut buckets[1]).cast(), size));
let size_ptr = cache.size_tree.find(&buckets[1]);
assert!(size_ptr.is_some());
assert!(size_ptr.unwrap().as_ref().size() == size_of::<Bucket>());
let order_ptr = cache.order_tree.find(&buckets[1]);
assert!(order_ptr.is_some());
assert!(order_ptr.unwrap().as_ref().size() == size_of::<Bucket>());
for i in [0, 2] {
assert!(cache.size_tree.find(&buckets[i]).is_none());
assert!(cache.order_tree.find(&buckets[i]).is_none());
}
}
{
assert!(cache.dealloc(NonNull::from(&mut buckets[0]).cast(), ALIGN) == false);
for i in [0, 2] {
assert!(cache.size_tree.find(&buckets[i]).is_none());
assert!(cache.order_tree.find(&buckets[i]).is_none());
}
}
{
let ptr: *mut u8 = (&mut buckets[1] as *mut Bucket).cast();
let ptr = NonNull::new(ptr.offset(-1 * ALIGN as isize)).unwrap();
assert!(cache.dealloc(ptr, ALIGN) == true);
let size_ptr = cache.size_tree.find(&(size_of::<Bucket>() + ALIGN));
assert!(size_ptr.is_some());
assert!(size_ptr.unwrap().as_ref().size() == size_of::<Bucket>() + ALIGN);
let order_ptr = cache.order_tree.find(&ptr);
assert!(order_ptr.is_some());
assert!(order_ptr.unwrap().as_ref().size() == size_of::<Bucket>() + ALIGN);
for i in [0, 2] {
assert!(cache.size_tree.find(&buckets[i]).is_none());
assert!(cache.order_tree.find(&buckets[i]).is_none());
}
}
{
assert!(cache.dealloc(NonNull::from(&mut buckets[2]).cast(), ALIGN) == true);
let size_ptr = cache.size_tree.find(&(size_of::<Bucket>() + 2 * ALIGN));
assert!(size_ptr.is_some());
assert!(size_ptr.unwrap().as_ref().size() == size_of::<Bucket>() + 2 * ALIGN);
let ptr: *mut u8 = (&mut buckets[1] as *mut Bucket).cast();
let ptr = NonNull::new(ptr.offset(-1 * ALIGN as isize)).unwrap();
let order_ptr = cache.order_tree.find(&ptr);
assert!(order_ptr.is_some());
assert!(order_ptr.unwrap().as_ref().size() == size_of::<Bucket>() + 2 * ALIGN);
}
}
}
}