1use byteorder::{LittleEndian, ReadBytesExt};
2use ordered_hash_map::OrderedHashMap;
3
4use super::lua::*;
5
6use std::error::Error;
7use std::ffi::CString;
8use std::fmt;
9use std::io;
10use std::string::FromUtf8Error;
11
12#[derive(Debug)]
13pub enum ReplayReadError {
14 Desynced(u32),
15 Malformed(&'static str),
16 MalformedUtf8(FromUtf8Error),
17 IO(std::io::Error),
18}
19pub type ReplayResult<T> = Result<T, ReplayReadError>;
20
21impl Error for ReplayReadError {
22 fn source(&self) -> Option<&(dyn Error + 'static)> {
23 match self {
24 Self::IO(e) => Some(e),
25 Self::MalformedUtf8(e) => Some(e),
26 _ => None,
27 }
28 }
29}
30
31impl fmt::Display for ReplayReadError {
32 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33 use ReplayReadError::*;
34
35 match self {
36 Desynced(tick) => write!(f, "desynced at tick {}", tick),
37 Malformed(msg) => write!(f, "{}", msg),
38 MalformedUtf8(e) => write!(f, "{}", e),
39 IO(e) => write!(f, "{}", e),
40 }
41 }
42}
43
44impl From<FromUtf8Error> for ReplayReadError {
45 fn from(err: FromUtf8Error) -> ReplayReadError {
46 ReplayReadError::MalformedUtf8(err)
47 }
48}
49
50impl From<std::io::Error> for ReplayReadError {
51 fn from(err: std::io::Error) -> ReplayReadError {
52 ReplayReadError::IO(err)
53 }
54}
55
56impl From<LuaTypeError> for ReplayReadError {
57 fn from(_err: LuaTypeError) -> ReplayReadError {
58 ReplayReadError::Malformed("invalud lua type")
59 }
60}
61
62pub trait ReplayReadExt: io::Read {
64 fn vec_read_to_end(&mut self) -> io::Result<Vec<u8>> {
66 let mut res = Vec::new();
67 self.read_to_end(&mut res)?;
68 Ok(res)
69 }
70
71 fn vec_read_exact(&mut self, n: usize) -> io::Result<Vec<u8>> {
73 let capacity = std::cmp::min(n, crate::MIB);
75 let mut vec = vec![0; capacity];
78 self.read_exact(&mut vec)?;
79
80 Ok(vec)
81 }
82
83 fn read_exact_to_vec(&mut self, n: usize, buf: &mut Vec<u8>) -> io::Result<()> {
86 let mut remaining = n;
87 while remaining > 0 {
89 let chunk = std::cmp::min(remaining, crate::MIB);
90 let index = n - remaining;
91
92 buf.resize(chunk, 0);
93 self.read_exact(&mut buf[index..index + chunk])?;
94 remaining -= chunk;
95 }
96 Ok(())
97 }
98
99 fn skip(&mut self, n: usize) -> io::Result<()> {
101 self.vec_read_exact(n)?;
102 Ok(())
103 }
104
105 fn skip_buf(&mut self, n: usize, buf: &mut [u8]) -> io::Result<()> {
111 let mut remaining = n;
112 while remaining > 0 {
113 let to_read = std::cmp::min(remaining, buf.len());
114 unsafe {
115 match self.read(buf.get_unchecked_mut(..to_read)) {
116 Ok(0) => {
117 return Err(io::Error::new(
118 io::ErrorKind::UnexpectedEof,
119 "failed to skip desired bytes",
120 ));
121 }
122 Ok(n) => remaining -= n,
123 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
124 Err(e) => return Err(e),
125 }
126 }
127 }
128 Ok(())
129 }
130
131 fn read_i32_le(&mut self) -> io::Result<i32> {
132 self.read_i32::<LittleEndian>()
133 }
134
135 fn read_u32_le(&mut self) -> io::Result<u32> {
136 self.read_u32::<LittleEndian>()
137 }
138
139 fn read_u16_le(&mut self) -> io::Result<u16> {
140 self.read_u16::<LittleEndian>()
141 }
142
143 fn read_i16_le(&mut self) -> io::Result<i16> {
144 self.read_i16::<LittleEndian>()
145 }
146
147 fn read_f32_le(&mut self) -> io::Result<f32> {
148 self.read_f32::<LittleEndian>()
149 }
150
151 fn read_u8(&mut self) -> io::Result<u8> {
153 ReadBytesExt::read_u8(self)
154 }
155
156 fn read_bool(&mut self) -> io::Result<bool> {
157 Ok(self.read_u8()? == 1)
158 }
159}
160
161impl<R: io::Read + ?Sized> ReplayReadExt for R {}
163
164pub trait ReplayBufReadExt: io::BufRead + ReplayReadExt {
166 fn vec_read_until(&mut self, byte: u8) -> io::Result<Vec<u8>> {
168 let mut vec = Vec::new();
169 self.read_until(byte, &mut vec)?;
170
171 Ok(vec)
172 }
173
174 fn vec_read_null_term(&mut self) -> io::Result<Vec<u8>> {
176 let mut vec = Vec::new();
177 self.read_until(0x00, &mut vec)?;
178 vec.pop();
179
180 Ok(vec)
181 }
182
183 fn vec_read_null_term_with_capacity(&mut self, capacity: usize) -> io::Result<Vec<u8>> {
185 let mut vec = Vec::with_capacity(capacity);
186 self.read_until(0x00, &mut vec)?;
187 vec.pop();
188
189 Ok(vec)
190 }
191
192 fn read_string(&mut self) -> io::Result<String> {
195 String::from_utf8(self.vec_read_null_term()?).map_err(|_| {
196 io::Error::new(
197 io::ErrorKind::InvalidData,
198 "stream did not contain valid UTF-8",
199 )
200 })
201 }
202
203 fn read_string_with_capacity(&mut self, capacity: usize) -> io::Result<String> {
205 String::from_utf8(self.vec_read_null_term_with_capacity(capacity)?).map_err(|_| {
206 io::Error::new(
207 io::ErrorKind::InvalidData,
208 "stream did not contain valid UTF-8",
209 )
210 })
211 }
212
213 fn read_c_string(&mut self) -> io::Result<CString> {
215 unsafe { Ok(CString::from_vec_unchecked(self.vec_read_null_term()?)) }
216 }
217
218 fn read_lua_object(&mut self) -> ReplayResult<LuaObject> {
220 let lua_type = self.read_u8()?;
221 self.read_lua_object_as(lua_type)
222 }
223
224 fn read_lua_object_as(&mut self, lua_type: u8) -> ReplayResult<LuaObject> {
225 match lua_type {
226 LUA_FLOAT_MARKER => Ok(LuaObject::Float(self.read_f32_le()?)),
227 LUA_STRING_MARKER => Ok(LuaObject::String(self.read_c_string()?)),
228 LUA_NIL_MARKER => {
229 self.read_u8()?;
230 Ok(LuaObject::Nil)
231 }
232 LUA_BOOL_MARKER => Ok(LuaObject::Bool(self.read_bool()?)),
233 LUA_TABLE_MARKER => {
234 let mut res = OrderedHashMap::new();
235 loop {
236 match self.read_u8()? {
237 LUA_END_MARKER => break,
238 next_type => {
239 let mut key = self.read_lua_object_as(next_type)?;
240 if let LuaObject::String(s) = key {
241 key = LuaObject::from(
242 s.into_string().map_err(|e| LuaTypeError::from(e))?,
243 )
244 }
245 let value = self.read_lua_object()?;
246 res.insert(key, value);
247 }
248 }
249 }
250 Ok(LuaObject::Table(res))
251 }
252 _ => Err(LuaTypeError {})?,
253 }
254 }
255}
256
257impl<R: io::BufRead + ReplayReadExt + ?Sized> ReplayBufReadExt for R {}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use pretty_assertions::assert_eq;
264
265 use std::io::Cursor;
266
267 #[test]
268 fn test_skip_buf() {
269 let mut cur = Cursor::new(vec![0u8; 100]);
270 let mut buf = [0; 10];
271
272 cur.skip_buf(17, &mut buf).unwrap();
273 assert_eq!(cur.position(), 17);
274
275 cur.skip_buf(25, &mut buf).unwrap();
276 assert_eq!(cur.position(), 17 + 25);
277 }
278
279 #[test]
280 fn read_empty_string() {
281 let mut cur = Cursor::new(b"\x00");
282 assert_eq!(cur.read_string().unwrap(), "")
283 }
284
285 #[test]
286 fn read_string() {
287 let data: &[u8] = b"This is a c-style string\x00";
288 let mut cur = Cursor::new(data);
289 assert_eq!(cur.read_string().unwrap(), "This is a c-style string")
290 }
291
292 #[test]
293 fn read_copyright_string() {
294 let data: &[u8] = &[
298 67, 111, 112, 121, 114, 105, 103, 104, 116, 32, 169, 32, 50, 48, 48, 54, 44, 32, 71,
299 97, 115, 32, 80, 111, 119, 101, 114, 101, 100, 32, 71, 97, 109, 101, 115, 0,
300 ];
301 let mut cur = Cursor::new(data);
302 cur.read_string().unwrap_err();
303 }
304
305 #[test]
306 fn read_i32_le() {
307 let mut data: &[u8] = b"\xef\xbe\xad\x7e";
308 assert_eq!(data.read_i32_le().unwrap(), 0x7eadbeef)
309 }
310
311 #[test]
312 fn read_i16_le() {
313 let mut data: &[u8] = b"\xad\xde";
314 assert_eq!(data.read_u16_le().unwrap(), 0xdead)
315 }
316
317 #[test]
318 fn vec_read_until() {
319 let mut data: &[u8] = &[1, 2, 3, 4, 5];
320 assert_eq!(data.vec_read_until(3).unwrap(), vec![1, 2, 3]);
321 }
322
323 #[test]
324 fn read_lua_table() {
325 let data: &[u8] = b"\x04\x05";
326 let mut cur = Cursor::new(data);
327 let table = cur.read_lua_object().unwrap();
328 let hashmap = table.into_hashmap().unwrap();
329 assert!(hashmap.is_empty())
330 }
331
332 #[test]
333 fn test_replay_read_error_display() {
334 use std::io::Read;
335
336 assert_eq!(
337 format!("{}", ReplayReadError::Desynced(10)),
338 "desynced at tick 10"
339 );
340 assert_eq!(
341 format!("{}", ReplayReadError::Malformed("bad something or other")),
342 "bad something or other"
343 );
344 assert_eq!(
345 format!(
346 "{}",
347 <ReplayReadError as From<FromUtf8Error>>::from(
348 String::from_utf8(b"\xFF".to_vec()).unwrap_err()
349 )
350 ),
351 "invalid utf-8 sequence of 1 bytes from index 0"
352 );
353
354 let mut buf: &[u8] = b"";
355 assert_eq!(
356 format!(
357 "{}",
358 ReplayReadError::from(
359 (&mut buf)
360 .read_exact(vec![0; 5].as_mut_slice())
361 .unwrap_err()
362 )
363 ),
364 "failed to fill whole buffer"
365 );
366 }
367}