Skip to main content

ipc_queue/
fifo.rs

1/* Copyright (c) Fortanix, Inc.
2 *
3 * This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
6
7use std::cell::UnsafeCell;
8use std::marker::PhantomData;
9use std::mem;
10#[cfg(not(target_env = "sgx"))]
11use {
12    std::sync::atomic::AtomicU64,
13    std::sync::Arc,
14};
15use std::sync::atomic::{AtomicUsize, Ordering, Ordering::SeqCst};
16
17use fortanix_sgx_abi::{FifoDescriptor, WithId};
18
19#[cfg(target_env = "sgx")]
20use super::{Identified, Transmittable, TryRecvError, TrySendError, UserRef, UserSafeSized};
21
22#[cfg(not(target_env = "sgx"))]
23use super::{AsyncReceiver, AsyncSender, AsyncSynchronizer, DescriptorGuard, Identified, Receiver, Sender,
24    Synchronizer, Transmittable, TryRecvError, TrySendError};
25
26// `fortanix_sgx_abi::WithId` is not `Copy` because it contains an `AtomicU64`.
27// This type has the same memory layout but is `Copy` and can be marked as
28// `UserSafeSized` which is needed for the `User::from_raw_parts()` below.
29#[cfg(target_env = "sgx")]
30#[repr(C)]
31#[derive(Default, Clone, Copy)]
32struct UserSafeWithId<T> {
33    pub id: u64,
34    pub data: T,
35}
36
37#[cfg(target_env = "sgx")]
38unsafe impl<T: UserSafeSized> UserSafeSized for UserSafeWithId<T> {}
39
40#[cfg(target_env = "sgx")]
41unsafe fn _sanity_check_with_id() {
42    use std::mem::size_of;
43    let _: [u8; size_of::<fortanix_sgx_abi::WithId<()>>()] = [0u8; size_of::<UserSafeWithId<()>>()];
44}
45
46// `usize` is not `UserSafeSized` and thus cannot be used to copy data to/from userspace.
47#[cfg(target_env = "sgx")]
48#[allow(dead_code)] // Dead code analysis goes wrong here due to type casts; it's important that
49                    // `WrapUsize` has the same size as `usize`, even though it is not read
50#[repr(transparent)]
51#[derive(Copy, Clone)]
52struct WrapUsize(usize);
53
54#[cfg(target_env = "sgx")]
55unsafe impl UserSafeSized for WrapUsize{}
56
57#[cfg(not(target_env = "sgx"))]
58pub fn bounded<T, S>(len: usize, s: S) -> (Sender<T, S>, Receiver<T, S>)
59where
60    T: Transmittable,
61    S: Synchronizer,
62{
63    let arc = Arc::new(FifoBuffer::new(len));
64    let inner = Fifo::from_arc(arc);
65    let tx = Sender { inner: inner.clone(), synchronizer: s.clone() };
66    let rx = Receiver { inner, synchronizer: s };
67    (tx, rx)
68}
69
70#[cfg(not(target_env = "sgx"))]
71pub fn bounded_async<T, S>(len: usize, s: S) -> (AsyncSender<T, S>, AsyncReceiver<T, S>)
72where
73    T: Transmittable,
74    S: AsyncSynchronizer,
75{
76    let arc = Arc::new(FifoBuffer::new(len));
77    let inner = Fifo::from_arc(arc);
78    let tx = AsyncSender { inner: inner.clone(), synchronizer: s.clone() };
79    let rx = AsyncReceiver { inner, synchronizer: s, read_epoch: Arc::new(AtomicU64::new(0)) };
80    (tx, rx)
81}
82
83#[cfg(all(test, target_env = "sgx"))]
84pub(crate) fn bounded<T, S>(len: usize, s: S) -> (Sender<T, S>, Receiver<T, S>)
85where
86    T: Transmittable,
87    S: Synchronizer,
88{
89    use std::ops::DerefMut;
90    use std::os::fortanix_sgx::usercalls::alloc::User;
91
92    // Allocate [WithId<T>; len] in userspace
93    // WARNING: This creates dangling memory in userspace, use in tests only!
94    let mut data = User::<[UserSafeWithId<T>]>::uninitialized(len);
95    data.deref_mut().iter_mut().for_each(|v| v.copy_from_enclave(&UserSafeWithId::default()));
96
97    // WARNING: This creates dangling memory in userspace, use in tests only!
98    let offsets = User::<WrapUsize>::new_from_enclave(&WrapUsize(0));
99    let offsets = offsets.into_raw() as *const AtomicUsize;
100
101    let descriptor = FifoDescriptor {
102        data: data.into_raw() as _,
103        len,
104        offsets,
105    };
106
107    let inner = unsafe { Fifo::from_descriptor(descriptor) };
108    let tx = Sender { inner: inner.clone(), synchronizer: s.clone() };
109    let rx = Receiver { inner, synchronizer: s };
110    (tx, rx)
111}
112
113#[cfg(not(target_env = "sgx"))]
114pub(crate) struct FifoBuffer<T> {
115    data: Box<[WithId<T>]>,
116    offsets: Box<AtomicUsize>,
117}
118
119#[cfg(not(target_env = "sgx"))]
120impl<T: Transmittable> FifoBuffer<T> {
121    fn new(len: usize) -> Self {
122        assert!(
123            len.is_power_of_two(),
124            "Fifo len should be a power of two"
125        );
126        let mut data = Vec::with_capacity(len);
127        data.resize_with(len, || WithId { id: AtomicU64::new(0), data: T::default() });
128        Self {
129            data: data.into_boxed_slice(),
130            offsets: Box::new(AtomicUsize::new(0)),
131        }
132    }
133}
134
135enum Storage<T: 'static> {
136    #[cfg(not(target_env = "sgx"))]
137    Shared(Arc<FifoBuffer<T>>),
138    Static(PhantomData<&'static T>),
139}
140
141impl<T> Clone for Storage<T> {
142    fn clone(&self) -> Self {
143        match self {
144            #[cfg(not(target_env = "sgx"))]
145            Storage::Shared(arc) => Storage::Shared(arc.clone()),
146            Storage::Static(p) => Storage::Static(*p),
147        }
148    }
149}
150
151pub(crate) struct Fifo<T: 'static> {
152    data: &'static [UnsafeCell<WithId<T>>],
153    offsets: &'static AtomicUsize,
154    storage: Storage<T>,
155}
156
157impl<T> Clone for Fifo<T> {
158    fn clone(&self) -> Self {
159        Self {
160            data: self.data,
161            offsets: self.offsets,
162            storage: self.storage.clone(),
163        }
164    }
165}
166
167impl<T> Fifo<T> {
168    pub(crate) fn current_offsets(&self, ordering: Ordering) -> Offsets {
169        Offsets::new(self.offsets.load(ordering), self.data.len() as u32)
170    }
171}
172
173impl<T: Transmittable> Fifo<T> {
174    pub(crate) unsafe fn from_descriptor(descriptor: FifoDescriptor<T>) -> Self {
175        assert!(
176            descriptor.len.is_power_of_two(),
177            "Fifo len should be a power of two"
178        );
179        #[cfg(target_env = "sgx")] {
180            use std::os::fortanix_sgx::usercalls::alloc::User;
181
182            // check pointers are outside enclave range, etc.
183            let data = User::<[UserSafeWithId<T>]>::from_raw_parts(descriptor.data as _, descriptor.len);
184            mem::forget(data);
185            UserRef::from_ptr(descriptor.offsets as *const WrapUsize);
186
187        }
188        let data_slice = std::slice::from_raw_parts(descriptor.data, descriptor.len);
189        Self {
190            data: &*(data_slice as *const [WithId<T>] as *const [UnsafeCell<WithId<T>>]),
191            offsets: &*descriptor.offsets,
192            storage: Storage::Static(PhantomData::default()),
193        }
194    }
195
196    #[cfg(not(target_env = "sgx"))]
197    fn from_arc(fifo: Arc<FifoBuffer<T>>) -> Self {
198        unsafe {
199            Self {
200                data: &*(fifo.data.as_ref() as *const [WithId<T>] as *const [UnsafeCell<WithId<T>>]),
201                offsets: &*(fifo.offsets.as_ref() as *const AtomicUsize),
202                storage: Storage::Shared(fifo),
203            }
204        }
205    }
206
207    /// Consumes `self` and returns a DescriptorGuard.
208    /// Panics if `self` was created using `from_descriptor`.
209    #[cfg(not(target_env = "sgx"))]
210    pub(crate) fn into_descriptor_guard(self) -> DescriptorGuard<T> {
211        let arc = match self.storage {
212            Storage::Shared(arc) => arc,
213            Storage::Static(_) => panic!("Sender/Receiver created using `from_descriptor()` cannot be turned into DescriptorGuard."),
214        };
215        let descriptor = FifoDescriptor {
216            data: self.data.as_ptr() as _,
217            len: self.data.len(),
218            offsets: self.offsets,
219        };
220        DescriptorGuard { descriptor, _fifo: arc }
221    }
222
223    pub(crate) fn try_send_impl(&self, val: Identified<T>) -> Result</*wake up reader:*/ bool, TrySendError> {
224        let (new, was_empty) = loop {
225            // 1. Load the current offsets.
226            let current = self.current_offsets(Ordering::SeqCst);
227            let was_empty = current.is_empty();
228
229            // 2. If the queue is full, wait, then go to step 1.
230            if current.is_full() {
231                return Err(TrySendError::QueueFull);
232            }
233
234            // 3. Add 1 to the write offset and do an atomic compare-and-swap (CAS)
235            // with the current offsets. If the CAS was not successful, go to step 1.
236            let new = current.increment_write_offset();
237            let current = current.as_usize();
238            if self.offsets.compare_exchange(current, new.as_usize(), SeqCst, SeqCst).is_ok() {
239                break (new, was_empty);
240            }
241        };
242
243        // 4. Write the data, then the `id`.
244        unsafe {
245            let slot = &mut *self.data[new.write_offset()].get();
246            T::write(&mut slot.data, &val.data);
247            slot.id.store(val.id, SeqCst);
248        }
249
250        // 5. If the queue was empty in step 1, signal the reader to wake up.
251        Ok(was_empty)
252    }
253
254    pub(crate) fn try_recv_impl(&self) -> Result<(Identified<T>, /*wake up writer:*/ bool, /*read offset wrapped around:*/ bool), TryRecvError> {
255        // 1. Load the current offsets.
256        let current = self.current_offsets(Ordering::SeqCst);
257
258        // 2. If the queue is empty, wait, then go to step 1.
259        if current.is_empty() {
260            return Err(TryRecvError::QueueEmpty);
261        }
262
263        // 3. Add 1 to the read offset.
264        let new = current.increment_read_offset();
265
266        let (slot, id) = loop {
267            // 4. Read the `id` at the new read offset.
268            let slot = unsafe { &mut *self.data[new.read_offset()].get() };
269            let id = slot.id.load(SeqCst);
270
271            // 5. If `id` is `0`, go to step 4 (spin). Spinning is OK because data is
272            //    expected to be written imminently.
273            if id != 0 {
274                break (slot, id);
275            }
276        };
277
278        // 6. Read the data, then store `0` in the `id`.
279        let data = unsafe { T::read(&slot.data) };
280        let val = Identified { id, data };
281        slot.id.store(0, SeqCst);
282
283        // 7. Store the new read offset, retrieving the old offsets.
284        let before = fetch_adjust(
285            self.offsets,
286            new.read as isize - current.read as isize,
287            SeqCst,
288        );
289
290        // 8. If the queue was full before step 7, signal the writer to wake up.
291        let was_full = Offsets::new(before, self.data.len() as u32).is_full();
292        Ok((val, was_full, new.read_offset() == 0))
293    }
294}
295
296pub(crate) fn fetch_adjust(x: &AtomicUsize, delta: isize, ord: Ordering) -> usize {
297    match delta > 0 {
298        true => x.fetch_add(delta as usize, ord),
299        false => x.fetch_sub(-delta as usize, ord),
300    }
301}
302
303#[derive(Clone, Copy)]
304pub(crate) struct Offsets {
305    write: u32,
306    read: u32,
307    len: u32,
308}
309
310impl Offsets {
311    // This implementation only works on 64-bit platforms.
312    fn _assert_usize_is_eight_bytes() -> [u8; 8] {
313        [0u8; mem::size_of::<usize>()]
314    }
315
316    pub(crate) fn new(offsets: usize, len: u32) -> Self {
317        debug_assert!(len.is_power_of_two());
318        Self {
319            write: (offsets >> 32) as u32,
320            read: offsets as u32,
321            len,
322        }
323    }
324
325    pub(crate) fn as_usize(&self) -> usize {
326        ((self.write as usize) << 32) | (self.read as usize)
327    }
328
329    pub(crate) fn is_empty(&self) -> bool {
330        self.read == self.write
331    }
332
333    pub(crate) fn is_full(&self) -> bool {
334        self.read != self.write && self.read_offset() == self.write_offset()
335    }
336
337    pub(crate) fn read_offset(&self) -> usize {
338        (self.read & (self.len - 1)) as _
339    }
340
341    pub(crate) fn write_offset(&self) -> usize {
342        (self.write & (self.len - 1)) as _
343    }
344
345    pub(crate) fn increment_read_offset(&self) -> Self {
346        Self {
347            read: (self.read + 1) & (self.len * 2 - 1),
348            ..*self
349        }
350    }
351
352    pub(crate) fn increment_write_offset(&self) -> Self {
353        Self {
354            write: (self.write + 1) & (self.len * 2 - 1),
355            ..*self
356        }
357    }
358
359    #[allow(unused)]
360    pub(crate) fn read_high_bit(&self) -> bool {
361        self.read & self.len == self.len
362    }
363
364    #[allow(unused)]
365    pub(crate) fn write_high_bit(&self) -> bool {
366        self.write & self.len == self.len
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373    use crate::test_support::{NoopSynchronizer, TestValue};
374    use std::sync::mpsc;
375    use std::thread;
376
377    fn inner<T, S>(tx: Sender<T, S>) -> Fifo<T> {
378        tx.inner
379    }
380
381    #[test]
382    fn basic1() {
383        let (tx, _rx) = bounded(32, NoopSynchronizer);
384        let inner = inner(tx);
385        assert!(inner.try_recv_impl().is_err());
386
387        for i in 1..=7 {
388            let wake = inner.try_send_impl(Identified { id: i, data: TestValue(i) }).unwrap();
389            assert!(if i == 1 { wake } else { !wake });
390        }
391
392        for i in 1..=7 {
393            let (v, wake, _) = inner.try_recv_impl().unwrap();
394            assert!(!wake);
395            assert_eq!(v.id, i);
396            assert_eq!(v.data.0, i);
397        }
398        assert!(inner.try_recv_impl().is_err());
399    }
400
401    #[test]
402    fn basic2() {
403        let (tx, _rx) = bounded(8, NoopSynchronizer);
404        let inner = inner(tx);
405        for _ in 0..3 {
406            for i in 1..=8 {
407                inner.try_send_impl(Identified { id: i, data: TestValue(i) }).unwrap();
408            }
409            assert!(inner.try_send_impl(Identified { id: 9, data: TestValue(9) }).is_err());
410
411            for i in 1..=8 {
412                let (v, wake, _) = inner.try_recv_impl().unwrap();
413                assert!(if i == 1 { wake } else { !wake });
414                assert_eq!(v.id, i);
415                assert_eq!(v.data.0, i);
416            }
417            assert!(inner.try_recv_impl().is_err());
418        }
419    }
420
421    #[test]
422    fn multi_threaded() {
423        let (tx, rx) = bounded(32, NoopSynchronizer);
424        assert!(rx.try_recv().is_err());
425
426        let (signal_tx, signal_rx) = mpsc::channel();
427
428        let h = thread::spawn(move || {
429            for _ in 0..4 {
430                for i in 0..7 {
431                    tx.try_send(Identified { id: i + 1, data: TestValue(i) }).unwrap();
432                }
433                signal_tx.send(()).unwrap();
434            }
435        });
436
437        for _ in 0..4 {
438            signal_rx.recv().unwrap();
439            for i in 0..7 {
440                let v = rx.try_recv().unwrap();
441                assert_eq!(v.id, i + 1);
442                assert_eq!(v.data.0, i);
443            }
444        }
445        assert!(rx.try_recv().is_err());
446        h.join().unwrap();
447    }
448
449    #[test]
450    fn fetch_adjust_correctness() {
451        let x = AtomicUsize::new(0);
452        fetch_adjust(&x, 5, SeqCst);
453        assert_eq!(x.load(SeqCst), 5);
454        fetch_adjust(&x, -3, SeqCst);
455        assert_eq!(x.load(SeqCst), 2);
456    }
457
458    #[test]
459    fn offsets() {
460        let mut o = Offsets::new(/*offsets:*/ 0, /*len:*/ 4);
461        assert!(o.is_empty());
462        assert!(!o.is_full());
463
464        for _ in 0..10 {
465            for i in 0..4 {
466                o = o.increment_write_offset();
467                assert!(!o.is_empty());
468                if i < 3 {
469                    assert!(!o.is_full());
470                } else {
471                    assert!(o.is_full());
472                }
473            }
474
475            assert!(!o.is_empty());
476            assert!(o.is_full());
477
478            for i in 0..4 {
479                o = o.increment_read_offset();
480                assert!(!o.is_full());
481                if i < 3 {
482                    assert!(!o.is_empty());
483                } else {
484                    assert!(o.is_empty());
485                }
486            }
487        }
488    }
489}