1use crate::types::FileInfo;
2use bytes::Bytes;
3#[cfg(feature = "metrics")]
4use metrics::{counter, histogram};
5use rbit::peer::ExtensionMessage;
6use rbit::{
7 ExtensionHandshake, Message, MetadataMessage, MetadataMessageType, PeerConnection, PeerId,
8 metadata_piece_count,
9};
10use sha1::{Digest, Sha1};
11use std::collections::BTreeMap;
12use std::net::SocketAddr;
13use std::time::Duration;
14use tokio::time::timeout;
15
16#[derive(Clone)]
17pub struct RbitFetcher {
18 timeout: Duration,
19}
20
21impl RbitFetcher {
22 pub fn new(timeout_secs: u64) -> Self {
23 Self {
24 timeout: Duration::from_secs(if timeout_secs == 0 { 15 } else { timeout_secs }),
25 }
26 }
27
28 pub async fn fetch(
29 &self,
30 info_hash: &[u8; 20],
31 peer_addr: SocketAddr,
32 ) -> Option<(String, u64, Vec<FileInfo>)> {
33 #[cfg(feature = "metrics")]
34 counter!("dht_metadata_fetch_attempts_total").increment(1);
35
36 let peer_id = PeerId::generate();
37
38 let mut conn = match timeout(
39 Duration::from_secs(3),
40 PeerConnection::connect(peer_addr, *info_hash, *peer_id.as_bytes()),
41 )
42 .await
43 {
44 Ok(Ok(c)) => {
45 #[cfg(feature = "metrics")]
46 counter!("dht_metadata_connection_result_total", "result" => "success")
47 .increment(1);
48 c
49 }
50 Ok(Err(_)) => {
51 #[cfg(feature = "metrics")]
52 counter!("dht_metadata_connection_result_total", "result" => "failed").increment(1);
53 return None;
54 }
55 Err(_) => {
56 #[cfg(feature = "metrics")]
57 counter!("dht_metadata_connection_result_total", "result" => "timeout")
58 .increment(1);
59 return None;
60 }
61 };
62
63 if !conn.supports_extension {
64 #[cfg(feature = "metrics")]
65 counter!("dht_metadata_handshake_result_total", "result" => "no_extension_support")
66 .increment(1);
67 return None;
68 }
69
70 let my_ut_metadata_id = 1;
71 let handshake = ExtensionHandshake::with_extensions(&[("ut_metadata", my_ut_metadata_id)]);
72
73 if let Ok(handshake_bytes) = handshake.encode() {
74 let _ = conn
75 .send(Message::Extended {
76 id: 0,
77 payload: handshake_bytes,
78 })
79 .await;
80 } else {
81 return None;
82 }
83
84 let mut metadata_size = 0;
85 let mut remote_ut_metadata_id = 0;
86 let mut pieces: BTreeMap<u32, Bytes> = BTreeMap::new();
87 let mut request_sent = false;
88
89 let result = timeout(self.timeout, async {
90 loop {
91 let msg = conn.receive().await.ok()?;
92 if let Message::Extended { id, payload } = msg {
93 if id == 0 {
94 if let Ok(ExtensionMessage::Handshake(remote_hs)) = ExtensionMessage::decode(id, &payload) {
95 if let Some(size) = remote_hs.metadata_size {
96 metadata_size = size as u32;
97 }
98 if let Some(ext_id) = remote_hs.get_extension_id("ut_metadata") {
99 remote_ut_metadata_id = ext_id;
100 }
101 }
102 if metadata_size > 0 && remote_ut_metadata_id > 0 && !request_sent {
103 if metadata_size > 10 * 1024 * 1024 {
104 #[cfg(feature = "metrics")]
105 counter!("dht_metadata_fetch_fail_total", "reason" => "size_limit").increment(1);
106 return None;
107 }
108
109 let count = metadata_piece_count(metadata_size as usize);
110 for i in 0..count {
111 let req = MetadataMessage::request(i as u32);
112 if let Ok(encoded) = req.encode() {
113 let _ = conn.send(Message::Extended { id: remote_ut_metadata_id, payload: encoded }).await;
114 }
115 }
116 request_sent = true;
117 }
118 } else if id == my_ut_metadata_id {
119 if let Ok(meta_msg) = MetadataMessage::decode(&payload)
120 && meta_msg.msg_type == MetadataMessageType::Data
121 && let Some(data) = meta_msg.data {
122 #[cfg(feature = "metrics")]
123 counter!("dht_metadata_bytes_downloaded_total").increment(data.len() as u64);
124 pieces.insert(meta_msg.piece, data);
125 }
126 if metadata_size > 0 {
127 let total_received: usize = pieces.values().map(|p| p.len()).sum();
128 if total_received >= metadata_size as usize {
129 let mut full_data = Vec::with_capacity(metadata_size as usize);
130 let count = metadata_piece_count(metadata_size as usize);
131 let mut success = true;
132 for i in 0..count {
133 if let Some(p) = pieces.get(&(i as u32)) {
134 full_data.extend_from_slice(p);
135 } else {
136 success = false; break;
137 }
138 }
139 if success {
140 let info_hash_copy = *info_hash;
141 let full_data_clone = full_data.clone();
142 let is_valid = tokio::task::spawn_blocking(move || {
143 let mut hasher = Sha1::new();
144 hasher.update(&full_data_clone);
145 let digest: [u8; 20] = hasher.finalize().into();
146 digest == info_hash_copy
147 }).await.unwrap_or(false);
148
149 if is_valid {
150 #[cfg(feature = "metrics")]
151 counter!("dht_metadata_handshake_result_total", "result" => "success").increment(1);
152 return Some(full_data);
153 }
154 #[cfg(feature = "metrics")]
155 counter!("dht_metadata_fetch_fail_total", "reason" => "sha1_mismatch").increment(1);
156 return None;
157 }
158 }
159 }
160 }
161 }
162 }
163 }).await;
164
165 match result {
166 Ok(Some(info_bytes)) => {
167 if let Ok(value) = rbit::decode(&info_bytes)
168 && let Some(dict) = value.as_dict() {
169 let name = dict
170 .get(&b"name"[..])
171 .and_then(|v| v.as_str())
172 .unwrap_or("Unknown")
173 .to_string();
174 let mut total_size = 0;
175 let mut file_list = Vec::new();
176 if let Some(files) = dict.get(&b"files"[..]).and_then(|v| v.as_list()) {
177 for file in files {
178 if let Some(f_dict) = file.as_dict()
179 && let Some(len) = f_dict.get(&b"length"[..]).and_then(|v| v.as_integer()) {
180 let len = len as u64;
181 total_size += len;
182 let mut path_parts = Vec::new();
183 if let Some(path_list) =
184 f_dict.get(&b"path"[..]).and_then(|v| v.as_list())
185 {
186 for p in path_list {
187 if let Some(p_str) = p.as_str() {
188 path_parts.push(p_str);
189 }
190 }
191 }
192 file_list.push(FileInfo {
193 path: path_parts.join("/"),
194 size: len,
195 });
196 }
197 }
198 } else if let Some(len) =
199 dict.get(&b"length"[..]).and_then(|v| v.as_integer())
200 {
201 total_size = len as u64;
202 file_list.push(FileInfo {
203 path: name.clone(),
204 size: total_size,
205 });
206 }
207 if total_size > 0 {
208 #[cfg(feature = "metrics")]
209 {
210 counter!("dht_metadata_fetch_success_total").increment(1);
211 histogram!("dht_metadata_size_bytes").record(total_size as f64);
212 }
213 return Some((name, total_size, file_list));
214 }
215 }
216 #[cfg(feature = "metrics")]
217 counter!("dht_metadata_fetch_fail_total", "reason" => "parse_error").increment(1);
218 None
219 }
220 _ => {
221 #[cfg(feature = "metrics")]
222 counter!("dht_metadata_fetch_fail_total", "reason" => "timeout").increment(1);
223 None
224 }
225 }
226 }
227}