oxirs_core/concurrent/
epoch.rs1use crossbeam_epoch::{self as epoch, Atomic, Guard, Owned, Shared};
7use std::sync::atomic::{AtomicUsize, Ordering};
8use std::sync::Arc;
9
10type CompareExchangeResult<'g, T> =
12 Result<Shared<'g, VersionedNode<T>>, (Shared<'g, VersionedNode<T>>, Owned<VersionedNode<T>>)>;
13
14pub struct EpochManager {
16 global_epoch: Arc<AtomicUsize>,
18 _phantom: std::marker::PhantomData<()>,
20}
21
22impl EpochManager {
23 pub fn new() -> Self {
25 Self {
26 global_epoch: Arc::new(AtomicUsize::new(0)),
27 _phantom: std::marker::PhantomData,
28 }
29 }
30
31 pub fn pin(&self) -> Guard {
33 epoch::pin()
34 }
35
36 pub fn advance(&self) {
38 self.global_epoch.fetch_add(1, Ordering::Release);
39 }
40
41 pub fn current_epoch(&self) -> usize {
43 self.global_epoch.load(Ordering::Acquire)
44 }
45
46 pub fn defer<F>(&self, guard: &Guard, f: F)
48 where
49 F: FnOnce() + Send + 'static,
50 {
51 guard.defer(f);
52 }
53
54 pub fn flush(&self, guard: &Guard) {
56 guard.flush();
57 }
58}
59
60impl Default for EpochManager {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66pub struct VersionedPointer<T> {
68 ptr: Atomic<VersionedNode<T>>,
69}
70
71pub struct VersionedNode<T> {
73 data: T,
74 version: usize,
75}
76
77impl<T> VersionedPointer<T> {
78 pub fn new(data: T) -> Self {
80 let node = VersionedNode { data, version: 0 };
81 Self {
82 ptr: Atomic::new(node),
83 }
84 }
85
86 pub fn load<'g>(&self, guard: &'g Guard) -> Option<&'g T> {
88 let shared = self.ptr.load(Ordering::Acquire, guard);
89 unsafe { shared.as_ref().map(|node| &node.data) }
90 }
91
92 pub fn compare_and_swap<'g>(
94 &self,
95 current: Shared<'g, VersionedNode<T>>,
96 new: Owned<VersionedNode<T>>,
97 guard: &'g Guard,
98 ) -> CompareExchangeResult<'g, T> {
99 match self
100 .ptr
101 .compare_exchange(current, new, Ordering::Release, Ordering::Acquire, guard)
102 {
103 Ok(shared) => Ok(shared),
104 Err(e) => Err((e.current, e.new)),
105 }
106 }
107
108 pub fn update(&self, data: T, version: usize, guard: &Guard) -> bool {
110 let current = self.ptr.load(Ordering::Acquire, guard);
111
112 if let Some(current_node) = unsafe { current.as_ref() } {
114 if current_node.version >= version {
115 return false;
117 }
118 }
119
120 let new_node = VersionedNode { data, version };
121 let new = Owned::new(new_node);
122
123 match self.compare_and_swap(current, new, guard) {
124 Ok(_) => {
125 if !current.is_null() {
127 unsafe {
128 guard.defer_destroy(current);
129 }
130 }
131 true
132 }
133 Err((_, returned)) => {
134 drop(returned);
136 false
137 }
138 }
139 }
140}
141
142pub struct HazardPointer<T> {
144 inner: Atomic<T>,
145}
146
147impl<T> HazardPointer<T> {
148 pub fn new(data: T) -> Self {
150 Self {
151 inner: Atomic::new(data),
152 }
153 }
154
155 pub fn load<'g>(&self, guard: &'g Guard) -> Shared<'g, T> {
157 self.inner.load(Ordering::Acquire, guard)
158 }
159
160 pub fn store(&self, new: Owned<T>, guard: &Guard) {
162 let old = self.inner.swap(new, Ordering::Release, guard);
163 if !old.is_null() {
164 unsafe {
165 guard.defer_destroy(old);
166 }
167 }
168 }
169
170 pub fn compare_and_swap<'g>(
172 &self,
173 current: Shared<'g, T>,
174 new: Owned<T>,
175 guard: &'g Guard,
176 ) -> Result<Shared<'g, T>, (Shared<'g, T>, Owned<T>)> {
177 match self
178 .inner
179 .compare_exchange(current, new, Ordering::Release, Ordering::Acquire, guard)
180 {
181 Ok(shared) => Ok(shared),
182 Err(e) => Err((e.current, e.new)),
183 }
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use std::sync::Arc;
191 use std::thread;
192
193 #[test]
194 fn test_epoch_manager() {
195 let manager = Arc::new(EpochManager::new());
196 let initial_epoch = manager.current_epoch();
197
198 manager.advance();
200 assert_eq!(manager.current_epoch(), initial_epoch + 1);
201
202 let guard = manager.pin();
204 drop(guard);
205 }
206
207 #[test]
208 fn test_versioned_pointer() {
209 let ptr = Arc::new(VersionedPointer::new(42));
210 let guard = epoch::pin();
211
212 assert_eq!(ptr.load(&guard), Some(&42));
214
215 assert!(ptr.update(100, 1, &guard));
217 assert_eq!(ptr.load(&guard), Some(&100));
218
219 let result = ptr.update(50, 0, &guard);
221 assert!(!result, "Update with outdated version should fail");
222 assert_eq!(ptr.load(&guard), Some(&100));
223 }
224
225 #[test]
226 fn test_concurrent_updates() {
227 let ptr = Arc::new(VersionedPointer::new(0));
228 let num_threads = 4;
229 let updates_per_thread = 1000;
230
231 let handles: Vec<_> = (0..num_threads)
232 .map(|i| {
233 let ptr = ptr.clone();
234 thread::spawn(move || {
235 let guard = epoch::pin();
236 for j in 0..updates_per_thread {
237 let version = i * updates_per_thread + j;
238 ptr.update(version as i32, version, &guard);
239 }
240 })
241 })
242 .collect();
243
244 for handle in handles {
245 handle.join().unwrap();
246 }
247
248 let guard = epoch::pin();
250 let final_value = ptr.load(&guard).unwrap();
251 assert!(*final_value >= 0);
252 }
253
254 #[test]
255 fn test_hazard_pointer() {
256 let hp = Arc::new(HazardPointer::new("initial"));
257 let guard = epoch::pin();
258
259 hp.store(Owned::new("updated"), &guard);
261
262 let loaded = hp.load(&guard);
264 unsafe {
265 assert_eq!(loaded.as_ref().unwrap(), &"updated");
266 }
267 }
268}