use tracing::{debug, trace};
use crate::Instant;
use crate::connection::spaces::PacketSpace;
use crate::crypto::{HeaderKey, KeyPair, PacketKey};
use crate::packet::{Packet, PartialDecode, SpaceId};
use crate::token::ResetToken;
use crate::{RESET_TOKEN_SIZE, TransportError};
pub(super) fn unprotect_header(
partial_decode: PartialDecode,
spaces: &[PacketSpace; 3],
zero_rtt_crypto: Option<&ZeroRttCrypto>,
stateless_reset_token: Option<ResetToken>,
) -> Option<UnprotectHeaderResult> {
let header_crypto = if partial_decode.is_0rtt() {
if let Some(crypto) = zero_rtt_crypto {
Some(&*crypto.header)
} else {
debug!("dropping unexpected 0-RTT packet");
return None;
}
} else if let Some(space) = partial_decode.space() {
if let Some(ref crypto) = spaces[space].crypto {
Some(&*crypto.header.remote)
} else {
debug!(
"discarding unexpected {:?} packet ({} bytes)",
space,
partial_decode.len(),
);
return None;
}
} else {
None
};
let packet = partial_decode.data();
let stateless_reset = packet.len() >= RESET_TOKEN_SIZE + 5
&& stateless_reset_token.as_deref() == Some(&packet[packet.len() - RESET_TOKEN_SIZE..]);
match partial_decode.finish(header_crypto) {
Ok(packet) => Some(UnprotectHeaderResult {
packet: Some(packet),
stateless_reset,
}),
Err(_) if stateless_reset => Some(UnprotectHeaderResult {
packet: None,
stateless_reset: true,
}),
Err(e) => {
trace!("unable to complete packet decoding: {}", e);
None
}
}
}
pub(super) struct UnprotectHeaderResult {
pub(super) packet: Option<Packet>,
pub(super) stateless_reset: bool,
}
pub(super) fn decrypt_packet_body(
packet: &mut Packet,
spaces: &[PacketSpace; 3],
zero_rtt_crypto: Option<&ZeroRttCrypto>,
conn_key_phase: bool,
prev_crypto: Option<&PrevCrypto>,
next_crypto: Option<&KeyPair<Box<dyn PacketKey>>>,
) -> Result<Option<DecryptPacketResult>, Option<TransportError>> {
if !packet.header.is_protected() {
return Ok(None);
}
let space = packet.header.space();
let rx_packet = spaces[space].rx_packet;
let number = packet.header.number().ok_or(None)?.expand(rx_packet + 1);
let packet_key_phase = packet.header.key_phase();
let mut crypto_update = false;
let crypto = if packet.header.is_0rtt() {
&zero_rtt_crypto.unwrap().packet
} else if packet_key_phase == conn_key_phase || space != SpaceId::Data {
&spaces[space].crypto.as_ref().unwrap().packet.remote
} else if let Some(prev) = prev_crypto.and_then(|crypto| {
if crypto.end_packet.is_none_or(|(pn, _)| number < pn) {
Some(crypto)
} else {
None
}
}) {
&prev.crypto.remote
} else {
crypto_update = true;
&next_crypto.unwrap().remote
};
crypto
.decrypt(number, &packet.header_data, &mut packet.payload)
.map_err(|_| {
trace!("decryption failed with packet number {}", number);
None
})?;
if !packet.reserved_bits_valid() {
return Err(Some(TransportError::PROTOCOL_VIOLATION(
"reserved bits set",
)));
}
let mut outgoing_key_update_acked = false;
if let Some(prev) = prev_crypto {
if prev.end_packet.is_none() && packet_key_phase == conn_key_phase {
outgoing_key_update_acked = true;
}
}
if crypto_update {
if number <= rx_packet || prev_crypto.is_some_and(|x| x.update_unacked) {
return Err(Some(TransportError::KEY_UPDATE_ERROR("")));
}
}
Ok(Some(DecryptPacketResult {
number,
outgoing_key_update_acked,
incoming_key_update: crypto_update,
}))
}
pub(super) struct DecryptPacketResult {
pub(super) number: u64,
pub(super) outgoing_key_update_acked: bool,
pub(super) incoming_key_update: bool,
}
pub(super) struct PrevCrypto {
pub(super) crypto: KeyPair<Box<dyn PacketKey>>,
pub(super) end_packet: Option<(u64, Instant)>,
pub(super) update_unacked: bool,
}
pub(super) struct ZeroRttCrypto {
pub(super) header: Box<dyn HeaderKey>,
pub(super) packet: Box<dyn PacketKey>,
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::{Bytes, BytesMut};
use crate::crypto::{CryptoError, Keys};
use crate::packet::{FixedLengthConnectionIdParser, Header, PacketNumber, SpaceId};
use crate::transport_error::Code;
use crate::{ConnectionId, Instant};
const REALISTIC_SAMPLE_SIZE: usize = 16;
struct TestHeaderKey;
impl HeaderKey for TestHeaderKey {
fn decrypt(&self, _pn_offset: usize, _packet: &mut [u8]) {}
fn encrypt(&self, _pn_offset: usize, _packet: &mut [u8]) {}
fn sample_size(&self) -> usize {
REALISTIC_SAMPLE_SIZE
}
}
struct TestPacketKey;
impl PacketKey for TestPacketKey {
fn encrypt(&self, _packet: u64, _buf: &mut [u8], _header_len: usize) {}
fn decrypt(
&self,
_packet: u64,
_header: &[u8],
_payload: &mut BytesMut,
) -> Result<(), CryptoError> {
Ok(())
}
fn tag_len(&self) -> usize {
0
}
fn confidentiality_limit(&self) -> u64 {
u64::MAX
}
fn integrity_limit(&self) -> u64 {
u64::MAX
}
}
fn test_packet_keys() -> KeyPair<Box<dyn PacketKey>> {
KeyPair {
local: Box::new(TestPacketKey),
remote: Box::new(TestPacketKey),
}
}
fn test_keys() -> Keys {
Keys {
header: KeyPair {
local: Box::new(TestHeaderKey),
remote: Box::new(TestHeaderKey),
},
packet: test_packet_keys(),
}
}
fn spaces_with_crypto() -> [PacketSpace; 3] {
let now = Instant::now();
let mut spaces = [
PacketSpace::new(now),
PacketSpace::new(now),
PacketSpace::new(now),
];
spaces[SpaceId::Data].crypto = Some(test_keys());
spaces
}
fn short_packet_bytes(first_byte: u8, packet_number: u8, payload: &[u8]) -> BytesMut {
let mut bytes = Vec::with_capacity(2 + payload.len());
bytes.push(first_byte);
bytes.push(packet_number);
bytes.extend_from_slice(payload);
let min_size = 1 + 4 + REALISTIC_SAMPLE_SIZE;
while bytes.len() < min_size {
bytes.push(0x00);
}
BytesMut::from(bytes.as_slice())
}
fn decode_short_packet(bytes: BytesMut) -> PartialDecode {
let supported_versions = crate::DEFAULT_SUPPORTED_VERSIONS.to_vec();
PartialDecode::new(
bytes,
&FixedLengthConnectionIdParser::new(0),
&supported_versions,
false,
)
.unwrap()
.0
}
fn short_packet(packet_number: u8, key_phase: bool, first_byte: u8) -> Packet {
Packet {
header: Header::Short {
spin: false,
key_phase,
dst_cid: ConnectionId::new(&[]),
number: PacketNumber::U8(packet_number),
},
header_data: Bytes::from(vec![first_byte]),
payload: BytesMut::from(&[0u8; 8][..]),
}
}
#[test]
fn unprotect_header_sets_stateless_reset_for_matching_token() {
let token_bytes = [0xAB; RESET_TOKEN_SIZE];
let stateless_reset_token = Some(ResetToken::from(token_bytes));
let mut payload = vec![0u8; 3];
payload.extend_from_slice(&token_bytes);
let bytes = short_packet_bytes(0x40, 0x01, &payload);
let partial = decode_short_packet(bytes);
let spaces = spaces_with_crypto();
let result = unprotect_header(partial, &spaces, None, stateless_reset_token)
.expect("packet should be decoded");
assert!(result.packet.is_some());
assert!(result.stateless_reset);
}
#[test]
fn unprotect_header_ignores_non_matching_token() {
let token_bytes = [0xAB; RESET_TOKEN_SIZE];
let stateless_reset_token = Some(ResetToken::from([0xCD; RESET_TOKEN_SIZE]));
let mut payload = vec![0u8; 3];
payload.extend_from_slice(&token_bytes);
let bytes = short_packet_bytes(0x40, 0x01, &payload);
let partial = decode_short_packet(bytes);
let spaces = spaces_with_crypto();
let result = unprotect_header(partial, &spaces, None, stateless_reset_token)
.expect("packet should be decoded");
assert!(result.packet.is_some());
assert!(!result.stateless_reset);
}
#[test]
fn decrypt_packet_body_rejects_reserved_bits() {
let mut spaces = spaces_with_crypto();
spaces[SpaceId::Data].rx_packet = 0;
let mut packet = short_packet(1, false, 0x58);
let result = decrypt_packet_body(&mut packet, &spaces, None, false, None, None);
let err = result
.err()
.expect("should be error")
.expect("should have transport error");
assert_eq!(err.code, Code::PROTOCOL_VIOLATION);
}
#[test]
fn decrypt_packet_body_reports_key_update_errors() {
let mut spaces = spaces_with_crypto();
spaces[SpaceId::Data].rx_packet = 10;
let mut packet = short_packet(10, true, 0x44);
let next_crypto = test_packet_keys();
let result =
decrypt_packet_body(&mut packet, &spaces, None, false, None, Some(&next_crypto));
let err = result
.err()
.expect("should be error")
.expect("should have transport error");
assert_eq!(err.code, Code::KEY_UPDATE_ERROR);
let mut spaces = spaces_with_crypto();
spaces[SpaceId::Data].rx_packet = 0;
let mut packet = short_packet(1, true, 0x44);
let prev_crypto = PrevCrypto {
crypto: test_packet_keys(),
end_packet: Some((0, Instant::now())),
update_unacked: true,
};
let next_crypto = test_packet_keys();
let result = decrypt_packet_body(
&mut packet,
&spaces,
None,
false,
Some(&prev_crypto),
Some(&next_crypto),
);
let err = result
.err()
.expect("should be error")
.expect("should have transport error");
assert_eq!(err.code, Code::KEY_UPDATE_ERROR);
}
#[test]
fn decrypt_packet_body_returns_result_for_valid_packet() {
let mut spaces = spaces_with_crypto();
spaces[SpaceId::Data].rx_packet = 0;
let mut packet = short_packet(1, false, 0x40);
let result = decrypt_packet_body(&mut packet, &spaces, None, false, None, None)
.expect("decryption should succeed")
.expect("protected packet should return result");
assert_eq!(result.number, 1);
assert!(!result.outgoing_key_update_acked);
assert!(!result.incoming_key_update);
}
#[test]
fn unprotect_header_rejects_too_short_packet() {
let spaces = spaces_with_crypto();
let too_short =
BytesMut::from(&[0x40, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00][..]);
let supported_versions = crate::DEFAULT_SUPPORTED_VERSIONS.to_vec();
let partial_result = PartialDecode::new(
too_short,
&FixedLengthConnectionIdParser::new(0),
&supported_versions,
false,
);
if let Ok((partial, _)) = partial_result {
let result = unprotect_header(partial, &spaces, None, None);
assert!(
result.is_none(),
"Packet too short for header protection should be rejected during unprotect"
);
}
}
}