dht-crawler 0.1.2

高性能的 Rust DHT (Distributed Hash Table) 爬虫库 | A high-performance Rust DHT crawler library for fetching torrent information from the BitTorrent DHT network
Documentation
use crate::metadata::RbitFetcher;
use crate::server::HashDiscovered;
use crate::types::TorrentInfo;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::{Arc, RwLock};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;

type TorrentCallback = Arc<dyn Fn(TorrentInfo) + Send + Sync>;
type MetadataFetchCallback = Arc<
    dyn Fn(String) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>>
        + Send
        + Sync,
>;

pub struct MetadataScheduler {
    hash_rx: mpsc::Receiver<HashDiscovered>,
    max_queue_size: usize,
    max_concurrent: usize,
    fetcher: Arc<RbitFetcher>,
    callback: Arc<RwLock<Option<TorrentCallback>>>,
    on_metadata_fetch: Arc<RwLock<Option<MetadataFetchCallback>>>,
    total_received: Arc<AtomicU64>,
    total_dropped: Arc<AtomicU64>,
    total_dispatched: Arc<AtomicU64>,
    queue_len: Arc<AtomicUsize>,
    shutdown: CancellationToken,
}

impl MetadataScheduler {
    #[allow(clippy::too_many_arguments)]
    pub fn new(
        hash_rx: mpsc::Receiver<HashDiscovered>,
        fetcher: Arc<RbitFetcher>,
        max_queue_size: usize,
        max_concurrent: usize,
        callback: Arc<RwLock<Option<TorrentCallback>>>,
        on_metadata_fetch: Arc<RwLock<Option<MetadataFetchCallback>>>,
        queue_len: Arc<AtomicUsize>,
        shutdown: CancellationToken,
    ) -> Self {
        Self {
            hash_rx,
            max_queue_size,
            max_concurrent,
            fetcher,
            callback,
            on_metadata_fetch,
            total_received: Arc::new(AtomicU64::new(0)),
            total_dropped: Arc::new(AtomicU64::new(0)),
            total_dispatched: Arc::new(AtomicU64::new(0)),
            queue_len,
            shutdown,
        }
    }

    pub fn set_callback(&mut self, callback: TorrentCallback) {
        if let Ok(mut guard) = self.callback.try_write() {
            *guard = Some(callback);
        }
    }

    pub fn set_metadata_fetch_callback(&mut self, callback: MetadataFetchCallback) {
        if let Ok(mut guard) = self.on_metadata_fetch.try_write() {
            *guard = Some(callback);
        }
    }

    pub async fn run(mut self) {
        let (task_tx, task_rx) = async_channel::bounded::<HashDiscovered>(self.max_queue_size);

        let shutdown = self.shutdown.clone();
        #[cfg_attr(not(debug_assertions), allow(unused_variables))]
        for worker_id in 0..self.max_concurrent {
            let task_rx = task_rx.clone();
            let fetcher = self.fetcher.clone();
            let callback = self.callback.clone();
            let on_metadata_fetch = self.on_metadata_fetch.clone();
            let total_dispatched = self.total_dispatched.clone();
            let queue_len = self.queue_len.clone();
            let shutdown_worker = shutdown.clone();

            tokio::spawn(async move {
                #[cfg(debug_assertions)]
                log::trace!("Worker {} 启动", worker_id);

                loop {
                    tokio::select! {
                        _ = shutdown_worker.cancelled() => {
                            #[cfg(debug_assertions)]
                            log::trace!("Worker {} 收到关闭信号,退出", worker_id);
                            break;
                        }
                        result = task_rx.recv() => {
                            let hash = match result {
                                Ok(h) => {
                                    queue_len.fetch_sub(1, Ordering::Relaxed);
                                    h
                                }
                                Err(_) => break,
                            };

                            total_dispatched.fetch_add(1, Ordering::Relaxed);

                            Self::process_hash(
                                hash,
                                &fetcher,
                                &callback,
                                &on_metadata_fetch,
                            ).await;
                        }
                    }
                }

                #[cfg(debug_assertions)]
                log::trace!("Worker {} 退出", worker_id);
            });
        }

        let mut stats_interval = if cfg!(debug_assertions) {
            Some(tokio::time::interval(std::time::Duration::from_secs(60)))
        } else {
            None
        };
        if let Some(ref mut interval) = stats_interval {
            interval.tick().await;
        }

        let shutdown = self.shutdown.clone();
        loop {
            tokio::select! {
                _ = shutdown.cancelled() => {
                    #[cfg(debug_assertions)]
                    log::trace!("MetadataScheduler 主循环收到关闭信号,退出");
                    break;
                }
                result = self.hash_rx.recv() => {
                    match result {
                        Some(hash) => {
                            self.total_received.fetch_add(1, Ordering::Relaxed);

                            match task_tx.try_send(hash) {
                                Ok(_) => {
                                    self.queue_len.fetch_add(1, Ordering::Relaxed);
                                }
                                Err(async_channel::TrySendError::Full(_)) => {
                                    self.total_dropped.fetch_add(1, Ordering::Relaxed);
                                }
                                Err(_) => break,
                            }
                        }
                        None => break,
                    }
                }
                _ = async {
                    match stats_interval.as_mut() {
                        Some(interval) => interval.tick().await,
                        None => std::future::pending().await,
                    }
                } => {
                    self.print_stats_inline();
                }
            }
        }

        drop(task_tx);
        #[cfg(debug_assertions)]
        log::trace!("MetadataScheduler 主循环退出,等待 worker 任务完成");
    }

