use anyhow::Result;
use serde::Serialize;
use serde::de::DeserializeOwned;
use std::path::Path;
use std::sync::Arc;
use tokio_socket::HandleSocket;
use lockshed::SharedMap;
use tokio_socket::PacketWriter;
use tokio_socket::SocketClient;
use tokio_socket::SocketPeer;
use tokio_socket::SocketServer;
use super::ReceiveRpcProtocol;
use super::RpcMessageState;
use super::RpcPacket;
use super::RpcPacketHandler;
use super::SendRpcProtocol;
#[derive(Clone)]
pub struct RpcPeer {
net_peer: SocketPeer,
state: RpcMessageState,
}
impl std::fmt::Debug for RpcPeer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RpcPeer")
.field("net_peer", &self.net_peer)
.finish()
}
}
impl RpcPeer {
pub fn new(net_peer: SocketPeer, state: RpcMessageState) -> Self {
Self { net_peer, state }
}
pub async fn send_message_to_protocol<M: Serialize>(
&self,
protocol_id: u64,
m: &M,
) -> Result<()> {
let message_buf = bincode::serialize(m)?;
let request_id = self.state.next_request_id();
let message = RpcPacket::new_request(request_id, protocol_id, message_buf);
self.net_peer.write_packet(&message).await
}
pub async fn send_message_with_response_to_protocol<M: Serialize, R: DeserializeOwned>(
&self,
protocol_id: u64,
m: &M,
) -> Result<R> {
let request_id = self.state.next_request_id();
let response_promise = self
.state
.response_promise_store
.new_promise(request_id)
.await;
let message_buf = bincode::serialize(m)?;
let message = RpcPacket::new_request(request_id, protocol_id, message_buf);
self.net_peer.write_packet(&message).await?;
let response_buf = response_promise.resolve().await?;
let response: R = bincode::deserialize(&response_buf)?;
Ok(response)
}
pub async fn send_message<M: Serialize>(&self, m: &M) -> Result<()> {
self.send_message_to_protocol(0, m).await
}
pub async fn send_message_with_response<M: Serialize, R: DeserializeOwned>(
&self,
m: &M,
) -> Result<R> {
self.send_message_with_response_to_protocol(0, m).await
}
pub fn peer_id(&self) -> &String {
self.net_peer.peer_id()
}
}
#[derive(Clone)]
pub struct RpcServer {
#[allow(unused)]
stream_server: SocketServer,
#[allow(unused)]
message_states: SharedMap<String, RpcMessageState>,
}
#[derive(Clone)]
pub struct RpcSocketHandler<H: RpcServerHandler> {
server_handler: Arc<H>,
message_states: SharedMap<String, RpcMessageState>,
}
impl<H: RpcServerHandler> HandleSocket for RpcSocketHandler<H> {
type PacketHandler = RpcPacketHandler<H::ReceiveRpc>;
async fn on_socket_connect(&self, peer: &SocketPeer) -> Option<Self::PacketHandler> {
let peer_id = peer.peer_id().clone();
let state = RpcMessageState::new();
self.message_states.insert(peer_id, state.clone()).await;
let rpc_peer = RpcPeer::new(peer.clone(), state.clone());
let rpc_sender = H::SendRpc::new(rpc_peer.clone());
let handler = self.server_handler.on_rpc_connect(&rpc_sender).await;
Some(RpcPacketHandler::new(handler, state.clone()))
}
async fn on_socket_disconnect(&self, peer_id: &str) {
self.server_handler.on_rpc_disconnect(peer_id).await;
}
}
pub trait RpcServerHandler: Send + Sync + 'static {
type ReceiveRpc: ReceiveRpcProtocol;
type SendRpc: SendRpcProtocol;
fn on_rpc_connect(
&self,
rpc_sender: &Self::SendRpc,
) -> impl std::future::Future<Output = Self::ReceiveRpc> + Send + Sync;
fn on_rpc_ready(&self) -> impl std::future::Future<Output = ()> + Send + Sync {
async {}
}
fn on_rpc_disconnect(
&self,
_peer_id: &str,
) -> impl std::future::Future<Output = ()> + Send + Sync {
async {}
}
}
impl RpcServer {
pub fn bind_unix(
addr: &crate::SocketAddr,
server_handler: impl RpcServerHandler,
) -> Result<Self> {
let message_states = SharedMap::new();
let handle_sockets = RpcSocketHandler {
server_handler: Arc::new(server_handler),
message_states: message_states.clone(),
};
let stream_server = SocketServer::bind_uds(addr.clone(), handle_sockets)?;
Ok(RpcServer {
stream_server,
message_states,
})
}
pub fn shutdown(&self) {
self.stream_server.cancel.cancel();
}
}
impl std::fmt::Debug for RpcServer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RpcServer").finish()
}
}
pub trait HandleRpcClient: Send + Sync + 'static {
type ReceiveRpc: ReceiveRpcProtocol;
fn on_rpc_connect(&self) -> impl std::future::Future<Output = Self::ReceiveRpc> + Send + Sync;
fn on_rpc_ready(&self) -> impl std::future::Future<Output = ()> + Send + Sync {
async {}
}
fn on_rpc_disconnect(&self) -> impl std::future::Future<Output = ()> + Send + Sync {
async {}
}
}
#[derive(Clone)]
pub struct RpcClient<S: SendRpcProtocol> {
pub sender: Arc<S>,
pub peer: SocketPeer,
}
impl<S: SendRpcProtocol> RpcClient<S> {
pub async fn connect_unix(
path: impl AsRef<Path>,
client_handler: impl ReceiveRpcProtocol,
) -> Result<Self> {
let state = RpcMessageState::new();
let packet_handler = RpcPacketHandler::new(client_handler, state.clone());
let socket_client = SocketClient::connect_file(path, packet_handler).await?;
let peer = socket_client.peer;
let rpc_peer = RpcPeer::new(peer.clone(), state);
Ok(Self {
sender: Arc::new(S::new(rpc_peer)),
peer,
})
}
pub async fn connect(
addr: crate::SocketAddr,
client_handler: impl ReceiveRpcProtocol,
) -> Result<Self> {
let state = RpcMessageState::new();
let packet_handler = RpcPacketHandler::new(client_handler, state.clone());
let socket_client = SocketClient::connect(addr, packet_handler).await?;
let peer = socket_client.peer;
let rpc_peer = RpcPeer::new(peer.clone(), state);
Ok(Self {
sender: Arc::new(S::new(rpc_peer)),
peer,
})
}
pub fn clone_sender(&self) -> Arc<S> {
self.sender.clone()
}
}