use super::super::codec::{ParsedRecord, is_legal_record_version, read_record, write_record};
use super::super::crypto::{RecordCrypter, Transcript};
use crate::tls::{Alert, AlertDescription, ContentType, Error, ProtocolVersion};
use alloc::vec::Vec;
pub(crate) const MAX_HANDSHAKE_REASSEMBLY: usize = 128 * 1024;
pub(crate) enum Incoming {
Handshake(Vec<u8>),
ApplicationData(usize),
Alert(Alert),
}
pub(crate) struct ConnectionCore {
inbuf: Vec<u8>,
outbuf: Vec<u8>,
hs_pending: Vec<u8>,
app_in: Vec<u8>,
early_in: Vec<u8>,
early_data_routing: bool,
read: Option<RecordCrypter>,
write: Option<RecordCrypter>,
pub(crate) transcript: Transcript,
sent_close_notify: bool,
ccs_window_open: bool,
peer_record_size_limit: Option<u16>,
}
impl ConnectionCore {
pub(crate) fn new() -> Self {
ConnectionCore {
inbuf: Vec::new(),
outbuf: Vec::new(),
hs_pending: Vec::new(),
app_in: Vec::new(),
early_in: Vec::new(),
early_data_routing: false,
read: None,
write: None,
transcript: Transcript::new(),
sent_close_notify: false,
ccs_window_open: true,
peer_record_size_limit: None,
}
}
pub(crate) fn set_peer_record_size_limit(&mut self, limit: u16) {
self.peer_record_size_limit = Some(limit);
}
pub(crate) fn close_ccs_window(&mut self) {
self.ccs_window_open = false;
}
pub(crate) fn read_tls(&mut self, bytes: &[u8]) {
self.inbuf.extend_from_slice(bytes);
}
pub(crate) fn write_tls(&mut self) -> Vec<u8> {
core::mem::take(&mut self.outbuf)
}
pub(crate) fn wants_write(&self) -> bool {
!self.outbuf.is_empty()
}
pub(crate) fn set_read(&mut self, crypter: RecordCrypter) {
self.read = Some(crypter);
}
pub(crate) fn set_write(&mut self, crypter: RecordCrypter) {
self.write = Some(crypter);
}
pub(crate) fn take_received(&mut self) -> Vec<u8> {
core::mem::take(&mut self.app_in)
}
pub(crate) fn take_early_data(&mut self) -> Vec<u8> {
core::mem::take(&mut self.early_in)
}
pub(crate) fn set_early_data_routing(&mut self, enabled: bool) {
self.early_data_routing = enabled;
}
pub(crate) fn emit_handshake(&mut self, message: Vec<u8>) {
self.transcript.update(&message);
self.emit_record(ContentType::Handshake, &message);
}
#[allow(dead_code)]
pub(crate) fn transcript_only(&mut self, message: &[u8]) {
self.transcript.update(message);
}
#[allow(dead_code)]
pub(crate) fn quic_feed_handshake(&mut self, bytes: &[u8]) -> Result<(), Error> {
self.append_handshake_bytes(bytes)
}
fn append_handshake_bytes(&mut self, bytes: &[u8]) -> Result<(), Error> {
if self.hs_pending.len().saturating_add(bytes.len()) > MAX_HANDSHAKE_REASSEMBLY {
return Err(Error::RecordOverflow);
}
self.hs_pending.extend_from_slice(bytes);
Ok(())
}
pub(crate) fn emit_ccs(&mut self) {
write_record(
&mut self.outbuf,
ContentType::ChangeCipherSpec,
ProtocolVersion::TLSv1_2,
&[1],
);
}
pub(crate) fn send_application_data(&mut self, data: &[u8]) {
let cap = self
.peer_record_size_limit
.map(|l| (l - 1) as usize)
.unwrap_or(1 << 14);
let cap = cap.min(1 << 14);
if data.len() <= cap {
self.emit_record(ContentType::ApplicationData, data);
} else {
for chunk in data.chunks(cap) {
self.emit_record(ContentType::ApplicationData, chunk);
}
}
}
pub(crate) fn send_alert(&mut self, description: AlertDescription) {
let body = [2, description.as_u8()]; self.emit_record(ContentType::Alert, &body);
}
pub(crate) fn send_close_notify(&mut self) {
if !self.sent_close_notify {
self.sent_close_notify = true;
let body = [1, AlertDescription::CloseNotify.as_u8()];
self.emit_record(ContentType::Alert, &body);
}
}
pub(crate) fn emit_record(&mut self, ct: ContentType, payload: &[u8]) {
match &mut self.write {
Some(crypter) => match crypter.encrypt(ct, payload) {
Ok(rec) => self.outbuf.extend_from_slice(&rec),
Err(_) => {
}
},
None => write_record(&mut self.outbuf, ct, ProtocolVersion::TLSv1_2, payload),
}
}
pub(crate) fn next_message(&mut self) -> Result<Option<Incoming>, Error> {
loop {
if let Some(msg) = self.pop_handshake() {
return Ok(Some(Incoming::Handshake(msg)));
}
let Some(ParsedRecord {
content_type,
version,
fragment,
len,
}) = read_record(&self.inbuf)?
else {
return Ok(None);
};
if !is_legal_record_version(version) {
return Err(Error::UnsupportedVersion);
}
let fragment = fragment.to_vec();
self.inbuf.drain(..len);
match content_type {
ContentType::ChangeCipherSpec => {
if !self.ccs_window_open || fragment.as_slice() != [0x01] {
return Err(Error::UnexpectedMessage);
}
continue;
}
ContentType::ApplicationData if self.read.is_some() => {
let (inner_ct, content) = self.decrypt(&fragment)?;
if let Some(msg) = self.dispatch_inner(inner_ct, content)? {
return Ok(Some(msg));
}
}
ContentType::Handshake => {
if self.read.is_some() {
return Err(Error::UnexpectedMessage);
}
self.append_handshake_bytes(&fragment)?;
}
ContentType::Alert => {
if self.read.is_some() {
return Err(Error::UnexpectedMessage);
}
return Ok(Some(parse_alert(&fragment)?));
}
_ => return Err(Error::UnexpectedMessage),
}
}
}
fn decrypt(&mut self, fragment: &[u8]) -> Result<(ContentType, Vec<u8>), Error> {
let mut header = [0u8; 5];
header[0] = ContentType::ApplicationData.as_u8();
header[1] = 0x03;
header[2] = 0x03;
header[3..5].copy_from_slice(&(fragment.len() as u16).to_be_bytes());
let crypter = self.read.as_mut().expect("read keys present");
crypter.decrypt(&header, fragment)
}
fn dispatch_inner(
&mut self,
inner_ct: ContentType,
content: Vec<u8>,
) -> Result<Option<Incoming>, Error> {
match inner_ct {
ContentType::Handshake => {
if content.is_empty() {
return Err(Error::UnexpectedMessage);
}
self.append_handshake_bytes(&content)?;
Ok(None)
}
ContentType::ApplicationData => {
let plaintext_len = content.len();
if self.early_data_routing {
self.early_in.extend_from_slice(&content);
} else {
self.app_in.extend_from_slice(&content);
}
Ok(Some(Incoming::ApplicationData(plaintext_len)))
}
ContentType::Alert => {
if content.is_empty() {
return Err(Error::UnexpectedMessage);
}
Ok(Some(parse_alert(&content)?))
}
_ => Err(Error::UnexpectedMessage),
}
}
fn pop_handshake(&mut self) -> Option<Vec<u8>> {
if self.hs_pending.len() < 4 {
return None;
}
let len = ((self.hs_pending[1] as usize) << 16)
| ((self.hs_pending[2] as usize) << 8)
| self.hs_pending[3] as usize;
let total = 4 + len;
if self.hs_pending.len() < total {
return None;
}
Some(self.hs_pending.drain(..total).collect())
}
}
fn parse_alert(body: &[u8]) -> Result<Incoming, Error> {
if body.len() != 2 {
return Err(Error::Decode);
}
Ok(Incoming::Alert(Alert {
fatal: body[0] == 2,
description: AlertDescription::from_u8(body[1]),
}))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn handshake_reassembly_bound_enforces_ceiling() {
let mut core = ConnectionCore::new();
let chunk = alloc::vec![0u8; 16 * 1024];
let chunks_to_fill = MAX_HANDSHAKE_REASSEMBLY / chunk.len();
for _ in 0..chunks_to_fill {
core.quic_feed_handshake(&chunk).unwrap();
}
assert!(matches!(
core.quic_feed_handshake(&chunk),
Err(Error::RecordOverflow)
));
}
#[test]
fn handshake_reassembly_bound_rejects_oversize_fragment() {
let mut core = ConnectionCore::new();
let too_big = alloc::vec![0u8; MAX_HANDSHAKE_REASSEMBLY + 1];
assert!(matches!(
core.quic_feed_handshake(&too_big),
Err(Error::RecordOverflow)
));
}
}