use std::net::SocketAddr;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
use tokio::net::TcpStream;
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::sync::Mutex;
use crate::error::{Error, ErrorKind};
use super::{Handshake, PeerId, PeerMessage, PeerState, decode, encode};
const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(30);
const MAX_MESSAGE_SIZE: u32 = 2 * 1024 * 1024;
const MESSAGE_READ_TIMEOUT: Duration = Duration::from_secs(60);
const MESSAGE_WRITE_TIMEOUT: Duration = Duration::from_secs(30);
pub struct PeerConnection {
reader: Mutex<BufReader<OwnedReadHalf>>,
writer: Mutex<BufWriter<OwnedWriteHalf>>,
state: PeerState,
remote_peer_id: Option<PeerId>,
remote_reserved: [u8; 8],
}
impl PeerConnection {
pub async fn connect(
addr: SocketAddr, info_hash: [u8; 20], our_peer_id: PeerId,
) -> Result<Self, Error> {
tracing::debug!("connecting to peer {}", addr);
let mut raw_stream =
match tokio::time::timeout(HANDSHAKE_TIMEOUT, TcpStream::connect(addr)).await {
Ok(Ok(s)) => s,
_ => return Err(Error::new(ErrorKind::PeerConnectionClosed)),
};
let mut handshake = Handshake::with_extensions(info_hash, our_peer_id.0, &[63]);
handshake.set_reserved_byte(5, handshake.reserved[5] | 0x10);
let handshake_bytes = handshake.to_bytes();
if let Err(e) =
tokio::time::timeout(HANDSHAKE_TIMEOUT, raw_stream.write_all(&handshake_bytes)).await
{
return Err(Error::with_source(ErrorKind::PeerConnectionClosed, e));
}
if let Err(e) = tokio::time::timeout(HANDSHAKE_TIMEOUT, raw_stream.flush()).await {
return Err(Error::with_source(ErrorKind::PeerConnectionClosed, e));
}
let mut buf = [0u8; 68];
match tokio::time::timeout(HANDSHAKE_TIMEOUT, raw_stream.read_exact(&mut buf)).await {
Ok(Ok(_n)) => {}
_ => return Err(Error::new(ErrorKind::PeerConnectionClosed)),
};
let remote_handshake = Handshake::from_bytes(&buf)?;
if remote_handshake.info_hash != info_hash {
return Err(Error::new(ErrorKind::PeerInvalidHandshake));
}
let remote_reserved = remote_handshake.reserved;
let (read_half, write_half) = raw_stream.into_split();
tracing::info!("handshake complete with {}", addr);
Ok(PeerConnection {
reader: Mutex::new(BufReader::new(read_half)),
writer: Mutex::new(BufWriter::new(write_half)),
state: PeerState::Init,
remote_peer_id: Some(PeerId(remote_handshake.peer_id)),
remote_reserved,
})
}
pub async fn send(&self, msg: &PeerMessage) -> Result<(), Error> {
tracing::trace!("sending {:?} to peer", msg);
let data = encode(msg);
let mut writer = self.writer.lock().await;
tokio::time::timeout(MESSAGE_WRITE_TIMEOUT, writer.write_all(&data))
.await
.map_err(|_| Error::new(ErrorKind::PeerConnectionClosed))?
.map_err(|e| Error::with_source(ErrorKind::PeerConnectionClosed, e))?;
tokio::time::timeout(MESSAGE_WRITE_TIMEOUT, writer.flush())
.await
.map_err(|_| Error::new(ErrorKind::PeerConnectionClosed))?
.map_err(|e| Error::with_source(ErrorKind::PeerConnectionClosed, e))?;
Ok(())
}
pub async fn recv(&self) -> Result<PeerMessage, Error> {
let mut reader = self.reader.lock().await;
let mut len_buf = [0u8; 4];
tokio::time::timeout(MESSAGE_READ_TIMEOUT, reader.read_exact(&mut len_buf))
.await
.map_err(|_| Error::new(ErrorKind::PeerConnectionClosed))?
.map_err(|e| Error::with_source(ErrorKind::PeerConnectionClosed, e))?;
let len = u32::from_be_bytes(len_buf);
if len == 0 {
tracing::trace!("received KeepAlive from peer");
return Ok(PeerMessage::KeepAlive);
}
if len > MAX_MESSAGE_SIZE {
return Err(Error::new(ErrorKind::PeerConnectionClosed));
}
let mut msg_buf = vec![0u8; len as usize];
tokio::time::timeout(MESSAGE_READ_TIMEOUT, reader.read_exact(&mut msg_buf))
.await
.map_err(|_| Error::new(ErrorKind::PeerConnectionClosed))?
.map_err(|e| Error::with_source(ErrorKind::PeerConnectionClosed, e))?;
let mut full_msg = len_buf.to_vec();
full_msg.extend_from_slice(&msg_buf);
decode(&full_msg)
}
pub fn state(&self) -> PeerState {
self.state
}
pub fn set_state(&mut self, state: PeerState) {
self.state = state;
}
pub fn remote_peer_id(&self) -> Option<PeerId> {
self.remote_peer_id
}
pub fn remote_has_extension(&self, bit: usize) -> bool {
if bit >= 64 {
return false;
}
let byte = bit / 8;
let bit_in_byte = 7 - (bit % 8);
(self.remote_reserved[byte] >> bit_in_byte) & 1 == 1
}
pub fn remote_reserved(&self) -> &[u8; 8] {
&self.remote_reserved
}
}