Skip to main content

corevpn_protocol/
data.rs

1//! Data Channel Packet Handling
2
3use bytes::{Bytes, BytesMut, BufMut};
4use corevpn_crypto::{DataChannelKey, PacketCipher};
5
6use crate::{KeyId, OpCode, ProtocolError, Result};
7
8/// Data channel packet
9#[derive(Debug, Clone)]
10pub struct DataPacket {
11    /// Key ID
12    pub key_id: KeyId,
13    /// Peer ID (for P_DATA_V2)
14    pub peer_id: Option<u32>,
15    /// Payload (IP packet)
16    pub payload: Bytes,
17}
18
19impl DataPacket {
20    /// Create a new data packet
21    pub fn new(key_id: KeyId, payload: Bytes) -> Self {
22        Self {
23            key_id,
24            peer_id: None,
25            payload,
26        }
27    }
28
29    /// Create a new data packet with peer ID (V2)
30    pub fn new_v2(key_id: KeyId, peer_id: u32, payload: Bytes) -> Self {
31        Self {
32            key_id,
33            peer_id: Some(peer_id),
34            payload,
35        }
36    }
37
38    /// Parse from raw encrypted packet
39    pub fn parse(data: &[u8]) -> Result<Self> {
40        if data.is_empty() {
41            return Err(ProtocolError::PacketTooShort {
42                expected: 1,
43                got: 0,
44            });
45        }
46
47        let opcode = OpCode::from_byte(data[0])?;
48        let key_id = KeyId::from_byte(data[0]);
49
50        let (peer_id, payload_start) = if opcode == OpCode::DataV2 {
51            if data.len() < 4 {
52                return Err(ProtocolError::PacketTooShort {
53                    expected: 4,
54                    got: data.len(),
55                });
56            }
57            let pid = ((data[1] as u32) << 16) | ((data[2] as u32) << 8) | (data[3] as u32);
58            (Some(pid), 4)
59        } else {
60            (None, 1)
61        };
62
63        Ok(Self {
64            key_id,
65            peer_id,
66            payload: Bytes::copy_from_slice(&data[payload_start..]),
67        })
68    }
69
70    /// Serialize to bytes (header + encrypted payload)
71    pub fn serialize(&self) -> BytesMut {
72        let opcode = if self.peer_id.is_some() {
73            OpCode::DataV2
74        } else {
75            OpCode::DataV1
76        };
77
78        let mut buf = BytesMut::with_capacity(4 + self.payload.len());
79        buf.put_u8(opcode.to_byte(self.key_id));
80
81        if let Some(pid) = self.peer_id {
82            buf.put_u8((pid >> 16) as u8);
83            buf.put_u8((pid >> 8) as u8);
84            buf.put_u8(pid as u8);
85        }
86
87        buf.put_slice(&self.payload);
88        buf
89    }
90}
91
92/// Data channel encryption/decryption handler
93pub struct DataChannel {
94    /// Key ID
95    key_id: KeyId,
96    /// Peer ID (for V2 protocol)
97    peer_id: Option<u32>,
98    /// Encrypt cipher (outgoing)
99    encrypt_cipher: PacketCipher,
100    /// Decrypt cipher (incoming)
101    decrypt_cipher: PacketCipher,
102    /// Whether to use V2 protocol
103    use_v2: bool,
104    /// Cached AAD prefix for encryption (opcode + peer_id header bytes)
105    encrypt_ad_prefix: Vec<u8>,
106}
107
108impl DataChannel {
109    /// Create a new data channel
110    pub fn new(
111        key_id: KeyId,
112        encrypt_key: DataChannelKey,
113        decrypt_key: DataChannelKey,
114        use_v2: bool,
115        peer_id: Option<u32>,
116    ) -> Self {
117        // Build the AAD prefix for encryption (header bytes that precede the packet ID).
118        // OpenVPN's encrypt_sign() in forward.c:
119        //   - P_DATA_V2: header (opcode+peer_id) is prepended to work buffer BEFORE
120        //     openvpn_encrypt(), so AAD = [opcode(1)][peer_id(3)][packet_id(4)]
121        //   - P_DATA_V1: opcode is prepended AFTER encryption (tls_prepend_opcode_v1),
122        //     so AAD = [packet_id(4)] only (no opcode in AAD!)
123        let encrypt_ad_prefix = if use_v2 {
124            let opcode_byte = OpCode::DataV2.to_byte(key_id);
125            let pid = peer_id.unwrap_or(0);
126            vec![opcode_byte, (pid >> 16) as u8, (pid >> 8) as u8, pid as u8]
127        } else {
128            // V1: no header bytes in AAD, just the packet ID (added by PacketCipher)
129            vec![]
130        };
131
132        Self {
133            key_id,
134            peer_id,
135            encrypt_cipher: PacketCipher::new(encrypt_key),
136            decrypt_cipher: PacketCipher::new(decrypt_key),
137            use_v2,
138            encrypt_ad_prefix,
139        }
140    }
141
142    /// Get the key ID
143    pub fn key_id(&self) -> KeyId {
144        self.key_id
145    }
146
147    /// Build AAD prefix from a received packet's header bytes.
148    /// For V2: [opcode_byte(1)] [peer_id(3)]; for V1: empty (no header in AAD)
149    fn decrypt_ad_prefix(&self, packet: &DataPacket) -> Vec<u8> {
150        if let Some(pid) = packet.peer_id {
151            let opcode_byte = OpCode::DataV2.to_byte(packet.key_id);
152            vec![opcode_byte, (pid >> 16) as u8, (pid >> 8) as u8, pid as u8]
153        } else {
154            // V1: no header bytes in AAD (opcode is added after encryption by OpenVPN)
155            vec![]
156        }
157    }
158
159    /// Encrypt an IP packet for transmission
160    pub fn encrypt(&mut self, ip_packet: &[u8]) -> Result<DataPacket> {
161        let encrypted = self.encrypt_cipher.encrypt(ip_packet, &self.encrypt_ad_prefix)?;
162
163        Ok(DataPacket {
164            key_id: self.key_id,
165            peer_id: if self.use_v2 { self.peer_id } else { None },
166            payload: Bytes::from(encrypted),
167        })
168    }
169
170    /// Decrypt a data packet
171    pub fn decrypt(&mut self, packet: &DataPacket) -> Result<Bytes> {
172        if packet.key_id != self.key_id {
173            return Err(ProtocolError::KeyNotAvailable(packet.key_id.0));
174        }
175
176        let ad_prefix = self.decrypt_ad_prefix(packet);
177        let decrypted = self.decrypt_cipher.decrypt(&packet.payload, &ad_prefix)?;
178        Ok(Bytes::from(decrypted))
179    }
180}
181
182/// Compression stub (compression is disabled for security)
183#[derive(Debug, Clone, Copy, PartialEq, Eq)]
184pub enum Compression {
185    /// No compression
186    None,
187    /// LZO stub (accepts but doesn't decompress)
188    LzoStub,
189    /// LZ4 stub
190    Lz4Stub,
191}
192
193impl Compression {
194    /// Check if compression byte indicates compressed data
195    pub fn is_compressed(byte: u8) -> bool {
196        // OpenVPN compression prefixes
197        // 0xFA = LZO compressed
198        // 0xFB = LZ4 compressed
199        byte == 0xFA || byte == 0xFB
200    }
201
202    /// Strip compression header if present (stub mode)
203    pub fn strip_header(data: &[u8]) -> Result<&[u8]> {
204        if data.is_empty() {
205            return Ok(data);
206        }
207
208        match data[0] {
209            0xFA | 0xFB => {
210                // Compressed data - we don't support actual decompression
211                // for security (VORACLE attacks)
212                Err(ProtocolError::InvalidPacket(
213                    "compressed data not supported".into(),
214                ))
215            }
216            0x00 => {
217                // Uncompressed with compression header
218                Ok(&data[1..])
219            }
220            _ => {
221                // No compression header
222                Ok(data)
223            }
224        }
225    }
226
227    /// Add compression header (always uncompressed)
228    pub fn add_header(data: &[u8], comp: Compression) -> Vec<u8> {
229        match comp {
230            Compression::None => data.to_vec(),
231            Compression::LzoStub | Compression::Lz4Stub => {
232                let mut out = Vec::with_capacity(1 + data.len());
233                out.push(0x00); // Uncompressed marker
234                out.extend_from_slice(data);
235                out
236            }
237        }
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244    use corevpn_crypto::CipherSuite;
245
246    #[test]
247    fn test_data_packet_v1() {
248        let packet = DataPacket::new(KeyId::new(1), Bytes::from_static(&[1, 2, 3, 4]));
249        let serialized = packet.serialize();
250
251        let parsed = DataPacket::parse(&serialized).unwrap();
252        assert_eq!(parsed.key_id, KeyId::new(1));
253        assert!(parsed.peer_id.is_none());
254        assert_eq!(&parsed.payload[..], &[1, 2, 3, 4]);
255    }
256
257    #[test]
258    fn test_data_packet_v2() {
259        let packet = DataPacket::new_v2(KeyId::new(2), 12345, Bytes::from_static(&[5, 6, 7, 8]));
260        let serialized = packet.serialize();
261
262        let parsed = DataPacket::parse(&serialized).unwrap();
263        assert_eq!(parsed.key_id, KeyId::new(2));
264        assert_eq!(parsed.peer_id, Some(12345));
265        assert_eq!(&parsed.payload[..], &[5, 6, 7, 8]);
266    }
267
268    #[test]
269    fn test_data_channel() {
270        let key1 = DataChannelKey::new([0x42u8; 32], CipherSuite::ChaCha20Poly1305);
271        let key2 = DataChannelKey::new([0x42u8; 32], CipherSuite::ChaCha20Poly1305);
272        let key3 = DataChannelKey::new([0x42u8; 32], CipherSuite::ChaCha20Poly1305);
273        let key4 = DataChannelKey::new([0x42u8; 32], CipherSuite::ChaCha20Poly1305);
274
275        let mut client = DataChannel::new(KeyId::new(0), key1, key2, false, None);
276        let mut server = DataChannel::new(KeyId::new(0), key3, key4, false, None);
277
278        // Client encrypts
279        let ip_packet = b"Hello, VPN!";
280        let encrypted = client.encrypt(ip_packet).unwrap();
281
282        // Server decrypts (note: in real use, server would use client's key for decrypt)
283        // This test just verifies the packet format
284        assert_eq!(encrypted.key_id, KeyId::new(0));
285    }
286
287    #[test]
288    fn test_compression_strip() {
289        // No compression
290        let data = [1, 2, 3, 4];
291        assert_eq!(Compression::strip_header(&data).unwrap(), &[1, 2, 3, 4]);
292
293        // Uncompressed with header
294        let data = [0x00, 1, 2, 3, 4];
295        assert_eq!(Compression::strip_header(&data).unwrap(), &[1, 2, 3, 4]);
296
297        // Compressed (should error)
298        let data = [0xFA, 1, 2, 3, 4];
299        assert!(Compression::strip_header(&data).is_err());
300    }
301}