use super::{ContentType, RecordLayer};
use crate::aead::TlsAead;
use crate::alert::{AlertMsg, TurtlsAlert};
use crate::extensions::versions::LEGACY_PROTO_VERS;
use crate::TurtlsError;
use crylib::aead::TAG_SIZE;
pub(super) struct WriteBuf {
buf: [u8; RecordLayer::BUF_SIZE],
len: usize,
record_bytes: usize,
total_bytes: usize,
}
impl WriteBuf {
pub const fn new() -> Self {
let mut buf = Self {
buf: [0; RecordLayer::BUF_SIZE],
len: 0,
record_bytes: 0,
total_bytes: 0,
};
[buf.buf[1], buf.buf[2]] = LEGACY_PROTO_VERS.to_be_bytes();
buf
}
}
impl RecordLayer {
fn set_msg_type(&mut self, msg_type: ContentType) {
self.wbuf.buf[0] = msg_type.to_byte();
}
pub(crate) fn write_raw(
&mut self,
buf: &[u8],
msg_type: ContentType,
) -> Result<(), TurtlsError> {
self.set_msg_type(msg_type);
while self.wbuf.total_bytes < buf.len() {
let record_size = std::cmp::min(buf.len() - self.wbuf.total_bytes, Self::MAX_LEN);
self.wbuf.buf[Self::HEADER_SIZE..][..record_size]
.copy_from_slice(&buf[self.wbuf.total_bytes..][..record_size]);
self.wbuf.len = record_size;
self.finish_raw()?;
}
self.wbuf.total_bytes = 0;
Ok(())
}
pub(crate) fn write(
&mut self,
buf: &[u8],
msg_type: ContentType,
aead: &mut TlsAead,
) -> Result<(), TurtlsError> {
self.set_msg_type(ContentType::ApplicationData);
for record in buf[self.wbuf.total_bytes..].chunks(Self::MAX_LEN) {
self.wbuf.buf[Self::HEADER_SIZE..][..record.len()].copy_from_slice(record);
self.wbuf.len = record.len();
self.finish(msg_type, aead)?;
}
Ok(())
}
fn encode_len(&mut self) {
let len = (self.wbuf.len as u16).to_be_bytes();
self.wbuf.buf[Self::HEADER_SIZE - Self::LEN_SIZE..Self::HEADER_SIZE].copy_from_slice(&len);
}
fn finish_raw(&mut self) -> Result<(), TurtlsError> {
self.encode_len();
self.write_record()
}
fn finish(&mut self, msg_type: ContentType, aead: &mut TlsAead) -> Result<(), TurtlsError> {
self.wbuf.buf[Self::HEADER_SIZE + self.wbuf.len] = msg_type.to_byte();
self.wbuf.len += size_of::<ContentType>();
self.wbuf.len += TAG_SIZE;
self.encode_len();
self.protect(aead);
self.write_record()
}
fn write_record(&mut self) -> Result<(), TurtlsError> {
while self.wbuf.record_bytes < self.wbuf.len + Self::HEADER_SIZE {
let bytes_written = self
.io
.write(&self.wbuf.buf[self.wbuf.record_bytes..Self::HEADER_SIZE + self.wbuf.len])
.ok_or(TurtlsError::WantWrite)?;
self.wbuf.record_bytes += bytes_written as usize;
}
self.wbuf.total_bytes += self.wbuf.record_bytes;
self.wbuf.record_bytes = 0;
Ok(())
}
fn protect(&mut self, aead: &mut TlsAead) {
let (header, msg) =
self.wbuf.buf[..Self::HEADER_SIZE + self.wbuf.len].split_at_mut(Self::HEADER_SIZE);
let (msg, tag) = msg.split_at_mut(self.wbuf.len - TAG_SIZE);
let tag: &mut [u8; TAG_SIZE] = tag.try_into().unwrap();
*tag = aead.encrypt_inline(msg, header);
}
pub(crate) fn close_raw(&mut self, alert: TurtlsAlert) {
let _ = self.write_raw(&AlertMsg::new(alert).to_be_bytes(), ContentType::Alert);
self.io.close();
}
pub(crate) fn close(&mut self, alert: TurtlsAlert, aead: &mut TlsAead) {
let _ = self.write(
&AlertMsg::new(alert).to_be_bytes(),
ContentType::Alert,
aead,
);
self.io.close();
}
}