Skip to main content

mssql_auth/
key_store.rs

1//! Key store providers and CEK caching for Always Encrypted.
2//!
3//! This module provides:
4//! - [`InMemoryKeyStore`]: A simple key store for testing and development
5//! - [`CekCache`]: A thread-safe cache for decrypted Column Encryption Keys
6//!
7//! ## Production Usage
8//!
9//! For production environments, implement the [`KeyStoreProvider`] trait
10//! with a secure key storage solution such as:
11//! - Azure Key Vault
12//! - Windows Certificate Store
13//! - Hardware Security Module (HSM)
14//!
15//! ## Example
16//!
17//! ```rust,ignore
18//! use mssql_auth::key_store::{InMemoryKeyStore, CekCache};
19//!
20//! // Create a key store with test keys
21//! let mut key_store = InMemoryKeyStore::new();
22//! key_store.add_key("TestKey", &private_key_pem)?;
23//!
24//! // Create a CEK cache for performance
25//! let cek_cache = CekCache::new();
26//! ```
27
28use std::collections::HashMap;
29use std::sync::Arc;
30use std::time::{Duration, Instant};
31
32use parking_lot::RwLock;
33
34use crate::aead::AeadEncryptor;
35use crate::encryption::{EncryptionError, KeyStoreProvider};
36use crate::key_unwrap::RsaKeyUnwrapper;
37
38/// In-memory key store for testing and development.
39///
40/// **Security Warning**: This stores private keys in memory without hardware
41/// protection. Use only for testing or development environments.
42///
43/// For production, use Azure Key Vault, Windows Certificate Store, or an HSM.
44pub struct InMemoryKeyStore {
45    /// Map of key path to RSA key unwrapper.
46    keys: HashMap<String, RsaKeyUnwrapper>,
47}
48
49impl InMemoryKeyStore {
50    /// Create a new empty in-memory key store.
51    pub fn new() -> Self {
52        Self {
53            keys: HashMap::new(),
54        }
55    }
56
57    /// Add a key to the store from PEM-encoded private key.
58    ///
59    /// # Arguments
60    ///
61    /// * `key_path` - The identifier/path for this key
62    /// * `pem` - PEM-encoded RSA private key (PKCS#1 or PKCS#8)
63    ///
64    /// # Errors
65    ///
66    /// Returns an error if the PEM cannot be parsed.
67    pub fn add_key(&mut self, key_path: &str, pem: &str) -> Result<(), EncryptionError> {
68        let unwrapper = RsaKeyUnwrapper::from_pem(pem)?;
69        self.keys.insert(key_path.to_string(), unwrapper);
70        Ok(())
71    }
72
73    /// Add a key to the store from DER-encoded private key.
74    ///
75    /// # Arguments
76    ///
77    /// * `key_path` - The identifier/path for this key
78    /// * `der` - DER-encoded RSA private key
79    ///
80    /// # Errors
81    ///
82    /// Returns an error if the DER cannot be parsed.
83    pub fn add_key_der(&mut self, key_path: &str, der: &[u8]) -> Result<(), EncryptionError> {
84        let unwrapper = RsaKeyUnwrapper::from_der(der)?;
85        self.keys.insert(key_path.to_string(), unwrapper);
86        Ok(())
87    }
88
89    /// Check if a key exists in the store.
90    pub fn has_key(&self, key_path: &str) -> bool {
91        self.keys.contains_key(key_path)
92    }
93
94    /// Remove a key from the store.
95    pub fn remove_key(&mut self, key_path: &str) -> bool {
96        self.keys.remove(key_path).is_some()
97    }
98
99    /// Get the number of keys in the store.
100    pub fn len(&self) -> usize {
101        self.keys.len()
102    }
103
104    /// Check if the store is empty.
105    pub fn is_empty(&self) -> bool {
106        self.keys.is_empty()
107    }
108}
109
110impl Default for InMemoryKeyStore {
111    fn default() -> Self {
112        Self::new()
113    }
114}
115
116#[async_trait::async_trait]
117impl KeyStoreProvider for InMemoryKeyStore {
118    fn provider_name(&self) -> &str {
119        "IN_MEMORY_KEY_STORE"
120    }
121
122    async fn decrypt_cek(
123        &self,
124        cmk_path: &str,
125        _algorithm: &str,
126        encrypted_cek: &[u8],
127    ) -> Result<Vec<u8>, EncryptionError> {
128        let unwrapper = self.keys.get(cmk_path).ok_or_else(|| {
129            EncryptionError::KeyStoreNotFound(format!("Key not found: {cmk_path}"))
130        })?;
131
132        unwrapper.decrypt_cek(encrypted_cek)
133    }
134}
135
136/// Entry in the CEK cache.
137struct CekCacheEntry {
138    /// The decrypted CEK (stored for potential future use like re-keying).
139    #[allow(dead_code)]
140    cek: Vec<u8>,
141    /// AEAD encryptor instance (pre-derived keys).
142    encryptor: Arc<AeadEncryptor>,
143    /// When this entry was created.
144    created_at: Instant,
145}
146
147/// Thread-safe cache for decrypted Column Encryption Keys.
148///
149/// The cache stores decrypted CEKs and pre-computed AEAD encryptors
150/// to avoid repeated RSA decryption and key derivation operations.
151///
152/// ## Cache Key
153///
154/// Entries are keyed by: `(database_id, cek_id, cek_version)`
155///
156/// ## Expiration
157///
158/// Entries expire after a configurable TTL (default: 2 hours).
159/// Expired entries are lazily removed on access.
160pub struct CekCache {
161    /// Map of cache key to entry.
162    entries: RwLock<HashMap<CekCacheKey, CekCacheEntry>>,
163    /// Time-to-live for cache entries.
164    ttl: Duration,
165}
166
167/// Key for CEK cache entries.
168#[derive(Debug, Clone, PartialEq, Eq, Hash)]
169pub struct CekCacheKey {
170    /// Database ID.
171    pub database_id: u32,
172    /// CEK ID within the database.
173    pub cek_id: u32,
174    /// CEK version (for key rotation).
175    pub cek_version: u32,
176}
177
178impl CekCacheKey {
179    /// Create a new cache key.
180    pub fn new(database_id: u32, cek_id: u32, cek_version: u32) -> Self {
181        Self {
182            database_id,
183            cek_id,
184            cek_version,
185        }
186    }
187}
188
189impl CekCache {
190    /// Create a new CEK cache with default TTL (2 hours).
191    pub fn new() -> Self {
192        Self::with_ttl(Duration::from_secs(2 * 60 * 60))
193    }
194
195    /// Create a new CEK cache with custom TTL.
196    pub fn with_ttl(ttl: Duration) -> Self {
197        Self {
198            entries: RwLock::new(HashMap::new()),
199            ttl,
200        }
201    }
202
203    /// Get a cached encryptor for a CEK.
204    ///
205    /// Returns `None` if the entry doesn't exist or has expired.
206    pub fn get(&self, key: &CekCacheKey) -> Option<Arc<AeadEncryptor>> {
207        let entries = self.entries.read();
208        if let Some(entry) = entries.get(key) {
209            if entry.created_at.elapsed() < self.ttl {
210                return Some(Arc::clone(&entry.encryptor));
211            }
212        }
213        None
214    }
215
216    /// Insert a CEK into the cache.
217    ///
218    /// Creates an AEAD encryptor from the CEK for future use.
219    ///
220    /// # Arguments
221    ///
222    /// * `key` - The cache key
223    /// * `cek` - The decrypted Column Encryption Key
224    ///
225    /// # Returns
226    ///
227    /// The AEAD encryptor for the CEK.
228    pub fn insert(
229        &self,
230        key: CekCacheKey,
231        cek: Vec<u8>,
232    ) -> Result<Arc<AeadEncryptor>, EncryptionError> {
233        let encryptor = Arc::new(AeadEncryptor::new(&cek)?);
234
235        let entry = CekCacheEntry {
236            cek,
237            encryptor: Arc::clone(&encryptor),
238            created_at: Instant::now(),
239        };
240
241        let mut entries = self.entries.write();
242        entries.insert(key, entry);
243
244        Ok(encryptor)
245    }
246
247    /// Get or insert a CEK.
248    ///
249    /// If the CEK is cached, returns the cached encryptor.
250    /// Otherwise, calls the provided function to get the CEK
251    /// and caches it.
252    ///
253    /// # Arguments
254    ///
255    /// * `key` - The cache key
256    /// * `get_cek` - Function to get the CEK if not cached
257    pub async fn get_or_insert<F, Fut>(
258        &self,
259        key: CekCacheKey,
260        get_cek: F,
261    ) -> Result<Arc<AeadEncryptor>, EncryptionError>
262    where
263        F: FnOnce() -> Fut,
264        Fut: std::future::Future<Output = Result<Vec<u8>, EncryptionError>>,
265    {
266        // Try to get from cache first
267        if let Some(encryptor) = self.get(&key) {
268            return Ok(encryptor);
269        }
270
271        // Not in cache, fetch and insert
272        let cek = get_cek().await?;
273        self.insert(key, cek)
274    }
275
276    /// Remove a CEK from the cache.
277    ///
278    /// Call this when a CEK is rotated or invalidated.
279    pub fn remove(&self, key: &CekCacheKey) -> bool {
280        let mut entries = self.entries.write();
281        entries.remove(key).is_some()
282    }
283
284    /// Clear all expired entries from the cache.
285    pub fn cleanup_expired(&self) {
286        let mut entries = self.entries.write();
287        entries.retain(|_, entry| entry.created_at.elapsed() < self.ttl);
288    }
289
290    /// Clear all entries from the cache.
291    pub fn clear(&self) {
292        let mut entries = self.entries.write();
293        entries.clear();
294    }
295
296    /// Get the number of entries in the cache.
297    pub fn len(&self) -> usize {
298        self.entries.read().len()
299    }
300
301    /// Check if the cache is empty.
302    pub fn is_empty(&self) -> bool {
303        self.entries.read().is_empty()
304    }
305}
306
307impl Default for CekCache {
308    fn default() -> Self {
309        Self::new()
310    }
311}
312
313#[cfg(test)]
314#[allow(clippy::unwrap_used, clippy::expect_used)]
315mod tests {
316    use super::*;
317    use rsa::{RsaPrivateKey, pkcs8::EncodePrivateKey};
318
319    fn generate_test_key_pem() -> String {
320        let mut rng = rand::thread_rng();
321        let key = RsaPrivateKey::new(&mut rng, 2048).unwrap();
322        key.to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
323            .unwrap()
324            .to_string()
325    }
326
327    #[test]
328    fn test_in_memory_key_store_new() {
329        let store = InMemoryKeyStore::new();
330        assert!(store.is_empty());
331        assert_eq!(store.len(), 0);
332    }
333
334    #[test]
335    fn test_in_memory_key_store_add_key() {
336        let mut store = InMemoryKeyStore::new();
337        let pem = generate_test_key_pem();
338
339        store.add_key("TestKey", &pem).unwrap();
340        assert!(store.has_key("TestKey"));
341        assert!(!store.has_key("OtherKey"));
342        assert_eq!(store.len(), 1);
343    }
344
345    #[test]
346    fn test_in_memory_key_store_remove_key() {
347        let mut store = InMemoryKeyStore::new();
348        let pem = generate_test_key_pem();
349
350        store.add_key("TestKey", &pem).unwrap();
351        assert!(store.remove_key("TestKey"));
352        assert!(!store.has_key("TestKey"));
353        assert!(!store.remove_key("TestKey"));
354    }
355
356    #[test]
357    fn test_in_memory_key_store_provider_name() {
358        let store = InMemoryKeyStore::new();
359        assert_eq!(store.provider_name(), "IN_MEMORY_KEY_STORE");
360    }
361
362    #[test]
363    fn test_cek_cache_key() {
364        let key1 = CekCacheKey::new(1, 2, 3);
365        let key2 = CekCacheKey::new(1, 2, 3);
366        let key3 = CekCacheKey::new(1, 2, 4);
367
368        assert_eq!(key1, key2);
369        assert_ne!(key1, key3);
370    }
371
372    #[test]
373    fn test_cek_cache_insert_and_get() {
374        let cache = CekCache::new();
375        let key = CekCacheKey::new(1, 1, 1);
376        let cek = vec![0x42u8; 32];
377
378        // Insert
379        let encryptor = cache.insert(key.clone(), cek).unwrap();
380        assert_eq!(cache.len(), 1);
381
382        // Get
383        let retrieved = cache.get(&key);
384        assert!(retrieved.is_some());
385        assert!(Arc::ptr_eq(&encryptor, &retrieved.unwrap()));
386    }
387
388    #[test]
389    fn test_cek_cache_miss() {
390        let cache = CekCache::new();
391        let key = CekCacheKey::new(1, 1, 1);
392
393        assert!(cache.get(&key).is_none());
394    }
395
396    #[test]
397    fn test_cek_cache_expiration() {
398        let cache = CekCache::with_ttl(Duration::from_millis(10));
399        let key = CekCacheKey::new(1, 1, 1);
400        let cek = vec![0x42u8; 32];
401
402        cache.insert(key.clone(), cek).unwrap();
403        assert!(cache.get(&key).is_some());
404
405        // Wait for expiration
406        std::thread::sleep(Duration::from_millis(20));
407        assert!(cache.get(&key).is_none());
408    }
409
410    #[test]
411    fn test_cek_cache_remove() {
412        let cache = CekCache::new();
413        let key = CekCacheKey::new(1, 1, 1);
414        let cek = vec![0x42u8; 32];
415
416        cache.insert(key.clone(), cek).unwrap();
417        assert!(cache.remove(&key));
418        assert!(cache.get(&key).is_none());
419    }
420
421    #[test]
422    fn test_cek_cache_clear() {
423        let cache = CekCache::new();
424
425        for i in 0..5 {
426            let key = CekCacheKey::new(i, 1, 1);
427            let cek = vec![0x42u8; 32];
428            cache.insert(key, cek).unwrap();
429        }
430
431        assert_eq!(cache.len(), 5);
432        cache.clear();
433        assert!(cache.is_empty());
434    }
435
436    #[test]
437    fn test_cek_cache_cleanup_removes_expired() {
438        // `sleep` only ever overshoots, so a wait far longer than the TTL
439        // guarantees expiry regardless of scheduler jitter (the previous test
440        // depended on 30ms windows and flaked on loaded macOS CI runners).
441        let cache = CekCache::with_ttl(Duration::from_millis(20));
442        cache
443            .insert(CekCacheKey::new(1, 1, 1), vec![0x42u8; 32])
444            .unwrap();
445        assert_eq!(cache.len(), 1);
446
447        std::thread::sleep(Duration::from_millis(150));
448        cache.cleanup_expired();
449
450        assert!(cache.is_empty());
451    }
452
453    #[test]
454    fn test_cek_cache_cleanup_keeps_fresh() {
455        // A TTL far longer than the test runtime guarantees nothing expires.
456        let cache = CekCache::with_ttl(Duration::from_secs(3600));
457        let key = CekCacheKey::new(1, 1, 1);
458        cache.insert(key.clone(), vec![0x42u8; 32]).unwrap();
459
460        cache.cleanup_expired();
461
462        assert_eq!(cache.len(), 1);
463        assert!(cache.get(&key).is_some());
464    }
465}