1use crate::error::{RconError, Result};
2use crate::protocol::{read_packet, write_packet, Packet};
3use std::time::Duration;
4use tokio::net::TcpStream;
5use tokio::time::timeout;
6use tracing::{debug, info};
7
8const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
10
11#[derive(Debug)]
13pub struct RconClient {
14 stream: TcpStream,
15 next_id: i32,
16 timeout_duration: Duration,
17}
18
19impl RconClient {
20 pub async fn connect(addr: impl AsRef<str>, password: &str) -> Result<Self> {
38 let addr = addr.as_ref();
39 info!("Connecting to RCON server at {}", addr);
40
41 let stream = timeout(DEFAULT_TIMEOUT, TcpStream::connect(addr))
42 .await
43 .map_err(|_| RconError::Timeout(DEFAULT_TIMEOUT.as_millis() as u64))?
44 .map_err(RconError::ConnectionFailed)?;
45
46 let mut client = Self {
47 stream,
48 next_id: 1,
49 timeout_duration: DEFAULT_TIMEOUT,
50 };
51
52 client.authenticate(password).await?;
53
54 info!("Successfully connected and authenticated to {}", addr);
55 Ok(client)
56 }
57
58 pub async fn execute(&mut self, command: &str) -> Result<String> {
74 self.execute_with_timeout(command, self.timeout_duration)
75 .await
76 }
77
78 pub async fn execute_with_timeout(
100 &mut self,
101 command: &str,
102 timeout_duration: Duration,
103 ) -> Result<String> {
104 let id = self.next_request_id();
105 debug!(id, command, "Executing command");
106
107 let result = timeout(timeout_duration, async {
108 let packet = Packet::command(id, command);
109 self.send_packet(&packet).await?;
110 self.receive_packet().await
111 })
112 .await
113 .map_err(|_| RconError::Timeout(timeout_duration.as_millis() as u64))??;
114
115 if result.id != id {
116 return Err(RconError::ProtocolError(format!(
117 "Response ID mismatch: expected {}, got {}",
118 id, result.id
119 )));
120 }
121
122 debug!(
123 id,
124 response_len = result.payload.len(),
125 "Command executed successfully"
126 );
127
128 Ok(result.payload)
129 }
130
131 pub fn set_timeout(&mut self, duration: Duration) {
136 self.timeout_duration = duration;
137 debug!(?duration, "Timeout updated");
138 }
139
140 async fn authenticate(&mut self, password: &str) -> Result<()> {
142 debug!("Authenticating");
143
144 let id = self.next_request_id();
145 let packet = Packet::auth(id, password);
146
147 let response = timeout(self.timeout_duration, async {
148 self.send_packet(&packet).await?;
149 self.receive_packet().await
150 })
151 .await
152 .map_err(|_| RconError::Timeout(self.timeout_duration.as_millis() as u64))??;
153
154 if response.id == -1 {
156 return Err(RconError::AuthFailed);
157 }
158
159 debug!("Authentication successful");
160 Ok(())
161 }
162
163 async fn send_packet(&mut self, packet: &Packet) -> Result<()> {
165 write_packet(&mut self.stream, packet)
166 .await
167 .map_err(|e| match e {
168 RconError::Io(io_err) => RconError::ConnectionLost(io_err),
169 other => other,
170 })
171 }
172
173 async fn receive_packet(&mut self) -> Result<Packet> {
175 read_packet(&mut self.stream).await
176 }
177
178 fn next_request_id(&mut self) -> i32 {
180 let id = self.next_id;
181 self.next_id = if id == i32::MAX { 1 } else { id + 1 };
183 id
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use tokio::io::{AsyncReadExt, AsyncWriteExt};
191 use tokio::net::TcpListener;
192
193 struct RecvPacket {
197 id: i32,
198 packet_type: i32,
199 payload: String,
200 }
201
202 struct MockServer {
205 stream: TcpStream,
206 }
207
208 impl MockServer {
209 async fn recv(&mut self) -> RecvPacket {
210 let mut len_buf = [0u8; 4];
211 self.stream.read_exact(&mut len_buf).await.unwrap();
212 let len = i32::from_le_bytes(len_buf) as usize;
213
214 let mut body = vec![0u8; len];
215 self.stream.read_exact(&mut body).await.unwrap();
216
217 let id = i32::from_le_bytes([body[0], body[1], body[2], body[3]]);
218 let packet_type = i32::from_le_bytes([body[4], body[5], body[6], body[7]]);
219 let payload = String::from_utf8_lossy(&body[8..len - 2]).to_string();
220
221 RecvPacket {
222 id,
223 packet_type,
224 payload,
225 }
226 }
227
228 async fn send(&mut self, id: i32, packet_type: i32, payload: &str) {
229 let payload_bytes = payload.as_bytes();
230 let body_len = (4 + 4 + payload_bytes.len() + 2) as i32;
231
232 self.stream
233 .write_all(&body_len.to_le_bytes())
234 .await
235 .unwrap();
236 self.stream.write_all(&id.to_le_bytes()).await.unwrap();
237 self.stream
238 .write_all(&packet_type.to_le_bytes())
239 .await
240 .unwrap();
241 self.stream.write_all(payload_bytes).await.unwrap();
242 self.stream.write_all(&[0, 0]).await.unwrap();
243 self.stream.flush().await.unwrap();
244 }
245 }
246
247 async fn mock_rcon<F, Fut>(handler: F) -> String
250 where
251 F: FnOnce(MockServer) -> Fut + Send + 'static,
252 Fut: std::future::Future<Output = ()> + Send,
253 {
254 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
255 let addr = listener.local_addr().unwrap().to_string();
256 tokio::spawn(async move {
257 let (stream, _) = listener.accept().await.unwrap();
258 handler(MockServer { stream }).await;
259 });
260 addr
261 }
262
263 #[tokio::test]
266 async fn auth_success() {
267 let addr = mock_rcon(|mut s| async move {
268 let req = s.recv().await;
269 assert_eq!(req.packet_type, 3); assert_eq!(req.payload, "secret");
271 s.send(req.id, 2, "").await; })
273 .await;
274
275 let _client = RconClient::connect(&addr, "secret").await.unwrap();
276 }
277
278 #[tokio::test]
279 async fn auth_failure() {
280 let addr = mock_rcon(|mut s| async move {
281 let _req = s.recv().await;
282 s.send(-1, 2, "").await; })
284 .await;
285
286 let err = RconClient::connect(&addr, "wrong").await.unwrap_err();
287 assert!(matches!(err, RconError::AuthFailed));
288 }
289
290 #[tokio::test]
291 async fn execute_returns_payload() {
292 let addr = mock_rcon(|mut s| async move {
293 let req = s.recv().await;
294 s.send(req.id, 2, "").await;
295
296 let req = s.recv().await;
297 assert_eq!(req.packet_type, 2); assert_eq!(req.payload, "/version");
299 s.send(req.id, 0, "Factorio 2.0.28").await;
300 })
301 .await;
302
303 let mut client = RconClient::connect(&addr, "pass").await.unwrap();
304 let result = client.execute("/version").await.unwrap();
305 assert_eq!(result, "Factorio 2.0.28");
306 }
307
308 #[tokio::test]
309 async fn execute_empty_response() {
310 let addr = mock_rcon(|mut s| async move {
311 let req = s.recv().await;
312 s.send(req.id, 2, "").await;
313
314 let req = s.recv().await;
315 s.send(req.id, 0, "").await;
316 })
317 .await;
318
319 let mut client = RconClient::connect(&addr, "pass").await.unwrap();
320 let result = client.execute("/noop").await.unwrap();
321 assert_eq!(result, "");
322 }
323
324 #[tokio::test]
325 async fn execute_timeout() {
326 let addr = mock_rcon(|mut s| async move {
327 let req = s.recv().await;
328 s.send(req.id, 2, "").await;
329
330 let _req = s.recv().await;
331 tokio::time::sleep(Duration::from_secs(10)).await;
333 })
334 .await;
335
336 let mut client = RconClient::connect(&addr, "pass").await.unwrap();
337 client.set_timeout(Duration::from_millis(50));
338
339 let err = client.execute("/slow").await.unwrap_err();
340 assert!(matches!(err, RconError::Timeout(_)));
341 }
342
343 #[tokio::test]
344 async fn connection_lost_on_read() {
345 let addr = mock_rcon(|mut s| async move {
346 let req = s.recv().await;
347 s.send(req.id, 2, "").await;
348
349 let _req = s.recv().await;
350 drop(s); })
352 .await;
353
354 let mut client = RconClient::connect(&addr, "pass").await.unwrap();
355 let err = client.execute("/test").await.unwrap_err();
356 assert!(matches!(err, RconError::ConnectionLost(_)));
357 }
358
359 #[tokio::test]
360 async fn multiple_sequential_commands() {
361 let addr = mock_rcon(|mut s| async move {
362 let req = s.recv().await;
363 s.send(req.id, 2, "").await;
364
365 for i in 1..=3 {
366 let req = s.recv().await;
367 s.send(req.id, 0, &format!("response {i}")).await;
368 }
369 })
370 .await;
371
372 let mut client = RconClient::connect(&addr, "pass").await.unwrap();
373 for i in 1..=3 {
374 let result = client.execute(&format!("/cmd{i}")).await.unwrap();
375 assert_eq!(result, format!("response {i}"));
376 }
377 }
378
379 #[tokio::test]
380 async fn response_id_mismatch() {
381 let addr = mock_rcon(|mut s| async move {
382 let req = s.recv().await;
383 s.send(req.id, 2, "").await;
384
385 let req = s.recv().await;
386 s.send(req.id + 999, 0, "wrong").await; })
388 .await;
389
390 let mut client = RconClient::connect(&addr, "pass").await.unwrap();
391 let err = client.execute("/test").await.unwrap_err();
392 assert!(matches!(err, RconError::ProtocolError(_)));
393 }
394
395 #[tokio::test]
396 async fn request_ids_increment() {
397 let addr = mock_rcon(|mut s| async move {
398 let req = s.recv().await;
399 let auth_id = req.id;
400 s.send(req.id, 2, "").await;
401
402 let req = s.recv().await;
404 assert_eq!(req.id, auth_id + 1);
405 s.send(req.id, 0, "").await;
406
407 let req = s.recv().await;
408 assert_eq!(req.id, auth_id + 2);
409 s.send(req.id, 0, "").await;
410 })
411 .await;
412
413 let mut client = RconClient::connect(&addr, "pass").await.unwrap();
414 client.execute("/a").await.unwrap();
415 client.execute("/b").await.unwrap();
416 }
417
418 #[tokio::test]
419 async fn request_id_wraps_at_i32_max() {
420 let addr = mock_rcon(|mut s| async move {
421 let req = s.recv().await;
422 s.send(req.id, 2, "").await;
423
424 let req = s.recv().await;
426 assert_eq!(req.id, i32::MAX);
427 s.send(req.id, 0, "ok1").await;
428
429 let req = s.recv().await;
431 assert_eq!(req.id, 1);
432 s.send(req.id, 0, "ok2").await;
433 })
434 .await;
435
436 let mut client = RconClient::connect(&addr, "pass").await.unwrap();
437 client.next_id = i32::MAX;
438
439 assert_eq!(client.execute("/a").await.unwrap(), "ok1");
440 assert_eq!(client.execute("/b").await.unwrap(), "ok2");
441 }
442}