Skip to main content

corevpn_protocol/
packet.rs

1//! OpenVPN Packet Parsing and Serialization
2//!
3//! # Performance Optimizations
4//! - Zero-copy parsing using Bytes slices
5//! - Inlined hot path functions
6//! - Pre-allocated serialization buffers
7
8use bytes::{BufMut, Bytes, BytesMut};
9
10use crate::{OpCode, KeyId, ProtocolError, Result};
11
12/// Session ID (8 bytes)
13pub type SessionId = [u8; 8];
14
15/// Packet ID (4 bytes) for replay protection
16pub type PacketId = u32;
17
18/// OpenVPN packet header
19#[derive(Debug, Clone)]
20pub struct PacketHeader {
21    /// Packet opcode
22    pub opcode: OpCode,
23    /// Key ID for data channel
24    pub key_id: KeyId,
25    /// Local session ID (for control channel)
26    pub session_id: Option<SessionId>,
27    /// HMAC (if tls-auth enabled)
28    pub hmac: Option<[u8; 32]>,
29    /// Packet ID (for replay protection with tls-auth)
30    pub packet_id: Option<PacketId>,
31    /// Timestamp (for tls-auth)
32    pub timestamp: Option<u32>,
33}
34
35impl PacketHeader {
36    /// Minimum header size (opcode only)
37    pub const MIN_SIZE: usize = 1;
38
39    /// Control channel header size (without HMAC)
40    pub const CONTROL_HEADER_SIZE: usize = 1 + 8; // opcode + session_id
41
42    /// Parse packet header from bytes
43    #[inline]
44    pub fn parse(data: &[u8], has_tls_auth: bool) -> Result<(Self, usize)> {
45        if data.is_empty() {
46            return Err(ProtocolError::PacketTooShort {
47                expected: 1,
48                got: 0,
49            });
50        }
51
52        let opcode = OpCode::from_byte(data[0])?;
53        let key_id = KeyId::from_byte(data[0]);
54
55        if opcode.is_data() {
56            // Data packets: just opcode + key_id, then encrypted payload
57            return Ok((
58                Self {
59                    opcode,
60                    key_id,
61                    session_id: None,
62                    hmac: None,
63                    packet_id: None,
64                    timestamp: None,
65                },
66                1,
67            ));
68        }
69
70        // Control packets have more header fields
71        // OpenVPN wire format: [opcode(1)] [session_id(8)] [HMAC(32)] [pid(4)] [time(4)] [rest...]
72        // Session ID always comes right after opcode, BEFORE tls-auth fields
73        let mut offset = 1;
74        let mut hmac = None;
75        let mut packet_id = None;
76        let mut timestamp = None;
77
78        // Session ID (8 bytes) - always right after opcode
79        if data.len() < offset + 8 {
80            return Err(ProtocolError::PacketTooShort {
81                expected: offset + 8,
82                got: data.len(),
83            });
84        }
85        let mut session_id = [0u8; 8];
86        session_id.copy_from_slice(&data[offset..offset + 8]);
87        offset += 8;
88
89        // Parse HMAC + replay packet ID if tls-auth is enabled
90        if has_tls_auth {
91            // HMAC (32 bytes for SHA256)
92            if data.len() < offset + 32 {
93                return Err(ProtocolError::PacketTooShort {
94                    expected: offset + 32,
95                    got: data.len(),
96                });
97            }
98            let mut h = [0u8; 32];
99            h.copy_from_slice(&data[offset..offset + 32]);
100            hmac = Some(h);
101            offset += 32;
102
103            // Packet ID (4 bytes)
104            if data.len() < offset + 4 {
105                return Err(ProtocolError::PacketTooShort {
106                    expected: offset + 4,
107                    got: data.len(),
108                });
109            }
110            packet_id = Some(u32::from_be_bytes(data[offset..offset + 4].try_into().unwrap()));
111            offset += 4;
112
113            // Timestamp (4 bytes)
114            if data.len() < offset + 4 {
115                return Err(ProtocolError::PacketTooShort {
116                    expected: offset + 4,
117                    got: data.len(),
118                });
119            }
120            timestamp = Some(u32::from_be_bytes(data[offset..offset + 4].try_into().unwrap()));
121            offset += 4;
122        }
123
124        Ok((
125            Self {
126                opcode,
127                key_id,
128                session_id: Some(session_id),
129                hmac,
130                packet_id,
131                timestamp,
132            },
133            offset,
134        ))
135    }
136
137    /// Serialize header to bytes
138    /// Wire format: [opcode(1)] [session_id(8)] [HMAC(32)] [pid(4)] [time(4)] [rest...]
139    #[inline]
140    pub fn serialize(&self, buf: &mut BytesMut) {
141        buf.put_u8(self.opcode.to_byte(self.key_id));
142
143        // Session ID comes right after opcode (before tls-auth fields)
144        if let Some(session_id) = &self.session_id {
145            buf.put_slice(session_id);
146        }
147
148        if let Some(hmac) = &self.hmac {
149            buf.put_slice(hmac);
150        }
151
152        if let Some(packet_id) = self.packet_id {
153            buf.put_u32(packet_id);
154        }
155
156        if let Some(timestamp) = self.timestamp {
157            buf.put_u32(timestamp);
158        }
159    }
160}
161
162/// Parsed OpenVPN packet
163#[derive(Debug, Clone)]
164pub enum Packet {
165    /// Control channel packet
166    Control(ControlPacketData),
167    /// Data channel packet
168    Data(DataPacketData),
169}
170
171/// Control channel packet data
172#[derive(Debug, Clone)]
173pub struct ControlPacketData {
174    /// Packet header
175    pub header: PacketHeader,
176    /// Remote session ID (for ACK packets)
177    pub remote_session_id: Option<SessionId>,
178    /// Acknowledgments
179    pub acks: Vec<PacketId>,
180    /// Message packet ID (for reliability)
181    pub message_packet_id: Option<PacketId>,
182    /// Payload (TLS records)
183    pub payload: Bytes,
184}
185
186/// Data channel packet data
187#[derive(Debug, Clone)]
188pub struct DataPacketData {
189    /// Packet header
190    pub header: PacketHeader,
191    /// Peer ID (for P_DATA_V2)
192    pub peer_id: Option<u32>,
193    /// Encrypted payload
194    pub payload: Bytes,
195}
196
197impl Packet {
198    /// Parse a packet from raw bytes
199    #[inline]
200    pub fn parse(data: &[u8], has_tls_auth: bool) -> Result<Self> {
201        let (header, mut offset) = PacketHeader::parse(data, has_tls_auth)?;
202
203        if header.opcode.is_data() {
204            // Data packet
205            let peer_id = if header.opcode == OpCode::DataV2 {
206                if data.len() < offset + 3 {
207                    return Err(ProtocolError::PacketTooShort {
208                        expected: offset + 3,
209                        got: data.len(),
210                    });
211                }
212                // Peer ID is 24 bits
213                let pid = ((data[offset] as u32) << 16)
214                    | ((data[offset + 1] as u32) << 8)
215                    | (data[offset + 2] as u32);
216                offset += 3;
217                Some(pid)
218            } else {
219                None
220            };
221
222            // Bounds check before slicing
223            if offset > data.len() {
224                return Err(ProtocolError::PacketTooShort {
225                    expected: offset,
226                    got: data.len(),
227                });
228            }
229            return Ok(Packet::Data(DataPacketData {
230                header,
231                peer_id,
232                payload: Bytes::copy_from_slice(&data[offset..]),
233            }));
234        }
235
236        // Control packet - parse additional fields
237        let mut remote_session_id = None;
238        let mut acks = Vec::new();
239
240        // Parse ACK array length
241        if data.len() < offset + 1 {
242            return Err(ProtocolError::PacketTooShort {
243                expected: offset + 1,
244                got: data.len(),
245            });
246        }
247        const MAX_ACK_COUNT: usize = 16; // Reasonable limit to prevent DoS
248        let ack_count = data[offset] as usize;
249        if ack_count > MAX_ACK_COUNT {
250            return Err(ProtocolError::InvalidPacket(
251                format!("ACK count {} exceeds maximum {}", ack_count, MAX_ACK_COUNT).into(),
252            ));
253        }
254        offset += 1;
255
256        // Parse ACKs
257        if ack_count > 0 {
258            // Parse ACK packet IDs
259            for _ in 0..ack_count {
260                if data.len() < offset + 4 {
261                    return Err(ProtocolError::PacketTooShort {
262                        expected: offset + 4,
263                        got: data.len(),
264                    });
265                }
266                acks.push(u32::from_be_bytes(data[offset..offset + 4].try_into().unwrap()));
267                offset += 4;
268            }
269
270            // Parse remote session ID
271            if data.len() < offset + 8 {
272                return Err(ProtocolError::PacketTooShort {
273                    expected: offset + 8,
274                    got: data.len(),
275                });
276            }
277            let mut rsid = [0u8; 8];
278            rsid.copy_from_slice(&data[offset..offset + 8]);
279            remote_session_id = Some(rsid);
280            offset += 8;
281        }
282
283        // Parse message packet ID (if not ACK-only)
284        let message_packet_id = if header.opcode != OpCode::AckV1 && data.len() >= offset + 4 {
285            let id = u32::from_be_bytes(data[offset..offset + 4].try_into().unwrap());
286            offset += 4;
287            Some(id)
288        } else {
289            None
290        };
291
292        // Remaining is payload
293        let payload = if offset < data.len() {
294            Bytes::copy_from_slice(&data[offset..])
295        } else {
296            Bytes::new()
297        };
298
299        Ok(Packet::Control(ControlPacketData {
300            header,
301            remote_session_id,
302            acks,
303            message_packet_id,
304            payload,
305        }))
306    }
307
308    /// Serialize packet to bytes
309    #[inline]
310    pub fn serialize(&self) -> BytesMut {
311        // Pre-allocate typical MTU size to avoid reallocations
312        let mut buf = BytesMut::with_capacity(1500);
313
314        match self {
315            Packet::Control(ctrl) => {
316                ctrl.header.serialize(&mut buf);
317
318                // ACK count
319                buf.put_u8(ctrl.acks.len() as u8);
320
321                // ACKs
322                for ack in &ctrl.acks {
323                    buf.put_u32(*ack);
324                }
325
326                // Remote session ID (if we have ACKs)
327                if !ctrl.acks.is_empty() {
328                    if let Some(rsid) = &ctrl.remote_session_id {
329                        buf.put_slice(rsid);
330                    }
331                }
332
333                // Message packet ID
334                if let Some(mpid) = ctrl.message_packet_id {
335                    buf.put_u32(mpid);
336                }
337
338                // Payload
339                buf.put_slice(&ctrl.payload);
340            }
341            Packet::Data(data) => {
342                data.header.serialize(&mut buf);
343
344                // Peer ID for V2
345                if let Some(pid) = data.peer_id {
346                    buf.put_u8((pid >> 16) as u8);
347                    buf.put_u8((pid >> 8) as u8);
348                    buf.put_u8(pid as u8);
349                }
350
351                // Payload
352                buf.put_slice(&data.payload);
353            }
354        }
355
356        buf
357    }
358
359    /// Get the opcode
360    pub fn opcode(&self) -> OpCode {
361        match self {
362            Packet::Control(c) => c.header.opcode,
363            Packet::Data(d) => d.header.opcode,
364        }
365    }
366
367    /// Get the key ID
368    pub fn key_id(&self) -> KeyId {
369        match self {
370            Packet::Control(c) => c.header.key_id,
371            Packet::Data(d) => d.header.key_id,
372        }
373    }
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379
380    #[test]
381    fn test_hard_reset_parse() {
382        // P_CONTROL_HARD_RESET_CLIENT_V2 with session ID
383        let data = [
384            0x38, // opcode=7, key_id=0
385            0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, // session_id
386            0x00, // ack_count = 0
387        ];
388
389        let packet = Packet::parse(&data, false).unwrap();
390        if let Packet::Control(ctrl) = packet {
391            assert_eq!(ctrl.header.opcode, OpCode::HardResetClientV2);
392            assert_eq!(ctrl.header.key_id, KeyId::new(0));
393            assert_eq!(
394                ctrl.header.session_id,
395                Some([0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08])
396            );
397            assert!(ctrl.acks.is_empty());
398        } else {
399            panic!("Expected control packet");
400        }
401    }
402
403    #[test]
404    fn test_data_packet_v2() {
405        let data = [
406            0x48, // opcode=9 (DataV2), key_id=0
407            0x00, 0x00, 0x01, // peer_id = 1
408            0xDE, 0xAD, 0xBE, 0xEF, // payload
409        ];
410
411        let packet = Packet::parse(&data, false).unwrap();
412        if let Packet::Data(d) = packet {
413            assert_eq!(d.header.opcode, OpCode::DataV2);
414            assert_eq!(d.peer_id, Some(1));
415            assert_eq!(&d.payload[..], &[0xDE, 0xAD, 0xBE, 0xEF]);
416        } else {
417            panic!("Expected data packet");
418        }
419    }
420}