1use std::net::SocketAddrV4;
2use std::sync::Arc;
3use std::time::Duration;
4
5use anyhow::{Context, Result};
6use bytes::Bytes;
7use parking_lot::Mutex;
8use serde::{Deserialize, Serialize};
9use tl_proto::{TlRead, TlWrite};
10use tokio::sync::mpsc;
11use tokio_util::sync::CancellationToken;
12
13use self::receiver::*;
14use self::sender::*;
15use super::channel::{AdnlChannelId, Channel};
16use super::keystore::{Key, Keystore, KeystoreError};
17use super::node_id::{NodeIdFull, NodeIdShort};
18use super::peer::{NewPeerContext, Peer, PeerFilter, Peers};
19use super::ping_subscriber::PingSubscriber;
20use super::queries_cache::{QueriesCache, QueryId};
21use super::socket::make_udp_socket;
22use super::transfer::*;
23use crate::proto;
24use crate::subscriber::*;
25use crate::util::*;
26
27mod receiver;
28mod sender;
29
30#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
32#[serde(default)]
33pub struct NodeOptions {
34 pub query_min_timeout_ms: u64,
40
41 pub query_default_timeout_ms: u64,
47
48 pub transfer_timeout_sec: u64,
53
54 pub clock_tolerance_sec: u32,
58
59 pub channel_reset_timeout_sec: u32,
63
64 pub address_list_timeout_sec: u32,
68
69 pub packet_history_enabled: bool,
73
74 pub packet_signature_required: bool,
78
79 pub force_use_priority_channels: bool,
83
84 pub use_loopback_for_neighbours: bool,
88
89 pub version: Option<u16>,
93}
94
95impl Default for NodeOptions {
96 fn default() -> Self {
97 Self {
98 query_min_timeout_ms: 500,
99 query_default_timeout_ms: 5000,
100 transfer_timeout_sec: 3,
101 clock_tolerance_sec: 60,
102 channel_reset_timeout_sec: 30,
103 address_list_timeout_sec: 1000,
104 packet_history_enabled: false,
105 packet_signature_required: true,
106 force_use_priority_channels: true,
107 use_loopback_for_neighbours: false,
108 version: None,
109 }
110 }
111}
112
113pub struct Node {
115 socket_addr: SocketAddrV4,
117 keystore: Keystore,
119 options: NodeOptions,
121
122 peer_filter: Option<Arc<dyn PeerFilter>>,
124
125 peers: FastHashMap<NodeIdShort, Peers>,
127
128 channels_by_id: FastDashMap<AdnlChannelId, ChannelReceiver>,
130 channels_by_peers: FastDashMap<NodeIdShort, Arc<Channel>>,
132
133 incoming_transfers: Arc<FastDashMap<TransferId, Arc<Transfer>>>,
135
136 queries: Arc<QueriesCache>,
138
139 sender_queue_tx: SenderQueueTx,
141 init_state: Mutex<Option<InitializationState>>,
143
144 start_time: u32,
146
147 cancellation_token: CancellationToken,
149}
150
151impl Node {
152 pub fn new(
154 mut socket_addr: SocketAddrV4,
155 keystore: Keystore,
156 options: NodeOptions,
157 peer_filter: Option<Arc<dyn PeerFilter>>,
158 ) -> Result<Arc<Self>> {
159 let socket = make_udp_socket(socket_addr.port())?;
161
162 if socket_addr.port() == 0 {
164 let local_addr = socket.local_addr().context("Failed to select UDP port")?;
165 socket_addr.set_port(local_addr.port());
166 }
167
168 let (sender_queue_tx, sender_queue_rx) = mpsc::unbounded_channel();
169
170 let mut peers =
172 FastHashMap::with_capacity_and_hasher(keystore.keys().len(), Default::default());
173 for key in keystore.keys().keys() {
174 peers.insert(*key, Peers::default());
175 }
176
177 Ok(Arc::new(Self {
178 socket_addr,
179 keystore,
180 options,
181 peer_filter,
182 peers,
183 channels_by_id: Default::default(),
184 channels_by_peers: Default::default(),
185 incoming_transfers: Default::default(),
186 queries: Default::default(),
187 sender_queue_tx,
188 init_state: Mutex::new(Some(InitializationState {
189 socket,
190 sender_queue_rx,
191 message_subscribers: Default::default(),
192 query_subscribers: Default::default(),
193 })),
194 start_time: now(),
195 cancellation_token: Default::default(),
196 }))
197 }
198
199 #[inline(always)]
201 pub fn options(&self) -> &NodeOptions {
202 &self.options
203 }
204
205 pub fn metrics(&self) -> NodeMetrics {
207 NodeMetrics {
208 peer_count: self.peers.values().map(|peers| peers.len()).sum(),
209 channels_by_id_len: self.channels_by_id.len(),
210 channels_by_peers_len: self.channels_by_peers.len(),
211 incoming_transfers_len: self.incoming_transfers.len(),
212 query_count: self.queries.len(),
213 }
214 }
215
216 pub fn add_message_subscriber(
218 &self,
219 message_subscriber: Arc<dyn MessageSubscriber>,
220 ) -> Result<()> {
221 let mut init = self.init_state.lock();
222 match &mut *init {
223 Some(init) => {
224 init.message_subscribers.push(message_subscriber);
225 Ok(())
226 }
227 None => Err(NodeError::AlreadyRunning.into()),
228 }
229 }
230
231 pub fn add_query_subscriber(&self, query_subscriber: Arc<dyn QuerySubscriber>) -> Result<()> {
233 let mut init = self.init_state.lock();
234 match &mut *init {
235 Some(init) => {
236 init.query_subscribers.push(query_subscriber);
237 Ok(())
238 }
239 None => Err(NodeError::AlreadyRunning.into()),
240 }
241 }
242
243 pub fn start(self: &Arc<Self>) -> Result<()> {
245 let mut init = match self.init_state.lock().take() {
247 Some(init) => init,
248 None => return Err(NodeError::AlreadyRunning.into()),
249 };
250
251 init.query_subscribers.push(Arc::new(PingSubscriber));
252
253 self.start_sender(init.socket.clone(), init.sender_queue_rx);
255 self.start_receiver(
256 init.socket,
257 init.message_subscribers,
258 init.query_subscribers,
259 );
260
261 Ok(())
263 }
264
265 pub fn shutdown(&self) {
267 self.cancellation_token.cancel();
268 }
269
270 pub fn compute_query_timeout(&self, roundtrip: Option<u64>) -> u64 {
272 let timeout = roundtrip.unwrap_or(self.options.query_default_timeout_ms);
273 std::cmp::max(self.options.query_min_timeout_ms, timeout)
274 }
275
276 #[inline(always)]
278 pub fn socket_addr(&self) -> SocketAddrV4 {
279 self.socket_addr
280 }
281
282 #[inline(always)]
284 pub fn start_time(&self) -> u32 {
285 self.start_time
286 }
287
288 pub fn build_address_list(&self) -> proto::adnl::AddressList {
290 proto::adnl::AddressList {
291 address: Some(proto::adnl::Address::from(&self.socket_addr)),
292 version: now(),
293 reinit_date: self.start_time,
294 expire_at: 0,
295 }
296 }
297
298 pub fn key_by_id(&self, id: &NodeIdShort) -> Result<&Arc<Key>, KeystoreError> {
302 self.keystore.key_by_id(id)
303 }
304
305 pub fn key_by_tag(&self, tag: usize) -> Result<&Arc<Key>, KeystoreError> {
309 self.keystore.key_by_tag(tag)
310 }
311
312 pub fn add_peer(
316 &self,
317 ctx: NewPeerContext,
318 local_id: &NodeIdShort,
319 peer_id: &NodeIdShort,
320 addr: SocketAddrV4,
321 peer_id_full: NodeIdFull,
322 ) -> Result<bool> {
323 use dashmap::mapref::entry::Entry;
324
325 if peer_id == local_id || addr == self.socket_addr {
327 return Ok(false);
328 }
329
330 if let Some(filter) = &self.peer_filter {
332 if !filter.check(ctx, addr, peer_id) {
333 return Ok(false);
334 }
335 }
336
337 match self.get_peers(local_id)?.entry(*peer_id) {
339 Entry::Occupied(entry) => entry.get().set_addr(addr),
341 Entry::Vacant(entry) => {
343 entry.insert(Peer::new(self.start_time, addr, peer_id_full));
344 tracing::trace!(%local_id, %peer_id, %addr, "added ADNL peer");
345 }
346 };
347
348 Ok(true)
349 }
350
351 pub fn remove_peer(&self, local_id: &NodeIdShort, peer_id: &NodeIdShort) -> Result<bool> {
358 let peers = self.get_peers(local_id)?;
359
360 self.channels_by_peers
361 .remove(peer_id)
362 .and_then(|(_, removed)| {
363 self.channels_by_id.remove(removed.ordinary_channel_in_id());
364 self.channels_by_id.remove(removed.priority_channel_in_id())
365 });
366
367 Ok(peers.remove(peer_id).is_some())
368 }
369
370 pub fn get_peer_address(
372 &self,
373 local_id: &NodeIdShort,
374 peer_id: &NodeIdShort,
375 ) -> Option<SocketAddrV4> {
376 let peers = self.get_peers(local_id).ok()?;
377 let peer = peers.get(peer_id)?;
378 Some(peer.addr())
379 }
380
381 pub fn match_peer_addresses<T>(
387 &self,
388 local_id: &NodeIdShort,
389 mut entries: FastHashMap<SocketAddrV4, T>,
390 ) -> Option<FastHashMap<T, NodeIdShort>>
391 where
392 T: std::hash::Hash + Eq,
393 {
394 let peers = self.get_peers(local_id).ok()?;
395
396 let mut result = FastHashMap::with_capacity_and_hasher(entries.len(), Default::default());
397 for peer in peers.iter() {
398 if let Some(key) = entries.remove(&peer.addr()) {
399 result.insert(key, *peer.key());
400 }
401 }
402
403 Some(result)
404 }
405
406 pub async fn query<Q, A>(
410 &self,
411 local_id: &NodeIdShort,
412 peer_id: &NodeIdShort,
413 query: Q,
414 timeout: Option<u64>,
415 ) -> Result<Option<A>>
416 where
417 Q: TlWrite,
418 for<'a> A: TlRead<'a, Repr = tl_proto::Boxed> + 'static,
419 {
420 match self
421 .query_raw(local_id, peer_id, make_query(None, query), timeout)
422 .await?
423 {
424 Some(answer) => Ok(Some(tl_proto::deserialize(&answer)?)),
425 None => Ok(None),
426 }
427 }
428
429 pub async fn query_with_prefix<Q, A>(
433 &self,
434 local_id: &NodeIdShort,
435 peer_id: &NodeIdShort,
436 prefix: &[u8],
437 query: Q,
438 timeout: Option<u64>,
439 ) -> Result<Option<A>>
440 where
441 Q: TlWrite,
442 for<'a> A: TlRead<'a, Repr = tl_proto::Boxed> + 'static,
443 {
444 match self
445 .query_raw(local_id, peer_id, make_query(Some(prefix), query), timeout)
446 .await?
447 {
448 Some(answer) => Ok(Some(tl_proto::deserialize(&answer)?)),
449 None => Ok(None),
450 }
451 }
452
453 pub async fn query_raw(
457 &self,
458 local_id: &NodeIdShort,
459 peer_id: &NodeIdShort,
460 query: Bytes,
461 timeout: Option<u64>,
462 ) -> Result<Option<Vec<u8>>> {
463 let query_id: QueryId = gen_fast_bytes();
464
465 let pending_query = self.queries.add_query(query_id);
466 self.send_message(
467 local_id,
468 peer_id,
469 proto::adnl::Message::Query {
470 query_id: &query_id,
471 query: &query,
472 },
473 self.options.force_use_priority_channels,
474 )?;
475 drop(query);
476
477 let channel = self
478 .channels_by_peers
479 .get(peer_id)
480 .map(|entry| entry.value().clone());
481
482 let timeout = timeout.unwrap_or(self.options.query_default_timeout_ms);
483 let answer = tokio::time::timeout(Duration::from_millis(timeout), pending_query.wait())
484 .await
485 .ok()
486 .flatten();
487
488 if answer.is_none() {
489 if let Some(channel) = channel {
490 if channel.update_drop_timeout(now(), self.options.channel_reset_timeout_sec) {
491 self.reset_peer(local_id, peer_id)?;
492 }
493 }
494 }
495
496 Ok(answer)
497 }
498
499 pub fn send_custom_message(
501 &self,
502 local_id: &NodeIdShort,
503 peer_id: &NodeIdShort,
504 data: &[u8],
505 ) -> Result<()> {
506 self.send_message(
507 local_id,
508 peer_id,
509 proto::adnl::Message::Custom { data },
510 self.options.force_use_priority_channels,
511 )
512 }
513
514 fn get_peers(&self, local_id: &NodeIdShort) -> Result<&Peers> {
515 if let Some(peers) = self.peers.get(local_id) {
516 Ok(peers)
517 } else {
518 Err(NodeError::PeersNotFound.into())
519 }
520 }
521
522 fn reset_peer(&self, local_id: &NodeIdShort, peer_id: &NodeIdShort) -> Result<()> {
523 let peers = self.get_peers(local_id)?;
524 let mut peer = peers.get_mut(peer_id).ok_or(NodeError::UnknownPeer)?;
525
526 tracing::trace!(%local_id, %peer_id, "resetting peer pair");
527
528 self.channels_by_peers
529 .remove(peer_id)
530 .and_then(|(_, removed)| {
531 self.channels_by_id.remove(removed.ordinary_channel_in_id());
532 self.channels_by_id.remove(removed.priority_channel_in_id())
533 });
534
535 peer.reset();
536
537 Ok(())
538 }
539}
540
541impl Drop for Node {
542 fn drop(&mut self) {
543 self.shutdown()
545 }
546}
547
548#[derive(Debug, Copy, Clone)]
550pub struct NodeMetrics {
551 pub peer_count: usize,
553 pub channels_by_id_len: usize,
555 pub channels_by_peers_len: usize,
557 pub incoming_transfers_len: usize,
559 pub query_count: usize,
561}
562
563struct InitializationState {
564 socket: Arc<tokio::net::UdpSocket>,
565 sender_queue_rx: SenderQueueRx,
567 message_subscribers: Vec<Arc<dyn MessageSubscriber>>,
568 query_subscribers: Vec<Arc<dyn QuerySubscriber>>,
569}
570
571fn make_query<T>(prefix: Option<&[u8]>, query: T) -> Bytes
572where
573 T: TlWrite,
574{
575 let prefix_len = match prefix {
576 Some(prefix) => prefix.len(),
577 None => 0,
578 };
579 let mut data = Vec::with_capacity(prefix_len + query.max_size_hint());
580 if let Some(prefix) = prefix {
581 data.extend_from_slice(prefix);
582 }
583 query.write_to(&mut data);
584 data.into()
585}
586
587#[derive(thiserror::Error, Debug)]
588enum NodeError {
589 #[error("ADNL node is already running")]
590 AlreadyRunning,
591 #[error("Local id peers not found")]
592 PeersNotFound,
593 #[error("Unknown peer")]
594 UnknownPeer,
595}