Skip to main content

provenant/license_detection/
token_set.rs

1use smallvec::SmallVec;
2use std::cmp::Ordering;
3use std::ops::Deref;
4
5use crate::license_detection::index::dictionary::{TokenDictionary, TokenId, TokenKind};
6
7/// A set of token IDs stored as a sorted SmallVec.
8///
9/// Invariant: elements are always sorted and deduplicated.
10/// Construct via `TokenSet::from_token_ids()`, `TokenSet::from_u16_iter()`,
11/// or `.collect()` from an iterator of u16.
12#[derive(Clone, Debug, PartialEq, Eq)]
13pub struct TokenSet(SmallVec<[u16; 64]>);
14
15impl TokenSet {
16    /// Create a TokenSet from an iterator of u16 token IDs.
17    /// Sorts and deduplicates the input.
18    pub fn from_u16_iter<I: IntoIterator<Item = u16>>(iter: I) -> Self {
19        let mut inner: SmallVec<[u16; 64]> = iter.into_iter().collect();
20        inner.sort_unstable();
21        inner.dedup();
22        Self(inner)
23    }
24
25    /// Create a TokenSet from an iterator of TokenId values.
26    pub fn from_token_ids<I: IntoIterator<Item = TokenId>>(iter: I) -> Self {
27        Self::from_u16_iter(iter.into_iter().map(|tid| tid.raw()))
28    }
29
30    /// Create an empty TokenSet.
31    pub fn new() -> Self {
32        Self(SmallVec::new())
33    }
34
35    /// Number of tokens in the set.
36    pub fn len(&self) -> usize {
37        self.0.len()
38    }
39
40    /// Is the set empty?
41    pub fn is_empty(&self) -> bool {
42        self.0.is_empty()
43    }
44
45    /// Return true if the set contains the given token ID.
46    pub fn contains_token_id(&self, token_id: TokenId) -> bool {
47        self.0.contains(&token_id.raw())
48    }
49
50    /// Get the subset containing only high-value (legalese) tokens.
51    pub fn high_subset(&self, dictionary: &TokenDictionary) -> Self {
52        Self::from_u16_iter(
53            self.iter()
54                .filter(|&tid| dictionary.token_kind(TokenId::new(tid)) == TokenKind::Legalese),
55        )
56    }
57
58    /// Count intersection with another TokenSet (no allocation).
59    pub fn intersection_count(&self, other: &TokenSet) -> usize {
60        let (mut i, mut j, mut count) = (0, 0, 0);
61        while i < self.0.len() && j < other.0.len() {
62            match self.0[i].cmp(&other.0[j]) {
63                Ordering::Less => i += 1,
64                Ordering::Greater => j += 1,
65                Ordering::Equal => {
66                    count += 1;
67                    i += 1;
68                    j += 1;
69                }
70            }
71        }
72        count
73    }
74
75    /// Materialize intersection with another TokenSet.
76    pub fn intersection(&self, other: &TokenSet) -> TokenSet {
77        let mut result = SmallVec::new();
78        let (mut i, mut j) = (0, 0);
79        while i < self.0.len() && j < other.0.len() {
80            match self.0[i].cmp(&other.0[j]) {
81                Ordering::Less => i += 1,
82                Ordering::Greater => j += 1,
83                Ordering::Equal => {
84                    result.push(self.0[i]);
85                    i += 1;
86                    j += 1;
87                }
88            }
89        }
90        Self(result)
91    }
92
93    /// Iterate over the sorted token IDs.
94    pub fn iter(&self) -> impl Iterator<Item = u16> + '_ {
95        self.0.iter().copied()
96    }
97}
98
99impl Deref for TokenSet {
100    type Target = [u16];
101    fn deref(&self) -> &Self::Target {
102        &self.0
103    }
104}
105
106impl Default for TokenSet {
107    fn default() -> Self {
108        Self::new()
109    }
110}
111
112impl std::iter::FromIterator<u16> for TokenSet {
113    fn from_iter<T: IntoIterator<Item = u16>>(iter: T) -> Self {
114        Self::from_u16_iter(iter)
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121    use crate::license_detection::index::dictionary::tid;
122
123    #[test]
124    fn test_from_token_ids() {
125        let set = TokenSet::from_token_ids([tid(4), tid(2), tid(4), tid(1)]);
126
127        assert_eq!(set.iter().collect::<Vec<_>>(), vec![1, 2, 4]);
128        assert!(set.contains_token_id(tid(1)));
129        assert!(set.contains_token_id(tid(2)));
130        assert!(set.contains_token_id(tid(4)));
131    }
132
133    #[test]
134    fn test_high_subset() {
135        let set = TokenSet::from_u16_iter([1, 2, 5, 10]);
136        let dict = TokenDictionary::new_with_legalese(&[("one", 1), ("two", 2)]);
137
138        let high_set = set.high_subset(&dict);
139
140        assert_eq!(high_set.iter().collect::<Vec<_>>(), vec![1, 2]);
141    }
142}