1use aes_gcm::{Aes256Gcm, KeyInit};
13use chacha20poly1305::{ChaCha20Poly1305, aead::AeadCore};
14use zeroize::ZeroizeOnDrop;
15use serde::{Serialize, Deserialize};
16
17use crate::{CryptoError, Result};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
21pub enum CipherSuite {
22 #[default]
24 ChaCha20Poly1305,
25 Aes256Gcm,
27}
28
29impl CipherSuite {
30 pub const KEY_SIZE: usize = 32;
32 pub const NONCE_SIZE: usize = 12;
34 pub const TAG_SIZE: usize = 16;
36
37 #[inline(always)]
39 pub const fn key_size(&self) -> usize {
40 Self::KEY_SIZE
41 }
42
43 #[inline(always)]
45 pub const fn nonce_size(&self) -> usize {
46 Self::NONCE_SIZE
47 }
48
49 #[inline(always)]
51 pub const fn tag_size(&self) -> usize {
52 Self::TAG_SIZE
53 }
54}
55
56pub struct DataChannelKey {
58 key: [u8; 32],
59 implicit_iv: [u8; 12],
61 cipher_suite: CipherSuite,
62}
63
64impl DataChannelKey {
65 pub fn new(key: [u8; 32], cipher_suite: CipherSuite) -> Self {
67 Self { key, implicit_iv: [0u8; 12], cipher_suite }
68 }
69
70 pub fn new_with_iv(key: [u8; 32], implicit_iv: [u8; 12], cipher_suite: CipherSuite) -> Self {
72 Self { key, implicit_iv, cipher_suite }
73 }
74
75 pub fn cipher_suite(&self) -> CipherSuite {
77 self.cipher_suite
78 }
79
80 pub fn key(&self) -> &[u8; 32] {
82 &self.key
83 }
84
85 pub fn implicit_iv(&self) -> &[u8; 12] {
87 &self.implicit_iv
88 }
89
90 pub fn cipher(&self) -> Cipher {
92 Cipher::new(&self.key, self.cipher_suite)
93 }
94}
95
96impl Drop for DataChannelKey {
97 fn drop(&mut self) {
98 use zeroize::Zeroize;
99 self.key.zeroize();
100 self.implicit_iv.zeroize();
101 }
102}
103
104impl ZeroizeOnDrop for DataChannelKey {}
105
106pub struct Cipher {
108 inner: CipherInner,
109 suite: CipherSuite,
110}
111
112enum CipherInner {
113 ChaCha(ChaCha20Poly1305),
114 Aes(Box<Aes256Gcm>),
115}
116
117impl Cipher {
118 #[inline]
120 pub fn new(key: &[u8; 32], suite: CipherSuite) -> Self {
121 let inner = match suite {
122 CipherSuite::ChaCha20Poly1305 => {
123 CipherInner::ChaCha(ChaCha20Poly1305::new(key.into()))
124 }
125 CipherSuite::Aes256Gcm => {
126 CipherInner::Aes(Box::new(Aes256Gcm::new(key.into())))
127 }
128 };
129 Self { inner, suite }
130 }
131
132 #[inline]
136 pub fn encrypt(&self, nonce: &[u8; 12], plaintext: &[u8], aad: &[u8]) -> Result<Vec<u8>> {
137 use chacha20poly1305::aead::Aead;
138 use aes_gcm::aead::Payload;
139
140 let payload = Payload { msg: plaintext, aad };
141
142 match &self.inner {
143 CipherInner::ChaCha(cipher) => {
144 cipher.encrypt(nonce.into(), payload)
145 .map_err(|_| CryptoError::EncryptionFailed("ChaCha20-Poly1305 encryption failed"))
146 }
147 CipherInner::Aes(cipher) => {
148 cipher.encrypt(nonce.into(), payload)
149 .map_err(|_| CryptoError::EncryptionFailed("AES-256-GCM encryption failed"))
150 }
151 }
152 }
153
154 #[inline]
159 pub fn encrypt_into(&self, nonce: &[u8; 12], plaintext: &[u8], aad: &[u8], out: &mut Vec<u8>) -> Result<usize> {
160 use chacha20poly1305::aead::Aead;
161 use aes_gcm::aead::Payload;
162
163 let payload = Payload { msg: plaintext, aad };
164 let start_len = out.len();
165
166 let ciphertext = match &self.inner {
167 CipherInner::ChaCha(cipher) => {
168 cipher.encrypt(nonce.into(), payload)
169 .map_err(|_| CryptoError::EncryptionFailed("ChaCha20-Poly1305 encryption failed"))?
170 }
171 CipherInner::Aes(cipher) => {
172 cipher.encrypt(nonce.into(), payload)
173 .map_err(|_| CryptoError::EncryptionFailed("AES-256-GCM encryption failed"))?
174 }
175 };
176
177 out.extend_from_slice(&ciphertext);
178 Ok(out.len() - start_len)
179 }
180
181 #[inline]
185 pub fn decrypt(&self, nonce: &[u8; 12], ciphertext: &[u8], aad: &[u8]) -> Result<Vec<u8>> {
186 use chacha20poly1305::aead::Aead;
187 use aes_gcm::aead::Payload;
188
189 let payload = Payload { msg: ciphertext, aad };
190
191 match &self.inner {
192 CipherInner::ChaCha(cipher) => {
193 cipher.decrypt(nonce.into(), payload)
194 .map_err(|_| CryptoError::DecryptionFailed)
195 }
196 CipherInner::Aes(cipher) => {
197 cipher.decrypt(nonce.into(), payload)
198 .map_err(|_| CryptoError::DecryptionFailed)
199 }
200 }
201 }
202
203 #[inline]
208 pub fn generate_nonce(&self) -> [u8; 12] {
209 match &self.inner {
210 CipherInner::ChaCha(_) => {
211 ChaCha20Poly1305::generate_nonce(&mut rand::rngs::OsRng).into()
212 }
213 CipherInner::Aes(_) => {
214 Aes256Gcm::generate_nonce(&mut rand::rngs::OsRng).into()
215 }
216 }
217 }
218
219 #[inline(always)]
221 pub fn suite(&self) -> CipherSuite {
222 self.suite
223 }
224}
225
226pub struct PacketCipher {
239 cipher: Cipher,
240 implicit_iv: [u8; 12],
242 tx_counter: u32,
244 rx_window: ReplayWindow,
246 debug_key_prefix: [u8; 8],
248}
249
250const PACKET_ID_SIZE: usize = 4;
252
253impl PacketCipher {
254 #[inline]
256 pub fn new(key: DataChannelKey) -> Self {
257 let implicit_iv = *key.implicit_iv();
258 let mut debug_key_prefix = [0u8; 8];
259 debug_key_prefix.copy_from_slice(&key.key()[..8]);
260 Self {
261 cipher: key.cipher(),
262 implicit_iv,
263 tx_counter: 0,
264 rx_window: ReplayWindow::new(),
265 debug_key_prefix,
266 }
267 }
268
269 #[inline(always)]
280 fn build_nonce(&self, pid_bytes: &[u8; 4]) -> [u8; 12] {
281 let mut nonce = self.implicit_iv;
282 nonce[0] ^= pid_bytes[0];
283 nonce[1] ^= pid_bytes[1];
284 nonce[2] ^= pid_bytes[2];
285 nonce[3] ^= pid_bytes[3];
286 nonce
287 }
288
289 #[inline]
301 pub fn encrypt(&mut self, plaintext: &[u8], ad_prefix: &[u8]) -> Result<Vec<u8>> {
302 self.tx_counter = self.tx_counter.checked_add(1)
303 .ok_or(CryptoError::EncryptionFailed("packet counter overflow"))?;
304
305 let pid_bytes = self.tx_counter.to_be_bytes();
306 let nonce = self.build_nonce(&pid_bytes);
307
308 let mut aad = Vec::with_capacity(ad_prefix.len() + PACKET_ID_SIZE);
310 aad.extend_from_slice(ad_prefix);
311 aad.extend_from_slice(&pid_bytes);
312
313 let ct_tag = self.cipher.encrypt(&nonce, plaintext, &aad)?;
315
316 let ct_len = ct_tag.len() - CipherSuite::TAG_SIZE;
320 let ciphertext = &ct_tag[..ct_len];
321 let tag = &ct_tag[ct_len..];
322
323 let mut output = Vec::with_capacity(PACKET_ID_SIZE + ct_tag.len());
324 output.extend_from_slice(&pid_bytes);
325 output.extend_from_slice(tag);
326 output.extend_from_slice(ciphertext);
327
328 Ok(output)
329 }
330
331 #[inline]
336 pub fn encrypt_into(&mut self, plaintext: &[u8], ad_prefix: &[u8], output: &mut Vec<u8>) -> Result<usize> {
337 self.tx_counter = self.tx_counter.checked_add(1)
338 .ok_or(CryptoError::EncryptionFailed("packet counter overflow"))?;
339
340 let pid_bytes = self.tx_counter.to_be_bytes();
341 let nonce = self.build_nonce(&pid_bytes);
342
343 let mut aad = Vec::with_capacity(ad_prefix.len() + PACKET_ID_SIZE);
345 aad.extend_from_slice(ad_prefix);
346 aad.extend_from_slice(&pid_bytes);
347
348 let ct_tag = self.cipher.encrypt(&nonce, plaintext, &aad)?;
349
350 let ct_len = ct_tag.len() - CipherSuite::TAG_SIZE;
352 let ciphertext = &ct_tag[..ct_len];
353 let tag = &ct_tag[ct_len..];
354
355 let total = PACKET_ID_SIZE + ct_tag.len();
356 output.extend_from_slice(&pid_bytes);
357 output.extend_from_slice(tag);
358 output.extend_from_slice(ciphertext);
359
360 Ok(total)
361 }
362
363 #[inline]
375 pub fn decrypt(&mut self, packet: &[u8], ad_prefix: &[u8]) -> Result<Vec<u8>> {
376 const MIN_PACKET_SIZE: usize = PACKET_ID_SIZE + CipherSuite::TAG_SIZE;
377
378 if packet.len() < MIN_PACKET_SIZE {
379 return Err(CryptoError::DecryptionFailed);
380 }
381
382 let pid_bytes: [u8; 4] = packet[..4].try_into().unwrap();
384 let counter = u32::from_be_bytes(pid_bytes) as u64;
385
386 if !self.rx_window.check_and_update(counter) {
388 return Err(CryptoError::ReplayDetected);
389 }
390
391 let nonce = self.build_nonce(&pid_bytes);
392
393 let mut aad = Vec::with_capacity(ad_prefix.len() + PACKET_ID_SIZE);
395 aad.extend_from_slice(ad_prefix);
396 aad.extend_from_slice(&pid_bytes);
397
398 if counter <= 3 {
400 eprintln!("[DECRYPT] packet_id={} key_prefix={:02x?} iv={:02x?}",
401 counter, &self.debug_key_prefix, &self.implicit_iv);
402 eprintln!("[DECRYPT] nonce={:02x?} aad={:02x?}", &nonce, &aad);
403 eprintln!("[DECRYPT] packet[..20]={:02x?} total_len={}",
404 &packet[..std::cmp::min(20, packet.len())], packet.len());
405 }
406
407 let tag = &packet[PACKET_ID_SIZE..PACKET_ID_SIZE + CipherSuite::TAG_SIZE];
409 let ct = &packet[PACKET_ID_SIZE + CipherSuite::TAG_SIZE..];
410 let mut ct_tag_reordered = Vec::with_capacity(ct.len() + CipherSuite::TAG_SIZE);
411 ct_tag_reordered.extend_from_slice(ct);
412 ct_tag_reordered.extend_from_slice(tag);
413
414 let ct_tag_end = &packet[PACKET_ID_SIZE..];
416
417 if let Ok(plaintext) = self.cipher.decrypt(&nonce, ct_tag_end, &aad) {
419 if counter <= 3 {
420 eprintln!("[DECRYPT] SUCCESS (tag-at-end) plaintext_len={}", plaintext.len());
421 }
422 return Ok(plaintext);
423 }
424
425 match self.cipher.decrypt(&nonce, &ct_tag_reordered, &aad) {
426 Ok(plaintext) => {
427 if counter <= 3 {
428 eprintln!("[DECRYPT] SUCCESS (tag-before) plaintext_len={}", plaintext.len());
429 }
430 Ok(plaintext)
431 }
432 Err(e) => {
433 if counter <= 3 {
434 eprintln!("[DECRYPT] FAILED both formats");
435 }
436 Err(e)
437 }
438 }
439 }
440
441 #[inline(always)]
443 pub fn tx_counter(&self) -> u64 {
444 self.tx_counter as u64
445 }
446}
447
448struct ReplayWindow {
453 highest: u64,
455 bitmap: u128,
458}
459
460impl ReplayWindow {
461 const WINDOW_SIZE: u64 = 128;
463
464 #[inline]
465 fn new() -> Self {
466 Self {
467 highest: 0,
468 bitmap: 0,
469 }
470 }
471
472 #[inline]
477 fn check_and_update(&mut self, packet_id: u64) -> bool {
478 if packet_id == 0 {
480 return false;
481 }
482
483 if packet_id > self.highest {
484 let shift = packet_id - self.highest;
486
487 if shift >= Self::WINDOW_SIZE {
488 self.bitmap = 1; } else {
491 self.bitmap = (self.bitmap << shift) | 1;
494 }
495 self.highest = packet_id;
496 true
497 } else {
498 let diff = self.highest - packet_id;
500
501 if diff >= Self::WINDOW_SIZE {
503 return false; }
505
506 let mask = 1u128 << diff;
508 if self.bitmap & mask != 0 {
509 return false; }
511
512 self.bitmap |= mask;
514 true
515 }
516 }
517
518 #[allow(dead_code)]
520 #[inline]
521 pub fn reset(&mut self) {
522 self.highest = 0;
523 self.bitmap = 0;
524 }
525}
526
527#[cfg(test)]
528mod tests {
529 use super::*;
530
531 #[test]
532 fn test_encrypt_decrypt() {
533 let key = [0x42u8; 32];
534
535 for suite in [CipherSuite::ChaCha20Poly1305, CipherSuite::Aes256Gcm] {
536 let cipher = Cipher::new(&key, suite);
537 let nonce = cipher.generate_nonce();
538 let plaintext = b"Hello, CoreVPN!";
539 let aad = b"associated data";
540
541 let ciphertext = cipher.encrypt(&nonce, plaintext, aad).unwrap();
542 let decrypted = cipher.decrypt(&nonce, &ciphertext, aad).unwrap();
543
544 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
545 }
546 }
547
548 #[test]
549 fn test_authentication_failure() {
550 let key = [0x42u8; 32];
551 let cipher = Cipher::new(&key, CipherSuite::ChaCha20Poly1305);
552 let nonce = cipher.generate_nonce();
553
554 let ciphertext = cipher.encrypt(&nonce, b"test", b"aad").unwrap();
555
556 let mut tampered = ciphertext.clone();
558 tampered[0] ^= 0xFF;
559
560 assert!(cipher.decrypt(&nonce, &tampered, b"aad").is_err());
561 }
562
563 #[test]
564 fn test_packet_cipher_replay_protection() {
565 let iv = [0xABu8; 12];
566 let key = DataChannelKey::new_with_iv([0x42u8; 32], iv, CipherSuite::ChaCha20Poly1305);
567 let mut encryptor = PacketCipher::new(key);
568
569 let key2 = DataChannelKey::new_with_iv([0x42u8; 32], iv, CipherSuite::ChaCha20Poly1305);
570 let mut decryptor = PacketCipher::new(key2);
571
572 let ad = &[0x48u8, 0x00, 0x00, 0x01]; let p1 = encryptor.encrypt(b"packet 1", ad).unwrap();
576 let p2 = encryptor.encrypt(b"packet 2", ad).unwrap();
577 let p3 = encryptor.encrypt(b"packet 3", ad).unwrap();
578
579 assert!(decryptor.decrypt(&p1, ad).is_ok());
581 assert!(decryptor.decrypt(&p2, ad).is_ok());
582
583 assert!(decryptor.decrypt(&p1, ad).is_err());
585
586 assert!(decryptor.decrypt(&p3, ad).is_ok());
588
589 assert!(decryptor.decrypt(&p3, ad).is_err());
591 }
592
593 #[test]
594 fn test_replay_window() {
595 let mut window = ReplayWindow::new();
596
597 assert!(window.check_and_update(1));
598 assert!(window.check_and_update(2));
599 assert!(!window.check_and_update(1)); assert!(window.check_and_update(100));
601 assert!(!window.check_and_update(1)); assert!(window.check_and_update(99)); assert!(!window.check_and_update(99)); }
605}