use crate::tls::Error;
use crate::tls::codec::hs_type;
use crate::tls::codec::{ReadCursor, put_u16, with_len_u8, with_len_u24};
use alloc::vec::Vec;
#[doc(hidden)]
pub const MAX_UNCOMPRESSED_BYTES: u32 = 256 * 1024;
pub(crate) mod algorithm {
pub(crate) const ZLIB: u16 = 1;
#[allow(dead_code)]
pub(crate) const BROTLI: u16 = 2;
#[allow(dead_code)]
pub(crate) const ZSTD: u16 = 3;
}
pub(crate) fn default_algorithms() -> Vec<u16> {
alloc::vec![algorithm::ZLIB]
}
pub(crate) fn supports(algorithm: u16) -> bool {
algorithm == algorithm::ZLIB
}
pub(crate) fn pick_from_lists(offered: &[u16], local: &[u16]) -> Option<u16> {
offered
.iter()
.copied()
.find(|a| supports(*a) && local.contains(a))
}
pub(crate) fn encode_extension(algorithms: &[u16]) -> Vec<u8> {
let mut out = Vec::with_capacity(1 + algorithms.len() * 2);
let take = algorithms.len().min(127);
with_len_u8(&mut out, |b| {
for alg in &algorithms[..take] {
put_u16(b, *alg);
}
});
out
}
#[doc(hidden)]
pub fn decode_extension(body: &[u8]) -> Result<Vec<u16>, Error> {
let mut c = ReadCursor::new(body);
let list = c.vec_u8()?;
c.expect_empty()?;
if list.is_empty() || list.len() % 2 != 0 {
return Err(Error::Decode);
}
let mut algs = Vec::with_capacity(list.len() / 2);
let mut lc = ReadCursor::new(list);
while !lc.is_empty() {
algs.push(lc.u16()?);
}
Ok(algs)
}
pub(crate) fn encode_compressed_certificate(
algorithm: u16,
certificate_message_body: &[u8],
) -> Result<Vec<u8>, Error> {
if !supports(algorithm) {
return Err(Error::IllegalParameter);
}
let uncompressed_length: u32 = certificate_message_body
.len()
.try_into()
.map_err(|_| Error::IllegalParameter)?;
if uncompressed_length > 0x00FF_FFFF {
return Err(Error::IllegalParameter);
}
let compressed = zlib_compress(certificate_message_body)?;
let mut msg = Vec::with_capacity(4 + 5 + compressed.len());
msg.push(hs_type::COMPRESSED_CERTIFICATE);
with_len_u24(&mut msg, |b| {
put_u16(b, algorithm);
b.extend_from_slice(&uncompressed_length.to_be_bytes()[1..]);
with_len_u24(b, |c| c.extend_from_slice(&compressed));
});
Ok(msg)
}
#[doc(hidden)]
pub fn decode_compressed_certificate(body: &[u8]) -> Result<Vec<u8>, Error> {
let mut c = ReadCursor::new(body);
let algorithm = c.u16().map_err(|_| Error::CertDecompressionFailed)?;
let uncompressed_length_u32 = c.u24().map_err(|_| Error::CertDecompressionFailed)? as u32;
let compressed = c.vec_u24().map_err(|_| Error::CertDecompressionFailed)?;
c.expect_empty()
.map_err(|_| Error::CertDecompressionFailed)?;
if !supports(algorithm) {
return Err(Error::CertDecompressionFailed);
}
if uncompressed_length_u32 > MAX_UNCOMPRESSED_BYTES {
return Err(Error::CertDecompressionFailed);
}
let out = zlib_decompress_capped(compressed, uncompressed_length_u32 as usize)?;
if out.len() != uncompressed_length_u32 as usize {
return Err(Error::CertDecompressionFailed);
}
Ok(out)
}
fn zlib_compress(input: &[u8]) -> Result<Vec<u8>, Error> {
compcol::vec::compress_to_vec::<compcol::zlib::Zlib>(input)
.map_err(|_| Error::IllegalParameter)
}
fn zlib_decompress_capped(input: &[u8], cap: usize) -> Result<Vec<u8>, Error> {
use compcol::limit::LimitedDecoder;
use compcol::{Decoder, Status};
let inner = compcol::zlib::Decoder::new();
let mut dec = LimitedDecoder::new(inner, cap as u64);
let mut out = alloc::vec![0u8; cap];
let mut input_pos = 0usize;
let mut output_pos = 0usize;
let mut input_drained = false;
loop {
if !input_drained {
let (progress, status) = dec
.decode(&input[input_pos..], &mut out[output_pos..])
.map_err(|_| Error::CertDecompressionFailed)?;
input_pos += progress.consumed;
output_pos += progress.written;
match status {
Status::StreamEnd => break,
Status::OutputFull => {
return Err(Error::CertDecompressionFailed);
}
Status::InputEmpty => {
input_drained = true;
}
}
} else {
let (progress, status) = dec
.finish(&mut out[output_pos..])
.map_err(|_| Error::CertDecompressionFailed)?;
output_pos += progress.written;
match status {
Status::StreamEnd => break,
Status::OutputFull => {
return Err(Error::CertDecompressionFailed);
}
Status::InputEmpty => {
return Err(Error::CertDecompressionFailed);
}
}
}
}
out.truncate(output_pos);
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extension_codec_round_trip() {
let advert = alloc::vec![algorithm::ZLIB];
let body = encode_extension(&advert);
assert_eq!(body, alloc::vec![0x02, 0x00, 0x01]);
let decoded = decode_extension(&body).expect("decode");
assert_eq!(decoded, advert);
}
#[test]
fn extension_codec_multiple_algorithms() {
let advert = alloc::vec![algorithm::ZLIB, algorithm::BROTLI, algorithm::ZSTD];
let body = encode_extension(&advert);
let decoded = decode_extension(&body).expect("decode");
assert_eq!(decoded, advert);
}
#[test]
fn extension_decode_rejects_empty_list() {
let body = alloc::vec![0x00];
assert!(matches!(decode_extension(&body), Err(Error::Decode)));
}
#[test]
fn extension_decode_rejects_odd_length() {
let body = alloc::vec![0x03, 0x00, 0x01, 0x00];
assert!(matches!(decode_extension(&body), Err(Error::Decode)));
}
#[test]
fn extension_decode_rejects_trailing_garbage() {
let body = alloc::vec![0x02, 0x00, 0x01, 0xAA];
assert!(matches!(decode_extension(&body), Err(Error::Decode)));
}
#[test]
fn pick_returns_zlib_when_offered() {
let local = default_algorithms();
assert_eq!(
pick_from_lists(&[algorithm::ZLIB], &local),
Some(algorithm::ZLIB)
);
assert_eq!(
pick_from_lists(&[algorithm::BROTLI, algorithm::ZLIB], &local),
Some(algorithm::ZLIB)
);
}
#[test]
fn pick_returns_none_with_no_overlap() {
let local = default_algorithms();
assert_eq!(
pick_from_lists(&[algorithm::BROTLI, algorithm::ZSTD], &local),
None
);
assert_eq!(pick_from_lists(&[], &local), None);
assert_eq!(pick_from_lists(&[42, 9000], &local), None);
}
#[test]
fn pick_returns_none_when_local_lacks_algorithm() {
assert_eq!(pick_from_lists(&[algorithm::ZLIB], &[]), None);
assert_eq!(
pick_from_lists(&[algorithm::ZLIB], &[algorithm::BROTLI]),
None
);
}
#[test]
fn compressed_certificate_round_trip() {
let mut cert_body = Vec::new();
cert_body.push(0); with_len_u24(&mut cert_body, |list| {
with_len_u24(list, |c| {
for i in 0..1024 {
c.push((i % 251) as u8);
}
});
list.extend_from_slice(&[0, 0]);
});
let msg = encode_compressed_certificate(algorithm::ZLIB, &cert_body).expect("encode");
assert_eq!(msg[0], 25);
let declared_msg_len =
((msg[1] as usize) << 16) | ((msg[2] as usize) << 8) | msg[3] as usize;
assert_eq!(declared_msg_len, msg.len() - 4);
let recovered = decode_compressed_certificate(&msg[4..]).expect("decode");
assert_eq!(recovered, cert_body);
assert!(
msg.len() < cert_body.len(),
"compressed wire ({}) should be smaller than cert body ({}) for repeating data",
msg.len(),
cert_body.len()
);
}
#[test]
fn decode_rejects_unsupported_algorithm() {
let mut body = Vec::new();
put_u16(&mut body, algorithm::BROTLI);
body.extend_from_slice(&[0x00, 0x00, 0x08]); with_len_u24(&mut body, |b| b.extend_from_slice(b"junk"));
assert!(matches!(
decode_compressed_certificate(&body),
Err(Error::CertDecompressionFailed)
));
}
#[test]
fn decode_rejects_uncompressed_length_over_cap() {
let mut body = Vec::new();
put_u16(&mut body, algorithm::ZLIB);
let over = MAX_UNCOMPRESSED_BYTES + 1;
body.extend_from_slice(&over.to_be_bytes()[1..]); with_len_u24(&mut body, |b| b.extend_from_slice(b"junk"));
assert!(matches!(
decode_compressed_certificate(&body),
Err(Error::CertDecompressionFailed)
));
}
#[test]
fn decode_rejects_length_mismatch() {
let inner = b"abcdefgh";
let compressed = zlib_compress(inner).expect("compress");
let mut body = Vec::new();
put_u16(&mut body, algorithm::ZLIB);
body.extend_from_slice(&[0x00, 0x00, 0x09]); with_len_u24(&mut body, |b| b.extend_from_slice(&compressed));
assert!(matches!(
decode_compressed_certificate(&body),
Err(Error::CertDecompressionFailed)
));
}
#[test]
fn decode_rejects_truncated_compressed_stream() {
let inner = b"the quick brown fox jumps over the lazy dog";
let compressed = zlib_compress(inner).expect("compress");
let truncated = &compressed[..compressed.len() / 2];
let mut body = Vec::new();
put_u16(&mut body, algorithm::ZLIB);
body.extend_from_slice(&(inner.len() as u32).to_be_bytes()[1..]);
with_len_u24(&mut body, |b| b.extend_from_slice(truncated));
assert!(matches!(
decode_compressed_certificate(&body),
Err(Error::CertDecompressionFailed)
));
}
#[test]
fn decode_rejects_zlib_bomb_attempting_to_exceed_cap() {
let big = alloc::vec![0xABu8; 4096];
let compressed = zlib_compress(&big).expect("compress");
let mut body = Vec::new();
put_u16(&mut body, algorithm::ZLIB);
body.extend_from_slice(&[0x00, 0x00, 0x10]); with_len_u24(&mut body, |b| b.extend_from_slice(&compressed));
assert!(matches!(
decode_compressed_certificate(&body),
Err(Error::CertDecompressionFailed)
));
}
#[test]
fn encode_compressed_certificate_rejects_unsupported_algorithm() {
assert!(matches!(
encode_compressed_certificate(algorithm::BROTLI, b"hello"),
Err(Error::IllegalParameter)
));
}
}