real_time/
reader.rs

1use {
2    crate::{
3        backoff::Backoff,
4        sync::{
5            atomic::{AtomicPtr, Ordering},
6            Arc,
7        },
8        PhantomUnsync,
9    },
10    crossbeam_utils::CachePadded,
11    std::{cell::UnsafeCell, marker::PhantomData, ops::Deref, ptr::null_mut},
12};
13
14/// A shared value that can be read on the real-time thread without blocking.
15pub struct RealtimeReader<T> {
16    shared: Arc<Shared<T>>,
17    _marker: PhantomUnsync,
18}
19
20/// A shared value that can be mutated on a non-real-time thread.
21pub struct LockingWriter<T> {
22    shared: Arc<Shared<T>>,
23    _marker: PhantomUnsync,
24}
25
26/// A guard that allows reading the shared value on the real-time thread.
27pub struct RealtimeReadGuard<'a, T> {
28    shared: &'a Shared<T>,
29    value: *const T,
30}
31
32/// Creates a shared value that can be read on the real-time thread without
33/// blocking.
34pub fn readable<T>(value: T) -> (LockingWriter<T>, RealtimeReader<T>)
35where
36    T: Send,
37{
38    let storage = Box::into_raw(Box::new(value));
39
40    let shared = Arc::new(Shared {
41        live: CachePadded::new(AtomicPtr::new(storage)),
42        storage: CachePadded::new(UnsafeCell::new(storage)),
43    });
44
45    (
46        LockingWriter {
47            shared: Arc::clone(&shared),
48            _marker: PhantomData,
49        },
50        RealtimeReader {
51            shared,
52            _marker: PhantomData,
53        },
54    )
55}
56
57struct Shared<T> {
58    storage: CachePadded<UnsafeCell<*mut T>>,
59    live: CachePadded<AtomicPtr<T>>,
60}
61
62impl<T> Drop for Shared<T> {
63    fn drop(&mut self) {
64        // SAFETY: Returned pointer of `get` call is never null.
65        let value = unsafe { *self.storage.get() };
66
67        assert!(!value.is_null());
68
69        // SAFETY: No other references to `value` exist, so it's safe to drop.
70        let _ = unsafe { Box::from_raw(value) };
71    }
72}
73
74unsafe impl<T> Sync for Shared<T> {}
75unsafe impl<T> Send for Shared<T> {}
76
77impl<T> RealtimeReader<T> {
78    fn lock(&self) -> RealtimeReadGuard<'_, T> {
79        let value = self.shared.live.swap(null_mut(), Ordering::Acquire);
80        debug_assert!(!value.is_null());
81
82        RealtimeReadGuard {
83            shared: &self.shared,
84            value,
85        }
86    }
87
88    /// Read the shared value on the real-time thread.
89    pub fn read(&mut self) -> RealtimeReadGuard<'_, T> {
90        self.lock()
91    }
92
93    /// Copy the shared value and return it.
94    pub fn get(&self) -> T
95    where
96        T: Copy,
97    {
98        *self.lock()
99    }
100}
101
102impl<T> Drop for RealtimeReadGuard<'_, T> {
103    fn drop(&mut self) {
104        self.shared
105            .live
106            .store(self.value.cast_mut(), Ordering::Release);
107    }
108}
109
110impl<T> Deref for RealtimeReadGuard<'_, T> {
111    type Target = T;
112
113    fn deref(&self) -> &Self::Target {
114        // SAFETY: `self.value` is a valid pointer for the lifetime of `self`.
115        unsafe { &*self.value }
116    }
117}
118
119impl<T> LockingWriter<T> {
120    /// Set the value.
121    pub fn set<V>(&self, value: V)
122    where
123        V: Into<T>,
124        T: Send,
125    {
126        let _ = self.swap(Box::new(value.into()));
127    }
128
129    /// Update the value and return the previous value.
130    pub fn swap(&self, value: Box<T>) -> Box<T>
131    where
132        T: Send,
133    {
134        let new = Box::into_raw(value);
135
136        // SAFETY: Both pointers are valid and aligned as they come from calling
137        // `Box::into_raw`.
138        let old = unsafe { self.shared.storage.get().replace(new) };
139
140        let backoff = Backoff::default();
141        while self
142            .shared
143            .live
144            .compare_exchange_weak(old, new, Ordering::AcqRel, Ordering::Relaxed)
145            .is_err()
146        {
147            backoff.spin();
148        }
149
150        // SAFETY: No other references to `old` now exist, so we can reconstruct the
151        // box.
152        unsafe { Box::from_raw(old) }
153    }
154}
155
156#[cfg(test)]
157mod test {
158    use {
159        super::*,
160        static_assertions::{assert_impl_all, assert_not_impl_any},
161        std::thread,
162    };
163
164    assert_impl_all!(RealtimeReader<i32>: Send);
165    assert_not_impl_any!(RealtimeReader<i32>: Sync, Copy, Clone);
166
167    assert_impl_all!(LockingWriter<i32>: Send);
168    assert_not_impl_any!(LockingWriter<i32>: Sync, Copy, Clone);
169
170    #[test]
171    fn setting_and_getting_the_shared_value() {
172        let (writer, reader) = readable(0);
173
174        assert_eq!(reader.get(), 0);
175        writer.set(1);
176        assert_eq!(reader.get(), 1);
177        writer.set(2);
178        assert_eq!(reader.get(), 2);
179    }
180
181    #[test]
182    fn reading_and_writing_simultaneously_from_different_threads() {
183        let (writer, mut reader) = readable(0);
184
185        #[cfg(miri)]
186        const NUM_WRITES: usize = 10;
187
188        #[cfg(not(miri))]
189        const NUM_WRITES: usize = 1_000_000;
190
191        let writer_thread = thread::spawn({
192            move || {
193                for value in 1..=NUM_WRITES {
194                    writer.set(value);
195                }
196            }
197        });
198
199        let mut last_value = 0;
200        while !writer_thread.is_finished() {
201            let value = reader.read();
202            assert!(*value >= last_value);
203            assert!(*value <= NUM_WRITES);
204            last_value = *value;
205        }
206
207        assert_eq!(reader.get(), NUM_WRITES);
208    }
209
210    #[test]
211    fn swapping_the_value() {
212        use std::ptr::addr_of;
213
214        let a = Box::new(1);
215        let a_addr = addr_of!(*a);
216
217        let (writer, reader) = readable(0);
218
219        let mut b = writer.swap(a);
220        assert_eq!(reader.get(), 1);
221        *b = 2;
222
223        let c = writer.swap(b);
224        assert_eq!(reader.get(), 2);
225        assert_eq!(addr_of!(*c), a_addr);
226    }
227
228    #[test]
229    fn can_set_anything_convertible_to_value() {
230        let (writer, reader) = readable(0_i64);
231
232        writer.set(42_i16);
233
234        assert_eq!(reader.get(), 42);
235    }
236}