entropy_protocol/protocol_transport/
mod.rsmod broadcaster;
pub mod errors;
pub mod noise;
mod subscribe_message;
use async_trait::async_trait;
pub use broadcaster::Broadcaster;
use errors::WsError;
#[cfg(any(feature = "server", feature = "wasm"))]
use futures::{SinkExt, StreamExt};
use noise::EncryptedWsConnection;
pub use subscribe_message::SubscribeMessage;
use tokio::sync::{broadcast, mpsc};
#[cfg(feature = "server")]
use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream};
use crate::{PartyId, ProtocolMessage};
pub struct WsChannels {
pub broadcast: broadcast::Receiver<ProtocolMessage>,
pub tx: mpsc::Sender<ProtocolMessage>,
pub is_final: bool,
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
pub trait WsConnection {
async fn recv(&mut self) -> Result<Vec<u8>, WsError>;
async fn send(&mut self, msg: Vec<u8>) -> Result<(), WsError>;
}
#[cfg(feature = "wasm")]
#[async_trait(?Send)]
impl WsConnection for gloo_net::websocket::futures::WebSocket {
async fn recv(&mut self) -> Result<Vec<u8>, WsError> {
if let gloo_net::websocket::Message::Bytes(msg) = self
.next()
.await
.ok_or(WsError::ConnectionClosed)?
.map_err(|e| WsError::ConnectionError(e.to_string()))?
{
Ok(msg)
} else {
Err(WsError::UnexpectedMessageType)
}
}
async fn send(&mut self, msg: Vec<u8>) -> Result<(), WsError> {
SinkExt::send(&mut self, gloo_net::websocket::Message::Bytes(msg))
.await
.map_err(|_| WsError::ConnectionClosed)
}
}
#[cfg(feature = "server")]
#[async_trait]
impl WsConnection for axum::extract::ws::WebSocket {
async fn recv(&mut self) -> Result<Vec<u8>, WsError> {
if let axum::extract::ws::Message::Binary(msg) = self
.recv()
.await
.ok_or(WsError::ConnectionClosed)?
.map_err(|e| WsError::ConnectionError(e.to_string()))?
{
Ok(msg)
} else {
Err(WsError::UnexpectedMessageType)
}
}
async fn send(&mut self, msg: Vec<u8>) -> Result<(), WsError> {
self.send(axum::extract::ws::Message::Binary(msg))
.await
.map_err(|_| WsError::ConnectionClosed)
}
}
#[cfg(feature = "server")]
#[async_trait]
impl WsConnection for tokio_tungstenite::WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>> {
async fn recv(&mut self) -> Result<Vec<u8>, WsError> {
if let tungstenite::Message::Binary(msg) = self
.next()
.await
.ok_or(WsError::ConnectionClosed)?
.map_err(|e| WsError::ConnectionError(e.to_string()))?
{
Ok(msg)
} else {
Err(WsError::UnexpectedMessageType)
}
}
async fn send(&mut self, msg: Vec<u8>) -> Result<(), WsError> {
SinkExt::send(&mut self, tungstenite::Message::Binary(msg))
.await
.map_err(|_| WsError::ConnectionClosed)
}
}
#[cfg(feature = "server")]
#[async_trait]
impl WsConnection for tokio_tungstenite::WebSocketStream<tokio::net::TcpStream> {
async fn recv(&mut self) -> Result<Vec<u8>, WsError> {
if let tungstenite::Message::Binary(msg) = self
.next()
.await
.ok_or(WsError::ConnectionClosed)?
.map_err(|e| WsError::ConnectionError(e.to_string()))?
{
Ok(msg)
} else {
Err(WsError::UnexpectedMessageType)
}
}
async fn send(&mut self, msg: Vec<u8>) -> Result<(), WsError> {
SinkExt::send(&mut self, tungstenite::Message::Binary(msg))
.await
.map_err(|_| WsError::ConnectionClosed)
}
}
pub async fn ws_to_channels<T: WsConnection>(
mut connection: EncryptedWsConnection<T>,
mut ws_channels: WsChannels,
remote_party_id: PartyId,
) -> Result<(), WsError> {
loop {
tokio::select! {
signing_message_result = connection.recv() => {
let serialized_signing_message = signing_message_result.map_err(|e| WsError::EncryptedConnection(e.to_string()))?;
let msg = ProtocolMessage::try_from(&serialized_signing_message[..])?;
ws_channels.tx.send(msg).await.map_err(|_| WsError::MessageAfterProtocolFinish)?;
}
msg_result = ws_channels.broadcast.recv() => {
if let Ok(msg) = msg_result {
if msg.to != remote_party_id {
continue;
}
let message_vec = bincode::serialize(&msg)?;
connection.send(message_vec).await.map_err(|e| WsError::EncryptedConnection(e.to_string()))?;
} else {
return Ok(());
}
}
}
}
}
#[cfg(feature = "server")]
pub trait ThreadSafeWsConnection: WsConnection + std::marker::Send + 'static {}
#[cfg(feature = "wasm")]
pub trait ThreadSafeWsConnection: WsConnection + 'static {}
#[cfg(feature = "server")]
impl ThreadSafeWsConnection
for WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>
{
}
#[cfg(feature = "wasm")]
impl ThreadSafeWsConnection for gloo_net::websocket::futures::WebSocket {}