use alloc::vec::Vec;
use core::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
pub(super) struct SelectAll<'f, Fut> {
futures: Vec<Pin<&'f mut Fut>>,
start_index: Option<usize>,
}
impl<'f, Fut> SelectAll<'f, Fut>
where
Fut: Future,
{
pub(super) fn new(start_index: Option<usize>) -> Self {
SelectAll {
futures: Vec::new(),
start_index,
}
}
pub(super) unsafe fn push_unchecked(&mut self, fut: &'f mut Fut) {
self.futures.push(unsafe { Pin::new_unchecked(fut) })
}
}
impl<'f, Fut> SelectAll<'f, Fut>
where
Fut: Future + Unpin,
{
pub(super) fn push(&mut self, fut: &'f mut Fut) {
self.futures.push(Pin::new(fut));
}
}
impl<Fut, Out> core::future::Future for SelectAll<'_, Fut>
where
Fut: Future<Output = Out>,
{
type Output = (usize, Out);
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let num_futures = self.futures.len();
if num_futures == 0 {
return Poll::Pending;
}
let start_idx = self.start_index.map_or(0, |idx| idx % num_futures);
for i in 0..num_futures {
let idx = (start_idx + i) % num_futures;
if let Poll::Ready(item) = self.futures[idx].as_mut().poll(cx) {
return Poll::Ready((idx, item));
}
}
Poll::Pending
}
}
#[cfg(test)]
mod tests {
use super::*;
use core::{
future::Future,
pin::Pin,
task::{Context, Poll, Waker},
};
#[test]
fn round_robin_fairness() {
let mut future0 = ControlledFuture::new(0);
let mut future1 = ControlledFuture::new(1);
let mut future2 = ControlledFuture::new(2);
future0.set_ready(true);
future1.set_ready(true);
future2.set_ready(true);
let mut select_all = SelectAll::new(Some(0));
select_all.push(&mut future0);
select_all.push(&mut future1);
select_all.push(&mut future2);
let waker = dummy_waker();
let mut cx = Context::from_waker(&waker);
let pinned = Pin::new(&mut select_all);
if let Poll::Ready((idx, value)) = pinned.poll(&mut cx) {
assert_eq!(idx, 0);
assert_eq!(value, 0);
} else {
panic!("Expected first future to be ready");
}
}
#[test]
fn round_robin_with_start_index() {
let mut future0 = ControlledFuture::new(0);
let mut future1 = ControlledFuture::new(1);
let mut future2 = ControlledFuture::new(2);
future0.set_ready(true);
future1.set_ready(true);
future2.set_ready(true);
let mut select_all = SelectAll::new(Some(1));
select_all.push(&mut future0);
select_all.push(&mut future1);
select_all.push(&mut future2);
let waker = dummy_waker();
let mut cx = Context::from_waker(&waker);
let pinned = Pin::new(&mut select_all);
if let Poll::Ready((idx, value)) = pinned.poll(&mut cx) {
assert_eq!(idx, 1);
assert_eq!(value, 1);
} else {
panic!("Expected second future to be ready");
}
}
#[test]
fn round_robin_wrapping() {
let mut future0 = ControlledFuture::new(0);
let mut future1 = ControlledFuture::new(1);
future0.set_ready(true);
future1.set_ready(false);
let mut select_all = SelectAll::new(Some(1));
select_all.push(&mut future0);
select_all.push(&mut future1);
let waker = dummy_waker();
let mut cx = Context::from_waker(&waker);
let pinned = Pin::new(&mut select_all);
if let Poll::Ready((idx, value)) = pinned.poll(&mut cx) {
assert_eq!(idx, 0);
assert_eq!(value, 0);
} else {
panic!("Expected first future to be ready after wrapping");
}
}
#[test]
fn start_index_larger_than_futures() {
let mut future0 = ControlledFuture::new(0);
let mut future1 = ControlledFuture::new(1);
future0.set_ready(true);
future1.set_ready(true);
let mut select_all = SelectAll::new(Some(5)); select_all.push(&mut future0);
select_all.push(&mut future1);
let waker = dummy_waker();
let mut cx = Context::from_waker(&waker);
let pinned = Pin::new(&mut select_all);
if let Poll::Ready((idx, value)) = pinned.poll(&mut cx) {
assert_eq!(idx, 1);
assert_eq!(value, 1);
} else {
panic!("Expected second future to be ready");
}
}
#[test]
fn no_start_index_defaults_to_zero() {
let mut future0 = ControlledFuture::new(0);
let mut future1 = ControlledFuture::new(1);
let mut future2 = ControlledFuture::new(2);
future0.set_ready(true);
future1.set_ready(true);
future2.set_ready(true);
let mut select_all = SelectAll::new(None);
select_all.push(&mut future0);
select_all.push(&mut future1);
select_all.push(&mut future2);
let waker = dummy_waker();
let mut cx = Context::from_waker(&waker);
let pinned = Pin::new(&mut select_all);
if let Poll::Ready((idx, value)) = pinned.poll(&mut cx) {
assert_eq!(idx, 0);
assert_eq!(value, 0);
} else {
panic!("Expected first future to be ready");
}
}
#[test]
fn empty_select_all_returns_pending() {
let mut select_all = SelectAll::<ControlledFuture>::new(Some(0));
let waker = dummy_waker();
let mut cx = Context::from_waker(&waker);
let pinned = Pin::new(&mut select_all);
assert!(matches!(pinned.poll(&mut cx), Poll::Pending));
}
#[test]
fn all_futures_pending() {
let mut future0 = ControlledFuture::new(0);
let mut future1 = ControlledFuture::new(1);
let mut select_all = SelectAll::new(Some(1));
select_all.push(&mut future0);
select_all.push(&mut future1);
let waker = dummy_waker();
let mut cx = Context::from_waker(&waker);
let pinned = Pin::new(&mut select_all);
assert!(matches!(pinned.poll(&mut cx), Poll::Pending));
}
struct ControlledFuture {
ready: bool,
value: usize,
}
impl ControlledFuture {
fn new(value: usize) -> Self {
Self {
ready: false,
value,
}
}
fn set_ready(&mut self, ready: bool) {
self.ready = ready;
}
}
impl Future for ControlledFuture {
type Output = usize;
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.ready {
Poll::Ready(self.value)
} else {
Poll::Pending
}
}
}
fn dummy_waker() -> Waker {
use core::task::{RawWaker, RawWakerVTable};
fn dummy_raw_waker() -> RawWaker {
RawWaker::new(core::ptr::null(), &VTABLE)
}
const VTABLE: RawWakerVTable =
RawWakerVTable::new(|_| dummy_raw_waker(), |_| {}, |_| {}, |_| {});
unsafe { Waker::from_raw(dummy_raw_waker()) }
}
}