pub mod error;
mod interface;
pub mod prelude;
pub mod protocol;
pub mod result;
pub use super::error::*;
pub use crate::encoding::Encoding;
use crate::imports::*;
pub use interface::{Interface, Method, Notification};
pub use protocol::{BorshProtocol, JsonProtocol, ProtocolHandler};
pub use std::net::SocketAddr;
pub use tokio::sync::mpsc::UnboundedSender as TokioUnboundedSender;
pub use workflow_core::task::spawn;
pub use workflow_websocket::server::{
Error as WebSocketError, Message, Result as WebSocketResult, TcpListener, WebSocketConfig,
WebSocketCounters, WebSocketHandler, WebSocketReceiver, WebSocketSender, WebSocketServer,
WebSocketServerTrait, WebSocketSink,
};
pub mod handshake {
pub use workflow_websocket::server::handshake::*;
}
use crate::server::result::Result;
pub use workflow_rpc_macros::server_method as method;
pub use workflow_rpc_macros::server_notification as notification;
#[derive(Debug, Clone)]
pub struct RpcContext {
pub peer: SocketAddr,
}
#[async_trait]
pub trait RpcHandler: Send + Sync + 'static {
type Context: Send + Sync;
fn accept(&self, _peer: &SocketAddr) -> bool {
true
}
async fn connect(self: Arc<Self>, _peer: &SocketAddr) -> WebSocketResult<()> {
Ok(())
}
async fn handshake(
self: Arc<Self>,
peer: &SocketAddr,
sender: &mut WebSocketSender,
receiver: &mut WebSocketReceiver,
messenger: Arc<Messenger>,
) -> WebSocketResult<Self::Context>;
async fn disconnect(self: Arc<Self>, _ctx: Self::Context, _result: WebSocketResult<()>) {}
}
#[derive(Debug)]
pub struct Messenger {
encoding: Encoding,
sink: WebSocketSink,
}
impl Messenger {
pub fn new(encoding: Encoding, sink: &WebSocketSink) -> Self {
Self {
encoding,
sink: sink.clone(),
}
}
pub fn close(&self) -> Result<()> {
self.sink.send(Message::Close(None))?;
Ok(())
}
pub async fn notify<Ops, Msg>(&self, op: Ops, msg: Msg) -> Result<()>
where
Ops: OpsT,
Msg: BorshSerialize + BorshDeserialize + Serialize + Send + Sync + 'static,
{
match self.encoding {
Encoding::Borsh => {
self.sink
.send(protocol::borsh::create_serialized_notification_message(
op, msg,
)?)?;
}
Encoding::SerdeJson => {
self.sink
.send(protocol::serde_json::create_serialized_notification_message(op, msg)?)?;
}
}
Ok(())
}
pub fn serialize_notification_message<Ops, Msg>(
&self,
op: Ops,
msg: Msg,
) -> Result<tungstenite::Message>
where
Ops: OpsT,
Msg: MsgT,
{
match self.encoding {
Encoding::Borsh => Ok(protocol::borsh::create_serialized_notification_message(
op, msg,
)?),
Encoding::SerdeJson => {
Ok(protocol::serde_json::create_serialized_notification_message(op, msg)?)
}
}
}
pub fn send_raw_message(&self, msg: tungstenite::Message) -> Result<()> {
self.sink.send(msg)?;
Ok(())
}
pub fn sink(&self) -> &WebSocketSink {
&self.sink
}
pub fn encoding(&self) -> Encoding {
self.encoding
}
}
#[derive(Clone)]
struct RpcWebSocketHandler<ServerContext, ConnectionContext, Protocol, Ops>
where
Ops: OpsT,
ServerContext: Clone + Send + Sync + 'static,
ConnectionContext: Clone + Send + Sync + 'static,
Protocol: ProtocolHandler<ServerContext, ConnectionContext, Ops> + Send + Sync + 'static,
{
rpc_handler: Arc<dyn RpcHandler<Context = ConnectionContext>>,
protocol: Arc<Protocol>,
enable_async_handling: bool,
_server_ctx: PhantomData<ServerContext>,
_ops: PhantomData<Ops>,
}
impl<ServerContext, ConnectionContext, Protocol, Ops>
RpcWebSocketHandler<ServerContext, ConnectionContext, Protocol, Ops>
where
Ops: OpsT,
ServerContext: Clone + Send + Sync + 'static,
ConnectionContext: Clone + Send + Sync + 'static,
Protocol: ProtocolHandler<ServerContext, ConnectionContext, Ops> + Send + Sync + 'static,
{
pub fn new(
rpc_handler: Arc<dyn RpcHandler<Context = ConnectionContext>>,
interface: Arc<Interface<ServerContext, ConnectionContext, Ops>>,
enable_async_handling: bool,
) -> Self {
let protocol = Arc::new(Protocol::new(interface));
Self {
rpc_handler,
protocol,
enable_async_handling,
_server_ctx: PhantomData,
_ops: PhantomData,
}
}
}
#[async_trait]
impl<ServerContext, ConnectionContext, Protocol, Ops> WebSocketHandler
for RpcWebSocketHandler<ServerContext, ConnectionContext, Protocol, Ops>
where
Ops: OpsT,
ServerContext: Clone + Send + Sync + 'static,
ConnectionContext: Clone + Send + Sync + 'static,
Protocol: ProtocolHandler<ServerContext, ConnectionContext, Ops> + Send + Sync + 'static,
{
type Context = ConnectionContext;
fn accept(&self, peer: &SocketAddr) -> bool {
self.rpc_handler.accept(peer)
}
async fn connect(self: &Arc<Self>, peer: &SocketAddr) -> WebSocketResult<()> {
self.rpc_handler.clone().connect(peer).await
}
async fn disconnect(self: &Arc<Self>, ctx: Self::Context, result: WebSocketResult<()>) {
self.rpc_handler.clone().disconnect(ctx, result).await
}
async fn handshake(
self: &Arc<Self>,
peer: &SocketAddr,
sender: &mut WebSocketSender,
receiver: &mut WebSocketReceiver,
sink: &WebSocketSink,
) -> WebSocketResult<Self::Context> {
let messenger = Arc::new(Messenger::new(self.protocol.encoding(), sink));
self.rpc_handler
.clone()
.handshake(peer, sender, receiver, messenger)
.await
}
async fn message(
self: &Arc<Self>,
connection_ctx: &Self::Context,
msg: Message,
sink: &WebSocketSink,
) -> WebSocketResult<()> {
let connection_ctx = (*connection_ctx).clone();
if self.enable_async_handling {
let sink = sink.clone();
let this = self.clone();
spawn(async move {
this.protocol
.handle_message(connection_ctx, msg, &sink)
.await
});
Ok(())
} else {
self.protocol
.handle_message(connection_ctx, msg, sink)
.await
}
}
}
#[derive(Clone)]
pub struct RpcServer {
ws_server: Arc<dyn WebSocketServerTrait>,
}
impl RpcServer {
pub fn new<ServerContext, ConnectionContext, Protocol, Ops>(
rpc_handler: Arc<dyn RpcHandler<Context = ConnectionContext>>,
interface: Arc<Interface<ServerContext, ConnectionContext, Ops>>,
counters: Option<Arc<WebSocketCounters>>,
enable_async_handling: bool,
) -> RpcServer
where
ServerContext: Clone + Send + Sync + 'static,
ConnectionContext: Clone + Send + Sync + 'static,
Protocol: ProtocolHandler<ServerContext, ConnectionContext, Ops> + Send + Sync + 'static,
Ops: OpsT,
{
let ws_handler = Arc::new(RpcWebSocketHandler::<
ServerContext,
ConnectionContext,
Protocol,
Ops,
>::new(rpc_handler, interface, enable_async_handling));
let ws_server = WebSocketServer::new(ws_handler, counters);
RpcServer { ws_server }
}
pub fn new_with_encoding<ServerContext, ConnectionContext, Ops, Id>(
encoding: Encoding,
rpc_handler: Arc<dyn RpcHandler<Context = ConnectionContext>>,
interface: Arc<Interface<ServerContext, ConnectionContext, Ops>>,
counters: Option<Arc<WebSocketCounters>>,
enable_async_handling: bool,
) -> RpcServer
where
ServerContext: Clone + Send + Sync + 'static,
ConnectionContext: Clone + Send + Sync + 'static,
Ops: OpsT,
Id: IdT,
{
match encoding {
Encoding::Borsh => {
RpcServer::new::<
ServerContext,
ConnectionContext,
BorshProtocol<ServerContext, ConnectionContext, Ops, Id>,
Ops,
>(rpc_handler, interface, counters, enable_async_handling)
}
Encoding::SerdeJson => {
RpcServer::new::<
ServerContext,
ConnectionContext,
JsonProtocol<ServerContext, ConnectionContext, Ops, Id>,
Ops,
>(rpc_handler, interface, counters, enable_async_handling)
}
}
}
pub async fn bind(&self, addr: &str) -> WebSocketResult<TcpListener> {
let addr = addr.replace("wrpc://", "");
self.ws_server.clone().bind(&addr).await
}
pub async fn listen(
&self,
listener: TcpListener,
config: Option<WebSocketConfig>,
) -> WebSocketResult<()> {
self.ws_server.clone().listen(listener, config).await
}
pub fn stop(&self) -> WebSocketResult<()> {
self.ws_server.stop()
}
pub async fn join(&self) -> WebSocketResult<()> {
self.ws_server.join().await
}
pub async fn stop_and_join(&self) -> WebSocketResult<()> {
self.ws_server.stop_and_join().await
}
}