extern crate alloc;
use alloc::sync::{Arc, Weak};
use core::cell::UnsafeCell;
use core::fmt::{self, Debug};
use core::hash::Hash;
use core::marker::PhantomData;
use core::mem;
use core::pin::Pin;
use core::ptr;
use core::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release, SeqCst};
use core::sync::atomic::{AtomicBool, AtomicPtr};
use futures_core::future::Future;
use futures_core::stream::{FusedStream, Stream};
use futures_core::task::{Context, Poll};
use futures_util::task::AtomicWaker;
use std::collections::hash_map::RandomState;
use std::collections::HashSet;
use std::hash::BuildHasher;
use std::ops::{Deref, DerefMut};
mod abort;
mod iter;
use self::iter::Keys;
#[allow(unreachable_pub)] pub use self::iter::{IntoIter, Iter, IterMut, IterPinMut, IterPinRef};
mod bi_multi_map_futures;
pub use self::bi_multi_map_futures::BiMultiMapFutures;
mod mapped_streams;
pub use self::mapped_streams::{map_all, MappedStreams};
mod bi_multi_map_streams;
pub use self::bi_multi_map_streams::BiMultiMapStreams;
mod task;
use self::task::{HashTask, Task};
mod ready_to_run_queue;
use self::ready_to_run_queue::{Dequeue, ReadyToRunQueue};
#[must_use = "streams do nothing unless polled"]
pub struct MappedFutures<K: Hash + Eq, Fut, S = RandomState>
where
S: BuildHasher,
{
hash_set: HashSet<HashTask<K, Fut>, S>,
ready_to_run_queue: Arc<ReadyToRunQueue<K, Fut>>,
head_all: AtomicPtr<Task<K, Fut>>,
is_terminated: AtomicBool,
}
unsafe impl<K: Hash + Eq, Fut: Send, S: BuildHasher> Send for MappedFutures<K, Fut, S> {}
unsafe impl<K: Hash + Eq, Fut: Sync, S: BuildHasher> Sync for MappedFutures<K, Fut, S> {}
impl<K: Hash + Eq, Fut, S: BuildHasher> Unpin for MappedFutures<K, Fut, S> {}
impl<K: Hash + Eq, Fut> Default for MappedFutures<K, Fut, RandomState> {
fn default() -> Self {
Self::new()
}
}
impl<K: Hash + Eq, Fut> MappedFutures<K, Fut, RandomState> {
pub fn new() -> MappedFutures<K, Fut, RandomState> {
let stub = Arc::new(Task {
future: UnsafeCell::new(None),
next_all: AtomicPtr::new(ptr::null_mut()),
prev_all: UnsafeCell::new(ptr::null()),
len_all: UnsafeCell::new(0),
next_ready_to_run: AtomicPtr::new(ptr::null_mut()),
queued: AtomicBool::new(true),
ready_to_run_queue: Weak::new(),
woken: AtomicBool::new(false),
key: UnsafeCell::new(None),
});
let stub_ptr = Arc::as_ptr(&stub);
let ready_to_run_queue = Arc::new(ReadyToRunQueue {
waker: AtomicWaker::new(),
head: AtomicPtr::new(stub_ptr as *mut _),
tail: UnsafeCell::new(stub_ptr),
stub,
});
Self {
hash_set: HashSet::new(),
head_all: AtomicPtr::new(ptr::null_mut()),
ready_to_run_queue,
is_terminated: AtomicBool::new(false),
}
}
}
impl<K: Hash + Eq, Fut, S: BuildHasher> MappedFutures<K, Fut, S> {
pub fn len(&self) -> usize {
let (_, len) = self.atomic_load_head_and_len_all();
len
}
pub fn is_empty(&self) -> bool {
self.head_all.load(Relaxed).is_null()
}
pub fn with_hasher(hasher: S) -> MappedFutures<K, Fut, S> {
let stub = Arc::new(Task {
future: UnsafeCell::new(None),
next_all: AtomicPtr::new(ptr::null_mut()),
prev_all: UnsafeCell::new(ptr::null()),
len_all: UnsafeCell::new(0),
next_ready_to_run: AtomicPtr::new(ptr::null_mut()),
queued: AtomicBool::new(true),
ready_to_run_queue: Weak::new(),
woken: AtomicBool::new(false),
key: UnsafeCell::new(None),
});
let stub_ptr = Arc::as_ptr(&stub);
let ready_to_run_queue = Arc::new(ReadyToRunQueue {
waker: AtomicWaker::new(),
head: AtomicPtr::new(stub_ptr as *mut _),
tail: UnsafeCell::new(stub_ptr),
stub,
});
Self {
hash_set: HashSet::with_hasher(hasher),
head_all: AtomicPtr::new(ptr::null_mut()),
ready_to_run_queue,
is_terminated: AtomicBool::new(false),
}
}
pub fn hasher(&self) -> &S {
self.hash_set.hasher()
}
pub fn insert(&mut self, key: K, future: Fut) -> bool {
let replacing = self.cancel(&key);
let task = Arc::new(Task {
future: UnsafeCell::new(Some(future)),
next_all: AtomicPtr::new(self.pending_next_all()),
prev_all: UnsafeCell::new(ptr::null_mut()),
len_all: UnsafeCell::new(0),
next_ready_to_run: AtomicPtr::new(ptr::null_mut()),
queued: AtomicBool::new(true),
ready_to_run_queue: Arc::downgrade(&self.ready_to_run_queue),
woken: AtomicBool::new(false),
key: UnsafeCell::new(Some(key)),
});
self.is_terminated.store(false, Relaxed);
let ptr = self.link(task);
self.ready_to_run_queue.enqueue(ptr);
!replacing
}
pub fn replace(&mut self, key: K, future: Fut) -> Option<Fut>
where
Fut: Unpin,
{
let replacing = self.remove(&key);
let task = Arc::new(Task {
future: UnsafeCell::new(Some(future)),
next_all: AtomicPtr::new(self.pending_next_all()),
prev_all: UnsafeCell::new(ptr::null_mut()),
len_all: UnsafeCell::new(0),
next_ready_to_run: AtomicPtr::new(ptr::null_mut()),
queued: AtomicBool::new(true),
ready_to_run_queue: Arc::downgrade(&self.ready_to_run_queue),
woken: AtomicBool::new(false),
key: UnsafeCell::new(Some(key)),
});
self.is_terminated.store(false, Relaxed);
let ptr = self.link(task);
self.ready_to_run_queue.enqueue(ptr);
replacing
}
pub fn cancel(&mut self, key: &K) -> bool {
if let Some(task) = self.hash_set.get(key) {
unsafe {
if (*task.future.get()).is_some() {
let task_clone = task.inner.clone();
self.unlink(Arc::as_ptr(&task.inner));
self.release_task(task_clone);
return true;
}
}
}
false
}
pub fn remove(&mut self, key: &K) -> Option<Fut>
where
Fut: Unpin,
{
if let Some(task) = self.hash_set.get(key) {
unsafe {
let fut = (*task.future.get()).take().unwrap();
let task_clone = task.inner.clone();
self.unlink(Arc::as_ptr(&task.inner));
self.release_task(task_clone);
return Some(fut);
}
}
None
}
pub fn contains(&mut self, key: &K) -> bool {
self.hash_set.contains(key)
}
pub fn get_pin_mut<'a>(&'a mut self, key: &K) -> Option<Pin<FutMut<'a, K, Fut>>> {
if let Some(task_ref) = self.hash_set.get(key) {
unsafe {
return Some(Pin::new_unchecked(FutMut::new(&task_ref.inner)));
}
}
None
}
pub fn get_mut<'a>(&mut self, key: &K) -> Option<FutMut<K, Fut>>
where
Fut: Unpin,
{
if let Some(task_ref) = self.hash_set.get(key) {
return Some(FutMut::new(&task_ref.inner));
}
None
}
pub fn get<'a>(&mut self, key: &K) -> Option<&'a Fut>
where
Fut: Unpin,
{
if let Some(task_ref) = self.hash_set.get(key) {
unsafe {
if let Some(fut) = &*task_ref.inner.future.get() {
return Some(fut);
}
}
}
None
}
pub fn get_pin(&mut self, key: &K) -> Option<Pin<&Fut>> {
if let Some(task_ref) = self.hash_set.get(key) {
unsafe {
if let Some(fut) = &*task_ref.future.get() {
return Some(Pin::new_unchecked(fut));
}
}
}
None
}
pub fn keys(&self) -> Keys<'_, K, Fut> {
Keys {
inner: self.hash_set.iter(),
}
}
pub fn iter(&self) -> Iter<'_, K, Fut, S>
where
Fut: Unpin,
{
Iter(Pin::new(self).iter_pin_ref())
}
pub fn iter_pin_ref(self: Pin<&Self>) -> IterPinRef<'_, K, Fut, S> {
let (task, len) = self.atomic_load_head_and_len_all();
let pending_next_all = self.pending_next_all();
IterPinRef {
task,
len,
pending_next_all,
_marker: PhantomData,
}
}
pub fn iter_mut(&mut self) -> IterMut<'_, K, Fut, S>
where
Fut: Unpin,
{
IterMut(Pin::new(self).iter_pin_mut())
}
pub fn iter_pin_mut(mut self: Pin<&mut Self>) -> IterPinMut<'_, K, Fut, S> {
let task = *self.head_all.get_mut();
let len = if task.is_null() {
0
} else {
unsafe { *(*task).len_all.get() }
};
IterPinMut {
task,
len,
_marker: PhantomData,
}
}
fn atomic_load_head_and_len_all(&self) -> (*const Task<K, Fut>, usize) {
let task = self.head_all.load(Acquire);
let len = if task.is_null() {
0
} else {
unsafe {
(*task).spin_next_all(self.pending_next_all(), Acquire);
*(*task).len_all.get()
}
};
(task, len)
}
fn release_task(&mut self, task: Arc<Task<K, Fut>>) {
debug_assert_eq!(task.next_all.load(Relaxed), self.pending_next_all());
unsafe {
debug_assert!((*task.prev_all.get()).is_null());
}
let prev = task.queued.swap(true, SeqCst);
unsafe {
*task.future.get() = None;
let key = &*task.key.get();
if let Some(key) = key {
self.hash_set.remove(key);
}
}
if prev {
mem::forget(task);
}
}
fn link(&mut self, task: Arc<Task<K, Fut>>) -> *const Task<K, Fut> {
debug_assert_eq!(task.next_all.load(Relaxed), self.pending_next_all());
let hash_task = HashTask {
inner: task.clone(),
};
self.hash_set.insert(hash_task);
let ptr = Arc::into_raw(task);
let next = self.head_all.swap(ptr as *mut _, AcqRel);
unsafe {
let new_len = if next.is_null() {
1
} else {
(*next).spin_next_all(self.pending_next_all(), Acquire);
*(*next).len_all.get() + 1
};
*(*ptr).len_all.get() = new_len;
(*ptr).next_all.store(next, Release);
if !next.is_null() {
*(*next).prev_all.get() = ptr;
}
}
ptr
}
unsafe fn unlink(&mut self, task: *const Task<K, Fut>) -> Arc<Task<K, Fut>> {
let head = *self.head_all.get_mut();
debug_assert!(!head.is_null());
let new_len = *(*head).len_all.get() - 1;
if let Some(key) = (*task).key() {
self.hash_set.remove(key);
}
let task = Arc::from_raw(task);
let next = task.next_all.load(Relaxed);
let prev = *task.prev_all.get();
task.next_all.store(self.pending_next_all(), Relaxed);
*task.prev_all.get() = ptr::null_mut();
if !next.is_null() {
*(*next).prev_all.get() = prev;
}
if !prev.is_null() {
(*prev).next_all.store(next, Relaxed);
} else {
*self.head_all.get_mut() = next;
}
let head = *self.head_all.get_mut();
if !head.is_null() {
*(*head).len_all.get() = new_len;
}
task
}
fn pending_next_all(&self) -> *mut Task<K, Fut> {
Arc::as_ptr(&self.ready_to_run_queue.stub) as *mut _
}
}
impl<K: Hash + Eq, Fut: Future, S: BuildHasher> Stream for MappedFutures<K, Fut, S> {
type Item = (K, Fut::Output);
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let len = self.len();
let mut polled = 0;
let mut yielded = 0;
self.ready_to_run_queue.waker.register(cx.waker());
loop {
let task = match unsafe { self.ready_to_run_queue.dequeue() } {
Dequeue::Empty => {
if self.is_empty() {
*self.is_terminated.get_mut() = true;
return Poll::Ready(None);
}
return Poll::Pending;
}
Dequeue::Inconsistent => {
cx.waker().wake_by_ref();
return Poll::Pending;
}
Dequeue::Data(task) => task,
};
debug_assert!(task != self.ready_to_run_queue.stub());
let future = match unsafe { &mut *(*task).future.get() } {
Some(future) => future,
None => {
let task = unsafe { Arc::from_raw(task) };
debug_assert_eq!(task.next_all.load(Relaxed), self.pending_next_all());
unsafe {
debug_assert!((*task.prev_all.get()).is_null());
}
continue;
}
};
let task = unsafe { self.unlink(task) };
let prev = task.queued.swap(false, SeqCst);
assert!(prev);
struct Bomb<'a, K: Hash + Eq, Fut, S: BuildHasher> {
queue: &'a mut MappedFutures<K, Fut, S>,
task: Option<Arc<Task<K, Fut>>>,
}
impl<K: Hash + Eq, Fut, S: BuildHasher> Drop for Bomb<'_, K, Fut, S> {
fn drop(&mut self) {
if let Some(task) = self.task.take() {
self.queue.release_task(task);
}
}
}
let mut bomb = Bomb {
task: Some(task),
queue: &mut *self,
};
let res = {
let task = bomb.task.as_ref().unwrap();
task.woken.store(false, Relaxed);
let waker = unsafe { Task::waker_ref(task) };
let mut cx = Context::from_waker(&waker);
let future = unsafe { Pin::new_unchecked(future) };
future.poll(&mut cx)
};
polled += 1;
match res {
Poll::Pending => {
let task = bomb.task.take().unwrap();
yielded += task.woken.load(Relaxed) as usize;
bomb.queue.link(task);
if yielded >= 2 || polled == len {
cx.waker().wake_by_ref();
return Poll::Pending;
}
continue;
}
Poll::Ready(output) => {
return Poll::Ready(Some((bomb.task.as_ref().unwrap().take_key(), output)));
}
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.len();
(len, Some(len))
}
}
impl<K: Hash + Eq, Fut, S: BuildHasher> Debug for MappedFutures<K, Fut, S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "MappedFutures {{ ... }}")
}
}
impl<K: Hash + Eq, Fut, S: BuildHasher> MappedFutures<K, Fut, S> {
pub fn clear(&mut self) {
self.clear_head_all();
self.hash_set.clear();
unsafe { self.ready_to_run_queue.clear() };
self.is_terminated.store(false, Relaxed);
}
fn clear_head_all(&mut self) {
while !self.head_all.get_mut().is_null() {
let head = *self.head_all.get_mut();
let task = unsafe { self.unlink(head) };
self.release_task(task);
}
}
}
impl<K: Hash + Eq, Fut, S: BuildHasher> Drop for MappedFutures<K, Fut, S> {
fn drop(&mut self) {
self.clear_head_all();
}
}
impl<'a, K: Hash + Eq, Fut: Unpin, S: BuildHasher> IntoIterator for &'a MappedFutures<K, Fut, S> {
type Item = &'a Fut;
type IntoIter = Iter<'a, K, Fut, S>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl<'a, K: Hash + Eq, Fut: Unpin, S: BuildHasher> IntoIterator
for &'a mut MappedFutures<K, Fut, S>
{
type Item = &'a mut Fut;
type IntoIter = IterMut<'a, K, Fut, S>;
fn into_iter(self) -> Self::IntoIter {
self.iter_mut()
}
}
impl<K: Hash + Eq, Fut: Unpin, S: BuildHasher> IntoIterator for MappedFutures<K, Fut, S> {
type Item = Fut;
type IntoIter = IntoIter<K, Fut, S>;
fn into_iter(mut self) -> Self::IntoIter {
let task = *self.head_all.get_mut();
let len = if task.is_null() {
0
} else {
unsafe { *(*task).len_all.get() }
};
IntoIter { len, inner: self }
}
}
impl<K: Hash + Eq, Fut> FromIterator<(K, Fut)> for MappedFutures<K, Fut, RandomState> {
fn from_iter<I>(iter: I) -> Self
where
I: IntoIterator<Item = (K, Fut)>,
{
let acc = Self::new();
iter.into_iter().fold(acc, |mut acc, (key, item)| {
acc.insert(key, item);
acc
})
}
}
impl<K: Hash + Eq, Fut: Future, S: BuildHasher> FusedStream for MappedFutures<K, Fut, S> {
fn is_terminated(&self) -> bool {
self.is_terminated.load(Relaxed)
}
}
impl<K: Hash + Eq, Fut, S: BuildHasher> Extend<(K, Fut)> for MappedFutures<K, Fut, S> {
fn extend<I>(&mut self, iter: I)
where
I: IntoIterator<Item = (K, Fut)>,
{
for (key, item) in iter {
self.insert(key, item);
}
}
}
pub struct FutMut<'a, K: Hash + Eq, Fut> {
inner: *const Task<K, Fut>,
mutated: bool,
_marker: PhantomData<&'a mut Task<K, Fut>>,
}
impl<'a, K: Hash + Eq, Fut> Deref for FutMut<'a, K, Fut> {
type Target = Fut;
fn deref(&self) -> &Self::Target {
unsafe { (*(*self.inner).future.get()).as_ref().unwrap() }
}
}
impl<'a, K: Hash + Eq, Fut> DerefMut for FutMut<'a, K, Fut> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.mutated = true;
unsafe { (*(*self.inner).future.get()).as_mut().unwrap() }
}
}
impl<'a, K: Hash + Eq, Fut> Drop for FutMut<'a, K, Fut> {
fn drop(&mut self) {
if self.mutated {
Task::wake_by_ptr(self.inner);
}
}
}
impl<'a, K: Hash + Eq, Fut> FutMut<'a, K, Fut> {
fn new(task: &'a Arc<Task<K, Fut>>) -> Self {
FutMut {
inner: Arc::as_ptr(task),
mutated: false,
_marker: PhantomData,
}
}
fn new_from_ptr(task: *const Task<K, Fut>) -> Self {
FutMut {
inner: task,
mutated: false,
_marker: PhantomData,
}
}
}
#[cfg(test)]
pub mod tests {
use crate::*;
use futures::executor::block_on;
use futures::future::LocalBoxFuture;
use futures_timer::Delay;
use futures_util::StreamExt;
use std::time::Duration;
fn insert_millis(futs: &mut MappedFutures<u32, Delay>, key: u32, millis: u64) -> bool {
futs.insert(key, Delay::new(Duration::from_millis(millis)))
}
fn insert_millis_pinned(
futs: &mut MappedFutures<u32, LocalBoxFuture<'static, ()>>,
key: u32,
millis: u64,
) {
futs.insert(key, Box::pin(Delay::new(Duration::from_millis(millis))));
}
#[test]
fn map_futures() {
let mut futures = MappedFutures::new();
insert_millis(&mut futures, 1, 50);
insert_millis(&mut futures, 2, 50);
insert_millis(&mut futures, 3, 150);
insert_millis(&mut futures, 4, 200);
assert_eq!(block_on(futures.next()).unwrap().0, 1);
assert_eq!(futures.cancel(&3), true);
assert_eq!(block_on(futures.next()).unwrap().0, 2);
assert_eq!(block_on(futures.next()).unwrap().0, 4);
assert_eq!(block_on(futures.next()), None);
assert!(futures.is_empty());
}
#[test]
fn add_duplicate() {
let mut futures = MappedFutures::new();
assert_eq!(true, insert_millis(&mut futures, 1, 50));
assert_eq!(false, insert_millis(&mut futures, 1, 50));
}
#[test]
fn remove_unpinned() {
let mut futures = MappedFutures::new();
insert_millis(&mut futures, 1, 50);
assert_eq!(block_on(futures.remove(&1).unwrap()), ());
}
#[test]
fn remove_pinned() {
let mut futures = MappedFutures::new();
insert_millis_pinned(&mut futures, 1, 50);
insert_millis_pinned(&mut futures, 3, 150);
insert_millis_pinned(&mut futures, 4, 200);
assert_eq!(block_on(futures.next()).unwrap().0, 1);
assert_eq!(block_on(futures.remove(&3).unwrap()), ());
insert_millis_pinned(&mut futures, 2, 60);
assert_eq!(block_on(futures.next()).unwrap().0, 4);
assert_eq!(block_on(futures.next()).unwrap().0, 2);
assert_eq!(block_on(futures.next()), None);
}
#[test]
fn mutate() {
let mut futures = MappedFutures::new();
insert_millis(&mut futures, 1, 500);
insert_millis(&mut futures, 2, 1000);
insert_millis(&mut futures, 3, 1500);
insert_millis(&mut futures, 4, 2000);
assert_eq!(block_on(futures.next()).unwrap().0, 1);
futures
.get_mut(&3)
.unwrap()
.reset(Duration::from_millis(300));
assert_eq!(block_on(futures.next()).unwrap().0, 3);
assert_eq!(block_on(futures.next()).unwrap().0, 2);
assert_eq!(block_on(futures.next()).unwrap().0, 4);
assert_eq!(block_on(futures.next()), None);
}
struct WakeMutateTest {
inner: MappedFutures<u32, BoolFut>,
}
impl Future for WakeMutateTest {
type Output = bool;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Poll::Pending = self.inner.poll_next_unpin(cx) {
self.inner.get_mut(&1).unwrap().0 = true;
if let Poll::Ready(_) = self.inner.poll_next_unpin(cx) {
return Poll::Ready(true);
}
}
Poll::Ready(false)
}
}
struct BoolFut(bool);
impl Future for BoolFut {
type Output = ();
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.0 {
true => Poll::Ready(()),
false => Poll::Pending,
}
}
}
#[test]
fn wake_mutate() {
let mut futures = MappedFutures::new();
futures.insert(1, BoolFut(false));
let fut = WakeMutateTest { inner: futures };
assert!(block_on(fut));
}
}