1use bytes::{Buf, BufMut, Bytes, BytesMut};
4use rand::random;
5use std::collections::HashMap;
6use std::time::Duration;
7use tokio::io;
8use tokio::net::UdpSocket;
9use tokio::time::timeout;
10
11use crate::errors::QueryProtocolError;
12
13const QUERY_MAGIC: u16 = 0xfe_fd;
14const SESSION_ID_MASK: u32 = 0x0f_0f_0f_0f;
15
16#[derive(Debug)]
19pub struct BasicStatResponse {
20 pub motd: String,
22
23 pub game_type: String,
26
27 pub map: String,
29
30 pub num_players: usize,
32
33 pub max_players: usize,
35
36 pub host_port: u16,
38
39 pub host_ip: String,
41}
42
43#[derive(Debug)]
46pub struct FullStatResponse {
47 pub motd: String,
49
50 pub game_type: String,
53
54 pub game_id: String,
57
58 pub version: String,
60
61 pub plugins: String,
65
66 pub map: String,
68
69 pub num_players: usize,
71
72 pub max_players: usize,
74
75 pub host_port: u16,
77
78 pub host_ip: String,
80
81 pub players: Vec<String>,
83}
84
85async fn stat_send(sock: &UdpSocket, bytes: &[u8]) -> io::Result<Bytes> {
86 sock.send(bytes).await?;
87 Box::pin(timeout(Duration::from_millis(250), recv_packet(sock))).await?
88}
89
90pub async fn stat_basic(host: &str, port: u16) -> io::Result<BasicStatResponse> {
116 let socket = UdpSocket::bind("0.0.0.0:0").await?;
117 socket.connect(format!("{host}:{port}")).await?;
118
119 let (token, session) = Box::pin(handshake(&socket)).await?;
120
121 let mut bytes = BytesMut::new();
122 bytes.put_u16(QUERY_MAGIC);
123 bytes.put_u8(0); bytes.put_i32(session);
125 bytes.put_i32(token);
126
127 let mut res = match stat_send(&socket, &bytes).await {
128 Ok(v) => v,
129 Err(_) => stat_send(&socket, &bytes).await?,
130 };
131
132 validate_packet(&mut res, 0, session)?;
133
134 let motd = get_string(&mut res)?;
135 let game_type = get_string(&mut res)?;
136 let map = get_string(&mut res)?;
137 let num_players = get_string(&mut res)?
138 .parse()
139 .map_err::<io::Error, _>(|_| QueryProtocolError::CannotParseInt.into())?;
140 let max_players = get_string(&mut res)?
141 .parse()
142 .map_err::<io::Error, _>(|_| QueryProtocolError::CannotParseInt.into())?;
143
144 let host_port = res.get_u16_le(); let host_ip = get_string(&mut res)?;
147
148 Ok(BasicStatResponse {
149 motd,
150 game_type,
151 map,
152 num_players,
153 max_players,
154 host_port,
155 host_ip,
156 })
157}
158
159pub async fn stat_full(host: &str, port: u16) -> io::Result<FullStatResponse> {
185 let socket = UdpSocket::bind("0.0.0.0:0").await?;
186 socket.connect(format!("{host}:{port}")).await?;
187
188 let (token, session) = Box::pin(handshake(&socket)).await?;
189
190 let mut bytes = BytesMut::new();
191 bytes.put_u16(QUERY_MAGIC);
192 bytes.put_u8(0); bytes.put_i32(session);
194 bytes.put_i32(token);
195 bytes.put_u32(0); let mut res = match stat_send(&socket, &bytes).await {
198 Ok(v) => v,
199 Err(_) => stat_send(&socket, &bytes).await?,
200 };
201
202 validate_packet(&mut res, 0, session)?;
203
204 res.advance(11);
206
207 let mut kv = HashMap::new();
209 loop {
210 let key = get_string(&mut res)?;
211 if key.is_empty() {
212 break;
213 }
214 let value = get_string(&mut res)?;
215 kv.insert(key, value);
216 }
217
218 let motd = kv
220 .remove("hostname")
221 .ok_or(QueryProtocolError::InvalidKeyValueSection)?;
222 let game_type = kv
223 .remove("gametype")
224 .ok_or(QueryProtocolError::InvalidKeyValueSection)?;
225 let game_id = kv
226 .remove("game_id")
227 .ok_or(QueryProtocolError::InvalidKeyValueSection)?;
228 let version = kv
229 .remove("version")
230 .ok_or(QueryProtocolError::InvalidKeyValueSection)?;
231 let plugins = kv
232 .remove("plugins")
233 .ok_or(QueryProtocolError::InvalidKeyValueSection)?;
234 let map = kv
235 .remove("map")
236 .ok_or(QueryProtocolError::InvalidKeyValueSection)?;
237 let num_players = kv
238 .remove("numplayers")
239 .ok_or(QueryProtocolError::InvalidKeyValueSection)?
240 .parse()
241 .map_err(|_| QueryProtocolError::CannotParseInt)?;
242 let max_players = kv
243 .remove("maxplayers")
244 .ok_or(QueryProtocolError::InvalidKeyValueSection)?
245 .parse()
246 .map_err(|_| QueryProtocolError::CannotParseInt)?;
247 let host_port = kv
248 .remove("hostport")
249 .ok_or(QueryProtocolError::InvalidKeyValueSection)?
250 .parse()
251 .map_err(|_| QueryProtocolError::CannotParseInt)?;
252 let host_ip = kv
253 .remove("hostip")
254 .ok_or(QueryProtocolError::InvalidKeyValueSection)?;
255
256 for _ in 0..10 {
258 res.get_u8();
259 }
260
261 let mut players = vec![];
263 loop {
264 let username = get_string(&mut res)?;
265 if username.is_empty() {
266 break;
267 }
268 players.push(username);
269 }
270
271 Ok(FullStatResponse {
272 motd,
273 game_type,
274 game_id,
275 version,
276 plugins,
277 map,
278 num_players,
279 max_players,
280 host_port,
281 host_ip,
282 players,
283 })
284}
285
286create_timeout!(stat_basic, BasicStatResponse);
287create_timeout!(stat_full, FullStatResponse);
288
289async fn handshake(socket: &UdpSocket) -> io::Result<(i32, i32)> {
297 #[allow(clippy::cast_possible_wrap)] let session_id = (random::<u32>() & SESSION_ID_MASK) as i32;
300
301 let mut req = BytesMut::with_capacity(7);
302 req.put_u16(QUERY_MAGIC);
303 req.put_u8(9); req.put_i32(session_id);
305 socket.send(&req).await?;
308
309 let mut response = Box::pin(recv_packet(socket)).await?;
310 validate_packet(&mut response, 9, session_id)?;
311
312 let token_str = get_string(&mut response)?;
313
314 token_str
315 .parse()
316 .map(|t| (t, session_id))
317 .map_err(|_| QueryProtocolError::CannotParseInt.into())
318}
319
320async fn recv_packet(socket: &UdpSocket) -> io::Result<Bytes> {
321 let mut buf = [0u8; 65536];
322 socket.recv(&mut buf).await?;
323
324 Ok(Bytes::copy_from_slice(&buf))
325}
326
327fn validate_packet(packet: &mut Bytes, expected_type: u8, expected_session: i32) -> io::Result<()> {
328 let recv_type = packet.get_u8();
329 if recv_type != expected_type {
330 return Err(QueryProtocolError::InvalidPacketType.into());
331 }
332
333 let recv_session = packet.get_i32();
334 if recv_session != expected_session {
335 return Err(QueryProtocolError::SessionIdMismatch.into());
336 }
337
338 Ok(())
339}
340
341fn get_string(bytes: &mut Bytes) -> io::Result<String> {
342 let mut buf = vec![];
343 loop {
344 let byte = bytes.get_u8();
345 if byte == 0 {
346 break;
347 }
348 buf.push(byte);
349 }
350
351 String::from_utf8(buf).map_err(|_| QueryProtocolError::InvalidUtf8.into())
352}
353
354#[cfg(test)]
355mod tests {
356 use tokio::io;
357
358 use super::{stat_basic, stat_full};
359
360 #[tokio::test]
361 async fn test_stat_basic() -> io::Result<()> {
362 let response = stat_basic("localhost", 25565).await?;
363 println!("{response:#?}");
364
365 Ok(())
366 }
367
368 #[tokio::test]
369 async fn test_stat_full() -> io::Result<()> {
370 let response = stat_full("localhost", 25565).await?;
371 println!("{response:#?}");
372
373 Ok(())
374 }
375}