rbit 0.2.2

A BitTorrent library implementing BEP specifications
Documentation
use std::time::Duration;

use bytes::BytesMut;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::timeout;

use super::error::PeerError;
use super::message::{Handshake, Message, HANDSHAKE_LEN};

const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
const READ_TIMEOUT: Duration = Duration::from_secs(120);
const WRITE_TIMEOUT: Duration = Duration::from_secs(30);

pub struct PeerTransport {
    stream: TcpStream,
    read_buf: BytesMut,
}

impl PeerTransport {
    pub fn new(stream: TcpStream) -> Self {
        Self {
            stream,
            read_buf: BytesMut::with_capacity(32 * 1024),
        }
    }

    pub async fn send_handshake(&mut self, handshake: &Handshake) -> Result<(), PeerError> {
        let data = handshake.encode();
        timeout(WRITE_TIMEOUT, self.stream.write_all(&data))
            .await
            .map_err(|_| PeerError::Timeout)??;
        Ok(())
    }

    pub async fn receive_handshake(&mut self) -> Result<Handshake, PeerError> {
        while self.read_buf.len() < HANDSHAKE_LEN {
            let n = timeout(READ_TIMEOUT, self.stream.read_buf(&mut self.read_buf))
                .await
                .map_err(|_| PeerError::Timeout)??;

            if n == 0 {
                return Err(PeerError::ConnectionClosed);
            }
        }

        let data = self.read_buf.split_to(HANDSHAKE_LEN);
        Handshake::decode(&data)
    }

    pub async fn send_message(&mut self, message: &Message) -> Result<(), PeerError> {
        let data = message.encode();
        timeout(WRITE_TIMEOUT, self.stream.write_all(&data))
            .await
            .map_err(|_| PeerError::Timeout)??;
        Ok(())
    }

    pub async fn receive_message(&mut self) -> Result<Message, PeerError> {
        while self.read_buf.len() < 4 {
            let n = timeout(READ_TIMEOUT, self.stream.read_buf(&mut self.read_buf))
                .await
                .map_err(|_| PeerError::Timeout)??;

            if n == 0 {
                return Err(PeerError::ConnectionClosed);
            }
        }

        let length = u32::from_be_bytes([
            self.read_buf[0],
            self.read_buf[1],
            self.read_buf[2],
            self.read_buf[3],
        ]) as usize;

        if length > MAX_MESSAGE_SIZE {
            return Err(PeerError::InvalidMessage(format!(
                "message too large: {}",
                length
            )));
        }

        let total_len = 4 + length;
        while self.read_buf.len() < total_len {
            let n = timeout(READ_TIMEOUT, self.stream.read_buf(&mut self.read_buf))
                .await
                .map_err(|_| PeerError::Timeout)??;

            if n == 0 {
                return Err(PeerError::ConnectionClosed);
            }
        }

        let data = self.read_buf.split_to(total_len);
        Message::decode(data.freeze())
    }

    pub fn into_inner(self) -> TcpStream {
        self.stream
    }

    pub fn peer_addr(&self) -> std::io::Result<std::net::SocketAddr> {
        self.stream.peer_addr()
    }
}