1use std::collections::{HashSet, VecDeque};
2use std::net::SocketAddr;
3use std::sync::Mutex;
4
5const QUEUE_SHARD_COUNT: usize = 16;
6
7#[derive(Debug, Clone)]
8pub struct NodeTuple {
9 pub id: Vec<u8>,
10 pub addr: SocketAddr,
11}
12
13struct NodeQueueShard {
14 queue: VecDeque<NodeTuple>,
15 index: HashSet<SocketAddr>,
16 capacity: usize,
17}
18
19impl NodeQueueShard {
20 fn new(capacity: usize) -> Self {
21 Self {
22 queue: VecDeque::with_capacity(capacity),
23 index: HashSet::with_capacity(capacity),
24 capacity,
25 }
26 }
27
28 fn push(&mut self, node: NodeTuple) {
29 if self.index.contains(&node.addr) {
30 return;
31 }
32
33 if self.queue.len() >= self.capacity
34 && let Some(removed) = self.queue.pop_front()
35 {
36 self.index.remove(&removed.addr);
37 }
38
39 self.index.insert(node.addr);
40 self.queue.push_back(node);
41 }
42
43 fn pop_batch(&mut self, count: usize) -> Vec<NodeTuple> {
44 let actual_count = count.min(self.queue.len());
45 let mut nodes = Vec::with_capacity(actual_count);
46
47 for _ in 0..actual_count {
48 if let Some(node) = self.queue.pop_front() {
49 self.index.remove(&node.addr);
50 nodes.push(node);
51 }
52 }
53 nodes
54 }
55
56 fn len(&self) -> usize {
57 self.queue.len()
58 }
59
60 fn is_empty(&self) -> bool {
61 self.queue.is_empty()
62 }
63}
64
65pub struct ShardedNodeQueue {
66 shards_v4: Vec<Mutex<NodeQueueShard>>,
67 shards_v6: Vec<Mutex<NodeQueueShard>>,
68}
69
70impl ShardedNodeQueue {
71 pub fn new(total_capacity: usize) -> Self {
72 #[allow(clippy::manual_div_ceil)]
73 let capacity_per_shard = (total_capacity + QUEUE_SHARD_COUNT - 1) / QUEUE_SHARD_COUNT;
74
75 let shards_v4 = (0..QUEUE_SHARD_COUNT)
76 .map(|_| Mutex::new(NodeQueueShard::new(capacity_per_shard)))
77 .collect();
78
79 let shards_v6 = (0..QUEUE_SHARD_COUNT)
80 .map(|_| Mutex::new(NodeQueueShard::new(capacity_per_shard)))
81 .collect();
82
83 Self {
84 shards_v4,
85 shards_v6,
86 }
87 }
88
89 pub fn push(&self, node: NodeTuple) {
90 let shard_idx = self.addr_to_shard(&node.addr);
91
92 if node.addr.is_ipv6() {
93 let mut shard = self.shards_v6[shard_idx].lock().unwrap_or_else(|e| e.into_inner());
94 shard.push(node);
95 } else {
96 let mut shard = self.shards_v4[shard_idx].lock().unwrap_or_else(|e| e.into_inner());
97 shard.push(node);
98 }
99 }
100
101 pub fn pop_batch(&self, count: usize, filter_ipv6: Option<bool>) -> Vec<NodeTuple> {
102 let mut result = Vec::with_capacity(count);
103 #[allow(clippy::manual_div_ceil)]
104 let per_shard = (count + QUEUE_SHARD_COUNT - 1) / QUEUE_SHARD_COUNT;
105
106 match filter_ipv6 {
107 Some(true) => {
108 for shard in &self.shards_v6 {
109 if result.len() >= count {
110 break;
111 }
112 let mut s = shard.lock().unwrap_or_else(|e| e.into_inner());
113 let nodes = s.pop_batch(per_shard);
114 result.extend(nodes);
115 }
116 }
117 Some(false) => {
118 for shard in &self.shards_v4 {
119 if result.len() >= count {
120 break;
121 }
122 let mut s = shard.lock().unwrap_or_else(|e| e.into_inner());
123 let nodes = s.pop_batch(per_shard);
124 result.extend(nodes);
125 }
126 }
127 None => {
128 for i in 0..QUEUE_SHARD_COUNT {
129 if result.len() >= count {
130 break;
131 }
132
133 let mut s4 = self.shards_v4[i].lock().unwrap_or_else(|e| e.into_inner());
134 let nodes4 = s4.pop_batch(per_shard / 2);
135 result.extend(nodes4);
136 drop(s4);
137
138 if result.len() >= count {
139 break;
140 }
141
142 let mut s6 = self.shards_v6[i].lock().unwrap_or_else(|e| e.into_inner());
143 let nodes6 = s6.pop_batch(per_shard / 2);
144 result.extend(nodes6);
145 drop(s6);
146 }
147 }
148 }
149
150 result
151 }
152
153 pub fn get_random_nodes(&self, count: usize, filter_ipv6: Option<bool>) -> Vec<NodeTuple> {
154 match filter_ipv6 {
155 Some(true) => self.get_random_nodes_from_shards(&self.shards_v6, count),
156 Some(false) => self.get_random_nodes_from_shards(&self.shards_v4, count),
157 None => {
158 let count_v4 = count / 2;
159 let count_v6 = count - count_v4;
160 let mut result = Vec::with_capacity(count);
161
162 result.extend(self.get_random_nodes_from_shards(&self.shards_v4, count_v4));
163 result.extend(self.get_random_nodes_from_shards(&self.shards_v6, count_v6));
164
165 result
166 }
167 }
168 }
169
170 fn get_random_nodes_from_shards(
171 &self,
172 shards: &[Mutex<NodeQueueShard>],
173 count: usize,
174 ) -> Vec<NodeTuple> {
175 use rand::Rng;
176 let mut rng = rand::thread_rng();
177
178 if count <= 16 {
179 let mut result = Vec::with_capacity(count);
180 #[allow(clippy::manual_div_ceil)]
181 let per_shard = (count + QUEUE_SHARD_COUNT - 1) / QUEUE_SHARD_COUNT;
182
183 for shard in shards {
184 if result.len() >= count {
185 break;
186 }
187
188 let s = shard.lock().unwrap_or_else(|e| e.into_inner());
189 let shard_len = s.queue.len();
190
191 if shard_len == 0 {
192 continue;
193 }
194
195 let to_take = per_shard.min(shard_len).min(count - result.len());
196
197 let mut indices: Vec<usize> = (0..shard_len).collect();
198
199 for i in 0..to_take {
200 let j = rng.gen_range(i..shard_len);
201 indices.swap(i, j);
202 }
203
204 for &idx in indices.iter().take(to_take) {
205 if let Some(node) = s.queue.get(idx) {
206 result.push(node.clone());
207 }
208 }
209 }
210
211 result
212 } else {
213 let mut result = Vec::with_capacity(count);
214 let mut seen = 0usize;
215
216 for shard in shards {
217 let s = shard.lock().unwrap_or_else(|e| e.into_inner());
218
219 for node in s.queue.iter() {
220 seen += 1;
221
222 if result.len() < count {
223 result.push(node.clone());
224 } else {
225 let j = rng.gen_range(0..seen);
226 if j < count {
227 result[j] = node.clone();
228 }
229 }
230 }
231 }
232
233 result
234 }
235 }
236
237 pub fn len(&self) -> usize {
238 let len_v4: usize = self
239 .shards_v4
240 .iter()
241 .map(|shard| shard.lock().unwrap_or_else(|e| e.into_inner()).len())
242 .sum();
243 let len_v6: usize = self
244 .shards_v6
245 .iter()
246 .map(|shard| shard.lock().unwrap_or_else(|e| e.into_inner()).len())
247 .sum();
248 len_v4 + len_v6
249 }
250
251 pub fn is_empty(&self) -> bool {
252 let empty_v4 = self
253 .shards_v4
254 .iter()
255 .all(|shard| shard.lock().unwrap_or_else(|e| e.into_inner()).is_empty());
256 let empty_v6 = self
257 .shards_v6
258 .iter()
259 .all(|shard| shard.lock().unwrap_or_else(|e| e.into_inner()).is_empty());
260 empty_v4 && empty_v6
261 }
262
263 pub fn is_empty_for(&self, filter_ipv6: Option<bool>) -> bool {
264 match filter_ipv6 {
265 Some(true) => self
266 .shards_v6
267 .iter()
268 .all(|shard| shard.lock().unwrap_or_else(|e| e.into_inner()).is_empty()),
269 Some(false) => self
270 .shards_v4
271 .iter()
272 .all(|shard| shard.lock().unwrap_or_else(|e| e.into_inner()).is_empty()),
273 None => self.is_empty(),
274 }
275 }
276
277 #[inline]
278 fn addr_to_shard(&self, addr: &SocketAddr) -> usize {
279 let hash = match addr.ip() {
280 std::net::IpAddr::V4(ip) => {
281 let octets = ip.octets();
282 (octets[3] as usize) ^ (addr.port() as usize)
283 }
284 std::net::IpAddr::V6(ip) => {
285 let octets = ip.octets();
286 (octets[15] as usize) ^ (addr.port() as usize)
287 }
288 };
289 hash % QUEUE_SHARD_COUNT
290 }
291}