Session Types
async-session-types
is a library to express the interaction between two parties
as message passing with statically compiled correctness. Any violation
of the protocol, such as an unexpected message type, would result in
closing down the connection with the remote peer.
This allows both parties to protect themselves from malicious counterparts,
by maintaining a session where they can carry a context of what they know
about each other and what they can rightfully expect next.
It can also protect against one party trying to overwhelm the other with
messages by making back pressure an essential part of the protocols, although
this hasn't been implemented yet (it would need bounded channels).
Session Types are a way to achieve something similar to Cardano's Mini Protocols.
The library has been inspired by session-types,
but has been extended in the following ways:
- To work with unreliable parties, it breaks up the session rather than panic if an unexpected message arrives.
- Timeouts are required for every receive operation.
- The ability to abort a session with an error if some business rule has been violated; these are rules that aren't expressed in
the protocol.
- Using
tokio
channels so we can have multiple session types without blocking full threads while awaiting the next message.
- At any time, only one of the the server or the client are in a position to send the next message.
- Every decision has to be represented as a message, unlike in the original which sent binary flags to indicate choice.
- Added a Repr type to be able to route messages to multiple session types using
enum
wrappers, instead of relying on dynamic casting.
- Added a pair of incoming/outgoing demultiplexer/multiplexer types to support dispatching to multiple session types over a single connection, like a Web Socket.
Please have a look at the tests to see examples.
Example
The following snippets demonstrate of hooking up some protocol with a Web Socket using JSON format.
Say we have a protocol to sync blocks from a blockchain:
pub mod messages {
use core::property::{HasHash, HasParent};
use async_session_types::{repr_bound, Repr};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct FindIntersect<B: HasParent>(pub Vec<<B as HasHash>::Hash>);
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct IntersectFound<B: HasParent>(pub <B as HasHash>::Hash);
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct IntersectNotFound;
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct RequestNext;
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct AwaitReply;
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct RollForward<B: HasParent>(pub B);
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct RollBackward<B: HasParent>(pub <B as HasHash>::Hash);
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Done;
repr_bound! {
pub LocalBlockSyncReprs<B: HasParent> [
FindIntersect<B>,
IntersectFound<B>,
IntersectNotFound,
RequestNext,
AwaitReply,
RollForward<B>,
RollBackward<B>,
Done
]
}
}
#[rustfmt::skip]
pub mod protocol {
use async_session_types::*;
use super::messages::*;
pub type Intersect<B> =
Recv<
FindIntersect<B>,
Choose<
Send<IntersectFound<B>, Var<Z>>,
Send<IntersectNotFound, Var<Z>>
>,
>;
pub type Roll<B> =
Choose<
Send<RollForward<B>, Var<Z>>,
Send<RollBackward<B>, Var<Z>>
>;
pub type Next<B> =
Recv<
RequestNext,
Choose<
Roll<B>,
Send<AwaitReply, Roll<B>>
>
>;
pub type Quit = Recv<Done, Eps>;
pub type Server<B> =
Offer<
Intersect<B>,
Offer<
Next<B>,
Quit
>
>;
pub type Client<B> = <Server<B> as HasDual>::Dual;
}
And a message wrapper that handles JSON tagging:
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Debug)]
pub enum ProtocolId {
LocalBlockSync,
}
pub type ProtocolMessage = MultiMessage<ProtocolId, Wrapper>;
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(tag = "type", content = "payload")]
pub enum LbsWrapper {
FindIntersect(lbs::FindIntersect<WalletBlock>),
IntersectFound(lbs::IntersectFound<WalletBlock>),
IntersectNotFound,
RequestNext,
AwaitReply,
RollForward(Box<lbs::RollForward<WalletBlock>>),
RollBackward(lbs::RollBackward<WalletBlock>),
Done,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(tag = "protocol")]
pub enum Wrapper {
LocalBlockSync(LbsWrapper),
}
repr_impl! {
Wrapper {
lbs::FindIntersect<WalletBlock> : (
|x| Wrapper::LocalBlockSync(LbsWrapper::FindIntersect(x)),
Wrapper::LocalBlockSync(LbsWrapper::FindIntersect(x)) => x
),
lbs::IntersectFound<WalletBlock> : (
|x| Wrapper::LocalBlockSync(LbsWrapper::IntersectFound(x)),
Wrapper::LocalBlockSync(LbsWrapper::IntersectFound(x)) => x
),
lbs::IntersectNotFound : (
|_| Wrapper::LocalBlockSync(LbsWrapper::IntersectNotFound),
Wrapper::LocalBlockSync(LbsWrapper::IntersectNotFound) => lbs::IntersectNotFound
),
lbs::RequestNext : (
|_| Wrapper::LocalBlockSync(LbsWrapper::RequestNext),
Wrapper::LocalBlockSync(LbsWrapper::RequestNext) => lbs::RequestNext
),
lbs::AwaitReply : (
|_| Wrapper::LocalBlockSync(LbsWrapper::AwaitReply),
Wrapper::LocalBlockSync(LbsWrapper::AwaitReply) => lbs::AwaitReply
),
lbs::RollForward<WalletBlock> : (
|x| Wrapper::LocalBlockSync(LbsWrapper::RollForward(Box::new(x))),
Wrapper::LocalBlockSync(LbsWrapper::RollForward(x)) => *x
),
lbs::RollBackward<WalletBlock> : (
|x| Wrapper::LocalBlockSync(LbsWrapper::RollBackward(x)),
Wrapper::LocalBlockSync(LbsWrapper::RollBackward(x)) => x
),
lbs::Done : (
|_| Wrapper::LocalBlockSync(LbsWrapper::Done),
Wrapper::LocalBlockSync(LbsWrapper::Done) => lbs::Done
),
}
}
Then we can handle web socket connections by creating new channels and spawning a task with a multiplexer:
async fn handle_connection(stream: TcpStream, addr: SocketAddr, root: Arc<Root>) {
debug!("Incoming TCP connection from: {}", addr);
let ws_stream = match tokio_tungstenite::accept_async(stream).await {
Ok(s) => s,
Err(e) => {
error!("Error during handshake from {}: {}", addr, e);
return;
}
};
info!("WebSocket connection established: {}", addr);
let (mut ws_out, ws_in) = ws_stream.split();
let (tx_in, rx_in) = tokio::sync::mpsc::unbounded_channel::<ProtocolMessage>();
let (tx_out, mut rx_out) = tokio::sync::mpsc::unbounded_channel::<ProtocolMessage>();
let _ = tokio::spawn(async move {
let imc = IncomingMultiChannel::new(tx_out, rx_in);
imc.run(|protocol_id, errors| root.init_server(addr, protocol_id, errors))
.await;
debug!("IncomingMultiChannel for {} stopped.", addr);
});
let forward_incoming = ws_in.try_for_each(|msg| {
if msg.is_text() || msg.is_binary() {
match Wrapper::try_from(msg) {
Ok(wrapper) => {
if tx_in.send(wrapper.into()).is_err() {
return future::err(tungstenite::Error::ConnectionClosed);
}
}
Err(ProtocolError::SerdeJson(e)) => {
warn!("Could not deserialise message from {}: {}", addr, e);
}
Err(ProtocolError::Tungstenite(e)) => {
error!("Unexpected tungstenite error from {}: {}", addr, e);
return future::err(e);
}
Err(e) => {
error!("Unexpected error from {}: {:?}", addr, e);
return future::err(tungstenite::Error::ConnectionClosed);
}
}
}
future::ok(())
});
let forward_outgoing = tokio::spawn(async move {
while let Some(msg) = rx_out.recv().await {
match serde_json::to_string(&msg.payload) {
Ok(json_str) => {
let msg = tungstenite::Message::text(json_str);
if ws_out.send(msg).await.is_err() {
return;
}
}
Err(e) => {
warn!("Could not convert reply to JSON: {}", e);
}
}
}
});
pin_mut!(forward_incoming, forward_outgoing);
future::select(forward_incoming, forward_outgoing).await;
info!("Disconnected from {}", addr);
}
And spawn a session handler when we first see a message for a new protocol:
type WrapperChan = (Sender<Wrapper>, Receiver<Wrapper>);
impl Root {
pub fn init_server(
&self,
addr: SocketAddr,
protocol_id: ProtocolId,
errors: Sender<SessionError>,
) -> WrapperChan {
let handle_result = move |result| {
match result {
Ok(()) => {
info!("Session {} from {} finished.", protocol_id, addr);
}
Err(SessionError::Disconnected) => {
info!("Session {} from {} disconnected.", protocol_id, addr);
}
Err(SessionError::Timeout) => {
warn!("Session {} from {} timed out.", protocol_id, addr);
}
Err(SessionError::UnexpectedMessage(msg)) => {
if let Some(msg) = msg.downcast_ref::<Wrapper>() {
error!(
"Session {} from {} ended because of an unexpected message: {:?}",
protocol_id, addr, msg
)
} else {
error!(
"Session {} from {} ended because of a completely unexpected message.",
protocol_id, addr
)
}
}
Err(SessionError::Abort(e)) => {
error!("Session {} from {} aborted: {}", protocol_id, addr, e);
let _ = errors.send(SessionError::Abort(e));
}
}
};
match protocol_id {
ProtocolId::LocalBlockSync => self.init_local_block_sync_server(handle_result),
}
}
fn init_local_block_sync_server<H>(&self, handle_result: H) -> WrapperChan
where
H: Send + 'static + FnOnce(Result<(), SessionError>),
{
let (chan, (tx, rx)) = session_channel_dyn::<Rec<Server<WalletBlock>>, Wrapper>();
let mut server = Producer::new();
tokio::spawn(async {
handle_result(server.sync_chain(chan).await)
});
(tx, rx)
}
}
Prerequisites
Install the following to be be able to build the project:
curl https://sh.rustup.rs -sSf | sh
rustup toolchain install nightly
rustup default nightly
rustup update
See more
License
This project is licensed under the MIT license.