fire_stream/util/
watch.rs

1
2use std::sync::{RwLock, Arc};
3use std::sync::atomic::{Ordering, AtomicUsize};
4
5use tokio::sync::Notify;
6
7pub fn channel<T>(data: T) -> (Sender<T>, Receiver<T>) {
8	let shared = Arc::new(Shared {
9		data: RwLock::new(data),
10		version: AtomicUsize::new(1),
11		tx_count: AtomicUsize::new(1),
12		notify: Notify::new()
13	});
14
15	(
16		Sender {
17			inner: shared.clone()
18		},
19		Receiver {
20			inner: shared,
21			version: 0
22		}
23	)
24}
25
26#[derive(Debug)]
27pub struct Sender<T> {
28	inner: Arc<Shared<T>>
29}
30
31impl<T> Clone for Sender<T> {
32	fn clone(&self) -> Self {
33		let inner = self.inner.clone();
34		// relaxed since this is only a counter
35		inner.tx_count.fetch_add(1, Ordering::Relaxed);
36
37		Self { inner }
38	}
39}
40
41impl<T> Drop for Sender<T> {
42	fn drop(&mut self) {
43		// relaxed since this is only a counter
44		let prev_count = self.inner.tx_count.fetch_sub(1, Ordering::Relaxed);
45		if prev_count == 1 {
46			// we are the last sender
47			// notify receivers
48			self.inner.notify.notify_waiters();
49		}
50	}
51}
52
53impl<T> Sender<T> {
54
55	/// It is possible that there are no receivers left.
56	/// 
57	/// This is not checked
58	pub fn send(&self, data: T) {
59		{
60			let mut lock = self.inner.data.write().unwrap();
61			*lock = data;
62			self.inner.version.fetch_add(1, Ordering::SeqCst);
63		}
64		self.inner.notify.notify_waiters();
65	}
66
67	pub fn newest(&self) -> T
68	where T: Clone {
69		self.inner.data.read().unwrap().clone()
70	}
71
72}
73
74#[derive(Debug)]
75pub struct Receiver<T> {
76	inner: Arc<Shared<T>>,
77	version: usize
78}
79
80impl<T> Clone for Receiver<T> {
81	fn clone(&self) -> Self {
82		Self {
83			inner: self.inner.clone(),
84			version: self.version
85		}
86	}
87}
88
89impl<T> Receiver<T> {
90	/// Returns None if there isn't any sender left.
91	pub async fn recv(&mut self) -> Option<T>
92	where T: Clone {
93		loop {
94
95			// let get the notification before we check if there exists a new
96			// version to not miss any notification that could be sent
97			// between our check.
98			let noti = self.inner.notify.notified();
99
100			let n_version = self.inner.version.load(Ordering::SeqCst);
101			if self.version != n_version {
102				self.version = n_version;
103				return Some(self.inner.data.read().unwrap().clone());
104			}
105
106			// todo: does this need to be SeqCst?
107			let tx_count = self.inner.tx_count.load(Ordering::SeqCst);
108			if tx_count == 0 {
109				return None
110			}
111
112			noti.await;
113
114		}
115	}
116
117	#[allow(dead_code)]
118	pub fn newest(&self) -> T
119	where T: Clone {
120		self.inner.data.read().unwrap().clone()
121	}
122}
123
124/// does not track if there are any receivers left
125#[derive(Debug)]
126pub struct Shared<T> {
127	data: RwLock<T>,
128	version: AtomicUsize,
129	tx_count: AtomicUsize,
130	notify: Notify
131}
132
133
134#[cfg(test)]
135mod tests {
136
137	use super::*;
138	use tokio::time::{sleep, Duration};
139
140	#[tokio::test]
141	async fn test_wakeup() {
142
143		let (tx, mut rx) = channel(true);
144
145		let task = tokio::spawn(async move {
146			let val = rx.recv().await.unwrap();
147			assert_eq!(val, true);
148			let n_val = rx.recv().await.unwrap();
149			assert_eq!(n_val, false);
150
151			assert!(rx.recv().await.is_none())
152		});
153
154		// wait for the task to start running
155		sleep(Duration::from_millis(100)).await;
156
157		tx.send(false);
158
159		drop(tx);
160
161		task.await.unwrap();
162	}
163
164}