ipc_queue/
interface_async.rs1use std::sync::atomic::Ordering;
8use crate::AsyncReceiver;
9use crate::AsyncSender;
10use crate::AsyncSynchronizer;
11#[cfg(not(target_env = "sgx"))]
12use crate::DescriptorGuard;
13use crate::Identified;
14use crate::QueueEvent;
15use crate::RecvError;
16use crate::SendError;
17use crate::SynchronizationError;
18use crate::Transmittable;
19use crate::TryRecvError;
20use crate::TrySendError;
21use crate::position::PositionMonitor;
22
23unsafe impl<T: Send, S: Send> Send for AsyncSender<T, S> {}
24unsafe impl<T: Send, S: Sync> Sync for AsyncSender<T, S> {}
25
26impl<T, S: Clone> Clone for AsyncSender<T, S> {
27 fn clone(&self) -> Self {
28 Self {
29 inner: self.inner.clone(),
30 synchronizer: self.synchronizer.clone(),
31 }
32 }
33}
34
35impl<T: Transmittable, S: AsyncSynchronizer> AsyncSender<T, S> {
36 pub async fn send(&self, val: Identified<T>) -> Result<(), SendError> {
37 loop {
38 match self.inner.try_send_impl(val) {
39 Ok(wake_receiver) => {
40 if wake_receiver {
41 self.synchronizer.notify(QueueEvent::NotEmpty);
42 }
43 return Ok(());
44 }
45 Err(TrySendError::QueueFull) => {
46 self.synchronizer
47 .wait(QueueEvent::NotFull).await
48 .map_err(|SynchronizationError::ChannelClosed| SendError::Closed)?;
49 }
50 Err(TrySendError::Closed) => return Err(SendError::Closed),
51 };
52 }
53 }
54
55 #[cfg(not(target_env = "sgx"))]
59 pub fn into_descriptor_guard(self) -> DescriptorGuard<T> {
60 self.inner.into_descriptor_guard()
61 }
62}
63
64unsafe impl<T: Send, S: Send> Send for AsyncReceiver<T, S> {}
65
66impl<T: Transmittable, S: AsyncSynchronizer> AsyncReceiver<T, S> {
67 pub async fn recv(&self) -> Result<Identified<T>, RecvError> {
68 loop {
69 match self.inner.try_recv_impl() {
70 Ok((val, wake_sender, read_wrapped_around)) => {
71 if wake_sender {
72 self.synchronizer.notify(QueueEvent::NotFull);
73 }
74 if read_wrapped_around {
75 self.read_epoch.fetch_add(1, Ordering::Relaxed);
76 }
77 return Ok(val);
78 }
79 Err(TryRecvError::QueueEmpty) => {
80 self.synchronizer
81 .wait(QueueEvent::NotEmpty).await
82 .map_err(|SynchronizationError::ChannelClosed| RecvError::Closed)?;
83 }
84 Err(TryRecvError::Closed) => return Err(RecvError::Closed),
85 }
86 }
87 }
88
89 pub fn position_monitor(&self) -> PositionMonitor<T> {
90 PositionMonitor::new(self.read_epoch.clone(), self.inner.clone())
91 }
92
93 #[cfg(not(target_env = "sgx"))]
97 pub fn into_descriptor_guard(self) -> DescriptorGuard<T> {
98 self.inner.into_descriptor_guard()
99 }
100}
101
102#[cfg(not(target_env = "sgx"))]
103#[cfg(test)]
104mod tests {
105 use futures::future::FutureExt;
106 use futures::lock::Mutex;
107 use tokio::sync::broadcast;
108 use tokio::sync::broadcast::error::{SendError, RecvError};
109
110 use crate::*;
111 use crate::test_support::TestValue;
112
113 async fn do_single_sender(len: usize, n: u64) {
114 let s = TestAsyncSynchronizer::new();
115 let (tx, rx) = bounded_async(len, s);
116 let local = tokio::task::LocalSet::new();
117
118 let h1 = local.spawn_local(async move {
119 for i in 0..n {
120 tx.send(Identified { id: i + 1, data: TestValue(i) }).await.unwrap();
121 }
122 });
123
124 let h2 = local.spawn_local(async move {
125 for i in 0..n {
126 let v = rx.recv().await.unwrap();
127 assert_eq!(v.id, i + 1);
128 assert_eq!(v.data.0, i);
129 }
130 });
131
132 local.await;
133 h1.await.unwrap();
134 h2.await.unwrap();
135 }
136
137 #[tokio::test]
138 async fn single_sender() {
139 do_single_sender(4, 10).await;
140 do_single_sender(1, 10).await;
141 do_single_sender(32, 1024).await;
142 do_single_sender(1024, 32).await;
143 }
144
145 async fn do_multi_sender(len: usize, n: u64, senders: u64) {
146 let s = TestAsyncSynchronizer::new();
147 let (tx, rx) = bounded_async(len, s);
148 let mut handles = Vec::with_capacity(senders as _);
149 let local = tokio::task::LocalSet::new();
150
151 for t in 0..senders {
152 let tx = tx.clone();
153 handles.push(local.spawn_local(async move {
154 for i in 0..n {
155 let id = t * n + i + 1;
156 tx.send(Identified { id, data: TestValue(i) }).await.unwrap();
157 }
158 }));
159 }
160
161 handles.push(local.spawn_local(async move {
162 for _ in 0..(n * senders) {
163 rx.recv().await.unwrap();
164 }
165 }));
166
167 local.await;
168 for h in handles {
169 h.await.unwrap();
170 }
171 }
172
173 #[tokio::test]
174 async fn multi_sender() {
175 do_multi_sender(4, 10, 3).await;
176 do_multi_sender(4, 1, 100).await;
177 do_multi_sender(2, 10, 100).await;
178 do_multi_sender(1024, 30, 100).await;
179 }
180
181 #[tokio::test]
182 async fn positions() {
183 const LEN: usize = 16;
184 let s = TestAsyncSynchronizer::new();
185 let (tx, rx) = bounded_async(LEN, s);
186 let monitor = rx.position_monitor();
187 let mut id = 1;
188
189 let p0 = monitor.write_position();
190 tx.send(Identified { id, data: TestValue(1) }).await.unwrap();
191 let p1 = monitor.write_position();
192 tx.send(Identified { id: id + 1, data: TestValue(2) }).await.unwrap();
193 let p2 = monitor.write_position();
194 tx.send(Identified { id: id + 2, data: TestValue(3) }).await.unwrap();
195 let p3 = monitor.write_position();
196 id += 3;
197 assert!(monitor.read_position().is_past(&p0) == Some(false));
198 assert!(monitor.read_position().is_past(&p1) == Some(false));
199 assert!(monitor.read_position().is_past(&p2) == Some(false));
200 assert!(monitor.read_position().is_past(&p3) == Some(false));
201
202 rx.recv().await.unwrap();
203 assert!(monitor.read_position().is_past(&p0) == Some(true));
204 assert!(monitor.read_position().is_past(&p1) == Some(false));
205 assert!(monitor.read_position().is_past(&p2) == Some(false));
206 assert!(monitor.read_position().is_past(&p3) == Some(false));
207
208 rx.recv().await.unwrap();
209 assert!(monitor.read_position().is_past(&p0) == Some(true));
210 assert!(monitor.read_position().is_past(&p1) == Some(true));
211 assert!(monitor.read_position().is_past(&p2) == Some(false));
212 assert!(monitor.read_position().is_past(&p3) == Some(false));
213
214 rx.recv().await.unwrap();
215 assert!(monitor.read_position().is_past(&p0) == Some(true));
216 assert!(monitor.read_position().is_past(&p1) == Some(true));
217 assert!(monitor.read_position().is_past(&p2) == Some(true));
218 assert!(monitor.read_position().is_past(&p3) == Some(false));
219
220 for i in 0..1000 {
221 let n = 1 + (i % LEN);
222 let p4 = monitor.write_position();
223 for _ in 0..n {
224 tx.send(Identified { id, data: TestValue(id) }).await.unwrap();
225 id += 1;
226 }
227 let p5 = monitor.write_position();
228 for _ in 0..n {
229 rx.recv().await.unwrap();
230 assert!(monitor.read_position().is_past(&p0) == Some(true));
231 assert!(monitor.read_position().is_past(&p1) == Some(true));
232 assert!(monitor.read_position().is_past(&p2) == Some(true));
233 assert!(monitor.read_position().is_past(&p3) == Some(true));
234 assert!(monitor.read_position().is_past(&p4) == Some(true));
235 assert!(monitor.read_position().is_past(&p5) == Some(false));
236 }
237 }
238 }
239
240 struct Subscription<T> {
241 tx: broadcast::Sender<T>,
242 rx: Mutex<broadcast::Receiver<T>>,
243 }
244
245 impl<T: Clone> Subscription<T> {
246 fn new(capacity: usize) -> Self {
247 let (tx, rx) = broadcast::channel(capacity);
248 Self {
249 tx,
250 rx: Mutex::new(rx),
251 }
252 }
253
254 fn send(&self, val: T) -> Result<(), SendError<T>> {
255 self.tx.send(val).map(|_| ())
256 }
257
258 async fn recv(&self) -> Result<T, RecvError> {
259 let mut rx = self.rx.lock().await;
260 rx.recv().await
261 }
262 }
263
264 impl<T> Clone for Subscription<T> {
265 fn clone(&self) -> Self {
266 Self {
267 tx: self.tx.clone(),
268 rx: Mutex::new(self.tx.subscribe()),
269 }
270 }
271 }
272
273 #[derive(Clone)]
274 struct TestAsyncSynchronizer {
275 not_empty: Subscription<()>,
276 not_full: Subscription<()>,
277 }
278
279 impl TestAsyncSynchronizer {
280 fn new() -> Self {
281 Self {
282 not_empty: Subscription::new(128),
283 not_full: Subscription::new(128),
284 }
285 }
286 }
287
288 impl AsyncSynchronizer for TestAsyncSynchronizer {
289 fn wait(&self, event: QueueEvent) -> Pin<Box<dyn Future<Output=Result<(), SynchronizationError>> + '_>> {
290 async move {
291 match event {
292 QueueEvent::NotEmpty => self.not_empty.recv().await,
293 QueueEvent::NotFull => self.not_full.recv().await,
294 }.map_err(|_| SynchronizationError::ChannelClosed)
295 }.boxed()
296 }
297
298 fn notify(&self, event: QueueEvent) {
299 let _ = match event {
300 QueueEvent::NotEmpty => self.not_empty.send(()),
301 QueueEvent::NotFull => self.not_full.send(()),
302 };
303 }
304 }
305}