use crate::traits::{MessagePublisher, PublisherError, Sent, SentBatch};
use crate::CanonicalMessage;
use async_trait::async_trait;
use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
use tracing::warn;
pub struct SwitchPublisher {
metadata_key: String,
cases: HashMap<String, Arc<dyn MessagePublisher>>,
default: Option<Arc<dyn MessagePublisher>>,
}
impl SwitchPublisher {
pub fn new(
metadata_key: String,
cases: HashMap<String, Arc<dyn MessagePublisher>>,
default: Option<Arc<dyn MessagePublisher>>,
) -> Self {
Self {
metadata_key,
cases,
default,
}
}
fn get_publisher(&self, message: &CanonicalMessage) -> Option<&Arc<dyn MessagePublisher>> {
if let Some(val) = message.metadata.get(&self.metadata_key) {
if let Some(publisher) = self.cases.get(val) {
return Some(publisher);
}
}
self.default.as_ref()
}
}
#[async_trait]
impl MessagePublisher for SwitchPublisher {
async fn send(&self, message: CanonicalMessage) -> Result<Sent, PublisherError> {
if let Some(publisher) = self.get_publisher(&message) {
publisher.send(message).await
} else {
warn!(
"Switch publisher dropped message with id {:032x}: metadata key '{}' not found or no matching case/default.",
message.message_id, self.metadata_key
);
Ok(Sent::Ack)
}
}
async fn send_batch(
&self,
messages: Vec<CanonicalMessage>,
) -> Result<SentBatch, PublisherError> {
use futures::future::join_all;
use std::collections::HashMap;
if messages.is_empty() {
return Ok(SentBatch::Ack);
}
let mut grouped_messages: HashMap<
String,
(Arc<dyn MessagePublisher>, Vec<CanonicalMessage>),
> = HashMap::new();
for message in messages {
if let Some(publisher) = self.get_publisher(&message) {
grouped_messages
.entry(
message
.metadata
.get(&self.metadata_key)
.cloned()
.unwrap_or_default(),
)
.or_insert_with(|| (publisher.clone(), Vec::new()))
.1
.push(message);
} else {
warn!(
"Switch publisher dropped message with id {:032x}: metadata key '{}' not found or no matching case/default.",
message.message_id, self.metadata_key
);
}
}
let batch_sends = grouped_messages
.into_values()
.map(|(publisher, batch)| async move { publisher.send_batch(batch).await });
let results = join_all(batch_sends).await;
let mut all_responses = Vec::new();
let mut all_failed = Vec::new();
for result in results {
match result {
Ok(SentBatch::Ack) => {}
Ok(SentBatch::Partial { responses, failed }) => {
if let Some(resps) = responses {
all_responses.extend(resps);
}
all_failed.extend(failed);
}
Err(e) => {
return Err(e);
}
}
}
if all_failed.is_empty() && all_responses.is_empty() {
Ok(SentBatch::Ack)
} else {
Ok(SentBatch::Partial {
responses: if all_responses.is_empty() {
None
} else {
Some(all_responses)
},
failed: all_failed,
})
}
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::endpoints::memory::MemoryPublisher;
use std::sync::Arc;
#[tokio::test]
async fn test_switch_publisher_routing() {
let pub_a = MemoryPublisher::new_local("topic_a", 10);
let pub_b = MemoryPublisher::new_local("topic_b", 10);
let pub_default = MemoryPublisher::new_local("topic_default", 10);
let chan_a = pub_a.channel();
let chan_b = pub_b.channel();
let chan_default = pub_default.channel();
let mut cases = HashMap::new();
cases.insert(
"A".to_string(),
Arc::new(pub_a) as Arc<dyn MessagePublisher>,
);
cases.insert(
"B".to_string(),
Arc::new(pub_b) as Arc<dyn MessagePublisher>,
);
let switch = SwitchPublisher::new(
"route_key".to_string(),
cases,
Some(Arc::new(pub_default) as Arc<dyn MessagePublisher>),
);
let msg_a = CanonicalMessage::from("payload_a").with_metadata_kv("route_key", "A");
switch.send(msg_a).await.unwrap();
assert_eq!(chan_a.len(), 1);
assert_eq!(chan_b.len(), 0);
assert_eq!(chan_default.len(), 0);
chan_a.drain_messages();
let msg_b = CanonicalMessage::from("payload_b").with_metadata_kv("route_key", "B");
switch.send(msg_b).await.unwrap();
assert_eq!(chan_a.len(), 0);
assert_eq!(chan_b.len(), 1);
assert_eq!(chan_default.len(), 0);
chan_b.drain_messages();
let msg_c =
CanonicalMessage::new(b"payload_c".to_vec(), None).with_metadata_kv("route_key", "C");
switch.send(msg_c).await.unwrap();
assert_eq!(chan_a.len(), 0);
assert_eq!(chan_b.len(), 0);
assert_eq!(chan_default.len(), 1);
chan_default.drain_messages();
let msg_d = CanonicalMessage::new(b"payload_d".to_vec(), None);
switch.send(msg_d).await.unwrap();
assert_eq!(chan_a.len(), 0);
assert_eq!(chan_b.len(), 0);
assert_eq!(chan_default.len(), 1);
}
}