auth_framework/storage/
encryption.rs

1use crate::errors::{AuthError, Result};
2use crate::storage::{AuthStorage, SessionData};
3use crate::tokens::AuthToken;
4use aes_gcm::{Aes256Gcm, Key, KeyInit, Nonce, aead::Aead};
5use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
6use rand::RngCore;
7use serde::{Deserialize, Serialize};
8use std::env;
9use std::time::Duration;
10
11/// Encrypted data container with metadata
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct EncryptedData {
14    /// Base64 encoded encrypted data
15    pub data: String,
16    /// Base64 encoded nonce/IV
17    pub nonce: String,
18    /// Algorithm identifier
19    pub algorithm: String,
20    /// Key derivation method (for future use)
21    pub key_derivation: String,
22}
23
24/// Storage encryption manager using AES-256-GCM
25pub struct StorageEncryption {
26    cipher: Aes256Gcm,
27}
28
29impl StorageEncryption {
30    /// Create new encryption manager from environment variable
31    pub fn new() -> Result<Self> {
32        let key_data = env::var("AUTH_STORAGE_ENCRYPTION_KEY").map_err(|_| {
33            AuthError::config("AUTH_STORAGE_ENCRYPTION_KEY environment variable not set")
34        })?;
35
36        let key_bytes = BASE64
37            .decode(&key_data)
38            .map_err(|_| AuthError::config("Invalid base64 in AUTH_STORAGE_ENCRYPTION_KEY"))?;
39
40        if key_bytes.len() != 32 {
41            return Err(AuthError::config(
42                "Encryption key must be 32 bytes (256 bits)",
43            ));
44        }
45
46        let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
47        let cipher = Aes256Gcm::new(key);
48
49        Ok(Self { cipher })
50    }
51
52    /// Create new encryption manager for testing with a random key
53    #[cfg(test)]
54    pub fn new_random() -> Self {
55        use rand::RngCore;
56        let mut key_bytes = [0u8; 32];
57        rand::rng().fill_bytes(&mut key_bytes);
58        let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
59        let cipher = Aes256Gcm::new(key);
60        Self { cipher }
61    }
62
63    /// Generate a new 256-bit encryption key (base64 encoded)
64    pub fn generate_key() -> String {
65        let mut key_bytes = [0u8; 32];
66        rand::rng().fill_bytes(&mut key_bytes);
67        BASE64.encode(key_bytes)
68    }
69
70    /// Encrypt sensitive data
71    pub fn encrypt(&self, plaintext: &str) -> Result<EncryptedData> {
72        // Generate random nonce
73        let mut nonce_bytes = [0u8; 12]; // 96-bit nonce for GCM
74        rand::rng().fill_bytes(&mut nonce_bytes);
75        let nonce = Nonce::from_slice(&nonce_bytes);
76
77        // Encrypt the data
78        let ciphertext = self
79            .cipher
80            .encrypt(nonce, plaintext.as_bytes())
81            .map_err(|e| AuthError::internal(format!("Encryption failed: {}", e)))?;
82
83        Ok(EncryptedData {
84            data: BASE64.encode(&ciphertext),
85            nonce: BASE64.encode(nonce_bytes),
86            algorithm: "AES-256-GCM".to_string(),
87            key_derivation: "direct".to_string(),
88        })
89    }
90
91    /// Decrypt sensitive data
92    pub fn decrypt(&self, encrypted: &EncryptedData) -> Result<String> {
93        // Validate algorithm
94        if encrypted.algorithm != "AES-256-GCM" {
95            return Err(AuthError::internal(format!(
96                "Unsupported encryption algorithm: {}",
97                encrypted.algorithm
98            )));
99        }
100
101        // Decode base64 data
102        let ciphertext = BASE64
103            .decode(&encrypted.data)
104            .map_err(|_| AuthError::internal("Invalid base64 in encrypted data"))?;
105
106        let nonce_bytes = BASE64
107            .decode(&encrypted.nonce)
108            .map_err(|_| AuthError::internal("Invalid base64 in nonce"))?;
109
110        if nonce_bytes.len() != 12 {
111            return Err(AuthError::internal("Invalid nonce length"));
112        }
113
114        let nonce = Nonce::from_slice(&nonce_bytes);
115
116        // Decrypt the data
117        let plaintext = self
118            .cipher
119            .decrypt(nonce, ciphertext.as_ref())
120            .map_err(|e| AuthError::internal(format!("Decryption failed: {}", e)))?;
121
122        String::from_utf8(plaintext)
123            .map_err(|_| AuthError::internal("Decrypted data is not valid UTF-8"))
124    }
125
126    /// Encrypt data for storage backend
127    pub fn encrypt_for_storage(&self, data: &[u8]) -> Result<Vec<u8>> {
128        let plaintext = String::from_utf8(data.to_vec())
129            .map_err(|_| AuthError::internal("Storage data is not valid UTF-8"))?;
130
131        let encrypted = self.encrypt(&plaintext)?;
132        let serialized = serde_json::to_string(&encrypted).map_err(|e| {
133            AuthError::internal(format!("Failed to serialize encrypted data: {}", e))
134        })?;
135
136        Ok(serialized.into_bytes())
137    }
138
139    /// Decrypt data from storage backend
140    pub fn decrypt_from_storage(&self, data: &[u8]) -> Result<Vec<u8>> {
141        let serialized = String::from_utf8(data.to_vec())
142            .map_err(|_| AuthError::internal("Storage data is not valid UTF-8"))?;
143
144        let encrypted: EncryptedData = serde_json::from_str(&serialized).map_err(|e| {
145            AuthError::internal(format!("Failed to deserialize encrypted data: {}", e))
146        })?;
147
148        let plaintext = self.decrypt(&encrypted)?;
149        Ok(plaintext.into_bytes())
150    }
151}
152
153/// Wrapper for storage backends that adds encryption at rest
154pub struct EncryptedStorage<T> {
155    inner: T,
156    encryption: StorageEncryption,
157}
158
159impl<T> EncryptedStorage<T> {
160    pub fn new(storage: T, encryption: StorageEncryption) -> Self {
161        Self {
162            inner: storage,
163            encryption,
164        }
165    }
166
167    pub fn into_inner(self) -> T {
168        self.inner
169    }
170}
171
172#[async_trait::async_trait]
173impl<T> AuthStorage for EncryptedStorage<T>
174where
175    T: AuthStorage + Send + Sync,
176{
177    // Token methods with encryption
178    async fn store_token(&self, token: &AuthToken) -> Result<()> {
179        self.inner.store_token(token).await
180    }
181
182    async fn get_token(&self, token_id: &str) -> Result<Option<AuthToken>> {
183        self.inner.get_token(token_id).await
184    }
185
186    async fn get_token_by_access_token(&self, access_token: &str) -> Result<Option<AuthToken>> {
187        self.inner.get_token_by_access_token(access_token).await
188    }
189
190    async fn update_token(&self, token: &AuthToken) -> Result<()> {
191        self.inner.update_token(token).await
192    }
193
194    async fn delete_token(&self, token_id: &str) -> Result<()> {
195        self.inner.delete_token(token_id).await
196    }
197
198    async fn list_user_tokens(&self, user_id: &str) -> Result<Vec<AuthToken>> {
199        self.inner.list_user_tokens(user_id).await
200    }
201
202    // Session methods with encryption
203    async fn store_session(&self, session_id: &str, data: &SessionData) -> Result<()> {
204        self.inner.store_session(session_id, data).await
205    }
206
207    async fn get_session(&self, session_id: &str) -> Result<Option<SessionData>> {
208        self.inner.get_session(session_id).await
209    }
210
211    async fn delete_session(&self, session_id: &str) -> Result<()> {
212        self.inner.delete_session(session_id).await
213    }
214
215    async fn list_user_sessions(&self, user_id: &str) -> Result<Vec<SessionData>> {
216        self.inner.list_user_sessions(user_id).await
217    }
218
219    async fn count_active_sessions(&self) -> Result<u64> {
220        self.inner.count_active_sessions().await
221    }
222
223    // Key-value methods with encryption
224    async fn store_kv(&self, key: &str, value: &[u8], ttl: Option<Duration>) -> Result<()> {
225        let encrypted_value = self.encryption.encrypt_for_storage(value)?;
226        self.inner.store_kv(key, &encrypted_value, ttl).await
227    }
228
229    async fn get_kv(&self, key: &str) -> Result<Option<Vec<u8>>> {
230        if let Some(encrypted_data) = self.inner.get_kv(key).await? {
231            let decrypted_data = self.encryption.decrypt_from_storage(&encrypted_data)?;
232            Ok(Some(decrypted_data))
233        } else {
234            Ok(None)
235        }
236    }
237
238    async fn delete_kv(&self, key: &str) -> Result<()> {
239        self.inner.delete_kv(key).await
240    }
241
242    async fn cleanup_expired(&self) -> Result<()> {
243        self.inner.cleanup_expired().await
244    }
245}
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    #[test]
251    fn test_key_generation() {
252        let key = StorageEncryption::generate_key();
253        assert!(!key.is_empty());
254
255        // Should be valid base64
256        let decoded = BASE64.decode(&key).unwrap();
257        assert_eq!(decoded.len(), 32);
258    }
259
260    #[test]
261    fn test_encryption_roundtrip() {
262        let encryption = StorageEncryption::new_random();
263        let plaintext = "sensitive client secret data";
264
265        let encrypted = encryption.encrypt(plaintext).unwrap();
266        assert_ne!(encrypted.data, plaintext);
267        assert_eq!(encrypted.algorithm, "AES-256-GCM");
268
269        let decrypted = encryption.decrypt(&encrypted).unwrap();
270        assert_eq!(decrypted, plaintext);
271    }
272
273    #[test]
274    fn test_storage_encryption() {
275        let encryption = StorageEncryption::new_random();
276        let data = b"sensitive authentication data";
277
278        let encrypted = encryption.encrypt_for_storage(data).unwrap();
279        assert_ne!(encrypted, data);
280
281        let decrypted = encryption.decrypt_from_storage(&encrypted).unwrap();
282        assert_eq!(decrypted, data);
283    }
284}