Skip to main content

mtorrent_dht/
processor.rs

1use super::cmds::{Command, CommandSource};
2use super::kademlia;
3use super::msgs::*;
4use super::peers::TokenManager;
5use super::queries::{InboundQueries, IncomingQuery, OutboundQueries};
6use super::tasks::*;
7use super::u160::U160;
8use crate::config::Config;
9use crate::kademlia::Node;
10use crate::peers::PeerTable;
11use futures_util::StreamExt;
12use local_async_utils::prelude::*;
13use mtorrent_utils::{info_stopwatch, warn_stopwatch};
14use std::collections::hash_map::Entry;
15use std::collections::{HashMap, HashSet};
16use std::net::{SocketAddr, ToSocketAddrs};
17use std::ops::ControlFlow::{self, *};
18use std::path::PathBuf;
19use std::rc::Rc;
20use std::time::Duration;
21use tokio::sync::mpsc;
22use tokio::sync::mpsc::error::TrySendError;
23use tokio::{select, task, time};
24use tokio_util::sync::CancellationToken;
25
26type RoutingTable = kademlia::RoutingTable<16>;
27
28/// Actor that maintains the routing table, responds to incoming queries,
29/// keeps track of discovered peers, and handles commands from the user.
30pub struct Processor {
31    node_table: Box<RoutingTable>,
32    peers: PeerTable,
33    token_mgr: TokenManager,
34    known_nodes: HashSet<SocketAddr>,
35    search_callbacks: HashMap<U160, mpsc::Sender<SocketAddr>>,
36    task_ctx: Rc<Ctx>,
37
38    peer_sender: local_unbounded::Sender<(SocketAddr, U160)>,
39    peer_receiver: local_unbounded::Receiver<(SocketAddr, U160)>,
40    node_event_receiver: mpsc::Receiver<NodeEvent>,
41
42    config: Config,
43    config_dir: PathBuf,
44    canceller: CancellationToken,
45}
46
47impl Processor {
48    pub fn new(config_dir: PathBuf, client: OutboundQueries) -> Self {
49        let (peer_sender, peer_receiver) = local_unbounded::channel();
50        let (node_event_sender, node_event_receiver) = mpsc::channel(1024);
51
52        let config = Config::load(&config_dir).unwrap_or_else(|e| {
53            log::info!("Failed to load config ({e}), using defaults");
54            Config::default()
55        });
56        let node_ctx = Rc::new(Ctx {
57            client: client.clone(),
58            event_reporter: node_event_sender,
59            local_id: config.local_id,
60        });
61        let node_table = RoutingTable::new_boxed(config.local_id);
62
63        Self {
64            peers: PeerTable::new(),
65            token_mgr: TokenManager::new(),
66            peer_sender,
67            peer_receiver,
68            config,
69            config_dir,
70            task_ctx: node_ctx,
71            node_event_receiver,
72            node_table,
73            known_nodes: HashSet::with_capacity(512),
74            search_callbacks: HashMap::new(),
75            canceller: CancellationToken::new(),
76        }
77    }
78
79    pub fn set_bootstrap_nodes(&mut self, nodes: Vec<String>) {
80        self.config.nodes = nodes;
81    }
82
83    pub async fn run(mut self, mut queries: InboundQueries, mut commands: CommandSource) {
84        macro_rules! handle_next_event {
85            ($queries:expr $(,$commands:expr)?) => {
86                select! {
87                    biased;
88                    $(cmd = $commands.next() => match cmd {
89                        Some(cmd) => self.handle_command(cmd),
90                        None => Break(()),
91                    },)?
92                    event = self.node_event_receiver.recv() => match event {
93                        Some(event) => Continue(self.handle_node_event(event)),
94                        None => Break(()),
95                    },
96                    query = $queries.0.next() => match query {
97                        Some(query) => Continue(self.handle_query(query)),
98                        None => Break(()),
99                    },
100                    peer_target = self.peer_receiver.next() => match peer_target {
101                        Some((peer, target)) => Continue(self.peers.add_record(target, peer)),
102                        None => unreachable!(),
103                    },
104                }
105            };
106        }
107
108        const BOOTSTRAP_TIMEOUT: Duration = sec!(10);
109        const BOOTSTRAP_TARGET: usize = 200;
110
111        // do all DNS resolution first because it will block the thread
112        let sw = warn_stopwatch!("DNS resolution of bootstrapping nodes");
113        let bootstrapping_nodes: Vec<SocketAddr> = self
114            .config
115            .nodes
116            .iter()
117            .filter_map(|node| node.to_socket_addrs().ok())
118            .flatten()
119            .filter(SocketAddr::is_ipv4)
120            .collect();
121        drop(sw);
122
123        // ignore commands until we have a certain number of nodes (or the timeout occurs)
124        if !bootstrapping_nodes.is_empty() {
125            let _sw = info_stopwatch!("Bootstrapping");
126
127            for addr in bootstrapping_nodes {
128                if self.known_nodes.insert(addr) {
129                    task::spawn_local(
130                        self.canceller
131                            .clone()
132                            .run_until_cancelled_owned(probe_node(addr, self.task_ctx.clone())),
133                    );
134                }
135            }
136
137            _ = time::timeout(BOOTSTRAP_TIMEOUT, async {
138                while self.node_table.iter().count() < BOOTSTRAP_TARGET
139                    && handle_next_event!(queries).is_continue()
140                {}
141            })
142            .await;
143        }
144
145        // now we're ready to handle commands
146        while handle_next_event!(queries, commands).is_continue() {}
147    }
148
149    fn handle_node_event(&mut self, event: NodeEvent) {
150        match event {
151            NodeEvent::Discovered(node) => {
152                if self.node_table.can_insert(&node.id) && self.known_nodes.insert(node.addr) {
153                    task::spawn_local(
154                        self.canceller.clone().run_until_cancelled_owned(probe_node(
155                            node.addr,
156                            self.task_ctx.clone(),
157                        )),
158                    );
159                }
160            }
161            NodeEvent::Connected(node) => {
162                if self.node_table.insert_node(&node.id, &node.addr) {
163                    task::spawn_local(
164                        self.canceller.clone().run_until_cancelled_owned(keep_alive_node(
165                            node,
166                            self.task_ctx.clone(),
167                        )),
168                    );
169                } else {
170                    self.known_nodes.remove(&node.addr);
171                }
172            }
173            NodeEvent::Disconnected(node) => {
174                self.known_nodes.remove(&node.addr);
175                self.node_table.remove_node(&node.id);
176            }
177            NodeEvent::Unreachable(addr) => {
178                self.known_nodes.remove(&addr);
179            }
180        }
181    }
182
183    fn handle_query(&mut self, query: IncomingQuery) {
184        let node = Node {
185            id: *query.node_id(),
186            addr: *query.source_addr(),
187        };
188
189        let result = match query {
190            IncomingQuery::Ping(ping) => ping.respond(PingResponse {
191                id: self.task_ctx.local_id,
192            }),
193            IncomingQuery::FindNode(find_node) => {
194                let mut nodes: Vec<_> = self
195                    .node_table
196                    .get_closest_nodes(&find_node.args().target, 8)
197                    .filter_map(|node| match node.addr {
198                        SocketAddr::V4(socket_addr_v4) => Some((node.id, socket_addr_v4)),
199                        SocketAddr::V6(_) => None,
200                    })
201                    .take(8)
202                    .collect();
203                if let Some(&exact_match) =
204                    nodes.iter().find(|(id, _)| *id == find_node.args().target)
205                {
206                    nodes.clear();
207                    nodes.push(exact_match);
208                }
209                find_node.respond(FindNodeResponse {
210                    id: self.task_ctx.local_id,
211                    nodes,
212                })
213            }
214            IncomingQuery::GetPeers(get_peers) => {
215                let token = self.token_mgr.generate_token_for(get_peers.source_addr());
216                let peer_addrs: Vec<_> = self
217                    .peers
218                    .get_ipv4_peers(get_peers.args().info_hash)
219                    .take(128) // apprx to fit MTU
220                    .cloned()
221                    .collect();
222                let response_data = if !peer_addrs.is_empty() {
223                    GetPeersResponseData::Peers(peer_addrs)
224                } else {
225                    GetPeersResponseData::Nodes(
226                        self.node_table
227                            .get_closest_nodes(&get_peers.args().info_hash, 8)
228                            .filter_map(|node| match node.addr {
229                                SocketAddr::V4(socket_addr_v4) => Some((node.id, socket_addr_v4)),
230                                SocketAddr::V6(_) => None,
231                            })
232                            .take(8)
233                            .collect(),
234                    )
235                };
236                get_peers.respond(GetPeersResponse {
237                    id: self.task_ctx.local_id,
238                    token: Some(token),
239                    data: response_data,
240                })
241            }
242            IncomingQuery::AnnouncePeer(announce_peer) => {
243                if !self
244                    .token_mgr
245                    .validate_token_from(announce_peer.source_addr(), &announce_peer.args().token)
246                {
247                    announce_peer.respond_error(ErrorMsg {
248                        error_code: ErrorCode::Generic,
249                        error_msg: "Invalid token".to_owned(),
250                    })
251                } else {
252                    // construct peer address
253                    let mut peer_addr = *announce_peer.source_addr();
254                    if let Some(port) = announce_peer.args().port {
255                        peer_addr.set_port(port);
256                    }
257                    // add to peer table
258                    self.peers.add_record(announce_peer.args().info_hash, peer_addr);
259                    // notify active search if any
260                    if let Entry::Occupied(entry) =
261                        self.search_callbacks.entry(announce_peer.args().info_hash)
262                    {
263                        match entry.get().try_send(peer_addr) {
264                            Ok(_) => (),
265                            Err(TrySendError::Closed(_)) => {
266                                entry.remove();
267                            }
268                            Err(TrySendError::Full(_)) => {
269                                log::error!(
270                                    "Failed to report announced peer: callback channel full"
271                                );
272                            }
273                        }
274                    }
275                    // respond to the query
276                    announce_peer.respond(AnnouncePeerResponse {
277                        id: self.task_ctx.local_id,
278                    })
279                }
280            }
281        };
282        if let Err(e) = result {
283            log::warn!("Failed to respond to query: {e}");
284        }
285
286        if self.node_table.can_insert(&node.id) && self.known_nodes.insert(node.addr) {
287            task::spawn_local(
288                self.canceller
289                    .clone()
290                    .run_until_cancelled_owned(probe_node(node.addr, self.task_ctx.clone())),
291            );
292        }
293    }
294
295    fn handle_command(&mut self, cmd: Command) -> ControlFlow<()> {
296        log::info!("Processing command: {cmd:?}");
297        match cmd {
298            Command::AddNode { addr } => {
299                if self.known_nodes.insert(addr) {
300                    task::spawn_local(
301                        self.canceller
302                            .clone()
303                            .run_until_cancelled_owned(probe_node(addr, self.task_ctx.clone())),
304                    );
305                }
306                Continue(())
307            }
308            Command::FindPeers {
309                info_hash,
310                callback,
311                local_peer_port,
312            } => {
313                if self
314                    .peers
315                    .get_peers(info_hash.into())
316                    .try_for_each(|addr| callback.try_send(*addr))
317                    .is_ok()
318                {
319                    let search_data = SearchTaskData {
320                        target: info_hash.into(),
321                        local_peer_port,
322                        ctx: self.task_ctx.clone(),
323                        cmd_result_sender: callback.clone(),
324                        peer_sender: self.peer_sender.clone(),
325                    };
326                    let initial_nodes: Vec<Node> = self
327                        .node_table
328                        .get_closest_nodes(&info_hash.into(), RoutingTable::BUCKET_SIZE * 3)
329                        .cloned()
330                        .collect();
331                    if !initial_nodes.is_empty() {
332                        self.search_callbacks.insert(search_data.target, callback);
333                        self.search_callbacks.retain(|_, cb| !cb.is_closed());
334                        task::spawn_local(run_search(
335                            search_data,
336                            self.canceller.child_token(),
337                            initial_nodes.into_iter(),
338                        ));
339                    } else {
340                        log::warn!("Search can't proceed - no initial nodes");
341                    }
342                }
343                Continue(())
344            }
345            Command::Shutdown => Break(()),
346        }
347    }
348}
349
350impl Drop for Processor {
351    fn drop(&mut self) {
352        let connected_nodes: Vec<String> =
353            self.node_table.iter().map(|node| node.addr.to_string()).collect();
354
355        log::info!("Processor shutting down, node_count = {}", connected_nodes.len());
356
357        if !connected_nodes.is_empty() {
358            self.config.nodes = connected_nodes;
359            self.config.nodes.extend(Config::default().nodes);
360        }
361        if let Err(e) = self.config.save(&self.config_dir) {
362            log::error!("Failed to save config: {e}");
363        }
364
365        self.canceller.cancel();
366    }
367}