use std::fmt;
use std::io::{Read, Write};
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Codec {
pub name: String,
pub cls: String,
}
#[derive(Debug)]
pub enum CodecError {
Unavailable {
reason: &'static str,
detail: String,
},
Failed(String),
}
impl fmt::Display for CodecError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Unavailable { reason, detail } => write!(f, "{reason}: {detail}"),
Self::Failed(detail) => f.write_str(detail),
}
}
}
impl std::error::Error for CodecError {}
const MAX_ZSTD_DECODED_SIZE: usize = 16 * 1024 * 1024;
fn decode_one(codec: &Codec, data: &[u8]) -> Result<Vec<u8>, CodecError> {
if codec.cls == "encrypt" {
return Err(CodecError::Unavailable {
reason: "missing-key",
detail: format!("no key for encrypt codec '{}'", codec.name),
});
}
match codec.name.as_str() {
"identity" => Ok(data.to_vec()),
"gzip" => {
let mut out = Vec::new();
flate2::read::GzDecoder::new(data)
.read_to_end(&mut out)
.map_err(|e| CodecError::Failed(format!("gzip decode failed: {e}")))?;
Ok(out)
}
"zstd" | "zstd-rsyncable" => {
let mut decoder = ruzstd::decoding::FrameDecoder::new();
let mut capacity = data
.len()
.saturating_mul(4)
.clamp(4096, MAX_ZSTD_DECODED_SIZE);
loop {
let mut out = Vec::with_capacity(capacity);
match decoder.decode_all_to_vec(data, &mut out) {
Ok(()) => return Ok(out),
Err(ruzstd::decoding::errors::FrameDecoderError::TargetTooSmall) => {
if capacity >= MAX_ZSTD_DECODED_SIZE {
return Err(CodecError::Failed(
"zstd decode failed: decompressed size exceeds safety bound".into(),
));
}
capacity = (capacity * 2).min(MAX_ZSTD_DECODED_SIZE);
continue;
}
Err(e) => return Err(CodecError::Failed(format!("zstd decode failed: {e}"))),
}
}
}
other => Err(CodecError::Unavailable {
reason: "unknown-codec",
detail: format!("unknown codec '{other}'"),
}),
}
}
const RSYNCABLE_BLOCK_SIZE: usize = 65_536;
fn encode_zstd(data: &[u8]) -> Vec<u8> {
ruzstd::encoding::compress_to_vec(data, ruzstd::encoding::CompressionLevel::Fastest)
}
fn encode_zstd_rsyncable(data: &[u8]) -> Vec<u8> {
let mut out = Vec::new();
for block in data.chunks(RSYNCABLE_BLOCK_SIZE) {
out.extend(encode_zstd(block));
}
out
}
fn encode_one(name: &str, data: &[u8]) -> Result<Vec<u8>, CodecError> {
match name {
"identity" => Ok(data.to_vec()),
"gzip" => {
let mut encoder = flate2::GzBuilder::new()
.mtime(0)
.write(Vec::new(), flate2::Compression::default());
encoder
.write_all(data)
.map_err(|e| CodecError::Failed(format!("gzip encode failed: {e}")))?;
encoder
.finish()
.map_err(|e| CodecError::Failed(format!("gzip encode failed: {e}")))
}
"zstd" => Ok(encode_zstd(data)),
"zstd-rsyncable" => Ok(encode_zstd_rsyncable(data)),
other => Err(CodecError::Unavailable {
reason: "unknown-codec",
detail: format!("writer cannot encode with codec '{other}'"),
}),
}
}
pub fn encode_chain(chain: &[String], data: &[u8]) -> Result<Vec<u8>, CodecError> {
let mut current = data.to_vec();
for name in chain {
current = encode_one(name, ¤t)?;
}
Ok(current)
}
pub fn decode_chain(chain: &[Codec], data: &[u8]) -> Result<Vec<u8>, CodecError> {
decode_chain_with_decrypt(chain, data, None)
}
pub type Decryptor<'a> = dyn Fn(&Codec, &[u8]) -> Result<Vec<u8>, CodecError> + 'a;
pub fn decode_chain_with_decrypt(
chain: &[Codec],
data: &[u8],
decrypt: Option<&Decryptor<'_>>,
) -> Result<Vec<u8>, CodecError> {
let mut current = data.to_vec();
for codec in chain.iter().rev() {
if codec.cls == "encrypt" {
current = match decrypt {
Some(decrypt) => decrypt(codec, ¤t)?,
None => {
return Err(CodecError::Unavailable {
reason: "missing-key",
detail: format!("no key for encrypt codec '{}'", codec.name),
})
}
};
} else {
current = decode_one(codec, ¤t)?;
}
}
Ok(current)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encoded_core_codecs_round_trip() {
let payload = b"stable payload for writer transform parity".repeat(8);
for name in ["identity", "gzip", "zstd", "zstd-rsyncable"] {
let encoded = encode_chain(&[name.to_string()], &payload).expect("encodes");
let decoded = decode_chain(
&[Codec {
name: name.to_string(),
cls: if name == "identity" {
"encode".into()
} else {
"compress".into()
},
}],
&encoded,
)
.expect("decodes");
assert_eq!(decoded, payload);
}
}
#[test]
fn gzip_encoding_is_deterministic() {
let payload = b"stable gzip payload".repeat(16);
assert_eq!(
encode_chain(&["gzip".to_string()], &payload).unwrap(),
encode_chain(&["gzip".to_string()], &payload).unwrap()
);
}
#[test]
fn zstd_rsyncable_decodes_concatenated_frames() {
let block1 = b"first block of rsyncable data ";
let block2 = b"second block of rsyncable data";
let mut encoded = ruzstd::encoding::compress_to_vec(
&block1[..],
ruzstd::encoding::CompressionLevel::Uncompressed,
);
encoded.extend(ruzstd::encoding::compress_to_vec(
&block2[..],
ruzstd::encoding::CompressionLevel::Uncompressed,
));
let decoded = decode_one(
&Codec {
name: "zstd-rsyncable".into(),
cls: "compress".into(),
},
&encoded,
)
.expect("multi-frame zstd must decode");
let mut expected = block1.to_vec();
expected.extend_from_slice(block2);
assert_eq!(decoded, expected);
}
#[test]
fn zstd_decode_grows_until_safety_bound() {
let payload = vec![b'x'; 2 * 1024 * 1024];
let encoded = encode_chain(&["zstd".to_string()], &payload).expect("zstd encodes");
let decoded = decode_chain(
&[Codec {
name: "zstd".into(),
cls: "compress".into(),
}],
&encoded,
)
.expect("zstd decoder grows beyond the initial output capacity");
assert_eq!(decoded, payload);
}
}