use async_trait::async_trait;
use futures::StreamExt;
use lapin::BasicProperties;
use lapin::{options::*, types::FieldTable, Connection, ConnectionProperties, Consumer};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
use crate::exchange::Exchange;
use crate::rabbit_result::RabbitResult;
#[async_trait]
pub trait AsyncHandler<S>: Send + Sync {
async fn handle(&self, data: Vec<u8>, state: Arc<Mutex<S>>) -> Option<RabbitResult>;
}
#[async_trait]
impl<S, F, Fut> AsyncHandler<S> for F
where
S: Send + Sync + 'static,
F: Fn(Vec<u8>, Arc<Mutex<S>>) -> Fut + Send + Sync,
Fut: std::future::Future<Output = Option<RabbitResult>> + Send + 'static,
{
async fn handle(&self, data: Vec<u8>, state: Arc<Mutex<S>>) -> Option<RabbitResult> {
(self)(data, state).await
}
}
pub struct RabbitRouter<S> {
pub connection: Arc<Connection>,
pub channel: Arc<lapin::Channel>,
pub routes: Arc<Mutex<HashMap<String, Arc<dyn AsyncHandler<S>>>>>,
pub state: Arc<Mutex<S>>,
}
impl<S> RabbitRouter<S>
where
S: Send + Sync + 'static,
{
pub async fn new(uri: &str, initial_state: S) -> Self {
let connection = Connection::connect(uri, ConnectionProperties::default())
.await
.expect("Failed to connect to RabbitMQ");
let channel = connection
.create_channel()
.await
.expect("Failed to create channel");
RabbitRouter {
connection: Arc::new(connection),
channel: Arc::new(channel),
routes: Arc::new(Mutex::new(HashMap::new())),
state: Arc::new(Mutex::new(initial_state)),
}
}
pub async fn add_route_exchange<H>(
&self,
exchange: &str,
routing_key: &str,
result_exchange: Option<Exchange>,
handler: H,
) -> Result<(), Box<dyn std::error::Error>>
where
H: AsyncHandler<S> + 'static,
{
let exchange = exchange.to_string();
let routing_key = routing_key.to_string();
let handler = Arc::new(handler);
let channel = self.channel.clone();
let mut exchange_options = ExchangeDeclareOptions::default();
exchange_options.durable = true;
channel
.exchange_declare(
&exchange,
lapin::ExchangeKind::Topic,
exchange_options,
FieldTable::default(),
)
.await?;
let mut options = QueueDeclareOptions::default();
options.exclusive = true;
options.durable = true;
let result = channel
.queue_declare("", options, FieldTable::default())
.await?;
let queue_name = result.name().as_str().to_string();
let mut routes = self.routes.lock().await;
routes.insert(queue_name.clone(), handler);
let routes = self.routes.clone();
let state = self.state.clone();
channel
.queue_bind(
&queue_name,
&exchange,
&routing_key,
QueueBindOptions::default(),
FieldTable::default(),
)
.await?;
tokio::spawn(async move {
let consumer = channel
.basic_consume(
&queue_name,
"",
BasicConsumeOptions::default(),
FieldTable::default(),
)
.await
.expect("Failed to start consumer");
Self::consume_messages(
queue_name,
consumer,
channel,
result_exchange,
routes,
state,
)
.await;
});
Ok(())
}
pub async fn add_route_queue<H>(
&self,
queue_name: &str,
result_exchange: Option<Exchange>,
handler: H,
) -> Result<(), Box<dyn std::error::Error>>
where
H: AsyncHandler<S> + 'static,
{
let queue_name = queue_name.to_string();
let handler = Arc::new(handler);
let mut routes = self.routes.lock().await;
routes.insert(queue_name.clone(), handler);
let routes = self.routes.clone();
let state = self.state.clone();
let channel = self.channel.clone();
let _queue = channel
.queue_declare(
&queue_name,
QueueDeclareOptions::default(),
FieldTable::default(),
)
.await
.expect("Failed to declare queue");
tokio::spawn(async move {
let consumer = channel
.basic_consume(
&queue_name,
"",
BasicConsumeOptions::default(),
FieldTable::default(),
)
.await
.expect("Failed to start consumer");
Self::consume_messages(
queue_name,
consumer,
channel,
result_exchange,
routes,
state,
)
.await;
});
Ok(())
}
async fn consume_messages(
queue_name: String,
mut consumer: Consumer,
channel: Arc<lapin::Channel>,
result_exchange: Option<Exchange>,
routes: Arc<Mutex<HashMap<String, Arc<dyn AsyncHandler<S>>>>>,
state: Arc<Mutex<S>>,
) {
while let Some(delivery) = consumer.next().await {
match delivery {
Ok(delivery) => {
let data = delivery.data.clone();
let routes = routes.lock().await;
if let Some(handler) = routes.get(&queue_name) {
let result = handler.handle(data, state.clone()).await;
if let Some(result) = result {
if let Some(result_exchange) = &result_exchange {
let full_key=result.logging_level.to_string()+"."+&result.billing_type.to_string()+"."+result_exchange.routing_key.as_str();
let payload = serde_json::to_vec(&result).unwrap();
channel
.basic_publish(
&result_exchange.name,
&full_key,
BasicPublishOptions::default(),
&payload,
BasicProperties::default(),
)
.await
.expect("Failed to publish message");
}
}
}
delivery
.ack(BasicAckOptions::default())
.await
.expect("Failed to ack message");
}
Err(error) => eprintln!("Error receiving message: {:?}", error),
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use lapin::BasicProperties;
use tokio::time::{sleep, Duration};
async fn test_handler(
data: Vec<u8>,
state: Arc<Mutex<Arc<Mutex<HashMap<String, i32>>>>>,
) -> Option<RabbitResult> {
let state = state.lock().await;
let mut inner_state = state.lock().await;
let count = inner_state.entry("shared_counter".to_string()).or_insert(0);
*count += 1;
println!(
"Handler for queue_1 received: {:?}, updated shared_counter: {}",
String::from_utf8_lossy(&data),
*count
);
None
}
#[tokio::test]
async fn test_rabbit_router() {
let state = HashMap::new();
let initial_state = Arc::new(Mutex::new(state));
let router = RabbitRouter::new("amqp://127.0.0.1:5672", initial_state.clone()).await;
router
.add_route_queue("test_queue", None, test_handler)
.await
.expect("Failed to add route for queue_1");
let channel = router.channel.clone();
let payload = b"Test message".to_vec();
channel
.basic_publish(
"",
"test_queue",
BasicPublishOptions::default(),
&payload,
BasicProperties::default(),
)
.await
.expect("Failed to publish message");
sleep(Duration::from_secs(2)).await;
println!("State: {:?}", initial_state.lock().await);
assert!(initial_state.lock().await.get("shared_counter").is_some());
assert!(initial_state.lock().await.get("shared_counter").unwrap() == &1);
}
#[tokio::test]
async fn test_multiple_routes() {
let state = HashMap::new();
let initial_state = Arc::new(Mutex::new(state));
let router = RabbitRouter::new("amqp://127.0.0.1:5672/%2f", initial_state.clone()).await;
let _ = router.add_route_queue("queue_1", None, test_handler).await;
let _ = router.add_route_queue("queue_2", None, test_handler).await;
let channel = router.channel.clone();
let payload_1 = b"Message for queue_1".to_vec();
let payload_2 = b"Message for queue_2".to_vec();
channel
.basic_publish(
"",
"queue_1",
BasicPublishOptions::default(),
&payload_1,
BasicProperties::default(),
)
.await
.expect("Failed to publish message to queue_1");
channel
.basic_publish(
"",
"queue_2",
BasicPublishOptions::default(),
&payload_2,
BasicProperties::default(),
)
.await
.expect("Failed to publish message to queue_2");
sleep(Duration::from_secs(2)).await;
assert!(initial_state.lock().await.get("shared_counter").is_some());
assert!(initial_state.lock().await.get("shared_counter").unwrap() == &2);
}
#[tokio::test]
async fn test_exchange_route() {
let state = HashMap::new();
let initial_state = Arc::new(Mutex::new(state));
let router = RabbitRouter::new("amqp://127.0.0.1:5672", initial_state.clone()).await;
let channel = router.channel.clone();
channel
.exchange_declare(
"test_exchange",
lapin::ExchangeKind::Topic,
ExchangeDeclareOptions::default(),
FieldTable::default(),
)
.await
.expect("Failed to declare exchange");
let _ = router
.add_route_exchange("test_exchange", "test.routing.key", None, test_handler)
.await;
let channel = router.channel.clone();
let payload = b"Test exchange message".to_vec();
channel
.basic_publish(
"test_exchange",
"test.routing.key",
BasicPublishOptions::default(),
&payload,
BasicProperties::default(),
)
.await
.expect("Failed to publish message to exchange");
println!("Message published to exchange");
sleep(Duration::from_secs(2)).await;
assert!(initial_state.lock().await.get("shared_counter").is_some());
assert!(initial_state.lock().await.get("shared_counter").unwrap() == &1);
}
}