Skip to main content

auth_framework/storage/
encryption.rs

1use crate::errors::{AuthError, Result};
2use crate::storage::{AuthStorage, SessionData};
3use crate::tokens::AuthToken;
4use aes_gcm::{Aes256Gcm, Key, KeyInit, Nonce, aead::Aead};
5use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
6use serde::{Deserialize, Serialize};
7use std::env;
8use std::time::Duration;
9
10/// Encrypted data container with metadata
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct EncryptedData {
13    /// Base64 encoded encrypted data
14    pub data: String,
15    /// Base64 encoded nonce/IV
16    pub nonce: String,
17    /// Algorithm identifier
18    pub algorithm: String,
19    /// Key derivation method (for future use)
20    pub key_derivation: String,
21}
22
23/// Storage encryption manager using AES-256-GCM
24pub struct StorageEncryption {
25    cipher: Aes256Gcm,
26}
27
28impl StorageEncryption {
29    /// Create new encryption manager from environment variable
30    pub fn new() -> Result<Self> {
31        let key_data = env::var("AUTH_STORAGE_ENCRYPTION_KEY").map_err(|_| {
32            AuthError::config("AUTH_STORAGE_ENCRYPTION_KEY environment variable not set")
33        })?;
34
35        let key_bytes = BASE64
36            .decode(&key_data)
37            .map_err(|_| AuthError::config("Invalid base64 in AUTH_STORAGE_ENCRYPTION_KEY"))?;
38
39        if key_bytes.len() != 32 {
40            return Err(AuthError::config(
41                "Encryption key must be 32 bytes (256 bits)",
42            ));
43        }
44
45        let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
46        let cipher = Aes256Gcm::new(key);
47
48        Ok(Self { cipher })
49    }
50
51    /// Create new encryption manager for testing with a random key
52    #[cfg(test)]
53    pub fn new_random() -> Self {
54        use rand::Rng;
55        let mut key_bytes = [0u8; 32];
56        rand::rng().fill_bytes(&mut key_bytes);
57        let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
58        let cipher = Aes256Gcm::new(key);
59        Self { cipher }
60    }
61
62    /// Generate a new 256-bit encryption key (base64 encoded)
63    pub fn generate_key() -> String {
64        use rand::Rng;
65        let mut key_bytes = [0u8; 32];
66        rand::rng().fill_bytes(&mut key_bytes);
67        BASE64.encode(key_bytes)
68    }
69
70    /// Encrypt sensitive data
71    pub fn encrypt(&self, plaintext: &str) -> Result<EncryptedData> {
72        use rand::Rng;
73        // Generate random nonce
74        let mut nonce_bytes = [0u8; 12]; // 96-bit nonce for GCM
75        rand::rng().fill_bytes(&mut nonce_bytes);
76        let nonce = Nonce::from_slice(&nonce_bytes);
77
78        // Encrypt the data
79        let ciphertext = self
80            .cipher
81            .encrypt(nonce, plaintext.as_bytes())
82            .map_err(|e| AuthError::internal(format!("Encryption failed: {}", e)))?;
83
84        Ok(EncryptedData {
85            data: BASE64.encode(&ciphertext),
86            nonce: BASE64.encode(nonce_bytes),
87            algorithm: "AES-256-GCM".to_string(),
88            key_derivation: "direct".to_string(),
89        })
90    }
91
92    /// Decrypt sensitive data
93    pub fn decrypt(&self, encrypted: &EncryptedData) -> Result<String> {
94        // Validate algorithm
95        if encrypted.algorithm != "AES-256-GCM" {
96            return Err(AuthError::internal(format!(
97                "Unsupported encryption algorithm: {}",
98                encrypted.algorithm
99            )));
100        }
101
102        // Decode base64 data
103        let ciphertext = BASE64
104            .decode(&encrypted.data)
105            .map_err(|_| AuthError::internal("Invalid base64 in encrypted data"))?;
106
107        let nonce_bytes = BASE64
108            .decode(&encrypted.nonce)
109            .map_err(|_| AuthError::internal("Invalid base64 in nonce"))?;
110
111        if nonce_bytes.len() != 12 {
112            return Err(AuthError::internal("Invalid nonce length"));
113        }
114
115        let nonce = Nonce::from_slice(&nonce_bytes);
116
117        // Decrypt the data
118        let plaintext = self
119            .cipher
120            .decrypt(nonce, ciphertext.as_ref())
121            .map_err(|e| AuthError::internal(format!("Decryption failed: {}", e)))?;
122
123        String::from_utf8(plaintext)
124            .map_err(|_| AuthError::internal("Decrypted data is not valid UTF-8"))
125    }
126
127    /// Encrypt data for storage backend
128    pub fn encrypt_for_storage(&self, data: &[u8]) -> Result<Vec<u8>> {
129        let plaintext = String::from_utf8(data.to_vec())
130            .map_err(|_| AuthError::internal("Storage data is not valid UTF-8"))?;
131
132        let encrypted = self.encrypt(&plaintext)?;
133        let serialized = serde_json::to_string(&encrypted).map_err(|e| {
134            AuthError::internal(format!("Failed to serialize encrypted data: {}", e))
135        })?;
136
137        Ok(serialized.into_bytes())
138    }
139
140    /// Decrypt data from storage backend
141    pub fn decrypt_from_storage(&self, data: &[u8]) -> Result<Vec<u8>> {
142        let serialized = String::from_utf8(data.to_vec())
143            .map_err(|_| AuthError::internal("Storage data is not valid UTF-8"))?;
144
145        let encrypted: EncryptedData = serde_json::from_str(&serialized).map_err(|e| {
146            AuthError::internal(format!("Failed to deserialize encrypted data: {}", e))
147        })?;
148
149        let plaintext = self.decrypt(&encrypted)?;
150        Ok(plaintext.into_bytes())
151    }
152}
153
154/// Wrapper for storage backends that adds encryption at rest
155pub struct EncryptedStorage<T> {
156    inner: T,
157    encryption: StorageEncryption,
158}
159
160impl<T> EncryptedStorage<T> {
161    pub fn new(storage: T, encryption: StorageEncryption) -> Self {
162        Self {
163            inner: storage,
164            encryption,
165        }
166    }
167
168    pub fn into_inner(self) -> T {
169        self.inner
170    }
171}
172
173#[async_trait::async_trait]
174impl<T> AuthStorage for EncryptedStorage<T>
175where
176    T: AuthStorage + Send + Sync,
177{
178    // Token methods — delegate to inner storage for index maintenance.
179    // Encryption is applied at the KV layer, which all token-related lookups
180    // ultimately use. Storing a separate encrypted_token:* KV blob would create
181    // redundant data that is never read back.
182    async fn store_token(&self, token: &AuthToken) -> Result<()> {
183        self.inner.store_token(token).await
184    }
185
186    async fn get_token(&self, token_id: &str) -> Result<Option<AuthToken>> {
187        self.inner.get_token(token_id).await
188    }
189
190    async fn get_token_by_access_token(&self, access_token: &str) -> Result<Option<AuthToken>> {
191        self.inner.get_token_by_access_token(access_token).await
192    }
193
194    async fn update_token(&self, token: &AuthToken) -> Result<()> {
195        self.inner.update_token(token).await
196    }
197
198    async fn delete_token(&self, token_id: &str) -> Result<()> {
199        self.inner.delete_token(token_id).await
200    }
201
202    async fn list_user_tokens(&self, user_id: &str) -> Result<Vec<AuthToken>> {
203        self.inner.list_user_tokens(user_id).await
204    }
205
206    // Session methods — delegate to inner storage.
207    // Encryption is applied at the KV layer.
208    async fn store_session(&self, session_id: &str, data: &SessionData) -> Result<()> {
209        self.inner.store_session(session_id, data).await
210    }
211
212    async fn get_session(&self, session_id: &str) -> Result<Option<SessionData>> {
213        self.inner.get_session(session_id).await
214    }
215
216    async fn delete_session(&self, session_id: &str) -> Result<()> {
217        self.inner.delete_session(session_id).await
218    }
219
220    async fn list_user_sessions(&self, user_id: &str) -> Result<Vec<SessionData>> {
221        self.inner.list_user_sessions(user_id).await
222    }
223
224    async fn count_active_sessions(&self) -> Result<u64> {
225        self.inner.count_active_sessions().await
226    }
227
228    // Key-value methods with encryption
229    async fn store_kv(&self, key: &str, value: &[u8], ttl: Option<Duration>) -> Result<()> {
230        let encrypted_value = self.encryption.encrypt_for_storage(value)?;
231        self.inner.store_kv(key, &encrypted_value, ttl).await
232    }
233
234    async fn get_kv(&self, key: &str) -> Result<Option<Vec<u8>>> {
235        if let Some(encrypted_data) = self.inner.get_kv(key).await? {
236            let decrypted_data = self.encryption.decrypt_from_storage(&encrypted_data)?;
237            Ok(Some(decrypted_data))
238        } else {
239            Ok(None)
240        }
241    }
242
243    async fn delete_kv(&self, key: &str) -> Result<()> {
244        self.inner.delete_kv(key).await
245    }
246
247    async fn cleanup_expired(&self) -> Result<()> {
248        self.inner.cleanup_expired().await
249    }
250}
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    #[test]
256    fn test_key_generation() {
257        let key = StorageEncryption::generate_key();
258        assert!(!key.is_empty());
259
260        // Should be valid base64
261        let decoded = BASE64.decode(&key).unwrap();
262        assert_eq!(decoded.len(), 32);
263    }
264
265    #[test]
266    fn test_encryption_roundtrip() {
267        let encryption = StorageEncryption::new_random();
268        let plaintext = "sensitive client secret data";
269
270        let encrypted = encryption.encrypt(plaintext).unwrap();
271        assert_ne!(encrypted.data, plaintext);
272        assert_eq!(encrypted.algorithm, "AES-256-GCM");
273
274        let decrypted = encryption.decrypt(&encrypted).unwrap();
275        assert_eq!(decrypted, plaintext);
276    }
277
278    #[test]
279    fn test_storage_encryption() {
280        let encryption = StorageEncryption::new_random();
281        let data = b"sensitive authentication data";
282
283        let encrypted = encryption.encrypt_for_storage(data).unwrap();
284        assert_ne!(encrypted, data);
285
286        let decrypted = encryption.decrypt_from_storage(&encrypted).unwrap();
287        assert_eq!(decrypted, data);
288    }
289}