Skip to main content

mqdb_core/
vault_keys.rs

1// Copyright 2025-2026 LabOverWire. All rights reserved.
2// SPDX-License-Identifier: AGPL-3.0-only
3
4use std::collections::HashMap;
5use std::sync::RwLock;
6use zeroize::Zeroizing;
7
8#[cfg(feature = "native")]
9use std::sync::Arc;
10#[cfg(feature = "native")]
11use tokio::sync::OwnedRwLockWriteGuard;
12
13pub struct VaultKeyStore {
14    keys: RwLock<HashMap<String, Zeroizing<Vec<u8>>>>,
15    #[cfg(feature = "native")]
16    fences: RwLock<HashMap<String, Arc<tokio::sync::RwLock<()>>>>,
17}
18
19impl Default for VaultKeyStore {
20    fn default() -> Self {
21        Self::new()
22    }
23}
24
25impl VaultKeyStore {
26    #[must_use]
27    pub fn new() -> Self {
28        Self {
29            keys: RwLock::new(HashMap::new()),
30            #[cfg(feature = "native")]
31            fences: RwLock::new(HashMap::new()),
32        }
33    }
34
35    pub fn set(&self, canonical_id: &str, key: Zeroizing<Vec<u8>>) {
36        if let Ok(mut map) = self.keys.write() {
37            map.insert(canonical_id.to_string(), key);
38        }
39    }
40
41    pub fn remove(&self, canonical_id: &str) {
42        if let Ok(mut map) = self.keys.write() {
43            map.remove(canonical_id);
44        }
45        #[cfg(feature = "native")]
46        if let Ok(mut map) = self.fences.write() {
47            map.remove(canonical_id);
48        }
49    }
50
51    #[must_use]
52    pub fn get(&self, canonical_id: &str) -> Option<Zeroizing<Vec<u8>>> {
53        let map = self.keys.read().ok()?;
54        map.get(canonical_id).cloned()
55    }
56
57    #[cfg(feature = "native")]
58    pub async fn acquire_fence(&self, canonical_id: &str) -> OwnedRwLockWriteGuard<()> {
59        let lock = {
60            let mut map = self
61                .fences
62                .write()
63                .unwrap_or_else(std::sync::PoisonError::into_inner);
64            map.entry(canonical_id.to_string())
65                .or_insert_with(|| Arc::new(tokio::sync::RwLock::new(())))
66                .clone()
67        };
68        lock.write_owned().await
69    }
70
71    #[cfg(feature = "native")]
72    pub async fn read_fence(&self, canonical_id: &str) {
73        let lock = {
74            let map = self
75                .fences
76                .read()
77                .unwrap_or_else(std::sync::PoisonError::into_inner);
78            map.get(canonical_id).cloned()
79        };
80        if let Some(lock) = lock {
81            let _guard = lock.read().await;
82        }
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89
90    #[test]
91    fn set_and_get_roundtrip() {
92        let store = VaultKeyStore::new();
93        let key = vec![1u8; 32];
94        store.set("user-1", Zeroizing::new(key.clone()));
95        let retrieved = store.get("user-1").expect("key should exist");
96        assert_eq!(&*retrieved, &key);
97    }
98
99    #[test]
100    fn remove_clears_key() {
101        let store = VaultKeyStore::new();
102        store.set("user-2", Zeroizing::new(vec![2u8; 32]));
103        assert!(store.get("user-2").is_some());
104        store.remove("user-2");
105        assert!(store.get("user-2").is_none());
106    }
107
108    #[test]
109    fn get_nonexistent_returns_none() {
110        let store = VaultKeyStore::new();
111        assert!(store.get("nonexistent").is_none());
112    }
113
114    #[cfg(feature = "native")]
115    #[tokio::test]
116    async fn fence_blocks_concurrent_reads() {
117        let store = Arc::new(VaultKeyStore::new());
118        let fence_guard = store.acquire_fence("user-a").await;
119
120        let read_completed = Arc::new(std::sync::atomic::AtomicBool::new(false));
121        let read_completed_clone = read_completed.clone();
122        let store_clone = store.clone();
123
124        let handle = tokio::spawn(async move {
125            store_clone.read_fence("user-a").await;
126            read_completed_clone.store(true, std::sync::atomic::Ordering::SeqCst);
127        });
128
129        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
130        assert!(!read_completed.load(std::sync::atomic::Ordering::SeqCst));
131
132        drop(fence_guard);
133        let _ = handle.await;
134        assert!(read_completed.load(std::sync::atomic::Ordering::SeqCst));
135    }
136
137    #[cfg(feature = "native")]
138    #[tokio::test]
139    async fn read_fence_noop_without_active_batch() {
140        let store = VaultKeyStore::new();
141        store.read_fence("no-fence-user").await;
142    }
143}