Skip to main content

corevpn_crypto/
cipher.rs

1//! Symmetric cipher implementations for data channel encryption
2//!
3//! Supports ChaCha20-Poly1305 (preferred) and AES-256-GCM (fallback).
4//! Both provide authenticated encryption with associated data (AEAD).
5//!
6//! # Performance Optimizations
7//! - Cipher instances are cached in PacketCipher
8//! - Counter-based nonces avoid RNG syscalls
9//! - Pre-allocated output buffers reduce allocations
10//! - Inlined hot paths for better performance
11
12use aes_gcm::{Aes256Gcm, KeyInit};
13use chacha20poly1305::{ChaCha20Poly1305, aead::AeadCore};
14use zeroize::ZeroizeOnDrop;
15use serde::{Serialize, Deserialize};
16
17use crate::{CryptoError, Result};
18
19/// Supported cipher suites
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
21pub enum CipherSuite {
22    /// ChaCha20-Poly1305 - preferred for software implementations
23    #[default]
24    ChaCha20Poly1305,
25    /// AES-256-GCM - hardware accelerated on modern CPUs
26    Aes256Gcm,
27}
28
29impl CipherSuite {
30    /// Key size in bytes (256 bits for both suites)
31    pub const KEY_SIZE: usize = 32;
32    /// Nonce size in bytes (96 bits for both suites)
33    pub const NONCE_SIZE: usize = 12;
34    /// Authentication tag size in bytes (128 bits for both suites)
35    pub const TAG_SIZE: usize = 16;
36
37    /// Get the key size for this cipher suite
38    #[inline(always)]
39    pub const fn key_size(&self) -> usize {
40        Self::KEY_SIZE
41    }
42
43    /// Get the nonce size for this cipher suite
44    #[inline(always)]
45    pub const fn nonce_size(&self) -> usize {
46        Self::NONCE_SIZE
47    }
48
49    /// Get the tag size for this cipher suite
50    #[inline(always)]
51    pub const fn tag_size(&self) -> usize {
52        Self::TAG_SIZE
53    }
54}
55
56/// Data channel encryption key with secure memory handling
57pub struct DataChannelKey {
58    key: [u8; 32],
59    /// Implicit IV for AEAD nonce construction (12 bytes from hmac key; bytes 4..12 used as nonce tail)
60    implicit_iv: [u8; 12],
61    cipher_suite: CipherSuite,
62}
63
64impl DataChannelKey {
65    /// Create a new data channel key (with zero implicit IV for non-AEAD or tests)
66    pub fn new(key: [u8; 32], cipher_suite: CipherSuite) -> Self {
67        Self { key, implicit_iv: [0u8; 12], cipher_suite }
68    }
69
70    /// Create a new data channel key with implicit IV (for OpenVPN AEAD)
71    pub fn new_with_iv(key: [u8; 32], implicit_iv: [u8; 12], cipher_suite: CipherSuite) -> Self {
72        Self { key, implicit_iv, cipher_suite }
73    }
74
75    /// Get the cipher suite
76    pub fn cipher_suite(&self) -> CipherSuite {
77        self.cipher_suite
78    }
79
80    /// Get the raw key bytes (for debug logging)
81    pub fn key(&self) -> &[u8; 32] {
82        &self.key
83    }
84
85    /// Get the implicit IV
86    pub fn implicit_iv(&self) -> &[u8; 12] {
87        &self.implicit_iv
88    }
89
90    /// Create a cipher instance
91    pub fn cipher(&self) -> Cipher {
92        Cipher::new(&self.key, self.cipher_suite)
93    }
94}
95
96impl Drop for DataChannelKey {
97    fn drop(&mut self) {
98        use zeroize::Zeroize;
99        self.key.zeroize();
100        self.implicit_iv.zeroize();
101    }
102}
103
104impl ZeroizeOnDrop for DataChannelKey {}
105
106/// AEAD cipher for encrypting/decrypting data channel packets
107pub struct Cipher {
108    inner: CipherInner,
109    suite: CipherSuite,
110}
111
112enum CipherInner {
113    ChaCha(ChaCha20Poly1305),
114    Aes(Box<Aes256Gcm>),
115}
116
117impl Cipher {
118    /// Create a new cipher instance
119    #[inline]
120    pub fn new(key: &[u8; 32], suite: CipherSuite) -> Self {
121        let inner = match suite {
122            CipherSuite::ChaCha20Poly1305 => {
123                CipherInner::ChaCha(ChaCha20Poly1305::new(key.into()))
124            }
125            CipherSuite::Aes256Gcm => {
126                CipherInner::Aes(Box::new(Aes256Gcm::new(key.into())))
127            }
128        };
129        Self { inner, suite }
130    }
131
132    /// Encrypt plaintext with associated data
133    ///
134    /// Returns ciphertext with authentication tag appended.
135    #[inline]
136    pub fn encrypt(&self, nonce: &[u8; 12], plaintext: &[u8], aad: &[u8]) -> Result<Vec<u8>> {
137        use chacha20poly1305::aead::Aead;
138        use aes_gcm::aead::Payload;
139
140        let payload = Payload { msg: plaintext, aad };
141
142        match &self.inner {
143            CipherInner::ChaCha(cipher) => {
144                cipher.encrypt(nonce.into(), payload)
145                    .map_err(|_| CryptoError::EncryptionFailed("ChaCha20-Poly1305 encryption failed"))
146            }
147            CipherInner::Aes(cipher) => {
148                cipher.encrypt(nonce.into(), payload)
149                    .map_err(|_| CryptoError::EncryptionFailed("AES-256-GCM encryption failed"))
150            }
151        }
152    }
153
154    /// Encrypt plaintext into pre-allocated buffer
155    ///
156    /// Returns the number of bytes written.
157    /// Buffer must have capacity for plaintext + TAG_SIZE bytes.
158    #[inline]
159    pub fn encrypt_into(&self, nonce: &[u8; 12], plaintext: &[u8], aad: &[u8], out: &mut Vec<u8>) -> Result<usize> {
160        use chacha20poly1305::aead::Aead;
161        use aes_gcm::aead::Payload;
162
163        let payload = Payload { msg: plaintext, aad };
164        let start_len = out.len();
165
166        let ciphertext = match &self.inner {
167            CipherInner::ChaCha(cipher) => {
168                cipher.encrypt(nonce.into(), payload)
169                    .map_err(|_| CryptoError::EncryptionFailed("ChaCha20-Poly1305 encryption failed"))?
170            }
171            CipherInner::Aes(cipher) => {
172                cipher.encrypt(nonce.into(), payload)
173                    .map_err(|_| CryptoError::EncryptionFailed("AES-256-GCM encryption failed"))?
174            }
175        };
176
177        out.extend_from_slice(&ciphertext);
178        Ok(out.len() - start_len)
179    }
180
181    /// Decrypt ciphertext with associated data
182    ///
183    /// Verifies authentication tag and returns plaintext.
184    #[inline]
185    pub fn decrypt(&self, nonce: &[u8; 12], ciphertext: &[u8], aad: &[u8]) -> Result<Vec<u8>> {
186        use chacha20poly1305::aead::Aead;
187        use aes_gcm::aead::Payload;
188
189        let payload = Payload { msg: ciphertext, aad };
190
191        match &self.inner {
192            CipherInner::ChaCha(cipher) => {
193                cipher.decrypt(nonce.into(), payload)
194                    .map_err(|_| CryptoError::DecryptionFailed)
195            }
196            CipherInner::Aes(cipher) => {
197                cipher.decrypt(nonce.into(), payload)
198                    .map_err(|_| CryptoError::DecryptionFailed)
199            }
200        }
201    }
202
203    /// Generate a random nonce using OsRng
204    ///
205    /// Note: For high-throughput scenarios, consider using counter-based nonces
206    /// via PacketCipher which avoids syscall overhead.
207    #[inline]
208    pub fn generate_nonce(&self) -> [u8; 12] {
209        match &self.inner {
210            CipherInner::ChaCha(_) => {
211                ChaCha20Poly1305::generate_nonce(&mut rand::rngs::OsRng).into()
212            }
213            CipherInner::Aes(_) => {
214                Aes256Gcm::generate_nonce(&mut rand::rngs::OsRng).into()
215            }
216        }
217    }
218
219    /// Get the cipher suite
220    #[inline(always)]
221    pub fn suite(&self) -> CipherSuite {
222        self.suite
223    }
224}
225
226/// Packet encryptor with automatic nonce management and replay protection.
227///
228/// Implements the OpenVPN AEAD data channel format:
229/// - 4-byte packet ID (big-endian counter)
230/// - Nonce = [packet_id_be(4)] || [implicit_iv[4..12]]
231/// - On-wire: [packet_id(4)] [AEAD_tag(16)] [ciphertext]
232/// - AAD = packet_id bytes (4 bytes)
233///
234/// # Performance
235/// - Uses counter-based nonces (no RNG syscalls)
236/// - Caches cipher instance for reuse
237/// - Pre-allocates output buffers with known capacity
238pub struct PacketCipher {
239    cipher: Cipher,
240    /// Implicit IV from key derivation (first 8 bytes used as nonce tail)
241    implicit_iv: [u8; 12],
242    /// Outgoing packet counter (32-bit, matching OpenVPN packet_id_type)
243    tx_counter: u32,
244    /// Replay protection window
245    rx_window: ReplayWindow,
246    /// Debug: raw key bytes (first 8 only, for logging)
247    debug_key_prefix: [u8; 8],
248}
249
250/// Packet ID header size (4-byte counter, matching OpenVPN)
251const PACKET_ID_SIZE: usize = 4;
252
253impl PacketCipher {
254    /// Create a new packet cipher
255    #[inline]
256    pub fn new(key: DataChannelKey) -> Self {
257        let implicit_iv = *key.implicit_iv();
258        let mut debug_key_prefix = [0u8; 8];
259        debug_key_prefix.copy_from_slice(&key.key()[..8]);
260        Self {
261            cipher: key.cipher(),
262            implicit_iv,
263            tx_counter: 0,
264            rx_window: ReplayWindow::new(),
265            debug_key_prefix,
266        }
267    }
268
269    /// Build a 12-byte AEAD nonce from implicit IV and packet ID.
270    ///
271    /// OpenVPN AEAD nonce construction (see openvpn_encrypt_aead in crypto.c):
272    ///
273    /// 1. Start with iv[0..12] = [packet_id_be(4), 0, 0, 0, 0, 0, 0, 0, 0]
274    /// 2. XOR the entire 12-byte iv with implicit_iv[0..12]:
275    ///      for i in 0..12: iv[i] ^= implicit_iv[i]
276    ///
277    /// Result: nonce[0..4] = packet_id XOR implicit_iv[0..4]
278    ///         nonce[4..12] = implicit_iv[4..12]
279    #[inline(always)]
280    fn build_nonce(&self, pid_bytes: &[u8; 4]) -> [u8; 12] {
281        let mut nonce = self.implicit_iv;
282        nonce[0] ^= pid_bytes[0];
283        nonce[1] ^= pid_bytes[1];
284        nonce[2] ^= pid_bytes[2];
285        nonce[3] ^= pid_bytes[3];
286        nonce
287    }
288
289    /// Encrypt a packet (OpenVPN AEAD tag-before-ciphertext format).
290    ///
291    /// `ad_prefix` is the header bytes (opcode + peer_id for V2) that precede
292    /// the packet ID in the on-wire format. OpenVPN authenticates these as part
293    /// of the AEAD AAD: AAD = [ad_prefix] [packet_id(4)].
294    ///
295    /// Returns: [packet_id(4)] [AEAD_tag(16)] [ciphertext]
296    ///
297    /// OpenVPN non-epoch AEAD format places the tag before the ciphertext:
298    ///   [packet_id(4)] [tag(16)] [ciphertext]
299    /// The AEAD library produces ciphertext||tag, so we reorder to tag||ciphertext.
300    #[inline]
301    pub fn encrypt(&mut self, plaintext: &[u8], ad_prefix: &[u8]) -> Result<Vec<u8>> {
302        self.tx_counter = self.tx_counter.checked_add(1)
303            .ok_or(CryptoError::EncryptionFailed("packet counter overflow"))?;
304
305        let pid_bytes = self.tx_counter.to_be_bytes();
306        let nonce = self.build_nonce(&pid_bytes);
307
308        // Build full AAD: [ad_prefix] [pid(4)]
309        let mut aad = Vec::with_capacity(ad_prefix.len() + PACKET_ID_SIZE);
310        aad.extend_from_slice(ad_prefix);
311        aad.extend_from_slice(&pid_bytes);
312
313        // AEAD encrypt: produces ciphertext || tag (standard AEAD output)
314        let ct_tag = self.cipher.encrypt(&nonce, plaintext, &aad)?;
315
316        // Reorder to tag-before-ciphertext format for OpenVPN compatibility:
317        // AEAD output: [ciphertext(N)] [tag(16)]
318        // Wire format: [pid(4)] [tag(16)] [ciphertext(N)]
319        let ct_len = ct_tag.len() - CipherSuite::TAG_SIZE;
320        let ciphertext = &ct_tag[..ct_len];
321        let tag = &ct_tag[ct_len..];
322
323        let mut output = Vec::with_capacity(PACKET_ID_SIZE + ct_tag.len());
324        output.extend_from_slice(&pid_bytes);
325        output.extend_from_slice(tag);
326        output.extend_from_slice(ciphertext);
327
328        Ok(output)
329    }
330
331    /// Encrypt a packet into a pre-allocated buffer
332    ///
333    /// Returns the total bytes written.
334    /// Buffer should be cleared before calling.
335    #[inline]
336    pub fn encrypt_into(&mut self, plaintext: &[u8], ad_prefix: &[u8], output: &mut Vec<u8>) -> Result<usize> {
337        self.tx_counter = self.tx_counter.checked_add(1)
338            .ok_or(CryptoError::EncryptionFailed("packet counter overflow"))?;
339
340        let pid_bytes = self.tx_counter.to_be_bytes();
341        let nonce = self.build_nonce(&pid_bytes);
342
343        // Build full AAD: [ad_prefix] [pid(4)]
344        let mut aad = Vec::with_capacity(ad_prefix.len() + PACKET_ID_SIZE);
345        aad.extend_from_slice(ad_prefix);
346        aad.extend_from_slice(&pid_bytes);
347
348        let ct_tag = self.cipher.encrypt(&nonce, plaintext, &aad)?;
349
350        // Reorder to tag-before-ciphertext format
351        let ct_len = ct_tag.len() - CipherSuite::TAG_SIZE;
352        let ciphertext = &ct_tag[..ct_len];
353        let tag = &ct_tag[ct_len..];
354
355        let total = PACKET_ID_SIZE + ct_tag.len();
356        output.extend_from_slice(&pid_bytes);
357        output.extend_from_slice(tag);
358        output.extend_from_slice(ciphertext);
359
360        Ok(total)
361    }
362
363    /// Decrypt a packet with replay protection (OpenVPN AEAD format).
364    ///
365    /// `ad_prefix` is the header bytes (opcode + peer_id for V2) that precede
366    /// the packet ID in the on-wire format. OpenVPN authenticates these as part
367    /// of the AEAD AAD: AAD = [ad_prefix] [packet_id(4)].
368    ///
369    /// Supports both tag-at-end and tag-before-ciphertext formats:
370    /// - Tag-at-end (OpenVPN 2.6+): [pid(4)] [ciphertext] [tag(16)]
371    /// - Tag-before (legacy):       [pid(4)] [tag(16)] [ciphertext]
372    ///
373    /// Tries tag-at-end first, falls back to tag-before if decryption fails.
374    #[inline]
375    pub fn decrypt(&mut self, packet: &[u8], ad_prefix: &[u8]) -> Result<Vec<u8>> {
376        const MIN_PACKET_SIZE: usize = PACKET_ID_SIZE + CipherSuite::TAG_SIZE;
377
378        if packet.len() < MIN_PACKET_SIZE {
379            return Err(CryptoError::DecryptionFailed);
380        }
381
382        // Extract 4-byte packet ID
383        let pid_bytes: [u8; 4] = packet[..4].try_into().unwrap();
384        let counter = u32::from_be_bytes(pid_bytes) as u64;
385
386        // Check replay
387        if !self.rx_window.check_and_update(counter) {
388            return Err(CryptoError::ReplayDetected);
389        }
390
391        let nonce = self.build_nonce(&pid_bytes);
392
393        // Build full AAD: [ad_prefix] [pid(4)]
394        let mut aad = Vec::with_capacity(ad_prefix.len() + PACKET_ID_SIZE);
395        aad.extend_from_slice(ad_prefix);
396        aad.extend_from_slice(&pid_bytes);
397
398        // Diagnostic logging for first 3 packets
399        if counter <= 3 {
400            eprintln!("[DECRYPT] packet_id={} key_prefix={:02x?} iv={:02x?}",
401                counter, &self.debug_key_prefix, &self.implicit_iv);
402            eprintln!("[DECRYPT]   nonce={:02x?} aad={:02x?}", &nonce, &aad);
403            eprintln!("[DECRYPT]   packet[..20]={:02x?} total_len={}",
404                &packet[..std::cmp::min(20, packet.len())], packet.len());
405        }
406
407        // For tag-before (OpenVPN non-epoch): [pid(4)] [tag(16)] [ciphertext]
408        let tag = &packet[PACKET_ID_SIZE..PACKET_ID_SIZE + CipherSuite::TAG_SIZE];
409        let ct = &packet[PACKET_ID_SIZE + CipherSuite::TAG_SIZE..];
410        let mut ct_tag_reordered = Vec::with_capacity(ct.len() + CipherSuite::TAG_SIZE);
411        ct_tag_reordered.extend_from_slice(ct);
412        ct_tag_reordered.extend_from_slice(tag);
413
414        // For tag-at-end (epoch format): [pid(4)] [ciphertext] [tag(16)]
415        let ct_tag_end = &packet[PACKET_ID_SIZE..];
416
417        // Try tag-at-end first (OpenVPN 2.6+ default), then tag-before (legacy)
418        if let Ok(plaintext) = self.cipher.decrypt(&nonce, ct_tag_end, &aad) {
419            if counter <= 3 {
420                eprintln!("[DECRYPT]   SUCCESS (tag-at-end) plaintext_len={}", plaintext.len());
421            }
422            return Ok(plaintext);
423        }
424
425        match self.cipher.decrypt(&nonce, &ct_tag_reordered, &aad) {
426            Ok(plaintext) => {
427                if counter <= 3 {
428                    eprintln!("[DECRYPT]   SUCCESS (tag-before) plaintext_len={}", plaintext.len());
429                }
430                Ok(plaintext)
431            }
432            Err(e) => {
433                if counter <= 3 {
434                    eprintln!("[DECRYPT]   FAILED both formats");
435                }
436                Err(e)
437            }
438        }
439    }
440
441    /// Get current TX counter (for debugging/stats)
442    #[inline(always)]
443    pub fn tx_counter(&self) -> u64 {
444        self.tx_counter as u64
445    }
446}
447
448/// Sliding window for replay protection
449///
450/// Uses a 128-bit bitmap for efficient replay detection with O(1) operations.
451/// The window tracks the last 128 packet IDs relative to the highest seen.
452struct ReplayWindow {
453    /// Highest seen packet ID
454    highest: u64,
455    /// Bitmap of recently seen packets (relative to highest)
456    /// Bit 0 = highest, bit N = highest - N
457    bitmap: u128,
458}
459
460impl ReplayWindow {
461    /// Window size in packets (128 bits = 128 packet tracking)
462    const WINDOW_SIZE: u64 = 128;
463
464    #[inline]
465    fn new() -> Self {
466        Self {
467            highest: 0,
468            bitmap: 0,
469        }
470    }
471
472    /// Check if packet ID is valid (not replayed) and update window
473    ///
474    /// Returns true if the packet should be processed, false if it's a replay
475    /// or too old.
476    #[inline]
477    fn check_and_update(&mut self, packet_id: u64) -> bool {
478        // Packet ID 0 is invalid (counter starts at 1)
479        if packet_id == 0 {
480            return false;
481        }
482
483        if packet_id > self.highest {
484            // New highest packet - advance window
485            let shift = packet_id - self.highest;
486
487            if shift >= Self::WINDOW_SIZE {
488                // Packet is way ahead, clear entire window
489                self.bitmap = 1; // Only mark current packet
490            } else {
491                // Shift window and mark current packet
492                // Use saturating shift to handle edge cases
493                self.bitmap = (self.bitmap << shift) | 1;
494            }
495            self.highest = packet_id;
496            true
497        } else {
498            // Packet is at or before highest
499            let diff = self.highest - packet_id;
500
501            // Check if packet is within window
502            if diff >= Self::WINDOW_SIZE {
503                return false; // Too old
504            }
505
506            // Check if already seen using bit test
507            let mask = 1u128 << diff;
508            if self.bitmap & mask != 0 {
509                return false; // Replay detected
510            }
511
512            // Mark as seen
513            self.bitmap |= mask;
514            true
515        }
516    }
517
518    /// Reset the replay window (e.g., for key renegotiation)
519    #[allow(dead_code)]
520    #[inline]
521    pub fn reset(&mut self) {
522        self.highest = 0;
523        self.bitmap = 0;
524    }
525}
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530
531    #[test]
532    fn test_encrypt_decrypt() {
533        let key = [0x42u8; 32];
534
535        for suite in [CipherSuite::ChaCha20Poly1305, CipherSuite::Aes256Gcm] {
536            let cipher = Cipher::new(&key, suite);
537            let nonce = cipher.generate_nonce();
538            let plaintext = b"Hello, CoreVPN!";
539            let aad = b"associated data";
540
541            let ciphertext = cipher.encrypt(&nonce, plaintext, aad).unwrap();
542            let decrypted = cipher.decrypt(&nonce, &ciphertext, aad).unwrap();
543
544            assert_eq!(plaintext.as_slice(), decrypted.as_slice());
545        }
546    }
547
548    #[test]
549    fn test_authentication_failure() {
550        let key = [0x42u8; 32];
551        let cipher = Cipher::new(&key, CipherSuite::ChaCha20Poly1305);
552        let nonce = cipher.generate_nonce();
553
554        let ciphertext = cipher.encrypt(&nonce, b"test", b"aad").unwrap();
555
556        // Tamper with ciphertext
557        let mut tampered = ciphertext.clone();
558        tampered[0] ^= 0xFF;
559
560        assert!(cipher.decrypt(&nonce, &tampered, b"aad").is_err());
561    }
562
563    #[test]
564    fn test_packet_cipher_replay_protection() {
565        let iv = [0xABu8; 12];
566        let key = DataChannelKey::new_with_iv([0x42u8; 32], iv, CipherSuite::ChaCha20Poly1305);
567        let mut encryptor = PacketCipher::new(key);
568
569        let key2 = DataChannelKey::new_with_iv([0x42u8; 32], iv, CipherSuite::ChaCha20Poly1305);
570        let mut decryptor = PacketCipher::new(key2);
571
572        let ad = &[0x48u8, 0x00, 0x00, 0x01]; // V2 header: opcode + peer_id
573
574        // Encrypt some packets
575        let p1 = encryptor.encrypt(b"packet 1", ad).unwrap();
576        let p2 = encryptor.encrypt(b"packet 2", ad).unwrap();
577        let p3 = encryptor.encrypt(b"packet 3", ad).unwrap();
578
579        // Decrypt in order - should work
580        assert!(decryptor.decrypt(&p1, ad).is_ok());
581        assert!(decryptor.decrypt(&p2, ad).is_ok());
582
583        // Replay p1 - should fail
584        assert!(decryptor.decrypt(&p1, ad).is_err());
585
586        // p3 out of order - should work
587        assert!(decryptor.decrypt(&p3, ad).is_ok());
588
589        // Replay p3 - should fail
590        assert!(decryptor.decrypt(&p3, ad).is_err());
591    }
592
593    #[test]
594    fn test_replay_window() {
595        let mut window = ReplayWindow::new();
596
597        assert!(window.check_and_update(1));
598        assert!(window.check_and_update(2));
599        assert!(!window.check_and_update(1)); // Replay
600        assert!(window.check_and_update(100));
601        assert!(!window.check_and_update(1)); // Too old
602        assert!(window.check_and_update(99)); // In window
603        assert!(!window.check_and_update(99)); // Replay
604    }
605}