ipc_queue/
interface_sync.rs

1/* Copyright (c) Fortanix, Inc.
2 *
3 * This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
6
7use fortanix_sgx_abi::FifoDescriptor;
8
9use super::{Fifo, Identified, QueueEvent, Receiver, RecvError, Sender, SendError, SynchronizationError, Synchronizer, Transmittable, TryRecvError, TrySendError};
10
11unsafe impl<T: Send, S: Send> Send for Sender<T, S> {}
12unsafe impl<T: Send, S: Sync> Sync for Sender<T, S> {}
13
14impl<T, S: Clone> Clone for Sender<T, S> {
15    fn clone(&self) -> Self {
16        Self {
17            inner: self.inner.clone(),
18            synchronizer: self.synchronizer.clone(),
19        }
20    }
21}
22
23impl<T: Transmittable, S: Synchronizer> Sender<T, S> {
24    /// Create a `Sender` from a `FifoDescriptor` and `Synchronizer`.
25    ///
26    /// # Safety
27    ///
28    /// The caller must ensure the following:
29    ///
30    /// * The `data` and `len` fields in `FifoDescriptor` must adhere to all
31    ///   safety requirements described in `std::slice::from_raw_parts_mut()`
32    ///
33    /// * The `offsets` field in `FifoDescriptor` must be non-null and point
34    ///   to a valid memory location holding an `AtomicUsize`.
35    ///
36    /// * The synchronizer must somehow know how to correctly synchronize with
37    ///   the other end of the channel.
38    pub unsafe fn from_descriptor(d: FifoDescriptor<T>, synchronizer: S) -> Self {
39        Self {
40            inner: Fifo::from_descriptor(d),
41            synchronizer,
42        }
43    }
44
45    pub fn try_send(&self, val: Identified<T>) -> Result<(), TrySendError> {
46        self.inner.try_send_impl(val).map(|wake_receiver| {
47            if wake_receiver {
48                self.synchronizer.notify(QueueEvent::NotEmpty);
49            }
50        })
51    }
52
53    /// Tries to send multiple values. Calling this function has the same
54    /// semantics as calling `try_send` for each item in order until an error
55    /// occurs, but it has the benefit of notifying the receiver at most once.
56    ///
57    /// Returns the number of successfully sent items if any item was
58    /// successfully sent, otherwise returns an error.
59    pub fn try_send_multiple(&self, values: &[Identified<T>]) -> Result<usize, TrySendError> {
60        let mut wake_receiver = false;
61        let mut sent = 0;
62        for val in values {
63            wake_receiver |= match self.inner.try_send_impl(*val) {
64                Ok(wake_receiver) => wake_receiver,
65                Err(e) if sent == 0 => return Err(e),
66                Err(_) => break,
67            };
68            sent += 1;
69        }
70        if wake_receiver {
71            self.synchronizer.notify(QueueEvent::NotEmpty);
72        }
73        Ok(sent)
74    }
75
76    pub fn send(&self, val: Identified<T>) -> Result<(), SendError> {
77        loop {
78            match self.inner.try_send_impl(val) {
79                Ok(wake_receiver) => {
80                    if wake_receiver {
81                        self.synchronizer.notify(QueueEvent::NotEmpty);
82                    }
83                    return Ok(());
84                }
85                Err(TrySendError::QueueFull) => {
86                    self.synchronizer
87                        .wait(QueueEvent::NotFull)
88                        .map_err(|SynchronizationError::ChannelClosed| SendError::Closed)?;
89                }
90                Err(TrySendError::Closed) => return Err(SendError::Closed),
91            };
92        }
93    }
94}
95
96unsafe impl<T: Send, S: Send> Send for Receiver<T, S> {}
97
98impl<T: Transmittable, S: Synchronizer> Receiver<T, S> {
99    /// Create a `Receiver` from a `FifoDescriptor` and `Synchronizer`.
100    ///
101    /// # Safety
102    ///
103    /// In addition to all requirements laid out in `Sender::from_descriptor`,
104    /// the caller must ensure the following additional requirements:
105    ///
106    /// * The caller must ensure that there is at most one `Receiver` for the queue.
107    pub unsafe fn from_descriptor(d: FifoDescriptor<T>, synchronizer: S) -> Self {
108        Self {
109            inner: Fifo::from_descriptor(d),
110            synchronizer,
111        }
112    }
113
114    pub fn try_recv(&self) -> Result<Identified<T>, TryRecvError> {
115        self.inner.try_recv_impl().map(|(val, wake_sender, _)| {
116            if wake_sender {
117                self.synchronizer.notify(QueueEvent::NotFull);
118            }
119            val
120        })
121    }
122
123    pub fn try_iter(&self) -> TryIter<'_, T, S> {
124        TryIter(self)
125    }
126
127    pub fn recv(&self) -> Result<Identified<T>, RecvError> {
128        loop {
129            match self.inner.try_recv_impl() {
130                Ok((val, wake_sender, _)) => {
131                    if wake_sender {
132                        self.synchronizer.notify(QueueEvent::NotFull);
133                    }
134                    return Ok(val);
135                }
136                Err(TryRecvError::QueueEmpty) => {
137                    self.synchronizer
138                        .wait(QueueEvent::NotEmpty)
139                        .map_err(|SynchronizationError::ChannelClosed| RecvError::Closed)?;
140                }
141                Err(TryRecvError::Closed) => return Err(RecvError::Closed),
142            }
143        }
144    }
145}
146
147pub struct TryIter<'r, T: 'static, S>(&'r Receiver<T, S>);
148
149impl<'r, T: Transmittable, S: Synchronizer> Iterator for TryIter<'r, T, S> {
150    type Item = Identified<T>;
151
152    fn next(&mut self) -> Option<Self::Item> {
153        self.0.try_recv().ok()
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use crate::fifo::bounded;
160    use crate::test_support::pubsub::{Channel, Subscription};
161    use crate::test_support::TestValue;
162    use crate::*;
163    use std::thread;
164
165    fn do_single_sender(len: usize, n: u64) {
166        let s = TestSynchronizer::new();
167        let (tx, rx) = bounded(len, s);
168
169        let h = thread::spawn(move || {
170            for i in 0..n {
171                tx.send(Identified { id: i + 1, data: TestValue(i) }).unwrap();
172            }
173        });
174
175        for i in 0..n {
176            let v = rx.recv().unwrap();
177            assert_eq!(v.id, i + 1);
178            assert_eq!(v.data.0, i);
179        }
180
181        h.join().unwrap();
182    }
183
184    #[test]
185    fn single_sender() {
186        do_single_sender(4, 10);
187        do_single_sender(1, 10);
188        do_single_sender(32, 1024);
189        do_single_sender(1024, 32);
190    }
191
192    fn do_multi_sender(len: usize, n: u64, senders: u64) {
193        let s = TestSynchronizer::new();
194        let (tx, rx) = bounded(len, s);
195        let mut handles = Vec::with_capacity(senders as _);
196
197        for t in 0..senders {
198            let tx = tx.clone();
199            handles.push(thread::spawn(move || {
200                for i in 0..n {
201                    let id = t * n + i + 1;
202                    tx.send(Identified { id, data: TestValue(i) }).unwrap();
203                }
204            }));
205        }
206
207        for _ in 0..(n * senders) {
208            rx.recv().unwrap();
209        }
210
211        for h in handles {
212            h.join().unwrap();
213        }
214    }
215
216    #[test]
217    fn multi_sender() {
218        do_multi_sender(4, 10, 3);
219        do_multi_sender(4, 1, 100);
220        do_multi_sender(2, 10, 100);
221        do_multi_sender(1024, 30, 100);
222    }
223
224    #[test]
225    fn try_error() {
226        const N: u64 = 8;
227        let s = TestSynchronizer::new();
228        let (tx, rx) = bounded(N as _, s);
229
230        for i in 0..N {
231            tx.send(Identified { id: i + 1, data: TestValue(i) }).unwrap();
232        }
233        assert!(tx.try_send(Identified { id: N + 1, data: TestValue(N) }).is_err());
234
235        for i in 0..N {
236            let v = rx.recv().unwrap();
237            assert_eq!(v.id, i + 1);
238            assert_eq!(v.data.0, i);
239        }
240        assert!(rx.try_recv().is_err());
241    }
242
243    #[test]
244    fn very_optimistic() {
245        const N: u64 = 8;
246        let s = TestSynchronizer::new();
247        let (tx, rx) = bounded(N as _, s);
248
249        for i in 0..N {
250            tx.try_send(Identified { id: i + 1, data: TestValue(i) }).unwrap();
251        }
252
253        for i in 0..N {
254            let v = rx.try_recv().unwrap();
255            assert_eq!(v.id, i + 1);
256            assert_eq!(v.data.0, i);
257        }
258    }
259
260    #[test]
261    fn mixed_try_send() {
262        let s = TestSynchronizer::new();
263        let (tx, rx) = bounded(8, s);
264
265        let h = thread::spawn(move || {
266            let mut sent_without_wait = 0;
267            for _ in 0..7 {
268                for i in 0..11 {
269                    let v = Identified { id: i + 1, data: TestValue(i) };
270                    if let Err(_) = tx.try_send(v) {
271                        tx.send(v).unwrap();
272                    } else {
273                        sent_without_wait += 1;
274                    }
275                }
276            }
277            assert!(sent_without_wait > 0);
278        });
279
280        for _ in 0..7 {
281            for i in 0..11 {
282                let v = rx.recv().unwrap();
283                assert_eq!(v.id, i + 1);
284                assert_eq!(v.data.0, i);
285            }
286        }
287
288        h.join().unwrap();
289    }
290
291    #[test]
292    fn mixed_try_recv() {
293        let s = TestSynchronizer::new();
294        let (tx, rx) = bounded(8, s);
295
296        let h = thread::spawn(move || {
297            for _ in 0..11 {
298                for i in 0..13 {
299                    tx.send(Identified { id: i + 1, data: TestValue(i) }).unwrap();
300                }
301            }
302        });
303
304        for _ in 0..11 {
305            for i in 0..13 {
306                let v = match rx.try_recv() {
307                    Ok(v) => v,
308                    Err(_) => rx.recv().unwrap(),
309                };
310                assert_eq!(v.id, i + 1);
311                assert_eq!(v.data.0, i);
312            }
313        }
314
315        h.join().unwrap();
316    }
317
318    #[test]
319    fn try_iter() {
320        let s = TestSynchronizer::new();
321        let (tx, rx) = bounded(8, s);
322        const N: u64 = 2048;
323
324        let h = thread::spawn(move || {
325            for i in 0..N {
326                tx.send(Identified { id: i + 1, data: TestValue(i) }).unwrap();
327            }
328        });
329
330        let mut total = 0;
331        while total < N {
332            for v in rx.recv().ok().into_iter().chain(rx.try_iter()) {
333                assert_eq!(v.id, total + 1);
334                assert_eq!(v.data.0, total);
335                total += 1;
336            }
337        }
338
339        h.join().unwrap();
340    }
341
342    #[test]
343    fn try_send_multiple() {
344        let s = TestSynchronizer::new();
345        let (tx, rx) = bounded(32, s);
346        const SENDERS: usize = 4;
347        const N: usize = 1024;
348        let mut handles = Vec::with_capacity(SENDERS);
349
350        for t in 0..SENDERS {
351            let tx = tx.clone();
352            handles.push(thread::spawn(move || {
353                let mut to_send = Vec::with_capacity(N);
354                for i in 0..N {
355                    let id = (t * N + i + 1) as u64;
356                    to_send.push(Identified { id, data: TestValue(i as u64) });
357                }
358                let mut sent = 0;
359                while sent < to_send.len() {
360                    match tx.try_send_multiple(&to_send[sent..]) {
361                        Err(_) => thread::yield_now(),
362                        Ok(n) => sent += n,
363                    }
364                }
365            }));
366        }
367
368        let mut values = Vec::with_capacity(N * SENDERS);
369        for _ in 0..(N * SENDERS) {
370            values.push(rx.recv().unwrap());
371        }
372        values.sort_by_key(|v| v.id);
373        assert!(values.windows(2).all(|w| w[0].id < w[1].id));
374
375        for h in handles {
376            h.join().unwrap();
377        }
378    }
379
380    #[derive(Clone)]
381    pub struct TestSynchronizer {
382        not_empty: Subscription<()>,
383        not_full: Subscription<()>,
384    }
385
386    impl TestSynchronizer {
387        pub fn new() -> Self {
388            Self {
389                not_empty: Channel::new().subscribe(),
390                not_full: Channel::new().subscribe(),
391            }
392        }
393    }
394
395    impl Synchronizer for TestSynchronizer {
396        fn wait(&self, event: QueueEvent) -> Result<(), SynchronizationError> {
397            match event {
398                QueueEvent::NotEmpty => self.not_empty.recv(),
399                QueueEvent::NotFull => self.not_full.recv(),
400            }.map_err(|_| SynchronizationError::ChannelClosed)
401        }
402
403        fn notify(&self, event: QueueEvent) {
404            let _ = match event {
405                QueueEvent::NotEmpty => self.not_empty.broadcast(()),
406                QueueEvent::NotFull => self.not_full.broadcast(()),
407            };
408        }
409    }
410}