moq_lite/model/
origin.rs

1use std::collections::{hash_map, HashMap};
2use tokio::sync::mpsc;
3use web_async::{Lock, LockWeak};
4
5use super::BroadcastConsumer;
6use crate::{Path, PathRef};
7
8// If there are multiple broadcasts with the same path, we use the most recent one but keep the others around.
9struct BroadcastState {
10	active: BroadcastConsumer,
11	backup: Vec<BroadcastConsumer>,
12}
13
14#[derive(Default)]
15struct ProducerState {
16	active: HashMap<Path, BroadcastState>,
17	consumers: Vec<ConsumerState>,
18}
19
20impl ProducerState {
21	// Returns true if this was a unique broadcast.
22	fn publish(&mut self, path: Path, broadcast: BroadcastConsumer) -> bool {
23		let mut unique = true;
24
25		match self.active.entry(path.clone()) {
26			hash_map::Entry::Occupied(mut entry) => {
27				let state = entry.get_mut();
28				if state.active.is_clone(&broadcast) {
29					// If we're already publishing this broadcast, then don't do anything.
30					return false;
31				}
32
33				// Make the new broadcast the active one.
34				let old = state.active.clone();
35				state.active = broadcast.clone();
36
37				// Move the old broadcast to the backup list.
38				// But we need to replace any previous duplicates.
39				let pos = state.backup.iter().position(|b| b.is_clone(&broadcast));
40				if let Some(pos) = pos {
41					state.backup[pos] = old;
42
43					// We're already publishing this broadcast, so don't run the cleanup task.
44					unique = false;
45				} else {
46					state.backup.push(old);
47				}
48
49				// Reannounce the path to all consumers.
50				retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
51			}
52			hash_map::Entry::Vacant(entry) => {
53				entry.insert(BroadcastState {
54					active: broadcast.clone(),
55					backup: Vec::new(),
56				});
57			}
58		};
59
60		retain_mut_unordered(&mut self.consumers, |c| c.insert(&path, &broadcast));
61
62		unique
63	}
64
65	fn remove(&mut self, path: Path, broadcast: BroadcastConsumer) {
66		let mut entry = match self.active.entry(path.clone()) {
67			hash_map::Entry::Occupied(entry) => entry,
68			hash_map::Entry::Vacant(_) => panic!("broadcast not found"),
69		};
70
71		// See if we can remove the broadcast from the backup list.
72		let pos = entry.get().backup.iter().position(|b| b.is_clone(&broadcast));
73		if let Some(pos) = pos {
74			entry.get_mut().backup.remove(pos);
75			// Nothing else to do
76			return;
77		}
78
79		// Okay so it must be the active broadcast or else we fucked up.
80		assert!(entry.get().active.is_clone(&broadcast));
81
82		retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
83
84		// If there's a backup broadcast, then announce it.
85		if let Some(active) = entry.get_mut().backup.pop() {
86			entry.get_mut().active = active;
87			retain_mut_unordered(&mut self.consumers, |c| c.insert(&path, &entry.get().active));
88		} else {
89			// No more backups, so remove the entry.
90			entry.remove();
91		}
92	}
93}
94
95impl Drop for ProducerState {
96	fn drop(&mut self) {
97		for (path, _) in self.active.drain() {
98			retain_mut_unordered(&mut self.consumers, |c| c.remove(&path));
99		}
100	}
101}
102
103// A faster version of retain_mut that doesn't maintain the order.
104fn retain_mut_unordered<T, F: Fn(&mut T) -> bool>(vec: &mut Vec<T>, f: F) {
105	let mut i = 0;
106	while let Some(item) = vec.get_mut(i) {
107		if f(item) {
108			i += 1;
109		} else {
110			vec.swap_remove(i);
111		}
112	}
113}
114
115/// A broadcast path and its associated consumer, or None if closed.
116/// The returned path is relative to the consumer's prefix.
117pub struct OriginUpdate {
118	pub suffix: Path,
119	pub active: Option<BroadcastConsumer>,
120}
121
122struct ConsumerState {
123	prefix: Path,
124	updates: mpsc::UnboundedSender<OriginUpdate>,
125}
126
127impl ConsumerState {
128	// Returns true if the consumer is still alive.
129	pub fn insert<'a>(&mut self, path: impl Into<PathRef<'a>>, consumer: &BroadcastConsumer) -> bool {
130		let path_ref = path.into();
131
132		if let Some(suffix) = path_ref.to_owned().strip_prefix(&self.prefix) {
133			// Send the absolute path, not the relative suffix
134			let update = OriginUpdate {
135				suffix: suffix.into(),
136				active: Some(consumer.clone()),
137			};
138			return self.updates.send(update).is_ok();
139		}
140
141		!self.updates.is_closed()
142	}
143
144	pub fn remove<'a>(&mut self, path: impl Into<PathRef<'a>>) -> bool {
145		let path_ref = path.into();
146
147		if let Some(suffix) = path_ref.to_owned().strip_prefix(&self.prefix) {
148			// Send the absolute path, not the relative suffix
149			let update = OriginUpdate {
150				suffix: suffix.into(),
151				active: None,
152			};
153			return self.updates.send(update).is_ok();
154		}
155
156		!self.updates.is_closed()
157	}
158}
159
160/// Announces broadcasts to consumers over the network.
161#[derive(Clone, Default)]
162pub struct OriginProducer {
163	/// All broadcasts are relative to this path.
164	root: Path,
165
166	/// All published broadcasts start with this prefix, relative to root.
167	///
168	/// NOTE: consumers are relative to the root.
169	prefix: Path,
170
171	state: Lock<ProducerState>,
172}
173
174impl OriginProducer {
175	pub fn new() -> Self {
176		Self::default()
177	}
178
179	/// Publish a broadcast, announcing it to all consumers.
180	///
181	/// The broadcast will be unannounced when it is closed.
182	/// If there is already a broadcast with the same path, then it will be replaced and reannounced.
183	/// If the old broadcast is closed before the new one, then nothing will happen.
184	/// If the new broadcast is closed before the old one, then the old broadcast will be reannounced.
185	pub fn publish<'a>(&mut self, path: impl Into<PathRef<'a>>, broadcast: BroadcastConsumer) {
186		let path = path.into();
187		let full = self.root.join(&self.prefix).join(&path);
188
189		if !self.state.lock().publish(full.clone(), broadcast.clone()) {
190			// The exact same BroadcastConsumer was published with the same path twice.
191			// This is not a huge deal, but we break early to avoid redundant cleanup work.
192			tracing::warn!(?path, "duplicate publish");
193			return;
194		}
195
196		let state = self.state.clone().downgrade();
197
198		// TODO cancel this task when the producer is dropped.
199		web_async::spawn(async move {
200			broadcast.closed().await;
201			if let Some(state) = state.upgrade() {
202				state.lock().remove(full, broadcast);
203			}
204		});
205	}
206
207	/// Returns a new OriginProducer where all published broadcasts are relative to the prefix.
208	pub fn publish_prefix<'a>(&self, prefix: impl Into<PathRef<'a>>) -> Self {
209		Self {
210			prefix: self.prefix.join(prefix),
211			state: self.state.clone(),
212			root: self.root.clone(),
213		}
214	}
215
216	/// Get a specific broadcast by path.
217	///
218	/// The most recent, non-closed broadcast will be returned if there are duplicates.
219	pub fn consume<'a>(&self, path: impl Into<PathRef<'a>>) -> Option<BroadcastConsumer> {
220		let path = path.into();
221
222		let full = self.root.join(path);
223		self.state.lock().active.get(&full).map(|b| b.active.clone())
224	}
225
226	/// Subscribe to all announced broadcasts.
227	pub fn consume_all(&self) -> OriginConsumer {
228		self.consume_prefix("")
229	}
230
231	/// Subscribe to all announced broadcasts matching the prefix.
232	///
233	/// NOTE: This takes a Suffix because it's appended to the existing prefix to get a new prefix.
234	/// Confusing I know, but it means that we don't have to return a Result.
235	pub fn consume_prefix(&self, prefix: impl Into<Path>) -> OriginConsumer {
236		let prefix = prefix.into();
237		let full = self.root.join(&prefix);
238
239		let mut state = self.state.lock();
240
241		let (tx, rx) = mpsc::unbounded_channel();
242		let mut consumer = ConsumerState {
243			prefix: full,
244			updates: tx,
245		};
246
247		for (path, broadcast) in &state.active {
248			consumer.insert(path, &broadcast.active);
249		}
250		state.consumers.push(consumer);
251
252		OriginConsumer {
253			root: self.root.clone(),
254			prefix,
255			updates: rx,
256			producer: self.state.clone().downgrade(),
257		}
258	}
259
260	pub fn with_root(&self, root: impl Into<Path>) -> Self {
261		let root = root.into();
262
263		// Make sure the new root matches any existing configured prefix.
264		// ex. if you only allow publishing /foo, it's not legal to change the root to /bar
265		let prefix = match self.prefix.strip_prefix(&root) {
266			Some(prefix) => prefix.to_owned(),
267			None if self.prefix.is_empty() => Path::default(),
268			None => panic!("with_root doesn't match existing prefix"),
269		};
270
271		Self {
272			root: self.root.join(&root),
273			prefix,
274			state: self.state.clone(),
275		}
276	}
277
278	/// Wait until all consumers have been dropped.
279	///
280	/// NOTE: subscribe can be called to unclose the producer.
281	pub async fn unused(&self) {
282		// Keep looping until all consumers are closed.
283		while let Some(notify) = self.unused_inner() {
284			notify.closed().await;
285		}
286	}
287
288	// Returns the closed notify of any consumer.
289	fn unused_inner(&self) -> Option<mpsc::UnboundedSender<OriginUpdate>> {
290		let mut state = self.state.lock();
291
292		while let Some(consumer) = state.consumers.last() {
293			if !consumer.updates.is_closed() {
294				return Some(consumer.updates.clone());
295			}
296
297			state.consumers.pop();
298		}
299
300		None
301	}
302
303	pub fn root(&self) -> &Path {
304		&self.root
305	}
306
307	pub fn prefix(&self) -> &Path {
308		&self.prefix
309	}
310}
311
312/// Consumes announced broadcasts matching against an optional prefix.
313pub struct OriginConsumer {
314	// We need a weak reference to the producer so that we can clone it.
315	producer: LockWeak<ProducerState>,
316	updates: mpsc::UnboundedReceiver<OriginUpdate>,
317
318	/// All broadcasts are relative to this root path.
319	root: Path,
320
321	/// Only fetch broadcasts matching this prefix.
322	prefix: Path,
323}
324
325impl OriginConsumer {
326	/// Returns the next (un)announced broadcast and the absolute path.
327	///
328	/// The broadcast will only be None if it was previously Some.
329	/// The same path won't be announced/unannounced twice, instead it will toggle.
330	///
331	/// Note: The returned path is absolute and will always match this consumer's prefix.
332	pub async fn next(&mut self) -> Option<OriginUpdate> {
333		self.updates.recv().await
334	}
335
336	/// Get a specific broadcast by path.
337	///
338	/// This is relative to the consumer's prefix.
339	/// Returns None if the path hasn't been announced yet.
340	pub fn consume<'a>(&self, path: impl Into<PathRef<'a>>) -> Option<BroadcastConsumer> {
341		let full = self.root.join(&self.prefix).join(path.into());
342
343		let state = self.producer.upgrade()?;
344		let state = state.lock();
345		state.active.get(&full).map(|b| b.active.clone())
346	}
347
348	pub fn consume_all(&self) -> OriginConsumer {
349		self.consume_prefix("")
350	}
351
352	pub fn consume_prefix<'a>(&self, prefix: impl Into<PathRef<'a>>) -> OriginConsumer {
353		// The prefix is relative to the existing prefix.
354		let prefix = self.prefix.join(prefix);
355
356		// Combine the consumer's prefix with the existing consumer's prefix.
357		let full = self.root.join(&prefix);
358
359		let (tx, rx) = mpsc::unbounded_channel();
360
361		// NOTE: consumer is immediately dropped, signalling FIN, if the producer can't be upgraded.
362		let mut consumer = ConsumerState {
363			prefix: full,
364			updates: tx,
365		};
366
367		if let Some(state) = self.producer.upgrade() {
368			let mut state = state.lock();
369
370			for (path, broadcast) in &state.active {
371				consumer.insert(path, &broadcast.active);
372			}
373
374			state.consumers.push(consumer);
375		}
376
377		OriginConsumer {
378			root: self.root.clone(),
379			prefix,
380			updates: rx,
381			producer: self.producer.clone(),
382		}
383	}
384
385	pub fn root(&self) -> &Path {
386		&self.root
387	}
388
389	pub fn prefix(&self) -> &Path {
390		&self.prefix
391	}
392}
393
394impl Clone for OriginConsumer {
395	fn clone(&self) -> Self {
396		self.consume_all()
397	}
398}
399
400#[cfg(test)]
401use futures::FutureExt;
402
403#[cfg(test)]
404impl OriginConsumer {
405	pub fn assert_next(&mut self, path: &str, broadcast: &BroadcastConsumer) {
406		let next = self.next().now_or_never().expect("next blocked").expect("no next");
407		assert_eq!(next.suffix.as_str(), path, "wrong path");
408		assert!(next.active.unwrap().is_clone(broadcast), "should be the same broadcast");
409	}
410
411	pub fn assert_next_none(&mut self, path: &str) {
412		let next = self.next().now_or_never().expect("next blocked").expect("no next");
413		assert_eq!(next.suffix.as_str(), path, "wrong path");
414		assert!(next.active.is_none(), "should be unannounced");
415	}
416
417	pub fn assert_next_wait(&mut self) {
418		assert!(self.next().now_or_never().is_none(), "next should block");
419	}
420
421	pub fn assert_next_closed(&mut self) {
422		assert!(
423			self.next().now_or_never().expect("next blocked").is_none(),
424			"next should be closed"
425		);
426	}
427}
428
429#[cfg(test)]
430mod tests {
431	use crate::BroadcastProducer;
432
433	use super::*;
434
435	#[tokio::test]
436	async fn test_announce() {
437		let mut producer = OriginProducer::default();
438		let broadcast1 = BroadcastProducer::new();
439		let broadcast2 = BroadcastProducer::new();
440
441		// Make a new consumer that should get it.
442		let mut consumer1 = producer.consume_all();
443		consumer1.assert_next_wait();
444
445		// Publish the first broadcast.
446		producer.publish("test1", broadcast1.consume());
447
448		consumer1.assert_next("test1", &broadcast1.consume());
449		consumer1.assert_next_wait();
450
451		// Make a new consumer that should get the existing broadcast.
452		// But we don't consume it yet.
453		let mut consumer2 = producer.consume_all();
454
455		// Publish the second broadcast.
456		producer.publish("test2", broadcast2.consume());
457
458		consumer1.assert_next("test2", &broadcast2.consume());
459		consumer1.assert_next_wait();
460
461		consumer2.assert_next("test1", &broadcast1.consume());
462		consumer2.assert_next("test2", &broadcast2.consume());
463		consumer2.assert_next_wait();
464
465		// Close the first broadcast.
466		drop(broadcast1);
467
468		// Wait for the async task to run.
469		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
470
471		// All consumers should get a None now.
472		consumer1.assert_next_none("test1");
473		consumer2.assert_next_none("test1");
474		consumer1.assert_next_wait();
475		consumer2.assert_next_wait();
476
477		// And a new consumer only gets the last broadcast.
478		let mut consumer3 = producer.consume_all();
479		consumer3.assert_next("test2", &broadcast2.consume());
480		consumer3.assert_next_wait();
481
482		// Close the producer and make sure it cleans up
483		drop(producer);
484
485		// Wait for the async task to run.
486		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
487
488		consumer1.assert_next_none("test2");
489		consumer2.assert_next_none("test2");
490		consumer3.assert_next_none("test2");
491
492		consumer1.assert_next_closed();
493		consumer2.assert_next_closed();
494		consumer3.assert_next_closed();
495	}
496
497	#[tokio::test]
498	async fn test_duplicate() {
499		let mut producer = OriginProducer::default();
500		let broadcast1 = BroadcastProducer::new();
501		let broadcast2 = BroadcastProducer::new();
502
503		producer.publish("test", broadcast1.consume());
504		producer.publish("test", broadcast2.consume());
505		assert!(producer.consume("test").is_some());
506
507		drop(broadcast1);
508
509		// Wait for the async task to run.
510		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
511		assert!(producer.consume("test").is_some());
512
513		drop(broadcast2);
514
515		// Wait for the async task to run.
516		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
517		assert!(producer.consume("test").is_none());
518	}
519
520	#[tokio::test]
521	async fn test_duplicate_reverse() {
522		let mut producer = OriginProducer::default();
523		let broadcast1 = BroadcastProducer::new();
524		let broadcast2 = BroadcastProducer::new();
525
526		producer.publish("test", broadcast1.consume());
527		producer.publish("test", broadcast2.consume());
528		assert!(producer.consume("test").is_some());
529
530		// This is harder, dropping the new broadcast first.
531		drop(broadcast2);
532
533		// Wait for the cleanup async task to run.
534		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
535		assert!(producer.consume("test").is_some());
536
537		drop(broadcast1);
538
539		// Wait for the cleanup async task to run.
540		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
541		assert!(producer.consume("test").is_none());
542	}
543
544	#[tokio::test]
545	async fn test_double_publish() {
546		let mut producer = OriginProducer::default();
547		let broadcast = BroadcastProducer::new();
548
549		// Ensure it doesn't crash.
550		producer.publish("test", broadcast.consume());
551		producer.publish("test", broadcast.consume());
552
553		assert!(producer.consume("test").is_some());
554
555		drop(broadcast);
556
557		// Wait for the async task to run.
558		tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
559		assert!(producer.consume("test").is_none());
560	}
561}