use crate::crypto::DecryptCipher;
use futures_lite::io::AsyncRead;
use futures_timer::Delay;
use std::future::Future;
use std::io::{Error, ErrorKind, Result};
use std::pin::Pin;
use std::task::{Context, Poll};
use crate::constants::{DEFAULT_TIMEOUT, MAX_MESSAGE_SIZE};
use crate::message::{Frame, FrameType};
use crate::util::stat_uint24_le;
use std::time::Duration;
const TIMEOUT: Duration = Duration::from_secs(DEFAULT_TIMEOUT as u64);
const READ_BUF_INITIAL_SIZE: usize = 1024 * 128;
#[derive(Debug)]
pub(crate) struct ReadState {
buf: Vec<u8>,
start: usize,
end: usize,
step: Step,
timeout: Delay,
cipher: Option<DecryptCipher>,
frame_type: FrameType,
}
impl ReadState {
pub(crate) fn new() -> ReadState {
ReadState {
buf: vec![0u8; READ_BUF_INITIAL_SIZE],
start: 0,
end: 0,
step: Step::Header,
timeout: Delay::new(TIMEOUT),
cipher: None,
frame_type: FrameType::Raw,
}
}
}
#[derive(Debug)]
enum Step {
Header,
Body {
header_len: usize,
body_len: usize,
},
Batch,
}
impl ReadState {
pub(crate) fn upgrade_with_decrypt_cipher(&mut self, decrypt_cipher: DecryptCipher) {
self.cipher = Some(decrypt_cipher);
}
pub(crate) fn decrypt_buf(&mut self, buf: &[u8]) -> Result<Vec<u8>> {
if let Some(cipher) = self.cipher.as_mut() {
Ok(cipher.decrypt_buf(buf)?.0)
} else {
Ok(buf.to_vec())
}
}
pub(crate) fn set_frame_type(&mut self, frame_type: FrameType) {
self.frame_type = frame_type;
}
pub(crate) fn poll_reader<R>(
&mut self,
cx: &mut Context<'_>,
mut reader: &mut R,
) -> Poll<Result<Frame>>
where
R: AsyncRead + Unpin,
{
let mut incomplete = true;
loop {
if !incomplete {
if let Some(result) = self.process() {
return Poll::Ready(result);
}
} else {
incomplete = false;
}
let n = match Pin::new(&mut reader).poll_read(cx, &mut self.buf[self.end..]) {
Poll::Ready(Ok(n)) if n > 0 => n,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending | Poll::Ready(Ok(_)) => {
return Pin::new(&mut self.timeout)
.poll(cx)
.map(|()| Err(Error::new(ErrorKind::TimedOut, "Remote timed out")));
}
};
let end = self.end + n;
let (success, segments) = create_segments(&self.buf[self.start..end])?;
if success {
if let Some(ref mut cipher) = self.cipher {
let mut dec_end = self.start;
for (index, header_len, body_len) in segments {
let de = cipher.decrypt(
&mut self.buf[self.start + index..end],
header_len,
body_len,
)?;
dec_end = self.start + index + de;
}
self.end = dec_end;
} else {
self.end = end;
}
} else {
self.cycle_buf_and_resize_if_needed(segments[segments.len() - 1]);
incomplete = true;
}
self.timeout.reset(TIMEOUT);
}
}
fn cycle_buf_and_resize_if_needed(&mut self, last_segment: (usize, usize, usize)) {
let (last_index, last_header_len, last_body_len) = last_segment;
let total_incoming_length = last_index + last_header_len + last_body_len;
if self.buf.len() < total_incoming_length {
self.buf.resize(total_incoming_length, 0u8);
}
let temp = self.buf[self.start..].to_vec();
let len = temp.len();
self.buf[..len].copy_from_slice(&temp[..]);
self.end = len;
self.start = 0;
}
fn process(&mut self) -> Option<Result<Frame>> {
loop {
match self.step {
Step::Header => {
let stat = stat_uint24_le(&self.buf[self.start..self.end]);
if let Some((header_len, body_len)) = stat {
if body_len == 0 {
self.start += header_len;
return None;
} else if (self.start + header_len + body_len as usize) < self.end {
self.step = Step::Batch;
} else {
let body_len = body_len as usize;
if body_len > MAX_MESSAGE_SIZE as usize {
return Some(Err(Error::new(
ErrorKind::InvalidData,
"Message length above max allowed size",
)));
}
self.step = Step::Body {
header_len,
body_len,
};
}
} else {
return Some(Err(Error::new(ErrorKind::InvalidData, "Invalid header")));
}
}
Step::Body {
header_len,
body_len,
} => {
let message_len = header_len + body_len;
let range = self.start + header_len..self.start + message_len;
let frame = Frame::decode(&self.buf[range], &self.frame_type);
self.start += message_len;
self.step = Step::Header;
return Some(frame);
}
Step::Batch => {
let frame =
Frame::decode_multiple(&self.buf[self.start..self.end], &self.frame_type);
self.start = self.end;
self.step = Step::Header;
return Some(frame);
}
}
}
}
}
#[allow(clippy::type_complexity)]
fn create_segments(buf: &[u8]) -> Result<(bool, Vec<(usize, usize, usize)>)> {
let mut index: usize = 0;
let len = buf.len();
let mut segments: Vec<(usize, usize, usize)> = vec![];
while index < len {
if let Some((header_len, body_len)) = stat_uint24_le(&buf[index..]) {
let body_len = body_len as usize;
segments.push((index, header_len, body_len));
if len < index + header_len + body_len {
return Ok((false, segments));
}
index += header_len + body_len;
} else {
return Err(Error::new(
ErrorKind::InvalidData,
"Could not read header while decrypting",
));
}
}
Ok((true, segments))
}