1use crate::crypto::dek_cache::{CacheKey, DekCache, DekMaterial};
2use crate::key_provider::KeyProvider;
3use crate::spec_compat::{
4 DecryptError, DecryptResult, EncryptionAlgorithm, Envelope, Error, Result, Scope, SecretMeta,
5 SecretRecord,
6};
7use base64::{engine::general_purpose::STANDARD, Engine};
8use hkdf::Hkdf;
9use rand::RngCore;
10#[cfg(feature = "crypto-ring")]
11use ring::{
12 aead,
13 rand::{SecureRandom, SystemRandom},
14};
15use sha2::Sha256;
16use std::env;
17
18#[cfg(feature = "xchacha")]
19use chacha20poly1305::{aead::Aead, KeyInit, XChaCha20Poly1305, XNonce};
20
21const DEFAULT_DEK_LEN: usize = 32;
22const HKDF_SALT_LEN: usize = 32;
23#[cfg(feature = "xchacha")]
24const XCHACHA_NONCE_LEN: usize = 24;
25#[cfg(feature = "crypto-ring")]
26const NONCE_LEN: usize = 12;
27#[cfg(feature = "crypto-ring")]
28const TAG_LEN: usize = 16;
29const ENC_ALGO_ENV: &str = "SECRETS_ENC_ALGO";
30
31#[cfg(not(any(feature = "crypto-ring", feature = "crypto-none")))]
32compile_error!("Enable either the `crypto-ring` or `crypto-none` feature for envelope encryption");
33
34pub struct EnvelopeService<P>
36where
37 P: KeyProvider,
38{
39 provider: P,
40 cache: DekCache,
41 algorithm: EncryptionAlgorithm,
42}
43
44impl<P> EnvelopeService<P>
45where
46 P: KeyProvider,
47{
48 pub fn new(provider: P, cache: DekCache, algorithm: EncryptionAlgorithm) -> Self {
50 Self {
51 provider,
52 cache,
53 algorithm,
54 }
55 }
56
57 pub fn from_env(provider: P) -> Result<Self> {
59 let algorithm = env::var(ENC_ALGO_ENV)
60 .ok()
61 .filter(|s| !s.trim().is_empty())
62 .map(|value| value.parse())
63 .transpose()?
64 .unwrap_or_default();
65
66 Ok(Self::new(provider, DekCache::from_env(), algorithm))
67 }
68
69 pub fn algorithm(&self) -> EncryptionAlgorithm {
71 self.algorithm
72 }
73
74 pub fn cache(&self) -> &DekCache {
76 &self.cache
77 }
78
79 pub fn cache_mut(&mut self) -> &mut DekCache {
81 &mut self.cache
82 }
83
84 pub fn encrypt_record(&mut self, meta: SecretMeta, plaintext: &[u8]) -> Result<SecretRecord> {
86 let cache_key = CacheKey::from_meta(&meta);
87 let scope = meta.scope().clone();
88 let info = meta.uri.to_string();
89
90 let (dek, wrapped) = self.obtain_dek(&cache_key, &scope)?;
91
92 let salt = random_bytes(HKDF_SALT_LEN);
93 let key = derive_key(&dek, &salt, info.as_bytes())?;
94 let (nonce, ciphertext) = encrypt_with_algorithm(self.algorithm, &key, plaintext)?;
95
96 let envelope = Envelope {
97 algorithm: self.algorithm,
98 nonce,
99 hkdf_salt: salt,
100 wrapped_dek: wrapped.clone(),
101 };
102
103 Ok(SecretRecord::new(meta, ciphertext, envelope))
104 }
105
106 fn obtain_dek(&mut self, cache_key: &CacheKey, scope: &Scope) -> Result<(Vec<u8>, Vec<u8>)> {
107 if let Some(material) = self.cache.get(cache_key) {
108 return Ok((material.dek, material.wrapped));
109 }
110
111 let dek = generate_dek();
112 let wrapped = self.provider.wrap_dek(scope, &dek)?;
113 self.cache
114 .insert(cache_key.clone(), dek.clone(), wrapped.clone());
115 Ok((dek, wrapped))
116 }
117
118 pub fn decrypt_record(&mut self, record: &SecretRecord) -> DecryptResult<Vec<u8>> {
120 let cache_key = CacheKey::from_meta(&record.meta);
121 let scope = record.meta.scope();
122 let algorithm = record.envelope.algorithm;
123 let info = record.meta.uri.to_string();
124
125 let material = match self.cache.get(&cache_key) {
126 Some(material) => material,
127 None => {
128 let dek = self
129 .provider
130 .unwrap_dek(scope, &record.envelope.wrapped_dek)
131 .map_err(|err| DecryptError::Provider(err.to_string()))?;
132 let material = DekMaterial {
133 dek: dek.clone(),
134 wrapped: record.envelope.wrapped_dek.clone(),
135 };
136 self.cache.insert(
137 cache_key.clone(),
138 material.dek.clone(),
139 material.wrapped.clone(),
140 );
141 material
142 }
143 };
144
145 let key = derive_key(&material.dek, &record.envelope.hkdf_salt, info.as_bytes())
146 .map_err(|err| DecryptError::Crypto(err.to_string()))?;
147 let plaintext =
148 decrypt_with_algorithm(algorithm, &key, &record.envelope.nonce, &record.value)?;
149
150 Ok(plaintext)
151 }
152}
153
154fn encrypt_with_algorithm(
155 algorithm: EncryptionAlgorithm,
156 key: &[u8; 32],
157 plaintext: &[u8],
158) -> Result<(Vec<u8>, Vec<u8>)> {
159 match algorithm {
160 EncryptionAlgorithm::Aes256Gcm => {
161 let sealed = seal_aead(key, plaintext).map_err(|err| Error::Crypto(err.to_string()))?;
162 let data = STANDARD
163 .decode(sealed)
164 .map_err(|err| Error::Crypto(err.to_string()))?;
165 let nonce_len = EncryptionAlgorithm::Aes256Gcm.nonce_len();
166 if data.len() < nonce_len {
167 return Err(Error::Crypto("ciphertext too short".into()));
168 }
169 let (nonce, ciphertext) = data.split_at(nonce_len);
170 Ok((nonce.to_vec(), ciphertext.to_vec()))
171 }
172 EncryptionAlgorithm::XChaCha20Poly1305 => {
173 #[cfg(feature = "xchacha")]
174 {
175 let cipher = XChaCha20Poly1305::new_from_slice(key)
176 .map_err(|_| Error::Crypto("invalid XChaCha key".into()))?;
177 let nonce_bytes = random_bytes(EncryptionAlgorithm::XChaCha20Poly1305.nonce_len());
178 let nonce_array: &[u8; XCHACHA_NONCE_LEN] = nonce_bytes
179 .as_slice()
180 .try_into()
181 .map_err(|_| Error::Crypto("invalid XChaCha nonce length".into()))?;
182 let nonce = XNonce::from(*nonce_array);
183 cipher
184 .encrypt(&nonce, plaintext)
185 .map(|ciphertext| (nonce_bytes, ciphertext))
186 .map_err(|_| Error::Crypto("failed to encrypt payload".into()))
187 }
188 #[cfg(not(feature = "xchacha"))]
189 {
190 Err(Error::AlgorithmFeatureUnavailable(
191 algorithm.as_str().to_string(),
192 ))
193 }
194 }
195 }
196}
197
198fn decrypt_with_algorithm(
199 algorithm: EncryptionAlgorithm,
200 key: &[u8; 32],
201 nonce: &[u8],
202 ciphertext: &[u8],
203) -> DecryptResult<Vec<u8>> {
204 match algorithm {
205 EncryptionAlgorithm::Aes256Gcm => {
206 let mut combined = Vec::with_capacity(nonce.len() + ciphertext.len());
207 combined.extend_from_slice(nonce);
208 combined.extend_from_slice(ciphertext);
209 let encoded = STANDARD.encode(combined);
210 match open_aead(key, &encoded) {
211 Ok(bytes) => Ok(bytes),
212 Err(Error::Backend(message)) if message == "open failed" => {
213 Err(DecryptError::MacMismatch)
214 }
215 Err(err) => Err(DecryptError::Crypto(err.to_string())),
216 }
217 }
218 EncryptionAlgorithm::XChaCha20Poly1305 => {
219 #[cfg(feature = "xchacha")]
220 {
221 let cipher = XChaCha20Poly1305::new_from_slice(key)
222 .map_err(|_| DecryptError::Crypto("invalid XChaCha key".into()))?;
223 let nonce_array: &[u8; XCHACHA_NONCE_LEN] = nonce
224 .try_into()
225 .map_err(|_| DecryptError::Crypto("invalid XChaCha nonce length".into()))?;
226 let nonce = XNonce::from(*nonce_array);
227 cipher
228 .decrypt(&nonce, ciphertext)
229 .map_err(|_| DecryptError::MacMismatch)
230 }
231 #[cfg(not(feature = "xchacha"))]
232 {
233 Err(DecryptError::Crypto(format!(
234 "algorithm {algorithm} unavailable"
235 )))
236 }
237 }
238 }
239}
240
241#[cfg(feature = "crypto-ring")]
242fn seal_aead(key_bytes: &[u8], plaintext: &[u8]) -> Result<String> {
243 let rng = SystemRandom::new();
244 let mut nonce = [0u8; NONCE_LEN];
245 rng.fill(&mut nonce)
246 .map_err(|err| Error::Backend(format!("rng: {err:?}")))?;
247
248 let key = aead::UnboundKey::new(&aead::AES_256_GCM, key_bytes)
249 .map_err(|_| Error::Backend("invalid key".into()))?;
250 let key = aead::LessSafeKey::new(key);
251
252 let mut in_out = plaintext.to_vec();
253 in_out.reserve(TAG_LEN);
254 key.seal_in_place_append_tag(
255 aead::Nonce::assume_unique_for_key(nonce),
256 aead::Aad::empty(),
257 &mut in_out,
258 )
259 .map_err(|_| Error::Backend("seal failed".into()))?;
260
261 let mut out = Vec::with_capacity(NONCE_LEN + in_out.len());
262 out.extend_from_slice(&nonce);
263 out.extend_from_slice(&in_out);
264 Ok(STANDARD.encode(out))
265}
266
267#[cfg(feature = "crypto-ring")]
268fn open_aead(key_bytes: &[u8], b64: &str) -> Result<Vec<u8>> {
269 let data = STANDARD
270 .decode(b64)
271 .map_err(|_| Error::Invalid("ciphertext".into(), "b64".into()))?;
272 if data.len() < NONCE_LEN {
273 return Err(Error::Invalid("ciphertext".into(), "too short".into()));
274 }
275 let (nonce, ct) = data.split_at(NONCE_LEN);
276
277 let key = aead::UnboundKey::new(&aead::AES_256_GCM, key_bytes)
278 .map_err(|_| Error::Backend("invalid key".into()))?;
279 let key = aead::LessSafeKey::new(key);
280
281 let mut buffer = ct.to_vec();
282 let plaintext = key
283 .open_in_place(
284 aead::Nonce::try_assume_unique_for_key(nonce)
285 .map_err(|_| Error::Invalid("nonce".into(), "invalid length".into()))?,
286 aead::Aad::empty(),
287 &mut buffer,
288 )
289 .map_err(|_| Error::Backend("open failed".into()))?;
290
291 Ok(plaintext.to_vec())
292}
293
294#[cfg(all(feature = "crypto-none", not(feature = "crypto-ring")))]
295fn seal_aead(_key_bytes: &[u8], plaintext: &[u8]) -> Result<String> {
296 Ok(STANDARD.encode(plaintext))
297}
298
299#[cfg(all(feature = "crypto-none", not(feature = "crypto-ring")))]
300fn open_aead(_key_bytes: &[u8], b64: &str) -> Result<Vec<u8>> {
301 STANDARD
302 .decode(b64)
303 .map_err(|_| Error::Invalid("ciphertext".into(), "b64".into()))
304}
305
306fn derive_key(dek: &[u8], salt: &[u8], info: &[u8]) -> Result<[u8; 32]> {
307 let hkdf = Hkdf::<Sha256>::new(Some(salt), dek);
308 let mut okm = [0u8; 32];
309 hkdf.expand(info, &mut okm)
310 .map_err(|_| Error::Crypto("failed to derive key material".into()))?;
311 Ok(okm)
312}
313
314fn generate_dek() -> Vec<u8> {
315 random_bytes(DEFAULT_DEK_LEN)
316}
317
318fn random_bytes(len: usize) -> Vec<u8> {
319 let mut buffer = vec![0u8; len];
320 let mut rng = rand::rng();
321 rng.fill_bytes(&mut buffer);
322 buffer
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328 use crate::crypto::dek_cache::DekCache;
329 use crate::key_provider::KeyProvider;
330 use crate::spec_compat::{ContentType, Scope, SecretMeta, Visibility};
331 use crate::uri::SecretUri;
332 use std::sync::{Arc, Mutex};
333 use std::time::Duration;
334
335 #[derive(Clone)]
336 struct DummyProvider {
337 wrap_calls: Arc<Mutex<usize>>,
338 unwrap_calls: Arc<Mutex<usize>>,
339 }
340
341 impl DummyProvider {
342 fn new() -> Self {
343 Self {
344 wrap_calls: Arc::new(Mutex::new(0)),
345 unwrap_calls: Arc::new(Mutex::new(0)),
346 }
347 }
348
349 fn calls(&self) -> (usize, usize) {
350 (
351 *self.wrap_calls.lock().unwrap(),
352 *self.unwrap_calls.lock().unwrap(),
353 )
354 }
355 }
356
357 impl KeyProvider for DummyProvider {
358 fn wrap_dek(&self, _scope: &Scope, dek: &[u8]) -> Result<Vec<u8>> {
359 *self.wrap_calls.lock().unwrap() += 1;
360 Ok(dek.iter().map(|b| b ^ 0xAA).collect())
361 }
362
363 fn unwrap_dek(&self, _scope: &Scope, wrapped: &[u8]) -> Result<Vec<u8>> {
364 *self.unwrap_calls.lock().unwrap() += 1;
365 Ok(wrapped.iter().map(|b| b ^ 0xAA).collect())
366 }
367 }
368
369 fn sample_meta(team: Option<&str>) -> SecretMeta {
370 let scope = Scope::new(
371 "prod".to_string(),
372 "acme".to_string(),
373 team.map(|t| t.to_string()),
374 )
375 .unwrap();
376 let uri = SecretUri::new(scope.clone(), "kv", "api")
377 .unwrap()
378 .with_version(Some("v1"))
379 .unwrap();
380 SecretMeta::new(uri, Visibility::Team, ContentType::Opaque)
381 }
382
383 #[test]
384 fn encrypt_decrypt_roundtrip() {
385 let provider = DummyProvider::new();
386 let cache = DekCache::new(8, Duration::from_secs(300));
387 let mut service = EnvelopeService::new(provider, cache, EncryptionAlgorithm::Aes256Gcm);
388
389 let meta = sample_meta(Some("payments"));
390 let plaintext = b"super-secret-data";
391 let record = service
392 .encrypt_record(meta.clone(), plaintext)
393 .expect("encrypt");
394
395 let recovered = service.decrypt_record(&record).expect("decrypt");
396 assert_eq!(plaintext.to_vec(), recovered);
397 assert_eq!(record.meta, meta);
398 }
399
400 #[test]
401 fn tamper_detection() {
402 let provider = DummyProvider::new();
403 let cache = DekCache::new(8, Duration::from_secs(300));
404 let mut service = EnvelopeService::new(provider, cache, EncryptionAlgorithm::Aes256Gcm);
405 let meta = sample_meta(Some("payments"));
406
407 let mut record = service.encrypt_record(meta, b"critical").expect("encrypt");
408 record.value[0] ^= 0xFF;
409
410 let err = service.decrypt_record(&record).unwrap_err();
411 assert!(matches!(err, DecryptError::MacMismatch));
412 }
413
414 #[test]
415 fn cache_hit_and_miss_behavior() {
416 let provider = DummyProvider::new();
417 let cache = DekCache::new(8, Duration::from_secs(300));
418 let mut service =
419 EnvelopeService::new(provider.clone(), cache, EncryptionAlgorithm::Aes256Gcm);
420 let meta = sample_meta(Some("payments"));
421 let plaintext = b"payload";
422
423 service
424 .encrypt_record(meta.clone(), plaintext)
425 .expect("encrypt");
426 let (wrap_calls, _) = provider.calls();
427 assert_eq!(wrap_calls, 1);
428
429 service
430 .encrypt_record(meta.clone(), plaintext)
431 .expect("encrypt again");
432 let (wrap_calls, _) = provider.calls();
433 assert_eq!(wrap_calls, 1, "expected cache hit to avoid wrapping");
434
435 let (wrap_calls_before, _) = provider.calls();
437 let mut service = EnvelopeService::new(
438 provider.clone(),
439 DekCache::new(8, Duration::from_secs(0)),
440 EncryptionAlgorithm::Aes256Gcm,
441 );
442 service
443 .encrypt_record(meta, plaintext)
444 .expect("encrypt with fresh cache");
445 let (wrap_calls, _) = provider.calls();
446 assert!(
447 wrap_calls > wrap_calls_before,
448 "expected miss to invoke wrap again"
449 );
450 }
451}