1#[cfg(loom)]
8use loom as mystd;
9#[cfg(not(loom))]
10use std as mystd;
11
12use mystd::{
13 cell::UnsafeCell,
14 sync::atomic::{AtomicUsize, Ordering},
15};
16use std::{mem, ops};
17
18const WRITE_BIT: usize = 1 << (mem::size_of::<usize>() * 8 - 1);
19
20pub struct SynRef<'a, T> {
22 state: &'a AtomicUsize,
23 value: &'a T,
24}
25
26impl<T> Drop for SynRef<'_, T> {
27 fn drop(&mut self) {
28 self.state.fetch_sub(1, Ordering::Release);
29 }
30}
31
32impl<T> ops::Deref for SynRef<'_, T> {
33 type Target = T;
34 fn deref(&self) -> &T {
35 self.value
36 }
37}
38
39pub struct SynRefMut<'a, T> {
41 state: &'a AtomicUsize,
42 value: &'a mut T,
43}
44
45impl<T> Drop for SynRefMut<'_, T> {
46 fn drop(&mut self) {
47 self.state.fetch_and(!WRITE_BIT, Ordering::Release);
48 }
49}
50
51impl<T> ops::Deref for SynRefMut<'_, T> {
52 type Target = T;
53 fn deref(&self) -> &T {
54 self.value
55 }
56}
57
58impl<T> ops::DerefMut for SynRefMut<'_, T> {
59 fn deref_mut(&mut self) -> &mut T {
60 self.value
61 }
62}
63
64pub struct SynCell<T> {
68 state: AtomicUsize,
69 value: UnsafeCell<T>,
70}
71
72unsafe impl<T> Sync for SynCell<T> {}
73
74impl<T> SynCell<T> {
75 pub fn new(value: T) -> Self {
77 Self {
78 state: AtomicUsize::new(0),
79 value: UnsafeCell::new(value),
80 }
81 }
82
83 pub fn into_inner(self) -> T {
85 debug_assert_eq!(self.state.load(Ordering::Acquire), 0);
86 self.value.into_inner()
87 }
88
89 pub fn get_mut(&mut self) -> &mut T {
91 debug_assert_eq!(self.state.load(Ordering::Acquire), 0);
92 self.value.get_mut()
93 }
94
95 pub fn borrow(&self) -> SynRef<T> {
99 let old = self.state.fetch_add(1, Ordering::AcqRel);
100 if old & WRITE_BIT != 0 {
101 self.state.fetch_sub(1, Ordering::Release);
102 panic!("SynCell is mutably borrowed elsewhere!");
103 }
104 SynRef {
105 state: &self.state,
106 value: unsafe { &*self.value.get() },
107 }
108 }
109
110 pub fn borrow_mut(&self) -> SynRefMut<T> {
114 let old = self.state.fetch_or(WRITE_BIT, Ordering::AcqRel);
115 if old & WRITE_BIT != 0 {
116 panic!("SynCell is mutably borrowed elsewhere!");
117 } else if old != 0 {
118 self.state.fetch_and(!WRITE_BIT, Ordering::Release);
119 panic!("SynCell is immutably borrowed elsewhere!");
120 }
121 SynRefMut {
122 state: &self.state,
123 value: unsafe { &mut *self.value.get() },
124 }
125 }
126}
127
128#[test]
129fn valid() {
130 let sc = SynCell::new(0u8);
131 {
132 let mut bw = sc.borrow_mut();
133 *bw += 1;
134 }
135 {
136 let b1 = sc.borrow();
137 let b2 = sc.borrow();
138 assert_eq!(*b1 + *b2, 2);
139 }
140}
141
142#[test]
143#[should_panic]
144fn bad_write_write() {
145 let sc = SynCell::new(0u8);
146 let _b1 = sc.borrow_mut();
147 let _b2 = sc.borrow_mut();
148}
149
150#[test]
151#[should_panic]
152fn bad_read_write() {
153 let sc = SynCell::new(0u8);
154 let _b1 = sc.borrow();
155 let _b2 = sc.borrow_mut();
156}
157
158#[test]
159#[should_panic]
160fn bad_write_read() {
161 let sc = SynCell::new(0u8);
162 let _b1 = sc.borrow_mut();
163 let _b2 = sc.borrow();
164}
165
166#[test]
167fn fight() {
168 use mystd::{
169 sync::{Arc, RwLock},
170 thread,
171 };
172 const NUM_THREADS: usize = 3;
173 const NUM_LOCKS: usize = if cfg!(miri) { 100 } else { 10000 };
174 let value = Arc::new(RwLock::new(SynCell::new(0usize)));
177 let sum = Arc::new(AtomicUsize::new(0));
178 let join_handles = (0..NUM_THREADS).map(|i| {
179 let sum = Arc::clone(&sum);
180 let value = Arc::clone(&value);
181 thread::spawn(move || {
182 for j in 0..NUM_LOCKS {
183 if (i + j) % NUM_THREADS == 0 {
184 let sc = value.write().unwrap();
185 *sc.borrow_mut() += 1;
186 } else {
187 let sc = value.read().unwrap();
188 let v = *sc.borrow();
189 sum.fetch_add(v, Ordering::Relaxed);
190 }
191 }
192 })
193 });
194 for jh in join_handles {
195 jh.join().unwrap();
196 }
197}