use std::future::Future;
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::Mutex;
use std::task::{Context, Poll};
use crate::cipher::Cipher;
use crate::util::eof;
use awak::io::AsyncRead;
pub struct DecryptReadExact<'a, A: ?Sized> {
cipher: Arc<Mutex<Cipher>>,
reader: &'a mut A,
buf: &'a mut [u8],
pos: usize,
}
pub fn read_exact<'a, A>(
cipher: Arc<Mutex<Cipher>>,
reader: &'a mut A,
buf: &'a mut [u8],
) -> DecryptReadExact<'a, A>
where
A: AsyncRead + Unpin + ?Sized,
{
DecryptReadExact {
cipher,
reader,
buf,
pos: 0,
}
}
impl<A> Future for DecryptReadExact<'_, A>
where
A: AsyncRead + Unpin + ?Sized,
{
type Output = io::Result<usize>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
let me = &mut *self;
let mut cipher = me.cipher.lock().unwrap();
if cipher.dec.is_none() {
while me.pos < cipher.iv.len() {
let n = ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut cipher.iv[me.pos..]))?;
me.pos += n;
if n == 0 {
return Err(eof()).into();
}
}
me.pos = 0;
cipher.init_decrypt();
};
while me.pos < me.buf.len() {
let n = ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut me.buf[me.pos..]))?;
me.pos += n;
if n == 0 {
return Err(eof()).into();
}
}
let copy_len = me.buf.len();
cipher.decrypt(&mut me.buf[..copy_len]);
Poll::Ready(Ok(me.pos))
}
}