moq_lite/model/
origin.rs

1use std::collections::{hash_map, HashMap};
2use tokio::sync::mpsc;
3use web_async::Lock;
4
5use super::BroadcastConsumer;
6
7// If there are multiple broadcasts with the same path, we use the most recent one but keep the others around.
8struct BroadcastState {
9	active: BroadcastConsumer,
10	backup: Vec<BroadcastConsumer>,
11}
12
13#[derive(Default)]
14struct ProducerState {
15	active: HashMap<String, BroadcastState>,
16	consumers: Vec<ConsumerState>,
17}
18
19impl ProducerState {
20	// Returns true if this was a unique broadcast.
21	fn publish(&mut self, path: String, broadcast: BroadcastConsumer) -> bool {
22		let mut unique = true;
23
24		match self.active.entry(path.clone()) {
25			hash_map::Entry::Occupied(mut entry) => {
26				let state = entry.get_mut();
27				if state.active.is_clone(&broadcast) {
28					// If we're already publishing this broadcast, then don't do anything.
29					return false;
30				}
31
32				// Make the new broadcast the active one.
33				let old = state.active.clone();
34				state.active = broadcast.clone();
35
36				// Move the old broadcast to the backup list.
37				// But we need to replace any previous duplicates.
38				let pos = state.backup.iter().position(|b| b.is_clone(&broadcast));
39				if let Some(pos) = pos {
40					state.backup[pos] = old;
41
42					// We're already publishing this broadcast, so don't run the cleanup task.
43					unique = false;
44				} else {
45					state.backup.push(old);
46				}
47
48				// Reannounce the path to all consumers.
49				retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
50			}
51			hash_map::Entry::Vacant(entry) => {
52				entry.insert(BroadcastState {
53					active: broadcast.clone(),
54					backup: Vec::new(),
55				});
56			}
57		};
58
59		retain_mut_unordered(&mut self.consumers, |c| c.insert(&path, &broadcast));
60
61		unique
62	}
63
64	fn remove(&mut self, path: String, broadcast: BroadcastConsumer) {
65		let mut entry = match self.active.entry(path) {
66			hash_map::Entry::Occupied(entry) => entry,
67			hash_map::Entry::Vacant(_) => panic!("broadcast not found"),
68		};
69
70		// See if we can remove the broadcast from the backup list.
71		let pos = entry.get().backup.iter().position(|b| b.is_clone(&broadcast));
72		if let Some(pos) = pos {
73			entry.get_mut().backup.remove(pos);
74			// Nothing else to do
75			return;
76		}
77
78		// Okay so it must be the active broadcast or else we fucked up.
79		assert!(entry.get().active.is_clone(&broadcast));
80
81		retain_mut_unordered(&mut self.consumers, |c| c.remove(entry.key()));
82
83		// If there's a backup broadcast, then announce it.
84		if let Some(active) = entry.get_mut().backup.pop() {
85			entry.get_mut().active = active;
86			retain_mut_unordered(&mut self.consumers, |c| c.insert(entry.key(), &entry.get().active));
87		} else {
88			// No more backups, so remove the entry.
89			entry.remove();
90		}
91	}
92}
93
94impl Drop for ProducerState {
95	fn drop(&mut self) {
96		for (path, _) in self.active.drain() {
97			retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
98		}
99	}
100}
101
102// A faster version of retain_mut that doesn't maintain the order.
103fn retain_mut_unordered<T, F: Fn(&mut T) -> bool>(vec: &mut Vec<T>, f: F) {
104	let mut i = 0;
105	while let Some(item) = vec.get_mut(i) {
106		if f(item) {
107			i += 1;
108		} else {
109			vec.swap_remove(i);
110		}
111	}
112}
113
114/// A broadcast path and its associated broadcast, or None if closed.
115type ConsumerUpdate = (String, Option<BroadcastConsumer>);
116
117struct ConsumerState {
118	prefix: String,
119	exact: bool,
120	updates: mpsc::UnboundedSender<ConsumerUpdate>,
121}
122
123impl ConsumerState {
124	// Returns true if the consuemr is still alive.
125	pub fn insert(&mut self, path: &str, consumer: &BroadcastConsumer) -> bool {
126		if self.exact {
127			if path == self.prefix {
128				let update = ("".to_string(), Some(consumer.clone()));
129				return self.updates.send(update).is_ok();
130			}
131		} else if let Some(suffix) = path.strip_prefix(&self.prefix) {
132			let update = (suffix.to_string(), Some(consumer.clone()));
133			return self.updates.send(update).is_ok();
134		}
135
136		!self.updates.is_closed()
137	}
138
139	pub fn remove(&mut self, path: &str) -> bool {
140		if self.exact {
141			if path == self.prefix {
142				let update = ("".to_string(), None);
143				return self.updates.send(update).is_ok();
144			}
145		} else if let Some(suffix) = path.strip_prefix(&self.prefix) {
146			let update = (suffix.to_string(), None);
147			return self.updates.send(update).is_ok();
148		}
149
150		!self.updates.is_closed()
151	}
152}
153
154/// Announces broadcasts to consumers over the network.
155#[derive(Clone, Default)]
156pub struct OriginProducer {
157	state: Lock<ProducerState>,
158}
159
160impl OriginProducer {
161	pub fn new() -> Self {
162		Self {
163			state: Lock::new(ProducerState {
164				active: HashMap::new(),
165				consumers: Vec::new(),
166			}),
167		}
168	}
169
170	/// Publish a broadcast, announcing it to all consumers.
171	///
172	/// The broadcast will be unannounced when it is closed.
173	/// If there is already a broadcast with the same path, then it will be replaced and reannounced.
174	/// If the old broadcast is closed before the new one, then nothing will happen.
175	/// If the new broadcast is closed before the old one, then the old broadcast will be reannounced.
176	pub fn publish<S: ToString>(&mut self, path: S, broadcast: BroadcastConsumer) {
177		let path = path.to_string();
178
179		if !self.state.lock().publish(path.clone(), broadcast.clone()) {
180			// This is not a big deal, but we want to avoid spawning additional cleanup tasks.
181			tracing::warn!(?path, "duplicate publish");
182			return;
183		}
184
185		let state = self.state.clone().downgrade();
186
187		// TODO cancel this task when the producer is dropped.
188		web_async::spawn(async move {
189			broadcast.closed().await;
190			if let Some(state) = state.upgrade() {
191				state.lock().remove(path, broadcast);
192			}
193		});
194	}
195
196	/// Publish all broadcasts from the given origin.
197	pub fn publish_all(&mut self, broadcasts: OriginConsumer) {
198		self.publish_prefix("", broadcasts);
199	}
200
201	/// Publish all broadcasts from the given origin with an optional prefix.
202	pub fn publish_prefix(&mut self, prefix: &str, mut broadcasts: OriginConsumer) {
203		// Really gross that this just spawns a background task, but I want publishing to be sync.
204		let mut this = self.clone();
205
206		// Overkill to avoid allocating a string if the prefix is empty.
207		let prefix = match prefix {
208			"" => None,
209			prefix => Some(prefix.to_string()),
210		};
211
212		web_async::spawn(async move {
213			while let Some((suffix, broadcast)) = broadcasts.next().await {
214				let broadcast = match broadcast {
215					Some(broadcast) => broadcast,
216					// We don't need to worry about unannouncements here because our own OriginPublisher will handle it.
217					// Announcements are ordered so I don't think there's a race condition?
218					None => continue,
219				};
220
221				let path = match &prefix {
222					Some(prefix) => format!("{prefix}{suffix}"),
223					None => suffix,
224				};
225
226				this.publish(path, broadcast);
227			}
228		});
229	}
230
231	/// Get a specific broadcast by name.
232	///
233	/// The most recent, non-closed broadcast will be returned if there are duplicates.
234	pub fn consume(&self, path: &str) -> Option<BroadcastConsumer> {
235		self.state.lock().active.get(path).map(|b| b.active.clone())
236	}
237
238	/// Subscribe to all announced broadcasts.
239	pub fn consume_all(&self) -> OriginConsumer {
240		self.consume_prefix("")
241	}
242
243	/// Subscribe to all announced broadcasts matching the prefix.
244	pub fn consume_prefix<S: ToString>(&self, prefix: S) -> OriginConsumer {
245		let mut state = self.state.lock();
246
247		let (tx, rx) = mpsc::unbounded_channel();
248		let mut consumer = ConsumerState {
249			prefix: prefix.to_string(),
250			exact: false,
251			updates: tx,
252		};
253
254		for (prefix, broadcast) in &state.active {
255			consumer.insert(prefix, &broadcast.active);
256		}
257		state.consumers.push(consumer);
258
259		OriginConsumer::new(rx)
260	}
261
262	/// Wait for an exact broadcast to be announced.
263	pub fn consume_exact<S: ToString>(&self, path: S) -> OriginConsumer {
264		let mut state = self.state.lock();
265
266		let (tx, rx) = mpsc::unbounded_channel();
267		let mut consumer = ConsumerState {
268			prefix: path.to_string(),
269			exact: true,
270			updates: tx,
271		};
272
273		for (prefix, broadcast) in &state.active {
274			consumer.insert(prefix, &broadcast.active);
275		}
276		state.consumers.push(consumer);
277
278		OriginConsumer::new(rx)
279	}
280
281	/// Wait until all consumers have been dropped.
282	///
283	/// NOTE: subscribe can be called to unclose the producer.
284	pub async fn unused(&self) {
285		// Keep looping until all consumers are closed.
286		while let Some(notify) = self.unused_inner() {
287			notify.closed().await;
288		}
289	}
290
291	// Returns the closed notify of any consumer.
292	fn unused_inner(&self) -> Option<mpsc::UnboundedSender<ConsumerUpdate>> {
293		let mut state = self.state.lock();
294
295		while let Some(consumer) = state.consumers.last() {
296			if !consumer.updates.is_closed() {
297				return Some(consumer.updates.clone());
298			}
299
300			state.consumers.pop();
301		}
302
303		None
304	}
305}
306
307/// Consumes announced broadcasts matching against an optional prefix.
308pub struct OriginConsumer {
309	updates: mpsc::UnboundedReceiver<ConsumerUpdate>,
310}
311
312impl OriginConsumer {
313	fn new(updates: mpsc::UnboundedReceiver<ConsumerUpdate>) -> Self {
314		Self { updates }
315	}
316
317	/// Returns the next (un)announced broadcast and the path.
318	///
319	/// The broadcast will only be None if it was previously Some.
320	/// The same path won't be announced/unannounced twice, instead it will toggle.
321	pub async fn next(&mut self) -> Option<ConsumerUpdate> {
322		self.updates.recv().await
323	}
324}
325
326#[cfg(test)]
327use futures::FutureExt;
328
329#[cfg(test)]
330impl OriginConsumer {
331	pub fn assert_next(&mut self, path: &str, broadcast: &BroadcastConsumer) {
332		let next = self.next().now_or_never().expect("next blocked").expect("no next");
333		assert_eq!(next.0, path, "wrong path");
334		assert!(next.1.unwrap().is_clone(broadcast), "should be the same broadcast");
335	}
336
337	pub fn assert_next_none(&mut self, path: &str) {
338		let next = self.next().now_or_never().expect("next blocked").expect("no next");
339		assert_eq!(next.0, path, "wrong path");
340		assert!(next.1.is_none(), "should be unannounced");
341	}
342
343	pub fn assert_next_wait(&mut self) {
344		assert!(self.next().now_or_never().is_none(), "next should block");
345	}
346
347	pub fn assert_next_closed(&mut self) {
348		assert!(
349			self.next().now_or_never().expect("next blocked").is_none(),
350			"next should be closed"
351		);
352	}
353}
354
355#[cfg(test)]
356mod tests {
357	use crate::BroadcastProducer;
358
359	use super::*;
360
361	#[tokio::test]
362	async fn test_announce() {
363		let mut producer = OriginProducer::new();
364		let broadcast1 = BroadcastProducer::new();
365		let broadcast2 = BroadcastProducer::new();
366
367		// Make a new consumer that should get it.
368		let mut consumer1 = producer.consume_all();
369		consumer1.assert_next_wait();
370
371		// Publish the first broadcast.
372		producer.publish("test1", broadcast1.consume());
373
374		consumer1.assert_next("test1", &broadcast1.consume());
375		consumer1.assert_next_wait();
376
377		// Make a new consumer that should get the existing broadcast.
378		// But we don't consume it yet.
379		let mut consumer2 = producer.consume_all();
380
381		// Publish the second broadcast.
382		producer.publish("test2", broadcast2.consume());
383
384		consumer1.assert_next("test2", &broadcast2.consume());
385		consumer1.assert_next_wait();
386
387		consumer2.assert_next("test1", &broadcast1.consume());
388		consumer2.assert_next("test2", &broadcast2.consume());
389		consumer2.assert_next_wait();
390
391		// Close the first broadcast.
392		drop(broadcast1);
393
394		// Wait for the async task to run.
395		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
396
397		// All consumers should get a None now.
398		consumer1.assert_next_none("test1");
399		consumer2.assert_next_none("test1");
400		consumer1.assert_next_wait();
401		consumer2.assert_next_wait();
402
403		// And a new consumer only gets the last broadcast.
404		let mut consumer3 = producer.consume_all();
405		consumer3.assert_next("test2", &broadcast2.consume());
406		consumer3.assert_next_wait();
407
408		// Close the producer and make sure it cleans up
409		drop(producer);
410
411		// Wait for the async task to run.
412		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
413
414		consumer1.assert_next_none("test2");
415		consumer2.assert_next_none("test2");
416		consumer3.assert_next_none("test2");
417
418		consumer1.assert_next_closed();
419		consumer2.assert_next_closed();
420		consumer3.assert_next_closed();
421	}
422
423	#[tokio::test]
424	async fn test_duplicate() {
425		let mut producer = OriginProducer::new();
426		let broadcast1 = BroadcastProducer::new();
427		let broadcast2 = BroadcastProducer::new();
428
429		producer.publish("test", broadcast1.consume());
430		producer.publish("test", broadcast2.consume());
431		assert!(producer.consume("test").is_some());
432
433		drop(broadcast1);
434
435		// Wait for the async task to run.
436		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
437		assert!(producer.consume("test").is_some());
438
439		drop(broadcast2);
440
441		// Wait for the async task to run.
442		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
443		assert!(producer.consume("test").is_none());
444	}
445
446	#[tokio::test]
447	async fn test_duplicate_reverse() {
448		let mut producer = OriginProducer::new();
449		let broadcast1 = BroadcastProducer::new();
450		let broadcast2 = BroadcastProducer::new();
451
452		producer.publish("test", broadcast1.consume());
453		producer.publish("test", broadcast2.consume());
454		assert!(producer.consume("test").is_some());
455
456		// This is harder, dropping the new broadcast first.
457		drop(broadcast2);
458
459		// Wait for the cleanup async task to run.
460		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
461		assert!(producer.consume("test").is_some());
462
463		drop(broadcast1);
464
465		// Wait for the cleanup async task to run.
466		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
467		assert!(producer.consume("test").is_none());
468	}
469
470	#[tokio::test]
471	async fn test_double_publish() {
472		let mut producer = OriginProducer::new();
473		let broadcast = BroadcastProducer::new();
474
475		// Ensure it doesn't crash.
476		producer.publish("test", broadcast.consume());
477		producer.publish("test", broadcast.consume());
478
479		assert!(producer.consume("test").is_some());
480
481		drop(broadcast);
482
483		// Wait for the async task to run.
484		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
485		assert!(producer.consume("test").is_none());
486	}
487}