mkv_element/
base.rs

1use crate::error::Error;
2use crate::functional::*;
3use crate::io::ReadExt;
4use crate::io::ReadFrom;
5use std::fmt::Display;
6use std::io::Read;
7use std::ops::Deref;
8
9/// A variable-length integer RFC 8794
10#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
11pub struct VInt64(pub u64);
12
13impl Display for VInt64 {
14    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15        // write!(f, "{}", self.as_encoded())
16        let encoded = self.as_encoded();
17        if encoded <= 0xFF {
18            write!(f, "0x{:02X}", encoded)
19        } else if encoded <= 0xFFFF {
20            write!(f, "0x{:04X}", encoded)
21        } else if encoded <= 0xFFFFFF {
22            write!(f, "0x{:06X}", encoded)
23        } else if encoded <= 0xFFFFFFFF {
24            write!(f, "0x{:08X}", encoded)
25        } else if encoded <= 0xFFFFFFFFFF {
26            write!(f, "0x{:010X}", encoded)
27        } else if encoded <= 0xFFFFFFFFFFFF {
28            write!(f, "0x{:012X}", encoded)
29        } else if encoded <= 0xFFFFFFFFFFFFFF {
30            write!(f, "0x{:014X}", encoded)
31        } else {
32            write!(f, "0x{:016X}", encoded)
33        }
34    }
35}
36
37impl Deref for VInt64 {
38    type Target = u64;
39    fn deref(&self) -> &Self::Target {
40        &self.0
41    }
42}
43impl VInt64 {
44    /// Create a VInt64 from an already encoded u64 value.
45    pub const fn from_encoded(enc: u64) -> Self {
46        Self(enc & (u64::MAX >> (enc.leading_zeros() + 1)))
47    }
48
49    /// Create a VInt64 from an already encoded u64 value.
50    pub fn as_encoded(&self) -> u64 {
51        let size = VInt64::encode_size(self.0);
52        let mut sbuf = [0u8; 8];
53        let slice = &mut sbuf[8 - size..];
54        slice.copy_from_slice(&self.0.to_be_bytes()[8 - size..]);
55        slice[0] |= 1u8 << (8 - size);
56        u64::from_be_bytes(sbuf)
57    }
58
59    /// Get the size in bytes of the encoded representation of a u64 value.
60    pub const fn encode_size(value: u64) -> usize {
61        let leading_zeros = value.leading_zeros() as usize;
62        let total_bits = 64 - leading_zeros;
63        if total_bits == 0 {
64            1
65        } else {
66            (total_bits + 6).div_euclid(7)
67        }
68    }
69}
70
71impl ReadFrom for VInt64 {
72    fn read_from<R: std::io::Read>(r: &mut R) -> crate::Result<Self> {
73        let first_byte = r.read_u8()?;
74        let leading_zeros = first_byte.leading_zeros() as usize;
75        if leading_zeros >= 8 {
76            return Err(crate::error::Error::InvalidVInt);
77        }
78
79        if leading_zeros == 0 {
80            Ok(VInt64((first_byte & 0b0111_1111) as u64))
81        } else {
82            let mut buf = [0u8; 8];
83            let read_buf = &mut buf[8 - leading_zeros..];
84            r.read_exact(read_buf)?;
85            if leading_zeros != 7 {
86                buf[8 - leading_zeros - 1] = first_byte & (0xFF >> (leading_zeros + 1));
87            }
88            Ok(VInt64(u64::from_be_bytes(buf)))
89        }
90    }
91}
92
93impl Decode for VInt64 {
94    fn decode(buf: &mut &[u8]) -> crate::Result<Self> {
95        if !buf.has_remaining() {
96            return Err(Error::OutOfBounds);
97        }
98        let first_byte = u8::decode(buf)?;
99        if first_byte == 0 {
100            return Err(Error::InvalidVInt);
101        }
102        let leading_zeros = first_byte.leading_zeros() as usize;
103
104        if leading_zeros == 0 {
105            Ok(VInt64((first_byte & 0b0111_1111) as u64))
106        } else {
107            if buf.remaining() < leading_zeros {
108                return Err(Error::OutOfBounds);
109            }
110            let mut bytes = [0u8; 8];
111            let read_buf = &mut bytes[8 - leading_zeros..];
112            read_buf.copy_from_slice(buf.slice(leading_zeros));
113
114            if leading_zeros != 7 {
115                bytes[8 - leading_zeros - 1] = first_byte & (0xFF >> (leading_zeros + 1));
116            }
117            buf.advance(leading_zeros);
118            Ok(VInt64(u64::from_be_bytes(bytes)))
119        }
120    }
121}
122
123impl Encode for VInt64 {
124    fn encode<B: BufMut>(&self, buf: &mut B) -> crate::Result<()> {
125        let size = VInt64::encode_size(self.0);
126        let mut sbuf = [0u8; 8];
127        let slice = &mut sbuf[8 - size..];
128        slice.copy_from_slice(&self.0.to_be_bytes()[8 - size..]);
129        slice[0] |= 1u8 << (8 - size);
130        buf.append_slice(slice);
131        Ok(())
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use crate::functional::{Decode, Encode};
138
139    use super::*;
140    use std::convert::TryInto;
141
142    #[test]
143    fn test_encode_size() {
144        let test_pair = [
145            (vec![0b1000_0000], 0),
146            (vec![0b1000_0001], 1),
147            (vec![0b1111_1111], 0b0111_1111),
148            (vec![0b0100_0000, 0xFF], 0xFF),
149            (vec![0b0100_0001, 0xFF], 0b1_1111_1111),
150            (vec![0b0111_1111, 0xFF], 0b11_1111_1111_1111),
151            (vec![0b0010_0000, 0b0111_1111, 0xFF], 0b111_1111_1111_1111),
152            (vec![0b0010_0000, 0xFF, 0xFF], 0xFFFF),
153            (vec![0b0011_1111, 0xFF, 0xFF], 0b1_1111_1111_1111_1111_1111),
154            (
155                vec![1, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF],
156                0xFF_FFFF_FFFF_FFFF,
157            ),
158        ];
159        for (encoded, val) in test_pair {
160            assert_eq!(VInt64::encode_size(val), encoded.len());
161        }
162    }
163
164    #[test]
165    fn test_encode() {
166        let test_pair = [
167            (vec![0b1000_0000], 0),
168            (vec![0b1000_0001], 1),
169            (vec![0b1111_1111], 0b0111_1111),
170            (vec![0b0100_0000, 0xFF], 0xFF),
171            (vec![0b0100_0001, 0xFF], 0b1_1111_1111),
172            (vec![0b0111_1111, 0xFF], 0b11_1111_1111_1111),
173            (vec![0b0010_0000, 0b0111_1111, 0xFF], 0b111_1111_1111_1111),
174            (vec![0b0010_0000, 0xFF, 0xFF], 0xFFFF),
175            (vec![0b0011_1111, 0xFF, 0xFF], 0b1_1111_1111_1111_1111_1111),
176            (
177                vec![1, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF],
178                0xFF_FFFF_FFFF_FFFF,
179            ),
180        ];
181        for (encoded, val) in test_pair {
182            let v = VInt64(val);
183            let mut out = vec![];
184            v.encode(&mut out).unwrap();
185            assert_eq!(encoded, out);
186
187            let encoded_num = v.as_encoded();
188            let mut enc8 = vec![0u8; 8 - encoded.len()];
189            enc8.extend_from_slice(&encoded);
190            let encoded_from = u64::from_be_bytes(enc8.try_into().unwrap());
191            assert_eq!(encoded_num, encoded_from);
192        }
193    }
194
195    #[test]
196    fn test_decode() {
197        let test_pair = [
198            (vec![0b1000_0000], 0),
199            (vec![0b1000_0001], 1),
200            (vec![0b1111_1111], 0b0111_1111),
201            (vec![0b0100_0000, 0xFF], 0xFF),
202            (vec![0b0100_0001, 0xFF], 0b1_1111_1111),
203            (vec![0b0111_1111, 0xFF], 0b11_1111_1111_1111),
204            (vec![0b0010_0000, 0b0111_1111, 0xFF], 0b111_1111_1111_1111),
205            (vec![0b0010_0000, 0xFF, 0xFF], 0xFFFF),
206            (vec![0b0011_1111, 0xFF, 0xFF], 0b1_1111_1111_1111_1111_1111),
207            (
208                vec![1, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF],
209                0xFF_FFFF_FFFF_FFFF,
210            ),
211        ];
212        for (encoded, val) in test_pair {
213            // test read
214            let mut c = std::io::Cursor::new(encoded.clone());
215            let vint = VInt64::read_from(&mut c).unwrap();
216            assert_eq!(*vint, val);
217
218            // test decode
219            let ecoded2 = encoded.clone();
220            let mut slice_encoded2 = &ecoded2[..];
221            let vint_decoded = VInt64::decode(&mut slice_encoded2).unwrap();
222            assert_eq!(*vint_decoded, val);
223
224            // test from_encoded
225            let mut enc8 = vec![0u8; 8 - encoded.len()];
226            enc8.extend_from_slice(&encoded);
227            let v = VInt64::from_encoded(u64::from_be_bytes(enc8.try_into().unwrap()));
228            assert_eq!(*v, val);
229        }
230    }
231}
232
233/// EBML element header, consisting of an ID and a size.
234#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
235pub struct Header {
236    /// EBML ID of the element.
237    pub id: VInt64,
238    /// Size of the element's data, excluding the header itself.
239    pub size: VInt64,
240}
241
242impl Header {
243    pub(crate) fn read_body<R: Read>(&self, r: &mut R) -> crate::Result<Vec<u8>> {
244        // we allocate 4096 bytes upfront and grow as needed
245        let size = self.size.0;
246        let cap = size.min(4096) as usize;
247        let mut buf = Vec::with_capacity(cap);
248        let n = std::io::copy(&mut r.take(size), &mut buf)?;
249        if size != n {
250            return Err(Error::OutOfBounds);
251        }
252        Ok(buf)
253    }
254}
255
256impl ReadFrom for Header {
257    fn read_from<R: std::io::Read>(reader: &mut R) -> crate::Result<Self> {
258        let id = VInt64::read_from(reader)?;
259        let size = VInt64::read_from(reader)?;
260        Ok(Self { id, size })
261    }
262}
263
264impl Decode for Header {
265    fn decode(buf: &mut &[u8]) -> crate::Result<Self> {
266        let id = VInt64::decode(buf)?;
267        let size = VInt64::decode(buf)?;
268        Ok(Self { id, size })
269    }
270}
271
272impl Encode for Header {
273    fn encode<B: BufMut>(&self, buf: &mut B) -> crate::Result<()> {
274        self.id.encode(buf)?;
275        self.size.encode(buf)?;
276        Ok(())
277    }
278}