1use std::collections::HashMap;
5use std::sync::RwLock;
6use zeroize::Zeroizing;
7
8#[cfg(feature = "native")]
9use std::sync::Arc;
10#[cfg(feature = "native")]
11use tokio::sync::OwnedRwLockWriteGuard;
12
13pub struct VaultKeyStore {
14 keys: RwLock<HashMap<String, Zeroizing<Vec<u8>>>>,
15 #[cfg(feature = "native")]
16 fences: RwLock<HashMap<String, Arc<tokio::sync::RwLock<()>>>>,
17}
18
19impl Default for VaultKeyStore {
20 fn default() -> Self {
21 Self::new()
22 }
23}
24
25impl VaultKeyStore {
26 #[must_use]
27 pub fn new() -> Self {
28 Self {
29 keys: RwLock::new(HashMap::new()),
30 #[cfg(feature = "native")]
31 fences: RwLock::new(HashMap::new()),
32 }
33 }
34
35 pub fn set(&self, canonical_id: &str, key: Zeroizing<Vec<u8>>) {
36 if let Ok(mut map) = self.keys.write() {
37 map.insert(canonical_id.to_string(), key);
38 }
39 }
40
41 pub fn remove(&self, canonical_id: &str) {
42 if let Ok(mut map) = self.keys.write() {
43 map.remove(canonical_id);
44 }
45 #[cfg(feature = "native")]
46 if let Ok(mut map) = self.fences.write() {
47 map.remove(canonical_id);
48 }
49 }
50
51 #[must_use]
52 pub fn get(&self, canonical_id: &str) -> Option<Zeroizing<Vec<u8>>> {
53 let map = self.keys.read().ok()?;
54 map.get(canonical_id).cloned()
55 }
56
57 #[cfg(feature = "native")]
58 pub async fn acquire_fence(&self, canonical_id: &str) -> OwnedRwLockWriteGuard<()> {
59 let lock = {
60 let mut map = self
61 .fences
62 .write()
63 .unwrap_or_else(std::sync::PoisonError::into_inner);
64 map.entry(canonical_id.to_string())
65 .or_insert_with(|| Arc::new(tokio::sync::RwLock::new(())))
66 .clone()
67 };
68 lock.write_owned().await
69 }
70
71 #[cfg(feature = "native")]
72 pub async fn read_fence(&self, canonical_id: &str) {
73 let lock = {
74 let map = self
75 .fences
76 .read()
77 .unwrap_or_else(std::sync::PoisonError::into_inner);
78 map.get(canonical_id).cloned()
79 };
80 if let Some(lock) = lock {
81 let _guard = lock.read().await;
82 }
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89
90 #[test]
91 fn set_and_get_roundtrip() {
92 let store = VaultKeyStore::new();
93 let key = vec![1u8; 32];
94 store.set("user-1", Zeroizing::new(key.clone()));
95 let retrieved = store.get("user-1").expect("key should exist");
96 assert_eq!(&*retrieved, &key);
97 }
98
99 #[test]
100 fn remove_clears_key() {
101 let store = VaultKeyStore::new();
102 store.set("user-2", Zeroizing::new(vec![2u8; 32]));
103 assert!(store.get("user-2").is_some());
104 store.remove("user-2");
105 assert!(store.get("user-2").is_none());
106 }
107
108 #[test]
109 fn get_nonexistent_returns_none() {
110 let store = VaultKeyStore::new();
111 assert!(store.get("nonexistent").is_none());
112 }
113
114 #[cfg(feature = "native")]
115 #[tokio::test]
116 async fn fence_blocks_concurrent_reads() {
117 let store = Arc::new(VaultKeyStore::new());
118 let fence_guard = store.acquire_fence("user-a").await;
119
120 let read_completed = Arc::new(std::sync::atomic::AtomicBool::new(false));
121 let read_completed_clone = read_completed.clone();
122 let store_clone = store.clone();
123
124 let handle = tokio::spawn(async move {
125 store_clone.read_fence("user-a").await;
126 read_completed_clone.store(true, std::sync::atomic::Ordering::SeqCst);
127 });
128
129 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
130 assert!(!read_completed.load(std::sync::atomic::Ordering::SeqCst));
131
132 drop(fence_guard);
133 let _ = handle.await;
134 assert!(read_completed.load(std::sync::atomic::Ordering::SeqCst));
135 }
136
137 #[cfg(feature = "native")]
138 #[tokio::test]
139 async fn read_fence_noop_without_active_batch() {
140 let store = VaultKeyStore::new();
141 store.read_fence("no-fence-user").await;
142 }
143}