Skip to main content

cranpose_core/
snapshot_pinning.rs

1/// Snapshot pinning system to prevent premature garbage collection of state records.
2///
3/// This module implements a pinning table that tracks which snapshot IDs need to remain
4/// alive. When a snapshot is created, it "pins" the lowest snapshot ID that it depends on,
5/// preventing state records from those snapshots from being garbage collected.
6///
7/// Uses SnapshotDoubleIndexHeap for O(log N) pin/unpin and O(1) lowest queries.
8/// Based on Jetpack Compose's pinning mechanism (Snapshot.kt:714-722, 1954).
9use crate::snapshot_double_index_heap::SnapshotDoubleIndexHeap;
10use crate::snapshot_double_index_heap::SnapshotDoubleIndexHeapDebugStats;
11use crate::snapshot_id_set::{SnapshotId, SnapshotIdSet};
12use std::cell::RefCell;
13
14/// A handle to a pinned snapshot. Dropping this handle releases the pin.
15///
16/// Internally stores a heap handle for O(log N) removal.
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub struct PinHandle(usize);
19
20impl PinHandle {
21    /// Invalid pin handle constant (0 is reserved as invalid).
22    pub const INVALID: PinHandle = PinHandle(0);
23
24    /// Check if this handle is valid (non-zero).
25    pub fn is_valid(&self) -> bool {
26        self.0 != 0
27    }
28}
29
30/// The global pinning table that tracks pinned snapshots using a min-heap.
31struct PinningTable {
32    /// Min-heap of pinned snapshot IDs for O(1) lowest queries
33    heap: SnapshotDoubleIndexHeap,
34}
35
36impl PinningTable {
37    fn new() -> Self {
38        Self {
39            heap: SnapshotDoubleIndexHeap::new(),
40        }
41    }
42
43    /// Add a pin for the given snapshot ID, returning a handle.
44    ///
45    /// Time complexity: O(log N)
46    fn add(&mut self, snapshot_id: SnapshotId) -> PinHandle {
47        let heap_handle = self.heap.add(snapshot_id);
48        // Heap handles start at 0, but we reserve 0 as INVALID for PinHandle
49        // So we offset by 1: heap handle 0 → PinHandle(1), etc.
50        PinHandle(heap_handle + 1)
51    }
52
53    /// Remove a pin by handle.
54    ///
55    /// Time complexity: O(log N)
56    fn remove(&mut self, handle: PinHandle) -> bool {
57        if !handle.is_valid() {
58            return false;
59        }
60
61        // Convert PinHandle back to heap handle (subtract 1)
62        let heap_handle = handle.0 - 1;
63
64        // Verify handle is within bounds
65        if heap_handle < usize::MAX {
66            self.heap.remove(heap_handle);
67            true
68        } else {
69            false
70        }
71    }
72
73    /// Get the lowest pinned snapshot ID, or None if nothing is pinned.
74    ///
75    /// Time complexity: O(1)
76    fn lowest_pinned(&self) -> Option<SnapshotId> {
77        if self.heap.is_empty() {
78            None
79        } else {
80            // Use 0 as default (will never be returned since heap is non-empty)
81            Some(self.heap.lowest_or_default(0))
82        }
83    }
84
85    /// Get the count of pins (for testing/debugging).
86    fn pin_count(&self) -> usize {
87        self.heap.len()
88    }
89
90    fn debug_stats(&self) -> SnapshotPinningDebugStats {
91        SnapshotPinningDebugStats {
92            pin_count: self.pin_count(),
93            lowest_pinned_snapshot: self.lowest_pinned(),
94            heap: self.heap.debug_stats(),
95        }
96    }
97}
98
99thread_local! {
100    // Global pinning table protected by a mutex.
101    static PINNING_TABLE: RefCell<PinningTable> = RefCell::new(PinningTable::new());
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
105pub struct SnapshotPinningDebugStats {
106    pub pin_count: usize,
107    pub lowest_pinned_snapshot: Option<SnapshotId>,
108    pub heap: SnapshotDoubleIndexHeapDebugStats,
109}
110
111/// Pin a snapshot and its invalid set, returning a handle.
112///
113/// This should be called when a snapshot is created to ensure that state records
114/// from the pinned snapshot and all its dependencies remain valid.
115///
116/// # Arguments
117/// * `snapshot_id` - The ID of the snapshot being created
118/// * `invalid` - The set of invalid snapshot IDs for this snapshot
119///
120/// # Returns
121/// A pin handle that should be released when the snapshot is disposed.
122///
123/// # Time Complexity
124/// O(log N) where N is the number of pinned snapshots
125pub fn track_pinning(snapshot_id: SnapshotId, invalid: &SnapshotIdSet) -> PinHandle {
126    // Pin the lowest snapshot ID that this snapshot depends on
127    let pinned_id = invalid.lowest(snapshot_id);
128
129    PINNING_TABLE.with(|cell| cell.borrow_mut().add(pinned_id))
130}
131
132/// Release a pinned snapshot.
133///
134/// # Arguments
135/// * `handle` - The pin handle returned by `track_pinning`
136///
137/// This must be called while holding the appropriate lock (sync).
138///
139/// # Time Complexity
140/// O(log N) where N is the number of pinned snapshots
141pub fn release_pinning(handle: PinHandle) {
142    if !handle.is_valid() {
143        return;
144    }
145
146    PINNING_TABLE.with(|cell| {
147        cell.borrow_mut().remove(handle);
148    });
149}
150
151/// Get the lowest currently pinned snapshot ID.
152///
153/// This is used to determine which state records can be safely garbage collected.
154/// Any state records from snapshots older than this ID are still potentially in use.
155///
156/// # Time Complexity
157/// O(1)
158pub fn lowest_pinned_snapshot() -> Option<SnapshotId> {
159    PINNING_TABLE.with(|cell| cell.borrow().lowest_pinned())
160}
161
162/// Get the current count of pinned snapshots (for testing).
163/// Get the current count of pinned snapshots (for testing/debugging).
164pub fn pin_count() -> usize {
165    PINNING_TABLE.with(|cell| cell.borrow().pin_count())
166}
167
168pub fn debug_snapshot_pinning_stats() -> SnapshotPinningDebugStats {
169    PINNING_TABLE.with(|cell| cell.borrow().debug_stats())
170}
171
172/// Reset the pinning table (for testing).
173#[cfg(test)]
174pub fn reset_pinning_table() {
175    PINNING_TABLE.with(|cell| {
176        let mut table = cell.borrow_mut();
177        table.heap = SnapshotDoubleIndexHeap::new();
178    });
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    // Helper to ensure tests start with clean state
186    fn setup() {
187        reset_pinning_table();
188    }
189
190    #[test]
191    fn test_invalid_handle() {
192        let handle = PinHandle::INVALID;
193        assert!(!handle.is_valid());
194        assert_eq!(handle.0, 0);
195    }
196
197    #[test]
198    fn test_valid_handle() {
199        setup();
200        let invalid = SnapshotIdSet::new().set(10);
201        let handle = track_pinning(20, &invalid);
202        assert!(handle.is_valid());
203        assert!(handle.0 > 0);
204    }
205
206    #[test]
207    fn test_track_and_release() {
208        setup();
209
210        let invalid = SnapshotIdSet::new().set(10);
211        let handle = track_pinning(20, &invalid);
212
213        assert_eq!(pin_count(), 1);
214        assert_eq!(lowest_pinned_snapshot(), Some(10));
215
216        release_pinning(handle);
217        assert_eq!(pin_count(), 0);
218        assert_eq!(lowest_pinned_snapshot(), None);
219    }
220
221    #[test]
222    fn test_multiple_pins() {
223        setup();
224
225        let invalid1 = SnapshotIdSet::new().set(10);
226        let handle1 = track_pinning(20, &invalid1);
227
228        let invalid2 = SnapshotIdSet::new().set(5).set(15);
229        let handle2 = track_pinning(30, &invalid2);
230
231        assert_eq!(pin_count(), 2);
232        assert_eq!(lowest_pinned_snapshot(), Some(5));
233
234        // Release first pin
235        release_pinning(handle1);
236        assert_eq!(pin_count(), 1);
237        assert_eq!(lowest_pinned_snapshot(), Some(5));
238
239        // Release second pin
240        release_pinning(handle2);
241        assert_eq!(pin_count(), 0);
242        assert_eq!(lowest_pinned_snapshot(), None);
243    }
244
245    #[test]
246    fn test_duplicate_pins() {
247        setup();
248
249        // Pin the same snapshot ID twice
250        let invalid = SnapshotIdSet::new().set(10);
251        let handle1 = track_pinning(20, &invalid);
252        let handle2 = track_pinning(25, &invalid);
253
254        assert_eq!(pin_count(), 2);
255        assert_eq!(lowest_pinned_snapshot(), Some(10));
256
257        // Releasing one doesn't unpin completely
258        release_pinning(handle1);
259        assert_eq!(pin_count(), 1);
260        assert_eq!(lowest_pinned_snapshot(), Some(10));
261
262        // Releasing second one unpins completely
263        release_pinning(handle2);
264        assert_eq!(pin_count(), 0);
265        assert_eq!(lowest_pinned_snapshot(), None);
266    }
267
268    #[test]
269    fn test_pin_ordering() {
270        setup();
271
272        // Add pins in non-sorted order
273        let invalid1 = SnapshotIdSet::new().set(30);
274        let _handle1 = track_pinning(40, &invalid1);
275
276        let invalid2 = SnapshotIdSet::new().set(10);
277        let _handle2 = track_pinning(20, &invalid2);
278
279        let invalid3 = SnapshotIdSet::new().set(20);
280        let _handle3 = track_pinning(30, &invalid3);
281
282        // Lowest should still be 10
283        assert_eq!(lowest_pinned_snapshot(), Some(10));
284    }
285
286    #[test]
287    fn test_release_invalid_handle() {
288        setup();
289
290        // Releasing an invalid handle should not crash
291        release_pinning(PinHandle::INVALID);
292        assert_eq!(pin_count(), 0);
293    }
294
295    #[test]
296    fn test_empty_invalid_set() {
297        setup();
298
299        // Empty invalid set means snapshot depends on nothing older
300        let invalid = SnapshotIdSet::new();
301        let handle = track_pinning(100, &invalid);
302
303        // Should pin snapshot 100 itself (lowest returns the upper bound if empty)
304        assert_eq!(pin_count(), 1);
305        assert_eq!(lowest_pinned_snapshot(), Some(100));
306
307        release_pinning(handle);
308    }
309
310    #[test]
311    fn test_lowest_from_invalid_set() {
312        setup();
313
314        // Create an invalid set with multiple IDs
315        let invalid = SnapshotIdSet::new().set(5).set(10).set(15).set(20);
316        let handle = track_pinning(25, &invalid);
317
318        // Should pin the lowest ID from the invalid set
319        assert_eq!(lowest_pinned_snapshot(), Some(5));
320
321        release_pinning(handle);
322    }
323
324    #[test]
325    fn test_concurrent_snapshots() {
326        setup();
327
328        // Simulate multiple concurrent snapshots
329        let handles: Vec<_> = (0..10)
330            .map(|i| {
331                let invalid = SnapshotIdSet::new().set(i * 10);
332                track_pinning(i * 10 + 5, &invalid)
333            })
334            .collect();
335
336        assert_eq!(pin_count(), 10);
337        assert_eq!(lowest_pinned_snapshot(), Some(0));
338
339        // Release all
340        for handle in handles {
341            release_pinning(handle);
342        }
343
344        assert_eq!(pin_count(), 0);
345        assert_eq!(lowest_pinned_snapshot(), None);
346    }
347
348    #[test]
349    fn test_heap_handle_based_removal() {
350        setup();
351
352        // Test that we can remove pins using just the handle, without knowing the snapshot ID
353        let invalid1 = SnapshotIdSet::new().set(42);
354        let invalid2 = SnapshotIdSet::new().set(17);
355        let invalid3 = SnapshotIdSet::new().set(99);
356
357        let h1 = track_pinning(50, &invalid1);
358        let h2 = track_pinning(25, &invalid2);
359        let h3 = track_pinning(100, &invalid3);
360
361        assert_eq!(pin_count(), 3);
362        assert_eq!(lowest_pinned_snapshot(), Some(17));
363
364        // Remove middle value using only handle
365        release_pinning(h1);
366        assert_eq!(pin_count(), 2);
367        assert_eq!(lowest_pinned_snapshot(), Some(17));
368
369        // Remove lowest using only handle
370        release_pinning(h2);
371        assert_eq!(pin_count(), 1);
372        assert_eq!(lowest_pinned_snapshot(), Some(99));
373
374        release_pinning(h3);
375        assert!(pin_count() == 0);
376    }
377}