use std::time::Duration;
use serde::Deserialize;
use thiserror::Error;
use tokio::net::TcpStream;
use crate::protocol::{self, AsyncReadRawPacket, AsyncWriteRawPacket};
#[derive(Error, Debug)]
pub enum ServerError {
#[error("error reading or writing data")]
ProtocolError,
#[error("failed to connect to server")]
FailedToConnect,
#[error("invalid JSON response: \"{0}\"")]
InvalidJson(String),
#[error("mismatched pong payload (expected \"{expected}\", got \"{actual}\")")]
MismatchedPayload { expected: u64, actual: u64 },
}
impl From<protocol::ProtocolError> for ServerError {
fn from(_err: protocol::ProtocolError) -> Self {
ServerError::ProtocolError
}
}
#[derive(Debug, Deserialize)]
pub struct ServerVersion {
pub name: String,
pub protocol: u32,
}
#[derive(Debug, Deserialize)]
pub struct ServerPlayer {
pub name: String,
pub id: String,
}
#[derive(Debug, Deserialize)]
pub struct ServerPlayers {
pub max: u32,
pub online: u32,
pub sample: Option<Vec<ServerPlayer>>,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum ServerDescription {
Plain(String),
Object { text: String },
}
#[derive(Debug, Deserialize)]
pub struct StatusResponse {
pub version: ServerVersion,
pub players: ServerPlayers,
pub description: ServerDescription,
pub favicon: Option<String>,
}
const LATEST_PROTOCOL_VERSION: usize = 578;
const DEFAULT_PORT: u16 = 25565;
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(2);
pub struct ConnectionConfig {
protocol_version: usize,
address: String,
port: u16,
timeout: Duration,
}
impl ConnectionConfig {
pub fn build<T: Into<String>>(address: T) -> Self {
ConnectionConfig {
protocol_version: LATEST_PROTOCOL_VERSION,
address: address.into(),
port: DEFAULT_PORT,
timeout: DEFAULT_TIMEOUT,
}
}
pub fn with_protocol_version(mut self, protocol_version: usize) -> Self {
self.protocol_version = protocol_version;
self
}
pub fn with_port(mut self, port: u16) -> Self {
self.port = port;
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub async fn connect(self) -> Result<StatusConnection, ServerError> {
let stream = TcpStream::connect(format!("{}:{}", self.address, self.port))
.await
.map_err(|_| ServerError::FailedToConnect)?;
Ok(StatusConnection {
stream,
protocol_version: self.protocol_version,
address: self.address,
port: self.port,
timeout: self.timeout,
})
}
}
pub async fn connect(address: String) -> Result<StatusConnection, ServerError> {
ConnectionConfig::build(address).connect().await
}
pub struct StatusConnection {
stream: TcpStream,
protocol_version: usize,
address: String,
port: u16,
timeout: Duration,
}
impl StatusConnection {
pub async fn status(mut self) -> Result<PingConnection, ServerError> {
let handshake = protocol::HandshakePacket::new(
self.protocol_version,
self.address.to_string(),
self.port,
);
self.stream
.write_packet_with_timeout(handshake, self.timeout.clone())
.await?;
self.stream
.write_packet_with_timeout(protocol::RequestPacket::new(), self.timeout.clone())
.await?;
let response: protocol::ResponsePacket = self
.stream
.read_packet_with_timeout(self.timeout.clone())
.await?;
let status: StatusResponse = serde_json::from_str(&response.body)
.map_err(|_| ServerError::InvalidJson(response.body))?;
Ok(PingConnection {
stream: self.stream,
protocol_version: self.protocol_version,
address: self.address,
port: self.port,
status,
timeout: self.timeout,
})
}
}
pub struct PingConnection {
stream: TcpStream,
protocol_version: usize,
address: String,
port: u16,
timeout: Duration,
pub status: StatusResponse,
}
impl PingConnection {
pub async fn ping(mut self, payload: u64) -> Result<(), ServerError> {
let ping = protocol::PingPacket::new(payload);
self.stream
.write_packet_with_timeout(ping, self.timeout.clone())
.await?;
let pong: protocol::PongPacket = self
.stream
.read_packet_with_timeout(self.timeout.clone())
.await?;
if pong.payload != payload {
return Err(ServerError::MismatchedPayload {
expected: payload,
actual: pong.payload,
}
.into());
}
Ok(())
}
}