Skip to main content

dht_crawler/
metadata.rs

1use crate::types::FileInfo;
2use bytes::Bytes;
3#[cfg(feature = "metrics")]
4use metrics::{counter, histogram};
5use rbit::peer::ExtensionMessage;
6use rbit::{
7    ExtensionHandshake, Message, MetadataMessage, MetadataMessageType, PeerConnection, PeerId,
8    metadata_piece_count,
9};
10use sha1::{Digest, Sha1};
11use std::collections::BTreeMap;
12use std::net::SocketAddr;
13use std::time::Duration;
14use tokio::time::timeout;
15
16#[derive(Clone)]
17pub struct RbitFetcher {
18    timeout: Duration,
19}
20
21impl RbitFetcher {
22    pub fn new(timeout_secs: u64) -> Self {
23        Self {
24            timeout: Duration::from_secs(if timeout_secs == 0 { 15 } else { timeout_secs }),
25        }
26    }
27
28    pub async fn fetch(
29        &self,
30        info_hash: &[u8; 20],
31        peer_addr: SocketAddr,
32    ) -> Option<(String, u64, Vec<FileInfo>, u64)> {
33        #[cfg(feature = "metrics")]
34        counter!("dht_metadata_fetch_attempts_total").increment(1);
35
36        let peer_id = PeerId::generate();
37
38        let mut conn = match timeout(
39            Duration::from_secs(3),
40            PeerConnection::connect(peer_addr, *info_hash, *peer_id.as_bytes()),
41        )
42        .await
43        {
44            Ok(Ok(c)) => {
45                #[cfg(feature = "metrics")]
46                counter!("dht_metadata_connection_result_total", "result" => "success")
47                    .increment(1);
48                c
49            }
50            Ok(Err(_)) => {
51                #[cfg(feature = "metrics")]
52                counter!("dht_metadata_connection_result_total", "result" => "failed").increment(1);
53                return None;
54            }
55            Err(_) => {
56                #[cfg(feature = "metrics")]
57                counter!("dht_metadata_connection_result_total", "result" => "timeout")
58                    .increment(1);
59                return None;
60            }
61        };
62
63        if !conn.supports_extension {
64            #[cfg(feature = "metrics")]
65            counter!("dht_metadata_handshake_result_total", "result" => "no_extension_support")
66                .increment(1);
67            return None;
68        }
69
70        let my_ut_metadata_id = 1;
71        let handshake = ExtensionHandshake::with_extensions(&[("ut_metadata", my_ut_metadata_id)]);
72
73        if let Ok(handshake_bytes) = handshake.encode() {
74            let _ = conn
75                .send(Message::Extended {
76                    id: 0,
77                    payload: handshake_bytes,
78                })
79                .await;
80        } else {
81            return None;
82        }
83
84        let mut metadata_size = 0;
85        let mut remote_ut_metadata_id = 0;
86        let mut pieces: BTreeMap<u32, Bytes> = BTreeMap::new();
87        let mut request_sent = false;
88
89        let result = timeout(self.timeout, async {
90            loop {
91                let msg = conn.receive().await.ok()?;
92                if let Message::Extended { id, payload } = msg {
93                    if id == 0 {
94                            if let Ok(ExtensionMessage::Handshake(remote_hs)) = ExtensionMessage::decode(id, &payload) {
95                                if let Some(size) = remote_hs.metadata_size {
96                                    metadata_size = size as u32;
97                                }
98                                if let Some(ext_id) = remote_hs.get_extension_id("ut_metadata") {
99                                    remote_ut_metadata_id = ext_id;
100                                }
101                            }
102                            if metadata_size > 0 && remote_ut_metadata_id > 0 && !request_sent {
103                                if metadata_size > 10 * 1024 * 1024 {
104                                    #[cfg(feature = "metrics")]
105                                    counter!("dht_metadata_fetch_fail_total", "reason" => "size_limit").increment(1);
106                                    return None;
107                                }
108
109                                let count = metadata_piece_count(metadata_size as usize);
110                                for i in 0..count {
111                                    let req = MetadataMessage::request(i as u32);
112                                    if let Ok(encoded) = req.encode() {
113                                        let _ = conn.send(Message::Extended { id: remote_ut_metadata_id, payload: encoded }).await;
114                                    }
115                                }
116                                request_sent = true;
117                            }
118                        } else if id == my_ut_metadata_id {
119                            if let Ok(meta_msg) = MetadataMessage::decode(&payload)
120                                && meta_msg.msg_type == MetadataMessageType::Data
121                                && let Some(data) = meta_msg.data {
122                                #[cfg(feature = "metrics")]
123                                counter!("dht_metadata_bytes_downloaded_total").increment(data.len() as u64);
124                                pieces.insert(meta_msg.piece, data);
125                            }
126                            if metadata_size > 0 {
127                                let total_received: usize = pieces.values().map(|p| p.len()).sum();
128                                if total_received >= metadata_size as usize {
129                                    let mut full_data = Vec::with_capacity(metadata_size as usize);
130                                    let count = metadata_piece_count(metadata_size as usize);
131                                    let mut success = true;
132                                    for i in 0..count {
133                                        if let Some(p) = pieces.get(&(i as u32)) {
134                                            full_data.extend_from_slice(p);
135                                        } else {
136                                            success = false; break;
137                                        }
138                                    }
139                                    if success {
140                                        let info_hash_copy = *info_hash;
141                                        let validated = tokio::task::spawn_blocking(move || {
142                                            let mut hasher = Sha1::new();
143                                            hasher.update(&full_data);
144                                            let digest: [u8; 20] = hasher.finalize().into();
145                                            if digest == info_hash_copy {
146                                                Some(full_data)
147                                            } else {
148                                                None
149                                            }
150                                        }).await.unwrap_or(None);
151
152                                        if validated.is_some() {
153                                            #[cfg(feature = "metrics")]
154                                            counter!("dht_metadata_handshake_result_total", "result" => "success").increment(1);
155                                            return validated;
156                                        }
157                                        #[cfg(feature = "metrics")]
158                                        counter!("dht_metadata_fetch_fail_total", "reason" => "sha1_mismatch").increment(1);
159                                        return None;
160                                    }
161                                }
162                            }
163                        }
164                    }
165            }
166        }).await;
167
168        match result {
169            Ok(Some(info_bytes)) => {
170                if let Ok(value) = rbit::decode(&info_bytes)
171                    && let Some(dict) = value.as_dict()
172                {
173                    let name = dict
174                        .get(&b"name"[..])
175                        .and_then(|v| v.as_str())
176                        .unwrap_or("Unknown")
177                        .to_string();
178                    let piece_length = dict
179                        .get(&b"piece length"[..])
180                        .and_then(|v| v.as_integer())
181                        .unwrap_or(0) as u64;
182                    let mut total_size = 0;
183                    let mut file_list = Vec::new();
184                    if let Some(files) = dict.get(&b"files"[..]).and_then(|v| v.as_list()) {
185                        for file in files {
186                            if let Some(f_dict) = file.as_dict()
187                                && let Some(len) =
188                                    f_dict.get(&b"length"[..]).and_then(|v| v.as_integer())
189                            {
190                                let len = len as u64;
191                                total_size += len;
192                                let mut path_parts = Vec::new();
193                                if let Some(path_list) =
194                                    f_dict.get(&b"path"[..]).and_then(|v| v.as_list())
195                                {
196                                    for p in path_list {
197                                        if let Some(p_str) = p.as_str() {
198                                            path_parts.push(p_str);
199                                        }
200                                    }
201                                }
202                                file_list.push(FileInfo {
203                                    path: path_parts.join("/"),
204                                    size: len,
205                                });
206                            }
207                        }
208                    } else if let Some(len) = dict.get(&b"length"[..]).and_then(|v| v.as_integer())
209                    {
210                        total_size = len as u64;
211                        file_list.push(FileInfo {
212                            path: name.clone(),
213                            size: total_size,
214                        });
215                    }
216                    if total_size > 0 {
217                        #[cfg(feature = "metrics")]
218                        {
219                            counter!("dht_metadata_fetch_success_total").increment(1);
220                            histogram!("dht_metadata_size_bytes").record(total_size as f64);
221                        }
222                        return Some((name, total_size, file_list, piece_length));
223                    }
224                }
225                #[cfg(feature = "metrics")]
226                counter!("dht_metadata_fetch_fail_total", "reason" => "parse_error").increment(1);
227                None
228            }
229            _ => {
230                #[cfg(feature = "metrics")]
231                counter!("dht_metadata_fetch_fail_total", "reason" => "timeout").increment(1);
232                None
233            }
234        }
235    }
236}