use std::borrow::{Borrow, Cow};
use std::convert::TryFrom;
use std::net::SocketAddrV4;
use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
use bytes::Bytes;
use futures_util::stream::FuturesUnordered;
use futures_util::StreamExt;
use rustc_hash::FxHashSet;
use smallvec::smallvec;
use tl_proto::{BoxedConstructor, BoxedWrapper, TlRead, TlWrite};
use super::buckets::Buckets;
use super::entry::Entry;
use super::futures::StoreValue;
use super::storage::{Storage, StorageOptions};
use super::{KEY_ADDRESS, KEY_NODES, MAX_DHT_PEERS};
use crate::adnl;
use crate::overlay;
use crate::proto;
use crate::subscriber::*;
use crate::util::*;
#[derive(Debug, Copy, Clone, serde::Serialize, serde::Deserialize)]
#[serde(default)]
pub struct NodeOptions {
pub value_ttl_sec: u32,
pub query_timeout_ms: u64,
pub default_value_batch_len: usize,
pub bad_peer_threshold: usize,
pub max_allowed_k: u32,
pub max_key_name_len: usize,
pub max_key_index: u32,
pub storage_gc_interval_ms: u64,
}
impl Default for NodeOptions {
fn default() -> Self {
Self {
value_ttl_sec: 3600,
query_timeout_ms: 1000,
default_value_batch_len: 5,
bad_peer_threshold: 5,
max_allowed_k: 20,
max_key_name_len: 127,
max_key_index: 15,
storage_gc_interval_ms: 10000,
}
}
}
pub struct Node {
adnl: Arc<adnl::Node>,
local_id: adnl::NodeIdShort,
query_prefix: Vec<u8>,
options: NodeOptions,
state: Arc<NodeState>,
}
impl Node {
pub fn new(adnl: Arc<adnl::Node>, key_tag: usize, options: NodeOptions) -> Result<Arc<Self>> {
let key = adnl.key_by_tag(key_tag)?.clone();
let buckets = Buckets::new(key.id());
let storage = Storage::new(StorageOptions {
max_key_name_len: options.max_key_name_len,
max_key_index: options.max_key_index,
});
let state = Arc::new(NodeState {
key: key.clone(),
known_peers: adnl::PeersSet::with_capacity(MAX_DHT_PEERS),
penalties: Default::default(),
buckets,
storage,
max_allowed_k: options.max_allowed_k,
});
adnl.add_query_subscriber(state.clone())?;
let query_prefix = tl_proto::serialize(proto::rpc::DhtQuery {
node: state
.sign_local_node(adnl.build_address_list())
.as_equivalent_ref(),
});
let dht_node = Arc::new(Self {
adnl,
local_id: *key.id(),
query_prefix,
options,
state,
});
let state = Arc::downgrade(&dht_node.state);
let interval = Duration::from_millis(dht_node.options.storage_gc_interval_ms);
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
if let Some(state) = state.upgrade() {
state.storage.gc();
}
}
});
Ok(dht_node)
}
#[inline(always)]
pub fn options(&self) -> &NodeOptions {
&self.options
}
#[inline(always)]
pub fn metrics(&self) -> NodeMetrics {
self.state.metrics()
}
#[inline(always)]
pub fn adnl(&self) -> &Arc<adnl::Node> {
&self.adnl
}
#[inline(always)]
pub fn key(&self) -> &Arc<adnl::Key> {
&self.state.key
}
pub fn iter_known_peers(&self) -> impl Iterator<Item = &adnl::NodeIdShort> {
self.state.known_peers.iter()
}
pub fn add_dht_peer(&self, peer: proto::dht::NodeOwned) -> Result<Option<adnl::NodeIdShort>> {
self.state.add_dht_peer(&self.adnl, peer)
}
pub fn is_bad_peer(&self, peer: &adnl::NodeIdShort) -> bool {
matches!(
self.state.penalties.get(peer),
Some(penalty) if *penalty > self.options.bad_peer_threshold
)
}
pub async fn ping(&self, peer_id: &adnl::NodeIdShort) -> Result<bool> {
use rand::RngCore;
let random_id = fast_thread_rng().next_u64();
match self
.query(peer_id, proto::rpc::DhtPing { random_id })
.await?
{
Some(proto::dht::Pong { random_id: answer }) => Ok(answer == random_id),
None => Ok(false),
}
}
pub fn entry<'a, T>(self: &'a Arc<Self>, id: &'a T, name: &'a str) -> Entry<'a>
where
T: Borrow<[u8; 32]>,
{
Entry::new(self, id, name)
}
pub async fn query_dht_nodes(
&self,
peer_id: &adnl::NodeIdShort,
k: u32,
store_self: bool,
) -> Result<Vec<proto::dht::NodeOwned>> {
let query = proto::rpc::DhtFindNode {
key: self.local_id.as_slice(),
k,
};
let answer = match store_self {
true => self.query_with_prefix(peer_id, query).await,
false => self.query(peer_id, query).await,
}?;
Ok(match answer {
Some(BoxedWrapper(proto::dht::NodesOwned { nodes })) => nodes,
None => Vec::new(),
})
}
pub async fn find_more_dht_nodes(&self) -> Result<usize> {
let known_nodes = self.known_peers().clone_inner();
let mut tasks = futures_util::stream::FuturesUnordered::new();
for peer_id in known_nodes {
tasks.push(async move {
let res = self.query_dht_nodes(&peer_id, 10, false).await;
(peer_id, res)
});
}
let mut node_count = 0;
while let Some((peer_id, res)) = tasks.next().await {
match res {
Ok(nodes) => {
for node in nodes {
node_count += self.add_dht_peer(node)?.is_some() as usize;
}
}
Err(e) => {
tracing::warn!(%peer_id, "failed to get DHT nodes: {e:?}")
}
}
}
Ok(node_count)
}
pub async fn find_overlay_nodes(
self: &Arc<Self>,
overlay_id: &overlay::IdShort,
) -> Result<Vec<(SocketAddrV4, proto::overlay::NodeOwned)>> {
let mut result = Vec::new();
let mut nodes = Vec::new();
let mut cache = FxHashSet::default();
loop {
let received = self
.entry(overlay_id, KEY_NODES)
.values()
.use_new_peers(true)
.map(|(_, BoxedWrapper(proto::overlay::NodesOwned { nodes }))| nodes)
.collect::<Vec<_>>()
.await;
if received.is_empty() {
break;
}
let mut futures = FuturesUnordered::new();
for node in received
.into_iter()
.flatten()
.chain(std::mem::take(&mut nodes).into_iter())
{
let peer_id = match adnl::NodeIdFull::try_from(node.id.as_equivalent_ref())
.map(|full_id| full_id.compute_short_id())
{
Ok(peer_id) if cache.insert(peer_id) => peer_id,
_ => continue,
};
let dht = self.clone();
futures.push(async move {
match dht.find_address(&peer_id).await {
Ok((ip, _)) => (Some(ip), node),
Err(_) => (None, node),
}
});
}
while let Some((ip, node)) = futures.next().await {
match ip {
Some(ip) => result.push((ip, node)),
None if result.is_empty() => nodes.push(node),
_ => {}
}
}
if !result.is_empty() {
break;
}
}
Ok(result)
}
pub async fn find_address(
self: &Arc<Self>,
peer_id: &adnl::NodeIdShort,
) -> Result<(SocketAddrV4, adnl::NodeIdFull)> {
let mut values = self.entry(peer_id, KEY_ADDRESS).values();
while let Some((key, BoxedWrapper(value))) = values.next().await {
match (
parse_address_list(&value, self.adnl.options().clock_tolerance_sec),
adnl::NodeIdFull::try_from(key.id.as_equivalent_ref()),
) {
(Ok(addr), Ok(full_id)) => return Ok((addr, full_id)),
_ => continue,
}
}
Err(DhtNodeError::NoAddressFound.into())
}
pub fn store_value(self: &Arc<Self>, value: proto::dht::Value<'_>) -> Result<StoreValue> {
StoreValue::new(self.clone(), value)
}
pub async fn store_overlay_node(
self: &Arc<Self>,
overlay_id_full: &overlay::IdFull,
node: proto::overlay::Node<'_>,
) -> Result<bool> {
let overlay_id = overlay_id_full.compute_short_id();
overlay_id.verify_overlay_node(&node)?;
let value = tl_proto::serialize_as_boxed(proto::overlay::Nodes {
nodes: smallvec![node],
});
let value = proto::dht::Value {
key: proto::dht::KeyDescription {
key: proto::dht::Key {
id: overlay_id.as_slice(),
name: KEY_NODES.as_bytes(),
idx: 0,
},
id: everscale_crypto::tl::PublicKey::Overlay {
name: overlay_id_full.as_slice(),
},
update_rule: proto::dht::UpdateRule::OverlayNodes,
signature: Default::default(),
},
value: &value,
ttl: now() + self.options.value_ttl_sec,
signature: Default::default(),
};
self.store_value(value)?
.then_check(
move |_, BoxedWrapper(proto::overlay::NodesOwned { nodes })| {
for stored_node in &nodes {
if stored_node.as_equivalent_ref() == node {
return Ok(true);
}
}
Ok(false)
},
)
.check_all()
.await
}
pub async fn store_address(
self: &Arc<Self>,
key: &adnl::Key,
addr: SocketAddrV4,
) -> Result<bool> {
let clock_tolerance_sec = self.adnl.options().clock_tolerance_sec;
self.entry(key.id(), KEY_ADDRESS)
.with_data(
proto::adnl::AddressList {
address: Some(proto::adnl::Address::from(&addr)),
version: now(),
reinit_date: self.adnl.start_time(),
expire_at: 0,
}
.into_boxed(),
)
.sign_and_store(key)?
.then_check(move |_, BoxedWrapper(address_list)| {
match parse_address_list(&address_list, clock_tolerance_sec)? {
stored_addr if stored_addr == addr => Ok(true),
stored_addr => {
tracing::warn!(
stored = %stored_addr,
expected = %addr,
"stored address mismatch",
);
Ok(false)
}
}
})
.await
}
async fn query<Q, A>(&self, peer_id: &adnl::NodeIdShort, query: Q) -> Result<Option<A>>
where
Q: TlWrite,
for<'a> A: TlRead<'a, Repr = tl_proto::Boxed> + 'static,
{
let result = self.adnl.query(&self.local_id, peer_id, query, None).await;
self.state.update_peer_status(peer_id, result.is_ok());
result
}
pub(super) async fn query_raw(
&self,
peer_id: &adnl::NodeIdShort,
query: Bytes,
) -> Result<Option<Vec<u8>>> {
let result = self
.adnl
.query_raw(
&self.local_id,
peer_id,
query,
Some(self.options.query_timeout_ms),
)
.await;
self.state.update_peer_status(peer_id, result.is_ok());
result
}
async fn query_with_prefix<Q, A>(
&self,
peer_id: &adnl::NodeIdShort,
query: Q,
) -> Result<Option<A>>
where
Q: TlWrite,
for<'a> A: TlRead<'a, Repr = tl_proto::Boxed> + 'static,
{
let result = self
.adnl
.query_with_prefix::<Q, A>(&self.local_id, peer_id, &self.query_prefix, query, None)
.await;
self.state.update_peer_status(peer_id, result.is_ok());
result
}
pub(super) fn parse_value_result<T>(
&self,
result: &[u8],
) -> Result<Option<(proto::dht::KeyDescriptionOwned, T)>>
where
for<'a> T: TlRead<'a, Repr = tl_proto::Boxed> + 'static,
{
match tl_proto::deserialize::<proto::dht::ValueResult>(result)? {
proto::dht::ValueResult::ValueFound(BoxedWrapper(mut value)) => {
if value.key.update_rule == proto::dht::UpdateRule::Signature {
verify_signed_dht_value(&mut value)?;
}
let parsed = tl_proto::deserialize(value.value)?;
Ok(Some((value.key.as_equivalent_owned(), parsed)))
}
proto::dht::ValueResult::ValueNotFound(proto::dht::NodesOwned { nodes }) => {
for node in nodes {
if let Err(e) = self.add_dht_peer(node) {
tracing::warn!("failed to add DHT peer: {e:?}");
}
}
Ok(None)
}
}
}
#[inline(always)]
pub(super) fn known_peers(&self) -> &adnl::PeersSet {
&self.state.known_peers
}
#[inline(always)]
pub(super) fn storage(&self) -> &Storage {
&self.state.storage
}
}
struct NodeState {
key: Arc<adnl::Key>,
known_peers: adnl::PeersSet,
penalties: Penalties,
buckets: Buckets,
storage: Storage,
max_allowed_k: u32,
}
impl NodeState {
fn metrics(&self) -> NodeMetrics {
NodeMetrics {
known_peers_len: self.known_peers.len(),
bucket_peer_count: self.buckets.iter().map(|bucket| bucket.len()).sum(),
storage_len: self.storage.len(),
storage_total_size: self.storage.total_size(),
}
}
fn sign_local_node(&self, addr_list: proto::adnl::AddressList) -> proto::dht::NodeOwned {
let mut node = proto::dht::NodeOwned {
id: self.key.full_id().as_tl().as_equivalent_owned(),
addr_list,
version: addr_list.version,
signature: Default::default(),
};
node.signature = self.key.sign(node.as_boxed()).to_vec().into();
node
}
fn add_dht_peer(
&self,
adnl: &adnl::Node,
mut peer: proto::dht::NodeOwned,
) -> Result<Option<adnl::NodeIdShort>> {
let peer_id_full = adnl::NodeIdFull::try_from(peer.id.as_equivalent_ref())?;
let signature = std::mem::take(&mut peer.signature);
if peer_id_full.verify(peer.as_boxed(), &signature).is_err() {
tracing::warn!("invalid DHT peer signature");
return Ok(None);
}
peer.signature = signature;
let peer_id = peer_id_full.compute_short_id();
let peer_addr = parse_address_list(&peer.addr_list, adnl.options().clock_tolerance_sec)?;
let is_new_peer = adnl.add_peer(
adnl::NewPeerContext::Dht,
self.key.id(),
&peer_id,
peer_addr,
peer_id_full,
)?;
if !is_new_peer {
return Ok(None);
}
if self.known_peers.insert(peer_id) {
self.buckets.insert(&peer_id, peer);
} else {
self.set_good_peer(&peer_id);
}
Ok(Some(peer_id))
}
fn update_peer_status(&self, peer: &adnl::NodeIdShort, is_good: bool) {
use dashmap::mapref::entry::Entry;
if is_good {
self.set_good_peer(peer);
} else {
match self.penalties.entry(*peer) {
Entry::Occupied(mut entry) => {
*entry.get_mut() += 2;
}
Entry::Vacant(entry) => {
entry.insert(0);
}
}
}
}
fn set_good_peer(&self, peer: &adnl::NodeIdShort) {
if let Some(mut count) = self.penalties.get_mut(peer) {
*count.value_mut() = count.saturating_sub(1);
}
}
fn process_find_node(&self, query: proto::rpc::DhtFindNode<'_>) -> proto::dht::NodesOwned {
self.buckets.find(query.key, query.k)
}
fn process_find_value(
&self,
query: proto::rpc::DhtFindValue<'_>,
) -> Result<proto::dht::ValueResultOwned> {
if query.k == 0 || query.k > self.max_allowed_k {
return Err(DhtNodeError::InvalidNodeCountLimit.into());
}
Ok(if let Some(value) = self.storage.get_ref(query.key) {
proto::dht::ValueResultOwned::ValueFound(value.clone().into_boxed())
} else {
let mut nodes = Vec::with_capacity(query.k as usize);
'outer: for bucket in &self.buckets {
for peer in bucket {
nodes.push(peer.clone());
if nodes.len() >= query.k as usize {
break 'outer;
}
}
}
proto::dht::ValueResultOwned::ValueNotFound(proto::dht::NodesOwned { nodes })
})
}
fn process_store(&self, query: proto::rpc::DhtStore<'_>) -> Result<proto::dht::Stored> {
self.storage.insert(query.value)?;
Ok(proto::dht::Stored)
}
}
#[async_trait::async_trait]
impl QuerySubscriber for NodeState {
async fn try_consume_query<'a>(
&self,
ctx: SubscriberContext<'a>,
constructor: u32,
query: Cow<'a, [u8]>,
) -> Result<QueryConsumingResult<'a>> {
match constructor {
proto::rpc::DhtPing::TL_ID => {
let proto::rpc::DhtPing { random_id } = tl_proto::deserialize(&query)?;
QueryConsumingResult::consume(proto::dht::Pong { random_id })
}
proto::rpc::DhtFindNode::TL_ID => {
let query = tl_proto::deserialize(&query)?;
QueryConsumingResult::consume(self.process_find_node(query).into_boxed())
}
proto::rpc::DhtFindValue::TL_ID => {
let query = tl_proto::deserialize(&query)?;
QueryConsumingResult::consume(self.process_find_value(query)?)
}
proto::rpc::DhtGetSignedAddressList::TL_ID => QueryConsumingResult::consume(
self.sign_local_node(ctx.adnl.build_address_list())
.into_boxed(),
),
proto::rpc::DhtStore::TL_ID => {
let query = tl_proto::deserialize(&query)?;
QueryConsumingResult::consume(self.process_store(query)?)
}
proto::rpc::DhtQuery::TL_ID => {
let mut offset = 0;
let proto::rpc::DhtQuery { node } = <_>::read_from(&query, &mut offset)?;
let constructor = u32::read_from(&query, &mut std::convert::identity(offset))?;
if offset >= query.len() {
return Err(DhtNodeError::UnexpectedQuery.into());
}
self.add_dht_peer(ctx.adnl, node.as_equivalent_owned())?;
match self
.try_consume_query(ctx, constructor, Cow::Borrowed(&query[offset..]))
.await?
{
QueryConsumingResult::Consumed(answer) => {
Ok(QueryConsumingResult::Consumed(answer))
}
QueryConsumingResult::Rejected(_) => Err(DhtNodeError::UnexpectedQuery.into()),
}
}
_ => Ok(QueryConsumingResult::Rejected(query)),
}
}
}
fn verify_signed_dht_value(value: &mut proto::dht::Value<'_>) -> Result<()> {
if value.key.key.id != &tl_proto::hash(value.key.id) {
return Err(DhtNodeError::InvalidValueKey.into());
}
let full_id = adnl::NodeIdFull::try_from(value.key.id)?;
let key_signature = std::mem::take(&mut value.key.signature);
full_id.verify(value.key.as_boxed(), key_signature)?;
value.key.signature = key_signature;
let value_signature = std::mem::take(&mut value.signature);
full_id.verify(value.as_boxed(), value_signature)?;
value.signature = value_signature;
Ok(())
}
#[derive(Debug, Copy, Clone)]
pub struct NodeMetrics {
pub known_peers_len: usize,
pub bucket_peer_count: usize,
pub storage_len: usize,
pub storage_total_size: usize,
}
type Penalties = FxDashMap<adnl::NodeIdShort, usize>;
#[derive(thiserror::Error, Debug)]
enum DhtNodeError {
#[error("No address found")]
NoAddressFound,
#[error("Unexpected DHT query")]
UnexpectedQuery,
#[error("Invalid node count limit")]
InvalidNodeCountLimit,
#[error("Invalid value key")]
InvalidValueKey,
}