formualizer_eval/engine/
epoch_tracker.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicU64, Ordering};
3
4/// Maximum number of concurrent reader threads supported
5pub const MAX_THREADS: usize = 256;
6
7/// Cache-padded atomic u64 to avoid false sharing
8#[repr(align(64))]
9struct CachePadded<T> {
10    value: T,
11}
12
13impl<T> CachePadded<T> {
14    fn new(value: T) -> Self {
15        Self { value }
16    }
17}
18
19/// Epoch-based MVCC tracker for concurrent reads during writes
20///
21/// Allows multiple readers to access consistent snapshots while
22/// writers make changes. Tracks the minimum epoch across all active
23/// readers to determine when old data can be safely reclaimed.
24pub struct EpochTracker {
25    /// Current global epoch, incremented on each write
26    current_epoch: AtomicU64,
27
28    /// Per-thread reader epochs (u64::MAX = no active reader)
29    /// Cache-padded to avoid false sharing between threads
30    reader_epochs: Arc<Vec<CachePadded<AtomicU64>>>,
31
32    /// Minimum epoch that is safe to reclaim (all readers are past this)
33    safe_epoch: AtomicU64,
34}
35
36impl EpochTracker {
37    pub fn new() -> Self {
38        let mut reader_epochs = Vec::with_capacity(MAX_THREADS);
39        for _ in 0..MAX_THREADS {
40            reader_epochs.push(CachePadded::new(AtomicU64::new(u64::MAX)));
41        }
42
43        Self {
44            current_epoch: AtomicU64::new(0),
45            reader_epochs: Arc::new(reader_epochs),
46            safe_epoch: AtomicU64::new(0),
47        }
48    }
49
50    /// Get the current epoch
51    pub fn current_epoch(&self) -> u64 {
52        self.current_epoch.load(Ordering::Acquire)
53    }
54
55    /// Get the safe epoch (minimum across all active readers)
56    pub fn safe_epoch(&self) -> u64 {
57        self.safe_epoch.load(Ordering::Acquire)
58    }
59
60    /// Begin a write operation, incrementing the global epoch
61    pub fn begin_write(&self) -> WriteGuard {
62        let epoch = self.current_epoch.fetch_add(1, Ordering::AcqRel) + 1;
63        WriteGuard {
64            tracker: self,
65            epoch,
66            committed: false,
67        }
68    }
69
70    /// Begin a read operation on the given thread
71    pub fn begin_read(&self, thread_id: usize) -> ReadGuard {
72        assert!(
73            thread_id < MAX_THREADS,
74            "Thread ID {thread_id} exceeds MAX_THREADS"
75        );
76
77        let epoch = self.current_epoch.load(Ordering::Acquire);
78        self.reader_epochs[thread_id]
79            .value
80            .store(epoch, Ordering::Release);
81
82        ReadGuard {
83            tracker: self,
84            thread_id,
85            epoch,
86        }
87    }
88
89    /// Update the safe epoch based on current reader states
90    fn update_safe_epoch(&self) {
91        let current = self.current_epoch.load(Ordering::Acquire);
92        let min_reader = self
93            .reader_epochs
94            .iter()
95            .map(|padded| padded.value.load(Ordering::Acquire))
96            .filter(|&epoch| epoch != u64::MAX) // Ignore inactive readers
97            .min()
98            .unwrap_or(current); // If no active readers, use current epoch
99
100        self.safe_epoch.store(min_reader, Ordering::Release);
101    }
102
103    /// Wait until all readers have advanced past the given epoch
104    pub fn wait_for_readers(&self, target_epoch: u64) {
105        loop {
106            self.update_safe_epoch();
107            if self.safe_epoch.load(Ordering::Acquire) > target_epoch {
108                break;
109            }
110            std::hint::spin_loop();
111        }
112    }
113}
114
115impl Default for EpochTracker {
116    fn default() -> Self {
117        Self::new()
118    }
119}
120
121/// Guard for write operations - updates safe epoch on drop
122pub struct WriteGuard<'a> {
123    tracker: &'a EpochTracker,
124    epoch: u64,
125    committed: bool,
126}
127
128impl<'a> WriteGuard<'a> {
129    /// Get the epoch this write is operating in
130    pub fn epoch(&self) -> u64 {
131        self.epoch
132    }
133
134    /// Mark this write as committed (for two-phase commit protocols)
135    pub fn commit(&mut self) {
136        self.committed = true;
137    }
138}
139
140impl<'a> Drop for WriteGuard<'a> {
141    fn drop(&mut self) {
142        self.tracker.update_safe_epoch();
143    }
144}
145
146/// Guard for read operations - clears reader epoch on drop
147pub struct ReadGuard<'a> {
148    tracker: &'a EpochTracker,
149    thread_id: usize,
150    epoch: u64,
151}
152
153impl<'a> ReadGuard<'a> {
154    /// Get the epoch this read is operating in
155    pub fn epoch(&self) -> u64 {
156        self.epoch
157    }
158
159    /// Check if this read's view is still current
160    pub fn is_current(&self) -> bool {
161        self.epoch == self.tracker.current_epoch()
162    }
163}
164
165impl<'a> Drop for ReadGuard<'a> {
166    fn drop(&mut self) {
167        self.tracker.reader_epochs[self.thread_id]
168            .value
169            .store(u64::MAX, Ordering::Release);
170        self.tracker.update_safe_epoch();
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use std::thread;
178    use std::time::Duration;
179
180    #[test]
181    fn test_epoch_basic() {
182        let tracker = EpochTracker::new();
183
184        // Initial epoch should be 0
185        assert_eq!(tracker.current_epoch(), 0);
186
187        // Begin write should increment epoch
188        let _write = tracker.begin_write();
189        assert_eq!(tracker.current_epoch(), 1);
190    }
191
192    #[test]
193    fn test_read_guard() {
194        let tracker = EpochTracker::new();
195
196        // Begin read on thread 0
197        let _read = tracker.begin_read(0);
198
199        // Safe epoch should be 0 (reader is at epoch 0)
200        assert_eq!(tracker.safe_epoch(), 0);
201    }
202
203    #[test]
204    fn test_write_advances_epoch() {
205        let tracker = EpochTracker::new();
206
207        {
208            let _write = tracker.begin_write();
209            assert_eq!(tracker.current_epoch(), 1);
210        }
211
212        // After write guard drops, safe epoch updates
213        assert_eq!(tracker.safe_epoch(), 1);
214    }
215
216    #[test]
217    fn test_concurrent_readers() {
218        let tracker = Arc::new(EpochTracker::new());
219
220        // Start multiple readers
221        let handles: Vec<_> = (0..4)
222            .map(|i| {
223                let t = Arc::clone(&tracker);
224                thread::spawn(move || {
225                    let _read = t.begin_read(i);
226                    thread::sleep(Duration::from_millis(10));
227                })
228            })
229            .collect();
230
231        // Wait a bit for readers to start
232        thread::sleep(Duration::from_millis(5));
233
234        // Safe epoch should still be 0 (readers are at epoch 0)
235        assert_eq!(tracker.safe_epoch(), 0);
236
237        // Wait for all readers to finish
238        for h in handles {
239            h.join().unwrap();
240        }
241
242        // Force update after readers finish
243        tracker.update_safe_epoch();
244
245        // Safe epoch should be current epoch (0) after all readers finish
246        assert_eq!(tracker.safe_epoch(), 0);
247    }
248
249    #[test]
250    fn test_write_waits_for_readers() {
251        let tracker = Arc::new(EpochTracker::new());
252
253        // Start a long-running reader at epoch 0
254        let reader_tracker = Arc::clone(&tracker);
255        let reader = thread::spawn(move || {
256            let _read = reader_tracker.begin_read(0);
257            thread::sleep(Duration::from_millis(50));
258        });
259
260        // Give reader time to start
261        thread::sleep(Duration::from_millis(10));
262
263        // Create a write guard (advances to epoch 1)
264        let _write = tracker.begin_write();
265        assert_eq!(tracker.current_epoch(), 1);
266
267        // Safe epoch should still be 0 (reader is active at epoch 0)
268        tracker.update_safe_epoch();
269        assert_eq!(tracker.safe_epoch(), 0);
270
271        // Wait for reader to finish
272        reader.join().unwrap();
273
274        // Now safe epoch should advance
275        tracker.update_safe_epoch();
276        assert_eq!(tracker.safe_epoch(), 1);
277    }
278
279    #[test]
280    #[should_panic(expected = "Thread ID 256 exceeds MAX_THREADS")]
281    fn test_thread_id_overflow() {
282        let tracker = EpochTracker::new();
283        tracker.begin_read(MAX_THREADS); // Should panic
284    }
285
286    #[test]
287    fn test_multiple_write_guards() {
288        let tracker = EpochTracker::new();
289
290        let write1 = tracker.begin_write();
291        assert_eq!(tracker.current_epoch(), 1);
292
293        let write2 = tracker.begin_write();
294        assert_eq!(tracker.current_epoch(), 2);
295
296        drop(write1);
297        drop(write2);
298
299        assert_eq!(tracker.safe_epoch(), 2);
300    }
301
302    #[test]
303    fn test_mvcc_with_vertex_store() {
304        use crate::engine::packed_coord::PackedCoord;
305        use crate::engine::vertex_store::VertexStore;
306
307        let tracker = Arc::new(EpochTracker::new());
308        let store = Arc::new(std::sync::Mutex::new(VertexStore::new()));
309
310        // Writer adds vertices
311        let writer_tracker = Arc::clone(&tracker);
312        let writer_store = Arc::clone(&store);
313        let writer = thread::spawn(move || {
314            let _write = writer_tracker.begin_write();
315            let mut store = writer_store.lock().unwrap();
316
317            for i in 0..5 {
318                store.allocate(PackedCoord::new(i, i), 0, 0);
319            }
320        });
321
322        // Reader observes consistent snapshot
323        let reader_tracker = Arc::clone(&tracker);
324        let reader_store = Arc::clone(&store);
325        let reader = thread::spawn(move || {
326            thread::sleep(Duration::from_millis(10)); // Let writer start
327
328            let _read = reader_tracker.begin_read(0);
329            let store = reader_store.lock().unwrap();
330
331            // Reader sees consistent view
332
333            store.len()
334        });
335
336        writer.join().unwrap();
337        let observed_len = reader.join().unwrap();
338
339        // Reader either saw 0 (before write) or 5 (after write), not partial
340        assert!(observed_len == 0 || observed_len == 5);
341    }
342}