GatoPSKTLS 0.1.0

TLS 1.3 PSK client + server (no_std, no allocator). Fork of drogue-iot/embedded-tls extended with server-mode PSK_KE handshake — for embedded MQTT brokers and similar peers.
Documentation
use crate::{
    TlsError,
    buffer::CryptoBuffer,
    config::{TLS_RECORD_OVERHEAD, TlsCipherSuite},
    connection::encrypt,
    key_schedule::{ReadKeySchedule, WriteKeySchedule},
    record::{ClientRecord, ClientRecordHeader},
};

pub struct WriteBuffer<'a> {
    buffer: &'a mut [u8],
    pos: usize,
    current_header: Option<ClientRecordHeader>,
}

pub(crate) struct WriteBufferBorrow<'a> {
    buffer: &'a [u8],
    pos: &'a usize,
    current_header: &'a Option<ClientRecordHeader>,
}

pub(crate) struct WriteBufferBorrowMut<'a> {
    buffer: &'a mut [u8],
    pos: &'a mut usize,
    current_header: &'a mut Option<ClientRecordHeader>,
}

impl<'a> WriteBuffer<'a> {
    pub fn new(buffer: &'a mut [u8]) -> Self {
        debug_assert!(
            buffer.len() > TLS_RECORD_OVERHEAD,
            "The write buffer must be sufficiently large to include the tls record overhead"
        );
        Self {
            buffer,
            pos: 0,
            current_header: None,
        }
    }

    pub(crate) fn reborrow_mut(&mut self) -> WriteBufferBorrowMut<'_> {
        WriteBufferBorrowMut {
            buffer: self.buffer,
            pos: &mut self.pos,
            current_header: &mut self.current_header,
        }
    }

    pub(crate) fn reborrow(&self) -> WriteBufferBorrow<'_> {
        WriteBufferBorrow {
            buffer: self.buffer,
            pos: &self.pos,
            current_header: &self.current_header,
        }
    }

    pub fn is_full(&self) -> bool {
        self.reborrow().is_full()
    }

    pub fn append(&mut self, buf: &[u8]) -> usize {
        self.reborrow_mut().append(buf)
    }

    pub fn is_empty(&self) -> bool {
        self.reborrow().is_empty()
    }

    pub fn contains(&self, header: ClientRecordHeader) -> bool {
        self.reborrow().contains(header)
    }

    pub(crate) fn start_record(&mut self, header: ClientRecordHeader) -> Result<(), TlsError> {
        self.reborrow_mut().start_record(header)
    }

    pub(crate) fn close_record<CipherSuite>(
        &mut self,
        write_key_schedule: &mut WriteKeySchedule<CipherSuite>,
    ) -> Result<&[u8], TlsError>
    where
        CipherSuite: TlsCipherSuite,
    {
        close_record(
            self.buffer,
            &mut self.pos,
            &mut self.current_header,
            write_key_schedule,
        )
    }

    pub fn write_record<CipherSuite>(
        &mut self,
        record: &ClientRecord<CipherSuite>,
        write_key_schedule: &mut WriteKeySchedule<CipherSuite>,
        read_key_schedule: Option<&mut ReadKeySchedule<CipherSuite>>,
    ) -> Result<&[u8], TlsError>
    where
        CipherSuite: TlsCipherSuite,
    {
        write_record(
            self.buffer,
            &mut self.pos,
            &mut self.current_header,
            record,
            write_key_schedule,
            read_key_schedule,
        )
    }
}

impl WriteBufferBorrow<'_> {
    fn max_block_size(&self) -> usize {
        self.buffer.len() - TLS_RECORD_OVERHEAD
    }

    pub fn is_full(&self) -> bool {
        *self.pos == self.max_block_size()
    }

    pub fn len(&self) -> usize {
        *self.pos
    }

    pub fn is_empty(&self) -> bool {
        self.len() == 0
    }

    pub fn space(&self) -> usize {
        self.max_block_size() - *self.pos
    }

    pub fn contains(&self, header: ClientRecordHeader) -> bool {
        self.current_header.as_ref() == Some(&header)
    }
}

