use core::{
ops::{Deref, DerefMut},
mem::{MaybeUninit, transmute},
ptr::{drop_in_place, write, copy_nonoverlapping},
};
use alloc::vec::Vec;
pub struct SliceMemoryGuard<'a, T> {
memory: &'a mut [MaybeUninit<T>],
}
impl<'a, T> SliceMemoryGuard<'a, T> {
#[inline]
pub unsafe fn new(memory: &'a mut [MaybeUninit<T>], mut init: impl FnMut(usize) -> T) -> Self {
for (index, item) in memory.into_iter().enumerate() {
write(item.as_mut_ptr(), init(index));
}
SliceMemoryGuard { memory }
}
#[inline]
pub unsafe fn new_from_iter(memory: &'a mut [MaybeUninit<T>], mut iter: impl Iterator<Item=T>) -> Result<Self, Vec<T>> {
for (index, item) in memory.into_iter().enumerate() {
match iter.next() {
Some(value) => {
write(item.as_mut_ptr(), value);
}
None => {
return Ok(SliceMemoryGuard {
memory: &mut memory[0..index],
});
}
}
}
if let Some(next_item) = iter.next() {
let mut vec = Vec::<T>::with_capacity(
memory.len() + 1
);
vec.set_len(memory.len());
copy_nonoverlapping(memory.as_mut_ptr() as *mut T, vec.as_mut_ptr(), memory.len());
vec.push(next_item);
vec.extend(iter);
Err(vec)
} else {
Ok(SliceMemoryGuard { memory })
}
}
}
impl<'a, T> Deref for SliceMemoryGuard<'a, T> {
type Target = [T];
#[inline]
fn deref(&self) -> &Self::Target {
unsafe { transmute::<&[MaybeUninit<T>], &[T]>(&self.memory) }
}
}
impl<'a, T> DerefMut for SliceMemoryGuard<'a, T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { transmute::<&mut [MaybeUninit<T>], &mut [T]>(&mut self.memory) }
}
}
impl<'a, T> Drop for SliceMemoryGuard<'a, T> {
#[inline]
fn drop(&mut self) {
for item in self.memory.into_iter() {
unsafe { drop_in_place(item.as_mut_ptr()); }
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_from_iter_uses_exactly_same_count_that_collects() {
let mut memory: [MaybeUninit<usize>; 128] = unsafe { MaybeUninit::uninit().assume_init() };
for count in 0..128 {
let result = unsafe { SliceMemoryGuard::new_from_iter(&mut memory, 0..count) };
assert!(result.is_ok());
let guard = result.unwrap();
assert_eq!(guard.len(), count);
assert!(guard.iter().cloned().eq(0..count));
}
for count in [129, 200, 512].iter().cloned() {
let result = unsafe { SliceMemoryGuard::new_from_iter(&mut memory, 0..count) };
assert!(result.is_err());
let vec = result.err().unwrap();
assert_eq!(vec.len(), count);
assert_eq!(vec, (0..count).collect::<Vec<_>>());
}
}
}