use super::{ContentType, RecordLayer};
use crate::aead::TlsAead;
use crate::alert::TurtlsAlert;
use crate::error::FullError;
use crate::TurtlsError;
use crylib::aead::TAG_SIZE;
enum ReadStatus {
NeedsHeader(usize),
NeedsData(usize),
Moving(usize),
}
impl ReadStatus {
pub const fn new() -> Self {
Self::NeedsHeader(0)
}
}
pub(super) struct ReadBuf {
buf: [u8; RecordLayer::BUF_SIZE],
len: usize,
status: ReadStatus,
}
impl ReadBuf {
pub(super) const fn new() -> Self {
Self {
buf: [0; RecordLayer::BUF_SIZE],
len: 0,
status: ReadStatus::new(),
}
}
}
impl RecordLayer {
pub(crate) fn msg_type(&self) -> u8 {
self.rbuf.buf[0]
}
pub(crate) fn get_raw(&mut self) -> Result<(), FullError> {
loop {
match self.rbuf.status {
ReadStatus::NeedsHeader(ref mut bytes_read) => {
while *bytes_read < Self::HEADER_SIZE {
let Some(new_bytes) = self
.io
.read(&mut self.rbuf.buf[*bytes_read..Self::HEADER_SIZE])
else {
return Err(FullError::error(TurtlsError::WantRead));
};
*bytes_read += new_bytes;
}
self.rbuf.len = u16::from_be_bytes(
self.rbuf.buf[Self::HEADER_SIZE - Self::LEN_SIZE..Self::HEADER_SIZE]
.try_into()
.unwrap(),
) as usize;
if self.rbuf.len > Self::MAX_LEN {
return Err(FullError::sending_alert(TurtlsAlert::RecordOverflow));
}
self.rbuf.status = ReadStatus::NeedsData(0);
},
ReadStatus::NeedsData(ref mut bytes_read) => {
while *bytes_read < self.rbuf.len {
let Some(new_bytes) = self.io.read(
&mut self.rbuf.buf[Self::HEADER_SIZE + *bytes_read
..Self::HEADER_SIZE + self.rbuf.len],
) else {
return Err(FullError::error(TurtlsError::WantRead));
};
*bytes_read += new_bytes;
}
self.rbuf.status = ReadStatus::Moving(0);
},
ReadStatus::Moving(bytes_moved) => {
if bytes_moved == self.rbuf.len {
self.rbuf.status = ReadStatus::NeedsHeader(0);
} else {
return Ok(());
}
},
}
}
}
pub(crate) fn get(&mut self, aead: &mut TlsAead) -> Result<(), FullError> {
self.get_raw()?;
self.deprotect(aead)
.map_err(|alert| FullError::sending_alert(alert))
}
fn deprotect(&mut self, aead: &mut TlsAead) -> Result<(), TurtlsAlert> {
if self.msg_type() != ContentType::ApplicationData.to_byte() {
return Ok(());
}
if self.rbuf.len < Self::MIN_PROT_LEN {
return Err(TurtlsAlert::DecodeError);
}
let (header, msg) = self.rbuf.buf[..Self::HEADER_SIZE + self.rbuf.len]
.split_at_mut(RecordLayer::HEADER_SIZE);
let (msg, tag) = msg.split_at_mut(msg.len() - TAG_SIZE);
let tag = (tag as &[u8]).try_into().unwrap();
if aead.decrypt_inline(msg, header, tag).is_err() {
return Err(TurtlsAlert::BadRecordMac);
}
self.rbuf.len -= TAG_SIZE;
let Some(padding) = self.rbuf.buf[Self::HEADER_SIZE..][..self.rbuf.len]
.iter()
.rev()
.position(|&x| x != 0)
else {
return Err(TurtlsAlert::UnexpectedMessage);
};
self.rbuf.len -= padding;
self.rbuf.buf[0] = self.rbuf.buf[Self::HEADER_SIZE + self.rbuf.len - 1];
self.rbuf.len -= 1;
Ok(())
}
pub(crate) fn read_remaining(&mut self, buf: &mut [u8]) -> usize {
let ReadStatus::Moving(ref mut bytes_read) = self.rbuf.status else {
panic!("The record cannot be read because it is currently being retrieved");
};
let new_bytes = std::cmp::min(self.rbuf.len - *bytes_read, buf.len());
buf[..new_bytes]
.copy_from_slice(&self.rbuf.buf[Self::HEADER_SIZE + *bytes_read..][..new_bytes]);
*bytes_read += new_bytes;
new_bytes
}
pub(crate) fn read_raw(&mut self, buf: &mut [u8]) -> Result<usize, FullError> {
self.get_raw()?;
Ok(self.read_remaining(buf))
}
pub(crate) fn read(&mut self, buf: &mut [u8], aead: &mut TlsAead) -> Result<usize, FullError> {
self.get(aead)?;
Ok(self.read_remaining(buf))
}
pub(crate) fn discard(&mut self) {
self.rbuf.status = ReadStatus::new();
}
pub(crate) fn check_alert(&self) -> Result<(), TurtlsAlert> {
if self.msg_type() == ContentType::Alert.to_byte() {
Err(TurtlsAlert::from_byte(self.rbuf.buf[Self::HEADER_SIZE + 1]))
} else {
Ok(())
}
}
}