1use std::{
2 cmp::Ordering,
3 io::{self, Read},
4};
5
6use bitstream_io::BitRead2;
7
8#[derive(Debug, thiserror::Error)]
9pub enum Error {
10 #[error(transparent)]
11 Io(#[from] io::Error),
12 #[error("Invalid LZW code encountered")]
13 InvalidCode,
14 #[error("Too many LZW codes encountered")]
15 TooManyCodes,
16}
17
18impl From<Error> for io::Error {
19 fn from(val: Error) -> Self {
20 match val {
21 Error::Io(error) => error,
22 Error::InvalidCode => io::Error::other("Invalid LZW code encountered"),
23 Error::TooManyCodes => io::Error::other("Too many LZW codes encountered"),
24 }
25 }
26}
27
28pub struct LzwTree<const N: usize> {
29 symbol_count: usize,
30 previous_symbol: usize,
31 symbol_size: u32,
32
33 parents: [usize; N],
34 values: [u8; N],
35 buffer: Vec<u8>,
36}
37
38impl<const N: usize> Default for LzwTree<N> {
39 fn default() -> Self {
40 if N < 256 {
41 panic!("Invalid LZWTree configuration");
42 }
43
44 let mut symbols = [0u8; N];
45 symbols
46 .iter_mut()
47 .take(256)
48 .enumerate()
49 .for_each(|(i, symbol)| *symbol = i as u8);
50
51 Self {
52 symbol_count: 256 + 1,
53 symbol_size: 9,
54 previous_symbol: usize::MAX,
55 parents: [usize::MAX; N],
56 buffer: Vec::with_capacity(1024),
57 values: symbols,
58 }
59 }
60}
61
62impl<const N: usize> LzwTree<N> {
63 fn reset(&mut self) {
64 self.symbol_count = 256 + 1;
65 self.previous_symbol = usize::MAX;
66 self.symbol_size = 9;
67 }
68
69 fn advance(&mut self, symbol: usize) -> Result<&[u8], Error> {
70 if self.previous_symbol == usize::MAX {
71 if symbol >= self.symbol_count {
72 return Err(Error::InvalidCode);
73 }
74
75 self.previous_symbol = symbol;
76 } else {
77 let value = match symbol.cmp(&self.symbol_count) {
78 Ordering::Less => self.find_first_byte(symbol),
79 Ordering::Equal => self.find_first_byte(self.previous_symbol),
80 Ordering::Greater => return Err(Error::InvalidCode),
81 };
82
83 let parent = self.previous_symbol;
84 self.previous_symbol = symbol;
85
86 if !self.full() {
87 self.parents[self.symbol_count] = parent;
88 self.values[self.symbol_count] = value;
89 self.symbol_count += 1;
90
91 if !self.full() && (self.symbol_count & (self.symbol_count - 1)) == 0 {
92 self.symbol_size += 1;
93 }
94 } else {
95 log::warn!("Ignore overflowing code table, hopefully the block ends soon…");
96 }
97 }
98
99 let n = self.output_len();
100 if n > self.buffer.len() {
101 self.buffer = vec![0u8; n];
102 }
103
104 let mut i = n;
105 let mut symbol = self.previous_symbol;
106 loop {
107 match symbol {
108 usize::MAX => return Ok(&self.buffer[0..n]),
109 _ => {
110 self.buffer[i - 1] = self.values[symbol];
111 symbol = self.parents[symbol];
112 i -= 1;
113 }
114 }
115 }
116 }
117
118 fn find_first_byte(&mut self, mut symbol: usize) -> u8 {
119 assert_ne!(symbol, usize::MAX);
120 loop {
121 match self.parents[symbol] {
122 usize::MAX => return self.values[symbol],
123 _ => symbol = self.parents[symbol],
124 }
125 }
126 }
127
128 fn full(&self) -> bool {
129 self.symbol_count == N
130 }
131
132 fn output_len(&self) -> usize {
133 let mut n = 0;
134 let mut symbol = self.previous_symbol;
135 loop {
136 match symbol {
137 usize::MAX => return n,
138 _ => {
139 n += 1;
140 symbol = self.parents[symbol]
141 }
142 }
143 }
144 }
145}
146
147pub struct LzwReader<R: io::Read> {
148 initialized: bool,
149 inner: bitstream_io::BitReader<R, bitstream_io::LittleEndian>,
150 tree: LzwTree<0x4000>,
151 symbol_counter: u32,
152 buffer: Vec<u8>,
153 buffer_pos: usize,
154
155 position: u64,
156 uncompressed_size: u64,
157}
158
159impl<R: io::Read> LzwReader<R> {
160 pub fn new(inner: R, uncompressed_size: u64) -> Self {
161 Self {
162 initialized: false,
163 inner: bitstream_io::BitReader::<_, bitstream_io::LittleEndian>::new(inner),
164 tree: Default::default(),
165 symbol_counter: 0,
166 buffer: Vec::new(),
167 buffer_pos: 0,
168
169 position: 0,
170 uncompressed_size,
171 }
172 }
173
174 pub fn into_inner(self) -> R {
175 self.inner.into_reader()
176 }
177
178 fn decode_chunk(&mut self) -> Result<&[u8], Error> {
179 loop {
180 self.symbol_counter += 1;
181 match self.inner.read(self.tree.symbol_size)? {
182 256u16 => {
183 log::info!("End of block found");
184 if !self.symbol_counter.is_multiple_of(8) {
185 self.inner
186 .skip(self.tree.symbol_size * (8 - (self.symbol_counter % 8)))?;
187 }
188 self.tree.reset();
189 self.symbol_counter = 0;
190 }
191 symbol => return self.tree.advance(symbol as usize),
192 }
193 }
194 }
195}
196
197impl<R: io::Read> io::Read for LzwReader<R> {
198 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
199 if self.position >= self.uncompressed_size {
200 return Ok(0);
201 }
202
203 if !self.initialized {
204 self.initialized = true;
205 self.buffer = self.decode_chunk()?.to_vec();
206 self.buffer_pos = 0;
207 }
208
209 for (idx, byte) in buf.iter_mut().enumerate() {
210 if self.buffer_pos < self.buffer.len() {
212 *byte = self.buffer[self.buffer_pos];
213 self.buffer_pos += 1;
214 continue;
215 }
216
217 self.buffer = match self.decode_chunk() {
219 Ok(buf) => buf.to_vec(),
220 Err(Error::Io(e)) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(idx),
221 Err(e) => return Err(e.into()),
222 };
223 self.buffer_pos = 0;
224
225 if self.buffer_pos < self.buffer.len() {
227 *byte = self.buffer[self.buffer_pos];
228 self.buffer_pos += 1;
229 continue;
230 }
231
232 self.position += idx as u64;
233 return Ok(idx);
234 }
235
236 Ok(buf.len())
237 }
238}
239
240impl<R: io::Read> io::Seek for LzwReader<R> {
241 fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
242 Ok(match pos {
243 io::SeekFrom::Current(0) => todo!(),
244 io::SeekFrom::Current(n) if n < 0 => todo!(),
245 io::SeekFrom::Current(x) => {
246 let mut buf = vec![0u8; x as usize];
247 self.read(&mut buf)? as u64
248 }
249 io::SeekFrom::End(_) => todo!(),
250 io::SeekFrom::Start(n) if n > self.position => {
251 self.seek(io::SeekFrom::Current(n as i64 - self.position as i64))?
252 }
253 _ => todo!(),
254 })
255 }
256
257 #[inline]
258 fn stream_position(&mut self) -> io::Result<u64> {
259 Ok(self.position)
260 }
261
262 #[inline]
263 fn stream_len(&mut self) -> io::Result<u64> {
264 Ok(self.uncompressed_size)
265 }
266}