provenant/
lib.rs

1use rand::Rng;
2use std::ops::Deref;
3use std::ptr;
4use std::sync::atomic::{compiler_fence, AtomicUsize, Ordering};
5
6/// An atomically reference counted shared pointer
7///
8/// See the documentation for [`Arc`](std::sync::Arc) in the standard library.
9/// This one has different weak pointers.
10pub struct Arc<T: ?Sized> {
11    ptr: *const Inner<T>,
12}
13
14/// A weak pointer to an atomically reference counted shared pointer
15///
16/// Can be upgraded to an [`Arc`], and will usually do the right thing.
17/// Does not prevent the pointed-to memory from being dropped or deallocated.
18#[derive(Copy, Clone)]
19pub struct Weak<T: ?Sized> {
20    provenance: usize,
21    ptr: *const Inner<T>,
22}
23
24struct Inner<T: ?Sized> {
25    // the low bit is used to locking, the rest are random provenance id
26    provenance: AtomicUsize,
27
28    // reference count of Arcs. Weak refs are uncounted
29    ref_count: AtomicUsize,
30
31    data: T,
32}
33
34impl<T: ?Sized> Drop for Inner<T> {
35    fn drop(&mut self) {
36        // using a volatile write followed by a fence should actually zero the memory
37        // and not get optimized out
38        unsafe {
39            ptr::write_volatile(&mut self.provenance, AtomicUsize::default());
40        }
41        compiler_fence(Ordering::SeqCst);
42    }
43}
44
45impl<T: ?Sized> Inner<T> {
46    fn weak(&self) -> Weak<T> {
47        let provenance = self.provenance.load(Ordering::Relaxed);
48        let provenance = provenance ^ (provenance & 1); //clear low bit
49        Weak {
50            provenance: provenance,
51            ptr: self as *const Inner<T>,
52        }
53    }
54
55    fn lock(&self, exp: usize) -> bool {
56        loop {
57            match self
58                .provenance
59                .compare_exchange(exp, exp | 1, Ordering::SeqCst, Ordering::SeqCst)
60            {
61                Ok(_) => return true,
62                Err(v) if v == exp | 1 => continue,
63                Err(_) => return false,
64            }
65        }
66    }
67}
68
69impl<T: ?Sized> Weak<T> {
70    /// Attempts to get a strong reference to the pointed-to memory. Will probably fail and return None
71    /// if there are no strong pointers left.
72    pub fn upgrade(&self) -> Option<Arc<T>> {
73        let exp = self.provenance;
74
75        let inner = unsafe { &(*self.ptr) };
76
77        if !inner.lock(exp) {
78            return None;
79        }
80
81        // increment ref count
82        inner.ref_count.fetch_add(1, Ordering::SeqCst);
83
84        // release the lock
85        inner.provenance.store(exp, Ordering::SeqCst);
86
87        Some(Arc { ptr: self.ptr })
88    }
89}
90
91impl<T: ?Sized> Drop for Arc<T> {
92    fn drop(&mut self) {
93        {
94            let inner = unsafe { &(*self.ptr) };
95
96            // we need to load provenance before decrementing ref count.
97            // otherwise, another thread could deallocate before the load happens
98            let exp = inner.provenance.load(Ordering::SeqCst);
99            let exp = exp ^ (exp & 1);
100
101            if inner.ref_count.fetch_sub(1, Ordering::SeqCst) > 1 {
102                return;
103            }
104
105            // if the lock fails, another thread must have dropped Inner already
106            // that can happen if this gets interrupted while a weak pointer
107            // upgrades and then drops (hitting 0 again)
108            if !inner.lock(exp) {
109                return;
110            }
111
112            // if the ref count isn't 0, a weak pointer managed to upgrade.
113            // it can deal with deallocating when it hits 0 again.
114            if inner.ref_count.load(Ordering::SeqCst) != 0 {
115                inner.provenance.store(exp, Ordering::SeqCst);
116                return;
117            }
118
119            // setting provenance to 0 isn't strictly necessary here, since Inner::drop does it
120            inner.provenance.store(0, Ordering::SeqCst);
121        }
122
123        unsafe {
124            Box::from_raw(self.ptr as *mut Inner<T>);
125        }
126    }
127}
128
129impl<T> Arc<T> {
130    /// Create a new shared reference
131    pub fn new(val: T) -> Self {
132        let mut rng = rand::thread_rng();
133        let provenance: usize = rng.gen();
134        let provenance = provenance ^ (provenance & 1);
135        let inner = Box::new(Inner {
136            provenance: AtomicUsize::new(provenance),
137            ref_count: AtomicUsize::new(1),
138            data: val,
139        });
140
141        let inner = Box::into_raw(inner) as *const Inner<T>;
142        Arc { ptr: inner }
143    }
144}
145
146impl<T: ?Sized> Arc<T> {
147    /// Gets a weak reference to the same memory
148    pub fn downgrade(this: &Self) -> Weak<T> {
149        let inner = unsafe { &(*this.ptr) };
150
151        inner.weak()
152    }
153}
154
155impl<T: ?Sized> Deref for Arc<T> {
156    type Target = T;
157    fn deref(&self) -> &Self::Target {
158        let inner = unsafe { &(*self.ptr) };
159
160        &inner.data
161    }
162}
163
164impl<T: ?Sized> Clone for Arc<T> {
165    fn clone(&self) -> Self {
166        let inner = unsafe { &(*self.ptr) };
167
168        inner.ref_count.fetch_add(1, Ordering::SeqCst);
169
170        Arc { ptr: self.ptr }
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    #[test]
178    fn use_after_free() {
179        let arc = Arc::new(50);
180        let weak = Arc::downgrade(&arc);
181
182        assert_eq!(50, *weak.upgrade().unwrap());
183
184        drop(arc);
185
186        assert!(weak.upgrade().is_none());
187    }
188
189    #[test]
190    fn use_after_clone() {
191        let arc = Arc::new(55);
192        let weak = Arc::downgrade(&arc);
193
194        let cloned = arc.clone();
195        assert_eq!(55, *weak.upgrade().unwrap());
196
197        drop(arc);
198
199        assert_eq!(55, *weak.upgrade().unwrap());
200
201        drop(cloned);
202
203        assert!(weak.upgrade().is_none());
204    }
205
206    #[test]
207    fn revive() {
208        let arc = Arc::new(21);
209        let weak = Arc::downgrade(&arc);
210        let strong = weak.upgrade().unwrap();
211        drop(arc);
212
213        assert_eq!(21, *weak.upgrade().unwrap());
214
215        drop(strong);
216
217        assert!(weak.upgrade().is_none());
218    }
219}