use crate::constants::{MAGIC_BYTES, VERSION};
use crate::crypto;
use crate::crypto::decrypt_stream;
use crate::flags::{is_encrypted, is_stored_raw};
use crate::headers::Headers;
use crate::io::BitReader;
use std::io::{Read, Seek, Write}; use tempfile::tempfile;
#[derive(Debug)]
pub struct DecodeInfo {
pub original_file_name: String,
pub checksum: u32,
pub original_size: u64,
}
pub fn decode<R: Read + Seek, W: Write>(
header: Headers,
reader: &mut R,
decrypt_password: Option<&str>,
writer: &mut W,
chunk_size: usize,
) -> Result<DecodeInfo, Box<dyn std::error::Error>> {
if header.magic_bytes != MAGIC_BYTES {
return Err("Error: Not a valid .small file".into());
}
if header.version != VERSION {
return Err("Error: Incorrect version".into());
}
let mut payload_reader: Box<dyn Read> = Box::new(reader.take(header.compressed_size));
if is_encrypted(header.flags) {
let password = decrypt_password.ok_or("Error: File is encrypted, password required.")?;
let key = crypto::derive_key(password.as_bytes(), &header.salt);
let mut decrypted_temp_file = tempfile()?; let _decrypted_bytes_written = decrypt_stream(
&mut payload_reader, &mut decrypted_temp_file, &key,
&header.iv,
&[], chunk_size, )?;
decrypted_temp_file.seek(std::io::SeekFrom::Start(0))?;
payload_reader = Box::new(decrypted_temp_file.take(header.payload_actual_size));
}
if is_stored_raw(header.flags) {
std::io::copy(&mut payload_reader.take(header.payload_actual_size), writer)?;
} else {
let mut bit_reader = BitReader::new(header.padding_bits as usize, payload_reader);
let mut current = &header.tree;
let mut decoded_bytes_count = 0;
let total_data_bits = (header.payload_actual_size * 8) - header.padding_bits as u64;
let mut bits_read_count: u64 = 0;
while decoded_bytes_count < header.original_size {
if let Some(byte) = current.symbol {
writer.write_all(&[byte])?;
decoded_bytes_count += 1;
current = &header.tree;
continue;
}
if bits_read_count >= total_data_bits {
return Err(
"Corrupted: Unexpected end of compressed data (read past padding)".into(),
);
}
match bit_reader.read_bit()? {
Some(0) => {
current = current
.left
.as_ref()
.ok_or("Corrupted: Missing left node")?;
bits_read_count += 1; }
Some(1) => {
current = current
.right
.as_ref()
.ok_or("Corrupted: Missing right node")?;
bits_read_count += 1; }
None => {
return Err(
"Corrupted: Unexpected end of file (BitReader returned None)".into(),
);
}
_ => unreachable!(),
}
}
}
Ok(DecodeInfo {
original_file_name: header.original_file_name,
checksum: header.checksum,
original_size: header.original_size,
})
}