use std::io::IoSliceMut;
use bytes::{Bytes, BytesMut};
use irontide_wire::Message;
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
use crate::peer_shared::OutgoingMessage;
const BUF_LEN: usize = 32 * 1024;
pub(crate) struct ReadBuf {
buf: Box<[u8; BUF_LEN]>,
start: usize,
len: usize,
}
impl ReadBuf {
pub(crate) fn new() -> Self {
Self {
buf: Box::new([0u8; BUF_LEN]),
start: 0,
len: 0,
}
}
#[inline]
pub(crate) fn len(&self) -> usize {
self.len
}
#[inline]
#[allow(dead_code)] pub(crate) fn is_full(&self) -> bool {
self.len == BUF_LEN
}
pub(crate) fn readable_slices_at(&self, offset: usize, len: usize) -> (&[u8], &[u8]) {
debug_assert!(
offset + len <= self.len,
"readable_slices_at: offset={offset} len={len} buf_len={}",
self.len,
);
if len == 0 {
return (&[], &[]);
}
let start = (self.start + offset) % BUF_LEN;
let end = start + len;
if end <= BUF_LEN {
(&self.buf[start..end], &[])
} else {
(&self.buf[start..], &self.buf[..end - BUF_LEN])
}
}
#[allow(dead_code)] pub(crate) fn readable_slices(&self) -> (&[u8], &[u8]) {
if self.len == 0 {
return (&[], &[]);
}
let end = self.start + self.len;
if end <= BUF_LEN {
(&self.buf[self.start..end], &[])
} else {
(&self.buf[self.start..], &self.buf[..end - BUF_LEN])
}
}
#[allow(dead_code)] pub(crate) fn unfilled_contiguous(&mut self) -> &mut [u8] {
if self.len == BUF_LEN {
return &mut [];
}
let write_pos = (self.start + self.len) % BUF_LEN;
if write_pos >= self.start {
&mut self.buf[write_pos..BUF_LEN]
} else {
&mut self.buf[write_pos..self.start]
}
}
pub(crate) fn unfilled_ioslices(&mut self) -> [IoSliceMut<'_>; 2] {
let available = BUF_LEN - self.len;
if available == 0 {
return [IoSliceMut::new(&mut []), IoSliceMut::new(&mut [])];
}
let write_pos = (self.start + self.len) % BUF_LEN;
if write_pos >= self.start {
let (left, right) = self.buf.split_at_mut(write_pos);
[
IoSliceMut::new(right),
IoSliceMut::new(&mut left[..self.start]),
]
} else {
[
IoSliceMut::new(&mut self.buf[write_pos..self.start]),
IoSliceMut::new(&mut []),
]
}
}
pub(crate) fn mark_filled(&mut self, n: usize) {
debug_assert!(
self.len + n <= BUF_LEN,
"mark_filled overflow: len={} n={n} cap={BUF_LEN}",
self.len,
);
self.len += n;
}
pub(crate) fn consume(&mut self, n: usize) {
debug_assert!(n <= self.len, "consume underflow: n={n} len={}", self.len,);
self.start = (self.start + n) % BUF_LEN;
self.len -= n;
if self.len == 0 {
self.start = 0;
}
}
pub(crate) fn consume_into(&mut self, dst: &mut [u8]) {
let n = dst.len();
debug_assert!(
n <= self.len,
"consume_into underflow: n={n} len={}",
self.len,
);
let end = self.start + n;
if end <= BUF_LEN {
dst.copy_from_slice(&self.buf[self.start..end]);
} else {
let first = BUF_LEN - self.start;
dst[..first].copy_from_slice(&self.buf[self.start..]);
dst[first..].copy_from_slice(&self.buf[..n - first]);
}
self.consume(n);
}
#[allow(clippy::uninit_vec)]
#[allow(dead_code)] pub(crate) fn consume_as_bytes(&mut self, n: usize) -> Bytes {
let mut vec = Vec::with_capacity(n);
unsafe { vec.set_len(n) };
self.consume_into(&mut vec);
Bytes::from(vec)
}
#[inline]
pub(crate) fn peek_byte_at(&self, offset: usize) -> u8 {
debug_assert!(
offset < self.len,
"peek_byte_at: offset={offset} len={}",
self.len
);
self.buf[(self.start + offset) % BUF_LEN]
}
#[inline]
pub(crate) fn peek_u32_be_at(&self, offset: usize) -> u32 {
debug_assert!(
offset + 4 <= self.len,
"peek_u32_be_at: offset={offset} len={}",
self.len
);
let s = self.start + offset;
u32::from_be_bytes([
self.buf[s % BUF_LEN],
self.buf[(s + 1) % BUF_LEN],
self.buf[(s + 2) % BUF_LEN],
self.buf[(s + 3) % BUF_LEN],
])
}
#[inline]
pub(crate) fn peek_u16_be_at(&self, offset: usize) -> u16 {
debug_assert!(
offset + 2 <= self.len,
"peek_u16_be_at: offset={offset} len={}",
self.len
);
let s = self.start + offset;
u16::from_be_bytes([self.buf[s % BUF_LEN], self.buf[(s + 1) % BUF_LEN]])
}
pub(crate) fn peek_u32_be(&self) -> u32 {
debug_assert!(
self.len >= 4,
"peek_u32_be: need 4 bytes, have {}",
self.len,
);
let b0 = self.buf[(self.start) % BUF_LEN];
let b1 = self.buf[(self.start + 1) % BUF_LEN];
let b2 = self.buf[(self.start + 2) % BUF_LEN];
let b3 = self.buf[(self.start + 3) % BUF_LEN];
u32::from_be_bytes([b0, b1, b2, b3])
}
#[allow(dead_code)] pub(crate) fn as_double_buf(&self) -> DoubleBufHelper<'_> {
let (a, b) = self.readable_slices();
DoubleBufHelper::new(a, b)
}
}
#[allow(dead_code)] pub(crate) struct DoubleBufHelper<'a> {
buf_0: &'a [u8],
buf_1: &'a [u8],
pos: usize,
}
#[allow(dead_code)] impl<'a> DoubleBufHelper<'a> {
pub(crate) fn new(buf_0: &'a [u8], buf_1: &'a [u8]) -> Self {
Self {
buf_0,
buf_1,
pos: 0,
}
}
#[inline]
pub(crate) fn remaining(&self) -> usize {
self.buf_0.len() + self.buf_1.len() - self.pos
}
#[inline]
fn byte_at(&self, pos: usize) -> u8 {
if pos < self.buf_0.len() {
self.buf_0[pos]
} else {
self.buf_1[pos - self.buf_0.len()]
}
}
pub(crate) fn consume<const N: usize>(&mut self) -> [u8; N] {
let mut out = [0u8; N];
for (i, slot) in out.iter_mut().enumerate() {
*slot = self.byte_at(self.pos + i);
}
self.pos += N;
out
}
#[inline]
pub(crate) fn read_u32_be(&mut self) -> u32 {
u32::from_be_bytes(self.consume::<4>())
}
#[inline]
pub(crate) fn read_u8(&mut self) -> u8 {
self.consume::<1>()[0]
}
pub(crate) fn as_ioslices(&self, limit: usize) -> [std::io::IoSlice<'_>; 2] {
let remaining_0 = self.buf_0.len().saturating_sub(self.pos);
let first_len = remaining_0.min(limit);
let second_len = self.buf_1.len().min(limit.saturating_sub(first_len));
let first_start = if self.pos < self.buf_0.len() {
self.pos
} else {
self.buf_0.len()
};
[
std::io::IoSlice::new(&self.buf_0[first_start..first_start + first_len]),
std::io::IoSlice::new(&self.buf_1[..second_len]),
]
}
pub(crate) fn consume_variable(&mut self, n: usize) -> Bytes {
let mut vec = vec![0u8; n];
let mut written = 0usize;
if self.pos < self.buf_0.len() {
let avail = self.buf_0.len() - self.pos;
let take = avail.min(n);
vec[..take].copy_from_slice(&self.buf_0[self.pos..self.pos + take]);
written = take;
}
if written < n {
let buf1_start = if self.pos > self.buf_0.len() {
self.pos - self.buf_0.len()
} else {
0
};
let remaining = n - written;
vec[written..].copy_from_slice(&self.buf_1[buf1_start..buf1_start + remaining]);
}
self.pos += n;
Bytes::from(vec)
}
}
pub(crate) struct PeerReader<R> {
reader: R,
buf: ReadBuf,
max_message_size: usize,
oversized_buf: Vec<u8>,
}
impl<R: crate::vectored_io::AsyncReadVectored> PeerReader<R> {
pub(crate) fn new(reader: R, max_message_size: usize) -> Self {
Self {
reader,
buf: ReadBuf::new(),
max_message_size,
oversized_buf: Vec::new(),
}
}
async fn fill(&mut self) -> std::io::Result<usize> {
let mut iov = self.buf.unfilled_ioslices();
if iov[0].is_empty() {
return Ok(0);
}
let n = crate::vectored_io::read_vectored(&mut self.reader, &mut iov).await?;
self.buf.mark_filled(n);
Ok(n)
}
pub(crate) async fn fill_message(&mut self) -> Result<FillStatus, irontide_wire::Error> {
loop {
if self.buf.len() >= 4 {
let length = self.buf.peek_u32_be() as usize;
if length > self.max_message_size {
return Err(irontide_wire::Error::MessageTooLarge {
size: length,
max: self.max_message_size,
});
}
if length == 0 {
return Ok(FillStatus::Ready);
}
if 4 + length > BUF_LEN {
return Ok(FillStatus::Oversized(length));
}
let total = 4 + length;
if self.buf.len() >= total {
return Ok(FillStatus::Ready);
}
}
let n = self.fill().await?;
if n == 0 {
if self.buf.len() == 0 {
return Ok(FillStatus::Eof);
}
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"unexpected EOF with partial message buffered",
)
.into());
}
}
}
pub(crate) fn try_decode(&mut self) -> Result<(Message<&[u8]>, usize), irontide_wire::Error> {
let length = self.buf.peek_u32_be() as usize;
let total = 4 + length;
if length == 0 {
return Ok((Message::KeepAlive, 4));
}
let id = self.buf.peek_byte_at(4);
decode_from_ring_borrowed(&self.buf, &mut self.oversized_buf, id, length, total)
}
#[inline]
pub(crate) fn advance(&mut self, n: usize) {
self.buf.consume(n);
}
#[allow(clippy::uninit_vec)]
pub(crate) async fn decode_oversized_into(
&mut self,
length: usize,
) -> Result<Message, irontide_wire::Error> {
self.buf.consume(4);
self.oversized_buf.clear();
self.oversized_buf.reserve(length);
unsafe { self.oversized_buf.set_len(length) };
let mut filled = 0;
let buffered = self.buf.len().min(length);
if buffered > 0 {
self.buf.consume_into(&mut self.oversized_buf[..buffered]);
filled = buffered;
}
if filled < length {
self.reader
.read_exact(&mut self.oversized_buf[filled..])
.await?;
}
Message::from_payload(Bytes::from(self.oversized_buf.split_off(0)))
}
#[allow(dead_code)] pub(crate) async fn next_message(&mut self) -> Result<Option<Message>, irontide_wire::Error> {
match self.fill_message().await? {
FillStatus::Eof => Ok(None),
FillStatus::Oversized(length) => self.decode_oversized_into(length).await.map(Some),
FillStatus::Ready => {
let (msg, consumed) = self.try_decode()?;
let owned = msg.to_owned_bytes();
self.advance(consumed);
Ok(Some(owned))
}
}
}
#[allow(dead_code)] async fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
let mut offset = 0;
let from_ring = self.buf.len().min(buf.len());
if from_ring > 0 {
self.buf.consume_into(&mut buf[..from_ring]);
offset = from_ring;
}
if offset < buf.len() {
self.reader.read_exact(&mut buf[offset..]).await?;
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum FillStatus {
Ready,
Eof,
Oversized(usize),
}
fn decode_from_ring_borrowed<'a>(
buf: &'a ReadBuf,
oversized_buf: &'a mut Vec<u8>,
id: u8,
length: usize,
total: usize,
) -> Result<(Message<&'a [u8]>, usize), irontide_wire::Error> {
match id {
MSG_CHOKE => Ok((Message::Choke, total)),
MSG_UNCHOKE => Ok((Message::Unchoke, total)),
MSG_INTERESTED => Ok((Message::Interested, total)),
MSG_NOT_INTERESTED => Ok((Message::NotInterested, total)),
MSG_HAVE => {
ensure_msg_len(length, 5)?;
let index = buf.peek_u32_be_at(5);
Ok((Message::Have { index }, total))
}
MSG_REQUEST => {
ensure_msg_len(length, 13)?;
let index = buf.peek_u32_be_at(5);
let begin = buf.peek_u32_be_at(9);
let len = buf.peek_u32_be_at(13);
Ok((
Message::Request {
index,
begin,
length: len,
},
total,
))
}
MSG_CANCEL => {
ensure_msg_len(length, 13)?;
let index = buf.peek_u32_be_at(5);
let begin = buf.peek_u32_be_at(9);
let len = buf.peek_u32_be_at(13);
Ok((
Message::Cancel {
index,
begin,
length: len,
},
total,
))
}
MSG_PORT => {
ensure_msg_len(length, 3)?;
let port = buf.peek_u16_be_at(5);
Ok((Message::Port(port), total))
}
MSG_SUGGEST_PIECE => {
ensure_msg_len(length, 5)?;
let index = buf.peek_u32_be_at(5);
Ok((Message::SuggestPiece(index), total))
}
MSG_HAVE_ALL => Ok((Message::HaveAll, total)),
MSG_HAVE_NONE => Ok((Message::HaveNone, total)),
MSG_REJECT_REQUEST => {
ensure_msg_len(length, 13)?;
let index = buf.peek_u32_be_at(5);
let begin = buf.peek_u32_be_at(9);
let len = buf.peek_u32_be_at(13);
Ok((
Message::RejectRequest {
index,
begin,
length: len,
},
total,
))
}
MSG_ALLOWED_FAST => {
ensure_msg_len(length, 5)?;
let index = buf.peek_u32_be_at(5);
Ok((Message::AllowedFast(index), total))
}
MSG_PIECE => {
ensure_msg_len(length, 9)?;
let index = buf.peek_u32_be_at(5);
let begin = buf.peek_u32_be_at(9);
let data_len = length - 9;
let (data_0, data_1) = buf.readable_slices_at(13, data_len);
Ok((
Message::Piece {
index,
begin,
data_0,
data_1,
},
total,
))
}
MSG_BITFIELD => {
let data_len = length - 1;
let (s0, s1) = buf.readable_slices_at(5, data_len);
if s1.is_empty() {
Ok((Message::Bitfield(s0), total))
} else {
oversized_buf.clear();
oversized_buf.reserve(data_len);
oversized_buf.extend_from_slice(s0);
oversized_buf.extend_from_slice(s1);
Ok((Message::Bitfield(oversized_buf.as_slice()), total))
}
}
MSG_EXTENDED => {
ensure_msg_len(length, 2)?;
let ext_id = buf.peek_byte_at(5);
let payload_len = length - 2;
let (s0, s1) = buf.readable_slices_at(6, payload_len);
if s1.is_empty() {
Ok((
Message::Extended {
ext_id,
payload: s0,
},
total,
))
} else {
oversized_buf.clear();
oversized_buf.reserve(payload_len);
oversized_buf.extend_from_slice(s0);
oversized_buf.extend_from_slice(s1);
Ok((
Message::Extended {
ext_id,
payload: oversized_buf.as_slice(),
},
total,
))
}
}
_ => {
oversized_buf.clear();
oversized_buf.reserve(length);
let (s0, s1) = buf.readable_slices_at(4, length);
oversized_buf.extend_from_slice(s0);
oversized_buf.extend_from_slice(s1);
let owned = Message::from_payload(Bytes::copy_from_slice(oversized_buf))?;
match owned {
Message::KeepAlive => Ok((Message::KeepAlive, total)),
Message::Choke => Ok((Message::Choke, total)),
Message::Unchoke => Ok((Message::Unchoke, total)),
Message::Interested => Ok((Message::Interested, total)),
Message::NotInterested => Ok((Message::NotInterested, total)),
Message::Have { index } => Ok((Message::Have { index }, total)),
Message::Request {
index,
begin,
length: len,
} => Ok((
Message::Request {
index,
begin,
length: len,
},
total,
)),
Message::Cancel {
index,
begin,
length: len,
} => Ok((
Message::Cancel {
index,
begin,
length: len,
},
total,
)),
Message::Port(port) => Ok((Message::Port(port), total)),
Message::SuggestPiece(index) => Ok((Message::SuggestPiece(index), total)),
Message::HaveAll => Ok((Message::HaveAll, total)),
Message::HaveNone => Ok((Message::HaveNone, total)),
Message::RejectRequest {
index,
begin,
length: len,
} => Ok((
Message::RejectRequest {
index,
begin,
length: len,
},
total,
)),
Message::AllowedFast(index) => Ok((Message::AllowedFast(index), total)),
Message::HashRequest {
pieces_root,
base,
index,
count,
proof_layers,
} => Ok((
Message::HashRequest {
pieces_root,
base,
index,
count,
proof_layers,
},
total,
)),
Message::HashReject {
pieces_root,
base,
index,
count,
proof_layers,
} => Ok((
Message::HashReject {
pieces_root,
base,
index,
count,
proof_layers,
},
total,
)),
Message::Hashes {
pieces_root,
base,
index,
count,
proof_layers,
hashes,
} => Ok((
Message::Hashes {
pieces_root,
base,
index,
count,
proof_layers,
hashes,
},
total,
)),
Message::Bitfield(_) | Message::Piece { .. } | Message::Extended { .. } => {
Ok((Message::Bitfield(oversized_buf.as_slice()), total))
}
}
}
}
}
const MSG_CHOKE: u8 = 0;
const MSG_UNCHOKE: u8 = 1;
const MSG_INTERESTED: u8 = 2;
const MSG_NOT_INTERESTED: u8 = 3;
const MSG_HAVE: u8 = 4;
const MSG_BITFIELD: u8 = 5;
const MSG_REQUEST: u8 = 6;
const MSG_PIECE: u8 = 7;
const MSG_CANCEL: u8 = 8;
const MSG_PORT: u8 = 9;
const MSG_SUGGEST_PIECE: u8 = 0x0D;
const MSG_HAVE_ALL: u8 = 0x0E;
const MSG_HAVE_NONE: u8 = 0x0F;
const MSG_REJECT_REQUEST: u8 = 0x10;
const MSG_ALLOWED_FAST: u8 = 0x11;
const MSG_EXTENDED: u8 = 20;
#[inline]
fn ensure_msg_len(length: usize, expected: usize) -> Result<(), irontide_wire::Error> {
if length < expected {
Err(irontide_wire::Error::MessageTooShort {
expected,
got: length,
})
} else {
Ok(())
}
}
const MAX_MSG_LEN: usize = 4 + 1 + 8 + 16_384;
pub(crate) struct PeerWriter<W> {
writer: W,
buf: Box<[u8; MAX_MSG_LEN]>,
}
impl<W: AsyncWrite + Unpin> PeerWriter<W> {
pub(crate) fn new(writer: W) -> Self {
Self {
writer,
buf: Box::new([0u8; MAX_MSG_LEN]),
}
}
pub(crate) async fn send(&mut self, msg: &Message) -> std::io::Result<()> {
if msg.wire_len() <= MAX_MSG_LEN {
let len = msg.encode_to_slice(&mut *self.buf);
self.writer.write_all(&self.buf[..len]).await?;
} else {
let wire = msg.wire_len();
let mut tmp = BytesMut::with_capacity(wire);
msg.encode_into(&mut tmp);
self.writer.write_all(&tmp).await?;
}
self.writer.flush().await?;
Ok(())
}
pub(crate) async fn send_outgoing(&mut self, msg: &OutgoingMessage) -> std::io::Result<()> {
match msg {
OutgoingMessage::Request {
index,
begin,
length,
} => {
self.send(&Message::Request {
index: *index,
begin: *begin,
length: *length,
})
.await
}
OutgoingMessage::Have { index } => self.send(&Message::Have { index: *index }).await,
OutgoingMessage::Cancel {
index,
begin,
length,
} => {
self.send(&Message::Cancel {
index: *index,
begin: *begin,
length: *length,
})
.await
}
OutgoingMessage::Keepalive => self.send(&Message::KeepAlive).await,
OutgoingMessage::Interested => self.send(&Message::Interested).await,
OutgoingMessage::NotInterested => self.send(&Message::NotInterested).await,
OutgoingMessage::Unchoke => self.send(&Message::Unchoke).await,
OutgoingMessage::Choke => self.send(&Message::Choke).await,
OutgoingMessage::Wire(wire_msg) => self.send(wire_msg).await,
}
}
#[allow(dead_code)] pub(crate) async fn flush(&mut self) -> std::io::Result<()> {
self.writer.flush().await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::vectored_io::VectoredCompat;
use tokio::io::duplex;
#[test]
fn readbuf_empty() {
let rb = ReadBuf::new();
assert_eq!(rb.len(), 0);
let (a, b) = rb.readable_slices();
assert!(a.is_empty());
assert!(b.is_empty());
}
#[test]
fn readbuf_readable_after_fill() {
let mut rb = ReadBuf::new();
let slice = rb.unfilled_contiguous();
slice[0] = 0xDE;
slice[1] = 0xAD;
slice[2] = 0xBE;
slice[3] = 0xEF;
rb.mark_filled(4);
assert_eq!(rb.len(), 4);
let (a, b) = rb.readable_slices();
assert_eq!(a, &[0xDE, 0xAD, 0xBE, 0xEF]);
assert!(b.is_empty());
}
#[test]
fn readbuf_advance_wraps() {
let mut rb = ReadBuf::new();
let n = BUF_LEN;
let slice = rb.unfilled_contiguous();
assert_eq!(slice.len(), n);
for (i, byte) in slice.iter_mut().enumerate() {
*byte = (i & 0xFF) as u8;
}
rb.mark_filled(n);
assert_eq!(rb.len(), n);
rb.consume(n - 4);
assert_eq!(rb.len(), 4);
assert_eq!(rb.start, n - 4);
let slice = rb.unfilled_contiguous();
assert!(slice.len() >= 8);
for (i, byte) in slice[..8].iter_mut().enumerate() {
*byte = (0xA0 + i) as u8;
}
rb.mark_filled(8);
assert_eq!(rb.len(), 12);
let (a, b) = rb.readable_slices();
assert_eq!(a.len(), 4);
assert_eq!(b.len(), 8);
for (i, byte) in b.iter().enumerate() {
assert_eq!(*byte, (0xA0 + i) as u8);
}
}
#[test]
fn readbuf_unfilled_ioslices_contiguous() {
let mut rb = ReadBuf::new();
let slice = rb.unfilled_contiguous();
for byte in slice[..100].iter_mut() {
*byte = 0xFF;
}
rb.mark_filled(100);
let [a, b] = rb.unfilled_ioslices();
assert_eq!(a.len(), BUF_LEN - 100);
assert!(b.is_empty());
}
#[test]
fn readbuf_unfilled_ioslices_wrapped() {
let mut rb = ReadBuf::new();
let slice = rb.unfilled_contiguous();
for byte in slice.iter_mut() {
*byte = 0x00;
}
rb.mark_filled(BUF_LEN);
rb.consume(BUF_LEN - 10);
rb.consume(6);
let mut rb2 = ReadBuf::new();
rb2.start = 5;
rb2.len = 3;
let [a, b] = rb2.unfilled_ioslices();
assert_eq!(a.len(), BUF_LEN - 8);
assert_eq!(b.len(), 5);
assert_eq!(a.len() + b.len(), BUF_LEN - 3);
}
#[test]
fn readbuf_full_buffer() {
let mut rb = ReadBuf::new();
let slice = rb.unfilled_contiguous();
assert_eq!(slice.len(), BUF_LEN);
for byte in slice[..BUF_LEN - 1].iter_mut() {
*byte = 0xAA;
}
rb.mark_filled(BUF_LEN - 1);
let uf = rb.unfilled_contiguous();
assert_eq!(uf.len(), 1);
assert!(!rb.is_full());
rb.mark_filled(1);
assert!(rb.is_full());
assert_eq!(rb.unfilled_contiguous().len(), 0);
}
#[test]
fn helper_consume_contiguous() {
let data = [1u8, 2, 3, 4, 5];
let mut h = DoubleBufHelper::new(&data, &[]);
let got = h.consume::<4>();
assert_eq!(got, [1, 2, 3, 4]);
assert_eq!(h.remaining(), 1);
}
#[test]
fn helper_consume_across_boundary() {
let a = [1u8, 2];
let b = [3u8, 4, 5];
let mut h = DoubleBufHelper::new(&a, &b);
let got = h.consume::<4>();
assert_eq!(got, [1, 2, 3, 4]);
assert_eq!(h.remaining(), 1);
}
#[test]
fn helper_read_u32_be_split() {
let a = [0x01u8, 0x02];
let b = [0x03u8, 0x04];
let mut h = DoubleBufHelper::new(&a, &b);
let val = h.read_u32_be();
assert_eq!(val, 0x0102_0304);
}
#[test]
fn helper_consume_variable_wrapped() {
let a = [10u8, 20, 30];
let b = [40u8, 50, 60, 70];
let mut h = DoubleBufHelper::new(&a, &b);
let _ = h.read_u8();
let bytes = h.consume_variable(5);
assert_eq!(&bytes[..], &[20, 30, 40, 50, 60]);
assert_eq!(h.remaining(), 1);
}
#[test]
fn helper_remaining() {
let a = [0u8; 10];
let b = [0u8; 5];
let mut h = DoubleBufHelper::new(&a, &b);
assert_eq!(h.remaining(), 15);
let _ = h.consume::<4>();
assert_eq!(h.remaining(), 11);
let _ = h.consume_variable(3);
assert_eq!(h.remaining(), 8);
}
fn encode_message(msg: &Message) -> Vec<u8> {
let mut buf = BytesMut::new();
msg.encode_into(&mut buf);
buf.to_vec()
}
#[tokio::test]
async fn reader_decode_contiguous_message() {
let msg = Message::Have { index: 42 };
let wire = encode_message(&msg);
let (client, mut server) = duplex(64 * 1024);
tokio::spawn(async move {
server.write_all(&wire).await.unwrap();
server.shutdown().await.unwrap();
});
let mut reader = PeerReader::new(VectoredCompat(client), 1 << 20);
let decoded = reader
.next_message()
.await
.expect("decode error")
.expect("expected message, got None");
assert_eq!(decoded, msg);
let eof = reader.next_message().await.expect("decode error");
assert!(eof.is_none());
}
#[tokio::test]
async fn reader_decode_wrapped_message() {
let msg = Message::Have { index: 99 };
let wire = encode_message(&msg);
let (client, _server) = duplex(1024);
let mut reader = PeerReader::new(VectoredCompat(client), 1 << 20);
let start = BUF_LEN - 5; reader.buf.start = start;
reader.buf.len = wire.len();
for (i, &byte) in wire.iter().enumerate() {
reader.buf.buf[(start + i) % BUF_LEN] = byte;
}
let decoded = reader
.next_message()
.await
.expect("decode error")
.expect("expected message");
assert_eq!(decoded, msg);
}
#[tokio::test]
async fn reader_decode_wrapped_length_prefix() {
let msg = Message::Interested; let wire = encode_message(&msg);
let (client, _server) = duplex(1024);
let mut reader = PeerReader::new(VectoredCompat(client), 1 << 20);
let start = BUF_LEN - 2;
reader.buf.start = start;
reader.buf.len = wire.len();
for (i, &byte) in wire.iter().enumerate() {
reader.buf.buf[(start + i) % BUF_LEN] = byte;
}
let decoded = reader
.next_message()
.await
.expect("decode error")
.expect("expected message");
assert_eq!(decoded, msg);
}
#[tokio::test]
async fn reader_decode_keepalive() {
let wire = [0u8; 4];
let (client, mut server) = duplex(1024);
tokio::spawn(async move {
server.write_all(&wire).await.unwrap();
server.shutdown().await.unwrap();
});
let mut reader = PeerReader::new(VectoredCompat(client), 1 << 20);
let decoded = reader
.next_message()
.await
.expect("decode error")
.expect("expected KeepAlive");
assert_eq!(decoded, Message::KeepAlive);
let eof = reader.next_message().await.expect("decode error");
assert!(eof.is_none());
}
#[tokio::test]
async fn reader_decode_piece_data_integrity() {
let mut piece_data = vec![0u8; 16_384];
for (i, byte) in piece_data.iter_mut().enumerate() {
*byte = (i % 251) as u8; }
let msg = Message::Piece {
index: 7,
begin: 0,
data_0: Bytes::from(piece_data.clone()),
data_1: Bytes::new(),
};
let wire = encode_message(&msg);
let (client, mut server) = duplex(256 * 1024);
tokio::spawn(async move {
server.write_all(&wire).await.unwrap();
server.shutdown().await.unwrap();
});
let mut reader = PeerReader::new(VectoredCompat(client), 1 << 20);
let decoded = reader
.next_message()
.await
.expect("decode error")
.expect("expected Piece");
if let Message::Piece {
index,
begin,
data_0,
data_1,
} = decoded
{
assert_eq!(index, 7);
assert_eq!(begin, 0);
assert_eq!(data_0.len() + data_1.len(), 16_384);
assert_eq!(&data_0[..], &piece_data[..]);
} else {
panic!("expected Piece message, got {decoded:?}");
}
}
#[tokio::test]
async fn reader_reject_oversized() {
let bad_len: u32 = 500_000;
let wire = bad_len.to_be_bytes();
let (client, mut server) = duplex(1024);
tokio::spawn(async move {
server.write_all(&wire).await.unwrap();
let _ = tokio::time::sleep(std::time::Duration::from_secs(10)).await;
});
let max = 100_000;
let mut reader = PeerReader::new(VectoredCompat(client), max);
let err = reader
.next_message()
.await
.expect_err("should reject oversized message");
match err {
irontide_wire::Error::MessageTooLarge { size, max: m } => {
assert_eq!(size, bad_len as usize);
assert_eq!(m, max);
}
other => panic!("expected MessageTooLarge, got {other:?}"),
}
}
#[tokio::test]
async fn reader_multiple_messages_sequence() {
let messages = [
Message::Choke,
Message::Unchoke,
Message::Have { index: 1 },
Message::Interested,
];
let mut wire = Vec::new();
for msg in &messages {
let mut buf = BytesMut::new();
msg.encode_into(&mut buf);
wire.extend_from_slice(&buf);
}
let (client, mut server) = duplex(64 * 1024);
tokio::spawn(async move {
server.write_all(&wire).await.unwrap();
server.shutdown().await.unwrap();
});
let mut reader = PeerReader::new(VectoredCompat(client), 1 << 20);
for expected in &messages {
let decoded = reader
.next_message()
.await
.expect("decode error")
.expect("expected message");
assert_eq!(&decoded, expected);
}
let eof = reader.next_message().await.expect("decode error");
assert!(eof.is_none());
}
#[tokio::test]
async fn reader_oversized_fallback() {
let bitfield_len = BUF_LEN + 1000; let bitfield_data = vec![0xAA; bitfield_len];
let msg: Message = Message::Bitfield(Bytes::from(bitfield_data.clone()));
let wire = encode_message(&msg);
let max = BUF_LEN * 4; let (client, mut server) = duplex(256 * 1024);
tokio::spawn(async move {
server.write_all(&wire).await.unwrap();
server.shutdown().await.unwrap();
});
let mut reader = PeerReader::new(VectoredCompat(client), max);
let decoded = reader
.next_message()
.await
.expect("decode error")
.expect("expected Bitfield");
if let Message::Bitfield(data) = decoded {
assert_eq!(data.len(), bitfield_len);
assert!(data.iter().all(|&b| b == 0xAA));
} else {
panic!("expected Bitfield, got {decoded:?}");
}
}
#[tokio::test]
async fn writer_send_small_message() {
let msg = Message::Have { index: 42 };
let expected = encode_message(&msg);
let (client, mut server) = duplex(64 * 1024);
let mut writer = PeerWriter::new(client);
writer.send(&msg).await.expect("send failed");
drop(writer);
let mut received = Vec::new();
server
.read_to_end(&mut received)
.await
.expect("read failed");
assert_eq!(received, expected);
}
#[tokio::test]
async fn writer_send_piece_message() {
let mut piece_data = vec![0u8; 16_384];
for (i, byte) in piece_data.iter_mut().enumerate() {
*byte = (i % 199) as u8; }
let msg = Message::Piece {
index: 3,
begin: 16_384,
data_0: Bytes::from(piece_data.clone()),
data_1: Bytes::new(),
};
let expected = encode_message(&msg);
let (client, mut server) = duplex(256 * 1024);
let mut writer = PeerWriter::new(client);
writer.send(&msg).await.expect("send failed");
drop(writer);
let mut received = Vec::new();
server
.read_to_end(&mut received)
.await
.expect("read failed");
assert_eq!(received, expected);
let decoded =
Message::from_payload(Bytes::from(received[4..].to_vec())).expect("decode failed");
if let Message::Piece {
index,
begin,
data_0,
data_1,
} = decoded
{
assert_eq!(index, 3);
assert_eq!(begin, 16_384);
let _ = &data_1; assert_eq!(&data_0[..], &piece_data[..]);
} else {
panic!("expected Piece, got {decoded:?}");
}
}
#[tokio::test]
async fn write_buffer_no_reallocation() {
let (client, _server) = duplex(256 * 1024);
let mut writer = PeerWriter::new(client);
let ptr_before = writer.buf.as_ptr();
for i in 0..100u32 {
let msg = Message::Have { index: i };
writer.send(&msg).await.expect("send failed");
}
let ptr_after = writer.buf.as_ptr();
assert_eq!(ptr_before, ptr_after, "Box buffer should never reallocate");
}
#[tokio::test]
async fn write_buffer_max_size_message() {
let piece_data = vec![0xABu8; 16_384];
let msg = Message::Piece {
index: 0,
begin: 0,
data_0: Bytes::from(piece_data),
data_1: Bytes::new(),
};
let (client, mut server) = duplex(256 * 1024);
let mut writer = PeerWriter::new(client);
writer.send(&msg).await.expect("send max-size message");
drop(writer);
let mut received = Vec::new();
server
.read_to_end(&mut received)
.await
.expect("read failed");
assert_eq!(received.len(), MAX_MSG_LEN);
}
fn preload_reader(
wire: &[u8],
start: usize,
) -> PeerReader<VectoredCompat<tokio::io::DuplexStream>> {
let (client, _server) = duplex(1024);
let mut reader = PeerReader::new(VectoredCompat(client), 1 << 20);
reader.buf.start = start;
reader.buf.len = wire.len();
for (i, &byte) in wire.iter().enumerate() {
reader.buf.buf[(start + i) % BUF_LEN] = byte;
}
reader
}
#[tokio::test]
async fn reader_decode_piece_borrowed() {
let data_len = 100;
let mut piece_data = vec![0u8; data_len];
for (i, byte) in piece_data.iter_mut().enumerate() {
*byte = (i % 197) as u8;
}
let msg = Message::Piece {
index: 42,
begin: 8192,
data_0: Bytes::from(piece_data.clone()),
data_1: Bytes::new(),
};
let wire = encode_message(&msg);
let mut reader = preload_reader(&wire, 0);
let status = reader.fill_message().await.expect("fill error");
assert_eq!(status, FillStatus::Ready);
let (decoded, consumed) = reader.try_decode().expect("decode error");
assert_eq!(consumed, wire.len());
if let Message::Piece {
index,
begin,
data_0,
data_1,
} = decoded
{
assert_eq!(index, 42);
assert_eq!(begin, 8192);
assert_eq!(data_0.len(), data_len);
assert!(data_1.is_empty());
assert_eq!(data_0, &piece_data[..]);
} else {
panic!("expected Piece, got {decoded:?}");
}
reader.advance(consumed);
assert_eq!(reader.buf.len(), 0);
}
#[tokio::test]
async fn reader_decode_piece_wrapped_borrowed() {
let data_len = 100;
let mut piece_data = vec![0u8; data_len];
for (i, byte) in piece_data.iter_mut().enumerate() {
*byte = (i % 179) as u8;
}
let msg = Message::Piece {
index: 10,
begin: 0,
data_0: Bytes::from(piece_data.clone()),
data_1: Bytes::new(),
};
let wire = encode_message(&msg);
let start = BUF_LEN - 50;
let mut reader = preload_reader(&wire, start);
let status = reader.fill_message().await.expect("fill error");
assert_eq!(status, FillStatus::Ready);
let (decoded, consumed) = reader.try_decode().expect("decode error");
assert_eq!(consumed, wire.len());
if let Message::Piece {
index,
begin,
data_0,
data_1,
} = decoded
{
assert_eq!(index, 10);
assert_eq!(begin, 0);
assert!(!data_1.is_empty(), "data should wrap");
assert_eq!(data_0.len() + data_1.len(), data_len);
let mut combined = Vec::new();
combined.extend_from_slice(data_0);
combined.extend_from_slice(data_1);
assert_eq!(combined, piece_data);
} else {
panic!("expected Piece, got {decoded:?}");
}
reader.advance(consumed);
assert_eq!(reader.buf.len(), 0);
}
#[tokio::test]
async fn reader_decode_bitfield_borrowed() {
let bitfield_data = vec![0xFF; 64];
let msg = Message::Bitfield(Bytes::from(bitfield_data.clone()));
let wire = encode_message(&msg);
let mut reader = preload_reader(&wire, 0);
let status = reader.fill_message().await.expect("fill error");
assert_eq!(status, FillStatus::Ready);
let (decoded, consumed) = reader.try_decode().expect("decode error");
assert_eq!(consumed, wire.len());
if let Message::Bitfield(data) = decoded {
assert_eq!(data, &bitfield_data[..]);
} else {
panic!("expected Bitfield, got {decoded:?}");
}
reader.advance(consumed);
assert_eq!(reader.buf.len(), 0);
}
#[tokio::test]
async fn reader_decode_extended_borrowed() {
let payload_data = b"d1:pi12345ee"; let msg = Message::Extended {
ext_id: 1,
payload: Bytes::from_static(payload_data),
};
let wire = encode_message(&msg);
let mut reader = preload_reader(&wire, 0);
let status = reader.fill_message().await.expect("fill error");
assert_eq!(status, FillStatus::Ready);
let (decoded, consumed) = reader.try_decode().expect("decode error");
assert_eq!(consumed, wire.len());
if let Message::Extended { ext_id, payload } = decoded {
assert_eq!(ext_id, 1);
assert_eq!(payload, payload_data);
} else {
panic!("expected Extended, got {decoded:?}");
}
reader.advance(consumed);
assert_eq!(reader.buf.len(), 0);
}
#[tokio::test]
async fn reader_bitfield_wrap_copies_to_oversized() {
let bitfield_data = vec![0xAA; 200];
let msg = Message::Bitfield(Bytes::from(bitfield_data.clone()));
let wire = encode_message(&msg);
let start = BUF_LEN - 100;
let mut reader = preload_reader(&wire, start);
let status = reader.fill_message().await.expect("fill error");
assert_eq!(status, FillStatus::Ready);
let (decoded, consumed) = reader.try_decode().expect("decode error");
assert_eq!(consumed, wire.len());
if let Message::Bitfield(data) = decoded {
assert_eq!(data, &bitfield_data[..]);
assert_eq!(reader.oversized_buf.len(), bitfield_data.len());
} else {
panic!("expected Bitfield, got {decoded:?}");
}
reader.advance(consumed);
assert_eq!(reader.buf.len(), 0);
}
#[tokio::test]
async fn reader_advance_after_message() {
let msg1 = Message::Have { index: 1 };
let msg2 = Message::Have { index: 2 };
let wire1 = encode_message(&msg1);
let wire2 = encode_message(&msg2);
let mut combined = Vec::new();
combined.extend_from_slice(&wire1);
combined.extend_from_slice(&wire2);
let mut reader = preload_reader(&combined, 0);
let status = reader.fill_message().await.expect("fill error");
assert_eq!(status, FillStatus::Ready);
let (decoded1, consumed1) = reader.try_decode().expect("decode error");
assert_eq!(decoded1, Message::Have::<&[u8]> { index: 1 });
assert_eq!(consumed1, wire1.len());
assert_eq!(reader.buf.len(), combined.len());
reader.advance(consumed1);
assert_eq!(reader.buf.len(), wire2.len());
let status = reader.fill_message().await.expect("fill error");
assert_eq!(status, FillStatus::Ready);
let (decoded2, consumed2) = reader.try_decode().expect("decode error");
assert_eq!(decoded2, Message::Have::<&[u8]> { index: 2 });
assert_eq!(consumed2, wire2.len());
reader.advance(consumed2);
assert_eq!(reader.buf.len(), 0);
}
#[test]
fn readbuf_readable_slices_at_contiguous() {
let mut rb = ReadBuf::new();
let slice = rb.unfilled_contiguous();
for (i, byte) in slice[..20].iter_mut().enumerate() {
*byte = i as u8;
}
rb.mark_filled(20);
let (a, b) = rb.readable_slices_at(3, 5);
assert_eq!(a, &[3, 4, 5, 6, 7]);
assert!(b.is_empty());
}
#[test]
fn readbuf_readable_slices_at_wrapping() {
let mut rb = ReadBuf::new();
let slice = rb.unfilled_contiguous();
for (i, byte) in slice.iter_mut().enumerate() {
*byte = (i & 0xFF) as u8;
}
rb.mark_filled(BUF_LEN);
rb.consume(BUF_LEN - 4);
let slice = rb.unfilled_contiguous();
for (i, byte) in slice[..8].iter_mut().enumerate() {
*byte = (0xB0 + i) as u8;
}
rb.mark_filled(8);
let (a, b) = rb.readable_slices_at(2, 6);
assert_eq!(a.len(), 2);
assert_eq!(b.len(), 4);
assert_eq!(b, &[0xB0, 0xB1, 0xB2, 0xB3]);
}
#[tokio::test]
async fn write_buffer_oversized_bitfield() {
let bitfield_len = 20_000; let bitfield_data = vec![0xCC; bitfield_len];
let msg = Message::Bitfield(Bytes::from(bitfield_data.clone()));
let expected = encode_message(&msg);
let (client, mut server) = duplex(256 * 1024);
let mut writer = PeerWriter::new(client);
writer.send(&msg).await.expect("send oversized bitfield");
drop(writer);
let mut received = Vec::new();
server
.read_to_end(&mut received)
.await
.expect("read failed");
assert_eq!(received, expected);
assert_eq!(received.len(), 5 + bitfield_len);
let decoded =
Message::from_payload(Bytes::from(received[4..].to_vec())).expect("decode failed");
if let Message::Bitfield(data) = decoded {
assert_eq!(data.len(), bitfield_len);
assert!(data.iter().all(|&b| b == 0xCC));
} else {
panic!("expected Bitfield, got {decoded:?}");
}
}
}