1use crate::decode::lzbuffer::{LzBuffer, LzCircularBuffer};
2use crate::decode::lzma::{DecoderState, LzmaParams};
3use crate::decode::rangecoder::RangeDecoder;
4use crate::decompress::Options;
5use crate::error::Error;
6use std::fmt::Debug;
7use std::io::{self, BufRead, Cursor, Read, Write};
8
9const MIN_HEADER_LEN: usize = 5;
13
14const MAX_HEADER_LEN: usize = MIN_HEADER_LEN + 8;
17
18const START_BYTES: usize = 5;
22
23const MAX_TMP_LEN: usize = MAX_HEADER_LEN + START_BYTES;
25
26#[derive(Debug)]
29enum State<W>
30where
31 W: Write,
32{
33 Header(W),
35 Data(RunState<W>),
37}
38
39struct RunState<W>
41where
42 W: Write,
43{
44 decoder: DecoderState,
45 range: u32,
46 code: u32,
47 output: LzCircularBuffer<W>,
48}
49
50impl<W> Debug for RunState<W>
51where
52 W: Write,
53{
54 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
55 fmt.debug_struct("RunState")
56 .field("range", &self.range)
57 .field("code", &self.code)
58 .finish()
59 }
60}
61
62#[cfg_attr(docsrs, doc(cfg(stream)))]
65pub struct Stream<W>
66where
67 W: Write,
68{
69 tmp: Cursor<[u8; MAX_TMP_LEN]>,
71 state: Option<State<W>>,
74 options: Options,
76}
77
78impl<W> Stream<W>
79where
80 W: Write,
81{
82 pub fn new(output: W) -> Self {
85 Self::new_with_options(&Options::default(), output)
86 }
87
88 pub fn new_with_options(options: &Options, output: W) -> Self {
92 Self {
93 tmp: Cursor::new([0; MAX_TMP_LEN]),
94 state: Some(State::Header(output)),
95 options: *options,
96 }
97 }
98
99 pub fn get_output(&self) -> Option<&W> {
101 self.state.as_ref().map(|state| match state {
102 State::Header(output) => &output,
103 State::Data(state) => state.output.get_output(),
104 })
105 }
106
107 pub fn get_output_mut(&mut self) -> Option<&mut W> {
109 self.state.as_mut().map(|state| match state {
110 State::Header(output) => output,
111 State::Data(state) => state.output.get_output_mut(),
112 })
113 }
114
115 pub fn finish(mut self) -> crate::error::Result<W> {
118 if let Some(state) = self.state.take() {
119 match state {
120 State::Header(output) => {
121 if self.tmp.position() > 0 {
122 Err(Error::LzmaError("failed to read header".to_string()))
123 } else {
124 Ok(output)
125 }
126 }
127 State::Data(mut state) => {
128 if !self.options.allow_incomplete {
129 let mut stream =
132 Cursor::new(&self.tmp.get_ref()[0..self.tmp.position() as usize]);
133 let mut range_decoder =
134 RangeDecoder::from_parts(&mut stream, state.range, state.code);
135 state
136 .decoder
137 .process(&mut state.output, &mut range_decoder)?;
138 }
139 let output = state.output.finish()?;
140 Ok(output)
141 }
142 }
143 } else {
144 Err(Error::LzmaError(
146 "can't finish stream because of previous write error".to_string(),
147 ))
148 }
149 }
150
151 fn read_header<R: BufRead>(
156 output: W,
157 mut input: &mut R,
158 options: &Options,
159 ) -> crate::error::Result<State<W>> {
160 match LzmaParams::read_header(&mut input, options) {
161 Ok(params) => {
162 let decoder = DecoderState::new(params.properties, params.unpacked_size);
163 let output = LzCircularBuffer::from_stream(
164 output,
165 params.dict_size as usize,
166 options.memlimit.unwrap_or(usize::MAX),
167 );
168 if let Ok(rangecoder) = RangeDecoder::new(&mut input) {
171 Ok(State::Data(RunState {
172 decoder,
173 output,
174 range: rangecoder.range,
175 code: rangecoder.code,
176 }))
177 } else {
178 Ok(State::Header(output.into_output()))
181 }
182 }
183 Err(Error::HeaderTooShort(_)) => Ok(State::Header(output)),
185 Err(e) => Err(e),
187 }
188 }
189
190 fn read_data<R: BufRead>(mut state: RunState<W>, mut input: &mut R) -> io::Result<RunState<W>> {
192 let mut rangecoder = RangeDecoder::from_parts(&mut input, state.range, state.code);
195
196 state
198 .decoder
199 .process_stream(&mut state.output, &mut rangecoder)
200 .map_err(|e| -> io::Error { e.into() })?;
201
202 Ok(RunState {
203 decoder: state.decoder,
204 output: state.output,
205 range: rangecoder.range,
206 code: rangecoder.code,
207 })
208 }
209}
210
211impl<W> Debug for Stream<W>
212where
213 W: Write + Debug,
214{
215 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
216 fmt.debug_struct("Stream")
217 .field("tmp", &self.tmp.position())
218 .field("state", &self.state)
219 .field("options", &self.options)
220 .finish()
221 }
222}
223
224impl<W> Write for Stream<W>
225where
226 W: Write,
227{
228 fn write(&mut self, data: &[u8]) -> io::Result<usize> {
229 let mut input = Cursor::new(data);
230
231 if let Some(state) = self.state.take() {
232 let state = match state {
233 State::Header(state) => {
235 let res = if self.tmp.position() > 0 {
236 let position = self.tmp.position();
238 let bytes_read =
239 input.read(&mut self.tmp.get_mut()[position as usize..])?;
240 let bytes_read = if bytes_read < std::u64::MAX as usize {
241 bytes_read as u64
242 } else {
243 return Err(io::Error::new(
244 io::ErrorKind::Other,
245 "Failed to convert integer to u64.",
246 ));
247 };
248 self.tmp.set_position(position + bytes_read);
249
250 let (position, res) = {
252 let mut tmp_input =
253 Cursor::new(&self.tmp.get_ref()[0..self.tmp.position() as usize]);
254 let res = Stream::read_header(state, &mut tmp_input, &self.options);
255 (tmp_input.position(), res)
256 };
257
258 if let Ok(State::Data(_)) = &res {
261 let tmp = *self.tmp.get_ref();
262 let end = self.tmp.position();
263 let new_len = end - position;
264 (&mut self.tmp.get_mut()[0..new_len as usize])
265 .copy_from_slice(&tmp[position as usize..end as usize]);
266 self.tmp.set_position(new_len);
267 }
268 res
269 } else {
270 Stream::read_header(state, &mut input, &self.options)
271 };
272
273 match res {
274 Ok(State::Header(val)) => {
277 if self.tmp.position() == 0 {
278 input.set_position(0);
280 let bytes_read = input.read(&mut self.tmp.get_mut()[..])?;
281 let bytes_read = if bytes_read < std::u64::MAX as usize {
282 bytes_read as u64
283 } else {
284 return Err(io::Error::new(
285 io::ErrorKind::Other,
286 "Failed to convert integer to u64.",
287 ));
288 };
289 self.tmp.set_position(bytes_read);
290 }
291 State::Header(val)
292 }
293
294 Ok(State::Data(val)) => State::Data(val),
297
298 Err(e) => {
301 return Err(match e {
302 Error::IoError(e) | Error::HeaderTooShort(e) => e,
303 Error::LzmaError(e) | Error::XzError(e) => {
304 io::Error::new(io::ErrorKind::Other, e)
305 }
306 });
307 }
308 }
309 }
310
311 State::Data(state) => {
313 let state = if self.tmp.position() > 0 {
314 let mut tmp_input =
315 Cursor::new(&self.tmp.get_ref()[0..self.tmp.position() as usize]);
316 let res = Stream::read_data(state, &mut tmp_input)?;
317 self.tmp.set_position(0);
318 res
319 } else {
320 state
321 };
322 State::Data(Stream::read_data(state, &mut input)?)
323 }
324 };
325 self.state.replace(state);
326 }
327 Ok(input.position() as usize)
328 }
329
330 fn flush(&mut self) -> io::Result<()> {
334 if let Some(ref mut state) = self.state {
335 match state {
336 State::Header(_) => Ok(()),
337 State::Data(state) => state.output.get_output_mut().flush(),
338 }
339 } else {
340 Ok(())
341 }
342 }
343}
344
345impl From<Error> for io::Error {
346 fn from(error: Error) -> io::Error {
347 io::Error::new(io::ErrorKind::Other, format!("{:?}", error))
348 }
349}
350
351#[cfg(test)]
352mod test {
353 use super::*;
354
355 #[test]
357 fn test_stream_noop() {
358 let stream = Stream::new(Vec::new());
359 assert!(stream.get_output().unwrap().is_empty());
360
361 let output = stream.finish().unwrap();
362 assert!(output.is_empty());
363 }
364
365 #[test]
367 fn test_stream_zero() {
368 let mut stream = Stream::new(Vec::new());
369
370 stream.write_all(&[]).unwrap();
371 stream.write_all(&[]).unwrap();
372
373 let output = stream.finish().unwrap();
374
375 assert!(output.is_empty());
376 }
377
378 #[test]
380 #[should_panic(expected = "LZMA header invalid properties: 255 must be < 225")]
381 fn test_bad_header() {
382 let input = [255u8; 32];
383
384 let mut stream = Stream::new(Vec::new());
385
386 stream.write_all(&input[..]).unwrap();
387
388 let output = stream.finish().unwrap();
389
390 assert!(output.is_empty());
391 }
392
393 #[test]
395 fn test_stream_incomplete() {
396 let input = b"\x5d\x00\x00\x80\x00\xff\xff\xff\xff\xff\xff\xff\xff\x00\x83\xff\
397 \xfb\xff\xff\xc0\x00\x00\x00";
398 let mut end = 1u64;
400
401 while end < (MAX_HEADER_LEN + START_BYTES) as u64 {
405 let mut stream = Stream::new(Vec::new());
406 stream.write_all(&input[..end as usize]).unwrap();
407 assert_eq!(stream.tmp.position(), end);
408
409 let err = stream.finish().unwrap_err();
410 assert!(
411 err.to_string().contains("failed to read header"),
412 "error was: {}",
413 err
414 );
415
416 end += 1;
417 }
418
419 while end < input.len() as u64 {
422 let mut stream = Stream::new(Vec::new());
423 stream.write_all(&input[..end as usize]).unwrap();
424
425 if end < (MAX_HEADER_LEN + START_BYTES) as u64 {
427 assert_eq!(stream.tmp.position(), end);
428 }
429
430 let err = stream.finish().unwrap_err();
431 assert!(err.to_string().contains("failed to fill whole buffer"));
432
433 end += 1;
434 }
435 }
436
437 #[test]
439 fn test_stream_chunked() {
440 let small_input = include_bytes!("../../tests/files/small.txt");
441
442 let mut reader = io::Cursor::new(&small_input[..]);
443 let mut small_input_compressed = Vec::new();
444 crate::lzma_compress(&mut reader, &mut small_input_compressed).unwrap();
445
446 let input : Vec<(&[u8], &[u8])> = vec![
447 (b"\x5d\x00\x00\x80\x00\xff\xff\xff\xff\xff\xff\xff\xff\x00\x83\xff\xfb\xff\xff\xc0\x00\x00\x00", b""),
448 (&small_input_compressed[..], small_input)];
449 for (input, expected) in input {
450 for chunk in 1..input.len() {
451 let mut consumed = 0;
452 let mut stream = Stream::new(Vec::new());
453 while consumed < input.len() {
454 let end = std::cmp::min(consumed + chunk, input.len());
455 stream.write_all(&input[consumed..end]).unwrap();
456 consumed = end;
457 }
458 let output = stream.finish().unwrap();
459 assert_eq!(expected, &output[..]);
460 }
461 }
462 }
463
464 #[test]
465 fn test_stream_corrupted() {
466 let mut stream = Stream::new(Vec::new());
467 let err = stream
468 .write_all(b"corrupted bytes here corrupted bytes here")
469 .unwrap_err();
470 assert!(err.to_string().contains("beyond output size"));
471 let err = stream.finish().unwrap_err();
472 assert!(err
473 .to_string()
474 .contains("can\'t finish stream because of previous write error"));
475 }
476
477 #[test]
478 fn test_allow_incomplete() {
479 let input = include_bytes!("../../tests/files/small.txt");
480
481 let mut reader = io::Cursor::new(&input[..]);
482 let mut compressed = Vec::new();
483 crate::lzma_compress(&mut reader, &mut compressed).unwrap();
484 let compressed = &compressed[..compressed.len() / 2];
485
486 let mut stream = Stream::new(Vec::new());
488 stream.write_all(&compressed[..]).unwrap();
489 stream.finish().unwrap_err();
490
491 let mut stream = Stream::new_with_options(
493 &Options {
494 allow_incomplete: true,
495 ..Default::default()
496 },
497 Vec::new(),
498 );
499 stream.write_all(&compressed[..]).unwrap();
500 let output = stream.finish().unwrap();
501 assert_eq!(output, &input[..26]);
502 }
503}