1use std::collections::BTreeMap;
2use std::net::SocketAddr;
3use std::time::Duration;
4use bytes::Bytes;
5use sha1::{Digest, Sha1};
6use tokio::time::timeout;
7use rbit::{
8 metadata_piece_count, ExtensionHandshake, Message, MetadataMessage,
9 MetadataMessageType, PeerConnection, PeerId,
10};
11use rbit::peer::ExtensionMessage;
12use crate::types::FileInfo;
13
14#[derive(Clone)]
15pub struct RbitFetcher {
16 timeout: Duration,
17}
18
19impl RbitFetcher {
20 pub fn new(timeout_secs: u64) -> Self {
21 Self {
22 timeout: Duration::from_secs(if timeout_secs == 0 { 15 } else { timeout_secs }),
23 }
24 }
25
26 pub async fn fetch(
27 &self,
28 info_hash: &[u8; 20],
29 peer_addr: SocketAddr,
30 ) -> Option<(String, u64, Vec<FileInfo>)> {
31 let info_hash_hex = hex::encode(info_hash);
32 log::debug!("[Metadata] 开始获取: {} @ {}", info_hash_hex, peer_addr);
33
34 let peer_id = PeerId::generate();
35
36 let mut conn = match timeout(
38 Duration::from_secs(3),
39 PeerConnection::connect(peer_addr, *info_hash, *peer_id.as_bytes()),
40 ).await {
41 Ok(Ok(c)) => c,
42 Ok(Err(_)) => return None,
43 Err(_) => return None,
44 };
45
46 if !conn.supports_extension {
47 return None;
48 }
49
50 let my_ut_metadata_id = 1;
51 let handshake = ExtensionHandshake::with_extensions(&[("ut_metadata", my_ut_metadata_id)]);
52
53 if let Ok(handshake_bytes) = handshake.encode() {
54 let _ = conn.send(Message::Extended { id: 0, payload: handshake_bytes }).await;
55 } else {
56 return None;
57 }
58
59 let mut metadata_size = 0;
60 let mut remote_ut_metadata_id = 0;
61 let mut pieces: BTreeMap<u32, Bytes> = BTreeMap::new();
62 let mut request_sent = false;
63
64 let result = timeout(self.timeout, async {
65 loop {
66 let msg = conn.receive().await.ok()?;
67 match msg {
68 Message::Extended { id, payload } => {
69 if id == 0 {
70 if let Ok(ExtensionMessage::Handshake(remote_hs)) = ExtensionMessage::decode(id, &payload) {
71 if let Some(size) = remote_hs.metadata_size {
72 metadata_size = size as u32;
73 }
74 if let Some(ext_id) = remote_hs.get_extension_id("ut_metadata") {
75 remote_ut_metadata_id = ext_id;
76 }
77 }
78 if metadata_size > 0 && remote_ut_metadata_id > 0 && !request_sent {
79 if metadata_size > 10 * 1024 * 1024 { return None; }
80
81 let count = metadata_piece_count(metadata_size as usize);
82 for i in 0..count {
83 let req = MetadataMessage::request(i as u32);
84 if let Ok(encoded) = req.encode() {
85 let _ = conn.send(Message::Extended { id: remote_ut_metadata_id, payload: encoded }).await;
86 }
87 }
88 request_sent = true;
89 }
90 } else if id == my_ut_metadata_id {
91 if let Ok(meta_msg) = MetadataMessage::decode(&payload) {
92 if meta_msg.msg_type == MetadataMessageType::Data {
93 if let Some(data) = meta_msg.data {
94 pieces.insert(meta_msg.piece, data);
95 }
96 }
97 }
98 if metadata_size > 0 {
99 let total_received: usize = pieces.values().map(|p| p.len()).sum();
100 if total_received >= metadata_size as usize {
101 let mut full_data = Vec::with_capacity(metadata_size as usize);
102 let count = metadata_piece_count(metadata_size as usize);
103 let mut success = true;
104 for i in 0..count {
105 if let Some(p) = pieces.get(&(i as u32)) {
106 full_data.extend_from_slice(p);
107 } else {
108 success = false; break;
109 }
110 }
111 if success {
112 let info_hash_copy = *info_hash;
113 let full_data_clone = full_data.clone();
114 let is_valid = tokio::task::spawn_blocking(move || {
115 let mut hasher = Sha1::new();
116 hasher.update(&full_data_clone);
117 let digest: [u8; 20] = hasher.finalize().into();
118 digest == info_hash_copy
119 }).await.unwrap_or(false);
120
121 if is_valid {
122 return Some(full_data);
123 }
124 return None;
125 }
126 }
127 }
128 }
129 }
130 _ => {}
131 }
132 }
133 }).await;
134
135 match result {
136 Ok(Some(info_bytes)) => {
137 if let Ok(value) = rbit::decode(&info_bytes) {
138 if let Some(dict) = value.as_dict() {
139 let name = dict.get(&b"name"[..]).and_then(|v| v.as_str()).unwrap_or("Unknown").to_string();
140 let mut total_size = 0;
141 let mut file_list = Vec::new();
142 if let Some(files) = dict.get(&b"files"[..]).and_then(|v| v.as_list()) {
143 for file in files {
144 if let Some(f_dict) = file.as_dict() {
145 if let Some(len) = f_dict.get(&b"length"[..]).and_then(|v| v.as_integer()) {
146 let len = len as u64;
147 total_size += len;
148 let mut path_parts = Vec::new();
149 if let Some(path_list) = f_dict.get(&b"path"[..]).and_then(|v| v.as_list()) {
150 for p in path_list {
151 if let Some(p_str) = p.as_str() { path_parts.push(p_str); }
152 }
153 }
154 file_list.push(FileInfo { path: path_parts.join("/"), size: len });
155 }
156 }
157 }
158 } else if let Some(len) = dict.get(&b"length"[..]).and_then(|v| v.as_integer()) {
159 total_size = len as u64;
160 file_list.push(FileInfo { path: name.clone(), size: total_size });
161 }
162 if total_size > 0 { return Some((name, total_size, file_list)); }
163 }
164 }
165 None
166 }
167 _ => None,
168 }
169 }
170}