blockconvert/
domain_set.rs

1use crate::domain::DOMAIN_MAX_LENGTH;
2use crate::Domain;
3
4use std::hash::{Hash, Hasher};
5
6use parking_lot::Mutex;
7
8const DEFAULT_SHARDS: usize = 1024;
9type DefaultHasher = std::collections::hash_map::RandomState;
10
11pub type DomainSetShardedDefault = DomainSetSharded<DefaultHasher>;
12
13pub struct DomainSetSharded<H: std::hash::BuildHasher> {
14    shards: Vec<Mutex<DomainSet>>,
15    hasher: H,
16}
17
18impl<H: std::hash::BuildHasher + Default> DomainSetSharded<H> {
19    pub fn new() -> Self {
20        Self::with_shards_and_hasher(DEFAULT_SHARDS, H::default())
21    }
22    pub fn with_shards(shard_count: usize) -> Self {
23        Self::with_shards_and_hasher(shard_count, H::default())
24    }
25}
26
27impl<H: std::hash::BuildHasher + Default> Default for DomainSetSharded<H> {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33impl<T: std::hash::BuildHasher> DomainSetSharded<T> {
34    pub fn with_shards_and_hasher(shard_count: usize, hasher: T) -> Self {
35        let mut shards = Vec::with_capacity(shard_count);
36        for _ in 0..shard_count {
37            shards.push(Mutex::new(DomainSet::new()));
38        }
39        Self { shards, hasher }
40    }
41    fn get_location(&self, data: &[u8]) -> usize {
42        let mut hasher = self.hasher.build_hasher();
43        data.hash(&mut hasher);
44        let hash = hasher.finish();
45        hash as usize % self.shards.len()
46    }
47
48    pub fn contains(&self, data: &[u8]) -> bool {
49        assert!(data.len() <= DOMAIN_MAX_LENGTH);
50        self.shards[self.get_location(data)].lock().contains(data)
51    }
52    pub fn contains_str(&self, data: &str) -> bool {
53        self.contains(data.as_bytes())
54    }
55
56    pub fn insert(&self, data: &[u8]) -> bool {
57        assert!(data.len() <= DOMAIN_MAX_LENGTH);
58        self.shards[self.get_location(data)].lock().insert(data)
59    }
60    pub fn insert_str(&self, data: &str) -> bool {
61        self.insert(data.as_bytes())
62    }
63
64    pub fn remove(&self, data: &[u8]) -> bool {
65        assert!(data.len() <= DOMAIN_MAX_LENGTH);
66        self.shards[self.get_location(data)].lock().remove(data)
67    }
68    pub fn remove_str(&self, data: &str) -> bool {
69        self.remove(data.as_bytes())
70    }
71
72    pub fn into_iter(self) -> impl Iterator<Item = Vec<u8>> {
73        self.shards.into_iter().flat_map(|shard| {
74            let shard_iter = std::mem::take(&mut *shard.lock());
75            shard_iter.into_iter()
76        })
77    }
78
79    pub fn into_iter_string(self) -> impl Iterator<Item = String> {
80        self.into_iter()
81            .filter_map(|element| String::from_utf8(element).ok())
82    }
83
84    pub fn into_iter_domains(self) -> impl Iterator<Item = Domain> {
85        self.into_iter_string()
86            .filter_map(|slice| slice.parse::<Domain>().ok())
87    }
88
89    pub fn shrink_to_fit(&self) {
90        for shard in self.shards.iter() {
91            shard.lock().shrink_to_fit();
92        }
93    }
94
95    pub fn len(&self) -> usize {
96        self.shards.iter().map(|shard| shard.lock().len()).sum()
97    }
98
99    pub fn is_empty(&self) -> bool {
100        self.shards.iter().all(|shard| shard.lock().is_empty())
101    }
102}
103
104pub struct DomainSetIter<'a> {
105    domain_set: &'a DomainSet,
106    has_empty_string: bool,
107    subset: usize,
108    index: usize,
109}
110
111impl<'a> DomainSetIter<'a> {
112    fn new(domain_set: &'a DomainSet) -> Self {
113        Self {
114            has_empty_string: domain_set.has_empty_string,
115            domain_set,
116            subset: 0,
117            index: 0,
118        }
119    }
120}
121
122impl<'a> Iterator for DomainSetIter<'a> {
123    type Item = &'a [u8];
124    fn next(&mut self) -> Option<Self::Item> {
125        if self.has_empty_string {
126            self.has_empty_string = false;
127            Some(&[])
128        } else if self.subset < self.domain_set.subsets.len() {
129            let subset = &self.domain_set.subsets[self.subset];
130            if self.index * (self.subset + 1) < subset.len() {
131                let result =
132                    &subset[self.index * (self.subset + 1)..(self.index + 1) * (self.subset + 1)];
133                self.index += 1;
134                Some(result)
135            } else {
136                self.subset += 1;
137                self.index = 0;
138                self.next()
139            }
140        } else {
141            None
142        }
143    }
144}
145
146pub struct DomainSetIntoIter {
147    domain_set: DomainSet,
148    has_empty_string: bool,
149    subset: usize,
150    index: usize,
151}
152
153impl DomainSetIntoIter {
154    fn new(domain_set: DomainSet) -> Self {
155        Self {
156            has_empty_string: domain_set.has_empty_string,
157            domain_set,
158            subset: 0,
159            index: 0,
160        }
161    }
162}
163
164impl Iterator for DomainSetIntoIter {
165    type Item = Vec<u8>;
166    fn next(&mut self) -> Option<Self::Item> {
167        if self.has_empty_string {
168            self.has_empty_string = false;
169            Some(Vec::new())
170        } else if self.subset < self.domain_set.subsets.len() {
171            let subset = &self.domain_set.subsets[self.subset];
172            if self.index * (self.subset + 1) < subset.len() {
173                let result = subset
174                    [self.index * (self.subset + 1)..(self.index + 1) * (self.subset + 1)]
175                    .to_vec();
176                self.index += 1;
177                Some(result)
178            } else {
179                drop(subset);
180                self.domain_set.subsets[self.subset] = Vec::new();
181                self.subset += 1;
182                self.index = 0;
183                self.next()
184            }
185        } else {
186            None
187        }
188    }
189}
190
191#[derive(Clone)]
192pub struct DomainSet {
193    subsets: [Vec<u8>; DOMAIN_MAX_LENGTH],
194    has_empty_string: bool,
195    length: usize,
196}
197
198impl Default for DomainSet {
199    fn default() -> Self {
200        Self::new()
201    }
202}
203
204impl DomainSet {
205    pub fn new() -> Self {
206        let mut subsets: [std::mem::MaybeUninit<Vec<u8>>; DOMAIN_MAX_LENGTH] =
207            unsafe { std::mem::MaybeUninit::uninit().assume_init() };
208        for elem in &mut subsets {
209            *elem = std::mem::MaybeUninit::new(Vec::new());
210        }
211        Self {
212            subsets: unsafe { std::mem::transmute::<_, _>(subsets) },
213            has_empty_string: false,
214            length: 0,
215        }
216    }
217    fn find_index(&self, data: &[u8]) -> Result<usize, usize> {
218        let len = data.len();
219        assert!(len != 0);
220        let subset = &self.subsets[len - 1];
221        assert_eq!(subset.len() % len, 0);
222        let chunk_count = subset.len() / len;
223        if chunk_count == 0 {
224            return Err(0);
225        }
226
227        let mut size = chunk_count;
228        let mut base = 0;
229        while size > 1 {
230            let half = size / 2;
231            let mid = base + half;
232            let slice = &subset[mid * len..(mid + 1) * len];
233            let cmp = data.cmp(slice);
234            base = if cmp == std::cmp::Ordering::Greater {
235                base
236            } else {
237                mid
238            };
239            size -= half;
240        }
241        let slice = &subset[base * len..(base + 1) * len];
242        let cmp = data.cmp(slice);
243        if cmp == std::cmp::Ordering::Equal {
244            Ok(base)
245        } else {
246            Err(base + (cmp == std::cmp::Ordering::Less) as usize)
247        }
248    }
249    pub fn contains(&self, data: &[u8]) -> bool {
250        if data.len() == 0 {
251            self.has_empty_string
252        } else {
253            self.find_index(data).is_ok()
254        }
255    }
256    pub fn contains_str(&self, data: &str) -> bool {
257        self.contains(data.as_bytes())
258    }
259
260    pub fn insert(&mut self, data: &[u8]) -> bool {
261        let len = data.len();
262        if len == 0 {
263            let old = self.has_empty_string;
264            self.has_empty_string = true;
265            if !old {
266                self.length += 1;
267            }
268            !old
269        } else if let Err(index) = self.find_index(data) {
270            let subset = &mut self.subsets[len - 1];
271            let removed: Vec<_> = subset
272                .splice(index * len..index * len, data.iter().cloned())
273                .collect();
274            assert_eq!(removed.len(), 0);
275            self.length += 1;
276            true
277        } else {
278            false
279        }
280    }
281    pub fn insert_str(&mut self, data: &str) -> bool {
282        self.insert(data.as_bytes())
283    }
284
285    pub fn remove(&mut self, data: &[u8]) -> bool {
286        let len = data.len();
287        if len == 0 {
288            let old = self.has_empty_string;
289            self.has_empty_string = false;
290            if self.has_empty_string {
291                self.length -= 1;
292            }
293            old
294        } else if let Ok(index) = self.find_index(data) {
295            let subset = &mut self.subsets[len - 1];
296            let removed: Vec<_> = subset
297                .splice(index * len..(index + 1) * len, std::iter::empty())
298                .collect();
299            assert_eq!(removed.len(), len);
300            self.length -= 1;
301            if subset.len() == 0 {
302                *subset = Vec::new();
303            } else if subset.len() * 4 < subset.capacity() {
304                //subset.shrink_to_fit();
305            }
306            true
307        } else {
308            false
309        }
310    }
311
312    pub fn remove_str(&mut self, data: &str) -> bool {
313        self.remove(data.as_bytes())
314    }
315
316    pub fn iter(&self) -> impl Iterator<Item = &[u8]> {
317        DomainSetIter::new(self)
318    }
319
320    pub fn into_iter(mut self) -> impl Iterator<Item = Vec<u8>> {
321        self.shrink_to_fit();
322        DomainSetIntoIter::new(self)
323    }
324    pub fn into_iter_string(self) -> impl Iterator<Item = String> {
325        self.into_iter()
326            .filter_map(|slice| String::from_utf8(slice).ok())
327    }
328
329    pub fn into_iter_domains(self) -> impl Iterator<Item = Domain> {
330        self.into_iter_string()
331            .filter_map(|slice| slice.parse::<Domain>().ok())
332    }
333
334    pub fn shrink_to_fit(&mut self) {
335        if self.length != 0 {
336            for subset in self.subsets.iter_mut() {
337                subset.shrink_to_fit();
338            }
339        }
340    }
341
342    pub fn len(&self) -> usize {
343        debug_assert_eq!(
344            self.length,
345            self.has_empty_string as usize
346                + self
347                    .subsets
348                    .iter()
349                    .enumerate()
350                    .map(|(len, subset)| subset.len() / (len + 1))
351                    .sum::<usize>()
352        );
353        self.length
354    }
355
356    pub fn is_empty(&self) -> bool {
357        debug_assert_eq!(
358            self.length == 0,
359            !self.has_empty_string && self.subsets.iter().all(|subset| subset.is_empty()),
360        );
361        self.length == 0
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368
369    #[quickcheck]
370    fn test_sharded_into_iter_string_is_original(mut strings: Vec<String>) {
371        let set = DomainSetShardedDefault::default();
372        strings.retain(|string| string.len() <= DOMAIN_MAX_LENGTH);
373        for domain in strings.iter() {
374            set.insert_str(&domain);
375        }
376        let mut generated = set.into_iter_string().collect::<Vec<_>>();
377        generated.sort();
378        strings.sort();
379        strings.dedup();
380        assert_eq!(strings, generated);
381    }
382
383    #[quickcheck]
384    fn test_domain_set_into_iter_string_is_original(mut strings: Vec<String>) {
385        let mut set = DomainSet::default();
386        strings.retain(|string| string.len() <= DOMAIN_MAX_LENGTH);
387        for domain in strings.iter() {
388            set.insert_str(&domain);
389        }
390        let mut generated = set.into_iter_string().collect::<Vec<_>>();
391        generated.sort();
392        strings.sort();
393        strings.dedup();
394        assert_eq!(strings, generated);
395    }
396
397    #[quickcheck]
398    fn test_into_iter_is_original(mut slices: Vec<Vec<u8>>) {
399        let set = DomainSetShardedDefault::default();
400        slices.retain(|string| string.len() <= DOMAIN_MAX_LENGTH);
401        for domain in slices.iter() {
402            set.insert(&domain);
403        }
404        let mut generated = set.into_iter().collect::<Vec<_>>();
405        generated.sort();
406        slices.sort();
407        slices.dedup();
408        assert_eq!(slices, generated);
409    }
410
411    #[quickcheck]
412    fn test_domain_set_iter_is_original(mut slices: Vec<Vec<u8>>) {
413        let mut set = DomainSet::default();
414        slices.retain(|string| string.len() <= DOMAIN_MAX_LENGTH);
415        for domain in slices.iter() {
416            set.insert(&domain);
417        }
418        let mut generated = set.iter().collect::<Vec<_>>();
419        generated.sort();
420        slices.sort();
421        slices.dedup();
422        assert_eq!(slices, generated);
423    }
424
425    #[test]
426    fn test_domain_set_can_have_elements_removed() {
427        let mut domains = vec!["google.com", "en.m.wikipedia.org", "example.tk"];
428        domains.sort();
429        let set = DomainSetShardedDefault::default();
430        for domain in domains.iter() {
431            set.insert_str(&domain);
432        }
433        set.insert_str("youtube.com");
434        assert_eq!(set.len(), 4);
435        assert_eq!(set.contains_str("youtube.com"), true);
436        set.remove_str("youtube.com");
437        assert_eq!(set.len(), 3);
438        assert_eq!(set.contains_str("youtube.com"), false);
439        let mut generated = set.into_iter_string().collect::<Vec<_>>();
440        generated.sort();
441        assert_eq!(domains, generated);
442    }
443
444    #[test]
445    fn test_domain_set_can_multiple_sizes() {
446        let mut domains = vec![
447            "",
448            "e",
449            "ex",
450            "exa",
451            "exam",
452            "examp",
453            "exampl",
454            "example",
455            "example.",
456            "example.c",
457            "example.co",
458            "example.com",
459        ];
460        domains.sort();
461        let set = DomainSetShardedDefault::default();
462        for (i, domain) in domains.iter().enumerate() {
463            assert_eq!(set.contains_str(&domain), false);
464            assert_eq!(set.len(), i);
465            set.insert_str(&domain);
466            assert_eq!(set.contains_str(&domain), true);
467            assert_eq!(set.len(), i + 1);
468        }
469        let mut generated = set.into_iter_string().collect::<Vec<_>>();
470        generated.sort();
471        assert_eq!(domains, generated);
472    }
473
474    #[test]
475    fn test_domain_set_removes_duplicates() {
476        let mut domains = vec![
477            "google.com",
478            "en.m.wikipedia.org",
479            "example.tk",
480            "google.com",
481        ];
482        let set = DomainSetShardedDefault::default();
483        for domain in domains.iter() {
484            set.insert_str(&domain);
485        }
486        let mut generated = set.into_iter_string().collect::<Vec<_>>();
487        generated.sort();
488        domains.sort();
489        domains.dedup();
490        assert_eq!(domains, generated);
491    }
492}