use core::convert::Infallible;
use gatopsktls::TlsError;
use gatopsktls::server::{TlsServerConfig, TlsServerSession};
use crate::transport::Transport;
use crate::transport::tls::TlsSession;
const RECORD_HEADER_LEN: usize = 5;
const TLS_CT_CHANGE_CIPHER_SPEC: u8 = 20;
const TLS_CT_ALERT: u8 = 21;
const TLS_CT_HANDSHAKE: u8 = 22;
const TLS_CT_APPLICATION_DATA: u8 = 23;
const ALERT_LEVEL_WARNING: u8 = 1;
const ALERT_DESCRIPTION_CLOSE_NOTIFY: u8 = 0;
#[derive(Debug)]
pub enum SessionError<TransportErr> {
Transport(TransportErr),
Tls(TlsError),
Eof,
RecordTooLarge { advertised: usize, capacity: usize },
UnexpectedRecordType(u8),
HandshakeIncomplete,
Closed,
}
impl<E: core::fmt::Debug> SessionError<E> {
fn from_tls(error: TlsError) -> Self {
Self::Tls(error)
}
}
#[derive(Debug, Clone, Copy)]
pub struct PskConfig<'cfg> {
pub identity: &'cfg [u8],
pub secret: &'cfg [u8],
}
pub struct EmbeddedTlsPskSession<'cfg, T, const BUF: usize = 4096> {
transport: T,
inner: TlsServerSession,
config: PskConfig<'cfg>,
server_random: [u8; 32],
record_buf: [u8; BUF],
record_len: usize,
out_buf: [u8; BUF],
plain_buf: [u8; BUF],
plain_len: usize,
plain_offset: usize,
handshake_done: bool,
closed: bool,
}
impl<'cfg, T, const BUF: usize> EmbeddedTlsPskSession<'cfg, T, BUF>
where
T: Transport,
{
pub fn new(transport: T, config: PskConfig<'cfg>, server_random: [u8; 32]) -> Self {
Self {
transport,
inner: TlsServerSession::new(),
config,
server_random,
record_buf: [0u8; BUF],
record_len: 0,
out_buf: [0u8; BUF],
plain_buf: [0u8; BUF],
plain_len: 0,
plain_offset: 0,
handshake_done: false,
closed: false,
}
}
async fn read_exact_into(
&mut self,
start: usize,
len: usize,
) -> Result<(), SessionError<T::Error>> {
let end = start + len;
if end > self.record_buf.len() {
return Err(SessionError::RecordTooLarge {
advertised: end,
capacity: self.record_buf.len(),
});
}
let mut filled = start;
while filled < end {
let n = self
.transport
.read(&mut self.record_buf[filled..end])
.await
.map_err(SessionError::Transport)?;
if n == 0 {
return Err(SessionError::Eof);
}
filled += n;
}
Ok(())
}
async fn fetch_record(&mut self) -> Result<(), SessionError<T::Error>> {
self.read_exact_into(0, RECORD_HEADER_LEN).await?;
let body_len =
u16::from_be_bytes([self.record_buf[3], self.record_buf[4]]) as usize;
let total = RECORD_HEADER_LEN + body_len;
if total > self.record_buf.len() {
return Err(SessionError::RecordTooLarge {
advertised: total,
capacity: self.record_buf.len(),
});
}
self.read_exact_into(RECORD_HEADER_LEN, body_len).await?;
self.record_len = total;
Ok(())
}
fn record_type(&self) -> u8 {
self.record_buf[0]
}
async fn do_accept(&mut self) -> Result<(), SessionError<T::Error>> {
self.fetch_record().await?;
if self.record_type() != TLS_CT_HANDSHAKE {
return Err(SessionError::UnexpectedRecordType(self.record_type()));
}
let ch_handshake_len = self.record_len - RECORD_HEADER_LEN;
let flight_len = {
let ch_handshake =
&self.record_buf[RECORD_HEADER_LEN..RECORD_HEADER_LEN + ch_handshake_len];
let cfg = TlsServerConfig {
psk: (self.config.identity, self.config.secret),
server_random: self.server_random,
};
self.inner
.process_client_hello(ch_handshake, &cfg, &mut self.out_buf)
.map_err(SessionError::from_tls)?
.len()
};
self.transport
.write(&self.out_buf[..flight_len])
.await
.map_err(SessionError::Transport)?;
loop {
self.fetch_record().await?;
match self.record_type() {
TLS_CT_CHANGE_CIPHER_SPEC => continue,
TLS_CT_APPLICATION_DATA => break,
other => return Err(SessionError::UnexpectedRecordType(other)),
}
}
let record_len = self.record_len;
self.inner
.process_client_finished(&self.record_buf[..record_len])
.map_err(SessionError::from_tls)?;
self.handshake_done = true;
Ok(())
}
async fn refill_plaintext(&mut self) -> Result<(), SessionError<T::Error>> {
loop {
self.fetch_record().await?;
let ct = self.record_type();
match ct {
TLS_CT_CHANGE_CIPHER_SPEC => continue,
TLS_CT_APPLICATION_DATA => break,
TLS_CT_ALERT => {
self.closed = true;
return Err(SessionError::Closed);
}
other => return Err(SessionError::UnexpectedRecordType(other)),
}
}
let record_len = self.record_len;
let len = self
.inner
.decrypt_app_data(&self.record_buf[..record_len], &mut self.plain_buf)
.map_err(SessionError::from_tls)?
.len();
self.plain_len = len;
self.plain_offset = 0;
Ok(())
}
}
impl<'cfg, T, const BUF: usize> TlsSession for EmbeddedTlsPskSession<'cfg, T, BUF>
where
T: Transport,
{
type Error = SessionError<T::Error>;
async fn accept(&mut self) -> Result<(), Self::Error> {
if self.handshake_done {
return Ok(());
}
self.do_accept().await
}
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
if !self.handshake_done {
return Err(SessionError::HandshakeIncomplete);
}
if self.closed {
return Ok(0);
}
if self.plain_offset >= self.plain_len {
match self.refill_plaintext().await {
Ok(()) => {}
Err(SessionError::Closed) | Err(SessionError::Eof) => return Ok(0),
Err(other) => return Err(other),
}
}
let available = self.plain_len - self.plain_offset;
let n = core::cmp::min(buf.len(), available);
buf[..n]
.copy_from_slice(&self.plain_buf[self.plain_offset..self.plain_offset + n]);
self.plain_offset += n;
Ok(n)
}
async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
if !self.handshake_done {
return Err(SessionError::HandshakeIncomplete);
}
if self.closed {
return Err(SessionError::Closed);
}
let max_chunk = BUF.saturating_sub(RECORD_HEADER_LEN + 1 + 16);
let chunk = core::cmp::min(buf.len(), max_chunk);
if chunk == 0 {
return Err(SessionError::Tls(TlsError::InsufficientSpace));
}
let record_len = self
.inner
.encrypt_app_data(&buf[..chunk], &mut self.out_buf)
.map_err(SessionError::from_tls)?
.len();
self.transport
.write(&self.out_buf[..record_len])
.await
.map_err(SessionError::Transport)?;
Ok(chunk)
}
async fn flush(&mut self) -> Result<(), Self::Error> {
Ok(())
}
async fn close(&mut self) {
if self.closed {
return;
}
self.closed = true;
if !self.handshake_done {
self.transport.close().await;
return;
}
let _ = ALERT_LEVEL_WARNING; let _ = ALERT_DESCRIPTION_CLOSE_NOTIFY;
self.transport.close().await;
}
}
#[allow(dead_code)]
type _PhantomInfallible = Infallible;
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::mock::MockTransport;
use core::future::Future;
use core::pin::{Pin, pin};
use core::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
fn block_on<F: Future>(future: F) -> F::Output {
fn raw_waker() -> RawWaker {
fn clone(_: *const ()) -> RawWaker {
raw_waker()
}
fn wake(_: *const ()) {}
fn wake_by_ref(_: *const ()) {}
fn drop(_: *const ()) {}
RawWaker::new(
core::ptr::null(),
&RawWakerVTable::new(clone, wake, wake_by_ref, drop),
)
}
let waker = unsafe { Waker::from_raw(raw_waker()) };
let mut future = pin!(future);
let mut cx = Context::from_waker(&waker);
loop {
match Pin::as_mut(&mut future).poll(&mut cx) {
Poll::Ready(out) => return out,
Poll::Pending => panic!("future unexpectedly pending"),
}
}
}
#[test]
fn write_before_accept_errors() {
let transport = MockTransport::new();
let cfg = PskConfig {
identity: b"id",
secret: &[0u8; 32],
};
let mut session: EmbeddedTlsPskSession<'_, MockTransport, 1024> =
EmbeddedTlsPskSession::new(transport, cfg, [0u8; 32]);
let err = block_on(session.write(b"hello")).unwrap_err();
assert!(matches!(err, SessionError::HandshakeIncomplete));
}
#[test]
fn read_before_accept_errors() {
let transport = MockTransport::new();
let cfg = PskConfig {
identity: b"id",
secret: &[0u8; 32],
};
let mut session: EmbeddedTlsPskSession<'_, MockTransport, 1024> =
EmbeddedTlsPskSession::new(transport, cfg, [0u8; 32]);
let mut buf = [0u8; 16];
let err = block_on(session.read(&mut buf)).unwrap_err();
assert!(matches!(err, SessionError::HandshakeIncomplete));
}
#[test]
fn close_marks_session_closed_and_calls_transport_close() {
let transport = MockTransport::new();
let cfg = PskConfig {
identity: b"id",
secret: &[0u8; 32],
};
let mut session: EmbeddedTlsPskSession<'_, MockTransport, 1024> =
EmbeddedTlsPskSession::new(transport, cfg, [0u8; 32]);
block_on(session.close());
assert!(session.closed);
assert!(session.transport.closed);
}
#[test]
fn read_returns_eof_when_transport_returns_zero_during_record_header() {
let transport = MockTransport::new();
let cfg = PskConfig {
identity: b"id",
secret: &[0u8; 32],
};
let mut session: EmbeddedTlsPskSession<'_, MockTransport, 1024> =
EmbeddedTlsPskSession::new(transport, cfg, [0u8; 32]);
session.handshake_done = true;
let mut buf = [0u8; 16];
let n = block_on(session.read(&mut buf)).unwrap();
assert_eq!(n, 0, "EOF on socket -> read returns 0 bytes");
}
}