use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tokio::sync::broadcast;
#[derive(Clone)]
pub struct Channels {
inner: Arc<ChannelsInner>,
}
struct ChannelsInner {
capacity: usize,
registry: Mutex<HashMap<String, Arc<broadcast::Sender<ChannelMessage>>>>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ChannelMessage(pub String);
impl From<String> for ChannelMessage {
fn from(s: String) -> Self {
Self(s)
}
}
impl From<&str> for ChannelMessage {
fn from(s: &str) -> Self {
Self(s.to_owned())
}
}
impl ChannelMessage {
#[must_use]
pub fn as_str(&self) -> &str {
&self.0
}
#[must_use]
pub fn into_string(self) -> String {
self.0
}
}
impl std::fmt::Display for ChannelMessage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
#[derive(Clone)]
pub struct Sender {
inner: Arc<broadcast::Sender<ChannelMessage>>,
}
impl Sender {
pub fn send(
&self,
msg: impl Into<ChannelMessage>,
) -> Result<usize, broadcast::error::SendError<ChannelMessage>> {
self.inner.send(msg.into())
}
#[must_use]
pub fn receiver_count(&self) -> usize {
self.inner.receiver_count()
}
}
pub struct Subscriber {
inner: broadcast::Receiver<ChannelMessage>,
}
impl Subscriber {
pub async fn recv(&mut self) -> Result<ChannelMessage, broadcast::error::RecvError> {
self.inner.recv().await
}
#[cfg(feature = "ws")]
pub fn into_stream(self) -> impl tokio_stream::Stream<Item = ChannelMessage> {
use tokio_stream::StreamExt;
tokio_stream::wrappers::BroadcastStream::new(self.inner).filter_map(std::result::Result::ok)
}
}
impl Channels {
#[must_use]
pub fn new(capacity: usize) -> Self {
Self {
inner: Arc::new(ChannelsInner {
capacity: capacity.clamp(1, 16384),
registry: Mutex::new(HashMap::new()),
}),
}
}
#[must_use]
pub fn sender(&self, name: &str) -> Sender {
let mut registry = self.inner.registry.lock().expect("channels lock poisoned");
#[allow(clippy::option_if_let_else)]
let tx = if let Some(tx) = registry.get(name) {
Arc::clone(tx)
} else {
let capacity = std::cmp::max(1, self.inner.capacity);
let tx = Arc::new(broadcast::channel(capacity).0);
registry.insert(name.to_owned(), Arc::clone(&tx));
tx
};
let sender = Sender { inner: tx };
drop(registry);
sender
}
#[must_use]
pub fn subscribe(&self, name: &str) -> Subscriber {
let mut registry = self.inner.registry.lock().expect("channels lock poisoned");
#[allow(clippy::option_if_let_else)]
let tx = if let Some(tx) = registry.get(name) {
Arc::clone(tx)
} else {
let capacity = std::cmp::max(1, self.inner.capacity);
let tx = Arc::new(broadcast::channel(capacity).0);
registry.insert(name.to_owned(), Arc::clone(&tx));
tx
};
let subscriber = Subscriber {
inner: tx.subscribe(),
};
drop(registry);
subscriber
}
#[must_use]
pub fn channel_count(&self) -> usize {
let registry = self.inner.registry.lock().expect("channels lock poisoned");
registry.len()
}
pub fn gc(&self) {
let mut registry = self.inner.registry.lock().expect("channels lock poisoned");
registry.retain(|_, tx| tx.receiver_count() > 0 || Arc::strong_count(tx) > 1);
}
#[must_use]
pub fn snapshot(&self) -> HashMap<String, usize> {
let registry = self.inner.registry.lock().expect("channels lock poisoned");
registry
.iter()
.map(|(name, tx)| (name.clone(), tx.receiver_count()))
.collect()
}
#[cfg(feature = "ws")]
pub fn sse_stream(
&self,
name: &str,
) -> axum::response::sse::Sse<
impl tokio_stream::Stream<Item = Result<axum::response::sse::Event, std::convert::Infallible>>,
> {
use tokio_stream::StreamExt;
let rx = self.subscribe(name);
let stream = rx
.into_stream()
.map(|msg| Ok(axum::response::sse::Event::default().data(msg.into_string())));
axum::response::sse::Sse::new(stream).keep_alive(crate::sse::keep_alive())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn create_channels() {
let channels = Channels::new(16);
assert_eq!(channels.channel_count(), 0);
}
#[test]
fn sender_creates_channel_lazily() {
let channels = Channels::new(16);
let _tx = channels.sender("test");
assert_eq!(channels.channel_count(), 1);
}
#[test]
fn subscribe_creates_channel_lazily() {
let channels = Channels::new(16);
let _rx = channels.subscribe("test");
assert_eq!(channels.channel_count(), 1);
}
#[tokio::test]
async fn send_and_receive() -> Result<(), broadcast::error::RecvError> {
let channels = Channels::new(16);
let tx = channels.sender("chat");
let mut rx = channels.subscribe("chat");
tx.send("hello").expect("should send");
let msg = rx.recv().await?;
assert_eq!(msg.as_str(), "hello");
Ok(())
}
#[tokio::test]
async fn multiple_subscribers() -> Result<(), broadcast::error::RecvError> {
let channels = Channels::new(16);
let tx = channels.sender("chat");
let mut rx1 = channels.subscribe("chat");
let mut rx2 = channels.subscribe("chat");
tx.send("broadcast").expect("should send");
let msg1 = rx1.recv().await?;
let msg2 = rx2.recv().await?;
assert_eq!(msg1.as_str(), "broadcast");
assert_eq!(msg2.as_str(), "broadcast");
Ok(())
}
#[test]
fn sender_receiver_count() {
let channels = Channels::new(16);
let tx = channels.sender("chat");
assert_eq!(tx.receiver_count(), 0);
let _rx1 = channels.subscribe("chat");
assert_eq!(tx.receiver_count(), 1);
let _rx2 = channels.subscribe("chat");
assert_eq!(tx.receiver_count(), 2);
}
#[test]
fn channel_message_conversions() {
let msg: ChannelMessage = "hello".into();
assert_eq!(msg.as_str(), "hello");
assert_eq!(msg.to_string(), "hello");
let msg2: ChannelMessage = String::from("world").into();
assert_eq!(msg2.into_string(), "world");
}
#[test]
#[allow(clippy::redundant_clone)]
fn channels_is_clone() {
let channels = Channels::new(16);
let _cloned = channels.clone();
}
#[test]
fn snapshot_returns_counts() {
let channels = Channels::new(16);
let _tx = channels.sender("empty");
let _tx2 = channels.sender("one");
let _rx_one = channels.subscribe("one");
let _tx3 = channels.sender("two");
let _rx_two_1 = channels.subscribe("two");
let _rx_two_2 = channels.subscribe("two");
let snap = channels.snapshot();
assert_eq!(snap.get("empty"), Some(&0));
assert_eq!(snap.get("one"), Some(&1));
assert_eq!(snap.get("two"), Some(&2));
assert_eq!(snap.len(), 3);
}
#[test]
fn gc_removes_dead_channels() {
let channels = Channels::new(16);
let _tx = channels.sender("alive");
{
let _tx = channels.sender("dead");
}
assert_eq!(channels.channel_count(), 2);
channels.gc();
assert_eq!(channels.channel_count(), 1);
}
#[cfg(feature = "ws")]
#[tokio::test]
async fn subscriber_into_stream() {
use tokio_stream::StreamExt;
let channels = Channels::new(16);
let tx = channels.sender("test_stream");
let rx = channels.subscribe("test_stream");
tx.send("message 1").unwrap();
tx.send("message 2").unwrap();
let mut stream = rx.into_stream();
let msg1 = stream.next().await.unwrap();
assert_eq!(msg1.as_str(), "message 1");
let msg2 = stream.next().await.unwrap();
assert_eq!(msg2.as_str(), "message 2");
}
#[cfg(feature = "ws")]
#[tokio::test]
async fn channels_sse_stream() {
let channels = Channels::new(16);
let tx = channels.sender("test_sse");
let sse = channels.sse_stream("test_sse");
tx.send("sse message").unwrap();
let _stream = sse;
}
}