use alloc::vec::Vec;
use core::num::NonZeroUsize;
use core::{fmt, mem};
#[cfg(feature = "std")]
use std::error::Error as StdError;
use super::UnbufferedConnectionCommon;
use crate::Error;
use crate::client::ClientConnectionData;
use crate::msgs::deframer::buffers::DeframerSliceBuffer;
use crate::server::ServerConnectionData;
impl UnbufferedConnectionCommon<ClientConnectionData> {
pub fn process_tls_records<'c, 'i>(
&'c mut self,
incoming_tls: &'i mut [u8],
) -> UnbufferedStatus<'c, 'i, ClientConnectionData> {
self.process_tls_records_common(incoming_tls, |_| false, |_, _| unreachable!())
}
}
impl UnbufferedConnectionCommon<ServerConnectionData> {
pub fn process_tls_records<'c, 'i>(
&'c mut self,
incoming_tls: &'i mut [u8],
) -> UnbufferedStatus<'c, 'i, ServerConnectionData> {
self.process_tls_records_common(
incoming_tls,
|conn| conn.peek_early_data().is_some(),
|conn, incoming_tls| ReadEarlyData::new(conn, incoming_tls).into(),
)
}
}
impl<Data> UnbufferedConnectionCommon<Data> {
fn process_tls_records_common<'c, 'i>(
&'c mut self,
incoming_tls: &'i mut [u8],
mut early_data_available: impl FnMut(&mut Self) -> bool,
early_data_state: impl FnOnce(&'c mut Self, &'i mut [u8]) -> ConnectionState<'c, 'i, Data>,
) -> UnbufferedStatus<'c, 'i, Data> {
let mut buffer = DeframerSliceBuffer::new(incoming_tls);
let mut buffer_progress = self.core.hs_deframer.progress();
let (discard, state) = loop {
if early_data_available(self) {
break (
buffer.pending_discard(),
early_data_state(self, incoming_tls),
);
}
if !self
.core
.common_state
.received_plaintext
.is_empty()
{
break (
buffer.pending_discard(),
ReadTraffic::new(self, incoming_tls).into(),
);
}
if let Some(chunk) = self
.core
.common_state
.sendable_tls
.pop()
{
break (
buffer.pending_discard(),
EncodeTlsData::new(self, chunk).into(),
);
}
let deframer_output = if self
.core
.common_state
.has_received_close_notify
{
None
} else {
match self
.core
.deframe(None, buffer.filled_mut(), &mut buffer_progress)
{
Err(err) => {
buffer.queue_discard(buffer_progress.take_discard());
return UnbufferedStatus {
discard: buffer.pending_discard(),
state: Err(err),
};
}
Ok(r) => r,
}
};
if let Some(msg) = deframer_output {
let mut state =
match mem::replace(&mut self.core.state, Err(Error::HandshakeNotComplete)) {
Ok(state) => state,
Err(e) => {
buffer.queue_discard(buffer_progress.take_discard());
self.core.state = Err(e.clone());
return UnbufferedStatus {
discard: buffer.pending_discard(),
state: Err(e),
};
}
};
match self.core.process_msg(msg, state, None) {
Ok(new) => state = new,
Err(e) => {
buffer.queue_discard(buffer_progress.take_discard());
self.core.state = Err(e.clone());
return UnbufferedStatus {
discard: buffer.pending_discard(),
state: Err(e),
};
}
}
buffer.queue_discard(buffer_progress.take_discard());
self.core.state = Ok(state);
} else if self.wants_write {
break (
buffer.pending_discard(),
TransmitTlsData { conn: self }.into(),
);
} else if self
.core
.common_state
.has_received_close_notify
&& !self.emitted_peer_closed_state
{
self.emitted_peer_closed_state = true;
break (buffer.pending_discard(), ConnectionState::PeerClosed);
} else if self
.core
.common_state
.has_received_close_notify
&& self
.core
.common_state
.has_sent_close_notify
{
break (buffer.pending_discard(), ConnectionState::Closed);
} else if self
.core
.common_state
.may_send_application_data
{
break (
buffer.pending_discard(),
ConnectionState::WriteTraffic(WriteTraffic { conn: self }),
);
} else {
break (buffer.pending_discard(), ConnectionState::BlockedHandshake);
}
};
UnbufferedStatus {
discard,
state: Ok(state),
}
}
}
#[must_use]
#[derive(Debug)]
pub struct UnbufferedStatus<'c, 'i, Data> {
pub discard: usize,
pub state: Result<ConnectionState<'c, 'i, Data>, Error>,
}
#[non_exhaustive] pub enum ConnectionState<'c, 'i, Data> {
ReadTraffic(ReadTraffic<'c, 'i, Data>),
PeerClosed,
Closed,
ReadEarlyData(ReadEarlyData<'c, 'i, Data>),
EncodeTlsData(EncodeTlsData<'c, Data>),
TransmitTlsData(TransmitTlsData<'c, Data>),
BlockedHandshake,
WriteTraffic(WriteTraffic<'c, Data>),
}
impl<'c, 'i, Data> From<ReadTraffic<'c, 'i, Data>> for ConnectionState<'c, 'i, Data> {
fn from(v: ReadTraffic<'c, 'i, Data>) -> Self {
Self::ReadTraffic(v)
}
}
impl<'c, 'i, Data> From<ReadEarlyData<'c, 'i, Data>> for ConnectionState<'c, 'i, Data> {
fn from(v: ReadEarlyData<'c, 'i, Data>) -> Self {
Self::ReadEarlyData(v)
}
}
impl<'c, Data> From<EncodeTlsData<'c, Data>> for ConnectionState<'c, '_, Data> {
fn from(v: EncodeTlsData<'c, Data>) -> Self {
Self::EncodeTlsData(v)
}
}
impl<'c, Data> From<TransmitTlsData<'c, Data>> for ConnectionState<'c, '_, Data> {
fn from(v: TransmitTlsData<'c, Data>) -> Self {
Self::TransmitTlsData(v)
}
}
impl<Data> fmt::Debug for ConnectionState<'_, '_, Data> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ReadTraffic(..) => f.debug_tuple("ReadTraffic").finish(),
Self::PeerClosed => write!(f, "PeerClosed"),
Self::Closed => write!(f, "Closed"),
Self::ReadEarlyData(..) => f.debug_tuple("ReadEarlyData").finish(),
Self::EncodeTlsData(..) => f.debug_tuple("EncodeTlsData").finish(),
Self::TransmitTlsData(..) => f
.debug_tuple("TransmitTlsData")
.finish(),
Self::BlockedHandshake => f
.debug_tuple("BlockedHandshake")
.finish(),
Self::WriteTraffic(..) => f.debug_tuple("WriteTraffic").finish(),
}
}
}
pub struct ReadTraffic<'c, 'i, Data> {
conn: &'c mut UnbufferedConnectionCommon<Data>,
_incoming_tls: &'i mut [u8],
chunk: Option<Vec<u8>>,
}
impl<'c, 'i, Data> ReadTraffic<'c, 'i, Data> {
fn new(conn: &'c mut UnbufferedConnectionCommon<Data>, _incoming_tls: &'i mut [u8]) -> Self {
Self {
conn,
_incoming_tls,
chunk: None,
}
}
pub fn next_record(&mut self) -> Option<Result<AppDataRecord<'_>, Error>> {
self.chunk = self
.conn
.core
.common_state
.received_plaintext
.pop();
self.chunk.as_ref().map(|chunk| {
Ok(AppDataRecord {
discard: 0,
payload: chunk,
})
})
}
pub fn peek_len(&self) -> Option<NonZeroUsize> {
self.conn
.core
.common_state
.received_plaintext
.peek()
.and_then(|ch| NonZeroUsize::new(ch.len()))
}
}
pub struct ReadEarlyData<'c, 'i, Data> {
conn: &'c mut UnbufferedConnectionCommon<Data>,
_incoming_tls: &'i mut [u8],
chunk: Option<Vec<u8>>,
}
impl<'c, 'i> ReadEarlyData<'c, 'i, ServerConnectionData> {
fn new(
conn: &'c mut UnbufferedConnectionCommon<ServerConnectionData>,
_incoming_tls: &'i mut [u8],
) -> Self {
Self {
conn,
_incoming_tls,
chunk: None,
}
}
pub fn next_record(&mut self) -> Option<Result<AppDataRecord<'_>, Error>> {
self.chunk = self.conn.pop_early_data();
self.chunk.as_ref().map(|chunk| {
Ok(AppDataRecord {
discard: 0,
payload: chunk,
})
})
}
pub fn peek_len(&self) -> Option<NonZeroUsize> {
self.conn
.peek_early_data()
.and_then(|ch| NonZeroUsize::new(ch.len()))
}
}
pub struct AppDataRecord<'i> {
pub discard: usize,
pub payload: &'i [u8],
}
pub struct WriteTraffic<'c, Data> {
conn: &'c mut UnbufferedConnectionCommon<Data>,
}
impl<Data> WriteTraffic<'_, Data> {
pub fn encrypt(
&mut self,
application_data: &[u8],
outgoing_tls: &mut [u8],
) -> Result<usize, EncryptError> {
self.conn
.core
.maybe_refresh_traffic_keys();
self.conn
.core
.common_state
.write_plaintext(application_data.into(), outgoing_tls)
}
pub fn queue_close_notify(&mut self, outgoing_tls: &mut [u8]) -> Result<usize, EncryptError> {
self.conn
.core
.common_state
.eager_send_close_notify(outgoing_tls)
}
pub fn refresh_traffic_keys(self) -> Result<(), Error> {
self.conn.core.refresh_traffic_keys()
}
}
pub struct EncodeTlsData<'c, Data> {
conn: &'c mut UnbufferedConnectionCommon<Data>,
chunk: Option<Vec<u8>>,
}
impl<'c, Data> EncodeTlsData<'c, Data> {
fn new(conn: &'c mut UnbufferedConnectionCommon<Data>, chunk: Vec<u8>) -> Self {
Self {
conn,
chunk: Some(chunk),
}
}
pub fn encode(&mut self, outgoing_tls: &mut [u8]) -> Result<usize, EncodeError> {
let Some(chunk) = self.chunk.take() else {
return Err(EncodeError::AlreadyEncoded);
};
let required_size = chunk.len();
if required_size > outgoing_tls.len() {
self.chunk = Some(chunk);
Err(InsufficientSizeError { required_size }.into())
} else {
let written = chunk.len();
outgoing_tls[..written].copy_from_slice(&chunk);
self.conn.wants_write = true;
Ok(written)
}
}
}
pub struct TransmitTlsData<'c, Data> {
pub(crate) conn: &'c mut UnbufferedConnectionCommon<Data>,
}
impl<Data> TransmitTlsData<'_, Data> {
pub fn done(self) {
self.conn.wants_write = false;
}
pub fn may_encrypt_app_data(&mut self) -> Option<WriteTraffic<'_, Data>> {
if self
.conn
.core
.common_state
.may_send_application_data
{
Some(WriteTraffic { conn: self.conn })
} else {
None
}
}
}
#[derive(Debug)]
pub enum EncodeError {
InsufficientSize(InsufficientSizeError),
AlreadyEncoded,
}
impl From<InsufficientSizeError> for EncodeError {
fn from(v: InsufficientSizeError) -> Self {
Self::InsufficientSize(v)
}
}
impl fmt::Display for EncodeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InsufficientSize(InsufficientSizeError { required_size }) => write!(
f,
"cannot encode due to insufficient size, {required_size} bytes are required"
),
Self::AlreadyEncoded => "cannot encode, data has already been encoded".fmt(f),
}
}
}
#[cfg(feature = "std")]
impl StdError for EncodeError {}
#[derive(Debug)]
pub enum EncryptError {
InsufficientSize(InsufficientSizeError),
EncryptExhausted,
}
impl From<InsufficientSizeError> for EncryptError {
fn from(v: InsufficientSizeError) -> Self {
Self::InsufficientSize(v)
}
}
impl fmt::Display for EncryptError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InsufficientSize(InsufficientSizeError { required_size }) => write!(
f,
"cannot encrypt due to insufficient size, {required_size} bytes are required"
),
Self::EncryptExhausted => f.write_str("encrypter has been exhausted"),
}
}
}
#[cfg(feature = "std")]
impl StdError for EncryptError {}
#[derive(Clone, Copy, Debug)]
pub struct InsufficientSizeError {
pub required_size: usize,
}