corevpn_crypto/
cipher.rs

1//! Symmetric cipher implementations for data channel encryption
2//!
3//! Supports ChaCha20-Poly1305 (preferred) and AES-256-GCM (fallback).
4//! Both provide authenticated encryption with associated data (AEAD).
5//!
6//! # Performance Optimizations
7//! - Cipher instances are cached in PacketCipher
8//! - Counter-based nonces avoid RNG syscalls
9//! - Pre-allocated output buffers reduce allocations
10//! - Inlined hot paths for better performance
11
12use aes_gcm::{Aes256Gcm, KeyInit};
13use chacha20poly1305::{ChaCha20Poly1305, aead::AeadCore};
14use zeroize::ZeroizeOnDrop;
15use serde::{Serialize, Deserialize};
16
17use crate::{CryptoError, Result};
18
19/// Supported cipher suites
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
21pub enum CipherSuite {
22    /// ChaCha20-Poly1305 - preferred for software implementations
23    #[default]
24    ChaCha20Poly1305,
25    /// AES-256-GCM - hardware accelerated on modern CPUs
26    Aes256Gcm,
27}
28
29impl CipherSuite {
30    /// Key size in bytes (256 bits for both suites)
31    pub const KEY_SIZE: usize = 32;
32    /// Nonce size in bytes (96 bits for both suites)
33    pub const NONCE_SIZE: usize = 12;
34    /// Authentication tag size in bytes (128 bits for both suites)
35    pub const TAG_SIZE: usize = 16;
36
37    /// Get the key size for this cipher suite
38    #[inline(always)]
39    pub const fn key_size(&self) -> usize {
40        Self::KEY_SIZE
41    }
42
43    /// Get the nonce size for this cipher suite
44    #[inline(always)]
45    pub const fn nonce_size(&self) -> usize {
46        Self::NONCE_SIZE
47    }
48
49    /// Get the tag size for this cipher suite
50    #[inline(always)]
51    pub const fn tag_size(&self) -> usize {
52        Self::TAG_SIZE
53    }
54}
55
56/// Data channel encryption key with secure memory handling
57pub struct DataChannelKey {
58    key: [u8; 32],
59    cipher_suite: CipherSuite,
60}
61
62impl DataChannelKey {
63    /// Create a new data channel key
64    pub fn new(key: [u8; 32], cipher_suite: CipherSuite) -> Self {
65        Self { key, cipher_suite }
66    }
67
68    /// Get the cipher suite
69    pub fn cipher_suite(&self) -> CipherSuite {
70        self.cipher_suite
71    }
72
73    /// Create a cipher instance
74    pub fn cipher(&self) -> Cipher {
75        Cipher::new(&self.key, self.cipher_suite)
76    }
77}
78
79impl Drop for DataChannelKey {
80    fn drop(&mut self) {
81        use zeroize::Zeroize;
82        self.key.zeroize();
83    }
84}
85
86impl ZeroizeOnDrop for DataChannelKey {}
87
88/// AEAD cipher for encrypting/decrypting data channel packets
89pub struct Cipher {
90    inner: CipherInner,
91    suite: CipherSuite,
92}
93
94enum CipherInner {
95    ChaCha(ChaCha20Poly1305),
96    Aes(Box<Aes256Gcm>),
97}
98
99impl Cipher {
100    /// Create a new cipher instance
101    #[inline]
102    pub fn new(key: &[u8; 32], suite: CipherSuite) -> Self {
103        let inner = match suite {
104            CipherSuite::ChaCha20Poly1305 => {
105                CipherInner::ChaCha(ChaCha20Poly1305::new(key.into()))
106            }
107            CipherSuite::Aes256Gcm => {
108                CipherInner::Aes(Box::new(Aes256Gcm::new(key.into())))
109            }
110        };
111        Self { inner, suite }
112    }
113
114    /// Encrypt plaintext with associated data
115    ///
116    /// Returns ciphertext with authentication tag appended.
117    #[inline]
118    pub fn encrypt(&self, nonce: &[u8; 12], plaintext: &[u8], aad: &[u8]) -> Result<Vec<u8>> {
119        use chacha20poly1305::aead::Aead;
120        use aes_gcm::aead::Payload;
121
122        let payload = Payload { msg: plaintext, aad };
123
124        match &self.inner {
125            CipherInner::ChaCha(cipher) => {
126                cipher.encrypt(nonce.into(), payload)
127                    .map_err(|_| CryptoError::EncryptionFailed("ChaCha20-Poly1305 encryption failed"))
128            }
129            CipherInner::Aes(cipher) => {
130                cipher.encrypt(nonce.into(), payload)
131                    .map_err(|_| CryptoError::EncryptionFailed("AES-256-GCM encryption failed"))
132            }
133        }
134    }
135
136    /// Encrypt plaintext into pre-allocated buffer
137    ///
138    /// Returns the number of bytes written.
139    /// Buffer must have capacity for plaintext + TAG_SIZE bytes.
140    #[inline]
141    pub fn encrypt_into(&self, nonce: &[u8; 12], plaintext: &[u8], aad: &[u8], out: &mut Vec<u8>) -> Result<usize> {
142        use chacha20poly1305::aead::Aead;
143        use aes_gcm::aead::Payload;
144
145        let payload = Payload { msg: plaintext, aad };
146        let start_len = out.len();
147
148        let ciphertext = match &self.inner {
149            CipherInner::ChaCha(cipher) => {
150                cipher.encrypt(nonce.into(), payload)
151                    .map_err(|_| CryptoError::EncryptionFailed("ChaCha20-Poly1305 encryption failed"))?
152            }
153            CipherInner::Aes(cipher) => {
154                cipher.encrypt(nonce.into(), payload)
155                    .map_err(|_| CryptoError::EncryptionFailed("AES-256-GCM encryption failed"))?
156            }
157        };
158
159        out.extend_from_slice(&ciphertext);
160        Ok(out.len() - start_len)
161    }
162
163    /// Decrypt ciphertext with associated data
164    ///
165    /// Verifies authentication tag and returns plaintext.
166    #[inline]
167    pub fn decrypt(&self, nonce: &[u8; 12], ciphertext: &[u8], aad: &[u8]) -> Result<Vec<u8>> {
168        use chacha20poly1305::aead::Aead;
169        use aes_gcm::aead::Payload;
170
171        let payload = Payload { msg: ciphertext, aad };
172
173        match &self.inner {
174            CipherInner::ChaCha(cipher) => {
175                cipher.decrypt(nonce.into(), payload)
176                    .map_err(|_| CryptoError::DecryptionFailed)
177            }
178            CipherInner::Aes(cipher) => {
179                cipher.decrypt(nonce.into(), payload)
180                    .map_err(|_| CryptoError::DecryptionFailed)
181            }
182        }
183    }
184
185    /// Generate a random nonce using OsRng
186    ///
187    /// Note: For high-throughput scenarios, consider using counter-based nonces
188    /// via PacketCipher which avoids syscall overhead.
189    #[inline]
190    pub fn generate_nonce(&self) -> [u8; 12] {
191        match &self.inner {
192            CipherInner::ChaCha(_) => {
193                ChaCha20Poly1305::generate_nonce(&mut rand::rngs::OsRng).into()
194            }
195            CipherInner::Aes(_) => {
196                Aes256Gcm::generate_nonce(&mut rand::rngs::OsRng).into()
197            }
198        }
199    }
200
201    /// Get the cipher suite
202    #[inline(always)]
203    pub fn suite(&self) -> CipherSuite {
204        self.suite
205    }
206}
207
208/// Packet encryptor with automatic nonce management and replay protection
209///
210/// # Performance
211/// - Uses counter-based nonces (no RNG syscalls)
212/// - Caches cipher instance for reuse
213/// - Pre-allocates output buffers with known capacity
214pub struct PacketCipher {
215    cipher: Cipher,
216    /// Outgoing packet counter (used as nonce)
217    tx_counter: u64,
218    /// Replay protection window
219    rx_window: ReplayWindow,
220}
221
222/// Packet header size (8-byte counter)
223const PACKET_HEADER_SIZE: usize = 8;
224
225impl PacketCipher {
226    /// Create a new packet cipher
227    #[inline]
228    pub fn new(key: DataChannelKey) -> Self {
229        Self {
230            cipher: key.cipher(),
231            tx_counter: 0,
232            rx_window: ReplayWindow::new(),
233        }
234    }
235
236    /// Encrypt a packet
237    ///
238    /// Returns: [8-byte packet_id | ciphertext | 16-byte tag]
239    #[inline]
240    pub fn encrypt(&mut self, plaintext: &[u8]) -> Result<Vec<u8>> {
241        // Increment counter (fail if overflow - extremely unlikely)
242        self.tx_counter = self.tx_counter.checked_add(1)
243            .ok_or(CryptoError::EncryptionFailed("packet counter overflow"))?;
244
245        // Build nonce from counter (padded to 12 bytes)
246        // Using a fixed-size array and copy is faster than iteration
247        let mut nonce = [0u8; 12];
248        let packet_id = self.tx_counter.to_be_bytes();
249        nonce[4..].copy_from_slice(&packet_id);
250
251        // Pre-allocate output with exact capacity
252        // Header (8) + plaintext + tag (16)
253        let output_len = PACKET_HEADER_SIZE + plaintext.len() + CipherSuite::TAG_SIZE;
254        let mut output = Vec::with_capacity(output_len);
255
256        // Write packet ID header
257        output.extend_from_slice(&packet_id);
258
259        // Encrypt directly into output buffer
260        self.cipher.encrypt_into(&nonce, plaintext, &packet_id, &mut output)?;
261
262        Ok(output)
263    }
264
265    /// Encrypt a packet into a pre-allocated buffer
266    ///
267    /// Returns the total bytes written (header + ciphertext + tag).
268    /// Buffer should be cleared before calling.
269    #[inline]
270    pub fn encrypt_into(&mut self, plaintext: &[u8], output: &mut Vec<u8>) -> Result<usize> {
271        self.tx_counter = self.tx_counter.checked_add(1)
272            .ok_or(CryptoError::EncryptionFailed("packet counter overflow"))?;
273
274        let mut nonce = [0u8; 12];
275        let packet_id = self.tx_counter.to_be_bytes();
276        nonce[4..].copy_from_slice(&packet_id);
277
278        output.extend_from_slice(&packet_id);
279        let cipher_bytes = self.cipher.encrypt_into(&nonce, plaintext, &packet_id, output)?;
280
281        Ok(PACKET_HEADER_SIZE + cipher_bytes)
282    }
283
284    /// Decrypt a packet with replay protection
285    #[inline]
286    pub fn decrypt(&mut self, packet: &[u8]) -> Result<Vec<u8>> {
287        const MIN_PACKET_SIZE: usize = PACKET_HEADER_SIZE + CipherSuite::TAG_SIZE;
288
289        if packet.len() < MIN_PACKET_SIZE {
290            return Err(CryptoError::DecryptionFailed);
291        }
292
293        // Extract packet ID using array pattern matching (faster than slice ops)
294        let packet_id: [u8; 8] = packet[..8].try_into().unwrap();
295        let counter = u64::from_be_bytes(packet_id);
296
297        // Check replay (inline for performance)
298        if !self.rx_window.check_and_update(counter) {
299            return Err(CryptoError::ReplayDetected);
300        }
301
302        // Build nonce from packet ID
303        let mut nonce = [0u8; 12];
304        nonce[4..].copy_from_slice(&packet_id);
305
306        // Decrypt
307        self.cipher.decrypt(&nonce, &packet[8..], &packet_id)
308    }
309
310    /// Get current TX counter (for debugging/stats)
311    #[inline(always)]
312    pub fn tx_counter(&self) -> u64 {
313        self.tx_counter
314    }
315}
316
317/// Sliding window for replay protection
318///
319/// Uses a 128-bit bitmap for efficient replay detection with O(1) operations.
320/// The window tracks the last 128 packet IDs relative to the highest seen.
321struct ReplayWindow {
322    /// Highest seen packet ID
323    highest: u64,
324    /// Bitmap of recently seen packets (relative to highest)
325    /// Bit 0 = highest, bit N = highest - N
326    bitmap: u128,
327}
328
329impl ReplayWindow {
330    /// Window size in packets (128 bits = 128 packet tracking)
331    const WINDOW_SIZE: u64 = 128;
332
333    #[inline]
334    fn new() -> Self {
335        Self {
336            highest: 0,
337            bitmap: 0,
338        }
339    }
340
341    /// Check if packet ID is valid (not replayed) and update window
342    ///
343    /// Returns true if the packet should be processed, false if it's a replay
344    /// or too old.
345    #[inline]
346    fn check_and_update(&mut self, packet_id: u64) -> bool {
347        // Packet ID 0 is invalid (counter starts at 1)
348        if packet_id == 0 {
349            return false;
350        }
351
352        if packet_id > self.highest {
353            // New highest packet - advance window
354            let shift = packet_id - self.highest;
355
356            if shift >= Self::WINDOW_SIZE {
357                // Packet is way ahead, clear entire window
358                self.bitmap = 1; // Only mark current packet
359            } else {
360                // Shift window and mark current packet
361                // Use saturating shift to handle edge cases
362                self.bitmap = (self.bitmap << shift) | 1;
363            }
364            self.highest = packet_id;
365            true
366        } else {
367            // Packet is at or before highest
368            let diff = self.highest - packet_id;
369
370            // Check if packet is within window
371            if diff >= Self::WINDOW_SIZE {
372                return false; // Too old
373            }
374
375            // Check if already seen using bit test
376            let mask = 1u128 << diff;
377            if self.bitmap & mask != 0 {
378                return false; // Replay detected
379            }
380
381            // Mark as seen
382            self.bitmap |= mask;
383            true
384        }
385    }
386
387    /// Reset the replay window (e.g., for key renegotiation)
388    #[allow(dead_code)]
389    #[inline]
390    pub fn reset(&mut self) {
391        self.highest = 0;
392        self.bitmap = 0;
393    }
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399
400    #[test]
401    fn test_encrypt_decrypt() {
402        let key = [0x42u8; 32];
403
404        for suite in [CipherSuite::ChaCha20Poly1305, CipherSuite::Aes256Gcm] {
405            let cipher = Cipher::new(&key, suite);
406            let nonce = cipher.generate_nonce();
407            let plaintext = b"Hello, CoreVPN!";
408            let aad = b"associated data";
409
410            let ciphertext = cipher.encrypt(&nonce, plaintext, aad).unwrap();
411            let decrypted = cipher.decrypt(&nonce, &ciphertext, aad).unwrap();
412
413            assert_eq!(plaintext.as_slice(), decrypted.as_slice());
414        }
415    }
416
417    #[test]
418    fn test_authentication_failure() {
419        let key = [0x42u8; 32];
420        let cipher = Cipher::new(&key, CipherSuite::ChaCha20Poly1305);
421        let nonce = cipher.generate_nonce();
422
423        let ciphertext = cipher.encrypt(&nonce, b"test", b"aad").unwrap();
424
425        // Tamper with ciphertext
426        let mut tampered = ciphertext.clone();
427        tampered[0] ^= 0xFF;
428
429        assert!(cipher.decrypt(&nonce, &tampered, b"aad").is_err());
430    }
431
432    #[test]
433    fn test_packet_cipher_replay_protection() {
434        let key = DataChannelKey::new([0x42u8; 32], CipherSuite::ChaCha20Poly1305);
435        let mut encryptor = PacketCipher::new(key);
436
437        let key2 = DataChannelKey::new([0x42u8; 32], CipherSuite::ChaCha20Poly1305);
438        let mut decryptor = PacketCipher::new(key2);
439
440        // Encrypt some packets
441        let p1 = encryptor.encrypt(b"packet 1").unwrap();
442        let p2 = encryptor.encrypt(b"packet 2").unwrap();
443        let p3 = encryptor.encrypt(b"packet 3").unwrap();
444
445        // Decrypt in order - should work
446        assert!(decryptor.decrypt(&p1).is_ok());
447        assert!(decryptor.decrypt(&p2).is_ok());
448
449        // Replay p1 - should fail
450        assert!(decryptor.decrypt(&p1).is_err());
451
452        // p3 out of order - should work
453        assert!(decryptor.decrypt(&p3).is_ok());
454
455        // Replay p3 - should fail
456        assert!(decryptor.decrypt(&p3).is_err());
457    }
458
459    #[test]
460    fn test_replay_window() {
461        let mut window = ReplayWindow::new();
462
463        assert!(window.check_and_update(1));
464        assert!(window.check_and_update(2));
465        assert!(!window.check_and_update(1)); // Replay
466        assert!(window.check_and_update(100));
467        assert!(!window.check_and_update(1)); // Too old
468        assert!(window.check_and_update(99)); // In window
469        assert!(!window.check_and_update(99)); // Replay
470    }
471}