ipc_queue/
interface_async.rs

1/* Copyright (c) Fortanix, Inc.
2 *
3 * This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
6
7use std::sync::atomic::Ordering;
8use crate::AsyncReceiver;
9use crate::AsyncSender;
10use crate::AsyncSynchronizer;
11#[cfg(not(target_env = "sgx"))]
12use crate::DescriptorGuard;
13use crate::Identified;
14use crate::QueueEvent;
15use crate::RecvError;
16use crate::SendError;
17use crate::SynchronizationError;
18use crate::Transmittable;
19use crate::TryRecvError;
20use crate::TrySendError;
21use crate::position::PositionMonitor;
22
23unsafe impl<T: Send, S: Send> Send for AsyncSender<T, S> {}
24unsafe impl<T: Send, S: Sync> Sync for AsyncSender<T, S> {}
25
26impl<T, S: Clone> Clone for AsyncSender<T, S> {
27    fn clone(&self) -> Self {
28        Self {
29            inner: self.inner.clone(),
30            synchronizer: self.synchronizer.clone(),
31        }
32    }
33}
34
35impl<T: Transmittable, S: AsyncSynchronizer> AsyncSender<T, S> {
36    pub async fn send(&self, val: Identified<T>) -> Result<(), SendError> {
37        loop {
38            match self.inner.try_send_impl(val) {
39                Ok(wake_receiver) => {
40                    if wake_receiver {
41                        self.synchronizer.notify(QueueEvent::NotEmpty);
42                    }
43                    return Ok(());
44                }
45                Err(TrySendError::QueueFull) => {
46                    self.synchronizer
47                        .wait(QueueEvent::NotFull).await
48                        .map_err(|SynchronizationError::ChannelClosed| SendError::Closed)?;
49                }
50                Err(TrySendError::Closed) => return Err(SendError::Closed),
51            };
52        }
53    }
54
55    /// Consumes `self` and returns a DescriptorGuard.
56    /// The returned guard can be used to make `FifoDescriptor`s that remain
57    /// valid as long as the guard is not dropped.
58    #[cfg(not(target_env = "sgx"))]
59    pub fn into_descriptor_guard(self) -> DescriptorGuard<T> {
60        self.inner.into_descriptor_guard()
61    }
62}
63
64unsafe impl<T: Send, S: Send> Send for AsyncReceiver<T, S> {}
65
66impl<T: Transmittable, S: AsyncSynchronizer> AsyncReceiver<T, S> {
67    pub async fn recv(&self) -> Result<Identified<T>, RecvError> {
68        loop {
69            match self.inner.try_recv_impl() {
70                Ok((val, wake_sender, read_wrapped_around)) => {
71                    if wake_sender {
72                        self.synchronizer.notify(QueueEvent::NotFull);
73                    }
74                    if read_wrapped_around {
75                        self.read_epoch.fetch_add(1, Ordering::Relaxed);
76                    }
77                    return Ok(val);
78                }
79                Err(TryRecvError::QueueEmpty) => {
80                    self.synchronizer
81                        .wait(QueueEvent::NotEmpty).await
82                        .map_err(|SynchronizationError::ChannelClosed| RecvError::Closed)?;
83                }
84                Err(TryRecvError::Closed) => return Err(RecvError::Closed),
85            }
86        }
87    }
88
89    pub fn position_monitor(&self) -> PositionMonitor<T> {
90        PositionMonitor::new(self.read_epoch.clone(), self.inner.clone())
91    }
92
93    /// Consumes `self` and returns a DescriptorGuard.
94    /// The returned guard can be used to make `FifoDescriptor`s that remain
95    /// valid as long as the guard is not dropped.
96    #[cfg(not(target_env = "sgx"))]
97    pub fn into_descriptor_guard(self) -> DescriptorGuard<T> {
98        self.inner.into_descriptor_guard()
99    }
100}
101
102#[cfg(not(target_env = "sgx"))]
103#[cfg(test)]
104mod tests {
105    use futures::future::FutureExt;
106    use futures::lock::Mutex;
107    use tokio::sync::broadcast;
108    use tokio::sync::broadcast::error::{SendError, RecvError};
109
110    use crate::*;
111    use crate::test_support::TestValue;
112
113    async fn do_single_sender(len: usize, n: u64) {
114        let s = TestAsyncSynchronizer::new();
115        let (tx, rx) = bounded_async(len, s);
116        let local = tokio::task::LocalSet::new();
117
118        let h1 = local.spawn_local(async move {
119            for i in 0..n {
120                tx.send(Identified { id: i + 1, data: TestValue(i) }).await.unwrap();
121            }
122        });
123
124        let h2 = local.spawn_local(async move {
125            for i in 0..n {
126                let v = rx.recv().await.unwrap();
127                assert_eq!(v.id, i + 1);
128                assert_eq!(v.data.0, i);
129            }
130        });
131
132        local.await;
133        h1.await.unwrap();
134        h2.await.unwrap();
135    }
136
137    #[tokio::test]
138    async fn single_sender() {
139        do_single_sender(4, 10).await;
140        do_single_sender(1, 10).await;
141        do_single_sender(32, 1024).await;
142        do_single_sender(1024, 32).await;
143    }
144
145    async fn do_multi_sender(len: usize, n: u64, senders: u64) {
146        let s = TestAsyncSynchronizer::new();
147        let (tx, rx) = bounded_async(len, s);
148        let mut handles = Vec::with_capacity(senders as _);
149        let local = tokio::task::LocalSet::new();
150
151        for t in 0..senders {
152            let tx = tx.clone();
153            handles.push(local.spawn_local(async move {
154                for i in 0..n {
155                    let id = t * n + i + 1;
156                    tx.send(Identified { id, data: TestValue(i) }).await.unwrap();
157                }
158            }));
159        }
160
161        handles.push(local.spawn_local(async move {
162            for _ in 0..(n * senders) {
163                rx.recv().await.unwrap();
164            }
165        }));
166
167        local.await;
168        for h in handles {
169            h.await.unwrap();
170        }
171    }
172
173    #[tokio::test]
174    async fn multi_sender() {
175        do_multi_sender(4, 10, 3).await;
176        do_multi_sender(4, 1, 100).await;
177        do_multi_sender(2, 10, 100).await;
178        do_multi_sender(1024, 30, 100).await;
179    }
180
181    #[tokio::test]
182    async fn positions() {
183        const LEN: usize = 16;
184        let s = TestAsyncSynchronizer::new();
185        let (tx, rx) = bounded_async(LEN, s);
186        let monitor = rx.position_monitor();
187        let mut id = 1;
188
189        let p0 = monitor.write_position();
190        tx.send(Identified { id, data: TestValue(1) }).await.unwrap();
191        let p1 = monitor.write_position();
192        tx.send(Identified { id: id + 1, data: TestValue(2) }).await.unwrap();
193        let p2 = monitor.write_position();
194        tx.send(Identified { id: id + 2, data: TestValue(3) }).await.unwrap();
195        let p3 = monitor.write_position();
196        id += 3;
197        assert!(monitor.read_position().is_past(&p0) == Some(false));
198        assert!(monitor.read_position().is_past(&p1) == Some(false));
199        assert!(monitor.read_position().is_past(&p2) == Some(false));
200        assert!(monitor.read_position().is_past(&p3) == Some(false));
201
202        rx.recv().await.unwrap();
203        assert!(monitor.read_position().is_past(&p0) == Some(true));
204        assert!(monitor.read_position().is_past(&p1) == Some(false));
205        assert!(monitor.read_position().is_past(&p2) == Some(false));
206        assert!(monitor.read_position().is_past(&p3) == Some(false));
207
208        rx.recv().await.unwrap();
209        assert!(monitor.read_position().is_past(&p0) == Some(true));
210        assert!(monitor.read_position().is_past(&p1) == Some(true));
211        assert!(monitor.read_position().is_past(&p2) == Some(false));
212        assert!(monitor.read_position().is_past(&p3) == Some(false));
213
214        rx.recv().await.unwrap();
215        assert!(monitor.read_position().is_past(&p0) == Some(true));
216        assert!(monitor.read_position().is_past(&p1) == Some(true));
217        assert!(monitor.read_position().is_past(&p2) == Some(true));
218        assert!(monitor.read_position().is_past(&p3) == Some(false));
219
220        for i in 0..1000 {
221            let n = 1 + (i % LEN);
222            let p4 = monitor.write_position();
223            for _ in 0..n {
224                tx.send(Identified { id, data: TestValue(id) }).await.unwrap();
225                id += 1;
226            }
227            let p5 = monitor.write_position();
228            for _ in 0..n {
229                rx.recv().await.unwrap();
230                assert!(monitor.read_position().is_past(&p0) == Some(true));
231                assert!(monitor.read_position().is_past(&p1) == Some(true));
232                assert!(monitor.read_position().is_past(&p2) == Some(true));
233                assert!(monitor.read_position().is_past(&p3) == Some(true));
234                assert!(monitor.read_position().is_past(&p4) == Some(true));
235                assert!(monitor.read_position().is_past(&p5) == Some(false));
236            }
237        }
238    }
239
240    struct Subscription<T> {
241        tx: broadcast::Sender<T>,
242        rx: Mutex<broadcast::Receiver<T>>,
243    }
244
245    impl<T: Clone> Subscription<T> {
246        fn new(capacity: usize) -> Self {
247            let (tx, rx) = broadcast::channel(capacity);
248            Self {
249                tx,
250                rx: Mutex::new(rx),
251            }
252        }
253
254        fn send(&self, val: T) -> Result<(), SendError<T>> {
255            self.tx.send(val).map(|_| ())
256        }
257
258        async fn recv(&self) -> Result<T, RecvError> {
259            let mut rx = self.rx.lock().await;
260            rx.recv().await
261        }
262    }
263
264    impl<T> Clone for Subscription<T> {
265        fn clone(&self) -> Self {
266            Self {
267                tx: self.tx.clone(),
268                rx: Mutex::new(self.tx.subscribe()),
269            }
270        }
271    }
272
273    #[derive(Clone)]
274    struct TestAsyncSynchronizer {
275        not_empty: Subscription<()>,
276        not_full: Subscription<()>,
277    }
278
279    impl TestAsyncSynchronizer {
280        fn new() -> Self {
281            Self {
282                not_empty: Subscription::new(128),
283                not_full: Subscription::new(128),
284            }
285        }
286    }
287
288    impl AsyncSynchronizer for TestAsyncSynchronizer {
289        fn wait(&self, event: QueueEvent) -> Pin<Box<dyn Future<Output=Result<(), SynchronizationError>> + '_>> {
290            async move {
291                match event {
292                    QueueEvent::NotEmpty => self.not_empty.recv().await,
293                    QueueEvent::NotFull => self.not_full.recv().await,
294                }.map_err(|_| SynchronizationError::ChannelClosed)
295            }.boxed()
296        }
297
298        fn notify(&self, event: QueueEvent) {
299            let _ = match event {
300                QueueEvent::NotEmpty => self.not_empty.send(()),
301                QueueEvent::NotFull => self.not_full.send(()),
302            };
303        }
304    }
305}