use std::io::{self, Read, BufRead};
use std::mem;
use std::time::Duration;
use futures::{Future, Poll, Async};
use tokio_core::reactor::{Handle, Timeout};
use tokio_io::io::{copy, Copy};
use tokio_io::{AsyncRead, AsyncWrite};
use bytes::{BufMut, BytesMut};
use super::BUFFER_SIZE;
use super::utils::{copy_timeout, copy_timeout_opt, CopyTimeout, CopyTimeoutOpt};
static DUMMY_BUFFER: [u8; BUFFER_SIZE] = [0u8; BUFFER_SIZE];
pub trait DecryptedRead: BufRead + AsyncRead {
fn buffer_size(&self, data: &[u8]) -> usize;
fn copy<W>(self, w: W) -> Copy<Self, W>
where Self: Sized,
W: AsyncWrite
{
copy(self, w)
}
fn copy_timeout<W>(self, w: W, timeout: Duration, handle: Handle) -> CopyTimeout<Self, W>
where Self: Sized,
W: AsyncWrite
{
copy_timeout(self, w, timeout, handle)
}
fn copy_timeout_opt<W>(self, w: W, timeout: Option<Duration>, handle: Handle) -> CopyTimeoutOpt<Self, W>
where Self: Sized,
W: AsyncWrite
{
copy_timeout_opt(self, w, timeout, handle)
}
}
pub trait EncryptedWrite {
fn write_raw(&mut self, data: &[u8]) -> io::Result<usize>;
fn flush(&mut self) -> io::Result<()>;
fn encrypt<B: BufMut>(&mut self, data: &[u8], buf: &mut B) -> io::Result<()>;
fn buffer_size(&self, data: &[u8]) -> usize;
fn write_all<B: AsRef<[u8]>>(self, buf: B) -> EncryptedWriteAll<Self, B>
where Self: Sized
{
EncryptedWriteAll::new(self, buf)
}
fn copy<R: Read>(self, r: R) -> EncryptedCopy<R, Self>
where Self: Sized
{
EncryptedCopy::new(r, self)
}
fn copy_timeout<R: Read>(self, r: R, timeout: Duration, handle: Handle) -> EncryptedCopyTimeout<R, Self>
where Self: Sized
{
EncryptedCopyTimeout::new(r, self, timeout, handle)
}
fn copy_timeout_opt<R: Read>(self, r: R, timeout: Option<Duration>, handle: Handle) -> EncryptedCopyOpt<R, Self>
where Self: Sized
{
match timeout {
Some(t) => EncryptedCopyOpt::CopyTimeout(self.copy_timeout(r, t, handle)),
None => EncryptedCopyOpt::Copy(self.copy(r)),
}
}
}
pub enum EncryptedWriteAll<W, B>
where W: EncryptedWrite,
B: AsRef<[u8]>
{
Writing {
writer: W,
buf: B,
pos: usize,
enc_buf: BytesMut,
encrypted: bool,
},
Empty,
}
impl<W, B> EncryptedWriteAll<W, B>
where W: EncryptedWrite,
B: AsRef<[u8]>
{
fn new(w: W, buf: B) -> EncryptedWriteAll<W, B> {
let buffer_size = w.buffer_size(&DUMMY_BUFFER);
EncryptedWriteAll::Writing {
writer: w,
buf: buf,
pos: 0,
enc_buf: BytesMut::with_capacity(buffer_size),
encrypted: false,
}
}
}
impl<W, B> Future for EncryptedWriteAll<W, B>
where W: EncryptedWrite,
B: AsRef<[u8]>
{
type Item = (W, B);
type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match *self {
EncryptedWriteAll::Empty => panic!("poll after EncryptedWriteAll finished"),
EncryptedWriteAll::Writing { ref mut writer, ref buf, ref mut pos, ref mut enc_buf, ref mut encrypted } => {
if !*encrypted {
*encrypted = true;
let buffer_len = writer.buffer_size(buf.as_ref());
enc_buf.reserve(buffer_len);
writer.encrypt(buf.as_ref(), enc_buf)?;
}
while *pos < enc_buf.len() {
let n = try_nb!(writer.write_raw(&enc_buf[*pos..]));
*pos += n;
if n == 0 {
let err = io::Error::new(io::ErrorKind::Other, "zero-length write");
return Err(err);
}
}
}
}
match mem::replace(self, EncryptedWriteAll::Empty) {
EncryptedWriteAll::Writing { writer, buf, .. } => Ok((writer, buf).into()),
EncryptedWriteAll::Empty => unreachable!(),
}
}
}
pub struct EncryptedCopy<R, W>
where R: Read,
W: EncryptedWrite
{
reader: Option<R>,
writer: Option<W>,
read_done: bool,
amt: u64,
pos: usize,
cap: usize,
buf: BytesMut,
}
impl<R, W> EncryptedCopy<R, W>
where R: Read,
W: EncryptedWrite
{
fn new(r: R, w: W) -> EncryptedCopy<R, W> {
let buffer_size = w.buffer_size(&DUMMY_BUFFER);
EncryptedCopy {
reader: Some(r),
writer: Some(w),
read_done: false,
amt: 0,
pos: 0,
cap: 0,
buf: BytesMut::with_capacity(buffer_size),
}
}
}
impl<R, W> Future for EncryptedCopy<R, W>
where R: Read,
W: EncryptedWrite
{
type Item = (u64, R, W);
type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let mut local_buf = [0u8; BUFFER_SIZE];
loop {
if self.pos == self.cap && !self.read_done {
let n = try_nb!(self.reader.as_mut().unwrap().read(&mut local_buf[..]));
self.buf.clear();
if n == 0 {
self.read_done = true;
} else {
let data = &local_buf[..n];
let buffer_len = self.writer.as_mut().unwrap().buffer_size(data);
self.buf.reserve(buffer_len);
self.writer.as_mut().unwrap().encrypt(data, &mut self.buf)?;
}
self.pos = 0;
self.cap = self.buf.len();
}
while self.pos < self.cap {
let i = try_nb!(self.writer.as_mut().unwrap().write_raw(&self.buf[self.pos..self.cap]));
self.pos += i;
self.amt += i as u64;
}
if self.pos == self.cap && self.read_done {
try_nb!(self.writer.as_mut().unwrap().flush());
return Ok((self.amt, self.reader.take().unwrap(), self.writer.take().unwrap()).into());
}
}
}
}
pub struct EncryptedCopyTimeout<R, W>
where R: Read,
W: EncryptedWrite
{
reader: Option<R>,
writer: Option<W>,
read_done: bool,
amt: u64,
pos: usize,
cap: usize,
timeout: Duration,
handle: Handle,
timer: Option<Timeout>,
read_buf: [u8; BUFFER_SIZE],
write_buf: BytesMut,
}
impl<R, W> EncryptedCopyTimeout<R, W>
where R: Read,
W: EncryptedWrite
{
fn new(r: R, w: W, dur: Duration, handle: Handle) -> EncryptedCopyTimeout<R, W> {
let buffer_size = w.buffer_size(&DUMMY_BUFFER);
EncryptedCopyTimeout {
reader: Some(r),
writer: Some(w),
read_done: false,
amt: 0,
pos: 0,
cap: 0,
timeout: dur,
handle: handle,
timer: None,
read_buf: [0u8; BUFFER_SIZE],
write_buf: BytesMut::with_capacity(buffer_size),
}
}
fn try_poll_timeout(&mut self) -> io::Result<()> {
match self.timer.as_mut() {
None => Ok(()),
Some(t) => {
match t.poll() {
Err(err) => Err(err),
Ok(Async::Ready(..)) => Err(io::Error::new(io::ErrorKind::TimedOut, "timeout")),
Ok(Async::NotReady) => Ok(()),
}
}
}
}
fn clear_timer(&mut self) {
let _ = self.timer.take();
}
fn read_or_set_timeout(&mut self) -> io::Result<usize> {
self.try_poll_timeout()?;
self.clear_timer();
self.write_buf.clear();
match self.reader.as_mut().unwrap().read(&mut self.read_buf) {
Ok(0) => {
self.cap = 0;
self.pos = 0;
Ok(0)
}
Ok(n) => {
let data = &self.read_buf[..n];
let buffer_len = self.writer.as_mut().unwrap().buffer_size(data);
self.write_buf.reserve(buffer_len);
self.writer.as_mut().unwrap().encrypt(data, &mut self.write_buf)?;
self.cap = self.write_buf.len();
self.pos = 0;
Ok(n)
}
Err(e) => {
if e.kind() == io::ErrorKind::WouldBlock {
self.timer = Some(Timeout::new(self.timeout, &self.handle).unwrap());
}
Err(e)
}
}
}
fn write_or_set_timeout(&mut self) -> io::Result<usize> {
self.try_poll_timeout()?;
self.clear_timer();
match self.writer.as_mut().unwrap().write_raw(&self.write_buf[self.pos..self.cap]) {
Ok(n) => {
self.pos += n;
Ok(n)
}
Err(e) => {
if e.kind() == io::ErrorKind::WouldBlock {
self.timer = Some(Timeout::new(self.timeout, &self.handle).unwrap());
}
Err(e)
}
}
}
}
impl<R, W> Future for EncryptedCopyTimeout<R, W>
where R: Read,
W: EncryptedWrite
{
type Item = (u64, R, W);
type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop {
if self.pos == self.cap && !self.read_done {
let n = try_nb!(self.read_or_set_timeout());
if n == 0 {
self.read_done = true;
}
}
while self.pos < self.cap {
let i = try_nb!(self.write_or_set_timeout());
if i == 0 {
let err = io::Error::new(io::ErrorKind::UnexpectedEof, "early eof");
return Err(err);
}
self.amt += i as u64;
}
if self.pos == self.cap && self.read_done {
try_nb!(self.writer.as_mut().unwrap().flush());
return Ok((self.amt, self.reader.take().unwrap(), self.writer.take().unwrap()).into());
}
}
}
}
pub enum EncryptedCopyOpt<R, W>
where R: Read,
W: EncryptedWrite
{
Copy(EncryptedCopy<R, W>),
CopyTimeout(EncryptedCopyTimeout<R, W>),
}
impl<R, W> Future for EncryptedCopyOpt<R, W>
where R: Read,
W: EncryptedWrite
{
type Item = (u64, R, W);
type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match *self {
EncryptedCopyOpt::Copy(ref mut c) => c.poll(),
EncryptedCopyOpt::CopyTimeout(ref mut c) => c.poll(),
}
}
}