use std::sync::Arc;
use std::sync::atomic::Ordering;
use diatomic_waker::WakeSource;
use futures_task::{ArcWake, WakerRef, waker_ref};
use crate::loom_exports::sync::atomic::{AtomicU32, AtomicU64};
const SLEEPING: u32 = u32::MAX;
const EMPTY: u32 = u32::MAX - 1;
const INDEX_MASK: u64 = u32::MAX as u64;
const COUNTDOWN_MASK: u64 = !INDEX_MASK;
const COUNTDOWN_ONE: u64 = 1 << 32;
pub(crate) struct TaskSet {
tasks: Vec<Arc<Task>>,
shared: Arc<Shared>,
task_count: usize,
}
impl TaskSet {
#[allow(clippy::assertions_on_constants)]
pub(crate) fn new(notifier: WakeSource) -> Self {
assert!(usize::BITS >= u32::BITS);
Self {
tasks: Vec::new(),
shared: Arc::new(Shared {
head: AtomicU64::new(EMPTY as u64),
notifier,
}),
task_count: 0,
}
}
#[allow(clippy::assertions_on_constants)]
pub(crate) fn with_len(notifier: WakeSource, len: usize) -> Self {
assert!(usize::BITS >= u32::BITS);
assert!(len <= EMPTY as usize && len <= SLEEPING as usize);
let len = len as u32;
let shared = Arc::new(Shared {
head: AtomicU64::new(EMPTY as u64),
notifier,
});
let tasks: Vec<_> = (0..len)
.map(|idx| {
Arc::new(Task {
idx,
shared: shared.clone(),
next: AtomicU32::new(SLEEPING),
})
})
.collect();
Self {
tasks,
shared,
task_count: len as usize,
}
}
pub(crate) fn take_scheduled(&self, notify_count: usize) -> Option<TaskIterator<'_>> {
let countdown = u32::try_from(notify_count).unwrap();
let mut head = self.shared.head.load(Ordering::Relaxed);
loop {
let new_head = if head & INDEX_MASK == EMPTY as u64 {
(countdown as u64 * COUNTDOWN_ONE) | EMPTY as u64
} else {
EMPTY as u64
};
match self.shared.head.compare_exchange_weak(
head,
new_head,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(h) => head = h,
}
}
let index = (head & INDEX_MASK) as u32;
if index == EMPTY {
None
} else {
Some(TaskIterator {
task_list: self,
next_index: index,
})
}
}
pub(crate) fn discard_scheduled(&self) {
if self.shared.head.load(Ordering::Relaxed) != EMPTY as u64 {
let _ = self.take_scheduled(0);
}
}
pub(crate) fn resize(&mut self, len: usize) {
assert!(len <= EMPTY as usize && len <= SLEEPING as usize);
self.task_count = len;
if len >= self.tasks.len() {
while len > self.tasks.len() {
let idx = self.tasks.len() as u32;
self.tasks.push(Arc::new(Task {
idx,
shared: self.shared.clone(),
next: AtomicU32::new(SLEEPING),
}));
}
}
}
pub(crate) fn has_scheduled(&self) -> bool {
self.shared.head.load(Ordering::Relaxed) & INDEX_MASK != EMPTY as u64
}
pub(crate) fn waker_of(&self, idx: usize) -> WakerRef<'_> {
assert!(idx < self.task_count);
waker_ref(&self.tasks[idx])
}
}
struct Shared {
head: AtomicU64,
notifier: WakeSource,
}
struct Task {
idx: u32,
next: AtomicU32,
shared: Arc<Shared>,
}
impl ArcWake for Task {
fn wake_by_ref(arc_self: &Arc<Self>) {
let mut next = arc_self.next.load(Ordering::Relaxed);
let mut head = loop {
if next == SLEEPING {
let head = arc_self.shared.head.load(Ordering::Relaxed);
match arc_self.next.compare_exchange_weak(
SLEEPING,
(head & INDEX_MASK) as u32,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => break head,
Err(n) => next = n,
}
} else {
match arc_self.next.compare_exchange_weak(
next,
next,
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => return,
Err(n) => next = n,
}
}
};
loop {
let countdown = head & COUNTDOWN_MASK;
let new_countdown = countdown.wrapping_sub((countdown != 0) as u64 * COUNTDOWN_ONE);
let new_head = new_countdown | arc_self.idx as u64;
match arc_self.shared.head.compare_exchange_weak(
head,
new_head,
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => {
if countdown == COUNTDOWN_ONE {
arc_self.shared.notifier.notify();
}
return;
}
Err(h) => {
head = h;
arc_self
.next
.swap((head & INDEX_MASK) as u32, Ordering::Relaxed);
}
}
}
}
}
pub(crate) struct TaskIterator<'a> {
task_list: &'a TaskSet,
next_index: u32,
}
impl Iterator for TaskIterator<'_> {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
while self.next_index != EMPTY {
let index = self.next_index as usize;
self.next_index = self.task_list.tasks[index]
.next
.swap(SLEEPING, Ordering::Acquire);
if index < self.task_list.task_count {
return Some(index);
}
}
None
}
}
impl Drop for TaskIterator<'_> {
fn drop(&mut self) {
while self.next_index != EMPTY {
let index = self.next_index as usize;
self.next_index = self.task_list.tasks[index].next.load(Ordering::Relaxed);
self.task_list.tasks[index]
.next
.store(SLEEPING, Ordering::Relaxed);
}
}
}