use super::service_base::{Service, ServiceMetrics};
use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{broadcast, RwLock};
use tracing::debug;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServiceMessage<T> {
pub id: String,
pub timestamp: std::time::SystemTime,
pub source: String,
pub destination: String,
pub payload: T,
pub correlation_id: Option<String>,
pub reply_to: Option<String>,
}
impl<T> ServiceMessage<T> {
pub fn new(source: String, destination: String, payload: T) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
timestamp: std::time::SystemTime::now(),
source,
destination,
payload,
correlation_id: None,
reply_to: None,
}
}
pub fn reply<R>(&self, payload: R) -> ServiceMessage<R> {
ServiceMessage {
id: uuid::Uuid::new_v4().to_string(),
timestamp: std::time::SystemTime::now(),
source: self.destination.clone(),
destination: self.source.clone(),
payload,
correlation_id: Some(self.id.clone()),
reply_to: None,
}
}
}
pub struct PubSubService<T: Clone + Send> {
subscribers: Arc<RwLock<HashMap<String, Vec<broadcast::Sender<T>>>>>,
metrics: Arc<RwLock<ServiceMetrics>>,
}
impl<T: Clone + Send> Default for PubSubService<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Clone + Send> PubSubService<T> {
#[must_use]
pub fn new() -> Self {
Self {
subscribers: Arc::new(RwLock::new(HashMap::new())),
metrics: Arc::new(RwLock::new(ServiceMetrics::default())),
}
}
pub async fn subscribe(&self, topic: String) -> broadcast::Receiver<T> {
let (tx, rx) = broadcast::channel(100);
let mut subs = self.subscribers.write().await;
subs.entry(topic).or_insert_with(Vec::new).push(tx);
rx
}
pub async fn publish(&self, topic: String, message: T) -> Result<()> {
let subs = self.subscribers.read().await;
if let Some(subscribers) = subs.get(&topic) {
for tx in subscribers {
let _ = tx.send(message.clone());
}
let mut metrics = self.metrics.write().await;
metrics.request_count += 1;
metrics.success_count += 1;
}
Ok(())
}
pub async fn subscriber_count(&self, topic: &str) -> usize {
let subs = self.subscribers.read().await;
subs.get(topic).map_or(0, std::vec::Vec::len)
}
}
type MessageService = Arc<
dyn Service<
Input = ServiceMessage<Vec<u8>>,
Output = ServiceMessage<Vec<u8>>,
Error = anyhow::Error,
> + Send
+ Sync,
>;
type RouteMap = Arc<RwLock<HashMap<String, MessageService>>>;
pub struct RouterService {
routes: RouteMap,
default_handler: Option<MessageService>,
metrics: ServiceMetrics,
}
impl Default for RouterService {
fn default() -> Self {
Self::new()
}
}
impl RouterService {
#[must_use]
pub fn new() -> Self {
Self {
routes: Arc::new(RwLock::new(HashMap::new())),
default_handler: None,
metrics: ServiceMetrics::default(),
}
}
pub async fn add_route<S>(&mut self, pattern: String, handler: S)
where
S: Service<
Input = ServiceMessage<Vec<u8>>,
Output = ServiceMessage<Vec<u8>>,
Error = anyhow::Error,
> + Send
+ Sync
+ 'static,
{
let mut routes = self.routes.write().await;
routes.insert(pattern, Arc::new(handler));
}
pub fn set_default<S>(&mut self, handler: S)
where
S: Service<
Input = ServiceMessage<Vec<u8>>,
Output = ServiceMessage<Vec<u8>>,
Error = anyhow::Error,
> + Send
+ Sync
+ 'static,
{
self.default_handler = Some(Arc::new(handler));
}
pub async fn route(&self, message: ServiceMessage<Vec<u8>>) -> Result<ServiceMessage<Vec<u8>>> {
let routes = self.routes.read().await;
if let Some(handler) = routes.get(&message.destination) {
debug!("Routing to {}", message.destination);
return handler.process(message).await;
}
if let Some(ref default) = self.default_handler {
debug!("Using default handler for {}", message.destination);
return default.process(message).await;
}
Err(anyhow::anyhow!(
"No route found for {}",
message.destination
))
}
#[must_use]
pub fn metrics(&self) -> &ServiceMetrics {
&self.metrics
}
}
#[async_trait]
impl Service for RouterService {
type Input = ServiceMessage<Vec<u8>>;
type Output = ServiceMessage<Vec<u8>>;
type Error = anyhow::Error;
async fn process(&self, input: Self::Input) -> Result<Self::Output, Self::Error> {
self.route(input).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_service_message() {
let msg = ServiceMessage::new(
"service-a".to_string(),
"service-b".to_string(),
"Hello".to_string(),
);
assert_eq!(msg.source, "service-a");
assert_eq!(msg.destination, "service-b");
assert_eq!(msg.payload, "Hello");
let reply = msg.reply("World".to_string());
assert_eq!(reply.source, "service-b");
assert_eq!(reply.destination, "service-a");
assert_eq!(reply.correlation_id, Some(msg.id));
}
#[tokio::test]
async fn test_pub_sub() {
let pubsub = PubSubService::<String>::new();
let mut subscriber1 = pubsub.subscribe("topic1".to_string()).await;
let mut subscriber2 = pubsub.subscribe("topic1".to_string()).await;
pubsub
.publish("topic1".to_string(), "Message 1".to_string())
.await
.unwrap();
let msg1 = subscriber1.recv().await.unwrap();
let msg2 = subscriber2.recv().await.unwrap();
assert_eq!(msg1, "Message 1");
assert_eq!(msg2, "Message 1");
assert_eq!(pubsub.subscriber_count("topic1").await, 2);
}
}
#[cfg(test)]
mod property_tests {
use proptest::prelude::*;
proptest! {
#[test]
fn basic_property_stability(_input in ".*") {
prop_assert!(true);
}
#[test]
fn module_consistency_check(_x in 0u32..1000) {
prop_assert!(_x < 1001);
}
}
}