1pub trait EncryptionProvider:
33 Send + Sync + std::panic::UnwindSafe + std::panic::RefUnwindSafe
34{
35 fn encrypt(&self, plaintext: &[u8]) -> crate::Result<Vec<u8>>;
44
45 fn max_overhead(&self) -> u32;
52
53 fn decrypt(&self, ciphertext: &[u8]) -> crate::Result<Vec<u8>>;
60
61 fn encrypt_vec(&self, plaintext: Vec<u8>) -> crate::Result<Vec<u8>> {
71 self.encrypt(&plaintext)
72 }
73
74 fn decrypt_vec(&self, ciphertext: Vec<u8>) -> crate::Result<Vec<u8>> {
85 self.decrypt(&ciphertext)
86 }
87}
88
89#[cfg(feature = "encryption")]
109pub struct Aes256GcmProvider {
110 cipher: aes_gcm::Aes256Gcm,
111}
112
113#[cfg(feature = "encryption")]
114impl Aes256GcmProvider {
115 const NONCE_LEN: usize = 12;
117
118 const TAG_LEN: usize = 16;
120
121 pub const OVERHEAD: usize = Self::NONCE_LEN + Self::TAG_LEN;
123
124 #[must_use]
129 pub fn new(key: &[u8; 32]) -> Self {
130 use aes_gcm::KeyInit;
131
132 Self {
133 cipher: aes_gcm::Aes256Gcm::new(key.into()),
134 }
135 }
136
137 pub fn from_slice(key: &[u8]) -> crate::Result<Self> {
144 let key: &[u8; 32] = key
145 .try_into()
146 .map_err(|_| crate::Error::Encrypt("AES-256-GCM key must be exactly 32 bytes"))?;
147 Ok(Self::new(key))
148 }
149}
150
151#[cfg(feature = "encryption")]
157fn new_chacha_rng() -> rand_chacha::ChaCha20Rng {
158 use aes_gcm::aead::rand_core::{OsRng, SeedableRng};
161
162 #[expect(
163 clippy::expect_used,
164 reason = "intentionally panics if OsRng is unavailable"
165 )]
166 rand_chacha::ChaCha20Rng::from_rng(OsRng)
167 .expect("OS RNG should be available for initial CSPRNG seed")
168}
169
170#[cfg(feature = "encryption")]
176struct ForkAwareRng {
177 pid: std::cell::Cell<u32>,
178 rng: std::cell::RefCell<rand_chacha::ChaCha20Rng>,
179}
180
181#[cfg(feature = "encryption")]
182impl ForkAwareRng {
183 fn new() -> Self {
184 Self {
185 pid: std::cell::Cell::new(std::process::id()),
186 rng: std::cell::RefCell::new(new_chacha_rng()),
187 }
188 }
189
190 fn with_rng<R>(&self, f: impl FnOnce(&mut rand_chacha::ChaCha20Rng) -> R) -> R {
191 let mut rng_ref = self.rng.borrow_mut();
192 let current_pid = std::process::id();
193 if self.pid.get() != current_pid {
194 self.pid.set(current_pid);
196 *rng_ref = new_chacha_rng();
197 }
198
199 f(&mut rng_ref)
204 }
205}
206
207#[cfg(feature = "encryption")]
208thread_local! {
209 static THREAD_RNG: ForkAwareRng = ForkAwareRng::new();
212}
213
214#[cfg(feature = "encryption")]
223fn thread_local_rng<R>(f: impl FnOnce(&mut rand_chacha::ChaCha20Rng) -> R) -> R {
224 THREAD_RNG.with(|state| state.with_rng(f))
225}
226
227#[cfg(feature = "encryption")]
228impl EncryptionProvider for Aes256GcmProvider {
229 fn max_overhead(&self) -> u32 {
230 #[expect(clippy::cast_possible_truncation, reason = "OVERHEAD is 28")]
232 {
233 Self::OVERHEAD as u32
234 }
235 }
236
237 fn encrypt(&self, plaintext: &[u8]) -> crate::Result<Vec<u8>> {
238 use aes_gcm::AeadCore;
239 use aes_gcm::AeadInPlace;
240
241 let nonce = thread_local_rng(|rng| aes_gcm::Aes256Gcm::generate_nonce(rng));
242
243 let mut buf = Vec::with_capacity(Self::NONCE_LEN + plaintext.len() + Self::TAG_LEN);
244 buf.extend_from_slice(&nonce);
245 buf.extend_from_slice(plaintext);
246
247 #[expect(
254 clippy::indexing_slicing,
255 reason = "buf length = NONCE_LEN + plaintext.len()"
256 )]
257 let tag = self
258 .cipher
259 .encrypt_in_place_detached(&nonce, b"", &mut buf[Self::NONCE_LEN..])
260 .map_err(|_| crate::Error::Encrypt("AES-256-GCM encryption failed"))?;
261
262 buf.extend_from_slice(&tag);
263
264 Ok(buf)
265 }
266
267 fn decrypt(&self, ciphertext: &[u8]) -> crate::Result<Vec<u8>> {
268 use aes_gcm::aead::generic_array::GenericArray;
269 use aes_gcm::AeadInPlace;
270
271 let min_len = Self::NONCE_LEN + Self::TAG_LEN;
272 if ciphertext.len() < min_len {
273 return Err(crate::Error::Decrypt(
274 "ciphertext too short for AES-256-GCM (need nonce + tag)",
275 ));
276 }
277
278 #[expect(clippy::indexing_slicing, reason = "length checked above")]
279 let nonce = GenericArray::from_slice(&ciphertext[..Self::NONCE_LEN]);
280
281 let tag_start = ciphertext.len() - Self::TAG_LEN;
283
284 #[expect(clippy::indexing_slicing, reason = "length checked above")]
285 let tag = GenericArray::from_slice(&ciphertext[tag_start..]);
286
287 #[expect(clippy::indexing_slicing, reason = "length checked above")]
288 let mut buf = ciphertext[Self::NONCE_LEN..tag_start].to_vec();
289
290 self.cipher
291 .decrypt_in_place_detached(nonce, b"", &mut buf, tag)
292 .map_err(|_| {
293 crate::Error::Decrypt("AES-256-GCM decryption failed (bad key or tampered data)")
294 })?;
295
296 Ok(buf)
297 }
298
299 fn encrypt_vec(&self, mut buf: Vec<u8>) -> crate::Result<Vec<u8>> {
300 use aes_gcm::AeadCore;
301 use aes_gcm::AeadInPlace;
302
303 let nonce = thread_local_rng(|rng| aes_gcm::Aes256Gcm::generate_nonce(rng));
304
305 let plaintext_len = buf.len();
308 buf.reserve(Self::NONCE_LEN + Self::TAG_LEN);
309 buf.resize(plaintext_len + Self::NONCE_LEN, 0);
310 buf.copy_within(..plaintext_len, Self::NONCE_LEN);
311 #[expect(
312 clippy::indexing_slicing,
313 reason = "buf was just resized to include NONCE_LEN"
314 )]
315 buf[..Self::NONCE_LEN].copy_from_slice(&nonce);
316
317 #[expect(
318 clippy::indexing_slicing,
319 reason = "buf length ≥ NONCE_LEN after resize + copy_within"
320 )]
321 let tag = self
322 .cipher
323 .encrypt_in_place_detached(&nonce, b"", &mut buf[Self::NONCE_LEN..])
324 .map_err(|_| crate::Error::Encrypt("AES-256-GCM encryption failed"))?;
325
326 buf.extend_from_slice(&tag);
327
328 Ok(buf)
329 }
330
331 fn decrypt_vec(&self, mut buf: Vec<u8>) -> crate::Result<Vec<u8>> {
332 use aes_gcm::aead::generic_array::GenericArray;
333 use aes_gcm::AeadInPlace;
334
335 let min_len = Self::NONCE_LEN + Self::TAG_LEN;
338 if buf.len() < min_len {
339 return Err(crate::Error::Decrypt(
340 "ciphertext too short for AES-256-GCM (need nonce + tag)",
341 ));
342 }
343
344 #[expect(clippy::indexing_slicing, reason = "length checked above")]
346 let nonce = *GenericArray::from_slice(&buf[..Self::NONCE_LEN]);
347
348 let tag_start = buf.len() - Self::TAG_LEN;
349 #[expect(clippy::indexing_slicing, reason = "length checked above")]
350 let tag = *GenericArray::from_slice(&buf[tag_start..]);
351
352 buf.copy_within(Self::NONCE_LEN..tag_start, 0);
355 buf.truncate(tag_start - Self::NONCE_LEN);
356
357 self.cipher
358 .decrypt_in_place_detached(&nonce, b"", &mut buf, &tag)
359 .map_err(|_| {
360 crate::Error::Decrypt("AES-256-GCM decryption failed (bad key or tampered data)")
361 })?;
362
363 Ok(buf)
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370
371 #[test]
372 fn encryption_provider_trait_is_object_safe() {
373 fn _assert_object_safe(_: &dyn EncryptionProvider) {}
375 }
376
377 struct XorProvider;
380
381 impl std::panic::UnwindSafe for XorProvider {}
382 impl std::panic::RefUnwindSafe for XorProvider {}
383
384 impl EncryptionProvider for XorProvider {
385 fn encrypt(&self, plaintext: &[u8]) -> crate::Result<Vec<u8>> {
386 Ok(plaintext.iter().map(|b| b ^ 0xAA).collect())
387 }
388
389 fn max_overhead(&self) -> u32 {
390 0
391 }
392
393 fn decrypt(&self, ciphertext: &[u8]) -> crate::Result<Vec<u8>> {
394 Ok(ciphertext.iter().map(|b| b ^ 0xAA).collect())
395 }
396 }
397
398 #[test]
399 fn default_encrypt_vec_delegates_to_encrypt() -> crate::Result<()> {
400 let provider = XorProvider;
401 let plaintext = b"test default encrypt_vec";
402
403 let via_encrypt = provider.encrypt(plaintext)?;
404 let via_encrypt_vec = provider.encrypt_vec(plaintext.to_vec())?;
405 assert_eq!(via_encrypt, via_encrypt_vec);
406
407 let decrypted = provider.decrypt(&via_encrypt_vec)?;
408 assert_eq!(decrypted, plaintext);
409 Ok(())
410 }
411
412 #[test]
413 fn default_decrypt_vec_delegates_to_decrypt() -> crate::Result<()> {
414 let provider = XorProvider;
415 let plaintext = b"test default decrypt_vec";
416
417 let ciphertext = provider.encrypt(plaintext)?;
418
419 let via_decrypt = provider.decrypt(&ciphertext)?;
420 let via_decrypt_vec = provider.decrypt_vec(ciphertext.clone())?;
421 assert_eq!(via_decrypt, via_decrypt_vec);
422 assert_eq!(via_decrypt_vec, plaintext);
423 Ok(())
424 }
425
426 #[cfg(feature = "encryption")]
427 mod aes256gcm {
428 use super::*;
429
430 fn test_key() -> [u8; 32] {
431 [0x42; 32]
432 }
433
434 #[test]
435 fn roundtrip_basic() -> crate::Result<()> {
436 let provider = Aes256GcmProvider::new(&test_key());
437 let plaintext = b"hello world, this is a block of data!";
438
439 let ciphertext = provider.encrypt(plaintext)?;
440 assert_ne!(&ciphertext[..], plaintext.as_slice());
441 assert_eq!(
442 ciphertext.len(),
443 Aes256GcmProvider::NONCE_LEN + plaintext.len() + Aes256GcmProvider::TAG_LEN,
444 );
445
446 let decrypted = provider.decrypt(&ciphertext)?;
447 assert_eq!(decrypted, plaintext);
448 Ok(())
449 }
450
451 #[test]
452 fn roundtrip_empty() -> crate::Result<()> {
453 let provider = Aes256GcmProvider::new(&test_key());
454 let plaintext = b"";
455
456 let ciphertext = provider.encrypt(plaintext)?;
457 let decrypted = provider.decrypt(&ciphertext)?;
458 assert_eq!(decrypted, plaintext);
459 Ok(())
460 }
461
462 #[test]
463 fn different_nonces_produce_different_ciphertexts() -> crate::Result<()> {
464 let provider = Aes256GcmProvider::new(&test_key());
465 let plaintext = b"deterministic input";
466
467 let ct1 = provider.encrypt(plaintext)?;
468 let ct2 = provider.encrypt(plaintext)?;
469 assert_ne!(
470 ct1, ct2,
471 "random nonces should produce different ciphertexts"
472 );
473
474 assert_eq!(provider.decrypt(&ct1)?, provider.decrypt(&ct2)?,);
476 Ok(())
477 }
478
479 #[test]
480 fn wrong_key_fails_decrypt() -> crate::Result<()> {
481 let provider1 = Aes256GcmProvider::new(&[0x01; 32]);
482 let provider2 = Aes256GcmProvider::new(&[0x02; 32]);
483
484 let ciphertext = provider1.encrypt(b"secret")?;
485 let result = provider2.decrypt(&ciphertext);
486 assert!(result.is_err());
487 Ok(())
488 }
489
490 #[test]
491 fn tampered_ciphertext_fails_decrypt() -> crate::Result<()> {
492 let provider = Aes256GcmProvider::new(&test_key());
493 let mut ciphertext = provider.encrypt(b"data")?;
494
495 let mid = Aes256GcmProvider::NONCE_LEN + 1;
497 if mid < ciphertext.len() {
498 #[expect(clippy::indexing_slicing)]
499 {
500 ciphertext[mid] ^= 0xFF;
501 }
502 }
503
504 let result = provider.decrypt(&ciphertext);
505 assert!(result.is_err());
506 Ok(())
507 }
508
509 #[test]
510 fn truncated_ciphertext_fails_decrypt() -> crate::Result<()> {
511 let provider = Aes256GcmProvider::new(&test_key());
512 let result = provider.decrypt(&[0u8; 10]); assert!(result.is_err());
514 Ok(())
515 }
516
517 #[test]
518 fn from_slice_rejects_wrong_length() {
519 assert!(Aes256GcmProvider::from_slice(&[0u8; 16]).is_err());
520 assert!(Aes256GcmProvider::from_slice(&[0u8; 31]).is_err());
521 assert!(Aes256GcmProvider::from_slice(&[0u8; 33]).is_err());
522 assert!(Aes256GcmProvider::from_slice(&[0u8; 32]).is_ok());
523 }
524
525 #[test]
526 fn roundtrip_large_payload() -> crate::Result<()> {
527 let provider = Aes256GcmProvider::new(&test_key());
528 let plaintext = vec![0xAB_u8; 64 * 1024]; let ciphertext = provider.encrypt(&plaintext)?;
531 let decrypted = provider.decrypt(&ciphertext)?;
532 assert_eq!(decrypted, plaintext);
533 Ok(())
534 }
535
536 #[test]
539 fn thread_local_rng_produces_unique_nonces() -> crate::Result<()> {
540 let provider = Aes256GcmProvider::new(&test_key());
541 let plaintext = b"nonce uniqueness test";
542
543 let mut nonces = std::collections::HashSet::new();
544 for _ in 0..1000 {
545 let ct = provider.encrypt(plaintext)?;
546
547 #[expect(clippy::indexing_slicing, reason = "ct always >= NONCE_LEN")]
548 #[expect(clippy::expect_used, reason = "test assertion")]
549 let nonce: [u8; Aes256GcmProvider::NONCE_LEN] = ct[..Aes256GcmProvider::NONCE_LEN]
550 .try_into()
551 .expect("nonce has expected length");
552
553 assert!(
554 nonces.insert(nonce),
555 "nonce collision detected — CSPRNG produced duplicate nonce"
556 );
557 }
558 Ok(())
559 }
560
561 #[test]
566 fn fork_aware_rng_reseeds_on_pid_change() {
567 use aes_gcm::aead::rand_core::RngCore;
568
569 let rng = ForkAwareRng::new();
570
571 let _ = rng.with_rng(|r| r.next_u64());
573
574 let current_pid = std::process::id();
576 let fake_pid = current_pid ^ 1;
577 rng.pid.set(fake_pid);
578 assert_eq!(rng.pid.get(), fake_pid, "PID should be set to fake value");
579
580 let _ = rng.with_rng(|r| r.next_u64());
583
584 assert_eq!(
586 rng.pid.get(),
587 std::process::id(),
588 "PID should be restored to real process ID after reseed"
589 );
590 }
591
592 #[test]
593 fn encrypt_vec_roundtrip() -> crate::Result<()> {
594 let provider = Aes256GcmProvider::new(&test_key());
595 let plaintext = b"block data for encrypt_vec test";
596
597 let ciphertext = provider.encrypt_vec(plaintext.to_vec())?;
598 assert_eq!(
599 ciphertext.len(),
600 Aes256GcmProvider::NONCE_LEN + plaintext.len() + Aes256GcmProvider::TAG_LEN,
601 );
602
603 let decrypted = provider.decrypt(&ciphertext)?;
605 assert_eq!(decrypted, plaintext);
606 Ok(())
607 }
608
609 #[test]
610 fn decrypt_vec_roundtrip() -> crate::Result<()> {
611 let provider = Aes256GcmProvider::new(&test_key());
612 let plaintext = b"block data for decrypt_vec test";
613
614 let ciphertext = provider.encrypt(plaintext)?;
616 let decrypted = provider.decrypt_vec(ciphertext)?;
617 assert_eq!(decrypted, plaintext);
618 Ok(())
619 }
620
621 #[test]
622 fn encrypt_vec_decrypt_vec_roundtrip() -> crate::Result<()> {
623 let provider = Aes256GcmProvider::new(&test_key());
624 let plaintext = vec![0xCD_u8; 16 * 1024]; let ciphertext = provider.encrypt_vec(plaintext.clone())?;
627 let decrypted = provider.decrypt_vec(ciphertext)?;
628 assert_eq!(decrypted, plaintext);
629 Ok(())
630 }
631
632 #[test]
633 fn encrypt_vec_empty() -> crate::Result<()> {
634 let provider = Aes256GcmProvider::new(&test_key());
635
636 let ciphertext = provider.encrypt_vec(vec![])?;
637 let decrypted = provider.decrypt_vec(ciphertext)?;
638 assert!(decrypted.is_empty());
639 Ok(())
640 }
641
642 #[test]
643 fn decrypt_vec_truncated_fails() -> crate::Result<()> {
644 let provider = Aes256GcmProvider::new(&test_key());
645 let result = provider.decrypt_vec(vec![0u8; 10]);
646 assert!(result.is_err());
647 Ok(())
648 }
649
650 #[test]
651 fn decrypt_vec_tampered_fails() -> crate::Result<()> {
652 let provider = Aes256GcmProvider::new(&test_key());
653 let mut ciphertext = provider.encrypt_vec(b"data".to_vec())?;
654
655 let mid = Aes256GcmProvider::NONCE_LEN + 1;
656 if mid < ciphertext.len() {
657 #[expect(clippy::indexing_slicing)]
658 {
659 ciphertext[mid] ^= 0xFF;
660 }
661 }
662
663 let result = provider.decrypt_vec(ciphertext);
664 assert!(result.is_err());
665 Ok(())
666 }
667 }
668}