oxirs_core/concurrent/
epoch.rs

1//! Epoch-based memory reclamation for lock-free data structures
2//!
3//! This module provides a safe memory reclamation scheme for concurrent
4//! data structures. It uses epochs to track when memory can be safely freed.
5
6use crossbeam_epoch::{self as epoch, Atomic, Guard, Owned, Shared};
7use std::sync::atomic::{AtomicUsize, Ordering};
8use std::sync::Arc;
9
10/// Type alias for compare-exchange result
11type CompareExchangeResult<'g, T> =
12    Result<Shared<'g, VersionedNode<T>>, (Shared<'g, VersionedNode<T>>, Owned<VersionedNode<T>>)>;
13
14/// A thread-local epoch tracker for safe memory reclamation
15pub struct EpochManager {
16    /// The global epoch counter
17    global_epoch: Arc<AtomicUsize>,
18    /// Thread-local epoch guards
19    _phantom: std::marker::PhantomData<()>,
20}
21
22impl EpochManager {
23    /// Create a new epoch manager
24    pub fn new() -> Self {
25        Self {
26            global_epoch: Arc::new(AtomicUsize::new(0)),
27            _phantom: std::marker::PhantomData,
28        }
29    }
30
31    /// Pin the current thread to an epoch
32    pub fn pin(&self) -> Guard {
33        epoch::pin()
34    }
35
36    /// Advance the global epoch
37    pub fn advance(&self) {
38        self.global_epoch.fetch_add(1, Ordering::Release);
39    }
40
41    /// Get the current global epoch
42    pub fn current_epoch(&self) -> usize {
43        self.global_epoch.load(Ordering::Acquire)
44    }
45
46    /// Defer a closure until it's safe to execute
47    pub fn defer<F>(&self, guard: &Guard, f: F)
48    where
49        F: FnOnce() + Send + 'static,
50    {
51        guard.defer(f);
52    }
53
54    /// Flush all deferred operations
55    pub fn flush(&self, guard: &Guard) {
56        guard.flush();
57    }
58}
59
60impl Default for EpochManager {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66/// A versioned pointer for lock-free updates
67pub struct VersionedPointer<T> {
68    ptr: Atomic<VersionedNode<T>>,
69}
70
71/// A node with version information
72pub struct VersionedNode<T> {
73    data: T,
74    version: usize,
75}
76
77impl<T> VersionedPointer<T> {
78    /// Create a new versioned pointer
79    pub fn new(data: T) -> Self {
80        let node = VersionedNode { data, version: 0 };
81        Self {
82            ptr: Atomic::new(node),
83        }
84    }
85
86    /// Load the current value
87    pub fn load<'g>(&self, guard: &'g Guard) -> Option<&'g T> {
88        let shared = self.ptr.load(Ordering::Acquire, guard);
89        unsafe { shared.as_ref().map(|node| &node.data) }
90    }
91
92    /// Compare and swap the pointer
93    pub fn compare_and_swap<'g>(
94        &self,
95        current: Shared<'g, VersionedNode<T>>,
96        new: Owned<VersionedNode<T>>,
97        guard: &'g Guard,
98    ) -> CompareExchangeResult<'g, T> {
99        match self
100            .ptr
101            .compare_exchange(current, new, Ordering::Release, Ordering::Acquire, guard)
102        {
103            Ok(shared) => Ok(shared),
104            Err(e) => Err((e.current, e.new)),
105        }
106    }
107
108    /// Update the value with a new version
109    pub fn update(&self, data: T, version: usize, guard: &Guard) -> bool {
110        let current = self.ptr.load(Ordering::Acquire, guard);
111
112        // Check version before attempting swap
113        if let Some(current_node) = unsafe { current.as_ref() } {
114            if current_node.version >= version {
115                // Our version is outdated
116                return false;
117            }
118        }
119
120        let new_node = VersionedNode { data, version };
121        let new = Owned::new(new_node);
122
123        match self.compare_and_swap(current, new, guard) {
124            Ok(_) => {
125                // Defer cleanup of old node
126                if !current.is_null() {
127                    unsafe {
128                        guard.defer_destroy(current);
129                    }
130                }
131                true
132            }
133            Err((_, returned)) => {
134                // Someone else updated between our load and CAS
135                drop(returned);
136                false
137            }
138        }
139    }
140}
141
142/// Hazard pointer wrapper for additional safety
143pub struct HazardPointer<T> {
144    inner: Atomic<T>,
145}
146
147impl<T> HazardPointer<T> {
148    /// Create a new hazard pointer
149    pub fn new(data: T) -> Self {
150        Self {
151            inner: Atomic::new(data),
152        }
153    }
154
155    /// Load with hazard pointer protection
156    pub fn load<'g>(&self, guard: &'g Guard) -> Shared<'g, T> {
157        self.inner.load(Ordering::Acquire, guard)
158    }
159
160    /// Store a new value
161    pub fn store(&self, new: Owned<T>, guard: &Guard) {
162        let old = self.inner.swap(new, Ordering::Release, guard);
163        if !old.is_null() {
164            unsafe {
165                guard.defer_destroy(old);
166            }
167        }
168    }
169
170    /// Compare and swap
171    pub fn compare_and_swap<'g>(
172        &self,
173        current: Shared<'g, T>,
174        new: Owned<T>,
175        guard: &'g Guard,
176    ) -> Result<Shared<'g, T>, (Shared<'g, T>, Owned<T>)> {
177        match self
178            .inner
179            .compare_exchange(current, new, Ordering::Release, Ordering::Acquire, guard)
180        {
181            Ok(shared) => Ok(shared),
182            Err(e) => Err((e.current, e.new)),
183        }
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use std::sync::Arc;
191    use std::thread;
192
193    #[test]
194    fn test_epoch_manager() {
195        let manager = Arc::new(EpochManager::new());
196        let initial_epoch = manager.current_epoch();
197
198        // Advance epoch
199        manager.advance();
200        assert_eq!(manager.current_epoch(), initial_epoch + 1);
201
202        // Test pinning
203        let guard = manager.pin();
204        drop(guard);
205    }
206
207    #[test]
208    fn test_versioned_pointer() {
209        let ptr = Arc::new(VersionedPointer::new(42));
210        let guard = epoch::pin();
211
212        // Load initial value
213        assert_eq!(ptr.load(&guard), Some(&42));
214
215        // Update value
216        assert!(ptr.update(100, 1, &guard));
217        assert_eq!(ptr.load(&guard), Some(&100));
218
219        // Try outdated update - this should fail because version 0 < current version 1
220        let result = ptr.update(50, 0, &guard);
221        assert!(!result, "Update with outdated version should fail");
222        assert_eq!(ptr.load(&guard), Some(&100));
223    }
224
225    #[test]
226    fn test_concurrent_updates() {
227        let ptr = Arc::new(VersionedPointer::new(0));
228        let num_threads = 4;
229        let updates_per_thread = 1000;
230
231        let handles: Vec<_> = (0..num_threads)
232            .map(|i| {
233                let ptr = ptr.clone();
234                thread::spawn(move || {
235                    let guard = epoch::pin();
236                    for j in 0..updates_per_thread {
237                        let version = i * updates_per_thread + j;
238                        ptr.update(version as i32, version, &guard);
239                    }
240                })
241            })
242            .collect();
243
244        for handle in handles {
245            handle.join().unwrap();
246        }
247
248        // Check final state
249        let guard = epoch::pin();
250        let final_value = ptr.load(&guard).unwrap();
251        assert!(*final_value >= 0);
252    }
253
254    #[test]
255    fn test_hazard_pointer() {
256        let hp = Arc::new(HazardPointer::new("initial"));
257        let guard = epoch::pin();
258
259        // Store new value
260        hp.store(Owned::new("updated"), &guard);
261
262        // Load value
263        let loaded = hp.load(&guard);
264        unsafe {
265            assert_eq!(loaded.as_ref().unwrap(), &"updated");
266        }
267    }
268}