azoth_bus/
notification.rs1use std::collections::HashMap;
4use std::sync::{Arc, RwLock};
5use std::time::Duration;
6
7#[derive(Clone)]
9pub enum WakeStrategy {
10 Poll { interval: Duration },
12
13 Notify(Arc<RwLock<HashMap<String, Arc<tokio::sync::Notify>>>>),
15}
16
17impl WakeStrategy {
18 pub fn poll(interval: Duration) -> Self {
20 Self::Poll { interval }
21 }
22
23 pub fn notify() -> Self {
25 Self::Notify(Arc::new(RwLock::new(HashMap::new())))
26 }
27
28 pub async fn wait(&self, stream: &str) {
30 match self {
31 WakeStrategy::Poll { interval } => {
32 tokio::time::sleep(*interval).await;
33 }
34 WakeStrategy::Notify(hub) => {
35 let notify = {
37 let mut map = hub.write().unwrap();
38 map.entry(stream.to_string())
39 .or_insert_with(|| Arc::new(tokio::sync::Notify::new()))
40 .clone()
41 };
42 notify.notified().await;
43 }
44 }
45 }
46
47 pub fn notify_stream(&self, stream: &str) {
51 if let WakeStrategy::Notify(hub) = self {
52 let map = hub.read().unwrap();
53 if let Some(notify) = map.get(stream) {
54 notify.notify_waiters();
55 }
56 }
57 }
58
59 pub fn notify_all(&self) {
63 if let WakeStrategy::Notify(hub) = self {
64 let map = hub.read().unwrap();
65 for notify in map.values() {
66 notify.notify_waiters();
67 }
68 }
69 }
70}
71
72impl Default for WakeStrategy {
73 fn default() -> Self {
74 Self::Poll {
76 interval: Duration::from_millis(10),
77 }
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84
85 #[tokio::test]
86 async fn test_poll_strategy() {
87 let strategy = WakeStrategy::poll(Duration::from_millis(1));
88
89 let start = std::time::Instant::now();
90 strategy.wait("test").await;
91 let elapsed = start.elapsed();
92
93 assert!(elapsed >= Duration::from_millis(1));
94 assert!(elapsed < Duration::from_millis(100)); }
96
97 #[tokio::test]
98 async fn test_notify_strategy() {
99 let strategy = WakeStrategy::notify();
100
101 let strategy_clone = strategy.clone();
103 let handle = tokio::spawn(async move {
104 strategy_clone.wait("test").await;
105 });
106
107 tokio::time::sleep(Duration::from_millis(10)).await;
109
110 strategy.notify_stream("test");
112
113 tokio::time::timeout(Duration::from_millis(100), handle)
115 .await
116 .expect("Task should complete")
117 .expect("Task should not panic");
118 }
119
120 #[tokio::test]
121 async fn test_notify_all() {
122 let strategy = WakeStrategy::notify();
123
124 let strategy1 = strategy.clone();
126 let handle1 = tokio::spawn(async move {
127 strategy1.wait("stream1").await;
128 });
129
130 let strategy2 = strategy.clone();
131 let handle2 = tokio::spawn(async move {
132 strategy2.wait("stream2").await;
133 });
134
135 tokio::time::sleep(Duration::from_millis(10)).await;
136
137 strategy.notify_all();
139
140 tokio::time::timeout(Duration::from_millis(100), handle1)
142 .await
143 .expect("Task 1 should complete")
144 .expect("Task 1 should not panic");
145
146 tokio::time::timeout(Duration::from_millis(100), handle2)
147 .await
148 .expect("Task 2 should complete")
149 .expect("Task 2 should not panic");
150 }
151}