async_session_types/
lib.rs

1// Enabled for the `repr_bound!` macro.
2#![feature(trait_alias)]
3#![feature(async_closure)]
4
5use std::{error::Error, marker, marker::PhantomData, mem::ManuallyDrop, thread, time::Duration};
6use tokio::time::timeout as timeout_after;
7
8// Type aliases so we don't have to update these in so many places if we switch the implementation.
9
10pub type Receiver<T> = tokio::sync::mpsc::UnboundedReceiver<T>;
11pub type Sender<T> = tokio::sync::mpsc::UnboundedSender<T>;
12
13fn unbounded_channel<T>() -> (Sender<T>, Receiver<T>) {
14    tokio::sync::mpsc::unbounded_channel()
15}
16
17/// Multiplexing multiple protocols over a single channel and dispatching to session handler instances.
18#[cfg(feature = "mux")]
19pub mod multiplexing;
20
21mod repr;
22
23pub use repr::{DynMessage, Repr};
24
25#[derive(Debug)]
26pub enum SessionError {
27    /// Wrong message type was sent.
28    UnexpectedMessage(DynMessage),
29    /// The other end of the channel is closed.
30    Disconnected,
31    /// Did not receive a message within the timeout.
32    Timeout,
33    /// Abort due to the a violation of some protocol constraints.
34    Abort(Box<dyn Error + marker::Send + 'static>),
35}
36
37impl std::fmt::Display for SessionError {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        write!(f, "{:?}", self)
40    }
41}
42
43impl Error for SessionError {}
44
45pub type SessionResult<T> = Result<T, SessionError>;
46
47pub fn ok<T>(value: T) -> SessionResult<T> {
48    Ok(value)
49}
50
51fn downcast<T, R: Repr<T>>(msg: R) -> SessionResult<T> {
52    msg.try_into()
53        .map_err(|msg| SessionError::UnexpectedMessage(Box::new(msg)))
54}
55
56/// A session typed channel. `P` is the protocol and `E` is the environment,
57/// containing potential recursion targets. `R` is the representation of
58/// messages, which could be `DynMessage`, or perhaps something we know
59/// statically how to turn into JSON or bytes.
60pub struct Chan<P, E, R> {
61    tx: ManuallyDrop<Sender<R>>,
62    rx: ManuallyDrop<Receiver<R>>,
63    stash: ManuallyDrop<Option<R>>,
64    _phantom: PhantomData<(P, E)>,
65}
66
67impl<P, E, R> Chan<P, E, R> {
68    fn new(tx: Sender<R>, rx: Receiver<R>) -> Chan<P, E, R> {
69        Chan {
70            tx: ManuallyDrop::new(tx),
71            rx: ManuallyDrop::new(rx),
72            stash: ManuallyDrop::new(None),
73            _phantom: PhantomData,
74        }
75    }
76}
77
78fn write_chan<T, P, E, R: Repr<T>>(chan: &Chan<P, E, R>, v: T) -> SessionResult<()> {
79    chan.tx
80        .send(Repr::from(v))
81        .map_err(|_| SessionError::Disconnected)
82}
83
84/// Read a message then cast it to the expected type.
85async fn read_chan<T, P, E, R: Repr<T>>(
86    chan: &mut Chan<P, E, R>,
87    timeout: Duration,
88) -> SessionResult<T> {
89    let msg = read_chan_dyn(chan, timeout).await?;
90    downcast(msg)
91}
92
93/// Try to read a dynamically typed message from the stash of from the channel.
94async fn read_chan_dyn<P, E, R>(chan: &mut Chan<P, E, R>, timeout: Duration) -> SessionResult<R> {
95    match chan.stash.take() {
96        Some(msg) => Ok(msg),
97        None if timeout == Duration::MAX => match chan.rx.recv().await {
98            Some(msg) => Ok(msg),
99            None => Err(SessionError::Disconnected),
100        },
101        None => match timeout_after(timeout, chan.rx.recv()).await {
102            Ok(Some(msg)) => Ok(msg),
103            Ok(None) => Err(SessionError::Disconnected),
104            Err(_) => Err(SessionError::Timeout),
105        },
106    }
107}
108
109/// Close the channel.
110fn close_chan<P, E, R>(chan: Chan<P, E, R>) {
111    // This method cleans up the channel without running the panicky destructor
112    // In essence, it calls the drop glue bypassing the `Drop::drop` method.
113    let mut this = ManuallyDrop::new(chan);
114    unsafe {
115        ManuallyDrop::drop(&mut this.tx);
116        ManuallyDrop::drop(&mut this.rx);
117        ManuallyDrop::drop(&mut this.stash);
118    }
119}
120
121/// Peano numbers: Zero
122pub struct Z;
123
124/// Peano numbers: Increment
125pub struct S<N>(PhantomData<N>);
126
127/// End of communication session (epsilon)
128pub struct Eps;
129
130/// Receive `T`, then resume with protocol `P`.
131pub struct Recv<T, P>(PhantomData<(T, P)>);
132
133/// Send `T`, then resume with protocol `P`.
134pub struct Send<T, P>(PhantomData<(T, P)>);
135
136/// Active choice between `P` and `Q`
137pub struct Choose<P: Outgoing, Q: Outgoing>(PhantomData<(P, Q)>);
138
139/// Passive choice (offer) between `P` and `Q`
140pub struct Offer<P: Incoming, Q: Incoming>(PhantomData<(P, Q)>);
141
142/// Enter a recursive environment.
143pub struct Rec<P>(PhantomData<P>);
144
145/// Recurse. N indicates how many layers of the recursive environment we recurse out of.
146pub struct Var<N>(PhantomData<N>);
147
148/// Indicate that a protocol will receive a message, and specify what type it is,
149/// so we can decide in an offer which arm we got a message for.
150pub trait Incoming {
151    type Expected;
152}
153
154impl<T, P> Incoming for Recv<T, P> {
155    type Expected = T;
156}
157impl<P: Incoming, Q: Incoming> Incoming for Offer<P, Q> {
158    type Expected = P::Expected;
159}
160
161/// Indicate that a protocol will send a message.
162pub trait Outgoing {}
163
164impl<T, P> Outgoing for Send<T, P> {}
165impl<P: Outgoing, Q: Outgoing> Outgoing for Choose<P, Q> {}
166
167/// The HasDual trait defines the dual relationship between protocols.
168///
169/// Any valid protocol has a corresponding dual.
170pub trait HasDual {
171    type Dual;
172}
173
174impl HasDual for Eps {
175    type Dual = Eps;
176}
177
178impl<A, P: HasDual> HasDual for Send<A, P> {
179    type Dual = Recv<A, P::Dual>;
180}
181
182impl<A, P: HasDual> HasDual for Recv<A, P> {
183    type Dual = Send<A, P::Dual>;
184}
185
186impl<P: HasDual, Q: HasDual> HasDual for Choose<P, Q>
187where
188    P: Outgoing,
189    Q: Outgoing,
190    P::Dual: Incoming,
191    Q::Dual: Incoming,
192{
193    type Dual = Offer<P::Dual, Q::Dual>;
194}
195
196impl<P: HasDual, Q: HasDual> HasDual for Offer<P, Q>
197where
198    P: Incoming,
199    Q: Incoming,
200    P::Dual: Outgoing,
201    Q::Dual: Outgoing,
202{
203    type Dual = Choose<P::Dual, Q::Dual>;
204}
205
206impl HasDual for Var<Z> {
207    type Dual = Var<Z>;
208}
209
210impl<N> HasDual for Var<S<N>> {
211    type Dual = Var<S<N>>;
212}
213
214impl<P: HasDual> HasDual for Rec<P> {
215    type Dual = Rec<P::Dual>;
216}
217
218/// Indicate whether the left or right choice was chosen in an `Offer`.
219pub enum Branch<L, R> {
220    Left(L),
221    Right(R),
222}
223
224/// A sanity check destructor that kicks in if we abandon the channel by
225/// returning `Ok(_)` without closing it first.
226impl<P, E, R> Drop for Chan<P, E, R> {
227    fn drop(&mut self) {
228        if !thread::panicking() {
229            panic!("Session channel prematurely dropped. Must call `.close()`.");
230        }
231    }
232}
233
234impl<E, R> Chan<Eps, E, R> {
235    /// Close a channel. Should always be used at the end of your program.
236    pub fn close(self) -> SessionResult<()> {
237        close_chan(self);
238        Ok(())
239    }
240}
241
242impl<P, E, R> Chan<P, E, R> {
243    fn cast<P2, E2>(self) -> Chan<P2, E2, R> {
244        let mut this = ManuallyDrop::new(self);
245        unsafe {
246            Chan {
247                tx: ManuallyDrop::new(ManuallyDrop::take(&mut this.tx)),
248                rx: ManuallyDrop::new(ManuallyDrop::take(&mut this.rx)),
249                stash: ManuallyDrop::new(ManuallyDrop::take(&mut this.stash)),
250                _phantom: PhantomData,
251            }
252        }
253    }
254
255    /// Close the channel and return an error due to some business logic violation.
256    pub fn abort<T, F: Error + marker::Send + 'static>(self, e: F) -> SessionResult<T> {
257        close_chan(self);
258        Err(SessionError::Abort(Box::new(e)))
259    }
260
261    pub fn abort_dyn<T>(self, e: Box<dyn Error + marker::Send>) -> SessionResult<T> {
262        close_chan(self);
263        Err(SessionError::Abort(e))
264    }
265}
266
267impl<P, E, T, R: Repr<T>> Chan<Send<T, P>, E, R> {
268    /// Send a value of type `T` over the channel. Returns a channel with protocol `P`.
269    pub fn send(self, v: T) -> SessionResult<Chan<P, E, R>> {
270        match write_chan(&self, v) {
271            Ok(()) => Ok(self.cast()),
272            Err(e) => {
273                close_chan(self);
274                Err(e)
275            }
276        }
277    }
278}
279
280impl<P, E, T, R: Repr<T>> Chan<Recv<T, P>, E, R> {
281    /// Receives a value of type `T` from the channel. Returns a tuple
282    /// containing the resulting channel and the received value.
283    pub async fn recv(mut self, timeout: Duration) -> SessionResult<(Chan<P, E, R>, T)> {
284        match read_chan(&mut self, timeout).await {
285            Ok(v) => Ok((self.cast(), v)),
286            Err(e) => {
287                close_chan(self);
288                Err(e)
289            }
290        }
291    }
292}
293
294impl<P: Outgoing, Q: Outgoing, E, R> Chan<Choose<P, Q>, E, R> {
295    /// Perform an active choice, selecting protocol `P`.
296    /// We haven't sent any value yet, so the agency stays on our side.
297    pub fn sel1(self) -> Chan<P, E, R> {
298        self.cast()
299    }
300
301    /// Perform an active choice, selecting protocol `Q`.
302    /// We haven't sent any value yet, so the agency stays on our side.
303    pub fn sel2(self) -> Chan<Q, E, R> {
304        self.cast()
305    }
306}
307
308/// Branches offered between protocols `P` and `Q`.
309type OfferBranch<P, Q, E, R> = Branch<Chan<P, E, R>, Chan<Q, E, R>>;
310
311impl<P: Incoming, Q: Incoming, E, R> Chan<Offer<P, Q>, E, R>
312where
313    P::Expected: 'static,
314    R: Repr<P::Expected>,
315{
316    /// Put the value we pulled from the channel back,
317    /// so the next protocol step can read it and use it.
318    fn stash(mut self, msg: R) -> Self {
319        self.stash = ManuallyDrop::new(Some(msg));
320        self
321    }
322
323    /// Passive choice. This allows the other end of the channel to select one
324    /// of two options for continuing the protocol: either `P` or `Q`.
325    /// Both options mean they will have to send a message to us,
326    /// the agency is on their side.
327    pub async fn offer(mut self, t: Duration) -> SessionResult<OfferBranch<P, Q, E, R>> {
328        // The next message we read from the channel decides
329        // which protocol we go with.
330        let msg = match read_chan_dyn(&mut self, t).await {
331            Ok(msg) => msg,
332            Err(e) => {
333                close_chan(self);
334                return Err(e);
335            }
336        };
337
338        // This variant casts then re-wraps.
339        // match Repr::<P::Expected>::try_cast(msg) {
340        //     Ok(exp) => Ok(Left(self.stash(exp.to_repr()).cast())),
341        //     Err(msg) => Ok(Right(self.stash(msg).cast())),
342        // }
343
344        // This variant just checks, to avoid unwrapping and re-wrapping.
345        if Repr::<P::Expected>::can_into(&msg) {
346            Ok(Branch::Left(self.stash(msg).cast()))
347        } else {
348            Ok(Branch::Right(self.stash(msg).cast()))
349        }
350    }
351}
352
353impl<P, E, R> Chan<Rec<P>, E, R> {
354    /// Enter a recursive environment, putting the current environment on the
355    /// top of the environment stack.
356    pub fn enter(self) -> Chan<P, (P, E), R> {
357        self.cast()
358    }
359}
360
361impl<P, E, R> Chan<Var<Z>, (P, E), R> {
362    /// Recurse to the environment on the top of the environment stack.
363    /// The agency must be kept, since there's no message exchange here,
364    /// we just start from the top as a continuation of where we are.
365    pub fn zero(self) -> SessionResult<Chan<P, (P, E), R>> {
366        Ok(self.cast())
367    }
368}
369
370impl<P, E, N, R> Chan<Var<S<N>>, (P, E), R> {
371    /// Pop the top environment from the environment stack.
372    pub fn succ(self) -> Chan<Var<N>, E, R> {
373        self.cast()
374    }
375}
376
377type ChanPair<P, R> = (Chan<P, (), R>, Chan<<P as HasDual>::Dual, (), R>);
378type ChanDynPair<P, R> = (Chan<P, (), R>, (Sender<R>, Receiver<R>));
379
380/// Create a pair of server and client channels for a given protocol `P`.
381pub fn session_channel<P: HasDual, R>() -> ChanPair<P, R> {
382    let (tx1, rx1) = unbounded_channel();
383    let (tx2, rx2) = unbounded_channel();
384
385    let c1 = Chan::new(tx1, rx2);
386    let c2 = Chan::new(tx2, rx1);
387
388    (c1, c2)
389}
390
391/// Similar to `session_channel`; create a typed channel for a protocol `P`,
392/// but instead of creating a channel for its dual, return the raw sender
393/// and receiver that can be used to communicate with the channel created.
394///
395/// These can be used in multiplexers to dispatch messages to/from the network.
396pub fn session_channel_dyn<P, R>() -> ChanDynPair<P, R> {
397    let (tx1, rx1) = unbounded_channel();
398    let (tx2, rx2) = unbounded_channel();
399
400    let c = Chan::new(tx1, rx2);
401
402    (c, (tx2, rx1))
403}
404
405/// This macro is convenient for server-like protocols of the form:
406///
407/// `Offer<A, Offer<B, Offer<C, ... >>>`
408///
409/// # Examples
410///
411/// Assume we have a protocol `Offer<Recv<u64, Eps>, Offer<Recv<String, Eps>,Eps>>>`
412/// we can use the `offer!` macro as follows:
413///
414/// ```rust
415/// use async_session_types::offer;
416/// use async_session_types::*;
417/// use std::time::Duration;
418///
419/// struct Bye;
420///
421/// async fn srv(c: Chan<Offer<Recv<u64, Eps>, Offer<Recv<String, Eps>, Recv<Bye, Eps>>>, (), DynMessage>) -> SessionResult<()> {
422///     let t = Duration::from_secs(1);
423///     offer! { c, t,
424///         Number => {
425///             let (c, n) = c.recv(t).await?;
426///             assert_eq!(42, n);
427///             c.close()
428///         },
429///         String => {
430///             c.recv(t).await?.0.close()
431///         },
432///         Quit => {
433///             c.recv(t).await?.0.close()
434///         }
435///     }
436/// }
437///
438/// async fn cli(c: Chan<Choose<Send<u64, Eps>, Choose<Send<String, Eps>, Send<Bye, Eps>>>, (), DynMessage>) -> SessionResult<()>{
439///     c.sel1().send(42)?.close()
440/// }
441///
442/// #[tokio::main]
443/// async fn main() {
444///     let (s, c) = session_channel();
445///     tokio::spawn(cli(c));
446///     srv(s).await.unwrap();
447/// }
448/// ```
449///
450/// The identifiers on the left-hand side of the arrows have no semantic
451/// meaning, they only provide a meaningful name for the reader.
452#[macro_export]
453macro_rules! offer {
454    (
455        $id:ident, $timeout:expr, $branch:ident => $code:expr, $($t:tt)+
456    ) => (
457        match $id.offer($timeout).await? {
458            $crate::Branch::Left($id) => $code,
459            $crate::Branch::Right($id) => offer!{ $id, $timeout, $($t)+ }
460        }
461    );
462    (
463        $id:ident, $timeout:expr, $branch:ident => $code:expr
464    ) => (
465        $code
466    )
467}
468
469#[cfg(test)]
470mod test;