Skip to main content

veilid_tools/
startup_lock.rs

1use super::*;
2
3#[derive(ThisError, Debug, Copy, Clone, PartialEq, Eq)]
4#[error("Already started")]
5pub struct StartupLockAlreadyStartedError;
6
7#[derive(ThisError, Debug, Copy, Clone, PartialEq, Eq)]
8#[error("Already shut down")]
9pub struct StartupLockAlreadyShutDownError;
10
11#[derive(ThisError, Debug, Copy, Clone, PartialEq, Eq)]
12#[error("Not started")]
13pub struct StartupLockNotStartedError;
14
15/// RAII-style lock for startup and shutdown operations
16/// Must call 'success()' on this lock to report a successful startup or shutdown
17/// Dropping this lock without calling 'success()' first indicates a failed
18/// startup or shutdown operation
19#[derive(Debug)]
20pub struct StartupLockGuard<'a> {
21    guard: AsyncRwLockWriteGuard<'a, bool>,
22    success_value: bool,
23}
24
25impl StartupLockGuard<'_> {
26    /// Call this function at the end of a successful startup or shutdown
27    /// operation to switch the state of the StartupLock.
28    pub fn success(mut self) {
29        *self.guard = self.success_value;
30    }
31}
32
33/// RAII-style lock for entry operations on a started-up region of code.
34#[derive(Debug)]
35pub struct StartupLockEnterGuard<'a> {
36    _guard: AsyncRwLockReadGuard<'a, bool>,
37    #[cfg(feature = "debug-locks")]
38    id: usize,
39    #[cfg(feature = "debug-locks")]
40    active_guards: Arc<Mutex<HashMap<usize, backtrace::Backtrace>>>,
41}
42
43#[cfg(feature = "debug-locks")]
44impl<'a> Drop for StartupLockEnterGuard<'a> {
45    fn drop(&mut self) {
46        self.active_guards.lock().remove(&self.id);
47    }
48}
49
50/// RAII-style lock for entry operations on a started-up region of code.
51#[derive(Debug)]
52pub struct StartupLockEnterGuardArc {
53    _guard: AsyncRwLockReadGuardArc<bool>,
54    #[cfg(feature = "debug-locks")]
55    id: usize,
56    #[cfg(feature = "debug-locks")]
57    active_guards: Arc<Mutex<HashMap<usize, backtrace::Backtrace>>>,
58}
59
60#[cfg(feature = "debug-locks")]
61impl Drop for StartupLockEnterGuardArc {
62    fn drop(&mut self) {
63        self.active_guards.lock().remove(&self.id);
64    }
65}
66
67#[cfg(feature = "debug-locks")]
68static GUARD_ID: AtomicUsize = AtomicUsize::new(0);
69
70/// Synchronization mechanism that tracks the startup and shutdown of a region of code.
71/// Guarantees that some code can only be started up once and shut down only if it is
72/// already started.
73/// Also tracks if the code is in-use and will wait for all 'entered' code to finish
74/// before shutting down. Once a shutdown is requested, future calls to 'enter' will
75/// fail, ensuring that nothing is 'entered' at the time of shutdown. This allows an
76/// asynchronous shutdown to wait for operations to finish before proceeding.
77#[derive(Debug)]
78pub struct StartupLock {
79    startup_state: Arc<AsyncRwLock<bool>>,
80    stop_source: Mutex<Option<StopSource>>,
81    #[cfg(feature = "debug-locks")]
82    active_guards: Arc<Mutex<HashMap<usize, backtrace::Backtrace>>>,
83}
84
85impl StartupLock {
86    #[must_use]
87    pub fn new() -> Self {
88        Self {
89            startup_state: Arc::new(AsyncRwLock::new(false)),
90            stop_source: Mutex::new(None),
91            #[cfg(feature = "debug-locks")]
92            active_guards: Arc::new(Mutex::new(HashMap::new())),
93        }
94    }
95
96    /// Start up if things are not already started up
97    /// One must call 'success()' on the returned startup lock guard if startup was successful
98    /// otherwise the startup lock will not shift to the 'started' state.
99    pub fn startup(&self) -> Result<StartupLockGuard<'_>, StartupLockAlreadyStartedError> {
100        let guard = self
101            .startup_state
102            .try_write()
103            .ok_or(StartupLockAlreadyStartedError)?;
104        if *guard {
105            return Err(StartupLockAlreadyStartedError);
106        }
107        *self.stop_source.lock() = Some(StopSource::new());
108
109        Ok(StartupLockGuard {
110            guard,
111            success_value: true,
112        })
113    }
114
115    /// Get a stop token for this lock
116    /// One can wait on this to timeout operations when a shutdown is requested
117    pub fn stop_token(&self) -> Option<StopToken> {
118        self.stop_source.lock().as_ref().map(|ss| ss.token())
119    }
120
121    /// Check if this StartupLock is currently in a started state
122    /// Returns false is the state is in transition
123    pub fn is_started(&self) -> bool {
124        let Some(guard) = self.startup_state.try_read() else {
125            return false;
126        };
127        *guard
128    }
129
130    /// Check if this StartupLock is currently in a shut down state
131    /// Returns false is the state is in transition
132    pub fn is_shut_down(&self) -> bool {
133        let Some(guard) = self.startup_state.try_read() else {
134            return false;
135        };
136        !*guard
137    }
138
139    /// Wait for all 'entered' operations to finish before shutting down
140    /// One must call 'success()' on the returned startup lock guard if shutdown was successful
141    /// otherwise the startup lock will not shift to the 'stopped' state.
142    pub async fn shutdown(&self) -> Result<StartupLockGuard<'_>, StartupLockAlreadyShutDownError> {
143        // Drop the stop source to ensure we can detect shutdown has been requested
144        *self.stop_source.lock() = None;
145
146        cfg_if! {
147            if #[cfg(feature = "debug-locks")] {
148                let guard = match timeout(30000, self.startup_state.write()).await {
149                    Ok(v) => v,
150                    Err(_) => {
151                        eprintln!("active guards: {:#?}", self.active_guards.lock().values().collect::<Vec<_>>());
152                        panic!("shutdown deadlock");
153                    }
154                };
155            } else {
156                let guard = self.startup_state.write().await;
157            }
158        }
159        if !*guard {
160            return Err(StartupLockAlreadyShutDownError);
161        }
162        Ok(StartupLockGuard {
163            guard,
164            success_value: false,
165        })
166    }
167
168    /// Enter an operation in a started-up module.
169    /// If this module has not yet started up or is in the process of startup or shutdown
170    /// this will fail.
171    pub fn enter(&self) -> Result<StartupLockEnterGuard<'_>, StartupLockNotStartedError> {
172        let guard = self
173            .startup_state
174            .try_read()
175            .ok_or(StartupLockNotStartedError)?;
176        if !*guard {
177            return Err(StartupLockNotStartedError);
178        }
179        let out = StartupLockEnterGuard {
180            _guard: guard,
181            #[cfg(feature = "debug-locks")]
182            id: GUARD_ID.fetch_add(1, Ordering::AcqRel),
183            #[cfg(feature = "debug-locks")]
184            active_guards: self.active_guards.clone(),
185        };
186
187        #[cfg(feature = "debug-locks")]
188        self.active_guards
189            .lock()
190            .insert(out.id, backtrace::Backtrace::new());
191
192        Ok(out)
193    }
194
195    /// Enter an operation in a started-up module, using an owned lock.
196    /// If this module has not yet started up or is in the process of startup or shutdown
197    /// this will fail.
198    pub fn enter_arc(&self) -> Result<StartupLockEnterGuardArc, StartupLockNotStartedError> {
199        let guard = self
200            .startup_state
201            .try_read_arc()
202            .ok_or(StartupLockNotStartedError)?;
203        if !*guard {
204            return Err(StartupLockNotStartedError);
205        }
206        let out = StartupLockEnterGuardArc {
207            _guard: guard,
208            #[cfg(feature = "debug-locks")]
209            id: GUARD_ID.fetch_add(1, Ordering::AcqRel),
210            #[cfg(feature = "debug-locks")]
211            active_guards: self.active_guards.clone(),
212        };
213
214        #[cfg(feature = "debug-locks")]
215        self.active_guards
216            .lock()
217            .insert(out.id, backtrace::Backtrace::new());
218
219        Ok(out)
220    }
221}
222
223impl Default for StartupLock {
224    fn default() -> Self {
225        Self::new()
226    }
227}