1pub mod channel;
37
38pub use rumpsteak_types as types;
40pub use rumpsteak_types::{GlobalType, Label, LocalTypeR, PayloadSort};
41
42#[cfg(feature = "theory")]
44pub use rumpsteak_theory as theory;
45
46pub use rumpsteak_aura_macros::{session, Message, Role, Roles};
47
48pub mod prelude {
50 pub use super::{session, try_session};
51 pub use super::{
52 Branch, Choice, Choices, End, FromState, IntoSession, Message, Receive, ReceiveError, Role,
53 Route, Select, Send, Session, SessionError,
54 };
55 pub use rumpsteak_types::{GlobalType, Label, LocalTypeR, PayloadSort};
56}
57
58use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt};
59use std::{
60 any::Any,
61 convert::Infallible,
62 future::Future,
63 marker::{self, PhantomData},
64};
65use thiserror::Error;
66
67pub trait Sealable {
69 fn seal(&mut self);
71
72 fn is_sealed(&self) -> bool;
74}
75
76#[derive(Debug, Error)]
78pub enum SessionError<E> {
79 #[error("session was used after being sealed")]
81 Sealed,
82 #[error(transparent)]
84 Channel(E),
85}
86
87pub type SendError<Q, R> =
89 SessionError<<<Q as Route<R>>::Route as Sink<<Q as Role>::Message>>::Error>;
90
91#[derive(Debug, Error)]
93pub enum ReceiveError {
94 #[error("receiver stream is empty")]
96 EmptyStream,
97 #[error("received message with an unexpected type")]
99 UnexpectedType,
100 #[error("session was used after being sealed")]
102 Sealed,
103}
104
105pub trait Message<L>: Sized {
109 fn upcast(label: L) -> Self;
111
112 fn downcast(self) -> Result<L, Self>;
120}
121
122impl<L: 'static> Message<L> for Box<dyn Any> {
123 fn upcast(label: L) -> Self {
124 Box::new(label)
125 }
126
127 fn downcast(self) -> Result<L, Self> {
128 self.downcast().map(|label| *label)
129 }
130}
131
132impl<L: marker::Send + 'static> Message<L> for Box<dyn Any + marker::Send> {
133 fn upcast(label: L) -> Self {
134 Box::new(label)
135 }
136
137 fn downcast(self) -> Result<L, Self> {
138 self.downcast().map(|label| *label)
139 }
140}
141
142impl<L: marker::Send + Sync + 'static> Message<L> for Box<dyn Any + marker::Send + Sync> {
143 fn upcast(label: L) -> Self {
144 Box::new(label)
145 }
146
147 fn downcast(self) -> Result<L, Self> {
148 self.downcast().map(|label| *label)
149 }
150}
151
152pub trait Role {
156 type Message;
158
159 fn seal(&mut self);
161
162 fn is_sealed(&self) -> bool;
164}
165
166pub trait Route<R>: Role + Sized {
170 type Route;
172
173 fn route(&mut self) -> &mut Self::Route;
175}
176
177pub struct State<'r, R: Role> {
183 role: &'r mut R,
184}
185
186impl<'r, R: Role> State<'r, R> {
187 #[inline]
188 fn new(role: &'r mut R) -> Self {
189 Self { role }
190 }
191}
192
193pub trait FromState<'r> {
197 type Role: Role;
199
200 fn from_state(state: State<'r, Self::Role>) -> Self;
202}
203
204pub trait Session<'r>: FromState<'r> + private::Session {}
206
207pub trait IntoSession<'r>: FromState<'r> {
209 type Session: Session<'r, Role = Self::Role>;
211
212 fn into_session(self) -> Self::Session;
214}
215
216pub struct End<'r, R: Role> {
218 state: State<'r, R>,
219}
220
221impl<'r, R: Role> FromState<'r> for End<'r, R> {
222 type Role = R;
223
224 #[inline]
225 fn from_state(state: State<'r, Self::Role>) -> Self {
226 Self { state }
227 }
228}
229
230impl<R: Role> End<'_, R> {
231 pub fn seal(self) {
233 self.state.role.seal();
234 }
235}
236
237impl<R: Role> Drop for End<'_, R> {
238 fn drop(&mut self) {
239 self.state.role.seal();
241 }
242}
243
244impl<R: Role> private::Session for End<'_, R> {}
245
246impl<'r, R: Role> Session<'r> for End<'r, R> {}
247
248pub struct Send<'q, Q: Role, R, L, S: FromState<'q, Role = Q>> {
250 state: State<'q, Q>,
251 phantom: PhantomData<(R, L, S)>,
252}
253
254impl<'q, Q: Role, R, L, S: FromState<'q, Role = Q>> FromState<'q> for Send<'q, Q, R, L, S> {
255 type Role = Q;
256
257 #[inline]
258 fn from_state(state: State<'q, Self::Role>) -> Self {
259 Self {
260 state,
261 phantom: PhantomData,
262 }
263 }
264}
265
266impl<'q, Q: Route<R>, R, L, S: FromState<'q, Role = Q>> Send<'q, Q, R, L, S>
267where
268 Q::Message: Message<L>,
269 Q::Route: Sink<Q::Message> + Unpin,
270{
271 #[inline]
278 pub async fn send(self, label: L) -> Result<S, SendError<Q, R>> {
279 if self.state.role.is_sealed() {
280 return Err(SessionError::Sealed);
281 }
282 self.state
283 .role
284 .route()
285 .send(Message::upcast(label))
286 .await
287 .map_err(SessionError::Channel)?;
288 Ok(FromState::from_state(self.state))
289 }
290}
291
292impl<'q, Q: Role, R, L, S: FromState<'q, Role = Q>> private::Session for Send<'q, Q, R, L, S> {}
293
294impl<'q, Q: Role, R, L, S: FromState<'q, Role = Q>> Session<'q> for Send<'q, Q, R, L, S> {}
295
296pub struct Receive<'q, Q: Role, R, L, S: FromState<'q, Role = Q>> {
298 state: State<'q, Q>,
299 phantom: PhantomData<(R, L, S)>,
300}
301
302impl<'q, Q: Role, R, L, S: FromState<'q, Role = Q>> FromState<'q> for Receive<'q, Q, R, L, S> {
303 type Role = Q;
304
305 #[inline]
306 fn from_state(state: State<'q, Self::Role>) -> Self {
307 Self {
308 state,
309 phantom: PhantomData,
310 }
311 }
312}
313
314impl<'q, Q: Route<R>, R, L, S: FromState<'q, Role = Q>> Receive<'q, Q, R, L, S>
315where
316 Q::Message: Message<L>,
317 Q::Route: Stream<Item = Q::Message> + Unpin,
318{
319 #[inline]
327 pub async fn receive(self) -> Result<(L, S), ReceiveError> {
328 if self.state.role.is_sealed() {
329 return Err(ReceiveError::Sealed);
330 }
331 let message = self.state.role.route().next().await;
332 let message = message.ok_or(ReceiveError::EmptyStream)?;
333 let label = message.downcast().or(Err(ReceiveError::UnexpectedType))?;
334 Ok((label, FromState::from_state(self.state)))
335 }
336}
337
338impl<'q, Q: Role, R, L, S: FromState<'q, Role = Q>> private::Session for Receive<'q, Q, R, L, S> {}
339
340impl<'q, Q: Role, R, L, S: FromState<'q, Role = Q>> Session<'q> for Receive<'q, Q, R, L, S> {}
341
342pub trait Choice<'r, L> {
344 type Session: FromState<'r>;
346}
347
348pub struct Select<'q, Q: Role, R, C> {
352 state: State<'q, Q>,
353 phantom: PhantomData<(R, C)>,
354}
355
356impl<'q, Q: Role, R, C> FromState<'q> for Select<'q, Q, R, C> {
357 type Role = Q;
358
359 #[inline]
360 fn from_state(state: State<'q, Self::Role>) -> Self {
361 Self {
362 state,
363 phantom: PhantomData,
364 }
365 }
366}
367
368impl<'q, Q: Route<R>, R, C> Select<'q, Q, R, C>
369where
370 Q::Route: Sink<Q::Message> + Unpin,
371{
372 #[inline]
379 pub async fn select<L>(self, label: L) -> Result<<C as Choice<'q, L>>::Session, SendError<Q, R>>
380 where
381 Q::Message: Message<L>,
382 C: Choice<'q, L>,
383 C::Session: FromState<'q, Role = Q>,
384 {
385 if self.state.role.is_sealed() {
386 return Err(SessionError::Sealed);
387 }
388 self.state
389 .role
390 .route()
391 .send(Message::upcast(label))
392 .await
393 .map_err(SessionError::Channel)?;
394 Ok(FromState::from_state(self.state))
395 }
396}
397
398impl<Q: Role, R, C> private::Session for Select<'_, Q, R, C> {}
399
400impl<'q, Q: Role, R, C> Session<'q> for Select<'q, Q, R, C> {}
401
402pub trait Choices<'r>: Sized {
406 type Role: Role;
408
409 fn downcast(
415 state: State<'r, Self::Role>,
416 message: <Self::Role as Role>::Message,
417 ) -> Result<Self, <Self::Role as Role>::Message>;
418}
419
420pub struct Branch<'q, Q: Role, R, C> {
424 state: State<'q, Q>,
425 phantom: PhantomData<(R, C)>,
426}
427
428impl<'q, Q: Role, R, C> FromState<'q> for Branch<'q, Q, R, C> {
429 type Role = Q;
430
431 #[inline]
432 fn from_state(state: State<'q, Self::Role>) -> Self {
433 Self {
434 state,
435 phantom: PhantomData,
436 }
437 }
438}
439
440impl<'q, Q: Route<R>, R, C: Choices<'q, Role = Q>> Branch<'q, Q, R, C>
441where
442 Q::Route: Stream<Item = Q::Message> + Unpin,
443{
444 #[inline]
452 pub async fn branch(self) -> Result<C, ReceiveError> {
453 if self.state.role.is_sealed() {
454 return Err(ReceiveError::Sealed);
455 }
456 let message = self.state.role.route().next().await;
457 let message = message.ok_or(ReceiveError::EmptyStream)?;
458 let choice = C::downcast(self.state, message);
459 choice.or(Err(ReceiveError::UnexpectedType))
460 }
461}
462
463impl<Q: Role, R, C> private::Session for Branch<'_, Q, R, C> {}
464
465impl<'q, Q: Role, R, C> Session<'q> for Branch<'q, Q, R, C> {}
466
467struct SessionGuard {
469 completed: bool,
470}
471
472impl SessionGuard {
473 fn new() -> Self {
474 Self { completed: false }
475 }
476
477 fn mark_completed(&mut self) {
478 self.completed = true;
479 }
480}
481
482impl Drop for SessionGuard {
483 fn drop(&mut self) {
484 if !self.completed {
485 #[cfg(debug_assertions)]
487 {
488 assert!(
489 std::thread::panicking(),
490 "Session dropped without completing! This indicates a protocol violation."
491 );
492 }
493 }
494 }
495}
496
497#[inline]
501pub async fn session<'r, R: Role, S: FromState<'r, Role = R>, T, F>(
502 role: &'r mut R,
503 f: impl FnOnce(S) -> F,
504) -> T
505where
506 F: Future<Output = (T, End<'r, R>)>,
507{
508 let output = try_session(role, |s| f(s).map(Ok)).await;
509 output.unwrap_or_else(|infallible: Infallible| match infallible {})
510}
511
512#[inline]
518pub async fn try_session<'r, R: Role, S: FromState<'r, Role = R>, T, E, F>(
519 role: &'r mut R,
520 f: impl FnOnce(S) -> F,
521) -> Result<T, E>
522where
523 F: Future<Output = Result<(T, End<'r, R>), E>>,
524{
525 let mut guard = SessionGuard::new();
526 let session = FromState::from_state(State::new(role));
527 let result = f(session).await;
528
529 if result.is_ok() {
530 guard.mark_completed();
531 }
532
533 result.map(|(output, _)| output)
535}
536
537mod private {
538 pub trait Session {}
539}
540
541#[cfg(test)]
542mod channel_test;