Skip to main content

oxi/extensions/
stale.rs

1//! Stale detection for extension contexts.
2//!
3//! When a session switches, forks, or reloads, all extension contexts
4//! from the previous session become stale. The guard uses an atomic
5//! generation counter to detect this efficiently.
6
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::Arc;
9
10/// Shared generation counter for invalidation.
11/// Create one per session. When the session changes,
12/// call `invalidate()` to bump the generation.
13#[derive(Debug, Clone)]
14pub struct InvalidationToken {
15    inner: Arc<AtomicU64>,
16}
17
18impl InvalidationToken {
19    /// Create a new invalidation token (generation 0).
20    pub fn new() -> Self {
21        Self {
22            inner: Arc::new(AtomicU64::new(0)),
23        }
24    }
25
26    /// Bump the generation, invalidating all existing guards.
27    pub fn invalidate(&self) {
28        self.inner.fetch_add(1, Ordering::Release);
29    }
30
31    /// Create a guard that tracks the current generation.
32    pub fn guard(&self) -> ContextGuard {
33        ContextGuard {
34            generation: self.inner.load(Ordering::Acquire),
35            token: self.inner.clone(),
36        }
37    }
38
39    /// Current generation value.
40    pub fn generation(&self) -> u64 {
41        self.inner.load(Ordering::Acquire)
42    }
43}
44
45impl Default for InvalidationToken {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51/// Guard that checks whether its context is still valid.
52#[derive(Debug, Clone)]
53pub struct ContextGuard {
54    generation: u64,
55    token: Arc<AtomicU64>,
56}
57
58impl ContextGuard {
59    /// Check if this guard's context is still valid.
60    pub fn is_valid(&self) -> bool {
61        self.generation == self.token.load(Ordering::Acquire)
62    }
63
64    /// Assert validity, returning an error if stale.
65    pub fn check_valid(&self) -> anyhow::Result<()> {
66        if self.is_valid() {
67            Ok(())
68        } else {
69            Err(anyhow::anyhow!(
70                "Extension context is stale — session has been switched or reloaded"
71            ))
72        }
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79
80    #[test]
81    fn test_guard_starts_valid() {
82        let token = InvalidationToken::new();
83        let guard = token.guard();
84        assert!(guard.is_valid());
85        assert!(guard.check_valid().is_ok());
86    }
87
88    #[test]
89    fn test_invalidation_invalidates_guard() {
90        let token = InvalidationToken::new();
91        let guard = token.guard();
92        assert!(guard.is_valid());
93
94        token.invalidate();
95        assert!(!guard.is_valid());
96        assert!(guard.check_valid().is_err());
97    }
98
99    #[test]
100    fn test_new_guard_after_invalidation_is_valid() {
101        let token = InvalidationToken::new();
102        token.invalidate();
103
104        let new_guard = token.guard();
105        assert!(new_guard.is_valid());
106    }
107
108    #[test]
109    fn test_multiple_invalidations() {
110        let token = InvalidationToken::new();
111        let g0 = token.guard();
112        token.invalidate();
113        let g1 = token.guard();
114        token.invalidate();
115        let g2 = token.guard();
116
117        assert!(!g0.is_valid());
118        assert!(!g1.is_valid());
119        assert!(g2.is_valid());
120    }
121}