use alloc::vec::Vec;
use core::num::NonZeroUsize;
use core::{fmt, mem};
#[cfg(feature = "std")]
use std::error::Error as StdError;
use super::UnbufferedConnectionCommon;
use crate::client::ClientConnectionData;
use crate::msgs::deframer::buffers::{BufferProgress, DeframerSliceBuffer};
use crate::server::ServerConnectionData;
use crate::Error;
impl UnbufferedConnectionCommon<ClientConnectionData> {
pub fn process_tls_records<'c, 'i>(
&'c mut self,
incoming_tls: &'i mut [u8],
) -> UnbufferedStatus<'c, 'i, ClientConnectionData> {
self.process_tls_records_common(incoming_tls, |_| None, |_, _, ()| unreachable!())
}
}
impl UnbufferedConnectionCommon<ServerConnectionData> {
pub fn process_tls_records<'c, 'i>(
&'c mut self,
incoming_tls: &'i mut [u8],
) -> UnbufferedStatus<'c, 'i, ServerConnectionData> {
self.process_tls_records_common(
incoming_tls,
|conn| conn.pop_early_data(),
|conn, incoming_tls, chunk| ReadEarlyData::new(conn, incoming_tls, chunk).into(),
)
}
}
impl<Data> UnbufferedConnectionCommon<Data> {
fn process_tls_records_common<'c, 'i, T>(
&'c mut self,
incoming_tls: &'i mut [u8],
mut check: impl FnMut(&mut Self) -> Option<T>,
execute: impl FnOnce(&'c mut Self, &'i mut [u8], T) -> ConnectionState<'c, 'i, Data>,
) -> UnbufferedStatus<'c, 'i, Data> {
let mut buffer = DeframerSliceBuffer::new(incoming_tls);
let mut buffer_progress = BufferProgress::default();
let (discard, state) = loop {
if let Some(value) = check(self) {
break (buffer.pending_discard(), execute(self, incoming_tls, value));
}
if let Some(chunk) = self
.core
.common_state
.received_plaintext
.pop()
{
break (
buffer.pending_discard(),
ReadTraffic::new(self, incoming_tls, chunk).into(),
);
}
if let Some(chunk) = self
.core
.common_state
.sendable_tls
.pop()
{
break (
buffer.pending_discard(),
EncodeTlsData::new(self, chunk).into(),
);
}
let deframer_output =
match self
.core
.deframe(None, buffer.filled_mut(), &mut buffer_progress)
{
Err(err) => {
buffer.queue_discard(buffer_progress.take_discard());
return UnbufferedStatus {
discard: buffer.pending_discard(),
state: Err(err),
};
}
Ok(r) => r,
};
if let Some(msg) = deframer_output {
let mut state =
match mem::replace(&mut self.core.state, Err(Error::HandshakeNotComplete)) {
Ok(state) => state,
Err(e) => {
buffer.queue_discard(buffer_progress.take_discard());
self.core.state = Err(e.clone());
return UnbufferedStatus {
discard: buffer.pending_discard(),
state: Err(e),
};
}
};
match self.core.process_msg(msg, state, None) {
Ok(new) => state = new,
Err(e) => {
buffer.queue_discard(buffer_progress.take_discard());
self.core.state = Err(e.clone());
return UnbufferedStatus {
discard: buffer.pending_discard(),
state: Err(e),
};
}
}
buffer.queue_discard(buffer_progress.take_discard());
self.core.state = Ok(state);
} else if self.wants_write {
break (
buffer.pending_discard(),
TransmitTlsData { conn: self }.into(),
);
} else if self
.core
.common_state
.has_received_close_notify
{
break (buffer.pending_discard(), ConnectionState::Closed);
} else if self
.core
.common_state
.may_send_application_data
{
break (
buffer.pending_discard(),
ConnectionState::WriteTraffic(WriteTraffic { conn: self }),
);
} else {
break (buffer.pending_discard(), ConnectionState::BlockedHandshake);
}
};
UnbufferedStatus {
discard,
state: Ok(state),
}
}
}
#[must_use]
#[derive(Debug)]
pub struct UnbufferedStatus<'c, 'i, Data> {
pub discard: usize,
pub state: Result<ConnectionState<'c, 'i, Data>, Error>,
}
#[non_exhaustive] pub enum ConnectionState<'c, 'i, Data> {
ReadTraffic(ReadTraffic<'c, 'i, Data>),
Closed,
ReadEarlyData(ReadEarlyData<'c, 'i, Data>),
EncodeTlsData(EncodeTlsData<'c, Data>),
TransmitTlsData(TransmitTlsData<'c, Data>),
BlockedHandshake,
WriteTraffic(WriteTraffic<'c, Data>),
}
impl<'c, 'i, Data> From<ReadTraffic<'c, 'i, Data>> for ConnectionState<'c, 'i, Data> {
fn from(v: ReadTraffic<'c, 'i, Data>) -> Self {
Self::ReadTraffic(v)
}
}
impl<'c, 'i, Data> From<ReadEarlyData<'c, 'i, Data>> for ConnectionState<'c, 'i, Data> {
fn from(v: ReadEarlyData<'c, 'i, Data>) -> Self {
Self::ReadEarlyData(v)
}
}
impl<'c, Data> From<EncodeTlsData<'c, Data>> for ConnectionState<'c, '_, Data> {
fn from(v: EncodeTlsData<'c, Data>) -> Self {
Self::EncodeTlsData(v)
}
}
impl<'c, Data> From<TransmitTlsData<'c, Data>> for ConnectionState<'c, '_, Data> {
fn from(v: TransmitTlsData<'c, Data>) -> Self {
Self::TransmitTlsData(v)
}
}
impl<Data> fmt::Debug for ConnectionState<'_, '_, Data> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ReadTraffic(..) => f.debug_tuple("ReadTraffic").finish(),
Self::Closed => write!(f, "Closed"),
Self::ReadEarlyData(..) => f.debug_tuple("ReadEarlyData").finish(),
Self::EncodeTlsData(..) => f.debug_tuple("EncodeTlsData").finish(),
Self::TransmitTlsData(..) => f
.debug_tuple("TransmitTlsData")
.finish(),
Self::BlockedHandshake => f
.debug_tuple("BlockedHandshake")
.finish(),
Self::WriteTraffic(..) => f.debug_tuple("WriteTraffic").finish(),
}
}
}
pub struct ReadTraffic<'c, 'i, Data> {
_conn: &'c mut UnbufferedConnectionCommon<Data>,
_incoming_tls: &'i mut [u8],
chunk: Vec<u8>,
taken: bool,
}
impl<'c, 'i, Data> ReadTraffic<'c, 'i, Data> {
fn new(
_conn: &'c mut UnbufferedConnectionCommon<Data>,
_incoming_tls: &'i mut [u8],
chunk: Vec<u8>,
) -> Self {
Self {
_conn,
_incoming_tls,
chunk,
taken: false,
}
}
pub fn next_record(&mut self) -> Option<Result<AppDataRecord<'_>, Error>> {
if self.taken {
None
} else {
self.taken = true;
Some(Ok(AppDataRecord {
discard: 0,
payload: &self.chunk,
}))
}
}
pub fn peek_len(&self) -> Option<NonZeroUsize> {
if self.taken {
None
} else {
NonZeroUsize::new(self.chunk.len())
}
}
}
pub struct ReadEarlyData<'c, 'i, Data> {
_conn: &'c mut UnbufferedConnectionCommon<Data>,
_incoming_tls: &'i mut [u8],
chunk: Vec<u8>,
taken: bool,
}
impl<'c, 'i, Data> ReadEarlyData<'c, 'i, Data> {
fn new(
_conn: &'c mut UnbufferedConnectionCommon<Data>,
_incoming_tls: &'i mut [u8],
chunk: Vec<u8>,
) -> Self {
Self {
_conn,
_incoming_tls,
chunk,
taken: false,
}
}
}
impl ReadEarlyData<'_, '_, ServerConnectionData> {
pub fn next_record(&mut self) -> Option<Result<AppDataRecord<'_>, Error>> {
if self.taken {
None
} else {
self.taken = true;
Some(Ok(AppDataRecord {
discard: 0,
payload: &self.chunk,
}))
}
}
pub fn peek_len(&self) -> Option<NonZeroUsize> {
if self.taken {
None
} else {
NonZeroUsize::new(self.chunk.len())
}
}
}
pub struct AppDataRecord<'i> {
pub discard: usize,
pub payload: &'i [u8],
}
pub struct WriteTraffic<'c, Data> {
conn: &'c mut UnbufferedConnectionCommon<Data>,
}
impl<Data> WriteTraffic<'_, Data> {
pub fn encrypt(
&mut self,
application_data: &[u8],
outgoing_tls: &mut [u8],
) -> Result<usize, EncryptError> {
self.conn
.core
.maybe_refresh_traffic_keys();
self.conn
.core
.common_state
.write_plaintext(application_data.into(), outgoing_tls)
}
pub fn queue_close_notify(&mut self, outgoing_tls: &mut [u8]) -> Result<usize, EncryptError> {
self.conn
.core
.common_state
.eager_send_close_notify(outgoing_tls)
}
pub fn refresh_traffic_keys(self) -> Result<(), Error> {
self.conn.core.refresh_traffic_keys()
}
}
pub struct EncodeTlsData<'c, Data> {
conn: &'c mut UnbufferedConnectionCommon<Data>,
chunk: Option<Vec<u8>>,
}
impl<'c, Data> EncodeTlsData<'c, Data> {
fn new(conn: &'c mut UnbufferedConnectionCommon<Data>, chunk: Vec<u8>) -> Self {
Self {
conn,
chunk: Some(chunk),
}
}
pub fn encode(&mut self, outgoing_tls: &mut [u8]) -> Result<usize, EncodeError> {
let chunk = match self.chunk.take() {
Some(chunk) => chunk,
None => return Err(EncodeError::AlreadyEncoded),
};
let required_size = chunk.len();
if required_size > outgoing_tls.len() {
self.chunk = Some(chunk);
Err(InsufficientSizeError { required_size }.into())
} else {
let written = chunk.len();
outgoing_tls[..written].copy_from_slice(&chunk);
self.conn.wants_write = true;
Ok(written)
}
}
}
pub struct TransmitTlsData<'c, Data> {
pub(crate) conn: &'c mut UnbufferedConnectionCommon<Data>,
}
impl<Data> TransmitTlsData<'_, Data> {
pub fn done(self) {
self.conn.wants_write = false;
}
pub fn may_encrypt_app_data(&mut self) -> Option<WriteTraffic<'_, Data>> {
if self
.conn
.core
.common_state
.may_send_application_data
{
Some(WriteTraffic { conn: self.conn })
} else {
None
}
}
}
#[derive(Debug)]
pub enum EncodeError {
InsufficientSize(InsufficientSizeError),
AlreadyEncoded,
}
impl From<InsufficientSizeError> for EncodeError {
fn from(v: InsufficientSizeError) -> Self {
Self::InsufficientSize(v)
}
}
impl fmt::Display for EncodeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InsufficientSize(InsufficientSizeError { required_size }) => write!(
f,
"cannot encode due to insufficient size, {} bytes are required",
required_size
),
Self::AlreadyEncoded => "cannot encode, data has already been encoded".fmt(f),
}
}
}
#[cfg(feature = "std")]
impl StdError for EncodeError {}
#[derive(Debug)]
pub enum EncryptError {
InsufficientSize(InsufficientSizeError),
EncryptExhausted,
}
impl From<InsufficientSizeError> for EncryptError {
fn from(v: InsufficientSizeError) -> Self {
Self::InsufficientSize(v)
}
}
impl fmt::Display for EncryptError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InsufficientSize(InsufficientSizeError { required_size }) => write!(
f,
"cannot encrypt due to insufficient size, {required_size} bytes are required"
),
Self::EncryptExhausted => f.write_str("encrypter has been exhausted"),
}
}
}
#[cfg(feature = "std")]
impl StdError for EncryptError {}
#[derive(Clone, Copy, Debug)]
pub struct InsufficientSizeError {
pub required_size: usize,
}