use std::{
collections::{HashMap, HashSet},
convert,
fmt::Debug,
marker::PhantomData,
net::IpAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Instant,
};
use futures::{
channel::{mpsc, oneshot},
future::{FutureExt, TryFutureExt},
prelude::*,
stream::FuturesUnordered,
task::noop_waker,
};
use indexmap::IndexMap;
use itertools::Itertools;
use num_integer::div_ceil;
use tokio::{
sync::{broadcast, watch},
task::JoinHandle,
};
use tower::{
discover::{Change, Discover},
load::Load,
Service,
};
use zebra_chain::{chain_tip::ChainTip, parameters::Network};
use crate::{
address_book::AddressMetrics,
constants::MIN_PEER_SET_LOG_INTERVAL,
peer::{LoadTrackedClient, MinimumPeerVersion},
peer_set::{
unready_service::{Error as UnreadyError, UnreadyService},
InventoryChange, InventoryRegistry,
},
protocol::{
external::InventoryHash,
internal::{Request, Response},
},
BoxError, Config, PeerError, PeerSocketAddr, SharedPeerError,
};
#[cfg(test)]
mod tests;
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct MorePeers;
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct CancelClientWork;
type ResponseFuture = Pin<Box<dyn Future<Output = Result<Response, BoxError>> + Send + 'static>>;
pub struct PeerSet<D, C>
where
D: Discover<Key = PeerSocketAddr, Service = LoadTrackedClient> + Unpin,
D::Error: Into<BoxError>,
C: ChainTip,
{
discover: D,
demand_signal: mpsc::Sender<MorePeers>,
bans_receiver: watch::Receiver<Arc<IndexMap<IpAddr, std::time::Instant>>>,
ready_services: HashMap<D::Key, D::Service>,
inventory_registry: InventoryRegistry,
queued_broadcast_all: Option<(
Request,
tokio::sync::mpsc::Sender<ResponseFuture>,
HashSet<D::Key>,
)>,
unready_services: FuturesUnordered<UnreadyService<D::Key, D::Service, Request>>,
cancel_handles: HashMap<D::Key, oneshot::Sender<CancelClientWork>>,
minimum_peer_version: MinimumPeerVersion<C>,
peerset_total_connection_limit: usize,
handle_rx: tokio::sync::oneshot::Receiver<Vec<JoinHandle<Result<(), BoxError>>>>,
guards: futures::stream::FuturesUnordered<JoinHandle<Result<(), BoxError>>>,
address_metrics: watch::Receiver<AddressMetrics>,
last_peer_log: Option<Instant>,
max_conns_per_ip: usize,
network: Network,
}
impl<D, C> Drop for PeerSet<D, C>
where
D: Discover<Key = PeerSocketAddr, Service = LoadTrackedClient> + Unpin,
D::Error: Into<BoxError>,
C: ChainTip,
{
fn drop(&mut self) {
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
self.shut_down_tasks_and_channels(&mut cx);
}
}
impl<D, C> PeerSet<D, C>
where
D: Discover<Key = PeerSocketAddr, Service = LoadTrackedClient> + Unpin,
D::Error: Into<BoxError>,
C: ChainTip,
{
#[allow(clippy::too_many_arguments)]
pub fn new(
config: &Config,
discover: D,
demand_signal: mpsc::Sender<MorePeers>,
handle_rx: tokio::sync::oneshot::Receiver<Vec<JoinHandle<Result<(), BoxError>>>>,
inv_stream: broadcast::Receiver<InventoryChange>,
bans_receiver: watch::Receiver<Arc<IndexMap<IpAddr, std::time::Instant>>>,
address_metrics: watch::Receiver<AddressMetrics>,
minimum_peer_version: MinimumPeerVersion<C>,
max_conns_per_ip: Option<usize>,
) -> Self {
Self {
discover,
demand_signal,
bans_receiver,
ready_services: HashMap::new(),
inventory_registry: InventoryRegistry::new(inv_stream),
queued_broadcast_all: None,
unready_services: FuturesUnordered::new(),
cancel_handles: HashMap::new(),
minimum_peer_version,
peerset_total_connection_limit: config.peerset_total_connection_limit(),
handle_rx,
guards: futures::stream::FuturesUnordered::new(),
last_peer_log: None,
address_metrics,
max_conns_per_ip: max_conns_per_ip.unwrap_or(config.max_connections_per_ip),
network: config.network.clone(),
}
}
fn poll_background_errors(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), BoxError>> {
futures::ready!(self.receive_tasks_if_needed(cx))?;
match futures::ready!(Pin::new(&mut self.guards).poll_next(cx)) {
Some(res) => {
info!(
background_tasks = %self.guards.len(),
"a peer set background task exited, shutting down other peer set tasks"
);
self.shut_down_tasks_and_channels(cx);
res.map_err(Into::into)
.and_then(convert::identity)?;
Poll::Ready(Err("a peer set background task exited".into()))
}
None => {
self.shut_down_tasks_and_channels(cx);
Poll::Ready(Err("all peer set background tasks have exited".into()))
}
}
}
fn receive_tasks_if_needed(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), BoxError>> {
if self.guards.is_empty() {
let handles = futures::ready!(Pin::new(&mut self.handle_rx).poll(cx));
match handles {
Ok(handles) => {
assert!(
!handles.is_empty(),
"the peer set requires at least one background task"
);
self.guards.extend(handles);
Poll::Ready(Ok(()))
}
Err(_) => Poll::Ready(Err(
"sender did not send peer background tasks before it was dropped".into(),
)),
}
} else {
Poll::Ready(Ok(()))
}
}
fn shut_down_tasks_and_channels(&mut self, cx: &mut Context<'_>) {
self.ready_services = HashMap::new();
for (_peer_key, handle) in self.cancel_handles.drain() {
let _ = handle.send(CancelClientWork);
}
self.unready_services = FuturesUnordered::new();
self.demand_signal.close_channel();
self.handle_rx.close();
let _ = self.receive_tasks_if_needed(cx);
for guard in self.guards.iter() {
guard.abort();
}
}
fn poll_peers(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), BoxError>> {
let _poll_pending_or_ready: Poll<Option<()>> = self.poll_unready(cx)?;
self.disconnect_from_outdated_peers();
self.poll_ready_peer_errors(cx).map(Ok)
}
fn poll_unready(&mut self, cx: &mut Context<'_>) -> Poll<Result<Option<()>, BoxError>> {
let mut result = Poll::Pending;
loop {
let Poll::Ready(ready_peer) = Pin::new(&mut self.unready_services).poll_next(cx) else {
break;
};
match ready_peer {
None => {
if result.is_pending() {
result = Poll::Ready(Ok(None));
}
break;
}
Some(Ok((key, svc))) => {
trace!(?key, "service became ready");
if self.bans_receiver.borrow().contains_key(&key.ip()) {
warn!(?key, "service is banned, dropping service");
std::mem::drop(svc);
let cancel = self.cancel_handles.remove(&key);
debug_assert!(
cancel.is_some(),
"missing cancel handle for banned unready peer"
);
continue;
}
self.push_ready(true, key, svc);
result = Poll::Ready(Ok(Some(())));
}
Some(Err((key, UnreadyError::Canceled))) => {
trace!(
?key,
duplicate_connection = self.cancel_handles.contains_key(&key),
"service was canceled, dropping service"
);
}
Some(Err((key, UnreadyError::CancelHandleDropped(_)))) => {
trace!(
?key,
duplicate_connection = self.cancel_handles.contains_key(&key),
"cancel handle was dropped, dropping service"
);
}
Some(Err((key, UnreadyError::Inner(error)))) => {
debug!(%error, "service failed while unready, dropping service");
let cancel = self.cancel_handles.remove(&key);
assert!(cancel.is_some(), "missing cancel handle");
}
}
}
result
}
fn poll_ready_peer_errors(&mut self, cx: &mut Context<'_>) -> Poll<()> {
let mut previous = HashMap::new();
std::mem::swap(&mut previous, &mut self.ready_services);
for (key, mut svc) in previous.drain() {
let Poll::Ready(peer_readiness) = Pin::new(&mut svc).poll_ready(cx) else {
unreachable!(
"unexpected unready peer: peers must be put into the unready_peers list \
after sending them a request"
);
};
match peer_readiness {
Ok(()) => {
if self.bans_receiver.borrow().contains_key(&key.ip()) {
debug!(?key, "service ip is banned, dropping service");
std::mem::drop(svc);
continue;
}
self.push_ready(false, key, svc)
}
Err(error) => {
debug!(%error, "service failed while ready, dropping service");
std::mem::drop(svc);
}
}
}
if self.ready_services.is_empty() {
Poll::Pending
} else {
Poll::Ready(())
}
}
fn num_peers_with_ip(&self, ip: IpAddr) -> usize {
self.ready_services
.keys()
.chain(self.cancel_handles.keys())
.filter(|addr| addr.ip() == ip)
.count()
}
fn has_peer_with_addr(&self, addr: PeerSocketAddr) -> bool {
self.ready_services.contains_key(&addr) || self.cancel_handles.contains_key(&addr)
}
fn poll_discover(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), BoxError>> {
let mut result = Poll::Pending;
loop {
let Poll::Ready(discovered) = Pin::new(&mut self.discover).poll_discover(cx) else {
break;
};
let change = discovered
.ok_or("discovery stream closed")?
.map_err(Into::into)?;
result = Poll::Ready(Ok(()));
match change {
Change::Remove(key) => {
trace!(?key, "got Change::Remove from Discover");
self.remove(&key);
}
Change::Insert(key, svc) => {
trace!(?key, "got Change::Insert from Discover");
if self.has_peer_with_addr(key) {
std::mem::drop(svc);
continue;
}
if self.num_peers_with_ip(key.ip()) >= self.max_conns_per_ip {
std::mem::drop(svc);
continue;
}
self.push_unready(key, svc);
}
}
}
result
}
fn disconnect_from_outdated_peers(&mut self) {
if let Some(minimum_version) = self.minimum_peer_version.changed() {
self.ready_services
.retain(|_address, peer| peer.remote_version() >= minimum_version);
}
}
fn take_ready_service(&mut self, key: &D::Key) -> Option<D::Service> {
if let Some(svc) = self.ready_services.remove(key) {
assert!(
!self.cancel_handles.contains_key(key),
"cancel handles are only used for unready service work"
);
Some(svc)
} else {
None
}
}
fn remove(&mut self, key: &D::Key) {
if let Some(ready_service) = self.take_ready_service(key) {
std::mem::drop(ready_service);
} else if let Some(handle) = self.cancel_handles.remove(key) {
let _ = handle.send(CancelClientWork);
}
}
fn push_ready(&mut self, was_unready: bool, key: D::Key, svc: D::Service) {
let cancel = self.cancel_handles.remove(&key);
assert_eq!(
cancel.is_some(),
was_unready,
"missing or unexpected cancel handle"
);
if svc.remote_version() >= self.minimum_peer_version.current() {
self.ready_services.insert(key, svc);
} else {
std::mem::drop(svc);
}
}
fn push_unready(&mut self, key: D::Key, svc: D::Service) {
let peer_version = svc.remote_version();
let (tx, rx) = oneshot::channel();
self.unready_services.push(UnreadyService {
key: Some(key),
service: Some(svc),
cancel: rx,
_req: PhantomData,
});
if peer_version >= self.minimum_peer_version.current() {
self.cancel_handles.insert(key, tx);
} else {
let _ = tx.send(CancelClientWork);
}
}
fn select_ready_p2c_peer(&self) -> Option<D::Key> {
self.select_p2c_peer_from_list(&self.ready_services.keys().copied().collect())
}
#[allow(clippy::unwrap_in_result)]
fn select_p2c_peer_from_list(&self, ready_service_list: &HashSet<D::Key>) -> Option<D::Key> {
match ready_service_list.len() {
0 => None,
1 => Some(
*ready_service_list
.iter()
.next()
.expect("just checked there is one service"),
),
len => {
let (a, b) = {
let idxs = rand::seq::index::sample(&mut rand::thread_rng(), len, 2);
let a = idxs.index(0);
let b = idxs.index(1);
let a = *ready_service_list
.iter()
.nth(a)
.expect("sample returns valid indexes");
let b = *ready_service_list
.iter()
.nth(b)
.expect("sample returns valid indexes");
(a, b)
};
let a_load = self.query_load(&a).expect("supplied services are ready");
let b_load = self.query_load(&b).expect("supplied services are ready");
let selected = if a_load <= b_load { a } else { b };
trace!(
a.key = ?a,
a.load = ?a_load,
b.key = ?b,
b.load = ?b_load,
selected = ?selected,
?len,
"selected service by p2c"
);
Some(selected)
}
}
}
fn select_random_ready_peers(&self, max_peers: usize) -> Vec<D::Key> {
use rand::seq::IteratorRandom;
self.ready_services
.keys()
.copied()
.choose_multiple(&mut rand::thread_rng(), max_peers)
}
fn query_load(&self, key: &D::Key) -> Option<<D::Service as Load>::Metric> {
let svc = self.ready_services.get(key);
svc.map(|svc| svc.load())
}
fn route_p2c(&mut self, req: Request) -> <Self as tower::Service<Request>>::Future {
if let Some(p2c_key) = self.select_ready_p2c_peer() {
tracing::trace!(?p2c_key, "routing based on p2c");
let mut svc = self
.take_ready_service(&p2c_key)
.expect("selected peer must be ready");
let fut = svc.call(req);
self.push_unready(p2c_key, svc);
return fut.map_err(Into::into).boxed();
}
async move {
tokio::task::yield_now().await;
Err(SharedPeerError::from(PeerError::NoReadyPeers))
}
.map_err(Into::into)
.boxed()
}
fn route_inv(
&mut self,
req: Request,
hash: InventoryHash,
) -> <Self as tower::Service<Request>>::Future {
let advertising_peer_list = self
.inventory_registry
.advertising_peers(hash)
.filter(|&addr| self.ready_services.contains_key(addr))
.copied()
.collect();
let peer = self.select_p2c_peer_from_list(&advertising_peer_list);
if let Some(mut svc) = peer.and_then(|key| self.take_ready_service(&key)) {
let peer = peer.expect("just checked peer is Some");
tracing::trace!(?hash, ?peer, "routing to a peer which advertised inventory");
let fut = svc.call(req);
self.push_unready(peer, svc);
return fut.map_err(Into::into).boxed();
}
let missing_peer_list: HashSet<PeerSocketAddr> = self
.inventory_registry
.missing_peers(hash)
.copied()
.collect();
let maybe_peer_list = self
.ready_services
.keys()
.filter(|addr| !missing_peer_list.contains(addr))
.copied()
.collect();
let peer = self.select_p2c_peer_from_list(&maybe_peer_list);
if let Some(mut svc) = peer.and_then(|key| self.take_ready_service(&key)) {
let peer = peer.expect("just checked peer is Some");
tracing::trace!(?hash, ?peer, "routing to a peer that might have inventory");
let fut = svc.call(req);
self.push_unready(peer, svc);
return fut.map_err(Into::into).boxed();
}
tracing::debug!(
?hash,
"all ready peers are missing inventory, failing request"
);
async move {
tokio::task::yield_now().await;
Err(SharedPeerError::from(PeerError::NotFoundRegistry(vec![
hash,
])))
}
.map_err(Into::into)
.boxed()
}
fn route_multiple(
&mut self,
req: Request,
max_peers: usize,
) -> <Self as tower::Service<Request>>::Future {
assert!(
max_peers > 0,
"requests must be routed to at least one peer"
);
assert!(
max_peers <= self.ready_services.len(),
"requests can only be routed to ready peers"
);
let selected_peers = self.select_random_ready_peers(max_peers);
self.send_multiple(req, selected_peers)
}
fn send_multiple(
&mut self,
req: Request,
peers: Vec<D::Key>,
) -> <Self as tower::Service<Request>>::Future {
let futs = FuturesUnordered::new();
for key in peers {
let mut svc = self
.take_ready_service(&key)
.expect("selected peers are ready");
futs.push(svc.call(req.clone()).map_err(|_| ()));
self.push_unready(key, svc);
}
async move {
let results = futs.collect::<Vec<Result<_, _>>>().await;
tracing::debug!(
ok.len = results.iter().filter(|r| r.is_ok()).count(),
err.len = results.iter().filter(|r| r.is_err()).count(),
"sent peer request to multiple peers"
);
Ok(Response::Nil)
}
.boxed()
}
fn route_broadcast(&mut self, req: Request) -> <Self as tower::Service<Request>>::Future {
self.route_multiple(req, self.number_of_peers_to_broadcast())
}
fn broadcast_all(&mut self, req: Request) -> <Self as tower::Service<Request>>::Future {
let ready_peers = self.ready_services.keys().copied().collect();
let send_multiple_fut = self.send_multiple(req.clone(), ready_peers);
let Some(mut queued_broadcast_fut_receiver) = self.queue_broadcast_all_unready(&req) else {
return send_multiple_fut;
};
async move {
let _ = send_multiple_fut.await?;
while queued_broadcast_fut_receiver.recv().await.is_some() {}
Ok(Response::Nil)
}
.boxed()
}
fn queue_broadcast_all_unready(
&mut self,
req: &Request,
) -> Option<tokio::sync::mpsc::Receiver<ResponseFuture>> {
if !self.cancel_handles.is_empty() {
const QUEUED_BROADCAST_FUTS_CHANNEL_SIZE: usize = 3;
let (sender, receiver) = tokio::sync::mpsc::channel(QUEUED_BROADCAST_FUTS_CHANNEL_SIZE);
let unready_peers: HashSet<_> = self.cancel_handles.keys().cloned().collect();
let queued = (req.clone(), sender, unready_peers);
self.queued_broadcast_all = Some(queued);
Some(receiver)
} else {
None
}
}
fn broadcast_all_queued(&mut self) {
let Some((req, sender, mut remaining_peers)) = self.queued_broadcast_all.take() else {
return;
};
let bans = self.bans_receiver.borrow().clone();
remaining_peers.retain(|addr| !bans.contains_key(&addr.ip()));
let Ok(reserved_send_slot) = sender.try_reserve() else {
self.queued_broadcast_all = Some((req, sender, remaining_peers));
return;
};
let peers: Vec<_> = self
.ready_services
.keys()
.filter(|ready_peer| remaining_peers.remove(ready_peer))
.copied()
.collect();
reserved_send_slot.send(self.send_multiple(req.clone(), peers).boxed());
if !remaining_peers.is_empty() {
self.queued_broadcast_all = Some((req, sender, remaining_peers));
}
}
pub(crate) fn number_of_peers_to_broadcast(&self) -> usize {
if self.network.is_regtest() {
self.ready_services.len()
} else {
const PEER_FRACTION_TO_BROADCAST: usize = 3;
div_ceil(self.ready_services.len(), PEER_FRACTION_TO_BROADCAST)
}
}
fn peer_set_addresses(&self) -> Vec<PeerSocketAddr> {
self.ready_services
.keys()
.chain(self.cancel_handles.keys())
.cloned()
.collect()
}
fn log_peer_set_size(&mut self) {
let ready_services_len = self.ready_services.len();
let unready_services_len = self.unready_services.len();
trace!(ready_peers = ?ready_services_len, unready_peers = ?unready_services_len);
let now = Instant::now();
if let Some(last_peer_log) = self.last_peer_log {
if now.duration_since(last_peer_log) < MIN_PEER_SET_LOG_INTERVAL {
return;
}
} else {
self.last_peer_log = Some(now);
return;
}
self.last_peer_log = Some(now);
let peers = self.peer_set_addresses();
let duplicates: Vec<PeerSocketAddr> = peers.iter().duplicates().cloned().collect();
let mut peer_counts = peers.iter().counts();
peer_counts.retain(|peer, _count| duplicates.contains(peer));
if !peer_counts.is_empty() {
let duplicate_connections: usize = peer_counts.values().sum();
warn!(
?duplicate_connections,
duplicated_peers = ?peer_counts.len(),
peers = ?peers.len(),
"duplicate peer connections in peer set"
);
}
let peers: Vec<IpAddr> = peers.iter().map(|addr| addr.ip()).collect();
let duplicates: Vec<IpAddr> = peers.iter().duplicates().cloned().collect();
let mut peer_counts = peers.iter().counts();
peer_counts.retain(|peer, _count| duplicates.contains(peer));
if !peer_counts.is_empty() {
let duplicate_connections: usize = peer_counts.values().sum();
info!(
?duplicate_connections,
duplicated_peers = ?peer_counts.len(),
peers = ?peers.len(),
"duplicate IP addresses in peer set"
);
}
if ready_services_len > 0 {
return;
}
let address_metrics = *self.address_metrics.borrow();
if unready_services_len == 0 {
warn!(
?address_metrics,
"network request with no peer connections. Hint: check your network connection"
);
} else {
info!(?address_metrics, "network request with no ready peers: finding more peers, waiting for {} peers to answer requests",
unready_services_len);
}
}
fn update_metrics(&self) {
let num_ready = self.ready_services.len();
let num_unready = self.unready_services.len();
let num_peers = num_ready + num_unready;
metrics::gauge!("pool.num_ready").set(num_ready as f64);
metrics::gauge!("pool.num_unready").set(num_unready as f64);
metrics::gauge!("zcash.net.peers").set(num_peers as f64);
if num_peers > self.peerset_total_connection_limit {
let address_metrics = *self.address_metrics.borrow();
panic!(
"unexpectedly exceeded configured peer set connection limit: \n\
peers: {num_peers:?}, ready: {num_ready:?}, unready: {num_unready:?}, \n\
address_metrics: {address_metrics:?}",
);
}
}
}
impl<D, C> Service<Request> for PeerSet<D, C>
where
D: Discover<Key = PeerSocketAddr, Service = LoadTrackedClient> + Unpin,
D::Error: Into<BoxError>,
C: ChainTip,
{
type Response = Response;
type Error = BoxError;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let _poll_pending_or_ready: Poll<()> = self.poll_discover(cx)?;
let _poll_pending: Poll<()> = self.poll_background_errors(cx)?;
let _poll_pending_or_ready: Poll<()> = self.inventory_registry.poll_inventory(cx)?;
let ready_peers = self.poll_peers(cx)?;
self.log_peer_set_size();
self.update_metrics();
if ready_peers.is_pending() {
trace!("no ready services, sending demand signal");
let _ = self.demand_signal.try_send(MorePeers);
return Poll::Pending;
}
self.broadcast_all_queued();
if self.ready_services.is_empty() {
self.poll_peers(cx)
} else {
Poll::Ready(Ok(()))
}
}
fn call(&mut self, req: Request) -> Self::Future {
let fut = match req {
Request::BlocksByHash(ref hashes) if hashes.len() == 1 => {
let hash = InventoryHash::from(*hashes.iter().next().unwrap());
self.route_inv(req, hash)
}
Request::TransactionsById(ref hashes) if hashes.len() == 1 => {
let hash = InventoryHash::from(*hashes.iter().next().unwrap());
self.route_inv(req, hash)
}
Request::AdvertiseTransactionIds(_) => self.route_broadcast(req),
Request::AdvertiseBlock(_) => self.route_broadcast(req),
Request::AdvertiseBlockToAll(_) => self.broadcast_all(req),
_ => self.route_p2c(req),
};
self.update_metrics();
fut
}
}