1use core::{
2 cell::UnsafeCell,
3 fmt::{self, Debug, Display, Formatter},
4 sync::atomic::{AtomicUsize, Ordering},
5};
6
7const NUM_BUFFERS: usize = 3;
8const BUFFER_INDEX_BITS: u32 = usize::BITS - NUM_BUFFERS.leading_zeros();
9const BUFFER_INDEX_MASK: usize = (1 << BUFFER_INDEX_BITS) - 1;
10
11#[derive(Clone, Copy, PartialEq, Eq, Debug)]
12struct State {
13 read_lock: Option<usize>,
14 write_lock: Option<usize>,
15 most_recent: usize,
16 readers_count: usize,
17}
18
19impl State {
20 const READ_LOCK_OFFSET: u32 = 0;
21 const WRITE_LOCK_OFFSET: u32 = Self::READ_LOCK_OFFSET + BUFFER_INDEX_BITS;
22 const MOST_RECENT_OFFSET: u32 = Self::WRITE_LOCK_OFFSET + BUFFER_INDEX_BITS;
23 const READERS_COUNT_OFFSET: u32 = Self::MOST_RECENT_OFFSET + BUFFER_INDEX_BITS;
24
25 const MAX_READERS_COUNT: usize = usize::MAX >> Self::READERS_COUNT_OFFSET;
26
27 const fn new() -> Self {
28 State {
29 read_lock: None,
30 write_lock: None,
31 most_recent: 0,
32 readers_count: 0,
33 }
34 }
35 const fn unpack(value: usize) -> Self {
36 State {
37 read_lock: match (value >> Self::READ_LOCK_OFFSET) & BUFFER_INDEX_MASK {
38 0 => None,
39 x => Some(x - 1),
40 },
41 write_lock: match (value >> Self::WRITE_LOCK_OFFSET) & BUFFER_INDEX_MASK {
42 0 => None,
43 x => Some(x - 1),
44 },
45 most_recent: (value >> Self::MOST_RECENT_OFFSET) & BUFFER_INDEX_MASK,
46 readers_count: value >> Self::READERS_COUNT_OFFSET,
47 }
48 }
49
50 const fn pack(self) -> usize {
51 #[cfg(debug_assertions)]
52 self.assert();
53
54 (match self.read_lock {
55 Some(x) => x + 1,
56 None => 0,
57 } << Self::READ_LOCK_OFFSET)
58 | (match self.write_lock {
59 Some(x) => x + 1,
60 None => 0,
61 } << Self::WRITE_LOCK_OFFSET)
62 | (self.most_recent << Self::MOST_RECENT_OFFSET)
63 | (self.readers_count << Self::READERS_COUNT_OFFSET)
64 }
65
66 const fn assert(self) {
67 match self.read_lock {
68 None => (),
69 Some(0..NUM_BUFFERS) => (),
70 _ => panic!("Invalid read_lock state"),
71 }
72 match self.write_lock {
73 None => (),
74 Some(0..NUM_BUFFERS) => (),
75 _ => panic!("Invalid write_lock state"),
76 }
77 match self.most_recent {
78 0..NUM_BUFFERS => (),
79 _ => panic!("most_recent index is out of bounds"),
80 }
81 if self.readers_count > Self::MAX_READERS_COUNT {
82 panic!("reader_count overflow");
83 }
84 }
85}
86
87pub struct Slot<T: Copy> {
95 values: [UnsafeCell<T>; NUM_BUFFERS],
96 state: AtomicUsize,
97}
98
99unsafe impl<T: Copy + Send> Send for Slot<T> {}
100unsafe impl<T: Copy + Send> Sync for Slot<T> {}
101
102impl<T: Copy + Default> Default for Slot<T> {
103 fn default() -> Self {
104 Self::new(T::default())
105 }
106}
107
108impl<T: Copy> Slot<T> {
109 pub const MAX_READERS_COUNT: usize = State::MAX_READERS_COUNT;
110
111 pub const fn new(value: T) -> Self {
112 Self {
113 values: [
114 UnsafeCell::new(value),
115 UnsafeCell::new(value),
116 UnsafeCell::new(value),
117 ],
118 state: AtomicUsize::new(State::new().pack()),
119 }
120 }
121
122 pub fn load(&self) -> T {
126 let mut new_state = None;
127 self.state
128 .fetch_update(Ordering::AcqRel, Ordering::Acquire, |packed| {
129 let mut state = State::unpack(packed);
130
131 match (state.read_lock, state.readers_count) {
132 (None, 0) => {
133 state.read_lock = Some(state.most_recent);
134 }
135 (Some(_), 0..State::MAX_READERS_COUNT) => {
136 state.readers_count = state.readers_count.wrapping_add(1);
137 }
138 (Some(_), State::MAX_READERS_COUNT..) => {
139 panic!("Maximum number of readers reached");
140 }
141 (None, 1..) => {
142 unreachable!()
143 }
144 }
145
146 new_state = Some(state);
147 Some(state.pack())
148 })
149 .unwrap();
150 let state = new_state.unwrap();
151
152 let read_index = state.read_lock.unwrap();
153 let value_ptr = self.values[read_index].get();
154 let value = unsafe { value_ptr.read() };
155
156 self.state
157 .fetch_update(Ordering::AcqRel, Ordering::Acquire, |packed| {
158 let mut state = State::unpack(packed);
159
160 match (state.read_lock, state.readers_count) {
161 (Some(_), 0) => {
162 state.read_lock = None;
163 }
164 (Some(_), 1..) => {
165 state.readers_count = state.readers_count.wrapping_sub(1);
166 }
167 (None, ..) => {
168 unreachable!()
169 }
170 }
171 Some(state.pack())
172 })
173 .unwrap();
174
175 value
176 }
177
178 pub fn store(&self, value: T) -> Result<(), SlotStoreError> {
179 let mut new_state = None;
180 let update = self
181 .state
182 .fetch_update(Ordering::AcqRel, Ordering::Acquire, |packed| {
183 let mut state = State::unpack(packed);
184
185 match state.write_lock {
186 None => {
187 let mut write_index = NUM_BUFFERS;
188 for i in 0..NUM_BUFFERS {
189 if i == state.most_recent {
190 continue;
191 }
192 if let Some(read_index) = state.read_lock
193 && i == read_index
194 {
195 continue;
196 }
197 write_index = i;
198 break;
199 }
200 state.write_lock = Some(write_index);
201 }
202 Some(_) => {
203 return None;
204 }
205 }
206
207 new_state = Some(state);
208 Some(state.pack())
209 });
210
211 if update.is_err() {
212 return Err(SlotStoreError::ConcurrentStore);
213 }
214 let state = new_state.unwrap();
215
216 let write_index = state.write_lock.unwrap();
217 let value_ptr = self.values[write_index].get();
218 unsafe { value_ptr.write(value) };
219
220 self.state
221 .fetch_update(Ordering::AcqRel, Ordering::Acquire, |packed| {
222 let mut state = State::unpack(packed);
223
224 match state.write_lock {
225 Some(write_index) => {
226 state.write_lock = None;
227 state.most_recent = write_index;
228 }
229 None => {
230 unreachable!()
231 }
232 }
233 Some(state.pack())
234 })
235 .unwrap();
236
237 Ok(())
238 }
239}
240
241#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
242pub enum SlotStoreError {
243 ConcurrentStore,
244}
245
246impl Display for SlotStoreError {
247 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
248 match self {
249 Self::ConcurrentStore => write!(f, "Concurrent store is taking place already"),
250 }
251 }
252}
253
254impl<T: Copy + Debug> Debug for Slot<T> {
255 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
256 self.load().fmt(f)
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 extern crate std;
263
264 use crate::{Slot, SlotStoreError};
265 use std::{
266 sync::{Arc, Barrier},
267 thread,
268 vec::Vec,
269 };
270
271 const NUM_ITEMS: usize = 1000;
272
273 #[test]
274 #[ignore = "heavy"]
275 fn concurrent_stores() {
276 let slot = Arc::new(Slot::new([usize::MAX; NUM_ITEMS]));
277
278 const NUM_THREADS: usize = 16;
279 const NUM_ATTEMPTS: usize = 1000;
280
281 let mut handles = Vec::with_capacity(NUM_THREADS);
282 let barrier = Arc::new(Barrier::new(NUM_THREADS));
283 for k in 0..NUM_THREADS {
284 let slot = slot.clone();
285 let barrier = barrier.clone();
286 handles.push(thread::spawn(move || {
287 barrier.wait();
288 for i in 0..NUM_ATTEMPTS {
289 if let Err(SlotStoreError::ConcurrentStore) =
290 slot.store([k * NUM_ATTEMPTS + i; NUM_ITEMS])
291 {
292 return true;
293 }
294 }
295 false
296 }));
297 }
298
299 assert!(handles.into_iter().any(|h| h.join().unwrap()));
300 }
301
302 #[test]
303 #[ignore = "heavy"]
304 fn concurrent_loads() {
305 let slot = Arc::new(Slot::new([usize::MAX; NUM_ITEMS]));
306
307 const NUM_THREADS: usize = 16;
308 const NUM_ATTEMPTS: usize = 1000;
309
310 let mut load_handles = Vec::with_capacity(NUM_THREADS);
311 let barrier = Arc::new(Barrier::new(NUM_THREADS + 1));
312 let store_handle = {
313 let slot = slot.clone();
314 let barrier = barrier.clone();
315 thread::spawn(move || {
316 barrier.wait();
317 for i in 0..NUM_ATTEMPTS {
318 slot.store([i; NUM_ITEMS]).unwrap();
319 }
320 false
321 })
322 };
323 for _ in 0..NUM_THREADS {
324 let slot = slot.clone();
325 let barrier = barrier.clone();
326 load_handles.push(thread::spawn(move || {
327 barrier.wait();
328 for _ in 0..NUM_ATTEMPTS {
329 let value = slot.load();
330 let mut iter = value.into_iter();
331 let item = iter.next().unwrap();
332 assert!(iter.all(|x| x == item));
333 }
334 false
335 }));
336 }
337
338 store_handle.join().unwrap();
339 for h in load_handles {
340 h.join().unwrap();
341 }
342 }
343
344 #[test]
345 fn load_store() {
346 let slot = Slot::new(0);
347 assert_eq!(slot.load(), 0);
348
349 slot.store(1).unwrap();
350 assert_eq!(slot.load(), 1);
351
352 slot.store(2).unwrap();
353 assert_eq!(slot.load(), 2);
354 }
355}