1use crate::crypto::dek_cache::{CacheKey, DekCache, DekMaterial};
2use crate::errors::{DecryptError, DecryptResult, Error, Result};
3use crate::key_provider::KeyProvider;
4use crate::types::{EncryptionAlgorithm, Envelope, Scope, SecretMeta, SecretRecord};
5#[allow(deprecated)]
6use aes_gcm::aead::generic_array::GenericArray;
7use aes_gcm::aead::{Aead, KeyInit};
8use aes_gcm::Aes256Gcm;
9use hkdf::Hkdf;
10use rand::RngCore;
11use sha2::Sha256;
12use std::env;
13
14#[cfg(feature = "xchacha")]
15use chacha20poly1305::{Nonce as XNonce, XChaCha20Poly1305};
16
17const DEFAULT_DEK_LEN: usize = 32;
18const HKDF_SALT_LEN: usize = 32;
19const ENC_ALGO_ENV: &str = "SECRETS_ENC_ALGO";
20
21pub struct EnvelopeService<P>
23where
24 P: KeyProvider,
25{
26 provider: P,
27 cache: DekCache,
28 algorithm: EncryptionAlgorithm,
29}
30
31impl<P> EnvelopeService<P>
32where
33 P: KeyProvider,
34{
35 pub fn new(provider: P, cache: DekCache, algorithm: EncryptionAlgorithm) -> Self {
37 Self {
38 provider,
39 cache,
40 algorithm,
41 }
42 }
43
44 pub fn from_env(provider: P) -> Result<Self> {
46 let algorithm = env::var(ENC_ALGO_ENV)
47 .ok()
48 .filter(|s| !s.trim().is_empty())
49 .map(|value| value.parse())
50 .transpose()?
51 .unwrap_or_default();
52
53 Ok(Self::new(provider, DekCache::from_env(), algorithm))
54 }
55
56 pub fn algorithm(&self) -> EncryptionAlgorithm {
58 self.algorithm
59 }
60
61 pub fn cache(&self) -> &DekCache {
63 &self.cache
64 }
65
66 pub fn cache_mut(&mut self) -> &mut DekCache {
68 &mut self.cache
69 }
70
71 pub fn encrypt_record(&mut self, meta: SecretMeta, plaintext: &[u8]) -> Result<SecretRecord> {
73 let cache_key = CacheKey::from_meta(&meta);
74 let scope = meta.scope().clone();
75 let info = meta.uri.to_string();
76
77 let (dek, wrapped) = self.obtain_dek(&cache_key, &scope)?;
78
79 let salt = random_bytes(HKDF_SALT_LEN);
80 let key = derive_key(&dek, &salt, info.as_bytes())?;
81 let nonce = random_bytes(self.algorithm.nonce_len());
82 let ciphertext = encrypt_with_algorithm(self.algorithm, &key, &nonce, plaintext)?;
83
84 let envelope = Envelope {
85 algorithm: self.algorithm,
86 nonce,
87 hkdf_salt: salt,
88 wrapped_dek: wrapped.clone(),
89 };
90
91 Ok(SecretRecord::new(meta, ciphertext, envelope))
92 }
93
94 fn obtain_dek(&mut self, cache_key: &CacheKey, scope: &Scope) -> Result<(Vec<u8>, Vec<u8>)> {
95 if let Some(material) = self.cache.get(cache_key) {
96 return Ok((material.dek, material.wrapped));
97 }
98
99 let dek = generate_dek();
100 let wrapped = self.provider.wrap_dek(scope, &dek)?;
101 self.cache
102 .insert(cache_key.clone(), dek.clone(), wrapped.clone());
103 Ok((dek, wrapped))
104 }
105
106 pub fn decrypt_record(&mut self, record: &SecretRecord) -> DecryptResult<Vec<u8>> {
108 let cache_key = CacheKey::from_meta(&record.meta);
109 let scope = record.meta.scope();
110 let algorithm = record.envelope.algorithm;
111 let info = record.meta.uri.to_string();
112
113 let material = match self.cache.get(&cache_key) {
114 Some(material) => material,
115 None => {
116 let dek = self
117 .provider
118 .unwrap_dek(scope, &record.envelope.wrapped_dek)
119 .map_err(|err| DecryptError::Provider(err.to_string()))?;
120 let material = DekMaterial {
121 dek: dek.clone(),
122 wrapped: record.envelope.wrapped_dek.clone(),
123 };
124 self.cache.insert(
125 cache_key.clone(),
126 material.dek.clone(),
127 material.wrapped.clone(),
128 );
129 material
130 }
131 };
132
133 let key = derive_key(&material.dek, &record.envelope.hkdf_salt, info.as_bytes())
134 .map_err(|err| DecryptError::Crypto(err.to_string()))?;
135 let plaintext =
136 decrypt_with_algorithm(algorithm, &key, &record.envelope.nonce, &record.value)?;
137
138 Ok(plaintext)
139 }
140}
141
142#[allow(deprecated)]
143fn encrypt_with_algorithm(
144 algorithm: EncryptionAlgorithm,
145 key: &[u8; 32],
146 nonce: &[u8],
147 plaintext: &[u8],
148) -> Result<Vec<u8>> {
149 match algorithm {
150 EncryptionAlgorithm::Aes256Gcm => {
151 let cipher = Aes256Gcm::new_from_slice(key)
152 .map_err(|_| Error::Crypto("invalid AES key".into()))?;
153 let nonce = GenericArray::clone_from_slice(nonce);
154 cipher
155 .encrypt(&nonce, plaintext)
156 .map_err(|_| Error::Crypto("failed to encrypt payload".into()))
157 }
158 EncryptionAlgorithm::XChaCha20Poly1305 => {
159 #[cfg(feature = "xchacha")]
160 {
161 let cipher = XChaCha20Poly1305::new_from_slice(key)
162 .map_err(|_| Error::Crypto("invalid XChaCha key".into()))?;
163 let nonce = XNonce::from_slice(nonce);
164 cipher
165 .encrypt(nonce, plaintext)
166 .map_err(|_| Error::Crypto("failed to encrypt payload".into()))
167 }
168 #[cfg(not(feature = "xchacha"))]
169 {
170 Err(Error::AlgorithmFeatureUnavailable(
171 algorithm.as_str().to_string(),
172 ))
173 }
174 }
175 }
176}
177
178#[allow(deprecated)]
179fn decrypt_with_algorithm(
180 algorithm: EncryptionAlgorithm,
181 key: &[u8; 32],
182 nonce: &[u8],
183 ciphertext: &[u8],
184) -> DecryptResult<Vec<u8>> {
185 match algorithm {
186 EncryptionAlgorithm::Aes256Gcm => {
187 let cipher = Aes256Gcm::new_from_slice(key)
188 .map_err(|_| DecryptError::Crypto("invalid AES key".into()))?;
189 let nonce = GenericArray::clone_from_slice(nonce);
190 cipher
191 .decrypt(&nonce, ciphertext)
192 .map_err(|_| DecryptError::MacMismatch)
193 }
194 EncryptionAlgorithm::XChaCha20Poly1305 => {
195 #[cfg(feature = "xchacha")]
196 {
197 let cipher = XChaCha20Poly1305::new_from_slice(key)
198 .map_err(|_| DecryptError::Crypto("invalid XChaCha key".into()))?;
199 let nonce = XNonce::from_slice(nonce);
200 cipher
201 .decrypt(nonce, ciphertext)
202 .map_err(|_| DecryptError::MacMismatch)
203 }
204 #[cfg(not(feature = "xchacha"))]
205 {
206 Err(DecryptError::Crypto(format!(
207 "algorithm {algorithm} unavailable"
208 )))
209 }
210 }
211 }
212}
213
214fn derive_key(dek: &[u8], salt: &[u8], info: &[u8]) -> Result<[u8; 32]> {
215 let hkdf = Hkdf::<Sha256>::new(Some(salt), dek);
216 let mut okm = [0u8; 32];
217 hkdf.expand(info, &mut okm)
218 .map_err(|_| Error::Crypto("failed to derive key material".into()))?;
219 Ok(okm)
220}
221
222fn generate_dek() -> Vec<u8> {
223 random_bytes(DEFAULT_DEK_LEN)
224}
225
226fn random_bytes(len: usize) -> Vec<u8> {
227 let mut buffer = vec![0u8; len];
228 let mut rng = rand::rng();
229 rng.fill_bytes(&mut buffer);
230 buffer
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use crate::crypto::dek_cache::DekCache;
237 use crate::key_provider::KeyProvider;
238 use crate::types::{ContentType, Scope, SecretMeta, Visibility};
239 use crate::uri::SecretUri;
240 use std::sync::{Arc, Mutex};
241 use std::time::Duration;
242
243 #[derive(Clone)]
244 struct DummyProvider {
245 wrap_calls: Arc<Mutex<usize>>,
246 unwrap_calls: Arc<Mutex<usize>>,
247 }
248
249 impl DummyProvider {
250 fn new() -> Self {
251 Self {
252 wrap_calls: Arc::new(Mutex::new(0)),
253 unwrap_calls: Arc::new(Mutex::new(0)),
254 }
255 }
256
257 fn calls(&self) -> (usize, usize) {
258 (
259 *self.wrap_calls.lock().unwrap(),
260 *self.unwrap_calls.lock().unwrap(),
261 )
262 }
263 }
264
265 impl KeyProvider for DummyProvider {
266 fn wrap_dek(&self, _scope: &Scope, dek: &[u8]) -> Result<Vec<u8>> {
267 *self.wrap_calls.lock().unwrap() += 1;
268 Ok(dek.iter().map(|b| b ^ 0xAA).collect())
269 }
270
271 fn unwrap_dek(&self, _scope: &Scope, wrapped: &[u8]) -> Result<Vec<u8>> {
272 *self.unwrap_calls.lock().unwrap() += 1;
273 Ok(wrapped.iter().map(|b| b ^ 0xAA).collect())
274 }
275 }
276
277 fn sample_meta(team: Option<&str>) -> SecretMeta {
278 let scope = Scope::new(
279 "prod".to_string(),
280 "acme".to_string(),
281 team.map(|t| t.to_string()),
282 )
283 .unwrap();
284 let uri = SecretUri::new(scope.clone(), "kv", "api")
285 .unwrap()
286 .with_version(Some("v1"))
287 .unwrap();
288 SecretMeta::new(uri, Visibility::Team, ContentType::Opaque)
289 }
290
291 #[test]
292 fn encrypt_decrypt_roundtrip() {
293 let provider = DummyProvider::new();
294 let cache = DekCache::new(8, Duration::from_secs(300));
295 let mut service = EnvelopeService::new(provider, cache, EncryptionAlgorithm::Aes256Gcm);
296
297 let meta = sample_meta(Some("payments"));
298 let plaintext = b"super-secret-data";
299 let record = service
300 .encrypt_record(meta.clone(), plaintext)
301 .expect("encrypt");
302
303 let recovered = service.decrypt_record(&record).expect("decrypt");
304 assert_eq!(plaintext.to_vec(), recovered);
305 assert_eq!(record.meta, meta);
306 }
307
308 #[test]
309 fn tamper_detection() {
310 let provider = DummyProvider::new();
311 let cache = DekCache::new(8, Duration::from_secs(300));
312 let mut service = EnvelopeService::new(provider, cache, EncryptionAlgorithm::Aes256Gcm);
313 let meta = sample_meta(Some("payments"));
314
315 let mut record = service.encrypt_record(meta, b"critical").expect("encrypt");
316 record.value[0] ^= 0xFF;
317
318 let err = service.decrypt_record(&record).unwrap_err();
319 assert!(matches!(err, DecryptError::MacMismatch));
320 }
321
322 #[test]
323 fn cache_hit_and_miss_behavior() {
324 let provider = DummyProvider::new();
325 let cache = DekCache::new(8, Duration::from_secs(300));
326 let mut service =
327 EnvelopeService::new(provider.clone(), cache, EncryptionAlgorithm::Aes256Gcm);
328 let meta = sample_meta(Some("payments"));
329 let plaintext = b"payload";
330
331 service
332 .encrypt_record(meta.clone(), plaintext)
333 .expect("encrypt");
334 let (wrap_calls, _) = provider.calls();
335 assert_eq!(wrap_calls, 1);
336
337 service
338 .encrypt_record(meta.clone(), plaintext)
339 .expect("encrypt again");
340 let (wrap_calls, _) = provider.calls();
341 assert_eq!(wrap_calls, 1, "expected cache hit to avoid wrapping");
342
343 let (wrap_calls_before, _) = provider.calls();
345 let mut service = EnvelopeService::new(
346 provider.clone(),
347 DekCache::new(8, Duration::from_secs(0)),
348 EncryptionAlgorithm::Aes256Gcm,
349 );
350 service
351 .encrypt_record(meta, plaintext)
352 .expect("encrypt with fresh cache");
353 let (wrap_calls, _) = provider.calls();
354 assert!(
355 wrap_calls > wrap_calls_before,
356 "expected miss to invoke wrap again"
357 );
358 }
359}