1use std::cell::UnsafeCell;
5use std::ops::{Deref, DerefMut};
6use std::sync::atomic::{fence, AtomicU64, Ordering};
7
8const LOCK_BIT: u64 = 1 << 63;
9const LOCK_MASK: u64 = u64::MAX ^ LOCK_BIT;
10
11pub struct OptimisticCell<T> {
12 guard: AtomicU64,
13 state: UnsafeCell<T>,
14}
15
16unsafe impl<T: Send> Send for OptimisticCell<T> {}
17unsafe impl<T: Sync> Sync for OptimisticCell<T> {}
18
19pub struct OptimisticWriteGuard<'a, T> {
20 cell: &'a OptimisticCell<T>,
21 previous_unlocked_guard_state: u64,
22}
23
24impl<'a, T> Deref for OptimisticWriteGuard<'a, T> {
25 type Target = T;
26 fn deref(&self) -> &T {
27 unsafe { &*self.cell.state.get() }
28 }
29}
30
31impl<'a, T> DerefMut for OptimisticWriteGuard<'a, T> {
32 fn deref_mut(&mut self) -> &mut T {
33 unsafe { &mut *self.cell.state.get() }
34 }
35}
36
37impl<'a, T> Drop for OptimisticWriteGuard<'a, T> {
38 fn drop(&mut self) {
39 let new_guard_state = (self.previous_unlocked_guard_state + 1) ^ LOCK_MASK;
40 let old = self.cell.guard.swap(new_guard_state, Ordering::Release);
41 assert_eq!(old, self.previous_unlocked_guard_state ^ LOCK_BIT);
42 }
43}
44
45impl<T> OptimisticCell<T> {
46 #[inline]
47 fn status(&self) -> (bool, u64) {
48 let guard_value = self.guard.load(Ordering::Acquire);
49 let is_locked = guard_value & LOCK_BIT != 0;
50 let timestamp = guard_value & LOCK_MASK;
51
52 (is_locked, timestamp)
53 }
54
55 pub fn new(item: T) -> OptimisticCell<T> {
56 OptimisticCell {
57 guard: 0.into(),
58 state: UnsafeCell::new(item),
59 }
60 }
61
62 pub fn read(&self) -> T
63 where
64 T: Copy,
65 {
66 self.read_with(|item| *item)
67 }
68
69 pub fn read_with<R, F: Fn(&T) -> R>(&self, read_function: F) -> R
70 where
71 T: Copy,
72 {
73 loop {
74 let (before_is_locked, before_timestamp) = self.status();
75 if before_is_locked {
76 std::hint::spin_loop();
77 continue;
78 }
79
80 let state: &T = unsafe { &*self.state.get() };
81 let ret = read_function(state);
82
83 fence(Ordering::Release);
86
87 let (after_is_locked, after_timestamp) = self.status();
88
89 if after_is_locked || after_timestamp != before_timestamp {
90 std::hint::spin_loop();
91 continue;
92 }
93
94 return ret;
95 }
96 }
97
98 pub fn lock(&self) -> OptimisticWriteGuard<'_, T> {
99 loop {
100 let prev = self.guard.fetch_or(LOCK_BIT, Ordering::Acquire);
101 let already_locked = prev & LOCK_BIT != 0;
102
103 if !already_locked {
104 return OptimisticWriteGuard {
105 previous_unlocked_guard_state: prev,
106 cell: self,
107 };
108 }
109
110 std::hint::spin_loop();
111 }
112 }
113}
114
115#[test]
116fn concurrent_test() {
117 let n: u32 = 128 * 1024 * 1024;
118 let concurrency = 2;
119
120 let cell = &OptimisticCell::new(0);
121 let barrier = &std::sync::Barrier::new(concurrency as _);
122
123 let before = std::time::Instant::now();
124 std::thread::scope(|s| {
125 let mut threads = vec![];
126 for _ in 0..concurrency {
127 let thread = s.spawn(move || {
128 barrier.wait();
129
130 for _ in 0..n {
131 let read_1 = cell.read();
132
133 let mut lock = cell.lock();
134 *lock += 1;
135 drop(lock);
136
137 let read_2 = cell.read();
138
139 assert_ne!(read_1, read_2);
140 }
141 });
142
143 threads.push(thread);
144 }
145 for thread in threads {
146 thread.join().unwrap();
147 }
148 });
149 dbg!(before.elapsed());
150}