#[cfg(feature = "std")]
use alloc::collections::VecDeque;
use alloc::vec::Vec;
use core::fmt::Debug;
#[cfg(feature = "std")]
use std::sync::Mutex;
use crate::enums::CertificateCompressionAlgorithm;
use crate::msgs::base::{Payload, PayloadU24};
use crate::msgs::codec::Codec;
use crate::msgs::handshake::{CertificatePayloadTls13, CompressedCertificatePayload};
use crate::sync::Arc;
pub fn default_cert_decompressors() -> &'static [&'static dyn CertDecompressor] {
&[
#[cfg(feature = "brotli")]
BROTLI_DECOMPRESSOR,
#[cfg(feature = "zlib")]
ZLIB_DECOMPRESSOR,
]
}
pub trait CertDecompressor: Debug + Send + Sync {
fn decompress(&self, input: &[u8], output: &mut [u8]) -> Result<(), DecompressionFailed>;
fn algorithm(&self) -> CertificateCompressionAlgorithm;
}
pub fn default_cert_compressors() -> &'static [&'static dyn CertCompressor] {
&[
#[cfg(feature = "brotli")]
BROTLI_COMPRESSOR,
#[cfg(feature = "zlib")]
ZLIB_COMPRESSOR,
]
}
pub trait CertCompressor: Debug + Send + Sync {
fn compress(
&self,
input: Vec<u8>,
level: CompressionLevel,
) -> Result<Vec<u8>, CompressionFailed>;
fn algorithm(&self) -> CertificateCompressionAlgorithm;
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum CompressionLevel {
Interactive,
Amortized,
}
#[derive(Debug)]
pub struct DecompressionFailed;
#[derive(Debug)]
pub struct CompressionFailed;
#[cfg(feature = "zlib")]
mod feat_zlib_rs {
use zlib_rs::{
DeflateConfig, InflateConfig, ReturnCode, compress_bound, compress_slice, decompress_slice,
};
use super::*;
pub const ZLIB_DECOMPRESSOR: &dyn CertDecompressor = &ZlibRsDecompressor;
#[derive(Debug)]
struct ZlibRsDecompressor;
impl CertDecompressor for ZlibRsDecompressor {
fn decompress(&self, input: &[u8], output: &mut [u8]) -> Result<(), DecompressionFailed> {
let output_len = output.len();
match decompress_slice(output, input, InflateConfig::default()) {
(output_filled, ReturnCode::Ok) if output_filled.len() == output_len => Ok(()),
(_, _) => Err(DecompressionFailed),
}
}
fn algorithm(&self) -> CertificateCompressionAlgorithm {
CertificateCompressionAlgorithm::Zlib
}
}
pub const ZLIB_COMPRESSOR: &dyn CertCompressor = &ZlibRsCompressor;
#[derive(Debug)]
struct ZlibRsCompressor;
impl CertCompressor for ZlibRsCompressor {
fn compress(
&self,
input: Vec<u8>,
level: CompressionLevel,
) -> Result<Vec<u8>, CompressionFailed> {
let mut output = alloc::vec![0u8; compress_bound(input.len())];
let config = match level {
CompressionLevel::Interactive => DeflateConfig::default(),
CompressionLevel::Amortized => DeflateConfig::best_compression(),
};
let (output_filled, rc) = compress_slice(&mut output, &input, config);
if rc != ReturnCode::Ok {
return Err(CompressionFailed);
}
let used = output_filled.len();
output.truncate(used);
Ok(output)
}
fn algorithm(&self) -> CertificateCompressionAlgorithm {
CertificateCompressionAlgorithm::Zlib
}
}
}
#[cfg(feature = "zlib")]
pub use feat_zlib_rs::{ZLIB_COMPRESSOR, ZLIB_DECOMPRESSOR};
#[cfg(feature = "brotli")]
mod feat_brotli {
use std::io::{Cursor, Write};
use super::*;
pub const BROTLI_DECOMPRESSOR: &dyn CertDecompressor = &BrotliDecompressor;
#[derive(Debug)]
struct BrotliDecompressor;
impl CertDecompressor for BrotliDecompressor {
fn decompress(&self, input: &[u8], output: &mut [u8]) -> Result<(), DecompressionFailed> {
let mut in_cursor = Cursor::new(input);
let mut out_cursor = Cursor::new(output);
brotli::BrotliDecompress(&mut in_cursor, &mut out_cursor)
.map_err(|_| DecompressionFailed)?;
if out_cursor.position() as usize != out_cursor.into_inner().len() {
return Err(DecompressionFailed);
}
Ok(())
}
fn algorithm(&self) -> CertificateCompressionAlgorithm {
CertificateCompressionAlgorithm::Brotli
}
}
pub const BROTLI_COMPRESSOR: &dyn CertCompressor = &BrotliCompressor;
#[derive(Debug)]
struct BrotliCompressor;
impl CertCompressor for BrotliCompressor {
fn compress(
&self,
input: Vec<u8>,
level: CompressionLevel,
) -> Result<Vec<u8>, CompressionFailed> {
let quality = match level {
CompressionLevel::Interactive => QUALITY_FAST,
CompressionLevel::Amortized => QUALITY_SLOW,
};
let output = Cursor::new(Vec::with_capacity(input.len() / 2));
let mut compressor = brotli::CompressorWriter::new(output, BUFFER_SIZE, quality, LGWIN);
compressor
.write_all(&input)
.map_err(|_| CompressionFailed)?;
Ok(compressor.into_inner().into_inner())
}
fn algorithm(&self) -> CertificateCompressionAlgorithm {
CertificateCompressionAlgorithm::Brotli
}
}
const BUFFER_SIZE: usize = 4096;
const LGWIN: u32 = 22;
const QUALITY_FAST: u32 = 4;
const QUALITY_SLOW: u32 = 11;
}
#[cfg(feature = "brotli")]
pub use feat_brotli::{BROTLI_COMPRESSOR, BROTLI_DECOMPRESSOR};
#[derive(Debug)]
pub enum CompressionCache {
Disabled,
#[cfg(feature = "std")]
Enabled(CompressionCacheInner),
}
#[cfg(feature = "std")]
#[derive(Debug)]
pub struct CompressionCacheInner {
size: usize,
entries: Mutex<VecDeque<Arc<CompressionCacheEntry>>>,
}
impl CompressionCache {
#[cfg(feature = "std")]
pub fn new(size: usize) -> Self {
if size == 0 {
return Self::Disabled;
}
Self::Enabled(CompressionCacheInner {
size,
entries: Mutex::new(VecDeque::with_capacity(size)),
})
}
pub(crate) fn compression_for(
&self,
compressor: &dyn CertCompressor,
original: &CertificatePayloadTls13<'_>,
) -> Result<Arc<CompressionCacheEntry>, CompressionFailed> {
match self {
Self::Disabled => Self::uncached_compression(compressor, original),
#[cfg(feature = "std")]
Self::Enabled(_) => self.compression_for_impl(compressor, original),
}
}
#[cfg(feature = "std")]
fn compression_for_impl(
&self,
compressor: &dyn CertCompressor,
original: &CertificatePayloadTls13<'_>,
) -> Result<Arc<CompressionCacheEntry>, CompressionFailed> {
let (max_size, entries) = match self {
Self::Enabled(CompressionCacheInner { size, entries }) => (*size, entries),
_ => unreachable!(),
};
if !original.context.0.is_empty() {
return Self::uncached_compression(compressor, original);
}
let encoding = original.get_encoding();
let algorithm = compressor.algorithm();
let mut cache = entries
.lock()
.map_err(|_| CompressionFailed)?;
for (i, item) in cache.iter().enumerate() {
if item.algorithm == algorithm && item.original == encoding {
let item = cache.remove(i).unwrap();
cache.push_back(item.clone());
return Ok(item);
}
}
drop(cache);
let uncompressed_len = encoding.len() as u32;
let compressed = compressor.compress(encoding.clone(), CompressionLevel::Amortized)?;
let new_entry = Arc::new(CompressionCacheEntry {
algorithm,
original: encoding,
compressed: CompressedCertificatePayload {
alg: algorithm,
uncompressed_len,
compressed: PayloadU24(Payload::new(compressed)),
},
});
let mut cache = entries
.lock()
.map_err(|_| CompressionFailed)?;
if cache.len() == max_size {
cache.pop_front();
}
cache.push_back(new_entry.clone());
Ok(new_entry)
}
fn uncached_compression(
compressor: &dyn CertCompressor,
original: &CertificatePayloadTls13<'_>,
) -> Result<Arc<CompressionCacheEntry>, CompressionFailed> {
let algorithm = compressor.algorithm();
let encoding = original.get_encoding();
let uncompressed_len = encoding.len() as u32;
let compressed = compressor.compress(encoding, CompressionLevel::Interactive)?;
Ok(Arc::new(CompressionCacheEntry {
algorithm,
original: Vec::new(),
compressed: CompressedCertificatePayload {
alg: algorithm,
uncompressed_len,
compressed: PayloadU24(Payload::new(compressed)),
},
}))
}
}
impl Default for CompressionCache {
fn default() -> Self {
#[cfg(feature = "std")]
{
Self::new(4)
}
#[cfg(not(feature = "std"))]
{
Self::Disabled
}
}
}
#[cfg_attr(not(feature = "std"), allow(dead_code))]
#[derive(Debug)]
pub(crate) struct CompressionCacheEntry {
algorithm: CertificateCompressionAlgorithm,
original: Vec<u8>,
compressed: CompressedCertificatePayload<'static>,
}
impl CompressionCacheEntry {
pub(crate) fn compressed_cert_payload(&self) -> CompressedCertificatePayload<'_> {
self.compressed.as_borrowed()
}
}
#[cfg(all(test, any(feature = "brotli", feature = "zlib")))]
mod tests {
use std::{println, vec};
use super::*;
#[test]
#[cfg(feature = "zlib")]
fn test_zlib() {
test_compressor(ZLIB_COMPRESSOR, ZLIB_DECOMPRESSOR);
}
#[test]
#[cfg(feature = "brotli")]
fn test_brotli() {
test_compressor(BROTLI_COMPRESSOR, BROTLI_DECOMPRESSOR);
}
fn test_compressor(comp: &dyn CertCompressor, decomp: &dyn CertDecompressor) {
assert_eq!(comp.algorithm(), decomp.algorithm());
for sz in [16, 64, 512, 2048, 8192, 16384] {
test_trivial_pairwise(comp, decomp, sz);
}
test_decompress_wrong_len(comp, decomp);
test_decompress_garbage(decomp);
}
fn test_trivial_pairwise(
comp: &dyn CertCompressor,
decomp: &dyn CertDecompressor,
plain_len: usize,
) {
let original = vec![0u8; plain_len];
for level in [CompressionLevel::Interactive, CompressionLevel::Amortized] {
let compressed = comp
.compress(original.clone(), level)
.unwrap();
println!(
"{:?} compressed trivial {} -> {} using {:?} level",
comp.algorithm(),
original.len(),
compressed.len(),
level
);
let mut recovered = vec![0xffu8; plain_len];
decomp
.decompress(&compressed, &mut recovered)
.unwrap();
assert_eq!(original, recovered);
}
}
fn test_decompress_wrong_len(comp: &dyn CertCompressor, decomp: &dyn CertDecompressor) {
let original = vec![0u8; 2048];
let compressed = comp
.compress(original.clone(), CompressionLevel::Interactive)
.unwrap();
println!("{compressed:?}");
let mut recovered = vec![0xffu8; original.len() + 1];
decomp
.decompress(&compressed, &mut recovered)
.unwrap_err();
let mut recovered = vec![0xffu8; original.len() - 1];
decomp
.decompress(&compressed, &mut recovered)
.unwrap_err();
}
fn test_decompress_garbage(decomp: &dyn CertDecompressor) {
let junk = [0u8; 1024];
let mut recovered = vec![0u8; 512];
decomp
.decompress(&junk, &mut recovered)
.unwrap_err();
}
#[test]
#[cfg(all(feature = "brotli", feature = "zlib"))]
fn test_cache_evicts_lru() {
use core::sync::atomic::{AtomicBool, Ordering};
use pki_types::CertificateDer;
let cache = CompressionCache::default();
let cert = CertificateDer::from(vec![1]);
let cert1 = CertificatePayloadTls13::new([&cert].into_iter(), Some(b"1"));
let cert2 = CertificatePayloadTls13::new([&cert].into_iter(), Some(b"2"));
let cert3 = CertificatePayloadTls13::new([&cert].into_iter(), Some(b"3"));
let cert4 = CertificatePayloadTls13::new([&cert].into_iter(), Some(b"4"));
cache
.compression_for(
&RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
&cert1,
)
.unwrap();
cache
.compression_for(
&RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
&cert2,
)
.unwrap();
cache
.compression_for(
&RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
&cert3,
)
.unwrap();
cache
.compression_for(
&RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
&cert4,
)
.unwrap();
cache
.compression_for(
&RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), true),
&cert4,
)
.unwrap();
cache
.compression_for(
&RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
&cert2,
)
.unwrap();
cache
.compression_for(
&RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
&cert3,
)
.unwrap();
cache
.compression_for(
&RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
&cert4,
)
.unwrap();
cache
.compression_for(
&RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), false),
&cert4,
)
.unwrap();
cache
.compression_for(
&RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
&cert1,
)
.unwrap();
cache
.compression_for(
&RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
&cert4,
)
.unwrap();
cache
.compression_for(
&RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
&cert3,
)
.unwrap();
cache
.compression_for(
&RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
&cert1,
)
.unwrap();
cache
.compression_for(
&RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), true),
&cert1,
)
.unwrap();
cache
.compression_for(
&RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), true),
&cert4,
)
.unwrap();
#[derive(Debug)]
struct RequireCompress(&'static dyn CertCompressor, AtomicBool, bool);
impl CertCompressor for RequireCompress {
fn compress(
&self,
input: Vec<u8>,
level: CompressionLevel,
) -> Result<Vec<u8>, CompressionFailed> {
self.1.store(true, Ordering::SeqCst);
self.0.compress(input, level)
}
fn algorithm(&self) -> CertificateCompressionAlgorithm {
self.0.algorithm()
}
}
impl Drop for RequireCompress {
fn drop(&mut self) {
assert_eq!(self.1.load(Ordering::SeqCst), self.2);
}
}
}
}