optimistic_cell/
lib.rs

1//! A lock-like structure that allows concurrent access to
2//! the contents with very efficient cache coherency behavior.
3
4use std::cell::UnsafeCell;
5use std::ops::{Deref, DerefMut};
6use std::sync::atomic::{fence, AtomicU64, Ordering};
7
8const LOCK_BIT: u64 = 1 << 63;
9const LOCK_MASK: u64 = u64::MAX ^ LOCK_BIT;
10
11pub struct OptimisticCell<T> {
12    guard: AtomicU64,
13    state: UnsafeCell<T>,
14}
15
16unsafe impl<T: Send> Send for OptimisticCell<T> {}
17unsafe impl<T: Sync> Sync for OptimisticCell<T> {}
18
19pub struct OptimisticWriteGuard<'a, T> {
20    cell: &'a OptimisticCell<T>,
21    previous_unlocked_guard_state: u64,
22}
23
24impl<'a, T> Deref for OptimisticWriteGuard<'a, T> {
25    type Target = T;
26    fn deref(&self) -> &T {
27        unsafe { &*self.cell.state.get() }
28    }
29}
30
31impl<'a, T> DerefMut for OptimisticWriteGuard<'a, T> {
32    fn deref_mut(&mut self) -> &mut T {
33        unsafe { &mut *self.cell.state.get() }
34    }
35}
36
37impl<'a, T> Drop for OptimisticWriteGuard<'a, T> {
38    fn drop(&mut self) {
39        let new_guard_state = (self.previous_unlocked_guard_state + 1) ^ LOCK_MASK;
40        let old = self.cell.guard.swap(new_guard_state, Ordering::Release);
41        assert_eq!(old, self.previous_unlocked_guard_state ^ LOCK_BIT);
42    }
43}
44
45impl<T> OptimisticCell<T> {
46    #[inline]
47    fn status(&self) -> (bool, u64) {
48        let guard_value = self.guard.load(Ordering::Acquire);
49        let is_locked = guard_value & LOCK_BIT != 0;
50        let timestamp = guard_value & LOCK_MASK;
51
52        (is_locked, timestamp)
53    }
54
55    pub fn new(item: T) -> OptimisticCell<T> {
56        OptimisticCell {
57            guard: 0.into(),
58            state: UnsafeCell::new(item),
59        }
60    }
61
62    pub fn read(&self) -> T
63    where
64        T: Copy,
65    {
66        self.read_with(|item| *item)
67    }
68
69    pub fn read_with<R, F: Fn(&T) -> R>(&self, read_function: F) -> R
70    where
71        T: Copy,
72    {
73        loop {
74            let (before_is_locked, before_timestamp) = self.status();
75            if before_is_locked {
76                std::hint::spin_loop();
77                continue;
78            }
79
80            let state: &T = unsafe { &*self.state.get() };
81            let ret = read_function(state);
82
83            // NB: a Release fence is important for keeping the above read from
84            // being reordered below this validation check.
85            fence(Ordering::Release);
86
87            let (after_is_locked, after_timestamp) = self.status();
88
89            if after_is_locked || after_timestamp != before_timestamp {
90                std::hint::spin_loop();
91                continue;
92            }
93
94            return ret;
95        }
96    }
97
98    pub fn lock(&self) -> OptimisticWriteGuard<'_, T> {
99        loop {
100            let prev = self.guard.fetch_or(LOCK_BIT, Ordering::Acquire);
101            let already_locked = prev & LOCK_BIT != 0;
102
103            if !already_locked {
104                return OptimisticWriteGuard {
105                    previous_unlocked_guard_state: prev,
106                    cell: self,
107                };
108            }
109
110            std::hint::spin_loop();
111        }
112    }
113}
114
115#[test]
116fn concurrent_test() {
117    let n: u32 = 128 * 1024 * 1024;
118    let concurrency = 2;
119
120    let cell = &OptimisticCell::new(0);
121    let barrier = &std::sync::Barrier::new(concurrency as _);
122
123    let before = std::time::Instant::now();
124    std::thread::scope(|s| {
125        let mut threads = vec![];
126        for _ in 0..concurrency {
127            let thread = s.spawn(move || {
128                barrier.wait();
129
130                for _ in 0..n {
131                    let read_1 = cell.read();
132
133                    let mut lock = cell.lock();
134                    *lock += 1;
135                    drop(lock);
136
137                    let read_2 = cell.read();
138
139                    assert_ne!(read_1, read_2);
140                }
141            });
142
143            threads.push(thread);
144        }
145        for thread in threads {
146            thread.join().unwrap();
147        }
148    });
149    dbg!(before.elapsed());
150}