use crate::msg::{Message, MsgHeader, MsgHeaderWrapper, Type};
use crate::mwc_core::global::header_size_bytes;
use crate::mwc_core::ser::{BufReader, ProtocolVersion, Readable};
use crate::types::{AttachmentMeta, AttachmentUpdate, Error};
use crate::{
msg::HeadersData,
mwc_core::core::block::{BlockHeader, UntrustedBlockHeader},
};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use mwc_core::ser::Reader;
use std::cmp::min;
use std::io::Read;
use std::mem;
use std::net::TcpStream;
use std::sync::Arc;
use std::time::{Duration, Instant};
use MsgHeaderWrapper::*;
use State::*;
const HEADER_IO_TIMEOUT: Duration = Duration::from_millis(2000);
pub const BODY_IO_TIMEOUT: Duration = Duration::from_millis(60000);
const HEADER_BATCH_SIZE: usize = 32;
enum State {
None,
Header(MsgHeaderWrapper),
BlockHeaders {
bytes_left: usize,
items_left: usize,
headers: Vec<BlockHeader>,
},
Attachment(usize, Arc<AttachmentMeta>, Instant),
}
impl State {
fn is_none(&self) -> bool {
match self {
State::None => true,
_ => false,
}
}
}
pub struct Codec {
pub version: ProtocolVersion,
stream: TcpStream,
buffer: BytesMut,
state: State,
bytes_read: usize,
}
impl Codec {
pub fn new(version: ProtocolVersion, stream: TcpStream) -> Self {
Self {
version,
stream,
buffer: BytesMut::with_capacity(8 * 1024),
state: None,
bytes_read: 0,
}
}
pub fn stream(self) -> TcpStream {
self.stream
}
pub fn expect_attachment(&mut self, meta: Arc<AttachmentMeta>) {
debug_assert!(self.state.is_none());
self.state = Attachment(meta.size, meta, Instant::now());
}
fn next_len(&self) -> usize {
match &self.state {
None => MsgHeader::LEN,
Header(Known(h)) if h.msg_type == Type::Headers => {
min(h.msg_len as usize, 2)
}
Header(Known(header)) => header.msg_len as usize,
Header(Unknown(len, _)) => *len as usize,
BlockHeaders { bytes_left, .. } => {
min(*bytes_left, header_size_bytes(63))
}
Attachment(left, _, _) => min(*left, 48_000),
}
}
fn set_stream_timeout(&self) -> Result<(), Error> {
let timeout = match &self.state {
None => HEADER_IO_TIMEOUT,
_ => BODY_IO_TIMEOUT,
};
self.stream.set_read_timeout(Some(timeout))?;
Ok(())
}
fn read_inner(&mut self) -> Result<Message, Error> {
self.bytes_read = 0;
loop {
let next_len = self.next_len();
let pre_len = self.buffer.len();
let to_read = next_len.saturating_sub(pre_len);
if to_read > 0 {
self.buffer.reserve(to_read);
for _ in 0..to_read {
self.buffer.put_u8(0);
}
self.set_stream_timeout()?;
if let Err(e) = self.stream.read_exact(&mut self.buffer[pre_len..]) {
self.buffer.truncate(pre_len);
return Err(e.into());
}
self.bytes_read += to_read;
}
match &mut self.state {
None => {
let mut raw = self.buffer.split_to(next_len).freeze();
let mut reader = BufReader::new(&mut raw, self.version);
let header = MsgHeaderWrapper::read(&mut reader)?;
self.state = Header(header);
}
Header(Known(header)) => {
let mut raw = self.buffer.split_to(next_len).freeze();
if header.msg_type == Type::Headers {
let mut reader = BufReader::new(&mut raw, self.version);
let items_left = reader.read_u16()? as usize;
self.state = BlockHeaders {
bytes_left: header.msg_len as usize - 2,
items_left,
headers: Vec::with_capacity(min(HEADER_BATCH_SIZE, items_left)),
};
} else {
let msg = decode_message(header, &mut raw, self.version);
self.state = None;
return msg;
}
}
Header(Unknown(_, msg_type)) => {
let msg_type = *msg_type;
self.buffer.advance(next_len);
self.state = None;
return Ok(Message::Unknown(msg_type));
}
BlockHeaders {
bytes_left,
items_left,
headers,
} => {
if *bytes_left == 0 {
self.state = None;
return Err(Error::BadMessage);
}
let mut reader = BufReader::new(&mut self.buffer, self.version);
let header: UntrustedBlockHeader = reader.body()?;
let bytes_read = reader.bytes_read() as usize;
headers.push(header.into());
*bytes_left = bytes_left.saturating_sub(bytes_read);
*items_left -= 1;
let remaining = *items_left as u64;
if headers.len() == HEADER_BATCH_SIZE || remaining == 0 {
let mut h = Vec::with_capacity(min(HEADER_BATCH_SIZE, *items_left));
mem::swap(headers, &mut h);
if remaining == 0 {
let bytes_left = *bytes_left;
self.state = None;
if bytes_left > 0 {
return Err(Error::BadMessage);
}
}
return Ok(Message::Headers(HeadersData {
headers: h,
remaining,
}));
}
}
Attachment(left, meta, now) => {
let raw = self.buffer.split_to(next_len).freeze();
*left -= next_len;
if now.elapsed().as_secs() > 10 {
*now = Instant::now();
debug!("attachment: {}/{}", meta.size - *left, meta.size);
}
let update = AttachmentUpdate {
read: next_len,
left: *left,
meta: Arc::clone(meta),
};
if *left == 0 {
self.state = None;
debug!("attachment: DONE");
}
return Ok(Message::Attachment(update, Some(raw)));
}
}
}
}
pub fn read(&mut self) -> (Result<Message, Error>, u64) {
let msg = self.read_inner();
(msg, self.bytes_read as u64)
}
}
fn decode_message(
header: &MsgHeader,
body: &mut Bytes,
version: ProtocolVersion,
) -> Result<Message, Error> {
let mut msg = BufReader::new(body, version);
let c = match header.msg_type {
Type::Ping => Message::Ping(msg.body()?),
Type::Pong => Message::Pong(msg.body()?),
Type::BanReason => Message::BanReason(msg.body()?),
Type::TransactionKernel => Message::TransactionKernel(msg.body()?),
Type::GetTransaction => Message::GetTransaction(msg.body()?),
Type::Transaction => Message::Transaction(msg.body()?),
Type::StemTransaction => Message::StemTransaction(msg.body()?),
Type::GetBlock => Message::GetBlock(msg.body()?),
Type::Block => Message::Block(msg.body()?),
Type::GetCompactBlock => Message::GetCompactBlock(msg.body()?),
Type::CompactBlock => Message::CompactBlock(msg.body()?),
Type::GetHeaders => Message::GetHeaders(msg.body()?),
Type::Header => Message::Header(msg.body()?),
Type::GetPeerAddrs => Message::GetPeerAddrs(msg.body()?),
Type::PeerAddrs => Message::PeerAddrs(msg.body()?),
Type::TxHashSetRequest => Message::TxHashSetRequest(msg.body()?),
Type::TxHashSetArchive => Message::TxHashSetArchive(msg.body()?),
Type::GetHeadersHashesSegment => Message::GetHeadersHashesSegment(msg.body()?),
Type::OutputHeadersHashesSegment => Message::OutputHeadersHashesSegment(msg.body()?),
Type::GetOutputBitmapSegment => Message::GetOutputBitmapSegment(msg.body()?),
Type::OutputBitmapSegment => Message::OutputBitmapSegment(msg.body()?),
Type::StartPibdSyncRequest => Message::StartPibdSyncRequest(msg.body()?),
Type::StartHeadersHashRequest => Message::StartHeadersHashRequest(msg.body()?),
Type::StartHeadersHashResponse => Message::StartHeadersHashResponse(msg.body()?),
Type::PibdSyncState => Message::PibdSyncState(msg.body()?),
Type::GetOutputSegment => Message::GetOutputSegment(msg.body()?),
Type::OutputSegment => Message::OutputSegment(msg.body()?),
Type::GetRangeProofSegment => Message::GetRangeProofSegment(msg.body()?),
Type::RangeProofSegment => Message::RangeProofSegment(msg.body()?),
Type::GetKernelSegment => Message::GetKernelSegment(msg.body()?),
Type::KernelSegment => Message::KernelSegment(msg.body()?),
Type::HasAnotherArchiveHeader => Message::HasAnotherArchiveHeader(msg.body()?),
Type::Error | Type::Hand | Type::Shake | Type::Headers => {
return Err(Error::UnexpectedMessage(format!(
"get message with type {:?} (code {})",
header.msg_type, header.msg_type as u32
)))
}
Type::TorAddress => Message::TorAddress(msg.body()?),
};
Ok(c)
}