1#![feature(trait_alias)]
3#![feature(async_closure)]
4
5use std::{error::Error, marker, marker::PhantomData, mem::ManuallyDrop, thread, time::Duration};
6use tokio::time::timeout as timeout_after;
7
8pub type Receiver<T> = tokio::sync::mpsc::UnboundedReceiver<T>;
11pub type Sender<T> = tokio::sync::mpsc::UnboundedSender<T>;
12
13fn unbounded_channel<T>() -> (Sender<T>, Receiver<T>) {
14 tokio::sync::mpsc::unbounded_channel()
15}
16
17#[cfg(feature = "mux")]
19pub mod multiplexing;
20
21mod repr;
22
23pub use repr::{DynMessage, Repr};
24
25#[derive(Debug)]
26pub enum SessionError {
27 UnexpectedMessage(DynMessage),
29 Disconnected,
31 Timeout,
33 Abort(Box<dyn Error + marker::Send + 'static>),
35}
36
37impl std::fmt::Display for SessionError {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 write!(f, "{:?}", self)
40 }
41}
42
43impl Error for SessionError {}
44
45pub type SessionResult<T> = Result<T, SessionError>;
46
47pub fn ok<T>(value: T) -> SessionResult<T> {
48 Ok(value)
49}
50
51fn downcast<T, R: Repr<T>>(msg: R) -> SessionResult<T> {
52 msg.try_into()
53 .map_err(|msg| SessionError::UnexpectedMessage(Box::new(msg)))
54}
55
56pub struct Chan<P, E, R> {
61 tx: ManuallyDrop<Sender<R>>,
62 rx: ManuallyDrop<Receiver<R>>,
63 stash: ManuallyDrop<Option<R>>,
64 _phantom: PhantomData<(P, E)>,
65}
66
67impl<P, E, R> Chan<P, E, R> {
68 fn new(tx: Sender<R>, rx: Receiver<R>) -> Chan<P, E, R> {
69 Chan {
70 tx: ManuallyDrop::new(tx),
71 rx: ManuallyDrop::new(rx),
72 stash: ManuallyDrop::new(None),
73 _phantom: PhantomData,
74 }
75 }
76}
77
78fn write_chan<T, P, E, R: Repr<T>>(chan: &Chan<P, E, R>, v: T) -> SessionResult<()> {
79 chan.tx
80 .send(Repr::from(v))
81 .map_err(|_| SessionError::Disconnected)
82}
83
84async fn read_chan<T, P, E, R: Repr<T>>(
86 chan: &mut Chan<P, E, R>,
87 timeout: Duration,
88) -> SessionResult<T> {
89 let msg = read_chan_dyn(chan, timeout).await?;
90 downcast(msg)
91}
92
93async fn read_chan_dyn<P, E, R>(chan: &mut Chan<P, E, R>, timeout: Duration) -> SessionResult<R> {
95 match chan.stash.take() {
96 Some(msg) => Ok(msg),
97 None if timeout == Duration::MAX => match chan.rx.recv().await {
98 Some(msg) => Ok(msg),
99 None => Err(SessionError::Disconnected),
100 },
101 None => match timeout_after(timeout, chan.rx.recv()).await {
102 Ok(Some(msg)) => Ok(msg),
103 Ok(None) => Err(SessionError::Disconnected),
104 Err(_) => Err(SessionError::Timeout),
105 },
106 }
107}
108
109fn close_chan<P, E, R>(chan: Chan<P, E, R>) {
111 let mut this = ManuallyDrop::new(chan);
114 unsafe {
115 ManuallyDrop::drop(&mut this.tx);
116 ManuallyDrop::drop(&mut this.rx);
117 ManuallyDrop::drop(&mut this.stash);
118 }
119}
120
121pub struct Z;
123
124pub struct S<N>(PhantomData<N>);
126
127pub struct Eps;
129
130pub struct Recv<T, P>(PhantomData<(T, P)>);
132
133pub struct Send<T, P>(PhantomData<(T, P)>);
135
136pub struct Choose<P: Outgoing, Q: Outgoing>(PhantomData<(P, Q)>);
138
139pub struct Offer<P: Incoming, Q: Incoming>(PhantomData<(P, Q)>);
141
142pub struct Rec<P>(PhantomData<P>);
144
145pub struct Var<N>(PhantomData<N>);
147
148pub trait Incoming {
151 type Expected;
152}
153
154impl<T, P> Incoming for Recv<T, P> {
155 type Expected = T;
156}
157impl<P: Incoming, Q: Incoming> Incoming for Offer<P, Q> {
158 type Expected = P::Expected;
159}
160
161pub trait Outgoing {}
163
164impl<T, P> Outgoing for Send<T, P> {}
165impl<P: Outgoing, Q: Outgoing> Outgoing for Choose<P, Q> {}
166
167pub trait HasDual {
171 type Dual;
172}
173
174impl HasDual for Eps {
175 type Dual = Eps;
176}
177
178impl<A, P: HasDual> HasDual for Send<A, P> {
179 type Dual = Recv<A, P::Dual>;
180}
181
182impl<A, P: HasDual> HasDual for Recv<A, P> {
183 type Dual = Send<A, P::Dual>;
184}
185
186impl<P: HasDual, Q: HasDual> HasDual for Choose<P, Q>
187where
188 P: Outgoing,
189 Q: Outgoing,
190 P::Dual: Incoming,
191 Q::Dual: Incoming,
192{
193 type Dual = Offer<P::Dual, Q::Dual>;
194}
195
196impl<P: HasDual, Q: HasDual> HasDual for Offer<P, Q>
197where
198 P: Incoming,
199 Q: Incoming,
200 P::Dual: Outgoing,
201 Q::Dual: Outgoing,
202{
203 type Dual = Choose<P::Dual, Q::Dual>;
204}
205
206impl HasDual for Var<Z> {
207 type Dual = Var<Z>;
208}
209
210impl<N> HasDual for Var<S<N>> {
211 type Dual = Var<S<N>>;
212}
213
214impl<P: HasDual> HasDual for Rec<P> {
215 type Dual = Rec<P::Dual>;
216}
217
218pub enum Branch<L, R> {
220 Left(L),
221 Right(R),
222}
223
224impl<P, E, R> Drop for Chan<P, E, R> {
227 fn drop(&mut self) {
228 if !thread::panicking() {
229 panic!("Session channel prematurely dropped. Must call `.close()`.");
230 }
231 }
232}
233
234impl<E, R> Chan<Eps, E, R> {
235 pub fn close(self) -> SessionResult<()> {
237 close_chan(self);
238 Ok(())
239 }
240}
241
242impl<P, E, R> Chan<P, E, R> {
243 fn cast<P2, E2>(self) -> Chan<P2, E2, R> {
244 let mut this = ManuallyDrop::new(self);
245 unsafe {
246 Chan {
247 tx: ManuallyDrop::new(ManuallyDrop::take(&mut this.tx)),
248 rx: ManuallyDrop::new(ManuallyDrop::take(&mut this.rx)),
249 stash: ManuallyDrop::new(ManuallyDrop::take(&mut this.stash)),
250 _phantom: PhantomData,
251 }
252 }
253 }
254
255 pub fn abort<T, F: Error + marker::Send + 'static>(self, e: F) -> SessionResult<T> {
257 close_chan(self);
258 Err(SessionError::Abort(Box::new(e)))
259 }
260
261 pub fn abort_dyn<T>(self, e: Box<dyn Error + marker::Send>) -> SessionResult<T> {
262 close_chan(self);
263 Err(SessionError::Abort(e))
264 }
265}
266
267impl<P, E, T, R: Repr<T>> Chan<Send<T, P>, E, R> {
268 pub fn send(self, v: T) -> SessionResult<Chan<P, E, R>> {
270 match write_chan(&self, v) {
271 Ok(()) => Ok(self.cast()),
272 Err(e) => {
273 close_chan(self);
274 Err(e)
275 }
276 }
277 }
278}
279
280impl<P, E, T, R: Repr<T>> Chan<Recv<T, P>, E, R> {
281 pub async fn recv(mut self, timeout: Duration) -> SessionResult<(Chan<P, E, R>, T)> {
284 match read_chan(&mut self, timeout).await {
285 Ok(v) => Ok((self.cast(), v)),
286 Err(e) => {
287 close_chan(self);
288 Err(e)
289 }
290 }
291 }
292}
293
294impl<P: Outgoing, Q: Outgoing, E, R> Chan<Choose<P, Q>, E, R> {
295 pub fn sel1(self) -> Chan<P, E, R> {
298 self.cast()
299 }
300
301 pub fn sel2(self) -> Chan<Q, E, R> {
304 self.cast()
305 }
306}
307
308type OfferBranch<P, Q, E, R> = Branch<Chan<P, E, R>, Chan<Q, E, R>>;
310
311impl<P: Incoming, Q: Incoming, E, R> Chan<Offer<P, Q>, E, R>
312where
313 P::Expected: 'static,
314 R: Repr<P::Expected>,
315{
316 fn stash(mut self, msg: R) -> Self {
319 self.stash = ManuallyDrop::new(Some(msg));
320 self
321 }
322
323 pub async fn offer(mut self, t: Duration) -> SessionResult<OfferBranch<P, Q, E, R>> {
328 let msg = match read_chan_dyn(&mut self, t).await {
331 Ok(msg) => msg,
332 Err(e) => {
333 close_chan(self);
334 return Err(e);
335 }
336 };
337
338 if Repr::<P::Expected>::can_into(&msg) {
346 Ok(Branch::Left(self.stash(msg).cast()))
347 } else {
348 Ok(Branch::Right(self.stash(msg).cast()))
349 }
350 }
351}
352
353impl<P, E, R> Chan<Rec<P>, E, R> {
354 pub fn enter(self) -> Chan<P, (P, E), R> {
357 self.cast()
358 }
359}
360
361impl<P, E, R> Chan<Var<Z>, (P, E), R> {
362 pub fn zero(self) -> SessionResult<Chan<P, (P, E), R>> {
366 Ok(self.cast())
367 }
368}
369
370impl<P, E, N, R> Chan<Var<S<N>>, (P, E), R> {
371 pub fn succ(self) -> Chan<Var<N>, E, R> {
373 self.cast()
374 }
375}
376
377type ChanPair<P, R> = (Chan<P, (), R>, Chan<<P as HasDual>::Dual, (), R>);
378type ChanDynPair<P, R> = (Chan<P, (), R>, (Sender<R>, Receiver<R>));
379
380pub fn session_channel<P: HasDual, R>() -> ChanPair<P, R> {
382 let (tx1, rx1) = unbounded_channel();
383 let (tx2, rx2) = unbounded_channel();
384
385 let c1 = Chan::new(tx1, rx2);
386 let c2 = Chan::new(tx2, rx1);
387
388 (c1, c2)
389}
390
391pub fn session_channel_dyn<P, R>() -> ChanDynPair<P, R> {
397 let (tx1, rx1) = unbounded_channel();
398 let (tx2, rx2) = unbounded_channel();
399
400 let c = Chan::new(tx1, rx2);
401
402 (c, (tx2, rx1))
403}
404
405#[macro_export]
453macro_rules! offer {
454 (
455 $id:ident, $timeout:expr, $branch:ident => $code:expr, $($t:tt)+
456 ) => (
457 match $id.offer($timeout).await? {
458 $crate::Branch::Left($id) => $code,
459 $crate::Branch::Right($id) => offer!{ $id, $timeout, $($t)+ }
460 }
461 );
462 (
463 $id:ident, $timeout:expr, $branch:ident => $code:expr
464 ) => (
465 $code
466 )
467}
468
469#[cfg(test)]
470mod test;