1use crate::crypto::dek_cache::{CacheKey, DekCache, DekMaterial};
2use crate::key_provider::KeyProvider;
3use crate::spec_compat::{
4 DecryptError, DecryptResult, EncryptionAlgorithm, Envelope, Error, Result, Scope, SecretMeta,
5 SecretRecord,
6};
7use base64::{engine::general_purpose::STANDARD, Engine};
8use greentic_secrets_support::{open_aead, seal_aead};
9use hkdf::Hkdf;
10use rand::RngCore;
11use sha2::Sha256;
12use std::env;
13
14#[cfg(feature = "xchacha")]
15use chacha20poly1305::{aead::Aead, KeyInit, XChaCha20Poly1305, XNonce};
16
17const DEFAULT_DEK_LEN: usize = 32;
18const HKDF_SALT_LEN: usize = 32;
19#[cfg(feature = "xchacha")]
20const XCHACHA_NONCE_LEN: usize = 24;
21const ENC_ALGO_ENV: &str = "SECRETS_ENC_ALGO";
22
23pub struct EnvelopeService<P>
25where
26 P: KeyProvider,
27{
28 provider: P,
29 cache: DekCache,
30 algorithm: EncryptionAlgorithm,
31}
32
33impl<P> EnvelopeService<P>
34where
35 P: KeyProvider,
36{
37 pub fn new(provider: P, cache: DekCache, algorithm: EncryptionAlgorithm) -> Self {
39 Self {
40 provider,
41 cache,
42 algorithm,
43 }
44 }
45
46 pub fn from_env(provider: P) -> Result<Self> {
48 let algorithm = env::var(ENC_ALGO_ENV)
49 .ok()
50 .filter(|s| !s.trim().is_empty())
51 .map(|value| value.parse())
52 .transpose()?
53 .unwrap_or_default();
54
55 Ok(Self::new(provider, DekCache::from_env(), algorithm))
56 }
57
58 pub fn algorithm(&self) -> EncryptionAlgorithm {
60 self.algorithm
61 }
62
63 pub fn cache(&self) -> &DekCache {
65 &self.cache
66 }
67
68 pub fn cache_mut(&mut self) -> &mut DekCache {
70 &mut self.cache
71 }
72
73 pub fn encrypt_record(&mut self, meta: SecretMeta, plaintext: &[u8]) -> Result<SecretRecord> {
75 let cache_key = CacheKey::from_meta(&meta);
76 let scope = meta.scope().clone();
77 let info = meta.uri.to_string();
78
79 let (dek, wrapped) = self.obtain_dek(&cache_key, &scope)?;
80
81 let salt = random_bytes(HKDF_SALT_LEN);
82 let key = derive_key(&dek, &salt, info.as_bytes())?;
83 let (nonce, ciphertext) = encrypt_with_algorithm(self.algorithm, &key, plaintext)?;
84
85 let envelope = Envelope {
86 algorithm: self.algorithm,
87 nonce,
88 hkdf_salt: salt,
89 wrapped_dek: wrapped.clone(),
90 };
91
92 Ok(SecretRecord::new(meta, ciphertext, envelope))
93 }
94
95 fn obtain_dek(&mut self, cache_key: &CacheKey, scope: &Scope) -> Result<(Vec<u8>, Vec<u8>)> {
96 if let Some(material) = self.cache.get(cache_key) {
97 return Ok((material.dek, material.wrapped));
98 }
99
100 let dek = generate_dek();
101 let wrapped = self.provider.wrap_dek(scope, &dek)?;
102 self.cache
103 .insert(cache_key.clone(), dek.clone(), wrapped.clone());
104 Ok((dek, wrapped))
105 }
106
107 pub fn decrypt_record(&mut self, record: &SecretRecord) -> DecryptResult<Vec<u8>> {
109 let cache_key = CacheKey::from_meta(&record.meta);
110 let scope = record.meta.scope();
111 let algorithm = record.envelope.algorithm;
112 let info = record.meta.uri.to_string();
113
114 let material = match self.cache.get(&cache_key) {
115 Some(material) => material,
116 None => {
117 let dek = self
118 .provider
119 .unwrap_dek(scope, &record.envelope.wrapped_dek)
120 .map_err(|err| DecryptError::Provider(err.to_string()))?;
121 let material = DekMaterial {
122 dek: dek.clone(),
123 wrapped: record.envelope.wrapped_dek.clone(),
124 };
125 self.cache.insert(
126 cache_key.clone(),
127 material.dek.clone(),
128 material.wrapped.clone(),
129 );
130 material
131 }
132 };
133
134 let key = derive_key(&material.dek, &record.envelope.hkdf_salt, info.as_bytes())
135 .map_err(|err| DecryptError::Crypto(err.to_string()))?;
136 let plaintext =
137 decrypt_with_algorithm(algorithm, &key, &record.envelope.nonce, &record.value)?;
138
139 Ok(plaintext)
140 }
141}
142
143fn encrypt_with_algorithm(
144 algorithm: EncryptionAlgorithm,
145 key: &[u8; 32],
146 plaintext: &[u8],
147) -> Result<(Vec<u8>, Vec<u8>)> {
148 match algorithm {
149 EncryptionAlgorithm::Aes256Gcm => {
150 let sealed = seal_aead(key, plaintext).map_err(|err| Error::Crypto(err.to_string()))?;
151 let data = STANDARD
152 .decode(sealed)
153 .map_err(|err| Error::Crypto(err.to_string()))?;
154 let nonce_len = EncryptionAlgorithm::Aes256Gcm.nonce_len();
155 if data.len() < nonce_len {
156 return Err(Error::Crypto("ciphertext too short".into()));
157 }
158 let (nonce, ciphertext) = data.split_at(nonce_len);
159 Ok((nonce.to_vec(), ciphertext.to_vec()))
160 }
161 EncryptionAlgorithm::XChaCha20Poly1305 => {
162 #[cfg(feature = "xchacha")]
163 {
164 let cipher = XChaCha20Poly1305::new_from_slice(key)
165 .map_err(|_| Error::Crypto("invalid XChaCha key".into()))?;
166 let nonce_bytes = random_bytes(EncryptionAlgorithm::XChaCha20Poly1305.nonce_len());
167 let nonce_array: &[u8; XCHACHA_NONCE_LEN] = nonce_bytes
168 .as_slice()
169 .try_into()
170 .map_err(|_| Error::Crypto("invalid XChaCha nonce length".into()))?;
171 let nonce = XNonce::from(*nonce_array);
172 cipher
173 .encrypt(&nonce, plaintext)
174 .map(|ciphertext| (nonce_bytes, ciphertext))
175 .map_err(|_| Error::Crypto("failed to encrypt payload".into()))
176 }
177 #[cfg(not(feature = "xchacha"))]
178 {
179 Err(Error::AlgorithmFeatureUnavailable(
180 algorithm.as_str().to_string(),
181 ))
182 }
183 }
184 }
185}
186
187fn decrypt_with_algorithm(
188 algorithm: EncryptionAlgorithm,
189 key: &[u8; 32],
190 nonce: &[u8],
191 ciphertext: &[u8],
192) -> DecryptResult<Vec<u8>> {
193 match algorithm {
194 EncryptionAlgorithm::Aes256Gcm => {
195 let mut combined = Vec::with_capacity(nonce.len() + ciphertext.len());
196 combined.extend_from_slice(nonce);
197 combined.extend_from_slice(ciphertext);
198 let encoded = STANDARD.encode(combined);
199 match open_aead(key, &encoded) {
200 Ok(bytes) => Ok(bytes),
201 Err(Error::Backend(message)) if message == "open failed" => {
202 Err(DecryptError::MacMismatch)
203 }
204 Err(err) => Err(DecryptError::Crypto(err.to_string())),
205 }
206 }
207 EncryptionAlgorithm::XChaCha20Poly1305 => {
208 #[cfg(feature = "xchacha")]
209 {
210 let cipher = XChaCha20Poly1305::new_from_slice(key)
211 .map_err(|_| DecryptError::Crypto("invalid XChaCha key".into()))?;
212 let nonce_array: &[u8; XCHACHA_NONCE_LEN] = nonce
213 .try_into()
214 .map_err(|_| DecryptError::Crypto("invalid XChaCha nonce length".into()))?;
215 let nonce = XNonce::from(*nonce_array);
216 cipher
217 .decrypt(&nonce, ciphertext)
218 .map_err(|_| DecryptError::MacMismatch)
219 }
220 #[cfg(not(feature = "xchacha"))]
221 {
222 Err(DecryptError::Crypto(format!(
223 "algorithm {algorithm} unavailable"
224 )))
225 }
226 }
227 }
228}
229
230fn derive_key(dek: &[u8], salt: &[u8], info: &[u8]) -> Result<[u8; 32]> {
231 let hkdf = Hkdf::<Sha256>::new(Some(salt), dek);
232 let mut okm = [0u8; 32];
233 hkdf.expand(info, &mut okm)
234 .map_err(|_| Error::Crypto("failed to derive key material".into()))?;
235 Ok(okm)
236}
237
238fn generate_dek() -> Vec<u8> {
239 random_bytes(DEFAULT_DEK_LEN)
240}
241
242fn random_bytes(len: usize) -> Vec<u8> {
243 let mut buffer = vec![0u8; len];
244 let mut rng = rand::rng();
245 rng.fill_bytes(&mut buffer);
246 buffer
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252 use crate::crypto::dek_cache::DekCache;
253 use crate::key_provider::KeyProvider;
254 use crate::spec_compat::{ContentType, Scope, SecretMeta, Visibility};
255 use crate::uri::SecretUri;
256 use std::sync::{Arc, Mutex};
257 use std::time::Duration;
258
259 #[derive(Clone)]
260 struct DummyProvider {
261 wrap_calls: Arc<Mutex<usize>>,
262 unwrap_calls: Arc<Mutex<usize>>,
263 }
264
265 impl DummyProvider {
266 fn new() -> Self {
267 Self {
268 wrap_calls: Arc::new(Mutex::new(0)),
269 unwrap_calls: Arc::new(Mutex::new(0)),
270 }
271 }
272
273 fn calls(&self) -> (usize, usize) {
274 (
275 *self.wrap_calls.lock().unwrap(),
276 *self.unwrap_calls.lock().unwrap(),
277 )
278 }
279 }
280
281 impl KeyProvider for DummyProvider {
282 fn wrap_dek(&self, _scope: &Scope, dek: &[u8]) -> Result<Vec<u8>> {
283 *self.wrap_calls.lock().unwrap() += 1;
284 Ok(dek.iter().map(|b| b ^ 0xAA).collect())
285 }
286
287 fn unwrap_dek(&self, _scope: &Scope, wrapped: &[u8]) -> Result<Vec<u8>> {
288 *self.unwrap_calls.lock().unwrap() += 1;
289 Ok(wrapped.iter().map(|b| b ^ 0xAA).collect())
290 }
291 }
292
293 fn sample_meta(team: Option<&str>) -> SecretMeta {
294 let scope = Scope::new(
295 "prod".to_string(),
296 "acme".to_string(),
297 team.map(|t| t.to_string()),
298 )
299 .unwrap();
300 let uri = SecretUri::new(scope.clone(), "kv", "api")
301 .unwrap()
302 .with_version(Some("v1"))
303 .unwrap();
304 SecretMeta::new(uri, Visibility::Team, ContentType::Opaque)
305 }
306
307 #[test]
308 fn encrypt_decrypt_roundtrip() {
309 let provider = DummyProvider::new();
310 let cache = DekCache::new(8, Duration::from_secs(300));
311 let mut service = EnvelopeService::new(provider, cache, EncryptionAlgorithm::Aes256Gcm);
312
313 let meta = sample_meta(Some("payments"));
314 let plaintext = b"super-secret-data";
315 let record = service
316 .encrypt_record(meta.clone(), plaintext)
317 .expect("encrypt");
318
319 let recovered = service.decrypt_record(&record).expect("decrypt");
320 assert_eq!(plaintext.to_vec(), recovered);
321 assert_eq!(record.meta, meta);
322 }
323
324 #[test]
325 fn tamper_detection() {
326 let provider = DummyProvider::new();
327 let cache = DekCache::new(8, Duration::from_secs(300));
328 let mut service = EnvelopeService::new(provider, cache, EncryptionAlgorithm::Aes256Gcm);
329 let meta = sample_meta(Some("payments"));
330
331 let mut record = service.encrypt_record(meta, b"critical").expect("encrypt");
332 record.value[0] ^= 0xFF;
333
334 let err = service.decrypt_record(&record).unwrap_err();
335 assert!(matches!(err, DecryptError::MacMismatch));
336 }
337
338 #[test]
339 fn cache_hit_and_miss_behavior() {
340 let provider = DummyProvider::new();
341 let cache = DekCache::new(8, Duration::from_secs(300));
342 let mut service =
343 EnvelopeService::new(provider.clone(), cache, EncryptionAlgorithm::Aes256Gcm);
344 let meta = sample_meta(Some("payments"));
345 let plaintext = b"payload";
346
347 service
348 .encrypt_record(meta.clone(), plaintext)
349 .expect("encrypt");
350 let (wrap_calls, _) = provider.calls();
351 assert_eq!(wrap_calls, 1);
352
353 service
354 .encrypt_record(meta.clone(), plaintext)
355 .expect("encrypt again");
356 let (wrap_calls, _) = provider.calls();
357 assert_eq!(wrap_calls, 1, "expected cache hit to avoid wrapping");
358
359 let (wrap_calls_before, _) = provider.calls();
361 let mut service = EnvelopeService::new(
362 provider.clone(),
363 DekCache::new(8, Duration::from_secs(0)),
364 EncryptionAlgorithm::Aes256Gcm,
365 );
366 service
367 .encrypt_record(meta, plaintext)
368 .expect("encrypt with fresh cache");
369 let (wrap_calls, _) = provider.calls();
370 assert!(
371 wrap_calls > wrap_calls_before,
372 "expected miss to invoke wrap again"
373 );
374 }
375}