use lapin::{
options::*,
types::FieldTable,
Connection, Channel, Queue, Consumer,
BasicProperties, publisher_confirm::Confirmation,
};
use std::sync::Arc;
use tokio::sync::{RwLock, Mutex};
use tracing::{info, warn, error, debug};
use crate::error::{RabbitMeshError, Result};
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct ConnectionConfig {
pub url: String,
pub connection_timeout_ms: u64,
pub heartbeat_seconds: u16,
pub max_retries: u32,
pub retry_delay_ms: u64,
pub prefetch_count: u16,
}
impl Default for ConnectionConfig {
fn default() -> Self {
Self {
url: "amqp://localhost:5672".to_string(),
connection_timeout_ms: 10_000,
heartbeat_seconds: 60,
max_retries: 5,
retry_delay_ms: 1_000,
prefetch_count: 10,
}
}
}
pub struct ConnectionManager {
config: ConnectionConfig,
connection: Arc<RwLock<Option<Arc<Connection>>>>,
channels: Arc<Mutex<Vec<Channel>>>,
}
impl ConnectionManager {
pub fn new(url: impl Into<String>) -> Self {
let mut config = ConnectionConfig::default();
config.url = url.into();
Self::with_config(config)
}
pub fn with_config(config: ConnectionConfig) -> Self {
Self {
config,
connection: Arc::new(RwLock::new(None)),
channels: Arc::new(Mutex::new(Vec::new())),
}
}
pub async fn connect(&self) -> Result<()> {
let mut attempts = 0;
let max_retries = self.config.max_retries;
loop {
match self.try_connect().await {
Ok(connection) => {
info!("Successfully connected to RabbitMQ at {}", self.config.url);
*self.connection.write().await = Some(Arc::new(connection));
return Ok(());
}
Err(e) => {
attempts += 1;
if attempts >= max_retries {
error!("Failed to connect to RabbitMQ after {} attempts: {}", attempts, e);
return Err(e);
}
warn!(
"Connection attempt {} failed, retrying in {}ms: {}",
attempts, self.config.retry_delay_ms, e
);
tokio::time::sleep(Duration::from_millis(self.config.retry_delay_ms)).await;
}
}
}
}
async fn try_connect(&self) -> Result<Connection> {
debug!("Attempting to connect to {}", self.config.url);
let connection = Connection::connect(
&self.config.url,
lapin::ConnectionProperties::default()
.with_connection_name(format!("rabbitmesh-{}", uuid::Uuid::new_v4()).into())
).await?;
debug!("AMQP connection established");
Ok(connection)
}
pub async fn get_channel(&self) -> Result<Channel> {
{
let mut channels = self.channels.lock().await;
if let Some(channel) = channels.pop() {
if channel.status().connected() {
debug!("Reusing existing channel");
return Ok(channel);
}
}
}
let connection = self.ensure_connected().await?;
let channel = connection.create_channel().await?;
channel.basic_qos(self.config.prefetch_count, BasicQosOptions::default()).await?;
debug!("Created new channel");
Ok(channel)
}
pub async fn return_channel(&self, channel: Channel) {
if channel.status().connected() {
let mut channels = self.channels.lock().await;
if channels.len() < 10 {
channels.push(channel);
debug!("Returned channel to pool");
}
}
}
async fn ensure_connected(&self) -> Result<Arc<Connection>> {
{
let connection_guard = self.connection.read().await;
if let Some(connection) = connection_guard.as_ref() {
if connection.status().connected() {
return Ok(connection.clone());
}
}
}
warn!("Connection lost, attempting to reconnect");
self.connect().await?;
let connection_guard = self.connection.read().await;
connection_guard
.as_ref()
.cloned()
.ok_or_else(|| RabbitMeshError::internal_error("Connection should exist after connect"))
}
pub async fn declare_queue(&self, queue_name: &str) -> Result<Queue> {
let channel = self.get_channel().await?;
let queue = channel
.queue_declare(
queue_name,
QueueDeclareOptions {
durable: true,
exclusive: false,
auto_delete: false,
..Default::default()
},
FieldTable::default(),
)
.await?;
self.return_channel(channel).await;
debug!("Declared queue: {}", queue_name);
Ok(queue)
}
pub async fn create_consumer(&self, queue_name: &str, consumer_tag: &str) -> Result<Consumer> {
let channel = self.get_channel().await?;
let consumer = channel
.basic_consume(
queue_name,
consumer_tag,
BasicConsumeOptions {
no_ack: false,
exclusive: false,
..Default::default()
},
FieldTable::default(),
)
.await?;
debug!("Created consumer for queue: {}", queue_name);
Ok(consumer)
}
pub async fn publish(
&self,
queue_name: &str,
payload: &[u8],
properties: BasicProperties,
) -> Result<Confirmation> {
let channel = self.get_channel().await?;
let confirmation = channel
.basic_publish(
"",
queue_name,
BasicPublishOptions::default(),
payload,
properties,
)
.await?
.await?;
self.return_channel(channel).await;
debug!("Published message to queue: {}", queue_name);
Ok(confirmation)
}
pub async fn is_connected(&self) -> bool {
let connection_guard = self.connection.read().await;
connection_guard
.as_ref()
.map(|conn| conn.status().connected())
.unwrap_or(false)
}
pub async fn get_stats(&self) -> ConnectionStats {
let is_connected = self.is_connected().await;
let channel_count = self.channels.lock().await.len();
ConnectionStats {
is_connected,
channel_pool_size: channel_count,
url: self.config.url.clone(),
}
}
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct ConnectionStats {
pub is_connected: bool,
pub channel_pool_size: usize,
pub url: String,
}
impl std::fmt::Debug for ConnectionManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnectionManager")
.field("config", &self.config)
.finish()
}
}