use crate::task::AtomicWaker;
use alloc::sync::{Arc, Weak};
use core::cell::UnsafeCell;
use core::fmt::{self, Debug};
use core::iter::FromIterator;
use core::marker::PhantomData;
use core::mem;
use core::pin::Pin;
use core::ptr;
use core::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release, SeqCst};
use core::sync::atomic::{AtomicBool, AtomicPtr};
use futures_core::future::Future;
use futures_core::stream::{FusedStream, Stream};
use futures_core::task::{Context, Poll};
use futures_task::{FutureObj, LocalFutureObj, LocalSpawn, Spawn, SpawnError};
mod abort;
mod iter;
pub use self::iter::{IntoIter, Iter, IterMut, IterPinMut, IterPinRef};
mod task;
use self::task::Task;
mod ready_to_run_queue;
use self::ready_to_run_queue::{Dequeue, ReadyToRunQueue};
#[must_use = "streams do nothing unless polled"]
pub struct FuturesUnordered<Fut> {
ready_to_run_queue: Arc<ReadyToRunQueue<Fut>>,
head_all: AtomicPtr<Task<Fut>>,
is_terminated: AtomicBool,
}
unsafe impl<Fut: Send> Send for FuturesUnordered<Fut> {}
unsafe impl<Fut: Sync> Sync for FuturesUnordered<Fut> {}
impl<Fut> Unpin for FuturesUnordered<Fut> {}
impl Spawn for FuturesUnordered<FutureObj<'_, ()>> {
fn spawn_obj(&self, future_obj: FutureObj<'static, ()>) -> Result<(), SpawnError> {
self.push(future_obj);
Ok(())
}
}
impl LocalSpawn for FuturesUnordered<LocalFutureObj<'_, ()>> {
fn spawn_local_obj(&self, future_obj: LocalFutureObj<'static, ()>) -> Result<(), SpawnError> {
self.push(future_obj);
Ok(())
}
}
impl<Fut> Default for FuturesUnordered<Fut> {
fn default() -> Self {
Self::new()
}
}
impl<Fut> FuturesUnordered<Fut> {
pub fn new() -> Self {
let stub = Arc::new(Task {
future: UnsafeCell::new(None),
next_all: AtomicPtr::new(ptr::null_mut()),
prev_all: UnsafeCell::new(ptr::null()),
len_all: UnsafeCell::new(0),
next_ready_to_run: AtomicPtr::new(ptr::null_mut()),
queued: AtomicBool::new(true),
ready_to_run_queue: Weak::new(),
woken: AtomicBool::new(false),
});
let stub_ptr = Arc::as_ptr(&stub);
let ready_to_run_queue = Arc::new(ReadyToRunQueue {
waker: AtomicWaker::new(),
head: AtomicPtr::new(stub_ptr as *mut _),
tail: UnsafeCell::new(stub_ptr),
stub,
});
Self {
head_all: AtomicPtr::new(ptr::null_mut()),
ready_to_run_queue,
is_terminated: AtomicBool::new(false),
}
}
pub fn len(&self) -> usize {
let (_, len) = self.atomic_load_head_and_len_all();
len
}
pub fn is_empty(&self) -> bool {
self.head_all.load(Relaxed).is_null()
}
pub fn push(&self, future: Fut) {
let task = Arc::new(Task {
future: UnsafeCell::new(Some(future)),
next_all: AtomicPtr::new(self.pending_next_all()),
prev_all: UnsafeCell::new(ptr::null_mut()),
len_all: UnsafeCell::new(0),
next_ready_to_run: AtomicPtr::new(ptr::null_mut()),
queued: AtomicBool::new(true),
ready_to_run_queue: Arc::downgrade(&self.ready_to_run_queue),
woken: AtomicBool::new(false),
});
self.is_terminated.store(false, Relaxed);
let ptr = self.link(task);
self.ready_to_run_queue.enqueue(ptr);
}
pub fn iter(&self) -> Iter<'_, Fut>
where
Fut: Unpin,
{
Iter(Pin::new(self).iter_pin_ref())
}
pub fn iter_pin_ref(self: Pin<&Self>) -> IterPinRef<'_, Fut> {
let (task, len) = self.atomic_load_head_and_len_all();
let pending_next_all = self.pending_next_all();
IterPinRef { task, len, pending_next_all, _marker: PhantomData }
}
pub fn iter_mut(&mut self) -> IterMut<'_, Fut>
where
Fut: Unpin,
{
IterMut(Pin::new(self).iter_pin_mut())
}
pub fn iter_pin_mut(mut self: Pin<&mut Self>) -> IterPinMut<'_, Fut> {
let task = *self.head_all.get_mut();
let len = if task.is_null() { 0 } else { unsafe { *(*task).len_all.get() } };
IterPinMut { task, len, _marker: PhantomData }
}
fn atomic_load_head_and_len_all(&self) -> (*const Task<Fut>, usize) {
let task = self.head_all.load(Acquire);
let len = if task.is_null() {
0
} else {
unsafe {
(*task).spin_next_all(self.pending_next_all(), Acquire);
*(*task).len_all.get()
}
};
(task, len)
}
fn release_task(&mut self, task: Arc<Task<Fut>>) {
debug_assert_eq!(task.next_all.load(Relaxed), self.pending_next_all());
unsafe {
debug_assert!((*task.prev_all.get()).is_null());
}
let prev = task.queued.swap(true, SeqCst);
unsafe {
*task.future.get() = None;
}
if prev {
mem::forget(task);
}
}
fn link(&self, task: Arc<Task<Fut>>) -> *const Task<Fut> {
debug_assert_eq!(task.next_all.load(Relaxed), self.pending_next_all());
let ptr = Arc::into_raw(task);
let next = self.head_all.swap(ptr as *mut _, AcqRel);
unsafe {
let new_len = if next.is_null() {
1
} else {
(*next).spin_next_all(self.pending_next_all(), Acquire);
*(*next).len_all.get() + 1
};
*(*ptr).len_all.get() = new_len;
(*ptr).next_all.store(next, Release);
if !next.is_null() {
*(*next).prev_all.get() = ptr;
}
}
ptr
}
unsafe fn unlink(&mut self, task: *const Task<Fut>) -> Arc<Task<Fut>> {
let head = *self.head_all.get_mut();
debug_assert!(!head.is_null());
let new_len = *(*head).len_all.get() - 1;
let task = Arc::from_raw(task);
let next = task.next_all.load(Relaxed);
let prev = *task.prev_all.get();
task.next_all.store(self.pending_next_all(), Relaxed);
*task.prev_all.get() = ptr::null_mut();
if !next.is_null() {
*(*next).prev_all.get() = prev;
}
if !prev.is_null() {
(*prev).next_all.store(next, Relaxed);
} else {
*self.head_all.get_mut() = next;
}
let head = *self.head_all.get_mut();
if !head.is_null() {
*(*head).len_all.get() = new_len;
}
task
}
fn pending_next_all(&self) -> *mut Task<Fut> {
Arc::as_ptr(&self.ready_to_run_queue.stub) as *mut _
}
}
impl<Fut: Future> Stream for FuturesUnordered<Fut> {
type Item = Fut::Output;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let len = self.len();
let mut polled = 0;
let mut yielded = 0;
self.ready_to_run_queue.waker.register(cx.waker());
loop {
let task = match unsafe { self.ready_to_run_queue.dequeue() } {
Dequeue::Empty => {
if self.is_empty() {
*self.is_terminated.get_mut() = true;
return Poll::Ready(None);
} else {
return Poll::Pending;
}
}
Dequeue::Inconsistent => {
cx.waker().wake_by_ref();
return Poll::Pending;
}
Dequeue::Data(task) => task,
};
debug_assert!(task != self.ready_to_run_queue.stub());
let future = match unsafe { &mut *(*task).future.get() } {
Some(future) => future,
None => {
let task = unsafe { Arc::from_raw(task) };
debug_assert_eq!(task.next_all.load(Relaxed), self.pending_next_all());
unsafe {
debug_assert!((*task.prev_all.get()).is_null());
}
continue;
}
};
let task = unsafe { self.unlink(task) };
let prev = task.queued.swap(false, SeqCst);
assert!(prev);
struct Bomb<'a, Fut> {
queue: &'a mut FuturesUnordered<Fut>,
task: Option<Arc<Task<Fut>>>,
}
impl<Fut> Drop for Bomb<'_, Fut> {
fn drop(&mut self) {
if let Some(task) = self.task.take() {
self.queue.release_task(task);
}
}
}
let mut bomb = Bomb { task: Some(task), queue: &mut *self };
let res = {
let task = bomb.task.as_ref().unwrap();
task.woken.store(false, Relaxed);
let waker = Task::waker_ref(task);
let mut cx = Context::from_waker(&waker);
let future = unsafe { Pin::new_unchecked(future) };
future.poll(&mut cx)
};
polled += 1;
match res {
Poll::Pending => {
let task = bomb.task.take().unwrap();
yielded += task.woken.load(Relaxed) as usize;
bomb.queue.link(task);
if yielded >= 2 || polled == len {
cx.waker().wake_by_ref();
return Poll::Pending;
}
continue;
}
Poll::Ready(output) => return Poll::Ready(Some(output)),
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.len();
(len, Some(len))
}
}
impl<Fut> Debug for FuturesUnordered<Fut> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "FuturesUnordered {{ ... }}")
}
}
impl<Fut> FuturesUnordered<Fut> {
pub fn clear(&mut self) {
self.clear_head_all();
unsafe { self.ready_to_run_queue.clear() };
self.is_terminated.store(false, Relaxed);
}
fn clear_head_all(&mut self) {
while !self.head_all.get_mut().is_null() {
let head = *self.head_all.get_mut();
let task = unsafe { self.unlink(head) };
self.release_task(task);
}
}
}
impl<Fut> Drop for FuturesUnordered<Fut> {
fn drop(&mut self) {
self.clear_head_all();
}
}
impl<'a, Fut: Unpin> IntoIterator for &'a FuturesUnordered<Fut> {
type Item = &'a Fut;
type IntoIter = Iter<'a, Fut>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl<'a, Fut: Unpin> IntoIterator for &'a mut FuturesUnordered<Fut> {
type Item = &'a mut Fut;
type IntoIter = IterMut<'a, Fut>;
fn into_iter(self) -> Self::IntoIter {
self.iter_mut()
}
}
impl<Fut: Unpin> IntoIterator for FuturesUnordered<Fut> {
type Item = Fut;
type IntoIter = IntoIter<Fut>;
fn into_iter(mut self) -> Self::IntoIter {
let task = *self.head_all.get_mut();
let len = if task.is_null() { 0 } else { unsafe { *(*task).len_all.get() } };
IntoIter { len, inner: self }
}
}
impl<Fut> FromIterator<Fut> for FuturesUnordered<Fut> {
fn from_iter<I>(iter: I) -> Self
where
I: IntoIterator<Item = Fut>,
{
let acc = Self::new();
iter.into_iter().fold(acc, |acc, item| {
acc.push(item);
acc
})
}
}
impl<Fut: Future> FusedStream for FuturesUnordered<Fut> {
fn is_terminated(&self) -> bool {
self.is_terminated.load(Relaxed)
}
}
impl<Fut> Extend<Fut> for FuturesUnordered<Fut> {
fn extend<I>(&mut self, iter: I)
where
I: IntoIterator<Item = Fut>,
{
for item in iter {
self.push(item);
}
}
}