use crate::pubsub::PubSub;
use async_trait::async_trait;
use dashmap::DashMap;
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::{broadcast, RwLock};
const DEFAULT_BUFFER_SIZE: usize = 1024;
pub struct MemoryPubSub {
channels: Arc<DashMap<String, broadcast::Sender<Vec<u8>>>>,
subscriptions: Arc<RwLock<HashSet<String>>>,
buffer_size: usize,
}
impl MemoryPubSub {
pub fn new() -> Self {
Self::with_buffer_size(DEFAULT_BUFFER_SIZE)
}
pub fn with_buffer_size(buffer_size: usize) -> Self {
Self {
channels: Arc::new(DashMap::new()),
subscriptions: Arc::new(RwLock::new(HashSet::new())),
buffer_size,
}
}
fn get_or_create_channel(&self, stream: &str) -> broadcast::Sender<Vec<u8>> {
self.channels
.entry(stream.to_string())
.or_insert_with(|| broadcast::channel(self.buffer_size).0)
.clone()
}
}
impl Default for MemoryPubSub {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl PubSub for MemoryPubSub {
async fn publish(&self, stream: &str, payload: &[u8]) -> anyhow::Result<()> {
let tx = self.get_or_create_channel(stream);
let _ = tx.send(payload.to_vec());
Ok(())
}
async fn subscribe(&self, stream: &str) -> anyhow::Result<()> {
self.get_or_create_channel(stream);
let mut subs = self.subscriptions.write().await;
subs.insert(stream.to_string());
Ok(())
}
async fn unsubscribe(&self, stream: &str) -> anyhow::Result<()> {
let mut subs = self.subscriptions.write().await;
subs.remove(stream);
if let Some(entry) = self.channels.get(stream)
&& entry.receiver_count() == 0
{
drop(entry);
self.channels.remove(stream);
}
Ok(())
}
async fn listen<F>(&self, callback: F) -> anyhow::Result<()>
where
F: Fn(String, Vec<u8>) + Send + Sync + 'static,
{
let callback = Arc::new(callback);
let channels = self.channels.clone();
let subscriptions = self.subscriptions.clone();
let mut receivers: std::collections::HashMap<String, broadcast::Receiver<Vec<u8>>> =
std::collections::HashMap::new();
loop {
{
let subs = subscriptions.read().await;
for stream in subs.iter() {
if !receivers.contains_key(stream)
&& let Some(tx) = channels.get(stream)
{
receivers.insert(stream.clone(), tx.subscribe());
}
}
receivers.retain(|stream, _| subs.contains(stream));
}
if receivers.is_empty() {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
continue;
}
let mut received_any = false;
for (stream, rx) in receivers.iter_mut() {
match rx.try_recv() {
Ok(payload) => {
callback(stream.clone(), payload);
received_any = true;
}
Err(broadcast::error::TryRecvError::Lagged(n)) => {
tracing::warn!("Receiver lagged, lost {} messages on stream {}", n, stream);
}
Err(broadcast::error::TryRecvError::Empty) => {
}
Err(broadcast::error::TryRecvError::Closed) => {
}
}
}
if !received_any {
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
#[tokio::test]
async fn test_publish_subscribe() {
let pubsub = MemoryPubSub::new();
let received = Arc::new(AtomicUsize::new(0));
let received_clone = received.clone();
pubsub.subscribe("test").await.unwrap();
let pubsub_clone = MemoryPubSub {
channels: pubsub.channels.clone(),
subscriptions: pubsub.subscriptions.clone(),
buffer_size: pubsub.buffer_size,
};
let handle = tokio::spawn(async move {
let _ = pubsub_clone
.listen(move |_stream, _payload| {
received_clone.fetch_add(1, Ordering::SeqCst);
})
.await;
});
tokio::time::sleep(Duration::from_millis(50)).await;
pubsub.publish("test", b"hello").await.unwrap();
pubsub.publish("test", b"world").await.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(received.load(Ordering::SeqCst), 2);
handle.abort();
}
#[tokio::test]
async fn test_unsubscribe() {
let pubsub = MemoryPubSub::new();
pubsub.subscribe("test").await.unwrap();
assert!(pubsub.subscriptions.read().await.contains("test"));
pubsub.unsubscribe("test").await.unwrap();
assert!(!pubsub.subscriptions.read().await.contains("test"));
}
#[test]
fn test_default() {
let pubsub = MemoryPubSub::default();
assert_eq!(pubsub.buffer_size, DEFAULT_BUFFER_SIZE);
}
}