cmprss 0.2.0

A compression multi-tool for the command line.
use crate::{
    progress::{progress_bar, ProgressArgs},
    utils::*,
};
use bzip2::write::{BzDecoder, BzEncoder};
use bzip2::Compression;
use clap::Args;
use std::{
    fs::File,
    io::{self, Read, Write},
};

#[derive(Args, Debug)]
pub struct Bzip2Args {
    #[clap(flatten)]
    pub common_args: CommonArgs,

    #[clap(flatten)]
    pub progress_args: ProgressArgs,

    /// Level of compression.
    /// This is an int 1-9, with 1 being minimal compression and 9 being highest compression.
    /// Also supports 'fast', and 'best'.
    #[arg(long, default_value = "9")]
    pub level: CompressionLevel,
}

pub struct Bzip2 {
    pub level: u32, // 1-9
    pub progress_args: ProgressArgs,
}

impl Default for Bzip2 {
    fn default() -> Self {
        Bzip2 {
            level: 6,
            progress_args: ProgressArgs::default(),
        }
    }
}

impl Bzip2 {
    pub fn new(args: &Bzip2Args) -> Self {
        Bzip2 {
            level: args.level.level,
            progress_args: args.progress_args,
        }
    }
}

impl Compressor for Bzip2 {
    /// The standard extension for the bz2 format.
    fn extension(&self) -> &str {
        "bz2"
    }

    /// Full name for bz2.
    fn name(&self) -> &str {
        "bzip2"
    }

    /// Compress an input file or pipe to a bz2 archive
    fn compress(&self, input: CmprssInput, output: CmprssOutput) -> Result<(), io::Error> {
        if self.level < 1 || self.level > 9 {
            return cmprss_error("Invalid compression level. Must be 1-9.");
        }
        let mut file_size = None;
        let mut input_stream = match input {
            CmprssInput::Path(paths) => {
                if paths.len() > 1 {
                    return cmprss_error("only 1 file can be compressed at a time");
                }
                let file = Box::new(File::open(paths[0].as_path())?);
                // Get the file size for the progress bar
                if let Ok(metadata) = file.metadata() {
                    file_size = Some(metadata.len());
                }
                file
            }
            CmprssInput::Pipe(pipe) => Box::new(pipe) as Box<dyn Read + Send>,
        };
        let output_stream: Box<dyn Write + Send> = match &output {
            CmprssOutput::Path(path) => Box::new(File::create(path)?),
            CmprssOutput::Pipe(pipe) => Box::new(pipe) as Box<dyn Write + Send>,
        };
        let mut encoder = BzEncoder::new(output_stream, Compression::new(self.level));
        let mut bar = progress_bar(file_size, self.progress_args.progress, &output);
        if let Some(progress) = &mut bar {
            // Copy the input to the output in chunks so that we can update the progress bar
            let mut buffer = vec![0; self.progress_args.chunk_size.size_in_bytes];
            loop {
                let bytes_read = input_stream.read(&mut buffer)?;
                if bytes_read == 0 {
                    break;
                }
                encoder.write_all(&buffer[..bytes_read])?;
                progress.update_input(encoder.total_in());
                progress.update_output(encoder.total_out());
            }
            encoder.flush()?;
            progress.update_output(encoder.total_out());
            progress.finish();
        } else {
            io::copy(&mut input_stream, &mut encoder)?;
        }
        Ok(())
    }

    /// Extract a bz2 archive to a file or pipe
    fn extract(&self, input: CmprssInput, output: CmprssOutput) -> Result<(), io::Error> {
        let mut file_size = None;
        let mut input_stream = match input {
            CmprssInput::Path(paths) => {
                if paths.len() > 1 {
                    return cmprss_error("only 1 file can be extracted at a time");
                }
                let file = Box::new(File::open(paths[0].as_path())?);
                // Get the file size for the progress bar
                if let Ok(metadata) = file.metadata() {
                    file_size = Some(metadata.len());
                }
                file
            }
            CmprssInput::Pipe(pipe) => Box::new(pipe) as Box<dyn Read + Send>,
        };
        let output_stream: Box<dyn Write + Send> = match &output {
            CmprssOutput::Path(path) => Box::new(File::create(path)?),
            CmprssOutput::Pipe(pipe) => Box::new(pipe) as Box<dyn Write + Send>,
        };
        let mut decoder = BzDecoder::new(output_stream);
        let mut bar = progress_bar(file_size, self.progress_args.progress, &output);
        if let Some(progress) = &mut bar {
            // Copy the input to the output in chunks so that we can update the progress bar
            let mut buffer = vec![0; self.progress_args.chunk_size.size_in_bytes];
            loop {
                let bytes_read = input_stream.read(&mut buffer)?;
                if bytes_read == 0 {
                    break;
                }
                decoder.write_all(&buffer[..bytes_read])?;
                progress.update_input(decoder.total_in());
                progress.update_output(decoder.total_out());
            }
            decoder.flush()?;
            progress.update_output(decoder.total_out());
            progress.finish();
        } else {
            io::copy(&mut input_stream, &mut decoder)?;
        }
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use assert_fs::prelude::*;
    use predicates::prelude::*;

    #[test]
    fn roundtrip() -> Result<(), Box<dyn std::error::Error>> {
        let compressor = Bzip2::default();

        let file = assert_fs::NamedTempFile::new("test.txt")?;
        file.write_str("garbage data for testing")?;
        let working_dir = assert_fs::TempDir::new()?;
        let archive = working_dir.child("archive.".to_owned() + compressor.extension());
        archive.assert(predicate::path::missing());

        // Roundtrip compress/extract
        compressor.compress(
            CmprssInput::Path(vec![file.path().to_path_buf()]),
            CmprssOutput::Path(archive.path().to_path_buf()),
        )?;
        archive.assert(predicate::path::is_file());
        compressor.extract(
            CmprssInput::Path(vec![archive.path().to_path_buf()]),
            CmprssOutput::Path(working_dir.child("test.txt").path().to_path_buf()),
        )?;

        // Assert the files are identical
        working_dir
            .child("test.txt")
            .assert(predicate::path::eq_file(file.path()));

        Ok(())
    }

    // Fail with a compression level of 0
    #[test]
    fn invalid_compression_level_0() {
        let compressor = Bzip2 {
            level: 0,
            ..Bzip2::default()
        };
        let file = assert_fs::NamedTempFile::new("test.txt").unwrap();
        let working_dir = assert_fs::TempDir::new().unwrap();
        let archive = working_dir.child("archive.".to_owned() + compressor.extension());
        let result = compressor.compress(
            CmprssInput::Path(vec![file.path().to_path_buf()]),
            CmprssOutput::Path(archive.path().to_path_buf()),
        );
        assert!(result.is_err());
    }

    // Fail with a compression level of 10
    #[test]
    fn invalid_compression_level_10() {
        let compressor = Bzip2 {
            level: 10,
            ..Bzip2::default()
        };
        let file = assert_fs::NamedTempFile::new("test.txt").unwrap();
        let working_dir = assert_fs::TempDir::new().unwrap();
        let archive = working_dir.child("archive.".to_owned() + compressor.extension());
        let result = compressor.compress(
            CmprssInput::Path(vec![file.path().to_path_buf()]),
            CmprssOutput::Path(archive.path().to_path_buf()),
        );
        assert!(result.is_err());
    }
}