use core::{
cell::Cell,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use pin_project::pin_project;
#[derive(Debug, PartialEq)]
pub enum Probe<T> {
Hit(T),
Miss,
}
impl<T> Probe<T> {
#[inline(always)]
pub fn unwrap(self) -> T {
match self {
Probe::Hit(v) => v,
Probe::Miss => panic!("called Probe::unwrap() on a Miss value"),
}
}
#[inline(always)]
pub fn require<E>(self, on_miss: impl FnOnce() -> E) -> Result<T, E> {
match self {
Probe::Hit(v) => Ok(v),
Probe::Miss => Err(on_miss()),
}
}
#[inline(always)]
pub fn is_miss(&self) -> bool {
matches!(self, Probe::Miss)
}
#[inline(always)]
pub fn is_hit(&self) -> bool {
matches!(self, Probe::Hit(_))
}
#[inline(always)]
pub fn map<U>(self, f: impl FnOnce(T) -> U) -> Probe<U> {
match self {
Probe::Hit(v) => Probe::Hit(f(v)),
Probe::Miss => Probe::Miss,
}
}
}
#[macro_export]
macro_rules! hit {
($e:expr) => {
match $e? {
$crate::Probe::Hit(v) => v,
$crate::Probe::Miss => return ::core::result::Result::Ok($crate::Probe::Miss),
}
};
}
#[macro_export]
macro_rules! or_miss {
($e:expr) => {
match $e {
::core::option::Option::Some(v) => v,
::core::option::Option::None => return ::core::result::Result::Ok($crate::Probe::Miss),
}
};
}
#[doc(hidden)]
#[macro_export]
macro_rules! __select_probe_arms {
($($arm:expr),+) => {
$crate::probe::SelectProbeBase $(+ $crate::probe::ProbeArm($arm))+
};
}
#[macro_export]
macro_rules! select_probe {
(biased; $($arm:expr),* $(, @miss => $miss:expr)? $(,)?) => {$crate::__select_probe_inner!(true; $($arm),* $(, @miss => $miss)?)};
($($arm:expr),* $(, @miss => $miss:expr)? $(,)?) => {$crate::__select_probe_inner!(false; $($arm),* $(, @miss => $miss)?)};
}
#[doc(hidden)]
#[macro_export]
macro_rules! __select_probe_inner {
($bias:literal; $($arm:expr),* $(, @miss => $miss:expr)?) => {{
#[allow(unused_imports)]
use $crate::probe::MissFallback as _;
let __probe_kills = <$crate::probe::KillManager<_>>::new();
let __probe_kills = &__probe_kills;
#[allow(unused_macros)]
macro_rules! kill {
($i:literal) => { __probe_kills.mark($i) };
}
let _miss = ::core::result::Result::Ok($crate::Probe::Miss);
let _miss2: *const _ = &_miss;
$(let _miss = $miss;)?
let miss = || $crate::probe::MissWrapper(_miss).into_future();
$crate::probe::align_types(_miss2, &miss);
let __probe_futs = $crate::__select_probe_arms!($($arm),+);
$crate::probe::select_probe($bias, __probe_kills, __probe_futs, miss).await
}};
}
#[doc(hidden)]
pub struct MissWrapper<T>(pub T);
impl<T: Future> MissWrapper<T> {
#[inline(always)]
pub fn into_future(self) -> T {
self.0
}
}
#[doc(hidden)]
pub trait MissFallback {
type Fut: Future;
fn into_future(self) -> Self::Fut;
}
impl<T, E> MissFallback for MissWrapper<Result<Probe<T>, E>> {
type Fut = core::future::Ready<Result<Probe<T>, E>>;
#[inline(always)]
fn into_future(self) -> Self::Fut {
core::future::ready(self.0)
}
}
#[doc(hidden)]
pub trait SelectProbeKillFlag {
const SIZE: usize;
fn new() -> Self;
fn is_marked(&self, i: usize) -> bool;
fn mark(&self, i: usize);
}
impl SelectProbeKillFlag for () {
const SIZE: usize = 0;
#[inline(always)]
fn new() -> Self {}
#[inline(always)]
fn is_marked(&self, _i: usize) -> bool {
unreachable!()
}
#[inline(always)]
fn mark(&self, _i: usize) {
unreachable!()
}
}
impl<T: SelectProbeKillFlag> SelectProbeKillFlag for (T, Cell<bool>) {
const SIZE: usize = T::SIZE + 1;
#[inline(always)]
fn new() -> Self {
(T::new(), Cell::new(false))
}
#[inline(always)]
fn is_marked(&self, i: usize) -> bool {
debug_assert!(i < Self::SIZE);
if i == Self::SIZE - 1 {
self.1.get()
} else {
self.0.is_marked(i)
}
}
#[inline(always)]
fn mark(&self, i: usize) {
debug_assert!(i < Self::SIZE);
if i == Self::SIZE - 1 {
self.1.set(true);
} else {
self.0.mark(i)
}
}
}
#[doc(hidden)]
pub struct KillManager<T: SelectProbeKillFlag> {
flags: T,
new_kills: Cell<bool>,
}
impl<T: SelectProbeKillFlag> Default for KillManager<T> {
#[inline(always)]
fn default() -> Self {
Self::new()
}
}
impl<T: SelectProbeKillFlag> KillManager<T> {
#[inline(always)]
pub fn new() -> Self {
Self {
flags: T::new(),
new_kills: Cell::new(false),
}
}
#[inline(always)]
pub fn mark(&self, i: usize) {
self.flags.mark(i);
self.new_kills.set(true);
}
#[inline(always)]
fn take_new_kills(&self) -> bool {
self.new_kills.replace(false)
}
}
#[doc(hidden)]
pub struct ProbeArm<F>(pub F);
#[doc(hidden)]
pub struct SelectProbeBase;
impl<F> core::ops::Add<ProbeArm<F>> for SelectProbeBase {
type Output = SP1<F>;
#[inline(always)]
fn add(self, rhs: ProbeArm<F>) -> SP1<F> {
SP1(Some(rhs.0))
}
}
#[doc(hidden)]
#[pin_project]
pub struct SP1<A>(#[pin] Option<A>);
#[doc(hidden)]
#[pin_project]
pub struct SP2<A, B>(#[pin] Option<A>, #[pin] Option<B>);
#[doc(hidden)]
#[pin_project]
pub struct SP3<A, B, C>(#[pin] Option<A>, #[pin] Option<B>, #[pin] Option<C>);
#[doc(hidden)]
#[pin_project]
pub struct SP4<A, B, C, D>(
#[pin] Option<A>,
#[pin] Option<B>,
#[pin] Option<C>,
#[pin] Option<D>,
);
impl<A, B> core::ops::Add<ProbeArm<B>> for SP1<A> {
type Output = SP2<A, B>;
#[inline(always)]
fn add(self, rhs: ProbeArm<B>) -> SP2<A, B> {
SP2(self.0, Some(rhs.0))
}
}
impl<A, B, C> core::ops::Add<ProbeArm<C>> for SP2<A, B> {
type Output = SP3<A, B, C>;
#[inline(always)]
fn add(self, rhs: ProbeArm<C>) -> SP3<A, B, C> {
SP3(self.0, self.1, Some(rhs.0))
}
}
impl<A, B, C, D> core::ops::Add<ProbeArm<D>> for SP3<A, B, C> {
type Output = SP4<A, B, C, D>;
#[inline(always)]
fn add(self, rhs: ProbeArm<D>) -> SP4<A, B, C, D> {
SP4(self.0, self.1, self.2, Some(rhs.0))
}
}
impl<A, B, C, D, E> core::ops::Add<ProbeArm<E>> for SP4<A, B, C, D> {
type Output = SelectProbeSlot<SP4<A, B, C, D>, E>;
#[inline(always)]
fn add(self, rhs: ProbeArm<E>) -> SelectProbeSlot<SP4<A, B, C, D>, E> {
SelectProbeSlot {
rest: self,
fut: Some(rhs.0),
}
}
}
impl<L, F, G> core::ops::Add<ProbeArm<G>> for SelectProbeSlot<L, F> {
type Output = SelectProbeSlot<SelectProbeSlot<L, F>, G>;
#[inline(always)]
fn add(self, rhs: ProbeArm<G>) -> SelectProbeSlot<SelectProbeSlot<L, F>, G> {
SelectProbeSlot {
rest: self,
fut: Some(rhs.0),
}
}
}
#[doc(hidden)]
#[pin_project]
pub struct SelectProbeSlot<Rest, F> {
#[pin]
rest: Rest,
#[pin]
fut: Option<F>,
}
#[doc(hidden)]
pub trait IntoSelectProbeSlots {
type Slots;
fn into_slots(self) -> Self::Slots;
}
impl<A> IntoSelectProbeSlots for SP1<A> {
type Slots = Self;
#[inline(always)]
fn into_slots(self) -> Self {
self
}
}
impl<A, B> IntoSelectProbeSlots for SP2<A, B> {
type Slots = Self;
#[inline(always)]
fn into_slots(self) -> Self {
self
}
}
impl<A, B, C> IntoSelectProbeSlots for SP3<A, B, C> {
type Slots = Self;
#[inline(always)]
fn into_slots(self) -> Self {
self
}
}
impl<A, B, C, D> IntoSelectProbeSlots for SP4<A, B, C, D> {
type Slots = Self;
#[inline(always)]
fn into_slots(self) -> Self {
self
}
}
impl<L: IntoSelectProbeSlots, F> IntoSelectProbeSlots for SelectProbeSlot<L, F> {
type Slots = SelectProbeSlot<L::Slots, F>;
#[inline(always)]
fn into_slots(self) -> Self::Slots {
SelectProbeSlot {
rest: self.rest.into_slots(),
fut: self.fut,
}
}
}
#[doc(hidden)]
pub trait SelectProbeFutures<T, E> {
const SIZE: usize;
type KF: SelectProbeKillFlag;
fn process_kills(self: Pin<&mut Self>, kill_flags: &Self::KF, active: &mut usize);
fn poll_one(
self: Pin<&mut Self>,
i: usize,
active: &mut usize,
cx: &mut Context<'_>,
) -> Poll<Result<Probe<T>, E>>;
}
#[inline(always)]
fn poll_slot<T, E, F: Future<Output = Result<Probe<T>, E>>>(
mut slot: Pin<&mut Option<F>>,
active: &mut usize,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Probe<T>, E>>> {
match slot.as_mut().as_pin_mut() {
None => Poll::Pending,
Some(fut) => match Future::poll(fut, cx) {
Poll::Ready(Ok(Probe::Hit(v))) => Poll::Ready(Some(Ok(Probe::Hit(v)))),
Poll::Ready(Ok(Probe::Miss)) => {
slot.set(None);
*active -= 1;
Poll::Pending
}
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
Poll::Pending => Poll::Pending,
},
}
}
#[doc(hidden)]
pub struct FlatKillFlags<const N: usize>([Cell<bool>; N]);
impl<const N: usize> FlatKillFlags<N> {
#[inline(always)]
fn new() -> Self {
Self(core::array::from_fn(|_| Cell::new(false)))
}
}
impl<const N: usize> SelectProbeKillFlag for FlatKillFlags<N> {
const SIZE: usize = N;
#[inline(always)]
fn new() -> Self {
Self::new()
}
#[inline(always)]
fn is_marked(&self, i: usize) -> bool {
self.0[i].get()
}
#[inline(always)]
fn mark(&self, i: usize) {
self.0[i].set(true);
}
}
impl<T, E, A: Future<Output = Result<Probe<T>, E>>> SelectProbeFutures<T, E> for SP1<A> {
const SIZE: usize = 1;
type KF = FlatKillFlags<1>;
#[inline(always)]
fn process_kills(self: Pin<&mut Self>, kill_flags: &Self::KF, active: &mut usize) {
let mut slot = self.project().0;
if kill_flags.0[0].get() && slot.as_mut().as_pin_mut().is_some() {
slot.set(None);
*active -= 1;
}
}
#[inline(always)]
fn poll_one(
self: Pin<&mut Self>,
i: usize,
active: &mut usize,
cx: &mut Context<'_>,
) -> Poll<Result<Probe<T>, E>> {
debug_assert_eq!(i, 0);
match poll_slot(self.project().0, active, cx) {
Poll::Ready(Some(r)) => Poll::Ready(r),
_ => Poll::Pending,
}
}
}
impl<T, E, A: Future<Output = Result<Probe<T>, E>>, B: Future<Output = Result<Probe<T>, E>>>
SelectProbeFutures<T, E> for SP2<A, B>
{
const SIZE: usize = 2;
type KF = FlatKillFlags<2>;
#[inline(always)]
fn process_kills(self: Pin<&mut Self>, kill_flags: &Self::KF, active: &mut usize) {
let mut p = self.project();
if kill_flags.0[0].get() && p.0.as_mut().as_pin_mut().is_some() {
p.0.set(None);
*active -= 1;
}
if kill_flags.0[1].get() && p.1.as_mut().as_pin_mut().is_some() {
p.1.set(None);
*active -= 1;
}
}
#[inline(always)]
fn poll_one(
self: Pin<&mut Self>,
i: usize,
active: &mut usize,
cx: &mut Context<'_>,
) -> Poll<Result<Probe<T>, E>> {
debug_assert!(i < 2);
let p = self.project();
match i {
0 => match poll_slot(p.0, active, cx) {
Poll::Ready(Some(r)) => Poll::Ready(r),
_ => Poll::Pending,
},
_ => match poll_slot(p.1, active, cx) {
Poll::Ready(Some(r)) => Poll::Ready(r),
_ => Poll::Pending,
},
}
}
}
impl<
T,
E,
A: Future<Output = Result<Probe<T>, E>>,
B: Future<Output = Result<Probe<T>, E>>,
C: Future<Output = Result<Probe<T>, E>>,
> SelectProbeFutures<T, E> for SP3<A, B, C>
{
const SIZE: usize = 3;
type KF = FlatKillFlags<3>;
#[inline(always)]
fn process_kills(self: Pin<&mut Self>, kill_flags: &Self::KF, active: &mut usize) {
let mut p = self.project();
if kill_flags.0[0].get() && p.0.as_mut().as_pin_mut().is_some() {
p.0.set(None);
*active -= 1;
}
if kill_flags.0[1].get() && p.1.as_mut().as_pin_mut().is_some() {
p.1.set(None);
*active -= 1;
}
if kill_flags.0[2].get() && p.2.as_mut().as_pin_mut().is_some() {
p.2.set(None);
*active -= 1;
}
}
#[inline(always)]
fn poll_one(
self: Pin<&mut Self>,
i: usize,
active: &mut usize,
cx: &mut Context<'_>,
) -> Poll<Result<Probe<T>, E>> {
debug_assert!(i < 3);
let p = self.project();
match i {
0 => match poll_slot(p.0, active, cx) {
Poll::Ready(Some(r)) => Poll::Ready(r),
_ => Poll::Pending,
},
1 => match poll_slot(p.1, active, cx) {
Poll::Ready(Some(r)) => Poll::Ready(r),
_ => Poll::Pending,
},
_ => match poll_slot(p.2, active, cx) {
Poll::Ready(Some(r)) => Poll::Ready(r),
_ => Poll::Pending,
},
}
}
}
impl<
T,
E,
A: Future<Output = Result<Probe<T>, E>>,
B: Future<Output = Result<Probe<T>, E>>,
C: Future<Output = Result<Probe<T>, E>>,
D: Future<Output = Result<Probe<T>, E>>,
> SelectProbeFutures<T, E> for SP4<A, B, C, D>
{
const SIZE: usize = 4;
type KF = FlatKillFlags<4>;
#[inline(always)]
fn process_kills(self: Pin<&mut Self>, kill_flags: &Self::KF, active: &mut usize) {
let mut p = self.project();
if kill_flags.0[0].get() && p.0.as_mut().as_pin_mut().is_some() {
p.0.set(None);
*active -= 1;
}
if kill_flags.0[1].get() && p.1.as_mut().as_pin_mut().is_some() {
p.1.set(None);
*active -= 1;
}
if kill_flags.0[2].get() && p.2.as_mut().as_pin_mut().is_some() {
p.2.set(None);
*active -= 1;
}
if kill_flags.0[3].get() && p.3.as_mut().as_pin_mut().is_some() {
p.3.set(None);
*active -= 1;
}
}
#[inline(always)]
fn poll_one(
self: Pin<&mut Self>,
i: usize,
active: &mut usize,
cx: &mut Context<'_>,
) -> Poll<Result<Probe<T>, E>> {
debug_assert!(i < 4);
let p = self.project();
match i {
0 => match poll_slot(p.0, active, cx) {
Poll::Ready(Some(r)) => Poll::Ready(r),
_ => Poll::Pending,
},
1 => match poll_slot(p.1, active, cx) {
Poll::Ready(Some(r)) => Poll::Ready(r),
_ => Poll::Pending,
},
2 => match poll_slot(p.2, active, cx) {
Poll::Ready(Some(r)) => Poll::Ready(r),
_ => Poll::Pending,
},
_ => match poll_slot(p.3, active, cx) {
Poll::Ready(Some(r)) => Poll::Ready(r),
_ => Poll::Pending,
},
}
}
}
impl<T, E, F: Future<Output = Result<Probe<T>, E>>, Rest: SelectProbeFutures<T, E>>
SelectProbeFutures<T, E> for SelectProbeSlot<Rest, F>
{
const SIZE: usize = Rest::SIZE + 1;
type KF = (Rest::KF, Cell<bool>);
#[inline(always)]
fn process_kills(self: Pin<&mut Self>, kill_flags: &Self::KF, active: &mut usize) {
let mut this = self.project();
this.rest.process_kills(&kill_flags.0, active);
if kill_flags.1.get() && this.fut.as_mut().as_pin_mut().is_some() {
this.fut.set(None);
*active -= 1;
}
}
#[inline(always)]
fn poll_one(
self: Pin<&mut Self>,
i: usize,
active: &mut usize,
cx: &mut Context<'_>,
) -> Poll<Result<Probe<T>, E>> {
debug_assert!(i < Self::SIZE);
if i == Self::SIZE - 1 {
match poll_slot(self.project().fut, active, cx) {
Poll::Ready(Some(r)) => Poll::Ready(r),
_ => Poll::Pending,
}
} else {
self.project().rest.poll_one(i, active, cx)
}
}
}
#[doc(hidden)]
#[inline(always)]
pub const fn align_types<T, E, M, Fut>(_: *const Result<Probe<T>, E>, _: *const M)
where
M: FnOnce() -> Fut,
Fut: Future<Output = Result<Probe<T>, E>>,
{
}
#[inline(always)]
pub async fn select_probe<T, E, Raw, M, Fut>(
biased: bool,
kills: &KillManager<<Raw::Slots as SelectProbeFutures<T, E>>::KF>,
futures: Raw,
miss: M,
) -> Result<Probe<T>, E>
where
Raw: IntoSelectProbeSlots,
Raw::Slots: SelectProbeFutures<T, E>,
M: FnOnce() -> Fut,
Fut: Future<Output = Result<Probe<T>, E>>,
{
let mut slots = core::pin::pin!(futures.into_slots());
let mut active = <Raw::Slots as SelectProbeFutures<T, E>>::SIZE;
let v = core::future::poll_fn(move |cx| {
for i in 0..<Raw::Slots as SelectProbeFutures<T, E>>::SIZE {
if kills.take_new_kills() {
slots.as_mut().process_kills(&kills.flags, &mut active);
}
match slots.as_mut().poll_one(i, &mut active, cx) {
Poll::Ready(Ok(Probe::Hit(val))) if biased => {
for j in 0..i {
if kills.take_new_kills() {
slots.as_mut().process_kills(&kills.flags, &mut active);
}
match slots.as_mut().poll_one(j, &mut active, cx) {
Poll::Ready(Ok(Probe::Hit(earlier_val))) => {
drop(val);
return Poll::Ready(Some(Ok(Probe::Hit(earlier_val))));
}
Poll::Ready(Ok(Probe::Miss)) => {}
Poll::Ready(Err(e)) => {
drop(val);
return Poll::Ready(Some(Err(e)));
}
Poll::Pending => {}
}
}
return Poll::Ready(Some(Ok(Probe::Hit(val))));
}
Poll::Ready(result) => return Poll::Ready(Some(result)),
Poll::Pending => {}
}
}
if active == 0 {
return Poll::Ready(None);
}
Poll::Pending
})
.await;
if let Some(v) = v {
return v;
}
miss().await
}
pub enum Chunk<Data, Done> {
Data(Data),
Done(Done),
}
impl<Data, Done> Chunk<Data, Done> {
#[inline(always)]
pub fn data(self) -> Option<Data> {
match self {
Self::Data(d) => Some(d),
_ => None,
}
}
#[inline(always)]
pub fn done(self) -> Option<Done> {
match self {
Self::Done(d) => Some(d),
_ => None,
}
}
#[inline(always)]
pub fn map_data<Data2>(self, f: impl FnOnce(Data) -> Data2) -> Chunk<Data2, Done> {
match self {
Self::Data(d) => Chunk::Data(f(d)),
Self::Done(d) => Chunk::Done(d),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use core::future::Future;
use core::pin::Pin;
use core::sync::atomic::{AtomicBool, Ordering};
use core::task::{Context, Poll};
fn poll_once<F: Future>(f: Pin<&mut F>) -> Poll<F::Output> {
let w = strede_test_util::noop_waker();
let mut cx = Context::from_waker(&w);
f.poll(&mut cx)
}
#[test]
fn kill_drops_future() {
static DROPPED: AtomicBool = AtomicBool::new(false);
struct SetOnDrop;
impl Drop for SetOnDrop {
fn drop(&mut self) {
DROPPED.store(true, Ordering::SeqCst);
}
}
let _guard = SetOnDrop;
let fut = async {
crate::select_probe! {
async move {
kill!(1);
Ok(Probe::Miss)
},
async move {
let _guard = _guard;
core::future::pending::<Result<Probe<u32>, ()>>().await
},
async move {
assert!(DROPPED.load(Ordering::SeqCst), "arm 1 was not dropped");
Ok(Probe::Hit(42u32))
},
}
};
let mut fut = core::pin::pin!(fut);
loop {
match poll_once(fut.as_mut()) {
Poll::Ready(result) => {
assert_eq!(result, Ok(Probe::Hit(42u32)));
break;
}
Poll::Pending => {}
}
}
}
}