Skip to main content

provenant/license_detection/
token_set.rs

1use smallvec::SmallVec;
2use std::cmp::Ordering;
3use std::ops::Deref;
4
5/// A set of token IDs stored as a sorted SmallVec.
6///
7/// Invariant: elements are always sorted and deduplicated.
8/// Construct via `TokenSet::from_u16_iter()` or `.collect()` from an iterator of u16.
9#[derive(Clone, Debug, PartialEq, Eq)]
10pub struct TokenSet(SmallVec<[u16; 64]>);
11
12impl TokenSet {
13    /// Create a TokenSet from an iterator of u16 token IDs.
14    /// Sorts and deduplicates the input.
15    pub fn from_u16_iter<I: IntoIterator<Item = u16>>(iter: I) -> Self {
16        let mut inner: SmallVec<[u16; 64]> = iter.into_iter().collect();
17        inner.sort_unstable();
18        inner.dedup();
19        Self(inner)
20    }
21
22    /// Create an empty TokenSet.
23    pub fn new() -> Self {
24        Self(SmallVec::new())
25    }
26
27    /// Number of tokens in the set.
28    pub fn len(&self) -> usize {
29        self.0.len()
30    }
31
32    /// Is the set empty?
33    pub fn is_empty(&self) -> bool {
34        self.0.is_empty()
35    }
36
37    /// Count intersection with another TokenSet (no allocation).
38    pub fn intersection_count(&self, other: &TokenSet) -> usize {
39        let (mut i, mut j, mut count) = (0, 0, 0);
40        while i < self.0.len() && j < other.0.len() {
41            match self.0[i].cmp(&other.0[j]) {
42                Ordering::Less => i += 1,
43                Ordering::Greater => j += 1,
44                Ordering::Equal => {
45                    count += 1;
46                    i += 1;
47                    j += 1;
48                }
49            }
50        }
51        count
52    }
53
54    /// Materialize intersection with another TokenSet.
55    pub fn intersection(&self, other: &TokenSet) -> TokenSet {
56        let mut result = SmallVec::new();
57        let (mut i, mut j) = (0, 0);
58        while i < self.0.len() && j < other.0.len() {
59            match self.0[i].cmp(&other.0[j]) {
60                Ordering::Less => i += 1,
61                Ordering::Greater => j += 1,
62                Ordering::Equal => {
63                    result.push(self.0[i]);
64                    i += 1;
65                    j += 1;
66                }
67            }
68        }
69        Self(result)
70    }
71
72    /// Iterate over the sorted token IDs.
73    pub fn iter(&self) -> impl Iterator<Item = u16> + '_ {
74        self.0.iter().copied()
75    }
76}
77
78impl Deref for TokenSet {
79    type Target = [u16];
80    fn deref(&self) -> &Self::Target {
81        &self.0
82    }
83}
84
85impl Default for TokenSet {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91impl std::iter::FromIterator<u16> for TokenSet {
92    fn from_iter<T: IntoIterator<Item = u16>>(iter: T) -> Self {
93        Self::from_u16_iter(iter)
94    }
95}