Skip to main content

mtorrent_dht/
queries.rs

1mod manager;
2
3#[cfg(test)]
4mod tests;
5
6use super::error::Error;
7use super::msgs::*;
8use super::u160::U160;
9use super::udp;
10use derive_more::derive::From;
11use futures_util::StreamExt;
12use local_async_utils::prelude::*;
13use manager::{OutgoingQuery, QueryManager};
14use mtorrent_utils::{debug_stopwatch, trace_stopwatch};
15use std::fmt::Debug;
16use std::future::pending;
17use std::marker::PhantomData;
18use std::mem;
19use std::net::SocketAddr;
20use std::rc::Rc;
21use tokio::select;
22use tokio::sync::{Semaphore, mpsc};
23use tokio::time::{Instant, sleep_until};
24
25/// Client for sending outgoing queries to different nodes.
26#[derive(Clone)]
27pub struct OutboundQueries {
28    channel: local_unbounded::Sender<OutgoingQuery>,
29    query_slots: Rc<Semaphore>,
30}
31
32impl OutboundQueries {
33    pub(super) async fn ping(
34        &self,
35        destination: SocketAddr,
36        query: PingArgs,
37    ) -> Result<PingResponse, Error> {
38        let _sw = trace_stopwatch!("Ping query to {destination}");
39        self.do_query(destination, query).await
40    }
41
42    pub(super) async fn find_node(
43        &self,
44        destination: SocketAddr,
45        query: FindNodeArgs,
46    ) -> Result<FindNodeResponse, Error> {
47        let _sw = trace_stopwatch!("FindNode query to {destination}");
48        self.do_query(destination, query).await
49    }
50
51    pub(super) async fn get_peers(
52        &self,
53        destination: SocketAddr,
54        query: GetPeersArgs,
55    ) -> Result<GetPeersResponse, Error> {
56        let _sw = trace_stopwatch!("GetPeers query to {destination}");
57        self.do_query(destination, query).await
58    }
59
60    pub(super) async fn announce_peer(
61        &self,
62        destination: SocketAddr,
63        query: AnnouncePeerArgs,
64    ) -> Result<AnnouncePeerResponse, Error> {
65        let _sw = trace_stopwatch!("AnnouncePeer query to {destination}");
66        self.do_query(destination, query).await
67    }
68
69    async fn do_query<Q, R>(&self, dst_addr: SocketAddr, args: Q) -> Result<R, Error>
70    where
71        Q: Into<QueryMsg> + Debug,
72        R: TryFrom<ResponseMsg, Error = Error> + Debug,
73    {
74        let _slot = self.query_slots.acquire().await;
75        let (tx, rx) = local_oneshot::channel();
76        log::trace!("[{dst_addr}] <= {args:?}");
77        self.channel.send(OutgoingQuery {
78            query: args.into(),
79            destination_addr: dst_addr,
80            response_sink: tx,
81        })?;
82        let result = rx.await.ok_or(Error::ChannelClosed)?.and_then(R::try_from);
83        match &result {
84            Ok(response) => log::trace!("[{dst_addr}] => {response:?}"),
85            Err(Error::ErrorResponse(msg)) => log::debug!("[{dst_addr}] => {msg:?}"),
86            Err(Error::Timeout) => log::trace!("Query to {dst_addr} timed out"),
87            Err(e) => log::debug!("Query to {dst_addr} failed: {e}"),
88        }
89        result
90    }
91}
92
93/// Server for receiving incoming queries from different nodes.
94pub struct InboundQueries(pub(super) local_unbounded::Receiver<IncomingQuery>);
95
96/// Actor that routes queries between [`crate::Processor`] and [`crate::IoDriver`],
97/// performs retries and matches requests and responses.
98pub struct QueryRouter {
99    queries: QueryManager,
100    outgoing_queries_source: local_unbounded::Receiver<OutgoingQuery>,
101    incoming_msgs_source: mpsc::Receiver<(Message, SocketAddr)>,
102}
103
104impl QueryRouter {
105    pub async fn run(mut self) {
106        let _sw = debug_stopwatch!("Queries runner");
107        loop {
108            let next_timeout = self.queries.next_timeout();
109            select! {
110                biased;
111                outgoing = self.outgoing_queries_source.next() => {
112                    let Some(query) = outgoing else { break };
113                    if let Err(e) = self.queries.handle_one_outgoing(query).await {
114                        log::warn!("Error while handling outbound query: {e}");
115                        break;
116                    }
117                }
118                incoming = self.incoming_msgs_source.recv() => {
119                    let Some(msg) = incoming else { break };
120                    if let Err(e) = self.queries.handle_one_incoming(msg).await {
121                        log::warn!("Error while handling inbound query: {e}");
122                        break;
123                    }
124                }
125                _ = Self::sleep_until(next_timeout), if next_timeout.is_some() => {
126                    if let Err(e) = self.queries.handle_timeouts().await {
127                        log::warn!("Error while handling timeouts: {e}");
128                        break;
129                    }
130                }
131            }
132        }
133    }
134
135    async fn sleep_until(deadline: Option<Instant>) {
136        #[cfg(not(test))]
137        match deadline {
138            Some(deadline) => sleep_until(deadline).await,
139            _ => pending::<()>().await,
140        }
141
142        #[cfg(test)]
143        match deadline {
144            Some(deadline) if tests::SLEEP_ENABLED.get() => sleep_until(deadline).await,
145            _ => pending::<()>().await,
146        }
147    }
148}
149
150// ------------------------------------------------------------------------------------------------
151
152/// Create the layer that facilitates inbound and outbound transactions (queries).
153pub fn setup_queries(
154    udp::MessageChannelSender(outgoing_msgs_sink): udp::MessageChannelSender,
155    udp::MessageChannelReceiver(incoming_msgs_source): udp::MessageChannelReceiver,
156    max_concurrent_queries: Option<usize>,
157) -> (OutboundQueries, InboundQueries, QueryRouter) {
158    let (outgoing_queries_sink, outgoing_queries_source) = local_unbounded::channel();
159    let (incoming_queries_sink, incoming_queries_source) = local_unbounded::channel();
160    let max_in_flight = max_concurrent_queries.unwrap_or(udp::MSG_QUEUE_LEN);
161
162    let runner = QueryRouter {
163        queries: QueryManager::new(outgoing_msgs_sink, incoming_queries_sink),
164        outgoing_queries_source,
165        incoming_msgs_source,
166    };
167    let client = OutboundQueries {
168        channel: outgoing_queries_sink,
169        query_slots: Rc::new(Semaphore::const_new(max_in_flight)),
170    };
171    if max_in_flight == 0 {
172        client.query_slots.close();
173    }
174    (client, InboundQueries(incoming_queries_source), runner)
175}
176
177// ------------------------------------------------------------------------------------------------
178
179#[cfg_attr(test, derive(Debug))]
180#[derive(From)]
181pub(super) enum IncomingQuery {
182    Ping(IncomingPingQuery),
183    FindNode(IncomingFindNodeQuery),
184    GetPeers(IncomingGetPeersQuery),
185    AnnouncePeer(IncomingAnnouncePeerQuery),
186}
187
188#[cfg_attr(test, derive(Debug))]
189pub(super) struct IncomingGenericQuery<Q, R> {
190    transaction_id: Vec<u8>,
191    query: Q,
192    response_sink: Option<mpsc::OwnedPermit<(Message, SocketAddr)>>,
193    source_addr: SocketAddr,
194    _stopwatch: Stopwatch,
195    _response_type: PhantomData<R>,
196}
197
198pub(super) type IncomingPingQuery = IncomingGenericQuery<PingArgs, PingResponse>;
199pub(super) type IncomingFindNodeQuery = IncomingGenericQuery<FindNodeArgs, FindNodeResponse>;
200pub(super) type IncomingGetPeersQuery = IncomingGenericQuery<GetPeersArgs, GetPeersResponse>;
201pub(super) type IncomingAnnouncePeerQuery =
202    IncomingGenericQuery<AnnouncePeerArgs, AnnouncePeerResponse>;
203
204impl IncomingQuery {
205    fn new(
206        incoming: QueryMsg,
207        tid: Vec<u8>,
208        sink: mpsc::OwnedPermit<(Message, SocketAddr)>,
209        remote_addr: SocketAddr,
210    ) -> IncomingQuery {
211        macro_rules! construct {
212            ($query_args:expr, $name:literal) => {{
213                log::trace!("[{}] => {:?}", remote_addr, $query_args);
214                IncomingQuery::from(IncomingGenericQuery {
215                    transaction_id: tid,
216                    query: $query_args,
217                    response_sink: Some(sink),
218                    source_addr: remote_addr,
219                    _stopwatch: trace_stopwatch!("{} query from {}", $name, remote_addr),
220                    _response_type: PhantomData,
221                })
222            }};
223        }
224        match incoming {
225            QueryMsg::Ping(args) => construct!(args, "Ping"),
226            QueryMsg::FindNode(args) => construct!(args, "FindNode"),
227            QueryMsg::GetPeers(args) => construct!(args, "GetPeers"),
228            QueryMsg::AnnouncePeer(args) => construct!(args, "AnnouncePeer"),
229        }
230    }
231
232    pub(super) fn node_id(&self) -> &U160 {
233        match self {
234            IncomingQuery::Ping(q) => &q.args().id,
235            IncomingQuery::FindNode(q) => &q.args().id,
236            IncomingQuery::GetPeers(q) => &q.args().id,
237            IncomingQuery::AnnouncePeer(q) => &q.args().id,
238        }
239    }
240
241    pub(super) fn source_addr(&self) -> &SocketAddr {
242        match self {
243            IncomingQuery::Ping(q) => q.source_addr(),
244            IncomingQuery::FindNode(q) => q.source_addr(),
245            IncomingQuery::GetPeers(q) => q.source_addr(),
246            IncomingQuery::AnnouncePeer(q) => q.source_addr(),
247        }
248    }
249}
250
251impl<Q, R> IncomingGenericQuery<Q, R> {
252    pub(super) fn args(&self) -> &Q {
253        &self.query
254    }
255
256    pub(super) fn source_addr(&self) -> &SocketAddr {
257        &self.source_addr
258    }
259
260    pub(super) fn respond(mut self, response: R) -> Result<(), Error>
261    where
262        R: Into<ResponseMsg> + Debug,
263    {
264        log::trace!("[{}] <= {:?}", self.source_addr, response);
265        let sender = self.response_sink.take().unwrap_or_else(|| unreachable!()).send((
266            Message {
267                transaction_id: mem::take(&mut self.transaction_id),
268                version: None,
269                data: MessageData::Response(response.into()),
270            },
271            self.source_addr,
272        ));
273        if sender.is_closed() {
274            Err(Error::ChannelClosed)
275        } else {
276            Ok(())
277        }
278    }
279
280    pub(super) fn respond_error(mut self, error: ErrorMsg) -> Result<(), Error> {
281        log::debug!("[{}] <= {:?}", self.source_addr, error);
282        let sender = self.response_sink.take().unwrap_or_else(|| unreachable!()).send((
283            Message {
284                transaction_id: mem::take(&mut self.transaction_id),
285                version: None,
286                data: MessageData::Error(error),
287            },
288            self.source_addr,
289        ));
290        if sender.is_closed() {
291            Err(Error::ChannelClosed)
292        } else {
293            Ok(())
294        }
295    }
296}
297
298impl<Q, R> Drop for IncomingGenericQuery<Q, R> {
299    fn drop(&mut self) {
300        if let Some(sink) = self.response_sink.take() {
301            let error_msg = ErrorMsg {
302                error_code: ErrorCode::Server,
303                error_msg: "Unable to handle query".to_string(),
304            };
305            log::warn!("[{}] <= {:?}", self.source_addr, error_msg);
306            sink.send((
307                Message {
308                    transaction_id: mem::take(&mut self.transaction_id),
309                    version: None,
310                    data: MessageData::Error(error_msg),
311                },
312                self.source_addr,
313            ));
314        }
315    }
316}