#![allow(clippy::module_name_repetitions)]
use core::fmt::Debug;
use alloc::string::String;
use embedded_io::{ErrorType, Read, Write};
use embedded_tls::{blocking::TlsConnection, Aes128GcmSha256, NoVerify, TlsConfig, TlsContext};
use rand::SeedableRng;
use rand_chacha::ChaCha20Rng;
use regex::Regex;
use crate::{
traits::io::{EasySocket, Open, OptionType},
types::TlsSocketOptions,
};
use super::{
state::{Connected, NotReady, Ready, SocketState},
tcp::TcpSocket,
};
lazy_static::lazy_static! {
static ref REGEX: Regex = Regex::new("\r|\0").unwrap();
}
pub const MAX_FRAGMENT_LENGTH: u16 = 16_384;
pub struct TlsSocket<'a, S: SocketState = NotReady> {
tls_connection: TlsConnection<'a, TcpSocket<Connected>, Aes128GcmSha256>,
tls_config: TlsConfig<'a, Aes128GcmSha256>,
_marker: core::marker::PhantomData<S>,
}
impl Debug for TlsSocket<'_> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("TlsSocket").finish()
}
}
impl<'a> TlsSocket<'_> {
pub fn new(
socket: TcpSocket<Connected>,
record_read_buf: &'a mut [u8],
record_write_buf: &'a mut [u8],
) -> TlsSocket<'a, NotReady> {
let tls_config: TlsConfig<'_, Aes128GcmSha256> = TlsConfig::new();
let tls_connection: TlsConnection<TcpSocket<Connected>, Aes128GcmSha256> =
TlsConnection::new(socket, record_read_buf, record_write_buf);
TlsSocket {
tls_connection,
tls_config,
_marker: core::marker::PhantomData,
}
}
#[must_use]
pub fn new_buffer() -> [u8; MAX_FRAGMENT_LENGTH as usize] {
[0; 16_384]
}
}
impl TlsSocket<'_, Ready> {
pub fn write_all(&mut self, buf: &[u8]) -> Result<(), embedded_tls::TlsError> {
self.tls_connection.write_all(buf)
}
pub fn read_string(&mut self) -> Result<String, embedded_tls::TlsError> {
let mut buf = TlsSocket::new_buffer();
let _ = self.read(&mut buf)?;
let text = String::from_utf8_lossy(&buf);
let text = REGEX.replace_all(&text, "");
Ok(text.into_owned())
}
}
impl<S: SocketState> ErrorType for TlsSocket<'_, S> {
type Error = embedded_tls::TlsError;
}
impl<S: SocketState> OptionType for TlsSocket<'_, S> {
type Options<'b> = TlsSocketOptions<'b>;
}
impl<'a, 'b> Open<'a, 'b> for TlsSocket<'b, NotReady>
where
'a: 'b,
{
type Return = TlsSocket<'a, Ready>;
fn open(self, options: &'b Self::Options<'_>) -> Result<Self::Return, embedded_tls::TlsError>
where
'b: 'a,
{
let mut rng = ChaCha20Rng::seed_from_u64(options.seed());
let mut tls_socket: TlsSocket<Ready> = TlsSocket {
tls_connection: self.tls_connection,
tls_config: self.tls_config,
_marker: core::marker::PhantomData,
};
tls_socket.tls_config = tls_socket
.tls_config
.with_server_name(options.server_name());
if options.rsa_signatures_enabled() {
tls_socket.tls_config = tls_socket.tls_config.enable_rsa_signatures();
}
if options.reset_max_fragment_length() {
tls_socket.tls_config = tls_socket.tls_config.reset_max_fragment_length();
}
if let Some(cert) = options.cert() {
tls_socket.tls_config = tls_socket.tls_config.with_cert(cert.clone());
}
if let Some(ca) = options.ca() {
tls_socket.tls_config = tls_socket.tls_config.with_ca(ca.clone());
}
let tls_context = TlsContext::new(&tls_socket.tls_config, &mut rng);
tls_socket
.tls_connection
.open::<ChaCha20Rng, NoVerify>(tls_context)?;
Ok(tls_socket)
}
}
impl embedded_io::Read for TlsSocket<'_, Ready> {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
self.tls_connection.read(buf)
}
}
impl embedded_io::Write for TlsSocket<'_, Ready> {
fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
self.tls_connection.write(buf)
}
fn flush(&mut self) -> Result<(), Self::Error> {
self.tls_connection.flush()
}
}
impl EasySocket for TlsSocket<'_, Ready> {}