use crate::metadata::RbitFetcher;
use crate::server::HashDiscovered;
use crate::types::TorrentInfo;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::{Arc, RwLock};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
type TorrentCallback = Arc<dyn Fn(TorrentInfo) + Send + Sync>;
type MetadataFetchCallback = Arc<
dyn Fn(String) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>>
+ Send
+ Sync,
>;
pub struct MetadataScheduler {
hash_rx: mpsc::Receiver<HashDiscovered>,
max_queue_size: usize,
max_concurrent: usize,
fetcher: Arc<RbitFetcher>,
callback: Arc<RwLock<Option<TorrentCallback>>>,
on_metadata_fetch: Arc<RwLock<Option<MetadataFetchCallback>>>,
total_received: Arc<AtomicU64>,
total_dropped: Arc<AtomicU64>,
total_dispatched: Arc<AtomicU64>,
queue_len: Arc<AtomicUsize>,
shutdown: CancellationToken,
}
impl MetadataScheduler {
#[allow(clippy::too_many_arguments)]
pub fn new(
hash_rx: mpsc::Receiver<HashDiscovered>,
fetcher: Arc<RbitFetcher>,
max_queue_size: usize,
max_concurrent: usize,
callback: Arc<RwLock<Option<TorrentCallback>>>,
on_metadata_fetch: Arc<RwLock<Option<MetadataFetchCallback>>>,
queue_len: Arc<AtomicUsize>,
shutdown: CancellationToken,
) -> Self {
Self {
hash_rx,
max_queue_size,
max_concurrent,
fetcher,
callback,
on_metadata_fetch,
total_received: Arc::new(AtomicU64::new(0)),
total_dropped: Arc::new(AtomicU64::new(0)),
total_dispatched: Arc::new(AtomicU64::new(0)),
queue_len,
shutdown,
}
}
pub fn set_callback(&mut self, callback: TorrentCallback) {
if let Ok(mut guard) = self.callback.try_write() {
*guard = Some(callback);
}
}
pub fn set_metadata_fetch_callback(&mut self, callback: MetadataFetchCallback) {
if let Ok(mut guard) = self.on_metadata_fetch.try_write() {
*guard = Some(callback);
}
}
pub async fn run(mut self) {
let (task_tx, task_rx) = async_channel::bounded::<HashDiscovered>(self.max_queue_size);
let shutdown = self.shutdown.clone();
#[cfg_attr(not(debug_assertions), allow(unused_variables))]
for worker_id in 0..self.max_concurrent {
let task_rx = task_rx.clone();
let fetcher = self.fetcher.clone();
let callback = self.callback.clone();
let on_metadata_fetch = self.on_metadata_fetch.clone();
let total_dispatched = self.total_dispatched.clone();
let queue_len = self.queue_len.clone();
let shutdown_worker = shutdown.clone();
tokio::spawn(async move {
#[cfg(debug_assertions)]
log::trace!("Worker {} 启动", worker_id);
loop {
tokio::select! {
_ = shutdown_worker.cancelled() => {
#[cfg(debug_assertions)]
log::trace!("Worker {} 收到关闭信号,退出", worker_id);
break;
}
result = task_rx.recv() => {
let hash = match result {
Ok(h) => {
queue_len.fetch_sub(1, Ordering::Relaxed);
h
}
Err(_) => break,
};
total_dispatched.fetch_add(1, Ordering::Relaxed);
Self::process_hash(
hash,
&fetcher,
&callback,
&on_metadata_fetch,
).await;
}
}
}
#[cfg(debug_assertions)]
log::trace!("Worker {} 退出", worker_id);
});
}
let mut stats_interval = if cfg!(debug_assertions) {
Some(tokio::time::interval(std::time::Duration::from_secs(60)))
} else {
None
};
if let Some(ref mut interval) = stats_interval {
interval.tick().await;
}
let shutdown = self.shutdown.clone();
loop {
tokio::select! {
_ = shutdown.cancelled() => {
#[cfg(debug_assertions)]
log::trace!("MetadataScheduler 主循环收到关闭信号,退出");
break;
}
result = self.hash_rx.recv() => {
match result {
Some(hash) => {
self.total_received.fetch_add(1, Ordering::Relaxed);
match task_tx.try_send(hash) {
Ok(_) => {
self.queue_len.fetch_add(1, Ordering::Relaxed);
}
Err(async_channel::TrySendError::Full(_)) => {
self.total_dropped.fetch_add(1, Ordering::Relaxed);
}
Err(_) => break,
}
}
None => break,
}
}
_ = async {
match stats_interval.as_mut() {
Some(interval) => interval.tick().await,
None => std::future::pending().await,
}
} => {
self.print_stats_inline();
}
}
}
drop(task_tx);
#[cfg(debug_assertions)]
log::trace!("MetadataScheduler 主循环退出,等待 worker 任务完成");
}
async fn process_hash(
hash: HashDiscovered,
fetcher: &Arc<RbitFetcher>,
callback: &Arc<RwLock<Option<TorrentCallback>>>,
on_metadata_fetch: &Arc<RwLock<Option<MetadataFetchCallback>>>,
) {
let info_hash = hash.info_hash.clone();
let peer_addr = hash.peer_addr;
let maybe_check_fn = {
match on_metadata_fetch.read() {
Ok(guard) => guard.clone(),
Err(_) => return,
}
};
if let Some(f) = maybe_check_fn
&& !f(info_hash.clone()).await
{
return;
}
let info_hash_bytes: [u8; 20] = match hex::decode(&info_hash) {
Ok(bytes) if bytes.len() == 20 => {
let mut arr = [0u8; 20];
arr.copy_from_slice(&bytes);
arr
}
_ => return,
};
if let Some((name, total_size, files, piece_length)) =
fetcher.fetch(&info_hash_bytes, peer_addr).await
{
let metadata = TorrentInfo {
info_hash,
name,
total_size,
files,
magnet_link: format!("magnet:?xt=urn:btih:{}", hash.info_hash),
peers: vec![peer_addr.to_string()],
piece_length,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
};
let maybe_torrent_cb = {
match callback.read() {
Ok(guard) => guard.clone(),
Err(_) => return,
}
};
if let Some(cb) = maybe_torrent_cb {
cb(metadata);
}
}
}
fn print_stats_inline(&self) {
#[cfg(debug_assertions)]
{
let received = self.total_received.load(Ordering::Relaxed);
let dropped = self.total_dropped.load(Ordering::Relaxed);
let dispatched = self.total_dispatched.load(Ordering::Relaxed);
let drop_rate = if received > 0 {
dropped as f64 / received as f64 * 100.0
} else {
0.0
};
let queue_len = self.queue_len.load(Ordering::Relaxed);
let queue_pressure = (queue_len as f64 / self.max_queue_size as f64) * 100.0;
if queue_pressure > 80.0 {
log::warn!(
"Metadata 队列高压:队列={}/{}({:.1}%), 接收={}, 调度={}, 丢弃={}({:.2}%)",
queue_len,
self.max_queue_size,
queue_pressure,
received,
dispatched,
dropped,
drop_rate
);
} else {
log::info!(
"Metadata 调度器统计:队列={}/{}({:.1}%), 接收={}, 调度={}, 丢弃={}({:.2}%)",
queue_len,
self.max_queue_size,
queue_pressure,
received,
dispatched,
dropped,
drop_rate
);
}
}
}
}