ebml_iterable/
tools.rs

1//! 
2//! Contains a number of tools that are useful when working with EBML encoded files.
3//! 
4
5use std::convert::TryInto;
6
7use super::errors::tool::ToolError;
8
9///
10/// Trait to enable easy serialization to a vint.
11/// 
12/// This is only available for types that can be cast as `u64`.
13/// 
14pub trait Vint: Into<u64> + Copy {
15    ///
16    /// Returns a representation of the current value as a vint array.
17    /// 
18    /// # Errors
19    ///
20    /// This can return an error if the value is too large to be representable as a vint.
21    /// 
22    fn as_vint(self) -> Result<Vec<u8>, ToolError> {
23        let val: u64 = self.into();
24        check_size_u64(val, 8)?;
25
26        if val < (1 << 7) {
27            Ok(as_vint_no_check_u64::<1>(val).to_vec())
28        } else if val < (1 << (7 * 2)) {
29            Ok(as_vint_no_check_u64::<2>(val).to_vec())
30        } else if val < (1 << (7 * 3)) {
31            Ok(as_vint_no_check_u64::<3>(val).to_vec())
32        } else if val < (1 << (7 * 4)) {
33            Ok(as_vint_no_check_u64::<4>(val).to_vec())
34        } else if val < (1 << (7 * 5)) {
35            Ok(as_vint_no_check_u64::<5>(val).to_vec())
36        } else if val < (1 << (7 * 6)) {
37            Ok(as_vint_no_check_u64::<6>(val).to_vec())
38        } else if val < (1 << (7 * 7)) {
39            Ok(as_vint_no_check_u64::<7>(val).to_vec())
40        } else {
41            Ok(as_vint_no_check_u64::<8>(val).to_vec())
42        }
43    }
44
45    ///
46    /// Returns a representation of the current value as a vint array with a specified length.
47    /// 
48    /// # Errors
49    ///
50    /// This can return an error if the value is too large to be representable as a vint.
51    /// 
52    fn as_vint_with_length<const LENGTH: usize>(&self) -> Result<[u8; LENGTH], ToolError> {
53        let val: u64 = (*self).into();
54        check_size_u64(val, LENGTH)?;
55        Ok(as_vint_no_check_u64::<LENGTH>(val))
56    }
57}
58
59impl Vint for u64 { }
60impl Vint for u32 { }
61impl Vint for u16 { }
62impl Vint for u8 { }
63
64#[inline]
65fn check_size_u64(val: u64, max_length: usize) -> Result<(), ToolError> {
66    if val >= 1 << (max_length * 7) {
67        Err(ToolError::WriteVintOverflow(val))
68    } else {
69        Ok(())
70    }
71}
72
73#[inline]
74fn as_vint_no_check_u64<const LENGTH: usize>(val: u64) -> [u8; LENGTH] {
75    let mut bytes: [u8; 8] = val.to_be_bytes();
76    bytes[8-LENGTH] |= 1 << (8 - LENGTH);
77    bytes[8-LENGTH..].try_into().expect("8 - (8-length) != length !?!?")
78}
79
80/// 
81/// Reads a vint from the beginning of the input array slice.
82/// 
83/// This method returns an option with the `None` variant used to indicate there was not enough data in the buffer to completely read a vint.
84/// 
85/// The returned tuple contains the value of the vint (`u64`) and the length of the vint (`usize`).  The length will be less than or equal to the length of the input slice.
86/// 
87/// # Errors
88///
89/// This method can return a `ToolError` if the input array cannot be read as a vint.
90/// 
91pub fn read_vint(buffer: &[u8]) -> Result<Option<(u64, usize)>, ToolError> {
92    if buffer.is_empty() {
93        return Ok(None);
94    }
95
96    if buffer[0] == 0 {
97        return Err(ToolError::ReadVintOverflow)
98    }
99
100    let length = 8 - buffer[0].ilog2() as usize;
101
102    if length > buffer.len() {
103        // Not enough data in the buffer to read out the vint value
104        return Ok(None);
105    }
106
107    let mut value = buffer[0] as u64;
108    value -= 1 << (8 - length);
109
110    for item in buffer.iter().take(length).skip(1) {
111        value <<= 8;
112        value += *item as u64;
113    }
114
115    Ok(Some((value, length)))
116}
117
118pub fn is_vint(val: u64) -> bool {
119    if val == 0 {
120        return false;
121    }
122
123    (val.ilog2() % 7) == 0
124}
125
126///
127/// Trait to enable easy serialization to a signed vint.
128/// 
129/// This is only available for types that can be cast as `i64`.  A signed vint can be written as a variable number of bytes just like a regular vint, but the value portion of the vint is expressed in two's complement notation.
130/// 
131/// For example, the decimal number "-33" would be written as [0xDF = 1101 1111].  This value is determined by first taking the two's complement of 33 [0x21 = 0010 0001] **but only using the bits available for the vint value**.  In this case, that is 7 bits (because the vint marker takes up the 8th bit).  The two's complement is [101 1111]. A handy calculator for two's complement can be found [here](https://www.omnicalculator.com/math/twos-complement).  Once the two's complement has been found, simply prepend the vint marker as usual to get [1101 1111 = 0xDF].
132/// 
133/// Some more examples:
134/// ```
135/// use ebml_iterable::tools::SignedVint;
136///
137/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
138/// assert_eq!(vec![0xDF], (-33i64).as_signed_vint().unwrap());
139/// assert_eq!(vec![0x40, 0xC8], (200i64).as_signed_vint().unwrap());
140/// assert_eq!(vec![0x7F, 0x38], (-200i64).as_signed_vint().unwrap());
141/// assert_eq!(vec![0xFF], (-1i64).as_signed_vint().unwrap());
142/// # Ok(())
143/// # }
144/// ```
145pub trait SignedVint: Into<i64> + Copy {
146    ///
147    /// Returns a representation of the current value as a vint array.
148    /// 
149    /// # Errors
150    ///
151    /// This can return an error if the value is outside of the range that can be represented as a vint.
152    /// 
153    fn as_signed_vint(&self) -> Result<Vec<u8>, ToolError> {
154        let val: i64 = (*self).into();
155        check_size_i64(val, 8)?;
156        let mut length = 1;
157        while length <= 8 {
158            if val >= -(1 << (7 * length - 1)) && val < (1 << (7 * length - 1)) {
159                break;
160            }
161            length += 1;
162        }
163
164        Ok(as_vint_no_check_i64(val, length))
165    }
166
167    ///
168    /// Returns a representation of the current value as a vint array with a specified length.
169    /// 
170    /// # Errors
171    ///
172    /// This can return an error if the value is outside of the range that can be represented as a vint.
173    /// 
174    fn as_signed_vint_with_length(&self, length: usize) -> Result<Vec<u8>, ToolError> {
175        let val: i64 = (*self).into();
176        check_size_i64(val, length)?;
177        Ok(as_vint_no_check_i64(val, length))
178    }
179}
180
181impl SignedVint for i64 { }
182impl SignedVint for i32 { }
183impl SignedVint for i16 { }
184impl SignedVint for i8 { }
185
186#[inline]
187fn check_size_i64(val: i64, max_length: usize) -> Result<(), ToolError> {
188    if val <= -(1 << (max_length * 7 - 1)) || val >= (1 << (max_length * 7 - 1)) {
189        Err(ToolError::WriteSignedVintOverflow(val))
190    } else {
191        Ok(())
192    }
193}
194
195#[inline]
196fn as_vint_no_check_i64(val: i64, length: usize) -> Vec<u8> {
197    let bytes: [u8; 8] = val.to_be_bytes();
198    let mut result: Vec<u8> = Vec::from(&bytes[(8-length)..]);
199    if val < 0 {
200        result[0] &= 0xFF >> (length-1);
201    } else {
202        result[0] |= 1 << (8 - length);
203    }
204    result
205}
206
207/// 
208/// Reads a signed vint from the beginning of the input array slice.
209/// 
210/// This method returns an option with the `None` variant used to indicate there was not enough data in the buffer to completely read a vint.
211/// 
212/// The returned tuple contains the value of the vint (`i64`) and the length of the vint (`usize`).  The length will be less than or equal to the length of the input slice.
213/// 
214/// # Errors
215///
216/// This method can return a `ToolError` if the input array cannot be read as a vint.
217/// 
218pub fn read_signed_vint(buffer: &[u8]) -> Result<Option<(i64, usize)>, ToolError> {
219    if buffer.is_empty() {
220        return Ok(None);
221    }
222
223    if buffer[0] == 0 {
224        return Err(ToolError::ReadVintOverflow)
225    }
226
227    let length = 8 - buffer[0].ilog2() as usize;
228
229    if length > buffer.len() {
230        // Not enough data in the buffer to read out the vint value
231        return Ok(None);
232    }
233
234    let is_negative = if length == 8 {
235        buffer[1] & 0x80
236    } else {
237        buffer[0] & (0x80 >> length)
238    } > 0;
239
240    let mut value = if is_negative {
241        (buffer[0] as i64) | (!0i64 << (8 - length))
242    } else {
243        (buffer[0] & (0xFF >> length)) as i64
244    };
245
246    for item in buffer.iter().take(length).skip(1) {
247        value <<= 8;
248        value += *item as i64;
249    }
250
251    Ok(Some((value, length)))
252}
253
254///
255/// Reads a `u64` value from any length array slice.
256/// 
257/// Rather than forcing the input to be a `[u8; 8]` like standard library methods, this can interpret a `u64` from a slice of any length < 8.  Bytes are assumed to be least significant when reading the value - i.e. an array of `[4, 0]` would return a value of `1024`.  
258///
259/// # Errors
260///
261/// This method will return an error if the input slice has a length > 8.
262/// 
263/// ## Example
264/// 
265/// ```
266/// # use ebml_iterable::tools::arr_to_u64;
267/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
268/// let result = arr_to_u64(&[16,0])?;
269/// assert_eq!(result, 4096);
270/// # Ok(())
271/// # }
272/// ```
273/// 
274pub fn arr_to_u64(arr: &[u8]) -> Result<u64, ToolError> {
275    if arr.len() > 8 {
276        return Err(ToolError::ReadU64Overflow(Vec::from(arr)));
277    }
278
279    let mut val = 0u64;
280    for byte in arr {
281        val *= 256;
282        val += *byte as u64;
283    }
284    Ok(val)
285}
286
287///
288/// Reads an `i64` value from any length array slice.
289/// 
290/// Rather than forcing the input to be a `[u8; 8]` like standard library methods, this can interpret an `i64` from a slice of any length < 8.  Bytes are assumed to be least significant when reading the value - i.e. an array of `[4, 0]` would return a value of `1024`.  
291///
292/// # Errors
293///
294/// This method will return an error if the input slice has a length > 8.
295/// 
296/// ## Example
297/// 
298/// ```
299/// # use ebml_iterable::tools::arr_to_i64;
300/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
301/// let result = arr_to_i64(&[4,0])?;
302/// assert_eq!(result, 1024);
303/// # Ok(())
304/// # }
305/// ```
306///
307pub fn arr_to_i64(arr: &[u8]) -> Result<i64, ToolError> {
308    if arr.len() > 8 {
309        return Err(ToolError::ReadI64Overflow(Vec::from(arr)));
310    }
311
312    if arr[0] > 127 {
313        if arr.len() == 8 {
314            Ok(i64::from_be_bytes(arr.try_into().expect("[u8;8] should be convertible to i64")))
315        } else {
316            Ok(-((1 << (arr.len() * 8)) - (arr_to_u64(arr).expect("arr_to_u64 shouldn't error if length is <= 8") as i64)))
317        }
318    } else {
319        Ok(arr_to_u64(arr).expect("arr_to_u64 shouldn't error if length is <= 8") as i64)
320    }
321}
322
323///
324/// Reads an `f64` value from an array slice of length 4 or 8.
325/// 
326/// This method wraps `f32` and `f64` conversions from big endian byte arrays and casts the result as an `f64`.  
327///
328/// # Errors
329///
330/// This method will throw an error if the input slice length is not 4 or 8.
331/// 
332pub fn arr_to_f64(arr: &[u8]) -> Result<f64, ToolError> {
333    if arr.len() == 4 {
334        Ok(f32::from_be_bytes(arr.try_into().expect("arr should be [u8;4]")) as f64)
335    } else if arr.len() == 8 {
336        Ok(f64::from_be_bytes(arr.try_into().expect("arr should be [u8;8]")))
337    } else {
338        Err(ToolError::ReadF64Mismatch(Vec::from(arr)))
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345
346    #[test]
347    fn read_vint_sixteen() {
348        let buffer = [144];
349        let result = read_vint(&buffer).unwrap().expect("Reading vint failed");
350
351        assert_eq!(16, result.0);
352        assert_eq!(1, result.1);
353    }
354
355    #[test]
356    fn write_vint_sixteen() {
357        let result = 16u64.as_vint().expect("Writing vint failed");
358        assert_eq!(vec![144u8], result);
359    }
360
361    #[test]
362    fn read_vint_one_twenty_seven() {
363        let buffer = [255u8];
364        let result = read_vint(&buffer).unwrap().expect("Reading vint failed");
365
366        assert_eq!(127, result.0);
367        assert_eq!(1, result.1);
368    }
369
370    #[test]
371    fn write_vint_one_twenty_seven() {
372        let result = 127u64.as_vint().expect("Writing vint failed");
373        assert_eq!(vec![255u8], result);
374    }
375
376    #[test]
377    fn read_vint_two_hundred() {
378        let buffer = [64, 200];
379        let result = read_vint(&buffer).unwrap().expect("Reading vint failed");
380
381        assert_eq!(200, result.0);
382        assert_eq!(2, result.1);
383    }
384
385    #[test]
386    fn write_vint_two_hundred() {
387        let result = 200u64.as_vint().expect("Writing vint failed");
388        assert_eq!(vec![64u8, 200u8], result);
389    }
390
391    #[test]
392    fn read_vint_for_ebml_tag() {
393        let buffer = [0x1a, 0x45, 0xdf, 0xa3];
394        let result = read_vint(&buffer).unwrap().expect("Reading vint failed");
395
396        assert_eq!(0x0a45dfa3, result.0);
397        assert_eq!(4, result.1);
398    }
399
400    #[test]
401    fn read_vint_very_long() {
402        let buffer = [1, 0, 0, 0, 0, 0, 0, 1];
403        let result = read_vint(&buffer).unwrap().expect("Reading vint failed");
404
405        assert_eq!(1, result.0);
406        assert_eq!(8, result.1);
407    }
408
409    #[test]
410    fn write_vint_very_long() {
411        let result = 1u64.as_vint_with_length::<8>().expect("Writing vint failed");
412        assert_eq!(vec![1, 0, 0, 0, 0, 0, 0, 1], result);
413    }
414
415    #[test]
416    fn read_vint_overflow() {
417        let buffer = [1, 0, 0, 0];
418        let result = read_vint(&buffer).expect("Reading vint failed");
419
420        assert_eq!(true, result.is_none());
421    }
422
423    #[test]
424    #[should_panic]
425    fn too_big_for_vint() {
426        (1u64 << 56).as_vint().expect("Writing vint failed");
427    }
428
429    #[test]
430    fn vint_encode_decode_range() {
431        for val in 0..500_000 {
432            let bytes = val.as_vint().unwrap();
433            let result = read_vint(bytes.as_slice()).unwrap().unwrap().0;
434            assert_eq!(val, result);
435        }
436    }
437
438    #[test]
439    fn signed_vint_encode_decode_range() {
440        for val in -500_000..500_000 {
441            let bytes = val.as_signed_vint().unwrap();
442            let result = read_signed_vint(bytes.as_slice()).unwrap().unwrap().0;
443            assert_eq!(val, result);
444        }
445    }
446
447    #[test]
448    fn read_u64_values() {
449        let mut buffer = vec![];
450        let mut expected = 0;
451        for _ in 0..8 {
452            buffer.push(0x25);
453            expected = (expected << 8) + 0x25;
454
455            let result = arr_to_u64(&buffer).unwrap();
456            assert_eq!(expected, result);
457        }
458    }
459
460    #[test]
461    fn read_i64_values() {
462        let mut buffer = vec![];
463        let mut expected = 0;
464        for _ in 0..8 {
465            buffer.push(0x0a);
466            expected = (expected << 8) + 0x0a;
467
468            let result = arr_to_i64(&buffer).unwrap();
469            assert_eq!(expected, result);
470
471            let neg_result = arr_to_i64(&(buffer.iter().map(|b| !b).collect::<Vec<u8>>())).unwrap() + 1;
472            assert_eq!(-expected, neg_result);
473        }
474    }
475
476    #[test]
477    fn valid_vints() {
478        assert!(is_vint(0x1F43B675));
479        assert!(is_vint(0xA0));
480        assert!(is_vint(0xA1));
481        assert!(is_vint(0x75A1));
482        assert!(is_vint(0xA6));
483        assert!(is_vint(0xEE));
484        assert!(is_vint(0xA5));
485        assert!(is_vint(0x9B));
486        assert!(is_vint(0xA2));
487        assert!(is_vint(0xA4));
488        assert!(is_vint(0x75A2));
489        assert!(is_vint(0xFB));
490        assert!(is_vint(0xC8));
491        assert!(is_vint(0xC9));
492        assert!(is_vint(0xCA));
493        assert!(is_vint(0xFA));
494        assert!(is_vint(0xFD));
495        assert!(is_vint(0x8E));
496        assert!(is_vint(0xE8));
497        assert!(is_vint(0xCB));
498        assert!(is_vint(0xCE));
499        assert!(is_vint(0xCD));
500        assert!(is_vint(0xCC));
501        assert!(is_vint(0xCF));
502        assert!(is_vint(0xAF));
503        assert!(is_vint(0xA7));
504        assert!(is_vint(0xAB));
505        assert!(is_vint(0x5854));
506        assert!(is_vint(0x58D7));
507        assert!(is_vint(0xA3));
508        assert!(is_vint(0xE7));
509        assert!(is_vint(0x3E83BB));
510        assert!(is_vint(0x3EB923));
511        assert!(is_vint(0x3C83AB));
512        assert!(is_vint(0x3CB923));
513
514        assert!(!is_vint(1234));
515        assert!(!is_vint(0x11));
516        assert!(!is_vint(0x7a));
517        assert!(!is_vint(0xfa4c));
518        assert!(!is_vint(0x1a5d));
519    }
520}