use super::extension::EchExtension;
use crate::tls::Error;
use crate::tls::codec::{ExtensionType, RawExtension};
use alloc::vec::Vec;
pub fn inner_extension_body() -> Vec<u8> {
EchExtension::Inner.encode()
}
pub(crate) fn encode_outer_extensions(types: &[ExtensionType]) -> Vec<u8> {
let mut body = Vec::with_capacity(1 + types.len() * 2);
let list_len = types.len() * 2;
body.push(list_len as u8);
for t in types {
body.extend_from_slice(&t.0.to_be_bytes());
}
body
}
pub(crate) fn decode_outer_extensions(body: &[u8]) -> Result<Vec<ExtensionType>, Error> {
if body.is_empty() {
return Err(Error::EchDecodeError);
}
let list_len = body[0] as usize;
if list_len < 2 || !list_len.is_multiple_of(2) || 1 + list_len != body.len() {
return Err(Error::EchDecodeError);
}
let mut out = Vec::with_capacity(list_len / 2);
let mut i = 1;
while i < body.len() {
let t = u16::from_be_bytes([body[i], body[i + 1]]);
out.push(ExtensionType(t));
i += 2;
}
Ok(out)
}
pub(crate) fn compress_extensions(
canonical_inner: &[RawExtension],
outer: &[RawExtension],
share_types: &[ExtensionType],
) -> Result<Vec<RawExtension>, Error> {
if share_types.is_empty() {
return Ok(canonical_inner.to_vec());
}
validate_share_types(share_types)?;
let inner_start =
find_subsequence(canonical_inner, share_types).ok_or(Error::EchDecodeError)?;
if find_subsequence(outer, share_types).is_none() {
return Err(Error::EchDecodeError);
}
let inner_end = inner_start + share_types.len();
let mut out = Vec::with_capacity(canonical_inner.len() - share_types.len() + 1);
out.extend_from_slice(&canonical_inner[..inner_start]);
out.push((
ExtensionType::ECH_OUTER_EXTENSIONS,
encode_outer_extensions(share_types),
));
out.extend_from_slice(&canonical_inner[inner_end..]);
Ok(out)
}
pub(crate) fn decompress_extensions(
compressed_inner: &[RawExtension],
outer: &[RawExtension],
) -> Result<Vec<RawExtension>, Error> {
let mut out = Vec::with_capacity(compressed_inner.len());
let mut seen_placeholder = false;
for (ty, body) in compressed_inner {
if *ty != ExtensionType::ECH_OUTER_EXTENSIONS {
out.push((*ty, body.clone()));
continue;
}
if seen_placeholder {
return Err(Error::EchDecodeError);
}
seen_placeholder = true;
let types = decode_outer_extensions(body)?;
validate_share_types(&types)?;
let outer_positions = resolve_outer_positions(outer, &types)?;
for &pos in &outer_positions {
let (oty, obody) = &outer[pos];
debug_assert_eq!(
*oty,
types[outer_positions.iter().position(|&p| p == pos).unwrap()]
);
out.push((*oty, obody.clone()));
}
}
Ok(out)
}
fn validate_share_types(types: &[ExtensionType]) -> Result<(), Error> {
for (i, t) in types.iter().enumerate() {
if *t == ExtensionType::ECH_OUTER_EXTENSIONS || *t == ExtensionType::ENCRYPTED_CLIENT_HELLO
{
return Err(Error::EchDecodeError);
}
if types[..i].contains(t) {
return Err(Error::EchDecodeError);
}
}
Ok(())
}
fn find_subsequence(haystack: &[RawExtension], needle: &[ExtensionType]) -> Option<usize> {
if needle.is_empty() || haystack.len() < needle.len() {
return None;
}
'outer: for start in 0..=haystack.len() - needle.len() {
for (i, n) in needle.iter().enumerate() {
if haystack[start + i].0 != *n {
continue 'outer;
}
}
return Some(start);
}
None
}
fn resolve_outer_positions(
outer: &[RawExtension],
types: &[ExtensionType],
) -> Result<Vec<usize>, Error> {
let mut positions = Vec::with_capacity(types.len());
let mut last = None::<usize>;
for t in types {
let mut found = None;
let start = last.map(|p| p + 1).unwrap_or(0);
for (i, (oty, _)) in outer.iter().enumerate().skip(start) {
if oty == t {
found = Some(i);
break;
}
}
let pos = found.ok_or(Error::EchDecodeError)?;
if *t == ExtensionType::ECH_OUTER_EXTENSIONS || *t == ExtensionType::ENCRYPTED_CLIENT_HELLO
{
return Err(Error::EchDecodeError);
}
positions.push(pos);
last = Some(pos);
}
Ok(positions)
}