1use bytes::{Bytes, BytesMut, BufMut};
4use corevpn_crypto::{DataChannelKey, PacketCipher};
5
6use crate::{KeyId, OpCode, ProtocolError, Result};
7
8#[derive(Debug, Clone)]
10pub struct DataPacket {
11 pub key_id: KeyId,
13 pub peer_id: Option<u32>,
15 pub payload: Bytes,
17}
18
19impl DataPacket {
20 pub fn new(key_id: KeyId, payload: Bytes) -> Self {
22 Self {
23 key_id,
24 peer_id: None,
25 payload,
26 }
27 }
28
29 pub fn new_v2(key_id: KeyId, peer_id: u32, payload: Bytes) -> Self {
31 Self {
32 key_id,
33 peer_id: Some(peer_id),
34 payload,
35 }
36 }
37
38 pub fn parse(data: &[u8]) -> Result<Self> {
40 if data.is_empty() {
41 return Err(ProtocolError::PacketTooShort {
42 expected: 1,
43 got: 0,
44 });
45 }
46
47 let opcode = OpCode::from_byte(data[0])?;
48 let key_id = KeyId::from_byte(data[0]);
49
50 let (peer_id, payload_start) = if opcode == OpCode::DataV2 {
51 if data.len() < 4 {
52 return Err(ProtocolError::PacketTooShort {
53 expected: 4,
54 got: data.len(),
55 });
56 }
57 let pid = ((data[1] as u32) << 16) | ((data[2] as u32) << 8) | (data[3] as u32);
58 (Some(pid), 4)
59 } else {
60 (None, 1)
61 };
62
63 Ok(Self {
64 key_id,
65 peer_id,
66 payload: Bytes::copy_from_slice(&data[payload_start..]),
67 })
68 }
69
70 pub fn serialize(&self) -> BytesMut {
72 let opcode = if self.peer_id.is_some() {
73 OpCode::DataV2
74 } else {
75 OpCode::DataV1
76 };
77
78 let mut buf = BytesMut::with_capacity(4 + self.payload.len());
79 buf.put_u8(opcode.to_byte(self.key_id));
80
81 if let Some(pid) = self.peer_id {
82 buf.put_u8((pid >> 16) as u8);
83 buf.put_u8((pid >> 8) as u8);
84 buf.put_u8(pid as u8);
85 }
86
87 buf.put_slice(&self.payload);
88 buf
89 }
90}
91
92pub struct DataChannel {
94 key_id: KeyId,
96 peer_id: Option<u32>,
98 encrypt_cipher: PacketCipher,
100 decrypt_cipher: PacketCipher,
102 use_v2: bool,
104}
105
106impl DataChannel {
107 pub fn new(
109 key_id: KeyId,
110 encrypt_key: DataChannelKey,
111 decrypt_key: DataChannelKey,
112 use_v2: bool,
113 peer_id: Option<u32>,
114 ) -> Self {
115 Self {
116 key_id,
117 peer_id,
118 encrypt_cipher: PacketCipher::new(encrypt_key),
119 decrypt_cipher: PacketCipher::new(decrypt_key),
120 use_v2,
121 }
122 }
123
124 pub fn key_id(&self) -> KeyId {
126 self.key_id
127 }
128
129 pub fn encrypt(&mut self, ip_packet: &[u8]) -> Result<DataPacket> {
131 let encrypted = self.encrypt_cipher.encrypt(ip_packet)?;
132
133 Ok(DataPacket {
134 key_id: self.key_id,
135 peer_id: if self.use_v2 { self.peer_id } else { None },
136 payload: Bytes::from(encrypted),
137 })
138 }
139
140 pub fn decrypt(&mut self, packet: &DataPacket) -> Result<Bytes> {
142 if packet.key_id != self.key_id {
143 return Err(ProtocolError::KeyNotAvailable(packet.key_id.0));
144 }
145
146 let decrypted = self.decrypt_cipher.decrypt(&packet.payload)?;
147 Ok(Bytes::from(decrypted))
148 }
149}
150
151#[derive(Debug, Clone, Copy, PartialEq, Eq)]
153pub enum Compression {
154 None,
156 LzoStub,
158 Lz4Stub,
160}
161
162impl Compression {
163 pub fn is_compressed(byte: u8) -> bool {
165 byte == 0xFA || byte == 0xFB
169 }
170
171 pub fn strip_header(data: &[u8]) -> Result<&[u8]> {
173 if data.is_empty() {
174 return Ok(data);
175 }
176
177 match data[0] {
178 0xFA | 0xFB => {
179 Err(ProtocolError::InvalidPacket(
182 "compressed data not supported".into(),
183 ))
184 }
185 0x00 => {
186 Ok(&data[1..])
188 }
189 _ => {
190 Ok(data)
192 }
193 }
194 }
195
196 pub fn add_header(data: &[u8], comp: Compression) -> Vec<u8> {
198 match comp {
199 Compression::None => data.to_vec(),
200 Compression::LzoStub | Compression::Lz4Stub => {
201 let mut out = Vec::with_capacity(1 + data.len());
202 out.push(0x00); out.extend_from_slice(data);
204 out
205 }
206 }
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use corevpn_crypto::CipherSuite;
214
215 #[test]
216 fn test_data_packet_v1() {
217 let packet = DataPacket::new(KeyId::new(1), Bytes::from_static(&[1, 2, 3, 4]));
218 let serialized = packet.serialize();
219
220 let parsed = DataPacket::parse(&serialized).unwrap();
221 assert_eq!(parsed.key_id, KeyId::new(1));
222 assert!(parsed.peer_id.is_none());
223 assert_eq!(&parsed.payload[..], &[1, 2, 3, 4]);
224 }
225
226 #[test]
227 fn test_data_packet_v2() {
228 let packet = DataPacket::new_v2(KeyId::new(2), 12345, Bytes::from_static(&[5, 6, 7, 8]));
229 let serialized = packet.serialize();
230
231 let parsed = DataPacket::parse(&serialized).unwrap();
232 assert_eq!(parsed.key_id, KeyId::new(2));
233 assert_eq!(parsed.peer_id, Some(12345));
234 assert_eq!(&parsed.payload[..], &[5, 6, 7, 8]);
235 }
236
237 #[test]
238 fn test_data_channel() {
239 let key1 = DataChannelKey::new([0x42u8; 32], CipherSuite::ChaCha20Poly1305);
240 let key2 = DataChannelKey::new([0x42u8; 32], CipherSuite::ChaCha20Poly1305);
241 let key3 = DataChannelKey::new([0x42u8; 32], CipherSuite::ChaCha20Poly1305);
242 let key4 = DataChannelKey::new([0x42u8; 32], CipherSuite::ChaCha20Poly1305);
243
244 let mut client = DataChannel::new(KeyId::new(0), key1, key2, false, None);
245 let mut server = DataChannel::new(KeyId::new(0), key3, key4, false, None);
246
247 let ip_packet = b"Hello, VPN!";
249 let encrypted = client.encrypt(ip_packet).unwrap();
250
251 assert_eq!(encrypted.key_id, KeyId::new(0));
254 }
255
256 #[test]
257 fn test_compression_strip() {
258 let data = [1, 2, 3, 4];
260 assert_eq!(Compression::strip_header(&data).unwrap(), &[1, 2, 3, 4]);
261
262 let data = [0x00, 1, 2, 3, 4];
264 assert_eq!(Compression::strip_header(&data).unwrap(), &[1, 2, 3, 4]);
265
266 let data = [0xFA, 1, 2, 3, 4];
268 assert!(Compression::strip_header(&data).is_err());
269 }
270}