use std::cell::{Cell, UnsafeCell};
use std::io;
use std::rc::{Rc, Weak};
use std::time::{Duration, Instant};
use crate::clock::INFINITY;
use crate::error::Error;
use crate::fiber::Cond;
use crate::network::protocol::codec::Consumer;
use crate::tuple::Decode;
use crate::Result;
use super::inner::ConnInner;
type StdResult<T, E> = std::result::Result<T, E>;
pub struct Promise<T> {
inner: Rc<InnerPromise<T>>,
}
impl<T> Promise<T> {
#[inline]
pub(crate) fn new(conn: Weak<ConnInner>) -> Self {
Self {
inner: Rc::new(InnerPromise {
conn,
cond: UnsafeCell::default(),
data: Cell::new(None),
}),
}
}
#[inline]
pub(crate) fn downgrade(&self) -> Weak<InnerPromise<T>> {
Rc::downgrade(&self.inner)
}
#[inline]
fn is_connected(&self) -> bool {
self.inner
.conn
.upgrade()
.map(|c| c.is_connected())
.unwrap_or(false)
}
#[inline]
fn check_connection(&self) -> Result<()> {
if self.is_connected() {
Ok(())
} else {
Err(io::Error::from(io::ErrorKind::NotConnected).into())
}
}
#[inline]
pub fn state(&self) -> State {
if let Some(res) = self.inner.data.take() {
let is_ok = res.is_ok();
self.inner.data.set(Some(res));
if is_ok {
State::Kept
} else {
State::ReceivedError
}
} else if self.is_connected() {
State::Pending
} else {
State::Disconnected
}
}
#[inline]
pub fn try_get(self) -> TryGet<T, Error> {
match (self.inner.data.take(), self.check_connection()) {
(Some(Ok(v)), _) => TryGet::Ok(v),
(Some(Err(e)), _) | (None, Err(e)) => TryGet::Err(e),
(None, Ok(())) => TryGet::Pending(self),
}
}
#[inline]
pub fn wait(self) -> Result<T> {
match self.wait_timeout(INFINITY) {
TryGet::Ok(v) => Ok(v),
TryGet::Err(e) => Err(e),
TryGet::Pending(_) => unreachable!("100 years have passed, wake up"),
}
}
pub fn wait_timeout(self, mut timeout: Duration) -> TryGet<T, Error> {
if let Some(res) = self.inner.data.take() {
return res.into();
}
loop {
if let Err(e) = self.check_connection() {
break TryGet::Err(e);
}
let last_awake = Instant::now();
unsafe { &*self.inner.cond.get() }.wait_timeout(timeout);
if let Some(res) = self.inner.data.take() {
break res.into();
}
timeout = timeout.saturating_sub(last_awake.elapsed());
if timeout.is_zero() {
break TryGet::Pending(self);
}
}
}
pub fn replace_cond(&mut self, cond: Rc<Cond>) -> Rc<Cond> {
unsafe { std::ptr::replace(self.inner.cond.get(), cond) }
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum State {
Kept,
ReceivedError,
Pending,
Disconnected,
}
#[derive(Debug)]
pub enum TryGet<T, E> {
Ok(T),
Err(E),
Pending(Promise<T>),
}
impl<T, E> TryGet<T, E> {
pub fn ok(self) -> Option<T> {
match self {
Self::Ok(v) => Some(v),
_ => None,
}
}
pub fn err(self) -> Option<E> {
match self {
Self::Err(e) => Some(e),
_ => None,
}
}
pub fn pending(self) -> Option<Promise<T>> {
match self {
Self::Pending(p) => Some(p),
_ => None,
}
}
#[inline(always)]
pub fn into_res(self) -> StdResult<StdResult<T, E>, Promise<T>> {
match self {
Self::Ok(v) => Ok(Ok(v)),
Self::Err(e) => Ok(Err(e)),
Self::Pending(p) => Err(p),
}
}
}
impl<T, E> From<StdResult<T, E>> for TryGet<T, E> {
fn from(r: StdResult<T, E>) -> Self {
match r {
Ok(v) => Self::Ok(v),
Err(e) => Self::Err(e),
}
}
}
impl<T, E> From<TryGet<T, E>> for StdResult<StdResult<T, E>, Promise<T>> {
#[inline(always)]
fn from(r: TryGet<T, E>) -> Self {
r.into_res()
}
}
impl<T> std::fmt::Debug for Promise<T> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("Promise")
.field("state", &self.state())
.finish_non_exhaustive()
}
}
pub struct InnerPromise<T> {
conn: Weak<ConnInner>,
cond: UnsafeCell<Rc<Cond>>,
data: Cell<Option<Result<T>>>,
}
impl<T> InnerPromise<T> {
fn signal(&self) {
unsafe { &*self.cond.get() }.signal();
}
}
impl<T> Consumer for InnerPromise<T>
where
T: for<'de> Decode<'de>,
{
fn handle_error(&self, error: Error) {
self.data.set(Some(Err(error)));
self.signal();
}
fn handle_disconnect(&self) {
self.signal();
}
fn consume_data(&self, data: &[u8]) {
self.data.set(Some(T::decode(data)));
self.signal();
}
}