provenant/license_detection/
token_set.rs1use smallvec::SmallVec;
2use std::cmp::Ordering;
3use std::ops::Deref;
4
5use crate::license_detection::index::dictionary::{TokenDictionary, TokenId, TokenKind};
6
7#[derive(Clone, Debug, PartialEq, Eq)]
13pub struct TokenSet(SmallVec<[u16; 64]>);
14
15impl TokenSet {
16 pub fn from_u16_iter<I: IntoIterator<Item = u16>>(iter: I) -> Self {
19 let mut inner: SmallVec<[u16; 64]> = iter.into_iter().collect();
20 inner.sort_unstable();
21 inner.dedup();
22 Self(inner)
23 }
24
25 pub fn from_token_ids<I: IntoIterator<Item = TokenId>>(iter: I) -> Self {
27 Self::from_u16_iter(iter.into_iter().map(|tid| tid.raw()))
28 }
29
30 pub fn new() -> Self {
32 Self(SmallVec::new())
33 }
34
35 pub fn len(&self) -> usize {
37 self.0.len()
38 }
39
40 pub fn is_empty(&self) -> bool {
42 self.0.is_empty()
43 }
44
45 pub fn contains_token_id(&self, token_id: TokenId) -> bool {
47 self.0.contains(&token_id.raw())
48 }
49
50 pub fn high_subset(&self, dictionary: &TokenDictionary) -> Self {
52 Self::from_u16_iter(
53 self.iter()
54 .filter(|&tid| dictionary.token_kind(TokenId::new(tid)) == TokenKind::Legalese),
55 )
56 }
57
58 pub fn intersection_count(&self, other: &TokenSet) -> usize {
60 let (mut i, mut j, mut count) = (0, 0, 0);
61 while i < self.0.len() && j < other.0.len() {
62 match self.0[i].cmp(&other.0[j]) {
63 Ordering::Less => i += 1,
64 Ordering::Greater => j += 1,
65 Ordering::Equal => {
66 count += 1;
67 i += 1;
68 j += 1;
69 }
70 }
71 }
72 count
73 }
74
75 pub fn intersection(&self, other: &TokenSet) -> TokenSet {
77 let mut result = SmallVec::new();
78 let (mut i, mut j) = (0, 0);
79 while i < self.0.len() && j < other.0.len() {
80 match self.0[i].cmp(&other.0[j]) {
81 Ordering::Less => i += 1,
82 Ordering::Greater => j += 1,
83 Ordering::Equal => {
84 result.push(self.0[i]);
85 i += 1;
86 j += 1;
87 }
88 }
89 }
90 Self(result)
91 }
92
93 pub fn iter(&self) -> impl Iterator<Item = u16> + '_ {
95 self.0.iter().copied()
96 }
97}
98
99impl Deref for TokenSet {
100 type Target = [u16];
101 fn deref(&self) -> &Self::Target {
102 &self.0
103 }
104}
105
106impl Default for TokenSet {
107 fn default() -> Self {
108 Self::new()
109 }
110}
111
112impl std::iter::FromIterator<u16> for TokenSet {
113 fn from_iter<T: IntoIterator<Item = u16>>(iter: T) -> Self {
114 Self::from_u16_iter(iter)
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121 use crate::license_detection::index::dictionary::tid;
122
123 #[test]
124 fn test_from_token_ids() {
125 let set = TokenSet::from_token_ids([tid(4), tid(2), tid(4), tid(1)]);
126
127 assert_eq!(set.iter().collect::<Vec<_>>(), vec![1, 2, 4]);
128 assert!(set.contains_token_id(tid(1)));
129 assert!(set.contains_token_id(tid(2)));
130 assert!(set.contains_token_id(tid(4)));
131 }
132
133 #[test]
134 fn test_high_subset() {
135 let set = TokenSet::from_u16_iter([1, 2, 5, 10]);
136 let dict = TokenDictionary::new_with_legalese(&[("one", 1), ("two", 2)]);
137
138 let high_set = set.high_subset(&dict);
139
140 assert_eq!(high_set.iter().collect::<Vec<_>>(), vec![1, 2]);
141 }
142}