1use std::collections::{hash_map, HashMap};
2use tokio::sync::mpsc;
3use web_async::{Lock, LockWeak};
4
5use super::BroadcastConsumer;
6use crate::{Path, PathRef};
7
8struct BroadcastState {
10 active: BroadcastConsumer,
11 backup: Vec<BroadcastConsumer>,
12}
13
14#[derive(Default)]
15struct ProducerState {
16 active: HashMap<Path, BroadcastState>,
17 consumers: Vec<ConsumerState>,
18}
19
20impl ProducerState {
21 fn publish(&mut self, path: Path, broadcast: BroadcastConsumer) -> bool {
23 let mut unique = true;
24
25 match self.active.entry(path.clone()) {
26 hash_map::Entry::Occupied(mut entry) => {
27 let state = entry.get_mut();
28 if state.active.is_clone(&broadcast) {
29 return false;
31 }
32
33 let old = state.active.clone();
35 state.active = broadcast.clone();
36
37 let pos = state.backup.iter().position(|b| b.is_clone(&broadcast));
40 if let Some(pos) = pos {
41 state.backup[pos] = old;
42
43 unique = false;
45 } else {
46 state.backup.push(old);
47 }
48
49 retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
51 }
52 hash_map::Entry::Vacant(entry) => {
53 entry.insert(BroadcastState {
54 active: broadcast.clone(),
55 backup: Vec::new(),
56 });
57 }
58 };
59
60 retain_mut_unordered(&mut self.consumers, |c| c.insert(&path, &broadcast));
61
62 unique
63 }
64
65 fn remove(&mut self, path: Path, broadcast: BroadcastConsumer) {
66 let mut entry = match self.active.entry(path.clone()) {
67 hash_map::Entry::Occupied(entry) => entry,
68 hash_map::Entry::Vacant(_) => panic!("broadcast not found"),
69 };
70
71 let pos = entry.get().backup.iter().position(|b| b.is_clone(&broadcast));
73 if let Some(pos) = pos {
74 entry.get_mut().backup.remove(pos);
75 return;
77 }
78
79 assert!(entry.get().active.is_clone(&broadcast));
81
82 retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
83
84 if let Some(active) = entry.get_mut().backup.pop() {
86 entry.get_mut().active = active;
87 retain_mut_unordered(&mut self.consumers, |c| c.insert(&path, &entry.get().active));
88 } else {
89 entry.remove();
91 }
92 }
93}
94
95impl Drop for ProducerState {
96 fn drop(&mut self) {
97 for (path, _) in self.active.drain() {
98 retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
99 }
100 }
101}
102
103fn retain_mut_unordered<T, F: Fn(&mut T) -> bool>(vec: &mut Vec<T>, f: F) {
105 let mut i = 0;
106 while let Some(item) = vec.get_mut(i) {
107 if f(item) {
108 i += 1;
109 } else {
110 vec.swap_remove(i);
111 }
112 }
113}
114
115pub struct OriginUpdate {
118 pub suffix: Path,
119 pub active: Option<BroadcastConsumer>,
120}
121
122struct ConsumerState {
123 prefix: Path,
124 updates: mpsc::UnboundedSender<OriginUpdate>,
125}
126
127impl ConsumerState {
128 pub fn insert<'a>(&mut self, path: impl Into<PathRef<'a>>, consumer: &BroadcastConsumer) -> bool {
130 let path_ref = path.into();
131
132 if let Some(suffix) = path_ref.to_owned().strip_prefix(&self.prefix) {
133 let update = OriginUpdate {
135 suffix: suffix.into(),
136 active: Some(consumer.clone()),
137 };
138 return self.updates.send(update).is_ok();
139 }
140
141 !self.updates.is_closed()
142 }
143
144 pub fn remove<'a>(&mut self, path: impl Into<PathRef<'a>>) -> bool {
145 let path_ref = path.into();
146
147 if let Some(suffix) = path_ref.to_owned().strip_prefix(&self.prefix) {
148 let update = OriginUpdate {
150 suffix: suffix.into(),
151 active: None,
152 };
153 return self.updates.send(update).is_ok();
154 }
155
156 !self.updates.is_closed()
157 }
158}
159
160#[derive(Clone, Default)]
162pub struct OriginProducer {
163 prefix: Path,
164 state: Lock<ProducerState>,
165}
166
167impl OriginProducer {
168 pub fn new<T: Into<Path>>(prefix: T) -> Self {
170 let prefix = prefix.into();
171 Self {
172 state: Lock::new(ProducerState {
173 active: HashMap::new(),
174 consumers: Vec::new(),
175 }),
176 prefix,
177 }
178 }
179
180 pub fn publish<'a>(&mut self, suffix: impl Into<PathRef<'a>>, broadcast: BroadcastConsumer) {
187 let path = self.prefix.join(suffix.into());
188
189 if !self.state.lock().publish(path.clone(), broadcast.clone()) {
190 tracing::warn!(?path, "duplicate publish");
192 return;
193 }
194
195 let state = self.state.clone().downgrade();
196
197 web_async::spawn(async move {
199 broadcast.closed().await;
200 if let Some(state) = state.upgrade() {
201 state.lock().remove(path, broadcast);
202 }
203 });
204 }
205
206 pub fn publish_prefix<'a>(&mut self, prefix: impl Into<PathRef<'a>>) -> OriginProducer {
208 OriginProducer {
209 prefix: self.prefix.join(prefix.into()),
210 state: self.state.clone(),
211 }
212 }
213
214 pub fn consume<'a>(&self, suffix: impl Into<PathRef<'a>>) -> Option<BroadcastConsumer> {
218 let path = self.prefix.join(suffix.into());
219 self.state.lock().active.get(&path).map(|b| b.active.clone())
220 }
221
222 pub fn consume_all(&self) -> OriginConsumer {
224 self.consume_prefix("")
225 }
226
227 pub fn consume_prefix<'a>(&self, prefix: impl Into<PathRef<'a>>) -> OriginConsumer {
232 let prefix = self.prefix.join(prefix.into());
234
235 let mut state = self.state.lock();
236
237 let (tx, rx) = mpsc::unbounded_channel();
238 let mut consumer = ConsumerState {
239 prefix: prefix.clone(),
240 updates: tx,
241 };
242
243 for (path, broadcast) in &state.active {
244 consumer.insert(path, &broadcast.active);
245 }
246 state.consumers.push(consumer);
247
248 OriginConsumer {
249 prefix,
250 updates: rx,
251 producer: self.state.clone().downgrade(),
252 }
253 }
254
255 pub async fn unused(&self) {
259 while let Some(notify) = self.unused_inner() {
261 notify.closed().await;
262 }
263 }
264
265 fn unused_inner(&self) -> Option<mpsc::UnboundedSender<OriginUpdate>> {
267 let mut state = self.state.lock();
268
269 while let Some(consumer) = state.consumers.last() {
270 if !consumer.updates.is_closed() {
271 return Some(consumer.updates.clone());
272 }
273
274 state.consumers.pop();
275 }
276
277 None
278 }
279
280 pub fn prefix(&self) -> &Path {
281 &self.prefix
282 }
283}
284
285pub struct OriginConsumer {
287 producer: LockWeak<ProducerState>,
289 updates: mpsc::UnboundedReceiver<OriginUpdate>,
290 prefix: Path,
291}
292
293impl OriginConsumer {
294 pub async fn next(&mut self) -> Option<OriginUpdate> {
301 self.updates.recv().await
302 }
303
304 pub fn consume<'a>(&self, suffix: impl Into<PathRef<'a>>) -> Option<BroadcastConsumer> {
309 let path = self.prefix.join(suffix.into());
310
311 let state = self.producer.upgrade()?;
312 let state = state.lock();
313 state.active.get(&path).map(|b| b.active.clone())
314 }
315
316 pub fn consume_all(&self) -> OriginConsumer {
317 self.consume_prefix("")
318 }
319
320 pub fn consume_prefix<'a>(&self, prefix: impl Into<PathRef<'a>>) -> OriginConsumer {
321 let prefix = self.prefix.join(prefix.into());
323
324 let (tx, rx) = mpsc::unbounded_channel();
325
326 let mut consumer = ConsumerState {
328 prefix: prefix.clone(),
329 updates: tx,
330 };
331
332 if let Some(state) = self.producer.upgrade() {
333 let mut state = state.lock();
334
335 for (path, broadcast) in &state.active {
336 consumer.insert(path, &broadcast.active);
337 }
338
339 state.consumers.push(consumer);
340 }
341
342 OriginConsumer {
343 prefix,
344 updates: rx,
345 producer: self.producer.clone(),
346 }
347 }
348
349 pub fn prefix(&self) -> &Path {
350 &self.prefix
351 }
352}
353
354impl Clone for OriginConsumer {
355 fn clone(&self) -> Self {
356 self.consume_all()
357 }
358}
359
360#[cfg(test)]
361use futures::FutureExt;
362
363#[cfg(test)]
364impl OriginConsumer {
365 pub fn assert_next(&mut self, path: &str, broadcast: &BroadcastConsumer) {
366 let next = self.next().now_or_never().expect("next blocked").expect("no next");
367 assert_eq!(next.suffix.as_str(), path, "wrong path");
368 assert!(next.active.unwrap().is_clone(broadcast), "should be the same broadcast");
369 }
370
371 pub fn assert_next_none(&mut self, path: &str) {
372 let next = self.next().now_or_never().expect("next blocked").expect("no next");
373 assert_eq!(next.suffix.as_str(), path, "wrong path");
374 assert!(next.active.is_none(), "should be unannounced");
375 }
376
377 pub fn assert_next_wait(&mut self) {
378 assert!(self.next().now_or_never().is_none(), "next should block");
379 }
380
381 pub fn assert_next_closed(&mut self) {
382 assert!(
383 self.next().now_or_never().expect("next blocked").is_none(),
384 "next should be closed"
385 );
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use crate::BroadcastProducer;
392
393 use super::*;
394
395 #[tokio::test]
396 async fn test_announce() {
397 let mut producer = OriginProducer::default();
398 let broadcast1 = BroadcastProducer::new();
399 let broadcast2 = BroadcastProducer::new();
400
401 let mut consumer1 = producer.consume_all();
403 consumer1.assert_next_wait();
404
405 producer.publish("test1", broadcast1.consume());
407
408 consumer1.assert_next("test1", &broadcast1.consume());
409 consumer1.assert_next_wait();
410
411 let mut consumer2 = producer.consume_all();
414
415 producer.publish("test2", broadcast2.consume());
417
418 consumer1.assert_next("test2", &broadcast2.consume());
419 consumer1.assert_next_wait();
420
421 consumer2.assert_next("test1", &broadcast1.consume());
422 consumer2.assert_next("test2", &broadcast2.consume());
423 consumer2.assert_next_wait();
424
425 drop(broadcast1);
427
428 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
430
431 consumer1.assert_next_none("test1");
433 consumer2.assert_next_none("test1");
434 consumer1.assert_next_wait();
435 consumer2.assert_next_wait();
436
437 let mut consumer3 = producer.consume_all();
439 consumer3.assert_next("test2", &broadcast2.consume());
440 consumer3.assert_next_wait();
441
442 drop(producer);
444
445 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
447
448 consumer1.assert_next_none("test2");
449 consumer2.assert_next_none("test2");
450 consumer3.assert_next_none("test2");
451
452 consumer1.assert_next_closed();
453 consumer2.assert_next_closed();
454 consumer3.assert_next_closed();
455 }
456
457 #[tokio::test]
458 async fn test_duplicate() {
459 let mut producer = OriginProducer::default();
460 let broadcast1 = BroadcastProducer::new();
461 let broadcast2 = BroadcastProducer::new();
462
463 producer.publish("test", broadcast1.consume());
464 producer.publish("test", broadcast2.consume());
465 assert!(producer.consume("test").is_some());
466
467 drop(broadcast1);
468
469 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
471 assert!(producer.consume("test").is_some());
472
473 drop(broadcast2);
474
475 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
477 assert!(producer.consume("test").is_none());
478 }
479
480 #[tokio::test]
481 async fn test_duplicate_reverse() {
482 let mut producer = OriginProducer::default();
483 let broadcast1 = BroadcastProducer::new();
484 let broadcast2 = BroadcastProducer::new();
485
486 producer.publish("test", broadcast1.consume());
487 producer.publish("test", broadcast2.consume());
488 assert!(producer.consume("test").is_some());
489
490 drop(broadcast2);
492
493 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
495 assert!(producer.consume("test").is_some());
496
497 drop(broadcast1);
498
499 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
501 assert!(producer.consume("test").is_none());
502 }
503
504 #[tokio::test]
505 async fn test_double_publish() {
506 let mut producer = OriginProducer::default();
507 let broadcast = BroadcastProducer::new();
508
509 producer.publish("test", broadcast.consume());
511 producer.publish("test", broadcast.consume());
512
513 assert!(producer.consume("test").is_some());
514
515 drop(broadcast);
516
517 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
519 assert!(producer.consume("test").is_none());
520 }
521}