use async_trait::async_trait;
use lapin::options::{
BasicAckOptions, BasicConsumeOptions, BasicNackOptions, BasicQosOptions, QueueDeclareOptions,
};
use std::sync::Arc;
use tokio::sync::{Mutex, Notify};
use crate::backends::ReceiveResult;
use crate::backends::contract::MessageBackend;
use crate::error::{WorkerError, WorkerResult};
use crate::message::{AckHandle, Message, MessageMetadata, ReceivedMessage};
#[derive(Debug)]
pub struct RabbitMqAckHandle {
delivery_tag: u64,
ack_channel: Arc<Mutex<lapin::Channel>>,
}
#[async_trait]
impl AckHandle for RabbitMqAckHandle {
async fn ack(&self) -> WorkerResult<()> {
tracing::debug!("Attempting to ack delivery tag {}", self.delivery_tag);
let channel = self.ack_channel.lock().await;
match channel
.basic_ack(
self.delivery_tag,
BasicAckOptions {
multiple: false, },
)
.await
{
Ok(_) => {
tracing::debug!("Successfully acked delivery tag {}", self.delivery_tag);
Ok(())
}
Err(e) => {
tracing::error!("Failed to ack delivery tag {}: {}", self.delivery_tag, e);
Err(WorkerError::BackendError(format!(
"Failed to ack message: {}",
e
)))
}
}
}
async fn nack(&self, requeue: bool) -> WorkerResult<()> {
tracing::debug!(
"Attempting to nack delivery tag {} (requeue={})",
self.delivery_tag,
requeue
);
let channel = self.ack_channel.lock().await;
channel
.basic_nack(
self.delivery_tag,
lapin::options::BasicNackOptions {
multiple: false, requeue,
},
)
.await
.map_err(|e| {
tracing::error!("Failed to nack delivery tag {}: {}", self.delivery_tag, e);
WorkerError::BackendError(format!("Failed to nack message: {}", e))
})?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct RabbitMqConsumerConfig {
pub queue_name: String,
pub consumer_tag: String,
pub auto_ack: bool,
pub prefetch_count: u16,
pub requeue_on_nack: bool,
}
impl Default for RabbitMqConsumerConfig {
fn default() -> Self {
Self {
queue_name: "worker_queue".to_string(),
consumer_tag: "foxtive-worker".to_string(),
auto_ack: false,
prefetch_count: 10,
requeue_on_nack: true,
}
}
}
struct MessageEnvelope {
delivery_tag: u64,
message: Message<serde_json::Value>,
}
pub struct RabbitMqBackend {
message_rx: Arc<Mutex<tokio::sync::mpsc::Receiver<MessageEnvelope>>>,
pool: deadpool_lapin::Pool,
consume_channel: Arc<Mutex<lapin::Channel>>,
config: RabbitMqConsumerConfig,
shutdown_notify: Arc<Notify>,
_consumer_handle: tokio::task::JoinHandle<()>,
}
impl std::fmt::Debug for RabbitMqBackend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RabbitMqBackend")
.field("queue", &self.config.queue_name)
.field("consumer_tag", &self.config.consumer_tag)
.finish()
}
}
impl RabbitMqBackend {
pub async fn new(
amqp_url: impl Into<String>,
config: RabbitMqConsumerConfig,
) -> WorkerResult<Self> {
let manager =
deadpool_lapin::Manager::new(amqp_url.into(), lapin::ConnectionProperties::default());
let pool = deadpool_lapin::Pool::builder(manager)
.build()
.map_err(|e| {
WorkerError::BackendError(format!("Failed to create connection pool: {}", e))
})?;
let conn = pool
.get()
.await
.map_err(|e| WorkerError::BackendError(format!("Failed to get connection: {}", e)))?;
let consume_channel = conn
.create_channel()
.await
.map_err(|e| WorkerError::BackendError(format!("Failed to create channel: {}", e)))?;
consume_channel
.basic_qos(config.prefetch_count, BasicQosOptions { global: false })
.await
.map_err(|e| WorkerError::BackendError(format!("Failed to set QoS: {}", e)))?;
consume_channel
.queue_declare(
&config.queue_name,
QueueDeclareOptions {
durable: true,
..Default::default()
},
lapin::types::FieldTable::default(),
)
.await
.map_err(|e| WorkerError::BackendError(format!("Failed to declare queue: {}", e)))?;
let (tx, rx) = tokio::sync::mpsc::channel(500);
let shutdown_notify = Arc::new(Notify::new());
let consumer_tag = config.consumer_tag.clone();
let queue_name = config.queue_name.clone();
let mut lapin_consumer = consume_channel
.basic_consume(
&queue_name,
&consumer_tag,
BasicConsumeOptions {
no_ack: config.auto_ack,
..Default::default()
},
lapin::types::FieldTable::default(),
)
.await
.map_err(|e| WorkerError::BackendError(format!("Failed to start consumer: {}", e)))?;
let notify_clone = shutdown_notify.clone();
let consumer_handle = tokio::spawn(async move {
use futures_util::StreamExt;
loop {
tokio::select! {
_ = notify_clone.notified() => {
tracing::debug!("[{}] Consumer shutting down", consumer_tag);
break;
}
delivery = lapin_consumer.next() => {
match delivery {
Some(Ok(delivery)) => {
let delivery_tag = delivery.delivery_tag;
let payload: serde_json::Value = match serde_json::from_slice(&delivery.data) {
Ok(p) => p,
Err(e) => {
tracing::error!(
"Failed to deserialize message payload: {} (message_id: {:?}, data length: {})",
e,
delivery.properties.message_id(),
delivery.data.len()
);
if let Err(nack_err) = delivery.nack(BasicNackOptions::default()).await {
tracing::error!("Failed to nack malformed message: {:?}", nack_err);
}
continue; }
};
let message_id = delivery.properties.message_id()
.as_ref()
.map(|v| v.to_string())
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let routing_key = delivery.routing_key.clone();
let metadata = MessageMetadata::new(&queue_name)
.with_routing_key(routing_key);
let message = Message {
id: message_id,
payload,
metadata,
};
let envelope = MessageEnvelope {
delivery_tag,
message,
};
if tx.send(envelope).await.is_err() {
tracing::debug!("[{}] Receiver dropped, stopping consumer", consumer_tag);
break;
}
}
Some(Err(e)) => {
tracing::error!("[{}] Consumer error: {:?}", consumer_tag, e);
}
None => {
tracing::warn!("[{}] Consumer stream ended", consumer_tag);
break;
}
}
}
}
}
});
Ok(Self {
message_rx: Arc::new(Mutex::new(rx)),
pool,
consume_channel: Arc::new(Mutex::new(consume_channel)),
config,
shutdown_notify,
_consumer_handle: consumer_handle,
})
}
pub async fn with_defaults(amqp_url: &str) -> WorkerResult<Self> {
Self::new(amqp_url, RabbitMqConsumerConfig::default()).await
}
pub fn queue_name(&self) -> &str {
&self.config.queue_name
}
pub async fn batch_ack(&self, delivery_tag: u64) -> WorkerResult<()> {
let channel = self.consume_channel.lock().await;
channel
.basic_ack(
delivery_tag,
BasicAckOptions {
multiple: true, },
)
.await
.map_err(|e| {
tracing::error!(
"Failed to batch ack up to delivery tag {}: {}",
delivery_tag,
e
);
WorkerError::BackendError(format!("Failed to batch ack messages: {}", e))
})?;
Ok(())
}
pub async fn adjust_prefetch(&self, prefetch_count: u16) -> WorkerResult<()> {
let channel = self.consume_channel.lock().await;
channel
.basic_qos(prefetch_count, BasicQosOptions { global: false })
.await
.map_err(|e| {
tracing::error!("Failed to adjust prefetch to {}: {}", prefetch_count, e);
WorkerError::BackendError(format!("Failed to adjust prefetch: {}", e))
})?;
tracing::info!("Adjusted prefetch count to {}", prefetch_count);
Ok(())
}
}
#[async_trait]
impl MessageBackend for RabbitMqBackend {
async fn receive(&self) -> WorkerResult<ReceiveResult<serde_json::Value>> {
let mut rx = self.message_rx.lock().await;
match rx.recv().await {
Some(envelope) => {
let ack_handle = Arc::new(RabbitMqAckHandle {
delivery_tag: envelope.delivery_tag,
ack_channel: self.consume_channel.clone(),
});
let message = ReceivedMessage::new(envelope.message, ack_handle);
Ok(ReceiveResult::Message(message))
}
None => {
Ok(ReceiveResult::ConnectionLost {
reason: "Consumer stream ended unexpectedly".to_string(),
})
}
}
}
async fn ack(&self, _message_id: &str) -> WorkerResult<()> {
Err(WorkerError::BackendError(
"Direct ack by ID not supported for RabbitMQ. Use AckHandle from receive()."
.to_string(),
))
}
async fn nack(&self, _message_id: &str, _requeue: bool) -> WorkerResult<()> {
Err(WorkerError::BackendError(
"Direct nack by ID not supported for RabbitMQ. Use AckHandle from receive()."
.to_string(),
))
}
async fn health_check(&self) -> WorkerResult<()> {
let _ = self.pool.get().await.map_err(|e| {
WorkerError::BackendError(format!("RabbitMQ health check failed: {}", e))
})?;
Ok(())
}
async fn shutdown(&self) -> WorkerResult<()> {
self.shutdown_notify.notify_one();
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
#[ignore]
async fn test_connect_and_health() {
let backend = RabbitMqBackend::with_defaults("amqp://localhost")
.await
.unwrap();
assert!(backend.health_check().await.is_ok());
}
#[tokio::test]
#[ignore]
async fn test_receive_timeout() {
let backend = RabbitMqBackend::with_defaults("amqp://localhost")
.await
.unwrap();
let result =
tokio::time::timeout(std::time::Duration::from_millis(100), backend.receive()).await;
assert!(result.is_err());
}
}