read_copy_update/
lib.rs

1#![doc = include_str!("../readme.md")]
2
3use std::{collections::VecDeque, ops::Deref, ptr};
4
5#[cfg(not(loom))]
6use std::{
7    cell::Cell,
8    sync::{
9        atomic::{AtomicPtr, AtomicU64, Ordering},
10        Arc, Mutex, MutexGuard,
11    },
12};
13
14#[cfg(loom)]
15use loom::{
16    cell::Cell,
17    sync::{
18        atomic::{AtomicPtr, AtomicU64, Ordering},
19        Arc, Mutex, MutexGuard,
20    },
21};
22
23/// Represents a read-copy-update for a specific value.
24///
25/// This is the write side, new read handles can be constructed by calling [`Rcu::reader`].
26#[derive(Debug)]
27pub struct Rcu<T> {
28    epoch: u64,
29    shared: Arc<Shared<T>>,
30}
31
32/// The reader handle for a value stored in an [`Rcu`].
33///
34/// Specific values can be read using [`Reader::read`]. Readers are `!Sync` and expected to be used
35/// only on a single thread.
36#[derive(Debug)]
37pub struct Reader<T: 'static> {
38    cache: Cell<&'static StampedValue<T>>,
39    refs: Cell<usize>,
40    state: ReaderState,
41    shared: Arc<Shared<T>>,
42}
43
44#[derive(Debug)]
45struct Shared<T> {
46    ptr: Pointer<T>,
47    reclaim: Mutex<VecDeque<Box<StampedValue<T>>>>,
48    readers: Mutex<Vec<ReaderState>>,
49}
50
51#[derive(Debug)]
52struct StampedValue<T> {
53    value: T,
54    epoch: u64,
55}
56
57#[derive(Debug, Clone)]
58struct ReaderState(Arc<AtomicU64>);
59
60#[derive(Debug)]
61struct Pointer<T>(AtomicPtr<StampedValue<T>>);
62
63#[derive(Debug)]
64pub struct Guard<'a, T: 'static> {
65    cache: &'a StampedValue<T>,
66    reader: &'a Reader<T>,
67}
68
69impl<T: 'static> Default for Rcu<T>
70where
71    T: Default,
72{
73    fn default() -> Self {
74        Self::new(T::default())
75    }
76}
77
78impl<T: 'static> Rcu<T> {
79    /// Constructs a new RCU with an initial value.
80    pub fn new(value: T) -> Self {
81        Self {
82            epoch: 1,
83            shared: Arc::new(Shared {
84                ptr: Pointer::new(StampedValue { value, epoch: 1 }),
85                reclaim: Mutex::new(VecDeque::new()),
86                readers: Mutex::new(Vec::new()),
87            }),
88        }
89    }
90
91    /// Registers a new [`Reader`] allowing values to be read.
92    pub fn reader(&mut self) -> Reader<T> {
93        Reader::new(self.shared.clone())
94    }
95
96    /// Write a new value making it available to all readers.
97    ///
98    /// Previously written values will be reclaimed when they are no longer accesed.
99    pub fn write(&mut self, value: T) {
100        // Records a new epoch associated with this value, not allowed to wrap around.
101        self.epoch += 1;
102
103        // Publish the value, we will attempt to reclaim the previous value.
104        let next = StampedValue {
105            epoch: self.epoch,
106            value,
107        };
108        let prev = self.shared.ptr.swap(next);
109        self.reclaim_queue().push_back(prev);
110
111        // Immediately attempt to reclaim.
112        self.try_reclaim();
113    }
114
115    /// Try and reclaim any values which are no longer in-use.
116    ///
117    /// Returns the number of values still waiting to be reclaimed.
118    pub fn try_reclaim(&mut self) -> usize {
119        let mut readers = self.shared.readers.lock().unwrap();
120
121        // Trim readers which have been removed.
122        readers.retain(|reader| reader.get() > ReaderState::NOT_IN_USE);
123
124        let mut reclaim = self.reclaim_queue();
125
126        // If there are no readers, we can reclaim everything.
127        if readers.is_empty() {
128            reclaim.clear();
129        }
130
131        // Check the minimum epoch across all active threads, removing records
132        // that are below the minimum epoch.
133        let min_epoch = readers
134            .iter()
135            .map(|r| r.get())
136            .min()
137            .unwrap_or(ReaderState::NOT_IN_USE);
138        while let Some(candidate) = reclaim.pop_front() {
139            if min_epoch > candidate.epoch {
140                drop(candidate);
141            } else {
142                reclaim.push_front(candidate);
143                // We short circuit, no point checking others with a higher epoch.
144                return reclaim.len();
145            }
146        }
147        0
148    }
149
150    fn reclaim_queue(&self) -> MutexGuard<'_, VecDeque<Box<StampedValue<T>>>> {
151        // The reclaimer must be shared so we can drop any remaining values when the 'Arc<Shared>'
152        // drops, but access to it should only be from this function. As a result we protect it with
153        // a mutex but only rely on try_lock().
154        self.shared
155            .reclaim
156            .try_lock()
157            .expect("invalid shared reclaimer access")
158    }
159}
160
161impl<T: 'static> Reader<T> {
162    fn new(shared: Arc<Shared<T>>) -> Self {
163        let value = shared.ptr.load();
164        let mut readers = shared.readers.lock().unwrap();
165
166        let state = ReaderState::new(value.epoch);
167        readers.push(state.clone());
168
169        Reader {
170            shared: shared.clone(),
171            refs: Cell::new(0),
172            cache: Cell::new(value),
173            state,
174        }
175    }
176
177    /// Reads the latest value guarded to ensure that the pointer will not be reclaimed while the
178    /// current reader has access.
179    pub fn read(&self) -> Guard<'_, T> {
180        // The read method provides a guard that allows deref access to one of the values written
181        // by the writer previously. The invariant this method maintains, using ref-counts, is that
182        // the epoch stamped on the current thread is always less than or equal to the epoch of the
183        // last used value. As soon as the reclaimer sees an epoch for a specific thread, it can be
184        // sure that no references with epochs 'below' the available epoch exist on that thread.
185
186        let cache = if self.refs.get() == 0 {
187            let value = self.shared.ptr.load();
188
189            // Update the epoch to note that we are currently using this value. This uses release
190            // ordering to ensure that loads when reclaiming will be ordered after this operation.
191            self.state.set(value.epoch);
192
193            // Cache the pointer in the current reader.
194            self.cache.set(value);
195            value
196        } else {
197            self.cache.get()
198        };
199
200        self.refs.set(self.refs.get() + 1);
201
202        Guard {
203            reader: self,
204            cache,
205        }
206    }
207}
208
209impl<'a, T> Deref for Guard<'a, T> {
210    type Target = T;
211
212    fn deref(&self) -> &Self::Target {
213        &self.cache.value
214    }
215}
216
217impl<'a, T> Drop for Guard<'a, T> {
218    fn drop(&mut self) {
219        self.reader.refs.replace(self.reader.refs.get() - 1);
220    }
221}
222
223impl<T> Drop for Reader<T> {
224    fn drop(&mut self) {
225        self.state.mark_dropped();
226    }
227}
228
229impl<T> Pointer<T> {
230    fn new(value: StampedValue<T>) -> Self {
231        Self(AtomicPtr::new(Box::leak(Box::new(value))))
232    }
233
234    fn swap(&self, value: StampedValue<T>) -> Box<StampedValue<T>> {
235        let ptr = Box::leak(Box::new(value));
236        let prev = self.0.swap(ptr, Ordering::AcqRel);
237        unsafe { Box::from_raw(prev) }
238    }
239
240    fn load(&self) -> &'static StampedValue<T> {
241        unsafe { &*self.0.load(Ordering::Relaxed) }
242    }
243}
244
245impl<T> Drop for Pointer<T> {
246    fn drop(&mut self) {
247        let prev = self.0.swap(ptr::null_mut(), Ordering::AcqRel);
248        let _ = unsafe { Box::from_raw(prev) };
249    }
250}
251
252impl ReaderState {
253    const NOT_IN_USE: u64 = 0;
254
255    fn new(epoch: u64) -> Self {
256        Self(Arc::new(AtomicU64::new(epoch)))
257    }
258
259    fn mark_dropped(&self) {
260        self.set(Self::NOT_IN_USE)
261    }
262
263    fn set(&self, epoch: u64) {
264        self.0.store(epoch, Ordering::Release)
265    }
266
267    fn get(&self) -> u64 {
268        self.0.load(Ordering::Acquire)
269    }
270}
271
272/// Provides thread-local storage to read [`Rcu`] values.
273///
274/// When a new thread is initialized a new [`Reader`] will be created and stored in a slot for the
275/// provided thread. Values will be published to the thread-local and access will be cheap
276#[cfg(feature = "thread-local")]
277pub struct ThreadLocal<T: Send + Sync + 'static> {
278    shared: Arc<Shared<T>>,
279    thread_local: thread_local::ThreadLocal<Reader<T>>,
280}
281
282#[cfg(feature = "thread-local")]
283impl<T: Send + Sync + 'static> ThreadLocal<T> {
284    pub fn new(rcu: &Rcu<T>) -> Self {
285        Self {
286            shared: rcu.shared.clone(),
287            thread_local: thread_local::ThreadLocal::new(),
288        }
289    }
290
291    /// Returns the element for the current thread, if it exists,
292    pub fn get(&self) -> Option<Guard<'_, T>> {
293        self.thread_local.get().map(|r| r.read())
294    }
295
296    /// Returns the element for the current thread, or creates it if it doesn't exist.
297    pub fn get_or_init(&self) -> Guard<'_, T> {
298        self.thread_local
299            .get_or(|| Reader::new(self.shared.clone()))
300            .read()
301    }
302}
303
304#[cfg(test)]
305#[cfg(loom)]
306mod loom_tests {
307    use loom::thread;
308
309    use super::*;
310
311    #[test]
312    fn nested() {
313        loom::model(|| {
314            let mut rcu = Rcu::new(10);
315
316            let rdr = rcu.reader();
317
318            {
319                let g = rdr.read();
320                assert_eq!(10, *g);
321
322                rcu.write(20);
323                {
324                    let g = rdr.read();
325                    assert_eq!(10, *g);
326                }
327            }
328        });
329    }
330
331    #[test]
332    fn thread_nested() {
333        loom::model(|| {
334            let n = 2;
335            let mut rcu = Rcu::new(0);
336            let rdr = rcu.reader();
337            let h = thread::spawn(move || {
338                let v = rdr.read();
339                assert!(*v < n);
340
341                {
342                    let g = rdr.read();
343                    assert!(*g < n);
344                }
345            });
346            for i in 0..n {
347                rcu.write(i);
348                loom::thread::yield_now();
349            }
350            h.join().unwrap();
351        });
352    }
353
354    #[test]
355    fn thread() {
356        loom::model(|| {
357            let n = 2;
358            let mut rcu = Rcu::new(0);
359            let rdr = rcu.reader();
360            let h = thread::spawn(move || {
361                for _ in 0..n {
362                    let v = rdr.read();
363                    assert!(*v < n);
364                    loom::thread::yield_now();
365                }
366            });
367            for i in 0..n {
368                rcu.write(i);
369                loom::thread::yield_now();
370            }
371            h.join().unwrap();
372        });
373    }
374
375    #[test]
376    fn thread_detached() {
377        loom::model(|| {
378            let n = 2;
379            let mut rcu = Rcu::new(0);
380            let rdr = rcu.reader();
381            thread::spawn(move || {
382                for _ in 0..n {
383                    let v = rdr.read();
384                    assert!(*v < n);
385                    loom::thread::yield_now();
386                }
387            });
388            for i in 0..n {
389                rcu.write(i);
390                loom::thread::yield_now();
391            }
392        });
393    }
394}
395
396#[cfg(test)]
397#[cfg(not(loom))]
398mod tests {
399    use std::{
400        sync::{atomic::AtomicUsize, Condvar},
401        thread,
402        time::Duration,
403    };
404
405    use super::*;
406
407    thread_local! {
408        static REFS: AtomicUsize = AtomicUsize::new(0);
409    }
410
411    struct RefsCheck;
412
413    impl RefsCheck {
414        fn new() -> Self {
415            REFS.with(|refs| {
416                assert_eq!(refs.load(Ordering::SeqCst), 0);
417            });
418            Self
419        }
420    }
421
422    impl Drop for RefsCheck {
423        fn drop(&mut self) {
424            REFS.with(|refs| {
425                assert_eq!(refs.load(Ordering::SeqCst), 0);
426            });
427        }
428    }
429
430    #[derive(Debug)]
431    struct RecordDrop(u32);
432
433    impl RecordDrop {
434        fn new(v: u32) -> Self {
435            REFS.with(|refs| {
436                refs.fetch_add(1, Ordering::SeqCst);
437            });
438            Self(v)
439        }
440    }
441
442    impl Drop for RecordDrop {
443        fn drop(&mut self) {
444            REFS.with(|refs| {
445                refs.fetch_sub(1, Ordering::SeqCst);
446            });
447        }
448    }
449
450    #[cfg(feature = "thread-local")]
451    #[test]
452    fn thread_local() {
453        let mut rcu = Rcu::new(10);
454        let tls = ThreadLocal::new(&rcu);
455
456        thread::scope(|s| {
457            s.spawn(|| {
458                let _val = tls.get_or_init();
459                assert!(tls.get().is_some());
460            });
461            s.spawn(|| {
462                let _val = tls.get_or_init();
463                assert!(tls.get().is_some());
464            });
465        });
466
467        rcu.write(1);
468    }
469
470    #[test]
471    fn send_check() {
472        let mut rcu = Rcu::new(10);
473        let rdr = rcu.reader();
474
475        thread::spawn(move || {
476            assert_eq!(10, *rdr.read());
477        });
478    }
479
480    #[test]
481    fn single_value() {
482        let _refs = RefsCheck::new();
483
484        let mut rcu = Rcu::new(RecordDrop::new(10));
485        let rdr = rcu.reader();
486        assert_eq!(10, rdr.read().0);
487    }
488
489    #[test]
490    fn old_value() {
491        let _refs = RefsCheck::new();
492
493        let mut rcu = Rcu::new(RecordDrop::new(10));
494        let rdr1 = rcu.reader();
495        assert_eq!(10, rdr1.read().0);
496
497        let rdr2 = rcu.reader();
498        assert_eq!(10, rdr2.read().0);
499
500        for i in 11..=20 {
501            rcu.write(RecordDrop::new(i));
502            assert_eq!(i, rdr1.read().0);
503        }
504
505        // because of the limitations of the current design, all values will
506        // not be dropped until this point.
507    }
508
509    #[test]
510    fn remove_readers() {
511        let _refs = RefsCheck::new();
512
513        let mut rcu = Rcu::new(RecordDrop::new(10));
514
515        let rdr1 = rcu.reader();
516        let rdr2 = rcu.reader();
517
518        for i in 11..=20 {
519            rcu.write(RecordDrop::new(i));
520        }
521
522        drop(rdr1);
523        drop(rdr2);
524
525        rcu.write(RecordDrop::new(30));
526    }
527
528    #[test]
529    fn nested() {
530        let _refs = RefsCheck::new();
531
532        let mut rcu = Rcu::new(RecordDrop::new(10));
533
534        let rdr = rcu.reader();
535
536        {
537            let handle = rdr.read();
538            assert_eq!(10, handle.0);
539
540            rcu.write(RecordDrop::new(20));
541
542            {
543                let handle = rdr.read();
544                assert_eq!(10, handle.0);
545            }
546
547            let handle2 = rdr.read();
548            assert_eq!(10, handle.0);
549            assert_eq!(10, handle2.0);
550        }
551
552        assert_eq!(20, rdr.read().0);
553    }
554
555    #[test]
556    fn nested_multi_threaded() {
557        let _refs = RefsCheck::new();
558
559        let notify = Arc::new((Mutex::new(false), Condvar::new()));
560
561        let mut rcu = Rcu::new(RecordDrop::new(10));
562        let rdr = rcu.reader();
563        assert_eq!(10, rdr.read().0);
564
565        let handles: Vec<_> = (0..10)
566            .map(|_| {
567                let rdr = rcu.reader();
568                let notify = notify.clone();
569                std::thread::spawn(move || {
570                    let _refs = RefsCheck::new();
571
572                    assert_eq!(10, rdr.read().0);
573                    {
574                        let handle = rdr.read();
575                        assert_eq!(10, handle.0);
576                    }
577
578                    let (lock, cvar) = &*notify;
579                    let mut started = lock.lock().unwrap();
580                    while !*started {
581                        started = cvar.wait(started).unwrap();
582                    }
583
584                    assert_eq!(20, rdr.read().0);
585                    {
586                        let handle = rdr.read();
587                        assert_eq!(20, handle.0);
588                    }
589                })
590            })
591            .collect();
592
593        thread::sleep(Duration::from_millis(10));
594        rcu.write(RecordDrop::new(20));
595
596        {
597            let (lock, cvar) = &*notify;
598            *lock.lock().unwrap() = true;
599            cvar.notify_all();
600        }
601
602        handles.into_iter().for_each(|h| h.join().unwrap());
603    }
604}