use async_trait::async_trait;
use celers_protocol::Message;
use std::collections::HashMap;
use std::time::Duration;
use uuid::Uuid;
use crate::{BrokerError, MessageMiddleware, Result};
pub struct ErrorClassificationMiddleware {
transient_patterns: Vec<String>,
permanent_patterns: Vec<String>,
max_transient_retries: u32,
max_permanent_retries: u32,
}
impl ErrorClassificationMiddleware {
pub fn new() -> Self {
Self {
transient_patterns: vec![
"timeout".to_string(),
"connection".to_string(),
"network".to_string(),
"unavailable".to_string(),
],
permanent_patterns: vec![
"validation".to_string(),
"schema".to_string(),
"invalid".to_string(),
"forbidden".to_string(),
],
max_transient_retries: 10,
max_permanent_retries: 1,
}
}
pub fn with_transient_pattern(mut self, pattern: &str) -> Self {
self.transient_patterns.push(pattern.to_string());
self
}
pub fn with_permanent_pattern(mut self, pattern: &str) -> Self {
self.permanent_patterns.push(pattern.to_string());
self
}
pub fn with_max_transient_retries(mut self, max_retries: u32) -> Self {
self.max_transient_retries = max_retries;
self
}
pub fn with_max_permanent_retries(mut self, max_retries: u32) -> Self {
self.max_permanent_retries = max_retries;
self
}
pub fn classify_error(&self, error_msg: &str) -> ErrorClass {
let error_lower = error_msg.to_lowercase();
for pattern in &self.permanent_patterns {
if error_lower.contains(&pattern.to_lowercase()) {
return ErrorClass::Permanent;
}
}
for pattern in &self.transient_patterns {
if error_lower.contains(&pattern.to_lowercase()) {
return ErrorClass::Transient;
}
}
ErrorClass::Unknown
}
pub fn should_retry(&self, error_msg: &str, current_retries: u32) -> bool {
match self.classify_error(error_msg) {
ErrorClass::Transient => current_retries < self.max_transient_retries,
ErrorClass::Permanent => current_retries < self.max_permanent_retries,
ErrorClass::Unknown => current_retries < self.max_transient_retries,
}
}
}
impl Default for ErrorClassificationMiddleware {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ErrorClass {
Transient,
Permanent,
Unknown,
}
#[async_trait]
impl MessageMiddleware for ErrorClassificationMiddleware {
async fn before_publish(&self, message: &mut Message) -> Result<()> {
if let Some(error_value) = message.headers.extra.get("error") {
if let Some(error_msg) = error_value.as_str() {
let error_class = self.classify_error(error_msg);
let should_retry =
self.should_retry(error_msg, message.headers.retries.unwrap_or(0));
message.headers.extra.insert(
"x-error-class".to_string(),
serde_json::json!(match error_class {
ErrorClass::Transient => "transient",
ErrorClass::Permanent => "permanent",
ErrorClass::Unknown => "unknown",
}),
);
message.headers.extra.insert(
"x-should-retry".to_string(),
serde_json::json!(should_retry),
);
if !should_retry {
message.headers.extra.insert(
"x-max-retries-exceeded".to_string(),
serde_json::json!(true),
);
}
}
}
Ok(())
}
async fn after_consume(&self, _message: &mut Message) -> Result<()> {
Ok(())
}
fn name(&self) -> &str {
"error_classification"
}
}
pub struct CorrelationMiddleware {
header_name: String,
}
impl CorrelationMiddleware {
pub fn new() -> Self {
Self {
header_name: "x-correlation-id".to_string(),
}
}
pub fn with_header_name(header_name: &str) -> Self {
Self {
header_name: header_name.to_string(),
}
}
fn get_or_generate_correlation_id(&self, message: &Message) -> String {
message
.headers
.extra
.get(&self.header_name)
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.unwrap_or_else(|| Uuid::new_v4().to_string())
}
}
impl Default for CorrelationMiddleware {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl MessageMiddleware for CorrelationMiddleware {
async fn before_publish(&self, message: &mut Message) -> Result<()> {
let correlation_id = self.get_or_generate_correlation_id(message);
message
.headers
.extra
.insert(self.header_name.clone(), serde_json::json!(correlation_id));
Ok(())
}
async fn after_consume(&self, message: &mut Message) -> Result<()> {
let correlation_id = self.get_or_generate_correlation_id(message);
message
.headers
.extra
.insert(self.header_name.clone(), serde_json::json!(correlation_id));
Ok(())
}
fn name(&self) -> &str {
"correlation"
}
}
pub struct ThrottlingMiddleware {
pub(crate) max_rate: f64,
pub(crate) burst_size: usize,
pub(crate) backpressure_threshold: f64,
last_refill: std::sync::Mutex<std::time::Instant>,
available_tokens: std::sync::Mutex<f64>,
}
impl ThrottlingMiddleware {
pub fn new(max_rate: f64) -> Self {
Self {
max_rate,
burst_size: (max_rate * 2.0) as usize,
backpressure_threshold: 0.8,
last_refill: std::sync::Mutex::new(std::time::Instant::now()),
available_tokens: std::sync::Mutex::new(max_rate),
}
}
pub fn with_burst_size(mut self, size: usize) -> Self {
self.burst_size = size;
self
}
pub fn with_backpressure_threshold(mut self, threshold: f64) -> Self {
self.backpressure_threshold = threshold.clamp(0.0, 1.0);
self
}
fn refill_tokens(&self) {
let mut last_refill = self.last_refill.lock().unwrap();
let mut tokens = self.available_tokens.lock().unwrap();
let now = std::time::Instant::now();
let elapsed = now.duration_since(*last_refill).as_secs_f64();
let new_tokens = elapsed * self.max_rate;
*tokens = (*tokens + new_tokens).min(self.burst_size as f64);
*last_refill = now;
}
fn calculate_delay(&self) -> Duration {
self.refill_tokens();
let tokens = self.available_tokens.lock().unwrap();
if *tokens >= 1.0 {
Duration::from_millis(0)
} else {
let wait_time = (1.0 - *tokens) / self.max_rate;
Duration::from_secs_f64(wait_time)
}
}
fn should_apply_backpressure(&self) -> bool {
self.refill_tokens();
let tokens = self.available_tokens.lock().unwrap();
(*tokens / self.burst_size as f64) < (1.0 - self.backpressure_threshold)
}
}
#[async_trait]
impl MessageMiddleware for ThrottlingMiddleware {
async fn before_publish(&self, message: &mut Message) -> Result<()> {
let delay = self.calculate_delay();
if delay > Duration::from_millis(0) {
message.headers.extra.insert(
"x-throttle-delay-ms".to_string(),
serde_json::json!(delay.as_millis()),
);
}
if self.should_apply_backpressure() {
message
.headers
.extra
.insert("x-backpressure-active".to_string(), serde_json::json!(true));
}
let mut tokens = self.available_tokens.lock().unwrap();
if *tokens >= 1.0 {
*tokens -= 1.0;
}
Ok(())
}
async fn after_consume(&self, _message: &mut Message) -> Result<()> {
Ok(())
}
fn name(&self) -> &str {
"throttling"
}
}
pub struct CircuitBreakerMiddleware {
pub(crate) failure_threshold: usize,
window: Duration,
failures: std::sync::Mutex<Vec<std::time::Instant>>,
}
impl CircuitBreakerMiddleware {
pub fn new(failure_threshold: usize, window: Duration) -> Self {
Self {
failure_threshold,
window,
failures: std::sync::Mutex::new(Vec::new()),
}
}
fn record_failure(&self) {
let mut failures = self.failures.lock().unwrap();
let now = std::time::Instant::now();
failures.retain(|&f| now.duration_since(f) < self.window);
failures.push(now);
}
fn is_circuit_open(&self) -> bool {
let mut failures = self.failures.lock().unwrap();
let now = std::time::Instant::now();
failures.retain(|&f| now.duration_since(f) < self.window);
failures.len() >= self.failure_threshold
}
fn get_failure_count(&self) -> usize {
let mut failures = self.failures.lock().unwrap();
let now = std::time::Instant::now();
failures.retain(|&f| now.duration_since(f) < self.window);
failures.len()
}
}
#[async_trait]
impl MessageMiddleware for CircuitBreakerMiddleware {
async fn before_publish(&self, message: &mut Message) -> Result<()> {
if self.is_circuit_open() {
message.headers.extra.insert(
"x-circuit-breaker-open".to_string(),
serde_json::json!(true),
);
message.headers.extra.insert(
"x-circuit-breaker-failures".to_string(),
serde_json::json!(self.get_failure_count()),
);
return Err(BrokerError::OperationFailed(
"Circuit breaker is open".to_string(),
));
}
Ok(())
}
async fn after_consume(&self, message: &mut Message) -> Result<()> {
if message.headers.extra.contains_key("error") {
self.record_failure();
}
message.headers.extra.insert(
"x-circuit-breaker-failures".to_string(),
serde_json::json!(self.get_failure_count()),
);
Ok(())
}
fn name(&self) -> &str {
"circuit_breaker"
}
}
pub struct SchemaValidationMiddleware {
pub(crate) required_fields: Vec<String>,
pub(crate) max_field_count: Option<usize>,
min_body_size: Option<usize>,
pub(crate) max_body_size: Option<usize>,
}
impl SchemaValidationMiddleware {
pub fn new() -> Self {
Self {
required_fields: Vec::new(),
max_field_count: None,
min_body_size: None,
max_body_size: None,
}
}
pub fn with_required_field(mut self, field: impl Into<String>) -> Self {
self.required_fields.push(field.into());
self
}
pub fn with_max_field_count(mut self, count: usize) -> Self {
self.max_field_count = Some(count);
self
}
pub fn with_min_body_size(mut self, size: usize) -> Self {
self.min_body_size = Some(size);
self
}
pub fn with_max_body_size(mut self, size: usize) -> Self {
self.max_body_size = Some(size);
self
}
fn validate_message(&self, message: &Message) -> Result<()> {
for field in &self.required_fields {
if !message.headers.extra.contains_key(field) {
return Err(BrokerError::Configuration(format!(
"Missing required field: {}",
field
)));
}
}
if let Some(max) = self.max_field_count {
if message.headers.extra.len() > max {
return Err(BrokerError::Configuration(format!(
"Too many fields: {} > {}",
message.headers.extra.len(),
max
)));
}
}
let body_len = message.body.len();
if let Some(min) = self.min_body_size {
if body_len < min {
return Err(BrokerError::Configuration(format!(
"Body too small: {} < {}",
body_len, min
)));
}
}
if let Some(max) = self.max_body_size {
if body_len > max {
return Err(BrokerError::Configuration(format!(
"Body too large: {} > {}",
body_len, max
)));
}
}
Ok(())
}
}
impl Default for SchemaValidationMiddleware {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl MessageMiddleware for SchemaValidationMiddleware {
async fn before_publish(&self, message: &mut Message) -> Result<()> {
self.validate_message(message)?;
message
.headers
.extra
.insert("x-schema-validated".to_string(), serde_json::json!(true));
Ok(())
}
async fn after_consume(&self, message: &mut Message) -> Result<()> {
self.validate_message(message)
}
fn name(&self) -> &str {
"schema_validation"
}
}
pub struct MessageEnrichmentMiddleware {
pub(crate) hostname: Option<String>,
pub(crate) environment: Option<String>,
pub(crate) version: Option<String>,
pub(crate) add_timestamp: bool,
custom_metadata: HashMap<String, serde_json::Value>,
}
impl MessageEnrichmentMiddleware {
pub fn new() -> Self {
Self {
hostname: None,
environment: None,
version: None,
add_timestamp: false,
custom_metadata: HashMap::new(),
}
}
pub fn with_hostname(mut self, hostname: impl Into<String>) -> Self {
self.hostname = Some(hostname.into());
self
}
pub fn with_environment(mut self, environment: impl Into<String>) -> Self {
self.environment = Some(environment.into());
self
}
pub fn with_version(mut self, version: impl Into<String>) -> Self {
self.version = Some(version.into());
self
}
pub fn with_add_timestamp(mut self, add: bool) -> Self {
self.add_timestamp = add;
self
}
pub fn with_custom_metadata(
mut self,
key: impl Into<String>,
value: serde_json::Value,
) -> Self {
self.custom_metadata.insert(key.into(), value);
self
}
fn enrich_message(&self, message: &mut Message) {
if let Some(ref hostname) = self.hostname {
message.headers.extra.insert(
"x-enrichment-hostname".to_string(),
serde_json::json!(hostname),
);
}
if let Some(ref environment) = self.environment {
message.headers.extra.insert(
"x-enrichment-environment".to_string(),
serde_json::json!(environment),
);
}
if let Some(ref version) = self.version {
message.headers.extra.insert(
"x-enrichment-version".to_string(),
serde_json::json!(version),
);
}
if self.add_timestamp {
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
message.headers.extra.insert(
"x-enrichment-timestamp".to_string(),
serde_json::json!(timestamp),
);
}
for (key, value) in &self.custom_metadata {
message
.headers
.extra
.insert(format!("x-enrichment-{}", key), value.clone());
}
}
}
impl Default for MessageEnrichmentMiddleware {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl MessageMiddleware for MessageEnrichmentMiddleware {
async fn before_publish(&self, message: &mut Message) -> Result<()> {
self.enrich_message(message);
Ok(())
}
async fn after_consume(&self, _message: &mut Message) -> Result<()> {
Ok(())
}
fn name(&self) -> &str {
"message_enrichment"
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RetryStrategy {
Exponential,
Linear,
Fibonacci,
Fixed,
}
pub struct RetryStrategyMiddleware {
strategy: RetryStrategy,
base_delay_ms: u64,
max_delay_ms: u64,
max_retries: u32,
}
impl RetryStrategyMiddleware {
pub fn new(strategy: RetryStrategy) -> Self {
Self {
strategy,
base_delay_ms: 1000, max_delay_ms: 300_000, max_retries: 5,
}
}
pub fn with_base_delay(mut self, delay: Duration) -> Self {
self.base_delay_ms = delay.as_millis() as u64;
self
}
pub fn with_max_delay(mut self, delay: Duration) -> Self {
self.max_delay_ms = delay.as_millis() as u64;
self
}
pub fn with_max_retries(mut self, retries: u32) -> Self {
self.max_retries = retries;
self
}
fn calculate_delay(&self, retry_count: u32) -> u64 {
let delay = match self.strategy {
RetryStrategy::Exponential => {
self.base_delay_ms * 2_u64.pow(retry_count)
}
RetryStrategy::Linear => {
self.base_delay_ms * (retry_count as u64 + 1)
}
RetryStrategy::Fibonacci => {
let fib = self.fibonacci(retry_count as usize);
self.base_delay_ms * fib
}
RetryStrategy::Fixed => {
self.base_delay_ms
}
};
delay.min(self.max_delay_ms)
}
fn fibonacci(&self, n: usize) -> u64 {
match n {
0 => 1,
1 => 1,
_ => {
let mut a = 1u64;
let mut b = 1u64;
for _ in 2..=n {
let temp = a + b;
a = b;
b = temp;
}
b
}
}
}
}
impl Default for RetryStrategyMiddleware {
fn default() -> Self {
Self::new(RetryStrategy::Exponential)
}
}
#[async_trait]
impl MessageMiddleware for RetryStrategyMiddleware {
async fn before_publish(&self, message: &mut Message) -> Result<()> {
let retry_count = message
.headers
.extra
.get("x-retry-count")
.and_then(|v| v.as_u64())
.unwrap_or(0) as u32;
if retry_count >= self.max_retries {
return Err(BrokerError::OperationFailed(format!(
"Max retries ({}) exceeded",
self.max_retries
)));
}
let delay_ms = self.calculate_delay(retry_count);
message
.headers
.extra
.insert("x-retry-delay-ms".to_string(), serde_json::json!(delay_ms));
message.headers.extra.insert(
"x-retry-strategy".to_string(),
serde_json::json!(format!("{:?}", self.strategy)),
);
Ok(())
}
async fn after_consume(&self, _message: &mut Message) -> Result<()> {
Ok(())
}
fn name(&self) -> &str {
"retry_strategy"
}
}
pub struct TenantIsolationMiddleware {
required: bool,
tenant_header: String,
allowed_tenants: Option<Vec<String>>,
}
impl TenantIsolationMiddleware {
pub fn new() -> Self {
Self {
required: true,
tenant_header: "x-tenant-id".to_string(),
allowed_tenants: None,
}
}
pub fn with_required_tenant(mut self, required: bool) -> Self {
self.required = required;
self
}
pub fn with_tenant_header(mut self, header: impl Into<String>) -> Self {
self.tenant_header = header.into();
self
}
pub fn with_allowed_tenants(mut self, tenants: Vec<String>) -> Self {
self.allowed_tenants = Some(tenants);
self
}
fn validate_tenant(&self, tenant_id: Option<&str>) -> Result<()> {
if self.required && tenant_id.is_none() {
return Err(BrokerError::Configuration(format!(
"Missing required tenant header: {}",
self.tenant_header
)));
}
if let (Some(tenant), Some(allowed)) = (tenant_id, &self.allowed_tenants) {
if !allowed.contains(&tenant.to_string()) {
return Err(BrokerError::Configuration(format!(
"Tenant '{}' not in allowed list",
tenant
)));
}
}
Ok(())
}
}
impl Default for TenantIsolationMiddleware {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl MessageMiddleware for TenantIsolationMiddleware {
async fn before_publish(&self, message: &mut Message) -> Result<()> {
let tenant_id = message
.headers
.extra
.get(&self.tenant_header)
.and_then(|v| v.as_str());
self.validate_tenant(tenant_id)?;
message
.headers
.extra
.insert("x-tenant-validated".to_string(), serde_json::json!(true));
Ok(())
}
async fn after_consume(&self, message: &mut Message) -> Result<()> {
let tenant_id = message
.headers
.extra
.get(&self.tenant_header)
.and_then(|v| v.as_str());
self.validate_tenant(tenant_id)
}
fn name(&self) -> &str {
"tenant_isolation"
}
}
#[derive(Debug, Clone)]
pub struct PartitioningMiddleware {
partition_count: usize,
partition_header: String,
partition_key_fn: Option<String>, }
impl PartitioningMiddleware {
pub fn new(partition_count: usize) -> Self {
Self {
partition_count: partition_count.max(1),
partition_header: "x-partition-id".to_string(),
partition_key_fn: None,
}
}
pub fn with_partition_header(mut self, header: impl Into<String>) -> Self {
self.partition_header = header.into();
self
}
pub fn with_partition_key_field(mut self, field: impl Into<String>) -> Self {
self.partition_key_fn = Some(field.into());
self
}
pub fn partition_count(&self) -> usize {
self.partition_count
}
fn calculate_partition(&self, message: &Message) -> usize {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let task_id_str = message.headers.id.to_string();
let key = if let Some(field) = &self.partition_key_fn {
message
.headers
.extra
.get(field)
.and_then(|v| v.as_str())
.unwrap_or(&task_id_str)
} else {
&task_id_str
};
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
let hash = hasher.finish();
(hash % self.partition_count as u64) as usize
}
}
impl Default for PartitioningMiddleware {
fn default() -> Self {
Self::new(4) }
}
#[async_trait]
impl MessageMiddleware for PartitioningMiddleware {
async fn before_publish(&self, message: &mut Message) -> Result<()> {
let partition_id = self.calculate_partition(message);
message.headers.extra.insert(
self.partition_header.clone(),
serde_json::json!(partition_id),
);
message.headers.extra.insert(
"x-partition-count".to_string(),
serde_json::json!(self.partition_count),
);
Ok(())
}
async fn after_consume(&self, _message: &mut Message) -> Result<()> {
Ok(())
}
fn name(&self) -> &str {
"partitioning"
}
}
#[derive(Debug, Clone)]
pub struct AdaptiveTimeoutMiddleware {
base_timeout: Duration,
min_timeout: Duration,
max_timeout: Duration,
samples: Vec<u64>, #[allow(dead_code)]
max_samples: usize,
percentile: f64, }
impl AdaptiveTimeoutMiddleware {
pub fn new(base_timeout: Duration) -> Self {
Self {
base_timeout,
min_timeout: Duration::from_secs(1),
max_timeout: base_timeout.mul_f64(5.0), samples: Vec::new(),
max_samples: 100,
percentile: 0.95, }
}
pub fn with_min_timeout(mut self, timeout: Duration) -> Self {
self.min_timeout = timeout;
self
}
pub fn with_max_timeout(mut self, timeout: Duration) -> Self {
self.max_timeout = timeout;
self
}
pub fn with_percentile(mut self, percentile: f64) -> Self {
self.percentile = percentile.clamp(0.0, 1.0);
self
}
pub fn has_samples(&self) -> bool {
!self.samples.is_empty()
}
pub fn calculate_adaptive_timeout(&self) -> Duration {
if self.samples.is_empty() {
return self.base_timeout;
}
let mut sorted_samples = self.samples.clone();
sorted_samples.sort_unstable();
let index = ((sorted_samples.len() as f64 * self.percentile) as usize)
.min(sorted_samples.len() - 1);
let timeout_ms = sorted_samples[index];
let buffered_ms = (timeout_ms as f64 * 1.2) as u64;
let timeout = Duration::from_millis(buffered_ms);
timeout.clamp(self.min_timeout, self.max_timeout)
}
}
impl Default for AdaptiveTimeoutMiddleware {
fn default() -> Self {
Self::new(Duration::from_secs(30))
}
}
#[async_trait]
impl MessageMiddleware for AdaptiveTimeoutMiddleware {
async fn before_publish(&self, message: &mut Message) -> Result<()> {
let timeout = self.calculate_adaptive_timeout();
message.headers.extra.insert(
"x-adaptive-timeout".to_string(),
serde_json::json!(timeout.as_millis() as u64),
);
message.headers.extra.insert(
"x-timeout-percentile".to_string(),
serde_json::json!(self.percentile),
);
Ok(())
}
async fn after_consume(&self, _message: &mut Message) -> Result<()> {
Ok(())
}
fn name(&self) -> &str {
"adaptive_timeout"
}
}