Skip to main content

corevpn_protocol/
packet.rs

1//! OpenVPN Packet Parsing and Serialization
2//!
3//! # Performance Optimizations
4//! - Zero-copy parsing using Bytes slices
5//! - Inlined hot path functions
6//! - Pre-allocated serialization buffers
7
8use bytes::{BufMut, Bytes, BytesMut};
9
10use crate::{OpCode, KeyId, ProtocolError, Result};
11
12/// Session ID (8 bytes)
13pub type SessionId = [u8; 8];
14
15/// Packet ID (4 bytes) for replay protection
16pub type PacketId = u32;
17
18/// OpenVPN packet header
19#[derive(Debug, Clone)]
20pub struct PacketHeader {
21    /// Packet opcode
22    pub opcode: OpCode,
23    /// Key ID for data channel
24    pub key_id: KeyId,
25    /// Local session ID (for control channel)
26    pub session_id: Option<SessionId>,
27    /// HMAC (if tls-auth enabled)
28    pub hmac: Option<[u8; 32]>,
29    /// Packet ID (for replay protection with tls-auth)
30    pub packet_id: Option<PacketId>,
31    /// Timestamp (for tls-auth)
32    pub timestamp: Option<u32>,
33}
34
35impl PacketHeader {
36    /// Minimum header size (opcode only)
37    pub const MIN_SIZE: usize = 1;
38
39    /// Control channel header size (without HMAC)
40    pub const CONTROL_HEADER_SIZE: usize = 1 + 8; // opcode + session_id
41
42    /// Parse packet header from bytes
43    #[inline]
44    pub fn parse(data: &[u8], has_tls_auth: bool) -> Result<(Self, usize)> {
45        if data.is_empty() {
46            return Err(ProtocolError::PacketTooShort {
47                expected: 1,
48                got: 0,
49            });
50        }
51
52        let opcode = OpCode::from_byte(data[0])?;
53        let key_id = KeyId::from_byte(data[0]);
54
55        if opcode.is_data() {
56            // Data packets: just opcode + key_id, then encrypted payload
57            return Ok((
58                Self {
59                    opcode,
60                    key_id,
61                    session_id: None,
62                    hmac: None,
63                    packet_id: None,
64                    timestamp: None,
65                },
66                1,
67            ));
68        }
69
70        // Control packets have more header fields
71        let mut offset = 1;
72        let mut hmac = None;
73        let mut packet_id = None;
74        let mut timestamp = None;
75
76        // Parse HMAC if tls-auth is enabled
77        if has_tls_auth {
78            if data.len() < offset + 32 {
79                return Err(ProtocolError::PacketTooShort {
80                    expected: offset + 32,
81                    got: data.len(),
82                });
83            }
84            let mut h = [0u8; 32];
85            h.copy_from_slice(&data[offset..offset + 32]);
86            hmac = Some(h);
87            offset += 32;
88
89            // Packet ID (4 bytes)
90            if data.len() < offset + 4 {
91                return Err(ProtocolError::PacketTooShort {
92                    expected: offset + 4,
93                    got: data.len(),
94                });
95            }
96            packet_id = Some(u32::from_be_bytes(data[offset..offset + 4].try_into().unwrap()));
97            offset += 4;
98
99            // Timestamp (4 bytes)
100            if data.len() < offset + 4 {
101                return Err(ProtocolError::PacketTooShort {
102                    expected: offset + 4,
103                    got: data.len(),
104                });
105            }
106            timestamp = Some(u32::from_be_bytes(data[offset..offset + 4].try_into().unwrap()));
107            offset += 4;
108        }
109
110        // Session ID (8 bytes)
111        if data.len() < offset + 8 {
112            return Err(ProtocolError::PacketTooShort {
113                expected: offset + 8,
114                got: data.len(),
115            });
116        }
117        let mut session_id = [0u8; 8];
118        session_id.copy_from_slice(&data[offset..offset + 8]);
119        offset += 8;
120
121        Ok((
122            Self {
123                opcode,
124                key_id,
125                session_id: Some(session_id),
126                hmac,
127                packet_id,
128                timestamp,
129            },
130            offset,
131        ))
132    }
133
134    /// Serialize header to bytes
135    #[inline]
136    pub fn serialize(&self, buf: &mut BytesMut) {
137        buf.put_u8(self.opcode.to_byte(self.key_id));
138
139        if let Some(hmac) = &self.hmac {
140            buf.put_slice(hmac);
141        }
142
143        if let Some(packet_id) = self.packet_id {
144            buf.put_u32(packet_id);
145        }
146
147        if let Some(timestamp) = self.timestamp {
148            buf.put_u32(timestamp);
149        }
150
151        if let Some(session_id) = &self.session_id {
152            buf.put_slice(session_id);
153        }
154    }
155}
156
157/// Parsed OpenVPN packet
158#[derive(Debug, Clone)]
159pub enum Packet {
160    /// Control channel packet
161    Control(ControlPacketData),
162    /// Data channel packet
163    Data(DataPacketData),
164}
165
166/// Control channel packet data
167#[derive(Debug, Clone)]
168pub struct ControlPacketData {
169    /// Packet header
170    pub header: PacketHeader,
171    /// Remote session ID (for ACK packets)
172    pub remote_session_id: Option<SessionId>,
173    /// Acknowledgments
174    pub acks: Vec<PacketId>,
175    /// Message packet ID (for reliability)
176    pub message_packet_id: Option<PacketId>,
177    /// Payload (TLS records)
178    pub payload: Bytes,
179}
180
181/// Data channel packet data
182#[derive(Debug, Clone)]
183pub struct DataPacketData {
184    /// Packet header
185    pub header: PacketHeader,
186    /// Peer ID (for P_DATA_V2)
187    pub peer_id: Option<u32>,
188    /// Encrypted payload
189    pub payload: Bytes,
190}
191
192impl Packet {
193    /// Parse a packet from raw bytes
194    #[inline]
195    pub fn parse(data: &[u8], has_tls_auth: bool) -> Result<Self> {
196        let (header, mut offset) = PacketHeader::parse(data, has_tls_auth)?;
197
198        if header.opcode.is_data() {
199            // Data packet
200            let peer_id = if header.opcode == OpCode::DataV2 {
201                if data.len() < offset + 3 {
202                    return Err(ProtocolError::PacketTooShort {
203                        expected: offset + 3,
204                        got: data.len(),
205                    });
206                }
207                // Peer ID is 24 bits
208                let pid = ((data[offset] as u32) << 16)
209                    | ((data[offset + 1] as u32) << 8)
210                    | (data[offset + 2] as u32);
211                offset += 3;
212                Some(pid)
213            } else {
214                None
215            };
216
217            // Bounds check before slicing
218            if offset > data.len() {
219                return Err(ProtocolError::PacketTooShort {
220                    expected: offset,
221                    got: data.len(),
222                });
223            }
224            return Ok(Packet::Data(DataPacketData {
225                header,
226                peer_id,
227                payload: Bytes::copy_from_slice(&data[offset..]),
228            }));
229        }
230
231        // Control packet - parse additional fields
232        let mut remote_session_id = None;
233        let mut acks = Vec::new();
234
235        // Parse ACK array length
236        if data.len() < offset + 1 {
237            return Err(ProtocolError::PacketTooShort {
238                expected: offset + 1,
239                got: data.len(),
240            });
241        }
242        const MAX_ACK_COUNT: usize = 16; // Reasonable limit to prevent DoS
243        let ack_count = data[offset] as usize;
244        if ack_count > MAX_ACK_COUNT {
245            return Err(ProtocolError::InvalidPacket(
246                format!("ACK count {} exceeds maximum {}", ack_count, MAX_ACK_COUNT).into(),
247            ));
248        }
249        offset += 1;
250
251        // Parse ACKs
252        if ack_count > 0 {
253            // Parse ACK packet IDs
254            for _ in 0..ack_count {
255                if data.len() < offset + 4 {
256                    return Err(ProtocolError::PacketTooShort {
257                        expected: offset + 4,
258                        got: data.len(),
259                    });
260                }
261                acks.push(u32::from_be_bytes(data[offset..offset + 4].try_into().unwrap()));
262                offset += 4;
263            }
264
265            // Parse remote session ID
266            if data.len() < offset + 8 {
267                return Err(ProtocolError::PacketTooShort {
268                    expected: offset + 8,
269                    got: data.len(),
270                });
271            }
272            let mut rsid = [0u8; 8];
273            rsid.copy_from_slice(&data[offset..offset + 8]);
274            remote_session_id = Some(rsid);
275            offset += 8;
276        }
277
278        // Parse message packet ID (if not ACK-only)
279        let message_packet_id = if header.opcode != OpCode::AckV1 && data.len() >= offset + 4 {
280            let id = u32::from_be_bytes(data[offset..offset + 4].try_into().unwrap());
281            offset += 4;
282            Some(id)
283        } else {
284            None
285        };
286
287        // Remaining is payload
288        let payload = if offset < data.len() {
289            Bytes::copy_from_slice(&data[offset..])
290        } else {
291            Bytes::new()
292        };
293
294        Ok(Packet::Control(ControlPacketData {
295            header,
296            remote_session_id,
297            acks,
298            message_packet_id,
299            payload,
300        }))
301    }
302
303    /// Serialize packet to bytes
304    #[inline]
305    pub fn serialize(&self) -> BytesMut {
306        // Pre-allocate typical MTU size to avoid reallocations
307        let mut buf = BytesMut::with_capacity(1500);
308
309        match self {
310            Packet::Control(ctrl) => {
311                ctrl.header.serialize(&mut buf);
312
313                // ACK count
314                buf.put_u8(ctrl.acks.len() as u8);
315
316                // ACKs
317                for ack in &ctrl.acks {
318                    buf.put_u32(*ack);
319                }
320
321                // Remote session ID (if we have ACKs)
322                if !ctrl.acks.is_empty() {
323                    if let Some(rsid) = &ctrl.remote_session_id {
324                        buf.put_slice(rsid);
325                    }
326                }
327
328                // Message packet ID
329                if let Some(mpid) = ctrl.message_packet_id {
330                    buf.put_u32(mpid);
331                }
332
333                // Payload
334                buf.put_slice(&ctrl.payload);
335            }
336            Packet::Data(data) => {
337                data.header.serialize(&mut buf);
338
339                // Peer ID for V2
340                if let Some(pid) = data.peer_id {
341                    buf.put_u8((pid >> 16) as u8);
342                    buf.put_u8((pid >> 8) as u8);
343                    buf.put_u8(pid as u8);
344                }
345
346                // Payload
347                buf.put_slice(&data.payload);
348            }
349        }
350
351        buf
352    }
353
354    /// Get the opcode
355    pub fn opcode(&self) -> OpCode {
356        match self {
357            Packet::Control(c) => c.header.opcode,
358            Packet::Data(d) => d.header.opcode,
359        }
360    }
361
362    /// Get the key ID
363    pub fn key_id(&self) -> KeyId {
364        match self {
365            Packet::Control(c) => c.header.key_id,
366            Packet::Data(d) => d.header.key_id,
367        }
368    }
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374
375    #[test]
376    fn test_hard_reset_parse() {
377        // P_CONTROL_HARD_RESET_CLIENT_V2 with session ID
378        let data = [
379            0x38, // opcode=7, key_id=0
380            0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, // session_id
381            0x00, // ack_count = 0
382        ];
383
384        let packet = Packet::parse(&data, false).unwrap();
385        if let Packet::Control(ctrl) = packet {
386            assert_eq!(ctrl.header.opcode, OpCode::HardResetClientV2);
387            assert_eq!(ctrl.header.key_id, KeyId::new(0));
388            assert_eq!(
389                ctrl.header.session_id,
390                Some([0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08])
391            );
392            assert!(ctrl.acks.is_empty());
393        } else {
394            panic!("Expected control packet");
395        }
396    }
397
398    #[test]
399    fn test_data_packet_v2() {
400        let data = [
401            0x48, // opcode=9 (DataV2), key_id=0
402            0x00, 0x00, 0x01, // peer_id = 1
403            0xDE, 0xAD, 0xBE, 0xEF, // payload
404        ];
405
406        let packet = Packet::parse(&data, false).unwrap();
407        if let Packet::Data(d) = packet {
408            assert_eq!(d.header.opcode, OpCode::DataV2);
409            assert_eq!(d.peer_id, Some(1));
410            assert_eq!(&d.payload[..], &[0xDE, 0xAD, 0xBE, 0xEF]);
411        } else {
412            panic!("Expected data packet");
413        }
414    }
415}