use crate::utils::{
CmprssInput, CmprssOutput, Compressor, ExtractedTarget, ReadWrapper, Result, WriteWrapper,
};
use anyhow::{anyhow, bail};
use std::io::{self, Read, Write};
use std::path::Path;
use std::sync::mpsc::{Receiver, Sender, channel};
use std::thread;
pub struct Pipeline {
compressors: Vec<Box<dyn Compressor>>,
format_override: Option<String>,
}
impl Clone for Pipeline {
fn clone(&self) -> Self {
Pipeline {
compressors: self.compressors.iter().map(|c| c.clone_boxed()).collect(),
format_override: self.format_override.clone(),
}
}
}
#[derive(Clone, Copy)]
enum StageAction {
Compress,
Extract,
}
impl Pipeline {
pub fn new(compressors: Vec<Box<dyn Compressor>>) -> Self {
Pipeline {
compressors,
format_override: None,
}
}
pub fn with_format(compressors: Vec<Box<dyn Compressor>>, format: String) -> Self {
Pipeline {
compressors,
format_override: Some(format),
}
}
fn format_chain(&self) -> String {
if let Some(ref f) = self.format_override {
return f.clone();
}
self.compressors
.iter()
.map(|c| c.extension())
.collect::<Vec<&str>>()
.join(".")
}
fn run_threaded<F>(
stages: Vec<Box<dyn Compressor>>,
initial_input: CmprssInput,
intermediate: StageAction,
finalize: F,
) -> Result
where
F: FnOnce(Box<dyn Compressor>, CmprssInput) -> Result,
{
debug_assert!(!stages.is_empty(), "pipeline is never empty");
let mut stages = stages;
let last = stages.pop().expect("pipeline is never empty");
let buffer_size = 64 * 1024;
let mut current_input = initial_input;
let mut handles = Vec::new();
for stage in stages {
let (sender, receiver) = channel::<Vec<u8>>();
let stage_output =
CmprssOutput::Writer(WriteWrapper(Box::new(PipeWriter::new(sender, buffer_size))));
let next_input = CmprssInput::Reader(ReadWrapper(Box::new(PipeReader::new(receiver))));
let stage_input = std::mem::replace(&mut current_input, next_input);
let handle = thread::spawn(move || match intermediate {
StageAction::Compress => stage.compress(stage_input, stage_output),
StageAction::Extract => stage.extract(stage_input, stage_output),
});
handles.push(handle);
}
finalize(last, current_input)?;
for handle in handles {
handle
.join()
.map_err(|_| anyhow!("Pipeline stage thread panicked"))??;
}
Ok(())
}
}
struct PipeReader {
receiver: Receiver<Vec<u8>>,
buffer: Vec<u8>,
position: usize,
eof: bool,
}
impl PipeReader {
fn new(receiver: Receiver<Vec<u8>>) -> Self {
PipeReader {
receiver,
buffer: Vec::new(),
position: 0,
eof: false,
}
}
}
impl Read for PipeReader {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if self.eof && self.position >= self.buffer.len() {
return Ok(0);
}
if self.position >= self.buffer.len() {
match self.receiver.recv() {
Ok(data) => {
if data.is_empty() {
self.eof = true;
return Ok(0);
}
self.buffer = data;
self.position = 0;
}
Err(_) => {
self.eof = true;
return Ok(0);
}
}
}
let available = self.buffer.len() - self.position;
let to_copy = available.min(buf.len());
buf[..to_copy].copy_from_slice(&self.buffer[self.position..self.position + to_copy]);
self.position += to_copy;
Ok(to_copy)
}
}
struct PipeWriter {
sender: Sender<Vec<u8>>,
buffer_size: usize,
}
impl PipeWriter {
fn new(sender: Sender<Vec<u8>>, buffer_size: usize) -> Self {
PipeWriter {
sender,
buffer_size,
}
}
}
impl Write for PipeWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut start = 0;
while start < buf.len() {
let end = (start + self.buffer_size).min(buf.len());
let chunk = Vec::from(&buf[start..end]);
if self.sender.send(chunk).is_err() {
return Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"Pipe receiver has been closed",
));
}
start = end;
}
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl Drop for PipeWriter {
fn drop(&mut self) {
let _ = self.sender.send(Vec::new());
}
}
impl Compressor for Pipeline {
fn name(&self) -> &str {
self.compressors
.last()
.expect("pipeline is never empty")
.name()
}
fn extension(&self) -> &str {
self.compressors
.last()
.expect("pipeline is never empty")
.extension()
}
fn default_extracted_target(&self) -> ExtractedTarget {
self.compressors
.first()
.expect("pipeline is never empty")
.default_extracted_target()
}
fn default_compressed_filename(&self, in_path: &Path) -> String {
let base = in_path
.file_name()
.map(|n| n.to_string_lossy().into_owned())
.unwrap_or_else(|| "archive".to_string());
format!("{}.{}", base, self.format_chain())
}
fn default_extracted_filename(&self, in_path: &Path) -> String {
if self.default_extracted_target() == ExtractedTarget::Directory {
return ".".to_string();
}
let mut name = in_path
.file_name()
.map(|n| n.to_string_lossy().into_owned())
.unwrap_or_else(|| "archive".to_string());
for comp in self.compressors.iter().rev() {
let ext = format!(".{}", comp.extension());
if let Some(stripped) = name.strip_suffix(&ext) {
name = stripped.to_string();
}
}
name
}
fn is_archive(&self, in_path: &Path) -> bool {
let file_name = match in_path.file_name().and_then(|f| f.to_str()) {
Some(f) => f,
None => return false,
};
file_name.ends_with(&format!(".{}", self.format_chain()))
}
fn compress(&self, input: CmprssInput, output: CmprssOutput) -> Result {
debug_assert!(!self.compressors.is_empty(), "pipeline is never empty");
if self.compressors.len() == 1 {
return self.compressors[0].compress(input, output);
}
let stages = self.compressors.iter().map(|c| c.clone_boxed()).collect();
Self::run_threaded(stages, input, StageAction::Compress, |last, input| {
last.compress(input, output)
})
}
fn extract(&self, input: CmprssInput, output: CmprssOutput) -> Result {
debug_assert!(!self.compressors.is_empty(), "pipeline is never empty");
if self.compressors.len() == 1 {
return self.compressors[0].extract(input, output);
}
let stages = self
.compressors
.iter()
.rev()
.map(|c| c.clone_boxed())
.collect();
Self::run_threaded(stages, input, StageAction::Extract, |last, input| {
let final_output = match output {
CmprssOutput::Path(ref p) => {
if last.default_extracted_target() == ExtractedTarget::Directory && !p.exists()
{
std::fs::create_dir_all(p)?;
}
CmprssOutput::Path(p.clone())
}
CmprssOutput::Pipe(_) | CmprssOutput::Writer(_) => output,
};
last.extract(input, final_output)
})
}
fn append(&self, input: CmprssInput, output: CmprssOutput) -> Result {
debug_assert!(!self.compressors.is_empty(), "pipeline is never empty");
if self.compressors.len() == 1 {
return self.compressors[0].append(input, output);
}
bail!(
"cannot --append to a compound archive ({}); it would require decompressing and recompressing the whole archive",
self.format_chain()
)
}
fn list(&self, input: CmprssInput) -> Result {
debug_assert!(!self.compressors.is_empty(), "pipeline is never empty");
if self.compressors.len() == 1 {
return self.compressors[0].list(input);
}
let stages = self
.compressors
.iter()
.rev()
.map(|c| c.clone_boxed())
.collect();
Self::run_threaded(stages, input, StageAction::Extract, |innermost, input| {
innermost.list(input)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::tempdir;
#[test]
fn test_pipeline_compression() -> Result {
let temp_dir = tempdir()?;
let test_content = "This is a test file for pipeline compression";
let test_file_path = temp_dir.path().join("test.txt");
fs::write(&test_file_path, test_content)?;
let pipeline = Pipeline::new(vec![
Box::new(crate::backends::Tar::default()),
Box::new(crate::backends::Gzip::default()),
]);
let archive_path = temp_dir.path().join("test.tar.gz");
pipeline.compress(
CmprssInput::Path(vec![test_file_path.clone()]),
CmprssOutput::Path(archive_path.clone()),
)?;
assert!(archive_path.exists());
let output_dir = temp_dir.path().join("extracted");
fs::create_dir(&output_dir)?;
pipeline.extract(
CmprssInput::Path(vec![archive_path.clone()]),
CmprssOutput::Path(output_dir.clone()),
)?;
let extracted_file = output_dir.join("test.txt");
assert!(extracted_file.exists());
let extracted_content = fs::read_to_string(extracted_file)?;
assert_eq!(extracted_content, test_content);
Ok(())
}
#[test]
fn test_pipeline_preserves_stage_config() -> Result {
use crate::progress::ProgressArgs;
let temp_dir = tempdir()?;
let input = temp_dir.path().join("input.txt");
fs::write(&input, "0123456789abcdef".repeat(1024))?;
let run = |level: i32, suffix: &str| -> Result<u64> {
let fast = Pipeline::new(vec![
Box::new(crate::backends::Tar::default()),
Box::new(crate::backends::Gzip {
compression_level: level,
progress_args: ProgressArgs::default(),
}),
]);
let out = temp_dir.path().join(format!("out.{suffix}.tar.gz"));
fast.compress(
CmprssInput::Path(vec![input.clone()]),
CmprssOutput::Path(out.clone()),
)?;
Ok(fs::metadata(&out)?.len())
};
let fast_size = run(1, "fast")?;
let best_size = run(9, "best")?;
assert!(
best_size < fast_size,
"expected best (level 9) to be smaller than fast (level 1), got {best_size} >= {fast_size}",
);
Ok(())
}
}