moq_lite/model/
origin.rs

1use std::collections::{hash_map, HashMap};
2use tokio::sync::mpsc;
3use web_async::Lock;
4
5use super::BroadcastConsumer;
6
7// If there are multiple broadcasts with the same path, we use the most recent one but keep the others around.
8struct BroadcastState {
9	active: BroadcastConsumer,
10	backup: Vec<BroadcastConsumer>,
11}
12
13#[derive(Default)]
14struct ProducerState {
15	active: HashMap<String, BroadcastState>,
16	consumers: Vec<ConsumerState>,
17}
18
19impl ProducerState {
20	fn publish(&mut self, path: String, broadcast: BroadcastConsumer) {
21		match self.active.entry(path.clone()) {
22			hash_map::Entry::Occupied(mut entry) => {
23				let state = entry.get_mut();
24				if state.active.is_clone(&broadcast) {
25					tracing::warn!(?path, "skipping duplicate publish");
26					return;
27				}
28
29				// Make the new broadcast the active one.
30				let old = state.active.clone();
31				state.active = broadcast.clone();
32
33				// Move the old broadcast to the backup list.
34				// But we need to replace any previous duplicates.
35				let pos = state.backup.iter().position(|b| b.is_clone(&broadcast));
36				if let Some(pos) = pos {
37					state.backup[pos] = old;
38				} else {
39					state.backup.push(old);
40				}
41
42				// Reannounce the path to all consumers.
43				retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
44			}
45			hash_map::Entry::Vacant(entry) => {
46				entry.insert(BroadcastState {
47					active: broadcast.clone(),
48					backup: Vec::new(),
49				});
50			}
51		};
52
53		retain_mut_unordered(&mut self.consumers, |c| c.insert(&path, &broadcast));
54	}
55
56	fn remove(&mut self, path: String, broadcast: BroadcastConsumer) {
57		let mut entry = match self.active.entry(path) {
58			hash_map::Entry::Occupied(entry) => entry,
59			hash_map::Entry::Vacant(_) => panic!("broadcast not found"),
60		};
61
62		// See if we can remove the broadcast from the backup list.
63		let pos = entry.get().backup.iter().position(|b| b.is_clone(&broadcast));
64		if let Some(pos) = pos {
65			entry.get_mut().backup.remove(pos);
66			// Nothing else to do
67			return;
68		}
69
70		// Okay so it must be the active broadcast or else we fucked up.
71		assert!(entry.get().active.is_clone(&broadcast));
72
73		retain_mut_unordered(&mut self.consumers, |c| c.remove(entry.key()));
74
75		// If there's a backup broadcast, then announce it.
76		if let Some(active) = entry.get_mut().backup.pop() {
77			entry.get_mut().active = active;
78			retain_mut_unordered(&mut self.consumers, |c| c.insert(entry.key(), &entry.get().active));
79		} else {
80			// No more backups, so remove the entry.
81			entry.remove();
82		}
83	}
84}
85
86impl Drop for ProducerState {
87	fn drop(&mut self) {
88		for (path, _) in self.active.drain() {
89			retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
90		}
91	}
92}
93
94// A faster version of retain_mut that doesn't maintain the order.
95fn retain_mut_unordered<T, F: Fn(&mut T) -> bool>(vec: &mut Vec<T>, f: F) {
96	let mut i = 0;
97	while let Some(item) = vec.get_mut(i) {
98		if f(item) {
99			i += 1;
100		} else {
101			vec.swap_remove(i);
102		}
103	}
104}
105
106/// A broadcast path and its associated broadcast, or None if closed.
107type ConsumerUpdate = (String, Option<BroadcastConsumer>);
108
109struct ConsumerState {
110	prefix: String,
111	updates: mpsc::UnboundedSender<ConsumerUpdate>,
112}
113
114impl ConsumerState {
115	// Returns true if the consuemr is still alive.
116	pub fn insert(&mut self, path: &str, consumer: &BroadcastConsumer) -> bool {
117		if let Some(suffix) = path.strip_prefix(&self.prefix) {
118			let update = (suffix.to_string(), Some(consumer.clone()));
119			self.updates.send(update).is_ok()
120		} else {
121			!self.updates.is_closed()
122		}
123	}
124
125	pub fn remove(&mut self, path: &str) -> bool {
126		if let Some(suffix) = path.strip_prefix(&self.prefix) {
127			let update = (suffix.to_string(), None);
128			self.updates.send(update).is_ok()
129		} else {
130			!self.updates.is_closed()
131		}
132	}
133}
134
135/// Announces broadcasts to consumers over the network.
136#[derive(Clone, Default)]
137pub struct OriginProducer {
138	state: Lock<ProducerState>,
139}
140
141impl OriginProducer {
142	pub fn new() -> Self {
143		Self {
144			state: Lock::new(ProducerState {
145				active: HashMap::new(),
146				consumers: Vec::new(),
147			}),
148		}
149	}
150
151	/// Publish a broadcast, announcing it to all consumers.
152	///
153	/// The broadcast will be unannounced when it is closed.
154	/// If there is already a broadcast with the same path, then it will be replaced and reannounced.
155	/// If the old broadcast is closed before the new one, then nothing will happen.
156	/// If the new broadcast is closed before the old one, then the old broadcast will be reannounced.
157	pub fn publish<S: ToString>(&mut self, path: S, broadcast: BroadcastConsumer) {
158		let path = path.to_string();
159		self.state.lock().publish(path.clone(), broadcast.clone());
160
161		let state = self.state.clone().downgrade();
162
163		// TODO cancel this task when the producer is dropped.
164		web_async::spawn(async move {
165			broadcast.closed().await;
166			if let Some(state) = state.upgrade() {
167				state.lock().remove(path, broadcast);
168			}
169		});
170	}
171
172	/// Publish all broadcasts from the given origin.
173	pub fn publish_all(&mut self, broadcasts: OriginConsumer) {
174		self.publish_prefix("", broadcasts);
175	}
176
177	/// Publish all broadcasts from the given origin with an optional prefix.
178	pub fn publish_prefix(&mut self, prefix: &str, mut broadcasts: OriginConsumer) {
179		// Really gross that this just spawns a background task, but I want publishing to be sync.
180		let mut this = self.clone();
181
182		// Overkill to avoid allocating a string if the prefix is empty.
183		let prefix = match prefix {
184			"" => None,
185			prefix => Some(prefix.to_string()),
186		};
187
188		web_async::spawn(async move {
189			while let Some((suffix, broadcast)) = broadcasts.next().await {
190				let broadcast = match broadcast {
191					Some(broadcast) => broadcast,
192					// We don't need to worry about unannouncements here because our own OriginPublisher will handle it.
193					// Announcements are ordered so I don't think there's a race condition?
194					None => continue,
195				};
196
197				let path = match &prefix {
198					Some(prefix) => format!("{}{}", prefix, suffix),
199					None => suffix,
200				};
201
202				this.publish(path, broadcast);
203			}
204		});
205	}
206
207	/// Get a specific broadcast by name.
208	///
209	/// The most recent, non-closed broadcast will be returned if there are duplicates.
210	pub fn consume(&self, path: &str) -> Option<BroadcastConsumer> {
211		self.state.lock().active.get(path).map(|b| b.active.clone())
212	}
213
214	/// Subscribe to all announced broadcasts.
215	pub fn consume_all(&self) -> OriginConsumer {
216		self.consume_prefix("")
217	}
218
219	/// Subscribe to all announced broadcasts matching the prefix.
220	pub fn consume_prefix<S: ToString>(&self, prefix: S) -> OriginConsumer {
221		let mut state = self.state.lock();
222
223		let (tx, rx) = mpsc::unbounded_channel();
224		let mut consumer = ConsumerState {
225			prefix: prefix.to_string(),
226			updates: tx,
227		};
228
229		for (prefix, broadcast) in &state.active {
230			consumer.insert(prefix, &broadcast.active);
231		}
232		state.consumers.push(consumer);
233
234		OriginConsumer::new(rx)
235	}
236
237	/// Wait until all consumers have been dropped.
238	///
239	/// NOTE: subscribe can be called to unclose the producer.
240	pub async fn unused(&self) {
241		// Keep looping until all consumers are closed.
242		while let Some(notify) = self.unused_inner() {
243			notify.closed().await;
244		}
245	}
246
247	// Returns the closed notify of any consumer.
248	fn unused_inner(&self) -> Option<mpsc::UnboundedSender<ConsumerUpdate>> {
249		let mut state = self.state.lock();
250
251		while let Some(consumer) = state.consumers.last() {
252			if !consumer.updates.is_closed() {
253				return Some(consumer.updates.clone());
254			}
255
256			state.consumers.pop();
257		}
258
259		None
260	}
261}
262
263/// Consumes announced broadcasts matching against an optional prefix.
264pub struct OriginConsumer {
265	updates: mpsc::UnboundedReceiver<ConsumerUpdate>,
266}
267
268impl OriginConsumer {
269	fn new(updates: mpsc::UnboundedReceiver<ConsumerUpdate>) -> Self {
270		Self { updates }
271	}
272
273	/// Returns the next (un)announced broadcast and the path.
274	///
275	/// The broadcast will only be None if it was previously Some.
276	/// The same path won't be announced/unannounced twice, instead it will toggle.
277	pub async fn next(&mut self) -> Option<ConsumerUpdate> {
278		self.updates.recv().await
279	}
280}
281
282#[cfg(test)]
283use futures::FutureExt;
284
285#[cfg(test)]
286impl OriginConsumer {
287	pub fn assert_next(&mut self, path: &str, broadcast: &BroadcastConsumer) {
288		let next = self.next().now_or_never().expect("next blocked").expect("no next");
289		assert_eq!(next.0, path, "wrong path");
290		assert!(next.1.unwrap().is_clone(broadcast), "should be the same broadcast");
291	}
292
293	pub fn assert_next_none(&mut self, path: &str) {
294		let next = self.next().now_or_never().expect("next blocked").expect("no next");
295		assert_eq!(next.0, path, "wrong path");
296		assert!(next.1.is_none(), "should be unannounced");
297	}
298
299	pub fn assert_next_wait(&mut self) {
300		assert!(self.next().now_or_never().is_none(), "next should block");
301	}
302
303	pub fn assert_next_closed(&mut self) {
304		assert!(
305			self.next().now_or_never().expect("next blocked").is_none(),
306			"next should be closed"
307		);
308	}
309}
310
311#[cfg(test)]
312mod tests {
313	use crate::BroadcastProducer;
314
315	use super::*;
316
317	#[tokio::test]
318	async fn test_announce() {
319		let mut producer = OriginProducer::new();
320		let broadcast1 = BroadcastProducer::new();
321		let broadcast2 = BroadcastProducer::new();
322
323		// Make a new consumer that should get it.
324		let mut consumer1 = producer.consume_all();
325		consumer1.assert_next_wait();
326
327		// Publish the first broadcast.
328		producer.publish("test1", broadcast1.consume());
329
330		consumer1.assert_next("test1", &broadcast1.consume());
331		consumer1.assert_next_wait();
332
333		// Make a new consumer that should get the existing broadcast.
334		// But we don't consume it yet.
335		let mut consumer2 = producer.consume_all();
336
337		// Publish the second broadcast.
338		producer.publish("test2", broadcast2.consume());
339
340		consumer1.assert_next("test2", &broadcast2.consume());
341		consumer1.assert_next_wait();
342
343		consumer2.assert_next("test1", &broadcast1.consume());
344		consumer2.assert_next("test2", &broadcast2.consume());
345		consumer2.assert_next_wait();
346
347		// Close the first broadcast.
348		drop(broadcast1);
349
350		// Wait for the async task to run.
351		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
352
353		// All consumers should get a None now.
354		consumer1.assert_next_none("test1");
355		consumer2.assert_next_none("test1");
356		consumer1.assert_next_wait();
357		consumer2.assert_next_wait();
358
359		// And a new consumer only gets the last broadcast.
360		let mut consumer3 = producer.consume_all();
361		consumer3.assert_next("test2", &broadcast2.consume());
362		consumer3.assert_next_wait();
363
364		// Close the producer and make sure it cleans up
365		drop(producer);
366
367		// Wait for the async task to run.
368		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
369
370		consumer1.assert_next_none("test2");
371		consumer2.assert_next_none("test2");
372		consumer3.assert_next_none("test2");
373
374		consumer1.assert_next_closed();
375		consumer2.assert_next_closed();
376		consumer3.assert_next_closed();
377	}
378
379	#[tokio::test]
380	async fn test_duplicate() {
381		let mut producer = OriginProducer::new();
382		let broadcast1 = BroadcastProducer::new();
383		let broadcast2 = BroadcastProducer::new();
384
385		producer.publish("test", broadcast1.consume());
386		producer.publish("test", broadcast2.consume());
387		assert!(producer.consume("test").is_some());
388
389		drop(broadcast1);
390
391		// Wait for the async task to run.
392		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
393		assert!(producer.consume("test").is_some());
394
395		drop(broadcast2);
396
397		// Wait for the async task to run.
398		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
399		assert!(producer.consume("test").is_none());
400	}
401
402	#[tokio::test]
403	async fn test_duplicate_reverse() {
404		let mut producer = OriginProducer::new();
405		let broadcast1 = BroadcastProducer::new();
406		let broadcast2 = BroadcastProducer::new();
407
408		producer.publish("test", broadcast1.consume());
409		producer.publish("test", broadcast2.consume());
410		assert!(producer.consume("test").is_some());
411
412		// This is harder, dropping the new broadcast first.
413		drop(broadcast2);
414
415		// Wait for the cleanup async task to run.
416		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
417		assert!(producer.consume("test").is_some());
418
419		drop(broadcast1);
420
421		// Wait for the cleanup async task to run.
422		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
423		assert!(producer.consume("test").is_none());
424	}
425}