1use crate::metadata::RbitFetcher;
2use crate::server::HashDiscovered;
3use crate::types::TorrentInfo;
4use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
5use std::sync::{Arc, RwLock};
6use tokio::sync::mpsc;
7use tokio_util::sync::CancellationToken;
8
9type TorrentCallback = Arc<dyn Fn(TorrentInfo) + Send + Sync>;
10type MetadataFetchCallback = Arc<
11 dyn Fn(String) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>>
12 + Send
13 + Sync,
14>;
15
16pub struct MetadataScheduler {
17 hash_rx: mpsc::Receiver<HashDiscovered>,
18 max_queue_size: usize,
19 max_concurrent: usize,
20 fetcher: Arc<RbitFetcher>,
21 callback: Arc<RwLock<Option<TorrentCallback>>>,
22 on_metadata_fetch: Arc<RwLock<Option<MetadataFetchCallback>>>,
23 total_received: Arc<AtomicU64>,
24 total_dropped: Arc<AtomicU64>,
25 total_dispatched: Arc<AtomicU64>,
26 queue_len: Arc<AtomicUsize>,
27 shutdown: CancellationToken,
28}
29
30impl MetadataScheduler {
31 #[allow(clippy::too_many_arguments)]
32 pub fn new(
33 hash_rx: mpsc::Receiver<HashDiscovered>,
34 fetcher: Arc<RbitFetcher>,
35 max_queue_size: usize,
36 max_concurrent: usize,
37 callback: Arc<RwLock<Option<TorrentCallback>>>,
38 on_metadata_fetch: Arc<RwLock<Option<MetadataFetchCallback>>>,
39 queue_len: Arc<AtomicUsize>,
40 shutdown: CancellationToken,
41 ) -> Self {
42 Self {
43 hash_rx,
44 max_queue_size,
45 max_concurrent,
46 fetcher,
47 callback,
48 on_metadata_fetch,
49 total_received: Arc::new(AtomicU64::new(0)),
50 total_dropped: Arc::new(AtomicU64::new(0)),
51 total_dispatched: Arc::new(AtomicU64::new(0)),
52 queue_len,
53 shutdown,
54 }
55 }
56
57 pub fn set_callback(&mut self, callback: TorrentCallback) {
58 if let Ok(mut guard) = self.callback.try_write() {
59 *guard = Some(callback);
60 }
61 }
62
63 pub fn set_metadata_fetch_callback(&mut self, callback: MetadataFetchCallback) {
64 if let Ok(mut guard) = self.on_metadata_fetch.try_write() {
65 *guard = Some(callback);
66 }
67 }
68
69 pub async fn run(mut self) {
70 let (task_tx, task_rx) = async_channel::bounded::<HashDiscovered>(self.max_queue_size);
71
72 let shutdown = self.shutdown.clone();
73 #[cfg_attr(not(debug_assertions), allow(unused_variables))]
74 for worker_id in 0..self.max_concurrent {
75 let task_rx = task_rx.clone();
76 let fetcher = self.fetcher.clone();
77 let callback = self.callback.clone();
78 let on_metadata_fetch = self.on_metadata_fetch.clone();
79 let total_dispatched = self.total_dispatched.clone();
80 let queue_len = self.queue_len.clone();
81 let shutdown_worker = shutdown.clone();
82
83 tokio::spawn(async move {
84 #[cfg(debug_assertions)]
85 log::trace!("Worker {} 启动", worker_id);
86
87 loop {
88 tokio::select! {
89 _ = shutdown_worker.cancelled() => {
90 #[cfg(debug_assertions)]
91 log::trace!("Worker {} 收到关闭信号,退出", worker_id);
92 break;
93 }
94 result = task_rx.recv() => {
95 let hash = match result {
96 Ok(h) => {
97 queue_len.fetch_sub(1, Ordering::Relaxed);
98 h
99 }
100 Err(_) => break,
101 };
102
103 total_dispatched.fetch_add(1, Ordering::Relaxed);
104
105 Self::process_hash(
106 hash,
107 &fetcher,
108 &callback,
109 &on_metadata_fetch,
110 ).await;
111 }
112 }
113 }
114
115 #[cfg(debug_assertions)]
116 log::trace!("Worker {} 退出", worker_id);
117 });
118 }
119
120 let mut stats_interval = if cfg!(debug_assertions) {
121 Some(tokio::time::interval(std::time::Duration::from_secs(60)))
122 } else {
123 None
124 };
125 if let Some(ref mut interval) = stats_interval {
126 interval.tick().await;
127 }
128
129 let shutdown = self.shutdown.clone();
130 loop {
131 tokio::select! {
132 _ = shutdown.cancelled() => {
133 #[cfg(debug_assertions)]
134 log::trace!("MetadataScheduler 主循环收到关闭信号,退出");
135 break;
136 }
137 result = self.hash_rx.recv() => {
138 match result {
139 Some(hash) => {
140 self.total_received.fetch_add(1, Ordering::Relaxed);
141
142 match task_tx.try_send(hash) {
143 Ok(_) => {
144 self.queue_len.fetch_add(1, Ordering::Relaxed);
145 }
146 Err(async_channel::TrySendError::Full(_)) => {
147 self.total_dropped.fetch_add(1, Ordering::Relaxed);
148 }
149 Err(_) => break,
150 }
151 }
152 None => break,
153 }
154 }
155 _ = async {
156 match stats_interval.as_mut() {
157 Some(interval) => interval.tick().await,
158 None => std::future::pending().await,
159 }
160 } => {
161 self.print_stats_inline();
162 }
163 }
164 }
165
166 drop(task_tx);
167 #[cfg(debug_assertions)]
168 log::trace!("MetadataScheduler 主循环退出,等待 worker 任务完成");
169 }
170
171 async fn process_hash(
172 hash: HashDiscovered,
173 fetcher: &Arc<RbitFetcher>,
174 callback: &Arc<RwLock<Option<TorrentCallback>>>,
175 on_metadata_fetch: &Arc<RwLock<Option<MetadataFetchCallback>>>,
176 ) {
177 let info_hash = hash.info_hash.clone();
178 let peer_addr = hash.peer_addr;
179
180 let maybe_check_fn = {
181 match on_metadata_fetch.read() {
182 Ok(guard) => guard.clone(),
183 Err(_) => return,
184 }
185 };
186
187 if let Some(f) = maybe_check_fn
188 && !f(info_hash.clone()).await
189 {
190 return;
191 }
192
193 let info_hash_bytes: [u8; 20] = match hex::decode(&info_hash) {
194 Ok(bytes) if bytes.len() == 20 => {
195 let mut arr = [0u8; 20];
196 arr.copy_from_slice(&bytes);
197 arr
198 }
199 _ => return,
200 };
201
202 if let Some((name, total_size, files, piece_length)) =
203 fetcher.fetch(&info_hash_bytes, peer_addr).await
204 {
205 let metadata = TorrentInfo {
206 info_hash,
207 name,
208 total_size,
209 files,
210 magnet_link: format!("magnet:?xt=urn:btih:{}", hash.info_hash),
211 peers: vec![peer_addr.to_string()],
212 piece_length,
213 timestamp: std::time::SystemTime::now()
214 .duration_since(std::time::UNIX_EPOCH)
215 .unwrap_or_default()
216 .as_secs(),
217 };
218
219 let maybe_torrent_cb = {
220 match callback.read() {
221 Ok(guard) => guard.clone(),
222 Err(_) => return,
223 }
224 };
225
226 if let Some(cb) = maybe_torrent_cb {
227 cb(metadata);
228 }
229 }
230 }
231
232 fn print_stats_inline(&self) {
233 #[cfg(debug_assertions)]
234 {
235 let received = self.total_received.load(Ordering::Relaxed);
236 let dropped = self.total_dropped.load(Ordering::Relaxed);
237 let dispatched = self.total_dispatched.load(Ordering::Relaxed);
238
239 let drop_rate = if received > 0 {
240 dropped as f64 / received as f64 * 100.0
241 } else {
242 0.0
243 };
244
245 let queue_len = self.queue_len.load(Ordering::Relaxed);
246 let queue_pressure = (queue_len as f64 / self.max_queue_size as f64) * 100.0;
247
248 if queue_pressure > 80.0 {
249 log::warn!(
250 "Metadata 队列高压:队列={}/{}({:.1}%), 接收={}, 调度={}, 丢弃={}({:.2}%)",
251 queue_len,
252 self.max_queue_size,
253 queue_pressure,
254 received,
255 dispatched,
256 dropped,
257 drop_rate
258 );
259 } else {
260 log::info!(
261 "Metadata 调度器统计:队列={}/{}({:.1}%), 接收={}, 调度={}, 丢弃={}({:.2}%)",
262 queue_len,
263 self.max_queue_size,
264 queue_pressure,
265 received,
266 dispatched,
267 dropped,
268 drop_rate
269 );
270 }
271 }
272 }
273}