use crate::{
crypto::{CryptoCore, CryptoReader, CryptoWriter},
error::{ClavisError, ClavisResult, MessageError, StreamError},
PacketTrait,
};
use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
use tracing::warn;
pub trait EncryptedPacket {
fn read_packet<P: PacketTrait>(
&mut self,
) -> impl std::future::Future<Output = ClavisResult<P>> + Send
where
Self: Sized;
fn write_packet(
&mut self,
packet: &impl PacketTrait,
) -> impl std::future::Future<Output = ClavisResult<()>> + Send
where
Self: Sized;
}
#[derive(Debug, Clone)]
pub struct EncryptedStreamOptions {
pub max_packet_size: usize,
pub psk: Option<Vec<u8>>,
}
impl Default for EncryptedStreamOptions {
fn default() -> Self {
Self {
max_packet_size: 65536,
psk: None,
}
}
}
pub struct EncryptedStream<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send,
{
stream: S,
crypto_reader: CryptoReader,
crypto_writer: CryptoWriter,
options: EncryptedStreamOptions,
}
impl<S> EncryptedStream<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send,
{
pub async fn new(mut stream: S, options: Option<EncryptedStreamOptions>) -> ClavisResult<Self> {
let options = options.unwrap_or_default();
if options.psk.is_none() {
warn!("No pre-shared key is set, this connection may be vulnerable to man-in-the-middle attacks.");
}
let core = CryptoCore::establish(&mut stream, options.clone())
.await
.map_err(|e| {
ClavisError::crypto_failure(crate::error::CryptoOperation::Handshake, e.to_string())
})?;
let crypto_reader = core.reader;
let crypto_writer = core.writer;
Ok(Self {
stream,
crypto_reader,
crypto_writer,
options,
})
}
pub fn split(
self,
) -> ClavisResult<(EncryptedReader<ReadHalf<S>>, EncryptedWriter<WriteHalf<S>>)> {
let (read, write) = tokio::io::split(self.stream);
Ok((
EncryptedReader::new(read, self.crypto_reader, self.options.clone()),
EncryptedWriter::new(write, self.crypto_writer, self.options),
))
}
}
impl<S> EncryptedPacket for EncryptedStream<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send,
{
async fn read_packet<P: PacketTrait>(&mut self) -> ClavisResult<P> {
let data = self.crypto_reader.read(&mut self.stream).await?;
if data.len() > self.options.max_packet_size {
return Err(ClavisError::Message(MessageError::MessageTooLarge {
size: data.len(),
max_size: self.options.max_packet_size,
}));
}
P::deserialize(&data).map_err(|e| ClavisError::deserialization_failed(e.to_string()))
}
async fn write_packet(&mut self, packet: &impl PacketTrait) -> ClavisResult<()> {
let data = packet
.serialize()
.map_err(|e| ClavisError::serialization_failed(e.to_string()))?;
if data.len() > self.options.max_packet_size {
return Err(ClavisError::Message(MessageError::MessageTooLarge {
size: data.len(),
max_size: self.options.max_packet_size,
}));
}
self.crypto_writer.write(&mut self.stream, &data).await
}
}
pub struct EncryptedReader<R> {
inner: R,
crypto: CryptoReader,
options: EncryptedStreamOptions,
}
impl<R> EncryptedReader<R> {
fn new(inner: R, crypto: CryptoReader, options: EncryptedStreamOptions) -> Self {
Self {
inner,
crypto,
options,
}
}
}
pub struct EncryptedWriter<W> {
inner: W,
crypto: CryptoWriter,
options: EncryptedStreamOptions,
}
impl<W> EncryptedWriter<W> {
fn new(inner: W, crypto: CryptoWriter, options: EncryptedStreamOptions) -> Self {
Self {
inner,
crypto,
options,
}
}
}
impl<R: AsyncRead + Unpin + Send> EncryptedPacket for EncryptedReader<R> {
async fn read_packet<P: PacketTrait>(&mut self) -> ClavisResult<P> {
let data = self.crypto.read(&mut self.inner).await?;
if data.len() > self.options.max_packet_size {
return Err(ClavisError::Message(MessageError::MessageTooLarge {
size: data.len(),
max_size: self.options.max_packet_size,
}));
}
P::deserialize(&data).map_err(|e| ClavisError::deserialization_failed(e.to_string()))
}
async fn write_packet(&mut self, _packet: &impl PacketTrait) -> ClavisResult<()> {
Err(ClavisError::Stream(StreamError::InvalidOperation(
"Cannot write to a read-only stream".into(),
)))
}
}
impl<W: AsyncWrite + Unpin + Send> EncryptedPacket for EncryptedWriter<W> {
async fn read_packet<P: PacketTrait>(&mut self) -> ClavisResult<P> {
Err(ClavisError::Stream(StreamError::InvalidOperation(
"Cannot read from a write-only stream".into(),
)))
}
async fn write_packet(&mut self, packet: &impl PacketTrait) -> ClavisResult<()> {
let data = packet
.serialize()
.map_err(|e| ClavisError::serialization_failed(e.to_string()))?;
if data.len() > self.options.max_packet_size {
return Err(ClavisError::Message(MessageError::MessageTooLarge {
size: data.len(),
max_size: self.options.max_packet_size,
}));
}
self.crypto.write(&mut self.inner, &data).await
}
}