1use std::collections::{HashMap, HashSet};
2use std::sync::atomic::AtomicBool;
3use std::sync::{Arc, Mutex};
4
5pub struct VaultState {
7 pub(in crate::app) cert_cache: HashMap<
10 String,
11 (
12 std::time::Instant,
13 crate::vault_ssh::CertStatus,
14 Option<std::time::SystemTime>,
15 ),
16 >,
17 pub(in crate::app) cert_checks_in_flight: HashSet<String>,
19 pub(in crate::app) cleanup_warning: Option<String>,
21 pub(in crate::app) signing_cancel: Option<Arc<AtomicBool>>,
23 pub(in crate::app) sign_thread: Option<std::thread::JoinHandle<()>>,
25 pub(in crate::app) sign_in_flight: Arc<Mutex<HashSet<String>>>,
27 pub(in crate::app) pending_config_write: bool,
29}
30
31impl Default for VaultState {
32 fn default() -> Self {
33 Self {
34 cert_cache: HashMap::new(),
35 cert_checks_in_flight: HashSet::new(),
36 cleanup_warning: None,
37 signing_cancel: None,
38 sign_thread: None,
39 sign_in_flight: Arc::new(Mutex::new(HashSet::new())),
40 pending_config_write: false,
41 }
42 }
43}
44
45type CertCacheEntry = (
46 std::time::Instant,
47 crate::vault_ssh::CertStatus,
48 Option<std::time::SystemTime>,
49);
50
51impl VaultState {
52 pub fn cert_cache(&self) -> &HashMap<String, CertCacheEntry> {
53 &self.cert_cache
54 }
55
56 pub fn cert_entry(&self, alias: &str) -> Option<&CertCacheEntry> {
57 self.cert_cache.get(alias)
58 }
59
60 pub fn has_cert(&self, alias: &str) -> bool {
61 self.cert_cache.contains_key(alias)
62 }
63
64 pub fn insert_cert(&mut self, alias: String, entry: CertCacheEntry) {
65 self.cert_cache.insert(alias, entry);
66 }
67
68 pub fn remove_cert(&mut self, alias: &str) {
69 self.cert_cache.remove(alias);
70 }
71
72 pub fn clear_cert_cache(&mut self) {
73 self.cert_cache.clear();
74 }
75
76 pub fn is_cert_check_in_flight(&self, alias: &str) -> bool {
77 self.cert_checks_in_flight.contains(alias)
78 }
79
80 pub fn take_cleanup_warning(&mut self) -> Option<String> {
81 self.cleanup_warning.take()
82 }
83
84 pub fn signing_cancel(&self) -> Option<&Arc<AtomicBool>> {
85 self.signing_cancel.as_ref()
86 }
87
88 pub fn is_signing(&self) -> bool {
89 self.signing_cancel.is_some()
90 }
91
92 pub fn set_signing_cancel(&mut self, cancel: Arc<AtomicBool>) {
93 self.signing_cancel = Some(cancel);
94 }
95
96 pub fn clear_signing_cancel(&mut self) {
97 self.signing_cancel = None;
98 }
99
100 pub fn set_sign_thread(&mut self, handle: std::thread::JoinHandle<()>) {
101 self.sign_thread = Some(handle);
102 }
103
104 pub fn sign_in_flight(&self) -> &Arc<Mutex<HashSet<String>>> {
105 &self.sign_in_flight
106 }
107
108 pub fn pending_config_write(&self) -> bool {
109 self.pending_config_write
110 }
111
112 pub fn set_pending_config_write(&mut self, value: bool) {
113 self.pending_config_write = value;
114 }
115
116 pub(crate) fn mark_cert_check_started(&mut self, alias: String) {
120 self.cert_checks_in_flight.insert(alias);
121 }
122
123 pub(crate) fn record_cert_check(
128 &mut self,
129 alias: String,
130 status: crate::vault_ssh::CertStatus,
131 mtime: Option<std::time::SystemTime>,
132 ) {
133 self.cert_checks_in_flight.remove(&alias);
134 self.cert_cache
135 .insert(alias, (std::time::Instant::now(), status, mtime));
136 }
137
138 pub(crate) fn cancel_signing_run(&mut self) -> Option<std::thread::JoinHandle<()>> {
143 if let Some(ref cancel) = self.signing_cancel {
144 cancel.store(true, std::sync::atomic::Ordering::Relaxed);
145 }
146 self.signing_cancel = None;
147 self.sign_thread.take()
148 }
149
150 pub(crate) fn finalize_signing_run(&mut self) -> Option<std::thread::JoinHandle<()>> {
157 self.signing_cancel = None;
158 self.sign_thread.take()
159 }
160
161 pub fn prune_orphans(&mut self, valid_aliases: &HashSet<&str>) {
167 let pre_cert = self.cert_cache.len();
168 let pre_checks = self.cert_checks_in_flight.len();
169 self.cert_cache
170 .retain(|alias, _| valid_aliases.contains(alias.as_str()));
171 self.cert_checks_in_flight
172 .retain(|alias| valid_aliases.contains(alias.as_str()));
173 let dropped_cert = pre_cert.saturating_sub(self.cert_cache.len());
174 if dropped_cert > 0 {
175 log::debug!(
176 "[purple] reload_hosts: dropped {dropped_cert} orphan cert_cache entrie(s)"
177 );
178 }
179 let dropped_checks = pre_checks.saturating_sub(self.cert_checks_in_flight.len());
180 if dropped_checks > 0 {
181 log::debug!(
182 "[purple] reload_hosts: dropped {dropped_checks} orphan cert_checks_in_flight alias(es)"
183 );
184 }
185
186 let mut sign = match self.sign_in_flight.lock() {
187 Ok(g) => g,
188 Err(p) => p.into_inner(),
189 };
190 let pre = sign.len();
191 sign.retain(|alias| valid_aliases.contains(alias.as_str()));
192 let dropped = pre.saturating_sub(sign.len());
193 if dropped > 0 {
194 log::debug!("[purple] reload_hosts: dropped {dropped} orphan sign_in_flight alias(es)");
195 }
196 }
197
198 pub fn migrate_alias(&mut self, old: &str, new: &str) {
205 if old == new {
206 return;
207 }
208 if self.cert_checks_in_flight.remove(old) {
209 self.cert_checks_in_flight.insert(new.to_string());
210 }
211 let mut sign = match self.sign_in_flight.lock() {
212 Ok(g) => g,
213 Err(p) => p.into_inner(),
214 };
215 if sign.remove(old) {
216 sign.insert(new.to_string());
217 }
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 use std::sync::atomic::Ordering;
225
226 #[test]
227 fn mark_cert_check_started_inserts_alias() {
228 let mut v = VaultState::default();
229 v.mark_cert_check_started("web".to_string());
230 assert!(v.cert_checks_in_flight.contains("web"));
231 }
232
233 #[test]
234 fn mark_cert_check_started_is_idempotent() {
235 let mut v = VaultState::default();
238 v.mark_cert_check_started("web".to_string());
239 v.mark_cert_check_started("web".to_string());
240 assert_eq!(v.cert_checks_in_flight.len(), 1);
241 assert!(v.cert_checks_in_flight.contains("web"));
242 }
243
244 #[test]
245 fn record_cert_check_clears_in_flight_and_writes_cache() {
246 let mut v = VaultState::default();
247 v.mark_cert_check_started("web".to_string());
248 v.record_cert_check(
249 "web".to_string(),
250 crate::vault_ssh::CertStatus::Missing,
251 None,
252 );
253 assert!(!v.cert_checks_in_flight.contains("web"));
254 assert!(v.cert_cache.contains_key("web"));
255 let (_, status, mtime) = v.cert_cache.get("web").unwrap();
256 assert!(matches!(status, crate::vault_ssh::CertStatus::Missing));
257 assert!(mtime.is_none());
258 }
259
260 #[test]
261 fn record_cert_check_caches_even_without_prior_start() {
262 let mut v = VaultState::default();
267 v.record_cert_check(
268 "web".to_string(),
269 crate::vault_ssh::CertStatus::Invalid("nope".to_string()),
270 None,
271 );
272 assert!(v.cert_cache.contains_key("web"));
273 assert!(v.cert_checks_in_flight.is_empty());
274 }
275
276 #[test]
277 fn cancel_signing_run_with_no_active_run_returns_none() {
278 let mut v = VaultState::default();
279 let handle = v.cancel_signing_run();
280 assert!(handle.is_none());
281 assert!(v.signing_cancel.is_none());
282 assert!(v.sign_thread.is_none());
283 }
284
285 #[test]
286 fn cancel_signing_run_signals_cancel_and_clears_handle() {
287 let mut v = VaultState::default();
291 let cancel = Arc::new(AtomicBool::new(false));
292 v.signing_cancel = Some(cancel.clone());
293 v.sign_thread = Some(std::thread::spawn(|| {}));
294
295 let handle = v
296 .cancel_signing_run()
297 .expect("returned thread handle for joining");
298 let _ = handle.join();
299
300 assert!(
301 cancel.load(Ordering::Relaxed),
302 "cancel must be signalled so a long-running worker exits"
303 );
304 assert!(v.signing_cancel.is_none());
305 assert!(v.sign_thread.is_none());
306 }
307
308 #[test]
309 fn finalize_signing_run_does_not_signal_cancel() {
310 let mut v = VaultState::default();
316 let cancel = Arc::new(AtomicBool::new(false));
317 v.signing_cancel = Some(cancel.clone());
318 v.sign_thread = Some(std::thread::spawn(|| {}));
319
320 let handle = v
321 .finalize_signing_run()
322 .expect("returned thread handle for joining");
323 let _ = handle.join();
324
325 assert!(
326 !cancel.load(Ordering::Relaxed),
327 "finalize must not signal cancel: a racing newer run's Arc could be hit"
328 );
329 assert!(v.signing_cancel.is_none());
330 assert!(v.sign_thread.is_none());
331 }
332
333 #[test]
334 fn finalize_signing_run_with_cancel_but_no_thread_clears_cancel() {
335 let mut v = VaultState::default();
341 let cancel = Arc::new(AtomicBool::new(false));
342 v.signing_cancel = Some(cancel.clone());
343
344 let handle = v.finalize_signing_run();
345 assert!(handle.is_none());
346 assert!(v.signing_cancel.is_none());
347 assert!(!cancel.load(Ordering::Relaxed));
348 }
349
350 #[test]
351 fn prune_orphans_drops_unknown_aliases_across_cert_and_sign_state() {
352 let mut v = VaultState::default();
353 v.cert_cache.insert(
354 "keep".to_string(),
355 (
356 std::time::Instant::now(),
357 crate::vault_ssh::CertStatus::Missing,
358 None,
359 ),
360 );
361 v.cert_cache.insert(
362 "drop".to_string(),
363 (
364 std::time::Instant::now(),
365 crate::vault_ssh::CertStatus::Missing,
366 None,
367 ),
368 );
369 v.cert_checks_in_flight.insert("keep".to_string());
370 v.cert_checks_in_flight.insert("drop".to_string());
371 v.sign_in_flight.lock().unwrap().insert("keep".to_string());
372 v.sign_in_flight.lock().unwrap().insert("drop".to_string());
373
374 let valid: HashSet<&str> = ["keep"].into_iter().collect();
375 v.prune_orphans(&valid);
376
377 assert!(v.cert_cache.contains_key("keep"));
378 assert!(!v.cert_cache.contains_key("drop"));
379 assert!(v.cert_checks_in_flight.contains("keep"));
380 assert!(!v.cert_checks_in_flight.contains("drop"));
381 let sign = v.sign_in_flight.lock().unwrap();
382 assert!(sign.contains("keep"));
383 assert!(!sign.contains("drop"));
384 }
385
386 #[test]
387 fn migrate_alias_moves_checks_and_sign_but_not_cert_cache() {
388 let mut v = VaultState::default();
389 v.cert_cache.insert(
390 "old".to_string(),
391 (
392 std::time::Instant::now(),
393 crate::vault_ssh::CertStatus::Missing,
394 None,
395 ),
396 );
397 v.cert_checks_in_flight.insert("old".to_string());
398 v.sign_in_flight.lock().unwrap().insert("old".to_string());
399
400 v.migrate_alias("old", "new");
401
402 assert!(v.cert_cache.contains_key("old"));
405 assert!(!v.cert_cache.contains_key("new"));
406
407 assert!(!v.cert_checks_in_flight.contains("old"));
408 assert!(v.cert_checks_in_flight.contains("new"));
409
410 let sign = v.sign_in_flight.lock().unwrap();
411 assert!(!sign.contains("old"));
412 assert!(sign.contains("new"));
413 }
414}