#![allow(clippy::unwrap_in_result)]
use std::{net::SocketAddr, sync::Arc};
use futures::{channel::mpsc, stream, Stream, StreamExt};
use proptest::{collection::vec, prelude::*};
use proptest_derive::Arbitrary;
use tokio::{
sync::{broadcast, watch},
task::JoinHandle,
};
use tower::{
discover::{Change, Discover},
BoxError,
};
use tracing::Span;
use zebra_chain::{
block,
chain_tip::ChainTip,
parameters::{Network, NetworkUpgrade},
};
use crate::{
address_book::AddressMetrics,
constants::DEFAULT_MAX_CONNS_PER_IP,
peer::{ClientTestHarness, LoadTrackedClient, MinimumPeerVersion},
peer_set::{set::MorePeers, InventoryChange, PeerSet},
protocol::external::types::Version,
AddressBook, Config, PeerSocketAddr,
};
#[cfg(test)]
mod prop;
#[cfg(test)]
mod vectors;
const MAX_PEERS: usize = 20;
#[derive(Arbitrary, Debug)]
struct PeerVersions {
#[proptest(strategy = "vec(any::<Version>(), 1..MAX_PEERS)")]
peer_versions: Vec<Version>,
}
impl PeerVersions {
pub fn mock_peers(&self) -> (Vec<LoadTrackedClient>, Vec<ClientTestHarness>) {
let mut clients = Vec::with_capacity(self.peer_versions.len());
let mut harnesses = Vec::with_capacity(self.peer_versions.len());
for peer_version in &self.peer_versions {
let (client, harness) = ClientTestHarness::build()
.with_version(*peer_version)
.finish();
clients.push(client.into());
harnesses.push(harness);
}
(clients, harnesses)
}
pub fn mock_peer_discovery(
&self,
) -> (
impl Stream<Item = Result<Change<PeerSocketAddr, LoadTrackedClient>, BoxError>>,
Vec<ClientTestHarness>,
) {
let (clients, harnesses) = self.mock_peers();
let fake_ports = 1_u16..;
let discovered_peers_iterator = fake_ports.zip(clients).map(|(port, client)| {
let peer_address: PeerSocketAddr = SocketAddr::new([127, 0, 0, 1].into(), port).into();
Ok(Change::Insert(peer_address, client))
});
let discovered_peers = stream::iter(discovered_peers_iterator).chain(stream::pending());
(discovered_peers, harnesses)
}
}
#[derive(Default)]
struct PeerSetBuilder<D, C> {
config: Option<Config>,
discover: Option<D>,
demand_signal: Option<mpsc::Sender<MorePeers>>,
handle_rx: Option<tokio::sync::oneshot::Receiver<Vec<JoinHandle<Result<(), BoxError>>>>>,
inv_stream: Option<broadcast::Receiver<InventoryChange>>,
address_book: Option<Arc<std::sync::Mutex<AddressBook>>>,
minimum_peer_version: Option<MinimumPeerVersion<C>>,
max_conns_per_ip: Option<usize>,
}
impl PeerSetBuilder<(), ()> {
pub fn new() -> Self {
PeerSetBuilder::default()
}
}
impl<D, C> PeerSetBuilder<D, C> {
pub fn with_discover<NewD>(self, discover: NewD) -> PeerSetBuilder<NewD, C> {
PeerSetBuilder {
discover: Some(discover),
config: self.config,
demand_signal: self.demand_signal,
handle_rx: self.handle_rx,
inv_stream: self.inv_stream,
address_book: self.address_book,
minimum_peer_version: self.minimum_peer_version,
max_conns_per_ip: self.max_conns_per_ip,
}
}
pub fn with_minimum_peer_version<NewC>(
self,
minimum_peer_version: MinimumPeerVersion<NewC>,
) -> PeerSetBuilder<D, NewC> {
PeerSetBuilder {
config: self.config,
discover: self.discover,
demand_signal: self.demand_signal,
handle_rx: self.handle_rx,
inv_stream: self.inv_stream,
address_book: self.address_book,
minimum_peer_version: Some(minimum_peer_version),
max_conns_per_ip: self.max_conns_per_ip,
}
}
pub fn max_conns_per_ip(self, max_conns_per_ip: usize) -> PeerSetBuilder<D, C> {
assert!(
max_conns_per_ip > 0,
"max_conns_per_ip must be greater than zero"
);
PeerSetBuilder {
config: self.config,
discover: self.discover,
demand_signal: self.demand_signal,
handle_rx: self.handle_rx,
inv_stream: self.inv_stream,
address_book: self.address_book,
minimum_peer_version: self.minimum_peer_version,
max_conns_per_ip: Some(max_conns_per_ip),
}
}
}
impl<D, C> PeerSetBuilder<D, C>
where
D: Discover<Key = PeerSocketAddr, Service = LoadTrackedClient> + Unpin,
D::Error: Into<BoxError>,
C: ChainTip,
{
pub fn build(self) -> (PeerSet<D, C>, PeerSetGuard) {
let mut guard = PeerSetGuard::new();
let config = self.config.unwrap_or_default();
let discover = self.discover.expect("`discover` must be set");
let minimum_peer_version = self
.minimum_peer_version
.expect("`minimum_peer_version` must be set");
let max_conns_per_ip = self.max_conns_per_ip;
let demand_signal = self
.demand_signal
.unwrap_or_else(|| guard.create_demand_sender());
let handle_rx = self
.handle_rx
.unwrap_or_else(|| guard.create_background_tasks_receiver());
let inv_stream = self
.inv_stream
.unwrap_or_else(|| guard.create_inventory_receiver());
let address_metrics = guard.prepare_address_book(self.address_book);
let (_bans_sender, bans_receiver) = tokio::sync::watch::channel(Default::default());
let peer_set = PeerSet::new(
&config,
discover,
demand_signal,
handle_rx,
inv_stream,
bans_receiver,
address_metrics,
minimum_peer_version,
max_conns_per_ip,
);
(peer_set, guard)
}
}
#[derive(Default)]
pub struct PeerSetGuard {
background_tasks_sender:
Option<tokio::sync::oneshot::Sender<Vec<JoinHandle<Result<(), BoxError>>>>>,
demand_receiver: Option<mpsc::Receiver<MorePeers>>,
inventory_sender: Option<broadcast::Sender<InventoryChange>>,
address_book: Option<Arc<std::sync::Mutex<AddressBook>>>,
}
impl PeerSetGuard {
pub fn new() -> Self {
PeerSetGuard::default()
}
#[allow(dead_code)]
pub fn background_tasks_sender(
&mut self,
) -> &mut Option<tokio::sync::oneshot::Sender<Vec<JoinHandle<Result<(), BoxError>>>>> {
&mut self.background_tasks_sender
}
#[allow(dead_code)]
pub fn demand_receiver(&mut self) -> &mut Option<mpsc::Receiver<MorePeers>> {
&mut self.demand_receiver
}
pub fn inventory_sender(&mut self) -> &mut Option<broadcast::Sender<InventoryChange>> {
&mut self.inventory_sender
}
#[allow(dead_code)]
pub fn address_book(&mut self) -> &mut Option<Arc<std::sync::Mutex<AddressBook>>> {
&mut self.address_book
}
pub fn create_background_tasks_receiver(
&mut self,
) -> tokio::sync::oneshot::Receiver<Vec<JoinHandle<Result<(), BoxError>>>> {
let (sender, receiver) = tokio::sync::oneshot::channel();
self.background_tasks_sender = Some(sender);
receiver
}
pub fn create_demand_sender(&mut self) -> mpsc::Sender<MorePeers> {
let (sender, receiver) = mpsc::channel(1);
self.demand_receiver = Some(receiver);
sender
}
pub fn create_inventory_receiver(&mut self) -> broadcast::Receiver<InventoryChange> {
let (sender, receiver) = broadcast::channel(1);
self.inventory_sender = Some(sender);
receiver
}
pub fn prepare_address_book(
&mut self,
maybe_address_book: Option<Arc<std::sync::Mutex<AddressBook>>>,
) -> watch::Receiver<AddressMetrics> {
let address_book = maybe_address_book.unwrap_or_else(Self::fallback_address_book);
let metrics_watcher = address_book
.lock()
.expect("unexpected panic in previous address book mutex guard")
.address_metrics_watcher();
self.address_book = Some(address_book);
metrics_watcher
}
fn fallback_address_book() -> Arc<std::sync::Mutex<AddressBook>> {
let local_listener = "127.0.0.1:1000"
.parse()
.expect("Invalid local listener address");
let address_book = AddressBook::new(
local_listener,
&Network::Mainnet,
DEFAULT_MAX_CONNS_PER_IP,
Span::none(),
);
Arc::new(std::sync::Mutex::new(address_book))
}
}
#[derive(Clone, Debug)]
pub struct BlockHeightPairAcrossNetworkUpgrades {
pub network: Network,
pub before_upgrade: block::Height,
pub after_upgrade: block::Height,
}
impl Arbitrary for BlockHeightPairAcrossNetworkUpgrades {
type Parameters = ();
fn arbitrary_with((): Self::Parameters) -> Self::Strategy {
any::<(Network, NetworkUpgrade)>()
.prop_filter("no block height before genesis", |(_, upgrade)| {
!matches!(upgrade, NetworkUpgrade::Genesis)
})
.prop_filter_map(
"missing activation height for network upgrade",
|(network, upgrade)| {
upgrade
.activation_height(&network)
.map(|height| (network, height))
},
)
.prop_flat_map(|(network, activation_height)| {
let before_upgrade_strategy = 0..activation_height.0;
let after_upgrade_strategy = activation_height.0..;
(
Just(network),
before_upgrade_strategy,
after_upgrade_strategy,
)
})
.prop_map(|(network, before_upgrade_height, after_upgrade_height)| {
BlockHeightPairAcrossNetworkUpgrades {
network,
before_upgrade: block::Height(before_upgrade_height),
after_upgrade: block::Height(after_upgrade_height),
}
})
.boxed()
}
type Strategy = BoxedStrategy<Self>;
}