1use alloc::vec::Vec;
2use core::fmt;
3
4use crate::aes128::Aes128;
5use crate::aes256::Aes256;
6use crate::hmac::hmac_sha256;
7use crate::pkcs7;
8use crate::Rng;
9
10pub const TOKEN_OVERHEAD: usize = 48; #[derive(Debug, PartialEq)]
13pub enum TokenError {
14 InvalidKeyLength,
15 InvalidToken,
16 HmacMismatch,
17 DecryptionFailed,
18}
19
20impl fmt::Display for TokenError {
21 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22 match self {
23 TokenError::InvalidKeyLength => write!(f, "Token key must be 32 or 64 bytes"),
24 TokenError::InvalidToken => write!(f, "Token too short"),
25 TokenError::HmacMismatch => write!(f, "Token HMAC was invalid"),
26 TokenError::DecryptionFailed => write!(f, "Could not decrypt token"),
27 }
28 }
29}
30
31enum AesMode {
32 Aes128(Aes128),
33 Aes256(Aes256),
34}
35
36pub struct Token {
37 signing_key: Vec<u8>,
38 mode: AesMode,
39}
40
41impl Token {
42 pub fn new(key: &[u8]) -> Result<Self, TokenError> {
43 match key.len() {
44 32 => {
45 let signing_key = key[..16].to_vec();
46 let encryption_key: [u8; 16] = key[16..32].try_into().unwrap();
47 Ok(Token {
48 signing_key,
49 mode: AesMode::Aes128(Aes128::new(&encryption_key)),
50 })
51 }
52 64 => {
53 let signing_key = key[..32].to_vec();
54 let encryption_key: [u8; 32] = key[32..64].try_into().unwrap();
55 Ok(Token {
56 signing_key,
57 mode: AesMode::Aes256(Aes256::new(&encryption_key)),
58 })
59 }
60 _ => Err(TokenError::InvalidKeyLength),
61 }
62 }
63
64 pub fn encrypt(&self, plaintext: &[u8], rng: &mut dyn Rng) -> Vec<u8> {
65 let mut iv = [0u8; 16];
66 rng.fill_bytes(&mut iv);
67 self.encrypt_with_iv(plaintext, &iv)
68 }
69
70 pub fn encrypt_with_iv(&self, plaintext: &[u8], iv: &[u8; 16]) -> Vec<u8> {
71 let padded = pkcs7::pad(plaintext, 16);
72 let ciphertext = match &self.mode {
73 AesMode::Aes128(aes) => aes.encrypt_cbc(&padded, iv),
74 AesMode::Aes256(aes) => aes.encrypt_cbc(&padded, iv),
75 };
76
77 let mut signed_parts = Vec::with_capacity(16 + ciphertext.len());
78 signed_parts.extend_from_slice(iv);
79 signed_parts.extend_from_slice(&ciphertext);
80
81 let mac = hmac_sha256(&self.signing_key, &signed_parts);
82
83 let mut result = Vec::with_capacity(signed_parts.len() + 32);
84 result.extend_from_slice(&signed_parts);
85 result.extend_from_slice(&mac);
86 result
87 }
88
89 pub fn verify_hmac(&self, token: &[u8]) -> Result<bool, TokenError> {
90 if token.len() <= 32 {
91 return Err(TokenError::InvalidToken);
92 }
93 let received_hmac = &token[token.len() - 32..];
94 let expected_hmac = hmac_sha256(&self.signing_key, &token[..token.len() - 32]);
95 Ok(received_hmac == expected_hmac)
96 }
97
98 pub fn decrypt(&self, token: &[u8]) -> Result<Vec<u8>, TokenError> {
99 if token.len() <= TOKEN_OVERHEAD {
100 return Err(TokenError::InvalidToken);
101 }
102
103 if !self.verify_hmac(token).map_err(|_| TokenError::InvalidToken)? {
104 return Err(TokenError::HmacMismatch);
105 }
106
107 let iv: [u8; 16] = token[..16].try_into().unwrap();
108 let ciphertext = &token[16..token.len() - 32];
109
110 let decrypted = match &self.mode {
111 AesMode::Aes128(aes) => aes.decrypt_cbc(ciphertext, &iv),
112 AesMode::Aes256(aes) => aes.decrypt_cbc(ciphertext, &iv),
113 };
114
115 pkcs7::unpad(&decrypted, 16)
116 .map(|s| s.to_vec())
117 .map_err(|_| TokenError::DecryptionFailed)
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use crate::FixedRng;
125
126 #[test]
127 fn test_token_new_32byte_key() {
128 let key = [0u8; 32];
129 assert!(Token::new(&key).is_ok());
130 }
131
132 #[test]
133 fn test_token_new_64byte_key() {
134 let key = [0u8; 64];
135 assert!(Token::new(&key).is_ok());
136 }
137
138 #[test]
139 fn test_token_new_48byte_key() {
140 let key = [0u8; 48];
141 assert!(matches!(Token::new(&key), Err(TokenError::InvalidKeyLength)));
142 }
143
144 #[test]
145 fn test_token_roundtrip_32() {
146 let key = [0x42u8; 32];
147 let token = Token::new(&key).unwrap();
148 let mut rng = FixedRng::new(&[0xAA; 16]);
149 let plaintext = b"Hello, Reticulum!";
150 let encrypted = token.encrypt(plaintext, &mut rng);
151 let decrypted = token.decrypt(&encrypted).unwrap();
152 assert_eq!(decrypted, plaintext);
153 }
154
155 #[test]
156 fn test_token_roundtrip_64() {
157 let key = [0x42u8; 64];
158 let token = Token::new(&key).unwrap();
159 let mut rng = FixedRng::new(&[0xBB; 16]);
160 let plaintext = b"Hello, Reticulum!";
161 let encrypted = token.encrypt(plaintext, &mut rng);
162 let decrypted = token.decrypt(&encrypted).unwrap();
163 assert_eq!(decrypted, plaintext);
164 }
165
166 #[test]
167 fn test_token_hmac_reject_tampered() {
168 let key = [0x42u8; 64];
169 let token = Token::new(&key).unwrap();
170 let mut rng = FixedRng::new(&[0xCC; 16]);
171 let encrypted = token.encrypt(b"test", &mut rng);
172 let mut tampered = encrypted.clone();
173 tampered[20] ^= 0xFF; assert!(token.decrypt(&tampered).is_err());
175 }
176
177 #[test]
178 fn test_token_decrypt_truncated() {
179 let key = [0x42u8; 64];
180 let token = Token::new(&key).unwrap();
181 assert!(matches!(token.decrypt(&[0u8; 10]), Err(TokenError::InvalidToken)));
182 }
183
184 #[test]
185 fn test_token_overhead() {
186 assert_eq!(TOKEN_OVERHEAD, 48);
187 }
188}