use crate::MAX_FETCH_TIMEOUT_IN_MS;
use snarkos_node_bft_ledger_service::LedgerService;
use snarkvm::{
console::network::{Network, consensus_config_value},
prelude::Result,
};
use anyhow::anyhow;
#[cfg(feature = "locktick")]
use locktick::parking_lot::RwLock;
#[cfg(not(feature = "locktick"))]
use parking_lot::RwLock;
use std::{
collections::{HashMap, HashSet},
hash::Hash,
net::SocketAddr,
sync::Arc,
};
use time::OffsetDateTime;
use tokio::sync::oneshot;
pub(crate) const CALLBACK_EXPIRATION_IN_SECS: i64 = MAX_FETCH_TIMEOUT_IN_MS.div_ceil(1000) as i64;
pub fn max_redundant_requests<N: Network>(ledger: Arc<dyn LedgerService<N>>, round: u64) -> Result<usize> {
let num_validators =
if let Ok(n) = ledger.get_committee_lookback_for_round(round).map(|committee| committee.num_members()) {
n
} else {
let max_committee_size = consensus_config_value!(N, MAX_CERTIFICATES, ledger.latest_block_height())
.ok_or_else(|| anyhow!("Couldn't obtain MAX_CERTIFICATES"))?;
max_committee_size as usize
};
Ok(1 + num_validators.saturating_div(3))
}
#[derive(Debug)]
pub struct Pending<T: PartialEq + Eq + Hash, V: Clone> {
pending: RwLock<HashMap<T, HashMap<SocketAddr, Vec<(oneshot::Sender<V>, i64, bool)>>>>,
}
impl<T: Copy + Clone + PartialEq + Eq + Hash, V: Clone> Default for Pending<T, V> {
fn default() -> Self {
Self::new()
}
}
impl<T: Copy + Clone + PartialEq + Eq + Hash, V: Clone> Pending<T, V> {
pub fn new() -> Self {
Self { pending: Default::default() }
}
pub fn is_empty(&self) -> bool {
self.pending.read().is_empty()
}
pub fn len(&self) -> usize {
self.pending.read().len()
}
pub fn contains(&self, item: impl Into<T>) -> bool {
self.pending.read().contains_key(&item.into())
}
pub fn contains_peer(&self, item: impl Into<T>, peer_ip: SocketAddr) -> bool {
self.pending.read().get(&item.into()).is_some_and(|peer_ips| peer_ips.contains_key(&peer_ip))
}
pub fn contains_peer_with_sent_request(&self, item: impl Into<T>, peer_ip: SocketAddr) -> bool {
self.pending.read().get(&item.into()).is_some_and(|peer_ips| {
peer_ips
.get(&peer_ip)
.map(|callbacks| callbacks.iter().any(|(_, _, request_sent)| *request_sent))
.unwrap_or(false)
})
}
pub fn get_peers(&self, item: impl Into<T>) -> Option<HashSet<SocketAddr>> {
self.pending.read().get(&item.into()).map(|map| map.keys().cloned().collect())
}
pub fn num_callbacks(&self, item: impl Into<T>) -> usize {
let item = item.into();
let now = OffsetDateTime::now_utc().unix_timestamp();
self.clear_expired_callbacks_for_item(now, item);
self.pending.read().get(&item).map_or(0, |peers| peers.values().fold(0, |acc, v| acc.saturating_add(v.len())))
}
pub fn num_sent_requests(&self, item: impl Into<T>) -> usize {
let item = item.into();
let now = OffsetDateTime::now_utc().unix_timestamp();
self.clear_expired_callbacks_for_item(now, item);
self.pending
.read()
.get(&item)
.map_or(0, |peers| peers.values().flatten().filter(|(_, _, request_sent)| *request_sent).count())
}
pub fn insert(
&self,
item: impl Into<T>,
peer_ip: SocketAddr,
callback: Option<(oneshot::Sender<V>, bool)>,
) -> bool {
let item = item.into();
let now = OffsetDateTime::now_utc().unix_timestamp();
let result = {
let mut pending = self.pending.write();
let entry = pending.entry(item).or_default();
let is_new_peer = !entry.contains_key(&peer_ip);
let peer_entry = entry.entry(peer_ip).or_default();
if let Some((callback, request_sent)) = callback {
peer_entry.push((callback, now, request_sent));
}
is_new_peer
};
self.clear_expired_callbacks_for_item(now, item);
result
}
pub fn remove(&self, item: impl Into<T>, callback_value: Option<V>) -> Option<HashSet<SocketAddr>> {
let item = item.into();
match self.pending.write().remove(&item) {
Some(callbacks) => {
let peer_ips = callbacks.keys().copied().collect();
if let Some(callback_value) = callback_value {
for (callback, _, _) in callbacks.into_values().flat_map(|callbacks| callbacks.into_iter()) {
callback.send(callback_value.clone()).ok();
}
}
Some(peer_ips)
}
None => None,
}
}
pub fn clear_expired_callbacks_for_item(&self, now: i64, item: impl Into<T>) {
let item = item.into();
let mut pending = self.pending.write();
if let Some(peer_map) = pending.get_mut(&item) {
for (_, callbacks) in peer_map.iter_mut() {
callbacks.retain(|(_, timestamp, _)| now - *timestamp <= CALLBACK_EXPIRATION_IN_SECS);
}
peer_map.retain(|_, callbacks| !callbacks.is_empty());
if peer_map.is_empty() {
pending.remove(&item);
}
}
}
pub fn clear_expired_callbacks(&self) {
let now = OffsetDateTime::now_utc().unix_timestamp();
let mut pending = self.pending.write();
pending.retain(|_, peer_map| {
for (_, callbacks) in peer_map.iter_mut() {
callbacks.retain(|(_, timestamp, _)| now - *timestamp <= CALLBACK_EXPIRATION_IN_SECS);
}
peer_map.retain(|_, callbacks| !callbacks.is_empty());
!peer_map.is_empty()
});
}
}
#[cfg(test)]
mod tests {
use super::*;
use snarkvm::{
ledger::narwhal::TransmissionID,
prelude::{Rng, TestRng},
};
use std::{thread, time::Duration};
type CurrentNetwork = snarkvm::prelude::MainnetV0;
const ITERATIONS: usize = 100;
#[test]
fn test_pending() {
let rng = &mut TestRng::default();
let pending = Pending::<TransmissionID<CurrentNetwork>, ()>::new();
assert!(pending.is_empty());
assert_eq!(pending.len(), 0);
let solution_id_1 = TransmissionID::Solution(
rng.r#gen::<u64>().into(),
rng.r#gen::<<CurrentNetwork as Network>::TransmissionChecksum>(),
);
let solution_id_2 = TransmissionID::Solution(
rng.r#gen::<u64>().into(),
rng.r#gen::<<CurrentNetwork as Network>::TransmissionChecksum>(),
);
let solution_id_3 = TransmissionID::Solution(
rng.r#gen::<u64>().into(),
rng.r#gen::<<CurrentNetwork as Network>::TransmissionChecksum>(),
);
let solution_id_4 = TransmissionID::Solution(
rng.r#gen::<u64>().into(),
rng.r#gen::<<CurrentNetwork as Network>::TransmissionChecksum>(),
);
let addr_1 = SocketAddr::from(([127, 0, 0, 1], 1234));
let addr_2 = SocketAddr::from(([127, 0, 0, 1], 2345));
let addr_3 = SocketAddr::from(([127, 0, 0, 1], 3456));
let addr_4 = SocketAddr::from(([127, 0, 0, 1], 4567));
let (callback_sender_1, _) = oneshot::channel();
let (callback_sender_2, _) = oneshot::channel();
let (callback_sender_3, _) = oneshot::channel();
let (callback_sender_4, _) = oneshot::channel();
assert!(pending.insert(solution_id_1, addr_1, Some((callback_sender_1, true))));
assert!(pending.insert(solution_id_2, addr_2, Some((callback_sender_2, true))));
assert!(pending.insert(solution_id_3, addr_3, Some((callback_sender_3, true))));
assert!(pending.insert(solution_id_4, addr_4, Some((callback_sender_4, false))));
assert_eq!(pending.len(), 4);
assert!(!pending.is_empty());
let ids = [solution_id_1, solution_id_2, solution_id_3];
let peers = [addr_1, addr_2, addr_3];
for i in 0..3 {
let id = ids[i];
assert!(pending.contains(id));
assert!(pending.contains_peer(id, peers[i]));
assert!(pending.contains_peer_with_sent_request(id, peers[i]));
}
assert!(pending.contains_peer(solution_id_4, addr_4));
assert!(!pending.contains_peer_with_sent_request(solution_id_4, addr_4));
let unknown_id = TransmissionID::Solution(
rng.r#gen::<u64>().into(),
rng.r#gen::<<CurrentNetwork as Network>::TransmissionChecksum>(),
);
assert!(!pending.contains(unknown_id));
assert_eq!(pending.get_peers(solution_id_1), Some(HashSet::from([addr_1])));
assert_eq!(pending.get_peers(solution_id_2), Some(HashSet::from([addr_2])));
assert_eq!(pending.get_peers(solution_id_3), Some(HashSet::from([addr_3])));
assert_eq!(pending.get_peers(solution_id_4), Some(HashSet::from([addr_4])));
assert_eq!(pending.get_peers(unknown_id), None);
assert!(pending.remove(solution_id_1, None).is_some());
assert!(pending.remove(solution_id_2, None).is_some());
assert!(pending.remove(solution_id_3, None).is_some());
assert!(pending.remove(solution_id_4, None).is_some());
assert!(pending.remove(unknown_id, None).is_none());
assert!(pending.is_empty());
}
#[test]
fn test_expired_callbacks() {
let rng = &mut TestRng::default();
let pending = Pending::<TransmissionID<CurrentNetwork>, ()>::new();
assert!(pending.is_empty());
assert_eq!(pending.len(), 0);
let solution_id_1 = TransmissionID::Solution(
rng.r#gen::<u64>().into(),
rng.r#gen::<<CurrentNetwork as Network>::TransmissionChecksum>(),
);
let addr_1 = SocketAddr::from(([127, 0, 0, 1], 1234));
let addr_2 = SocketAddr::from(([127, 0, 0, 1], 2345));
let addr_3 = SocketAddr::from(([127, 0, 0, 1], 3456));
let (callback_sender_1, _) = oneshot::channel();
let (callback_sender_2, _) = oneshot::channel();
let (callback_sender_3, _) = oneshot::channel();
assert!(pending.insert(solution_id_1, addr_1, Some((callback_sender_1, true))));
assert!(pending.insert(solution_id_1, addr_2, Some((callback_sender_2, true))));
thread::sleep(Duration::from_secs(CALLBACK_EXPIRATION_IN_SECS as u64 - 1));
assert!(pending.insert(solution_id_1, addr_3, Some((callback_sender_3, true))));
assert_eq!(pending.num_callbacks(solution_id_1), 3);
thread::sleep(Duration::from_secs(2));
assert_eq!(pending.num_callbacks(solution_id_1), 1);
thread::sleep(Duration::from_secs(CALLBACK_EXPIRATION_IN_SECS as u64));
assert_eq!(pending.num_callbacks(solution_id_1), 0);
}
#[test]
fn test_num_sent_requests() {
let rng = &mut TestRng::default();
let pending = Pending::<TransmissionID<CurrentNetwork>, ()>::new();
for _ in 0..ITERATIONS {
let solution_id = TransmissionID::Solution(
rng.r#gen::<u64>().into(),
rng.r#gen::<<CurrentNetwork as Network>::TransmissionChecksum>(),
);
let mut expected_num_sent_requests = 0;
for i in 0..ITERATIONS {
let addr = SocketAddr::from(([127, 0, 0, 1], i as u16));
let (callback_sender, _) = oneshot::channel();
let is_sent_request = rng.r#gen();
if is_sent_request {
expected_num_sent_requests += 1;
}
assert!(pending.insert(solution_id, addr, Some((callback_sender, is_sent_request))));
}
assert_eq!(pending.num_sent_requests(solution_id), expected_num_sent_requests);
}
}
#[test]
fn test_expired_items() {
let rng = &mut TestRng::default();
let pending = Pending::<TransmissionID<CurrentNetwork>, ()>::new();
assert!(pending.is_empty());
assert_eq!(pending.len(), 0);
let solution_id_1 = TransmissionID::Solution(
rng.r#gen::<u64>().into(),
rng.r#gen::<<CurrentNetwork as Network>::TransmissionChecksum>(),
);
let solution_id_2 = TransmissionID::Solution(
rng.r#gen::<u64>().into(),
rng.r#gen::<<CurrentNetwork as Network>::TransmissionChecksum>(),
);
let addr_1 = SocketAddr::from(([127, 0, 0, 1], 1234));
let addr_2 = SocketAddr::from(([127, 0, 0, 1], 2345));
let addr_3 = SocketAddr::from(([127, 0, 0, 1], 3456));
let (callback_sender_1, _) = oneshot::channel();
let (callback_sender_2, _) = oneshot::channel();
let (callback_sender_3, _) = oneshot::channel();
assert!(pending.insert(solution_id_1, addr_1, Some((callback_sender_1, true))));
assert!(pending.insert(solution_id_1, addr_2, Some((callback_sender_2, true))));
assert!(pending.insert(solution_id_2, addr_3, Some((callback_sender_3, true))));
assert_eq!(pending.num_callbacks(solution_id_1), 2);
assert_eq!(pending.num_callbacks(solution_id_2), 1);
assert_eq!(pending.len(), 2);
thread::sleep(Duration::from_secs(CALLBACK_EXPIRATION_IN_SECS as u64 + 1));
pending.clear_expired_callbacks();
assert_eq!(pending.num_callbacks(solution_id_1), 0);
assert_eq!(pending.num_callbacks(solution_id_2), 0);
assert!(pending.is_empty());
}
}
#[cfg(test)]
mod prop_tests {
use super::*;
use test_strategy::{Arbitrary, proptest};
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub struct Item {
pub id: usize,
}
#[derive(Arbitrary, Clone, Debug)]
pub struct PendingInput {
#[strategy(1..5_000usize)]
pub count: usize,
}
impl PendingInput {
pub fn to_pending(&self) -> Pending<Item, ()> {
let pending = Pending::<Item, ()>::new();
for i in 0..self.count {
pending.insert(
Item { id: i },
SocketAddr::from(([127, 0, 0, 1], i as u16)),
Some((oneshot::channel().0, true)),
);
}
pending
}
}
#[proptest]
fn test_pending_proptest(input: PendingInput) {
let pending = input.to_pending();
assert_eq!(pending.len(), input.count);
assert!(!pending.is_empty());
assert!(!pending.contains(Item { id: input.count + 1 }));
assert_eq!(pending.get_peers(Item { id: input.count + 1 }), None);
assert!(pending.remove(Item { id: input.count + 1 }, None).is_none());
for i in 0..input.count {
assert!(pending.contains(Item { id: i }));
let peer_ip = SocketAddr::from(([127, 0, 0, 1], i as u16));
assert!(pending.contains_peer(Item { id: i }, peer_ip));
assert_eq!(pending.get_peers(Item { id: i }), Some(HashSet::from([peer_ip])));
assert!(pending.remove(Item { id: i }, None).is_some());
}
assert!(pending.is_empty());
}
}