1use crate::error::Result;
2use crate::metadata::RbitFetcher;
3use crate::protocol::{DhtArgs, DhtMessage, DhtResponse};
4use crate::scheduler::MetadataScheduler;
5use crate::sharded::{NodeTuple, ShardedNodeQueue};
6use crate::types::{DHTOptions, NetMode, TorrentInfo};
7#[cfg(feature = "metrics")]
8use metrics::{counter, gauge};
9use rand::Rng;
10use socket2::{Domain, Protocol, Socket, Type};
11use std::collections::HashMap;
12use std::future::Future;
13use std::hash::{Hash, Hasher};
14use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
15use std::pin::Pin;
16use std::sync::atomic::{AtomicUsize, Ordering};
17use std::sync::{Arc, RwLock};
18use std::time::Duration;
19use tokio::net::UdpSocket;
20use tokio::sync::{Semaphore, mpsc};
21use tokio_util::sync::CancellationToken;
22
23const BOOTSTRAP_NODES: &[&str] = &[
24 "router.bittorrent.com:6881",
25 "dht.transmissionbt.com:6881",
26 "router.utorrent.com:6881",
27 "dht.aelitis.com:6881",
28];
29
30pub type BoxedBoolFuture = Pin<Box<dyn Future<Output = bool> + Send>>;
31pub type MetadataFetchCallback = Arc<dyn Fn(String) -> BoxedBoolFuture + Send + Sync>;
32type WorkerHandle = mpsc::Sender<(Box<[u8]>, SocketAddr, SocketAddr)>;
33
34#[derive(Debug, Clone)]
35pub struct HashDiscovered {
36 pub info_hash: String,
37 pub peer_addr: SocketAddr,
38 pub discovered_at: std::time::Instant,
39}
40
41type TorrentCallback = Arc<dyn Fn(TorrentInfo) + Send + Sync>;
42type FilterCallback = Arc<dyn Fn(&str) -> bool + Send + Sync>;
43type ErrorCallback = Arc<dyn Fn(crate::error::DHTError) + Send + Sync>;
44
45#[derive(Clone)]
46pub struct DHTServer {
47 #[allow(dead_code)]
48 options: DHTOptions,
49 node_id: [u8; 20],
50 socket_providers: Arc<HashMap<SocketAddr, Arc<UdpSocket>>>,
51 token_secret: [u8; 10],
52 callback: Arc<RwLock<Option<TorrentCallback>>>,
53 filter: Arc<RwLock<Option<FilterCallback>>>,
54 on_metadata_fetch: Arc<RwLock<Option<MetadataFetchCallback>>>,
55 on_error_cb: Arc<RwLock<Option<ErrorCallback>>>,
56 node_queue: Arc<ShardedNodeQueue>,
57 hash_tx: mpsc::Sender<HashDiscovered>,
58 metadata_queue_len: Arc<AtomicUsize>,
59 max_metadata_queue_size: usize,
60 shutdown: CancellationToken,
61}
62
63fn create_udp_sock(domain: Domain, ty: Type, addr: SocketAddr) -> std::io::Result<UdpSocket> {
64 let sock = Socket::new(domain, ty, Some(Protocol::UDP))?;
65 #[cfg(not(windows))]
66 {
67 sock.set_reuse_port(true)?;
68 if addr.is_ipv6() {
69 sock.set_only_v6(true)?;
70 }
71 }
72 let _ = sock.set_reuse_address(true);
73 sock.set_nonblocking(true)?;
74
75 let _ = sock.set_recv_buffer_size(32 * 1024 * 1024);
76 let _ = sock.set_send_buffer_size(8 * 1024 * 1024);
77
78 sock.bind(&addr.into())?;
79 UdpSocket::from_std(sock.into())
80}
81
82impl DHTServer {
83 pub async fn new(options: DHTOptions) -> Result<Self> {
84 const ANY_V4_ADDR: SocketAddr =
85 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 8080);
86 const ANY_V6_ADDR: SocketAddr =
87 SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 8080);
88 let mut socket_providers = HashMap::new();
89 match options.netmode {
91 NetMode::Ipv4Only => {
92 let mut addr = ANY_V4_ADDR;
93 addr.set_port(options.port);
94 let sock = create_udp_sock(Domain::IPV4, Type::DGRAM, addr)?;
95 socket_providers.insert(addr, Arc::new(sock));
96 }
97 NetMode::Ipv6Only => {
98 let mut addr = ANY_V6_ADDR;
99 addr.set_port(options.port);
100 let sock = create_udp_sock(Domain::IPV6, Type::DGRAM, addr)?;
101 socket_providers.insert(addr, Arc::new(sock));
102 }
103 NetMode::DualStack => {
104 let mut addr = ANY_V4_ADDR;
105 addr.set_port(options.port);
106 let sock = create_udp_sock(Domain::IPV4, Type::DGRAM, addr)?;
107 socket_providers.insert(addr, Arc::new(sock));
108 let mut addr = ANY_V6_ADDR;
109 addr.set_port(options.port);
110 let sock = create_udp_sock(Domain::IPV6, Type::DGRAM, addr)?;
111 socket_providers.insert(addr, Arc::new(sock));
112 }
113 };
114
115 let node_id = generate_random_id();
116 let mut token_secret = [0u8; 10];
117 rand::thread_rng().fill(&mut token_secret);
118
119 let node_queue = ShardedNodeQueue::new(options.node_queue_capacity);
120
121 let (hash_tx, hash_rx) = mpsc::channel::<HashDiscovered>(options.hash_queue_capacity);
122
123 let fetcher = Arc::new(RbitFetcher::new(options.metadata_timeout));
124
125 let callback = Arc::new(RwLock::new(None));
126 let on_metadata_fetch = Arc::new(RwLock::new(None));
127
128 let metadata_queue_len = Arc::new(AtomicUsize::new(0));
129
130 let shutdown = CancellationToken::new();
131 let shutdown_for_scheduler = shutdown.clone();
132
133 let scheduler = MetadataScheduler::new(
134 hash_rx,
135 fetcher,
136 options.max_metadata_queue_size,
137 options.max_metadata_worker_count,
138 callback.clone(),
139 on_metadata_fetch.clone(),
140 metadata_queue_len.clone(),
141 shutdown_for_scheduler,
142 );
143
144 tokio::spawn(async move {
145 scheduler.run().await;
146 });
147
148 let max_metadata_queue_size = options.max_metadata_queue_size;
149 let server = Self {
150 options,
151 node_id,
152 socket_providers: Arc::new(socket_providers),
153 token_secret,
154 callback,
155 on_metadata_fetch,
156 node_queue: Arc::new(node_queue),
157 filter: Arc::new(RwLock::new(None)),
158 on_error_cb: Arc::new(RwLock::new(None)),
159 hash_tx,
160 metadata_queue_len,
161 max_metadata_queue_size,
162 shutdown,
163 };
164
165 Ok(server)
166 }
167
168 pub fn on_metadata_fetch<F, Fut>(&self, callback: F)
169 where
170 F: Fn(String) -> Fut + Send + Sync + 'static,
171 Fut: Future<Output = bool> + Send + 'static,
172 {
173 *self.on_metadata_fetch.write().unwrap_or_else(|e| e.into_inner()) =
174 Some(Arc::new(move |hash| Box::pin(callback(hash))));
175 }
176
177 pub fn on_torrent<F>(&self, callback: F)
178 where
179 F: Fn(TorrentInfo) + Send + Sync + 'static,
180 {
181 *self.callback.write().unwrap_or_else(|e| e.into_inner()) = Some(Arc::new(callback));
182 }
183
184 pub fn set_filter<F>(&self, filter: F)
185 where
186 F: Fn(&str) -> bool + Send + Sync + 'static,
187 {
188 *self.filter.write().unwrap_or_else(|e| e.into_inner()) = Some(Arc::new(filter));
189 }
190
191 pub fn on_error<F>(&self, callback: F)
192 where
193 F: Fn(crate::error::DHTError) + Send + Sync + 'static,
194 {
195 *self.on_error_cb.write().unwrap_or_else(|e| e.into_inner()) = Some(Arc::new(callback));
196 }
197
198 fn emit_error(&self, error: crate::error::DHTError) {
199 if let Ok(cb) = self.on_error_cb.read() {
200 if let Some(f) = cb.as_ref() {
201 f(error);
202 }
203 }
204 }
205
206 pub fn get_node_pool_size(&self) -> usize {
207 self.node_queue.len()
208 }
209
210 pub async fn start(&self) -> Result<()> {
211 if self.shutdown.is_cancelled() {
213 log::warn!("⚠️ 尝试启动已关闭的服务器");
214 return Err(crate::error::DHTError::Other("服务器已关闭".to_string()));
215 }
216
217 let workers = self.spawn_workers();
218 for sock in self.socket_providers.values().cloned() {
219 spawn_udp_listener(sock, workers.clone(), self.shutdown.clone())?;
220 }
221 self.bootstrap().await;
222
223 let server = self.clone();
224 let shutdown = self.shutdown.clone();
225
226 tokio::spawn(async move {
227 let semaphore = Arc::new(Semaphore::new(2000));
228 let mut loop_tick = 0;
229
230 loop {
231 if shutdown.is_cancelled() {
233 #[cfg(debug_assertions)]
234 log::trace!("主循环收到关闭信号,退出");
235 break;
236 }
237
238 let queue_len = server.metadata_queue_len.load(Ordering::Relaxed);
239 let queue_pressure = queue_len as f64 / server.max_metadata_queue_size as f64;
240
241 #[cfg(feature = "metrics")]
242 {
243 gauge!("dht_metadata_queue_size").set(queue_len as f64);
244 gauge!("dht_metadata_worker_pressure").set(queue_pressure);
245 gauge!("dht_node_queue_size").set(server.node_queue.len() as f64);
246 }
247
248 let (batch_size, sleep_duration) = if queue_pressure < 0.8 {
249 (200, Duration::from_millis(10))
250 } else if queue_pressure < 0.95 {
251 (20, Duration::from_millis(500))
252 } else {
253 (0, Duration::from_millis(1000))
254 };
255
256 let filter_ipv6 = match server.options.netmode {
257 NetMode::Ipv4Only => Some(false),
258 NetMode::Ipv6Only => Some(true),
259 NetMode::DualStack => None,
260 };
261
262 let queue_empty = server.node_queue.is_empty_for(filter_ipv6);
263
264 let nodes_batch = {
265 if queue_empty || batch_size == 0 {
266 None
267 } else {
268 Some(server.node_queue.pop_batch(batch_size, filter_ipv6))
269 }
270 };
271
272 loop_tick += 1;
273 if nodes_batch.is_none() || loop_tick % 50 == 0 {
274 server.bootstrap().await;
275 if nodes_batch.is_none() {
276 tokio::select! {
277 _ = shutdown.cancelled() => break,
278 _ = tokio::time::sleep(sleep_duration) => {},
279 }
280 continue;
281 }
282 }
283
284 if let Some(nodes) = nodes_batch {
285 let node_id = server.node_id;
286
287 for node in nodes {
288 let permit = match semaphore.clone().acquire_owned().await {
289 Ok(p) => p,
290 Err(_) => break,
291 };
292 let node_id_clone = node_id;
293 let socket = match server.socket_for_addr(&node.addr) {
294 Some(sock) => sock,
295 None => {
296 log::warn!("未绑定任何地址");
297 break;
298 }
299 };
300 let node_addr = node.addr;
301 let node_id_for_target = node.id;
302
303 tokio::spawn(async move {
304 let neighbor_id =
305 generate_neighbor_target(&node_id_for_target, &node_id_clone);
306 let random_target = generate_random_id();
307 let _ = send_find_node_impl(
308 &node_addr,
309 &random_target,
310 &neighbor_id,
311 socket,
312 )
313 .await;
314 drop(permit);
315 });
316 }
317 }
318
319 tokio::select! {
320 _ = shutdown.cancelled() => break,
321 _ = tokio::time::sleep(sleep_duration) => {},
322 }
323 }
324 });
325 self.shutdown.cancelled().await;
326 Ok(())
327 }
328
329 pub fn shutdown(&self) {
331 self.shutdown.cancel();
332 }
333
334 fn spawn_workers(&self) -> Vec<WorkerHandle> {
335 let server = self.clone();
336 let shutdown = self.shutdown.clone();
337
338 let num_workers = std::thread::available_parallelism()
339 .map(|n| n.get())
340 .unwrap_or(8);
341
342 let queue_size = 5000;
343
344 let mut workers: Vec<WorkerHandle> = Vec::with_capacity(num_workers);
345 for _ in 0..num_workers {
346 let (tx, mut rx) = mpsc::channel(queue_size);
347 workers.push(tx);
348
349 let server_clone = server.clone();
350 let cancellation_token = shutdown.clone();
351
352 tokio::spawn(async move {
353 loop {
354 tokio::select! {
355 _ = cancellation_token.cancelled() => {
356 #[cfg(debug_assertions)]
357 log::trace!("Worker 收到关闭信号,退出");
358 break;
359 }
360 msg = rx.recv() => {
361 match msg {
362 Some((data, remote_addr, local_addr)) => {
363 if let Err(e) = server_clone.handle_message(data.as_ref(), remote_addr, local_addr).await {
364 server_clone.emit_error(e);
365 }
366 }
367 None => break,
368 }
369 }
370 }
371 }
372 });
373 }
374 workers
375 }
376
377 async fn handle_message(
378 &self,
379 data: &[u8],
380 remote_addr: SocketAddr,
381 local_addr: SocketAddr,
382 ) -> Result<()> {
383 if self.socket_providers.get(&local_addr).is_none() {
384 #[cfg(debug_assertions)]
385 log::trace!(
386 "⚠️ 拒绝未绑定的地址: {} (当前模式: {:?})",
387 remote_addr,
388 self.options.netmode
389 );
390 return Ok(());
391 }
392
393 let msg: DhtMessage = match serde_bencode::from_bytes(data) {
394 Ok(m) => m,
395 Err(_) => {
396 #[cfg(feature = "metrics")]
397 counter!("dht_messages_parse_error_total").increment(1);
398 return Ok(());
399 }
400 };
401
402 #[cfg(feature = "metrics")]
403 {
404 let label = match msg.y.as_str() {
406 "q" => "q",
407 "r" => "r",
408 "e" => "e",
409 _ => "unknown", };
411 counter!("dht_messages_processed_total", "type" => label).increment(1);
412 }
413
414 match msg.y.as_str() {
415 "q" => {
416 if let Some(q_type) = &msg.q {
417 self.handle_query(&msg, q_type.as_bytes(), remote_addr, local_addr)
418 .await?;
419 }
420 }
421 "r" => {
422 if let Some(response) = &msg.r {
423 self.handle_response(response).await?;
424 }
425 }
426 _ => {}
427 }
428 Ok(())
429 }
430
431 async fn handle_query(
432 &self,
433 msg: &DhtMessage,
434 query_type: &[u8],
435 remote_addr: SocketAddr,
436 local_addr: SocketAddr,
437 ) -> Result<()> {
438 let args = match &msg.a {
439 Some(a) => a,
440 None => return Ok(()),
441 };
442
443 let transaction_id = &msg.t;
444 let sender_id: Option<&[u8]> = args.id.as_deref().map(|v| v.as_slice());
445 let target_id_fallback: Option<&[u8]> = args
446 .target
447 .as_deref()
448 .or(args.info_hash.as_deref())
449 .map(|v| v.as_slice());
450
451 let q_str = std::str::from_utf8(query_type).unwrap_or("");
452
453 #[cfg(feature = "metrics")]
454 {
455 let label = match q_str {
456 "ping" => "ping",
457 "find_node" => "find_node",
458 "get_peers" => "get_peers",
459 "announce_peer" => "announce_peer",
460 "vote" => "vote",
461 _ => "other_or_invalid",
462 };
463 counter!("dht_queries_total", "q" => label).increment(1);
464 }
465
466 if q_str == "announce_peer" {
467 self.handle_announce_peer(args, remote_addr).await?;
468 }
469
470 self.send_response(
471 transaction_id,
472 remote_addr,
473 local_addr,
474 q_str,
475 sender_id,
476 target_id_fallback,
477 )
478 .await?;
479 Ok(())
480 }
481
482 async fn handle_announce_peer(&self, args: &DhtArgs, addr: SocketAddr) -> Result<()> {
483 if let Some(token) = &args.token {
484 if !self.validate_token(token, addr) {
485 #[cfg(feature = "metrics")]
486 counter!("dht_announce_peer_blocked_total", "reason" => "invalid_token")
487 .increment(1);
488 return Ok(());
489 }
490 } else {
491 return Ok(());
492 }
493
494 if let Some(info_hash) = &args.info_hash {
495 let info_hash_arr: [u8; 20] = match info_hash.as_ref().try_into() {
496 Ok(arr) => arr,
497 Err(_) => return Ok(()),
498 };
499 let hash_hex = hex::encode(info_hash_arr);
500
501 let filter_cb = self.filter.read().unwrap_or_else(|e| e.into_inner()).clone();
502 if let Some(f) = filter_cb
503 && !f(&hash_hex)
504 {
505 #[cfg(feature = "metrics")]
506 counter!("dht_announce_peer_blocked_total", "reason" => "filtered").increment(1);
507 return Ok(());
508 }
509
510 #[cfg(feature = "metrics")]
511 counter!("dht_info_hashes_discovered_total").increment(1);
512
513 #[cfg(debug_assertions)]
514 log::debug!("🔥 新 Hash: {} 来自 {}", hash_hex, addr);
515
516 let port = if let Some(implied) = args.implied_port {
517 if implied != 0 {
518 addr.port()
519 } else {
520 args.port.unwrap_or(0)
521 }
522 } else {
523 args.port.unwrap_or(addr.port())
524 };
525
526 if port > 0 {
527 let event = HashDiscovered {
528 info_hash: hash_hex,
529 peer_addr: SocketAddr::new(addr.ip(), port),
530 discovered_at: std::time::Instant::now(),
531 };
532
533 if self.hash_tx.try_send(event).is_err() {
534 #[cfg(debug_assertions)]
535 log::debug!("⚠️ Hash 队列满,丢弃 hash");
536 }
537 }
538 }
539 Ok(())
540 }
541
542 async fn handle_response(&self, response: &DhtResponse) -> Result<()> {
543 if let Some(nodes_bytes) = &response.nodes {
544 self.process_compact_nodes(nodes_bytes);
545 }
546 if let Some(nodes6_bytes) = &response.nodes6 {
547 self.process_compact_nodes_v6(nodes6_bytes);
548 }
549 Ok(())
550 }
551
552 fn process_compact_nodes(&self, nodes_bytes: &[u8]) {
553 if self.options.netmode == NetMode::Ipv6Only {
554 return;
555 }
556
557 #[allow(clippy::manual_is_multiple_of)]
558 if nodes_bytes.len() % 26 != 0 {
559 return;
560 }
561
562 for chunk in nodes_bytes.chunks(26) {
563 let id = chunk[0..20].to_vec();
564 let port = u16::from_be_bytes([chunk[24], chunk[25]]);
565
566 let ip = std::net::Ipv4Addr::new(chunk[20], chunk[21], chunk[22], chunk[23]);
567 let addr = SocketAddr::new(std::net::IpAddr::V4(ip), port);
568
569 #[cfg(feature = "metrics")]
570 counter!("dht_nodes_discovered_total", "ip_version" => "v4").increment(1);
571
572 self.node_queue.push(NodeTuple { id, addr });
573 }
574 }
575
576 fn process_compact_nodes_v6(&self, nodes_bytes: &[u8]) {
577 if self.options.netmode == NetMode::Ipv4Only {
578 return;
579 }
580
581 #[allow(clippy::manual_is_multiple_of)]
582 if nodes_bytes.len() % 38 != 0 {
583 return;
584 }
585 for chunk in nodes_bytes.chunks(38) {
586 let id = chunk[0..20].to_vec();
587 let port = u16::from_be_bytes([chunk[36], chunk[37]]);
588 let ip_bytes: [u8; 16] = match chunk[20..36].try_into() {
589 Ok(b) => b,
590 Err(_) => continue,
591 };
592 let ip = Ipv6Addr::from(ip_bytes);
593 if !ip.is_unspecified() && !ip.is_multicast() {
594 let addr = SocketAddr::new(IpAddr::V6(ip), port);
595
596 #[cfg(feature = "metrics")]
597 counter!("dht_nodes_discovered_total", "ip_version" => "v6").increment(1);
598
599 self.node_queue.push(NodeTuple { id, addr });
600 }
601 }
602 }
603
604 async fn send_response(
605 &self,
606 tid: &[u8],
607 remote_addr: SocketAddr,
608 local_addr: SocketAddr,
609 query_type: &str,
610 sender_id: Option<&[u8]>,
611 target_id_fallback: Option<&[u8]>,
612 ) -> Result<()> {
613 let socket = match self.socket_providers.get(&local_addr) {
614 Some(sock) => sock,
615 None => return Ok(()), };
617
618 let mut r_dict = std::collections::HashMap::new();
619
620 let reference_id = sender_id.or(target_id_fallback);
621 let my_id = if let Some(target) = reference_id {
622 generate_neighbor_target(target, &self.node_id)
623 } else {
624 self.node_id.to_vec()
625 };
626
627 r_dict.insert(b"id".to_vec(), serde_bencode::value::Value::Bytes(my_id));
628 let token = self.generate_token(remote_addr);
629 r_dict.insert(
630 b"token".to_vec(),
631 serde_bencode::value::Value::Bytes(token.to_vec()),
632 );
633
634 if query_type == "get_peers" || query_type == "find_node" {
635 let requestor_is_ipv6 = remote_addr.is_ipv6();
636 let filter_ipv6 = match self.options.netmode {
637 NetMode::Ipv4Only => Some(false),
638 NetMode::Ipv6Only => Some(true),
639 NetMode::DualStack => Some(requestor_is_ipv6),
640 };
641
642 let nodes = self.node_queue.get_random_nodes(8, filter_ipv6);
643
644 let mut nodes_data = Vec::new();
645 let mut nodes6_data = Vec::new();
646
647 for node in nodes {
648 match node.addr.ip() {
649 IpAddr::V4(ip) => {
650 nodes_data.extend_from_slice(&node.id);
651 nodes_data.extend_from_slice(&ip.octets());
652 nodes_data.extend_from_slice(&node.addr.port().to_be_bytes());
653 }
654 IpAddr::V6(ip) => {
655 nodes6_data.extend_from_slice(&node.id);
656 nodes6_data.extend_from_slice(&ip.octets());
657 nodes6_data.extend_from_slice(&node.addr.port().to_be_bytes());
658 }
659 }
660 }
661
662 if requestor_is_ipv6 {
663 if !nodes6_data.is_empty() {
664 r_dict.insert(
665 b"nodes6".to_vec(),
666 serde_bencode::value::Value::Bytes(nodes6_data),
667 );
668 }
669 } else if !nodes_data.is_empty() {
670 r_dict.insert(
671 b"nodes".to_vec(),
672 serde_bencode::value::Value::Bytes(nodes_data),
673 );
674 }
675 }
676
677 let mut response: std::collections::HashMap<String, serde_bencode::value::Value> =
678 std::collections::HashMap::new();
679 response.insert(
680 "t".to_string(),
681 serde_bencode::value::Value::Bytes(tid.to_vec()),
682 );
683 response.insert(
684 "y".to_string(),
685 serde_bencode::value::Value::Bytes(b"r".to_vec()),
686 );
687 response.insert("r".to_string(), serde_bencode::value::Value::Dict(r_dict));
688
689 if let Ok(encoded) = serde_bencode::to_bytes(&response) {
690 #[allow(unused)]
691 if let Ok(len) = socket.send_to(&encoded, remote_addr).await {
692 #[cfg(feature = "metrics")]
693 {
694 counter!("dht_udp_bytes_sent_total").increment(len as u64);
695 counter!("dht_udp_packets_sent_total", "type" => "response").increment(1);
696 }
697 }
698 }
699 Ok(())
700 }
701
702 async fn bootstrap(&self) {
703 let target = generate_random_id();
704 for node in BOOTSTRAP_NODES {
705 if let Ok(addrs) = tokio::net::lookup_host(node).await {
706 for addr in addrs {
707 match self.options.netmode {
708 NetMode::Ipv4Only => {
709 if addr.is_ipv6() {
710 continue;
711 }
712 }
713 NetMode::Ipv6Only => {
714 if addr.is_ipv4() {
715 continue;
716 }
717 }
718 NetMode::DualStack => {}
719 }
720 let _ = self.send_find_node(&addr, &target, &self.node_id).await;
721 }
722 }
723 }
724 }
725
726 fn socket_for_addr(&self, addr: &SocketAddr) -> Option<Arc<UdpSocket>> {
727 self.socket_providers
728 .iter()
729 .find(|(bind_addr, _)| bind_addr.is_ipv4() == addr.is_ipv4())
730 .map(|(_, sock)| sock.clone())
731 }
732
733 async fn send_find_node(&self, target_addr: &SocketAddr, target: &[u8], sender_id: &[u8]) {
734 if let Some(sock) = self.socket_for_addr(target_addr) {
735 send_find_node_impl(target_addr, target, sender_id, sock).await
736 }
737 }
738
739 fn generate_token(&self, addr: SocketAddr) -> [u8; 8] {
740 let mut hasher = ahash::AHasher::default();
741
742 match addr.ip() {
743 IpAddr::V4(ip) => ip.octets().hash(&mut hasher),
744 IpAddr::V6(ip) => ip.octets().hash(&mut hasher),
745 }
746
747 self.token_secret.hash(&mut hasher);
748
749 hasher.finish().to_le_bytes()
750 }
751
752 fn validate_token(&self, token: &[u8], addr: SocketAddr) -> bool {
753 if token.len() != 8 {
754 return false;
755 }
756 let expected = self.generate_token(addr);
757 token == expected
758 }
759}
760
761async fn send_find_node_impl(
800 addr: &SocketAddr,
801 target: &[u8],
802 sender_id: &[u8],
803 socket: Arc<UdpSocket>,
804) {
805 let mut args = std::collections::HashMap::new();
807 args.insert(
808 b"id".to_vec(),
809 serde_bencode::value::Value::Bytes(sender_id.to_vec()),
810 );
811 args.insert(
812 b"target".to_vec(),
813 serde_bencode::value::Value::Bytes(target.to_vec()),
814 );
815
816 let mut msg: std::collections::HashMap<String, serde_bencode::value::Value> =
818 std::collections::HashMap::new();
819 msg.insert(
820 "t".to_string(),
821 serde_bencode::value::Value::Bytes(vec![0, 1]),
822 ); msg.insert(
824 "y".to_string(),
825 serde_bencode::value::Value::Bytes(b"q".to_vec()),
826 ); msg.insert(
828 "q".to_string(),
829 serde_bencode::value::Value::Bytes(b"find_node".to_vec()),
830 ); msg.insert("a".to_string(), serde_bencode::value::Value::Dict(args)); if let Ok(encoded) = serde_bencode::to_bytes(&msg) {
835 #[cfg(feature = "metrics")]
837 {
838 counter!("dht_udp_bytes_sent_total").increment(encoded.len() as u64);
839 counter!("dht_udp_packets_sent_total", "type" => "query").increment(1);
840 }
841 let _ = socket.send_to(&encoded, addr).await;
842 }
843}
844
845fn generate_random_id() -> [u8; 20] {
846 let mut id = [0u8; 20];
847 rand::thread_rng().fill(&mut id);
848 id
849}
850
851fn generate_neighbor_target(remote_id: &[u8], local_id: &[u8]) -> Vec<u8> {
891 let mut id = Vec::with_capacity(20);
892 let prefix_len = std::cmp::min(remote_id.len(), 6);
893 id.extend_from_slice(&remote_id[..prefix_len]);
894
895 if local_id.len() > prefix_len {
896 id.extend_from_slice(&local_id[prefix_len..]);
897 } else {
898 while id.len() < 20 {
899 id.push(rand::random());
900 }
901 }
902 id
903}
904
905fn spawn_udp_listener(
906 socket: Arc<UdpSocket>,
907 mut workers: Vec<WorkerHandle>,
908 shutdown: CancellationToken,
909) -> crate::error::Result<()> {
910 let local_addr = socket
911 .local_addr()
912 .map_err(|e| crate::error::DHTError::Init(format!("socket 无法获取本地地址: {e}")))?;
913 if workers.is_empty() {
914 return Err(crate::error::DHTError::Init(
915 "spawn_udp_listener: 未提供任何 worker".to_string(),
916 ));
917 }
918 tokio::spawn(async move {
919 let mut buffer = [0u8; 65536];
920 let mut worker_index = 0;
921
922 loop {
923 tokio::select! {
924 _ = shutdown.cancelled() => {
925 #[cfg(debug_assertions)]
926 log::trace!("UDP 读取循环收到关闭信号,退出");
927 break;
928 }
929 result = socket.recv_from(&mut buffer) => {
930 match result {
931 Ok((size, origin_addr)) => {
932 if let Err(ProcessUdpPacketError::NoLiveWorkers) = process_udp_packet(size, origin_addr, local_addr, &buffer, &mut workers, &mut worker_index){
933 log::warn!("Socket {socket:?} is closing because no worker can process packets.");
934 break
936 }
937 },
938 Err(_e) => {
939 tokio::select! {
940 _ = shutdown.cancelled() => break,
941 _ = tokio::time::sleep(Duration::from_millis(1)) => {},
942 }
943 }
944 }
945 }
946 }
947 }
948 });
949 Ok(())
950}
951
952enum ProcessUdpPacketError {
953 PacketTooLarge,
954 InvalidPacket,
955 ChokedWorkers,
956 NoLiveWorkers,
957}
958
959fn process_udp_packet(
960 size: usize,
961 origin_addr: SocketAddr,
962 local_addr: SocketAddr,
963 buffer: &[u8],
964 workers: &mut Vec<WorkerHandle>,
965 worker_index: &mut usize,
966) -> std::result::Result<(), ProcessUdpPacketError> {
967 #[cfg(feature = "metrics")]
968 counter!("dht_udp_bytes_received_total").increment(size as u64);
969
970 if size > 8192 {
971 #[cfg(feature = "metrics")]
972 counter!("dht_udp_packets_received_total", "status" => "dropped_size").increment(1);
973
974 #[cfg(debug_assertions)]
975 log::trace!("⚠️ 拒绝异常大的 UDP 包: {} 字节 from {}", size, origin_addr);
976 return Err(ProcessUdpPacketError::PacketTooLarge);
977 }
978
979 if size == 0 || buffer[0] != b'd' {
980 #[cfg(feature = "metrics")]
981 counter!("dht_udp_packets_received_total", "status" => "dropped_magic").increment(1);
982 return Err(ProcessUdpPacketError::InvalidPacket);
983 }
984 let mut data = Some(buffer[..size].to_owned().into_boxed_slice());
985 let mut attempts = 0;
986 let max_attempts = workers.len();
987
988 while let Some(packet) = data.take() {
989 let worker = &workers[*worker_index];
990 match worker.try_send((packet, origin_addr, local_addr)) {
991 Ok(_) => {
992 #[cfg(feature = "metrics")]
993 counter!("dht_udp_packets_received_total", "status" => "ok").increment(1);
994 break;
995 }
996 Err(mpsc::error::TrySendError::Full((packet, _, _))) => {
997 attempts += 1;
998 if attempts >= max_attempts {
999 #[cfg(feature = "metrics")]
1000 counter!("dht_udp_packets_received_total", "status" => "queue_full")
1001 .increment(1);
1002
1003 #[cfg(debug_assertions)]
1004 log::trace!("UDP worker queue full, dropping packet");
1005 return Err(ProcessUdpPacketError::ChokedWorkers);
1006 }
1007 let _ = data.insert(packet);
1008 }
1009 Err(mpsc::error::TrySendError::Closed((packet, _, _))) => {
1010 log::warn!("UDP worker dropped.");
1011 workers.swap_remove(*worker_index);
1012 let _ = data.insert(packet);
1013 }
1014 }
1015 if workers.is_empty() {
1016 return Err(ProcessUdpPacketError::NoLiveWorkers);
1017 }
1018 *worker_index = (*worker_index + 1) % workers.len();
1019 }
1020
1021 Ok(())
1022}