1use async_trait::async_trait;
6use std::{marker::PhantomData, sync::Arc};
7use tokio::sync::{
8 broadcast, mpsc,
9 watch::{self, error::RecvError},
10};
11
12#[derive(Clone)]
13pub struct Barrier<T>(watch::Receiver<Option<T>>)
14where
15 T: Clone;
16
17impl<T> Barrier<T>
18where
19 T: Clone,
20{
21 pub async fn wait(&mut self) -> Result<T, RecvError> {
23 loop {
24 self.0.changed().await?;
25
26 if let Some(v) = self.0.borrow().clone() {
27 return Ok(v);
28 }
29 }
30 }
31
32 pub fn is_open(&self) -> bool {
34 self.0.borrow().is_some()
35 }
36}
37
38#[async_trait]
39impl<T: Clone + Send + Sync> Receivable<T> for Barrier<T> {
40 async fn recv_msg(&mut self) -> Option<T> {
41 self.wait().await.ok()
42 }
43}
44
45#[derive(Clone)]
46pub struct BarrierOpener<T: Clone>(Arc<watch::Sender<Option<T>>>);
47
48impl<T: Clone> BarrierOpener<T> {
49 pub fn open(&self, value: T) {
51 self.0.send_if_modified(|v| {
52 if v.is_none() {
53 *v = Some(value);
54 true
55 } else {
56 false
57 }
58 });
59 }
60}
61
62pub fn new_barrier<T>() -> (Barrier<T>, BarrierOpener<T>)
65where
66 T: Clone,
67{
68 let (closed_tx, closed_rx) = watch::channel(None);
69 (Barrier(closed_rx), BarrierOpener(Arc::new(closed_tx)))
70}
71
72#[async_trait]
74pub trait Receivable<T> {
75 async fn recv_msg(&mut self) -> Option<T>;
76}
77
78#[async_trait]
81impl<T: Clone + Send> Receivable<T> for broadcast::Receiver<T> {
82 async fn recv_msg(&mut self) -> Option<T> {
83 loop {
84 match self.recv().await {
85 Ok(v) => return Some(v),
86 Err(broadcast::error::RecvError::Lagged(_)) => continue,
87 Err(broadcast::error::RecvError::Closed) => return None,
88 }
89 }
90 }
91}
92
93#[async_trait]
94impl<T: Send> Receivable<T> for mpsc::UnboundedReceiver<T> {
95 async fn recv_msg(&mut self) -> Option<T> {
96 self.recv().await
97 }
98}
99
100#[async_trait]
101impl<T: Send> Receivable<T> for () {
102 async fn recv_msg(&mut self) -> Option<T> {
103 futures::future::pending().await
104 }
105}
106
107pub struct ConcatReceivable<T: Send, A: Receivable<T>, B: Receivable<T>> {
108 left: Option<A>,
109 right: B,
110 _marker: PhantomData<T>,
111}
112
113impl<T: Send, A: Receivable<T>, B: Receivable<T>> ConcatReceivable<T, A, B> {
114 pub fn new(left: A, right: B) -> Self {
115 Self {
116 left: Some(left),
117 right,
118 _marker: PhantomData,
119 }
120 }
121}
122
123#[async_trait]
124impl<T: Send, A: Send + Receivable<T>, B: Send + Receivable<T>> Receivable<T>
125 for ConcatReceivable<T, A, B>
126{
127 async fn recv_msg(&mut self) -> Option<T> {
128 if let Some(left) = &mut self.left {
129 match left.recv_msg().await {
130 Some(v) => return Some(v),
131 None => {
132 self.left = None;
133 }
134 }
135 }
136
137 return self.right.recv_msg().await;
138 }
139}
140
141pub struct MergedReceivable<T: Send, A: Receivable<T>, B: Receivable<T>> {
142 left: Option<A>,
143 right: Option<B>,
144 _marker: PhantomData<T>,
145}
146
147impl<T: Send, A: Receivable<T>, B: Receivable<T>> MergedReceivable<T, A, B> {
148 pub fn new(left: A, right: B) -> Self {
149 Self {
150 left: Some(left),
151 right: Some(right),
152 _marker: PhantomData,
153 }
154 }
155}
156
157#[async_trait]
158impl<T: Send, A: Send + Receivable<T>, B: Send + Receivable<T>> Receivable<T>
159 for MergedReceivable<T, A, B>
160{
161 async fn recv_msg(&mut self) -> Option<T> {
162 loop {
163 match (&mut self.left, &mut self.right) {
164 (Some(left), Some(right)) => {
165 tokio::select! {
166 left = left.recv_msg() => match left {
167 Some(v) => return Some(v),
168 None => { self.left = None; continue; },
169 },
170 right = right.recv_msg() => match right {
171 Some(v) => return Some(v),
172 None => { self.right = None; continue; },
173 },
174 }
175 }
176 (Some(a), None) => break a.recv_msg().await,
177 (None, Some(b)) => break b.recv_msg().await,
178 (None, None) => break None,
179 }
180 }
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187
188 #[tokio::test]
189 async fn test_barrier_close_after_spawn() {
190 let (mut barrier, opener) = new_barrier::<u32>();
191 let (tx, rx) = tokio::sync::oneshot::channel::<u32>();
192
193 tokio::spawn(async move {
194 tx.send(barrier.wait().await.unwrap()).unwrap();
195 });
196
197 opener.open(42);
198
199 assert!(rx.await.unwrap() == 42);
200 }
201
202 #[tokio::test]
203 async fn test_barrier_close_before_spawn() {
204 let (barrier, opener) = new_barrier::<u32>();
205 let (tx1, rx1) = tokio::sync::oneshot::channel::<u32>();
206 let (tx2, rx2) = tokio::sync::oneshot::channel::<u32>();
207
208 opener.open(42);
209 let mut b1 = barrier.clone();
210 tokio::spawn(async move {
211 tx1.send(b1.wait().await.unwrap()).unwrap();
212 });
213 let mut b2 = barrier.clone();
214 tokio::spawn(async move {
215 tx2.send(b2.wait().await.unwrap()).unwrap();
216 });
217
218 assert!(rx1.await.unwrap() == 42);
219 assert!(rx2.await.unwrap() == 42);
220 }
221}