use std::net::SocketAddr;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
use tokio::net::TcpStream;
use crate::error::{Error, ErrorKind};
use crate::peer::{Handshake, PeerId, PeerMessage, PeerState, decode, encode};
const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(30);
pub struct PeerConnection {
stream: BufReader<BufWriter<TcpStream>>,
state: PeerState,
#[allow(dead_code)]
info_hash: [u8; 20],
#[allow(dead_code)]
our_peer_id: PeerId,
remote_peer_id: Option<PeerId>,
}
impl PeerConnection {
pub async fn connect(
addr: SocketAddr,
info_hash: [u8; 20],
our_peer_id: PeerId,
) -> Result<Self, Error> {
tracing::debug!("connecting to peer {}", addr);
let raw_stream =
match tokio::time::timeout(HANDSHAKE_TIMEOUT, TcpStream::connect(addr)).await {
Ok(Ok(s)) => s,
_ => return Err(Error::new(ErrorKind::PeerConnectionClosed)),
};
let stream = BufReader::new(BufWriter::new(raw_stream));
let mut conn = PeerConnection {
stream,
state: PeerState::Handshake,
info_hash,
our_peer_id,
remote_peer_id: None,
};
let handshake = Handshake::new(info_hash, our_peer_id.0);
let handshake_bytes = handshake.to_bytes();
if let Err(e) = conn.stream.get_mut().write_all(&handshake_bytes).await {
return Err(Error::with_source(ErrorKind::PeerConnectionClosed, e));
}
if let Err(e) = conn.stream.get_mut().flush().await {
return Err(Error::with_source(ErrorKind::PeerConnectionClosed, e));
}
let mut buf = [0u8; 68];
match tokio::time::timeout(HANDSHAKE_TIMEOUT, read_exact(&mut conn, &mut buf)).await {
Ok(Ok(())) => {}
_ => return Err(Error::new(ErrorKind::PeerConnectionClosed)),
};
let remote_handshake = Handshake::from_bytes(&buf)?;
if remote_handshake.info_hash != info_hash {
return Err(Error::new(ErrorKind::PeerInvalidHandshake));
}
conn.remote_peer_id = Some(PeerId(remote_handshake.peer_id));
conn.state = PeerState::Init;
tracing::info!("handshake complete with {}", addr);
Ok(conn)
}
pub async fn send(&mut self, msg: &PeerMessage) -> Result<(), Error> {
tracing::trace!("sending {:?} to peer", msg);
let data = encode(msg);
if let Err(e) = self.stream.get_mut().write_all(&data).await {
return Err(Error::with_source(ErrorKind::PeerConnectionClosed, e));
}
if let Err(e) = self.stream.get_mut().flush().await {
return Err(Error::with_source(ErrorKind::PeerConnectionClosed, e));
}
Ok(())
}
pub async fn recv(&mut self) -> Result<PeerMessage, Error> {
let mut len_buf = [0u8; 4];
read_exact(self, &mut len_buf).await?;
let len = u32::from_be_bytes(len_buf);
if len == 0 {
tracing::trace!("received KeepAlive from peer");
return Ok(PeerMessage::KeepAlive);
}
let mut msg_buf = vec![0u8; len as usize];
read_exact(self, &mut msg_buf).await?;
let mut full_msg = len_buf.to_vec();
full_msg.extend_from_slice(&msg_buf);
decode(&full_msg)
}
pub fn state(&self) -> PeerState {
self.state
}
pub fn set_state(&mut self, state: PeerState) {
self.state = state;
}
pub fn remote_peer_id(&self) -> Option<PeerId> {
self.remote_peer_id
}
}
async fn read_exact(conn: &mut PeerConnection, buf: &mut [u8]) -> Result<(), Error> {
if let Err(e) = conn.stream.read_exact(buf).await {
return Err(Error::with_source(ErrorKind::PeerConnectionClosed, e));
}
Ok(())
}