use std::collections::HashMap;
use std::sync::Arc;
use parking_lot::RwLock;
use super::metrics::FoldStats;
use super::state::ApplyOutcome;
use super::wire::SignedAnnouncement;
use super::wire::WireError;
use super::{Fold, FoldKind};
use crate::adapter::net::identity::EntityId;
pub const SUBPROTOCOL_FOLD: u16 = 0x1000;
pub trait FoldDispatch: Send + Sync {
fn kind_id(&self) -> u16;
fn dispatch(
&self,
bytes: &[u8],
publisher: &crate::adapter::net::identity::EntityId,
) -> Result<ApplyOutcome, WireError>;
fn stats(&self) -> FoldStats;
}
pub struct FoldDispatchAdapter<K: FoldKind> {
fold: Arc<Fold<K>>,
}
impl<K: FoldKind> FoldDispatchAdapter<K> {
pub fn new(fold: Arc<Fold<K>>) -> Self {
Self { fold }
}
pub fn fold(&self) -> &Arc<Fold<K>> {
&self.fold
}
}
impl<K: FoldKind> FoldDispatch for FoldDispatchAdapter<K> {
fn kind_id(&self) -> u16 {
K::KIND_ID
}
fn stats(&self) -> FoldStats {
self.fold.stats()
}
fn dispatch(
&self,
bytes: &[u8],
publisher: &crate::adapter::net::identity::EntityId,
) -> Result<ApplyOutcome, WireError> {
let ann = SignedAnnouncement::<K::Payload>::decode_and_verify(bytes, publisher)?;
if ann.kind != K::KIND_ID {
return Err(WireError::KindMismatch {
got: ann.kind,
expected: K::KIND_ID,
});
}
Ok(self.fold.apply(ann)?)
}
}
pub struct FoldRegistry {
folds: RwLock<HashMap<u16, Arc<dyn FoldDispatch>>>,
}
impl FoldRegistry {
pub fn new() -> Self {
Self {
folds: RwLock::new(HashMap::new()),
}
}
pub fn register<K: FoldKind>(&self, fold: Arc<Fold<K>>) -> Option<Arc<dyn FoldDispatch>> {
let adapter = Arc::new(FoldDispatchAdapter::new(fold));
self.folds
.write()
.insert(K::KIND_ID, adapter as Arc<dyn FoldDispatch>)
}
pub fn deregister(&self, kind: u16) -> Option<Arc<dyn FoldDispatch>> {
self.folds.write().remove(&kind)
}
pub fn len(&self) -> usize {
self.folds.read().len()
}
pub fn is_empty(&self) -> bool {
self.folds.read().is_empty()
}
pub fn get(&self, kind: u16) -> Option<Arc<dyn FoldDispatch>> {
self.folds.read().get(&kind).cloned()
}
pub fn stats(&self) -> Vec<FoldStats> {
self.folds
.read()
.values()
.map(|adapter| adapter.stats())
.collect()
}
pub fn dispatch(
&self,
bytes: &[u8],
publisher: &crate::adapter::net::identity::EntityId,
) -> Result<ApplyOutcome, DispatchError> {
let kind = peek_kind(bytes).ok_or(DispatchError::Truncated)?;
let adapter = self.get(kind).ok_or(DispatchError::UnknownKind(kind))?;
adapter
.dispatch(bytes, publisher)
.map_err(DispatchError::Wire)
}
}
impl Default for FoldRegistry {
fn default() -> Self {
Self::new()
}
}
pub trait FoldChannelRouter: Send + Sync {
fn try_route(&self, publisher: &EntityId, bytes: &[u8]) -> Result<ApplyOutcome, DispatchError>;
fn stats(&self) -> Vec<FoldStats>;
}
impl FoldChannelRouter for FoldRegistry {
fn try_route(&self, publisher: &EntityId, bytes: &[u8]) -> Result<ApplyOutcome, DispatchError> {
self.dispatch(bytes, publisher)
}
fn stats(&self) -> Vec<FoldStats> {
FoldRegistry::stats(self)
}
}
#[derive(Debug, thiserror::Error)]
pub enum DispatchError {
#[error("envelope truncated before kind varint completes")]
Truncated,
#[error("no fold registered for kind {0:#06x}")]
UnknownKind(u16),
#[error("wire / verify / apply failed: {0}")]
Wire(#[from] WireError),
}
fn peek_kind(bytes: &[u8]) -> Option<u16> {
let (kind, _rest) = postcard::take_from_bytes::<u16>(bytes).ok()?;
Some(kind)
}