use std::time::Duration;
use crate::remote::WireEnvelope;
#[derive(Debug, Clone)]
pub enum WireDisposition {
Accept,
Delay(Duration),
Reject(String),
Drop,
}
pub trait WireInterceptor: Send + Sync + 'static {
fn name(&self) -> &'static str;
fn on_receive(&self, envelope: &WireEnvelope) -> WireDisposition;
}
pub struct WireInterceptorPipeline {
interceptors: Vec<Box<dyn WireInterceptor>>,
}
impl WireInterceptorPipeline {
pub fn new() -> Self {
Self {
interceptors: Vec::new(),
}
}
pub fn add(&mut self, interceptor: impl WireInterceptor) {
self.interceptors.push(Box::new(interceptor));
}
pub fn process(&self, envelope: &WireEnvelope) -> (WireDisposition, Option<&'static str>) {
for interceptor in &self.interceptors {
match interceptor.on_receive(envelope) {
WireDisposition::Accept => continue,
other => return (other, Some(interceptor.name())),
}
}
(WireDisposition::Accept, None)
}
pub fn process_envelope(
&self,
envelope: &WireEnvelope,
) -> Result<WireProcessResult, WireRejectError> {
let (disposition, interceptor_name) = self.process(envelope);
match disposition {
WireDisposition::Accept => Ok(WireProcessResult::Accepted),
WireDisposition::Delay(d) => Ok(WireProcessResult::Delayed(d)),
WireDisposition::Reject(reason) => Err(WireRejectError {
interceptor: interceptor_name.unwrap_or("unknown").to_string(),
reason,
}),
WireDisposition::Drop => Ok(WireProcessResult::Dropped),
}
}
pub fn len(&self) -> usize {
self.interceptors.len()
}
pub fn is_empty(&self) -> bool {
self.interceptors.is_empty()
}
}
impl Default for WireInterceptorPipeline {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub enum WireProcessResult {
Accepted,
Delayed(Duration),
Dropped,
}
#[derive(Debug, Clone)]
pub struct WireRejectError {
pub interceptor: String,
pub reason: String,
}
impl std::fmt::Display for WireRejectError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"wire interceptor '{}' rejected: {}",
self.interceptor, self.reason
)
}
}
impl std::error::Error for WireRejectError {}
pub struct MaxBodySizeInterceptor {
max_bytes: usize,
}
impl MaxBodySizeInterceptor {
pub fn new(max_bytes: usize) -> Self {
Self { max_bytes }
}
}
impl WireInterceptor for MaxBodySizeInterceptor {
fn name(&self) -> &'static str {
"max-body-size"
}
fn on_receive(&self, envelope: &WireEnvelope) -> WireDisposition {
if envelope.body.len() > self.max_bytes {
WireDisposition::Reject(format!(
"body size {} exceeds limit {}",
envelope.body.len(),
self.max_bytes
))
} else {
WireDisposition::Accept
}
}
}
pub struct RateLimitWireInterceptor {
max_per_window: u64,
window: Duration,
state: std::sync::Mutex<RateLimitState>,
}
struct RateLimitState {
count: u64,
window_start: std::time::Instant,
}
impl RateLimitWireInterceptor {
pub fn new(max_per_window: u64, window: Duration) -> Self {
Self {
max_per_window,
window,
state: std::sync::Mutex::new(RateLimitState {
count: 0,
window_start: std::time::Instant::now(),
}),
}
}
}
impl WireInterceptor for RateLimitWireInterceptor {
fn name(&self) -> &'static str {
"rate-limit"
}
fn on_receive(&self, _envelope: &WireEnvelope) -> WireDisposition {
let mut state = self
.state
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
let now = std::time::Instant::now();
if now.duration_since(state.window_start) >= self.window {
state.count = 0;
state.window_start = now;
}
state.count += 1;
if state.count > self.max_per_window {
WireDisposition::Reject(format!(
"rate limit exceeded: {} > {} per {:?}",
state.count, self.max_per_window, self.window
))
} else {
WireDisposition::Accept
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::interceptor::SendMode;
use crate::node::{ActorId, NodeId};
use crate::remote::WireHeaders;
fn test_envelope(body_size: usize) -> WireEnvelope {
WireEnvelope {
target: ActorId {
node: NodeId("n1".into()),
local: 1,
},
target_name: "test".into(),
message_type: "test::Msg".into(),
send_mode: SendMode::Tell,
headers: WireHeaders::new(),
body: vec![0u8; body_size],
request_id: None,
version: None,
}
}
struct AcceptAll;
impl WireInterceptor for AcceptAll {
fn name(&self) -> &'static str {
"accept-all"
}
fn on_receive(&self, _: &WireEnvelope) -> WireDisposition {
WireDisposition::Accept
}
}
struct RejectAll;
impl WireInterceptor for RejectAll {
fn name(&self) -> &'static str {
"reject-all"
}
fn on_receive(&self, _: &WireEnvelope) -> WireDisposition {
WireDisposition::Reject("blocked".into())
}
}
struct DropAll;
impl WireInterceptor for DropAll {
fn name(&self) -> &'static str {
"drop-all"
}
fn on_receive(&self, _: &WireEnvelope) -> WireDisposition {
WireDisposition::Drop
}
}
struct DelayAll(Duration);
impl WireInterceptor for DelayAll {
fn name(&self) -> &'static str {
"delay-all"
}
fn on_receive(&self, _: &WireEnvelope) -> WireDisposition {
WireDisposition::Delay(self.0)
}
}
#[test]
fn empty_pipeline_accepts() {
let pipeline = WireInterceptorPipeline::new();
assert!(pipeline.is_empty());
let (result, name) = pipeline.process(&test_envelope(10));
assert!(matches!(result, WireDisposition::Accept));
assert!(name.is_none());
}
#[test]
fn pipeline_accept_all() {
let mut pipeline = WireInterceptorPipeline::new();
pipeline.add(AcceptAll);
pipeline.add(AcceptAll);
let (result, _) = pipeline.process(&test_envelope(10));
assert!(matches!(result, WireDisposition::Accept));
}
#[test]
fn pipeline_reject_short_circuits() {
let mut pipeline = WireInterceptorPipeline::new();
pipeline.add(AcceptAll);
pipeline.add(RejectAll);
pipeline.add(AcceptAll); let (result, name) = pipeline.process(&test_envelope(10));
assert!(matches!(result, WireDisposition::Reject(_)));
assert_eq!(name, Some("reject-all"));
}
#[test]
fn pipeline_drop_short_circuits() {
let mut pipeline = WireInterceptorPipeline::new();
pipeline.add(DropAll);
pipeline.add(AcceptAll); let (result, name) = pipeline.process(&test_envelope(10));
assert!(matches!(result, WireDisposition::Drop));
assert_eq!(name, Some("drop-all"));
}
#[test]
fn pipeline_delay_short_circuits() {
let mut pipeline = WireInterceptorPipeline::new();
pipeline.add(AcceptAll);
pipeline.add(DelayAll(Duration::from_millis(50)));
pipeline.add(RejectAll); let (result, _) = pipeline.process(&test_envelope(10));
assert!(matches!(result, WireDisposition::Delay(_)));
}
#[test]
fn max_body_size_accepts_within_limit() {
let interceptor = MaxBodySizeInterceptor::new(100);
let result = interceptor.on_receive(&test_envelope(50));
assert!(matches!(result, WireDisposition::Accept));
}
#[test]
fn max_body_size_accepts_at_limit() {
let interceptor = MaxBodySizeInterceptor::new(100);
let result = interceptor.on_receive(&test_envelope(100));
assert!(matches!(result, WireDisposition::Accept));
}
#[test]
fn max_body_size_rejects_over_limit() {
let interceptor = MaxBodySizeInterceptor::new(100);
let result = interceptor.on_receive(&test_envelope(101));
assert!(matches!(result, WireDisposition::Reject(_)));
if let WireDisposition::Reject(reason) = result {
assert!(reason.contains("101"));
assert!(reason.contains("100"));
}
}
#[test]
fn rate_limit_accepts_within_limit() {
let rl = RateLimitWireInterceptor::new(3, Duration::from_secs(1));
let envelope = test_envelope(10);
assert!(matches!(rl.on_receive(&envelope), WireDisposition::Accept));
assert!(matches!(rl.on_receive(&envelope), WireDisposition::Accept));
assert!(matches!(rl.on_receive(&envelope), WireDisposition::Accept));
}
#[test]
fn rate_limit_rejects_over_limit() {
let rl = RateLimitWireInterceptor::new(2, Duration::from_secs(1));
let envelope = test_envelope(10);
assert!(matches!(rl.on_receive(&envelope), WireDisposition::Accept));
assert!(matches!(rl.on_receive(&envelope), WireDisposition::Accept));
let result = rl.on_receive(&envelope);
assert!(matches!(result, WireDisposition::Reject(_)));
if let WireDisposition::Reject(reason) = result {
assert!(reason.contains("rate limit exceeded"));
}
}
#[test]
fn process_envelope_accept() {
let pipeline = WireInterceptorPipeline::new();
let result = pipeline.process_envelope(&test_envelope(10));
assert!(result.is_ok());
assert!(matches!(result.unwrap(), WireProcessResult::Accepted));
}
#[test]
fn process_envelope_reject() {
let mut pipeline = WireInterceptorPipeline::new();
pipeline.add(RejectAll);
let result = pipeline.process_envelope(&test_envelope(10));
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.interceptor, "reject-all");
assert!(err.reason.contains("blocked"));
}
#[test]
fn process_envelope_drop() {
let mut pipeline = WireInterceptorPipeline::new();
pipeline.add(DropAll);
let result = pipeline.process_envelope(&test_envelope(10));
assert!(result.is_ok());
assert!(matches!(result.unwrap(), WireProcessResult::Dropped));
}
#[test]
fn process_envelope_delay() {
let mut pipeline = WireInterceptorPipeline::new();
pipeline.add(DelayAll(Duration::from_millis(50)));
let result = pipeline.process_envelope(&test_envelope(10));
assert!(result.is_ok());
if let WireProcessResult::Delayed(d) = result.unwrap() {
assert_eq!(d, Duration::from_millis(50));
} else {
panic!("expected Delayed");
}
}
#[test]
fn wire_reject_error_display() {
let err = WireRejectError {
interceptor: "max-body-size".into(),
reason: "too large".into(),
};
assert_eq!(
format!("{err}"),
"wire interceptor 'max-body-size' rejected: too large"
);
}
#[test]
fn pipeline_len() {
let mut pipeline = WireInterceptorPipeline::new();
assert_eq!(pipeline.len(), 0);
pipeline.add(AcceptAll);
pipeline.add(RejectAll);
assert_eq!(pipeline.len(), 2);
}
}