use std::collections::HashMap;
use std::collections::HashSet;
use std::net::SocketAddrV4;
use tracing::{debug, trace};
use super::{socket::KrpcSocket, ClosestNodes};
use crate::common::{FindNodeRequestArguments, GetPeersRequestArguments, GetValueRequestArguments};
use crate::{
common::{Id, Node, RequestSpecific, RequestTypeSpecific, MAX_BUCKET_SIZE_K},
rpc::Response,
};
#[derive(Debug)]
pub(crate) struct IterativeQuery {
pub request: RequestSpecific,
closest: ClosestNodes,
responders: ClosestNodes,
inflight_requests: Vec<u32>,
visited: HashSet<SocketAddrV4>,
responses: Vec<Response>,
public_address_votes: HashMap<SocketAddrV4, u16>,
}
#[derive(Debug)]
pub enum GetRequestSpecific {
FindNode(FindNodeRequestArguments),
GetPeers(GetPeersRequestArguments),
GetValue(GetValueRequestArguments),
}
impl GetRequestSpecific {
pub fn target(&self) -> &Id {
match self {
GetRequestSpecific::FindNode(args) => &args.target,
GetRequestSpecific::GetPeers(args) => &args.info_hash,
GetRequestSpecific::GetValue(args) => &args.target,
}
}
}
impl IterativeQuery {
pub fn new(requester_id: Id, target: Id, request: GetRequestSpecific) -> Self {
let request_type = match request {
GetRequestSpecific::FindNode(s) => RequestTypeSpecific::FindNode(s),
GetRequestSpecific::GetPeers(s) => RequestTypeSpecific::GetPeers(s),
GetRequestSpecific::GetValue(s) => RequestTypeSpecific::GetValue(s),
};
trace!(?target, ?request_type, "New Query");
Self {
request: RequestSpecific {
requester_id,
request_type,
},
closest: ClosestNodes::new(target),
responders: ClosestNodes::new(target),
inflight_requests: Vec::new(),
visited: HashSet::new(),
responses: Vec::new(),
public_address_votes: HashMap::new(),
}
}
pub fn target(&self) -> Id {
self.responders.target()
}
pub fn closest(&self) -> &ClosestNodes {
&self.closest
}
pub fn responders(&self) -> &ClosestNodes {
&self.responders
}
pub fn responses(&self) -> &[Response] {
&self.responses
}
pub fn best_address(&self) -> Option<SocketAddrV4> {
let mut max = 0_u16;
let mut best_addr = None;
for (addr, count) in self.public_address_votes.iter() {
if *count > max {
max = *count;
best_addr = Some(*addr);
};
}
best_addr
}
pub fn start(&mut self, socket: &mut KrpcSocket) {
self.visit_closest(socket);
}
pub fn add_candidate(&mut self, node: Node) {
self.closest.add(node);
}
pub fn add_address_vote(&mut self, address: SocketAddrV4) {
self.public_address_votes
.entry(address)
.and_modify(|counter| *counter += 1)
.or_insert(1);
}
pub fn visit(&mut self, socket: &mut KrpcSocket, address: SocketAddrV4) {
let tid = socket.request(address, self.request.clone());
self.inflight_requests.push(tid);
let tid = socket.request(
address,
RequestSpecific {
requester_id: Id::random(),
request_type: RequestTypeSpecific::Ping,
},
);
self.inflight_requests.push(tid);
self.visited.insert(address);
}
pub fn inflight(&self, tid: u32) -> bool {
self.inflight_requests.contains(&tid)
}
pub fn add_responding_node(&mut self, node: Node) {
self.responders.add(node)
}
pub fn response(&mut self, from: SocketAddrV4, response: Response) {
let target = self.target();
debug!(?target, ?response, ?from, "Query got response");
self.responses.push(response.to_owned());
}
pub fn tick(&mut self, socket: &mut KrpcSocket) -> bool {
self.visit_closest(socket);
let done = !self
.inflight_requests
.iter()
.any(|&tid| socket.inflight(tid));
if done {
debug!(id=?self.target(), closest = ?self.closest.len(), visited = ?self.visited.len(), responders = ?self.responders.len(), "Done query");
};
done
}
fn visit_closest(&mut self, socket: &mut KrpcSocket) {
let to_visit = self
.closest
.nodes()
.iter()
.take(MAX_BUCKET_SIZE_K)
.filter(|node| !self.visited.contains(&node.address()))
.map(|node| node.address())
.collect::<Vec<_>>();
for address in to_visit {
self.visit(socket, address);
}
}
}