use std::collections::HashMap;
use std::future::Future;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use tokio_stream::{Stream, StreamExt};
use crate::Action;
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub struct SubKey(String);
impl SubKey {
pub fn new(name: impl Into<String>) -> Self {
Self(name.into())
}
pub fn name(&self) -> &str {
&self.0
}
}
impl From<&'static str> for SubKey {
fn from(s: &'static str) -> Self {
Self::new(s)
}
}
impl From<String> for SubKey {
fn from(s: String) -> Self {
Self(s)
}
}
#[derive(Clone)]
pub struct SubPauseHandle {
paused: Arc<AtomicBool>,
}
impl SubPauseHandle {
pub fn pause(&self) {
self.paused.store(true, Ordering::SeqCst);
}
pub fn resume(&self) {
self.paused.store(false, Ordering::SeqCst);
}
pub fn is_paused(&self) -> bool {
self.paused.load(Ordering::SeqCst)
}
}
pub struct Subscriptions<A> {
handles: HashMap<SubKey, JoinHandle<()>>,
action_tx: mpsc::UnboundedSender<A>,
paused: Arc<AtomicBool>,
}
impl<A> Subscriptions<A>
where
A: Action,
{
pub fn new(action_tx: mpsc::UnboundedSender<A>) -> Self {
Self {
handles: HashMap::new(),
action_tx,
paused: Arc::new(AtomicBool::new(false)),
}
}
pub fn pause(&self) {
self.paused.store(true, Ordering::SeqCst);
}
pub fn resume(&self) {
self.paused.store(false, Ordering::SeqCst);
}
pub fn is_paused(&self) -> bool {
self.paused.load(Ordering::SeqCst)
}
pub fn pause_handle(&self) -> SubPauseHandle {
SubPauseHandle {
paused: self.paused.clone(),
}
}
pub fn cleanup(&mut self) {
self.handles.retain(|_, handle| !handle.is_finished());
}
pub fn interval<F>(
&mut self,
key: impl Into<SubKey>,
duration: Duration,
action_fn: F,
) -> &mut Self
where
F: Fn() -> A + Send + 'static,
{
let key = key.into();
self.cleanup();
self.cancel(&key);
let tx = self.action_tx.clone();
let paused = self.paused.clone();
let handle = tokio::spawn(async move {
let mut interval = tokio::time::interval(duration);
interval.tick().await;
loop {
interval.tick().await;
if paused.load(Ordering::SeqCst) {
continue;
}
let action = action_fn();
if tx.send(action).is_err() {
break;
}
}
});
self.handles.insert(key, handle);
self
}
pub fn interval_immediate<F>(
&mut self,
key: impl Into<SubKey>,
duration: Duration,
action_fn: F,
) -> &mut Self
where
F: Fn() -> A + Send + 'static,
{
let key = key.into();
self.cleanup();
self.cancel(&key);
let tx = self.action_tx.clone();
let paused = self.paused.clone();
let handle = tokio::spawn(async move {
let mut interval = tokio::time::interval(duration);
loop {
interval.tick().await;
if paused.load(Ordering::SeqCst) {
continue;
}
let action = action_fn();
if tx.send(action).is_err() {
break;
}
}
});
self.handles.insert(key, handle);
self
}
pub fn stream<S>(&mut self, key: impl Into<SubKey>, stream: S) -> &mut Self
where
S: Stream<Item = A> + Send + 'static,
{
let key = key.into();
self.cleanup();
self.cancel(&key);
let tx = self.action_tx.clone();
let paused = self.paused.clone();
let handle = tokio::spawn(async move {
tokio::pin!(stream);
while let Some(action) = stream.next().await {
if paused.load(Ordering::SeqCst) {
continue;
}
if tx.send(action).is_err() {
break;
}
}
});
self.handles.insert(key, handle);
self
}
pub fn stream_async<F, S>(&mut self, key: impl Into<SubKey>, stream_fn: F) -> &mut Self
where
F: Future<Output = S> + Send + 'static,
S: Stream<Item = A> + Send + 'static,
{
let key = key.into();
self.cleanup();
self.cancel(&key);
let tx = self.action_tx.clone();
let paused = self.paused.clone();
let handle = tokio::spawn(async move {
let stream = stream_fn.await;
tokio::pin!(stream);
while let Some(action) = stream.next().await {
if paused.load(Ordering::SeqCst) {
continue;
}
if tx.send(action).is_err() {
break;
}
}
});
self.handles.insert(key, handle);
self
}
pub fn cancel(&mut self, key: &SubKey) {
if let Some(handle) = self.handles.remove(key) {
handle.abort();
}
}
pub fn cancel_all(&mut self) {
for (_, handle) in self.handles.drain() {
handle.abort();
}
}
pub fn is_active(&self, key: &SubKey) -> bool {
self.handles
.get(key)
.map(|handle| !handle.is_finished())
.unwrap_or(false)
}
pub fn len(&self) -> usize {
self.handles
.values()
.filter(|handle| !handle.is_finished())
.count()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn active_keys(&self) -> impl Iterator<Item = &SubKey> {
self.handles
.iter()
.filter(|(_, handle)| !handle.is_finished())
.map(|(key, _)| key)
}
}
impl<A> Drop for Subscriptions<A> {
fn drop(&mut self) {
for (_, handle) in self.handles.drain() {
handle.abort();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
#[derive(Clone, Debug)]
enum TestAction {
Tick,
Value(usize),
}
impl Action for TestAction {
fn name(&self) -> &'static str {
match self {
TestAction::Tick => "Tick",
TestAction::Value(_) => "Value",
}
}
}
#[test]
fn test_sub_key() {
let k1 = SubKey::new("test");
let k2 = SubKey::from("test");
let k3: SubKey = "test".into();
assert_eq!(k1, k2);
assert_eq!(k2, k3);
assert_eq!(k1.name(), "test");
}
#[tokio::test]
async fn test_interval_emits_actions() {
let (tx, mut rx) = mpsc::unbounded_channel();
let mut subs = Subscriptions::new(tx);
subs.interval("tick", Duration::from_millis(20), || TestAction::Tick);
let action = tokio::time::timeout(Duration::from_millis(100), rx.recv())
.await
.expect("timeout")
.expect("channel closed");
assert!(matches!(action, TestAction::Tick));
let action2 = tokio::time::timeout(Duration::from_millis(50), rx.recv())
.await
.expect("timeout")
.expect("channel closed");
assert!(matches!(action2, TestAction::Tick));
}
#[tokio::test]
async fn test_interval_immediate() {
let (tx, mut rx) = mpsc::unbounded_channel();
let mut subs = Subscriptions::new(tx);
subs.interval_immediate("tick", Duration::from_millis(100), || TestAction::Tick);
let action = tokio::time::timeout(Duration::from_millis(20), rx.recv())
.await
.expect("should receive immediately")
.expect("channel closed");
assert!(matches!(action, TestAction::Tick));
}
#[tokio::test]
async fn test_stream_forwards_items() {
let (tx, mut rx) = mpsc::unbounded_channel();
let mut subs = Subscriptions::new(tx);
let stream = tokio_stream::iter(vec![
TestAction::Value(1),
TestAction::Value(2),
TestAction::Value(3),
]);
subs.stream("test", stream);
let mut values = vec![];
for _ in 0..3 {
let action = tokio::time::timeout(Duration::from_millis(100), rx.recv())
.await
.expect("timeout")
.expect("channel closed");
if let TestAction::Value(v) = action {
values.push(v);
}
}
assert_eq!(values, vec![1, 2, 3]);
}
#[tokio::test]
async fn test_cancel_stops_subscription() {
let (tx, mut rx) = mpsc::unbounded_channel();
let mut subs = Subscriptions::new(tx);
subs.interval("tick", Duration::from_millis(10), || TestAction::Tick);
assert!(subs.is_active(&SubKey::new("tick")));
let _ = tokio::time::timeout(Duration::from_millis(50), rx.recv()).await;
subs.cancel(&SubKey::new("tick"));
assert!(!subs.is_active(&SubKey::new("tick")));
while rx.try_recv().is_ok() {}
let result = tokio::time::timeout(Duration::from_millis(50), rx.recv()).await;
assert!(result.is_err(), "should timeout - no more ticks");
}
#[tokio::test]
async fn test_cancel_all() {
let (tx, _rx) = mpsc::unbounded_channel();
let mut subs = Subscriptions::new(tx);
subs.interval("a", Duration::from_secs(10), || TestAction::Tick);
subs.interval("b", Duration::from_secs(10), || TestAction::Tick);
assert_eq!(subs.len(), 2);
subs.cancel_all();
assert!(subs.is_empty());
}
#[tokio::test]
async fn test_replace_existing_subscription() {
let (tx, mut rx) = mpsc::unbounded_channel();
let mut subs = Subscriptions::new(tx);
let counter = Arc::new(AtomicUsize::new(0));
let c1 = counter.clone();
subs.interval("test", Duration::from_millis(10), move || {
c1.fetch_add(1, Ordering::SeqCst);
TestAction::Value(1)
});
let c2 = counter.clone();
subs.interval("test", Duration::from_millis(10), move || {
c2.fetch_add(100, Ordering::SeqCst);
TestAction::Value(2)
});
assert_eq!(subs.len(), 1);
tokio::time::sleep(Duration::from_millis(50)).await;
let mut got_two = false;
while let Ok(action) = rx.try_recv() {
if let TestAction::Value(v) = action {
assert_eq!(v, 2);
got_two = true;
}
}
assert!(got_two, "should have received Value(2)");
}
#[test]
fn test_active_keys() {
let (tx, _rx) = mpsc::unbounded_channel::<TestAction>();
let subs = Subscriptions::new(tx);
assert!(subs.is_empty());
assert_eq!(subs.len(), 0);
}
#[test]
fn test_pause_handle_basic() {
let (tx, _rx) = mpsc::unbounded_channel::<TestAction>();
let subs = Subscriptions::new(tx);
let handle = subs.pause_handle();
assert!(!handle.is_paused());
handle.pause();
assert!(handle.is_paused());
handle.resume();
assert!(!handle.is_paused());
}
#[tokio::test]
async fn test_pause_suppresses_interval() {
let (tx, mut rx) = mpsc::unbounded_channel::<TestAction>();
let mut subs = Subscriptions::new(tx);
let handle = subs.pause_handle();
subs.interval("tick", Duration::from_millis(10), || TestAction::Tick);
let _ = tokio::time::timeout(Duration::from_millis(50), rx.recv()).await;
handle.pause();
tokio::time::sleep(Duration::from_millis(20)).await;
while rx.try_recv().is_ok() {}
let result = tokio::time::timeout(Duration::from_millis(50), rx.recv()).await;
assert!(result.is_err(), "should timeout - subscription is paused");
handle.resume();
let result = tokio::time::timeout(Duration::from_millis(50), rx.recv()).await;
assert!(result.is_ok(), "should receive tick after resume");
}
#[test]
fn test_pause_handle_clone() {
let (tx, _rx) = mpsc::unbounded_channel::<TestAction>();
let subs = Subscriptions::new(tx);
let handle1 = subs.pause_handle();
let handle2 = handle1.clone();
handle1.pause();
assert!(handle2.is_paused());
handle2.resume();
assert!(!handle1.is_paused());
}
#[tokio::test]
async fn test_finished_stream_cleaned_up() {
let (tx, mut rx) = mpsc::unbounded_channel();
let mut subs = Subscriptions::new(tx);
let stream = tokio_stream::iter(vec![
TestAction::Value(1),
TestAction::Value(2),
TestAction::Value(3),
]);
subs.stream("finite", stream);
for _ in 0..3 {
let _ = tokio::time::timeout(Duration::from_millis(100), rx.recv())
.await
.expect("timeout");
}
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(!subs.is_active(&SubKey::new("finite")));
assert_eq!(subs.len(), 0);
}
#[tokio::test]
async fn test_is_active_accurate_for_running_interval() {
let (tx, mut rx) = mpsc::unbounded_channel();
let mut subs = Subscriptions::new(tx);
subs.interval("tick", Duration::from_millis(20), || TestAction::Tick);
assert!(subs.is_active(&SubKey::new("tick")));
assert_eq!(subs.len(), 1);
let _ = tokio::time::timeout(Duration::from_millis(100), rx.recv())
.await
.expect("timeout");
assert!(subs.is_active(&SubKey::new("tick")));
assert_eq!(subs.len(), 1);
subs.cancel(&SubKey::new("tick"));
assert!(!subs.is_active(&SubKey::new("tick")));
assert_eq!(subs.len(), 0);
}
}