cranpose_core/
snapshot_pinning.rs

1/// Snapshot pinning system to prevent premature garbage collection of state records.
2///
3/// This module implements a pinning table that tracks which snapshot IDs need to remain
4/// alive. When a snapshot is created, it "pins" the lowest snapshot ID that it depends on,
5/// preventing state records from those snapshots from being garbage collected.
6///
7/// Uses SnapshotDoubleIndexHeap for O(log N) pin/unpin and O(1) lowest queries.
8/// Based on Jetpack Compose's pinning mechanism (Snapshot.kt:714-722, 1954).
9use crate::snapshot_double_index_heap::SnapshotDoubleIndexHeap;
10use crate::snapshot_id_set::{SnapshotId, SnapshotIdSet};
11use std::cell::RefCell;
12
13/// A handle to a pinned snapshot. Dropping this handle releases the pin.
14///
15/// Internally stores a heap handle for O(log N) removal.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub struct PinHandle(usize);
18
19impl PinHandle {
20    /// Invalid pin handle constant (0 is reserved as invalid).
21    pub const INVALID: PinHandle = PinHandle(0);
22
23    /// Check if this handle is valid (non-zero).
24    pub fn is_valid(&self) -> bool {
25        self.0 != 0
26    }
27}
28
29/// The global pinning table that tracks pinned snapshots using a min-heap.
30struct PinningTable {
31    /// Min-heap of pinned snapshot IDs for O(1) lowest queries
32    heap: SnapshotDoubleIndexHeap,
33}
34
35impl PinningTable {
36    fn new() -> Self {
37        Self {
38            heap: SnapshotDoubleIndexHeap::new(),
39        }
40    }
41
42    /// Add a pin for the given snapshot ID, returning a handle.
43    ///
44    /// Time complexity: O(log N)
45    fn add(&mut self, snapshot_id: SnapshotId) -> PinHandle {
46        let heap_handle = self.heap.add(snapshot_id);
47        // Heap handles start at 0, but we reserve 0 as INVALID for PinHandle
48        // So we offset by 1: heap handle 0 → PinHandle(1), etc.
49        PinHandle(heap_handle + 1)
50    }
51
52    /// Remove a pin by handle.
53    ///
54    /// Time complexity: O(log N)
55    fn remove(&mut self, handle: PinHandle) -> bool {
56        if !handle.is_valid() {
57            return false;
58        }
59
60        // Convert PinHandle back to heap handle (subtract 1)
61        let heap_handle = handle.0 - 1;
62
63        // Verify handle is within bounds
64        if heap_handle < usize::MAX {
65            self.heap.remove(heap_handle);
66            true
67        } else {
68            false
69        }
70    }
71
72    /// Get the lowest pinned snapshot ID, or None if nothing is pinned.
73    ///
74    /// Time complexity: O(1)
75    fn lowest_pinned(&self) -> Option<SnapshotId> {
76        if self.heap.is_empty() {
77            None
78        } else {
79            // Use 0 as default (will never be returned since heap is non-empty)
80            Some(self.heap.lowest_or_default(0))
81        }
82    }
83
84    /// Get the count of pins (for testing).
85    #[cfg(test)]
86    fn pin_count(&self) -> usize {
87        self.heap.len()
88    }
89}
90
91thread_local! {
92    // Global pinning table protected by a mutex.
93    static PINNING_TABLE: RefCell<PinningTable> = RefCell::new(PinningTable::new());
94}
95
96/// Pin a snapshot and its invalid set, returning a handle.
97///
98/// This should be called when a snapshot is created to ensure that state records
99/// from the pinned snapshot and all its dependencies remain valid.
100///
101/// # Arguments
102/// * `snapshot_id` - The ID of the snapshot being created
103/// * `invalid` - The set of invalid snapshot IDs for this snapshot
104///
105/// # Returns
106/// A pin handle that should be released when the snapshot is disposed.
107///
108/// # Time Complexity
109/// O(log N) where N is the number of pinned snapshots
110pub fn track_pinning(snapshot_id: SnapshotId, invalid: &SnapshotIdSet) -> PinHandle {
111    // Pin the lowest snapshot ID that this snapshot depends on
112    let pinned_id = invalid.lowest(snapshot_id);
113
114    PINNING_TABLE.with(|cell| cell.borrow_mut().add(pinned_id))
115}
116
117/// Release a pinned snapshot.
118///
119/// # Arguments
120/// * `handle` - The pin handle returned by `track_pinning`
121///
122/// This must be called while holding the appropriate lock (sync).
123///
124/// # Time Complexity
125/// O(log N) where N is the number of pinned snapshots
126pub fn release_pinning(handle: PinHandle) {
127    if !handle.is_valid() {
128        return;
129    }
130
131    PINNING_TABLE.with(|cell| {
132        cell.borrow_mut().remove(handle);
133    });
134}
135
136/// Get the lowest currently pinned snapshot ID.
137///
138/// This is used to determine which state records can be safely garbage collected.
139/// Any state records from snapshots older than this ID are still potentially in use.
140///
141/// # Time Complexity
142/// O(1)
143pub fn lowest_pinned_snapshot() -> Option<SnapshotId> {
144    PINNING_TABLE.with(|cell| cell.borrow().lowest_pinned())
145}
146
147/// Get the current count of pinned snapshots (for testing).
148#[cfg(test)]
149pub fn pin_count() -> usize {
150    PINNING_TABLE.with(|cell| cell.borrow().pin_count())
151}
152
153/// Reset the pinning table (for testing).
154#[cfg(test)]
155pub fn reset_pinning_table() {
156    PINNING_TABLE.with(|cell| {
157        let mut table = cell.borrow_mut();
158        table.heap = SnapshotDoubleIndexHeap::new();
159    });
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    // Helper to ensure tests start with clean state
167    fn setup() {
168        reset_pinning_table();
169    }
170
171    #[test]
172    fn test_invalid_handle() {
173        let handle = PinHandle::INVALID;
174        assert!(!handle.is_valid());
175        assert_eq!(handle.0, 0);
176    }
177
178    #[test]
179    fn test_valid_handle() {
180        setup();
181        let invalid = SnapshotIdSet::new().set(10);
182        let handle = track_pinning(20, &invalid);
183        assert!(handle.is_valid());
184        assert!(handle.0 > 0);
185    }
186
187    #[test]
188    fn test_track_and_release() {
189        setup();
190
191        let invalid = SnapshotIdSet::new().set(10);
192        let handle = track_pinning(20, &invalid);
193
194        assert_eq!(pin_count(), 1);
195        assert_eq!(lowest_pinned_snapshot(), Some(10));
196
197        release_pinning(handle);
198        assert_eq!(pin_count(), 0);
199        assert_eq!(lowest_pinned_snapshot(), None);
200    }
201
202    #[test]
203    fn test_multiple_pins() {
204        setup();
205
206        let invalid1 = SnapshotIdSet::new().set(10);
207        let handle1 = track_pinning(20, &invalid1);
208
209        let invalid2 = SnapshotIdSet::new().set(5).set(15);
210        let handle2 = track_pinning(30, &invalid2);
211
212        assert_eq!(pin_count(), 2);
213        assert_eq!(lowest_pinned_snapshot(), Some(5));
214
215        // Release first pin
216        release_pinning(handle1);
217        assert_eq!(pin_count(), 1);
218        assert_eq!(lowest_pinned_snapshot(), Some(5));
219
220        // Release second pin
221        release_pinning(handle2);
222        assert_eq!(pin_count(), 0);
223        assert_eq!(lowest_pinned_snapshot(), None);
224    }
225
226    #[test]
227    fn test_duplicate_pins() {
228        setup();
229
230        // Pin the same snapshot ID twice
231        let invalid = SnapshotIdSet::new().set(10);
232        let handle1 = track_pinning(20, &invalid);
233        let handle2 = track_pinning(25, &invalid);
234
235        assert_eq!(pin_count(), 2);
236        assert_eq!(lowest_pinned_snapshot(), Some(10));
237
238        // Releasing one doesn't unpin completely
239        release_pinning(handle1);
240        assert_eq!(pin_count(), 1);
241        assert_eq!(lowest_pinned_snapshot(), Some(10));
242
243        // Releasing second one unpins completely
244        release_pinning(handle2);
245        assert_eq!(pin_count(), 0);
246        assert_eq!(lowest_pinned_snapshot(), None);
247    }
248
249    #[test]
250    fn test_pin_ordering() {
251        setup();
252
253        // Add pins in non-sorted order
254        let invalid1 = SnapshotIdSet::new().set(30);
255        let _handle1 = track_pinning(40, &invalid1);
256
257        let invalid2 = SnapshotIdSet::new().set(10);
258        let _handle2 = track_pinning(20, &invalid2);
259
260        let invalid3 = SnapshotIdSet::new().set(20);
261        let _handle3 = track_pinning(30, &invalid3);
262
263        // Lowest should still be 10
264        assert_eq!(lowest_pinned_snapshot(), Some(10));
265    }
266
267    #[test]
268    fn test_release_invalid_handle() {
269        setup();
270
271        // Releasing an invalid handle should not crash
272        release_pinning(PinHandle::INVALID);
273        assert_eq!(pin_count(), 0);
274    }
275
276    #[test]
277    fn test_empty_invalid_set() {
278        setup();
279
280        // Empty invalid set means snapshot depends on nothing older
281        let invalid = SnapshotIdSet::new();
282        let handle = track_pinning(100, &invalid);
283
284        // Should pin snapshot 100 itself (lowest returns the upper bound if empty)
285        assert_eq!(pin_count(), 1);
286        assert_eq!(lowest_pinned_snapshot(), Some(100));
287
288        release_pinning(handle);
289    }
290
291    #[test]
292    fn test_lowest_from_invalid_set() {
293        setup();
294
295        // Create an invalid set with multiple IDs
296        let invalid = SnapshotIdSet::new().set(5).set(10).set(15).set(20);
297        let handle = track_pinning(25, &invalid);
298
299        // Should pin the lowest ID from the invalid set
300        assert_eq!(lowest_pinned_snapshot(), Some(5));
301
302        release_pinning(handle);
303    }
304
305    #[test]
306    fn test_concurrent_snapshots() {
307        setup();
308
309        // Simulate multiple concurrent snapshots
310        let handles: Vec<_> = (0..10)
311            .map(|i| {
312                let invalid = SnapshotIdSet::new().set(i * 10);
313                track_pinning(i * 10 + 5, &invalid)
314            })
315            .collect();
316
317        assert_eq!(pin_count(), 10);
318        assert_eq!(lowest_pinned_snapshot(), Some(0));
319
320        // Release all
321        for handle in handles {
322            release_pinning(handle);
323        }
324
325        assert_eq!(pin_count(), 0);
326        assert_eq!(lowest_pinned_snapshot(), None);
327    }
328
329    #[test]
330    fn test_heap_handle_based_removal() {
331        setup();
332
333        // Test that we can remove pins using just the handle, without knowing the snapshot ID
334        let invalid1 = SnapshotIdSet::new().set(42);
335        let invalid2 = SnapshotIdSet::new().set(17);
336        let invalid3 = SnapshotIdSet::new().set(99);
337
338        let h1 = track_pinning(50, &invalid1);
339        let h2 = track_pinning(25, &invalid2);
340        let h3 = track_pinning(100, &invalid3);
341
342        assert_eq!(pin_count(), 3);
343        assert_eq!(lowest_pinned_snapshot(), Some(17));
344
345        // Remove middle value using only handle
346        release_pinning(h1);
347        assert_eq!(pin_count(), 2);
348        assert_eq!(lowest_pinned_snapshot(), Some(17));
349
350        // Remove lowest using only handle
351        release_pinning(h2);
352        assert_eq!(pin_count(), 1);
353        assert_eq!(lowest_pinned_snapshot(), Some(99));
354
355        release_pinning(h3);
356        assert!(pin_count() == 0);
357    }
358}