a2s/
lib.rs

1pub mod errors;
2pub mod info;
3pub mod players;
4pub mod rules;
5
6use std::io::{Cursor, Read, Write};
7#[cfg(not(feature = "async"))]
8use std::net::{ToSocketAddrs, UdpSocket};
9use std::ops::Deref;
10use std::time::Duration;
11
12#[cfg(feature = "async")]
13use tokio::net::{ToSocketAddrs, UdpSocket};
14#[cfg(feature = "async")]
15use tokio::time;
16
17use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
18use bzip2::read::BzDecoder;
19use crc::crc32;
20
21use crate::errors::{Error, Result};
22
23const SINGLE_PACKET: i32 = -1;
24const MULTI_PACKET: i32 = -2;
25
26// Offsets
27const OFS_HEADER: usize = 0;
28const OFS_SP_PAYLOAD: usize = 4;
29const OFS_MP_ID: usize = 4;
30const OFS_MP_SS_TOTAL: usize = 8;
31const OFS_MP_SS_NUMBER: usize = 9;
32const OFS_MP_SS_SIZE: usize = 10;
33const OFS_MP_SS_BZ2_SIZE: usize = 12;
34const OFS_MP_SS_BZ2_CRC: usize = 16;
35const OFS_MP_SS_PAYLOAD: usize = OFS_MP_SS_BZ2_SIZE;
36const OFS_MP_SS_PAYLOAD_BZ2: usize = OFS_MP_SS_BZ2_CRC + 4;
37
38macro_rules! read_buffer_offset {
39    ($buf:expr, $offset:expr, i8) => {
40        $buf[$offset].into()
41    };
42    ($buf:expr, $offset:expr, u8) => {
43        $buf[$offset].into()
44    };
45    ($buf:expr, $offset:expr, i16) => {
46        i16::from_le_bytes([$buf[$offset], $buf[$offset + 1]])
47    };
48    ($buf:expr, $offset:expr, u16) => {
49        u16::from_le_bytes([$buf[$offset], $buf[$offset + 1]])
50    };
51    ($buf:expr, $offset:expr, i32) => {
52        i32::from_le_bytes([
53            $buf[$offset],
54            $buf[$offset + 1],
55            $buf[$offset + 2],
56            $buf[$offset + 3],
57        ])
58    };
59    ($buf:expr, $offset:expr, u32) => {
60        u32::from_le_bytes([
61            $buf[$offset],
62            $buf[$offset + 1],
63            $buf[$offset + 2],
64            $buf[$offset + 3],
65        ])
66    };
67    ($buf:expr, $offset:expr, i64) => {
68        i64::from_le_bytes([
69            $buf[$offset],
70            $buf[$offset + 1],
71            $buf[$offset + 2],
72            $buf[$offset + 3],
73            $buf[$offset + 4],
74            $buf[$offset + 5],
75            $buf[$offset + 6],
76            $buf[$offset + 7],
77        ])
78    };
79    ($buf:expr, $offset:expr, u64) => {
80        u64::from_le_bytes([
81            $buf[$offset],
82            $buf[$offset + 1],
83            $buf[$offset + 2],
84            $buf[$offset + 3],
85            $buf[$offset + 4],
86            $buf[$offset + 5],
87            $buf[$offset + 6],
88            $buf[$offset + 7],
89        ])
90    };
91}
92
93#[derive(Debug)]
94struct PacketFragment {
95    number: u8,
96    payload: Vec<u8>,
97}
98
99pub struct A2SClient {
100    socket: UdpSocket,
101    #[cfg(feature = "async")]
102    timeout: Duration,
103    max_size: usize,
104    app_id: u16,
105}
106
107#[cfg(feature = "async")]
108macro_rules! future_timeout {
109    ($timeout:expr, $future:expr) => {
110        match time::timeout($timeout, $future).await {
111            Ok(value) => value,
112            Err(_) => return Err(Error::ErrTimeout),
113        }
114    };
115}
116
117impl A2SClient {
118    #[cfg(not(feature = "async"))]
119    pub fn new() -> Result<A2SClient> {
120        let socket = UdpSocket::bind("0.0.0.0:0")?;
121        let timeout = Duration::new(5, 0);
122
123        socket.set_read_timeout(Some(timeout))?;
124        socket.set_write_timeout(Some(timeout))?;
125
126        Ok(A2SClient {
127            socket,
128            max_size: 1400,
129            app_id: 0,
130        })
131    }
132
133    #[cfg(feature = "async")]
134    pub async fn new() -> Result<A2SClient> {
135        Ok(A2SClient {
136            socket : UdpSocket::bind("0.0.0.0:0").await?,
137            timeout: Duration::new(5, 0),
138            max_size: 1400,
139            app_id: 0,
140        })
141    }
142
143    pub fn max_size(&mut self, size: usize) -> &mut Self {
144        self.max_size = size;
145        self
146    }
147
148    pub fn app_id(&mut self, app_id: u16) -> &mut Self {
149        self.app_id = app_id;
150        self
151    }
152
153    #[cfg(not(feature = "async"))]
154    pub fn set_timeout(&mut self, timeout : Duration) -> Result<&mut Self> {
155        if timeout == Duration::ZERO {return Err(Error::Other("attempted to set timeout to 0"));}
156        self.socket.set_read_timeout(Some(timeout))?;
157        self.socket.set_write_timeout(Some(timeout))?;
158        Ok(self)
159    }
160
161    #[cfg(feature = "async")]
162    pub fn set_timeout(&mut self, timeout : Duration) -> Result<&mut Self> {
163        if timeout == Duration::ZERO {return Err(Error::Other("attempted to set timeout to 0"));}
164        self.timeout = timeout;
165        Ok(self)
166    }
167
168    #[cfg(feature = "async")]
169    async fn send<A: ToSocketAddrs>(&self, payload: &[u8], addr: A) -> Result<Vec<u8>> {
170        future_timeout!(self.timeout, self.socket.send_to(payload, addr))?;
171
172        let mut data = vec![0; self.max_size];
173
174        let read = future_timeout!(self.timeout, self.socket.recv(&mut data))?;
175        data.truncate(read);
176
177        // Header is a long (4 bytes)
178        if data.len() < 4 {
179            return Err(Error::InvalidResponse);
180        }
181
182        let header = read_buffer_offset!(&data, OFS_HEADER, i32);
183
184        if header == SINGLE_PACKET {
185            Ok(data[OFS_SP_PAYLOAD..].to_vec())
186        } else if header == MULTI_PACKET {
187            // ID - long (4 bytes)
188            // Total - byte (1 byte)
189            // Number - byte (1 byte)
190            // Size - short (2 bytes)
191
192            let id = read_buffer_offset!(&data, OFS_MP_ID, i32);
193            let total_packets: usize = data[OFS_MP_SS_TOTAL].into();
194            let switching_size: usize = read_buffer_offset!(&data, OFS_MP_SS_SIZE, u16).into();
195
196            // Sanity check
197            if (switching_size > self.max_size) || (total_packets > 32) {
198                return Err(Error::InvalidResponse);
199            }
200
201            let mut packets: Vec<PacketFragment> = Vec::with_capacity(0);
202            packets.try_reserve(total_packets)?;
203            packets.push(PacketFragment {
204                number: data[OFS_MP_SS_NUMBER],
205                // The first packet seems to include a single packet header (0xFFFFFFFF) for some
206                // reason, so we'd rather skip that (hence +4)
207                payload: Vec::from(&data[OFS_MP_SS_PAYLOAD + 4..]),
208            });
209
210            loop {
211                let mut data: Vec<u8> = Vec::with_capacity(0);
212                data.try_reserve(switching_size)?;
213                data.resize(switching_size, 0);
214
215                let read = future_timeout!(self.timeout, self.socket.recv(&mut data))?;
216                data.truncate(read);
217
218                if data.len() <= 9 {
219                    Err(Error::InvalidResponse)?
220                }
221
222                let packet_id = read_buffer_offset!(&data, OFS_MP_ID, i32);
223
224                if packet_id != id {
225                    return Err(Error::MismatchID);
226                }
227
228                if id as u32 & 0x80000000 == 0 {
229                    // Uncompressed packet
230                    packets.push(PacketFragment {
231                        number: data[OFS_MP_SS_NUMBER],
232                        payload: Vec::from(&data[OFS_MP_SS_PAYLOAD..]),
233                    });
234                } else {
235                    // BZip2 compressed packet
236                    packets.push(PacketFragment {
237                        number: data[OFS_MP_SS_NUMBER],
238                        payload: Vec::from(&data[OFS_MP_SS_PAYLOAD_BZ2..]),
239                    });
240                }
241
242                if packets.len() == total_packets {
243                    break;
244                }
245            }
246
247            packets.sort_by_key(|p| p.number);
248
249            let mut aggregation = Vec::with_capacity(0);
250            aggregation.try_reserve(total_packets * self.max_size)?;
251
252            for p in packets {
253                aggregation.extend(p.payload);
254            }
255
256            if id as u32 & 0x80000000 != 0 {
257                let decompressed_size = read_buffer_offset!(&data, OFS_MP_SS_BZ2_SIZE, u32);
258                let checksum = read_buffer_offset!(&data, OFS_MP_SS_BZ2_CRC, u32);
259
260                if decompressed_size > (1024 * 1024) {
261                    return Err(Error::InvalidBz2Size);
262                }
263
264                let mut decompressed = Vec::with_capacity(0);
265                decompressed.try_reserve(decompressed_size as usize)?;
266                decompressed.resize(decompressed_size as usize, 0);
267
268                BzDecoder::new(aggregation.deref()).read_exact(&mut decompressed)?;
269
270                if crc32::checksum_ieee(&decompressed) != checksum {
271                    return Err(Error::CheckSumMismatch);
272                }
273
274                Ok(decompressed)
275            } else {
276                Ok(aggregation)
277            }
278        } else {
279            Err(Error::InvalidResponse)
280        }
281    }
282
283    #[cfg(feature = "async")]
284    async fn do_challenge_request<A: ToSocketAddrs>(
285        &self,
286        addr: A,
287        header: &[u8],
288    ) -> Result<Vec<u8>> {
289        let packet = Vec::with_capacity(9);
290        let mut packet = Cursor::new(packet);
291
292        packet.write_all(header)?;
293        packet.write_i32::<LittleEndian>(-1)?;
294
295        let data = self.send(packet.get_ref(), &addr).await?;
296        let mut data = Cursor::new(data);
297
298        let header = data.read_u8()?;
299        if header != 'A' as u8 {
300            return Err(Error::InvalidResponse);
301        }
302
303        let challenge = data.read_i32::<LittleEndian>()?;
304
305        packet.set_position(5);
306        packet.write_i32::<LittleEndian>(challenge)?;
307        let data = self.send(packet.get_ref(), &addr).await?;
308
309        Ok(data)
310    }
311
312    #[cfg(not(feature = "async"))]
313    fn send<A: ToSocketAddrs>(&self, payload: &[u8], addr: A) -> Result<Vec<u8>> {
314        self.socket.send_to(payload, addr)?;
315
316        let mut data = vec![0; self.max_size];
317
318        let read = self.socket.recv(&mut data)?;
319        data.truncate(read);
320
321        let header = read_buffer_offset!(&data, OFS_HEADER, i32);
322
323        if header == SINGLE_PACKET {
324            Ok(data[OFS_SP_PAYLOAD..].to_vec())
325        } else if header == MULTI_PACKET {
326            // ID - long (4 bytes)
327            // Total - byte (1 byte)
328            // Number - byte (1 byte)
329            // Size - short (2 bytes)
330
331            let id = read_buffer_offset!(&data, OFS_MP_ID, i32);
332            let total_packets: usize = data[OFS_MP_SS_TOTAL].into();
333            let switching_size: usize = read_buffer_offset!(&data, OFS_MP_SS_SIZE, u16).into();
334
335            // Sanity check
336            if (switching_size > self.max_size) || (total_packets > 32) {
337                return Err(Error::InvalidResponse);
338            }
339
340            let mut packets: Vec<PacketFragment> = Vec::with_capacity(0);
341            packets.try_reserve(total_packets)?;
342            packets.push(PacketFragment {
343                number: data[OFS_MP_SS_NUMBER],
344                // The first packet seems to include a single packet header (0xFFFFFFFF) for some
345                // reason, so we'd rather skip that (hence +4)
346                payload: Vec::from(&data[OFS_MP_SS_PAYLOAD + 4..]),
347            });
348
349            loop {
350                let mut data: Vec<u8> = Vec::with_capacity(0);
351                data.try_reserve(switching_size)?;
352                data.resize(switching_size, 0);
353
354                let read = self.socket.recv(&mut data)?;
355                data.truncate(read);
356
357                if data.len() <= 9 {
358                    Err(Error::InvalidResponse)?
359                }
360
361                let packet_id = read_buffer_offset!(&data, OFS_MP_ID, i32);
362
363                if packet_id != id {
364                    return Err(Error::MismatchID);
365                }
366
367                if id as u32 & 0x80000000 == 0 {
368                    // Uncompressed packet
369                    packets.push(PacketFragment {
370                        number: data[OFS_MP_SS_NUMBER],
371                        payload: Vec::from(&data[OFS_MP_SS_PAYLOAD..]),
372                    });
373                } else {
374                    // BZip2 compressed packet
375                    packets.push(PacketFragment {
376                        number: data[OFS_MP_SS_NUMBER],
377                        payload: Vec::from(&data[OFS_MP_SS_PAYLOAD_BZ2..]),
378                    });
379                }
380
381                if packets.len() == total_packets {
382                    break;
383                }
384            }
385
386            packets.sort_by_key(|p| p.number);
387
388            let mut aggregation = Vec::with_capacity(0);
389            aggregation.try_reserve(total_packets * self.max_size)?;
390
391            for p in packets {
392                aggregation.extend(p.payload);
393            }
394
395            if id as u32 & 0x80000000 != 0 {
396                let decompressed_size = read_buffer_offset!(&data, OFS_MP_SS_BZ2_SIZE, u32);
397                let checksum = read_buffer_offset!(&data, OFS_MP_SS_BZ2_CRC, u32);
398
399                if decompressed_size > (1024 * 1024) {
400                    return Err(Error::InvalidBz2Size);
401                }
402
403                let mut decompressed = Vec::with_capacity(0);
404                decompressed.try_reserve(decompressed_size as usize)?;
405                decompressed.resize(decompressed_size as usize, 0);
406
407                BzDecoder::new(aggregation.deref()).read_exact(&mut decompressed)?;
408
409                if crc32::checksum_ieee(&decompressed) != checksum {
410                    return Err(Error::CheckSumMismatch);
411                }
412
413                Ok(decompressed)
414            } else {
415                Ok(aggregation)
416            }
417        } else {
418            Err(Error::InvalidResponse)
419        }
420    }
421
422    #[cfg(not(feature = "async"))]
423    fn do_challenge_request<A: ToSocketAddrs>(&self, addr: A, header: &[u8]) -> Result<Vec<u8>> {
424        let packet = Vec::with_capacity(9);
425        let mut packet = Cursor::new(packet);
426
427        packet.write_all(header)?;
428        packet.write_i32::<LittleEndian>(-1)?;
429
430        let data = self.send(packet.get_ref(), &addr)?;
431        let mut data = Cursor::new(data);
432
433        let header = data.read_u8()?;
434        if header != b'A' {
435            return Err(Error::InvalidResponse);
436        }
437
438        let challenge = data.read_i32::<LittleEndian>()?;
439
440        packet.set_position(5);
441        packet.write_i32::<LittleEndian>(challenge)?;
442        let data = self.send(packet.get_ref(), &addr)?;
443
444        Ok(data)
445    }
446}
447
448trait ReadCString {
449    fn read_cstring(&mut self) -> Result<String>;
450}
451
452impl ReadCString for Cursor<Vec<u8>> {
453    fn read_cstring(&mut self) -> Result<String> {
454        let end = self.get_ref().len() as u64;
455        let mut buf = [0; 1];
456        let mut str_vec = Vec::with_capacity(256);
457        while self.position() < end {
458            self.read_exact(&mut buf)?;
459            if buf[0] == 0 {
460                break;
461            } else {
462                str_vec.push(buf[0]);
463            }
464        }
465        Ok(String::from_utf8_lossy(&str_vec[..]).into_owned())
466    }
467}