use crate::broker::bridge::{
BridgeConfig, BridgeConnection, BridgeError, BridgeStats, LoopPrevention, Result,
};
use crate::broker::router::MessageRouter;
use crate::packet::publish::PublishPacket;
use parking_lot::{Mutex, RwLock};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::task::JoinHandle;
use tracing::{debug, error, info};
pub struct BridgeManager {
bridges: Arc<RwLock<HashMap<String, Arc<BridgeConnection>>>>,
tasks: Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
router: Arc<MessageRouter>,
loop_prevention: Arc<RwLock<Option<Arc<LoopPrevention>>>>,
runtime_handle: Option<tokio::runtime::Handle>,
}
impl BridgeManager {
#[allow(clippy::must_use_candidate)]
pub fn new(router: Arc<MessageRouter>) -> Self {
Self {
bridges: Arc::new(RwLock::new(HashMap::new())),
tasks: Arc::new(Mutex::new(HashMap::new())),
router,
loop_prevention: Arc::new(RwLock::new(None)),
runtime_handle: None,
}
}
#[allow(clippy::must_use_candidate)]
pub fn with_runtime(router: Arc<MessageRouter>, handle: tokio::runtime::Handle) -> Self {
Self {
bridges: Arc::new(RwLock::new(HashMap::new())),
tasks: Arc::new(Mutex::new(HashMap::new())),
router,
loop_prevention: Arc::new(RwLock::new(None)),
runtime_handle: Some(handle),
}
}
pub fn set_loop_prevention(&self, ttl: crate::time::Duration, cache_size: usize) {
let lp = Arc::new(LoopPrevention::new(ttl, cache_size));
*self.loop_prevention.write() = Some(lp);
info!(
ttl_secs = ttl.as_secs(),
cache_size = cache_size,
"Loop prevention configured at manager level"
);
}
fn get_or_init_loop_prevention(&self, config: &BridgeConfig) -> Arc<LoopPrevention> {
let mut guard = self.loop_prevention.write();
if let Some(ref lp) = *guard {
let existing_ttl = lp.ttl();
let existing_cache_size = lp.max_cache_size();
if config.loop_prevention_ttl != existing_ttl
|| config.loop_prevention_cache_size != existing_cache_size
{
tracing::warn!(
bridge = %config.name,
bridge_ttl_secs = config.loop_prevention_ttl.as_secs(),
bridge_cache_size = config.loop_prevention_cache_size,
active_ttl_secs = existing_ttl.as_secs(),
active_cache_size = existing_cache_size,
"Bridge has different loop prevention settings than active config; using active config"
);
}
return lp.clone();
}
let lp = Arc::new(LoopPrevention::new(
config.loop_prevention_ttl,
config.loop_prevention_cache_size,
));
*guard = Some(lp.clone());
info!(
ttl_secs = config.loop_prevention_ttl.as_secs(),
cache_size = config.loop_prevention_cache_size,
"Loop prevention initialized from bridge config"
);
lp
}
pub fn add_bridge(&self, config: BridgeConfig) -> Result<()> {
let name = config.name.clone();
if self.bridges.read().contains_key(&name) {
return Err(BridgeError::ConfigurationError(format!(
"Bridge '{name}' already exists"
)));
}
self.get_or_init_loop_prevention(&config);
let bridge = Arc::new(BridgeConnection::new(config, self.router.clone())?);
bridge.start();
let bridge_clone = bridge.clone();
let task = if let Some(ref handle) = self.runtime_handle {
handle.spawn(async move {
if let Err(e) = Box::pin(bridge_clone.run()).await {
error!("Bridge task error: {e}");
}
})
} else {
tokio::spawn(async move {
if let Err(e) = Box::pin(bridge_clone.run()).await {
error!("Bridge task error: {e}");
}
})
};
let task_name = name.clone();
self.bridges.write().insert(name, bridge);
self.tasks.lock().insert(task_name, task);
Ok(())
}
pub async fn remove_bridge(&self, name: &str) -> Result<()> {
let bridge = self.bridges.write().remove(name);
if let Some(bridge) = bridge {
bridge.stop().await?;
if let Some(task) = self.tasks.lock().remove(name) {
task.abort();
}
info!("Removed bridge '{}'", name);
Ok(())
} else {
Err(BridgeError::ConfigurationError(format!(
"Bridge '{name}' not found"
)))
}
}
pub async fn handle_outgoing(&self, packet: &PublishPacket) -> Result<()> {
debug!(
topic = %packet.topic_name,
payload_len = packet.payload.len(),
"handle_outgoing called"
);
if packet.topic_name.starts_with("$SYS/") {
debug!(topic = %packet.topic_name, "skipping $SYS topic");
return Ok(());
}
let loop_prevention = self.loop_prevention.read().clone();
if let Some(lp) = loop_prevention {
if !lp.check_message(packet).await {
debug!(
topic = %packet.topic_name,
"Message loop detected, not forwarding to bridges"
);
return Ok(());
}
}
let bridge_list: Vec<_> = {
let bridges = self.bridges.read();
bridges
.iter()
.map(|(name, bridge)| (name.clone(), bridge.clone()))
.collect()
};
if bridge_list.is_empty() {
debug!(topic = %packet.topic_name, "no bridges configured");
return Ok(());
}
debug!(
topic = %packet.topic_name,
bridge_count = bridge_list.len(),
"forwarding to bridges"
);
for (name, bridge) in bridge_list {
debug!(bridge = %name, topic = %packet.topic_name, "calling forward_message");
if let Err(e) = bridge.forward_message(packet).await {
error!("Bridge '{}' failed to forward message: {}", name, e);
}
}
Ok(())
}
pub async fn get_all_stats(&self) -> HashMap<String, BridgeStats> {
let bridge_list: Vec<_> = {
let bridges = self.bridges.read();
bridges
.iter()
.map(|(name, bridge)| (name.clone(), bridge.clone()))
.collect()
};
let mut stats = HashMap::new();
for (name, bridge) in bridge_list {
stats.insert(name, bridge.get_stats().await);
}
stats
}
pub async fn get_bridge_stats(&self, name: &str) -> Option<BridgeStats> {
let bridge = {
let bridges = self.bridges.read();
bridges.get(name).cloned()
};
if let Some(bridge) = bridge {
Some(bridge.get_stats().await)
} else {
None
}
}
#[must_use]
pub fn list_bridges(&self) -> Vec<String> {
self.bridges.read().keys().cloned().collect()
}
pub async fn stop_all(&self) -> Result<()> {
info!("Stopping all bridges");
let bridges: Vec<_> = self.bridges.read().values().cloned().collect();
for bridge in bridges {
if let Err(e) = bridge.stop().await {
error!("Failed to stop bridge: {e}");
}
}
let mut tasks = self.tasks.lock();
for (name, task) in tasks.drain() {
debug!("Cancelling task for bridge '{}'", name);
task.abort();
}
self.bridges.write().clear();
Ok(())
}
pub async fn reload_bridge(&self, config: BridgeConfig) -> Result<()> {
let name = config.name.clone();
if self.bridges.read().contains_key(&name) {
self.remove_bridge(&name).await?;
}
self.add_bridge(config)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::broker::bridge::BridgeDirection;
use crate::QoS;
#[tokio::test]
async fn test_bridge_manager_lifecycle() {
use crate::broker::config::{BrokerConfig, StorageBackend, StorageConfig};
use crate::broker::server::MqttBroker;
let router = Arc::new(MessageRouter::new());
let manager = BridgeManager::new(router);
let storage_config = StorageConfig {
backend: StorageBackend::Memory,
enable_persistence: true,
..Default::default()
};
let config = BrokerConfig::default()
.with_bind_address("127.0.0.1:0".parse::<std::net::SocketAddr>().unwrap())
.with_storage(storage_config);
let mut broker = MqttBroker::with_config(config)
.await
.expect("Failed to create broker");
let broker_addr = broker.local_addr().expect("Failed to get broker address");
let broker_handle = tokio::spawn(async move { broker.run().await });
tokio::time::sleep(crate::time::Duration::from_millis(100)).await;
let config = BridgeConfig::new("test-bridge", format!("{broker_addr}")).add_topic(
"test/#",
BridgeDirection::Both,
QoS::AtMostOnce,
);
assert!(manager.add_bridge(config.clone()).is_ok());
let bridges = manager.list_bridges();
assert_eq!(bridges.len(), 1);
assert!(bridges.contains(&"test-bridge".to_string()));
assert!(manager.add_bridge(config).is_err());
broker_handle.abort();
assert!(manager.remove_bridge("test-bridge").await.is_ok());
let bridges = manager.list_bridges();
assert_eq!(bridges.len(), 0);
}
}