1use crate::error::Result;
2use crate::metadata::RbitFetcher;
3use crate::protocol::{DhtArgs, DhtMessage, DhtResponse};
4use crate::scheduler::MetadataScheduler;
5use crate::sharded::{NodeTuple, ShardedNodeQueue};
6use crate::types::{DHTOptions, NetMode, TorrentInfo};
7use ahash::AHasher;
8#[cfg(feature = "metrics")]
9use metrics::{counter, gauge};
10use rand::Rng;
11use socket2::{Domain, Protocol, Socket, Type};
12use std::future::Future;
13use std::hash::{Hash, Hasher};
14use std::net::{IpAddr, Ipv6Addr, SocketAddr};
15use std::pin::Pin;
16use std::sync::atomic::{AtomicUsize, Ordering};
17use std::sync::{Arc, RwLock};
18use std::time::Duration;
19use tokio::net::UdpSocket;
20use tokio::sync::{Semaphore, mpsc};
21use tokio_util::sync::CancellationToken;
22
23const BOOTSTRAP_NODES: &[&str] = &[
24 "router.bittorrent.com:6881",
25 "dht.transmissionbt.com:6881",
26 "router.utorrent.com:6881",
27 "dht.aelitis.com:6881",
28];
29
30pub type BoxedBoolFuture = Pin<Box<dyn Future<Output = bool> + Send>>;
31pub type MetadataFetchCallback = Arc<dyn Fn(String) -> BoxedBoolFuture + Send + Sync>;
32
33#[derive(Debug, Clone)]
34pub struct HashDiscovered {
35 pub info_hash: String,
36 pub peer_addr: SocketAddr,
37 pub discovered_at: std::time::Instant,
38}
39
40type TorrentCallback = Arc<dyn Fn(TorrentInfo) + Send + Sync>;
41type FilterCallback = Arc<dyn Fn(&str) -> bool + Send + Sync>;
42
43#[derive(Clone)]
44pub struct DHTServer {
45 #[allow(dead_code)]
46 options: DHTOptions,
47 node_id: Vec<u8>,
48 socket: Arc<UdpSocket>,
49 socket_v6: Option<Arc<UdpSocket>>,
50 token_secret: Vec<u8>,
51 callback: Arc<RwLock<Option<TorrentCallback>>>,
52 filter: Arc<RwLock<Option<FilterCallback>>>,
53 on_metadata_fetch: Arc<RwLock<Option<MetadataFetchCallback>>>,
54 node_queue: Arc<ShardedNodeQueue>,
55 hash_tx: mpsc::Sender<HashDiscovered>,
56 metadata_queue_len: Arc<AtomicUsize>,
57 max_metadata_queue_size: usize,
58 shutdown: CancellationToken,
59}
60
61impl DHTServer {
62 pub async fn new(options: DHTOptions) -> Result<Self> {
63 let (socket, socket_v6) = match options.netmode {
64 NetMode::Ipv4Only => {
65 let sock = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
66 #[cfg(not(windows))]
67 {
68 let _ = sock.set_reuse_port(true);
69 }
70 let _ = sock.set_reuse_address(true);
71 sock.set_nonblocking(true)?;
72
73 let _ = sock.set_recv_buffer_size(32 * 1024 * 1024);
74 let _ = sock.set_send_buffer_size(8 * 1024 * 1024);
75
76 let addr: SocketAddr = format!("0.0.0.0:{}", options.port).parse().unwrap();
77 sock.bind(&addr.into())?;
78 (Arc::new(UdpSocket::from_std(sock.into())?), None)
79 }
80 NetMode::Ipv6Only => {
81 let sock = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?;
82 #[cfg(not(windows))]
83 {
84 let _ = sock.set_reuse_port(true);
85 }
86 let _ = sock.set_reuse_address(true);
87 #[cfg(not(windows))]
88 {
89 let _ = sock.set_only_v6(true);
90 }
91 sock.set_nonblocking(true)?;
92
93 let _ = sock.set_recv_buffer_size(32 * 1024 * 1024);
94 let _ = sock.set_send_buffer_size(8 * 1024 * 1024);
95
96 let addr: SocketAddr = format!("[::]:{}", options.port).parse().unwrap();
97 sock.bind(&addr.into())?;
98 (Arc::new(UdpSocket::from_std(sock.into())?), None)
99 }
100 NetMode::DualStack => {
101 let sock_v4 = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
102 #[cfg(not(windows))]
103 {
104 let _ = sock_v4.set_reuse_port(true);
105 }
106 let _ = sock_v4.set_reuse_address(true);
107 sock_v4.set_nonblocking(true)?;
108 let _ = sock_v4.set_recv_buffer_size(32 * 1024 * 1024);
109 let _ = sock_v4.set_send_buffer_size(8 * 1024 * 1024);
110 let addr_v4: SocketAddr = format!("0.0.0.0:{}", options.port).parse().unwrap();
111 sock_v4.bind(&addr_v4.into())?;
112 let socket = Arc::new(UdpSocket::from_std(sock_v4.into())?);
113
114 let sock_v6 = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?;
115 #[cfg(not(windows))]
116 {
117 let _ = sock_v6.set_reuse_port(true);
118 }
119 let _ = sock_v6.set_reuse_address(true);
120 #[cfg(not(windows))]
121 {
122 let _ = sock_v6.set_only_v6(true);
123 }
124 sock_v6.set_nonblocking(true)?;
125 let _ = sock_v6.set_recv_buffer_size(32 * 1024 * 1024);
126 let _ = sock_v6.set_send_buffer_size(8 * 1024 * 1024);
127 let addr_v6: SocketAddr = format!("[::]:{}", options.port).parse().unwrap();
128 sock_v6.bind(&addr_v6.into())?;
129 let socket_v6 = Some(Arc::new(UdpSocket::from_std(sock_v6.into())?));
130
131 (socket, socket_v6)
132 }
133 };
134
135 let node_id = generate_random_id();
136 let mut rng = rand::thread_rng();
137 let token_secret: Vec<u8> = (0..10).map(|_| rng.r#gen::<u8>()).collect();
138
139 let node_queue = ShardedNodeQueue::new(options.node_queue_capacity);
140
141 let (hash_tx, hash_rx) = mpsc::channel::<HashDiscovered>(options.hash_queue_capacity);
142
143 let fetcher = Arc::new(RbitFetcher::new(options.metadata_timeout));
144
145 let callback = Arc::new(RwLock::new(None));
146 let on_metadata_fetch = Arc::new(RwLock::new(None));
147
148 let metadata_queue_len = Arc::new(AtomicUsize::new(0));
149
150 let shutdown = CancellationToken::new();
151 let shutdown_for_scheduler = shutdown.clone();
152
153 let scheduler = MetadataScheduler::new(
154 hash_rx,
155 fetcher,
156 options.max_metadata_queue_size,
157 options.max_metadata_worker_count,
158 callback.clone(),
159 on_metadata_fetch.clone(),
160 metadata_queue_len.clone(),
161 shutdown_for_scheduler,
162 );
163
164 tokio::spawn(async move {
165 scheduler.run().await;
166 });
167
168 let max_metadata_queue_size = options.max_metadata_queue_size;
169 let server = Self {
170 options,
171 node_id: node_id.clone(),
172 socket,
173 socket_v6,
174 token_secret,
175 callback,
176 on_metadata_fetch,
177 node_queue: Arc::new(node_queue),
178 filter: Arc::new(RwLock::new(None)),
179 hash_tx,
180 metadata_queue_len,
181 max_metadata_queue_size,
182 shutdown,
183 };
184
185 Ok(server)
186 }
187
188 pub fn local_addr(&self) -> Result<SocketAddr> {
189 Ok(self.socket.local_addr()?)
190 }
191
192 fn is_addr_allowed(&self, addr: &SocketAddr) -> bool {
193 match self.options.netmode {
194 NetMode::Ipv4Only => addr.is_ipv4(),
195 NetMode::Ipv6Only => addr.is_ipv6(),
196 NetMode::DualStack => true,
197 }
198 }
199
200 fn select_socket(&self, addr: &SocketAddr) -> &Arc<UdpSocket> {
201 match self.options.netmode {
202 NetMode::Ipv4Only => &self.socket,
203 NetMode::Ipv6Only => &self.socket,
204 NetMode::DualStack => {
205 if addr.is_ipv6() {
206 self.socket_v6.as_ref().unwrap_or(&self.socket)
207 } else {
208 &self.socket
209 }
210 }
211 }
212 }
213
214 pub fn on_metadata_fetch<F, Fut>(&self, callback: F)
215 where
216 F: Fn(String) -> Fut + Send + Sync + 'static,
217 Fut: Future<Output = bool> + Send + 'static,
218 {
219 *self.on_metadata_fetch.write().unwrap() =
220 Some(Arc::new(move |hash| Box::pin(callback(hash))));
221 }
222
223 pub fn on_torrent<F>(&self, callback: F)
224 where
225 F: Fn(TorrentInfo) + Send + Sync + 'static,
226 {
227 *self.callback.write().unwrap() = Some(Arc::new(callback));
228 }
229
230 pub fn set_filter<F>(&self, filter: F)
231 where
232 F: Fn(&str) -> bool + Send + Sync + 'static,
233 {
234 *self.filter.write().unwrap() = Some(Arc::new(filter));
235 }
236
237 pub fn get_node_pool_size(&self) -> usize {
238 self.node_queue.len()
239 }
240
241 pub async fn start(&self) -> Result<()> {
242 if self.shutdown.is_cancelled() {
244 log::warn!("⚠️ 尝试启动已关闭的服务器");
245 return Err(crate::error::DHTError::Other("服务器已关闭".to_string()));
246 }
247
248 self.start_receiver();
249 self.bootstrap().await;
250
251 let server = self.clone();
252 let shutdown = self.shutdown.clone();
253
254 tokio::spawn(async move {
255 let semaphore = Arc::new(Semaphore::new(2000));
256 let mut loop_tick = 0;
257
258 loop {
259 if shutdown.is_cancelled() {
261 #[cfg(debug_assertions)]
262 log::trace!("主循环收到关闭信号,退出");
263 break;
264 }
265
266 let queue_len = server.metadata_queue_len.load(Ordering::Relaxed);
267 let queue_pressure = queue_len as f64 / server.max_metadata_queue_size as f64;
268
269 #[cfg(feature = "metrics")]
270 {
271 gauge!("dht_metadata_queue_size").set(queue_len as f64);
272 gauge!("dht_metadata_worker_pressure").set(queue_pressure);
273 gauge!("dht_node_queue_size").set(server.node_queue.len() as f64);
274 }
275
276 let (batch_size, sleep_duration) = if queue_pressure < 0.8 {
277 (200, Duration::from_millis(10))
278 } else if queue_pressure < 0.95 {
279 (20, Duration::from_millis(500))
280 } else {
281 (0, Duration::from_millis(1000))
282 };
283
284 let filter_ipv6 = match server.options.netmode {
285 NetMode::Ipv4Only => Some(false),
286 NetMode::Ipv6Only => Some(true),
287 NetMode::DualStack => None,
288 };
289
290 let queue_empty = server.node_queue.is_empty_for(filter_ipv6);
291
292 let nodes_batch = {
293 if queue_empty || batch_size == 0 {
294 None
295 } else {
296 Some(server.node_queue.pop_batch(batch_size, filter_ipv6))
297 }
298 };
299
300 loop_tick += 1;
301 if nodes_batch.is_none() || loop_tick % 50 == 0 {
302 server.bootstrap().await;
303 if nodes_batch.is_none() {
304 tokio::select! {
305 _ = shutdown.cancelled() => break,
306 _ = tokio::time::sleep(sleep_duration) => {},
307 }
308 continue;
309 }
310 }
311
312 if let Some(nodes) = nodes_batch {
313 let node_id = server.node_id.clone();
314 let socket = server.socket.clone();
315 let socket_v6 = server.socket_v6.clone();
316 let netmode = server.options.netmode;
317
318 for node in nodes {
319 let permit = semaphore.clone().acquire_owned().await.unwrap();
320 let node_id_clone = node_id.clone();
321 let socket_clone = socket.clone();
322 let socket_v6_clone = socket_v6.clone();
323 let node_addr = node.addr;
324 let node_id_for_target = node.id;
325
326 tokio::spawn(async move {
327 let neighbor_id =
328 generate_neighbor_target(&node_id_for_target, &node_id_clone);
329 let random_target = generate_random_id();
330 let _ = send_find_node_impl(
331 node_addr,
332 &random_target,
333 &neighbor_id,
334 &socket_clone,
335 socket_v6_clone.as_ref(),
336 netmode,
337 )
338 .await;
339 drop(permit);
340 });
341 }
342 }
343
344 tokio::select! {
345 _ = shutdown.cancelled() => break,
346 _ = tokio::time::sleep(sleep_duration) => {},
347 }
348 }
349 });
350 self.shutdown.cancelled().await;
351 Ok(())
352 }
353
354 pub fn shutdown(&self) {
356 self.shutdown.cancel();
357 }
358
359 fn start_receiver(&self) {
360 let socket = self.socket.clone();
361 let socket_v6 = self.socket_v6.clone();
362 let server = self.clone();
363 let shutdown = self.shutdown.clone();
364
365 let num_workers = std::thread::available_parallelism()
366 .map(|n| n.get())
367 .unwrap_or(8);
368
369 let queue_size = 5000;
370
371 let mut senders = Vec::with_capacity(num_workers);
372 for _ in 0..num_workers {
373 let (tx, mut rx) = mpsc::channel::<(Vec<u8>, SocketAddr)>(queue_size);
374 senders.push(tx);
375
376 let server_clone = server.clone();
377 let shutdown_worker = shutdown.clone();
378
379 tokio::spawn(async move {
380 loop {
381 tokio::select! {
382 _ = shutdown_worker.cancelled() => {
383 #[cfg(debug_assertions)]
384 log::trace!("Worker 收到关闭信号,退出");
385 break;
386 }
387 msg = rx.recv() => {
388 match msg {
389 Some((data, addr)) => {
390 let _ = server_clone.handle_message(&data, addr).await;
391 }
392 None => break,
393 }
394 }
395 }
396 }
397 });
398 }
399
400 Self::spawn_udp_reader(socket, senders.clone(), shutdown.clone());
401
402 if let Some(socket_v6) = socket_v6 {
403 Self::spawn_udp_reader(socket_v6, senders, shutdown);
404 }
405 }
406
407 fn spawn_udp_reader(
408 socket: Arc<UdpSocket>,
409 senders: Vec<mpsc::Sender<(Vec<u8>, SocketAddr)>>,
410 shutdown: CancellationToken,
411 ) {
412 let num_workers = senders.len();
413
414 tokio::spawn(async move {
415 let mut buf = [0u8; 65536];
416 let mut next_worker_idx = 0;
417
418 loop {
419 tokio::select! {
420 _ = shutdown.cancelled() => {
421 #[cfg(debug_assertions)]
422 log::trace!("UDP 读取循环收到关闭信号,退出");
423 break;
424 }
425 result = socket.recv_from(&mut buf) => {
426 match result {
427 Ok((size, addr)) => {
428 #[cfg(feature = "metrics")]
429 counter!("dht_udp_bytes_received_total").increment(size as u64);
430
431 if size > 8192 {
432 #[cfg(feature = "metrics")]
433 counter!("dht_udp_packets_received_total", "status" => "dropped_size").increment(1);
434
435 #[cfg(debug_assertions)]
436 log::trace!("⚠️ 拒绝异常大的 UDP 包: {} 字节 from {}", size, addr);
437 continue;
438 }
439
440 if size == 0 || buf[0] != b'd' {
441 #[cfg(feature = "metrics")]
442 counter!("dht_udp_packets_received_total", "status" => "dropped_magic").increment(1);
443 continue;
444 }
445
446 let data = buf[..size].to_vec();
447
448 let tx = &senders[next_worker_idx];
449 next_worker_idx = (next_worker_idx + 1) % num_workers;
450
451 match tx.try_send((data, addr)) {
452 Ok(_) => {
453 #[cfg(feature = "metrics")]
454 counter!("dht_udp_packets_received_total", "status" => "ok").increment(1);
455 },
456 Err(mpsc::error::TrySendError::Full(_)) => {
457 #[cfg(feature = "metrics")]
458 counter!("dht_udp_packets_received_total", "status" => "queue_full").increment(1);
459
460 #[cfg(debug_assertions)]
461 log::trace!("UDP worker queue full, dropping packet");
462 },
463 Err(_) => { break; }
464 }
465 }
466 Err(_e) => {
467 tokio::select! {
468 _ = shutdown.cancelled() => break,
469 _ = tokio::time::sleep(Duration::from_millis(1)) => {},
470 }
471 }
472 }
473 }
474 }
475 }
476 });
477 }
478
479 async fn handle_message(&self, data: &[u8], addr: SocketAddr) -> Result<()> {
480 if !self.is_addr_allowed(&addr) {
481 #[cfg(debug_assertions)]
482 log::trace!(
483 "⚠️ 拒绝不匹配的地址类型: {} (当前模式: {:?})",
484 addr,
485 self.options.netmode
486 );
487 return Ok(());
488 }
489
490 let msg: DhtMessage = match serde_bencode::from_bytes(data) {
491 Ok(m) => m,
492 Err(_) => {
493 #[cfg(feature = "metrics")]
494 counter!("dht_messages_parse_error_total").increment(1);
495 return Ok(());
496 }
497 };
498
499 #[cfg(feature = "metrics")]
500 {
501 let label = match msg.y.as_str() {
503 "q" => "q",
504 "r" => "r",
505 "e" => "e",
506 _ => "unknown", };
508 counter!("dht_messages_processed_total", "type" => label).increment(1);
509 }
510
511 match msg.y.as_str() {
512 "q" => {
513 if let Some(q_type) = &msg.q {
514 self.handle_query(&msg, q_type.as_bytes(), addr).await?;
515 }
516 }
517 "r" => {
518 if let Some(response) = &msg.r {
519 self.handle_response(response).await?;
520 }
521 }
522 _ => {}
523 }
524 Ok(())
525 }
526
527 async fn handle_query(
528 &self,
529 msg: &DhtMessage,
530 query_type: &[u8],
531 addr: SocketAddr,
532 ) -> Result<()> {
533 let args = match &msg.a {
534 Some(a) => a,
535 None => return Ok(()),
536 };
537
538 let transaction_id = &msg.t;
539 let sender_id: Option<&[u8]> = args.id.as_deref().map(|v| v.as_slice());
540 let target_id_fallback: Option<&[u8]> = args
541 .target
542 .as_deref()
543 .or(args.info_hash.as_deref())
544 .map(|v| v.as_slice());
545
546 let q_str = std::str::from_utf8(query_type).unwrap_or("");
547
548 #[cfg(feature = "metrics")]
549 {
550 let label = match q_str {
551 "ping" => "ping",
552 "find_node" => "find_node",
553 "get_peers" => "get_peers",
554 "announce_peer" => "announce_peer",
555 "vote" => "vote",
556 _ => "other_or_invalid",
557 };
558 counter!("dht_queries_total", "q" => label).increment(1);
559 }
560
561 if q_str == "announce_peer" {
562 self.handle_announce_peer(args, addr).await?;
563 }
564
565 self.send_response(transaction_id, addr, q_str, sender_id, target_id_fallback)
566 .await?;
567 Ok(())
568 }
569
570 async fn handle_announce_peer(&self, args: &DhtArgs, addr: SocketAddr) -> Result<()> {
571 if let Some(token) = &args.token {
572 if !self.validate_token(token, addr) {
573 #[cfg(feature = "metrics")]
574 counter!("dht_announce_peer_blocked_total", "reason" => "invalid_token")
575 .increment(1);
576 return Ok(());
577 }
578 } else {
579 return Ok(());
580 }
581
582 if let Some(info_hash) = &args.info_hash {
583 let info_hash_arr: [u8; 20] = match info_hash.as_ref().try_into() {
584 Ok(arr) => arr,
585 Err(_) => return Ok(()),
586 };
587 let hash_hex = hex::encode(info_hash_arr);
588
589 let filter_cb = self.filter.read().unwrap().clone();
590 if let Some(f) = filter_cb
591 && !f(&hash_hex) {
592 #[cfg(feature = "metrics")]
593 counter!("dht_announce_peer_blocked_total", "reason" => "filtered")
594 .increment(1);
595 return Ok(());
596 }
597
598 #[cfg(feature = "metrics")]
599 counter!("dht_info_hashes_discovered_total").increment(1);
600
601 #[cfg(debug_assertions)]
602 log::debug!("🔥 新 Hash: {} 来自 {}", hash_hex, addr);
603
604 let port = if let Some(implied) = args.implied_port {
605 if implied != 0 {
606 addr.port()
607 } else {
608 args.port.unwrap_or(0)
609 }
610 } else {
611 args.port.unwrap_or(addr.port())
612 };
613
614 if port > 0 {
615 let event = HashDiscovered {
616 info_hash: hash_hex,
617 peer_addr: SocketAddr::new(addr.ip(), port),
618 discovered_at: std::time::Instant::now(),
619 };
620
621 if self.hash_tx.try_send(event).is_err() {
622 #[cfg(debug_assertions)]
623 log::debug!("⚠️ Hash 队列满,丢弃 hash");
624 }
625 }
626 }
627 Ok(())
628 }
629
630 async fn handle_response(&self, response: &DhtResponse) -> Result<()> {
631 if let Some(nodes_bytes) = &response.nodes {
632 self.process_compact_nodes(nodes_bytes);
633 }
634 if let Some(nodes6_bytes) = &response.nodes6 {
635 self.process_compact_nodes_v6(nodes6_bytes);
636 }
637 Ok(())
638 }
639
640 fn process_compact_nodes(&self, nodes_bytes: &[u8]) {
641 if self.options.netmode == NetMode::Ipv6Only {
642 return;
643 }
644
645 #[allow(clippy::manual_is_multiple_of)]
646 if nodes_bytes.len() % 26 != 0 {
647 return;
648 }
649
650 for chunk in nodes_bytes.chunks(26) {
651 let id = chunk[0..20].to_vec();
652 let port = u16::from_be_bytes([chunk[24], chunk[25]]);
653
654 let ip = std::net::Ipv4Addr::new(chunk[20], chunk[21], chunk[22], chunk[23]);
655 let addr = SocketAddr::new(std::net::IpAddr::V4(ip), port);
656
657 #[cfg(feature = "metrics")]
658 counter!("dht_nodes_discovered_total", "ip_version" => "v4").increment(1);
659
660 self.node_queue.push(NodeTuple { id, addr });
661 }
662 }
663
664 fn process_compact_nodes_v6(&self, nodes_bytes: &[u8]) {
665 if self.options.netmode == NetMode::Ipv4Only {
666 return;
667 }
668
669 #[allow(clippy::manual_is_multiple_of)]
670 if nodes_bytes.len() % 38 != 0 {
671 return;
672 }
673 for chunk in nodes_bytes.chunks(38) {
674 let id = chunk[0..20].to_vec();
675 let port = u16::from_be_bytes([chunk[36], chunk[37]]);
676 let ip_bytes: [u8; 16] = match chunk[20..36].try_into() {
677 Ok(b) => b,
678 Err(_) => continue,
679 };
680 let ip = Ipv6Addr::from(ip_bytes);
681 if !ip.is_unspecified() && !ip.is_multicast() {
682 let addr = SocketAddr::new(IpAddr::V6(ip), port);
683
684 #[cfg(feature = "metrics")]
685 counter!("dht_nodes_discovered_total", "ip_version" => "v6").increment(1);
686
687 self.node_queue.push(NodeTuple { id, addr });
688 }
689 }
690 }
691
692 async fn send_response(
693 &self,
694 tid: &[u8],
695 addr: SocketAddr,
696 query_type: &str,
697 sender_id: Option<&[u8]>,
698 target_id_fallback: Option<&[u8]>,
699 ) -> Result<()> {
700 let mut r_dict = std::collections::HashMap::new();
701
702 let reference_id = sender_id.or(target_id_fallback);
703 let my_id = if let Some(target) = reference_id {
704 generate_neighbor_target(target, &self.node_id)
705 } else {
706 self.node_id.clone()
707 };
708
709 r_dict.insert(b"id".to_vec(), serde_bencode::value::Value::Bytes(my_id));
710 let token = self.generate_token(addr);
711 r_dict.insert(b"token".to_vec(), serde_bencode::value::Value::Bytes(token));
712
713 if query_type == "get_peers" || query_type == "find_node" {
714 let requestor_is_ipv6 = addr.is_ipv6();
715 let filter_ipv6 = match self.options.netmode {
716 NetMode::Ipv4Only => Some(false),
717 NetMode::Ipv6Only => Some(true),
718 NetMode::DualStack => Some(requestor_is_ipv6),
719 };
720
721 let nodes = self.node_queue.get_random_nodes(8, filter_ipv6);
722
723 let mut nodes_data = Vec::new();
724 let mut nodes6_data = Vec::new();
725
726 for node in nodes {
727 match node.addr.ip() {
728 IpAddr::V4(ip) => {
729 nodes_data.extend_from_slice(&node.id);
730 nodes_data.extend_from_slice(&ip.octets());
731 nodes_data.extend_from_slice(&node.addr.port().to_be_bytes());
732 }
733 IpAddr::V6(ip) => {
734 nodes6_data.extend_from_slice(&node.id);
735 nodes6_data.extend_from_slice(&ip.octets());
736 nodes6_data.extend_from_slice(&node.addr.port().to_be_bytes());
737 }
738 }
739 }
740
741 if requestor_is_ipv6 {
742 if !nodes6_data.is_empty() {
743 r_dict.insert(
744 b"nodes6".to_vec(),
745 serde_bencode::value::Value::Bytes(nodes6_data),
746 );
747 }
748 } else if !nodes_data.is_empty() {
749 r_dict.insert(
750 b"nodes".to_vec(),
751 serde_bencode::value::Value::Bytes(nodes_data),
752 );
753 }
754 }
755
756 let mut response: std::collections::HashMap<String, serde_bencode::value::Value> =
757 std::collections::HashMap::new();
758 response.insert(
759 "t".to_string(),
760 serde_bencode::value::Value::Bytes(tid.to_vec()),
761 );
762 response.insert(
763 "y".to_string(),
764 serde_bencode::value::Value::Bytes(b"r".to_vec()),
765 );
766 response.insert("r".to_string(), serde_bencode::value::Value::Dict(r_dict));
767
768 if let Ok(encoded) = serde_bencode::to_bytes(&response) {
769 #[cfg(feature = "metrics")]
770 {
771 counter!("dht_udp_bytes_sent_total").increment(encoded.len() as u64);
772 counter!("dht_udp_packets_sent_total", "type" => "response").increment(1);
773 }
774 let _ = self.select_socket(&addr).send_to(&encoded, addr).await;
775 }
776 Ok(())
777 }
778
779 async fn bootstrap(&self) {
780 let target = generate_random_id();
781 for node in BOOTSTRAP_NODES {
782 if let Ok(addrs) = tokio::net::lookup_host(node).await {
783 for addr in addrs {
784 match self.options.netmode {
785 NetMode::Ipv4Only => {
786 if addr.is_ipv6() {
787 continue;
788 }
789 }
790 NetMode::Ipv6Only => {
791 if addr.is_ipv4() {
792 continue;
793 }
794 }
795 NetMode::DualStack => {}
796 }
797 let _ = self.send_find_node(addr, &target, &self.node_id).await;
798 }
799 }
800 }
801 }
802
803 async fn send_find_node(
804 &self,
805 addr: SocketAddr,
806 target: &[u8],
807 sender_id: &[u8],
808 ) -> Result<()> {
809 send_find_node_impl(
810 addr,
811 target,
812 sender_id,
813 &self.socket,
814 self.socket_v6.as_ref(),
815 self.options.netmode,
816 )
817 .await
818 }
819
820 fn generate_token(&self, addr: SocketAddr) -> Vec<u8> {
821 let mut hasher = AHasher::default();
822
823 match addr.ip() {
824 IpAddr::V4(ip) => ip.octets().hash(&mut hasher),
825 IpAddr::V6(ip) => ip.octets().hash(&mut hasher),
826 }
827
828 self.token_secret.hash(&mut hasher);
829
830 let hash = hasher.finish();
831 hash.to_le_bytes().to_vec()
832 }
833
834 fn validate_token(&self, token: &[u8], addr: SocketAddr) -> bool {
835 if token.len() != 8 {
836 return false;
837 }
838 let expected = self.generate_token(addr);
839 token == expected.as_slice()
840 }
841}
842
843async fn send_find_node_impl(
882 addr: SocketAddr,
883 target: &[u8],
884 sender_id: &[u8],
885 socket: &Arc<UdpSocket>,
886 socket_v6: Option<&Arc<UdpSocket>>,
887 netmode: NetMode,
888) -> Result<()> {
889 let mut args = std::collections::HashMap::new();
891 args.insert(
892 b"id".to_vec(),
893 serde_bencode::value::Value::Bytes(sender_id.to_vec()),
894 );
895 args.insert(
896 b"target".to_vec(),
897 serde_bencode::value::Value::Bytes(target.to_vec()),
898 );
899
900 let mut msg: std::collections::HashMap<String, serde_bencode::value::Value> =
902 std::collections::HashMap::new();
903 msg.insert(
904 "t".to_string(),
905 serde_bencode::value::Value::Bytes(vec![0, 1]),
906 ); msg.insert(
908 "y".to_string(),
909 serde_bencode::value::Value::Bytes(b"q".to_vec()),
910 ); msg.insert(
912 "q".to_string(),
913 serde_bencode::value::Value::Bytes(b"find_node".to_vec()),
914 ); msg.insert("a".to_string(), serde_bencode::value::Value::Dict(args)); if let Ok(encoded) = serde_bencode::to_bytes(&msg) {
919 let selected_socket = match netmode {
921 NetMode::Ipv4Only => socket,
922 NetMode::Ipv6Only => socket,
923 NetMode::DualStack => {
924 if addr.is_ipv6() {
925 socket_v6.unwrap_or(socket)
926 } else {
927 socket
928 }
929 }
930 };
931 #[cfg(feature = "metrics")]
933 {
934 counter!("dht_udp_bytes_sent_total").increment(encoded.len() as u64);
935 counter!("dht_udp_packets_sent_total", "type" => "query").increment(1);
936 }
937 let _ = selected_socket.send_to(&encoded, addr).await;
938 }
939 Ok(())
940}
941
942fn generate_random_id() -> Vec<u8> {
943 let mut rng = rand::thread_rng();
944 (0..20).map(|_| rng.r#gen::<u8>()).collect()
945}
946
947fn generate_neighbor_target(remote_id: &[u8], local_id: &[u8]) -> Vec<u8> {
987 let mut id = Vec::with_capacity(20);
988 let prefix_len = std::cmp::min(remote_id.len(), 6);
989 id.extend_from_slice(&remote_id[..prefix_len]);
990
991 if local_id.len() > prefix_len {
992 id.extend_from_slice(&local_id[prefix_len..]);
993 } else {
994 while id.len() < 20 {
995 id.push(rand::random());
996 }
997 }
998 id
999}