auth_framework/storage/
encryption.rs1use crate::errors::{AuthError, Result};
2use crate::storage::{AuthStorage, SessionData};
3use crate::tokens::AuthToken;
4use aes_gcm::{Aes256Gcm, Key, KeyInit, Nonce, aead::Aead};
5use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
6use rand::RngCore;
7use serde::{Deserialize, Serialize};
8use std::env;
9use std::time::Duration;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct EncryptedData {
14 pub data: String,
16 pub nonce: String,
18 pub algorithm: String,
20 pub key_derivation: String,
22}
23
24pub struct StorageEncryption {
26 cipher: Aes256Gcm,
27}
28
29impl StorageEncryption {
30 pub fn new() -> Result<Self> {
32 let key_data = env::var("AUTH_STORAGE_ENCRYPTION_KEY").map_err(|_| {
33 AuthError::config("AUTH_STORAGE_ENCRYPTION_KEY environment variable not set")
34 })?;
35
36 let key_bytes = BASE64
37 .decode(&key_data)
38 .map_err(|_| AuthError::config("Invalid base64 in AUTH_STORAGE_ENCRYPTION_KEY"))?;
39
40 if key_bytes.len() != 32 {
41 return Err(AuthError::config(
42 "Encryption key must be 32 bytes (256 bits)",
43 ));
44 }
45
46 let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
47 let cipher = Aes256Gcm::new(key);
48
49 Ok(Self { cipher })
50 }
51
52 #[cfg(test)]
54 pub fn new_random() -> Self {
55 use rand::RngCore;
56 let mut key_bytes = [0u8; 32];
57 rand::rng().fill_bytes(&mut key_bytes);
58 let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
59 let cipher = Aes256Gcm::new(key);
60 Self { cipher }
61 }
62
63 pub fn generate_key() -> String {
65 let mut key_bytes = [0u8; 32];
66 rand::rng().fill_bytes(&mut key_bytes);
67 BASE64.encode(key_bytes)
68 }
69
70 pub fn encrypt(&self, plaintext: &str) -> Result<EncryptedData> {
72 let mut nonce_bytes = [0u8; 12]; rand::rng().fill_bytes(&mut nonce_bytes);
75 let nonce = Nonce::from_slice(&nonce_bytes);
76
77 let ciphertext = self
79 .cipher
80 .encrypt(nonce, plaintext.as_bytes())
81 .map_err(|e| AuthError::internal(format!("Encryption failed: {}", e)))?;
82
83 Ok(EncryptedData {
84 data: BASE64.encode(&ciphertext),
85 nonce: BASE64.encode(nonce_bytes),
86 algorithm: "AES-256-GCM".to_string(),
87 key_derivation: "direct".to_string(),
88 })
89 }
90
91 pub fn decrypt(&self, encrypted: &EncryptedData) -> Result<String> {
93 if encrypted.algorithm != "AES-256-GCM" {
95 return Err(AuthError::internal(format!(
96 "Unsupported encryption algorithm: {}",
97 encrypted.algorithm
98 )));
99 }
100
101 let ciphertext = BASE64
103 .decode(&encrypted.data)
104 .map_err(|_| AuthError::internal("Invalid base64 in encrypted data"))?;
105
106 let nonce_bytes = BASE64
107 .decode(&encrypted.nonce)
108 .map_err(|_| AuthError::internal("Invalid base64 in nonce"))?;
109
110 if nonce_bytes.len() != 12 {
111 return Err(AuthError::internal("Invalid nonce length"));
112 }
113
114 let nonce = Nonce::from_slice(&nonce_bytes);
115
116 let plaintext = self
118 .cipher
119 .decrypt(nonce, ciphertext.as_ref())
120 .map_err(|e| AuthError::internal(format!("Decryption failed: {}", e)))?;
121
122 String::from_utf8(plaintext)
123 .map_err(|_| AuthError::internal("Decrypted data is not valid UTF-8"))
124 }
125
126 pub fn encrypt_for_storage(&self, data: &[u8]) -> Result<Vec<u8>> {
128 let plaintext = String::from_utf8(data.to_vec())
129 .map_err(|_| AuthError::internal("Storage data is not valid UTF-8"))?;
130
131 let encrypted = self.encrypt(&plaintext)?;
132 let serialized = serde_json::to_string(&encrypted).map_err(|e| {
133 AuthError::internal(format!("Failed to serialize encrypted data: {}", e))
134 })?;
135
136 Ok(serialized.into_bytes())
137 }
138
139 pub fn decrypt_from_storage(&self, data: &[u8]) -> Result<Vec<u8>> {
141 let serialized = String::from_utf8(data.to_vec())
142 .map_err(|_| AuthError::internal("Storage data is not valid UTF-8"))?;
143
144 let encrypted: EncryptedData = serde_json::from_str(&serialized).map_err(|e| {
145 AuthError::internal(format!("Failed to deserialize encrypted data: {}", e))
146 })?;
147
148 let plaintext = self.decrypt(&encrypted)?;
149 Ok(plaintext.into_bytes())
150 }
151}
152
153pub struct EncryptedStorage<T> {
155 inner: T,
156 encryption: StorageEncryption,
157}
158
159impl<T> EncryptedStorage<T> {
160 pub fn new(storage: T, encryption: StorageEncryption) -> Self {
161 Self {
162 inner: storage,
163 encryption,
164 }
165 }
166
167 pub fn into_inner(self) -> T {
168 self.inner
169 }
170}
171
172#[async_trait::async_trait]
173impl<T> AuthStorage for EncryptedStorage<T>
174where
175 T: AuthStorage + Send + Sync,
176{
177 async fn store_token(&self, token: &AuthToken) -> Result<()> {
179 self.inner.store_token(token).await
180 }
181
182 async fn get_token(&self, token_id: &str) -> Result<Option<AuthToken>> {
183 self.inner.get_token(token_id).await
184 }
185
186 async fn get_token_by_access_token(&self, access_token: &str) -> Result<Option<AuthToken>> {
187 self.inner.get_token_by_access_token(access_token).await
188 }
189
190 async fn update_token(&self, token: &AuthToken) -> Result<()> {
191 self.inner.update_token(token).await
192 }
193
194 async fn delete_token(&self, token_id: &str) -> Result<()> {
195 self.inner.delete_token(token_id).await
196 }
197
198 async fn list_user_tokens(&self, user_id: &str) -> Result<Vec<AuthToken>> {
199 self.inner.list_user_tokens(user_id).await
200 }
201
202 async fn store_session(&self, session_id: &str, data: &SessionData) -> Result<()> {
204 self.inner.store_session(session_id, data).await
205 }
206
207 async fn get_session(&self, session_id: &str) -> Result<Option<SessionData>> {
208 self.inner.get_session(session_id).await
209 }
210
211 async fn delete_session(&self, session_id: &str) -> Result<()> {
212 self.inner.delete_session(session_id).await
213 }
214
215 async fn list_user_sessions(&self, user_id: &str) -> Result<Vec<SessionData>> {
216 self.inner.list_user_sessions(user_id).await
217 }
218
219 async fn count_active_sessions(&self) -> Result<u64> {
220 self.inner.count_active_sessions().await
221 }
222
223 async fn store_kv(&self, key: &str, value: &[u8], ttl: Option<Duration>) -> Result<()> {
225 let encrypted_value = self.encryption.encrypt_for_storage(value)?;
226 self.inner.store_kv(key, &encrypted_value, ttl).await
227 }
228
229 async fn get_kv(&self, key: &str) -> Result<Option<Vec<u8>>> {
230 if let Some(encrypted_data) = self.inner.get_kv(key).await? {
231 let decrypted_data = self.encryption.decrypt_from_storage(&encrypted_data)?;
232 Ok(Some(decrypted_data))
233 } else {
234 Ok(None)
235 }
236 }
237
238 async fn delete_kv(&self, key: &str) -> Result<()> {
239 self.inner.delete_kv(key).await
240 }
241
242 async fn cleanup_expired(&self) -> Result<()> {
243 self.inner.cleanup_expired().await
244 }
245}
246#[cfg(test)]
247mod tests {
248 use super::*;
249
250 #[test]
251 fn test_key_generation() {
252 let key = StorageEncryption::generate_key();
253 assert!(!key.is_empty());
254
255 let decoded = BASE64.decode(&key).unwrap();
257 assert_eq!(decoded.len(), 32);
258 }
259
260 #[test]
261 fn test_encryption_roundtrip() {
262 let encryption = StorageEncryption::new_random();
263 let plaintext = "sensitive client secret data";
264
265 let encrypted = encryption.encrypt(plaintext).unwrap();
266 assert_ne!(encrypted.data, plaintext);
267 assert_eq!(encrypted.algorithm, "AES-256-GCM");
268
269 let decrypted = encryption.decrypt(&encrypted).unwrap();
270 assert_eq!(decrypted, plaintext);
271 }
272
273 #[test]
274 fn test_storage_encryption() {
275 let encryption = StorageEncryption::new_random();
276 let data = b"sensitive authentication data";
277
278 let encrypted = encryption.encrypt_for_storage(data).unwrap();
279 assert_ne!(encrypted, data);
280
281 let decrypted = encryption.decrypt_from_storage(&encrypted).unwrap();
282 assert_eq!(decrypted, data);
283 }
284}