titan-api-codec 1.2.9

Helpers for encoding and decoding Titan API messages
Documentation
//! Defines transforms that utilize [zstd] to compress and decompress data.
//!
//! [zstd]: https://github.com/facebook/zstd

use super::common::BinaryTransform;

use bytes::{Buf, Bytes};
use zstd::bulk::{compress, decompress, Compressor, Decompressor};
use zstd::stream::decode_all;

/// Transform that applies zstd compression to input.
#[derive(Default)]
pub struct ZstdCompressor {
    level: i32,
    inner: Compressor<'static>,
}

impl ZstdCompressor {
    /// Creates a new compressor with the given compression level.
    pub fn new(level: i32) -> std::io::Result<Self> {
        let mut inner = Compressor::default();
        if level != 0 {
            inner.set_dictionary(level, &[])?;
        }

        Ok(Self { level, inner })
    }
}

impl BinaryTransform for ZstdCompressor {
    fn transform(&self, data: Bytes) -> Result<Bytes, std::io::Error> {
        compress(&data, self.level).map(Bytes::from)
    }

    fn transform_mut(&mut self, data: Bytes) -> Result<Bytes, std::io::Error> {
        self.inner.compress(&data).map(Bytes::from)
    }
}

/// Transform that transforms zstd-compressed data back to its original content.
#[derive(Default)]
pub struct ZstdDecompressor {
    inner: Decompressor<'static>,
}

impl ZstdCompressor {}

impl BinaryTransform for ZstdDecompressor {
    fn transform(&self, data: Bytes) -> Result<Bytes, std::io::Error> {
        if let Ok(bound) = zstd_safe::decompress_bound(&data) {
            decompress(&data, bound as usize).map(Bytes::from)
        } else {
            // Unable to determine size, fallback to stream decoding.
            decode_all(data.reader()).map(Bytes::from)
        }
    }

    fn transform_mut(&mut self, data: Bytes) -> Result<Bytes, std::io::Error> {
        if let Ok(bound) = zstd_safe::decompress_bound(&data) {
            self.inner
                .decompress(&data, bound as usize)
                .map(Bytes::from)
        } else {
            // Unable to determine size, fallback to stream decoding.
            decode_all(data.reader()).map(Bytes::from)
        }
    }
}

#[cfg(test)]
mod test {
    use super::{ZstdCompressor, ZstdDecompressor};
    use crate::transform::BinaryTransform;
    use bytes::Bytes;
    use lipsum::lipsum;

    #[test]
    fn test_roundtrip_default() {
        let compressor = ZstdCompressor::default();
        let decompressor = ZstdDecompressor::default();

        let data = Bytes::from(lipsum(1000));

        let compressed = compressor
            .transform(data.clone())
            .expect("should compress via zstd");
        let uncompressed = decompressor
            .transform(compressed)
            .expect("should decompress from zstd");

        assert_eq!(data, uncompressed);
    }

    #[test]
    fn test_roundtrip_mut_default() {
        let mut compressor = ZstdCompressor::default();
        let mut decompressor = ZstdDecompressor::default();

        let data = Bytes::from(lipsum(1000));

        let compressed = compressor
            .transform_mut(data.clone())
            .expect("should compress via zstd");
        let uncompressed = decompressor
            .transform_mut(compressed)
            .expect("should decompress from zstd");

        assert_eq!(data, uncompressed);
    }

    #[test]
    fn test_compressor_adds_size() {
        let mut compressor = ZstdCompressor::default();

        let data = Bytes::from(lipsum(1000));

        let compressed = compressor
            .transform_mut(data.clone())
            .expect("should compress via zstd");

        let size_bound = zstd_safe::decompress_bound(&compressed)
            .expect("should be able to determine decompress bound");
        assert_eq!(size_bound as usize, data.len());
    }
}