#[cfg(feature = "telemetry")]
use crate::helpers::Telemetry;
use crate::{
CONTEXT,
MAX_BATCH_DELAY_IN_MS,
MEMORY_POOL_PORT,
Worker,
events::{DisconnectReason, EventCodec, PrimaryPing},
helpers::{Cache, PrimarySender, Storage, SyncSender, WorkerSender, assign_to_worker},
spawn_blocking,
};
use smol_str::SmolStr;
use snarkos_account::Account;
use snarkos_node_bft_events::{
BlockRequest,
BlockResponse,
CertificateRequest,
CertificateResponse,
ChallengeRequest,
ChallengeResponse,
DataBlocks,
Event,
EventTrait,
TransmissionRequest,
TransmissionResponse,
ValidatorsRequest,
ValidatorsResponse,
};
use snarkos_node_bft_ledger_service::LedgerService;
use snarkos_node_network::{
ConnectionMode,
NodeType,
Peer,
PeerPoolHandling,
Resolver,
bootstrap_peers,
get_repo_commit_hash,
log_repo_sha_comparison,
shorten_snarkos_sha,
};
use snarkos_node_sync::{MAX_BLOCKS_BEHIND, communication_service::CommunicationService};
use snarkos_node_tcp::{
Config,
ConnectError,
Connection,
ConnectionSide,
P2P,
Tcp,
protocols::{Disconnect, Handshake, OnConnect, Reading, Writing},
};
use snarkos_utilities::NodeDataDir;
use snarkvm::{
console::prelude::*,
ledger::{
committee::Committee,
narwhal::{BatchHeader, Data},
},
prelude::{Address, Field},
utilities::flatten_error,
};
use colored::Colorize;
use futures::{SinkExt, future::join_all};
use indexmap::IndexMap;
#[cfg(feature = "locktick")]
use locktick::parking_lot::{Mutex, RwLock};
#[cfg(not(feature = "locktick"))]
use parking_lot::{Mutex, RwLock};
use rand::{
rngs::OsRng,
seq::{IteratorRandom, SliceRandom},
};
use std::{
collections::{HashMap, HashSet},
future::Future,
io,
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
sync::Arc,
time::Duration,
};
use tokio::{
net::TcpStream,
sync::{OnceCell, oneshot},
task::{self, JoinHandle},
};
use tokio_stream::StreamExt;
use tokio_util::codec::Framed;
const CACHE_EVENTS_INTERVAL: i64 = (MAX_BATCH_DELAY_IN_MS / 1000) as i64; const CACHE_REQUESTS_INTERVAL: i64 = (MAX_BATCH_DELAY_IN_MS / 1000) as i64;
#[cfg(not(test))]
const MAX_CONNECTION_ATTEMPTS: usize = 10;
pub const MAX_VALIDATORS_TO_SEND: usize = 200;
#[cfg(not(test))]
const CONNECTION_ATTEMPTS_SINCE_SECS: i64 = 10;
const IP_BAN_TIME_IN_SECS: u64 = 300;
#[async_trait]
pub trait Transport<N: Network>: Send + Sync {
async fn send(&self, peer_ip: SocketAddr, event: Event<N>) -> Option<oneshot::Receiver<io::Result<()>>>;
fn broadcast(&self, event: Event<N>);
}
#[derive(Clone)]
pub struct Gateway<N: Network>(Arc<InnerGateway<N>>);
impl<N: Network> Deref for Gateway<N> {
type Target = Arc<InnerGateway<N>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
pub struct InnerGateway<N: Network> {
account: Account<N>,
storage: Storage<N>,
ledger: Arc<dyn LedgerService<N>>,
tcp: Tcp,
cache: Cache<N>,
resolver: RwLock<Resolver<N>>,
peer_pool: RwLock<HashMap<SocketAddr, Peer<N>>>,
#[cfg(feature = "telemetry")]
validator_telemetry: Telemetry<N>,
primary_sender: OnceCell<PrimarySender<N>>,
worker_senders: OnceCell<IndexMap<u8, WorkerSender<N>>>,
sync_sender: OnceCell<SyncSender<N>>,
handles: Mutex<Vec<JoinHandle<()>>>,
node_data_dir: NodeDataDir,
trusted_peers_only: bool,
dev: Option<u16>,
}
impl<N: Network> PeerPoolHandling<N> for Gateway<N> {
const MAXIMUM_POOL_SIZE: usize = 200;
const OWNER: &str = CONTEXT;
const PEER_SLASHING_COUNT: usize = 20;
fn peer_pool(&self) -> &RwLock<HashMap<SocketAddr, Peer<N>>> {
&self.peer_pool
}
fn resolver(&self) -> &RwLock<Resolver<N>> {
&self.resolver
}
fn is_dev(&self) -> bool {
self.dev.is_some()
}
fn trusted_peers_only(&self) -> bool {
self.trusted_peers_only
}
fn node_type(&self) -> NodeType {
NodeType::Validator
}
}
impl<N: Network> Gateway<N> {
#[allow(clippy::too_many_arguments)]
pub fn new(
account: Account<N>,
storage: Storage<N>,
ledger: Arc<dyn LedgerService<N>>,
ip: Option<SocketAddr>,
trusted_validators: &[SocketAddr],
trusted_peers_only: bool,
node_data_dir: NodeDataDir,
dev: Option<u16>,
) -> Result<Self> {
let ip = match (ip, dev) {
(None, Some(dev)) => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, MEMORY_POOL_PORT + dev)),
(None, None) => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, MEMORY_POOL_PORT)),
(Some(ip), _) => ip,
};
let tcp = Tcp::new(Config::new(ip, Committee::<N>::max_committee_size() * 10));
let mut initial_peers = HashMap::new();
if !trusted_peers_only {
let cached_peers = Self::load_cached_peers(&node_data_dir.gateway_peer_cache_path())?;
for addr in cached_peers {
initial_peers.insert(addr, Peer::new_candidate(addr, false));
}
}
initial_peers.extend(trusted_validators.iter().copied().map(|addr| (addr, Peer::new_candidate(addr, true))));
Ok(Self(Arc::new(InnerGateway {
account,
storage,
ledger,
tcp,
cache: Default::default(),
resolver: Default::default(),
peer_pool: RwLock::new(initial_peers),
#[cfg(feature = "telemetry")]
validator_telemetry: Default::default(),
primary_sender: Default::default(),
worker_senders: Default::default(),
sync_sender: Default::default(),
handles: Default::default(),
node_data_dir,
trusted_peers_only,
dev,
})))
}
pub async fn run(
&self,
primary_sender: PrimarySender<N>,
worker_senders: IndexMap<u8, WorkerSender<N>>,
sync_sender: Option<SyncSender<N>>,
) {
debug!("Starting the gateway for the memory pool...");
self.primary_sender.set(primary_sender).expect("Primary sender already set in gateway");
self.worker_senders.set(worker_senders).expect("The worker senders are already set");
if let Some(sync_sender) = sync_sender {
self.sync_sender.set(sync_sender).expect("Sync sender already set in gateway");
}
self.enable_handshake().await;
self.enable_reading().await;
self.enable_writing().await;
self.enable_disconnect().await;
self.enable_on_connect().await;
#[cfg(feature = "metrics")]
{
let gateway = self.clone();
self.spawn(async move {
loop {
tokio::time::sleep(Duration::from_secs(1)).await;
gateway.update_metrics();
}
});
}
let listen_addr = self.tcp.enable_listener().await.expect("Failed to enable the TCP listener");
debug!("Listening for validator connections at address {listen_addr:?}");
self.initialize_heartbeat();
info!("Started the gateway for the memory pool at '{}'", self.local_ip());
}
}
impl<N: Network> Gateway<N> {
fn max_committee_size(&self) -> usize {
self.ledger
.current_committee()
.map_or_else(|_e| Committee::<N>::max_committee_size() as usize, |committee| committee.num_members())
}
fn max_cache_events(&self) -> usize {
self.max_cache_transmissions()
}
fn max_cache_certificates(&self) -> usize {
2 * BatchHeader::<N>::MAX_GC_ROUNDS * self.max_committee_size()
}
fn max_cache_transmissions(&self) -> usize {
self.max_cache_certificates() * BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH
}
fn max_cache_duplicates(&self) -> usize {
self.max_committee_size().pow(2)
}
}
#[async_trait]
impl<N: Network> CommunicationService for Gateway<N> {
type Message = Event<N>;
fn prepare_block_request(start_height: u32, end_height: u32) -> Self::Message {
debug_assert!(start_height < end_height, "Invalid block request format");
Event::BlockRequest(BlockRequest { start_height, end_height })
}
async fn send(&self, peer_ip: SocketAddr, message: Self::Message) -> Option<oneshot::Receiver<io::Result<()>>> {
Transport::send(self, peer_ip, message).await
}
}
impl<N: Network> Gateway<N> {
pub fn account(&self) -> &Account<N> {
&self.account
}
pub fn dev(&self) -> Option<u16> {
self.dev
}
pub fn resolver(&self) -> &RwLock<Resolver<N>> {
&self.resolver
}
pub fn resolve_to_listener(&self, connected_addr: &SocketAddr) -> Option<SocketAddr> {
self.resolver.read().get_listener(*connected_addr)
}
#[cfg(feature = "telemetry")]
pub fn validator_telemetry(&self) -> &Telemetry<N> {
&self.validator_telemetry
}
pub fn primary_sender(&self) -> &PrimarySender<N> {
self.primary_sender.get().expect("Primary sender not set in gateway")
}
pub fn num_workers(&self) -> u8 {
u8::try_from(self.worker_senders.get().expect("Missing worker senders in gateway").len())
.expect("Too many workers")
}
pub fn get_worker_sender(&self, worker_id: u8) -> Option<&WorkerSender<N>> {
self.worker_senders.get().and_then(|senders| senders.get(&worker_id))
}
pub fn is_authorized_validator_ip(&self, ip: SocketAddr) -> bool {
if self.trusted_peers().contains(&ip) {
return true;
}
match self.resolve_to_aleo_addr(ip) {
Some(address) => self.is_authorized_validator_address(address),
None => {
warn!("{CONTEXT} Could not resolve the Aleo address for '{ip}'");
false
}
}
}
pub fn is_authorized_validator_address(&self, validator_address: Address<N>) -> bool {
if self
.ledger
.get_committee_lookback_for_round(self.storage.current_round())
.is_ok_and(|committee| committee.is_committee_member(validator_address))
{
return true;
}
if self.ledger.current_committee().is_ok_and(|committee| committee.is_committee_member(validator_address)) {
return true;
}
let previous_block_height = self.ledger.latest_block_height().saturating_sub(MAX_BLOCKS_BEHIND);
match self.ledger.get_block_round(previous_block_height) {
Ok(block_round) => (block_round..self.storage.current_round()).step_by(2).any(|round| {
self.ledger
.get_committee_lookback_for_round(round)
.is_ok_and(|committee| committee.is_committee_member(validator_address))
}),
Err(_) => false,
}
}
pub fn connected_addresses(&self) -> HashSet<Address<N>> {
self.get_connected_peers().into_iter().map(|peer| peer.aleo_addr).collect()
}
fn ensure_peer_is_allowed(&self, listener_addr: SocketAddr) -> Result<(), DisconnectReason> {
if self.is_local_ip(listener_addr) {
return Err(DisconnectReason::SelfConnect);
}
Ok(())
}
#[cfg(feature = "metrics")]
fn update_metrics(&self) {
metrics::gauge(metrics::bft::CONNECTED, self.number_of_connected_peers() as f64);
metrics::gauge(metrics::bft::CONNECTING, self.number_of_connecting_peers() as f64);
}
#[cfg(test)]
pub fn insert_connected_peer(&self, peer_ip: SocketAddr, peer_addr: SocketAddr, address: Address<N>) {
self.resolver.write().insert_peer(peer_ip, peer_addr, Some(address));
self.peer_pool.write().insert(peer_ip, Peer::new_connecting(peer_ip, false));
if let Some(peer) = self.peer_pool.write().get_mut(&peer_ip) {
peer.upgrade_to_connected(
peer_addr,
peer_ip.port(),
address,
NodeType::Validator,
0,
get_repo_commit_hash(),
ConnectionMode::Gateway,
);
}
}
fn send_inner(&self, peer_ip: SocketAddr, event: Event<N>) -> Option<oneshot::Receiver<io::Result<()>>> {
let Some(peer_addr) = self.resolve_to_ambiguous(peer_ip) else {
warn!("Unable to resolve the listener IP address '{peer_ip}'");
return None;
};
let name = event.name();
trace!("{CONTEXT} Sending '{name}' to '{peer_ip}'");
let result = self.unicast(peer_addr, event);
if let Err(err) = &result {
warn!("{CONTEXT} Failed to send '{name}' to '{peer_ip}': {err:?}");
debug!("{CONTEXT} Disconnecting from '{peer_ip}' (unable to send)");
self.disconnect(peer_ip);
}
result.ok()
}
async fn inbound(&self, peer_addr: SocketAddr, event: Event<N>) -> Result<bool> {
let Some(peer_ip) = self.resolver.read().get_listener(peer_addr) else {
trace!("Dropping a {} from {peer_addr} - no longer connected.", event.name());
return Ok(false);
};
if !(self.is_authorized_validator_ip(peer_ip)
|| self
.get_connected_peer(peer_ip)
.map(|peer| peer.node_type == NodeType::BootstrapClient)
.unwrap_or(false))
{
bail!("{CONTEXT} Dropping '{}' from '{peer_ip}' (not authorized)", event.name())
}
let num_events = self.cache.insert_inbound_event(peer_ip, CACHE_EVENTS_INTERVAL);
if num_events >= self.max_cache_events() {
bail!("Dropping '{peer_ip}' for spamming events (num_events = {num_events})")
}
match event {
Event::CertificateRequest(_) | Event::CertificateResponse(_) => {
let certificate_id = match &event {
Event::CertificateRequest(CertificateRequest { certificate_id }) => *certificate_id,
Event::CertificateResponse(CertificateResponse { certificate }) => certificate.id(),
_ => unreachable!(),
};
let num_events = self.cache.insert_inbound_certificate(certificate_id, CACHE_REQUESTS_INTERVAL);
if num_events >= self.max_cache_duplicates() {
return Ok(true);
}
}
Event::TransmissionRequest(TransmissionRequest { transmission_id })
| Event::TransmissionResponse(TransmissionResponse { transmission_id, .. }) => {
let num_events = self.cache.insert_inbound_transmission(transmission_id, CACHE_REQUESTS_INTERVAL);
if num_events >= self.max_cache_duplicates() {
return Ok(true);
}
}
Event::BlockRequest(_) => {
let num_events = self.cache.insert_inbound_block_request(peer_ip, CACHE_REQUESTS_INTERVAL);
if num_events >= self.max_cache_duplicates() {
return Ok(true);
}
}
_ => {}
}
trace!("{CONTEXT} Received '{}' from '{peer_ip}'", event.name());
match event {
Event::BatchPropose(batch_propose) => {
let _ = self.primary_sender().tx_batch_propose.send((peer_ip, batch_propose)).await;
Ok(true)
}
Event::BatchSignature(batch_signature) => {
let _ = self.primary_sender().tx_batch_signature.send((peer_ip, batch_signature)).await;
Ok(true)
}
Event::BatchCertified(batch_certified) => {
let _ = self.primary_sender().tx_batch_certified.send((peer_ip, batch_certified.certificate)).await;
Ok(true)
}
Event::BlockRequest(block_request) => {
let BlockRequest { start_height, end_height } = block_request;
if start_height >= end_height {
bail!("Block request from '{peer_ip}' has an invalid range ({start_height}..{end_height})")
}
if end_height - start_height > DataBlocks::<N>::MAXIMUM_NUMBER_OF_BLOCKS as u32 {
bail!("Block request from '{peer_ip}' has an excessive range ({start_height}..{end_height})")
}
let latest_consensus_version = N::CONSENSUS_VERSION(end_height - 1)?;
let self_ = self.clone();
let blocks = match task::spawn_blocking(move || {
match self_.ledger.get_blocks(start_height..end_height) {
Ok(blocks) => Ok(DataBlocks(blocks)),
Err(error) => bail!("Missing blocks {start_height} to {end_height} from ledger - {error}"),
}
})
.await
{
Ok(Ok(blocks)) => blocks,
Ok(Err(error)) => return Err(error),
Err(error) => return Err(anyhow!("[BlockRequest] {error}")),
};
let self_ = self.clone();
tokio::spawn(async move {
let event =
Event::BlockResponse(BlockResponse::new(block_request, blocks, latest_consensus_version));
Transport::send(&self_, peer_ip, event).await;
});
Ok(true)
}
Event::BlockResponse(BlockResponse { request, latest_consensus_version, blocks, .. }) => {
if let Some(sync_sender) = self.sync_sender.get() {
if !self.cache.remove_outbound_block_request(peer_ip, &request) {
bail!("Unsolicited block response from '{peer_ip}'")
}
let (send, recv) = tokio::sync::oneshot::channel();
rayon::spawn_fifo(move || {
let blocks = blocks.deserialize_blocking().map_err(|error| anyhow!("[BlockResponse] {error}"));
let _ = send.send(blocks);
});
let blocks = match recv.await {
Ok(Ok(blocks)) => blocks,
Ok(Err(error)) => bail!("Peer '{peer_ip}' sent an invalid block response - {error}"),
Err(error) => bail!("Peer '{peer_ip}' sent an invalid block response - {error}"),
};
blocks.ensure_response_is_well_formed(peer_ip, request.start_height, request.end_height)?;
match sync_sender.insert_block_response(peer_ip, blocks.0, latest_consensus_version).await {
Ok(_) => Ok(true),
Err(err) if err.is_benign() => {
let err: anyhow::Error = err.into();
let err = err.context(format!("Ignoring block response from peer '{peer_ip}'"));
debug!("{}", flatten_error(err));
Ok(true)
}
Err(err) if err.is_invalid_consensus_version() => {
let err: anyhow::Error = err.into();
let err = err.context(format!("Peer sent an invalid block response '{peer_ip}'"));
let msg = flatten_error(&err);
error!("{msg}");
self.ip_ban_peer(peer_ip, Some(&msg));
Err(err)
}
Err(err) => {
let err: anyhow::Error = err.into();
let err = err.context(format!("Peer '{peer_ip}' sent an invalid block response"));
warn!("{}", flatten_error(err));
Ok(true)
}
}
} else {
debug!("Ignoring block response from '{peer_ip}' - no sync sender");
Ok(true)
}
}
Event::CertificateRequest(certificate_request) => {
if let Some(sync_sender) = self.sync_sender.get() {
let _ = sync_sender.tx_certificate_request.send((peer_ip, certificate_request)).await;
}
Ok(true)
}
Event::CertificateResponse(certificate_response) => {
if let Some(sync_sender) = self.sync_sender.get() {
let _ = sync_sender.tx_certificate_response.send((peer_ip, certificate_response)).await;
}
Ok(true)
}
Event::ChallengeRequest(..) | Event::ChallengeResponse(..) => {
bail!("{CONTEXT} Peer '{peer_ip}' is not following the protocol")
}
Event::Disconnect(message) => {
debug!("Peer '{peer_ip}' decided to disconnect due to '{}'", message.reason);
self.disconnect(peer_ip);
Ok(false)
}
Event::PrimaryPing(ping) => {
let PrimaryPing { version, block_locators, primary_certificate } = ping;
if version < Event::<N>::VERSION {
bail!("Dropping '{peer_ip}' on event version {version} (outdated)");
}
debug!("Validator '{peer_ip}' is at height {}", block_locators.latest_locator_height());
if let Some(sync_sender) = self.sync_sender.get() {
if let Err(error) = sync_sender.update_peer_locators(peer_ip, block_locators).await {
bail!("Validator '{peer_ip}' sent invalid block locators - {error}");
}
}
let _ = self.primary_sender().tx_primary_ping.send((peer_ip, primary_certificate)).await;
Ok(true)
}
Event::TransmissionRequest(request) => {
let Ok(worker_id) = assign_to_worker(request.transmission_id, self.num_workers()) else {
warn!("{CONTEXT} Unable to assign transmission ID '{}' to a worker", request.transmission_id);
return Ok(true);
};
if let Some(sender) = self.get_worker_sender(worker_id) {
let _ = sender.tx_transmission_request.send((peer_ip, request)).await;
}
Ok(true)
}
Event::TransmissionResponse(response) => {
let Ok(worker_id) = assign_to_worker(response.transmission_id, self.num_workers()) else {
warn!("{CONTEXT} Unable to assign transmission ID '{}' to a worker", response.transmission_id);
return Ok(true);
};
if let Some(sender) = self.get_worker_sender(worker_id) {
let _ = sender.tx_transmission_response.send((peer_ip, response)).await;
}
Ok(true)
}
Event::ValidatorsRequest(_) => {
let mut connected_peers = self.get_best_connected_peers(Some(MAX_VALIDATORS_TO_SEND));
connected_peers.shuffle(&mut rand::thread_rng());
let self_ = self.clone();
tokio::spawn(async move {
let mut validators = IndexMap::with_capacity(MAX_VALIDATORS_TO_SEND);
for validator in connected_peers.into_iter() {
validators.insert(validator.listener_addr, validator.aleo_addr);
}
let event = Event::ValidatorsResponse(ValidatorsResponse { validators });
Transport::send(&self_, peer_ip, event).await;
});
Ok(true)
}
Event::ValidatorsResponse(response) => {
if self.trusted_peers_only {
bail!("{CONTEXT} Not accepting validators response from '{peer_ip}' (trusted peers only)");
}
let ValidatorsResponse { validators } = response;
ensure!(validators.len() <= MAX_VALIDATORS_TO_SEND, "{CONTEXT} Received too many validators");
if !self.cache.contains_outbound_validators_request(peer_ip) {
bail!("{CONTEXT} Received validators response from '{peer_ip}' without a validators request")
}
self.cache.decrement_outbound_validators_requests(peer_ip);
let valid_addrs = validators
.into_iter()
.filter_map(|(listener_addr, aleo_addr)| {
(self.account.address() != aleo_addr
&& !self.is_connected_address(aleo_addr)
&& self.is_authorized_validator_address(aleo_addr))
.then_some((listener_addr, None))
})
.collect::<Vec<_>>();
if !valid_addrs.is_empty() {
self.insert_candidate_peers(valid_addrs);
}
Ok(true)
}
Event::WorkerPing(ping) => {
ensure!(
ping.transmission_ids.len() <= Worker::<N>::MAX_TRANSMISSIONS_PER_WORKER_PING,
"{CONTEXT} Received too many transmissions"
);
let num_workers = self.num_workers();
for transmission_id in ping.transmission_ids.into_iter() {
let Ok(worker_id) = assign_to_worker(transmission_id, num_workers) else {
warn!("{CONTEXT} Unable to assign transmission ID '{transmission_id}' to a worker");
continue;
};
if let Some(sender) = self.get_worker_sender(worker_id) {
let _ = sender.tx_worker_ping.send((peer_ip, transmission_id)).await;
}
}
Ok(true)
}
}
}
fn initialize_heartbeat(&self) {
let self_clone = self.clone();
self.spawn(async move {
tokio::time::sleep(Duration::from_millis(1000)).await;
info!("Starting the heartbeat of the gateway...");
loop {
self_clone.heartbeat().await;
tokio::time::sleep(Duration::from_secs(15)).await;
}
});
}
#[allow(dead_code)]
fn spawn<T: Future<Output = ()> + Send + 'static>(&self, future: T) {
self.handles.lock().push(tokio::spawn(future));
}
pub async fn shut_down(&self) {
info!("Shutting down the gateway...");
if let Err(e) = self.save_best_peers(&self.node_data_dir.gateway_peer_cache_path(), None, true) {
warn!("Failed to persist best validators to disk: {e}");
}
self.handles.lock().iter().for_each(|handle| handle.abort());
self.tcp.shut_down().await;
}
}
impl<N: Network> Gateway<N> {
const MINIMUM_TIME_BETWEEN_CONNECTION_ATTEMPTS: Duration = Duration::from_secs(10);
const MISSING_VALIDATOR_CONNECTIONS_GRACE_PERIOD: Duration = Duration::from_secs(60);
async fn heartbeat(&self) {
self.log_connected_validators();
#[cfg(feature = "telemetry")]
self.log_participation_scores();
self.handle_trusted_validators();
self.handle_bootstrap_peers().await;
self.handle_unauthorized_validators();
self.handle_min_connected_validators().await;
self.handle_banned_ips();
self.update_validator_whitelist();
}
fn log_connected_validators(&self) {
let connected_validators = self.filter_connected_peers(|peer| peer.node_type == NodeType::Validator);
let committee = match self.ledger.current_committee() {
Ok(c) => c,
Err(err) => {
error!("Failed to get current committee: {err}");
return;
}
};
let validators_total = committee.num_members().saturating_sub(1);
let total_validators = format!("(of {validators_total} bonded validators)").dimmed();
let connections_msg = match connected_validators.len() {
0 => "No connected validators".to_string(),
num_connected => format!("Connected to {num_connected} validators {total_validators}"),
};
info!("{connections_msg}");
let mut connected_validator_addresses = HashSet::with_capacity(connected_validators.len());
let mut connected_validator_shas: HashMap<SmolStr, u64> = HashMap::with_capacity(connected_validators.len());
let our_sha = shorten_snarkos_sha(&get_repo_commit_hash());
let our_stake = committee.get_stake(self.account.address());
connected_validator_shas.insert(our_sha.clone(), our_stake);
connected_validator_addresses.insert(self.account.address());
for peer in &connected_validators {
let address = peer.aleo_addr;
connected_validator_addresses.insert(address);
let address_stake = committee.get_stake(address);
let short_peer_sha = shorten_snarkos_sha(&peer.snarkos_sha);
*connected_validator_shas.entry(short_peer_sha.clone()).or_default() += address_stake;
debug!(
"{}",
format!(
" Connected to: {} - {} (connection age {:?})",
peer.listener_addr,
peer.aleo_addr,
peer.first_seen.elapsed()
)
.dimmed()
);
}
if let Some(combined_stake) = connected_validator_shas.get(&our_sha) {
let percentage = *combined_stake as f64 / committee.total_stake() as f64 * 100.0;
debug!("{}", format!(" Combined stake @ {our_sha}: {percentage:.2}%").dimmed());
#[cfg(feature = "metrics")]
metrics::gauge(metrics::bft::CONNECTED_STAKE_WITH_MATCHING_SHA, percentage);
}
let num_not_connected = validators_total.saturating_sub(connected_validators.len());
if num_not_connected > 0 && self.tcp().uptime() > Self::MISSING_VALIDATOR_CONNECTIONS_GRACE_PERIOD {
let total_stake = committee.total_stake();
let total_stake_f64 = total_stake as f64;
let committee_members: HashSet<_> =
self.ledger.current_committee().map(|c| c.members().keys().copied().collect()).unwrap_or_default();
let not_connected_stake: u64 = committee_members
.difference(&connected_validator_addresses)
.map(|address| {
let address_stake = committee.get_stake(*address);
let address_stake_as_percentage =
if total_stake == 0 { 0.0 } else { address_stake as f64 / total_stake_f64 * 100.0 };
debug!(
"{}",
format!(" Not connected to {address} ({address_stake_as_percentage:.2}% of total stake)")
.dimmed()
);
address_stake
})
.sum();
let not_connected_stake_as_percentage =
if total_stake == 0 { 0.0 } else { not_connected_stake as f64 / total_stake_f64 * 100.0 };
warn!(
"Not connected to {num_not_connected} validators {total_validators} ({not_connected_stake_as_percentage:.2}% of total stake not connected)"
);
#[cfg(feature = "metrics")]
{
let connected_stake_as_percentage = 100.0 - not_connected_stake_as_percentage;
metrics::gauge(metrics::bft::CONNECTED_STAKE, connected_stake_as_percentage);
}
} else {
#[cfg(feature = "metrics")]
metrics::gauge(metrics::bft::CONNECTED_STAKE, 100.0);
};
if !committee.is_quorum_threshold_reached(&connected_validator_addresses) {
if self.tcp().uptime() > Self::MISSING_VALIDATOR_CONNECTIONS_GRACE_PERIOD {
error!("Not connected to a quorum of validators");
} else {
debug!("Not connected to a quorum of validators");
}
}
}
#[cfg(feature = "telemetry")]
fn log_participation_scores(&self) {
if let Ok(current_committee) = self.ledger.current_committee() {
let participation_scores = self.validator_telemetry().get_participation_scores(¤t_committee);
debug!("Participation Scores (in the last {} rounds):", self.storage.max_gc_rounds());
for (address, score) in participation_scores {
debug!("{}", format!(" {address} - {score:.2}%").dimmed());
}
}
}
fn handle_trusted_validators(&self) {
let trusted_peers = self.trusted_peers();
let handles: Vec<JoinHandle<_>> = trusted_peers
.iter()
.filter_map(|validator_ip| {
match self.connect(*validator_ip) {
Ok(hdl) => Some(hdl),
Err(ConnectError::SelfConnect { .. })
| Err(ConnectError::AlreadyConnected { .. })
| Err(ConnectError::AlreadyConnecting { .. }) => None,
Err(err) => {
warn!("Could not initiate connection to trusted validator at '{validator_ip}' - {err}");
None
}
}
})
.collect();
if !handles.is_empty() {
info!("Reconnnecting to {} out of {} trusted validators", handles.len(), trusted_peers.len());
}
}
async fn handle_bootstrap_peers(&self) {
if self.trusted_peers_only {
return;
}
let mut candidate_bootstrap = Vec::new();
let connected_bootstrap = self.filter_connected_peers(|peer| peer.node_type == NodeType::BootstrapClient);
for bootstrap_ip in bootstrap_peers::<N>(self.is_dev()) {
if !connected_bootstrap.iter().any(|peer| peer.listener_addr == bootstrap_ip) {
candidate_bootstrap.push(bootstrap_ip);
}
}
if connected_bootstrap.is_empty() {
let rng = &mut OsRng;
if let Some(peer_ip) = candidate_bootstrap.into_iter().choose(rng) {
match self.connect(peer_ip) {
Ok(hdl) => {
let result = hdl.await;
if let Err(err) = result {
warn!("{CONTEXT} Failed to connect to bootstrap peer at '{peer_ip}' - {err}");
}
}
Err(ConnectError::AlreadyConnected { .. }) | Err(ConnectError::AlreadyConnecting { .. }) => {}
Err(err) => {
warn!("{CONTEXT} Could not initiate connection to bootstrap peer at '{peer_ip}' - {err}")
}
}
}
}
let num_surplus = connected_bootstrap.len().saturating_sub(1);
if num_surplus > 0 {
let rng = &mut OsRng;
for peer in connected_bootstrap.into_iter().choose_multiple(rng, num_surplus) {
info!("{CONTEXT} Disconnecting from '{}' (exceeded maximum bootstrap)", peer.listener_addr);
<Self as Transport<N>>::send(
self,
peer.listener_addr,
Event::Disconnect(DisconnectReason::NoReasonGiven.into()),
)
.await;
self.disconnect(peer.listener_addr);
}
}
}
fn handle_unauthorized_validators(&self) {
let self_ = self.clone();
tokio::spawn(async move {
let validators = self_.get_connected_peers();
for peer in validators {
if peer.node_type == NodeType::BootstrapClient {
continue;
}
if !self_.is_authorized_validator_ip(peer.listener_addr) {
warn!(
"{CONTEXT} Disconnecting from '{}' - Validator is not in the current committee",
peer.listener_addr
);
Transport::send(&self_, peer.listener_addr, DisconnectReason::ProtocolViolation.into()).await;
self_.disconnect(peer.listener_addr);
}
}
});
}
async fn handle_min_connected_validators(&self) {
let trusted_validators = self.trusted_peers();
if self.number_of_connected_peers() < N::LATEST_MAX_CERTIFICATES() as usize {
let (addrs, handles): (Vec<_>, Vec<_>) = self
.get_candidate_peers()
.iter()
.filter_map(|peer| {
if trusted_validators.contains(&peer.listener_addr) {
return None;
}
if let Some(previous_attempt) = peer.last_connection_attempt
&& previous_attempt.elapsed() < Self::MINIMUM_TIME_BETWEEN_CONNECTION_ATTEMPTS
{
return None;
}
match self.connect(peer.listener_addr) {
Ok(hdl) => Some((peer.listener_addr, hdl)),
Err(ConnectError::AlreadyConnected { .. })
| Err(ConnectError::AlreadyConnecting { .. })
| Err(ConnectError::SelfConnect { .. }) => None,
Err(err) => {
warn!(
"{CONTEXT} Could not initiate connection to validator at '{}' - {err}",
peer.listener_addr
);
None
}
}
})
.unzip();
for (addr, result) in addrs.into_iter().zip(join_all(handles).await) {
if let Err(err) = result {
warn!("{CONTEXT} Failed to connect to validator at '{addr}' - {err}");
}
}
let validators = self.connected_peers();
if validators.is_empty() {
return;
}
if let Some(validator_ip) = validators.into_iter().choose(&mut rand::thread_rng()) {
let self_ = self.clone();
tokio::spawn(async move {
self_.cache.increment_outbound_validators_requests(validator_ip);
let _ = Transport::send(&self_, validator_ip, Event::ValidatorsRequest(ValidatorsRequest)).await;
});
}
}
}
async fn process_message_inner(&self, peer_addr: SocketAddr, message: Event<N>) {
if let Err(error) = self.inbound(peer_addr, message).await
&& let Some(peer_ip) = self.resolver.read().get_listener(peer_addr)
{
warn!("{CONTEXT} Disconnecting from '{peer_ip}' - {error}");
let self_ = self.clone();
tokio::spawn(async move {
Transport::send(&self_, peer_ip, DisconnectReason::ProtocolViolation.into()).await;
self_.disconnect(peer_ip);
});
}
}
fn handle_banned_ips(&self) {
self.tcp.banned_peers().remove_old_bans(IP_BAN_TIME_IN_SECS);
}
fn update_validator_whitelist(&self) {
if let Err(err) =
self.save_best_peers(&self.node_data_dir.validator_whitelist_path(), Some(MAX_VALIDATORS_TO_SEND), false)
{
warn!("{CONTEXT} Could not update the validator whitelist: {err}");
}
}
}
#[async_trait]
impl<N: Network> Transport<N> for Gateway<N> {
async fn send(&self, peer_ip: SocketAddr, event: Event<N>) -> Option<oneshot::Receiver<io::Result<()>>> {
macro_rules! send {
($self:ident, $cache_map:ident, $interval:expr, $freq:ident) => {{
while $self.cache.$cache_map(peer_ip, $interval) > $self.$freq() {
tokio::time::sleep(Duration::from_millis(10)).await;
}
$self.send_inner(peer_ip, event)
}};
}
match event {
Event::CertificateRequest(_) | Event::CertificateResponse(_) => {
self.cache.insert_outbound_event(peer_ip, CACHE_EVENTS_INTERVAL);
send!(self, insert_outbound_certificate, CACHE_REQUESTS_INTERVAL, max_cache_certificates)
}
Event::TransmissionRequest(_) | Event::TransmissionResponse(_) => {
self.cache.insert_outbound_event(peer_ip, CACHE_EVENTS_INTERVAL);
send!(self, insert_outbound_transmission, CACHE_REQUESTS_INTERVAL, max_cache_transmissions)
}
Event::BlockRequest(request) => {
self.cache.insert_outbound_block_request(peer_ip, request);
send!(self, insert_outbound_event, CACHE_EVENTS_INTERVAL, max_cache_events)
}
_ => {
send!(self, insert_outbound_event, CACHE_EVENTS_INTERVAL, max_cache_events)
}
}
}
fn broadcast(&self, event: Event<N>) {
if self.number_of_connected_peers() > 0 {
let self_ = self.clone();
let connected_peers = self.connected_peers();
tokio::spawn(async move {
for peer_ip in connected_peers {
let _ = Transport::send(&self_, peer_ip, event.clone()).await;
}
});
}
}
}
impl<N: Network> P2P for Gateway<N> {
fn tcp(&self) -> &Tcp {
&self.tcp
}
}
#[async_trait]
impl<N: Network> Reading for Gateway<N> {
type Codec = EventCodec<N>;
type Message = Event<N>;
fn codec(&self, _peer_addr: SocketAddr, _side: ConnectionSide) -> Self::Codec {
Default::default()
}
async fn process_message(&self, peer_addr: SocketAddr, message: Self::Message) -> io::Result<()> {
if matches!(message, Event::BlockRequest(_) | Event::BlockResponse(_)) {
let self_ = self.clone();
tokio::spawn(async move {
self_.process_message_inner(peer_addr, message).await;
});
} else {
self.process_message_inner(peer_addr, message).await;
}
Ok(())
}
fn message_queue_depth(&self) -> usize {
2 * BatchHeader::<N>::MAX_GC_ROUNDS
* N::LATEST_MAX_CERTIFICATES() as usize
* BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH
}
}
#[async_trait]
impl<N: Network> Writing for Gateway<N> {
type Codec = EventCodec<N>;
type Message = Event<N>;
fn codec(&self, _peer_addr: SocketAddr, _side: ConnectionSide) -> Self::Codec {
Default::default()
}
fn message_queue_depth(&self) -> usize {
2 * BatchHeader::<N>::MAX_GC_ROUNDS
* N::LATEST_MAX_CERTIFICATES() as usize
* BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH
}
}
#[async_trait]
impl<N: Network> Disconnect for Gateway<N> {
async fn handle_disconnect(&self, peer_addr: SocketAddr) {
if let Some(peer_ip) = self.resolve_to_listener(&peer_addr) {
self.downgrade_peer_to_candidate(peer_ip);
if let Some(sync_sender) = self.sync_sender.get() {
let tx_block_sync_remove_peer_ = sync_sender.tx_block_sync_remove_peer.clone();
tokio::spawn(async move {
if let Err(err) = tx_block_sync_remove_peer_.send(peer_ip).await {
warn!("{CONTEXT} Unable to remove '{peer_ip}' from the sync module - {err}");
}
});
}
self.cache.clear_outbound_validators_requests(peer_ip);
self.cache.clear_outbound_block_requests(peer_ip);
}
}
}
#[async_trait]
impl<N: Network> OnConnect for Gateway<N> {
async fn on_connect(&self, peer_addr: SocketAddr) {
if let Some(listener_addr) = self.resolve_to_listener(&peer_addr) {
if let Some(peer) = self.get_connected_peer(listener_addr) {
if peer.node_type == NodeType::BootstrapClient {
self.cache.increment_outbound_validators_requests(listener_addr);
let _ =
<Self as Transport<N>>::send(self, listener_addr, Event::ValidatorsRequest(ValidatorsRequest))
.await;
}
}
}
}
}
#[async_trait]
impl<N: Network> Handshake for Gateway<N> {
async fn perform_handshake(&self, mut connection: Connection) -> Result<Connection, ConnectError> {
let peer_addr = connection.addr();
let peer_side = connection.side();
#[cfg(not(test))]
if self.dev().is_none() && peer_side == ConnectionSide::Initiator {
if self.is_ip_banned(peer_addr.ip()) {
trace!("{CONTEXT} Rejected a connection request from banned IP '{}'", peer_addr.ip());
return Err(ConnectError::BannedIp { ip: peer_addr.ip() });
}
let num_attempts = self.cache.insert_inbound_connection(peer_addr.ip(), CONNECTION_ATTEMPTS_SINCE_SECS);
debug!("Number of connection attempts from '{}': {}", peer_addr.ip(), num_attempts);
if num_attempts > MAX_CONNECTION_ATTEMPTS {
self.update_ip_ban(peer_addr.ip());
trace!("{CONTEXT} Rejected a consecutive connection request from IP '{}'", peer_addr.ip());
return Err(ConnectError::other(anyhow!("'{}' appears to be spamming connections", peer_addr.ip())));
}
}
let stream = self.borrow_stream(&mut connection);
let mut listener_addr = if peer_side == ConnectionSide::Initiator {
debug!("{CONTEXT} Received a connection request from '{peer_addr}'");
None
} else {
debug!("{CONTEXT} Shaking hands with {peer_addr}...");
Some(peer_addr)
};
let restrictions_id = self.ledger.latest_restrictions_id();
let handshake_result = if peer_side == ConnectionSide::Responder {
self.handshake_inner_initiator(peer_addr, restrictions_id, stream).await
} else {
self.handshake_inner_responder(peer_addr, &mut listener_addr, restrictions_id, stream).await
};
if let Some(addr) = listener_addr {
match handshake_result {
Ok(ref cr) => {
let node_type = if bootstrap_peers::<N>(self.is_dev()).contains(&addr) {
NodeType::BootstrapClient
} else {
NodeType::Validator
};
if let Some(peer) = self.peer_pool.write().get_mut(&addr) {
self.resolver.write().insert_peer(addr, peer_addr, Some(cr.address));
peer.upgrade_to_connected(
peer_addr,
cr.listener_port,
cr.address,
node_type,
cr.version,
cr.snarkos_sha,
ConnectionMode::Gateway,
);
}
info!("{CONTEXT} Connected to '{addr}'");
}
Err(error) => {
if let Some(peer) = self.peer_pool.write().get_mut(&addr) {
if peer.is_connecting() {
peer.downgrade_to_candidate(addr);
}
}
return Err(error);
}
}
}
Ok(connection)
}
}
macro_rules! expect_event {
($event_ty:path, $framed:expr, $peer_addr:expr) => {
match $framed.try_next().await? {
Some($event_ty(data)) => {
trace!("{CONTEXT} Received '{}' from '{}'", data.name(), $peer_addr);
data
}
Some(Event::Disconnect($crate::events::Disconnect { reason })) => {
return Err(ConnectError::other(format!("'{}' disconnected with reason \"{reason}\"", $peer_addr)));
}
Some(ty) => {
return Err(ConnectError::other(format!(
"'{}' did not follow the handshake protocol: received {:?} instead of {}",
$peer_addr,
ty.name(),
stringify!($msg_ty),
)));
}
None => return Err(ConnectError::IoError(io::ErrorKind::BrokenPipe.into())),
}
};
}
async fn send_event<N: Network>(
framed: &mut Framed<&mut TcpStream, EventCodec<N>>,
peer_addr: SocketAddr,
event: Event<N>,
) -> io::Result<()> {
trace!("{CONTEXT} Sending '{}' to '{peer_addr}'", event.name());
framed.send(event).await
}
impl<N: Network> Gateway<N> {
async fn handshake_inner_initiator<'a>(
&'a self,
peer_addr: SocketAddr,
restrictions_id: Field<N>,
stream: &'a mut TcpStream,
) -> Result<ChallengeRequest<N>, ConnectError> {
self.add_connecting_peer(peer_addr)?;
let mut framed = Framed::new(stream, EventCodec::<N>::handshake());
let rng = &mut rand::rngs::OsRng;
let our_nonce = rng.r#gen();
let current_block_height = self.ledger.latest_block_height();
let consensus_version = N::CONSENSUS_VERSION(current_block_height).unwrap();
let snarkos_sha = match (self.is_dev(), consensus_version >= ConsensusVersion::V12, get_repo_commit_hash()) {
(true, _, Some(sha)) => Some(sha),
(_, true, Some(sha)) => Some(sha),
_ => None,
};
let our_request = ChallengeRequest::new(self.local_ip().port(), self.account.address(), our_nonce, snarkos_sha);
send_event(&mut framed, peer_addr, Event::ChallengeRequest(our_request)).await?;
let peer_response = expect_event!(Event::ChallengeResponse, framed, peer_addr);
let peer_request = expect_event!(Event::ChallengeRequest, framed, peer_addr);
if let Some(reason) = self
.verify_challenge_response(peer_addr, peer_request.address, peer_response, restrictions_id, our_nonce)
.await
{
send_event(&mut framed, peer_addr, reason.into()).await?;
return Err(ConnectError::application(reason));
}
if let Some(reason) = self.verify_challenge_request(peer_addr, &peer_request) {
send_event(&mut framed, peer_addr, reason.into()).await?;
return Err(reason.into_connect_error(peer_addr));
}
let response_nonce: u64 = rng.r#gen();
let data = [peer_request.nonce.to_le_bytes(), response_nonce.to_le_bytes()].concat();
let Ok(our_signature) = self.account.sign_bytes(&data, rng) else {
return Err(ConnectError::other(anyhow!("Failed to sign the challenge request nonce")));
};
let our_response =
ChallengeResponse { restrictions_id, signature: Data::Object(our_signature), nonce: response_nonce };
send_event(&mut framed, peer_addr, Event::ChallengeResponse(our_response)).await?;
Ok(peer_request)
}
async fn handshake_inner_responder<'a>(
&'a self,
peer_addr: SocketAddr,
peer_ip: &mut Option<SocketAddr>,
restrictions_id: Field<N>,
stream: &'a mut TcpStream,
) -> Result<ChallengeRequest<N>, ConnectError> {
let mut framed = Framed::new(stream, EventCodec::<N>::handshake());
let peer_request = expect_event!(Event::ChallengeRequest, framed, peer_addr);
if self.account.address() == peer_request.address {
return Err(ConnectError::SelfConnect { address: peer_addr });
}
*peer_ip = Some(SocketAddr::new(peer_addr.ip(), peer_request.listener_port));
let peer_ip = peer_ip.unwrap();
if let Err(reason) = self.ensure_peer_is_allowed(peer_ip) {
send_event(&mut framed, peer_addr, reason.into()).await?;
return Err(reason.into_connect_error(peer_addr));
}
self.add_connecting_peer(peer_ip)?;
if let Some(reason) = self.verify_challenge_request(peer_addr, &peer_request) {
send_event(&mut framed, peer_addr, reason.into()).await?;
return Err(reason.into_connect_error(peer_addr));
}
let rng = &mut rand::rngs::OsRng;
let response_nonce: u64 = rng.r#gen();
let data = [peer_request.nonce.to_le_bytes(), response_nonce.to_le_bytes()].concat();
let Ok(our_signature) = self.account.sign_bytes(&data, rng) else {
return Err(ConnectError::other(anyhow!("Failed to sign the challenge request nonce")));
};
let our_response =
ChallengeResponse { restrictions_id, signature: Data::Object(our_signature), nonce: response_nonce };
send_event(&mut framed, peer_addr, Event::ChallengeResponse(our_response)).await?;
let our_nonce = rng.r#gen();
let current_block_height = self.ledger.latest_block_height();
let consensus_version = N::CONSENSUS_VERSION(current_block_height).unwrap();
let snarkos_sha = match (self.is_dev(), consensus_version >= ConsensusVersion::V12, get_repo_commit_hash()) {
(true, _, Some(sha)) => Some(sha),
(_, true, Some(sha)) => Some(sha),
_ => None,
};
let our_request = ChallengeRequest::new(self.local_ip().port(), self.account.address(), our_nonce, snarkos_sha);
send_event(&mut framed, peer_addr, Event::ChallengeRequest(our_request)).await?;
let peer_response = expect_event!(Event::ChallengeResponse, framed, peer_addr);
if let Some(reason) = self
.verify_challenge_response(peer_addr, peer_request.address, peer_response, restrictions_id, our_nonce)
.await
{
send_event(&mut framed, peer_addr, reason.into()).await?;
Err(reason.into_connect_error(peer_addr))
} else {
Ok(peer_request)
}
}
#[must_use]
fn verify_challenge_request(&self, peer_addr: SocketAddr, event: &ChallengeRequest<N>) -> Option<DisconnectReason> {
let &ChallengeRequest { version, listener_port, address, nonce: _, ref snarkos_sha } = event;
log_repo_sha_comparison(peer_addr, snarkos_sha, CONTEXT);
let listener_addr = SocketAddr::new(peer_addr.ip(), listener_port);
if version < Event::<N>::VERSION {
return Some(DisconnectReason::OutdatedClientVersion);
}
if self.trusted_peers_only && !self.is_trusted(listener_addr) {
warn!("{CONTEXT} Dropping '{peer_addr}' for being an untrusted validator ({address})");
return Some(DisconnectReason::NoExternalPeersAllowed);
}
if !bootstrap_peers::<N>(self.dev().is_some()).contains(&listener_addr) {
if !self.is_authorized_validator_address(address) {
return Some(DisconnectReason::UnauthorizedValidator);
}
}
if self.is_connected_address(address) {
return Some(DisconnectReason::AlreadyConnectedToAleoAddress);
}
None
}
#[must_use]
async fn verify_challenge_response(
&self,
peer_addr: SocketAddr,
peer_address: Address<N>,
response: ChallengeResponse<N>,
expected_restrictions_id: Field<N>,
expected_nonce: u64,
) -> Option<DisconnectReason> {
let ChallengeResponse { restrictions_id, signature, nonce } = response;
if restrictions_id != expected_restrictions_id {
warn!("{CONTEXT} Handshake with '{peer_addr}' failed (incorrect restrictions ID)");
return Some(DisconnectReason::InvalidChallengeResponse);
}
let Ok(signature) = spawn_blocking!(signature.deserialize_blocking()) else {
warn!("{CONTEXT} Handshake with '{peer_addr}' failed (cannot deserialize the signature)");
return Some(DisconnectReason::InvalidChallengeResponse);
};
if !signature.verify_bytes(&peer_address, &[expected_nonce.to_le_bytes(), nonce.to_le_bytes()].concat()) {
warn!("{CONTEXT} Handshake with '{peer_addr}' failed (invalid signature)");
return Some(DisconnectReason::InvalidChallengeResponse);
}
None
}
}
#[cfg(test)]
mod prop_tests {
use crate::{
Gateway,
MAX_WORKERS,
MEMORY_POOL_PORT,
Worker,
helpers::{Storage, init_primary_channels, init_worker_channels},
};
use snarkos_account::Account;
use snarkos_node_bft_ledger_service::MockLedgerService;
use snarkos_node_bft_storage_service::BFTMemoryService;
use snarkos_node_network::PeerPoolHandling;
use snarkos_node_tcp::P2P;
use snarkos_utilities::NodeDataDir;
use snarkvm::{
ledger::{
committee::{
Committee,
prop_tests::{CommitteeContext, ValidatorSet},
test_helpers::sample_committee_for_round_and_members,
},
narwhal::{BatchHeader, batch_certificate::test_helpers::sample_batch_certificate_for_round},
},
prelude::{MainnetV0, PrivateKey},
utilities::TestRng,
};
use indexmap::{IndexMap, IndexSet};
use proptest::{
prelude::{Arbitrary, BoxedStrategy, Just, Strategy, any, any_with},
sample::Selector,
};
use std::{
fmt::{Debug, Formatter},
net::{IpAddr, Ipv4Addr, SocketAddr},
sync::Arc,
};
use test_strategy::proptest;
type CurrentNetwork = MainnetV0;
impl Debug for Gateway<CurrentNetwork> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("Gateway").field(&self.account.address()).field(&self.tcp.config()).finish()
}
}
#[derive(Debug, test_strategy::Arbitrary)]
enum GatewayAddress {
Dev(u8),
Prod(Option<SocketAddr>),
}
impl GatewayAddress {
fn ip(&self) -> Option<SocketAddr> {
if let GatewayAddress::Prod(ip) = self {
return *ip;
}
None
}
fn port(&self) -> Option<u16> {
if let GatewayAddress::Dev(port) = self {
return Some(*port as u16);
}
None
}
}
impl Arbitrary for Gateway<CurrentNetwork> {
type Parameters = ();
type Strategy = BoxedStrategy<Gateway<CurrentNetwork>>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
any_valid_dev_gateway()
.prop_map(|(storage, _, private_key, address)| {
Gateway::new(
Account::try_from(private_key).unwrap(),
storage.clone(),
storage.ledger().clone(),
address.ip(),
&[],
false,
NodeDataDir::new_test(None),
address.port(),
)
.unwrap()
})
.boxed()
}
}
type GatewayInput = (Storage<CurrentNetwork>, CommitteeContext, PrivateKey<CurrentNetwork>, GatewayAddress);
fn any_valid_dev_gateway() -> BoxedStrategy<GatewayInput> {
(any::<CommitteeContext>(), any::<Selector>())
.prop_flat_map(|(context, account_selector)| {
let CommitteeContext(_, ValidatorSet(validators)) = context.clone();
(
any_with::<Storage<CurrentNetwork>>(context.clone()),
Just(context),
Just(account_selector.select(validators)),
0u8..,
)
.prop_map(|(a, b, c, d)| (a, b, c.private_key, GatewayAddress::Dev(d)))
})
.boxed()
}
fn any_valid_prod_gateway() -> BoxedStrategy<GatewayInput> {
(any::<CommitteeContext>(), any::<Selector>())
.prop_flat_map(|(context, account_selector)| {
let CommitteeContext(_, ValidatorSet(validators)) = context.clone();
(
any_with::<Storage<CurrentNetwork>>(context.clone()),
Just(context),
Just(account_selector.select(validators)),
any::<Option<SocketAddr>>(),
)
.prop_map(|(a, b, c, d)| (a, b, c.private_key, GatewayAddress::Prod(d)))
})
.boxed()
}
#[proptest]
fn gateway_dev_initialization(#[strategy(any_valid_dev_gateway())] input: GatewayInput) {
let (storage, _, private_key, dev) = input;
let account = Account::try_from(private_key).unwrap();
let gateway = Gateway::new(
account.clone(),
storage.clone(),
storage.ledger().clone(),
dev.ip(),
&[],
false,
NodeDataDir::new_test(None),
dev.port(),
)
.unwrap();
let tcp_config = gateway.tcp().config();
assert_eq!(tcp_config.listener_ip, Some(IpAddr::V4(Ipv4Addr::LOCALHOST)));
assert_eq!(tcp_config.desired_listening_port, Some(MEMORY_POOL_PORT + dev.port().unwrap()));
let tcp_config = gateway.tcp().config();
assert_eq!(tcp_config.max_connections, Committee::<CurrentNetwork>::max_committee_size() * 10);
assert_eq!(gateway.account().address(), account.address());
}
#[proptest]
fn gateway_prod_initialization(#[strategy(any_valid_prod_gateway())] input: GatewayInput) {
let (storage, _, private_key, dev) = input;
let account = Account::try_from(private_key).unwrap();
let gateway = Gateway::new(
account.clone(),
storage.clone(),
storage.ledger().clone(),
dev.ip(),
&[],
false,
NodeDataDir::new_test(None),
dev.port(),
)
.unwrap();
let tcp_config = gateway.tcp().config();
if let Some(socket_addr) = dev.ip() {
assert_eq!(tcp_config.listener_ip, Some(socket_addr.ip()));
assert_eq!(tcp_config.desired_listening_port, Some(socket_addr.port()));
} else {
assert_eq!(tcp_config.listener_ip, Some(IpAddr::V4(Ipv4Addr::UNSPECIFIED)));
assert_eq!(tcp_config.desired_listening_port, Some(MEMORY_POOL_PORT));
}
let tcp_config = gateway.tcp().config();
assert_eq!(tcp_config.max_connections, Committee::<CurrentNetwork>::max_committee_size() * 10);
assert_eq!(gateway.account().address(), account.address());
}
#[proptest(async = "tokio")]
async fn gateway_start(
#[strategy(any_valid_dev_gateway())] input: GatewayInput,
#[strategy(0..MAX_WORKERS)] workers_count: u8,
) {
let (storage, committee, private_key, dev) = input;
let committee = committee.0;
let worker_storage = storage.clone();
let account = Account::try_from(private_key).unwrap();
let gateway = Gateway::new(
account,
storage.clone(),
storage.ledger().clone(),
dev.ip(),
&[],
false,
NodeDataDir::new_test(None),
dev.port(),
)
.unwrap();
let (primary_sender, _) = init_primary_channels();
let (workers, worker_senders) = {
let mut tx_workers = IndexMap::new();
let mut workers = IndexMap::new();
for id in 0..workers_count {
let (tx_worker, rx_worker) = init_worker_channels();
let ledger = Arc::new(MockLedgerService::new(committee.clone()));
let worker =
Worker::new(id, Arc::new(gateway.clone()), worker_storage.clone(), ledger, Default::default())
.unwrap();
worker.run(rx_worker);
workers.insert(id, worker);
tx_workers.insert(id, tx_worker);
}
(workers, tx_workers)
};
gateway.run(primary_sender, worker_senders, None).await;
assert_eq!(
gateway.local_ip(),
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), MEMORY_POOL_PORT + dev.port().unwrap())
);
assert_eq!(gateway.num_workers(), workers.len() as u8);
}
#[proptest]
fn test_is_authorized_validator(#[strategy(any_valid_dev_gateway())] input: GatewayInput) {
let rng = &mut TestRng::default();
let current_round = 2;
let committee_size = 4;
let max_gc_rounds = BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64;
let (_, _, private_key, dev) = input;
let account = Account::try_from(private_key).unwrap();
let mut certificates = IndexSet::new();
for _ in 0..committee_size {
certificates.insert(sample_batch_certificate_for_round(current_round, rng));
}
let addresses: Vec<_> = certificates.iter().map(|certificate| certificate.author()).collect();
let committee = sample_committee_for_round_and_members(current_round, addresses, rng);
for _ in 0..committee_size {
certificates.insert(sample_batch_certificate_for_round(current_round, rng));
}
let ledger = Arc::new(MockLedgerService::new(committee.clone()));
let storage = Storage::new(ledger.clone(), Arc::new(BFTMemoryService::new()), max_gc_rounds);
let gateway = Gateway::new(
account.clone(),
storage.clone(),
ledger.clone(),
dev.ip(),
&[],
false,
NodeDataDir::new_test(None),
dev.port(),
)
.unwrap();
for certificate in certificates.iter() {
storage.testing_only_insert_certificate_testing_only(certificate.clone());
}
for i in 0..certificates.clone().len() {
let is_authorized = gateway.is_authorized_validator_address(certificates[i].author());
if i < committee_size {
assert!(is_authorized);
} else {
assert!(!is_authorized);
}
}
}
}