use std::fmt;
use std::pin::Pin;
use async_std::io::{self, Read, Write};
use async_std::prelude::*;
use async_std::stream::Stream;
use async_std::sync::Arc;
use byte_pool::{Block, BytePool};
use futures::task::{Context, Poll};
use nom::Needed;
use crate::types::{Request, ResponseData};
lazy_static::lazy_static! {
pub(crate) static ref POOL: Arc<BytePool> = Arc::new(BytePool::new());
}
#[derive(Debug)]
pub struct ImapStream<R: Read + Write> {
pub(crate) inner: R,
decode_needs: Option<usize>,
buffer: Buffer,
closed: bool,
}
impl<R: Read + Write + Unpin> ImapStream<R> {
pub fn new(inner: R) -> Self {
ImapStream {
inner,
buffer: Buffer::new(),
decode_needs: None,
closed: false,
}
}
pub async fn encode(&mut self, msg: Request) -> Result<(), io::Error> {
if self.closed {
return Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"inner stream closed",
));
}
log::trace!(
"encode: input: {:?}, {:?}",
msg.0,
std::str::from_utf8(&msg.1)
);
if let Some(tag) = msg.0 {
self.inner.write_all(tag.as_bytes()).await?;
self.inner.write(b" ").await?;
}
self.inner.write_all(&msg.1).await?;
self.inner.write_all(b"\r\n").await?;
Ok(())
}
pub fn into_inner(self) -> R {
self.inner
}
pub async fn flush(&mut self) -> Result<(), io::Error> {
self.inner.flush().await
}
pub fn as_mut(&mut self) -> &mut R {
&mut self.inner
}
fn stream_eof_value(&self) -> Option<io::Result<ResponseData>> {
match self.buffer.used() {
0 => None,
_ => Some(Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"bytes remaining in stream",
))),
}
}
}
impl<R: Read + Write + Unpin> ImapStream<R> {
fn maybe_decode(&mut self) -> io::Result<Option<ResponseData>> {
if self.buffer.used() > self.decode_needs.unwrap_or(0) {
self.decode()
} else {
Ok(None)
}
}
fn decode(&mut self) -> io::Result<Option<ResponseData>> {
let block: Block<'static> = self.buffer.take_block();
let res = ResponseData::try_new(block, |buf| {
let buf = &buf[..self.buffer.used()];
log::trace!("decode: input: {:?}", std::str::from_utf8(buf));
match imap_proto::parse_response(buf) {
Ok((remaining, response)) => {
self.decode_needs = None;
self.buffer.reset_with_data(remaining);
Ok(response)
}
Err(nom::Err::Incomplete(Needed::Size(min))) => {
log::trace!("decode: incomplete data, need minimum {} bytes", min);
self.decode_needs = Some(min);
Err(None)
}
Err(nom::Err::Incomplete(_)) => {
log::trace!("decode: incomplete data, need unknown number of bytes");
self.decode_needs = None;
Err(None)
}
Err(other) => {
self.decode_needs = None;
Err(Some(io::Error::new(
io::ErrorKind::Other,
format!("{:?} during parsing of {:?}", other, buf),
)))
}
}
});
match res {
Ok(response) => Ok(Some(response)),
Err(rental::RentalError(err, block)) => {
self.buffer.return_block(block);
match err {
Some(err) => Err(err),
None => Ok(None),
}
}
}
}
}
struct Buffer {
block: Block<'static>,
offset: usize,
}
impl Buffer {
const BLOCK_SIZE: usize = 1024 * 4;
const MAX_CAPACITY: usize = 512 * 1024 * 1024;
fn new() -> Self {
Self {
block: POOL.alloc(Self::BLOCK_SIZE),
offset: 0,
}
}
fn used(&self) -> usize {
self.offset
}
fn free_as_mut_slice(&mut self) -> &mut [u8] {
&mut self.block[self.offset..]
}
fn extend_used(&mut self, num_bytes: usize) {
self.offset += num_bytes;
if self.offset > self.block.size() {
self.offset = self.block.size();
}
}
fn ensure_capacity(&mut self, required: Option<usize>) -> io::Result<()> {
let free_bytes: usize = self.block.size() - self.offset;
let min_required_bytes: usize = required.unwrap_or(0);
let extra_bytes_needed: usize = min_required_bytes.saturating_sub(self.block.size());
if free_bytes == 0 || extra_bytes_needed > 0 {
let increase = std::cmp::max(Buffer::BLOCK_SIZE, extra_bytes_needed);
self.grow(increase)?;
}
Ok(())
}
fn grow(&mut self, num_bytes: usize) -> io::Result<()> {
let min_size = self.block.size() + num_bytes;
let new_size = match min_size % Self::BLOCK_SIZE {
0 => min_size,
n => min_size + (Self::BLOCK_SIZE - n),
};
if new_size > Self::MAX_CAPACITY {
Err(io::Error::new(
io::ErrorKind::Other,
"incoming data too large",
))
} else {
self.block.realloc(new_size);
Ok(())
}
}
fn take_block(&mut self) -> Block<'static> {
std::mem::replace(&mut self.block, POOL.alloc(Self::BLOCK_SIZE))
}
fn reset_with_data(&mut self, data: &[u8]) {
let min_size = data.len();
let new_size = match min_size % Self::BLOCK_SIZE {
0 => min_size + Self::BLOCK_SIZE,
n => min_size + (Self::BLOCK_SIZE - n),
};
self.block = POOL.alloc(new_size);
self.block[..data.len()].copy_from_slice(data);
self.offset = data.len();
}
fn return_block(&mut self, block: Block<'static>) {
self.block = block;
}
}
impl fmt::Debug for Buffer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Buffer")
.field("used", &self.used())
.field("capacity", &self.block.size())
.finish()
}
}
impl<R: Read + Write + Unpin> Stream for ImapStream<R> {
type Item = io::Result<ResponseData>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = &mut *self;
if let Some(response) = this.maybe_decode()? {
return Poll::Ready(Some(Ok(response)));
}
if this.closed {
return Poll::Ready(this.stream_eof_value());
}
loop {
this.buffer.ensure_capacity(this.decode_needs)?;
let num_bytes_read =
match Pin::new(&mut this.inner).poll_read(cx, this.buffer.free_as_mut_slice()) {
Poll::Ready(result) => result?,
Poll::Pending => {
return Poll::Pending;
}
};
if num_bytes_read == 0 {
this.closed = true;
return Poll::Ready(this.stream_eof_value());
}
this.buffer.extend_used(num_bytes_read);
if let Some(response) = this.maybe_decode()? {
return Poll::Ready(Some(Ok(response)));
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
#[test]
fn test_buffer_empty() {
let buf = Buffer::new();
assert_eq!(buf.used(), 0);
let mut buf = Buffer::new();
let slice: &[u8] = buf.free_as_mut_slice();
assert_eq!(slice.len(), Buffer::BLOCK_SIZE);
assert_eq!(slice.len(), buf.block.size());
}
#[test]
fn test_buffer_extend_use() {
let mut buf = Buffer::new();
buf.extend_used(3);
assert_eq!(buf.used(), 3);
let slice = buf.free_as_mut_slice();
assert_eq!(slice.len(), Buffer::BLOCK_SIZE - 3);
buf.extend_used(Buffer::BLOCK_SIZE);
assert_eq!(buf.used(), Buffer::BLOCK_SIZE);
assert_eq!(buf.offset, Buffer::BLOCK_SIZE);
assert_eq!(buf.block.len(), buf.offset);
let slice = buf.free_as_mut_slice();
assert_eq!(slice.len(), 0);
}
#[test]
fn test_buffer_write_read() {
let mut buf = Buffer::new();
let mut slice = buf.free_as_mut_slice();
slice.write_all(b"hello").unwrap();
buf.extend_used(b"hello".len());
let slice = &buf.block[..buf.used()];
assert_eq!(slice, b"hello");
assert_eq!(buf.free_as_mut_slice().len(), buf.block.size() - buf.offset);
}
#[test]
fn test_buffer_grow() {
let mut buf = Buffer::new();
assert_eq!(buf.block.size(), Buffer::BLOCK_SIZE);
buf.grow(1).unwrap();
assert_eq!(buf.block.size(), 2 * Buffer::BLOCK_SIZE);
buf.grow(Buffer::BLOCK_SIZE + 1).unwrap();
assert_eq!(buf.block.size(), 4 * Buffer::BLOCK_SIZE);
let ret = buf.grow(Buffer::MAX_CAPACITY);
assert!(ret.is_err());
}
#[test]
fn test_buffer_ensure_capacity() {
let mut buf = Buffer::new();
buf.extend_used(Buffer::BLOCK_SIZE - 1);
assert_eq!(buf.free_as_mut_slice().len(), 1);
assert_eq!(buf.block.size(), Buffer::BLOCK_SIZE);
buf.ensure_capacity(None).unwrap();
assert_eq!(buf.free_as_mut_slice().len(), 1);
assert_eq!(buf.block.size(), Buffer::BLOCK_SIZE);
buf.extend_used(1);
assert_eq!(buf.free_as_mut_slice().len(), 0);
assert_eq!(buf.block.size(), Buffer::BLOCK_SIZE);
buf.ensure_capacity(None).unwrap();
assert_eq!(buf.free_as_mut_slice().len(), Buffer::BLOCK_SIZE);
assert_eq!(buf.block.size(), 2 * Buffer::BLOCK_SIZE);
buf.extend_used(5);
assert_eq!(buf.offset, Buffer::BLOCK_SIZE + 5);
buf.ensure_capacity(Some(3 * Buffer::BLOCK_SIZE - 6))
.unwrap();
assert_eq!(buf.free_as_mut_slice().len(), 2 * Buffer::BLOCK_SIZE - 5);
assert_eq!(buf.block.size(), 3 * Buffer::BLOCK_SIZE);
}
#[test]
fn test_buffer_take_and_return_block() {
let mut buf = Buffer::new();
buf.grow(1).unwrap();
let block_size = buf.block.size();
let block = buf.take_block();
assert_eq!(block.size(), block_size);
assert_ne!(buf.block.size(), block_size);
buf.return_block(block);
assert_eq!(buf.block.size(), block_size);
}
#[test]
fn test_buffer_reset_with_data() {
let data: [u8; 2 * Buffer::BLOCK_SIZE] = [b'a'; 2 * Buffer::BLOCK_SIZE];
let mut buf = Buffer::new();
let block_size = buf.block.size();
assert_eq!(block_size, Buffer::BLOCK_SIZE);
buf.reset_with_data(&data);
assert_ne!(buf.block.size(), block_size);
assert_eq!(buf.block.size(), 3 * Buffer::BLOCK_SIZE);
assert!(!buf.free_as_mut_slice().is_empty());
let data: [u8; 0] = [];
let mut buf = Buffer::new();
buf.reset_with_data(&data);
assert_eq!(buf.block.size(), Buffer::BLOCK_SIZE);
}
#[test]
fn test_buffer_debug() {
assert_eq!(
format!("{:?}", Buffer::new()),
format!(r#"Buffer {{ used: 0, capacity: {} }}"#, Buffer::BLOCK_SIZE)
);
}
}