1use std::collections::{hash_map, HashMap};
2use tokio::sync::mpsc;
3use web_async::{Lock, LockWeak};
4
5use super::BroadcastConsumer;
6use crate::{Path, PathRef};
7
8struct BroadcastState {
10 active: BroadcastConsumer,
11 backup: Vec<BroadcastConsumer>,
12}
13
14#[derive(Default)]
15struct ProducerState {
16 active: HashMap<Path, BroadcastState>,
17 consumers: Vec<ConsumerState>,
18}
19
20impl ProducerState {
21 fn publish(&mut self, path: Path, broadcast: BroadcastConsumer) -> bool {
23 let mut unique = true;
24
25 match self.active.entry(path.clone()) {
26 hash_map::Entry::Occupied(mut entry) => {
27 let state = entry.get_mut();
28 if state.active.is_clone(&broadcast) {
29 return false;
31 }
32
33 let old = state.active.clone();
35 state.active = broadcast.clone();
36
37 let pos = state.backup.iter().position(|b| b.is_clone(&broadcast));
40 if let Some(pos) = pos {
41 state.backup[pos] = old;
42
43 unique = false;
45 } else {
46 state.backup.push(old);
47 }
48
49 retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
51 }
52 hash_map::Entry::Vacant(entry) => {
53 entry.insert(BroadcastState {
54 active: broadcast.clone(),
55 backup: Vec::new(),
56 });
57 }
58 };
59
60 retain_mut_unordered(&mut self.consumers, |c| c.insert(&path, &broadcast));
61
62 unique
63 }
64
65 fn remove(&mut self, path: Path, broadcast: BroadcastConsumer) {
66 let mut entry = match self.active.entry(path.clone()) {
67 hash_map::Entry::Occupied(entry) => entry,
68 hash_map::Entry::Vacant(_) => panic!("broadcast not found"),
69 };
70
71 let pos = entry.get().backup.iter().position(|b| b.is_clone(&broadcast));
73 if let Some(pos) = pos {
74 entry.get_mut().backup.remove(pos);
75 return;
77 }
78
79 assert!(entry.get().active.is_clone(&broadcast));
81
82 retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
83
84 if let Some(active) = entry.get_mut().backup.pop() {
86 entry.get_mut().active = active;
87 retain_mut_unordered(&mut self.consumers, |c| c.insert(&path, &entry.get().active));
88 } else {
89 entry.remove();
91 }
92 }
93}
94
95impl Drop for ProducerState {
96 fn drop(&mut self) {
97 for (path, _) in self.active.drain() {
98 retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
99 }
100 }
101}
102
103fn retain_mut_unordered<T, F: Fn(&mut T) -> bool>(vec: &mut Vec<T>, f: F) {
105 let mut i = 0;
106 while let Some(item) = vec.get_mut(i) {
107 if f(item) {
108 i += 1;
109 } else {
110 vec.swap_remove(i);
111 }
112 }
113}
114
115pub struct OriginUpdate {
118 pub suffix: Path,
119 pub active: Option<BroadcastConsumer>,
120}
121
122struct ConsumerState {
123 prefix: Path,
124 updates: mpsc::UnboundedSender<OriginUpdate>,
125}
126
127impl ConsumerState {
128 pub fn insert<'a>(&mut self, path: impl Into<PathRef<'a>>, consumer: &BroadcastConsumer) -> bool {
130 let path_ref = path.into();
131
132 if let Some(suffix) = path_ref.to_path().strip_prefix(&self.prefix) {
133 let update = OriginUpdate {
135 suffix: suffix.into(),
136 active: Some(consumer.clone()),
137 };
138 return self.updates.send(update).is_ok();
139 }
140
141 !self.updates.is_closed()
142 }
143
144 pub fn remove<'a>(&mut self, path: impl Into<PathRef<'a>>) -> bool {
145 let path_ref = path.into();
146
147 if let Some(suffix) = path_ref.to_path().strip_prefix(&self.prefix) {
148 let update = OriginUpdate {
150 suffix: suffix.into(),
151 active: None,
152 };
153 return self.updates.send(update).is_ok();
154 }
155
156 !self.updates.is_closed()
157 }
158}
159
160#[derive(Clone, Default)]
162pub struct OriginProducer {
163 prefix: Path,
164 state: Lock<ProducerState>,
165}
166
167impl OriginProducer {
168 pub fn new<T: Into<Path>>(prefix: T) -> Self {
170 let prefix = prefix.into();
171 Self {
172 state: Lock::new(ProducerState {
173 active: HashMap::new(),
174 consumers: Vec::new(),
175 }),
176 prefix,
177 }
178 }
179
180 pub fn publish<'a>(&mut self, suffix: impl Into<PathRef<'a>>, broadcast: BroadcastConsumer) {
187 let path = self.prefix.join(suffix.into());
188
189 if !self.state.lock().publish(path.clone(), broadcast.clone()) {
190 tracing::warn!(?path, "duplicate publish");
192 return;
193 }
194
195 let state = self.state.clone().downgrade();
196
197 web_async::spawn(async move {
199 broadcast.closed().await;
200 if let Some(state) = state.upgrade() {
201 state.lock().remove(path, broadcast);
202 }
203 });
204 }
205
206 pub fn publish_prefix<'a>(&mut self, prefix: impl Into<PathRef<'a>>) -> OriginProducer {
208 OriginProducer {
209 prefix: self.prefix.join(prefix.into()),
210 state: self.state.clone(),
211 }
212 }
213
214 pub fn consume<'a>(&self, suffix: impl Into<PathRef<'a>>) -> Option<BroadcastConsumer> {
218 let path = self.prefix.join(suffix.into());
219 self.state.lock().active.get(&path).map(|b| b.active.clone())
220 }
221
222 pub fn consume_all(&self) -> OriginConsumer {
224 self.consume_prefix("")
225 }
226
227 pub fn consume_prefix<'a>(&self, prefix: impl Into<PathRef<'a>>) -> OriginConsumer {
232 let prefix = self.prefix.join(prefix.into());
234
235 let mut state = self.state.lock();
236
237 let (tx, rx) = mpsc::unbounded_channel();
238 let mut consumer = ConsumerState {
239 prefix: prefix.clone(),
240 updates: tx,
241 };
242
243 for (path, broadcast) in &state.active {
244 consumer.insert(path, &broadcast.active);
245 }
246 state.consumers.push(consumer);
247
248 OriginConsumer {
249 prefix,
250 updates: rx,
251 producer: self.state.clone().downgrade(),
252 }
253 }
254
255 pub async fn unused(&self) {
259 while let Some(notify) = self.unused_inner() {
261 notify.closed().await;
262 }
263 }
264
265 fn unused_inner(&self) -> Option<mpsc::UnboundedSender<OriginUpdate>> {
267 let mut state = self.state.lock();
268
269 while let Some(consumer) = state.consumers.last() {
270 if !consumer.updates.is_closed() {
271 return Some(consumer.updates.clone());
272 }
273
274 state.consumers.pop();
275 }
276
277 None
278 }
279
280 pub fn prefix(&self) -> &Path {
281 &self.prefix
282 }
283}
284
285pub struct OriginConsumer {
287 producer: LockWeak<ProducerState>,
289 updates: mpsc::UnboundedReceiver<OriginUpdate>,
290 prefix: Path,
291}
292
293impl OriginConsumer {
294 pub async fn next(&mut self) -> Option<OriginUpdate> {
301 self.updates.recv().await
302 }
303
304 pub fn consume<'a>(&self, suffix: impl Into<PathRef<'a>>) -> Option<BroadcastConsumer> {
308 let path = self.prefix.join(suffix.into());
309
310 let state = self.producer.upgrade()?;
311 let state = state.lock();
312 state.active.get(&path).map(|b| b.active.clone())
313 }
314
315 pub fn consume_all(&self) -> OriginConsumer {
316 self.consume_prefix("")
317 }
318
319 pub fn consume_prefix<'a>(&self, prefix: impl Into<PathRef<'a>>) -> OriginConsumer {
320 let prefix = self.prefix.join(prefix.into());
322
323 let (tx, rx) = mpsc::unbounded_channel();
324
325 let mut consumer = ConsumerState {
327 prefix: prefix.clone(),
328 updates: tx,
329 };
330
331 if let Some(state) = self.producer.upgrade() {
332 let mut state = state.lock();
333
334 for (path, broadcast) in &state.active {
335 consumer.insert(path, &broadcast.active);
336 }
337
338 state.consumers.push(consumer);
339 }
340
341 OriginConsumer {
342 prefix,
343 updates: rx,
344 producer: self.producer.clone(),
345 }
346 }
347
348 pub fn prefix(&self) -> &Path {
349 &self.prefix
350 }
351}
352
353impl Clone for OriginConsumer {
354 fn clone(&self) -> Self {
355 self.consume_all()
356 }
357}
358
359#[cfg(test)]
360use futures::FutureExt;
361
362#[cfg(test)]
363impl OriginConsumer {
364 pub fn assert_next(&mut self, path: &str, broadcast: &BroadcastConsumer) {
365 let next = self.next().now_or_never().expect("next blocked").expect("no next");
366 assert_eq!(next.suffix.as_str(), path, "wrong path");
367 assert!(next.active.unwrap().is_clone(broadcast), "should be the same broadcast");
368 }
369
370 pub fn assert_next_none(&mut self, path: &str) {
371 let next = self.next().now_or_never().expect("next blocked").expect("no next");
372 assert_eq!(next.suffix.as_str(), path, "wrong path");
373 assert!(next.active.is_none(), "should be unannounced");
374 }
375
376 pub fn assert_next_wait(&mut self) {
377 assert!(self.next().now_or_never().is_none(), "next should block");
378 }
379
380 pub fn assert_next_closed(&mut self) {
381 assert!(
382 self.next().now_or_never().expect("next blocked").is_none(),
383 "next should be closed"
384 );
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use crate::BroadcastProducer;
391
392 use super::*;
393
394 #[tokio::test]
395 async fn test_announce() {
396 let mut producer = OriginProducer::default();
397 let broadcast1 = BroadcastProducer::new();
398 let broadcast2 = BroadcastProducer::new();
399
400 let mut consumer1 = producer.consume_all();
402 consumer1.assert_next_wait();
403
404 producer.publish("test1", broadcast1.consume());
406
407 consumer1.assert_next("test1", &broadcast1.consume());
408 consumer1.assert_next_wait();
409
410 let mut consumer2 = producer.consume_all();
413
414 producer.publish("test2", broadcast2.consume());
416
417 consumer1.assert_next("test2", &broadcast2.consume());
418 consumer1.assert_next_wait();
419
420 consumer2.assert_next("test1", &broadcast1.consume());
421 consumer2.assert_next("test2", &broadcast2.consume());
422 consumer2.assert_next_wait();
423
424 drop(broadcast1);
426
427 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
429
430 consumer1.assert_next_none("test1");
432 consumer2.assert_next_none("test1");
433 consumer1.assert_next_wait();
434 consumer2.assert_next_wait();
435
436 let mut consumer3 = producer.consume_all();
438 consumer3.assert_next("test2", &broadcast2.consume());
439 consumer3.assert_next_wait();
440
441 drop(producer);
443
444 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
446
447 consumer1.assert_next_none("test2");
448 consumer2.assert_next_none("test2");
449 consumer3.assert_next_none("test2");
450
451 consumer1.assert_next_closed();
452 consumer2.assert_next_closed();
453 consumer3.assert_next_closed();
454 }
455
456 #[tokio::test]
457 async fn test_duplicate() {
458 let mut producer = OriginProducer::default();
459 let broadcast1 = BroadcastProducer::new();
460 let broadcast2 = BroadcastProducer::new();
461
462 producer.publish("test", broadcast1.consume());
463 producer.publish("test", broadcast2.consume());
464 assert!(producer.consume("test").is_some());
465
466 drop(broadcast1);
467
468 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
470 assert!(producer.consume("test").is_some());
471
472 drop(broadcast2);
473
474 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
476 assert!(producer.consume("test").is_none());
477 }
478
479 #[tokio::test]
480 async fn test_duplicate_reverse() {
481 let mut producer = OriginProducer::default();
482 let broadcast1 = BroadcastProducer::new();
483 let broadcast2 = BroadcastProducer::new();
484
485 producer.publish("test", broadcast1.consume());
486 producer.publish("test", broadcast2.consume());
487 assert!(producer.consume("test").is_some());
488
489 drop(broadcast2);
491
492 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
494 assert!(producer.consume("test").is_some());
495
496 drop(broadcast1);
497
498 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
500 assert!(producer.consume("test").is_none());
501 }
502
503 #[tokio::test]
504 async fn test_double_publish() {
505 let mut producer = OriginProducer::default();
506 let broadcast = BroadcastProducer::new();
507
508 producer.publish("test", broadcast.consume());
510 producer.publish("test", broadcast.consume());
511
512 assert!(producer.consume("test").is_some());
513
514 drop(broadcast);
515
516 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
518 assert!(producer.consume("test").is_none());
519 }
520}