use anyhow::{Result, anyhow};
use std::collections::HashMap;
use std::sync::Arc;
pub type PacketResult = Result<Option<Vec<u8>>>;
pub trait ProtocolHandler: Send + Sync {
fn handle_packet(
&self,
peer: &tokio_socket::SocketPeer,
buf: Vec<u8>,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = PacketResult> + Send + '_>>;
}
#[derive(Clone)]
pub struct ProtocolRegistry {
handlers: Arc<HashMap<u64, Arc<dyn ProtocolHandler>>>,
names: Arc<HashMap<u64, String>>,
}
impl ProtocolRegistry {
pub fn builder() -> ProtocolRegistryBuilder {
ProtocolRegistryBuilder::new()
}
pub fn list_protocols(&self) -> Vec<(u64, String)> {
self.names
.iter()
.map(|(id, name)| (*id, name.clone()))
.collect()
}
}
impl super::ReceiveRpcProtocol for ProtocolRegistry {
async fn handle_packet(
&self,
protocol_id: u64,
peer: &tokio_socket::SocketPeer,
buf: Vec<u8>,
) -> Result<Option<Vec<u8>>> {
let handler = self
.handlers
.get(&protocol_id)
.ok_or_else(|| anyhow!("Unknown protocol ID: 0x{:016x}", protocol_id))?;
handler.handle_packet(peer, buf).await
}
}
pub struct ProtocolRegistryBuilder {
handlers: HashMap<u64, Arc<dyn ProtocolHandler>>,
names: HashMap<u64, String>,
}
impl ProtocolRegistryBuilder {
pub fn new() -> Self {
Self {
handlers: HashMap::new(),
names: HashMap::new(),
}
}
pub fn register<H>(mut self, protocol_id: u64, name: String, handler: H) -> Self
where
H: ProtocolHandler + 'static,
{
if let Some(existing) = self.names.get(&protocol_id) {
panic!(
"Protocol ID collision! {} and {} both have ID 0x{:016x}",
existing, name, protocol_id
);
}
self.handlers.insert(protocol_id, Arc::new(handler));
self.names.insert(protocol_id, name);
self
}
pub fn register_auto<P, H>(self, handler: H) -> Self
where
P: super::ProtocolMetadata,
H: ProtocolHandler + 'static,
{
let name = format!(
"{}::{} v{}",
P::PROTOCOL_CRATE,
P::PROTOCOL_NAME,
P::PROTOCOL_VERSION
);
self.register(P::PROTOCOL_ID, name, handler)
}
pub fn with(
self,
protocol_id: u64,
protocol_crate: &str,
protocol_name: &str,
handler: impl ProtocolHandler + 'static,
) -> Self {
let name = format!("{}::{}", protocol_crate, protocol_name);
self.register(protocol_id, name, handler)
}
pub fn build(self) -> ProtocolRegistry {
ProtocolRegistry {
handlers: Arc::new(self.handlers),
names: Arc::new(self.names),
}
}
}
impl Default for ProtocolRegistryBuilder {
fn default() -> Self {
Self::new()
}
}
#[macro_export]
macro_rules! registry {
($($protocol:ident => $handler:expr),* $(,)?) => {
$crate::ProtocolRegistryBuilder::new()
$(
.register(
$protocol::PROTOCOL_ID,
format!("{}::{}", $protocol::PROTOCOL_CRATE, $protocol::PROTOCOL_NAME),
$protocol::ReceiverWrapper::new($handler),
)
)*
.build()
};
}
#[macro_export]
macro_rules! registry_for {
($handler:expr, [$($protocol:ident),* $(,)?]) => {
{
let h = $handler;
$crate::ProtocolRegistryBuilder::new()
$(
.register(
$protocol::PROTOCOL_ID,
format!("{}::{}", $protocol::PROTOCOL_CRATE, $protocol::PROTOCOL_NAME),
$protocol::ReceiverWrapper::new(h.clone()),
)
)*
.build()
}
};
}