Skip to main content

factorio_rcon/
client.rs

1use crate::error::{RconError, Result};
2use crate::protocol::{read_packet, write_packet, Packet};
3use std::time::Duration;
4use tokio::net::TcpStream;
5use tokio::time::timeout;
6use tracing::{debug, info};
7
8/// Default timeout for connect and command operations
9const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
10
11/// Async RCON client for Factorio
12#[derive(Debug)]
13pub struct RconClient {
14    stream: TcpStream,
15    next_id: i32,
16    timeout_duration: Duration,
17}
18
19impl RconClient {
20    /// Connect to an RCON server and authenticate
21    ///
22    /// Uses the default 5-second timeout for both TCP connection and auth.
23    ///
24    /// # Arguments
25    /// * `addr` - Server address (e.g., "127.0.0.1:27015")
26    /// * `password` - RCON password
27    ///
28    /// # Example
29    /// ```no_run
30    /// # async fn example() -> factorio_rcon::Result<()> {
31    /// use factorio_rcon::RconClient;
32    ///
33    /// let mut client = RconClient::connect("127.0.0.1:27015", "password").await?;
34    /// # Ok(())
35    /// # }
36    /// ```
37    pub async fn connect(addr: impl AsRef<str>, password: &str) -> Result<Self> {
38        let addr = addr.as_ref();
39        info!("Connecting to RCON server at {}", addr);
40
41        let stream = timeout(DEFAULT_TIMEOUT, TcpStream::connect(addr))
42            .await
43            .map_err(|_| RconError::Timeout(DEFAULT_TIMEOUT.as_millis() as u64))?
44            .map_err(RconError::ConnectionFailed)?;
45
46        let mut client = Self {
47            stream,
48            next_id: 1,
49            timeout_duration: DEFAULT_TIMEOUT,
50        };
51
52        client.authenticate(password).await?;
53
54        info!("Successfully connected and authenticated to {}", addr);
55        Ok(client)
56    }
57
58    /// Execute an RCON command and return the response
59    ///
60    /// # Arguments
61    /// * `command` - Command to execute (e.g., "/version" or "/c game.tick")
62    ///
63    /// # Example
64    /// ```no_run
65    /// # async fn example() -> factorio_rcon::Result<()> {
66    /// # use factorio_rcon::RconClient;
67    /// # let mut client = RconClient::connect("127.0.0.1:27015", "password").await?;
68    /// let version = client.execute("/version").await?;
69    /// println!("Server version: {}", version);
70    /// # Ok(())
71    /// # }
72    /// ```
73    pub async fn execute(&mut self, command: &str) -> Result<String> {
74        self.execute_with_timeout(command, self.timeout_duration)
75            .await
76    }
77
78    /// Execute a command with a custom timeout
79    ///
80    /// The timeout covers the entire round-trip (send + receive).
81    ///
82    /// # Arguments
83    /// * `command` - Command to execute
84    /// * `timeout_duration` - Maximum time to wait for the complete operation
85    ///
86    /// # Example
87    /// ```no_run
88    /// # async fn example() -> factorio_rcon::Result<()> {
89    /// # use factorio_rcon::RconClient;
90    /// # use std::time::Duration;
91    /// # let mut client = RconClient::connect("127.0.0.1:27015", "password").await?;
92    /// let result = client.execute_with_timeout(
93    ///     "/c rcon.print(serpent.line(game.surfaces))",
94    ///     Duration::from_secs(10)
95    /// ).await?;
96    /// # Ok(())
97    /// # }
98    /// ```
99    pub async fn execute_with_timeout(
100        &mut self,
101        command: &str,
102        timeout_duration: Duration,
103    ) -> Result<String> {
104        let id = self.next_request_id();
105        debug!(id, command, "Executing command");
106
107        let result = timeout(timeout_duration, async {
108            let packet = Packet::command(id, command);
109            self.send_packet(&packet).await?;
110            self.receive_packet().await
111        })
112        .await
113        .map_err(|_| RconError::Timeout(timeout_duration.as_millis() as u64))??;
114
115        if result.id != id {
116            return Err(RconError::ProtocolError(format!(
117                "Response ID mismatch: expected {}, got {}",
118                id, result.id
119            )));
120        }
121
122        debug!(
123            id,
124            response_len = result.payload.len(),
125            "Command executed successfully"
126        );
127
128        Ok(result.payload)
129    }
130
131    /// Configure the default timeout for operations
132    ///
133    /// # Arguments
134    /// * `duration` - New timeout duration
135    pub fn set_timeout(&mut self, duration: Duration) {
136        self.timeout_duration = duration;
137        debug!(?duration, "Timeout updated");
138    }
139
140    /// Authenticate with the RCON server
141    async fn authenticate(&mut self, password: &str) -> Result<()> {
142        debug!("Authenticating");
143
144        let id = self.next_request_id();
145        let packet = Packet::auth(id, password);
146
147        let response = timeout(self.timeout_duration, async {
148            self.send_packet(&packet).await?;
149            self.receive_packet().await
150        })
151        .await
152        .map_err(|_| RconError::Timeout(self.timeout_duration.as_millis() as u64))??;
153
154        // Server returns ID=-1 on auth failure
155        if response.id == -1 {
156            return Err(RconError::AuthFailed);
157        }
158
159        debug!("Authentication successful");
160        Ok(())
161    }
162
163    /// Send a packet to the server
164    async fn send_packet(&mut self, packet: &Packet) -> Result<()> {
165        write_packet(&mut self.stream, packet)
166            .await
167            .map_err(|e| match e {
168                RconError::Io(io_err) => RconError::ConnectionLost(io_err),
169                other => other,
170            })
171    }
172
173    /// Receive a packet from the server
174    async fn receive_packet(&mut self) -> Result<Packet> {
175        read_packet(&mut self.stream).await
176    }
177
178    /// Get next request ID, wrapping to stay positive and avoid -1
179    fn next_request_id(&mut self) -> i32 {
180        let id = self.next_id;
181        // Wrap to 1 instead of overflowing or hitting -1 (auth failure sentinel)
182        self.next_id = if id == i32::MAX { 1 } else { id + 1 };
183        id
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use tokio::io::{AsyncReadExt, AsyncWriteExt};
191    use tokio::net::TcpListener;
192
193    // --- Mock RCON server ---
194
195    /// A received RCON packet (server-side view).
196    struct RecvPacket {
197        id: i32,
198        packet_type: i32,
199        payload: String,
200    }
201
202    /// Mock RCON server that reads and writes raw packets.
203    /// Intentionally independent of our `Packet` type — tests the wire format.
204    struct MockServer {
205        stream: TcpStream,
206    }
207
208    impl MockServer {
209        async fn recv(&mut self) -> RecvPacket {
210            let mut len_buf = [0u8; 4];
211            self.stream.read_exact(&mut len_buf).await.unwrap();
212            let len = i32::from_le_bytes(len_buf) as usize;
213
214            let mut body = vec![0u8; len];
215            self.stream.read_exact(&mut body).await.unwrap();
216
217            let id = i32::from_le_bytes([body[0], body[1], body[2], body[3]]);
218            let packet_type = i32::from_le_bytes([body[4], body[5], body[6], body[7]]);
219            let payload = String::from_utf8_lossy(&body[8..len - 2]).to_string();
220
221            RecvPacket {
222                id,
223                packet_type,
224                payload,
225            }
226        }
227
228        async fn send(&mut self, id: i32, packet_type: i32, payload: &str) {
229            let payload_bytes = payload.as_bytes();
230            let body_len = (4 + 4 + payload_bytes.len() + 2) as i32;
231
232            self.stream
233                .write_all(&body_len.to_le_bytes())
234                .await
235                .unwrap();
236            self.stream.write_all(&id.to_le_bytes()).await.unwrap();
237            self.stream
238                .write_all(&packet_type.to_le_bytes())
239                .await
240                .unwrap();
241            self.stream.write_all(payload_bytes).await.unwrap();
242            self.stream.write_all(&[0, 0]).await.unwrap();
243            self.stream.flush().await.unwrap();
244        }
245    }
246
247    /// Spawn a mock RCON server. Returns the address to connect to.
248    /// The handler drives the server side of the conversation.
249    async fn mock_rcon<F, Fut>(handler: F) -> String
250    where
251        F: FnOnce(MockServer) -> Fut + Send + 'static,
252        Fut: std::future::Future<Output = ()> + Send,
253    {
254        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
255        let addr = listener.local_addr().unwrap().to_string();
256        tokio::spawn(async move {
257            let (stream, _) = listener.accept().await.unwrap();
258            handler(MockServer { stream }).await;
259        });
260        addr
261    }
262
263    // --- Tests ---
264
265    #[tokio::test]
266    async fn auth_success() {
267        let addr = mock_rcon(|mut s| async move {
268            let req = s.recv().await;
269            assert_eq!(req.packet_type, 3); // SERVERDATA_AUTH
270            assert_eq!(req.payload, "secret");
271            s.send(req.id, 2, "").await; // SERVERDATA_AUTH_RESPONSE
272        })
273        .await;
274
275        let _client = RconClient::connect(&addr, "secret").await.unwrap();
276    }
277
278    #[tokio::test]
279    async fn auth_failure() {
280        let addr = mock_rcon(|mut s| async move {
281            let _req = s.recv().await;
282            s.send(-1, 2, "").await; // ID=-1 means auth failed
283        })
284        .await;
285
286        let err = RconClient::connect(&addr, "wrong").await.unwrap_err();
287        assert!(matches!(err, RconError::AuthFailed));
288    }
289
290    #[tokio::test]
291    async fn execute_returns_payload() {
292        let addr = mock_rcon(|mut s| async move {
293            let req = s.recv().await;
294            s.send(req.id, 2, "").await;
295
296            let req = s.recv().await;
297            assert_eq!(req.packet_type, 2); // SERVERDATA_EXECCOMMAND
298            assert_eq!(req.payload, "/version");
299            s.send(req.id, 0, "Factorio 2.0.28").await;
300        })
301        .await;
302
303        let mut client = RconClient::connect(&addr, "pass").await.unwrap();
304        let result = client.execute("/version").await.unwrap();
305        assert_eq!(result, "Factorio 2.0.28");
306    }
307
308    #[tokio::test]
309    async fn execute_empty_response() {
310        let addr = mock_rcon(|mut s| async move {
311            let req = s.recv().await;
312            s.send(req.id, 2, "").await;
313
314            let req = s.recv().await;
315            s.send(req.id, 0, "").await;
316        })
317        .await;
318
319        let mut client = RconClient::connect(&addr, "pass").await.unwrap();
320        let result = client.execute("/noop").await.unwrap();
321        assert_eq!(result, "");
322    }
323
324    #[tokio::test]
325    async fn execute_timeout() {
326        let addr = mock_rcon(|mut s| async move {
327            let req = s.recv().await;
328            s.send(req.id, 2, "").await;
329
330            let _req = s.recv().await;
331            // Never respond — hold connection open
332            tokio::time::sleep(Duration::from_secs(10)).await;
333        })
334        .await;
335
336        let mut client = RconClient::connect(&addr, "pass").await.unwrap();
337        client.set_timeout(Duration::from_millis(50));
338
339        let err = client.execute("/slow").await.unwrap_err();
340        assert!(matches!(err, RconError::Timeout(_)));
341    }
342
343    #[tokio::test]
344    async fn connection_lost_on_read() {
345        let addr = mock_rcon(|mut s| async move {
346            let req = s.recv().await;
347            s.send(req.id, 2, "").await;
348
349            let _req = s.recv().await;
350            drop(s); // Close connection before responding
351        })
352        .await;
353
354        let mut client = RconClient::connect(&addr, "pass").await.unwrap();
355        let err = client.execute("/test").await.unwrap_err();
356        assert!(matches!(err, RconError::ConnectionLost(_)));
357    }
358
359    #[tokio::test]
360    async fn multiple_sequential_commands() {
361        let addr = mock_rcon(|mut s| async move {
362            let req = s.recv().await;
363            s.send(req.id, 2, "").await;
364
365            for i in 1..=3 {
366                let req = s.recv().await;
367                s.send(req.id, 0, &format!("response {i}")).await;
368            }
369        })
370        .await;
371
372        let mut client = RconClient::connect(&addr, "pass").await.unwrap();
373        for i in 1..=3 {
374            let result = client.execute(&format!("/cmd{i}")).await.unwrap();
375            assert_eq!(result, format!("response {i}"));
376        }
377    }
378
379    #[tokio::test]
380    async fn response_id_mismatch() {
381        let addr = mock_rcon(|mut s| async move {
382            let req = s.recv().await;
383            s.send(req.id, 2, "").await;
384
385            let req = s.recv().await;
386            s.send(req.id + 999, 0, "wrong").await; // Wrong ID
387        })
388        .await;
389
390        let mut client = RconClient::connect(&addr, "pass").await.unwrap();
391        let err = client.execute("/test").await.unwrap_err();
392        assert!(matches!(err, RconError::ProtocolError(_)));
393    }
394
395    #[tokio::test]
396    async fn request_ids_increment() {
397        let addr = mock_rcon(|mut s| async move {
398            let req = s.recv().await;
399            let auth_id = req.id;
400            s.send(req.id, 2, "").await;
401
402            // Each command should have an incrementing ID
403            let req = s.recv().await;
404            assert_eq!(req.id, auth_id + 1);
405            s.send(req.id, 0, "").await;
406
407            let req = s.recv().await;
408            assert_eq!(req.id, auth_id + 2);
409            s.send(req.id, 0, "").await;
410        })
411        .await;
412
413        let mut client = RconClient::connect(&addr, "pass").await.unwrap();
414        client.execute("/a").await.unwrap();
415        client.execute("/b").await.unwrap();
416    }
417
418    #[tokio::test]
419    async fn request_id_wraps_at_i32_max() {
420        let addr = mock_rcon(|mut s| async move {
421            let req = s.recv().await;
422            s.send(req.id, 2, "").await;
423
424            // First command should use i32::MAX
425            let req = s.recv().await;
426            assert_eq!(req.id, i32::MAX);
427            s.send(req.id, 0, "ok1").await;
428
429            // Second command should wrap to 1, not overflow or hit -1
430            let req = s.recv().await;
431            assert_eq!(req.id, 1);
432            s.send(req.id, 0, "ok2").await;
433        })
434        .await;
435
436        let mut client = RconClient::connect(&addr, "pass").await.unwrap();
437        client.next_id = i32::MAX;
438
439        assert_eq!(client.execute("/a").await.unwrap(), "ok1");
440        assert_eq!(client.execute("/b").await.unwrap(), "ok2");
441    }
442}