#![cfg_attr(
feature = "ws",
doc = "- [`websocket::WebSocket`] - WebSocket connections (requires `ws` feature)"
)]
#![cfg_attr(
not(feature = "ws"),
doc = "- `websocket::WebSocket` - WebSocket connections (requires `ws` feature)"
)]
#[cfg(feature = "http")]
pub mod http;
pub mod mock;
pub mod signal;
pub mod terminal;
pub mod time;
#[cfg(feature = "ws")]
pub mod websocket;
use std::{
any::TypeId,
collections::{HashMap, HashSet},
hash::Hash,
};
use futures::{StreamExt, stream::BoxStream};
use tokio::{
sync::mpsc::{self},
task::JoinHandle,
};
pub struct Subscription<Msg: 'static> {
pub(super) id: SubscriptionId,
pub(super) spawn: Box<dyn FnOnce() -> BoxStream<'static, Msg> + Send>,
}
impl<Msg: 'static> Subscription<Msg> {
#[must_use]
pub fn new(source: impl SubscriptionSource<Output = Msg> + 'static) -> Self {
let id = source.id();
Self {
id,
spawn: Box::new(move || source.stream().boxed()),
}
}
#[must_use]
pub fn map<F, NewMsg>(self, f: F) -> Subscription<NewMsg>
where
F: Fn(Msg) -> NewMsg + Send + 'static,
Msg: 'static,
NewMsg: 'static,
{
let spawn = self.spawn;
Subscription {
id: self.id,
spawn: Box::new(move || {
let stream = spawn();
stream.map(f).boxed()
}),
}
}
}
impl<A: SubscriptionSource<Output = Msg> + 'static, Msg> From<A> for Subscription<Msg> {
fn from(value: A) -> Self {
Self::new(value)
}
}
pub trait SubscriptionSource: Send {
type Output;
fn stream(&self) -> BoxStream<'static, Self::Output>;
fn id(&self) -> SubscriptionId;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct SubscriptionId {
type_id: TypeId,
hash: u64,
}
impl SubscriptionId {
#[must_use]
pub fn of<T: 'static>(hash: u64) -> Self {
Self {
type_id: TypeId::of::<T>(),
hash,
}
}
}
struct RunningSubscription {
handle: JoinHandle<()>,
}
pub struct SubscriptionManager<Msg> {
running: HashMap<SubscriptionId, RunningSubscription>,
msg_sender: mpsc::UnboundedSender<Msg>,
}
impl<Msg: Send + 'static> SubscriptionManager<Msg> {
#[must_use]
pub fn new(msg_sender: mpsc::UnboundedSender<Msg>) -> Self {
Self {
running: HashMap::new(),
msg_sender,
}
}
pub fn update<I>(&mut self, subscriptions: I)
where
I: IntoIterator<Item = Subscription<Msg>>,
{
let mut new_subs: HashMap<_, _> = subscriptions
.into_iter()
.map(|sub| (sub.id, sub.spawn))
.collect();
let new_ids: HashSet<_> = new_subs.keys().copied().collect();
let current_ids: HashSet<_> = self.running.keys().copied().collect();
let to_remove: Vec<_> = current_ids.difference(&new_ids).copied().collect();
let to_add: Vec<_> = new_ids.difference(¤t_ids).copied().collect();
for id in to_remove {
if let Some(running) = self.running.remove(&id) {
running.handle.abort();
}
}
for id in to_add {
if let Some(spawn) = new_subs.remove(&id) {
let stream = spawn();
let handle = self.spawn_subscription(stream);
self.running.insert(id, RunningSubscription { handle });
}
}
}
fn spawn_subscription(&self, mut stream: BoxStream<'static, Msg>) -> JoinHandle<()> {
let sender = self.msg_sender.clone();
tokio::spawn(async move {
while let Some(msg) = stream.next().await {
if sender.send(msg).is_err() {
break;
}
}
})
}
pub fn shutdown(&mut self) {
for (_, running) in self.running.drain() {
running.handle.abort();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::subscription::mock::MockSource;
use color_eyre::eyre::Result;
use tokio::time::{Duration, sleep, timeout};
#[test]
fn test_subscription_new() {
use crate::subscription::mock::MockSource;
let mock = MockSource::<i32>::new();
let sub = Subscription::new(mock);
assert_eq!(sub.id.type_id, TypeId::of::<MockSource<i32>>());
}
#[tokio::test]
async fn test_subscription_map() -> Result<()> {
use crate::subscription::mock::MockSource;
let mock = MockSource::new();
let sub = Subscription::new(mock.clone()).map(|x: i32| x * 2);
let mut stream = (sub.spawn)();
mock.emit(1)?;
mock.emit(2)?;
mock.emit(3)?;
let mut results = vec![];
for _ in 0..3 {
if let Some(value) = stream.next().await {
results.push(value);
}
}
assert_eq!(results, vec![2, 4, 6]);
Ok(())
}
#[tokio::test]
async fn test_subscription_map_type_conversion() -> Result<()> {
use crate::subscription::mock::MockSource;
#[derive(Debug, PartialEq)]
enum Message {
Number(i32),
}
let mock = MockSource::new();
let sub = Subscription::new(mock.clone()).map(Message::Number);
let mut stream = (sub.spawn)();
mock.emit(1)?;
mock.emit(2)?;
mock.emit(3)?;
let mut results = vec![];
for _ in 0..3 {
if let Some(value) = stream.next().await {
results.push(value);
}
}
assert_eq!(
results,
vec![Message::Number(1), Message::Number(2), Message::Number(3)]
);
Ok(())
}
#[test]
fn test_subscription_id_of() {
let id1 = SubscriptionId::of::<i32>(12345);
let id2 = SubscriptionId::of::<i32>(12345);
let id3 = SubscriptionId::of::<i32>(67890);
assert_eq!(id1, id2);
assert_ne!(id1, id3);
}
#[test]
fn test_subscription_id_different_types() {
let id_i32 = SubscriptionId::of::<i32>(12345);
let id_u64 = SubscriptionId::of::<u64>(12345);
let id_string = SubscriptionId::of::<String>(12345);
assert_ne!(id_i32, id_u64);
assert_ne!(id_i32, id_string);
assert_ne!(id_u64, id_string);
}
#[tokio::test]
async fn test_subscription_manager_basic_update() -> Result<()> {
use crate::subscription::mock::MockSource;
let (tx, mut rx) = mpsc::unbounded_channel();
let mut manager = SubscriptionManager::new(tx);
let mock = MockSource::new();
let sub = Subscription::new(mock.clone());
manager.update(vec![sub]);
sleep(Duration::from_millis(10)).await;
mock.emit(10)?;
mock.emit(20)?;
let msg1 = timeout(Duration::from_millis(100), rx.recv()).await?;
assert_eq!(msg1, Some(10));
let msg2 = timeout(Duration::from_millis(100), rx.recv()).await?;
assert_eq!(msg2, Some(20));
Ok(())
}
#[tokio::test]
async fn test_subscription_manager_shutdown() {
use futures::stream;
struct InfiniteSub;
impl SubscriptionSource for InfiniteSub {
type Output = i32;
fn stream(&self) -> BoxStream<'static, Self::Output> {
stream::unfold(0, |state| async move {
sleep(Duration::from_millis(10)).await;
Some((state, state + 1))
})
.boxed()
}
fn id(&self) -> SubscriptionId {
SubscriptionId::of::<Self>(999)
}
}
let (tx, mut rx) = mpsc::unbounded_channel();
let mut manager = SubscriptionManager::new(tx);
let sub = Subscription::new(InfiniteSub);
manager.update(vec![sub]);
let _ = timeout(Duration::from_millis(100), rx.recv()).await;
manager.shutdown();
sleep(Duration::from_millis(50)).await;
}
#[tokio::test]
async fn test_subscription_manager_multiple_subscriptions() -> Result<()> {
use crate::subscription::mock::MockSource;
let (tx, mut rx) = mpsc::unbounded_channel();
let mut manager = SubscriptionManager::new(tx);
let mock1 = MockSource::new();
let mock2 = MockSource::new();
manager.update(vec![
Subscription::new(mock1.clone()),
Subscription::new(mock2.clone()),
]);
sleep(Duration::from_millis(10)).await;
mock1.emit(1)?;
mock2.emit(2)?;
let mut results = vec![];
for _ in 0..2 {
if let Ok(Some(msg)) = timeout(Duration::from_millis(100), rx.recv()).await {
results.push(msg);
}
}
results.sort_unstable();
assert_eq!(results, vec![1, 2]);
Ok(())
}
#[tokio::test]
async fn test_subscription_manager_subscription_starts_when_enabled() -> Result<()> {
let (tx, mut rx) = mpsc::unbounded_channel();
let mut manager = SubscriptionManager::new(tx);
let mock = MockSource::new();
manager.update(Vec::<Subscription<i32>>::new());
sleep(Duration::from_millis(10)).await;
manager.update(vec![Subscription::new(mock.clone())]);
sleep(Duration::from_millis(10)).await;
mock.emit(42)?;
let msg = timeout(Duration::from_millis(100), rx.recv()).await?;
assert_eq!(msg, Some(42));
Ok(())
}
#[tokio::test]
async fn test_subscription_manager_subscription_stops_when_disabled() -> Result<()> {
let (tx, mut rx) = mpsc::unbounded_channel();
let mut manager = SubscriptionManager::new(tx);
let mock = MockSource::new();
manager.update(vec![Subscription::new(mock.clone())]);
sleep(Duration::from_millis(10)).await;
mock.emit(1)?;
let msg = timeout(Duration::from_millis(100), rx.recv()).await?;
assert_eq!(msg, Some(1));
manager.update(Vec::<Subscription<i32>>::new());
sleep(Duration::from_millis(10)).await;
let _ = mock.emit(2); sleep(Duration::from_millis(10)).await;
assert!(rx.try_recv().is_err());
Ok(())
}
#[tokio::test]
async fn test_subscription_manager_subscription_changes_based_on_state() -> Result<()> {
let (tx, mut rx) = mpsc::unbounded_channel();
let mut manager = SubscriptionManager::new(tx);
let mock1 = MockSource::new();
let mock2 = MockSource::new();
manager.update(vec![Subscription::new(mock1.clone())]);
sleep(Duration::from_millis(10)).await;
mock1.emit(100)?;
let msg = timeout(Duration::from_millis(100), rx.recv()).await?;
assert_eq!(msg, Some(100));
manager.update(vec![Subscription::new(mock2.clone())]);
sleep(Duration::from_millis(10)).await;
let _ = mock1.emit(200);
mock2.emit(300)?;
let msg = timeout(Duration::from_millis(100), rx.recv()).await?;
assert_eq!(msg, Some(300));
Ok(())
}
#[tokio::test]
async fn test_subscription_manager_subscription_multiple_changes() -> Result<()> {
let (tx, mut rx) = mpsc::unbounded_channel();
let mut manager = SubscriptionManager::new(tx);
let mock = MockSource::new();
manager.update(vec![Subscription::new(mock.clone())]);
sleep(Duration::from_millis(10)).await;
mock.emit(1)?;
assert_eq!(
timeout(Duration::from_millis(100), rx.recv()).await?,
Some(1)
);
manager.update(Vec::<Subscription<i32>>::new());
sleep(Duration::from_millis(10)).await;
manager.update(vec![Subscription::new(mock.clone())]);
sleep(Duration::from_millis(10)).await;
mock.emit(2)?;
assert_eq!(
timeout(Duration::from_millis(100), rx.recv()).await?,
Some(2)
);
manager.update(Vec::<Subscription<i32>>::new());
sleep(Duration::from_millis(10)).await;
manager.update(vec![Subscription::new(mock.clone())]);
sleep(Duration::from_millis(10)).await;
mock.emit(3)?;
assert_eq!(
timeout(Duration::from_millis(100), rx.recv()).await?,
Some(3)
);
Ok(())
}
}