use core::ptr::NonNull;
use std::alloc::Layout;
use super::{BLOCK_CONTENT_MIN_ALIGN, BLOCK_CONTENT_MIN_SIZE, BLOCK_MIN_SIZE};
use crate::header::{Header, HEADER_SIZE};
use crate::util::find_aligned;
#[inline]
pub fn find_divisible(x: usize, y: usize) -> Option<usize> {
if x % y == 0 {
Some(x)
} else {
((x / y) * y).checked_add(y)
}
}
pub fn find_place(block_start: *const u8, obj_align: usize) -> Option<NonNull<u8>> {
let mut obj_start = block_start;
loop {
let dist = obj_start as usize - block_start as usize;
if dist == HEADER_SIZE || dist >= HEADER_SIZE + BLOCK_MIN_SIZE {
break;
}
if obj_start as usize == usize::MAX {
return None;
}
obj_start = find_aligned(unsafe { obj_start.add(1).cast() }, obj_align)?;
}
unsafe { Some(NonNull::new_unchecked(obj_start as *mut u8)) }
}
#[inline]
pub fn augment_size(size: usize) -> Result<usize, ()> {
match find_divisible(size.max(BLOCK_CONTENT_MIN_SIZE), HEADER_SIZE) {
Some(new_size) if new_size as isize > 0 => Ok(new_size),
_ => Err(()),
}
}
pub fn augment_layout(layout: Layout) -> Result<Layout, ()> {
let obj_align = layout.align().max(BLOCK_CONTENT_MIN_ALIGN);
let obj_size = augment_size(layout.size())?;
debug_assert!(Layout::from_size_align(obj_size, obj_align).is_ok());
unsafe { Ok(Layout::from_size_align_unchecked(obj_size, obj_align)) }
}
#[inline]
pub unsafe fn to_nonnull_slice(obj_start: NonNull<u8>) -> NonNull<[u8]> {
let block_header: *const Header = unsafe { obj_start.as_ptr().sub(HEADER_SIZE).cast() };
debug_assert!(!(*block_header).is_tagged());
let obj_size = (*block_header).__content_size;
NonNull::slice_from_raw_parts(obj_start, obj_size)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::header::HEADER_ALIGN;
use core::ptr::null;
#[test]
fn test_find_divisible_1() {
assert_eq!(find_divisible(5, 5).unwrap(), 5);
assert_eq!(find_divisible(5, 10).unwrap(), 10);
assert_eq!(find_divisible(0, 100).unwrap(), 0);
}
#[test]
#[should_panic]
fn test_find_divisible_2() {
let _ = find_divisible(5, 0);
}
#[test]
fn test_find_divisible_3() {
assert!(find_divisible(usize::MAX, 2).is_none());
assert_eq!(find_divisible(usize::MAX - 7, 8), Some(usize::MAX - 7));
}
#[test]
fn test_find_place_1() {
assert_eq!(
find_place(null(), HEADER_ALIGN).unwrap().as_ptr() as usize,
HEADER_SIZE
);
assert!(
find_place(1 as *const u8, HEADER_ALIGN)
.unwrap()
.as_ptr() as usize
> BLOCK_MIN_SIZE + HEADER_SIZE
);
assert!(find_place(usize::MAX as *const u8, HEADER_ALIGN).is_none());
}
#[test]
#[should_panic]
fn test_find_place_2() {
find_place(null(), 7);
}
#[test]
fn test_find_place_3() {
for i in 4000..5000 {
for j in 0..20 {
let tmp = find_place(i as *const u8, 1 << j).unwrap().as_ptr();
assert_eq!(tmp as usize % (1 << j), 0);
let diff = tmp as usize - i as usize;
assert!(diff == HEADER_SIZE || diff >= BLOCK_MIN_SIZE + HEADER_SIZE);
}
}
}
#[test]
fn test_augment_layout_1() {
for size in HEADER_SIZE + 1..=2 * HEADER_SIZE {
for align in (0..=usize::ilog2(HEADER_ALIGN)).map(|i| 1 << i) {
let layout = Layout::from_size_align(size, align).unwrap();
let augmented = augment_layout(layout).unwrap();
assert_eq!(
augmented,
Layout::from_size_align(HEADER_SIZE * 2, BLOCK_CONTENT_MIN_ALIGN).unwrap()
);
}
}
}
#[test]
fn test_to_nonnull_slice() {
let obj_size = 20;
let mut header = unsafe { Header::new_unchecked(obj_size, false) };
let block_start: *mut u8 = (&mut header as *mut Header).cast();
let obj_start = unsafe { block_start.add(HEADER_SIZE) };
let obj_as_slice = unsafe { to_nonnull_slice(NonNull::new(obj_start).unwrap()) };
assert_eq!(obj_as_slice.as_ptr() as *mut u8, obj_start);
assert_eq!(obj_as_slice.len(), obj_size);
}
}