moq_lite/model/
origin.rs

1use std::collections::{hash_map, HashMap};
2use tokio::sync::mpsc;
3use web_async::Lock;
4
5use super::BroadcastConsumer;
6
7// If there are multiple broadcasts with the same path, we use the most recent one but keep the others around.
8struct BroadcastState {
9	active: BroadcastConsumer,
10	backup: Vec<BroadcastConsumer>,
11}
12
13#[derive(Default)]
14struct ProducerState {
15	active: HashMap<String, BroadcastState>,
16	consumers: Vec<ConsumerState>,
17}
18
19impl ProducerState {
20	// Returns true if this was a unique broadcast.
21	fn publish(&mut self, path: String, broadcast: BroadcastConsumer) -> bool {
22		let mut unique = true;
23
24		match self.active.entry(path.clone()) {
25			hash_map::Entry::Occupied(mut entry) => {
26				let state = entry.get_mut();
27				if state.active.is_clone(&broadcast) {
28					// If we're already publishing this broadcast, then don't do anything.
29					return false;
30				}
31
32				// Make the new broadcast the active one.
33				let old = state.active.clone();
34				state.active = broadcast.clone();
35
36				// Move the old broadcast to the backup list.
37				// But we need to replace any previous duplicates.
38				let pos = state.backup.iter().position(|b| b.is_clone(&broadcast));
39				if let Some(pos) = pos {
40					state.backup[pos] = old;
41
42					// We're already publishing this broadcast, so don't run the cleanup task.
43					unique = false;
44				} else {
45					state.backup.push(old);
46				}
47
48				// Reannounce the path to all consumers.
49				retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
50			}
51			hash_map::Entry::Vacant(entry) => {
52				entry.insert(BroadcastState {
53					active: broadcast.clone(),
54					backup: Vec::new(),
55				});
56			}
57		};
58
59		retain_mut_unordered(&mut self.consumers, |c| c.insert(&path, &broadcast));
60
61		unique
62	}
63
64	fn remove(&mut self, path: String, broadcast: BroadcastConsumer) {
65		let mut entry = match self.active.entry(path) {
66			hash_map::Entry::Occupied(entry) => entry,
67			hash_map::Entry::Vacant(_) => panic!("broadcast not found"),
68		};
69
70		// See if we can remove the broadcast from the backup list.
71		let pos = entry.get().backup.iter().position(|b| b.is_clone(&broadcast));
72		if let Some(pos) = pos {
73			entry.get_mut().backup.remove(pos);
74			// Nothing else to do
75			return;
76		}
77
78		// Okay so it must be the active broadcast or else we fucked up.
79		assert!(entry.get().active.is_clone(&broadcast));
80
81		retain_mut_unordered(&mut self.consumers, |c| c.remove(entry.key()));
82
83		// If there's a backup broadcast, then announce it.
84		if let Some(active) = entry.get_mut().backup.pop() {
85			entry.get_mut().active = active;
86			retain_mut_unordered(&mut self.consumers, |c| c.insert(entry.key(), &entry.get().active));
87		} else {
88			// No more backups, so remove the entry.
89			entry.remove();
90		}
91	}
92}
93
94impl Drop for ProducerState {
95	fn drop(&mut self) {
96		for (path, _) in self.active.drain() {
97			retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
98		}
99	}
100}
101
102// A faster version of retain_mut that doesn't maintain the order.
103fn retain_mut_unordered<T, F: Fn(&mut T) -> bool>(vec: &mut Vec<T>, f: F) {
104	let mut i = 0;
105	while let Some(item) = vec.get_mut(i) {
106		if f(item) {
107			i += 1;
108		} else {
109			vec.swap_remove(i);
110		}
111	}
112}
113
114/// A broadcast path and its associated broadcast, or None if closed.
115type ConsumerUpdate = (String, Option<BroadcastConsumer>);
116
117struct ConsumerState {
118	prefix: String,
119	updates: mpsc::UnboundedSender<ConsumerUpdate>,
120}
121
122impl ConsumerState {
123	// Returns true if the consuemr is still alive.
124	pub fn insert(&mut self, path: &str, consumer: &BroadcastConsumer) -> bool {
125		if let Some(suffix) = path.strip_prefix(&self.prefix) {
126			let update = (suffix.to_string(), Some(consumer.clone()));
127			self.updates.send(update).is_ok()
128		} else {
129			!self.updates.is_closed()
130		}
131	}
132
133	pub fn remove(&mut self, path: &str) -> bool {
134		if let Some(suffix) = path.strip_prefix(&self.prefix) {
135			let update = (suffix.to_string(), None);
136			self.updates.send(update).is_ok()
137		} else {
138			!self.updates.is_closed()
139		}
140	}
141}
142
143/// Announces broadcasts to consumers over the network.
144#[derive(Clone, Default)]
145pub struct OriginProducer {
146	state: Lock<ProducerState>,
147}
148
149impl OriginProducer {
150	pub fn new() -> Self {
151		Self {
152			state: Lock::new(ProducerState {
153				active: HashMap::new(),
154				consumers: Vec::new(),
155			}),
156		}
157	}
158
159	/// Publish a broadcast, announcing it to all consumers.
160	///
161	/// The broadcast will be unannounced when it is closed.
162	/// If there is already a broadcast with the same path, then it will be replaced and reannounced.
163	/// If the old broadcast is closed before the new one, then nothing will happen.
164	/// If the new broadcast is closed before the old one, then the old broadcast will be reannounced.
165	pub fn publish<S: ToString>(&mut self, path: S, broadcast: BroadcastConsumer) {
166		let path = path.to_string();
167
168		if !self.state.lock().publish(path.clone(), broadcast.clone()) {
169			// This is not a big deal, but we want to avoid spawning additional cleanup tasks.
170			tracing::warn!(?path, "duplicate publish");
171			return;
172		}
173
174		let state = self.state.clone().downgrade();
175
176		// TODO cancel this task when the producer is dropped.
177		web_async::spawn(async move {
178			broadcast.closed().await;
179			if let Some(state) = state.upgrade() {
180				state.lock().remove(path, broadcast);
181			}
182		});
183	}
184
185	/// Publish all broadcasts from the given origin.
186	pub fn publish_all(&mut self, broadcasts: OriginConsumer) {
187		self.publish_prefix("", broadcasts);
188	}
189
190	/// Publish all broadcasts from the given origin with an optional prefix.
191	pub fn publish_prefix(&mut self, prefix: &str, mut broadcasts: OriginConsumer) {
192		// Really gross that this just spawns a background task, but I want publishing to be sync.
193		let mut this = self.clone();
194
195		// Overkill to avoid allocating a string if the prefix is empty.
196		let prefix = match prefix {
197			"" => None,
198			prefix => Some(prefix.to_string()),
199		};
200
201		web_async::spawn(async move {
202			while let Some((suffix, broadcast)) = broadcasts.next().await {
203				let broadcast = match broadcast {
204					Some(broadcast) => broadcast,
205					// We don't need to worry about unannouncements here because our own OriginPublisher will handle it.
206					// Announcements are ordered so I don't think there's a race condition?
207					None => continue,
208				};
209
210				let path = match &prefix {
211					Some(prefix) => format!("{}{}", prefix, suffix),
212					None => suffix,
213				};
214
215				this.publish(path, broadcast);
216			}
217		});
218	}
219
220	/// Get a specific broadcast by name.
221	///
222	/// The most recent, non-closed broadcast will be returned if there are duplicates.
223	pub fn consume(&self, path: &str) -> Option<BroadcastConsumer> {
224		self.state.lock().active.get(path).map(|b| b.active.clone())
225	}
226
227	/// Subscribe to all announced broadcasts.
228	pub fn consume_all(&self) -> OriginConsumer {
229		self.consume_prefix("")
230	}
231
232	/// Subscribe to all announced broadcasts matching the prefix.
233	pub fn consume_prefix<S: ToString>(&self, prefix: S) -> OriginConsumer {
234		let mut state = self.state.lock();
235
236		let (tx, rx) = mpsc::unbounded_channel();
237		let mut consumer = ConsumerState {
238			prefix: prefix.to_string(),
239			updates: tx,
240		};
241
242		for (prefix, broadcast) in &state.active {
243			consumer.insert(prefix, &broadcast.active);
244		}
245		state.consumers.push(consumer);
246
247		OriginConsumer::new(rx)
248	}
249
250	/// Wait until all consumers have been dropped.
251	///
252	/// NOTE: subscribe can be called to unclose the producer.
253	pub async fn unused(&self) {
254		// Keep looping until all consumers are closed.
255		while let Some(notify) = self.unused_inner() {
256			notify.closed().await;
257		}
258	}
259
260	// Returns the closed notify of any consumer.
261	fn unused_inner(&self) -> Option<mpsc::UnboundedSender<ConsumerUpdate>> {
262		let mut state = self.state.lock();
263
264		while let Some(consumer) = state.consumers.last() {
265			if !consumer.updates.is_closed() {
266				return Some(consumer.updates.clone());
267			}
268
269			state.consumers.pop();
270		}
271
272		None
273	}
274}
275
276/// Consumes announced broadcasts matching against an optional prefix.
277pub struct OriginConsumer {
278	updates: mpsc::UnboundedReceiver<ConsumerUpdate>,
279}
280
281impl OriginConsumer {
282	fn new(updates: mpsc::UnboundedReceiver<ConsumerUpdate>) -> Self {
283		Self { updates }
284	}
285
286	/// Returns the next (un)announced broadcast and the path.
287	///
288	/// The broadcast will only be None if it was previously Some.
289	/// The same path won't be announced/unannounced twice, instead it will toggle.
290	pub async fn next(&mut self) -> Option<ConsumerUpdate> {
291		self.updates.recv().await
292	}
293}
294
295#[cfg(test)]
296use futures::FutureExt;
297
298#[cfg(test)]
299impl OriginConsumer {
300	pub fn assert_next(&mut self, path: &str, broadcast: &BroadcastConsumer) {
301		let next = self.next().now_or_never().expect("next blocked").expect("no next");
302		assert_eq!(next.0, path, "wrong path");
303		assert!(next.1.unwrap().is_clone(broadcast), "should be the same broadcast");
304	}
305
306	pub fn assert_next_none(&mut self, path: &str) {
307		let next = self.next().now_or_never().expect("next blocked").expect("no next");
308		assert_eq!(next.0, path, "wrong path");
309		assert!(next.1.is_none(), "should be unannounced");
310	}
311
312	pub fn assert_next_wait(&mut self) {
313		assert!(self.next().now_or_never().is_none(), "next should block");
314	}
315
316	pub fn assert_next_closed(&mut self) {
317		assert!(
318			self.next().now_or_never().expect("next blocked").is_none(),
319			"next should be closed"
320		);
321	}
322}
323
324#[cfg(test)]
325mod tests {
326	use crate::BroadcastProducer;
327
328	use super::*;
329
330	#[tokio::test]
331	async fn test_announce() {
332		let mut producer = OriginProducer::new();
333		let broadcast1 = BroadcastProducer::new();
334		let broadcast2 = BroadcastProducer::new();
335
336		// Make a new consumer that should get it.
337		let mut consumer1 = producer.consume_all();
338		consumer1.assert_next_wait();
339
340		// Publish the first broadcast.
341		producer.publish("test1", broadcast1.consume());
342
343		consumer1.assert_next("test1", &broadcast1.consume());
344		consumer1.assert_next_wait();
345
346		// Make a new consumer that should get the existing broadcast.
347		// But we don't consume it yet.
348		let mut consumer2 = producer.consume_all();
349
350		// Publish the second broadcast.
351		producer.publish("test2", broadcast2.consume());
352
353		consumer1.assert_next("test2", &broadcast2.consume());
354		consumer1.assert_next_wait();
355
356		consumer2.assert_next("test1", &broadcast1.consume());
357		consumer2.assert_next("test2", &broadcast2.consume());
358		consumer2.assert_next_wait();
359
360		// Close the first broadcast.
361		drop(broadcast1);
362
363		// Wait for the async task to run.
364		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
365
366		// All consumers should get a None now.
367		consumer1.assert_next_none("test1");
368		consumer2.assert_next_none("test1");
369		consumer1.assert_next_wait();
370		consumer2.assert_next_wait();
371
372		// And a new consumer only gets the last broadcast.
373		let mut consumer3 = producer.consume_all();
374		consumer3.assert_next("test2", &broadcast2.consume());
375		consumer3.assert_next_wait();
376
377		// Close the producer and make sure it cleans up
378		drop(producer);
379
380		// Wait for the async task to run.
381		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
382
383		consumer1.assert_next_none("test2");
384		consumer2.assert_next_none("test2");
385		consumer3.assert_next_none("test2");
386
387		consumer1.assert_next_closed();
388		consumer2.assert_next_closed();
389		consumer3.assert_next_closed();
390	}
391
392	#[tokio::test]
393	async fn test_duplicate() {
394		let mut producer = OriginProducer::new();
395		let broadcast1 = BroadcastProducer::new();
396		let broadcast2 = BroadcastProducer::new();
397
398		producer.publish("test", broadcast1.consume());
399		producer.publish("test", broadcast2.consume());
400		assert!(producer.consume("test").is_some());
401
402		drop(broadcast1);
403
404		// Wait for the async task to run.
405		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
406		assert!(producer.consume("test").is_some());
407
408		drop(broadcast2);
409
410		// Wait for the async task to run.
411		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
412		assert!(producer.consume("test").is_none());
413	}
414
415	#[tokio::test]
416	async fn test_duplicate_reverse() {
417		let mut producer = OriginProducer::new();
418		let broadcast1 = BroadcastProducer::new();
419		let broadcast2 = BroadcastProducer::new();
420
421		producer.publish("test", broadcast1.consume());
422		producer.publish("test", broadcast2.consume());
423		assert!(producer.consume("test").is_some());
424
425		// This is harder, dropping the new broadcast first.
426		drop(broadcast2);
427
428		// Wait for the cleanup async task to run.
429		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
430		assert!(producer.consume("test").is_some());
431
432		drop(broadcast1);
433
434		// Wait for the cleanup async task to run.
435		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
436		assert!(producer.consume("test").is_none());
437	}
438
439	#[tokio::test]
440	async fn test_double_publish() {
441		let mut producer = OriginProducer::new();
442		let broadcast = BroadcastProducer::new();
443
444		// Ensure it doesn't crash.
445		producer.publish("test", broadcast.consume());
446		producer.publish("test", broadcast.consume());
447
448		assert!(producer.consume("test").is_some());
449
450		drop(broadcast);
451
452		// Wait for the async task to run.
453		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
454		assert!(producer.consume("test").is_none());
455	}
456}