use crate::{ArrayData, Primitive};
use std::{
alloc::{self, Layout},
any,
borrow::Borrow,
fmt::{Debug, Formatter, Result},
iter::{Copied, FromIterator},
mem,
ops::Deref,
ptr::{self, NonNull},
slice::{self, Iter},
};
pub(crate) const ALIGNMENT: usize = 6;
pub struct Buffer<T, const A: usize>
where
T: Primitive,
{
ptr: NonNull<T>,
len: usize,
}
impl<T, const A: usize> Buffer<T, A>
where
T: Primitive,
{
pub(crate) unsafe fn new_unchecked(ptr: *mut T, len: usize) -> Self {
Self {
ptr: NonNull::new_unchecked(ptr),
len,
}
}
fn layout(&self) -> Layout {
layout::<T, A>(self.len())
}
fn from_slice(slice: &[T]) -> Self {
let ptr = unsafe { alloc::<T, A>(layout::<T, A>(slice.len())) };
unsafe { ptr::copy_nonoverlapping(slice.as_ptr(), ptr, slice.len()) }
Self {
ptr:
unsafe { NonNull::new_unchecked(ptr) },
len: slice.len(),
}
}
}
impl<T, const A: usize> ArrayData for Buffer<T, A>
where
T: Primitive,
{
fn len(&self) -> usize {
self.len
}
fn is_null(&self, index: usize) -> bool {
#[cold]
#[inline(never)]
fn assert_failed(index: usize, len: usize) -> ! {
panic!("is_null index (is {}) should be < len (is {})", index, len);
}
let len = self.len();
if index >= len {
assert_failed(index, len);
}
false
}
fn null_count(&self) -> usize {
0
}
fn is_valid(&self, index: usize) -> bool {
#[cold]
#[inline(never)]
fn assert_failed(index: usize, len: usize) -> ! {
panic!("is_valid index (is {}) should be < len (is {})", index, len);
}
let len = self.len();
if index >= len {
assert_failed(index, len);
}
true
}
fn valid_count(&self) -> usize {
self.len
}
}
impl<T, const A: usize> AsRef<Buffer<T, A>> for Buffer<T, A>
where
T: Primitive,
{
fn as_ref(&self) -> &Buffer<T, A> {
self
}
}
impl<T, const A: usize> AsRef<[u8]> for Buffer<T, A>
where
T: Primitive,
{
fn as_ref(&self) -> &[u8] {
unsafe {
slice::from_raw_parts(
self.ptr.as_ptr() as *const u8,
self.len * mem::size_of::<T>(),
)
}
}
}
impl<T, const A: usize> Borrow<[T]> for Buffer<T, A>
where
T: Primitive,
{
fn borrow(&self) -> &[T] {
self
}
}
impl<T, const A: usize> Clone for Buffer<T, A>
where
T: Primitive,
{
fn clone(&self) -> Self {
Self::from_slice(self)
}
}
impl<T, const A: usize> Debug for Buffer<T, A>
where
T: Primitive + Debug,
{
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
f.debug_struct(&format!("Buffer<{}, {}>", any::type_name::<T>(), A))
.field("values", &self.deref())
.finish()
}
}
impl<T, const A: usize> Default for Buffer<T, A>
where
T: Primitive,
{
fn default() -> Self {
Self {
ptr: NonNull::dangling(),
len: 0,
}
}
}
impl<T, const A: usize> Deref for Buffer<T, A>
where
T: Primitive,
{
type Target = [T];
fn deref(&self) -> &Self::Target {
unsafe { slice::from_raw_parts(self.ptr.as_ptr() as *const T, self.len) }
}
}
impl<T, const A: usize> Drop for Buffer<T, A>
where
T: Primitive,
{
fn drop(&mut self) {
if self.len != 0 {
unsafe {
alloc::dealloc(self.ptr.as_ptr() as *mut u8, self.layout());
}
}
}
}
impl<T, const A: usize> Eq for Buffer<T, A>
where
T: Primitive,
for<'a> &'a [T]: PartialEq,
{
}
impl<T, const A: usize> FromIterator<T> for Buffer<T, A>
where
T: Primitive,
{
fn from_iter<I>(iter: I) -> Self
where
I: IntoIterator<Item = T>,
{
let mut iter = iter.into_iter();
match iter.next() {
Some(value) => {
let (lower_bound, _) = iter.size_hint();
let mut ptr = unsafe { alloc::<T, A>(layout::<T, A>(lower_bound + 1)) };
unsafe { ptr.write(value) };
let mut len = 1;
while len < lower_bound {
unsafe {
ptr.add(len).write(
iter.next()
.expect("reported lower bound of size hint incorrect"),
);
}
len += 1;
}
for value in iter {
ptr = unsafe { realloc::<T, A, A>(ptr, len, len + 1) };
unsafe { ptr.add(len).write(value) };
len += 1;
}
Self {
ptr: unsafe { NonNull::new_unchecked(ptr) },
len,
}
}
None => Self::default(),
}
}
}
impl<'a, T, const A: usize> IntoIterator for &'a Buffer<T, A>
where
T: Primitive,
{
type Item = T;
type IntoIter = Copied<Iter<'a, T>>;
fn into_iter(self) -> Self::IntoIter {
self.iter().copied()
}
}
impl<T, const A: usize> PartialEq for Buffer<T, A>
where
T: Primitive,
for<'a> &'a [T]: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.len == other.len
&& (self.len == 0 || (self.layout() == other.layout() && &self[..] == &other[..]))
}
}
unsafe impl<T, const A: usize> Send for Buffer<T, A> where T: Primitive {}
unsafe impl<T, const A: usize> Sync for Buffer<T, A> where T: Primitive {}
pub(crate) fn layout<T, const A: usize>(length: usize) -> Layout
where
T: Primitive,
{
assert!(length != 0, "Zero-sized layouts are not supported");
let align = 1 << A;
assert!(
align % mem::align_of::<T>() == 0,
"Alignment `A` must be a multiple of the ABI-required minimum alignment of type `T`"
);
let size = length * mem::size_of::<T>();
let padding =
(size.wrapping_add(align).wrapping_sub(1) & !align.wrapping_sub(1)).wrapping_sub(size);
let (rounded_up_len, overflow) = size.overflowing_add(padding);
assert!(!overflow, "Allocation size overflow");
unsafe { Layout::from_size_align_unchecked(rounded_up_len, align) }
}
pub(crate) unsafe fn alloc<T, const A: usize>(layout: Layout) -> *mut T
where
T: Primitive,
{
let ptr = alloc::alloc(layout) as *mut T;
assert!(!ptr.is_null(), "Allocation failed");
ptr
}
pub(crate) unsafe fn realloc<T, const A: usize, const B: usize>(
ptr: *mut T,
old_length: usize,
new_length: usize,
) -> *mut T
where
T: Primitive,
{
let old_layout = layout::<T, A>(old_length);
let new_layout = layout::<T, B>(new_length);
if old_layout == new_layout {
ptr
} else {
let new_ptr = alloc::<T, B>(new_layout);
ptr::copy_nonoverlapping(ptr as *const T, new_ptr, old_length);
alloc::dealloc(ptr as *mut u8, old_layout);
new_ptr
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[should_panic(expected = "Zero-sized layouts are not supported")]
fn layout_zero_sized() {
layout::<u8, 0>(0).size();
}
#[test]
#[should_panic(expected = "Allocation size overflow")]
fn layout_overflow() {
layout::<u8, 6>(usize::MAX - 62);
}
#[test]
#[should_panic(
expected = "Alignment `A` must be a multiple of the ABI-required minimum alignment of type `T`"
)]
fn layout_bad_align() {
layout::<u64, 0>(1234);
}
#[test]
fn layout_size() {
assert_eq!(layout::<u8, 0>(1).size(), 1);
assert_eq!(layout::<u8, 5>(1).size(), 32);
assert_eq!(layout::<u8, 5>(32).size(), 32);
assert_eq!(layout::<u8, 5>(33).size(), 64);
assert_eq!(layout::<u8, 6>(1).size(), 64);
assert_eq!(layout::<u8, 6>(64).size(), 64);
assert_eq!(layout::<u8, 6>(65).size(), 128);
assert_eq!(layout::<u32, 6>(5).size(), 64);
assert_eq!(layout::<f64, 6>(8).size(), 64);
assert_eq!(layout::<f64, 6>(9).size(), 128);
}
#[test]
fn as_ref() {
let buffer: Buffer<_, 7> = [1u32, 2, 3, 4].iter().copied().collect();
let x: &Buffer<_, 7> = buffer.as_ref();
assert_eq!(x.len(), 4);
let x: &[u8] = buffer.as_ref();
assert_eq!(x.len(), 4 * 4);
}
#[test]
fn as_ref_u8() {
let vec = vec![42u32, u32::MAX, 0xc0fefe];
let buffer: Buffer<_, 7> = vec.into_iter().collect();
assert_eq!(
AsRef::<[u8]>::as_ref(&buffer),
&[42u8, 0, 0, 0, 255, 255, 255, 255, 254, 254, 192, 0]
);
}
#[test]
fn borrow() {
let buffer: Buffer<_, 7> = [1u32, 2, 3, 4].iter().copied().collect();
fn borrow_u32<T: Borrow<[u32]>>(input: T) {
assert_eq!(input.borrow(), &[1, 2, 3, 4]);
}
borrow_u32(buffer);
}
#[test]
fn deref() {
let buffer: Buffer<_, 3> = [1u32, 2, 3, 4].iter().copied().collect();
assert_eq!(buffer.len(), 4);
assert_eq!(&buffer[2..], &[3, 4]);
}
#[test]
fn default() {
let buffer: Buffer<u8, 6> = Buffer::default();
assert!(buffer.is_empty());
assert_eq!(buffer.len(), 0);
assert!(buffer.first().is_none());
assert!(buffer.get(1234).is_none());
assert!(buffer.iter().next().is_none());
let slice = &buffer[..];
let empty_slice: &[u8] = &[];
assert_eq!(slice, empty_slice);
let bytes: &[u8] = buffer.as_ref();
assert_eq!(bytes, empty_slice);
}
#[test]
fn from_iter() {
let vec = vec![1u32, 2, 3, 4];
let buffer = vec.clone().into_iter().collect::<Buffer<_, 6>>();
assert_eq!(buffer.len(), 4);
assert_eq!(&vec[..], &buffer[..]);
}
#[test]
fn from_iter_ref() {
let vec = vec![1u32, 2, 3, 4];
let buffer = vec.iter().copied().collect::<Buffer<_, 4>>();
assert_eq!(buffer.len(), 4);
assert_eq!(&vec[..], &buffer[..]);
}
#[test]
fn into_iter() {
let vec = vec![1u32, 2, 3, 4];
let other = vec
.iter()
.copied()
.collect::<Buffer<_, 5>>()
.into_iter()
.collect::<Vec<_>>();
assert_eq!(vec, other);
}
}