formualizer_eval/engine/
epoch_tracker.rs1use std::sync::Arc;
2use std::sync::atomic::{AtomicU64, Ordering};
3
4pub const MAX_THREADS: usize = 256;
6
7#[repr(align(64))]
9struct CachePadded<T> {
10 value: T,
11}
12
13impl<T> CachePadded<T> {
14 fn new(value: T) -> Self {
15 Self { value }
16 }
17}
18
19pub struct EpochTracker {
25 current_epoch: AtomicU64,
27
28 reader_epochs: Arc<Vec<CachePadded<AtomicU64>>>,
31
32 safe_epoch: AtomicU64,
34}
35
36impl EpochTracker {
37 pub fn new() -> Self {
38 let mut reader_epochs = Vec::with_capacity(MAX_THREADS);
39 for _ in 0..MAX_THREADS {
40 reader_epochs.push(CachePadded::new(AtomicU64::new(u64::MAX)));
41 }
42
43 Self {
44 current_epoch: AtomicU64::new(0),
45 reader_epochs: Arc::new(reader_epochs),
46 safe_epoch: AtomicU64::new(0),
47 }
48 }
49
50 pub fn current_epoch(&self) -> u64 {
52 self.current_epoch.load(Ordering::Acquire)
53 }
54
55 pub fn safe_epoch(&self) -> u64 {
57 self.safe_epoch.load(Ordering::Acquire)
58 }
59
60 pub fn begin_write(&self) -> WriteGuard {
62 let epoch = self.current_epoch.fetch_add(1, Ordering::AcqRel) + 1;
63 WriteGuard {
64 tracker: self,
65 epoch,
66 committed: false,
67 }
68 }
69
70 pub fn begin_read(&self, thread_id: usize) -> ReadGuard {
72 assert!(
73 thread_id < MAX_THREADS,
74 "Thread ID {thread_id} exceeds MAX_THREADS"
75 );
76
77 let epoch = self.current_epoch.load(Ordering::Acquire);
78 self.reader_epochs[thread_id]
79 .value
80 .store(epoch, Ordering::Release);
81
82 ReadGuard {
83 tracker: self,
84 thread_id,
85 epoch,
86 }
87 }
88
89 fn update_safe_epoch(&self) {
91 let current = self.current_epoch.load(Ordering::Acquire);
92 let min_reader = self
93 .reader_epochs
94 .iter()
95 .map(|padded| padded.value.load(Ordering::Acquire))
96 .filter(|&epoch| epoch != u64::MAX) .min()
98 .unwrap_or(current); self.safe_epoch.store(min_reader, Ordering::Release);
101 }
102
103 pub fn wait_for_readers(&self, target_epoch: u64) {
105 loop {
106 self.update_safe_epoch();
107 if self.safe_epoch.load(Ordering::Acquire) > target_epoch {
108 break;
109 }
110 std::hint::spin_loop();
111 }
112 }
113}
114
115impl Default for EpochTracker {
116 fn default() -> Self {
117 Self::new()
118 }
119}
120
121pub struct WriteGuard<'a> {
123 tracker: &'a EpochTracker,
124 epoch: u64,
125 committed: bool,
126}
127
128impl<'a> WriteGuard<'a> {
129 pub fn epoch(&self) -> u64 {
131 self.epoch
132 }
133
134 pub fn commit(&mut self) {
136 self.committed = true;
137 }
138}
139
140impl<'a> Drop for WriteGuard<'a> {
141 fn drop(&mut self) {
142 self.tracker.update_safe_epoch();
143 }
144}
145
146pub struct ReadGuard<'a> {
148 tracker: &'a EpochTracker,
149 thread_id: usize,
150 epoch: u64,
151}
152
153impl<'a> ReadGuard<'a> {
154 pub fn epoch(&self) -> u64 {
156 self.epoch
157 }
158
159 pub fn is_current(&self) -> bool {
161 self.epoch == self.tracker.current_epoch()
162 }
163}
164
165impl<'a> Drop for ReadGuard<'a> {
166 fn drop(&mut self) {
167 self.tracker.reader_epochs[self.thread_id]
168 .value
169 .store(u64::MAX, Ordering::Release);
170 self.tracker.update_safe_epoch();
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use std::thread;
178 use std::time::Duration;
179
180 #[test]
181 fn test_epoch_basic() {
182 let tracker = EpochTracker::new();
183
184 assert_eq!(tracker.current_epoch(), 0);
186
187 let _write = tracker.begin_write();
189 assert_eq!(tracker.current_epoch(), 1);
190 }
191
192 #[test]
193 fn test_read_guard() {
194 let tracker = EpochTracker::new();
195
196 let _read = tracker.begin_read(0);
198
199 assert_eq!(tracker.safe_epoch(), 0);
201 }
202
203 #[test]
204 fn test_write_advances_epoch() {
205 let tracker = EpochTracker::new();
206
207 {
208 let _write = tracker.begin_write();
209 assert_eq!(tracker.current_epoch(), 1);
210 }
211
212 assert_eq!(tracker.safe_epoch(), 1);
214 }
215
216 #[test]
217 fn test_concurrent_readers() {
218 let tracker = Arc::new(EpochTracker::new());
219
220 let handles: Vec<_> = (0..4)
222 .map(|i| {
223 let t = Arc::clone(&tracker);
224 thread::spawn(move || {
225 let _read = t.begin_read(i);
226 thread::sleep(Duration::from_millis(10));
227 })
228 })
229 .collect();
230
231 thread::sleep(Duration::from_millis(5));
233
234 assert_eq!(tracker.safe_epoch(), 0);
236
237 for h in handles {
239 h.join().unwrap();
240 }
241
242 tracker.update_safe_epoch();
244
245 assert_eq!(tracker.safe_epoch(), 0);
247 }
248
249 #[test]
250 fn test_write_waits_for_readers() {
251 let tracker = Arc::new(EpochTracker::new());
252
253 let reader_tracker = Arc::clone(&tracker);
255 let reader = thread::spawn(move || {
256 let _read = reader_tracker.begin_read(0);
257 thread::sleep(Duration::from_millis(50));
258 });
259
260 thread::sleep(Duration::from_millis(10));
262
263 let _write = tracker.begin_write();
265 assert_eq!(tracker.current_epoch(), 1);
266
267 tracker.update_safe_epoch();
269 assert_eq!(tracker.safe_epoch(), 0);
270
271 reader.join().unwrap();
273
274 tracker.update_safe_epoch();
276 assert_eq!(tracker.safe_epoch(), 1);
277 }
278
279 #[test]
280 #[should_panic(expected = "Thread ID 256 exceeds MAX_THREADS")]
281 fn test_thread_id_overflow() {
282 let tracker = EpochTracker::new();
283 tracker.begin_read(MAX_THREADS); }
285
286 #[test]
287 fn test_multiple_write_guards() {
288 let tracker = EpochTracker::new();
289
290 let write1 = tracker.begin_write();
291 assert_eq!(tracker.current_epoch(), 1);
292
293 let write2 = tracker.begin_write();
294 assert_eq!(tracker.current_epoch(), 2);
295
296 drop(write1);
297 drop(write2);
298
299 assert_eq!(tracker.safe_epoch(), 2);
300 }
301
302 #[test]
303 fn test_mvcc_with_vertex_store() {
304 use crate::engine::packed_coord::PackedCoord;
305 use crate::engine::vertex_store::VertexStore;
306
307 let tracker = Arc::new(EpochTracker::new());
308 let store = Arc::new(std::sync::Mutex::new(VertexStore::new()));
309
310 let writer_tracker = Arc::clone(&tracker);
312 let writer_store = Arc::clone(&store);
313 let writer = thread::spawn(move || {
314 let _write = writer_tracker.begin_write();
315 let mut store = writer_store.lock().unwrap();
316
317 for i in 0..5 {
318 store.allocate(PackedCoord::new(i, i), 0, 0);
319 }
320 });
321
322 let reader_tracker = Arc::clone(&tracker);
324 let reader_store = Arc::clone(&store);
325 let reader = thread::spawn(move || {
326 thread::sleep(Duration::from_millis(10)); let _read = reader_tracker.begin_read(0);
329 let store = reader_store.lock().unwrap();
330
331 store.len()
334 });
335
336 writer.join().unwrap();
337 let observed_len = reader.join().unwrap();
338
339 assert!(observed_len == 0 || observed_len == 5);
341 }
342}