use crate::{BrokerError, MessageMiddleware, Result};
use async_trait::async_trait;
use celers_protocol::Message;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct BatchAckHintMiddleware {
batch_size: usize,
hint_header: String,
}
impl BatchAckHintMiddleware {
pub fn new(batch_size: usize) -> Self {
Self {
batch_size: batch_size.max(1),
hint_header: "x-batch-ack-hint".to_string(),
}
}
pub fn with_hint_header(mut self, header: impl Into<String>) -> Self {
self.hint_header = header.into();
self
}
pub fn batch_size(&self) -> usize {
self.batch_size
}
}
impl Default for BatchAckHintMiddleware {
fn default() -> Self {
Self::new(10)
}
}
#[async_trait]
impl MessageMiddleware for BatchAckHintMiddleware {
async fn before_publish(&self, message: &mut Message) -> Result<()> {
message
.headers
.extra
.insert(self.hint_header.clone(), serde_json::json!(self.batch_size));
message.headers.extra.insert(
"x-batch-ack-recommended".to_string(),
serde_json::json!(true),
);
Ok(())
}
async fn after_consume(&self, _message: &mut Message) -> Result<()> {
Ok(())
}
fn name(&self) -> &str {
"batch_ack_hint"
}
}
#[derive(Debug, Clone)]
pub struct LoadSheddingMiddleware {
load_threshold: f64, priority_cutoff: u8, current_load: f64, }
impl LoadSheddingMiddleware {
pub fn new(load_threshold: f64) -> Self {
Self {
load_threshold: load_threshold.clamp(0.0, 1.0),
priority_cutoff: 3, current_load: 0.0,
}
}
pub fn with_priority_cutoff(mut self, cutoff: u8) -> Self {
self.priority_cutoff = cutoff.min(10);
self
}
pub fn update_load(&mut self, load: f64) {
self.current_load = load.clamp(0.0, 1.0);
}
pub fn threshold(&self) -> f64 {
self.load_threshold
}
fn should_shed(&self, priority: u8) -> bool {
self.current_load > self.load_threshold && priority < self.priority_cutoff
}
}
impl Default for LoadSheddingMiddleware {
fn default() -> Self {
Self::new(0.8)
}
}
#[async_trait]
impl MessageMiddleware for LoadSheddingMiddleware {
async fn before_publish(&self, message: &mut Message) -> Result<()> {
let priority = message
.headers
.extra
.get("priority")
.and_then(|v| v.as_u64())
.map(|v| v as u8)
.unwrap_or(5);
if self.should_shed(priority) {
message
.headers
.extra
.insert("x-load-shed".to_string(), serde_json::json!(true));
message.headers.extra.insert(
"x-current-load".to_string(),
serde_json::json!(self.current_load),
);
return Err(BrokerError::OperationFailed(format!(
"Load shedding: current load {:.2} exceeds threshold {:.2}",
self.current_load, self.load_threshold
)));
}
Ok(())
}
async fn after_consume(&self, _message: &mut Message) -> Result<()> {
Ok(())
}
fn name(&self) -> &str {
"load_shedding"
}
}
#[derive(Debug, Clone)]
pub struct MessagePriorityEscalationMiddleware {
age_threshold_secs: u64, escalation_step: u8, max_priority: u8, escalate_on_retry: bool, }
impl MessagePriorityEscalationMiddleware {
pub fn new(age_threshold_secs: u64) -> Self {
Self {
age_threshold_secs,
escalation_step: 1,
max_priority: 10,
escalate_on_retry: true,
}
}
pub fn with_escalation_step(mut self, step: u8) -> Self {
self.escalation_step = step.max(1);
self
}
pub fn with_max_priority(mut self, max: u8) -> Self {
self.max_priority = max.min(10);
self
}
pub fn with_escalate_on_retry(mut self, enable: bool) -> Self {
self.escalate_on_retry = enable;
self
}
pub fn age_threshold_secs(&self) -> u64 {
self.age_threshold_secs
}
fn calculate_priority(&self, base_priority: u8, age_secs: u64, retries: u32) -> u8 {
let mut priority = base_priority;
if age_secs >= self.age_threshold_secs {
let age_multiplier = (age_secs / self.age_threshold_secs) as u8;
priority = priority.saturating_add(age_multiplier * self.escalation_step);
}
if self.escalate_on_retry && retries > 0 {
let retry_boost = (retries as u8).min(3); priority = priority.saturating_add(retry_boost);
}
priority.min(self.max_priority)
}
}
impl Default for MessagePriorityEscalationMiddleware {
fn default() -> Self {
Self::new(300) }
}
#[async_trait]
impl MessageMiddleware for MessagePriorityEscalationMiddleware {
async fn before_publish(&self, message: &mut Message) -> Result<()> {
let base_priority = message
.headers
.extra
.get("priority")
.and_then(|v| v.as_u64())
.map(|v| v as u8)
.unwrap_or(5);
let age_secs = 0; let retries = message.headers.retries.unwrap_or(0);
let new_priority = self.calculate_priority(base_priority, age_secs, retries);
if new_priority != base_priority {
message
.headers
.extra
.insert("priority".to_string(), serde_json::json!(new_priority));
message
.headers
.extra
.insert("x-priority-escalated".to_string(), serde_json::json!(true));
message.headers.extra.insert(
"x-original-priority".to_string(),
serde_json::json!(base_priority),
);
}
Ok(())
}
async fn after_consume(&self, _message: &mut Message) -> Result<()> {
Ok(())
}
fn name(&self) -> &str {
"priority_escalation"
}
}
#[derive(Debug, Clone)]
pub struct ObservabilityMiddleware {
service_name: String,
enable_metrics: bool,
enable_logging: bool,
log_level: String,
}
impl ObservabilityMiddleware {
pub fn new(service_name: impl Into<String>) -> Self {
Self {
service_name: service_name.into(),
enable_metrics: true,
enable_logging: true,
log_level: "info".to_string(),
}
}
pub fn without_metrics(mut self) -> Self {
self.enable_metrics = false;
self
}
pub fn without_logging(mut self) -> Self {
self.enable_logging = false;
self
}
pub fn with_log_level(mut self, level: impl Into<String>) -> Self {
self.log_level = level.into();
self
}
pub fn service_name(&self) -> &str {
&self.service_name
}
}
impl Default for ObservabilityMiddleware {
fn default() -> Self {
Self::new("unknown-service")
}
}
#[async_trait]
impl MessageMiddleware for ObservabilityMiddleware {
async fn before_publish(&self, message: &mut Message) -> Result<()> {
if self.enable_metrics {
message.headers.extra.insert(
"x-observability-enabled".to_string(),
serde_json::json!(true),
);
}
if self.enable_logging {
message
.headers
.extra
.insert("x-log-level".to_string(), serde_json::json!(self.log_level));
}
message.headers.extra.insert(
"x-service-name".to_string(),
serde_json::json!(self.service_name),
);
Ok(())
}
async fn after_consume(&self, _message: &mut Message) -> Result<()> {
Ok(())
}
fn name(&self) -> &str {
"observability"
}
}
pub struct HealthCheckMiddleware {
last_check: std::sync::Arc<std::sync::Mutex<u64>>,
check_interval_secs: u64,
is_healthy: std::sync::Arc<std::sync::Mutex<bool>>,
}
impl HealthCheckMiddleware {
pub fn new() -> Self {
Self {
last_check: std::sync::Arc::new(std::sync::Mutex::new(0)),
check_interval_secs: 60, is_healthy: std::sync::Arc::new(std::sync::Mutex::new(true)),
}
}
pub fn with_check_interval_secs(mut self, interval: u64) -> Self {
self.check_interval_secs = interval;
self
}
pub fn is_healthy(&self) -> bool {
*self.is_healthy.lock().unwrap()
}
pub fn mark_unhealthy(&self) {
*self.is_healthy.lock().unwrap() = false;
}
pub fn mark_healthy(&self) {
*self.is_healthy.lock().unwrap() = true;
}
fn should_check(&self) -> bool {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
let last = *self.last_check.lock().unwrap();
now - last >= self.check_interval_secs
}
fn update_check_time(&self) {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
*self.last_check.lock().unwrap() = now;
}
}
impl Default for HealthCheckMiddleware {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl MessageMiddleware for HealthCheckMiddleware {
async fn before_publish(&self, message: &mut Message) -> Result<()> {
if self.should_check() {
self.update_check_time();
}
let health_status = if self.is_healthy() {
"healthy"
} else {
"unhealthy"
};
message.headers.extra.insert(
"x-health-status".to_string(),
serde_json::json!(health_status),
);
Ok(())
}
async fn after_consume(&self, _message: &mut Message) -> Result<()> {
Ok(())
}
fn name(&self) -> &str {
"health_check"
}
}
pub struct MessageTaggingMiddleware {
environment: String,
tags: HashMap<String, String>,
}
impl MessageTaggingMiddleware {
pub fn new(environment: impl Into<String>) -> Self {
Self {
environment: environment.into(),
tags: HashMap::new(),
}
}
pub fn with_tags(mut self, tags: HashMap<String, String>) -> Self {
self.tags = tags;
self
}
pub fn with_tag(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.tags.insert(key.into(), value.into());
self
}
}
#[async_trait]
impl MessageMiddleware for MessageTaggingMiddleware {
async fn before_publish(&self, message: &mut Message) -> Result<()> {
message.headers.extra.insert(
"x-environment".to_string(),
serde_json::json!(self.environment.clone()),
);
for (key, value) in &self.tags {
message
.headers
.extra
.insert(format!("x-tag-{}", key), serde_json::json!(value.clone()));
}
let category = if message.task_name().contains("email") {
"communication"
} else if message.task_name().contains("report") {
"analytics"
} else if message.task_name().contains("process") {
"computation"
} else {
"general"
};
message
.headers
.extra
.insert("x-category".to_string(), serde_json::json!(category));
Ok(())
}
async fn after_consume(&self, _message: &mut Message) -> Result<()> {
Ok(())
}
fn name(&self) -> &str {
"message_tagging"
}
}
pub struct CostAttributionMiddleware {
message_cost: f64,
compute_cost_per_sec: f64,
storage_cost_per_mb: f64,
}
impl CostAttributionMiddleware {
pub fn new(message_cost: f64) -> Self {
Self {
message_cost,
compute_cost_per_sec: 0.0,
storage_cost_per_mb: 0.0,
}
}
pub fn with_compute_cost_per_sec(mut self, cost: f64) -> Self {
self.compute_cost_per_sec = cost;
self
}
pub fn with_storage_cost_per_mb(mut self, cost: f64) -> Self {
self.storage_cost_per_mb = cost;
self
}
fn calculate_cost(&self, message: &Message) -> f64 {
let mut cost = self.message_cost;
let size_mb = message.body.len() as f64 / (1024.0 * 1024.0);
cost += size_mb * self.storage_cost_per_mb;
cost
}
}
#[async_trait]
impl MessageMiddleware for CostAttributionMiddleware {
async fn before_publish(&self, message: &mut Message) -> Result<()> {
let cost = self.calculate_cost(message);
message.headers.extra.insert(
"x-cost-estimate".to_string(),
serde_json::json!(format!("{:.6}", cost)),
);
let tenant = message
.headers
.extra
.get("x-tenant")
.and_then(|v| v.as_str())
.or_else(|| message.headers.extra.get("tenant").and_then(|v| v.as_str()))
.unwrap_or("default")
.to_string();
message
.headers
.extra
.insert("x-cost-tenant".to_string(), serde_json::json!(tenant));
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
message.headers.extra.insert(
"x-cost-timestamp".to_string(),
serde_json::json!(timestamp.to_string()),
);
Ok(())
}
async fn after_consume(&self, message: &mut Message) -> Result<()> {
if let Some(cost_timestamp) = message.headers.extra.get("x-cost-timestamp") {
if let Some(timestamp_str) = cost_timestamp.as_str() {
if let Ok(start_time) = timestamp_str.parse::<u64>() {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
let duration_secs = (now - start_time) as f64;
let compute_cost = duration_secs * self.compute_cost_per_sec;
if let Some(base_cost) = message.headers.extra.get("x-cost-estimate") {
if let Some(base_str) = base_cost.as_str() {
if let Ok(base) = base_str.parse::<f64>() {
let total_cost = base + compute_cost;
message.headers.extra.insert(
"x-cost-actual".to_string(),
serde_json::json!(format!("{:.6}", total_cost)),
);
}
}
}
}
}
}
Ok(())
}
fn name(&self) -> &str {
"cost_attribution"
}
}
pub struct SLAMonitoringMiddleware {
target_ms: u64,
percentile: u8,
alert_threshold: f64,
processing_times: std::sync::Arc<std::sync::Mutex<Vec<u64>>>,
}
impl SLAMonitoringMiddleware {
pub fn new(target_ms: u64) -> Self {
Self {
target_ms,
percentile: 95,
alert_threshold: 0.9,
processing_times: std::sync::Arc::new(std::sync::Mutex::new(Vec::new())),
}
}
pub fn with_percentile(mut self, percentile: u8) -> Self {
self.percentile = percentile.clamp(1, 99);
self
}
pub fn with_alert_threshold(mut self, threshold: f64) -> Self {
self.alert_threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn compliance_rate(&self) -> f64 {
let times = self.processing_times.lock().unwrap();
if times.is_empty() {
return 1.0;
}
let within_sla = times.iter().filter(|&&t| t <= self.target_ms).count();
within_sla as f64 / times.len() as f64
}
pub fn should_alert(&self) -> bool {
self.compliance_rate() < self.alert_threshold
}
}
#[async_trait]
impl MessageMiddleware for SLAMonitoringMiddleware {
async fn before_publish(&self, message: &mut Message) -> Result<()> {
message.headers.extra.insert(
"x-sla-target-ms".to_string(),
serde_json::json!(self.target_ms),
);
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
message
.headers
.extra
.insert("x-sla-start-ms".to_string(), serde_json::json!(timestamp));
Ok(())
}
async fn after_consume(&self, message: &mut Message) -> Result<()> {
if let Some(start_ms) = message.headers.extra.get("x-sla-start-ms") {
if let Some(start_str) = start_ms.as_u64() {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
let processing_time = now - start_str;
self.processing_times.lock().unwrap().push(processing_time);
let within_sla = processing_time <= self.target_ms;
message
.headers
.extra
.insert("x-sla-met".to_string(), serde_json::json!(within_sla));
message.headers.extra.insert(
"x-sla-processing-ms".to_string(),
serde_json::json!(processing_time),
);
if self.should_alert() {
message
.headers
.extra
.insert("x-sla-alert".to_string(), serde_json::json!(true));
}
}
}
Ok(())
}
fn name(&self) -> &str {
"sla_monitoring"
}
}
pub struct MessageVersioningMiddleware {
current_version: String,
min_supported_version: Option<String>,
auto_upgrade: bool,
}
impl MessageVersioningMiddleware {
pub fn new(current_version: impl Into<String>) -> Self {
Self {
current_version: current_version.into(),
min_supported_version: None,
auto_upgrade: false,
}
}
pub fn with_min_supported_version(mut self, version: impl Into<String>) -> Self {
self.min_supported_version = Some(version.into());
self
}
pub fn with_auto_upgrade(mut self, enabled: bool) -> Self {
self.auto_upgrade = enabled;
self
}
fn is_version_supported(&self, version: &str) -> bool {
if let Some(ref min_version) = self.min_supported_version {
version >= min_version.as_str()
} else {
true
}
}
}
#[async_trait]
impl MessageMiddleware for MessageVersioningMiddleware {
async fn before_publish(&self, message: &mut Message) -> Result<()> {
message.headers.extra.insert(
"x-message-version".to_string(),
serde_json::json!(self.current_version.clone()),
);
message.headers.extra.insert(
"x-schema-version".to_string(),
serde_json::json!(self.current_version.clone()),
);
Ok(())
}
async fn after_consume(&self, message: &mut Message) -> Result<()> {
if let Some(msg_version) = message.headers.extra.get("x-message-version") {
if let Some(version_str) = msg_version.as_str() {
if !self.is_version_supported(version_str) {
return Err(BrokerError::Configuration(format!(
"Unsupported message version: {}. Minimum supported: {:?}",
version_str, self.min_supported_version
)));
}
if self.auto_upgrade && version_str != self.current_version {
message.headers.extra.insert(
"x-upgraded-from".to_string(),
serde_json::json!(version_str),
);
message.headers.extra.insert(
"x-message-version".to_string(),
serde_json::json!(self.current_version.clone()),
);
}
}
} else {
message
.headers
.extra
.insert("x-message-version".to_string(), serde_json::json!("legacy"));
if self.auto_upgrade {
message
.headers
.extra
.insert("x-upgraded-from".to_string(), serde_json::json!("legacy"));
message.headers.extra.insert(
"x-message-version".to_string(),
serde_json::json!(self.current_version.clone()),
);
}
}
Ok(())
}
fn name(&self) -> &str {
"message_versioning"
}
}
type ResourceUsageMap = std::sync::Arc<std::sync::Mutex<HashMap<String, (usize, usize, u64)>>>;
pub struct ResourceQuotaMiddleware {
max_messages: usize,
max_size_bytes: usize,
time_window_secs: u64,
usage: ResourceUsageMap,
}
impl ResourceQuotaMiddleware {
pub fn new(max_messages: usize) -> Self {
Self {
max_messages,
max_size_bytes: usize::MAX,
time_window_secs: 3600, usage: std::sync::Arc::new(std::sync::Mutex::new(HashMap::new())),
}
}
pub fn with_max_size_bytes(mut self, max_bytes: usize) -> Self {
self.max_size_bytes = max_bytes;
self
}
pub fn with_time_window_secs(mut self, seconds: u64) -> Self {
self.time_window_secs = seconds;
self
}
pub fn get_usage(&self, consumer_id: &str) -> (usize, usize) {
let usage = self.usage.lock().unwrap();
usage
.get(consumer_id)
.map(|(msgs, bytes, _)| (*msgs, *bytes))
.unwrap_or((0, 0))
}
pub fn reset_quota(&self, consumer_id: &str) {
let mut usage = self.usage.lock().unwrap();
usage.remove(consumer_id);
}
fn check_and_update_quota(&self, consumer_id: &str, message_size: usize) -> Result<()> {
let mut usage = self.usage.lock().unwrap();
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
let (msg_count, byte_count, last_reset) =
usage.entry(consumer_id.to_string()).or_insert((0, 0, now));
if now - *last_reset >= self.time_window_secs {
*msg_count = 0;
*byte_count = 0;
*last_reset = now;
}
if *msg_count >= self.max_messages {
return Err(BrokerError::Configuration(format!(
"Message quota exceeded for consumer {}: {}/{}",
consumer_id, msg_count, self.max_messages
)));
}
if *byte_count + message_size > self.max_size_bytes {
return Err(BrokerError::Configuration(format!(
"Size quota exceeded for consumer {}: {}/{}",
consumer_id, byte_count, self.max_size_bytes
)));
}
*msg_count += 1;
*byte_count += message_size;
Ok(())
}
}
#[async_trait]
impl MessageMiddleware for ResourceQuotaMiddleware {
async fn before_publish(&self, message: &mut Message) -> Result<()> {
message.headers.extra.insert(
"x-quota-max-messages".to_string(),
serde_json::json!(self.max_messages),
);
if self.max_size_bytes != usize::MAX {
message.headers.extra.insert(
"x-quota-max-bytes".to_string(),
serde_json::json!(self.max_size_bytes),
);
}
Ok(())
}
async fn after_consume(&self, message: &mut Message) -> Result<()> {
let consumer_id = message
.headers
.extra
.get("x-consumer-id")
.and_then(|v| v.as_str())
.unwrap_or("default")
.to_string();
let message_size = message.body.len();
self.check_and_update_quota(&consumer_id, message_size)?;
let (msg_count, byte_count) = self.get_usage(&consumer_id);
message.headers.extra.insert(
"x-quota-used-messages".to_string(),
serde_json::json!(msg_count),
);
message.headers.extra.insert(
"x-quota-used-bytes".to_string(),
serde_json::json!(byte_count),
);
Ok(())
}
fn name(&self) -> &str {
"resource_quota"
}
}