use crossbeam_channel::{Receiver, Sender, bounded};
use lazy_static::lazy_static;
use std::any::Any;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use uuid::Uuid;
lazy_static! {
static ref PUBSUB: Arc<PubSub> = Arc::new(PubSub::new());
}
#[derive(Clone)]
pub struct TopicConfig {
queue_depth: usize,
overwrite: bool,
}
impl TopicConfig {
pub fn new(queue_depth: usize, overwrite: bool) -> Self {
TopicConfig {
queue_depth,
overwrite,
}
}
}
#[derive(Clone)]
struct MessageWrapper {
data: Arc<dyn Any + Send + Sync>,
}
#[derive(Clone)]
struct ChannelPair {
sender: Sender<MessageWrapper>,
receiver: Receiver<MessageWrapper>,
config: TopicConfig,
subscriber_id: String,
}
impl ChannelPair {
fn new(
sender: Sender<MessageWrapper>,
receiver: Receiver<MessageWrapper>,
config: TopicConfig,
subscriber_id: String,
) -> Self {
ChannelPair {
sender,
receiver,
config,
subscriber_id,
}
}
}
struct TopicData {
#[allow(dead_code)]
name: String,
channel_pairs: Vec<ChannelPair>,
}
struct SubscriberData {
topic: String,
#[allow(dead_code)]
receiver: Receiver<MessageWrapper>,
#[allow(dead_code)]
callback: Option<Arc<dyn Fn(&dyn Any) + Send + Sync>>,
}
#[derive(Clone)]
pub struct ManualReceiver<T: 'static> {
receiver: Receiver<MessageWrapper>,
subscriber_id: String,
pubsub: Arc<PubSub>,
_marker: std::marker::PhantomData<T>,
}
impl<T: Clone + 'static> ManualReceiver<T> {
pub fn try_recv(&self) -> Option<T> {
let msg = self.receiver.try_recv().ok();
match msg {
Some(msg) => {
if let Some(data) = msg.downcast::<T>() {
return Some(data.to_owned());
}
None
}
None => None,
}
}
pub fn recv(&self) -> Option<T> {
self.recv_timeout(None)
}
pub fn recv_timeout(&self, timeout_ms: Option<u64>) -> Option<T> {
let msg = match timeout_ms {
Some(ms) => self.receiver.recv_timeout(Duration::from_millis(ms)).ok(),
None => self.receiver.recv().ok(),
};
match msg {
Some(msg) => {
if let Some(data) = msg.downcast::<T>() {
return Some(data.to_owned());
}
None
}
None => None,
}
}
pub fn unsubscribe(self) {
self.pubsub.unsubscribe(&self.subscriber_id);
}
}
impl MessageWrapper {
fn new<T: Send + Sync + Clone + 'static>(data: T) -> Self {
MessageWrapper {
data: Arc::new(data),
}
}
fn downcast<T: 'static>(&self) -> Option<&T> {
self.data.downcast_ref::<T>()
}
}
pub struct PubSub {
topics: Mutex<Vec<TopicData>>,
topic_map: Mutex<HashMap<String, usize>>,
subscribers: Mutex<HashMap<String, SubscriberData>>,
}
impl PubSub {
fn new() -> Self {
PubSub {
topics: Mutex::new(Vec::new()),
topic_map: Mutex::new(HashMap::new()),
subscribers: Mutex::new(HashMap::new()),
}
}
pub fn instance() -> Arc<PubSub> {
PUBSUB.clone()
}
pub fn create_publisher(&self, topic: &str) -> usize {
let mut topic_map = self.topic_map.lock().unwrap();
if let Some(&index) = topic_map.get(topic) {
return index;
}
let mut topics = self.topics.lock().unwrap();
let new_index = topics.len();
topics.push(TopicData {
name: topic.to_string(),
channel_pairs: Vec::new(),
});
topic_map.insert(topic.to_string(), new_index);
new_index
}
pub fn subscribe_manual<T: Send + Sync + Clone + 'static>(
&self,
topic: &str,
config: TopicConfig,
) -> ManualReceiver<T>
where
T: 'static,
{
let subscriber_id = Uuid::new_v4().to_string();
let (tx, rx) = bounded(config.queue_depth);
let topic_str = topic.to_string();
let topic_index = self.create_publisher(topic);
{
let mut topics = self.topics.lock().unwrap();
topics[topic_index].channel_pairs.push(ChannelPair::new(
tx,
rx.clone(),
config.clone(),
subscriber_id.clone(),
));
}
{
self.subscribers.lock().unwrap().insert(
subscriber_id.clone(),
SubscriberData {
topic: topic_str.clone(),
receiver: rx.clone(),
callback: None,
},
);
}
ManualReceiver {
receiver: rx,
subscriber_id,
pubsub: PubSub::instance(),
_marker: std::marker::PhantomData,
}
}
pub fn subscribe<T, F>(&self, topic: &str, config: TopicConfig, callback: F) -> String
where
T: Send + Sync + Clone + 'static,
F: Fn(&T) + Send + Sync + 'static,
{
let subscriber_id = Uuid::new_v4().to_string();
let (tx, rx) = bounded(config.queue_depth);
let topic_str = topic.to_string();
let topic_index = self.create_publisher(topic);
{
let mut topics = self.topics.lock().unwrap();
topics[topic_index].channel_pairs.push(ChannelPair::new(
tx,
rx.clone(),
config.clone(),
subscriber_id.clone(),
));
}
let callback_wrapper: Arc<dyn Fn(&dyn Any) + Send + Sync> =
Arc::new(move |data: &dyn Any| {
if let Some(t) = data.downcast_ref::<T>() {
callback(t);
}
});
{
self.subscribers.lock().unwrap().insert(
subscriber_id.clone(),
SubscriberData {
topic: topic_str.clone(),
receiver: rx.clone(),
callback: Some(callback_wrapper.clone()),
},
);
}
let rx_clone = rx.clone();
let callback_for_thread = callback_wrapper.clone();
std::thread::spawn(move || {
while let Ok(msg) = rx_clone.recv() {
if let Some(data) = msg.downcast::<T>() {
callback_for_thread(data);
}
}
});
subscriber_id
}
pub fn try_publish<T: Send + Sync + Clone + 'static>(&self, topic_id: usize, message: T) {
let msg = MessageWrapper::new(message);
let channel_pairs = {
let topics = self.topics.lock().unwrap();
if topic_id >= topics.len() {
return;
}
if topics[topic_id].channel_pairs.is_empty() {
return;
}
topics[topic_id].channel_pairs.clone()
};
for pair in channel_pairs.iter() {
if pair.config.overwrite {
while pair.sender.is_full() {
let _ = pair.receiver.try_recv();
}
}
let _ = pair.sender.try_send(msg.clone());
}
}
pub fn publish<T: Send + Sync + Clone + 'static>(&self, topic_id: usize, message: T) {
self.publish_with_timeout(topic_id, message, None);
}
pub fn publish_with_timeout<T: Send + Sync + Clone + 'static>(
&self,
topic_id: usize,
message: T,
max_wait_ms: Option<u64>,
) {
let msg = MessageWrapper::new(message);
let channel_pairs = {
let topics = self.topics.lock().unwrap();
if topic_id >= topics.len() {
return;
}
if topics[topic_id].channel_pairs.is_empty() {
return;
}
topics[topic_id].channel_pairs.clone()
};
for pair in channel_pairs.iter() {
if pair.config.overwrite {
while pair.sender.is_full() {
let _ = pair.receiver.try_recv();
}
let _ = pair.sender.try_send(msg.clone());
} else {
match max_wait_ms {
Some(ms) => {
let _ = pair
.sender
.send_timeout(msg.clone(), Duration::from_millis(ms));
}
None => {
let _ = pair.sender.send(msg.clone());
}
}
}
}
}
pub fn unsubscribe(&self, subscriber_id: &str) {
let topic_opt = {
let mut subscribers = self.subscribers.lock().unwrap();
if let Some(data) = subscribers.remove(subscriber_id) {
Some(data.topic)
} else {
None
}
};
if let Some(topic) = topic_opt {
let topic_index_opt = {
let topic_map = self.topic_map.lock().unwrap();
topic_map.get(&topic).cloned()
};
if let Some(topic_index) = topic_index_opt {
let mut topics = self.topics.lock().unwrap();
if let Some(topic_data) = topics.get_mut(topic_index) {
topic_data
.channel_pairs
.retain(|pair| pair.subscriber_id != subscriber_id);
if topic_data.channel_pairs.is_empty() {
let mut topic_map = self.topic_map.lock().unwrap();
topic_map.remove(&topic);
}
}
}
}
}
}