dht_crawler/
scheduler.rs

1use crate::server::HashDiscovered;
2use crate::types::TorrentInfo;
3use crate::metadata::RbitFetcher;
4use std::sync::{Arc, RwLock};
5use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
6use tokio::sync::{mpsc, Mutex};
7#[cfg(debug_assertions)]
8use std::time::Duration;
9
10type TorrentCallback = Arc<dyn Fn(TorrentInfo) + Send + Sync>;
11type MetadataFetchCallback = Arc<dyn Fn(String) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>> + Send + Sync>;
12
13/// 元数据调度器(优雅版:Worker 池 + Channel)
14/// 负责管理元数据获取队列和任务调度
15pub struct MetadataScheduler {
16    // 输入通道
17    hash_rx: mpsc::Receiver<HashDiscovered>,
18    
19    // 配置
20    max_queue_size: usize,
21    max_concurrent: usize,
22    
23    // 元数据获取器
24    fetcher: Arc<RbitFetcher>,
25    
26    // 回调
27    callback: Arc<RwLock<Option<TorrentCallback>>>,
28    on_metadata_fetch: Arc<RwLock<Option<MetadataFetchCallback>>>,
29    
30    // 统计(使用 Atomic 支持多线程访问)
31    total_received: Arc<AtomicU64>,
32    total_dropped: Arc<AtomicU64>,
33    total_dispatched: Arc<AtomicU64>,
34    
35    // 共享的队列长度计数器(用于向 Server 反馈背压)
36    queue_len: Arc<AtomicUsize>,
37}
38
39impl MetadataScheduler {
40    pub fn new(
41        hash_rx: mpsc::Receiver<HashDiscovered>,
42        fetcher: Arc<RbitFetcher>,
43        max_queue_size: usize,
44        max_concurrent: usize,
45        callback: Arc<RwLock<Option<TorrentCallback>>>,
46        on_metadata_fetch: Arc<RwLock<Option<MetadataFetchCallback>>>,
47        queue_len: Arc<AtomicUsize>, // 新增参数
48    ) -> Self {
49        Self {
50            hash_rx,
51            max_queue_size,
52            max_concurrent,
53            fetcher,
54            callback,
55            on_metadata_fetch,
56            total_received: Arc::new(AtomicU64::new(0)),
57            total_dropped: Arc::new(AtomicU64::new(0)),
58            total_dispatched: Arc::new(AtomicU64::new(0)),
59            queue_len,
60        }
61    }
62    
63    /// 设置 torrent 回调
64    pub fn set_callback(&mut self, callback: TorrentCallback) {
65        if let Ok(mut guard) = self.callback.try_write() {
66            *guard = Some(callback);
67        }
68    }
69    
70    /// 设置元数据获取前的检查回调
71    pub fn set_metadata_fetch_callback(&mut self, callback: MetadataFetchCallback) {
72        if let Ok(mut guard) = self.on_metadata_fetch.try_write() {
73            *guard = Some(callback);
74        }
75    }
76    
77    /// 运行调度器(完全事件驱动)
78    pub async fn run(mut self) {        
79        // 创建任务队列(channel 自带背压)
80        let (task_tx, task_rx) = mpsc::channel::<HashDiscovered>(self.max_queue_size);
81        let task_rx = Arc::new(Mutex::new(task_rx));
82        
83        // 启动 Worker 池
84        for worker_id in 0..self.max_concurrent {
85            let task_rx = task_rx.clone();
86            let fetcher = self.fetcher.clone();
87            let callback = self.callback.clone();
88            let on_metadata_fetch = self.on_metadata_fetch.clone();
89            let total_dispatched = self.total_dispatched.clone();
90            let queue_len = self.queue_len.clone(); // 传递计数器
91            
92            tokio::spawn(async move {
93                log::trace!("Worker {} 启动", worker_id);
94                
95                loop {
96                    // Worker 从队列取任务(阻塞等待,零延迟)
97                    let hash = {
98                        let mut rx = task_rx.lock().await;
99                        let h = rx.recv().await;
100                        // 取出任务后,减少计数器
101                        if h.is_some() {
102                            queue_len.fetch_sub(1, Ordering::Relaxed);
103                        }
104                        h
105                    };
106                    
107                    let hash = match hash {
108                        Some(h) => h,
109                        None => break,  // Channel 关闭,退出
110                    };
111                    
112                    total_dispatched.fetch_add(1, Ordering::Relaxed);
113                    
114                    // 执行任务
115                    Self::process_hash(
116                        hash,
117                        &fetcher,
118                        &callback,
119                        &on_metadata_fetch,
120                    ).await;
121                }
122                
123                log::trace!("Worker {} 退出", worker_id);
124            });
125        }
126        
127        // 主循环:只负责接收 hash 并转发到 worker 队列
128        #[cfg(debug_assertions)]
129        let mut stats_interval = tokio::time::interval(Duration::from_secs(60));
130        #[cfg(debug_assertions)]
131        stats_interval.tick().await;
132        
133        loop {
134            #[cfg(debug_assertions)]
135            {
136                tokio::select! {
137                    Some(hash) = self.hash_rx.recv() => {
138                        self.total_received.fetch_add(1, Ordering::Relaxed);
139                        
140                        // 尝试发送到 worker 队列
141                        match task_tx.try_send(hash) {
142                            Ok(_) => {
143                                // 成功入队,增加计数器
144                                self.queue_len.fetch_add(1, Ordering::Relaxed);
145                            }
146                            Err(mpsc::error::TrySendError::Full(_)) => {
147                                // 队列满,丢弃
148                                self.total_dropped.fetch_add(1, Ordering::Relaxed);
149                            }
150                            Err(_) => break,  // Channel 关闭
151                        }
152                    }
153                    
154                    _ = stats_interval.tick() => {
155                        self.print_stats(&task_tx);
156                    }
157                    
158                    else => break,
159                }
160            }
161            
162            #[cfg(not(debug_assertions))]
163            {
164                match self.hash_rx.recv().await {
165                    Some(hash) => {
166                        self.total_received.fetch_add(1, Ordering::Relaxed);
167                        
168                        // 尝试发送到 worker 队列
169                        match task_tx.try_send(hash) {
170                            Ok(_) => {
171                                // 成功入队,增加计数器
172                                self.queue_len.fetch_add(1, Ordering::Relaxed);
173                            }
174                            Err(mpsc::error::TrySendError::Full(_)) => {
175                                // 队列满,丢弃
176                                self.total_dropped.fetch_add(1, Ordering::Relaxed);
177                            }
178                            Err(_) => break,  // Channel 关闭
179                        }
180                    }
181                    None => break,  // Channel 关闭
182                }
183            }
184        }
185    }
186    
187    /// 处理单个 hash(Worker 调用)
188    async fn process_hash(
189        hash: HashDiscovered,
190        fetcher: &Arc<RbitFetcher>,
191        callback: &Arc<RwLock<Option<TorrentCallback>>>,
192        on_metadata_fetch: &Arc<RwLock<Option<MetadataFetchCallback>>>,
193    ) {
194        let info_hash = hash.info_hash.clone();
195        let peer_addr = hash.peer_addr;
196        
197        // 检查是否需要获取(获取回调快照并释放锁)
198        let maybe_check_fn = {
199            match on_metadata_fetch.read() {
200                Ok(guard) => guard.clone(),
201                Err(_) => return, // 锁中毒
202            }
203        };
204
205        if let Some(f) = maybe_check_fn {
206            if !f(info_hash.clone()).await {
207                return;
208            }
209        }
210        
211        // 解码 info_hash
212        let info_hash_bytes: [u8; 20] = match hex::decode(&info_hash) {
213            Ok(bytes) if bytes.len() == 20 => {
214                let mut arr = [0u8; 20];
215                arr.copy_from_slice(&bytes);
216                arr
217            }
218            _ => return,
219        };
220        
221        // 获取元数据
222        if let Some((name, total_size, files)) = fetcher.fetch(&info_hash_bytes, peer_addr).await {
223            let metadata = TorrentInfo {
224                info_hash,
225                name,
226                total_size,
227                files,
228                magnet_link: format!("magnet:?xt=urn:btih:{}", hash.info_hash),
229                peers: vec![peer_addr.to_string()],
230                piece_length: 0,
231                timestamp: std::time::SystemTime::now()
232                    .duration_since(std::time::UNIX_EPOCH)
233                    .unwrap()
234                    .as_secs(),
235            };
236            
237            // 获取回调快照并释放锁
238            let maybe_torrent_cb = {
239                match callback.read() {
240                    Ok(guard) => guard.clone(),
241                    Err(_) => return, // 锁中毒
242                }
243            };
244            
245            if let Some(cb) = maybe_torrent_cb {
246                cb(metadata);
247            }
248        }
249    }
250    
251    /// 输出统计信息(仅在 debug 模式下编译)
252    #[cfg(debug_assertions)]
253    fn print_stats(&self, task_tx: &mpsc::Sender<HashDiscovered>) {
254        let received = self.total_received.load(Ordering::Relaxed);
255        let dropped = self.total_dropped.load(Ordering::Relaxed);
256        let dispatched = self.total_dispatched.load(Ordering::Relaxed);
257        
258        let drop_rate = if received > 0 {
259            dropped as f64 / received as f64 * 100.0
260        } else {
261            0.0
262        };
263        
264        let queue_size = self.max_queue_size - task_tx.capacity();
265        let queue_pressure = (queue_size as f64 / self.max_queue_size as f64) * 100.0;
266        
267        // 根据压力选择日志级别
268        if queue_pressure > 80.0 {
269            log::warn!(
270                "⚠️ Metadata 队列高压:队列={}/{}({:.1}%), 接收={}, 调度={}, 丢弃={}({:.2}%)",
271                queue_size,
272                self.max_queue_size,
273                queue_pressure,
274                received,
275                dispatched,
276                dropped,
277                drop_rate
278            );
279        } else {
280            log::info!(
281                "📊 Metadata 调度器统计:队列={}/{}({:.1}%), 接收={}, 调度={}, 丢弃={}({:.2}%)",
282                queue_size,
283                self.max_queue_size,
284                queue_pressure,
285                received,
286                dispatched,
287                dropped,
288                drop_rate
289            );
290        }
291    }
292}