1use bloomfilter::Bloom;
5use std::collections::{HashSet, VecDeque};
6use std::net::SocketAddr;
7use std::sync::Mutex;
8use std::sync::atomic::{AtomicUsize, Ordering};
9
10const BLOOM_SHARD_COUNT: usize = 32; const QUEUE_SHARD_COUNT: usize = 16; pub struct ShardedBloom {
21 shards: Vec<Mutex<Bloom<[u8; 20]>>>,
22 count: AtomicUsize,
23}
24
25impl ShardedBloom {
26 pub fn new_for_fp_rate(expected_items: usize, fp_rate: f64) -> Self {
28 let items_per_shard = (expected_items + BLOOM_SHARD_COUNT - 1) / BLOOM_SHARD_COUNT;
29
30 let shards = (0..BLOOM_SHARD_COUNT)
31 .map(|_| Mutex::new(Bloom::new_for_fp_rate(items_per_shard, fp_rate)))
32 .collect();
33
34 Self {
35 shards,
36 count: AtomicUsize::new(0),
37 }
38 }
39
40 pub fn check_and_set(&self, hash: &[u8; 20]) -> bool {
42 let shard_idx = self.hash_to_shard(hash);
43 let mut shard = self.shards[shard_idx].lock().unwrap();
44 let present = shard.check_and_set(hash);
45
46 if !present {
48 self.count.fetch_add(1, Ordering::Relaxed);
49 }
50 present
51 }
52
53 pub fn number_of_bits(&self) -> u64 {
55 self.count.load(Ordering::Relaxed) as u64
56 }
57
58 #[inline]
60 fn hash_to_shard(&self, hash: &[u8; 20]) -> usize {
61 let idx = (hash[0] as usize) | ((hash[1] as usize) << 8);
63 idx % BLOOM_SHARD_COUNT
64 }
65}
66
67#[derive(Debug, Clone)]
71pub struct NodeTuple {
72 pub id: Vec<u8>,
73 pub addr: SocketAddr,
74}
75
76struct NodeQueueShard {
78 queue: VecDeque<NodeTuple>,
79 index: HashSet<SocketAddr>,
80 capacity: usize,
81}
82
83impl NodeQueueShard {
84 fn new(capacity: usize) -> Self {
85 Self {
86 queue: VecDeque::with_capacity(capacity),
87 index: HashSet::with_capacity(capacity),
88 capacity,
89 }
90 }
91
92 fn push(&mut self, node: NodeTuple) {
93 if self.index.contains(&node.addr) {
94 return;
95 }
96
97 if self.queue.len() >= self.capacity {
99 if let Some(removed) = self.queue.pop_front() {
100 self.index.remove(&removed.addr);
101 }
102 }
103
104 self.index.insert(node.addr);
105 self.queue.push_back(node);
106 }
107
108 fn pop_batch(&mut self, count: usize) -> Vec<NodeTuple> {
109 let actual_count = count.min(self.queue.len());
110 let mut nodes = Vec::with_capacity(actual_count);
111
112 for _ in 0..actual_count {
113 if let Some(node) = self.queue.pop_front() {
114 self.index.remove(&node.addr);
115 nodes.push(node);
116 }
117 }
118 nodes
119 }
120
121 fn len(&self) -> usize {
122 self.queue.len()
123 }
124
125 fn is_empty(&self) -> bool {
126 self.queue.is_empty()
127 }
128}
129
130pub struct ShardedNodeQueue {
132 shards_v4: Vec<Mutex<NodeQueueShard>>, shards_v6: Vec<Mutex<NodeQueueShard>>, }
135
136impl ShardedNodeQueue {
137 pub fn new(total_capacity: usize) -> Self {
139 let capacity_per_shard = (total_capacity + QUEUE_SHARD_COUNT - 1) / QUEUE_SHARD_COUNT;
140
141 let shards_v4 = (0..QUEUE_SHARD_COUNT)
142 .map(|_| Mutex::new(NodeQueueShard::new(capacity_per_shard)))
143 .collect();
144
145 let shards_v6 = (0..QUEUE_SHARD_COUNT)
146 .map(|_| Mutex::new(NodeQueueShard::new(capacity_per_shard)))
147 .collect();
148
149 Self { shards_v4, shards_v6 }
150 }
151
152 pub fn push(&self, node: NodeTuple) {
154 let shard_idx = self.addr_to_shard(&node.addr);
155
156 if node.addr.is_ipv6() {
157 let mut shard = self.shards_v6[shard_idx].lock().unwrap();
158 shard.push(node);
159 } else {
160 let mut shard = self.shards_v4[shard_idx].lock().unwrap();
161 shard.push(node);
162 }
163 }
164
165 pub fn pop_batch(&self, count: usize, filter_ipv6: Option<bool>) -> Vec<NodeTuple> {
171 let mut result = Vec::with_capacity(count);
172 let per_shard = (count + QUEUE_SHARD_COUNT - 1) / QUEUE_SHARD_COUNT;
173
174 match filter_ipv6 {
175 Some(true) => {
176 for shard in &self.shards_v6 {
178 if result.len() >= count {
179 break;
180 }
181 let mut s = shard.lock().unwrap();
182 let nodes = s.pop_batch(per_shard);
183 result.extend(nodes);
184 }
185 },
186 Some(false) => {
187 for shard in &self.shards_v4 {
189 if result.len() >= count {
190 break;
191 }
192 let mut s = shard.lock().unwrap();
193 let nodes = s.pop_batch(per_shard);
194 result.extend(nodes);
195 }
196 },
197 None => {
198 for i in 0..QUEUE_SHARD_COUNT {
200 if result.len() >= count {
201 break;
202 }
203
204 let mut s4 = self.shards_v4[i].lock().unwrap();
206 let nodes4 = s4.pop_batch(per_shard / 2);
207 result.extend(nodes4);
208 drop(s4);
209
210 if result.len() >= count {
211 break;
212 }
213
214 let mut s6 = self.shards_v6[i].lock().unwrap();
216 let nodes6 = s6.pop_batch(per_shard / 2);
217 result.extend(nodes6);
218 drop(s6);
219 }
220 },
221 }
222
223 result
224 }
225
226 pub fn get_random_nodes(&self, count: usize, filter_ipv6: Option<bool>) -> Vec<NodeTuple> {
231 match filter_ipv6 {
232 Some(true) => {
233 self.get_random_nodes_from_shards(&self.shards_v6, count)
235 },
236 Some(false) => {
237 self.get_random_nodes_from_shards(&self.shards_v4, count)
239 },
240 None => {
241 let count_v4 = count / 2;
243 let count_v6 = count - count_v4;
244 let mut result = Vec::with_capacity(count);
245
246 result.extend(self.get_random_nodes_from_shards(&self.shards_v4, count_v4));
247 result.extend(self.get_random_nodes_from_shards(&self.shards_v6, count_v6));
248
249 result
250 },
251 }
252 }
253
254 fn get_random_nodes_from_shards(&self, shards: &[Mutex<NodeQueueShard>], count: usize) -> Vec<NodeTuple> {
256 use rand::Rng;
257 let mut rng = rand::thread_rng();
258
259 if count <= 16 {
261 let mut result = Vec::with_capacity(count);
262 let per_shard = (count + QUEUE_SHARD_COUNT - 1) / QUEUE_SHARD_COUNT;
263
264 for shard in shards {
265 if result.len() >= count {
266 break;
267 }
268
269 let s = shard.lock().unwrap();
270 let shard_len = s.queue.len();
271
272 if shard_len == 0 {
273 continue;
274 }
275
276 let to_take = per_shard.min(shard_len).min(count - result.len());
278
279 let mut indices: Vec<usize> = (0..shard_len).collect();
281
282 for i in 0..to_take {
284 let j = rng.gen_range(i..shard_len);
285 indices.swap(i, j);
286 }
287
288 for i in 0..to_take {
290 if let Some(node) = s.queue.get(indices[i]) {
291 result.push(node.clone());
292 }
293 }
294 }
295
296 result
297 } else {
298 let mut result = Vec::with_capacity(count);
300 let mut seen = 0usize;
301
302 for shard in shards {
304 let s = shard.lock().unwrap();
305
306 for node in s.queue.iter() {
307 seen += 1;
308
309 if result.len() < count {
310 result.push(node.clone());
312 } else {
313 let j = rng.gen_range(0..seen);
315 if j < count {
316 result[j] = node.clone();
317 }
318 }
319 }
320 }
321
322 result
323 }
324 }
325
326
327 pub fn len(&self) -> usize {
329 let len_v4: usize = self.shards_v4
330 .iter()
331 .map(|shard| shard.lock().unwrap().len())
332 .sum();
333 let len_v6: usize = self.shards_v6
334 .iter()
335 .map(|shard| shard.lock().unwrap().len())
336 .sum();
337 len_v4 + len_v6
338 }
339
340 pub fn is_empty(&self) -> bool {
342 let empty_v4 = self.shards_v4
343 .iter()
344 .all(|shard| shard.lock().unwrap().is_empty());
345 let empty_v6 = self.shards_v6
346 .iter()
347 .all(|shard| shard.lock().unwrap().is_empty());
348 empty_v4 && empty_v6
349 }
350
351 pub fn is_empty_for(&self, filter_ipv6: Option<bool>) -> bool {
356 match filter_ipv6 {
357 Some(true) => {
358 self.shards_v6
360 .iter()
361 .all(|shard| shard.lock().unwrap().is_empty())
362 },
363 Some(false) => {
364 self.shards_v4
366 .iter()
367 .all(|shard| shard.lock().unwrap().is_empty())
368 },
369 None => self.is_empty(),
370 }
371 }
372
373 #[inline]
375 fn addr_to_shard(&self, addr: &SocketAddr) -> usize {
376 let hash = match addr.ip() {
378 std::net::IpAddr::V4(ip) => {
379 let octets = ip.octets();
380 (octets[3] as usize) ^ (addr.port() as usize)
381 }
382 std::net::IpAddr::V6(ip) => {
383 let octets = ip.octets();
384 (octets[15] as usize) ^ (addr.port() as usize)
385 }
386 };
387 hash % QUEUE_SHARD_COUNT
388 }
389}
390