use std::net::SocketAddr;
use std::time::Instant;
use bytes::Bytes;
use tokio::net::TcpStream;
use super::bitfield::Bitfield;
use super::choking::ChokingState;
use super::error::PeerError;
use super::extension::ExtensionHandshake;
use super::message::{validate_hash_request, Handshake, Message};
use super::peer_id::PeerId;
use super::transport::PeerTransport;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PeerState {
Connecting,
Handshaking,
Connected,
Disconnected,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ProtocolMode {
#[default]
V1,
V2,
}
pub struct PeerConnection {
pub addr: SocketAddr,
pub peer_id: Option<PeerId>,
pub state: PeerState,
pub choking: ChokingState,
pub bitfield: Option<Bitfield>,
pub extension_handshake: Option<ExtensionHandshake>,
pub supports_fast: bool,
pub supports_extension: bool,
pub supports_v2: bool,
pub protocol_mode: ProtocolMode,
pub connected_at: Instant,
pub last_message_at: Instant,
pub bytes_downloaded: u64,
pub bytes_uploaded: u64,
transport: Option<PeerTransport>,
}
impl PeerConnection {
pub async fn connect(
addr: SocketAddr,
info_hash: [u8; 20],
our_peer_id: [u8; 20],
) -> Result<Self, PeerError> {
Self::connect_with_mode(addr, info_hash, our_peer_id, false).await
}
pub async fn connect_with_mode(
addr: SocketAddr,
info_hash: [u8; 20],
our_peer_id: [u8; 20],
advertise_v2: bool,
) -> Result<Self, PeerError> {
let stream = TcpStream::connect(addr).await?;
let mut transport = PeerTransport::new(stream);
let handshake = if advertise_v2 {
Handshake::new_v2(info_hash, our_peer_id)
} else {
Handshake::new(info_hash, our_peer_id)
};
transport.send_handshake(&handshake).await?;
let their_handshake = transport.receive_handshake().await?;
if their_handshake.info_hash != info_hash {
return Err(PeerError::InfoHashMismatch);
}
let supports_v2 = their_handshake.supports_v2();
let protocol_mode = if advertise_v2 && supports_v2 {
ProtocolMode::V2
} else {
ProtocolMode::V1
};
let now = Instant::now();
Ok(Self {
addr,
peer_id: PeerId::from_bytes(&their_handshake.peer_id),
state: PeerState::Connected,
choking: ChokingState::default(),
bitfield: None,
extension_handshake: None,
supports_fast: their_handshake.supports_fast_extension(),
supports_extension: their_handshake.supports_extension_protocol(),
supports_v2,
protocol_mode,
connected_at: now,
last_message_at: now,
bytes_downloaded: 0,
bytes_uploaded: 0,
transport: Some(transport),
})
}
pub async fn accept(
stream: TcpStream,
info_hash: [u8; 20],
our_peer_id: [u8; 20],
) -> Result<Self, PeerError> {
Self::accept_with_mode(stream, info_hash, our_peer_id, false).await
}
pub async fn accept_with_mode(
stream: TcpStream,
info_hash: [u8; 20],
our_peer_id: [u8; 20],
advertise_v2: bool,
) -> Result<Self, PeerError> {
let addr = stream.peer_addr()?;
let mut transport = PeerTransport::new(stream);
let their_handshake = transport.receive_handshake().await?;
if their_handshake.info_hash != info_hash {
return Err(PeerError::InfoHashMismatch);
}
let handshake = if advertise_v2 {
Handshake::new_v2(info_hash, our_peer_id)
} else {
Handshake::new(info_hash, our_peer_id)
};
transport.send_handshake(&handshake).await?;
let supports_v2 = their_handshake.supports_v2();
let protocol_mode = if advertise_v2 && supports_v2 {
ProtocolMode::V2
} else {
ProtocolMode::V1
};
let now = Instant::now();
Ok(Self {
addr,
peer_id: PeerId::from_bytes(&their_handshake.peer_id),
state: PeerState::Connected,
choking: ChokingState::default(),
bitfield: None,
extension_handshake: None,
supports_fast: their_handshake.supports_fast_extension(),
supports_extension: their_handshake.supports_extension_protocol(),
supports_v2,
protocol_mode,
connected_at: now,
last_message_at: now,
bytes_downloaded: 0,
bytes_uploaded: 0,
transport: Some(transport),
})
}
pub async fn send(&mut self, message: Message) -> Result<(), PeerError> {
if let Some(ref mut transport) = self.transport {
transport.send_message(&message).await?;
if let Message::Piece { ref data, .. } = message {
self.bytes_uploaded += data.len() as u64;
}
Ok(())
} else {
Err(PeerError::ConnectionClosed)
}
}
pub async fn receive(&mut self) -> Result<Message, PeerError> {
if let Some(ref mut transport) = self.transport {
let message = transport.receive_message().await?;
self.last_message_at = Instant::now();
match &message {
Message::Choke => self.choking.peer_choking = true,
Message::Unchoke => self.choking.peer_choking = false,
Message::Interested => self.choking.peer_interested = true,
Message::NotInterested => self.choking.peer_interested = false,
Message::Piece { data, .. } => {
self.bytes_downloaded += data.len() as u64;
}
_ => {}
}
Ok(message)
} else {
Err(PeerError::ConnectionClosed)
}
}
pub fn disconnect(&mut self) {
self.transport = None;
self.state = PeerState::Disconnected;
}
pub fn is_connected(&self) -> bool {
self.state == PeerState::Connected && self.transport.is_some()
}
pub fn can_request(&self) -> bool {
self.is_connected() && !self.choking.peer_choking && self.choking.am_interested
}
pub fn set_interested(&mut self, interested: bool) {
self.choking.am_interested = interested;
}
pub fn set_choking(&mut self, choking: bool) {
self.choking.am_choking = choking;
}
pub fn is_v2_mode(&self) -> bool {
self.protocol_mode == ProtocolMode::V2
}
pub fn is_v1_mode(&self) -> bool {
self.protocol_mode == ProtocolMode::V1
}
pub fn upgrade_to_v2(&mut self) {
if self.supports_v2 {
self.protocol_mode = ProtocolMode::V2;
}
}
pub fn set_protocol_mode(&mut self, mode: ProtocolMode) {
self.protocol_mode = mode;
}
pub fn can_use_hash_messages(&self) -> bool {
self.supports_v2 && self.protocol_mode == ProtocolMode::V2
}
pub async fn send_hash_request(
&mut self,
pieces_root: [u8; 32],
base_layer: u32,
index: u32,
length: u32,
proof_layers: u32,
) -> Result<(), PeerError> {
if !self.can_use_hash_messages() {
return Err(PeerError::Protocol("v2 hash messages not supported".into()));
}
if let Some(err) = validate_hash_request(length, index) {
return Err(PeerError::InvalidMessage(err.into()));
}
let message = Message::HashRequest {
pieces_root,
base_layer,
index,
length,
proof_layers,
};
self.send(message).await
}
pub async fn send_hashes(
&mut self,
pieces_root: [u8; 32],
base_layer: u32,
index: u32,
length: u32,
proof_layers: u32,
hashes: Bytes,
) -> Result<(), PeerError> {
if !self.can_use_hash_messages() {
return Err(PeerError::Protocol("v2 hash messages not supported".into()));
}
let expected_len = ((length + proof_layers) as usize) * 32;
if hashes.len() != expected_len {
return Err(PeerError::InvalidMessage(format!(
"hash data length {} doesn't match expected {}",
hashes.len(),
expected_len
)));
}
let message = Message::Hashes {
pieces_root,
base_layer,
index,
length,
proof_layers,
hashes,
};
self.send(message).await
}
pub async fn send_hash_reject(
&mut self,
pieces_root: [u8; 32],
base_layer: u32,
index: u32,
length: u32,
proof_layers: u32,
) -> Result<(), PeerError> {
if !self.can_use_hash_messages() {
return Err(PeerError::Protocol("v2 hash messages not supported".into()));
}
let message = Message::HashReject {
pieces_root,
base_layer,
index,
length,
proof_layers,
};
self.send(message).await
}
pub fn is_hash_message(message: &Message) -> bool {
matches!(
message,
Message::HashRequest { .. } | Message::Hashes { .. } | Message::HashReject { .. }
)
}
}