Skip to main content

dht_crawler/
server.rs

1use crate::error::Result;
2use crate::metadata::RbitFetcher;
3use crate::protocol::{DhtArgs, DhtMessage, DhtResponse};
4use crate::scheduler::MetadataScheduler;
5use crate::sharded::{NodeTuple, ShardedNodeQueue};
6use crate::types::{DHTOptions, NetMode, TorrentInfo};
7#[cfg(feature = "metrics")]
8use metrics::{counter, gauge};
9use rand::Rng;
10use socket2::{Domain, Protocol, Socket, Type};
11use std::collections::HashMap;
12use std::future::Future;
13use std::hash::{Hash, Hasher};
14use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
15use std::pin::Pin;
16use std::sync::atomic::{AtomicUsize, Ordering};
17use std::sync::{Arc, RwLock};
18use std::time::Duration;
19use tokio::net::UdpSocket;
20use tokio::sync::{Semaphore, mpsc};
21use tokio_util::sync::CancellationToken;
22
23const BOOTSTRAP_NODES: &[&str] = &[
24    "router.bittorrent.com:6881",
25    "dht.transmissionbt.com:6881",
26    "router.utorrent.com:6881",
27    "dht.aelitis.com:6881",
28];
29
30pub type BoxedBoolFuture = Pin<Box<dyn Future<Output = bool> + Send>>;
31pub type MetadataFetchCallback = Arc<dyn Fn(String) -> BoxedBoolFuture + Send + Sync>;
32type WorkerHandle = mpsc::Sender<(Box<[u8]>, SocketAddr, SocketAddr)>;
33
34#[derive(Debug, Clone)]
35pub struct HashDiscovered {
36    pub info_hash: String,
37    pub peer_addr: SocketAddr,
38    pub discovered_at: std::time::Instant,
39}
40
41type TorrentCallback = Arc<dyn Fn(TorrentInfo) + Send + Sync>;
42type FilterCallback = Arc<dyn Fn(&str) -> bool + Send + Sync>;
43type ErrorCallback = Arc<dyn Fn(crate::error::DHTError) + Send + Sync>;
44
45#[derive(Clone)]
46pub struct DHTServer {
47    #[allow(dead_code)]
48    options: DHTOptions,
49    node_id: [u8; 20],
50    socket_providers: Arc<HashMap<SocketAddr, Arc<UdpSocket>>>,
51    token_secret: [u8; 10],
52    callback: Arc<RwLock<Option<TorrentCallback>>>,
53    filter: Arc<RwLock<Option<FilterCallback>>>,
54    on_metadata_fetch: Arc<RwLock<Option<MetadataFetchCallback>>>,
55    on_error_cb: Arc<RwLock<Option<ErrorCallback>>>,
56    node_queue: Arc<ShardedNodeQueue>,
57    hash_tx: mpsc::Sender<HashDiscovered>,
58    metadata_queue_len: Arc<AtomicUsize>,
59    max_metadata_queue_size: usize,
60    shutdown: CancellationToken,
61}
62
63fn create_udp_sock(domain: Domain, ty: Type, addr: SocketAddr) -> std::io::Result<UdpSocket> {
64    let sock = Socket::new(domain, ty, Some(Protocol::UDP))?;
65    #[cfg(not(windows))]
66    {
67        sock.set_reuse_port(true)?;
68        if addr.is_ipv6() {
69            sock.set_only_v6(true)?;
70        }
71    }
72    let _ = sock.set_reuse_address(true);
73    sock.set_nonblocking(true)?;
74
75    let _ = sock.set_recv_buffer_size(32 * 1024 * 1024);
76    let _ = sock.set_send_buffer_size(8 * 1024 * 1024);
77
78    sock.bind(&addr.into())?;
79    UdpSocket::from_std(sock.into())
80}
81
82impl DHTServer {
83    pub async fn new(options: DHTOptions) -> Result<Self> {
84        const ANY_V4_ADDR: SocketAddr =
85            SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 8080);
86        const ANY_V6_ADDR: SocketAddr =
87            SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 8080);
88        let mut socket_providers = HashMap::new();
89        // TODO: Check address reachability
90        match options.netmode {
91            NetMode::Ipv4Only => {
92                let mut addr = ANY_V4_ADDR;
93                addr.set_port(options.port);
94                let sock = create_udp_sock(Domain::IPV4, Type::DGRAM, addr)?;
95                socket_providers.insert(addr, Arc::new(sock));
96            }
97            NetMode::Ipv6Only => {
98                let mut addr = ANY_V6_ADDR;
99                addr.set_port(options.port);
100                let sock = create_udp_sock(Domain::IPV6, Type::DGRAM, addr)?;
101                socket_providers.insert(addr, Arc::new(sock));
102            }
103            NetMode::DualStack => {
104                let mut addr = ANY_V4_ADDR;
105                addr.set_port(options.port);
106                let sock = create_udp_sock(Domain::IPV4, Type::DGRAM, addr)?;
107                socket_providers.insert(addr, Arc::new(sock));
108                let mut addr = ANY_V6_ADDR;
109                addr.set_port(options.port);
110                let sock = create_udp_sock(Domain::IPV6, Type::DGRAM, addr)?;
111                socket_providers.insert(addr, Arc::new(sock));
112            }
113        };
114
115        let node_id = generate_random_id();
116        let mut token_secret = [0u8; 10];
117        rand::thread_rng().fill(&mut token_secret);
118
119        let node_queue = ShardedNodeQueue::new(options.node_queue_capacity);
120
121        let (hash_tx, hash_rx) = mpsc::channel::<HashDiscovered>(options.hash_queue_capacity);
122
123        let fetcher = Arc::new(RbitFetcher::new(options.metadata_timeout));
124
125        let callback = Arc::new(RwLock::new(None));
126        let on_metadata_fetch = Arc::new(RwLock::new(None));
127
128        let metadata_queue_len = Arc::new(AtomicUsize::new(0));
129
130        let shutdown = CancellationToken::new();
131        let shutdown_for_scheduler = shutdown.clone();
132
133        let scheduler = MetadataScheduler::new(
134            hash_rx,
135            fetcher,
136            options.max_metadata_queue_size,
137            options.max_metadata_worker_count,
138            callback.clone(),
139            on_metadata_fetch.clone(),
140            metadata_queue_len.clone(),
141            shutdown_for_scheduler,
142        );
143
144        tokio::spawn(async move {
145            scheduler.run().await;
146        });
147
148        let max_metadata_queue_size = options.max_metadata_queue_size;
149        let server = Self {
150            options,
151            node_id,
152            socket_providers: Arc::new(socket_providers),
153            token_secret,
154            callback,
155            on_metadata_fetch,
156            node_queue: Arc::new(node_queue),
157            filter: Arc::new(RwLock::new(None)),
158            on_error_cb: Arc::new(RwLock::new(None)),
159            hash_tx,
160            metadata_queue_len,
161            max_metadata_queue_size,
162            shutdown,
163        };
164
165        Ok(server)
166    }
167
168    pub fn on_metadata_fetch<F, Fut>(&self, callback: F)
169    where
170        F: Fn(String) -> Fut + Send + Sync + 'static,
171        Fut: Future<Output = bool> + Send + 'static,
172    {
173        *self.on_metadata_fetch.write().unwrap_or_else(|e| e.into_inner()) =
174            Some(Arc::new(move |hash| Box::pin(callback(hash))));
175    }
176
177    pub fn on_torrent<F>(&self, callback: F)
178    where
179        F: Fn(TorrentInfo) + Send + Sync + 'static,
180    {
181        *self.callback.write().unwrap_or_else(|e| e.into_inner()) = Some(Arc::new(callback));
182    }
183
184    pub fn set_filter<F>(&self, filter: F)
185    where
186        F: Fn(&str) -> bool + Send + Sync + 'static,
187    {
188        *self.filter.write().unwrap_or_else(|e| e.into_inner()) = Some(Arc::new(filter));
189    }
190
191    pub fn on_error<F>(&self, callback: F)
192    where
193        F: Fn(crate::error::DHTError) + Send + Sync + 'static,
194    {
195        *self.on_error_cb.write().unwrap_or_else(|e| e.into_inner()) = Some(Arc::new(callback));
196    }
197
198    fn emit_error(&self, error: crate::error::DHTError) {
199        if let Ok(cb) = self.on_error_cb.read() {
200            if let Some(f) = cb.as_ref() {
201                f(error);
202            }
203        }
204    }
205
206    pub fn get_node_pool_size(&self) -> usize {
207        self.node_queue.len()
208    }
209
210    pub async fn start(&self) -> Result<()> {
211        // 检查是否已经被关闭
212        if self.shutdown.is_cancelled() {
213            log::warn!("⚠️ 尝试启动已关闭的服务器");
214            return Err(crate::error::DHTError::Other("服务器已关闭".to_string()));
215        }
216
217        let workers = self.spawn_workers();
218        for sock in self.socket_providers.values().cloned() {
219            spawn_udp_listener(sock, workers.clone(), self.shutdown.clone())?;
220        }
221        self.bootstrap().await;
222
223        let server = self.clone();
224        let shutdown = self.shutdown.clone();
225
226        tokio::spawn(async move {
227            let semaphore = Arc::new(Semaphore::new(2000));
228            let mut loop_tick = 0;
229
230            loop {
231                // 检查关闭信号
232                if shutdown.is_cancelled() {
233                    #[cfg(debug_assertions)]
234                    log::trace!("主循环收到关闭信号,退出");
235                    break;
236                }
237
238                let queue_len = server.metadata_queue_len.load(Ordering::Relaxed);
239                let queue_pressure = queue_len as f64 / server.max_metadata_queue_size as f64;
240
241                #[cfg(feature = "metrics")]
242                {
243                    gauge!("dht_metadata_queue_size").set(queue_len as f64);
244                    gauge!("dht_metadata_worker_pressure").set(queue_pressure);
245                    gauge!("dht_node_queue_size").set(server.node_queue.len() as f64);
246                }
247
248                let (batch_size, sleep_duration) = if queue_pressure < 0.8 {
249                    (200, Duration::from_millis(10))
250                } else if queue_pressure < 0.95 {
251                    (20, Duration::from_millis(500))
252                } else {
253                    (0, Duration::from_millis(1000))
254                };
255
256                let filter_ipv6 = match server.options.netmode {
257                    NetMode::Ipv4Only => Some(false),
258                    NetMode::Ipv6Only => Some(true),
259                    NetMode::DualStack => None,
260                };
261
262                let queue_empty = server.node_queue.is_empty_for(filter_ipv6);
263
264                let nodes_batch = {
265                    if queue_empty || batch_size == 0 {
266                        None
267                    } else {
268                        Some(server.node_queue.pop_batch(batch_size, filter_ipv6))
269                    }
270                };
271
272                loop_tick += 1;
273                if nodes_batch.is_none() || loop_tick % 50 == 0 {
274                    server.bootstrap().await;
275                    if nodes_batch.is_none() {
276                        tokio::select! {
277                            _ = shutdown.cancelled() => break,
278                            _ = tokio::time::sleep(sleep_duration) => {},
279                        }
280                        continue;
281                    }
282                }
283
284                if let Some(nodes) = nodes_batch {
285                    let node_id = server.node_id;
286
287                    for node in nodes {
288                        let permit = match semaphore.clone().acquire_owned().await {
289                            Ok(p) => p,
290                            Err(_) => break,
291                        };
292                        let node_id_clone = node_id;
293                        let socket = match server.socket_for_addr(&node.addr) {
294                            Some(sock) => sock,
295                            None => {
296                                log::warn!("未绑定任何地址");
297                                break;
298                            }
299                        };
300                        let node_addr = node.addr;
301                        let node_id_for_target = node.id;
302
303                        tokio::spawn(async move {
304                            let neighbor_id =
305                                generate_neighbor_target(&node_id_for_target, &node_id_clone);
306                            let random_target = generate_random_id();
307                            let _ = send_find_node_impl(
308                                &node_addr,
309                                &random_target,
310                                &neighbor_id,
311                                socket,
312                            )
313                            .await;
314                            drop(permit);
315                        });
316                    }
317                }
318
319                tokio::select! {
320                    _ = shutdown.cancelled() => break,
321                    _ = tokio::time::sleep(sleep_duration) => {},
322                }
323            }
324        });
325        self.shutdown.cancelled().await;
326        Ok(())
327    }
328
329    /// 显式关闭服务器,停止所有后台任务
330    pub fn shutdown(&self) {
331        self.shutdown.cancel();
332    }
333
334    fn spawn_workers(&self) -> Vec<WorkerHandle> {
335        let server = self.clone();
336        let shutdown = self.shutdown.clone();
337
338        let num_workers = std::thread::available_parallelism()
339            .map(|n| n.get())
340            .unwrap_or(8);
341
342        let queue_size = 5000;
343
344        let mut workers: Vec<WorkerHandle> = Vec::with_capacity(num_workers);
345        for _ in 0..num_workers {
346            let (tx, mut rx) = mpsc::channel(queue_size);
347            workers.push(tx);
348
349            let server_clone = server.clone();
350            let cancellation_token = shutdown.clone();
351
352            tokio::spawn(async move {
353                loop {
354                    tokio::select! {
355                        _ = cancellation_token.cancelled() => {
356                            #[cfg(debug_assertions)]
357                            log::trace!("Worker 收到关闭信号,退出");
358                            break;
359                        }
360                        msg = rx.recv() => {
361                            match msg {
362                                Some((data, remote_addr, local_addr)) => {
363                                    if let Err(e) = server_clone.handle_message(data.as_ref(), remote_addr, local_addr).await {
364                                        server_clone.emit_error(e);
365                                    }
366                                }
367                                None => break,
368                            }
369                        }
370                    }
371                }
372            });
373        }
374        workers
375    }
376
377    async fn handle_message(
378        &self,
379        data: &[u8],
380        remote_addr: SocketAddr,
381        local_addr: SocketAddr,
382    ) -> Result<()> {
383        if self.socket_providers.get(&local_addr).is_none() {
384            #[cfg(debug_assertions)]
385            log::trace!(
386                "⚠️ 拒绝未绑定的地址: {} (当前模式: {:?})",
387                remote_addr,
388                self.options.netmode
389            );
390            return Ok(());
391        }
392
393        let msg: DhtMessage = match serde_bencode::from_bytes(data) {
394            Ok(m) => m,
395            Err(_) => {
396                #[cfg(feature = "metrics")]
397                counter!("dht_messages_parse_error_total").increment(1);
398                return Ok(());
399            }
400        };
401
402        #[cfg(feature = "metrics")]
403        {
404            // 使用 match 映射到静态字符串,避免 clone(),同时防止恶意 tag
405            let label = match msg.y.as_str() {
406                "q" => "q",
407                "r" => "r",
408                "e" => "e",
409                _ => "unknown", // 将所有非法/未知类型归一化
410            };
411            counter!("dht_messages_processed_total", "type" => label).increment(1);
412        }
413
414        match msg.y.as_str() {
415            "q" => {
416                if let Some(q_type) = &msg.q {
417                    self.handle_query(&msg, q_type.as_bytes(), remote_addr, local_addr)
418                        .await?;
419                }
420            }
421            "r" => {
422                if let Some(response) = &msg.r {
423                    self.handle_response(response).await?;
424                }
425            }
426            _ => {}
427        }
428        Ok(())
429    }
430
431    async fn handle_query(
432        &self,
433        msg: &DhtMessage,
434        query_type: &[u8],
435        remote_addr: SocketAddr,
436        local_addr: SocketAddr,
437    ) -> Result<()> {
438        let args = match &msg.a {
439            Some(a) => a,
440            None => return Ok(()),
441        };
442
443        let transaction_id = &msg.t;
444        let sender_id: Option<&[u8]> = args.id.as_deref().map(|v| v.as_slice());
445        let target_id_fallback: Option<&[u8]> = args
446            .target
447            .as_deref()
448            .or(args.info_hash.as_deref())
449            .map(|v| v.as_slice());
450
451        let q_str = std::str::from_utf8(query_type).unwrap_or("");
452
453        #[cfg(feature = "metrics")]
454        {
455            let label = match q_str {
456                "ping" => "ping",
457                "find_node" => "find_node",
458                "get_peers" => "get_peers",
459                "announce_peer" => "announce_peer",
460                "vote" => "vote",
461                _ => "other_or_invalid",
462            };
463            counter!("dht_queries_total", "q" => label).increment(1);
464        }
465
466        if q_str == "announce_peer" {
467            self.handle_announce_peer(args, remote_addr).await?;
468        }
469
470        self.send_response(
471            transaction_id,
472            remote_addr,
473            local_addr,
474            q_str,
475            sender_id,
476            target_id_fallback,
477        )
478        .await?;
479        Ok(())
480    }
481
482    async fn handle_announce_peer(&self, args: &DhtArgs, addr: SocketAddr) -> Result<()> {
483        if let Some(token) = &args.token {
484            if !self.validate_token(token, addr) {
485                #[cfg(feature = "metrics")]
486                counter!("dht_announce_peer_blocked_total", "reason" => "invalid_token")
487                    .increment(1);
488                return Ok(());
489            }
490        } else {
491            return Ok(());
492        }
493
494        if let Some(info_hash) = &args.info_hash {
495            let info_hash_arr: [u8; 20] = match info_hash.as_ref().try_into() {
496                Ok(arr) => arr,
497                Err(_) => return Ok(()),
498            };
499            let hash_hex = hex::encode(info_hash_arr);
500
501            let filter_cb = self.filter.read().unwrap_or_else(|e| e.into_inner()).clone();
502            if let Some(f) = filter_cb
503                && !f(&hash_hex)
504            {
505                #[cfg(feature = "metrics")]
506                counter!("dht_announce_peer_blocked_total", "reason" => "filtered").increment(1);
507                return Ok(());
508            }
509
510            #[cfg(feature = "metrics")]
511            counter!("dht_info_hashes_discovered_total").increment(1);
512
513            #[cfg(debug_assertions)]
514            log::debug!("🔥 新 Hash: {} 来自 {}", hash_hex, addr);
515
516            let port = if let Some(implied) = args.implied_port {
517                if implied != 0 {
518                    addr.port()
519                } else {
520                    args.port.unwrap_or(0)
521                }
522            } else {
523                args.port.unwrap_or(addr.port())
524            };
525
526            if port > 0 {
527                let event = HashDiscovered {
528                    info_hash: hash_hex,
529                    peer_addr: SocketAddr::new(addr.ip(), port),
530                    discovered_at: std::time::Instant::now(),
531                };
532
533                if self.hash_tx.try_send(event).is_err() {
534                    #[cfg(debug_assertions)]
535                    log::debug!("⚠️ Hash 队列满,丢弃 hash");
536                }
537            }
538        }
539        Ok(())
540    }
541
542    async fn handle_response(&self, response: &DhtResponse) -> Result<()> {
543        if let Some(nodes_bytes) = &response.nodes {
544            self.process_compact_nodes(nodes_bytes);
545        }
546        if let Some(nodes6_bytes) = &response.nodes6 {
547            self.process_compact_nodes_v6(nodes6_bytes);
548        }
549        Ok(())
550    }
551
552    fn process_compact_nodes(&self, nodes_bytes: &[u8]) {
553        if self.options.netmode == NetMode::Ipv6Only {
554            return;
555        }
556
557        #[allow(clippy::manual_is_multiple_of)]
558        if nodes_bytes.len() % 26 != 0 {
559            return;
560        }
561
562        for chunk in nodes_bytes.chunks(26) {
563            let id = chunk[0..20].to_vec();
564            let port = u16::from_be_bytes([chunk[24], chunk[25]]);
565
566            let ip = std::net::Ipv4Addr::new(chunk[20], chunk[21], chunk[22], chunk[23]);
567            let addr = SocketAddr::new(std::net::IpAddr::V4(ip), port);
568
569            #[cfg(feature = "metrics")]
570            counter!("dht_nodes_discovered_total", "ip_version" => "v4").increment(1);
571
572            self.node_queue.push(NodeTuple { id, addr });
573        }
574    }
575
576    fn process_compact_nodes_v6(&self, nodes_bytes: &[u8]) {
577        if self.options.netmode == NetMode::Ipv4Only {
578            return;
579        }
580
581        #[allow(clippy::manual_is_multiple_of)]
582        if nodes_bytes.len() % 38 != 0 {
583            return;
584        }
585        for chunk in nodes_bytes.chunks(38) {
586            let id = chunk[0..20].to_vec();
587            let port = u16::from_be_bytes([chunk[36], chunk[37]]);
588            let ip_bytes: [u8; 16] = match chunk[20..36].try_into() {
589                Ok(b) => b,
590                Err(_) => continue,
591            };
592            let ip = Ipv6Addr::from(ip_bytes);
593            if !ip.is_unspecified() && !ip.is_multicast() {
594                let addr = SocketAddr::new(IpAddr::V6(ip), port);
595
596                #[cfg(feature = "metrics")]
597                counter!("dht_nodes_discovered_total", "ip_version" => "v6").increment(1);
598
599                self.node_queue.push(NodeTuple { id, addr });
600            }
601        }
602    }
603
604    async fn send_response(
605        &self,
606        tid: &[u8],
607        remote_addr: SocketAddr,
608        local_addr: SocketAddr,
609        query_type: &str,
610        sender_id: Option<&[u8]>,
611        target_id_fallback: Option<&[u8]>,
612    ) -> Result<()> {
613        let socket = match self.socket_providers.get(&local_addr) {
614            Some(sock) => sock,
615            None => return Ok(()), // Silent failure when the socket is not present
616        };
617
618        let mut r_dict = std::collections::HashMap::new();
619
620        let reference_id = sender_id.or(target_id_fallback);
621        let my_id = if let Some(target) = reference_id {
622            generate_neighbor_target(target, &self.node_id)
623        } else {
624            self.node_id.to_vec()
625        };
626
627        r_dict.insert(b"id".to_vec(), serde_bencode::value::Value::Bytes(my_id));
628        let token = self.generate_token(remote_addr);
629        r_dict.insert(
630            b"token".to_vec(),
631            serde_bencode::value::Value::Bytes(token.to_vec()),
632        );
633
634        if query_type == "get_peers" || query_type == "find_node" {
635            let requestor_is_ipv6 = remote_addr.is_ipv6();
636            let filter_ipv6 = match self.options.netmode {
637                NetMode::Ipv4Only => Some(false),
638                NetMode::Ipv6Only => Some(true),
639                NetMode::DualStack => Some(requestor_is_ipv6),
640            };
641
642            let nodes = self.node_queue.get_random_nodes(8, filter_ipv6);
643
644            let mut nodes_data = Vec::new();
645            let mut nodes6_data = Vec::new();
646
647            for node in nodes {
648                match node.addr.ip() {
649                    IpAddr::V4(ip) => {
650                        nodes_data.extend_from_slice(&node.id);
651                        nodes_data.extend_from_slice(&ip.octets());
652                        nodes_data.extend_from_slice(&node.addr.port().to_be_bytes());
653                    }
654                    IpAddr::V6(ip) => {
655                        nodes6_data.extend_from_slice(&node.id);
656                        nodes6_data.extend_from_slice(&ip.octets());
657                        nodes6_data.extend_from_slice(&node.addr.port().to_be_bytes());
658                    }
659                }
660            }
661
662            if requestor_is_ipv6 {
663                if !nodes6_data.is_empty() {
664                    r_dict.insert(
665                        b"nodes6".to_vec(),
666                        serde_bencode::value::Value::Bytes(nodes6_data),
667                    );
668                }
669            } else if !nodes_data.is_empty() {
670                r_dict.insert(
671                    b"nodes".to_vec(),
672                    serde_bencode::value::Value::Bytes(nodes_data),
673                );
674            }
675        }
676
677        let mut response: std::collections::HashMap<String, serde_bencode::value::Value> =
678            std::collections::HashMap::new();
679        response.insert(
680            "t".to_string(),
681            serde_bencode::value::Value::Bytes(tid.to_vec()),
682        );
683        response.insert(
684            "y".to_string(),
685            serde_bencode::value::Value::Bytes(b"r".to_vec()),
686        );
687        response.insert("r".to_string(), serde_bencode::value::Value::Dict(r_dict));
688
689        if let Ok(encoded) = serde_bencode::to_bytes(&response) {
690            #[allow(unused)]
691            if let Ok(len) = socket.send_to(&encoded, remote_addr).await {
692                #[cfg(feature = "metrics")]
693                {
694                    counter!("dht_udp_bytes_sent_total").increment(len as u64);
695                    counter!("dht_udp_packets_sent_total", "type" => "response").increment(1);
696                }
697            }
698        }
699        Ok(())
700    }
701
702    async fn bootstrap(&self) {
703        let target = generate_random_id();
704        for node in BOOTSTRAP_NODES {
705            if let Ok(addrs) = tokio::net::lookup_host(node).await {
706                for addr in addrs {
707                    match self.options.netmode {
708                        NetMode::Ipv4Only => {
709                            if addr.is_ipv6() {
710                                continue;
711                            }
712                        }
713                        NetMode::Ipv6Only => {
714                            if addr.is_ipv4() {
715                                continue;
716                            }
717                        }
718                        NetMode::DualStack => {}
719                    }
720                    let _ = self.send_find_node(&addr, &target, &self.node_id).await;
721                }
722            }
723        }
724    }
725
726    fn socket_for_addr(&self, addr: &SocketAddr) -> Option<Arc<UdpSocket>> {
727        self.socket_providers
728            .iter()
729            .find(|(bind_addr, _)| bind_addr.is_ipv4() == addr.is_ipv4())
730            .map(|(_, sock)| sock.clone())
731    }
732
733    async fn send_find_node(&self, target_addr: &SocketAddr, target: &[u8], sender_id: &[u8]) {
734        if let Some(sock) = self.socket_for_addr(target_addr) {
735            send_find_node_impl(target_addr, target, sender_id, sock).await
736        }
737    }
738
739    fn generate_token(&self, addr: SocketAddr) -> [u8; 8] {
740        let mut hasher = ahash::AHasher::default();
741
742        match addr.ip() {
743            IpAddr::V4(ip) => ip.octets().hash(&mut hasher),
744            IpAddr::V6(ip) => ip.octets().hash(&mut hasher),
745        }
746
747        self.token_secret.hash(&mut hasher);
748
749        hasher.finish().to_le_bytes()
750    }
751
752    fn validate_token(&self, token: &[u8], addr: SocketAddr) -> bool {
753        if token.len() != 8 {
754            return false;
755        }
756        let expected = self.generate_token(addr);
757        token == expected
758    }
759}
760
761/// 发送 DHT find_node 查询消息
762///
763/// 这是 DHT 协议中的核心操作之一,用于向指定节点查询包含目标 ID 的节点信息。
764/// 该方法构建符合 BEP5 (BitTorrent DHT Protocol) 规范的消息并异步发送。
765///
766/// # 参数
767///
768/// * `addr` - 目标节点的 Socket 地址
769/// * `target` - 要查找的目标节点 ID (20 字节)
770/// * `sender_id` - 发送者的节点 ID (20 字节),用于标识自己
771/// * `socket` - IPv4 UDP socket 的引用
772/// * `socket_v6` - IPv6 UDP socket 的可选引用(仅在双栈模式下需要)
773/// * `netmode` - 网络模式:仅 IPv4、仅 IPv6 或双栈模式
774///
775/// # 返回值
776///
777/// 返回 `Result<()>`,成功时返回 `Ok(())`,失败时返回错误信息
778///
779/// # 消息格式
780///
781/// 构建的 DHT 消息格式如下:
782/// ```bencode
783/// {
784///   "t": [0, 1],           // 事务 ID (transaction ID)
785///   "y": "q",              // 消息类型:查询 (query)
786///   "q": "find_node",      // 查询类型:查找节点
787///   "a": {                 // 参数 (arguments)
788///     "id": <sender_id>,   // 发送者节点 ID
789///     "target": <target>   // 目标节点 ID
790///   }
791/// }
792/// ```
793///
794/// # 网络模式处理
795///
796/// * `Ipv4Only`: 始终使用 IPv4 socket
797/// * `Ipv6Only`: 始终使用 IPv4 socket(IPv6 模式下 socket 实际是 IPv6)
798/// * `DualStack`: 根据目标地址类型自动选择 IPv4 或 IPv6 socket
799async fn send_find_node_impl(
800    addr: &SocketAddr,
801    target: &[u8],
802    sender_id: &[u8],
803    socket: Arc<UdpSocket>,
804) {
805    // 构建查询参数
806    let mut args = std::collections::HashMap::new();
807    args.insert(
808        b"id".to_vec(),
809        serde_bencode::value::Value::Bytes(sender_id.to_vec()),
810    );
811    args.insert(
812        b"target".to_vec(),
813        serde_bencode::value::Value::Bytes(target.to_vec()),
814    );
815
816    // 构建完整的 DHT 消息
817    let mut msg: std::collections::HashMap<String, serde_bencode::value::Value> =
818        std::collections::HashMap::new();
819    msg.insert(
820        "t".to_string(),
821        serde_bencode::value::Value::Bytes(vec![0, 1]),
822    ); // 事务 ID
823    msg.insert(
824        "y".to_string(),
825        serde_bencode::value::Value::Bytes(b"q".to_vec()),
826    ); // 消息类型:查询
827    msg.insert(
828        "q".to_string(),
829        serde_bencode::value::Value::Bytes(b"find_node".to_vec()),
830    ); // 查询类型
831    msg.insert("a".to_string(), serde_bencode::value::Value::Dict(args)); // 参数字典
832
833    // 将消息编码为 bencode 格式并发送
834    if let Ok(encoded) = serde_bencode::to_bytes(&msg) {
835        // 异步发送 UDP 数据包
836        #[cfg(feature = "metrics")]
837        {
838            counter!("dht_udp_bytes_sent_total").increment(encoded.len() as u64);
839            counter!("dht_udp_packets_sent_total", "type" => "query").increment(1);
840        }
841        let _ = socket.send_to(&encoded, addr).await;
842    }
843}
844
845fn generate_random_id() -> [u8; 20] {
846    let mut id = [0u8; 20];
847    rand::thread_rng().fill(&mut id);
848    id
849}
850
851/// 生成邻居目标节点 ID
852///
853/// 该方法用于生成一个"看起来像"远程节点 ID 但实际基于本地节点 ID 的邻居节点 ID。
854/// 这是 DHT 协议中的一个重要优化策略,用于提高查询成功率和保护节点 ID 隐私。
855///
856/// # 工作原理
857///
858/// 1. 取远程节点 ID 的前 6 个字节作为前缀(如果远程 ID 长度足够)
859/// 2. 用本地节点 ID 的剩余部分填充
860/// 3. 如果本地 ID 不够长,用随机字节填充到 20 字节(标准 DHT 节点 ID 长度)
861///
862/// 这样生成的 ID 在 ID 空间中既接近远程节点(前 6 字节相同),又基于本地节点
863/// (后续字节来自本地 ID),从而在 DHT 路由时更容易获得相关响应。
864///
865/// # 参数
866///
867/// * `remote_id` - 远程节点的 ID(通常是查询目标节点或请求方的 ID)
868/// * `local_id` - 本地节点的 ID(通常是自己真实的节点 ID)
869///
870/// # 返回值
871///
872/// 返回一个 20 字节的节点 ID Vec,其前 6 字节来自 `remote_id`,后续字节来自 `local_id`
873///
874/// # 使用场景
875///
876/// 1. **发送查询时**:使用邻居 ID 作为发送者 ID,让远程节点认为查询来自一个接近目标 ID 的节点,
877///    从而返回更相关的节点列表
878/// 2. **发送响应时**:使用邻居 ID 作为响应中的节点 ID,保护真实本地 ID 的隐私,
879///    同时提高返回节点的相关性
880///
881/// # 示例
882///
883/// ```
884/// // 假设:
885/// // remote_id = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, ...]
886/// // local_id  = [0xAA, 0xBB, 0xCC, 0xDD, ...]
887/// // 生成结果 = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0xCC, 0xDD, ...]
888/// //           (前6字节来自remote_id,后续来自local_id)
889/// ```
890fn generate_neighbor_target(remote_id: &[u8], local_id: &[u8]) -> Vec<u8> {
891    let mut id = Vec::with_capacity(20);
892    let prefix_len = std::cmp::min(remote_id.len(), 6);
893    id.extend_from_slice(&remote_id[..prefix_len]);
894
895    if local_id.len() > prefix_len {
896        id.extend_from_slice(&local_id[prefix_len..]);
897    } else {
898        while id.len() < 20 {
899            id.push(rand::random());
900        }
901    }
902    id
903}
904
905fn spawn_udp_listener(
906    socket: Arc<UdpSocket>,
907    mut workers: Vec<WorkerHandle>,
908    shutdown: CancellationToken,
909) -> crate::error::Result<()> {
910    let local_addr = socket
911        .local_addr()
912        .map_err(|e| crate::error::DHTError::Init(format!("socket 无法获取本地地址: {e}")))?;
913    if workers.is_empty() {
914        return Err(crate::error::DHTError::Init(
915            "spawn_udp_listener: 未提供任何 worker".to_string(),
916        ));
917    }
918    tokio::spawn(async move {
919        let mut buffer = [0u8; 65536];
920        let mut worker_index = 0;
921
922        loop {
923            tokio::select! {
924                _ = shutdown.cancelled() => {
925                    #[cfg(debug_assertions)]
926                    log::trace!("UDP 读取循环收到关闭信号,退出");
927                    break;
928                }
929                result = socket.recv_from(&mut buffer) => {
930                    match result {
931                        Ok((size, origin_addr)) => {
932                            if let Err(ProcessUdpPacketError::NoLiveWorkers) = process_udp_packet(size, origin_addr, local_addr, &buffer, &mut workers, &mut worker_index){
933                                log::warn!("Socket {socket:?} is closing because no worker can process packets.");
934                                // TODO: Remove the dead socket, or find a way to supply workers.
935                                break
936                            }
937                        },
938                        Err(_e) => {
939                            tokio::select! {
940                                _ = shutdown.cancelled() => break,
941                                _ = tokio::time::sleep(Duration::from_millis(1)) => {},
942                            }
943                        }
944                    }
945                }
946            }
947        }
948    });
949    Ok(())
950}
951
952enum ProcessUdpPacketError {
953    PacketTooLarge,
954    InvalidPacket,
955    ChokedWorkers,
956    NoLiveWorkers,
957}
958
959fn process_udp_packet(
960    size: usize,
961    origin_addr: SocketAddr,
962    local_addr: SocketAddr,
963    buffer: &[u8],
964    workers: &mut Vec<WorkerHandle>,
965    worker_index: &mut usize,
966) -> std::result::Result<(), ProcessUdpPacketError> {
967    #[cfg(feature = "metrics")]
968    counter!("dht_udp_bytes_received_total").increment(size as u64);
969
970    if size > 8192 {
971        #[cfg(feature = "metrics")]
972        counter!("dht_udp_packets_received_total", "status" => "dropped_size").increment(1);
973
974        #[cfg(debug_assertions)]
975        log::trace!("⚠️ 拒绝异常大的 UDP 包: {} 字节 from {}", size, origin_addr);
976        return Err(ProcessUdpPacketError::PacketTooLarge);
977    }
978
979    if size == 0 || buffer[0] != b'd' {
980        #[cfg(feature = "metrics")]
981        counter!("dht_udp_packets_received_total", "status" => "dropped_magic").increment(1);
982        return Err(ProcessUdpPacketError::InvalidPacket);
983    }
984    let mut data = Some(buffer[..size].to_owned().into_boxed_slice());
985    let mut attempts = 0;
986    let max_attempts = workers.len();
987
988    while let Some(packet) = data.take() {
989        let worker = &workers[*worker_index];
990        match worker.try_send((packet, origin_addr, local_addr)) {
991            Ok(_) => {
992                #[cfg(feature = "metrics")]
993                counter!("dht_udp_packets_received_total", "status" => "ok").increment(1);
994                break;
995            }
996            Err(mpsc::error::TrySendError::Full((packet, _, _))) => {
997                attempts += 1;
998                if attempts >= max_attempts {
999                    #[cfg(feature = "metrics")]
1000                    counter!("dht_udp_packets_received_total", "status" => "queue_full")
1001                        .increment(1);
1002
1003                    #[cfg(debug_assertions)]
1004                    log::trace!("UDP worker queue full, dropping packet");
1005                    return Err(ProcessUdpPacketError::ChokedWorkers);
1006                }
1007                let _ = data.insert(packet);
1008            }
1009            Err(mpsc::error::TrySendError::Closed((packet, _, _))) => {
1010                log::warn!("UDP worker dropped.");
1011                workers.swap_remove(*worker_index);
1012                let _ = data.insert(packet);
1013            }
1014        }
1015        if workers.is_empty() {
1016            return Err(ProcessUdpPacketError::NoLiveWorkers);
1017        }
1018        *worker_index = (*worker_index + 1) % workers.len();
1019    }
1020
1021    Ok(())
1022}