use std::net::SocketAddrV4;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context, Result};
use bytes::Bytes;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use tl_proto::{TlRead, TlWrite};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use self::receiver::*;
use self::sender::*;
use super::channel::{AdnlChannelId, Channel};
use super::keystore::{Key, Keystore, KeystoreError};
use super::node_id::{NodeIdFull, NodeIdShort};
use super::peer::{NewPeerContext, Peer, PeerFilter, Peers};
use super::ping_subscriber::PingSubscriber;
use super::queries_cache::{QueriesCache, QueryId};
use super::socket::make_udp_socket;
use super::transfer::*;
use crate::proto;
use crate::subscriber::*;
use crate::util::*;
mod receiver;
mod sender;
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct NodeOptions {
pub query_min_timeout_ms: u64,
pub query_default_timeout_ms: u64,
pub transfer_timeout_sec: u64,
pub clock_tolerance_sec: u32,
pub channel_reset_timeout_sec: u32,
pub address_list_timeout_sec: u32,
pub packet_history_enabled: bool,
pub packet_signature_required: bool,
pub force_use_priority_channels: bool,
pub use_loopback_for_neighbours: bool,
pub version: Option<u16>,
}
impl Default for NodeOptions {
fn default() -> Self {
Self {
query_min_timeout_ms: 500,
query_default_timeout_ms: 5000,
transfer_timeout_sec: 3,
clock_tolerance_sec: 60,
channel_reset_timeout_sec: 30,
address_list_timeout_sec: 1000,
packet_history_enabled: false,
packet_signature_required: true,
force_use_priority_channels: true,
use_loopback_for_neighbours: false,
version: None,
}
}
}
pub struct Node {
socket_addr: SocketAddrV4,
keystore: Keystore,
options: NodeOptions,
peer_filter: Option<Arc<dyn PeerFilter>>,
peers: FastHashMap<NodeIdShort, Peers>,
channels_by_id: FastDashMap<AdnlChannelId, ChannelReceiver>,
channels_by_peers: FastDashMap<NodeIdShort, Arc<Channel>>,
incoming_transfers: Arc<FastDashMap<TransferId, Arc<Transfer>>>,
queries: Arc<QueriesCache>,
sender_queue_tx: SenderQueueTx,
init_state: Mutex<Option<InitializationState>>,
start_time: u32,
cancellation_token: CancellationToken,
}
impl Node {
pub fn new(
mut socket_addr: SocketAddrV4,
keystore: Keystore,
options: NodeOptions,
peer_filter: Option<Arc<dyn PeerFilter>>,
) -> Result<Arc<Self>> {
let socket = make_udp_socket(socket_addr.port())?;
if socket_addr.port() == 0 {
let local_addr = socket.local_addr().context("Failed to select UDP port")?;
socket_addr.set_port(local_addr.port());
}
let (sender_queue_tx, sender_queue_rx) = mpsc::unbounded_channel();
let mut peers =
FastHashMap::with_capacity_and_hasher(keystore.keys().len(), Default::default());
for key in keystore.keys().keys() {
peers.insert(*key, Peers::default());
}
Ok(Arc::new(Self {
socket_addr,
keystore,
options,
peer_filter,
peers,
channels_by_id: Default::default(),
channels_by_peers: Default::default(),
incoming_transfers: Default::default(),
queries: Default::default(),
sender_queue_tx,
init_state: Mutex::new(Some(InitializationState {
socket,
sender_queue_rx,
message_subscribers: Default::default(),
query_subscribers: Default::default(),
})),
start_time: now(),
cancellation_token: Default::default(),
}))
}
#[inline(always)]
pub fn options(&self) -> &NodeOptions {
&self.options
}
pub fn metrics(&self) -> NodeMetrics {
NodeMetrics {
peer_count: self.peers.values().map(|peers| peers.len()).sum(),
channels_by_id_len: self.channels_by_id.len(),
channels_by_peers_len: self.channels_by_peers.len(),
incoming_transfers_len: self.incoming_transfers.len(),
query_count: self.queries.len(),
}
}
pub fn add_message_subscriber(
&self,
message_subscriber: Arc<dyn MessageSubscriber>,
) -> Result<()> {
let mut init = self.init_state.lock();
match &mut *init {
Some(init) => {
init.message_subscribers.push(message_subscriber);
Ok(())
}
None => Err(NodeError::AlreadyRunning.into()),
}
}
pub fn add_query_subscriber(&self, query_subscriber: Arc<dyn QuerySubscriber>) -> Result<()> {
let mut init = self.init_state.lock();
match &mut *init {
Some(init) => {
init.query_subscribers.push(query_subscriber);
Ok(())
}
None => Err(NodeError::AlreadyRunning.into()),
}
}
pub fn start(self: &Arc<Self>) -> Result<()> {
let mut init = match self.init_state.lock().take() {
Some(init) => init,
None => return Err(NodeError::AlreadyRunning.into()),
};
init.query_subscribers.push(Arc::new(PingSubscriber));
self.start_sender(init.socket.clone(), init.sender_queue_rx);
self.start_receiver(
init.socket,
init.message_subscribers,
init.query_subscribers,
);
Ok(())
}
pub fn shutdown(&self) {
self.cancellation_token.cancel();
}
pub fn compute_query_timeout(&self, roundtrip: Option<u64>) -> u64 {
let timeout = roundtrip.unwrap_or(self.options.query_default_timeout_ms);
std::cmp::max(self.options.query_min_timeout_ms, timeout)
}
#[inline(always)]
pub fn socket_addr(&self) -> SocketAddrV4 {
self.socket_addr
}
#[inline(always)]
pub fn start_time(&self) -> u32 {
self.start_time
}
pub fn build_address_list(&self) -> proto::adnl::AddressList {
proto::adnl::AddressList {
address: Some(proto::adnl::Address::from(&self.socket_addr)),
version: now(),
reinit_date: self.start_time,
expire_at: 0,
}
}
pub fn key_by_id(&self, id: &NodeIdShort) -> Result<&Arc<Key>, KeystoreError> {
self.keystore.key_by_id(id)
}
pub fn key_by_tag(&self, tag: usize) -> Result<&Arc<Key>, KeystoreError> {
self.keystore.key_by_tag(tag)
}
pub fn add_peer(
&self,
ctx: NewPeerContext,
local_id: &NodeIdShort,
peer_id: &NodeIdShort,
addr: SocketAddrV4,
peer_id_full: NodeIdFull,
) -> Result<bool> {
use dashmap::mapref::entry::Entry;
if peer_id == local_id || addr == self.socket_addr {
return Ok(false);
}
if let Some(filter) = &self.peer_filter {
if !filter.check(ctx, addr, peer_id) {
return Ok(false);
}
}
match self.get_peers(local_id)?.entry(*peer_id) {
Entry::Occupied(entry) => entry.get().set_addr(addr),
Entry::Vacant(entry) => {
entry.insert(Peer::new(self.start_time, addr, peer_id_full));
tracing::trace!(%local_id, %peer_id, %addr, "added ADNL peer");
}
};
Ok(true)
}
pub fn remove_peer(&self, local_id: &NodeIdShort, peer_id: &NodeIdShort) -> Result<bool> {
let peers = self.get_peers(local_id)?;
self.channels_by_peers
.remove(peer_id)
.and_then(|(_, removed)| {
self.channels_by_id.remove(removed.ordinary_channel_in_id());
self.channels_by_id.remove(removed.priority_channel_in_id())
});
Ok(peers.remove(peer_id).is_some())
}
pub fn get_peer_address(
&self,
local_id: &NodeIdShort,
peer_id: &NodeIdShort,
) -> Option<SocketAddrV4> {
let peers = self.get_peers(local_id).ok()?;
let peer = peers.get(peer_id)?;
Some(peer.addr())
}
pub fn match_peer_addresses<T>(
&self,
local_id: &NodeIdShort,
mut entries: FastHashMap<SocketAddrV4, T>,
) -> Option<FastHashMap<T, NodeIdShort>>
where
T: std::hash::Hash + Eq,
{
let peers = self.get_peers(local_id).ok()?;
let mut result = FastHashMap::with_capacity_and_hasher(entries.len(), Default::default());
for peer in peers.iter() {
if let Some(key) = entries.remove(&peer.addr()) {
result.insert(key, *peer.key());
}
}
Some(result)
}
pub async fn query<Q, A>(
&self,
local_id: &NodeIdShort,
peer_id: &NodeIdShort,
query: Q,
timeout: Option<u64>,
) -> Result<Option<A>>
where
Q: TlWrite,
for<'a> A: TlRead<'a, Repr = tl_proto::Boxed> + 'static,
{
match self
.query_raw(local_id, peer_id, make_query(None, query), timeout)
.await?
{
Some(answer) => Ok(Some(tl_proto::deserialize(&answer)?)),
None => Ok(None),
}
}
pub async fn query_with_prefix<Q, A>(
&self,
local_id: &NodeIdShort,
peer_id: &NodeIdShort,
prefix: &[u8],
query: Q,
timeout: Option<u64>,
) -> Result<Option<A>>
where
Q: TlWrite,
for<'a> A: TlRead<'a, Repr = tl_proto::Boxed> + 'static,
{
match self
.query_raw(local_id, peer_id, make_query(Some(prefix), query), timeout)
.await?
{
Some(answer) => Ok(Some(tl_proto::deserialize(&answer)?)),
None => Ok(None),
}
}
pub async fn query_raw(
&self,
local_id: &NodeIdShort,
peer_id: &NodeIdShort,
query: Bytes,
timeout: Option<u64>,
) -> Result<Option<Vec<u8>>> {
let query_id: QueryId = gen_fast_bytes();
let pending_query = self.queries.add_query(query_id);
self.send_message(
local_id,
peer_id,
proto::adnl::Message::Query {
query_id: &query_id,
query: &query,
},
true,
)?;
drop(query);
let channel = self
.channels_by_peers
.get(peer_id)
.map(|entry| entry.value().clone());
let timeout = timeout.unwrap_or(self.options.query_default_timeout_ms);
let answer = tokio::time::timeout(Duration::from_millis(timeout), pending_query.wait())
.await
.ok()
.flatten();
if answer.is_none() {
if let Some(channel) = channel {
if channel.update_drop_timeout(now(), self.options.channel_reset_timeout_sec) {
self.reset_peer(local_id, peer_id)?;
}
}
}
Ok(answer)
}
pub fn send_custom_message(
&self,
local_id: &NodeIdShort,
peer_id: &NodeIdShort,
data: &[u8],
) -> Result<()> {
self.send_message(
local_id,
peer_id,
proto::adnl::Message::Custom { data },
self.options.force_use_priority_channels,
)
}
fn get_peers(&self, local_id: &NodeIdShort) -> Result<&Peers> {
if let Some(peers) = self.peers.get(local_id) {
Ok(peers)
} else {
Err(NodeError::PeersNotFound.into())
}
}
fn reset_peer(&self, local_id: &NodeIdShort, peer_id: &NodeIdShort) -> Result<()> {
let peers = self.get_peers(local_id)?;
let mut peer = peers.get_mut(peer_id).ok_or(NodeError::UnknownPeer)?;
tracing::trace!(%local_id, %peer_id, "resetting peer pair");
self.channels_by_peers
.remove(peer_id)
.and_then(|(_, removed)| {
self.channels_by_id.remove(removed.ordinary_channel_in_id());
self.channels_by_id.remove(removed.priority_channel_in_id())
});
peer.reset();
Ok(())
}
}
impl Drop for Node {
fn drop(&mut self) {
self.shutdown()
}
}
#[derive(Debug, Copy, Clone)]
pub struct NodeMetrics {
pub peer_count: usize,
pub channels_by_id_len: usize,
pub channels_by_peers_len: usize,
pub incoming_transfers_len: usize,
pub query_count: usize,
}
struct InitializationState {
socket: Arc<tokio::net::UdpSocket>,
sender_queue_rx: SenderQueueRx,
message_subscribers: Vec<Arc<dyn MessageSubscriber>>,
query_subscribers: Vec<Arc<dyn QuerySubscriber>>,
}
fn make_query<T>(prefix: Option<&[u8]>, query: T) -> Bytes
where
T: TlWrite,
{
let prefix_len = match prefix {
Some(prefix) => prefix.len(),
None => 0,
};
let mut data = Vec::with_capacity(prefix_len + query.max_size_hint());
if let Some(prefix) = prefix {
data.extend_from_slice(prefix);
}
query.write_to(&mut data);
data.into()
}
#[derive(thiserror::Error, Debug)]
enum NodeError {
#[error("ADNL node is already running")]
AlreadyRunning,
#[error("Local id peers not found")]
PeersNotFound,
#[error("Unknown peer")]
UnknownPeer,
}