auth_framework/storage/
encryption.rs1use crate::errors::{AuthError, Result};
2use crate::storage::{AuthStorage, SessionData};
3use crate::tokens::AuthToken;
4use aes_gcm::{Aes256Gcm, Key, KeyInit, Nonce, aead::Aead};
5use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
6use rand::{RngCore, rngs::OsRng};
7use serde::{Deserialize, Serialize};
8use std::env;
9use std::time::Duration;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct EncryptedData {
14 pub data: String,
16 pub nonce: String,
18 pub algorithm: String,
20 pub key_derivation: String,
22}
23
24pub struct StorageEncryption {
26 cipher: Aes256Gcm,
27}
28
29impl StorageEncryption {
30 pub fn new() -> Result<Self> {
32 let key_data = env::var("AUTH_STORAGE_ENCRYPTION_KEY").map_err(|_| {
33 AuthError::config("AUTH_STORAGE_ENCRYPTION_KEY environment variable not set")
34 })?;
35
36 let key_bytes = BASE64
37 .decode(&key_data)
38 .map_err(|_| AuthError::config("Invalid base64 in AUTH_STORAGE_ENCRYPTION_KEY"))?;
39
40 if key_bytes.len() != 32 {
41 return Err(AuthError::config(
42 "Encryption key must be 32 bytes (256 bits)",
43 ));
44 }
45
46 let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
47 let cipher = Aes256Gcm::new(key);
48
49 Ok(Self { cipher })
50 }
51
52 #[cfg(test)]
54 pub fn new_random() -> Self {
55 let key = Aes256Gcm::generate_key(&mut OsRng);
56 let cipher = Aes256Gcm::new(&key);
57 Self { cipher }
58 }
59
60 pub fn generate_key() -> String {
62 let mut key_bytes = [0u8; 32];
63 OsRng.fill_bytes(&mut key_bytes);
64 BASE64.encode(key_bytes)
65 }
66
67 pub fn encrypt(&self, plaintext: &str) -> Result<EncryptedData> {
69 let mut nonce_bytes = [0u8; 12]; OsRng.fill_bytes(&mut nonce_bytes);
72 let nonce = Nonce::from_slice(&nonce_bytes);
73
74 let ciphertext = self
76 .cipher
77 .encrypt(nonce, plaintext.as_bytes())
78 .map_err(|e| AuthError::internal(format!("Encryption failed: {}", e)))?;
79
80 Ok(EncryptedData {
81 data: BASE64.encode(&ciphertext),
82 nonce: BASE64.encode(nonce_bytes),
83 algorithm: "AES-256-GCM".to_string(),
84 key_derivation: "direct".to_string(),
85 })
86 }
87
88 pub fn decrypt(&self, encrypted: &EncryptedData) -> Result<String> {
90 if encrypted.algorithm != "AES-256-GCM" {
92 return Err(AuthError::internal(format!(
93 "Unsupported encryption algorithm: {}",
94 encrypted.algorithm
95 )));
96 }
97
98 let ciphertext = BASE64
100 .decode(&encrypted.data)
101 .map_err(|_| AuthError::internal("Invalid base64 in encrypted data"))?;
102
103 let nonce_bytes = BASE64
104 .decode(&encrypted.nonce)
105 .map_err(|_| AuthError::internal("Invalid base64 in nonce"))?;
106
107 if nonce_bytes.len() != 12 {
108 return Err(AuthError::internal("Invalid nonce length"));
109 }
110
111 let nonce = Nonce::from_slice(&nonce_bytes);
112
113 let plaintext = self
115 .cipher
116 .decrypt(nonce, ciphertext.as_ref())
117 .map_err(|e| AuthError::internal(format!("Decryption failed: {}", e)))?;
118
119 String::from_utf8(plaintext)
120 .map_err(|_| AuthError::internal("Decrypted data is not valid UTF-8"))
121 }
122
123 pub fn encrypt_for_storage(&self, data: &[u8]) -> Result<Vec<u8>> {
125 let plaintext = String::from_utf8(data.to_vec())
126 .map_err(|_| AuthError::internal("Storage data is not valid UTF-8"))?;
127
128 let encrypted = self.encrypt(&plaintext)?;
129 let serialized = serde_json::to_string(&encrypted).map_err(|e| {
130 AuthError::internal(format!("Failed to serialize encrypted data: {}", e))
131 })?;
132
133 Ok(serialized.into_bytes())
134 }
135
136 pub fn decrypt_from_storage(&self, data: &[u8]) -> Result<Vec<u8>> {
138 let serialized = String::from_utf8(data.to_vec())
139 .map_err(|_| AuthError::internal("Storage data is not valid UTF-8"))?;
140
141 let encrypted: EncryptedData = serde_json::from_str(&serialized).map_err(|e| {
142 AuthError::internal(format!("Failed to deserialize encrypted data: {}", e))
143 })?;
144
145 let plaintext = self.decrypt(&encrypted)?;
146 Ok(plaintext.into_bytes())
147 }
148}
149
150pub struct EncryptedStorage<T> {
152 inner: T,
153 encryption: StorageEncryption,
154}
155
156impl<T> EncryptedStorage<T> {
157 pub fn new(storage: T, encryption: StorageEncryption) -> Self {
158 Self {
159 inner: storage,
160 encryption,
161 }
162 }
163
164 pub fn into_inner(self) -> T {
165 self.inner
166 }
167}
168
169#[async_trait::async_trait]
170impl<T> AuthStorage for EncryptedStorage<T>
171where
172 T: AuthStorage + Send + Sync,
173{
174 async fn store_token(&self, token: &AuthToken) -> Result<()> {
176 self.inner.store_token(token).await
177 }
178
179 async fn get_token(&self, token_id: &str) -> Result<Option<AuthToken>> {
180 self.inner.get_token(token_id).await
181 }
182
183 async fn get_token_by_access_token(&self, access_token: &str) -> Result<Option<AuthToken>> {
184 self.inner.get_token_by_access_token(access_token).await
185 }
186
187 async fn update_token(&self, token: &AuthToken) -> Result<()> {
188 self.inner.update_token(token).await
189 }
190
191 async fn delete_token(&self, token_id: &str) -> Result<()> {
192 self.inner.delete_token(token_id).await
193 }
194
195 async fn list_user_tokens(&self, user_id: &str) -> Result<Vec<AuthToken>> {
196 self.inner.list_user_tokens(user_id).await
197 }
198
199 async fn store_session(&self, session_id: &str, data: &SessionData) -> Result<()> {
201 self.inner.store_session(session_id, data).await
202 }
203
204 async fn get_session(&self, session_id: &str) -> Result<Option<SessionData>> {
205 self.inner.get_session(session_id).await
206 }
207
208 async fn delete_session(&self, session_id: &str) -> Result<()> {
209 self.inner.delete_session(session_id).await
210 }
211
212 async fn list_user_sessions(&self, user_id: &str) -> Result<Vec<SessionData>> {
213 self.inner.list_user_sessions(user_id).await
214 }
215
216 async fn count_active_sessions(&self) -> Result<u64> {
217 self.inner.count_active_sessions().await
218 }
219
220 async fn store_kv(&self, key: &str, value: &[u8], ttl: Option<Duration>) -> Result<()> {
222 let encrypted_value = self.encryption.encrypt_for_storage(value)?;
223 self.inner.store_kv(key, &encrypted_value, ttl).await
224 }
225
226 async fn get_kv(&self, key: &str) -> Result<Option<Vec<u8>>> {
227 if let Some(encrypted_data) = self.inner.get_kv(key).await? {
228 let decrypted_data = self.encryption.decrypt_from_storage(&encrypted_data)?;
229 Ok(Some(decrypted_data))
230 } else {
231 Ok(None)
232 }
233 }
234
235 async fn delete_kv(&self, key: &str) -> Result<()> {
236 self.inner.delete_kv(key).await
237 }
238
239 async fn cleanup_expired(&self) -> Result<()> {
240 self.inner.cleanup_expired().await
241 }
242}
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 #[test]
248 fn test_key_generation() {
249 let key = StorageEncryption::generate_key();
250 assert!(!key.is_empty());
251
252 let decoded = BASE64.decode(&key).unwrap();
254 assert_eq!(decoded.len(), 32);
255 }
256
257 #[test]
258 fn test_encryption_roundtrip() {
259 let encryption = StorageEncryption::new_random();
260 let plaintext = "sensitive client secret data";
261
262 let encrypted = encryption.encrypt(plaintext).unwrap();
263 assert_ne!(encrypted.data, plaintext);
264 assert_eq!(encrypted.algorithm, "AES-256-GCM");
265
266 let decrypted = encryption.decrypt(&encrypted).unwrap();
267 assert_eq!(decrypted, plaintext);
268 }
269
270 #[test]
271 fn test_storage_encryption() {
272 let encryption = StorageEncryption::new_random();
273 let data = b"sensitive authentication data";
274
275 let encrypted = encryption.encrypt_for_storage(data).unwrap();
276 assert_ne!(encrypted, data);
277
278 let decrypted = encryption.decrypt_from_storage(&encrypted).unwrap();
279 assert_eq!(decrypted, data);
280 }
281}
282
283