use std::borrow::Borrow;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use anyhow::Result;
use arc_swap::ArcSwapOption;
use bytes::{Bytes, BytesMut};
use indexmap::{IndexMap, IndexSet};
use parking_lot::{Mutex, RwLock, RwLockReadGuard};
use rand::Rng;
use tokio::sync::Notify;
use tycho_util::futures::BoxFutureOrNoop;
use tycho_util::{FastDashSet, FastHasherState};
use crate::dht::{PeerResolver, PeerResolverHandle};
use crate::network::Network;
use crate::overlay::OverlayId;
use crate::overlay::metrics::Metrics;
use crate::proto::overlay::{PublicEntry, PublicEntryToSign, rpc};
use crate::types::{BoxService, PeerId, Request, Response, Service, ServiceExt, ServiceRequest};
use crate::util::NetworkExt;
pub struct PublicOverlayBuilder {
overlay_id: OverlayId,
min_capacity: usize,
entry_ttl: Duration,
banned_peer_ids: FastDashSet<PeerId>,
peer_resolver: Option<PeerResolver>,
name: Option<&'static str>,
}
impl PublicOverlayBuilder {
pub fn with_min_capacity(mut self, min_capacity: usize) -> Self {
self.min_capacity = min_capacity;
self
}
pub fn with_entry_ttl(mut self, entry_ttl: Duration) -> Self {
self.entry_ttl = entry_ttl;
self
}
pub fn with_banned_peers<I>(mut self, banned_peers: I) -> Self
where
I: IntoIterator,
I::Item: Borrow<PeerId>,
{
self.banned_peer_ids
.extend(banned_peers.into_iter().map(|id| *id.borrow()));
self
}
pub fn with_peer_resolver(mut self, peer_resolver: PeerResolver) -> Self {
self.peer_resolver = Some(peer_resolver);
self
}
pub fn named(mut self, name: &'static str) -> Self {
self.name = Some(name);
self
}
pub fn build<S>(self, service: S) -> PublicOverlay
where
S: Send + Sync + 'static,
S: Service<ServiceRequest, QueryResponse = Response>,
{
const UNRESOLVED_QUEUE_CAPACITY: usize = 5;
let request_prefix = tl_proto::serialize(rpc::Prefix {
overlay_id: self.overlay_id.as_bytes(),
});
let entries = PublicOverlayEntries {
items: Default::default(),
};
let entry_ttl_sec = self.entry_ttl.as_secs().try_into().unwrap_or(u32::MAX);
PublicOverlay {
inner: Arc::new(Inner {
overlay_id: self.overlay_id,
min_capacity: self.min_capacity,
entry_ttl_sec,
peer_resolver: self.peer_resolver,
entries: RwLock::new(entries),
entries_added: Notify::new(),
entries_changed: Notify::new(),
entries_removed: Notify::new(),
entry_count: AtomicUsize::new(0),
own_signed_entry: Default::default(),
unknown_peers_queue: UnknownPeersQueue::with_capacity(UNRESOLVED_QUEUE_CAPACITY),
banned_peer_ids: self.banned_peer_ids,
service: service.boxed(),
request_prefix: request_prefix.into_boxed_slice(),
metrics: self
.name
.map(|label| Metrics::new("tycho_public_overlay", label))
.unwrap_or_default(),
}),
}
}
}
#[derive(Clone)]
#[repr(transparent)]
pub struct PublicOverlay {
inner: Arc<Inner>,
}
impl PublicOverlay {
pub fn builder(overlay_id: OverlayId) -> PublicOverlayBuilder {
PublicOverlayBuilder {
overlay_id,
min_capacity: 100,
entry_ttl: Duration::from_secs(3600),
banned_peer_ids: Default::default(),
peer_resolver: None,
name: None,
}
}
#[inline]
pub fn overlay_id(&self) -> &OverlayId {
&self.inner.overlay_id
}
pub fn entry_ttl_sec(&self) -> u32 {
self.inner.entry_ttl_sec
}
pub fn peer_resolver(&self) -> &Option<PeerResolver> {
&self.inner.peer_resolver
}
pub fn unknown_peers_queue(&self) -> &UnknownPeersQueue {
&self.inner.unknown_peers_queue
}
pub async fn query(
&self,
network: &Network,
peer_id: &PeerId,
mut request: Request,
) -> Result<Response> {
self.inner.metrics.record_tx(request.body.len());
self.prepend_prefix_to_body(&mut request.body);
network.query(peer_id, request).await
}
pub async fn send(
&self,
network: &Network,
peer_id: &PeerId,
mut request: Request,
) -> Result<()> {
self.inner.metrics.record_tx(request.body.len());
self.prepend_prefix_to_body(&mut request.body);
network.send(peer_id, request).await
}
pub fn ban_peer(&self, peer_id: PeerId) -> bool {
self.inner.banned_peer_ids.insert(peer_id)
}
pub fn unban_peer(&self, peer_id: &PeerId) -> bool {
self.inner.banned_peer_ids.remove(peer_id).is_some()
}
pub fn read_entries(&self) -> PublicOverlayEntriesReadGuard<'_> {
PublicOverlayEntriesReadGuard {
entries: self.inner.entries.read(),
}
}
pub fn entires_added(&self) -> &Notify {
&self.inner.entries_added
}
pub fn entries_changed(&self) -> &Notify {
&self.inner.entries_changed
}
pub fn entries_removed(&self) -> &Notify {
&self.inner.entries_removed
}
pub fn own_signed_entry(&self) -> Option<Arc<PublicEntry>> {
self.inner.own_signed_entry.load_full()
}
pub(crate) fn set_own_signed_entry(&self, entry: Arc<PublicEntry>) {
self.inner.own_signed_entry.store(Some(entry));
}
pub(crate) fn handle_query(&self, req: ServiceRequest) -> BoxFutureOrNoop<Option<Response>> {
self.inner.metrics.record_rx(req.body.len());
if self.check_peer_id(&req.metadata.peer_id) {
BoxFutureOrNoop::future(self.inner.service.on_query(req))
} else {
BoxFutureOrNoop::Noop
}
}
pub(crate) fn handle_message(&self, req: ServiceRequest) -> BoxFutureOrNoop<()> {
self.inner.metrics.record_rx(req.body.len());
if self.check_peer_id(&req.metadata.peer_id) {
BoxFutureOrNoop::future(self.inner.service.on_message(req))
} else {
BoxFutureOrNoop::Noop
}
}
fn check_peer_id(&self, peer_id: &PeerId) -> bool {
if self.inner.banned_peer_ids.contains(peer_id) {
return false;
}
if !self.inner.unknown_peers_queue.is_full() && !self.inner.entries.read().contains(peer_id)
{
if self.inner.unknown_peers_queue.push(peer_id) {
tracing::debug!(%peer_id, "found new unknown peer to resolve");
}
}
true
}
pub(crate) fn add_untrusted_entries(
&self,
local_id: &PeerId,
entries: &[Arc<PublicEntry>],
now: u32,
) -> bool {
if entries.is_empty() {
return false;
}
let this = self.inner.as_ref();
let to_add = entries.len();
let mut entry_count = this.entry_count.load(Ordering::Acquire);
let to_add = loop {
let to_add = match this.min_capacity.checked_sub(entry_count) {
Some(capacity) if capacity > 0 => std::cmp::min(to_add, capacity),
_ => return false,
};
let res = this.entry_count.compare_exchange_weak(
entry_count,
entry_count + to_add,
Ordering::Release,
Ordering::Acquire,
);
match res {
Ok(_) => break to_add,
Err(n) => entry_count = n,
}
};
let mut is_valid = vec![false; entries.len()];
let mut has_valid = false;
for (entry, is_valid) in std::iter::zip(entries, is_valid.iter_mut()) {
if entry.is_expired(now, this.entry_ttl_sec)
|| self.inner.banned_peer_ids.contains(&entry.peer_id)
|| entry.peer_id == local_id
{
continue;
}
let Some(pubkey) = entry.peer_id.as_public_key() else {
continue;
};
if !pubkey.verify_tl(
PublicEntryToSign {
overlay_id: this.overlay_id.as_bytes(),
peer_id: &entry.peer_id,
created_at: entry.created_at,
},
&entry.signature,
) {
continue;
}
*is_valid = true;
has_valid = true;
}
let mut added = 0;
let mut changed = false;
if has_valid {
let mut stored = this.entries.write();
for (entry, is_valid) in std::iter::zip(entries, is_valid) {
if !is_valid {
continue;
}
let status = stored.insert(&this.peer_resolver, entry);
changed |= status.is_changed();
added += status.is_added() as usize;
if added >= to_add {
break;
}
}
}
if added < to_add {
this.entry_count
.fetch_sub(to_add - added, Ordering::Release);
}
if added > 0 {
this.entries_added.notify_waiters();
}
if changed {
this.entries_changed.notify_waiters();
}
changed || added > 0
}
pub(crate) fn remove_invalid_entries(&self, now: u32) {
let this = self.inner.as_ref();
let mut should_notify = false;
let mut entries = this.entries.write();
entries.retain(|item| {
let retain = !item.entry.is_expired(now, this.entry_ttl_sec)
&& !this.banned_peer_ids.contains(&item.entry.peer_id);
should_notify |= !retain;
retain
});
if should_notify {
self.inner.entries_removed.notify_waiters();
}
}
fn prepend_prefix_to_body(&self, body: &mut Bytes) {
let this = self.inner.as_ref();
let mut res = BytesMut::with_capacity(this.request_prefix.len() + body.len());
res.extend_from_slice(&this.request_prefix);
res.extend_from_slice(body);
*body = res.freeze();
}
}
impl std::fmt::Debug for PublicOverlay {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PublicOverlay")
.field("overlay_id", &self.inner.overlay_id)
.finish()
}
}
struct Inner {
overlay_id: OverlayId,
min_capacity: usize,
entry_ttl_sec: u32,
peer_resolver: Option<PeerResolver>,
entries: RwLock<PublicOverlayEntries>,
entry_count: AtomicUsize,
entries_added: Notify,
entries_changed: Notify,
entries_removed: Notify,
own_signed_entry: ArcSwapOption<PublicEntry>,
unknown_peers_queue: UnknownPeersQueue,
banned_peer_ids: FastDashSet<PeerId>,
service: BoxService<ServiceRequest, Response>,
request_prefix: Box<[u8]>,
metrics: Metrics,
}
pub struct PublicOverlayEntries {
items: OverlayItems,
}
impl PublicOverlayEntries {
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}
pub fn len(&self) -> usize {
self.items.len()
}
pub fn contains(&self, peer_id: &PeerId) -> bool {
self.items.contains_key(peer_id)
}
pub fn iter(&self) -> indexmap::map::Values<'_, PeerId, PublicOverlayEntryData> {
self.items.values()
}
pub fn choose<R>(&self, rng: &mut R) -> Option<&PublicOverlayEntryData>
where
R: Rng + ?Sized,
{
let index = rng.random_range(0..self.items.len());
let (_, value) = self.items.get_index(index)?;
Some(value)
}
pub fn choose_multiple<R>(
&self,
rng: &mut R,
n: usize,
) -> ChooseMultiplePublicOverlayEntries<'_>
where
R: Rng + ?Sized,
{
let len = self.items.len();
ChooseMultiplePublicOverlayEntries {
items: &self.items,
indices: rand::seq::index::sample(rng, len, n.min(len)).into_iter(),
}
}
pub fn choose_all<R>(&self, rng: &mut R) -> ChooseMultiplePublicOverlayEntries<'_>
where
R: Rng + ?Sized,
{
self.choose_multiple(rng, self.items.len())
}
fn insert(&mut self, peer_resolver: &Option<PeerResolver>, item: &PublicEntry) -> UpdateStatus {
match self.items.entry(item.peer_id) {
indexmap::map::Entry::Vacant(entry) => {
let resolver_handle = peer_resolver.as_ref().map_or_else(
|| PeerResolverHandle::new_noop(&item.peer_id),
|resolver| resolver.insert(&item.peer_id, false),
);
entry.insert(PublicOverlayEntryData {
entry: Arc::new(item.clone()),
resolver_handle,
});
UpdateStatus::Added
}
indexmap::map::Entry::Occupied(mut entry) => {
let existing = entry.get_mut();
if existing.entry.created_at >= item.created_at {
return UpdateStatus::Skipped;
}
match Arc::get_mut(&mut existing.entry) {
Some(existing) => existing.clone_from(item),
None => existing.entry = Arc::new(item.clone()),
}
UpdateStatus::Updated
}
}
}
fn retain<F>(&mut self, mut f: F)
where
F: FnMut(&PublicOverlayEntryData) -> bool,
{
self.items.retain(|_, item| f(item));
}
}
#[derive(Clone)]
pub struct PublicOverlayEntryData {
pub entry: Arc<PublicEntry>,
pub resolver_handle: PeerResolverHandle,
}
impl PublicOverlayEntryData {
pub fn is_expired(&self, now: u32, ttl: u32) -> bool {
self.entry.is_expired(now, ttl)
}
pub fn expires_at(&self, ttl: u32) -> u32 {
self.entry.created_at.saturating_add(ttl)
}
}
pub struct PublicOverlayEntriesReadGuard<'a> {
entries: RwLockReadGuard<'a, PublicOverlayEntries>,
}
impl std::ops::Deref for PublicOverlayEntriesReadGuard<'_> {
type Target = PublicOverlayEntries;
#[inline]
fn deref(&self) -> &Self::Target {
&self.entries
}
}
pub struct UnknownPeersQueue {
peer_ids: Mutex<IndexSet<PeerId, FastHasherState>>,
peer_id_count: AtomicUsize,
capacity: usize,
}
impl UnknownPeersQueue {
pub fn with_capacity(capacity: usize) -> Self {
Self {
peer_ids: Mutex::new(IndexSet::with_capacity_and_hasher(
capacity,
Default::default(),
)),
peer_id_count: AtomicUsize::new(0),
capacity,
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn is_full(&self) -> bool {
self.len() >= self.capacity
}
pub fn len(&self) -> usize {
self.peer_id_count.load(Ordering::Acquire)
}
pub fn push(&self, peer_id: &PeerId) -> bool {
let mut peer_ids = self.peer_ids.lock();
if peer_ids.len() >= self.capacity {
return false;
}
let added = peer_ids.insert(*peer_id);
self.peer_id_count.fetch_add(added as _, Ordering::Release);
added
}
pub fn pop_multiple(&self) -> Option<IndexSet<PeerId, FastHasherState>> {
if self.is_empty() {
return None;
}
let mut peer_ids = self.peer_ids.lock();
self.peer_id_count.store(0, Ordering::Release);
let res = std::mem::take(&mut *peer_ids);
if res.is_empty() { None } else { Some(res) }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum UpdateStatus {
Skipped,
Updated,
Added,
}
impl UpdateStatus {
fn is_changed(self) -> bool {
matches!(self, Self::Updated | Self::Added)
}
fn is_added(self) -> bool {
matches!(self, Self::Added)
}
}
pub struct ChooseMultiplePublicOverlayEntries<'a> {
items: &'a OverlayItems,
indices: rand::seq::index::IndexVecIntoIter,
}
impl<'a> Iterator for ChooseMultiplePublicOverlayEntries<'a> {
type Item = &'a PublicOverlayEntryData;
fn next(&mut self) -> Option<Self::Item> {
self.indices.next().and_then(|i| {
let (_, value) = self.items.get_index(i)?;
Some(value)
})
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.indices.len(), Some(self.indices.len()))
}
}
impl ExactSizeIterator for ChooseMultiplePublicOverlayEntries<'_> {
fn len(&self) -> usize {
self.indices.len()
}
}
type OverlayItems = IndexMap<PeerId, PublicOverlayEntryData, FastHasherState>;
#[cfg(test)]
mod tests {
use tycho_crypto::ed25519;
use tycho_util::time::now_sec;
use super::*;
fn generate_public_entry(overlay: &PublicOverlay, now: u32) -> Arc<PublicEntry> {
let keypair = rand::random::<ed25519::KeyPair>();
let peer_id: PeerId = keypair.public_key.into();
let signature = keypair.sign_tl(crate::proto::overlay::PublicEntryToSign {
overlay_id: overlay.overlay_id().as_bytes(),
peer_id: &peer_id,
created_at: now,
});
Arc::new(PublicEntry {
peer_id,
created_at: now,
signature: Box::new(signature),
})
}
fn generate_invalid_public_entry(now: u32) -> Arc<PublicEntry> {
let keypair = rand::random::<ed25519::KeyPair>();
let peer_id: PeerId = keypair.public_key.into();
Arc::new(PublicEntry {
peer_id,
created_at: now,
signature: Box::new([0; 64]),
})
}
fn generate_public_entries(
overlay: &PublicOverlay,
now: u32,
n: usize,
) -> Vec<Arc<PublicEntry>> {
(0..n)
.map(|_| generate_public_entry(overlay, now))
.collect()
}
fn count_entries(overlay: &PublicOverlay) -> usize {
let tracked_count = overlay.inner.entry_count.load(Ordering::Acquire);
let guard = overlay.read_entries();
assert_eq!(guard.entries.items.len(), tracked_count);
tracked_count
}
fn make_overlay_with_min_capacity(min_capacity: usize) -> PublicOverlay {
PublicOverlay::builder(rand::random())
.with_min_capacity(min_capacity)
.build(crate::service_query_fn(|_| {
futures_util::future::ready(None)
}))
}
#[test]
fn min_capacity_works_with_single_thread() {
let now = now_sec();
let local_id: PeerId = rand::random();
{
let overlay = make_overlay_with_min_capacity(10);
let entries = generate_public_entries(&overlay, now, 10);
overlay.add_untrusted_entries(&local_id, &entries[..5], now);
assert_eq!(count_entries(&overlay), 5);
overlay.add_untrusted_entries(&local_id, &entries[5..], now);
assert_eq!(count_entries(&overlay), 10);
}
{
let overlay = make_overlay_with_min_capacity(10);
let entries = generate_public_entries(&overlay, now, 10);
overlay.add_untrusted_entries(&local_id, &entries, now);
assert_eq!(count_entries(&overlay), 10);
}
{
let overlay = make_overlay_with_min_capacity(10);
let entries = generate_public_entries(&overlay, now, 20);
overlay.add_untrusted_entries(&local_id, &entries, now);
assert_eq!(count_entries(&overlay), 10);
}
{
let overlay = make_overlay_with_min_capacity(0);
let entries = generate_public_entries(&overlay, now, 10);
overlay.add_untrusted_entries(&local_id, &entries, now);
assert_eq!(count_entries(&overlay), 0);
}
{
let overlay = make_overlay_with_min_capacity(10);
let entries = (0..10)
.map(|_| generate_invalid_public_entry(now))
.collect::<Vec<_>>();
overlay.add_untrusted_entries(&local_id, &entries, now);
assert_eq!(count_entries(&overlay), 0);
}
{
let overlay = make_overlay_with_min_capacity(10);
let entries = [
generate_invalid_public_entry(now),
generate_public_entry(&overlay, now),
generate_invalid_public_entry(now),
generate_public_entry(&overlay, now),
generate_invalid_public_entry(now),
generate_public_entry(&overlay, now),
generate_invalid_public_entry(now),
generate_public_entry(&overlay, now),
generate_invalid_public_entry(now),
generate_public_entry(&overlay, now),
];
overlay.add_untrusted_entries(&local_id, &entries, now);
assert_eq!(count_entries(&overlay), 5);
}
{
let overlay = make_overlay_with_min_capacity(3);
let entries = [
generate_invalid_public_entry(now),
generate_invalid_public_entry(now),
generate_invalid_public_entry(now),
generate_invalid_public_entry(now),
generate_invalid_public_entry(now),
generate_public_entry(&overlay, now),
generate_public_entry(&overlay, now),
generate_public_entry(&overlay, now),
generate_public_entry(&overlay, now),
generate_public_entry(&overlay, now),
];
overlay.add_untrusted_entries(&local_id, &entries, now);
assert_eq!(count_entries(&overlay), 3);
}
}
#[test]
fn min_capacity_works_with_multi_thread() {
let now = now_sec();
let local_id: PeerId = rand::random();
let overlay = make_overlay_with_min_capacity(201);
let entries = generate_public_entries(&overlay, now, 7 * 3 * 10);
std::thread::scope(|s| {
for entries in entries.chunks_exact(7 * 3) {
s.spawn(|| {
for entries in entries.chunks_exact(7) {
overlay.add_untrusted_entries(&local_id, entries, now);
}
});
}
});
assert_eq!(count_entries(&overlay), 201);
}
#[test]
fn unknown_peers_queue() {
let queue = UnknownPeersQueue::with_capacity(5);
assert!(queue.is_empty());
assert!(!queue.is_full());
let added = queue.push(&PeerId([0; 32]));
assert!(added);
assert_eq!(queue.len(), 1);
assert!(!queue.is_empty());
assert!(!queue.is_full());
let added = queue.push(&PeerId([0; 32]));
assert!(!added);
assert_eq!(queue.len(), 1);
for i in 1..=3 {
let added = queue.push(&PeerId([i; 32]));
assert!(added);
assert_eq!(queue.len(), i as usize + 1);
assert!(!queue.is_empty());
assert!(!queue.is_full());
}
let added = queue.push(&PeerId([4; 32]));
assert!(added);
assert_eq!(queue.len(), 5);
assert!(queue.is_full());
let added = queue.push(&PeerId([5; 32]));
assert!(!added);
assert_eq!(queue.len(), 5);
assert!(queue.is_full());
let items = queue.pop_multiple().unwrap();
assert!(queue.is_empty());
assert!(!queue.is_full());
assert_eq!(items.len(), 5);
for i in 0..5 {
assert!(items.contains(&PeerId([i; 32])));
}
let items = queue.pop_multiple();
assert!(items.is_none());
let added = queue.push(&PeerId([0; 32]));
assert!(added);
assert_eq!(queue.len(), 1);
assert!(!queue.is_empty());
assert!(!queue.is_full());
let items = queue.pop_multiple().unwrap();
assert!(queue.is_empty());
assert!(!queue.is_full());
assert_eq!(items.len(), 1);
assert!(items.contains(&PeerId([0; 32])));
}
}