Skip to main content

dht_crawler/
sharded.rs

1use std::collections::{HashSet, VecDeque};
2use std::net::SocketAddr;
3use std::sync::Mutex;
4
5const QUEUE_SHARD_COUNT: usize = 16;
6
7#[derive(Debug, Clone)]
8pub struct NodeTuple {
9    pub id: Vec<u8>,
10    pub addr: SocketAddr,
11}
12
13struct NodeQueueShard {
14    queue: VecDeque<NodeTuple>,
15    index: HashSet<SocketAddr>,
16    capacity: usize,
17}
18
19impl NodeQueueShard {
20    fn new(capacity: usize) -> Self {
21        Self {
22            queue: VecDeque::with_capacity(capacity),
23            index: HashSet::with_capacity(capacity),
24            capacity,
25        }
26    }
27
28    fn push(&mut self, node: NodeTuple) {
29        if self.index.contains(&node.addr) {
30            return;
31        }
32
33        if self.queue.len() >= self.capacity
34            && let Some(removed) = self.queue.pop_front()
35        {
36            self.index.remove(&removed.addr);
37        }
38
39        self.index.insert(node.addr);
40        self.queue.push_back(node);
41    }
42
43    fn pop_batch(&mut self, count: usize) -> Vec<NodeTuple> {
44        let actual_count = count.min(self.queue.len());
45        let mut nodes = Vec::with_capacity(actual_count);
46
47        for _ in 0..actual_count {
48            if let Some(node) = self.queue.pop_front() {
49                self.index.remove(&node.addr);
50                nodes.push(node);
51            }
52        }
53        nodes
54    }
55
56    fn len(&self) -> usize {
57        self.queue.len()
58    }
59
60    fn is_empty(&self) -> bool {
61        self.queue.is_empty()
62    }
63}
64
65pub struct ShardedNodeQueue {
66    shards_v4: Vec<Mutex<NodeQueueShard>>,
67    shards_v6: Vec<Mutex<NodeQueueShard>>,
68}
69
70impl ShardedNodeQueue {
71    pub fn new(total_capacity: usize) -> Self {
72        #[allow(clippy::manual_div_ceil)]
73        let capacity_per_shard = (total_capacity + QUEUE_SHARD_COUNT - 1) / QUEUE_SHARD_COUNT;
74
75        let shards_v4 = (0..QUEUE_SHARD_COUNT)
76            .map(|_| Mutex::new(NodeQueueShard::new(capacity_per_shard)))
77            .collect();
78
79        let shards_v6 = (0..QUEUE_SHARD_COUNT)
80            .map(|_| Mutex::new(NodeQueueShard::new(capacity_per_shard)))
81            .collect();
82
83        Self {
84            shards_v4,
85            shards_v6,
86        }
87    }
88
89    pub fn push(&self, node: NodeTuple) {
90        let shard_idx = self.addr_to_shard(&node.addr);
91
92        if node.addr.is_ipv6() {
93            let mut shard = self.shards_v6[shard_idx].lock().unwrap_or_else(|e| e.into_inner());
94            shard.push(node);
95        } else {
96            let mut shard = self.shards_v4[shard_idx].lock().unwrap_or_else(|e| e.into_inner());
97            shard.push(node);
98        }
99    }
100
101    pub fn pop_batch(&self, count: usize, filter_ipv6: Option<bool>) -> Vec<NodeTuple> {
102        let mut result = Vec::with_capacity(count);
103        #[allow(clippy::manual_div_ceil)]
104        let per_shard = (count + QUEUE_SHARD_COUNT - 1) / QUEUE_SHARD_COUNT;
105
106        match filter_ipv6 {
107            Some(true) => {
108                for shard in &self.shards_v6 {
109                    if result.len() >= count {
110                        break;
111                    }
112                    let mut s = shard.lock().unwrap_or_else(|e| e.into_inner());
113                    let nodes = s.pop_batch(per_shard);
114                    result.extend(nodes);
115                }
116            }
117            Some(false) => {
118                for shard in &self.shards_v4 {
119                    if result.len() >= count {
120                        break;
121                    }
122                    let mut s = shard.lock().unwrap_or_else(|e| e.into_inner());
123                    let nodes = s.pop_batch(per_shard);
124                    result.extend(nodes);
125                }
126            }
127            None => {
128                for i in 0..QUEUE_SHARD_COUNT {
129                    if result.len() >= count {
130                        break;
131                    }
132
133                    let mut s4 = self.shards_v4[i].lock().unwrap_or_else(|e| e.into_inner());
134                    let nodes4 = s4.pop_batch(per_shard / 2);
135                    result.extend(nodes4);
136                    drop(s4);
137
138                    if result.len() >= count {
139                        break;
140                    }
141
142                    let mut s6 = self.shards_v6[i].lock().unwrap_or_else(|e| e.into_inner());
143                    let nodes6 = s6.pop_batch(per_shard / 2);
144                    result.extend(nodes6);
145                    drop(s6);
146                }
147            }
148        }
149
150        result
151    }
152
153    pub fn get_random_nodes(&self, count: usize, filter_ipv6: Option<bool>) -> Vec<NodeTuple> {
154        match filter_ipv6 {
155            Some(true) => self.get_random_nodes_from_shards(&self.shards_v6, count),
156            Some(false) => self.get_random_nodes_from_shards(&self.shards_v4, count),
157            None => {
158                let count_v4 = count / 2;
159                let count_v6 = count - count_v4;
160                let mut result = Vec::with_capacity(count);
161
162                result.extend(self.get_random_nodes_from_shards(&self.shards_v4, count_v4));
163                result.extend(self.get_random_nodes_from_shards(&self.shards_v6, count_v6));
164
165                result
166            }
167        }
168    }
169
170    fn get_random_nodes_from_shards(
171        &self,
172        shards: &[Mutex<NodeQueueShard>],
173        count: usize,
174    ) -> Vec<NodeTuple> {
175        use rand::Rng;
176        let mut rng = rand::thread_rng();
177
178        if count <= 16 {
179            let mut result = Vec::with_capacity(count);
180            #[allow(clippy::manual_div_ceil)]
181            let per_shard = (count + QUEUE_SHARD_COUNT - 1) / QUEUE_SHARD_COUNT;
182
183            for shard in shards {
184                if result.len() >= count {
185                    break;
186                }
187
188                let s = shard.lock().unwrap_or_else(|e| e.into_inner());
189                let shard_len = s.queue.len();
190
191                if shard_len == 0 {
192                    continue;
193                }
194
195                let to_take = per_shard.min(shard_len).min(count - result.len());
196
197                let mut indices: Vec<usize> = (0..shard_len).collect();
198
199                for i in 0..to_take {
200                    let j = rng.gen_range(i..shard_len);
201                    indices.swap(i, j);
202                }
203
204                for &idx in indices.iter().take(to_take) {
205                    if let Some(node) = s.queue.get(idx) {
206                        result.push(node.clone());
207                    }
208                }
209            }
210
211            result
212        } else {
213            let mut result = Vec::with_capacity(count);
214            let mut seen = 0usize;
215
216            for shard in shards {
217                let s = shard.lock().unwrap_or_else(|e| e.into_inner());
218
219                for node in s.queue.iter() {
220                    seen += 1;
221
222                    if result.len() < count {
223                        result.push(node.clone());
224                    } else {
225                        let j = rng.gen_range(0..seen);
226                        if j < count {
227                            result[j] = node.clone();
228                        }
229                    }
230                }
231            }
232
233            result
234        }
235    }
236
237    pub fn len(&self) -> usize {
238        let len_v4: usize = self
239            .shards_v4
240            .iter()
241            .map(|shard| shard.lock().unwrap_or_else(|e| e.into_inner()).len())
242            .sum();
243        let len_v6: usize = self
244            .shards_v6
245            .iter()
246            .map(|shard| shard.lock().unwrap_or_else(|e| e.into_inner()).len())
247            .sum();
248        len_v4 + len_v6
249    }
250
251    pub fn is_empty(&self) -> bool {
252        let empty_v4 = self
253            .shards_v4
254            .iter()
255            .all(|shard| shard.lock().unwrap_or_else(|e| e.into_inner()).is_empty());
256        let empty_v6 = self
257            .shards_v6
258            .iter()
259            .all(|shard| shard.lock().unwrap_or_else(|e| e.into_inner()).is_empty());
260        empty_v4 && empty_v6
261    }
262
263    pub fn is_empty_for(&self, filter_ipv6: Option<bool>) -> bool {
264        match filter_ipv6 {
265            Some(true) => self
266                .shards_v6
267                .iter()
268                .all(|shard| shard.lock().unwrap_or_else(|e| e.into_inner()).is_empty()),
269            Some(false) => self
270                .shards_v4
271                .iter()
272                .all(|shard| shard.lock().unwrap_or_else(|e| e.into_inner()).is_empty()),
273            None => self.is_empty(),
274        }
275    }
276
277    #[inline]
278    fn addr_to_shard(&self, addr: &SocketAddr) -> usize {
279        let hash = match addr.ip() {
280            std::net::IpAddr::V4(ip) => {
281                let octets = ip.octets();
282                (octets[3] as usize) ^ (addr.port() as usize)
283            }
284            std::net::IpAddr::V6(ip) => {
285                let octets = ip.octets();
286                (octets[15] as usize) ^ (addr.port() as usize)
287            }
288        };
289        hash % QUEUE_SHARD_COUNT
290    }
291}