1use std::collections::{hash_map, HashMap};
2use tokio::sync::mpsc;
3use web_async::Lock;
4
5use super::BroadcastConsumer;
6
7struct BroadcastState {
9 active: BroadcastConsumer,
10 backup: Vec<BroadcastConsumer>,
11}
12
13#[derive(Default)]
14struct ProducerState {
15 active: HashMap<String, BroadcastState>,
16 consumers: Vec<ConsumerState>,
17}
18
19impl ProducerState {
20 fn publish(&mut self, path: String, broadcast: BroadcastConsumer) -> bool {
22 let mut unique = true;
23
24 match self.active.entry(path.clone()) {
25 hash_map::Entry::Occupied(mut entry) => {
26 let state = entry.get_mut();
27 if state.active.is_clone(&broadcast) {
28 return false;
30 }
31
32 let old = state.active.clone();
34 state.active = broadcast.clone();
35
36 let pos = state.backup.iter().position(|b| b.is_clone(&broadcast));
39 if let Some(pos) = pos {
40 state.backup[pos] = old;
41
42 unique = false;
44 } else {
45 state.backup.push(old);
46 }
47
48 retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
50 }
51 hash_map::Entry::Vacant(entry) => {
52 entry.insert(BroadcastState {
53 active: broadcast.clone(),
54 backup: Vec::new(),
55 });
56 }
57 };
58
59 retain_mut_unordered(&mut self.consumers, |c| c.insert(&path, &broadcast));
60
61 unique
62 }
63
64 fn remove(&mut self, path: String, broadcast: BroadcastConsumer) {
65 let mut entry = match self.active.entry(path) {
66 hash_map::Entry::Occupied(entry) => entry,
67 hash_map::Entry::Vacant(_) => panic!("broadcast not found"),
68 };
69
70 let pos = entry.get().backup.iter().position(|b| b.is_clone(&broadcast));
72 if let Some(pos) = pos {
73 entry.get_mut().backup.remove(pos);
74 return;
76 }
77
78 assert!(entry.get().active.is_clone(&broadcast));
80
81 retain_mut_unordered(&mut self.consumers, |c| c.remove(entry.key()));
82
83 if let Some(active) = entry.get_mut().backup.pop() {
85 entry.get_mut().active = active;
86 retain_mut_unordered(&mut self.consumers, |c| c.insert(entry.key(), &entry.get().active));
87 } else {
88 entry.remove();
90 }
91 }
92}
93
94impl Drop for ProducerState {
95 fn drop(&mut self) {
96 for (path, _) in self.active.drain() {
97 retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
98 }
99 }
100}
101
102fn retain_mut_unordered<T, F: Fn(&mut T) -> bool>(vec: &mut Vec<T>, f: F) {
104 let mut i = 0;
105 while let Some(item) = vec.get_mut(i) {
106 if f(item) {
107 i += 1;
108 } else {
109 vec.swap_remove(i);
110 }
111 }
112}
113
114type ConsumerUpdate = (String, Option<BroadcastConsumer>);
116
117struct ConsumerState {
118 prefix: String,
119 updates: mpsc::UnboundedSender<ConsumerUpdate>,
120}
121
122impl ConsumerState {
123 pub fn insert(&mut self, path: &str, consumer: &BroadcastConsumer) -> bool {
125 if let Some(suffix) = path.strip_prefix(&self.prefix) {
126 let update = (suffix.to_string(), Some(consumer.clone()));
127 self.updates.send(update).is_ok()
128 } else {
129 !self.updates.is_closed()
130 }
131 }
132
133 pub fn remove(&mut self, path: &str) -> bool {
134 if let Some(suffix) = path.strip_prefix(&self.prefix) {
135 let update = (suffix.to_string(), None);
136 self.updates.send(update).is_ok()
137 } else {
138 !self.updates.is_closed()
139 }
140 }
141}
142
143#[derive(Clone, Default)]
145pub struct OriginProducer {
146 state: Lock<ProducerState>,
147}
148
149impl OriginProducer {
150 pub fn new() -> Self {
151 Self {
152 state: Lock::new(ProducerState {
153 active: HashMap::new(),
154 consumers: Vec::new(),
155 }),
156 }
157 }
158
159 pub fn publish<S: ToString>(&mut self, path: S, broadcast: BroadcastConsumer) {
166 let path = path.to_string();
167
168 if !self.state.lock().publish(path.clone(), broadcast.clone()) {
169 tracing::warn!(?path, "duplicate publish");
171 return;
172 }
173
174 let state = self.state.clone().downgrade();
175
176 web_async::spawn(async move {
178 broadcast.closed().await;
179 if let Some(state) = state.upgrade() {
180 state.lock().remove(path, broadcast);
181 }
182 });
183 }
184
185 pub fn publish_all(&mut self, broadcasts: OriginConsumer) {
187 self.publish_prefix("", broadcasts);
188 }
189
190 pub fn publish_prefix(&mut self, prefix: &str, mut broadcasts: OriginConsumer) {
192 let mut this = self.clone();
194
195 let prefix = match prefix {
197 "" => None,
198 prefix => Some(prefix.to_string()),
199 };
200
201 web_async::spawn(async move {
202 while let Some((suffix, broadcast)) = broadcasts.next().await {
203 let broadcast = match broadcast {
204 Some(broadcast) => broadcast,
205 None => continue,
208 };
209
210 let path = match &prefix {
211 Some(prefix) => format!("{}{}", prefix, suffix),
212 None => suffix,
213 };
214
215 this.publish(path, broadcast);
216 }
217 });
218 }
219
220 pub fn consume(&self, path: &str) -> Option<BroadcastConsumer> {
224 self.state.lock().active.get(path).map(|b| b.active.clone())
225 }
226
227 pub fn consume_all(&self) -> OriginConsumer {
229 self.consume_prefix("")
230 }
231
232 pub fn consume_prefix<S: ToString>(&self, prefix: S) -> OriginConsumer {
234 let mut state = self.state.lock();
235
236 let (tx, rx) = mpsc::unbounded_channel();
237 let mut consumer = ConsumerState {
238 prefix: prefix.to_string(),
239 updates: tx,
240 };
241
242 for (prefix, broadcast) in &state.active {
243 consumer.insert(prefix, &broadcast.active);
244 }
245 state.consumers.push(consumer);
246
247 OriginConsumer::new(rx)
248 }
249
250 pub async fn unused(&self) {
254 while let Some(notify) = self.unused_inner() {
256 notify.closed().await;
257 }
258 }
259
260 fn unused_inner(&self) -> Option<mpsc::UnboundedSender<ConsumerUpdate>> {
262 let mut state = self.state.lock();
263
264 while let Some(consumer) = state.consumers.last() {
265 if !consumer.updates.is_closed() {
266 return Some(consumer.updates.clone());
267 }
268
269 state.consumers.pop();
270 }
271
272 None
273 }
274}
275
276pub struct OriginConsumer {
278 updates: mpsc::UnboundedReceiver<ConsumerUpdate>,
279}
280
281impl OriginConsumer {
282 fn new(updates: mpsc::UnboundedReceiver<ConsumerUpdate>) -> Self {
283 Self { updates }
284 }
285
286 pub async fn next(&mut self) -> Option<ConsumerUpdate> {
291 self.updates.recv().await
292 }
293}
294
295#[cfg(test)]
296use futures::FutureExt;
297
298#[cfg(test)]
299impl OriginConsumer {
300 pub fn assert_next(&mut self, path: &str, broadcast: &BroadcastConsumer) {
301 let next = self.next().now_or_never().expect("next blocked").expect("no next");
302 assert_eq!(next.0, path, "wrong path");
303 assert!(next.1.unwrap().is_clone(broadcast), "should be the same broadcast");
304 }
305
306 pub fn assert_next_none(&mut self, path: &str) {
307 let next = self.next().now_or_never().expect("next blocked").expect("no next");
308 assert_eq!(next.0, path, "wrong path");
309 assert!(next.1.is_none(), "should be unannounced");
310 }
311
312 pub fn assert_next_wait(&mut self) {
313 assert!(self.next().now_or_never().is_none(), "next should block");
314 }
315
316 pub fn assert_next_closed(&mut self) {
317 assert!(
318 self.next().now_or_never().expect("next blocked").is_none(),
319 "next should be closed"
320 );
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use crate::BroadcastProducer;
327
328 use super::*;
329
330 #[tokio::test]
331 async fn test_announce() {
332 let mut producer = OriginProducer::new();
333 let broadcast1 = BroadcastProducer::new();
334 let broadcast2 = BroadcastProducer::new();
335
336 let mut consumer1 = producer.consume_all();
338 consumer1.assert_next_wait();
339
340 producer.publish("test1", broadcast1.consume());
342
343 consumer1.assert_next("test1", &broadcast1.consume());
344 consumer1.assert_next_wait();
345
346 let mut consumer2 = producer.consume_all();
349
350 producer.publish("test2", broadcast2.consume());
352
353 consumer1.assert_next("test2", &broadcast2.consume());
354 consumer1.assert_next_wait();
355
356 consumer2.assert_next("test1", &broadcast1.consume());
357 consumer2.assert_next("test2", &broadcast2.consume());
358 consumer2.assert_next_wait();
359
360 drop(broadcast1);
362
363 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
365
366 consumer1.assert_next_none("test1");
368 consumer2.assert_next_none("test1");
369 consumer1.assert_next_wait();
370 consumer2.assert_next_wait();
371
372 let mut consumer3 = producer.consume_all();
374 consumer3.assert_next("test2", &broadcast2.consume());
375 consumer3.assert_next_wait();
376
377 drop(producer);
379
380 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
382
383 consumer1.assert_next_none("test2");
384 consumer2.assert_next_none("test2");
385 consumer3.assert_next_none("test2");
386
387 consumer1.assert_next_closed();
388 consumer2.assert_next_closed();
389 consumer3.assert_next_closed();
390 }
391
392 #[tokio::test]
393 async fn test_duplicate() {
394 let mut producer = OriginProducer::new();
395 let broadcast1 = BroadcastProducer::new();
396 let broadcast2 = BroadcastProducer::new();
397
398 producer.publish("test", broadcast1.consume());
399 producer.publish("test", broadcast2.consume());
400 assert!(producer.consume("test").is_some());
401
402 drop(broadcast1);
403
404 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
406 assert!(producer.consume("test").is_some());
407
408 drop(broadcast2);
409
410 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
412 assert!(producer.consume("test").is_none());
413 }
414
415 #[tokio::test]
416 async fn test_duplicate_reverse() {
417 let mut producer = OriginProducer::new();
418 let broadcast1 = BroadcastProducer::new();
419 let broadcast2 = BroadcastProducer::new();
420
421 producer.publish("test", broadcast1.consume());
422 producer.publish("test", broadcast2.consume());
423 assert!(producer.consume("test").is_some());
424
425 drop(broadcast2);
427
428 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
430 assert!(producer.consume("test").is_some());
431
432 drop(broadcast1);
433
434 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
436 assert!(producer.consume("test").is_none());
437 }
438
439 #[tokio::test]
440 async fn test_double_publish() {
441 let mut producer = OriginProducer::new();
442 let broadcast = BroadcastProducer::new();
443
444 producer.publish("test", broadcast.consume());
446 producer.publish("test", broadcast.consume());
447
448 assert!(producer.consume("test").is_some());
449
450 drop(broadcast);
451
452 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
454 assert!(producer.consume("test").is_none());
455 }
456}