cli/util/
sync.rs

1/*---------------------------------------------------------------------------------------------
2 *  Copyright (c) Microsoft Corporation. All rights reserved.
3 *  Licensed under the MIT License. See License.txt in the project root for license information.
4 *--------------------------------------------------------------------------------------------*/
5use async_trait::async_trait;
6use std::{marker::PhantomData, sync::Arc};
7use tokio::sync::{
8	broadcast, mpsc,
9	watch::{self, error::RecvError},
10};
11
12#[derive(Clone)]
13pub struct Barrier<T>(watch::Receiver<Option<T>>)
14where
15	T: Clone;
16
17impl<T> Barrier<T>
18where
19	T: Clone,
20{
21	/// Waits for the barrier to be closed, returning a value if one was sent.
22	pub async fn wait(&mut self) -> Result<T, RecvError> {
23		loop {
24			self.0.changed().await?;
25
26			if let Some(v) = self.0.borrow().clone() {
27				return Ok(v);
28			}
29		}
30	}
31
32	/// Gets whether the barrier is currently open
33	pub fn is_open(&self) -> bool {
34		self.0.borrow().is_some()
35	}
36}
37
38#[async_trait]
39impl<T: Clone + Send + Sync> Receivable<T> for Barrier<T> {
40	async fn recv_msg(&mut self) -> Option<T> {
41		self.wait().await.ok()
42	}
43}
44
45#[derive(Clone)]
46pub struct BarrierOpener<T: Clone>(Arc<watch::Sender<Option<T>>>);
47
48impl<T: Clone> BarrierOpener<T> {
49	/// Opens the barrier.
50	pub fn open(&self, value: T) {
51		self.0.send_if_modified(|v| {
52			if v.is_none() {
53				*v = Some(value);
54				true
55			} else {
56				false
57			}
58		});
59	}
60}
61
62/// The Barrier is something that can be opened once from one side,
63/// and is thereafter permanently closed. It can contain a value.
64pub fn new_barrier<T>() -> (Barrier<T>, BarrierOpener<T>)
65where
66	T: Clone,
67{
68	let (closed_tx, closed_rx) = watch::channel(None);
69	(Barrier(closed_rx), BarrierOpener(Arc::new(closed_tx)))
70}
71
72/// Type that can receive messages in an async way.
73#[async_trait]
74pub trait Receivable<T> {
75	async fn recv_msg(&mut self) -> Option<T>;
76}
77
78// todo: ideally we would use an Arc in the broadcast::Receiver to avoid having
79// to clone bytes everywhere, requires updating rpc consumers as well.
80#[async_trait]
81impl<T: Clone + Send> Receivable<T> for broadcast::Receiver<T> {
82	async fn recv_msg(&mut self) -> Option<T> {
83		loop {
84			match self.recv().await {
85				Ok(v) => return Some(v),
86				Err(broadcast::error::RecvError::Lagged(_)) => continue,
87				Err(broadcast::error::RecvError::Closed) => return None,
88			}
89		}
90	}
91}
92
93#[async_trait]
94impl<T: Send> Receivable<T> for mpsc::UnboundedReceiver<T> {
95	async fn recv_msg(&mut self) -> Option<T> {
96		self.recv().await
97	}
98}
99
100#[async_trait]
101impl<T: Send> Receivable<T> for () {
102	async fn recv_msg(&mut self) -> Option<T> {
103		futures::future::pending().await
104	}
105}
106
107pub struct ConcatReceivable<T: Send, A: Receivable<T>, B: Receivable<T>> {
108	left: Option<A>,
109	right: B,
110	_marker: PhantomData<T>,
111}
112
113impl<T: Send, A: Receivable<T>, B: Receivable<T>> ConcatReceivable<T, A, B> {
114	pub fn new(left: A, right: B) -> Self {
115		Self {
116			left: Some(left),
117			right,
118			_marker: PhantomData,
119		}
120	}
121}
122
123#[async_trait]
124impl<T: Send, A: Send + Receivable<T>, B: Send + Receivable<T>> Receivable<T>
125	for ConcatReceivable<T, A, B>
126{
127	async fn recv_msg(&mut self) -> Option<T> {
128		if let Some(left) = &mut self.left {
129			match left.recv_msg().await {
130				Some(v) => return Some(v),
131				None => {
132					self.left = None;
133				}
134			}
135		}
136
137		return self.right.recv_msg().await;
138	}
139}
140
141pub struct MergedReceivable<T: Send, A: Receivable<T>, B: Receivable<T>> {
142	left: Option<A>,
143	right: Option<B>,
144	_marker: PhantomData<T>,
145}
146
147impl<T: Send, A: Receivable<T>, B: Receivable<T>> MergedReceivable<T, A, B> {
148	pub fn new(left: A, right: B) -> Self {
149		Self {
150			left: Some(left),
151			right: Some(right),
152			_marker: PhantomData,
153		}
154	}
155}
156
157#[async_trait]
158impl<T: Send, A: Send + Receivable<T>, B: Send + Receivable<T>> Receivable<T>
159	for MergedReceivable<T, A, B>
160{
161	async fn recv_msg(&mut self) -> Option<T> {
162		loop {
163			match (&mut self.left, &mut self.right) {
164				(Some(left), Some(right)) => {
165					tokio::select! {
166						left = left.recv_msg() => match left {
167							Some(v) => return Some(v),
168							None => { self.left = None; continue; },
169						},
170						right = right.recv_msg() => match right {
171							Some(v) => return Some(v),
172							None => { self.right = None; continue; },
173						},
174					}
175				}
176				(Some(a), None) => break a.recv_msg().await,
177				(None, Some(b)) => break b.recv_msg().await,
178				(None, None) => break None,
179			}
180		}
181	}
182}
183
184#[cfg(test)]
185mod tests {
186	use super::*;
187
188	#[tokio::test]
189	async fn test_barrier_close_after_spawn() {
190		let (mut barrier, opener) = new_barrier::<u32>();
191		let (tx, rx) = tokio::sync::oneshot::channel::<u32>();
192
193		tokio::spawn(async move {
194			tx.send(barrier.wait().await.unwrap()).unwrap();
195		});
196
197		opener.open(42);
198
199		assert!(rx.await.unwrap() == 42);
200	}
201
202	#[tokio::test]
203	async fn test_barrier_close_before_spawn() {
204		let (barrier, opener) = new_barrier::<u32>();
205		let (tx1, rx1) = tokio::sync::oneshot::channel::<u32>();
206		let (tx2, rx2) = tokio::sync::oneshot::channel::<u32>();
207
208		opener.open(42);
209		let mut b1 = barrier.clone();
210		tokio::spawn(async move {
211			tx1.send(b1.wait().await.unwrap()).unwrap();
212		});
213		let mut b2 = barrier.clone();
214		tokio::spawn(async move {
215			tx2.send(b2.wait().await.unwrap()).unwrap();
216		});
217
218		assert!(rx1.await.unwrap() == 42);
219		assert!(rx2.await.unwrap() == 42);
220	}
221}