use bytes::{Buf, BufMut, Bytes, BytesMut};
use rand::random;
use std::collections::HashMap;
use std::time::Duration;
use tokio::io;
use tokio::net::UdpSocket;
use tokio::time::timeout;
use crate::errors::QueryProtocolError;
const QUERY_MAGIC: u16 = 0xfe_fd;
const SESSION_ID_MASK: u32 = 0x0f_0f_0f_0f;
#[derive(Debug)]
pub struct BasicStatResponse {
pub motd: String,
pub game_type: String,
pub map: String,
pub num_players: usize,
pub max_players: usize,
pub host_port: u16,
pub host_ip: String,
}
#[derive(Debug)]
pub struct FullStatResponse {
pub motd: String,
pub game_type: String,
pub game_id: String,
pub version: String,
pub plugins: String,
pub map: String,
pub num_players: usize,
pub max_players: usize,
pub host_port: u16,
pub host_ip: String,
pub players: Vec<String>,
}
async fn stat_send(sock: &UdpSocket, bytes: &[u8]) -> io::Result<Bytes> {
sock.send(bytes).await?;
Box::pin(timeout(Duration::from_millis(250), recv_packet(sock))).await?
}
pub async fn stat_basic(host: &str, port: u16) -> io::Result<BasicStatResponse> {
let socket = UdpSocket::bind("0.0.0.0:0").await?;
socket.connect(format!("{host}:{port}")).await?;
let (token, session) = Box::pin(handshake(&socket)).await?;
let mut bytes = BytesMut::new();
bytes.put_u16(QUERY_MAGIC);
bytes.put_u8(0); bytes.put_i32(session);
bytes.put_i32(token);
let mut res = match stat_send(&socket, &bytes).await {
Ok(v) => v,
Err(_) => stat_send(&socket, &bytes).await?,
};
validate_packet(&mut res, 0, session)?;
let motd = get_string(&mut res)?;
let game_type = get_string(&mut res)?;
let map = get_string(&mut res)?;
let num_players = get_string(&mut res)?
.parse()
.map_err::<io::Error, _>(|_| QueryProtocolError::CannotParseInt.into())?;
let max_players = get_string(&mut res)?
.parse()
.map_err::<io::Error, _>(|_| QueryProtocolError::CannotParseInt.into())?;
let host_port = res.get_u16_le();
let host_ip = get_string(&mut res)?;
Ok(BasicStatResponse {
motd,
game_type,
map,
num_players,
max_players,
host_port,
host_ip,
})
}
pub async fn stat_full(host: &str, port: u16) -> io::Result<FullStatResponse> {
let socket = UdpSocket::bind("0.0.0.0:0").await?;
socket.connect(format!("{host}:{port}")).await?;
let (token, session) = Box::pin(handshake(&socket)).await?;
let mut bytes = BytesMut::new();
bytes.put_u16(QUERY_MAGIC);
bytes.put_u8(0); bytes.put_i32(session);
bytes.put_i32(token);
bytes.put_u32(0);
let mut res = match stat_send(&socket, &bytes).await {
Ok(v) => v,
Err(_) => stat_send(&socket, &bytes).await?,
};
validate_packet(&mut res, 0, session)?;
res.advance(11);
let mut kv = HashMap::new();
loop {
let key = get_string(&mut res)?;
if key.is_empty() {
break;
}
let value = get_string(&mut res)?;
kv.insert(key, value);
}
let motd = kv
.remove("hostname")
.ok_or(QueryProtocolError::InvalidKeyValueSection)?;
let game_type = kv
.remove("gametype")
.ok_or(QueryProtocolError::InvalidKeyValueSection)?;
let game_id = kv
.remove("game_id")
.ok_or(QueryProtocolError::InvalidKeyValueSection)?;
let version = kv
.remove("version")
.ok_or(QueryProtocolError::InvalidKeyValueSection)?;
let plugins = kv
.remove("plugins")
.ok_or(QueryProtocolError::InvalidKeyValueSection)?;
let map = kv
.remove("map")
.ok_or(QueryProtocolError::InvalidKeyValueSection)?;
let num_players = kv
.remove("numplayers")
.ok_or(QueryProtocolError::InvalidKeyValueSection)?
.parse()
.map_err(|_| QueryProtocolError::CannotParseInt)?;
let max_players = kv
.remove("maxplayers")
.ok_or(QueryProtocolError::InvalidKeyValueSection)?
.parse()
.map_err(|_| QueryProtocolError::CannotParseInt)?;
let host_port = kv
.remove("hostport")
.ok_or(QueryProtocolError::InvalidKeyValueSection)?
.parse()
.map_err(|_| QueryProtocolError::CannotParseInt)?;
let host_ip = kv
.remove("hostip")
.ok_or(QueryProtocolError::InvalidKeyValueSection)?;
for _ in 0..10 {
res.get_u8();
}
let mut players = vec![];
loop {
let username = get_string(&mut res)?;
if username.is_empty() {
break;
}
players.push(username);
}
Ok(FullStatResponse {
motd,
game_type,
game_id,
version,
plugins,
map,
num_players,
max_players,
host_port,
host_ip,
players,
})
}
create_timeout!(stat_basic, BasicStatResponse);
create_timeout!(stat_full, FullStatResponse);
async fn handshake(socket: &UdpSocket) -> io::Result<(i32, i32)> {
#[allow(clippy::cast_possible_wrap)] let session_id = (random::<u32>() & SESSION_ID_MASK) as i32;
let mut req = BytesMut::with_capacity(7);
req.put_u16(QUERY_MAGIC);
req.put_u8(9); req.put_i32(session_id);
socket.send(&req).await?;
let mut response = Box::pin(recv_packet(socket)).await?;
validate_packet(&mut response, 9, session_id)?;
let token_str = get_string(&mut response)?;
token_str
.parse()
.map(|t| (t, session_id))
.map_err(|_| QueryProtocolError::CannotParseInt.into())
}
async fn recv_packet(socket: &UdpSocket) -> io::Result<Bytes> {
let mut buf = [0u8; 65536];
socket.recv(&mut buf).await?;
Ok(Bytes::copy_from_slice(&buf))
}
fn validate_packet(packet: &mut Bytes, expected_type: u8, expected_session: i32) -> io::Result<()> {
let recv_type = packet.get_u8();
if recv_type != expected_type {
return Err(QueryProtocolError::InvalidPacketType.into());
}
let recv_session = packet.get_i32();
if recv_session != expected_session {
return Err(QueryProtocolError::SessionIdMismatch.into());
}
Ok(())
}
fn get_string(bytes: &mut Bytes) -> io::Result<String> {
let mut buf = vec![];
loop {
let byte = bytes.get_u8();
if byte == 0 {
break;
}
buf.push(byte);
}
String::from_utf8(buf).map_err(|_| QueryProtocolError::InvalidUtf8.into())
}
#[cfg(test)]
mod tests {
use tokio::io;
use super::{stat_basic, stat_full};
#[tokio::test]
async fn test_stat_basic() -> io::Result<()> {
let response = stat_basic("localhost", 25565).await?;
println!("{response:#?}");
Ok(())
}
#[tokio::test]
async fn test_stat_full() -> io::Result<()> {
let response = stat_full("localhost", 25565).await?;
println!("{response:#?}");
Ok(())
}
}