1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
//! `aead` is a module used to maintain the primary object of `AEAD (Authenticated Encryption with Associated Data`
//!
//! This module provides an abstraction to maintain `AEAD` on top of `aead` from `RustCrypto/traits` repository
use rst_common::with_cryptography::chacha20poly1305::{
    aead::{Aead, AeadCore, KeyInit},
    XChaCha20Poly1305,
};

use rst_common::with_cryptography::rand::{rngs::adapter::ReseedingRng, SeedableRng};
use rst_common::with_cryptography::rand_chacha::{rand_core::OsRng as RandCoreOsRng, ChaCha20Core};

mod key;
pub use key::Key;

pub mod errors {
    use rst_common::with_errors::thiserror::{self, Error};

    /// `AeadError` used specifically when manage cipher management
    /// specifically `AEAD`
    #[derive(Debug, Error)]
    pub enum AeadError {
        #[error("aead: unable to parse bytes: `{0}`")]
        CipherGeneratorError(String),
    }
}

/// `AEAD` is a main entrypoint to encrypt and decrypt the given data (in bytes), and also
/// generate nonce (in bytes)
pub struct AEAD;

impl AEAD {
    pub fn nonce() -> Vec<u8> {
        let prng = ChaCha20Core::from_entropy();
        let reseeding_rng = ReseedingRng::new(prng, 0, RandCoreOsRng);
        let nonce = XChaCha20Poly1305::generate_nonce(reseeding_rng);
        nonce.to_vec()
    }

    pub fn encrypt(key: &Key, message: &Vec<u8>) -> Result<Vec<u8>, errors::AeadError> {
        let cipher = XChaCha20Poly1305::new(&key.get_key().into());
        cipher
            .encrypt(&key.get_nonce().into(), message.as_slice())
            .map_err(|err| errors::AeadError::CipherGeneratorError(err.to_string()))
    }

    pub fn decrypt(key: &Key, encrypted: &Vec<u8>) -> Result<Vec<u8>, errors::AeadError> {
        let cipher = XChaCha20Poly1305::new(&key.get_key().into());
        cipher
            .decrypt(&key.get_nonce().into(), encrypted.as_ref())
            .map_err(|err| errors::AeadError::CipherGeneratorError(err.to_string()))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::ecdh::keypair::KeyPair;

    #[test]
    fn test_nonce() {
        let nonce = AEAD::nonce();
        let nonce_value: Result<[u8; 24], _> = nonce.try_into();
        assert!(!nonce_value.is_err())
    }

    #[test]
    fn test_encrypt_decrypt() {
        let keypair_alice = KeyPair::generate();
        let keypair_bob = KeyPair::generate();

        let pubkey_bob = keypair_bob.pub_key();
        let public_bob_hex = pubkey_bob.to_hex();

        let secret_alice = keypair_alice.secret(&public_bob_hex);
        let shared_secret_alice_blake3 = secret_alice.to_blake3();
        let shared_secret_alice_value: &Result<[u8; 32], _> =
            &shared_secret_alice_blake3.unwrap().as_bytes()[..32].try_into();
        assert!(!shared_secret_alice_value.is_err());

        let nonce = AEAD::nonce();
        let nonce_value: Result<[u8; 24], _> = nonce.try_into();
        let alice_key = shared_secret_alice_value.unwrap();
        let key = Key::generate(alice_key, nonce_value.unwrap());

        let message = String::from("plaintext");
        let encrypted = AEAD::encrypt(&key, &message.as_bytes().to_vec());
        assert!(!encrypted.is_err());

        let encrypted_str = encrypted.unwrap();
        let decrypted = AEAD::decrypt(&key, &encrypted_str.clone());
        assert!(!decrypted.is_err());

        let decrypted_value = decrypted.unwrap();
        let result = String::from_utf8(decrypted_value.clone());
        assert!(!result.is_err());
        assert_eq!(result.unwrap(), message);
        assert_eq!(decrypted_value.clone(), message.clone().as_bytes().to_vec());

        let nonce_missed = AEAD::nonce();
        let nonce_missed_value: Result<[u8; 24], _> = nonce_missed.try_into();
        let key_invalid = Key::generate(alice_key, nonce_missed_value.unwrap());
        let encrypted2 = AEAD::encrypt(&key, &message.as_bytes().to_vec());
        let decrypted_unmatched = AEAD::decrypt(&key_invalid, &encrypted2.unwrap());

        assert!(decrypted_unmatched.is_err());
        assert!(matches!(
            decrypted_unmatched,
            Err(errors::AeadError::CipherGeneratorError(_))
        ))
    }
}