use std::fmt::{Debug, Formatter};
use std::future::Future;
use std::sync::Arc;
use lapin::options::ConfirmSelectOptions;
use lapin::{Channel, Connection, ConnectionProperties};
use tokio::time::{Duration, sleep, timeout};
use tokio_util::sync::CancellationToken;
use crate::SHUTDOWN_GRACE;
use crate::backends::rabbitmq::map_lapin_error;
use crate::error::{Result, ShoveError};
use crate::metrics;
use crate::retry::Backoff;
#[derive(Clone)]
pub struct RabbitMqConfig {
pub uri: String,
}
impl Debug for RabbitMqConfig {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let redacted_uri = if let Ok(mut url) = url::Url::parse(&self.uri) {
url.set_password(None).ok();
url.to_string()
} else {
"<unparseable>".to_string()
};
f.debug_struct("RabbitMqConfig")
.field("uri", &redacted_uri)
.finish()
}
}
impl RabbitMqConfig {
pub fn new(uri: impl Into<String>) -> Self {
Self { uri: uri.into() }
}
pub fn uri(&self) -> &str {
&self.uri
}
}
impl Default for RabbitMqConfig {
fn default() -> Self {
Self::new("amqp://guest:guest@localhost:5672")
}
}
fn is_connection_dead(e: &lapin::Error) -> bool {
e.is_amqp_hard_error()
|| matches!(
e.kind(),
lapin::ErrorKind::InvalidConnectionState(_)
| lapin::ErrorKind::IOError(_)
| lapin::ErrorKind::MissingHeartbeatError
)
}
#[derive(Clone)]
pub struct RabbitMqClient {
inner: Arc<ClientInner>,
}
struct ClientInner {
connection: arc_swap::ArcSwap<Connection>,
config: RabbitMqConfig,
reconnect_lock: tokio::sync::Mutex<()>,
shutdown_token: CancellationToken,
}
impl RabbitMqClient {
pub async fn connect(config: &RabbitMqConfig) -> Result<Self> {
let connection = Self::dial(config).await?;
Ok(Self {
inner: Arc::new(ClientInner {
connection: arc_swap::ArcSwap::from_pointee(connection),
config: config.clone(),
reconnect_lock: tokio::sync::Mutex::new(()),
shutdown_token: CancellationToken::new(),
}),
})
}
pub async fn connect_with_retry(config: &RabbitMqConfig, max_attempts: u32) -> Result<Self> {
let mut backoff = Backoff::default();
let mut last_err = None;
for attempt in 0..max_attempts {
match Self::connect(config).await {
Ok(client) => return Ok(client),
Err(e) => {
if attempt + 1 < max_attempts {
let delay = backoff.next().expect("backoff is infinite");
tracing::warn!(
attempt = attempt + 1,
max_attempts,
error = %e,
"RabbitMQ connection failed, retrying in {delay:?}"
);
tokio::time::sleep(delay).await;
}
last_err = Some(e);
}
}
}
Err(last_err.expect("loop ran at least once"))
}
async fn dial(config: &RabbitMqConfig) -> Result<Connection> {
let pid = std::process::id();
let connection_name = format!("shove-rs-{pid}");
let properties =
ConnectionProperties::default().with_connection_name(connection_name.into());
timeout(
Duration::from_secs(5),
Connection::connect(&config.uri, properties),
)
.await
.map_err(|_| ShoveError::Connection("timed out connecting to RabbitMQ".into()))?
.map_err(|e| map_lapin_error("failed to connect to RabbitMQ", e))
}
fn snapshot(&self) -> Arc<Connection> {
self.inner.connection.load_full()
}
async fn with_reconnect<F, Fut, T>(&self, op_name: &'static str, op: F) -> Result<T>
where
F: Fn(Arc<Connection>) -> Fut,
Fut: Future<Output = std::result::Result<T, lapin::Error>>,
{
if self.inner.shutdown_token.is_cancelled() {
metrics::record_backend_error(
metrics::BackendLabel::RabbitMq,
metrics::BackendErrorKind::Connection,
);
return Err(ShoveError::Connection(format!(
"cannot {op_name}: client is shutting down"
)));
}
let observed = self.snapshot();
match op(observed.clone()).await {
Ok(v) => Ok(v),
Err(e) if is_connection_dead(&e) => {
tracing::warn!(error = %e, op = op_name, "RabbitMQ connection appears dead, reconnecting");
self.reconnect(&observed).await?;
let fresh = self.snapshot();
op(fresh)
.await
.map_err(|e| map_lapin_error(&format!("{op_name} failed after reconnect"), e))
}
Err(e) => Err(map_lapin_error(&format!("{op_name} failed"), e)),
}
}
async fn reconnect(&self, observed: &Arc<Connection>) -> Result<()> {
if self.inner.shutdown_token.is_cancelled() {
return Err(ShoveError::Connection(
"cannot reconnect: client is shutting down".into(),
));
}
let _guard = self.inner.reconnect_lock.lock().await;
let current = self.inner.connection.load_full();
if !Arc::ptr_eq(¤t, observed) {
return Ok(());
}
let new_conn = Self::dial(&self.inner.config).await?;
self.inner.connection.store(Arc::new(new_conn));
tracing::info!("RabbitMQ connection re-established");
Ok(())
}
pub async fn create_channel(&self) -> Result<Channel> {
self.with_reconnect("create channel", |conn| async move {
conn.create_channel().await
})
.await
}
pub async fn create_confirm_channel(&self) -> Result<Channel> {
self.with_reconnect("create confirm channel", |conn| async move {
let channel = conn.create_channel().await?;
channel
.confirm_select(ConfirmSelectOptions::default())
.await?;
Ok(channel)
})
.await
}
#[cfg(feature = "rabbitmq-transactional")]
pub async fn create_tx_channel(&self) -> Result<Channel> {
self.with_reconnect("create tx channel", |conn| async move {
let channel = conn.create_channel().await?;
channel.tx_select().await?;
Ok(channel)
})
.await
}
pub fn shutdown_token(&self) -> CancellationToken {
self.inner.shutdown_token.clone()
}
pub fn is_connected(&self) -> bool {
self.snapshot().status().connected()
}
pub async fn shutdown(&self) {
self.inner.shutdown_token.cancel();
sleep(SHUTDOWN_GRACE).await;
if let Err(e) = self.snapshot().close(0, "shutdown".into()).await {
tracing::warn!("error while closing RabbitMQ connection: {e}");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use lapin::ErrorKind;
use lapin::protocol::{AMQPError, AMQPErrorKind, AMQPHardError, AMQPSoftError};
#[test]
fn config_debug_redacts_password_only() {
let config = RabbitMqConfig::new("amqp://admin:s3cret!@localhost:5672/%2F");
let debug_output = format!("{config:?}");
assert!(!debug_output.contains("s3cret!"));
assert!(debug_output.contains("admin@localhost"));
}
#[test]
fn config_debug_no_creds_remains_clear() {
let config = RabbitMqConfig::new("amqp://localhost:5672/%2F");
let debug_output = format!("{config:?}");
assert!(debug_output.contains("amqp://localhost:5672/%2F"));
}
#[test]
fn config_new_stores_uri() {
let config = RabbitMqConfig::new("amqp://host:1234/%2F");
assert_eq!(config.uri, "amqp://host:1234/%2F");
}
#[test]
fn default_config_is_localhost() {
let cfg = RabbitMqConfig::default();
assert!(cfg.uri().contains("localhost:5672"));
}
#[test]
fn invalid_connection_state_is_dead() {
let err = lapin::Error::from(ErrorKind::InvalidConnectionState(
lapin::ConnectionState::Closed,
));
assert!(is_connection_dead(&err));
}
#[test]
fn missing_heartbeat_is_dead() {
let err = lapin::Error::from(ErrorKind::MissingHeartbeatError);
assert!(is_connection_dead(&err));
}
#[test]
fn invalid_channel_state_is_not_connection_dead() {
let err = lapin::Error::from(ErrorKind::InvalidChannelState(
lapin::ChannelState::Closed,
"test",
));
assert!(!is_connection_dead(&err));
}
#[test]
fn channels_limit_is_not_connection_dead() {
let err = lapin::Error::from(ErrorKind::ChannelsLimitReached);
assert!(!is_connection_dead(&err));
}
#[test]
fn io_error_is_dead() {
let io_err = std::io::Error::new(std::io::ErrorKind::ConnectionAborted, "broken pipe");
let err = lapin::Error::from(ErrorKind::IOError(Arc::new(io_err)));
assert!(is_connection_dead(&err));
}
#[test]
fn amqp_hard_error_is_dead() {
let amqp_err = AMQPError::new(
AMQPErrorKind::Hard(AMQPHardError::CONNECTIONFORCED),
"broker closed".into(),
);
let err = lapin::Error::from(ErrorKind::ProtocolError(amqp_err));
assert!(is_connection_dead(&err));
}
#[test]
fn amqp_soft_error_is_not_connection_dead() {
let amqp_err = AMQPError::new(
AMQPErrorKind::Soft(AMQPSoftError::ACCESSREFUSED),
"denied".into(),
);
let err = lapin::Error::from(ErrorKind::ProtocolError(amqp_err));
assert!(!is_connection_dead(&err));
}
}