use std::collections::HashSet;
use tracing::{info, warn};
use crate::config::DefinitionsSelection;
use crate::definitions::management::ManagementClient;
use crate::definitions::types::RabbitMqDefinitions;
use crate::error::{Error, Result};
pub struct DefinitionsImporter {
client: ManagementClient,
}
impl DefinitionsImporter {
pub fn new(client: ManagementClient) -> Self {
Self { client }
}
pub async fn import(&self, definitions: &RabbitMqDefinitions) -> Result<()> {
self.validate_compatibility(definitions).await?;
info!(
"Importing definitions: {} users, {} vhosts, {} queues, {} exchanges, {} bindings",
definitions.users.len(),
definitions.vhosts.len(),
definitions.queues.len(),
definitions.exchanges.len(),
definitions.bindings.len(),
);
self.client.import_definitions(definitions).await?;
info!("Definitions imported successfully");
Ok(())
}
pub async fn import_json(&self, json: &[u8]) -> Result<()> {
let definitions = Self::validate_json(json)?;
self.import(&definitions).await
}
pub fn filter_definitions(
definitions: &RabbitMqDefinitions,
selection: &DefinitionsSelection,
) -> Result<RabbitMqDefinitions> {
if selection.is_empty() {
return Ok(definitions.clone());
}
let selected_vhosts: HashSet<&str> = selection.vhosts.iter().map(String::as_str).collect();
let selected_queues: HashSet<&str> = selection.queues.iter().map(String::as_str).collect();
let selected_exchanges: HashSet<&str> =
selection.exchanges.iter().map(String::as_str).collect();
validate_selection(
definitions,
&selected_vhosts,
&selected_queues,
&selected_exchanges,
)?;
let vhost_only = !selected_vhosts.is_empty()
&& selected_queues.is_empty()
&& selected_exchanges.is_empty();
let mut kept_vhosts: HashSet<String> = selected_vhosts
.iter()
.map(|vhost| (*vhost).to_string())
.collect();
let mut kept_queues: HashSet<(String, String)> = HashSet::new();
let mut kept_exchanges: HashSet<(String, String)> = HashSet::new();
for queue in &definitions.queues {
let Some(vhost) = object_string(queue, "vhost") else {
continue;
};
let Some(name) = object_string(queue, "name") else {
continue;
};
if !selected_vhosts.is_empty() && !selected_vhosts.contains(vhost.as_str()) {
continue;
}
if vhost_only || selected_queues.contains(name.as_str()) {
kept_vhosts.insert(vhost.clone());
kept_queues.insert((vhost, name));
}
}
for exchange in &definitions.exchanges {
let Some(vhost) = object_string(exchange, "vhost") else {
continue;
};
let Some(name) = object_string(exchange, "name") else {
continue;
};
if !selected_vhosts.is_empty() && !selected_vhosts.contains(vhost.as_str()) {
continue;
}
if vhost_only || selected_exchanges.contains(name.as_str()) {
kept_vhosts.insert(vhost.clone());
kept_exchanges.insert((vhost, name));
}
}
let available_exchanges: HashSet<(String, String)> = definitions
.exchanges
.iter()
.filter_map(|exchange| {
Some((
object_string(exchange, "vhost")?,
object_string(exchange, "name")?,
))
})
.collect();
let mut changed = true;
while changed {
changed = false;
for binding in &definitions.bindings {
let Some(vhost) = object_string(binding, "vhost") else {
continue;
};
if !kept_vhosts.contains(&vhost) {
continue;
}
let Some(source) = object_string(binding, "source") else {
continue;
};
let Some(destination) = object_string(binding, "destination") else {
continue;
};
let Some(destination_type) = object_string(binding, "destination_type") else {
continue;
};
let destination_kept = match destination_type.as_str() {
"queue" => kept_queues.contains(&(vhost.clone(), destination)),
"exchange" => kept_exchanges.contains(&(vhost.clone(), destination)),
_ => false,
};
let source_key = (vhost.clone(), source);
if destination_kept
&& available_exchanges.contains(&source_key)
&& kept_exchanges.insert(source_key)
{
changed = true;
}
}
}
let bindings = definitions
.bindings
.iter()
.filter(|binding| {
let Some(vhost) = object_string(binding, "vhost") else {
return false;
};
let Some(source) = object_string(binding, "source") else {
return false;
};
let Some(destination) = object_string(binding, "destination") else {
return false;
};
let Some(destination_type) = object_string(binding, "destination_type") else {
return false;
};
kept_exchanges.contains(&(vhost.clone(), source))
&& match destination_type.as_str() {
"queue" => kept_queues.contains(&(vhost.clone(), destination)),
"exchange" => kept_exchanges.contains(&(vhost.clone(), destination)),
_ => false,
}
})
.cloned()
.collect();
let keep_vhost_scoped = |value: &serde_json::Value| {
object_string(value, "vhost").is_some_and(|vhost| kept_vhosts.contains(&vhost))
};
let permissions: Vec<_> = definitions
.permissions
.iter()
.filter(|value| vhost_only && keep_vhost_scoped(value))
.cloned()
.collect();
let topic_permissions: Vec<_> = definitions
.topic_permissions
.iter()
.filter(|value| vhost_only && keep_vhost_scoped(value))
.cloned()
.collect();
Ok(RabbitMqDefinitions {
rabbit_version: definitions.rabbit_version.clone(),
users: if permissions.is_empty() && topic_permissions.is_empty() {
vec![]
} else {
definitions.users.clone()
},
vhosts: definitions
.vhosts
.iter()
.filter(|vhost| {
object_string(vhost, "name").is_some_and(|name| kept_vhosts.contains(&name))
})
.cloned()
.collect(),
queues: definitions
.queues
.iter()
.filter(|queue| {
let Some(vhost) = object_string(queue, "vhost") else {
return false;
};
let Some(name) = object_string(queue, "name") else {
return false;
};
kept_queues.contains(&(vhost, name))
})
.cloned()
.collect(),
exchanges: definitions
.exchanges
.iter()
.filter(|exchange| {
let Some(vhost) = object_string(exchange, "vhost") else {
return false;
};
let Some(name) = object_string(exchange, "name") else {
return false;
};
kept_exchanges.contains(&(vhost, name))
})
.cloned()
.collect(),
bindings,
policies: definitions
.policies
.iter()
.filter(|value| vhost_only && keep_vhost_scoped(value))
.cloned()
.collect(),
parameters: definitions
.parameters
.iter()
.filter(|value| vhost_only && keep_vhost_scoped(value))
.cloned()
.collect(),
global_parameters: vec![],
permissions,
topic_permissions,
})
}
pub fn validate_json(json: &[u8]) -> Result<RabbitMqDefinitions> {
let value: serde_json::Value = serde_json::from_slice(json)?;
let object = value.as_object().ok_or_else(|| {
Error::Serialization("RabbitMQ definitions must be a JSON object".to_string())
})?;
let rabbit_version = object.get("rabbit_version").ok_or_else(|| {
Error::Serialization("Definitions missing required key: rabbit_version".to_string())
})?;
if !rabbit_version.is_string() {
return Err(Error::Serialization(
"Definitions key rabbit_version must be a string".to_string(),
));
}
for key in ["users", "vhosts", "queues"] {
let value = object.get(key).ok_or_else(|| {
Error::Serialization(format!("Definitions missing required key: {key}"))
})?;
if !value.is_array() {
return Err(Error::Serialization(format!(
"Definitions key {key} must be an array"
)));
}
}
Ok(serde_json::from_value(value)?)
}
async fn validate_compatibility(&self, definitions: &RabbitMqDefinitions) -> Result<()> {
let Some(source_version) = definitions.rabbit_version.as_deref() else {
return Err(Error::Serialization(
"Definitions missing rabbit_version".to_string(),
));
};
let overview = self.client.get_overview().await?;
let Some(target_version) = overview.rabbitmq_version.as_deref() else {
warn!("Target RabbitMQ version unavailable; skipping definitions version check");
return Ok(());
};
if major_version(source_version) != major_version(target_version) {
return Err(Error::ManagementApi(format!(
"Definitions RabbitMQ major version {source_version} is incompatible with target {target_version}"
)));
}
Ok(())
}
}
fn major_version(version: &str) -> Option<&str> {
version.split('.').next().filter(|part| !part.is_empty())
}
fn validate_selection(
definitions: &RabbitMqDefinitions,
selected_vhosts: &HashSet<&str>,
selected_queues: &HashSet<&str>,
selected_exchanges: &HashSet<&str>,
) -> Result<()> {
for vhost in selected_vhosts {
if !definitions
.vhosts
.iter()
.any(|candidate| object_string(candidate, "name").as_deref() == Some(*vhost))
{
return Err(Error::Config(format!(
"definitions_selection references missing vhost {vhost}"
)));
}
}
for queue in selected_queues {
if !definitions.queues.iter().any(|candidate| {
object_string(candidate, "name").as_deref() == Some(*queue)
&& (selected_vhosts.is_empty()
|| object_string(candidate, "vhost")
.as_deref()
.is_some_and(|vhost| selected_vhosts.contains(vhost)))
}) {
return Err(Error::Config(format!(
"definitions_selection references missing queue {queue}"
)));
}
}
for exchange in selected_exchanges {
if !definitions.exchanges.iter().any(|candidate| {
object_string(candidate, "name").as_deref() == Some(*exchange)
&& (selected_vhosts.is_empty()
|| object_string(candidate, "vhost")
.as_deref()
.is_some_and(|vhost| selected_vhosts.contains(vhost)))
}) {
return Err(Error::Config(format!(
"definitions_selection references missing exchange {exchange}"
)));
}
}
Ok(())
}
fn object_string(value: &serde_json::Value, key: &str) -> Option<String> {
value.get(key)?.as_str().map(ToString::to_string)
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_definitions() -> RabbitMqDefinitions {
RabbitMqDefinitions {
rabbit_version: Some("4.0.0".to_string()),
users: vec![serde_json::json!({"name": "guest"})],
vhosts: vec![
serde_json::json!({"name": "/"}),
serde_json::json!({"name": "other"}),
],
queues: vec![
serde_json::json!({"name": "orders", "vhost": "/", "durable": true}),
serde_json::json!({"name": "payments", "vhost": "/", "durable": true}),
serde_json::json!({"name": "other-q", "vhost": "other", "durable": true}),
],
exchanges: vec![
serde_json::json!({"name": "orders-ex", "vhost": "/", "type": "direct"}),
serde_json::json!({"name": "payments-ex", "vhost": "/", "type": "direct"}),
serde_json::json!({"name": "other-ex", "vhost": "other", "type": "direct"}),
],
bindings: vec![
serde_json::json!({
"source": "orders-ex",
"vhost": "/",
"destination": "orders",
"destination_type": "queue",
"routing_key": "orders"
}),
serde_json::json!({
"source": "payments-ex",
"vhost": "/",
"destination": "payments",
"destination_type": "queue",
"routing_key": "payments"
}),
serde_json::json!({
"source": "other-ex",
"vhost": "other",
"destination": "other-q",
"destination_type": "queue",
"routing_key": "other"
}),
],
policies: vec![serde_json::json!({"vhost": "/", "name": "ha"})],
parameters: vec![serde_json::json!({"vhost": "/", "name": "param"})],
global_parameters: vec![serde_json::json!({"name": "cluster_name"})],
permissions: vec![serde_json::json!({"user": "guest", "vhost": "/"})],
topic_permissions: vec![],
}
}
#[test]
fn test_validate_json_requires_object() {
let err = DefinitionsImporter::validate_json(b"[]").unwrap_err();
assert!(err.to_string().contains("JSON object"));
}
#[test]
fn test_validate_json_requires_rabbit_version() {
let err = DefinitionsImporter::validate_json(br#"{"users":[],"vhosts":[],"queues":[]}"#)
.unwrap_err();
assert!(err.to_string().contains("rabbit_version"));
}
#[test]
fn test_validate_json_accepts_required_keys() {
let defs = DefinitionsImporter::validate_json(
br#"{"rabbit_version":"4.0.0","users":[],"vhosts":[],"queues":[]}"#,
)
.unwrap();
assert_eq!(defs.rabbit_version.as_deref(), Some("4.0.0"));
}
#[test]
fn test_major_version() {
assert_eq!(major_version("4.0.5"), Some("4"));
assert_eq!(major_version(""), None);
}
#[test]
fn test_filter_definitions_queue_selection_keeps_binding_source_exchange() {
let filtered = DefinitionsImporter::filter_definitions(
&sample_definitions(),
&DefinitionsSelection {
queues: vec!["orders".to_string()],
..Default::default()
},
)
.unwrap();
assert_eq!(filtered.vhosts.len(), 1);
assert_eq!(filtered.queues.len(), 1);
assert_eq!(filtered.queues[0]["name"], "orders");
assert_eq!(filtered.exchanges.len(), 1);
assert_eq!(filtered.exchanges[0]["name"], "orders-ex");
assert_eq!(filtered.bindings.len(), 1);
assert!(filtered.permissions.is_empty());
assert!(filtered.global_parameters.is_empty());
}
#[test]
fn test_filter_definitions_vhost_selection_keeps_vhost_topology() {
let filtered = DefinitionsImporter::filter_definitions(
&sample_definitions(),
&DefinitionsSelection {
vhosts: vec!["/".to_string()],
..Default::default()
},
)
.unwrap();
assert_eq!(filtered.vhosts.len(), 1);
assert_eq!(filtered.queues.len(), 2);
assert_eq!(filtered.exchanges.len(), 2);
assert_eq!(filtered.bindings.len(), 2);
assert_eq!(filtered.permissions.len(), 1);
assert_eq!(filtered.users.len(), 1);
}
#[test]
fn test_filter_definitions_rejects_missing_queue() {
let err = DefinitionsImporter::filter_definitions(
&sample_definitions(),
&DefinitionsSelection {
queues: vec!["missing".to_string()],
..Default::default()
},
)
.unwrap_err();
assert!(err.to_string().contains("missing queue"));
}
}