Skip to main content

dht_crawler/
metadata.rs

1use crate::types::FileInfo;
2use bytes::Bytes;
3#[cfg(feature = "metrics")]
4use metrics::{counter, histogram};
5use rbit::peer::ExtensionMessage;
6use rbit::{
7    ExtensionHandshake, Message, MetadataMessage, MetadataMessageType, PeerConnection, PeerId,
8    metadata_piece_count,
9};
10use sha1::{Digest, Sha1};
11use std::collections::BTreeMap;
12use std::net::SocketAddr;
13use std::time::Duration;
14use tokio::time::timeout;
15
16#[derive(Clone)]
17pub struct RbitFetcher {
18    timeout: Duration,
19}
20
21impl RbitFetcher {
22    pub fn new(timeout_secs: u64) -> Self {
23        Self {
24            timeout: Duration::from_secs(if timeout_secs == 0 { 15 } else { timeout_secs }),
25        }
26    }
27
28    pub async fn fetch(
29        &self,
30        info_hash: &[u8; 20],
31        peer_addr: SocketAddr,
32    ) -> Option<(String, u64, Vec<FileInfo>)> {
33        #[cfg(feature = "metrics")]
34        counter!("dht_metadata_fetch_attempts_total").increment(1);
35
36        let peer_id = PeerId::generate();
37
38        let mut conn = match timeout(
39            Duration::from_secs(3),
40            PeerConnection::connect(peer_addr, *info_hash, *peer_id.as_bytes()),
41        )
42        .await
43        {
44            Ok(Ok(c)) => {
45                #[cfg(feature = "metrics")]
46                counter!("dht_metadata_connection_result_total", "result" => "success")
47                    .increment(1);
48                c
49            }
50            Ok(Err(_)) => {
51                #[cfg(feature = "metrics")]
52                counter!("dht_metadata_connection_result_total", "result" => "failed").increment(1);
53                return None;
54            }
55            Err(_) => {
56                #[cfg(feature = "metrics")]
57                counter!("dht_metadata_connection_result_total", "result" => "timeout")
58                    .increment(1);
59                return None;
60            }
61        };
62
63        if !conn.supports_extension {
64            #[cfg(feature = "metrics")]
65            counter!("dht_metadata_handshake_result_total", "result" => "no_extension_support")
66                .increment(1);
67            return None;
68        }
69
70        let my_ut_metadata_id = 1;
71        let handshake = ExtensionHandshake::with_extensions(&[("ut_metadata", my_ut_metadata_id)]);
72
73        if let Ok(handshake_bytes) = handshake.encode() {
74            let _ = conn
75                .send(Message::Extended {
76                    id: 0,
77                    payload: handshake_bytes,
78                })
79                .await;
80        } else {
81            return None;
82        }
83
84        let mut metadata_size = 0;
85        let mut remote_ut_metadata_id = 0;
86        let mut pieces: BTreeMap<u32, Bytes> = BTreeMap::new();
87        let mut request_sent = false;
88
89        let result = timeout(self.timeout, async {
90            loop {
91                let msg = conn.receive().await.ok()?;
92                if let Message::Extended { id, payload } = msg {
93                    if id == 0 {
94                            if let Ok(ExtensionMessage::Handshake(remote_hs)) = ExtensionMessage::decode(id, &payload) {
95                                if let Some(size) = remote_hs.metadata_size {
96                                    metadata_size = size as u32;
97                                }
98                                if let Some(ext_id) = remote_hs.get_extension_id("ut_metadata") {
99                                    remote_ut_metadata_id = ext_id;
100                                }
101                            }
102                            if metadata_size > 0 && remote_ut_metadata_id > 0 && !request_sent {
103                                if metadata_size > 10 * 1024 * 1024 {
104                                    #[cfg(feature = "metrics")]
105                                    counter!("dht_metadata_fetch_fail_total", "reason" => "size_limit").increment(1);
106                                    return None;
107                                }
108
109                                let count = metadata_piece_count(metadata_size as usize);
110                                for i in 0..count {
111                                    let req = MetadataMessage::request(i as u32);
112                                    if let Ok(encoded) = req.encode() {
113                                        let _ = conn.send(Message::Extended { id: remote_ut_metadata_id, payload: encoded }).await;
114                                    }
115                                }
116                                request_sent = true;
117                            }
118                        } else if id == my_ut_metadata_id {
119                            if let Ok(meta_msg) = MetadataMessage::decode(&payload)
120                                && meta_msg.msg_type == MetadataMessageType::Data
121                                && let Some(data) = meta_msg.data {
122                                #[cfg(feature = "metrics")]
123                                counter!("dht_metadata_bytes_downloaded_total").increment(data.len() as u64);
124                                pieces.insert(meta_msg.piece, data);
125                            }
126                            if metadata_size > 0 {
127                                let total_received: usize = pieces.values().map(|p| p.len()).sum();
128                                if total_received >= metadata_size as usize {
129                                    let mut full_data = Vec::with_capacity(metadata_size as usize);
130                                    let count = metadata_piece_count(metadata_size as usize);
131                                    let mut success = true;
132                                    for i in 0..count {
133                                        if let Some(p) = pieces.get(&(i as u32)) {
134                                            full_data.extend_from_slice(p);
135                                        } else {
136                                            success = false; break;
137                                        }
138                                    }
139                                    if success {
140                                        let info_hash_copy = *info_hash;
141                                        let full_data_clone = full_data.clone();
142                                        let is_valid = tokio::task::spawn_blocking(move || {
143                                            let mut hasher = Sha1::new();
144                                            hasher.update(&full_data_clone);
145                                            let digest: [u8; 20] = hasher.finalize().into();
146                                            digest == info_hash_copy
147                                        }).await.unwrap_or(false);
148
149                                        if is_valid {
150                                            #[cfg(feature = "metrics")]
151                                            counter!("dht_metadata_handshake_result_total", "result" => "success").increment(1);
152                                            return Some(full_data);
153                                        }
154                                        #[cfg(feature = "metrics")]
155                                        counter!("dht_metadata_fetch_fail_total", "reason" => "sha1_mismatch").increment(1);
156                                        return None;
157                                    }
158                                }
159                            }
160                        }
161                    }
162            }
163        }).await;
164
165        match result {
166            Ok(Some(info_bytes)) => {
167                if let Ok(value) = rbit::decode(&info_bytes)
168                    && let Some(dict) = value.as_dict() {
169                        let name = dict
170                            .get(&b"name"[..])
171                            .and_then(|v| v.as_str())
172                            .unwrap_or("Unknown")
173                            .to_string();
174                        let mut total_size = 0;
175                        let mut file_list = Vec::new();
176                        if let Some(files) = dict.get(&b"files"[..]).and_then(|v| v.as_list()) {
177                            for file in files {
178                                if let Some(f_dict) = file.as_dict()
179                                    && let Some(len) = f_dict.get(&b"length"[..]).and_then(|v| v.as_integer()) {
180                                    let len = len as u64;
181                                    total_size += len;
182                                    let mut path_parts = Vec::new();
183                                    if let Some(path_list) =
184                                        f_dict.get(&b"path"[..]).and_then(|v| v.as_list())
185                                    {
186                                        for p in path_list {
187                                            if let Some(p_str) = p.as_str() {
188                                                path_parts.push(p_str);
189                                            }
190                                        }
191                                    }
192                                    file_list.push(FileInfo {
193                                        path: path_parts.join("/"),
194                                        size: len,
195                                    });
196                                }
197                            }
198                        } else if let Some(len) =
199                            dict.get(&b"length"[..]).and_then(|v| v.as_integer())
200                        {
201                            total_size = len as u64;
202                            file_list.push(FileInfo {
203                                path: name.clone(),
204                                size: total_size,
205                            });
206                        }
207                        if total_size > 0 {
208                            #[cfg(feature = "metrics")]
209                            {
210                                counter!("dht_metadata_fetch_success_total").increment(1);
211                                histogram!("dht_metadata_size_bytes").record(total_size as f64);
212                            }
213                            return Some((name, total_size, file_list));
214                        }
215                    }
216                #[cfg(feature = "metrics")]
217                counter!("dht_metadata_fetch_fail_total", "reason" => "parse_error").increment(1);
218                None
219            }
220            _ => {
221                #[cfg(feature = "metrics")]
222                counter!("dht_metadata_fetch_fail_total", "reason" => "timeout").increment(1);
223                None
224            }
225        }
226    }
227}