use std::{
cell::UnsafeCell,
fmt::{self, Debug},
future::Future,
marker::PhantomData,
mem,
pin::Pin,
ptr,
sync::{
Arc, Weak,
atomic::{
AtomicBool, AtomicPtr, AtomicU64,
Ordering::{self, AcqRel, Acquire, Relaxed, Release, SeqCst},
},
},
task::{Context, Poll},
};
use super::{Dequeue, Iter, IterMut, IterPinMut, IterPinRef, ReadyToRunQueue, Task, atomic_waker::AtomicWaker};
#[must_use = "streams do nothing unless polled"]
pub struct FuturesUnordered<Fut> {
ready_to_run_queue: Arc<ReadyToRunQueue<Fut>>,
pub(super) head_all: AtomicPtr<Task<Fut>>,
poll_counter: u64,
}
#[allow(clippy::non_send_fields_in_send_ty)]
unsafe impl<Fut: Send> Send for FuturesUnordered<Fut> {}
unsafe impl<Fut: Sync> Sync for FuturesUnordered<Fut> {}
impl<Fut> Unpin for FuturesUnordered<Fut> {}
impl<Fut> Default for FuturesUnordered<Fut> {
fn default() -> Self {
Self::new()
}
}
impl<Fut> FuturesUnordered<Fut> {
pub fn new() -> Self {
let stub = Arc::new(Task {
future: UnsafeCell::new(None),
next_all: AtomicPtr::new(ptr::null_mut()),
prev_all: UnsafeCell::new(ptr::null()),
len_all: UnsafeCell::new(0),
next_ready_to_run: AtomicPtr::new(ptr::null_mut()),
queued: AtomicBool::new(true),
ready_to_run_queue: Weak::new(),
last_polled: AtomicU64::new(1),
});
let stub_ptr = Arc::as_ptr(&stub);
let ready_to_run_queue = Arc::new(ReadyToRunQueue {
waker: AtomicWaker::new(),
head: AtomicPtr::new(stub_ptr as *mut _),
tail: UnsafeCell::new(stub_ptr),
stub,
});
Self {
head_all: AtomicPtr::new(ptr::null_mut()),
ready_to_run_queue,
poll_counter: 0,
}
}
pub fn len(&self) -> usize {
let (_, len) = self.atomic_load_head_and_len_all();
len
}
pub fn is_empty(&self) -> bool {
self.head_all.load(Relaxed).is_null()
}
pub fn push(&self, future: Fut) {
let task = Arc::new(Task {
future: UnsafeCell::new(Some(future)),
next_all: AtomicPtr::new(self.pending_next_all()),
prev_all: UnsafeCell::new(ptr::null_mut()),
len_all: UnsafeCell::new(0),
next_ready_to_run: AtomicPtr::new(ptr::null_mut()),
queued: AtomicBool::new(true),
ready_to_run_queue: Arc::downgrade(&self.ready_to_run_queue),
last_polled: AtomicU64::new(0),
});
let ptr = self.link(task);
self.ready_to_run_queue.enqueue(ptr);
}
pub fn iter(&self) -> Iter<'_, Fut>
where
Fut: Unpin,
{
Iter(Pin::new(self).iter_pin_ref())
}
pub fn iter_pin_ref(self: Pin<&Self>) -> IterPinRef<'_, Fut> {
let (task, len) = self.atomic_load_head_and_len_all();
let pending_next_all = self.pending_next_all();
IterPinRef {
task,
len,
pending_next_all,
_marker: PhantomData,
}
}
pub fn iter_mut(&mut self) -> IterMut<'_, Fut>
where
Fut: Unpin,
{
IterMut(Pin::new(self).iter_pin_mut())
}
pub fn iter_pin_mut(mut self: Pin<&mut Self>) -> IterPinMut<'_, Fut> {
let task = *self.head_all.get_mut();
let len = if task.is_null() {
0
} else {
unsafe { *(*task).len_all.get() }
};
IterPinMut {
task,
len,
_marker: PhantomData,
}
}
fn atomic_load_head_and_len_all(&self) -> (*const Task<Fut>, usize) {
let task = self.head_all.load(Acquire);
let len = if task.is_null() {
0
} else {
unsafe {
(*task).spin_next_all(self.pending_next_all(), Acquire);
*(*task).len_all.get()
}
};
(task, len)
}
fn release_task(&mut self, task: Arc<Task<Fut>>) {
debug_assert_eq!(task.next_all.load(Relaxed), self.pending_next_all());
unsafe {
debug_assert!((*task.prev_all.get()).is_null());
}
let prev = task.queued.swap(true, SeqCst);
unsafe {
*task.future.get() = None;
}
if prev {
mem::forget(task);
}
}
fn link(&self, task: Arc<Task<Fut>>) -> *const Task<Fut> {
debug_assert_eq!(task.next_all.load(Relaxed), self.pending_next_all());
let ptr = Arc::into_raw(task);
let next = self.head_all.swap(ptr as *mut _, AcqRel);
unsafe {
let new_len = if next.is_null() {
1
} else {
(*next).spin_next_all(self.pending_next_all(), Acquire);
*(*next).len_all.get() + 1
};
*(*ptr).len_all.get() = new_len;
(*ptr).next_all.store(next, Release);
if !next.is_null() {
*(*next).prev_all.get() = ptr;
}
}
ptr
}
unsafe fn unlink(&mut self, task: *const Task<Fut>) -> Arc<Task<Fut>> {
let head = *self.head_all.get_mut();
debug_assert!(!head.is_null());
unsafe {
let new_len = *(*head).len_all.get() - 1;
let task = Arc::from_raw(task);
let next = task.next_all.load(Relaxed);
let prev = *task.prev_all.get();
task.next_all.store(self.pending_next_all(), Relaxed);
*task.prev_all.get() = ptr::null_mut();
if !next.is_null() {
*(*next).prev_all.get() = prev;
}
if !prev.is_null() {
(*prev).next_all.store(next, Relaxed);
} else {
*self.head_all.get_mut() = next;
}
let head = *self.head_all.get_mut();
if !head.is_null() {
*(*head).len_all.get() = new_len;
}
task
}
}
fn pending_next_all(&self) -> *mut Task<Fut> {
Arc::as_ptr(&self.ready_to_run_queue.stub) as *mut _
}
}
impl<Fut: Future<Output = ()>> FuturesUnordered<Fut> {
pub fn increment_counter(&mut self) -> u64 {
self.poll_counter += 1;
self.poll_counter
}
pub fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
self.ready_to_run_queue.waker.register(cx.waker());
let mut initial_task_id = None;
loop {
let task_ptr = match unsafe { self.ready_to_run_queue.dequeue() } {
Dequeue::Empty => {
if self.is_empty() {
return Poll::Ready(());
} else {
return Poll::Pending;
}
}
Dequeue::Inconsistent => {
cx.waker().wake_by_ref();
return Poll::Pending;
}
Dequeue::Data(task) => task,
};
debug_assert!(task_ptr != self.ready_to_run_queue.stub());
let future = match unsafe { &mut *(*task_ptr).future.get() } {
Some(future) => future,
None => {
let task = unsafe { Arc::from_raw(task_ptr) };
debug_assert_eq!(task.next_all.load(Relaxed), self.pending_next_all());
unsafe {
debug_assert!((*task.prev_all.get()).is_null());
}
continue;
}
};
let task = unsafe { self.unlink(task_ptr) };
let prev = task.queued.swap(false, SeqCst);
assert!(prev);
if task.last_polled.load(Ordering::Relaxed) >= self.poll_counter {
task.queued.swap(true, SeqCst);
let task_ptr = self.link(task);
self.ready_to_run_queue.enqueue(task_ptr);
match initial_task_id {
Some(iti) => {
if iti == task_ptr as usize {
return Poll::Pending;
} else {
continue;
}
}
None => {
initial_task_id = Some(task_ptr as usize);
continue;
}
}
}
task.last_polled.store(self.poll_counter, Ordering::Relaxed);
struct Bomb<'a, Fut> {
queue: &'a mut FuturesUnordered<Fut>,
task: Option<Arc<Task<Fut>>>,
}
impl<Fut> Drop for Bomb<'_, Fut> {
fn drop(&mut self) {
if let Some(task) = self.task.take() {
self.queue.release_task(task);
}
}
}
let mut bomb = Bomb {
task: Some(task),
queue: &mut *self,
};
let res = {
let waker = Task::waker_ref(bomb.task.as_ref().unwrap());
let mut cx = Context::from_waker(&waker);
let future = unsafe { Pin::new_unchecked(future) };
future.poll(&mut cx)
};
match res {
Poll::Pending => {
let task = bomb.task.take().unwrap();
let task_ptr = bomb.queue.link(task);
if initial_task_id.is_none() {
initial_task_id = Some(task_ptr as usize);
}
}
Poll::Ready(()) => {}
};
}
}
}
impl<Fut> Debug for FuturesUnordered<Fut> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "FuturesUnordered {{ ... }}")
}
}
impl<Fut> FuturesUnordered<Fut> {
#[allow(dead_code)] pub fn clear(&mut self) {
self.clear_head_all();
unsafe { self.ready_to_run_queue.clear() };
}
fn clear_head_all(&mut self) {
while !self.head_all.get_mut().is_null() {
let head = *self.head_all.get_mut();
let task = unsafe { self.unlink(head) };
self.release_task(task);
}
}
}
impl<Fut> Drop for FuturesUnordered<Fut> {
fn drop(&mut self) {
self.clear_head_all();
}
}