    async fn process_hash(
        hash: HashDiscovered,
        fetcher: &Arc<RbitFetcher>,
        callback: &Arc<RwLock<Option<TorrentCallback>>>,
        on_metadata_fetch: &Arc<RwLock<Option<MetadataFetchCallback>>>,
    ) {
        let info_hash = hash.info_hash.clone();
        let peer_addr = hash.peer_addr;

        let maybe_check_fn = {
            match on_metadata_fetch.read() {
                Ok(guard) => guard.clone(),
                Err(_) => return,
            }
        };

        if let Some(f) = maybe_check_fn
            && !f(info_hash.clone()).await
        {
            return;
        }

        let info_hash_bytes: [u8; 20] = match hex::decode(&info_hash) {
            Ok(bytes) if bytes.len() == 20 => {
                let mut arr = [0u8; 20];
                arr.copy_from_slice(&bytes);
                arr
            }
            _ => return,
        };

        if let Some((name, total_size, files, piece_length)) =
            fetcher.fetch(&info_hash_bytes, peer_addr).await
        {
            let metadata = TorrentInfo {
                info_hash,
                name,
                total_size,
                files,
                magnet_link: format!("magnet:?xt=urn:btih:{}", hash.info_hash),
                peers: vec![peer_addr.to_string()],
                piece_length,
                timestamp: std::time::SystemTime::now()
                    .duration_since(std::time::UNIX_EPOCH)
                    .unwrap_or_default()
                    .as_secs(),
            };

            let maybe_torrent_cb = {
                match callback.read() {
                    Ok(guard) => guard.clone(),
                    Err(_) => return,
                }
            };

            if let Some(cb) = maybe_torrent_cb {
                cb(metadata);
            }
        }
    }

    fn print_stats_inline(&self) {
        #[cfg(debug_assertions)]
        {
            let received = self.total_received.load(Ordering::Relaxed);
            let dropped = self.total_dropped.load(Ordering::Relaxed);
            let dispatched = self.total_dispatched.load(Ordering::Relaxed);

            let drop_rate = if received > 0 {
                dropped as f64 / received as f64 * 100.0
            } else {
                0.0
            };

            let queue_len = self.queue_len.load(Ordering::Relaxed);
            let queue_pressure = (queue_len as f64 / self.max_queue_size as f64) * 100.0;

            if queue_pressure > 80.0 {
                log::warn!(
                    "Metadata 队列高压:队列={}/{}({:.1}%), 接收={}, 调度={}, 丢弃={}({:.2}%)",
                    queue_len,
                    self.max_queue_size,
                    queue_pressure,
                    received,
                    dispatched,
                    dropped,
                    drop_rate
                );
            } else {
                log::info!(
                    "Metadata 调度器统计:队列={}/{}({:.1}%), 接收={}, 调度={}, 丢弃={}({:.2}%)",
                    queue_len,
                    self.max_queue_size,
                    queue_pressure,
                    received,
                    dispatched,
                    dropped,
                    drop_rate
                );
            }
        }
    }
}