1mod manager;
2
3#[cfg(test)]
4mod tests;
5
6use super::error::Error;
7use super::msgs::*;
8use super::u160::U160;
9use super::udp;
10use derive_more::derive::From;
11use futures_util::StreamExt;
12use local_async_utils::prelude::*;
13use manager::{OutgoingQuery, QueryManager};
14use mtorrent_utils::{debug_stopwatch, trace_stopwatch};
15use std::fmt::Debug;
16use std::future::pending;
17use std::marker::PhantomData;
18use std::mem;
19use std::net::SocketAddr;
20use std::rc::Rc;
21use tokio::select;
22use tokio::sync::{Semaphore, mpsc};
23use tokio::time::{Instant, sleep_until};
24
25#[derive(Clone)]
27pub struct OutboundQueries {
28 channel: local_unbounded::Sender<OutgoingQuery>,
29 query_slots: Rc<Semaphore>,
30}
31
32impl OutboundQueries {
33 pub(super) async fn ping(
34 &self,
35 destination: SocketAddr,
36 query: PingArgs,
37 ) -> Result<PingResponse, Error> {
38 let _sw = trace_stopwatch!("Ping query to {destination}");
39 self.do_query(destination, query).await
40 }
41
42 pub(super) async fn find_node(
43 &self,
44 destination: SocketAddr,
45 query: FindNodeArgs,
46 ) -> Result<FindNodeResponse, Error> {
47 let _sw = trace_stopwatch!("FindNode query to {destination}");
48 self.do_query(destination, query).await
49 }
50
51 pub(super) async fn get_peers(
52 &self,
53 destination: SocketAddr,
54 query: GetPeersArgs,
55 ) -> Result<GetPeersResponse, Error> {
56 let _sw = trace_stopwatch!("GetPeers query to {destination}");
57 self.do_query(destination, query).await
58 }
59
60 pub(super) async fn announce_peer(
61 &self,
62 destination: SocketAddr,
63 query: AnnouncePeerArgs,
64 ) -> Result<AnnouncePeerResponse, Error> {
65 let _sw = trace_stopwatch!("AnnouncePeer query to {destination}");
66 self.do_query(destination, query).await
67 }
68
69 async fn do_query<Q, R>(&self, dst_addr: SocketAddr, args: Q) -> Result<R, Error>
70 where
71 Q: Into<QueryMsg> + Debug,
72 R: TryFrom<ResponseMsg, Error = Error> + Debug,
73 {
74 let _slot = self.query_slots.acquire().await;
75 let (tx, rx) = local_oneshot::channel();
76 log::trace!("[{dst_addr}] <= {args:?}");
77 self.channel.send(OutgoingQuery {
78 query: args.into(),
79 destination_addr: dst_addr,
80 response_sink: tx,
81 })?;
82 let result = rx.await.ok_or(Error::ChannelClosed)?.and_then(R::try_from);
83 match &result {
84 Ok(response) => log::trace!("[{dst_addr}] => {response:?}"),
85 Err(Error::ErrorResponse(msg)) => log::debug!("[{dst_addr}] => {msg:?}"),
86 Err(Error::Timeout) => log::trace!("Query to {dst_addr} timed out"),
87 Err(e) => log::debug!("Query to {dst_addr} failed: {e}"),
88 }
89 result
90 }
91}
92
93pub struct InboundQueries(pub(super) local_unbounded::Receiver<IncomingQuery>);
95
96pub struct QueryRouter {
99 queries: QueryManager,
100 outgoing_queries_source: local_unbounded::Receiver<OutgoingQuery>,
101 incoming_msgs_source: mpsc::Receiver<(Message, SocketAddr)>,
102}
103
104impl QueryRouter {
105 pub async fn run(mut self) {
106 let _sw = debug_stopwatch!("Queries runner");
107 loop {
108 let next_timeout = self.queries.next_timeout();
109 select! {
110 biased;
111 outgoing = self.outgoing_queries_source.next() => {
112 let Some(query) = outgoing else { break };
113 if let Err(e) = self.queries.handle_one_outgoing(query).await {
114 log::warn!("Error while handling outbound query: {e}");
115 break;
116 }
117 }
118 incoming = self.incoming_msgs_source.recv() => {
119 let Some(msg) = incoming else { break };
120 if let Err(e) = self.queries.handle_one_incoming(msg).await {
121 log::warn!("Error while handling inbound query: {e}");
122 break;
123 }
124 }
125 _ = Self::sleep_until(next_timeout), if next_timeout.is_some() => {
126 if let Err(e) = self.queries.handle_timeouts().await {
127 log::warn!("Error while handling timeouts: {e}");
128 break;
129 }
130 }
131 }
132 }
133 }
134
135 async fn sleep_until(deadline: Option<Instant>) {
136 #[cfg(not(test))]
137 match deadline {
138 Some(deadline) => sleep_until(deadline).await,
139 _ => pending::<()>().await,
140 }
141
142 #[cfg(test)]
143 match deadline {
144 Some(deadline) if tests::SLEEP_ENABLED.get() => sleep_until(deadline).await,
145 _ => pending::<()>().await,
146 }
147 }
148}
149
150pub fn setup_queries(
154 udp::MessageChannelSender(outgoing_msgs_sink): udp::MessageChannelSender,
155 udp::MessageChannelReceiver(incoming_msgs_source): udp::MessageChannelReceiver,
156 max_concurrent_queries: Option<usize>,
157) -> (OutboundQueries, InboundQueries, QueryRouter) {
158 let (outgoing_queries_sink, outgoing_queries_source) = local_unbounded::channel();
159 let (incoming_queries_sink, incoming_queries_source) = local_unbounded::channel();
160 let max_in_flight = max_concurrent_queries.unwrap_or(udp::MSG_QUEUE_LEN);
161
162 let runner = QueryRouter {
163 queries: QueryManager::new(outgoing_msgs_sink, incoming_queries_sink),
164 outgoing_queries_source,
165 incoming_msgs_source,
166 };
167 let client = OutboundQueries {
168 channel: outgoing_queries_sink,
169 query_slots: Rc::new(Semaphore::const_new(max_in_flight)),
170 };
171 if max_in_flight == 0 {
172 client.query_slots.close();
173 }
174 (client, InboundQueries(incoming_queries_source), runner)
175}
176
177#[cfg_attr(test, derive(Debug))]
180#[derive(From)]
181pub(super) enum IncomingQuery {
182 Ping(IncomingPingQuery),
183 FindNode(IncomingFindNodeQuery),
184 GetPeers(IncomingGetPeersQuery),
185 AnnouncePeer(IncomingAnnouncePeerQuery),
186}
187
188#[cfg_attr(test, derive(Debug))]
189pub(super) struct IncomingGenericQuery<Q, R> {
190 transaction_id: Vec<u8>,
191 query: Q,
192 response_sink: Option<mpsc::OwnedPermit<(Message, SocketAddr)>>,
193 source_addr: SocketAddr,
194 _stopwatch: Stopwatch,
195 _response_type: PhantomData<R>,
196}
197
198pub(super) type IncomingPingQuery = IncomingGenericQuery<PingArgs, PingResponse>;
199pub(super) type IncomingFindNodeQuery = IncomingGenericQuery<FindNodeArgs, FindNodeResponse>;
200pub(super) type IncomingGetPeersQuery = IncomingGenericQuery<GetPeersArgs, GetPeersResponse>;
201pub(super) type IncomingAnnouncePeerQuery =
202 IncomingGenericQuery<AnnouncePeerArgs, AnnouncePeerResponse>;
203
204impl IncomingQuery {
205 fn new(
206 incoming: QueryMsg,
207 tid: Vec<u8>,
208 sink: mpsc::OwnedPermit<(Message, SocketAddr)>,
209 remote_addr: SocketAddr,
210 ) -> IncomingQuery {
211 macro_rules! construct {
212 ($query_args:expr, $name:literal) => {{
213 log::trace!("[{}] => {:?}", remote_addr, $query_args);
214 IncomingQuery::from(IncomingGenericQuery {
215 transaction_id: tid,
216 query: $query_args,
217 response_sink: Some(sink),
218 source_addr: remote_addr,
219 _stopwatch: trace_stopwatch!("{} query from {}", $name, remote_addr),
220 _response_type: PhantomData,
221 })
222 }};
223 }
224 match incoming {
225 QueryMsg::Ping(args) => construct!(args, "Ping"),
226 QueryMsg::FindNode(args) => construct!(args, "FindNode"),
227 QueryMsg::GetPeers(args) => construct!(args, "GetPeers"),
228 QueryMsg::AnnouncePeer(args) => construct!(args, "AnnouncePeer"),
229 }
230 }
231
232 pub(super) fn node_id(&self) -> &U160 {
233 match self {
234 IncomingQuery::Ping(q) => &q.args().id,
235 IncomingQuery::FindNode(q) => &q.args().id,
236 IncomingQuery::GetPeers(q) => &q.args().id,
237 IncomingQuery::AnnouncePeer(q) => &q.args().id,
238 }
239 }
240
241 pub(super) fn source_addr(&self) -> &SocketAddr {
242 match self {
243 IncomingQuery::Ping(q) => q.source_addr(),
244 IncomingQuery::FindNode(q) => q.source_addr(),
245 IncomingQuery::GetPeers(q) => q.source_addr(),
246 IncomingQuery::AnnouncePeer(q) => q.source_addr(),
247 }
248 }
249}
250
251impl<Q, R> IncomingGenericQuery<Q, R> {
252 pub(super) fn args(&self) -> &Q {
253 &self.query
254 }
255
256 pub(super) fn source_addr(&self) -> &SocketAddr {
257 &self.source_addr
258 }
259
260 pub(super) fn respond(mut self, response: R) -> Result<(), Error>
261 where
262 R: Into<ResponseMsg> + Debug,
263 {
264 log::trace!("[{}] <= {:?}", self.source_addr, response);
265 let sender = self.response_sink.take().unwrap_or_else(|| unreachable!()).send((
266 Message {
267 transaction_id: mem::take(&mut self.transaction_id),
268 version: None,
269 data: MessageData::Response(response.into()),
270 },
271 self.source_addr,
272 ));
273 if sender.is_closed() {
274 Err(Error::ChannelClosed)
275 } else {
276 Ok(())
277 }
278 }
279
280 pub(super) fn respond_error(mut self, error: ErrorMsg) -> Result<(), Error> {
281 log::debug!("[{}] <= {:?}", self.source_addr, error);
282 let sender = self.response_sink.take().unwrap_or_else(|| unreachable!()).send((
283 Message {
284 transaction_id: mem::take(&mut self.transaction_id),
285 version: None,
286 data: MessageData::Error(error),
287 },
288 self.source_addr,
289 ));
290 if sender.is_closed() {
291 Err(Error::ChannelClosed)
292 } else {
293 Ok(())
294 }
295 }
296}
297
298impl<Q, R> Drop for IncomingGenericQuery<Q, R> {
299 fn drop(&mut self) {
300 if let Some(sink) = self.response_sink.take() {
301 let error_msg = ErrorMsg {
302 error_code: ErrorCode::Server,
303 error_msg: "Unable to handle query".to_string(),
304 };
305 log::warn!("[{}] <= {:?}", self.source_addr, error_msg);
306 sink.send((
307 Message {
308 transaction_id: mem::take(&mut self.transaction_id),
309 version: None,
310 data: MessageData::Error(error_msg),
311 },
312 self.source_addr,
313 ));
314 }
315 }
316}