1use crate::metadata::RbitFetcher;
2use crate::server::HashDiscovered;
3use crate::types::TorrentInfo;
4use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
5use std::sync::{Arc, RwLock};
6#[cfg(debug_assertions)]
7use std::time::Duration;
8use tokio::sync::{Mutex, mpsc};
9use tokio_util::sync::CancellationToken;
10
11type TorrentCallback = Arc<dyn Fn(TorrentInfo) + Send + Sync>;
12type MetadataFetchCallback = Arc<
13 dyn Fn(String) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>>
14 + Send
15 + Sync,
16>;
17
18pub struct MetadataScheduler {
19 hash_rx: mpsc::Receiver<HashDiscovered>,
20 max_queue_size: usize,
21 max_concurrent: usize,
22 fetcher: Arc<RbitFetcher>,
23 callback: Arc<RwLock<Option<TorrentCallback>>>,
24 on_metadata_fetch: Arc<RwLock<Option<MetadataFetchCallback>>>,
25 total_received: Arc<AtomicU64>,
26 total_dropped: Arc<AtomicU64>,
27 total_dispatched: Arc<AtomicU64>,
28 queue_len: Arc<AtomicUsize>,
29 shutdown: CancellationToken,
30}
31
32impl MetadataScheduler {
33 #[allow(clippy::too_many_arguments)]
34 pub fn new(
35 hash_rx: mpsc::Receiver<HashDiscovered>,
36 fetcher: Arc<RbitFetcher>,
37 max_queue_size: usize,
38 max_concurrent: usize,
39 callback: Arc<RwLock<Option<TorrentCallback>>>,
40 on_metadata_fetch: Arc<RwLock<Option<MetadataFetchCallback>>>,
41 queue_len: Arc<AtomicUsize>,
42 shutdown: CancellationToken,
43 ) -> Self {
44 Self {
45 hash_rx,
46 max_queue_size,
47 max_concurrent,
48 fetcher,
49 callback,
50 on_metadata_fetch,
51 total_received: Arc::new(AtomicU64::new(0)),
52 total_dropped: Arc::new(AtomicU64::new(0)),
53 total_dispatched: Arc::new(AtomicU64::new(0)),
54 queue_len,
55 shutdown,
56 }
57 }
58
59 pub fn set_callback(&mut self, callback: TorrentCallback) {
60 if let Ok(mut guard) = self.callback.try_write() {
61 *guard = Some(callback);
62 }
63 }
64
65 pub fn set_metadata_fetch_callback(&mut self, callback: MetadataFetchCallback) {
66 if let Ok(mut guard) = self.on_metadata_fetch.try_write() {
67 *guard = Some(callback);
68 }
69 }
70
71 pub async fn run(mut self) {
72 let (task_tx, task_rx) = mpsc::channel::<HashDiscovered>(self.max_queue_size);
73 let task_rx = Arc::new(Mutex::new(task_rx));
74
75 let shutdown = self.shutdown.clone();
76 #[cfg_attr(not(debug_assertions), allow(unused_variables))]
77 for worker_id in 0..self.max_concurrent {
78 let task_rx = task_rx.clone();
79 let fetcher = self.fetcher.clone();
80 let callback = self.callback.clone();
81 let on_metadata_fetch = self.on_metadata_fetch.clone();
82 let total_dispatched = self.total_dispatched.clone();
83 let queue_len = self.queue_len.clone();
84 let shutdown_worker = shutdown.clone();
85
86 tokio::spawn(async move {
87 #[cfg(debug_assertions)]
88 log::trace!("Worker {} 启动", worker_id);
89
90 loop {
91 tokio::select! {
92 _ = shutdown_worker.cancelled() => {
93 #[cfg(debug_assertions)]
94 log::trace!("Worker {} 收到关闭信号,退出", worker_id);
95 break;
96 }
97 result = async {
98 let mut rx = task_rx.lock().await;
99 rx.recv().await
100 } => {
101 let hash = {
102 if result.is_some() {
103 queue_len.fetch_sub(1, Ordering::Relaxed);
104 }
105 result
106 };
107
108 let hash = match hash {
109 Some(h) => h,
110 None => break,
111 };
112
113 total_dispatched.fetch_add(1, Ordering::Relaxed);
114
115 Self::process_hash(
116 hash,
117 &fetcher,
118 &callback,
119 &on_metadata_fetch,
120 ).await;
121 }
122 }
123 }
124
125 #[cfg(debug_assertions)]
126 log::trace!("Worker {} 退出", worker_id);
127 });
128 }
129
130 #[cfg(debug_assertions)]
131 let mut stats_interval = tokio::time::interval(Duration::from_secs(60));
132 #[cfg(debug_assertions)]
133 stats_interval.tick().await;
134
135 let shutdown = self.shutdown.clone();
136 loop {
137 #[cfg(debug_assertions)]
138 {
139 tokio::select! {
140 _ = shutdown.cancelled() => {
141 #[cfg(debug_assertions)]
142 log::trace!("MetadataScheduler 主循环收到关闭信号,退出");
143 break;
144 }
145 Some(hash) = self.hash_rx.recv() => {
146 self.total_received.fetch_add(1, Ordering::Relaxed);
147
148 match task_tx.try_send(hash) {
149 Ok(_) => {
150 self.queue_len.fetch_add(1, Ordering::Relaxed);
151 }
152 Err(mpsc::error::TrySendError::Full(_)) => {
153 self.total_dropped.fetch_add(1, Ordering::Relaxed);
154 }
155 Err(_) => break,
156 }
157 }
158
159 _ = stats_interval.tick() => {
160 self.print_stats(&task_tx);
161 }
162
163 else => break,
164 }
165 }
166
167 #[cfg(not(debug_assertions))]
168 {
169 tokio::select! {
170 _ = shutdown.cancelled() => {
171 #[cfg(debug_assertions)]
172 log::trace!("MetadataScheduler 主循环收到关闭信号,退出");
173 break;
174 }
175 result = self.hash_rx.recv() => {
176 match result {
177 Some(hash) => {
178 self.total_received.fetch_add(1, Ordering::Relaxed);
179
180 match task_tx.try_send(hash) {
181 Ok(_) => {
182 self.queue_len.fetch_add(1, Ordering::Relaxed);
183 }
184 Err(mpsc::error::TrySendError::Full(_)) => {
185 self.total_dropped.fetch_add(1, Ordering::Relaxed);
186 }
187 Err(_) => break,
188 }
189 }
190 None => break,
191 }
192 }
193 }
194 }
195 }
196
197 drop(task_tx);
199 #[cfg(debug_assertions)]
200 log::trace!("MetadataScheduler 主循环退出,等待 worker 任务完成");
201 }
202
203 async fn process_hash(
204 hash: HashDiscovered,
205 fetcher: &Arc<RbitFetcher>,
206 callback: &Arc<RwLock<Option<TorrentCallback>>>,
207 on_metadata_fetch: &Arc<RwLock<Option<MetadataFetchCallback>>>,
208 ) {
209 let info_hash = hash.info_hash.clone();
210 let peer_addr = hash.peer_addr;
211
212 let maybe_check_fn = {
213 match on_metadata_fetch.read() {
214 Ok(guard) => guard.clone(),
215 Err(_) => return,
216 }
217 };
218
219 if let Some(f) = maybe_check_fn
220 && !f(info_hash.clone()).await {
221 return;
222 }
223
224 let info_hash_bytes: [u8; 20] = match hex::decode(&info_hash) {
225 Ok(bytes) if bytes.len() == 20 => {
226 let mut arr = [0u8; 20];
227 arr.copy_from_slice(&bytes);
228 arr
229 }
230 _ => return,
231 };
232
233 if let Some((name, total_size, files)) = fetcher.fetch(&info_hash_bytes, peer_addr).await {
234 let metadata = TorrentInfo {
235 info_hash,
236 name,
237 total_size,
238 files,
239 magnet_link: format!("magnet:?xt=urn:btih:{}", hash.info_hash),
240 peers: vec![peer_addr.to_string()],
241 piece_length: 0,
242 timestamp: std::time::SystemTime::now()
243 .duration_since(std::time::UNIX_EPOCH)
244 .unwrap()
245 .as_secs(),
246 };
247
248 let maybe_torrent_cb = {
249 match callback.read() {
250 Ok(guard) => guard.clone(),
251 Err(_) => return,
252 }
253 };
254
255 if let Some(cb) = maybe_torrent_cb {
256 cb(metadata);
257 }
258 }
259 }
260
261 #[cfg(debug_assertions)]
262 fn print_stats(&self, task_tx: &mpsc::Sender<HashDiscovered>) {
263 let received = self.total_received.load(Ordering::Relaxed);
264 let dropped = self.total_dropped.load(Ordering::Relaxed);
265 let dispatched = self.total_dispatched.load(Ordering::Relaxed);
266
267 let drop_rate = if received > 0 {
268 dropped as f64 / received as f64 * 100.0
269 } else {
270 0.0
271 };
272
273 let queue_size = self.max_queue_size - task_tx.capacity();
274 let queue_pressure = (queue_size as f64 / self.max_queue_size as f64) * 100.0;
275
276 if queue_pressure > 80.0 {
277 log::warn!(
278 "⚠️ Metadata 队列高压:队列={}/{}({:.1}%), 接收={}, 调度={}, 丢弃={}({:.2}%)",
279 queue_size,
280 self.max_queue_size,
281 queue_pressure,
282 received,
283 dispatched,
284 dropped,
285 drop_rate
286 );
287 } else {
288 log::info!(
289 "📊 Metadata 调度器统计:队列={}/{}({:.1}%), 接收={}, 调度={}, 丢弃={}({:.2}%)",
290 queue_size,
291 self.max_queue_size,
292 queue_pressure,
293 received,
294 dispatched,
295 dropped,
296 drop_rate
297 );
298 }
299 }
300}