use super::{
packet::{RconPacket, RconPacketType},
MAX_LEN_CLIENTBOUND,
};
use crate::errors::{timeout_err, RconProtocolError};
use bytes::{BufMut, BytesMut};
use std::time::Duration;
use tokio::{
io::{self, AsyncReadExt, AsyncWriteExt, Error},
net::TcpStream,
time::timeout,
};
#[allow(clippy::module_name_repetitions)]
#[derive(Debug)]
pub struct RconClient {
socket: TcpStream,
timeout: Option<Duration>,
}
impl RconClient {
pub async fn new(host: &str, port: u16) -> io::Result<Self> {
let connection = TcpStream::connect(format!("{host}:{port}")).await?;
Ok(Self {
socket: connection,
timeout: None,
})
}
pub async fn with_timeout(host: &str, port: u16, timeout: Duration) -> io::Result<Self> {
let mut client = Self::new(host, port).await?;
client.set_timeout(Some(timeout));
Ok(client)
}
pub fn set_timeout(&mut self, timeout: Option<Duration>) {
self.timeout = timeout;
}
pub async fn disconnect(mut self) -> io::Result<()> {
self.socket.shutdown().await
}
pub async fn authenticate(&mut self, password: &str) -> io::Result<()> {
let to = self.timeout;
let fut = self.authenticate_raw(password);
match to {
None => fut.await,
Some(d) => timeout(d, fut).await.unwrap_or(timeout_err()),
}
}
pub async fn run_command(&mut self, command: &str) -> io::Result<String> {
let to = self.timeout;
let fut = self.run_command_raw(command);
match to {
None => fut.await,
Some(d) => timeout(d, fut).await.unwrap_or(timeout_err()),
}
}
async fn authenticate_raw(&mut self, password: &str) -> io::Result<()> {
let packet =
RconPacket::new(1, RconPacketType::Login, password.to_string()).map_err(Error::from)?;
self.write_packet(packet).await?;
let packet = self.read_packet().await?;
if !matches!(packet.packet_type, RconPacketType::RunCommand) {
return Err(RconProtocolError::InvalidPacketType.into());
}
if packet.request_id == -1 {
return Err(RconProtocolError::AuthFailed.into());
} else if packet.request_id != 1 {
return Err(RconProtocolError::RequestIdMismatch.into());
}
Ok(())
}
async fn run_command_raw(&mut self, command: &str) -> io::Result<String> {
let packet = RconPacket::new(1, RconPacketType::RunCommand, command.to_string())
.map_err(Error::from)?;
self.write_packet(packet).await?;
let mut full_payload = String::new();
loop {
let recieved = self.read_packet().await?;
if recieved.request_id == -1 {
return Err(RconProtocolError::AuthFailed.into());
} else if recieved.request_id != 1 {
return Err(RconProtocolError::RequestIdMismatch.into());
}
full_payload.push_str(&recieved.payload);
if recieved.payload.len() < MAX_LEN_CLIENTBOUND {
break;
}
}
Ok(full_payload)
}
async fn read_packet(&mut self) -> io::Result<RconPacket> {
let len = self.socket.read_i32_le().await?;
let mut bytes = BytesMut::new();
bytes.put_i32_le(len);
for _ in 0..len {
let current = self.socket.read_u8().await?;
bytes.put_u8(current);
}
RconPacket::try_from(bytes.freeze()).map_err(Error::from)
}
async fn write_packet(&mut self, packet: RconPacket) -> io::Result<()> {
let bytes = packet.bytes();
self.socket.write_all(&bytes).await
}
}
#[cfg(test)]
mod tests {
use super::RconClient;
use tokio::io;
#[tokio::test]
async fn test_rcon_command() -> io::Result<()> {
let mut client = RconClient::new("localhost", 25575).await?;
client.authenticate("mc-query-test").await?;
let response = client.run_command("time set day").await?;
println!("recieved response: {response}");
Ok(())
}
#[tokio::test]
async fn test_rcon_unauthenticated() -> io::Result<()> {
let mut client = RconClient::new("localhost", 25575).await?;
let result = client.run_command("time set day").await;
assert!(result.is_err());
Ok(())
}
#[tokio::test]
async fn test_rcon_incorrect_password() -> io::Result<()> {
let mut client = RconClient::new("localhost", 25575).await?;
let result = client.authenticate("incorrect").await;
assert!(result.is_err());
Ok(())
}
}