use std::collections::{HashMap, HashSet};
use std::sync::atomic::AtomicBool;
use std::sync::{Arc, Mutex};
pub struct VaultState {
pub(in crate::app) cert_cache: HashMap<
String,
(
std::time::Instant,
crate::vault_ssh::CertStatus,
Option<std::time::SystemTime>,
),
>,
pub(in crate::app) cert_checks_in_flight: HashSet<String>,
pub(in crate::app) cleanup_warning: Option<String>,
pub(in crate::app) signing_cancel: Option<Arc<AtomicBool>>,
pub(in crate::app) sign_thread: Option<std::thread::JoinHandle<()>>,
pub(in crate::app) sign_in_flight: Arc<Mutex<HashSet<String>>>,
pub(in crate::app) pending_config_write: bool,
}
impl Default for VaultState {
fn default() -> Self {
Self {
cert_cache: HashMap::new(),
cert_checks_in_flight: HashSet::new(),
cleanup_warning: None,
signing_cancel: None,
sign_thread: None,
sign_in_flight: Arc::new(Mutex::new(HashSet::new())),
pending_config_write: false,
}
}
}
type CertCacheEntry = (
std::time::Instant,
crate::vault_ssh::CertStatus,
Option<std::time::SystemTime>,
);
impl VaultState {
pub fn cert_cache(&self) -> &HashMap<String, CertCacheEntry> {
&self.cert_cache
}
pub fn cert_entry(&self, alias: &str) -> Option<&CertCacheEntry> {
self.cert_cache.get(alias)
}
pub fn has_cert(&self, alias: &str) -> bool {
self.cert_cache.contains_key(alias)
}
pub fn insert_cert(&mut self, alias: String, entry: CertCacheEntry) {
self.cert_cache.insert(alias, entry);
}
pub fn remove_cert(&mut self, alias: &str) {
self.cert_cache.remove(alias);
}
pub fn clear_cert_cache(&mut self) {
self.cert_cache.clear();
}
pub fn is_cert_check_in_flight(&self, alias: &str) -> bool {
self.cert_checks_in_flight.contains(alias)
}
pub fn take_cleanup_warning(&mut self) -> Option<String> {
self.cleanup_warning.take()
}
pub fn signing_cancel(&self) -> Option<&Arc<AtomicBool>> {
self.signing_cancel.as_ref()
}
pub fn is_signing(&self) -> bool {
self.signing_cancel.is_some()
}
pub fn set_signing_cancel(&mut self, cancel: Arc<AtomicBool>) {
self.signing_cancel = Some(cancel);
}
pub fn clear_signing_cancel(&mut self) {
self.signing_cancel = None;
}
pub fn set_sign_thread(&mut self, handle: std::thread::JoinHandle<()>) {
self.sign_thread = Some(handle);
}
pub fn sign_in_flight(&self) -> &Arc<Mutex<HashSet<String>>> {
&self.sign_in_flight
}
pub fn pending_config_write(&self) -> bool {
self.pending_config_write
}
pub fn set_pending_config_write(&mut self, value: bool) {
self.pending_config_write = value;
}
pub(crate) fn mark_cert_check_started(&mut self, alias: String) {
self.cert_checks_in_flight.insert(alias);
}
pub(crate) fn record_cert_check(
&mut self,
alias: String,
status: crate::vault_ssh::CertStatus,
mtime: Option<std::time::SystemTime>,
) {
self.cert_checks_in_flight.remove(&alias);
self.cert_cache
.insert(alias, (std::time::Instant::now(), status, mtime));
}
pub(crate) fn cancel_signing_run(&mut self) -> Option<std::thread::JoinHandle<()>> {
if let Some(ref cancel) = self.signing_cancel {
cancel.store(true, std::sync::atomic::Ordering::Relaxed);
}
self.signing_cancel = None;
self.sign_thread.take()
}
pub(crate) fn finalize_signing_run(&mut self) -> Option<std::thread::JoinHandle<()>> {
self.signing_cancel = None;
self.sign_thread.take()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::Ordering;
#[test]
fn mark_cert_check_started_inserts_alias() {
let mut v = VaultState::default();
v.mark_cert_check_started("web".to_string());
assert!(v.cert_checks_in_flight.contains("web"));
}
#[test]
fn mark_cert_check_started_is_idempotent() {
let mut v = VaultState::default();
v.mark_cert_check_started("web".to_string());
v.mark_cert_check_started("web".to_string());
assert_eq!(v.cert_checks_in_flight.len(), 1);
assert!(v.cert_checks_in_flight.contains("web"));
}
#[test]
fn record_cert_check_clears_in_flight_and_writes_cache() {
let mut v = VaultState::default();
v.mark_cert_check_started("web".to_string());
v.record_cert_check(
"web".to_string(),
crate::vault_ssh::CertStatus::Missing,
None,
);
assert!(!v.cert_checks_in_flight.contains("web"));
assert!(v.cert_cache.contains_key("web"));
let (_, status, mtime) = v.cert_cache.get("web").unwrap();
assert!(matches!(status, crate::vault_ssh::CertStatus::Missing));
assert!(mtime.is_none());
}
#[test]
fn record_cert_check_caches_even_without_prior_start() {
let mut v = VaultState::default();
v.record_cert_check(
"web".to_string(),
crate::vault_ssh::CertStatus::Invalid("nope".to_string()),
None,
);
assert!(v.cert_cache.contains_key("web"));
assert!(v.cert_checks_in_flight.is_empty());
}
#[test]
fn cancel_signing_run_with_no_active_run_returns_none() {
let mut v = VaultState::default();
let handle = v.cancel_signing_run();
assert!(handle.is_none());
assert!(v.signing_cancel.is_none());
assert!(v.sign_thread.is_none());
}
#[test]
fn cancel_signing_run_signals_cancel_and_clears_handle() {
let mut v = VaultState::default();
let cancel = Arc::new(AtomicBool::new(false));
v.signing_cancel = Some(cancel.clone());
v.sign_thread = Some(std::thread::spawn(|| {}));
let handle = v
.cancel_signing_run()
.expect("returned thread handle for joining");
let _ = handle.join();
assert!(
cancel.load(Ordering::Relaxed),
"cancel must be signalled so a long-running worker exits"
);
assert!(v.signing_cancel.is_none());
assert!(v.sign_thread.is_none());
}
#[test]
fn finalize_signing_run_does_not_signal_cancel() {
let mut v = VaultState::default();
let cancel = Arc::new(AtomicBool::new(false));
v.signing_cancel = Some(cancel.clone());
v.sign_thread = Some(std::thread::spawn(|| {}));
let handle = v
.finalize_signing_run()
.expect("returned thread handle for joining");
let _ = handle.join();
assert!(
!cancel.load(Ordering::Relaxed),
"finalize must not signal cancel: a racing newer run's Arc could be hit"
);
assert!(v.signing_cancel.is_none());
assert!(v.sign_thread.is_none());
}
#[test]
fn finalize_signing_run_with_cancel_but_no_thread_clears_cancel() {
let mut v = VaultState::default();
let cancel = Arc::new(AtomicBool::new(false));
v.signing_cancel = Some(cancel.clone());
let handle = v.finalize_signing_run();
assert!(handle.is_none());
assert!(v.signing_cancel.is_none());
assert!(!cancel.load(Ordering::Relaxed));
}
}