async_minecraft_ping/
protocol.rs

1//! This module defines various methods to read and
2//! write packets in Minecraft's
3//! [ServerListPing](https://wiki.vg/Server_List_Ping)
4//! protocol.
5
6use std::io::Cursor;
7use std::time::Duration;
8
9use async_trait::async_trait;
10use thiserror::Error;
11use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
12
13#[derive(Error, Debug)]
14pub enum ProtocolError {
15    #[error("error reading or writing data")]
16    Io(#[from] std::io::Error),
17
18    #[error("invalid packet length")]
19    InvalidPacketLength,
20
21    #[error("invalid varint data")]
22    InvalidVarInt,
23
24    #[error("invalid packet (expected ID {expected:?}, actual ID {actual:?})")]
25    InvalidPacketId { expected: usize, actual: usize },
26
27    #[error("invalid ServerListPing response body (invalid UTF-8)")]
28    InvalidResponseBody,
29
30    #[error("connection timed out")]
31    Timeout(#[from] tokio::time::error::Elapsed),
32}
33
34/// State represents the desired next state of the
35/// exchange.
36///
37/// It's a bit silly now as there's only
38/// one entry, but technically there is more than
39/// one type that can be sent here.
40#[derive(Clone, Copy)]
41pub enum State {
42    Status,
43}
44
45impl From<State> for usize {
46    fn from(state: State) -> Self {
47        match state {
48            State::Status => 1,
49        }
50    }
51}
52
53/// RawPacket is the underlying wrapper of data that
54/// gets read from and written to the socket.
55///
56/// Typically, the flow looks like this:
57/// 1. Construct a specific packet (HandshakePacket
58///   for example).
59/// 2. Write that packet's contents to a byte buffer.
60/// 3. Construct a RawPacket using that byte buffer.
61/// 4. Write the RawPacket to the socket.
62struct RawPacket {
63    id: usize,
64    data: Box<[u8]>,
65}
66
67impl RawPacket {
68    fn new(id: usize, data: Box<[u8]>) -> Self {
69        RawPacket { id, data }
70    }
71}
72
73/// AsyncWireReadExt adds varint and varint-backed
74/// string support to things that implement AsyncRead.
75#[async_trait]
76pub trait AsyncWireReadExt {
77    async fn read_varint(&mut self) -> Result<usize, ProtocolError>;
78    async fn read_string(&mut self) -> Result<String, ProtocolError>;
79}
80
81#[async_trait]
82impl<R: AsyncRead + Unpin + Send + Sync> AsyncWireReadExt for R {
83    async fn read_varint(&mut self) -> Result<usize, ProtocolError> {
84        let mut read = 0;
85        let mut result = 0;
86        loop {
87            let read_value = self.read_u8().await?;
88            let value = read_value & 0b0111_1111;
89            result |= (value as usize) << (7 * read);
90            read += 1;
91            if read > 5 {
92                return Err(ProtocolError::InvalidVarInt);
93            }
94            if (read_value & 0b1000_0000) == 0 {
95                return Ok(result);
96            }
97        }
98    }
99
100    async fn read_string(&mut self) -> Result<String, ProtocolError> {
101        let length = self.read_varint().await?;
102
103        let mut buffer = vec![0; length];
104        self.read_exact(&mut buffer).await?;
105
106        Ok(String::from_utf8(buffer).map_err(|_| ProtocolError::InvalidResponseBody)?)
107    }
108}
109
110/// AsyncWireWriteExt adds varint and varint-backed
111/// string support to things that implement AsyncWrite.
112#[async_trait]
113pub trait AsyncWireWriteExt {
114    async fn write_varint(&mut self, int: usize) -> Result<(), ProtocolError>;
115    async fn write_string(&mut self, string: &str) -> Result<(), ProtocolError>;
116}
117
118#[async_trait]
119impl<W: AsyncWrite + Unpin + Send + Sync> AsyncWireWriteExt for W {
120    async fn write_varint(&mut self, int: usize) -> Result<(), ProtocolError> {
121        let mut int = (int as u64) & 0xFFFF_FFFF;
122        let mut written = 0;
123        let mut buffer = [0; 5];
124        loop {
125            let temp = (int & 0b0111_1111) as u8;
126            int >>= 7;
127            if int != 0 {
128                buffer[written] = temp | 0b1000_0000;
129            } else {
130                buffer[written] = temp;
131            }
132            written += 1;
133            if int == 0 {
134                break;
135            }
136        }
137        self.write(&buffer[0..written]).await?;
138
139        Ok(())
140    }
141
142    async fn write_string(&mut self, string: &str) -> Result<(), ProtocolError> {
143        self.write_varint(string.len()).await?;
144        self.write_all(string.as_bytes()).await?;
145
146        Ok(())
147    }
148}
149
150/// PacketId is used to allow AsyncWriteRawPacket
151/// to generically get a packet's ID.
152pub trait PacketId {
153    fn get_packet_id(&self) -> usize;
154}
155
156/// ExpectedPacketId is used to allow AsyncReadRawPacket
157/// to generically get a packet's expected ID.
158pub trait ExpectedPacketId {
159    fn get_expected_packet_id() -> usize;
160}
161
162/// AsyncReadFromBuffer is used to allow
163/// AsyncReadRawPacket to generically read a
164/// packet's specific data from a buffer.
165#[async_trait]
166pub trait AsyncReadFromBuffer: Sized {
167    async fn read_from_buffer(buffer: Vec<u8>) -> Result<Self, ProtocolError>;
168}
169
170/// AsyncWriteToBuffer is used to allow
171/// AsyncWriteRawPacket to generically write a
172/// packet's specific data into a buffer.
173#[async_trait]
174pub trait AsyncWriteToBuffer {
175    async fn write_to_buffer(&self) -> Result<Vec<u8>, ProtocolError>;
176}
177
178/// AsyncReadRawPacket is the core piece of
179/// the read side of the protocol. It allows
180/// the user to construct a specific packet
181/// from something that implements AsyncRead.
182#[async_trait]
183pub trait AsyncReadRawPacket {
184    async fn read_packet<T: ExpectedPacketId + AsyncReadFromBuffer + Send + Sync>(
185        &mut self,
186    ) -> Result<T, ProtocolError>;
187
188    async fn read_packet_with_timeout<T: ExpectedPacketId + AsyncReadFromBuffer + Send + Sync>(
189        &mut self,
190        timeout: Duration,
191    ) -> Result<T, ProtocolError>;
192}
193
194#[async_trait]
195impl<R: AsyncRead + Unpin + Send + Sync> AsyncReadRawPacket for R {
196    async fn read_packet<T: ExpectedPacketId + AsyncReadFromBuffer + Send + Sync>(
197        &mut self,
198    ) -> Result<T, ProtocolError> {
199        let length = self.read_varint().await?;
200
201        if length == 0 {
202            return Err(ProtocolError::InvalidPacketLength);
203        }
204
205        let packet_id = self.read_varint().await?;
206
207        let expected_packet_id = T::get_expected_packet_id();
208
209        if packet_id != expected_packet_id {
210            return Err(ProtocolError::InvalidPacketId {
211                expected: expected_packet_id,
212                actual: packet_id,
213            });
214        }
215
216        let mut buffer = vec![0; length - 1];
217        self.read_exact(&mut buffer).await?;
218
219        T::read_from_buffer(buffer).await
220    }
221
222    async fn read_packet_with_timeout<T: ExpectedPacketId + AsyncReadFromBuffer + Send + Sync>(
223        &mut self,
224        timeout: Duration,
225    ) -> Result<T, ProtocolError> {
226        tokio::time::timeout(timeout, self.read_packet()).await?
227    }
228}
229
230/// AsyncWriteRawPacket is the core piece of
231/// the write side of the protocol. It allows
232/// the user to write a specific packet to
233/// something that implements AsyncWrite.
234#[async_trait]
235pub trait AsyncWriteRawPacket {
236    async fn write_packet<T: PacketId + AsyncWriteToBuffer + Send + Sync>(
237        &mut self,
238        packet: T,
239    ) -> Result<(), ProtocolError>;
240
241    async fn write_packet_with_timeout<T: PacketId + AsyncWriteToBuffer + Send + Sync>(
242        &mut self,
243        packet: T,
244        timeout: Duration,
245    ) -> Result<(), ProtocolError>;
246}
247
248#[async_trait]
249impl<W: AsyncWrite + Unpin + Send + Sync> AsyncWriteRawPacket for W {
250    async fn write_packet<T: PacketId + AsyncWriteToBuffer + Send + Sync>(
251        &mut self,
252        packet: T,
253    ) -> Result<(), ProtocolError> {
254        let packet_buffer = packet.write_to_buffer().await?;
255
256        let raw_packet = RawPacket::new(packet.get_packet_id(), packet_buffer.into_boxed_slice());
257
258        let mut buffer: Cursor<Vec<u8>> = Cursor::new(Vec::new());
259
260        buffer.write_varint(raw_packet.id).await?;
261        buffer.write_all(&raw_packet.data).await?;
262
263        let inner = buffer.into_inner();
264        self.write_varint(inner.len()).await?;
265        self.write(&inner).await?;
266        Ok(())
267    }
268
269    async fn write_packet_with_timeout<T: PacketId + AsyncWriteToBuffer + Send + Sync>(
270        &mut self,
271        packet: T,
272        timeout: Duration,
273    ) -> Result<(), ProtocolError> {
274        tokio::time::timeout(timeout, self.write_packet(packet)).await?
275    }
276}
277
278/// HandshakePacket is the first of two packets
279/// to be sent during a status check for
280/// ServerListPing.
281pub struct HandshakePacket {
282    pub packet_id: usize,
283    pub protocol_version: usize,
284    pub server_address: String,
285    pub server_port: u16,
286    pub next_state: State,
287}
288
289impl HandshakePacket {
290    pub fn new(protocol_version: usize, server_address: String, server_port: u16) -> Self {
291        Self {
292            packet_id: 0,
293            protocol_version,
294            server_address,
295            server_port,
296            next_state: State::Status,
297        }
298    }
299}
300
301#[async_trait]
302impl AsyncWriteToBuffer for HandshakePacket {
303    async fn write_to_buffer(&self) -> Result<Vec<u8>, ProtocolError> {
304        let mut buffer = Cursor::new(Vec::<u8>::new());
305
306        buffer.write_varint(self.protocol_version).await?;
307        buffer.write_string(&self.server_address).await?;
308        buffer.write_u16(self.server_port).await?;
309        buffer.write_varint(self.next_state.into()).await?;
310
311        Ok(buffer.into_inner())
312    }
313}
314
315impl PacketId for HandshakePacket {
316    fn get_packet_id(&self) -> usize {
317        self.packet_id
318    }
319}
320
321/// RequestPacket is the second of two packets
322/// to be sent during a status check for
323/// ServerListPing.
324pub struct RequestPacket {
325    pub packet_id: usize,
326}
327
328impl RequestPacket {
329    pub fn new() -> Self {
330        Self { packet_id: 0 }
331    }
332}
333
334#[async_trait]
335impl AsyncWriteToBuffer for RequestPacket {
336    async fn write_to_buffer(&self) -> Result<Vec<u8>, ProtocolError> {
337        Ok(Vec::new())
338    }
339}
340
341impl PacketId for RequestPacket {
342    fn get_packet_id(&self) -> usize {
343        self.packet_id
344    }
345}
346
347/// ResponsePacket is the response from the
348/// server to a status check for
349/// ServerListPing.
350pub struct ResponsePacket {
351    pub packet_id: usize,
352    pub body: String,
353}
354
355impl ExpectedPacketId for ResponsePacket {
356    fn get_expected_packet_id() -> usize {
357        0
358    }
359}
360
361#[async_trait]
362impl AsyncReadFromBuffer for ResponsePacket {
363    async fn read_from_buffer(buffer: Vec<u8>) -> Result<Self, ProtocolError> {
364        let mut reader = Cursor::new(buffer);
365
366        let body = reader.read_string().await?;
367
368        Ok(ResponsePacket { packet_id: 0, body })
369    }
370}
371
372pub struct PingPacket {
373    pub packet_id: usize,
374    pub payload: u64,
375}
376
377impl PingPacket {
378    pub fn new(payload: u64) -> Self {
379        Self {
380            packet_id: 1,
381            payload,
382        }
383    }
384}
385
386#[async_trait]
387impl AsyncWriteToBuffer for PingPacket {
388    async fn write_to_buffer(&self) -> Result<Vec<u8>, ProtocolError> {
389        let mut buffer = Cursor::new(Vec::<u8>::new());
390
391        buffer.write_u64(self.payload).await?;
392
393        Ok(buffer.into_inner())
394    }
395}
396
397impl PacketId for PingPacket {
398    fn get_packet_id(&self) -> usize {
399        self.packet_id
400    }
401}
402
403pub struct PongPacket {
404    pub packet_id: usize,
405    pub payload: u64,
406}
407
408impl ExpectedPacketId for PongPacket {
409    fn get_expected_packet_id() -> usize {
410        1
411    }
412}
413
414#[async_trait]
415impl AsyncReadFromBuffer for PongPacket {
416    async fn read_from_buffer(buffer: Vec<u8>) -> Result<Self, ProtocolError> {
417        let mut reader = Cursor::new(buffer);
418
419        let payload = reader.read_u64().await?;
420
421        Ok(PongPacket {
422            packet_id: 0,
423            payload,
424        })
425    }
426}