Skip to main content

dht_crawler/
scheduler.rs

1use crate::metadata::RbitFetcher;
2use crate::server::HashDiscovered;
3use crate::types::TorrentInfo;
4use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
5use std::sync::{Arc, RwLock};
6use tokio::sync::mpsc;
7use tokio_util::sync::CancellationToken;
8
9type TorrentCallback = Arc<dyn Fn(TorrentInfo) + Send + Sync>;
10type MetadataFetchCallback = Arc<
11    dyn Fn(String) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>>
12        + Send
13        + Sync,
14>;
15
16pub struct MetadataScheduler {
17    hash_rx: mpsc::Receiver<HashDiscovered>,
18    max_queue_size: usize,
19    max_concurrent: usize,
20    fetcher: Arc<RbitFetcher>,
21    callback: Arc<RwLock<Option<TorrentCallback>>>,
22    on_metadata_fetch: Arc<RwLock<Option<MetadataFetchCallback>>>,
23    total_received: Arc<AtomicU64>,
24    total_dropped: Arc<AtomicU64>,
25    total_dispatched: Arc<AtomicU64>,
26    queue_len: Arc<AtomicUsize>,
27    shutdown: CancellationToken,
28}
29
30impl MetadataScheduler {
31    #[allow(clippy::too_many_arguments)]
32    pub fn new(
33        hash_rx: mpsc::Receiver<HashDiscovered>,
34        fetcher: Arc<RbitFetcher>,
35        max_queue_size: usize,
36        max_concurrent: usize,
37        callback: Arc<RwLock<Option<TorrentCallback>>>,
38        on_metadata_fetch: Arc<RwLock<Option<MetadataFetchCallback>>>,
39        queue_len: Arc<AtomicUsize>,
40        shutdown: CancellationToken,
41    ) -> Self {
42        Self {
43            hash_rx,
44            max_queue_size,
45            max_concurrent,
46            fetcher,
47            callback,
48            on_metadata_fetch,
49            total_received: Arc::new(AtomicU64::new(0)),
50            total_dropped: Arc::new(AtomicU64::new(0)),
51            total_dispatched: Arc::new(AtomicU64::new(0)),
52            queue_len,
53            shutdown,
54        }
55    }
56
57    pub fn set_callback(&mut self, callback: TorrentCallback) {
58        if let Ok(mut guard) = self.callback.try_write() {
59            *guard = Some(callback);
60        }
61    }
62
63    pub fn set_metadata_fetch_callback(&mut self, callback: MetadataFetchCallback) {
64        if let Ok(mut guard) = self.on_metadata_fetch.try_write() {
65            *guard = Some(callback);
66        }
67    }
68
69    pub async fn run(mut self) {
70        let (task_tx, task_rx) = async_channel::bounded::<HashDiscovered>(self.max_queue_size);
71
72        let shutdown = self.shutdown.clone();
73        #[cfg_attr(not(debug_assertions), allow(unused_variables))]
74        for worker_id in 0..self.max_concurrent {
75            let task_rx = task_rx.clone();
76            let fetcher = self.fetcher.clone();
77            let callback = self.callback.clone();
78            let on_metadata_fetch = self.on_metadata_fetch.clone();
79            let total_dispatched = self.total_dispatched.clone();
80            let queue_len = self.queue_len.clone();
81            let shutdown_worker = shutdown.clone();
82
83            tokio::spawn(async move {
84                #[cfg(debug_assertions)]
85                log::trace!("Worker {} 启动", worker_id);
86
87                loop {
88                    tokio::select! {
89                        _ = shutdown_worker.cancelled() => {
90                            #[cfg(debug_assertions)]
91                            log::trace!("Worker {} 收到关闭信号,退出", worker_id);
92                            break;
93                        }
94                        result = task_rx.recv() => {
95                            let hash = match result {
96                                Ok(h) => {
97                                    queue_len.fetch_sub(1, Ordering::Relaxed);
98                                    h
99                                }
100                                Err(_) => break,
101                            };
102
103                            total_dispatched.fetch_add(1, Ordering::Relaxed);
104
105                            Self::process_hash(
106                                hash,
107                                &fetcher,
108                                &callback,
109                                &on_metadata_fetch,
110                            ).await;
111                        }
112                    }
113                }
114
115                #[cfg(debug_assertions)]
116                log::trace!("Worker {} 退出", worker_id);
117            });
118        }
119
120        let mut stats_interval = if cfg!(debug_assertions) {
121            Some(tokio::time::interval(std::time::Duration::from_secs(60)))
122        } else {
123            None
124        };
125        if let Some(ref mut interval) = stats_interval {
126            interval.tick().await;
127        }
128
129        let shutdown = self.shutdown.clone();
130        loop {
131            tokio::select! {
132                _ = shutdown.cancelled() => {
133                    #[cfg(debug_assertions)]
134                    log::trace!("MetadataScheduler 主循环收到关闭信号,退出");
135                    break;
136                }
137                result = self.hash_rx.recv() => {
138                    match result {
139                        Some(hash) => {
140                            self.total_received.fetch_add(1, Ordering::Relaxed);
141
142                            match task_tx.try_send(hash) {
143                                Ok(_) => {
144                                    self.queue_len.fetch_add(1, Ordering::Relaxed);
145                                }
146                                Err(async_channel::TrySendError::Full(_)) => {
147                                    self.total_dropped.fetch_add(1, Ordering::Relaxed);
148                                }
149                                Err(_) => break,
150                            }
151                        }
152                        None => break,
153                    }
154                }
155                _ = async {
156                    match stats_interval.as_mut() {
157                        Some(interval) => interval.tick().await,
158                        None => std::future::pending().await,
159                    }
160                } => {
161                    self.print_stats_inline();
162                }
163            }
164        }
165
166        drop(task_tx);
167        #[cfg(debug_assertions)]
168        log::trace!("MetadataScheduler 主循环退出,等待 worker 任务完成");
169    }
170
171    async fn process_hash(
172        hash: HashDiscovered,
173        fetcher: &Arc<RbitFetcher>,
174        callback: &Arc<RwLock<Option<TorrentCallback>>>,
175        on_metadata_fetch: &Arc<RwLock<Option<MetadataFetchCallback>>>,
176    ) {
177        let info_hash = hash.info_hash.clone();
178        let peer_addr = hash.peer_addr;
179
180        let maybe_check_fn = {
181            match on_metadata_fetch.read() {
182                Ok(guard) => guard.clone(),
183                Err(_) => return,
184            }
185        };
186
187        if let Some(f) = maybe_check_fn
188            && !f(info_hash.clone()).await
189        {
190            return;
191        }
192
193        let info_hash_bytes: [u8; 20] = match hex::decode(&info_hash) {
194            Ok(bytes) if bytes.len() == 20 => {
195                let mut arr = [0u8; 20];
196                arr.copy_from_slice(&bytes);
197                arr
198            }
199            _ => return,
200        };
201
202        if let Some((name, total_size, files, piece_length)) =
203            fetcher.fetch(&info_hash_bytes, peer_addr).await
204        {
205            let metadata = TorrentInfo {
206                info_hash,
207                name,
208                total_size,
209                files,
210                magnet_link: format!("magnet:?xt=urn:btih:{}", hash.info_hash),
211                peers: vec![peer_addr.to_string()],
212                piece_length,
213                timestamp: std::time::SystemTime::now()
214                    .duration_since(std::time::UNIX_EPOCH)
215                    .unwrap_or_default()
216                    .as_secs(),
217            };
218
219            let maybe_torrent_cb = {
220                match callback.read() {
221                    Ok(guard) => guard.clone(),
222                    Err(_) => return,
223                }
224            };
225
226            if let Some(cb) = maybe_torrent_cb {
227                cb(metadata);
228            }
229        }
230    }
231
232    fn print_stats_inline(&self) {
233        #[cfg(debug_assertions)]
234        {
235            let received = self.total_received.load(Ordering::Relaxed);
236            let dropped = self.total_dropped.load(Ordering::Relaxed);
237            let dispatched = self.total_dispatched.load(Ordering::Relaxed);
238
239            let drop_rate = if received > 0 {
240                dropped as f64 / received as f64 * 100.0
241            } else {
242                0.0
243            };
244
245            let queue_len = self.queue_len.load(Ordering::Relaxed);
246            let queue_pressure = (queue_len as f64 / self.max_queue_size as f64) * 100.0;
247
248            if queue_pressure > 80.0 {
249                log::warn!(
250                    "Metadata 队列高压:队列={}/{}({:.1}%), 接收={}, 调度={}, 丢弃={}({:.2}%)",
251                    queue_len,
252                    self.max_queue_size,
253                    queue_pressure,
254                    received,
255                    dispatched,
256                    dropped,
257                    drop_rate
258                );
259            } else {
260                log::info!(
261                    "Metadata 调度器统计:队列={}/{}({:.1}%), 接收={}, 调度={}, 丢弃={}({:.2}%)",
262                    queue_len,
263                    self.max_queue_size,
264                    queue_pressure,
265                    received,
266                    dispatched,
267                    dropped,
268                    drop_rate
269                );
270            }
271        }
272    }
273}