use bytes::{Buf, BufMut, BytesMut};
use chacha20poly1305::{
aead::{
generic_array::ArrayLength,
stream::{Encryptor, NonceSize, StreamPrimitive},
},
AeadInPlace,
};
use std::{
ops::Sub,
pin::Pin,
task::{ready, Poll},
};
use tokio::io::AsyncWrite;
use crate::{DEFAULT_BUFFER_SIZE, DEFAULT_CHUNK_SIZE};
pin_project_lite::pin_project! {
pub struct WriteHalf<T, U> {
#[pin]
inner: T,
encryptor: U,
buffer: bytes::BytesMut,
chunk_size: usize
}
}
impl<T, A, S> WriteHalf<T, Encryptor<A, S>>
where
T: AsyncWrite,
S: StreamPrimitive<A>,
A: AeadInPlace,
A::NonceSize: Sub<<S as StreamPrimitive<A>>::NonceOverhead>,
NonceSize<A, S>: ArrayLength<u8>,
{
pub fn new(inner: T, encryptor: Encryptor<A, S>) -> Self {
Self::with_capacity(inner, encryptor, DEFAULT_BUFFER_SIZE, DEFAULT_CHUNK_SIZE)
}
pub fn with_capacity(
inner: T,
encryptor: Encryptor<A, S>,
size: usize,
chunk_size: usize,
) -> Self {
Self {
inner,
encryptor,
buffer: BytesMut::with_capacity(size),
chunk_size,
}
}
fn get_encrypted(&mut self, buf: &[u8]) -> std::io::Result<Vec<u8>> {
let mut encrypted = self
.encryptor
.encrypt_next(buf)
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
let len = (encrypted.len() as u32).to_le_bytes();
let mut buf = Vec::with_capacity(encrypted.len() + std::mem::size_of::<u32>());
buf.extend_from_slice(&len);
buf.append(&mut encrypted);
Ok(buf)
}
fn flush_buf(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::io::Result<()>> {
let mut me = self.project();
while me.buffer.has_remaining() {
match ready!(me.inner.as_mut().poll_write(cx, &me.buffer[..])) {
Ok(0) => {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"failed to write the buffered data",
)));
}
Ok(n) => me.buffer.advance(n),
Err(e) => return Poll::Ready(Err(e)),
}
}
Poll::Ready(Ok(()))
}
}
impl<T, A, S> AsyncWrite for WriteHalf<T, Encryptor<A, S>>
where
T: AsyncWrite + Unpin,
S: StreamPrimitive<A>,
A: AeadInPlace,
A::NonceSize: Sub<<S as StreamPrimitive<A>>::NonceOverhead>,
NonceSize<A, S>: ArrayLength<u8>,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
if !self.buffer.is_empty() {
ready!(self.as_mut().flush_buf(cx))?
}
let mut total_written = 0;
for chunk in buf.chunks(self.chunk_size) {
let encrypted = self.get_encrypted(chunk)?;
total_written += chunk.len();
let me = self.as_mut().project();
match me.inner.poll_write(cx, &encrypted[..]) {
Poll::Ready(Ok(written)) => {
if written < encrypted.len() {
self.buffer.put(&encrypted[written..]);
return Poll::Ready(Ok(total_written));
}
}
Poll::Pending | Poll::Ready(Err(..)) => {
self.buffer.put(&encrypted[..]);
return Poll::Ready(Ok(total_written));
}
}
}
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
ready!(self.as_mut().flush_buf(cx))?;
self.project().inner.poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
self.project().inner.poll_shutdown(cx)
}
}
#[cfg(test)]
mod tests {
use std::assert_eq;
use chacha20poly1305::{aead::stream::EncryptorLE31, KeyInit, XChaCha20Poly1305};
use tokio::io::AsyncWriteExt;
use crate::get_key;
use super::*;
#[tokio::test]
pub async fn test_crypto_stream_write_half() {
let key: [u8; 32] = get_key("key", "group");
let start_nonce = [0u8; 20];
let mut encryptor: EncryptorLE31<XChaCha20Poly1305> =
chacha20poly1305::aead::stream::EncryptorLE31::from_aead(
XChaCha20Poly1305::new(key.as_ref().into()),
start_nonce.as_ref().into(),
);
let expected = {
let mut encrypted = encryptor.encrypt_next("some content".as_bytes()).unwrap();
let mut expected = Vec::new();
expected.extend((encrypted.len() as u32).to_le_bytes());
expected.append(&mut encrypted);
expected
};
let mut writer = WriteHalf::new(
tokio::io::BufWriter::new(Vec::new()),
chacha20poly1305::aead::stream::EncryptorLE31::from_aead(
XChaCha20Poly1305::new(key.as_ref().into()),
start_nonce.as_ref().into(),
),
);
assert_eq!(
writer.write(b"some content").await.unwrap(),
"some content".bytes().len()
);
assert_eq!(expected, writer.inner.buffer())
}
}