use super::config::{EchConfig, HpkeSymCipherSuite};
use super::extension::{EchExtension, decode_outer_position};
use super::hpke_setup::{map_sym_suite, setup_receiver, setup_sender};
use crate::hpke::{ReceiverContext, SenderContext};
use crate::rng::RngCore;
use crate::tls::Error;
use alloc::vec::Vec;
pub(crate) const HPKE_TAG_LEN: usize = 16;
pub(crate) fn pad_inner(
encoded_inner: &[u8],
inner_sni_len: usize,
maximum_name_length: u8,
) -> Vec<u8> {
let max_len = maximum_name_length as usize;
let extra = max_len.saturating_sub(inner_sni_len);
let target = encoded_inner.len() + extra;
let target = target.next_multiple_of(32).max(32);
let mut out = encoded_inner.to_vec();
out.resize(target, 0);
out
}
pub(crate) struct SealedOuter {
pub outer_ch: Vec<u8>,
pub sender: SenderContext,
}
pub(crate) fn seal_into_skeleton(
sender: &mut SenderContext,
outer_ch_skeleton: Vec<u8>,
padded_inner: &[u8],
) -> Result<Vec<u8>, Error> {
let (start, len) = locate_payload_in_handshake(&outer_ch_skeleton)?;
if len != padded_inner.len() + HPKE_TAG_LEN {
return Err(Error::EchDecodeError);
}
let aad = outer_ch_skeleton.clone();
let ciphertext = sender
.seal(&aad, padded_inner)
.map_err(|_| Error::EchDecryptionFailed)?;
if ciphertext.len() != len {
return Err(Error::EchDecryptionFailed);
}
let mut outer_ch = outer_ch_skeleton;
outer_ch[start..start + len].copy_from_slice(&ciphertext);
Ok(outer_ch)
}
pub(crate) fn locate_payload_in_handshake(handshake_msg: &[u8]) -> Result<(usize, usize), Error> {
if handshake_msg.len() < 4 || handshake_msg[0] != crate::tls::codec::hs_type::CLIENT_HELLO {
return Err(Error::EchDecodeError);
}
let body_len = ((handshake_msg[1] as usize) << 16)
| ((handshake_msg[2] as usize) << 8)
| (handshake_msg[3] as usize);
if 4 + body_len != handshake_msg.len() {
return Err(Error::EchDecodeError);
}
let body = &handshake_msg[4..];
let mut idx = 0usize;
let need = |idx: usize, n: usize| -> Result<(), Error> {
if idx + n > body.len() {
Err(Error::EchDecodeError)
} else {
Ok(())
}
};
need(idx, 2)?;
idx += 2;
need(idx, 32)?;
idx += 32;
need(idx, 1)?;
let sid_len = body[idx] as usize;
idx += 1;
need(idx, sid_len)?;
idx += sid_len;
need(idx, 2)?;
let cs_len = ((body[idx] as usize) << 8) | (body[idx + 1] as usize);
idx += 2;
need(idx, cs_len)?;
idx += cs_len;
need(idx, 1)?;
let cm_len = body[idx] as usize;
idx += 1;
need(idx, cm_len)?;
idx += cm_len;
need(idx, 2)?;
let ext_total = ((body[idx] as usize) << 8) | (body[idx + 1] as usize);
idx += 2;
let ext_start_in_body = idx;
need(idx, ext_total)?;
let ext_end_in_body = idx + ext_total;
let mut p = ext_start_in_body;
let mut found: Option<(usize, usize)> = None;
while p < ext_end_in_body {
if p + 4 > ext_end_in_body {
return Err(Error::EchDecodeError);
}
let ty = ((body[p] as u16) << 8) | (body[p + 1] as u16);
let bl = ((body[p + 2] as usize) << 8) | (body[p + 3] as usize);
let body_start = p + 4;
let body_end = body_start + bl;
if body_end > ext_end_in_body {
return Err(Error::EchDecodeError);
}
if ty == crate::tls::codec::ExtensionType::ENCRYPTED_CLIENT_HELLO.0 {
if found.is_some() {
return Err(Error::EchDecodeError);
}
let ext_body = &body[body_start..body_end];
let (pay_off_in_body, pay_len) = decode_outer_position(ext_body)?;
let abs = 4 + body_start + pay_off_in_body;
found = Some((abs, pay_len));
}
p = body_end;
}
found.ok_or(Error::EchDecodeError)
}
pub(crate) fn build_outer_ext_body(
sym: HpkeSymCipherSuite,
config_id: u8,
enc: &[u8],
padded_inner_len: usize,
) -> Vec<u8> {
let payload_len = padded_inner_len + HPKE_TAG_LEN;
let ext = EchExtension::Outer {
cipher_suite: sym,
config_id,
enc: enc.to_vec(),
payload: alloc::vec![0u8; payload_len],
};
ext.encode()
}
pub(crate) fn seal_with<R, F>(
config: &EchConfig,
sym: HpkeSymCipherSuite,
encoded_inner: &[u8],
inner_sni_len: usize,
rng: &mut R,
caller_build_outer_skeleton: F,
) -> Result<SealedOuter, Error>
where
R: RngCore,
F: FnOnce(&[u8], usize) -> Vec<u8>,
{
let contents = config.contents.as_ref().ok_or(Error::EchDecodeError)?;
let padded = pad_inner(encoded_inner, inner_sni_len, contents.maximum_name_length);
let (enc, mut sender, _suite) = setup_sender(rng, config, sym)?;
let skeleton = caller_build_outer_skeleton(&enc, padded.len());
let outer_ch = seal_into_skeleton(&mut sender, skeleton, &padded)?;
Ok(SealedOuter { outer_ch, sender })
}
pub(crate) struct DecappedInner {
pub inner_ch_bytes: Vec<u8>,
pub receiver: ReceiverContext,
pub sym: HpkeSymCipherSuite,
pub config_id: u8,
}
pub(crate) fn try_decap_inner(
handshake_msg: &[u8],
keys: &super::keys::EchKeyRing,
) -> Result<DecappedInner, Error> {
let (payload_off, payload_len) = locate_payload_in_handshake(handshake_msg)?;
let mut aad = handshake_msg.to_vec();
for b in aad[payload_off..payload_off + payload_len].iter_mut() {
*b = 0;
}
let ciphertext = handshake_msg[payload_off..payload_off + payload_len].to_vec();
let (sym, config_id, enc) = extract_outer_meta(handshake_msg)?;
let pair = keys
.find_by_config_id(config_id)
.ok_or(Error::EchDecryptionFailed)?;
let (kdf, aead) = map_sym_suite(sym)?;
if !pair.accepts(kdf, aead) {
return Err(Error::EchDecryptionFailed);
}
let (mut receiver, _suite) =
setup_receiver(pair.config(), pair.private_key_bytes(), &enc, sym)?;
let plaintext = receiver
.open(&aad, &ciphertext)
.map_err(|_| Error::EchDecryptionFailed)?;
let unpadded = strip_trailing_padding(&plaintext)?;
require_inner_marker(&unpadded)?;
Ok(DecappedInner {
inner_ch_bytes: unpadded,
receiver,
sym,
config_id,
})
}
pub(crate) fn try_decap_inner_retry(
handshake_msg: &[u8],
state: &mut DecappedInner,
) -> Result<Vec<u8>, Error> {
let (payload_off, payload_len) = locate_payload_in_handshake(handshake_msg)?;
let mut aad = handshake_msg.to_vec();
for b in aad[payload_off..payload_off + payload_len].iter_mut() {
*b = 0;
}
let ciphertext = handshake_msg[payload_off..payload_off + payload_len].to_vec();
let (sym, config_id, enc) = extract_outer_meta(handshake_msg)?;
if sym != state.sym || config_id != state.config_id || !enc.is_empty() {
return Err(Error::EchDecryptionFailed);
}
let plaintext = state
.receiver
.open(&aad, &ciphertext)
.map_err(|_| Error::EchDecryptionFailed)?;
let unpadded = strip_trailing_padding(&plaintext)?;
require_inner_marker(&unpadded)?;
Ok(unpadded)
}
fn extract_outer_meta(handshake_msg: &[u8]) -> Result<(HpkeSymCipherSuite, u8, Vec<u8>), Error> {
let body = handshake_msg.get(4..).ok_or(Error::EchDecodeError)?;
let mut idx = 0usize;
let need = |idx: usize, n: usize| -> Result<(), Error> {
if idx + n > body.len() {
Err(Error::EchDecodeError)
} else {
Ok(())
}
};
need(idx, 2 + 32 + 1)?;
idx += 2 + 32;
let sid_len = body[idx] as usize;
idx += 1;
need(idx, sid_len + 2)?;
idx += sid_len;
let cs_len = ((body[idx] as usize) << 8) | (body[idx + 1] as usize);
idx += 2;
need(idx, cs_len + 1)?;
idx += cs_len;
let cm_len = body[idx] as usize;
idx += 1;
need(idx, cm_len + 2)?;
idx += cm_len;
let ext_total = ((body[idx] as usize) << 8) | (body[idx + 1] as usize);
idx += 2;
need(idx, ext_total)?;
let ext_start = idx;
let ext_end = idx + ext_total;
let mut p = ext_start;
while p < ext_end {
if p + 4 > ext_end {
return Err(Error::EchDecodeError);
}
let ty = ((body[p] as u16) << 8) | (body[p + 1] as u16);
let bl = ((body[p + 2] as usize) << 8) | (body[p + 3] as usize);
let body_start = p + 4;
let body_end = body_start + bl;
if body_end > ext_end {
return Err(Error::EchDecodeError);
}
if ty == crate::tls::codec::ExtensionType::ENCRYPTED_CLIENT_HELLO.0 {
let ext_body = &body[body_start..body_end];
let ext = EchExtension::decode(ext_body)?;
match ext {
EchExtension::Outer {
cipher_suite,
config_id,
enc,
..
} => return Ok((cipher_suite, config_id, enc)),
EchExtension::Inner => return Err(Error::EchDecodeError),
}
}
p = body_end;
}
Err(Error::EchDecodeError)
}
fn require_inner_marker(inner_ch: &[u8]) -> Result<(), Error> {
if inner_ch.len() < 4 || inner_ch[0] != crate::tls::codec::hs_type::CLIENT_HELLO {
return Err(Error::EchDecodeError);
}
let body_len =
((inner_ch[1] as usize) << 16) | ((inner_ch[2] as usize) << 8) | (inner_ch[3] as usize);
if 4 + body_len != inner_ch.len() {
return Err(Error::EchDecodeError);
}
let body = &inner_ch[4..];
let mut idx = 0usize;
let need = |idx: usize, n: usize| -> Result<(), Error> {
if idx + n > body.len() {
Err(Error::EchDecodeError)
} else {
Ok(())
}
};
need(idx, 2 + 32 + 1)?;
idx += 2 + 32;
let sid_len = body[idx] as usize;
idx += 1;
need(idx, sid_len + 2)?;
idx += sid_len;
let cs_len = ((body[idx] as usize) << 8) | (body[idx + 1] as usize);
idx += 2;
need(idx, cs_len + 1)?;
idx += cs_len;
let cm_len = body[idx] as usize;
idx += 1;
need(idx, cm_len + 2)?;
idx += cm_len;
let ext_total = ((body[idx] as usize) << 8) | (body[idx + 1] as usize);
idx += 2;
need(idx, ext_total)?;
let ext_start = idx;
let ext_end = idx + ext_total;
let mut p = ext_start;
let mut found = false;
while p < ext_end {
if p + 4 > ext_end {
return Err(Error::EchDecodeError);
}
let ty = ((body[p] as u16) << 8) | (body[p + 1] as u16);
let bl = ((body[p + 2] as usize) << 8) | (body[p + 3] as usize);
let body_start = p + 4;
let body_end = body_start + bl;
if body_end > ext_end {
return Err(Error::EchDecodeError);
}
if ty == crate::tls::codec::ExtensionType::ENCRYPTED_CLIENT_HELLO.0 {
if found {
return Err(Error::EchDecodeError);
}
let ext_body = &body[body_start..body_end];
match EchExtension::decode(ext_body)? {
EchExtension::Inner => {}
EchExtension::Outer { .. } => return Err(Error::EchDecodeError),
}
found = true;
}
p = body_end;
}
if !found {
return Err(Error::EchDecodeError);
}
Ok(())
}
fn strip_trailing_padding(padded: &[u8]) -> Result<Vec<u8>, Error> {
if padded.len() < 4 || padded[0] != crate::tls::codec::hs_type::CLIENT_HELLO {
return Err(Error::EchDecodeError);
}
let body_len =
((padded[1] as usize) << 16) | ((padded[2] as usize) << 8) | (padded[3] as usize);
let total = 4 + body_len;
if total > padded.len() {
return Err(Error::EchDecodeError);
}
if padded[total..].iter().any(|b| *b != 0) {
return Err(Error::EchDecodeError);
}
Ok(padded[..total].to_vec())
}