use std::{
pin::Pin,
task::{Context, Poll},
};
use futures::{FutureExt, Stream, StreamExt};
use indexmap::IndexMap;
use tokio::{
sync::broadcast,
time::{self, Instant},
};
use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream, IntervalStream};
use zebra_chain::serialization::AtLeastOne;
use crate::{
constants::INVENTORY_ROTATION_INTERVAL,
protocol::{external::InventoryHash, internal::InventoryResponse},
BoxError, PeerSocketAddr,
};
use self::update::Update;
use InventoryResponse::*;
pub mod update;
#[cfg(test)]
mod tests;
pub const MAX_INV_PER_MAP: usize = 1000;
pub const MAX_PEERS_PER_INV: usize = 70;
pub type InventoryStatus<T> = InventoryResponse<T, T>;
pub type InventoryChange = InventoryStatus<(AtLeastOne<InventoryHash>, PeerSocketAddr)>;
type InventoryMarker = InventoryStatus<()>;
pub struct InventoryRegistry {
current: IndexMap<InventoryHash, IndexMap<PeerSocketAddr, InventoryMarker>>,
prev: IndexMap<InventoryHash, IndexMap<PeerSocketAddr, InventoryMarker>>,
inv_stream: Pin<
Box<dyn Stream<Item = Result<InventoryChange, BroadcastStreamRecvError>> + Send + 'static>,
>,
interval: IntervalStream,
}
impl std::fmt::Debug for InventoryRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("InventoryRegistry")
.field("current", &self.current)
.field("prev", &self.prev)
.finish()
}
}
impl InventoryChange {
pub fn new_available(hash: InventoryHash, peer: PeerSocketAddr) -> Self {
let bv = AtLeastOne::from_vec(vec![hash]).expect("bounded vec must fit");
InventoryStatus::Available((bv, peer))
}
#[allow(dead_code)]
pub fn new_missing(hash: InventoryHash, peer: PeerSocketAddr) -> Self {
let bv = AtLeastOne::from_vec(vec![hash]).expect("bounded vec must fit");
InventoryStatus::Missing((bv, peer))
}
pub fn new_available_multi<'a>(
hashes: impl IntoIterator<Item = &'a InventoryHash>,
peer: PeerSocketAddr,
) -> Option<Self> {
let mut hashes: Vec<InventoryHash> = hashes.into_iter().copied().collect();
hashes.truncate(MAX_INV_PER_MAP);
let hashes = hashes.try_into().ok();
hashes.map(|hashes| InventoryStatus::Available((hashes, peer)))
}
pub fn new_missing_multi<'a>(
hashes: impl IntoIterator<Item = &'a InventoryHash>,
peer: PeerSocketAddr,
) -> Option<Self> {
let mut hashes: Vec<InventoryHash> = hashes.into_iter().copied().collect();
hashes.truncate(MAX_INV_PER_MAP);
let hashes = hashes.try_into().ok();
hashes.map(|hashes| InventoryStatus::Missing((hashes, peer)))
}
}
impl<T> InventoryStatus<T> {
pub fn marker(&self) -> InventoryMarker {
self.as_ref().map(|_inner| ())
}
pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> InventoryStatus<U> {
match self {
Available(item) => Available(f(item)),
Missing(item) => Missing(f(item)),
}
}
}
impl<T: Clone> InventoryStatus<T> {
pub fn to_inner(&self) -> T {
match self {
Available(item) | Missing(item) => item.clone(),
}
}
}
impl InventoryRegistry {
pub fn new(inv_stream: broadcast::Receiver<InventoryChange>) -> Self {
let interval = INVENTORY_ROTATION_INTERVAL;
let mut interval = tokio::time::interval_at(Instant::now() + interval, interval);
interval.set_missed_tick_behavior(time::MissedTickBehavior::Burst);
Self {
current: Default::default(),
prev: Default::default(),
inv_stream: BroadcastStream::new(inv_stream).boxed(),
interval: IntervalStream::new(interval),
}
}
pub fn advertising_peers(&self, hash: InventoryHash) -> impl Iterator<Item = &PeerSocketAddr> {
self.status_peers(hash)
.filter_map(|addr_status| addr_status.available())
}
#[allow(dead_code)]
pub fn missing_peers(&self, hash: InventoryHash) -> impl Iterator<Item = &PeerSocketAddr> {
self.status_peers(hash)
.filter_map(|addr_status| addr_status.missing())
}
pub fn status_peers(
&self,
hash: InventoryHash,
) -> impl Iterator<Item = InventoryStatus<&PeerSocketAddr>> {
let prev = self.prev.get(&hash);
let current = self.current.get(&hash);
let prev = prev
.into_iter()
.flatten()
.filter(move |(addr, _status)| !self.has_current_status(hash, **addr));
let current = current.into_iter().flatten();
current
.chain(prev)
.map(|(addr, status)| status.map(|()| addr))
}
pub fn has_current_status(&self, hash: InventoryHash, addr: PeerSocketAddr) -> bool {
self.current
.get(&hash)
.and_then(|current| current.get(&addr))
.is_some()
}
#[allow(dead_code)]
pub fn status_hashes(
&self,
) -> impl Iterator<Item = (&InventoryHash, &IndexMap<PeerSocketAddr, InventoryMarker>)> {
self.current.iter().chain(self.prev.iter())
}
#[allow(dead_code)]
pub fn update(&mut self) -> Update<'_> {
Update::new(self)
}
pub fn poll_inventory(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), BoxError>> {
let mut result = Poll::Pending;
if Pin::new(&mut self.interval).poll_next(cx).is_ready() {
self.rotate();
result = Poll::Ready(Ok(()));
}
loop {
let channel_result = self.inv_stream.next().poll_unpin(cx);
match channel_result {
Poll::Ready(Some(Ok(change))) => {
self.register(change);
result = Poll::Ready(Ok(()));
}
Poll::Ready(Some(Err(BroadcastStreamRecvError::Lagged(count)))) => {
metrics::counter!("pool.inventory.dropped").increment(1);
metrics::counter!("pool.inventory.dropped.messages").increment(count);
info!(count, "dropped lagged inventory advertisements");
}
Poll::Ready(None) => {
result = Poll::Ready(Err(broadcast::error::RecvError::Closed.into()));
}
Poll::Pending => {
break;
}
}
}
result
}
fn register(&mut self, change: InventoryChange) {
let new_status = change.marker();
let (invs, addr) = change.to_inner();
for inv in invs {
use InventoryHash::*;
assert!(
matches!(inv, Block(_) | Tx(_) | Wtx(_)),
"unexpected inventory type: {inv:?} from peer: {addr:?}",
);
let hash_peers = self.current.entry(inv).or_default();
if let Some(old_status) = hash_peers.get(&addr) {
if old_status.is_missing() && new_status.is_available() {
debug!(?new_status, ?old_status, ?addr, ?inv, "skipping new status");
continue;
}
debug!(
?new_status,
?old_status,
?addr,
?inv,
"keeping both new and old status"
);
}
let replaced_status = hash_peers.insert(addr, new_status);
debug!(
?new_status,
?replaced_status,
?addr,
?inv,
"inserted new status"
);
if hash_peers.len() > MAX_PEERS_PER_INV {
hash_peers.shift_remove_index(0);
}
if self.current.len() > MAX_INV_PER_MAP {
self.current.shift_remove_index(0);
}
}
}
fn rotate(&mut self) {
self.prev = std::mem::take(&mut self.current);
}
}