provenant/license_detection/index/
dictionary.rs1use std::collections::HashMap;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
9pub struct TokenId(u16);
10
11impl TokenId {
12 pub const fn new(raw: u16) -> Self {
13 Self(raw)
14 }
15
16 pub const fn raw(self) -> u16 {
17 self.0
18 }
19
20 pub const fn as_usize(self) -> usize {
21 self.0 as usize
22 }
23
24 pub const fn to_le_bytes(self) -> [u8; 2] {
25 self.0.to_le_bytes()
26 }
27}
28
29#[cfg(test)]
30pub const fn tid(raw: u16) -> TokenId {
31 TokenId::new(raw)
32}
33
34impl From<u16> for TokenId {
35 fn from(value: u16) -> Self {
36 Self(value)
37 }
38}
39
40impl From<TokenId> for u16 {
41 fn from(value: TokenId) -> Self {
42 value.0
43 }
44}
45
46impl PartialEq<u16> for TokenId {
47 fn eq(&self, other: &u16) -> bool {
48 self.0 == *other
49 }
50}
51
52impl PartialOrd<u16> for TokenId {
53 fn partial_cmp(&self, other: &u16) -> Option<std::cmp::Ordering> {
54 self.0.partial_cmp(other)
55 }
56}
57
58impl PartialEq<TokenId> for u16 {
59 fn eq(&self, other: &TokenId) -> bool {
60 *self == other.0
61 }
62}
63
64impl PartialOrd<TokenId> for u16 {
65 fn partial_cmp(&self, other: &TokenId) -> Option<std::cmp::Ordering> {
66 self.partial_cmp(&other.0)
67 }
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
71pub enum TokenKind {
72 Legalese,
73 Regular,
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
77pub struct KnownToken {
78 pub id: TokenId,
79 pub kind: TokenKind,
80 pub is_digit_only: bool,
81 pub is_short_or_digit: bool,
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
85pub enum QueryToken {
86 Known(KnownToken),
87 Unknown,
88 Stopword,
89}
90
91#[derive(Debug, Clone, Copy)]
92struct TokenMetadata {
93 kind: TokenKind,
94 is_digit_only: bool,
95 is_short_or_digit: bool,
96}
97
98#[derive(Debug, Clone)]
110pub struct TokenDictionary {
111 tokens_to_ids: HashMap<String, TokenId>,
113
114 token_metadata: Vec<Option<TokenMetadata>>,
115
116 len_legalese: usize,
118
119 next_id: TokenId,
121}
122
123impl TokenDictionary {
124 const DEFAULT_METADATA: TokenMetadata = TokenMetadata {
125 kind: TokenKind::Regular,
126 is_digit_only: false,
127 is_short_or_digit: false,
128 };
129
130 pub fn new_with_legalese(legalese_entries: &[(&str, u16)]) -> Self {
141 let mut tokens_to_ids = HashMap::new();
142 let max_existing_id = legalese_entries
143 .iter()
144 .map(|(_, token_id)| *token_id as usize)
145 .max()
146 .unwrap_or(0);
147 let mut token_metadata = vec![None; max_existing_id.saturating_add(1)];
148
149 for (word, token_id) in legalese_entries {
150 let id = TokenId::from(*token_id);
151 tokens_to_ids.insert(word.to_string(), id);
152 token_metadata[id.as_usize()] = Some(TokenMetadata {
153 kind: TokenKind::Legalese,
154 is_digit_only: word.chars().all(|c| c.is_ascii_digit()),
155 is_short_or_digit: word.len() == 1 || word.chars().all(|c| c.is_ascii_digit()),
156 });
157 }
158
159 let len_legalese = legalese_entries.len();
160 let next_id = TokenId::new((max_existing_id + 1).max(len_legalese) as u16);
161
162 Self {
163 tokens_to_ids,
164 token_metadata,
165 len_legalese,
166 next_id,
167 }
168 }
169
170 pub fn new(legalese_count: usize) -> Self {
178 Self {
179 tokens_to_ids: HashMap::new(),
180 token_metadata: Vec::new(),
181 len_legalese: legalese_count,
182 next_id: TokenId::new(legalese_count as u16),
183 }
184 }
185
186 fn metadata_for(&self, id: TokenId) -> TokenMetadata {
187 self.token_metadata
188 .get(id.as_usize())
189 .and_then(|meta| *meta)
190 .unwrap_or(Self::DEFAULT_METADATA)
191 }
192
193 fn build_known_token(&self, id: TokenId) -> KnownToken {
194 let metadata = self.metadata_for(id);
195 KnownToken {
196 id,
197 kind: metadata.kind,
198 is_digit_only: metadata.is_digit_only,
199 is_short_or_digit: metadata.is_short_or_digit,
200 }
201 }
202
203 fn insert_metadata(&mut self, id: TokenId, kind: TokenKind, token: &str) {
204 let raw = id.as_usize();
205 if self.token_metadata.len() <= raw {
206 self.token_metadata.resize(raw + 1, None);
207 }
208 self.token_metadata[raw] = Some(TokenMetadata {
209 kind,
210 is_digit_only: token.chars().all(|c| c.is_ascii_digit()),
211 is_short_or_digit: token.len() == 1 || token.chars().all(|c| c.is_ascii_digit()),
212 });
213 }
214
215 pub fn intern(&mut self, token: &str) -> KnownToken {
216 if let Some(&id) = self.tokens_to_ids.get(token) {
217 return self.build_known_token(id);
218 }
219
220 let id = self.next_id;
221 self.next_id = TokenId::new(self.next_id.raw() + 1);
222 self.tokens_to_ids.insert(token.to_string(), id);
223 self.insert_metadata(id, TokenKind::Regular, token);
224 self.build_known_token(id)
225 }
226
227 pub fn lookup(&self, token: &str) -> Option<KnownToken> {
228 self.tokens_to_ids
229 .get(token)
230 .copied()
231 .map(|id| self.build_known_token(id))
232 }
233
234 pub fn classify_query_token(&self, token: &str) -> QueryToken {
235 self.lookup(token)
236 .map_or(QueryToken::Unknown, QueryToken::Known)
237 }
238
239 pub fn token_kind(&self, token_id: TokenId) -> TokenKind {
240 self.metadata_for(token_id).kind
241 }
242
243 pub fn is_digit_only_token(&self, token_id: TokenId) -> bool {
244 self.metadata_for(token_id).is_digit_only
245 }
246
247 #[cfg(test)]
248 pub fn get_or_assign(&mut self, token: &str) -> TokenId {
249 self.intern(token).id
250 }
251
252 pub fn get_token_id(&self, token: &str) -> Option<TokenId> {
260 self.lookup(token).map(|token| token.id)
261 }
262
263 #[inline]
265 pub fn get(&self, token: &str) -> Option<TokenId> {
266 self.get_token_id(token)
267 }
268
269 pub const fn legalese_count(&self) -> usize {
271 self.len_legalese
272 }
273
274 #[cfg(test)]
276 pub fn tokens_to_ids(&self) -> impl Iterator<Item = (&String, &TokenId)> {
277 self.tokens_to_ids.iter()
278 }
279}
280
281impl Default for TokenDictionary {
282 fn default() -> Self {
283 Self::new(0)
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 #[test]
292 fn test_token_dictionary_new() {
293 let dict = TokenDictionary::new(10);
294 assert_eq!(dict.legalese_count(), 10);
295 assert_eq!(dict.tokens_to_ids.len(), 0);
296 assert!(dict.tokens_to_ids.is_empty());
297 }
298
299 #[test]
300 fn test_new_with_legalese() {
301 let legalese = [
302 ("license".to_string(), 0u16),
303 ("copyright".to_string(), 1u16),
304 ("permission".to_string(), 2u16),
305 ];
306
307 let mut dict = TokenDictionary::new_with_legalese(
308 &legalese
309 .iter()
310 .map(|(s, i)| (s.as_str(), *i))
311 .collect::<Vec<_>>(),
312 );
313
314 assert_eq!(dict.legalese_count(), 3);
315 assert_eq!(dict.tokens_to_ids.len(), 3);
316 assert!(!dict.tokens_to_ids.is_empty());
317
318 assert_eq!(dict.get_token_id("license"), Some(tid(0)));
320 assert_eq!(dict.get_token_id("copyright"), Some(tid(1)));
321 assert_eq!(dict.get_token_id("permission"), Some(tid(2)));
322
323 let test_id = dict.get_or_assign("test");
325 assert_eq!(test_id, 3);
326 }
327
328 #[test]
329 fn test_new_with_legalese_sorted() {
330 let legalese = [
331 ("copyright".to_string(), 5u16),
332 ("license".to_string(), 0u16),
333 ("permission".to_string(), 10u16),
334 ];
335
336 let mut dict = TokenDictionary::new_with_legalese(
337 &legalese
338 .iter()
339 .map(|(s, i)| (s.as_str(), *i))
340 .collect::<Vec<_>>(),
341 );
342
343 assert_eq!(dict.legalese_count(), 3);
344 assert_eq!(dict.tokens_to_ids.len(), 3);
345
346 assert_eq!(dict.get_token_id("copyright"), Some(tid(5)));
348 assert_eq!(dict.get_token_id("license"), Some(tid(0)));
349 assert_eq!(dict.get_token_id("permission"), Some(tid(10)));
350
351 let test_id = dict.get_or_assign("test");
353 assert_eq!(test_id, tid(11));
354 }
355
356 #[test]
357 fn test_get_or_assign_new_token() {
358 let mut dict = TokenDictionary::new(5);
359
360 let id1 = dict.get_or_assign("hello");
361 let id2 = dict.get_or_assign("world");
362
363 assert_eq!(id1, 5);
365 assert_eq!(id2, 6);
366 assert_eq!(dict.tokens_to_ids.len(), 2);
367 }
368
369 #[test]
370 fn test_get_or_assign_existing_token() {
371 let mut dict = TokenDictionary::new(5);
372
373 let id1 = dict.get_or_assign("hello");
374 let id2 = dict.get_or_assign("hello");
375
376 assert_eq!(id1, id2);
378 assert_eq!(dict.tokens_to_ids.len(), 1);
379 }
380
381 #[test]
382 fn test_get_or_assign_with_preexisting_legalese() {
383 let legalese = [("license".to_string(), 0u16)];
384 let mut dict = TokenDictionary::new_with_legalese(
385 &legalese
386 .iter()
387 .map(|(s, i)| (s.as_str(), *i))
388 .collect::<Vec<_>>(),
389 );
390
391 let id = dict.get_or_assign("license");
393 assert_eq!(id, 0);
394 assert_eq!(dict.tokens_to_ids.len(), 1);
395
396 let new_id = dict.get_or_assign("new");
398 assert_eq!(new_id, 1);
399 assert_eq!(dict.tokens_to_ids.len(), 2);
400 }
401
402 #[test]
403 fn test_get_existing_token() {
404 let mut dict = TokenDictionary::new(5);
405
406 dict.get_or_assign("hello");
407 assert_eq!(dict.get_token_id("hello"), Some(tid(5)));
408 }
409
410 #[test]
411 fn test_get_nonexistent_token() {
412 let dict = TokenDictionary::new(5);
413 assert_eq!(dict.get_token_id("hello"), None);
414 }
415
416 #[test]
417 fn test_legalese_range() {
418 let dict = TokenDictionary::new(10);
419
420 assert!(0 < dict.legalese_count() as u16);
422 assert!(5 < dict.legalese_count() as u16);
423 assert!(9 < dict.legalese_count() as u16);
424
425 assert!(10 >= dict.legalese_count() as u16);
427 assert!(100 >= dict.legalese_count() as u16);
428 }
429
430 #[test]
431 fn test_legalese_range_with_actual_legalese() {
432 let legalese = [
433 ("license".to_string(), 0u16),
434 ("copyright".to_string(), 1u16),
435 ];
436
437 let mut dict = TokenDictionary::new_with_legalese(
438 &legalese
439 .iter()
440 .map(|(s, i)| (s.as_str(), *i))
441 .collect::<Vec<_>>(),
442 );
443
444 assert!(dict.get_token_id("license").unwrap() < dict.legalese_count() as u16);
446 assert!(dict.get_token_id("copyright").unwrap() < dict.legalese_count() as u16);
447
448 let regular_id = dict.get_or_assign("regular");
450 assert!(regular_id >= dict.legalese_count() as u16);
451 }
452
453 #[test]
454 fn test_token_dictionary_default() {
455 let dict = TokenDictionary::default();
456 assert_eq!(dict.legalese_count(), 0);
457 assert!(dict.tokens_to_ids.is_empty());
458 }
459
460 #[test]
461 fn test_get_alias() {
462 let mut dict = TokenDictionary::new(5);
463 dict.get_or_assign("hello");
464
465 assert_eq!(dict.get("hello"), dict.get_token_id("hello"));
467 }
468
469 #[test]
470 fn test_with_actual_legalese_module() {
471 use crate::license_detection::rules::legalese;
472
473 let legalese_words = legalese::get_legalese_words();
474 assert!(!legalese_words.is_empty(), "Should have legalese words");
475
476 let mut dict = TokenDictionary::new_with_legalese(&legalese_words);
477
478 assert_eq!(dict.legalese_count(), legalese_words.len());
480 assert_eq!(dict.tokens_to_ids.len(), legalese_words.len());
481
482 let license_id = dict.get_token_id("license");
484 assert!(license_id.is_some(), "License should be in dictionary");
485 assert!(
486 license_id.unwrap() < dict.legalese_count() as u16,
487 "License should be a legalese token"
488 );
489
490 let copyrighted_id = dict.get_token_id("copyrighted");
493 assert!(
494 copyrighted_id.is_some(),
495 "Copyrighted should be in dictionary"
496 );
497 assert!(
498 copyrighted_id.unwrap() < dict.legalese_count() as u16,
499 "Copyrighted should be a legalese token"
500 );
501
502 let hello_id = dict.get_or_assign("hello");
504 assert!(hello_id >= dict.legalese_count() as u16);
505 assert!(hello_id >= dict.legalese_count() as u16);
506 }
507}