use std::{
cell::UnsafeCell,
cmp,
fmt::{self, Debug},
future::Future,
iter::FromIterator,
mem,
pin::Pin,
ptr,
sync::{
atomic::{
AtomicBool, AtomicPtr,
Ordering::{AcqRel, Acquire, Relaxed, Release, SeqCst},
},
Arc, Weak,
},
task::{Context, Poll},
};
use super::{atomic_waker::AtomicWaker, Dequeue, ReadyToRunQueue, Task};
const YIELD_EVERY: usize = 32;
#[must_use = "streams do nothing unless polled"]
pub struct FuturesUnordered<Fut> {
ready_to_run_queue: Arc<ReadyToRunQueue<Fut>>,
pub(super) head_all: AtomicPtr<Task<Fut>>,
is_terminated: AtomicBool,
}
#[allow(clippy::non_send_fields_in_send_ty)]
unsafe impl<Fut: Send> Send for FuturesUnordered<Fut> {}
unsafe impl<Fut: Sync> Sync for FuturesUnordered<Fut> {}
impl<Fut> Unpin for FuturesUnordered<Fut> {}
impl<Fut> Default for FuturesUnordered<Fut> {
fn default() -> Self {
Self::new()
}
}
impl<Fut> FuturesUnordered<Fut> {
pub fn new() -> Self {
let stub = Arc::new(Task {
future: UnsafeCell::new(None),
next_all: AtomicPtr::new(ptr::null_mut()),
prev_all: UnsafeCell::new(ptr::null()),
len_all: UnsafeCell::new(0),
next_ready_to_run: AtomicPtr::new(ptr::null_mut()),
queued: AtomicBool::new(true),
ready_to_run_queue: Weak::new(),
});
let stub_ptr = Arc::as_ptr(&stub);
let ready_to_run_queue = Arc::new(ReadyToRunQueue {
waker: AtomicWaker::new(),
head: AtomicPtr::new(stub_ptr as *mut _),
tail: UnsafeCell::new(stub_ptr),
stub,
});
Self {
head_all: AtomicPtr::new(ptr::null_mut()),
ready_to_run_queue,
is_terminated: AtomicBool::new(false),
}
}
pub fn len(&self) -> usize {
let (_, len) = self.atomic_load_head_and_len_all();
len
}
pub fn is_empty(&self) -> bool {
self.head_all.load(Relaxed).is_null()
}
pub fn push(&self, future: Fut) {
let task = Arc::new(Task {
future: UnsafeCell::new(Some(future)),
next_all: AtomicPtr::new(self.pending_next_all()),
prev_all: UnsafeCell::new(ptr::null_mut()),
len_all: UnsafeCell::new(0),
next_ready_to_run: AtomicPtr::new(ptr::null_mut()),
queued: AtomicBool::new(true),
ready_to_run_queue: Arc::downgrade(&self.ready_to_run_queue),
});
self.is_terminated.store(false, Relaxed);
let ptr = self.link(task);
self.ready_to_run_queue.enqueue(ptr);
}
fn atomic_load_head_and_len_all(&self) -> (*const Task<Fut>, usize) {
let task = self.head_all.load(Acquire);
let len = if task.is_null() {
0
} else {
unsafe {
(*task).spin_next_all(self.pending_next_all(), Acquire);
*(*task).len_all.get()
}
};
(task, len)
}
fn release_task(&mut self, task: Arc<Task<Fut>>) {
debug_assert_eq!(task.next_all.load(Relaxed), self.pending_next_all());
unsafe {
debug_assert!((*task.prev_all.get()).is_null());
}
let prev = task.queued.swap(true, SeqCst);
unsafe {
*task.future.get() = None;
}
if prev {
mem::forget(task);
}
}
fn link(&self, task: Arc<Task<Fut>>) -> *const Task<Fut> {
debug_assert_eq!(task.next_all.load(Relaxed), self.pending_next_all());
let ptr = Arc::into_raw(task);
let next = self.head_all.swap(ptr as *mut _, AcqRel);
unsafe {
let new_len = if next.is_null() {
1
} else {
(*next).spin_next_all(self.pending_next_all(), Acquire);
*(*next).len_all.get() + 1
};
*(*ptr).len_all.get() = new_len;
(*ptr).next_all.store(next, Release);
if !next.is_null() {
*(*next).prev_all.get() = ptr;
}
}
ptr
}
unsafe fn unlink(&mut self, task: *const Task<Fut>) -> Arc<Task<Fut>> {
let head = *self.head_all.get_mut();
debug_assert!(!head.is_null());
let new_len = *(*head).len_all.get() - 1;
let task = Arc::from_raw(task);
let next = task.next_all.load(Relaxed);
let prev = *task.prev_all.get();
task.next_all.store(self.pending_next_all(), Relaxed);
*task.prev_all.get() = ptr::null_mut();
if !next.is_null() {
*(*next).prev_all.get() = prev;
}
if !prev.is_null() {
(*prev).next_all.store(next, Relaxed);
} else {
*self.head_all.get_mut() = next;
}
let head = *self.head_all.get_mut();
if !head.is_null() {
*(*head).len_all.get() = new_len;
}
task
}
fn pending_next_all(&self) -> *mut Task<Fut> {
Arc::as_ptr(&self.ready_to_run_queue.stub) as *mut _
}
}
impl<Fut: Future> FuturesUnordered<Fut> {
pub fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Fut::Output>> {
let yield_every = cmp::min(self.len(), YIELD_EVERY);
let mut polled = 0;
self.ready_to_run_queue.waker.register(cx.waker());
loop {
let task = match unsafe { self.ready_to_run_queue.dequeue() } {
Dequeue::Empty => {
if self.is_empty() {
*self.is_terminated.get_mut() = true;
return Poll::Ready(None);
} else {
return Poll::Pending;
}
}
Dequeue::Inconsistent => {
cx.waker().wake_by_ref();
return Poll::Pending;
}
Dequeue::Data(task) => task,
};
debug_assert!(task != self.ready_to_run_queue.stub());
let future = match unsafe { &mut *(*task).future.get() } {
Some(future) => future,
None => {
let task = unsafe { Arc::from_raw(task) };
debug_assert_eq!(task.next_all.load(Relaxed), self.pending_next_all());
unsafe {
debug_assert!((*task.prev_all.get()).is_null());
}
continue;
}
};
let task = unsafe { self.unlink(task) };
let prev = task.queued.swap(false, SeqCst);
assert!(prev);
struct Bomb<'a, Fut> {
queue: &'a mut FuturesUnordered<Fut>,
task: Option<Arc<Task<Fut>>>,
}
impl<Fut> Drop for Bomb<'_, Fut> {
fn drop(&mut self) {
if let Some(task) = self.task.take() {
self.queue.release_task(task);
}
}
}
let mut bomb = Bomb {
task: Some(task),
queue: &mut *self,
};
let res = {
let waker = Task::waker_ref(bomb.task.as_ref().unwrap());
let mut cx = Context::from_waker(&waker);
let future = unsafe { Pin::new_unchecked(future) };
future.poll(&mut cx)
};
polled += 1;
match res {
Poll::Pending => {
let task = bomb.task.take().unwrap();
bomb.queue.link(task);
if polled == yield_every {
cx.waker().wake_by_ref();
return Poll::Pending;
}
continue;
}
Poll::Ready(output) => return Poll::Ready(Some(output)),
}
}
}
}
impl<Fut> Debug for FuturesUnordered<Fut> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "FuturesUnordered {{ ... }}")
}
}
impl<Fut> FuturesUnordered<Fut> {
#[allow(dead_code)] pub fn clear(&mut self) {
self.clear_head_all();
unsafe { self.ready_to_run_queue.clear() };
self.is_terminated.store(false, Relaxed);
}
fn clear_head_all(&mut self) {
while !self.head_all.get_mut().is_null() {
let head = *self.head_all.get_mut();
let task = unsafe { self.unlink(head) };
self.release_task(task);
}
}
}
impl<Fut> Drop for FuturesUnordered<Fut> {
fn drop(&mut self) {
self.clear_head_all();
}
}
impl<Fut> FromIterator<Fut> for FuturesUnordered<Fut> {
fn from_iter<I>(iter: I) -> Self
where
I: IntoIterator<Item = Fut>,
{
let acc = Self::new();
iter.into_iter().fold(acc, |acc, item| {
acc.push(item);
acc
})
}
}