1use super::AckHeader;
2use crate::{
3 buffer_pool::{BufHandle, BufPool},
4 cursor::{BufferLimitedWriter, CursorExtras},
5 prelude::PicklebackConfig,
6 PacketId, PicklebackError,
7};
8use byteorder::*;
9use std::{
10 io::{Cursor, Write},
11 net::SocketAddr,
12};
13
14use super::DisconnectReason;
15
16#[derive(Clone, Eq, PartialEq)]
17pub struct AddressedPacket {
18 pub address: SocketAddr,
19 pub packet: BufHandle,
20}
21
22#[derive(Debug)]
23pub(crate) enum ProtocolPacket {
24 ConnectionRequest(ConnectionRequestPacket),
26 ConnectionChallenge(ConnectionChallengePacket),
28 ConnectionChallengeResponse(ConnectionChallengeResponsePacket),
30 ConnectionDenied(ConnectionDeniedPacket),
32 Messages(MessagesPacket),
34 Disconnect(DisconnectPacket),
36 KeepAlive(KeepAlivePacket),
38}
39
40#[derive(Debug, Copy, Clone)]
41#[repr(u8)]
42pub(crate) enum PacketType {
43 ConnectionRequest = 1,
44 ConnectionChallenge = 2,
45 ConnectionChallengeResponse = 3,
46 ConnectionDenied = 4,
47 Messages = 5,
48 Disconnect = 6,
49 KeepAlive = 7,
50}
51
52impl TryFrom<u8> for PacketType {
53 type Error = PicklebackError;
54
55 fn try_from(value: u8) -> Result<Self, Self::Error> {
56 match value {
57 1 => Ok(PacketType::ConnectionRequest),
58 2 => Ok(PacketType::ConnectionChallenge),
59 3 => Ok(PacketType::ConnectionChallengeResponse),
60 4 => Ok(PacketType::ConnectionDenied),
61 5 => Ok(PacketType::Messages),
62 6 => Ok(PacketType::Disconnect),
63 7 => Ok(PacketType::KeepAlive),
64 _ => Err(PicklebackError::InvalidPacket),
65 }
66 }
67}
68
69impl From<&ProtocolPacket> for PacketType {
70 fn from(val: &ProtocolPacket) -> Self {
71 match val {
72 ProtocolPacket::ConnectionRequest(_) => PacketType::ConnectionRequest,
73 ProtocolPacket::ConnectionChallenge(_) => PacketType::ConnectionChallenge,
74 ProtocolPacket::ConnectionChallengeResponse(_) => {
75 PacketType::ConnectionChallengeResponse
76 }
77 ProtocolPacket::ConnectionDenied(_) => PacketType::ConnectionDenied,
78 ProtocolPacket::Messages(_) => PacketType::Messages,
79 ProtocolPacket::Disconnect(_) => PacketType::Disconnect,
80 ProtocolPacket::KeepAlive(_) => PacketType::KeepAlive,
81 }
82 }
83}
84
85#[derive(Debug)]
86pub(crate) struct ProtocolPacketHeader {
87 pub(crate) packet_type: PacketType,
88 pub(crate) id: PacketId,
89 pub(crate) ack_header: Option<AckHeader>,
90}
91
92impl ProtocolPacketHeader {
93 pub(crate) fn new(
94 id: PacketId,
95 ack_iter: impl Iterator<Item = (u16, bool)>,
96 num_acks: u16,
97 packet_type: PacketType,
98 ) -> Result<Self, PicklebackError> {
99 if num_acks == 0 {
100 return Self::new_no_acks(id, packet_type);
101 }
102 let ack_header = AckHeader::from_ack_iter(num_acks, ack_iter)?;
103 Ok(Self {
104 packet_type,
105 id,
106 ack_header: Some(ack_header),
107 })
108 }
109 pub(crate) fn new_no_acks(
110 id: PacketId,
111 packet_type: PacketType,
112 ) -> Result<Self, PicklebackError> {
113 Ok(Self {
114 packet_type,
115 id,
116 ack_header: None,
117 })
118 }
119
120 pub(crate) fn id(&self) -> PacketId {
121 self.id
122 }
123 pub(crate) fn ack_id(&self) -> Option<PacketId> {
124 self.ack_header.map(|header| header.ack_id())
125 }
126 pub(crate) fn acks(&self) -> Option<impl Iterator<Item = (u16, bool)>> {
127 self.ack_header.map(|header| header.into_iter())
128 }
129
130 #[allow(unused)]
131 pub(crate) fn ack_header(&self) -> Option<&AckHeader> {
132 self.ack_header.as_ref()
133 }
134
135 pub(crate) fn size(&self) -> usize {
136 1 + 2 + self.ack_header.map_or(0, |header| header.size())
139 }
140
141 pub(crate) fn write(&self, mut writer: &mut impl Write) -> Result<(), PicklebackError> {
142 let mut prefix_byte = self.packet_type as u8;
143 if self.ack_header.is_some() {
144 prefix_byte |= 0b1000_0000;
146 }
147 writer.write_u8(prefix_byte)?;
148 writer.write_u16::<NetworkEndian>(self.id.0)?;
149 if let Some(ack_header) = self.ack_header {
150 ack_header.write(&mut writer)?;
151 }
152 Ok(())
153 }
154 pub(crate) fn parse(reader: &mut Cursor<&[u8]>) -> Result<Self, PicklebackError> {
155 let prefix_byte = reader.read_u8()?;
156 let ack_header_present = prefix_byte & 0b1000_0000 != 0;
157 let Ok(packet_type) = PacketType::try_from(prefix_byte & 0b0111_1111) else {
158 log::error!("prefix byte packet type invalid");
159 return Err(PicklebackError::InvalidPacket);
160 };
161 let id = PacketId(reader.read_u16::<NetworkEndian>()?);
162 let ack_header = if ack_header_present {
163 Some(AckHeader::parse(reader)?)
164 } else {
165 None
166 };
167 Ok(Self {
168 packet_type,
169 id,
170 ack_header,
171 })
172 }
173}
174
175#[derive(Debug)]
177pub(crate) struct ConnectionRequestPacket {
178 pub(crate) header: ProtocolPacketHeader,
179 pub(crate) client_salt: u64,
180 pub(crate) protocol_version: u64,
181 }
183
184#[derive(Debug)]
186pub(crate) struct ConnectionChallengePacket {
187 pub(crate) header: ProtocolPacketHeader,
188 pub(crate) client_salt: u64,
189 pub(crate) server_salt: u64,
190}
191
192#[derive(Debug)]
194pub(crate) struct ConnectionChallengeResponsePacket {
195 pub(crate) header: ProtocolPacketHeader,
196 pub(crate) xor_salt: u64,
197}
198
199#[derive(Debug)]
201pub(crate) struct ConnectionDeniedPacket {
202 pub(crate) header: ProtocolPacketHeader,
203 pub(crate) reason: DisconnectReason,
204}
205
206pub(crate) struct MessagesPacket {
208 pub(crate) header: ProtocolPacketHeader,
209 pub(crate) xor_salt: u64,
210 }
212
213impl std::fmt::Debug for MessagesPacket {
214 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215 write!(
216 f,
217 "MessagesPacket[header: {:?} xor_salt: {}]]",
218 self.header, self.xor_salt
219 )
220 }
221}
222
223#[derive(Debug)]
225pub(crate) struct DisconnectPacket {
226 pub(crate) header: ProtocolPacketHeader,
227 pub(crate) xor_salt: u64,
228}
229
230#[derive(Debug)]
232pub(crate) struct KeepAlivePacket {
233 pub(crate) header: ProtocolPacketHeader,
234 pub(crate) xor_salt: u64,
235 pub(crate) client_index: u32,
236}
237
238pub(crate) fn write_zero_bytes<W: Write>(writer: &mut W, num_bytes: usize) -> std::io::Result<()> {
239 let buffer = vec![0u8; num_bytes];
240 writer.write_all(&buffer)
241}
242
243pub(crate) fn write_packet(
244 pool: &BufPool,
245 config: &PicklebackConfig,
246 packet: ProtocolPacket,
247) -> Result<BufHandle, PicklebackError> {
248 let max_packet_size = config.max_packet_size;
249 let mut buffer = pool.get_buffer(max_packet_size);
250 let mut writer = BufferLimitedWriter::new(Cursor::new(&mut buffer), max_packet_size);
251
252 match packet {
253 ProtocolPacket::KeepAlive(KeepAlivePacket {
254 header,
255 xor_salt,
256 client_index,
257 }) => {
258 header.write(&mut writer)?;
259 writer.write_u64::<NetworkEndian>(xor_salt)?;
260 writer.write_u32::<NetworkEndian>(client_index)?;
261 }
262 ProtocolPacket::ConnectionRequest(ConnectionRequestPacket {
263 header,
264 client_salt,
265 protocol_version,
266 }) => {
267 header.write(&mut writer)?;
268 writer.write_u64::<NetworkEndian>(client_salt)?;
269 writer.write_u64::<NetworkEndian>(protocol_version)?;
270 write_zero_bytes(&mut writer, 500)?;
271 }
272 ProtocolPacket::ConnectionChallenge(ConnectionChallengePacket {
273 header,
274 client_salt,
275 server_salt,
276 }) => {
277 header.write(&mut writer)?;
278 writer.write_u64::<NetworkEndian>(client_salt)?;
279 writer.write_u64::<NetworkEndian>(server_salt)?;
280 write_zero_bytes(&mut writer, 500)?;
281 }
282 ProtocolPacket::ConnectionChallengeResponse(ConnectionChallengeResponsePacket {
283 header,
284 xor_salt,
285 }) => {
286 header.write(&mut writer)?;
287 writer.write_u64::<NetworkEndian>(xor_salt)?;
288 write_zero_bytes(&mut writer, 500)?;
289 }
290 ProtocolPacket::ConnectionDenied(ConnectionDeniedPacket { header, reason }) => {
291 header.write(&mut writer)?;
292 writer.write_u8(reason as u8)?;
293 }
294 ProtocolPacket::Disconnect(DisconnectPacket { header, xor_salt }) => {
295 header.write(&mut writer)?;
296 writer.write_u64::<NetworkEndian>(xor_salt)?;
297 }
298 ProtocolPacket::Messages(MessagesPacket { .. }) => {
299 panic!("written elsewhere");
301 }
302 }
303 Ok(buffer)
304}
305
306pub(crate) fn read_packet(reader: &mut Cursor<&[u8]>) -> Result<ProtocolPacket, PicklebackError> {
307 let header = ProtocolPacketHeader::parse(reader)?;
308 match header.packet_type {
309 PacketType::KeepAlive => {
310 let c = KeepAlivePacket {
311 header,
312 xor_salt: reader.read_u64::<NetworkEndian>()?,
313 client_index: reader.read_u32::<NetworkEndian>()?,
314 };
315 Ok(ProtocolPacket::KeepAlive(c))
316 }
317 PacketType::ConnectionRequest => {
318 let c = ConnectionRequestPacket {
319 header,
320 client_salt: reader.read_u64::<NetworkEndian>()?,
321 protocol_version: reader.read_u64::<NetworkEndian>()?,
322 };
323 if reader.remaining() != 500 {
324 log::warn!("Invalid remaining len for ConnectionRequestPacket");
325 return Err(PicklebackError::InvalidPacket);
326 }
327 Ok(ProtocolPacket::ConnectionRequest(c))
328 }
329 PacketType::ConnectionChallenge => {
330 let c = ConnectionChallengePacket {
331 header,
332 client_salt: reader.read_u64::<NetworkEndian>()?,
333 server_salt: reader.read_u64::<NetworkEndian>()?,
334 };
335 if reader.remaining() != 500 {
336 log::warn!("Invalid remaining len for ConnectionChallengePacket");
337 return Err(PicklebackError::InvalidPacket);
338 }
339 Ok(ProtocolPacket::ConnectionChallenge(c))
340 }
341 PacketType::ConnectionChallengeResponse => {
342 let c = ConnectionChallengeResponsePacket {
343 header,
344 xor_salt: reader.read_u64::<NetworkEndian>()?,
345 };
346 if reader.remaining() != 500 {
347 log::warn!("Invalid remaining len for ConnectionChallengeResponsePacket");
348 return Err(PicklebackError::InvalidPacket);
349 }
350 Ok(ProtocolPacket::ConnectionChallengeResponse(c))
351 }
352 PacketType::ConnectionDenied => {
353 let c = ConnectionDeniedPacket {
354 header,
355 reason: DisconnectReason::try_from(reader.read_u8()?)?,
356 };
357 Ok(ProtocolPacket::ConnectionDenied(c))
358 }
359 PacketType::Messages => {
360 let xor_salt = reader.read_u64::<NetworkEndian>()?;
361 let c = MessagesPacket { header, xor_salt };
369 Ok(ProtocolPacket::Messages(c))
370 }
371 PacketType::Disconnect => {
372 let c = DisconnectPacket {
373 header,
374 xor_salt: reader.read_u64::<NetworkEndian>()?,
375 };
376 Ok(ProtocolPacket::Disconnect(c))
377 }
378 }
379}
380
381