dht_crawler/
sharded.rs

1// 分片锁实现 - 大幅减少锁竞争,提升并发性能
2//
3// 核心思想:1个大锁 → N个小锁
4use bloomfilter::Bloom;
5use std::collections::{HashSet, VecDeque};
6use std::net::SocketAddr;
7use std::sync::Mutex;
8use std::sync::atomic::{AtomicUsize, Ordering};
9
10// 配置:分片数量
11const BLOOM_SHARD_COUNT: usize = 32;  // 32个布隆过滤器分片
12const QUEUE_SHARD_COUNT: usize = 16;  // 16个队列分片
13
14// ==================== 分片布隆过滤器 ====================
15
16/// 分片布隆过滤器 - 减少锁竞争
17/// 
18/// 将单个布隆过滤器拆分为32个分片,每个分片独立锁
19/// 不同的hash会落到不同的分片上,大幅减少竞争
20pub struct ShardedBloom {
21    shards: Vec<Mutex<Bloom<[u8; 20]>>>,
22    count: AtomicUsize,
23}
24
25impl ShardedBloom {
26    /// 创建新的分片布隆过滤器
27    pub fn new_for_fp_rate(expected_items: usize, fp_rate: f64) -> Self {
28        let items_per_shard = (expected_items + BLOOM_SHARD_COUNT - 1) / BLOOM_SHARD_COUNT;
29        
30        let shards = (0..BLOOM_SHARD_COUNT)
31            .map(|_| Mutex::new(Bloom::new_for_fp_rate(items_per_shard, fp_rate)))
32            .collect();
33        
34        Self { 
35            shards,
36            count: AtomicUsize::new(0),
37        }
38    }
39    
40    /// 检查并设置元素(原子操作)
41    pub fn check_and_set(&self, hash: &[u8; 20]) -> bool {
42        let shard_idx = self.hash_to_shard(hash);
43        let mut shard = self.shards[shard_idx].lock().unwrap();
44        let present = shard.check_and_set(hash);
45        
46        // 如果之前不存在,增加计数
47        if !present {
48            self.count.fetch_add(1, Ordering::Relaxed);
49        }
50        present
51    }
52    
53    /// 获取实际发现的唯一 InfoHash 数量
54    pub fn number_of_bits(&self) -> u64 {
55        self.count.load(Ordering::Relaxed) as u64
56    }
57    
58    /// 根据hash计算分片索引
59    #[inline]
60    fn hash_to_shard(&self, hash: &[u8; 20]) -> usize {
61        // 使用hash的前两个字节计算分片
62        let idx = (hash[0] as usize) | ((hash[1] as usize) << 8);
63        idx % BLOOM_SHARD_COUNT
64    }
65}
66
67// ==================== 分片节点队列 ====================
68
69/// 节点信息
70#[derive(Debug, Clone)]
71pub struct NodeTuple {
72    pub id: Vec<u8>,
73    pub addr: SocketAddr,
74}
75
76/// 单个队列分片
77struct NodeQueueShard {
78    queue: VecDeque<NodeTuple>,
79    index: HashSet<SocketAddr>,
80    capacity: usize,
81}
82
83impl NodeQueueShard {
84    fn new(capacity: usize) -> Self {
85        Self {
86            queue: VecDeque::with_capacity(capacity),
87            index: HashSet::with_capacity(capacity),
88            capacity,
89        }
90    }
91    
92    fn push(&mut self, node: NodeTuple) {
93        if self.index.contains(&node.addr) {
94            return;
95        }
96
97        // 如果满了,移除最早的一个(保持流动性,优胜劣汰)
98        if self.queue.len() >= self.capacity {
99            if let Some(removed) = self.queue.pop_front() {
100                self.index.remove(&removed.addr);
101            }
102        }
103
104        self.index.insert(node.addr);
105        self.queue.push_back(node);
106    }
107    
108    fn pop_batch(&mut self, count: usize) -> Vec<NodeTuple> {
109        let actual_count = count.min(self.queue.len());
110        let mut nodes = Vec::with_capacity(actual_count);
111        
112        for _ in 0..actual_count {
113            if let Some(node) = self.queue.pop_front() {
114                self.index.remove(&node.addr);
115                nodes.push(node);
116            }
117        }
118        nodes
119    }
120    
121    fn len(&self) -> usize {
122        self.queue.len()
123    }
124    
125    fn is_empty(&self) -> bool {
126        self.queue.is_empty()
127    }
128}
129
130/// 分片节点队列 - 支持高并发,IPv4 和 IPv6 节点分开存储
131pub struct ShardedNodeQueue {
132    shards_v4: Vec<Mutex<NodeQueueShard>>,  // IPv4 节点分片
133    shards_v6: Vec<Mutex<NodeQueueShard>>,  // IPv6 节点分片
134}
135
136impl ShardedNodeQueue {
137    /// 创建新的分片队列
138    pub fn new(total_capacity: usize) -> Self {
139        let capacity_per_shard = (total_capacity + QUEUE_SHARD_COUNT - 1) / QUEUE_SHARD_COUNT;
140        
141        let shards_v4 = (0..QUEUE_SHARD_COUNT)
142            .map(|_| Mutex::new(NodeQueueShard::new(capacity_per_shard)))
143            .collect();
144        
145        let shards_v6 = (0..QUEUE_SHARD_COUNT)
146            .map(|_| Mutex::new(NodeQueueShard::new(capacity_per_shard)))
147            .collect();
148        
149        Self { shards_v4, shards_v6 }
150    }
151    
152    /// 添加节点(根据地址类型自动存入对应队列)
153    pub fn push(&self, node: NodeTuple) {
154        let shard_idx = self.addr_to_shard(&node.addr);
155        
156        if node.addr.is_ipv6() {
157            let mut shard = self.shards_v6[shard_idx].lock().unwrap();
158            shard.push(node);
159        } else {
160            let mut shard = self.shards_v4[shard_idx].lock().unwrap();
161            shard.push(node);
162        }
163    }
164    
165    /// 批量弹出节点
166    /// 
167    /// # Arguments
168    /// * `count` - 需要获取的节点数量
169    /// * `filter_ipv6` - 如果为 `Some(true)`,只从 IPv6 队列获取;如果为 `Some(false)`,只从 IPv4 队列获取;如果为 `None`,从两个队列混合获取
170    pub fn pop_batch(&self, count: usize, filter_ipv6: Option<bool>) -> Vec<NodeTuple> {
171        let mut result = Vec::with_capacity(count);
172        let per_shard = (count + QUEUE_SHARD_COUNT - 1) / QUEUE_SHARD_COUNT;
173        
174        match filter_ipv6 {
175            Some(true) => {
176                // 只从 IPv6 队列获取
177                for shard in &self.shards_v6 {
178                    if result.len() >= count {
179                        break;
180                    }
181                    let mut s = shard.lock().unwrap();
182                    let nodes = s.pop_batch(per_shard);
183                    result.extend(nodes);
184                }
185            },
186            Some(false) => {
187                // 只从 IPv4 队列获取
188                for shard in &self.shards_v4 {
189                    if result.len() >= count {
190                        break;
191                    }
192                    let mut s = shard.lock().unwrap();
193                    let nodes = s.pop_batch(per_shard);
194                    result.extend(nodes);
195                }
196            },
197            None => {
198                // 混合模式:从两个队列交替获取
199                for i in 0..QUEUE_SHARD_COUNT {
200                    if result.len() >= count {
201                        break;
202                    }
203                    
204                    // 从 IPv4 分片获取
205                    let mut s4 = self.shards_v4[i].lock().unwrap();
206                    let nodes4 = s4.pop_batch(per_shard / 2);
207                    result.extend(nodes4);
208                    drop(s4);
209                    
210                    if result.len() >= count {
211                        break;
212                    }
213                    
214                    // 从 IPv6 分片获取
215                    let mut s6 = self.shards_v6[i].lock().unwrap();
216                    let nodes6 = s6.pop_batch(per_shard / 2);
217                    result.extend(nodes6);
218                    drop(s6);
219                }
220            },
221        }
222        
223        result
224    }
225    
226    /// 获取随机节点(用于DHT响应)
227    /// # Arguments
228    /// * `count` - 需要获取的节点数量
229    /// * `filter_ipv6` - 如果为 `Some(true)`,只返回 IPv6 节点;如果为 `Some(false)`,只返回 IPv4 节点;如果为 `None`,返回所有节点(混合)
230    pub fn get_random_nodes(&self, count: usize, filter_ipv6: Option<bool>) -> Vec<NodeTuple> {
231        match filter_ipv6 {
232            Some(true) => {
233                // 只要 IPv6 节点
234                self.get_random_nodes_from_shards(&self.shards_v6, count)
235            },
236            Some(false) => {
237                // 只要 IPv4 节点
238                self.get_random_nodes_from_shards(&self.shards_v4, count)
239            },
240            None => {
241                // 混合模式:从两个队列各取一半
242                let count_v4 = count / 2;
243                let count_v6 = count - count_v4;
244                let mut result = Vec::with_capacity(count);
245                
246                result.extend(self.get_random_nodes_from_shards(&self.shards_v4, count_v4));
247                result.extend(self.get_random_nodes_from_shards(&self.shards_v6, count_v6));
248                
249                result
250            },
251        }
252    }
253    
254    /// 从指定的分片组中获取随机节点
255    fn get_random_nodes_from_shards(&self, shards: &[Mutex<NodeQueueShard>], count: usize) -> Vec<NodeTuple> {
256        use rand::Rng;
257        let mut rng = rand::thread_rng();
258        
259        // 🚀 策略1:小规模请求用快速路径(最常见:8个节点)
260        if count <= 16 {
261            let mut result = Vec::with_capacity(count);
262            let per_shard = (count + QUEUE_SHARD_COUNT - 1) / QUEUE_SHARD_COUNT;
263            
264            for shard in shards {
265                if result.len() >= count {
266                    break;
267                }
268                
269                let s = shard.lock().unwrap();
270                let shard_len = s.queue.len();
271                
272                if shard_len == 0 {
273                    continue;
274                }
275                
276                // 从当前分片随机选择最多 per_shard 个节点
277                let to_take = per_shard.min(shard_len).min(count - result.len());
278                
279                // 生成随机索引(不重复)
280                let mut indices: Vec<usize> = (0..shard_len).collect();
281                
282                // 只 shuffle 前 to_take 个(部分 shuffle,Fisher-Yates 优化)
283                for i in 0..to_take {
284                    let j = rng.gen_range(i..shard_len);
285                    indices.swap(i, j);
286                }
287                
288                // 取前 to_take 个索引对应的节点
289                for i in 0..to_take {
290                    if let Some(node) = s.queue.get(indices[i]) {
291                        result.push(node.clone());
292                    }
293                }
294            }
295            
296            result
297        } else {
298            // 🚀 策略2:大规模请求用储层采样
299            let mut result = Vec::with_capacity(count);
300            let mut seen = 0usize;
301            
302            // 储层采样算法
303            for shard in shards {
304                let s = shard.lock().unwrap();
305                
306                for node in s.queue.iter() {
307                    seen += 1;
308                    
309                    if result.len() < count {
310                        // 前 count 个直接加入
311                        result.push(node.clone());
312                    } else {
313                        // 后续以 count/seen 的概率替换
314                        let j = rng.gen_range(0..seen);
315                        if j < count {
316                            result[j] = node.clone();
317                        }
318                    }
319                }
320            }
321            
322            result
323        }
324    }
325    
326    
327    /// 获取总长度(IPv4 + IPv6)
328    pub fn len(&self) -> usize {
329        let len_v4: usize = self.shards_v4
330            .iter()
331            .map(|shard| shard.lock().unwrap().len())
332            .sum();
333        let len_v6: usize = self.shards_v6
334            .iter()
335            .map(|shard| shard.lock().unwrap().len())
336            .sum();
337        len_v4 + len_v6
338    }
339    
340    /// 检查是否为空
341    pub fn is_empty(&self) -> bool {
342        let empty_v4 = self.shards_v4
343            .iter()
344            .all(|shard| shard.lock().unwrap().is_empty());
345        let empty_v6 = self.shards_v6
346            .iter()
347            .all(|shard| shard.lock().unwrap().is_empty());
348        empty_v4 && empty_v6
349    }
350    
351    /// 检查指定地址族的队列是否为空
352    /// 
353    /// # Arguments
354    /// * `filter_ipv6` - 如果为 `Some(true)`,检查 IPv6 队列;如果为 `Some(false)`,检查 IPv4 队列;如果为 `None`,检查两个队列
355    pub fn is_empty_for(&self, filter_ipv6: Option<bool>) -> bool {
356        match filter_ipv6 {
357            Some(true) => {
358                // 检查 IPv6 队列
359                self.shards_v6
360                    .iter()
361                    .all(|shard| shard.lock().unwrap().is_empty())
362            },
363            Some(false) => {
364                // 检查 IPv4 队列
365                self.shards_v4
366                    .iter()
367                    .all(|shard| shard.lock().unwrap().is_empty())
368            },
369            None => self.is_empty(),
370        }
371    }
372    
373    /// 根据地址计算分片索引
374    #[inline]
375    fn addr_to_shard(&self, addr: &SocketAddr) -> usize {
376        // 使用端口和IP最后一个字节
377        let hash = match addr.ip() {
378            std::net::IpAddr::V4(ip) => {
379                let octets = ip.octets();
380                (octets[3] as usize) ^ (addr.port() as usize)
381            }
382            std::net::IpAddr::V6(ip) => {
383                let octets = ip.octets();
384                (octets[15] as usize) ^ (addr.port() as usize)
385            }
386        };
387        hash % QUEUE_SHARD_COUNT
388    }
389}
390