use std::collections::HashMap;
use std::env;
use std::path::PathBuf;
use serde::{Deserialize, Serialize};
use crate::storage::StorageBackendConfig;
pub fn config_from_yaml_str(input: &str) -> crate::Result<Config> {
let interpolated = interpolate_env_vars(input)?;
Ok(serde_yaml::from_str(&interpolated)?)
}
pub fn interpolate_env_vars(input: &str) -> crate::Result<String> {
let re = regex::Regex::new(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
.map_err(|e| crate::Error::Config(format!("Invalid env interpolation regex: {e}")))?;
let mut output = String::with_capacity(input.len());
let mut last = 0;
for captures in re.captures_iter(input) {
let whole = captures
.get(0)
.expect("regex capture 0 should always exist");
let var = captures
.get(1)
.expect("regex capture 1 should always exist")
.as_str();
output.push_str(&input[last..whole.start()]);
let value = env::var(var)
.map_err(|_| crate::Error::Config(format!("Environment variable {var} is not set")))?;
output.push_str(&value);
last = whole.end();
}
output.push_str(&input[last..]);
Ok(output)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub mode: Mode,
pub backup_id: String,
#[serde(default)]
pub source: Option<SourceConfig>,
#[serde(default)]
pub target: Option<TargetConfig>,
pub storage: StorageBackendConfig,
#[serde(default)]
pub backup: Option<BackupOptions>,
#[serde(default)]
pub restore: Option<RestoreOptions>,
#[serde(default)]
pub offset_storage: Option<OffsetStorageConfig>,
#[serde(default)]
pub metrics: Option<MetricsConfig>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Mode {
Backup,
Restore,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SourceConfig {
pub amqp_url: String,
pub management_url: String,
pub management_username: String,
pub management_password: String,
#[serde(default = "default_stream_port")]
pub stream_port: u16,
#[serde(default)]
pub tls: Option<TlsConfig>,
#[serde(default)]
pub queues: Option<QueueSelection>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TargetConfig {
pub amqp_url: String,
#[serde(default)]
pub management_url: Option<String>,
#[serde(default)]
pub management_username: Option<String>,
#[serde(default)]
pub management_password: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TlsConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default)]
pub ca_cert: Option<PathBuf>,
#[serde(default)]
pub client_cert: Option<PathBuf>,
#[serde(default)]
pub client_key: Option<PathBuf>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct QueueSelection {
#[serde(default)]
pub include: Vec<String>,
#[serde(default)]
pub exclude: Vec<String>,
#[serde(default)]
pub vhosts: Vec<String>,
#[serde(default)]
pub types: Vec<QueueType>,
#[serde(default)]
pub min_messages: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum QueueType {
Classic,
Quorum,
Stream,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackupOptions {
#[serde(default = "default_segment_max_bytes")]
pub segment_max_bytes: u64,
#[serde(default = "default_segment_max_interval_ms")]
pub segment_max_interval_ms: u64,
#[serde(default)]
pub compression: CompressionType,
#[serde(default = "default_compression_level")]
pub compression_level: i32,
#[serde(default = "default_prefetch_count")]
pub prefetch_count: u16,
#[serde(default)]
pub requeue_strategy: RequeueStrategy,
#[serde(default = "default_max_concurrent_queues")]
pub max_concurrent_queues: usize,
#[serde(default = "default_checkpoint_interval_secs")]
pub checkpoint_interval_secs: u64,
#[serde(default = "default_sync_interval_secs")]
pub sync_interval_secs: u64,
#[serde(default = "default_true")]
pub include_definitions: bool,
#[serde(default = "default_true")]
pub stop_at_current_depth: bool,
#[serde(default)]
pub stream_enabled: bool,
}
impl Default for BackupOptions {
fn default() -> Self {
Self {
segment_max_bytes: default_segment_max_bytes(),
segment_max_interval_ms: default_segment_max_interval_ms(),
compression: CompressionType::default(),
compression_level: default_compression_level(),
prefetch_count: default_prefetch_count(),
requeue_strategy: RequeueStrategy::default(),
max_concurrent_queues: default_max_concurrent_queues(),
checkpoint_interval_secs: default_checkpoint_interval_secs(),
sync_interval_secs: default_sync_interval_secs(),
include_definitions: true,
stop_at_current_depth: true,
stream_enabled: false,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum CompressionType {
#[default]
Zstd,
Lz4,
None,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum RequeueStrategy {
Nack,
Reject,
#[default]
Cancel,
Get,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RestoreOptions {
#[serde(default)]
pub time_window_start: Option<i64>,
#[serde(default)]
pub time_window_end: Option<i64>,
#[serde(default)]
pub queue_mapping: HashMap<String, String>,
#[serde(default)]
pub vhost_mapping: HashMap<String, String>,
#[serde(default)]
pub exchange_mapping: HashMap<String, String>,
#[serde(default)]
pub publish_mode: PublishMode,
#[serde(default = "default_true")]
pub publisher_confirms: bool,
#[serde(default = "default_max_concurrent_queues")]
pub max_concurrent_queues: usize,
#[serde(default = "default_produce_batch_size")]
pub produce_batch_size: usize,
#[serde(default)]
pub rate_limit_messages_per_sec: u64,
#[serde(default = "default_true")]
pub restore_definitions: bool,
#[serde(default)]
pub definitions_dry_run: bool,
#[serde(default)]
pub definitions_selection: DefinitionsSelection,
#[serde(default)]
pub create_missing_queues: bool,
#[serde(default)]
pub checkpoint_state: Option<PathBuf>,
#[serde(default = "default_restore_checkpoint_interval")]
pub checkpoint_interval_secs: u64,
#[serde(default)]
pub dry_run: bool,
}
impl Default for RestoreOptions {
fn default() -> Self {
Self {
time_window_start: None,
time_window_end: None,
queue_mapping: HashMap::new(),
vhost_mapping: HashMap::new(),
exchange_mapping: HashMap::new(),
publish_mode: PublishMode::default(),
publisher_confirms: true,
max_concurrent_queues: default_max_concurrent_queues(),
produce_batch_size: default_produce_batch_size(),
rate_limit_messages_per_sec: 0,
restore_definitions: true,
definitions_dry_run: false,
definitions_selection: DefinitionsSelection::default(),
create_missing_queues: false,
checkpoint_state: None,
checkpoint_interval_secs: default_restore_checkpoint_interval(),
dry_run: false,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct DefinitionsSelection {
#[serde(default)]
pub vhosts: Vec<String>,
#[serde(default)]
pub queues: Vec<String>,
#[serde(default)]
pub exchanges: Vec<String>,
}
impl DefinitionsSelection {
pub fn is_empty(&self) -> bool {
self.vhosts.is_empty() && self.queues.is_empty() && self.exchanges.is_empty()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
#[derive(Default)]
pub enum PublishMode {
#[default]
Exchange,
DirectToQueue,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OffsetStorageConfig {
#[serde(default)]
pub backend: OffsetStorageBackend,
#[serde(default = "default_db_path")]
pub db_path: PathBuf,
#[serde(default)]
pub s3_key: Option<String>,
#[serde(default = "default_sync_interval_secs")]
pub sync_interval_secs: u64,
}
impl Default for OffsetStorageConfig {
fn default() -> Self {
Self {
backend: OffsetStorageBackend::default(),
db_path: default_db_path(),
s3_key: None,
sync_interval_secs: default_sync_interval_secs(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum OffsetStorageBackend {
#[default]
Sqlite,
Memory,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetricsConfig {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default = "default_metrics_port")]
pub port: u16,
#[serde(default = "default_bind_address")]
pub bind_address: String,
#[serde(default = "default_metrics_path")]
pub path: String,
}
impl Default for MetricsConfig {
fn default() -> Self {
Self {
enabled: true,
port: default_metrics_port(),
bind_address: default_bind_address(),
path: default_metrics_path(),
}
}
}
fn default_stream_port() -> u16 {
5552
}
fn default_segment_max_bytes() -> u64 {
134_217_728 }
fn default_segment_max_interval_ms() -> u64 {
60_000 }
fn default_compression_level() -> i32 {
3
}
fn default_prefetch_count() -> u16 {
100
}
fn default_max_concurrent_queues() -> usize {
4
}
fn default_checkpoint_interval_secs() -> u64 {
5
}
fn default_sync_interval_secs() -> u64 {
30
}
fn default_produce_batch_size() -> usize {
100
}
fn default_restore_checkpoint_interval() -> u64 {
60
}
fn default_db_path() -> PathBuf {
PathBuf::from("./offsets.db")
}
fn default_metrics_port() -> u16 {
8080
}
fn default_bind_address() -> String {
"0.0.0.0".to_string()
}
fn default_metrics_path() -> String {
"/metrics".to_string()
}
fn default_true() -> bool {
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_backup_options_defaults() {
let opts = BackupOptions::default();
assert_eq!(opts.segment_max_bytes, 134_217_728);
assert_eq!(opts.compression, CompressionType::Zstd);
assert_eq!(opts.compression_level, 3);
assert_eq!(opts.prefetch_count, 100);
assert_eq!(opts.requeue_strategy, RequeueStrategy::Cancel);
assert_eq!(opts.max_concurrent_queues, 4);
assert!(opts.include_definitions);
assert!(opts.stop_at_current_depth);
}
#[test]
fn test_restore_options_defaults() {
let opts = RestoreOptions::default();
assert_eq!(opts.publish_mode, PublishMode::Exchange);
assert!(opts.publisher_confirms);
assert!(!opts.dry_run);
assert!(!opts.create_missing_queues);
}
#[test]
fn test_config_yaml_deserialization() {
let yaml = r#"
mode: backup
backup_id: "test-backup-001"
source:
amqp_url: "amqp://guest:guest@localhost:5672/%2f"
management_url: "http://localhost:15672"
management_username: guest
management_password: guest
queues:
include:
- "orders-*"
exclude:
- "*-dead-letter"
storage:
backend: filesystem
path: /tmp/backups
backup:
compression: zstd
prefetch_count: 50
requeue_strategy: cancel
"#;
let config: Config = serde_yaml::from_str(yaml).unwrap();
assert_eq!(config.mode, Mode::Backup);
assert_eq!(config.backup_id, "test-backup-001");
assert!(config.source.is_some());
assert!(config.target.is_none());
let backup = config.backup.unwrap();
assert_eq!(backup.compression, CompressionType::Zstd);
assert_eq!(backup.prefetch_count, 50);
assert_eq!(backup.requeue_strategy, RequeueStrategy::Cancel);
}
#[test]
fn test_env_interpolation() {
let var_name = "RABBITMQ_BACKUP_TEST_INTERPOLATION";
std::env::set_var(var_name, "interpolated-secret");
let yaml = r#"
mode: backup
backup_id: "test-backup-001"
source:
amqp_url: "amqp://guest:${RABBITMQ_BACKUP_TEST_INTERPOLATION}@localhost:5672/%2f"
management_url: "http://localhost:15672"
management_username: guest
management_password: ${RABBITMQ_BACKUP_TEST_INTERPOLATION}
storage:
backend: filesystem
path: /tmp/backups
"#;
let config = config_from_yaml_str(yaml).unwrap();
let source = config.source.unwrap();
assert_eq!(source.management_password, "interpolated-secret");
assert_eq!(
source.amqp_url,
"amqp://guest:interpolated-secret@localhost:5672/%2f"
);
std::env::remove_var(var_name);
}
#[test]
fn test_env_interpolation_missing_var_errors() {
std::env::remove_var("RABBITMQ_BACKUP_TEST_MISSING_INTERPOLATION");
let err = interpolate_env_vars("${RABBITMQ_BACKUP_TEST_MISSING_INTERPOLATION}")
.unwrap_err()
.to_string();
assert!(err.contains("RABBITMQ_BACKUP_TEST_MISSING_INTERPOLATION"));
}
}