auth_framework/storage/
encryption.rs

1use crate::errors::{AuthError, Result};
2use crate::storage::{AuthStorage, SessionData};
3use crate::tokens::AuthToken;
4use aes_gcm::{Aes256Gcm, Key, KeyInit, Nonce, aead::Aead};
5use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
6use rand::{RngCore, rngs::OsRng};
7use serde::{Deserialize, Serialize};
8use std::env;
9use std::time::Duration;
10
11/// Encrypted data container with metadata
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct EncryptedData {
14    /// Base64 encoded encrypted data
15    pub data: String,
16    /// Base64 encoded nonce/IV
17    pub nonce: String,
18    /// Algorithm identifier
19    pub algorithm: String,
20    /// Key derivation method (for future use)
21    pub key_derivation: String,
22}
23
24/// Storage encryption manager using AES-256-GCM
25pub struct StorageEncryption {
26    cipher: Aes256Gcm,
27}
28
29impl StorageEncryption {
30    /// Create new encryption manager from environment variable
31    pub fn new() -> Result<Self> {
32        let key_data = env::var("AUTH_STORAGE_ENCRYPTION_KEY").map_err(|_| {
33            AuthError::config("AUTH_STORAGE_ENCRYPTION_KEY environment variable not set")
34        })?;
35
36        let key_bytes = BASE64
37            .decode(&key_data)
38            .map_err(|_| AuthError::config("Invalid base64 in AUTH_STORAGE_ENCRYPTION_KEY"))?;
39
40        if key_bytes.len() != 32 {
41            return Err(AuthError::config(
42                "Encryption key must be 32 bytes (256 bits)",
43            ));
44        }
45
46        let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
47        let cipher = Aes256Gcm::new(key);
48
49        Ok(Self { cipher })
50    }
51
52    /// Create new encryption manager for testing with a random key
53    #[cfg(test)]
54    pub fn new_random() -> Self {
55        let key = Aes256Gcm::generate_key(&mut OsRng);
56        let cipher = Aes256Gcm::new(&key);
57        Self { cipher }
58    }
59
60    /// Generate a new 256-bit encryption key (base64 encoded)
61    pub fn generate_key() -> String {
62        let mut key_bytes = [0u8; 32];
63        OsRng.fill_bytes(&mut key_bytes);
64        BASE64.encode(key_bytes)
65    }
66
67    /// Encrypt sensitive data
68    pub fn encrypt(&self, plaintext: &str) -> Result<EncryptedData> {
69        // Generate random nonce
70        let mut nonce_bytes = [0u8; 12]; // 96-bit nonce for GCM
71        OsRng.fill_bytes(&mut nonce_bytes);
72        let nonce = Nonce::from_slice(&nonce_bytes);
73
74        // Encrypt the data
75        let ciphertext = self
76            .cipher
77            .encrypt(nonce, plaintext.as_bytes())
78            .map_err(|e| AuthError::internal(format!("Encryption failed: {}", e)))?;
79
80        Ok(EncryptedData {
81            data: BASE64.encode(&ciphertext),
82            nonce: BASE64.encode(nonce_bytes),
83            algorithm: "AES-256-GCM".to_string(),
84            key_derivation: "direct".to_string(),
85        })
86    }
87
88    /// Decrypt sensitive data
89    pub fn decrypt(&self, encrypted: &EncryptedData) -> Result<String> {
90        // Validate algorithm
91        if encrypted.algorithm != "AES-256-GCM" {
92            return Err(AuthError::internal(format!(
93                "Unsupported encryption algorithm: {}",
94                encrypted.algorithm
95            )));
96        }
97
98        // Decode base64 data
99        let ciphertext = BASE64
100            .decode(&encrypted.data)
101            .map_err(|_| AuthError::internal("Invalid base64 in encrypted data"))?;
102
103        let nonce_bytes = BASE64
104            .decode(&encrypted.nonce)
105            .map_err(|_| AuthError::internal("Invalid base64 in nonce"))?;
106
107        if nonce_bytes.len() != 12 {
108            return Err(AuthError::internal("Invalid nonce length"));
109        }
110
111        let nonce = Nonce::from_slice(&nonce_bytes);
112
113        // Decrypt the data
114        let plaintext = self
115            .cipher
116            .decrypt(nonce, ciphertext.as_ref())
117            .map_err(|e| AuthError::internal(format!("Decryption failed: {}", e)))?;
118
119        String::from_utf8(plaintext)
120            .map_err(|_| AuthError::internal("Decrypted data is not valid UTF-8"))
121    }
122
123    /// Encrypt data for storage backend
124    pub fn encrypt_for_storage(&self, data: &[u8]) -> Result<Vec<u8>> {
125        let plaintext = String::from_utf8(data.to_vec())
126            .map_err(|_| AuthError::internal("Storage data is not valid UTF-8"))?;
127
128        let encrypted = self.encrypt(&plaintext)?;
129        let serialized = serde_json::to_string(&encrypted).map_err(|e| {
130            AuthError::internal(format!("Failed to serialize encrypted data: {}", e))
131        })?;
132
133        Ok(serialized.into_bytes())
134    }
135
136    /// Decrypt data from storage backend
137    pub fn decrypt_from_storage(&self, data: &[u8]) -> Result<Vec<u8>> {
138        let serialized = String::from_utf8(data.to_vec())
139            .map_err(|_| AuthError::internal("Storage data is not valid UTF-8"))?;
140
141        let encrypted: EncryptedData = serde_json::from_str(&serialized).map_err(|e| {
142            AuthError::internal(format!("Failed to deserialize encrypted data: {}", e))
143        })?;
144
145        let plaintext = self.decrypt(&encrypted)?;
146        Ok(plaintext.into_bytes())
147    }
148}
149
150/// Wrapper for storage backends that adds encryption at rest
151pub struct EncryptedStorage<T> {
152    inner: T,
153    encryption: StorageEncryption,
154}
155
156impl<T> EncryptedStorage<T> {
157    pub fn new(storage: T, encryption: StorageEncryption) -> Self {
158        Self {
159            inner: storage,
160            encryption,
161        }
162    }
163
164    pub fn into_inner(self) -> T {
165        self.inner
166    }
167}
168
169#[async_trait::async_trait]
170impl<T> AuthStorage for EncryptedStorage<T>
171where
172    T: AuthStorage + Send + Sync,
173{
174    // Token methods with encryption
175    async fn store_token(&self, token: &AuthToken) -> Result<()> {
176        self.inner.store_token(token).await
177    }
178
179    async fn get_token(&self, token_id: &str) -> Result<Option<AuthToken>> {
180        self.inner.get_token(token_id).await
181    }
182
183    async fn get_token_by_access_token(&self, access_token: &str) -> Result<Option<AuthToken>> {
184        self.inner.get_token_by_access_token(access_token).await
185    }
186
187    async fn update_token(&self, token: &AuthToken) -> Result<()> {
188        self.inner.update_token(token).await
189    }
190
191    async fn delete_token(&self, token_id: &str) -> Result<()> {
192        self.inner.delete_token(token_id).await
193    }
194
195    async fn list_user_tokens(&self, user_id: &str) -> Result<Vec<AuthToken>> {
196        self.inner.list_user_tokens(user_id).await
197    }
198
199    // Session methods with encryption
200    async fn store_session(&self, session_id: &str, data: &SessionData) -> Result<()> {
201        self.inner.store_session(session_id, data).await
202    }
203
204    async fn get_session(&self, session_id: &str) -> Result<Option<SessionData>> {
205        self.inner.get_session(session_id).await
206    }
207
208    async fn delete_session(&self, session_id: &str) -> Result<()> {
209        self.inner.delete_session(session_id).await
210    }
211
212    async fn list_user_sessions(&self, user_id: &str) -> Result<Vec<SessionData>> {
213        self.inner.list_user_sessions(user_id).await
214    }
215
216    async fn count_active_sessions(&self) -> Result<u64> {
217        self.inner.count_active_sessions().await
218    }
219
220    // Key-value methods with encryption
221    async fn store_kv(&self, key: &str, value: &[u8], ttl: Option<Duration>) -> Result<()> {
222        let encrypted_value = self.encryption.encrypt_for_storage(value)?;
223        self.inner.store_kv(key, &encrypted_value, ttl).await
224    }
225
226    async fn get_kv(&self, key: &str) -> Result<Option<Vec<u8>>> {
227        if let Some(encrypted_data) = self.inner.get_kv(key).await? {
228            let decrypted_data = self.encryption.decrypt_from_storage(&encrypted_data)?;
229            Ok(Some(decrypted_data))
230        } else {
231            Ok(None)
232        }
233    }
234
235    async fn delete_kv(&self, key: &str) -> Result<()> {
236        self.inner.delete_kv(key).await
237    }
238
239    async fn cleanup_expired(&self) -> Result<()> {
240        self.inner.cleanup_expired().await
241    }
242}
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[test]
248    fn test_key_generation() {
249        let key = StorageEncryption::generate_key();
250        assert!(!key.is_empty());
251
252        // Should be valid base64
253        let decoded = BASE64.decode(&key).unwrap();
254        assert_eq!(decoded.len(), 32);
255    }
256
257    #[test]
258    fn test_encryption_roundtrip() {
259        let encryption = StorageEncryption::new_random();
260        let plaintext = "sensitive client secret data";
261
262        let encrypted = encryption.encrypt(plaintext).unwrap();
263        assert_ne!(encrypted.data, plaintext);
264        assert_eq!(encrypted.algorithm, "AES-256-GCM");
265
266        let decrypted = encryption.decrypt(&encrypted).unwrap();
267        assert_eq!(decrypted, plaintext);
268    }
269
270    #[test]
271    fn test_storage_encryption() {
272        let encryption = StorageEncryption::new_random();
273        let data = b"sensitive authentication data";
274
275        let encrypted = encryption.encrypt_for_storage(data).unwrap();
276        assert_ne!(encrypted, data);
277
278        let decrypted = encryption.decrypt_from_storage(&encrypted).unwrap();
279        assert_eq!(decrypted, data);
280    }
281}
282
283