use async_trait::async_trait;
use std::time::Duration;
use crate::error::{WorkerError, WorkerResult};
use crate::message::ReceivedMessage;
use crate::middleware::{MessageHandler, Middleware};
#[derive(Debug, Clone)]
pub struct ProcessingTimeoutMiddleware {
timeout: Duration,
}
impl ProcessingTimeoutMiddleware {
pub fn new(timeout: Duration) -> Self {
assert!(
!timeout.is_zero(),
"Processing timeout must be greater than zero"
);
Self { timeout }
}
pub fn timeout(&self) -> Duration {
self.timeout
}
}
#[async_trait]
impl Middleware for ProcessingTimeoutMiddleware {
fn name(&self) -> &str {
"processing-timeout"
}
async fn handle(
&self,
message: ReceivedMessage<serde_json::Value>,
next: Box<dyn MessageHandler>,
) -> WorkerResult<()> {
let message_id = message.message.id.clone();
tracing::debug!(
message_id = %message_id,
timeout_ms = self.timeout.as_millis(),
"Starting message processing with timeout"
);
match tokio::time::timeout(self.timeout, next.handle(message.clone())).await {
Ok(result) => {
match result {
Ok(()) => {
tracing::debug!(
message_id = %message_id,
"Message processing completed successfully within timeout"
);
Ok(())
}
Err(e) => {
tracing::warn!(
message_id = %message_id,
error = %e,
"Message processing failed (within timeout)"
);
Err(e)
}
}
}
Err(_) => {
tracing::warn!(
message_id = %message_id,
timeout_ms = self.timeout.as_millis(),
"Message processing timed out - nacking with requeue"
);
if let Err(nack_err) = message.nack(true).await {
tracing::error!(
message_id = %message_id,
error = %nack_err,
"Failed to nack timed-out message"
);
}
Err(WorkerError::Timeout(format!(
"Message {} processing exceeded timeout of {:?}",
message_id, self.timeout
)))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::message::{AckHandle, Message, MessageMetadata};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
#[derive(Debug)]
struct MockAckHandle {
acked: Arc<AtomicBool>,
nacked: Arc<AtomicBool>,
requeued: Arc<AtomicBool>,
}
impl MockAckHandle {
fn new() -> (Self, Arc<AtomicBool>, Arc<AtomicBool>, Arc<AtomicBool>) {
let acked = Arc::new(AtomicBool::new(false));
let nacked = Arc::new(AtomicBool::new(false));
let requeued = Arc::new(AtomicBool::new(false));
(
Self {
acked: acked.clone(),
nacked: nacked.clone(),
requeued: requeued.clone(),
},
acked,
nacked,
requeued,
)
}
}
#[async_trait]
impl AckHandle for MockAckHandle {
async fn ack(&self) -> WorkerResult<()> {
self.acked.store(true, Ordering::SeqCst);
Ok(())
}
async fn nack(&self, requeue: bool) -> WorkerResult<()> {
self.nacked.store(true, Ordering::SeqCst);
self.requeued.store(requeue, Ordering::SeqCst);
Ok(())
}
}
struct FastHandler;
#[async_trait]
impl MessageHandler for FastHandler {
async fn handle(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
Ok(())
}
}
struct SlowHandler {
delay: Duration,
}
#[async_trait]
impl MessageHandler for SlowHandler {
async fn handle(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
tokio::time::sleep(self.delay).await;
Ok(())
}
}
struct FailingHandler;
#[async_trait]
impl MessageHandler for FailingHandler {
async fn handle(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
Err(WorkerError::ProcessingFailed("intentional failure".to_string()))
}
}
fn create_test_message() -> (ReceivedMessage<serde_json::Value>, Arc<AtomicBool>, Arc<AtomicBool>, Arc<AtomicBool>) {
let (ack_handle, acked, nacked, requeued) = MockAckHandle::new();
let message = Message {
id: "test-msg-1".to_string(),
payload: serde_json::json!({"test": "data"}),
metadata: MessageMetadata::new("test-queue"),
};
(ReceivedMessage::new(message, Arc::new(ack_handle)), acked, nacked, requeued)
}
#[tokio::test]
async fn test_fast_processing_completes() {
let middleware = ProcessingTimeoutMiddleware::new(Duration::from_secs(5));
let (message, acked, nacked, _) = create_test_message();
let result = middleware.handle(message, Box::new(FastHandler)).await;
assert!(result.is_ok());
assert!(!acked.load(Ordering::SeqCst)); assert!(!nacked.load(Ordering::SeqCst));
}
#[tokio::test]
async fn test_slow_processing_times_out() {
let timeout = Duration::from_millis(100);
let middleware = ProcessingTimeoutMiddleware::new(timeout);
let (message, _, nacked, requeued) = create_test_message();
let slow_handler = SlowHandler {
delay: Duration::from_secs(1),
};
let result = middleware.handle(message, Box::new(slow_handler)).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), WorkerError::Timeout(_)));
assert!(nacked.load(Ordering::SeqCst)); assert!(requeued.load(Ordering::SeqCst)); }
#[tokio::test]
async fn test_processing_error_propagates() {
let middleware = ProcessingTimeoutMiddleware::new(Duration::from_secs(5));
let (message, _, _, _) = create_test_message();
let result = middleware.handle(message, Box::new(FailingHandler)).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), WorkerError::ProcessingFailed(_)));
}
#[tokio::test]
async fn test_timeout_cancels_long_running_task() {
let timeout = Duration::from_millis(50);
let middleware = ProcessingTimeoutMiddleware::new(timeout);
let (message, _, nacked, _) = create_test_message();
let very_slow_handler = SlowHandler {
delay: Duration::from_secs(10),
};
let start = std::time::Instant::now();
let result = middleware.handle(message, Box::new(very_slow_handler)).await;
let elapsed = start.elapsed();
assert!(result.is_err());
assert!(elapsed < Duration::from_secs(1)); assert!(nacked.load(Ordering::SeqCst));
}
#[tokio::test]
async fn test_boundary_condition_exactly_at_timeout() {
let timeout = Duration::from_millis(100);
let middleware = ProcessingTimeoutMiddleware::new(timeout);
let (message, _, _, _) = create_test_message();
let almost_timeout_handler = SlowHandler {
delay: Duration::from_millis(80),
};
let result = middleware.handle(message, Box::new(almost_timeout_handler)).await;
assert!(result.is_ok());
}
#[test]
#[should_panic(expected = "Processing timeout must be greater than zero")]
fn test_zero_timeout_panics() {
let _ = ProcessingTimeoutMiddleware::new(Duration::ZERO);
}
#[test]
fn test_timeout_getter() {
let timeout = Duration::from_secs(30);
let middleware = ProcessingTimeoutMiddleware::new(timeout);
assert_eq!(middleware.timeout(), timeout);
}
}