use chacha20poly1305::{
aead::{
generic_array::ArrayLength,
stream::{Decryptor, NonceSize, StreamPrimitive},
},
AeadInPlace,
};
use pin_project_lite::pin_project;
use std::{ops::Sub, pin::Pin, task::ready};
use tokio::io::{AsyncBufRead, AsyncRead};
use crate::DEFAULT_BUFFER_SIZE;
pin_project! {
pub struct ReadHalf<T, U> {
#[pin]
inner: T,
decryptor: U,
buffer: Vec<u8>,
pos: usize,
cap: usize
}
}
impl<T, A, S> ReadHalf<T, Decryptor<A, S>>
where
S: StreamPrimitive<A>,
A: AeadInPlace,
A::NonceSize: Sub<<S as StreamPrimitive<A>>::NonceOverhead>,
NonceSize<A, S>: ArrayLength<u8>,
{
pub fn new(inner: T, decryptor: Decryptor<A, S>) -> Self {
Self::with_capacity(inner, decryptor, DEFAULT_BUFFER_SIZE)
}
pub fn with_capacity(inner: T, decryptor: Decryptor<A, S>, size: usize) -> Self {
Self {
inner,
decryptor,
buffer: vec![0u8; size],
pos: 0,
cap: 0,
}
}
fn produce(mut self: Pin<&mut Self>) -> std::io::Result<Option<Vec<u8>>> {
if self.cap <= self.pos {
return Ok(None);
}
let mut length_bytes = [0u8; 4];
length_bytes.copy_from_slice(&self.buffer[self.pos..self.pos + 4]);
let length = u32::from_le_bytes(length_bytes) as usize;
let me = self.as_mut().project();
if *me.cap >= *me.pos + length + 4 {
let decrypted = me
.decryptor
.decrypt_next(&me.buffer[*me.pos + 4..*me.pos + 4 + length])
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?;
*me.pos += 4 + length;
if *me.pos == *me.cap {
*me.pos = 0;
*me.cap = 0;
}
Ok(Some(decrypted))
} else {
self.adjust_buffer(length + 4);
Ok(None)
}
}
fn adjust_buffer(self: Pin<&mut Self>, desired_additional: usize) {
let me = self.project();
if *me.cap + desired_additional >= me.buffer.len() && *me.pos > 0 {
me.buffer.copy_within(*me.pos..*me.cap, 0);
*me.cap -= *me.pos;
*me.pos = 0;
}
if *me.pos + desired_additional > me.buffer.len() {
me.buffer.resize(me.buffer.len() * 2, 0);
}
}
pub fn buffer(&self) -> &[u8] {
&self.buffer[self.pos..]
}
}
impl<T, A, S> AsyncRead for ReadHalf<T, Decryptor<A, S>>
where
T: AsyncRead,
S: StreamPrimitive<A>,
A: AeadInPlace,
A::NonceSize: Sub<<S as StreamPrimitive<A>>::NonceOverhead>,
NonceSize<A, S>: ArrayLength<u8>,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
loop {
if let Some(decrypted) = self.as_mut().produce()? {
buf.put_slice(&decrypted);
return std::task::Poll::Ready(Ok(()));
}
if ready!(self.as_mut().poll_fill_buf(cx))?.is_empty() {
return std::task::Poll::Ready(Ok(()));
}
}
}
}
impl<R: AsyncRead, A, S> tokio::io::AsyncBufRead for ReadHalf<R, Decryptor<A, S>>
where
S: StreamPrimitive<A>,
A: AeadInPlace,
A::NonceSize: Sub<<S as StreamPrimitive<A>>::NonceOverhead>,
NonceSize<A, S>: ArrayLength<u8>,
{
fn poll_fill_buf(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<&[u8]>> {
let me = self.project();
let mut buf = tokio::io::ReadBuf::new(&mut me.buffer[*me.cap..]);
ready!(me.inner.poll_read(cx, &mut buf))?;
if !buf.filled().is_empty() {
*me.cap += buf.filled().len();
}
std::task::Poll::Ready(Ok(&me.buffer[*me.pos..*me.cap]))
}
fn consume(self: Pin<&mut Self>, amt: usize) {
let me = self.project();
*me.pos += amt;
if *me.pos >= *me.cap {
*me.pos = 0;
*me.cap = 0;
}
}
}
#[cfg(test)]
mod tests {
use std::{assert_eq, time::Duration};
use chacha20poly1305::{aead::stream::EncryptorLE31, KeyInit, XChaCha20Poly1305};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use crate::get_key;
use super::*;
#[tokio::test]
pub async fn test_crypto_stream_read_half() {
let key: [u8; 32] = get_key("key", "group");
let start_nonce = [0u8; 20];
let (rx, mut tx) = tokio::io::duplex(100);
tokio::spawn(async move {
let encrypted_content = {
let mut encryptor: EncryptorLE31<XChaCha20Poly1305> =
chacha20poly1305::aead::stream::EncryptorLE31::from_aead(
XChaCha20Poly1305::new(key.as_ref().into()),
start_nonce.as_ref().into(),
);
let mut expected = Vec::new();
for data in ["some content", "some other content", "even more content"] {
let mut encrypted = encryptor.encrypt_next(data.as_bytes()).unwrap();
expected.extend((encrypted.len() as u32).to_le_bytes());
expected.append(&mut encrypted);
}
expected
};
for chunk in encrypted_content.chunks(10) {
let _ = tx.write(chunk).await;
tokio::time::sleep(Duration::from_millis(20)).await;
}
});
tokio::time::sleep(Duration::from_millis(20)).await;
let decryptor = chacha20poly1305::aead::stream::DecryptorLE31::from_aead(
XChaCha20Poly1305::new(key.as_ref().into()),
start_nonce.as_ref().into(),
);
let mut reader = ReadHalf::new(rx, decryptor);
let mut plain_content = String::new();
let _ = reader.read_to_string(&mut plain_content).await;
assert_eq!(
plain_content,
"some contentsome other contenteven more content"
);
}
#[tokio::test]
pub async fn test_read_invalid_data() {
let key: [u8; 32] = get_key("key", "group");
let start_nonce = [0u8; 20];
let (rx, _tx) = tokio::io::duplex(100);
let decryptor = chacha20poly1305::aead::stream::DecryptorLE31::from_aead(
XChaCha20Poly1305::new(key.as_ref().into()),
start_nonce.as_ref().into(),
);
let mut reader = ReadHalf::new(rx, decryptor);
let mut reader_data = Vec::from_iter(10u32.to_le_bytes());
reader_data.extend_from_slice(&[0u8; 20]);
reader.cap = reader_data.len();
reader.buffer = reader_data;
let mut buf = [0u8; 1024];
assert!(reader.read(&mut buf).await.is_err());
assert!(reader.read(&mut buf).await.is_err());
}
}