use async_trait::async_trait;
use celers_protocol::Message;
use std::time::Duration;
use uuid::Uuid;
use crate::{BrokerError, BrokerMetrics, MessageMiddleware, Result};
pub struct ValidationMiddleware {
max_body_size: Option<usize>,
require_task_name: bool,
}
impl ValidationMiddleware {
pub fn new() -> Self {
Self {
max_body_size: Some(10 * 1024 * 1024), require_task_name: true,
}
}
pub fn with_max_body_size(mut self, size: usize) -> Self {
self.max_body_size = Some(size);
self
}
pub fn without_body_size_limit(mut self) -> Self {
self.max_body_size = None;
self
}
pub fn with_require_task_name(mut self, require: bool) -> Self {
self.require_task_name = require;
self
}
fn validate_message(&self, message: &Message) -> Result<()> {
if self.require_task_name && message.task_name().is_empty() {
return Err(BrokerError::Configuration(
"Task name cannot be empty".to_string(),
));
}
if let Some(max_size) = self.max_body_size {
if message.body.len() > max_size {
return Err(BrokerError::Configuration(format!(
"Message body size {} exceeds maximum {}",
message.body.len(),
max_size
)));
}
}
Ok(())
}
}
impl Default for ValidationMiddleware {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl MessageMiddleware for ValidationMiddleware {
async fn before_publish(&self, message: &mut Message) -> Result<()> {
self.validate_message(message)
}
async fn after_consume(&self, message: &mut Message) -> Result<()> {
self.validate_message(message)
}
fn name(&self) -> &str {
"validation"
}
}
pub struct LoggingMiddleware {
prefix: String,
log_body: bool,
}
impl LoggingMiddleware {
pub fn new(prefix: impl Into<String>) -> Self {
Self {
prefix: prefix.into(),
log_body: false,
}
}
pub fn with_body_logging(mut self) -> Self {
self.log_body = true;
self
}
}
#[async_trait]
impl MessageMiddleware for LoggingMiddleware {
async fn before_publish(&self, message: &mut Message) -> Result<()> {
if self.log_body {
eprintln!(
"[{}] Publishing: task={}, id={}, body_size={}",
self.prefix,
message.task_name(),
message.task_id(),
message.body.len()
);
} else {
eprintln!(
"[{}] Publishing: task={}, id={}",
self.prefix,
message.task_name(),
message.task_id()
);
}
Ok(())
}
async fn after_consume(&self, message: &mut Message) -> Result<()> {
if self.log_body {
eprintln!(
"[{}] Consumed: task={}, id={}, body_size={}",
self.prefix,
message.task_name(),
message.task_id(),
message.body.len()
);
} else {
eprintln!(
"[{}] Consumed: task={}, id={}",
self.prefix,
message.task_name(),
message.task_id()
);
}
Ok(())
}
fn name(&self) -> &str {
"logging"
}
}
pub struct MetricsMiddleware {
metrics: std::sync::Arc<std::sync::Mutex<BrokerMetrics>>,
}
impl MetricsMiddleware {
pub fn new(metrics: std::sync::Arc<std::sync::Mutex<BrokerMetrics>>) -> Self {
Self { metrics }
}
pub fn get_metrics(&self) -> BrokerMetrics {
self.metrics.lock().unwrap().clone()
}
}
#[async_trait]
impl MessageMiddleware for MetricsMiddleware {
async fn before_publish(&self, _message: &mut Message) -> Result<()> {
let mut metrics = self.metrics.lock().unwrap();
metrics.inc_published();
Ok(())
}
async fn after_consume(&self, _message: &mut Message) -> Result<()> {
let mut metrics = self.metrics.lock().unwrap();
metrics.inc_consumed();
Ok(())
}
fn name(&self) -> &str {
"metrics"
}
}
pub struct RetryLimitMiddleware {
max_retries: u32,
}
impl RetryLimitMiddleware {
pub fn new(max_retries: u32) -> Self {
Self { max_retries }
}
}
#[async_trait]
impl MessageMiddleware for RetryLimitMiddleware {
async fn before_publish(&self, _message: &mut Message) -> Result<()> {
Ok(())
}
async fn after_consume(&self, message: &mut Message) -> Result<()> {
let retries = message.headers.retries.unwrap_or(0);
if retries > self.max_retries {
return Err(BrokerError::Configuration(format!(
"Message exceeded maximum retries: {} > {}",
retries, self.max_retries
)));
}
Ok(())
}
fn name(&self) -> &str {
"retry_limit"
}
}
pub struct RateLimitingMiddleware {
max_rate: f64,
tokens: std::sync::Arc<std::sync::Mutex<TokenBucket>>,
}
struct TokenBucket {
tokens: f64,
capacity: f64,
refill_rate: f64,
last_refill: std::time::Instant,
}
impl TokenBucket {
fn new(capacity: f64, refill_rate: f64) -> Self {
Self {
tokens: capacity,
capacity,
refill_rate,
last_refill: std::time::Instant::now(),
}
}
fn try_consume(&mut self, tokens: f64) -> bool {
let now = std::time::Instant::now();
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.capacity);
self.last_refill = now;
if self.tokens >= tokens {
self.tokens -= tokens;
true
} else {
false
}
}
}
impl RateLimitingMiddleware {
pub fn new(max_rate: f64) -> Self {
Self {
max_rate,
tokens: std::sync::Arc::new(std::sync::Mutex::new(TokenBucket::new(
max_rate, max_rate,
))),
}
}
}
#[async_trait]
impl MessageMiddleware for RateLimitingMiddleware {
async fn before_publish(&self, _message: &mut Message) -> Result<()> {
let mut bucket = self.tokens.lock().unwrap();
if !bucket.try_consume(1.0) {
return Err(BrokerError::OperationFailed(format!(
"Rate limit exceeded: {} messages/sec",
self.max_rate
)));
}
Ok(())
}
async fn after_consume(&self, _message: &mut Message) -> Result<()> {
Ok(())
}
fn name(&self) -> &str {
"rate_limit"
}
}
pub struct DeduplicationMiddleware {
seen_ids: std::sync::Arc<std::sync::Mutex<std::collections::HashSet<Uuid>>>,
max_cache_size: usize,
}
impl DeduplicationMiddleware {
pub fn new(max_cache_size: usize) -> Self {
Self {
seen_ids: std::sync::Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())),
max_cache_size,
}
}
pub fn with_default_cache() -> Self {
Self::new(10_000)
}
}
impl Default for DeduplicationMiddleware {
fn default() -> Self {
Self::with_default_cache()
}
}
#[async_trait]
impl MessageMiddleware for DeduplicationMiddleware {
async fn before_publish(&self, _message: &mut Message) -> Result<()> {
Ok(())
}
async fn after_consume(&self, message: &mut Message) -> Result<()> {
let msg_id = message.task_id();
let mut seen = self.seen_ids.lock().unwrap();
if seen.contains(&msg_id) {
return Err(BrokerError::OperationFailed(format!(
"Duplicate message detected: {}",
msg_id
)));
}
seen.insert(msg_id);
if seen.len() > self.max_cache_size {
if let Some(&id) = seen.iter().next() {
seen.remove(&id);
}
}
Ok(())
}
fn name(&self) -> &str {
"deduplication"
}
}
#[cfg(feature = "compression")]
pub struct CompressionMiddleware {
compressor: celers_protocol::compression::Compressor,
min_compress_size: usize,
}
#[cfg(feature = "compression")]
impl CompressionMiddleware {
pub fn new(compression_type: celers_protocol::compression::CompressionType) -> Self {
Self {
compressor: celers_protocol::compression::Compressor::new(compression_type),
min_compress_size: 1024, }
}
pub fn with_min_size(mut self, size: usize) -> Self {
self.min_compress_size = size;
self
}
pub fn with_level(mut self, level: u32) -> Self {
self.compressor = self.compressor.with_level(level);
self
}
}
#[cfg(feature = "compression")]
#[async_trait]
impl MessageMiddleware for CompressionMiddleware {
async fn before_publish(&self, message: &mut Message) -> Result<()> {
if message.body.len() >= self.min_compress_size {
let compressed = self
.compressor
.compress(&message.body)
.map_err(|e| BrokerError::Serialization(e.to_string()))?;
if compressed.len() < message.body.len() {
message.body = compressed;
}
}
Ok(())
}
async fn after_consume(&self, message: &mut Message) -> Result<()> {
let _ = message;
Ok(())
}
fn name(&self) -> &str {
"compression"
}
}
#[cfg(feature = "signing")]
pub struct SigningMiddleware {
signer: celers_protocol::auth::MessageSigner,
}
#[cfg(feature = "signing")]
impl SigningMiddleware {
pub fn new(key: &[u8]) -> Self {
Self {
signer: celers_protocol::auth::MessageSigner::new(key),
}
}
}
#[cfg(feature = "signing")]
#[async_trait]
impl MessageMiddleware for SigningMiddleware {
async fn before_publish(&self, message: &mut Message) -> Result<()> {
let signature = self
.signer
.sign(&message.body)
.map_err(|e| BrokerError::OperationFailed(format!("signing failed: {}", e)))?;
let _ = signature;
Ok(())
}
async fn after_consume(&self, message: &mut Message) -> Result<()> {
let _ = self
.signer
.sign(&message.body)
.map_err(|e| BrokerError::OperationFailed(format!("signing failed: {}", e)))?;
Ok(())
}
fn name(&self) -> &str {
"signing"
}
}
#[cfg(feature = "encryption")]
pub struct EncryptionMiddleware {
encryptor: celers_protocol::crypto::MessageEncryptor,
}
#[cfg(feature = "encryption")]
impl EncryptionMiddleware {
pub fn new(key: &[u8]) -> Result<Self> {
let encryptor = celers_protocol::crypto::MessageEncryptor::new(key)
.map_err(|e| BrokerError::Configuration(e.to_string()))?;
Ok(Self { encryptor })
}
}
#[cfg(feature = "encryption")]
#[async_trait]
impl MessageMiddleware for EncryptionMiddleware {
async fn before_publish(&self, message: &mut Message) -> Result<()> {
let (ciphertext, nonce) = self
.encryptor
.encrypt(&message.body)
.map_err(|e| BrokerError::Serialization(e.to_string()))?;
let mut encrypted = nonce.to_vec();
encrypted.extend_from_slice(&ciphertext);
message.body = encrypted;
Ok(())
}
async fn after_consume(&self, message: &mut Message) -> Result<()> {
if message.body.len() < celers_protocol::crypto::NONCE_SIZE {
return Err(BrokerError::Serialization(
"Message too short to contain nonce".to_string(),
));
}
let (nonce_bytes, ciphertext) = message.body.split_at(celers_protocol::crypto::NONCE_SIZE);
let plaintext = self
.encryptor
.decrypt(ciphertext, nonce_bytes)
.map_err(|e| BrokerError::Serialization(e.to_string()))?;
message.body = plaintext;
Ok(())
}
fn name(&self) -> &str {
"encryption"
}
}
pub struct TimeoutMiddleware {
timeout: Duration,
}
impl TimeoutMiddleware {
pub fn new(timeout: Duration) -> Self {
Self { timeout }
}
pub fn timeout(&self) -> Duration {
self.timeout
}
}
#[async_trait]
impl MessageMiddleware for TimeoutMiddleware {
async fn before_publish(&self, message: &mut Message) -> Result<()> {
message.headers.extra.insert(
"x-timeout-ms".to_string(),
serde_json::Value::Number((self.timeout.as_millis() as u64).into()),
);
Ok(())
}
async fn after_consume(&self, _message: &mut Message) -> Result<()> {
Ok(())
}
fn name(&self) -> &str {
"timeout"
}
}
pub struct FilterMiddleware {
predicate: Box<dyn Fn(&Message) -> bool + Send + Sync>,
}
impl FilterMiddleware {
pub fn new<F>(predicate: F) -> Self
where
F: Fn(&Message) -> bool + Send + Sync + 'static,
{
Self {
predicate: Box::new(predicate),
}
}
pub fn matches(&self, message: &Message) -> bool {
(self.predicate)(message)
}
}
#[async_trait]
impl MessageMiddleware for FilterMiddleware {
async fn before_publish(&self, _message: &mut Message) -> Result<()> {
Ok(())
}
async fn after_consume(&self, message: &mut Message) -> Result<()> {
if !self.matches(message) {
return Err(BrokerError::Configuration(
"Message filtered out by predicate".to_string(),
));
}
Ok(())
}
fn name(&self) -> &str {
"filter"
}
}
pub struct SamplingMiddleware {
sample_rate: f64,
counter: std::sync::atomic::AtomicU64,
}
impl SamplingMiddleware {
pub fn new(sample_rate: f64) -> Self {
Self {
sample_rate: sample_rate.clamp(0.0, 1.0),
counter: std::sync::atomic::AtomicU64::new(0),
}
}
pub fn sample_rate(&self) -> f64 {
self.sample_rate
}
fn should_sample(&self) -> bool {
let count = self
.counter
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let threshold = (u64::MAX as f64 * self.sample_rate) as u64;
(count % u64::MAX) < threshold
}
}
#[async_trait]
impl MessageMiddleware for SamplingMiddleware {
async fn before_publish(&self, _message: &mut Message) -> Result<()> {
Ok(())
}
async fn after_consume(&self, _message: &mut Message) -> Result<()> {
if !self.should_sample() {
return Err(BrokerError::Configuration(
"Message filtered out by sampling".to_string(),
));
}
Ok(())
}
fn name(&self) -> &str {
"sampling"
}
}
pub struct TransformationMiddleware {
transform_fn: Box<dyn Fn(Vec<u8>) -> Vec<u8> + Send + Sync>,
}
impl TransformationMiddleware {
pub fn new<F>(transform_fn: F) -> Self
where
F: Fn(Vec<u8>) -> Vec<u8> + Send + Sync + 'static,
{
Self {
transform_fn: Box::new(transform_fn),
}
}
fn transform(&self, body: Vec<u8>) -> Vec<u8> {
(self.transform_fn)(body)
}
}
#[async_trait]
impl MessageMiddleware for TransformationMiddleware {
async fn before_publish(&self, message: &mut Message) -> Result<()> {
let transformed = self.transform(message.body.clone());
message.body = transformed;
Ok(())
}
async fn after_consume(&self, message: &mut Message) -> Result<()> {
let transformed = self.transform(message.body.clone());
message.body = transformed;
Ok(())
}
fn name(&self) -> &str {
"transformation"
}
}
#[derive(Debug, Clone)]
pub struct TracingMiddleware {
service_name: String,
}
impl TracingMiddleware {
pub fn new(service_name: impl Into<String>) -> Self {
Self {
service_name: service_name.into(),
}
}
}
#[async_trait]
impl MessageMiddleware for TracingMiddleware {
async fn before_publish(&self, message: &mut Message) -> Result<()> {
if !message.headers.extra.contains_key("trace-id") {
let trace_id = uuid::Uuid::new_v4().to_string();
message
.headers
.extra
.insert("trace-id".to_string(), serde_json::json!(trace_id));
}
message.headers.extra.insert(
"service-name".to_string(),
serde_json::json!(self.service_name.clone()),
);
let span_id = uuid::Uuid::new_v4().to_string();
message
.headers
.extra
.insert("span-id".to_string(), serde_json::json!(span_id));
message.headers.extra.insert(
"trace-timestamp".to_string(),
serde_json::json!(std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis()),
);
Ok(())
}
async fn after_consume(&self, message: &mut Message) -> Result<()> {
if let Some(trace_id) = message.headers.extra.get("trace-id").cloned() {
message.headers.extra.insert(
"consumer-service".to_string(),
serde_json::json!(self.service_name.clone()),
);
message
.headers
.extra
.insert("trace-id-consumed".to_string(), trace_id);
}
Ok(())
}
fn name(&self) -> &str {
"tracing"
}
}