dht_crawler/
sharded.rs

1use bloomfilter::Bloom;
2use std::collections::{HashSet, VecDeque};
3use std::net::SocketAddr;
4use std::sync::Mutex;
5use std::sync::atomic::{AtomicUsize, Ordering};
6
7const BLOOM_SHARD_COUNT: usize = 32;
8const QUEUE_SHARD_COUNT: usize = 16;
9
10pub struct ShardedBloom {
11    shards: Vec<Mutex<Bloom<[u8; 20]>>>,
12    count: AtomicUsize,
13}
14
15impl ShardedBloom {
16    pub fn new_for_fp_rate(expected_items: usize, fp_rate: f64) -> Self {
17        #[allow(clippy::manual_div_ceil)]
18        let items_per_shard = (expected_items + BLOOM_SHARD_COUNT - 1) / BLOOM_SHARD_COUNT;
19
20        let shards = (0..BLOOM_SHARD_COUNT)
21            .map(|_| Mutex::new(Bloom::new_for_fp_rate(items_per_shard, fp_rate)))
22            .collect();
23
24        Self {
25            shards,
26            count: AtomicUsize::new(0),
27        }
28    }
29
30    pub fn check_and_set(&self, hash: &[u8; 20]) -> bool {
31        let shard_idx = self.hash_to_shard(hash);
32        let mut shard = self.shards[shard_idx].lock().unwrap();
33        let present = shard.check_and_set(hash);
34
35        if !present {
36            self.count.fetch_add(1, Ordering::Relaxed);
37        }
38        present
39    }
40
41    pub fn number_of_bits(&self) -> u64 {
42        self.count.load(Ordering::Relaxed) as u64
43    }
44
45    #[inline]
46    fn hash_to_shard(&self, hash: &[u8; 20]) -> usize {
47        let idx = (hash[0] as usize) | ((hash[1] as usize) << 8);
48        idx % BLOOM_SHARD_COUNT
49    }
50}
51
52#[derive(Debug, Clone)]
53pub struct NodeTuple {
54    pub id: Vec<u8>,
55    pub addr: SocketAddr,
56}
57
58struct NodeQueueShard {
59    queue: VecDeque<NodeTuple>,
60    index: HashSet<SocketAddr>,
61    capacity: usize,
62}
63
64impl NodeQueueShard {
65    fn new(capacity: usize) -> Self {
66        Self {
67            queue: VecDeque::with_capacity(capacity),
68            index: HashSet::with_capacity(capacity),
69            capacity,
70        }
71    }
72
73    fn push(&mut self, node: NodeTuple) {
74        if self.index.contains(&node.addr) {
75            return;
76        }
77
78        if self.queue.len() >= self.capacity
79            && let Some(removed) = self.queue.pop_front() {
80            self.index.remove(&removed.addr);
81        }
82
83        self.index.insert(node.addr);
84        self.queue.push_back(node);
85    }
86
87    fn pop_batch(&mut self, count: usize) -> Vec<NodeTuple> {
88        let actual_count = count.min(self.queue.len());
89        let mut nodes = Vec::with_capacity(actual_count);
90
91        for _ in 0..actual_count {
92            if let Some(node) = self.queue.pop_front() {
93                self.index.remove(&node.addr);
94                nodes.push(node);
95            }
96        }
97        nodes
98    }
99
100    fn len(&self) -> usize {
101        self.queue.len()
102    }
103
104    fn is_empty(&self) -> bool {
105        self.queue.is_empty()
106    }
107}
108
109pub struct ShardedNodeQueue {
110    shards_v4: Vec<Mutex<NodeQueueShard>>,
111    shards_v6: Vec<Mutex<NodeQueueShard>>,
112}
113
114impl ShardedNodeQueue {
115    pub fn new(total_capacity: usize) -> Self {
116        #[allow(clippy::manual_div_ceil)]
117        let capacity_per_shard = (total_capacity + QUEUE_SHARD_COUNT - 1) / QUEUE_SHARD_COUNT;
118
119        let shards_v4 = (0..QUEUE_SHARD_COUNT)
120            .map(|_| Mutex::new(NodeQueueShard::new(capacity_per_shard)))
121            .collect();
122
123        let shards_v6 = (0..QUEUE_SHARD_COUNT)
124            .map(|_| Mutex::new(NodeQueueShard::new(capacity_per_shard)))
125            .collect();
126
127        Self {
128            shards_v4,
129            shards_v6,
130        }
131    }
132
133    pub fn push(&self, node: NodeTuple) {
134        let shard_idx = self.addr_to_shard(&node.addr);
135
136        if node.addr.is_ipv6() {
137            let mut shard = self.shards_v6[shard_idx].lock().unwrap();
138            shard.push(node);
139        } else {
140            let mut shard = self.shards_v4[shard_idx].lock().unwrap();
141            shard.push(node);
142        }
143    }
144
145    pub fn pop_batch(&self, count: usize, filter_ipv6: Option<bool>) -> Vec<NodeTuple> {
146        let mut result = Vec::with_capacity(count);
147        #[allow(clippy::manual_div_ceil)]
148        let per_shard = (count + QUEUE_SHARD_COUNT - 1) / QUEUE_SHARD_COUNT;
149
150        match filter_ipv6 {
151            Some(true) => {
152                for shard in &self.shards_v6 {
153                    if result.len() >= count {
154                        break;
155                    }
156                    let mut s = shard.lock().unwrap();
157                    let nodes = s.pop_batch(per_shard);
158                    result.extend(nodes);
159                }
160            }
161            Some(false) => {
162                for shard in &self.shards_v4 {
163                    if result.len() >= count {
164                        break;
165                    }
166                    let mut s = shard.lock().unwrap();
167                    let nodes = s.pop_batch(per_shard);
168                    result.extend(nodes);
169                }
170            }
171            None => {
172                for i in 0..QUEUE_SHARD_COUNT {
173                    if result.len() >= count {
174                        break;
175                    }
176
177                    let mut s4 = self.shards_v4[i].lock().unwrap();
178                    let nodes4 = s4.pop_batch(per_shard / 2);
179                    result.extend(nodes4);
180                    drop(s4);
181
182                    if result.len() >= count {
183                        break;
184                    }
185
186                    let mut s6 = self.shards_v6[i].lock().unwrap();
187                    let nodes6 = s6.pop_batch(per_shard / 2);
188                    result.extend(nodes6);
189                    drop(s6);
190                }
191            }
192        }
193
194        result
195    }
196
197    pub fn get_random_nodes(&self, count: usize, filter_ipv6: Option<bool>) -> Vec<NodeTuple> {
198        match filter_ipv6 {
199            Some(true) => self.get_random_nodes_from_shards(&self.shards_v6, count),
200            Some(false) => self.get_random_nodes_from_shards(&self.shards_v4, count),
201            None => {
202                let count_v4 = count / 2;
203                let count_v6 = count - count_v4;
204                let mut result = Vec::with_capacity(count);
205
206                result.extend(self.get_random_nodes_from_shards(&self.shards_v4, count_v4));
207                result.extend(self.get_random_nodes_from_shards(&self.shards_v6, count_v6));
208
209                result
210            }
211        }
212    }
213
214    fn get_random_nodes_from_shards(
215        &self,
216        shards: &[Mutex<NodeQueueShard>],
217        count: usize,
218    ) -> Vec<NodeTuple> {
219        use rand::Rng;
220        let mut rng = rand::thread_rng();
221
222        if count <= 16 {
223            let mut result = Vec::with_capacity(count);
224            #[allow(clippy::manual_div_ceil)]
225            let per_shard = (count + QUEUE_SHARD_COUNT - 1) / QUEUE_SHARD_COUNT;
226
227            for shard in shards {
228                if result.len() >= count {
229                    break;
230                }
231
232                let s = shard.lock().unwrap();
233                let shard_len = s.queue.len();
234
235                if shard_len == 0 {
236                    continue;
237                }
238
239                let to_take = per_shard.min(shard_len).min(count - result.len());
240
241                let mut indices: Vec<usize> = (0..shard_len).collect();
242
243                for i in 0..to_take {
244                    let j = rng.gen_range(i..shard_len);
245                    indices.swap(i, j);
246                }
247
248                for &idx in indices.iter().take(to_take) {
249                    if let Some(node) = s.queue.get(idx) {
250                        result.push(node.clone());
251                    }
252                }
253            }
254
255            result
256        } else {
257            let mut result = Vec::with_capacity(count);
258            let mut seen = 0usize;
259
260            for shard in shards {
261                let s = shard.lock().unwrap();
262
263                for node in s.queue.iter() {
264                    seen += 1;
265
266                    if result.len() < count {
267                        result.push(node.clone());
268                    } else {
269                        let j = rng.gen_range(0..seen);
270                        if j < count {
271                            result[j] = node.clone();
272                        }
273                    }
274                }
275            }
276
277            result
278        }
279    }
280
281    pub fn len(&self) -> usize {
282        let len_v4: usize = self
283            .shards_v4
284            .iter()
285            .map(|shard| shard.lock().unwrap().len())
286            .sum();
287        let len_v6: usize = self
288            .shards_v6
289            .iter()
290            .map(|shard| shard.lock().unwrap().len())
291            .sum();
292        len_v4 + len_v6
293    }
294
295    pub fn is_empty(&self) -> bool {
296        let empty_v4 = self
297            .shards_v4
298            .iter()
299            .all(|shard| shard.lock().unwrap().is_empty());
300        let empty_v6 = self
301            .shards_v6
302            .iter()
303            .all(|shard| shard.lock().unwrap().is_empty());
304        empty_v4 && empty_v6
305    }
306
307    pub fn is_empty_for(&self, filter_ipv6: Option<bool>) -> bool {
308        match filter_ipv6 {
309            Some(true) => self
310                .shards_v6
311                .iter()
312                .all(|shard| shard.lock().unwrap().is_empty()),
313            Some(false) => self
314                .shards_v4
315                .iter()
316                .all(|shard| shard.lock().unwrap().is_empty()),
317            None => self.is_empty(),
318        }
319    }
320
321    #[inline]
322    fn addr_to_shard(&self, addr: &SocketAddr) -> usize {
323        let hash = match addr.ip() {
324            std::net::IpAddr::V4(ip) => {
325                let octets = ip.octets();
326                (octets[3] as usize) ^ (addr.port() as usize)
327            }
328            std::net::IpAddr::V6(ip) => {
329                let octets = ip.octets();
330                (octets[15] as usize) ^ (addr.port() as usize)
331            }
332        };
333        hash % QUEUE_SHARD_COUNT
334    }
335}