use crate::error::{RconError, Result};
use crate::protocol::{read_packet, write_packet, Packet};
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::time::timeout;
use tracing::{debug, info};
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
#[derive(Debug)]
pub struct RconClient {
stream: TcpStream,
next_id: i32,
timeout_duration: Duration,
}
impl RconClient {
pub async fn connect(addr: impl AsRef<str>, password: &str) -> Result<Self> {
let addr = addr.as_ref();
info!("Connecting to RCON server at {}", addr);
let stream = timeout(DEFAULT_TIMEOUT, TcpStream::connect(addr))
.await
.map_err(|_| RconError::Timeout(DEFAULT_TIMEOUT.as_millis() as u64))?
.map_err(RconError::ConnectionFailed)?;
let mut client = Self {
stream,
next_id: 1,
timeout_duration: DEFAULT_TIMEOUT,
};
client.authenticate(password).await?;
info!("Successfully connected and authenticated to {}", addr);
Ok(client)
}
pub async fn execute(&mut self, command: &str) -> Result<String> {
self.execute_with_timeout(command, self.timeout_duration)
.await
}
pub async fn execute_with_timeout(
&mut self,
command: &str,
timeout_duration: Duration,
) -> Result<String> {
let id = self.next_request_id();
debug!(id, command, "Executing command");
let result = timeout(timeout_duration, async {
let packet = Packet::command(id, command);
self.send_packet(&packet).await?;
self.receive_packet().await
})
.await
.map_err(|_| RconError::Timeout(timeout_duration.as_millis() as u64))??;
if result.id != id {
return Err(RconError::ProtocolError(format!(
"Response ID mismatch: expected {}, got {}",
id, result.id
)));
}
debug!(
id,
response_len = result.payload.len(),
"Command executed successfully"
);
Ok(result.payload)
}
pub fn set_timeout(&mut self, duration: Duration) {
self.timeout_duration = duration;
debug!(?duration, "Timeout updated");
}
async fn authenticate(&mut self, password: &str) -> Result<()> {
debug!("Authenticating");
let id = self.next_request_id();
let packet = Packet::auth(id, password);
let response = timeout(self.timeout_duration, async {
self.send_packet(&packet).await?;
self.receive_packet().await
})
.await
.map_err(|_| RconError::Timeout(self.timeout_duration.as_millis() as u64))??;
if response.id == -1 {
return Err(RconError::AuthFailed);
}
debug!("Authentication successful");
Ok(())
}
async fn send_packet(&mut self, packet: &Packet) -> Result<()> {
write_packet(&mut self.stream, packet)
.await
.map_err(|e| match e {
RconError::Io(io_err) => RconError::ConnectionLost(io_err),
other => other,
})
}
async fn receive_packet(&mut self) -> Result<Packet> {
read_packet(&mut self.stream).await
}
fn next_request_id(&mut self) -> i32 {
let id = self.next_id;
self.next_id = if id == i32::MAX { 1 } else { id + 1 };
id
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
struct RecvPacket {
id: i32,
packet_type: i32,
payload: String,
}
struct MockServer {
stream: TcpStream,
}
impl MockServer {
async fn recv(&mut self) -> RecvPacket {
let mut len_buf = [0u8; 4];
self.stream.read_exact(&mut len_buf).await.unwrap();
let len = i32::from_le_bytes(len_buf) as usize;
let mut body = vec![0u8; len];
self.stream.read_exact(&mut body).await.unwrap();
let id = i32::from_le_bytes([body[0], body[1], body[2], body[3]]);
let packet_type = i32::from_le_bytes([body[4], body[5], body[6], body[7]]);
let payload = String::from_utf8_lossy(&body[8..len - 2]).to_string();
RecvPacket {
id,
packet_type,
payload,
}
}
async fn send(&mut self, id: i32, packet_type: i32, payload: &str) {
let payload_bytes = payload.as_bytes();
let body_len = (4 + 4 + payload_bytes.len() + 2) as i32;
self.stream
.write_all(&body_len.to_le_bytes())
.await
.unwrap();
self.stream.write_all(&id.to_le_bytes()).await.unwrap();
self.stream
.write_all(&packet_type.to_le_bytes())
.await
.unwrap();
self.stream.write_all(payload_bytes).await.unwrap();
self.stream.write_all(&[0, 0]).await.unwrap();
self.stream.flush().await.unwrap();
}
}
async fn mock_rcon<F, Fut>(handler: F) -> String
where
F: FnOnce(MockServer) -> Fut + Send + 'static,
Fut: std::future::Future<Output = ()> + Send,
{
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap().to_string();
tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
handler(MockServer { stream }).await;
});
addr
}
#[tokio::test]
async fn auth_success() {
let addr = mock_rcon(|mut s| async move {
let req = s.recv().await;
assert_eq!(req.packet_type, 3); assert_eq!(req.payload, "secret");
s.send(req.id, 2, "").await; })
.await;
let _client = RconClient::connect(&addr, "secret").await.unwrap();
}
#[tokio::test]
async fn auth_failure() {
let addr = mock_rcon(|mut s| async move {
let _req = s.recv().await;
s.send(-1, 2, "").await; })
.await;
let err = RconClient::connect(&addr, "wrong").await.unwrap_err();
assert!(matches!(err, RconError::AuthFailed));
}
#[tokio::test]
async fn execute_returns_payload() {
let addr = mock_rcon(|mut s| async move {
let req = s.recv().await;
s.send(req.id, 2, "").await;
let req = s.recv().await;
assert_eq!(req.packet_type, 2); assert_eq!(req.payload, "/version");
s.send(req.id, 0, "Factorio 2.0.28").await;
})
.await;
let mut client = RconClient::connect(&addr, "pass").await.unwrap();
let result = client.execute("/version").await.unwrap();
assert_eq!(result, "Factorio 2.0.28");
}
#[tokio::test]
async fn execute_empty_response() {
let addr = mock_rcon(|mut s| async move {
let req = s.recv().await;
s.send(req.id, 2, "").await;
let req = s.recv().await;
s.send(req.id, 0, "").await;
})
.await;
let mut client = RconClient::connect(&addr, "pass").await.unwrap();
let result = client.execute("/noop").await.unwrap();
assert_eq!(result, "");
}
#[tokio::test]
async fn execute_timeout() {
let addr = mock_rcon(|mut s| async move {
let req = s.recv().await;
s.send(req.id, 2, "").await;
let _req = s.recv().await;
tokio::time::sleep(Duration::from_secs(10)).await;
})
.await;
let mut client = RconClient::connect(&addr, "pass").await.unwrap();
client.set_timeout(Duration::from_millis(50));
let err = client.execute("/slow").await.unwrap_err();
assert!(matches!(err, RconError::Timeout(_)));
}
#[tokio::test]
async fn connection_lost_on_read() {
let addr = mock_rcon(|mut s| async move {
let req = s.recv().await;
s.send(req.id, 2, "").await;
let _req = s.recv().await;
drop(s); })
.await;
let mut client = RconClient::connect(&addr, "pass").await.unwrap();
let err = client.execute("/test").await.unwrap_err();
assert!(matches!(err, RconError::ConnectionLost(_)));
}
#[tokio::test]
async fn multiple_sequential_commands() {
let addr = mock_rcon(|mut s| async move {
let req = s.recv().await;
s.send(req.id, 2, "").await;
for i in 1..=3 {
let req = s.recv().await;
s.send(req.id, 0, &format!("response {i}")).await;
}
})
.await;
let mut client = RconClient::connect(&addr, "pass").await.unwrap();
for i in 1..=3 {
let result = client.execute(&format!("/cmd{i}")).await.unwrap();
assert_eq!(result, format!("response {i}"));
}
}
#[tokio::test]
async fn response_id_mismatch() {
let addr = mock_rcon(|mut s| async move {
let req = s.recv().await;
s.send(req.id, 2, "").await;
let req = s.recv().await;
s.send(req.id + 999, 0, "wrong").await; })
.await;
let mut client = RconClient::connect(&addr, "pass").await.unwrap();
let err = client.execute("/test").await.unwrap_err();
assert!(matches!(err, RconError::ProtocolError(_)));
}
#[tokio::test]
async fn request_ids_increment() {
let addr = mock_rcon(|mut s| async move {
let req = s.recv().await;
let auth_id = req.id;
s.send(req.id, 2, "").await;
let req = s.recv().await;
assert_eq!(req.id, auth_id + 1);
s.send(req.id, 0, "").await;
let req = s.recv().await;
assert_eq!(req.id, auth_id + 2);
s.send(req.id, 0, "").await;
})
.await;
let mut client = RconClient::connect(&addr, "pass").await.unwrap();
client.execute("/a").await.unwrap();
client.execute("/b").await.unwrap();
}
#[tokio::test]
async fn request_id_wraps_at_i32_max() {
let addr = mock_rcon(|mut s| async move {
let req = s.recv().await;
s.send(req.id, 2, "").await;
let req = s.recv().await;
assert_eq!(req.id, i32::MAX);
s.send(req.id, 0, "ok1").await;
let req = s.recv().await;
assert_eq!(req.id, 1);
s.send(req.id, 0, "ok2").await;
})
.await;
let mut client = RconClient::connect(&addr, "pass").await.unwrap();
client.next_id = i32::MAX;
assert_eq!(client.execute("/a").await.unwrap(), "ok1");
assert_eq!(client.execute("/b").await.unwrap(), "ok2");
}
}