use std::sync;
use anyhow::{Result, anyhow, bail};
use lockshed::{PromiseId, PromiseStore};
use serde::{Deserialize, Serialize};
pub type HandlerId = u64;
pub type RequestId = PromiseId;
pub type Buffer = Vec<u8>;
#[derive(Copy, Clone, Serialize, Deserialize, PartialEq)]
#[repr(u8)]
pub enum RpcPacketType {
Request = 1,
Response = 2,
}
impl RpcPacketType {
pub fn to_u8(&self) -> u8 {
*self as u8
}
pub fn from_u8(val: u8) -> Result<Self> {
match val {
_ if val == RpcPacketType::Request as u8 => Ok(RpcPacketType::Request),
_ if val == RpcPacketType::Response as u8 => Ok(RpcPacketType::Response),
_ => bail!("Unknown message type {val}"),
}
}
pub fn to_string(&self) -> &'static str {
match self {
RpcPacketType::Request => "Request",
RpcPacketType::Response => "Response",
}
}
}
impl std::fmt::Debug for RpcPacketType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.to_string())
}
}
pub type ResponsePromiseStore = PromiseStore<Vec<u8>>;
#[derive(Clone, Debug)]
pub struct RpcPacket {
pub packet_type: RpcPacketType,
pub packet_id: RequestId,
pub protocol_id: u64, pub buf: Vec<u8>,
}
impl RpcPacket {
fn new(packet_type: RpcPacketType, packet_id: RequestId, protocol_id: u64, buf: Vec<u8>) -> Self {
Self {
packet_type,
packet_id,
protocol_id,
buf,
}
}
pub fn new_request(request_id: RequestId, protocol_id: u64, buf: Vec<u8>) -> Self {
Self::new(RpcPacketType::Request, request_id, protocol_id, buf)
}
pub fn new_response(request_id: RequestId, protocol_id: u64, buf: Vec<u8>) -> Self {
Self::new(RpcPacketType::Response, request_id, protocol_id, buf)
}
}
impl tokio_socket::PacketProtocol for RpcPacket {
fn to_bytes(&self) -> Result<Vec<u8>> {
let buf_len = self.buf.len();
let mut bytes = Vec::with_capacity(
1 + 8 + 8 + 8 + buf_len,
);
bytes.push(self.packet_type.to_u8());
bytes.extend(&self.packet_id.to_be_bytes());
bytes.extend(&self.protocol_id.to_be_bytes());
bytes.extend(&(buf_len as u64).to_be_bytes());
bytes.extend(&self.buf);
Ok(bytes)
}
fn from_bytes(bytes: &[u8]) -> Result<Self> {
if bytes.len() < 25 {
bail!("Packet too small: {} bytes", bytes.len());
}
let packet_kind = RpcPacketType::from_u8(bytes[0])
.map_err(|e| anyhow!("Invalid message type: {e}"))?;
let packet_id = u64::from_be_bytes(
bytes[1..9]
.try_into()
.map_err(|e| anyhow!("Invalid message_id: {e}"))?
);
let protocol_id = u64::from_be_bytes(
bytes[9..17]
.try_into()
.map_err(|e| anyhow!("Invalid protocol_id: {e}"))?
);
let buf_len = u64::from_be_bytes(
bytes[17..25]
.try_into()
.map_err(|e| anyhow!("Invalid buf_len: {e}"))?
) as usize;
if bytes.len() < 25 + buf_len {
bail!("Packet truncated: expected {} bytes, got {}", 25 + buf_len, bytes.len());
}
let buf = bytes[25..25 + buf_len].to_vec();
Ok(RpcPacket::new(packet_kind, packet_id, protocol_id, buf))
}
}
pub trait RpcProtocolSender: Clone + Send + Sync + 'static {
fn peer(&self) -> &crate::RpcPeer;
}
pub trait ReceiveRpcProtocol: Clone + Send + Sync + 'static {
fn handle_packet(
&self,
protocol_id: u64,
peer: &tokio_socket::SocketPeer,
buf: Vec<u8>,
) -> impl std::future::Future<Output = Result<Option<Vec<u8>>>> + Send;
}
pub trait SendRpcProtocol: Send + Sync + 'static {
fn new(peer: crate::RpcPeer) -> Self
where
Self: Sized;
}
#[derive(Clone)]
pub struct RpcMessageState {
pub response_promise_store: ResponsePromiseStore,
request_id_counter: sync::Arc<sync::atomic::AtomicU64>,
}
impl RpcMessageState {
pub fn new() -> Self {
Self {
request_id_counter: sync::Arc::new(sync::atomic::AtomicU64::new(0)),
response_promise_store: ResponsePromiseStore::new(),
}
}
}
impl Default for RpcMessageState {
fn default() -> Self {
Self::new()
}
}
impl RpcMessageState {
pub fn next_request_id(&self) -> u64 {
self.request_id_counter
.fetch_add(1, sync::atomic::Ordering::SeqCst)
}
}
#[derive(Clone)]
pub struct RpcPacketHandler<H: ReceiveRpcProtocol> {
packet_handler: H,
state: RpcMessageState,
}
impl<H: ReceiveRpcProtocol> std::fmt::Debug for RpcPacketHandler<H> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RpcPacketHandler").finish()
}
}
impl<H: ReceiveRpcProtocol> RpcPacketHandler<H> {
pub fn new(packet_handler: H, state: RpcMessageState) -> Self {
Self {
packet_handler,
state,
}
}
}
impl<H: ReceiveRpcProtocol + 'static> tokio_socket::HandlePacket for RpcPacketHandler<H> {
type Packet = RpcPacket;
async fn on_packet(&self, peer: tokio_socket::SocketPeer, packet: Self::Packet) -> Result<()> {
match packet.packet_type {
RpcPacketType::Request => {
tracing::trace!("recv request from {peer} protocol={} {packet:?}", packet.protocol_id);
let response = self.packet_handler
.handle_packet(packet.protocol_id, &peer, packet.buf)
.await?;
tracing::trace!("send response to {peer} {response:?}");
if let Some(buf) = response {
let response = RpcPacket::new_response(packet.packet_id, packet.protocol_id, buf);
tokio_socket::PacketWriter::write_packet(&peer, &response).await
} else {
Ok(())
}
}
RpcPacketType::Response => {
tracing::trace!("recv response from {peer} {packet:?}");
let resolver = self
.state
.response_promise_store
.get_resolver(packet.packet_id)
.await?;
tracing::trace!("found response resolver for request from {peer} {packet:?}");
resolver.resolve(packet.buf).await
}
}
}
}