use std::{
alloc::Layout,
mem::MaybeUninit,
ops::{Deref, DerefMut},
ptr::NonNull,
};
#[cfg(feature = "flatbuffers")]
use super::Allocator;
use super::{AllocatorCore, AllocatorError, GlobalAllocator, TryClone};
#[derive(Debug)]
#[repr(C)]
pub struct Poly<T, A = GlobalAllocator>
where
T: ?Sized,
A: AllocatorCore,
{
ptr: NonNull<T>,
allocator: A,
}
unsafe impl<T, A> Send for Poly<T, A>
where
T: ?Sized + Send,
A: AllocatorCore + Send,
{
}
unsafe impl<T, A> Sync for Poly<T, A>
where
T: ?Sized + Sync,
A: AllocatorCore + Sync,
{
}
#[derive(Debug, Clone, Copy)]
pub enum CompoundError<E> {
Allocator(AllocatorError),
Constructor(E),
}
impl<T, A> Poly<T, A>
where
A: AllocatorCore,
{
pub fn new(value: T, allocator: A) -> Result<Self, AllocatorError> {
if std::mem::size_of::<T>() == 0 {
Ok(Self {
ptr: NonNull::dangling(),
allocator,
})
} else {
let ptr = allocator.allocate(Layout::new::<T>())?;
let ptr: NonNull<T> = unsafe {
let ptr = ptr.cast::<T>();
ptr.as_ptr().write(value);
ptr
};
Ok(Self { ptr, allocator })
}
}
pub fn new_with<F, E>(f: F, allocator: A) -> Result<Self, CompoundError<E>>
where
F: FnOnce(A) -> Result<T, E>,
A: Clone,
{
let mut this = Self::new_uninit(allocator.clone()).map_err(CompoundError::Allocator)?;
this.write(f(allocator).map_err(CompoundError::Constructor)?);
Ok(unsafe { this.assume_init() })
}
pub fn new_uninit(allocator: A) -> Result<Poly<MaybeUninit<T>, A>, AllocatorError> {
if std::mem::size_of::<T>() == 0 {
Ok(Poly {
ptr: NonNull::dangling(),
allocator,
})
} else {
let ptr = allocator.allocate(Layout::new::<MaybeUninit<T>>())?;
let ptr: NonNull<MaybeUninit<T>> = ptr.cast::<MaybeUninit<T>>();
Ok(Poly { ptr, allocator })
}
}
}
impl<T, A> Poly<T, A>
where
T: ?Sized,
A: AllocatorCore,
{
pub fn into_raw(this: Self) -> (NonNull<T>, A) {
let ptr = this.ptr;
let allocator = unsafe { std::ptr::read(&this.allocator) };
std::mem::forget(this);
(ptr, allocator)
}
pub unsafe fn from_raw(ptr: NonNull<T>, allocator: A) -> Self {
Poly { ptr, allocator }
}
pub fn as_ptr(this: &Self) -> *const T {
this.ptr.as_ptr().cast_const()
}
pub fn allocator(&self) -> &A {
&self.allocator
}
}
impl<T, A> Poly<MaybeUninit<T>, A>
where
A: AllocatorCore,
{
pub unsafe fn assume_init(self) -> Poly<T, A> {
let (ptr, allocator) = Poly::into_raw(self);
unsafe { Poly::from_raw(ptr.cast::<T>(), allocator) }
}
}
impl<T, A> Poly<[T], A>
where
A: AllocatorCore,
{
pub fn new_uninit_slice(
len: usize,
allocator: A,
) -> Result<Poly<[MaybeUninit<T>], A>, AllocatorError> {
let layout = Layout::array::<T>(len).map_err(|_| AllocatorError)?;
let ptr = if layout.size() == 0 {
unsafe {
NonNull::new_unchecked(std::ptr::slice_from_raw_parts_mut(
NonNull::dangling().as_ptr(),
len,
))
}
} else {
let ptr = allocator.allocate(layout)?;
debug_assert_eq!(ptr.len(), layout.size());
unsafe {
NonNull::new_unchecked(std::ptr::slice_from_raw_parts_mut(
ptr.as_ptr().cast::<MaybeUninit<T>>(),
len,
))
}
};
Ok(unsafe { Poly::from_raw(ptr, allocator) })
}
pub fn from_iter<I>(iter: I, allocator: A) -> Result<Self, AllocatorError>
where
I: TrustedIter<Item = T>,
{
struct Guard<'a, T, A>
where
A: AllocatorCore,
{
uninit: &'a mut Poly<[MaybeUninit<T>], A>,
initialized_to: usize,
}
impl<T, A> Drop for Guard<'_, T, A>
where
A: AllocatorCore,
{
fn drop(&mut self) {
if std::mem::needs_drop::<T>() {
self.uninit
.iter_mut()
.take(self.initialized_to)
.for_each(|u|
unsafe { u.assume_init_drop() });
}
}
}
let mut uninit = Poly::<[T], A>::new_uninit_slice(iter.len(), allocator)?;
let mut guard = Guard {
uninit: &mut uninit,
initialized_to: 0,
};
std::iter::zip(iter, guard.uninit.iter_mut()).for_each(|(src, dst)| {
dst.write(src);
guard.initialized_to += 1;
});
debug_assert_eq!(
guard.initialized_to,
guard.uninit.len(),
"an incorrect number of elements was initialized",
);
std::mem::forget(guard);
Ok(unsafe { uninit.assume_init() })
}
}
impl<T, A> Poly<[MaybeUninit<T>], A>
where
A: AllocatorCore,
{
pub unsafe fn assume_init(self) -> Poly<[T], A> {
let len = self.deref().len();
let (ptr, allocator) = Poly::into_raw(self);
let ptr = unsafe {
NonNull::new_unchecked(std::ptr::slice_from_raw_parts_mut(
ptr.as_ptr().cast::<T>(),
len,
))
};
unsafe { Poly::<[T], A>::from_raw(ptr, allocator) }
}
}
impl<T, A> Poly<[T], A>
where
A: AllocatorCore,
T: Clone,
{
pub fn broadcast(value: T, len: usize, allocator: A) -> Result<Self, AllocatorError> {
Self::from_iter((0..len).map(|_| value.clone()), allocator)
}
}
impl<T, A> Drop for Poly<T, A>
where
T: ?Sized,
A: AllocatorCore,
{
fn drop(&mut self) {
let layout = Layout::for_value(unsafe { self.ptr.as_ref() });
unsafe { std::ptr::drop_in_place(self.ptr.as_ptr()) };
if layout.size() != 0 {
let as_slice =
std::ptr::slice_from_raw_parts_mut(self.ptr.as_ptr().cast::<u8>(), layout.size());
let ptr = unsafe { NonNull::new_unchecked(as_slice) };
unsafe { self.allocator.deallocate(ptr, layout) }
}
}
}
impl<T, A> Deref for Poly<T, A>
where
T: ?Sized,
A: AllocatorCore,
{
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { self.ptr.as_ref() }
}
}
impl<T, A> DerefMut for Poly<T, A>
where
T: ?Sized,
A: AllocatorCore,
{
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { self.ptr.as_mut() }
}
}
pub unsafe trait TrustedIter: ExactSizeIterator {}
unsafe impl<T> TrustedIter for std::slice::Iter<'_, T> {}
unsafe impl<T> TrustedIter for std::vec::IntoIter<T> {}
unsafe impl TrustedIter for std::ops::Range<usize> {}
unsafe impl<T, const N: usize> TrustedIter for std::array::IntoIter<T, N> {}
unsafe impl TrustedIter for rand::seq::index::IndexVecIntoIter {}
#[cfg(feature = "flatbuffers")]
unsafe impl<'a, T> TrustedIter for flatbuffers::VectorIter<'a, T> where T: flatbuffers::Follow<'a> {}
unsafe impl<I, U, F> TrustedIter for std::iter::Map<I, F>
where
I: TrustedIter,
F: FnMut(I::Item) -> U,
{
}
unsafe impl<I> TrustedIter for std::iter::Enumerate<I> where I: TrustedIter {}
unsafe impl<'a, I, T> TrustedIter for std::iter::Cloned<I>
where
I: TrustedIter<Item = &'a T>,
T: 'a + Clone,
{
}
unsafe impl<'a, I, T> TrustedIter for std::iter::Copied<I>
where
I: TrustedIter<Item = &'a T>,
T: 'a + Copy,
{
}
unsafe impl<T, U> TrustedIter for std::iter::Zip<T, U>
where
T: TrustedIter,
U: TrustedIter,
{
}
#[macro_export]
macro_rules! poly {
({ $($traits:tt)+ }, $v:ident, $alloc:ident) => {{
$crate::alloc::Poly::new($v, $alloc).map(|poly| {
$crate::alloc::poly!({ $($traits)+ }, poly)
})
}};
($trait:path, $v:ident, $alloc:ident) => {{
$crate::alloc::poly!({ $trait }, $v, $alloc)
}};
({ $($traits:tt)+ }, $poly:ident) => {{
let (ptr, alloc) = $crate::alloc::Poly::into_raw($poly);
unsafe { $crate::alloc::Poly::<dyn $($traits)*, _>::from_raw(ptr, alloc) }
}};
($trait:path, $poly:ident) => {{
$crate::alloc::poly!({ $trait }, $poly)
}};
([$($x:expr),* $(,)?], $alloc:ident) => {{
Poly::from_iter([$($x,)*].into_iter(), $alloc)
}}
}
pub use poly;
impl<T, A> TryClone for Poly<T, A>
where
T: Clone,
A: AllocatorCore + Clone,
{
fn try_clone(&self) -> Result<Self, AllocatorError> {
let clone = (*self).clone();
Poly::new(clone, self.allocator().clone())
}
}
impl<T, A> TryClone for Poly<[T], A>
where
T: Clone,
A: AllocatorCore + Clone,
{
fn try_clone(&self) -> Result<Self, AllocatorError> {
Poly::from_iter(self.iter().cloned(), self.allocator().clone())
}
}
impl<T, A> TryClone for Option<Poly<T, A>>
where
T: ?Sized,
A: AllocatorCore,
Poly<T, A>: super::TryClone,
{
fn try_clone(&self) -> Result<Self, AllocatorError> {
Ok(match self {
Some(v) => Some(v.try_clone()?),
None => None,
})
}
}
impl<T, A> PartialEq for Poly<T, A>
where
T: ?Sized + PartialEq,
A: AllocatorCore,
{
#[inline]
fn eq(&self, other: &Self) -> bool {
PartialEq::eq(&**self, &**other)
}
}
impl<T> From<Box<[T]>> for Poly<[T], GlobalAllocator> {
fn from(value: Box<[T]>) -> Self {
unsafe {
Poly::from_raw(
NonNull::new_unchecked(Box::into_raw(value)),
GlobalAllocator,
)
}
}
}
#[cfg(feature = "flatbuffers")]
unsafe impl<A> flatbuffers::Allocator for Poly<[u8], A>
where
A: Allocator,
{
type Error = AllocatorError;
fn grow_downwards(&mut self) -> Result<(), Self::Error> {
let next_len = (2 * self.len()).max(1);
let mut next = Poly::broadcast(0u8, next_len, self.allocator().clone())?;
next[next_len - self.len()..].copy_from_slice(self);
*self = next;
Ok(())
}
fn len(&self) -> usize {
(**self).len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_util::AlwaysFails;
struct HasHoles {
s: String,
a: u32,
b: u8,
}
impl HasHoles {
fn new(s: String, a: u32, b: u8) -> Self {
Self { s, a, b }
}
}
fn assert_is_send<T>(_: &T)
where
T: Send,
{
}
#[test]
fn size_check() {
assert_eq!(std::mem::size_of::<Poly<usize>>(), 8);
assert_eq!(std::mem::size_of::<Option<Poly<usize>>>(), 8);
}
#[test]
fn basic_test_copy() {
let x = 10usize;
let poly = Poly::new(x, GlobalAllocator).unwrap();
assert_eq!(*poly, 10);
}
#[test]
fn basic_test_borrow() {
let x = &10usize;
let poly = Poly::<&usize>::new(x, GlobalAllocator).unwrap();
assert_eq!(**poly, 10);
}
#[test]
fn test_with_drop() {
let poly = Poly::<String>::new("hello world".to_string(), GlobalAllocator).unwrap();
assert_eq!(&**poly, "hello world");
}
#[test]
fn test_mutate() {
let mut poly = Poly::<String>::new("foo".to_string(), GlobalAllocator).unwrap();
assert_eq!(&**poly, "foo");
*poly = "bar".to_string();
assert_eq!(&**poly, "bar");
}
#[test]
fn zero_sized() {
let _ = Poly::<()>::new((), GlobalAllocator).unwrap();
}
#[test]
fn zero_sized_raw() {
let x = Poly::<()>::new((), GlobalAllocator).unwrap();
let (ptr, alloc) = Poly::into_raw(x);
let _ = unsafe { Poly::from_raw(ptr, alloc) };
}
#[test]
fn zero_sized_uninit() {
let _ = Poly::<()>::new_uninit(GlobalAllocator).unwrap();
}
#[test]
fn zero_sized_uninit_to_init() {
let x = Poly::<()>::new_uninit(GlobalAllocator).unwrap();
let _ = unsafe { x.assume_init() };
}
#[test]
fn zero_sized_slice() {
let x = Poly::<[()]>::from_iter((0..0).map(|_| ()), GlobalAllocator).unwrap();
assert!(x.is_empty());
let x = Poly::<[()]>::from_iter((0..10).map(|_| ()), GlobalAllocator).unwrap();
assert_eq!(x.len(), 10);
let x = Poly::<[usize]>::from_iter(0..0, GlobalAllocator).unwrap();
assert!(x.is_empty());
let x =
Poly::<[String]>::from_iter((0..0).map(|i| i.to_string()), GlobalAllocator).unwrap();
assert!(x.is_empty());
}
#[test]
fn dropping_uninit_is_okay() {
let _ = Poly::<HasHoles>::new_uninit(GlobalAllocator).unwrap();
}
#[test]
fn test_assume_init() {
let mut poly = Poly::<HasHoles>::new_uninit(GlobalAllocator).unwrap();
poly.write(HasHoles::new("hello world".into(), 10, 5));
let poly: Poly<HasHoles> = unsafe { poly.assume_init() };
assert_eq!(poly.s, "hello world");
assert_eq!(poly.a, 10);
assert_eq!(poly.b, 5);
}
#[test]
fn test_assume_init_slice_copy() {
let mut poly = Poly::<[usize]>::new_uninit_slice(10, GlobalAllocator).unwrap();
assert_eq!(poly.len(), 10);
for (i, v) in poly.iter_mut().enumerate() {
v.write(i);
}
let poly: Poly<[usize]> = unsafe { poly.assume_init() };
for (i, v) in poly.iter().enumerate() {
assert_eq!(*v, i);
}
}
#[test]
fn test_assume_init_slice_drop() {
let mut poly = Poly::<[HasHoles]>::new_uninit_slice(10, GlobalAllocator).unwrap();
assert_eq!(poly.len(), 10);
for (i, v) in poly.iter_mut().enumerate() {
v.write(HasHoles::new(
i.to_string(),
i.try_into().unwrap(),
i.try_into().unwrap(),
));
}
let poly: Poly<[HasHoles]> = unsafe { poly.assume_init() };
for (i, v) in poly.iter().enumerate() {
assert_eq!(v.s, i.to_string());
assert_eq!(v.a as usize, i);
assert_eq!(v.b as usize, i);
}
}
#[test]
fn from_iter_strings() {
let p =
Poly::<[String], _>::from_iter((0..5).map(|i| i.to_string()), GlobalAllocator).unwrap();
assert_eq!(&*p, &["0", "1", "2", "3", "4"])
}
#[test]
#[should_panic(expected = "first")]
fn from_iter_cleanup_first() {
Poly::<[String], _>::from_iter((0..5).map(|_| panic!("first")), GlobalAllocator).unwrap();
}
#[test]
#[should_panic(expected = "middle")]
fn from_iter_cleanup_middle() {
Poly::<[String], _>::from_iter(
(0..5).map(|i| {
if i == 3 {
panic!("middle");
} else {
i.to_string()
}
}),
GlobalAllocator,
)
.unwrap();
}
#[test]
#[should_panic(expected = "last")]
fn from_iter_cleanup_last() {
Poly::<[String], _>::from_iter(
(0..5).map(|i| {
let string = i.to_string();
if i == 4 {
panic!("last");
}
string
}),
GlobalAllocator,
)
.unwrap();
}
#[test]
fn new_error() {
let _ = Poly::new(10usize, AlwaysFails).unwrap_err();
}
#[test]
fn new_with_error() {
let err = Poly::new_with(
|_| -> Result<u8, std::convert::Infallible> { Ok(0) },
AlwaysFails,
)
.unwrap_err();
assert!(matches!(err, CompoundError::Allocator(_)));
let err = Poly::new_with(
|_| -> Result<u8, std::num::TryFromIntError> {
let x: u8 = (1000usize).try_into()?;
Ok(x)
},
GlobalAllocator,
)
.unwrap_err();
assert!(matches!(
err,
CompoundError::Constructor(std::num::TryFromIntError { .. })
));
}
#[test]
fn new_uninit_error() {
let _ = Poly::<String, _>::new_uninit(AlwaysFails).unwrap_err();
}
#[test]
fn new_uninit_slice_error() {
let _ = Poly::<[usize], _>::new_uninit_slice(10, AlwaysFails).unwrap_err();
}
#[test]
fn new_from_iter_error() {
let _ = Poly::<[usize], _>::from_iter(0..10, AlwaysFails).unwrap_err();
}
trait Describe {
fn describe(&self) -> String;
fn describe_mut(&mut self) -> String;
}
struct ImplsDescribe;
impl Describe for ImplsDescribe {
fn describe(&self) -> String {
"describe const".to_string()
}
fn describe_mut(&mut self) -> String {
"describe mut".to_string()
}
}
struct AlsoImplsDescribe(String);
impl Describe for AlsoImplsDescribe {
fn describe(&self) -> String {
format!("describe const: {}", self.0)
}
fn describe_mut(&mut self) -> String {
format!("describe mut: {}", self.0)
}
}
struct DescribeLifetime<'a>(&'a str);
impl Describe for DescribeLifetime<'_> {
fn describe(&self) -> String {
format!("describe const: {}", self.0)
}
fn describe_mut(&mut self) -> String {
format!("describe mut: {}", self.0)
}
}
trait Foo<T> {
fn foo(&self, v: T) -> T;
}
impl Foo<f32> for f32 {
fn foo(&self, v: f32) -> f32 {
*self + v
}
}
#[test]
fn test_dyn_trait() {
{
let mut poly0 = poly!(Describe, ImplsDescribe, GlobalAllocator).unwrap();
let also = AlsoImplsDescribe("foo".to_string());
let mut poly1 = poly!({ Describe + Send }, also, GlobalAllocator).unwrap();
assert_is_send::<Poly<dyn Describe + Send, _>>(&poly1);
assert_eq!(poly1.describe(), "describe const: foo");
assert_eq!(poly1.describe_mut(), "describe mut: foo");
assert_eq!(poly0.describe(), "describe const");
assert_eq!(poly0.describe_mut(), "describe mut");
}
{
let mut poly =
Poly::new(AlsoImplsDescribe("bar".to_string()), GlobalAllocator).unwrap();
assert_is_send::<Poly<AlsoImplsDescribe>>(&poly);
assert_eq!(poly.describe(), "describe const: bar");
assert_eq!(poly.describe_mut(), "describe mut: bar");
let mut poly = poly!({ Describe + Send }, poly);
assert_is_send::<Poly<dyn Describe + Send>>(&poly);
assert_eq!(poly.describe(), "describe const: bar");
assert_eq!(poly.describe_mut(), "describe mut: bar");
}
{
let f = 1.0f32;
let poly = poly!({ Foo<f32> }, f, GlobalAllocator).unwrap();
assert_eq!(poly.foo(2.0), 3.0);
}
{
let poly = Poly::new(1.0f32, GlobalAllocator).unwrap();
let poly = poly!({ Foo<f32> + Send }, poly);
assert_is_send::<Poly<dyn Foo<f32> + Send>>(&poly);
assert_eq!(poly.foo(2.0), 3.0);
}
fn test<'a, T>(x: T) -> Poly<dyn Foo<T> + 'a>
where
T: Foo<T> + 'a,
{
poly!({ Foo<T> }, x, GlobalAllocator).unwrap()
}
{
let x = test(1.0f32);
assert_eq!(x.foo(2.0), 3.0);
}
}
#[test]
fn test_dyn_trait_with_lifetime() {
let base: String = "foo".into();
let describe = DescribeLifetime(&base);
let mut poly: Poly<dyn Describe> = poly!({ Describe }, describe, GlobalAllocator).unwrap();
assert_eq!(poly.describe(), "describe const: foo");
assert_eq!(poly.describe_mut(), "describe mut: foo");
}
#[test]
fn test_try_clone_item() {
let x = Poly::<String>::new("hello".to_string(), GlobalAllocator).unwrap();
let y = x.try_clone().unwrap();
assert_eq!(x, y);
}
#[test]
fn test_try_clone_slice() {
let x = Poly::<[String]>::from_iter(
["foo".to_string(), "bar".to_string(), "baz".to_string()].into_iter(),
GlobalAllocator,
)
.unwrap();
let y = x.try_clone().unwrap();
assert_eq!(x, y);
}
#[test]
fn test_try_clone_option() {
let mut x = Some(Poly::<String>::new("hello".to_string(), GlobalAllocator).unwrap());
let y = x.try_clone().unwrap();
assert_eq!(x, y);
x = None;
let y = x.try_clone().unwrap();
assert_eq!(x, y);
}
#[cfg(feature = "flatbuffers")]
#[test]
fn test_grow_downwards() {
let mut x = Poly::from_iter([1u8, 2u8, 3u8].into_iter(), GlobalAllocator).unwrap();
<_ as flatbuffers::Allocator>::grow_downwards(&mut x).unwrap();
assert_eq!(&*x, &[0, 0, 0, 1, 2, 3]);
<_ as flatbuffers::Allocator>::grow_downwards(&mut x).unwrap();
assert_eq!(&*x, &[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3]);
}
}