1use std::io::Cursor;
7use std::time::Duration;
8
9use async_trait::async_trait;
10use thiserror::Error;
11use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
12
13#[derive(Error, Debug)]
14pub enum ProtocolError {
15 #[error("error reading or writing data")]
16 Io(#[from] std::io::Error),
17
18 #[error("invalid packet length")]
19 InvalidPacketLength,
20
21 #[error("invalid varint data")]
22 InvalidVarInt,
23
24 #[error("invalid packet (expected ID {expected:?}, actual ID {actual:?})")]
25 InvalidPacketId { expected: usize, actual: usize },
26
27 #[error("invalid ServerListPing response body (invalid UTF-8)")]
28 InvalidResponseBody,
29
30 #[error("connection timed out")]
31 Timeout(#[from] tokio::time::error::Elapsed),
32}
33
34#[derive(Clone, Copy)]
41pub enum State {
42 Status,
43}
44
45impl From<State> for usize {
46 fn from(state: State) -> Self {
47 match state {
48 State::Status => 1,
49 }
50 }
51}
52
53struct RawPacket {
63 id: usize,
64 data: Box<[u8]>,
65}
66
67impl RawPacket {
68 fn new(id: usize, data: Box<[u8]>) -> Self {
69 RawPacket { id, data }
70 }
71}
72
73#[async_trait]
76pub trait AsyncWireReadExt {
77 async fn read_varint(&mut self) -> Result<usize, ProtocolError>;
78 async fn read_string(&mut self) -> Result<String, ProtocolError>;
79}
80
81#[async_trait]
82impl<R: AsyncRead + Unpin + Send + Sync> AsyncWireReadExt for R {
83 async fn read_varint(&mut self) -> Result<usize, ProtocolError> {
84 let mut read = 0;
85 let mut result = 0;
86 loop {
87 let read_value = self.read_u8().await?;
88 let value = read_value & 0b0111_1111;
89 result |= (value as usize) << (7 * read);
90 read += 1;
91 if read > 5 {
92 return Err(ProtocolError::InvalidVarInt);
93 }
94 if (read_value & 0b1000_0000) == 0 {
95 return Ok(result);
96 }
97 }
98 }
99
100 async fn read_string(&mut self) -> Result<String, ProtocolError> {
101 let length = self.read_varint().await?;
102
103 let mut buffer = vec![0; length];
104 self.read_exact(&mut buffer).await?;
105
106 Ok(String::from_utf8(buffer).map_err(|_| ProtocolError::InvalidResponseBody)?)
107 }
108}
109
110#[async_trait]
113pub trait AsyncWireWriteExt {
114 async fn write_varint(&mut self, int: usize) -> Result<(), ProtocolError>;
115 async fn write_string(&mut self, string: &str) -> Result<(), ProtocolError>;
116}
117
118#[async_trait]
119impl<W: AsyncWrite + Unpin + Send + Sync> AsyncWireWriteExt for W {
120 async fn write_varint(&mut self, int: usize) -> Result<(), ProtocolError> {
121 let mut int = (int as u64) & 0xFFFF_FFFF;
122 let mut written = 0;
123 let mut buffer = [0; 5];
124 loop {
125 let temp = (int & 0b0111_1111) as u8;
126 int >>= 7;
127 if int != 0 {
128 buffer[written] = temp | 0b1000_0000;
129 } else {
130 buffer[written] = temp;
131 }
132 written += 1;
133 if int == 0 {
134 break;
135 }
136 }
137 self.write(&buffer[0..written]).await?;
138
139 Ok(())
140 }
141
142 async fn write_string(&mut self, string: &str) -> Result<(), ProtocolError> {
143 self.write_varint(string.len()).await?;
144 self.write_all(string.as_bytes()).await?;
145
146 Ok(())
147 }
148}
149
150pub trait PacketId {
153 fn get_packet_id(&self) -> usize;
154}
155
156pub trait ExpectedPacketId {
159 fn get_expected_packet_id() -> usize;
160}
161
162#[async_trait]
166pub trait AsyncReadFromBuffer: Sized {
167 async fn read_from_buffer(buffer: Vec<u8>) -> Result<Self, ProtocolError>;
168}
169
170#[async_trait]
174pub trait AsyncWriteToBuffer {
175 async fn write_to_buffer(&self) -> Result<Vec<u8>, ProtocolError>;
176}
177
178#[async_trait]
183pub trait AsyncReadRawPacket {
184 async fn read_packet<T: ExpectedPacketId + AsyncReadFromBuffer + Send + Sync>(
185 &mut self,
186 ) -> Result<T, ProtocolError>;
187
188 async fn read_packet_with_timeout<T: ExpectedPacketId + AsyncReadFromBuffer + Send + Sync>(
189 &mut self,
190 timeout: Duration,
191 ) -> Result<T, ProtocolError>;
192}
193
194#[async_trait]
195impl<R: AsyncRead + Unpin + Send + Sync> AsyncReadRawPacket for R {
196 async fn read_packet<T: ExpectedPacketId + AsyncReadFromBuffer + Send + Sync>(
197 &mut self,
198 ) -> Result<T, ProtocolError> {
199 let length = self.read_varint().await?;
200
201 if length == 0 {
202 return Err(ProtocolError::InvalidPacketLength);
203 }
204
205 let packet_id = self.read_varint().await?;
206
207 let expected_packet_id = T::get_expected_packet_id();
208
209 if packet_id != expected_packet_id {
210 return Err(ProtocolError::InvalidPacketId {
211 expected: expected_packet_id,
212 actual: packet_id,
213 });
214 }
215
216 let mut buffer = vec![0; length - 1];
217 self.read_exact(&mut buffer).await?;
218
219 T::read_from_buffer(buffer).await
220 }
221
222 async fn read_packet_with_timeout<T: ExpectedPacketId + AsyncReadFromBuffer + Send + Sync>(
223 &mut self,
224 timeout: Duration,
225 ) -> Result<T, ProtocolError> {
226 tokio::time::timeout(timeout, self.read_packet()).await?
227 }
228}
229
230#[async_trait]
235pub trait AsyncWriteRawPacket {
236 async fn write_packet<T: PacketId + AsyncWriteToBuffer + Send + Sync>(
237 &mut self,
238 packet: T,
239 ) -> Result<(), ProtocolError>;
240
241 async fn write_packet_with_timeout<T: PacketId + AsyncWriteToBuffer + Send + Sync>(
242 &mut self,
243 packet: T,
244 timeout: Duration,
245 ) -> Result<(), ProtocolError>;
246}
247
248#[async_trait]
249impl<W: AsyncWrite + Unpin + Send + Sync> AsyncWriteRawPacket for W {
250 async fn write_packet<T: PacketId + AsyncWriteToBuffer + Send + Sync>(
251 &mut self,
252 packet: T,
253 ) -> Result<(), ProtocolError> {
254 let packet_buffer = packet.write_to_buffer().await?;
255
256 let raw_packet = RawPacket::new(packet.get_packet_id(), packet_buffer.into_boxed_slice());
257
258 let mut buffer: Cursor<Vec<u8>> = Cursor::new(Vec::new());
259
260 buffer.write_varint(raw_packet.id).await?;
261 buffer.write_all(&raw_packet.data).await?;
262
263 let inner = buffer.into_inner();
264 self.write_varint(inner.len()).await?;
265 self.write(&inner).await?;
266 Ok(())
267 }
268
269 async fn write_packet_with_timeout<T: PacketId + AsyncWriteToBuffer + Send + Sync>(
270 &mut self,
271 packet: T,
272 timeout: Duration,
273 ) -> Result<(), ProtocolError> {
274 tokio::time::timeout(timeout, self.write_packet(packet)).await?
275 }
276}
277
278pub struct HandshakePacket {
282 pub packet_id: usize,
283 pub protocol_version: usize,
284 pub server_address: String,
285 pub server_port: u16,
286 pub next_state: State,
287}
288
289impl HandshakePacket {
290 pub fn new(protocol_version: usize, server_address: String, server_port: u16) -> Self {
291 Self {
292 packet_id: 0,
293 protocol_version,
294 server_address,
295 server_port,
296 next_state: State::Status,
297 }
298 }
299}
300
301#[async_trait]
302impl AsyncWriteToBuffer for HandshakePacket {
303 async fn write_to_buffer(&self) -> Result<Vec<u8>, ProtocolError> {
304 let mut buffer = Cursor::new(Vec::<u8>::new());
305
306 buffer.write_varint(self.protocol_version).await?;
307 buffer.write_string(&self.server_address).await?;
308 buffer.write_u16(self.server_port).await?;
309 buffer.write_varint(self.next_state.into()).await?;
310
311 Ok(buffer.into_inner())
312 }
313}
314
315impl PacketId for HandshakePacket {
316 fn get_packet_id(&self) -> usize {
317 self.packet_id
318 }
319}
320
321pub struct RequestPacket {
325 pub packet_id: usize,
326}
327
328impl RequestPacket {
329 pub fn new() -> Self {
330 Self { packet_id: 0 }
331 }
332}
333
334#[async_trait]
335impl AsyncWriteToBuffer for RequestPacket {
336 async fn write_to_buffer(&self) -> Result<Vec<u8>, ProtocolError> {
337 Ok(Vec::new())
338 }
339}
340
341impl PacketId for RequestPacket {
342 fn get_packet_id(&self) -> usize {
343 self.packet_id
344 }
345}
346
347pub struct ResponsePacket {
351 pub packet_id: usize,
352 pub body: String,
353}
354
355impl ExpectedPacketId for ResponsePacket {
356 fn get_expected_packet_id() -> usize {
357 0
358 }
359}
360
361#[async_trait]
362impl AsyncReadFromBuffer for ResponsePacket {
363 async fn read_from_buffer(buffer: Vec<u8>) -> Result<Self, ProtocolError> {
364 let mut reader = Cursor::new(buffer);
365
366 let body = reader.read_string().await?;
367
368 Ok(ResponsePacket { packet_id: 0, body })
369 }
370}
371
372pub struct PingPacket {
373 pub packet_id: usize,
374 pub payload: u64,
375}
376
377impl PingPacket {
378 pub fn new(payload: u64) -> Self {
379 Self {
380 packet_id: 1,
381 payload,
382 }
383 }
384}
385
386#[async_trait]
387impl AsyncWriteToBuffer for PingPacket {
388 async fn write_to_buffer(&self) -> Result<Vec<u8>, ProtocolError> {
389 let mut buffer = Cursor::new(Vec::<u8>::new());
390
391 buffer.write_u64(self.payload).await?;
392
393 Ok(buffer.into_inner())
394 }
395}
396
397impl PacketId for PingPacket {
398 fn get_packet_id(&self) -> usize {
399 self.packet_id
400 }
401}
402
403pub struct PongPacket {
404 pub packet_id: usize,
405 pub payload: u64,
406}
407
408impl ExpectedPacketId for PongPacket {
409 fn get_expected_packet_id() -> usize {
410 1
411 }
412}
413
414#[async_trait]
415impl AsyncReadFromBuffer for PongPacket {
416 async fn read_from_buffer(buffer: Vec<u8>) -> Result<Self, ProtocolError> {
417 let mut reader = Cursor::new(buffer);
418
419 let payload = reader.read_u64().await?;
420
421 Ok(PongPacket {
422 packet_id: 0,
423 payload,
424 })
425 }
426}