1use chacha20poly1305::{
11 aead::{Aead, KeyInit},
12 XChaCha20Poly1305, XNonce,
13};
14use crate::core::{CryptoError, AEAD_NONCE_SIZE, AEAD_TAG_SIZE, SESSION_ID_SIZE};
15use zeroize::Zeroize;
16
17pub const SESSION_KEY_SIZE: usize = 32;
19
20pub const AAD_SIZE: usize = 16;
22
23#[derive(Clone)]
27pub struct SessionKey {
28 key: [u8; SESSION_KEY_SIZE],
29}
30
31impl SessionKey {
32 pub fn from_bytes(key: [u8; SESSION_KEY_SIZE]) -> Self {
34 Self { key }
35 }
36
37 pub fn as_bytes(&self) -> &[u8; SESSION_KEY_SIZE] {
42 &self.key
43 }
44}
45
46impl Drop for SessionKey {
47 fn drop(&mut self) {
48 self.key.zeroize();
49 }
50}
51
52pub fn construct_aad(
59 frame_type: u8,
60 flags: u8,
61 session_id: &[u8; SESSION_ID_SIZE],
62 nonce_counter: u64,
63) -> [u8; AAD_SIZE] {
64 let mut aad = [0u8; AAD_SIZE];
65
66 aad[0] = frame_type;
67 aad[1] = flags;
68 aad[2..8].copy_from_slice(session_id);
69 aad[8..16].copy_from_slice(&nonce_counter.to_le_bytes());
70
71 aad
72}
73
74pub fn encrypt(
85 key: &SessionKey,
86 nonce: &[u8; AEAD_NONCE_SIZE],
87 aad: &[u8],
88 plaintext: &[u8],
89) -> Result<Vec<u8>, CryptoError> {
90 let cipher = XChaCha20Poly1305::new(key.as_bytes().into());
91 let xnonce = XNonce::from_slice(nonce);
92
93 cipher
94 .encrypt(xnonce, chacha20poly1305::aead::Payload { msg: plaintext, aad })
95 .map_err(|_| CryptoError::EncryptionFailed)
96}
97
98pub fn decrypt(
109 key: &SessionKey,
110 nonce: &[u8; AEAD_NONCE_SIZE],
111 aad: &[u8],
112 ciphertext: &[u8],
113) -> Result<Vec<u8>, CryptoError> {
114 if ciphertext.len() < AEAD_TAG_SIZE {
115 return Err(CryptoError::DecryptionFailed);
116 }
117
118 let cipher = XChaCha20Poly1305::new(key.as_bytes().into());
119 let xnonce = XNonce::from_slice(nonce);
120
121 cipher
122 .decrypt(xnonce, chacha20poly1305::aead::Payload { msg: ciphertext, aad })
123 .map_err(|_| CryptoError::DecryptionFailed)
124}
125
126pub fn encrypt_in_place(
130 key: &SessionKey,
131 nonce: &[u8; AEAD_NONCE_SIZE],
132 aad: &[u8],
133 buffer: &mut Vec<u8>,
134) -> Result<(), CryptoError> {
135 let cipher = XChaCha20Poly1305::new(key.as_bytes().into());
136 let xnonce = XNonce::from_slice(nonce);
137
138 use chacha20poly1305::aead::AeadInPlace;
139 cipher
140 .encrypt_in_place(xnonce, aad, buffer)
141 .map_err(|_| CryptoError::EncryptionFailed)
142}
143
144pub fn decrypt_in_place(
146 key: &SessionKey,
147 nonce: &[u8; AEAD_NONCE_SIZE],
148 aad: &[u8],
149 buffer: &mut Vec<u8>,
150) -> Result<(), CryptoError> {
151 if buffer.len() < AEAD_TAG_SIZE {
152 return Err(CryptoError::DecryptionFailed);
153 }
154
155 let cipher = XChaCha20Poly1305::new(key.as_bytes().into());
156 let xnonce = XNonce::from_slice(nonce);
157
158 use chacha20poly1305::aead::AeadInPlace;
159 cipher
160 .decrypt_in_place(xnonce, aad, buffer)
161 .map_err(|_| CryptoError::DecryptionFailed)
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167
168 #[test]
169 fn test_aad_construction() {
170 let session_id = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06];
171 let aad = construct_aad(0x03, 0x01, &session_id, 42);
172
173 assert_eq!(aad.len(), AAD_SIZE);
174 assert_eq!(aad[0], 0x03); assert_eq!(aad[1], 0x01); assert_eq!(&aad[2..8], &session_id); assert_eq!(&aad[8..16], &42u64.to_le_bytes()); }
179
180 #[test]
181 fn test_encrypt_decrypt_roundtrip() {
182 let key = SessionKey::from_bytes([0x42; SESSION_KEY_SIZE]);
183 let nonce = [0x01; AEAD_NONCE_SIZE];
184 let aad = [0x02; AAD_SIZE];
185 let plaintext = b"Hello, NOMAD!";
186
187 let ciphertext = encrypt(&key, &nonce, &aad, plaintext).unwrap();
188 assert_eq!(ciphertext.len(), plaintext.len() + AEAD_TAG_SIZE);
189
190 let decrypted = decrypt(&key, &nonce, &aad, &ciphertext).unwrap();
191 assert_eq!(decrypted, plaintext);
192 }
193
194 #[test]
195 fn test_decrypt_wrong_key_fails() {
196 let key1 = SessionKey::from_bytes([0x42; SESSION_KEY_SIZE]);
197 let key2 = SessionKey::from_bytes([0x43; SESSION_KEY_SIZE]);
198 let nonce = [0x01; AEAD_NONCE_SIZE];
199 let aad = [0x02; AAD_SIZE];
200 let plaintext = b"Secret message";
201
202 let ciphertext = encrypt(&key1, &nonce, &aad, plaintext).unwrap();
203 let result = decrypt(&key2, &nonce, &aad, &ciphertext);
204
205 assert!(matches!(result, Err(CryptoError::DecryptionFailed)));
206 }
207
208 #[test]
209 fn test_decrypt_wrong_aad_fails() {
210 let key = SessionKey::from_bytes([0x42; SESSION_KEY_SIZE]);
211 let nonce = [0x01; AEAD_NONCE_SIZE];
212 let aad1 = [0x02; AAD_SIZE];
213 let aad2 = [0x03; AAD_SIZE];
214 let plaintext = b"Secret message";
215
216 let ciphertext = encrypt(&key, &nonce, &aad1, plaintext).unwrap();
217 let result = decrypt(&key, &nonce, &aad2, &ciphertext);
218
219 assert!(matches!(result, Err(CryptoError::DecryptionFailed)));
220 }
221
222 #[test]
223 fn test_decrypt_corrupted_ciphertext_fails() {
224 let key = SessionKey::from_bytes([0x42; SESSION_KEY_SIZE]);
225 let nonce = [0x01; AEAD_NONCE_SIZE];
226 let aad = [0x02; AAD_SIZE];
227 let plaintext = b"Secret message";
228
229 let mut ciphertext = encrypt(&key, &nonce, &aad, plaintext).unwrap();
230 ciphertext[0] ^= 0xFF; let result = decrypt(&key, &nonce, &aad, &ciphertext);
233 assert!(matches!(result, Err(CryptoError::DecryptionFailed)));
234 }
235
236 #[test]
237 fn test_encrypt_decrypt_in_place() {
238 let key = SessionKey::from_bytes([0x42; SESSION_KEY_SIZE]);
239 let nonce = [0x01; AEAD_NONCE_SIZE];
240 let aad = [0x02; AAD_SIZE];
241 let plaintext = b"Hello, NOMAD!";
242
243 let mut buffer = plaintext.to_vec();
244 encrypt_in_place(&key, &nonce, &aad, &mut buffer).unwrap();
245 assert_eq!(buffer.len(), plaintext.len() + AEAD_TAG_SIZE);
246
247 decrypt_in_place(&key, &nonce, &aad, &mut buffer).unwrap();
248 assert_eq!(buffer, plaintext);
249 }
250
251 #[test]
252 fn test_empty_plaintext() {
253 let key = SessionKey::from_bytes([0x42; SESSION_KEY_SIZE]);
254 let nonce = [0x01; AEAD_NONCE_SIZE];
255 let aad = [0x02; AAD_SIZE];
256 let plaintext = b"";
257
258 let ciphertext = encrypt(&key, &nonce, &aad, plaintext).unwrap();
259 assert_eq!(ciphertext.len(), AEAD_TAG_SIZE); let decrypted = decrypt(&key, &nonce, &aad, &ciphertext).unwrap();
262 assert_eq!(decrypted, plaintext);
263 }
264}