use std::{
future::Future,
io,
pin::Pin,
task::{Context, Poll},
};
use futures::ready;
use pin_project::pin_project;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use crate::crypto::{CipherCategory, CipherKind};
#[derive(Debug)]
struct CopyBuffer {
read_done: bool,
pos: usize,
cap: usize,
amt: u64,
buf: Box<[u8]>,
}
impl CopyBuffer {
fn new(buffer_size: usize) -> Self {
Self {
read_done: false,
pos: 0,
cap: 0,
amt: 0,
buf: vec![0; buffer_size].into_boxed_slice(),
}
}
fn poll_copy<R, W>(
&mut self,
cx: &mut Context<'_>,
mut reader: Pin<&mut R>,
mut writer: Pin<&mut W>,
) -> Poll<io::Result<u64>>
where
R: AsyncRead + ?Sized,
W: AsyncWrite + ?Sized,
{
loop {
if self.pos == self.cap && !self.read_done {
let me = &mut *self;
let mut buf = ReadBuf::new(&mut me.buf);
ready!(reader.as_mut().poll_read(cx, &mut buf))?;
let n = buf.filled().len();
if n == 0 {
self.read_done = true;
} else {
self.pos = 0;
self.cap = n;
}
}
while self.pos < self.cap {
let me = &mut *self;
let i = ready!(writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]))?;
if i == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"write zero byte into writer",
)));
} else {
self.pos += i;
self.amt += i as u64;
}
}
if self.pos == self.cap && self.read_done {
ready!(writer.as_mut().poll_flush(cx))?;
return Poll::Ready(Ok(self.amt));
}
}
}
}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
struct Copy<'a, R: ?Sized, W: ?Sized> {
reader: &'a mut R,
writer: &'a mut W,
buf: CopyBuffer,
}
impl<R, W> Future for Copy<'_, R, W>
where
R: AsyncRead + Unpin + ?Sized,
W: AsyncWrite + Unpin + ?Sized,
{
type Output = io::Result<u64>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
let me = &mut *self;
me.buf
.poll_copy(cx, Pin::new(&mut *me.reader), Pin::new(&mut *me.writer))
}
}
pub async fn copy_from_encrypted<ER, PW>(method: CipherKind, reader: &mut ER, writer: &mut PW) -> io::Result<u64>
where
ER: AsyncRead + Unpin + ?Sized,
PW: AsyncWrite + Unpin + ?Sized,
{
Copy {
reader,
writer,
buf: CopyBuffer::new(encrypted_read_buffer_size(method)),
}
.await
}
pub async fn copy_to_encrypted<PR, EW>(method: CipherKind, reader: &mut PR, writer: &mut EW) -> io::Result<u64>
where
PR: AsyncRead + Unpin + ?Sized,
EW: AsyncWrite + Unpin + ?Sized,
{
Copy {
reader,
writer,
buf: CopyBuffer::new(plain_read_buffer_size(method)),
}
.await
}
fn encrypted_read_buffer_size(method: CipherKind) -> usize {
match method.category() {
CipherCategory::Aead => super::aead::MAX_PACKET_SIZE + method.tag_len(),
#[cfg(feature = "stream-cipher")]
CipherCategory::Stream => 1 << 14,
CipherCategory::None => 1 << 14,
#[cfg(feature = "aead-cipher-2022")]
CipherCategory::Aead2022 => super::aead_2022::MAX_PACKET_SIZE + method.tag_len(),
}
}
fn plain_read_buffer_size(method: CipherKind) -> usize {
match method.category() {
CipherCategory::Aead => super::aead::MAX_PACKET_SIZE,
#[cfg(feature = "stream-cipher")]
CipherCategory::Stream => 1 << 14,
CipherCategory::None => 1 << 14,
#[cfg(feature = "aead-cipher-2022")]
CipherCategory::Aead2022 => super::aead_2022::MAX_PACKET_SIZE,
}
}
#[inline]
pub fn alloc_encrypted_read_buffer(method: CipherKind) -> Box<[u8]> {
vec![0u8; encrypted_read_buffer_size(method)].into_boxed_slice()
}
#[inline]
pub fn alloc_plain_read_buffer(method: CipherKind) -> Box<[u8]> {
vec![0u8; plain_read_buffer_size(method)].into_boxed_slice()
}
enum TransferState {
Running(CopyBuffer),
ShuttingDown(u64),
Done(u64),
}
#[pin_project(project = CopyBidirectionalProj)]
struct CopyBidirectional<'a, A: ?Sized, B: ?Sized> {
#[pin]
a: &'a mut A,
#[pin]
b: &'a mut B,
a_to_b: TransferState,
b_to_a: TransferState,
}
fn transfer_one_direction<A, B>(
cx: &mut Context<'_>,
state: &mut TransferState,
mut r: Pin<&mut A>,
mut w: Pin<&mut B>,
) -> Poll<io::Result<u64>>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
loop {
match state {
TransferState::Running(buf) => {
let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
*state = TransferState::ShuttingDown(count);
}
TransferState::ShuttingDown(count) => {
ready!(w.as_mut().poll_shutdown(cx))?;
*state = TransferState::Done(*count);
}
TransferState::Done(count) => return Poll::Ready(Ok(*count)),
}
}
}
impl<'a, A, B> Future for CopyBidirectional<'a, A, B>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
type Output = io::Result<(u64, u64)>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let CopyBidirectionalProj {
mut a,
mut b,
a_to_b,
b_to_a,
} = self.project();
let poll_a_to_b = transfer_one_direction(cx, a_to_b, a.as_mut(), b.as_mut())?;
let poll_b_to_a = transfer_one_direction(cx, b_to_a, b.as_mut(), a.as_mut())?;
let a_to_b = ready!(poll_a_to_b);
let b_to_a = ready!(poll_b_to_a);
Poll::Ready(Ok((a_to_b, b_to_a)))
}
}
pub async fn copy_encrypted_bidirectional<E, P>(
method: CipherKind,
encrypted: &mut E,
plain: &mut P,
) -> Result<(u64, u64), std::io::Error>
where
E: AsyncRead + AsyncWrite + Unpin + ?Sized,
P: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
CopyBidirectional {
a: encrypted,
b: plain,
a_to_b: TransferState::Running(CopyBuffer::new(encrypted_read_buffer_size(method))),
b_to_a: TransferState::Running(CopyBuffer::new(plain_read_buffer_size(method))),
}
.await
}