warcat 0.3.0

Command-line tool and library for handling Web ARChive (WARC) files
Documentation
use std::{fmt::Debug, io::Write};

use chunked::{ChunkedDecoder, ChunkedEncoder};
use compress::{CompressionDecoder, CompressionEncoder};

use crate::error::{GeneralError, ProtocolError, ProtocolErrorKind};

use super::header::{fields::FieldsExt, MessageHeader};

pub mod chunked;
pub mod compress;

pub type BoxedCodec = Box<dyn Codec>;

pub trait Codec: Debug {
    fn transform(&mut self, input: &[u8], output: &mut Vec<u8>) -> Result<(), GeneralError>;

    fn finish_input(&mut self, output: &mut Vec<u8>) -> Result<(), GeneralError> {
        let _ = output;

        Ok(())
    }

    fn has_remaining_trailer(&self) -> bool {
        false
    }

    fn remaining_trailer(&mut self, trailer: &mut Vec<u8>) {
        let _ = trailer;
    }
}

#[derive(Debug, Default)]
pub struct IdentityCodec;

impl Codec for IdentityCodec {
    fn transform(&mut self, input: &[u8], output: &mut Vec<u8>) -> Result<(), GeneralError> {
        output.write_all(input)?;
        Ok(())
    }
}

pub fn build_decoders(
    header: &MessageHeader,
    codecs: &mut Vec<BoxedCodec>,
) -> Result<(), ProtocolError> {
    build_codecs(header, codecs, false)
}

pub fn build_encoders(
    header: &MessageHeader,
    codecs: &mut Vec<BoxedCodec>,
) -> Result<(), ProtocolError> {
    build_codecs(header, codecs, true)
}

fn build_codecs(
    header: &MessageHeader,
    codecs: &mut Vec<BoxedCodec>,
    encode: bool,
) -> Result<(), ProtocolError> {
    let mut te_names = header
        .fields
        .get_comma_list("transfer-encoding")
        .collect::<Vec<_>>();
    let ce_names = header.fields.get_comma_list("content-encoding");

    te_names.reverse();
    for name in te_names {
        if encode {
            if let Some(codec) = make_encoder(name.as_ref(), true) {
                codecs.push(codec);
                continue;
            }
        } else if let Some(codec) = make_decoder(name.as_ref(), true) {
            codecs.push(codec);
            continue;
        }
        return Err(ProtocolError::new(
            ProtocolErrorKind::UnsupportedTransferEncoding,
        ));
    }

    for name in ce_names {
        if name == "identity" {
            continue;
        }

        if encode {
            if let Some(codec) = make_encoder(name.as_ref(), false) {
                codecs.push(codec);
                continue;
            }
        } else if let Some(codec) = make_decoder(name.as_ref(), false) {
            codecs.push(codec);
            continue;
        }
        return Err(ProtocolError::new(
            ProtocolErrorKind::UnsupportedContentEncoding,
        ));
    }

    Ok(())
}

fn make_encoder(name: &str, transfer_encoding: bool) -> Option<BoxedCodec> {
    if let Ok(codec) = CompressionEncoder::try_of_name(name) {
        tracing::trace!(name, "built compression encoder");
        Some(Box::new(codec))
    } else if name == "chunked" && transfer_encoding {
        tracing::trace!(name, "built chunked encoder");
        Some(Box::new(ChunkedEncoder::new()))
    } else {
        None
    }
}

fn make_decoder(name: &str, transfer_encoding: bool) -> Option<BoxedCodec> {
    if let Ok(codec) = CompressionDecoder::try_of_name(name) {
        tracing::trace!(name, "built compression decoder");
        Some(Box::new(codec))
    } else if name == "chunked" && transfer_encoding {
        tracing::trace!(name, "built chunked decoder");
        Some(Box::new(ChunkedDecoder::new()))
    } else {
        None
    }
}

#[derive(Debug, Default)]
pub struct CodecPipeline {
    codecs: Vec<BoxedCodec>,
    buf_in: Vec<u8>,
    buf_out: Vec<u8>,
}

impl CodecPipeline {
    pub fn new(codecs: Vec<BoxedCodec>) -> Self {
        Self {
            codecs,
            buf_in: Vec::new(),
            buf_out: Vec::new(),
        }
    }

    pub fn transform<W: Write>(&mut self, input: &[u8], mut output: W) -> Result<(), GeneralError> {
        if self.codecs.is_empty() {
            output.write_all(input)?;
            return Ok(());
        }

        self.buf_in.extend_from_slice(input);

        for codec in &mut self.codecs {
            codec.transform(&self.buf_in, &mut self.buf_out)?;

            self.buf_in.clear();
            std::mem::swap(&mut self.buf_in, &mut self.buf_out);
        }

        output.write_all(&self.buf_in)?;

        self.buf_in.clear();

        Ok(())
    }

    pub fn finish_input<W: Write>(&mut self, mut output: W) -> Result<(), GeneralError> {
        if self.codecs.is_empty() {
            return Ok(());
        }

        for codec in &mut self.codecs {
            codec.finish_input(&mut self.buf_out)?;

            self.buf_in.clear();
            std::mem::swap(&mut self.buf_in, &mut self.buf_out);
        }

        output.write_all(&self.buf_in)?;
        self.buf_in.clear();

        Ok(())
    }

    pub fn has_remaining_trailer(&self) -> bool {
        self.codecs
            .iter()
            .any(|codec| codec.has_remaining_trailer())
    }

    pub fn remaining_trailer(&mut self, trailer: &mut Vec<u8>) {
        for codec in &mut self.codecs {
            codec.remaining_trailer(trailer);
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_codec_pipeline() {
        let mut pipeline = CodecPipeline::new(vec![
            Box::new(IdentityCodec),
            Box::new(IdentityCodec),
            Box::new(IdentityCodec),
        ]);
        let mut output = Vec::new();

        pipeline.transform(b"a", &mut output).unwrap();
        pipeline.transform(b"b", &mut output).unwrap();
        pipeline.transform(b"c", &mut output).unwrap();

        assert_eq!(&output, b"abc");
    }
}