Skip to main content

dht_crawler/
server.rs

1use crate::error::Result;
2use crate::metadata::RbitFetcher;
3use crate::protocol::{DhtArgs, DhtMessage, DhtResponse};
4use crate::scheduler::MetadataScheduler;
5use crate::sharded::{NodeTuple, ShardedNodeQueue};
6use crate::types::{DHTOptions, NetMode, TorrentInfo};
7use ahash::AHasher;
8#[cfg(feature = "metrics")]
9use metrics::{counter, gauge};
10use rand::Rng;
11use socket2::{Domain, Protocol, Socket, Type};
12use std::future::Future;
13use std::hash::{Hash, Hasher};
14use std::net::{IpAddr, Ipv6Addr, SocketAddr};
15use std::pin::Pin;
16use std::sync::atomic::{AtomicUsize, Ordering};
17use std::sync::{Arc, RwLock};
18use std::time::Duration;
19use tokio::net::UdpSocket;
20use tokio::sync::{Semaphore, mpsc};
21use tokio_util::sync::CancellationToken;
22
23const BOOTSTRAP_NODES: &[&str] = &[
24    "router.bittorrent.com:6881",
25    "dht.transmissionbt.com:6881",
26    "router.utorrent.com:6881",
27    "dht.aelitis.com:6881",
28];
29
30pub type BoxedBoolFuture = Pin<Box<dyn Future<Output = bool> + Send>>;
31pub type MetadataFetchCallback = Arc<dyn Fn(String) -> BoxedBoolFuture + Send + Sync>;
32
33#[derive(Debug, Clone)]
34pub struct HashDiscovered {
35    pub info_hash: String,
36    pub peer_addr: SocketAddr,
37    pub discovered_at: std::time::Instant,
38}
39
40type TorrentCallback = Arc<dyn Fn(TorrentInfo) + Send + Sync>;
41type FilterCallback = Arc<dyn Fn(&str) -> bool + Send + Sync>;
42
43#[derive(Clone)]
44pub struct DHTServer {
45    #[allow(dead_code)]
46    options: DHTOptions,
47    node_id: Vec<u8>,
48    socket: Arc<UdpSocket>,
49    socket_v6: Option<Arc<UdpSocket>>,
50    token_secret: Vec<u8>,
51    callback: Arc<RwLock<Option<TorrentCallback>>>,
52    filter: Arc<RwLock<Option<FilterCallback>>>,
53    on_metadata_fetch: Arc<RwLock<Option<MetadataFetchCallback>>>,
54    node_queue: Arc<ShardedNodeQueue>,
55    hash_tx: mpsc::Sender<HashDiscovered>,
56    metadata_queue_len: Arc<AtomicUsize>,
57    max_metadata_queue_size: usize,
58    shutdown: CancellationToken,
59}
60
61impl DHTServer {
62    pub async fn new(options: DHTOptions) -> Result<Self> {
63        let (socket, socket_v6) = match options.netmode {
64            NetMode::Ipv4Only => {
65                let sock = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
66                #[cfg(not(windows))]
67                {
68                    let _ = sock.set_reuse_port(true);
69                }
70                let _ = sock.set_reuse_address(true);
71                sock.set_nonblocking(true)?;
72
73                let _ = sock.set_recv_buffer_size(32 * 1024 * 1024);
74                let _ = sock.set_send_buffer_size(8 * 1024 * 1024);
75
76                let addr: SocketAddr = format!("0.0.0.0:{}", options.port).parse().unwrap();
77                sock.bind(&addr.into())?;
78                (Arc::new(UdpSocket::from_std(sock.into())?), None)
79            }
80            NetMode::Ipv6Only => {
81                let sock = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?;
82                #[cfg(not(windows))]
83                {
84                    let _ = sock.set_reuse_port(true);
85                }
86                let _ = sock.set_reuse_address(true);
87                #[cfg(not(windows))]
88                {
89                    let _ = sock.set_only_v6(true);
90                }
91                sock.set_nonblocking(true)?;
92
93                let _ = sock.set_recv_buffer_size(32 * 1024 * 1024);
94                let _ = sock.set_send_buffer_size(8 * 1024 * 1024);
95
96                let addr: SocketAddr = format!("[::]:{}", options.port).parse().unwrap();
97                sock.bind(&addr.into())?;
98                (Arc::new(UdpSocket::from_std(sock.into())?), None)
99            }
100            NetMode::DualStack => {
101                let sock_v4 = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
102                #[cfg(not(windows))]
103                {
104                    let _ = sock_v4.set_reuse_port(true);
105                }
106                let _ = sock_v4.set_reuse_address(true);
107                sock_v4.set_nonblocking(true)?;
108                let _ = sock_v4.set_recv_buffer_size(32 * 1024 * 1024);
109                let _ = sock_v4.set_send_buffer_size(8 * 1024 * 1024);
110                let addr_v4: SocketAddr = format!("0.0.0.0:{}", options.port).parse().unwrap();
111                sock_v4.bind(&addr_v4.into())?;
112                let socket = Arc::new(UdpSocket::from_std(sock_v4.into())?);
113
114                let sock_v6 = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?;
115                #[cfg(not(windows))]
116                {
117                    let _ = sock_v6.set_reuse_port(true);
118                }
119                let _ = sock_v6.set_reuse_address(true);
120                #[cfg(not(windows))]
121                {
122                    let _ = sock_v6.set_only_v6(true);
123                }
124                sock_v6.set_nonblocking(true)?;
125                let _ = sock_v6.set_recv_buffer_size(32 * 1024 * 1024);
126                let _ = sock_v6.set_send_buffer_size(8 * 1024 * 1024);
127                let addr_v6: SocketAddr = format!("[::]:{}", options.port).parse().unwrap();
128                sock_v6.bind(&addr_v6.into())?;
129                let socket_v6 = Some(Arc::new(UdpSocket::from_std(sock_v6.into())?));
130
131                (socket, socket_v6)
132            }
133        };
134
135        let node_id = generate_random_id();
136        let mut rng = rand::thread_rng();
137        let token_secret: Vec<u8> = (0..10).map(|_| rng.r#gen::<u8>()).collect();
138
139        let node_queue = ShardedNodeQueue::new(options.node_queue_capacity);
140
141        let (hash_tx, hash_rx) = mpsc::channel::<HashDiscovered>(options.hash_queue_capacity);
142
143        let fetcher = Arc::new(RbitFetcher::new(options.metadata_timeout));
144
145        let callback = Arc::new(RwLock::new(None));
146        let on_metadata_fetch = Arc::new(RwLock::new(None));
147
148        let metadata_queue_len = Arc::new(AtomicUsize::new(0));
149
150        let shutdown = CancellationToken::new();
151        let shutdown_for_scheduler = shutdown.clone();
152
153        let scheduler = MetadataScheduler::new(
154            hash_rx,
155            fetcher,
156            options.max_metadata_queue_size,
157            options.max_metadata_worker_count,
158            callback.clone(),
159            on_metadata_fetch.clone(),
160            metadata_queue_len.clone(),
161            shutdown_for_scheduler,
162        );
163
164        tokio::spawn(async move {
165            scheduler.run().await;
166        });
167
168        let max_metadata_queue_size = options.max_metadata_queue_size;
169        let server = Self {
170            options,
171            node_id: node_id.clone(),
172            socket,
173            socket_v6,
174            token_secret,
175            callback,
176            on_metadata_fetch,
177            node_queue: Arc::new(node_queue),
178            filter: Arc::new(RwLock::new(None)),
179            hash_tx,
180            metadata_queue_len,
181            max_metadata_queue_size,
182            shutdown,
183        };
184
185        Ok(server)
186    }
187
188    pub fn local_addr(&self) -> Result<SocketAddr> {
189        Ok(self.socket.local_addr()?)
190    }
191
192    fn is_addr_allowed(&self, addr: &SocketAddr) -> bool {
193        match self.options.netmode {
194            NetMode::Ipv4Only => addr.is_ipv4(),
195            NetMode::Ipv6Only => addr.is_ipv6(),
196            NetMode::DualStack => true,
197        }
198    }
199
200    fn select_socket(&self, addr: &SocketAddr) -> &Arc<UdpSocket> {
201        match self.options.netmode {
202            NetMode::Ipv4Only => &self.socket,
203            NetMode::Ipv6Only => &self.socket,
204            NetMode::DualStack => {
205                if addr.is_ipv6() {
206                    self.socket_v6.as_ref().unwrap_or(&self.socket)
207                } else {
208                    &self.socket
209                }
210            }
211        }
212    }
213
214    pub fn on_metadata_fetch<F, Fut>(&self, callback: F)
215    where
216        F: Fn(String) -> Fut + Send + Sync + 'static,
217        Fut: Future<Output = bool> + Send + 'static,
218    {
219        *self.on_metadata_fetch.write().unwrap() =
220            Some(Arc::new(move |hash| Box::pin(callback(hash))));
221    }
222
223    pub fn on_torrent<F>(&self, callback: F)
224    where
225        F: Fn(TorrentInfo) + Send + Sync + 'static,
226    {
227        *self.callback.write().unwrap() = Some(Arc::new(callback));
228    }
229
230    pub fn set_filter<F>(&self, filter: F)
231    where
232        F: Fn(&str) -> bool + Send + Sync + 'static,
233    {
234        *self.filter.write().unwrap() = Some(Arc::new(filter));
235    }
236
237    pub fn get_node_pool_size(&self) -> usize {
238        self.node_queue.len()
239    }
240
241    pub async fn start(&self) -> Result<()> {
242        // 检查是否已经被关闭
243        if self.shutdown.is_cancelled() {
244            log::warn!("⚠️ 尝试启动已关闭的服务器");
245            return Err(crate::error::DHTError::Other("服务器已关闭".to_string()));
246        }
247
248        self.start_receiver();
249        self.bootstrap().await;
250
251        let server = self.clone();
252        let shutdown = self.shutdown.clone();
253
254        tokio::spawn(async move {
255            let semaphore = Arc::new(Semaphore::new(2000));
256            let mut loop_tick = 0;
257
258            loop {
259                // 检查关闭信号
260                if shutdown.is_cancelled() {
261                    #[cfg(debug_assertions)]
262                    log::trace!("主循环收到关闭信号,退出");
263                    break;
264                }
265
266                let queue_len = server.metadata_queue_len.load(Ordering::Relaxed);
267                let queue_pressure = queue_len as f64 / server.max_metadata_queue_size as f64;
268
269                #[cfg(feature = "metrics")]
270                {
271                    gauge!("dht_metadata_queue_size").set(queue_len as f64);
272                    gauge!("dht_metadata_worker_pressure").set(queue_pressure);
273                    gauge!("dht_node_queue_size").set(server.node_queue.len() as f64);
274                }
275
276                let (batch_size, sleep_duration) = if queue_pressure < 0.8 {
277                    (200, Duration::from_millis(10))
278                } else if queue_pressure < 0.95 {
279                    (20, Duration::from_millis(500))
280                } else {
281                    (0, Duration::from_millis(1000))
282                };
283
284                let filter_ipv6 = match server.options.netmode {
285                    NetMode::Ipv4Only => Some(false),
286                    NetMode::Ipv6Only => Some(true),
287                    NetMode::DualStack => None,
288                };
289
290                let queue_empty = server.node_queue.is_empty_for(filter_ipv6);
291
292                let nodes_batch = {
293                    if queue_empty || batch_size == 0 {
294                        None
295                    } else {
296                        Some(server.node_queue.pop_batch(batch_size, filter_ipv6))
297                    }
298                };
299
300                loop_tick += 1;
301                if nodes_batch.is_none() || loop_tick % 50 == 0 {
302                    server.bootstrap().await;
303                    if nodes_batch.is_none() {
304                        tokio::select! {
305                            _ = shutdown.cancelled() => break,
306                            _ = tokio::time::sleep(sleep_duration) => {},
307                        }
308                        continue;
309                    }
310                }
311
312                if let Some(nodes) = nodes_batch {
313                    let node_id = server.node_id.clone();
314                    let socket = server.socket.clone();
315                    let socket_v6 = server.socket_v6.clone();
316                    let netmode = server.options.netmode;
317
318                    for node in nodes {
319                        let permit = semaphore.clone().acquire_owned().await.unwrap();
320                        let node_id_clone = node_id.clone();
321                        let socket_clone = socket.clone();
322                        let socket_v6_clone = socket_v6.clone();
323                        let node_addr = node.addr;
324                        let node_id_for_target = node.id;
325
326                        tokio::spawn(async move {
327                            let neighbor_id =
328                                generate_neighbor_target(&node_id_for_target, &node_id_clone);
329                            let random_target = generate_random_id();
330                            let _ = send_find_node_impl(
331                                node_addr,
332                                &random_target,
333                                &neighbor_id,
334                                &socket_clone,
335                                socket_v6_clone.as_ref(),
336                                netmode,
337                            )
338                            .await;
339                            drop(permit);
340                        });
341                    }
342                }
343
344                tokio::select! {
345                    _ = shutdown.cancelled() => break,
346                    _ = tokio::time::sleep(sleep_duration) => {},
347                }
348            }
349        });
350        self.shutdown.cancelled().await;
351        Ok(())
352    }
353
354    /// 显式关闭服务器,停止所有后台任务
355    pub fn shutdown(&self) {
356        self.shutdown.cancel();
357    }
358
359    fn start_receiver(&self) {
360        let socket = self.socket.clone();
361        let socket_v6 = self.socket_v6.clone();
362        let server = self.clone();
363        let shutdown = self.shutdown.clone();
364
365        let num_workers = std::thread::available_parallelism()
366            .map(|n| n.get())
367            .unwrap_or(8);
368
369        let queue_size = 5000;
370
371        let mut senders = Vec::with_capacity(num_workers);
372        for _ in 0..num_workers {
373            let (tx, mut rx) = mpsc::channel::<(Vec<u8>, SocketAddr)>(queue_size);
374            senders.push(tx);
375
376            let server_clone = server.clone();
377            let shutdown_worker = shutdown.clone();
378
379            tokio::spawn(async move {
380                loop {
381                    tokio::select! {
382                        _ = shutdown_worker.cancelled() => {
383                            #[cfg(debug_assertions)]
384                            log::trace!("Worker 收到关闭信号,退出");
385                            break;
386                        }
387                        msg = rx.recv() => {
388                            match msg {
389                                Some((data, addr)) => {
390                                    let _ = server_clone.handle_message(&data, addr).await;
391                                }
392                                None => break,
393                            }
394                        }
395                    }
396                }
397            });
398        }
399
400        Self::spawn_udp_reader(socket, senders.clone(), shutdown.clone());
401
402        if let Some(socket_v6) = socket_v6 {
403            Self::spawn_udp_reader(socket_v6, senders, shutdown);
404        }
405    }
406
407    fn spawn_udp_reader(
408        socket: Arc<UdpSocket>,
409        senders: Vec<mpsc::Sender<(Vec<u8>, SocketAddr)>>,
410        shutdown: CancellationToken,
411    ) {
412        let num_workers = senders.len();
413
414        tokio::spawn(async move {
415            let mut buf = [0u8; 65536];
416            let mut next_worker_idx = 0;
417
418            loop {
419                tokio::select! {
420                    _ = shutdown.cancelled() => {
421                        #[cfg(debug_assertions)]
422                        log::trace!("UDP 读取循环收到关闭信号,退出");
423                        break;
424                    }
425                    result = socket.recv_from(&mut buf) => {
426                        match result {
427                            Ok((size, addr)) => {
428                                #[cfg(feature = "metrics")]
429                                counter!("dht_udp_bytes_received_total").increment(size as u64);
430
431                                if size > 8192 {
432                                    #[cfg(feature = "metrics")]
433                                    counter!("dht_udp_packets_received_total", "status" => "dropped_size").increment(1);
434
435                                    #[cfg(debug_assertions)]
436                                    log::trace!("⚠️ 拒绝异常大的 UDP 包: {} 字节 from {}", size, addr);
437                                    continue;
438                                }
439
440                                if size == 0 || buf[0] != b'd' {
441                                    #[cfg(feature = "metrics")]
442                                    counter!("dht_udp_packets_received_total", "status" => "dropped_magic").increment(1);
443                                    continue;
444                                }
445
446                                let data = buf[..size].to_vec();
447
448                                let tx = &senders[next_worker_idx];
449                                next_worker_idx = (next_worker_idx + 1) % num_workers;
450
451                                match tx.try_send((data, addr)) {
452                                    Ok(_) => {
453                                        #[cfg(feature = "metrics")]
454                                        counter!("dht_udp_packets_received_total", "status" => "ok").increment(1);
455                                    },
456                                    Err(mpsc::error::TrySendError::Full(_)) => {
457                                        #[cfg(feature = "metrics")]
458                                        counter!("dht_udp_packets_received_total", "status" => "queue_full").increment(1);
459
460                                        #[cfg(debug_assertions)]
461                                        log::trace!("UDP worker queue full, dropping packet");
462                                    },
463                                    Err(_) => { break; }
464                                }
465                            }
466                            Err(_e) => {
467                                tokio::select! {
468                                    _ = shutdown.cancelled() => break,
469                                    _ = tokio::time::sleep(Duration::from_millis(1)) => {},
470                                }
471                            }
472                        }
473                    }
474                }
475            }
476        });
477    }
478
479    async fn handle_message(&self, data: &[u8], addr: SocketAddr) -> Result<()> {
480        if !self.is_addr_allowed(&addr) {
481            #[cfg(debug_assertions)]
482            log::trace!(
483                "⚠️ 拒绝不匹配的地址类型: {} (当前模式: {:?})",
484                addr,
485                self.options.netmode
486            );
487            return Ok(());
488        }
489
490        let msg: DhtMessage = match serde_bencode::from_bytes(data) {
491            Ok(m) => m,
492            Err(_) => {
493                #[cfg(feature = "metrics")]
494                counter!("dht_messages_parse_error_total").increment(1);
495                return Ok(());
496            }
497        };
498
499        #[cfg(feature = "metrics")]
500        {
501            // 使用 match 映射到静态字符串,避免 clone(),同时防止恶意 tag
502            let label = match msg.y.as_str() {
503                "q" => "q",
504                "r" => "r",
505                "e" => "e",
506                _ => "unknown", // 将所有非法/未知类型归一化
507            };
508            counter!("dht_messages_processed_total", "type" => label).increment(1);
509        }
510
511        match msg.y.as_str() {
512            "q" => {
513                if let Some(q_type) = &msg.q {
514                    self.handle_query(&msg, q_type.as_bytes(), addr).await?;
515                }
516            }
517            "r" => {
518                if let Some(response) = &msg.r {
519                    self.handle_response(response).await?;
520                }
521            }
522            _ => {}
523        }
524        Ok(())
525    }
526
527    async fn handle_query(
528        &self,
529        msg: &DhtMessage,
530        query_type: &[u8],
531        addr: SocketAddr,
532    ) -> Result<()> {
533        let args = match &msg.a {
534            Some(a) => a,
535            None => return Ok(()),
536        };
537
538        let transaction_id = &msg.t;
539        let sender_id: Option<&[u8]> = args.id.as_deref().map(|v| v.as_slice());
540        let target_id_fallback: Option<&[u8]> = args
541            .target
542            .as_deref()
543            .or(args.info_hash.as_deref())
544            .map(|v| v.as_slice());
545
546        let q_str = std::str::from_utf8(query_type).unwrap_or("");
547
548        #[cfg(feature = "metrics")]
549        {
550            let label = match q_str {
551                "ping" => "ping",
552                "find_node" => "find_node",
553                "get_peers" => "get_peers",
554                "announce_peer" => "announce_peer",
555                "vote" => "vote",
556                _ => "other_or_invalid",
557            };
558            counter!("dht_queries_total", "q" => label).increment(1);
559        }
560
561        if q_str == "announce_peer" {
562            self.handle_announce_peer(args, addr).await?;
563        }
564
565        self.send_response(transaction_id, addr, q_str, sender_id, target_id_fallback)
566            .await?;
567        Ok(())
568    }
569
570    async fn handle_announce_peer(&self, args: &DhtArgs, addr: SocketAddr) -> Result<()> {
571        if let Some(token) = &args.token {
572            if !self.validate_token(token, addr) {
573                #[cfg(feature = "metrics")]
574                counter!("dht_announce_peer_blocked_total", "reason" => "invalid_token")
575                    .increment(1);
576                return Ok(());
577            }
578        } else {
579            return Ok(());
580        }
581
582        if let Some(info_hash) = &args.info_hash {
583            let info_hash_arr: [u8; 20] = match info_hash.as_ref().try_into() {
584                Ok(arr) => arr,
585                Err(_) => return Ok(()),
586            };
587            let hash_hex = hex::encode(info_hash_arr);
588
589            let filter_cb = self.filter.read().unwrap().clone();
590            if let Some(f) = filter_cb
591                && !f(&hash_hex) {
592                #[cfg(feature = "metrics")]
593                counter!("dht_announce_peer_blocked_total", "reason" => "filtered")
594                    .increment(1);
595                return Ok(());
596            }
597
598            #[cfg(feature = "metrics")]
599            counter!("dht_info_hashes_discovered_total").increment(1);
600
601            #[cfg(debug_assertions)]
602            log::debug!("🔥 新 Hash: {} 来自 {}", hash_hex, addr);
603
604            let port = if let Some(implied) = args.implied_port {
605                if implied != 0 {
606                    addr.port()
607                } else {
608                    args.port.unwrap_or(0)
609                }
610            } else {
611                args.port.unwrap_or(addr.port())
612            };
613
614            if port > 0 {
615                let event = HashDiscovered {
616                    info_hash: hash_hex,
617                    peer_addr: SocketAddr::new(addr.ip(), port),
618                    discovered_at: std::time::Instant::now(),
619                };
620
621                if self.hash_tx.try_send(event).is_err() {
622                    #[cfg(debug_assertions)]
623                    log::debug!("⚠️ Hash 队列满,丢弃 hash");
624                }
625            }
626        }
627        Ok(())
628    }
629
630    async fn handle_response(&self, response: &DhtResponse) -> Result<()> {
631        if let Some(nodes_bytes) = &response.nodes {
632            self.process_compact_nodes(nodes_bytes);
633        }
634        if let Some(nodes6_bytes) = &response.nodes6 {
635            self.process_compact_nodes_v6(nodes6_bytes);
636        }
637        Ok(())
638    }
639
640    fn process_compact_nodes(&self, nodes_bytes: &[u8]) {
641        if self.options.netmode == NetMode::Ipv6Only {
642            return;
643        }
644
645        #[allow(clippy::manual_is_multiple_of)]
646        if nodes_bytes.len() % 26 != 0 {
647            return;
648        }
649
650        for chunk in nodes_bytes.chunks(26) {
651            let id = chunk[0..20].to_vec();
652            let port = u16::from_be_bytes([chunk[24], chunk[25]]);
653
654            let ip = std::net::Ipv4Addr::new(chunk[20], chunk[21], chunk[22], chunk[23]);
655            let addr = SocketAddr::new(std::net::IpAddr::V4(ip), port);
656
657            #[cfg(feature = "metrics")]
658            counter!("dht_nodes_discovered_total", "ip_version" => "v4").increment(1);
659
660            self.node_queue.push(NodeTuple { id, addr });
661        }
662    }
663
664    fn process_compact_nodes_v6(&self, nodes_bytes: &[u8]) {
665        if self.options.netmode == NetMode::Ipv4Only {
666            return;
667        }
668
669        #[allow(clippy::manual_is_multiple_of)]
670        if nodes_bytes.len() % 38 != 0 {
671            return;
672        }
673        for chunk in nodes_bytes.chunks(38) {
674            let id = chunk[0..20].to_vec();
675            let port = u16::from_be_bytes([chunk[36], chunk[37]]);
676            let ip_bytes: [u8; 16] = match chunk[20..36].try_into() {
677                Ok(b) => b,
678                Err(_) => continue,
679            };
680            let ip = Ipv6Addr::from(ip_bytes);
681            if !ip.is_unspecified() && !ip.is_multicast() {
682                let addr = SocketAddr::new(IpAddr::V6(ip), port);
683
684                #[cfg(feature = "metrics")]
685                counter!("dht_nodes_discovered_total", "ip_version" => "v6").increment(1);
686
687                self.node_queue.push(NodeTuple { id, addr });
688            }
689        }
690    }
691
692    async fn send_response(
693        &self,
694        tid: &[u8],
695        addr: SocketAddr,
696        query_type: &str,
697        sender_id: Option<&[u8]>,
698        target_id_fallback: Option<&[u8]>,
699    ) -> Result<()> {
700        let mut r_dict = std::collections::HashMap::new();
701
702        let reference_id = sender_id.or(target_id_fallback);
703        let my_id = if let Some(target) = reference_id {
704            generate_neighbor_target(target, &self.node_id)
705        } else {
706            self.node_id.clone()
707        };
708
709        r_dict.insert(b"id".to_vec(), serde_bencode::value::Value::Bytes(my_id));
710        let token = self.generate_token(addr);
711        r_dict.insert(b"token".to_vec(), serde_bencode::value::Value::Bytes(token));
712
713        if query_type == "get_peers" || query_type == "find_node" {
714            let requestor_is_ipv6 = addr.is_ipv6();
715            let filter_ipv6 = match self.options.netmode {
716                NetMode::Ipv4Only => Some(false),
717                NetMode::Ipv6Only => Some(true),
718                NetMode::DualStack => Some(requestor_is_ipv6),
719            };
720
721            let nodes = self.node_queue.get_random_nodes(8, filter_ipv6);
722
723            let mut nodes_data = Vec::new();
724            let mut nodes6_data = Vec::new();
725
726            for node in nodes {
727                match node.addr.ip() {
728                    IpAddr::V4(ip) => {
729                        nodes_data.extend_from_slice(&node.id);
730                        nodes_data.extend_from_slice(&ip.octets());
731                        nodes_data.extend_from_slice(&node.addr.port().to_be_bytes());
732                    }
733                    IpAddr::V6(ip) => {
734                        nodes6_data.extend_from_slice(&node.id);
735                        nodes6_data.extend_from_slice(&ip.octets());
736                        nodes6_data.extend_from_slice(&node.addr.port().to_be_bytes());
737                    }
738                }
739            }
740
741            if requestor_is_ipv6 {
742                if !nodes6_data.is_empty() {
743                    r_dict.insert(
744                        b"nodes6".to_vec(),
745                        serde_bencode::value::Value::Bytes(nodes6_data),
746                    );
747                }
748            } else if !nodes_data.is_empty() {
749                r_dict.insert(
750                    b"nodes".to_vec(),
751                    serde_bencode::value::Value::Bytes(nodes_data),
752                );
753            }
754        }
755
756        let mut response: std::collections::HashMap<String, serde_bencode::value::Value> =
757            std::collections::HashMap::new();
758        response.insert(
759            "t".to_string(),
760            serde_bencode::value::Value::Bytes(tid.to_vec()),
761        );
762        response.insert(
763            "y".to_string(),
764            serde_bencode::value::Value::Bytes(b"r".to_vec()),
765        );
766        response.insert("r".to_string(), serde_bencode::value::Value::Dict(r_dict));
767
768        if let Ok(encoded) = serde_bencode::to_bytes(&response) {
769            #[cfg(feature = "metrics")]
770            {
771                counter!("dht_udp_bytes_sent_total").increment(encoded.len() as u64);
772                counter!("dht_udp_packets_sent_total", "type" => "response").increment(1);
773            }
774            let _ = self.select_socket(&addr).send_to(&encoded, addr).await;
775        }
776        Ok(())
777    }
778
779    async fn bootstrap(&self) {
780        let target = generate_random_id();
781        for node in BOOTSTRAP_NODES {
782            if let Ok(addrs) = tokio::net::lookup_host(node).await {
783                for addr in addrs {
784                    match self.options.netmode {
785                        NetMode::Ipv4Only => {
786                            if addr.is_ipv6() {
787                                continue;
788                            }
789                        }
790                        NetMode::Ipv6Only => {
791                            if addr.is_ipv4() {
792                                continue;
793                            }
794                        }
795                        NetMode::DualStack => {}
796                    }
797                    let _ = self.send_find_node(addr, &target, &self.node_id).await;
798                }
799            }
800        }
801    }
802
803    async fn send_find_node(
804        &self,
805        addr: SocketAddr,
806        target: &[u8],
807        sender_id: &[u8],
808    ) -> Result<()> {
809        send_find_node_impl(
810            addr,
811            target,
812            sender_id,
813            &self.socket,
814            self.socket_v6.as_ref(),
815            self.options.netmode,
816        )
817        .await
818    }
819
820    fn generate_token(&self, addr: SocketAddr) -> Vec<u8> {
821        let mut hasher = AHasher::default();
822
823        match addr.ip() {
824            IpAddr::V4(ip) => ip.octets().hash(&mut hasher),
825            IpAddr::V6(ip) => ip.octets().hash(&mut hasher),
826        }
827
828        self.token_secret.hash(&mut hasher);
829
830        let hash = hasher.finish();
831        hash.to_le_bytes().to_vec()
832    }
833
834    fn validate_token(&self, token: &[u8], addr: SocketAddr) -> bool {
835        if token.len() != 8 {
836            return false;
837        }
838        let expected = self.generate_token(addr);
839        token == expected.as_slice()
840    }
841}
842
843/// 发送 DHT find_node 查询消息
844///
845/// 这是 DHT 协议中的核心操作之一,用于向指定节点查询包含目标 ID 的节点信息。
846/// 该方法构建符合 BEP5 (BitTorrent DHT Protocol) 规范的消息并异步发送。
847///
848/// # 参数
849///
850/// * `addr` - 目标节点的 Socket 地址
851/// * `target` - 要查找的目标节点 ID (20 字节)
852/// * `sender_id` - 发送者的节点 ID (20 字节),用于标识自己
853/// * `socket` - IPv4 UDP socket 的引用
854/// * `socket_v6` - IPv6 UDP socket 的可选引用(仅在双栈模式下需要)
855/// * `netmode` - 网络模式:仅 IPv4、仅 IPv6 或双栈模式
856///
857/// # 返回值
858///
859/// 返回 `Result<()>`,成功时返回 `Ok(())`,失败时返回错误信息
860///
861/// # 消息格式
862///
863/// 构建的 DHT 消息格式如下:
864/// ```bencode
865/// {
866///   "t": [0, 1],           // 事务 ID (transaction ID)
867///   "y": "q",              // 消息类型:查询 (query)
868///   "q": "find_node",      // 查询类型:查找节点
869///   "a": {                 // 参数 (arguments)
870///     "id": <sender_id>,   // 发送者节点 ID
871///     "target": <target>   // 目标节点 ID
872///   }
873/// }
874/// ```
875///
876/// # 网络模式处理
877///
878/// * `Ipv4Only`: 始终使用 IPv4 socket
879/// * `Ipv6Only`: 始终使用 IPv4 socket(IPv6 模式下 socket 实际是 IPv6)
880/// * `DualStack`: 根据目标地址类型自动选择 IPv4 或 IPv6 socket
881async fn send_find_node_impl(
882    addr: SocketAddr,
883    target: &[u8],
884    sender_id: &[u8],
885    socket: &Arc<UdpSocket>,
886    socket_v6: Option<&Arc<UdpSocket>>,
887    netmode: NetMode,
888) -> Result<()> {
889    // 构建查询参数
890    let mut args = std::collections::HashMap::new();
891    args.insert(
892        b"id".to_vec(),
893        serde_bencode::value::Value::Bytes(sender_id.to_vec()),
894    );
895    args.insert(
896        b"target".to_vec(),
897        serde_bencode::value::Value::Bytes(target.to_vec()),
898    );
899
900    // 构建完整的 DHT 消息
901    let mut msg: std::collections::HashMap<String, serde_bencode::value::Value> =
902        std::collections::HashMap::new();
903    msg.insert(
904        "t".to_string(),
905        serde_bencode::value::Value::Bytes(vec![0, 1]),
906    ); // 事务 ID
907    msg.insert(
908        "y".to_string(),
909        serde_bencode::value::Value::Bytes(b"q".to_vec()),
910    ); // 消息类型:查询
911    msg.insert(
912        "q".to_string(),
913        serde_bencode::value::Value::Bytes(b"find_node".to_vec()),
914    ); // 查询类型
915    msg.insert("a".to_string(), serde_bencode::value::Value::Dict(args)); // 参数字典
916
917    // 将消息编码为 bencode 格式并发送
918    if let Ok(encoded) = serde_bencode::to_bytes(&msg) {
919        // 根据网络模式选择合适的 socket
920        let selected_socket = match netmode {
921            NetMode::Ipv4Only => socket,
922            NetMode::Ipv6Only => socket,
923            NetMode::DualStack => {
924                if addr.is_ipv6() {
925                    socket_v6.unwrap_or(socket)
926                } else {
927                    socket
928                }
929            }
930        };
931        // 异步发送 UDP 数据包
932        #[cfg(feature = "metrics")]
933        {
934            counter!("dht_udp_bytes_sent_total").increment(encoded.len() as u64);
935            counter!("dht_udp_packets_sent_total", "type" => "query").increment(1);
936        }
937        let _ = selected_socket.send_to(&encoded, addr).await;
938    }
939    Ok(())
940}
941
942fn generate_random_id() -> Vec<u8> {
943    let mut rng = rand::thread_rng();
944    (0..20).map(|_| rng.r#gen::<u8>()).collect()
945}
946
947/// 生成邻居目标节点 ID
948///
949/// 该方法用于生成一个"看起来像"远程节点 ID 但实际基于本地节点 ID 的邻居节点 ID。
950/// 这是 DHT 协议中的一个重要优化策略,用于提高查询成功率和保护节点 ID 隐私。
951///
952/// # 工作原理
953///
954/// 1. 取远程节点 ID 的前 6 个字节作为前缀(如果远程 ID 长度足够)
955/// 2. 用本地节点 ID 的剩余部分填充
956/// 3. 如果本地 ID 不够长,用随机字节填充到 20 字节(标准 DHT 节点 ID 长度)
957///
958/// 这样生成的 ID 在 ID 空间中既接近远程节点(前 6 字节相同),又基于本地节点
959/// (后续字节来自本地 ID),从而在 DHT 路由时更容易获得相关响应。
960///
961/// # 参数
962///
963/// * `remote_id` - 远程节点的 ID(通常是查询目标节点或请求方的 ID)
964/// * `local_id` - 本地节点的 ID(通常是自己真实的节点 ID)
965///
966/// # 返回值
967///
968/// 返回一个 20 字节的节点 ID Vec,其前 6 字节来自 `remote_id`,后续字节来自 `local_id`
969///
970/// # 使用场景
971///
972/// 1. **发送查询时**:使用邻居 ID 作为发送者 ID,让远程节点认为查询来自一个接近目标 ID 的节点,
973///    从而返回更相关的节点列表
974/// 2. **发送响应时**:使用邻居 ID 作为响应中的节点 ID,保护真实本地 ID 的隐私,
975///    同时提高返回节点的相关性
976///
977/// # 示例
978///
979/// ```
980/// // 假设:
981/// // remote_id = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, ...]
982/// // local_id  = [0xAA, 0xBB, 0xCC, 0xDD, ...]
983/// // 生成结果 = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0xCC, 0xDD, ...]
984/// //           (前6字节来自remote_id,后续来自local_id)
985/// ```
986fn generate_neighbor_target(remote_id: &[u8], local_id: &[u8]) -> Vec<u8> {
987    let mut id = Vec::with_capacity(20);
988    let prefix_len = std::cmp::min(remote_id.len(), 6);
989    id.extend_from_slice(&remote_id[..prefix_len]);
990
991    if local_id.len() > prefix_len {
992        id.extend_from_slice(&local_id[prefix_len..]);
993    } else {
994        while id.len() < 20 {
995            id.push(rand::random());
996        }
997    }
998    id
999}