1use crate::server::HashDiscovered;
2use crate::types::TorrentInfo;
3use crate::metadata::RbitFetcher;
4use std::sync::{Arc, RwLock};
5use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
6use tokio::sync::{mpsc, Mutex};
7#[cfg(debug_assertions)]
8use std::time::Duration;
9
10type TorrentCallback = Arc<dyn Fn(TorrentInfo) + Send + Sync>;
11type MetadataFetchCallback = Arc<dyn Fn(String) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>> + Send + Sync>;
12
13pub struct MetadataScheduler {
16 hash_rx: mpsc::Receiver<HashDiscovered>,
18
19 max_queue_size: usize,
21 max_concurrent: usize,
22
23 fetcher: Arc<RbitFetcher>,
25
26 callback: Arc<RwLock<Option<TorrentCallback>>>,
28 on_metadata_fetch: Arc<RwLock<Option<MetadataFetchCallback>>>,
29
30 total_received: Arc<AtomicU64>,
32 total_dropped: Arc<AtomicU64>,
33 total_dispatched: Arc<AtomicU64>,
34
35 queue_len: Arc<AtomicUsize>,
37}
38
39impl MetadataScheduler {
40 pub fn new(
41 hash_rx: mpsc::Receiver<HashDiscovered>,
42 fetcher: Arc<RbitFetcher>,
43 max_queue_size: usize,
44 max_concurrent: usize,
45 callback: Arc<RwLock<Option<TorrentCallback>>>,
46 on_metadata_fetch: Arc<RwLock<Option<MetadataFetchCallback>>>,
47 queue_len: Arc<AtomicUsize>, ) -> Self {
49 Self {
50 hash_rx,
51 max_queue_size,
52 max_concurrent,
53 fetcher,
54 callback,
55 on_metadata_fetch,
56 total_received: Arc::new(AtomicU64::new(0)),
57 total_dropped: Arc::new(AtomicU64::new(0)),
58 total_dispatched: Arc::new(AtomicU64::new(0)),
59 queue_len,
60 }
61 }
62
63 pub fn set_callback(&mut self, callback: TorrentCallback) {
65 if let Ok(mut guard) = self.callback.try_write() {
66 *guard = Some(callback);
67 }
68 }
69
70 pub fn set_metadata_fetch_callback(&mut self, callback: MetadataFetchCallback) {
72 if let Ok(mut guard) = self.on_metadata_fetch.try_write() {
73 *guard = Some(callback);
74 }
75 }
76
77 pub async fn run(mut self) {
79 let (task_tx, task_rx) = mpsc::channel::<HashDiscovered>(self.max_queue_size);
81 let task_rx = Arc::new(Mutex::new(task_rx));
82
83 for worker_id in 0..self.max_concurrent {
85 let task_rx = task_rx.clone();
86 let fetcher = self.fetcher.clone();
87 let callback = self.callback.clone();
88 let on_metadata_fetch = self.on_metadata_fetch.clone();
89 let total_dispatched = self.total_dispatched.clone();
90 let queue_len = self.queue_len.clone(); tokio::spawn(async move {
93 log::trace!("Worker {} 启动", worker_id);
94
95 loop {
96 let hash = {
98 let mut rx = task_rx.lock().await;
99 let h = rx.recv().await;
100 if h.is_some() {
102 queue_len.fetch_sub(1, Ordering::Relaxed);
103 }
104 h
105 };
106
107 let hash = match hash {
108 Some(h) => h,
109 None => break, };
111
112 total_dispatched.fetch_add(1, Ordering::Relaxed);
113
114 Self::process_hash(
116 hash,
117 &fetcher,
118 &callback,
119 &on_metadata_fetch,
120 ).await;
121 }
122
123 log::trace!("Worker {} 退出", worker_id);
124 });
125 }
126
127 #[cfg(debug_assertions)]
129 let mut stats_interval = tokio::time::interval(Duration::from_secs(60));
130 #[cfg(debug_assertions)]
131 stats_interval.tick().await;
132
133 loop {
134 #[cfg(debug_assertions)]
135 {
136 tokio::select! {
137 Some(hash) = self.hash_rx.recv() => {
138 self.total_received.fetch_add(1, Ordering::Relaxed);
139
140 match task_tx.try_send(hash) {
142 Ok(_) => {
143 self.queue_len.fetch_add(1, Ordering::Relaxed);
145 }
146 Err(mpsc::error::TrySendError::Full(_)) => {
147 self.total_dropped.fetch_add(1, Ordering::Relaxed);
149 }
150 Err(_) => break, }
152 }
153
154 _ = stats_interval.tick() => {
155 self.print_stats(&task_tx);
156 }
157
158 else => break,
159 }
160 }
161
162 #[cfg(not(debug_assertions))]
163 {
164 match self.hash_rx.recv().await {
165 Some(hash) => {
166 self.total_received.fetch_add(1, Ordering::Relaxed);
167
168 match task_tx.try_send(hash) {
170 Ok(_) => {
171 self.queue_len.fetch_add(1, Ordering::Relaxed);
173 }
174 Err(mpsc::error::TrySendError::Full(_)) => {
175 self.total_dropped.fetch_add(1, Ordering::Relaxed);
177 }
178 Err(_) => break, }
180 }
181 None => break, }
183 }
184 }
185 }
186
187 async fn process_hash(
189 hash: HashDiscovered,
190 fetcher: &Arc<RbitFetcher>,
191 callback: &Arc<RwLock<Option<TorrentCallback>>>,
192 on_metadata_fetch: &Arc<RwLock<Option<MetadataFetchCallback>>>,
193 ) {
194 let info_hash = hash.info_hash.clone();
195 let peer_addr = hash.peer_addr;
196
197 let maybe_check_fn = {
199 match on_metadata_fetch.read() {
200 Ok(guard) => guard.clone(),
201 Err(_) => return, }
203 };
204
205 if let Some(f) = maybe_check_fn {
206 if !f(info_hash.clone()).await {
207 return;
208 }
209 }
210
211 let info_hash_bytes: [u8; 20] = match hex::decode(&info_hash) {
213 Ok(bytes) if bytes.len() == 20 => {
214 let mut arr = [0u8; 20];
215 arr.copy_from_slice(&bytes);
216 arr
217 }
218 _ => return,
219 };
220
221 if let Some((name, total_size, files)) = fetcher.fetch(&info_hash_bytes, peer_addr).await {
223 let metadata = TorrentInfo {
224 info_hash,
225 name,
226 total_size,
227 files,
228 magnet_link: format!("magnet:?xt=urn:btih:{}", hash.info_hash),
229 peers: vec![peer_addr.to_string()],
230 piece_length: 0,
231 timestamp: std::time::SystemTime::now()
232 .duration_since(std::time::UNIX_EPOCH)
233 .unwrap()
234 .as_secs(),
235 };
236
237 let maybe_torrent_cb = {
239 match callback.read() {
240 Ok(guard) => guard.clone(),
241 Err(_) => return, }
243 };
244
245 if let Some(cb) = maybe_torrent_cb {
246 cb(metadata);
247 }
248 }
249 }
250
251 #[cfg(debug_assertions)]
253 fn print_stats(&self, task_tx: &mpsc::Sender<HashDiscovered>) {
254 let received = self.total_received.load(Ordering::Relaxed);
255 let dropped = self.total_dropped.load(Ordering::Relaxed);
256 let dispatched = self.total_dispatched.load(Ordering::Relaxed);
257
258 let drop_rate = if received > 0 {
259 dropped as f64 / received as f64 * 100.0
260 } else {
261 0.0
262 };
263
264 let queue_size = self.max_queue_size - task_tx.capacity();
265 let queue_pressure = (queue_size as f64 / self.max_queue_size as f64) * 100.0;
266
267 if queue_pressure > 80.0 {
269 log::warn!(
270 "⚠️ Metadata 队列高压:队列={}/{}({:.1}%), 接收={}, 调度={}, 丢弃={}({:.2}%)",
271 queue_size,
272 self.max_queue_size,
273 queue_pressure,
274 received,
275 dispatched,
276 dropped,
277 drop_rate
278 );
279 } else {
280 log::info!(
281 "📊 Metadata 调度器统计:队列={}/{}({:.1}%), 接收={}, 调度={}, 丢弃={}({:.2}%)",
282 queue_size,
283 self.max_queue_size,
284 queue_pressure,
285 received,
286 dispatched,
287 dropped,
288 drop_rate
289 );
290 }
291 }
292}