1#![doc = include_str!("../readme.md")]
2
3use std::{collections::VecDeque, ops::Deref, ptr};
4
5#[cfg(not(loom))]
6use std::{
7 cell::Cell,
8 sync::{
9 atomic::{AtomicPtr, AtomicU64, Ordering},
10 Arc, Mutex, MutexGuard,
11 },
12};
13
14#[cfg(loom)]
15use loom::{
16 cell::Cell,
17 sync::{
18 atomic::{AtomicPtr, AtomicU64, Ordering},
19 Arc, Mutex, MutexGuard,
20 },
21};
22
23#[derive(Debug)]
27pub struct Rcu<T> {
28 epoch: u64,
29 shared: Arc<Shared<T>>,
30}
31
32#[derive(Debug)]
37pub struct Reader<T: 'static> {
38 cache: Cell<&'static StampedValue<T>>,
39 refs: Cell<usize>,
40 state: ReaderState,
41 shared: Arc<Shared<T>>,
42}
43
44#[derive(Debug)]
45struct Shared<T> {
46 ptr: Pointer<T>,
47 reclaim: Mutex<VecDeque<Box<StampedValue<T>>>>,
48 readers: Mutex<Vec<ReaderState>>,
49}
50
51#[derive(Debug)]
52struct StampedValue<T> {
53 value: T,
54 epoch: u64,
55}
56
57#[derive(Debug, Clone)]
58struct ReaderState(Arc<AtomicU64>);
59
60#[derive(Debug)]
61struct Pointer<T>(AtomicPtr<StampedValue<T>>);
62
63#[derive(Debug)]
64pub struct Guard<'a, T: 'static> {
65 cache: &'a StampedValue<T>,
66 reader: &'a Reader<T>,
67}
68
69impl<T: 'static> Default for Rcu<T>
70where
71 T: Default,
72{
73 fn default() -> Self {
74 Self::new(T::default())
75 }
76}
77
78impl<T: 'static> Rcu<T> {
79 pub fn new(value: T) -> Self {
81 Self {
82 epoch: 1,
83 shared: Arc::new(Shared {
84 ptr: Pointer::new(StampedValue { value, epoch: 1 }),
85 reclaim: Mutex::new(VecDeque::new()),
86 readers: Mutex::new(Vec::new()),
87 }),
88 }
89 }
90
91 pub fn reader(&mut self) -> Reader<T> {
93 Reader::new(self.shared.clone())
94 }
95
96 pub fn write(&mut self, value: T) {
100 self.epoch += 1;
102
103 let next = StampedValue {
105 epoch: self.epoch,
106 value,
107 };
108 let prev = self.shared.ptr.swap(next);
109 self.reclaim_queue().push_back(prev);
110
111 self.try_reclaim();
113 }
114
115 pub fn try_reclaim(&mut self) -> usize {
119 let mut readers = self.shared.readers.lock().unwrap();
120
121 readers.retain(|reader| reader.get() > ReaderState::NOT_IN_USE);
123
124 let mut reclaim = self.reclaim_queue();
125
126 if readers.is_empty() {
128 reclaim.clear();
129 }
130
131 let min_epoch = readers
134 .iter()
135 .map(|r| r.get())
136 .min()
137 .unwrap_or(ReaderState::NOT_IN_USE);
138 while let Some(candidate) = reclaim.pop_front() {
139 if min_epoch > candidate.epoch {
140 drop(candidate);
141 } else {
142 reclaim.push_front(candidate);
143 return reclaim.len();
145 }
146 }
147 0
148 }
149
150 fn reclaim_queue(&self) -> MutexGuard<'_, VecDeque<Box<StampedValue<T>>>> {
151 self.shared
155 .reclaim
156 .try_lock()
157 .expect("invalid shared reclaimer access")
158 }
159}
160
161impl<T: 'static> Reader<T> {
162 fn new(shared: Arc<Shared<T>>) -> Self {
163 let value = shared.ptr.load();
164 let mut readers = shared.readers.lock().unwrap();
165
166 let state = ReaderState::new(value.epoch);
167 readers.push(state.clone());
168
169 Reader {
170 shared: shared.clone(),
171 refs: Cell::new(0),
172 cache: Cell::new(value),
173 state,
174 }
175 }
176
177 pub fn read(&self) -> Guard<'_, T> {
180 let cache = if self.refs.get() == 0 {
187 let value = self.shared.ptr.load();
188
189 self.state.set(value.epoch);
192
193 self.cache.set(value);
195 value
196 } else {
197 self.cache.get()
198 };
199
200 self.refs.set(self.refs.get() + 1);
201
202 Guard {
203 reader: self,
204 cache,
205 }
206 }
207}
208
209impl<'a, T> Deref for Guard<'a, T> {
210 type Target = T;
211
212 fn deref(&self) -> &Self::Target {
213 &self.cache.value
214 }
215}
216
217impl<'a, T> Drop for Guard<'a, T> {
218 fn drop(&mut self) {
219 self.reader.refs.replace(self.reader.refs.get() - 1);
220 }
221}
222
223impl<T> Drop for Reader<T> {
224 fn drop(&mut self) {
225 self.state.mark_dropped();
226 }
227}
228
229impl<T> Pointer<T> {
230 fn new(value: StampedValue<T>) -> Self {
231 Self(AtomicPtr::new(Box::leak(Box::new(value))))
232 }
233
234 fn swap(&self, value: StampedValue<T>) -> Box<StampedValue<T>> {
235 let ptr = Box::leak(Box::new(value));
236 let prev = self.0.swap(ptr, Ordering::AcqRel);
237 unsafe { Box::from_raw(prev) }
238 }
239
240 fn load(&self) -> &'static StampedValue<T> {
241 unsafe { &*self.0.load(Ordering::Relaxed) }
242 }
243}
244
245impl<T> Drop for Pointer<T> {
246 fn drop(&mut self) {
247 let prev = self.0.swap(ptr::null_mut(), Ordering::AcqRel);
248 let _ = unsafe { Box::from_raw(prev) };
249 }
250}
251
252impl ReaderState {
253 const NOT_IN_USE: u64 = 0;
254
255 fn new(epoch: u64) -> Self {
256 Self(Arc::new(AtomicU64::new(epoch)))
257 }
258
259 fn mark_dropped(&self) {
260 self.set(Self::NOT_IN_USE)
261 }
262
263 fn set(&self, epoch: u64) {
264 self.0.store(epoch, Ordering::Release)
265 }
266
267 fn get(&self) -> u64 {
268 self.0.load(Ordering::Acquire)
269 }
270}
271
272#[cfg(feature = "thread-local")]
277pub struct ThreadLocal<T: Send + Sync + 'static> {
278 shared: Arc<Shared<T>>,
279 thread_local: thread_local::ThreadLocal<Reader<T>>,
280}
281
282#[cfg(feature = "thread-local")]
283impl<T: Send + Sync + 'static> ThreadLocal<T> {
284 pub fn new(rcu: &Rcu<T>) -> Self {
285 Self {
286 shared: rcu.shared.clone(),
287 thread_local: thread_local::ThreadLocal::new(),
288 }
289 }
290
291 pub fn get(&self) -> Option<Guard<'_, T>> {
293 self.thread_local.get().map(|r| r.read())
294 }
295
296 pub fn get_or_init(&self) -> Guard<'_, T> {
298 self.thread_local
299 .get_or(|| Reader::new(self.shared.clone()))
300 .read()
301 }
302}
303
304#[cfg(test)]
305#[cfg(loom)]
306mod loom_tests {
307 use loom::thread;
308
309 use super::*;
310
311 #[test]
312 fn nested() {
313 loom::model(|| {
314 let mut rcu = Rcu::new(10);
315
316 let rdr = rcu.reader();
317
318 {
319 let g = rdr.read();
320 assert_eq!(10, *g);
321
322 rcu.write(20);
323 {
324 let g = rdr.read();
325 assert_eq!(10, *g);
326 }
327 }
328 });
329 }
330
331 #[test]
332 fn thread_nested() {
333 loom::model(|| {
334 let n = 2;
335 let mut rcu = Rcu::new(0);
336 let rdr = rcu.reader();
337 let h = thread::spawn(move || {
338 let v = rdr.read();
339 assert!(*v < n);
340
341 {
342 let g = rdr.read();
343 assert!(*g < n);
344 }
345 });
346 for i in 0..n {
347 rcu.write(i);
348 loom::thread::yield_now();
349 }
350 h.join().unwrap();
351 });
352 }
353
354 #[test]
355 fn thread() {
356 loom::model(|| {
357 let n = 2;
358 let mut rcu = Rcu::new(0);
359 let rdr = rcu.reader();
360 let h = thread::spawn(move || {
361 for _ in 0..n {
362 let v = rdr.read();
363 assert!(*v < n);
364 loom::thread::yield_now();
365 }
366 });
367 for i in 0..n {
368 rcu.write(i);
369 loom::thread::yield_now();
370 }
371 h.join().unwrap();
372 });
373 }
374
375 #[test]
376 fn thread_detached() {
377 loom::model(|| {
378 let n = 2;
379 let mut rcu = Rcu::new(0);
380 let rdr = rcu.reader();
381 thread::spawn(move || {
382 for _ in 0..n {
383 let v = rdr.read();
384 assert!(*v < n);
385 loom::thread::yield_now();
386 }
387 });
388 for i in 0..n {
389 rcu.write(i);
390 loom::thread::yield_now();
391 }
392 });
393 }
394}
395
396#[cfg(test)]
397#[cfg(not(loom))]
398mod tests {
399 use std::{
400 sync::{atomic::AtomicUsize, Condvar},
401 thread,
402 time::Duration,
403 };
404
405 use super::*;
406
407 thread_local! {
408 static REFS: AtomicUsize = AtomicUsize::new(0);
409 }
410
411 struct RefsCheck;
412
413 impl RefsCheck {
414 fn new() -> Self {
415 REFS.with(|refs| {
416 assert_eq!(refs.load(Ordering::SeqCst), 0);
417 });
418 Self
419 }
420 }
421
422 impl Drop for RefsCheck {
423 fn drop(&mut self) {
424 REFS.with(|refs| {
425 assert_eq!(refs.load(Ordering::SeqCst), 0);
426 });
427 }
428 }
429
430 #[derive(Debug)]
431 struct RecordDrop(u32);
432
433 impl RecordDrop {
434 fn new(v: u32) -> Self {
435 REFS.with(|refs| {
436 refs.fetch_add(1, Ordering::SeqCst);
437 });
438 Self(v)
439 }
440 }
441
442 impl Drop for RecordDrop {
443 fn drop(&mut self) {
444 REFS.with(|refs| {
445 refs.fetch_sub(1, Ordering::SeqCst);
446 });
447 }
448 }
449
450 #[cfg(feature = "thread-local")]
451 #[test]
452 fn thread_local() {
453 let mut rcu = Rcu::new(10);
454 let tls = ThreadLocal::new(&rcu);
455
456 thread::scope(|s| {
457 s.spawn(|| {
458 let _val = tls.get_or_init();
459 assert!(tls.get().is_some());
460 });
461 s.spawn(|| {
462 let _val = tls.get_or_init();
463 assert!(tls.get().is_some());
464 });
465 });
466
467 rcu.write(1);
468 }
469
470 #[test]
471 fn send_check() {
472 let mut rcu = Rcu::new(10);
473 let rdr = rcu.reader();
474
475 thread::spawn(move || {
476 assert_eq!(10, *rdr.read());
477 });
478 }
479
480 #[test]
481 fn single_value() {
482 let _refs = RefsCheck::new();
483
484 let mut rcu = Rcu::new(RecordDrop::new(10));
485 let rdr = rcu.reader();
486 assert_eq!(10, rdr.read().0);
487 }
488
489 #[test]
490 fn old_value() {
491 let _refs = RefsCheck::new();
492
493 let mut rcu = Rcu::new(RecordDrop::new(10));
494 let rdr1 = rcu.reader();
495 assert_eq!(10, rdr1.read().0);
496
497 let rdr2 = rcu.reader();
498 assert_eq!(10, rdr2.read().0);
499
500 for i in 11..=20 {
501 rcu.write(RecordDrop::new(i));
502 assert_eq!(i, rdr1.read().0);
503 }
504
505 }
508
509 #[test]
510 fn remove_readers() {
511 let _refs = RefsCheck::new();
512
513 let mut rcu = Rcu::new(RecordDrop::new(10));
514
515 let rdr1 = rcu.reader();
516 let rdr2 = rcu.reader();
517
518 for i in 11..=20 {
519 rcu.write(RecordDrop::new(i));
520 }
521
522 drop(rdr1);
523 drop(rdr2);
524
525 rcu.write(RecordDrop::new(30));
526 }
527
528 #[test]
529 fn nested() {
530 let _refs = RefsCheck::new();
531
532 let mut rcu = Rcu::new(RecordDrop::new(10));
533
534 let rdr = rcu.reader();
535
536 {
537 let handle = rdr.read();
538 assert_eq!(10, handle.0);
539
540 rcu.write(RecordDrop::new(20));
541
542 {
543 let handle = rdr.read();
544 assert_eq!(10, handle.0);
545 }
546
547 let handle2 = rdr.read();
548 assert_eq!(10, handle.0);
549 assert_eq!(10, handle2.0);
550 }
551
552 assert_eq!(20, rdr.read().0);
553 }
554
555 #[test]
556 fn nested_multi_threaded() {
557 let _refs = RefsCheck::new();
558
559 let notify = Arc::new((Mutex::new(false), Condvar::new()));
560
561 let mut rcu = Rcu::new(RecordDrop::new(10));
562 let rdr = rcu.reader();
563 assert_eq!(10, rdr.read().0);
564
565 let handles: Vec<_> = (0..10)
566 .map(|_| {
567 let rdr = rcu.reader();
568 let notify = notify.clone();
569 std::thread::spawn(move || {
570 let _refs = RefsCheck::new();
571
572 assert_eq!(10, rdr.read().0);
573 {
574 let handle = rdr.read();
575 assert_eq!(10, handle.0);
576 }
577
578 let (lock, cvar) = &*notify;
579 let mut started = lock.lock().unwrap();
580 while !*started {
581 started = cvar.wait(started).unwrap();
582 }
583
584 assert_eq!(20, rdr.read().0);
585 {
586 let handle = rdr.read();
587 assert_eq!(20, handle.0);
588 }
589 })
590 })
591 .collect();
592
593 thread::sleep(Duration::from_millis(10));
594 rcu.write(RecordDrop::new(20));
595
596 {
597 let (lock, cvar) = &*notify;
598 *lock.lock().unwrap() = true;
599 cvar.notify_all();
600 }
601
602 handles.into_iter().for_each(|h| h.join().unwrap());
603 }
604}