faf_replay_parser/
reader.rs

1use byteorder::{LittleEndian, ReadBytesExt};
2use ordered_hash_map::OrderedHashMap;
3
4use super::lua::*;
5
6use std::error::Error;
7use std::ffi::CString;
8use std::fmt;
9use std::io;
10use std::string::FromUtf8Error;
11
12#[derive(Debug)]
13pub enum ReplayReadError {
14    Desynced(u32),
15    Malformed(&'static str),
16    MalformedUtf8(FromUtf8Error),
17    IO(std::io::Error),
18}
19pub type ReplayResult<T> = Result<T, ReplayReadError>;
20
21impl Error for ReplayReadError {
22    fn source(&self) -> Option<&(dyn Error + 'static)> {
23        match self {
24            Self::IO(e) => Some(e),
25            Self::MalformedUtf8(e) => Some(e),
26            _ => None,
27        }
28    }
29}
30
31impl fmt::Display for ReplayReadError {
32    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33        use ReplayReadError::*;
34
35        match self {
36            Desynced(tick) => write!(f, "desynced at tick {}", tick),
37            Malformed(msg) => write!(f, "{}", msg),
38            MalformedUtf8(e) => write!(f, "{}", e),
39            IO(e) => write!(f, "{}", e),
40        }
41    }
42}
43
44impl From<FromUtf8Error> for ReplayReadError {
45    fn from(err: FromUtf8Error) -> ReplayReadError {
46        ReplayReadError::MalformedUtf8(err)
47    }
48}
49
50impl From<std::io::Error> for ReplayReadError {
51    fn from(err: std::io::Error) -> ReplayReadError {
52        ReplayReadError::IO(err)
53    }
54}
55
56impl From<LuaTypeError> for ReplayReadError {
57    fn from(_err: LuaTypeError) -> ReplayReadError {
58        ReplayReadError::Malformed("invalud lua type")
59    }
60}
61
62/// Extensions to [`std::io::Read`]
63pub trait ReplayReadExt: io::Read {
64    /// Convenience wrapper around [`std::io::Read::read_to_end`] for allocating a new `Vec`.
65    fn vec_read_to_end(&mut self) -> io::Result<Vec<u8>> {
66        let mut res = Vec::new();
67        self.read_to_end(&mut res)?;
68        Ok(res)
69    }
70
71    /// Convenience wrapper around [`std::io::Read::read_exact`] for allocating a new `Vec`.
72    fn vec_read_exact(&mut self, n: usize) -> io::Result<Vec<u8>> {
73        // Don't pre-allocate more than 1 MiB.
74        let capacity = std::cmp::min(n, crate::MIB);
75        // TODO: Use #![feature(read_initializer)] to skip initialization here. Since read requries
76        // that the buffer be initialized, we need to write 0's to it for now.
77        let mut vec = vec![0; capacity];
78        self.read_exact(&mut vec)?;
79
80        Ok(vec)
81    }
82
83    /// Convenience wrapper around [`std::io::Read::read_exact`] for reading data into a Vec,
84    /// resizing it if more space is needed.
85    fn read_exact_to_vec(&mut self, n: usize, buf: &mut Vec<u8>) -> io::Result<()> {
86        let mut remaining = n;
87        // Read in 1 MiB chunks in case the requested size is really big.
88        while remaining > 0 {
89            let chunk = std::cmp::min(remaining, crate::MIB);
90            let index = n - remaining;
91
92            buf.resize(chunk, 0);
93            self.read_exact(&mut buf[index..index + chunk])?;
94            remaining -= chunk;
95        }
96        Ok(())
97    }
98
99    /// Ignore the next `n` bytes.
100    fn skip(&mut self, n: usize) -> io::Result<()> {
101        self.vec_read_exact(n)?;
102        Ok(())
103    }
104
105    /// Ignore the next `n` bytes by reading them into the buffer. This may perform multiple calls
106    /// to `read` so make sure the buffer is sufficiently large.
107    ///
108    /// If frequent skips are needed, this can be more efficient than calling `skip` as the buffer
109    /// can be reused.
110    fn skip_buf(&mut self, n: usize, buf: &mut [u8]) -> io::Result<()> {
111        let mut remaining = n;
112        while remaining > 0 {
113            let to_read = std::cmp::min(remaining, buf.len());
114            unsafe {
115                match self.read(buf.get_unchecked_mut(..to_read)) {
116                    Ok(0) => {
117                        return Err(io::Error::new(
118                            io::ErrorKind::UnexpectedEof,
119                            "failed to skip desired bytes",
120                        ));
121                    }
122                    Ok(n) => remaining -= n,
123                    Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
124                    Err(e) => return Err(e),
125                }
126            }
127        }
128        Ok(())
129    }
130
131    fn read_i32_le(&mut self) -> io::Result<i32> {
132        self.read_i32::<LittleEndian>()
133    }
134
135    fn read_u32_le(&mut self) -> io::Result<u32> {
136        self.read_u32::<LittleEndian>()
137    }
138
139    fn read_u16_le(&mut self) -> io::Result<u16> {
140        self.read_u16::<LittleEndian>()
141    }
142
143    fn read_i16_le(&mut self) -> io::Result<i16> {
144        self.read_i16::<LittleEndian>()
145    }
146
147    fn read_f32_le(&mut self) -> io::Result<f32> {
148        self.read_f32::<LittleEndian>()
149    }
150
151    /// Re-export of `ReadBytesExt::read_u8`
152    fn read_u8(&mut self) -> io::Result<u8> {
153        ReadBytesExt::read_u8(self)
154    }
155
156    fn read_bool(&mut self) -> io::Result<bool> {
157        Ok(self.read_u8()? == 1)
158    }
159}
160
161/// Implement `ReplayReadExt` for all types that implement `Read`.
162impl<R: io::Read + ?Sized> ReplayReadExt for R {}
163
164/// Extensions to `std::io::BufRead`
165pub trait ReplayBufReadExt: io::BufRead + ReplayReadExt {
166    /// Convenience wrapper around `std::io::BufRead::read_until` for allocating a new `Vec`.
167    fn vec_read_until(&mut self, byte: u8) -> io::Result<Vec<u8>> {
168        let mut vec = Vec::new();
169        self.read_until(byte, &mut vec)?;
170
171        Ok(vec)
172    }
173
174    /// Read a null terminated string from the stream into a new `Vec`.
175    fn vec_read_null_term(&mut self) -> io::Result<Vec<u8>> {
176        let mut vec = Vec::new();
177        self.read_until(0x00, &mut vec)?;
178        vec.pop();
179
180        Ok(vec)
181    }
182
183    /// Like `vec_read_null_term` but initialize the `Vec` with given capacity.
184    fn vec_read_null_term_with_capacity(&mut self, capacity: usize) -> io::Result<Vec<u8>> {
185        let mut vec = Vec::with_capacity(capacity);
186        self.read_until(0x00, &mut vec)?;
187        vec.pop();
188
189        Ok(vec)
190    }
191
192    /// Read a null terminated string from the stream and decode it as UTF-8. If it does not
193    /// decode succesfully, a MalformedUtf8 error is returned.
194    fn read_string(&mut self) -> io::Result<String> {
195        String::from_utf8(self.vec_read_null_term()?).map_err(|_| {
196            io::Error::new(
197                io::ErrorKind::InvalidData,
198                "stream did not contain valid UTF-8",
199            )
200        })
201    }
202
203    /// Like `read_string` but initialize the buffer with given capacity.
204    fn read_string_with_capacity(&mut self, capacity: usize) -> io::Result<String> {
205        String::from_utf8(self.vec_read_null_term_with_capacity(capacity)?).map_err(|_| {
206            io::Error::new(
207                io::ErrorKind::InvalidData,
208                "stream did not contain valid UTF-8",
209            )
210        })
211    }
212
213    /// Read a null terminated string from the stream and wrap it in a CString.
214    fn read_c_string(&mut self) -> io::Result<CString> {
215        unsafe { Ok(CString::from_vec_unchecked(self.vec_read_null_term()?)) }
216    }
217
218    /// Read a lua object.
219    fn read_lua_object(&mut self) -> ReplayResult<LuaObject> {
220        let lua_type = self.read_u8()?;
221        self.read_lua_object_as(lua_type)
222    }
223
224    fn read_lua_object_as(&mut self, lua_type: u8) -> ReplayResult<LuaObject> {
225        match lua_type {
226            LUA_FLOAT_MARKER => Ok(LuaObject::Float(self.read_f32_le()?)),
227            LUA_STRING_MARKER => Ok(LuaObject::String(self.read_c_string()?)),
228            LUA_NIL_MARKER => {
229                self.read_u8()?;
230                Ok(LuaObject::Nil)
231            }
232            LUA_BOOL_MARKER => Ok(LuaObject::Bool(self.read_bool()?)),
233            LUA_TABLE_MARKER => {
234                let mut res = OrderedHashMap::new();
235                loop {
236                    match self.read_u8()? {
237                        LUA_END_MARKER => break,
238                        next_type => {
239                            let mut key = self.read_lua_object_as(next_type)?;
240                            if let LuaObject::String(s) = key {
241                                key = LuaObject::from(
242                                    s.into_string().map_err(|e| LuaTypeError::from(e))?,
243                                )
244                            }
245                            let value = self.read_lua_object()?;
246                            res.insert(key, value);
247                        }
248                    }
249                }
250                Ok(LuaObject::Table(res))
251            }
252            _ => Err(LuaTypeError {})?,
253        }
254    }
255}
256
257/// Implement `ReplayBufReadExt` for all types that implement `Read`.
258impl<R: io::BufRead + ReplayReadExt + ?Sized> ReplayBufReadExt for R {}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use pretty_assertions::assert_eq;
264
265    use std::io::Cursor;
266
267    #[test]
268    fn test_skip_buf() {
269        let mut cur = Cursor::new(vec![0u8; 100]);
270        let mut buf = [0; 10];
271
272        cur.skip_buf(17, &mut buf).unwrap();
273        assert_eq!(cur.position(), 17);
274
275        cur.skip_buf(25, &mut buf).unwrap();
276        assert_eq!(cur.position(), 17 + 25);
277    }
278
279    #[test]
280    fn read_empty_string() {
281        let mut cur = Cursor::new(b"\x00");
282        assert_eq!(cur.read_string().unwrap(), "")
283    }
284
285    #[test]
286    fn read_string() {
287        let data: &[u8] = b"This is a c-style string\x00";
288        let mut cur = Cursor::new(data);
289        assert_eq!(cur.read_string().unwrap(), "This is a c-style string")
290    }
291
292    #[test]
293    fn read_copyright_string() {
294        // Taken directly from replays using the "Total Mayhem" mod which was an official GPG mod.
295        // "Copyright © 2006, Gas Powered Games" but as a single byte encoding
296        // Strings like these should not be decoded by the parser, as they would generate an error
297        let data: &[u8] = &[
298            67, 111, 112, 121, 114, 105, 103, 104, 116, 32, 169, 32, 50, 48, 48, 54, 44, 32, 71,
299            97, 115, 32, 80, 111, 119, 101, 114, 101, 100, 32, 71, 97, 109, 101, 115, 0,
300        ];
301        let mut cur = Cursor::new(data);
302        cur.read_string().unwrap_err();
303    }
304
305    #[test]
306    fn read_i32_le() {
307        let mut data: &[u8] = b"\xef\xbe\xad\x7e";
308        assert_eq!(data.read_i32_le().unwrap(), 0x7eadbeef)
309    }
310
311    #[test]
312    fn read_i16_le() {
313        let mut data: &[u8] = b"\xad\xde";
314        assert_eq!(data.read_u16_le().unwrap(), 0xdead)
315    }
316
317    #[test]
318    fn vec_read_until() {
319        let mut data: &[u8] = &[1, 2, 3, 4, 5];
320        assert_eq!(data.vec_read_until(3).unwrap(), vec![1, 2, 3]);
321    }
322
323    #[test]
324    fn read_lua_table() {
325        let data: &[u8] = b"\x04\x05";
326        let mut cur = Cursor::new(data);
327        let table = cur.read_lua_object().unwrap();
328        let hashmap = table.into_hashmap().unwrap();
329        assert!(hashmap.is_empty())
330    }
331
332    #[test]
333    fn test_replay_read_error_display() {
334        use std::io::Read;
335
336        assert_eq!(
337            format!("{}", ReplayReadError::Desynced(10)),
338            "desynced at tick 10"
339        );
340        assert_eq!(
341            format!("{}", ReplayReadError::Malformed("bad something or other")),
342            "bad something or other"
343        );
344        assert_eq!(
345            format!(
346                "{}",
347                <ReplayReadError as From<FromUtf8Error>>::from(
348                    String::from_utf8(b"\xFF".to_vec()).unwrap_err()
349                )
350            ),
351            "invalid utf-8 sequence of 1 bytes from index 0"
352        );
353
354        let mut buf: &[u8] = b"";
355        assert_eq!(
356            format!(
357                "{}",
358                ReplayReadError::from(
359                    (&mut buf)
360                        .read_exact(vec![0; 5].as_mut_slice())
361                        .unwrap_err()
362                )
363            ),
364            "failed to fill whole buffer"
365        );
366    }
367}