1use alloc::vec::Vec;
81
82#[cfg_attr(
83 any(feature = "aead-chacha20", feature = "aead-aes-gcm"),
84 allow(unused_imports)
85)]
86use crate::error::{Error, Result};
87
88#[cfg(feature = "aead-aes-gcm")]
89mod aes_gcm;
90#[cfg(feature = "aead-chacha20")]
91mod chacha20;
92
93pub const CHACHA20_NONCE_LEN: usize = 12;
95
96pub const CHACHA20_TAG_LEN: usize = 16;
98
99pub const AES_GCM_NONCE_LEN: usize = 12;
102
103pub const AES_GCM_TAG_LEN: usize = 16;
105
106pub const KEY_LEN: usize = 32;
109
110#[non_exhaustive]
115#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
116pub enum Algorithm {
117 #[default]
123 ChaCha20Poly1305,
124 Aes256Gcm,
131}
132
133impl Algorithm {
134 #[must_use]
136 pub const fn name(self) -> &'static str {
137 match self {
138 Self::ChaCha20Poly1305 => "ChaCha20-Poly1305",
139 Self::Aes256Gcm => "AES-256-GCM",
140 }
141 }
142
143 #[must_use]
145 pub const fn key_len(self) -> usize {
146 match self {
147 Self::ChaCha20Poly1305 | Self::Aes256Gcm => KEY_LEN,
148 }
149 }
150
151 #[must_use]
153 pub const fn nonce_len(self) -> usize {
154 match self {
155 Self::ChaCha20Poly1305 => CHACHA20_NONCE_LEN,
156 Self::Aes256Gcm => AES_GCM_NONCE_LEN,
157 }
158 }
159
160 #[must_use]
162 pub const fn tag_len(self) -> usize {
163 match self {
164 Self::ChaCha20Poly1305 => CHACHA20_TAG_LEN,
165 Self::Aes256Gcm => AES_GCM_TAG_LEN,
166 }
167 }
168}
169
170#[derive(Debug, Clone, Copy, PartialEq, Eq)]
183pub struct Crypt {
184 algorithm: Algorithm,
185}
186
187impl Crypt {
188 #[must_use]
191 pub const fn new() -> Self {
192 Self {
193 algorithm: Algorithm::ChaCha20Poly1305,
194 }
195 }
196
197 #[must_use]
199 pub const fn with_algorithm(algorithm: Algorithm) -> Self {
200 Self { algorithm }
201 }
202
203 #[cfg(feature = "aead-aes-gcm")]
212 #[must_use]
213 pub const fn aes_256_gcm() -> Self {
214 Self {
215 algorithm: Algorithm::Aes256Gcm,
216 }
217 }
218
219 #[must_use]
221 pub const fn algorithm(&self) -> Algorithm {
222 self.algorithm
223 }
224
225 pub fn encrypt(&self, key: &[u8], plaintext: &[u8]) -> Result<Vec<u8>> {
252 self.encrypt_with_aad(key, plaintext, &[])
253 }
254
255 pub fn encrypt_with_aad(&self, key: &[u8], plaintext: &[u8], aad: &[u8]) -> Result<Vec<u8>> {
269 match self.algorithm {
270 Algorithm::ChaCha20Poly1305 => {
271 #[cfg(feature = "aead-chacha20")]
272 {
273 chacha20::encrypt(key, plaintext, aad)
274 }
275 #[cfg(not(feature = "aead-chacha20"))]
276 {
277 let _ = (key, plaintext, aad);
278 Err(Error::AlgorithmNotEnabled("aead-chacha20"))
279 }
280 }
281 Algorithm::Aes256Gcm => {
282 #[cfg(feature = "aead-aes-gcm")]
283 {
284 aes_gcm::encrypt(key, plaintext, aad)
285 }
286 #[cfg(not(feature = "aead-aes-gcm"))]
287 {
288 let _ = (key, plaintext, aad);
289 Err(Error::AlgorithmNotEnabled("aead-aes-gcm"))
290 }
291 }
292 }
293 }
294
295 pub fn decrypt(&self, key: &[u8], ciphertext: &[u8]) -> Result<Vec<u8>> {
332 self.decrypt_with_aad(key, ciphertext, &[])
333 }
334
335 pub fn decrypt_with_aad(&self, key: &[u8], ciphertext: &[u8], aad: &[u8]) -> Result<Vec<u8>> {
342 match self.algorithm {
343 Algorithm::ChaCha20Poly1305 => {
344 #[cfg(feature = "aead-chacha20")]
345 {
346 chacha20::decrypt(key, ciphertext, aad)
347 }
348 #[cfg(not(feature = "aead-chacha20"))]
349 {
350 let _ = (key, ciphertext, aad);
351 Err(Error::AlgorithmNotEnabled("aead-chacha20"))
352 }
353 }
354 Algorithm::Aes256Gcm => {
355 #[cfg(feature = "aead-aes-gcm")]
356 {
357 aes_gcm::decrypt(key, ciphertext, aad)
358 }
359 #[cfg(not(feature = "aead-aes-gcm"))]
360 {
361 let _ = (key, ciphertext, aad);
362 Err(Error::AlgorithmNotEnabled("aead-aes-gcm"))
363 }
364 }
365 }
366 }
367}
368
369impl Default for Crypt {
370 fn default() -> Self {
371 Self::new()
372 }
373}
374
375#[cfg(all(test, feature = "aead-chacha20"))]
376#[allow(clippy::unwrap_used, clippy::expect_used)]
377mod tests {
378 use super::*;
379 use alloc::vec;
380
381 #[test]
382 fn algorithm_metadata_matches_constants() {
383 let a = Algorithm::default();
384 assert_eq!(a, Algorithm::ChaCha20Poly1305);
385 assert_eq!(a.key_len(), KEY_LEN);
386 assert_eq!(a.nonce_len(), CHACHA20_NONCE_LEN);
387 assert_eq!(a.tag_len(), CHACHA20_TAG_LEN);
388 assert_eq!(a.name(), "ChaCha20-Poly1305");
389 }
390
391 #[test]
392 fn crypt_defaults_to_chacha20() {
393 let c = Crypt::new();
394 assert_eq!(c.algorithm(), Algorithm::ChaCha20Poly1305);
395 let d = Crypt::default();
396 assert_eq!(d.algorithm(), Algorithm::ChaCha20Poly1305);
397 }
398
399 #[test]
400 fn round_trip_empty_plaintext() {
401 let crypt = Crypt::new();
402 let key = [0x11u8; 32];
403 let ciphertext = crypt.encrypt(&key, b"").unwrap();
404 assert_eq!(ciphertext.len(), CHACHA20_NONCE_LEN + CHACHA20_TAG_LEN);
406 let recovered = crypt.decrypt(&key, &ciphertext).unwrap();
407 assert!(recovered.is_empty());
408 }
409
410 #[test]
411 fn round_trip_short_plaintext() {
412 let crypt = Crypt::new();
413 let key = [0x22u8; 32];
414 let plaintext = b"hello, world!";
415 let ciphertext = crypt.encrypt(&key, plaintext).unwrap();
416 let recovered = crypt.decrypt(&key, &ciphertext).unwrap();
417 assert_eq!(&*recovered, plaintext);
418 }
419
420 #[test]
421 fn round_trip_one_megabyte() {
422 let crypt = Crypt::new();
423 let key = [0x33u8; 32];
424 let plaintext = vec![0xa5u8; 1024 * 1024];
425 let ciphertext = crypt.encrypt(&key, &plaintext).unwrap();
426 let recovered = crypt.decrypt(&key, &ciphertext).unwrap();
427 assert_eq!(recovered, plaintext);
428 }
429
430 #[test]
431 fn two_encryptions_of_same_plaintext_differ() {
432 let crypt = Crypt::new();
433 let key = [0u8; 32];
434 let plaintext = b"deterministic? no.";
435 let a = crypt.encrypt(&key, plaintext).unwrap();
436 let b = crypt.encrypt(&key, plaintext).unwrap();
437 assert_ne!(a, b, "nonce-prepended outputs must differ across calls");
438 }
439
440 #[test]
441 fn wrong_key_fails_authentication() {
442 let crypt = Crypt::new();
443 let key = [0x44u8; 32];
444 let wrong = [0x55u8; 32];
445 let ciphertext = crypt.encrypt(&key, b"secret").unwrap();
446 let err = crypt.decrypt(&wrong, &ciphertext).unwrap_err();
447 assert_eq!(err, Error::AuthenticationFailed);
448 }
449
450 #[test]
451 fn tampered_ciphertext_fails_authentication() {
452 let crypt = Crypt::new();
453 let key = [0x66u8; 32];
454 let mut ciphertext = crypt.encrypt(&key, b"hands off").unwrap();
455 let i = ciphertext.len() / 2;
457 ciphertext[i] ^= 0x01;
458 let err = crypt.decrypt(&key, &ciphertext).unwrap_err();
459 assert_eq!(err, Error::AuthenticationFailed);
460 }
461
462 #[test]
463 fn tampered_tag_fails_authentication() {
464 let crypt = Crypt::new();
465 let key = [0x77u8; 32];
466 let mut ciphertext = crypt.encrypt(&key, b"sign me").unwrap();
467 let last = ciphertext.len() - 1;
468 ciphertext[last] ^= 0xff;
469 let err = crypt.decrypt(&key, &ciphertext).unwrap_err();
470 assert_eq!(err, Error::AuthenticationFailed);
471 }
472
473 #[test]
474 fn truncated_ciphertext_is_rejected() {
475 let crypt = Crypt::new();
476 let key = [0u8; 32];
477 for len in 0..(CHACHA20_NONCE_LEN + CHACHA20_TAG_LEN) {
479 let err = crypt.decrypt(&key, &vec![0u8; len]).unwrap_err();
480 assert!(
481 matches!(err, Error::InvalidCiphertext(_)),
482 "len={len} should error"
483 );
484 }
485 }
486
487 #[test]
488 fn aad_round_trip() {
489 let crypt = Crypt::new();
490 let key = [0x88u8; 32];
491 let plaintext = b"plaintext";
492 let aad = b"associated";
493 let ciphertext = crypt.encrypt_with_aad(&key, plaintext, aad).unwrap();
494 let recovered = crypt.decrypt_with_aad(&key, &ciphertext, aad).unwrap();
495 assert_eq!(&*recovered, plaintext);
496 }
497
498 #[test]
499 fn aad_mismatch_fails_authentication() {
500 let crypt = Crypt::new();
501 let key = [0x99u8; 32];
502 let ciphertext = crypt
503 .encrypt_with_aad(&key, b"body", b"original-aad")
504 .unwrap();
505 let err = crypt
506 .decrypt_with_aad(&key, &ciphertext, b"tampered-aad")
507 .unwrap_err();
508 assert_eq!(err, Error::AuthenticationFailed);
509 }
510
511 #[test]
512 fn encrypt_with_aad_then_decrypt_without_aad_fails() {
513 let crypt = Crypt::new();
514 let key = [0xaau8; 32];
515 let ciphertext = crypt.encrypt_with_aad(&key, b"body", b"required").unwrap();
516 let err = crypt.decrypt(&key, &ciphertext).unwrap_err();
517 assert_eq!(err, Error::AuthenticationFailed);
518 }
519
520 #[test]
521 fn invalid_key_length_rejected_on_encrypt() {
522 let crypt = Crypt::new();
523 let err = crypt.encrypt(&[0u8; 16], b"x").unwrap_err();
524 assert_eq!(
525 err,
526 Error::InvalidKey {
527 expected: 32,
528 actual: 16
529 }
530 );
531 }
532
533 #[test]
534 fn invalid_key_length_rejected_on_decrypt() {
535 let crypt = Crypt::new();
536 let ciphertext = crypt.encrypt(&[0u8; 32], b"x").unwrap();
539 let err = crypt.decrypt(&[0u8; 16], &ciphertext).unwrap_err();
540 assert_eq!(
541 err,
542 Error::InvalidKey {
543 expected: 32,
544 actual: 16
545 }
546 );
547 }
548}
549
550#[cfg(all(test, feature = "aead-aes-gcm"))]
554#[allow(clippy::unwrap_used, clippy::expect_used)]
555mod aes_gcm_tests {
556 use super::*;
557 use alloc::vec;
558
559 fn aes() -> Crypt {
560 Crypt::aes_256_gcm()
561 }
562
563 #[test]
564 fn algorithm_metadata_matches_constants() {
565 let a = Algorithm::Aes256Gcm;
566 assert_eq!(a.key_len(), KEY_LEN);
567 assert_eq!(a.nonce_len(), AES_GCM_NONCE_LEN);
568 assert_eq!(a.tag_len(), AES_GCM_TAG_LEN);
569 assert_eq!(a.name(), "AES-256-GCM");
570 }
571
572 #[test]
573 fn aes_256_gcm_constructor_selects_algorithm() {
574 let c = aes();
575 assert_eq!(c.algorithm(), Algorithm::Aes256Gcm);
576 let alt = Crypt::with_algorithm(Algorithm::Aes256Gcm);
577 assert_eq!(c, alt);
578 }
579
580 #[test]
581 fn round_trip_empty_plaintext() {
582 let crypt = aes();
583 let key = [0x11u8; 32];
584 let ciphertext = crypt.encrypt(&key, b"").unwrap();
585 assert_eq!(ciphertext.len(), AES_GCM_NONCE_LEN + AES_GCM_TAG_LEN);
586 let recovered = crypt.decrypt(&key, &ciphertext).unwrap();
587 assert!(recovered.is_empty());
588 }
589
590 #[test]
591 fn round_trip_short_plaintext() {
592 let crypt = aes();
593 let key = [0x22u8; 32];
594 let plaintext = b"hello, world!";
595 let ciphertext = crypt.encrypt(&key, plaintext).unwrap();
596 let recovered = crypt.decrypt(&key, &ciphertext).unwrap();
597 assert_eq!(&*recovered, plaintext);
598 }
599
600 #[test]
601 fn round_trip_one_megabyte() {
602 let crypt = aes();
603 let key = [0x33u8; 32];
604 let plaintext = vec![0xa5u8; 1024 * 1024];
605 let ciphertext = crypt.encrypt(&key, &plaintext).unwrap();
606 let recovered = crypt.decrypt(&key, &ciphertext).unwrap();
607 assert_eq!(recovered, plaintext);
608 }
609
610 #[test]
611 fn two_encryptions_of_same_plaintext_differ() {
612 let crypt = aes();
613 let key = [0u8; 32];
614 let plaintext = b"deterministic? no.";
615 let a = crypt.encrypt(&key, plaintext).unwrap();
616 let b = crypt.encrypt(&key, plaintext).unwrap();
617 assert_ne!(a, b, "nonce-prepended outputs must differ across calls");
618 }
619
620 #[test]
621 fn wrong_key_fails_authentication() {
622 let crypt = aes();
623 let key = [0x44u8; 32];
624 let wrong = [0x55u8; 32];
625 let ciphertext = crypt.encrypt(&key, b"secret").unwrap();
626 let err = crypt.decrypt(&wrong, &ciphertext).unwrap_err();
627 assert_eq!(err, Error::AuthenticationFailed);
628 }
629
630 #[test]
631 fn tampered_ciphertext_fails_authentication() {
632 let crypt = aes();
633 let key = [0x66u8; 32];
634 let mut ciphertext = crypt.encrypt(&key, b"hands off").unwrap();
635 let i = ciphertext.len() / 2;
636 ciphertext[i] ^= 0x01;
637 let err = crypt.decrypt(&key, &ciphertext).unwrap_err();
638 assert_eq!(err, Error::AuthenticationFailed);
639 }
640
641 #[test]
642 fn tampered_tag_fails_authentication() {
643 let crypt = aes();
644 let key = [0x77u8; 32];
645 let mut ciphertext = crypt.encrypt(&key, b"sign me").unwrap();
646 let last = ciphertext.len() - 1;
647 ciphertext[last] ^= 0xff;
648 let err = crypt.decrypt(&key, &ciphertext).unwrap_err();
649 assert_eq!(err, Error::AuthenticationFailed);
650 }
651
652 #[test]
653 fn truncated_ciphertext_is_rejected() {
654 let crypt = aes();
655 let key = [0u8; 32];
656 for len in 0..(AES_GCM_NONCE_LEN + AES_GCM_TAG_LEN) {
657 let err = crypt.decrypt(&key, &vec![0u8; len]).unwrap_err();
658 assert!(
659 matches!(err, Error::InvalidCiphertext(_)),
660 "len={len} should error"
661 );
662 }
663 }
664
665 #[test]
666 fn aad_round_trip() {
667 let crypt = aes();
668 let key = [0x88u8; 32];
669 let plaintext = b"plaintext";
670 let aad = b"associated";
671 let ciphertext = crypt.encrypt_with_aad(&key, plaintext, aad).unwrap();
672 let recovered = crypt.decrypt_with_aad(&key, &ciphertext, aad).unwrap();
673 assert_eq!(&*recovered, plaintext);
674 }
675
676 #[test]
677 fn aad_mismatch_fails_authentication() {
678 let crypt = aes();
679 let key = [0x99u8; 32];
680 let ciphertext = crypt
681 .encrypt_with_aad(&key, b"body", b"original-aad")
682 .unwrap();
683 let err = crypt
684 .decrypt_with_aad(&key, &ciphertext, b"tampered-aad")
685 .unwrap_err();
686 assert_eq!(err, Error::AuthenticationFailed);
687 }
688
689 #[test]
690 fn invalid_key_length_rejected_on_encrypt() {
691 let crypt = aes();
692 let err = crypt.encrypt(&[0u8; 16], b"x").unwrap_err();
693 assert_eq!(
694 err,
695 Error::InvalidKey {
696 expected: 32,
697 actual: 16
698 }
699 );
700 }
701}
702
703#[cfg(all(test, feature = "aead-chacha20", feature = "aead-aes-gcm"))]
708#[allow(clippy::unwrap_used, clippy::expect_used)]
709mod cross_algorithm_tests {
710 use super::*;
711
712 #[test]
713 fn chacha_ciphertext_does_not_decrypt_as_aes() {
714 let key = [0xcdu8; 32];
715 let ct = Crypt::new().encrypt(&key, b"message").unwrap();
716 let err = Crypt::aes_256_gcm().decrypt(&key, &ct).unwrap_err();
717 assert_eq!(err, Error::AuthenticationFailed);
718 }
719
720 #[test]
721 fn aes_ciphertext_does_not_decrypt_as_chacha() {
722 let key = [0xefu8; 32];
723 let ct = Crypt::aes_256_gcm().encrypt(&key, b"message").unwrap();
724 let err = Crypt::new().decrypt(&key, &ct).unwrap_err();
725 assert_eq!(err, Error::AuthenticationFailed);
726 }
727
728 #[test]
729 fn algorithm_name_table_is_unique() {
730 let names = [
731 Algorithm::ChaCha20Poly1305.name(),
732 Algorithm::Aes256Gcm.name(),
733 ];
734 for (i, a) in names.iter().enumerate() {
735 for (j, b) in names.iter().enumerate() {
736 if i != j {
737 assert_ne!(a, b, "algorithm names must be distinct");
738 }
739 }
740 }
741 }
742}