impl WriteBufferBorrowMut<'_> {
    fn reborrow(&self) -> WriteBufferBorrow<'_> {
        WriteBufferBorrow {
            buffer: self.buffer,
            pos: self.pos,
            current_header: self.current_header,
        }
    }

    pub fn is_full(&self) -> bool {
        self.reborrow().is_full()
    }

    pub fn is_empty(&self) -> bool {
        self.reborrow().is_empty()
    }

    pub fn contains(&self, header: ClientRecordHeader) -> bool {
        self.reborrow().contains(header)
    }

    pub fn append(&mut self, buf: &[u8]) -> usize {
        let buffered = usize::min(buf.len(), self.reborrow().space());
        if buffered > 0 {
            self.buffer[*self.pos..*self.pos + buffered].copy_from_slice(&buf[..buffered]);
            *self.pos += buffered;
        }
        buffered
    }

    pub(crate) fn start_record(&mut self, header: ClientRecordHeader) -> Result<(), TlsError> {
        start_record(self.buffer, self.pos, self.current_header, header)
    }

    pub fn close_record<CipherSuite>(
        &mut self,
        write_key_schedule: &mut WriteKeySchedule<CipherSuite>,
    ) -> Result<&[u8], TlsError>
    where
        CipherSuite: TlsCipherSuite,
    {
        close_record(
            self.buffer,
            self.pos,
            self.current_header,
            write_key_schedule,
        )
    }
}

fn start_record(
    buffer: &mut [u8],
    pos: &mut usize,
    current_header: &mut Option<ClientRecordHeader>,
    header: ClientRecordHeader,
) -> Result<(), TlsError> {
    debug_assert!(current_header.is_none());

    debug!("start_record({:?})", header);
    *current_header = Some(header);

    with_buffer(buffer, pos, |mut buf| {
        header.encode(&mut buf)?;
        buf.push_u16(0)?;
        Ok(buf.rewind())
    })
}

fn with_buffer(
    buffer: &mut [u8],
    pos: &mut usize,
    op: impl FnOnce(CryptoBuffer) -> Result<CryptoBuffer, TlsError>,
) -> Result<(), TlsError> {
    let buf = CryptoBuffer::wrap_with_pos(buffer, *pos);

    match op(buf) {
        Ok(buf) => {
            *pos = buf.len();
            Ok(())
        }
        Err(err) => Err(err),
    }
}

fn close_record<'a, CipherSuite>(
    buffer: &'a mut [u8],
    pos: &mut usize,
    current_header: &mut Option<ClientRecordHeader>,
    write_key_schedule: &mut WriteKeySchedule<CipherSuite>,
) -> Result<&'a [u8], TlsError>
where
    CipherSuite: TlsCipherSuite,
{
    const HEADER_SIZE: usize = 5;

    let header = current_header.take().unwrap();
    with_buffer(buffer, pos, |mut buf| {
        if !header.is_encrypted() {
            return Ok(buf);
        }

        buf.push(header.trailer_content_type() as u8)
            .map_err(|_| TlsError::EncodeError)?;

        let mut buf = buf.offset(HEADER_SIZE);
        encrypt(write_key_schedule, &mut buf)?;
        Ok(buf.rewind())
    })?;
    let [upper, lower] = ((*pos - HEADER_SIZE) as u16).to_be_bytes();

    buffer[3] = upper;
    buffer[4] = lower;

    let slice = &buffer[..*pos];

    *pos = 0;
    *current_header = None;

    Ok(slice)
}

fn write_record<'a, CipherSuite>(
    buffer: &'a mut [u8],
    pos: &mut usize,
    current_header: &mut Option<ClientRecordHeader>,
    record: &ClientRecord<CipherSuite>,
    write_key_schedule: &mut WriteKeySchedule<CipherSuite>,
    read_key_schedule: Option<&mut ReadKeySchedule<CipherSuite>>,
) -> Result<&'a [u8], TlsError>
where
    CipherSuite: TlsCipherSuite,
{
    if current_header.is_some() {
        return Err(TlsError::InternalError);
    }

    start_record(buffer, pos, current_header, record.header())?;
    with_buffer(buffer, pos, |buf| {
        let mut buf = buf.forward();
        record.encode_payload(&mut buf)?;

        let transcript = read_key_schedule
            .ok_or(TlsError::InternalError)?
            .transcript_hash();

        record.finish_record(&mut buf, transcript, write_key_schedule)?;
        Ok(buf.rewind())
    })?;
    close_record(buffer, pos, current_header, write_key_schedule)
}