Skip to main content

authx_core/crypto/
encryption.rs

1use aes_gcm::{
2    Aes256Gcm, Nonce,
3    aead::{Aead, KeyInit, OsRng, rand_core::RngCore},
4};
5use tracing::instrument;
6
7use crate::error::{AuthError, Result};
8
9const NONCE_LEN: usize = 12;
10
11/// Encrypts plaintext with AES-256-GCM. Returns `nonce || ciphertext` as hex.
12#[instrument(skip(key, plaintext))]
13pub fn encrypt(key: &[u8; 32], plaintext: &[u8]) -> Result<String> {
14    let cipher =
15        Aes256Gcm::new_from_slice(key).map_err(|e| AuthError::EncryptionError(e.to_string()))?;
16
17    let mut nonce_bytes = [0u8; NONCE_LEN];
18    OsRng.fill_bytes(&mut nonce_bytes);
19    let nonce = Nonce::from_slice(&nonce_bytes);
20
21    let ciphertext = cipher
22        .encrypt(nonce, plaintext)
23        .map_err(|e| AuthError::EncryptionError(e.to_string()))?;
24
25    let mut out = Vec::with_capacity(NONCE_LEN + ciphertext.len());
26    out.extend_from_slice(&nonce_bytes);
27    out.extend_from_slice(&ciphertext);
28
29    tracing::debug!("encrypted {} bytes", plaintext.len());
30    Ok(hex::encode(out))
31}
32
33/// Decrypts hex-encoded `nonce || ciphertext` produced by [`encrypt`].
34#[instrument(skip(key, hex_blob))]
35pub fn decrypt(key: &[u8; 32], hex_blob: &str) -> Result<Vec<u8>> {
36    let raw =
37        hex::decode(hex_blob).map_err(|_| AuthError::EncryptionError("invalid hex".into()))?;
38
39    if raw.len() < NONCE_LEN {
40        return Err(AuthError::EncryptionError("blob too short".into()));
41    }
42
43    let (nonce_bytes, ciphertext) = raw.split_at(NONCE_LEN);
44    let nonce = Nonce::from_slice(nonce_bytes);
45
46    let cipher =
47        Aes256Gcm::new_from_slice(key).map_err(|e| AuthError::EncryptionError(e.to_string()))?;
48
49    let plaintext = cipher
50        .decrypt(nonce, ciphertext)
51        .map_err(|_| AuthError::EncryptionError("decryption failed".into()))?;
52
53    tracing::debug!("decrypted {} bytes", plaintext.len());
54    Ok(plaintext)
55}
56
57#[cfg(test)]
58mod tests {
59    use super::*;
60
61    fn test_key() -> [u8; 32] {
62        [0x42u8; 32]
63    }
64
65    #[test]
66    fn round_trip_plaintext() {
67        let key = test_key();
68        let plaintext = b"oauth-access-token-secret";
69        let blob = encrypt(&key, plaintext).unwrap();
70        let recovered = decrypt(&key, &blob).unwrap();
71        assert_eq!(recovered, plaintext);
72    }
73
74    #[test]
75    fn different_nonce_each_call() {
76        let key = test_key();
77        let a = encrypt(&key, b"same").unwrap();
78        let b = encrypt(&key, b"same").unwrap();
79        assert_ne!(a, b, "nonce must be randomised per call");
80    }
81
82    #[test]
83    fn wrong_key_fails() {
84        let key1 = [0x11u8; 32];
85        let key2 = [0x22u8; 32];
86        let blob = encrypt(&key1, b"secret").unwrap();
87        assert!(decrypt(&key2, &blob).is_err());
88    }
89
90    #[test]
91    fn truncated_blob_fails() {
92        let key = test_key();
93        assert!(decrypt(&key, "deadbeef").is_err());
94    }
95}