syncell/
lib.rs

1//! Synchronized Cell
2//!
3//! Main principles:
4//!   1. if you change state, and it's fine, you reverse it on drop()
5//!   2. if you found a problem, still undo your change, and then panic()
6
7#[cfg(loom)]
8use loom as mystd;
9#[cfg(not(loom))]
10use std as mystd;
11
12use mystd::{
13    cell::UnsafeCell,
14    sync::atomic::{AtomicUsize, Ordering},
15};
16use std::{mem, ops};
17
18const WRITE_BIT: usize = 1 << (mem::size_of::<usize>() * 8 - 1);
19
20/// A shared reference to `SynCell` data.
21pub struct SynRef<'a, T> {
22    state: &'a AtomicUsize,
23    value: &'a T,
24}
25
26impl<T> Drop for SynRef<'_, T> {
27    fn drop(&mut self) {
28        self.state.fetch_sub(1, Ordering::Release);
29    }
30}
31
32impl<T> ops::Deref for SynRef<'_, T> {
33    type Target = T;
34    fn deref(&self) -> &T {
35        self.value
36    }
37}
38
39/// A mutable reference to `SynCell` data.
40pub struct SynRefMut<'a, T> {
41    state: &'a AtomicUsize,
42    value: &'a mut T,
43}
44
45impl<T> Drop for SynRefMut<'_, T> {
46    fn drop(&mut self) {
47        self.state.fetch_and(!WRITE_BIT, Ordering::Release);
48    }
49}
50
51impl<T> ops::Deref for SynRefMut<'_, T> {
52    type Target = T;
53    fn deref(&self) -> &T {
54        self.value
55    }
56}
57
58impl<T> ops::DerefMut for SynRefMut<'_, T> {
59    fn deref_mut(&mut self) -> &mut T {
60        self.value
61    }
62}
63
64/// A Sync cell. Stores a value of type `T` and allows
65/// to access it behind a reference. `SynCell` follows Rust borrowing
66/// rules but checks them at run time as opposed to compile time.
67pub struct SynCell<T> {
68    state: AtomicUsize,
69    value: UnsafeCell<T>,
70}
71
72unsafe impl<T> Sync for SynCell<T> {}
73
74impl<T> SynCell<T> {
75    /// Create a new cell.
76    pub fn new(value: T) -> Self {
77        Self {
78            state: AtomicUsize::new(0),
79            value: UnsafeCell::new(value),
80        }
81    }
82
83    /// Convert into the value.
84    pub fn into_inner(self) -> T {
85        debug_assert_eq!(self.state.load(Ordering::Acquire), 0);
86        self.value.into_inner()
87    }
88
89    /// Get a direct mutable reference to the data.
90    pub fn get_mut(&mut self) -> &mut T {
91        debug_assert_eq!(self.state.load(Ordering::Acquire), 0);
92        self.value.get_mut()
93    }
94
95    /// Borrow immutably (can be shared).
96    ///
97    /// Panics if the value is already borrowed mutably.
98    pub fn borrow(&self) -> SynRef<T> {
99        let old = self.state.fetch_add(1, Ordering::AcqRel);
100        if old & WRITE_BIT != 0 {
101            self.state.fetch_sub(1, Ordering::Release);
102            panic!("SynCell is mutably borrowed elsewhere!");
103        }
104        SynRef {
105            state: &self.state,
106            value: unsafe { &*self.value.get() },
107        }
108    }
109
110    /// Borrow mutably (exclusive).
111    ///
112    /// Panics if the value is already borrowed in any way.
113    pub fn borrow_mut(&self) -> SynRefMut<T> {
114        let old = self.state.fetch_or(WRITE_BIT, Ordering::AcqRel);
115        if old & WRITE_BIT != 0 {
116            panic!("SynCell is mutably borrowed elsewhere!");
117        } else if old != 0 {
118            self.state.fetch_and(!WRITE_BIT, Ordering::Release);
119            panic!("SynCell is immutably borrowed elsewhere!");
120        }
121        SynRefMut {
122            state: &self.state,
123            value: unsafe { &mut *self.value.get() },
124        }
125    }
126}
127
128#[test]
129fn valid() {
130    let sc = SynCell::new(0u8);
131    {
132        let mut bw = sc.borrow_mut();
133        *bw += 1;
134    }
135    {
136        let b1 = sc.borrow();
137        let b2 = sc.borrow();
138        assert_eq!(*b1 + *b2, 2);
139    }
140}
141
142#[test]
143#[should_panic]
144fn bad_write_write() {
145    let sc = SynCell::new(0u8);
146    let _b1 = sc.borrow_mut();
147    let _b2 = sc.borrow_mut();
148}
149
150#[test]
151#[should_panic]
152fn bad_read_write() {
153    let sc = SynCell::new(0u8);
154    let _b1 = sc.borrow();
155    let _b2 = sc.borrow_mut();
156}
157
158#[test]
159#[should_panic]
160fn bad_write_read() {
161    let sc = SynCell::new(0u8);
162    let _b1 = sc.borrow_mut();
163    let _b2 = sc.borrow();
164}
165
166#[test]
167fn fight() {
168    use mystd::{
169        sync::{Arc, RwLock},
170        thread,
171    };
172    const NUM_THREADS: usize = 3;
173    const NUM_LOCKS: usize = if cfg!(miri) { 100 } else { 10000 };
174    // Since `SynCell` is inside `RwLock`, it's guaranteed
175    // that all the access is rightful, and no panic is expected.
176    let value = Arc::new(RwLock::new(SynCell::new(0usize)));
177    let sum = Arc::new(AtomicUsize::new(0));
178    let join_handles = (0..NUM_THREADS).map(|i| {
179        let sum = Arc::clone(&sum);
180        let value = Arc::clone(&value);
181        thread::spawn(move || {
182            for j in 0..NUM_LOCKS {
183                if (i + j) % NUM_THREADS == 0 {
184                    let sc = value.write().unwrap();
185                    *sc.borrow_mut() += 1;
186                } else {
187                    let sc = value.read().unwrap();
188                    let v = *sc.borrow();
189                    sum.fetch_add(v, Ordering::Relaxed);
190                }
191            }
192        })
193    });
194    for jh in join_handles {
195        jh.join().unwrap();
196    }
197}