use super::{PubSubBackend, Subscription};
use crate::error::PubSubError;
use async_trait::async_trait;
use std::{
collections::HashMap,
future::Future,
sync::{Arc, Mutex},
};
use tokio::sync::{broadcast, oneshot};
#[derive(Clone)]
pub struct InMemoryPubSub {
topics: Arc<Mutex<HashMap<String, broadcast::Sender<Vec<u8>>>>>,
buffer: usize,
}
impl InMemoryPubSub {
pub fn new() -> Self {
Self::with_buffer(64)
}
pub fn with_buffer(buffer: usize) -> Self {
Self {
topics: Arc::new(Mutex::new(HashMap::new())),
buffer: buffer.max(1),
}
}
fn get_or_create_sender(&self, topic: &str) -> broadcast::Sender<Vec<u8>> {
let mut map = self.topics.lock().expect("InMemoryPubSub lock poisoned");
map
.entry(topic.to_string())
.or_insert_with(|| {
let (tx, _rx) = broadcast::channel(self.buffer);
tx
})
.clone()
}
}
impl Default for InMemoryPubSub {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl PubSubBackend for InMemoryPubSub {
async fn publish_bytes(&self, topic: &str, payload: Vec<u8>) -> Result<(), PubSubError> {
let sender = self.get_or_create_sender(topic);
let _ = sender.send(payload);
Ok(())
}
async fn subscribe_bytes(
&self,
topic: &str,
handler: Box<dyn Fn(Vec<u8>) -> std::pin::Pin<Box<dyn Future<Output = ()> + Send>> + Send>,
) -> Result<Subscription, PubSubError> {
let sender = self.get_or_create_sender(topic);
let mut rx = sender.subscribe();
let (stop_tx, mut stop_rx) = oneshot::channel::<()>();
let handle = tokio::spawn(async move {
loop {
tokio::select! {
res = rx.recv() => {
match res {
Ok(msg) => handler(msg).await,
Err(broadcast::error::RecvError::Lagged(_)) => {
continue;
}
Err(broadcast::error::RecvError::Closed) => break,
}
}
_ = &mut stop_rx => {
break;
}
}
}
});
Ok(Subscription::new(handle, stop_tx))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pubsub::PubSubExt;
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
struct TestMsg {
val: String,
}
#[tokio::test]
async fn in_memory_pubsub_works() {
let backend = InMemoryPubSub::new();
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
let (ready_tx, ready_rx) = tokio::sync::oneshot::channel();
let _sub = backend
.subscribe("topic", move |data| {
let tx = tx.clone();
async move {
let msg: TestMsg = serde_json::from_slice(&data).unwrap();
tx.send(msg).await.ok();
}
})
.await
.unwrap();
let _ = ready_tx.send(());
ready_rx.await.ok();
tokio::time::sleep(Duration::from_millis(50)).await;
backend
.publish(
"topic",
&TestMsg {
val: "hello".into(),
},
)
.await
.unwrap();
let received = tokio::time::timeout(Duration::from_secs(1), rx.recv())
.await
.unwrap()
.unwrap();
assert_eq!(
received,
TestMsg {
val: "hello".into()
}
);
}
#[tokio::test]
async fn default_impl_works() {
let backend = InMemoryPubSub::default();
assert!(backend.buffer > 0);
}
}