use bytes::Buf;
use core::{borrow::Borrow, fmt};
use crate::{
metainfo::InfoHash,
peer::{self, Id, InvalidInput},
piece::{Block, BlockBegin, BlockData, BlockLength, Index},
};
pub const PROTOCOL_STRING_BYTES: [u8; 20] = *b"\x13BitTorrent protocol";
#[derive(Default, Clone, Copy, PartialEq, Eq)]
pub struct ReservedBytes(pub [u8; 8]);
impl AsRef<[u8]> for ReservedBytes {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
impl Borrow<[u8]> for ReservedBytes {
fn borrow(&self) -> &[u8] {
&self.0
}
}
impl From<[u8; 8]> for ReservedBytes {
fn from(other: [u8; 8]) -> Self {
Self(other)
}
}
impl From<&[u8; 8]> for ReservedBytes {
fn from(other: &[u8; 8]) -> Self {
Self(*other)
}
}
fmt_byte_array!(ReservedBytes);
#[derive(Debug, PartialEq, Eq)]
pub enum Frame<'a> {
KeepAlive,
Choke,
Unchoke,
Interested,
NotInterested,
Have(HaveMsg),
Bitfield(BitfieldMsg<'a>),
Request(RequestMsg),
Piece(PieceMsg<'a>),
Cancel(CancelMsg),
Unknown(u8, &'a [u8]),
}
#[cfg_attr(feature = "std", derive(thiserror::Error))]
#[derive(Debug)]
pub enum Error {
#[cfg_attr(feature = "std", error("incomplete frame"))]
IncompleteFrame,
#[cfg_attr(feature = "std", error("message length too large {0}"))]
MessageLengthTooLarge(usize),
#[cfg_attr(feature = "std", error("invalid message length"))]
InvalidMessageLength,
}
#[cfg(feature = "std")]
impl From<Error> for std::io::Error {
fn from(error: Error) -> Self {
std::io::Error::new(std::io::ErrorKind::InvalidInput, error)
}
}
pub const MAX_EXPECTED_FRAME_LEN: usize = 1 + 4 + 4 + 16384;
impl<'a> Frame<'a> {
#[cfg(feature = "std")]
pub fn check<T: AsRef<[u8]>>(cursor: &mut std::io::Cursor<T>) -> Result<(), Error> {
if cursor.remaining() < 4 {
return Err(Error::IncompleteFrame);
}
let msg_len = cursor.get_u32();
if msg_len == 0 {
return Ok(());
}
let msg_len = usize::try_from(msg_len).unwrap();
if msg_len > MAX_EXPECTED_FRAME_LEN {
return Err(Error::MessageLengthTooLarge(msg_len));
}
if cursor.remaining() < msg_len {
return Err(Error::IncompleteFrame);
}
let ty = cursor.get_u8();
match ty {
ChokeMsg::TYPE => {
if msg_len != ChokeMsg::LEN {
return Err(Error::InvalidMessageLength);
}
Ok(())
}
UnchokeMsg::TYPE => {
if msg_len != UnchokeMsg::LEN {
return Err(Error::InvalidMessageLength);
}
Ok(())
}
InterestedMsg::TYPE => {
if msg_len != InterestedMsg::LEN {
return Err(Error::InvalidMessageLength);
}
Ok(())
}
NotInterestedMsg::TYPE => {
if msg_len != NotInterestedMsg::LEN {
return Err(Error::InvalidMessageLength);
}
Ok(())
}
HaveMsg::TYPE => {
if msg_len != HaveMsg::LEN {
return Err(Error::InvalidMessageLength);
}
cursor.advance(4);
Ok(())
}
RequestMsg::TYPE => {
if msg_len != RequestMsg::LEN {
return Err(Error::InvalidMessageLength);
}
cursor.advance(12);
Ok(())
}
CancelMsg::TYPE => {
if msg_len != CancelMsg::LEN {
return Err(Error::InvalidMessageLength);
}
cursor.advance(12);
Ok(())
}
_ => {
cursor.advance(msg_len - 1);
Ok(())
}
}
}
pub fn parse(buf: &'a [u8]) -> Result<Self, Error> {
if buf.remaining() < 4 {
return Err(Error::IncompleteFrame);
}
let msg_len = u32::from_be_bytes(<[u8; 4]>::try_from(&buf[..4]).unwrap());
if msg_len == 0 {
return Ok(Self::KeepAlive);
}
let msg_len = usize::try_from(msg_len).unwrap();
if msg_len > MAX_EXPECTED_FRAME_LEN {
return Err(Error::MessageLengthTooLarge(msg_len));
}
let buf = &buf[4..];
if buf.remaining() < msg_len {
return Err(Error::IncompleteFrame);
}
let ty = buf[0];
match ty {
ChokeMsg::TYPE => {
if msg_len != ChokeMsg::LEN {
return Err(Error::InvalidMessageLength);
}
Ok(Self::Choke)
}
UnchokeMsg::TYPE => {
if msg_len != UnchokeMsg::LEN {
return Err(Error::InvalidMessageLength);
}
Ok(Self::Unchoke)
}
InterestedMsg::TYPE => {
if msg_len != InterestedMsg::LEN {
return Err(Error::InvalidMessageLength);
}
Ok(Self::Interested)
}
NotInterestedMsg::TYPE => {
if msg_len != NotInterestedMsg::LEN {
return Err(Error::InvalidMessageLength);
}
Ok(Self::NotInterested)
}
HaveMsg::TYPE => {
if msg_len != HaveMsg::LEN {
return Err(Error::InvalidMessageLength);
}
let index =
Index::from(u32::from_be_bytes(<[u8; 4]>::try_from(&buf[1..5]).unwrap()));
Ok(Self::Have(HaveMsg(index)))
}
BitfieldMsg::TYPE => {
todo!()
}
RequestMsg::TYPE => {
if msg_len != RequestMsg::LEN {
return Err(Error::InvalidMessageLength);
}
let index =
Index::from(u32::from_be_bytes(<[u8; 4]>::try_from(&buf[1..5]).unwrap()));
let begin =
BlockBegin::from(u32::from_be_bytes(<[u8; 4]>::try_from(&buf[5..9]).unwrap()));
let length = BlockLength::from(u32::from_be_bytes(
<[u8; 4]>::try_from(&buf[9..13]).unwrap(),
));
Ok(Self::Request(RequestMsg(Block {
index,
begin,
length,
})))
}
PieceMsg::TYPE => {
let index =
Index::from(u32::from_be_bytes(<[u8; 4]>::try_from(&buf[1..5]).unwrap()));
let begin =
BlockBegin::from(u32::from_be_bytes(<[u8; 4]>::try_from(&buf[5..9]).unwrap()));
let length = msg_len - 1 - 4 - 4;
let data = &buf[9..9 + length];
Ok(Self::Piece(PieceMsg(BlockData { index, begin, data })))
}
CancelMsg::TYPE => {
if msg_len != CancelMsg::LEN {
return Err(Error::InvalidMessageLength);
}
let index =
Index::from(u32::from_be_bytes(<[u8; 4]>::try_from(&buf[1..5]).unwrap()));
let begin =
BlockBegin::from(u32::from_be_bytes(<[u8; 4]>::try_from(&buf[5..9]).unwrap()));
let length = BlockLength::from(u32::from_be_bytes(
<[u8; 4]>::try_from(&buf[9..13]).unwrap(),
));
Ok(Self::Cancel(CancelMsg(Block {
index,
begin,
length,
})))
}
ty => Ok(Self::Unknown(ty, &buf[1..1 + msg_len - 1])),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct KeepAliveMsg;
impl KeepAliveMsg {
#[must_use]
pub const fn msg_len() -> u32 {
0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ChokeMsg;
impl ChokeMsg {
pub const TYPE: u8 = 0;
pub const LEN: usize = 1;
#[must_use]
pub const fn msg_len() -> u32 {
1
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct UnchokeMsg;
impl UnchokeMsg {
pub const TYPE: u8 = 1;
pub const LEN: usize = 1;
#[must_use]
pub const fn msg_len() -> u32 {
1
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct InterestedMsg;
impl InterestedMsg {
pub const TYPE: u8 = 2;
pub const LEN: usize = 1;
#[must_use]
pub const fn msg_len() -> u32 {
1
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct NotInterestedMsg;
impl NotInterestedMsg {
pub const TYPE: u8 = 3;
pub const LEN: usize = 1;
#[must_use]
pub const fn msg_len() -> u32 {
1
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct HaveMsg(pub Index);
impl HaveMsg {
pub const TYPE: u8 = 4;
pub const LEN: usize = 5;
#[must_use]
pub const fn msg_len() -> u32 {
1 + 4
}
}
#[derive(Clone, PartialEq, Eq)]
pub struct BitfieldMsg<'a>(pub &'a [u8]);
impl<'a> BitfieldMsg<'a> {
pub const TYPE: u8 = 5;
#[must_use]
pub fn msg_len(&self) -> u32 {
1 + (u32::try_from(self.0.len()).unwrap())
}
}
impl<'a> fmt::Debug for BitfieldMsg<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
struct BytesDebug<'a>(&'a [u8]);
impl<'a> fmt::Debug for BytesDebug<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for b in self.0 {
write!(f, "{b:02x}")?;
}
Ok(())
}
}
f.debug_tuple("BitfieldMsg")
.field(&BytesDebug(self.0))
.finish()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RequestMsg(pub Block);
impl RequestMsg {
pub const TYPE: u8 = 6;
pub const LEN: usize = 13;
#[must_use]
pub const fn msg_len() -> u32 {
1 + 4 + 4 + 4
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PieceMsg<'a>(pub BlockData<'a>);
impl<'a> PieceMsg<'a> {
pub const TYPE: u8 = 7;
#[must_use]
pub fn msg_len(&self) -> u32 {
1 + 4 + 4 + u32::try_from(self.0.data.len()).unwrap()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct CancelMsg(pub Block);
impl CancelMsg {
pub const TYPE: u8 = 8;
pub const LEN: usize = 13;
#[must_use]
pub const fn msg_len() -> u32 {
1 + 4 + 4 + 4
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReceivedHandshakeState {
None(usize),
ReceivedProtocol,
ReceivedReservedBytes(ReservedBytes),
ReceivedReservedBytesAndInfoHash(ReservedBytes, InfoHash),
ReceivedHandshake(ReservedBytes, InfoHash, Id),
}
impl Default for ReceivedHandshakeState {
fn default() -> Self {
Self::None(0)
}
}
pub fn parse_handshake<B>(
buf: &mut B,
state: &ReceivedHandshakeState,
) -> Result<Option<ReceivedHandshakeState>, InvalidInput>
where
B: Buf,
{
match state {
ReceivedHandshakeState::None(mut handshake_offset) => {
let offset = handshake_offset;
for _ in offset..core::cmp::min(20, buf.remaining()) {
if PROTOCOL_STRING_BYTES[handshake_offset] != buf.get_u8() {
return Err(InvalidInput);
}
handshake_offset += 1;
}
debug_assert!(handshake_offset <= 20);
if handshake_offset == 20 {
Ok(Some(ReceivedHandshakeState::ReceivedProtocol))
} else {
Ok(Some(ReceivedHandshakeState::None(handshake_offset)))
}
}
ReceivedHandshakeState::ReceivedProtocol => {
if buf.remaining() < 8 {
return Ok(None);
}
let reserved_bytes = {
let mut tmp: [u8; 8] = [0; 8];
buf.copy_to_slice(&mut tmp);
ReservedBytes::from(tmp)
};
Ok(Some(ReceivedHandshakeState::ReceivedReservedBytes(
reserved_bytes,
)))
}
ReceivedHandshakeState::ReceivedReservedBytes(reserved_bytes) => {
if buf.remaining() < 20 {
return Ok(None);
}
let info_hash = {
let mut tmp: [u8; 20] = [0; 20];
buf.copy_to_slice(&mut tmp);
InfoHash::from(tmp)
};
Ok(Some(
ReceivedHandshakeState::ReceivedReservedBytesAndInfoHash(
*reserved_bytes,
info_hash,
),
))
}
ReceivedHandshakeState::ReceivedReservedBytesAndInfoHash(reserved_bytes, info_hash) => {
if buf.remaining() < 20 {
return Ok(None);
}
let peer_id = {
let mut tmp: [u8; 20] = [0; 20];
buf.copy_to_slice(&mut tmp);
peer::Id::from(tmp)
};
Ok(Some(ReceivedHandshakeState::ReceivedHandshake(
*reserved_bytes,
*info_hash,
peer_id,
)))
}
ReceivedHandshakeState::ReceivedHandshake(..) => Ok(None),
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct Metrics {
pub keepalive_msgs: u64,
pub choke_msgs: u64,
pub unchoke_msgs: u64,
pub interested_msgs: u64,
pub not_interested_msgs: u64,
pub have_msgs: u64,
pub bitfield_msgs: u64,
pub bitfield_bytes: u64,
pub request_msgs: u64,
pub piece_msgs: u64,
pub piece_bytes: u64,
pub unrequested_piece_bytes: u64,
pub cancel_msgs: u64,
pub unknown_bytes: u64,
}
impl Metrics {
#[must_use]
pub fn is_any_nonzero(&self) -> bool {
self.keepalive_msgs != 0
|| self.choke_msgs != 0
|| self.unchoke_msgs != 0
|| self.interested_msgs != 0
|| self.not_interested_msgs != 0
|| self.have_msgs != 0
|| self.bitfield_bytes != 0
|| self.bitfield_msgs != 0
|| self.request_msgs != 0
|| self.piece_msgs != 0
|| self.piece_bytes != 0
|| self.unrequested_piece_bytes != 0
|| self.cancel_msgs != 0
|| self.unknown_bytes != 0
}
#[inline]
pub fn add_request(&mut self) {
self.request_msgs = self.request_msgs.saturating_add(1);
}
#[inline]
pub fn add_piece(&mut self, piece_msg: &PieceMsg<'_>) {
let piece_bytes = piece_msg.0.data.len() as u64;
self.piece_bytes = self.piece_bytes.saturating_add(piece_bytes);
self.piece_msgs = self.piece_msgs.saturating_add(1);
}
#[inline]
pub fn add_keepalive(&mut self) {
self.keepalive_msgs = self.keepalive_msgs.saturating_add(1);
}
#[inline]
pub fn add_have(&mut self) {
self.have_msgs = self.have_msgs.saturating_add(1);
}
#[inline]
pub fn add_choke(&mut self) {
self.choke_msgs = self.choke_msgs.saturating_add(1);
}
#[inline]
pub fn add_unchoke(&mut self) {
self.unchoke_msgs = self.unchoke_msgs.saturating_add(1);
}
#[inline]
pub fn add_interested(&mut self) {
self.interested_msgs = self.interested_msgs.saturating_add(1);
}
#[inline]
pub fn add_not_interested(&mut self) {
self.not_interested_msgs = self.not_interested_msgs.saturating_add(1);
}
#[inline]
pub fn add_cancel(&mut self) {
self.cancel_msgs = self.cancel_msgs.saturating_add(1);
}
#[inline]
pub fn add_bitfield(&mut self, bitfield_msg: &BitfieldMsg<'_>) {
let bitfield_msg_len = u64::from(bitfield_msg.msg_len() - 1);
self.bitfield_msgs = self.bitfield_msgs.saturating_add(1);
self.bitfield_bytes = self.bitfield_bytes.saturating_add(bitfield_msg_len);
}
#[inline]
pub fn add_unknown(&mut self, len: u64) {
self.unknown_bytes = self.unknown_bytes.saturating_add(len);
}
pub fn add_frame(&mut self, frame: &Frame<'_>) {
match frame {
Frame::Request(_) => self.add_request(),
Frame::Piece(piece_msg) => self.add_piece(piece_msg),
Frame::KeepAlive => self.add_keepalive(),
Frame::Have(_) => self.add_have(),
Frame::Choke => self.add_choke(),
Frame::Unchoke => self.add_unchoke(),
Frame::Interested => self.add_interested(),
Frame::NotInterested => self.add_not_interested(),
Frame::Cancel(_) => self.add_cancel(),
Frame::Bitfield(bitfield_msg) => self.add_bitfield(bitfield_msg),
Frame::Unknown(_, data) => {
self.add_unknown(u64::try_from(data.len()).unwrap_or(u64::MAX));
}
}
}
}
impl core::ops::Add for Metrics {
type Output = Metrics;
fn add(mut self, rhs: Metrics) -> Metrics {
self += rhs;
self
}
}
impl core::ops::AddAssign for Metrics {
fn add_assign(&mut self, rhs: Metrics) {
self.keepalive_msgs = self.keepalive_msgs.saturating_add(rhs.keepalive_msgs);
self.choke_msgs = self.choke_msgs.saturating_add(rhs.choke_msgs);
self.unchoke_msgs = self.unchoke_msgs.saturating_add(rhs.unchoke_msgs);
self.interested_msgs = self.interested_msgs.saturating_add(rhs.interested_msgs);
self.not_interested_msgs = self
.not_interested_msgs
.saturating_add(rhs.not_interested_msgs);
self.have_msgs = self.have_msgs.saturating_add(rhs.have_msgs);
self.bitfield_msgs = self.bitfield_msgs.saturating_add(rhs.bitfield_msgs);
self.bitfield_bytes = self.bitfield_bytes.saturating_add(rhs.bitfield_bytes);
self.request_msgs = self.request_msgs.saturating_add(rhs.request_msgs);
self.piece_msgs = self.piece_msgs.saturating_add(rhs.piece_msgs);
self.piece_bytes = self.piece_bytes.saturating_add(rhs.piece_bytes);
self.unrequested_piece_bytes = self
.unrequested_piece_bytes
.saturating_add(rhs.unrequested_piece_bytes);
self.cancel_msgs = self.cancel_msgs.saturating_add(rhs.cancel_msgs);
self.unknown_bytes = self.unknown_bytes.saturating_add(rhs.unknown_bytes);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_metrics_size() {
assert_eq!(core::mem::size_of::<Metrics>(), 112);
}
}