use super::{ExactParallelSink, FromExactParallelSink};
use std::cell::UnsafeCell;
use std::mem::MaybeUninit;
use std::sync::Mutex;
impl<T: Send, const N: usize> FromExactParallelSink for [T; N] {
type Item = T;
type Sink = ArrayParallelSink<T, N>;
unsafe fn finalize(sink: Self::Sink) -> Self {
debug_assert!(sink.skipped.into_inner().unwrap().is_empty());
unsafe { sink.array.into_array() }
}
}
#[must_use = "iterator adaptors are lazy"]
pub struct ArrayParallelSink<T: Send, const N: usize> {
array: ArrayWrapper<T, N>,
skipped: Mutex<Vec<std::ops::Range<usize>>>,
}
impl<T: Send, const N: usize> ExactParallelSink for ArrayParallelSink<T, N> {
type Item = T;
const NEEDS_CLEANUP: bool = std::mem::needs_drop::<T>();
fn new(len: usize) -> Self {
assert_eq!(
len, N,
"tried to collect an iterator into an array of the wrong length"
);
Self {
array: ArrayWrapper::new(),
skipped: Mutex::new(Vec::new()),
}
}
unsafe fn push_item(&self, index: usize, item: Self::Item) {
debug_assert!(index < N);
let base_ptr: *mut T = self.array.start();
let item_ptr: *mut T = unsafe { base_ptr.add(index) };
unsafe { std::ptr::write(item_ptr, item) };
}
unsafe fn skip_item_range(&self, range: std::ops::Range<usize>) {
if Self::NEEDS_CLEANUP {
debug_assert!(range.start <= range.end);
debug_assert!(range.start <= N);
debug_assert!(range.end <= N);
self.skipped.lock().unwrap().push(range);
}
}
unsafe fn cancel(self) {
let base_ptr: *mut T = self.array.start();
if Self::NEEDS_CLEANUP {
let mut skipped = self.skipped.into_inner().unwrap();
skipped.sort_unstable_by_key(|range| range.start);
let mut prev = 0..0;
for range in skipped.into_iter() {
Self::cleanup_item_range(base_ptr, prev.end..range.start);
prev = range.clone();
}
Self::cleanup_item_range(base_ptr, prev.end..N);
}
}
}
impl<T: Send, const N: usize> ArrayParallelSink<T, N> {
fn cleanup_item_range(base_ptr: *mut T, range: std::ops::Range<usize>) {
if Self::NEEDS_CLEANUP {
debug_assert!(range.start <= range.end);
debug_assert!(range.start <= N);
debug_assert!(range.end <= N);
let start_ptr: *mut T = unsafe { base_ptr.add(range.start) };
let slice: *mut [T] =
std::ptr::slice_from_raw_parts_mut(start_ptr, range.end - range.start);
unsafe { std::ptr::drop_in_place(slice) };
}
}
}
struct ArrayWrapper<T, const N: usize>(UnsafeCell<[MaybeUninit<T>; N]>);
impl<T, const N: usize> ArrayWrapper<T, N> {
fn new() -> Self {
ArrayWrapper(UnsafeCell::new([const { MaybeUninit::uninit() }; N]))
}
unsafe fn into_array(self) -> [T; N] {
let array: [MaybeUninit<T>; N] = self.0.into_inner();
let array: MaybeUninit<[T; N]> = array.transpose();
unsafe { array.assume_init() }
}
fn start(&self) -> *mut T {
let array_ptr: *mut [MaybeUninit<T>; N] = self.0.get();
let start_ptr: *mut MaybeUninit<T> = array_ptr.as_mut_ptr();
start_ptr as *mut T
}
}
unsafe impl<T: Send, const N: usize> Sync for ArrayWrapper<T, N> {}