Skip to main content

lock_free_static/
slot.rs

1use core::{
2    cell::UnsafeCell,
3    fmt::{self, Debug, Display, Formatter},
4    sync::atomic::{AtomicUsize, Ordering},
5};
6
7const NUM_BUFFERS: usize = 3;
8const BUFFER_INDEX_BITS: u32 = usize::BITS - NUM_BUFFERS.leading_zeros();
9const BUFFER_INDEX_MASK: usize = (1 << BUFFER_INDEX_BITS) - 1;
10
11#[derive(Clone, Copy, PartialEq, Eq, Debug)]
12struct State {
13    read_lock: Option<usize>,
14    write_lock: Option<usize>,
15    most_recent: usize,
16    readers_count: usize,
17}
18
19impl State {
20    const READ_LOCK_OFFSET: u32 = 0;
21    const WRITE_LOCK_OFFSET: u32 = Self::READ_LOCK_OFFSET + BUFFER_INDEX_BITS;
22    const MOST_RECENT_OFFSET: u32 = Self::WRITE_LOCK_OFFSET + BUFFER_INDEX_BITS;
23    const READERS_COUNT_OFFSET: u32 = Self::MOST_RECENT_OFFSET + BUFFER_INDEX_BITS;
24
25    const MAX_READERS_COUNT: usize = usize::MAX >> Self::READERS_COUNT_OFFSET;
26
27    const fn new() -> Self {
28        State {
29            read_lock: None,
30            write_lock: None,
31            most_recent: 0,
32            readers_count: 0,
33        }
34    }
35    const fn unpack(value: usize) -> Self {
36        State {
37            read_lock: match (value >> Self::READ_LOCK_OFFSET) & BUFFER_INDEX_MASK {
38                0 => None,
39                x => Some(x - 1),
40            },
41            write_lock: match (value >> Self::WRITE_LOCK_OFFSET) & BUFFER_INDEX_MASK {
42                0 => None,
43                x => Some(x - 1),
44            },
45            most_recent: (value >> Self::MOST_RECENT_OFFSET) & BUFFER_INDEX_MASK,
46            readers_count: value >> Self::READERS_COUNT_OFFSET,
47        }
48    }
49
50    const fn pack(self) -> usize {
51        #[cfg(debug_assertions)]
52        self.assert();
53
54        (match self.read_lock {
55            Some(x) => x + 1,
56            None => 0,
57        } << Self::READ_LOCK_OFFSET)
58            | (match self.write_lock {
59                Some(x) => x + 1,
60                None => 0,
61            } << Self::WRITE_LOCK_OFFSET)
62            | (self.most_recent << Self::MOST_RECENT_OFFSET)
63            | (self.readers_count << Self::READERS_COUNT_OFFSET)
64    }
65
66    const fn assert(self) {
67        match self.read_lock {
68            None => (),
69            Some(0..NUM_BUFFERS) => (),
70            _ => panic!("Invalid read_lock state"),
71        }
72        match self.write_lock {
73            None => (),
74            Some(0..NUM_BUFFERS) => (),
75            _ => panic!("Invalid write_lock state"),
76        }
77        match self.most_recent {
78            0..NUM_BUFFERS => (),
79            _ => panic!("most_recent index is out of bounds"),
80        }
81        if self.readers_count > Self::MAX_READERS_COUNT {
82            panic!("reader_count overflow");
83        }
84    }
85}
86
87/// A value that can be loaded and stored without locking.
88///
89/// Loads and a store can run concurrently.
90/// There may be up to [`Self::MAX_READERS_COUNT`] concurrent loads but only single concurrent store allowed.
91/// You can think of slot as of SPMC atomic varible but for non-atomic types.
92///
93/// Implemented using a triple-buffer.
94pub struct Slot<T: Copy> {
95    values: [UnsafeCell<T>; NUM_BUFFERS],
96    state: AtomicUsize,
97}
98
99unsafe impl<T: Copy + Send> Send for Slot<T> {}
100unsafe impl<T: Copy + Send> Sync for Slot<T> {}
101
102impl<T: Copy + Default> Default for Slot<T> {
103    fn default() -> Self {
104        Self::new(T::default())
105    }
106}
107
108impl<T: Copy> Slot<T> {
109    pub const MAX_READERS_COUNT: usize = State::MAX_READERS_COUNT;
110
111    pub const fn new(value: T) -> Self {
112        Self {
113            values: [
114                UnsafeCell::new(value),
115                UnsafeCell::new(value),
116                UnsafeCell::new(value),
117            ],
118            state: AtomicUsize::new(State::new().pack()),
119        }
120    }
121
122    /// On success returns stored value.
123    /// + If there is no concurrent load takes place then the most recent value is returned
124    /// + If there are concurring loads then the same value is returned by all of them.
125    pub fn load(&self) -> T {
126        let mut new_state = None;
127        self.state
128            .fetch_update(Ordering::AcqRel, Ordering::Acquire, |packed| {
129                let mut state = State::unpack(packed);
130
131                match (state.read_lock, state.readers_count) {
132                    (None, 0) => {
133                        state.read_lock = Some(state.most_recent);
134                    }
135                    (Some(_), 0..State::MAX_READERS_COUNT) => {
136                        state.readers_count = state.readers_count.wrapping_add(1);
137                    }
138                    (Some(_), State::MAX_READERS_COUNT..) => {
139                        panic!("Maximum number of readers reached");
140                    }
141                    (None, 1..) => {
142                        unreachable!()
143                    }
144                }
145
146                new_state = Some(state);
147                Some(state.pack())
148            })
149            .unwrap();
150        let state = new_state.unwrap();
151
152        let read_index = state.read_lock.unwrap();
153        let value_ptr = self.values[read_index].get();
154        let value = unsafe { value_ptr.read() };
155
156        self.state
157            .fetch_update(Ordering::AcqRel, Ordering::Acquire, |packed| {
158                let mut state = State::unpack(packed);
159
160                match (state.read_lock, state.readers_count) {
161                    (Some(_), 0) => {
162                        state.read_lock = None;
163                    }
164                    (Some(_), 1..) => {
165                        state.readers_count = state.readers_count.wrapping_sub(1);
166                    }
167                    (None, ..) => {
168                        unreachable!()
169                    }
170                }
171                Some(state.pack())
172            })
173            .unwrap();
174
175        value
176    }
177
178    pub fn store(&self, value: T) -> Result<(), SlotStoreError> {
179        let mut new_state = None;
180        let update = self
181            .state
182            .fetch_update(Ordering::AcqRel, Ordering::Acquire, |packed| {
183                let mut state = State::unpack(packed);
184
185                match state.write_lock {
186                    None => {
187                        let mut write_index = NUM_BUFFERS;
188                        for i in 0..NUM_BUFFERS {
189                            if i == state.most_recent {
190                                continue;
191                            }
192                            if let Some(read_index) = state.read_lock
193                                && i == read_index
194                            {
195                                continue;
196                            }
197                            write_index = i;
198                            break;
199                        }
200                        state.write_lock = Some(write_index);
201                    }
202                    Some(_) => {
203                        return None;
204                    }
205                }
206
207                new_state = Some(state);
208                Some(state.pack())
209            });
210
211        if update.is_err() {
212            return Err(SlotStoreError::ConcurrentStore);
213        }
214        let state = new_state.unwrap();
215
216        let write_index = state.write_lock.unwrap();
217        let value_ptr = self.values[write_index].get();
218        unsafe { value_ptr.write(value) };
219
220        self.state
221            .fetch_update(Ordering::AcqRel, Ordering::Acquire, |packed| {
222                let mut state = State::unpack(packed);
223
224                match state.write_lock {
225                    Some(write_index) => {
226                        state.write_lock = None;
227                        state.most_recent = write_index;
228                    }
229                    None => {
230                        unreachable!()
231                    }
232                }
233                Some(state.pack())
234            })
235            .unwrap();
236
237        Ok(())
238    }
239}
240
241#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
242pub enum SlotStoreError {
243    ConcurrentStore,
244}
245
246impl Display for SlotStoreError {
247    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
248        match self {
249            Self::ConcurrentStore => write!(f, "Concurrent store is taking place already"),
250        }
251    }
252}
253
254impl<T: Copy + Debug> Debug for Slot<T> {
255    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
256        self.load().fmt(f)
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    extern crate std;
263
264    use crate::{Slot, SlotStoreError};
265    use std::{
266        sync::{Arc, Barrier},
267        thread,
268        vec::Vec,
269    };
270
271    const NUM_ITEMS: usize = 1000;
272
273    #[test]
274    #[ignore = "heavy"]
275    fn concurrent_stores() {
276        let slot = Arc::new(Slot::new([usize::MAX; NUM_ITEMS]));
277
278        const NUM_THREADS: usize = 16;
279        const NUM_ATTEMPTS: usize = 1000;
280
281        let mut handles = Vec::with_capacity(NUM_THREADS);
282        let barrier = Arc::new(Barrier::new(NUM_THREADS));
283        for k in 0..NUM_THREADS {
284            let slot = slot.clone();
285            let barrier = barrier.clone();
286            handles.push(thread::spawn(move || {
287                barrier.wait();
288                for i in 0..NUM_ATTEMPTS {
289                    if let Err(SlotStoreError::ConcurrentStore) =
290                        slot.store([k * NUM_ATTEMPTS + i; NUM_ITEMS])
291                    {
292                        return true;
293                    }
294                }
295                false
296            }));
297        }
298
299        assert!(handles.into_iter().any(|h| h.join().unwrap()));
300    }
301
302    #[test]
303    #[ignore = "heavy"]
304    fn concurrent_loads() {
305        let slot = Arc::new(Slot::new([usize::MAX; NUM_ITEMS]));
306
307        const NUM_THREADS: usize = 16;
308        const NUM_ATTEMPTS: usize = 1000;
309
310        let mut load_handles = Vec::with_capacity(NUM_THREADS);
311        let barrier = Arc::new(Barrier::new(NUM_THREADS + 1));
312        let store_handle = {
313            let slot = slot.clone();
314            let barrier = barrier.clone();
315            thread::spawn(move || {
316                barrier.wait();
317                for i in 0..NUM_ATTEMPTS {
318                    slot.store([i; NUM_ITEMS]).unwrap();
319                }
320                false
321            })
322        };
323        for _ in 0..NUM_THREADS {
324            let slot = slot.clone();
325            let barrier = barrier.clone();
326            load_handles.push(thread::spawn(move || {
327                barrier.wait();
328                for _ in 0..NUM_ATTEMPTS {
329                    let value = slot.load();
330                    let mut iter = value.into_iter();
331                    let item = iter.next().unwrap();
332                    assert!(iter.all(|x| x == item));
333                }
334                false
335            }));
336        }
337
338        store_handle.join().unwrap();
339        for h in load_handles {
340            h.join().unwrap();
341        }
342    }
343
344    #[test]
345    fn load_store() {
346        let slot = Slot::new(0);
347        assert_eq!(slot.load(), 0);
348
349        slot.store(1).unwrap();
350        assert_eq!(slot.load(), 1);
351
352        slot.store(2).unwrap();
353        assert_eq!(slot.load(), 2);
354    }
355}