1use bytes::{Bytes, BytesMut, BufMut};
4use corevpn_crypto::{DataChannelKey, PacketCipher};
5
6use crate::{KeyId, OpCode, ProtocolError, Result};
7
8#[derive(Debug, Clone)]
10pub struct DataPacket {
11 pub key_id: KeyId,
13 pub peer_id: Option<u32>,
15 pub payload: Bytes,
17}
18
19impl DataPacket {
20 pub fn new(key_id: KeyId, payload: Bytes) -> Self {
22 Self {
23 key_id,
24 peer_id: None,
25 payload,
26 }
27 }
28
29 pub fn new_v2(key_id: KeyId, peer_id: u32, payload: Bytes) -> Self {
31 Self {
32 key_id,
33 peer_id: Some(peer_id),
34 payload,
35 }
36 }
37
38 pub fn parse(data: &[u8]) -> Result<Self> {
40 if data.is_empty() {
41 return Err(ProtocolError::PacketTooShort {
42 expected: 1,
43 got: 0,
44 });
45 }
46
47 let opcode = OpCode::from_byte(data[0])?;
48 let key_id = KeyId::from_byte(data[0]);
49
50 let (peer_id, payload_start) = if opcode == OpCode::DataV2 {
51 if data.len() < 4 {
52 return Err(ProtocolError::PacketTooShort {
53 expected: 4,
54 got: data.len(),
55 });
56 }
57 let pid = ((data[1] as u32) << 16) | ((data[2] as u32) << 8) | (data[3] as u32);
58 (Some(pid), 4)
59 } else {
60 (None, 1)
61 };
62
63 Ok(Self {
64 key_id,
65 peer_id,
66 payload: Bytes::copy_from_slice(&data[payload_start..]),
67 })
68 }
69
70 pub fn serialize(&self) -> BytesMut {
72 let opcode = if self.peer_id.is_some() {
73 OpCode::DataV2
74 } else {
75 OpCode::DataV1
76 };
77
78 let mut buf = BytesMut::with_capacity(4 + self.payload.len());
79 buf.put_u8(opcode.to_byte(self.key_id));
80
81 if let Some(pid) = self.peer_id {
82 buf.put_u8((pid >> 16) as u8);
83 buf.put_u8((pid >> 8) as u8);
84 buf.put_u8(pid as u8);
85 }
86
87 buf.put_slice(&self.payload);
88 buf
89 }
90}
91
92pub struct DataChannel {
94 key_id: KeyId,
96 peer_id: Option<u32>,
98 encrypt_cipher: PacketCipher,
100 decrypt_cipher: PacketCipher,
102 use_v2: bool,
104 encrypt_ad_prefix: Vec<u8>,
106}
107
108impl DataChannel {
109 pub fn new(
111 key_id: KeyId,
112 encrypt_key: DataChannelKey,
113 decrypt_key: DataChannelKey,
114 use_v2: bool,
115 peer_id: Option<u32>,
116 ) -> Self {
117 let encrypt_ad_prefix = if use_v2 {
124 let opcode_byte = OpCode::DataV2.to_byte(key_id);
125 let pid = peer_id.unwrap_or(0);
126 vec![opcode_byte, (pid >> 16) as u8, (pid >> 8) as u8, pid as u8]
127 } else {
128 vec![]
130 };
131
132 Self {
133 key_id,
134 peer_id,
135 encrypt_cipher: PacketCipher::new(encrypt_key),
136 decrypt_cipher: PacketCipher::new(decrypt_key),
137 use_v2,
138 encrypt_ad_prefix,
139 }
140 }
141
142 pub fn key_id(&self) -> KeyId {
144 self.key_id
145 }
146
147 fn decrypt_ad_prefix(&self, packet: &DataPacket) -> Vec<u8> {
150 if let Some(pid) = packet.peer_id {
151 let opcode_byte = OpCode::DataV2.to_byte(packet.key_id);
152 vec![opcode_byte, (pid >> 16) as u8, (pid >> 8) as u8, pid as u8]
153 } else {
154 vec![]
156 }
157 }
158
159 pub fn encrypt(&mut self, ip_packet: &[u8]) -> Result<DataPacket> {
161 let encrypted = self.encrypt_cipher.encrypt(ip_packet, &self.encrypt_ad_prefix)?;
162
163 Ok(DataPacket {
164 key_id: self.key_id,
165 peer_id: if self.use_v2 { self.peer_id } else { None },
166 payload: Bytes::from(encrypted),
167 })
168 }
169
170 pub fn decrypt(&mut self, packet: &DataPacket) -> Result<Bytes> {
172 if packet.key_id != self.key_id {
173 return Err(ProtocolError::KeyNotAvailable(packet.key_id.0));
174 }
175
176 let ad_prefix = self.decrypt_ad_prefix(packet);
177 let decrypted = self.decrypt_cipher.decrypt(&packet.payload, &ad_prefix)?;
178 Ok(Bytes::from(decrypted))
179 }
180}
181
182#[derive(Debug, Clone, Copy, PartialEq, Eq)]
184pub enum Compression {
185 None,
187 LzoStub,
189 Lz4Stub,
191}
192
193impl Compression {
194 pub fn is_compressed(byte: u8) -> bool {
196 byte == 0xFA || byte == 0xFB
200 }
201
202 pub fn strip_header(data: &[u8]) -> Result<&[u8]> {
204 if data.is_empty() {
205 return Ok(data);
206 }
207
208 match data[0] {
209 0xFA | 0xFB => {
210 Err(ProtocolError::InvalidPacket(
213 "compressed data not supported".into(),
214 ))
215 }
216 0x00 => {
217 Ok(&data[1..])
219 }
220 _ => {
221 Ok(data)
223 }
224 }
225 }
226
227 pub fn add_header(data: &[u8], comp: Compression) -> Vec<u8> {
229 match comp {
230 Compression::None => data.to_vec(),
231 Compression::LzoStub | Compression::Lz4Stub => {
232 let mut out = Vec::with_capacity(1 + data.len());
233 out.push(0x00); out.extend_from_slice(data);
235 out
236 }
237 }
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244 use corevpn_crypto::CipherSuite;
245
246 #[test]
247 fn test_data_packet_v1() {
248 let packet = DataPacket::new(KeyId::new(1), Bytes::from_static(&[1, 2, 3, 4]));
249 let serialized = packet.serialize();
250
251 let parsed = DataPacket::parse(&serialized).unwrap();
252 assert_eq!(parsed.key_id, KeyId::new(1));
253 assert!(parsed.peer_id.is_none());
254 assert_eq!(&parsed.payload[..], &[1, 2, 3, 4]);
255 }
256
257 #[test]
258 fn test_data_packet_v2() {
259 let packet = DataPacket::new_v2(KeyId::new(2), 12345, Bytes::from_static(&[5, 6, 7, 8]));
260 let serialized = packet.serialize();
261
262 let parsed = DataPacket::parse(&serialized).unwrap();
263 assert_eq!(parsed.key_id, KeyId::new(2));
264 assert_eq!(parsed.peer_id, Some(12345));
265 assert_eq!(&parsed.payload[..], &[5, 6, 7, 8]);
266 }
267
268 #[test]
269 fn test_data_channel() {
270 let key1 = DataChannelKey::new([0x42u8; 32], CipherSuite::ChaCha20Poly1305);
271 let key2 = DataChannelKey::new([0x42u8; 32], CipherSuite::ChaCha20Poly1305);
272 let key3 = DataChannelKey::new([0x42u8; 32], CipherSuite::ChaCha20Poly1305);
273 let key4 = DataChannelKey::new([0x42u8; 32], CipherSuite::ChaCha20Poly1305);
274
275 let mut client = DataChannel::new(KeyId::new(0), key1, key2, false, None);
276 let mut server = DataChannel::new(KeyId::new(0), key3, key4, false, None);
277
278 let ip_packet = b"Hello, VPN!";
280 let encrypted = client.encrypt(ip_packet).unwrap();
281
282 assert_eq!(encrypted.key_id, KeyId::new(0));
285 }
286
287 #[test]
288 fn test_compression_strip() {
289 let data = [1, 2, 3, 4];
291 assert_eq!(Compression::strip_header(&data).unwrap(), &[1, 2, 3, 4]);
292
293 let data = [0x00, 1, 2, 3, 4];
295 assert_eq!(Compression::strip_header(&data).unwrap(), &[1, 2, 3, 4]);
296
297 let data = [0xFA, 1, 2, 3, 4];
299 assert!(Compression::strip_header(&data).is_err());
300 }
301}