1use aes_gcm::{Aes256Gcm, KeyInit};
13use chacha20poly1305::{ChaCha20Poly1305, aead::AeadCore};
14use zeroize::ZeroizeOnDrop;
15use serde::{Serialize, Deserialize};
16
17use crate::{CryptoError, Result};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
21pub enum CipherSuite {
22 #[default]
24 ChaCha20Poly1305,
25 Aes256Gcm,
27}
28
29impl CipherSuite {
30 pub const KEY_SIZE: usize = 32;
32 pub const NONCE_SIZE: usize = 12;
34 pub const TAG_SIZE: usize = 16;
36
37 #[inline(always)]
39 pub const fn key_size(&self) -> usize {
40 Self::KEY_SIZE
41 }
42
43 #[inline(always)]
45 pub const fn nonce_size(&self) -> usize {
46 Self::NONCE_SIZE
47 }
48
49 #[inline(always)]
51 pub const fn tag_size(&self) -> usize {
52 Self::TAG_SIZE
53 }
54}
55
56pub struct DataChannelKey {
58 key: [u8; 32],
59 implicit_iv: [u8; 12],
61 cipher_suite: CipherSuite,
62}
63
64impl DataChannelKey {
65 pub fn new(key: [u8; 32], cipher_suite: CipherSuite) -> Self {
67 Self { key, implicit_iv: [0u8; 12], cipher_suite }
68 }
69
70 pub fn new_with_iv(key: [u8; 32], implicit_iv: [u8; 12], cipher_suite: CipherSuite) -> Self {
72 Self { key, implicit_iv, cipher_suite }
73 }
74
75 pub fn cipher_suite(&self) -> CipherSuite {
77 self.cipher_suite
78 }
79
80 pub fn implicit_iv(&self) -> &[u8; 12] {
82 &self.implicit_iv
83 }
84
85 pub fn cipher(&self) -> Cipher {
87 Cipher::new(&self.key, self.cipher_suite)
88 }
89}
90
91impl Drop for DataChannelKey {
92 fn drop(&mut self) {
93 use zeroize::Zeroize;
94 self.key.zeroize();
95 self.implicit_iv.zeroize();
96 }
97}
98
99impl ZeroizeOnDrop for DataChannelKey {}
100
101pub struct Cipher {
103 inner: CipherInner,
104 suite: CipherSuite,
105}
106
107enum CipherInner {
108 ChaCha(ChaCha20Poly1305),
109 Aes(Box<Aes256Gcm>),
110}
111
112impl Cipher {
113 #[inline]
115 pub fn new(key: &[u8; 32], suite: CipherSuite) -> Self {
116 let inner = match suite {
117 CipherSuite::ChaCha20Poly1305 => {
118 CipherInner::ChaCha(ChaCha20Poly1305::new(key.into()))
119 }
120 CipherSuite::Aes256Gcm => {
121 CipherInner::Aes(Box::new(Aes256Gcm::new(key.into())))
122 }
123 };
124 Self { inner, suite }
125 }
126
127 #[inline]
131 pub fn encrypt(&self, nonce: &[u8; 12], plaintext: &[u8], aad: &[u8]) -> Result<Vec<u8>> {
132 use chacha20poly1305::aead::Aead;
133 use aes_gcm::aead::Payload;
134
135 let payload = Payload { msg: plaintext, aad };
136
137 match &self.inner {
138 CipherInner::ChaCha(cipher) => {
139 cipher.encrypt(nonce.into(), payload)
140 .map_err(|_| CryptoError::EncryptionFailed("ChaCha20-Poly1305 encryption failed"))
141 }
142 CipherInner::Aes(cipher) => {
143 cipher.encrypt(nonce.into(), payload)
144 .map_err(|_| CryptoError::EncryptionFailed("AES-256-GCM encryption failed"))
145 }
146 }
147 }
148
149 #[inline]
154 pub fn encrypt_into(&self, nonce: &[u8; 12], plaintext: &[u8], aad: &[u8], out: &mut Vec<u8>) -> Result<usize> {
155 use chacha20poly1305::aead::Aead;
156 use aes_gcm::aead::Payload;
157
158 let payload = Payload { msg: plaintext, aad };
159 let start_len = out.len();
160
161 let ciphertext = match &self.inner {
162 CipherInner::ChaCha(cipher) => {
163 cipher.encrypt(nonce.into(), payload)
164 .map_err(|_| CryptoError::EncryptionFailed("ChaCha20-Poly1305 encryption failed"))?
165 }
166 CipherInner::Aes(cipher) => {
167 cipher.encrypt(nonce.into(), payload)
168 .map_err(|_| CryptoError::EncryptionFailed("AES-256-GCM encryption failed"))?
169 }
170 };
171
172 out.extend_from_slice(&ciphertext);
173 Ok(out.len() - start_len)
174 }
175
176 #[inline]
180 pub fn decrypt(&self, nonce: &[u8; 12], ciphertext: &[u8], aad: &[u8]) -> Result<Vec<u8>> {
181 use chacha20poly1305::aead::Aead;
182 use aes_gcm::aead::Payload;
183
184 let payload = Payload { msg: ciphertext, aad };
185
186 match &self.inner {
187 CipherInner::ChaCha(cipher) => {
188 cipher.decrypt(nonce.into(), payload)
189 .map_err(|_| CryptoError::DecryptionFailed)
190 }
191 CipherInner::Aes(cipher) => {
192 cipher.decrypt(nonce.into(), payload)
193 .map_err(|_| CryptoError::DecryptionFailed)
194 }
195 }
196 }
197
198 #[inline]
203 pub fn generate_nonce(&self) -> [u8; 12] {
204 match &self.inner {
205 CipherInner::ChaCha(_) => {
206 ChaCha20Poly1305::generate_nonce(&mut rand::rngs::OsRng).into()
207 }
208 CipherInner::Aes(_) => {
209 Aes256Gcm::generate_nonce(&mut rand::rngs::OsRng).into()
210 }
211 }
212 }
213
214 #[inline(always)]
216 pub fn suite(&self) -> CipherSuite {
217 self.suite
218 }
219}
220
221pub struct PacketCipher {
234 cipher: Cipher,
235 implicit_iv: [u8; 12],
237 tx_counter: u32,
239 rx_window: ReplayWindow,
241}
242
243const PACKET_ID_SIZE: usize = 4;
245
246impl PacketCipher {
247 #[inline]
249 pub fn new(key: DataChannelKey) -> Self {
250 let implicit_iv = *key.implicit_iv();
251 Self {
252 cipher: key.cipher(),
253 implicit_iv,
254 tx_counter: 0,
255 rx_window: ReplayWindow::new(),
256 }
257 }
258
259 #[inline(always)]
263 fn build_nonce(&self, pid_bytes: &[u8; 4]) -> [u8; 12] {
264 let mut nonce = self.implicit_iv;
265 nonce[0] ^= pid_bytes[0];
266 nonce[1] ^= pid_bytes[1];
267 nonce[2] ^= pid_bytes[2];
268 nonce[3] ^= pid_bytes[3];
269 nonce
270 }
271
272 #[inline]
284 pub fn encrypt(&mut self, plaintext: &[u8], ad_prefix: &[u8]) -> Result<Vec<u8>> {
285 self.tx_counter = self.tx_counter.checked_add(1)
286 .ok_or(CryptoError::EncryptionFailed("packet counter overflow"))?;
287
288 let pid_bytes = self.tx_counter.to_be_bytes();
289 let nonce = self.build_nonce(&pid_bytes);
290
291 let mut aad = Vec::with_capacity(ad_prefix.len() + PACKET_ID_SIZE);
293 aad.extend_from_slice(ad_prefix);
294 aad.extend_from_slice(&pid_bytes);
295
296 let ct_tag = self.cipher.encrypt(&nonce, plaintext, &aad)?;
298
299 let mut output = Vec::with_capacity(PACKET_ID_SIZE + ct_tag.len());
301 output.extend_from_slice(&pid_bytes);
302 output.extend_from_slice(&ct_tag);
303
304 Ok(output)
305 }
306
307 #[inline]
312 pub fn encrypt_into(&mut self, plaintext: &[u8], ad_prefix: &[u8], output: &mut Vec<u8>) -> Result<usize> {
313 self.tx_counter = self.tx_counter.checked_add(1)
314 .ok_or(CryptoError::EncryptionFailed("packet counter overflow"))?;
315
316 let pid_bytes = self.tx_counter.to_be_bytes();
317 let nonce = self.build_nonce(&pid_bytes);
318
319 let mut aad = Vec::with_capacity(ad_prefix.len() + PACKET_ID_SIZE);
321 aad.extend_from_slice(ad_prefix);
322 aad.extend_from_slice(&pid_bytes);
323
324 let ct_tag = self.cipher.encrypt(&nonce, plaintext, &aad)?;
325
326 let total = PACKET_ID_SIZE + ct_tag.len();
327 output.extend_from_slice(&pid_bytes);
328 output.extend_from_slice(&ct_tag);
329
330 Ok(total)
331 }
332
333 #[inline]
345 pub fn decrypt(&mut self, packet: &[u8], ad_prefix: &[u8]) -> Result<Vec<u8>> {
346 const MIN_PACKET_SIZE: usize = PACKET_ID_SIZE + CipherSuite::TAG_SIZE;
347
348 if packet.len() < MIN_PACKET_SIZE {
349 return Err(CryptoError::DecryptionFailed);
350 }
351
352 let pid_bytes: [u8; 4] = packet[..4].try_into().unwrap();
354 let counter = u32::from_be_bytes(pid_bytes) as u64;
355
356 if !self.rx_window.check_and_update(counter) {
358 return Err(CryptoError::ReplayDetected);
359 }
360
361 let nonce = self.build_nonce(&pid_bytes);
362
363 let mut aad = Vec::with_capacity(ad_prefix.len() + PACKET_ID_SIZE);
365 aad.extend_from_slice(ad_prefix);
366 aad.extend_from_slice(&pid_bytes);
367
368 let ct_tag = &packet[PACKET_ID_SIZE..];
371 if let Ok(plaintext) = self.cipher.decrypt(&nonce, ct_tag, &aad) {
372 return Ok(plaintext);
373 }
374
375 let tag = &packet[PACKET_ID_SIZE..PACKET_ID_SIZE + CipherSuite::TAG_SIZE];
378 let ct = &packet[PACKET_ID_SIZE + CipherSuite::TAG_SIZE..];
379 let mut ct_tag_reordered = Vec::with_capacity(ct.len() + CipherSuite::TAG_SIZE);
380 ct_tag_reordered.extend_from_slice(ct);
381 ct_tag_reordered.extend_from_slice(tag);
382
383 self.cipher.decrypt(&nonce, &ct_tag_reordered, &aad)
384 }
385
386 #[inline(always)]
388 pub fn tx_counter(&self) -> u64 {
389 self.tx_counter as u64
390 }
391}
392
393struct ReplayWindow {
398 highest: u64,
400 bitmap: u128,
403}
404
405impl ReplayWindow {
406 const WINDOW_SIZE: u64 = 128;
408
409 #[inline]
410 fn new() -> Self {
411 Self {
412 highest: 0,
413 bitmap: 0,
414 }
415 }
416
417 #[inline]
422 fn check_and_update(&mut self, packet_id: u64) -> bool {
423 if packet_id == 0 {
425 return false;
426 }
427
428 if packet_id > self.highest {
429 let shift = packet_id - self.highest;
431
432 if shift >= Self::WINDOW_SIZE {
433 self.bitmap = 1; } else {
436 self.bitmap = (self.bitmap << shift) | 1;
439 }
440 self.highest = packet_id;
441 true
442 } else {
443 let diff = self.highest - packet_id;
445
446 if diff >= Self::WINDOW_SIZE {
448 return false; }
450
451 let mask = 1u128 << diff;
453 if self.bitmap & mask != 0 {
454 return false; }
456
457 self.bitmap |= mask;
459 true
460 }
461 }
462
463 #[allow(dead_code)]
465 #[inline]
466 pub fn reset(&mut self) {
467 self.highest = 0;
468 self.bitmap = 0;
469 }
470}
471
472#[cfg(test)]
473mod tests {
474 use super::*;
475
476 #[test]
477 fn test_encrypt_decrypt() {
478 let key = [0x42u8; 32];
479
480 for suite in [CipherSuite::ChaCha20Poly1305, CipherSuite::Aes256Gcm] {
481 let cipher = Cipher::new(&key, suite);
482 let nonce = cipher.generate_nonce();
483 let plaintext = b"Hello, CoreVPN!";
484 let aad = b"associated data";
485
486 let ciphertext = cipher.encrypt(&nonce, plaintext, aad).unwrap();
487 let decrypted = cipher.decrypt(&nonce, &ciphertext, aad).unwrap();
488
489 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
490 }
491 }
492
493 #[test]
494 fn test_authentication_failure() {
495 let key = [0x42u8; 32];
496 let cipher = Cipher::new(&key, CipherSuite::ChaCha20Poly1305);
497 let nonce = cipher.generate_nonce();
498
499 let ciphertext = cipher.encrypt(&nonce, b"test", b"aad").unwrap();
500
501 let mut tampered = ciphertext.clone();
503 tampered[0] ^= 0xFF;
504
505 assert!(cipher.decrypt(&nonce, &tampered, b"aad").is_err());
506 }
507
508 #[test]
509 fn test_packet_cipher_replay_protection() {
510 let iv = [0xABu8; 12];
511 let key = DataChannelKey::new_with_iv([0x42u8; 32], iv, CipherSuite::ChaCha20Poly1305);
512 let mut encryptor = PacketCipher::new(key);
513
514 let key2 = DataChannelKey::new_with_iv([0x42u8; 32], iv, CipherSuite::ChaCha20Poly1305);
515 let mut decryptor = PacketCipher::new(key2);
516
517 let ad = &[0x48u8, 0x00, 0x00, 0x01]; let p1 = encryptor.encrypt(b"packet 1", ad).unwrap();
521 let p2 = encryptor.encrypt(b"packet 2", ad).unwrap();
522 let p3 = encryptor.encrypt(b"packet 3", ad).unwrap();
523
524 assert!(decryptor.decrypt(&p1, ad).is_ok());
526 assert!(decryptor.decrypt(&p2, ad).is_ok());
527
528 assert!(decryptor.decrypt(&p1, ad).is_err());
530
531 assert!(decryptor.decrypt(&p3, ad).is_ok());
533
534 assert!(decryptor.decrypt(&p3, ad).is_err());
536 }
537
538 #[test]
539 fn test_replay_window() {
540 let mut window = ReplayWindow::new();
541
542 assert!(window.check_and_update(1));
543 assert!(window.check_and_update(2));
544 assert!(!window.check_and_update(1)); assert!(window.check_and_update(100));
546 assert!(!window.check_and_update(1)); assert!(window.check_and_update(99)); assert!(!window.check_and_update(99)); }
550}