moq_lite/model/
broadcast.rs1use std::{
2 collections::{HashMap, hash_map},
3 task::Poll,
4};
5
6use crate::{
7 Error, TrackConsumer, TrackProducer,
8 model::{
9 state::{Consumer, Producer},
10 track::TrackWeak,
11 waiter::{Waiter, waiter_fn},
12 },
13};
14
15use super::Track;
16
17#[derive(Clone, Default)]
21pub struct Broadcast {
22 }
24
25impl Broadcast {
26 pub fn produce() -> BroadcastProducer {
27 BroadcastProducer::new()
28 }
29}
30
31#[derive(Default, Clone)]
32struct State {
33 tracks: HashMap<String, TrackWeak>,
35
36 requests: Vec<TrackProducer>,
38
39 dynamic: usize,
42}
43
44#[derive(Clone)]
49pub struct BroadcastProducer {
50 state: Producer<State>,
51}
52
53impl Default for BroadcastProducer {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59impl BroadcastProducer {
60 pub fn new() -> Self {
61 Self {
62 state: Default::default(),
63 }
64 }
65
66 pub fn insert_track(&mut self, track: &TrackProducer) -> Result<(), Error> {
70 let mut state = self.state.modify()?;
71
72 let hash_map::Entry::Vacant(entry) = state.tracks.entry(track.info.name.clone()) else {
73 return Err(Error::Duplicate);
74 };
75
76 entry.insert(track.weak());
77
78 Ok(())
79 }
80
81 pub fn remove_track(&mut self, name: &str) -> Result<(), Error> {
83 let mut state = self.state.modify()?;
84
85 state.tracks.remove(name).ok_or(Error::NotFound)?;
86
87 Ok(())
88 }
89
90 pub fn create_track(&mut self, track: Track) -> Result<TrackProducer, Error> {
92 let track = TrackProducer::new(track);
93 self.insert_track(&track)?;
94 Ok(track)
95 }
96
97 pub fn dynamic(&self) -> BroadcastDynamic {
99 BroadcastDynamic::new(self.state.clone())
100 }
101
102 pub fn consume(&self) -> BroadcastConsumer {
104 BroadcastConsumer {
105 state: self.state.consume(),
106 }
107 }
108
109 pub fn abort(&mut self, err: Error) -> Result<(), Error> {
111 let mut state = self.state.modify()?;
112
113 for weak in state.tracks.values() {
115 weak.abort(err.clone());
116 }
117
118 for mut request in state.requests.drain(..) {
120 request.abort(err.clone()).ok();
121 }
122
123 state.abort(err);
124 Ok(())
125 }
126
127 pub fn is_clone(&self, other: &Self) -> bool {
129 self.state.is_clone(&other.state)
130 }
131}
132
133#[cfg(test)]
134impl BroadcastProducer {
135 pub fn assert_create_track(&mut self, track: &Track) -> TrackProducer {
136 self.create_track(track.clone()).expect("should not have errored")
137 }
138
139 pub fn assert_insert_track(&mut self, track: &TrackProducer) {
140 self.insert_track(track).expect("should not have errored")
141 }
142}
143
144#[derive(Clone)]
150pub struct BroadcastDynamic {
151 state: Producer<State>,
152}
153
154impl BroadcastDynamic {
155 fn new(state: Producer<State>) -> Self {
156 if let Ok(mut state) = state.modify() {
157 state.dynamic += 1;
159 }
160
161 Self { state }
162 }
163
164 fn poll_requested_track(&self, waiter: &Waiter) -> Poll<Result<Option<TrackProducer>, Error>> {
165 self.state.poll_modify(waiter, |state| {
166 if state.requests.is_empty() {
167 return Poll::Pending;
168 }
169 Poll::Ready(state.requests.pop())
170 })
171 }
172
173 pub async fn requested_track(&mut self) -> Result<Option<TrackProducer>, Error> {
175 waiter_fn(move |waiter| self.poll_requested_track(waiter)).await
176 }
177
178 pub fn consume(&self) -> BroadcastConsumer {
180 BroadcastConsumer {
181 state: self.state.consume(),
182 }
183 }
184
185 pub fn abort(&mut self, err: Error) -> Result<(), Error> {
187 let mut state = self.state.modify()?;
188
189 for weak in state.tracks.values() {
191 weak.abort(err.clone());
192 }
193
194 for mut request in state.requests.drain(..) {
196 request.abort(err.clone()).ok();
197 }
198
199 state.abort(err);
200 Ok(())
201 }
202
203 pub fn is_clone(&self, other: &Self) -> bool {
205 self.state.is_clone(&other.state)
206 }
207}
208
209impl Drop for BroadcastDynamic {
210 fn drop(&mut self) {
211 if let Ok(mut state) = self.state.modify() {
212 state.dynamic = state.dynamic.saturating_sub(1);
214 if state.dynamic != 0 {
215 return;
216 }
217
218 for mut request in state.requests.drain(..) {
220 request.abort(Error::Cancel).ok();
221 }
222 }
223 }
224}
225
226#[cfg(test)]
227use futures::FutureExt;
228
229#[cfg(test)]
230impl BroadcastDynamic {
231 pub fn assert_request(&mut self) -> TrackProducer {
232 self.requested_track()
233 .now_or_never()
234 .expect("should not have blocked")
235 .expect("should not have errored")
236 .expect("should be a request")
237 }
238
239 pub fn assert_no_request(&mut self) {
240 assert!(self.requested_track().now_or_never().is_none(), "should have blocked");
241 }
242}
243
244#[derive(Clone)]
246pub struct BroadcastConsumer {
247 state: Consumer<State>,
248}
249
250impl BroadcastConsumer {
251 pub fn subscribe_track(&self, track: &Track) -> Result<TrackConsumer, Error> {
252 let producer = self.state.produce()?;
254 let mut state = producer.modify()?;
255
256 if let Some(weak) = state.tracks.get(&track.name) {
257 if !weak.is_closed() {
258 return Ok(weak.consume());
259 }
260 state.tracks.remove(&track.name);
262 }
263
264 let producer = track.clone().produce();
266 let consumer = producer.consume();
267
268 if state.dynamic == 0 {
269 return Err(Error::NotFound);
270 }
271
272 let weak = producer.weak();
274 state.tracks.insert(producer.info.name.clone(), weak.clone());
275 state.requests.push(producer);
276
277 let consumer_state = self.state.clone();
279 web_async::spawn(async move {
280 let _ = weak.unused().await;
281 if let Ok(producer) = consumer_state.produce()
282 && let Ok(mut state) = producer.modify()
283 && let Some(current) = state.tracks.remove(&weak.info.name)
284 && !current.is_clone(&weak)
285 {
286 state.tracks.insert(current.info.name.clone(), current);
287 }
288 });
289
290 Ok(consumer)
291 }
292
293 pub async fn closed(&self) -> Error {
294 self.state.closed().await
295 }
296
297 pub fn is_clone(&self, other: &Self) -> bool {
299 self.state.is_clone(&other.state)
300 }
301}
302
303#[cfg(test)]
304impl BroadcastConsumer {
305 pub fn assert_subscribe_track(&self, track: &Track) -> TrackConsumer {
306 self.subscribe_track(track).expect("should not have errored")
307 }
308
309 pub fn assert_not_closed(&self) {
310 assert!(self.closed().now_or_never().is_none(), "should not be closed");
311 }
312
313 pub fn assert_closed(&self) {
314 assert!(self.closed().now_or_never().is_some(), "should be closed");
315 }
316}
317
318#[cfg(test)]
319mod test {
320 use super::*;
321
322 #[tokio::test]
323 async fn insert() {
324 let mut producer = BroadcastProducer::new();
325 let mut track1 = Track::new("track1").produce();
326
327 producer.assert_insert_track(&track1);
329 track1.append_group().unwrap();
330
331 let consumer = producer.consume();
332
333 let mut track1_sub = consumer.assert_subscribe_track(&Track::new("track1"));
334 track1_sub.assert_group();
335
336 let mut track2 = Track::new("track2").produce();
337 producer.assert_insert_track(&track2);
338
339 let consumer2 = producer.consume();
340 let mut track2_consumer = consumer2.assert_subscribe_track(&Track::new("track2"));
341 track2_consumer.assert_no_group();
342
343 track2.append_group().unwrap();
344
345 track2_consumer.assert_group();
346 }
347
348 #[tokio::test]
349 async fn closed() {
350 let mut producer = BroadcastProducer::new();
351 let _dynamic = producer.dynamic();
352
353 let consumer = producer.consume();
354 consumer.assert_not_closed();
355
356 let track1 = producer.assert_create_track(&Track::new("track1"));
358 let track1c = consumer.assert_subscribe_track(&track1.info);
359 let track2 = consumer.assert_subscribe_track(&Track::new("track2"));
360
361 producer.abort(Error::Cancel).unwrap();
363
364 track2.assert_error();
366
367 track1c.assert_error();
369
370 assert!(track1.is_closed());
372 }
373
374 #[tokio::test]
375 async fn requests() {
376 let mut producer = BroadcastProducer::new().dynamic();
377
378 let consumer = producer.consume();
379 let consumer2 = consumer.clone();
380
381 let mut track1 = consumer.assert_subscribe_track(&Track::new("track1"));
382 track1.assert_not_closed();
383 track1.assert_no_group();
384
385 let mut track2 = consumer2.assert_subscribe_track(&Track::new("track1"));
387 track2.assert_is_clone(&track1);
388
389 let mut track3 = producer.assert_request();
391 producer.assert_no_request();
392
393 track3.consume().assert_is_clone(&track1);
395
396 track3.append_group().unwrap();
398 track1.assert_group();
399 track2.assert_group();
400
401 let track4 = consumer.assert_subscribe_track(&Track::new("track2"));
403 drop(producer);
404
405 track4.assert_error();
407
408 let track5 = consumer2.subscribe_track(&Track::new("track3"));
409 assert!(track5.is_err(), "should have errored");
410 }
411
412 #[tokio::test]
413 async fn stale_producer() {
414 let mut broadcast = Broadcast::produce().dynamic();
415 let consumer = broadcast.consume();
416
417 let track1 = consumer.assert_subscribe_track(&Track::new("track1"));
419
420 let mut producer1 = broadcast.assert_request();
422 producer1.append_group().unwrap();
423 producer1.finish().unwrap();
424 drop(producer1);
425
426 track1.assert_closed();
428
429 let mut track2 = consumer.assert_subscribe_track(&Track::new("track1"));
431 track2.assert_not_closed();
432 track2.assert_not_clone(&track1);
433
434 let mut producer2 = broadcast.assert_request();
436 producer2.append_group().unwrap();
437
438 track2.assert_group();
440 }
441
442 #[tokio::test]
443 async fn requested_unused() {
444 let mut broadcast = Broadcast::produce().dynamic();
445
446 let consumer1 = broadcast.consume().assert_subscribe_track(&Track::new("unknown_track"));
448
449 let producer1 = broadcast.assert_request();
451
452 assert!(
454 producer1.unused().now_or_never().is_none(),
455 "track producer should be used"
456 );
457
458 let consumer2 = broadcast.consume().assert_subscribe_track(&Track::new("unknown_track"));
460 consumer2.assert_is_clone(&consumer1);
461
462 drop(consumer1);
464
465 assert!(
467 producer1.unused().now_or_never().is_none(),
468 "track producer should be used"
469 );
470
471 drop(consumer2);
473
474 assert!(
478 producer1.unused().now_or_never().is_some(),
479 "track producer should be unused after consumer is dropped"
480 );
481
482 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
484
485 let consumer3 = broadcast.consume().subscribe_track(&Track::new("unknown_track"));
487 let producer2 = broadcast.assert_request();
488
489 drop(consumer3);
491 assert!(
492 producer2.unused().now_or_never().is_some(),
493 "track producer should be unused after consumer is dropped"
494 );
495 }
496}