use async_trait::async_trait;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MiddlewareResult {
Continue,
Stop,
}
#[async_trait]
pub trait MessageMiddleware: Send + Sync {
async fn on_connect(&self, _peer: &str) -> bool {
true
}
async fn on_message(&self, peer: &str, data: &[u8]) -> MiddlewareResult;
async fn on_disconnect(&self, _peer: &str) {}
}
#[cfg(test)]
mod tests {
use super::*;
struct AllowAllMiddleware;
#[async_trait]
impl MessageMiddleware for AllowAllMiddleware {
async fn on_message(&self, _peer: &str, _data: &[u8]) -> MiddlewareResult {
MiddlewareResult::Continue
}
}
struct RejectAllMiddleware;
#[async_trait]
impl MessageMiddleware for RejectAllMiddleware {
async fn on_connect(&self, _peer: &str) -> bool {
false
}
async fn on_message(&self, _peer: &str, _data: &[u8]) -> MiddlewareResult {
MiddlewareResult::Stop
}
}
#[tokio::test]
async fn test_default_on_connect_allows() {
let mw = AllowAllMiddleware;
assert!(mw.on_connect("peer1").await);
}
#[tokio::test]
async fn test_default_on_disconnect_is_noop() {
let mw = AllowAllMiddleware;
mw.on_disconnect("peer1").await;
}
#[tokio::test]
async fn test_allow_all_middleware() {
let mw = AllowAllMiddleware;
assert!(mw.on_connect("peer1").await);
assert_eq!(
mw.on_message("peer1", b"hello").await,
MiddlewareResult::Continue
);
}
#[tokio::test]
async fn test_reject_all_middleware() {
let mw = RejectAllMiddleware;
assert!(!mw.on_connect("peer1").await);
assert_eq!(
mw.on_message("peer1", b"hello").await,
MiddlewareResult::Stop
);
}
#[test]
fn test_middleware_result_equality() {
assert_eq!(MiddlewareResult::Continue, MiddlewareResult::Continue);
assert_eq!(MiddlewareResult::Stop, MiddlewareResult::Stop);
assert_ne!(MiddlewareResult::Continue, MiddlewareResult::Stop);
}
#[test]
fn test_middleware_result_copy() {
let r = MiddlewareResult::Continue;
let r2 = r;
assert_eq!(r, r2);
}
}