use core::{
clone::TrivialClone,
convert::Infallible,
marker::PhantomData,
mem::{self, MaybeUninit},
ptr,
};
use crate::{
init::{Init, InitError, InitPin, InitPinResult, InitResult, Initializer, IntoInitPin},
owned::Own,
pin::DropSlot,
uninit::Uninit,
};
#[inline]
fn maybe_uninit_slice<T, const N: usize>(m: &mut MaybeUninit<[T; N]>) -> &mut [MaybeUninit<T>] {
unsafe { &mut *(m.as_mut_ptr() as *mut [MaybeUninit<T>; N]) }
}
#[derive(Debug, thiserror::Error, Clone, Copy, PartialEq)]
#[error("slice length mismatch")]
pub struct SliceError;
trait SpecInitSlice<T> {
fn init_slice(self, place: Uninit<'_, [T]>) -> InitResult<'_, [T], SliceError>;
fn init_array<const N: usize>(
self,
place: Uninit<'_, [T; N]>,
) -> InitResult<'_, [T; N], SliceError>;
}
impl<T: Clone> SpecInitSlice<T> for &[T] {
default fn init_slice(self, mut place: Uninit<'_, [T]>) -> InitResult<'_, [T], SliceError> {
if place.len() != self.len() {
return Err(InitError { error: SliceError, place });
}
place.write_clone_of_slice(self);
Ok(unsafe { place.assume_init() })
}
default fn init_array<const N: usize>(
self,
mut place: Uninit<'_, [T; N]>,
) -> InitResult<'_, [T; N], SliceError> {
if N != self.len() {
return Err(InitError { error: SliceError, place });
}
maybe_uninit_slice(&mut place).write_clone_of_slice(self);
Ok(unsafe { place.assume_init() })
}
}
impl<T: TrivialClone> SpecInitSlice<T> for &[T] {
fn init_slice(self, mut place: Uninit<'_, [T]>) -> InitResult<'_, [T], SliceError> {
if place.len() != self.len() {
return Err(InitError { error: SliceError, place });
}
unsafe { ptr::copy_nonoverlapping(self.as_ptr(), place.as_mut_ptr().cast(), self.len()) };
Ok(unsafe { place.assume_init() })
}
fn init_array<const N: usize>(
self,
mut place: Uninit<'_, [T; N]>,
) -> InitResult<'_, [T; N], SliceError> {
if N != self.len() {
return Err(InitError { error: SliceError, place });
}
unsafe { ptr::copy_nonoverlapping(self.as_ptr(), place.as_mut_ptr().cast(), self.len()) };
Ok(unsafe { place.assume_init() })
}
}
#[derive(Debug, PartialEq)]
pub struct Slice<'a, T>(&'a [T]);
impl<T> Initializer for Slice<'_, T> {
type Error = SliceError;
}
impl<T: Clone> InitPin<[T]> for Slice<'_, T> {
#[inline]
fn init_pin<'a, 'b>(
self,
place: Uninit<'a, [T]>,
slot: DropSlot<'a, 'b, [T]>,
) -> InitPinResult<'a, 'b, [T], SliceError> {
match self.0.init_slice(place) {
Ok(own) => Ok(Own::into_pin(own, slot)),
Err(err) => Err(err.into_pin(slot)),
}
}
}
impl<T: Clone> Init<[T]> for Slice<'_, T> {
#[inline]
fn init(self, place: Uninit<'_, [T]>) -> InitResult<'_, [T], SliceError> {
self.0.init_slice(place)
}
}
impl<T: Clone, const N: usize> InitPin<[T; N]> for Slice<'_, T> {
#[inline]
fn init_pin<'a, 'b>(
self,
place: Uninit<'a, [T; N]>,
slot: DropSlot<'a, 'b, [T; N]>,
) -> InitPinResult<'a, 'b, [T; N], SliceError> {
match self.0.init_array(place) {
Ok(own) => Ok(Own::into_pin(own, slot)),
Err(err) => Err(err.into_pin(slot)),
}
}
}
impl<T: Clone, const N: usize> Init<[T; N]> for Slice<'_, T> {
#[inline]
fn init(self, place: Uninit<'_, [T; N]>) -> InitResult<'_, [T; N], SliceError> {
self.0.init_array(place)
}
}
#[inline]
pub const fn slice<T: Clone>(s: &[T]) -> Slice<'_, T> {
Slice(s)
}
impl<'a, T: Clone> IntoInitPin<[T], Slice<'a, T>> for &'a [T] {
type Init = Slice<'a, T>;
type Error = SliceError;
#[inline]
fn into_init(self) -> Self::Init {
Slice(self)
}
}
impl<'a, T: Clone, const N: usize> IntoInitPin<[T; N], Slice<'a, T>> for &'a [T] {
type Init = Slice<'a, T>;
type Error = SliceError;
#[inline]
fn into_init(self) -> Self::Init {
Slice(self)
}
}
#[derive(Debug, PartialEq)]
pub struct Str<'a>(&'a str);
impl Initializer for Str<'_> {
type Error = SliceError;
}
impl InitPin<str> for Str<'_> {
fn init_pin<'a, 'b>(
self,
mut place: Uninit<'a, str>,
slot: DropSlot<'a, 'b, str>,
) -> InitPinResult<'a, 'b, str, SliceError> {
if place.len() != self.0.len() {
return Err(InitError { error: SliceError, place }.into_pin(slot));
}
let src = unsafe { mem::transmute::<&[u8], &[MaybeUninit<u8>]>(self.0.as_bytes()) };
place.copy_from_slice(src);
Ok(unsafe { place.assume_init_pin(slot) })
}
}
impl Init<str> for Str<'_> {
fn init(self, mut place: Uninit<'_, str>) -> InitResult<'_, str, SliceError> {
if place.len() != self.0.len() {
return Err(InitError { error: SliceError, place });
}
let src = unsafe { mem::transmute::<&[u8], &[MaybeUninit<u8>]>(self.0.as_bytes()) };
place.copy_from_slice(src);
Ok(unsafe { place.assume_init() })
}
}
#[inline]
pub const fn str(s: &str) -> Str<'_> {
Str(s)
}
impl<'b> IntoInitPin<str, Str<'b>> for &'b str {
type Init = Str<'b>;
type Error = SliceError;
#[inline]
fn into_init(self) -> Self::Init {
Str(self)
}
}
#[derive(Debug, PartialEq)]
pub struct Repeat<T>(T);
impl<T> Initializer for Repeat<T> {
type Error = Infallible;
}
impl<T: Clone> InitPin<[T]> for Repeat<T> {
fn init_pin<'a, 'b>(
self,
mut place: Uninit<'a, [T]>,
slot: DropSlot<'a, 'b, [T]>,
) -> InitPinResult<'a, 'b, [T], Infallible> {
place.write_filled(self.0);
Ok(unsafe { place.assume_init_pin(slot) })
}
}
impl<T: Clone> Init<[T]> for Repeat<T> {
fn init(self, mut place: Uninit<'_, [T]>) -> InitResult<'_, [T], Infallible> {
place.write_filled(self.0);
Ok(unsafe { place.assume_init() })
}
}
impl<T: Clone, const N: usize> InitPin<[T; N]> for Repeat<T> {
fn init_pin<'a, 'b>(
self,
mut place: Uninit<'a, [T; N]>,
slot: DropSlot<'a, 'b, [T; N]>,
) -> InitPinResult<'a, 'b, [T; N], Infallible> {
maybe_uninit_slice(&mut place).write_filled(self.0);
Ok(unsafe { place.assume_init_pin(slot) })
}
}
impl<T: Clone, const N: usize> Init<[T; N]> for Repeat<T> {
fn init(self, mut place: Uninit<'_, [T; N]>) -> InitResult<'_, [T; N], Infallible> {
maybe_uninit_slice(&mut place).write_filled(self.0);
Ok(unsafe { place.assume_init() })
}
}
#[inline]
pub const fn repeat<T: Clone>(value: T) -> Repeat<T> {
Repeat(value)
}
#[derive(Debug, PartialEq)]
pub struct RepeatWith<F>(F);
impl<F> Initializer for RepeatWith<F> {
type Error = Infallible;
}
impl<T, F> InitPin<[T]> for RepeatWith<F>
where
F: Fn(usize) -> T,
{
fn init_pin<'a, 'b>(
self,
mut place: Uninit<'a, [T]>,
slot: DropSlot<'a, 'b, [T]>,
) -> InitPinResult<'a, 'b, [T], Infallible> {
place.write_with(self.0);
Ok(unsafe { place.assume_init_pin(slot) })
}
}
impl<T, F> Init<[T]> for RepeatWith<F>
where
F: Fn(usize) -> T,
{
fn init(self, mut place: Uninit<'_, [T]>) -> InitResult<'_, [T], Infallible> {
place.write_with(self.0);
Ok(unsafe { place.assume_init() })
}
}
impl<T, F, const N: usize> InitPin<[T; N]> for RepeatWith<F>
where
F: Fn(usize) -> T,
{
fn init_pin<'a, 'b>(
self,
mut place: Uninit<'a, [T; N]>,
slot: DropSlot<'a, 'b, [T; N]>,
) -> InitPinResult<'a, 'b, [T; N], Infallible> {
maybe_uninit_slice(&mut place).write_with(self.0);
Ok(unsafe { place.assume_init_pin(slot) })
}
}
impl<T, F, const N: usize> Init<[T; N]> for RepeatWith<F>
where
F: Fn(usize) -> T,
{
fn init(self, mut place: Uninit<'_, [T; N]>) -> InitResult<'_, [T; N], Infallible> {
maybe_uninit_slice(&mut place).write_with(self.0);
Ok(unsafe { place.assume_init() })
}
}
#[inline]
pub const fn repeat_with<T, F>(f: F) -> RepeatWith<F>
where
F: Fn(usize) -> T,
{
RepeatWith(f)
}
#[derive(Debug, PartialEq)]
pub struct FromIter<I, T>(I, PhantomData<fn() -> T>);
#[derive(Debug, thiserror::Error)]
#[error("iterator initialization failed")]
pub struct FromIterError(());
impl<I, T> Initializer for FromIter<I, T> {
type Error = FromIterError;
}
#[inline]
fn collect_iter_slice<T, I>(uninit: &mut [MaybeUninit<T>], iter: I) -> Result<(), FromIterError>
where
I: IntoIterator<Item = T>,
{
let (_, remaining) = uninit.write_iter(iter);
match remaining.len() {
0 => Ok(()),
len => {
let init_len = uninit.len() - len;
unsafe { uninit[..init_len].assume_init_drop() };
Err(FromIterError(()))
}
}
}
#[inline]
fn collect_iter_array<T, const N: usize>(
uninit: &mut MaybeUninit<[T; N]>,
iter: impl IntoIterator<Item = T>,
) -> Result<(), FromIterError> {
collect_iter_slice(maybe_uninit_slice(uninit), iter)
}
fn concat_str<'a, I>(uninit: &mut [MaybeUninit<u8>], iter: I) -> Result<(), FromIterError>
where
I: IntoIterator<Item = &'a str>,
{
let mut remaining = uninit.len();
let mut dst = uninit.as_mut_ptr().cast::<u8>();
for s in iter {
let bytes = s.as_bytes();
let len = remaining.min(bytes.len());
if !s.is_char_boundary(len) {
return Err(FromIterError(()));
}
unsafe { ptr::copy_nonoverlapping(bytes.as_ptr(), dst, len) };
dst = unsafe { dst.add(len) };
remaining -= len;
if remaining == 0 {
return Ok(());
}
}
Err(FromIterError(()))
}
fn collect_chars<I>(uninit: &mut [MaybeUninit<u8>], iter: I) -> Result<(), FromIterError>
where
I: IntoIterator<Item = char>,
{
let mut remaining = uninit.len();
let mut dst = uninit.as_mut_ptr().cast::<u8>();
for c in iter {
if remaining < c.len_utf8() {
return Err(FromIterError(()));
}
let mut buf = [0; 4];
let bytes = c.encode_utf8(&mut buf).as_bytes();
unsafe { ptr::copy_nonoverlapping(bytes.as_ptr(), dst, bytes.len()) };
dst = unsafe { dst.add(bytes.len()) };
remaining -= bytes.len();
if remaining == 0 {
return Ok(());
}
}
Err(FromIterError(()))
}
macro_rules! derive_from_iter {
($($(@[$($g:tt)*]:)? $item:ty => $ty:ty = $imp:ident),* $(,)?) => {$(
impl<$($($g)*,)? __I> InitPin<$ty> for FromIter<__I, $item>
where
__I: IntoIterator<Item = $item>,
{
fn init_pin<'a, 'b>(
self,
mut place: Uninit<'a, $ty>,
slot: DropSlot<'a, 'b, $ty>,
) -> InitPinResult<'a, 'b, $ty, FromIterError> {
match $imp(&mut *place, self.0) {
Ok(()) => Ok(unsafe { place.assume_init_pin(slot) }),
Err(err) => Err(InitError { error: err, place }.into_pin(slot)),
}
}
}
impl<$($($g)*,)? __I> Init<$ty> for FromIter<__I, $item>
where
__I: IntoIterator<Item = $item>,
{
fn init(self, mut place: Uninit<'_, $ty>) -> InitResult<'_, $ty, FromIterError> {
match $imp(&mut *place, self.0) {
Ok(()) => Ok(unsafe { place.assume_init() }),
Err(err) => Err(InitError { error: err, place }),
}
}
}
)*};
}
derive_from_iter! {
@[T]: T => [T] = collect_iter_slice,
@[T, const N: usize]: T => [T; N] = collect_iter_array,
@['t]: &'t str => str = concat_str,
char => str = collect_chars,
}
#[inline]
pub const fn from_iter<I, T>(iter: I) -> FromIter<I, T>
where
I: IntoIterator<Item = T>,
{
FromIter(iter, PhantomData)
}
#[derive(Debug, PartialEq)]
pub struct Incremental<F, A: ?Sized, T>(F, PhantomData<fn(&mut A) -> T>);
impl<F, A: ?Sized, T> Initializer for Incremental<F, A, T> {
type Error = Infallible;
}
fn write_inc_slice<T, F>(uninit: &mut [MaybeUninit<T>], mut f: F)
where
F: FnMut(&mut [T]) -> T,
{
struct Guard<'a, T> {
slice: &'a mut [MaybeUninit<T>],
initialized: usize,
}
impl<'a, T> Guard<'a, T> {
fn initialized(&mut self) -> &mut [T] {
let init_part = &mut self.slice[..self.initialized];
unsafe { init_part.assume_init_mut() }
}
fn write(&mut self, v: T) {
self.slice[self.initialized].write(v);
self.initialized += 1;
}
}
impl<'a, T> Drop for Guard<'a, T> {
fn drop(&mut self) {
let initialized_part = &mut self.slice[..self.initialized];
unsafe {
initialized_part.assume_init_drop();
}
}
}
let mut guard = Guard { slice: uninit, initialized: 0 };
for _ in 0..guard.slice.len() {
let next = f(guard.initialized());
guard.write(next);
}
mem::forget(guard);
}
#[inline]
fn write_inc_array<T, F, const N: usize>(uninit: &mut MaybeUninit<[T; N]>, f: F)
where
F: FnMut(&mut [T]) -> T,
{
write_inc_slice(maybe_uninit_slice(uninit), f)
}
fn write_inc_str<'t, F>(uninit: &mut [MaybeUninit<u8>], mut f: F)
where
F: FnMut(&mut str) -> &'t str,
{
let mut initialized = 0;
let total = uninit.len();
let dst = uninit.as_mut_ptr().cast::<u8>();
loop {
let next = f(unsafe {
let init = core::slice::from_raw_parts_mut(dst, initialized);
core::str::from_utf8_unchecked_mut(init)
});
let bytes = next.as_bytes();
let len = (total - initialized).min(bytes.len());
assert!(
next.is_char_boundary(len),
"invalid UTF-8 boundary in incremental initialization"
);
unsafe { ptr::copy_nonoverlapping(bytes.as_ptr(), dst.add(initialized), len) };
initialized += len;
if initialized == total {
break;
}
}
}
fn write_inc_chars<F>(uninit: &mut [MaybeUninit<u8>], mut f: F)
where
F: FnMut(&mut str) -> char,
{
let mut initialized = 0;
let total = uninit.len();
let dst = uninit.as_mut_ptr().cast::<u8>();
loop {
let next = f(unsafe {
let init = core::slice::from_raw_parts_mut(dst, initialized);
core::str::from_utf8_unchecked_mut(init)
});
assert!(
initialized + next.len_utf8() <= total,
"not enough space for next char in incremental initialization"
);
let mut buf = [0; 4];
let bytes = next.encode_utf8(&mut buf).as_bytes();
unsafe { ptr::copy_nonoverlapping(bytes.as_ptr(), dst.add(initialized), bytes.len()) };
initialized += bytes.len();
if initialized == total {
break;
}
}
}
macro_rules! derive_incremental {
(@COERCED $ty:ty |[$coerced:ty]) => { $coerced };
(@COERCED $ty:ty) => { $ty };
($($(@[$($g:tt)*]:)? $item:ty => $ty:ty $(|[$coerced:ty])? = $imp:ident),* $(,)?) => {$(
impl<$($($g)*,)? __F> InitPin<$ty> for Incremental<
__F,
derive_incremental!(@COERCED $ty $(|[$coerced])?),
$item,
>
where
__F: FnMut(
&mut derive_incremental!(@COERCED $ty $(|[$coerced])?)
) -> $item,
{
fn init_pin<'a, 'b>(
self,
mut place: Uninit<'a, $ty>,
slot: DropSlot<'a, 'b, $ty>,
) -> InitPinResult<'a, 'b, $ty, Infallible> {
$imp(&mut *place, self.0);
Ok(unsafe { place.assume_init_pin(slot) })
}
}
impl<$($($g)*,)? __F> Init<$ty> for Incremental<
__F,
derive_incremental!(@COERCED $ty $(|[$coerced])?),
$item,
>
where
__F: FnMut(
&mut derive_incremental!(@COERCED $ty $(|[$coerced])?)
) -> $item,
{
fn init(self, mut place: Uninit<'_, $ty>) -> InitResult<'_, $ty, Infallible> {
$imp(&mut *place, self.0);
Ok(unsafe { place.assume_init() })
}
}
)*};
}
derive_incremental! {
@[T]: T => [T] = write_inc_slice,
@[T, const N: usize]: T => [T; N] |[[T]] = write_inc_array,
@['t]: &'t str => str = write_inc_str,
char => str = write_inc_chars,
}
#[inline]
pub const fn incremental<A: ?Sized, T, F>(f: F) -> Incremental<F, A, T>
where
F: FnMut(&mut A) -> T,
{
Incremental(f, PhantomData)
}