moq_lite/model/
origin.rs

1use std::collections::{hash_map, HashMap};
2use tokio::sync::mpsc;
3use web_async::{Lock, LockWeak};
4
5use super::BroadcastConsumer;
6use crate::{Path, PathRef};
7
8// If there are multiple broadcasts with the same path, we use the most recent one but keep the others around.
9struct BroadcastState {
10	active: BroadcastConsumer,
11	backup: Vec<BroadcastConsumer>,
12}
13
14#[derive(Default)]
15struct ProducerState {
16	active: HashMap<Path, BroadcastState>,
17	consumers: Vec<ConsumerState>,
18}
19
20impl ProducerState {
21	// Returns true if this was a unique broadcast.
22	fn publish(&mut self, path: Path, broadcast: BroadcastConsumer) -> bool {
23		let mut unique = true;
24
25		match self.active.entry(path.clone()) {
26			hash_map::Entry::Occupied(mut entry) => {
27				let state = entry.get_mut();
28				if state.active.is_clone(&broadcast) {
29					// If we're already publishing this broadcast, then don't do anything.
30					return false;
31				}
32
33				// Make the new broadcast the active one.
34				let old = state.active.clone();
35				state.active = broadcast.clone();
36
37				// Move the old broadcast to the backup list.
38				// But we need to replace any previous duplicates.
39				let pos = state.backup.iter().position(|b| b.is_clone(&broadcast));
40				if let Some(pos) = pos {
41					state.backup[pos] = old;
42
43					// We're already publishing this broadcast, so don't run the cleanup task.
44					unique = false;
45				} else {
46					state.backup.push(old);
47				}
48
49				// Reannounce the path to all consumers.
50				retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
51			}
52			hash_map::Entry::Vacant(entry) => {
53				entry.insert(BroadcastState {
54					active: broadcast.clone(),
55					backup: Vec::new(),
56				});
57			}
58		};
59
60		retain_mut_unordered(&mut self.consumers, |c| c.insert(&path, &broadcast));
61
62		unique
63	}
64
65	fn remove(&mut self, path: Path, broadcast: BroadcastConsumer) {
66		let mut entry = match self.active.entry(path.clone()) {
67			hash_map::Entry::Occupied(entry) => entry,
68			hash_map::Entry::Vacant(_) => panic!("broadcast not found"),
69		};
70
71		// See if we can remove the broadcast from the backup list.
72		let pos = entry.get().backup.iter().position(|b| b.is_clone(&broadcast));
73		if let Some(pos) = pos {
74			entry.get_mut().backup.remove(pos);
75			// Nothing else to do
76			return;
77		}
78
79		// Okay so it must be the active broadcast or else we fucked up.
80		assert!(entry.get().active.is_clone(&broadcast));
81
82		retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
83
84		// If there's a backup broadcast, then announce it.
85		if let Some(active) = entry.get_mut().backup.pop() {
86			entry.get_mut().active = active;
87			retain_mut_unordered(&mut self.consumers, |c| c.insert(&path, &entry.get().active));
88		} else {
89			// No more backups, so remove the entry.
90			entry.remove();
91		}
92	}
93}
94
95impl Drop for ProducerState {
96	fn drop(&mut self) {
97		for (path, _) in self.active.drain() {
98			retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
99		}
100	}
101}
102
103// A faster version of retain_mut that doesn't maintain the order.
104fn retain_mut_unordered<T, F: Fn(&mut T) -> bool>(vec: &mut Vec<T>, f: F) {
105	let mut i = 0;
106	while let Some(item) = vec.get_mut(i) {
107		if f(item) {
108			i += 1;
109		} else {
110			vec.swap_remove(i);
111		}
112	}
113}
114
115/// A broadcast path and its associated consumer, or None if closed.
116/// The returned path is relative to the consumer's prefix.
117pub struct OriginUpdate {
118	pub suffix: Path,
119	pub active: Option<BroadcastConsumer>,
120}
121
122struct ConsumerState {
123	prefix: Path,
124	updates: mpsc::UnboundedSender<OriginUpdate>,
125}
126
127impl ConsumerState {
128	// Returns true if the consumer is still alive.
129	pub fn insert<'a>(&mut self, path: impl Into<PathRef<'a>>, consumer: &BroadcastConsumer) -> bool {
130		let path_ref = path.into();
131
132		if let Some(suffix) = path_ref.to_path().strip_prefix(&self.prefix) {
133			// Send the absolute path, not the relative suffix
134			let update = OriginUpdate {
135				suffix: suffix.into(),
136				active: Some(consumer.clone()),
137			};
138			return self.updates.send(update).is_ok();
139		}
140
141		!self.updates.is_closed()
142	}
143
144	pub fn remove<'a>(&mut self, path: impl Into<PathRef<'a>>) -> bool {
145		let path_ref = path.into();
146
147		if let Some(suffix) = path_ref.to_path().strip_prefix(&self.prefix) {
148			// Send the absolute path, not the relative suffix
149			let update = OriginUpdate {
150				suffix: suffix.into(),
151				active: None,
152			};
153			return self.updates.send(update).is_ok();
154		}
155
156		!self.updates.is_closed()
157	}
158}
159
160/// Announces broadcasts to consumers over the network.
161#[derive(Clone, Default)]
162pub struct OriginProducer {
163	prefix: Path,
164	state: Lock<ProducerState>,
165}
166
167impl OriginProducer {
168	/// Create a new origin producer with an optional prefix applied to all broadcasts.
169	pub fn new<T: Into<Path>>(prefix: T) -> Self {
170		let prefix = prefix.into();
171		Self {
172			state: Lock::new(ProducerState {
173				active: HashMap::new(),
174				consumers: Vec::new(),
175			}),
176			prefix,
177		}
178	}
179
180	/// Publish a broadcast, announcing it to all consumers.
181	///
182	/// The broadcast will be unannounced when it is closed.
183	/// If there is already a broadcast with the same path, then it will be replaced and reannounced.
184	/// If the old broadcast is closed before the new one, then nothing will happen.
185	/// If the new broadcast is closed before the old one, then the old broadcast will be reannounced.
186	pub fn publish<'a>(&mut self, suffix: impl Into<PathRef<'a>>, broadcast: BroadcastConsumer) {
187		let path = self.prefix.join(suffix.into());
188
189		if !self.state.lock().publish(path.clone(), broadcast.clone()) {
190			// This is not a big deal, but we want to avoid spawning additional cleanup tasks.
191			tracing::warn!(?path, "duplicate publish");
192			return;
193		}
194
195		let state = self.state.clone().downgrade();
196
197		// TODO cancel this task when the producer is dropped.
198		web_async::spawn(async move {
199			broadcast.closed().await;
200			if let Some(state) = state.upgrade() {
201				state.lock().remove(path, broadcast);
202			}
203		});
204	}
205
206	/// Create a new origin producer for the given sub-prefix.
207	pub fn publish_prefix<'a>(&mut self, prefix: impl Into<PathRef<'a>>) -> OriginProducer {
208		OriginProducer {
209			prefix: self.prefix.join(prefix.into()),
210			state: self.state.clone(),
211		}
212	}
213
214	/// Get a specific broadcast by path.
215	///
216	/// The most recent, non-closed broadcast will be returned if there are duplicates.
217	pub fn consume<'a>(&self, suffix: impl Into<PathRef<'a>>) -> Option<BroadcastConsumer> {
218		let path = self.prefix.join(suffix.into());
219		self.state.lock().active.get(&path).map(|b| b.active.clone())
220	}
221
222	/// Subscribe to all announced broadcasts.
223	pub fn consume_all(&self) -> OriginConsumer {
224		self.consume_prefix("")
225	}
226
227	/// Subscribe to all announced broadcasts matching the prefix.
228	///
229	/// NOTE: This takes a Suffix because it's appended to the existing prefix to get a new prefix.
230	/// Confusing I know, but it means that we don't have to return a Result.
231	pub fn consume_prefix<'a>(&self, prefix: impl Into<PathRef<'a>>) -> OriginConsumer {
232		// Combine the consumer's prefix with the publisher's prefix.
233		let prefix = self.prefix.join(prefix.into());
234
235		let mut state = self.state.lock();
236
237		let (tx, rx) = mpsc::unbounded_channel();
238		let mut consumer = ConsumerState {
239			prefix: prefix.clone(),
240			updates: tx,
241		};
242
243		for (path, broadcast) in &state.active {
244			consumer.insert(path, &broadcast.active);
245		}
246		state.consumers.push(consumer);
247
248		OriginConsumer {
249			prefix,
250			updates: rx,
251			producer: self.state.clone().downgrade(),
252		}
253	}
254
255	/// Wait until all consumers have been dropped.
256	///
257	/// NOTE: subscribe can be called to unclose the producer.
258	pub async fn unused(&self) {
259		// Keep looping until all consumers are closed.
260		while let Some(notify) = self.unused_inner() {
261			notify.closed().await;
262		}
263	}
264
265	// Returns the closed notify of any consumer.
266	fn unused_inner(&self) -> Option<mpsc::UnboundedSender<OriginUpdate>> {
267		let mut state = self.state.lock();
268
269		while let Some(consumer) = state.consumers.last() {
270			if !consumer.updates.is_closed() {
271				return Some(consumer.updates.clone());
272			}
273
274			state.consumers.pop();
275		}
276
277		None
278	}
279
280	pub fn prefix(&self) -> &Path {
281		&self.prefix
282	}
283}
284
285/// Consumes announced broadcasts matching against an optional prefix.
286pub struct OriginConsumer {
287	// We need a weak reference to the producer so that we can clone it.
288	producer: LockWeak<ProducerState>,
289	updates: mpsc::UnboundedReceiver<OriginUpdate>,
290	prefix: Path,
291}
292
293impl OriginConsumer {
294	/// Returns the next (un)announced broadcast and the absolute path.
295	///
296	/// The broadcast will only be None if it was previously Some.
297	/// The same path won't be announced/unannounced twice, instead it will toggle.
298	///
299	/// Note: The returned path is absolute and will always match this consumer's prefix.
300	pub async fn next(&mut self) -> Option<OriginUpdate> {
301		self.updates.recv().await
302	}
303
304	/// Get a specific broadcast by path.
305	///
306	/// This is relative to the consumer's prefix.
307	pub fn consume<'a>(&self, suffix: impl Into<PathRef<'a>>) -> Option<BroadcastConsumer> {
308		let path = self.prefix.join(suffix.into());
309
310		let state = self.producer.upgrade()?;
311		let state = state.lock();
312		state.active.get(&path).map(|b| b.active.clone())
313	}
314
315	pub fn consume_all(&self) -> OriginConsumer {
316		self.consume_prefix("")
317	}
318
319	pub fn consume_prefix<'a>(&self, prefix: impl Into<PathRef<'a>>) -> OriginConsumer {
320		// Combine the consumer's prefix with the existing consumer's prefix.
321		let prefix = self.prefix.join(prefix.into());
322
323		let (tx, rx) = mpsc::unbounded_channel();
324
325		// NOTE: consumer is immediately dropped, signalling FIN, if the producer can't be upgraded.
326		let mut consumer = ConsumerState {
327			prefix: prefix.clone(),
328			updates: tx,
329		};
330
331		if let Some(state) = self.producer.upgrade() {
332			let mut state = state.lock();
333
334			for (path, broadcast) in &state.active {
335				consumer.insert(path, &broadcast.active);
336			}
337
338			state.consumers.push(consumer);
339		}
340
341		OriginConsumer {
342			prefix,
343			updates: rx,
344			producer: self.producer.clone(),
345		}
346	}
347
348	pub fn prefix(&self) -> &Path {
349		&self.prefix
350	}
351}
352
353impl Clone for OriginConsumer {
354	fn clone(&self) -> Self {
355		self.consume_all()
356	}
357}
358
359#[cfg(test)]
360use futures::FutureExt;
361
362#[cfg(test)]
363impl OriginConsumer {
364	pub fn assert_next(&mut self, path: &str, broadcast: &BroadcastConsumer) {
365		let next = self.next().now_or_never().expect("next blocked").expect("no next");
366		assert_eq!(next.suffix.as_str(), path, "wrong path");
367		assert!(next.active.unwrap().is_clone(broadcast), "should be the same broadcast");
368	}
369
370	pub fn assert_next_none(&mut self, path: &str) {
371		let next = self.next().now_or_never().expect("next blocked").expect("no next");
372		assert_eq!(next.suffix.as_str(), path, "wrong path");
373		assert!(next.active.is_none(), "should be unannounced");
374	}
375
376	pub fn assert_next_wait(&mut self) {
377		assert!(self.next().now_or_never().is_none(), "next should block");
378	}
379
380	pub fn assert_next_closed(&mut self) {
381		assert!(
382			self.next().now_or_never().expect("next blocked").is_none(),
383			"next should be closed"
384		);
385	}
386}
387
388#[cfg(test)]
389mod tests {
390	use crate::BroadcastProducer;
391
392	use super::*;
393
394	#[tokio::test]
395	async fn test_announce() {
396		let mut producer = OriginProducer::default();
397		let broadcast1 = BroadcastProducer::new();
398		let broadcast2 = BroadcastProducer::new();
399
400		// Make a new consumer that should get it.
401		let mut consumer1 = producer.consume_all();
402		consumer1.assert_next_wait();
403
404		// Publish the first broadcast.
405		producer.publish("test1", broadcast1.consume());
406
407		consumer1.assert_next("test1", &broadcast1.consume());
408		consumer1.assert_next_wait();
409
410		// Make a new consumer that should get the existing broadcast.
411		// But we don't consume it yet.
412		let mut consumer2 = producer.consume_all();
413
414		// Publish the second broadcast.
415		producer.publish("test2", broadcast2.consume());
416
417		consumer1.assert_next("test2", &broadcast2.consume());
418		consumer1.assert_next_wait();
419
420		consumer2.assert_next("test1", &broadcast1.consume());
421		consumer2.assert_next("test2", &broadcast2.consume());
422		consumer2.assert_next_wait();
423
424		// Close the first broadcast.
425		drop(broadcast1);
426
427		// Wait for the async task to run.
428		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
429
430		// All consumers should get a None now.
431		consumer1.assert_next_none("test1");
432		consumer2.assert_next_none("test1");
433		consumer1.assert_next_wait();
434		consumer2.assert_next_wait();
435
436		// And a new consumer only gets the last broadcast.
437		let mut consumer3 = producer.consume_all();
438		consumer3.assert_next("test2", &broadcast2.consume());
439		consumer3.assert_next_wait();
440
441		// Close the producer and make sure it cleans up
442		drop(producer);
443
444		// Wait for the async task to run.
445		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
446
447		consumer1.assert_next_none("test2");
448		consumer2.assert_next_none("test2");
449		consumer3.assert_next_none("test2");
450
451		consumer1.assert_next_closed();
452		consumer2.assert_next_closed();
453		consumer3.assert_next_closed();
454	}
455
456	#[tokio::test]
457	async fn test_duplicate() {
458		let mut producer = OriginProducer::default();
459		let broadcast1 = BroadcastProducer::new();
460		let broadcast2 = BroadcastProducer::new();
461
462		producer.publish("test", broadcast1.consume());
463		producer.publish("test", broadcast2.consume());
464		assert!(producer.consume("test").is_some());
465
466		drop(broadcast1);
467
468		// Wait for the async task to run.
469		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
470		assert!(producer.consume("test").is_some());
471
472		drop(broadcast2);
473
474		// Wait for the async task to run.
475		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
476		assert!(producer.consume("test").is_none());
477	}
478
479	#[tokio::test]
480	async fn test_duplicate_reverse() {
481		let mut producer = OriginProducer::default();
482		let broadcast1 = BroadcastProducer::new();
483		let broadcast2 = BroadcastProducer::new();
484
485		producer.publish("test", broadcast1.consume());
486		producer.publish("test", broadcast2.consume());
487		assert!(producer.consume("test").is_some());
488
489		// This is harder, dropping the new broadcast first.
490		drop(broadcast2);
491
492		// Wait for the cleanup async task to run.
493		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
494		assert!(producer.consume("test").is_some());
495
496		drop(broadcast1);
497
498		// Wait for the cleanup async task to run.
499		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
500		assert!(producer.consume("test").is_none());
501	}
502
503	#[tokio::test]
504	async fn test_double_publish() {
505		let mut producer = OriginProducer::default();
506		let broadcast = BroadcastProducer::new();
507
508		// Ensure it doesn't crash.
509		producer.publish("test", broadcast.consume());
510		producer.publish("test", broadcast.consume());
511
512		assert!(producer.consume("test").is_some());
513
514		drop(broadcast);
515
516		// Wait for the async task to run.
517		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
518		assert!(producer.consume("test").is_none());
519	}
520}