auth_framework/storage/
encryption.rs1use crate::errors::{AuthError, Result};
2use crate::storage::{AuthStorage, SessionData};
3use crate::tokens::AuthToken;
4use aes_gcm::{Aes256Gcm, Key, KeyInit, Nonce, aead::Aead};
5use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
6use serde::{Deserialize, Serialize};
7use std::env;
8use std::time::Duration;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct EncryptedData {
13 pub data: String,
15 pub nonce: String,
17 pub algorithm: String,
19 pub key_derivation: String,
21}
22
23pub struct StorageEncryption {
25 cipher: Aes256Gcm,
26}
27
28impl StorageEncryption {
29 pub fn new() -> Result<Self> {
31 let key_data = env::var("AUTH_STORAGE_ENCRYPTION_KEY").map_err(|_| {
32 AuthError::config("AUTH_STORAGE_ENCRYPTION_KEY environment variable not set")
33 })?;
34
35 let key_bytes = BASE64
36 .decode(&key_data)
37 .map_err(|_| AuthError::config("Invalid base64 in AUTH_STORAGE_ENCRYPTION_KEY"))?;
38
39 if key_bytes.len() != 32 {
40 return Err(AuthError::config(
41 "Encryption key must be 32 bytes (256 bits)",
42 ));
43 }
44
45 let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
46 let cipher = Aes256Gcm::new(key);
47
48 Ok(Self { cipher })
49 }
50
51 #[cfg(test)]
53 pub fn new_random() -> Self {
54 use rand::Rng;
55 let mut key_bytes = [0u8; 32];
56 rand::rng().fill_bytes(&mut key_bytes);
57 let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
58 let cipher = Aes256Gcm::new(key);
59 Self { cipher }
60 }
61
62 pub fn generate_key() -> String {
64 use rand::Rng;
65 let mut key_bytes = [0u8; 32];
66 rand::rng().fill_bytes(&mut key_bytes);
67 BASE64.encode(key_bytes)
68 }
69
70 pub fn encrypt(&self, plaintext: &str) -> Result<EncryptedData> {
72 use rand::Rng;
73 let mut nonce_bytes = [0u8; 12]; rand::rng().fill_bytes(&mut nonce_bytes);
76 let nonce = Nonce::from_slice(&nonce_bytes);
77
78 let ciphertext = self
80 .cipher
81 .encrypt(nonce, plaintext.as_bytes())
82 .map_err(|e| AuthError::internal(format!("Encryption failed: {}", e)))?;
83
84 Ok(EncryptedData {
85 data: BASE64.encode(&ciphertext),
86 nonce: BASE64.encode(nonce_bytes),
87 algorithm: "AES-256-GCM".to_string(),
88 key_derivation: "direct".to_string(),
89 })
90 }
91
92 pub fn decrypt(&self, encrypted: &EncryptedData) -> Result<String> {
94 if encrypted.algorithm != "AES-256-GCM" {
96 return Err(AuthError::internal(format!(
97 "Unsupported encryption algorithm: {}",
98 encrypted.algorithm
99 )));
100 }
101
102 let ciphertext = BASE64
104 .decode(&encrypted.data)
105 .map_err(|_| AuthError::internal("Invalid base64 in encrypted data"))?;
106
107 let nonce_bytes = BASE64
108 .decode(&encrypted.nonce)
109 .map_err(|_| AuthError::internal("Invalid base64 in nonce"))?;
110
111 if nonce_bytes.len() != 12 {
112 return Err(AuthError::internal("Invalid nonce length"));
113 }
114
115 let nonce = Nonce::from_slice(&nonce_bytes);
116
117 let plaintext = self
119 .cipher
120 .decrypt(nonce, ciphertext.as_ref())
121 .map_err(|e| AuthError::internal(format!("Decryption failed: {}", e)))?;
122
123 String::from_utf8(plaintext)
124 .map_err(|_| AuthError::internal("Decrypted data is not valid UTF-8"))
125 }
126
127 pub fn encrypt_for_storage(&self, data: &[u8]) -> Result<Vec<u8>> {
129 let plaintext = String::from_utf8(data.to_vec())
130 .map_err(|_| AuthError::internal("Storage data is not valid UTF-8"))?;
131
132 let encrypted = self.encrypt(&plaintext)?;
133 let serialized = serde_json::to_string(&encrypted).map_err(|e| {
134 AuthError::internal(format!("Failed to serialize encrypted data: {}", e))
135 })?;
136
137 Ok(serialized.into_bytes())
138 }
139
140 pub fn decrypt_from_storage(&self, data: &[u8]) -> Result<Vec<u8>> {
142 let serialized = String::from_utf8(data.to_vec())
143 .map_err(|_| AuthError::internal("Storage data is not valid UTF-8"))?;
144
145 let encrypted: EncryptedData = serde_json::from_str(&serialized).map_err(|e| {
146 AuthError::internal(format!("Failed to deserialize encrypted data: {}", e))
147 })?;
148
149 let plaintext = self.decrypt(&encrypted)?;
150 Ok(plaintext.into_bytes())
151 }
152}
153
154pub struct EncryptedStorage<T> {
156 inner: T,
157 encryption: StorageEncryption,
158}
159
160impl<T> EncryptedStorage<T> {
161 pub fn new(storage: T, encryption: StorageEncryption) -> Self {
162 Self {
163 inner: storage,
164 encryption,
165 }
166 }
167
168 pub fn into_inner(self) -> T {
169 self.inner
170 }
171}
172
173#[async_trait::async_trait]
174impl<T> AuthStorage for EncryptedStorage<T>
175where
176 T: AuthStorage + Send + Sync,
177{
178 async fn store_token(&self, token: &AuthToken) -> Result<()> {
183 self.inner.store_token(token).await
184 }
185
186 async fn get_token(&self, token_id: &str) -> Result<Option<AuthToken>> {
187 self.inner.get_token(token_id).await
188 }
189
190 async fn get_token_by_access_token(&self, access_token: &str) -> Result<Option<AuthToken>> {
191 self.inner.get_token_by_access_token(access_token).await
192 }
193
194 async fn update_token(&self, token: &AuthToken) -> Result<()> {
195 self.inner.update_token(token).await
196 }
197
198 async fn delete_token(&self, token_id: &str) -> Result<()> {
199 self.inner.delete_token(token_id).await
200 }
201
202 async fn list_user_tokens(&self, user_id: &str) -> Result<Vec<AuthToken>> {
203 self.inner.list_user_tokens(user_id).await
204 }
205
206 async fn store_session(&self, session_id: &str, data: &SessionData) -> Result<()> {
209 self.inner.store_session(session_id, data).await
210 }
211
212 async fn get_session(&self, session_id: &str) -> Result<Option<SessionData>> {
213 self.inner.get_session(session_id).await
214 }
215
216 async fn delete_session(&self, session_id: &str) -> Result<()> {
217 self.inner.delete_session(session_id).await
218 }
219
220 async fn list_user_sessions(&self, user_id: &str) -> Result<Vec<SessionData>> {
221 self.inner.list_user_sessions(user_id).await
222 }
223
224 async fn count_active_sessions(&self) -> Result<u64> {
225 self.inner.count_active_sessions().await
226 }
227
228 async fn store_kv(&self, key: &str, value: &[u8], ttl: Option<Duration>) -> Result<()> {
230 let encrypted_value = self.encryption.encrypt_for_storage(value)?;
231 self.inner.store_kv(key, &encrypted_value, ttl).await
232 }
233
234 async fn get_kv(&self, key: &str) -> Result<Option<Vec<u8>>> {
235 if let Some(encrypted_data) = self.inner.get_kv(key).await? {
236 let decrypted_data = self.encryption.decrypt_from_storage(&encrypted_data)?;
237 Ok(Some(decrypted_data))
238 } else {
239 Ok(None)
240 }
241 }
242
243 async fn delete_kv(&self, key: &str) -> Result<()> {
244 self.inner.delete_kv(key).await
245 }
246
247 async fn cleanup_expired(&self) -> Result<()> {
248 self.inner.cleanup_expired().await
249 }
250}
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[test]
256 fn test_key_generation() {
257 let key = StorageEncryption::generate_key();
258 assert!(!key.is_empty());
259
260 let decoded = BASE64.decode(&key).unwrap();
262 assert_eq!(decoded.len(), 32);
263 }
264
265 #[test]
266 fn test_encryption_roundtrip() {
267 let encryption = StorageEncryption::new_random();
268 let plaintext = "sensitive client secret data";
269
270 let encrypted = encryption.encrypt(plaintext).unwrap();
271 assert_ne!(encrypted.data, plaintext);
272 assert_eq!(encrypted.algorithm, "AES-256-GCM");
273
274 let decrypted = encryption.decrypt(&encrypted).unwrap();
275 assert_eq!(decrypted, plaintext);
276 }
277
278 #[test]
279 fn test_storage_encryption() {
280 let encryption = StorageEncryption::new_random();
281 let data = b"sensitive authentication data";
282
283 let encrypted = encryption.encrypt_for_storage(data).unwrap();
284 assert_ne!(encrypted, data);
285
286 let decrypted = encryption.decrypt_from_storage(&encrypted).unwrap();
287 assert_eq!(decrypted, data);
288 }
289}