Skip to main content

rns_crypto/
token.rs

1use alloc::vec::Vec;
2use core::fmt;
3
4use crate::aes128::Aes128;
5use crate::aes256::Aes256;
6use crate::hmac::hmac_sha256;
7use crate::pkcs7;
8use crate::Rng;
9
10pub const TOKEN_OVERHEAD: usize = 48; // 16 IV + 32 HMAC
11
12#[derive(Debug, PartialEq)]
13pub enum TokenError {
14    InvalidKeyLength,
15    InvalidToken,
16    HmacMismatch,
17    DecryptionFailed,
18}
19
20impl fmt::Display for TokenError {
21    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22        match self {
23            TokenError::InvalidKeyLength => write!(f, "Token key must be 32 or 64 bytes"),
24            TokenError::InvalidToken => write!(f, "Token too short"),
25            TokenError::HmacMismatch => write!(f, "Token HMAC was invalid"),
26            TokenError::DecryptionFailed => write!(f, "Could not decrypt token"),
27        }
28    }
29}
30
31enum AesMode {
32    Aes128(Aes128),
33    Aes256(Aes256),
34}
35
36pub struct Token {
37    signing_key: Vec<u8>,
38    mode: AesMode,
39}
40
41impl Token {
42    pub fn new(key: &[u8]) -> Result<Self, TokenError> {
43        match key.len() {
44            32 => {
45                let signing_key = key[..16].to_vec();
46                let encryption_key: [u8; 16] = key[16..32].try_into().unwrap();
47                Ok(Token {
48                    signing_key,
49                    mode: AesMode::Aes128(Aes128::new(&encryption_key)),
50                })
51            }
52            64 => {
53                let signing_key = key[..32].to_vec();
54                let encryption_key: [u8; 32] = key[32..64].try_into().unwrap();
55                Ok(Token {
56                    signing_key,
57                    mode: AesMode::Aes256(Aes256::new(&encryption_key)),
58                })
59            }
60            _ => Err(TokenError::InvalidKeyLength),
61        }
62    }
63
64    pub fn encrypt(&self, plaintext: &[u8], rng: &mut dyn Rng) -> Vec<u8> {
65        let mut iv = [0u8; 16];
66        rng.fill_bytes(&mut iv);
67        self.encrypt_with_iv(plaintext, &iv)
68    }
69
70    pub fn encrypt_with_iv(&self, plaintext: &[u8], iv: &[u8; 16]) -> Vec<u8> {
71        let padded = pkcs7::pad(plaintext, 16);
72        let ciphertext = match &self.mode {
73            AesMode::Aes128(aes) => aes.encrypt_cbc(&padded, iv),
74            AesMode::Aes256(aes) => aes.encrypt_cbc(&padded, iv),
75        };
76
77        let mut signed_parts = Vec::with_capacity(16 + ciphertext.len());
78        signed_parts.extend_from_slice(iv);
79        signed_parts.extend_from_slice(&ciphertext);
80
81        let mac = hmac_sha256(&self.signing_key, &signed_parts);
82
83        let mut result = Vec::with_capacity(signed_parts.len() + 32);
84        result.extend_from_slice(&signed_parts);
85        result.extend_from_slice(&mac);
86        result
87    }
88
89    pub fn verify_hmac(&self, token: &[u8]) -> Result<bool, TokenError> {
90        if token.len() <= 32 {
91            return Err(TokenError::InvalidToken);
92        }
93        let received_hmac = &token[token.len() - 32..];
94        let expected_hmac = hmac_sha256(&self.signing_key, &token[..token.len() - 32]);
95        Ok(received_hmac == expected_hmac)
96    }
97
98    pub fn decrypt(&self, token: &[u8]) -> Result<Vec<u8>, TokenError> {
99        if token.len() <= TOKEN_OVERHEAD {
100            return Err(TokenError::InvalidToken);
101        }
102
103        if !self.verify_hmac(token).map_err(|_| TokenError::InvalidToken)? {
104            return Err(TokenError::HmacMismatch);
105        }
106
107        let iv: [u8; 16] = token[..16].try_into().unwrap();
108        let ciphertext = &token[16..token.len() - 32];
109
110        let decrypted = match &self.mode {
111            AesMode::Aes128(aes) => aes.decrypt_cbc(ciphertext, &iv),
112            AesMode::Aes256(aes) => aes.decrypt_cbc(ciphertext, &iv),
113        };
114
115        pkcs7::unpad(&decrypted, 16)
116            .map(|s| s.to_vec())
117            .map_err(|_| TokenError::DecryptionFailed)
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124    use crate::FixedRng;
125
126    #[test]
127    fn test_token_new_32byte_key() {
128        let key = [0u8; 32];
129        assert!(Token::new(&key).is_ok());
130    }
131
132    #[test]
133    fn test_token_new_64byte_key() {
134        let key = [0u8; 64];
135        assert!(Token::new(&key).is_ok());
136    }
137
138    #[test]
139    fn test_token_new_48byte_key() {
140        let key = [0u8; 48];
141        assert!(matches!(Token::new(&key), Err(TokenError::InvalidKeyLength)));
142    }
143
144    #[test]
145    fn test_token_roundtrip_32() {
146        let key = [0x42u8; 32];
147        let token = Token::new(&key).unwrap();
148        let mut rng = FixedRng::new(&[0xAA; 16]);
149        let plaintext = b"Hello, Reticulum!";
150        let encrypted = token.encrypt(plaintext, &mut rng);
151        let decrypted = token.decrypt(&encrypted).unwrap();
152        assert_eq!(decrypted, plaintext);
153    }
154
155    #[test]
156    fn test_token_roundtrip_64() {
157        let key = [0x42u8; 64];
158        let token = Token::new(&key).unwrap();
159        let mut rng = FixedRng::new(&[0xBB; 16]);
160        let plaintext = b"Hello, Reticulum!";
161        let encrypted = token.encrypt(plaintext, &mut rng);
162        let decrypted = token.decrypt(&encrypted).unwrap();
163        assert_eq!(decrypted, plaintext);
164    }
165
166    #[test]
167    fn test_token_hmac_reject_tampered() {
168        let key = [0x42u8; 64];
169        let token = Token::new(&key).unwrap();
170        let mut rng = FixedRng::new(&[0xCC; 16]);
171        let encrypted = token.encrypt(b"test", &mut rng);
172        let mut tampered = encrypted.clone();
173        tampered[20] ^= 0xFF; // flip a bit in ciphertext
174        assert!(token.decrypt(&tampered).is_err());
175    }
176
177    #[test]
178    fn test_token_decrypt_truncated() {
179        let key = [0x42u8; 64];
180        let token = Token::new(&key).unwrap();
181        assert!(matches!(token.decrypt(&[0u8; 10]), Err(TokenError::InvalidToken)));
182    }
183
184    #[test]
185    fn test_token_overhead() {
186        assert_eq!(TOKEN_OVERHEAD, 48);
187    }
188}