Skip to main content

dht_crawler/
scheduler.rs

1use crate::metadata::RbitFetcher;
2use crate::server::HashDiscovered;
3use crate::types::TorrentInfo;
4use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
5use std::sync::{Arc, RwLock};
6#[cfg(debug_assertions)]
7use std::time::Duration;
8use tokio::sync::{Mutex, mpsc};
9use tokio_util::sync::CancellationToken;
10
11type TorrentCallback = Arc<dyn Fn(TorrentInfo) + Send + Sync>;
12type MetadataFetchCallback = Arc<
13    dyn Fn(String) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>>
14        + Send
15        + Sync,
16>;
17
18pub struct MetadataScheduler {
19    hash_rx: mpsc::Receiver<HashDiscovered>,
20    max_queue_size: usize,
21    max_concurrent: usize,
22    fetcher: Arc<RbitFetcher>,
23    callback: Arc<RwLock<Option<TorrentCallback>>>,
24    on_metadata_fetch: Arc<RwLock<Option<MetadataFetchCallback>>>,
25    total_received: Arc<AtomicU64>,
26    total_dropped: Arc<AtomicU64>,
27    total_dispatched: Arc<AtomicU64>,
28    queue_len: Arc<AtomicUsize>,
29    shutdown: CancellationToken,
30}
31
32impl MetadataScheduler {
33    #[allow(clippy::too_many_arguments)]
34    pub fn new(
35        hash_rx: mpsc::Receiver<HashDiscovered>,
36        fetcher: Arc<RbitFetcher>,
37        max_queue_size: usize,
38        max_concurrent: usize,
39        callback: Arc<RwLock<Option<TorrentCallback>>>,
40        on_metadata_fetch: Arc<RwLock<Option<MetadataFetchCallback>>>,
41        queue_len: Arc<AtomicUsize>,
42        shutdown: CancellationToken,
43    ) -> Self {
44        Self {
45            hash_rx,
46            max_queue_size,
47            max_concurrent,
48            fetcher,
49            callback,
50            on_metadata_fetch,
51            total_received: Arc::new(AtomicU64::new(0)),
52            total_dropped: Arc::new(AtomicU64::new(0)),
53            total_dispatched: Arc::new(AtomicU64::new(0)),
54            queue_len,
55            shutdown,
56        }
57    }
58
59    pub fn set_callback(&mut self, callback: TorrentCallback) {
60        if let Ok(mut guard) = self.callback.try_write() {
61            *guard = Some(callback);
62        }
63    }
64
65    pub fn set_metadata_fetch_callback(&mut self, callback: MetadataFetchCallback) {
66        if let Ok(mut guard) = self.on_metadata_fetch.try_write() {
67            *guard = Some(callback);
68        }
69    }
70
71    pub async fn run(mut self) {
72        let (task_tx, task_rx) = mpsc::channel::<HashDiscovered>(self.max_queue_size);
73        let task_rx = Arc::new(Mutex::new(task_rx));
74
75        let shutdown = self.shutdown.clone();
76        #[cfg_attr(not(debug_assertions), allow(unused_variables))]
77        for worker_id in 0..self.max_concurrent {
78            let task_rx = task_rx.clone();
79            let fetcher = self.fetcher.clone();
80            let callback = self.callback.clone();
81            let on_metadata_fetch = self.on_metadata_fetch.clone();
82            let total_dispatched = self.total_dispatched.clone();
83            let queue_len = self.queue_len.clone();
84            let shutdown_worker = shutdown.clone();
85
86            tokio::spawn(async move {
87                #[cfg(debug_assertions)]
88                log::trace!("Worker {} 启动", worker_id);
89
90                loop {
91                    tokio::select! {
92                        _ = shutdown_worker.cancelled() => {
93                            #[cfg(debug_assertions)]
94                            log::trace!("Worker {} 收到关闭信号,退出", worker_id);
95                            break;
96                        }
97                        result = async {
98                            let mut rx = task_rx.lock().await;
99                            rx.recv().await
100                        } => {
101                            let hash = {
102                                if result.is_some() {
103                                    queue_len.fetch_sub(1, Ordering::Relaxed);
104                                }
105                                result
106                            };
107
108                            let hash = match hash {
109                                Some(h) => h,
110                                None => break,
111                            };
112
113                            total_dispatched.fetch_add(1, Ordering::Relaxed);
114
115                            Self::process_hash(
116                                hash,
117                                &fetcher,
118                                &callback,
119                                &on_metadata_fetch,
120                            ).await;
121                        }
122                    }
123                }
124
125                #[cfg(debug_assertions)]
126                log::trace!("Worker {} 退出", worker_id);
127            });
128        }
129
130        #[cfg(debug_assertions)]
131        let mut stats_interval = tokio::time::interval(Duration::from_secs(60));
132        #[cfg(debug_assertions)]
133        stats_interval.tick().await;
134
135        let shutdown = self.shutdown.clone();
136        loop {
137            #[cfg(debug_assertions)]
138            {
139                tokio::select! {
140                    _ = shutdown.cancelled() => {
141                        #[cfg(debug_assertions)]
142                        log::trace!("MetadataScheduler 主循环收到关闭信号,退出");
143                        break;
144                    }
145                    Some(hash) = self.hash_rx.recv() => {
146                        self.total_received.fetch_add(1, Ordering::Relaxed);
147
148                        match task_tx.try_send(hash) {
149                            Ok(_) => {
150                                self.queue_len.fetch_add(1, Ordering::Relaxed);
151                            }
152                            Err(mpsc::error::TrySendError::Full(_)) => {
153                                self.total_dropped.fetch_add(1, Ordering::Relaxed);
154                            }
155                            Err(_) => break,
156                        }
157                    }
158
159                    _ = stats_interval.tick() => {
160                        self.print_stats(&task_tx);
161                    }
162
163                    else => break,
164                }
165            }
166
167            #[cfg(not(debug_assertions))]
168            {
169                tokio::select! {
170                    _ = shutdown.cancelled() => {
171                        #[cfg(debug_assertions)]
172                        log::trace!("MetadataScheduler 主循环收到关闭信号,退出");
173                        break;
174                    }
175                    result = self.hash_rx.recv() => {
176                        match result {
177                            Some(hash) => {
178                                self.total_received.fetch_add(1, Ordering::Relaxed);
179
180                                match task_tx.try_send(hash) {
181                                    Ok(_) => {
182                                        self.queue_len.fetch_add(1, Ordering::Relaxed);
183                                    }
184                                    Err(mpsc::error::TrySendError::Full(_)) => {
185                                        self.total_dropped.fetch_add(1, Ordering::Relaxed);
186                                    }
187                                    Err(_) => break,
188                                }
189                            }
190                            None => break,
191                        }
192                    }
193                }
194            }
195        }
196
197        // 显式关闭 task_tx,让所有 worker 任务能够退出
198        drop(task_tx);
199        #[cfg(debug_assertions)]
200        log::trace!("MetadataScheduler 主循环退出,等待 worker 任务完成");
201    }
202
203    async fn process_hash(
204        hash: HashDiscovered,
205        fetcher: &Arc<RbitFetcher>,
206        callback: &Arc<RwLock<Option<TorrentCallback>>>,
207        on_metadata_fetch: &Arc<RwLock<Option<MetadataFetchCallback>>>,
208    ) {
209        let info_hash = hash.info_hash.clone();
210        let peer_addr = hash.peer_addr;
211
212        let maybe_check_fn = {
213            match on_metadata_fetch.read() {
214                Ok(guard) => guard.clone(),
215                Err(_) => return,
216            }
217        };
218
219        if let Some(f) = maybe_check_fn
220            && !f(info_hash.clone()).await {
221            return;
222        }
223
224        let info_hash_bytes: [u8; 20] = match hex::decode(&info_hash) {
225            Ok(bytes) if bytes.len() == 20 => {
226                let mut arr = [0u8; 20];
227                arr.copy_from_slice(&bytes);
228                arr
229            }
230            _ => return,
231        };
232
233        if let Some((name, total_size, files)) = fetcher.fetch(&info_hash_bytes, peer_addr).await {
234            let metadata = TorrentInfo {
235                info_hash,
236                name,
237                total_size,
238                files,
239                magnet_link: format!("magnet:?xt=urn:btih:{}", hash.info_hash),
240                peers: vec![peer_addr.to_string()],
241                piece_length: 0,
242                timestamp: std::time::SystemTime::now()
243                    .duration_since(std::time::UNIX_EPOCH)
244                    .unwrap()
245                    .as_secs(),
246            };
247
248            let maybe_torrent_cb = {
249                match callback.read() {
250                    Ok(guard) => guard.clone(),
251                    Err(_) => return,
252                }
253            };
254
255            if let Some(cb) = maybe_torrent_cb {
256                cb(metadata);
257            }
258        }
259    }
260
261    #[cfg(debug_assertions)]
262    fn print_stats(&self, task_tx: &mpsc::Sender<HashDiscovered>) {
263        let received = self.total_received.load(Ordering::Relaxed);
264        let dropped = self.total_dropped.load(Ordering::Relaxed);
265        let dispatched = self.total_dispatched.load(Ordering::Relaxed);
266
267        let drop_rate = if received > 0 {
268            dropped as f64 / received as f64 * 100.0
269        } else {
270            0.0
271        };
272
273        let queue_size = self.max_queue_size - task_tx.capacity();
274        let queue_pressure = (queue_size as f64 / self.max_queue_size as f64) * 100.0;
275
276        if queue_pressure > 80.0 {
277            log::warn!(
278                "⚠️ Metadata 队列高压:队列={}/{}({:.1}%), 接收={}, 调度={}, 丢弃={}({:.2}%)",
279                queue_size,
280                self.max_queue_size,
281                queue_pressure,
282                received,
283                dispatched,
284                dropped,
285                drop_rate
286            );
287        } else {
288            log::info!(
289                "📊 Metadata 调度器统计:队列={}/{}({:.1}%), 接收={}, 调度={}, 丢弃={}({:.2}%)",
290                queue_size,
291                self.max_queue_size,
292                queue_pressure,
293                received,
294                dispatched,
295                dropped,
296                drop_rate
297            );
298        }
299    }
300}