1use aes_gcm::{Aes256Gcm, KeyInit};
13use chacha20poly1305::{ChaCha20Poly1305, aead::AeadCore};
14use zeroize::ZeroizeOnDrop;
15use serde::{Serialize, Deserialize};
16
17use crate::{CryptoError, Result};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
21pub enum CipherSuite {
22 #[default]
24 ChaCha20Poly1305,
25 Aes256Gcm,
27}
28
29impl CipherSuite {
30 pub const KEY_SIZE: usize = 32;
32 pub const NONCE_SIZE: usize = 12;
34 pub const TAG_SIZE: usize = 16;
36
37 #[inline(always)]
39 pub const fn key_size(&self) -> usize {
40 Self::KEY_SIZE
41 }
42
43 #[inline(always)]
45 pub const fn nonce_size(&self) -> usize {
46 Self::NONCE_SIZE
47 }
48
49 #[inline(always)]
51 pub const fn tag_size(&self) -> usize {
52 Self::TAG_SIZE
53 }
54}
55
56pub struct DataChannelKey {
58 key: [u8; 32],
59 cipher_suite: CipherSuite,
60}
61
62impl DataChannelKey {
63 pub fn new(key: [u8; 32], cipher_suite: CipherSuite) -> Self {
65 Self { key, cipher_suite }
66 }
67
68 pub fn cipher_suite(&self) -> CipherSuite {
70 self.cipher_suite
71 }
72
73 pub fn cipher(&self) -> Cipher {
75 Cipher::new(&self.key, self.cipher_suite)
76 }
77}
78
79impl Drop for DataChannelKey {
80 fn drop(&mut self) {
81 use zeroize::Zeroize;
82 self.key.zeroize();
83 }
84}
85
86impl ZeroizeOnDrop for DataChannelKey {}
87
88pub struct Cipher {
90 inner: CipherInner,
91 suite: CipherSuite,
92}
93
94enum CipherInner {
95 ChaCha(ChaCha20Poly1305),
96 Aes(Box<Aes256Gcm>),
97}
98
99impl Cipher {
100 #[inline]
102 pub fn new(key: &[u8; 32], suite: CipherSuite) -> Self {
103 let inner = match suite {
104 CipherSuite::ChaCha20Poly1305 => {
105 CipherInner::ChaCha(ChaCha20Poly1305::new(key.into()))
106 }
107 CipherSuite::Aes256Gcm => {
108 CipherInner::Aes(Box::new(Aes256Gcm::new(key.into())))
109 }
110 };
111 Self { inner, suite }
112 }
113
114 #[inline]
118 pub fn encrypt(&self, nonce: &[u8; 12], plaintext: &[u8], aad: &[u8]) -> Result<Vec<u8>> {
119 use chacha20poly1305::aead::Aead;
120 use aes_gcm::aead::Payload;
121
122 let payload = Payload { msg: plaintext, aad };
123
124 match &self.inner {
125 CipherInner::ChaCha(cipher) => {
126 cipher.encrypt(nonce.into(), payload)
127 .map_err(|_| CryptoError::EncryptionFailed("ChaCha20-Poly1305 encryption failed"))
128 }
129 CipherInner::Aes(cipher) => {
130 cipher.encrypt(nonce.into(), payload)
131 .map_err(|_| CryptoError::EncryptionFailed("AES-256-GCM encryption failed"))
132 }
133 }
134 }
135
136 #[inline]
141 pub fn encrypt_into(&self, nonce: &[u8; 12], plaintext: &[u8], aad: &[u8], out: &mut Vec<u8>) -> Result<usize> {
142 use chacha20poly1305::aead::Aead;
143 use aes_gcm::aead::Payload;
144
145 let payload = Payload { msg: plaintext, aad };
146 let start_len = out.len();
147
148 let ciphertext = match &self.inner {
149 CipherInner::ChaCha(cipher) => {
150 cipher.encrypt(nonce.into(), payload)
151 .map_err(|_| CryptoError::EncryptionFailed("ChaCha20-Poly1305 encryption failed"))?
152 }
153 CipherInner::Aes(cipher) => {
154 cipher.encrypt(nonce.into(), payload)
155 .map_err(|_| CryptoError::EncryptionFailed("AES-256-GCM encryption failed"))?
156 }
157 };
158
159 out.extend_from_slice(&ciphertext);
160 Ok(out.len() - start_len)
161 }
162
163 #[inline]
167 pub fn decrypt(&self, nonce: &[u8; 12], ciphertext: &[u8], aad: &[u8]) -> Result<Vec<u8>> {
168 use chacha20poly1305::aead::Aead;
169 use aes_gcm::aead::Payload;
170
171 let payload = Payload { msg: ciphertext, aad };
172
173 match &self.inner {
174 CipherInner::ChaCha(cipher) => {
175 cipher.decrypt(nonce.into(), payload)
176 .map_err(|_| CryptoError::DecryptionFailed)
177 }
178 CipherInner::Aes(cipher) => {
179 cipher.decrypt(nonce.into(), payload)
180 .map_err(|_| CryptoError::DecryptionFailed)
181 }
182 }
183 }
184
185 #[inline]
190 pub fn generate_nonce(&self) -> [u8; 12] {
191 match &self.inner {
192 CipherInner::ChaCha(_) => {
193 ChaCha20Poly1305::generate_nonce(&mut rand::rngs::OsRng).into()
194 }
195 CipherInner::Aes(_) => {
196 Aes256Gcm::generate_nonce(&mut rand::rngs::OsRng).into()
197 }
198 }
199 }
200
201 #[inline(always)]
203 pub fn suite(&self) -> CipherSuite {
204 self.suite
205 }
206}
207
208pub struct PacketCipher {
215 cipher: Cipher,
216 tx_counter: u64,
218 rx_window: ReplayWindow,
220}
221
222const PACKET_HEADER_SIZE: usize = 8;
224
225impl PacketCipher {
226 #[inline]
228 pub fn new(key: DataChannelKey) -> Self {
229 Self {
230 cipher: key.cipher(),
231 tx_counter: 0,
232 rx_window: ReplayWindow::new(),
233 }
234 }
235
236 #[inline]
240 pub fn encrypt(&mut self, plaintext: &[u8]) -> Result<Vec<u8>> {
241 self.tx_counter = self.tx_counter.checked_add(1)
243 .ok_or(CryptoError::EncryptionFailed("packet counter overflow"))?;
244
245 let mut nonce = [0u8; 12];
248 let packet_id = self.tx_counter.to_be_bytes();
249 nonce[4..].copy_from_slice(&packet_id);
250
251 let output_len = PACKET_HEADER_SIZE + plaintext.len() + CipherSuite::TAG_SIZE;
254 let mut output = Vec::with_capacity(output_len);
255
256 output.extend_from_slice(&packet_id);
258
259 self.cipher.encrypt_into(&nonce, plaintext, &packet_id, &mut output)?;
261
262 Ok(output)
263 }
264
265 #[inline]
270 pub fn encrypt_into(&mut self, plaintext: &[u8], output: &mut Vec<u8>) -> Result<usize> {
271 self.tx_counter = self.tx_counter.checked_add(1)
272 .ok_or(CryptoError::EncryptionFailed("packet counter overflow"))?;
273
274 let mut nonce = [0u8; 12];
275 let packet_id = self.tx_counter.to_be_bytes();
276 nonce[4..].copy_from_slice(&packet_id);
277
278 output.extend_from_slice(&packet_id);
279 let cipher_bytes = self.cipher.encrypt_into(&nonce, plaintext, &packet_id, output)?;
280
281 Ok(PACKET_HEADER_SIZE + cipher_bytes)
282 }
283
284 #[inline]
286 pub fn decrypt(&mut self, packet: &[u8]) -> Result<Vec<u8>> {
287 const MIN_PACKET_SIZE: usize = PACKET_HEADER_SIZE + CipherSuite::TAG_SIZE;
288
289 if packet.len() < MIN_PACKET_SIZE {
290 return Err(CryptoError::DecryptionFailed);
291 }
292
293 let packet_id: [u8; 8] = packet[..8].try_into().unwrap();
295 let counter = u64::from_be_bytes(packet_id);
296
297 if !self.rx_window.check_and_update(counter) {
299 return Err(CryptoError::ReplayDetected);
300 }
301
302 let mut nonce = [0u8; 12];
304 nonce[4..].copy_from_slice(&packet_id);
305
306 self.cipher.decrypt(&nonce, &packet[8..], &packet_id)
308 }
309
310 #[inline(always)]
312 pub fn tx_counter(&self) -> u64 {
313 self.tx_counter
314 }
315}
316
317struct ReplayWindow {
322 highest: u64,
324 bitmap: u128,
327}
328
329impl ReplayWindow {
330 const WINDOW_SIZE: u64 = 128;
332
333 #[inline]
334 fn new() -> Self {
335 Self {
336 highest: 0,
337 bitmap: 0,
338 }
339 }
340
341 #[inline]
346 fn check_and_update(&mut self, packet_id: u64) -> bool {
347 if packet_id == 0 {
349 return false;
350 }
351
352 if packet_id > self.highest {
353 let shift = packet_id - self.highest;
355
356 if shift >= Self::WINDOW_SIZE {
357 self.bitmap = 1; } else {
360 self.bitmap = (self.bitmap << shift) | 1;
363 }
364 self.highest = packet_id;
365 true
366 } else {
367 let diff = self.highest - packet_id;
369
370 if diff >= Self::WINDOW_SIZE {
372 return false; }
374
375 let mask = 1u128 << diff;
377 if self.bitmap & mask != 0 {
378 return false; }
380
381 self.bitmap |= mask;
383 true
384 }
385 }
386
387 #[allow(dead_code)]
389 #[inline]
390 pub fn reset(&mut self) {
391 self.highest = 0;
392 self.bitmap = 0;
393 }
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399
400 #[test]
401 fn test_encrypt_decrypt() {
402 let key = [0x42u8; 32];
403
404 for suite in [CipherSuite::ChaCha20Poly1305, CipherSuite::Aes256Gcm] {
405 let cipher = Cipher::new(&key, suite);
406 let nonce = cipher.generate_nonce();
407 let plaintext = b"Hello, CoreVPN!";
408 let aad = b"associated data";
409
410 let ciphertext = cipher.encrypt(&nonce, plaintext, aad).unwrap();
411 let decrypted = cipher.decrypt(&nonce, &ciphertext, aad).unwrap();
412
413 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
414 }
415 }
416
417 #[test]
418 fn test_authentication_failure() {
419 let key = [0x42u8; 32];
420 let cipher = Cipher::new(&key, CipherSuite::ChaCha20Poly1305);
421 let nonce = cipher.generate_nonce();
422
423 let ciphertext = cipher.encrypt(&nonce, b"test", b"aad").unwrap();
424
425 let mut tampered = ciphertext.clone();
427 tampered[0] ^= 0xFF;
428
429 assert!(cipher.decrypt(&nonce, &tampered, b"aad").is_err());
430 }
431
432 #[test]
433 fn test_packet_cipher_replay_protection() {
434 let key = DataChannelKey::new([0x42u8; 32], CipherSuite::ChaCha20Poly1305);
435 let mut encryptor = PacketCipher::new(key);
436
437 let key2 = DataChannelKey::new([0x42u8; 32], CipherSuite::ChaCha20Poly1305);
438 let mut decryptor = PacketCipher::new(key2);
439
440 let p1 = encryptor.encrypt(b"packet 1").unwrap();
442 let p2 = encryptor.encrypt(b"packet 2").unwrap();
443 let p3 = encryptor.encrypt(b"packet 3").unwrap();
444
445 assert!(decryptor.decrypt(&p1).is_ok());
447 assert!(decryptor.decrypt(&p2).is_ok());
448
449 assert!(decryptor.decrypt(&p1).is_err());
451
452 assert!(decryptor.decrypt(&p3).is_ok());
454
455 assert!(decryptor.decrypt(&p3).is_err());
457 }
458
459 #[test]
460 fn test_replay_window() {
461 let mut window = ReplayWindow::new();
462
463 assert!(window.check_and_update(1));
464 assert!(window.check_and_update(2));
465 assert!(!window.check_and_update(1)); assert!(window.check_and_update(100));
467 assert!(!window.check_and_update(1)); assert!(window.check_and_update(99)); assert!(!window.check_and_update(99)); }
471}