1use std::collections::{hash_map, HashMap};
2use tokio::sync::mpsc;
3use web_async::Lock;
4
5use super::BroadcastConsumer;
6
7struct BroadcastState {
9 active: BroadcastConsumer,
10 backup: Vec<BroadcastConsumer>,
11}
12
13#[derive(Default)]
14struct ProducerState {
15 active: HashMap<String, BroadcastState>,
16 consumers: Vec<ConsumerState>,
17}
18
19impl ProducerState {
20 fn publish(&mut self, path: String, broadcast: BroadcastConsumer) {
21 match self.active.entry(path.clone()) {
22 hash_map::Entry::Occupied(mut entry) => {
23 let state = entry.get_mut();
24 if state.active.is_clone(&broadcast) {
25 tracing::warn!(?path, "skipping duplicate publish");
26 return;
27 }
28
29 let old = state.active.clone();
31 state.active = broadcast.clone();
32
33 let pos = state.backup.iter().position(|b| b.is_clone(&broadcast));
36 if let Some(pos) = pos {
37 state.backup[pos] = old;
38 } else {
39 state.backup.push(old);
40 }
41
42 retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
44 }
45 hash_map::Entry::Vacant(entry) => {
46 entry.insert(BroadcastState {
47 active: broadcast.clone(),
48 backup: Vec::new(),
49 });
50 }
51 };
52
53 retain_mut_unordered(&mut self.consumers, |c| c.insert(&path, &broadcast));
54 }
55
56 fn remove(&mut self, path: String, broadcast: BroadcastConsumer) {
57 let mut entry = match self.active.entry(path) {
58 hash_map::Entry::Occupied(entry) => entry,
59 hash_map::Entry::Vacant(_) => panic!("broadcast not found"),
60 };
61
62 let pos = entry.get().backup.iter().position(|b| b.is_clone(&broadcast));
64 if let Some(pos) = pos {
65 entry.get_mut().backup.remove(pos);
66 return;
68 }
69
70 assert!(entry.get().active.is_clone(&broadcast));
72
73 retain_mut_unordered(&mut self.consumers, |c| c.remove(entry.key()));
74
75 if let Some(active) = entry.get_mut().backup.pop() {
77 entry.get_mut().active = active;
78 retain_mut_unordered(&mut self.consumers, |c| c.insert(entry.key(), &entry.get().active));
79 } else {
80 entry.remove();
82 }
83 }
84}
85
86impl Drop for ProducerState {
87 fn drop(&mut self) {
88 for (path, _) in self.active.drain() {
89 retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
90 }
91 }
92}
93
94fn retain_mut_unordered<T, F: Fn(&mut T) -> bool>(vec: &mut Vec<T>, f: F) {
96 let mut i = 0;
97 while let Some(item) = vec.get_mut(i) {
98 if f(item) {
99 i += 1;
100 } else {
101 vec.swap_remove(i);
102 }
103 }
104}
105
106type ConsumerUpdate = (String, Option<BroadcastConsumer>);
108
109struct ConsumerState {
110 prefix: String,
111 updates: mpsc::UnboundedSender<ConsumerUpdate>,
112}
113
114impl ConsumerState {
115 pub fn insert(&mut self, path: &str, consumer: &BroadcastConsumer) -> bool {
117 if let Some(suffix) = path.strip_prefix(&self.prefix) {
118 let update = (suffix.to_string(), Some(consumer.clone()));
119 self.updates.send(update).is_ok()
120 } else {
121 !self.updates.is_closed()
122 }
123 }
124
125 pub fn remove(&mut self, path: &str) -> bool {
126 if let Some(suffix) = path.strip_prefix(&self.prefix) {
127 let update = (suffix.to_string(), None);
128 self.updates.send(update).is_ok()
129 } else {
130 !self.updates.is_closed()
131 }
132 }
133}
134
135#[derive(Clone, Default)]
137pub struct OriginProducer {
138 state: Lock<ProducerState>,
139}
140
141impl OriginProducer {
142 pub fn new() -> Self {
143 Self {
144 state: Lock::new(ProducerState {
145 active: HashMap::new(),
146 consumers: Vec::new(),
147 }),
148 }
149 }
150
151 pub fn publish<S: ToString>(&mut self, path: S, broadcast: BroadcastConsumer) {
158 let path = path.to_string();
159 self.state.lock().publish(path.clone(), broadcast.clone());
160
161 let state = self.state.clone().downgrade();
162
163 web_async::spawn(async move {
165 broadcast.closed().await;
166 if let Some(state) = state.upgrade() {
167 state.lock().remove(path, broadcast);
168 }
169 });
170 }
171
172 pub fn publish_all(&mut self, broadcasts: OriginConsumer) {
174 self.publish_prefix("", broadcasts);
175 }
176
177 pub fn publish_prefix(&mut self, prefix: &str, mut broadcasts: OriginConsumer) {
179 let mut this = self.clone();
181
182 let prefix = match prefix {
184 "" => None,
185 prefix => Some(prefix.to_string()),
186 };
187
188 web_async::spawn(async move {
189 while let Some((suffix, broadcast)) = broadcasts.next().await {
190 let broadcast = match broadcast {
191 Some(broadcast) => broadcast,
192 None => continue,
195 };
196
197 let path = match &prefix {
198 Some(prefix) => format!("{}{}", prefix, suffix),
199 None => suffix,
200 };
201
202 this.publish(path, broadcast);
203 }
204 });
205 }
206
207 pub fn consume(&self, path: &str) -> Option<BroadcastConsumer> {
211 self.state.lock().active.get(path).map(|b| b.active.clone())
212 }
213
214 pub fn consume_all(&self) -> OriginConsumer {
216 self.consume_prefix("")
217 }
218
219 pub fn consume_prefix<S: ToString>(&self, prefix: S) -> OriginConsumer {
221 let mut state = self.state.lock();
222
223 let (tx, rx) = mpsc::unbounded_channel();
224 let mut consumer = ConsumerState {
225 prefix: prefix.to_string(),
226 updates: tx,
227 };
228
229 for (prefix, broadcast) in &state.active {
230 consumer.insert(prefix, &broadcast.active);
231 }
232 state.consumers.push(consumer);
233
234 OriginConsumer::new(rx)
235 }
236
237 pub async fn unused(&self) {
241 while let Some(notify) = self.unused_inner() {
243 notify.closed().await;
244 }
245 }
246
247 fn unused_inner(&self) -> Option<mpsc::UnboundedSender<ConsumerUpdate>> {
249 let mut state = self.state.lock();
250
251 while let Some(consumer) = state.consumers.last() {
252 if !consumer.updates.is_closed() {
253 return Some(consumer.updates.clone());
254 }
255
256 state.consumers.pop();
257 }
258
259 None
260 }
261}
262
263pub struct OriginConsumer {
265 updates: mpsc::UnboundedReceiver<ConsumerUpdate>,
266}
267
268impl OriginConsumer {
269 fn new(updates: mpsc::UnboundedReceiver<ConsumerUpdate>) -> Self {
270 Self { updates }
271 }
272
273 pub async fn next(&mut self) -> Option<ConsumerUpdate> {
278 self.updates.recv().await
279 }
280}
281
282#[cfg(test)]
283use futures::FutureExt;
284
285#[cfg(test)]
286impl OriginConsumer {
287 pub fn assert_next(&mut self, path: &str, broadcast: &BroadcastConsumer) {
288 let next = self.next().now_or_never().expect("next blocked").expect("no next");
289 assert_eq!(next.0, path, "wrong path");
290 assert!(next.1.unwrap().is_clone(broadcast), "should be the same broadcast");
291 }
292
293 pub fn assert_next_none(&mut self, path: &str) {
294 let next = self.next().now_or_never().expect("next blocked").expect("no next");
295 assert_eq!(next.0, path, "wrong path");
296 assert!(next.1.is_none(), "should be unannounced");
297 }
298
299 pub fn assert_next_wait(&mut self) {
300 assert!(self.next().now_or_never().is_none(), "next should block");
301 }
302
303 pub fn assert_next_closed(&mut self) {
304 assert!(
305 self.next().now_or_never().expect("next blocked").is_none(),
306 "next should be closed"
307 );
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use crate::BroadcastProducer;
314
315 use super::*;
316
317 #[tokio::test]
318 async fn test_announce() {
319 let mut producer = OriginProducer::new();
320 let broadcast1 = BroadcastProducer::new();
321 let broadcast2 = BroadcastProducer::new();
322
323 let mut consumer1 = producer.consume_all();
325 consumer1.assert_next_wait();
326
327 producer.publish("test1", broadcast1.consume());
329
330 consumer1.assert_next("test1", &broadcast1.consume());
331 consumer1.assert_next_wait();
332
333 let mut consumer2 = producer.consume_all();
336
337 producer.publish("test2", broadcast2.consume());
339
340 consumer1.assert_next("test2", &broadcast2.consume());
341 consumer1.assert_next_wait();
342
343 consumer2.assert_next("test1", &broadcast1.consume());
344 consumer2.assert_next("test2", &broadcast2.consume());
345 consumer2.assert_next_wait();
346
347 drop(broadcast1);
349
350 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
352
353 consumer1.assert_next_none("test1");
355 consumer2.assert_next_none("test1");
356 consumer1.assert_next_wait();
357 consumer2.assert_next_wait();
358
359 let mut consumer3 = producer.consume_all();
361 consumer3.assert_next("test2", &broadcast2.consume());
362 consumer3.assert_next_wait();
363
364 drop(producer);
366
367 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
369
370 consumer1.assert_next_none("test2");
371 consumer2.assert_next_none("test2");
372 consumer3.assert_next_none("test2");
373
374 consumer1.assert_next_closed();
375 consumer2.assert_next_closed();
376 consumer3.assert_next_closed();
377 }
378
379 #[tokio::test]
380 async fn test_duplicate() {
381 let mut producer = OriginProducer::new();
382 let broadcast1 = BroadcastProducer::new();
383 let broadcast2 = BroadcastProducer::new();
384
385 producer.publish("test", broadcast1.consume());
386 producer.publish("test", broadcast2.consume());
387 assert!(producer.consume("test").is_some());
388
389 drop(broadcast1);
390
391 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
393 assert!(producer.consume("test").is_some());
394
395 drop(broadcast2);
396
397 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
399 assert!(producer.consume("test").is_none());
400 }
401
402 #[tokio::test]
403 async fn test_duplicate_reverse() {
404 let mut producer = OriginProducer::new();
405 let broadcast1 = BroadcastProducer::new();
406 let broadcast2 = BroadcastProducer::new();
407
408 producer.publish("test", broadcast1.consume());
409 producer.publish("test", broadcast2.consume());
410 assert!(producer.consume("test").is_some());
411
412 drop(broadcast2);
414
415 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
417 assert!(producer.consume("test").is_some());
418
419 drop(broadcast1);
420
421 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
423 assert!(producer.consume("test").is_none());
424 }
425}