1use crate::types::FileInfo;
2use bytes::Bytes;
3#[cfg(feature = "metrics")]
4use metrics::{counter, histogram};
5use rbit::peer::ExtensionMessage;
6use rbit::{
7 ExtensionHandshake, Message, MetadataMessage, MetadataMessageType, PeerConnection, PeerId,
8 metadata_piece_count,
9};
10use sha1::{Digest, Sha1};
11use std::collections::BTreeMap;
12use std::net::SocketAddr;
13use std::time::Duration;
14use tokio::time::timeout;
15
16#[derive(Clone)]
17pub struct RbitFetcher {
18 timeout: Duration,
19}
20
21impl RbitFetcher {
22 pub fn new(timeout_secs: u64) -> Self {
23 Self {
24 timeout: Duration::from_secs(if timeout_secs == 0 { 15 } else { timeout_secs }),
25 }
26 }
27
28 pub async fn fetch(
29 &self,
30 info_hash: &[u8; 20],
31 peer_addr: SocketAddr,
32 ) -> Option<(String, u64, Vec<FileInfo>, u64)> {
33 #[cfg(feature = "metrics")]
34 counter!("dht_metadata_fetch_attempts_total").increment(1);
35
36 let peer_id = PeerId::generate();
37
38 let mut conn = match timeout(
39 Duration::from_secs(3),
40 PeerConnection::connect(peer_addr, *info_hash, *peer_id.as_bytes()),
41 )
42 .await
43 {
44 Ok(Ok(c)) => {
45 #[cfg(feature = "metrics")]
46 counter!("dht_metadata_connection_result_total", "result" => "success")
47 .increment(1);
48 c
49 }
50 Ok(Err(_)) => {
51 #[cfg(feature = "metrics")]
52 counter!("dht_metadata_connection_result_total", "result" => "failed").increment(1);
53 return None;
54 }
55 Err(_) => {
56 #[cfg(feature = "metrics")]
57 counter!("dht_metadata_connection_result_total", "result" => "timeout")
58 .increment(1);
59 return None;
60 }
61 };
62
63 if !conn.supports_extension {
64 #[cfg(feature = "metrics")]
65 counter!("dht_metadata_handshake_result_total", "result" => "no_extension_support")
66 .increment(1);
67 return None;
68 }
69
70 let my_ut_metadata_id = 1;
71 let handshake = ExtensionHandshake::with_extensions(&[("ut_metadata", my_ut_metadata_id)]);
72
73 if let Ok(handshake_bytes) = handshake.encode() {
74 let _ = conn
75 .send(Message::Extended {
76 id: 0,
77 payload: handshake_bytes,
78 })
79 .await;
80 } else {
81 return None;
82 }
83
84 let mut metadata_size = 0;
85 let mut remote_ut_metadata_id = 0;
86 let mut pieces: BTreeMap<u32, Bytes> = BTreeMap::new();
87 let mut request_sent = false;
88
89 let result = timeout(self.timeout, async {
90 loop {
91 let msg = conn.receive().await.ok()?;
92 if let Message::Extended { id, payload } = msg {
93 if id == 0 {
94 if let Ok(ExtensionMessage::Handshake(remote_hs)) = ExtensionMessage::decode(id, &payload) {
95 if let Some(size) = remote_hs.metadata_size {
96 metadata_size = size as u32;
97 }
98 if let Some(ext_id) = remote_hs.get_extension_id("ut_metadata") {
99 remote_ut_metadata_id = ext_id;
100 }
101 }
102 if metadata_size > 0 && remote_ut_metadata_id > 0 && !request_sent {
103 if metadata_size > 10 * 1024 * 1024 {
104 #[cfg(feature = "metrics")]
105 counter!("dht_metadata_fetch_fail_total", "reason" => "size_limit").increment(1);
106 return None;
107 }
108
109 let count = metadata_piece_count(metadata_size as usize);
110 for i in 0..count {
111 let req = MetadataMessage::request(i as u32);
112 if let Ok(encoded) = req.encode() {
113 let _ = conn.send(Message::Extended { id: remote_ut_metadata_id, payload: encoded }).await;
114 }
115 }
116 request_sent = true;
117 }
118 } else if id == my_ut_metadata_id {
119 if let Ok(meta_msg) = MetadataMessage::decode(&payload)
120 && meta_msg.msg_type == MetadataMessageType::Data
121 && let Some(data) = meta_msg.data {
122 #[cfg(feature = "metrics")]
123 counter!("dht_metadata_bytes_downloaded_total").increment(data.len() as u64);
124 pieces.insert(meta_msg.piece, data);
125 }
126 if metadata_size > 0 {
127 let total_received: usize = pieces.values().map(|p| p.len()).sum();
128 if total_received >= metadata_size as usize {
129 let mut full_data = Vec::with_capacity(metadata_size as usize);
130 let count = metadata_piece_count(metadata_size as usize);
131 let mut success = true;
132 for i in 0..count {
133 if let Some(p) = pieces.get(&(i as u32)) {
134 full_data.extend_from_slice(p);
135 } else {
136 success = false; break;
137 }
138 }
139 if success {
140 let info_hash_copy = *info_hash;
141 let validated = tokio::task::spawn_blocking(move || {
142 let mut hasher = Sha1::new();
143 hasher.update(&full_data);
144 let digest: [u8; 20] = hasher.finalize().into();
145 if digest == info_hash_copy {
146 Some(full_data)
147 } else {
148 None
149 }
150 }).await.unwrap_or(None);
151
152 if validated.is_some() {
153 #[cfg(feature = "metrics")]
154 counter!("dht_metadata_handshake_result_total", "result" => "success").increment(1);
155 return validated;
156 }
157 #[cfg(feature = "metrics")]
158 counter!("dht_metadata_fetch_fail_total", "reason" => "sha1_mismatch").increment(1);
159 return None;
160 }
161 }
162 }
163 }
164 }
165 }
166 }).await;
167
168 match result {
169 Ok(Some(info_bytes)) => {
170 if let Ok(value) = rbit::decode(&info_bytes)
171 && let Some(dict) = value.as_dict()
172 {
173 let name = dict
174 .get(&b"name"[..])
175 .and_then(|v| v.as_str())
176 .unwrap_or("Unknown")
177 .to_string();
178 let piece_length = dict
179 .get(&b"piece length"[..])
180 .and_then(|v| v.as_integer())
181 .unwrap_or(0) as u64;
182 let mut total_size = 0;
183 let mut file_list = Vec::new();
184 if let Some(files) = dict.get(&b"files"[..]).and_then(|v| v.as_list()) {
185 for file in files {
186 if let Some(f_dict) = file.as_dict()
187 && let Some(len) =
188 f_dict.get(&b"length"[..]).and_then(|v| v.as_integer())
189 {
190 let len = len as u64;
191 total_size += len;
192 let mut path_parts = Vec::new();
193 if let Some(path_list) =
194 f_dict.get(&b"path"[..]).and_then(|v| v.as_list())
195 {
196 for p in path_list {
197 if let Some(p_str) = p.as_str() {
198 path_parts.push(p_str);
199 }
200 }
201 }
202 file_list.push(FileInfo {
203 path: path_parts.join("/"),
204 size: len,
205 });
206 }
207 }
208 } else if let Some(len) = dict.get(&b"length"[..]).and_then(|v| v.as_integer())
209 {
210 total_size = len as u64;
211 file_list.push(FileInfo {
212 path: name.clone(),
213 size: total_size,
214 });
215 }
216 if total_size > 0 {
217 #[cfg(feature = "metrics")]
218 {
219 counter!("dht_metadata_fetch_success_total").increment(1);
220 histogram!("dht_metadata_size_bytes").record(total_size as f64);
221 }
222 return Some((name, total_size, file_list, piece_length));
223 }
224 }
225 #[cfg(feature = "metrics")]
226 counter!("dht_metadata_fetch_fail_total", "reason" => "parse_error").increment(1);
227 None
228 }
229 _ => {
230 #[cfg(feature = "metrics")]
231 counter!("dht_metadata_fetch_fail_total", "reason" => "timeout").increment(1);
232 None
233 }
234 }
235 }
236}