use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use super::cipher::Rc4;
const WRITE_BUF_CAPACITY: usize = 16_384;
pub struct MseStream<S> {
inner: S,
read_cipher: Option<Rc4>,
write_cipher: Option<Rc4>,
write_buf: Vec<u8>,
initial_read: Vec<u8>,
}
impl<S> MseStream<S> {
pub fn plaintext(inner: S) -> Self {
Self {
inner,
read_cipher: None,
write_cipher: None,
write_buf: Vec::new(),
initial_read: Vec::new(),
}
}
pub(crate) fn encrypted(
inner: S,
read_cipher: Rc4,
write_cipher: Rc4,
initial_read: Vec<u8>,
) -> Self {
Self {
inner,
read_cipher: Some(read_cipher),
write_cipher: Some(write_cipher),
write_buf: Vec::with_capacity(WRITE_BUF_CAPACITY),
initial_read,
}
}
}
impl<S: AsyncRead + Unpin> AsyncRead for MseStream<S> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
if !this.initial_read.is_empty() {
let to_copy = this.initial_read.len().min(buf.remaining());
buf.put_slice(&this.initial_read[..to_copy]);
this.initial_read.drain(..to_copy);
return Poll::Ready(Ok(()));
}
let before = buf.filled().len();
match Pin::new(&mut this.inner).poll_read(cx, buf) {
Poll::Ready(Ok(())) => {
if let Some(cipher) = &mut this.read_cipher {
let filled = buf.filled_mut();
cipher.apply(&mut filled[before..]);
}
Poll::Ready(Ok(()))
}
other => other,
}
}
}
impl<S: AsyncWrite + Unpin> AsyncWrite for MseStream<S> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
if let Some(cipher) = &mut this.write_cipher {
this.write_buf.clear();
this.write_buf.extend_from_slice(buf);
cipher.apply(&mut this.write_buf);
Pin::new(&mut this.inner).poll_write(cx, &this.write_buf)
} else {
Pin::new(&mut this.inner).poll_write(cx, buf)
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().inner).poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[tokio::test]
async fn plaintext_passthrough() {
let (client, server) = tokio::io::duplex(1024);
let mut client = MseStream::plaintext(client);
let mut server = MseStream::plaintext(server);
client.write_all(b"hello").await.unwrap();
client.flush().await.unwrap();
let mut buf = [0u8; 5];
server.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello");
}
#[tokio::test]
async fn encrypted_roundtrip() {
let key_a = b"key for direction A!";
let key_b = b"key for direction B!";
let (raw_client, raw_server) = tokio::io::duplex(1024);
let mut client = MseStream::encrypted(
raw_client,
Rc4::new(key_b), Rc4::new(key_a), Vec::new(),
);
let mut server = MseStream::encrypted(
raw_server,
Rc4::new(key_a), Rc4::new(key_b), Vec::new(),
);
client.write_all(b"client to server").await.unwrap();
client.flush().await.unwrap();
let mut buf = [0u8; 16];
server.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"client to server");
server.write_all(b"server to client").await.unwrap();
server.flush().await.unwrap();
let mut buf = [0u8; 16];
client.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"server to client");
}
#[test]
fn encrypted_write_buf_pre_allocated() {
let (raw, _) = tokio::io::duplex(1024);
let stream = MseStream::encrypted(raw, Rc4::new(b"r"), Rc4::new(b"w"), Vec::new());
assert_eq!(stream.write_buf.capacity(), WRITE_BUF_CAPACITY);
}
#[tokio::test]
async fn encrypted_no_realloc_on_chunk_write() {
let (raw_client, _raw_server) = tokio::io::duplex(32768);
let mut client =
MseStream::encrypted(raw_client, Rc4::new(b"r"), Rc4::new(b"w"), Vec::new());
let data = vec![0xABu8; WRITE_BUF_CAPACITY];
client.write_all(&data).await.unwrap();
assert_eq!(client.write_buf.capacity(), WRITE_BUF_CAPACITY);
}
#[tokio::test]
async fn initial_read_drains_before_inner() {
let (raw_client, mut raw_server) = tokio::io::duplex(1024);
let initial = b"overflow".to_vec();
let mut client = MseStream::encrypted(raw_client, Rc4::new(b"r"), Rc4::new(b"w"), initial);
raw_server.write_all(b"inner").await.unwrap();
raw_server.flush().await.unwrap();
let mut buf = [0u8; 8];
client.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"overflow");
let mut buf = [0u8; 5];
client.read_exact(&mut buf).await.unwrap();
assert_eq!(buf.len(), 5);
}
}