1use bloomfilter::Bloom;
2use std::collections::{HashSet, VecDeque};
3use std::net::SocketAddr;
4use std::sync::Mutex;
5use std::sync::atomic::{AtomicUsize, Ordering};
6
7const BLOOM_SHARD_COUNT: usize = 32;
8const QUEUE_SHARD_COUNT: usize = 16;
9
10pub struct ShardedBloom {
11 shards: Vec<Mutex<Bloom<[u8; 20]>>>,
12 count: AtomicUsize,
13}
14
15impl ShardedBloom {
16 pub fn new_for_fp_rate(expected_items: usize, fp_rate: f64) -> Self {
17 #[allow(clippy::manual_div_ceil)]
18 let items_per_shard = (expected_items + BLOOM_SHARD_COUNT - 1) / BLOOM_SHARD_COUNT;
19
20 let shards = (0..BLOOM_SHARD_COUNT)
21 .map(|_| Mutex::new(Bloom::new_for_fp_rate(items_per_shard, fp_rate)))
22 .collect();
23
24 Self {
25 shards,
26 count: AtomicUsize::new(0),
27 }
28 }
29
30 pub fn check_and_set(&self, hash: &[u8; 20]) -> bool {
31 let shard_idx = self.hash_to_shard(hash);
32 let mut shard = self.shards[shard_idx].lock().unwrap();
33 let present = shard.check_and_set(hash);
34
35 if !present {
36 self.count.fetch_add(1, Ordering::Relaxed);
37 }
38 present
39 }
40
41 pub fn number_of_bits(&self) -> u64 {
42 self.count.load(Ordering::Relaxed) as u64
43 }
44
45 #[inline]
46 fn hash_to_shard(&self, hash: &[u8; 20]) -> usize {
47 let idx = (hash[0] as usize) | ((hash[1] as usize) << 8);
48 idx % BLOOM_SHARD_COUNT
49 }
50}
51
52#[derive(Debug, Clone)]
53pub struct NodeTuple {
54 pub id: Vec<u8>,
55 pub addr: SocketAddr,
56}
57
58struct NodeQueueShard {
59 queue: VecDeque<NodeTuple>,
60 index: HashSet<SocketAddr>,
61 capacity: usize,
62}
63
64impl NodeQueueShard {
65 fn new(capacity: usize) -> Self {
66 Self {
67 queue: VecDeque::with_capacity(capacity),
68 index: HashSet::with_capacity(capacity),
69 capacity,
70 }
71 }
72
73 fn push(&mut self, node: NodeTuple) {
74 if self.index.contains(&node.addr) {
75 return;
76 }
77
78 if self.queue.len() >= self.capacity
79 && let Some(removed) = self.queue.pop_front() {
80 self.index.remove(&removed.addr);
81 }
82
83 self.index.insert(node.addr);
84 self.queue.push_back(node);
85 }
86
87 fn pop_batch(&mut self, count: usize) -> Vec<NodeTuple> {
88 let actual_count = count.min(self.queue.len());
89 let mut nodes = Vec::with_capacity(actual_count);
90
91 for _ in 0..actual_count {
92 if let Some(node) = self.queue.pop_front() {
93 self.index.remove(&node.addr);
94 nodes.push(node);
95 }
96 }
97 nodes
98 }
99
100 fn len(&self) -> usize {
101 self.queue.len()
102 }
103
104 fn is_empty(&self) -> bool {
105 self.queue.is_empty()
106 }
107}
108
109pub struct ShardedNodeQueue {
110 shards_v4: Vec<Mutex<NodeQueueShard>>,
111 shards_v6: Vec<Mutex<NodeQueueShard>>,
112}
113
114impl ShardedNodeQueue {
115 pub fn new(total_capacity: usize) -> Self {
116 #[allow(clippy::manual_div_ceil)]
117 let capacity_per_shard = (total_capacity + QUEUE_SHARD_COUNT - 1) / QUEUE_SHARD_COUNT;
118
119 let shards_v4 = (0..QUEUE_SHARD_COUNT)
120 .map(|_| Mutex::new(NodeQueueShard::new(capacity_per_shard)))
121 .collect();
122
123 let shards_v6 = (0..QUEUE_SHARD_COUNT)
124 .map(|_| Mutex::new(NodeQueueShard::new(capacity_per_shard)))
125 .collect();
126
127 Self {
128 shards_v4,
129 shards_v6,
130 }
131 }
132
133 pub fn push(&self, node: NodeTuple) {
134 let shard_idx = self.addr_to_shard(&node.addr);
135
136 if node.addr.is_ipv6() {
137 let mut shard = self.shards_v6[shard_idx].lock().unwrap();
138 shard.push(node);
139 } else {
140 let mut shard = self.shards_v4[shard_idx].lock().unwrap();
141 shard.push(node);
142 }
143 }
144
145 pub fn pop_batch(&self, count: usize, filter_ipv6: Option<bool>) -> Vec<NodeTuple> {
146 let mut result = Vec::with_capacity(count);
147 #[allow(clippy::manual_div_ceil)]
148 let per_shard = (count + QUEUE_SHARD_COUNT - 1) / QUEUE_SHARD_COUNT;
149
150 match filter_ipv6 {
151 Some(true) => {
152 for shard in &self.shards_v6 {
153 if result.len() >= count {
154 break;
155 }
156 let mut s = shard.lock().unwrap();
157 let nodes = s.pop_batch(per_shard);
158 result.extend(nodes);
159 }
160 }
161 Some(false) => {
162 for shard in &self.shards_v4 {
163 if result.len() >= count {
164 break;
165 }
166 let mut s = shard.lock().unwrap();
167 let nodes = s.pop_batch(per_shard);
168 result.extend(nodes);
169 }
170 }
171 None => {
172 for i in 0..QUEUE_SHARD_COUNT {
173 if result.len() >= count {
174 break;
175 }
176
177 let mut s4 = self.shards_v4[i].lock().unwrap();
178 let nodes4 = s4.pop_batch(per_shard / 2);
179 result.extend(nodes4);
180 drop(s4);
181
182 if result.len() >= count {
183 break;
184 }
185
186 let mut s6 = self.shards_v6[i].lock().unwrap();
187 let nodes6 = s6.pop_batch(per_shard / 2);
188 result.extend(nodes6);
189 drop(s6);
190 }
191 }
192 }
193
194 result
195 }
196
197 pub fn get_random_nodes(&self, count: usize, filter_ipv6: Option<bool>) -> Vec<NodeTuple> {
198 match filter_ipv6 {
199 Some(true) => self.get_random_nodes_from_shards(&self.shards_v6, count),
200 Some(false) => self.get_random_nodes_from_shards(&self.shards_v4, count),
201 None => {
202 let count_v4 = count / 2;
203 let count_v6 = count - count_v4;
204 let mut result = Vec::with_capacity(count);
205
206 result.extend(self.get_random_nodes_from_shards(&self.shards_v4, count_v4));
207 result.extend(self.get_random_nodes_from_shards(&self.shards_v6, count_v6));
208
209 result
210 }
211 }
212 }
213
214 fn get_random_nodes_from_shards(
215 &self,
216 shards: &[Mutex<NodeQueueShard>],
217 count: usize,
218 ) -> Vec<NodeTuple> {
219 use rand::Rng;
220 let mut rng = rand::thread_rng();
221
222 if count <= 16 {
223 let mut result = Vec::with_capacity(count);
224 #[allow(clippy::manual_div_ceil)]
225 let per_shard = (count + QUEUE_SHARD_COUNT - 1) / QUEUE_SHARD_COUNT;
226
227 for shard in shards {
228 if result.len() >= count {
229 break;
230 }
231
232 let s = shard.lock().unwrap();
233 let shard_len = s.queue.len();
234
235 if shard_len == 0 {
236 continue;
237 }
238
239 let to_take = per_shard.min(shard_len).min(count - result.len());
240
241 let mut indices: Vec<usize> = (0..shard_len).collect();
242
243 for i in 0..to_take {
244 let j = rng.gen_range(i..shard_len);
245 indices.swap(i, j);
246 }
247
248 for &idx in indices.iter().take(to_take) {
249 if let Some(node) = s.queue.get(idx) {
250 result.push(node.clone());
251 }
252 }
253 }
254
255 result
256 } else {
257 let mut result = Vec::with_capacity(count);
258 let mut seen = 0usize;
259
260 for shard in shards {
261 let s = shard.lock().unwrap();
262
263 for node in s.queue.iter() {
264 seen += 1;
265
266 if result.len() < count {
267 result.push(node.clone());
268 } else {
269 let j = rng.gen_range(0..seen);
270 if j < count {
271 result[j] = node.clone();
272 }
273 }
274 }
275 }
276
277 result
278 }
279 }
280
281 pub fn len(&self) -> usize {
282 let len_v4: usize = self
283 .shards_v4
284 .iter()
285 .map(|shard| shard.lock().unwrap().len())
286 .sum();
287 let len_v6: usize = self
288 .shards_v6
289 .iter()
290 .map(|shard| shard.lock().unwrap().len())
291 .sum();
292 len_v4 + len_v6
293 }
294
295 pub fn is_empty(&self) -> bool {
296 let empty_v4 = self
297 .shards_v4
298 .iter()
299 .all(|shard| shard.lock().unwrap().is_empty());
300 let empty_v6 = self
301 .shards_v6
302 .iter()
303 .all(|shard| shard.lock().unwrap().is_empty());
304 empty_v4 && empty_v6
305 }
306
307 pub fn is_empty_for(&self, filter_ipv6: Option<bool>) -> bool {
308 match filter_ipv6 {
309 Some(true) => self
310 .shards_v6
311 .iter()
312 .all(|shard| shard.lock().unwrap().is_empty()),
313 Some(false) => self
314 .shards_v4
315 .iter()
316 .all(|shard| shard.lock().unwrap().is_empty()),
317 None => self.is_empty(),
318 }
319 }
320
321 #[inline]
322 fn addr_to_shard(&self, addr: &SocketAddr) -> usize {
323 let hash = match addr.ip() {
324 std::net::IpAddr::V4(ip) => {
325 let octets = ip.octets();
326 (octets[3] as usize) ^ (addr.port() as usize)
327 }
328 std::net::IpAddr::V6(ip) => {
329 let octets = ip.octets();
330 (octets[15] as usize) ^ (addr.port() as usize)
331 }
332 };
333 hash % QUEUE_SHARD_COUNT
334 }
335}