use crate::qpp::{create_prng, QuantumPermutationPad};
use std::sync::Arc;
use std::sync::Mutex;
use tokio::io::{AsyncRead, AsyncWrite};
pub struct QppReadHalf<S> {
inner: S,
pad: Arc<QuantumPermutationPad>,
rprng: Arc<Mutex<crate::qpp::Rand>>,
}
pub struct QppWriteHalf<S> {
inner: S,
pad: Arc<QuantumPermutationPad>,
wprng: Arc<Mutex<crate::qpp::Rand>>,
}
impl<S: AsyncRead + Unpin> QppReadHalf<S> {
pub fn new(inner: S, pad: Arc<QuantumPermutationPad>, seed: &[u8]) -> Self {
QppReadHalf {
inner,
pad,
rprng: Arc::new(Mutex::new(create_prng(seed))),
}
}
}
impl<S: AsyncWrite + Unpin> QppWriteHalf<S> {
pub fn new(inner: S, pad: Arc<QuantumPermutationPad>, seed: &[u8]) -> Self {
QppWriteHalf {
inner,
pad,
wprng: Arc::new(Mutex::new(create_prng(seed))),
}
}
}
impl<S: AsyncRead + Unpin> AsyncRead for QppReadHalf<S> {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let filled_before = buf.filled().len();
match std::pin::Pin::new(&mut self.inner).poll_read(cx, buf) {
std::task::Poll::Ready(Ok(())) => {
let filled_len = buf.filled().len();
let n = filled_len - filled_before;
if n > 0 {
let mut rprng = self.rprng.lock().unwrap();
let filled = buf.filled_mut();
self.pad.decrypt_with_prng(&mut filled[filled_before..], &mut rprng);
}
std::task::Poll::Ready(Ok(()))
}
other => other,
}
}
}
impl<S: AsyncWrite + Unpin> AsyncWrite for QppWriteHalf<S> {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
let mut data = buf.to_vec();
{
let mut wprng = self.wprng.lock().unwrap();
self.pad.encrypt_with_prng(&mut data, &mut wprng);
}
std::pin::Pin::new(&mut self.inner).poll_write(cx, &data)
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::pin::Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::pin::Pin::new(&mut self.inner).poll_shutdown(cx)
}
}
pub fn wrap_with_qpp<R, W>(
read_half: R,
write_half: W,
seed: &[u8],
qpp_count: u32,
) -> (QppReadHalf<R>, QppWriteHalf<W>)
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
{
let pad = Arc::new(QuantumPermutationPad::new(seed, qpp_count as u16));
(
QppReadHalf::new(read_half, pad.clone(), seed),
QppWriteHalf::new(write_half, pad, seed),
)
}