1use std::collections::{hash_map, HashMap};
2use tokio::sync::mpsc;
3use web_async::{Lock, LockWeak};
4
5use super::BroadcastConsumer;
6use crate::{Path, PathRef};
7
8struct BroadcastState {
10 active: BroadcastConsumer,
11 backup: Vec<BroadcastConsumer>,
12}
13
14#[derive(Default)]
15struct ProducerState {
16 active: HashMap<Path, BroadcastState>,
17 consumers: Vec<ConsumerState>,
18}
19
20impl ProducerState {
21 fn publish(&mut self, path: Path, broadcast: BroadcastConsumer) -> bool {
23 let mut unique = true;
24
25 match self.active.entry(path.clone()) {
26 hash_map::Entry::Occupied(mut entry) => {
27 let state = entry.get_mut();
28 if state.active.is_clone(&broadcast) {
29 return false;
31 }
32
33 let old = state.active.clone();
35 state.active = broadcast.clone();
36
37 let pos = state.backup.iter().position(|b| b.is_clone(&broadcast));
40 if let Some(pos) = pos {
41 state.backup[pos] = old;
42
43 unique = false;
45 } else {
46 state.backup.push(old);
47 }
48
49 retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
51 }
52 hash_map::Entry::Vacant(entry) => {
53 entry.insert(BroadcastState {
54 active: broadcast.clone(),
55 backup: Vec::new(),
56 });
57 }
58 };
59
60 retain_mut_unordered(&mut self.consumers, |c| c.insert(&path, &broadcast));
61
62 unique
63 }
64
65 fn remove(&mut self, path: Path, broadcast: BroadcastConsumer) {
66 let mut entry = match self.active.entry(path.clone()) {
67 hash_map::Entry::Occupied(entry) => entry,
68 hash_map::Entry::Vacant(_) => panic!("broadcast not found"),
69 };
70
71 let pos = entry.get().backup.iter().position(|b| b.is_clone(&broadcast));
73 if let Some(pos) = pos {
74 entry.get_mut().backup.remove(pos);
75 return;
77 }
78
79 assert!(entry.get().active.is_clone(&broadcast));
81
82 retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
83
84 if let Some(active) = entry.get_mut().backup.pop() {
86 entry.get_mut().active = active;
87 retain_mut_unordered(&mut self.consumers, |c| c.insert(&path, &entry.get().active));
88 } else {
89 entry.remove();
91 }
92 }
93}
94
95impl Drop for ProducerState {
96 fn drop(&mut self) {
97 for (path, _) in self.active.drain() {
98 retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
99 }
100 }
101}
102
103fn retain_mut_unordered<T, F: Fn(&mut T) -> bool>(vec: &mut Vec<T>, f: F) {
105 let mut i = 0;
106 while let Some(item) = vec.get_mut(i) {
107 if f(item) {
108 i += 1;
109 } else {
110 vec.swap_remove(i);
111 }
112 }
113}
114
115pub struct OriginUpdate {
118 pub suffix: Path,
119 pub active: Option<BroadcastConsumer>,
120}
121
122struct ConsumerState {
123 prefix: Path,
124 updates: mpsc::UnboundedSender<OriginUpdate>,
125}
126
127impl ConsumerState {
128 pub fn insert<'a>(&mut self, path: impl Into<PathRef<'a>>, consumer: &BroadcastConsumer) -> bool {
130 let path_ref = path.into();
131
132 if let Some(suffix) = path_ref.to_owned().strip_prefix(&self.prefix) {
133 let update = OriginUpdate {
135 suffix: suffix.into(),
136 active: Some(consumer.clone()),
137 };
138 return self.updates.send(update).is_ok();
139 }
140
141 !self.updates.is_closed()
142 }
143
144 pub fn remove<'a>(&mut self, path: impl Into<PathRef<'a>>) -> bool {
145 let path_ref = path.into();
146
147 if let Some(suffix) = path_ref.to_owned().strip_prefix(&self.prefix) {
148 let update = OriginUpdate {
150 suffix: suffix.into(),
151 active: None,
152 };
153 return self.updates.send(update).is_ok();
154 }
155
156 !self.updates.is_closed()
157 }
158}
159
160#[derive(Clone, Default)]
162pub struct OriginProducer {
163 root: Path,
165
166 prefix: Path,
170
171 state: Lock<ProducerState>,
172}
173
174impl OriginProducer {
175 pub fn new() -> Self {
176 Self::default()
177 }
178
179 pub fn publish<'a>(&mut self, path: impl Into<PathRef<'a>>, broadcast: BroadcastConsumer) {
186 let path = path.into();
187 let full = self.root.join(&self.prefix).join(&path);
188
189 if !self.state.lock().publish(full.clone(), broadcast.clone()) {
190 tracing::warn!(?path, "duplicate publish");
193 return;
194 }
195
196 let state = self.state.clone().downgrade();
197
198 web_async::spawn(async move {
200 broadcast.closed().await;
201 if let Some(state) = state.upgrade() {
202 state.lock().remove(full, broadcast);
203 }
204 });
205 }
206
207 pub fn publish_prefix<'a>(&self, prefix: impl Into<PathRef<'a>>) -> Self {
209 Self {
210 prefix: self.prefix.join(prefix),
211 state: self.state.clone(),
212 root: self.root.clone(),
213 }
214 }
215
216 pub fn consume<'a>(&self, path: impl Into<PathRef<'a>>) -> Option<BroadcastConsumer> {
220 let path = path.into();
221
222 let full = self.root.join(path);
223 self.state.lock().active.get(&full).map(|b| b.active.clone())
224 }
225
226 pub fn consume_all(&self) -> OriginConsumer {
228 self.consume_prefix("")
229 }
230
231 pub fn consume_prefix(&self, prefix: impl Into<Path>) -> OriginConsumer {
236 let prefix = prefix.into();
237 let full = self.root.join(&prefix);
238
239 let mut state = self.state.lock();
240
241 let (tx, rx) = mpsc::unbounded_channel();
242 let mut consumer = ConsumerState {
243 prefix: full,
244 updates: tx,
245 };
246
247 for (path, broadcast) in &state.active {
248 consumer.insert(path, &broadcast.active);
249 }
250 state.consumers.push(consumer);
251
252 OriginConsumer {
253 root: self.root.clone(),
254 prefix,
255 updates: rx,
256 producer: self.state.clone().downgrade(),
257 }
258 }
259
260 pub fn with_root(&self, root: impl Into<Path>) -> Self {
261 let root = root.into();
262
263 let prefix = match self.prefix.strip_prefix(&root) {
266 Some(prefix) => prefix.to_owned(),
267 None if self.prefix.is_empty() => Path::default(),
268 None => panic!("with_root doesn't match existing prefix"),
269 };
270
271 Self {
272 root: self.root.join(&root),
273 prefix,
274 state: self.state.clone(),
275 }
276 }
277
278 pub async fn unused(&self) {
282 while let Some(notify) = self.unused_inner() {
284 notify.closed().await;
285 }
286 }
287
288 fn unused_inner(&self) -> Option<mpsc::UnboundedSender<OriginUpdate>> {
290 let mut state = self.state.lock();
291
292 while let Some(consumer) = state.consumers.last() {
293 if !consumer.updates.is_closed() {
294 return Some(consumer.updates.clone());
295 }
296
297 state.consumers.pop();
298 }
299
300 None
301 }
302
303 pub fn root(&self) -> &Path {
304 &self.root
305 }
306
307 pub fn prefix(&self) -> &Path {
308 &self.prefix
309 }
310}
311
312pub struct OriginConsumer {
314 producer: LockWeak<ProducerState>,
316 updates: mpsc::UnboundedReceiver<OriginUpdate>,
317
318 root: Path,
320
321 prefix: Path,
323}
324
325impl OriginConsumer {
326 pub async fn next(&mut self) -> Option<OriginUpdate> {
333 self.updates.recv().await
334 }
335
336 pub fn consume<'a>(&self, path: impl Into<PathRef<'a>>) -> Option<BroadcastConsumer> {
341 let full = self.root.join(&self.prefix).join(path.into());
342
343 let state = self.producer.upgrade()?;
344 let state = state.lock();
345 state.active.get(&full).map(|b| b.active.clone())
346 }
347
348 pub fn consume_all(&self) -> OriginConsumer {
349 self.consume_prefix("")
350 }
351
352 pub fn consume_prefix<'a>(&self, prefix: impl Into<PathRef<'a>>) -> OriginConsumer {
353 let prefix = self.prefix.join(prefix);
355
356 let full = self.root.join(&prefix);
358
359 let (tx, rx) = mpsc::unbounded_channel();
360
361 let mut consumer = ConsumerState {
363 prefix: full,
364 updates: tx,
365 };
366
367 if let Some(state) = self.producer.upgrade() {
368 let mut state = state.lock();
369
370 for (path, broadcast) in &state.active {
371 consumer.insert(path, &broadcast.active);
372 }
373
374 state.consumers.push(consumer);
375 }
376
377 OriginConsumer {
378 root: self.root.clone(),
379 prefix,
380 updates: rx,
381 producer: self.producer.clone(),
382 }
383 }
384
385 pub fn root(&self) -> &Path {
386 &self.root
387 }
388
389 pub fn prefix(&self) -> &Path {
390 &self.prefix
391 }
392}
393
394impl Clone for OriginConsumer {
395 fn clone(&self) -> Self {
396 self.consume_all()
397 }
398}
399
400#[cfg(test)]
401use futures::FutureExt;
402
403#[cfg(test)]
404impl OriginConsumer {
405 pub fn assert_next(&mut self, path: &str, broadcast: &BroadcastConsumer) {
406 let next = self.next().now_or_never().expect("next blocked").expect("no next");
407 assert_eq!(next.suffix.as_str(), path, "wrong path");
408 assert!(next.active.unwrap().is_clone(broadcast), "should be the same broadcast");
409 }
410
411 pub fn assert_next_none(&mut self, path: &str) {
412 let next = self.next().now_or_never().expect("next blocked").expect("no next");
413 assert_eq!(next.suffix.as_str(), path, "wrong path");
414 assert!(next.active.is_none(), "should be unannounced");
415 }
416
417 pub fn assert_next_wait(&mut self) {
418 assert!(self.next().now_or_never().is_none(), "next should block");
419 }
420
421 pub fn assert_next_closed(&mut self) {
422 assert!(
423 self.next().now_or_never().expect("next blocked").is_none(),
424 "next should be closed"
425 );
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use crate::BroadcastProducer;
432
433 use super::*;
434
435 #[tokio::test]
436 async fn test_announce() {
437 let mut producer = OriginProducer::default();
438 let broadcast1 = BroadcastProducer::new();
439 let broadcast2 = BroadcastProducer::new();
440
441 let mut consumer1 = producer.consume_all();
443 consumer1.assert_next_wait();
444
445 producer.publish("test1", broadcast1.consume());
447
448 consumer1.assert_next("test1", &broadcast1.consume());
449 consumer1.assert_next_wait();
450
451 let mut consumer2 = producer.consume_all();
454
455 producer.publish("test2", broadcast2.consume());
457
458 consumer1.assert_next("test2", &broadcast2.consume());
459 consumer1.assert_next_wait();
460
461 consumer2.assert_next("test1", &broadcast1.consume());
462 consumer2.assert_next("test2", &broadcast2.consume());
463 consumer2.assert_next_wait();
464
465 drop(broadcast1);
467
468 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
470
471 consumer1.assert_next_none("test1");
473 consumer2.assert_next_none("test1");
474 consumer1.assert_next_wait();
475 consumer2.assert_next_wait();
476
477 let mut consumer3 = producer.consume_all();
479 consumer3.assert_next("test2", &broadcast2.consume());
480 consumer3.assert_next_wait();
481
482 drop(producer);
484
485 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
487
488 consumer1.assert_next_none("test2");
489 consumer2.assert_next_none("test2");
490 consumer3.assert_next_none("test2");
491
492 consumer1.assert_next_closed();
493 consumer2.assert_next_closed();
494 consumer3.assert_next_closed();
495 }
496
497 #[tokio::test]
498 async fn test_duplicate() {
499 let mut producer = OriginProducer::default();
500 let broadcast1 = BroadcastProducer::new();
501 let broadcast2 = BroadcastProducer::new();
502
503 producer.publish("test", broadcast1.consume());
504 producer.publish("test", broadcast2.consume());
505 assert!(producer.consume("test").is_some());
506
507 drop(broadcast1);
508
509 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
511 assert!(producer.consume("test").is_some());
512
513 drop(broadcast2);
514
515 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
517 assert!(producer.consume("test").is_none());
518 }
519
520 #[tokio::test]
521 async fn test_duplicate_reverse() {
522 let mut producer = OriginProducer::default();
523 let broadcast1 = BroadcastProducer::new();
524 let broadcast2 = BroadcastProducer::new();
525
526 producer.publish("test", broadcast1.consume());
527 producer.publish("test", broadcast2.consume());
528 assert!(producer.consume("test").is_some());
529
530 drop(broadcast2);
532
533 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
535 assert!(producer.consume("test").is_some());
536
537 drop(broadcast1);
538
539 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
541 assert!(producer.consume("test").is_none());
542 }
543
544 #[tokio::test]
545 async fn test_double_publish() {
546 let mut producer = OriginProducer::default();
547 let broadcast = BroadcastProducer::new();
548
549 producer.publish("test", broadcast.consume());
551 producer.publish("test", broadcast.consume());
552
553 assert!(producer.consume("test").is_some());
554
555 drop(broadcast);
556
557 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
559 assert!(producer.consume("test").is_none());
560 }
561}