use std::{pin::Pin, task};
use futures::{
FutureExt as _,
future::{Either, FusedFuture, Then},
};
use pin_project::pin_project;
pub(crate) const fn yield_now() -> YieldNow {
YieldNow(false)
}
#[derive(Clone, Copy, Debug)]
pub(crate) struct YieldNow(bool);
impl Future for YieldNow {
type Output = ();
fn poll(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> task::Poll<Self::Output> {
if self.0 {
task::Poll::Ready(())
} else {
self.0 = true;
cx.waker().wake_by_ref();
task::Poll::Pending
}
}
}
type ThenYield<F, O> = Then<F, YieldThenReturn<O>, fn(O) -> YieldThenReturn<O>>;
pub(crate) trait FutureExt: Future + Sized {
fn then_yield(self) -> ThenYield<Self, Self::Output> {
self.then(YieldThenReturn::new)
}
}
impl<T: Future> FutureExt for T {}
#[derive(Debug)]
#[pin_project]
pub(crate) struct YieldThenReturn<V> {
value: Option<V>,
r#yield: YieldNow,
}
impl<V> YieldThenReturn<V> {
const fn new(v: V) -> Self {
Self { value: Some(v), r#yield: yield_now() }
}
}
impl<V> Future for YieldThenReturn<V> {
type Output = V;
fn poll(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> task::Poll<Self::Output> {
let this = self.project();
task::ready!(this.r#yield.poll_unpin(cx));
this.value.take().map_or(task::Poll::Pending, task::Poll::Ready)
}
}
pub(crate) const fn select_with_biased_first<A, B>(
biased: A,
regular: B,
) -> SelectWithBiasedFirst<A, B>
where
A: Future + Unpin,
B: Future + Unpin,
{
SelectWithBiasedFirst { inner: Some((biased, regular)) }
}
pub(crate) struct SelectWithBiasedFirst<A, B> {
inner: Option<(A, B)>,
}
impl<A, B> Future for SelectWithBiasedFirst<A, B>
where
A: Future + Unpin,
B: Future + Unpin,
{
type Output = Either<(A::Output, B), (B::Output, A)>;
fn poll(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> task::Poll<Self::Output> {
#[expect(clippy::expect_used, reason = "should not happen normally")]
let (mut a, mut b) = self
.inner
.take()
.expect("cannot poll `SelectWithBiasedFirst` twice");
if let task::Poll::Ready(val) = a.poll_unpin(cx) {
return task::Poll::Ready(Either::Left((val, b)));
}
if let task::Poll::Ready(val) = b.poll_unpin(cx) {
return task::Poll::Ready(Either::Right((val, a)));
}
self.inner = Some((a, b));
task::Poll::Pending
}
}
impl<A, B> FusedFuture for SelectWithBiasedFirst<A, B>
where
A: Future + Unpin,
B: Future + Unpin,
{
fn is_terminated(&self) -> bool {
self.inner.is_none()
}
}