orc_format/read/decompress/
mod.rs1use std::io::Read;
3
4use fallible_streaming_iterator::FallibleStreamingIterator;
5
6use crate::error::Error;
7use crate::proto::CompressionKind;
8
9fn decode_header(bytes: &[u8]) -> (bool, usize) {
10 let a: [u8; 3] = (&bytes[..3]).try_into().unwrap();
11 let a = [0, a[0], a[1], a[2]];
12 let length = u32::from_le_bytes(a);
13 let is_original = a[1] & 1 == 1;
14 let length = (length >> (8 + 1)) as usize;
15
16 (is_original, length)
17}
18
19enum State<'a> {
20 Original(&'a [u8]),
21 Compressed(Vec<u8>),
22}
23
24struct DecompressorIter<'a> {
25 stream: &'a [u8],
26 current: Option<State<'a>>, compression: CompressionKind,
28 scratch: Vec<u8>,
29}
30
31impl<'a> DecompressorIter<'a> {
32 pub fn new(stream: &'a [u8], compression: CompressionKind, scratch: Vec<u8>) -> Self {
33 Self {
34 stream,
35 current: None,
36 compression,
37 scratch,
38 }
39 }
40
41 pub fn into_inner(self) -> Vec<u8> {
42 match self.current {
43 Some(State::Compressed(some)) => some,
44 _ => self.scratch,
45 }
46 }
47}
48
49impl<'a> FallibleStreamingIterator for DecompressorIter<'a> {
50 type Item = [u8];
51
52 type Error = Error;
53
54 #[inline]
55 fn advance(&mut self) -> Result<(), Self::Error> {
56 if self.stream.is_empty() {
57 self.current = None;
58 return Ok(());
59 }
60 match self.compression {
61 CompressionKind::None => {
62 self.current = Some(State::Original(self.stream));
64 self.stream = &[];
65 }
66 CompressionKind::Zlib => {
67 let (is_original, length) = decode_header(self.stream);
69 self.stream = &self.stream[3..];
70 let (maybe_compressed, remaining) = self.stream.split_at(length);
71 self.stream = remaining;
72 if is_original {
73 self.current = Some(State::Original(maybe_compressed));
74 } else {
75 let mut gz = flate2::read::DeflateDecoder::new(maybe_compressed);
76 self.scratch.clear();
77 gz.read_to_end(&mut self.scratch)?;
78 self.current = Some(State::Compressed(std::mem::take(&mut self.scratch)));
79 }
80 }
81 _ => todo!(),
82 };
83 Ok(())
84 }
85
86 #[inline]
87 fn get(&self) -> Option<&Self::Item> {
88 self.current.as_ref().map(|x| match x {
89 State::Original(x) => *x,
90 State::Compressed(x) => x.as_ref(),
91 })
92 }
93}
94
95pub struct Decompressor<'a> {
97 decompressor: DecompressorIter<'a>,
98 offset: usize,
99 is_first: bool,
100}
101
102impl<'a> Decompressor<'a> {
103 pub fn new(stream: &'a [u8], compression: CompressionKind, scratch: Vec<u8>) -> Self {
105 Self {
106 decompressor: DecompressorIter::new(stream, compression, scratch),
107 offset: 0,
108 is_first: true,
109 }
110 }
111
112 pub fn into_inner(self) -> Vec<u8> {
114 self.decompressor.into_inner()
115 }
116}
117
118impl<'a> std::io::Read for Decompressor<'a> {
119 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
120 if self.is_first {
121 self.is_first = false;
122 self.decompressor.advance().unwrap();
123 }
124 let current = self.decompressor.get();
125 let current = if let Some(current) = current {
126 if current.len() == self.offset {
127 self.decompressor.advance().unwrap();
128 self.offset = 0;
129 let current = self.decompressor.get();
130 if let Some(current) = current {
131 current
132 } else {
133 return Ok(0);
134 }
135 } else {
136 ¤t[self.offset..]
137 }
138 } else {
139 return Ok(0);
140 };
141
142 if current.len() >= buf.len() {
143 buf.copy_from_slice(¤t[..buf.len()]);
144 self.offset += buf.len();
145 Ok(buf.len())
146 } else {
147 buf[..current.len()].copy_from_slice(current);
148 self.offset += current.len();
149 Ok(current.len())
150 }
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157
158 #[test]
159 fn decode_uncompressed() {
160 let bytes = &[0b1011, 0, 0, 0];
162
163 let (is_original, length) = decode_header(bytes);
164 assert!(is_original);
165 assert_eq!(length, 5);
166 }
167
168 #[test]
169 fn decode_compressed() {
170 let bytes = &[0b01000000, 0b00001101, 0b00000011, 0];
172
173 let (is_original, length) = decode_header(bytes);
174 assert!(!is_original);
175 assert_eq!(length, 100_000);
176 }
177}