provenant/license_detection/
token_set.rs1use smallvec::SmallVec;
2use std::cmp::Ordering;
3use std::ops::Deref;
4
5use rkyv::Archive;
6
7use crate::license_detection::index::dictionary::{TokenDictionary, TokenId, TokenKind};
8
9#[derive(Clone, Debug, PartialEq, Eq, Archive, rkyv::Serialize, rkyv::Deserialize)]
15pub struct TokenSet(SmallVec<[u16; 64]>);
16
17impl TokenSet {
18 pub fn from_u16_iter<I: IntoIterator<Item = u16>>(iter: I) -> Self {
21 let mut inner: SmallVec<[u16; 64]> = iter.into_iter().collect();
22 inner.sort_unstable();
23 inner.dedup();
24 Self(inner)
25 }
26
27 pub fn from_token_ids<I: IntoIterator<Item = TokenId>>(iter: I) -> Self {
29 Self::from_u16_iter(iter.into_iter().map(|tid| tid.raw()))
30 }
31
32 pub fn new() -> Self {
34 Self(SmallVec::new())
35 }
36
37 pub fn len(&self) -> usize {
39 self.0.len()
40 }
41
42 pub fn is_empty(&self) -> bool {
44 self.0.is_empty()
45 }
46
47 pub fn contains_token_id(&self, token_id: TokenId) -> bool {
49 self.0.contains(&token_id.raw())
50 }
51
52 pub fn high_subset(&self, dictionary: &TokenDictionary) -> Self {
54 Self::from_u16_iter(
55 self.iter()
56 .filter(|&tid| dictionary.token_kind(TokenId::new(tid)) == TokenKind::Legalese),
57 )
58 }
59
60 pub fn intersection_count(&self, other: &TokenSet) -> usize {
62 let (mut i, mut j, mut count) = (0, 0, 0);
63 while i < self.0.len() && j < other.0.len() {
64 match self.0[i].cmp(&other.0[j]) {
65 Ordering::Less => i += 1,
66 Ordering::Greater => j += 1,
67 Ordering::Equal => {
68 count += 1;
69 i += 1;
70 j += 1;
71 }
72 }
73 }
74 count
75 }
76
77 pub fn intersection(&self, other: &TokenSet) -> TokenSet {
79 let mut result = SmallVec::new();
80 let (mut i, mut j) = (0, 0);
81 while i < self.0.len() && j < other.0.len() {
82 match self.0[i].cmp(&other.0[j]) {
83 Ordering::Less => i += 1,
84 Ordering::Greater => j += 1,
85 Ordering::Equal => {
86 result.push(self.0[i]);
87 i += 1;
88 j += 1;
89 }
90 }
91 }
92 Self(result)
93 }
94
95 pub fn iter(&self) -> impl Iterator<Item = u16> + '_ {
97 self.0.iter().copied()
98 }
99}
100
101impl Deref for TokenSet {
102 type Target = [u16];
103 fn deref(&self) -> &Self::Target {
104 &self.0
105 }
106}
107
108impl Default for TokenSet {
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114impl std::iter::FromIterator<u16> for TokenSet {
115 fn from_iter<T: IntoIterator<Item = u16>>(iter: T) -> Self {
116 Self::from_u16_iter(iter)
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123 use crate::license_detection::index::dictionary::tid;
124
125 #[test]
126 fn test_from_token_ids() {
127 let set = TokenSet::from_token_ids([tid(4), tid(2), tid(4), tid(1)]);
128
129 assert_eq!(set.iter().collect::<Vec<_>>(), vec![1, 2, 4]);
130 assert!(set.contains_token_id(tid(1)));
131 assert!(set.contains_token_id(tid(2)));
132 assert!(set.contains_token_id(tid(4)));
133 }
134
135 #[test]
136 fn test_high_subset() {
137 let set = TokenSet::from_u16_iter([1, 2, 5, 10]);
138 let dict = TokenDictionary::new_with_legalese(&[("one", 1), ("two", 2)]);
139
140 let high_set = set.high_subset(&dict);
141
142 assert_eq!(high_set.iter().collect::<Vec<_>>(), vec![1, 2]);
143 }
144}