use std::marker::PhantomPinned;
use std::mem::ManuallyDrop;
use std::ptr::NonNull;
use std::task::{Context, Waker};
use crate::loom::cell::UnsafeCell;
use crate::loom::sync::{Arc, Mutex};
use crate::util::linked_list::{self, Link};
use crate::util::{waker_ref, Wake};
type LinkedList<T> =
linked_list::LinkedList<ListEntry<T>, <ListEntry<T> as linked_list::Link>::Target>;
pub(crate) struct IdleNotifiedSet<T> {
lists: Arc<Lists<T>>,
length: usize,
}
pub(crate) struct EntryInOneOfTheLists<'a, T> {
entry: Arc<ListEntry<T>>,
set: &'a mut IdleNotifiedSet<T>,
}
type Lists<T> = Mutex<ListsInner<T>>;
struct ListsInner<T> {
notified: LinkedList<T>,
idle: LinkedList<T>,
waker: Option<Waker>,
}
#[derive(Copy, Clone, Eq, PartialEq)]
enum List {
Notified,
Idle,
Neither,
}
struct ListEntry<T> {
pointers: linked_list::Pointers<ListEntry<T>>,
parent: Arc<Lists<T>>,
value: UnsafeCell<ManuallyDrop<T>>,
my_list: UnsafeCell<List>,
_pin: PhantomPinned,
}
generate_addr_of_methods! {
impl<T> ListEntry<T> {
unsafe fn addr_of_pointers(self: NonNull<Self>) -> NonNull<linked_list::Pointers<ListEntry<T>>> {
&self.pointers
}
}
}
unsafe impl<T: Send> Send for IdleNotifiedSet<T> {}
unsafe impl<T: Sync> Sync for IdleNotifiedSet<T> {}
unsafe impl<T> Send for ListEntry<T> {}
unsafe impl<T> Sync for ListEntry<T> {}
impl<T> IdleNotifiedSet<T> {
pub(crate) fn new() -> Self {
let lists = Mutex::new(ListsInner {
notified: LinkedList::new(),
idle: LinkedList::new(),
waker: None,
});
IdleNotifiedSet {
lists: Arc::new(lists),
length: 0,
}
}
pub(crate) fn len(&self) -> usize {
self.length
}
pub(crate) fn is_empty(&self) -> bool {
self.length == 0
}
pub(crate) fn insert_idle(&mut self, value: T) -> EntryInOneOfTheLists<'_, T> {
self.length += 1;
let entry = Arc::new(ListEntry {
parent: self.lists.clone(),
value: UnsafeCell::new(ManuallyDrop::new(value)),
my_list: UnsafeCell::new(List::Idle),
pointers: linked_list::Pointers::new(),
_pin: PhantomPinned,
});
{
let mut lock = self.lists.lock();
lock.idle.push_front(entry.clone());
}
EntryInOneOfTheLists { entry, set: self }
}
pub(crate) fn pop_notified(&mut self, waker: &Waker) -> Option<EntryInOneOfTheLists<'_, T>> {
if self.length == 0 {
return None;
}
let mut lock = self.lists.lock();
let should_update_waker = match lock.waker.as_mut() {
Some(cur_waker) => !waker.will_wake(cur_waker),
None => true,
};
if should_update_waker {
lock.waker = Some(waker.clone());
}
let entry = lock.notified.pop_back()?;
lock.idle.push_front(entry.clone());
entry.my_list.with_mut(|ptr| unsafe {
*ptr = List::Idle;
});
drop(lock);
Some(EntryInOneOfTheLists { entry, set: self })
}
pub(crate) fn for_each<F: FnMut(&mut T)>(&mut self, mut func: F) {
fn get_ptrs<T>(list: &mut LinkedList<T>, ptrs: &mut Vec<*mut T>) {
let mut node = list.last();
while let Some(entry) = node {
ptrs.push(entry.value.with_mut(|ptr| {
let ptr: *mut ManuallyDrop<T> = ptr;
let ptr: *mut T = ptr.cast();
ptr
}));
let prev = entry.pointers.get_prev();
node = prev.map(|prev| unsafe { &*prev.as_ptr() });
}
}
let mut ptrs = Vec::with_capacity(self.len());
{
let mut lock = self.lists.lock();
get_ptrs(&mut lock.idle, &mut ptrs);
get_ptrs(&mut lock.notified, &mut ptrs);
}
debug_assert_eq!(ptrs.len(), ptrs.capacity());
for ptr in ptrs {
func(unsafe { &mut *ptr });
}
}
pub(crate) fn drain<F: FnMut(T)>(&mut self, func: F) {
if self.length == 0 {
return;
}
self.length = 0;
struct AllEntries<T, F: FnMut(T)> {
all_entries: LinkedList<T>,
func: F,
}
impl<T, F: FnMut(T)> AllEntries<T, F> {
fn pop_next(&mut self) -> bool {
if let Some(entry) = self.all_entries.pop_back() {
entry
.value
.with_mut(|ptr| unsafe { (self.func)(ManuallyDrop::take(&mut *ptr)) });
true
} else {
false
}
}
}
impl<T, F: FnMut(T)> Drop for AllEntries<T, F> {
fn drop(&mut self) {
while self.pop_next() {}
}
}
let mut all_entries = AllEntries {
all_entries: LinkedList::new(),
func,
};
{
let mut lock = self.lists.lock();
unsafe {
move_to_new_list(&mut lock.idle, &mut all_entries.all_entries);
move_to_new_list(&mut lock.notified, &mut all_entries.all_entries);
}
}
while all_entries.pop_next() {}
}
}
unsafe fn move_to_new_list<T>(from: &mut LinkedList<T>, to: &mut LinkedList<T>) {
while let Some(entry) = from.pop_back() {
entry.my_list.with_mut(|ptr| {
*ptr = List::Neither;
});
to.push_front(entry);
}
}
impl<'a, T> EntryInOneOfTheLists<'a, T> {
pub(crate) fn remove(self) -> T {
self.set.length -= 1;
{
let mut lock = self.set.lists.lock();
let old_my_list = self.entry.my_list.with_mut(|ptr| unsafe {
let old_my_list = *ptr;
*ptr = List::Neither;
old_my_list
});
let list = match old_my_list {
List::Idle => &mut lock.idle,
List::Notified => &mut lock.notified,
List::Neither => unreachable!(),
};
unsafe {
list.remove(ListEntry::as_raw(&self.entry)).unwrap();
}
}
self.entry
.value
.with_mut(|ptr| unsafe { ManuallyDrop::take(&mut *ptr) })
}
pub(crate) fn with_value_and_context<F, U>(&mut self, func: F) -> U
where
F: FnOnce(&mut T, &mut Context<'_>) -> U,
T: 'static,
{
let waker = waker_ref(&self.entry);
let mut context = Context::from_waker(&waker);
self.entry
.value
.with_mut(|ptr| unsafe { func(&mut *ptr, &mut context) })
}
}
impl<T> Drop for IdleNotifiedSet<T> {
fn drop(&mut self) {
self.drain(drop);
#[cfg(debug_assertions)]
if !std::thread::panicking() {
let lock = self.lists.lock();
assert!(lock.idle.is_empty());
assert!(lock.notified.is_empty());
}
}
}
impl<T: 'static> Wake for ListEntry<T> {
fn wake_by_ref(me: &Arc<Self>) {
let mut lock = me.parent.lock();
let old_my_list = me.my_list.with_mut(|ptr| unsafe {
let old_my_list = *ptr;
if old_my_list == List::Idle {
*ptr = List::Notified;
}
old_my_list
});
if old_my_list == List::Idle {
let me = unsafe {
lock.idle.remove(NonNull::from(&**me)).unwrap()
};
lock.notified.push_front(me);
if let Some(waker) = lock.waker.take() {
drop(lock);
waker.wake();
}
}
}
fn wake(me: Arc<Self>) {
Self::wake_by_ref(&me)
}
}
unsafe impl<T> linked_list::Link for ListEntry<T> {
type Handle = Arc<ListEntry<T>>;
type Target = ListEntry<T>;
fn as_raw(handle: &Self::Handle) -> NonNull<ListEntry<T>> {
let ptr: *const ListEntry<T> = Arc::as_ptr(handle);
unsafe { NonNull::new_unchecked(ptr as *mut ListEntry<T>) }
}
unsafe fn from_raw(ptr: NonNull<ListEntry<T>>) -> Arc<ListEntry<T>> {
Arc::from_raw(ptr.as_ptr())
}
unsafe fn pointers(
target: NonNull<ListEntry<T>>,
) -> NonNull<linked_list::Pointers<ListEntry<T>>> {
ListEntry::addr_of_pointers(target)
}
}