use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll}
};
use crate::err::Error;
use super::{Shared, State};
#[repr(transparent)]
pub struct WaitCtx<T, S, E>(pub(crate) Arc<Shared<T, S, E>>);
impl<T, S, E> WaitCtx<T, S, E> {
pub fn wait(self) -> Result<T, Error<S, E>> {
let mut inner = self.0.inner.lock();
loop {
match inner.state {
State::Waiting => {
self.0.signal.wait(&mut inner);
}
State::Data(_) => {
let old = std::mem::replace(&mut inner.state, State::Finalized);
drop(inner);
let State::Data(data) = old else {
unimplemented!("Unable to extract data");
};
break Ok(data);
}
State::Err(_) => {
let old = std::mem::replace(&mut inner.state, State::Finalized);
drop(inner);
let State::Err(err) = old else {
unimplemented!("Unable to extract error");
};
break Err(err);
}
State::Finalized => {
unimplemented!("Unexpected state")
}
}
}
}
pub fn try_get(&self) -> Result<Option<T>, Error<S, E>> {
let mut inner = self.0.inner.lock();
inner.try_get()
}
#[must_use]
pub const fn wait_async(&self) -> WaitFuture<T, S, E> {
WaitFuture(self)
}
}
impl<T, S, E> Drop for WaitCtx<T, S, E> {
fn drop(&mut self) {
let mut inner = self.0.inner.lock();
inner.wctx_dropped = true;
}
}
impl<T, S, E> Future for WaitCtx<T, S, E> {
type Output = Result<T, Error<S, E>>;
fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
let mut inner = self.0.inner.lock();
match inner.try_get() {
Ok(Some(v)) => Poll::Ready(Ok(v)),
Ok(None) => {
inner.waker = Some(ctx.waker().clone());
Poll::Pending
}
Err(e) => Poll::Ready(Err(e))
}
}
}
#[repr(transparent)]
pub struct WaitFuture<'wctx, T, S, E>(&'wctx WaitCtx<T, S, E>);
impl<'wctx, T, S, E> Future for WaitFuture<'wctx, T, S, E> {
type Output = Result<T, Error<S, E>>;
fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
let mut inner = self.0 .0.inner.lock();
match inner.try_get() {
Ok(Some(v)) => Poll::Ready(Ok(v)),
Ok(None) => {
inner.waker = Some(ctx.waker().clone());
Poll::Pending
}
Err(e) => Poll::Ready(Err(e))
}
}
}