1use std::collections::{hash_map, HashMap};
2use tokio::sync::mpsc;
3use web_async::Lock;
4
5use super::BroadcastConsumer;
6
7struct BroadcastState {
9 active: BroadcastConsumer,
10 backup: Vec<BroadcastConsumer>,
11}
12
13#[derive(Default)]
14struct ProducerState {
15 active: HashMap<String, BroadcastState>,
16 consumers: Vec<ConsumerState>,
17}
18
19impl ProducerState {
20 fn publish(&mut self, path: String, broadcast: BroadcastConsumer) -> bool {
22 let mut unique = true;
23
24 match self.active.entry(path.clone()) {
25 hash_map::Entry::Occupied(mut entry) => {
26 let state = entry.get_mut();
27 if state.active.is_clone(&broadcast) {
28 return false;
30 }
31
32 let old = state.active.clone();
34 state.active = broadcast.clone();
35
36 let pos = state.backup.iter().position(|b| b.is_clone(&broadcast));
39 if let Some(pos) = pos {
40 state.backup[pos] = old;
41
42 unique = false;
44 } else {
45 state.backup.push(old);
46 }
47
48 retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
50 }
51 hash_map::Entry::Vacant(entry) => {
52 entry.insert(BroadcastState {
53 active: broadcast.clone(),
54 backup: Vec::new(),
55 });
56 }
57 };
58
59 retain_mut_unordered(&mut self.consumers, |c| c.insert(&path, &broadcast));
60
61 unique
62 }
63
64 fn remove(&mut self, path: String, broadcast: BroadcastConsumer) {
65 let mut entry = match self.active.entry(path) {
66 hash_map::Entry::Occupied(entry) => entry,
67 hash_map::Entry::Vacant(_) => panic!("broadcast not found"),
68 };
69
70 let pos = entry.get().backup.iter().position(|b| b.is_clone(&broadcast));
72 if let Some(pos) = pos {
73 entry.get_mut().backup.remove(pos);
74 return;
76 }
77
78 assert!(entry.get().active.is_clone(&broadcast));
80
81 retain_mut_unordered(&mut self.consumers, |c| c.remove(entry.key()));
82
83 if let Some(active) = entry.get_mut().backup.pop() {
85 entry.get_mut().active = active;
86 retain_mut_unordered(&mut self.consumers, |c| c.insert(entry.key(), &entry.get().active));
87 } else {
88 entry.remove();
90 }
91 }
92}
93
94impl Drop for ProducerState {
95 fn drop(&mut self) {
96 for (path, _) in self.active.drain() {
97 retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
98 }
99 }
100}
101
102fn retain_mut_unordered<T, F: Fn(&mut T) -> bool>(vec: &mut Vec<T>, f: F) {
104 let mut i = 0;
105 while let Some(item) = vec.get_mut(i) {
106 if f(item) {
107 i += 1;
108 } else {
109 vec.swap_remove(i);
110 }
111 }
112}
113
114type ConsumerUpdate = (String, Option<BroadcastConsumer>);
116
117struct ConsumerState {
118 prefix: String,
119 exact: bool,
120 updates: mpsc::UnboundedSender<ConsumerUpdate>,
121}
122
123impl ConsumerState {
124 pub fn insert(&mut self, path: &str, consumer: &BroadcastConsumer) -> bool {
126 if self.exact {
127 if path == self.prefix {
128 let update = ("".to_string(), Some(consumer.clone()));
129 return self.updates.send(update).is_ok();
130 }
131 } else if let Some(suffix) = path.strip_prefix(&self.prefix) {
132 let update = (suffix.to_string(), Some(consumer.clone()));
133 return self.updates.send(update).is_ok();
134 }
135
136 !self.updates.is_closed()
137 }
138
139 pub fn remove(&mut self, path: &str) -> bool {
140 if self.exact {
141 if path == self.prefix {
142 let update = ("".to_string(), None);
143 return self.updates.send(update).is_ok();
144 }
145 } else if let Some(suffix) = path.strip_prefix(&self.prefix) {
146 let update = (suffix.to_string(), None);
147 return self.updates.send(update).is_ok();
148 }
149
150 !self.updates.is_closed()
151 }
152}
153
154#[derive(Clone, Default)]
156pub struct OriginProducer {
157 state: Lock<ProducerState>,
158}
159
160impl OriginProducer {
161 pub fn new() -> Self {
162 Self {
163 state: Lock::new(ProducerState {
164 active: HashMap::new(),
165 consumers: Vec::new(),
166 }),
167 }
168 }
169
170 pub fn publish<S: ToString>(&mut self, path: S, broadcast: BroadcastConsumer) {
177 let path = path.to_string();
178
179 if !self.state.lock().publish(path.clone(), broadcast.clone()) {
180 tracing::warn!(?path, "duplicate publish");
182 return;
183 }
184
185 let state = self.state.clone().downgrade();
186
187 web_async::spawn(async move {
189 broadcast.closed().await;
190 if let Some(state) = state.upgrade() {
191 state.lock().remove(path, broadcast);
192 }
193 });
194 }
195
196 pub fn publish_all(&mut self, broadcasts: OriginConsumer) {
198 self.publish_prefix("", broadcasts);
199 }
200
201 pub fn publish_prefix(&mut self, prefix: &str, mut broadcasts: OriginConsumer) {
203 let mut this = self.clone();
205
206 let prefix = match prefix {
208 "" => None,
209 prefix => Some(prefix.to_string()),
210 };
211
212 web_async::spawn(async move {
213 while let Some((suffix, broadcast)) = broadcasts.next().await {
214 let broadcast = match broadcast {
215 Some(broadcast) => broadcast,
216 None => continue,
219 };
220
221 let path = match &prefix {
222 Some(prefix) => format!("{}{}", prefix, suffix),
223 None => suffix,
224 };
225
226 this.publish(path, broadcast);
227 }
228 });
229 }
230
231 pub fn consume(&self, path: &str) -> Option<BroadcastConsumer> {
235 self.state.lock().active.get(path).map(|b| b.active.clone())
236 }
237
238 pub fn consume_all(&self) -> OriginConsumer {
240 self.consume_prefix("")
241 }
242
243 pub fn consume_prefix<S: ToString>(&self, prefix: S) -> OriginConsumer {
245 let mut state = self.state.lock();
246
247 let (tx, rx) = mpsc::unbounded_channel();
248 let mut consumer = ConsumerState {
249 prefix: prefix.to_string(),
250 exact: false,
251 updates: tx,
252 };
253
254 for (prefix, broadcast) in &state.active {
255 consumer.insert(prefix, &broadcast.active);
256 }
257 state.consumers.push(consumer);
258
259 OriginConsumer::new(rx)
260 }
261
262 pub fn consume_exact<S: ToString>(&self, path: S) -> OriginConsumer {
264 let mut state = self.state.lock();
265
266 let (tx, rx) = mpsc::unbounded_channel();
267 let mut consumer = ConsumerState {
268 prefix: path.to_string(),
269 exact: true,
270 updates: tx,
271 };
272
273 for (prefix, broadcast) in &state.active {
274 consumer.insert(prefix, &broadcast.active);
275 }
276 state.consumers.push(consumer);
277
278 OriginConsumer::new(rx)
279 }
280
281 pub async fn unused(&self) {
285 while let Some(notify) = self.unused_inner() {
287 notify.closed().await;
288 }
289 }
290
291 fn unused_inner(&self) -> Option<mpsc::UnboundedSender<ConsumerUpdate>> {
293 let mut state = self.state.lock();
294
295 while let Some(consumer) = state.consumers.last() {
296 if !consumer.updates.is_closed() {
297 return Some(consumer.updates.clone());
298 }
299
300 state.consumers.pop();
301 }
302
303 None
304 }
305}
306
307pub struct OriginConsumer {
309 updates: mpsc::UnboundedReceiver<ConsumerUpdate>,
310}
311
312impl OriginConsumer {
313 fn new(updates: mpsc::UnboundedReceiver<ConsumerUpdate>) -> Self {
314 Self { updates }
315 }
316
317 pub async fn next(&mut self) -> Option<ConsumerUpdate> {
322 self.updates.recv().await
323 }
324}
325
326#[cfg(test)]
327use futures::FutureExt;
328
329#[cfg(test)]
330impl OriginConsumer {
331 pub fn assert_next(&mut self, path: &str, broadcast: &BroadcastConsumer) {
332 let next = self.next().now_or_never().expect("next blocked").expect("no next");
333 assert_eq!(next.0, path, "wrong path");
334 assert!(next.1.unwrap().is_clone(broadcast), "should be the same broadcast");
335 }
336
337 pub fn assert_next_none(&mut self, path: &str) {
338 let next = self.next().now_or_never().expect("next blocked").expect("no next");
339 assert_eq!(next.0, path, "wrong path");
340 assert!(next.1.is_none(), "should be unannounced");
341 }
342
343 pub fn assert_next_wait(&mut self) {
344 assert!(self.next().now_or_never().is_none(), "next should block");
345 }
346
347 pub fn assert_next_closed(&mut self) {
348 assert!(
349 self.next().now_or_never().expect("next blocked").is_none(),
350 "next should be closed"
351 );
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use crate::BroadcastProducer;
358
359 use super::*;
360
361 #[tokio::test]
362 async fn test_announce() {
363 let mut producer = OriginProducer::new();
364 let broadcast1 = BroadcastProducer::new();
365 let broadcast2 = BroadcastProducer::new();
366
367 let mut consumer1 = producer.consume_all();
369 consumer1.assert_next_wait();
370
371 producer.publish("test1", broadcast1.consume());
373
374 consumer1.assert_next("test1", &broadcast1.consume());
375 consumer1.assert_next_wait();
376
377 let mut consumer2 = producer.consume_all();
380
381 producer.publish("test2", broadcast2.consume());
383
384 consumer1.assert_next("test2", &broadcast2.consume());
385 consumer1.assert_next_wait();
386
387 consumer2.assert_next("test1", &broadcast1.consume());
388 consumer2.assert_next("test2", &broadcast2.consume());
389 consumer2.assert_next_wait();
390
391 drop(broadcast1);
393
394 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
396
397 consumer1.assert_next_none("test1");
399 consumer2.assert_next_none("test1");
400 consumer1.assert_next_wait();
401 consumer2.assert_next_wait();
402
403 let mut consumer3 = producer.consume_all();
405 consumer3.assert_next("test2", &broadcast2.consume());
406 consumer3.assert_next_wait();
407
408 drop(producer);
410
411 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
413
414 consumer1.assert_next_none("test2");
415 consumer2.assert_next_none("test2");
416 consumer3.assert_next_none("test2");
417
418 consumer1.assert_next_closed();
419 consumer2.assert_next_closed();
420 consumer3.assert_next_closed();
421 }
422
423 #[tokio::test]
424 async fn test_duplicate() {
425 let mut producer = OriginProducer::new();
426 let broadcast1 = BroadcastProducer::new();
427 let broadcast2 = BroadcastProducer::new();
428
429 producer.publish("test", broadcast1.consume());
430 producer.publish("test", broadcast2.consume());
431 assert!(producer.consume("test").is_some());
432
433 drop(broadcast1);
434
435 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
437 assert!(producer.consume("test").is_some());
438
439 drop(broadcast2);
440
441 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
443 assert!(producer.consume("test").is_none());
444 }
445
446 #[tokio::test]
447 async fn test_duplicate_reverse() {
448 let mut producer = OriginProducer::new();
449 let broadcast1 = BroadcastProducer::new();
450 let broadcast2 = BroadcastProducer::new();
451
452 producer.publish("test", broadcast1.consume());
453 producer.publish("test", broadcast2.consume());
454 assert!(producer.consume("test").is_some());
455
456 drop(broadcast2);
458
459 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
461 assert!(producer.consume("test").is_some());
462
463 drop(broadcast1);
464
465 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
467 assert!(producer.consume("test").is_none());
468 }
469
470 #[tokio::test]
471 async fn test_double_publish() {
472 let mut producer = OriginProducer::new();
473 let broadcast = BroadcastProducer::new();
474
475 producer.publish("test", broadcast.consume());
477 producer.publish("test", broadcast.consume());
478
479 assert!(producer.consume("test").is_some());
480
481 drop(broadcast);
482
483 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
485 assert!(producer.consume("test").is_none());
486 }
487}