use alloc::vec::Vec;
use core::{
array,
iter::{self, FusedIterator, TrustedLen},
marker::PhantomData,
mem::MaybeUninit,
ptr,
};
use crate::include::*;
const MAX_COUNT: usize = isize::MAX as _;
#[derive(Debug)]
pub struct Element<T> {
storage: UnsafeCell<MaybeUninit<T>>,
placed: AtomicBool,
}
impl<T> Default for Element<T> {
fn default() -> Self {
Element {
storage: UnsafeCell::new(MaybeUninit::uninit()),
placed: AtomicBool::new(false),
}
}
}
impl<T> Element<T> {
pub fn vec(count: usize) -> Vec<Self> {
iter::repeat_with(Default::default)
.take(count)
.collect::<Vec<_>>()
}
pub fn array<const N: usize>() -> [Self; N] {
array::from_fn(|_| Default::default())
}
pub(crate) unsafe fn place(&self, data: T) {
unsafe { self.storage.with_mut(|ptr| (*ptr).write(data)) };
self.placed.store(true, Relaxed);
}
pub(crate) unsafe fn take(&self) -> Option<T> {
self.placed
.load(Relaxed)
.then(|| unsafe { self.storage.with_mut(|ptr| (*ptr).assume_init_read()) })
}
}
pub trait Place<T>: AsRef<[Element<T>]> {}
impl<T, P> Place<T> for P where P: AsRef<[Element<T>]> {}
struct Inner<T, P>
where
P: Place<T>,
{
count: AtomicUsize,
place: P,
marker: PhantomData<[T]>,
}
impl<T, P> Inner<T, P>
where
P: Place<T>,
{
const LAYOUT: Layout = Layout::new::<Self>();
fn new(place: P) -> NonNull<Self> {
let count = place.as_ref().len();
assert!(
count <= MAX_COUNT,
"the length of the slot must not exceed `isize::MAX`"
);
assert!(count > 0, "the slot must not be empty");
let memory = match Global.allocate(Self::LAYOUT) {
Ok(memory) => memory.cast::<Self>(),
Err(_) => handle_alloc_error(Self::LAYOUT),
};
let value = Self {
count: AtomicUsize::new(count),
place,
marker: PhantomData,
};
unsafe { memory.as_ptr().write(value) }
memory
}
unsafe fn drop_in_place(this: NonNull<Self>, start: usize) {
let inner = unsafe { this.as_ref() };
for elem in inner.place.as_ref().get(start..).into_iter().flatten() {
unsafe { drop(elem.take()) }
}
unsafe { ptr::drop_in_place(this.as_ptr()) };
unsafe { Global.deallocate(this.cast(), Inner::<T, P>::LAYOUT) };
}
}
#[derive(Debug)]
pub struct Sender<T, P>
where
P: Place<T>,
{
inner: NonNull<Inner<T, P>>,
index: usize,
}
unsafe impl<T: Send, P: Place<T>> Send for Sender<T, P> {}
impl<T, P> Sender<T, P>
where
P: Place<T>,
{
unsafe fn new(inner: NonNull<Inner<T, P>>, index: usize) -> Self {
Sender { inner, index }
}
pub fn send(self, value: T) -> Result<(), SenderIter<T, P>> {
let inner = unsafe { self.inner.as_ref() };
let elem = unsafe { inner.place.as_ref().get_unchecked(self.index) };
unsafe { elem.place(value) };
let fetch_sub = inner.count.fetch_sub(1, Release);
let pointer = self.inner;
mem::forget(self);
if fetch_sub == 1 {
atomic::fence(Acquire);
return Err(unsafe { SenderIter::new(pointer) });
}
Ok(())
}
}
impl<T, P: Place<T>> Drop for Sender<T, P> {
fn drop(&mut self) {
let inner = unsafe { self.inner.as_ref() };
if inner.count.fetch_sub(1, Relaxed) == 1 {
atomic::fence(Acquire);
unsafe { Inner::drop_in_place(self.inner, 0) }
}
}
}
#[derive(Debug)]
pub struct SenderIter<T, P>
where
P: Place<T>,
{
inner: NonNull<Inner<T, P>>,
index: usize,
}
unsafe impl<T: Send, P: Place<T>> Send for SenderIter<T, P> {}
impl<T, P: Place<T>> SenderIter<T, P> {
unsafe fn new(inner: NonNull<Inner<T, P>>) -> Self {
Self { inner, index: 0 }
}
}
impl<T, P: Place<T>> Iterator for SenderIter<T, P> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
let inner = unsafe { self.inner.as_ref() };
while let Some(elem) = inner.place.as_ref().get(self.index) {
self.index += 1;
if let Some(data) = unsafe { elem.take() } {
return Some(data);
}
}
None
}
fn size_hint(&self) -> (usize, Option<usize>) {
let inner = unsafe { self.inner.as_ref() };
let len = inner.place.as_ref().len();
(0, Some(len))
}
}
impl<T, P: Place<T>> FusedIterator for SenderIter<T, P> {}
impl<T, P: Place<T>> Drop for SenderIter<T, P> {
fn drop(&mut self) {
unsafe { Inner::drop_in_place(self.inner, self.index) }
}
}
#[derive(Debug)]
pub struct InitIter<T, P: Place<T>> {
inner: NonNull<Inner<T, P>>,
index: usize,
}
unsafe impl<T: Send, P: Place<T>> Send for InitIter<T, P> {}
impl<T, P: Place<T>> InitIter<T, P> {
unsafe fn new(inner: NonNull<Inner<T, P>>) -> Self {
InitIter { inner, index: 0 }
}
}
impl<T, P: Place<T>> Iterator for InitIter<T, P> {
type Item = Sender<T, P>;
fn next(&mut self) -> Option<Self::Item> {
let inner = unsafe { self.inner.as_ref() };
let len = inner.place.as_ref().len();
if self.index < len {
let s = unsafe { Sender::new(self.inner, self.index) };
self.index += 1;
Some(s)
} else {
None
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let inner = unsafe { self.inner.as_ref() };
let len = inner.place.as_ref().len();
(len, Some(len))
}
}
impl<T, P: Place<T>> Drop for InitIter<T, P> {
fn drop(&mut self) {
self.for_each(drop)
}
}
impl<T, P: Place<T>> ExactSizeIterator for InitIter<T, P> {}
impl<T, P: Place<T>> FusedIterator for InitIter<T, P> {}
unsafe impl<T, P: Place<T>> TrustedLen for InitIter<T, P> {}
pub fn from_place<T, P: Place<T>>(place: P) -> InitIter<T, P> {
let inner = Inner::new(place);
unsafe { InitIter::new(inner) }
}
pub fn vec<T>(count: usize) -> InitIter<T, Vec<Element<T>>> {
from_place(Element::vec(count))
}
pub fn array<T, const N: usize>() -> [Sender<T, [Element<T>; N]>; N] {
let inner = Inner::new(Element::array());
array::from_fn(move |index| unsafe { Sender::new(inner, index) })
}
#[cfg(test)]
mod tests {
use alloc::vec::Vec;
#[cfg(not(loom))]
use std::thread;
#[cfg(loom)]
use loom::thread;
use crate::array::{from_place, Element};
#[test]
fn send() {
fn inner() {
let j = from_place(Element::array::<3>())
.enumerate()
.map(|(i, s)| thread::spawn(move || s.send(i)))
.collect::<Vec<_>>();
let iter = j
.into_iter()
.map(|j| j.join().unwrap())
.fold(Ok(()), Result::and)
.unwrap_err();
assert_eq!(iter.collect::<Vec<_>>(), [0, 1, 2]);
}
#[cfg(not(loom))]
inner();
#[cfg(loom)]
loom::model(inner);
}
#[test]
fn drop_one() {
fn inner() {
let j = from_place(Element::vec(3))
.enumerate()
.map(|(i, s)| {
if i != 1 {
thread::spawn(move || s.send(i))
} else {
thread::spawn(move || {
drop(s);
Ok(())
})
}
})
.collect::<Vec<_>>();
let res = j
.into_iter()
.map(|j| j.join().unwrap())
.fold(Ok(()), Result::and);
if let Err(iter) = res {
assert_eq!(iter.collect::<Vec<_>>(), [0, 2]);
}
}
#[cfg(not(loom))]
inner();
#[cfg(loom)]
loom::model(inner);
}
}