1use crate::domain::DOMAIN_MAX_LENGTH;
2use crate::Domain;
3
4use std::hash::{Hash, Hasher};
5
6use parking_lot::Mutex;
7
8const DEFAULT_SHARDS: usize = 1024;
9type DefaultHasher = std::collections::hash_map::RandomState;
10
11pub type DomainSetShardedDefault = DomainSetSharded<DefaultHasher>;
12
13pub struct DomainSetSharded<H: std::hash::BuildHasher> {
14 shards: Vec<Mutex<DomainSet>>,
15 hasher: H,
16}
17
18impl<H: std::hash::BuildHasher + Default> DomainSetSharded<H> {
19 pub fn new() -> Self {
20 Self::with_shards_and_hasher(DEFAULT_SHARDS, H::default())
21 }
22 pub fn with_shards(shard_count: usize) -> Self {
23 Self::with_shards_and_hasher(shard_count, H::default())
24 }
25}
26
27impl<H: std::hash::BuildHasher + Default> Default for DomainSetSharded<H> {
28 fn default() -> Self {
29 Self::new()
30 }
31}
32
33impl<T: std::hash::BuildHasher> DomainSetSharded<T> {
34 pub fn with_shards_and_hasher(shard_count: usize, hasher: T) -> Self {
35 let mut shards = Vec::with_capacity(shard_count);
36 for _ in 0..shard_count {
37 shards.push(Mutex::new(DomainSet::new()));
38 }
39 Self { shards, hasher }
40 }
41 fn get_location(&self, data: &[u8]) -> usize {
42 let mut hasher = self.hasher.build_hasher();
43 data.hash(&mut hasher);
44 let hash = hasher.finish();
45 hash as usize % self.shards.len()
46 }
47
48 pub fn contains(&self, data: &[u8]) -> bool {
49 assert!(data.len() <= DOMAIN_MAX_LENGTH);
50 self.shards[self.get_location(data)].lock().contains(data)
51 }
52 pub fn contains_str(&self, data: &str) -> bool {
53 self.contains(data.as_bytes())
54 }
55
56 pub fn insert(&self, data: &[u8]) -> bool {
57 assert!(data.len() <= DOMAIN_MAX_LENGTH);
58 self.shards[self.get_location(data)].lock().insert(data)
59 }
60 pub fn insert_str(&self, data: &str) -> bool {
61 self.insert(data.as_bytes())
62 }
63
64 pub fn remove(&self, data: &[u8]) -> bool {
65 assert!(data.len() <= DOMAIN_MAX_LENGTH);
66 self.shards[self.get_location(data)].lock().remove(data)
67 }
68 pub fn remove_str(&self, data: &str) -> bool {
69 self.remove(data.as_bytes())
70 }
71
72 pub fn into_iter(self) -> impl Iterator<Item = Vec<u8>> {
73 self.shards.into_iter().flat_map(|shard| {
74 let shard_iter = std::mem::take(&mut *shard.lock());
75 shard_iter.into_iter()
76 })
77 }
78
79 pub fn into_iter_string(self) -> impl Iterator<Item = String> {
80 self.into_iter()
81 .filter_map(|element| String::from_utf8(element).ok())
82 }
83
84 pub fn into_iter_domains(self) -> impl Iterator<Item = Domain> {
85 self.into_iter_string()
86 .filter_map(|slice| slice.parse::<Domain>().ok())
87 }
88
89 pub fn shrink_to_fit(&self) {
90 for shard in self.shards.iter() {
91 shard.lock().shrink_to_fit();
92 }
93 }
94
95 pub fn len(&self) -> usize {
96 self.shards.iter().map(|shard| shard.lock().len()).sum()
97 }
98
99 pub fn is_empty(&self) -> bool {
100 self.shards.iter().all(|shard| shard.lock().is_empty())
101 }
102}
103
104pub struct DomainSetIter<'a> {
105 domain_set: &'a DomainSet,
106 has_empty_string: bool,
107 subset: usize,
108 index: usize,
109}
110
111impl<'a> DomainSetIter<'a> {
112 fn new(domain_set: &'a DomainSet) -> Self {
113 Self {
114 has_empty_string: domain_set.has_empty_string,
115 domain_set,
116 subset: 0,
117 index: 0,
118 }
119 }
120}
121
122impl<'a> Iterator for DomainSetIter<'a> {
123 type Item = &'a [u8];
124 fn next(&mut self) -> Option<Self::Item> {
125 if self.has_empty_string {
126 self.has_empty_string = false;
127 Some(&[])
128 } else if self.subset < self.domain_set.subsets.len() {
129 let subset = &self.domain_set.subsets[self.subset];
130 if self.index * (self.subset + 1) < subset.len() {
131 let result =
132 &subset[self.index * (self.subset + 1)..(self.index + 1) * (self.subset + 1)];
133 self.index += 1;
134 Some(result)
135 } else {
136 self.subset += 1;
137 self.index = 0;
138 self.next()
139 }
140 } else {
141 None
142 }
143 }
144}
145
146pub struct DomainSetIntoIter {
147 domain_set: DomainSet,
148 has_empty_string: bool,
149 subset: usize,
150 index: usize,
151}
152
153impl DomainSetIntoIter {
154 fn new(domain_set: DomainSet) -> Self {
155 Self {
156 has_empty_string: domain_set.has_empty_string,
157 domain_set,
158 subset: 0,
159 index: 0,
160 }
161 }
162}
163
164impl Iterator for DomainSetIntoIter {
165 type Item = Vec<u8>;
166 fn next(&mut self) -> Option<Self::Item> {
167 if self.has_empty_string {
168 self.has_empty_string = false;
169 Some(Vec::new())
170 } else if self.subset < self.domain_set.subsets.len() {
171 let subset = &self.domain_set.subsets[self.subset];
172 if self.index * (self.subset + 1) < subset.len() {
173 let result = subset
174 [self.index * (self.subset + 1)..(self.index + 1) * (self.subset + 1)]
175 .to_vec();
176 self.index += 1;
177 Some(result)
178 } else {
179 drop(subset);
180 self.domain_set.subsets[self.subset] = Vec::new();
181 self.subset += 1;
182 self.index = 0;
183 self.next()
184 }
185 } else {
186 None
187 }
188 }
189}
190
191#[derive(Clone)]
192pub struct DomainSet {
193 subsets: [Vec<u8>; DOMAIN_MAX_LENGTH],
194 has_empty_string: bool,
195 length: usize,
196}
197
198impl Default for DomainSet {
199 fn default() -> Self {
200 Self::new()
201 }
202}
203
204impl DomainSet {
205 pub fn new() -> Self {
206 let mut subsets: [std::mem::MaybeUninit<Vec<u8>>; DOMAIN_MAX_LENGTH] =
207 unsafe { std::mem::MaybeUninit::uninit().assume_init() };
208 for elem in &mut subsets {
209 *elem = std::mem::MaybeUninit::new(Vec::new());
210 }
211 Self {
212 subsets: unsafe { std::mem::transmute::<_, _>(subsets) },
213 has_empty_string: false,
214 length: 0,
215 }
216 }
217 fn find_index(&self, data: &[u8]) -> Result<usize, usize> {
218 let len = data.len();
219 assert!(len != 0);
220 let subset = &self.subsets[len - 1];
221 assert_eq!(subset.len() % len, 0);
222 let chunk_count = subset.len() / len;
223 if chunk_count == 0 {
224 return Err(0);
225 }
226
227 let mut size = chunk_count;
228 let mut base = 0;
229 while size > 1 {
230 let half = size / 2;
231 let mid = base + half;
232 let slice = &subset[mid * len..(mid + 1) * len];
233 let cmp = data.cmp(slice);
234 base = if cmp == std::cmp::Ordering::Greater {
235 base
236 } else {
237 mid
238 };
239 size -= half;
240 }
241 let slice = &subset[base * len..(base + 1) * len];
242 let cmp = data.cmp(slice);
243 if cmp == std::cmp::Ordering::Equal {
244 Ok(base)
245 } else {
246 Err(base + (cmp == std::cmp::Ordering::Less) as usize)
247 }
248 }
249 pub fn contains(&self, data: &[u8]) -> bool {
250 if data.len() == 0 {
251 self.has_empty_string
252 } else {
253 self.find_index(data).is_ok()
254 }
255 }
256 pub fn contains_str(&self, data: &str) -> bool {
257 self.contains(data.as_bytes())
258 }
259
260 pub fn insert(&mut self, data: &[u8]) -> bool {
261 let len = data.len();
262 if len == 0 {
263 let old = self.has_empty_string;
264 self.has_empty_string = true;
265 if !old {
266 self.length += 1;
267 }
268 !old
269 } else if let Err(index) = self.find_index(data) {
270 let subset = &mut self.subsets[len - 1];
271 let removed: Vec<_> = subset
272 .splice(index * len..index * len, data.iter().cloned())
273 .collect();
274 assert_eq!(removed.len(), 0);
275 self.length += 1;
276 true
277 } else {
278 false
279 }
280 }
281 pub fn insert_str(&mut self, data: &str) -> bool {
282 self.insert(data.as_bytes())
283 }
284
285 pub fn remove(&mut self, data: &[u8]) -> bool {
286 let len = data.len();
287 if len == 0 {
288 let old = self.has_empty_string;
289 self.has_empty_string = false;
290 if self.has_empty_string {
291 self.length -= 1;
292 }
293 old
294 } else if let Ok(index) = self.find_index(data) {
295 let subset = &mut self.subsets[len - 1];
296 let removed: Vec<_> = subset
297 .splice(index * len..(index + 1) * len, std::iter::empty())
298 .collect();
299 assert_eq!(removed.len(), len);
300 self.length -= 1;
301 if subset.len() == 0 {
302 *subset = Vec::new();
303 } else if subset.len() * 4 < subset.capacity() {
304 }
306 true
307 } else {
308 false
309 }
310 }
311
312 pub fn remove_str(&mut self, data: &str) -> bool {
313 self.remove(data.as_bytes())
314 }
315
316 pub fn iter(&self) -> impl Iterator<Item = &[u8]> {
317 DomainSetIter::new(self)
318 }
319
320 pub fn into_iter(mut self) -> impl Iterator<Item = Vec<u8>> {
321 self.shrink_to_fit();
322 DomainSetIntoIter::new(self)
323 }
324 pub fn into_iter_string(self) -> impl Iterator<Item = String> {
325 self.into_iter()
326 .filter_map(|slice| String::from_utf8(slice).ok())
327 }
328
329 pub fn into_iter_domains(self) -> impl Iterator<Item = Domain> {
330 self.into_iter_string()
331 .filter_map(|slice| slice.parse::<Domain>().ok())
332 }
333
334 pub fn shrink_to_fit(&mut self) {
335 if self.length != 0 {
336 for subset in self.subsets.iter_mut() {
337 subset.shrink_to_fit();
338 }
339 }
340 }
341
342 pub fn len(&self) -> usize {
343 debug_assert_eq!(
344 self.length,
345 self.has_empty_string as usize
346 + self
347 .subsets
348 .iter()
349 .enumerate()
350 .map(|(len, subset)| subset.len() / (len + 1))
351 .sum::<usize>()
352 );
353 self.length
354 }
355
356 pub fn is_empty(&self) -> bool {
357 debug_assert_eq!(
358 self.length == 0,
359 !self.has_empty_string && self.subsets.iter().all(|subset| subset.is_empty()),
360 );
361 self.length == 0
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 #[quickcheck]
370 fn test_sharded_into_iter_string_is_original(mut strings: Vec<String>) {
371 let set = DomainSetShardedDefault::default();
372 strings.retain(|string| string.len() <= DOMAIN_MAX_LENGTH);
373 for domain in strings.iter() {
374 set.insert_str(&domain);
375 }
376 let mut generated = set.into_iter_string().collect::<Vec<_>>();
377 generated.sort();
378 strings.sort();
379 strings.dedup();
380 assert_eq!(strings, generated);
381 }
382
383 #[quickcheck]
384 fn test_domain_set_into_iter_string_is_original(mut strings: Vec<String>) {
385 let mut set = DomainSet::default();
386 strings.retain(|string| string.len() <= DOMAIN_MAX_LENGTH);
387 for domain in strings.iter() {
388 set.insert_str(&domain);
389 }
390 let mut generated = set.into_iter_string().collect::<Vec<_>>();
391 generated.sort();
392 strings.sort();
393 strings.dedup();
394 assert_eq!(strings, generated);
395 }
396
397 #[quickcheck]
398 fn test_into_iter_is_original(mut slices: Vec<Vec<u8>>) {
399 let set = DomainSetShardedDefault::default();
400 slices.retain(|string| string.len() <= DOMAIN_MAX_LENGTH);
401 for domain in slices.iter() {
402 set.insert(&domain);
403 }
404 let mut generated = set.into_iter().collect::<Vec<_>>();
405 generated.sort();
406 slices.sort();
407 slices.dedup();
408 assert_eq!(slices, generated);
409 }
410
411 #[quickcheck]
412 fn test_domain_set_iter_is_original(mut slices: Vec<Vec<u8>>) {
413 let mut set = DomainSet::default();
414 slices.retain(|string| string.len() <= DOMAIN_MAX_LENGTH);
415 for domain in slices.iter() {
416 set.insert(&domain);
417 }
418 let mut generated = set.iter().collect::<Vec<_>>();
419 generated.sort();
420 slices.sort();
421 slices.dedup();
422 assert_eq!(slices, generated);
423 }
424
425 #[test]
426 fn test_domain_set_can_have_elements_removed() {
427 let mut domains = vec!["google.com", "en.m.wikipedia.org", "example.tk"];
428 domains.sort();
429 let set = DomainSetShardedDefault::default();
430 for domain in domains.iter() {
431 set.insert_str(&domain);
432 }
433 set.insert_str("youtube.com");
434 assert_eq!(set.len(), 4);
435 assert_eq!(set.contains_str("youtube.com"), true);
436 set.remove_str("youtube.com");
437 assert_eq!(set.len(), 3);
438 assert_eq!(set.contains_str("youtube.com"), false);
439 let mut generated = set.into_iter_string().collect::<Vec<_>>();
440 generated.sort();
441 assert_eq!(domains, generated);
442 }
443
444 #[test]
445 fn test_domain_set_can_multiple_sizes() {
446 let mut domains = vec![
447 "",
448 "e",
449 "ex",
450 "exa",
451 "exam",
452 "examp",
453 "exampl",
454 "example",
455 "example.",
456 "example.c",
457 "example.co",
458 "example.com",
459 ];
460 domains.sort();
461 let set = DomainSetShardedDefault::default();
462 for (i, domain) in domains.iter().enumerate() {
463 assert_eq!(set.contains_str(&domain), false);
464 assert_eq!(set.len(), i);
465 set.insert_str(&domain);
466 assert_eq!(set.contains_str(&domain), true);
467 assert_eq!(set.len(), i + 1);
468 }
469 let mut generated = set.into_iter_string().collect::<Vec<_>>();
470 generated.sort();
471 assert_eq!(domains, generated);
472 }
473
474 #[test]
475 fn test_domain_set_removes_duplicates() {
476 let mut domains = vec![
477 "google.com",
478 "en.m.wikipedia.org",
479 "example.tk",
480 "google.com",
481 ];
482 let set = DomainSetShardedDefault::default();
483 for domain in domains.iter() {
484 set.insert_str(&domain);
485 }
486 let mut generated = set.into_iter_string().collect::<Vec<_>>();
487 generated.sort();
488 domains.sort();
489 domains.dedup();
490 assert_eq!(domains, generated);
491 }
492}