#![warn(missing_docs)]
use std::io::ErrorKind as IOErrorKind;
use std::sync::Arc;
use bytes::Bytes;
use rings_core::message::MessagePayload;
use rings_rpc::protos::rings_node::SendBackendMessageRequest;
use serde::Deserialize;
use serde::Serialize;
use crate::error::Error;
use crate::provider::Provider;
#[cfg(feature = "snark")]
pub mod snark;
pub type TunnelId = uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub enum BackendMessage {
Extension(Bytes),
ServiceMessage(ServiceMessage),
PlainText(String),
#[cfg(feature = "snark")]
SNARKTaskMessage(snark::SNARKTaskMessage),
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub enum ServiceMessage {
TcpDial {
tid: TunnelId,
service: String,
},
TcpClose {
tid: TunnelId,
reason: TunnelDefeat,
},
TcpPackage {
tid: TunnelId,
body: Bytes,
},
HttpRequest(HttpRequest),
HttpResponse(HttpResponse),
}
#[derive(Deserialize, Serialize, Debug, Clone, Copy)]
#[repr(u8)]
#[non_exhaustive]
pub enum TunnelDefeat {
WebrtcDatachannelSendFailed = 1,
ConnectionTimeout = 2,
ConnectionRefused = 3,
ConnectionAborted = 4,
ConnectionReset = 5,
NotConnected = 6,
ConnectionClosed = 7,
Unknown = u8::MAX,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct HttpRequest {
pub rid: Option<String>,
pub service: String,
pub method: String,
pub path: String,
pub headers: Vec<(String, String)>,
pub body: Option<Vec<u8>>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct HttpResponse {
pub rid: Option<String>,
pub status: u16,
pub headers: Vec<(String, String)>,
pub body: Option<Bytes>,
}
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
pub trait MessageHandler<T> {
async fn handle_message(
&self,
provider: Arc<Provider>,
ctx: &MessagePayload,
data: &T,
) -> Result<(), Box<dyn std::error::Error>>;
}
impl From<ServiceMessage> for BackendMessage {
fn from(val: ServiceMessage) -> Self {
BackendMessage::ServiceMessage(val)
}
}
impl From<IOErrorKind> for TunnelDefeat {
fn from(kind: IOErrorKind) -> TunnelDefeat {
match kind {
IOErrorKind::ConnectionRefused => TunnelDefeat::ConnectionRefused,
IOErrorKind::ConnectionAborted => TunnelDefeat::ConnectionAborted,
IOErrorKind::ConnectionReset => TunnelDefeat::ConnectionReset,
IOErrorKind::NotConnected => TunnelDefeat::NotConnected,
_ => TunnelDefeat::Unknown,
}
}
}
macro_rules! impl_message_handler_for_tuple {
($($T:ident),+; $($n: tt),+; wasm) => {
#[async_trait::async_trait(?Send)]
impl<$($T: MessageHandler<BackendMessage>),+> MessageHandler<BackendMessage> for ($($T),+)
{
async fn handle_message(
&self,
provider: Arc<Provider>,
ctx: &MessagePayload,
msg: &BackendMessage,
) -> std::result::Result<(), Box<dyn std::error::Error>> {
$(
self.$n.handle_message(provider.clone(), ctx, msg).await?;
)+
Ok(())
}
}
};
($($T:ident),+; $($n: tt),+; non_wasm) => {
#[async_trait::async_trait]
impl<$($T: MessageHandler<BackendMessage> + Send + Sync),+> MessageHandler<BackendMessage> for ($($T),+)
{
async fn handle_message(
&self,
provider: Arc<Provider>,
ctx: &MessagePayload,
msg: &BackendMessage,
) -> std::result::Result<(), Box<dyn std::error::Error>> {
$(
self.$n.handle_message(provider.clone(), ctx, msg).await?;
)+
Ok(())
}
}
};
}
#[cfg(not(target_family = "wasm"))]
impl_message_handler_for_tuple!(T1, T2; 0, 1; non_wasm);
#[cfg(not(target_family = "wasm"))]
impl_message_handler_for_tuple!(T1, T2, T3; 0, 1, 2; non_wasm);
#[cfg(not(target_family = "wasm"))]
impl_message_handler_for_tuple!(T1, T2, T3, T4; 0, 1, 2, 3; non_wasm);
#[cfg(not(target_family = "wasm"))]
impl_message_handler_for_tuple!(T1, T2, T3, T4, T5; 0, 1, 2, 3, 4; non_wasm);
#[cfg(target_family = "wasm")]
impl_message_handler_for_tuple!(T1, T2; 0, 1; wasm);
#[cfg(target_family = "wasm")]
impl_message_handler_for_tuple!(T1, T2, T3; 0, 1, 2; wasm);
#[cfg(target_family = "wasm")]
impl_message_handler_for_tuple!(T1, T2, T3, T4; 0, 1, 2, 3; wasm);
#[cfg(target_family = "wasm")]
impl_message_handler_for_tuple!(T1, T2, T3, T4, T5; 0, 1, 2, 3, 4; wasm);
impl BackendMessage {
pub fn into_send_backend_message_request(
self,
destination_did: impl ToString,
) -> Result<SendBackendMessageRequest, Error> {
Ok(SendBackendMessageRequest {
destination_did: destination_did.to_string(),
data: serde_json::to_string(&self)?,
})
}
}