moq_lite/model/
origin.rs

1use std::collections::{hash_map, HashMap};
2use tokio::sync::mpsc;
3use web_async::{Lock, LockWeak};
4
5use super::BroadcastConsumer;
6use crate::{Path, PathRef};
7
8// If there are multiple broadcasts with the same path, we use the most recent one but keep the others around.
9struct BroadcastState {
10	active: BroadcastConsumer,
11	backup: Vec<BroadcastConsumer>,
12}
13
14#[derive(Default)]
15struct ProducerState {
16	active: HashMap<Path, BroadcastState>,
17	consumers: Vec<ConsumerState>,
18}
19
20impl ProducerState {
21	// Returns true if this was a unique broadcast.
22	fn publish(&mut self, path: Path, broadcast: BroadcastConsumer) -> bool {
23		let mut unique = true;
24
25		match self.active.entry(path.clone()) {
26			hash_map::Entry::Occupied(mut entry) => {
27				let state = entry.get_mut();
28				if state.active.is_clone(&broadcast) {
29					// If we're already publishing this broadcast, then don't do anything.
30					return false;
31				}
32
33				// Make the new broadcast the active one.
34				let old = state.active.clone();
35				state.active = broadcast.clone();
36
37				// Move the old broadcast to the backup list.
38				// But we need to replace any previous duplicates.
39				let pos = state.backup.iter().position(|b| b.is_clone(&broadcast));
40				if let Some(pos) = pos {
41					state.backup[pos] = old;
42
43					// We're already publishing this broadcast, so don't run the cleanup task.
44					unique = false;
45				} else {
46					state.backup.push(old);
47				}
48
49				// Reannounce the path to all consumers.
50				retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
51			}
52			hash_map::Entry::Vacant(entry) => {
53				entry.insert(BroadcastState {
54					active: broadcast.clone(),
55					backup: Vec::new(),
56				});
57			}
58		};
59
60		retain_mut_unordered(&mut self.consumers, |c| c.insert(&path, &broadcast));
61
62		unique
63	}
64
65	fn remove(&mut self, path: Path, broadcast: BroadcastConsumer) {
66		let mut entry = match self.active.entry(path.clone()) {
67			hash_map::Entry::Occupied(entry) => entry,
68			hash_map::Entry::Vacant(_) => panic!("broadcast not found"),
69		};
70
71		// See if we can remove the broadcast from the backup list.
72		let pos = entry.get().backup.iter().position(|b| b.is_clone(&broadcast));
73		if let Some(pos) = pos {
74			entry.get_mut().backup.remove(pos);
75			// Nothing else to do
76			return;
77		}
78
79		// Okay so it must be the active broadcast or else we fucked up.
80		assert!(entry.get().active.is_clone(&broadcast));
81
82		retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
83
84		// If there's a backup broadcast, then announce it.
85		if let Some(active) = entry.get_mut().backup.pop() {
86			entry.get_mut().active = active;
87			retain_mut_unordered(&mut self.consumers, |c| c.insert(&path, &entry.get().active));
88		} else {
89			// No more backups, so remove the entry.
90			entry.remove();
91		}
92	}
93}
94
95impl Drop for ProducerState {
96	fn drop(&mut self) {
97		for (path, _) in self.active.drain() {
98			retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
99		}
100	}
101}
102
103// A faster version of retain_mut that doesn't maintain the order.
104fn retain_mut_unordered<T, F: Fn(&mut T) -> bool>(vec: &mut Vec<T>, f: F) {
105	let mut i = 0;
106	while let Some(item) = vec.get_mut(i) {
107		if f(item) {
108			i += 1;
109		} else {
110			vec.swap_remove(i);
111		}
112	}
113}
114
115/// A broadcast path and its associated consumer, or None if closed.
116/// The returned path is relative to the consumer's prefix.
117pub struct OriginUpdate {
118	pub suffix: Path,
119	pub active: Option<BroadcastConsumer>,
120}
121
122struct ConsumerState {
123	prefix: Path,
124	updates: mpsc::UnboundedSender<OriginUpdate>,
125}
126
127impl ConsumerState {
128	// Returns true if the consumer is still alive.
129	pub fn insert<'a>(&mut self, path: impl Into<PathRef<'a>>, consumer: &BroadcastConsumer) -> bool {
130		let path_ref = path.into();
131
132		if let Some(suffix) = path_ref.to_owned().strip_prefix(&self.prefix) {
133			// Send the absolute path, not the relative suffix
134			let update = OriginUpdate {
135				suffix: suffix.into(),
136				active: Some(consumer.clone()),
137			};
138			return self.updates.send(update).is_ok();
139		}
140
141		!self.updates.is_closed()
142	}
143
144	pub fn remove<'a>(&mut self, path: impl Into<PathRef<'a>>) -> bool {
145		let path_ref = path.into();
146
147		if let Some(suffix) = path_ref.to_owned().strip_prefix(&self.prefix) {
148			// Send the absolute path, not the relative suffix
149			let update = OriginUpdate {
150				suffix: suffix.into(),
151				active: None,
152			};
153			return self.updates.send(update).is_ok();
154		}
155
156		!self.updates.is_closed()
157	}
158}
159
160/// Announces broadcasts to consumers over the network.
161#[derive(Clone, Default)]
162pub struct OriginProducer {
163	prefix: Path,
164	state: Lock<ProducerState>,
165}
166
167impl OriginProducer {
168	/// Create a new origin producer with an optional prefix applied to all broadcasts.
169	pub fn new<T: Into<Path>>(prefix: T) -> Self {
170		let prefix = prefix.into();
171		Self {
172			state: Lock::new(ProducerState {
173				active: HashMap::new(),
174				consumers: Vec::new(),
175			}),
176			prefix,
177		}
178	}
179
180	/// Publish a broadcast, announcing it to all consumers.
181	///
182	/// The broadcast will be unannounced when it is closed.
183	/// If there is already a broadcast with the same path, then it will be replaced and reannounced.
184	/// If the old broadcast is closed before the new one, then nothing will happen.
185	/// If the new broadcast is closed before the old one, then the old broadcast will be reannounced.
186	pub fn publish<'a>(&mut self, suffix: impl Into<PathRef<'a>>, broadcast: BroadcastConsumer) {
187		let path = self.prefix.join(suffix.into());
188
189		if !self.state.lock().publish(path.clone(), broadcast.clone()) {
190			// This is not a big deal, but we want to avoid spawning additional cleanup tasks.
191			tracing::warn!(?path, "duplicate publish");
192			return;
193		}
194
195		let state = self.state.clone().downgrade();
196
197		// TODO cancel this task when the producer is dropped.
198		web_async::spawn(async move {
199			broadcast.closed().await;
200			if let Some(state) = state.upgrade() {
201				state.lock().remove(path, broadcast);
202			}
203		});
204	}
205
206	/// Create a new origin producer for the given sub-prefix.
207	pub fn publish_prefix<'a>(&mut self, prefix: impl Into<PathRef<'a>>) -> OriginProducer {
208		OriginProducer {
209			prefix: self.prefix.join(prefix.into()),
210			state: self.state.clone(),
211		}
212	}
213
214	/// Get a specific broadcast by path.
215	///
216	/// The most recent, non-closed broadcast will be returned if there are duplicates.
217	pub fn consume<'a>(&self, suffix: impl Into<PathRef<'a>>) -> Option<BroadcastConsumer> {
218		let path = self.prefix.join(suffix.into());
219		self.state.lock().active.get(&path).map(|b| b.active.clone())
220	}
221
222	/// Subscribe to all announced broadcasts.
223	pub fn consume_all(&self) -> OriginConsumer {
224		self.consume_prefix("")
225	}
226
227	/// Subscribe to all announced broadcasts matching the prefix.
228	///
229	/// NOTE: This takes a Suffix because it's appended to the existing prefix to get a new prefix.
230	/// Confusing I know, but it means that we don't have to return a Result.
231	pub fn consume_prefix<'a>(&self, prefix: impl Into<PathRef<'a>>) -> OriginConsumer {
232		// Combine the consumer's prefix with the publisher's prefix.
233		let prefix = self.prefix.join(prefix.into());
234
235		let mut state = self.state.lock();
236
237		let (tx, rx) = mpsc::unbounded_channel();
238		let mut consumer = ConsumerState {
239			prefix: prefix.clone(),
240			updates: tx,
241		};
242
243		for (path, broadcast) in &state.active {
244			consumer.insert(path, &broadcast.active);
245		}
246		state.consumers.push(consumer);
247
248		OriginConsumer {
249			prefix,
250			updates: rx,
251			producer: self.state.clone().downgrade(),
252		}
253	}
254
255	/// Wait until all consumers have been dropped.
256	///
257	/// NOTE: subscribe can be called to unclose the producer.
258	pub async fn unused(&self) {
259		// Keep looping until all consumers are closed.
260		while let Some(notify) = self.unused_inner() {
261			notify.closed().await;
262		}
263	}
264
265	// Returns the closed notify of any consumer.
266	fn unused_inner(&self) -> Option<mpsc::UnboundedSender<OriginUpdate>> {
267		let mut state = self.state.lock();
268
269		while let Some(consumer) = state.consumers.last() {
270			if !consumer.updates.is_closed() {
271				return Some(consumer.updates.clone());
272			}
273
274			state.consumers.pop();
275		}
276
277		None
278	}
279
280	pub fn prefix(&self) -> &Path {
281		&self.prefix
282	}
283}
284
285/// Consumes announced broadcasts matching against an optional prefix.
286pub struct OriginConsumer {
287	// We need a weak reference to the producer so that we can clone it.
288	producer: LockWeak<ProducerState>,
289	updates: mpsc::UnboundedReceiver<OriginUpdate>,
290	prefix: Path,
291}
292
293impl OriginConsumer {
294	/// Returns the next (un)announced broadcast and the absolute path.
295	///
296	/// The broadcast will only be None if it was previously Some.
297	/// The same path won't be announced/unannounced twice, instead it will toggle.
298	///
299	/// Note: The returned path is absolute and will always match this consumer's prefix.
300	pub async fn next(&mut self) -> Option<OriginUpdate> {
301		self.updates.recv().await
302	}
303
304	/// Get a specific broadcast by path.
305	///
306	/// This is relative to the consumer's prefix.
307	/// Returns None if the path hasn't been announced yet.
308	pub fn consume<'a>(&self, suffix: impl Into<PathRef<'a>>) -> Option<BroadcastConsumer> {
309		let path = self.prefix.join(suffix.into());
310
311		let state = self.producer.upgrade()?;
312		let state = state.lock();
313		state.active.get(&path).map(|b| b.active.clone())
314	}
315
316	pub fn consume_all(&self) -> OriginConsumer {
317		self.consume_prefix("")
318	}
319
320	pub fn consume_prefix<'a>(&self, prefix: impl Into<PathRef<'a>>) -> OriginConsumer {
321		// Combine the consumer's prefix with the existing consumer's prefix.
322		let prefix = self.prefix.join(prefix.into());
323
324		let (tx, rx) = mpsc::unbounded_channel();
325
326		// NOTE: consumer is immediately dropped, signalling FIN, if the producer can't be upgraded.
327		let mut consumer = ConsumerState {
328			prefix: prefix.clone(),
329			updates: tx,
330		};
331
332		if let Some(state) = self.producer.upgrade() {
333			let mut state = state.lock();
334
335			for (path, broadcast) in &state.active {
336				consumer.insert(path, &broadcast.active);
337			}
338
339			state.consumers.push(consumer);
340		}
341
342		OriginConsumer {
343			prefix,
344			updates: rx,
345			producer: self.producer.clone(),
346		}
347	}
348
349	pub fn prefix(&self) -> &Path {
350		&self.prefix
351	}
352}
353
354impl Clone for OriginConsumer {
355	fn clone(&self) -> Self {
356		self.consume_all()
357	}
358}
359
360#[cfg(test)]
361use futures::FutureExt;
362
363#[cfg(test)]
364impl OriginConsumer {
365	pub fn assert_next(&mut self, path: &str, broadcast: &BroadcastConsumer) {
366		let next = self.next().now_or_never().expect("next blocked").expect("no next");
367		assert_eq!(next.suffix.as_str(), path, "wrong path");
368		assert!(next.active.unwrap().is_clone(broadcast), "should be the same broadcast");
369	}
370
371	pub fn assert_next_none(&mut self, path: &str) {
372		let next = self.next().now_or_never().expect("next blocked").expect("no next");
373		assert_eq!(next.suffix.as_str(), path, "wrong path");
374		assert!(next.active.is_none(), "should be unannounced");
375	}
376
377	pub fn assert_next_wait(&mut self) {
378		assert!(self.next().now_or_never().is_none(), "next should block");
379	}
380
381	pub fn assert_next_closed(&mut self) {
382		assert!(
383			self.next().now_or_never().expect("next blocked").is_none(),
384			"next should be closed"
385		);
386	}
387}
388
389#[cfg(test)]
390mod tests {
391	use crate::BroadcastProducer;
392
393	use super::*;
394
395	#[tokio::test]
396	async fn test_announce() {
397		let mut producer = OriginProducer::default();
398		let broadcast1 = BroadcastProducer::new();
399		let broadcast2 = BroadcastProducer::new();
400
401		// Make a new consumer that should get it.
402		let mut consumer1 = producer.consume_all();
403		consumer1.assert_next_wait();
404
405		// Publish the first broadcast.
406		producer.publish("test1", broadcast1.consume());
407
408		consumer1.assert_next("test1", &broadcast1.consume());
409		consumer1.assert_next_wait();
410
411		// Make a new consumer that should get the existing broadcast.
412		// But we don't consume it yet.
413		let mut consumer2 = producer.consume_all();
414
415		// Publish the second broadcast.
416		producer.publish("test2", broadcast2.consume());
417
418		consumer1.assert_next("test2", &broadcast2.consume());
419		consumer1.assert_next_wait();
420
421		consumer2.assert_next("test1", &broadcast1.consume());
422		consumer2.assert_next("test2", &broadcast2.consume());
423		consumer2.assert_next_wait();
424
425		// Close the first broadcast.
426		drop(broadcast1);
427
428		// Wait for the async task to run.
429		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
430
431		// All consumers should get a None now.
432		consumer1.assert_next_none("test1");
433		consumer2.assert_next_none("test1");
434		consumer1.assert_next_wait();
435		consumer2.assert_next_wait();
436
437		// And a new consumer only gets the last broadcast.
438		let mut consumer3 = producer.consume_all();
439		consumer3.assert_next("test2", &broadcast2.consume());
440		consumer3.assert_next_wait();
441
442		// Close the producer and make sure it cleans up
443		drop(producer);
444
445		// Wait for the async task to run.
446		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
447
448		consumer1.assert_next_none("test2");
449		consumer2.assert_next_none("test2");
450		consumer3.assert_next_none("test2");
451
452		consumer1.assert_next_closed();
453		consumer2.assert_next_closed();
454		consumer3.assert_next_closed();
455	}
456
457	#[tokio::test]
458	async fn test_duplicate() {
459		let mut producer = OriginProducer::default();
460		let broadcast1 = BroadcastProducer::new();
461		let broadcast2 = BroadcastProducer::new();
462
463		producer.publish("test", broadcast1.consume());
464		producer.publish("test", broadcast2.consume());
465		assert!(producer.consume("test").is_some());
466
467		drop(broadcast1);
468
469		// Wait for the async task to run.
470		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
471		assert!(producer.consume("test").is_some());
472
473		drop(broadcast2);
474
475		// Wait for the async task to run.
476		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
477		assert!(producer.consume("test").is_none());
478	}
479
480	#[tokio::test]
481	async fn test_duplicate_reverse() {
482		let mut producer = OriginProducer::default();
483		let broadcast1 = BroadcastProducer::new();
484		let broadcast2 = BroadcastProducer::new();
485
486		producer.publish("test", broadcast1.consume());
487		producer.publish("test", broadcast2.consume());
488		assert!(producer.consume("test").is_some());
489
490		// This is harder, dropping the new broadcast first.
491		drop(broadcast2);
492
493		// Wait for the cleanup async task to run.
494		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
495		assert!(producer.consume("test").is_some());
496
497		drop(broadcast1);
498
499		// Wait for the cleanup async task to run.
500		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
501		assert!(producer.consume("test").is_none());
502	}
503
504	#[tokio::test]
505	async fn test_double_publish() {
506		let mut producer = OriginProducer::default();
507		let broadcast = BroadcastProducer::new();
508
509		// Ensure it doesn't crash.
510		producer.publish("test", broadcast.consume());
511		producer.publish("test", broadcast.consume());
512
513		assert!(producer.consume("test").is_some());
514
515		drop(broadcast);
516
517		// Wait for the async task to run.
518		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
519		assert!(producer.consume("test").is_none());
520	}
521}