use futures::Future;
use pin_project::pin_project;
use std::{
any::TypeId,
convert::{TryFrom, TryInto},
marker::{self, PhantomData},
mem,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
};
use crate::tuple::{HasLength, List, Tuple};
use crate::Unavailable;
use crate::{backend::*, IncompleteHalf, SessionIncomplete};
use crate::{prelude::*, types::*, unary::*};
#[derive(Derivative)]
#[derivative(Debug)]
#[repr(C)]
#[must_use]
pub struct Chan<S: Session, Tx: marker::Send + 'static, Rx: marker::Send + 'static> {
tx: Option<Tx>,
rx: Option<Rx>,
drop_tx: Arc<Mutex<Result<Tx, IncompleteHalf<Tx>>>>,
drop_rx: Arc<Mutex<Result<Rx, IncompleteHalf<Rx>>>>,
session: PhantomData<fn() -> S>,
}
impl<Tx, Rx, S> Drop for Chan<S, Tx, Rx>
where
Tx: marker::Send + 'static,
Rx: marker::Send + 'static,
S: Session,
{
fn drop(&mut self) {
let done = TypeId::of::<<S as Session>::Action>() == TypeId::of::<Done>();
if let Some(tx) = self.tx.take() {
*self.drop_tx.lock().unwrap() = if done {
Ok(tx)
} else {
Err(IncompleteHalf::Unfinished(tx))
};
}
if let Some(rx) = self.rx.take() {
*self.drop_rx.lock().unwrap() = if done {
Ok(rx)
} else {
Err(IncompleteHalf::Unfinished(rx))
};
}
}
}
impl<Tx, Rx, S> Chan<S, Tx, Rx>
where
S: Session,
Tx: marker::Send + 'static,
Rx: marker::Send + 'static,
{
pub fn close(self)
where
S: Session<Action = Done>,
{
drop(self)
}
pub async fn recv<T, P>(mut self) -> Result<(T, Chan<P, Tx, Rx>), Rx::Error>
where
S: Session<Action = Recv<T, P>>,
P: Session,
Rx: Receive<T>,
{
let result = self.rx.as_mut().unwrap().recv().await?;
Ok((result, self.unchecked_cast()))
}
pub async fn send<T, P>(mut self, message: T) -> Result<Chan<P, Tx, Rx>, Tx::Error>
where
S: Session<Action = Send<T, P>>,
P: Session,
Tx: Transmit<T>,
T: marker::Send,
{
self.tx.as_mut().unwrap().send(message).await?;
Ok(self.unchecked_cast())
}
pub async fn send_ref<T, P>(mut self, message: &T) -> Result<Chan<P, Tx, Rx>, Tx::Error>
where
S: Session<Action = Send<T, P>>,
P: Session,
Tx: Transmit<T, Ref>,
T: marker::Send,
{
self.tx.as_mut().unwrap().send(message).await?;
Ok(self.unchecked_cast())
}
pub async fn send_mut<T, P>(mut self, message: &mut T) -> Result<Chan<P, Tx, Rx>, Tx::Error>
where
S: Session<Action = Send<T, P>>,
P: Session,
Tx: Transmit<T, Mut>,
T: marker::Send,
{
self.tx.as_mut().unwrap().send(message).await?;
Ok(self.unchecked_cast())
}
}
impl<Tx, Rx, S, Choices, const LENGTH: usize> Chan<S, Tx, Rx>
where
S: Session<Action = Choose<Choices>>,
Choices: Tuple,
Choices::AsList: HasLength,
<Choices::AsList as HasLength>::Length: ToConstant<AsConstant = Number<LENGTH>>,
Tx: Transmitter + marker::Send + 'static,
Rx: marker::Send + 'static,
{
pub async fn choose<const N: usize>(
mut self,
) -> Result<
Chan<<Choices::AsList as Select<<Number<N> as ToUnary>::AsUnary>>::Selected, Tx, Rx>,
Tx::Error,
>
where
Number<N>: ToUnary,
Choices::AsList: Select<<Number<N> as ToUnary>::AsUnary>,
<Choices::AsList as Select<<Number<N> as ToUnary>::AsUnary>>::Selected: Session,
{
let choice: Choice<LENGTH> = u8::try_from(N)
.expect("choices must fit into a byte")
.try_into()
.expect("type system prevents out of range choice in `choose`");
self.tx.as_mut().unwrap().send_choice(choice).await?;
Ok(self.unchecked_cast())
}
}
impl<Tx, Rx, S, Choices, const LENGTH: usize> Chan<S, Tx, Rx>
where
S: Session<Action = Offer<Choices>>,
Choices: Tuple + 'static,
Choices::AsList: HasLength + EachScoped + EachHasDual,
<Choices::AsList as HasLength>::Length: ToConstant<AsConstant = Number<LENGTH>>,
Z: LessThan<<Choices::AsList as HasLength>::Length>,
Tx: marker::Send + 'static,
Rx: Receiver + marker::Send + 'static,
{
pub async fn offer(self) -> Result<Branches<Choices, Tx, Rx>, Rx::Error> {
let (tx, mut rx, drop_tx, drop_rx) = self.unwrap_contents();
let variant = rx.as_mut().unwrap().recv_choice::<LENGTH>().await?.into();
Ok(Branches {
variant,
tx,
rx,
drop_tx,
drop_rx,
protocols: PhantomData,
})
}
}
impl<Tx, Rx, S> Chan<S, Tx, Rx>
where
S: Session,
Tx: marker::Send + 'static,
Rx: marker::Send + 'static,
{
pub async fn call<T, E, P, Q, F, Fut>(
self,
first: F,
) -> Result<(T, Result<Chan<Q, Tx, Rx>, SessionIncomplete<Tx, Rx>>), E>
where
S: Session<Action = Call<P, Q>>,
P: Session,
Q: Session,
F: FnOnce(Chan<P, Tx, Rx>) -> Fut,
Fut: Future<Output = Result<T, E>>,
{
let (tx, rx, drop_tx, drop_rx) = self.unwrap_contents();
let (result, chan_result) = P::over(tx.unwrap(), rx.unwrap(), first).await;
Ok((
result?,
chan_result.map(|(tx, rx)| Chan {
tx: Some(tx),
rx: Some(rx),
drop_tx,
drop_rx,
session: PhantomData,
}),
))
}
pub async fn split<T, E, P, Q, R, F, Fut>(
self,
with_parts: F,
) -> Result<(T, Result<Chan<R, Tx, Rx>, SessionIncomplete<Tx, Rx>>), E>
where
S: Session<Action = Split<P, Q, R>>,
P: Session,
Q: Session,
R: Session,
F: FnOnce(Chan<P, Tx, Unavailable>, Chan<Q, Unavailable, Rx>) -> Fut,
Fut: Future<Output = Result<T, E>>,
{
use IncompleteHalf::*;
use SessionIncomplete::*;
let (tx, rx, drop_tx, drop_rx) = self.unwrap_contents();
let ((result, maybe_rx), maybe_tx) =
P::over(tx.unwrap(), Unavailable::default(), |tx_only| async move {
Q::over(Unavailable::default(), rx.unwrap(), |rx_only| async move {
with_parts(tx_only, rx_only).await
})
.await
})
.await;
let maybe_tx_rx: Result<(Tx, Rx), SessionIncomplete<Tx, Rx>> = match (
maybe_tx
.map(|(tx, _)| Ok(tx))
.unwrap_or_else(|incomplete| incomplete.into_halves().0),
maybe_rx
.map(|(_, rx)| Ok(rx))
.unwrap_or_else(|incomplete| incomplete.into_halves().1),
) {
(Ok(tx), Ok(rx)) => Ok((tx, rx)),
(Ok(tx), Err(Unclosed)) => Err(RxHalf { tx, rx: Unclosed }),
(Err(Unclosed), Ok(rx)) => Err(TxHalf { tx: Unclosed, rx }),
(Ok(tx), Err(Unfinished(rx))) => Err(RxHalf {
tx,
rx: Unfinished(rx),
}),
(Err(Unfinished(tx)), Ok(rx)) => Err(TxHalf {
tx: Unfinished(tx),
rx,
}),
(Err(Unfinished(tx)), Err(Unclosed)) => Err(BothHalves {
tx: Unfinished(tx),
rx: Unclosed,
}),
(Err(Unclosed), Err(Unfinished(rx))) => Err(BothHalves {
tx: Unclosed,
rx: Unfinished(rx),
}),
(Err(Unclosed), Err(Unclosed)) => Err(BothHalves {
tx: Unclosed,
rx: Unclosed,
}),
(Err(Unfinished(tx)), Err(Unfinished(rx))) => Err(BothHalves {
tx: Unfinished(tx),
rx: Unfinished(rx),
}),
};
Ok((
result?,
maybe_tx_rx.map(|(tx, rx)| Chan {
tx: Some(tx),
rx: Some(rx),
drop_tx,
drop_rx,
session: PhantomData,
}),
))
}
pub fn into_inner(self) -> (Tx, Rx) {
let (tx, rx, _, _) = self.unwrap_contents();
(tx.unwrap(), rx.unwrap())
}
fn unwrap_contents(
mut self,
) -> (
Option<Tx>,
Option<Rx>,
Arc<Mutex<Result<Tx, IncompleteHalf<Tx>>>>,
Arc<Mutex<Result<Rx, IncompleteHalf<Rx>>>>,
) {
let tx = self.tx.take();
let rx = self.rx.take();
let drop_tx = self.drop_tx.clone();
let drop_rx = self.drop_rx.clone();
(tx, rx, drop_tx, drop_rx)
}
fn unchecked_cast<Q>(mut self) -> Chan<Q, Tx, Rx>
where
Q: Session,
{
let new: *mut Chan<Q, _, _> = (&mut self as *mut Chan<_, _, _>).cast();
mem::forget(self);
unsafe { new.read() }
}
pub(crate) fn from_raw_unchecked(tx: Tx, rx: Rx) -> Chan<S, Tx, Rx> {
Chan {
tx: Some(tx),
rx: Some(rx),
drop_tx: Arc::new(Mutex::new(Err(IncompleteHalf::Unclosed))),
drop_rx: Arc::new(Mutex::new(Err(IncompleteHalf::Unclosed))),
session: PhantomData,
}
}
}
pub(crate) fn over<P, Tx, Rx, T, F, Fut>(tx: Tx, rx: Rx, with_chan: F) -> Over<Tx, Rx, T, Fut>
where
P: Session,
Tx: std::marker::Send + 'static,
Rx: std::marker::Send + 'static,
F: FnOnce(Chan<P, Tx, Rx>) -> Fut,
Fut: Future<Output = T>,
{
let drop_tx = Arc::new(Mutex::new(Err(IncompleteHalf::Unclosed)));
let drop_rx = Arc::new(Mutex::new(Err(IncompleteHalf::Unclosed)));
let reclaimed_tx = drop_tx.clone();
let reclaimed_rx = drop_rx.clone();
let chan = Chan {
tx: Some(tx),
rx: Some(rx),
drop_tx,
drop_rx,
session: PhantomData,
};
Over {
future: with_chan(chan),
reclaimed_tx,
reclaimed_rx,
}
}
impl<Tx, Rx, T, Fut> Future for Over<Tx, Rx, T, Fut>
where
Fut: Future<Output = T>,
{
type Output = (T, Result<(Tx, Rx), SessionIncomplete<Tx, Rx>>);
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
use IncompleteHalf::*;
use SessionIncomplete::*;
let reclaimed_tx = self.reclaimed_tx.clone();
let reclaimed_rx = self.reclaimed_rx.clone();
self.project().future.poll(cx).map(|result| {
let chan = match (
mem::replace(&mut *reclaimed_tx.lock().unwrap(), Err(Unclosed)),
mem::replace(&mut *reclaimed_rx.lock().unwrap(), Err(Unclosed)),
) {
(Ok(tx), Ok(rx)) => Ok((tx, rx)),
(Err(tx), Ok(rx)) => Err(TxHalf { tx, rx }),
(Ok(tx), Err(rx)) => Err(RxHalf { tx, rx }),
(Err(tx), Err(rx)) => Err(BothHalves { tx, rx }),
};
(result, chan)
})
}
}
#[pin_project]
#[derive(Debug)]
pub struct Over<Tx, Rx, T, Fut>
where
Fut: Future<Output = T>,
{
reclaimed_tx: Arc<Mutex<Result<Tx, IncompleteHalf<Tx>>>>,
reclaimed_rx: Arc<Mutex<Result<Rx, IncompleteHalf<Rx>>>>,
#[pin]
future: Fut,
}
#[derive(Derivative)]
#[derivative(Debug)]
#[must_use]
pub struct Branches<Choices, Tx, Rx>
where
Tx: marker::Send + 'static,
Rx: marker::Send + 'static,
Choices: Tuple + 'static,
Choices::AsList: EachScoped + EachHasDual + HasLength,
{
variant: u8,
tx: Option<Tx>,
rx: Option<Rx>,
drop_tx: Arc<Mutex<Result<Tx, IncompleteHalf<Tx>>>>,
drop_rx: Arc<Mutex<Result<Rx, IncompleteHalf<Rx>>>>,
protocols: PhantomData<fn() -> Choices>,
}
impl<Tx, Rx, Choices> Drop for Branches<Choices, Tx, Rx>
where
Tx: marker::Send + 'static,
Rx: marker::Send + 'static,
Choices: Tuple + 'static,
Choices::AsList: EachScoped + EachHasDual + HasLength,
{
fn drop(&mut self) {
if let Some(tx) = self.tx.take() {
*self.drop_tx.lock().unwrap() = Err(IncompleteHalf::Unfinished(tx));
}
if let Some(rx) = self.rx.take() {
*self.drop_rx.lock().unwrap() = Err(IncompleteHalf::Unfinished(rx));
}
}
}
impl<Tx, Rx, Choices, const LENGTH: usize> Branches<Choices, Tx, Rx>
where
Choices: Tuple + 'static,
Choices::AsList: EachScoped + EachHasDual + HasLength,
<Choices::AsList as HasLength>::Length: ToConstant<AsConstant = Number<LENGTH>>,
Tx: marker::Send + 'static,
Rx: marker::Send + 'static,
{
pub fn case<const N: usize>(
mut self,
) -> Result<
Chan<<Choices::AsList as Select<<Number<N> as ToUnary>::AsUnary>>::Selected, Tx, Rx>,
Branches<<<Choices::AsList as Select<<Number<N> as ToUnary>::AsUnary>>::Remainder as List>::AsTuple, Tx, Rx>,
>
where
Number<N>: ToUnary,
Choices::AsList: Select<<Number<N> as ToUnary>::AsUnary>,
<Choices::AsList as Select<<Number<N> as ToUnary>::AsUnary>>::Selected: Session,
<Choices::AsList as Select<<Number<N> as ToUnary>::AsUnary>>::Remainder: EachScoped + EachHasDual + HasLength + List,
{
let variant = self.variant;
let tx = self.tx.take();
let rx = self.rx.take();
let drop_tx = self.drop_tx.clone();
let drop_rx = self.drop_rx.clone();
let branch: u8 = N
.try_into()
.expect("branch discriminant exceeded u8::MAX in `case`");
if variant == branch {
Ok(Chan {
tx,
rx,
drop_tx,
drop_rx,
session: PhantomData,
})
} else {
Err(Branches {
variant: if variant > branch {
variant - 1
} else {
variant
},
tx,
rx,
drop_tx,
drop_rx,
protocols: PhantomData,
})
}
}
pub fn choice(&self) -> Choice<LENGTH> {
self.variant
.try_into()
.expect("internal variant for `Branches` exceeds number of choices")
}
}
impl<'a, Tx, Rx> Branches<(), Tx, Rx>
where
Tx: marker::Send + 'static,
Rx: marker::Send + 'static,
{
pub fn empty_case<T>(self) -> T {
unreachable!("empty `Branches` cannot be constructed")
}
}