1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
#![deny(unsafe_op_in_unsafe_fn)]

//! An simple atomic reference counter.

use std::{
    ops::{Deref, DerefMut},
    sync::atomic::{AtomicUsize, Ordering},
};

/// An atomic reference counter.
#[repr(transparent)]
#[derive(Debug, Default)]
pub struct AtomicBorrow {
    borrow: AtomicUsize,
}

impl AtomicBorrow {
    /// The mask for the shared borrow count.
    pub const SHARED_MASK: usize = usize::MAX >> 1;
    /// The mask for the unique borrow bit.
    pub const UNIQUE_MASK: usize = !Self::SHARED_MASK;

    const SPIN_COUNT: usize = 1 << 10;

    /// Creates a new `AtomicBorrow`.
    #[inline]
    pub const fn new() -> Self {
        Self {
            borrow: AtomicUsize::new(0),
        }
    }

    /// Returns number of shared borrows.
    #[inline]
    pub fn shared_count(&self) -> usize {
        self.borrow.load(Ordering::Acquire) & Self::SHARED_MASK
    }

    /// Returns true if `self` is uniquely borrowed.
    #[inline]
    pub fn is_unique(&self) -> bool {
        self.borrow.load(Ordering::Acquire) & Self::UNIQUE_MASK == 0
    }

    /// Returns true if `self` is borrowed in any way.
    #[inline]
    pub fn is_borrowed(&self) -> bool {
        self.borrow.load(Ordering::Acquire) != 0
    }

    /// Tries to acquire a shared reference.
    ///
    /// Returns `true` if the reference was acquired.
    #[inline]
    pub fn borrow(&self) -> bool {
        let prev = self.borrow.fetch_add(1, Ordering::Acquire);

        if prev & Self::SHARED_MASK == Self::SHARED_MASK {
            panic!("borrow counter overflowed");
        }

        if prev & Self::UNIQUE_MASK != 0 {
            // we're already uniquely borrowed, so undo the increment and return false
            self.borrow.fetch_sub(1, Ordering::Release);
            false
        } else {
            true
        }
    }

    /// Tries to acquire a unique reference.
    ///
    /// Returns `true` if the reference was acquired.
    #[inline]
    pub fn borrow_mut(&self) -> bool {
        self.borrow
            .compare_exchange(0, Self::UNIQUE_MASK, Ordering::Acquire, Ordering::Relaxed)
            .is_ok()
    }

    /// Releases a shared reference.
    ///
    /// # Panics.
    /// * If `self` is not borrowed. Only with `debug_assertions` enabled.
    /// * If `self` is uniquely borrowed. Only with `debug_assertions` enabled.
    #[inline]
    pub fn release(&self) {
        let prev = self.borrow.fetch_sub(1, Ordering::Release);
        debug_assert_ne!(
            prev, 0,
            "borrow counter underflow, this means you released more times than you borrowed"
        );
        debug_assert_eq!(
            prev & Self::UNIQUE_MASK,
            0,
            "shared release of unique borrow"
        );
    }

    /// Releases a unique reference.
    ///
    /// # Panics.
    /// * If `self` is not uniquely borrowed. Only with `debug_assertions` enabled.
    #[inline]
    pub fn release_mut(&self) {
        let prev = self.borrow.fetch_and(!Self::UNIQUE_MASK, Ordering::Release);
        debug_assert_ne!(
            prev & Self::UNIQUE_MASK,
            0,
            "unique release of shared borrow"
        );
    }

    /// Spins until a shared reference can be acquired.
    #[inline]
    pub fn spin_borrow(&self) {
        for _ in 0..Self::SPIN_COUNT {
            if self.borrow() {
                return;
            }

            std::hint::spin_loop();
        }

        while !self.borrow() {
            std::thread::yield_now();
        }
    }

    /// Spins until a unique reference can be acquired.
    #[inline]
    pub fn spin_borrow_mut(&self) {
        for _ in 0..Self::SPIN_COUNT {
            if self.borrow_mut() {
                return;
            }

            std::hint::spin_loop();
        }

        while !self.borrow_mut() {
            std::thread::yield_now();
        }
    }
}

