use std::sync::Arc;
use std::time::Duration;
use bytes::Bytes;
use tokio::sync::{broadcast, Mutex, Semaphore};
use tracing::{debug, info, warn};
use crate::amqp::{amqp_url_with_vhost, AmqpClient, ConfirmStats};
use crate::compression;
use crate::config::{
CompressionType, Config, PublishMode, QueueType, RestoreOptions, TargetConfig,
};
use crate::definitions::{DefinitionsImporter, ManagementClient};
use crate::error::{Error, Result};
use crate::manifest::{BackupManifest, BackupRecord};
use crate::offset_store::{QueueProgressUpdate, SqliteOffsetStore};
use crate::restore::publisher::Publisher;
use crate::segment::SegmentReader;
use crate::storage::{create_backend_from_config, StorageBackend};
#[derive(Debug, Clone, Default)]
pub struct RestoreStats {
pub restored: u64,
pub skipped: u64,
pub failed: u64,
pub queues_processed: usize,
}
pub struct RestoreEngine {
config: Config,
}
impl RestoreEngine {
pub fn new(config: Config) -> Self {
Self { config }
}
pub async fn run(&self, shutdown_rx: broadcast::Receiver<()>) -> Result<RestoreStats> {
let storage = create_backend_from_config(&self.config.storage)?;
info!("Storage backend initialized");
let manifest_key = format!("{}/manifest.json", self.config.backup_id);
let manifest_data = storage.get(&manifest_key).await?;
let manifest = BackupManifest::from_json(&manifest_data)?;
info!(
"Loaded manifest: {} queues, {} messages, {} segments",
manifest.queues.len(),
manifest.total_messages,
manifest.total_segments
);
let restore_opts = self.config.restore.clone().unwrap_or_default();
let checkpoint = if let Some(path) = restore_opts.checkpoint_state.as_ref() {
let store = SqliteOffsetStore::new(path).await?;
store
.set_job_status(&self.config.backup_id, "restore-running")
.await?;
Some(Arc::new(store))
} else {
None
};
if restore_opts.dry_run {
info!("DRY RUN — no messages will be published");
println!("Dry run summary:");
println!(" Backup ID: {}", manifest.backup_id);
println!(" Queues: {}", manifest.queues.len());
println!(" Messages: {}", manifest.total_messages);
println!(" Segments: {}", manifest.total_segments);
println!(" Size: {} bytes", manifest.total_bytes);
if let (Some(start), Some(end)) =
(restore_opts.time_window_start, restore_opts.time_window_end)
{
println!(" PITR window: {} - {} (epoch ms)", start, end);
}
return Ok(RestoreStats {
skipped: manifest.total_messages,
..Default::default()
});
}
if restore_opts.restore_definitions {
if let Some(ref defs) = manifest.definitions {
info!("Restoring definitions from {}", defs.key);
self.restore_definitions(&storage, defs, &restore_opts)
.await?;
info!("Definitions restored successfully");
}
}
if manifest.queues.is_empty() {
info!("No queues to restore");
return Ok(RestoreStats::default());
}
let stats = Arc::new(Mutex::new(RestoreStats::default()));
let semaphore = Arc::new(Semaphore::new(restore_opts.max_concurrent_queues.max(1)));
let mut handles = Vec::new();
let mut error_count = 0;
for queue_backup in &manifest.queues {
let sem = semaphore.clone();
let storage = storage.clone();
let stats = stats.clone();
let config = self.config.clone();
let queue_backup = queue_backup.clone();
let checkpoint = checkpoint.clone();
let shutdown_rx = shutdown_rx.resubscribe();
handles.push(tokio::spawn(async move {
let _permit = sem.acquire().await;
restore_single_queue(
queue_backup,
config,
storage,
stats,
checkpoint,
shutdown_rx,
)
.await
}));
}
let results = futures::future::join_all(handles).await;
for result in results {
match result {
Ok(Ok(())) => {}
Ok(Err(e)) => {
warn!("Queue restore error: {}", e);
error_count += 1;
}
Err(e) => {
warn!("Task join error: {}", e);
error_count += 1;
}
}
}
let stats = stats.lock().await.clone();
if error_count > 0 {
if let Some(checkpoint) = &checkpoint {
checkpoint
.set_job_status(&self.config.backup_id, "restore-failed")
.await?;
}
return Err(Error::Manifest(format!(
"{} queue(s) failed during restore",
error_count
)));
}
if let Some(checkpoint) = &checkpoint {
checkpoint
.set_job_status(&self.config.backup_id, "restore-completed")
.await?;
}
info!(
"Restore complete: {} restored, {} skipped, {} failed ({} queues)",
stats.restored, stats.skipped, stats.failed, stats.queues_processed
);
Ok(stats)
}
async fn restore_definitions(
&self,
storage: &Arc<dyn StorageBackend>,
defs: &crate::manifest::DefinitionsBackup,
restore_opts: &RestoreOptions,
) -> Result<()> {
let compressed = storage.get(&defs.key).await?;
let ext = &defs.key;
let compression_type = compression::detect_from_extension(ext);
let json = compression::decompress(&compressed, compression_type)?;
let definitions = DefinitionsImporter::validate_json(&json)?;
let definitions = DefinitionsImporter::filter_definitions(
&definitions,
&restore_opts.definitions_selection,
)?;
if !restore_opts.definitions_selection.is_empty() {
info!(
"Definitions selection applied: {} vhosts, {} queues, {} exchanges, {} bindings",
definitions.vhosts.len(),
definitions.queues.len(),
definitions.exchanges.len(),
definitions.bindings.len()
);
}
if restore_opts.definitions_dry_run {
info!("Definitions dry-run: would import {} bytes", json.len());
return Ok(());
}
let (mgmt_url, mgmt_user, mgmt_pass) = if let Some(ref target) = self.config.target {
(
target
.management_url
.as_deref()
.ok_or_else(|| Error::Config("Target management_url required".to_string()))?,
target.management_username.as_deref().ok_or_else(|| {
Error::Config("Target management_username required".to_string())
})?,
target.management_password.as_deref().ok_or_else(|| {
Error::Config("Target management_password required".to_string())
})?,
)
} else {
return Err(Error::Config(
"Target configuration required for restore".to_string(),
));
};
let mgmt = ManagementClient::new(mgmt_url, mgmt_user, mgmt_pass)?;
let rollback_key =
write_definitions_rollback(&mgmt, storage, &self.config.backup_id, compression_type)
.await?;
info!("Current target definitions rollback snapshot written to {rollback_key}");
let importer = DefinitionsImporter::new(mgmt);
importer.import(&definitions).await
}
}
async fn write_definitions_rollback(
mgmt: &ManagementClient,
storage: &Arc<dyn StorageBackend>,
backup_id: &str,
compression_type: CompressionType,
) -> Result<String> {
let current_definitions = mgmt.export_definitions(None).await?;
let json = serde_json::to_vec_pretty(¤t_definitions)?;
let compressed = compression::compress(&json, compression_type)?;
let key = format!(
"{}/definitions/rollback-before-import-{}{}",
backup_id,
chrono::Utc::now().timestamp_millis(),
compression::extension(compression_type)
);
storage.put(&key, Bytes::from(compressed)).await?;
Ok(key)
}
async fn restore_single_queue(
queue_backup: crate::manifest::QueueBackup,
config: Config,
storage: Arc<dyn StorageBackend>,
stats: Arc<Mutex<RestoreStats>>,
checkpoint: Option<Arc<SqliteOffsetStore>>,
mut shutdown_rx: broadcast::Receiver<()>,
) -> Result<()> {
let restore_opts = config.restore.clone().unwrap_or_default();
let target = config
.target
.as_ref()
.ok_or_else(|| Error::Config("Target configuration required for restore".to_string()))?;
let target_queue = restore_opts
.queue_mapping
.get(&queue_backup.name)
.cloned()
.unwrap_or_else(|| queue_backup.name.clone());
let target_vhost = restore_opts
.vhost_mapping
.get(&queue_backup.vhost)
.cloned()
.unwrap_or_else(|| queue_backup.vhost.clone());
info!(
"Restoring queue {}/{} → {}/{} ({} segments, {} messages)",
queue_backup.vhost,
queue_backup.name,
target_vhost,
target_queue,
queue_backup.segments.len(),
queue_backup.message_count
);
let mut resume_after_records = 0u64;
if let Some(checkpoint) = &checkpoint {
if let Some(progress) = checkpoint
.get_progress(&config.backup_id, &queue_backup.vhost, &queue_backup.name)
.await?
{
if progress.completed && progress.target_message_count == queue_backup.message_count {
info!(
"Skipping already-restored queue {}/{} from checkpoint",
queue_backup.vhost, queue_backup.name
);
let mut s = stats.lock().await;
s.queues_processed += 1;
return Ok(());
}
resume_after_records = progress.messages_backed_up.min(queue_backup.message_count);
if resume_after_records > 0 {
info!(
"Resuming restore for {}/{} after {} processed record(s)",
queue_backup.vhost, queue_backup.name, resume_after_records
);
}
}
}
if restore_opts.create_missing_queues {
ensure_target_queue_exists(
target,
&restore_opts,
&target_vhost,
&target_queue,
queue_backup.queue_type,
)
.await?;
}
let target_amqp_url = amqp_url_with_vhost(&target.amqp_url, &target_vhost)?;
let mut client = AmqpClient::connect(&target_amqp_url, None).await?;
let channel_id = client.open_channel().await?;
let mut publisher = Publisher::new(client, channel_id);
if restore_opts.publisher_confirms {
publisher.enable_confirms().await?;
}
let mut queue_restored = 0u64;
let mut queue_skipped = 0u64;
let mut queue_failed = 0u64;
let mut raw_records_seen = 0u64;
let mut stopped_early = false;
for segment_meta in &queue_backup.segments {
if shutdown_rx.try_recv().is_ok() {
info!(
"Shutdown signal received during restore of {}",
target_queue
);
stopped_early = true;
break;
}
let segment_data = storage.get(&segment_meta.key).await?;
let records = SegmentReader::read_records(&segment_data)?;
debug!(
"Read segment {} ({} records)",
segment_meta.key,
records.len()
);
let batch_size = restore_opts.produce_batch_size.max(1);
let mut batch = Vec::with_capacity(batch_size);
for record in &records {
raw_records_seen += 1;
if raw_records_seen <= resume_after_records {
continue;
}
if !should_include(record, &restore_opts) {
queue_skipped += 1;
set_restore_progress(
&checkpoint,
&config.backup_id,
&queue_backup,
raw_records_seen,
segment_meta.sequence,
false,
)
.await?;
continue;
}
batch.push((record, raw_records_seen));
if batch.len() >= batch_size {
let last_record_index = batch.last().map(|(_, index)| *index).unwrap_or(0);
let (restored, failed) =
publish_restore_batch(&mut publisher, &restore_opts, &target_queue, &batch)
.await?;
queue_restored += restored;
queue_failed += failed;
if failed == 0 {
set_restore_progress(
&checkpoint,
&config.backup_id,
&queue_backup,
last_record_index,
segment_meta.sequence,
false,
)
.await?;
}
batch.clear();
}
}
if !batch.is_empty() {
let last_record_index = batch.last().map(|(_, index)| *index).unwrap_or(0);
let (restored, failed) =
publish_restore_batch(&mut publisher, &restore_opts, &target_queue, &batch).await?;
queue_restored += restored;
queue_failed += failed;
if failed == 0 {
set_restore_progress(
&checkpoint,
&config.backup_id,
&queue_backup,
last_record_index,
segment_meta.sequence,
false,
)
.await?;
}
}
}
if !stopped_early && queue_failed == 0 {
let last_segment_sequence = queue_backup
.segments
.last()
.map(|segment| segment.sequence)
.unwrap_or(0);
set_restore_progress(
&checkpoint,
&config.backup_id,
&queue_backup,
queue_backup.message_count,
last_segment_sequence,
true,
)
.await?;
}
publisher.close().await;
{
let mut s = stats.lock().await;
s.restored += queue_restored;
s.skipped += queue_skipped;
s.failed += queue_failed;
s.queues_processed += 1;
}
info!(
"Queue {} restored: {} published, {} skipped (PITR), {} failed",
target_queue, queue_restored, queue_skipped, queue_failed
);
if queue_failed > 0 {
return Err(Error::Amqp(format!(
"{} message(s) failed while restoring queue {}",
queue_failed, target_queue
)));
}
Ok(())
}
async fn ensure_target_queue_exists(
target: &TargetConfig,
restore_opts: &RestoreOptions,
target_vhost: &str,
target_queue: &str,
queue_type: QueueType,
) -> Result<()> {
if restore_opts.publish_mode != PublishMode::DirectToQueue {
warn!(
"restore.create_missing_queues is only applied for direct-to-queue publish mode; exchange mode requires target topology to exist"
);
return Ok(());
}
let mgmt = target_management_client(target)?;
if mgmt.queue_exists(target_vhost, target_queue).await? {
return Ok(());
}
info!(
"Target queue {}/{} is missing; creating it before restore",
target_vhost, target_queue
);
mgmt.declare_queue(target_vhost, target_queue, queue_type)
.await
}
fn target_management_client(target: &TargetConfig) -> Result<ManagementClient> {
let mgmt_url = target.management_url.as_deref().ok_or_else(|| {
Error::Config(
"target.management_url is required when restore.create_missing_queues is true"
.to_string(),
)
})?;
let mgmt_user = target.management_username.as_deref().ok_or_else(|| {
Error::Config(
"target.management_username is required when restore.create_missing_queues is true"
.to_string(),
)
})?;
let mgmt_password = target.management_password.as_deref().ok_or_else(|| {
Error::Config(
"target.management_password is required when restore.create_missing_queues is true"
.to_string(),
)
})?;
ManagementClient::new(mgmt_url, mgmt_user, mgmt_password)
}
async fn publish_restore_batch(
publisher: &mut Publisher,
restore_opts: &RestoreOptions,
target_queue: &str,
batch: &[(&BackupRecord, u64)],
) -> Result<(u64, u64)> {
if restore_opts.rate_limit_messages_per_sec > 0 {
let delay = Duration::from_secs_f64(
batch.len() as f64 / restore_opts.rate_limit_messages_per_sec as f64,
);
tokio::time::sleep(delay).await;
}
let mut attempted = 0u64;
let mut failed = 0u64;
let mut last_tag = 0u64;
for (record, _) in batch {
let (exchange, routing_key) = match restore_opts.publish_mode {
PublishMode::Exchange => {
let ex = restore_opts
.exchange_mapping
.get(&record.exchange)
.cloned()
.unwrap_or_else(|| record.exchange.clone());
(ex, record.routing_key.clone())
}
PublishMode::DirectToQueue => (String::new(), target_queue.to_string()),
};
match publisher
.publish_record(record, &exchange, &routing_key)
.await
{
Ok(tag) => {
attempted += 1;
last_tag = tag;
}
Err(e) => {
warn!("Failed to publish message: {}", e);
failed += 1;
}
}
}
if restore_opts.publisher_confirms && attempted > 0 {
let confirm_stats = publisher.wait_for_confirms(last_tag).await?;
let (confirmed, confirm_failed) = confirmed_and_failed(attempted, confirm_stats);
Ok((confirmed, failed + confirm_failed))
} else {
Ok((attempted, failed))
}
}
async fn set_restore_progress(
checkpoint: &Option<Arc<SqliteOffsetStore>>,
backup_id: &str,
queue_backup: &crate::manifest::QueueBackup,
messages_processed: u64,
last_segment_sequence: u64,
completed: bool,
) -> Result<()> {
if let Some(checkpoint) = checkpoint {
checkpoint
.set_progress_state(QueueProgressUpdate {
backup_id,
vhost: &queue_backup.vhost,
queue_name: &queue_backup.name,
messages_backed_up: messages_processed,
last_segment_sequence,
target_message_count: queue_backup.message_count,
completed,
})
.await?;
}
Ok(())
}
fn should_include(record: &BackupRecord, opts: &RestoreOptions) -> bool {
let timestamp_ms = pitr_timestamp_ms(record);
let after_start = opts.time_window_start.is_none_or(|s| timestamp_ms >= s);
let before_end = opts.time_window_end.is_none_or(|e| timestamp_ms <= e);
after_start && before_end
}
fn pitr_timestamp_ms(record: &BackupRecord) -> i64 {
match record.properties.timestamp {
Some(ts) if (-10_000_000_000..10_000_000_000).contains(&ts) => ts * 1000,
Some(ts) => ts,
None => record.backed_up_at,
}
}
fn confirmed_and_failed(attempted: u64, confirm_stats: ConfirmStats) -> (u64, u64) {
let failed = confirm_stats.failed();
(attempted.saturating_sub(failed), failed)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::manifest::BackupProperties;
fn make_record(ts: i64) -> BackupRecord {
BackupRecord {
body: Some(b"test".to_vec()),
properties: BackupProperties::default(),
headers: vec![],
exchange: "".to_string(),
routing_key: "".to_string(),
delivery_tag: 0,
redelivered: false,
backed_up_at: ts,
source_queue: "q".to_string(),
source_vhost: "/".to_string(),
}
}
#[test]
fn test_pitr_no_filter() {
let opts = RestoreOptions::default();
assert!(should_include(&make_record(1000), &opts));
}
#[test]
fn test_pitr_start_only() {
let opts = RestoreOptions {
time_window_start: Some(500),
..Default::default()
};
assert!(should_include(&make_record(1000), &opts));
assert!(!should_include(&make_record(100), &opts));
}
#[test]
fn test_pitr_end_only() {
let opts = RestoreOptions {
time_window_end: Some(500),
..Default::default()
};
assert!(should_include(&make_record(100), &opts));
assert!(!should_include(&make_record(1000), &opts));
}
#[test]
fn test_pitr_window() {
let opts = RestoreOptions {
time_window_start: Some(100),
time_window_end: Some(500),
..Default::default()
};
assert!(should_include(&make_record(300), &opts));
assert!(!should_include(&make_record(50), &opts));
assert!(!should_include(&make_record(600), &opts));
}
#[test]
fn test_pitr_prefers_original_message_timestamp() {
let mut record = make_record(2_000_000);
record.properties.timestamp = Some(1_700_000_000);
let opts = RestoreOptions {
time_window_start: Some(1_699_999_999_000),
time_window_end: Some(1_700_000_001_000),
..Default::default()
};
assert!(should_include(&record, &opts));
}
#[test]
fn test_pitr_falls_back_to_backup_timestamp() {
let record = make_record(2_000_000);
let opts = RestoreOptions {
time_window_start: Some(1_999_999),
time_window_end: Some(2_000_001),
..Default::default()
};
assert!(should_include(&record, &opts));
}
#[test]
fn test_confirm_failures_are_not_counted_as_restored() {
let (confirmed, failed) = confirmed_and_failed(
10,
ConfirmStats {
nacked: 2,
returned: 1,
},
);
assert_eq!(confirmed, 7);
assert_eq!(failed, 3);
}
}