Skip to main content

corevpn_crypto/
cipher.rs

1//! Symmetric cipher implementations for data channel encryption
2//!
3//! Supports ChaCha20-Poly1305 (preferred) and AES-256-GCM (fallback).
4//! Both provide authenticated encryption with associated data (AEAD).
5//!
6//! # Performance Optimizations
7//! - Cipher instances are cached in PacketCipher
8//! - Counter-based nonces avoid RNG syscalls
9//! - Pre-allocated output buffers reduce allocations
10//! - Inlined hot paths for better performance
11
12use aes_gcm::{Aes256Gcm, KeyInit};
13use chacha20poly1305::{ChaCha20Poly1305, aead::AeadCore};
14use zeroize::ZeroizeOnDrop;
15use serde::{Serialize, Deserialize};
16
17use crate::{CryptoError, Result};
18
19/// Supported cipher suites
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
21pub enum CipherSuite {
22    /// ChaCha20-Poly1305 - preferred for software implementations
23    #[default]
24    ChaCha20Poly1305,
25    /// AES-256-GCM - hardware accelerated on modern CPUs
26    Aes256Gcm,
27}
28
29impl CipherSuite {
30    /// Key size in bytes (256 bits for both suites)
31    pub const KEY_SIZE: usize = 32;
32    /// Nonce size in bytes (96 bits for both suites)
33    pub const NONCE_SIZE: usize = 12;
34    /// Authentication tag size in bytes (128 bits for both suites)
35    pub const TAG_SIZE: usize = 16;
36
37    /// Get the key size for this cipher suite
38    #[inline(always)]
39    pub const fn key_size(&self) -> usize {
40        Self::KEY_SIZE
41    }
42
43    /// Get the nonce size for this cipher suite
44    #[inline(always)]
45    pub const fn nonce_size(&self) -> usize {
46        Self::NONCE_SIZE
47    }
48
49    /// Get the tag size for this cipher suite
50    #[inline(always)]
51    pub const fn tag_size(&self) -> usize {
52        Self::TAG_SIZE
53    }
54}
55
56/// Data channel encryption key with secure memory handling
57pub struct DataChannelKey {
58    key: [u8; 32],
59    /// Implicit IV for AEAD nonce construction (XORed with packet counter)
60    implicit_iv: [u8; 12],
61    cipher_suite: CipherSuite,
62}
63
64impl DataChannelKey {
65    /// Create a new data channel key (with zero implicit IV for non-AEAD or tests)
66    pub fn new(key: [u8; 32], cipher_suite: CipherSuite) -> Self {
67        Self { key, implicit_iv: [0u8; 12], cipher_suite }
68    }
69
70    /// Create a new data channel key with implicit IV (for OpenVPN AEAD)
71    pub fn new_with_iv(key: [u8; 32], implicit_iv: [u8; 12], cipher_suite: CipherSuite) -> Self {
72        Self { key, implicit_iv, cipher_suite }
73    }
74
75    /// Get the cipher suite
76    pub fn cipher_suite(&self) -> CipherSuite {
77        self.cipher_suite
78    }
79
80    /// Get the implicit IV
81    pub fn implicit_iv(&self) -> &[u8; 12] {
82        &self.implicit_iv
83    }
84
85    /// Create a cipher instance
86    pub fn cipher(&self) -> Cipher {
87        Cipher::new(&self.key, self.cipher_suite)
88    }
89}
90
91impl Drop for DataChannelKey {
92    fn drop(&mut self) {
93        use zeroize::Zeroize;
94        self.key.zeroize();
95        self.implicit_iv.zeroize();
96    }
97}
98
99impl ZeroizeOnDrop for DataChannelKey {}
100
101/// AEAD cipher for encrypting/decrypting data channel packets
102pub struct Cipher {
103    inner: CipherInner,
104    suite: CipherSuite,
105}
106
107enum CipherInner {
108    ChaCha(ChaCha20Poly1305),
109    Aes(Box<Aes256Gcm>),
110}
111
112impl Cipher {
113    /// Create a new cipher instance
114    #[inline]
115    pub fn new(key: &[u8; 32], suite: CipherSuite) -> Self {
116        let inner = match suite {
117            CipherSuite::ChaCha20Poly1305 => {
118                CipherInner::ChaCha(ChaCha20Poly1305::new(key.into()))
119            }
120            CipherSuite::Aes256Gcm => {
121                CipherInner::Aes(Box::new(Aes256Gcm::new(key.into())))
122            }
123        };
124        Self { inner, suite }
125    }
126
127    /// Encrypt plaintext with associated data
128    ///
129    /// Returns ciphertext with authentication tag appended.
130    #[inline]
131    pub fn encrypt(&self, nonce: &[u8; 12], plaintext: &[u8], aad: &[u8]) -> Result<Vec<u8>> {
132        use chacha20poly1305::aead::Aead;
133        use aes_gcm::aead::Payload;
134
135        let payload = Payload { msg: plaintext, aad };
136
137        match &self.inner {
138            CipherInner::ChaCha(cipher) => {
139                cipher.encrypt(nonce.into(), payload)
140                    .map_err(|_| CryptoError::EncryptionFailed("ChaCha20-Poly1305 encryption failed"))
141            }
142            CipherInner::Aes(cipher) => {
143                cipher.encrypt(nonce.into(), payload)
144                    .map_err(|_| CryptoError::EncryptionFailed("AES-256-GCM encryption failed"))
145            }
146        }
147    }
148
149    /// Encrypt plaintext into pre-allocated buffer
150    ///
151    /// Returns the number of bytes written.
152    /// Buffer must have capacity for plaintext + TAG_SIZE bytes.
153    #[inline]
154    pub fn encrypt_into(&self, nonce: &[u8; 12], plaintext: &[u8], aad: &[u8], out: &mut Vec<u8>) -> Result<usize> {
155        use chacha20poly1305::aead::Aead;
156        use aes_gcm::aead::Payload;
157
158        let payload = Payload { msg: plaintext, aad };
159        let start_len = out.len();
160
161        let ciphertext = match &self.inner {
162            CipherInner::ChaCha(cipher) => {
163                cipher.encrypt(nonce.into(), payload)
164                    .map_err(|_| CryptoError::EncryptionFailed("ChaCha20-Poly1305 encryption failed"))?
165            }
166            CipherInner::Aes(cipher) => {
167                cipher.encrypt(nonce.into(), payload)
168                    .map_err(|_| CryptoError::EncryptionFailed("AES-256-GCM encryption failed"))?
169            }
170        };
171
172        out.extend_from_slice(&ciphertext);
173        Ok(out.len() - start_len)
174    }
175
176    /// Decrypt ciphertext with associated data
177    ///
178    /// Verifies authentication tag and returns plaintext.
179    #[inline]
180    pub fn decrypt(&self, nonce: &[u8; 12], ciphertext: &[u8], aad: &[u8]) -> Result<Vec<u8>> {
181        use chacha20poly1305::aead::Aead;
182        use aes_gcm::aead::Payload;
183
184        let payload = Payload { msg: ciphertext, aad };
185
186        match &self.inner {
187            CipherInner::ChaCha(cipher) => {
188                cipher.decrypt(nonce.into(), payload)
189                    .map_err(|_| CryptoError::DecryptionFailed)
190            }
191            CipherInner::Aes(cipher) => {
192                cipher.decrypt(nonce.into(), payload)
193                    .map_err(|_| CryptoError::DecryptionFailed)
194            }
195        }
196    }
197
198    /// Generate a random nonce using OsRng
199    ///
200    /// Note: For high-throughput scenarios, consider using counter-based nonces
201    /// via PacketCipher which avoids syscall overhead.
202    #[inline]
203    pub fn generate_nonce(&self) -> [u8; 12] {
204        match &self.inner {
205            CipherInner::ChaCha(_) => {
206                ChaCha20Poly1305::generate_nonce(&mut rand::rngs::OsRng).into()
207            }
208            CipherInner::Aes(_) => {
209                Aes256Gcm::generate_nonce(&mut rand::rngs::OsRng).into()
210            }
211        }
212    }
213
214    /// Get the cipher suite
215    #[inline(always)]
216    pub fn suite(&self) -> CipherSuite {
217        self.suite
218    }
219}
220
221/// Packet encryptor with automatic nonce management and replay protection.
222///
223/// Implements the OpenVPN AEAD data channel format:
224/// - 4-byte packet ID (big-endian counter)
225/// - Nonce = implicit_iv XOR padded(packet_id)
226/// - On-wire: [packet_id(4)] [AEAD_tag(16)] [ciphertext]
227/// - AAD = packet_id bytes (4 bytes)
228///
229/// # Performance
230/// - Uses counter-based nonces (no RNG syscalls)
231/// - Caches cipher instance for reuse
232/// - Pre-allocates output buffers with known capacity
233pub struct PacketCipher {
234    cipher: Cipher,
235    /// Implicit IV from key derivation, XORed with packet counter to form nonce
236    implicit_iv: [u8; 12],
237    /// Outgoing packet counter (32-bit, matching OpenVPN packet_id_type)
238    tx_counter: u32,
239    /// Replay protection window
240    rx_window: ReplayWindow,
241}
242
243/// Packet ID header size (4-byte counter, matching OpenVPN)
244const PACKET_ID_SIZE: usize = 4;
245
246impl PacketCipher {
247    /// Create a new packet cipher
248    #[inline]
249    pub fn new(key: DataChannelKey) -> Self {
250        let implicit_iv = *key.implicit_iv();
251        Self {
252            cipher: key.cipher(),
253            implicit_iv,
254            tx_counter: 0,
255            rx_window: ReplayWindow::new(),
256        }
257    }
258
259    /// Build a 12-byte AEAD nonce from implicit IV and packet ID.
260    ///
261    /// nonce = implicit_iv XOR [packet_id_be(4) || 00000000(8)]
262    #[inline(always)]
263    fn build_nonce(&self, pid_bytes: &[u8; 4]) -> [u8; 12] {
264        let mut nonce = self.implicit_iv;
265        nonce[0] ^= pid_bytes[0];
266        nonce[1] ^= pid_bytes[1];
267        nonce[2] ^= pid_bytes[2];
268        nonce[3] ^= pid_bytes[3];
269        nonce
270    }
271
272    /// Encrypt a packet (OpenVPN AEAD tag-at-end format).
273    ///
274    /// `ad_prefix` is the header bytes (opcode + peer_id for V2) that precede
275    /// the packet ID in the on-wire format. OpenVPN authenticates these as part
276    /// of the AEAD AAD: AAD = [ad_prefix] [packet_id(4)].
277    ///
278    /// Returns: [packet_id(4)] [ciphertext] [AEAD_tag(16)]
279    ///
280    /// OpenVPN 2.6+ uses AEAD tag at the end by default for AEAD ciphers.
281    /// The AEAD library naturally produces ciphertext||tag, so this is the
282    /// standard format: [pid(4)] [ciphertext||tag].
283    #[inline]
284    pub fn encrypt(&mut self, plaintext: &[u8], ad_prefix: &[u8]) -> Result<Vec<u8>> {
285        self.tx_counter = self.tx_counter.checked_add(1)
286            .ok_or(CryptoError::EncryptionFailed("packet counter overflow"))?;
287
288        let pid_bytes = self.tx_counter.to_be_bytes();
289        let nonce = self.build_nonce(&pid_bytes);
290
291        // Build full AAD: [ad_prefix] [pid(4)]
292        let mut aad = Vec::with_capacity(ad_prefix.len() + PACKET_ID_SIZE);
293        aad.extend_from_slice(ad_prefix);
294        aad.extend_from_slice(&pid_bytes);
295
296        // AEAD encrypt: produces ciphertext || tag (standard AEAD output)
297        let ct_tag = self.cipher.encrypt(&nonce, plaintext, &aad)?;
298
299        // Tag-at-end format: [pid(4)] [ciphertext||tag]
300        let mut output = Vec::with_capacity(PACKET_ID_SIZE + ct_tag.len());
301        output.extend_from_slice(&pid_bytes);
302        output.extend_from_slice(&ct_tag);
303
304        Ok(output)
305    }
306
307    /// Encrypt a packet into a pre-allocated buffer
308    ///
309    /// Returns the total bytes written.
310    /// Buffer should be cleared before calling.
311    #[inline]
312    pub fn encrypt_into(&mut self, plaintext: &[u8], ad_prefix: &[u8], output: &mut Vec<u8>) -> Result<usize> {
313        self.tx_counter = self.tx_counter.checked_add(1)
314            .ok_or(CryptoError::EncryptionFailed("packet counter overflow"))?;
315
316        let pid_bytes = self.tx_counter.to_be_bytes();
317        let nonce = self.build_nonce(&pid_bytes);
318
319        // Build full AAD: [ad_prefix] [pid(4)]
320        let mut aad = Vec::with_capacity(ad_prefix.len() + PACKET_ID_SIZE);
321        aad.extend_from_slice(ad_prefix);
322        aad.extend_from_slice(&pid_bytes);
323
324        let ct_tag = self.cipher.encrypt(&nonce, plaintext, &aad)?;
325
326        let total = PACKET_ID_SIZE + ct_tag.len();
327        output.extend_from_slice(&pid_bytes);
328        output.extend_from_slice(&ct_tag);
329
330        Ok(total)
331    }
332
333    /// Decrypt a packet with replay protection (OpenVPN AEAD format).
334    ///
335    /// `ad_prefix` is the header bytes (opcode + peer_id for V2) that precede
336    /// the packet ID in the on-wire format. OpenVPN authenticates these as part
337    /// of the AEAD AAD: AAD = [ad_prefix] [packet_id(4)].
338    ///
339    /// Supports both tag-at-end and tag-before-ciphertext formats:
340    /// - Tag-at-end (OpenVPN 2.6+): [pid(4)] [ciphertext] [tag(16)]
341    /// - Tag-before (legacy):       [pid(4)] [tag(16)] [ciphertext]
342    ///
343    /// Tries tag-at-end first, falls back to tag-before if decryption fails.
344    #[inline]
345    pub fn decrypt(&mut self, packet: &[u8], ad_prefix: &[u8]) -> Result<Vec<u8>> {
346        const MIN_PACKET_SIZE: usize = PACKET_ID_SIZE + CipherSuite::TAG_SIZE;
347
348        if packet.len() < MIN_PACKET_SIZE {
349            return Err(CryptoError::DecryptionFailed);
350        }
351
352        // Extract 4-byte packet ID
353        let pid_bytes: [u8; 4] = packet[..4].try_into().unwrap();
354        let counter = u32::from_be_bytes(pid_bytes) as u64;
355
356        // Check replay
357        if !self.rx_window.check_and_update(counter) {
358            return Err(CryptoError::ReplayDetected);
359        }
360
361        let nonce = self.build_nonce(&pid_bytes);
362
363        // Build full AAD: [ad_prefix] [pid(4)]
364        let mut aad = Vec::with_capacity(ad_prefix.len() + PACKET_ID_SIZE);
365        aad.extend_from_slice(ad_prefix);
366        aad.extend_from_slice(&pid_bytes);
367
368        // Try tag-at-end first: [pid(4)] [ciphertext||tag]
369        // The AEAD library expects ciphertext||tag, which is exactly packet[4..]
370        let ct_tag = &packet[PACKET_ID_SIZE..];
371        if let Ok(plaintext) = self.cipher.decrypt(&nonce, ct_tag, &aad) {
372            return Ok(plaintext);
373        }
374
375        // Fall back to tag-before: [pid(4)] [tag(16)] [ciphertext]
376        // Reassemble as ciphertext||tag for the AEAD library
377        let tag = &packet[PACKET_ID_SIZE..PACKET_ID_SIZE + CipherSuite::TAG_SIZE];
378        let ct = &packet[PACKET_ID_SIZE + CipherSuite::TAG_SIZE..];
379        let mut ct_tag_reordered = Vec::with_capacity(ct.len() + CipherSuite::TAG_SIZE);
380        ct_tag_reordered.extend_from_slice(ct);
381        ct_tag_reordered.extend_from_slice(tag);
382
383        self.cipher.decrypt(&nonce, &ct_tag_reordered, &aad)
384    }
385
386    /// Get current TX counter (for debugging/stats)
387    #[inline(always)]
388    pub fn tx_counter(&self) -> u64 {
389        self.tx_counter as u64
390    }
391}
392
393/// Sliding window for replay protection
394///
395/// Uses a 128-bit bitmap for efficient replay detection with O(1) operations.
396/// The window tracks the last 128 packet IDs relative to the highest seen.
397struct ReplayWindow {
398    /// Highest seen packet ID
399    highest: u64,
400    /// Bitmap of recently seen packets (relative to highest)
401    /// Bit 0 = highest, bit N = highest - N
402    bitmap: u128,
403}
404
405impl ReplayWindow {
406    /// Window size in packets (128 bits = 128 packet tracking)
407    const WINDOW_SIZE: u64 = 128;
408
409    #[inline]
410    fn new() -> Self {
411        Self {
412            highest: 0,
413            bitmap: 0,
414        }
415    }
416
417    /// Check if packet ID is valid (not replayed) and update window
418    ///
419    /// Returns true if the packet should be processed, false if it's a replay
420    /// or too old.
421    #[inline]
422    fn check_and_update(&mut self, packet_id: u64) -> bool {
423        // Packet ID 0 is invalid (counter starts at 1)
424        if packet_id == 0 {
425            return false;
426        }
427
428        if packet_id > self.highest {
429            // New highest packet - advance window
430            let shift = packet_id - self.highest;
431
432            if shift >= Self::WINDOW_SIZE {
433                // Packet is way ahead, clear entire window
434                self.bitmap = 1; // Only mark current packet
435            } else {
436                // Shift window and mark current packet
437                // Use saturating shift to handle edge cases
438                self.bitmap = (self.bitmap << shift) | 1;
439            }
440            self.highest = packet_id;
441            true
442        } else {
443            // Packet is at or before highest
444            let diff = self.highest - packet_id;
445
446            // Check if packet is within window
447            if diff >= Self::WINDOW_SIZE {
448                return false; // Too old
449            }
450
451            // Check if already seen using bit test
452            let mask = 1u128 << diff;
453            if self.bitmap & mask != 0 {
454                return false; // Replay detected
455            }
456
457            // Mark as seen
458            self.bitmap |= mask;
459            true
460        }
461    }
462
463    /// Reset the replay window (e.g., for key renegotiation)
464    #[allow(dead_code)]
465    #[inline]
466    pub fn reset(&mut self) {
467        self.highest = 0;
468        self.bitmap = 0;
469    }
470}
471
472#[cfg(test)]
473mod tests {
474    use super::*;
475
476    #[test]
477    fn test_encrypt_decrypt() {
478        let key = [0x42u8; 32];
479
480        for suite in [CipherSuite::ChaCha20Poly1305, CipherSuite::Aes256Gcm] {
481            let cipher = Cipher::new(&key, suite);
482            let nonce = cipher.generate_nonce();
483            let plaintext = b"Hello, CoreVPN!";
484            let aad = b"associated data";
485
486            let ciphertext = cipher.encrypt(&nonce, plaintext, aad).unwrap();
487            let decrypted = cipher.decrypt(&nonce, &ciphertext, aad).unwrap();
488
489            assert_eq!(plaintext.as_slice(), decrypted.as_slice());
490        }
491    }
492
493    #[test]
494    fn test_authentication_failure() {
495        let key = [0x42u8; 32];
496        let cipher = Cipher::new(&key, CipherSuite::ChaCha20Poly1305);
497        let nonce = cipher.generate_nonce();
498
499        let ciphertext = cipher.encrypt(&nonce, b"test", b"aad").unwrap();
500
501        // Tamper with ciphertext
502        let mut tampered = ciphertext.clone();
503        tampered[0] ^= 0xFF;
504
505        assert!(cipher.decrypt(&nonce, &tampered, b"aad").is_err());
506    }
507
508    #[test]
509    fn test_packet_cipher_replay_protection() {
510        let iv = [0xABu8; 12];
511        let key = DataChannelKey::new_with_iv([0x42u8; 32], iv, CipherSuite::ChaCha20Poly1305);
512        let mut encryptor = PacketCipher::new(key);
513
514        let key2 = DataChannelKey::new_with_iv([0x42u8; 32], iv, CipherSuite::ChaCha20Poly1305);
515        let mut decryptor = PacketCipher::new(key2);
516
517        let ad = &[0x48u8, 0x00, 0x00, 0x01]; // V2 header: opcode + peer_id
518
519        // Encrypt some packets
520        let p1 = encryptor.encrypt(b"packet 1", ad).unwrap();
521        let p2 = encryptor.encrypt(b"packet 2", ad).unwrap();
522        let p3 = encryptor.encrypt(b"packet 3", ad).unwrap();
523
524        // Decrypt in order - should work
525        assert!(decryptor.decrypt(&p1, ad).is_ok());
526        assert!(decryptor.decrypt(&p2, ad).is_ok());
527
528        // Replay p1 - should fail
529        assert!(decryptor.decrypt(&p1, ad).is_err());
530
531        // p3 out of order - should work
532        assert!(decryptor.decrypt(&p3, ad).is_ok());
533
534        // Replay p3 - should fail
535        assert!(decryptor.decrypt(&p3, ad).is_err());
536    }
537
538    #[test]
539    fn test_replay_window() {
540        let mut window = ReplayWindow::new();
541
542        assert!(window.check_and_update(1));
543        assert!(window.check_and_update(2));
544        assert!(!window.check_and_update(1)); // Replay
545        assert!(window.check_and_update(100));
546        assert!(!window.check_and_update(1)); // Too old
547        assert!(window.check_and_update(99)); // In window
548        assert!(!window.check_and_update(99)); // Replay
549    }
550}