use std::collections::HashSet;
use std::net::IpAddr;
use sha1::{Digest, Sha1};
const ALLOWED_FAST_SET_SIZE: usize = 10;
#[derive(Debug, Clone, Default)]
pub struct FastExtensionState {
pub allowed_fast_incoming: HashSet<u32>,
pub allowed_fast_outgoing: HashSet<u32>,
pub suggested_pieces: Vec<u32>,
pub peer_has_all: bool,
pub peer_has_none: bool,
}
impl FastExtensionState {
pub fn new() -> Self {
Self::default()
}
pub fn add_allowed_fast_incoming(&mut self, piece: u32) {
self.allowed_fast_incoming.insert(piece);
}
pub fn add_allowed_fast_outgoing(&mut self, piece: u32) {
self.allowed_fast_outgoing.insert(piece);
}
pub fn can_request_while_choked(&self, piece: u32) -> bool {
self.allowed_fast_incoming.contains(&piece)
}
pub fn should_serve_choked_request(&self, piece: u32) -> bool {
self.allowed_fast_outgoing.contains(&piece)
}
pub fn add_suggested(&mut self, piece: u32) {
if !self.suggested_pieces.contains(&piece) {
self.suggested_pieces.push(piece);
}
}
pub fn set_have_all(&mut self) {
self.peer_has_all = true;
self.peer_has_none = false;
}
pub fn set_have_none(&mut self) {
self.peer_has_none = true;
self.peer_has_all = false;
}
pub fn clear(&mut self) {
self.allowed_fast_incoming.clear();
self.allowed_fast_outgoing.clear();
self.suggested_pieces.clear();
self.peer_has_all = false;
self.peer_has_none = false;
}
}
pub fn generate_allowed_fast_set(
info_hash: &[u8; 20],
peer_ip: IpAddr,
num_pieces: u32,
set_size: usize,
) -> Vec<u32> {
if num_pieces == 0 {
return Vec::new();
}
let mut allowed_set = Vec::with_capacity(set_size);
let ip_bytes = match peer_ip {
IpAddr::V4(ip) => {
let octets = ip.octets();
[octets[0], octets[1], octets[2], 0]
}
IpAddr::V6(ip) => {
let octets = ip.octets();
[octets[0], octets[1], octets[2], octets[3]]
}
};
let mut x = Vec::with_capacity(24);
x.extend_from_slice(&ip_bytes);
x.extend_from_slice(info_hash);
while allowed_set.len() < set_size {
let mut hasher = Sha1::new();
hasher.update(&x);
let hash = hasher.finalize();
for chunk in hash.chunks(4) {
if allowed_set.len() >= set_size {
break;
}
let index = u32::from_be_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]) % num_pieces;
if !allowed_set.contains(&index) {
allowed_set.push(index);
}
}
x = hash.to_vec();
}
allowed_set
}
pub struct FastExtension {
allowed_fast_set: Vec<u32>,
}
impl FastExtension {
pub fn new() -> Self {
Self {
allowed_fast_set: Vec::new(),
}
}
pub fn compute_allowed_fast_set(
&mut self,
info_hash: &[u8; 20],
peer_ip: IpAddr,
num_pieces: u32,
) {
self.allowed_fast_set =
generate_allowed_fast_set(info_hash, peer_ip, num_pieces, ALLOWED_FAST_SET_SIZE);
}
pub fn allowed_fast_set(&self) -> &[u32] {
&self.allowed_fast_set
}
pub fn is_allowed_fast(&self, piece: u32) -> bool {
self.allowed_fast_set.contains(&piece)
}
}
impl Default for FastExtension {
fn default() -> Self {
Self::new()
}
}