use std::sync::Arc;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use tokio::sync::Semaphore;
use super::compression;
use super::transfers_cache::*;
use crate::adnl;
use crate::proto;
use crate::subscriber::*;
use crate::util::*;
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct NodeOptions {
pub max_answer_size: u32,
pub max_peer_queries: usize,
pub query_min_timeout_ms: u64,
pub query_max_timeout_ms: u64,
pub query_wave_len: u32,
pub query_wave_interval_ms: u64,
pub force_compression: bool,
}
impl Default for NodeOptions {
fn default() -> Self {
Self {
max_answer_size: 10 * 1024 * 1024,
max_peer_queries: 16,
query_min_timeout_ms: 500,
query_max_timeout_ms: 10000,
query_wave_len: 10,
query_wave_interval_ms: 10,
force_compression: false,
}
}
}
pub struct Node {
adnl: Arc<adnl::Node>,
semaphores: FastDashMap<adnl::NodeIdShort, Arc<Semaphore>>,
transfers: Arc<TransfersCache>,
options: NodeOptions,
}
impl Node {
pub fn new(
adnl: Arc<adnl::Node>,
subscribers: Vec<Arc<dyn QuerySubscriber>>,
options: NodeOptions,
) -> Result<Arc<Self>> {
let transfers = Arc::new(TransfersCache::new(subscribers, options));
adnl.add_message_subscriber(transfers.clone())?;
Ok(Arc::new(Self {
adnl,
semaphores: Default::default(),
transfers,
options,
}))
}
#[inline(always)]
pub fn adnl(&self) -> &Arc<adnl::Node> {
&self.adnl
}
#[inline(always)]
pub fn options(&self) -> &NodeOptions {
&self.options
}
pub fn metrics(&self) -> NodeMetrics {
NodeMetrics {
peer_count: self.semaphores.len(),
transfers_cache_len: self.transfers.len(),
}
}
pub fn gc(&self) {
let max_permits = self.options.max_peer_queries;
self.semaphores
.retain(|_, semaphore| semaphore.available_permits() < max_permits);
}
#[tracing::instrument(level = "debug", name = "rldp_query", skip_all, fields(%local_id, %peer_id, ?roundtrip))]
pub async fn query(
&self,
local_id: &adnl::NodeIdShort,
peer_id: &adnl::NodeIdShort,
data: Vec<u8>,
roundtrip: Option<u64>,
) -> Result<(Option<Vec<u8>>, u64)> {
let (query_id, query) = self.make_query(data);
let peer = self
.semaphores
.entry(*peer_id)
.or_insert_with(|| Arc::new(Semaphore::new(self.options.max_peer_queries)))
.value()
.clone();
let result = {
let _permit = peer.acquire().await.ok();
self.transfers
.query(&self.adnl, local_id, peer_id, query, roundtrip)
.await
};
match result? {
(Some(answer), roundtrip) => match tl_proto::deserialize(&answer) {
Ok(proto::rldp::Message::Answer {
query_id: answer_id,
data,
}) if answer_id == &query_id => Ok((
Some(compression::decompress(data).unwrap_or_else(|| data.to_vec())),
roundtrip,
)),
Ok(proto::rldp::Message::Answer { .. }) => Err(NodeError::QueryIdMismatch.into()),
Ok(proto::rldp::Message::Message { .. }) => {
Err(NodeError::UnexpectedAnswer("RldpMessageView::Message").into())
}
Ok(proto::rldp::Message::Query { .. }) => {
Err(NodeError::UnexpectedAnswer("RldpMessageView::Query").into())
}
Err(e) => Err(NodeError::InvalidPacketContent(e).into()),
},
(None, roundtrip) => Ok((None, roundtrip)),
}
}
fn make_query(&self, mut data: Vec<u8>) -> ([u8; 32], Vec<u8>) {
if self.options.force_compression {
if let Err(e) = compression::compress(&mut data) {
tracing::warn!("failed to compress RLDP query: {e:?}");
}
}
let query_id = gen_fast_bytes();
let data = proto::rldp::Message::Query {
query_id: &query_id,
max_answer_size: self.options.max_answer_size as u64,
timeout: now() + self.options.query_max_timeout_ms as u32 / 1000,
data: &data,
};
(query_id, tl_proto::serialize(data))
}
}
#[async_trait::async_trait]
impl MessageSubscriber for TransfersCache {
async fn try_consume_custom<'a>(
&self,
ctx: SubscriberContext<'a>,
constructor: u32,
data: &'a [u8],
) -> Result<bool> {
if constructor != proto::rldp::MessagePart::TL_ID_MESSAGE_PART
&& constructor != proto::rldp::MessagePart::TL_ID_CONFIRM
&& constructor != proto::rldp::MessagePart::TL_ID_COMPLETE
{
return Ok(false);
}
let message_part = tl_proto::deserialize::<proto::rldp::MessagePart<'_>>(data)?;
if message_part.is_valid() {
self.handle_message(ctx.adnl, ctx.local_id, ctx.peer_id, message_part)
.await?;
}
Ok(true)
}
}
#[derive(Debug, Copy, Clone)]
pub struct NodeMetrics {
pub peer_count: usize,
pub transfers_cache_len: usize,
}
#[derive(thiserror::Error, Debug)]
enum NodeError {
#[error("Unexpected answer: {0}")]
UnexpectedAnswer(&'static str),
#[error("Invalid packet content: {0:?}")]
InvalidPacketContent(tl_proto::TlError),
#[error("Unknown query id")]
QueryIdMismatch,
}