provenant/license_detection/
token_set.rs1use smallvec::SmallVec;
5use std::cmp::Ordering;
6use std::ops::Deref;
7
8use rkyv::Archive;
9
10use crate::license_detection::index::dictionary::{TokenDictionary, TokenId, TokenKind};
11
12#[derive(Clone, Debug, PartialEq, Eq, Archive, rkyv::Serialize, rkyv::Deserialize)]
18pub struct TokenSet(SmallVec<[u16; 64]>);
19
20impl TokenSet {
21 pub fn from_u16_iter<I: IntoIterator<Item = u16>>(iter: I) -> Self {
24 let mut inner: SmallVec<[u16; 64]> = iter.into_iter().collect();
25 inner.sort_unstable();
26 inner.dedup();
27 Self(inner)
28 }
29
30 pub fn from_token_ids<I: IntoIterator<Item = TokenId>>(iter: I) -> Self {
32 Self::from_u16_iter(iter.into_iter().map(|tid| tid.raw()))
33 }
34
35 pub fn new() -> Self {
37 Self(SmallVec::new())
38 }
39
40 pub fn len(&self) -> usize {
42 self.0.len()
43 }
44
45 pub fn is_empty(&self) -> bool {
47 self.0.is_empty()
48 }
49
50 pub fn contains_token_id(&self, token_id: TokenId) -> bool {
52 self.0.contains(&token_id.raw())
53 }
54
55 pub fn high_subset(&self, dictionary: &TokenDictionary) -> Self {
57 Self::from_u16_iter(
58 self.iter()
59 .filter(|&tid| dictionary.token_kind(TokenId::new(tid)) == TokenKind::Legalese),
60 )
61 }
62
63 pub fn intersection_count(&self, other: &TokenSet) -> usize {
65 let (mut i, mut j, mut count) = (0, 0, 0);
66 while i < self.0.len() && j < other.0.len() {
67 match self.0[i].cmp(&other.0[j]) {
68 Ordering::Less => i += 1,
69 Ordering::Greater => j += 1,
70 Ordering::Equal => {
71 count += 1;
72 i += 1;
73 j += 1;
74 }
75 }
76 }
77 count
78 }
79
80 pub fn intersection(&self, other: &TokenSet) -> TokenSet {
82 let mut result = SmallVec::new();
83 let (mut i, mut j) = (0, 0);
84 while i < self.0.len() && j < other.0.len() {
85 match self.0[i].cmp(&other.0[j]) {
86 Ordering::Less => i += 1,
87 Ordering::Greater => j += 1,
88 Ordering::Equal => {
89 result.push(self.0[i]);
90 i += 1;
91 j += 1;
92 }
93 }
94 }
95 Self(result)
96 }
97
98 pub fn iter(&self) -> impl Iterator<Item = u16> + '_ {
100 self.0.iter().copied()
101 }
102}
103
104impl Deref for TokenSet {
105 type Target = [u16];
106 fn deref(&self) -> &Self::Target {
107 &self.0
108 }
109}
110
111impl Default for TokenSet {
112 fn default() -> Self {
113 Self::new()
114 }
115}
116
117impl std::iter::FromIterator<u16> for TokenSet {
118 fn from_iter<T: IntoIterator<Item = u16>>(iter: T) -> Self {
119 Self::from_u16_iter(iter)
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126 use crate::license_detection::index::dictionary::tid;
127
128 #[test]
129 fn test_from_token_ids() {
130 let set = TokenSet::from_token_ids([tid(4), tid(2), tid(4), tid(1)]);
131
132 assert_eq!(set.iter().collect::<Vec<_>>(), vec![1, 2, 4]);
133 assert!(set.contains_token_id(tid(1)));
134 assert!(set.contains_token_id(tid(2)));
135 assert!(set.contains_token_id(tid(4)));
136 }
137
138 #[test]
139 fn test_high_subset() {
140 let set = TokenSet::from_u16_iter([1, 2, 5, 10]);
141 let dict = TokenDictionary::new_with_legalese_pairs(&[("one", 1), ("two", 2)]);
142
143 let high_set = set.high_subset(&dict);
144
145 assert_eq!(high_set.iter().collect::<Vec<_>>(), vec![1, 2]);
146 }
147}