/// A guard that releases a shared reference when dropped.
pub struct SharedGuard<'a, T> {
    data: *const T,
    borrow: &'a AtomicBorrow,
}

impl<'a, T> SharedGuard<'a, T> {
    /// Creates a new [`SharedGuard`].
    #[inline]
    pub fn new(data: &'a T, borrow: &'a AtomicBorrow) -> Self {
        Self { data, borrow }
    }

    /// Tries to borrow the data.
    ///
    /// # Safety
    /// * Any borrows of `data` must be registered with `borrow`.
    /// * `data` must be a valid pointer for the entire lifetime of `self`.
    #[inline]
    pub unsafe fn try_new(data: *const T, borrow: &'a AtomicBorrow) -> Option<Self> {
        if borrow.borrow() {
            Some(Self { data, borrow })
        } else {
            None
        }
    }

    /// Spins until the data can be borrowed.
    ///
    /// # Safety
    /// * Any borrows of `data` must be registered with `borrow`.
    /// * `data` must be a valid pointer for the entire lifetime of `self`.
    #[inline]
    pub unsafe fn spin(data: *const T, borrow: &'a AtomicBorrow) -> Self {
        borrow.spin_borrow();
        Self { data, borrow }
    }

    #[inline]
    pub fn ptr(&self) -> *const T {
        self.data
    }

    #[inline]
    pub fn forget(self) -> *const T {
        let ptr = self.ptr();
        std::mem::forget(self);
        ptr
    }
}

impl<'a, T> Deref for SharedGuard<'a, T> {
    type Target = T;

    #[inline]
    fn deref(&self) -> &Self::Target {
        unsafe { &*self.data }
    }
}

impl<'a, T> Drop for SharedGuard<'a, T> {
    #[inline]
    fn drop(&mut self) {
        self.borrow.release();
    }
}

/// A guard that releases a unique reference when dropped.
pub struct UniqueGuard<'a, T> {
    data: *mut T,
    borrow: &'a AtomicBorrow,
}

impl<'a, T> UniqueGuard<'a, T> {
    /// Creates a new [`UniqueGuard`].
    #[inline]
    pub fn new(data: &'a mut T, borrow: &'a AtomicBorrow) -> Self {
        Self { data, borrow }
    }

    /// Tries to borrow the data.
    ///
    /// # Safety
    /// * Any borrows of `data` must be registered with `borrow`.
    /// * `data` must be a valid pointer for the entire lifetime of `self`.
    #[inline]
    pub unsafe fn try_new(data: *mut T, borrow: &'a AtomicBorrow) -> Option<Self> {
        if borrow.borrow_mut() {
            Some(Self { data, borrow })
        } else {
            None
        }
    }

    /// Spins until the data can be borrowed.
    ///
    /// # Safety
    /// * Any borrows of `data` must be registered with `borrow`.
    /// * `data` must be a valid pointer for the entire lifetime of `self`.
    #[inline]
    pub unsafe fn spin(data: *mut T, borrow: &'a AtomicBorrow) -> Self {
        borrow.spin_borrow_mut();
        Self { data, borrow }
    }

    #[inline]
    pub fn ptr(&self) -> *mut T {
        self.data
    }

    #[inline]
    pub fn forget(self) -> *mut T {
        let ptr = self.ptr();
        std::mem::forget(self);
        ptr
    }
}

impl<'a, T> Deref for UniqueGuard<'a, T> {
    type Target = T;

    #[inline]
    fn deref(&self) -> &Self::Target {
        unsafe { &*self.data }
    }
}

impl<'a, T> DerefMut for UniqueGuard<'a, T> {
    #[inline]
    fn deref_mut(&mut self) -> &mut Self::Target {
        unsafe { &mut *self.data }
    }
}

impl<'a, T> Drop for UniqueGuard<'a, T> {
    #[inline]
    fn drop(&mut self) {
        self.borrow.release_mut();
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn atomic_borrow() {
        let borrow = AtomicBorrow::new();

        assert!(borrow.borrow());
        assert!(borrow.borrow());

        assert!(!borrow.borrow_mut());

        borrow.release();
        borrow.release();

        assert!(borrow.borrow_mut());

        assert!(!borrow.borrow());

        borrow.release_mut();
    }
}