Skip to main content

provenant/license_detection/
token_set.rs

1use smallvec::SmallVec;
2use std::cmp::Ordering;
3use std::ops::Deref;
4
5use rkyv::Archive;
6
7use crate::license_detection::index::dictionary::{TokenDictionary, TokenId, TokenKind};
8
9/// A set of token IDs stored as a sorted SmallVec.
10///
11/// Invariant: elements are always sorted and deduplicated.
12/// Construct via `TokenSet::from_token_ids()`, `TokenSet::from_u16_iter()`,
13/// or `.collect()` from an iterator of u16.
14#[derive(Clone, Debug, PartialEq, Eq, Archive, rkyv::Serialize, rkyv::Deserialize)]
15pub struct TokenSet(SmallVec<[u16; 64]>);
16
17impl TokenSet {
18    /// Create a TokenSet from an iterator of u16 token IDs.
19    /// Sorts and deduplicates the input.
20    pub fn from_u16_iter<I: IntoIterator<Item = u16>>(iter: I) -> Self {
21        let mut inner: SmallVec<[u16; 64]> = iter.into_iter().collect();
22        inner.sort_unstable();
23        inner.dedup();
24        Self(inner)
25    }
26
27    /// Create a TokenSet from an iterator of TokenId values.
28    pub fn from_token_ids<I: IntoIterator<Item = TokenId>>(iter: I) -> Self {
29        Self::from_u16_iter(iter.into_iter().map(|tid| tid.raw()))
30    }
31
32    /// Create an empty TokenSet.
33    pub fn new() -> Self {
34        Self(SmallVec::new())
35    }
36
37    /// Number of tokens in the set.
38    pub fn len(&self) -> usize {
39        self.0.len()
40    }
41
42    /// Is the set empty?
43    pub fn is_empty(&self) -> bool {
44        self.0.is_empty()
45    }
46
47    /// Return true if the set contains the given token ID.
48    pub fn contains_token_id(&self, token_id: TokenId) -> bool {
49        self.0.contains(&token_id.raw())
50    }
51
52    /// Get the subset containing only high-value (legalese) tokens.
53    pub fn high_subset(&self, dictionary: &TokenDictionary) -> Self {
54        Self::from_u16_iter(
55            self.iter()
56                .filter(|&tid| dictionary.token_kind(TokenId::new(tid)) == TokenKind::Legalese),
57        )
58    }
59
60    /// Count intersection with another TokenSet (no allocation).
61    pub fn intersection_count(&self, other: &TokenSet) -> usize {
62        let (mut i, mut j, mut count) = (0, 0, 0);
63        while i < self.0.len() && j < other.0.len() {
64            match self.0[i].cmp(&other.0[j]) {
65                Ordering::Less => i += 1,
66                Ordering::Greater => j += 1,
67                Ordering::Equal => {
68                    count += 1;
69                    i += 1;
70                    j += 1;
71                }
72            }
73        }
74        count
75    }
76
77    /// Materialize intersection with another TokenSet.
78    pub fn intersection(&self, other: &TokenSet) -> TokenSet {
79        let mut result = SmallVec::new();
80        let (mut i, mut j) = (0, 0);
81        while i < self.0.len() && j < other.0.len() {
82            match self.0[i].cmp(&other.0[j]) {
83                Ordering::Less => i += 1,
84                Ordering::Greater => j += 1,
85                Ordering::Equal => {
86                    result.push(self.0[i]);
87                    i += 1;
88                    j += 1;
89                }
90            }
91        }
92        Self(result)
93    }
94
95    /// Iterate over the sorted token IDs.
96    pub fn iter(&self) -> impl Iterator<Item = u16> + '_ {
97        self.0.iter().copied()
98    }
99}
100
101impl Deref for TokenSet {
102    type Target = [u16];
103    fn deref(&self) -> &Self::Target {
104        &self.0
105    }
106}
107
108impl Default for TokenSet {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114impl std::iter::FromIterator<u16> for TokenSet {
115    fn from_iter<T: IntoIterator<Item = u16>>(iter: T) -> Self {
116        Self::from_u16_iter(iter)
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123    use crate::license_detection::index::dictionary::tid;
124
125    #[test]
126    fn test_from_token_ids() {
127        let set = TokenSet::from_token_ids([tid(4), tid(2), tid(4), tid(1)]);
128
129        assert_eq!(set.iter().collect::<Vec<_>>(), vec![1, 2, 4]);
130        assert!(set.contains_token_id(tid(1)));
131        assert!(set.contains_token_id(tid(2)));
132        assert!(set.contains_token_id(tid(4)));
133    }
134
135    #[test]
136    fn test_high_subset() {
137        let set = TokenSet::from_u16_iter([1, 2, 5, 10]);
138        let dict = TokenDictionary::new_with_legalese(&[("one", 1), ("two", 2)]);
139
140        let high_set = set.high_subset(&dict);
141
142        assert_eq!(high_set.iter().collect::<Vec<_>>(), vec![1, 2]);
143    }
144}