pickleback/protocol/
packets.rs

1use super::AckHeader;
2use crate::{
3    buffer_pool::{BufHandle, BufPool},
4    cursor::{BufferLimitedWriter, CursorExtras},
5    prelude::PicklebackConfig,
6    PacketId, PicklebackError,
7};
8use byteorder::*;
9use std::{
10    io::{Cursor, Write},
11    net::SocketAddr,
12};
13
14use super::DisconnectReason;
15
16#[derive(Clone, Eq, PartialEq)]
17pub struct AddressedPacket {
18    pub address: SocketAddr,
19    pub packet: BufHandle,
20}
21
22#[derive(Debug)]
23pub(crate) enum ProtocolPacket {
24    // 1 - C2S
25    ConnectionRequest(ConnectionRequestPacket),
26    // 2 - S2C
27    ConnectionChallenge(ConnectionChallengePacket),
28    // 3 - C2S
29    ConnectionChallengeResponse(ConnectionChallengeResponsePacket),
30    // 4 - S2C
31    ConnectionDenied(ConnectionDeniedPacket),
32    // 5 - Any
33    Messages(MessagesPacket),
34    // 6 - Any (server can kick you, or you can gracefully exit)
35    Disconnect(DisconnectPacket),
36    // 7 - Any
37    KeepAlive(KeepAlivePacket),
38}
39
40#[derive(Debug, Copy, Clone)]
41#[repr(u8)]
42pub(crate) enum PacketType {
43    ConnectionRequest = 1,
44    ConnectionChallenge = 2,
45    ConnectionChallengeResponse = 3,
46    ConnectionDenied = 4,
47    Messages = 5,
48    Disconnect = 6,
49    KeepAlive = 7,
50}
51
52impl TryFrom<u8> for PacketType {
53    type Error = PicklebackError;
54
55    fn try_from(value: u8) -> Result<Self, Self::Error> {
56        match value {
57            1 => Ok(PacketType::ConnectionRequest),
58            2 => Ok(PacketType::ConnectionChallenge),
59            3 => Ok(PacketType::ConnectionChallengeResponse),
60            4 => Ok(PacketType::ConnectionDenied),
61            5 => Ok(PacketType::Messages),
62            6 => Ok(PacketType::Disconnect),
63            7 => Ok(PacketType::KeepAlive),
64            _ => Err(PicklebackError::InvalidPacket),
65        }
66    }
67}
68
69impl From<&ProtocolPacket> for PacketType {
70    fn from(val: &ProtocolPacket) -> Self {
71        match val {
72            ProtocolPacket::ConnectionRequest(_) => PacketType::ConnectionRequest,
73            ProtocolPacket::ConnectionChallenge(_) => PacketType::ConnectionChallenge,
74            ProtocolPacket::ConnectionChallengeResponse(_) => {
75                PacketType::ConnectionChallengeResponse
76            }
77            ProtocolPacket::ConnectionDenied(_) => PacketType::ConnectionDenied,
78            ProtocolPacket::Messages(_) => PacketType::Messages,
79            ProtocolPacket::Disconnect(_) => PacketType::Disconnect,
80            ProtocolPacket::KeepAlive(_) => PacketType::KeepAlive,
81        }
82    }
83}
84
85#[derive(Debug)]
86pub(crate) struct ProtocolPacketHeader {
87    pub(crate) packet_type: PacketType,
88    pub(crate) id: PacketId,
89    pub(crate) ack_header: Option<AckHeader>,
90}
91
92impl ProtocolPacketHeader {
93    pub(crate) fn new(
94        id: PacketId,
95        ack_iter: impl Iterator<Item = (u16, bool)>,
96        num_acks: u16,
97        packet_type: PacketType,
98    ) -> Result<Self, PicklebackError> {
99        if num_acks == 0 {
100            return Self::new_no_acks(id, packet_type);
101        }
102        let ack_header = AckHeader::from_ack_iter(num_acks, ack_iter)?;
103        Ok(Self {
104            packet_type,
105            id,
106            ack_header: Some(ack_header),
107        })
108    }
109    pub(crate) fn new_no_acks(
110        id: PacketId,
111        packet_type: PacketType,
112    ) -> Result<Self, PicklebackError> {
113        Ok(Self {
114            packet_type,
115            id,
116            ack_header: None,
117        })
118    }
119
120    pub(crate) fn id(&self) -> PacketId {
121        self.id
122    }
123    pub(crate) fn ack_id(&self) -> Option<PacketId> {
124        self.ack_header.map(|header| header.ack_id())
125    }
126    pub(crate) fn acks(&self) -> Option<impl Iterator<Item = (u16, bool)>> {
127        self.ack_header.map(|header| header.into_iter())
128    }
129
130    #[allow(unused)]
131    pub(crate) fn ack_header(&self) -> Option<&AckHeader> {
132        self.ack_header.as_ref()
133    }
134
135    pub(crate) fn size(&self) -> usize {
136        1 + // prefix byte
137        2 + // packet sequence id
138        self.ack_header.map_or(0, |header| header.size())
139    }
140
141    pub(crate) fn write(&self, mut writer: &mut impl Write) -> Result<(), PicklebackError> {
142        let mut prefix_byte = self.packet_type as u8;
143        if self.ack_header.is_some() {
144            // highest bit denotes presence of ack header
145            prefix_byte |= 0b1000_0000;
146        }
147        writer.write_u8(prefix_byte)?;
148        writer.write_u16::<NetworkEndian>(self.id.0)?;
149        if let Some(ack_header) = self.ack_header {
150            ack_header.write(&mut writer)?;
151        }
152        Ok(())
153    }
154    pub(crate) fn parse(reader: &mut Cursor<&[u8]>) -> Result<Self, PicklebackError> {
155        let prefix_byte = reader.read_u8()?;
156        let ack_header_present = prefix_byte & 0b1000_0000 != 0;
157        let Ok(packet_type) = PacketType::try_from(prefix_byte & 0b0111_1111) else {
158            log::error!("prefix byte packet type invalid");
159            return Err(PicklebackError::InvalidPacket);
160        };
161        let id = PacketId(reader.read_u16::<NetworkEndian>()?);
162        let ack_header = if ack_header_present {
163            Some(AckHeader::parse(reader)?)
164        } else {
165            None
166        };
167        Ok(Self {
168            packet_type,
169            id,
170            ack_header,
171        })
172    }
173}
174
175// C2S
176#[derive(Debug)]
177pub(crate) struct ConnectionRequestPacket {
178    pub(crate) header: ProtocolPacketHeader,
179    pub(crate) client_salt: u64,
180    pub(crate) protocol_version: u64,
181    // TODO protocol version, so server can reject unsupported versions.
182}
183
184// S2C
185#[derive(Debug)]
186pub(crate) struct ConnectionChallengePacket {
187    pub(crate) header: ProtocolPacketHeader,
188    pub(crate) client_salt: u64,
189    pub(crate) server_salt: u64,
190}
191
192// C2S
193#[derive(Debug)]
194pub(crate) struct ConnectionChallengeResponsePacket {
195    pub(crate) header: ProtocolPacketHeader,
196    pub(crate) xor_salt: u64,
197}
198
199// S2C
200#[derive(Debug)]
201pub(crate) struct ConnectionDeniedPacket {
202    pub(crate) header: ProtocolPacketHeader,
203    pub(crate) reason: DisconnectReason,
204}
205
206// Bidirectional
207pub(crate) struct MessagesPacket {
208    pub(crate) header: ProtocolPacketHeader,
209    pub(crate) xor_salt: u64,
210    // pub(crate) messages: Vec<Message>, // Box<dyn Iterator<Item = Result<Message, PicklebackError>> + 'a>, //Vec<Message>,
211}
212
213impl std::fmt::Debug for MessagesPacket {
214    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215        write!(
216            f,
217            "MessagesPacket[header: {:?} xor_salt: {}]]",
218            self.header, self.xor_salt
219        )
220    }
221}
222
223// Bi?
224#[derive(Debug)]
225pub(crate) struct DisconnectPacket {
226    pub(crate) header: ProtocolPacketHeader,
227    pub(crate) xor_salt: u64,
228}
229
230// Bi?
231#[derive(Debug)]
232pub(crate) struct KeepAlivePacket {
233    pub(crate) header: ProtocolPacketHeader,
234    pub(crate) xor_salt: u64,
235    pub(crate) client_index: u32,
236}
237
238pub(crate) fn write_zero_bytes<W: Write>(writer: &mut W, num_bytes: usize) -> std::io::Result<()> {
239    let buffer = vec![0u8; num_bytes];
240    writer.write_all(&buffer)
241}
242
243pub(crate) fn write_packet(
244    pool: &BufPool,
245    config: &PicklebackConfig,
246    packet: ProtocolPacket,
247) -> Result<BufHandle, PicklebackError> {
248    let max_packet_size = config.max_packet_size;
249    let mut buffer = pool.get_buffer(max_packet_size);
250    let mut writer = BufferLimitedWriter::new(Cursor::new(&mut buffer), max_packet_size);
251
252    match packet {
253        ProtocolPacket::KeepAlive(KeepAlivePacket {
254            header,
255            xor_salt,
256            client_index,
257        }) => {
258            header.write(&mut writer)?;
259            writer.write_u64::<NetworkEndian>(xor_salt)?;
260            writer.write_u32::<NetworkEndian>(client_index)?;
261        }
262        ProtocolPacket::ConnectionRequest(ConnectionRequestPacket {
263            header,
264            client_salt,
265            protocol_version,
266        }) => {
267            header.write(&mut writer)?;
268            writer.write_u64::<NetworkEndian>(client_salt)?;
269            writer.write_u64::<NetworkEndian>(protocol_version)?;
270            write_zero_bytes(&mut writer, 500)?;
271        }
272        ProtocolPacket::ConnectionChallenge(ConnectionChallengePacket {
273            header,
274            client_salt,
275            server_salt,
276        }) => {
277            header.write(&mut writer)?;
278            writer.write_u64::<NetworkEndian>(client_salt)?;
279            writer.write_u64::<NetworkEndian>(server_salt)?;
280            write_zero_bytes(&mut writer, 500)?;
281        }
282        ProtocolPacket::ConnectionChallengeResponse(ConnectionChallengeResponsePacket {
283            header,
284            xor_salt,
285        }) => {
286            header.write(&mut writer)?;
287            writer.write_u64::<NetworkEndian>(xor_salt)?;
288            write_zero_bytes(&mut writer, 500)?;
289        }
290        ProtocolPacket::ConnectionDenied(ConnectionDeniedPacket { header, reason }) => {
291            header.write(&mut writer)?;
292            writer.write_u8(reason as u8)?;
293        }
294        ProtocolPacket::Disconnect(DisconnectPacket { header, xor_salt }) => {
295            header.write(&mut writer)?;
296            writer.write_u64::<NetworkEndian>(xor_salt)?;
297        }
298        ProtocolPacket::Messages(MessagesPacket { .. }) => {
299            // written in messages layer
300            panic!("written elsewhere");
301        }
302    }
303    Ok(buffer)
304}
305
306pub(crate) fn read_packet(reader: &mut Cursor<&[u8]>) -> Result<ProtocolPacket, PicklebackError> {
307    let header = ProtocolPacketHeader::parse(reader)?;
308    match header.packet_type {
309        PacketType::KeepAlive => {
310            let c = KeepAlivePacket {
311                header,
312                xor_salt: reader.read_u64::<NetworkEndian>()?,
313                client_index: reader.read_u32::<NetworkEndian>()?,
314            };
315            Ok(ProtocolPacket::KeepAlive(c))
316        }
317        PacketType::ConnectionRequest => {
318            let c = ConnectionRequestPacket {
319                header,
320                client_salt: reader.read_u64::<NetworkEndian>()?,
321                protocol_version: reader.read_u64::<NetworkEndian>()?,
322            };
323            if reader.remaining() != 500 {
324                log::warn!("Invalid remaining len for ConnectionRequestPacket");
325                return Err(PicklebackError::InvalidPacket);
326            }
327            Ok(ProtocolPacket::ConnectionRequest(c))
328        }
329        PacketType::ConnectionChallenge => {
330            let c = ConnectionChallengePacket {
331                header,
332                client_salt: reader.read_u64::<NetworkEndian>()?,
333                server_salt: reader.read_u64::<NetworkEndian>()?,
334            };
335            if reader.remaining() != 500 {
336                log::warn!("Invalid remaining len for ConnectionChallengePacket");
337                return Err(PicklebackError::InvalidPacket);
338            }
339            Ok(ProtocolPacket::ConnectionChallenge(c))
340        }
341        PacketType::ConnectionChallengeResponse => {
342            let c = ConnectionChallengeResponsePacket {
343                header,
344                xor_salt: reader.read_u64::<NetworkEndian>()?,
345            };
346            if reader.remaining() != 500 {
347                log::warn!("Invalid remaining len for ConnectionChallengeResponsePacket");
348                return Err(PicklebackError::InvalidPacket);
349            }
350            Ok(ProtocolPacket::ConnectionChallengeResponse(c))
351        }
352        PacketType::ConnectionDenied => {
353            let c = ConnectionDeniedPacket {
354                header,
355                reason: DisconnectReason::try_from(reader.read_u8()?)?,
356            };
357            Ok(ProtocolPacket::ConnectionDenied(c))
358        }
359        PacketType::Messages => {
360            let xor_salt = reader.read_u64::<NetworkEndian>()?;
361            // let mut messages = Vec::new();
362            // while reader.remaining() > 0 {
363            //     // as long as there are bytes left to read, we should only find whole messages
364            //     messages.push(Message::parse(pool, reader)?);
365            // }
366            // up to caller to parse payload from cursor and extract messages.
367            // TODO enclose with message iterator?
368            let c = MessagesPacket { header, xor_salt };
369            Ok(ProtocolPacket::Messages(c))
370        }
371        PacketType::Disconnect => {
372            let c = DisconnectPacket {
373                header,
374                xor_salt: reader.read_u64::<NetworkEndian>()?,
375            };
376            Ok(ProtocolPacket::Disconnect(c))
377        }
378    }
379}
380
381// struct MessageIterator<'a> {
382//     reader: &'a mut Cursor<&'a [u8]>,
383//     pool: &'a BufPool,
384// }
385
386// impl<'a> Iterator for MessageIterator<'a> {
387//     type Item = Result<Message, PicklebackError>;
388
389//     fn next(&mut self) -> Option<Self::Item> {
390//         if self.reader.remaining() > 0 {
391//             Some(Message::parse(self.pool, self.reader))
392//         } else {
393//             None
394//         }
395//     }
396// }