use std::sync::Arc;
use std::time::Duration;
use bytes::Bytes;
use sha2::{Digest, Sha256};
use tokio::sync::{broadcast, Mutex, Semaphore};
use tracing::{debug, info, warn};
use crate::amqp::{amqp_url_with_vhost, AmqpClient};
use crate::backup::queue_reader::MessageAssembler;
use crate::backup::stream_reader::{StreamCheckpoint, StreamReader};
use crate::compression;
use crate::config::{CompressionType, Config, OffsetStorageBackend, QueueType};
use crate::definitions::types::{QueueInfo, RabbitMqDefinitions};
use crate::definitions::{DefinitionsExporter, ManagementClient};
use crate::error::{Error, Result};
use crate::manifest::{BackupManifest, DefinitionsBackup, SegmentMetadata};
use crate::offset_store::{QueueProgressUpdate, SqliteOffsetStore};
use crate::segment::{SegmentReader, SegmentWriter};
use crate::storage::{create_backend_from_config, StorageBackend};
use crate::stream::StreamClient;
#[derive(Clone)]
struct CheckpointHandle {
store: Arc<SqliteOffsetStore>,
remote_key: Option<String>,
}
impl CheckpointHandle {
async fn sync(&self, storage: &Arc<dyn StorageBackend>) -> Result<()> {
if let Some(key) = &self.remote_key {
self.store.sync_to_storage(storage, key).await?;
} else {
self.store.checkpoint().await?;
}
Ok(())
}
}
pub struct BackupEngine {
config: Config,
}
impl BackupEngine {
pub fn new(config: Config) -> Self {
Self { config }
}
pub async fn run(&self, shutdown_rx: broadcast::Receiver<()>) -> Result<BackupManifest> {
let storage = create_backend_from_config(&self.config.storage)?;
info!("Storage backend initialized");
let checkpoint = init_checkpoint(&self.config, &storage).await?;
if let Some(checkpoint) = &checkpoint {
checkpoint
.store
.set_job_status(&self.config.backup_id, "running")
.await?;
checkpoint.sync(&storage).await?;
}
let mut manifest = BackupManifest::new(
self.config.backup_id.clone(),
env!("CARGO_PKG_VERSION").to_string(),
);
let source =
self.config.source.as_ref().ok_or_else(|| {
Error::Config("Source configuration required for backup".to_string())
})?;
let mgmt = ManagementClient::from_config(source)?;
match mgmt.get_overview().await {
Ok(overview) => {
manifest.set_source_info(overview.cluster_name, overview.rabbitmq_version);
info!("Connected to RabbitMQ cluster");
}
Err(e) => {
warn!("Failed to fetch cluster overview (continuing): {}", e);
}
}
let backup_opts = self.config.backup.clone().unwrap_or_default();
if backup_opts.include_definitions {
info!("Exporting definitions...");
match export_definitions(source, &storage, &self.config.backup_id, &backup_opts).await {
Ok(defs_backup) => {
manifest.set_definitions(defs_backup);
info!("Definitions exported successfully");
}
Err(e) => {
warn!("Failed to export definitions (continuing): {}", e);
}
}
}
let selection = source.queues.clone().unwrap_or_default();
let queues = mgmt.discover_queues(&selection).await?;
let (stream_queues, amqp_queues): (Vec<_>, Vec<_>) =
queues.into_iter().partition(|q| q.queue_type == "stream");
if !stream_queues.is_empty() && backup_opts.stream_enabled {
info!("{} stream queue(s) found", stream_queues.len());
}
if amqp_queues.is_empty() && (stream_queues.is_empty() || !backup_opts.stream_enabled) {
info!("No queues to back up");
manifest.finalize();
let manifest_json = manifest.to_json()?;
storage
.put(
&format!("{}/manifest.json", self.config.backup_id),
manifest_json,
)
.await?;
if let Some(checkpoint) = &checkpoint {
checkpoint
.store
.set_job_status(&self.config.backup_id, "completed")
.await?;
checkpoint.sync(&storage).await?;
}
return Ok(manifest);
}
info!(
"Backing up {} queue(s) (max {} concurrent)",
amqp_queues.len(),
backup_opts.max_concurrent_queues
);
let manifest = Arc::new(Mutex::new(manifest));
let semaphore = Arc::new(Semaphore::new(backup_opts.max_concurrent_queues.max(1)));
let mut handles = Vec::new();
for queue in amqp_queues {
let sem = semaphore.clone();
let storage = storage.clone();
let manifest = manifest.clone();
let config = self.config.clone();
let shutdown_rx = shutdown_rx.resubscribe();
let checkpoint = checkpoint.clone();
handles.push(tokio::spawn(async move {
let _permit = sem.acquire().await;
backup_single_queue(queue, config, storage, manifest, checkpoint, shutdown_rx).await
}));
}
if backup_opts.stream_enabled && !stream_queues.is_empty() {
let source = source.clone();
let stream_port = source.stream_port;
let stream_host = source
.amqp_url
.parse::<amq_protocol::uri::AMQPUri>()
.map(|u| u.authority.host)
.unwrap_or_else(|_| "localhost".to_string());
for queue in stream_queues {
let sem = semaphore.clone();
let storage = storage.clone();
let manifest = manifest.clone();
let backup_opts = backup_opts.clone();
let backup_id = self.config.backup_id.clone();
let host = stream_host.clone();
let username = source.management_username.clone();
let password = source.management_password.clone();
let checkpoint = checkpoint.clone();
let shutdown_rx = shutdown_rx.resubscribe();
handles.push(tokio::spawn(async move {
let _permit = sem.acquire().await;
let checkpoint_for_stream = checkpoint.as_ref().map(|checkpoint| {
StreamCheckpoint {
store: checkpoint.store.clone(),
remote_key: checkpoint.remote_key.clone(),
}
});
let mut start_offset = rabbitmq_stream_client::types::OffsetSpecification::First;
if let Some(checkpoint) = &checkpoint {
if let Some(progress) = checkpoint
.store
.get_progress(&backup_id, &queue.vhost, &queue.name)
.await?
{
if progress.completed && progress.target_message_count == queue.messages {
if reuse_completed_queue_segments(
&backup_id,
&queue,
QueueType::Stream,
backup_opts.compression,
progress.last_segment_sequence,
&storage,
&manifest,
)
.await?
{
info!(
"Skipping already-completed stream {}/{} from checkpoint",
queue.vhost, queue.name
);
return Ok(());
}
warn!(
"Checkpoint for stream {}/{} was complete, but segment reuse failed; restarting from first offset",
queue.vhost, queue.name
);
} else if progress.last_segment_sequence > 0
|| progress.messages_backed_up > 0
{
start_offset =
rabbitmq_stream_client::types::OffsetSpecification::Offset(
progress.messages_backed_up.saturating_add(1),
);
warn!(
"Resuming stream {}/{} from offset {}",
queue.vhost,
queue.name,
progress.messages_backed_up.saturating_add(1)
);
}
}
}
let client = StreamClient::connect_with_vhost(
&host,
stream_port,
&username,
&password,
&queue.vhost,
)
.await?;
StreamReader::backup_stream(
&client,
&queue.name,
&queue.vhost,
start_offset,
&backup_id,
&backup_opts,
&storage,
&manifest,
checkpoint_for_stream,
queue.messages,
shutdown_rx,
)
.await?;
Ok(())
}));
}
}
let results = futures::future::join_all(handles).await;
let mut error_count = 0;
for result in results {
match result {
Ok(Ok(())) => {}
Ok(Err(e)) => {
warn!("Queue backup error: {}", e);
error_count += 1;
}
Err(e) => {
warn!("Task join error: {}", e);
error_count += 1;
}
}
}
let mut manifest = manifest.lock().await;
manifest.finalize();
let manifest_json = manifest.to_json()?;
storage
.put(
&format!("{}/manifest.json", self.config.backup_id),
manifest_json,
)
.await?;
if error_count > 0 {
warn!("{} queue(s) had errors during backup", error_count);
if let Some(checkpoint) = &checkpoint {
checkpoint
.store
.set_job_status(&self.config.backup_id, "failed")
.await?;
checkpoint.sync(&storage).await?;
}
return Err(Error::Manifest(format!(
"{} queue(s) failed during backup; manifest was written for inspection",
error_count
)));
}
if let Some(checkpoint) = &checkpoint {
checkpoint
.store
.set_job_status(&self.config.backup_id, "completed")
.await?;
checkpoint.sync(&storage).await?;
}
info!(
"Backup complete: {} queues, {} messages, {} segments, {} bytes",
manifest.queues.len(),
manifest.total_messages,
manifest.total_segments,
manifest.total_bytes
);
Ok(manifest.clone())
}
}
async fn init_checkpoint(
config: &Config,
storage: &Arc<dyn StorageBackend>,
) -> Result<Option<CheckpointHandle>> {
let Some(offset_config) = &config.offset_storage else {
return Ok(None);
};
match offset_config.backend {
OffsetStorageBackend::Sqlite => {}
OffsetStorageBackend::Memory => {
return Err(Error::Config(
"memory offset storage is not wired into backup checkpointing".to_string(),
));
}
}
if let Some(remote_key) = &offset_config.s3_key {
let local_exists = tokio::fs::try_exists(&offset_config.db_path)
.await
.map_err(|e| Error::Checkpoint(format!("Failed to check offset DB path: {}", e)))?;
if !local_exists {
SqliteOffsetStore::try_load_from_storage(storage, remote_key, &offset_config.db_path)
.await?;
}
}
Ok(Some(CheckpointHandle {
store: Arc::new(SqliteOffsetStore::new(&offset_config.db_path).await?),
remote_key: offset_config.s3_key.clone(),
}))
}
async fn export_definitions(
source: &crate::config::SourceConfig,
storage: &Arc<dyn StorageBackend>,
backup_id: &str,
backup_opts: &crate::config::BackupOptions,
) -> Result<DefinitionsBackup> {
let mgmt = ManagementClient::from_config(source)?;
let exporter = DefinitionsExporter::new(mgmt);
let defs_json = exporter.export_json(None).await?;
let compressed = compression::compress(&defs_json, backup_opts.compression)?;
let ext = compression::extension(backup_opts.compression);
let defs_key = format!("{}/definitions/definitions.json{}", backup_id, ext);
storage.put(&defs_key, Bytes::from(compressed)).await?;
let defs: RabbitMqDefinitions = serde_json::from_slice(&defs_json)?;
Ok(DefinitionsBackup {
key: defs_key,
vhost_count: defs.vhosts.len(),
queue_count: defs.queues.len(),
exchange_count: defs.exchanges.len(),
user_count: defs.users.len(),
size_bytes: defs_json.len() as u64,
})
}
async fn backup_single_queue(
mut queue: QueueInfo,
config: Config,
storage: Arc<dyn StorageBackend>,
manifest: Arc<Mutex<BackupManifest>>,
checkpoint: Option<CheckpointHandle>,
mut shutdown_rx: broadcast::Receiver<()>,
) -> Result<()> {
let source = config
.source
.as_ref()
.ok_or_else(|| Error::Config("Source configuration required".to_string()))?;
let backup_opts = config.backup.clone().unwrap_or_default();
queue = ManagementClient::from_config(source)?
.get_queue(&queue.vhost, &queue.name)
.await?;
let queue_type = match queue.queue_type.as_str() {
"quorum" => QueueType::Quorum,
"stream" => QueueType::Stream,
_ => QueueType::Classic,
};
info!(
"Backing up queue {} (vhost={}, type={}, messages={})",
queue.name, queue.vhost, queue.queue_type, queue.messages
);
let initial_depth = queue.messages;
if let Some(checkpoint) = &checkpoint {
if let Some(progress) = checkpoint
.store
.get_progress(&config.backup_id, &queue.vhost, &queue.name)
.await?
{
if progress.completed && progress.target_message_count == initial_depth {
if reuse_completed_queue_segments(
&config.backup_id,
&queue,
queue_type,
backup_opts.compression,
progress.last_segment_sequence,
&storage,
&manifest,
)
.await?
{
info!(
"Queue {} reused from checkpoint: {} messages in {} segment(s)",
queue.name, progress.messages_backed_up, progress.last_segment_sequence
);
return Ok(());
}
warn!(
"Queue {} checkpoint was complete but segment metadata could not be rebuilt; restarting queue backup",
queue.name
);
} else if progress.messages_backed_up > 0 || progress.last_segment_sequence > 0 {
warn!(
"Queue {} has partial checkpoint progress ({} messages, {} segments); restarting from queue head because AMQP delivery tags are not durable offsets",
queue.name, progress.messages_backed_up, progress.last_segment_sequence
);
}
}
checkpoint
.store
.set_progress_state(QueueProgressUpdate {
backup_id: &config.backup_id,
vhost: &queue.vhost,
queue_name: &queue.name,
messages_backed_up: 0,
last_segment_sequence: 0,
target_message_count: initial_depth,
completed: false,
})
.await?;
checkpoint.sync(&storage).await?;
}
if queue.messages == 0 {
debug!("Queue {} is empty, skipping", queue.name);
if let Some(checkpoint) = &checkpoint {
checkpoint
.store
.set_progress_state(QueueProgressUpdate {
backup_id: &config.backup_id,
vhost: &queue.vhost,
queue_name: &queue.name,
messages_backed_up: 0,
last_segment_sequence: 0,
target_message_count: initial_depth,
completed: true,
})
.await?;
checkpoint.sync(&storage).await?;
}
return Ok(());
}
let source_amqp_url = amqp_url_with_vhost(&source.amqp_url, &queue.vhost)?;
let mut client = AmqpClient::connect(&source_amqp_url, source.tls.as_ref()).await?;
let channel_id = client.open_channel().await?;
client.basic_qos(channel_id, 0).await?;
let consumer_tag = format!("rmq-backup-{}", uuid::Uuid::new_v4());
client
.basic_consume(channel_id, &queue.name, &consumer_tag)
.await?;
let mut segment_writer = SegmentWriter::new(1);
let mut assembler = MessageAssembler::new();
let mut received_count = 0u64;
let mut last_segment_sequence = 0u64;
let read_timeout = Duration::from_secs(10);
let mut shutdown_closed = false;
loop {
tokio::select! {
shutdown = shutdown_rx.recv(), if !shutdown_closed => {
match shutdown {
Ok(()) => {
info!("Shutdown signal received for queue {}", queue.name);
break;
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
shutdown_closed = true;
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
info!("Shutdown signal received for queue {}", queue.name);
break;
}
}
}
frame_result = client.read_frame_timeout(read_timeout) => {
match frame_result? {
Some(frame) => {
if let Some(record) = assembler.process_frame(
frame, &queue.name, &queue.vhost
)? {
segment_writer.add_record(&record)?;
received_count += 1;
if received_count.is_multiple_of(1000) {
debug!(
"Queue {}: {} messages received",
queue.name, received_count
);
}
if segment_writer.should_rotate(
backup_opts.segment_max_bytes,
backup_opts.segment_max_interval_ms,
) {
let key = segment_key(
&config.backup_id,
&queue.vhost,
&queue.name,
segment_writer.sequence(),
backup_opts.compression,
);
let finalized = segment_writer.finalize(
backup_opts.compression,
backup_opts.compression_level,
key,
)?;
storage
.put(&finalized.metadata.key, finalized.data)
.await?;
last_segment_sequence = finalized.metadata.sequence;
manifest.lock().await.add_segment(
&queue.vhost,
&queue.name,
queue_type,
finalized.metadata,
);
if let Some(checkpoint) = &checkpoint {
checkpoint
.store
.set_progress_state(QueueProgressUpdate {
backup_id: &config.backup_id,
vhost: &queue.vhost,
queue_name: &queue.name,
messages_backed_up: received_count,
last_segment_sequence,
target_message_count: initial_depth,
completed: false,
})
.await?;
checkpoint.sync(&storage).await?;
}
}
if backup_opts.stop_at_current_depth
&& received_count >= initial_depth
{
debug!(
"Queue {}: reached target depth {}",
queue.name, initial_depth
);
break;
}
}
}
None => {
if received_count > 0 {
debug!(
"Queue {}: no more messages after {}",
queue.name, received_count
);
break;
}
}
}
}
}
}
if let Err(e) = client.basic_cancel(channel_id, &consumer_tag).await {
warn!(
"Failed to cancel consumer on {} (messages still requeued on disconnect): {}",
queue.name, e
);
}
if segment_writer.has_records() {
let key = segment_key(
&config.backup_id,
&queue.vhost,
&queue.name,
segment_writer.sequence(),
backup_opts.compression,
);
let finalized =
segment_writer.finalize(backup_opts.compression, backup_opts.compression_level, key)?;
storage.put(&finalized.metadata.key, finalized.data).await?;
last_segment_sequence = finalized.metadata.sequence;
manifest.lock().await.add_segment(
&queue.vhost,
&queue.name,
queue_type,
finalized.metadata,
);
}
if let Some(checkpoint) = &checkpoint {
checkpoint
.store
.set_progress_state(QueueProgressUpdate {
backup_id: &config.backup_id,
vhost: &queue.vhost,
queue_name: &queue.name,
messages_backed_up: received_count,
last_segment_sequence,
target_message_count: initial_depth,
completed: true,
})
.await?;
checkpoint.sync(&storage).await?;
}
client.close_channel(channel_id).await.ok();
client.close().await.ok();
info!(
"Queue {} backed up: {} messages in {} segment(s)",
queue.name,
received_count,
segment_writer.sequence() - 1
);
Ok(())
}
async fn reuse_completed_queue_segments(
backup_id: &str,
queue: &QueueInfo,
queue_type: QueueType,
compression: CompressionType,
last_segment_sequence: u64,
storage: &Arc<dyn StorageBackend>,
manifest: &Arc<Mutex<BackupManifest>>,
) -> Result<bool> {
for seq in 1..=last_segment_sequence {
let key = segment_key(backup_id, &queue.vhost, &queue.name, seq, compression);
if !storage.exists(&key).await? {
return Ok(false);
}
let data = storage.get(&key).await?;
SegmentReader::verify_integrity(&data)?;
let header = SegmentReader::read_header(&data)?;
let records = SegmentReader::read_records(&data)?;
let mut hasher = Sha256::new();
hasher.update(&data);
let checksum = format!("{:x}", hasher.finalize());
let uncompressed_bytes = records
.iter()
.map(|record| serde_json::to_vec(record).map(|json| json.len() as u64 + 4))
.collect::<std::result::Result<Vec<_>, _>>()?
.into_iter()
.sum();
let metadata = SegmentMetadata {
key,
sequence: seq,
record_count: header.record_count,
size_bytes: data.len() as u64,
uncompressed_bytes,
first_timestamp: (header.first_timestamp != 0).then_some(header.first_timestamp),
last_timestamp: (header.last_timestamp != 0).then_some(header.last_timestamp),
checksum,
};
manifest
.lock()
.await
.add_segment(&queue.vhost, &queue.name, queue_type, metadata);
}
Ok(true)
}
fn segment_key(
backup_id: &str,
vhost: &str,
queue: &str,
seq: u64,
compression: CompressionType,
) -> String {
let vhost_safe = if vhost == "/" { "_default" } else { vhost };
format!(
"{}/queues/{}/{}/segment-{:04}{}",
backup_id,
vhost_safe,
queue,
seq,
compression::extension(compression)
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_segment_key_default_vhost() {
let key = segment_key("backup-001", "/", "orders", 1, CompressionType::Zstd);
assert_eq!(key, "backup-001/queues/_default/orders/segment-0001.zst");
}
#[test]
fn test_segment_key_custom_vhost() {
let key = segment_key(
"backup-001",
"production",
"payments",
42,
CompressionType::Lz4,
);
assert_eq!(
key,
"backup-001/queues/production/payments/segment-0042.lz4"
);
}
#[test]
fn test_segment_key_no_compression() {
let key = segment_key("backup-001", "/", "test", 1, CompressionType::None);
assert_eq!(key, "backup-001/queues/_default/test/segment-0001");
}
}