use bytes::{Buf, BufMut, Bytes, BytesMut};
use irontide_core::Id20;
use crate::error::{Error, Result};
const PROTOCOL: &[u8] = b"BitTorrent protocol";
pub const HANDSHAKE_SIZE: usize = 68;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Handshake {
pub reserved: [u8; 8],
pub info_hash: Id20,
pub peer_id: Id20,
}
impl Handshake {
#[must_use]
pub fn new(info_hash: Id20, peer_id: Id20) -> Self {
let mut reserved = [0u8; 8];
reserved[5] |= 0x10;
Self {
reserved,
info_hash,
peer_id,
}
}
#[must_use]
pub fn supports_extensions(&self) -> bool {
self.reserved[5] & 0x10 != 0
}
#[must_use]
pub fn supports_dht(&self) -> bool {
self.reserved[7] & 0x01 != 0
}
#[must_use]
pub fn with_dht(mut self) -> Self {
self.reserved[7] |= 0x01;
self
}
#[must_use]
pub fn supports_fast(&self) -> bool {
self.reserved[7] & 0x04 != 0
}
#[must_use]
pub fn with_fast(mut self) -> Self {
self.reserved[7] |= 0x04;
self
}
#[must_use]
pub fn to_bytes(&self) -> Bytes {
let mut buf = BytesMut::with_capacity(HANDSHAKE_SIZE);
buf.put_u8(19);
buf.put_slice(PROTOCOL);
buf.put_slice(&self.reserved);
buf.put_slice(self.info_hash.as_bytes());
buf.put_slice(self.peer_id.as_bytes());
buf.freeze()
}
pub fn from_bytes(mut data: &[u8]) -> Result<Self> {
if data.len() < HANDSHAKE_SIZE {
return Err(Error::InvalidHandshake(format!(
"need {} bytes, got {}",
HANDSHAKE_SIZE,
data.len()
)));
}
let pstrlen = data.get_u8();
if pstrlen != 19 {
return Err(Error::InvalidHandshake(format!(
"pstrlen {pstrlen}, expected 19"
)));
}
let pstr = &data[..19];
if pstr != PROTOCOL {
return Err(Error::InvalidHandshake("wrong protocol string".into()));
}
data.advance(19);
let mut reserved = [0u8; 8];
reserved.copy_from_slice(&data[..8]);
data.advance(8);
let info_hash =
Id20::from_bytes(&data[..20]).map_err(|e| Error::InvalidHandshake(e.to_string()))?;
data.advance(20);
let peer_id =
Id20::from_bytes(&data[..20]).map_err(|e| Error::InvalidHandshake(e.to_string()))?;
Ok(Self {
reserved,
info_hash,
peer_id,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn handshake_round_trip() {
let info_hash = Id20::from_hex("aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d").unwrap();
let peer_id = Id20::from_hex("0102030405060708091011121314151617181920").unwrap();
let hs = Handshake::new(info_hash, peer_id);
assert!(hs.supports_extensions());
let bytes = hs.to_bytes();
assert_eq!(bytes.len(), HANDSHAKE_SIZE);
let parsed = Handshake::from_bytes(&bytes).unwrap();
assert_eq!(hs, parsed);
}
#[test]
fn handshake_dht_flag() {
let hs = Handshake::new(Id20::ZERO, Id20::ZERO).with_dht();
assert!(hs.supports_dht());
assert!(hs.supports_extensions());
let parsed = Handshake::from_bytes(&hs.to_bytes()).unwrap();
assert!(parsed.supports_dht());
}
#[test]
fn ext_handshake_reserved_bit_position() {
let mut expected_reserved = [0u8; 8];
expected_reserved[5] = 0x10;
let hs = Handshake::new(Id20::ZERO, Id20::ZERO);
assert_eq!(
hs.reserved, expected_reserved,
"Handshake::new() reserved field must match BEP 10 spec"
);
assert_eq!(
hs.reserved[5] & 0x10,
0x10,
"BEP 10 extension bit must be at reserved[5] & 0x10"
);
assert!(hs.supports_extensions());
for (i, &byte) in hs.reserved.iter().enumerate() {
if i == 5 {
assert_eq!(byte, 0x10, "byte 5 should be exactly 0x10");
} else {
assert_eq!(byte, 0, "byte {i} should be zero when only BEP 10 is set");
}
}
let mut hs_no_ext = hs;
hs_no_ext.reserved[5] &= !0x10;
assert!(!hs_no_ext.supports_extensions());
let bit_index = 20;
let byte_index_from_right = bit_index / 8; let byte_index = 7 - byte_index_from_right; let bit_within_byte = bit_index % 8; assert_eq!(byte_index, 5);
assert_eq!(bit_within_byte, 4);
assert_eq!(1u8 << bit_within_byte, 0x10);
}
#[test]
fn handshake_too_short() {
assert!(Handshake::from_bytes(&[0u8; 10]).is_err());
}
#[test]
fn handshake_fast_flag() {
let hs = Handshake::new(Id20::ZERO, Id20::ZERO).with_fast();
assert!(hs.supports_fast());
assert!(hs.supports_extensions());
let hs2 = hs.with_dht();
assert!(hs2.supports_fast());
assert!(hs2.supports_dht());
let parsed = Handshake::from_bytes(&hs2.to_bytes()).unwrap();
assert!(parsed.supports_fast());
assert!(parsed.supports_dht());
}
}