use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeadLetterMessage {
pub original_id: String,
pub original_payload: serde_json::Value,
pub source_queue: String,
pub attempt_count: u32,
pub error_message: String,
pub dlq_timestamp: DateTime<Utc>,
pub last_worker_id: Option<String>,
pub failure_context: serde_json::Value,
}
impl DeadLetterMessage {
pub fn new(
original_id: String,
original_payload: serde_json::Value,
source_queue: String,
attempt_count: u32,
error_message: String,
) -> Self {
Self {
original_id,
original_payload,
source_queue,
attempt_count,
error_message,
dlq_timestamp: Utc::now(),
last_worker_id: None,
failure_context: serde_json::json!({}),
}
}
pub fn with_worker_id(mut self, worker_id: String) -> Self {
self.last_worker_id = Some(worker_id);
self
}
pub fn with_context(mut self, key: &str, value: serde_json::Value) -> Self {
if let serde_json::Value::Object(ref mut map) = self.failure_context {
map.insert(key.to_string(), value);
}
self
}
pub fn to_json(&self) -> Result<String, serde_json::Error> {
serde_json::to_string(self)
}
pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
serde_json::from_str(json)
}
}
#[derive(Debug, Clone)]
pub struct PoisonPillConfig {
pub max_failures: u32,
pub time_window: std::time::Duration,
pub immediate_dlq: bool,
}
impl Default for PoisonPillConfig {
fn default() -> Self {
Self {
max_failures: 10,
time_window: std::time::Duration::from_secs(3600), immediate_dlq: true,
}
}
}
#[derive(Debug)]
pub struct PoisonPillTracker {
config: PoisonPillConfig,
failure_counts: std::sync::Mutex<std::collections::HashMap<String, Vec<DateTime<Utc>>>>,
}
impl PoisonPillTracker {
pub fn new(config: PoisonPillConfig) -> Self {
Self {
config,
failure_counts: std::sync::Mutex::new(std::collections::HashMap::new()),
}
}
pub fn record_failure(&self, message_id: &str) -> bool {
let mut counts = self.failure_counts.lock().unwrap();
let now = Utc::now();
let failures = counts.entry(message_id.to_string()).or_default();
failures.push(now);
let cutoff = now - chrono::Duration::from_std(self.config.time_window).unwrap_or_default();
failures.retain(|&t| t > cutoff);
let is_poison_pill = failures.len() >= self.config.max_failures as usize;
if is_poison_pill {
tracing::warn!(
"Poison pill detected for message {}: {} failures in {:?}",
message_id,
failures.len(),
self.config.time_window
);
}
is_poison_pill
}
pub fn get_failure_count(&self, message_id: &str) -> usize {
let counts = self.failure_counts.lock().unwrap();
counts.get(message_id).map(|v| v.len()).unwrap_or(0)
}
pub fn clear(&self, message_id: &str) {
let mut counts = self.failure_counts.lock().unwrap();
counts.remove(message_id);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_dead_letter_message_creation() {
let dlq_msg = DeadLetterMessage::new(
"msg-123".to_string(),
serde_json::json!({"data": "test"}),
"my-queue".to_string(),
5,
"Processing failed".to_string(),
);
assert_eq!(dlq_msg.original_id, "msg-123");
assert_eq!(dlq_msg.attempt_count, 5);
assert_eq!(dlq_msg.source_queue, "my-queue");
}
#[test]
fn test_dead_letter_message_serialization() {
let dlq_msg = DeadLetterMessage::new(
"msg-123".to_string(),
serde_json::json!({"data": "test"}),
"my-queue".to_string(),
5,
"Processing failed".to_string(),
);
let json = dlq_msg.to_json().unwrap();
let parsed = DeadLetterMessage::from_json(&json).unwrap();
assert_eq!(parsed.original_id, dlq_msg.original_id);
assert_eq!(parsed.attempt_count, dlq_msg.attempt_count);
}
#[test]
fn test_poison_pill_detection() {
let config = PoisonPillConfig {
max_failures: 3,
time_window: Duration::from_secs(60),
immediate_dlq: true,
};
let tracker = PoisonPillTracker::new(config);
assert!(!tracker.record_failure("msg-1"));
assert!(!tracker.record_failure("msg-1"));
assert!(tracker.record_failure("msg-1"));
}
}